feat: implement tail call optimization (TCO)

Add trampoline-based tail call optimization to prevent stack overflow
on deeply recursive tail-recursive functions. The implementation:

- Extends EvalResult with TailCall variant for deferred evaluation
- Adds trampoline loop in eval_expr() to handle tail calls iteratively
- Propagates tail position through If, Let, Match, and Block expressions
- Updates all builtin callbacks to handle tail calls via eval_call_to_value
- Includes tests for deep recursion (10000+ calls) and accumulator patterns

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 04:42:53 -05:00
parent 052db9c88f
commit df5c0a1a32
2 changed files with 135 additions and 58 deletions

View File

@@ -423,10 +423,16 @@ impl Continuation {
}
}
/// Result of evaluation (either a value or an effect request)
/// Result of evaluation (either a value, effect request, or tail call)
pub enum EvalResult {
Value(Value),
Effect(EffectRequest),
/// Tail call optimization: instead of recursing, return the call to be trampolined
TailCall {
func: Value,
args: Vec<Value>,
span: Span,
},
}
/// Effect trace entry for debugging
@@ -882,18 +888,32 @@ impl Interpreter {
}
}
/// Evaluate an expression
/// Evaluate an expression with tail call optimization (trampoline)
fn eval_expr(&mut self, expr: &Expr, env: &Env) -> Result<Value, RuntimeError> {
match self.eval_expr_inner(expr, env)? {
EvalResult::Value(v) => Ok(v),
let mut result = self.eval_expr_inner(expr, env)?;
// Trampoline loop for tail call optimization
loop {
match result {
EvalResult::Value(v) => return Ok(v),
EvalResult::Effect(req) => {
// Handle the effect
self.handle_effect(req)
return self.handle_effect(req);
}
EvalResult::TailCall { func, args, span } => {
// Continue the tail call without growing the stack
result = self.eval_call(func, args, span)?;
}
}
}
}
fn eval_expr_inner(&mut self, expr: &Expr, env: &Env) -> Result<EvalResult, RuntimeError> {
self.eval_expr_tail(expr, env, false)
}
/// Evaluate an expression, with tail position tracking for TCO
fn eval_expr_tail(&mut self, expr: &Expr, env: &Env, tail: bool) -> Result<EvalResult, RuntimeError> {
match expr {
Expr::Literal(lit) => Ok(EvalResult::Value(self.eval_literal(lit))),
@@ -930,8 +950,17 @@ impl Interpreter {
.map(|a| self.eval_expr(a, env))
.collect::<Result<_, _>>()?;
// If we're in tail position, return TailCall for trampoline
if tail {
Ok(EvalResult::TailCall {
func: func_val,
args: arg_vals,
span: *span,
})
} else {
self.eval_call(func_val, arg_vals, *span)
}
}
Expr::EffectOp {
effect,
@@ -1013,7 +1042,8 @@ impl Interpreter {
let val = self.eval_expr(value, env)?;
let new_env = env.extend();
new_env.define(&name.name, val);
self.eval_expr_inner(body, &new_env)
// Body of let is in tail position if the let itself is
self.eval_expr_tail(body, &new_env, tail)
}
Expr::If {
@@ -1024,8 +1054,9 @@ impl Interpreter {
} => {
let cond_val = self.eval_expr(condition, env)?;
match cond_val {
Value::Bool(true) => self.eval_expr_inner(then_branch, env),
Value::Bool(false) => self.eval_expr_inner(else_branch, env),
// Branches are in tail position if the if itself is
Value::Bool(true) => self.eval_expr_tail(then_branch, env, tail),
Value::Bool(false) => self.eval_expr_tail(else_branch, env, tail),
_ => Err(RuntimeError {
message: format!("If condition must be Bool, got {}", cond_val.type_name()),
span: Some(*span),
@@ -1039,7 +1070,8 @@ impl Interpreter {
span,
} => {
let val = self.eval_expr(scrutinee, env)?;
self.eval_match(val, arms, env, *span)
// Match arms are in tail position if the match itself is
self.eval_match(val, arms, env, *span, tail)
}
Expr::Block {
@@ -1057,7 +1089,8 @@ impl Interpreter {
}
}
}
self.eval_expr_inner(result, &block_env)
// Block result is in tail position if the block itself is
self.eval_expr_tail(result, &block_env, tail)
}
Expr::Record { fields, .. } => {
@@ -1229,14 +1262,7 @@ impl Interpreter {
},
BinaryOp::Pipe => {
// a |> f means f(a)
self.eval_call(right, vec![left], span)
.and_then(|r| match r {
EvalResult::Value(v) => Ok(v),
EvalResult::Effect(_) => Err(RuntimeError {
message: "Effect in pipe expression".to_string(),
span: Some(span),
}),
})
self.eval_call_to_value(right, vec![left], span)
}
}
}
@@ -1285,7 +1311,8 @@ impl Interpreter {
call_env.define(param, arg);
}
self.eval_expr_inner(&closure.body, &call_env)
// Evaluate body in tail position for TCO
self.eval_expr_tail(&closure.body, &call_env, true)
}
Value::Constructor { name, fields } => {
// Constructor application
@@ -1304,6 +1331,31 @@ impl Interpreter {
}
}
/// Fully evaluate a call, handling any tail calls via trampoline.
/// Used by builtins that need to call user functions and get a value back.
fn eval_call_to_value(
&mut self,
func: Value,
args: Vec<Value>,
span: Span,
) -> Result<Value, RuntimeError> {
let mut result = self.eval_call(func, args, span)?;
loop {
match result {
EvalResult::Value(v) => return Ok(v),
EvalResult::Effect(_) => {
return Err(RuntimeError {
message: "Effect in callback not supported".to_string(),
span: Some(span),
});
}
EvalResult::TailCall { func, args, span } => {
result = self.eval_call(func, args, span)?;
}
}
}
}
fn eval_builtin(
&mut self,
builtin: BuiltinFn,
@@ -1322,10 +1374,8 @@ impl Interpreter {
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.map", span)?;
let mut result = Vec::with_capacity(list.len());
for item in list {
match self.eval_call(func.clone(), vec![item], span)? {
EvalResult::Value(v) => result.push(v),
EvalResult::Effect(_) => return Err(err("Effect in List.map callback")),
}
let v = self.eval_call_to_value(func.clone(), vec![item], span)?;
result.push(v);
}
Ok(EvalResult::Value(Value::List(result)))
}
@@ -1335,16 +1385,16 @@ impl Interpreter {
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.filter", span)?;
let mut result = Vec::new();
for item in list {
match self.eval_call(func.clone(), vec![item.clone()], span)? {
EvalResult::Value(Value::Bool(true)) => result.push(item),
EvalResult::Value(Value::Bool(false)) => {}
EvalResult::Value(v) => {
let v = self.eval_call_to_value(func.clone(), vec![item.clone()], span)?;
match v {
Value::Bool(true) => result.push(item),
Value::Bool(false) => {}
_ => {
return Err(err(&format!(
"List.filter predicate must return Bool, got {}",
v.type_name()
)))
}
EvalResult::Effect(_) => return Err(err("Effect in List.filter callback")),
}
}
Ok(EvalResult::Value(Value::List(result)))
@@ -1370,10 +1420,7 @@ impl Interpreter {
let func = args[2].clone();
for item in list {
match self.eval_call(func.clone(), vec![acc, item], span)? {
EvalResult::Value(v) => acc = v,
EvalResult::Effect(_) => return Err(err("Effect in List.fold callback")),
}
acc = self.eval_call_to_value(func.clone(), vec![acc, item], span)?;
}
Ok(EvalResult::Value(acc))
}
@@ -1538,13 +1585,11 @@ impl Interpreter {
let (opt, func) = Self::expect_args_2::<Value, Value>(&args, "Option.map", span)?;
match opt {
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
match self.eval_call(func, vec![fields[0].clone()], span)? {
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
Ok(EvalResult::Value(Value::Constructor {
name: "Some".to_string(),
fields: vec![v],
})),
EvalResult::Effect(_) => Err(err("Effect in Option.map callback")),
}
}))
}
Value::Constructor { name, .. } if name == "None" => {
Ok(EvalResult::Value(Value::Constructor {
@@ -1564,10 +1609,8 @@ impl Interpreter {
Self::expect_args_2::<Value, Value>(&args, "Option.flatMap", span)?;
match opt {
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
match self.eval_call(func, vec![fields[0].clone()], span)? {
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
EvalResult::Effect(_) => Err(err("Effect in Option.flatMap callback")),
}
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
Ok(EvalResult::Value(v))
}
Value::Constructor { name, .. } if name == "None" => {
Ok(EvalResult::Value(Value::Constructor {
@@ -1636,13 +1679,11 @@ impl Interpreter {
let (res, func) = Self::expect_args_2::<Value, Value>(&args, "Result.map", span)?;
match res {
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
match self.eval_call(func, vec![fields[0].clone()], span)? {
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
Ok(EvalResult::Value(Value::Constructor {
name: "Ok".to_string(),
fields: vec![v],
})),
EvalResult::Effect(_) => Err(err("Effect in Result.map callback")),
}
}))
}
Value::Constructor { name, fields } if name == "Err" => {
Ok(EvalResult::Value(Value::Constructor {
@@ -1662,10 +1703,8 @@ impl Interpreter {
Self::expect_args_2::<Value, Value>(&args, "Result.flatMap", span)?;
match res {
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
match self.eval_call(func, vec![fields[0].clone()], span)? {
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
EvalResult::Effect(_) => Err(err("Effect in Result.flatMap callback")),
}
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
Ok(EvalResult::Value(v))
}
Value::Constructor { name, fields } if name == "Err" => {
Ok(EvalResult::Value(Value::Constructor {
@@ -1804,6 +1843,7 @@ impl Interpreter {
arms: &[MatchArm],
env: &Env,
span: Span,
tail: bool,
) -> Result<EvalResult, RuntimeError> {
for arm in arms {
if let Some(bindings) = self.match_pattern(&arm.pattern, &val) {
@@ -1827,7 +1867,8 @@ impl Interpreter {
}
}
return self.eval_expr_inner(&arm.body, &match_env);
// Match arm body is in tail position if the match itself is
return self.eval_expr_tail(&arm.body, &match_env, tail);
}
}

View File

@@ -1282,5 +1282,41 @@ c")"#;
let result = eval(source);
assert!(result.is_ok(), "Expected success but got: {:?}", result);
}
#[test]
fn test_tail_call_optimization() {
// This test verifies that tail-recursive functions don't overflow the stack.
// Without TCO, a countdown from 10000 would cause a stack overflow.
let source = r#"
fn countdown(n: Int): Int = if n <= 0 then 0 else countdown(n - 1)
let result = countdown(10000)
"#;
assert_eq!(eval(source).unwrap(), "0");
}
#[test]
fn test_tail_call_with_accumulator() {
// Test TCO with an accumulator pattern (common for tail-recursive sum)
let source = r#"
fn sum_to(n: Int, acc: Int): Int = if n <= 0 then acc else sum_to(n - 1, acc + n)
let result = sum_to(1000, 0)
"#;
// Sum from 1 to 1000 = 1000 * 1001 / 2 = 500500
assert_eq!(eval(source).unwrap(), "500500");
}
#[test]
fn test_tail_call_in_match() {
// Test that TCO works through match expressions
let source = r#"
fn process(opt: Option<Int>, acc: Int): Int = match opt {
Some(n) => if n <= 0 then acc else process(Some(n - 1), acc + n),
None => acc
}
let result = process(Some(100), 0)
"#;
// Sum from 1 to 100 = 5050
assert_eq!(eval(source).unwrap(), "5050");
}
}
}