Source file owl_slicing.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# 1 "src/owl/core/owl_slicing.ml"
open Bigarray
open Owl_types
include Owl_base_slicing
let get_fancy_array_typ axis x =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
let y = Genarray.create _kind C_layout sy in
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true
then (
Owl_slicing_basic.get _kind axis' x' y';
y)
else (
Owl_slicing_fancy.get _kind axis' x' y';
y)
let get_fancy_array_typ_ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true
then Owl_slicing_basic.get _kind axis' x' y'
else Owl_slicing_fancy.get _kind axis' x' y'
let set_fancy_array_typ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
assert (Genarray.dims y = sy);
let axis', x', y' = optimise_input_shape axis x y in
if is_basic_slicing axis = true
then Owl_slicing_basic.set _kind axis' x' y'
else Owl_slicing_fancy.set _kind axis' x' y'
let get_slice_array_typ axis x =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
let y = Genarray.create _kind C_layout sy in
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.get _kind axis' x' y';
y
let get_slice_array_typ_ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.get _kind axis' x' y'
let set_slice_array_typ axis x y =
let _kind = Genarray.kind x in
let sx = Genarray.dims x in
let axis = check_slice_definition axis sx in
let sy = calc_slice_shape axis in
assert (Genarray.dims y = sy);
let axis', x', y' = optimise_input_shape axis x y in
Owl_slicing_basic.set _kind axis' x' y'
let get_fancy_list_typ axis x = get_fancy_array_typ (sdlist_to_sdarray axis) x
let get_fancy_list_typ_ axis x y = get_fancy_array_typ_ (sdlist_to_sdarray axis) x y
let get_fancy_ext_idx_typ axis x = get_fancy_array_typ (sdarray_to_sdarray axis) x
let set_fancy_list_typ axis x y = set_fancy_array_typ (sdlist_to_sdarray axis) x y
let set_fancy_ext_idx_typ axis x y = set_fancy_array_typ (sdarray_to_sdarray axis) x y
let get_slice_list_typ axis x =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
get_slice_array_typ axis x
let get_slice_list_typ_ axis x y =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
get_slice_array_typ_ axis x y
let get_slice_ext_idx_typ axis x =
let axis = Array.map (fun i -> R_ (Array.of_list i)) axis in
get_slice_array_typ axis x
let set_slice_list_typ axis x y =
let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
set_slice_array_typ axis x y
let set_slice_ext_idx_typ axis x y =
let axis = Array.map (fun i -> R_ (Array.of_list i)) axis in
set_slice_array_typ axis x y