Source file owl_slicing.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# 1 "src/owl/core/owl_slicing.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * Copyright (c) 2016-2019 Liang Wang <liang.wang@cl.cam.ac.uk>
 *)

open Bigarray

open Owl_types

(* include the base implementation of slicing function *)

include Owl_base_slicing


(* fancy slicing function *)

let get_fancy_array_typ axis x =
  let _kind = Genarray.kind x in
  (* check axis is within boundary then re-format *)
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  (* calculate the new shape for slice *)
  let sy = calc_slice_shape axis in
  let y = Genarray.create _kind C_layout sy in
  (* optimise the shape if possible *)
  let axis', x', y' = optimise_input_shape axis x y in
  (* slicing vs. fancy indexing *)
  if is_basic_slicing axis = true then (
    Owl_slicing_basic.get _kind axis' x' y';
    y
  )
  else (
    Owl_slicing_fancy.get _kind axis' x' y';
    y
  )


let get_fancy_array_typ_ axis x y =
  let _kind = Genarray.kind x in
  (* check axis is within boundary then re-format *)
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  (* optimise the shape if possible *)
  let axis', x', y' = optimise_input_shape axis x y in
  (* slicing vs. fancy indexing *)
  if is_basic_slicing axis = true then
    Owl_slicing_basic.get _kind axis' x' y'
  else
    Owl_slicing_fancy.get _kind axis' x' y'


let set_fancy_array_typ axis x y =
  let _kind = Genarray.kind x in
  (* check axis is within boundary then re-format *)
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  (* validate the slice shape is the same as y's *)
  let sy = calc_slice_shape axis in
  assert (Genarray.dims y = sy);
  (* optimise the shape if possible *)
  let axis', x', y' = optimise_input_shape axis x y in
  (* slicing vs. fancy indexing *)
  if is_basic_slicing axis = true then
    Owl_slicing_basic.set _kind axis' x' y'
  else
    Owl_slicing_fancy.set _kind axis' x' y'


(* Basic slicing function *)

let get_slice_array_typ axis x =
  let _kind = Genarray.kind x in
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  let sy = calc_slice_shape axis in
  let y = Genarray.create _kind C_layout sy in
  let axis', x', y' = optimise_input_shape axis x y in
  Owl_slicing_basic.get _kind axis' x' y';
  y


let get_slice_array_typ_ axis x y =
  let _kind = Genarray.kind x in
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  let axis', x', y' = optimise_input_shape axis x y in
  Owl_slicing_basic.get _kind axis' x' y'


let set_slice_array_typ axis x y =
  let _kind = Genarray.kind x in
  let sx = Genarray.dims x in
  let axis = check_slice_definition axis sx in
  let sy = calc_slice_shape axis in
  assert (Genarray.dims y = sy);
  let axis', x', y' = optimise_input_shape axis x y in
  Owl_slicing_basic.set _kind axis' x' y'


(* same as slice_array_typ function but take list type as slice definition *)
let get_fancy_list_typ axis x = get_fancy_array_typ (sdlist_to_sdarray axis) x


let get_fancy_list_typ_ axis x y = get_fancy_array_typ_ (sdlist_to_sdarray axis) x y


(* same as set_slice_array_typ function but take list type as slice definition *)
let set_fancy_list_typ axis x y = set_fancy_array_typ (sdlist_to_sdarray axis) x y


(* simplified get_slice function which accept list of list as slice definition *)
let get_slice_list_typ axis x =
  let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
  get_slice_array_typ axis x


let get_slice_list_typ_ axis x y =
  let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
  get_slice_array_typ_ axis x y


(* simplified set_slice function which accept list of list as slice definition *)
let set_slice_list_typ axis x y =
  let axis = List.map (fun i -> R_ (Array.of_list i)) axis |> Array.of_list in
  set_slice_array_typ axis x y



(* ends here *)