Remove TypSamples
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 -- import Debug.Trace
9
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
12
13 data TypedExpr = TypedExpr
14         { unTypeExpr    :: Expr
15         , typeOf        :: Typ
16         } deriving (Eq, Typeable, Data)
17
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft  e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
21
22 data EVar = F
23           | FromTypVar Int
24           | FromParam Int Bool
25             deriving (Eq, Typeable, Data)
26
27 data Expr
28         = Var EVar
29         | App Expr Expr
30         | Conc [Expr] -- Conc [] is Id
31         | Lambda TypedExpr Expr
32         | Pair Expr Expr
33         | Map
34         | ConstUnit
35         | EitherMap
36         | EUnit
37             deriving (Eq, Typeable, Data)
38
39 data LambdaBE = CurriedEquals Typ
40               | LambdaBE TypedExpr TypedExpr BoolExpr
41             deriving (Eq, Typeable, Data)
42
43 data BoolExpr 
44         = Equal Expr Expr
45         | And [BoolExpr] -- And [] is True
46         | AllZipWith LambdaBE Expr Expr
47         | AndEither  LambdaBE LambdaBE Expr Expr
48         | Condition [TypedExpr] BoolExpr BoolExpr
49         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
50         | TypeVarInst Bool Int BoolExpr
51             deriving (Eq, Typeable, Data)
52
53 -- Smart constructors
54
55 -- | Try eta reduction
56 equal :: TypedExpr -> TypedExpr -> BoolExpr
57 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
58               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
59
60 equal' :: Expr -> Expr -> BoolExpr
61 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
62               , (Just (rf,rv)) <- isFunctionApplication e2
63               , lv == rv 
64                                          = equal' lf rf
65               -- This makes it return True...
66               | e1 == e2                 = beTrue
67               | otherwise                = Equal e1 e2
68
69 -- | If e is of the type (app f1 (app f2 .. (app fn x)..)),
70 --   return Just (f1 . f2. ... . fn, x)
71 isFunctionApplication :: Expr -> Maybe (Expr, Expr)
72 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
73                                  = Just (conc f inner, v)
74                                  | otherwise
75                                  = Just (f, e')
76 isFunctionApplication _          = Nothing
77
78
79 -- | If both bound variables are just functions, we can replace this
80 --   by a comparison
81 unpackPair :: TypedExpr -> TypedExpr -> TypedExpr -> BoolExpr -> BoolExpr
82 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
83                        , Just subst2 <- findReplacer v2 be
84                        = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
85
86 -- | If the whole tuple is a function, we can replace this
87 --   by a comparison
88 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
89                        = subst $ (pair v1 v2 `equal` te) `aand` be
90
91 -- | Nothing to optimize
92 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
93
94 pair :: TypedExpr -> TypedExpr -> TypedExpr
95 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
96
97 lambdaBE :: TypedExpr -> TypedExpr -> BoolExpr -> LambdaBE
98 lambdaBE v1 v2 rel | typeOf v1 == typeOf v2 
99                    , rel == v1 `equal` v2    = CurriedEquals (typeOf v1)
100                    | otherwise               = LambdaBE v1 v2 rel
101
102 andEither :: LambdaBE -> LambdaBE -> TypedExpr -> TypedExpr -> BoolExpr
103 andEither (CurriedEquals _) (CurriedEquals _) e1 e2 = e1 `equal` e2
104 andEither lbe1 lbe2 e1 e2 | Just f1 <- arg1IsFunc lbe1
105                           , Just f2 <- arg1IsFunc lbe2
106                           = e1 `equal` eitherE f1 f2 e2
107                           | Just f1 <- arg2IsFunc lbe1
108                           , Just f2 <- arg2IsFunc lbe2
109                           = eitherE f1 f2 e1 `equal` e2
110                           | otherwise
111                           = AndEither lbe1 lbe2 (unTypeExpr e1) (unTypeExpr e2)
112
113 arg1IsFunc :: LambdaBE -> Maybe TypedExpr
114 arg1IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
115 arg1IsFunc (LambdaBE v1 v2 rel) | Just v1' <- defFor v1 rel
116                                 = Just (lambda v2 v1')
117                                 | otherwise = Nothing
118
119 arg2IsFunc :: LambdaBE -> Maybe TypedExpr
120 arg2IsFunc (CurriedEquals t)    = Just $ TypedExpr (Conc []) (Arrow t t)
121 arg2IsFunc (LambdaBE v1 v2 rel) | Just v2' <- defFor v2 rel
122                                 = Just (lambda v1 v2')
123                                 | otherwise = Nothing
124
125 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
126 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
127                                 e1 `equal` amap (lambda v2 v1') e2
128                            | Just v2' <- defFor v2 rel =
129                                 amap (lambda v1 v2') e1 `equal` e2
130                            | otherwise =
131                                 AllZipWith (LambdaBE v1 v2 rel) (unTypeExpr e1) (unTypeExpr e2)
132
133 eitherE :: TypedExpr -> TypedExpr -> TypedExpr -> TypedExpr
134 eitherE f1 f2 e | Arrow lt1 lt2 <- typeOf f1
135                 , Arrow rt1 rt2 <- typeOf f2
136                 , TEither lt rt <- typeOf e
137                 , lt1 == lt
138                 , rt1 == rt
139         = let tEither = TypedExpr EitherMap (Arrow (typeOf f1) (Arrow (typeOf f2) (Arrow (typeOf e) (TEither lt2 rt2))))
140           in  app (app (app tEither f1) f2) e
141                 | otherwise = error $ "Type error in eitherE\n" ++ show (f1, f2, e)
142
143 amap :: TypedExpr -> TypedExpr -> TypedExpr
144 amap tf tl | Arrow t1 t2 <- typeOf tf
145            , List t      <- typeOf tl
146            , t1 == t
147            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
148              in app (app tMap tf) tl
149            | otherwise = error "Type error in map"
150
151 aand :: BoolExpr -> BoolExpr -> BoolExpr
152 aand (And xs) (And ys) = And (xs  ++ ys)
153 aand (And xs) y        = And (xs  ++ [y])
154 aand x        (And ys) = And ([x] ++ ys)
155 aand x        y        = And ([x,y])
156
157 beTrue :: BoolExpr
158 beTrue = And []
159
160 -- | Optimize a forall condition
161 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
162 -- empty condition
163 condition [] cond concl   | cond == beTrue
164                           = concl
165 -- float out conditions on the right
166 condition vars cond (Condition vars' cond' concl')
167                           = condition (vars ++ vars') (cond `aand` cond') concl'
168
169 -- Try to find variables that are functions of other variables, and remove them
170 condition vars cond concl | True -- set to false to disable
171                           , ((vars',cond',concl'):_) <- mapMaybe try vars
172                           = condition vars' cond' concl'
173               -- A variable which can be replaced
174   where try v | Just subst <- findReplacer v cond
175               = -- trace ("Tested " ++ show v ++ ", can be replaced") $
176                 Just (delete v vars, subst cond, subst concl)
177  
178               -- A variable with can be removed
179               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
180               = -- trace ("Tested " ++ show v ++ ", can be reased") $
181                 Just (delete v vars, cond, concl)
182
183               -- Nothing to do with this variable
184               | otherwise
185               = -- trace ("Tested " ++ show v ++ " without success") $
186                 Nothing
187
188 -- Nothing left to optizmize
189 condition vars cond concl = Condition vars cond concl
190
191 -- | Replaces a Term in a BoolExpr
192 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
193 replaceTermBE d r = go
194   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
195                            | d == e2 && r == e1 = beTrue
196                            | otherwise          = go' e1 `equal'` go' e2
197         go (And es)        = foldr aand beTrue (map go es)
198         go (AllZipWith lbe e1 e2) 
199                            = AllZipWith (goL lbe) (go' e1) (go' e2)
200         go (AndEither lbe1 lbe2 e1 e2)
201                            = AndEither (goL lbe1) (goL lbe2) (go' e1) (go' e2)
202         go (Condition vs cond concl)
203                            = condition vs (go cond) (go concl)
204         go (UnpackPair v1 v2 e be)
205                            = unpackPair v1 v2 (goT e) (go be)
206         go (TypeVarInst _ _ _) = error "TypeVarInst not expected here"
207
208         go' = replaceExpr d r
209
210         goT te = te { unTypeExpr = go' (unTypeExpr te) }
211
212         goL (CurriedEquals t)   = (CurriedEquals t)
213         goL (LambdaBE v1 v2 be) = lambdaBE v1 v2 (go be)
214
215
216 replaceExpr :: Expr -> Expr -> Expr -> Expr
217 replaceExpr d r = go
218   where go e | e == d    = r
219         go (App e1 e2)   = app' (go e1) (go e2)
220         go (Conc es)     = foldr conc (Conc []) (map go es)
221         go (Lambda te e) = lambda' te (go e)
222         go (Pair e1 e2)  = Pair (go e1) (go e2)
223         go e             = e
224
225
226 -- | Is inside the term a definition for the variable?
227 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
228 findReplacer tv be = findReplacer' (unTypeExpr tv) be
229         
230 -- | Find a definition, and return a substitution
231 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
232 -- For combined types, look up the components
233 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
234                            , Just (delY) <- findReplacer' y e
235                     = Just (delX . delY)
236 -- Find the definition
237 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
238                                 | e == e2    = Just (replaceTermBE e e1)
239 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
240                                   -- assuming no two definitions can exist
241 findReplacer' _ _               = Nothing
242
243 -- | Is inside the term a definition for the variable?
244 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
245 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
246                          = Just (TypedExpr e' (typeOf tv))
247              | otherwise = Nothing
248         
249 -- | Find a definition, and return it along the definition remover
250 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
251 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
252                           | e == e2                 = Just (e1)
253 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
254                                                     = Just d
255 defFor' _ _                                         = Nothing
256
257 app :: TypedExpr -> TypedExpr -> TypedExpr
258 app te1 te2 | Arrow t1 t2 <- typeOf te1
259             , t3          <- typeOf te2 
260             , t1 == t3 
261             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
262 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
263                                                            show te1 ++ " " ++ show te2
264
265 app' :: Expr -> Expr -> Expr
266 app' Map (Conc []) = Conc []   -- map id = id
267 app' ConstUnit _   = EUnit     -- id x   = x
268 app' (Conc []) v   = v         -- id x   = x
269 app' f v           = App f v
270
271 lambda :: TypedExpr -> TypedExpr -> TypedExpr
272 lambda tv e = TypedExpr (lambda' tv (unTypeExpr e)) (Arrow (typeOf tv) (typeOf e))
273
274 lambda' :: TypedExpr -> Expr -> Expr
275 lambda' tv e  | (Just e') <- isApplOn (unTypeExpr tv) e
276               , not (unTypeExpr tv `occursIn` e')
277                                      = e'
278               | unTypeExpr tv == e   = Conc []
279               | otherwise            = Lambda tv e
280
281 conc :: Expr -> Expr -> Expr
282 conc (Conc xs) (Conc ys) = Conc (xs  ++ ys)
283 conc (Conc xs)  y        = Conc (xs  ++ [y])
284 conc x         (Conc ys) = Conc ([x] ++ ys)
285 conc x          y        = Conc ([x,y])
286
287
288 -- Specialization of g'
289
290 specialize :: BoolExpr -> BoolExpr
291 specialize (TypeVarInst strict i be') = 
292                 replaceTermBE (Var (FromTypVar i)) (if strict then Conc [] else ConstUnit) .
293                 everywhere (mkT $ go) $
294                 be
295         where be = specialize be'
296               go (TypInst i' _) | i' == i = TUnit
297               go tv                       = tv                 
298 -- No need to go further once we are through the quantors
299 specialize be = be
300
301 -- Helpers
302
303 isApplOn :: Expr -> Expr -> Maybe Expr
304 isApplOn e e'         | e == e'                       = Nothing
305 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
306 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
307 isApplOn _ _                                          = Nothing
308
309 occursIn :: (Typeable a, Data a1, Eq a) => a -> a1 -> Bool
310 e `occursIn` e'       = not (null (listify (==e) e'))
311
312 isTuple :: Typ -> Bool
313 isTuple (TPair _ _) = True
314 isTuple _           = False
315
316
317 -- showing
318
319 -- Precedences:
320 -- 10 fun app
321 --  9 (.)
322 --  8 ==
323 --  7 ==>
324 --  6 forall
325
326 instance Show EVar where
327         show F               = "f"
328         show (FromTypVar i)  = "g" ++ show i
329         show (FromParam i b) = "x" ++ show i ++ if b then "'" else ""
330
331
332 instance Show Expr where
333         showsPrec _ (Var v)     = showsPrec 11 v
334         showsPrec d (App e1 e2) = showParen (d>10) $
335                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
336         showsPrec _ (Conc [])   = showString "id"
337         showsPrec d (Conc [e])  = showsPrec d e
338         showsPrec d (Conc es)   = showParen (d>9) $
339                 showIntercalate (showString " . ") (map (showsPrec 10) es)
340         showsPrec _ (Lambda tv e) = showParen True $ 
341                                     showString "\\" .
342                                     showsPrec 0 tv .
343                                     showString " -> ".
344                                     showsPrec 0 e 
345         showsPrec _ (Pair e1 e2) = showParen True $ 
346                                    showsPrec 0 e1 .
347                                    showString "," .
348                                    showsPrec 0 e2
349         showsPrec _ EUnit         = showString "()"
350         showsPrec _ Map           = showString "map"
351         showsPrec d ConstUnit     = showParen (d>10) $ showString "const ()"
352         showsPrec _ EitherMap     = showString "eitherMap"
353
354 showIntercalate :: ShowS -> [ShowS] -> ShowS
355 showIntercalate _ []  = id
356 showIntercalate _ [x] = x
357 showIntercalate i (x:xs) = x . i . showIntercalate i xs
358
359 instance Show TypedExpr where
360         showsPrec d (TypedExpr e t) = 
361                 showParen (d>10) $
362                         showsPrec 0 e .
363                         showString " :: " .
364                         showString (showTypePrec 0 t)
365
366 instance Show LambdaBE where
367         show (CurriedEquals _) = 
368                         "(==)"
369         show (LambdaBE v1 v2 be) = 
370                         "(" ++
371                         "\\" ++
372                         showsPrec 11 (unTypeExpr v1) "" ++
373                         " " ++
374                         showsPrec 11 (unTypeExpr v2) "" ++
375                         " -> " ++
376                         show be ++
377                         ")"
378
379 instance Show BoolExpr where
380         show (Equal e1 e2) = showsPrec 9 e1 $
381                              showString " == " $
382                              showsPrec 9 e2 ""
383         show (And [])      = "True"
384         show (And bes)     = intercalate " && " $ map show bes
385         show (AllZipWith lbe e1 e2) =
386                         "allZipWith " ++
387                         show lbe ++
388                         " " ++
389                         showsPrec 11 e1 "" ++
390                         " " ++
391                         showsPrec 11 e2 ""
392         show (AndEither lbe1 lbe2 e1 e2) =
393                         "andEither " ++
394                         show lbe1 ++
395                         " " ++
396                         show lbe2 ++
397                         " " ++
398                         showsPrec 11 e1 "" ++
399                         " " ++
400                         showsPrec 11 e2 ""
401         show (Condition tvars be1 be2) = 
402                         "forall " ++
403                         intercalate ", " (map show tvars) ++
404                         ".\n" ++
405                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
406                         indent 2 (show be2)
407         show (UnpackPair v1 v2 e be) = 
408                         "let (" ++
409                         showsPrec 0 v1 "" ++
410                         "," ++
411                         showsPrec 0 v2 "" ++
412                         ") = " ++
413                         showsPrec 0 e "" ++
414                         " in\n" ++
415                         indent 2 (show be)
416         show (TypeVarInst strict i be) = 
417                         "forall types t" ++
418                         show (2*i-1) ++
419                         ", t" ++
420                         show (2*i) ++
421                         ", " ++
422                         (if strict then "strict " else "(strict) ") ++
423                         "functions g" ++
424                         show i ++
425                         " :: t" ++
426                         show (2*i-1) ++
427                         " -> t" ++
428                         show (2*i) ++ 
429                         ".\n" ++
430                         indent 2 (show be)
431
432 indent :: Int -> String -> String
433 indent n = unlines . map (replicate n ' ' ++) . lines
434
435 showTypePrec :: Int -> Typ -> String
436 showTypePrec _ Int                          = "Int" 
437 showTypePrec _ (TVar (TypVar i))            = "a"++show i
438 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
439                                     |     b = "t" ++  show (2*i)
440 showTypePrec _ (TVar (TUnit))               = "()"
441 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
442                                                 showTypePrec 10 t1 ++
443                                                 " -> " ++
444                                                 showTypePrec 9 t2 
445 showTypePrec _ (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
446 showTypePrec _ (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
447                                                     " " ++ showTypePrec 11 t2
448 showTypePrec _ (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
449                                               "," ++ showTypePrec 0 t2 ++ ")"
450 showTypePrec _ t                            = error $ "Did not expect to show " ++ show t
451
452 paren :: Bool -> String -> String
453 paren b p   =  if b then "(" ++ p ++ ")" else p