From 3734a17e5cc563d4db79b1bb5f01ce3f4db7a351 Mon Sep 17 00:00:00 2001 From: Brandon Lucas Date: Fri, 13 Feb 2026 09:20:36 -0500 Subject: [PATCH] feat: implement generic type parameters Add full support for user-defined generic types and functions: - Add type_params field to TypeChecker to track type parameters in scope - Update resolve_type() to resolve type parameters to their bound variables - Update function_type() to bind type parameters and return polymorphic TypeScheme - Update type declaration handling for generic ADTs (e.g., Pair) Generic functions and types now work: fn identity(x: T): T = x type Pair = | MkPair(A, B) fn first(p: Pair): A = ... Add examples/generics.lux demonstrating: - Generic identity function - Generic Pair type with first/second accessors - Generic mapOption function Co-Authored-By: Claude Opus 4.5 --- examples/generics.lux | 61 +++++++++++++++++++++++++++++ src/typechecker.rs | 89 +++++++++++++++++++++++++++++++++++++------ 2 files changed, 139 insertions(+), 11 deletions(-) create mode 100644 examples/generics.lux diff --git a/examples/generics.lux b/examples/generics.lux new file mode 100644 index 0000000..13f32ff --- /dev/null +++ b/examples/generics.lux @@ -0,0 +1,61 @@ +// Demonstrating generic type parameters in Lux +// +// Expected output: +// identity(42) = 42 +// identity("hello") = hello +// first(MkPair(1, "one")) = 1 +// second(MkPair(1, "one")) = one +// map(Some(21), double) = Some(42) + +// Generic identity function +fn identity(x: T): T = x + +// Generic pair type +type Pair = + | MkPair(A, B) + +fn first(p: Pair): A = + match p { + MkPair(a, _) => a + } + +fn second(p: Pair): B = + match p { + MkPair(_, b) => b + } + +// Generic map function for Option +fn mapOption(opt: Option, f: fn(T): U): Option = + match opt { + None => None, + Some(x) => Some(f(x)) + } + +// Helper function for testing +fn double(x: Int): Int = x * 2 + +// Test usage +let id_int = identity(42) +let id_str = identity("hello") + +let pair = MkPair(1, "one") +let fst = first(pair) +let snd = second(pair) + +let doubled = mapOption(Some(21), double) + +fn showOption(opt: Option): String = + match opt { + None => "None", + Some(x) => "Some(" + toString(x) + ")" + } + +fn printResults(): Unit with {Console} = { + Console.print("identity(42) = " + toString(id_int)) + Console.print("identity(\"hello\") = " + id_str) + Console.print("first(MkPair(1, \"one\")) = " + toString(fst)) + Console.print("second(MkPair(1, \"one\")) = " + snd) + Console.print("map(Some(21), double) = " + showOption(doubled)) +} + +let output = run printResults() with {} diff --git a/src/typechecker.rs b/src/typechecker.rs index e13e786..0975b74 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -2,6 +2,8 @@ #![allow(dead_code, unused_variables)] +use std::collections::HashMap; + use crate::ast::{ self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl, ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span, @@ -107,6 +109,8 @@ pub struct TypeChecker { /// Whether we're inferring effects (no explicit declaration) inferring_effects: bool, errors: Vec, + /// Type parameters in scope (maps "T" -> Type::Var(n) for generics) + type_params: HashMap, } impl TypeChecker { @@ -117,6 +121,7 @@ impl TypeChecker { inferred_effects: EffectSet::empty(), inferring_effects: false, errors: Vec::new(), + type_params: HashMap::new(), } } @@ -234,8 +239,8 @@ impl TypeChecker { fn collect_declaration(&mut self, decl: &Declaration) { match decl { Declaration::Function(func) => { - let typ = self.function_type(func); - self.env.bind(&func.name.name, TypeScheme::mono(typ)); + let scheme = self.function_type(func); + self.env.bind(&func.name.name, scheme); } Declaration::Effect(effect) => { let effect_def = self.effect_def(effect); @@ -244,17 +249,44 @@ impl TypeChecker { .insert(effect.name.name.clone(), effect_def); } Declaration::Type(type_decl) => { + // Save old type params for this scope + let old_params = std::mem::take(&mut self.type_params); + + // Bind type parameters to fresh type variables + let mut bound_vars = Vec::new(); + for param in &type_decl.type_params { + let var = Type::var(); + if let Type::Var(n) = &var { + bound_vars.push(*n); + } + self.type_params.insert(param.name.clone(), var); + } + + // Build the parameterized return type + let base_type = if type_decl.type_params.is_empty() { + Type::Named(type_decl.name.name.clone()) + } else { + Type::App { + constructor: Box::new(Type::Named(type_decl.name.name.clone())), + args: type_decl + .type_params + .iter() + .map(|p| self.type_params.get(&p.name).unwrap().clone()) + .collect(), + } + }; + + // Register the type definition let type_def = self.type_def(type_decl); self.env.types.insert(type_decl.name.name.clone(), type_def.clone()); - // Register ADT constructors as values in the type environment + // Register ADT constructors as values with polymorphic types if let ast::TypeDef::Enum(variants) = &type_decl.definition { - let type_name = Type::Named(type_decl.name.name.clone()); for variant in variants { let constructor_type = match &variant.fields { VariantFields::Unit => { // Unit variant is just the type itself - type_name.clone() + base_type.clone() } VariantFields::Tuple(field_types) => { // Tuple variant is a function from fields to the type @@ -262,7 +294,7 @@ impl TypeChecker { .iter() .map(|t| self.resolve_type(t)) .collect(); - Type::function(param_types, type_name.clone()) + Type::function(param_types, base_type.clone()) } VariantFields::Record(fields) => { // Record variant is a function from record to the type @@ -270,12 +302,20 @@ impl TypeChecker { .iter() .map(|f| (f.name.name.clone(), self.resolve_type(&f.typ))) .collect(); - Type::function(vec![Type::Record(field_types)], type_name.clone()) + Type::function(vec![Type::Record(field_types)], base_type.clone()) } }; - self.env.bind(&variant.name.name, TypeScheme::mono(constructor_type)); + // Wrap in polymorphic TypeScheme for generic types + let scheme = TypeScheme { + vars: bound_vars.clone(), + typ: constructor_type, + }; + self.env.bind(&variant.name.name, scheme); } } + + // Restore old type params + self.type_params = old_params; } Declaration::Handler(handler) => { let handler_def = self.handler_def(handler); @@ -1396,7 +1436,21 @@ impl TypeChecker { // Helper methods - fn function_type(&self, func: &FunctionDecl) -> Type { + fn function_type(&mut self, func: &FunctionDecl) -> TypeScheme { + // Save old type params and start fresh for this function's scope + let old_params = std::mem::take(&mut self.type_params); + + // Bind type parameters to fresh type variables + let mut bound_vars = Vec::new(); + for param in &func.type_params { + let var = Type::var(); + if let Type::Var(n) = &var { + bound_vars.push(*n); + } + self.type_params.insert(param.name.clone(), var); + } + + // Resolve parameter and return types (will use type_params for generics) let param_types: Vec = func .params .iter() @@ -1407,7 +1461,14 @@ impl TypeChecker { let effects = EffectSet::from_iter(func.effects.iter().map(|e| e.name.clone())); let properties = PropertySet::from_ast(&func.properties); - Type::function_with_properties(param_types, return_type, effects, properties) + // Restore old type params + self.type_params = old_params; + + // Return polymorphic type scheme with bound variables + TypeScheme { + vars: bound_vars, + typ: Type::function_with_properties(param_types, return_type, effects, properties), + } } fn effect_def(&self, effect: &EffectDecl) -> EffectDef { @@ -1643,7 +1704,13 @@ impl TypeChecker { "Char" => Type::Char, "Unit" => Type::Unit, "_" => Type::var(), - name => Type::Named(name.to_string()), + name => { + // Check if it's a type parameter in scope (for generics) + if let Some(var) = self.type_params.get(name) { + return var.clone(); + } + Type::Named(name.to_string()) + } }, TypeExpr::App(constructor, args) => { let resolved_args: Vec = args.iter().map(|a| self.resolve_type(a)).collect();