//! Cranelift-based native compiler for Lux //! //! This module compiles Lux programs to native machine code using Cranelift. //! Currently supports a subset of the language for performance-critical code. #![allow(dead_code)] use crate::ast::{Expr, Program, Declaration, FunctionDecl, BinaryOp, UnaryOp, LiteralKind, Statement}; use cranelift_codegen::ir::{AbiParam, InstBuilder, Value, types}; use cranelift_codegen::ir::condcodes::IntCC; use cranelift_codegen::Context; use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Variable}; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{Module, Linkage, FuncId}; use std::collections::HashMap; /// Errors that can occur during compilation #[derive(Debug)] pub struct CompileError { pub message: String, } impl std::fmt::Display for CompileError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Compile error: {}", self.message) } } impl std::error::Error for CompileError {} /// The JIT compiler for Lux pub struct JitCompiler { /// The Cranelift JIT module module: JITModule, /// Builder context (reusable) builder_context: FunctionBuilderContext, /// Cranelift context (reusable) ctx: Context, /// Compiled function pointers functions: HashMap, /// Function IDs for linking func_ids: HashMap, } impl JitCompiler { /// Create a new JIT compiler pub fn new() -> Result { let builder = JITBuilder::new(cranelift_module::default_libcall_names()) .map_err(|e| CompileError { message: e.to_string() })?; let module = JITModule::new(builder); Ok(Self { module, builder_context: FunctionBuilderContext::new(), ctx: Context::new(), functions: HashMap::new(), func_ids: HashMap::new(), }) } /// Compile a Lux function to native code pub fn compile_function(&mut self, func: &FunctionDecl) -> Result<*const u8, CompileError> { // Check if already compiled if let Some(ptr) = self.functions.get(&func.name.name) { return Ok(*ptr); } // Create function signature let mut sig = self.module.make_signature(); for _ in &func.params { sig.params.push(AbiParam::new(types::I64)); } sig.returns.push(AbiParam::new(types::I64)); // Declare the function let func_id = self.module .declare_function(&func.name.name, Linkage::Local, &sig) .map_err(|e| CompileError { message: e.to_string() })?; self.func_ids.insert(func.name.name.clone(), func_id); // Clear context for reuse self.ctx.clear(); self.ctx.func.signature = sig; // Clone func_ids for use in closure let func_ids = self.func_ids.clone(); // Build the function body { let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); // Create entry block let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); // Map parameter names to values let mut variables: HashMap = HashMap::new(); let params = builder.block_params(entry_block).to_vec(); for (i, param) in func.params.iter().enumerate() { let var = Variable::from_u32(i as u32); builder.declare_var(var, types::I64); builder.def_var(var, params[i]); variables.insert(param.name.name.clone(), var); } // Compile the function body let var_count = variables.len(); let result = compile_expr(&mut builder, &func.body, &mut variables, var_count, &func_ids, &mut self.module)?; // Return the result builder.ins().return_(&[result]); builder.finalize(); } // Compile to machine code self.module .define_function(func_id, &mut self.ctx) .map_err(|e| CompileError { message: e.to_string() })?; self.module.clear_context(&mut self.ctx); // Finalize and get the function pointer self.module.finalize_definitions() .map_err(|e| CompileError { message: e.to_string() })?; let ptr = self.module.get_finalized_function(func_id); self.functions.insert(func.name.name.clone(), ptr); Ok(ptr) } /// Compile and run a simple function that takes no args and returns an i64 pub fn compile_and_run(&mut self, func: &FunctionDecl) -> Result { let ptr = self.compile_function(func)?; // Cast to function pointer and call let func_ptr: fn() -> i64 = unsafe { std::mem::transmute(ptr) }; Ok(func_ptr()) } /// Compile a program and return pointers to all compiled functions pub fn compile_program(&mut self, program: &Program) -> Result<(), CompileError> { // First pass: declare all functions for decl in &program.declarations { if let Declaration::Function(func) = decl { let mut sig = self.module.make_signature(); for _ in &func.params { sig.params.push(AbiParam::new(types::I64)); } sig.returns.push(AbiParam::new(types::I64)); let func_id = self.module .declare_function(&func.name.name, Linkage::Local, &sig) .map_err(|e| CompileError { message: e.to_string() })?; self.func_ids.insert(func.name.name.clone(), func_id); } } // Second pass: compile all functions for decl in &program.declarations { if let Declaration::Function(func) = decl { self.compile_function_body(func)?; } } // Finalize self.module.finalize_definitions() .map_err(|e| CompileError { message: e.to_string() })?; // Store function pointers for (name, func_id) in &self.func_ids { let ptr = self.module.get_finalized_function(*func_id); self.functions.insert(name.clone(), ptr); } Ok(()) } /// Compile a function body (assumes function is already declared) fn compile_function_body(&mut self, func: &FunctionDecl) -> Result<(), CompileError> { let func_id = *self.func_ids.get(&func.name.name).ok_or_else(|| CompileError { message: format!("Function not declared: {}", func.name.name), })?; // Create signature let mut sig = self.module.make_signature(); for _ in &func.params { sig.params.push(AbiParam::new(types::I64)); } sig.returns.push(AbiParam::new(types::I64)); // Clear and set up context self.ctx.clear(); self.ctx.func.signature = sig; // Clone func_ids for use in closure let func_ids = self.func_ids.clone(); // Build function { let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let entry_block = builder.create_block(); builder.append_block_params_for_function_params(entry_block); builder.switch_to_block(entry_block); builder.seal_block(entry_block); let mut variables: HashMap = HashMap::new(); let params = builder.block_params(entry_block).to_vec(); for (i, param) in func.params.iter().enumerate() { let var = Variable::from_u32(i as u32); builder.declare_var(var, types::I64); builder.def_var(var, params[i]); variables.insert(param.name.name.clone(), var); } let var_count = variables.len(); let result = compile_expr(&mut builder, &func.body, &mut variables, var_count, &func_ids, &mut self.module)?; builder.ins().return_(&[result]); builder.finalize(); } // Define the function self.module .define_function(func_id, &mut self.ctx) .map_err(|e| CompileError { message: e.to_string() })?; self.module.clear_context(&mut self.ctx); Ok(()) } /// Get a compiled function pointer by name pub fn get_function(&self, name: &str) -> Option<*const u8> { self.functions.get(name).copied() } /// Call a compiled function with given i64 arguments pub unsafe fn call_function(&self, name: &str, args: &[i64]) -> Result { let ptr = self.get_function(name).ok_or_else(|| CompileError { message: format!("Function not found: {}", name), })?; match args.len() { 0 => { let f: fn() -> i64 = std::mem::transmute(ptr); Ok(f()) } 1 => { let f: fn(i64) -> i64 = std::mem::transmute(ptr); Ok(f(args[0])) } 2 => { let f: fn(i64, i64) -> i64 = std::mem::transmute(ptr); Ok(f(args[0], args[1])) } 3 => { let f: fn(i64, i64, i64) -> i64 = std::mem::transmute(ptr); Ok(f(args[0], args[1], args[2])) } _ => Err(CompileError { message: format!("Too many arguments: {}", args.len()), }), } } } impl Default for JitCompiler { fn default() -> Self { Self::new().expect("Failed to create JIT compiler") } } /// Compile an expression to Cranelift IR (free function to avoid borrow issues) fn compile_expr( builder: &mut FunctionBuilder, expr: &Expr, variables: &mut HashMap, next_var: usize, func_ids: &HashMap, module: &mut JITModule, ) -> Result { match expr { Expr::Literal(lit) => { match &lit.kind { LiteralKind::Int(n) => { Ok(builder.ins().iconst(types::I64, *n)) } LiteralKind::Bool(b) => { Ok(builder.ins().iconst(types::I64, if *b { 1 } else { 0 })) } _ => Err(CompileError { message: "Unsupported literal type".to_string() }), } } Expr::Var(ident) => { let var = variables.get(&ident.name).ok_or_else(|| CompileError { message: format!("Undefined variable: {}", ident.name), })?; Ok(builder.use_var(*var)) } Expr::BinaryOp { op, left, right, .. } => { let lhs = compile_expr(builder, left, variables, next_var, func_ids, module)?; let rhs = compile_expr(builder, right, variables, next_var, func_ids, module)?; let result = match op { BinaryOp::Add => builder.ins().iadd(lhs, rhs), BinaryOp::Sub => builder.ins().isub(lhs, rhs), BinaryOp::Mul => builder.ins().imul(lhs, rhs), BinaryOp::Div => builder.ins().sdiv(lhs, rhs), BinaryOp::Mod => builder.ins().srem(lhs, rhs), BinaryOp::Eq => { let cmp = builder.ins().icmp(IntCC::Equal, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::Ne => { let cmp = builder.ins().icmp(IntCC::NotEqual, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::Lt => { let cmp = builder.ins().icmp(IntCC::SignedLessThan, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::Le => { let cmp = builder.ins().icmp(IntCC::SignedLessThanOrEqual, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::Gt => { let cmp = builder.ins().icmp(IntCC::SignedGreaterThan, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::Ge => { let cmp = builder.ins().icmp(IntCC::SignedGreaterThanOrEqual, lhs, rhs); builder.ins().uextend(types::I64, cmp) } BinaryOp::And => builder.ins().band(lhs, rhs), BinaryOp::Or => builder.ins().bor(lhs, rhs), _ => return Err(CompileError { message: format!("Unsupported binary operator: {:?}", op), }), }; Ok(result) } Expr::UnaryOp { op, operand, .. } => { let val = compile_expr(builder, operand, variables, next_var, func_ids, module)?; let result = match op { UnaryOp::Neg => builder.ins().ineg(val), UnaryOp::Not => { let one = builder.ins().iconst(types::I64, 1); builder.ins().bxor(val, one) } }; Ok(result) } Expr::If { condition, then_branch, else_branch, .. } => { let cond_val = compile_expr(builder, condition, variables, next_var, func_ids, module)?; // Create blocks let then_block = builder.create_block(); let else_block = builder.create_block(); let merge_block = builder.create_block(); // Add block parameter for the result builder.append_block_param(merge_block, types::I64); // Branch based on condition let zero = builder.ins().iconst(types::I64, 0); let cmp = builder.ins().icmp(IntCC::NotEqual, cond_val, zero); builder.ins().brif(cmp, then_block, &[], else_block, &[]); // Then block builder.switch_to_block(then_block); builder.seal_block(then_block); let then_val = compile_expr(builder, then_branch, variables, next_var, func_ids, module)?; builder.ins().jump(merge_block, &[then_val]); // Else block builder.switch_to_block(else_block); builder.seal_block(else_block); let else_val = compile_expr(builder, else_branch, variables, next_var, func_ids, module)?; builder.ins().jump(merge_block, &[else_val]); // Merge block builder.switch_to_block(merge_block); builder.seal_block(merge_block); Ok(builder.block_params(merge_block)[0]) } Expr::Let { name, value, body, .. } => { // Compile the value let val = compile_expr(builder, value, variables, next_var, func_ids, module)?; // Create a new variable let var = Variable::from_u32(next_var as u32); builder.declare_var(var, types::I64); builder.def_var(var, val); variables.insert(name.name.clone(), var); // Compile the body with the new variable in scope compile_expr(builder, body, variables, next_var + 1, func_ids, module) } Expr::Call { func, args, .. } => { // Check if it's a direct function call if let Expr::Var(name) = func.as_ref() { // Look up the function let func_id = *func_ids.get(&name.name).ok_or_else(|| CompileError { message: format!("Unknown function: {}", name.name), })?; // Compile arguments let mut arg_values = Vec::new(); for arg in args { arg_values.push(compile_expr(builder, arg, variables, next_var, func_ids, module)?); } // Get function reference let func_ref = module.declare_func_in_func(func_id, builder.func); // Make the call let call = builder.ins().call(func_ref, &arg_values); Ok(builder.inst_results(call)[0]) } else { Err(CompileError { message: "Only direct function calls supported".to_string(), }) } } Expr::Block { statements, result: block_result, .. } => { let mut current_var = next_var; // Compile all statements for stmt in statements { match stmt { Statement::Let { name, value, .. } => { let val = compile_expr(builder, value, variables, current_var, func_ids, module)?; let var = Variable::from_u32(current_var as u32); builder.declare_var(var, types::I64); builder.def_var(var, val); variables.insert(name.name.clone(), var); current_var += 1; } Statement::Expr(expr) => { compile_expr(builder, expr, variables, current_var, func_ids, module)?; } } } // Compile and return the result expression compile_expr(builder, block_result, variables, current_var, func_ids, module) } _ => Err(CompileError { message: "Unsupported expression type".to_string(), }), } } #[cfg(test)] mod tests { use super::*; use crate::parser::Parser; fn parse_function(src: &str) -> FunctionDecl { let program = Parser::parse_source(src).expect("Parse error"); match &program.declarations[0] { Declaration::Function(f) => f.clone(), _ => panic!("Expected function"), } } #[test] fn test_simple_arithmetic() { let func = parse_function("fn test(): Int = 1 + 2 * 3"); let mut jit = JitCompiler::new().unwrap(); let result = jit.compile_and_run(&func).unwrap(); assert_eq!(result, 7); } #[test] fn test_conditionals() { let func = parse_function("fn test(): Int = if 1 > 0 then 42 else 0"); let mut jit = JitCompiler::new().unwrap(); let result = jit.compile_and_run(&func).unwrap(); assert_eq!(result, 42); } #[test] fn test_let_binding() { let func = parse_function("fn test(): Int = { let x = 10; let y = 20; x + y }"); let mut jit = JitCompiler::new().unwrap(); let result = jit.compile_and_run(&func).unwrap(); assert_eq!(result, 30); } #[test] fn test_recursive_fibonacci() { use std::time::Instant; // Compile a program with recursive fibonacci let src = r#" fn fib(n: Int): Int = if n <= 1 then n else fib(n - 1) + fib(n - 2) "#; let program = Parser::parse_source(src).expect("Parse error"); let mut jit = JitCompiler::new().unwrap(); let compile_start = Instant::now(); jit.compile_program(&program).unwrap(); let compile_time = compile_start.elapsed(); // Call fib(30) let exec_start = Instant::now(); let result = unsafe { jit.call_function("fib", &[30]).unwrap() }; let exec_time = exec_start.elapsed(); println!("\n=== JIT Benchmark ==="); println!("Compile time: {:?}", compile_time); println!("Execute fib(30) = {} in {:?}", result, exec_time); println!("Total: {:?}", compile_time + exec_time); assert_eq!(result, 832040); } }