
module Main
where

-- An operator
data Op   
	= OAdd | OSub   | OMul 
        | OAnd | ONotEq | OEq | OGt
	deriving (Show, Eq)

-- A literal value
data Lit  
	= LBool Bool | LInt Int
	deriving (Show, Eq)

-- An expression
data Exp
	= XLit  Lit
        | XPrim Op  Exp Exp
        | XIf   Exp Exp Exp
	deriving (Show, Eq)


-- Pretty print an operator.
pprO :: Op -> String 
pprO OAdd		= "+"
pprO OSub		= "-"
pprO OMul		= "*"
pprO OAnd		= "&"
pprO ONotEq		= "/="
pprO OEq		= "=="
pprO OGt		= ">"

-- Pretty print a literal.
pprL :: Lit -> String
pprL (LBool b)		= show b
pprL (LInt i)		= show i

-- Pretty print an expression.
pprX :: Exp -> String
pprX (XLit l)		= pprL l
pprX (XPrim o  x1 x2)	= "(" ++ pprX x1 ++ " " ++ pprO o ++ " " ++ pprX x2 ++ ")"
pprX (XIf   x1 x2 x3)	= "if " ++ pprX x1 ++ " then " ++ pprX x2 ++ " else " ++ pprX x3


-- Take an integer from a literal
takeI :: Lit -> Int
takeI (LInt i)	= i
takeI _		= error "takeI: not an int!"

-- Take a boolean from a literal
takeB :: Lit -> Bool
takeB (LBool b)	= b
takeB _		= error "takeB: not a bool!"


-- Evaluate an expression
evalX :: Exp -> Lit
evalX (XLit l)			= l

evalX (XPrim OAdd x1 x2)	= LInt  (takeI (evalX x1) + takeI (evalX x2))
evalX (XPrim OSub x1 x2)	= LInt  (takeI (evalX x1) - takeI (evalX x2))
evalX (XPrim OMul x1 x2)	= LInt  (takeI (evalX x1) * takeI (evalX x2))
evalX (XPrim OGt  x1 x2)	= LBool (takeI (evalX x1) > takeI (evalX x2))

evalX (XIf x1 x2 x3)
	| takeB (evalX x1)	= evalX x2
	| otherwise		= evalX x3


-- Map a function over an expression tree
mapX :: (Exp -> Exp) -> Exp -> Exp
mapX f 	(XLit l)		= f (XLit l)
mapX f  (XPrim op x1 x2)	= f (XPrim op (mapX f x1) (mapX f x2))
mapX f  (XIf x1 x2 x3)		= f (XIf (mapX f x1) (mapX f x2) (mapX f x3))


-- Short circuit some operations on literals
shortX :: Exp -> Exp
shortX  (XPrim OAdd (XLit (LInt i1)) (XLit (LInt i2)))
	= XLit (LInt (i1 + i2))

shortX  (XPrim OSub (XLit (LInt i1)) (XLit (LInt i2)))
	= XLit (LInt (i1 - i2))

shortX  x = x 


-- An example expression
--	if 5 * 3 > 14 then 7 - 3 else 2 + 5

exp1	= XIf 	(XPrim OGt	
			(XPrim OMul (XLit (LInt 5)) (XLit (LInt 3)))
			(XLit (LInt 14)))
		(XPrim OSub
			(XLit (LInt 7))
			(XLit (LInt 3)))
		(XPrim OAdd
			(XLit (LInt 2))
			(XLit (LInt 5)))

-- The rewritten expression
exp1r	= mapX shortX exp1



