Source file ppx_helper.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
open Base
open Ppxlib
type li = longident
let rec collect_list accu = function
| [%expr [%e? hd] :: [%e? tl]] -> collect_list (hd :: accu) tl
| [%expr []] -> List.rev accu
| expr -> List.rev (expr :: accu)
let dim_spec_to_string = function
| `Input_dims dim -> "input (tuple) of dim " ^ Int.to_string dim
| `Output_dims dim -> "output (list) of dim " ^ Int.to_string dim
| `Batch_dims dim -> "batch (array) of dim " ^ Int.to_string dim
let ndarray_constant expr =
let loc = expr.pexp_loc in
let rec loop_dims accu = function
| { pexp_desc = Pexp_tuple (exp :: _ as exps); _ } ->
loop_dims (`Input_dims (List.length exps) :: accu) exp
| { pexp_desc = Pexp_array (exp :: _ as exps); _ } ->
loop_dims (`Batch_dims (List.length exps) :: accu) exp
| { pexp_desc = Pexp_tuple []; _ } -> `Input_dims 0 :: accu
| { pexp_desc = Pexp_array []; _ } -> `Batch_dims 0 :: accu
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } as expr -> (
let exps = collect_list [] expr in
match exps with
| exp :: _ -> loop_dims (`Output_dims (List.length exps) :: accu) exp
| [] -> `Output_dims 0 :: accu)
| _ -> accu
in
let dims_spec = Array.of_list_rev @@ loop_dims [] expr in
let open Ast_builder.Default in
let rec loop_values depth accu expr =
if depth >= Array.length dims_spec then
match expr with
| { pexp_desc = Pexp_constant (Pconst_float _); _ } -> expr :: accu
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
[%expr Float.of_int [%e expr]] :: accu
| { pexp_desc = Pexp_tuple _; pexp_loc = loc; _ } ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal found input axis (tuple), expected number")
:: accu
| { pexp_desc = Pexp_array _; pexp_loc = loc; _ } ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal found batch axis (array), expected number")
:: accu
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal found output axis (list), expected number")
:: accu
| expr -> expr :: accu
else
match expr with
| { pexp_desc = Pexp_tuple exps; _ } -> (
match dims_spec.(depth) with
| `Input_dims dim when dim = List.length exps ->
List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1))
| dim_spec ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal axis mismatch, got %s, expected %s"
(dim_spec_to_string @@ `Input_dims (List.length exps))
(dim_spec_to_string dim_spec))
:: accu)
| { pexp_desc = Pexp_array exps; _ } -> (
match dims_spec.(depth) with
| `Batch_dims dim when dim = List.length exps ->
List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1))
| dim_spec ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal axis mismatch, got %s, expected %s"
(dim_spec_to_string @@ `Batch_dims (List.length exps))
(dim_spec_to_string dim_spec))
:: accu)
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> (
let exps = collect_list [] expr in
match dims_spec.(depth) with
| `Output_dims dim when dim = List.length exps ->
List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1))
| dim_spec ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal axis mismatch, got %s, expected %s"
(dim_spec_to_string @@ `Output_dims (List.length exps))
(dim_spec_to_string dim_spec))
:: accu)
| { pexp_loc = loc; _ } ->
(pexp_extension ~loc
@@ Location.error_extensionf ~loc
"Arrayjit: ndarray literal: expected an axis (tuple, list or array)")
:: accu
in
let result = loop_values 0 [] expr in
let values = { expr with pexp_desc = Pexp_array (List.rev result) } in
let batch_dims, output_dims, input_dims =
Array.fold dims_spec ~init:([], [], [])
~f:(fun (batch_dims, output_dims, input_dims) -> function
| `Input_dims dim -> (batch_dims, output_dims, eint ~loc dim :: input_dims)
| `Output_dims dim -> (batch_dims, eint ~loc dim :: output_dims, input_dims)
| `Batch_dims dim -> (eint ~loc dim :: batch_dims, output_dims, input_dims))
in
(values, List.rev batch_dims, List.rev output_dims, List.rev input_dims)
(** Convert a string containing patterns like "identifier*" to an OCaml expression that substitutes
the identifiers with their runtime values. Identifiers match the pattern [a-z_][a-z0-9_]* and
must directly precede '*'.
Example usage: [substitute_identifiers_in_string ~loc "a *x + b * y"] generates an expression
equivalent to: [String.concat "" [Int.to_string a; " *x + "; Int.to_string b; " * y"]]
So if [a = 2] and [b = 3], the result would be ["2 *x + 3 * y"]. Whitespace between identifiers
and '*' is preserved. *)
let substitute_identifiers_in_einsum_spec ~loc str_input =
let multichar = String.contains str_input ',' in
let open Ast_builder.Default in
let is_identifier_start c = Char.is_alpha c || Char.equal c '_' in
let is_identifier_char c = Char.is_alphanum c || Char.equal c '_' in
let len = String.length str_input in
let substitutions = ref [] in
let i = ref 0 in
while !i < len do
let c = str_input.[!i] in
if is_identifier_start c then (
let start_pos = !i in
while !i < len && is_identifier_char str_input.[!i] && (multichar || !i = start_pos) do
i := !i + 1
done;
let end_pos = !i - 1 in
while !i < len && List.mem ~equal:Char.equal [ ' '; '\t'; '\n'; '\r' ] str_input.[!i] do
i := !i + 1
done;
if !i < len && Char.equal str_input.[!i] '*' then
let identifier = String.sub str_input ~pos:start_pos ~len:(end_pos - start_pos + 1) in
substitutions := (start_pos, end_pos, identifier) :: !substitutions)
else i := !i + 1
done;
let substitutions = List.rev !substitutions in
let segments = ref [] in
let pos = ref 0 in
List.iter substitutions ~f:(fun (start_pos, end_pos, identifier) ->
(if start_pos > !pos then
let literal = String.sub str_input ~pos:!pos ~len:(start_pos - !pos) in
segments := estring ~loc literal :: !segments);
segments :=
[%expr Int.to_string [%e pexp_ident ~loc (Located.mk ~loc (Lident identifier))]]
:: !segments;
pos := end_pos + 1);
(if !pos < len then
let literal = String.sub str_input ~pos:!pos ~len:(len - !pos) in
segments := estring ~loc literal :: !segments);
let segments = List.rev !segments in
match segments with
| [] -> estring ~loc ""
| [ single ] -> single
| multiple -> [%expr String.concat ~sep:"" [%e elist ~loc multiple]]