137 lines
3.8 KiB
OCaml
137 lines
3.8 KiB
OCaml
open Expr
|
|
open Simplify
|
|
|
|
type matrix = expr array array
|
|
|
|
let create rows cols init =
|
|
Array.init rows (fun i -> Array.init cols (fun j -> init i j))
|
|
|
|
let identity n =
|
|
create n n (fun i j -> if i = j then Const 1.0 else Const 0.0)
|
|
|
|
let rows m = Array.length m
|
|
let cols m = if Array.length m > 0 then Array.length m.(0) else 0
|
|
|
|
let get m i j = m.(i).(j)
|
|
let set m i j v = m.(i).(j) <- v
|
|
|
|
let map f m =
|
|
Array.map (Array.map f) m
|
|
|
|
let transpose m =
|
|
let r = rows m in
|
|
let c = cols m in
|
|
create c r (fun i j -> m.(j).(i))
|
|
|
|
let add m1 m2 =
|
|
if rows m1 <> rows m2 || cols m1 <> cols m2 then
|
|
failwith "matrix dimensions must match for addition"
|
|
else
|
|
create (rows m1) (cols m1) (fun i j ->
|
|
simplify (Add (m1.(i).(j), m2.(i).(j)))
|
|
)
|
|
|
|
let mult m1 m2 =
|
|
if cols m1 <> rows m2 then
|
|
failwith "incompatible matrix dimensions for multiplication"
|
|
else
|
|
create (rows m1) (cols m2) (fun i j ->
|
|
let sum = ref (Const 0.0) in
|
|
for k = 0 to cols m1 - 1 do
|
|
sum := Add (!sum, Mul (m1.(i).(k), m2.(k).(j)))
|
|
done;
|
|
simplify !sum
|
|
)
|
|
|
|
let scalar_mult s m =
|
|
map (fun e -> simplify (Mul (s, e))) m
|
|
|
|
let det m =
|
|
let n = rows m in
|
|
if n <> cols m then failwith "determinant requires square matrix";
|
|
|
|
let rec determinant mat size =
|
|
if size = 1 then mat.(0).(0)
|
|
else if size = 2 then
|
|
simplify (Sub (Mul (mat.(0).(0), mat.(1).(1)),
|
|
Mul (mat.(0).(1), mat.(1).(0))))
|
|
else
|
|
let result = ref (Const 0.0) in
|
|
for j = 0 to size - 1 do
|
|
let minor = create (size - 1) (size - 1) (fun i k ->
|
|
let mi = if i < 0 then i else i + 1 in
|
|
let mk = if k < j then k else k + 1 in
|
|
mat.(mi).(mk)
|
|
) in
|
|
let cofactor = determinant minor (size - 1) in
|
|
let sign = if j mod 2 = 0 then Const 1.0 else Const (-1.0) in
|
|
result := Add (!result, Mul (Mul (sign, mat.(0).(j)), cofactor))
|
|
done;
|
|
simplify !result
|
|
in
|
|
determinant m n
|
|
|
|
let inverse m =
|
|
let n = rows m in
|
|
if n <> cols m then None
|
|
else
|
|
let d = det m in
|
|
match d with
|
|
| Const 0.0 -> None
|
|
| _ ->
|
|
let adj = create n n (fun i j ->
|
|
let minor = create (n - 1) (n - 1) (fun mi mj ->
|
|
let si = if mi < i then mi else mi + 1 in
|
|
let sj = if mj < j then mj else mj + 1 in
|
|
m.(si).(sj)
|
|
) in
|
|
let minor_det = det minor in
|
|
let sign = if (i + j) mod 2 = 0 then Const 1.0 else Const (-1.0) in
|
|
simplify (Mul (sign, minor_det))
|
|
) in
|
|
let adj_t = transpose adj in
|
|
Some (map (fun e -> simplify (Div (e, d))) adj_t)
|
|
|
|
let trace m =
|
|
let n = min (rows m) (cols m) in
|
|
let sum = ref (Const 0.0) in
|
|
for i = 0 to n - 1 do
|
|
sum := Add (!sum, m.(i).(i))
|
|
done;
|
|
simplify !sum
|
|
|
|
let eigenvalues _m =
|
|
[]
|
|
|
|
let rank m =
|
|
let r = rows m in
|
|
let c = cols m in
|
|
let temp = Array.map Array.copy m in
|
|
let rec count_pivots row col rank =
|
|
if row >= r || col >= c then rank
|
|
else
|
|
match temp.(row).(col) with
|
|
| Const 0.0 ->
|
|
let rec find_pivot i =
|
|
if i >= r then None
|
|
else match temp.(i).(col) with
|
|
| Const 0.0 -> find_pivot (i + 1)
|
|
| _ -> Some i
|
|
in
|
|
(match find_pivot (row + 1) with
|
|
| None -> count_pivots row (col + 1) rank
|
|
| Some i ->
|
|
let tmp_row = temp.(row) in
|
|
temp.(row) <- temp.(i);
|
|
temp.(i) <- tmp_row;
|
|
count_pivots row col rank)
|
|
| pivot ->
|
|
for i = row + 1 to r - 1 do
|
|
let factor = Div (temp.(i).(col), pivot) in
|
|
for j = col to c - 1 do
|
|
temp.(i).(j) <- simplify (Sub (temp.(i).(j), Mul (factor, temp.(row).(j))))
|
|
done
|
|
done;
|
|
count_pivots (row + 1) (col + 1) (rank + 1)
|
|
in
|
|
count_pivots 0 0 0
|