diff --git a/src/interpreter.rs b/src/interpreter.rs index ddbaab4..166c11f 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -423,10 +423,16 @@ impl Continuation { } } -/// Result of evaluation (either a value or an effect request) +/// Result of evaluation (either a value, effect request, or tail call) pub enum EvalResult { Value(Value), Effect(EffectRequest), + /// Tail call optimization: instead of recursing, return the call to be trampolined + TailCall { + func: Value, + args: Vec, + span: Span, + }, } /// Effect trace entry for debugging @@ -882,18 +888,32 @@ impl Interpreter { } } - /// Evaluate an expression + /// Evaluate an expression with tail call optimization (trampoline) fn eval_expr(&mut self, expr: &Expr, env: &Env) -> Result { - match self.eval_expr_inner(expr, env)? { - EvalResult::Value(v) => Ok(v), - EvalResult::Effect(req) => { - // Handle the effect - self.handle_effect(req) + let mut result = self.eval_expr_inner(expr, env)?; + + // Trampoline loop for tail call optimization + loop { + match result { + EvalResult::Value(v) => return Ok(v), + EvalResult::Effect(req) => { + // Handle the effect + return self.handle_effect(req); + } + EvalResult::TailCall { func, args, span } => { + // Continue the tail call without growing the stack + result = self.eval_call(func, args, span)?; + } } } } fn eval_expr_inner(&mut self, expr: &Expr, env: &Env) -> Result { + self.eval_expr_tail(expr, env, false) + } + + /// Evaluate an expression, with tail position tracking for TCO + fn eval_expr_tail(&mut self, expr: &Expr, env: &Env, tail: bool) -> Result { match expr { Expr::Literal(lit) => Ok(EvalResult::Value(self.eval_literal(lit))), @@ -930,7 +950,16 @@ impl Interpreter { .map(|a| self.eval_expr(a, env)) .collect::>()?; - self.eval_call(func_val, arg_vals, *span) + // If we're in tail position, return TailCall for trampoline + if tail { + Ok(EvalResult::TailCall { + func: func_val, + args: arg_vals, + span: *span, + }) + } else { + self.eval_call(func_val, arg_vals, *span) + } } Expr::EffectOp { @@ -1013,7 +1042,8 @@ impl Interpreter { let val = self.eval_expr(value, env)?; let new_env = env.extend(); new_env.define(&name.name, val); - self.eval_expr_inner(body, &new_env) + // Body of let is in tail position if the let itself is + self.eval_expr_tail(body, &new_env, tail) } Expr::If { @@ -1024,8 +1054,9 @@ impl Interpreter { } => { let cond_val = self.eval_expr(condition, env)?; match cond_val { - Value::Bool(true) => self.eval_expr_inner(then_branch, env), - Value::Bool(false) => self.eval_expr_inner(else_branch, env), + // Branches are in tail position if the if itself is + Value::Bool(true) => self.eval_expr_tail(then_branch, env, tail), + Value::Bool(false) => self.eval_expr_tail(else_branch, env, tail), _ => Err(RuntimeError { message: format!("If condition must be Bool, got {}", cond_val.type_name()), span: Some(*span), @@ -1039,7 +1070,8 @@ impl Interpreter { span, } => { let val = self.eval_expr(scrutinee, env)?; - self.eval_match(val, arms, env, *span) + // Match arms are in tail position if the match itself is + self.eval_match(val, arms, env, *span, tail) } Expr::Block { @@ -1057,7 +1089,8 @@ impl Interpreter { } } } - self.eval_expr_inner(result, &block_env) + // Block result is in tail position if the block itself is + self.eval_expr_tail(result, &block_env, tail) } Expr::Record { fields, .. } => { @@ -1229,14 +1262,7 @@ impl Interpreter { }, BinaryOp::Pipe => { // a |> f means f(a) - self.eval_call(right, vec![left], span) - .and_then(|r| match r { - EvalResult::Value(v) => Ok(v), - EvalResult::Effect(_) => Err(RuntimeError { - message: "Effect in pipe expression".to_string(), - span: Some(span), - }), - }) + self.eval_call_to_value(right, vec![left], span) } } } @@ -1285,7 +1311,8 @@ impl Interpreter { call_env.define(param, arg); } - self.eval_expr_inner(&closure.body, &call_env) + // Evaluate body in tail position for TCO + self.eval_expr_tail(&closure.body, &call_env, true) } Value::Constructor { name, fields } => { // Constructor application @@ -1304,6 +1331,31 @@ impl Interpreter { } } + /// Fully evaluate a call, handling any tail calls via trampoline. + /// Used by builtins that need to call user functions and get a value back. + fn eval_call_to_value( + &mut self, + func: Value, + args: Vec, + span: Span, + ) -> Result { + let mut result = self.eval_call(func, args, span)?; + loop { + match result { + EvalResult::Value(v) => return Ok(v), + EvalResult::Effect(_) => { + return Err(RuntimeError { + message: "Effect in callback not supported".to_string(), + span: Some(span), + }); + } + EvalResult::TailCall { func, args, span } => { + result = self.eval_call(func, args, span)?; + } + } + } + } + fn eval_builtin( &mut self, builtin: BuiltinFn, @@ -1322,10 +1374,8 @@ impl Interpreter { Self::expect_args_2::, Value>(&args, "List.map", span)?; let mut result = Vec::with_capacity(list.len()); for item in list { - match self.eval_call(func.clone(), vec![item], span)? { - EvalResult::Value(v) => result.push(v), - EvalResult::Effect(_) => return Err(err("Effect in List.map callback")), - } + let v = self.eval_call_to_value(func.clone(), vec![item], span)?; + result.push(v); } Ok(EvalResult::Value(Value::List(result))) } @@ -1335,16 +1385,16 @@ impl Interpreter { Self::expect_args_2::, Value>(&args, "List.filter", span)?; let mut result = Vec::new(); for item in list { - match self.eval_call(func.clone(), vec![item.clone()], span)? { - EvalResult::Value(Value::Bool(true)) => result.push(item), - EvalResult::Value(Value::Bool(false)) => {} - EvalResult::Value(v) => { + let v = self.eval_call_to_value(func.clone(), vec![item.clone()], span)?; + match v { + Value::Bool(true) => result.push(item), + Value::Bool(false) => {} + _ => { return Err(err(&format!( "List.filter predicate must return Bool, got {}", v.type_name() ))) } - EvalResult::Effect(_) => return Err(err("Effect in List.filter callback")), } } Ok(EvalResult::Value(Value::List(result))) @@ -1370,10 +1420,7 @@ impl Interpreter { let func = args[2].clone(); for item in list { - match self.eval_call(func.clone(), vec![acc, item], span)? { - EvalResult::Value(v) => acc = v, - EvalResult::Effect(_) => return Err(err("Effect in List.fold callback")), - } + acc = self.eval_call_to_value(func.clone(), vec![acc, item], span)?; } Ok(EvalResult::Value(acc)) } @@ -1538,13 +1585,11 @@ impl Interpreter { let (opt, func) = Self::expect_args_2::(&args, "Option.map", span)?; match opt { Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => { - match self.eval_call(func, vec![fields[0].clone()], span)? { - EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor { - name: "Some".to_string(), - fields: vec![v], - })), - EvalResult::Effect(_) => Err(err("Effect in Option.map callback")), - } + let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?; + Ok(EvalResult::Value(Value::Constructor { + name: "Some".to_string(), + fields: vec![v], + })) } Value::Constructor { name, .. } if name == "None" => { Ok(EvalResult::Value(Value::Constructor { @@ -1564,10 +1609,8 @@ impl Interpreter { Self::expect_args_2::(&args, "Option.flatMap", span)?; match opt { Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => { - match self.eval_call(func, vec![fields[0].clone()], span)? { - EvalResult::Value(v) => Ok(EvalResult::Value(v)), - EvalResult::Effect(_) => Err(err("Effect in Option.flatMap callback")), - } + let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?; + Ok(EvalResult::Value(v)) } Value::Constructor { name, .. } if name == "None" => { Ok(EvalResult::Value(Value::Constructor { @@ -1636,13 +1679,11 @@ impl Interpreter { let (res, func) = Self::expect_args_2::(&args, "Result.map", span)?; match res { Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => { - match self.eval_call(func, vec![fields[0].clone()], span)? { - EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor { - name: "Ok".to_string(), - fields: vec![v], - })), - EvalResult::Effect(_) => Err(err("Effect in Result.map callback")), - } + let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?; + Ok(EvalResult::Value(Value::Constructor { + name: "Ok".to_string(), + fields: vec![v], + })) } Value::Constructor { name, fields } if name == "Err" => { Ok(EvalResult::Value(Value::Constructor { @@ -1662,10 +1703,8 @@ impl Interpreter { Self::expect_args_2::(&args, "Result.flatMap", span)?; match res { Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => { - match self.eval_call(func, vec![fields[0].clone()], span)? { - EvalResult::Value(v) => Ok(EvalResult::Value(v)), - EvalResult::Effect(_) => Err(err("Effect in Result.flatMap callback")), - } + let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?; + Ok(EvalResult::Value(v)) } Value::Constructor { name, fields } if name == "Err" => { Ok(EvalResult::Value(Value::Constructor { @@ -1804,6 +1843,7 @@ impl Interpreter { arms: &[MatchArm], env: &Env, span: Span, + tail: bool, ) -> Result { for arm in arms { if let Some(bindings) = self.match_pattern(&arm.pattern, &val) { @@ -1827,7 +1867,8 @@ impl Interpreter { } } - return self.eval_expr_inner(&arm.body, &match_env); + // Match arm body is in tail position if the match itself is + return self.eval_expr_tail(&arm.body, &match_env, tail); } } diff --git a/src/main.rs b/src/main.rs index 80e93df..06e6ab0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1282,5 +1282,41 @@ c")"#; let result = eval(source); assert!(result.is_ok(), "Expected success but got: {:?}", result); } + + #[test] + fn test_tail_call_optimization() { + // This test verifies that tail-recursive functions don't overflow the stack. + // Without TCO, a countdown from 10000 would cause a stack overflow. + let source = r#" + fn countdown(n: Int): Int = if n <= 0 then 0 else countdown(n - 1) + let result = countdown(10000) + "#; + assert_eq!(eval(source).unwrap(), "0"); + } + + #[test] + fn test_tail_call_with_accumulator() { + // Test TCO with an accumulator pattern (common for tail-recursive sum) + let source = r#" + fn sum_to(n: Int, acc: Int): Int = if n <= 0 then acc else sum_to(n - 1, acc + n) + let result = sum_to(1000, 0) + "#; + // Sum from 1 to 1000 = 1000 * 1001 / 2 = 500500 + assert_eq!(eval(source).unwrap(), "500500"); + } + + #[test] + fn test_tail_call_in_match() { + // Test that TCO works through match expressions + let source = r#" + fn process(opt: Option, acc: Int): Int = match opt { + Some(n) => if n <= 0 then acc else process(Some(n - 1), acc + n), + None => acc + } + let result = process(Some(100), 0) + "#; + // Sum from 1 to 100 = 5050 + assert_eq!(eval(source).unwrap(), "5050"); + } } }