~ubuntu-branches/ubuntu/wily/agda/wily-proposed

« back to all changes in this revision

Viewing changes to src/prototyping/trace/TypeChecker.hs

  • Committer: Package Import Robot
  • Author(s): Iain Lane
  • Date: 2014-08-05 06:38:12 UTC
  • mfrom: (1.1.6)
  • Revision ID: package-import@ubuntu.com-20140805063812-io8e77niomivhd49
Tags: 2.4.0.2-1
* [6e140ac] Imported Upstream version 2.4.0.2
* [2049fc8] Update Build-Depends to match control
* [93dc4d4] Install the new primitives
* [e48f40f] Fix typo dev→doc

Show diffs side-by-side

added added

removed removed

Lines of Context:
1
 
{-# OPTIONS -fglasgow-exts #-}
2
 
module TypeChecker where
3
 
 
4
 
import Control.Monad.Reader
5
 
import Control.Monad.State
6
 
import Control.Monad.Error
7
 
 
8
 
import qualified Data.Map as Map
9
 
import Data.Maybe
10
 
 
11
 
import Lambda.Abs
12
 
import Lambda.Print
13
 
 
14
 
 
15
 
type Context   = Map.Map Ident Type
16
 
 
17
 
emptyContext :: Context
18
 
emptyContext = Map.empty
19
 
 
20
 
 
21
 
type Trace   = Current
22
 
type Sibling = Child
23
 
 
24
 
data Current = Current Call Parent [Sibling] [Child] | TopLevel [Child]
25
 
data Parent  = Parent  Call Parent [Sibling] | NoParent
26
 
data Child   = Child   Call [Child]
27
 
 
28
 
noTrace :: Trace
29
 
noTrace = TopLevel []
30
 
 
31
 
data Call = Infer Exp           (Maybe Type)
32
 
          | Check Exp Type      (Maybe ())
33
 
          | LookupVar Ident     (Maybe Type)
34
 
          | EqualType Type Type (Maybe ())
35
 
          | IsFunctionType Type (Maybe (Type,Type))
36
 
 
37
 
indent :: Int -> String -> String
38
 
indent _ "" = ""
39
 
indent n s  = unlines . concatMap ind . lines $ s
40
 
    where
41
 
        ind "" = []
42
 
        ind s = [replicate n ' ' ++ s]
43
 
 
44
 
topLevelView :: Trace -> Trace
45
 
topLevelView t@(TopLevel _)      = t
46
 
topLevelView t@(Current c _ _ _) = topLevelView $ updateCall c t
47
 
 
48
 
instance Show Current where
49
 
    show t = case topLevelView t of
50
 
        TopLevel cs -> unlines $ map show cs
51
 
 
52
 
instance Show Child where
53
 
    show (Child c cs) =
54
 
        indent 0 (show c) ++ indent 2 (unlines $ map show $ reverse cs)
55
 
 
56
 
instance Show Call where
57
 
    show e = case e of
58
 
        Infer e r          -> "infer " ++ show' e ++ " = " ++ res r
59
 
        Check e t r        -> "check " ++ show' e ++ " " ++ show' t ++ " " ++ nores r
60
 
        LookupVar x r      -> "lookupVar " ++ printTree x ++ " = " ++ res r
61
 
        EqualType s t r    -> show' s ++ " == " ++ show' t ++ " " ++ nores r
62
 
        IsFunctionType t r -> "isFunctionType " ++ show' t ++ " " ++ nores r
63
 
        where
64
 
            show' x = par $ printTree x
65
 
            par s
66
 
                | ' ' `elem` s = "(" ++ s ++ ")"
67
 
                | otherwise    = s
68
 
            res Nothing = "?"
69
 
            res (Just r) = printTree r
70
 
            nores Nothing = "?"
71
 
            nores (Just _) = ""
72
 
 
73
 
inProgress :: Call -> Bool
74
 
inProgress c = case c of
75
 
    Infer _ r          -> isNothing r
76
 
    Check _ _ r        -> isNothing r
77
 
    LookupVar _ r      -> isNothing r
78
 
    EqualType _ _ r    -> isNothing r
79
 
    IsFunctionType _ r -> isNothing r
80
 
 
81
 
newCall :: Call -> Trace -> Trace
82
 
newCall c (TopLevel cs)        = Current c NoParent cs []
83
 
newCall c (Current c' p ss cs) = Current c (Parent c' p ss) cs []
84
 
 
85
 
updateCall :: Call -> Trace -> Trace
86
 
updateCall c (TopLevel _)        = error $ "updateCall: no call in progress"
87
 
updateCall c (Current _ p ss cs) = case p of
88
 
    NoParent         -> TopLevel $ Child c cs : ss
89
 
    Parent c' p' ss' -> Current c' p' ss' $ Child c cs : ss
90
 
 
91
 
call :: (Maybe r -> Call) -> TC r -> TC r
92
 
call mkCall m = do
93
 
    modify $ newCall (mkCall Nothing)
94
 
    r <- m
95
 
    modify $ updateCall (mkCall $ Just r)
96
 
    return r
97
 
 
98
 
 
99
 
data ErrorMsg = UnboundVar Ident
100
 
              | TypeMismatch Type Type
101
 
              | NotFunctionType Type
102
 
              | InternalError String
103
 
 
104
 
instance Show ErrorMsg where
105
 
    show e = case e of
106
 
        UnboundVar x      -> "Unbound variable " ++ printTree x
107
 
        TypeMismatch s t  -> printTree s ++ " != " ++ printTree t
108
 
        NotFunctionType t -> printTree t ++ " is not a function type"
109
 
        InternalError s   -> "Internal error: " ++ s
110
 
 
111
 
type TypeError = (Trace, ErrorMsg)
112
 
 
113
 
instance Error TypeError where
114
 
    noMsg    = (noTrace,InternalError "")
115
 
    strMsg s = (noTrace,InternalError s)
116
 
 
117
 
 
118
 
newtype TC a = TC { unTC :: ReaderT Context (StateT Trace (Either TypeError)) a }
119
 
    deriving (MonadReader Context, MonadState Trace)
120
 
 
121
 
instance Monad TC where
122
 
    return = TC . return
123
 
    TC m >>= k = TC $ m >>= unTC . k
124
 
    fail = failure . InternalError
125
 
 
126
 
failure :: ErrorMsg -> TC a
127
 
failure msg = do
128
 
    tr <- get
129
 
    TC $ lift $ lift $ Left (tr, msg)
130
 
 
131
 
runTC :: TC a -> Either TypeError a
132
 
runTC (TC tc) = evalStateT (runReaderT tc emptyContext) noTrace
133
 
 
134
 
 
135
 
lookupVar :: Ident -> TC Type
136
 
lookupVar x = call (LookupVar x) $ do
137
 
    ctx <- ask
138
 
    case Map.lookup x ctx of
139
 
        Just t  -> return t
140
 
        Nothing -> failure $ UnboundVar x
141
 
 
142
 
addToContext :: Ident -> Type -> TC a -> TC a
143
 
addToContext x t = local $ Map.insert x t
144
 
 
145
 
infer :: Exp -> TC Type
146
 
infer e = call (Infer e) $ case e of
147
 
    Var x     -> lookupVar x
148
 
    Lit _     -> return $ ConT (Ident "Nat")
149
 
    App e1 e2 -> do
150
 
        t        <- infer e1
151
 
        (t1, t2) <- isFunctionType t
152
 
        check e2 t1
153
 
        return t2
154
 
    Lam x t e -> do
155
 
        t' <- addToContext x t $ infer e
156
 
        return $ FunT t t'
157
 
 
158
 
check :: Exp -> Type -> TC ()
159
 
check e t = call (Check e t) $ do
160
 
    t' <- infer e
161
 
    t === t'
162
 
 
163
 
(===) :: Type -> Type -> TC ()
164
 
s === t = call (EqualType s t) $ case (s,t) of
165
 
    (ConT x, ConT y)
166
 
        | x == y     -> return ()
167
 
    (FunT s1 t1, FunT s2 t2) -> do
168
 
        s1 === s2
169
 
        t1 === t2
170
 
    _ -> failure $ TypeMismatch s t
171
 
 
172
 
isFunctionType :: Type -> TC (Type, Type)
173
 
isFunctionType t = call (IsFunctionType t) $ case t of
174
 
    FunT t1 t2  -> return (t1, t2)
175
 
    _           -> failure $ NotFunctionType t
176
 
 
177
 
 
178
 
matchCall :: (Call -> Maybe a) -> Trace -> Maybe a
179
 
matchCall f = matchTrace f'
180
 
    where
181
 
        f' (Child c _) = f c
182
 
 
183
 
matchTrace :: (Child -> Maybe a) -> Trace -> Maybe a
184
 
matchTrace f (TopLevel _) = Nothing
185
 
matchTrace f t@(Current c _ _ cs) =
186
 
    f (Child c cs) `mplus` matchTrace f (updateCall c t)
187
 
 
188
 
displayError :: TypeError -> String
189
 
displayError (tr, e) = case e of
190
 
    InternalError s     -> unlines [ "internal error: " ++ s ]
191
 
    UnboundVar x        -> unlines [ "unbound variable " ++ printTree x ]
192
 
    NotFunctionType t   ->
193
 
        indent 0 $ unlines
194
 
                [ "the expression"
195
 
                , "  " ++ printTree f
196
 
                , "is applied to an argument, but has type"
197
 
                , "  " ++ printTree t
198
 
                , "which is not a function type"
199
 
                ] -- show tr
200
 
        where
201
 
            f = case matchCall isInferApp tr of
202
 
                    Just f  -> f
203
 
                    Nothing -> error "displayError: can't find function"
204
 
            isInferApp (Infer (App f _) Nothing) = Just f
205
 
            isInferApp _                         = Nothing
206
 
 
207
 
    TypeMismatch s t ->
208
 
        indent 0 $ unlines
209
 
            [ "When checking the type of"
210
 
            , "  " ++ printTree e
211
 
            , "the inferred type"
212
 
            , "  " ++ printTree t'
213
 
            , "does not match the expected type"
214
 
            , "  " ++ printTree s'
215
 
            , "because"
216
 
            , "  " ++ printTree s ++ " != " ++ printTree t
217
 
            ]
218
 
        -- show tr
219
 
        where
220
 
            isCheck (Child (Check e _ Nothing) (Child (EqualType s t _) _ : _)) = Just (e,s,t)
221
 
            isCheck _ = Nothing
222
 
 
223
 
            (e,s',t') = case matchTrace isCheck tr of
224
 
                Just (e,s,t) -> (e,s,t)
225
 
                Nothing      -> error $ "displayError: weird type mismatch"