1
{-# OPTIONS -fglasgow-exts #-}
2
module TypeChecker where
4
import Control.Monad.Reader
5
import Control.Monad.State
6
import Control.Monad.Error
8
import qualified Data.Map as Map
15
type Context = Map.Map Ident Type
17
emptyContext :: Context
18
emptyContext = Map.empty
24
data Current = Current Call Parent [Sibling] [Child] | TopLevel [Child]
25
data Parent = Parent Call Parent [Sibling] | NoParent
26
data Child = Child Call [Child]
31
data Call = Infer Exp (Maybe Type)
32
| Check Exp Type (Maybe ())
33
| LookupVar Ident (Maybe Type)
34
| EqualType Type Type (Maybe ())
35
| IsFunctionType Type (Maybe (Type,Type))
37
indent :: Int -> String -> String
39
indent n s = unlines . concatMap ind . lines $ s
42
ind s = [replicate n ' ' ++ s]
44
topLevelView :: Trace -> Trace
45
topLevelView t@(TopLevel _) = t
46
topLevelView t@(Current c _ _ _) = topLevelView $ updateCall c t
48
instance Show Current where
49
show t = case topLevelView t of
50
TopLevel cs -> unlines $ map show cs
52
instance Show Child where
54
indent 0 (show c) ++ indent 2 (unlines $ map show $ reverse cs)
56
instance Show Call where
58
Infer e r -> "infer " ++ show' e ++ " = " ++ res r
59
Check e t r -> "check " ++ show' e ++ " " ++ show' t ++ " " ++ nores r
60
LookupVar x r -> "lookupVar " ++ printTree x ++ " = " ++ res r
61
EqualType s t r -> show' s ++ " == " ++ show' t ++ " " ++ nores r
62
IsFunctionType t r -> "isFunctionType " ++ show' t ++ " " ++ nores r
64
show' x = par $ printTree x
66
| ' ' `elem` s = "(" ++ s ++ ")"
69
res (Just r) = printTree r
73
inProgress :: Call -> Bool
74
inProgress c = case c of
75
Infer _ r -> isNothing r
76
Check _ _ r -> isNothing r
77
LookupVar _ r -> isNothing r
78
EqualType _ _ r -> isNothing r
79
IsFunctionType _ r -> isNothing r
81
newCall :: Call -> Trace -> Trace
82
newCall c (TopLevel cs) = Current c NoParent cs []
83
newCall c (Current c' p ss cs) = Current c (Parent c' p ss) cs []
85
updateCall :: Call -> Trace -> Trace
86
updateCall c (TopLevel _) = error $ "updateCall: no call in progress"
87
updateCall c (Current _ p ss cs) = case p of
88
NoParent -> TopLevel $ Child c cs : ss
89
Parent c' p' ss' -> Current c' p' ss' $ Child c cs : ss
91
call :: (Maybe r -> Call) -> TC r -> TC r
93
modify $ newCall (mkCall Nothing)
95
modify $ updateCall (mkCall $ Just r)
99
data ErrorMsg = UnboundVar Ident
100
| TypeMismatch Type Type
101
| NotFunctionType Type
102
| InternalError String
104
instance Show ErrorMsg where
106
UnboundVar x -> "Unbound variable " ++ printTree x
107
TypeMismatch s t -> printTree s ++ " != " ++ printTree t
108
NotFunctionType t -> printTree t ++ " is not a function type"
109
InternalError s -> "Internal error: " ++ s
111
type TypeError = (Trace, ErrorMsg)
113
instance Error TypeError where
114
noMsg = (noTrace,InternalError "")
115
strMsg s = (noTrace,InternalError s)
118
newtype TC a = TC { unTC :: ReaderT Context (StateT Trace (Either TypeError)) a }
119
deriving (MonadReader Context, MonadState Trace)
121
instance Monad TC where
123
TC m >>= k = TC $ m >>= unTC . k
124
fail = failure . InternalError
126
failure :: ErrorMsg -> TC a
129
TC $ lift $ lift $ Left (tr, msg)
131
runTC :: TC a -> Either TypeError a
132
runTC (TC tc) = evalStateT (runReaderT tc emptyContext) noTrace
135
lookupVar :: Ident -> TC Type
136
lookupVar x = call (LookupVar x) $ do
138
case Map.lookup x ctx of
140
Nothing -> failure $ UnboundVar x
142
addToContext :: Ident -> Type -> TC a -> TC a
143
addToContext x t = local $ Map.insert x t
145
infer :: Exp -> TC Type
146
infer e = call (Infer e) $ case e of
148
Lit _ -> return $ ConT (Ident "Nat")
151
(t1, t2) <- isFunctionType t
155
t' <- addToContext x t $ infer e
158
check :: Exp -> Type -> TC ()
159
check e t = call (Check e t) $ do
163
(===) :: Type -> Type -> TC ()
164
s === t = call (EqualType s t) $ case (s,t) of
166
| x == y -> return ()
167
(FunT s1 t1, FunT s2 t2) -> do
170
_ -> failure $ TypeMismatch s t
172
isFunctionType :: Type -> TC (Type, Type)
173
isFunctionType t = call (IsFunctionType t) $ case t of
174
FunT t1 t2 -> return (t1, t2)
175
_ -> failure $ NotFunctionType t
178
matchCall :: (Call -> Maybe a) -> Trace -> Maybe a
179
matchCall f = matchTrace f'
183
matchTrace :: (Child -> Maybe a) -> Trace -> Maybe a
184
matchTrace f (TopLevel _) = Nothing
185
matchTrace f t@(Current c _ _ cs) =
186
f (Child c cs) `mplus` matchTrace f (updateCall c t)
188
displayError :: TypeError -> String
189
displayError (tr, e) = case e of
190
InternalError s -> unlines [ "internal error: " ++ s ]
191
UnboundVar x -> unlines [ "unbound variable " ++ printTree x ]
196
, "is applied to an argument, but has type"
198
, "which is not a function type"
201
f = case matchCall isInferApp tr of
203
Nothing -> error "displayError: can't find function"
204
isInferApp (Infer (App f _) Nothing) = Just f
205
isInferApp _ = Nothing
209
[ "When checking the type of"
211
, "the inferred type"
212
, " " ++ printTree t'
213
, "does not match the expected type"
214
, " " ++ printTree s'
216
, " " ++ printTree s ++ " != " ++ printTree t
220
isCheck (Child (Check e _ Nothing) (Child (EqualType s t _) _ : _)) = Just (e,s,t)
223
(e,s',t') = case matchTrace isCheck tr of
224
Just (e,s,t) -> (e,s,t)
225
Nothing -> error $ "displayError: weird type mismatch"