More general deriveThisNT
authorJoachim Breitner <mail@joachim-breitner.de>
Thu, 4 Jul 2013 11:52:00 +0000 (13:52 +0200)
committerJoachim Breitner <mail@joachim-breitner.de>
Thu, 4 Jul 2013 11:52:00 +0000 (13:52 +0200)
GHC/NT.hs
GHC/NT/Plugin.hs
test.hs

index ffa9f47..032ffc4 100644 (file)
--- a/GHC/NT.hs
+++ b/GHC/NT.hs
@@ -1,6 +1,6 @@
 {-# OPTIONS_GHC -fplugin GHC.NT.Plugin #-}
 
-module GHC.NT (NT, coerce, refl, sym, trans, createNT, listNT) where
+module GHC.NT (NT, coerce, refl, sym, trans, deriveThisNT) where
 
 import GHC.NT.Type
 
@@ -16,10 +16,7 @@ sym = error "GHC.NT.sym"
 trans  :: NT a b -> NT b c -> NT a c
 trans = error "GHC.NT.trans"
 
-createNT :: NT a b
-createNT = error "GHC.NT.createNT"
-{-# NOINLINE createNT #-}
-
-listNT :: NT a b -> NT [a] [b]
-listNT = error "GHC.NT.liftNT"
+deriveThisNT :: a
+deriveThisNT = error "left over deriveThis. Did GHC.NT.Plugin run?"
+{-# NOINLINE deriveThisNT #-}
 
index 22cc6cb..c4d32b8 100644 (file)
@@ -11,6 +11,7 @@ import Control.Monad
 import Control.Applicative
 import Data.Functor
 import Data.Maybe
+import Data.List
 
 plugin :: Plugin
 plugin = defaultPlugin {
@@ -39,7 +40,11 @@ ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
 ntPass g = return g
 
 nt2Pass :: ModGuts -> CoreM ModGuts
-nt2Pass = bindsOnlyPass $ mapM (traverseBind replaceCreateNT)
+nt2Pass g = do
+    nttc <- lookupNTTyCon (mg_rdr_env g) (mg_module g)
+    --putMsg (ppr nttc)
+    binds' <- mapM (traverseBind (replaceDeriveThisNT nttc)) (mg_binds g)
+    return $ g { mg_binds = binds' }
 
 createNTTyCon :: Module -> TyCon -> CoreM TyCon
 createNTTyCon mod oldTyCon = do
@@ -147,42 +152,62 @@ bind nttc b@(NonRec v e) | getOccString v == "trans" = do
     conNT nttc $ do
     return $ TransCo (CoVarCo co1) (CoVarCo co2)
 
-bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
-    NonRec v <$> do
-    tyLam "a" $ \a -> do
-    tyLam "b" $ \b -> do
-    lamNT nttc "co" (mkTyVarTy a) (mkTyVarTy b) $ \co -> do
-    conNT nttc $ do
-    return $ TyConAppCo listTyCon [CoVarCo co]
-
 bind _ b = do
     --putMsg (ppr b)
     return b
 
-replaceCreateNT :: CoreExpr -> CoreM (Maybe CoreExpr)
-replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
-    | getOccString f == "createNT" = do
-        -- We exepct ta to be a newtype of tb
-        (tc,tyArgs) <- case splitTyConApp_maybe ta of
-            Nothing -> error "not a type application"
-            Just (tc,tyArgs) -> return (tc,tyArgs)
-        (vars,coa) <- case unwrapNewTyCon_maybe tc of
-            Nothing -> error "not a newtype"
-            Just (vars,_,co) -> return (vars,co)
-
-        -- TODO: Check if all construtors are in scope
-        -- TODO: Check that the expanded type of a is actually b
-
-        -- Extract the typcon from f's type
-        let nttc = tyConAppTyCon (exprType e)
-
-        Just <$> do
-        conNT nttc $ do
-        return $ mkAxInstCo coa tyArgs
+findCoercion :: Type -> Type -> [Coercion] -> Maybe Coercion
+findCoercion t1 t2 = find go
+  where go c = let Pair t1' t2' = coercionKind c in t1' `eqType` t1 && t2' `eqType` t2
+
+deriveNT :: TyCon -> [Coercion] -> Type -> Type -> CoreM Coercion
+deriveNT nttc cos t1 t2
+    | Just (tc1,tyArgs1) <- splitTyConApp_maybe t1,
+      Just (tc2,tyArgs2) <- splitTyConApp_maybe t2,
+      tc1 == tc2 = do
+        TyConAppCo tc1 <$> sequence (zipWith (deriveNT nttc cos) tyArgs1 tyArgs2)
+    | Just (tc,tyArgs) <- splitTyConApp_maybe t1 = do
+        case unwrapNewTyCon_maybe tc of
+            Just (tyVars, tyExpanded, coAxiom) -> do
+                putMsg (ppr (unwrapNewTyCon_maybe tc))
+                let rhs = newTyConInstRhs tc tyArgs
+                if t2 `eqType` rhs
+                  then return $ mkAxInstCo coAxiom tyArgs
+                  else pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
+                        ppr t1 $$ ppr t2 $$ 
+                        text "The former is a newtype of" $$ ppr (newTyConInstRhs tc tyArgs)
+            Nothing -> 
+                pprPgmError "deriveThisNT does not know how to derive an NT value relating" $  
+                    ppr t1 $$ ppr t2 $$ 
+                    text "The former is not a newtype."
+    | Just usable <- findCoercion t1 t2 cos = do
+        return usable
     | otherwise = do
-        --putMsgS $ getOccString f
-        return Nothing
-replaceCreateNT e = do
+        pprSorry "deriveThisNT does not know how to derive an NT value relating" $  
+            ppr t1 $$ ppr t2
+
+isNTType :: TyCon -> Type -> Maybe (Type, Type)
+isNTType nttc t | Just (tc,[t1,t2]) <- splitTyConApp_maybe t, tc == nttc = Just (t1,t2)
+                | otherwise = Nothing
+
+
+deriveNTFun :: TyCon -> [Coercion] -> Type -> CoreM CoreExpr
+deriveNTFun nttc cos t
+    | Just (at, rt) <- splitFunTy_maybe t = do
+        case isNTType nttc at of
+            Just (t1,t2) -> do
+                lamNT nttc "nt" t1 t2 $ \co -> 
+                    deriveNTFun nttc (CoVarCo co:cos) rt
+            Nothing -> pprPgmError "deriveNTFun cannot handle arguments of non-NT-type:" $ ppr at
+    | Just (t1,t2) <- isNTType nttc t = do
+        conNT nttc $ deriveNT nttc cos t1 t2
+    | otherwise = do
+        pprPgmError "deriveThisNT does not know how to derive code of type:" $  ppr t
+
+replaceDeriveThisNT :: TyCon -> CoreExpr -> CoreM (Maybe CoreExpr)
+replaceDeriveThisNT nttc e@(App (Var f) (Type t))
+    | getOccString f == "deriveThisNT" = Just <$> deriveNTFun nttc [] t
+replaceDeriveThisNT _ e = do
     --putMsg (ppr e)
     return Nothing
 
diff --git a/test.hs b/test.hs
index 134d39a..bc681f5 100644 (file)
--- a/test.hs
+++ b/test.hs
@@ -2,16 +2,21 @@
 
 import GHC.NT
 
+listNT :: NT a b -> NT [a] [b]
+listNT = deriveThisNT
+
 newtype Age = Age Int deriving Show
 
 ageNT :: NT Age Int
-ageNT = createNT
+ageNT = deriveThisNT
 
 newtype MyList a = MyList [a] deriving Show
 
 myListNT :: NT (MyList a) [a]
-myListNT = createNT
+myListNT = deriveThisNT
 
+foo :: NT a b -> NT (MyList a) (MyList b)
+foo = deriveThisNT
 
 main = do
     let n = 1 :: Int
@@ -22,4 +27,4 @@ main = do
     print a
     print l2
     print l3
-
+    print $ coerce (foo (sym ageNT)) l3