leibniz/lib/matrix.ml
2026-01-19 03:37:26 +00:00

137 lines
3.8 KiB
OCaml

open Expr
open Simplify
type matrix = expr array array
let create rows cols init =
Array.init rows (fun i -> Array.init cols (fun j -> init i j))
let identity n =
create n n (fun i j -> if i = j then Const 1.0 else Const 0.0)
let rows m = Array.length m
let cols m = if Array.length m > 0 then Array.length m.(0) else 0
let get m i j = m.(i).(j)
let set m i j v = m.(i).(j) <- v
let map f m =
Array.map (Array.map f) m
let transpose m =
let r = rows m in
let c = cols m in
create c r (fun i j -> m.(j).(i))
let add m1 m2 =
if rows m1 <> rows m2 || cols m1 <> cols m2 then
failwith "matrix dimensions must match for addition"
else
create (rows m1) (cols m1) (fun i j ->
simplify (Add (m1.(i).(j), m2.(i).(j)))
)
let mult m1 m2 =
if cols m1 <> rows m2 then
failwith "incompatible matrix dimensions for multiplication"
else
create (rows m1) (cols m2) (fun i j ->
let sum = ref (Const 0.0) in
for k = 0 to cols m1 - 1 do
sum := Add (!sum, Mul (m1.(i).(k), m2.(k).(j)))
done;
simplify !sum
)
let scalar_mult s m =
map (fun e -> simplify (Mul (s, e))) m
let det m =
let n = rows m in
if n <> cols m then failwith "determinant requires square matrix";
let rec determinant mat size =
if size = 1 then mat.(0).(0)
else if size = 2 then
simplify (Sub (Mul (mat.(0).(0), mat.(1).(1)),
Mul (mat.(0).(1), mat.(1).(0))))
else
let result = ref (Const 0.0) in
for j = 0 to size - 1 do
let minor = create (size - 1) (size - 1) (fun i k ->
let mi = if i < 0 then i else i + 1 in
let mk = if k < j then k else k + 1 in
mat.(mi).(mk)
) in
let cofactor = determinant minor (size - 1) in
let sign = if j mod 2 = 0 then Const 1.0 else Const (-1.0) in
result := Add (!result, Mul (Mul (sign, mat.(0).(j)), cofactor))
done;
simplify !result
in
determinant m n
let inverse m =
let n = rows m in
if n <> cols m then None
else
let d = det m in
match d with
| Const 0.0 -> None
| _ ->
let adj = create n n (fun i j ->
let minor = create (n - 1) (n - 1) (fun mi mj ->
let si = if mi < i then mi else mi + 1 in
let sj = if mj < j then mj else mj + 1 in
m.(si).(sj)
) in
let minor_det = det minor in
let sign = if (i + j) mod 2 = 0 then Const 1.0 else Const (-1.0) in
simplify (Mul (sign, minor_det))
) in
let adj_t = transpose adj in
Some (map (fun e -> simplify (Div (e, d))) adj_t)
let trace m =
let n = min (rows m) (cols m) in
let sum = ref (Const 0.0) in
for i = 0 to n - 1 do
sum := Add (!sum, m.(i).(i))
done;
simplify !sum
let eigenvalues _m =
[]
let rank m =
let r = rows m in
let c = cols m in
let temp = Array.map Array.copy m in
let rec count_pivots row col rank =
if row >= r || col >= c then rank
else
match temp.(row).(col) with
| Const 0.0 ->
let rec find_pivot i =
if i >= r then None
else match temp.(i).(col) with
| Const 0.0 -> find_pivot (i + 1)
| _ -> Some i
in
(match find_pivot (row + 1) with
| None -> count_pivots row (col + 1) rank
| Some i ->
let tmp_row = temp.(row) in
temp.(row) <- temp.(i);
temp.(i) <- tmp_row;
count_pivots row col rank)
| pivot ->
for i = row + 1 to r - 1 do
let factor = Div (temp.(i).(col), pivot) in
for j = col to c - 1 do
temp.(i).(j) <- simplify (Sub (temp.(i).(j), Mul (factor, temp.(row).(j))))
done
done;
count_pivots (row + 1) (col + 1) (rank + 1)
in
count_pivots 0 0 0