64851fe869165010807aec82dbfc63a015885b96
[darcs-mirror-polyfix.git] / Expr.hs
1 {-# LANGUAGE PatternGuards, DeriveDataTypeable   #-}
2 module Expr where
3
4 import Data.List
5 import Data.Maybe
6 import ParseType
7
8 import Data.Generics hiding (typeOf)
9 import Data.Generics.Schemes
10
11 data TypedExpr = TypedExpr
12         { unTypeExpr    :: Expr
13         , typeOf        :: Typ
14         } deriving (Eq, Typeable, Data)
15
16 typedLeft, typedRight :: Expr -> Typ -> TypedExpr
17 typedLeft  e t = TypedExpr e (instType False t)
18 typedRight e t = TypedExpr e (instType True t)
19
20 data Expr
21         = Var String
22         | App Expr Expr
23         | Conc [Expr] -- Conc [] is Id
24         | Lambda TypedExpr Expr
25         | Pair Expr Expr
26         | Map
27             deriving (Eq, Typeable, Data)
28
29 data BoolExpr 
30         = Equal Expr Expr
31         | And [BoolExpr] -- And [] is True
32         | AllZipWith TypedExpr TypedExpr BoolExpr Expr Expr
33         | Condition [TypedExpr] BoolExpr BoolExpr
34         | UnpackPair TypedExpr TypedExpr TypedExpr BoolExpr
35         | TypeVarInst Int BoolExpr
36             deriving (Eq, Typeable, Data)
37
38 -- Smart constructors
39
40 -- | Try eta reduction
41 equal te1 te2 | typeOf te1 /= typeOf te2 = error "Type mismatch in equal"
42               | otherwise                = equal' (unTypeExpr te1) (unTypeExpr te2)
43
44 equal' e1 e2  | (Just (lf,lv)) <- isFunctionApplication e1
45               , (Just (rf,rv)) <- isFunctionApplication e2
46               , lv == rv 
47                                          = equal' lf rf
48               -- This makes it return True...
49               | e1 == e2                 = beTrue
50               | otherwise                = Equal e1 e2
51
52 isFunctionApplication (App f e') | (Just (inner,v)) <- isFunctionApplication e'
53                                  = Just (conc f inner, v)
54                                  | otherwise
55                                  = Just (Conc [f], e')
56 isFunctionApplication _          = Nothing
57
58
59 unpackPair = UnpackPair
60
61 allZipWith :: TypedExpr -> TypedExpr -> BoolExpr -> TypedExpr -> TypedExpr -> BoolExpr
62 allZipWith v1 v2 rel e1 e2 | Just v1' <- defFor v1 rel =
63                                 e1 `equal` amap (lambda v2 v1') e2
64                            | Just v2' <- defFor v2 rel =
65                                 amap (lambda v1 v2') e1 `equal` e2
66                            | otherwise =
67                                 AllZipWith v1 v2 rel (unTypeExpr e1) (unTypeExpr e2)
68
69 amap tf tl | Arrow t1 t2 <- typeOf tf
70            , List t      <- typeOf tl
71            , t1 == t
72            = let tMap = TypedExpr Map (Arrow (Arrow t1 t2) (Arrow (List t1) (List t2)))
73              in app (app tMap tf) tl
74 amap tf tl | otherwise = error "Type error in map"
75
76 -- Float out foralls without condition
77 aand (Condition vars beTrue concl) be = condition vars beTrue (aand concl be)
78 aand (And xs) (And ys) = And (xs  ++ ys)
79 aand (And xs) y        = And (xs  ++ [y])
80 aand x        (And ys) = And ([x] ++ ys)
81 aand x        y        = And ([x,y])
82
83 beTrue = And []
84
85 -- | Is any var (or part of var) defined in cond, and can be replaced in concl?
86 condition :: [TypedExpr] -> BoolExpr -> BoolExpr -> BoolExpr
87 -- empty condition
88 condition [] cond concl   | cond == beTrue
89                           = concl
90 -- float out conditions on the right
91 condition vars cond (Condition vars' cond' concl')
92                           = condition (vars ++ vars') (cond `aand` cond') concl'
93
94 -- Try to find variables that are functions of other variables, and remove them
95 condition vars cond concl | True -- set to false to disable
96                           , ((vars',cond',concl'):_) <- mapMaybe try vars
97                           = condition vars' cond' concl'
98   where try v | Just subst <- findReplacer v cond
99               = Just (delete v vars, subst cond, subst concl)
100               | not (unTypeExpr v `occursIn` cond || unTypeExpr v `occursIn` concl)
101               = Just (delete v vars, cond, concl)
102               | otherwise
103               = Nothing
104
105 -- Nothing left to optizmize
106 condition vars cond concl = Condition vars cond concl
107
108 -- | Replaces a Term in a BoolExpr
109 replaceTermBE :: Expr -> Expr -> BoolExpr -> BoolExpr
110 replaceTermBE d r = go
111   where go (e1 `Equal` e2) | d == e1 && r == e2 = beTrue
112                            | d == e2 && r == e1 = beTrue
113                            | otherwise          = go' e1 `equal'` go' e2
114         go (And es)        = foldr aand beTrue (map go es)
115         go (AllZipWith v1 v2 be e1 e2) 
116                            = AllZipWith v1 v2 (go be) (go' e1) (go' e2)
117         go (Condition vs cond concl)
118                            = condition vs (go cond) (go concl)
119         go (UnpackPair v1 v2 e be)
120                            = unpackPair v1 v2 (go' e) (go be)
121         go (TypeVarInst _ _) = error "TypeVarInst not expected here"
122         go' :: Data a => a -> a
123         go' = replaceExpr d r
124
125 replaceExpr :: Data a => Expr -> Expr -> a -> a
126 replaceExpr d r = everywhere (mkT go)
127   where go e | e == d    = r 
128              | otherwise = e
129
130 -- | Is inside the term a definition for the variable?
131 findReplacer :: TypedExpr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
132 findReplacer tv be = findReplacer' (unTypeExpr tv) be
133         
134 -- | Find a definition, and return a substitution
135 findReplacer' :: Expr -> BoolExpr -> Maybe (BoolExpr -> BoolExpr)
136 -- For combined types, look up the components
137 findReplacer' (Pair x y) e | Just (delX) <- findReplacer' x e
138                            , Just (delY) <- findReplacer' y e
139                     = Just (delX . delY)
140 -- Find the definition
141 findReplacer' e (e1 `Equal` e2) | e == e1    = Just (replaceTermBE e e2)
142                                 | e == e2    = Just (replaceTermBE e e1)
143 findReplacer' e (And es)        = listToMaybe (mapMaybe (findReplacer' e) es)
144                                   -- assuming no two definitions can exist
145 findReplacer' _ _               = Nothing
146
147 -- | Is inside the term a definition for the variable?
148 defFor :: TypedExpr -> BoolExpr -> Maybe (TypedExpr)
149 defFor tv be | Just (e') <- defFor' (unTypeExpr tv) be
150                          = Just (TypedExpr e' (typeOf tv))
151              | otherwise = Nothing
152         
153 -- | Find a definition, and return it along the definition remover
154 defFor' :: Expr -> BoolExpr -> Maybe (Expr)
155 defFor' e (e1 `Equal` e2) | e == e1                 = Just (e2)
156                           | e == e2                 = Just (e1)
157 defFor' e (And es)        | [d]  <- mapMaybe (defFor' e) es -- exactly one definition
158                                                     = Just d
159 defFor' _ _                                         = Nothing
160
161 app te1 te2 | Arrow t1 t2 <- typeOf te1
162             , t3          <- typeOf te2 
163             , t1 == t3 
164             = TypedExpr (app' (unTypeExpr te1) (unTypeExpr te2)) t2
165  where app' Map (Conc []) = Conc []
166        app' (Conc []) v   = v
167        app' f v           = App f v
168 app te1 te2 | otherwise                          = error $ "Type mismatch in app: " ++
169                                                            show te1 ++ " " ++ show te2
170
171 {- dead code
172 unCond v (Equal l r) | (Just l') <- isApplOn (unTypeExpr v) l 
173                      , (Just r') <- isApplOn (unTypeExpr v) r = 
174         if v `occursIn` l' || v `occursIn` r'
175         then Condition [v] beTrue (Equal l' r')
176         else (Equal l' r')
177 unCond v e = Condition [v] beTrue e
178 -}
179
180 lambda tv e = TypedExpr inner (Arrow (typeOf tv) (typeOf e))
181   where inner | (Just e') <- isApplOn (unTypeExpr tv) (unTypeExpr e)
182               , not (unTypeExpr tv `occursIn` e')
183                           = e'
184               | tv == e   = Conc []
185               | otherwise = Lambda tv (unTypeExpr e)
186
187 conc (Conc xs) (Conc ys) = Conc (xs  ++ ys)
188 conc (Conc xs)  y        = Conc (xs  ++ [y])
189 conc x         (Conc ys) = Conc ([x] ++ ys)
190 conc x          y        = Conc ([x,y])
191
192 -- Helpers
193
194 isApplOn e e'         | e == e'                       = Nothing
195 isApplOn e (App f e') | e == e'                       = Just (Conc [f])
196 isApplOn e (App f e') | (Just inner) <- isApplOn e e' = Just (conc f inner)
197 isApplOn _ _                                          = Nothing
198
199 hasVar v (Var v')     = v == v'
200 hasVar v (App e1 e2)  = hasVar v e1 && hasVar v e2
201 hasVar v (Conc es)    = any (hasVar v) es
202 hasVar v (Lambda _ e) = hasVar v e
203 hasVar v Map          = False
204
205 e `occursIn` e'       = not (null (listify (==e) e'))
206
207 isTuple (TPair _ _) = True
208 isTuple _           = False
209
210
211 -- showing
212
213 -- Precedences:
214 -- 10 fun app
215 --  9 (.)
216 --  8 ==
217 --  7 ==>
218 --  6 forall
219
220 instance Show Expr where
221         showsPrec d (Var s)     = showString s
222         showsPrec d (App e1 e2) = showParen (d>10) $
223                 showsPrec 10 e1 . showChar ' ' . showsPrec 11 e2
224         showsPrec d (Conc [])   = showString "id"
225         showsPrec d (Conc [e])  = showsPrec d e
226         showsPrec d (Conc es)   = showParen (d>9) $
227                 showIntercalate (showString " . ") (map (showsPrec 10) es)
228         showsPrec d (Lambda tv e) = showParen True $ 
229                                     showString "\\" .
230                                     showsPrec 0 tv .
231                                     showString " -> ".
232                                     showsPrec 0 e 
233         showsPrec _ (Pair e1 e2) = showParen True $ 
234                                    showsPrec 0 e1 .
235                                    showString "," .
236                                    showsPrec 0 e2
237         showsPrec _ Map           = showString "map"
238
239 showIntercalate i []  = id
240 showIntercalate i [x] = x
241 showIntercalate i (x:xs) = x . i . showIntercalate i xs
242
243 instance Show TypedExpr where
244         showsPrec d (TypedExpr e t) = 
245                 showParen (d>10) $
246                         showsPrec 0 e .
247                         showString " :: " .
248                         showString (showTypePrec 0 t)
249
250 instance Show BoolExpr where
251         show (Equal e1 e2) = showsPrec 9 e1 $
252                              showString " == " $
253                              showsPrec 9 e2 ""
254         show (And [])      = "True"
255         show (And bes)     = intercalate " && " $ map show bes
256         show (AllZipWith v1 v2 be e1 e2) =
257                         "allZipWith " ++
258                         "( " ++
259                         "\\" ++
260                         showsPrec 11 v1 "" ++
261                         " " ++
262                         showsPrec 11 v2 "" ++
263                         " -> " ++
264                         show be ++
265                         ")" ++
266                         " " ++
267                         showsPrec 11 e1 "" ++
268                         " " ++
269                         showsPrec 11 e2 ""
270         show (Condition tvars be1 be2) = 
271                         "forall " ++
272                         intercalate ", " (map show tvars) ++
273                         ".\n" ++
274                         (if be1 /= beTrue then indent 2 (show be1) ++ "==>\n" else "") ++
275                         indent 2 (show be2)
276         show (UnpackPair v1 v2 e be) = 
277                         "let (" ++
278                         showsPrec 0 v1 "" ++
279                         "," ++
280                         showsPrec 0 v2 "" ++
281                         ") = " ++
282                         showsPrec 0 e "" ++
283                         " in\n" ++
284                         indent 2 (show be)
285         show (TypeVarInst i be) = 
286                         "forall types t" ++
287                         show (2*i-1) ++
288                         ", t" ++
289                         show (2*i) ++
290                         ", function g" ++
291                         show i ++
292                         " :: t" ++
293                         show (2*i-1) ++
294                         " -> t" ++
295                         show (2*i) ++ 
296                         ".\n" ++
297                         indent 2 (show be)
298
299 indent n = unlines . map (replicate n ' ' ++) . lines
300
301 showTypePrec :: Int -> Typ -> String
302 showTypePrec _ Int                          = "Int" 
303 showTypePrec _ (TVar (TypVar i))            = "a"++show i
304 showTypePrec _ (TVar (TypInst i b)) | not b = "t" ++  show (2*i-1)
305                                     |     b = "t" ++  show (2*i)
306 showTypePrec d (Arrow t1 t2)                = paren (d>9) $ 
307                                   showTypePrec 10 t1 ++ " -> " ++ showTypePrec 9 t2 
308 showTypePrec d (List t)                     = "[" ++ showTypePrec 0 t ++ "]"
309 showTypePrec d (TEither t1 t2)              = "Either " ++ showTypePrec 11 t1 ++ 
310                                                     " " ++ showTypePrec 11 t2
311 showTypePrec d (TPair t1 t2)                = "(" ++ showTypePrec 0 t1 ++
312                                               "," ++ showTypePrec 0 t2 ++ ")"
313
314 paren b p   =  if b then "(" ++ p ++ ")" else p