Source file linear_algebra.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
(** Linear algebra module, copied from {{:
https://gitlab.com/nomadic-labs/privacy-team/-/blob/9e4050cb4a304848901c3434d61a8d7f0c7107c4/nuplompiler/linear_algebra.ml
} Nomadic Labs privacy team repository } *)
module type Ring_sig = sig
type t
val add : t -> t -> t
val mul : t -> t -> t
val negate : t -> t
val zero : t
val one : t
val eq : t -> t -> bool
end
module type Field_sig = sig
include Ring_sig
val inverse_exn : t -> t
end
(** This refers to the mathematical generalization of vector space called
"module", where the field of scalars is replaced by a ring *)
module type Module_sig = sig
type t
type matrix = t array array
(** [zeros r c] is a matrix with [r] rows and [c] columns filled with zeros *)
val zeros : int -> int -> matrix
(** [identity n] is the identity matrix of dimension [n] *)
val identity : int -> matrix
(** matrix equality *)
val equal : matrix -> matrix -> bool
(** matrix addition *)
val add : matrix -> matrix -> matrix
(** matrix multiplication *)
val mul : matrix -> matrix -> matrix
(** matrix transposition *)
val transpose : matrix -> matrix
(** [row_add ~coeff i j m] adds to the i-th row, the j-th row times coeff in m *)
val row_add : ?coeff:t -> int -> int -> matrix -> unit
(** [row_swap i j m] swaps the i-th and j-th rows of m *)
val row_swap : int -> int -> matrix -> unit
(** [row_mul coeff i m] multiplies the i-th row by coeff in m *)
val row_mul : t -> int -> matrix -> unit
(** [filter_cols f m] removes the columns of [m] whose index does not satisfy [f] *)
val filter_cols : (int -> bool) -> matrix -> matrix
(** splits matrix [m] into the first n columns and the rest, producing two matrices *)
val split_n : int -> matrix -> matrix * matrix
end
module type VectorSpace_sig = sig
include Module_sig
(** reduced row Echelon form of m *)
val reduced_row_echelon_form : matrix -> matrix
(** [inverse m] is the inverse matrix of m
@raise [Invalid_argument] if [m] is not invertible *)
val inverse : matrix -> matrix
end
module Make_Module (Ring : Ring_sig) : Module_sig with type t = Ring.t = struct
type t = Ring.t
type matrix = t array array
let zeros r c = Array.make_matrix r c Ring.zero
let identity n =
Array.(init n (fun i -> init n Ring.(fun j -> if i = j then one else zero)))
let equal = Array.(for_all2 (for_all2 Ring.eq))
let add = Array.(map2 (map2 Ring.add))
let mul m1 m2 =
let nb_rows = Array.length m1 in
let nb_cols = Array.length m2.(0) in
let n = Array.length m1.(0) in
assert (Array.length m2 = n) ;
let p = zeros nb_rows nb_cols in
for i = 0 to nb_rows - 1 do
for j = 0 to nb_cols - 1 do
for k = 0 to n - 1 do
p.(i).(j) <- Ring.(add p.(i).(j) @@ mul m1.(i).(k) m2.(k).(j))
done
done
done ;
p
let transpose m =
let nb_rows = Array.length m in
let nb_cols = Array.length m.(0) in
Array.(init nb_cols (fun i -> init nb_rows (fun j -> m.(j).(i))))
let row_add ?(coeff = Ring.one) i j m =
m.(i) <- Array.map2 Ring.(fun a b -> add a (mul coeff b)) m.(i) m.(j)
let row_swap i j m =
let aux = m.(i) in
m.(i) <- m.(j) ;
m.(j) <- aux
let row_mul coeff i m = m.(i) <- Array.map (Ring.mul coeff) m.(i)
let filter_cols f =
Array.map (fun row ->
List.filteri (fun i _ -> f i) (Array.to_list row) |> Array.of_list)
let split_n n m =
(filter_cols (fun i -> i < n) m, filter_cols (fun i -> i >= n) m)
end
module Make_VectorSpace (Field : Field_sig) :
VectorSpace_sig with type t = Field.t = struct
include Make_Module (Field)
let reduced_row_echelon_form m =
let n = Array.length m in
let find_pivot row =
let rec aux cnt = function
| [] -> None
| x :: xs -> if Field.(eq zero x) then aux (cnt + 1) xs else Some cnt
in
aux 0 (Array.to_list row)
in
let move_zeros_to_bottom m =
let is_non_zero_row = Array.exists (fun a -> not Field.(eq zero a)) in
let rec aux nonzeros zeros = function
| [] -> Array.of_list (List.rev nonzeros @ zeros)
| r :: rs ->
if is_non_zero_row r then aux (r :: nonzeros) zeros rs
else aux nonzeros (r :: zeros) rs
in
aux [] [] (Array.to_list m)
in
let rec aux k =
if k >= Array.length m then m
else
match find_pivot m.(k) with
| Some j when j < n ->
row_mul (Field.inverse_exn m.(k).(j)) k m ;
Array.iteri
(fun i _ ->
if i <> k then row_add ~coeff:Field.(negate @@ m.(i).(j)) i k m)
m ;
row_swap k j m ;
aux (k + 1)
| _ -> aux (k + 1)
in
aux 0 |> move_zeros_to_bottom
let inverse m =
let n = Array.length m in
let id_n = identity n in
let augmented = Array.(map2 append m id_n) in
let reduced = reduced_row_echelon_form augmented in
let residue, inv = split_n n reduced in
let is_zero_row = Array.for_all Field.(eq zero) in
if Array.exists is_zero_row residue then
raise @@ Invalid_argument "matrix [m] is not invertible"
else inv
end