Combined-line mode checks bidirectionalizability
[darcs-mirror-sem_syn.git] / Type.hs
1 module Type where
2
3 import AST 
4
5 import Data.Graph 
6 import Control.Monad.State
7 import Control.Monad.Error
8 import Util
9 import Data.Maybe
10 import Data.List (nub,nubBy,union)
11
12
13 import Data.Map (Map)
14 import qualified Data.Map as Map
15
16 -- type erasure
17 eraseType (AST decls) =
18     AST $ map (\(Decl f ftype ps e) ->
19              Decl f FTUndet (map eraseTypeP ps) (eraseTypeE e)) decls 
20
21 eraseTypeP (PVar id t varname)   
22     = PVar id TUndet varname
23 eraseTypeP (PCon id t conname ps)
24     = PCon id TUndet conname (map eraseTypeP ps)
25
26 eraseTypeE (EVar id t varname) 
27     = EVar id TUndet varname 
28 eraseTypeE (ECon id t conname es)
29     = ECon id TUndet conname (map eraseTypeE es)
30 eraseTypeE (EFun id t funname es)
31     = EFun id TUndet funname (map eraseTypeE es)
32
33 eraseTypeT (TAST decls) = 
34     TAST $ map (\(TDecl f ps es bs) -> 
35                     TDecl f (map eraseTypeP ps) (map eraseTypeE es)bs) decls
36
37 -- type inference
38
39 initTMap :: [ (Name, FType) ]
40 initTMap =
41     [ (Name "Z",   TFun [] [] (TCon (Name "Nat") [])),
42       (Name "S",   TFun [] [TCon (Name "Nat") []] (TCon (Name "Nat") [])),
43       (Name "Nil",  TFun [0] [] (TCon (Name "List") [TVar 0])),
44       (Name "Cons", TFun [0] [TVar 0, TCon (Name "List") [TVar 0]] 
45                 (TCon (Name "List") [TVar 0])) ]
46
47
48 typeInference (AST decls) = 
49     let mAst = do { (decls',_,_) <- 
50                         foldr (\decls m -> 
51                                    do (rdecls, tMap,  icount)  <- m
52                                       (decls', tMap', icount') <- inferenceStep decls tMap icount
53                                       return $ (decls'++rdecls, tMap', icount')
54                               ) (return ([],initTMap,initIcount)) declss
55                   ; return $ AST decls' } 
56     in case mAst of 
57          Left s  -> error s 
58          Right a ->  a 
59     where
60       initIcount = 100 -- FIXME 
61       declss = 
62           let scc = stronglyConnComp callGraph 
63           in  reverse $ map (\x -> case x of 
64                            AcyclicSCC f  -> 
65                                filter (\(Decl g _ _ _) -> f == g) decls
66                            CyclicSCC  fs -> 
67                                filter (\(Decl g _ _ _) -> g `elem` fs) decls) scc
68 --      callGraph = map (\f -> (f,f,snub $ f:funCallsE e)) $
69 --                     grupBy $ map (\(Decl f _ _ _) -> f) decls
70       callGraph = 
71           let fMap  = Map.fromListWith union $ 
72                        map (\(Decl f _ _ e) -> (f,f:funCallsE e)) decls 
73               fMap' = Map.map (snub) fMap 
74           in map (\(f,fs) -> (f,f,fs)) $ Map.toList fMap'
75       funCallsE (EVar _ _ v)    = []
76       funCallsE (EFun _ _ f es) = f:concatMap funCallsE es 
77       funCallsE (ECon _ _ _ es) = concatMap funCallsE es 
78
79
80 inferenceStep decls tmap icount = 
81     let (decls0,  (tmpMap, icount0))  
82             = runState (makeInitConstr tmap decls) ([],icount)
83         (decls' , (constr, icount')) 
84             = runState (mapM (assignTypeVars tmpMap tmap) decls0) ([],icount0)
85     in
86       do { (tmpMap', etypeMap') <- solveConstr tmpMap constr
87          ; let decls'' = map (repl tmpMap' etypeMap') decls'
88          ; return (decls'', tmpMap' ++ tmap, icount') }
89         where 
90           repl tM cM (Decl f ftype ps e) =
91               Decl f (fromJust $ lookup f tM) (map replP ps) (replE e)
92               where
93                 replP (PVar id (TVar i) v)    
94                     = PVar id (fromJust $ lookup i cM) v
95                 replP (PCon id (TVar i) c ps)
96                     = PCon id (fromJust $ lookup i cM) c (map replP ps)
97                 replE (EVar id (TVar i) v)
98                     = EVar id (fromJust $ lookup i cM) v
99                 replE (ECon id (TVar i) c es)
100                     = ECon id (fromJust $ lookup i cM) c (map replE es)
101                 replE (EFun id (TVar i) c es)
102                     = EFun id (fromJust $ lookup i cM) c (map replE es)
103           extractConstr ds = map (\(Decl f t _ _) -> (f,t)) $
104                                 nubBy isSameFunc ds
105
106
107
108 solveConstr tmpMap constr 
109     = substStep constr (tmpMap, rearrange constr)
110     where 
111       introForAll (k,TFun _ ts t) =
112           let vs = snub $ varsT t ++ concatMap varsT ts 
113           in (k,TFun vs ts t)
114       rearrange constr = 
115           let vs = nub $ concatMap (\(t1,t2) -> varsT t1 ++ varsT t2) constr 
116           in map (\x -> (x,TVar x)) vs                
117       varsT (TVar i)    = [i]
118       varsT (TCon _ ts) = concatMap varsT ts 
119       varsT (TUndet)    = []
120       substStep [] (tM,cM) = return (map introForAll tM,cM)
121       substStep ((t,t'):cs) (tM,cM) =
122           do { subs <- unify t t'
123              ; substStep
124                   (performSubstC subs cs)
125                   (performSubstTM subs tM, performSubstCM subs cM) }
126       performSubstC subs cs
127           = map (\(t1,t2) -> (performSubstT subs t1, performSubstT subs t2)) cs
128       performSubstTM subs tM 
129           = map (\(k,v) -> (k, performSubstFT subs v)) tM
130       performSubstCM subs cM
131           = map (\(k,v) -> (k, performSubstT subs v)) cM
132       performSubstFT subs (TFun is ts t) 
133           = TFun [] (map (performSubstT subs) ts) (performSubstT subs t)
134       performSubstT subs (TUndet) = TUndet 
135       performSubstT subs (TVar i) = 
136           case lookup (TVar i) subs of 
137             Just t' -> t'
138             _       -> TVar i
139       performSubstT subs (TCon c ts) =
140           TCon c (map (performSubstT subs) ts)
141       unify :: Type -> Type -> Either String [ (Type, Type) ]
142       unify (TVar i) t | not (i `elem` varsT t) = return [ (TVar i, t) ]
143       unify t (TVar i) | not (i `elem` varsT t) = return [ (TVar i, t) ]
144       unify (TVar i) (TVar j) | i == j = return []
145       unify (TCon c ts) (TCon c' ts') | c == c' 
146           = do { ss <- mapM (uncurry unify) $ zip ts ts'
147                ; return $ concat ss }
148       unify t t' = throwError $ "Can't unify types: " ++ show ( ppr (t,t'))
149                  
150     
151                
152
153 makeInitConstr tmap decls =
154     do { mapM_ (\(Decl f t ps e) ->
155                       do { tmpMap <- getTmpMap 
156                          ; case t of
157                              FTUndet -> 
158                                  case lookup f tmpMap of 
159                                    Just t' -> 
160                                        return ()
161                                    _ -> 
162                                        do { i  <- newTypeVar 
163                                           ; is <- mapM (\_->newTypeVar) ps 
164                                           ; let t' = TFun [] (map TVar is) (TVar i) 
165                                           ; putTmpMap ((f,t'):tmpMap)
166                                           ; return ()  }
167                              _ -> 
168                                  putTmpMap ((f,t):tmpMap)}) $ 
169          (nubBy isSameFunc decls)
170        ; tmpMap <- getTmpMap
171        ; let decls' = map (\(Decl f t ps e) -> 
172                              Decl f (fromJust $ lookup f tmpMap) ps e) decls
173        ; return decls' }
174     where getTmpMap    = do { (tmpMap,i) <- get; return tmpMap }
175           putTmpMap tm = do { (_,i) <- get; put (tm,i) }
176           newTypeVar   = do { (tm,i) <- get; put (tm,i+1); return i}
177
178     
179                
180
181 assignTypeVars tmpMap typeMap (Decl fname ftype ps e) =
182     do ps' <- mapM assignTypeVarsP ps
183        e'  <- assignTypeVarsE      e
184        unifyFT ftype (TFun [] (map typeofP ps') (typeofE e'))
185        let vtp = concatMap vtMapP ps'
186        let vte = vtMapE e'
187        mapM_ (\(x,t) -> case (lookup x vte) of 
188                           { Just t' -> unifyT t t'; _ -> return ()}) vtp 
189        mapM_ (\(x,t) -> case (lookup x vte) of 
190                           { Just t' -> unifyT t t' }) vte 
191        return $ Decl fname ftype ps' e'
192     where
193       vtMapP (PVar _ t x)    = [(x,t)]
194       vtMapP (PCon _ _ c ps) = concatMap vtMapP ps 
195       vtMapE (EVar _ t x)    = [(x,t)]
196       vtMapE (ECon _ _ c es) = concatMap vtMapE es
197       vtMapE (EFun _ _ c es) = concatMap vtMapE es
198 --      newTypeVar :: State ( [(Type,Type)], Int ) Int
199       newTypeVar = do { (constr, icount) <- get
200                       ; put (constr, icount+1)
201                       ; return icount }
202       addConstr s t = do { (constr, icount) <- get
203                            ; put ((s,t):constr, icount) }
204       assignTypeVarsP (PVar id t v) = 
205           do { i <- newTypeVar
206              ; unifyT t (TVar i) 
207              ; return $ PVar id (TVar i) v } 
208       assignTypeVarsP (PCon id t c ps) = 
209           do { i <- newTypeVar
210              ; case lookup c typeMap of
211                  Just t' -> 
212                      do { ps' <- mapM assignTypeVarsP ps 
213                         ; unifyFT t' (TFun [] (map typeofP ps') (TVar i))
214                         ; unifyT  t  (TVar i)
215                         ; return $ PCon id (TVar i) c ps' }}
216       assignTypeVarsE (EVar id t v) = 
217           do { i <- newTypeVar 
218              ; unifyT t (TVar i)
219              ; return $ EVar id (TVar i) v }
220       assignTypeVarsE (ECon id t c es) =
221           do { i <- newTypeVar
222              ; case lookup c typeMap of
223                  Just t' -> 
224                      do { es' <- mapM assignTypeVarsE es 
225                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
226                         ; unifyT  t  (TVar i)
227                         ; return $ ECon id (TVar i) c es' }
228                  Nothing -> fail $ "No type " ++ show c ++ " in type map"
229              }
230       assignTypeVarsE (EFun id t f es) =
231           do { i <- newTypeVar
232              ; case lookup f (typeMap ++ tmpMap)  of
233                  Just t' -> 
234                      do { es' <- mapM assignTypeVarsE es 
235                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
236                         ; unifyT  t  (TVar i)
237                         ; return $ EFun id (TVar i) f es' }
238                  _ ->
239                      error $ (show f ++ " is not in " ++ show (typeMap ++ tmpMap))
240              }
241 --      unifyT :: Type -> Type -> State ([(Type,Type)],Int) ()
242       unifyT (TUndet) _ = return ()
243       unifyT _ (TUndet) = return ()
244       unifyT (TVar i) (TVar j) | i == j = return ()
245       unifyT t t'       = addConstr t t'
246       unifyFT (FTUndet) _ = return ()
247       unifyFT _ (FTUndet) = return ()
248       unifyFT t t' = 
249           do { s  <- escapeForAll t 
250              ; s' <- escapeForAll t'
251              ; case (s,s') of 
252                  (TFun _ ts t, TFun _ ts' t') ->
253                      mapM_ (uncurry unifyT) $ zip (t:ts) (t':ts') }
254       escapeForAll (TFun is ts t) =
255           do { is' <- mapM (\_ -> newTypeVar) is 
256              ; let ts' = map (replaceTVar (zip is is')) ts
257              ; let t'   = replaceTVar (zip is is') t 
258              ; return $ TFun [] ts' t'}
259       replaceTVar table TUndet = TUndet
260       replaceTVar table (TVar i) =
261           case lookup i table of
262             Just j -> TVar j 
263             _      -> TVar i
264       replaceTVar table (TCon t ts) =
265           TCon t (map (replaceTVar table) ts)
266
267                      
268