leibniz/lib/simplify.ml
2026-01-19 03:37:26 +00:00

220 lines
6.7 KiB
OCaml

open Expr
let max_iterations = 100
let rec simplify_once = function
| Const _ as c -> c
| SymConst _ as s -> s
| Var _ as v -> v
| Add (e1, e2) -> simplify_add (simplify_once e1) (simplify_once e2)
| Sub (e1, e2) -> simplify_sub (simplify_once e1) (simplify_once e2)
| Mul (e1, e2) -> simplify_mul (simplify_once e1) (simplify_once e2)
| Div (e1, e2) -> simplify_div (simplify_once e1) (simplify_once e2)
| Pow (e1, e2) -> simplify_pow (simplify_once e1) (simplify_once e2)
| Neg e -> simplify_neg (simplify_once e)
| Sin e -> simplify_sin (simplify_once e)
| Cos e -> simplify_cos (simplify_once e)
| Tan e -> simplify_tan (simplify_once e)
| Sinh e -> simplify_sinh (simplify_once e)
| Cosh e -> simplify_cosh (simplify_once e)
| Tanh e -> simplify_tanh (simplify_once e)
| Asin e -> simplify_asin (simplify_once e)
| Acos e -> simplify_acos (simplify_once e)
| Atan e -> simplify_atan (simplify_once e)
| Atan2 (e1, e2) -> simplify_atan2 (simplify_once e1) (simplify_once e2)
| Exp e -> simplify_exp (simplify_once e)
| Ln e -> simplify_ln (simplify_once e)
| Log (e1, e2) -> simplify_log (simplify_once e1) (simplify_once e2)
| Sqrt e -> simplify_sqrt (simplify_once e)
| Abs e -> simplify_abs (simplify_once e)
and simplify_add e1 e2 =
match (e1, e2) with
| Const 0.0, e | e, Const 0.0 -> e
| Const a, Const b -> Const (a +. b)
| Mul (Const a, e1), Mul (Const b, e2) when Canonical.equal e1 e2 ->
simplify_mul (Const (a +. b)) e1
| Mul (Const a, e1), e2 when Canonical.equal e1 e2 ->
simplify_mul (Const (a +. 1.0)) e1
| e1, Mul (Const b, e2) when Canonical.equal e1 e2 ->
simplify_mul (Const (1.0 +. b)) e1
| e1, e2 when Canonical.equal e1 e2 ->
simplify_mul (Const 2.0) e1
| Add (e1, Const a), Const b -> simplify_add e1 (Const (a +. b))
| _ -> Add (e1, e2)
and simplify_sub e1 e2 =
match (e1, e2) with
| e, Const 0.0 -> e
| Const a, Const b -> Const (a -. b)
| e1, e2 when Canonical.equal e1 e2 -> Const 0.0
| _ -> Sub (e1, e2)
and simplify_mul e1 e2 =
match (e1, e2) with
| Const 0.0, _ | _, Const 0.0 -> Const 0.0
| Const 1.0, e | e, Const 1.0 -> e
| Const (-1.0), e -> simplify_neg e
| e, Const (-1.0) -> simplify_neg e
| Const a, Const b -> Const (a *. b)
| Const a, Mul (Const b, e) -> simplify_mul (Const (a *. b)) e
| Mul (Const a, e), Const b -> simplify_mul (Const (a *. b)) e
| Const a, Mul (e1, Mul (Const b, e2)) ->
simplify_mul (Const (a *. b)) (Mul (e1, e2))
| Mul (Const a, e1), Mul (Const b, e2) ->
simplify_mul (Const (a *. b)) (Mul (e1, e2))
| Pow (e1, a), Pow (e2, b) when Canonical.equal e1 e2 ->
simplify_pow e1 (simplify_add a b)
| e1, Pow (e2, b) when Canonical.equal e1 e2 ->
simplify_pow e1 (simplify_add (Const 1.0) b)
| Pow (e1, a), e2 when Canonical.equal e1 e2 ->
simplify_pow e1 (simplify_add a (Const 1.0))
| e1, e2 when Canonical.equal e1 e2 ->
simplify_pow e1 (Const 2.0)
| Exp e1, Exp e2 -> Exp (simplify_add e1 e2)
| (Sin _ | Cos _ | Tan _ | Sinh _ | Cosh _ | Tanh _ |
Asin _ | Acos _ | Atan _ | Exp _ | Ln _ | Log _ | Sqrt _ | Abs _ | Pow _), Var _ ->
Mul (e2, e1)
| _ -> Mul (e1, e2)
and simplify_div e1 e2 =
match (e1, e2) with
| Const 0.0, _ -> Const 0.0
| e, Const 1.0 -> e
| Const a, Const b -> Const (a /. b)
| e1, e2 when Canonical.equal e1 e2 -> Const 1.0
| Mul (e1, e2), e3 when Canonical.equal e2 e3 -> e1
| Mul (e1, e2), e3 when Canonical.equal e1 e3 -> e2
| _ -> Div (e1, e2)
and simplify_pow e1 e2 =
match (e1, e2) with
| _, Const 0.0 -> Const 1.0
| e, Const 1.0 -> e
| Const 0.0, _ -> Const 0.0
| Const 1.0, _ -> Const 1.0
| Const a, Const b -> Const (a ** b)
| Pow (e, a), b -> simplify_pow e (simplify_mul a b)
| Sqrt e, Const 2.0 -> e
| e, Const 0.5 -> simplify_sqrt e
| SymConst E, Ln e -> e
| _ -> Pow (e1, e2)
and simplify_neg = function
| Const c -> Const (-.c)
| Neg e -> e
| Mul (Const c, e) -> simplify_mul (Const (-.c)) e
| Mul (e, Const c) -> simplify_mul (Const (-.c)) e
| e -> Neg e
and simplify_sin = function
| Const 0.0 -> Const 0.0
| Const c -> Const (sin c)
| Asin e -> e
| Neg e -> simplify_neg (Sin e)
| e -> Sin e
and simplify_cos = function
| Const 0.0 -> Const 1.0
| Const c -> Const (cos c)
| Acos e -> e
| Neg e -> Cos e
| e -> Cos e
and simplify_tan = function
| Const 0.0 -> Const 0.0
| Const c -> Const (tan c)
| Atan e -> e
| Neg e -> simplify_neg (Tan e)
| e -> Tan e
and simplify_sinh = function
| Const 0.0 -> Const 0.0
| Const c -> Const (sinh c)
| Neg e -> simplify_neg (Sinh e)
| e -> Sinh e
and simplify_cosh = function
| Const 0.0 -> Const 1.0
| Const c -> Const (cosh c)
| Neg e -> Cosh e
| e -> Cosh e
and simplify_tanh = function
| Const 0.0 -> Const 0.0
| Const c -> Const (tanh c)
| Neg e -> simplify_neg (Tanh e)
| e -> Tanh e
and simplify_asin = function
| Const c -> Const (asin c)
| Sin e -> e
| e -> Asin e
and simplify_acos = function
| Const c -> Const (acos c)
| Cos e -> e
| e -> Acos e
and simplify_atan = function
| Const c -> Const (atan c)
| Tan e -> e
| e -> Atan e
and simplify_atan2 e1 e2 =
match (e1, e2) with
| Const a, Const b -> Const (atan2 a b)
| _ -> Atan2 (e1, e2)
and simplify_exp = function
| Const 0.0 -> Const 1.0
| Const c -> Const (exp c)
| Ln e -> e
| Add (e1, e2) -> simplify_mul (Exp e1) (Exp e2)
| Mul (Const c, e) -> simplify_pow (Exp e) (Const c)
| e -> Exp e
and simplify_ln = function
| Const 1.0 -> Const 0.0
| Const c when c > 0.0 -> Const (log c)
| Exp e -> e
| SymConst E -> Const 1.0
| Mul (e1, e2) -> simplify_add (Ln e1) (Ln e2)
| Div (e1, e2) -> simplify_sub (Ln e1) (Ln e2)
| Pow (e, Const c) -> simplify_mul (Const c) (Ln e)
| e -> Ln e
and simplify_log base_e arg =
match (base_e, arg) with
| Const b, Const a when b > 0.0 && b <> 1.0 && a > 0.0 ->
Const (log a /. log b)
| b, a when Canonical.equal b a -> Const 1.0
| Const b, Pow (e, c) when Canonical.equal (Const b) e ->
c
| _ -> Log (base_e, arg)
and simplify_sqrt = function
| Const 0.0 -> Const 0.0
| Const 1.0 -> Const 1.0
| Const c when c >= 0.0 -> Const (sqrt c)
| Pow (e, Const 2.0) -> simplify_abs e
| Mul (e1, e2) -> simplify_mul (Sqrt e1) (Sqrt e2)
| e -> Sqrt e
and simplify_abs = function
| Const c -> Const (abs_float c)
| Abs e -> Abs e
| Neg e -> Abs e
| Mul (Const c, e) -> simplify_mul (Const (abs_float c)) (Abs e)
| e -> Abs e
let simplify expr =
let rec fixed_point e count =
if count >= max_iterations then e
else
let canonical = Canonical.canonicalize e in
let simplified = simplify_once canonical in
if Canonical.equal simplified canonical then simplified
else fixed_point simplified (count + 1)
in
fixed_point expr 0