1 {-# LANGUAGE TupleSections #-}
3 module GHC.NT.Plugin where
10 import Control.Applicative
15 plugin = defaultPlugin {
16 installCoreToDos = install
19 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
22 return $ CoreDoPasses [nt,nt2] : xs
23 where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
24 nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
26 ntPass :: ModGuts -> CoreM ModGuts
27 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
28 let [oldTc] = mg_tcs g
29 nttc <- createNTTyCon (mg_module g) oldTc
30 tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
33 return $ g { mg_tcs = tcs' }
34 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
35 nttc <- lookupNTTyCon (mg_module g)
36 binds' <- mapM (bind nttc) (mg_binds g)
41 return $ g { mg_binds = binds' }
44 nt2Pass :: ModGuts -> CoreM ModGuts
45 nt2Pass = bindsOnlyPass $ mapM bind2
47 createNTTyCon :: Module -> TyCon -> CoreM TyCon
48 createNTTyCon mod oldTyCon = do
51 let arg_tys = map mkTyVarTy [a,b]
52 let tyConU = tyConUnique oldTyCon
53 dataConU <- getUniqueM
54 dataConWorkerU <- getUniqueM
55 dataConWrapperU <- getUniqueM
56 let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
57 rett = mkTyConApp t' arg_tys
58 dct = mkForAllTys [a,b] $ mkFunTy cot rett
59 -- Have to use the original name, otherwise we get a
60 -- urk! lookup local fingerprint
61 --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
62 tcn = tyConName oldTyCon
63 n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
64 dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
65 dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
66 workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
67 dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
81 dataConIds -- (DCIds Nothing workId)
84 (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
88 (DataTyCon [dc'] False)
94 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
96 | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
97 | otherwise = return t
99 lookupNTTyCon :: Module -> CoreM TyCon
100 lookupNTTyCon mod = do
101 let packageId = modulePackageId mod -- HACK!
102 let ntTypeModule = mkModule packageId (mkModuleName "GHC.NT.Type")
103 nc <- getOrigNameCache
105 dflags <- getDynFlags
106 --putMsgS $ showSDoc dflags (ppr (moduleEnvKeys nc))
107 --putMsgS $ showSDoc dflags (ppr ntTypeModule)
110 let ntTypeModule = last (moduleEnvKeys nc) -- Why does the other not work?
111 --putMsgS $ showSDoc dflags (ppr ntTypeModule)
113 let Just occEnv = lookupModuleEnv nc ntTypeModule
115 --putMsgS $ showSDoc dflags (ppr (occEnv))
116 -- let Just ntTyConName = lookupOccEnv occEnv (mkTcOccFS (fsLit "NT")) -- Why does this not work?
118 let [ntTyConName] = occEnvElts occEnv
119 lookupTyCon ntTyConName
121 bind :: TyCon -> CoreBind -> CoreM CoreBind
122 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
123 dflags <- getDynFlags
130 let a = mkTyVar (mkSystemName au (mkTyVarOcc "a")) liftedTypeKind
131 b = mkTyVar (mkSystemName bu (mkTyVarOcc "b")) liftedTypeKind
132 ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
133 nt = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
134 x = mkLocalVar VanillaId (mkSystemName xu (mkVarOcc "b")) (mkTyVarTy a) vanillaIdInfo
135 co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
136 [dc] = tyConDataCons nttc
137 let e' = Lam a $ Lam b $ Lam nt $ Lam x $
138 Case (Var nt) nt (mkTyVarTy b)
140 Cast (Var x) (CoVarCo co)
144 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
145 dflags <- getDynFlags
152 let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
153 ntt' = mkTyConApp nttc [mkTyVarTy b, mkTyVarTy a]
154 nt = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
155 co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
156 [dc] = tyConDataCons nttc
157 let e' = Lam a $ Lam b $ Lam nt $
158 Case (Var nt) nt ntt'
159 [(DataAlt dc, [co], mkConApp dc
162 , Coercion (SymCo (CoVarCo co))
167 bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
174 let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
175 ntt' = mkTyConApp nttc [mkTyConApp listTyCon [mkTyVarTy a], mkTyConApp listTyCon [mkTyVarTy b]]
176 nt = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
177 co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
178 [dc] = tyConDataCons nttc
179 let e' = Lam a $ Lam b $ Lam nt $
180 Case (Var nt) nt ntt'
181 [(DataAlt dc, [co], mkConApp dc
182 [ Type (mkTyConApp listTyCon [mkTyVarTy a])
183 , Type (mkTyConApp listTyCon [mkTyVarTy b])
184 , Coercion (TyConAppCo listTyCon [CoVarCo co])
190 dflags <- getDynFlags
191 --putMsgS $ showSDoc dflags (ppr b)
194 bind2 :: CoreBind -> CoreM CoreBind
195 bind2 (NonRec v e) = NonRec v <$> traverse replaceCreateNT e
196 bind2 (Rec binds) = Rec <$> mapM (\(v,e) -> (\e' -> (v,e')) <$> traverse replaceCreateNT e) binds
199 replaceCreateNT :: CoreExpr -> CoreM (Maybe CoreExpr)
200 replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
201 | getOccString f == "createNT" = do
202 -- We exepct ta to be a newtype of tb
203 (tc,tyArgs) <- case splitTyConApp_maybe ta of
204 Nothing -> error "not a type application"
205 Just (tc,tyArgs) -> return (tc,tyArgs)
206 (vars,coa) <- case unwrapNewTyCon_maybe tc of
207 Nothing -> error "not a newtype"
208 Just (vars,_,co) -> return (vars,co)
210 -- TODO: Check if all construtors are in scope
211 -- TODO: Check that the expanded type of a is actually b
213 dflags <- getDynFlags
214 --putMsgS $ showSDoc dflags (ppr e)
215 -- Extract the typcon from f's type
216 let nttc = tyConAppTyCon (exprType e)
217 let [dc] = tyConDataCons nttc
218 let e' = mkConApp dc [ Type ta, Type tb, Coercion (mkAxInstCo coa tyArgs)] :: CoreExpr
219 --putMsgS $ showSDoc dflags (ppr nttc)
220 --putMsgS $ showSDoc dflags (ppr (tyConDataCons nttc))
221 --putMsgS $ showSDoc dflags (ppr e')
224 --putMsgS $ getOccString f
226 replaceCreateNT e = do
227 --dflags <- getDynFlags
228 --putMsgS $ showSDoc dflags (ppr e)
231 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
234 Type t -> return $ Type t
235 Coercion c -> return $ Coercion c
236 Lit lit -> return $ Lit lit
237 Var v -> return $ Var v
238 App fun a -> App <$> traverse f fun <*> traverse f a
239 Tick t e -> Tick t <$> traverse f e
240 Cast e co -> Cast <$> traverse f e <*> (return co)
241 Lam b e -> Lam b <$> traverse f e
242 Let bind e -> Let <$> traverseBind f bind <*> traverse f e
243 Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts
246 return (fromMaybe x r)
248 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
249 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
250 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
252 createTyVar :: String -> CoreM TyVar
253 createTyVar name = do
255 return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind