consistenly use pretty unicode apostrophies
[darcs-mirror-sem_syn.git] / Type.hs
1 module Type (eraseType, eraseTypeT, typeInference) where
2
3 import AST 
4
5 import Data.Graph 
6 import Control.Monad.State
7 import Control.Monad.Error
8 import Util
9 import Data.Maybe
10 import Data.List (nub,nubBy,union)
11
12 import Data.Map (Map)
13 import qualified Data.Map as Map
14
15 -- type erasure
16 eraseType (AST decls) =
17     AST $ map (\(Decl f ftype ps e) ->
18              Decl f FTUndet (map eraseTypeP ps) (eraseTypeE e)) decls 
19
20 eraseTypeP (PVar id t varname)   
21     = PVar id TUndet varname
22 eraseTypeP (PCon id t conname ps)
23     = PCon id TUndet conname (map eraseTypeP ps)
24
25 eraseTypeE (EVar id t varname) 
26     = EVar id TUndet varname 
27 eraseTypeE (ECon id t conname es)
28     = ECon id TUndet conname (map eraseTypeE es)
29 eraseTypeE (EFun id t funname es)
30     = EFun id TUndet funname (map eraseTypeE es)
31
32 eraseTypeT (TAST decls) = 
33     TAST $ map (\(TDecl f ps es bs) -> 
34                     TDecl f (map eraseTypeP ps) (map eraseTypeE es)bs) decls
35
36 -- type inference
37
38 initTMap :: [ (Name, FType) ]
39 initTMap =
40     [ (Name "Z",   TFun [] [] (TCon (Name "Nat") [])),
41       (Name "S",   TFun [] [TCon (Name "Nat") []] (TCon (Name "Nat") [])),
42       (Name "Nil",  TFun [0] [] (TCon (Name "List") [TVar 0])),
43       (Name "Cons", TFun [0] [TVar 0, TCon (Name "List") [TVar 0]] 
44                 (TCon (Name "List") [TVar 0])) ]
45
46
47 typeInference (AST decls) = 
48     do { (decls',_,_) <- 
49              foldr (\decls m -> 
50                         do (rdecls, tMap,  icount)  <- m
51                            (decls', tMap', icount') <- inferenceStep decls tMap icount
52                            return $ (decls'++rdecls, tMap', icount')
53                    ) (return ([],initTMap,initIcount)) declss
54        ; return $ AST decls' } 
55     where
56       initIcount = max 1 ((foldr max 0 $ map maxTVarCount decls)+1) -- FIXME 
57       declss = 
58           let scc = stronglyConnComp callGraph 
59           in reverse $ map (\x -> case x of 
60                            AcyclicSCC f  -> 
61                                filter (\(Decl g _ _ _) -> f == g) decls
62                            CyclicSCC  fs -> 
63                                filter (\(Decl g _ _ _) -> g `elem` fs) decls) scc
64 --      callGraph = map (\f -> (f,f,snub $ f:funCallsE e)) $
65 --                     grupBy $ map (\(Decl f _ _ _) -> f) decls
66       callGraph = 
67           let fMap  = Map.fromListWith union $ 
68                        map (\(Decl f _ _ e) -> (f,f:funCallsE e)) decls 
69               fMap' = Map.map (snub) fMap 
70           in map (\(f,fs) -> (f,f,fs)) $ Map.toList fMap'
71       funCallsE (EVar _ _ v)    = []
72       funCallsE (EFun _ _ f es) = f:concatMap funCallsE es 
73       funCallsE (ECon _ _ _ es) = concatMap funCallsE es 
74
75
76 maxTVarCount (Decl f t ps e) =
77     (maxTVarFT t) 
78     `max` (foldr max 0 $ map maxTVarP ps) 
79     `max` (maxTVarE e)
80     where
81       maxTVarFT FTUndet        = 0 
82       maxTVarFT (TFun is ts t) = foldr max 0 is
83       maxTVarP  (PVar _ t _)    = fromT t
84       maxTVarP  (PCon _ t _ ps) = fromT t `max` 
85                                   (foldr max 0 $ map maxTVarP ps)
86       maxTVarE  (EVar _ t _)    = fromT t 
87       maxTVarE  (EFun _ t _ es) = fromT t `max`
88                                   (foldr max 0 $ map maxTVarE es)
89       maxTVarE  (ECon _ t _ es) = fromT t `max`
90                                   (foldr max 0 $ map maxTVarE es)
91       fromT (TUndet) = 0
92       fromT (TVar i) = i 
93       fromT (TCon _ ts) = 
94           foldr max 0 $ map fromT ts 
95
96 inferenceStep decls tmap icount = 
97       do { (decls0,  (tmpMap, icount0)) <- runStateT (makeInitConstr tmap decls) ([],icount)
98          ; (decls' , (constr, icount')) <- runStateT (mapM (assignTypeVars tmpMap tmap) decls0) ([],icount0)
99          ; (tmpMap', etypeMap') <- solveConstr tmpMap constr
100          ; let decls'' = map (repl tmpMap' etypeMap') decls'
101          ; return (decls'', tmpMap' ++ tmap, icount') }
102         where 
103           repl tM cM (Decl f ftype ps e) =
104               Decl f (fromJust $ lookup f tM) (map replP ps) (replE e)
105               where
106                 replP (PVar id (TVar i) v)    
107                     = PVar id (fromJust $ lookup i cM) v
108                 replP (PCon id (TVar i) c ps)
109                     = PCon id (fromJust $ lookup i cM) c (map replP ps)
110                 replE (EVar id (TVar i) v)
111                     = EVar id (fromJust $ lookup i cM) v
112                 replE (ECon id (TVar i) c es)
113                     = ECon id (fromJust $ lookup i cM) c (map replE es)
114                 replE (EFun id (TVar i) c es)
115                     = EFun id (fromJust $ lookup i cM) c (map replE es)
116           extractConstr ds = map (\(Decl f t _ _) -> (f,t)) $
117                                 nubBy isSameFunc ds
118
119
120
121 solveConstr tmpMap constr 
122     = substStep constr (tmpMap, rearrange constr)
123     where 
124       introForAll (k,TFun _ ts t) =
125           let vs = snub $ varsT t ++ concatMap varsT ts 
126           in (k,TFun vs ts t)
127       rearrange constr = 
128           let vs = nub $ concatMap (\(t1,t2) -> varsT t1 ++ varsT t2) constr 
129           in map (\x -> (x,TVar x)) vs                
130       varsT (TVar i)    = [i]
131       varsT (TCon _ ts) = concatMap varsT ts 
132       varsT (TUndet)    = []
133       substStep [] (tM,cM) = return (map introForAll tM,cM)
134       substStep ((t,t'):cs) (tM,cM) =
135           do { subs <- unify t t'
136              ; substStep
137                   (performSubstC subs cs)
138                   (performSubstTM subs tM, performSubstCM subs cM) }
139       performSubstC subs cs
140           = map (\(t1,t2) -> (performSubstT subs t1, performSubstT subs t2)) cs
141       performSubstTM subs tM 
142           = map (\(k,v) -> (k, performSubstFT subs v)) tM
143       performSubstCM subs cM
144           = map (\(k,v) -> (k, performSubstT subs v)) cM
145       performSubstFT subs (TFun is ts t) 
146           = TFun [] (map (performSubstT subs) ts) (performSubstT subs t)
147       performSubstT subs (TUndet) = TUndet 
148       performSubstT subs (TVar i) = 
149           case lookup (TVar i) subs of 
150             Just t' -> t'
151             _       -> TVar i
152       performSubstT subs (TCon c ts) =
153           TCon c (map (performSubstT subs) ts)
154       unify :: Type -> Type -> Either String [ (Type, Type) ]
155       unify (TVar i) t | not (i `elem` varsT t) = return [ (TVar i, t) ]
156       unify t (TVar i) | not (i `elem` varsT t) = return [ (TVar i, t) ]
157       unify (TVar i) (TVar j) | i == j = return []
158       unify (TCon c ts) (TCon c' ts') | c == c' 
159           = do { ss <- mapM (uncurry unify) $ zip ts ts'
160                ; return $ concat ss }
161       unify t t' = throwError $ "Can't unify types: " ++ show ( ppr (t,t'))
162                  
163     
164                
165
166 makeInitConstr tmap decls =
167     do { mapM_ (\(Decl f t ps e) ->
168                       do { tmpMap <- getTmpMap 
169                          ; case t of
170                              FTUndet -> 
171                                  case lookup f tmpMap of 
172                                    Just t' -> 
173                                        return ()
174                                    _ -> 
175                                        do { i  <- newTypeVar 
176                                           ; is <- mapM (\_->newTypeVar) ps 
177                                           ; let t' = TFun [] (map TVar is) (TVar i) 
178                                           ; putTmpMap ((f,t'):tmpMap)
179                                           ; return ()  }
180                              _ -> 
181                                  putTmpMap ((f,t):tmpMap)}) $ 
182          (nubBy isSameFunc decls)
183        ; tmpMap <- getTmpMap
184        ; let decls' = map (\(Decl f t ps e) -> 
185                              Decl f (fromJust $ lookup f tmpMap) ps e) decls
186        ; return decls' }
187     where getTmpMap    = do { (tmpMap,i) <- get; return tmpMap }
188           putTmpMap tm = do { (_,i) <- get; put (tm,i) }
189           newTypeVar   = do { (tm,i) <- get; put (tm,i+1); return i}
190
191     
192                
193
194 assignTypeVars tmpMap typeMap (Decl fname ftype ps e) =
195     do ps' <- mapM assignTypeVarsP ps
196        e'  <- assignTypeVarsE      e
197        unifyFT ftype (TFun [] (map typeofP ps') (typeofE e'))
198        let vtp = concatMap vtMapP ps'
199        let vte = vtMapE e'
200        mapM_ (\(x,t) -> case (lookup x vte) of 
201                           { Just t' -> unifyT t t'; _ -> return ()}) vtp 
202        mapM_ (\(x,t) -> case (lookup x vte) of 
203                           { Just t' -> unifyT t t' }) vte 
204        return $ Decl fname ftype ps' e'
205     where
206       vtMapP (PVar _ t x)    = [(x,t)]
207       vtMapP (PCon _ _ c ps) = concatMap vtMapP ps 
208       vtMapE (EVar _ t x)    = [(x,t)]
209       vtMapE (ECon _ _ c es) = concatMap vtMapE es
210       vtMapE (EFun _ _ c es) = concatMap vtMapE es
211 --      newTypeVar :: State ( [(Type,Type)], Int ) Int
212       newTypeVar = do { (constr, icount) <- get
213                       ; put (constr, icount+1)
214                       ; return icount }
215       addConstr s t = do { (constr, icount) <- get
216                            ; put ((s,t):constr, icount) }
217       assignTypeVarsP (PVar id t v) = 
218           do { i <- newTypeVar
219              ; unifyT t (TVar i) 
220              ; return $ PVar id (TVar i) v } 
221       assignTypeVarsP (PCon id t c ps) = 
222           do { i <- newTypeVar
223              ; case lookup c typeMap of
224                  Just t' -> 
225                      do { ps' <- mapM assignTypeVarsP ps 
226                         ; unifyFT t' (TFun [] (map typeofP ps') (TVar i))
227                         ; unifyT  t  (TVar i)
228                         ; return $ PCon id (TVar i) c ps' }
229                  Nothing -> fail $ "No entry " ++ show c ++ " in type map"
230              }
231       assignTypeVarsE (EVar id t v) = 
232           do { i <- newTypeVar 
233              ; unifyT t (TVar i)
234              ; return $ EVar id (TVar i) v }
235       assignTypeVarsE (ECon id t c es) =
236           do { i <- newTypeVar
237              ; case lookup c typeMap of
238                  Just t' -> 
239                      do { es' <- mapM assignTypeVarsE es 
240                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
241                         ; unifyT  t  (TVar i)
242                         ; return $ ECon id (TVar i) c es' }
243                  Nothing -> fail $ "No entry " ++ show c ++ " in type map"
244              }
245       assignTypeVarsE (EFun id t f es) =
246           do { i <- newTypeVar
247              ; case lookup f (typeMap ++ tmpMap)  of
248                  Just t' -> 
249                      do { es' <- mapM assignTypeVarsE es 
250                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
251                         ; unifyT  t  (TVar i)
252                         ; return $ EFun id (TVar i) f es' }
253                  _ ->
254                      fail $ (show f ++ " is not in " ++ show (typeMap ++ tmpMap))
255              }
256 --      unifyT :: Type -> Type -> State ([(Type,Type)],Int) ()
257       unifyT (TUndet) _ = return ()
258       unifyT _ (TUndet) = return ()
259       unifyT (TVar i) (TVar j) | i == j = return ()
260       unifyT t t'       = addConstr t t'
261       unifyFT (FTUndet) _ = return ()
262       unifyFT _ (FTUndet) = return ()
263       unifyFT t t' = 
264           do { s  <- escapeForAll t 
265              ; s' <- escapeForAll t'
266              ; case (s,s') of 
267                  (TFun _ ts t, TFun _ ts' t') ->
268                      mapM_ (uncurry unifyT) $ zip (t:ts) (t':ts') }
269       escapeForAll (TFun is ts t) =
270           do { is' <- mapM (\_ -> newTypeVar) is 
271              ; let ts' = map (replaceTVar (zip is is')) ts
272              ; let t'   = replaceTVar (zip is is') t 
273              ; return $ TFun [] ts' t'}
274       replaceTVar table TUndet = TUndet
275       replaceTVar table (TVar i) =
276           case lookup i table of
277             Just j -> TVar j 
278             _      -> TVar i
279       replaceTVar table (TCon t ts) =
280           TCon t (map (replaceTVar table) ts)
281
282                      
283