110 lines
3.5 KiB
OCaml
110 lines
3.5 KiB
OCaml
open Expr
|
|
open Canonical
|
|
|
|
type cse_result = {
|
|
final_expr: expr;
|
|
subexpressions: (string * expr) list;
|
|
}
|
|
|
|
let common_subexpression_elimination expr =
|
|
let counter = ref 0 in
|
|
let subexpr_map = Hashtbl.create 100 in
|
|
let name_map = Hashtbl.create 100 in
|
|
|
|
let get_or_create_name e =
|
|
let h = hash e in
|
|
match Hashtbl.find_opt name_map h with
|
|
| Some name -> Var name
|
|
| None ->
|
|
incr counter;
|
|
let name = Printf.sprintf "cse_%d" !counter in
|
|
Hashtbl.add name_map h name;
|
|
Hashtbl.add subexpr_map name e;
|
|
Var name
|
|
in
|
|
|
|
let should_extract = function
|
|
| Const _ | SymConst _ | Var _ -> false
|
|
| _ -> true
|
|
in
|
|
|
|
let rec count_occurrences e counts =
|
|
let h = hash e in
|
|
Hashtbl.replace counts h (1 + (Hashtbl.find_opt counts h |> Option.value ~default:0));
|
|
match e with
|
|
| Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Div (e1, e2) | Pow (e1, e2) ->
|
|
count_occurrences e1 counts;
|
|
count_occurrences e2 counts
|
|
| Neg e | Sin e | Cos e | Tan e | Sinh e | Cosh e | Tanh e
|
|
| Asin e | Acos e | Atan e | Exp e | Ln e | Sqrt e | Abs e ->
|
|
count_occurrences e counts
|
|
| Atan2 (e1, e2) | Log (e1, e2) ->
|
|
count_occurrences e1 counts;
|
|
count_occurrences e2 counts
|
|
| _ -> ()
|
|
in
|
|
|
|
let counts = Hashtbl.create 100 in
|
|
count_occurrences expr counts;
|
|
|
|
let rec extract e =
|
|
if should_extract e && Hashtbl.find counts (hash e) >= 2 then
|
|
get_or_create_name e
|
|
else
|
|
match e with
|
|
| Add (e1, e2) -> Add (extract e1, extract e2)
|
|
| Sub (e1, e2) -> Sub (extract e1, extract e2)
|
|
| Mul (e1, e2) -> Mul (extract e1, extract e2)
|
|
| Div (e1, e2) -> Div (extract e1, extract e2)
|
|
| Pow (e1, e2) -> Pow (extract e1, extract e2)
|
|
| Neg e -> Neg (extract e)
|
|
| Sin e -> Sin (extract e)
|
|
| Cos e -> Cos (extract e)
|
|
| Tan e -> Tan (extract e)
|
|
| Sinh e -> Sinh (extract e)
|
|
| Cosh e -> Cosh (extract e)
|
|
| Tanh e -> Tanh (extract e)
|
|
| Asin e -> Asin (extract e)
|
|
| Acos e -> Acos (extract e)
|
|
| Atan e -> Atan (extract e)
|
|
| Atan2 (e1, e2) -> Atan2 (extract e1, extract e2)
|
|
| Exp e -> Exp (extract e)
|
|
| Ln e -> Ln (extract e)
|
|
| Log (e1, e2) -> Log (extract e1, extract e2)
|
|
| Sqrt e -> Sqrt (extract e)
|
|
| Abs e -> Abs (extract e)
|
|
| e -> e
|
|
in
|
|
|
|
let final = extract expr in
|
|
let subexprs = Hashtbl.fold (fun name e acc -> (name, e) :: acc) subexpr_map [] in
|
|
{final_expr = final; subexpressions = List.sort (fun (n1, _) (n2, _) -> String.compare n1 n2) subexprs}
|
|
|
|
let horner_form expr var =
|
|
let rec collect_poly e =
|
|
match e with
|
|
| Const c -> [(0, Const c)]
|
|
| Var v when v = var -> [(1, Const 1.0)]
|
|
| Pow (Var v, Const n) when v = var && Float.is_integer n ->
|
|
[(int_of_float n, Const 1.0)]
|
|
| Mul (Const c, Pow (Var v, Const n)) when v = var && Float.is_integer n ->
|
|
[(int_of_float n, Const c)]
|
|
| Add (e1, e2) -> collect_poly e1 @ collect_poly e2
|
|
| _ -> [(0, e)]
|
|
in
|
|
|
|
let terms = collect_poly expr in
|
|
let max_degree = List.fold_left (fun acc (deg, _) -> max acc deg) 0 terms in
|
|
|
|
let coeffs = Array.make (max_degree + 1) (Const 0.0) in
|
|
List.iter (fun (deg, coeff) ->
|
|
coeffs.(deg) <- Simplify.simplify (Add (coeffs.(deg), coeff))
|
|
) terms;
|
|
|
|
let rec build_horner deg =
|
|
if deg < 0 then Const 0.0
|
|
else if deg = 0 then coeffs.(0)
|
|
else Add (coeffs.(deg), Mul (Var var, build_horner (deg - 1)))
|
|
in
|
|
|
|
Simplify.simplify (build_horner max_degree)
|