Source file PatternUnif.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
(** {1 Pattern unification algorithm implementation} *)
module T = Term
module US = Unif_subst
module H = HVar
exception NotUnifiable
exception NotInFragment
type subst = US.t
module S = struct
let apply s t = Subst.FO.apply Subst.Renaming.none (US.subst s) t
let apply_ty s ty = Subst.Ty.apply Subst.Renaming.none (US.subst s) ty
let pp = US.pp
end
let unif_simple ?(subst=Subst.empty) ~scope t s =
try
let type_unifier = Unif.FO.unify_syn ~subst (t, scope) (s, scope) in
Some (US.of_subst type_unifier)
with Unif.Fail -> None
let build_constraints args1 args2 rest =
let rf, other =
CCList.combine args1 args2
|> CCList.partition (fun (s,t) -> T.is_const (T.head_term s) && T.is_const (T.head_term t)) in
rf @ rest @ other
let eta_expand_otf ~subst ~scope pref1 pref2 t1 t2 =
let do_exp_otf n types t =
let remaining = CCList.drop n types in
assert(List.length remaining != 0);
let num_vars = List.length remaining in
let vars = List.mapi (fun i ty ->
let ty = Subst.Ty.apply Subst.Renaming.none (US.subst subst) (ty,scope) in
T.bvar ~ty (num_vars-1-i)) remaining in
let shifted = T.DB.shift num_vars t in
T.app shifted vars in
if List.length pref1 = List.length pref2 then (t1, t2, pref1)
else (
let n1, n2 = List.length pref1, List.length pref2 in
if n1 < n2 then (do_exp_otf n1 pref2 t1,t2,pref2)
else (t1,do_exp_otf n2 pref1 t2,pref1))
let cmp (i, _) (j, _) = compare i j
let rec eligible_arg t =
match T.view t with
| T.AppBuiltin _ | T.Const _ | T.Var _ -> false
| T.DB _ -> true
| T.Fun (_, body) -> eligible_arg body
| T.App (f, l) -> List.for_all eligible_arg (f :: l)
let get_bvars args =
let n = List.length args in
if List.for_all T.is_bvar args then (
let res = List.mapi
(fun i a -> (Term.as_bvar_exn a, T.bvar ~ty:(Term.ty a) (n-1-i)))
args in
let no_dup = CCList.sort_uniq ~cmp res in
if List.length no_dup = List.length res
then Some (CCArray.of_list no_dup)
else None)
else None
let norm t =
if Term.is_fun (T.head_term t)
then Lambda.whnf t else t
let rec norm_deref subst (t,sc) =
let pref, tt = T.open_fun t in
let t' =
begin match T.view tt with
| T.Var _ ->
let u, _ = US.FO.deref subst (tt,sc) in
if T.equal tt u then u
else norm_deref subst (u,sc)
| T.App (f0, l) ->
let f = norm_deref subst (f0, sc) in
let t =
if T.equal f0 f then tt else T.app f l in
let u = norm t in
if T.equal t u
then t
else norm_deref subst (u,sc)
| _ -> tt
end in
if T.equal tt t' then t
else T.fun_l pref t'
let rec build_term ?(depth=0) ~all_args ~subst ~scope ~counter var bvar_map t =
let rec args_same ls1 ls2 =
match ls1, ls2 with
| ((Some x)::xs), (y::ys) ->
T.equal x y && args_same xs ys
| [], [] -> true
| _ -> false in
let t = norm t in
match T.view t with
| T.Var _ ->
let t' = fst @@ US.FO.deref subst (t,scope) in
if T.equal t' t then (
if T.equal var t then (
if Type.is_fun (T.ty var) then raise NotInFragment
else raise (Failure "occurs check")
)
else (t', subst)
)
else build_term ~all_args ~subst ~scope ~counter ~depth var bvar_map t'
| T.Const _ -> (t, subst)
| T.App (hd, args) ->
if T.is_var hd then (
assert(not @@ CCList.is_empty args);
if T.equal hd var then
raise NotInFragment;
if not (US.FO.mem subst (Term.as_var_exn hd, scope)) then (
if CCOpt.is_none (get_bvars args) then raise NotInFragment;
let new_args, subst =
List.fold_right (fun arg (l, subst) -> (
try (
let arg', subst = build_term ~all_args ~depth ~subst ~scope ~counter
var bvar_map arg in
Some arg' :: l, subst)
with Failure _ -> None :: l, subst))
args ([], subst) in
let pref_types = List.map Term.ty args in
let n = List.length pref_types in
let ret_type = Type.apply_unsafe (Term.ty hd)
((List.mapi (fun i x -> match x with
| Some t -> t
| None -> List.nth args i) new_args) :> InnerTerm.t list) in
let matrix =
CCList.filter_map CCFun.id (List.mapi (fun i opt_arg ->
(match opt_arg with
| Some arg -> Some (T.bvar ~ty:(Term.ty arg) (n-i-1))
| None -> None)) new_args) in
if List.length matrix != List.length args then (
let ty = Type.arrow (List.map Term.ty matrix) ret_type in
let new_hd = T.var @@ H.fresh_cnt ~counter ~ty () in
let hd_subs = T.fun_l pref_types (T.app new_hd matrix) in
let subst = US.FO.bind subst (T.as_var_exn hd, scope) (hd_subs, scope) in
let res_term = T.app new_hd (CCList.filter_map (fun x->x) new_args) in
res_term, subst
)
else (
if args_same new_args args then (t,subst)
else T.app hd (CCList.filter_map CCFun.id new_args), subst
)
)
else (
let hd',_ = US.FO.deref subst (hd, scope) in
let t' = if T.equal hd hd' then t else T.app hd' args in
build_term ~all_args ~depth ~subst ~scope ~counter var bvar_map t'
)
) else (
let new_hd, subst = build_term ~all_args ~depth ~subst ~scope ~counter var bvar_map hd in
let new_args, subst =
List.fold_right (fun arg (l, subst) ->
let arg', subst = build_term ~all_args ~depth ~subst ~scope ~counter var bvar_map arg in
arg' :: l, subst
) args ([], subst) in
if T.equal new_hd hd && List.for_all2 T.equal args new_args then t,subst
else T.app new_hd new_args, subst
)
| T.Fun(ty, body) ->
let b', subst = build_term ~all_args ~depth:(depth+1) ~subst ~scope ~counter var bvar_map body in
if T.equal b' body then t,subst
else T.fun_ ty b', subst
| T.DB i ->
if i < depth then t,subst
else (
match CCArray.bsearch ~cmp (i-depth, Term.true_) bvar_map with
| `At idx ->
let val_,bvar = CCArray.get bvar_map idx in
assert(val_ = (i-depth));
T.DB.shift depth bvar, subst
| _ -> raise (Failure "Bound variable not argument to head")
)
| T.AppBuiltin(b,args) ->
let new_args, subst =
List.fold_right (fun arg (l, subst) ->
let arg', subst = build_term ~all_args ~depth ~subst ~scope ~counter var bvar_map arg in
arg' :: l, subst
) args ([], subst) in
if List.for_all2 T.equal args new_args then t,subst
else T.app_builtin ~ty:(Term.ty t) b new_args, subst
let rec unify ~scope ~counter ~subst = function
| [] -> subst
| (s,t) :: rest -> (
if not @@ Type.is_ground (T.ty s) || not @@ Type.is_ground (T.ty t) then (
raise NotInFragment
);
if not (Type.equal (T.ty s) (T.ty t)) then (
raise NotUnifiable
);
let s', t' = norm_deref subst (s,scope), norm_deref subst (t,scope) in
if not (Term.equal s' t') then (
let pref_s, body_s = T.open_fun s' in
let pref_t, body_t = T.open_fun t' in
let body_s', body_t', pref_l = eta_expand_otf ~subst ~scope pref_s pref_t body_s body_t in
let hd_s, args_s = T.as_app body_s' in
let hd_t, args_t = T.as_app body_t' in
match T.view hd_s, T.view hd_t with
| (T.Var _, T.Var _) ->
let subst =
(if T.equal hd_s hd_t then flex_same ~counter ~scope ~subst hd_s args_s args_t
else flex_diff ~counter ~scope ~subst hd_s hd_t args_s args_t) in
unify ~scope ~counter ~subst rest
| (T.Var _, T.Const _) | (T.Var _, T.DB _) | (T.Var _, T.AppBuiltin _) ->
let subst = flex_rigid ~pref_l ~subst ~counter ~scope body_s' body_t' in
unify ~scope ~counter ~subst rest
| (T.Const _, T.Var _) | (T.DB _, T.Var _) | (T.AppBuiltin _, T.Var _) ->
let subst = flex_rigid ~pref_l ~subst ~counter ~scope body_t' body_s' in
unify ~scope ~counter ~subst rest
| T.Const f , T.Const g when ID.equal f g && List.length args_s = List.length args_t ->
unify ~subst ~counter ~scope @@ build_constraints args_s args_t rest
| T.AppBuiltin(hd_s, args_s'), T.AppBuiltin(hd_t, args_t') when
Builtin.equal hd_s hd_t &&
List.length args_s' + List.length args_s =
List.length args_t' + List.length args_t ->
unify ~subst ~counter ~scope @@ build_constraints (args_s'@args_s) (args_t'@args_t) rest
| T.DB i, T.DB j when i = j && List.length args_s = List.length args_t ->
unify ~subst ~counter ~scope @@ build_constraints args_s args_t rest
| _ -> raise NotUnifiable)
else (
unify ~subst ~counter ~scope rest
)
)
and flex_same ~counter ~scope ~subst var args_s args_t =
let bvar_s, bvar_t = get_bvars args_s, get_bvars args_t in
if CCOpt.is_none bvar_s || CCOpt.is_none bvar_t then
raise NotInFragment;
let bvar_s, bvar_t = CCOpt.get_exn bvar_s, CCOpt.get_exn bvar_t in
assert(CCArray.length bvar_s = CCArray.length bvar_t);
let v = Term.as_var_exn var in
let ret_ty = Type.apply_unsafe (Term.ty var)
(args_s :> InnerTerm.t list) in
let bvars =
CCList.filter_map (fun x->x)
(CCArray.mapi (fun _ si ->
let i,s = si in
if i < CCArray.length bvar_t then (
let bi,bv = CCArray.get bvar_t i in
if i=bi && T.equal s bv then Some s else None)
else None) bvar_s
|> CCArray.to_list) in
let v_ty = Type.arrow (List.map T.ty bvars) ret_ty in
let matrix = Term.app (Term.var (H.fresh_cnt ~counter ~ty:v_ty ())) bvars in
let res_term = Term.fun_l (List.map Term.ty args_s) matrix in
let subst = US.FO.bind subst (v, scope) (res_term, scope) in
subst
and flex_diff ~counter ~scope ~subst var_s var_t args_s args_t =
if CCList.is_empty args_s && CCList.is_empty args_t then (
US.FO.bind subst (Term.as_var_exn var_s,scope) (var_t,scope)
) else (
let bvar_s, bvar_t = get_bvars args_s, get_bvars args_t in
if CCOpt.is_none bvar_s || CCOpt.is_none bvar_t then (
raise NotInFragment
) else (
let bvar_s, bvar_t = CCOpt.get_exn bvar_s, CCOpt.get_exn bvar_t in
let new_bvars =
CCArray.map (fun si ->
match CCArray.bsearch ~cmp (fst si, Term.true_) bvar_t with
| `At idx -> Some (snd si, snd @@ CCArray.get bvar_t idx)
| _ -> None
) bvar_s
|> CCArray.filter_map CCFun.id
|> CCArray.to_list in
let arg_types = List.map (fun (b1, _) -> Term.ty b1) new_bvars in
let ret_ty =
Type.apply_unsafe (Term.ty var_s) (args_s :> InnerTerm.t list) in
let new_var_ty = Type.arrow arg_types ret_ty in
let new_var = Term.var @@ H.fresh_cnt ~counter ~ty:new_var_ty () in
let matrix_s = Term.app new_var (List.map fst new_bvars) in
let matrix_t = Term.app new_var (List.map snd new_bvars) in
let subs_s = Term.fun_l (List.map Term.ty args_s) matrix_s in
let subs_t = Term.fun_l (List.map Term.ty args_t) matrix_t in
let v_s, v_t = Term.as_var_exn var_s, Term.as_var_exn var_t in
let subst = US.FO.bind subst (v_s, scope) (subs_s, scope) in
let subst = US.FO.bind subst (v_t, scope) (subs_t, scope) in
subst)
)
and flex_rigid ~pref_l ~subst ~counter ~scope flex rigid =
let hd, args = Term.as_app flex in
assert(Term.is_var hd);
let bvars = get_bvars args in
if CCOpt.is_none bvars then
raise NotInFragment;
let bvars = CCOpt.get_exn bvars in
let all_args = List.length pref_l = Array.length bvars in
try
let matrix, subst =
build_term ~all_args ~subst ~scope ~counter hd bvars rigid in
let new_subs_val = T.fun_l (List.map Term.ty args) matrix in
US.FO.bind subst (T.as_var_exn hd, scope) (new_subs_val, scope)
with Failure _ -> raise NotUnifiable
let unify_scoped ?(subst=US.empty) ?(counter = ref 0) t0_s t1_s =
let res =
if US.is_empty subst then (
let t0',t1',scope,subst = US.FO.rename_to_new_scope ~counter t0_s t1_s in
unify ~scope ~counter ~subst [(t0', t1')]
)
else (
if Scoped.scope t0_s != Scoped.scope t1_s then (
raise (Invalid_argument "scopes should be the same")
)
else (
let t0', t1' = fst t0_s, fst t1_s in
unify ~scope:(Scoped.scope t0_s) ~counter ~subst [(t0', t1')]
)
)
in
res