Initial check-in
[nt-coerce.git] / GHC / NT / Plugin.hs
1 {-# LANGUAGE TupleSections #-}
2
3 module GHC.NT.Plugin where
4
5 import GhcPlugins
6 import MkId
7 import Kind
8
9 import Control.Monad
10 import Control.Applicative
11 import Data.Functor
12 import Data.Maybe
13
14 plugin :: Plugin
15 plugin = defaultPlugin {
16     installCoreToDos = install
17   }
18
19 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
20 install _ xs = do
21     reinitializeGlobals
22     return $ CoreDoPasses [nt,nt2] : xs
23   where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
24         nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
25
26 ntPass :: ModGuts -> CoreM ModGuts
27 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
28     let [oldTc] = mg_tcs g
29     nttc <- createNTTyCon (mg_module g) oldTc
30     tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
31
32     dflags <- getDynFlags
33     return $ g { mg_tcs = tcs' }
34 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
35     nttc <- lookupNTTyCon (mg_module g)
36     binds' <- mapM (bind nttc) (mg_binds g)
37
38     dflags <- getDynFlags
39     dflags <- getDynFlags
40
41     return $ g { mg_binds = binds' }
42 ntPass g = return g
43
44 nt2Pass :: ModGuts -> CoreM ModGuts
45 nt2Pass = bindsOnlyPass $ mapM bind2
46
47 createNTTyCon :: Module -> TyCon -> CoreM TyCon
48 createNTTyCon mod oldTyCon = do
49     a <- createTyVar "a"
50     b <- createTyVar "b"
51     let arg_tys = map mkTyVarTy [a,b]
52     let tyConU = tyConUnique oldTyCon
53     dataConU <- getUniqueM
54     dataConWorkerU <- getUniqueM
55     dataConWrapperU <- getUniqueM
56     let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
57         rett = mkTyConApp t' arg_tys
58         dct = mkForAllTys [a,b] $ mkFunTy cot rett
59         -- Have to use the original name, otherwise we get a 
60         -- urk! lookup local fingerprint
61         --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
62         tcn = tyConName oldTyCon
63         n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
64         dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
65         dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
66         workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
67         dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
68         dc' = mkDataCon
69                 n
70                 False
71                 [ HsNoBang ]
72                 []
73                 [a,b]
74                 []
75                 []
76                 []
77                 [ cot ]
78                 rett
79                 t'
80                 []
81                 dataConIds -- (DCIds Nothing workId)
82         t' = mkAlgTyCon
83                tcn 
84                (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
85                [a,b]
86                Nothing
87                []
88                (DataTyCon [dc'] False)
89                NoParentTyCon
90                NonRecursive
91                False
92     return t'
93
94 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
95 replaceTyCon nttc t 
96     | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
97     | otherwise = return t
98
99 lookupNTTyCon :: Module -> CoreM TyCon
100 lookupNTTyCon mod = do
101     let packageId = modulePackageId mod -- HACK!
102     let ntTypeModule = mkModule packageId (mkModuleName "GHC.NT.Type")
103     nc <- getOrigNameCache
104
105     dflags <- getDynFlags
106     --putMsgS $ showSDoc dflags (ppr (moduleEnvKeys nc))
107     --putMsgS $ showSDoc dflags (ppr ntTypeModule)
108
109     -- HACK!
110     let ntTypeModule = last (moduleEnvKeys nc) -- Why does the other not work?
111     --putMsgS $ showSDoc dflags (ppr ntTypeModule)
112
113     let Just occEnv = lookupModuleEnv nc ntTypeModule
114
115     --putMsgS $ showSDoc dflags (ppr (occEnv)) 
116     -- let Just ntTyConName = lookupOccEnv occEnv (mkTcOccFS (fsLit "NT")) -- Why does this not work?
117     -- MORE HACKS!
118     let [ntTyConName] = occEnvElts occEnv
119     lookupTyCon ntTyConName
120
121 bind :: TyCon -> CoreBind -> CoreM CoreBind
122 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
123     dflags <- getDynFlags
124     au <- getUniqueM
125     bu <- getUniqueM
126     ntu <- getUniqueM
127     nttu <- getUniqueM
128     xu <- getUniqueM
129     cou <- getUniqueM
130     let a   = mkTyVar (mkSystemName au (mkTyVarOcc "a")) liftedTypeKind
131         b   = mkTyVar (mkSystemName bu (mkTyVarOcc "b")) liftedTypeKind
132         ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
133         nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
134         x   = mkLocalVar VanillaId (mkSystemName xu (mkVarOcc "b")) (mkTyVarTy a) vanillaIdInfo
135         co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
136         [dc] = tyConDataCons nttc
137     let e' = Lam a $ Lam b $ Lam nt $ Lam x $
138                 Case (Var nt) nt (mkTyVarTy b)
139                     [(DataAlt dc, [co], 
140                         Cast (Var x) (CoVarCo co)
141                     )]
142     return (NonRec v e')
143
144 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
145     dflags <- getDynFlags
146     a <- createTyVar "a"
147     b <- createTyVar "b"
148     ntu <- getUniqueM
149     nttu <- getUniqueM
150     ntt'u <- getUniqueM
151     cou <- getUniqueM
152     let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
153         ntt' = mkTyConApp nttc [mkTyVarTy b, mkTyVarTy a]
154         nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
155         co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
156         [dc] = tyConDataCons nttc
157     let e' = Lam a $ Lam b $ Lam nt $
158             Case (Var nt) nt ntt'
159                 [(DataAlt dc, [co], mkConApp dc
160                     [ Type (mkTyVarTy b)
161                     , Type (mkTyVarTy a)
162                     ,  Coercion (SymCo (CoVarCo co))
163                     ]
164                 )]
165     return (NonRec v e')
166
167 bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
168     a <- createTyVar "a"
169     b <- createTyVar "b"
170     ntu <- getUniqueM
171     nttu <- getUniqueM
172     ntt'u <- getUniqueM
173     cou <- getUniqueM
174     let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
175         ntt' = mkTyConApp nttc [mkTyConApp listTyCon [mkTyVarTy a], mkTyConApp listTyCon [mkTyVarTy b]]
176         nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
177         co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
178         [dc] = tyConDataCons nttc
179     let e' = Lam a $ Lam b $ Lam nt $
180             Case (Var nt) nt ntt' 
181                 [(DataAlt dc, [co], mkConApp dc
182                     [ Type (mkTyConApp listTyCon [mkTyVarTy a])
183                     , Type (mkTyConApp listTyCon [mkTyVarTy b])
184                     ,  Coercion (TyConAppCo listTyCon [CoVarCo co])
185                     ]
186                 )]
187     return (NonRec v e')
188
189 bind _ b = do
190     dflags <- getDynFlags
191     --putMsgS $ showSDoc dflags (ppr b)
192     return b
193
194 bind2 :: CoreBind -> CoreM CoreBind
195 bind2 (NonRec v e) = NonRec v <$> traverse replaceCreateNT e
196 bind2 (Rec binds) = Rec <$> mapM (\(v,e) -> (\e' -> (v,e')) <$> traverse replaceCreateNT e) binds
197
198
199 replaceCreateNT :: CoreExpr -> CoreM (Maybe CoreExpr)
200 replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
201     | getOccString f == "createNT" = do
202         -- We exepct ta to be a newtype of tb
203         (tc,tyArgs) <- case splitTyConApp_maybe ta of
204             Nothing -> error "not a type application"
205             Just (tc,tyArgs) -> return (tc,tyArgs)
206         (vars,coa) <- case unwrapNewTyCon_maybe tc of
207             Nothing -> error "not a newtype"
208             Just (vars,_,co) -> return (vars,co)
209
210         -- TODO: Check if all construtors are in scope
211         -- TODO: Check that the expanded type of a is actually b
212
213         dflags <- getDynFlags
214         --putMsgS $ showSDoc dflags (ppr e)
215         -- Extract the typcon from f's type
216         let nttc = tyConAppTyCon (exprType e)
217         let [dc] = tyConDataCons nttc
218         let e' = mkConApp dc [ Type ta, Type tb, Coercion (mkAxInstCo coa tyArgs)] :: CoreExpr
219         --putMsgS $ showSDoc dflags (ppr nttc)
220         --putMsgS $ showSDoc dflags (ppr (tyConDataCons nttc))
221         --putMsgS $ showSDoc dflags (ppr e')
222         return (Just e')
223     | otherwise = do
224         --putMsgS $ getOccString f
225         return Nothing
226 replaceCreateNT e = do
227     --dflags <- getDynFlags
228     --putMsgS $ showSDoc dflags (ppr e)
229     return Nothing
230
231 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
232 traverse f e
233     = f' =<< case e of
234         Type t               -> return $ Type t
235         Coercion c           -> return $ Coercion c
236         Lit lit              -> return $ Lit lit
237         Var v                -> return $ Var v
238         App fun a            -> App <$> traverse f fun <*> traverse f a
239         Tick t e             -> Tick t <$> traverse f e
240         Cast e co            -> Cast <$> traverse f e <*> (return co)
241         Lam b e              -> Lam b <$> traverse f e
242         Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
243         Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
244     where f' x = do
245             r <- f x
246             return (fromMaybe x r)
247
248 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
249 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
250 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
251
252 createTyVar :: String -> CoreM TyVar
253 createTyVar name = do
254     u <- getUniqueM
255     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind