[Init] Implementation of "Combining Syntactic and Semantic Bidirectionalization"
[darcs-mirror-sem_syn.git] / Type.hs
1 module Type where
2
3 import AST 
4
5 import Data.Graph 
6 import Control.Monad.State
7 import Control.Monad.Error
8 import Util
9 import Data.Maybe
10 import Data.List (nub,nubBy,union)
11
12
13 import Data.Map (Map)
14 import qualified Data.Map as Map
15
16
17 -- type inference
18
19 initTMap :: [ (Name, FType) ]
20 initTMap =
21     [ (Name "Z",   TFun [] [] (TCon (Name "Nat") [])),
22       (Name "S",   TFun [] [TCon (Name "Nat") []] (TCon (Name "Nat") [])),
23       (Name "Nil",  TFun [0] [] (TCon (Name "List") [TVar 0])),
24       (Name "Cons", TFun [0] [TVar 0, TCon (Name "List") [TVar 0]] 
25                 (TCon (Name "List") [TVar 0])) ]
26
27
28 typeInference (AST decls) = 
29     let mAst = do { (decls',_,_) <- 
30                         foldr (\decls m -> 
31                                    do (rdecls, tMap,  icount)  <- m
32                                       (decls', tMap', icount') <- inferenceStep decls tMap icount
33                                       return $ (decls'++rdecls, tMap', icount')
34                               ) (return ([],initTMap,initIcount)) declss
35                   ; return $ AST decls' } 
36     in case mAst of 
37          Left s  -> error s 
38          Right a ->  a 
39     where
40       initIcount = 100 -- FIXME 
41       declss = 
42           let scc = stronglyConnComp callGraph 
43           in  reverse $ map (\x -> case x of 
44                            AcyclicSCC f  -> 
45                                filter (\(Decl g _ _ _) -> f == g) decls
46                            CyclicSCC  fs -> 
47                                filter (\(Decl g _ _ _) -> g `elem` fs) decls) scc
48 --      callGraph = map (\f -> (f,f,snub $ f:funCallsE e)) $
49 --                     grupBy $ map (\(Decl f _ _ _) -> f) decls
50       callGraph = 
51           let fMap  = Map.fromListWith union $ 
52                        map (\(Decl f _ _ e) -> (f,f:funCallsE e)) decls 
53               fMap' = Map.map (snub) fMap 
54           in map (\(f,fs) -> (f,f,fs)) $ Map.toList fMap'
55       funCallsE (EVar _ _ v)    = []
56       funCallsE (EFun _ _ f es) = f:concatMap funCallsE es 
57       funCallsE (ECon _ _ _ es) = concatMap funCallsE es 
58
59
60 inferenceStep decls tmap icount = 
61     let (decls0,  (tmpMap, icount0))  
62             = runState (makeInitConstr tmap decls) ([],icount)
63         (decls' , (constr, icount')) 
64             = runState (mapM (assignTypeVars tmpMap tmap) decls0) ([],icount0)
65     in
66       do { (tmpMap', etypeMap') <- solveConstr tmpMap constr
67          ; let decls'' = map (repl tmpMap' etypeMap') decls'
68          ; return (decls'', tmpMap' ++ tmap, icount') }
69         where 
70           repl tM cM (Decl f ftype ps e) =
71               Decl f (fromJust $ lookup f tM) (map replP ps) (replE e)
72               where
73                 replP (PVar id (TVar i) v)    
74                     = PVar id (fromJust $ lookup i cM) v
75                 replP (PCon id (TVar i) c ps)
76                     = PCon id (fromJust $ lookup i cM) c (map replP ps)
77                 replE (EVar id (TVar i) v)
78                     = EVar id (fromJust $ lookup i cM) v
79                 replE (ECon id (TVar i) c es)
80                     = ECon id (fromJust $ lookup i cM) c (map replE es)
81                 replE (EFun id (TVar i) c es)
82                     = EFun id (fromJust $ lookup i cM) c (map replE es)
83           extractConstr ds = map (\(Decl f t _ _) -> (f,t)) $
84                                 nubBy isSameFunc ds
85
86
87
88 solveConstr tmpMap constr 
89     = substStep constr (tmpMap, rearrange constr)
90     where 
91       introForAll (k,TFun _ ts t) =
92           let vs = snub $ varsT t ++ concatMap varsT ts 
93           in (k,TFun vs ts t)
94       rearrange constr = 
95           let vs = nub $ concatMap (\(t1,t2) -> varsT t1 ++ varsT t2) constr 
96           in map (\x -> (x,TVar x)) vs                
97       varsT (TVar i)    = [i]
98       varsT (TCon _ ts) = concatMap varsT ts 
99       varsT (TUndet)    = []
100       substStep [] (tM,cM) = return (map introForAll tM,cM)
101       substStep ((t,t'):cs) (tM,cM) =
102           do { subs <- unify t t'
103              ; substStep
104                   (performSubstC subs cs)
105                   (performSubstTM subs tM, performSubstCM subs cM) }
106       performSubstC subs cs
107           = map (\(t1,t2) -> (performSubstT subs t1, performSubstT subs t2)) cs
108       performSubstTM subs tM 
109           = map (\(k,v) -> (k, performSubstFT subs v)) tM
110       performSubstCM subs cM
111           = map (\(k,v) -> (k, performSubstT subs v)) cM
112       performSubstFT subs (TFun is ts t) 
113           = TFun [] (map (performSubstT subs) ts) (performSubstT subs t)
114       performSubstT subs (TUndet) = TUndet 
115       performSubstT subs (TVar i) = 
116           case lookup (TVar i) subs of 
117             Just t' -> t'
118             _       -> TVar i
119       performSubstT subs (TCon c ts) =
120           TCon c (map (performSubstT subs) ts)
121       unify :: Type -> Type -> Either String [ (Type, Type) ]
122       unify (TVar i) t | not (i `elem` varsT t) = return [ (TVar i, t) ]
123       unify t (TVar i) | not (i `elem` varsT t) = return [ (TVar i, t) ]
124       unify (TVar i) (TVar j) | i == j = return []
125       unify (TCon c ts) (TCon c' ts') | c == c' 
126           = do { ss <- mapM (uncurry unify) $ zip ts ts'
127                ; return $ concat ss }
128       unify t t' = throwError $ "Can't unify types: " ++ show ( ppr (t,t'))
129                  
130     
131                
132
133 makeInitConstr tmap decls =
134     do { mapM_ (\(Decl f t ps e) ->
135                       do { tmpMap <- getTmpMap 
136                          ; case t of
137                              FTUndet -> 
138                                  case lookup f tmpMap of 
139                                    Just t' -> 
140                                        return ()
141                                    _ -> 
142                                        do { i  <- newTypeVar 
143                                           ; is <- mapM (\_->newTypeVar) ps 
144                                           ; let t' = TFun [] (map TVar is) (TVar i) 
145                                           ; putTmpMap ((f,t'):tmpMap)
146                                           ; return ()  }
147                              _ -> 
148                                  putTmpMap ((f,t):tmpMap)}) $ 
149          (nubBy isSameFunc decls)
150        ; tmpMap <- getTmpMap
151        ; let decls' = map (\(Decl f t ps e) -> 
152                              Decl f (fromJust $ lookup f tmpMap) ps e) decls
153        ; return decls' }
154     where getTmpMap    = do { (tmpMap,i) <- get; return tmpMap }
155           putTmpMap tm = do { (_,i) <- get; put (tm,i) }
156           newTypeVar   = do { (tm,i) <- get; put (tm,i+1); return i}
157
158     
159                
160
161 assignTypeVars tmpMap typeMap (Decl fname ftype ps e) =
162     do ps' <- mapM assignTypeVarsP ps
163        e'  <- assignTypeVarsE      e
164        unifyFT ftype (TFun [] (map typeofP ps') (typeofE e'))
165        let vtp = concatMap vtMapP ps'
166        let vte = vtMapE e'
167        mapM_ (\(x,t) -> case (lookup x vte) of 
168                           { Just t' -> unifyT t t'; _ -> return ()}) vtp 
169        mapM_ (\(x,t) -> case (lookup x vte) of 
170                           { Just t' -> unifyT t t' }) vte 
171        return $ Decl fname ftype ps' e'
172     where
173       vtMapP (PVar _ t x)    = [(x,t)]
174       vtMapP (PCon _ _ c ps) = concatMap vtMapP ps 
175       vtMapE (EVar _ t x)    = [(x,t)]
176       vtMapE (ECon _ _ c es) = concatMap vtMapE es
177       vtMapE (EFun _ _ c es) = concatMap vtMapE es
178 --      newTypeVar :: State ( [(Type,Type)], Int ) Int
179       newTypeVar = do { (constr, icount) <- get
180                       ; put (constr, icount+1)
181                       ; return icount }
182       addConstr s t = do { (constr, icount) <- get
183                            ; put ((s,t):constr, icount) }
184       assignTypeVarsP (PVar id t v) = 
185           do { i <- newTypeVar
186              ; unifyT t (TVar i) 
187              ; return $ PVar id (TVar i) v } 
188       assignTypeVarsP (PCon id t c ps) = 
189           do { i <- newTypeVar
190              ; case lookup c typeMap of
191                  Just t' -> 
192                      do { ps' <- mapM assignTypeVarsP ps 
193                         ; unifyFT t' (TFun [] (map typeofP ps') (TVar i))
194                         ; unifyT  t  (TVar i)
195                         ; return $ PCon id (TVar i) c ps' }}
196       assignTypeVarsE (EVar id t v) = 
197           do { i <- newTypeVar 
198              ; unifyT t (TVar i)
199              ; return $ EVar id (TVar i) v }
200       assignTypeVarsE (ECon id t c es) =
201           do { i <- newTypeVar
202              ; case lookup c typeMap of
203                  Just t' -> 
204                      do { es' <- mapM assignTypeVarsE es 
205                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
206                         ; unifyT  t  (TVar i)
207                         ; return $ ECon id (TVar i) c es' }}
208       assignTypeVarsE (EFun id t f es) =
209           do { i <- newTypeVar
210              ; case lookup f (typeMap ++ tmpMap)  of
211                  Just t' -> 
212                      do { es' <- mapM assignTypeVarsE es 
213                         ; unifyFT t' (TFun [] (map typeofE es') (TVar i))
214                         ; unifyT  t  (TVar i)
215                         ; return $ EFun id (TVar i) f es' }
216                  _ ->
217                      error $ (show f ++ " is not in " ++ show (typeMap ++ tmpMap))
218              }
219 --      unifyT :: Type -> Type -> State ([(Type,Type)],Int) ()
220       unifyT (TUndet) _ = return ()
221       unifyT _ (TUndet) = return ()
222       unifyT (TVar i) (TVar j) | i == j = return ()
223       unifyT t t'       = addConstr t t'
224       unifyFT (FTUndet) _ = return ()
225       unifyFT _ (FTUndet) = return ()
226       unifyFT t t' = 
227           do { s  <- escapeForAll t 
228              ; s' <- escapeForAll t'
229              ; case (s,s') of 
230                  (TFun _ ts t, TFun _ ts' t') ->
231                      mapM_ (uncurry unifyT) $ zip (t:ts) (t':ts') }
232       escapeForAll (TFun is ts t) =
233           do { is' <- mapM (\_ -> newTypeVar) is 
234              ; let ts' = map (replaceTVar (zip is is')) ts
235              ; let t'   = replaceTVar (zip is is') t 
236              ; return $ TFun [] ts' t'}
237       replaceTVar table TUndet = TUndet
238       replaceTVar table (TVar i) =
239           case lookup i table of
240             Just j -> TVar j 
241             _      -> TVar i
242       replaceTVar table (TCon t ts) =
243           TCon t (map (replaceTVar table) ts)
244
245                      
246