Implemented typechecking (not 100% correct)

This commit is contained in:
reijix 2023-06-01 12:25:20 +02:00
parent c11d46956b
commit 3b18671079
5 changed files with 215 additions and 66 deletions

View file

@ -20,7 +20,8 @@ extra-source-files:
README.md
common shared
build-depends: base >=4.13.0.0 && <4.18.0.0.0
build-depends: base >=4.13.0.0 && <4.18.0.0.0,
containers
hs-source-dirs: src
default-language: Haskell2010

View file

@ -6,31 +6,31 @@ import Misc
-- small-step cbn semantics of PCF
eval :: Term -> Term
-- if-then-else
eval (ITE T e2 e3) = e2
eval (ITE F e2 e3) = e3
eval (ITE e1 e2 e3) = ITE (eval e1) e2 e3
eval (Ite T e2 e3) = e2
eval (Ite F e2 e3) = e3
eval (Ite e1 e2 e3) = Ite (eval e1) e2 e3
-- arithmetic
eval (ADD (Number n1) (Number n2)) = Number $ n1 + n2
eval (ADD e1@(Number n1) e2) = ADD e1 (eval e2)
eval (ADD e1 e2) = ADD (eval e1) e2
eval (SUB (Number n1) (Number n2)) = Number $ n1 - n2
eval (SUB e1@(Number n1) e2) = SUB e1 (eval e2)
eval (SUB e1 e2) = SUB (eval e1) e2
eval (Add (Number n1) (Number n2)) = Number $ n1 + n2
eval (Add e1@(Number n1) e2) = Add e1 (eval e2)
eval (Add e1 e2) = Add (eval e1) e2
eval (Sub (Number n1) (Number n2)) = Number $ n1 - n2
eval (Sub e1@(Number n1) e2) = Sub e1 (eval e2)
eval (Sub e1 e2) = Sub (eval e1) e2
-- equality
-- cases where both arguments are evaluated already
eval (EQ (Number n1) (Number n2)) = if n1 == n2 then T else F
eval (EQ T T) = T
eval (EQ F F) = T
eval (EQ T F) = F
eval (EQ F T) = F
eval (Eql (Number n1) (Number n2)) = if n1 == n2 then T else F
eval (Eql T T) = T
eval (Eql F F) = T
eval (Eql T F) = F
eval (Eql F T) = F
-- left argument is evaluated, eval the right one
eval (EQ e1@(Number n1) e2) = EQ e1 (eval e2)
eval (EQ T e2) = EQ T (eval e2)
eval (EQ F e2) = EQ F (eval e2)
eval (Eql e1@(Number n1) e2) = Eql e1 (eval e2)
eval (Eql T e2) = Eql T (eval e2)
eval (Eql F e2) = Eql F (eval e2)
-- no argument is evaluated, eval the left one
eval (EQ e1 e2) = EQ (eval e1) e2
eval (Eql e1 e2) = Eql (eval e1) e2
-- application
-- beta reduction

View file

@ -1,12 +1,16 @@
module Misc where
import Syntax
import Data.List
import Data.Maybe
-- calculates free variables in a Term
freeVars :: Term -> [Name]
freeVars (Var x) = [x]
freeVars (ITE e1 e2 e3) = freeVars e1 ++ freeVars e2 ++ freeVars e3
freeVars (ADD e1 e2) = freeVars e1 ++ freeVars e2
freeVars (SUB e1 e2) = freeVars e1 ++ freeVars e2
freeVars (EQ e1 e2) = freeVars e1 ++ freeVars e2
freeVars (Ite e1 e2 e3) = freeVars e1 ++ freeVars e2 ++ freeVars e3
freeVars (Add e1 e2) = freeVars e1 ++ freeVars e2
freeVars (Sub e1 e2) = freeVars e1 ++ freeVars e2
freeVars (Eql e1 e2) = freeVars e1 ++ freeVars e2
freeVars (App e1 e2) = freeVars e1 ++ freeVars e2
freeVars (Abs x p) = filter (/= x) $ freeVars p
freeVars _ = []
@ -18,11 +22,11 @@ vars = ['v' : show n | n <- [1..]]
-- capture-avoiding substitution
subst :: (Name -> Term) -> Term -> Term
subst o (Var x) = o x
subst o (ITE e1 e2 e3) = ITE (subst o e1) (subst o e2) (subst o e3)
subst o (Ite e1 e2 e3) = Ite (subst o e1) (subst o e2) (subst o e3)
subst o (Case e1 e2 e3) = Case (subst o e1) (subst o e2) (subst o e3)
subst o (ADD e1 e2) = ADD (subst o e1) (subst o e2)
subst o (SUB e1 e2) = SUB (subst o e1) (subst o e2)
subst o (EQ e1 e2) = EQ (subst o e1) (subst o e2)
subst o (Add e1 e2) = Add (subst o e1) (subst o e2)
subst o (Sub e1 e2) = Sub (subst o e1) (subst o e2)
subst o (Eql e1 e2) = Eql (subst o e1) (subst o e2)
subst o (App e1 e2) = App (subst o e1) (subst o e2)
subst o (Cons e1 e2) = Cons (subst o e1) (subst o e2)
subst o (Pair e1 e2) = Pair (subst o e1) (subst o e2)

View file

@ -3,43 +3,55 @@ module Syntax where
type Name = String
data Term
= T | F
| ITE Term Term Term
| Number Int
| ADD Term Term
| SUB Term Term
| EQ Term Term
| Fst Term
| Snd Term
| Inl Term
| Inr Term
| Case Term Term Term
| Nil
| Cons Term Term
| Pair Term Term
| Var Name
| App Term Term
| Abs Name Term
| Fix
deriving (Eq)
= T
| F
| Ite Term Term Term
| Number Int
| Add Term Term
| Sub Term Term
| Eql Term Term
| Or Term Term
| Fst Term
| Snd Term
| Inl Term
| Inr Term
| Case Term Term Term
| Nil
| Cons Term Term
| Pair Term Term
| Var Name
| App Term Term
| Abs Name Term
| Fix
| Star
deriving (Eq)
instance Show Term where
show T = "true"
show F = "false"
show (ITE t1 t2 t3) = "if (" ++ show t1 ++ ")\n then (" ++ show t2 ++ ")\n else (" ++ show t3 ++ ")"
show (Case t1 t2 t3) = "\n case (" ++ show t1 ++ ") of\n <0> => (" ++ show t2 ++ ");\n <1> => (" ++ show t3 ++ ")"
show (ADD t1 t2) = "(" ++ show t1 ++ ") + (" ++ show t2 ++ ")"
show (SUB t1 t2) = "(" ++ show t1 ++ ") - (" ++ show t2 ++ ")"
show (EQ t1 t2) = "(" ++ show t1 ++ ") == (" ++ show t2 ++ ")"
show (Number n) = show n
show (Fst t) = "fst " ++ show t
show (Snd t) = "snd " ++ show t
show (Inl t) = "inl " ++ show t
show (Inr t) = "inr " ++ show t
show Nil = "[]"
show (Cons t1 t2) = "(" ++ show t1 ++ ") : (" ++ show t2 ++ ")"
show (Pair t1 t2) = "(" ++ show t1 ++ ", " ++ show t2 ++ ")"
show (Var n) = n
show (App t1 t2) = "(" ++ show t1 ++ ") " ++ show t2
show Fix = "Y"
show (Abs x p) = "\\" ++ x ++ ". " ++ show p
show T = "true"
show F = "false"
show (Ite t1 t2 t3) = "if (" ++ show t1 ++ ")\n then (" ++ show t2 ++ ")\n else (" ++ show t3 ++ ")"
show (Case t1 t2 t3) = "\n case (" ++ show t1 ++ ") of\n <0> => (" ++ show t2 ++ ");\n <1> => (" ++ show t3 ++ ")"
show (Add t1 t2) = "(" ++ show t1 ++ ") + (" ++ show t2 ++ ")"
show (Sub t1 t2) = "(" ++ show t1 ++ ") - (" ++ show t2 ++ ")"
show (Eql t1 t2) = "(" ++ show t1 ++ ") == (" ++ show t2 ++ ")"
show (Number n) = show n
show (Fst t) = "fst " ++ show t
show (Snd t) = "snd " ++ show t
show (Inl t) = "inl " ++ show t
show (Inr t) = "inr " ++ show t
show Nil = "[]"
show (Cons t1 t2) = "(" ++ show t1 ++ ") : (" ++ show t2 ++ ")"
show (Pair t1 t2) = "(" ++ show t1 ++ ", " ++ show t2 ++ ")"
show (Var n) = n
show (App t1 t2) = "(" ++ show t1 ++ ") " ++ show t2
show Fix = "Y"
show (Abs x p) = "\\" ++ x ++ ". " ++ show p
data Type
= TBool
| TNat
| TUnit
| TProd Type Type
| TArrow Type Type
| THole
deriving (Eq, Show)

View file

@ -1,3 +1,135 @@
module Typecheck where
import Syntax
import Control.Monad
import Data.Map.Strict
import qualified Data.Map.Strict as Map
import Syntax
term1 = Abs "y" $ Add (App (Abs "x" $ Or (Var "x") (Var "y")) T) (Var "y")
term2 = Abs "x" $ Add (Number 5) (Var "x")
term3 = Abs "x" $ Eql (Var "x") (Number 5)
type Context = Map Name Type
typeCheck :: Term -> Maybe Type
typeCheck t = fst <$> inferType Map.empty t THole
inferType :: Context -> Term -> Type -> Maybe (Type, Context)
inferType ctx T _ = Just (TBool, ctx)
inferType ctx F _ = Just (TBool, ctx)
inferType ctx (Ite b t1 t2) expected = do
(tb, ctxb) <- inferType ctx b TBool
(tp1, ctx1) <- inferType ctx t1 expected
(tp2, ctx2) <- inferType ctx t2 expected
ctx' <- mergeContexts ctxb ctx1
ctx'' <- mergeContexts ctx' ctx2
if tp1 == tp2
then Just (tp1, ctx'')
else Nothing
inferType ctx (Number _) _ = Just (TNat, ctx)
inferType ctx (Add t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 TNat
(tp2, ctx2) <- inferType ctx t2 TNat
ctx' <- mergeContexts ctx1 ctx2
case (tp1, tp2) of
(TNat, TNat) -> Just (TNat, ctx')
_ -> Nothing
inferType ctx (Sub t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 TNat
(tp2, ctx2) <- inferType ctx t2 TNat
ctx' <- mergeContexts ctx1 ctx2
case (tp1, tp2) of
(TNat, TNat) -> Just (TNat, ctx')
_ -> Nothing
inferType ctx (Or t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 TBool
(tp2, ctx2) <- inferType ctx t2 TBool
ctx' <- mergeContexts ctx1 ctx2
case (tp1, tp2) of
(TBool, TBool) -> Just (TBool, ctx')
_ -> Nothing
inferType ctx (Eql t1@(Var v1) t2) _ = do
(tp2, ctx2) <- inferType ctx t2 THole
(tp1, ctx1) <- inferType ctx t1 tp2
ctx' <- mergeContexts ctx1 ctx2
if tp1 == tp2
then case tp1 of
TBool -> Just (TBool, ctx')
TNat -> Just (TBool, ctx')
TUnit -> Just (TBool, ctx')
_ -> Nothing
else Nothing
inferType ctx (Eql t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 THole
(tp2, ctx2) <- inferType ctx t2 tp1
ctx' <- mergeContexts ctx1 ctx2
if tp1 == tp2
then case tp1 of
TBool -> Just (TBool, ctx')
TNat -> Just (TBool, ctx')
TUnit -> Just (TBool, ctx')
_ -> Nothing
else Nothing
inferType ctx (Pair t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 THole
(tp2, ctx2) <- inferType ctx t2 THole
ctx' <- mergeContexts ctx1 ctx2
Just (TProd tp1 tp2, ctx')
inferType ctx (Fst t) _ = do
(tp, ctx') <- inferType ctx t THole
case tp of
TProd tp1 tp2 -> Just (tp1, ctx')
_ -> Nothing
inferType ctx (Snd t) _ = do
(tp, ctx') <- inferType ctx t THole
case tp of
TProd tp1 tp2 -> Just (tp2, ctx')
_ -> Nothing
inferType ctx Star _ = Just (TUnit, ctx)
inferType ctx (Var v) expected = do
tp <- Map.lookup v ctx
case tp of
THole -> return (expected, Map.insert v expected ctx)
tp' | tp == tp' -> return (expected, ctx)
_ -> Nothing
inferType ctx (App Fix t2) _ = do
(tp2, ctx') <- inferType ctx t2 THole
case tp2 of
TArrow tpl tpr | tpl == tpr -> Just (tpl, ctx')
_ -> Nothing
inferType ctx (App t1 t2) _ = do
(tp1, ctx1) <- inferType ctx t1 THole
(tp2, ctx2) <- inferType ctx t2 THole
ctx' <- mergeContexts ctx1 ctx2
case tp1 of
TArrow tpl tpr | tpl == tp2 -> Just (tpr, ctx')
_ -> Nothing
inferType ctx (Abs x t) _ = do
-- save old type of x (it might get overwritten here, but it also might be Nothing atm)
let tpxOld = Map.lookup x ctx
-- do type inference in t, this might fill some holes we had, so we can't discard ctx'...
(tp, ctx') <- inferType (Map.insert x THole ctx) t THole
tpx <- Map.lookup x ctx'
-- restore old type of x in new context
let ctx'' = case tpxOld of
Nothing -> ctx'
Just old -> Map.insert x old ctx'
return (TArrow tpx tp, ctx'')
-- merge two contexts
mergeContexts :: Context -> Context -> Maybe Context
mergeContexts ctx1 ctx2 = do
ctx' <- foldrWithKey helper (Just Map.empty) ctx1
return $ Map.union ctx' ctx2
where
helper name tp Nothing = Nothing
helper name THole (Just ctx) = case Map.lookup name ctx2 of
Nothing -> return $ Map.insert name THole ctx
Just tp' -> return $ Map.insert name tp' ctx
helper name tp (Just ctx) = case Map.lookup name ctx2 of
Nothing -> return $ Map.insert name tp ctx
Just THole -> return $ Map.insert name tp ctx
Just tp' | tp == tp' -> Just ctx
Just tp' | otherwise -> Nothing