Source file SolverLo.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
(******************************************************************************)
(*                                                                            *)
(*                                  Inferno                                   *)
(*                                                                            *)
(*                       François Pottier, Inria Paris                        *)
(*                                                                            *)
(*  Copyright Inria. All rights reserved. This file is distributed under the  *)
(*  terms of the MIT License, as described in the file LICENSE.               *)
(*                                                                            *)
(******************************************************************************)

open UnifierSig
open SolverSig

module Make
  (X : TEVAR)
  (S : STRUCTURE)
  (O : OUTPUT with type 'a structure = 'a S.structure)
= struct

(* -------------------------------------------------------------------------- *)

(* The type [tevar] of term variables is provided by [X]. *)

type tevar =
  X.tevar

module XMap =
  Map.Make(struct include X type t = tevar end)

(* The generalization engines is parametrized over inference structures [S],
   it instantiates the generalization module [G] but also its underlying
   unified [U]. *)
module G = Generalization.Make(S)
module U = G.U

(* The type [variable] of type variables is provided by the unifier [U]. *)

type variable = int

let fresh =
  let counter = ref (-1) in
  fun () ->
    incr counter;
    !counter

(* The type [ischeme] is provided by the generalization engine [G]. *)

(* -------------------------------------------------------------------------- *)

(* Decoding types. *)

(* A variable is decoded to its unique integer identifier, which (via the
   function [O.variable]) is turned into an output type. *)

let decode_variable (x : U.variable) : O.tyvar =
  (* The following assertion ensures that the decoder is invoked only
     after the solver has been run. It would not really make sense to
     invoke the decoder before running the solver. That said, at the
     time of writing this comment, the API does not expose the decoder,
     so the client should have no way of violating this assertion. *)
  assert (G.registered x);
  O.solver_tyvar (U.id x)

let decode_variable_as_type (x : U.variable) : O.ty =
  O.variable (decode_variable x)

(* A type decoder is a function that transforms a unifier variable into an
   output type. We choose to decode types in an eager manner; that is, we take
   care of invoking the decoder, so that the client never needs to perform this
   task. As a result, we do not even need to expose the decoder to the client
   (although we could, if desired). *)

type 'v decoder =
  'v -> O.ty

(* The state of the cyclic decoder cannot persist. We must create a new
   cyclic decoder at every invocation, otherwise the [mu] binders could
   be incorrectly placed in the output. *)

let decode_cyclic : U.variable decoder =
  fun x ->
  U.new_cyclic_decoder
    decode_variable_as_type
    O.structure
    (fun x t -> O.mu (decode_variable x) t)
    x

(* Because [O.ty] is a nominal representation of types, a type is decoded
  in the same way, regardless of how many type binders we have entered.
  This makes it possible for the state of an (acyclic) decoder to persist
  between invocations. Thanks to this property, the type decoding process
  requires only linear time and space, regardless of how many calls to the
  decoder are performed. *)

(* The function [new_decoder] returns a new decoder. If [rectypes] is on,
   the cyclic decoding function, which does not have persistent state, is
   used. If [rectypes] is off, then a new acyclic decoder, with persistent
   state, is created and returned. *)

let new_decoder ~rectypes =
  let decode_acyclic : U.variable decoder =
    U.new_acyclic_decoder
      decode_variable_as_type
      O.structure
  in
  if rectypes then decode_cyclic else decode_acyclic

(* The function [decode_scheme] is parameterized by a type decoder, [decode]. *)

let decode_scheme (decode : U.variable decoder) (s : G.scheme) : O.scheme =
  List.map decode_variable (G.quantifiers s),
  decode (G.body s)

(* -------------------------------------------------------------------------- *)

(* The syntax of constraints is as follows. *)

(* This syntax is exposed to the user in the low-level interface [SolverLo],
   but not in the high-level interface [SolverHi]. So, it could be easily
   modified if desired. *)

type range =
  Lexing.position * Lexing.position

type _ rawco =
| CRange : range * 'a rawco -> 'a rawco
| CTrue : unit rawco
| CConj : 'a rawco * 'b rawco -> ('a * 'b) rawco
| CEq : variable * variable -> unit rawco
| CExist : variable * variable S.structure option * 'a rawco -> 'a rawco
| CWitness : variable -> O.ty rawco
| CInstance : tevar * variable -> O.ty list rawco
| CDef : tevar * variable * 'a rawco -> 'a rawco
| CLet : (tevar * variable) list
         * 'a rawco
         * 'b rawco
       -> (O.tyvar list * (tevar * O.scheme) list * 'a * 'b) rawco
| CMap : 'a rawco * ('a -> 'b) -> 'b rawco

let pprint_rawco c : PPrint.document = begin[@warning "-4"]
  let open! PPrint in
  let rec print : type a . a rawco -> _ = fun c -> print_binders c
  and print_binders : type a . a rawco -> _ = fun c ->
    let next = print_conj in
    let self = print_binders in
    group @@ match c with
    | CRange(_range, c) -> self c
    | CMap(c, _f) -> self c
    | CExist (v, s, c) ->
      string "exi" ^^ space
      ^^ group (
        print_var v
        ^^ (match s with
          | None -> empty
          | Some s ->
            space ^^ string "~" ^^ space ^^ string (S.to_string string_of_var s))
      )
      ^^ space ^^ string "in"
      ^//^  self c
    | CDef (x, v, c) ->
      separate space [
        string "def"; annot x v;
        string "in";
      ] ^//^ self c
    | CLet (xvs, c1, c2) ->
      separate space [
        string "let";
        separate_map (space ^^ string "and" ^^ space) (fun (x, v) -> annot x v) xvs;
        string "=";
        next c1;
        string "in";
      ] ^//^ self c2
    | _ -> next c

  and print_conj : type a . a rawco -> _ = fun c ->
    let self = print_conj in
    let next = print_simple in
    group @@ match c with
    | CRange(_range, c) -> self c
    | CMap(c, _f) -> self c
    | CConj (c1, c2) -> self c1 ^//^ separate space [string "*"; self c2]
    | _ -> next c

  and print_simple : type a . a rawco -> _ = fun c ->
    let self = print_simple in
    let next = print_parenthesized in
    group @@ match c with
    | CRange(_range, c) -> self c
    | CMap(c, _f) -> self c
    | CTrue -> string "True"
    | CWitness v -> separate space [string "wit"; print_var v]
    | CEq (v1, v2) -> separate space [print_var v1; string "="; print_var v2]
    | CInstance (x, v) ->
      separate space [string "inst"; print_tevar x; print_var v]
    | _ -> next c

  and print_parenthesized : type a . a rawco -> _ = fun c ->
    match c with
    | CRange _
    | CMap _
    | CTrue
    | CConj _
    | CEq _
    | CExist _
    | CWitness _
    | CInstance _
    | CDef _
    | CLet _
      ->
      surround 2 0 lparen (print c) rparen

  and annot x v = separate space [print_tevar x; print_var v]
  and print_tevar x = PPrint.string (X.to_string x)
  and print_var v = PPrint.string (string_of_var v)
  and string_of_var v = string_of_int v
  in print c
end

let print_rawco ppf c =
  let open! PPrint in
  ToFormatter.pretty 0.9 80 ppf (pprint_rawco c);
  Format.pp_print_newline ppf ()

(* -------------------------------------------------------------------------- *)

(* The exceptions [Unify] and [Cycle], raised by the unifier, must be caught
   and re-raised in a slightly different format, as the unifier does not know
   about ranges.

   Note that the cyclic decoder is used when decoding types arising in typing
   errors/exceptions, even if the solver was called with [~rectypes:false:
   recursive types can appear before the occurs check is successfully run.
 *)

exception Unify of range * O.ty * O.ty
exception Cycle of range * O.ty

let unify range v1 v2 =
  try
    U.unify v1 v2
  with U.Unify (v1, v2) ->
    let decode = new_decoder ~rectypes:true in
    raise (Unify (range, decode v1, decode v2))

let exit range ~rectypes state vs =
  try
    G.exit ~rectypes state vs
  with U.Cycle v ->
    let decode = new_decoder ~rectypes:true in
    raise (Cycle (range, decode v))

(* -------------------------------------------------------------------------- *)

(* The non-recursive wrapper function [solve] is parameterized by the
   flag [rectypes], which indicates whether recursive types are
   permitted.

   [solve] expects a constraint of type [a rawco] and solves it,
   returning a witness of tyoe [a]. It proceedings in two phases:

   - resolution: first the constraint is traversed, with each part
     solved eagerly,

   - witness construction: if the resolution succeeded, a global
     solution is available, and we construct a witness from the solution.

   This is implemented by having the first phase return a closure of
   type ['a on_sol], that describes how the witnesses should be
   computed when the final value will be available.
*)
exception Unbound of range * tevar

(* ['a on_sol] represents values of type 'a that will only be
   available once the solver has succeeded in finding a global
   solution, as opposed to a value that is available as soon as it is
   done traversing a sub-constraint. *)
type 'a on_sol = On_sol of (unit -> 'a) [@@unboxed]

let solve ~(rectypes : bool) (type a) (c : a rawco) : a =

  (* Initialize the generalization engine. It has mutable state, so [state]
     does not need to be an explicit parameter of the recursive function
     [solve]. *)

  let state = G.init() in

  (* Constraint variables are immutable, they carry no mutable state.
     We map them to unification variables (uvars) by using a table
     indexed over the constraint variable integer identifier. *)
  let module VarTbl = Hashtbl.Make(struct
    type t = variable
    let hash v = Hashtbl.hash v
    let equal v1 v2 = Int.equal v1 v2
  end) in
  let uvars = VarTbl.create 8 in

  let uvar v =
    match VarTbl.find uvars v with
    | None -> Printf.ksprintf failwith "Unbound variable %d" v
    | Some uv -> uv
  in

  let bind_uvar v s =
    assert (not (VarTbl.mem uvars v));
    let uv = G.fresh (Option.map (S.map uvar) s) in
    VarTbl.add uvars v (Some uv);
    G.register state uv;
    uv
  in

  let decode = new_decoder ~rectypes in

  (* The recursive function [solve] is parameterized with an environment
     (which maps term variables to type schemes) and with a range (it is the
     range annotation that was most recently encountered on the way down). *)

  let rec solve : type a . G.scheme XMap.t -> range -> a rawco -> a on_sol =
    fun env range c -> match c with
    | CRange (range, c) ->
        solve env range c
    | CTrue ->
        On_sol (fun () -> ())
    | CMap (c, f) ->
        let (On_sol r) = solve env range c in
        On_sol (fun () -> f (r ()))
    | CConj (c1, c2) ->
        let (On_sol r1) = solve env range c1 in
        let (On_sol r2) = solve env range c2 in
        On_sol (fun () -> (r1 (), r2 ()))
    | CEq (v, w) ->
        unify range (uvar v) (uvar w);
        On_sol (fun () -> ())
    | CExist (v, s, c) ->
        ignore (bind_uvar v s);
        solve env range c
    | CWitness v ->
        On_sol (fun () -> decode (uvar v))
    | CInstance (x, w) ->
        (* The environment provides a type scheme for [x]. *)
        let s = try XMap.find x env with Not_found -> raise (Unbound (range, x)) in
        (* Instantiating this type scheme yields a variable [v], which we unify with
           [w]. It also yields a list of witnesses, which we record, as they will be
           useful during the decoding phase. *)
        let witnesses, v = G.instantiate state s in
        unify range v (uvar w);
        On_sol (fun () -> List.map decode witnesses)
    | CDef (x, v, c) ->
        let env = XMap.add x (G.trivial (uvar v)) env in
        solve env range c
    | CLet (xvs, c1, c2) ->
        (* Warn the generalization engine that we are entering the left-hand
           side of a [let] construct. *)
        G.enter state;
        (* Register the variables [vs] with the generalization engine, just as if
           they were existentially bound in [c1]. This is what they are, basically,
           but they also serve as named entry points. *)
        let vs = List.map (fun (_, v) -> bind_uvar v None) xvs in
        (* Solve the constraint [c1]. *)
        let (On_sol r1) = solve env range c1 in
        (* Ask the generalization engine to perform an occurs check, to adjust the
           ranks of the type variables in the young generation (i.e., all of the
           type variables that were registered since the call to [G.enter] above),
           and to construct a list [ss] of type schemes for our entry points. The
           generalization engine also produces a list [generalizable] of the young
           variables that should be universally quantified here. *)
        let generalizable, ss = exit range ~rectypes state vs in
        (* Extend the environment [env] and compute the type schemes [xss] *)
        let env, xss =
          List.fold_right2 (fun (x, _) s (env, xss) ->
            (XMap.add x s env, (x,s)::xss)
          ) xvs ss (env, [])
        in
        (* Proceed to solve [c2] in the extended environment. *)
        let (On_sol r2) = solve env range c2 in
        On_sol (fun () ->
          List.map decode_variable generalizable,
          List.map (fun (x, s) -> (x, decode_scheme decode s)) xss,
          r1 (),
          r2 ())
  in
  let env = XMap.empty
  and range = Lexing.(dummy_pos, dummy_pos) in
  let (On_sol witness) = solve env range c in
  witness ()
end