Source file ppx_rapper.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
open Base
open Ppxlib
module Buildef = Ast_builder.Default
(** Handle 'record_in' etc. in [%rapper "SELECT * FROM USERS" record_in record_out] *)
let parse_args args =
let allowed_args =
[ "record_in"; "record_out"; "function_out"; "syntax_off" ]
in
match
List.find
~f:(fun a -> not (List.mem ~equal:String.equal allowed_args a))
args
with
| Some unknown ->
Error (Printf.sprintf "Unknown rapper argument '%s'" unknown)
| None ->
let record_in = List.mem args "record_in" ~equal:String.equal in
let record_out = List.mem args "record_out" ~equal:String.equal in
let function_out = List.mem args "function_out" ~equal:String.equal in
let input_kind = if record_in then `Record else `Labelled_args in
let output_kind =
match (record_out, function_out) with
| false, false -> `Tuple
| true, false -> `Record
| false, true -> `Function
| true, true -> assert false
in
let syntax_off = List.mem args "syntax_off" ~equal:String.equal in
assert (not (function_out && record_out));
Ok (input_kind, output_kind, syntax_off)
(** Make some subexpressions to be used in generated code *)
let component_expressions ~loc parsed_query =
let open Query in
let inputs_caqti_type =
Codegen.make_caqti_type_tup ~loc parsed_query.in_params
in
let outputs_caqti_type =
Codegen.make_caqti_type_tup ~loc parsed_query.out_params
in
let parsed_sql = Buildef.estring ~loc parsed_query.sql in
(inputs_caqti_type, outputs_caqti_type, parsed_sql)
(** Make a function [expand_get] to produce the expressions for [get_one], [get_opt] and [get_many], and a similar [expand_exec] for [execute] *)
let make_expand_get_and_exec_expression ~loc parsed_query input_kind output_kind
=
let { Query.sql; in_params; out_params; list_params } = parsed_query in
match list_params with
| Some { subsql; string_index; param_index; params } ->
if not (List.length params = 1) then
failwith "%list only supports one input parameter currently";
let subsql_expr = Buildef.estring ~loc subsql in
let sql_before =
Buildef.estring ~loc @@ String.sub sql ~pos:0 ~len:string_index
in
let sql_after =
Buildef.estring ~loc
@@ String.sub sql ~pos:string_index
~len:(String.length sql - string_index)
in
let params_before, params_after = List.split_n in_params param_index in
let expression_contents =
{
Codegen.in_params = params_before @ params @ params_after;
out_params;
input_kind;
output_kind;
}
in
let caqti_input_type =
let exprs_before =
List.map ~f:(Codegen.caqti_type_of_param ~loc) params_before
in
let exprs_after =
List.map ~f:(Codegen.caqti_type_of_param ~loc) params_after
in
match (List.is_empty params_before, List.is_empty params_after) with
| true, true -> [%expr packed_list_type]
| true, false ->
let expression =
Codegen.caqti_type_tup_of_expressions ~loc
([%expr packed_list_type] :: exprs_after)
in
[%expr Caqti_type.([%e expression])]
| false, true ->
let expression =
Codegen.caqti_type_tup_of_expressions ~loc
(exprs_before @ [ [%expr packed_list_type] ])
in
[%expr Caqti_type.([%e expression])]
| false, false ->
let expression =
Codegen.caqti_type_tup_of_expressions ~loc
(exprs_before @ [ [%expr packed_list_type] ] @ exprs_after)
in
[%expr Caqti_type.([%e expression])]
in
let outputs_caqti_type = Codegen.make_caqti_type_tup ~loc out_params in
let list_param = List.hd_exn params in
let make_generic make_function query_expr =
let body_fn body =
let base =
[%expr
match
[%e
Buildef.pexp_ident ~loc
(Codegen.lident_of_param ~loc list_param)]
with
| [] ->
Rapper_helper.fail
Caqti_error.(
encode_rejected ~uri:Uri.empty ~typ:Caqti_type.unit
(Msg "Empty list"))
| elems ->
let subsqls =
Stdlib.List.map (fun _ -> [%e subsql_expr]) elems
in
let patch = Stdlib.String.concat ", " subsqls in
let sql = [%e sql_before] ^ patch ^ [%e sql_after] in
let open Rapper.Internal in
let (Dynparam.Pack
( packed_list_type,
[%p Codegen.ppat_of_param ~loc list_param] )) =
Stdlib.List.fold_left
(fun pack item ->
Dynparam.add
(Caqti_type.(
[%e
Codegen.make_caqti_type_tup ~loc [ list_param ]])
[@ocaml.warning "-33"])
item pack)
Dynparam.empty elems
in
let query = [%e query_expr] in
[%e body]]
in
match output_kind with
| `Function -> [%expr fun loaders -> [%e base]]
| _ -> base
in
match output_kind with
| `Function ->
[%expr
let wrapped =
[%e make_function ~body_fn ~loc expression_contents]
in
wrapped loaders]
| _ ->
[%expr
let wrapped =
[%e make_function ~body_fn ~loc expression_contents]
in
wrapped]
in
let expand_get caqti_request_function_expr make_function =
try
Ok
(make_generic make_function
[%expr
[%e caqti_request_function_expr]
~oneshot:true ([%e caqti_input_type] [@ocaml.warning "-33"])
(Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
sql])
with Codegen.Error s -> Error s
in
let expand_exec caqti_request_function_expr make_function =
try
Ok
(make_generic make_function
[%expr
[%e caqti_request_function_expr]
[%e caqti_input_type] (Caqti_type.unit) sql])
with Codegen.Error s -> Error s
in
(expand_get, expand_exec)
| None ->
let inputs_caqti_type, outputs_caqti_type, parsed_sql =
component_expressions ~loc parsed_query
in
let expression_contents =
Codegen.
{
in_params = parsed_query.in_params;
out_params = parsed_query.out_params;
input_kind;
output_kind;
}
in
let make_generic make_function query_expr =
match output_kind with
| `Function ->
[%expr
fun loaders ->
let query = [%e query_expr] in
let wrapped =
[%e
make_function ~body_fn:(fun x -> x) ~loc expression_contents]
in
wrapped loaders]
| _ ->
[%expr
let query = [%e query_expr] in
let wrapped =
[%e
make_function ~body_fn:(fun x -> x) ~loc expression_contents]
in
wrapped]
in
let expand_get caqti_request_function_expr make_function =
try
Ok
(make_generic make_function
[%expr
[%e caqti_request_function_expr]
(Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
(Caqti_type.([%e outputs_caqti_type]) [@ocaml.warning "-33"])
[%e parsed_sql]])
with Codegen.Error s -> Error s
in
let expand_exec caqti_request_function_expr make_function =
try
Ok
(make_generic make_function
[%expr
[%e caqti_request_function_expr]
(Caqti_type.([%e inputs_caqti_type]) [@ocaml.warning "-33"])
(Caqti_type.unit) [%e parsed_sql]])
with Codegen.Error s -> Error s
in
(expand_get, expand_exec)
let expand_apply ~loc ~path:_ action query args =
let expression_result =
match parse_args args with
| Error err -> Error err
| Ok (input_kind, output_kind, syntax_off) -> (
match Query.parse query with
| Error error -> Error (Query.explain_error error)
| Ok parsed_query -> (
let syntax_result =
match syntax_off with
| false -> (
let query_sql =
match parsed_query.list_params with
| Some { subsql; string_index; _ } ->
let sql = parsed_query.sql in
let sql_before =
String.sub sql ~pos:0 ~len:string_index
in
let sql_after =
String.sub sql ~pos:string_index
~len:(String.length sql - string_index)
in
sql_before ^ subsql ^ sql_after
| None -> parsed_query.sql
in
match Pg_query.parse query_sql with
| Ok _ -> Ok ()
| Error msg ->
Error (Printf.sprintf "Syntax error in SQL: '%s'" msg))
| true -> Ok ()
in
match syntax_result with
| Error msg -> Error msg
| Ok () ->
Ok
(let expand_get, expand_exec =
make_expand_get_and_exec_expression ~loc parsed_query
input_kind output_kind
in
match action with
| "execute" -> (
match output_kind with
| `Record ->
Error
"record_out is not a valid argument for execute"
| `Function ->
Error
"function_out is not a valid argument for execute"
| `Tuple ->
expand_exec [%expr Caqti_request.Infix.(->.)] Codegen.exec_function)
| "get_one" -> expand_get [%expr Caqti_request.Infix.(->!)] Codegen.find_function
| "get_opt" ->
expand_get [%expr Caqti_request.Infix.(->?)] Codegen.find_opt_function
| "get_many" ->
expand_get [%expr Caqti_request.Infix.(->*)] Codegen.collect_list_function
| _ ->
Error
"Supported actions are execute, get_one, get_opt and \
get_many")))
in
match expression_result with
| Ok (Ok expr) -> expr
| Ok (Error msg) | Error msg ->
raise
(Location.Error
(Location.Error.createf ~loc "Error in ppx_rapper: %s" msg))
let expand_let ~loc ~path vb_var action query args =
let module Ast = Ast_builder.Default in
let vb =
Ast.value_binding ~loc
~pat:(Ast.ppat_var ~loc (Ast.Located.mk ~loc vb_var))
~expr:(expand_apply ~loc ~path action query args)
in
Ast.pstr_value ~loc Nonrecursive [ vb ]
(** Captures [\[%rapper get_one "SELECT id FROM things WHERE condition"\]] *)
let apply_pattern () =
let open Ast_pattern in
let query_action = pexp_ident (lident __) in
let query = pair nolabel (estring __) in
let arg = pair nolabel (pexp_ident (lident __)) in
let arguments = query ^:: many arg in
pexp_apply query_action arguments
(** Captures [\[let%rapper get_thing = get_one "SELECT id FROM things WHERE condition"\]] *)
let let_pattern () =
let open Ast_pattern in
pstr
(pstr_value nonrecursive
(value_binding ~pat:(ppat_var __) ~expr:(apply_pattern ()) ^:: nil)
^:: nil)
let name = "rapper"
let apply_ext =
Extension.declare name Extension.Context.expression
Ast_pattern.(single_expr_payload (apply_pattern ()))
expand_apply
let let_ext =
Extension.declare name Extension.Context.structure_item (let_pattern ())
expand_let
let () = Driver.register_transformation name ~extensions:[ let_ext; apply_ext ]