Add nice getComplete' function
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 -- import Debug.Trace
9
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
12
13 data TypedExpr = TypedExpr
14         { unTypeExpr    :: Expr
15         , typeOf        :: Typ
16         } deriving (Eq, Typeable, Data)
17
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft  e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
21
22 data Expr
23         = Var String
24         | App Expr Expr
25         | Conc [Expr] -- Conc [] is Id
26         | Lambda TypedExpr Expr
27         | Pair Expr Expr
28         | Map
29         | EitherMap
30             deriving (Eq, Typeable, Data)
31
32 data LambdaBE = CurriedEquals Typ
33               | LambdaBE TypedExpr TypedExpr BoolExpr
34             deriving (Eq, Typeable, Data)
35
36 data BoolExpr 
37         = Equal Expr Expr
38         | And [BoolExpr] -- And [] is True
39         | AllZipWith LambdaBE Expr Expr
40         | AndEither  LambdaBE LambdaBE Expr Expr
41         | Condition [TypedExpr] BoolExpr BoolExpr
42         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
43         | TypeVarInst Int BoolExpr
44             deriving (Eq, Typeable, Data)
45
46 -- Smart constructors
47
48 -- | Try eta reduction
49 equal :: TypedExpr -> TypedExpr -> BoolExpr
50 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
51               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
52
53 equal' :: Expr -> Expr -> BoolExpr
54 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
55               , (Just (rf,rv)) <- isFunctionApplication e2
56               , lv == rv 
57                                          = equal' lf rf
58               -- This makes it return True...
59               | e1 == e2                 = beTrue
60               | otherwise                = Equal e1 e2
61
62 -- | If e is of the type (app f1 (app f2 .. (app fn x)..)),
63 --   return Just (f1 . f2. ... . fn, x)
64 isFunctionApplication :: Expr -> Maybe (Expr, Expr)
65 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
66                                  = Just (conc f inner, v)
67                                  | otherwise
68                                  = Just (f, e')
69 isFunctionApplication _          = Nothing
70
71
72 -- | If both bound variables are just functions, we can replace this
73 --   by a comparison
74 unpackPair :: TypedExpr -> TypedExpr -> TypedExpr -> BoolExpr -> BoolExpr
75 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
76                        , Just subst2 <- findReplacer v2 be
77                        = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
78
79 -- | If the whole tuple is a function, we can replace this
80 --   by a comparison
81 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
82                        = subst $ (pair v1 v2 `equal` te) `aand` be
83
84 -- | Nothing to optimize
85 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
86
87 pair :: TypedExpr -> TypedExpr -> TypedExpr
88 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
89
90 lambdaBE :: TypedExpr -> TypedExpr -> BoolExpr -> LambdaBE
91 lambdaBE v1 v2 rel | typeOf v1 == typeOf v2 
92                    , rel == v1 `equal` v2    = CurriedEquals (typeOf v1)
93                    | otherwise               = LambdaBE v1 v2 rel
94
95 andEither :: LambdaBE -> LambdaBE -> TypedExpr -> TypedExpr -> BoolExpr
96 andEither (CurriedEquals _) (CurriedEquals _) e1 e2 = e1 `equal` e2
97 andEither lbe1 lbe2 e1 e2 | Just f1 <- arg1IsFunc lbe1
98                           , Just f2 <- arg1IsFunc lbe2
99                           = e1 `equal` eitherE f1 f2 e2
100                           | Just f1 <- arg2IsFunc lbe1
101                           , Just f2 <- arg2IsFunc lbe2
102                           = eitherE f1 f2 e1 `equal` e2
103                           | otherwise
104                           = AndEither lbe1 lbe2 (unTypeExpr e1) (unTypeExpr e2)
105
106 arg1IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
107 arg1IsFunc (LambdaBE v1 v2 rel) | Just v1' <- defFor v1 rel
108                                 = Just (lambda v2 v1')
109                                 | otherwise = Nothing
110
111 arg2IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
112 arg2IsFunc (LambdaBE v1 v2 rel) | Just v2' <- defFor v2 rel
113                                 = Just (lambda v1 v2')
114                                 | otherwise = Nothing
115
116 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
117 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
118                                 e1 `equal` amap (lambda v2 v1') e2
119                            | Just v2' <- defFor v2 rel =
120                                 amap (lambda v1 v2') e1 `equal` e2
121                            | otherwise =
122                                 AllZipWith (LambdaBE v1 v2 rel) (unTypeExpr e1) (unTypeExpr e2)
123
124 eitherE :: TypedExpr -> TypedExpr -> TypedExpr -> TypedExpr
125 eitherE f1 f2 e | Arrow lt1 lt2 <- typeOf f1
126                 , Arrow rt1 rt2 <- typeOf f2
127                 , TEither lt rt <- typeOf e
128                 , lt1 == lt
129                 , rt1 == rt
130         = let tEither = TypedExpr EitherMap (Arrow (typeOf f1) (Arrow (typeOf f2) (Arrow (typeOf e) (TEither lt2 rt2))))
131           in  app (app (app tEither f1) f2) e
132                 | otherwise = error $ "Type error in eitherE\n" ++ show (f1, f2, e)
133
134 amap :: TypedExpr -> TypedExpr -> TypedExpr
135 amap tf tl | Arrow t1 t2 <- typeOf tf
136            , List t      <- typeOf tl
137            , t1 == t
138            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
139              in app (app tMap tf) tl
140            | otherwise = error "Type error in map"
141
142 aand :: BoolExpr -> BoolExpr -> BoolExpr
143 aand (And xs) (And ys) = And (xs  ++ ys)
144 aand (And xs) y        = And (xs  ++ [y])
145 aand x        (And ys) = And ([x] ++ ys)
146 aand x        y        = And ([x,y])
147
148 beTrue :: BoolExpr
149 beTrue = And []
150
151 -- | Optimize a forall condition
152 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
153 -- empty condition
154 condition [] cond concl   | cond == beTrue
155                           = concl
156 -- float out conditions on the right
157 condition vars cond (Condition vars' cond' concl')
158                           = condition (vars ++ vars') (cond `aand` cond') concl'
159
160 -- Try to find variables that are functions of other variables, and remove them
161 condition vars cond concl | True -- set to false to disable
162                           , ((vars',cond',concl'):_) <- mapMaybe try vars
163                           = condition vars' cond' concl'
164               -- A variable which can be replaced
165   where try v | Just subst <- findReplacer v cond
166               = -- trace ("Tested " ++ show v ++ ", can be replaced") $
167                 Just (delete v vars, subst cond, subst concl)
168  
169               -- A variable with can be removed
170               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
171               = -- trace ("Tested " ++ show v ++ ", can be reased") $
172                 Just (delete v vars, cond, concl)
173
174               -- Nothing to do with this variable
175               | otherwise
176               = -- trace ("Tested " ++ show v ++ " without success") $
177                 Nothing
178
179 -- Nothing left to optizmize
180 condition vars cond concl = Condition vars cond concl
181
182 -- | Replaces a Term in a BoolExpr
183 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
184 replaceTermBE d r = go
185   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
186                            | d == e2 && r == e1 = beTrue
187                            | otherwise          = go' e1 `equal'` go' e2
188         go (And es)        = foldr aand beTrue (map go es)
189         go (AllZipWith lbe e1 e2) 
190                            = AllZipWith (goL lbe) (go' e1) (go' e2)
191         go (AndEither lbe1 lbe2 e1 e2)
192                            = AndEither (goL lbe1) (goL lbe2) (go' e1) (go' e2)
193         go (Condition vs cond concl)
194                            = condition vs (go cond) (go concl)
195         go (UnpackPair v1 v2 e be)
196                            = unpackPair v1 v2 (go' e) (go be)
197         go (TypeVarInst _ _) = error "TypeVarInst not expected here"
198
199         go' :: Data a => a -> a
200         go' = replaceExpr d r
201
202         goL (CurriedEquals t)   = (CurriedEquals t)
203         goL (LambdaBE v1 v2 be) = lambdaBE v1 v2 (go be)
204
205 replaceExpr :: Data a => Expr -> Expr -> a -> a
206 replaceExpr d r = everywhere (mkT go)
207   where go e | e == d    = r 
208              | otherwise = e
209
210 -- | Is inside the term a definition for the variable?
211 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
212 findReplacer tv be = findReplacer' (unTypeExpr tv) be
213         
214 -- | Find a definition, and return a substitution
215 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
216 -- For combined types, look up the components
217 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
218                            , Just (delY) <- findReplacer' y e
219                     = Just (delX . delY)
220 -- Find the definition
221 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
222                                 | e == e2    = Just (replaceTermBE e e1)
223 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
224                                   -- assuming no two definitions can exist
225 findReplacer' _ _               = Nothing
226
227 -- | Is inside the term a definition for the variable?
228 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
229 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
230                          = Just (TypedExpr e' (typeOf tv))
231              | otherwise = Nothing
232         
233 -- | Find a definition, and return it along the definition remover
234 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
235 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
236                           | e == e2                 = Just (e1)
237 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
238                                                     = Just d
239 defFor' _ _                                         = Nothing
240
241 app :: TypedExpr -> TypedExpr -> TypedExpr
242 app te1 te2 | Arrow t1 t2 <- typeOf te1
243             , t3          <- typeOf te2 
244             , t1 == t3 
245             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
246  where app' Map (Conc []) = Conc []   -- map id = id
247        app' (Conc []) v   = v         -- id x   = x
248        app' f v           = App f v
249 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
250                                                            show te1 ++ " " ++ show te2
251
252 lambda :: TypedExpr -> TypedExpr -> TypedExpr
253 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
254   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
255               , not (unTypeExpr tv `occursIn` e')
256                           = e'
257               | tv == e   = Conc []
258               | otherwise = Lambda tv (unTypeExpr e)
259
260 conc :: Expr -> Expr -> Expr
261 conc (Conc xs) (Conc ys) = Conc (xs  ++ ys)
262 conc (Conc xs)  y        = Conc (xs  ++ [y])
263 conc x         (Conc ys) = Conc ([x] ++ ys)
264 conc x          y        = Conc ([x,y])
265
266 -- Helpers
267
268 isApplOn :: Expr -> Expr -> Maybe Expr
269 isApplOn e e'         | e == e'                       = Nothing
270 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
271 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
272 isApplOn _ _                                          = Nothing
273
274 hasVar :: String -> Expr -> Bool
275 hasVar v (Var v')     = v == v'
276 hasVar v (App e1 e2)  = hasVar v e1 || hasVar v e2
277 hasVar v (Conc es)    = any (hasVar v) es
278 hasVar v (Lambda _ e) = hasVar v e
279 hasVar v (Pair e1 e2) = hasVar v e1 || hasVar v e2
280 hasVar _ Map          = False
281
282 occursIn :: (Typeable a, Data a1, Eq a) => a -> a1 -> Bool
283 e `occursIn` e'       = not (null (listify (==e) e'))
284
285 isTuple :: Typ -> Bool
286 isTuple (TPair _ _) = True
287 isTuple _           = False
288
289
290 -- showing
291
292 -- Precedences:
293 -- 10 fun app
294 --  9 (.)
295 --  8 ==
296 --  7 ==>
297 --  6 forall
298
299 instance Show Expr where
300         showsPrec _ (Var s)     = showString s
301         showsPrec d (App e1 e2) = showParen (d>10) $
302                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
303         showsPrec _ (Conc [])   = showString "id"
304         showsPrec d (Conc [e])  = showsPrec d e
305         showsPrec d (Conc es)   = showParen (d>9) $
306                 showIntercalate (showString " . ") (map (showsPrec 10) es)
307         showsPrec _ (Lambda tv e) = showParen True $ 
308                                     showString "\\" .
309                                     showsPrec 0 tv .
310                                     showString " -> ".
311                                     showsPrec 0 e 
312         showsPrec _ (Pair e1 e2) = showParen True $ 
313                                    showsPrec 0 e1 .
314                                    showString "," .
315                                    showsPrec 0 e2
316         showsPrec _ Map           = showString "map"
317         showsPrec _ EitherMap     = showString "eitherMap"
318
319 showIntercalate :: ShowS -> [ShowS] -> ShowS
320 showIntercalate _ []  = id
321 showIntercalate _ [x] = x
322 showIntercalate i (x:xs) = x . i . showIntercalate i xs
323
324 instance Show TypedExpr where
325         showsPrec d (TypedExpr e t) = 
326                 showParen (d>10) $
327                         showsPrec 0 e .
328                         showString " :: " .
329                         showString (showTypePrec 0 t)
330
331 instance Show LambdaBE where
332         show (CurriedEquals _) = 
333                         "(==)"
334         show (LambdaBE v1 v2 be) = 
335                         "(" ++
336                         "\\" ++
337                         showsPrec 11 (unTypeExpr v1) "" ++
338                         " " ++
339                         showsPrec 11 (unTypeExpr v2) "" ++
340                         " -> " ++
341                         show be ++
342                         ")"
343
344 instance Show BoolExpr where
345         show (Equal e1 e2) = showsPrec 9 e1 $
346                              showString " == " $
347                              showsPrec 9 e2 ""
348         show (And [])      = "True"
349         show (And bes)     = intercalate " && " $ map show bes
350         show (AllZipWith lbe e1 e2) =
351                         "allZipWith " ++
352                         show lbe ++
353                         " " ++
354                         showsPrec 11 e1 "" ++
355                         " " ++
356                         showsPrec 11 e2 ""
357         show (AndEither lbe1 lbe2 e1 e2) =
358                         "andEither " ++
359                         show lbe1 ++
360                         " " ++
361                         show lbe2 ++
362                         " " ++
363                         showsPrec 11 e1 "" ++
364                         " " ++
365                         showsPrec 11 e2 ""
366         show (Condition tvars be1 be2) = 
367                         "forall " ++
368                         intercalate ", " (map show tvars) ++
369                         ".\n" ++
370                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
371                         indent 2 (show be2)
372         show (UnpackPair v1 v2 e be) = 
373                         "let (" ++
374                         showsPrec 0 v1 "" ++
375                         "," ++
376                         showsPrec 0 v2 "" ++
377                         ") = " ++
378                         showsPrec 0 e "" ++
379                         " in\n" ++
380                         indent 2 (show be)
381         show (TypeVarInst i be) = 
382                         "forall types t" ++
383                         show (2*i-1) ++
384                         ", t" ++
385                         show (2*i) ++
386                         ", function g" ++
387                         show i ++
388                         " :: t" ++
389                         show (2*i-1) ++
390                         " -> t" ++
391                         show (2*i) ++ 
392                         ".\n" ++
393                         indent 2 (show be)
394
395 indent :: Int -> String -> String
396 indent n = unlines . map (replicate n ' ' ++) . lines
397
398 showTypePrec :: Int -> Typ -> String
399 showTypePrec _ Int                          = "Int" 
400 showTypePrec _ (TVar (TypVar i))            = "a"++show i
401 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
402                                     |     b = "t" ++  show (2*i)
403 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
404                                                 showTypePrec 10 t1 ++
405                                                 " -> " ++
406                                                 showTypePrec 9 t2 
407 showTypePrec _ (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
408 showTypePrec _ (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
409                                                     " " ++ showTypePrec 11 t2
410 showTypePrec _ (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
411                                               "," ++ showTypePrec 0 t2 ++ ")"
412 showTypePrec _ t                            = error $ "Did not expect to show " ++ show t
413
414 paren :: Bool -> String -> String
415 paren b p   =  if b then "(" ++ p ++ ")" else p