diff --git a/src/exhaustiveness.rs b/src/exhaustiveness.rs new file mode 100644 index 0000000..a101522 --- /dev/null +++ b/src/exhaustiveness.rs @@ -0,0 +1,408 @@ +//! Pattern match exhaustiveness checking +//! +//! Implements the "usefulness" algorithm to check if pattern matches +//! cover all possible cases. + +use crate::ast::{Literal, LiteralKind, MatchArm, Pattern}; +use crate::types::{Type, TypeDef, TypeEnv, VariantDef}; +use std::collections::HashSet; + +/// Result of exhaustiveness checking +#[derive(Debug, Clone)] +pub struct ExhaustivenessResult { + pub is_exhaustive: bool, + pub missing_patterns: Vec, + pub redundant_arms: Vec, +} + +/// Check if a match expression is exhaustive +pub fn check_exhaustiveness( + scrutinee_type: &Type, + arms: &[MatchArm], + env: &TypeEnv, +) -> ExhaustivenessResult { + let patterns: Vec<&Pattern> = arms.iter().map(|arm| &arm.pattern).collect(); + + // Check for guards - patterns with guards don't guarantee coverage + let has_guards = arms.iter().any(|arm| arm.guard.is_some()); + + // Get missing patterns + let missing = find_missing_patterns(scrutinee_type, &patterns, env); + + // Find redundant arms (patterns that can never match) + let redundant = find_redundant_arms(&patterns); + + // If any pattern has a guard, we can't guarantee exhaustiveness + // unless there's an unconditional wildcard at the end + let is_exhaustive = if has_guards { + // Check if last pattern is an unconditional catch-all + arms.last() + .map(|arm| arm.guard.is_none() && is_catch_all(&arm.pattern)) + .unwrap_or(false) + } else { + missing.is_empty() + }; + + ExhaustivenessResult { + is_exhaustive, + missing_patterns: missing, + redundant_arms: redundant, + } +} + +/// Check if a pattern is a catch-all (matches anything) +fn is_catch_all(pattern: &Pattern) -> bool { + match pattern { + Pattern::Wildcard(_) => true, + Pattern::Var(_) => true, + _ => false, + } +} + +/// Find patterns that are missing from coverage +fn find_missing_patterns( + scrutinee_type: &Type, + patterns: &[&Pattern], + env: &TypeEnv, +) -> Vec { + // If any pattern is a catch-all, nothing is missing + if patterns.iter().any(|p| is_catch_all(p)) { + return Vec::new(); + } + + match scrutinee_type { + Type::Bool => check_bool_exhaustiveness(patterns), + Type::Option(inner) => check_option_exhaustiveness(patterns, inner, env), + Type::Named(name) => check_named_type_exhaustiveness(patterns, name, env), + Type::Tuple(elements) => check_tuple_exhaustiveness(patterns, elements, env), + // For other types (Int, String, etc.), we can't enumerate all values + // So we need a wildcard pattern + Type::Int | Type::Float | Type::String | Type::Char => { + vec!["_".to_string()] + } + // Unit type has exactly one value + Type::Unit => Vec::new(), + // For type variables and other complex types, assume exhaustive if there's any pattern + _ => { + if patterns.is_empty() { + vec!["_".to_string()] + } else { + Vec::new() + } + } + } +} + +/// Check Bool exhaustiveness +fn check_bool_exhaustiveness(patterns: &[&Pattern]) -> Vec { + let mut has_true = false; + let mut has_false = false; + + for pattern in patterns { + match pattern { + Pattern::Literal(Literal { + kind: LiteralKind::Bool(true), + .. + }) => has_true = true, + Pattern::Literal(Literal { + kind: LiteralKind::Bool(false), + .. + }) => has_false = true, + _ => {} + } + } + + let mut missing = Vec::new(); + if !has_true { + missing.push("true".to_string()); + } + if !has_false { + missing.push("false".to_string()); + } + missing +} + +/// Check Option exhaustiveness +fn check_option_exhaustiveness( + patterns: &[&Pattern], + _inner: &Type, + _env: &TypeEnv, +) -> Vec { + let mut has_none = false; + let mut has_some = false; + + for pattern in patterns { + if let Pattern::Constructor { name, .. } = pattern { + match name.name.as_str() { + "None" => has_none = true, + "Some" => has_some = true, + _ => {} + } + } + } + + let mut missing = Vec::new(); + if !has_none { + missing.push("None".to_string()); + } + if !has_some { + missing.push("Some(_)".to_string()); + } + missing +} + +/// Check named type (enum) exhaustiveness +fn check_named_type_exhaustiveness( + patterns: &[&Pattern], + type_name: &str, + env: &TypeEnv, +) -> Vec { + // Look up the type definition + let typedef = match env.types.get(type_name) { + Some(td) => td, + None => return Vec::new(), // Unknown type, assume exhaustive + }; + + // Handle Result specially since it's common + if type_name == "Result" { + return check_result_exhaustiveness(patterns); + } + + // Get all constructors for enum types + let constructors: Vec<&VariantDef> = match typedef { + TypeDef::Enum(variants) => variants.iter().collect(), + _ => return Vec::new(), // Not an enum, assume exhaustive + }; + + // Find which constructors are covered + let mut covered: HashSet<&str> = HashSet::new(); + for pattern in patterns { + if let Pattern::Constructor { name, .. } = pattern { + covered.insert(&name.name); + } + } + + // Find missing constructors + constructors + .iter() + .filter(|v| !covered.contains(v.name.as_str())) + .map(|v| format_constructor_pattern(v)) + .collect() +} + +/// Check Result exhaustiveness +fn check_result_exhaustiveness(patterns: &[&Pattern]) -> Vec { + let mut has_ok = false; + let mut has_err = false; + + for pattern in patterns { + if let Pattern::Constructor { name, .. } = pattern { + match name.name.as_str() { + "Ok" => has_ok = true, + "Err" => has_err = true, + _ => {} + } + } + } + + let mut missing = Vec::new(); + if !has_ok { + missing.push("Ok(_)".to_string()); + } + if !has_err { + missing.push("Err(_)".to_string()); + } + missing +} + +/// Check tuple exhaustiveness +fn check_tuple_exhaustiveness( + patterns: &[&Pattern], + _elements: &[Type], + _env: &TypeEnv, +) -> Vec { + // Tuples need a pattern that matches the whole tuple + // For simplicity, we just check if there's any tuple pattern + let has_tuple_pattern = patterns.iter().any(|p| matches!(p, Pattern::Tuple { .. })); + + if has_tuple_pattern { + Vec::new() + } else { + vec!["(_, ...)".to_string()] + } +} + +/// Format a variant as a pattern suggestion +fn format_constructor_pattern(variant: &VariantDef) -> String { + use crate::types::VariantFieldsDef; + + match &variant.fields { + VariantFieldsDef::Unit => variant.name.clone(), + VariantFieldsDef::Tuple(fields) => { + let wildcards: Vec<&str> = fields.iter().map(|_| "_").collect(); + format!("{}({})", variant.name, wildcards.join(", ")) + } + VariantFieldsDef::Record(fields) => { + let wildcards: Vec = fields.iter().map(|(n, _)| format!("{}: _", n)).collect(); + format!("{} {{ {} }}", variant.name, wildcards.join(", ")) + } + } +} + +/// Find redundant arms that can never match +fn find_redundant_arms(patterns: &[&Pattern]) -> Vec { + let mut redundant = Vec::new(); + let mut seen_catch_all = false; + + for (i, pattern) in patterns.iter().enumerate() { + if seen_catch_all { + // Any pattern after a catch-all is redundant + redundant.push(i); + } else if is_catch_all(pattern) { + seen_catch_all = true; + } + } + + redundant +} + +/// Generate a hint message for missing patterns +pub fn missing_patterns_hint(missing: &[String]) -> String { + if missing.is_empty() { + return String::new(); + } + + if missing.len() == 1 { + format!("Pattern '{}' is not covered.", missing[0]) + } else if missing.len() <= 3 { + format!("Patterns {} are not covered.", missing.join(", ")) + } else { + format!( + "Patterns {}, and {} more are not covered.", + missing[..2].join(", "), + missing.len() - 2 + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::{Ident, Span}; + + fn span() -> Span { + Span::default() + } + + fn make_ident(name: &str) -> Ident { + Ident { + name: name.to_string(), + span: span(), + } + } + + #[test] + fn test_bool_exhaustive() { + let patterns = vec![ + Pattern::Literal(Literal { + kind: LiteralKind::Bool(true), + span: span(), + }), + Pattern::Literal(Literal { + kind: LiteralKind::Bool(false), + span: span(), + }), + ]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let missing = check_bool_exhaustiveness(&refs); + assert!(missing.is_empty()); + } + + #[test] + fn test_bool_missing_true() { + let patterns = vec![Pattern::Literal(Literal { + kind: LiteralKind::Bool(false), + span: span(), + })]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let missing = check_bool_exhaustiveness(&refs); + assert_eq!(missing, vec!["true"]); + } + + #[test] + fn test_option_exhaustive() { + let patterns = vec![ + Pattern::Constructor { + name: make_ident("None"), + fields: vec![], + span: span(), + }, + Pattern::Constructor { + name: make_ident("Some"), + fields: vec![Pattern::Wildcard(span())], + span: span(), + }, + ]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let env = TypeEnv::new(); + let missing = check_option_exhaustiveness(&refs, &Type::Int, &env); + assert!(missing.is_empty()); + } + + #[test] + fn test_option_missing_none() { + let patterns = vec![Pattern::Constructor { + name: make_ident("Some"), + fields: vec![Pattern::Wildcard(span())], + span: span(), + }]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let env = TypeEnv::new(); + let missing = check_option_exhaustiveness(&refs, &Type::Int, &env); + assert!(missing.contains(&"None".to_string())); + } + + #[test] + fn test_wildcard_covers_all() { + let patterns = vec![Pattern::Wildcard(span())]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + assert!(is_catch_all(&patterns[0])); + + let env = TypeEnv::new(); + let missing = find_missing_patterns(&Type::Bool, &refs, &env); + assert!(missing.is_empty()); + } + + #[test] + fn test_redundant_after_wildcard() { + let patterns = vec![ + Pattern::Wildcard(span()), + Pattern::Literal(Literal { + kind: LiteralKind::Bool(true), + span: span(), + }), + ]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let redundant = find_redundant_arms(&refs); + assert_eq!(redundant, vec![1]); + } + + #[test] + fn test_result_exhaustive() { + let patterns = vec![ + Pattern::Constructor { + name: make_ident("Ok"), + fields: vec![Pattern::Wildcard(span())], + span: span(), + }, + Pattern::Constructor { + name: make_ident("Err"), + fields: vec![Pattern::Wildcard(span())], + span: span(), + }, + ]; + let refs: Vec<&Pattern> = patterns.iter().collect(); + let missing = check_result_exhaustiveness(&refs); + assert!(missing.is_empty()); + } +} diff --git a/src/main.rs b/src/main.rs index 200ea9b..93b247e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ mod ast; mod diagnostics; +mod exhaustiveness; mod interpreter; mod lexer; mod modules; @@ -931,4 +932,114 @@ c")"#; assert_eq!(diag.title, "Purity Violation"); } } + + // Exhaustiveness checking tests + mod exhaustiveness_tests { + use super::*; + + #[test] + fn test_exhaustive_bool_match() { + let source = r#" + fn check(b: Bool): Int = match b { + true => 1, + false => 0 + } + let result = check(true) + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_non_exhaustive_bool_match() { + let source = r#" + fn check(b: Bool): Int = match b { + true => 1 + } + let result = check(true) + "#; + let result = eval(source); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Non-exhaustive")); + } + + #[test] + fn test_exhaustive_option_match() { + let source = r#" + fn unwrap_or(opt: Option, default: Int): Int = match opt { + Some(x) => x, + None => default + } + let result = unwrap_or(Some(42), 0) + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_non_exhaustive_option_missing_none() { + let source = r#" + fn get_value(opt: Option): Int = match opt { + Some(x) => x + } + let result = get_value(Some(1)) + "#; + let result = eval(source); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Non-exhaustive")); + } + + #[test] + fn test_wildcard_is_exhaustive() { + let source = r#" + fn classify(n: Int): String = match n { + 0 => "zero", + 1 => "one", + _ => "many" + } + let result = classify(5) + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_variable_pattern_is_exhaustive() { + let source = r#" + fn identity(n: Int): Int = match n { + x => x + } + let result = identity(42) + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + + #[test] + fn test_redundant_arm_warning() { + let source = r#" + fn test_fn(n: Int): Int = match n { + _ => 1, + 0 => 2 + } + let result = test_fn(0) + "#; + let result = eval(source); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Redundant")); + } + + #[test] + fn test_exhaustive_result_match() { + let source = r#" + fn handle_result(r: Result): Int = match r { + Ok(n) => n, + Err(_) => 0 + } + let result = handle_result(Ok(42)) + "#; + let result = eval(source); + assert!(result.is_ok(), "Expected success but got: {:?}", result); + } + } } diff --git a/src/typechecker.rs b/src/typechecker.rs index dc33fb1..e0d3d5b 100644 --- a/src/typechecker.rs +++ b/src/typechecker.rs @@ -8,6 +8,7 @@ use crate::ast::{ TypeDecl, TypeExpr, UnaryOp, VariantFields, }; use crate::diagnostics::{Diagnostic, Severity}; +use crate::exhaustiveness::{check_exhaustiveness, missing_patterns_hint}; use crate::modules::ModuleLoader; use crate::types::{ self, unify, EffectDef, EffectOpDef, EffectSet, HandlerDef, PropertySet, Type, TypeEnv, @@ -988,6 +989,30 @@ impl TypeChecker { } } + // Check exhaustiveness + let exhaustiveness = check_exhaustiveness(&scrutinee_type, arms, &self.env); + + if !exhaustiveness.is_exhaustive { + let hint = missing_patterns_hint(&exhaustiveness.missing_patterns); + self.errors.push(TypeError { + message: format!( + "Non-exhaustive pattern match. {}", + hint + ), + span, + }); + } + + // Warn about redundant arms + for idx in exhaustiveness.redundant_arms { + self.errors.push(TypeError { + message: format!( + "Redundant pattern: this arm will never be matched because previous patterns cover all cases" + ), + span: arms[idx].span, + }); + } + result_type.unwrap_or(Type::Error) }