Source file linear_algebra.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
let error_msgf fmt = Format.kasprintf (fun err -> Error (`Msg err)) fmt
let col_norm a column =
let acc = ref 0. in
for i = 0 to Array.length a - 1 do
let entry = a.(i).(column) in
acc := !acc +. (entry *. entry)
done;
sqrt !acc
let col_inner_prod t j1 j2 =
let acc = ref 0. in
for i = 0 to Array.length t - 1 do
acc := !acc +. (t.(i).(j1) *. t.(i).(j2))
done;
!acc
let qr_in_place a =
let m = Array.length a in
if m = 0 then ([||], [||])
else
let n = Array.length a.(0) in
let r = Array.make_matrix n n 0. in
for j = 0 to n - 1 do
let alpha = col_norm a j in
r.(j).(j) <- alpha;
let one_over_alpha = 1. /. alpha in
for i = 0 to m - 1 do
a.(i).(j) <- a.(i).(j) *. one_over_alpha
done;
for j2 = j + 1 to n - 1 do
let c = col_inner_prod a j j2 in
r.(j).(j2) <- c;
for i = 0 to m - 1 do
a.(i).(j2) <- a.(i).(j2) -. (c *. a.(i).(j))
done
done
done;
(a, r)
let qr ?(in_place = false) a =
let a = if in_place then a else Array.map Array.copy a in
qr_in_place a
let mul_mv ?(trans = false) a x =
let rows = Array.length a in
if rows = 0 then [||]
else
let cols = Array.length a.(0) in
let m, n, get =
if trans then
let get i j = a.(j).(i) in
(cols, rows, get)
else
let get i j = a.(i).(j) in
(rows, cols, get)
in
if n <> Array.length x then failwith "Dimension mismatch";
let result = Array.make m 0. in
for i = 0 to m - 1 do
let v, _ =
Array.fold_left
(fun (acc, j) x -> (acc +. (get i j *. x), succ j))
(0., 0) x
in
result.(i) <- v
done;
result
let is_nan v = match classify_float v with FP_nan -> true | _ -> false
let triu_solve r b =
let m = Array.length b in
if m <> Array.length r then
error_msgf
"triu_solve R b requires R to be square with same number of rows as b"
else if m = 0 then Ok [||]
else if m <> Array.length r.(0) then
error_msgf "triu_solve R b requires R to be a square"
else
let sol = Array.copy b in
for i = m - 1 downto 0 do
sol.(i) <- sol.(i) /. r.(i).(i);
for j = 0 to i - 1 do
sol.(j) <- sol.(j) -. (r.(j).(i) *. sol.(i))
done
done;
if Array.exists is_nan sol then error_msgf "triu_solve detected NaN result"
else Ok sol
let ols ?(in_place = false) a b =
let q, r = qr ~in_place a in
triu_solve r (mul_mv ~trans:true q b)