open Expr let rec compare_expr e1 e2 = match (e1, e2) with | Const a, Const b -> Float.compare a b | Const _, _ -> -1 | _, Const _ -> 1 | SymConst a, SymConst b -> Stdlib.compare a b | SymConst _, _ -> -1 | _, SymConst _ -> 1 | Var a, Var b -> String.compare a b | Var _, _ -> -1 | _, Var _ -> 1 | Neg a, Neg b -> compare_expr a b | Neg _, _ -> -1 | _, Neg _ -> 1 | Add (a1, a2), Add (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Add _, _ -> -1 | _, Add _ -> 1 | Sub (a1, a2), Sub (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Sub _, _ -> -1 | _, Sub _ -> 1 | Mul (a1, a2), Mul (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Mul _, _ -> -1 | _, Mul _ -> 1 | Div (a1, a2), Div (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Div _, _ -> -1 | _, Div _ -> 1 | Pow (a1, a2), Pow (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Pow _, _ -> -1 | _, Pow _ -> 1 | Sin a, Sin b -> compare_expr a b | Sin _, _ -> -1 | _, Sin _ -> 1 | Cos a, Cos b -> compare_expr a b | Cos _, _ -> -1 | _, Cos _ -> 1 | Tan a, Tan b -> compare_expr a b | Tan _, _ -> -1 | _, Tan _ -> 1 | Sinh a, Sinh b -> compare_expr a b | Sinh _, _ -> -1 | _, Sinh _ -> 1 | Cosh a, Cosh b -> compare_expr a b | Cosh _, _ -> -1 | _, Cosh _ -> 1 | Tanh a, Tanh b -> compare_expr a b | Tanh _, _ -> -1 | _, Tanh _ -> 1 | Asin a, Asin b -> compare_expr a b | Asin _, _ -> -1 | _, Asin _ -> 1 | Acos a, Acos b -> compare_expr a b | Acos _, _ -> -1 | _, Acos _ -> 1 | Atan a, Atan b -> compare_expr a b | Atan _, _ -> -1 | _, Atan _ -> 1 | Atan2 (a1, a2), Atan2 (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Atan2 _, _ -> -1 | _, Atan2 _ -> 1 | Exp a, Exp b -> compare_expr a b | Exp _, _ -> -1 | _, Exp _ -> 1 | Ln a, Ln b -> compare_expr a b | Ln _, _ -> -1 | _, Ln _ -> 1 | Log (a1, a2), Log (b1, b2) -> let c = compare_expr a1 b1 in if c <> 0 then c else compare_expr a2 b2 | Log _, _ -> -1 | _, Log _ -> 1 | Sqrt a, Sqrt b -> compare_expr a b | Sqrt _, _ -> -1 | _, Sqrt _ -> 1 | Abs a, Abs b -> compare_expr a b let rec canonicalize = function | Const _ as c -> c | SymConst _ as s -> s | Var _ as v -> v | Neg (Neg e) -> canonicalize e | Neg e -> Neg (canonicalize e) | Sub (e1, e2) -> canonicalize (Add (e1, Neg e2)) | Div (e1, e2) -> canonicalize (Mul (e1, Pow (e2, Const (-.1.0)))) | Add (e1, e2) -> let c1 = canonicalize e1 in let c2 = canonicalize e2 in if compare_expr c1 c2 <= 0 then Add (c1, c2) else Add (c2, c1) | Mul (e1, e2) -> let c1 = canonicalize e1 in let c2 = canonicalize e2 in if compare_expr c1 c2 <= 0 then Mul (c1, c2) else Mul (c2, c1) | Pow (e1, e2) -> Pow (canonicalize e1, canonicalize e2) | Sin e -> Sin (canonicalize e) | Cos e -> Cos (canonicalize e) | Tan e -> Tan (canonicalize e) | Sinh e -> Sinh (canonicalize e) | Cosh e -> Cosh (canonicalize e) | Tanh e -> Tanh (canonicalize e) | Asin e -> Asin (canonicalize e) | Acos e -> Acos (canonicalize e) | Atan e -> Atan (canonicalize e) | Atan2 (e1, e2) -> Atan2 (canonicalize e1, canonicalize e2) | Exp e -> Exp (canonicalize e) | Ln e -> Ln (canonicalize e) | Log (e1, e2) -> Log (canonicalize e1, canonicalize e2) | Sqrt e -> Sqrt (canonicalize e) | Abs e -> Abs (canonicalize e) let equal e1 e2 = let c1 = canonicalize e1 in let c2 = canonicalize e2 in compare_expr c1 c2 = 0 let compare e1 e2 = let c1 = canonicalize e1 in let c2 = canonicalize e2 in compare_expr c1 c2 let rec hash = function | Const f -> Hashtbl.hash ("Const", f) | SymConst s -> Hashtbl.hash ("SymConst", s) | Var v -> Hashtbl.hash ("Var", v) | Add (e1, e2) -> Hashtbl.hash ("Add", hash e1, hash e2) | Sub (e1, e2) -> Hashtbl.hash ("Sub", hash e1, hash e2) | Mul (e1, e2) -> Hashtbl.hash ("Mul", hash e1, hash e2) | Div (e1, e2) -> Hashtbl.hash ("Div", hash e1, hash e2) | Pow (e1, e2) -> Hashtbl.hash ("Pow", hash e1, hash e2) | Neg e -> Hashtbl.hash ("Neg", hash e) | Sin e -> Hashtbl.hash ("Sin", hash e) | Cos e -> Hashtbl.hash ("Cos", hash e) | Tan e -> Hashtbl.hash ("Tan", hash e) | Sinh e -> Hashtbl.hash ("Sinh", hash e) | Cosh e -> Hashtbl.hash ("Cosh", hash e) | Tanh e -> Hashtbl.hash ("Tanh", hash e) | Asin e -> Hashtbl.hash ("Asin", hash e) | Acos e -> Hashtbl.hash ("Acos", hash e) | Atan e -> Hashtbl.hash ("Atan", hash e) | Atan2 (e1, e2) -> Hashtbl.hash ("Atan2", hash e1, hash e2) | Exp e -> Hashtbl.hash ("Exp", hash e) | Ln e -> Hashtbl.hash ("Ln", hash e) | Log (e1, e2) -> Hashtbl.hash ("Log", hash e1, hash e2) | Sqrt e -> Hashtbl.hash ("Sqrt", hash e) | Abs e -> Hashtbl.hash ("Abs", hash e)