feat: add closure support to C backend

Implement closures/lambdas in the C code generator:

- Add LuxClosure struct (env pointer + function pointer)
- Add ClosureInfo to track closure metadata during generation
- Implement free variable analysis to find captured variables
- Generate environment structs for each lambda's captured vars
- Generate lambda implementation functions
- Support indirect closure calls via function pointer casting
- Track functions returning closures for proper type inference

Example: fn(x) => x + n compiles to a LuxEnv struct holding n,
a lambda_N function taking env + x, and closure allocation code.

Limitations: No memory management (closures leak), types
mostly hardcoded to LuxInt.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-14 03:53:38 -05:00
parent 67437b8273
commit 6ec1f3bdbb

View File

@@ -15,6 +15,16 @@ pub struct CGenError {
pub span: Option<Span>,
}
/// Information about a closure to be emitted
#[derive(Debug, Clone)]
struct ClosureInfo {
id: usize,
env_fields: Vec<(String, String)>, // (var_name, c_type)
params: Vec<(String, String)>, // (param_name, c_type)
return_type: String,
body: Expr,
}
impl std::fmt::Display for CGenError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "C codegen error: {}", self.message)
@@ -37,6 +47,12 @@ pub struct CBackend {
name_counter: usize,
/// Effects used in the program (for evidence struct)
effects_used: HashSet<String>,
/// Closures to emit (collected during expression generation)
closures: Vec<ClosureInfo>,
/// Variables in scope (for free variable analysis)
local_vars: HashSet<String>,
/// Functions that return closures
closure_returning_functions: HashSet<String>,
}
impl CBackend {
@@ -48,12 +64,16 @@ impl CBackend {
types_emitted: HashSet::new(),
name_counter: 0,
effects_used: HashSet::new(),
closures: Vec::new(),
local_vars: HashSet::new(),
closure_returning_functions: HashSet::new(),
}
}
/// Generate C code from a Lux program
pub fn generate(&mut self, program: &Program) -> Result<String, CGenError> {
self.output.clear();
self.closures.clear();
self.emit_prelude();
// First pass: collect all function names and types
@@ -61,6 +81,10 @@ impl CBackend {
match decl {
Declaration::Function(f) => {
self.functions.insert(f.name.name.clone());
// Check if this function returns a closure
if matches!(&f.return_type, TypeExpr::Function { .. }) {
self.closure_returning_functions.insert(f.name.name.clone());
}
}
Declaration::Type(t) => {
self.collect_type(t)?;
@@ -72,10 +96,13 @@ impl CBackend {
// Emit type definitions
self.emit_type_definitions(program)?;
// Emit forward declarations
// Emit forward declarations for regular functions
self.emit_forward_declarations(program)?;
// Emit function definitions
// Generate function bodies to a temporary buffer
// This collects closures that need to be emitted
let saved_output = std::mem::take(&mut self.output);
for decl in &program.declarations {
match decl {
Declaration::Function(f) => {
@@ -91,12 +118,177 @@ impl CBackend {
}
}
let function_output = std::mem::replace(&mut self.output, saved_output);
// Now emit closure definitions (collected during function generation)
self.emit_closures()?;
// Append the function definitions
self.output.push_str(&function_output);
// Emit main wrapper if there's a main function or top-level expressions
self.emit_main_wrapper(program)?;
Ok(self.output.clone())
}
/// Emit all collected closure definitions
fn emit_closures(&mut self) -> Result<(), CGenError> {
if self.closures.is_empty() {
return Ok(());
}
self.writeln("// === Closure Definitions ===");
self.writeln("");
// Take closures to avoid borrow issues
let closures = std::mem::take(&mut self.closures);
for closure in &closures {
// Emit environment struct
if closure.env_fields.is_empty() {
// No captured variables - no env struct needed
} else {
self.writeln(&format!("typedef struct {{"));
self.indent += 1;
for (name, typ) in &closure.env_fields {
self.writeln(&format!("{} {};", typ, name));
}
self.indent -= 1;
self.writeln(&format!("}} LuxEnv_{};", closure.id));
self.writeln("");
}
// Emit lambda implementation
let params_str = if closure.params.is_empty() {
"void* _env".to_string()
} else {
let ps: Vec<String> = closure.params.iter()
.map(|(name, typ)| format!("{} {}", typ, name))
.collect();
format!("void* _env, {}", ps.join(", "))
};
self.writeln(&format!("static {} lambda_{}({}) {{",
closure.return_type, closure.id, params_str));
self.indent += 1;
// Cast env pointer if we have captured variables
if !closure.env_fields.is_empty() {
self.writeln(&format!("LuxEnv_{}* env = (LuxEnv_{}*)_env;",
closure.id, closure.id));
}
// Emit body with env-> substitution for captured variables
let body_result = self.emit_closure_body(&closure.body, &closure.env_fields)?;
self.writeln(&format!("return {};", body_result));
self.indent -= 1;
self.writeln("}");
self.writeln("");
}
// Restore closures (in case more are collected)
self.closures = closures;
Ok(())
}
/// Emit closure body with environment variable substitution
fn emit_closure_body(&mut self, expr: &Expr, env_fields: &[(String, String)]) -> Result<String, CGenError> {
// Create a set of captured variable names for quick lookup
let captured: HashSet<&str> = env_fields.iter().map(|(n, _)| n.as_str()).collect();
self.emit_expr_with_env(expr, &captured)
}
/// Emit expression, substituting captured variables with env->name
fn emit_expr_with_env(&mut self, expr: &Expr, captured: &HashSet<&str>) -> Result<String, CGenError> {
match expr {
Expr::Var(ident) => {
if captured.contains(ident.name.as_str()) {
Ok(format!("env->{}", ident.name))
} else if self.functions.contains(&ident.name) {
Ok(self.mangle_name(&ident.name))
} else {
Ok(ident.name.clone())
}
}
Expr::Literal(lit) => self.emit_literal(lit),
Expr::BinaryOp { op, left, right, .. } => {
let l = self.emit_expr_with_env(left, captured)?;
let r = self.emit_expr_with_env(right, captured)?;
let op_str = match op {
BinaryOp::Add => "+",
BinaryOp::Sub => "-",
BinaryOp::Mul => "*",
BinaryOp::Div => "/",
BinaryOp::Mod => "%",
BinaryOp::Eq => "==",
BinaryOp::Ne => "!=",
BinaryOp::Lt => "<",
BinaryOp::Le => "<=",
BinaryOp::Gt => ">",
BinaryOp::Ge => ">=",
BinaryOp::And => "&&",
BinaryOp::Or => "||",
_ => return Err(CGenError {
message: format!("Unsupported binary operator in closure"),
span: None,
}),
};
Ok(format!("({} {} {})", l, op_str, r))
}
Expr::UnaryOp { op, operand, .. } => {
let o = self.emit_expr_with_env(operand, captured)?;
let op_str = match op {
UnaryOp::Neg => "-",
UnaryOp::Not => "!",
};
Ok(format!("({}{})", op_str, o))
}
Expr::If { condition, then_branch, else_branch, .. } => {
let c = self.emit_expr_with_env(condition, captured)?;
let t = self.emit_expr_with_env(then_branch, captured)?;
let e = self.emit_expr_with_env(else_branch, captured)?;
Ok(format!("({} ? {} : {})", c, t, e))
}
Expr::Call { func, args, .. } => {
let arg_strs: Result<Vec<_>, _> = args.iter()
.map(|a| self.emit_expr_with_env(a, captured))
.collect();
let args_str = arg_strs?.join(", ");
match func.as_ref() {
Expr::Var(ident) if self.functions.contains(&ident.name) => {
let c_func_name = self.mangle_name(&ident.name);
Ok(format!("{}({})", c_func_name, args_str))
}
_ => {
let closure_expr = self.emit_expr_with_env(func, captured)?;
let param_types: Vec<&str> = args.iter().map(|_| "LuxInt").collect();
let params_str = if param_types.is_empty() {
String::new()
} else {
format!(", {}", param_types.join(", "))
};
let args_with_env = if args_str.is_empty() {
format!("({})->env", closure_expr)
} else {
format!("({})->env, {}", closure_expr, args_str)
};
Ok(format!("((LuxInt(*)(void*{}))({})->fn_ptr)({})",
params_str, closure_expr, args_with_env))
}
}
}
_ => {
// For other expressions, fall back to regular emit_expr
// This may not handle captured variables correctly for complex expressions
self.emit_expr(expr)
}
}
}
fn emit_prelude(&mut self) {
self.writeln("// Generated by Lux compiler");
self.writeln("// Do not edit - regenerate from .lux source");
@@ -115,6 +307,9 @@ impl CBackend {
self.writeln("typedef char* LuxString;");
self.writeln("typedef void* LuxUnit;");
self.writeln("");
self.writeln("// Closure representation: env pointer + function pointer");
self.writeln("typedef struct { void* env; void* fn_ptr; } LuxClosure;");
self.writeln("");
self.writeln("// === String Operations ===");
self.writeln("");
self.writeln("static LuxString lux_string_concat(LuxString a, LuxString b) {");
@@ -364,25 +559,94 @@ impl CBackend {
}
Expr::Call { func, args, .. } => {
let func_name = match func.as_ref() {
Expr::Var(ident) => ident.name.clone(),
_ => return Err(CGenError {
message: "Only direct function calls supported".to_string(),
span: None,
}),
};
let arg_strs: Result<Vec<_>, _> = args.iter().map(|a| self.emit_expr(a)).collect();
let args_str = arg_strs?.join(", ");
// Mangle user-defined function names
let c_func_name = if self.functions.contains(&func_name) {
self.mangle_name(&func_name)
} else {
func_name
};
match func.as_ref() {
Expr::Var(ident) if self.functions.contains(&ident.name) => {
// Direct call to a known function
let c_func_name = self.mangle_name(&ident.name);
Ok(format!("{}({})", c_func_name, args_str))
}
_ => {
// Indirect call - treat as closure
let closure_expr = self.emit_expr(func)?;
// Build the cast for the function pointer
// For now, assume all args are LuxInt and return LuxInt
let param_types: Vec<&str> = args.iter().map(|_| "LuxInt").collect();
let params_str = if param_types.is_empty() {
String::new()
} else {
format!(", {}", param_types.join(", "))
};
let args_with_env = if args_str.is_empty() {
format!("({})->env", closure_expr)
} else {
format!("({})->env, {}", closure_expr, args_str)
};
Ok(format!("((LuxInt(*)(void*{}))({})->fn_ptr)({})",
params_str, closure_expr, args_with_env))
}
}
}
Ok(format!("{}({})", c_func_name, args_str))
Expr::Lambda { params, body, return_type, .. } => {
// Find free variables in the lambda body
let param_names: HashSet<String> = params.iter()
.map(|p| p.name.name.clone())
.collect();
let free_vars = self.find_free_vars(body, &param_names);
// Generate unique closure ID
let id = self.fresh_name();
// Determine parameter types
let param_pairs: Vec<(String, String)> = params.iter()
.map(|p| {
let typ = self.type_expr_to_c(&p.typ)
.unwrap_or_else(|_| "LuxInt".to_string());
(p.name.name.clone(), typ)
})
.collect();
// Determine return type (default to LuxInt)
let ret_type = return_type.as_ref()
.map(|t| self.type_expr_to_c(t).unwrap_or_else(|_| "LuxInt".to_string()))
.unwrap_or_else(|| "LuxInt".to_string());
// Determine captured variable types (default to LuxInt for now)
let env_fields: Vec<(String, String)> = free_vars.iter()
.map(|v| (v.clone(), "LuxInt".to_string()))
.collect();
// Store closure info for later emission
self.closures.push(ClosureInfo {
id,
env_fields: env_fields.clone(),
params: param_pairs,
return_type: ret_type,
body: (**body).clone(),
});
// Generate code to create the closure at runtime
let temp_env = format!("_env_{}", id);
let temp_closure = format!("_closure_{}", id);
// Allocate and initialize environment struct
if env_fields.is_empty() {
self.writeln(&format!("LuxClosure* {} = malloc(sizeof(LuxClosure));", temp_closure));
self.writeln(&format!("{}->env = NULL;", temp_closure));
} else {
self.writeln(&format!("LuxEnv_{}* {} = malloc(sizeof(LuxEnv_{}));", id, temp_env, id));
for (name, _) in &env_fields {
self.writeln(&format!("{}->{} = {};", temp_env, name, name));
}
self.writeln(&format!("LuxClosure* {} = malloc(sizeof(LuxClosure));", temp_closure));
self.writeln(&format!("{}->env = {};", temp_closure, temp_env));
}
self.writeln(&format!("{}->fn_ptr = (void*)lambda_{};", temp_closure, id));
Ok(temp_closure)
}
Expr::Block { statements, result, .. } => {
@@ -390,7 +654,13 @@ impl CBackend {
match stmt {
Statement::Let { name, value, .. } => {
let val = self.emit_expr(value)?;
self.writeln(&format!("LuxInt {} = {};", name.name, val));
// Infer type from value: closures return LuxClosure*
let typ = if val.starts_with("_closure_") || self.is_closure_returning_call(value) {
"LuxClosure*"
} else {
"LuxInt"
};
self.writeln(&format!("{} {} = {};", typ, name.name, val));
}
Statement::Expr(e) => {
let _ = self.emit_expr(e)?;
@@ -580,7 +850,7 @@ impl CBackend {
}
TypeExpr::Unit => Ok("void".to_string()),
TypeExpr::Versioned { base, .. } => self.type_expr_to_c(base),
TypeExpr::Function { .. } => Ok("void*".to_string()),
TypeExpr::Function { .. } => Ok("LuxClosure*".to_string()),
TypeExpr::Tuple(_) => Ok("void*".to_string()),
TypeExpr::Record(_) => Ok("void*".to_string()),
}
@@ -591,6 +861,126 @@ impl CBackend {
self.name_counter
}
/// Check if an expression is a call to a function that returns a closure
fn is_closure_returning_call(&self, expr: &Expr) -> bool {
match expr {
Expr::Lambda { .. } => true,
Expr::Call { func, .. } => {
// Check if we're calling a function known to return a closure
if let Expr::Var(ident) = func.as_ref() {
self.closure_returning_functions.contains(&ident.name)
} else {
false
}
}
_ => false,
}
}
/// Find free variables in an expression (variables not in bound set)
fn find_free_vars(&self, expr: &Expr, bound: &HashSet<String>) -> Vec<String> {
let mut free = Vec::new();
self.collect_free_vars(expr, bound, &mut free);
// Remove duplicates while preserving order
let mut seen = HashSet::new();
free.retain(|v| seen.insert(v.clone()));
free
}
fn collect_free_vars(&self, expr: &Expr, bound: &HashSet<String>, free: &mut Vec<String>) {
match expr {
Expr::Var(ident) => {
// If not bound locally and not a known function, it's a free variable
if !bound.contains(&ident.name) && !self.functions.contains(&ident.name) {
free.push(ident.name.clone());
}
}
Expr::Literal(_) => {}
Expr::BinaryOp { left, right, .. } => {
self.collect_free_vars(left, bound, free);
self.collect_free_vars(right, bound, free);
}
Expr::UnaryOp { operand, .. } => {
self.collect_free_vars(operand, bound, free);
}
Expr::If { condition, then_branch, else_branch, .. } => {
self.collect_free_vars(condition, bound, free);
self.collect_free_vars(then_branch, bound, free);
self.collect_free_vars(else_branch, bound, free);
}
Expr::Call { func, args, .. } => {
self.collect_free_vars(func, bound, free);
for arg in args {
self.collect_free_vars(arg, bound, free);
}
}
Expr::Block { statements, result, .. } => {
let mut inner_bound = bound.clone();
for stmt in statements {
match stmt {
Statement::Let { name, value, .. } => {
self.collect_free_vars(value, &inner_bound, free);
inner_bound.insert(name.name.clone());
}
Statement::Expr(e) => {
self.collect_free_vars(e, &inner_bound, free);
}
}
}
self.collect_free_vars(result, &inner_bound, free);
}
Expr::Lambda { params, body, .. } => {
let mut inner_bound = bound.clone();
for p in params {
inner_bound.insert(p.name.name.clone());
}
self.collect_free_vars(body, &inner_bound, free);
}
Expr::Record { fields, .. } => {
for (_, val) in fields {
self.collect_free_vars(val, bound, free);
}
}
Expr::Field { object, .. } => {
self.collect_free_vars(object, bound, free);
}
Expr::Match { scrutinee, arms, .. } => {
self.collect_free_vars(scrutinee, bound, free);
for arm in arms {
// TODO: Handle pattern bindings
self.collect_free_vars(&arm.body, bound, free);
}
}
Expr::EffectOp { args, .. } => {
for arg in args {
self.collect_free_vars(arg, bound, free);
}
}
Expr::Run { expr, .. } => {
self.collect_free_vars(expr, bound, free);
}
Expr::Tuple { elements, .. } => {
for e in elements {
self.collect_free_vars(e, bound, free);
}
}
Expr::List { elements, .. } => {
for e in elements {
self.collect_free_vars(e, bound, free);
}
}
Expr::Resume { value, .. } => {
self.collect_free_vars(value, bound, free);
}
Expr::Let { value, body, name, .. } => {
self.collect_free_vars(value, bound, free);
let mut inner_bound = bound.clone();
inner_bound.insert(name.name.clone());
self.collect_free_vars(body, &inner_bound, free);
}
}
}
fn writeln(&mut self, line: &str) {
let indent = " ".repeat(self.indent);
writeln!(self.output, "{}{}", indent, line).unwrap();
@@ -643,4 +1033,40 @@ mod tests {
let c_code = generate(source).unwrap();
assert!(c_code.contains("lux_console_print"));
}
#[test]
fn test_closure_basic() {
let source = r#"
fn makeAdder(n: Int): fn(Int): Int =
fn(x: Int): Int => x + n
"#;
let c_code = generate(source).unwrap();
// Should have closure type
assert!(c_code.contains("LuxClosure"));
// Should have environment struct
assert!(c_code.contains("LuxEnv_"));
// Should have lambda function
assert!(c_code.contains("lambda_"));
// Function should return LuxClosure*
assert!(c_code.contains("LuxClosure* makeAdder_lux"));
}
#[test]
fn test_closure_call() {
let source = r#"
fn makeAdder(n: Int): fn(Int): Int =
fn(x: Int): Int => x + n
fn test(): Int = {
let add5 = makeAdder(5)
add5(10)
}
"#;
let c_code = generate(source).unwrap();
// Local should be typed as LuxClosure*
assert!(c_code.contains("LuxClosure* add5"));
// Should have indirect call syntax
assert!(c_code.contains("->fn_ptr"));
assert!(c_code.contains("->env"));
}
}