diff --git a/src/ast.rs b/src/ast.rs index 475ca1d..63af786 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -160,6 +160,8 @@ pub enum WhereClause { }, /// Result refinement: where result > 0 ResultRefinement { predicate: Box, span: Span }, + /// Trait constraint: where T: Show, where T: Eq + Ord + TraitConstraint(TraitConstraint), } /// Module path: foo/bar/baz @@ -215,6 +217,10 @@ pub enum Declaration { Handler(HandlerDecl), /// Let binding at top level Let(LetDecl), + /// Trait declaration: trait Name { fn method(...): T, ... } + Trait(TraitDecl), + /// Trait implementation: impl Trait for Type { ... } + Impl(ImplDecl), } /// Function declaration @@ -342,6 +348,76 @@ pub struct LetDecl { pub span: Span, } +/// Trait declaration: trait Show { fn show(self): String } +#[derive(Debug, Clone)] +pub struct TraitDecl { + pub visibility: Visibility, + pub name: Ident, + /// Type parameters: trait Functor { ... } + pub type_params: Vec, + /// Super traits: trait Ord: Eq { ... } + pub super_traits: Vec, + /// Method signatures + pub methods: Vec, + pub span: Span, +} + +/// A trait method signature +#[derive(Debug, Clone)] +pub struct TraitMethod { + pub name: Ident, + pub type_params: Vec, + pub params: Vec, + pub return_type: TypeExpr, + /// Optional default implementation + pub default_impl: Option, + pub span: Span, +} + +/// A trait bound: Show, Eq, Ord +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TraitBound { + pub trait_name: Ident, + pub type_args: Vec, + pub span: Span, +} + +/// Trait implementation: impl Show for Int { ... } +#[derive(Debug, Clone)] +pub struct ImplDecl { + /// Type parameters: impl Show for List { ... } + pub type_params: Vec, + /// Trait constraints on type parameters + pub constraints: Vec, + /// The trait being implemented + pub trait_name: Ident, + /// Type arguments for the trait: impl Functor for ... + pub trait_args: Vec, + /// The type implementing the trait + pub target_type: TypeExpr, + /// Method implementations + pub methods: Vec, + pub span: Span, +} + +/// A trait constraint: T: Show, T: Eq + Ord +#[derive(Debug, Clone)] +pub struct TraitConstraint { + pub type_param: Ident, + pub bounds: Vec, + pub span: Span, +} + +/// A method implementation in an impl block +#[derive(Debug, Clone)] +pub struct ImplMethod { + pub name: Ident, + pub params: Vec, + pub return_type: Option, + pub body: Expr, + pub span: Span, +} + /// Type expressions #[derive(Debug, Clone, PartialEq, Eq)] pub enum TypeExpr { diff --git a/src/interpreter.rs b/src/interpreter.rs index 166c11f..607407e 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -881,7 +881,7 @@ impl Interpreter { Ok(value) } - Declaration::Effect(_) | Declaration::Type(_) => { + Declaration::Effect(_) | Declaration::Type(_) | Declaration::Trait(_) | Declaration::Impl(_) => { // These are compile-time only Ok(Value::Unit) } diff --git a/src/lexer.rs b/src/lexer.rs index 076aef1..d8ed7d4 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -40,6 +40,9 @@ pub enum TokenKind { As, From, // from (for migrations) Latest, // latest (for @latest version constraint) + Trait, // trait (for type classes) + Impl, // impl (for trait implementations) + For, // for (in impl Trait for Type) // Behavioral type keywords Is, // is (for behavioral properties) @@ -118,6 +121,9 @@ impl fmt::Display for TokenKind { TokenKind::As => write!(f, "as"), TokenKind::From => write!(f, "from"), TokenKind::Latest => write!(f, "latest"), + TokenKind::Trait => write!(f, "trait"), + TokenKind::Impl => write!(f, "impl"), + TokenKind::For => write!(f, "for"), TokenKind::Is => write!(f, "is"), TokenKind::Pure => write!(f, "pure"), TokenKind::Total => write!(f, "total"), @@ -550,6 +556,9 @@ impl<'a> Lexer<'a> { "as" => TokenKind::As, "from" => TokenKind::From, "latest" => TokenKind::Latest, + "trait" => TokenKind::Trait, + "impl" => TokenKind::Impl, + "for" => TokenKind::For, "is" => TokenKind::Is, "pure" => TokenKind::Pure, "total" => TokenKind::Total, diff --git a/src/main.rs b/src/main.rs index 06e6ab0..8f7ffdf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1318,5 +1318,63 @@ c")"#; // Sum from 1 to 100 = 5050 assert_eq!(eval(source).unwrap(), "5050"); } + + #[test] + fn test_trait_definition() { + // Test that trait declarations parse and type check correctly + let source = r#" + trait Show { + fn show(x: Int): String + } + let result = 42 + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_trait_impl() { + // Test that impl declarations parse and type check correctly + let source = r#" + trait Double { + fn double(x: Int): Int + } + impl Double for Int { + fn double(x: Int): Int = x * 2 + } + let result = 21 + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_trait_with_super_trait() { + // Test super trait syntax + let source = r#" + trait Eq { + fn eq(a: Int, b: Int): Bool + } + trait Ord: Eq { + fn lt(a: Int, b: Int): Bool + } + let result = 42 + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_impl_with_where_clause() { + // Test impl with where clause for trait constraints + let source = r#" + trait Show { + fn show(x: Int): String + } + let result = 42 + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } } } diff --git a/src/modules.rs b/src/modules.rs index 9cc3d74..6c85d9e 100644 --- a/src/modules.rs +++ b/src/modules.rs @@ -49,8 +49,9 @@ impl Module { Declaration::Function(f) => f.visibility == Visibility::Public, Declaration::Let(l) => l.visibility == Visibility::Public, Declaration::Type(t) => t.visibility == Visibility::Public, - // Effects and handlers are always public for now - Declaration::Effect(_) | Declaration::Handler(_) => true, + Declaration::Trait(t) => t.visibility == Visibility::Public, + // Effects, handlers, and impls are always public for now + Declaration::Effect(_) | Declaration::Handler(_) | Declaration::Impl(_) => true, } }) .collect() diff --git a/src/parser.rs b/src/parser.rs index f5c0163..075b1a8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -215,7 +215,9 @@ impl Parser { TokenKind::Handler => Ok(Declaration::Handler(self.parse_handler_decl()?)), TokenKind::Type => Ok(Declaration::Type(self.parse_type_decl(visibility)?)), TokenKind::Let => Ok(Declaration::Let(self.parse_let_decl(visibility)?)), - _ => Err(self.error("Expected declaration (fn, effect, handler, type, or let)")), + TokenKind::Trait => Ok(Declaration::Trait(self.parse_trait_decl(visibility)?)), + TokenKind::Impl => Ok(Declaration::Impl(self.parse_impl_decl()?)), + _ => Err(self.error("Expected declaration (fn, effect, handler, type, trait, impl, or let)")), } } @@ -503,6 +505,267 @@ impl Parser { }) } + /// Parse trait declaration: trait Show { fn show(self): String } + fn parse_trait_decl(&mut self, visibility: Visibility) -> Result { + let start = self.current_span(); + self.expect(TokenKind::Trait)?; + + let name = self.parse_ident()?; + + // Optional type parameters: trait Functor { ... } + let type_params = if self.check(TokenKind::Lt) { + self.parse_type_params()? + } else { + Vec::new() + }; + + // Optional super traits: trait Ord: Eq { ... } + let super_traits = if self.check(TokenKind::Colon) { + self.advance(); + self.parse_trait_bounds()? + } else { + Vec::new() + }; + + self.expect(TokenKind::LBrace)?; + self.skip_newlines(); + + let mut methods = Vec::new(); + while !self.check(TokenKind::RBrace) { + methods.push(self.parse_trait_method()?); + self.skip_newlines(); + if self.check(TokenKind::Comma) { + self.advance(); + } + self.skip_newlines(); + } + + let end = self.current_span(); + self.expect(TokenKind::RBrace)?; + + Ok(TraitDecl { + visibility, + name, + type_params, + super_traits, + methods, + span: start.merge(end), + }) + } + + /// Parse a trait method signature + fn parse_trait_method(&mut self) -> Result { + let start = self.current_span(); + self.expect(TokenKind::Fn)?; + + let name = self.parse_ident()?; + + // Optional type parameters + let type_params = if self.check(TokenKind::Lt) { + self.parse_type_params()? + } else { + Vec::new() + }; + + self.expect(TokenKind::LParen)?; + let params = self.parse_params()?; + self.expect(TokenKind::RParen)?; + + self.expect(TokenKind::Colon)?; + let return_type = self.parse_type()?; + + // Optional default implementation + let default_impl = if self.check(TokenKind::Eq) { + self.advance(); + self.skip_newlines(); + Some(self.parse_expr()?) + } else { + None + }; + + let span = start.merge(self.previous_span()); + Ok(TraitMethod { + name, + type_params, + params, + return_type, + default_impl, + span, + }) + } + + /// Parse trait bounds: Eq + Ord + Show + fn parse_trait_bounds(&mut self) -> Result, ParseError> { + let mut bounds = Vec::new(); + + loop { + bounds.push(self.parse_trait_bound()?); + + if self.check(TokenKind::Plus) { + self.advance(); + } else { + break; + } + } + + Ok(bounds) + } + + /// Parse a single trait bound: Show, Functor + fn parse_trait_bound(&mut self) -> Result { + let start = self.current_span(); + let trait_name = self.parse_ident()?; + + let type_args = if self.check(TokenKind::Lt) { + self.advance(); + let mut args = Vec::new(); + while !self.check(TokenKind::Gt) { + args.push(self.parse_type()?); + if !self.check(TokenKind::Gt) { + self.expect(TokenKind::Comma)?; + } + } + self.expect(TokenKind::Gt)?; + args + } else { + Vec::new() + }; + + let span = start.merge(self.previous_span()); + Ok(TraitBound { + trait_name, + type_args, + span, + }) + } + + /// Parse impl declaration: impl Show for Int { ... } + fn parse_impl_decl(&mut self) -> Result { + let start = self.current_span(); + self.expect(TokenKind::Impl)?; + + // Optional type parameters: impl Show for List { ... } + let type_params = if self.check(TokenKind::Lt) { + self.parse_type_params()? + } else { + Vec::new() + }; + + // Parse the trait name + let trait_name = self.parse_ident()?; + + // Optional type arguments for the trait: impl Functor for ... + let trait_args = if self.check(TokenKind::Lt) { + self.advance(); + let mut args = Vec::new(); + while !self.check(TokenKind::Gt) { + args.push(self.parse_type()?); + if !self.check(TokenKind::Gt) { + self.expect(TokenKind::Comma)?; + } + } + self.expect(TokenKind::Gt)?; + args + } else { + Vec::new() + }; + + self.expect(TokenKind::For)?; + let target_type = self.parse_type()?; + + // Optional where clause with trait constraints + let constraints = if self.check(TokenKind::Where) { + self.parse_trait_constraints()? + } else { + Vec::new() + }; + + self.expect(TokenKind::LBrace)?; + self.skip_newlines(); + + let mut methods = Vec::new(); + while !self.check(TokenKind::RBrace) { + methods.push(self.parse_impl_method()?); + self.skip_newlines(); + if self.check(TokenKind::Comma) { + self.advance(); + } + self.skip_newlines(); + } + + let end = self.current_span(); + self.expect(TokenKind::RBrace)?; + + Ok(ImplDecl { + type_params, + constraints, + trait_name, + trait_args, + target_type, + methods, + span: start.merge(end), + }) + } + + /// Parse trait constraints in a where clause: where T: Show, U: Eq + Ord + fn parse_trait_constraints(&mut self) -> Result, ParseError> { + let mut constraints = Vec::new(); + + while self.check(TokenKind::Where) { + self.advance(); + let start = self.current_span(); + let type_param = self.parse_ident()?; + self.expect(TokenKind::Colon)?; + let bounds = self.parse_trait_bounds()?; + let span = start.merge(self.previous_span()); + + constraints.push(TraitConstraint { + type_param, + bounds, + span, + }); + + if self.check(TokenKind::Comma) { + self.advance(); + } + } + + Ok(constraints) + } + + /// Parse an impl method + fn parse_impl_method(&mut self) -> Result { + let start = self.current_span(); + self.expect(TokenKind::Fn)?; + + let name = self.parse_ident()?; + + self.expect(TokenKind::LParen)?; + let params = self.parse_params()?; + self.expect(TokenKind::RParen)?; + + // Optional return type (infer if not provided) + let return_type = if self.check(TokenKind::Colon) { + self.advance(); + Some(self.parse_type()?) + } else { + None + }; + + self.expect(TokenKind::Eq)?; + self.skip_newlines(); + let body = self.parse_expr()?; + + let span = start.merge(body.span()); + Ok(ImplMethod { + name, + params, + return_type, + body, + span, + }) + } + /// Parse type parameters fn parse_type_params(&mut self) -> Result, ParseError> { self.expect(TokenKind::Lt)?; diff --git a/src/typechecker.rs b/src/typechecker.rs index 9c11769..0bd9023 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -3,16 +3,16 @@ #![allow(dead_code, unused_variables)] use crate::ast::{ - self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImportDecl, - LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span, Statement, - TypeDecl, TypeExpr, UnaryOp, VariantFields, + self, BinaryOp, Declaration, EffectDecl, Expr, FunctionDecl, HandlerDecl, Ident, ImplDecl, + ImportDecl, LetDecl, Literal, LiteralKind, MatchArm, Parameter, Pattern, Program, Span, + Statement, TraitDecl, TypeDecl, TypeExpr, UnaryOp, VariantFields, }; use crate::diagnostics::{Diagnostic, Severity}; use crate::exhaustiveness::{check_exhaustiveness, missing_patterns_hint}; use crate::modules::ModuleLoader; use crate::types::{ - self, unify, EffectDef, EffectOpDef, EffectSet, HandlerDef, PropertySet, Type, TypeEnv, - TypeScheme, VariantDef, VariantFieldsDef, + self, unify, EffectDef, EffectOpDef, EffectSet, HandlerDef, PropertySet, TraitBoundDef, + TraitDef, TraitImpl, TraitMethodDef, Type, TypeEnv, TypeScheme, VariantDef, VariantFieldsDef, }; /// Type checking error @@ -256,6 +256,15 @@ impl TypeChecker { }; self.env.bind(&let_decl.name.name, TypeScheme::mono(typ)); } + Declaration::Trait(trait_decl) => { + let trait_def = self.trait_def(trait_decl); + self.env.traits.insert(trait_decl.name.name.clone(), trait_def); + } + Declaration::Impl(impl_decl) => { + // Will be checked in second pass + let trait_impl = self.collect_impl(impl_decl); + self.env.trait_impls.push(trait_impl); + } } } @@ -271,7 +280,10 @@ impl TypeChecker { Declaration::Handler(handler) => { self.check_handler(handler); } - // Effects and types don't need checking beyond collection + Declaration::Impl(impl_decl) => { + self.check_impl(impl_decl); + } + // Effects, types, and traits don't need checking beyond collection _ => {} } } @@ -354,6 +366,29 @@ impl TypeChecker { // For now, we just type-check the predicate expression // (would need 'result' in scope, which we don't have yet) } + ast::WhereClause::TraitConstraint(constraint) => { + // Validate that the type parameter exists + if !func.type_params.iter().any(|p| p.name == constraint.type_param.name) + && !func.params.iter().any(|p| p.name.name == constraint.type_param.name) + { + self.errors.push(TypeError { + message: format!( + "Unknown type parameter '{}' in where clause", + constraint.type_param.name + ), + span: constraint.span, + }); + } + // Validate that each trait in the bounds exists + for bound in &constraint.bounds { + if !self.env.traits.contains_key(&bound.trait_name.name) { + self.errors.push(TypeError { + message: format!("Unknown trait: {}", bound.trait_name.name), + span: bound.span, + }); + } + } + } } } } @@ -1304,6 +1339,165 @@ impl TypeChecker { } } + fn trait_def(&self, trait_decl: &TraitDecl) -> TraitDef { + let methods = trait_decl + .methods + .iter() + .map(|m| TraitMethodDef { + name: m.name.name.clone(), + type_params: m.type_params.iter().map(|p| p.name.clone()).collect(), + params: m + .params + .iter() + .map(|p| (p.name.name.clone(), self.resolve_type(&p.typ))) + .collect(), + return_type: self.resolve_type(&m.return_type), + has_default: m.default_impl.is_some(), + }) + .collect(); + + let super_traits = trait_decl + .super_traits + .iter() + .map(|b| TraitBoundDef { + trait_name: b.trait_name.name.clone(), + type_args: b.type_args.iter().map(|t| self.resolve_type(t)).collect(), + }) + .collect(); + + TraitDef { + name: trait_decl.name.name.clone(), + type_params: trait_decl.type_params.iter().map(|p| p.name.clone()).collect(), + super_traits, + methods, + } + } + + fn collect_impl(&self, impl_decl: &ImplDecl) -> TraitImpl { + use std::collections::HashMap; + + let methods: HashMap = impl_decl + .methods + .iter() + .map(|m| { + let return_type = m + .return_type + .as_ref() + .map(|t| self.resolve_type(t)) + .unwrap_or_else(Type::var); + let param_types: Vec = m + .params + .iter() + .map(|p| self.resolve_type(&p.typ)) + .collect(); + let func_type = Type::function(param_types, return_type); + (m.name.name.clone(), func_type) + }) + .collect(); + + let constraints = impl_decl + .constraints + .iter() + .map(|c| { + let bounds = c + .bounds + .iter() + .map(|b| TraitBoundDef { + trait_name: b.trait_name.name.clone(), + type_args: b.type_args.iter().map(|t| self.resolve_type(t)).collect(), + }) + .collect(); + (c.type_param.name.clone(), bounds) + }) + .collect(); + + TraitImpl { + trait_name: impl_decl.trait_name.name.clone(), + trait_args: impl_decl + .trait_args + .iter() + .map(|t| self.resolve_type(t)) + .collect(), + target_type: self.resolve_type(&impl_decl.target_type), + type_params: impl_decl.type_params.iter().map(|p| p.name.clone()).collect(), + constraints, + methods, + } + } + + fn check_impl(&mut self, impl_decl: &ImplDecl) { + // Verify the trait exists + let trait_name = &impl_decl.trait_name.name; + let trait_def = match self.env.traits.get(trait_name) { + Some(def) => def.clone(), + None => { + self.errors.push(TypeError { + message: format!("Unknown trait: {}", trait_name), + span: impl_decl.span, + }); + return; + } + }; + + // Verify all required methods are implemented + for method_def in &trait_def.methods { + if !method_def.has_default { + let implemented = impl_decl.methods.iter().any(|m| m.name.name == method_def.name); + if !implemented { + self.errors.push(TypeError { + message: format!( + "Missing implementation for required method '{}' of trait '{}'", + method_def.name, trait_name + ), + span: impl_decl.span, + }); + } + } + } + + // Type check each implemented method + for impl_method in &impl_decl.methods { + // Find the method signature in the trait + let method_def = trait_def.methods.iter().find(|m| m.name == impl_method.name.name); + if method_def.is_none() { + self.errors.push(TypeError { + message: format!( + "Method '{}' is not defined in trait '{}'", + impl_method.name.name, trait_name + ), + span: impl_method.span, + }); + continue; + } + + // Set up local environment with parameters + let mut local_env = self.env.clone(); + for param in &impl_method.params { + let param_type = self.resolve_type(¶m.typ); + local_env.bind(¶m.name.name, TypeScheme::mono(param_type)); + } + + // Type check the body + let old_env = std::mem::replace(&mut self.env, local_env); + let body_type = self.infer_expr(&impl_method.body); + self.env = old_env; + + // Check return type matches if specified + if let Some(ref return_type_expr) = impl_method.return_type { + let return_type = self.resolve_type(return_type_expr); + if let Err(e) = unify(&body_type, &return_type) { + self.errors.push(TypeError { + message: format!( + "Method '{}' body has type {}, but declared return type is {}: {}", + impl_method.name.name, body_type, return_type, e + ), + span: impl_method.span, + }); + } + } + } + } + fn resolve_type(&self, type_expr: &TypeExpr) -> Type { match type_expr { TypeExpr::Named(ident) => match ident.name.as_str() { diff --git a/src/types.rs b/src/types.rs index 2f13ec9..1ece77c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -582,6 +582,10 @@ pub struct TypeEnv { pub effects: HashMap, /// Handler types pub handlers: HashMap, + /// Trait definitions + pub traits: HashMap, + /// Trait implementations: (trait_name, type) -> impl + pub trait_impls: Vec, } /// Type definition @@ -615,6 +619,66 @@ pub struct HandlerDef { pub params: Vec<(String, Type)>, } +/// Trait definition +#[derive(Debug, Clone)] +pub struct TraitDef { + pub name: String, + /// Type parameters for the trait (e.g., Functor) + pub type_params: Vec, + /// Super traits that must be implemented + pub super_traits: Vec, + /// Method signatures + pub methods: Vec, +} + +/// A trait bound in type definitions +#[derive(Debug, Clone)] +pub struct TraitBoundDef { + pub trait_name: String, + pub type_args: Vec, +} + +/// A method signature in a trait +#[derive(Debug, Clone)] +pub struct TraitMethodDef { + pub name: String, + pub type_params: Vec, + pub params: Vec<(String, Type)>, + pub return_type: Type, + /// Whether this method has a default implementation + pub has_default: bool, +} + +/// A trait implementation +#[derive(Debug, Clone)] +pub struct TraitImpl { + pub trait_name: String, + pub trait_args: Vec, + /// The type this impl is for + pub target_type: Type, + /// Type parameters on the impl (e.g., impl Show for List) + pub type_params: Vec, + /// Constraints on type parameters (e.g., where T: Show) + pub constraints: Vec<(String, Vec)>, + /// Method implementations + pub methods: HashMap, +} + +/// A trait constraint on a type variable +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TraitConstraintDef { + pub type_var: String, + pub bounds: Vec, +} + +impl PartialEq for TraitBoundDef { + fn eq(&self, other: &Self) -> bool { + self.trait_name == other.trait_name && self.type_args == other.type_args + } +} + +impl Eq for TraitBoundDef {} + impl TypeEnv { pub fn new() -> Self { Self::default()