Reduce Uncurry when possible
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 -- import Debug.Trace
9
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
12
13 data TypedExpr = TypedExpr
14         { unTypeExpr    :: Expr
15         , typeOf        :: Typ
16         } deriving (Eq, Typeable, Data)
17
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft  e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
21
22 data EVar = F
23           | FromTypVar Int
24           | FromParam Int Bool
25             deriving (Eq, Ord, Typeable, Data)
26
27 data Expr
28         = Var EVar
29         | App Expr Expr
30         | Conc [Expr] -- Conc [] is Id
31         | Lambda Expr Expr
32         | Pair Expr Expr
33         | Map
34         | Const Expr
35         | ELeft Expr
36         | ERight Expr
37         | CaseUnit Expr Expr
38         | EitherMap
39         | Uncurry
40         | HeadMap
41         | EUnit
42         | Singleton Expr
43         | Zero
44         | Bottom
45             deriving (Eq, Typeable, Data)
46
47 data LambdaBE = CurriedEquals Typ
48               | LambdaBE TypedExpr TypedExpr BoolExpr
49             deriving (Eq, Typeable, Data)
50
51 data BoolExpr 
52         = Equal Expr Expr
53         | And [BoolExpr] -- And [] is True
54         | AllZipWith LambdaBE Expr Expr
55         | AndEither  LambdaBE LambdaBE Expr Expr
56         | Condition [TypedExpr] BoolExpr BoolExpr
57         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
58         | TypeVarInst Bool Int BoolExpr
59             deriving (Eq, Typeable, Data)
60
61 -- Smart constructors
62
63 -- | Try eta reduction
64 equal :: TypedExpr -> TypedExpr -> BoolExpr
65 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
66               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
67
68 equal' :: Expr -> Expr -> BoolExpr
69 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
70               , (Just (rf,rv)) <- isFunctionApplication e2
71               , lv == rv 
72               , False
73                                          = equal' lf rf
74               -- This makes it return True...
75               | e1 == e2                 = beTrue
76               | otherwise                = Equal e1 e2
77
78 -- | If e is of the type (app f1 (app f2 .. (app fn x)..)),
79 --   return Just (f1 . f2. ... . fn, x)
80 isFunctionApplication :: Expr -> Maybe (Expr, Expr)
81 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
82                                  = Just (conc f inner, v)
83                                  | otherwise
84                                  = Just (f, e')
85 isFunctionApplication _          = Nothing
86
87
88 unpackPair :: TypedExpr -> TypedExpr -> TypedExpr -> BoolExpr -> BoolExpr
89 -- | if the te is already a tuple, then replace the variables
90 unpackPair v1 v2 te be | Pair pe1 pe2 <- unTypeExpr te
91                        = replaceTermBE (unTypeExpr v1) pe1 $
92                          replaceTermBE (unTypeExpr v2) pe2 $ be
93 -- | If both bound variables are just functions, we can replace this
94 --   by a comparison
95 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
96                        , Just subst2 <- findReplacer v2 be
97                        = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
98
99 -- | Don’t unpack pair if vars are not used
100 unpackPair v1 v2 te be | not (unTypeExpr v1 `occursIn` be || unTypeExpr v2 `occursIn` be)
101                        = be
102 -- | If the whole tuple is a function, we can replace this
103 --   by a comparison
104 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
105                        = subst $ (pair v1 v2 `equal` te) `aand` be
106 -- | Nothing to optimize
107 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
108
109 pair :: TypedExpr -> TypedExpr -> TypedExpr
110 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
111
112 lambdaBE :: TypedExpr -> TypedExpr -> BoolExpr -> LambdaBE
113 lambdaBE v1 v2 rel | typeOf v1 == typeOf v2 
114                    , rel == v1 `equal` v2    = CurriedEquals (typeOf v1)
115                    | otherwise               = LambdaBE v1 v2 rel
116
117 andEither :: LambdaBE -> LambdaBE -> TypedExpr -> TypedExpr -> BoolExpr
118 andEither (CurriedEquals _) (CurriedEquals _) e1 e2 = e1 `equal` e2
119 andEither lbe1 lbe2 e1 e2 | Just f1 <- arg1IsFunc lbe1
120                           , Just f2 <- arg1IsFunc lbe2
121                           = e1 `equal` eitherE f1 f2 e2
122                           | Just f1 <- arg2IsFunc lbe1
123                           , Just f2 <- arg2IsFunc lbe2
124                           = eitherE f1 f2 e1 `equal` e2
125                           | otherwise
126                           = andEither' lbe1 lbe2 (unTypeExpr e1) (unTypeExpr e2)
127
128 andEither' :: LambdaBE -> LambdaBE -> Expr -> Expr -> BoolExpr
129 andEither' (LambdaBE v1 v2 rel) _ (ELeft e1) (ELeft e2)
130        = replaceTermBE (unTypeExpr v1) e1 $ replaceTermBE (unTypeExpr v2) e2 $ rel
131 andEither' _ (LambdaBE v1 v2 rel) (ERight e1) (ERight e2)
132        = replaceTermBE (unTypeExpr v1) e1 $ replaceTermBE (unTypeExpr v2) e2 $ rel
133 andEither' lbe1 lbe2 e1 e2
134        = AndEither lbe1 lbe2 e1 e2
135
136
137 arg1IsFunc :: LambdaBE -> Maybe TypedExpr
138 arg1IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
139 arg1IsFunc (LambdaBE v1 v2 rel) | Just v1' <- defFor v1 rel
140                                 = Just (lambda v2 v1')
141                                 | otherwise = Nothing
142
143 arg2IsFunc :: LambdaBE -> Maybe TypedExpr
144 arg2IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
145 arg2IsFunc (LambdaBE v1 v2 rel) | Just v2' <- defFor v2 rel
146                                 = Just (lambda v1 v2')
147                                 | otherwise = Nothing
148
149 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
150 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
151                                 e1 `equal` amap (lambda v2 v1') e2
152                            | Just v2' <- defFor v2 rel =
153                                 amap (lambda v1 v2') e1 `equal` e2
154                            | otherwise =
155                                 AllZipWith (LambdaBE v1 v2 rel) (unTypeExpr e1) (unTypeExpr e2)
156
157 eitherE :: TypedExpr -> TypedExpr -> TypedExpr -> TypedExpr
158 eitherE f1 f2 e | Arrow lt1 lt2 <- typeOf f1
159                 , Arrow rt1 rt2 <- typeOf f2
160                 , TEither lt rt <- typeOf e
161                 , lt1 == lt
162                 , rt1 == rt
163         = let tEither = TypedExpr EitherMap (Arrow (typeOf f1) (Arrow (typeOf f2) (Arrow (typeOf e) (TEither lt2 rt2))))
164           in  app (app (app tEither f1) f2) e
165                 | otherwise = error $ "Type error in eitherE\n" ++ show (f1, f2, e)
166
167 amap :: TypedExpr -> TypedExpr -> TypedExpr
168 amap tf tl | Arrow t1 t2 <- typeOf tf
169            , List t      <- typeOf tl
170            , t1 == t
171            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
172              in app (app tMap tf) tl
173            | otherwise = error "Type error in map"
174
175 aand :: BoolExpr -> BoolExpr -> BoolExpr
176 aand (And xs) (And ys) = And (xs  ++ ys)
177 aand (And xs) y        = And (xs  ++ [y])
178 aand x        (And ys) = And ([x] ++ ys)
179 aand x        y        = And ([x,y])
180
181 beTrue :: BoolExpr
182 beTrue = And []
183
184 -- | Optimize a forall condition
185 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
186 -- empty condition
187 condition [] cond concl   | cond == beTrue
188                           = concl
189 -- float out conditions on the right
190 condition vars cond (Condition vars' cond' concl')
191                           = condition (vars ++ vars') (cond `aand` cond') concl'
192
193 -- Try to find variables that are functions of other variables, and remove them
194 condition vars cond concl | True -- set to false to disable
195                           , ((vars',cond',concl'):_) <- mapMaybe try vars
196                           = condition vars' cond' concl'
197               -- A variable which can be replaced
198   where try v | Just subst <- findReplacer v cond
199               = -- trace ("Tested " ++ show v ++ ", can be replaced") $
200                 Just (delete v vars, subst cond, subst concl)
201  
202               -- A variable with can be removed
203               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
204               = -- trace ("Tested " ++ show v ++ ", can be reased") $
205                 Just (delete v vars, cond, concl)
206
207               -- Nothing to do with this variable
208               | otherwise
209               = -- trace ("Tested " ++ show v ++ " without success") $
210                 Nothing
211
212 -- Nothing left to optizmize
213 condition vars cond concl = Condition vars cond concl
214
215
216 caseUnit Bottom e = Bottom
217 caseUnit EUnit e  = e
218 caseUnit v e      = CaseUnit v e
219
220 -- | Replaces a Term in a BoolExpr
221 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
222 replaceTermBE d r = go
223   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
224                            | d == e2 && r == e1 = beTrue
225                            | otherwise          = go' e1 `equal'` go' e2
226         go (And es)        = foldr aand beTrue (map go es)
227         go (AllZipWith lbe e1 e2) 
228                            = AllZipWith (goL lbe) (go' e1) (go' e2)
229         go (AndEither lbe1 lbe2 e1 e2)
230                            = andEither' (goL lbe1) (goL lbe2) (go' e1) (go' e2)
231         go c@(Condition vs cond concl) -- shadowed definition
232                            | d `elem` map unTypeExpr vs 
233                            = c
234         go (Condition vs cond concl)
235                            = condition vs (go cond) (go concl)
236         go (UnpackPair v1 v2 e be)
237                            = unpackPair v1 v2 (goT e) (go be)
238         go (TypeVarInst _ _ _) = error "TypeVarInst not expected here"
239
240         go' = replaceExpr d r
241
242         goT te = te { unTypeExpr = go' (unTypeExpr te) }
243
244         goL (CurriedEquals t)   = (CurriedEquals t)
245         goL (LambdaBE v1 v2 be) = lambdaBE v1 v2 (go be)
246
247
248 replaceExpr :: Expr -> Expr -> Expr -> Expr
249 replaceExpr d r = go
250   where go e | e == d    = r
251         go (App e1 e2)   = app' (go e1) (go e2)
252         go (Conc es)     = foldr conc (Conc []) (map go es)
253         go (Lambda v e)  = lambda' v (go e)
254         go (Pair e1 e2)  = Pair (go e1) (go e2)
255         go (Const e)     = Const (go e)
256         go (CaseUnit v e)= caseUnit (go v) (go e)
257         go e             = e
258
259
260 -- | Is inside the term a definition for the variable?
261 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
262 findReplacer tv be = findReplacer' (unTypeExpr tv) be
263         
264 -- | Find a definition, and return a substitution
265 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
266 -- For combined types, look up the components
267 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
268                            , Just (delY) <- findReplacer' y e
269                     = Just (delX . delY)
270 -- Find the definition
271 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
272                                 | e == e2    = Just (replaceTermBE e e1)
273 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
274                                   -- assuming no two definitions can exist
275 findReplacer' _ _               = Nothing
276
277 -- | Is inside the term a definition for the variable?
278 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
279 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
280                          = Just (TypedExpr e' (typeOf tv))
281              | otherwise = Nothing
282         
283 -- | Find a definition, and return it along the definition remover
284 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
285 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
286                           | e == e2                 = Just (e1)
287 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
288                                                     = Just d
289 defFor' _ _                                         = Nothing
290
291 app :: TypedExpr -> TypedExpr -> TypedExpr
292 app te1 te2 | Arrow t1 t2 <- typeOf te1
293             , t3          <- typeOf te2 
294             , t1 == t3 
295             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
296 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
297                                                            show te1 ++ " " ++ show te2
298
299 app' :: Expr -> Expr -> Expr
300 app' (App HeadMap f) Bottom                 = Bottom
301 app' (App HeadMap f) (Singleton e)          = app' f e
302 app' (App Uncurry _) Bottom                 = Bottom
303 app' (App Uncurry f) (Pair v1 v2)           = f `app'` v1 `app'` v2
304 app' (App (App EitherMap f1) f2) Bottom     = Bottom
305 app' (App (App EitherMap f1) f2) (ELeft v)  = ELeft (app' f1 v)
306 app' (App (App EitherMap f1) f2) (ERight v) = ERight (app' f2 v)
307 app' Bottom    _   = Bottom    -- _|_ x = _|_
308 app' (Lambda v e1) (e2) = replaceExpr v e2 e1 -- lambda application
309 app' (App Map f) (Singleton v) = Singleton (app' f v)
310 app' Map (Conc []) = Conc []   -- map id = id
311 app' (Const e) _   = e         -- const x y = x
312 app' (Conc []) v   = v         -- id x   = x
313 app' (Conc xs) v   = foldr app' v xs
314 app' f v           = App f v
315
316 lambda :: TypedExpr -> TypedExpr -> TypedExpr
317 lambda tv e = TypedExpr (lambda' (unTypeExpr tv) (unTypeExpr e))
318                         (Arrow (typeOf tv) (typeOf e))
319
320 lambda' :: Expr -> Expr -> Expr
321 lambda' v e  | e == EUnit           = Const EUnit
322               | (Just e') <- isApplOn v e
323               , not (v `occursIn` e')
324                                      = e'
325               | v == e   = Conc []
326               | otherwise            = Lambda v e
327
328 conc :: Expr -> Expr -> Expr
329 conc (Lambda v (CaseUnit v' e)) (Conc ((Const EUnit):r))
330                                 | v == v' = conc (Const e) (Conc r)
331 conc (Lambda v (CaseUnit v' e)) (Const EUnit) | v == v' = Const e
332 conc (Conc xs) (Conc ys)        | [x] <- xs ++ ys       = x
333                                 | otherwise             = Conc (xs  ++ ys)
334 conc (Conc xs)  y               = Conc (xs  ++ [y])
335 conc x         (Conc ys)        = Conc ([x] ++ ys)
336 conc x          y               = Conc ([x,y])
337
338
339 -- Specialization of g'
340
341 specialize :: BoolExpr -> BoolExpr
342 specialize (TypeVarInst strict i be') = 
343                 replaceTermBE (Var (FromTypVar i)) (if strict then Conc [] else Const EUnit) .
344                 everywhere (mkT $ go) $
345                 be
346         where be = specialize be'
347               go (TypInst i' _) | i' == i = TUnit
348               go tv                       = tv                 
349 -- No need to go further once we are through the quantors
350 specialize be = be
351
352 -- Helpers
353
354 isApplOn :: Expr -> Expr -> Maybe Expr
355 isApplOn e e'         | e == e'                       = Nothing
356 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
357 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
358 isApplOn _ _                                          = Nothing
359
360 occursIn :: (Typeable a, Data a1, Eq a) => a -> a1 -> Bool
361 e `occursIn` e'       = not (null (listify (==e) e'))
362
363 isTuple :: Typ -> Bool
364 isTuple (TPair _ _) = True
365 isTuple _           = False
366
367
368 -- showing
369
370 -- Precedences:
371 -- 10 fun app
372 --  9 (.)
373 --  8 ==
374 --  7 ==>
375 --  6 forall
376
377 instance Show EVar where
378         show F               = "f"
379         show (FromTypVar i)  = "g" ++ show i
380         show (FromParam i b) = "x" ++ show i ++ if b then "'" else ""
381
382
383 instance Show Expr where
384         showsPrec _ (Var v)     = showsPrec 11 v
385         showsPrec d (App e1 e2) = showParen (d>10) $
386                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
387         showsPrec _ (Conc [])   = showString "id"
388         showsPrec d (Conc [e])  = showsPrec d e
389         showsPrec d (Conc es)   = showParen (d>9) $
390                 showIntercalate (showString " . ") (map (showsPrec 10) es)
391         showsPrec _ (Lambda v e)  = showParen True $ 
392                                     showString "\\" .
393                                     showsPrec 11 v .
394                                     showString " -> ".
395                                     showsPrec 0 e 
396         showsPrec _ (Pair e1 e2) = showParen True $ 
397                                    showsPrec 0 e1 .
398                                    showString "," .
399                                    showsPrec 0 e2
400         showsPrec _ Zero          = showString "0"
401         showsPrec _ EUnit         = showString "()"
402         showsPrec _ (Singleton e) = showString "[" . showsPrec 0 e . showString "]"
403         showsPrec _ Map           = showString "map"
404         showsPrec d (ELeft e)     = showParen (d>10) $ 
405                                         showString "Left ".
406                                         showsPrec 11 e
407         showsPrec d (ERight e)    = showParen (d>10) $ 
408                                         showString "Right ".
409                                         showsPrec 11 e
410         showsPrec d (Const e)     = showParen (d>10) $ showString "const ".showsPrec 11 e
411         showsPrec d (CaseUnit t1 t2) = showParen (d>5) $
412                                         showString "case " .
413                                         showsPrec 0 t1 .
414                                         showString " of () ->  " .
415                                         showsPrec 11 t2
416         showsPrec _ EitherMap     = showString "eitherMap"
417         showsPrec _ HeadMap       = showString "headMap"
418         showsPrec _ Uncurry       = showString "uncurry"
419         showsPrec _ Bottom        = showString "_|_"
420
421 showIntercalate :: ShowS -> [ShowS] -> ShowS
422 showIntercalate _ []  = id
423 showIntercalate _ [x] = x
424 showIntercalate i (x:xs) = x . i . showIntercalate i xs
425
426 instance Show TypedExpr where
427         showsPrec d (TypedExpr e t) = 
428                 showParen (d>10) $
429                         showsPrec 0 e .
430                         showString " :: " .
431                         showString (showTypePrec 0 t)
432
433 instance Show LambdaBE where
434         show (CurriedEquals _) = 
435                         "(==)"
436         show (LambdaBE v1 v2 be) = 
437                         "(" ++
438                         "\\" ++
439                         showsPrec 11 (unTypeExpr v1) "" ++
440                         " " ++
441                         showsPrec 11 (unTypeExpr v2) "" ++
442                         " -> " ++
443                         show be ++
444                         ")"
445
446 instance Show BoolExpr where
447         show (Equal e1 e2) = showsPrec 9 e1 $
448                              showString " == " $
449                              showsPrec 9 e2 ""
450         show (And [])      = "True"
451         show (And bes)     = intercalate " && " $ map show bes
452         show (AllZipWith lbe e1 e2) =
453                         "allZipWith " ++
454                         show lbe ++
455                         " " ++
456                         showsPrec 11 e1 "" ++
457                         " " ++
458                         showsPrec 11 e2 ""
459         show (AndEither lbe1 lbe2 e1 e2) =
460                         "andEither " ++
461                         show lbe1 ++
462                         " " ++
463                         show lbe2 ++
464                         " " ++
465                         showsPrec 11 e1 "" ++
466                         " " ++
467                         showsPrec 11 e2 ""
468         show (Condition tvars be1 be2) = 
469                         "forall " ++
470                         intercalate ", " (map show tvars) ++
471                         ".\n" ++
472                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
473                         indent 2 (show be2)
474         show (UnpackPair v1 v2 e be) = 
475                         "let (" ++
476                         showsPrec 0 v1 "" ++
477                         "," ++
478                         showsPrec 0 v2 "" ++
479                         ") = " ++
480                         showsPrec 0 e "" ++
481                         " in\n" ++
482                         indent 2 (show be)
483         show (TypeVarInst strict i be) = 
484                         "forall types t" ++
485                         show (2*i-1) ++
486                         ", t" ++
487                         show (2*i) ++
488                         ", " ++
489                         (if strict then "strict " else "(strict) ") ++
490                         "functions g" ++
491                         show i ++
492                         " :: t" ++
493                         show (2*i-1) ++
494                         " -> t" ++
495                         show (2*i) ++ 
496                         ".\n" ++
497                         indent 2 (show be)
498
499 indent :: Int -> String -> String
500 indent n = unlines . map (replicate n ' ' ++) . lines
501
502 showTypePrec :: Int -> Typ -> String
503 showTypePrec _ Int                          = "Int" 
504 showTypePrec _ (TVar (TypVar i))            = "a"++show i
505 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
506                                     |     b = "t" ++  show (2*i)
507 showTypePrec _ (TVar (TUnit))               = "()"
508 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
509                                                 showTypePrec 10 t1 ++
510                                                 " -> " ++
511                                                 showTypePrec 9 t2 
512 showTypePrec _ (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
513 showTypePrec _ (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
514                                                     " " ++ showTypePrec 11 t2
515 showTypePrec _ (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
516                                               "," ++ showTypePrec 0 t2 ++ ")"
517 showTypePrec _ t                            = error $ "Did not expect to show " ++ show t
518
519 paren :: Bool -> String -> String
520 paren b p   =  if b then "(" ++ p ++ ")" else p