leibniz/lib/substitute.ml
2026-01-19 03:37:26 +00:00

194 lines
6.5 KiB
OCaml

open Expr
let rec substitute var replacement = function
| Const _ as c -> c
| SymConst _ as s -> s
| Var v -> if v = var then replacement else Var v
| Add (e1, e2) -> Add (substitute var replacement e1, substitute var replacement e2)
| Sub (e1, e2) -> Sub (substitute var replacement e1, substitute var replacement e2)
| Mul (e1, e2) -> Mul (substitute var replacement e1, substitute var replacement e2)
| Div (e1, e2) -> Div (substitute var replacement e1, substitute var replacement e2)
| Pow (e1, e2) -> Pow (substitute var replacement e1, substitute var replacement e2)
| Neg e -> Neg (substitute var replacement e)
| Sin e -> Sin (substitute var replacement e)
| Cos e -> Cos (substitute var replacement e)
| Tan e -> Tan (substitute var replacement e)
| Sinh e -> Sinh (substitute var replacement e)
| Cosh e -> Cosh (substitute var replacement e)
| Tanh e -> Tanh (substitute var replacement e)
| Asin e -> Asin (substitute var replacement e)
| Acos e -> Acos (substitute var replacement e)
| Atan e -> Atan (substitute var replacement e)
| Atan2 (e1, e2) -> Atan2 (substitute var replacement e1, substitute var replacement e2)
| Exp e -> Exp (substitute var replacement e)
| Ln e -> Ln (substitute var replacement e)
| Log (e1, e2) -> Log (substitute var replacement e1, substitute var replacement e2)
| Sqrt e -> Sqrt (substitute var replacement e)
| Abs e -> Abs (substitute var replacement e)
let substitute_many subs expr =
List.fold_left (fun e (var, repl) -> substitute var repl e) expr subs
type pattern =
| PVar of string
| PWild
| PConst of float
| PSymConst of sym_const
| POp of string * pattern list
type bindings = (string * expr) list
let rec matches (pat : pattern) (e : expr) : bindings option =
match (pat, e) with
| PWild, _ -> Some []
| PVar v, e -> Some [(v, e)]
| PConst c1, Const c2 when c1 = c2 -> Some []
| PSymConst s1, SymConst s2 when s1 = s2 -> Some []
| POp ("Add", [p1; p2]), Add (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Sub", [p1; p2]), Sub (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Mul", [p1; p2]), Mul (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Div", [p1; p2]), Div (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Pow", [p1; p2]), Pow (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Neg", [p]), Neg e ->
matches p e
| POp ("Sin", [p]), Sin e ->
matches p e
| POp ("Cos", [p]), Cos e ->
matches p e
| POp ("Tan", [p]), Tan e ->
matches p e
| POp ("Sinh", [p]), Sinh e ->
matches p e
| POp ("Cosh", [p]), Cosh e ->
matches p e
| POp ("Tanh", [p]), Tanh e ->
matches p e
| POp ("Asin", [p]), Asin e ->
matches p e
| POp ("Acos", [p]), Acos e ->
matches p e
| POp ("Atan", [p]), Atan e ->
matches p e
| POp ("Atan2", [p1; p2]), Atan2 (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Exp", [p]), Exp e ->
matches p e
| POp ("Ln", [p]), Ln e ->
matches p e
| POp ("Log", [p1; p2]), Log (e1, e2) ->
matches_binary p1 p2 e1 e2
| POp ("Sqrt", [p]), Sqrt e ->
matches p e
| POp ("Abs", [p]), Abs e ->
matches p e
| _ -> None
and matches_binary p1 p2 e1 e2 =
match matches p1 e1 with
| None -> None
| Some b1 ->
match matches p2 e2 with
| None -> None
| Some b2 -> Some (b1 @ b2)
let rec instantiate (template : pattern) (bindings : bindings) : expr option =
match template with
| PWild -> None
| PVar v -> List.assoc_opt v bindings
| PConst c -> Some (Const c)
| PSymConst s -> Some (SymConst s)
| POp ("Add", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Add (e1, e2))
| _ -> None)
| POp ("Sub", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Sub (e1, e2))
| _ -> None)
| POp ("Mul", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Mul (e1, e2))
| _ -> None)
| POp ("Div", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Div (e1, e2))
| _ -> None)
| POp ("Pow", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Pow (e1, e2))
| _ -> None)
| POp ("Neg", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Neg e)
| None -> None)
| POp ("Sin", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Sin e)
| None -> None)
| POp ("Cos", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Cos e)
| None -> None)
| POp ("Tan", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Tan e)
| None -> None)
| POp ("Sinh", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Sinh e)
| None -> None)
| POp ("Cosh", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Cosh e)
| None -> None)
| POp ("Tanh", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Tanh e)
| None -> None)
| POp ("Asin", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Asin e)
| None -> None)
| POp ("Acos", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Acos e)
| None -> None)
| POp ("Atan", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Atan e)
| None -> None)
| POp ("Atan2", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Atan2 (e1, e2))
| _ -> None)
| POp ("Exp", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Exp e)
| None -> None)
| POp ("Ln", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Ln e)
| None -> None)
| POp ("Log", [p1; p2]) ->
(match (instantiate p1 bindings, instantiate p2 bindings) with
| Some e1, Some e2 -> Some (Log (e1, e2))
| _ -> None)
| POp ("Sqrt", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Sqrt e)
| None -> None)
| POp ("Abs", [p]) ->
(match instantiate p bindings with
| Some e -> Some (Abs e)
| None -> None)
| _ -> None
let rewrite pattern template expr =
match matches pattern expr with
| None -> None
| Some bindings -> instantiate template bindings