leibniz/lib/cse.ml
2026-01-19 03:37:26 +00:00

110 lines
3.5 KiB
OCaml

open Expr
open Canonical
type cse_result = {
final_expr: expr;
subexpressions: (string * expr) list;
}
let common_subexpression_elimination expr =
let counter = ref 0 in
let subexpr_map = Hashtbl.create 100 in
let name_map = Hashtbl.create 100 in
let get_or_create_name e =
let h = hash e in
match Hashtbl.find_opt name_map h with
| Some name -> Var name
| None ->
incr counter;
let name = Printf.sprintf "cse_%d" !counter in
Hashtbl.add name_map h name;
Hashtbl.add subexpr_map name e;
Var name
in
let should_extract = function
| Const _ | SymConst _ | Var _ -> false
| _ -> true
in
let rec count_occurrences e counts =
let h = hash e in
Hashtbl.replace counts h (1 + (Hashtbl.find_opt counts h |> Option.value ~default:0));
match e with
| Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Div (e1, e2) | Pow (e1, e2) ->
count_occurrences e1 counts;
count_occurrences e2 counts
| Neg e | Sin e | Cos e | Tan e | Sinh e | Cosh e | Tanh e
| Asin e | Acos e | Atan e | Exp e | Ln e | Sqrt e | Abs e ->
count_occurrences e counts
| Atan2 (e1, e2) | Log (e1, e2) ->
count_occurrences e1 counts;
count_occurrences e2 counts
| _ -> ()
in
let counts = Hashtbl.create 100 in
count_occurrences expr counts;
let rec extract e =
if should_extract e && Hashtbl.find counts (hash e) >= 2 then
get_or_create_name e
else
match e with
| Add (e1, e2) -> Add (extract e1, extract e2)
| Sub (e1, e2) -> Sub (extract e1, extract e2)
| Mul (e1, e2) -> Mul (extract e1, extract e2)
| Div (e1, e2) -> Div (extract e1, extract e2)
| Pow (e1, e2) -> Pow (extract e1, extract e2)
| Neg e -> Neg (extract e)
| Sin e -> Sin (extract e)
| Cos e -> Cos (extract e)
| Tan e -> Tan (extract e)
| Sinh e -> Sinh (extract e)
| Cosh e -> Cosh (extract e)
| Tanh e -> Tanh (extract e)
| Asin e -> Asin (extract e)
| Acos e -> Acos (extract e)
| Atan e -> Atan (extract e)
| Atan2 (e1, e2) -> Atan2 (extract e1, extract e2)
| Exp e -> Exp (extract e)
| Ln e -> Ln (extract e)
| Log (e1, e2) -> Log (extract e1, extract e2)
| Sqrt e -> Sqrt (extract e)
| Abs e -> Abs (extract e)
| e -> e
in
let final = extract expr in
let subexprs = Hashtbl.fold (fun name e acc -> (name, e) :: acc) subexpr_map [] in
{final_expr = final; subexpressions = List.sort (fun (n1, _) (n2, _) -> String.compare n1 n2) subexprs}
let horner_form expr var =
let rec collect_poly e =
match e with
| Const c -> [(0, Const c)]
| Var v when v = var -> [(1, Const 1.0)]
| Pow (Var v, Const n) when v = var && Float.is_integer n ->
[(int_of_float n, Const 1.0)]
| Mul (Const c, Pow (Var v, Const n)) when v = var && Float.is_integer n ->
[(int_of_float n, Const c)]
| Add (e1, e2) -> collect_poly e1 @ collect_poly e2
| _ -> [(0, e)]
in
let terms = collect_poly expr in
let max_degree = List.fold_left (fun acc (deg, _) -> max acc deg) 0 terms in
let coeffs = Array.make (max_degree + 1) (Const 0.0) in
List.iter (fun (deg, coeff) ->
coeffs.(deg) <- Simplify.simplify (Add (coeffs.(deg), coeff))
) terms;
let rec build_horner deg =
if deg < 0 then Const 0.0
else if deg = 0 then coeffs.(0)
else Add (coeffs.(deg), Mul (Var var, build_horner (deg - 1)))
in
Simplify.simplify (build_horner max_degree)