Simplify construction of CoreSyn
authorJoachim Breitner <mail@joachim-breitner.de>
Wed, 3 Jul 2013 12:23:43 +0000 (14:23 +0200)
committerJoachim Breitner <mail@joachim-breitner.de>
Wed, 3 Jul 2013 12:23:43 +0000 (14:23 +0200)
GHC/NT/Plugin.hs

index 3bb4e5e..3d5f120 100644 (file)
@@ -4,6 +4,7 @@ module GHC.NT.Plugin where
 
 import GhcPlugins
 import MkId
+import Pair
 import Kind
 
 import Control.Monad
@@ -112,108 +113,52 @@ lookupNTTyCon env mod = do
     let n = gre_name e'
     lookupTyCon n
 
+
 bind :: TyCon -> CoreBind -> CoreM CoreBind
 bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
-    au <- getUniqueM
-    bu <- getUniqueM
-    ntu <- getUniqueM
-    nttu <- getUniqueM
-    xu <- getUniqueM
-    cou <- getUniqueM
-    let a   = mkTyVar (mkSystemName au (mkTyVarOcc "a")) liftedTypeKind
-        b   = mkTyVar (mkSystemName bu (mkTyVarOcc "b")) liftedTypeKind
-        ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
-        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
-        x   = mkLocalVar VanillaId (mkSystemName xu (mkVarOcc "b")) (mkTyVarTy a) vanillaIdInfo
-        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
-        [dc] = tyConDataCons nttc
-    let e' = Lam a $ Lam b $ Lam nt $ Lam x $
-                Case (Var nt) nt (mkTyVarTy b)
-                    [(DataAlt dc, [co], Cast (Var x) (CoVarCo co))]
-    return (NonRec v e')
+    NonRec v <$> do
+    tyLam "a" $ \a -> do
+    tyLam "b" $ \b -> do
+    lam "nt" (mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]) $ \nt -> do
+    lam "x" (mkTyVarTy a) $ \x -> do
+    deconNT "co" (Var nt) $ \co -> do
+    return $ Cast (Var x) (CoVarCo co)
 
 bind nttc b@(NonRec v e) | getOccString v == "refl" = do
-    a <- createTyVar "a"
-    let [dc] = tyConDataCons nttc
-    let e' = Lam a $ mkConApp dc
-                    [ Type (mkTyVarTy a)
-                    , Type (mkTyVarTy a)
-                    , Coercion (Refl (mkTyVarTy a))
-                    ]
-    return (NonRec v e')
+    NonRec v <$> do
+    tyLam "a" $ \a ->
+        conNT nttc $
+            return $ Refl (mkTyVarTy a)
 
 bind nttc b@(NonRec v e) | getOccString v == "sym" = do
-    a <- createTyVar "a"
-    b <- createTyVar "b"
-    ntu <- getUniqueM
-    nttu <- getUniqueM
-    cou <- getUniqueM
-    let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
-        ntt' = mkTyConApp nttc [mkTyVarTy b, mkTyVarTy a]
-        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
-        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
-        [dc] = tyConDataCons nttc
-    let e' = Lam a $ Lam b $ Lam nt $
-            Case (Var nt) nt ntt'
-                [(DataAlt dc, [co], mkConApp dc
-                    [ Type (mkTyVarTy b)
-                    , Type (mkTyVarTy a)
-                    , Coercion (SymCo (CoVarCo co))
-                    ]
-                )]
-    return (NonRec v e')
+    NonRec v <$> do
+    tyLam "a" $ \a -> do
+    tyLam "b" $ \b -> do
+    lam "nt" (mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]) $ \nt -> do
+    deconNT "co" (Var nt) $ \co -> do
+    conNT nttc $ do
+    return $ SymCo (CoVarCo co)
 
 bind nttc b@(NonRec v e) | getOccString v == "trans" = do
-    a <- createTyVar "a"
-    b <- createTyVar "b"
-    c <- createTyVar "c"
-    ntu <- getUniqueM
-    nt2u <- getUniqueM
-    nttu <- getUniqueM
-    ntt'u <- getUniqueM
-    cou <- getUniqueM
-    co2u <- getUniqueM
-    let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
-        nt2t = mkTyConApp nttc [mkTyVarTy b, mkTyVarTy c]
-        ntt' = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy c]
-        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
-        nt2  = mkLocalVar VanillaId (mkSystemName nt2u (mkVarOcc "nt2")) nt2t vanillaIdInfo
-        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
-        co2 = mkCoVar (mkSystemName co2u (mkTyVarOcc "co2")) (mkCoercionType (mkTyVarTy b) (mkTyVarTy c))
-        [dc] = tyConDataCons nttc
-    let e' = Lam a $ Lam b $ Lam c $ Lam nt $ Lam nt2 $
-            Case (Var nt) nt ntt'
-                [(DataAlt dc, [co], Case (Var nt2) nt2 ntt' 
-                    [(DataAlt dc, [co2], mkConApp dc
-                        [ Type (mkTyVarTy a)
-                        , Type (mkTyVarTy c)
-                        , Coercion (TransCo (CoVarCo co) (CoVarCo co2))
-                        ]
-                    )]
-                )]
-    return (NonRec v e')
+    NonRec v <$> do
+    tyLam "a" $ \a -> do
+    tyLam "b" $ \b -> do
+    tyLam "c" $ \c -> do
+    lam "nt1" (mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]) $ \nt1 -> do
+    lam "nt2" (mkTyConApp nttc [mkTyVarTy b, mkTyVarTy c]) $ \nt2 -> do
+    deconNT "co1" (Var nt1) $ \co1 -> do
+    deconNT "co2" (Var nt2) $ \co2 -> do
+    conNT nttc $ do
+    return $ TransCo (CoVarCo co1) (CoVarCo co2)
 
 bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
-    a <- createTyVar "a"
-    b <- createTyVar "b"
-    ntu <- getUniqueM
-    nttu <- getUniqueM
-    ntt'u <- getUniqueM
-    cou <- getUniqueM
-    let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
-        ntt' = mkTyConApp nttc [mkTyConApp listTyCon [mkTyVarTy a], mkTyConApp listTyCon [mkTyVarTy b]]
-        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
-        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
-        [dc] = tyConDataCons nttc
-    let e' = Lam a $ Lam b $ Lam nt $
-            Case (Var nt) nt ntt' 
-                [(DataAlt dc, [co], mkConApp dc
-                    [ Type (mkTyConApp listTyCon [mkTyVarTy a])
-                    , Type (mkTyConApp listTyCon [mkTyVarTy b])
-                    ,  Coercion (TyConAppCo listTyCon [CoVarCo co])
-                    ]
-                )]
-    return (NonRec v e')
+    NonRec v <$> do
+    tyLam "a" $ \a -> do
+    tyLam "b" $ \b -> do
+    lam "nt" (mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]) $ \nt -> do
+    deconNT "co" (Var nt) $ \co -> do
+    conNT nttc $ do
+    return $ TyConAppCo listTyCon [CoVarCo co]
 
 bind _ b = do
     --putMsg (ppr b)
@@ -233,15 +178,12 @@ replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
         -- TODO: Check if all construtors are in scope
         -- TODO: Check that the expanded type of a is actually b
 
-        --putMsg (ppr e)
         -- Extract the typcon from f's type
         let nttc = tyConAppTyCon (exprType e)
-        let [dc] = tyConDataCons nttc
-        let e' = mkConApp dc [ Type ta, Type tb, Coercion (mkAxInstCo coa tyArgs)] :: CoreExpr
-        --putMsg (ppr nttc)
-        --putMsg (ppr (tyConDataCons nttc))
-        --putMsg (ppr e')
-        return (Just e')
+
+        Just <$> do
+        conNT nttc $ do
+        return $ mkAxInstCo coa tyArgs
     | otherwise = do
         --putMsgS $ getOccString f
         return Nothing
@@ -274,3 +216,31 @@ createTyVar :: String -> CoreM TyVar
 createTyVar name = do
     u <- getUniqueM
     return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
+
+tyLam :: String -> (TyVar -> CoreM CoreExpr) -> CoreM CoreExpr
+tyLam name body = do 
+    v <- createTyVar name
+    Lam v <$> body v
+
+lam :: String -> Type -> (Var -> CoreM CoreExpr) -> CoreM CoreExpr
+lam name ty body = do 
+    u <- getUniqueM
+    let v = mkLocalVar VanillaId (mkSystemName u (mkVarOcc name)) ty vanillaIdInfo
+    Lam v <$> body v
+    
+deconNT :: String -> CoreExpr -> (CoVar -> CoreM CoreExpr) -> CoreM CoreExpr
+deconNT name nt body = do
+    let ntType = exprType nt
+    let (nttc, [t1, t2]) = splitTyConApp ntType
+    cou <- getUniqueM
+    let co = mkCoVar (mkSystemName cou (mkTyVarOcc name)) (mkCoercionType t1 t2)
+        [dc] = tyConDataCons nttc
+    b <- body co
+    return $ mkWildCase nt ntType (exprType b) [(DataAlt dc, [co], b)]
+
+conNT :: TyCon -> CoreM Coercion -> CoreM CoreExpr
+conNT nttc body = do
+    co <- body 
+    let Pair t1 t2  = coercionKind co
+    return $ mkConApp dc [ Type t1 , Type t2 , Coercion co ]
+  where [dc] = tyConDataCons nttc