Optimize untupling lets
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 import Debug.Trace
9
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
12
13 data TypedExpr = TypedExpr
14         { unTypeExpr    :: Expr
15         , typeOf        :: Typ
16         } deriving (Eq, Typeable, Data)
17
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft  e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
21
22 data Expr
23         = Var String
24         | App Expr Expr
25         | Conc [Expr] -- Conc [] is Id
26         | Lambda TypedExpr Expr
27         | Pair Expr Expr
28         | Map
29             deriving (Eq, Typeable, Data)
30
31 data BoolExpr 
32         = Equal Expr Expr
33         | And [BoolExpr] -- And [] is True
34         | AllZipWith TypedExpr TypedExpr BoolExpr Expr Expr
35         | Condition [TypedExpr] BoolExpr BoolExpr
36         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
37         | TypeVarInst Int BoolExpr
38             deriving (Eq, Typeable, Data)
39
40 -- Smart constructors
41
42 -- | Try eta reduction
43 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
44               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
45
46 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
47               , (Just (rf,rv)) <- isFunctionApplication e2
48               , lv == rv 
49                                          = equal' lf rf
50               -- This makes it return True...
51               | e1 == e2                 = beTrue
52               | otherwise                = Equal e1 e2
53
54 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
55                                  = Just (conc f inner, v)
56                                  | otherwise
57                                  = Just (f, e')
58 isFunctionApplication _          = Nothing
59
60
61 -- | If both bound variables are just functions, we can replace this
62 --   by a comparison
63 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
64                        , Just subst2 <- findReplacer v2 be
65                        = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
66
67 -- | If the whole tuple is a function, we can replace this
68 --   by a comparison
69 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
70                        = subst $ (pair v1 v2 `equal` te) `aand` be
71
72 -- | Nothing to optimize
73 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
74
75 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
76
77 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
78 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
79                                 e1 `equal` amap (lambda v2 v1') e2
80                            | Just v2' <- defFor v2 rel =
81                                 amap (lambda v1 v2') e1 `equal` e2
82                            | otherwise =
83                                 AllZipWith v1 v2 rel (unTypeExpr e1) (unTypeExpr e2)
84
85 amap tf tl | Arrow t1 t2 <- typeOf tf
86            , List t      <- typeOf tl
87            , t1 == t
88            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
89              in app (app tMap tf) tl
90 amap tf tl | otherwise = error "Type error in map"
91
92 aand (And xs) (And ys) = And (xs  ++ ys)
93 aand (And xs) y        = And (xs  ++ [y])
94 aand x        (And ys) = And ([x] ++ ys)
95 aand x        y        = And ([x,y])
96
97 beTrue = And []
98
99 -- | Optimize a forall condition
100 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
101 -- empty condition
102 condition [] cond concl   | cond == beTrue
103                           = concl
104 -- float out conditions on the right
105 condition vars cond (Condition vars' cond' concl')
106                           = condition (vars ++ vars') (cond `aand` cond') concl'
107
108 -- Try to find variables that are functions of other variables, and remove them
109 condition vars cond concl | True -- set to false to disable
110                           , ((vars',cond',concl'):_) <- mapMaybe try vars
111                           = condition vars' cond' concl'
112               -- A variable which can be replaced
113   where try v | Just subst <- findReplacer v cond
114               = -- trace ("Tested " ++ show v ++ ", can be replaced") $
115                 Just (delete v vars, subst cond, subst concl)
116  
117               -- A variable with can be removed
118               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
119               = -- trace ("Tested " ++ show v ++ ", can be reased") $
120                 Just (delete v vars, cond, concl)
121
122               -- Nothing to do with this variable
123               | otherwise
124               = -- trace ("Tested " ++ show v ++ " without success") $
125                 Nothing
126
127 -- Nothing left to optizmize
128 condition vars cond concl = Condition vars cond concl
129
130 -- | Replaces a Term in a BoolExpr
131 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
132 replaceTermBE d r = go
133   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
134                            | d == e2 && r == e1 = beTrue
135                            | otherwise          = go' e1 `equal'` go' e2
136         go (And es)        = foldr aand beTrue (map go es)
137         go (AllZipWith v1 v2 be e1 e2) 
138                            = AllZipWith v1 v2 (go be) (go' e1) (go' e2)
139         go (Condition vs cond concl)
140                            = condition vs (go cond) (go concl)
141         go (UnpackPair v1 v2 e be)
142                            = unpackPair v1 v2 (go' e) (go be)
143         go (TypeVarInst _ _) = error "TypeVarInst not expected here"
144         go' :: Data a => a -> a
145         go' = replaceExpr d r
146
147 replaceExpr :: Data a => Expr -> Expr -> a -> a
148 replaceExpr d r = everywhere (mkT go)
149   where go e | e == d    = r 
150              | otherwise = e
151
152 -- | Is inside the term a definition for the variable?
153 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
154 findReplacer tv be = findReplacer' (unTypeExpr tv) be
155         
156 -- | Find a definition, and return a substitution
157 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
158 -- For combined types, look up the components
159 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
160                            , Just (delY) <- findReplacer' y e
161                     = Just (delX . delY)
162 -- Find the definition
163 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
164                                 | e == e2    = Just (replaceTermBE e e1)
165 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
166                                   -- assuming no two definitions can exist
167 findReplacer' _ _               = Nothing
168
169 -- | Is inside the term a definition for the variable?
170 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
171 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
172                          = Just (TypedExpr e' (typeOf tv))
173              | otherwise = Nothing
174         
175 -- | Find a definition, and return it along the definition remover
176 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
177 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
178                           | e == e2                 = Just (e1)
179 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
180                                                     = Just d
181 defFor' _ _                                         = Nothing
182
183 app te1 te2 | Arrow t1 t2 <- typeOf te1
184             , t3          <- typeOf te2 
185             , t1 == t3 
186             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
187  where app' Map (Conc []) = Conc []
188        app' (Conc []) v   = v
189        app' f v           = App f v
190 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
191                                                            show te1 ++ " " ++ show te2
192
193 {- dead code
194 unCond v (Equal l r) | (Just l') <- isApplOn (unTypeExpr v) l 
195                      , (Just r') <- isApplOn (unTypeExpr v) r = 
196         if v `occursIn` l' || v `occursIn` r'
197         then Condition [v] beTrue (Equal l' r')
198         else (Equal l' r')
199 unCond v e = Condition [v] beTrue e
200 -}
201
202 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
203   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
204               , not (unTypeExpr tv `occursIn` e')
205                           = e'
206               | tv == e   = Conc []
207               | otherwise = Lambda tv (unTypeExpr e)
208
209 conc (Conc xs) (Conc ys) = Conc (xs  ++ ys)
210 conc (Conc xs)  y        = Conc (xs  ++ [y])
211 conc x         (Conc ys) = Conc ([x] ++ ys)
212 conc x          y        = Conc ([x,y])
213
214 -- Helpers
215
216 isApplOn e e'         | e == e'                       = Nothing
217 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
218 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
219 isApplOn _ _                                          = Nothing
220
221 hasVar v (Var v')     = v == v'
222 hasVar v (App e1 e2)  = hasVar v e1 && hasVar v e2
223 hasVar v (Conc es)    = any (hasVar v) es
224 hasVar v (Lambda _ e) = hasVar v e
225 hasVar v Map          = False
226
227 e `occursIn` e'       = not (null (listify (==e) e'))
228
229 isTuple (TPair _ _) = True
230 isTuple _           = False
231
232
233 -- showing
234
235 -- Precedences:
236 -- 10 fun app
237 --  9 (.)
238 --  8 ==
239 --  7 ==>
240 --  6 forall
241
242 instance Show Expr where
243         showsPrec d (Var s)     = showString s
244         showsPrec d (App e1 e2) = showParen (d>10) $
245                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
246         showsPrec d (Conc [])   = showString "id"
247         showsPrec d (Conc [e])  = showsPrec d e
248         showsPrec d (Conc es)   = showParen (d>9) $
249                 showIntercalate (showString " . ") (map (showsPrec 10) es)
250         showsPrec d (Lambda tv e) = showParen True $ 
251                                     showString "\\" .
252                                     showsPrec 0 tv .
253                                     showString " -> ".
254                                     showsPrec 0 e 
255         showsPrec _ (Pair e1 e2) = showParen True $ 
256                                    showsPrec 0 e1 .
257                                    showString "," .
258                                    showsPrec 0 e2
259         showsPrec _ Map           = showString "map"
260
261 showIntercalate i []  = id
262 showIntercalate i [x] = x
263 showIntercalate i (x:xs) = x . i . showIntercalate i xs
264
265 instance Show TypedExpr where
266         showsPrec d (TypedExpr e t) = 
267                 showParen (d>10) $
268                         showsPrec 0 e .
269                         showString " :: " .
270                         showString (showTypePrec 0 t)
271
272 instance Show BoolExpr where
273         show (Equal e1 e2) = showsPrec 9 e1 $
274                              showString " == " $
275                              showsPrec 9 e2 ""
276         show (And [])      = "True"
277         show (And bes)     = intercalate " && " $ map show bes
278         show (AllZipWith v1 v2 be e1 e2) =
279                         "allZipWith " ++
280                         "( " ++
281                         "\\" ++
282                         showsPrec 11 v1 "" ++
283                         " " ++
284                         showsPrec 11 v2 "" ++
285                         " -> " ++
286                         show be ++
287                         ")" ++
288                         " " ++
289                         showsPrec 11 e1 "" ++
290                         " " ++
291                         showsPrec 11 e2 ""
292         show (Condition tvars be1 be2) = 
293                         "forall " ++
294                         intercalate ", " (map show tvars) ++
295                         ".\n" ++
296                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
297                         indent 2 (show be2)
298         show (UnpackPair v1 v2 e be) = 
299                         "let (" ++
300                         showsPrec 0 v1 "" ++
301                         "," ++
302                         showsPrec 0 v2 "" ++
303                         ") = " ++
304                         showsPrec 0 e "" ++
305                         " in\n" ++
306                         indent 2 (show be)
307         show (TypeVarInst i be) = 
308                         "forall types t" ++
309                         show (2*i-1) ++
310                         ", t" ++
311                         show (2*i) ++
312                         ", function g" ++
313                         show i ++
314                         " :: t" ++
315                         show (2*i-1) ++
316                         " -> t" ++
317                         show (2*i) ++ 
318                         ".\n" ++
319                         indent 2 (show be)
320
321 indent n = unlines . map (replicate n ' ' ++) . lines
322
323 showTypePrec :: Int -> Typ -> String
324 showTypePrec _ Int                          = "Int" 
325 showTypePrec _ (TVar (TypVar i))            = "a"++show i
326 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
327                                     |     b = "t" ++  show (2*i)
328 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
329                                   showTypePrec 10 t1 ++ " -> " ++ showTypePrec 9 t2 
330 showTypePrec d (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
331 showTypePrec d (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
332                                                     " " ++ showTypePrec 11 t2
333 showTypePrec d (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
334                                               "," ++ showTypePrec 0 t2 ++ ")"
335
336 paren b p   =  if b then "(" ++ p ++ ")" else p