22cc6cb9222f5e2fa51daeb28322c4c62fe037b6
[nt-coerce.git] / GHC / NT / Plugin.hs
1 {-# LANGUAGE TupleSections #-}
2
3 module GHC.NT.Plugin where
4
5 import GhcPlugins
6 import MkId
7 import Pair
8 import Kind
9
10 import Control.Monad
11 import Control.Applicative
12 import Data.Functor
13 import Data.Maybe
14
15 plugin :: Plugin
16 plugin = defaultPlugin {
17     installCoreToDos = install
18   }
19
20 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
21 install _ xs = do
22     reinitializeGlobals
23     return $ CoreDoPasses [nt,nt2] : xs
24   where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
25         nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
26
27 ntPass :: ModGuts -> CoreM ModGuts
28 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
29     let [oldTc] = mg_tcs g
30     nttc <- createNTTyCon (mg_module g) oldTc
31     tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
32
33     return $ g { mg_tcs = tcs' }
34 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
35     nttc <- lookupNTTyCon (mg_rdr_env g) (mg_module g)
36     binds' <- mapM (bind nttc) (mg_binds g)
37
38     return $ g { mg_binds = binds' }
39 ntPass g = return g
40
41 nt2Pass :: ModGuts -> CoreM ModGuts
42 nt2Pass = bindsOnlyPass $ mapM (traverseBind replaceCreateNT)
43
44 createNTTyCon :: Module -> TyCon -> CoreM TyCon
45 createNTTyCon mod oldTyCon = do
46     a <- createTyVar "a"
47     b <- createTyVar "b"
48     let arg_tys = map mkTyVarTy [a,b]
49     let tyConU = tyConUnique oldTyCon
50     dataConU <- getUniqueM
51     dataConWorkerU <- getUniqueM
52     dataConWrapperU <- getUniqueM
53     let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
54         rett = mkTyConApp t' arg_tys
55         dct = mkForAllTys [a,b] $ mkFunTy cot rett
56         -- Have to use the original name, otherwise we get a 
57         -- urk! lookup local fingerprint
58         --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
59         tcn = tyConName oldTyCon
60         n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
61         dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
62         dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
63         workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
64         dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
65         dc' = mkDataCon
66                 n
67                 False
68                 [ HsNoBang ]
69                 []
70                 [a,b]
71                 []
72                 []
73                 []
74                 [ cot ]
75                 rett
76                 t'
77                 []
78                 dataConIds -- (DCIds Nothing workId)
79         t' = mkAlgTyCon
80                tcn 
81                (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
82                [a,b]
83                Nothing
84                []
85                (DataTyCon [dc'] False)
86                NoParentTyCon
87                NonRecursive
88                False
89     return t'
90
91 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
92 replaceTyCon nttc t 
93     | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
94     | otherwise = return t
95
96 lookupNTTyCon :: GlobalRdrEnv -> Module -> CoreM TyCon
97 lookupNTTyCon env mod = do
98     let packageId = modulePackageId mod -- HACK!
99     let ntTypeModule = mkModule packageId (mkModuleName "GHC.NT.Type")
100     let rdrName = mkRdrQual (mkModuleName "GHC.NT.Type") (mkTcOcc "NT")
101
102     let e' = head (head (occEnvElts env)) -- HACK
103     
104     {-
105     putMsg (ppr e')
106     putMsg (ppr rdrName)
107     putMsg (ppr (lookupGRE_RdrName rdrName env))
108     putMsg (ppr (lookupGRE_RdrName (nameRdrName (gre_name e')) env))
109     
110     let [e] = lookupGRE_RdrName rdrName env
111     -}
112
113     let n = gre_name e'
114     lookupTyCon n
115
116
117 bind :: TyCon -> CoreBind -> CoreM CoreBind
118 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
119     NonRec v <$> do
120     tyLam "a" $ \a -> do
121     tyLam "b" $ \b -> do
122     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do 
123     lam "x" (mkTyVarTy a) $ \x -> do
124     return $ Cast (Var x) (CoVarCo co)
125
126 bind nttc b@(NonRec v e) | getOccString v == "refl" = do
127     NonRec v <$> do
128     tyLam "a" $ \a ->
129         conNT nttc $
130             return $ Refl (mkTyVarTy a)
131
132 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
133     NonRec v <$> do
134     tyLam "a" $ \a -> do
135     tyLam "b" $ \b -> do
136     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
137     conNT nttc $ do
138     return $ SymCo (CoVarCo co)
139
140 bind nttc b@(NonRec v e) | getOccString v == "trans" = do
141     NonRec v <$> do
142     tyLam "a" $ \a -> do
143     tyLam "b" $ \b -> do
144     tyLam "c" $ \c -> do
145     lamNT nttc "co1" (mkTyVarTy a) (mkTyVarTy b) $ \co1 -> do
146     lamNT nttc "co2" (mkTyVarTy b) (mkTyVarTy c) $ \co2 -> do
147     conNT nttc $ do
148     return $ TransCo (CoVarCo co1) (CoVarCo co2)
149
150 bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
151     NonRec v <$> do
152     tyLam "a" $ \a -> do
153     tyLam "b" $ \b -> do
154     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
155     conNT nttc $ do
156     return $ TyConAppCo listTyCon [CoVarCo co]
157
158 bind _ b = do
159     --putMsg (ppr b)
160     return b
161
162 replaceCreateNT :: CoreExpr -> CoreM (Maybe CoreExpr)
163 replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
164     | getOccString f == "createNT" = do
165         -- We exepct ta to be a newtype of tb
166         (tc,tyArgs) <- case splitTyConApp_maybe ta of
167             Nothing -> error "not a type application"
168             Just (tc,tyArgs) -> return (tc,tyArgs)
169         (vars,coa) <- case unwrapNewTyCon_maybe tc of
170             Nothing -> error "not a newtype"
171             Just (vars,_,co) -> return (vars,co)
172
173         -- TODO: Check if all construtors are in scope
174         -- TODO: Check that the expanded type of a is actually b
175
176         -- Extract the typcon from f's type
177         let nttc = tyConAppTyCon (exprType e)
178
179         Just <$> do
180         conNT nttc $ do
181         return $ mkAxInstCo coa tyArgs
182     | otherwise = do
183         --putMsgS $ getOccString f
184         return Nothing
185 replaceCreateNT e = do
186     --putMsg (ppr e)
187     return Nothing
188
189 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
190 traverse f e
191     = f' =<< case e of
192         Type t               -> return $ Type t
193         Coercion c           -> return $ Coercion c
194         Lit lit              -> return $ Lit lit
195         Var v                -> return $ Var v
196         App fun a            -> App <$> traverse f fun <*> traverse f a
197         Tick t e             -> Tick t <$> traverse f e
198         Cast e co            -> Cast <$> traverse f e <*> (return co)
199         Lam b e              -> Lam b <$> traverse f e
200         Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
201         Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
202     where f' x = do
203             r <- f x
204             return (fromMaybe x r)
205
206 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
207 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
208 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
209
210 createTyVar :: String -> CoreM TyVar
211 createTyVar name = do
212     u <- getUniqueM
213     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
214
215 tyLam :: String -> (TyVar -> CoreM CoreExpr) -> CoreM CoreExpr
216 tyLam name body = do 
217     v <- createTyVar name
218     Lam v <$> body v
219
220 lam :: String -> Type -> (Var -> CoreM CoreExpr) -> CoreM CoreExpr
221 lam name ty body = do 
222     u <- getUniqueM
223     let v = mkLocalVar VanillaId (mkSystemName u (mkVarOcc name)) ty vanillaIdInfo
224     Lam v <$> body v
225     
226 deconNT :: String -> CoreExpr -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
227 deconNT name nt body = do
228     let ntType = exprType nt
229     let (nttc, [t1, t2]) = splitTyConApp ntType
230     cou <- getUniqueM
231     let co = mkCoVar (mkSystemName cou (mkTyVarOcc name)) (mkCoercionType t1 t2)
232         [dc] = tyConDataCons nttc
233     b <- body co
234     return $ mkWildCase nt ntType (exprType b) [(DataAlt dc, [co], b)]
235
236 lamNT :: TyCon -> String -> Type -> Type -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
237 lamNT nttc name t1 t2 body = do
238     lam (name ++ "nt") (mkTyConApp nttc [t1, t2]) $ \nt -> do
239     deconNT name (Var nt) $ body
240
241 conNT :: TyCon -> CoreM Coercion -> CoreM CoreExpr
242 conNT nttc body = do
243     co <- body 
244     let Pair t1 t2  = coercionKind co
245     return $ mkConApp dc [ Type t1 , Type t2 , Coercion co ]
246   where [dc] = tyConDataCons nttc