feat: implement tail call optimization (TCO)
Add trampoline-based tail call optimization to prevent stack overflow on deeply recursive tail-recursive functions. The implementation: - Extends EvalResult with TailCall variant for deferred evaluation - Adds trampoline loop in eval_expr() to handle tail calls iteratively - Propagates tail position through If, Let, Match, and Block expressions - Updates all builtin callbacks to handle tail calls via eval_call_to_value - Includes tests for deep recursion (10000+ calls) and accumulator patterns Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -423,10 +423,16 @@ impl Continuation {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of evaluation (either a value or an effect request)
|
/// Result of evaluation (either a value, effect request, or tail call)
|
||||||
pub enum EvalResult {
|
pub enum EvalResult {
|
||||||
Value(Value),
|
Value(Value),
|
||||||
Effect(EffectRequest),
|
Effect(EffectRequest),
|
||||||
|
/// Tail call optimization: instead of recursing, return the call to be trampolined
|
||||||
|
TailCall {
|
||||||
|
func: Value,
|
||||||
|
args: Vec<Value>,
|
||||||
|
span: Span,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Effect trace entry for debugging
|
/// Effect trace entry for debugging
|
||||||
@@ -882,18 +888,32 @@ impl Interpreter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Evaluate an expression
|
/// Evaluate an expression with tail call optimization (trampoline)
|
||||||
fn eval_expr(&mut self, expr: &Expr, env: &Env) -> Result<Value, RuntimeError> {
|
fn eval_expr(&mut self, expr: &Expr, env: &Env) -> Result<Value, RuntimeError> {
|
||||||
match self.eval_expr_inner(expr, env)? {
|
let mut result = self.eval_expr_inner(expr, env)?;
|
||||||
EvalResult::Value(v) => Ok(v),
|
|
||||||
EvalResult::Effect(req) => {
|
// Trampoline loop for tail call optimization
|
||||||
// Handle the effect
|
loop {
|
||||||
self.handle_effect(req)
|
match result {
|
||||||
|
EvalResult::Value(v) => return Ok(v),
|
||||||
|
EvalResult::Effect(req) => {
|
||||||
|
// Handle the effect
|
||||||
|
return self.handle_effect(req);
|
||||||
|
}
|
||||||
|
EvalResult::TailCall { func, args, span } => {
|
||||||
|
// Continue the tail call without growing the stack
|
||||||
|
result = self.eval_call(func, args, span)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn eval_expr_inner(&mut self, expr: &Expr, env: &Env) -> Result<EvalResult, RuntimeError> {
|
fn eval_expr_inner(&mut self, expr: &Expr, env: &Env) -> Result<EvalResult, RuntimeError> {
|
||||||
|
self.eval_expr_tail(expr, env, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate an expression, with tail position tracking for TCO
|
||||||
|
fn eval_expr_tail(&mut self, expr: &Expr, env: &Env, tail: bool) -> Result<EvalResult, RuntimeError> {
|
||||||
match expr {
|
match expr {
|
||||||
Expr::Literal(lit) => Ok(EvalResult::Value(self.eval_literal(lit))),
|
Expr::Literal(lit) => Ok(EvalResult::Value(self.eval_literal(lit))),
|
||||||
|
|
||||||
@@ -930,7 +950,16 @@ impl Interpreter {
|
|||||||
.map(|a| self.eval_expr(a, env))
|
.map(|a| self.eval_expr(a, env))
|
||||||
.collect::<Result<_, _>>()?;
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
self.eval_call(func_val, arg_vals, *span)
|
// If we're in tail position, return TailCall for trampoline
|
||||||
|
if tail {
|
||||||
|
Ok(EvalResult::TailCall {
|
||||||
|
func: func_val,
|
||||||
|
args: arg_vals,
|
||||||
|
span: *span,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
self.eval_call(func_val, arg_vals, *span)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr::EffectOp {
|
Expr::EffectOp {
|
||||||
@@ -1013,7 +1042,8 @@ impl Interpreter {
|
|||||||
let val = self.eval_expr(value, env)?;
|
let val = self.eval_expr(value, env)?;
|
||||||
let new_env = env.extend();
|
let new_env = env.extend();
|
||||||
new_env.define(&name.name, val);
|
new_env.define(&name.name, val);
|
||||||
self.eval_expr_inner(body, &new_env)
|
// Body of let is in tail position if the let itself is
|
||||||
|
self.eval_expr_tail(body, &new_env, tail)
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr::If {
|
Expr::If {
|
||||||
@@ -1024,8 +1054,9 @@ impl Interpreter {
|
|||||||
} => {
|
} => {
|
||||||
let cond_val = self.eval_expr(condition, env)?;
|
let cond_val = self.eval_expr(condition, env)?;
|
||||||
match cond_val {
|
match cond_val {
|
||||||
Value::Bool(true) => self.eval_expr_inner(then_branch, env),
|
// Branches are in tail position if the if itself is
|
||||||
Value::Bool(false) => self.eval_expr_inner(else_branch, env),
|
Value::Bool(true) => self.eval_expr_tail(then_branch, env, tail),
|
||||||
|
Value::Bool(false) => self.eval_expr_tail(else_branch, env, tail),
|
||||||
_ => Err(RuntimeError {
|
_ => Err(RuntimeError {
|
||||||
message: format!("If condition must be Bool, got {}", cond_val.type_name()),
|
message: format!("If condition must be Bool, got {}", cond_val.type_name()),
|
||||||
span: Some(*span),
|
span: Some(*span),
|
||||||
@@ -1039,7 +1070,8 @@ impl Interpreter {
|
|||||||
span,
|
span,
|
||||||
} => {
|
} => {
|
||||||
let val = self.eval_expr(scrutinee, env)?;
|
let val = self.eval_expr(scrutinee, env)?;
|
||||||
self.eval_match(val, arms, env, *span)
|
// Match arms are in tail position if the match itself is
|
||||||
|
self.eval_match(val, arms, env, *span, tail)
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr::Block {
|
Expr::Block {
|
||||||
@@ -1057,7 +1089,8 @@ impl Interpreter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.eval_expr_inner(result, &block_env)
|
// Block result is in tail position if the block itself is
|
||||||
|
self.eval_expr_tail(result, &block_env, tail)
|
||||||
}
|
}
|
||||||
|
|
||||||
Expr::Record { fields, .. } => {
|
Expr::Record { fields, .. } => {
|
||||||
@@ -1229,14 +1262,7 @@ impl Interpreter {
|
|||||||
},
|
},
|
||||||
BinaryOp::Pipe => {
|
BinaryOp::Pipe => {
|
||||||
// a |> f means f(a)
|
// a |> f means f(a)
|
||||||
self.eval_call(right, vec![left], span)
|
self.eval_call_to_value(right, vec![left], span)
|
||||||
.and_then(|r| match r {
|
|
||||||
EvalResult::Value(v) => Ok(v),
|
|
||||||
EvalResult::Effect(_) => Err(RuntimeError {
|
|
||||||
message: "Effect in pipe expression".to_string(),
|
|
||||||
span: Some(span),
|
|
||||||
}),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1285,7 +1311,8 @@ impl Interpreter {
|
|||||||
call_env.define(param, arg);
|
call_env.define(param, arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.eval_expr_inner(&closure.body, &call_env)
|
// Evaluate body in tail position for TCO
|
||||||
|
self.eval_expr_tail(&closure.body, &call_env, true)
|
||||||
}
|
}
|
||||||
Value::Constructor { name, fields } => {
|
Value::Constructor { name, fields } => {
|
||||||
// Constructor application
|
// Constructor application
|
||||||
@@ -1304,6 +1331,31 @@ impl Interpreter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Fully evaluate a call, handling any tail calls via trampoline.
|
||||||
|
/// Used by builtins that need to call user functions and get a value back.
|
||||||
|
fn eval_call_to_value(
|
||||||
|
&mut self,
|
||||||
|
func: Value,
|
||||||
|
args: Vec<Value>,
|
||||||
|
span: Span,
|
||||||
|
) -> Result<Value, RuntimeError> {
|
||||||
|
let mut result = self.eval_call(func, args, span)?;
|
||||||
|
loop {
|
||||||
|
match result {
|
||||||
|
EvalResult::Value(v) => return Ok(v),
|
||||||
|
EvalResult::Effect(_) => {
|
||||||
|
return Err(RuntimeError {
|
||||||
|
message: "Effect in callback not supported".to_string(),
|
||||||
|
span: Some(span),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
EvalResult::TailCall { func, args, span } => {
|
||||||
|
result = self.eval_call(func, args, span)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn eval_builtin(
|
fn eval_builtin(
|
||||||
&mut self,
|
&mut self,
|
||||||
builtin: BuiltinFn,
|
builtin: BuiltinFn,
|
||||||
@@ -1322,10 +1374,8 @@ impl Interpreter {
|
|||||||
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.map", span)?;
|
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.map", span)?;
|
||||||
let mut result = Vec::with_capacity(list.len());
|
let mut result = Vec::with_capacity(list.len());
|
||||||
for item in list {
|
for item in list {
|
||||||
match self.eval_call(func.clone(), vec![item], span)? {
|
let v = self.eval_call_to_value(func.clone(), vec![item], span)?;
|
||||||
EvalResult::Value(v) => result.push(v),
|
result.push(v);
|
||||||
EvalResult::Effect(_) => return Err(err("Effect in List.map callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(EvalResult::Value(Value::List(result)))
|
Ok(EvalResult::Value(Value::List(result)))
|
||||||
}
|
}
|
||||||
@@ -1335,16 +1385,16 @@ impl Interpreter {
|
|||||||
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.filter", span)?;
|
Self::expect_args_2::<Vec<Value>, Value>(&args, "List.filter", span)?;
|
||||||
let mut result = Vec::new();
|
let mut result = Vec::new();
|
||||||
for item in list {
|
for item in list {
|
||||||
match self.eval_call(func.clone(), vec![item.clone()], span)? {
|
let v = self.eval_call_to_value(func.clone(), vec![item.clone()], span)?;
|
||||||
EvalResult::Value(Value::Bool(true)) => result.push(item),
|
match v {
|
||||||
EvalResult::Value(Value::Bool(false)) => {}
|
Value::Bool(true) => result.push(item),
|
||||||
EvalResult::Value(v) => {
|
Value::Bool(false) => {}
|
||||||
|
_ => {
|
||||||
return Err(err(&format!(
|
return Err(err(&format!(
|
||||||
"List.filter predicate must return Bool, got {}",
|
"List.filter predicate must return Bool, got {}",
|
||||||
v.type_name()
|
v.type_name()
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
EvalResult::Effect(_) => return Err(err("Effect in List.filter callback")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(EvalResult::Value(Value::List(result)))
|
Ok(EvalResult::Value(Value::List(result)))
|
||||||
@@ -1370,10 +1420,7 @@ impl Interpreter {
|
|||||||
let func = args[2].clone();
|
let func = args[2].clone();
|
||||||
|
|
||||||
for item in list {
|
for item in list {
|
||||||
match self.eval_call(func.clone(), vec![acc, item], span)? {
|
acc = self.eval_call_to_value(func.clone(), vec![acc, item], span)?;
|
||||||
EvalResult::Value(v) => acc = v,
|
|
||||||
EvalResult::Effect(_) => return Err(err("Effect in List.fold callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(EvalResult::Value(acc))
|
Ok(EvalResult::Value(acc))
|
||||||
}
|
}
|
||||||
@@ -1538,13 +1585,11 @@ impl Interpreter {
|
|||||||
let (opt, func) = Self::expect_args_2::<Value, Value>(&args, "Option.map", span)?;
|
let (opt, func) = Self::expect_args_2::<Value, Value>(&args, "Option.map", span)?;
|
||||||
match opt {
|
match opt {
|
||||||
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
||||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||||
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
name: "Some".to_string(),
|
name: "Some".to_string(),
|
||||||
fields: vec![v],
|
fields: vec![v],
|
||||||
})),
|
}))
|
||||||
EvalResult::Effect(_) => Err(err("Effect in Option.map callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Value::Constructor { name, .. } if name == "None" => {
|
Value::Constructor { name, .. } if name == "None" => {
|
||||||
Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
@@ -1564,10 +1609,8 @@ impl Interpreter {
|
|||||||
Self::expect_args_2::<Value, Value>(&args, "Option.flatMap", span)?;
|
Self::expect_args_2::<Value, Value>(&args, "Option.flatMap", span)?;
|
||||||
match opt {
|
match opt {
|
||||||
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
Value::Constructor { name, fields } if name == "Some" && !fields.is_empty() => {
|
||||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||||
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
|
Ok(EvalResult::Value(v))
|
||||||
EvalResult::Effect(_) => Err(err("Effect in Option.flatMap callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Value::Constructor { name, .. } if name == "None" => {
|
Value::Constructor { name, .. } if name == "None" => {
|
||||||
Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
@@ -1636,13 +1679,11 @@ impl Interpreter {
|
|||||||
let (res, func) = Self::expect_args_2::<Value, Value>(&args, "Result.map", span)?;
|
let (res, func) = Self::expect_args_2::<Value, Value>(&args, "Result.map", span)?;
|
||||||
match res {
|
match res {
|
||||||
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
||||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||||
EvalResult::Value(v) => Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
name: "Ok".to_string(),
|
name: "Ok".to_string(),
|
||||||
fields: vec![v],
|
fields: vec![v],
|
||||||
})),
|
}))
|
||||||
EvalResult::Effect(_) => Err(err("Effect in Result.map callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Value::Constructor { name, fields } if name == "Err" => {
|
Value::Constructor { name, fields } if name == "Err" => {
|
||||||
Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
@@ -1662,10 +1703,8 @@ impl Interpreter {
|
|||||||
Self::expect_args_2::<Value, Value>(&args, "Result.flatMap", span)?;
|
Self::expect_args_2::<Value, Value>(&args, "Result.flatMap", span)?;
|
||||||
match res {
|
match res {
|
||||||
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
Value::Constructor { name, fields } if name == "Ok" && !fields.is_empty() => {
|
||||||
match self.eval_call(func, vec![fields[0].clone()], span)? {
|
let v = self.eval_call_to_value(func, vec![fields[0].clone()], span)?;
|
||||||
EvalResult::Value(v) => Ok(EvalResult::Value(v)),
|
Ok(EvalResult::Value(v))
|
||||||
EvalResult::Effect(_) => Err(err("Effect in Result.flatMap callback")),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Value::Constructor { name, fields } if name == "Err" => {
|
Value::Constructor { name, fields } if name == "Err" => {
|
||||||
Ok(EvalResult::Value(Value::Constructor {
|
Ok(EvalResult::Value(Value::Constructor {
|
||||||
@@ -1804,6 +1843,7 @@ impl Interpreter {
|
|||||||
arms: &[MatchArm],
|
arms: &[MatchArm],
|
||||||
env: &Env,
|
env: &Env,
|
||||||
span: Span,
|
span: Span,
|
||||||
|
tail: bool,
|
||||||
) -> Result<EvalResult, RuntimeError> {
|
) -> Result<EvalResult, RuntimeError> {
|
||||||
for arm in arms {
|
for arm in arms {
|
||||||
if let Some(bindings) = self.match_pattern(&arm.pattern, &val) {
|
if let Some(bindings) = self.match_pattern(&arm.pattern, &val) {
|
||||||
@@ -1827,7 +1867,8 @@ impl Interpreter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return self.eval_expr_inner(&arm.body, &match_env);
|
// Match arm body is in tail position if the match itself is
|
||||||
|
return self.eval_expr_tail(&arm.body, &match_env, tail);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
36
src/main.rs
36
src/main.rs
@@ -1282,5 +1282,41 @@ c")"#;
|
|||||||
let result = eval(source);
|
let result = eval(source);
|
||||||
assert!(result.is_ok(), "Expected success but got: {:?}", result);
|
assert!(result.is_ok(), "Expected success but got: {:?}", result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_call_optimization() {
|
||||||
|
// This test verifies that tail-recursive functions don't overflow the stack.
|
||||||
|
// Without TCO, a countdown from 10000 would cause a stack overflow.
|
||||||
|
let source = r#"
|
||||||
|
fn countdown(n: Int): Int = if n <= 0 then 0 else countdown(n - 1)
|
||||||
|
let result = countdown(10000)
|
||||||
|
"#;
|
||||||
|
assert_eq!(eval(source).unwrap(), "0");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_call_with_accumulator() {
|
||||||
|
// Test TCO with an accumulator pattern (common for tail-recursive sum)
|
||||||
|
let source = r#"
|
||||||
|
fn sum_to(n: Int, acc: Int): Int = if n <= 0 then acc else sum_to(n - 1, acc + n)
|
||||||
|
let result = sum_to(1000, 0)
|
||||||
|
"#;
|
||||||
|
// Sum from 1 to 1000 = 1000 * 1001 / 2 = 500500
|
||||||
|
assert_eq!(eval(source).unwrap(), "500500");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tail_call_in_match() {
|
||||||
|
// Test that TCO works through match expressions
|
||||||
|
let source = r#"
|
||||||
|
fn process(opt: Option<Int>, acc: Int): Int = match opt {
|
||||||
|
Some(n) => if n <= 0 then acc else process(Some(n - 1), acc + n),
|
||||||
|
None => acc
|
||||||
|
}
|
||||||
|
let result = process(Some(100), 0)
|
||||||
|
"#;
|
||||||
|
// Sum from 1 to 100 = 5050
|
||||||
|
assert_eq!(eval(source).unwrap(), "5050");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user