Implement Refl in deriveNT
[nt-coerce.git] / GHC / NT / Plugin.hs
1 {-# LANGUAGE TupleSections #-}
2
3 module GHC.NT.Plugin where
4
5 import GhcPlugins
6 import MkId
7 import Pair
8 import Kind
9
10 import Control.Monad
11 import Control.Applicative
12 import Data.Functor
13 import Data.Maybe
14 import Data.List
15
16 --
17 -- General plugin pass setup
18 --
19
20 plugin :: Plugin
21 plugin = defaultPlugin {
22     installCoreToDos = install
23   }
24
25 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
26 install _ xs = do
27     reinitializeGlobals
28     return $ CoreDoPasses [nt,nt2] : xs
29   where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
30         nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
31
32 ntPass :: ModGuts -> CoreM ModGuts
33 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
34     let [oldTc] = mg_tcs g
35     nttc <- createNTTyCon (mg_module g) oldTc
36     tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
37
38     return $ g { mg_tcs = tcs' }
39 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
40     nttc <- lookupNTTyCon (mg_rdr_env g)
41     binds' <- mapM (bind nttc) (mg_binds g)
42
43     return $ g { mg_binds = binds' }
44 ntPass g = return g
45
46 nt2Pass :: ModGuts -> CoreM ModGuts
47 nt2Pass g = do
48     nttc <- lookupNTTyCon (mg_rdr_env g)
49     --putMsg (ppr nttc)
50     binds' <- mapM (traverseBind (replaceDeriveThisNT nttc)) (mg_binds g)
51     return $ g { mg_binds = binds' }
52
53 --
54 -- Definition of the NT data constructor (which cannot be written in Haskell)
55 -- 
56
57 createNTTyCon :: Module -> TyCon -> CoreM TyCon
58 createNTTyCon mod oldTyCon = do
59     a <- createTyVar "a"
60     b <- createTyVar "b"
61     let arg_tys = map mkTyVarTy [a,b]
62     let tyConU = tyConUnique oldTyCon
63     dataConU <- getUniqueM
64     dataConWorkerU <- getUniqueM
65     dataConWrapperU <- getUniqueM
66     let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
67         rett = mkTyConApp t' arg_tys
68         dct = mkForAllTys [a,b] $ mkFunTy cot rett
69         -- Have to use the original name, otherwise we get a 
70         -- urk! lookup local fingerprint
71         --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
72         tcn = tyConName oldTyCon
73         n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
74         dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
75         dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
76         workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
77         dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
78         dc' = mkDataCon
79                 n
80                 False
81                 [ HsNoBang ]
82                 []
83                 [a,b]
84                 []
85                 []
86                 []
87                 [ cot ]
88                 rett
89                 t'
90                 []
91                 dataConIds -- (DCIds Nothing workId)
92         t' = mkAlgTyCon
93                tcn 
94                (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
95                [a,b]
96                Nothing
97                []
98                (DataTyCon [dc'] False)
99                NoParentTyCon
100                NonRecursive
101                False
102     return t'
103
104 -- | This replaces the dummy NT type constuctor by our generated one
105 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
106 replaceTyCon nttc t 
107     | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
108     | otherwise = return t
109
110 -- | In later modules, fetching the NT type constructor 
111 lookupNTTyCon :: GlobalRdrEnv -> CoreM TyCon
112 lookupNTTyCon env = do
113     let Just n = find isNT (map gre_name (concat (occEnvElts env)))
114     lookupTyCon n
115   where
116     isNT n = let oN = occName n in
117         occNameString oN == "NT" &&
118         occNameSpace oN == tcClsName &&
119         moduleNameString (moduleName (nameModule n)) == "GHC.NT.Type"
120
121 --
122 -- Implementation of the pass that produces GHC.NT
123 --
124
125 bind :: TyCon -> CoreBind -> CoreM CoreBind
126 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
127     NonRec v <$> do
128     tyLam "a" $ \a -> do
129     tyLam "b" $ \b -> do
130     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do 
131     lam "x" (mkTyVarTy a) $ \x -> do
132     return $ Cast (Var x) (CoVarCo co)
133
134 bind nttc b@(NonRec v e) | getOccString v == "refl" = do
135     NonRec v <$> do
136     tyLam "a" $ \a ->
137         conNT nttc $
138             return $ Refl (mkTyVarTy a)
139
140 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
141     NonRec v <$> do
142     tyLam "a" $ \a -> do
143     tyLam "b" $ \b -> do
144     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
145     conNT nttc $ do
146     return $ SymCo (CoVarCo co)
147
148 bind nttc b@(NonRec v e) | getOccString v == "trans" = do
149     NonRec v <$> do
150     tyLam "a" $ \a -> do
151     tyLam "b" $ \b -> do
152     tyLam "c" $ \c -> do
153     lamNT nttc "co1" (mkTyVarTy a) (mkTyVarTy b) $ \co1 -> do
154     lamNT nttc "co2" (mkTyVarTy b) (mkTyVarTy c) $ \co2 -> do
155     conNT nttc $ do
156     return $ TransCo (CoVarCo co1) (CoVarCo co2)
157
158 bind _ b = do
159     --putMsg (ppr b)
160     return b
161
162
163 --
164 -- Implementation of "deriving foo :: ... -> NT t1 t2"
165 --
166
167 -- Tries to find a coercion between the given types in the list of coercions
168 findCoercion :: Type -> Type -> [Coercion] -> Maybe Coercion
169 findCoercion t1 t2 = find go
170   where go c = let Pair t1' t2' = coercionKind c in t1' `eqType` t1 && t2' `eqType` t2
171
172 -- Given two types (and a few coercions to use), tries to construct a coercion
173 -- between them
174 deriveNT :: TyCon -> [Coercion] -> Type -> Type -> CoreM Coercion
175 deriveNT nttc cos t1 t2
176     | t1 `eqType` t2 = do
177         return $ Refl t1
178     | Just (tc1,tyArgs1) <- splitTyConApp_maybe t1,
179       Just (tc2,tyArgs2) <- splitTyConApp_maybe t2,
180       tc1 == tc2 = do
181         TyConAppCo tc1 <$> sequence (zipWith (deriveNT nttc cos) tyArgs1 tyArgs2)
182     | Just (tc,tyArgs) <- splitTyConApp_maybe t1 = do
183         case unwrapNewTyCon_maybe tc of
184             Just (tyVars, tyExpanded, coAxiom) -> do
185                 -- putMsg (ppr (unwrapNewTyCon_maybe tc))
186                 let rhs = newTyConInstRhs tc tyArgs
187                 if t2 `eqType` rhs
188                   then return $ mkAxInstCo coAxiom tyArgs
189                   else err_wrong_newtype rhs
190             Nothing -> err_not_newtype
191     | Just usable <- findCoercion t1 t2 cos = do
192         return usable
193     | otherwise = err_no_idea_what_to_do
194   where
195     err_wrong_newtype rhs =
196         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
197             ppr t1 $$ ppr t2 $$ 
198             text "The former is a newtype of" $$ ppr rhs
199     err_not_newtype = 
200         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
201             ppr t1 $$ ppr t2 $$ 
202             text "The former is not a newtype."
203     err_no_idea_what_to_do =
204         pprSorry "deriveThisNT does not know how to derive an NT value relating" $  
205             ppr t1 $$ ppr t2
206
207
208 -- Check if a type if of type NT t1 t2, and returns t1 and t2
209 isNTType :: TyCon -> Type -> Maybe (Type, Type)
210 isNTType nttc t | Just (tc,[t1,t2]) <- splitTyConApp_maybe t, tc == nttc = Just (t1,t2)
211                 | otherwise = Nothing
212
213
214 -- Creates the body of a "deriving foo :: ... -> NT t1 t2" function
215 deriveNTFun :: TyCon -> [Coercion] -> Type -> CoreM CoreExpr
216 deriveNTFun nttc cos t
217     | Just (at, rt) <- splitFunTy_maybe t = do
218         case isNTType nttc at of
219             Just (t1,t2) -> do
220                 lamNT nttc "nt" t1 t2 $ \co -> 
221                     deriveNTFun nttc (CoVarCo co:cos) rt
222             Nothing -> err_non_NT_argument at
223     | Just (t1,t2) <- isNTType nttc t = do
224         conNT nttc $ deriveNT nttc cos t1 t2
225     | otherwise = err_no_idea_what_to_do
226   where
227     err_non_NT_argument at = 
228         pprPgmError "deriveNTFun cannot handle arguments of non-NT-type:" $ ppr at
229     err_no_idea_what_to_do =
230         pprPgmError "deriveThisNT does not know how to derive code of type:" $  ppr t
231
232 -- Replace every occurrence of the magic 'deriveThisNT' by a valid implementation
233 replaceDeriveThisNT :: TyCon -> CoreExpr -> CoreM (Maybe CoreExpr)
234 replaceDeriveThisNT nttc e@(App (Var f) (Type t))
235     | getOccString f == "deriveThisNT" = Just <$> deriveNTFun nttc [] t
236 replaceDeriveThisNT _ e = do
237     --putMsg (ppr e)
238     return Nothing
239
240 --
241 -- General utilities
242 -- 
243
244 -- Replace an expression everywhere
245 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
246 traverse f e
247     = f' =<< case e of
248         Type t               -> return $ Type t
249         Coercion c           -> return $ Coercion c
250         Lit lit              -> return $ Lit lit
251         Var v                -> return $ Var v
252         App fun a            -> App <$> traverse f fun <*> traverse f a
253         Tick t e             -> Tick t <$> traverse f e
254         Cast e co            -> Cast <$> traverse f e <*> (return co)
255         Lam b e              -> Lam b <$> traverse f e
256         Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
257         Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
258     where f' x = do
259             r <- f x
260             return (fromMaybe x r)
261
262 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
263 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
264 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
265
266 -- Convenient Core creating functions
267
268 createTyVar :: String -> CoreM TyVar
269 createTyVar name = do
270     u <- getUniqueM
271     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
272
273 tyLam :: String -> (TyVar -> CoreM CoreExpr) -> CoreM CoreExpr
274 tyLam name body = do 
275     v <- createTyVar name
276     Lam v <$> body v
277
278 lam :: String -> Type -> (Var -> CoreM CoreExpr) -> CoreM CoreExpr
279 lam name ty body = do 
280     u <- getUniqueM
281     let v = mkLocalVar VanillaId (mkSystemName u (mkVarOcc name)) ty vanillaIdInfo
282     Lam v <$> body v
283     
284 deconNT :: String -> CoreExpr -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
285 deconNT name nt body = do
286     let ntType = exprType nt
287     let (nttc, [t1, t2]) = splitTyConApp ntType
288     cou <- getUniqueM
289     let co = mkCoVar (mkSystemName cou (mkTyVarOcc name)) (mkCoercionType t1 t2)
290         [dc] = tyConDataCons nttc
291     b <- body co
292     return $ mkWildCase nt ntType (exprType b) [(DataAlt dc, [co], b)]
293
294 lamNT :: TyCon -> String -> Type -> Type -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
295 lamNT nttc name t1 t2 body = do
296     lam (name ++ "nt") (mkTyConApp nttc [t1, t2]) $ \nt -> do
297     deconNT name (Var nt) $ body
298
299 conNT :: TyCon -> CoreM Coercion -> CoreM CoreExpr
300 conNT nttc body = do
301     co <- body 
302     let Pair t1 t2  = coercionKind co
303     return $ mkConApp dc [ Type t1 , Type t2 , Coercion co ]
304   where [dc] = tyConDataCons nttc