1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
type t = int array
let numel shape =
let n = Array.length shape in
if n = 0 then 1 else Array.fold_left ( * ) 1 shape
let equal = ( = )
let c_contiguous_strides shape =
let n = Array.length shape in
if n = 0 then [||]
else
let strides = Array.make n 0 in
strides.(n - 1) <- (if shape.(n - 1) = 0 then 0 else 1);
for i = n - 2 downto 0 do
strides.(i) <-
(if shape.(i) = 0 then 0 else strides.(i + 1) * max 1 shape.(i + 1))
done;
strides
let ravel_index indices strides =
if Array.length indices <> Array.length strides then
Error.invalid ~op:"ravel_index"
~what:
(Printf.sprintf "rank mismatch: indices[%d] vs strides[%d]"
(Array.length indices) (Array.length strides))
~reason:"dimensions must match" ();
let o = ref 0 in
Array.iteri (fun i v -> o := !o + (v * strides.(i))) indices;
!o
let unravel_index k shape =
let n = Array.length shape in
if n = 0 then
if k = 0 then [||]
else
Error.invalid ~op:"unravel_index" ~what:"k"
~reason:(Printf.sprintf "%d out of bounds for scalar" k)
()
else if Array.exists (( = ) 0) shape then
if k = 0 then Array.make n 0
else
Error.invalid ~op:"unravel_index" ~what:"k"
~reason:(Printf.sprintf "%d > 0 for zero-size shape" k)
()
else
let total_elements = numel shape in
if k < 0 || k >= total_elements then
Error.invalid ~op:"unravel_index" ~what:"k"
~reason:
(Printf.sprintf "%d out of bounds for shape (size %d)" k
total_elements)
();
let idx = Array.make n 0 in
let temp_k = ref k in
for i = n - 1 downto 1 do
let dim_size = shape.(i) in
idx.(i) <- !temp_k mod dim_size;
temp_k := !temp_k / dim_size
done;
idx.(0) <- !temp_k;
if idx.(0) >= shape.(0) then
Error.invalid ~op:"unravel_index"
~what:(Printf.sprintf "calculated idx.(0)=%d" idx.(0))
~reason:(Printf.sprintf "out of bounds for shape.(0)=%d" shape.(0))
~hint:"this indicates an issue with k or logic" ();
idx
let unravel_index_into k shape result =
let n = Array.length shape in
if n = 0 then (
if k <> 0 then
Error.invalid ~op:"unravel_index_into" ~what:"k"
~reason:(Printf.sprintf "%d out of bounds for scalar" k)
() )
else if Array.exists (( = ) 0) shape then
if
k = 0
then
for i = 0 to n - 1 do
result.(i) <- 0
done
else
Error.invalid ~op:"unravel_index_into" ~what:"k"
~reason:(Printf.sprintf "%d > 0 for zero-size shape" k)
()
else
let total_elements = numel shape in
if k < 0 || k >= total_elements then
Error.invalid ~op:"unravel_index_into" ~what:"k"
~reason:
(Printf.sprintf "%d out of bounds for shape (size %d)" k
total_elements)
()
else
let temp_k = ref k in
for i = n - 1 downto 1 do
let dim_size = shape.(i) in
result.(i) <- !temp_k mod dim_size;
temp_k := !temp_k / dim_size
done;
result.(0) <- !temp_k;
if result.(0) >= shape.(0) then
Error.invalid ~op:"unravel_index_into"
~what:(Printf.sprintf "calculated result.(0)=%d" result.(0))
~reason:(Printf.sprintf "out of bounds for shape.(0)=%d" shape.(0))
()
let resolve_neg_one current_shape new_shape_spec =
let new_shape_spec_l = Array.to_list new_shape_spec in
let current_numel = numel current_shape in
let neg_one_count =
new_shape_spec_l |> List.filter (( = ) (-1)) |> List.length
in
if neg_one_count > 1 then
Error.invalid ~op:"reshape" ~what:"shape specification"
~reason:"multiple -1 dimensions"
~hint:"can only specify one unknown dimension" ()
else if neg_one_count = 0 then new_shape_spec
else
let specified_numel =
List.filter (( <> ) (-1)) new_shape_spec_l |> Array.of_list |> numel
in
if specified_numel = 0 then
if current_numel = 0 then
Array.map (fun x -> if x = -1 then 0 else x) new_shape_spec
else
Error.cannot ~op:"reshape" ~what:"infer -1"
~from:"shape with 0-size dimensions" ~to_:"non-zero total size"
~reason:"incompatible dimensions" ()
else if current_numel mod specified_numel <> 0 then
Error.cannot ~op:"reshape" ~what:"reshape"
~from:(Printf.sprintf "%d elements" current_numel)
~to_:(Printf.sprintf "shape with %d elements" specified_numel)
~reason:"size mismatch" ()
else
let inferred_dim = current_numel / specified_numel in
Array.map (fun s -> if s = -1 then inferred_dim else s) new_shape_spec
let broadcast shape_a shape_b =
let rank_a = Array.length shape_a and rank_b = Array.length shape_b in
let rank_out = max rank_a rank_b in
let out_shape = Array.make rank_out 1 in
for i = 0 to rank_out - 1 do
let dim_a =
if i < rank_out - rank_a then 1 else shape_a.(i - (rank_out - rank_a))
in
let dim_b =
if i < rank_out - rank_b then 1 else shape_b.(i - (rank_out - rank_b))
in
if dim_a = dim_b then out_shape.(i) <- dim_a
else if dim_a = 1 then out_shape.(i) <- dim_b
else if dim_b = 1 then out_shape.(i) <- dim_a
else
Error.broadcast_incompatible ~op:"broadcast" ~shape1:shape_a
~shape2:shape_b ()
done;
out_shape
let broadcast_index target_multi_idx source_shape =
let target_ndim = Array.length target_multi_idx in
let source_ndim = Array.length source_shape in
let source_multi_idx = Array.make source_ndim 0 in
for i = 0 to source_ndim - 1 do
let target_idx_pos = target_ndim - source_ndim + i in
let source_idx_pos = i in
if source_idx_pos < 0 || target_idx_pos < 0 then ()
else if source_shape.(source_idx_pos) = 1 then
source_multi_idx.(source_idx_pos) <- 0
else source_multi_idx.(source_idx_pos) <- target_multi_idx.(target_idx_pos)
done;
source_multi_idx
let to_string shape =
let shape_str =
Array.map string_of_int shape |> Array.to_list |> String.concat ","
in
Printf.sprintf "[%s]" shape_str
let broadcast_index_into target_multi_idx source_shape result =
let target_ndim = Array.length target_multi_idx in
let source_ndim = Array.length source_shape in
for i = 0 to source_ndim - 1 do
let target_idx_pos = target_ndim - source_ndim + i in
let source_idx_pos = i in
if source_idx_pos < 0 || target_idx_pos < 0 then ()
else if source_shape.(source_idx_pos) = 1 then result.(source_idx_pos) <- 0
else result.(source_idx_pos) <- target_multi_idx.(target_idx_pos)
done
let pp fmt shape = Format.fprintf fmt "%s" (to_string shape)