feat: implement tail call optimization (TCO)
Add trampoline-based tail call optimization to prevent stack overflow on deeply recursive tail-recursive functions. The implementation: - Extends EvalResult with TailCall variant for deferred evaluation - Adds trampoline loop in eval_expr() to handle tail calls iteratively - Propagates tail position through If, Let, Match, and Block expressions - Updates all builtin callbacks to handle tail calls via eval_call_to_value - Includes tests for deep recursion (10000+ calls) and accumulator patterns Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -423,10 +423,16 @@ impl Continuation {
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of evaluation (either a value or an effect request)
|
||||
/// Result of evaluation (either a value, effect request, or tail call)
|
||||
pub enum EvalResult {
|
||||
Value(Value),
|
||||
Effect(EffectRequest),
|
||||
/// Tail call optimization: instead of recursing, return the call to be trampolined
|
||||
TailCall {
|
||||
func: Value,
|
||||
args: Vec<Value>,
|
||||
span: Span,
|
||||
},
|
||||
}
|
||||
|
||||
/// Effect trace entry for debugging
|
||||
@@ -882,18 +888,32 @@ impl Interpreter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate an expression
|
||||
/// Evaluate an expression with tail call optimization (trampoline)
|
||||
fn eval_expr(&mut self, expr: &Expr, env: &Env) -> Result<Value, RuntimeError> {
|
||||
match self.eval_expr_inner(expr, env)? {
|
||||
EvalResult::Value(v) => Ok(v),
|
||||
let mut result = self.eval_expr_inner(expr, env)?;
|
||||
|
||||
// Trampoline loop for tail call optimization
|
||||
loop {
|
||||
match result {
|
||||
EvalResult::Value(v) => return Ok(v),
|
||||
EvalResult::Effect(req) => {
|
||||
// Handle the effect
|
||||
self.handle_effect(req)
|
||||
return self.handle_effect(req);
|
||||
}
|
||||
EvalResult::TailCall { func, args, span } => {
|
||||
// Continue the tail call without growing the stack
|
||||
result = self.eval_call(func, args, span)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_expr_inner(&mut self, expr: &Expr, env: &Env) -> Result<EvalResult, RuntimeError> {
|
||||
self.eval_expr_tail(expr, env, false)
|
||||
}
|
||||
|
||||
/// Evaluate an expression, with tail position tracking for TCO
|
||||
fn eval_expr_tail(&mut self, expr: &Expr, env: &Env, tail: bool) -> Result<EvalResult, RuntimeError> {
|
||||
match expr {
|
||||
Expr::Literal(lit) => Ok(EvalResult::Value(self.eval_literal(lit))),
|
||||
|
||||
@@ -930,8 +950,17 @@ impl Interpreter {
|
||||
.map(|a| self.eval_expr(a, env))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
// If we're in tail position, return TailCall for trampoline
|
||||
if tail {
|
||||
Ok(EvalResult::TailCall {
|
||||
func: func_val,
|
||||
args: arg_vals,
|
||||
span: *span,
|
||||
})
|
||||
} else {
|
||||
self.eval_call(func_val, arg_vals, *span)
|
||||
}
|
||||
}
|
||||
|
||||
Expr::EffectOp {
|
||||
effect,
|
||||
@@ -1013,7 +1042,8 @@ impl Interpreter {
|
||||
let val = self.eval_expr(value, env)?;
|
||||
let new_env = env.extend();
|
||||
new_env.define(&name.name, val);
|
||||
self.eval_expr_inner(body, &new_env)
|
||||
// Body of let is in tail position if the let itself is
|
||||
self.eval_expr_tail(body, &new_env, tail)
|
||||
}
|
||||
|
||||
Expr::If {
|
||||
@@ -1024,8 +1054,9 @@ impl Interpreter {
|
||||
} => {
|
||||
let cond_val = self.eval_expr(condition, env)?;
|
||||
match cond_val {
|
||||
Value::Bool(true) => self.eval_expr_inner(then_branch, env),
|
||||
Value::Bool(false) => self.eval_expr_inner(else_branch, env),
|
||||
// Branches are in tail position if the if itself is
|
||||
Value::Bool(true) => self.eval_expr_tail(then_branch, env, tail),
|
||||
Value::Bool(false) => self.eval_expr_tail(else_branch, env, tail),
|
||||
_ => Err(RuntimeError {
|
||||
message: format!("If condition must be Bool, got {}", cond_val.type_name()),
|
||||
span: Some(*span),
|
||||
@@ -1039,7 +1070,8 @@ impl Interpreter {
|
||||
span,
|
||||
} => {
|
||||
let val = self.eval_expr(scrutinee, env)?;
|
||||
self.eval_match(val, arms, env, *span)
|
||||
// Match arms are in tail position if the match itself is
|
||||
self.eval_match(val, arms, env, *span, tail)
|
||||
}
|
||||
|
||||
Expr::Block {
|
||||
@@ -1057,7 +1089,8 @@ impl Interpreter {
|
||||
}
|
||||
}
|
||||
}
|
||||
self.eval_expr_inner(result, &block_env)
|
||||
// Block result is in tail position if the block itself is
|
||||
self.eval_expr_tail(result, &block_env, tail)
|
||||
}
|
||||
|
||||
Expr::Record { fields, .. } => {
|
||||
@@ -1229,14 +1262,7 @@ impl Interpreter {
|
||||
},
|
||||
BinaryOp::Pipe => {
|
||||
// a |> f means f(a)
|
||||
self.eval_call(right, vec![left], span)
|
||||
.and_then(|r| match r {
|
||||
EvalResult::Value(v) => Ok(v),
|
||||
EvalResult::Effect(_) => Err(RuntimeError {
|
||||
message: "Effect in pipe expression".to_string(),
|
||||
span: Some(span),
|
||||
}),
|
||||
})
|
||||
self.eval_call_to_value(right, vec![left], span)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1285,7 +1311,8 @@ impl Interpreter {
|
||||
call_env.define(param, arg);
|
||||
}
|
||||
|
||||
self.eval_expr_inner(&closure.body, &call_env)
|
||||
// Evaluate body in tail position for TCO
|
||||
self.eval_expr_tail(&closure.body, &call_env, true)
|
||||
}
|
||||
Value::Constructor { name, fields } => {
|
||||
// Constructor application
|
||||
@@ -1304,6 +1331,31 @@ impl Interpreter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Fully evaluate a call, handling any tail calls via trampoline.
|
||||
/// Used by builtins that need to call user functions and get a value back.
|
||||
fn eval_call_to_value(
|
||||
&mut self,
|
||||
func: Value,
|
||||
args: Vec<Value>,
|
||||
span: Span,
|
||||
) -> Result<Value, RuntimeError> {
|
||||
let mut result = self.eval_call(func, args, span)?;
|
||||
loop {
|
||||
match result {
|
||||
EvalResult::Value(v) => return Ok(v),
|
||||
EvalResult::Effect(_) => {
|
||||
return Err(RuntimeError {
|
||||
message: "Effect in callback not supported".to_string(),
|
||||
span: Some(span),
|
||||
});
|
||||
}
|
||||
EvalResult::TailCall { func, args, span } => {
|
||||
result = self.eval_call(func, args, span)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn eval_builtin(
|
||||
&mut self,
|
||||
builtin: BuiltinFn,
|
||||
@@ -1322,10 +1374,8 @@ impl Interpreter {
|
||||
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.map", span)?;
|
||||
let mut result = Vec::with_capacity(list.len());
|
||||
for item in list {
|
||||
match self.eval_call(func.clone(), vec![item], span)? {
|
||||
EvalResult::Value(v) => result.push(v),
|
||||
EvalResult::Effect(_) => return Err(err("Effect in List.map callback")),
|
||||
}
|
||||
let v = self.eval_call_to_value(func.clone(), vec![item], span)?;
|
||||
result.push(v);
|
||||
}
|
||||
Ok(EvalResult::Value(Value::List(result)))
|
||||
}
|
||||
@@ -1335,16 +1385,16 @@ impl Interpreter {
|
||||
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.filter", span)?;
|
||||
let mut result = Vec::new();
|
||||
for item in list {
|
||||
match self.eval_call(func.clone(), vec![item.clone()], span)? {
|
||||
EvalResult::Value(Value::Bool(true)) => result.push(item),
|
||||
EvalResult::Value(Value::Bool(false)) => {}
|
||||
EvalResult::Value(v) => {
|
||||
let v = self.eval_call_to_value(func.clone(), vec![item.clone()], span)?;
|
||||
match v {
|
||||
Value::Bool(true) => result.push(item),
|
||||
Value::Bool(false) => {}
|
||||
_ => {
|
||||
return Err(err(&format!(
|
||||
"List.filter predicate must return Bool, got {}",
|
||||
v.type_name()
|
||||
)))
|
||||
}
|
||||
EvalResult::Effect(_) => return Err(err("Effect in List.filter callback")),
|
||||
}
|
||||
}
|
||||
Ok(EvalResult::Value(Value::List(result)))
|
||||
@@ -1370,10 +1420,7 @@ impl Interpreter {
|
||||
let func = args[2].clone();
|
||||
|
||||
for item in list {
|
||||
match self.eval_call(func.clone(), vec![acc, item], span)? {
|
||||
EvalResult::Value(v) => acc = v,
|
||||
EvalResult::Effect(_) => return Err(err("Effect in List.fold callback")),
|
||||
}
|
||||
acc = self.eval_call_to_value(func.clone(), vec![acc, item], span)?;
|
||||
}
|
||||
Ok(EvalResult::Value(acc))
|
||||
}
|
||||
@@ -1538,13 +1585,11 @@ impl Interpreter {
|
||||
let (opt, func) = Self::expect_args_2::<Value, Value>(&args, "Option.map", span)?;
|
||||
match opt {
|
||||
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
||||
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
|
||||
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
name: "Some".to_string(),
|
||||
fields: vec![v],
|
||||
})),
|
||||
EvalResult::Effect(_) => Err(err("Effect in Option.map callback")),
|
||||
}
|
||||
}))
|
||||
}
|
||||
Value::Constructor { name, .. } if name == "None" => {
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
@@ -1564,10 +1609,8 @@ impl Interpreter {
|
||||
Self::expect_args_2::<Value, Value>(&args, "Option.flatMap", span)?;
|
||||
match opt {
|
||||
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
||||
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
|
||||
EvalResult::Effect(_) => Err(err("Effect in Option.flatMap callback")),
|
||||
}
|
||||
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||
Ok(EvalResult::Value(v))
|
||||
}
|
||||
Value::Constructor { name, .. } if name == "None" => {
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
@@ -1636,13 +1679,11 @@ impl Interpreter {
|
||||
let (res, func) = Self::expect_args_2::<Value, Value>(&args, "Result.map", span)?;
|
||||
match res {
|
||||
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
||||
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
|
||||
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
name: "Ok".to_string(),
|
||||
fields: vec![v],
|
||||
})),
|
||||
EvalResult::Effect(_) => Err(err("Effect in Result.map callback")),
|
||||
}
|
||||
}))
|
||||
}
|
||||
Value::Constructor { name, fields } if name == "Err" => {
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
@@ -1662,10 +1703,8 @@ impl Interpreter {
|
||||
Self::expect_args_2::<Value, Value>(&args, "Result.flatMap", span)?;
|
||||
match res {
|
||||
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
||||
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
|
||||
EvalResult::Effect(_) => Err(err("Effect in Result.flatMap callback")),
|
||||
}
|
||||
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||
Ok(EvalResult::Value(v))
|
||||
}
|
||||
Value::Constructor { name, fields } if name == "Err" => {
|
||||
Ok(EvalResult::Value(Value::Constructor {
|
||||
@@ -1804,6 +1843,7 @@ impl Interpreter {
|
||||
arms: &[MatchArm],
|
||||
env: &Env,
|
||||
span: Span,
|
||||
tail: bool,
|
||||
) -> Result<EvalResult, RuntimeError> {
|
||||
for arm in arms {
|
||||
if let Some(bindings) = self.match_pattern(&arm.pattern, &val) {
|
||||
@@ -1827,7 +1867,8 @@ impl Interpreter {
|
||||
}
|
||||
}
|
||||
|
||||
return self.eval_expr_inner(&arm.body, &match_env);
|
||||
// Match arm body is in tail position if the match itself is
|
||||
return self.eval_expr_tail(&arm.body, &match_env, tail);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
36
src/main.rs
36
src/main.rs
@@ -1282,5 +1282,41 @@ c")"#;
|
||||
let result = eval(source);
|
||||
assert!(result.is_ok(), "Expected success but got: {:?}", result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tail_call_optimization() {
|
||||
// This test verifies that tail-recursive functions don't overflow the stack.
|
||||
// Without TCO, a countdown from 10000 would cause a stack overflow.
|
||||
let source = r#"
|
||||
fn countdown(n: Int): Int = if n <= 0 then 0 else countdown(n - 1)
|
||||
let result = countdown(10000)
|
||||
"#;
|
||||
assert_eq!(eval(source).unwrap(), "0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tail_call_with_accumulator() {
|
||||
// Test TCO with an accumulator pattern (common for tail-recursive sum)
|
||||
let source = r#"
|
||||
fn sum_to(n: Int, acc: Int): Int = if n <= 0 then acc else sum_to(n - 1, acc + n)
|
||||
let result = sum_to(1000, 0)
|
||||
"#;
|
||||
// Sum from 1 to 1000 = 1000 * 1001 / 2 = 500500
|
||||
assert_eq!(eval(source).unwrap(), "500500");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tail_call_in_match() {
|
||||
// Test that TCO works through match expressions
|
||||
let source = r#"
|
||||
fn process(opt: Option<Int>, acc: Int): Int = match opt {
|
||||
Some(n) => if n <= 0 then acc else process(Some(n - 1), acc + n),
|
||||
None => acc
|
||||
}
|
||||
let result = process(Some(100), 0)
|
||||
"#;
|
||||
// Sum from 1 to 100 = 5050
|
||||
assert_eq!(eval(source).unwrap(), "5050");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user