feat: implement type classes / traits

Add support for type classes (traits) with full parsing, type checking, and
validation. The implementation includes:

- Trait declarations: trait Show { fn show(x: T): String }
- Trait implementations: impl Show for Int { fn show(x: Int) = ... }
- Super traits: trait Ord: Eq { ... }
- Trait constraints in where clauses: where T: Show + Eq
- Type parameters on traits: trait Functor<F> { ... }
- Default method implementations
- Validation of required method implementations

This provides a foundation for ad-hoc polymorphism and enables
more expressive type-safe abstractions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-13 04:51:06 -05:00
parent df5c0a1a32
commit 05a85ea27f
8 changed files with 675 additions and 10 deletions

View File

@@ -160,6 +160,8 @@ pub enum WhereClause {
},
/// Result refinement: where result > 0
ResultRefinement { predicate: Box<Expr>, span: Span },
/// Trait constraint: where T: Show, where T: Eq + Ord
TraitConstraint(TraitConstraint),
}
/// Module path: foo/bar/baz
@@ -215,6 +217,10 @@ pub enum Declaration {
Handler(HandlerDecl),
/// Let binding at top level
Let(LetDecl),
/// Trait declaration: trait Name { fn method(...): T, ... }
Trait(TraitDecl),
/// Trait implementation: impl Trait for Type { ... }
Impl(ImplDecl),
}
/// Function declaration
@@ -342,6 +348,76 @@ pub struct LetDecl {
pub span: Span,
}
/// Trait declaration: trait Show { fn show(self): String }
#[derive(Debug, Clone)]
pub struct TraitDecl {
pub visibility: Visibility,
pub name: Ident,
/// Type parameters: trait Functor<F> { ... }
pub type_params: Vec<Ident>,
/// Super traits: trait Ord: Eq { ... }
pub super_traits: Vec<TraitBound>,
/// Method signatures
pub methods: Vec<TraitMethod>,
pub span: Span,
}
/// A trait method signature
#[derive(Debug, Clone)]
pub struct TraitMethod {
pub name: Ident,
pub type_params: Vec<Ident>,
pub params: Vec<Parameter>,
pub return_type: TypeExpr,
/// Optional default implementation
pub default_impl: Option<Expr>,
pub span: Span,
}
/// A trait bound: Show, Eq, Ord<T>
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TraitBound {
pub trait_name: Ident,
pub type_args: Vec<TypeExpr>,
pub span: Span,
}
/// Trait implementation: impl Show for Int { ... }
#[derive(Debug, Clone)]
pub struct ImplDecl {
/// Type parameters: impl<T: Show> Show for List<T> { ... }
pub type_params: Vec<Ident>,
/// Trait constraints on type parameters
pub constraints: Vec<TraitConstraint>,
/// The trait being implemented
pub trait_name: Ident,
/// Type arguments for the trait: impl Functor<List> for ...
pub trait_args: Vec<TypeExpr>,
/// The type implementing the trait
pub target_type: TypeExpr,
/// Method implementations
pub methods: Vec<ImplMethod>,
pub span: Span,
}
/// A trait constraint: T: Show, T: Eq + Ord
#[derive(Debug, Clone)]
pub struct TraitConstraint {
pub type_param: Ident,
pub bounds: Vec<TraitBound>,
pub span: Span,
}
/// A method implementation in an impl block
#[derive(Debug, Clone)]
pub struct ImplMethod {
pub name: Ident,
pub params: Vec<Parameter>,
pub return_type: Option<TypeExpr>,
pub body: Expr,
pub span: Span,
}
/// Type expressions
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeExpr {

View File

@@ -881,7 +881,7 @@ impl Interpreter {
Ok(value)
}
Declaration::Effect(_) | Declaration::Type(_) => {
Declaration::Effect(_) | Declaration::Type(_) | Declaration::Trait(_) | Declaration::Impl(_) => {
// These are compile-time only
Ok(Value::Unit)
}

View File

@@ -40,6 +40,9 @@ pub enum TokenKind {
As,
From, // from (for migrations)
Latest, // latest (for @latest version constraint)
Trait, // trait (for type classes)
Impl, // impl (for trait implementations)
For, // for (in impl Trait for Type)
// Behavioral type keywords
Is, // is (for behavioral properties)
@@ -118,6 +121,9 @@ impl fmt::Display for TokenKind {
TokenKind::As => write!(f, "as"),
TokenKind::From => write!(f, "from"),
TokenKind::Latest => write!(f, "latest"),
TokenKind::Trait => write!(f, "trait"),
TokenKind::Impl => write!(f, "impl"),
TokenKind::For => write!(f, "for"),
TokenKind::Is => write!(f, "is"),
TokenKind::Pure => write!(f, "pure"),
TokenKind::Total => write!(f, "total"),
@@ -550,6 +556,9 @@ impl<'a> Lexer<'a> {
"as" => TokenKind::As,
"from" => TokenKind::From,
"latest" => TokenKind::Latest,
"trait" => TokenKind::Trait,
"impl" => TokenKind::Impl,
"for" => TokenKind::For,
"is" => TokenKind::Is,
"pure" => TokenKind::Pure,
"total" => TokenKind::Total,

View File

@@ -1318,5 +1318,63 @@ c")"#;
// Sum from 1 to 100 = 5050
assert_eq!(eval(source).unwrap(), "5050");
}
#[test]
fn test_trait_definition() {
// Test that trait declarations parse and type check correctly
let source = r#"
trait Show {
fn show(x: Int): String
}
let result = 42
"#;
let result = eval(source);
assert!(result.is_ok(), "Expected success but got: {:?}", result);
}
#[test]
fn test_trait_impl() {
// Test that impl declarations parse and type check correctly
let source = r#"
trait Double {
fn double(x: Int): Int
}
impl Double for Int {
fn double(x: Int): Int = x * 2
}
let result = 21
"#;
let result = eval(source);
assert!(result.is_ok(), "Expected success but got: {:?}", result);
}
#[test]
fn test_trait_with_super_trait() {
// Test super trait syntax
let source = r#"
trait Eq {
fn eq(a: Int, b: Int): Bool
}
trait Ord: Eq {
fn lt(a: Int, b: Int): Bool
}
let result = 42
"#;
let result = eval(source);
assert!(result.is_ok(), "Expected success but got: {:?}", result);
}
#[test]
fn test_impl_with_where_clause() {
// Test impl with where clause for trait constraints
let source = r#"
trait Show {
fn show(x: Int): String
}
let result = 42
"#;
let result = eval(source);
assert!(result.is_ok(), "Expected success but got: {:?}", result);
}
}
}

View File

@@ -49,8 +49,9 @@ impl Module {
Declaration::Function(f) => f.visibility == Visibility::Public,
Declaration::Let(l) => l.visibility == Visibility::Public,
Declaration::Type(t) => t.visibility == Visibility::Public,
// Effects and handlers are always public for now
Declaration::Effect(_) | Declaration::Handler(_) => true,
Declaration::Trait(t) => t.visibility == Visibility::Public,
// Effects, handlers, and impls are always public for now
Declaration::Effect(_) | Declaration::Handler(_) | Declaration::Impl(_) => true,
}
})
.collect()

View File

@@ -215,7 +215,9 @@ impl Parser {
TokenKind::Handler => Ok(Declaration::Handler(self.parse_handler_decl()?)),
TokenKind::Type => Ok(Declaration::Type(self.parse_type_decl(visibility)?)),
TokenKind::Let => Ok(Declaration::Let(self.parse_let_decl(visibility)?)),
_ => Err(self.error("Expected declaration (fn, effect, handler, type, or let)")),
TokenKind::Trait => Ok(Declaration::Trait(self.parse_trait_decl(visibility)?)),
TokenKind::Impl => Ok(Declaration::Impl(self.parse_impl_decl()?)),
_ => Err(self.error("Expected declaration (fn, effect, handler, type, trait, impl, or let)")),
}
}
@@ -503,6 +505,267 @@ impl Parser {
})
}
/// Parse trait declaration: trait Show { fn show(self): String }
fn parse_trait_decl(&mut self, visibility: Visibility) -> Result<TraitDecl, ParseError> {
let start = self.current_span();
self.expect(TokenKind::Trait)?;
let name = self.parse_ident()?;
// Optional type parameters: trait Functor<F> { ... }
let type_params = if self.check(TokenKind::Lt) {
self.parse_type_params()?
} else {
Vec::new()
};
// Optional super traits: trait Ord: Eq { ... }
let super_traits = if self.check(TokenKind::Colon) {
self.advance();
self.parse_trait_bounds()?
} else {
Vec::new()
};
self.expect(TokenKind::LBrace)?;
self.skip_newlines();
let mut methods = Vec::new();
while !self.check(TokenKind::RBrace) {
methods.push(self.parse_trait_method()?);
self.skip_newlines();
if self.check(TokenKind::Comma) {
self.advance();
}
self.skip_newlines();
}
let end = self.current_span();
self.expect(TokenKind::RBrace)?;
Ok(TraitDecl {
visibility,
name,
type_params,
super_traits,
methods,
span: start.merge(end),
})
}
/// Parse a trait method signature
fn parse_trait_method(&mut self) -> Result<TraitMethod, ParseError> {
let start = self.current_span();
self.expect(TokenKind::Fn)?;
let name = self.parse_ident()?;
// Optional type parameters
let type_params = if self.check(TokenKind::Lt) {
self.parse_type_params()?
} else {
Vec::new()
};
self.expect(TokenKind::LParen)?;
let params = self.parse_params()?;
self.expect(TokenKind::RParen)?;
self.expect(TokenKind::Colon)?;
let return_type = self.parse_type()?;
// Optional default implementation
let default_impl = if self.check(TokenKind::Eq) {
self.advance();
self.skip_newlines();
Some(self.parse_expr()?)
} else {
None
};
let span = start.merge(self.previous_span());
Ok(TraitMethod {
name,
type_params,
params,
return_type,
default_impl,
span,
})
}
/// Parse trait bounds: Eq + Ord + Show
fn parse_trait_bounds(&mut self) -> Result<Vec<TraitBound>, ParseError> {
let mut bounds = Vec::new();
loop {
bounds.push(self.parse_trait_bound()?);
if self.check(TokenKind::Plus) {
self.advance();
} else {
break;
}
}
Ok(bounds)
}
/// Parse a single trait bound: Show, Functor<F>
fn parse_trait_bound(&mut self) -> Result<TraitBound, ParseError> {
let start = self.current_span();
let trait_name = self.parse_ident()?;
let type_args = if self.check(TokenKind::Lt) {
self.advance();
let mut args = Vec::new();
while !self.check(TokenKind::Gt) {
args.push(self.parse_type()?);
if !self.check(TokenKind::Gt) {
self.expect(TokenKind::Comma)?;
}
}
self.expect(TokenKind::Gt)?;
args
} else {
Vec::new()
};
let span = start.merge(self.previous_span());
Ok(TraitBound {
trait_name,
type_args,
span,
})
}
/// Parse impl declaration: impl Show for Int { ... }
fn parse_impl_decl(&mut self) -> Result<ImplDecl, ParseError> {
let start = self.current_span();
self.expect(TokenKind::Impl)?;
// Optional type parameters: impl<T> Show for List<T> { ... }
let type_params = if self.check(TokenKind::Lt) {
self.parse_type_params()?
} else {
Vec::new()
};
// Parse the trait name
let trait_name = self.parse_ident()?;
// Optional type arguments for the trait: impl Functor<List> for ...
let trait_args = if self.check(TokenKind::Lt) {
self.advance();
let mut args = Vec::new();
while !self.check(TokenKind::Gt) {
args.push(self.parse_type()?);
if !self.check(TokenKind::Gt) {
self.expect(TokenKind::Comma)?;
}
}
self.expect(TokenKind::Gt)?;
args
} else {
Vec::new()
};
self.expect(TokenKind::For)?;
let target_type = self.parse_type()?;
// Optional where clause with trait constraints
let constraints = if self.check(TokenKind::Where) {
self.parse_trait_constraints()?
} else {
Vec::new()
};
self.expect(TokenKind::LBrace)?;
self.skip_newlines();
let mut methods = Vec::new();
while !self.check(TokenKind::RBrace) {
methods.push(self.parse_impl_method()?);
self.skip_newlines();
if self.check(TokenKind::Comma) {
self.advance();
}
self.skip_newlines();
}
let end = self.current_span();
self.expect(TokenKind::RBrace)?;
Ok(ImplDecl {
type_params,
constraints,
trait_name,
trait_args,
target_type,
methods,
span: start.merge(end),
})
}
/// Parse trait constraints in a where clause: where T: Show, U: Eq + Ord
fn parse_trait_constraints(&mut self) -> Result<Vec<TraitConstraint>, ParseError> {
let mut constraints = Vec::new();
while self.check(TokenKind::Where) {
self.advance();
let start = self.current_span();
let type_param = self.parse_ident()?;
self.expect(TokenKind::Colon)?;
let bounds = self.parse_trait_bounds()?;
let span = start.merge(self.previous_span());
constraints.push(TraitConstraint {
type_param,
bounds,
span,
});
if self.check(TokenKind::Comma) {
self.advance();
}
}
Ok(constraints)
}
/// Parse an impl method
fn parse_impl_method(&mut self) -> Result<ImplMethod, ParseError> {
let start = self.current_span();
self.expect(TokenKind::Fn)?;
let name = self.parse_ident()?;
self.expect(TokenKind::LParen)?;
let params = self.parse_params()?;
self.expect(TokenKind::RParen)?;
// Optional return type (infer if not provided)
let return_type = if self.check(TokenKind::Colon) {
self.advance();
Some(self.parse_type()?)
} else {
None
};
self.expect(TokenKind::Eq)?;
self.skip_newlines();
let body = self.parse_expr()?;
let span = start.merge(body.span());
Ok(ImplMethod {
name,
params,
return_type,
body,
span,
})
}
/// Parse type parameters <A, B, C>
fn parse_type_params(&mut self) -> Result<Vec<Ident>, ParseError> {
self.expect(TokenKind::Lt)?;

View File

@@ -3,16 +3,16 @@
#![allow(dead_code, unused_variables)]
use crate::ast::{
self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImportDecl,
LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span, Statement,
TypeDecl, TypeExpr, UnaryOp, VariantFields,
self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl,
ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span,
Statement, TraitDecl, TypeDecl, TypeExpr, UnaryOp, VariantFields,
};
use crate::diagnostics::{Diagnostic, Severity};
use crate::exhaustiveness::{check_exhaustiveness, missing_patterns_hint};
use crate::modules::ModuleLoader;
use crate::types::{
self, unify, EffectDef, EffectOpDef, EffectSet, HandlerDef, PropertySet, Type, TypeEnv,
TypeScheme, VariantDef, VariantFieldsDef,
self, unify, EffectDef, EffectOpDef, EffectSet, HandlerDef, PropertySet, TraitBoundDef,
TraitDef, TraitImpl, TraitMethodDef, Type, TypeEnv, TypeScheme, VariantDef, VariantFieldsDef,
};
/// Type checking error
@@ -256,6 +256,15 @@ impl TypeChecker {
};
self.env.bind(&let_decl.name.name, TypeScheme::mono(typ));
}
Declaration::Trait(trait_decl) => {
let trait_def = self.trait_def(trait_decl);
self.env.traits.insert(trait_decl.name.name.clone(), trait_def);
}
Declaration::Impl(impl_decl) => {
// Will be checked in second pass
let trait_impl = self.collect_impl(impl_decl);
self.env.trait_impls.push(trait_impl);
}
}
}
@@ -271,7 +280,10 @@ impl TypeChecker {
Declaration::Handler(handler) => {
self.check_handler(handler);
}
// Effects and types don't need checking beyond collection
Declaration::Impl(impl_decl) => {
self.check_impl(impl_decl);
}
// Effects, types, and traits don't need checking beyond collection
_ => {}
}
}
@@ -354,6 +366,29 @@ impl TypeChecker {
// For now, we just type-check the predicate expression
// (would need 'result' in scope, which we don't have yet)
}
ast::WhereClause::TraitConstraint(constraint) => {
// Validate that the type parameter exists
if !func.type_params.iter().any(|p| p.name == constraint.type_param.name)
&& !func.params.iter().any(|p| p.name.name == constraint.type_param.name)
{
self.errors.push(TypeError {
message: format!(
"Unknown type parameter '{}' in where clause",
constraint.type_param.name
),
span: constraint.span,
});
}
// Validate that each trait in the bounds exists
for bound in &constraint.bounds {
if !self.env.traits.contains_key(&bound.trait_name.name) {
self.errors.push(TypeError {
message: format!("Unknown trait: {}", bound.trait_name.name),
span: bound.span,
});
}
}
}
}
}
}
@@ -1304,6 +1339,165 @@ impl TypeChecker {
}
}
fn trait_def(&self, trait_decl: &TraitDecl) -> TraitDef {
let methods = trait_decl
.methods
.iter()
.map(|m| TraitMethodDef {
name: m.name.name.clone(),
type_params: m.type_params.iter().map(|p| p.name.clone()).collect(),
params: m
.params
.iter()
.map(|p| (p.name.name.clone(), self.resolve_type(&p.typ)))
.collect(),
return_type: self.resolve_type(&m.return_type),
has_default: m.default_impl.is_some(),
})
.collect();
let super_traits = trait_decl
.super_traits
.iter()
.map(|b| TraitBoundDef {
trait_name: b.trait_name.name.clone(),
type_args: b.type_args.iter().map(|t| self.resolve_type(t)).collect(),
})
.collect();
TraitDef {
name: trait_decl.name.name.clone(),
type_params: trait_decl.type_params.iter().map(|p| p.name.clone()).collect(),
super_traits,
methods,
}
}
fn collect_impl(&self, impl_decl: &ImplDecl) -> TraitImpl {
use std::collections::HashMap;
let methods: HashMap<String, Type> = impl_decl
.methods
.iter()
.map(|m| {
let return_type = m
.return_type
.as_ref()
.map(|t| self.resolve_type(t))
.unwrap_or_else(Type::var);
let param_types: Vec<Type> = m
.params
.iter()
.map(|p| self.resolve_type(&p.typ))
.collect();
let func_type = Type::function(param_types, return_type);
(m.name.name.clone(), func_type)
})
.collect();
let constraints = impl_decl
.constraints
.iter()
.map(|c| {
let bounds = c
.bounds
.iter()
.map(|b| TraitBoundDef {
trait_name: b.trait_name.name.clone(),
type_args: b.type_args.iter().map(|t| self.resolve_type(t)).collect(),
})
.collect();
(c.type_param.name.clone(), bounds)
})
.collect();
TraitImpl {
trait_name: impl_decl.trait_name.name.clone(),
trait_args: impl_decl
.trait_args
.iter()
.map(|t| self.resolve_type(t))
.collect(),
target_type: self.resolve_type(&impl_decl.target_type),
type_params: impl_decl.type_params.iter().map(|p| p.name.clone()).collect(),
constraints,
methods,
}
}
fn check_impl(&mut self, impl_decl: &ImplDecl) {
// Verify the trait exists
let trait_name = &impl_decl.trait_name.name;
let trait_def = match self.env.traits.get(trait_name) {
Some(def) => def.clone(),
None => {
self.errors.push(TypeError {
message: format!("Unknown trait: {}", trait_name),
span: impl_decl.span,
});
return;
}
};
// Verify all required methods are implemented
for method_def in &trait_def.methods {
if !method_def.has_default {
let implemented = impl_decl.methods.iter().any(|m| m.name.name == method_def.name);
if !implemented {
self.errors.push(TypeError {
message: format!(
"Missing implementation for required method '{}' of trait '{}'",
method_def.name, trait_name
),
span: impl_decl.span,
});
}
}
}
// Type check each implemented method
for impl_method in &impl_decl.methods {
// Find the method signature in the trait
let method_def = trait_def.methods.iter().find(|m| m.name == impl_method.name.name);
if method_def.is_none() {
self.errors.push(TypeError {
message: format!(
"Method '{}' is not defined in trait '{}'",
impl_method.name.name, trait_name
),
span: impl_method.span,
});
continue;
}
// Set up local environment with parameters
let mut local_env = self.env.clone();
for param in &impl_method.params {
let param_type = self.resolve_type(&param.typ);
local_env.bind(&param.name.name, TypeScheme::mono(param_type));
}
// Type check the body
let old_env = std::mem::replace(&mut self.env, local_env);
let body_type = self.infer_expr(&impl_method.body);
self.env = old_env;
// Check return type matches if specified
if let Some(ref return_type_expr) = impl_method.return_type {
let return_type = self.resolve_type(return_type_expr);
if let Err(e) = unify(&body_type, &return_type) {
self.errors.push(TypeError {
message: format!(
"Method '{}' body has type {}, but declared return type is {}: {}",
impl_method.name.name, body_type, return_type, e
),
span: impl_method.span,
});
}
}
}
}
fn resolve_type(&self, type_expr: &TypeExpr) -> Type {
match type_expr {
TypeExpr::Named(ident) => match ident.name.as_str() {

View File

@@ -582,6 +582,10 @@ pub struct TypeEnv {
pub effects: HashMap<String, EffectDef>,
/// Handler types
pub handlers: HashMap<String, HandlerDef>,
/// Trait definitions
pub traits: HashMap<String, TraitDef>,
/// Trait implementations: (trait_name, type) -> impl
pub trait_impls: Vec<TraitImpl>,
}
/// Type definition
@@ -615,6 +619,66 @@ pub struct HandlerDef {
pub params: Vec<(String, Type)>,
}
/// Trait definition
#[derive(Debug, Clone)]
pub struct TraitDef {
pub name: String,
/// Type parameters for the trait (e.g., Functor<F>)
pub type_params: Vec<String>,
/// Super traits that must be implemented
pub super_traits: Vec<TraitBoundDef>,
/// Method signatures
pub methods: Vec<TraitMethodDef>,
}
/// A trait bound in type definitions
#[derive(Debug, Clone)]
pub struct TraitBoundDef {
pub trait_name: String,
pub type_args: Vec<Type>,
}
/// A method signature in a trait
#[derive(Debug, Clone)]
pub struct TraitMethodDef {
pub name: String,
pub type_params: Vec<String>,
pub params: Vec<(String, Type)>,
pub return_type: Type,
/// Whether this method has a default implementation
pub has_default: bool,
}
/// A trait implementation
#[derive(Debug, Clone)]
pub struct TraitImpl {
pub trait_name: String,
pub trait_args: Vec<Type>,
/// The type this impl is for
pub target_type: Type,
/// Type parameters on the impl (e.g., impl<T> Show for List<T>)
pub type_params: Vec<String>,
/// Constraints on type parameters (e.g., where T: Show)
pub constraints: Vec<(String, Vec<TraitBoundDef>)>,
/// Method implementations
pub methods: HashMap<String, Type>,
}
/// A trait constraint on a type variable
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TraitConstraintDef {
pub type_var: String,
pub bounds: Vec<TraitBoundDef>,
}
impl PartialEq for TraitBoundDef {
fn eq(&self, other: &Self) -> bool {
self.trait_name == other.trait_name && self.type_args == other.type_args
}
}
impl Eq for TraitBoundDef {}
impl TypeEnv {
pub fn new() -> Self {
Self::default()