feat: implement generic type parameters
Add full support for user-defined generic types and functions: - Add type_params field to TypeChecker to track type parameters in scope - Update resolve_type() to resolve type parameters to their bound variables - Update function_type() to bind type parameters and return polymorphic TypeScheme - Update type declaration handling for generic ADTs (e.g., Pair<A, B>) Generic functions and types now work: fn identity<T>(x: T): T = x type Pair<A, B> = | MkPair(A, B) fn first<A, B>(p: Pair<A, B>): A = ... Add examples/generics.lux demonstrating: - Generic identity function - Generic Pair type with first/second accessors - Generic mapOption function Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
61
examples/generics.lux
Normal file
61
examples/generics.lux
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
// Demonstrating generic type parameters in Lux
|
||||||
|
//
|
||||||
|
// Expected output:
|
||||||
|
// identity(42) = 42
|
||||||
|
// identity("hello") = hello
|
||||||
|
// first(MkPair(1, "one")) = 1
|
||||||
|
// second(MkPair(1, "one")) = one
|
||||||
|
// map(Some(21), double) = Some(42)
|
||||||
|
|
||||||
|
// Generic identity function
|
||||||
|
fn identity<T>(x: T): T = x
|
||||||
|
|
||||||
|
// Generic pair type
|
||||||
|
type Pair<A, B> =
|
||||||
|
| MkPair(A, B)
|
||||||
|
|
||||||
|
fn first<A, B>(p: Pair<A, B>): A =
|
||||||
|
match p {
|
||||||
|
MkPair(a, _) => a
|
||||||
|
}
|
||||||
|
|
||||||
|
fn second<A, B>(p: Pair<A, B>): B =
|
||||||
|
match p {
|
||||||
|
MkPair(_, b) => b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generic map function for Option
|
||||||
|
fn mapOption<T, U>(opt: Option<T>, f: fn(T): U): Option<U> =
|
||||||
|
match opt {
|
||||||
|
None => None,
|
||||||
|
Some(x) => Some(f(x))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function for testing
|
||||||
|
fn double(x: Int): Int = x * 2
|
||||||
|
|
||||||
|
// Test usage
|
||||||
|
let id_int = identity(42)
|
||||||
|
let id_str = identity("hello")
|
||||||
|
|
||||||
|
let pair = MkPair(1, "one")
|
||||||
|
let fst = first(pair)
|
||||||
|
let snd = second(pair)
|
||||||
|
|
||||||
|
let doubled = mapOption(Some(21), double)
|
||||||
|
|
||||||
|
fn showOption(opt: Option<Int>): String =
|
||||||
|
match opt {
|
||||||
|
None => "None",
|
||||||
|
Some(x) => "Some(" + toString(x) + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn printResults(): Unit with {Console} = {
|
||||||
|
Console.print("identity(42) = " + toString(id_int))
|
||||||
|
Console.print("identity(\"hello\") = " + id_str)
|
||||||
|
Console.print("first(MkPair(1, \"one\")) = " + toString(fst))
|
||||||
|
Console.print("second(MkPair(1, \"one\")) = " + snd)
|
||||||
|
Console.print("map(Some(21), double) = " + showOption(doubled))
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = run printResults() with {}
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#![allow(dead_code, unused_variables)]
|
#![allow(dead_code, unused_variables)]
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::ast::{
|
use crate::ast::{
|
||||||
self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl,
|
self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl,
|
||||||
ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span,
|
ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span,
|
||||||
@@ -107,6 +109,8 @@ pub struct TypeChecker {
|
|||||||
/// Whether we're inferring effects (no explicit declaration)
|
/// Whether we're inferring effects (no explicit declaration)
|
||||||
inferring_effects: bool,
|
inferring_effects: bool,
|
||||||
errors: Vec<TypeError>,
|
errors: Vec<TypeError>,
|
||||||
|
/// Type parameters in scope (maps "T" -> Type::Var(n) for generics)
|
||||||
|
type_params: HashMap<String, Type>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TypeChecker {
|
impl TypeChecker {
|
||||||
@@ -117,6 +121,7 @@ impl TypeChecker {
|
|||||||
inferred_effects: EffectSet::empty(),
|
inferred_effects: EffectSet::empty(),
|
||||||
inferring_effects: false,
|
inferring_effects: false,
|
||||||
errors: Vec::new(),
|
errors: Vec::new(),
|
||||||
|
type_params: HashMap::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,8 +239,8 @@ impl TypeChecker {
|
|||||||
fn collect_declaration(&mut self, decl: &Declaration) {
|
fn collect_declaration(&mut self, decl: &Declaration) {
|
||||||
match decl {
|
match decl {
|
||||||
Declaration::Function(func) => {
|
Declaration::Function(func) => {
|
||||||
let typ = self.function_type(func);
|
let scheme = self.function_type(func);
|
||||||
self.env.bind(&func.name.name, TypeScheme::mono(typ));
|
self.env.bind(&func.name.name, scheme);
|
||||||
}
|
}
|
||||||
Declaration::Effect(effect) => {
|
Declaration::Effect(effect) => {
|
||||||
let effect_def = self.effect_def(effect);
|
let effect_def = self.effect_def(effect);
|
||||||
@@ -244,17 +249,44 @@ impl TypeChecker {
|
|||||||
.insert(effect.name.name.clone(), effect_def);
|
.insert(effect.name.name.clone(), effect_def);
|
||||||
}
|
}
|
||||||
Declaration::Type(type_decl) => {
|
Declaration::Type(type_decl) => {
|
||||||
|
// Save old type params for this scope
|
||||||
|
let old_params = std::mem::take(&mut self.type_params);
|
||||||
|
|
||||||
|
// Bind type parameters to fresh type variables
|
||||||
|
let mut bound_vars = Vec::new();
|
||||||
|
for param in &type_decl.type_params {
|
||||||
|
let var = Type::var();
|
||||||
|
if let Type::Var(n) = &var {
|
||||||
|
bound_vars.push(*n);
|
||||||
|
}
|
||||||
|
self.type_params.insert(param.name.clone(), var);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the parameterized return type
|
||||||
|
let base_type = if type_decl.type_params.is_empty() {
|
||||||
|
Type::Named(type_decl.name.name.clone())
|
||||||
|
} else {
|
||||||
|
Type::App {
|
||||||
|
constructor: Box::new(Type::Named(type_decl.name.name.clone())),
|
||||||
|
args: type_decl
|
||||||
|
.type_params
|
||||||
|
.iter()
|
||||||
|
.map(|p| self.type_params.get(&p.name).unwrap().clone())
|
||||||
|
.collect(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register the type definition
|
||||||
let type_def = self.type_def(type_decl);
|
let type_def = self.type_def(type_decl);
|
||||||
self.env.types.insert(type_decl.name.name.clone(), type_def.clone());
|
self.env.types.insert(type_decl.name.name.clone(), type_def.clone());
|
||||||
|
|
||||||
// Register ADT constructors as values in the type environment
|
// Register ADT constructors as values with polymorphic types
|
||||||
if let ast::TypeDef::Enum(variants) = &type_decl.definition {
|
if let ast::TypeDef::Enum(variants) = &type_decl.definition {
|
||||||
let type_name = Type::Named(type_decl.name.name.clone());
|
|
||||||
for variant in variants {
|
for variant in variants {
|
||||||
let constructor_type = match &variant.fields {
|
let constructor_type = match &variant.fields {
|
||||||
VariantFields::Unit => {
|
VariantFields::Unit => {
|
||||||
// Unit variant is just the type itself
|
// Unit variant is just the type itself
|
||||||
type_name.clone()
|
base_type.clone()
|
||||||
}
|
}
|
||||||
VariantFields::Tuple(field_types) => {
|
VariantFields::Tuple(field_types) => {
|
||||||
// Tuple variant is a function from fields to the type
|
// Tuple variant is a function from fields to the type
|
||||||
@@ -262,7 +294,7 @@ impl TypeChecker {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|t| self.resolve_type(t))
|
.map(|t| self.resolve_type(t))
|
||||||
.collect();
|
.collect();
|
||||||
Type::function(param_types, type_name.clone())
|
Type::function(param_types, base_type.clone())
|
||||||
}
|
}
|
||||||
VariantFields::Record(fields) => {
|
VariantFields::Record(fields) => {
|
||||||
// Record variant is a function from record to the type
|
// Record variant is a function from record to the type
|
||||||
@@ -270,12 +302,20 @@ impl TypeChecker {
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|f| (f.name.name.clone(), self.resolve_type(&f.typ)))
|
.map(|f| (f.name.name.clone(), self.resolve_type(&f.typ)))
|
||||||
.collect();
|
.collect();
|
||||||
Type::function(vec![Type::Record(field_types)], type_name.clone())
|
Type::function(vec![Type::Record(field_types)], base_type.clone())
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.env.bind(&variant.name.name, TypeScheme::mono(constructor_type));
|
// Wrap in polymorphic TypeScheme for generic types
|
||||||
|
let scheme = TypeScheme {
|
||||||
|
vars: bound_vars.clone(),
|
||||||
|
typ: constructor_type,
|
||||||
|
};
|
||||||
|
self.env.bind(&variant.name.name, scheme);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Restore old type params
|
||||||
|
self.type_params = old_params;
|
||||||
}
|
}
|
||||||
Declaration::Handler(handler) => {
|
Declaration::Handler(handler) => {
|
||||||
let handler_def = self.handler_def(handler);
|
let handler_def = self.handler_def(handler);
|
||||||
@@ -1396,7 +1436,21 @@ impl TypeChecker {
|
|||||||
|
|
||||||
// Helper methods
|
// Helper methods
|
||||||
|
|
||||||
fn function_type(&self, func: &FunctionDecl) -> Type {
|
fn function_type(&mut self, func: &FunctionDecl) -> TypeScheme {
|
||||||
|
// Save old type params and start fresh for this function's scope
|
||||||
|
let old_params = std::mem::take(&mut self.type_params);
|
||||||
|
|
||||||
|
// Bind type parameters to fresh type variables
|
||||||
|
let mut bound_vars = Vec::new();
|
||||||
|
for param in &func.type_params {
|
||||||
|
let var = Type::var();
|
||||||
|
if let Type::Var(n) = &var {
|
||||||
|
bound_vars.push(*n);
|
||||||
|
}
|
||||||
|
self.type_params.insert(param.name.clone(), var);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve parameter and return types (will use type_params for generics)
|
||||||
let param_types: Vec<Type> = func
|
let param_types: Vec<Type> = func
|
||||||
.params
|
.params
|
||||||
.iter()
|
.iter()
|
||||||
@@ -1407,7 +1461,14 @@ impl TypeChecker {
|
|||||||
let effects = EffectSet::from_iter(func.effects.iter().map(|e| e.name.clone()));
|
let effects = EffectSet::from_iter(func.effects.iter().map(|e| e.name.clone()));
|
||||||
let properties = PropertySet::from_ast(&func.properties);
|
let properties = PropertySet::from_ast(&func.properties);
|
||||||
|
|
||||||
Type::function_with_properties(param_types, return_type, effects, properties)
|
// Restore old type params
|
||||||
|
self.type_params = old_params;
|
||||||
|
|
||||||
|
// Return polymorphic type scheme with bound variables
|
||||||
|
TypeScheme {
|
||||||
|
vars: bound_vars,
|
||||||
|
typ: Type::function_with_properties(param_types, return_type, effects, properties),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn effect_def(&self, effect: &EffectDecl) -> EffectDef {
|
fn effect_def(&self, effect: &EffectDecl) -> EffectDef {
|
||||||
@@ -1643,7 +1704,13 @@ impl TypeChecker {
|
|||||||
"Char" => Type::Char,
|
"Char" => Type::Char,
|
||||||
"Unit" => Type::Unit,
|
"Unit" => Type::Unit,
|
||||||
"_" => Type::var(),
|
"_" => Type::var(),
|
||||||
name => Type::Named(name.to_string()),
|
name => {
|
||||||
|
// Check if it's a type parameter in scope (for generics)
|
||||||
|
if let Some(var) = self.type_params.get(name) {
|
||||||
|
return var.clone();
|
||||||
|
}
|
||||||
|
Type::Named(name.to_string())
|
||||||
|
}
|
||||||
},
|
},
|
||||||
TypeExpr::App(constructor, args) => {
|
TypeExpr::App(constructor, args) => {
|
||||||
let resolved_args: Vec<Type> = args.iter().map(|a| self.resolve_type(a)).collect();
|
let resolved_args: Vec<Type> = args.iter().map(|a| self.resolve_type(a)).collect();
|
||||||
|
|||||||
Reference in New Issue
Block a user