open Expr open Canonical type cse_result = { final_expr: expr; subexpressions: (string * expr) list; } let common_subexpression_elimination expr = let counter = ref 0 in let subexpr_map = Hashtbl.create 100 in let name_map = Hashtbl.create 100 in let get_or_create_name e = let h = hash e in match Hashtbl.find_opt name_map h with | Some name -> Var name | None -> incr counter; let name = Printf.sprintf "cse_%d" !counter in Hashtbl.add name_map h name; Hashtbl.add subexpr_map name e; Var name in let should_extract = function | Const _ | SymConst _ | Var _ -> false | _ -> true in let rec count_occurrences e counts = let h = hash e in Hashtbl.replace counts h (1 + (Hashtbl.find_opt counts h |> Option.value ~default:0)); match e with | Add (e1, e2) | Sub (e1, e2) | Mul (e1, e2) | Div (e1, e2) | Pow (e1, e2) -> count_occurrences e1 counts; count_occurrences e2 counts | Neg e | Sin e | Cos e | Tan e | Sinh e | Cosh e | Tanh e | Asin e | Acos e | Atan e | Exp e | Ln e | Sqrt e | Abs e -> count_occurrences e counts | Atan2 (e1, e2) | Log (e1, e2) -> count_occurrences e1 counts; count_occurrences e2 counts | _ -> () in let counts = Hashtbl.create 100 in count_occurrences expr counts; let rec extract e = if should_extract e && Hashtbl.find counts (hash e) >= 2 then get_or_create_name e else match e with | Add (e1, e2) -> Add (extract e1, extract e2) | Sub (e1, e2) -> Sub (extract e1, extract e2) | Mul (e1, e2) -> Mul (extract e1, extract e2) | Div (e1, e2) -> Div (extract e1, extract e2) | Pow (e1, e2) -> Pow (extract e1, extract e2) | Neg e -> Neg (extract e) | Sin e -> Sin (extract e) | Cos e -> Cos (extract e) | Tan e -> Tan (extract e) | Sinh e -> Sinh (extract e) | Cosh e -> Cosh (extract e) | Tanh e -> Tanh (extract e) | Asin e -> Asin (extract e) | Acos e -> Acos (extract e) | Atan e -> Atan (extract e) | Atan2 (e1, e2) -> Atan2 (extract e1, extract e2) | Exp e -> Exp (extract e) | Ln e -> Ln (extract e) | Log (e1, e2) -> Log (extract e1, extract e2) | Sqrt e -> Sqrt (extract e) | Abs e -> Abs (extract e) | e -> e in let final = extract expr in let subexprs = Hashtbl.fold (fun name e acc -> (name, e) :: acc) subexpr_map [] in {final_expr = final; subexpressions = List.sort (fun (n1, _) (n2, _) -> String.compare n1 n2) subexprs} let horner_form expr var = let rec collect_poly e = match e with | Const c -> [(0, Const c)] | Var v when v = var -> [(1, Const 1.0)] | Pow (Var v, Const n) when v = var && Float.is_integer n -> [(int_of_float n, Const 1.0)] | Mul (Const c, Pow (Var v, Const n)) when v = var && Float.is_integer n -> [(int_of_float n, Const c)] | Add (e1, e2) -> collect_poly e1 @ collect_poly e2 | _ -> [(0, e)] in let terms = collect_poly expr in let max_degree = List.fold_left (fun acc (deg, _) -> max acc deg) 0 terms in let coeffs = Array.make (max_degree + 1) (Const 0.0) in List.iter (fun (deg, coeff) -> coeffs.(deg) <- Simplify.simplify (Add (coeffs.(deg), coeff)) ) terms; let rec build_horner deg = if deg < 0 then Const 0.0 else if deg = 0 then coeffs.(0) else Add (coeffs.(deg), Mul (Var var, build_horner (deg - 1))) in Simplify.simplify (build_horner max_degree)