1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
(** View tracking for lazy tensor operations. *)
type t = { views : View.t list }
let views_to_real_strides views =
match views with
| [] -> None
| [ single ] -> Some (View.strides single)
| _ -> (
let rec simplify_views views =
match views with
| [] | [ _ ] -> views
| v1 :: v2 :: rest -> (
match View.merge v1 v2 with
| Some merged -> simplify_views (merged :: rest)
| None -> v1 :: simplify_views (v2 :: rest))
in
let simplified = simplify_views views in
match simplified with
| [ single ] -> Some (View.strides single)
| _ -> (
match List.rev simplified with
| [] -> None
| last :: _ ->
let last_strides = View.strides last in
if Array.length last_strides > 0 then Some last_strides else None)
)
let create shape =
let view = View.create shape in
{ views = [ view ] }
let create_strided shape ~strides ~offset =
let view = View.create ~strides ~offset shape in
{ views = [ view ] }
let shape t =
match t.views with
| [] -> Error.failed ~op:"view_tracker.shape" ~what:"empty views list" ()
| _ ->
let last_view = List.hd (List.rev t.views) in
View.shape last_view
let ndim t = Symbolic_shape.rank (shape t)
let numel t =
let s = shape t in
let n = Symbolic_shape.rank s in
if n = 0 then Symbolic_shape.static 1
else
Array.fold_left
(fun acc dim ->
match (Symbolic_shape.eval_dim acc, Symbolic_shape.eval_dim dim) with
| Some a, Some b -> Symbolic_shape.static (a * b)
| _ ->
Symbolic_shape.mul acc dim)
s.(0)
(Array.sub s 1 (n - 1))
let offset t =
match t.views with
| [] -> Symbolic_shape.static 0
| views -> (
let last_view = List.hd (List.rev views) in
match View.offset last_view with
| n -> Symbolic_shape.static n )
let rec is_contiguous t =
let simplified = simplify t in
match simplified.views with
| [ view ] -> View.is_c_contiguous view
| _ -> false
and simplify t =
let rec merge_adjacent views =
match views with
| [] | [ _ ] -> views
| v1 :: v2 :: rest -> (
match View.merge v1 v2 with
| Some merged -> merge_adjacent (merged :: rest)
| None -> v1 :: merge_adjacent (v2 :: rest))
in
let merged_views = merge_adjacent t.views in
let views = List.map View.simplify merged_views in
{ views }
let add_view view t = { views = t.views @ [ view ] }
let get_last_view t =
match List.rev t.views with
| [] -> Error.failed ~op:"view_tracker" ~what:"empty views list" ()
| view :: _ -> view
let reshape new_shape t =
let current_view = get_last_view t in
let reshaped = View.reshape current_view new_shape in
let result = add_view reshaped t in
result
let permute axes t =
let current_view = get_last_view t in
let permuted = View.permute current_view axes in
let result = add_view permuted t in
result
let expand new_shape t =
let current_view = get_last_view t in
let expanded = View.expand current_view new_shape in
add_view expanded t
let shrink bounds t =
let current_view = get_last_view t in
let shrunk = View.shrink current_view bounds in
add_view shrunk t
let pad padding t =
let current_view = get_last_view t in
let padded = View.pad current_view padding in
add_view padded t
let flip axes_to_flip t =
let current_view = get_last_view t in
let flipped = View.flip current_view axes_to_flip in
add_view flipped t
let strides t =
let simplified = simplify t in
match views_to_real_strides simplified.views with
| Some s -> Some s
| None -> (
match List.rev simplified.views with
| [] -> None
| last :: _ -> Some (View.strides last))
let can_get_strides t = Option.is_some (strides t)
let is_materializable t =
match t.views with
| [] -> false
| _views ->
let final_shape = shape t in
Symbolic_shape.is_static final_shape
&&
true
let compose t =
match t.views with
| [] -> None
| [ single ] -> Some single
| first :: rest ->
let result =
List.fold_left
(fun acc v ->
match acc with
| None -> None
| Some acc_view ->
let merged = View.merge acc_view v in
merged)
(Some first) rest
in
result