Initial check-in
authorJoachim Breitner <mail@joachim-breitner.de>
Mon, 1 Jul 2013 11:18:44 +0000 (13:18 +0200)
committerJoachim Breitner <mail@joachim-breitner.de>
Mon, 1 Jul 2013 11:18:44 +0000 (13:18 +0200)
GHC/NT.hs [new file with mode: 0644]
GHC/NT/Plugin.hs [new file with mode: 0644]
GHC/NT/Type.hs [new file with mode: 0644]
LICENSE [new file with mode: 0644]
Setup.hs [new file with mode: 0644]
nt-coerce.cabal [new file with mode: 0644]
test.hs [new file with mode: 0644]

diff --git a/GHC/NT.hs b/GHC/NT.hs
new file mode 100644 (file)
index 0000000..ffa9f47
--- /dev/null
+++ b/GHC/NT.hs
@@ -0,0 +1,25 @@
+{-# OPTIONS_GHC -fplugin GHC.NT.Plugin #-}
+
+module GHC.NT (NT, coerce, refl, sym, trans, createNT, listNT) where
+
+import GHC.NT.Type
+
+coerce :: NT a b -> a -> b
+coerce = error "GHC.NT.coerce"
+
+refl   :: NT a a
+refl = error "GHC.NT.refl"
+
+sym    :: NT a b -> NT b a
+sym = error "GHC.NT.sym"
+
+trans  :: NT a b -> NT b c -> NT a c
+trans = error "GHC.NT.trans"
+
+createNT :: NT a b
+createNT = error "GHC.NT.createNT"
+{-# NOINLINE createNT #-}
+
+listNT :: NT a b -> NT [a] [b]
+listNT = error "GHC.NT.liftNT"
+
diff --git a/GHC/NT/Plugin.hs b/GHC/NT/Plugin.hs
new file mode 100644 (file)
index 0000000..27f24c3
--- /dev/null
@@ -0,0 +1,255 @@
+{-# LANGUAGE TupleSections #-}
+
+module GHC.NT.Plugin where
+
+import GhcPlugins
+import MkId
+import Kind
+
+import Control.Monad
+import Control.Applicative
+import Data.Functor
+import Data.Maybe
+
+plugin :: Plugin
+plugin = defaultPlugin {
+    installCoreToDos = install
+  }
+
+install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
+install _ xs = do
+    reinitializeGlobals
+    return $ CoreDoPasses [nt,nt2] : xs
+  where nt = CoreDoPluginPass "GHC.NT implementation" ntPass
+        nt2 = CoreDoPluginPass "GHC.NT.createNT implementation" nt2Pass
+
+ntPass :: ModGuts -> CoreM ModGuts
+ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT.Type" = do
+    let [oldTc] = mg_tcs g
+    nttc <- createNTTyCon (mg_module g) oldTc
+    tcs' <- mapM (replaceTyCon nttc) (mg_tcs g)
+
+    dflags <- getDynFlags
+    return $ g { mg_tcs = tcs' }
+ntPass g | moduleNameString (moduleName (mg_module g)) == "GHC.NT" = do
+    nttc <- lookupNTTyCon (mg_module g)
+    binds' <- mapM (bind nttc) (mg_binds g)
+
+    dflags <- getDynFlags
+    dflags <- getDynFlags
+
+    return $ g { mg_binds = binds' }
+ntPass g = return g
+
+nt2Pass :: ModGuts -> CoreM ModGuts
+nt2Pass = bindsOnlyPass $ mapM bind2
+
+createNTTyCon :: Module -> TyCon -> CoreM TyCon
+createNTTyCon mod oldTyCon = do
+    a <- createTyVar "a"
+    b <- createTyVar "b"
+    let arg_tys = map mkTyVarTy [a,b]
+    let tyConU = tyConUnique oldTyCon
+    dataConU <- getUniqueM
+    dataConWorkerU <- getUniqueM
+    dataConWrapperU <- getUniqueM
+    let cot = mkCoercionType (mkTyVarTy a) (mkTyVarTy b)
+        rett = mkTyConApp t' arg_tys
+        dct = mkForAllTys [a,b] $ mkFunTy cot rett
+        -- Have to use the original name, otherwise we get a 
+        -- urk! lookup local fingerprint
+        --tcn = mkExternalName tyConU mod (mkTcOcc "NT") noSrcSpan
+        tcn = tyConName oldTyCon
+        n = mkExternalName dataConU mod (mkDataOcc "NT") noSrcSpan
+        dataConWorkerN = mkSystemName dataConWorkerU (mkDataOcc "NT_work")
+        dataConWrapperN = mkSystemName dataConWrapperU (mkDataOcc "NT_wrap")
+        workId = mkGlobalId (DataConWrapId dc') dataConWorkerN dct vanillaIdInfo
+        dataConIds = mkDataConIds dataConWorkerN dataConWrapperN dc'
+        dc' = mkDataCon
+                n
+                False
+                [ HsNoBang ]
+                []
+                [a,b]
+                []
+                []
+                []
+                [ cot ]
+                rett
+                t'
+                []
+                dataConIds -- (DCIds Nothing workId)
+        t' = mkAlgTyCon
+               tcn 
+               (mkArrowKinds [liftedTypeKind, liftedTypeKind] liftedTypeKind)
+               [a,b]
+               Nothing
+               []
+               (DataTyCon [dc'] False)
+               NoParentTyCon
+               NonRecursive
+               False
+    return t'
+
+replaceTyCon :: TyCon -> TyCon -> CoreM TyCon
+replaceTyCon nttc t 
+    | occNameString (nameOccName (tyConName t)) == "NT" = return nttc
+    | otherwise = return t
+
+lookupNTTyCon :: Module -> CoreM TyCon
+lookupNTTyCon mod = do
+    let packageId = modulePackageId mod -- HACK!
+    let ntTypeModule = mkModule packageId (mkModuleName "GHC.NT.Type")
+    nc <- getOrigNameCache
+
+    dflags <- getDynFlags
+    --putMsgS $ showSDoc dflags (ppr (moduleEnvKeys nc))
+    --putMsgS $ showSDoc dflags (ppr ntTypeModule)
+
+    -- HACK!
+    let ntTypeModule = last (moduleEnvKeys nc) -- Why does the other not work?
+    --putMsgS $ showSDoc dflags (ppr ntTypeModule)
+
+    let Just occEnv = lookupModuleEnv nc ntTypeModule
+
+    --putMsgS $ showSDoc dflags (ppr (occEnv)) 
+    -- let Just ntTyConName = lookupOccEnv occEnv (mkTcOccFS (fsLit "NT")) -- Why does this not work?
+    -- MORE HACKS!
+    let [ntTyConName] = occEnvElts occEnv
+    lookupTyCon ntTyConName
+
+bind :: TyCon -> CoreBind -> CoreM CoreBind
+bind nttc b@(NonRec v e) | getOccString v == "coerce" = do
+    dflags <- getDynFlags
+    au <- getUniqueM
+    bu <- getUniqueM
+    ntu <- getUniqueM
+    nttu <- getUniqueM
+    xu <- getUniqueM
+    cou <- getUniqueM
+    let a   = mkTyVar (mkSystemName au (mkTyVarOcc "a")) liftedTypeKind
+        b   = mkTyVar (mkSystemName bu (mkTyVarOcc "b")) liftedTypeKind
+        ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
+        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
+        x   = mkLocalVar VanillaId (mkSystemName xu (mkVarOcc "b")) (mkTyVarTy a) vanillaIdInfo
+        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
+        [dc] = tyConDataCons nttc
+    let e' = Lam a $ Lam b $ Lam nt $ Lam x $
+                Case (Var nt) nt (mkTyVarTy b)
+                    [(DataAlt dc, [co], 
+                        Cast (Var x) (CoVarCo co)
+                    )]
+    return (NonRec v e')
+
+bind nttc b@(NonRec v e) | getOccString v == "sym" = do
+    dflags <- getDynFlags
+    a <- createTyVar "a"
+    b <- createTyVar "b"
+    ntu <- getUniqueM
+    nttu <- getUniqueM
+    ntt'u <- getUniqueM
+    cou <- getUniqueM
+    let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
+        ntt' = mkTyConApp nttc [mkTyVarTy b, mkTyVarTy a]
+        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
+        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
+        [dc] = tyConDataCons nttc
+    let e' = Lam a $ Lam b $ Lam nt $
+            Case (Var nt) nt ntt'
+                [(DataAlt dc, [co], mkConApp dc
+                    [ Type (mkTyVarTy b)
+                    , Type (mkTyVarTy a)
+                    ,  Coercion (SymCo (CoVarCo co))
+                    ]
+                )]
+    return (NonRec v e')
+
+bind nttc b@(NonRec v e) | getOccString v == "listNT" = do
+    a <- createTyVar "a"
+    b <- createTyVar "b"
+    ntu <- getUniqueM
+    nttu <- getUniqueM
+    ntt'u <- getUniqueM
+    cou <- getUniqueM
+    let ntt = mkTyConApp nttc [mkTyVarTy a, mkTyVarTy b]
+        ntt' = mkTyConApp nttc [mkTyConApp listTyCon [mkTyVarTy a], mkTyConApp listTyCon [mkTyVarTy b]]
+        nt  = mkLocalVar VanillaId (mkSystemName ntu (mkVarOcc "nt")) ntt vanillaIdInfo
+        co = mkCoVar (mkSystemName cou (mkTyVarOcc "co")) (mkCoercionType (mkTyVarTy a) (mkTyVarTy b))
+        [dc] = tyConDataCons nttc
+    let e' = Lam a $ Lam b $ Lam nt $
+            Case (Var nt) nt ntt' 
+                [(DataAlt dc, [co], mkConApp dc
+                    [ Type (mkTyConApp listTyCon [mkTyVarTy a])
+                    , Type (mkTyConApp listTyCon [mkTyVarTy b])
+                    ,  Coercion (TyConAppCo listTyCon [CoVarCo co])
+                    ]
+                )]
+    return (NonRec v e')
+
+bind _ b = do
+    dflags <- getDynFlags
+    --putMsgS $ showSDoc dflags (ppr b)
+    return b
+
+bind2 :: CoreBind -> CoreM CoreBind
+bind2 (NonRec v e) = NonRec v <$> traverse replaceCreateNT e
+bind2 (Rec binds) = Rec <$> mapM (\(v,e) -> (\e' -> (v,e')) <$> traverse replaceCreateNT e) binds
+
+
+replaceCreateNT :: CoreExpr -> CoreM (Maybe CoreExpr)
+replaceCreateNT e@((App (App (Var f) (Type ta)) (Type tb)))
+    | getOccString f == "createNT" = do
+        -- We exepct ta to be a newtype of tb
+        (tc,tyArgs) <- case splitTyConApp_maybe ta of
+            Nothing -> error "not a type application"
+            Just (tc,tyArgs) -> return (tc,tyArgs)
+        (vars,coa) <- case unwrapNewTyCon_maybe tc of
+            Nothing -> error "not a newtype"
+            Just (vars,_,co) -> return (vars,co)
+
+        -- TODO: Check if all construtors are in scope
+        -- TODO: Check that the expanded type of a is actually b
+
+        dflags <- getDynFlags
+        --putMsgS $ showSDoc dflags (ppr e)
+        -- Extract the typcon from f's type
+        let nttc = tyConAppTyCon (exprType e)
+        let [dc] = tyConDataCons nttc
+        let e' = mkConApp dc [ Type ta, Type tb, Coercion (mkAxInstCo coa tyArgs)] :: CoreExpr
+        --putMsgS $ showSDoc dflags (ppr nttc)
+        --putMsgS $ showSDoc dflags (ppr (tyConDataCons nttc))
+        --putMsgS $ showSDoc dflags (ppr e')
+        return (Just e')
+    | otherwise = do
+        --putMsgS $ getOccString f
+        return Nothing
+replaceCreateNT e = do
+    --dflags <- getDynFlags
+    --putMsgS $ showSDoc dflags (ppr e)
+    return Nothing
+
+traverse :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Expr a -> m (Expr a)
+traverse f e
+    = f' =<< case e of
+        Type t               -> return $ Type t
+        Coercion c           -> return $ Coercion c
+        Lit lit              -> return $ Lit lit
+        Var v                -> return $ Var v
+        App fun a            -> App <$> traverse f fun <*> traverse f a
+        Tick t e             -> Tick t <$> traverse f e
+        Cast e co            -> Cast <$> traverse f e <*> (return co)
+        Lam b e              -> Lam b <$> traverse f e
+        Let bind e           -> Let <$> traverseBind f bind <*> traverse f e
+        Case scrut bndr ty alts -> Case scrut bndr ty <$> mapM (\(a,b,c) -> (a,b,) <$> traverse f c) alts 
+    where f' x = do
+            r <- f x
+            return (fromMaybe x r)
+
+traverseBind :: (Functor m, Applicative m, Monad m) => (Expr a -> m (Maybe (Expr a))) -> Bind a -> m (Bind a)
+traverseBind f (NonRec b e) = NonRec b <$> traverse f e
+traverseBind f (Rec l) = Rec <$> mapM (\(a,b) -> (a,) <$> traverse f b) l
+
+createTyVar :: String -> CoreM TyVar
+createTyVar name = do
+    u <- getUniqueM
+    return $ mkTyVar (mkSystemName u (mkTyVarOcc name)) liftedTypeKind
diff --git a/GHC/NT/Type.hs b/GHC/NT/Type.hs
new file mode 100644 (file)
index 0000000..e8a7f6a
--- /dev/null
@@ -0,0 +1,6 @@
+{-# LANGUAGE EmptyDataDecls #-}
+{-# OPTIONS_GHC -fplugin GHC.NT.Plugin #-}
+
+module GHC.NT.Type where
+
+data NT a b -- = NT (a,b)
diff --git a/LICENSE b/LICENSE
new file mode 100644 (file)
index 0000000..2a93d81
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,30 @@
+Copyright (c) 2013, Joachim Breitner
+
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+    * Redistributions of source code must retain the above copyright
+      notice, this list of conditions and the following disclaimer.
+
+    * Redistributions in binary form must reproduce the above
+      copyright notice, this list of conditions and the following
+      disclaimer in the documentation and/or other materials provided
+      with the distribution.
+
+    * Neither the name of Joachim Breitner nor the names of other
+      contributors may be used to endorse or promote products derived
+      from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/Setup.hs b/Setup.hs
new file mode 100644 (file)
index 0000000..9a994af
--- /dev/null
+++ b/Setup.hs
@@ -0,0 +1,2 @@
+import Distribution.Simple
+main = defaultMain
diff --git a/nt-coerce.cabal b/nt-coerce.cabal
new file mode 100644 (file)
index 0000000..ecdbc95
--- /dev/null
@@ -0,0 +1,22 @@
+-- Initial nt-coerce.cabal generated by cabal init.  For further 
+-- documentation, see http://haskell.org/cabal/users-guide/
+
+name:                nt-coerce
+version:             0.1
+synopsis:            Zero-cost coercions for types with equal represention
+-- description:         
+license:             BSD3
+license-file:        LICENSE
+author:              Joachim Breitner
+maintainer:          mail@joachim-breitner.de
+-- copyright:           
+category:            Language
+build-type:          Simple
+cabal-version:       >=1.8
+
+library
+  exposed-modules:     
+    GHC.NT
+    GHC.NT.Plugin
+  -- other-modules:       
+  build-depends:       base ==4.6.*, ghc
diff --git a/test.hs b/test.hs
new file mode 100644 (file)
index 0000000..134d39a
--- /dev/null
+++ b/test.hs
@@ -0,0 +1,25 @@
+{-# OPTIONS_GHC -fplugin GHC.NT.Plugin #-}
+
+import GHC.NT
+
+newtype Age = Age Int deriving Show
+
+ageNT :: NT Age Int
+ageNT = createNT
+
+newtype MyList a = MyList [a] deriving Show
+
+myListNT :: NT (MyList a) [a]
+myListNT = createNT
+
+
+main = do
+    let n = 1 :: Int
+    let a = coerce (sym ageNT) 1
+    let l1 = [a]
+    let l2 = coerce (listNT ageNT) l1
+    let l3 = coerce (sym myListNT) l2
+    print a
+    print l2
+    print l3
+