Source file optimizer.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
(*****************************************************************************)
(*                                                                           *)
(* MIT License                                                               *)
(* Copyright (c) 2022 Nomadic Labs <contact@nomadic-labs.com>                *)
(*                                                                           *)
(* Permission is hereby granted, free of charge, to any person obtaining a   *)
(* copy of this software and associated documentation files (the "Software"),*)
(* to deal in the Software without restriction, including without limitation *)
(* the rights to use, copy, modify, merge, publish, distribute, sublicense,  *)
(* and/or sell copies of the Software, and to permit persons to whom the     *)
(* Software is furnished to do so, subject to the following conditions:      *)
(*                                                                           *)
(* The above copyright notice and this permission notice shall be included   *)
(* in all copies or substantial portions of the Software.                    *)
(*                                                                           *)
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,  *)
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL   *)
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING   *)
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER       *)
(* DEALINGS IN THE SOFTWARE.                                                 *)
(*                                                                           *)
(*****************************************************************************)

open Csir
open Optimizer_helpers

(** The optimizer simplifies a constraint system, producing an equivalent one
    with fewer constraints in essentially three ways:

    - 1. {!next_gate_selectors} [qlg], [qrg] and/or [qog], which
    lead to a more compact representation of sums (linear combinations).

    - 2. {!shared_wires} among different constraints.

    - 3. {!compacting_blocks} which operate on independent wires.

    {1 Optimization Rules}

    {2:next_gate_selectors Introducing next-gate selectors}

    In a system with 3-wires architecture, adding (in a linear combination) [n]
    elements requires [n-1] constraints. For example, computing
    [out = 1 + 2 x1 + 4 x2 + 8 x3 + 16 x4 + 32 x5] can be implemented in 4
    constraints as follows:
    {[
      { a = x1;   b = x2;   c = aux1; identity = [ql 2; qr 4; qo -1; qc 1] };
      { a = x3;   b = x4;   c = aux2; identity = [ql 8; qr 16; qo -1] };
      { a = x5;   b = aux1; c = aux3; identity = [ql 32; qr 1; qo -1] };
      { a = aux2; b = aux3; c = out;  identity = [ql 1; qr 1; qo -1] };
    ]}
    Instead, by using next-gate selectors, a linear combination of [n] terms
    can be implemented in [(n-1)/2] constraints. In this case, just two:
    {[
      { a = x1; b = x2; c = out; identity = [ql 2; qr 4; qlg 8; qrg 16; qog 32; qo -1; qc 1] };
      { a = x3; b = x4; c = x5;  identity = [] };
    ]}

    {2:shared_wires Reusing shared-wires}

    Remarkably, two linear combinations that involve some common wires can be
    implemented more efficiently if they are handled together.
    For example, [out1 = 5 x + 3 y + 9 z] and [out2 = 2 x - 3 y + 7 t] can be
    implemented with 3 constraints (instead of 4) as follows:
    {[
      { a = z;        c = out1; identity = [ql 9; qlg 5; qrg 3; qog -1] };
      { a = x; b = y; c = out2; identity = [ql 2; qr -3; qrg 7; qog -1] };
      {        b = t;         ; identity = [] };
    ]}
    Furthermore, observe how, conveniently, some wires in the above constraints
    are unbound, which can lead to our next optimization.

    {2:compacting_blocks Compacting blocks}

    Consider a gate which only takes one wire as input, e.g. for computing
    [out3 = 3 w + 1], implemented with constraint:
    {[{ a = w;        c = out3; identity = [ql 3; qo -1; qc 1] }]}
    This constraint can be merged with our previous block, which only uses
    wire b in its last constraint and has unbound selectors for it:
    {[
      { a = z;        c = out1; identity = [ql 9; qlg 5; qrg 3; qog -1] };
      { a = x; b = y; c = out2; identity = [ql 2; qr -3; qrg 7; qog -1] };
      { a = w; b = t; c = out3; identity = [ql 3; qo -1; qc 1] };
    ]}

    {1 Algorithm}

    The optimization proceeds in several steps:

    1. Collect information about all linear constraints/gates (see {!is_linear}),
    including the number of occurrences of their output wire among the circuit.

    2. Inline any {!type:linear_computation} whose output is used only once
    among the circuit, in the computation where it is used (only if the latter
    is also linear).

    3. Transform any non-inlined linear computation into a {!type:pseudo_block}
    and join pseudo blocks that share at least a wire (preferably two wires).

    4. Transform pseudo blocks to actual blocks (also considering the non-linear
    computations that were left aside) and join blocks that operate on
    independent wires together. *)

let nb_wires_arch = 3

type constr = Csir.CS.raw_constraint
type scalar = Csir.Scalar.t [@@deriving repr]

(* The goal of blocks and pseudo-blocks is to keep together constraints that
   depend on each other, e.g. those using next selectors. *)
type block = constr array

(* coeff * wire *)
type term = scalar * int [@@deriving repr]

type trace_info = {
  free_wires : int list;
  assignments : (int * term list) list;
}
[@@deriving repr]
(** Type to store the trace_updater information.
    [free_wires] represents a list of indices of wires that have been
    removed from the system and have not been used again for auxiliary
    computations. If a free wire [w] is used to store an auxiliary linear
    computation on some [terms], it is removed from [free_wires] and the pair
    [(v, terms)] is added to the list of [assignments]. *)

type pseudo_constr = {
  pc_head : term list;
  pc_body : term list;
  pc_tail : term list;
  pc_const : Scalar.t;
}
(** A pseudo constraint represents an equation between an arbitrary number of
    terms. In particular, adding all terms from its head, body, tail and its
    constant term should produce 0.
    The length of pc_head and pc_tail should never exceed [nb_wires_arch]. *)

type pseudo_block = pseudo_constr list
(** A pseudo block is a list of pseudo constraints that have been groupped
    in order to apply the {!shared_wires} heuristic.Alcotest
    The constraints in a pseudo block satisfy (by construction) the invariant
    the pc_head of a constraint starts with the pc_tail of the previous one
    (when considering the term variables, coefficients may be different).
    Note that the pc_head may contain extra terms that do not appear in the
    previous pc_tail.
    This invariant is important for [block_of_pseudo_block] to be correct *)

(** [join_list ~join scope x xs] joins [x] with another element among the first
    [scope] elements of the list of candidates [xs], or returns [None] if no
    union was possible. This function may not preserve the order in [xs] *)
let join_list :
    join:('a -> 'a -> 'a option) -> int -> 'a -> 'a list -> 'a list option =
 fun ~join scope x xs ->
  let rec aux cnt seen_xs = function
    | [] -> None
    | x' :: xs -> (
        if cnt >= scope then None
        else
          match join x x' with
          | Some union -> Some (List.rev_append seen_xs (union :: xs))
          | None -> (
              match join x' x with
              | Some union -> Some (List.rev_append seen_xs (union :: xs))
              | None -> aux (cnt + 1) (x' :: seen_xs) xs))
  in
  aux 0 [] xs

(** [combine_quadratic ~join xs] takes a list of elements [xs] and returns
    a (potentially) shorter list [xs'] where some of the elements have been
    combined according to [join]. [join x x'] is [None] if [x] and [x'] cannot
    be combined and it is [Some union] otherwise.
    This algorithm tries to combine every element of [xs] with any other element
    by greedily trying all possible [n choose 2] different pairings of elements,
    where [n] is the length of [xs]. *)
let combine_quadratic :
    scope:int -> join:('a -> 'a -> 'a option) -> 'a list -> 'a list =
 fun ~scope ~join groups ->
  let rec aux seen_xs = function
    | [] -> seen_xs
    | x :: xs -> (
        match join_list scope ~join x xs with
        | None -> aux (x :: seen_xs) xs
        | Some xs -> aux seen_xs xs)
  in
  aux [] groups

(** Similar to [combine_quadratic], but more efficient.
    Takes a list of [top_candidates] that will each be try to be combined
    with [rest]. This function may not preserve the order in [rest] *)
let combine_quadratic_efficient ~scope ~join top_candidates rest =
  let rec aux seen top rest =
    match top with
    | [] -> List.rev_append (combine_quadratic ~scope ~join seen) rest
    | x :: xs -> (
        match join_list scope ~join x rest with
        | None -> aux (x :: seen) xs rest
        | Some rest -> aux seen xs rest)
  in
  aux [] top_candidates rest

(** Concatenates two lists of terms, merging terms with the same wire identifier
    (by adding their coefficients) *)
let add_terms : term list -> term list -> term list =
 fun terms1 terms2 ->
  (* [aux] requires that the terms be first sorted wrt the wire identifier *)
  let rec aux terms (q, w) = function
    | [] -> (q, w) :: terms
    | (q', w') :: rest ->
        if w = w' then aux terms (Scalar.add q q', w) rest
        else aux ((q, w) :: terms) (q', w') rest
  in
  let compare (_, w1) (_, w2) = Int.compare w2 w1 in
  let sorted = List.sort compare (List.rev_append terms1 terms2) in
  match sorted with [] -> [] | _ -> aux [] (List.hd sorted) (List.tl sorted)

(** [inline w ~from ~into linear] takes an array of linear constraints
    (a linear constraint is represented by list of terms), solves for wire
    [w] in constraint [linear.(from)] and substitutes wire [w] by the result
    in constraint [linear.(into)]. [linear.(from)] is left empty.
    This routine should only be called if [w] does not appear in any other
    constraint. In that case, constraint [linear.(from)] can then be removed
    and wire [w] is said to be "free". *)
let inline w ~from ~into linear : unit =
  let equals_w t = w = snd t in
  let (q1, _), ts1 = find_and_remove equals_w linear.(from) in
  let (q2, _), ts2 = find_and_remove equals_w linear.(into) in
  let f (q, w) = (Scalar.(negate q * q2 / q1), w) in
  linear.(from) <- [];
  linear.(into) <- add_terms (List.map f ts1) ts2

module IMap = Map.Make (Int)
(** The keys are wire indices and the values lists of constraints indices, such
    that the wire appears in each of the constraints. *)
(* TODO: sometimes for clarity I like to make the type of the values monomorphic
   This module can also be called Wire_occurrences *)

module ISet = Set.Make (Int)

(** Auxiliary function used in [inline_linear].

    [increase_occurrences map i terms] updates [map] by adding [i] to the
    list of occurrences of wire [w] for every [(_, w)] in [terms]. *)
let increase_occurrences map i terms =
  List.fold_left
    (fun map (_, w) ->
      IMap.update w (fun o -> Some (i :: Option.value ~default:[] o)) map)
    map terms

(** Creates a pseudo block with all terms in the body except for the wire -1
    which is set as constant. *)
let pseudo_block_of_terms terms =
  (* The list of terms contains no duplicates with respect to the wire
     identifier [w], thanks to function [add_terms]. *)
  let qcs, terms = List.partition (fun (_, w) -> w < 0) terms in
  let qc =
    match qcs with
    | [] -> Scalar.zero
    | [ (q, -1) ] -> q
    | _ ->
        (* Two different terms should not have the same second component *)
        assert false
  in
  [ { pc_head = []; pc_body = terms; pc_tail = []; pc_const = qc } ]

(** Returns a list of linear gates that have possibly been inlined, the list of
    wires that were inlined (and thus freed) and the list of unchanged non-linear
    gates. The wires selected for inlining must:
    - appear in exactly two linear gates of size one
    - not be inputs
    The inlined linear gates are returned as pseudo blocks because they may need
    to be split again to fit 3 wires. *)
let inline_linear ~nb_inputs gates =
  let linear, non_linear =
    List.partition_map
      (fun gate ->
        if Array.length gate = 1 && CS.is_linear_raw_constr gate.(0) then
          Left gate.(0)
        else Right gate)
      gates
  in
  let linear = Array.(map CS.linear_terms @@ of_list linear) in

  let non_linear_wires =
    List.concat_map CS.gate_wires non_linear |> ISet.of_list
  in

  (* Associate each wire to the indices (of array [linear]) of the constraints
     where the wire appears. The only wires considered must not:
     - be a circuit input
     - appear in a non-linear gate. *)
  let wire_occurrences =
    Array.fold_left
      (fun (map, i) gate -> (increase_occurrences map i gate, succ i))
      (IMap.empty, 0) linear
    |> fst
    |> IMap.filter (fun w _ ->
           w > nb_inputs && (not @@ ISet.mem w non_linear_wires))
  in

  (* During the process of inlining, some wire occurrences will point to
     constraints that have been removed (inlined).
     In [pointers] we carry the information of where they were inlined into. *)
  let pointers = ref IMap.empty in
  let rec updated_pointer occ =
    match IMap.find_opt occ !pointers with
    | Some p ->
        let p' = updated_pointer p in
        (* Create a direct connection between [occ] and [p'] *)
        pointers := IMap.add occ p' !pointers;
        p'
    | None -> occ
  in

  (* If a wire occurs only on two constraints, we remove one
     of them, "inlining" it in the other, the wire disappears *)
  let free_wires, removable_constrs =
    IMap.fold
      (fun w occs (free_wires, removable_constrs) ->
        match occs with
        | [ occ ] ->
            (* We could raise an exception saying that wire [w] is unnecessary *)
            let occ = updated_pointer occ in
            (w :: free_wires, occ :: removable_constrs)
        | [ occ1; occ2 ] ->
            let from = updated_pointer occ1 in
            let into = updated_pointer occ2 in
            inline w ~from ~into linear;
            (* References to [from] should refer to [into] instead *)
            pointers := IMap.add from into !pointers;
            (w :: free_wires, removable_constrs)
        | _ -> (free_wires, removable_constrs))
      wire_occurrences ([], [])
  in

  (* Drop removable constraints (only now that the inlining has finished) *)
  List.iter (fun c -> linear.(c) <- []) removable_constrs;
  let linear = List.filter (fun ts -> ts <> []) (Array.to_list linear) in

  (* Sort free_vars to reach allocation order. Is this really needed? *)
  let free_wires = List.sort Int.compare free_wires in

  let pseudo_blocks = List.map pseudo_block_of_terms linear in

  (pseudo_blocks, non_linear, free_wires)

(** [place_over_pseudo_block ~perfectly b1 b2] tries to join pseudo block [b1]
    on top of [b2], i.e., link the last pseudo constraint of [b1] with the
    first pseudo constraint of [b2]. It returns an optional argument containing
    the combined pseudo block if the combination was successful.
    [perfectly] is a Boolean argument that specifies whether only "perfect"
    matches should be accepted. In this context, "perfect" means that the two
    constraints to be combined share exactly [nb_wires_arch] wires. *)
let place_over_pseudo_block ~perfectly b1 b2 =
  let b1_rest, b1_k = split_last b1 in
  let b2_0, b2_rest = split_first b2 in

  if b1_k.pc_tail <> [] || b2_0.pc_head <> [] then None
  else
    let b1_body = List.map snd b1_k.pc_body in
    let b2_body = List.map snd b2_0.pc_body in
    let common_body = list_intersection ~equal:( = ) b1_body b2_body in

    (* Given that the head of [b2_0] is empty, we can place [nb_wires_arch]
       terms on it. The total room depends on how occupied its tail is *)
    let b2_free_room = (2 * nb_wires_arch) - List.length b2_0.pc_tail in

    (* We will pick [n] common terms between the bodies of [b1_k] and [b2_0].
       If the body of [b2_0] fits completely into its head, we do not need
       to leave an extra space for an auxiliary variable *)
    let n =
      if
        List.length b2_body <= b2_free_room
        && List.length b2_body = nb_wires_arch
      then nb_wires_arch
      else nb_wires_arch - 1
    in
    let common = rev_first_n ~n common_body |> ISet.of_list in

    let tight = ISet.cardinal common = nb_wires_arch in
    if perfectly && not tight then None
    else if nb_wires_arch - ISet.cardinal common > 1 then None
    else
      let split_body = List.partition (fun (_, w) -> ISet.mem w common) in
      let qcommon1, body1 = split_body b1_k.pc_body in
      let qcommon2, body2 = split_body b2_0.pc_body in

      (* We sort them so that they have the same order *)
      let compare_terms (_, w1) (_, w2) = Int.compare w1 w2 in
      let qcommon1 = List.sort compare_terms qcommon1 in
      let qcommon2 = List.sort compare_terms qcommon2 in
      let pc1 = { b1_k with pc_tail = qcommon1; pc_body = body1 } in
      let pc2 = { b2_0 with pc_head = qcommon2; pc_body = body2 } in
      Some (b1_rest @ (pc1 :: pc2 :: b2_rest))

let combine_pseudo_blocks ~scope ~perfectly pseudo_blocks =
  let try_both_ways b1 b2 =
    match place_over_pseudo_block ~perfectly b1 b2 with
    | None -> place_over_pseudo_block ~perfectly b2 b1
    | Some union -> Some union
  in
  combine_quadratic ~scope ~join:try_both_ways pseudo_blocks

let linear_combination ~this ~next qc =
  let pad = List.init (nb_wires_arch - List.length this) (fun _ -> 0) in
  let wires = List.map snd this @ pad in

  let combine l1 l2 = List.combine (list_sub (List.length l2) l1) l2 in
  let this_sels = combine CS.this_constr_linear_selectors (List.map fst this) in
  let next_sels = combine CS.next_constr_linear_selectors (List.map fst next) in
  CS.mk_linear_constr (wires, (("qc", qc) :: this_sels) @ next_sels)

let dummy_linear_combination terms =
  let pad = List.init (nb_wires_arch - List.length terms) (fun _ -> 0) in
  CS.mk_linear_constr (List.map snd terms @ pad, [])

(** Given a pseudo block, i.e., a list of pseudo constraints, transform it into
    a block. Thanks to the invariant that the pc_head of a pseudo constraint
    starts with the pc_tail of the previous one, it is enough to add a constraint
    for each pc_head (and encode each pc_tail with next-gate selectors).
    Note that pc_bodies may need to be freed through auxiliary variables. *)
let block_of_pseudo_block trace_info pseudo_block =
  let pseudo_block = Array.of_list pseudo_block in
  Array.fold_left
    (fun (trace_info, block, residue_pseudo_blocks, i) b ->
      (* We need to move all terms from b.pc_body to its head or tail.
         If there are more terms that those that can be moved, we will
         need an auxiliary variable, to be put in the head *)

      (* TODO: select heuristically what to place and where, e.g., leave
         for the residue the terms that are more frequent among the circuit *)
      let is_last = i = Array.length pseudo_block - 1 in
      let free_room = nb_wires_arch - List.length b.pc_head in
      let free_room =
        free_room + if is_last then nb_wires_arch - List.length b.pc_tail else 0
      in
      let aux_var_needed = List.length b.pc_body > free_room in
      let n = if aux_var_needed then nb_wires_arch - 1 else nb_wires_arch in
      let head, pending = complete ~n b.pc_head b.pc_body in
      let tail, pending =
        if is_last then complete ~n:nb_wires_arch b.pc_tail pending
        else (b.pc_tail, pending)
      in
      let trace_info, head, residues =
        if not aux_var_needed then (
          assert (pending = []);
          (trace_info, head, []))
        else
          let v, free_wires = split_first trace_info.free_wires in
          let assignments = (v, pending) :: trace_info.assignments in
          assert (List.length head < nb_wires_arch);
          let head = head @ [ (Scalar.one, v) ] in
          let trace_info = { free_wires; assignments } in
          let residues =
            [ pseudo_block_of_terms @@ ((Scalar.(negate one), v) :: pending) ]
          in
          (trace_info, head, residues)
      in
      (* The next constraint will have a head that starts with this tail *)
      let constr1 = linear_combination ~this:head ~next:tail b.pc_const in
      let constrs =
        if is_last && tail <> [] then [ constr1; dummy_linear_combination tail ]
        else [ constr1 ]
      in
      (trace_info, block @ constrs, residues @ residue_pseudo_blocks, succ i))
    (trace_info, [], [], 0) pseudo_block
  |> fun (trace_info, block, residue_pseudo_blocks, _) ->
  (trace_info, Array.of_list block, residue_pseudo_blocks)

(** This function takes a list of pseudo blocks, combines them if possible and
    converts them into blocks. This conversion may lead to residue pseudo blocks
    (if auxiliary variables were involved to free a pc_body).
    It then applies the same process on the residue pseudo blocks.
    Observe that it is relatively rare that a residue pseudo block is produced
    after the convertion. Consequently the second iteration of this function
    will be among many fewer pseudo blocks and it quickly reaches termination *)
let blocks_of_pseudo_blocks ~scope free_wires pseudo_blocks =
  let rec iteration (trace_info, blocks, pseudo_blocks) =
    match pseudo_blocks with
    | [] -> (trace_info, blocks)
    | _ ->
        let joined =
          pseudo_blocks
          |> combine_pseudo_blocks ~scope ~perfectly:true
          |> combine_pseudo_blocks ~scope ~perfectly:false
        in
        List.fold_left
          (fun (trace_info, blocks, new_pseudoblks) pb ->
            let trace_info, block, new_pseudo =
              block_of_pseudo_block trace_info pb
            in
            ( trace_info,
              block :: blocks,
              List.rev_append new_pseudo new_pseudoblks ))
          (trace_info, blocks, []) joined
        |> iteration
  in
  let trace_info = { free_wires; assignments = [] } in
  iteration (trace_info, [], pseudo_blocks)

(** [place_over_block ~perfectly pb pb'] tries to join block [b1] on top
    of [b2]. It works only if the last constraint of [b1] has no selectors.
    [perfectly] returns only matches that have no empty selectors (no holes).
    It is relevant for a heuristic where we first want to find perfect matches
    and then any match.

    The joining is done by
    - finding a permutation that makes the wires of b1.last and b2.first match
      (being relaxed about unset wires)
    - checking if the permutation if perfect and failing if required
    - renaming the selectors of b1.last
    - renaming the next selectors of b1.second_last

    Example:
    [ (1 1 1)  [ql:a qlg:a]
      (3 2 -1) [] ]
    [ (2 3 4)  [ql:a]
      (1 1 1)  [ql:a] ]
    becomes
    [ (1 1 1)  [ql:a qlg:b]
      (2 3 4)  [ql:a]
      (1 1 1)  [ql:a] ]
*)
let place_over_block : perfectly:bool -> block -> block -> block option =
 fun ~perfectly b1 b2 ->
  let b1_k = b1.(Array.length b1 - 1) in

  if b1_k.sels <> [] then None
  else
    let wires1 = CS.wires_of_constr_i b1 (Array.length b1 - 1) in
    let wires2 = CS.wires_of_constr_i b2 0 in

    match adapt wires1 wires2 with
    | None -> None
    | Some perm ->
        let wires1 = permute_list perm wires1 in

        let combined =
          List.map2 (fun x y -> if x >= 0 then x else y) wires1 wires2
        in

        let is_perfect = List.for_all (fun x -> x >= 0) combined in
        if perfectly && not is_perfect then None
        else
          let rename n =
            let selectors_assoc =
              List.combine
                (permute_list perm CS.next_constr_linear_selectors)
                CS.next_constr_linear_selectors
            in
            assert (List.length selectors_assoc = nb_wires_arch);
            match List.find_opt (fun (s, _) -> s = n) selectors_assoc with
            | Some (_, s') -> s'
            | None -> n
          in
          let b1_k_1 = b1.(Array.length b1 - 2) in
          let sels_prev = List.map (fun (n, q) -> (rename n, q)) b1_k_1.sels in

          let a, b, c =
            (* TODO: Revisit [type raw_constraint] and how to represent unused
               wires. Should the conversion from -1 to 0 be done here? *)
            match List.map (max 0) combined with
            | [ a; b; c ] -> (a, b, c)
            | _ -> assert false
          in
          let b2_0 = b2.(0) in
          let b2_rest = Array.sub b2 1 (Array.length b2 - 1) in
          b1.(Array.length b1 - 1) <-
            { a; b; c; sels = b2_0.sels; label = "hyb" :: b2_0.label };
          b1.(Array.length b1 - 2) <- { b1_k_1 with sels = sels_prev };
          Some (Array.append b1 b2_rest)

let combine_blocks ~perfectly blocks =
  let blocks_with_empty_selectors, rest =
    let open CS in
    List.partition (fun b -> b.(Array.length b - 1).sels = []) blocks
  in
  combine_quadratic_efficient
    ~join:(place_over_block ~perfectly)
    blocks_with_empty_selectors rest

(** Takes {!type:trace_info} and returns a function that updates a given
    trace to make it compatible with the optimized constraints. *)
let trace_updater : trace_info -> scalar array -> scalar array =
 fun trace_info trace ->
  let sum ~trace =
    List.fold_left
      (fun s (coeff, i) -> Scalar.(add s @@ mul coeff trace.(i)))
      Scalar.zero
  in
  List.iter
    (fun (i, terms) -> trace.(i) <- sum ~trace terms)
    (* Iterate over the assignments reversed to ensure that variables
       are assigned in the correct order, in case they depend on each other. *)
    (List.rev trace_info.assignments);
  trace

(** Takes a list of raw_constraints and returns an equivalent constraint system
    with potentially fewer constraints.
    As a second output, it returns the necessary information to build a function
    that updates the trace to make it compatible with the new constraints *)
let optimize : nb_inputs:int -> CS.gate list -> CS.gate list * trace_info =
 fun ~nb_inputs gates ->
  let linear_pseudo_blocks, non_linear, free_wires =
    inline_linear ~nb_inputs gates
  in
  let scope =
    if List.compare_length_with non_linear 10_000 > 0 then 10 else 100
  in
  let trace_info, linear_blocks =
    blocks_of_pseudo_blocks ~scope free_wires linear_pseudo_blocks
  in
  let cs =
    List.rev_append non_linear linear_blocks
    |> combine_blocks ~scope ~perfectly:true
    (* we shuffle the blocks to mix linear with non-linear ones and this
       way allow for more likely combinations; we use an arbitrary fixed seed
       for determinism *)
    |> shuffle_list ~seed:1031
    |> combine_blocks ~scope ~perfectly:false
  in
  (cs, trace_info)