diff --git a/pcf.cabal b/pcf.cabal index 99b29fd..a819f22 100644 --- a/pcf.cabal +++ b/pcf.cabal @@ -20,7 +20,8 @@ extra-source-files: README.md common shared - build-depends: base >=4.13.0.0 && <4.18.0.0.0 + build-depends: base >=4.13.0.0 && <4.18.0.0.0, + containers hs-source-dirs: src default-language: Haskell2010 diff --git a/src/Evaluator.hs b/src/Evaluator.hs index 2584d7f..7063ce3 100644 --- a/src/Evaluator.hs +++ b/src/Evaluator.hs @@ -6,31 +6,31 @@ import Misc -- small-step cbn semantics of PCF eval :: Term -> Term -- if-then-else -eval (ITE T e2 e3) = e2 -eval (ITE F e2 e3) = e3 -eval (ITE e1 e2 e3) = ITE (eval e1) e2 e3 +eval (Ite T e2 e3) = e2 +eval (Ite F e2 e3) = e3 +eval (Ite e1 e2 e3) = Ite (eval e1) e2 e3 -- arithmetic -eval (ADD (Number n1) (Number n2)) = Number $ n1 + n2 -eval (ADD e1@(Number n1) e2) = ADD e1 (eval e2) -eval (ADD e1 e2) = ADD (eval e1) e2 -eval (SUB (Number n1) (Number n2)) = Number $ n1 - n2 -eval (SUB e1@(Number n1) e2) = SUB e1 (eval e2) -eval (SUB e1 e2) = SUB (eval e1) e2 +eval (Add (Number n1) (Number n2)) = Number $ n1 + n2 +eval (Add e1@(Number n1) e2) = Add e1 (eval e2) +eval (Add e1 e2) = Add (eval e1) e2 +eval (Sub (Number n1) (Number n2)) = Number $ n1 - n2 +eval (Sub e1@(Number n1) e2) = Sub e1 (eval e2) +eval (Sub e1 e2) = Sub (eval e1) e2 -- equality -- cases where both arguments are evaluated already -eval (EQ (Number n1) (Number n2)) = if n1 == n2 then T else F -eval (EQ T T) = T -eval (EQ F F) = T -eval (EQ T F) = F -eval (EQ F T) = F +eval (Eql (Number n1) (Number n2)) = if n1 == n2 then T else F +eval (Eql T T) = T +eval (Eql F F) = T +eval (Eql T F) = F +eval (Eql F T) = F -- left argument is evaluated, eval the right one -eval (EQ e1@(Number n1) e2) = EQ e1 (eval e2) -eval (EQ T e2) = EQ T (eval e2) -eval (EQ F e2) = EQ F (eval e2) +eval (Eql e1@(Number n1) e2) = Eql e1 (eval e2) +eval (Eql T e2) = Eql T (eval e2) +eval (Eql F e2) = Eql F (eval e2) -- no argument is evaluated, eval the left one -eval (EQ e1 e2) = EQ (eval e1) e2 +eval (Eql e1 e2) = Eql (eval e1) e2 -- application -- beta reduction diff --git a/src/Misc.hs b/src/Misc.hs index d1f8195..dbc9404 100644 --- a/src/Misc.hs +++ b/src/Misc.hs @@ -1,12 +1,16 @@ module Misc where +import Syntax +import Data.List +import Data.Maybe + -- calculates free variables in a Term freeVars :: Term -> [Name] freeVars (Var x) = [x] -freeVars (ITE e1 e2 e3) = freeVars e1 ++ freeVars e2 ++ freeVars e3 -freeVars (ADD e1 e2) = freeVars e1 ++ freeVars e2 -freeVars (SUB e1 e2) = freeVars e1 ++ freeVars e2 -freeVars (EQ e1 e2) = freeVars e1 ++ freeVars e2 +freeVars (Ite e1 e2 e3) = freeVars e1 ++ freeVars e2 ++ freeVars e3 +freeVars (Add e1 e2) = freeVars e1 ++ freeVars e2 +freeVars (Sub e1 e2) = freeVars e1 ++ freeVars e2 +freeVars (Eql e1 e2) = freeVars e1 ++ freeVars e2 freeVars (App e1 e2) = freeVars e1 ++ freeVars e2 freeVars (Abs x p) = filter (/= x) $ freeVars p freeVars _ = [] @@ -18,11 +22,11 @@ vars = ['v' : show n | n <- [1..]] -- capture-avoiding substitution subst :: (Name -> Term) -> Term -> Term subst o (Var x) = o x -subst o (ITE e1 e2 e3) = ITE (subst o e1) (subst o e2) (subst o e3) +subst o (Ite e1 e2 e3) = Ite (subst o e1) (subst o e2) (subst o e3) subst o (Case e1 e2 e3) = Case (subst o e1) (subst o e2) (subst o e3) -subst o (ADD e1 e2) = ADD (subst o e1) (subst o e2) -subst o (SUB e1 e2) = SUB (subst o e1) (subst o e2) -subst o (EQ e1 e2) = EQ (subst o e1) (subst o e2) +subst o (Add e1 e2) = Add (subst o e1) (subst o e2) +subst o (Sub e1 e2) = Sub (subst o e1) (subst o e2) +subst o (Eql e1 e2) = Eql (subst o e1) (subst o e2) subst o (App e1 e2) = App (subst o e1) (subst o e2) subst o (Cons e1 e2) = Cons (subst o e1) (subst o e2) subst o (Pair e1 e2) = Pair (subst o e1) (subst o e2) diff --git a/src/Syntax.hs b/src/Syntax.hs index bf17788..78b1a4b 100644 --- a/src/Syntax.hs +++ b/src/Syntax.hs @@ -3,43 +3,55 @@ module Syntax where type Name = String data Term - = T | F - | ITE Term Term Term - | Number Int - | ADD Term Term - | SUB Term Term - | EQ Term Term - | Fst Term - | Snd Term - | Inl Term - | Inr Term - | Case Term Term Term - | Nil - | Cons Term Term - | Pair Term Term - | Var Name - | App Term Term - | Abs Name Term - | Fix - deriving (Eq) + = T + | F + | Ite Term Term Term + | Number Int + | Add Term Term + | Sub Term Term + | Eql Term Term + | Or Term Term + | Fst Term + | Snd Term + | Inl Term + | Inr Term + | Case Term Term Term + | Nil + | Cons Term Term + | Pair Term Term + | Var Name + | App Term Term + | Abs Name Term + | Fix + | Star + deriving (Eq) instance Show Term where - show T = "true" - show F = "false" - show (ITE t1 t2 t3) = "if (" ++ show t1 ++ ")\n then (" ++ show t2 ++ ")\n else (" ++ show t3 ++ ")" - show (Case t1 t2 t3) = "\n case (" ++ show t1 ++ ") of\n <0> => (" ++ show t2 ++ ");\n <1> => (" ++ show t3 ++ ")" - show (ADD t1 t2) = "(" ++ show t1 ++ ") + (" ++ show t2 ++ ")" - show (SUB t1 t2) = "(" ++ show t1 ++ ") - (" ++ show t2 ++ ")" - show (EQ t1 t2) = "(" ++ show t1 ++ ") == (" ++ show t2 ++ ")" - show (Number n) = show n - show (Fst t) = "fst " ++ show t - show (Snd t) = "snd " ++ show t - show (Inl t) = "inl " ++ show t - show (Inr t) = "inr " ++ show t - show Nil = "[]" - show (Cons t1 t2) = "(" ++ show t1 ++ ") : (" ++ show t2 ++ ")" - show (Pair t1 t2) = "(" ++ show t1 ++ ", " ++ show t2 ++ ")" - show (Var n) = n - show (App t1 t2) = "(" ++ show t1 ++ ") " ++ show t2 - show Fix = "Y" - show (Abs x p) = "\\" ++ x ++ ". " ++ show p \ No newline at end of file + show T = "true" + show F = "false" + show (Ite t1 t2 t3) = "if (" ++ show t1 ++ ")\n then (" ++ show t2 ++ ")\n else (" ++ show t3 ++ ")" + show (Case t1 t2 t3) = "\n case (" ++ show t1 ++ ") of\n <0> => (" ++ show t2 ++ ");\n <1> => (" ++ show t3 ++ ")" + show (Add t1 t2) = "(" ++ show t1 ++ ") + (" ++ show t2 ++ ")" + show (Sub t1 t2) = "(" ++ show t1 ++ ") - (" ++ show t2 ++ ")" + show (Eql t1 t2) = "(" ++ show t1 ++ ") == (" ++ show t2 ++ ")" + show (Number n) = show n + show (Fst t) = "fst " ++ show t + show (Snd t) = "snd " ++ show t + show (Inl t) = "inl " ++ show t + show (Inr t) = "inr " ++ show t + show Nil = "[]" + show (Cons t1 t2) = "(" ++ show t1 ++ ") : (" ++ show t2 ++ ")" + show (Pair t1 t2) = "(" ++ show t1 ++ ", " ++ show t2 ++ ")" + show (Var n) = n + show (App t1 t2) = "(" ++ show t1 ++ ") " ++ show t2 + show Fix = "Y" + show (Abs x p) = "\\" ++ x ++ ". " ++ show p + +data Type + = TBool + | TNat + | TUnit + | TProd Type Type + | TArrow Type Type + | THole + deriving (Eq, Show) \ No newline at end of file diff --git a/src/Typecheck.hs b/src/Typecheck.hs index ff8bc5e..95d6721 100644 --- a/src/Typecheck.hs +++ b/src/Typecheck.hs @@ -1,3 +1,135 @@ module Typecheck where -import Syntax \ No newline at end of file +import Control.Monad +import Data.Map.Strict +import qualified Data.Map.Strict as Map +import Syntax + +term1 = Abs "y" $ Add (App (Abs "x" $ Or (Var "x") (Var "y")) T) (Var "y") +term2 = Abs "x" $ Add (Number 5) (Var "x") +term3 = Abs "x" $ Eql (Var "x") (Number 5) + +type Context = Map Name Type + +typeCheck :: Term -> Maybe Type +typeCheck t = fst <$> inferType Map.empty t THole + +inferType :: Context -> Term -> Type -> Maybe (Type, Context) +inferType ctx T _ = Just (TBool, ctx) +inferType ctx F _ = Just (TBool, ctx) +inferType ctx (Ite b t1 t2) expected = do + (tb, ctxb) <- inferType ctx b TBool + (tp1, ctx1) <- inferType ctx t1 expected + (tp2, ctx2) <- inferType ctx t2 expected + ctx' <- mergeContexts ctxb ctx1 + ctx'' <- mergeContexts ctx' ctx2 + if tp1 == tp2 + then Just (tp1, ctx'') + else Nothing +inferType ctx (Number _) _ = Just (TNat, ctx) +inferType ctx (Add t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 TNat + (tp2, ctx2) <- inferType ctx t2 TNat + ctx' <- mergeContexts ctx1 ctx2 + case (tp1, tp2) of + (TNat, TNat) -> Just (TNat, ctx') + _ -> Nothing +inferType ctx (Sub t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 TNat + (tp2, ctx2) <- inferType ctx t2 TNat + ctx' <- mergeContexts ctx1 ctx2 + case (tp1, tp2) of + (TNat, TNat) -> Just (TNat, ctx') + _ -> Nothing +inferType ctx (Or t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 TBool + (tp2, ctx2) <- inferType ctx t2 TBool + ctx' <- mergeContexts ctx1 ctx2 + case (tp1, tp2) of + (TBool, TBool) -> Just (TBool, ctx') + _ -> Nothing +inferType ctx (Eql t1@(Var v1) t2) _ = do + (tp2, ctx2) <- inferType ctx t2 THole + (tp1, ctx1) <- inferType ctx t1 tp2 + ctx' <- mergeContexts ctx1 ctx2 + if tp1 == tp2 + then case tp1 of + TBool -> Just (TBool, ctx') + TNat -> Just (TBool, ctx') + TUnit -> Just (TBool, ctx') + _ -> Nothing + else Nothing +inferType ctx (Eql t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 THole + (tp2, ctx2) <- inferType ctx t2 tp1 + ctx' <- mergeContexts ctx1 ctx2 + if tp1 == tp2 + then case tp1 of + TBool -> Just (TBool, ctx') + TNat -> Just (TBool, ctx') + TUnit -> Just (TBool, ctx') + _ -> Nothing + else Nothing +inferType ctx (Pair t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 THole + (tp2, ctx2) <- inferType ctx t2 THole + ctx' <- mergeContexts ctx1 ctx2 + Just (TProd tp1 tp2, ctx') +inferType ctx (Fst t) _ = do + (tp, ctx') <- inferType ctx t THole + case tp of + TProd tp1 tp2 -> Just (tp1, ctx') + _ -> Nothing +inferType ctx (Snd t) _ = do + (tp, ctx') <- inferType ctx t THole + case tp of + TProd tp1 tp2 -> Just (tp2, ctx') + _ -> Nothing +inferType ctx Star _ = Just (TUnit, ctx) +inferType ctx (Var v) expected = do + tp <- Map.lookup v ctx + case tp of + THole -> return (expected, Map.insert v expected ctx) + tp' | tp == tp' -> return (expected, ctx) + _ -> Nothing +inferType ctx (App Fix t2) _ = do + (tp2, ctx') <- inferType ctx t2 THole + case tp2 of + TArrow tpl tpr | tpl == tpr -> Just (tpl, ctx') + _ -> Nothing +inferType ctx (App t1 t2) _ = do + (tp1, ctx1) <- inferType ctx t1 THole + (tp2, ctx2) <- inferType ctx t2 THole + ctx' <- mergeContexts ctx1 ctx2 + case tp1 of + TArrow tpl tpr | tpl == tp2 -> Just (tpr, ctx') + _ -> Nothing +inferType ctx (Abs x t) _ = do + -- save old type of x (it might get overwritten here, but it also might be Nothing atm) + let tpxOld = Map.lookup x ctx + + -- do type inference in t, this might fill some holes we had, so we can't discard ctx'... + (tp, ctx') <- inferType (Map.insert x THole ctx) t THole + tpx <- Map.lookup x ctx' + + -- restore old type of x in new context + let ctx'' = case tpxOld of + Nothing -> ctx' + Just old -> Map.insert x old ctx' + return (TArrow tpx tp, ctx'') + +-- merge two contexts +mergeContexts :: Context -> Context -> Maybe Context +mergeContexts ctx1 ctx2 = do + ctx' <- foldrWithKey helper (Just Map.empty) ctx1 + return $ Map.union ctx' ctx2 + where + helper name tp Nothing = Nothing + helper name THole (Just ctx) = case Map.lookup name ctx2 of + Nothing -> return $ Map.insert name THole ctx + Just tp' -> return $ Map.insert name tp' ctx + helper name tp (Just ctx) = case Map.lookup name ctx2 of + Nothing -> return $ Map.insert name tp ctx + Just THole -> return $ Map.insert name tp ctx + Just tp' | tp == tp' -> Just ctx + Just tp' | otherwise -> Nothing