Deep Definition removal
[darcs-mirror-polyfix.git] / Expr.hs
diff --git a/Expr.hs b/Expr.hs
index 1fd49bc..727399a 100644 (file)
--- a/Expr.hs
+++ b/Expr.hs
@@ -2,6 +2,7 @@
 module Expr where
 
 import Data.List
+import Data.Maybe
 import ParseType
 
 import Data.Generics hiding (typeOf)
@@ -26,9 +27,8 @@ data Expr
             deriving (Eq, Typeable, Data)
 
 data BoolExpr 
-       = BETrue
-       | Equal Expr Expr
-       | And BoolExpr BoolExpr
+       = Equal Expr Expr
+       | And [BoolExpr] -- And [] is True
        | AllZipWith TypedExpr TypedExpr BoolExpr Expr Expr
        | Condition [TypedExpr] BoolExpr BoolExpr
        | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
@@ -42,9 +42,10 @@ equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
 
 unpackPair = UnpackPair
 
-allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
+allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
+allZipWith v1 v2 rel e1 e2 | Just (v1', _) <- defFor v1 rel =
                                e1 `equal` amap (lambda v2 v1') e2
-                           | Just v2' <- defFor v2 rel =
+                           | Just (v2', _) <- defFor v2 rel =
                                amap (lambda v1 v2') e1 `equal` e2
                            | otherwise =
                                AllZipWith v1 v2 rel (unTypeExpr e1) (unTypeExpr e2)
@@ -56,18 +57,64 @@ amap tf tl | Arrow t1 t2 <- typeOf tf
             in app (app tMap tf) tl
 amap tf tl | otherwise = error "Type error in map"
 
+aand (And xs) (And ys) = And (xs  ++ ys)
+aand (And xs) y        = And (xs  ++ [y])
+aand x        (And ys) = And ([x] ++ ys)
+aand x        y        = And ([x,y])
+
+beTrue = And []
+
+-- | Is any var (or part of var) defined in cond, and can be replaced in concl?
+condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
+condition vars cond concl | ((vars',cond',concl'):_) <- mapMaybe try vars
+                         = condition vars' cond' concl'
+                          | otherwise
+                          = Condition vars cond concl
+  where try v = do (def,del) <- defFor v cond --Maybe Monad
+                   return (delete v vars, del cond, del concl)
+
+-- | Replaces a Term in a BoolExpr
+replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
+replaceTermBE d r = go
+  where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
+                           | d == e2 && r == e1 = beTrue
+                           | otherwise          = go' e1 `Equal` go' e2
+        go (And es)        = foldr aand beTrue (map go es)
+        go (AllZipWith v1 v2 be e1 e2) 
+                           = AllZipWith v1 v2 (go be) (go' e1) (go' e2)
+       go (Condition vs cond concl)
+                          = condition vs (go cond) (go concl)
+       go (UnpackPair v1 v2 e be)
+                          = unpackPair v1 v2 (goT e) (go be)
+       go (TypeVarInst _ _) = error "TypeVarInst not expected here"
+       goT = replaceTypedExpr d r
+       go' = replaceExpr d r
+
+replaceExpr :: Expr -> Expr -> Expr -> Expr
+replaceExpr d r = everywhere (mkT go)
+  where go e | e == d    = r 
+             | otherwise = e
+
+replaceTypedExpr :: Expr -> Expr -> TypedExpr -> TypedExpr
+replaceTypedExpr d r = everywhere (mkT go)
+  where go e | unTypeExpr e == d = e { unTypeExpr = r }
+             | otherwise         = e
+
 -- | Is inside the term a definition for the variable?
-defFor :: TypedExpr -> BoolExpr -> Maybe TypedExpr
-defFor tv be | Just e' <- defFor' (unTypeExpr tv) be
-                         = Just (TypedExpr e' (typeOf tv))
+defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr, BoolExpr -> BoolExpr)
+defFor tv be | Just (e', delDef) <- defFor' (unTypeExpr tv) be
+                         = Just (TypedExpr e' (typeOf tv), delDef)
              | otherwise = Nothing
        
-defFor' v (e1 `Equal` e2) | v == e1                 = Just e2
-                          | v == e2                 = Just e1
-defFor' v (e1 `And` e2)   | Just d  <- defFor' v e1
-                         , Nothing <- defFor' v e2 = Just d
-defFor' v (e1 `And` e2)   | Just d  <- defFor' v e2
-                         , Nothing <- defFor' v e1 = Just d
+-- | Find a definition, and return it along the definition remover
+defFor' :: Expr -> BoolExpr -> Maybe (Expr, BoolExpr -> BoolExpr)
+defFor' (Pair x y) e | Just (dx, delX) <- defFor' x e
+                     , Just (dy, delY) <- defFor' y e
+                    = Just ((Pair dx dy), delX . delY)
+defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2, replaceTermBE e e2)
+                          | e == e2                 = Just (e1, replaceTermBE e e1)
+defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
+                                                   = Just d
 defFor' _ _                                         = Nothing
 
 app te1 te2 | Arrow t1 t2 <- typeOf te1
@@ -83,9 +130,9 @@ app te1 te2 | otherwise                          = error $ "Type mismatch in app
 unCond v (Equal l r) | (Just l') <- isApplOn (unTypeExpr v) l 
                     , (Just r') <- isApplOn (unTypeExpr v) r = 
        if v `occursIn` l' || v `occursIn` r'
-       then Condition [v] BETrue (Equal l' r')
+       then Condition [v] beTrue (Equal l' r')
        else (Equal l' r')
-unCond v e = Condition [v] BETrue e
+unCond v e = Condition [v] beTrue e
 
 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
@@ -158,9 +205,8 @@ instance Show BoolExpr where
        show (Equal e1 e2) = showsPrec 9 e1 $
                             showString " == " $
                             showsPrec 9 e2 ""
-       show (And be1 be2) = show be1 ++
-                            " && " ++
-                            show be2 
+       show (And [])      = show "True"
+        show (And bes)     = intercalate " && " $ map show bes
        show (AllZipWith v1 v2 be e1 e2) =
                        "allZipWith " ++
                        "( " ++
@@ -179,7 +225,7 @@ instance Show BoolExpr where
                        "forall " ++
                        intercalate ", " (map show tvars) ++
                        ".\n" ++
-                       (if be1 /= BETrue then indent 2 (show be1) ++ "==>\n" else "") ++
+                       (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
                        indent 2 (show be2)
        show (UnpackPair v1 v2 e be) = 
                        "let (" ++