Add comments
[nt-coerce.git] / GHC / NT / Plugin.hs
1 {-# LANGUAGE TupleSections #-}
2
3 module GHC.NT.Plugin where
4
5 import GhcPlugins
6 import MkId
7 import Pair
8 import Kind
9
10 import Control.Monad
11 import Control.Applicative
12 import Data.Functor
13 import Data.Maybe
14 import Data.List
15
16 --
17 -- General plugin pass setup
18 --
19
20 plugin :: Plugin
21 plugin = defaultPlugin {
22     installCoreToDos = install
23   }
24
25 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
26 install _ xs = do
27     reinitializeGlobals
28     return $ CoreDoPasses [nt,nt2] : xs
29   where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
30         nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
31
32 ntPass :: ModGuts -> CoreM ModGuts
33 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
34     let [oldTc] = mg_tcs g
35     nttc <- createNTTyCon (mg_module g) oldTc
36     tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
37
38     return $ g { mg_tcs = tcs' }
39 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
40     nttc <- lookupNTTyCon (mg_rdr_env g)
41     binds' <- mapM (bind nttc) (mg_binds g)
42
43     return $ g { mg_binds = binds' }
44 ntPass g = return g
45
46 nt2Pass :: ModGuts -> CoreM ModGuts
47 nt2Pass g = do
48     nttc <- lookupNTTyCon (mg_rdr_env g)
49     --putMsg (ppr nttc)
50     binds' <- mapM (traverseBind (replaceDeriveThisNT nttc)) (mg_binds g)
51     return $ g { mg_binds = binds' }
52
53 --
54 -- Definition of the NT data constructor (which cannot be written in Haskell)
55 -- 
56
57 createNTTyCon :: Module -> TyCon -> CoreM TyCon
58 createNTTyCon mod oldTyCon = do
59     a <- createTyVar "a"
60     b <- createTyVar "b"
61     let arg_tys = map mkTyVarTy [a,b]
62     let tyConU = tyConUnique oldTyCon
63     dataConU <- getUniqueM
64     dataConWorkerU <- getUniqueM
65     dataConWrapperU <- getUniqueM
66     let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
67         rett = mkTyConApp t' arg_tys
68         dct = mkForAllTys [a,b] $ mkFunTy cot rett
69         -- Have to use the original name, otherwise we get a 
70         -- urk! lookup local fingerprint
71         --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
72         tcn = tyConName oldTyCon
73         n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
74         dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
75         dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
76         workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
77         dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
78         dc' = mkDataCon
79                 n
80                 False
81                 [ HsNoBang ]
82                 []
83                 [a,b]
84                 []
85                 []
86                 []
87                 [ cot ]
88                 rett
89                 t'
90                 []
91                 dataConIds -- (DCIds Nothing workId)
92         t' = mkAlgTyCon
93                tcn 
94                (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
95                [a,b]
96                Nothing
97                []
98                (DataTyCon [dc'] False)
99                NoParentTyCon
100                NonRecursive
101                False
102     return t'
103
104 -- | This replaces the dummy NT type constuctor by our generated one
105 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
106 replaceTyCon nttc t 
107     | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
108     | otherwise = return t
109
110 -- | In later modules, fetching the NT type constructor 
111 lookupNTTyCon :: GlobalRdrEnv -> CoreM TyCon
112 lookupNTTyCon env = do
113     let Just n = find isNT (map gre_name (concat (occEnvElts env)))
114     lookupTyCon n
115   where
116     isNT n = let oN = occName n in
117         occNameString oN == "NT" &&
118         occNameSpace oN == tcClsName &&
119         moduleNameString (moduleName (nameModule n)) == "GHC.NT.Type"
120
121 --
122 -- Implementation of the pass that produces GHC.NT
123 --
124
125 bind :: TyCon -> CoreBind -> CoreM CoreBind
126 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
127     NonRec v <$> do
128     tyLam "a" $ \a -> do
129     tyLam "b" $ \b -> do
130     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do 
131     lam "x" (mkTyVarTy a) $ \x -> do
132     return $ Cast (Var x) (CoVarCo co)
133
134 bind nttc b@(NonRec v e) | getOccString v == "refl" = do
135     NonRec v <$> do
136     tyLam "a" $ \a ->
137         conNT nttc $
138             return $ Refl (mkTyVarTy a)
139
140 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
141     NonRec v <$> do
142     tyLam "a" $ \a -> do
143     tyLam "b" $ \b -> do
144     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
145     conNT nttc $ do
146     return $ SymCo (CoVarCo co)
147
148 bind nttc b@(NonRec v e) | getOccString v == "trans" = do
149     NonRec v <$> do
150     tyLam "a" $ \a -> do
151     tyLam "b" $ \b -> do
152     tyLam "c" $ \c -> do
153     lamNT nttc "co1" (mkTyVarTy a) (mkTyVarTy b) $ \co1 -> do
154     lamNT nttc "co2" (mkTyVarTy b) (mkTyVarTy c) $ \co2 -> do
155     conNT nttc $ do
156     return $ TransCo (CoVarCo co1) (CoVarCo co2)
157
158 bind _ b = do
159     --putMsg (ppr b)
160     return b
161
162
163 --
164 -- Implementation of "deriving foo :: ... -> NT t1 t2"
165 --
166
167 -- Tries to find a coercion between the given types in the list of coercions
168 findCoercion :: Type -> Type -> [Coercion] -> Maybe Coercion
169 findCoercion t1 t2 = find go
170   where go c = let Pair t1' t2' = coercionKind c in t1' `eqType` t1 && t2' `eqType` t2
171
172 -- Given two types (and a few coercions to use), tries to construct a coercion
173 -- between them
174 deriveNT :: TyCon -> [Coercion] -> Type -> Type -> CoreM Coercion
175 deriveNT nttc cos t1 t2
176     | Just (tc1,tyArgs1) <- splitTyConApp_maybe t1,
177       Just (tc2,tyArgs2) <- splitTyConApp_maybe t2,
178       tc1 == tc2 = do
179         TyConAppCo tc1 <$> sequence (zipWith (deriveNT nttc cos) tyArgs1 tyArgs2)
180     | Just (tc,tyArgs) <- splitTyConApp_maybe t1 = do
181         case unwrapNewTyCon_maybe tc of
182             Just (tyVars, tyExpanded, coAxiom) -> do
183                 -- putMsg (ppr (unwrapNewTyCon_maybe tc))
184                 let rhs = newTyConInstRhs tc tyArgs
185                 if t2 `eqType` rhs
186                   then return $ mkAxInstCo coAxiom tyArgs
187                   else err_wrong_newtype rhs
188             Nothing -> err_not_newtype
189     | Just usable <- findCoercion t1 t2 cos = do
190         return usable
191     | otherwise = err_no_idea_what_to_do
192   where
193     err_wrong_newtype rhs =
194         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
195             ppr t1 $$ ppr t2 $$ 
196             text "The former is a newtype of" $$ ppr rhs
197     err_not_newtype = 
198         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
199             ppr t1 $$ ppr t2 $$ 
200             text "The former is not a newtype."
201     err_no_idea_what_to_do =
202         pprSorry "deriveThisNT does not know how to derive an NT value relating" $  
203             ppr t1 $$ ppr t2
204
205
206 -- Check if a type if of type NT t1 t2, and returns t1 and t2
207 isNTType :: TyCon -> Type -> Maybe (Type, Type)
208 isNTType nttc t | Just (tc,[t1,t2]) <- splitTyConApp_maybe t, tc == nttc = Just (t1,t2)
209                 | otherwise = Nothing
210
211
212 -- Creates the body of a "deriving foo :: ... -> NT t1 t2" function
213 deriveNTFun :: TyCon -> [Coercion] -> Type -> CoreM CoreExpr
214 deriveNTFun nttc cos t
215     | Just (at, rt) <- splitFunTy_maybe t = do
216         case isNTType nttc at of
217             Just (t1,t2) -> do
218                 lamNT nttc "nt" t1 t2 $ \co -> 
219                     deriveNTFun nttc (CoVarCo co:cos) rt
220             Nothing -> err_non_NT_argument at
221     | Just (t1,t2) <- isNTType nttc t = do
222         conNT nttc $ deriveNT nttc cos t1 t2
223     | otherwise = err_no_idea_what_to_do
224   where
225     err_non_NT_argument at = 
226         pprPgmError "deriveNTFun cannot handle arguments of non-NT-type:" $ ppr at
227     err_no_idea_what_to_do =
228         pprPgmError "deriveThisNT does not know how to derive code of type:" $  ppr t
229
230 -- Replace every occurrence of the magic 'deriveThisNT' by a valid implementation
231 replaceDeriveThisNT :: TyCon -> CoreExpr -> CoreM (Maybe CoreExpr)
232 replaceDeriveThisNT nttc e@(App (Var f) (Type t))
233     | getOccString f == "deriveThisNT" = Just <$> deriveNTFun nttc [] t
234 replaceDeriveThisNT _ e = do
235     --putMsg (ppr e)
236     return Nothing
237
238 --
239 -- General utilities
240 -- 
241
242 -- Replace an expression everywhere
243 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
244 traverse f e
245     = f' =<< case e of
246         Type t               -> return $ Type t
247         Coercion c           -> return $ Coercion c
248         Lit lit              -> return $ Lit lit
249         Var v                -> return $ Var v
250         App fun a            -> App <$> traverse f fun <*> traverse f a
251         Tick t e             -> Tick t <$> traverse f e
252         Cast e co            -> Cast <$> traverse f e <*> (return co)
253         Lam b e              -> Lam b <$> traverse f e
254         Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
255         Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
256     where f' x = do
257             r <- f x
258             return (fromMaybe x r)
259
260 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
261 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
262 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
263
264 -- Convenient Core creating functions
265
266 createTyVar :: String -> CoreM TyVar
267 createTyVar name = do
268     u <- getUniqueM
269     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
270
271 tyLam :: String -> (TyVar -> CoreM CoreExpr) -> CoreM CoreExpr
272 tyLam name body = do 
273     v <- createTyVar name
274     Lam v <$> body v
275
276 lam :: String -> Type -> (Var -> CoreM CoreExpr) -> CoreM CoreExpr
277 lam name ty body = do 
278     u <- getUniqueM
279     let v = mkLocalVar VanillaId (mkSystemName u (mkVarOcc name)) ty vanillaIdInfo
280     Lam v <$> body v
281     
282 deconNT :: String -> CoreExpr -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
283 deconNT name nt body = do
284     let ntType = exprType nt
285     let (nttc, [t1, t2]) = splitTyConApp ntType
286     cou <- getUniqueM
287     let co = mkCoVar (mkSystemName cou (mkTyVarOcc name)) (mkCoercionType t1 t2)
288         [dc] = tyConDataCons nttc
289     b <- body co
290     return $ mkWildCase nt ntType (exprType b) [(DataAlt dc, [co], b)]
291
292 lamNT :: TyCon -> String -> Type -> Type -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
293 lamNT nttc name t1 t2 body = do
294     lam (name ++ "nt") (mkTyConApp nttc [t1, t2]) $ \nt -> do
295     deconNT name (Var nt) $ body
296
297 conNT :: TyCon -> CoreM Coercion -> CoreM CoreExpr
298 conNT nttc body = do
299     co <- body 
300     let Pair t1 t2  = coercionKind co
301     return $ mkConApp dc [ Type t1 , Type t2 , Coercion co ]
302   where [dc] = tyConDataCons nttc