feat: implement generic type parameters

Add full support for user-defined generic types and functions:

- Add type_params field to TypeChecker to track type parameters in scope
- Update resolve_type() to resolve type parameters to their bound variables
- Update function_type() to bind type parameters and return polymorphic TypeScheme
- Update type declaration handling for generic ADTs (e.g., Pair<A, B>)

Generic functions and types now work:
  fn identity<T>(x: T): T = x
  type Pair<A, B> = | MkPair(A, B)
  fn first<A, B>(p: Pair<A, B>): A = ...

Add examples/generics.lux demonstrating:
- Generic identity function
- Generic Pair type with first/second accessors
- Generic mapOption function

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 09:20:36 -05:00
parent 15a820a467
commit 3734a17e5c
2 changed files with 139 additions and 11 deletions

61
examples/generics.lux Normal file
View File

@@ -0,0 +1,61 @@
// Demonstrating generic type parameters in Lux
//
// Expected output:
// identity(42) = 42
// identity("hello") = hello
// first(MkPair(1, "one")) = 1
// second(MkPair(1, "one")) = one
// map(Some(21), double) = Some(42)
// Generic identity function
fn identity<T>(x: T): T = x
// Generic pair type
type Pair<A, B> =
| MkPair(A, B)
fn first<A, B>(p: Pair<A, B>): A =
match p {
MkPair(a, _) => a
}
fn second<A, B>(p: Pair<A, B>): B =
match p {
MkPair(_, b) => b
}
// Generic map function for Option
fn mapOption<T, U>(opt: Option<T>, f: fn(T): U): Option<U> =
match opt {
None => None,
Some(x) => Some(f(x))
}
// Helper function for testing
fn double(x: Int): Int = x * 2
// Test usage
let id_int = identity(42)
let id_str = identity("hello")
let pair = MkPair(1, "one")
let fst = first(pair)
let snd = second(pair)
let doubled = mapOption(Some(21), double)
fn showOption(opt: Option<Int>): String =
match opt {
None => "None",
Some(x) => "Some(" + toString(x) + ")"
}
fn printResults(): Unit with {Console} = {
Console.print("identity(42) = " + toString(id_int))
Console.print("identity(\"hello\") = " + id_str)
Console.print("first(MkPair(1, \"one\")) = " + toString(fst))
Console.print("second(MkPair(1, \"one\")) = " + snd)
Console.print("map(Some(21), double) = " + showOption(doubled))
}
let output = run printResults() with {}

View File

@@ -2,6 +2,8 @@
#![allow(dead_code, unused_variables)] #![allow(dead_code, unused_variables)]
use std::collections::HashMap;
use crate::ast::{ use crate::ast::{
self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl, self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl,
ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span, ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span,
@@ -107,6 +109,8 @@ pub struct TypeChecker {
/// Whether we're inferring effects (no explicit declaration) /// Whether we're inferring effects (no explicit declaration)
inferring_effects: bool, inferring_effects: bool,
errors: Vec<TypeError>, errors: Vec<TypeError>,
/// Type parameters in scope (maps "T" -> Type::Var(n) for generics)
type_params: HashMap<String, Type>,
} }
impl TypeChecker { impl TypeChecker {
@@ -117,6 +121,7 @@ impl TypeChecker {
inferred_effects: EffectSet::empty(), inferred_effects: EffectSet::empty(),
inferring_effects: false, inferring_effects: false,
errors: Vec::new(), errors: Vec::new(),
type_params: HashMap::new(),
} }
} }
@@ -234,8 +239,8 @@ impl TypeChecker {
fn collect_declaration(&mut self, decl: &Declaration) { fn collect_declaration(&mut self, decl: &Declaration) {
match decl { match decl {
Declaration::Function(func) => { Declaration::Function(func) => {
let typ = self.function_type(func); let scheme = self.function_type(func);
self.env.bind(&func.name.name, TypeScheme::mono(typ)); self.env.bind(&func.name.name, scheme);
} }
Declaration::Effect(effect) => { Declaration::Effect(effect) => {
let effect_def = self.effect_def(effect); let effect_def = self.effect_def(effect);
@@ -244,17 +249,44 @@ impl TypeChecker {
.insert(effect.name.name.clone(), effect_def); .insert(effect.name.name.clone(), effect_def);
} }
Declaration::Type(type_decl) => { Declaration::Type(type_decl) => {
// Save old type params for this scope
let old_params = std::mem::take(&mut self.type_params);
// Bind type parameters to fresh type variables
let mut bound_vars = Vec::new();
for param in &type_decl.type_params {
let var = Type::var();
if let Type::Var(n) = &var {
bound_vars.push(*n);
}
self.type_params.insert(param.name.clone(), var);
}
// Build the parameterized return type
let base_type = if type_decl.type_params.is_empty() {
Type::Named(type_decl.name.name.clone())
} else {
Type::App {
constructor: Box::new(Type::Named(type_decl.name.name.clone())),
args: type_decl
.type_params
.iter()
.map(|p| self.type_params.get(&p.name).unwrap().clone())
.collect(),
}
};
// Register the type definition
let type_def = self.type_def(type_decl); let type_def = self.type_def(type_decl);
self.env.types.insert(type_decl.name.name.clone(), type_def.clone()); self.env.types.insert(type_decl.name.name.clone(), type_def.clone());
// Register ADT constructors as values in the type environment // Register ADT constructors as values with polymorphic types
if let ast::TypeDef::Enum(variants) = &type_decl.definition { if let ast::TypeDef::Enum(variants) = &type_decl.definition {
let type_name = Type::Named(type_decl.name.name.clone());
for variant in variants { for variant in variants {
let constructor_type = match &variant.fields { let constructor_type = match &variant.fields {
VariantFields::Unit => { VariantFields::Unit => {
// Unit variant is just the type itself // Unit variant is just the type itself
type_name.clone() base_type.clone()
} }
VariantFields::Tuple(field_types) => { VariantFields::Tuple(field_types) => {
// Tuple variant is a function from fields to the type // Tuple variant is a function from fields to the type
@@ -262,7 +294,7 @@ impl TypeChecker {
.iter() .iter()
.map(|t| self.resolve_type(t)) .map(|t| self.resolve_type(t))
.collect(); .collect();
Type::function(param_types, type_name.clone()) Type::function(param_types, base_type.clone())
} }
VariantFields::Record(fields) => { VariantFields::Record(fields) => {
// Record variant is a function from record to the type // Record variant is a function from record to the type
@@ -270,12 +302,20 @@ impl TypeChecker {
.iter() .iter()
.map(|f| (f.name.name.clone(), self.resolve_type(&f.typ))) .map(|f| (f.name.name.clone(), self.resolve_type(&f.typ)))
.collect(); .collect();
Type::function(vec![Type::Record(field_types)], type_name.clone()) Type::function(vec![Type::Record(field_types)], base_type.clone())
} }
}; };
self.env.bind(&variant.name.name, TypeScheme::mono(constructor_type)); // Wrap in polymorphic TypeScheme for generic types
let scheme = TypeScheme {
vars: bound_vars.clone(),
typ: constructor_type,
};
self.env.bind(&variant.name.name, scheme);
} }
} }
// Restore old type params
self.type_params = old_params;
} }
Declaration::Handler(handler) => { Declaration::Handler(handler) => {
let handler_def = self.handler_def(handler); let handler_def = self.handler_def(handler);
@@ -1396,7 +1436,21 @@ impl TypeChecker {
// Helper methods // Helper methods
fn function_type(&self, func: &FunctionDecl) -> Type { fn function_type(&mut self, func: &FunctionDecl) -> TypeScheme {
// Save old type params and start fresh for this function's scope
let old_params = std::mem::take(&mut self.type_params);
// Bind type parameters to fresh type variables
let mut bound_vars = Vec::new();
for param in &func.type_params {
let var = Type::var();
if let Type::Var(n) = &var {
bound_vars.push(*n);
}
self.type_params.insert(param.name.clone(), var);
}
// Resolve parameter and return types (will use type_params for generics)
let param_types: Vec<Type> = func let param_types: Vec<Type> = func
.params .params
.iter() .iter()
@@ -1407,7 +1461,14 @@ impl TypeChecker {
let effects = EffectSet::from_iter(func.effects.iter().map(|e| e.name.clone())); let effects = EffectSet::from_iter(func.effects.iter().map(|e| e.name.clone()));
let properties = PropertySet::from_ast(&func.properties); let properties = PropertySet::from_ast(&func.properties);
Type::function_with_properties(param_types, return_type, effects, properties) // Restore old type params
self.type_params = old_params;
// Return polymorphic type scheme with bound variables
TypeScheme {
vars: bound_vars,
typ: Type::function_with_properties(param_types, return_type, effects, properties),
}
} }
fn effect_def(&self, effect: &EffectDecl) -> EffectDef { fn effect_def(&self, effect: &EffectDecl) -> EffectDef {
@@ -1643,7 +1704,13 @@ impl TypeChecker {
"Char" => Type::Char, "Char" => Type::Char,
"Unit" => Type::Unit, "Unit" => Type::Unit,
"_" => Type::var(), "_" => Type::var(),
name => Type::Named(name.to_string()), name => {
// Check if it's a type parameter in scope (for generics)
if let Some(var) = self.type_params.get(name) {
return var.clone();
}
Type::Named(name.to_string())
}
}, },
TypeExpr::App(constructor, args) => { TypeExpr::App(constructor, args) => {
let resolved_args: Vec<Type> = args.iter().map(|a| self.resolve_type(a)).collect(); let resolved_args: Vec<Type> = args.iter().map(|a| self.resolve_type(a)).collect();