Source file codec.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
type 'a t = {
  schema: Schema.t;
  encode: 'a -> Output.t -> unit;
  decode: Input.t -> 'a;
}

let null = {
  schema = Schema.Null;
  encode = (fun () out -> Output.write_null out ());
  decode = (fun inp -> Input.read_null inp);
}

let boolean = {
  schema = Schema.Boolean;
  encode = (fun b out -> Output.write_boolean out b);
  decode = (fun inp -> Input.read_boolean inp);
}

let int = {
  schema = Schema.Int None;
  encode = (fun i out -> Output.write_int out i);
  decode = (fun inp -> Input.read_int inp);
}

let long = {
  schema = Schema.Long None;
  encode = (fun l out -> Output.write_long out l);
  decode = (fun inp -> Input.read_long inp);
}

let float = {
  schema = Schema.Float;
  encode = (fun f out -> Output.write_float out f);
  decode = (fun inp -> Input.read_float inp);
}

let double = {
  schema = Schema.Double;
  encode = (fun d out -> Output.write_double out d);
  decode = (fun inp -> Input.read_double inp);
}

let bytes = {
  schema = Schema.Bytes None;
  encode = (fun b out -> Output.write_bytes out b);
  decode = (fun inp -> Input.read_bytes inp);
}

let string = {
  schema = Schema.String None;
  encode = (fun s out -> Output.write_string out s);
  decode = (fun inp -> Input.read_string inp);
}

(* TODO Warning on un-used variable *)
(* TODO In general reduce type conversions between List, Array, Bytes, String. *)
let fixed ?(name = "fixed") size = {
  schema = Schema.Fixed {
    fixed_name = Type_name.simple name;
    size;
    fixed_doc = None;
    fixed_aliases = [];
    fixed_logical = None;
  };
  encode = (fun bytes out ->
    if Bytes.length bytes <> size then
      failwith (Printf.sprintf "Fixed type size mismatch: expected %d bytes, got %d"
        size (Bytes.length bytes));
    Output.write_fixed out bytes
  );
  decode = (fun inp -> Input.read_fixed inp size);
}

let array codec = {
  schema = Schema.Array codec.schema;
  encode = (fun arr out ->
    let len = Array.length arr in
    if len = 0 then
      Output.write_long out 0L
    else begin
      Output.write_long out (Int64.of_int len);
      Array.iter (fun elem -> codec.encode elem out) arr;
      Output.write_long out 0L
    end
  );
  decode = (fun inp ->
    (* Accumulate arrays from blocks, then concatenate at end *)
    let rec read_blocks acc =
      let count = Input.read_long inp in
      if count = 0L then
        List.rev acc
      else if count < 0L then
        let _size = Input.read_long inp in
        let items = Array.init (Int64.to_int (Int64.neg count))
          (fun _ -> codec.decode inp) in
        (read_blocks[@tailcall]) (items :: acc)
      else
        let items = Array.init (Int64.to_int count)
          (fun _ -> codec.decode inp) in
        (read_blocks[@tailcall]) (items :: acc)
    in
    let arrays = read_blocks [] in
    (* Concatenate all arrays efficiently *)
    Array.concat arrays
  );
}

let map codec = {
  schema = Schema.Map codec.schema;
  encode = (fun pairs out ->
    let len = List.length pairs in
    if len = 0 then
      Output.write_long out 0L
    else begin
      Output.write_long out (Int64.of_int len);
      List.iter (fun (key, value) ->
        Output.write_string out key;
        codec.encode value out
      ) pairs;
      Output.write_long out 0L
    end
  );
  decode = (fun inp ->
    let rec read_blocks acc =
      let count = Input.read_long inp in
      if count = 0L then
        List.rev acc
      else if count < 0L then
        let _size = Input.read_long inp in
        let items = List.init (Int64.to_int (Int64.neg count))
          (fun _ ->
            let key = Input.read_string inp in
            let value = codec.decode inp in
            (key, value)
          ) in
        read_blocks (List.rev_append items acc)
      else
        let items = List.init (Int64.to_int count)
          (fun _ ->
            let key = Input.read_string inp in
            let value = codec.decode inp in
            (key, value)
          ) in
        read_blocks (List.rev_append items acc)
    in
    read_blocks []
  );
}

let union codecs = {
  schema = Schema.Union (List.map (fun c -> c.schema) codecs);
  encode = (fun (branch, value) out ->
    Output.write_long out (Int64.of_int branch);
    (List.nth codecs branch).encode value out
  );
  decode = (fun inp ->
    let branch = Int64.to_int (Input.read_long inp) in
    let value = (List.nth codecs branch).decode inp in
    (branch, value)
  );
}

let option codec = {
  schema = Schema.Union [Schema.Null; codec.schema];
  encode = (fun opt out ->
    match opt with
    | None ->
        Output.write_long out 0L;
        Output.write_null out ()
    | Some value ->
        Output.write_long out 1L;
        codec.encode value out
  );
  decode = (fun inp ->
    let branch = Int64.to_int (Input.read_long inp) in
    match branch with
    | 0 ->
        Input.read_null inp;
        None
    | 1 ->
        Some (codec.decode inp)
    | _ ->
        failwith "Invalid union branch for option"
  );
}

type ('record, 'constructor) builder = {
  type_name: Type_name.t;
  constructor: 'constructor;
  fields_rev: Schema.field list;
  encode: 'record -> Output.t -> unit;
  decode: Input.t -> 'constructor;
}

let record type_name constructor = {
  type_name;
  constructor;
  fields_rev = [];
  encode = (fun _ _ -> ());
  decode = (fun _ -> constructor);
}

let field field_name field_codec getter builder =
  let field_schema = {
    Schema.field_name;
    Schema.field_type = field_codec.schema;
    Schema.field_doc = None;
    Schema.field_default = None;
    Schema.field_aliases = [];
  } in

  let new_encode record out =
    builder.encode record out;
    field_codec.encode (getter record) out
  in

  let new_decode inp =
    let partial = builder.decode inp in
    let value = field_codec.decode inp in
    partial value
  in

  {
    type_name = builder.type_name;
    constructor = (Obj.magic ());
    fields_rev = field_schema :: builder.fields_rev;
    encode = new_encode;
    decode = new_decode;
  }

let field_opt field_name field_codec getter builder =
  let option_codec = option field_codec in
  let field_schema = {
    Schema.field_name;
    Schema.field_type = option_codec.schema;
    Schema.field_doc = None;
    Schema.field_default = Some Schema.Null_default;
    Schema.field_aliases = [];
  } in

  let new_encode record out =
    builder.encode record out;
    option_codec.encode (getter record) out
  in

  let new_decode inp =
    let partial = builder.decode inp in
    let value = option_codec.decode inp in
    partial value
  in

  {
    type_name = builder.type_name;
    constructor = (Obj.magic ());
    fields_rev = field_schema :: builder.fields_rev;
    encode = new_encode;
    decode = new_decode;
  }

let finish builder =
  let schema = Schema.Record {
    name = builder.type_name;
    fields = List.rev builder.fields_rev;
    record_doc = None;
    record_aliases = [];
  } in

  {
    schema;
    encode = builder.encode;
    decode = builder.decode;
  }

let encode_to_bytes (codec : 'a t) (value : 'a) : bytes =
  let out = Output.create () in
  codec.encode value out;
  Output.to_bytes out

let decode_from_bytes (codec : 'a t) (bytes : bytes) : 'a =
  let inp = Input.of_bytes bytes in
  codec.decode inp

let encode_to_string (codec : 'a t) (value : 'a) : string =
  let out = Output.create () in
  codec.encode value out;
  Output.contents out

let decode_from_string (codec : 'a t) (str : string) : 'a =
  let inp = Input.of_string str in
  codec.decode inp