Source file owl_slicing.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# 1 "src/owl/core/owl_slicing.ml"
open Bigarray
open Owl_types
include Owl_base_slicing
let get_fancy_array_typ axis x =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
let y = Genarray.create _kind C_layout sy in
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true then (
Owl_slicing_basic.get _kind axis' x' y';
y
)
else (
Owl_slicing_fancy.get _kind axis' x' y';
y
)
let get_fancy_array_typ_ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true then
Owl_slicing_basic.get _kind axis' x' y'
else
Owl_slicing_fancy.get _kind axis' x' y'
let set_fancy_array_typ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
assert (Genarray.dims y = sy);
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true then
Owl_slicing_basic.set _kind axis' x' y'
else
Owl_slicing_fancy.set _kind axis' x' y'
let get_slice_array_typ axis x =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
let y = Genarray.create _kind C_layout sy in
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.get _kind axis' x' y';
y
let get_slice_array_typ_ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.get _kind axis' x' y'
let set_slice_array_typ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
assert (Genarray.dims y = sy);
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.set _kind axis' x' y'
let get_fancy_list_typ axis x = get_fancy_array_typ (sdlist_to_sdarray axis) x
let get_fancy_list_typ_ axis x y = get_fancy_array_typ_ (sdlist_to_sdarray axis) x y
let set_fancy_list_typ axis x y = set_fancy_array_typ (sdlist_to_sdarray axis) x y
let get_slice_list_typ axis x =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
get_slice_array_typ axis x
let get_slice_list_typ_ axis x y =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
get_slice_array_typ_ axis x y
let set_slice_list_typ axis x y =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
set_slice_array_typ axis x y