diff --git a/src/codegen/c_backend.rs b/src/codegen/c_backend.rs index eaae92b..f84e7d4 100644 --- a/src/codegen/c_backend.rs +++ b/src/codegen/c_backend.rs @@ -15,6 +15,16 @@ pub struct CGenError { pub span: Option, } +/// Information about a closure to be emitted +#[derive(Debug, Clone)] +struct ClosureInfo { + id: usize, + env_fields: Vec<(String, String)>, // (var_name, c_type) + params: Vec<(String, String)>, // (param_name, c_type) + return_type: String, + body: Expr, +} + impl std::fmt::Display for CGenError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "C codegen error: {}", self.message) @@ -37,6 +47,12 @@ pub struct CBackend { name_counter: usize, /// Effects used in the program (for evidence struct) effects_used: HashSet, + /// Closures to emit (collected during expression generation) + closures: Vec, + /// Variables in scope (for free variable analysis) + local_vars: HashSet, + /// Functions that return closures + closure_returning_functions: HashSet, } impl CBackend { @@ -48,12 +64,16 @@ impl CBackend { types_emitted: HashSet::new(), name_counter: 0, effects_used: HashSet::new(), + closures: Vec::new(), + local_vars: HashSet::new(), + closure_returning_functions: HashSet::new(), } } /// Generate C code from a Lux program pub fn generate(&mut self, program: &Program) -> Result { self.output.clear(); + self.closures.clear(); self.emit_prelude(); // First pass: collect all function names and types @@ -61,6 +81,10 @@ impl CBackend { match decl { Declaration::Function(f) => { self.functions.insert(f.name.name.clone()); + // Check if this function returns a closure + if matches!(&f.return_type, TypeExpr::Function { .. }) { + self.closure_returning_functions.insert(f.name.name.clone()); + } } Declaration::Type(t) => { self.collect_type(t)?; @@ -72,10 +96,13 @@ impl CBackend { // Emit type definitions self.emit_type_definitions(program)?; - // Emit forward declarations + // Emit forward declarations for regular functions self.emit_forward_declarations(program)?; - // Emit function definitions + // Generate function bodies to a temporary buffer + // This collects closures that need to be emitted + let saved_output = std::mem::take(&mut self.output); + for decl in &program.declarations { match decl { Declaration::Function(f) => { @@ -91,12 +118,177 @@ impl CBackend { } } + let function_output = std::mem::replace(&mut self.output, saved_output); + + // Now emit closure definitions (collected during function generation) + self.emit_closures()?; + + // Append the function definitions + self.output.push_str(&function_output); + // Emit main wrapper if there's a main function or top-level expressions self.emit_main_wrapper(program)?; Ok(self.output.clone()) } + /// Emit all collected closure definitions + fn emit_closures(&mut self) -> Result<(), CGenError> { + if self.closures.is_empty() { + return Ok(()); + } + + self.writeln("// === Closure Definitions ==="); + self.writeln(""); + + // Take closures to avoid borrow issues + let closures = std::mem::take(&mut self.closures); + + for closure in &closures { + // Emit environment struct + if closure.env_fields.is_empty() { + // No captured variables - no env struct needed + } else { + self.writeln(&format!("typedef struct {{")); + self.indent += 1; + for (name, typ) in &closure.env_fields { + self.writeln(&format!("{} {};", typ, name)); + } + self.indent -= 1; + self.writeln(&format!("}} LuxEnv_{};", closure.id)); + self.writeln(""); + } + + // Emit lambda implementation + let params_str = if closure.params.is_empty() { + "void* _env".to_string() + } else { + let ps: Vec = closure.params.iter() + .map(|(name, typ)| format!("{} {}", typ, name)) + .collect(); + format!("void* _env, {}", ps.join(", ")) + }; + + self.writeln(&format!("static {} lambda_{}({}) {{", + closure.return_type, closure.id, params_str)); + self.indent += 1; + + // Cast env pointer if we have captured variables + if !closure.env_fields.is_empty() { + self.writeln(&format!("LuxEnv_{}* env = (LuxEnv_{}*)_env;", + closure.id, closure.id)); + } + + // Emit body with env-> substitution for captured variables + let body_result = self.emit_closure_body(&closure.body, &closure.env_fields)?; + self.writeln(&format!("return {};", body_result)); + + self.indent -= 1; + self.writeln("}"); + self.writeln(""); + } + + // Restore closures (in case more are collected) + self.closures = closures; + + Ok(()) + } + + /// Emit closure body with environment variable substitution + fn emit_closure_body(&mut self, expr: &Expr, env_fields: &[(String, String)]) -> Result { + // Create a set of captured variable names for quick lookup + let captured: HashSet<&str> = env_fields.iter().map(|(n, _)| n.as_str()).collect(); + self.emit_expr_with_env(expr, &captured) + } + + /// Emit expression, substituting captured variables with env->name + fn emit_expr_with_env(&mut self, expr: &Expr, captured: &HashSet<&str>) -> Result { + match expr { + Expr::Var(ident) => { + if captured.contains(ident.name.as_str()) { + Ok(format!("env->{}", ident.name)) + } else if self.functions.contains(&ident.name) { + Ok(self.mangle_name(&ident.name)) + } else { + Ok(ident.name.clone()) + } + } + Expr::Literal(lit) => self.emit_literal(lit), + Expr::BinaryOp { op, left, right, .. } => { + let l = self.emit_expr_with_env(left, captured)?; + let r = self.emit_expr_with_env(right, captured)?; + let op_str = match op { + BinaryOp::Add => "+", + BinaryOp::Sub => "-", + BinaryOp::Mul => "*", + BinaryOp::Div => "/", + BinaryOp::Mod => "%", + BinaryOp::Eq => "==", + BinaryOp::Ne => "!=", + BinaryOp::Lt => "<", + BinaryOp::Le => "<=", + BinaryOp::Gt => ">", + BinaryOp::Ge => ">=", + BinaryOp::And => "&&", + BinaryOp::Or => "||", + _ => return Err(CGenError { + message: format!("Unsupported binary operator in closure"), + span: None, + }), + }; + Ok(format!("({} {} {})", l, op_str, r)) + } + Expr::UnaryOp { op, operand, .. } => { + let o = self.emit_expr_with_env(operand, captured)?; + let op_str = match op { + UnaryOp::Neg => "-", + UnaryOp::Not => "!", + }; + Ok(format!("({}{})", op_str, o)) + } + Expr::If { condition, then_branch, else_branch, .. } => { + let c = self.emit_expr_with_env(condition, captured)?; + let t = self.emit_expr_with_env(then_branch, captured)?; + let e = self.emit_expr_with_env(else_branch, captured)?; + Ok(format!("({} ? {} : {})", c, t, e)) + } + Expr::Call { func, args, .. } => { + let arg_strs: Result, _> = args.iter() + .map(|a| self.emit_expr_with_env(a, captured)) + .collect(); + let args_str = arg_strs?.join(", "); + + match func.as_ref() { + Expr::Var(ident) if self.functions.contains(&ident.name) => { + let c_func_name = self.mangle_name(&ident.name); + Ok(format!("{}({})", c_func_name, args_str)) + } + _ => { + let closure_expr = self.emit_expr_with_env(func, captured)?; + let param_types: Vec<&str> = args.iter().map(|_| "LuxInt").collect(); + let params_str = if param_types.is_empty() { + String::new() + } else { + format!(", {}", param_types.join(", ")) + }; + let args_with_env = if args_str.is_empty() { + format!("({})->env", closure_expr) + } else { + format!("({})->env, {}", closure_expr, args_str) + }; + Ok(format!("((LuxInt(*)(void*{}))({})->fn_ptr)({})", + params_str, closure_expr, args_with_env)) + } + } + } + _ => { + // For other expressions, fall back to regular emit_expr + // This may not handle captured variables correctly for complex expressions + self.emit_expr(expr) + } + } + } + fn emit_prelude(&mut self) { self.writeln("// Generated by Lux compiler"); self.writeln("// Do not edit - regenerate from .lux source"); @@ -115,6 +307,9 @@ impl CBackend { self.writeln("typedef char* LuxString;"); self.writeln("typedef void* LuxUnit;"); self.writeln(""); + self.writeln("// Closure representation: env pointer + function pointer"); + self.writeln("typedef struct { void* env; void* fn_ptr; } LuxClosure;"); + self.writeln(""); self.writeln("// === String Operations ==="); self.writeln(""); self.writeln("static LuxString lux_string_concat(LuxString a, LuxString b) {"); @@ -364,25 +559,94 @@ impl CBackend { } Expr::Call { func, args, .. } => { - let func_name = match func.as_ref() { - Expr::Var(ident) => ident.name.clone(), - _ => return Err(CGenError { - message: "Only direct function calls supported".to_string(), - span: None, - }), - }; - let arg_strs: Result, _> = args.iter().map(|a| self.emit_expr(a)).collect(); let args_str = arg_strs?.join(", "); - // Mangle user-defined function names - let c_func_name = if self.functions.contains(&func_name) { - self.mangle_name(&func_name) - } else { - func_name - }; + match func.as_ref() { + Expr::Var(ident) if self.functions.contains(&ident.name) => { + // Direct call to a known function + let c_func_name = self.mangle_name(&ident.name); + Ok(format!("{}({})", c_func_name, args_str)) + } + _ => { + // Indirect call - treat as closure + let closure_expr = self.emit_expr(func)?; + // Build the cast for the function pointer + // For now, assume all args are LuxInt and return LuxInt + let param_types: Vec<&str> = args.iter().map(|_| "LuxInt").collect(); + let params_str = if param_types.is_empty() { + String::new() + } else { + format!(", {}", param_types.join(", ")) + }; + let args_with_env = if args_str.is_empty() { + format!("({})->env", closure_expr) + } else { + format!("({})->env, {}", closure_expr, args_str) + }; + Ok(format!("((LuxInt(*)(void*{}))({})->fn_ptr)({})", + params_str, closure_expr, args_with_env)) + } + } + } - Ok(format!("{}({})", c_func_name, args_str)) + Expr::Lambda { params, body, return_type, .. } => { + // Find free variables in the lambda body + let param_names: HashSet = params.iter() + .map(|p| p.name.name.clone()) + .collect(); + let free_vars = self.find_free_vars(body, ¶m_names); + + // Generate unique closure ID + let id = self.fresh_name(); + + // Determine parameter types + let param_pairs: Vec<(String, String)> = params.iter() + .map(|p| { + let typ = self.type_expr_to_c(&p.typ) + .unwrap_or_else(|_| "LuxInt".to_string()); + (p.name.name.clone(), typ) + }) + .collect(); + + // Determine return type (default to LuxInt) + let ret_type = return_type.as_ref() + .map(|t| self.type_expr_to_c(t).unwrap_or_else(|_| "LuxInt".to_string())) + .unwrap_or_else(|| "LuxInt".to_string()); + + // Determine captured variable types (default to LuxInt for now) + let env_fields: Vec<(String, String)> = free_vars.iter() + .map(|v| (v.clone(), "LuxInt".to_string())) + .collect(); + + // Store closure info for later emission + self.closures.push(ClosureInfo { + id, + env_fields: env_fields.clone(), + params: param_pairs, + return_type: ret_type, + body: (**body).clone(), + }); + + // Generate code to create the closure at runtime + let temp_env = format!("_env_{}", id); + let temp_closure = format!("_closure_{}", id); + + // Allocate and initialize environment struct + if env_fields.is_empty() { + self.writeln(&format!("LuxClosure* {} = malloc(sizeof(LuxClosure));", temp_closure)); + self.writeln(&format!("{}->env = NULL;", temp_closure)); + } else { + self.writeln(&format!("LuxEnv_{}* {} = malloc(sizeof(LuxEnv_{}));", id, temp_env, id)); + for (name, _) in &env_fields { + self.writeln(&format!("{}->{} = {};", temp_env, name, name)); + } + self.writeln(&format!("LuxClosure* {} = malloc(sizeof(LuxClosure));", temp_closure)); + self.writeln(&format!("{}->env = {};", temp_closure, temp_env)); + } + self.writeln(&format!("{}->fn_ptr = (void*)lambda_{};", temp_closure, id)); + + Ok(temp_closure) } Expr::Block { statements, result, .. } => { @@ -390,7 +654,13 @@ impl CBackend { match stmt { Statement::Let { name, value, .. } => { let val = self.emit_expr(value)?; - self.writeln(&format!("LuxInt {} = {};", name.name, val)); + // Infer type from value: closures return LuxClosure* + let typ = if val.starts_with("_closure_") || self.is_closure_returning_call(value) { + "LuxClosure*" + } else { + "LuxInt" + }; + self.writeln(&format!("{} {} = {};", typ, name.name, val)); } Statement::Expr(e) => { let _ = self.emit_expr(e)?; @@ -580,7 +850,7 @@ impl CBackend { } TypeExpr::Unit => Ok("void".to_string()), TypeExpr::Versioned { base, .. } => self.type_expr_to_c(base), - TypeExpr::Function { .. } => Ok("void*".to_string()), + TypeExpr::Function { .. } => Ok("LuxClosure*".to_string()), TypeExpr::Tuple(_) => Ok("void*".to_string()), TypeExpr::Record(_) => Ok("void*".to_string()), } @@ -591,6 +861,126 @@ impl CBackend { self.name_counter } + /// Check if an expression is a call to a function that returns a closure + fn is_closure_returning_call(&self, expr: &Expr) -> bool { + match expr { + Expr::Lambda { .. } => true, + Expr::Call { func, .. } => { + // Check if we're calling a function known to return a closure + if let Expr::Var(ident) = func.as_ref() { + self.closure_returning_functions.contains(&ident.name) + } else { + false + } + } + _ => false, + } + } + + /// Find free variables in an expression (variables not in bound set) + fn find_free_vars(&self, expr: &Expr, bound: &HashSet) -> Vec { + let mut free = Vec::new(); + self.collect_free_vars(expr, bound, &mut free); + // Remove duplicates while preserving order + let mut seen = HashSet::new(); + free.retain(|v| seen.insert(v.clone())); + free + } + + fn collect_free_vars(&self, expr: &Expr, bound: &HashSet, free: &mut Vec) { + match expr { + Expr::Var(ident) => { + // If not bound locally and not a known function, it's a free variable + if !bound.contains(&ident.name) && !self.functions.contains(&ident.name) { + free.push(ident.name.clone()); + } + } + Expr::Literal(_) => {} + Expr::BinaryOp { left, right, .. } => { + self.collect_free_vars(left, bound, free); + self.collect_free_vars(right, bound, free); + } + Expr::UnaryOp { operand, .. } => { + self.collect_free_vars(operand, bound, free); + } + Expr::If { condition, then_branch, else_branch, .. } => { + self.collect_free_vars(condition, bound, free); + self.collect_free_vars(then_branch, bound, free); + self.collect_free_vars(else_branch, bound, free); + } + Expr::Call { func, args, .. } => { + self.collect_free_vars(func, bound, free); + for arg in args { + self.collect_free_vars(arg, bound, free); + } + } + Expr::Block { statements, result, .. } => { + let mut inner_bound = bound.clone(); + for stmt in statements { + match stmt { + Statement::Let { name, value, .. } => { + self.collect_free_vars(value, &inner_bound, free); + inner_bound.insert(name.name.clone()); + } + Statement::Expr(e) => { + self.collect_free_vars(e, &inner_bound, free); + } + } + } + self.collect_free_vars(result, &inner_bound, free); + } + Expr::Lambda { params, body, .. } => { + let mut inner_bound = bound.clone(); + for p in params { + inner_bound.insert(p.name.name.clone()); + } + self.collect_free_vars(body, &inner_bound, free); + } + Expr::Record { fields, .. } => { + for (_, val) in fields { + self.collect_free_vars(val, bound, free); + } + } + Expr::Field { object, .. } => { + self.collect_free_vars(object, bound, free); + } + Expr::Match { scrutinee, arms, .. } => { + self.collect_free_vars(scrutinee, bound, free); + for arm in arms { + // TODO: Handle pattern bindings + self.collect_free_vars(&arm.body, bound, free); + } + } + Expr::EffectOp { args, .. } => { + for arg in args { + self.collect_free_vars(arg, bound, free); + } + } + Expr::Run { expr, .. } => { + self.collect_free_vars(expr, bound, free); + } + Expr::Tuple { elements, .. } => { + for e in elements { + self.collect_free_vars(e, bound, free); + } + } + Expr::List { elements, .. } => { + for e in elements { + self.collect_free_vars(e, bound, free); + } + } + Expr::Resume { value, .. } => { + self.collect_free_vars(value, bound, free); + } + Expr::Let { value, body, name, .. } => { + self.collect_free_vars(value, bound, free); + let mut inner_bound = bound.clone(); + inner_bound.insert(name.name.clone()); + self.collect_free_vars(body, &inner_bound, free); + } + } + } + fn writeln(&mut self, line: &str) { let indent = " ".repeat(self.indent); writeln!(self.output, "{}{}", indent, line).unwrap(); @@ -643,4 +1033,40 @@ mod tests { let c_code = generate(source).unwrap(); assert!(c_code.contains("lux_console_print")); } + + #[test] + fn test_closure_basic() { + let source = r#" + fn makeAdder(n: Int): fn(Int): Int = + fn(x: Int): Int => x + n + "#; + let c_code = generate(source).unwrap(); + // Should have closure type + assert!(c_code.contains("LuxClosure")); + // Should have environment struct + assert!(c_code.contains("LuxEnv_")); + // Should have lambda function + assert!(c_code.contains("lambda_")); + // Function should return LuxClosure* + assert!(c_code.contains("LuxClosure* makeAdder_lux")); + } + + #[test] + fn test_closure_call() { + let source = r#" + fn makeAdder(n: Int): fn(Int): Int = + fn(x: Int): Int => x + n + + fn test(): Int = { + let add5 = makeAdder(5) + add5(10) + } + "#; + let c_code = generate(source).unwrap(); + // Local should be typed as LuxClosure* + assert!(c_code.contains("LuxClosure* add5")); + // Should have indirect call syntax + assert!(c_code.contains("->fn_ptr")); + assert!(c_code.contains("->env")); + } }