Source file discrimination_tree.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
open Elpi_util.Util
let kConstant = 0
let kPrimitive = 1
let kVariable = 2
let kOther = 3
let arity_bits = 4
let k_bits = 2
(** [encode k c a]:
- k : the "constuctor" (kConstant, kPrimitive, kVariable, kOther)
- c : the "data"
- a : the "arity"
*)
let cell_size = Sys.int_size
let k_lshift = cell_size - k_bits
let ka_lshift = cell_size - k_bits - arity_bits
let k_mask = ((1 lsl k_bits) - 1) lsl k_lshift
let arity_mask = (((1 lsl arity_bits) lsl k_bits) - 1) lsl ka_lshift
let data_mask = (1 lsl ka_lshift) - 1
let encode ~k ~arity ~data = (k lsl k_lshift) lor (arity lsl ka_lshift) lor (data land data_mask)
let k_of n = (n land k_mask) lsr k_lshift
let arity_of n =
let k = k_of n in
if k == kConstant then (n land arity_mask) lsr ka_lshift else 0
let data_of n =
(n land data_mask)
let mkConstant ~safe ~data ~arity =
let rc = encode ~k:kConstant ~data ~arity in
if safe && (abs data > data_mask || arity >= 1 lsl arity_bits) then
anomaly (Printf.sprintf "Indexing at depth > 1 is unsupported since constant %d/%d is too large or wide" data arity);
rc
let mkPrimitive c = encode ~k:kPrimitive ~data:(CData.hash c lsl k_bits) ~arity:0
let mkVariable = encode ~k:kVariable ~data:0 ~arity:0
let mkAny = encode ~k:kOther ~data:0 ~arity:0
let mkInputMode = encode ~k:kOther ~data:1 ~arity:0
let mkOutputMode = encode ~k:kOther ~data:2 ~arity:0
let mkListTailVariable = encode ~k:kOther ~data:3 ~arity:0
let mkListHead = encode ~k:kOther ~data:4 ~arity:0
let mkListEnd = encode ~k:kOther ~data:5 ~arity:0
let mkPathEnd = encode ~k:kOther ~data:6 ~arity:0
let mkListTailVariableUnif = encode ~k:kOther ~data:7 ~arity:0
let isVariable x = x == mkVariable
let isAny x = x == mkAny
let isInput x = x == mkInputMode
let isOutput x = x == mkOutputMode
let isListHead x = x == mkListHead
let isListEnd x = x == mkListEnd
let isListTailVariable x = x == mkListTailVariable
let isListTailVariableUnif x = x == mkListTailVariableUnif
let isPathEnd x = x == mkPathEnd
type cell = int
let pp_cell fmt n =
let k = k_of n in
if k == kConstant then
let data = data_mask land n in
let arity = (arity_mask land n) lsr ka_lshift in
Format.fprintf fmt "Constant(%d,%d)" data arity
else if k == kVariable then Format.fprintf fmt "Variable"
else if k == kOther then
Format.fprintf fmt
(if isInput n then "Input"
else if isOutput n then "Output"
else if isListTailVariable n then "ListTailVariable"
else if isListHead n then "ListHead"
else if isListEnd n then "ListEnd"
else if isPathEnd n then "PathEnd"
else if isListTailVariableUnif n then "ListTailVariableUnif"
else if isAny n then "Other"
else failwith "Invalid path construct...")
else if k == kPrimitive then Format.fprintf fmt "Primitive"
else Format.fprintf fmt "%o" k
let show_cell n = Format.asprintf "%a" pp_cell n
module Path = struct
type t = cell array [@@deriving show]
let get a i = a.(i)
type builder = { mutable pos : int; mutable path : cell array }
let get_builder_pos {pos} = pos
let make size e = { pos = 0; path = Array.make size e }
let rec emit p e =
let len = Array.length p.path in
if p.pos < len - 1 then begin
Array.unsafe_set p.path p.pos e;
p.pos <- p.pos + 1
end else begin
let newpath = Array.make (2 * len) mkPathEnd in
Array.blit p.path 0 newpath 0 len;
p.path <- newpath;
emit p e
end
let stop { path } = path
let of_list l =
let path = make 1 (List.hd l) in
List.iter (emit path) l;
stop path
end
module Trie = struct
type 'a t = Node of {
data : 'a list;
other : 'a t option;
listTailVariable : 'a t option;
map : 'a t Ptmap.t;
}
let empty = Node { data = []; other = None; listTailVariable = None; map = Ptmap.empty }
let is_empty x = x == empty
let rec replace p x = function
| Node { data; other; listTailVariable; map } ->
Node {
data = data |> List.map (fun y -> if p y then x else y);
other = other |> Option.map (replace p x);
listTailVariable = listTailVariable |> Option.map (replace p x);
map = map |> Ptmap.map (replace p x);
}
let rec remove f = function
| Node { data; other; listTailVariable; map } ->
Node {
data = data |> List.filter (fun x -> not (f x));
other = other |> Option.map (remove f);
listTailVariable = listTailVariable |> Option.map (remove f);
map = map |> Ptmap.map (remove f);
}
let add (a : Path.t) v t =
let max = ref 0 in
let rec ins ~pos = let x = Path.get a pos in function
| Node ({ data } as t) when isPathEnd x -> max := pos; Node { t with data = v :: data }
| Node ({ other } as t) when isVariable x || isAny x ->
let t' = match other with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with other = Some t'' }
| Node ({ listTailVariable } as t) when isListTailVariable x || isListTailVariableUnif x ->
let t' = match listTailVariable with None -> empty | Some x -> x in
let t'' = ins ~pos:(pos+1) t' in
Node { t with listTailVariable = Some t'' }
| Node ({ map } as t) ->
let t' = try Ptmap.find x map with Not_found -> empty in
let t'' = ins ~pos:(pos+1) t' in
Node { t with map = Ptmap.add x t'' map }
in
let t = ins ~pos:0 t in
t, !max
let rec pp (ppelem : Format.formatter -> 'a -> unit) (fmt : Format.formatter)
(Node { data; other; listTailVariable; map } : 'a t) : unit =
Format.fprintf fmt "@[<v>[values:{";
pplist ppelem "; " fmt data;
Format.fprintf fmt "}@ other:{";
(match other with None -> () | Some m -> pp ppelem fmt m);
Format.fprintf fmt "}@ listTailVariable:{";
(match listTailVariable with None -> () | Some m -> pp ppelem fmt m);
Format.fprintf fmt "}@ key:{@[<hov 2>";
Ptmap.to_list map
|> pplist
(fun fmt (k, v) ->
pp_cell fmt k;
pp ppelem fmt v)
";@ " fmt;
Format.fprintf fmt "@]}]@]"
let show (fmt : Format.formatter -> 'a -> unit) (n : 'a t) : string =
let b = Buffer.create 22 in
Format.fprintf (Format.formatter_of_buffer b) "@[%a@]" (pp fmt) n;
Buffer.contents b
end
let update_par_count n k =
if isListHead k then n + 1 else
if isListEnd k || isListTailVariable k || isListTailVariableUnif k then n - 1 else n
let skip ~pos path : int =
let rec aux_list acc p =
if acc = 0 then p
else aux_list (update_par_count acc (Path.get path p)) (p+1)
in
let rec aux_const arity p =
if arity = 0 then p
else
if isListHead (Path.get path p) then
let skip_list = aux_list 1 (p+1) in
aux_const (arity - 1) skip_list
else aux_const (arity - 1 + arity_of (Path.get path p)) (p+1)
in
if isListHead (Path.get path pos) then aux_list 1 (pos+1) else aux_const (arity_of (Path.get path pos)) (pos+1)
(**
Takes a path and skip a listTailVariable:
we take the node just after the first occurrence of isListEnd or isListTailVariable
with no corresponding isListHead
*)
let skip_listTailVariable ~pos path : int =
let rec aux i pc =
if pc = 0 then i
else aux (i+1) (update_par_count pc (Path.get path i)) in
aux (pos + 1) (update_par_count 1 (Path.get path pos))
type 'a t = {t: 'a Trie.t; max_size : int; max_depths : int array; max_list_length: int }
let pp pp_a fmt { t } : unit = Trie.pp (fun fmt data -> pp_a fmt data) fmt t
let show pp_a { t } : string = Trie.show (fun fmt data -> pp_a fmt data) t
let index { t; max_size; max_depths; max_list_length = mll } ~max_list_length path data =
let t, m = Trie.add path data t in
{ t; max_size = max max_size m; max_depths; max_list_length = max max_list_length mll }
let max_path { max_size } = max_size
let max_depths { max_depths } = max_depths
let max_list_length { max_list_length } = max_list_length
let call f =
let res = ref @@ [] in
let add_result x = res := x :: !res in
f ~add_result; !res
let skip_to_listEnd ~add_result mode (Trie.Node { other; map; listTailVariable }) =
let rec get_from_list n = function
| Trie.Node { other = None; map; listTailVariable } as tree ->
if n = 0 then add_result tree
else
(Ptmap.iter
(fun k v -> get_from_list (update_par_count n k) v)
map;
match listTailVariable with None -> () | Some a -> add_result a)
| Trie.Node { other = Some other; map; listTailVariable } as tree ->
if n = 0 then (add_result tree; add_result other)
else
(get_from_list n other;
Ptmap.iter
(fun k v -> get_from_list (update_par_count n k) v)
map;
match listTailVariable with None -> () | Some a -> add_result a)
in
let some_to_list = function Some x -> add_result x | None -> () in
(match other with None -> () | Some a -> get_from_list 1 a);
if isOutput mode then Ptmap.iter (fun k v -> get_from_list (update_par_count 1 k) v) map;
some_to_list listTailVariable
let skip_to_listEnd mode t = call (skip_to_listEnd mode t)
let get_all_children v mode = isAny v || (isVariable v && isOutput mode)
let rec retrieve ~pos ~add_result mode path tree : unit =
let hd = Path.get path pos in
let Trie.Node {data; map; other; listTailVariable} = tree in
if Trie.is_empty tree then ()
else if isPathEnd hd then List.iter add_result data
else if isInput hd || isOutput hd then
retrieve ~pos:(pos+1) ~add_result hd path tree
else if isListTailVariable hd || isListTailVariableUnif hd then
let sub_tries = skip_to_listEnd (if isListTailVariableUnif hd then mkOutputMode else mode) tree in
List.iter (retrieve ~pos:(pos+1) ~add_result mode path) sub_tries
else begin
begin
if get_all_children hd mode then
on_all_children ~pos:(pos+1) ~add_result mode path map
else if isInput mode && isVariable hd then ()
else
try retrieve ~pos:(pos+1) ~add_result mode path (Ptmap.find hd map)
with Not_found -> ()
end;
if not(isListEnd hd) then Option.iter (fun a -> retrieve ~pos:(skip ~pos path) ~add_result mode path a) other;
Option.iter (fun a -> retrieve ~pos:(skip_listTailVariable ~pos path) ~add_result mode path a) listTailVariable
end
and on_all_children ~pos ~add_result mode path map =
let rec skip_list par_count arity = function
| Trie.Node { other; map; listTailVariable } as tree ->
if par_count = 0 then begin
skip_functor arity tree
end else begin
Ptmap.iter (fun k v -> skip_list (update_par_count par_count k) arity v) map;
Option.iter (skip_list par_count arity) other;
Option.iter (skip_list (par_count - 1) arity) listTailVariable
end
and skip_functor arity = function
| Trie.Node { other; map } as tree ->
Option.iter (retrieve ~pos ~add_result mode path) other;
if arity = 0 then retrieve ~pos ~add_result mode path tree
else
Ptmap.iter (fun k v ->
if isListHead k then skip_list 1 (arity - 1) v
else skip_functor (arity - 1 + arity_of k) v)
map
in
Ptmap.iter (fun k v ->
if isListHead k then skip_list 1 0 v
else skip_functor (arity_of k) v)
map
let empty_dt args_depth : 'a t =
let max_depths = Array.make (List.length args_depth) 0 in
{t = Trie.empty; max_depths; max_size = 0; max_list_length=0}
let retrieve ~pos ~add_result path index =
let mode = Path.get path pos in
assert(isInput mode || isOutput mode);
retrieve ~add_result mode ~pos:(pos+1) path index
let retrieve cmp_data path { t } =
let r = call (retrieve ~pos:0 path t) in
Bl.of_list @@ List.sort cmp_data r
let replace p x i = { i with t = Trie.replace p x i.t }
let remove keep dt = { dt with t = Trie.remove keep dt.t}
module Internal = struct
let kConstant = kConstant
let kPrimitive = kPrimitive
let kVariable = kVariable
let kOther = kOther
let k_of = k_of
let arity_of = arity_of
let data_of = data_of
let isVariable = isVariable
let isAny = isAny
let isInput = isInput
let isOutput = isOutput
let isListHead = isListHead
let isListEnd = isListEnd
let isListTailVariable = isListTailVariable
let isListTailVariableUnif = isListTailVariableUnif
let isPathEnd = isPathEnd
end