a49bc0969bd1cf6536ed5e3d9aaf542b93419695
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 -- import Debug.Trace
9
10 import Data.Generics hiding (typeOf)
11 import Data.Generics.Schemes
12
13 data TypedExpr = TypedExpr
14         { unTypeExpr    :: Expr
15         , typeOf        :: Typ
16         } deriving (Eq, Typeable, Data)
17
18 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
19 typedLeft  e t = TypedExpr e (instType False t)
20 typedRight e t = TypedExpr e (instType True t)
21
22 data Expr
23         = Var String
24         | App Expr Expr
25         | Conc [Expr] -- Conc [] is Id
26         | Lambda TypedExpr Expr
27         | Pair Expr Expr
28         | Map
29             deriving (Eq, Typeable, Data)
30
31 data LambdaBE = CurriedEquals
32               | LambdaBE TypedExpr TypedExpr BoolExpr
33             deriving (Eq, Typeable, Data)
34
35 data BoolExpr 
36         = Equal Expr Expr
37         | And [BoolExpr] -- And [] is True
38         | AllZipWith LambdaBE Expr Expr
39         | AndEither  LambdaBE LambdaBE Expr Expr
40         | Condition [TypedExpr] BoolExpr BoolExpr
41         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
42         | TypeVarInst Int BoolExpr
43             deriving (Eq, Typeable, Data)
44
45 -- Smart constructors
46
47 -- | Try eta reduction
48 equal :: TypedExpr -> TypedExpr -> BoolExpr
49 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
50               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
51
52 equal' :: Expr -> Expr -> BoolExpr
53 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
54               , (Just (rf,rv)) <- isFunctionApplication e2
55               , lv == rv 
56                                          = equal' lf rf
57               -- This makes it return True...
58               | e1 == e2                 = beTrue
59               | otherwise                = Equal e1 e2
60
61 -- | If e is of the type (app f1 (app f2 .. (app fn x)..)),
62 --   return Just (f1 . f2. ... . fn, x)
63 isFunctionApplication :: Expr -> Maybe (Expr, Expr)
64 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
65                                  = Just (conc f inner, v)
66                                  | otherwise
67                                  = Just (f, e')
68 isFunctionApplication _          = Nothing
69
70
71 -- | If both bound variables are just functions, we can replace this
72 --   by a comparison
73 unpackPair :: TypedExpr -> TypedExpr -> TypedExpr -> BoolExpr -> BoolExpr
74 unpackPair v1 v2 te be | Just subst1 <- findReplacer v1 be
75                        , Just subst2 <- findReplacer v2 be
76                        = subst1. subst2 $ (pair v1 v2 `equal` te) `aand` be
77
78 -- | If the whole tuple is a function, we can replace this
79 --   by a comparison
80 unpackPair v1 v2 te be | Just subst <- findReplacer (pair v1 v2) be
81                        = subst $ (pair v1 v2 `equal` te) `aand` be
82
83 -- | Nothing to optimize
84 unpackPair v1 v2 te be = UnpackPair v1 v2 te be
85
86 pair :: TypedExpr -> TypedExpr -> TypedExpr
87 pair (TypedExpr e1 t1) (TypedExpr e2 t2) = TypedExpr (Pair e1 e2) (TPair t1 t2)
88
89 lambdaBE :: TypedExpr -> TypedExpr -> BoolExpr -> LambdaBE
90 lambdaBE v1 v2 rel | typeOf v1 == typeOf v2 
91                    , rel == v1 `equal` v2    = CurriedEquals
92                    | otherwise               = LambdaBE v1 v2 rel
93
94 andEither :: LambdaBE -> LambdaBE -> TypedExpr -> TypedExpr -> BoolExpr
95 andEither CurriedEquals CurriedEquals e1 e2 = e1 `equal` e2
96 andEither lbe1 lbe2 e1 e2 =
97         AndEither lbe1 lbe2 (unTypeExpr e1) (unTypeExpr e2)
98
99 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
100 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
101                                 e1 `equal` amap (lambda v2 v1') e2
102                            | Just v2' <- defFor v2 rel =
103                                 amap (lambda v1 v2') e1 `equal` e2
104                            | otherwise =
105                                 AllZipWith (LambdaBE v1 v2 rel) (unTypeExpr e1) (unTypeExpr e2)
106
107 amap :: TypedExpr -> TypedExpr -> TypedExpr
108 amap tf tl | Arrow t1 t2 <- typeOf tf
109            , List t      <- typeOf tl
110            , t1 == t
111            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
112              in app (app tMap tf) tl
113 amap _ _   | otherwise = error "Type error in map"
114
115 aand :: BoolExpr -> BoolExpr -> BoolExpr
116 aand (And xs) (And ys) = And (xs  ++ ys)
117 aand (And xs) y        = And (xs  ++ [y])
118 aand x        (And ys) = And ([x] ++ ys)
119 aand x        y        = And ([x,y])
120
121 beTrue :: BoolExpr
122 beTrue = And []
123
124 -- | Optimize a forall condition
125 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
126 -- empty condition
127 condition [] cond concl   | cond == beTrue
128                           = concl
129 -- float out conditions on the right
130 condition vars cond (Condition vars' cond' concl')
131                           = condition (vars ++ vars') (cond `aand` cond') concl'
132
133 -- Try to find variables that are functions of other variables, and remove them
134 condition vars cond concl | True -- set to false to disable
135                           , ((vars',cond',concl'):_) <- mapMaybe try vars
136                           = condition vars' cond' concl'
137               -- A variable which can be replaced
138   where try v | Just subst <- findReplacer v cond
139               = -- trace ("Tested " ++ show v ++ ", can be replaced") $
140                 Just (delete v vars, subst cond, subst concl)
141  
142               -- A variable with can be removed
143               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
144               = -- trace ("Tested " ++ show v ++ ", can be reased") $
145                 Just (delete v vars, cond, concl)
146
147               -- Nothing to do with this variable
148               | otherwise
149               = -- trace ("Tested " ++ show v ++ " without success") $
150                 Nothing
151
152 -- Nothing left to optizmize
153 condition vars cond concl = Condition vars cond concl
154
155 -- | Replaces a Term in a BoolExpr
156 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
157 replaceTermBE d r = go
158   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
159                            | d == e2 && r == e1 = beTrue
160                            | otherwise          = go' e1 `equal'` go' e2
161         go (And es)        = foldr aand beTrue (map go es)
162         go (AllZipWith (LambdaBE v1 v2 be) e1 e2) 
163                            = AllZipWith (lambdaBE v1 v2 (go be)) (go' e1) (go' e2)
164         go (AndEither (LambdaBE l1 l2 be1) (LambdaBE r1 r2 be2) e1 e2)
165                            = AndEither (lambdaBE l1 l2 (go be1))
166                                        (lambdaBE r1 r2 (go be2))
167                                        (go' e1) (go' e2)
168         go (Condition vs cond concl)
169                            = condition vs (go cond) (go concl)
170         go (UnpackPair v1 v2 e be)
171                            = unpackPair v1 v2 (go' e) (go be)
172         go (TypeVarInst _ _) = error "TypeVarInst not expected here"
173         go' :: Data a => a -> a
174         go' = replaceExpr d r
175
176 replaceExpr :: Data a => Expr -> Expr -> a -> a
177 replaceExpr d r = everywhere (mkT go)
178   where go e | e == d    = r 
179              | otherwise = e
180
181 -- | Is inside the term a definition for the variable?
182 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
183 findReplacer tv be = findReplacer' (unTypeExpr tv) be
184         
185 -- | Find a definition, and return a substitution
186 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
187 -- For combined types, look up the components
188 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
189                            , Just (delY) <- findReplacer' y e
190                     = Just (delX . delY)
191 -- Find the definition
192 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
193                                 | e == e2    = Just (replaceTermBE e e1)
194 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
195                                   -- assuming no two definitions can exist
196 findReplacer' _ _               = Nothing
197
198 -- | Is inside the term a definition for the variable?
199 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
200 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
201                          = Just (TypedExpr e' (typeOf tv))
202              | otherwise = Nothing
203         
204 -- | Find a definition, and return it along the definition remover
205 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
206 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
207                           | e == e2                 = Just (e1)
208 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
209                                                     = Just d
210 defFor' _ _                                         = Nothing
211
212 app :: TypedExpr -> TypedExpr -> TypedExpr
213 app te1 te2 | Arrow t1 t2 <- typeOf te1
214             , t3          <- typeOf te2 
215             , t1 == t3 
216             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
217  where app' Map (Conc []) = Conc []   -- map id = id
218        app' (Conc []) v   = v         -- id x   = x
219        app' f v           = App f v
220 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
221                                                            show te1 ++ " " ++ show te2
222
223 lambda :: TypedExpr -> TypedExpr -> TypedExpr
224 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
225   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
226               , not (unTypeExpr tv `occursIn` e')
227                           = e'
228               | tv == e   = Conc []
229               | otherwise = Lambda tv (unTypeExpr e)
230
231 conc :: Expr -> Expr -> Expr
232 conc (Conc xs) (Conc ys) = Conc (xs  ++ ys)
233 conc (Conc xs)  y        = Conc (xs  ++ [y])
234 conc x         (Conc ys) = Conc ([x] ++ ys)
235 conc x          y        = Conc ([x,y])
236
237 -- Helpers
238
239 isApplOn :: Expr -> Expr -> Maybe Expr
240 isApplOn e e'         | e == e'                       = Nothing
241 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
242 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
243 isApplOn _ _                                          = Nothing
244
245 hasVar :: String -> Expr -> Bool
246 hasVar v (Var v')     = v == v'
247 hasVar v (App e1 e2)  = hasVar v e1 || hasVar v e2
248 hasVar v (Conc es)    = any (hasVar v) es
249 hasVar v (Lambda _ e) = hasVar v e
250 hasVar v (Pair e1 e2) = hasVar v e1 || hasVar v e2
251 hasVar _ Map          = False
252
253 occursIn :: (Typeable a, Data a1, Eq a) => a -> a1 -> Bool
254 e `occursIn` e'       = not (null (listify (==e) e'))
255
256 isTuple :: Typ -> Bool
257 isTuple (TPair _ _) = True
258 isTuple _           = False
259
260
261 -- showing
262
263 -- Precedences:
264 -- 10 fun app
265 --  9 (.)
266 --  8 ==
267 --  7 ==>
268 --  6 forall
269
270 instance Show Expr where
271         showsPrec _ (Var s)     = showString s
272         showsPrec d (App e1 e2) = showParen (d>10) $
273                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
274         showsPrec _ (Conc [])   = showString "id"
275         showsPrec d (Conc [e])  = showsPrec d e
276         showsPrec d (Conc es)   = showParen (d>9) $
277                 showIntercalate (showString " . ") (map (showsPrec 10) es)
278         showsPrec _ (Lambda tv e) = showParen True $ 
279                                     showString "\\" .
280                                     showsPrec 0 tv .
281                                     showString " -> ".
282                                     showsPrec 0 e 
283         showsPrec _ (Pair e1 e2) = showParen True $ 
284                                    showsPrec 0 e1 .
285                                    showString "," .
286                                    showsPrec 0 e2
287         showsPrec _ Map           = showString "map"
288
289 showIntercalate :: ShowS -> [ShowS] -> ShowS
290 showIntercalate _ []  = id
291 showIntercalate _ [x] = x
292 showIntercalate i (x:xs) = x . i . showIntercalate i xs
293
294 instance Show TypedExpr where
295         showsPrec d (TypedExpr e t) = 
296                 showParen (d>10) $
297                         showsPrec 0 e .
298                         showString " :: " .
299                         showString (showTypePrec 0 t)
300
301 instance Show LambdaBE where
302         show (CurriedEquals) = 
303                         "(==)"
304         show (LambdaBE v1 v2 be) = 
305                         "(" ++
306                         "\\" ++
307                         showsPrec 11 (unTypeExpr v1) "" ++
308                         " " ++
309                         showsPrec 11 (unTypeExpr v2) "" ++
310                         " -> " ++
311                         show be ++
312                         ")"
313
314 instance Show BoolExpr where
315         show (Equal e1 e2) = showsPrec 9 e1 $
316                              showString " == " $
317                              showsPrec 9 e2 ""
318         show (And [])      = "True"
319         show (And bes)     = intercalate " && " $ map show bes
320         show (AllZipWith lbe e1 e2) =
321                         "allZipWith " ++
322                         show lbe ++
323                         " " ++
324                         showsPrec 11 e1 "" ++
325                         " " ++
326                         showsPrec 11 e2 ""
327         show (AndEither lbe1 lbe2 e1 e2) =
328                         "andEither " ++
329                         show lbe1 ++
330                         " " ++
331                         show lbe2 ++
332                         " " ++
333                         showsPrec 11 e1 "" ++
334                         " " ++
335                         showsPrec 11 e2 ""
336         show (Condition tvars be1 be2) = 
337                         "forall " ++
338                         intercalate ", " (map show tvars) ++
339                         ".\n" ++
340                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
341                         indent 2 (show be2)
342         show (UnpackPair v1 v2 e be) = 
343                         "let (" ++
344                         showsPrec 0 v1 "" ++
345                         "," ++
346                         showsPrec 0 v2 "" ++
347                         ") = " ++
348                         showsPrec 0 e "" ++
349                         " in\n" ++
350                         indent 2 (show be)
351         show (TypeVarInst i be) = 
352                         "forall types t" ++
353                         show (2*i-1) ++
354                         ", t" ++
355                         show (2*i) ++
356                         ", function g" ++
357                         show i ++
358                         " :: t" ++
359                         show (2*i-1) ++
360                         " -> t" ++
361                         show (2*i) ++ 
362                         ".\n" ++
363                         indent 2 (show be)
364
365 indent :: Int -> String -> String
366 indent n = unlines . map (replicate n ' ' ++) . lines
367
368 showTypePrec :: Int -> Typ -> String
369 showTypePrec _ Int                          = "Int" 
370 showTypePrec _ (TVar (TypVar i))            = "a"++show i
371 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
372                                     |     b = "t" ++  show (2*i)
373 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
374                                                 showTypePrec 10 t1 ++
375                                                 " -> " ++
376                                                 showTypePrec 9 t2 
377 showTypePrec _ (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
378 showTypePrec _ (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
379                                                     " " ++ showTypePrec 11 t2
380 showTypePrec _ (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
381                                               "," ++ showTypePrec 0 t2 ++ ")"
382
383 paren :: Bool -> String -> String
384 paren b p   =  if b then "(" ++ p ++ ")" else p