open Expr let max_iterations = 100 let rec simplify_once = function | Const _ as c -> c | SymConst _ as s -> s | Var _ as v -> v | Add (e1, e2) -> simplify_add (simplify_once e1) (simplify_once e2) | Sub (e1, e2) -> simplify_sub (simplify_once e1) (simplify_once e2) | Mul (e1, e2) -> simplify_mul (simplify_once e1) (simplify_once e2) | Div (e1, e2) -> simplify_div (simplify_once e1) (simplify_once e2) | Pow (e1, e2) -> simplify_pow (simplify_once e1) (simplify_once e2) | Neg e -> simplify_neg (simplify_once e) | Sin e -> simplify_sin (simplify_once e) | Cos e -> simplify_cos (simplify_once e) | Tan e -> simplify_tan (simplify_once e) | Sinh e -> simplify_sinh (simplify_once e) | Cosh e -> simplify_cosh (simplify_once e) | Tanh e -> simplify_tanh (simplify_once e) | Asin e -> simplify_asin (simplify_once e) | Acos e -> simplify_acos (simplify_once e) | Atan e -> simplify_atan (simplify_once e) | Atan2 (e1, e2) -> simplify_atan2 (simplify_once e1) (simplify_once e2) | Exp e -> simplify_exp (simplify_once e) | Ln e -> simplify_ln (simplify_once e) | Log (e1, e2) -> simplify_log (simplify_once e1) (simplify_once e2) | Sqrt e -> simplify_sqrt (simplify_once e) | Abs e -> simplify_abs (simplify_once e) and simplify_add e1 e2 = match (e1, e2) with | Const 0.0, e | e, Const 0.0 -> e | Const a, Const b -> Const (a +. b) | Mul (Const a, e1), Mul (Const b, e2) when Canonical.equal e1 e2 -> simplify_mul (Const (a +. b)) e1 | Mul (Const a, e1), e2 when Canonical.equal e1 e2 -> simplify_mul (Const (a +. 1.0)) e1 | e1, Mul (Const b, e2) when Canonical.equal e1 e2 -> simplify_mul (Const (1.0 +. b)) e1 | e1, e2 when Canonical.equal e1 e2 -> simplify_mul (Const 2.0) e1 | Add (e1, Const a), Const b -> simplify_add e1 (Const (a +. b)) | _ -> Add (e1, e2) and simplify_sub e1 e2 = match (e1, e2) with | e, Const 0.0 -> e | Const a, Const b -> Const (a -. b) | e1, e2 when Canonical.equal e1 e2 -> Const 0.0 | _ -> Sub (e1, e2) and simplify_mul e1 e2 = match (e1, e2) with | Const 0.0, _ | _, Const 0.0 -> Const 0.0 | Const 1.0, e | e, Const 1.0 -> e | Const (-1.0), e -> simplify_neg e | e, Const (-1.0) -> simplify_neg e | Const a, Const b -> Const (a *. b) | Const a, Mul (Const b, e) -> simplify_mul (Const (a *. b)) e | Mul (Const a, e), Const b -> simplify_mul (Const (a *. b)) e | Const a, Mul (e1, Mul (Const b, e2)) -> simplify_mul (Const (a *. b)) (Mul (e1, e2)) | Mul (Const a, e1), Mul (Const b, e2) -> simplify_mul (Const (a *. b)) (Mul (e1, e2)) | Pow (e1, a), Pow (e2, b) when Canonical.equal e1 e2 -> simplify_pow e1 (simplify_add a b) | e1, Pow (e2, b) when Canonical.equal e1 e2 -> simplify_pow e1 (simplify_add (Const 1.0) b) | Pow (e1, a), e2 when Canonical.equal e1 e2 -> simplify_pow e1 (simplify_add a (Const 1.0)) | e1, e2 when Canonical.equal e1 e2 -> simplify_pow e1 (Const 2.0) | Exp e1, Exp e2 -> Exp (simplify_add e1 e2) | (Sin _ | Cos _ | Tan _ | Sinh _ | Cosh _ | Tanh _ | Asin _ | Acos _ | Atan _ | Exp _ | Ln _ | Log _ | Sqrt _ | Abs _ | Pow _), Var _ -> Mul (e2, e1) | _ -> Mul (e1, e2) and simplify_div e1 e2 = match (e1, e2) with | Const 0.0, _ -> Const 0.0 | e, Const 1.0 -> e | Const a, Const b -> Const (a /. b) | e1, e2 when Canonical.equal e1 e2 -> Const 1.0 | Mul (e1, e2), e3 when Canonical.equal e2 e3 -> e1 | Mul (e1, e2), e3 when Canonical.equal e1 e3 -> e2 | _ -> Div (e1, e2) and simplify_pow e1 e2 = match (e1, e2) with | _, Const 0.0 -> Const 1.0 | e, Const 1.0 -> e | Const 0.0, _ -> Const 0.0 | Const 1.0, _ -> Const 1.0 | Const a, Const b -> Const (a ** b) | Pow (e, a), b -> simplify_pow e (simplify_mul a b) | Sqrt e, Const 2.0 -> e | e, Const 0.5 -> simplify_sqrt e | SymConst E, Ln e -> e | _ -> Pow (e1, e2) and simplify_neg = function | Const c -> Const (-.c) | Neg e -> e | Mul (Const c, e) -> simplify_mul (Const (-.c)) e | Mul (e, Const c) -> simplify_mul (Const (-.c)) e | e -> Neg e and simplify_sin = function | Const 0.0 -> Const 0.0 | Const c -> Const (sin c) | Asin e -> e | Neg e -> simplify_neg (Sin e) | e -> Sin e and simplify_cos = function | Const 0.0 -> Const 1.0 | Const c -> Const (cos c) | Acos e -> e | Neg e -> Cos e | e -> Cos e and simplify_tan = function | Const 0.0 -> Const 0.0 | Const c -> Const (tan c) | Atan e -> e | Neg e -> simplify_neg (Tan e) | e -> Tan e and simplify_sinh = function | Const 0.0 -> Const 0.0 | Const c -> Const (sinh c) | Neg e -> simplify_neg (Sinh e) | e -> Sinh e and simplify_cosh = function | Const 0.0 -> Const 1.0 | Const c -> Const (cosh c) | Neg e -> Cosh e | e -> Cosh e and simplify_tanh = function | Const 0.0 -> Const 0.0 | Const c -> Const (tanh c) | Neg e -> simplify_neg (Tanh e) | e -> Tanh e and simplify_asin = function | Const c -> Const (asin c) | Sin e -> e | e -> Asin e and simplify_acos = function | Const c -> Const (acos c) | Cos e -> e | e -> Acos e and simplify_atan = function | Const c -> Const (atan c) | Tan e -> e | e -> Atan e and simplify_atan2 e1 e2 = match (e1, e2) with | Const a, Const b -> Const (atan2 a b) | _ -> Atan2 (e1, e2) and simplify_exp = function | Const 0.0 -> Const 1.0 | Const c -> Const (exp c) | Ln e -> e | Add (e1, e2) -> simplify_mul (Exp e1) (Exp e2) | Mul (Const c, e) -> simplify_pow (Exp e) (Const c) | e -> Exp e and simplify_ln = function | Const 1.0 -> Const 0.0 | Const c when c > 0.0 -> Const (log c) | Exp e -> e | SymConst E -> Const 1.0 | Mul (e1, e2) -> simplify_add (Ln e1) (Ln e2) | Div (e1, e2) -> simplify_sub (Ln e1) (Ln e2) | Pow (e, Const c) -> simplify_mul (Const c) (Ln e) | e -> Ln e and simplify_log base_e arg = match (base_e, arg) with | Const b, Const a when b > 0.0 && b <> 1.0 && a > 0.0 -> Const (log a /. log b) | b, a when Canonical.equal b a -> Const 1.0 | Const b, Pow (e, c) when Canonical.equal (Const b) e -> c | _ -> Log (base_e, arg) and simplify_sqrt = function | Const 0.0 -> Const 0.0 | Const 1.0 -> Const 1.0 | Const c when c >= 0.0 -> Const (sqrt c) | Pow (e, Const 2.0) -> simplify_abs e | Mul (e1, e2) -> simplify_mul (Sqrt e1) (Sqrt e2) | e -> Sqrt e and simplify_abs = function | Const c -> Const (abs_float c) | Abs e -> Abs e | Neg e -> Abs e | Mul (Const c, e) -> simplify_mul (Const (abs_float c)) (Abs e) | e -> Abs e let simplify expr = let rec fixed_point e count = if count >= max_iterations then e else let canonical = Canonical.canonicalize e in let simplified = simplify_once canonical in if Canonical.equal simplified canonical then simplified else fixed_point simplified (count + 1) in fixed_point expr 0