1 {-# LANGUAGE PatternGuards, DeriveDataTypeable #-}
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
13 data TypedExpr = TypedExpr
16 } deriving (Eq, Typeable, Data)
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
25 deriving (Eq, Typeable, Data)
31 | Conc [Expr] -- Conc [] is Id
32 | Lambda TypedExpr Expr
38 deriving (Eq, Typeable, Data)
40 data LambdaBE = CurriedEquals Typ
41 | LambdaBE TypedExpr TypedExpr BoolExpr
42 deriving (Eq, Typeable, Data)
46 | And [BoolExpr] -- And [] is True
47 | AllZipWith LambdaBE Expr Expr
48 | AndEither LambdaBE LambdaBE Expr Expr
49 | Condition [TypedExpr] BoolExpr BoolExpr
50 | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
51 | TypeVarInst Bool Int BoolExpr
52 deriving (Eq, Typeable, Data)
56 -- | Try eta reduction
57 equal :: TypedExpr -> TypedExpr -> BoolExpr
58 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
59 | otherwise = equal' (unTypeExpr te1) (unTypeExpr te2)
61 equal' :: Expr -> Expr -> BoolExpr
62 equal' e1 e2 | (Just (lf,lv)) <- isFunctionApplication e1
63 , (Just (rf,rv)) <- isFunctionApplication e2
66 -- This makes it return True...
68 | otherwise = Equal e1 e2
70 -- | If e is of the type (app f1 (app f2 .. (app fn x)..)),
71 -- return Just (f1 . f2. ... . fn, x)
72 isFunctionApplication :: Expr -> Maybe (Expr, Expr)
73 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
74 = Just (conc f inner, v)
77 isFunctionApplication _ = Nothing
80 -- | If both bound variables are just functions, we can replace this
82 unpackPair :: TypedExpr -> TypedExpr -> TypedExpr -> BoolExpr -> BoolExpr
83 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
84 , Just subst2 <- findReplacer v2 be
85 = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
87 -- | If the whole tuple is a function, we can replace this
89 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
90 = subst $ (pair v1 v2 `equal` te) `aand` be
92 -- | Nothing to optimize
93 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
95 pair :: TypedExpr -> TypedExpr -> TypedExpr
96 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
98 lambdaBE :: TypedExpr -> TypedExpr -> BoolExpr -> LambdaBE
99 lambdaBE v1 v2 rel | typeOf v1 == typeOf v2
100 , rel == v1 `equal` v2 = CurriedEquals (typeOf v1)
101 | otherwise = LambdaBE v1 v2 rel
103 andEither :: LambdaBE -> LambdaBE -> TypedExpr -> TypedExpr -> BoolExpr
104 andEither (CurriedEquals _) (CurriedEquals _) e1 e2 = e1 `equal` e2
105 andEither lbe1 lbe2 e1 e2 | Just f1 <- arg1IsFunc lbe1
106 , Just f2 <- arg1IsFunc lbe2
107 = e1 `equal` eitherE f1 f2 e2
108 | Just f1 <- arg2IsFunc lbe1
109 , Just f2 <- arg2IsFunc lbe2
110 = eitherE f1 f2 e1 `equal` e2
112 = AndEither lbe1 lbe2 (unTypeExpr e1) (unTypeExpr e2)
114 arg1IsFunc (CurriedEquals t) = Just $ TypedExpr (Conc []) (Arrow t t)
115 arg1IsFunc (LambdaBE v1 v2 rel) | Just v1' <- defFor v1 rel
116 = Just (lambda v2 v1')
117 | otherwise = Nothing
119 arg2IsFunc (CurriedEquals t) = Just $ TypedExpr (Conc []) (Arrow t t)
120 arg2IsFunc (LambdaBE v1 v2 rel) | Just v2' <- defFor v2 rel
121 = Just (lambda v1 v2')
122 | otherwise = Nothing
124 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
125 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
126 e1 `equal` amap (lambda v2 v1') e2
127 | Just v2' <- defFor v2 rel =
128 amap (lambda v1 v2') e1 `equal` e2
130 AllZipWith (LambdaBE v1 v2 rel) (unTypeExpr e1) (unTypeExpr e2)
132 eitherE :: TypedExpr -> TypedExpr -> TypedExpr -> TypedExpr
133 eitherE f1 f2 e | Arrow lt1 lt2 <- typeOf f1
134 , Arrow rt1 rt2 <- typeOf f2
135 , TEither lt rt <- typeOf e
138 = let tEither = TypedExpr EitherMap (Arrow (typeOf f1) (Arrow (typeOf f2) (Arrow (typeOf e) (TEither lt2 rt2))))
139 in app (app (app tEither f1) f2) e
140 | otherwise = error $ "Type error in eitherE\n" ++ show (f1, f2, e)
142 amap :: TypedExpr -> TypedExpr -> TypedExpr
143 amap tf tl | Arrow t1 t2 <- typeOf tf
144 , List t <- typeOf tl
146 = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
147 in app (app tMap tf) tl
148 | otherwise = error "Type error in map"
150 aand :: BoolExpr -> BoolExpr -> BoolExpr
151 aand (And xs) (And ys) = And (xs ++ ys)
152 aand (And xs) y = And (xs ++ [y])
153 aand x (And ys) = And ([x] ++ ys)
154 aand x y = And ([x,y])
159 -- | Optimize a forall condition
160 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
162 condition [] cond concl | cond == beTrue
164 -- float out conditions on the right
165 condition vars cond (Condition vars' cond' concl')
166 = condition (vars ++ vars') (cond `aand` cond') concl'
168 -- Try to find variables that are functions of other variables, and remove them
169 condition vars cond concl | True -- set to false to disable
170 , ((vars',cond',concl'):_) <- mapMaybe try vars
171 = condition vars' cond' concl'
172 -- A variable which can be replaced
173 where try v | Just subst <- findReplacer v cond
174 = -- trace ("Tested " ++ show v ++ ", can be replaced") $
175 Just (delete v vars, subst cond, subst concl)
177 -- A variable with can be removed
178 | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
179 = -- trace ("Tested " ++ show v ++ ", can be reased") $
180 Just (delete v vars, cond, concl)
182 -- Nothing to do with this variable
184 = -- trace ("Tested " ++ show v ++ " without success") $
187 -- Nothing left to optizmize
188 condition vars cond concl = Condition vars cond concl
190 -- | Replaces a Term in a BoolExpr
191 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
192 replaceTermBE d r = go
193 where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
194 | d == e2 && r == e1 = beTrue
195 | otherwise = go' e1 `equal'` go' e2
196 go (And es) = foldr aand beTrue (map go es)
197 go (AllZipWith lbe e1 e2)
198 = AllZipWith (goL lbe) (go' e1) (go' e2)
199 go (AndEither lbe1 lbe2 e1 e2)
200 = AndEither (goL lbe1) (goL lbe2) (go' e1) (go' e2)
201 go (Condition vs cond concl)
202 = condition vs (go cond) (go concl)
203 go (UnpackPair v1 v2 e be)
204 = unpackPair v1 v2 (goT e) (go be)
205 go (TypeVarInst _ _ _) = error "TypeVarInst not expected here"
207 go' = replaceExpr d r
209 goT te = te { unTypeExpr = go' (unTypeExpr te) }
211 goL (CurriedEquals t) = (CurriedEquals t)
212 goL (LambdaBE v1 v2 be) = lambdaBE v1 v2 (go be)
215 replaceExpr :: Expr -> Expr -> Expr -> Expr
217 where go e | e == d = r
218 go (App e1 e2) = app' (go e1) (go e2)
219 go (Conc es) = foldr conc (Conc []) (map go es)
220 go (Lambda te e) = lambda' te (go e)
221 go (Pair e1 e2) = Pair (go e1) (go e2)
225 -- | Is inside the term a definition for the variable?
226 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
227 findReplacer tv be = findReplacer' (unTypeExpr tv) be
229 -- | Find a definition, and return a substitution
230 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
231 -- For combined types, look up the components
232 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
233 , Just (delY) <- findReplacer' y e
235 -- Find the definition
236 findReplacer' e (e1 `Equal` e2) | e == e1 = Just (replaceTermBE e e2)
237 | e == e2 = Just (replaceTermBE e e1)
238 findReplacer' e (And es) = listToMaybe (mapMaybe (findReplacer' e) es)
239 -- assuming no two definitions can exist
240 findReplacer' _ _ = Nothing
242 -- | Is inside the term a definition for the variable?
243 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
244 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
245 = Just (TypedExpr e' (typeOf tv))
246 | otherwise = Nothing
248 -- | Find a definition, and return it along the definition remover
249 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
250 defFor' e (e1 `Equal` e2) | e == e1 = Just (e2)
251 | e == e2 = Just (e1)
252 defFor' e (And es) | [d] <- mapMaybe (defFor' e) es -- exactly one definition
254 defFor' _ _ = Nothing
256 app :: TypedExpr -> TypedExpr -> TypedExpr
257 app te1 te2 | Arrow t1 t2 <- typeOf te1
260 = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
261 app te1 te2 | otherwise = error $ "Type mismatch in app: " ++
262 show te1 ++ " " ++ show te2
264 app' :: Expr -> Expr -> Expr
265 app' Map (Conc []) = Conc [] -- map id = id
266 app' ConstUnit v = EUnit -- id x = x
267 app' (Conc []) v = v -- id x = x
270 lambda :: TypedExpr -> TypedExpr -> TypedExpr
271 lambda tv e = TypedExpr (lambda' tv (unTypeExpr e)) (Arrow (typeOf tv) (typeOf e))
273 lambda' :: TypedExpr -> Expr -> Expr
274 lambda' tv e | (Just e') <- isApplOn (unTypeExpr tv) e
275 , not (unTypeExpr tv `occursIn` e')
277 | unTypeExpr tv == e = Conc []
278 | otherwise = Lambda tv e
280 conc :: Expr -> Expr -> Expr
281 conc (Conc xs) (Conc ys) = Conc (xs ++ ys)
282 conc (Conc xs) y = Conc (xs ++ [y])
283 conc x (Conc ys) = Conc ([x] ++ ys)
284 conc x y = Conc ([x,y])
287 -- Specialization of g'
289 specialize (TypeVarInst strict i be') =
290 replaceTermBE (Var (FromTypVar i)) (if strict then Conc [] else ConstUnit) .
291 everywhere (mkT $ go) $
293 where be = specialize be'
294 go (TypInst i' _) | i' == i = TUnit
296 -- No need to go further once we are through the quantors
301 isApplOn :: Expr -> Expr -> Maybe Expr
302 isApplOn e e' | e == e' = Nothing
303 isApplOn e (App f e') | e == e' = Just (Conc [f])
304 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
305 isApplOn _ _ = Nothing
307 occursIn :: (Typeable a, Data a1, Eq a) => a -> a1 -> Bool
308 e `occursIn` e' = not (null (listify (==e) e'))
310 isTuple :: Typ -> Bool
311 isTuple (TPair _ _) = True
324 instance Show EVar where
326 show (FromTypVar i) = "g" ++ show i
327 show (FromParam i b) = "x" ++ show i ++ if b then "'" else ""
330 instance Show Expr where
331 showsPrec _ (Var v) = showsPrec 11 v
332 showsPrec d (App e1 e2) = showParen (d>10) $
333 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
334 showsPrec _ (Conc []) = showString "id"
335 showsPrec d (Conc [e]) = showsPrec d e
336 showsPrec d (Conc es) = showParen (d>9) $
337 showIntercalate (showString " . ") (map (showsPrec 10) es)
338 showsPrec _ (Lambda tv e) = showParen True $
343 showsPrec _ (Pair e1 e2) = showParen True $
347 showsPrec _ Map = showString "map"
348 showsPrec d ConstUnit = showParen (d>10) $ showString "const ()"
349 showsPrec _ EitherMap = showString "eitherMap"
351 showIntercalate :: ShowS -> [ShowS] -> ShowS
352 showIntercalate _ [] = id
353 showIntercalate _ [x] = x
354 showIntercalate i (x:xs) = x . i . showIntercalate i xs
356 instance Show TypedExpr where
357 showsPrec d (TypedExpr e t) =
361 showString (showTypePrec 0 t)
363 instance Show LambdaBE where
364 show (CurriedEquals _) =
366 show (LambdaBE v1 v2 be) =
369 showsPrec 11 (unTypeExpr v1) "" ++
371 showsPrec 11 (unTypeExpr v2) "" ++
376 instance Show BoolExpr where
377 show (Equal e1 e2) = showsPrec 9 e1 $
380 show (And []) = "True"
381 show (And bes) = intercalate " && " $ map show bes
382 show (AllZipWith lbe e1 e2) =
386 showsPrec 11 e1 "" ++
389 show (AndEither lbe1 lbe2 e1 e2) =
395 showsPrec 11 e1 "" ++
398 show (Condition tvars be1 be2) =
400 intercalate ", " (map show tvars) ++
402 (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
404 show (UnpackPair v1 v2 e be) =
413 show (TypeVarInst strict i be) =
419 (if strict then "strict " else "(strict) ") ++
429 indent :: Int -> String -> String
430 indent n = unlines . map (replicate n ' ' ++) . lines
432 showTypePrec :: Int -> Typ -> String
433 showTypePrec _ Int = "Int"
434 showTypePrec _ (TVar (TypVar i)) = "a"++show i
435 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++ show (2*i-1)
436 | b = "t" ++ show (2*i)
437 showTypePrec _ (TVar (TUnit)) = "()"
438 showTypePrec d (Arrow t1 t2) = paren (d>9) $
439 showTypePrec 10 t1 ++
442 showTypePrec _ (List t) = "[" ++ showTypePrec 0 t ++ "]"
443 showTypePrec _ (TEither t1 t2) = "Either " ++ showTypePrec 11 t1 ++
444 " " ++ showTypePrec 11 t2
445 showTypePrec _ (TPair t1 t2) = "(" ++ showTypePrec 0 t1 ++
446 "," ++ showTypePrec 0 t2 ++ ")"
447 showTypePrec _ t = error $ "Did not expect to show " ++ show t
449 paren :: Bool -> String -> String
450 paren b p = if b then "(" ++ p ++ ")" else p