Also replace stuff in the case scrutineer
[nt-coerce.git] / GHC / NT / Plugin.hs
1 {-# LANGUAGE TupleSections #-}
2
3 module GHC.NT.Plugin where
4
5 import GhcPlugins
6 import MkId
7 import Pair
8 import Kind
9
10 import Control.Monad
11 import Control.Applicative
12 import Data.Functor
13 import Data.Maybe
14 import Data.List
15
16 --
17 -- General plugin pass setup
18 --
19
20 plugin :: Plugin
21 plugin = defaultPlugin {
22     installCoreToDos = install
23   }
24
25 install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
26 install _ xs = do
27     reinitializeGlobals
28     return $ CoreDoPasses [nt,nt2] : xs
29   where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
30         nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
31
32 ntPass :: ModGuts -> CoreM ModGuts
33 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
34     let [oldTc] = mg_tcs g
35     nttc <- createNTTyCon (mg_module g) oldTc
36     tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
37
38     return $ g { mg_tcs = tcs' }
39 ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
40     nttc <- lookupNTTyCon (mg_rdr_env g)
41     binds' <- mapM (bind nttc) (mg_binds g)
42
43     return $ g { mg_binds = binds' }
44 ntPass g = return g
45
46 nt2Pass :: ModGuts -> CoreM ModGuts
47 nt2Pass g = do
48     nttc <- lookupNTTyCon (mg_rdr_env g)
49     --putMsg (ppr nttc)
50     binds' <- mapM (traverseBind (replaceDeriveThisNT (mg_rdr_env g) nttc)) (mg_binds g)
51     return $ g { mg_binds = binds' }
52
53 --
54 -- Definition of the NT data constructor (which cannot be written in Haskell)
55 -- 
56
57 createNTTyCon :: Module -> TyCon -> CoreM TyCon
58 createNTTyCon mod oldTyCon = do
59     a <- createTyVar "a"
60     b <- createTyVar "b"
61     let arg_tys = map mkTyVarTy [a,b]
62     let tyConU = tyConUnique oldTyCon
63     dataConU <- getUniqueM
64     dataConWorkerU <- getUniqueM
65     dataConWrapperU <- getUniqueM
66     let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
67         rett = mkTyConApp t' arg_tys
68         dct = mkForAllTys [a,b] $ mkFunTy cot rett
69         -- Have to use the original name, otherwise we get a 
70         -- urk! lookup local fingerprint
71         --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
72         tcn = tyConName oldTyCon
73         n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
74         dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
75         dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
76         workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
77         dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
78         dc' = mkDataCon
79                 n
80                 False
81                 [ HsNoBang ]
82                 []
83                 [a,b]
84                 []
85                 []
86                 []
87                 [ cot ]
88                 rett
89                 t'
90                 []
91                 dataConIds -- (DCIds Nothing workId)
92         t' = mkAlgTyCon
93                tcn 
94                (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
95                [a,b]
96                Nothing
97                []
98                (DataTyCon [dc'] False)
99                NoParentTyCon
100                NonRecursive
101                False
102     return t'
103
104 -- | This replaces the dummy NT type constuctor by our generated one
105 replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
106 replaceTyCon nttc tc = return $ if isNT (tyConName tc) then nttc else tc
107
108 -- | In later modules, fetching the NT type constructor 
109 lookupNTTyCon :: GlobalRdrEnv -> CoreM TyCon
110 lookupNTTyCon env = do
111     let Just n = find isNT (map gre_name (concat (occEnvElts env)))
112     lookupTyCon n
113
114 -- | Checks if the given name is the type constructor 'GHC.NT.Type.NT'
115 isNT :: Name -> Bool
116 isNT n = let oN = occName n in
117     occNameString oN == "NT" &&
118     occNameSpace oN == tcClsName &&
119     moduleNameString (moduleName (nameModule n)) == "GHC.NT.Type"
120
121 --
122 -- Implementation of the pass that produces GHC.NT
123 --
124
125 bind :: TyCon -> CoreBind -> CoreM CoreBind
126 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
127     NonRec v <$> do
128     tyLam "a" $ \a -> do
129     tyLam "b" $ \b -> do
130     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do 
131     lam "x" (mkTyVarTy a) $ \x -> do
132     return $ Cast (Var x) (CoVarCo co)
133
134 bind nttc b@(NonRec v e) | getOccString v == "refl" = do
135     NonRec v <$> do
136     tyLam "a" $ \a ->
137         conNT nttc $
138             return $ Refl (mkTyVarTy a)
139
140 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
141     NonRec v <$> do
142     tyLam "a" $ \a -> do
143     tyLam "b" $ \b -> do
144     lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
145     conNT nttc $ do
146     return $ SymCo (CoVarCo co)
147
148 bind nttc b@(NonRec v e) | getOccString v == "trans" = do
149     NonRec v <$> do
150     tyLam "a" $ \a -> do
151     tyLam "b" $ \b -> do
152     tyLam "c" $ \c -> do
153     lamNT nttc "co1" (mkTyVarTy a) (mkTyVarTy b) $ \co1 -> do
154     lamNT nttc "co2" (mkTyVarTy b) (mkTyVarTy c) $ \co2 -> do
155     conNT nttc $ do
156     return $ TransCo (CoVarCo co1) (CoVarCo co2)
157
158 bind _ b = do
159     --putMsg (ppr b)
160     return b
161
162
163 --
164 -- Implementation of "deriving foo :: ... -> NT t1 t2"
165 --
166
167 -- Tries to find a coercion between the given types in the list of coercions
168 findCoercion :: Type -> Type -> [Coercion] -> Maybe Coercion
169 findCoercion t1 t2 = find go
170   where go c = let Pair t1' t2' = coercionKind c in t1' `eqType` t1 && t2' `eqType` t2
171
172 -- Check if the user is able to see the data constructors of the given type.
173 -- It seems that the type constructors for lists do not occur in the
174 -- GlobalRdrEnv, so we assume that they are always ok.
175 -- NOTE: It is not possible to have an abstract data type without type
176 -- constructors.
177 checkDataConsInScope :: GlobalRdrEnv -> TyCon -> CoreM ()
178 checkDataConsInScope env tc | tc == listTyCon = return ()
179 checkDataConsInScope env tc = mapM_ (checkInScope env . dataConName) (tyConDataCons tc)
180
181 checkInScope :: GlobalRdrEnv -> Name -> CoreM ()
182 checkInScope env n = case lookupGRE_Name env n of
183     [gre] -> return () -- TODO: Mark name as used (not possible in a plugin, I guess)
184     [] -> err_not_in_scope
185     _ -> panic "checkInScope: Got more GREs than expected "
186   where
187     err_not_in_scope =
188         pprPgmError "Cannot derive:" $
189             ppr n <+> text "Not in scope" $$ ppr (globalRdrEnvElts env)
190
191 -- Given two types (and a few coercions to use), tries to construct a coercion
192 -- between them
193 deriveNT :: GlobalRdrEnv -> TyCon -> [Coercion] -> Type -> Type -> CoreM Coercion
194 deriveNT env nttc cos t1 t2
195     | t1 `eqType` t2 = do
196         return $ Refl t1
197     | Just (tc1,tyArgs1) <- splitTyConApp_maybe t1,
198       Just (tc2,tyArgs2) <- splitTyConApp_maybe t2,
199       tc1 == tc2 = do
200         checkDataConsInScope env tc1
201         TyConAppCo tc1 <$> sequence (zipWith (deriveNT env nttc cos) tyArgs1 tyArgs2)
202     | Just (tc,tyArgs) <- splitTyConApp_maybe t1 = do
203         case unwrapNewTyCon_maybe tc of
204             Just (tyVars, tyExpanded, coAxiom) -> do
205                 checkDataConsInScope env tc
206                 -- putMsg (ppr (unwrapNewTyCon_maybe tc))
207                 let rhs = newTyConInstRhs tc tyArgs
208                 if t2 `eqType` rhs
209                   then return $ mkAxInstCo coAxiom tyArgs
210                   else err_wrong_newtype rhs
211             Nothing -> err_not_newtype
212     | Just usable <- findCoercion t1 t2 cos = do
213         return usable
214     | otherwise = err_no_idea_what_to_do
215   where
216     err_wrong_newtype rhs =
217         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
218             ppr t1 $$ ppr t2 $$ 
219             text "The former is a newtype of" $$ ppr rhs
220     err_not_newtype = 
221         pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
222             ppr t1 $$ ppr t2 $$ 
223             text "The former is not a newtype."
224     err_no_idea_what_to_do =
225         pprSorry "deriveThisNT does not know how to derive an NT value relating" $  
226             ppr t1 $$ ppr t2
227
228
229 -- Check if a type if of type NT t1 t2, and returns t1 and t2
230 isNTType :: TyCon -> Type -> Maybe (Type, Type)
231 isNTType nttc t | Just (tc,[t1,t2]) <- splitTyConApp_maybe t, tc == nttc = Just (t1,t2)
232                 | otherwise = Nothing
233
234
235 -- Creates the body of a "deriving foo :: ... -> NT t1 t2" function
236 deriveNTFun :: GlobalRdrEnv -> TyCon -> [Coercion] -> Type -> CoreM CoreExpr
237 deriveNTFun env nttc cos t
238     | Just (at, rt) <- splitFunTy_maybe t = do
239         case isNTType nttc at of
240             Just (t1,t2) -> do
241                 lamNT nttc "nt" t1 t2 $ \co -> 
242                     deriveNTFun env nttc (CoVarCo co:cos) rt
243             Nothing -> err_non_NT_argument at
244     | Just (t1,t2) <- isNTType nttc t = do
245         conNT nttc $ deriveNT env nttc cos t1 t2
246     | otherwise = err_no_idea_what_to_do
247   where
248     err_non_NT_argument at = 
249         pprPgmError "deriveNTFun cannot handle arguments of non-NT-type:" $ ppr at
250     err_no_idea_what_to_do =
251         pprPgmError "deriveThisNT does not know how to derive code of type:" $  ppr t
252
253 -- Replace every occurrence of the magic 'deriveThisNT' by a valid implementation
254 replaceDeriveThisNT env nttc e@(App (Var f) (Type t))
255     | getOccString f == "deriveThisNT" = Just <$> deriveNTFun env nttc [] t
256 replaceDeriveThisNT _ _ e = do
257     --putMsg (ppr e)
258     return Nothing
259
260 --
261 -- General utilities
262 -- 
263
264 -- Replace an expression everywhere
265 traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
266 traverse f e
267     = f' =<< case e of
268         Type t               -> return $ Type t
269         Coercion c           -> return $ Coercion c
270         Lit lit              -> return $ Lit lit
271         Var v                -> return $ Var v
272         App fun a            -> App <$> traverse f fun <*> traverse f a
273         Tick t e             -> Tick t <$> traverse f e
274         Cast e co            -> Cast <$> traverse f e <*> (return co)
275         Lam b e              -> Lam b <$> traverse f e
276         Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
277         Case scrut bndr ty alts -> Case <$> traverse f scrut <*> pure bndr <*> pure ty <*> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
278     where f' x = do
279             r <- f x
280             return (fromMaybe x r)
281
282 traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
283 traverseBind f (NonRec b e) = NonRec b <$> traverse f e
284 traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
285
286 -- Convenient Core creating functions
287
288 createTyVar :: String -> CoreM TyVar
289 createTyVar name = do
290     u <- getUniqueM
291     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
292
293 tyLam :: String -> (TyVar -> CoreM CoreExpr) -> CoreM CoreExpr
294 tyLam name body = do 
295     v <- createTyVar name
296     Lam v <$> body v
297
298 lam :: String -> Type -> (Var -> CoreM CoreExpr) -> CoreM CoreExpr
299 lam name ty body = do 
300     u <- getUniqueM
301     let v = mkLocalVar VanillaId (mkSystemName u (mkVarOcc name)) ty vanillaIdInfo
302     Lam v <$> body v
303     
304 deconNT :: String -> CoreExpr -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
305 deconNT name nt body = do
306     let ntType = exprType nt
307     let (nttc, [t1, t2]) = splitTyConApp ntType
308     cou <- getUniqueM
309     let co = mkCoVar (mkSystemName cou (mkTyVarOcc name)) (mkCoercionType t1 t2)
310         [dc] = tyConDataCons nttc
311     b <- body co
312     return $ mkWildCase nt ntType (exprType b) [(DataAlt dc, [co], b)]
313
314 lamNT :: TyCon -> String -> Type -> Type -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
315 lamNT nttc name t1 t2 body = do
316     lam (name ++ "nt") (mkTyConApp nttc [t1, t2]) $ \nt -> do
317     deconNT name (Var nt) $ body
318
319 conNT :: TyCon -> CoreM Coercion -> CoreM CoreExpr
320 conNT nttc body = do
321     co <- body 
322     let Pair t1 t2  = coercionKind co
323     return $ mkConApp dc [ Type t1 , Type t2 , Coercion co ]
324   where [dc] = tyConDataCons nttc