open Expr open Simplify type matrix = expr array array let create rows cols init = Array.init rows (fun i -> Array.init cols (fun j -> init i j)) let identity n = create n n (fun i j -> if i = j then Const 1.0 else Const 0.0) let rows m = Array.length m let cols m = if Array.length m > 0 then Array.length m.(0) else 0 let get m i j = m.(i).(j) let set m i j v = m.(i).(j) <- v let map f m = Array.map (Array.map f) m let transpose m = let r = rows m in let c = cols m in create c r (fun i j -> m.(j).(i)) let add m1 m2 = if rows m1 <> rows m2 || cols m1 <> cols m2 then failwith "matrix dimensions must match for addition" else create (rows m1) (cols m1) (fun i j -> simplify (Add (m1.(i).(j), m2.(i).(j))) ) let mult m1 m2 = if cols m1 <> rows m2 then failwith "incompatible matrix dimensions for multiplication" else create (rows m1) (cols m2) (fun i j -> let sum = ref (Const 0.0) in for k = 0 to cols m1 - 1 do sum := Add (!sum, Mul (m1.(i).(k), m2.(k).(j))) done; simplify !sum ) let scalar_mult s m = map (fun e -> simplify (Mul (s, e))) m let det m = let n = rows m in if n <> cols m then failwith "determinant requires square matrix"; let rec determinant mat size = if size = 1 then mat.(0).(0) else if size = 2 then simplify (Sub (Mul (mat.(0).(0), mat.(1).(1)), Mul (mat.(0).(1), mat.(1).(0)))) else let result = ref (Const 0.0) in for j = 0 to size - 1 do let minor = create (size - 1) (size - 1) (fun i k -> let mi = if i < 0 then i else i + 1 in let mk = if k < j then k else k + 1 in mat.(mi).(mk) ) in let cofactor = determinant minor (size - 1) in let sign = if j mod 2 = 0 then Const 1.0 else Const (-1.0) in result := Add (!result, Mul (Mul (sign, mat.(0).(j)), cofactor)) done; simplify !result in determinant m n let inverse m = let n = rows m in if n <> cols m then None else let d = det m in match d with | Const 0.0 -> None | _ -> let adj = create n n (fun i j -> let minor = create (n - 1) (n - 1) (fun mi mj -> let si = if mi < i then mi else mi + 1 in let sj = if mj < j then mj else mj + 1 in m.(si).(sj) ) in let minor_det = det minor in let sign = if (i + j) mod 2 = 0 then Const 1.0 else Const (-1.0) in simplify (Mul (sign, minor_det)) ) in let adj_t = transpose adj in Some (map (fun e -> simplify (Div (e, d))) adj_t) let trace m = let n = min (rows m) (cols m) in let sum = ref (Const 0.0) in for i = 0 to n - 1 do sum := Add (!sum, m.(i).(i)) done; simplify !sum let eigenvalues _m = [] let rank m = let r = rows m in let c = cols m in let temp = Array.map Array.copy m in let rec count_pivots row col rank = if row >= r || col >= c then rank else match temp.(row).(col) with | Const 0.0 -> let rec find_pivot i = if i >= r then None else match temp.(i).(col) with | Const 0.0 -> find_pivot (i + 1) | _ -> Some i in (match find_pivot (row + 1) with | None -> count_pivots row (col + 1) rank | Some i -> let tmp_row = temp.(row) in temp.(row) <- temp.(i); temp.(i) <- tmp_row; count_pivots row col rank) | pivot -> for i = row + 1 to r - 1 do let factor = Div (temp.(i).(col), pivot) in for j = col to c - 1 do temp.(i).(j) <- simplify (Sub (temp.(i).(j), Mul (factor, temp.(row).(j)))) done done; count_pivots (row + 1) (col + 1) (rank + 1) in count_pivots 0 0 0