leibniz/lib/canonical.ml
2026-01-19 03:37:26 +00:00

157 lines
4.9 KiB
OCaml

open Expr
let rec compare_expr e1 e2 =
match (e1, e2) with
| Const a, Const b -> Float.compare a b
| Const _, _ -> -1
| _, Const _ -> 1
| SymConst a, SymConst b -> Stdlib.compare a b
| SymConst _, _ -> -1
| _, SymConst _ -> 1
| Var a, Var b -> String.compare a b
| Var _, _ -> -1
| _, Var _ -> 1
| Neg a, Neg b -> compare_expr a b
| Neg _, _ -> -1
| _, Neg _ -> 1
| Add (a1, a2), Add (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Add _, _ -> -1
| _, Add _ -> 1
| Sub (a1, a2), Sub (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Sub _, _ -> -1
| _, Sub _ -> 1
| Mul (a1, a2), Mul (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Mul _, _ -> -1
| _, Mul _ -> 1
| Div (a1, a2), Div (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Div _, _ -> -1
| _, Div _ -> 1
| Pow (a1, a2), Pow (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Pow _, _ -> -1
| _, Pow _ -> 1
| Sin a, Sin b -> compare_expr a b
| Sin _, _ -> -1
| _, Sin _ -> 1
| Cos a, Cos b -> compare_expr a b
| Cos _, _ -> -1
| _, Cos _ -> 1
| Tan a, Tan b -> compare_expr a b
| Tan _, _ -> -1
| _, Tan _ -> 1
| Sinh a, Sinh b -> compare_expr a b
| Sinh _, _ -> -1
| _, Sinh _ -> 1
| Cosh a, Cosh b -> compare_expr a b
| Cosh _, _ -> -1
| _, Cosh _ -> 1
| Tanh a, Tanh b -> compare_expr a b
| Tanh _, _ -> -1
| _, Tanh _ -> 1
| Asin a, Asin b -> compare_expr a b
| Asin _, _ -> -1
| _, Asin _ -> 1
| Acos a, Acos b -> compare_expr a b
| Acos _, _ -> -1
| _, Acos _ -> 1
| Atan a, Atan b -> compare_expr a b
| Atan _, _ -> -1
| _, Atan _ -> 1
| Atan2 (a1, a2), Atan2 (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Atan2 _, _ -> -1
| _, Atan2 _ -> 1
| Exp a, Exp b -> compare_expr a b
| Exp _, _ -> -1
| _, Exp _ -> 1
| Ln a, Ln b -> compare_expr a b
| Ln _, _ -> -1
| _, Ln _ -> 1
| Log (a1, a2), Log (b1, b2) ->
let c = compare_expr a1 b1 in
if c <> 0 then c else compare_expr a2 b2
| Log _, _ -> -1
| _, Log _ -> 1
| Sqrt a, Sqrt b -> compare_expr a b
| Sqrt _, _ -> -1
| _, Sqrt _ -> 1
| Abs a, Abs b -> compare_expr a b
let rec canonicalize = function
| Const _ as c -> c
| SymConst _ as s -> s
| Var _ as v -> v
| Neg (Neg e) -> canonicalize e
| Neg e -> Neg (canonicalize e)
| Sub (e1, e2) -> canonicalize (Add (e1, Neg e2))
| Div (e1, e2) -> canonicalize (Mul (e1, Pow (e2, Const (-.1.0))))
| Add (e1, e2) ->
let c1 = canonicalize e1 in
let c2 = canonicalize e2 in
if compare_expr c1 c2 <= 0 then Add (c1, c2) else Add (c2, c1)
| Mul (e1, e2) ->
let c1 = canonicalize e1 in
let c2 = canonicalize e2 in
if compare_expr c1 c2 <= 0 then Mul (c1, c2) else Mul (c2, c1)
| Pow (e1, e2) -> Pow (canonicalize e1, canonicalize e2)
| Sin e -> Sin (canonicalize e)
| Cos e -> Cos (canonicalize e)
| Tan e -> Tan (canonicalize e)
| Sinh e -> Sinh (canonicalize e)
| Cosh e -> Cosh (canonicalize e)
| Tanh e -> Tanh (canonicalize e)
| Asin e -> Asin (canonicalize e)
| Acos e -> Acos (canonicalize e)
| Atan e -> Atan (canonicalize e)
| Atan2 (e1, e2) -> Atan2 (canonicalize e1, canonicalize e2)
| Exp e -> Exp (canonicalize e)
| Ln e -> Ln (canonicalize e)
| Log (e1, e2) -> Log (canonicalize e1, canonicalize e2)
| Sqrt e -> Sqrt (canonicalize e)
| Abs e -> Abs (canonicalize e)
let equal e1 e2 =
let c1 = canonicalize e1 in
let c2 = canonicalize e2 in
compare_expr c1 c2 = 0
let compare e1 e2 =
let c1 = canonicalize e1 in
let c2 = canonicalize e2 in
compare_expr c1 c2
let rec hash = function
| Const f -> Hashtbl.hash ("Const", f)
| SymConst s -> Hashtbl.hash ("SymConst", s)
| Var v -> Hashtbl.hash ("Var", v)
| Add (e1, e2) -> Hashtbl.hash ("Add", hash e1, hash e2)
| Sub (e1, e2) -> Hashtbl.hash ("Sub", hash e1, hash e2)
| Mul (e1, e2) -> Hashtbl.hash ("Mul", hash e1, hash e2)
| Div (e1, e2) -> Hashtbl.hash ("Div", hash e1, hash e2)
| Pow (e1, e2) -> Hashtbl.hash ("Pow", hash e1, hash e2)
| Neg e -> Hashtbl.hash ("Neg", hash e)
| Sin e -> Hashtbl.hash ("Sin", hash e)
| Cos e -> Hashtbl.hash ("Cos", hash e)
| Tan e -> Hashtbl.hash ("Tan", hash e)
| Sinh e -> Hashtbl.hash ("Sinh", hash e)
| Cosh e -> Hashtbl.hash ("Cosh", hash e)
| Tanh e -> Hashtbl.hash ("Tanh", hash e)
| Asin e -> Hashtbl.hash ("Asin", hash e)
| Acos e -> Hashtbl.hash ("Acos", hash e)
| Atan e -> Hashtbl.hash ("Atan", hash e)
| Atan2 (e1, e2) -> Hashtbl.hash ("Atan2", hash e1, hash e2)
| Exp e -> Hashtbl.hash ("Exp", hash e)
| Ln e -> Hashtbl.hash ("Ln", hash e)
| Log (e1, e2) -> Hashtbl.hash ("Log", hash e1, hash e2)
| Sqrt e -> Hashtbl.hash ("Sqrt", hash e)
| Abs e -> Hashtbl.hash ("Abs", hash e)