1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
open Nx_core
open Bigarray
type ('a, 'b) buffer = ('a, 'b, Bigarray.c_layout) Bigarray.Array1.t
type context = { pool : Parallel.pool }
type ('a, 'b) t = {
context : context;
dtype : ('a, 'b) Dtype.t;
buffer : ('a, 'b) buffer;
view : View.t;
}
let dtype { dtype; _ } = dtype
let buffer { buffer; _ } = buffer
let view { view; _ } = view
let shape { view; _ } = View.shape view
let strides { view; _ } = View.strides view
let stride axis { view; _ } = View.stride axis view
let offset { view; _ } = View.offset view
let size { view; _ } = View.numel view
let numel { view; _ } = View.numel view
let dims { view; _ } = View.shape view
let dim axis { view; _ } = View.dim axis view
let ndim { view; _ } = View.ndim view
let is_c_contiguous { view; _ } = View.is_c_contiguous view
let create_buffer_unsafe (type a b) (dt : (a, b) Dtype.t)
(size_in_elements : int) : (a, b) buffer =
Bigarray.Array1.create
(Dtype.to_bigarray_kind dt)
Bigarray.c_layout size_in_elements
let empty : type a b. context -> (a, b) Dtype.t -> int array -> (a, b) t =
fun ctx dt shp ->
let num_elements = Array.fold_left ( * ) 1 shp in
let buf = create_buffer_unsafe dt num_elements in
let vw = View.create shp in
{ context = ctx; dtype = dt; buffer = buf; view = vw }
let full : type a b. context -> (a, b) Dtype.t -> int array -> a -> (a, b) t =
fun ctx dt shp value ->
let t = empty ctx dt shp in
if Array.fold_left ( * ) 1 shp > 0 then Array1.fill t.buffer value;
t
let empty_like t = empty t.context (dtype t) (shape t)
let copy : type a b. (a, b) t -> (a, b) t =
fun t_src ->
let src_view = view t_src in
let src_shape = View.shape src_view in
let total_elements = View.numel src_view in
let new_buffer = create_buffer_unsafe (dtype t_src) total_elements in
let new_view = View.create src_shape in
let new_t =
{
context = t_src.context;
dtype = dtype t_src;
buffer = new_buffer;
view = new_view;
}
in
if total_elements = 0 then new_t
else if
is_c_contiguous t_src
&& View.offset src_view = 0
&& Array1.dim (buffer t_src) = total_elements
then (
Array1.blit (buffer t_src) new_buffer;
new_t)
else
let n_dims = View.ndim src_view in
if n_dims = 0 then (
let v = Bigarray.Array1.get (buffer t_src) (View.offset src_view) in
Bigarray.Array1.set new_buffer (View.offset new_view) v;
new_t)
else
let current_md_idx = Array.make n_dims 0 in
let rec copy_slice dim =
if dim = n_dims then
let src_physical_idx =
View.offset src_view
+ Shape.ravel_index current_md_idx (View.strides src_view)
in
let dst_physical_idx =
Shape.ravel_index current_md_idx (View.strides new_view)
in
new_buffer.{dst_physical_idx} <- (buffer t_src).{src_physical_idx}
else
for i = 0 to View.dim dim src_view - 1 do
current_md_idx.(dim) <- i;
copy_slice (dim + 1)
done
in
copy_slice 0;
new_t
let fill : type a b. a -> (a, b) t -> unit =
fun value t_fill ->
let fill_view = view t_fill in
let fill_buffer = buffer t_fill in
let total_elements = View.numel fill_view in
if total_elements = 0 then ()
else if is_c_contiguous t_fill && View.offset fill_view = 0 then
Array1.fill fill_buffer value
else
let n_dims = View.ndim fill_view in
if n_dims = 0 then
Bigarray.Array1.set fill_buffer (View.offset fill_view) value
else
let current_md_idx = Array.make n_dims 0 in
let rec fill_slice dim =
if dim = n_dims then
let physical_idx =
View.offset fill_view
+ Shape.ravel_index current_md_idx (View.strides fill_view)
in
fill_buffer.{physical_idx} <- value
else
for i = 0 to View.dim dim fill_view - 1 do
current_md_idx.(dim) <- i;
fill_slice (dim + 1)
done
in
fill_slice 0
let blit : type a b. (a, b) t -> (a, b) t -> unit =
fun src dst ->
let src_view = view src in
let dst_view = view dst in
if View.ndim src_view <> View.ndim dst_view then
invalid_arg "blit: tensors must have the same number of dimensions";
if not (Shape.equal (View.shape src_view) (View.shape dst_view)) then
invalid_arg "blit: tensors must have the same shape";
let total_elements = View.numel src_view in
if total_elements = 0 then ()
else
let src_buffer = buffer src in
let dst_buffer = buffer dst in
let n_dims = View.ndim src_view in
if n_dims = 0 then
dst_buffer.{View.offset dst_view} <- src_buffer.{View.offset src_view}
else
let current_md_idx = Array.make n_dims 0 in
let rec blit_slice dim =
if dim = n_dims then (
let src_physical_offset =
View.offset src_view
+ Shape.ravel_index current_md_idx (View.strides src_view)
in
let dst_physical_offset =
View.offset dst_view
+ Shape.ravel_index current_md_idx (View.strides dst_view)
in
if false then
Printf.printf "Copying from src[%d] to dst[%d]\n"
src_physical_offset dst_physical_offset;
dst_buffer.{dst_physical_offset} <- src_buffer.{src_physical_offset})
else
for i = 0 to View.dim dim src_view - 1 do
current_md_idx.(dim) <- i;
blit_slice (dim + 1)
done
in
blit_slice 0