Source file PUnif.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
(* This file is free software, part of Logtk. See file "license" for more details. *)

(** {1 Pragmatic variant of JP algorithm} *)

module U = Unif_subst
module T = Term
module H = HVar
module S = Subst
module P = PatternUnif
module Params = PragUnifParams
module I = Int32
module IntSet = CCSet.Make(CCInt)
module PUP = PragUnifParams

let elim_vars = ref IntSet.empty
let ident_vars = ref IntSet.empty


type op =
  | ProjApp
  | ImitFlex
  | ImitRigid
  | Ident
  | Elim

let (<<<) = I.shift_left 
let (>>>) = I.shift_right_logical  
let (&&&) = I.logand
let (|||) = I.logor
let (~~~) = I.lognot

let (i63) = I.of_int 63

let op_masks =
  [ProjApp, ( i63, (0), "proj");
   ImitFlex, (i63 <<< 6, 6, "imit_flex");
   ImitRigid, (i63 <<< 12, 12, "imit_rigid");
   Ident, (i63 <<< 18, 18, "ident");
   Elim, (i63 <<< 24, 24, "elim")]

let get_op flag op =
  let mask,shift,_ = List.assoc op op_masks in
  I.to_int ((flag &&& mask) >>> shift)

let get_depth flag =
  let ops = [ProjApp; ImitFlex; ImitRigid; Ident; Elim] in
  List.fold_left (fun acc o -> get_op flag o + acc ) 0 ops

let inc_op flag op =
  let old = get_op flag op in
  let mask, shift, _ = List.assoc op op_masks in
  let op_val = (I.succ ((flag &&& mask) >>> shift)) <<< shift in
  let res = (flag &&& (~~~ mask)) ||| op_val in
  assert( old + 1 = (get_op res op));
  res

let pp_flag out flag =
  List.iter (fun (op, (_,_,name)) ->
      CCFormat.fprintf out "|%s:%d" name (get_op flag op);
    ) op_masks

(* Create substitution: v |-> λ u1 ... um. u_i (H1 u1 ... um) ... (Hn u1 ... um)
   where type of u_i is τ1 -> ... τn -> τ where τ is atomic and H_i have correct
   type. This substitution is called a projection. *)
let project_hs_one ~counter pref_types i type_ui =
  let pref_types_ui, _ = Type.open_fun type_ui in
  let n_args_free = List.length pref_types in
  let pref_args = 
    List.mapi (fun i ty -> T.bvar ~ty (n_args_free-i-1)) pref_types in
  let new_vars = 
    List.map (fun ty ->
        let new_ty =  (Type.arrow pref_types ty) in
        T.var (H.fresh_cnt ~counter ~ty:new_ty ()))
      pref_types_ui  in
  let new_vars_applied = List.map (fun nv -> T.app nv pref_args) new_vars in
  let matrix_hd = T.bvar ~ty:type_ui (n_args_free-i-1) in
  let matrix = T.app matrix_hd new_vars_applied in
  T.fun_l pref_types matrix

(* Create substitution: v |-> λ u1 ... um. f (H1 u1 ... um) ... (Hn u1 ... um)
   where type of f is τ1 -> ... τn -> τ where τ is atomic, H_i have correct
   type and f is a constant. This substitution is called an imitation.*)
let imitate_one ~scope ~counter  s t =
  try
    OSeq.nth 0 (JP_unif.imitate_onesided ~scope ~counter s t)
  with Not_found ->  invalid_arg "no_imits"

let proj_lr ~counter ~scope ~subst s t flag max_app_projs = 
  let hd_s, args_s = CCPair.map_fst T.as_var_exn (T.as_app s) in
  let argss_arr = CCArray.of_list args_s in
  let hd_t,_ = T.as_app (snd (T.open_fun t)) in
  let pref_tys, hd_ret_ty = Type.open_fun (HVar.ty hd_s) in
  pref_tys
  |> List.mapi (fun i ty -> i, ty)
  |> (fun l ->
      (* if we performed more than N projections that applied the
         bound variable we back off *)
      if get_op flag ProjApp < max_app_projs then l
      else List.filter (fun (_, ty) -> List.length (Type.expected_args ty) = 0) l)
  (* If heads are different constants, do not project to those subterms *)
  |> CCList.filter_map (fun ((i, _) as p) -> 
      if i < List.length args_s then (
        let s_i = snd (T.open_fun argss_arr.(i)) in
        let hd_si = T.head_term s_i in
        if ((T.is_const hd_si && T.is_const hd_t && (not (T.equal hd_si hd_t))) 
            || (T.is_bvar argss_arr.(i) && T.is_bvar hd_t && (not (T.equal argss_arr.(i) hd_t)))) then None 
        else Some p
      ) else Some p
    )
  |> CCList.filter_map(fun (i, ty) ->
      let _, arg_ret_ty = Type.open_fun ty in
      match PatternUnif.unif_simple ~subst ~scope 
              (T.of_ty arg_ret_ty) (T.of_ty hd_ret_ty) with
      | Some subst' ->
        (* we project only to arguments of appropriate type *)
        let subst' = Unif_subst.subst subst' in
        let pr_bind = project_hs_one ~counter pref_tys i ty in
        let max_num_of_apps = 
          List.length @@ Type.expected_args ty in
        let flag' = if max_num_of_apps > 0 then inc_op flag ProjApp else flag in
        (* let flag' = inc_op flag ProjApp in *)
        Some (Subst.FO.bind' subst' (hd_s, scope) (pr_bind, scope), flag')
      | None -> None)

let proj_hs ~counter ~scope ~flex s =
  CCList.map fst @@ proj_lr ~counter ~scope ~subst:Subst.empty flex s Int32.zero max_int

let k_subset ~k l =
  let rec aux i acc l = 
    if i = 0 then OSeq.return acc
    else if i > List.length l then OSeq.empty 
    else (
      match l with 
      | x :: xs ->
        OSeq.interleave (aux i acc xs) (aux (i-1) (x::acc) xs)
      | [] -> assert(false)
    ) in

  assert(k>=0);
  aux k [] l

let elim_subsets_rule  ?(max_elims=None) ~elim_vars ~counter ~scope t u depth =
  let hd_t, args_t = T.head_term t, Array.of_list (T.args t) in
  let hd_u, args_u = T.head_term u, Array.of_list (T.args u) in
  assert(T.is_var hd_t);
  assert(T.is_var hd_u);
  assert(T.equal hd_t hd_u);

  let hd_var = T.as_var_exn hd_t in
  let var_id = !counter in
  elim_vars := IntSet.add var_id !elim_vars;
  incr counter;

  let pref_tys, ret_ty = Type.open_fun (T.ty hd_t) in
  let pref_len = List.length pref_tys in

  let same_args, diff_args = 
    List.mapi (fun i ty -> 
        if i < Array.length args_t && i < Array.length args_u &&
           T.equal args_t.(i) args_u.(i) 
        then `Left (T.bvar ~ty (pref_len-i-1))
        else `Right (T.bvar ~ty (pref_len-i-1))) pref_tys
    |> CCList.partition_map CCFun.id in

  let diff_args_num = List.length diff_args in
  let end_ = match max_elims with 
    | None -> 0 
    | Some x -> assert(x>0); diff_args_num-x in
  let start,step = max (diff_args_num-1) 0, -1 in
  CCList.range_by start (max end_ 0) ~step
  |> OSeq.of_list
  |> OSeq.flat_map (fun k ->
      k_subset ~k diff_args
      |> OSeq.map (fun diff_args_subset ->
          assert(List.length diff_args_subset = k);
          let all_args = diff_args_subset @ same_args in
          assert(List.length all_args <= pref_len);
          let arg_tys = List.map T.ty all_args in
          let ty = Type.arrow arg_tys ret_ty in
          let matrix = T.app (T.var (HVar.make ~ty var_id)) all_args in
          let subs_term = T.fun_l pref_tys matrix in
          assert(T.DB.is_closed subs_term);
          (Subst.FO.bind' Subst.empty (hd_var, scope) (subs_term, scope),
           (depth+(diff_args_num-k)))))

let subset_elimination ~max_elims ~counter ~scope t u =
  elim_subsets_rule ~elim_vars ~max_elims ~counter ~scope t u 0
  |> OSeq.map (fun sub_flag -> 
      fst sub_flag, (snd sub_flag))

module Make (St : sig val st : Flex_state.t end) = struct
  module PUP = PragUnifParams 
  module SU = SolidUnif.Make(St)

  let get_option k = Flex_state.get_exn k St.st 

  let delay _ res = res

  (*Create all possible projection and imitation bindings. *)
  let proj_imit_lr ?(disable_imit=false) ~counter ~scope ~subst s t flag =
    try
      let simp_proj, func_proj = 
        let is_ident_last = 
          let hd_var_id = HVar.id (T.as_var_exn (T.head_term s)) in
          IntSet.mem hd_var_id !ident_vars in
        if is_ident_last then [],[]
        else (
          proj_lr ~counter ~scope ~subst s t flag (get_option PUP.k_max_app_projections)
          |> CCList.partition_map (fun ((sub,_) as r) -> 
              let binding,_ = Subst.FO.deref sub (T.head_term s,scope) in
              let _,body = T.open_fun binding in
              if T.is_bvar body then `Left (Some r) else `Right (Some r))) in
      let imit_binding =
        try
          if not disable_imit && not (Term.is_var (T.head_term t)) &&
             get_op flag ImitRigid < get_option PUP.k_max_rigid_imitations then (
            let flag' = inc_op flag ImitRigid in
            [Some (U.subst @@ imitate_one ~scope ~counter s t, flag')])
          else []
        with Invalid_argument s when String.equal s "no_imits" -> [] in
      (* OSeq.of_list (simp_proj @ imit_binding @ func_proj) *)
      (* OSeq.append 
         (OSeq.of_list simp_proj)
         (delay (get_depth flag) @@ OSeq.append (OSeq.of_list imit_binding) (OSeq.of_list func_proj)) *)
      OSeq.append 
        (OSeq.append (OSeq.of_list simp_proj) (OSeq.of_list imit_binding) )
        ((OSeq.of_list func_proj))
    with Invalid_argument s when String.equal s "as_var_exn" ->
      OSeq.empty

  let elim_rule ~counter ~scope t _ flag = 
    let eliminate_at_idx v k =  
      let prefix_types, return_type = Type.open_fun (HVar.ty v) in
      let m = List.length prefix_types in
      let bvars = List.mapi (fun i ty -> T.bvar ~ty (m-1-i)) prefix_types in
      let prefix_types' = CCList.remove_at_idx k prefix_types in
      let new_ty = Type.arrow prefix_types' return_type in
      let bvars' = CCList.remove_at_idx k bvars in
      let matrix_head = T.var (H.fresh_cnt ~counter ~ty:new_ty ()) in
      let matrix = T.app matrix_head bvars' in
      let subst_value = T.fun_l prefix_types matrix in
      let subst = S.FO.bind' Subst.empty (v, scope) (subst_value, scope) in
      subst in 

    let eliminate_one t = 
      let hd, args = T.as_app t in
      if T.is_var hd && List.length args > 0 then (
        CCList.range 0 ((List.length args)-1) 
        |> List.map (eliminate_at_idx (T.as_var_exn hd)))
      else [] in
    eliminate_one t
    |> List.map (fun x -> Some (x, inc_op flag Elim))

  (* removes all arguments of an applied variable
     v |-> λ u1 ... um. x
  *)
  let elim_trivial ~scope ~counter v =  
    let prefix_types, return_type = Type.open_fun (HVar.ty v) in
    let matrix_head = T.var (H.fresh_cnt ~counter ~ty:return_type ()) in
    let subst_value = T.fun_l prefix_types matrix_head in
    let subst = Subst.FO.bind' Subst.empty (v, scope) (subst_value, scope) in
    subst

  let flex_flex_diff_trivial ~scope ~counter x y  =  
    let prefix_types_x, return_type_x = Type.open_fun (HVar.ty x) in
    let prefix_types_y, return_type_y = Type.open_fun (HVar.ty y) in
    assert(Type.equal return_type_x return_type_y);
    let matrix_head = T.var (H.fresh_cnt ~counter ~ty:return_type_x ()) in
    let subst_value_x = T.fun_l prefix_types_x matrix_head in
    let subst_value_y = T.fun_l prefix_types_y matrix_head in
    let subst = Subst.FO.bind' Subst.empty (x, scope) (subst_value_x, scope) in
    let subst = Subst.FO.bind' subst (y, scope) (subst_value_y, scope) in
    subst

  let renamer ~counter t0s t1s = 
    let lhs,rhs, unifscope, us = U.FO.rename_to_new_scope ~counter t0s t1s in
    lhs,rhs,unifscope,U.subst us

  let deciders ~counter () =
    let pattern = 
      if get_option PUP.k_pattern_decider then 
        [(fun s t sub -> [(U.subst @@ PatternUnif.unify_scoped ~subst:(U.of_subst sub) ~counter s t)])] 
      else [] in
    let solid = 
      if get_option PUP.k_solid_decider then 
        [(fun s t sub -> (List.map U.subst @@ SU.unify_scoped ~subst:(U.of_subst sub) ~counter s t))] 
      else [] in
    let fixpoint = 
      if get_option PUP.k_fixpoint_decider then 
        [(fun s t sub -> [(U.subst @@ FixpointUnif.unify_scoped ~subst:(U.of_subst sub) ~counter s t)])] 
      else [] in
    fixpoint @ pattern @ solid

  let head_classifier s =
    match T.view @@ T.head_term s with 
    | T.Var x -> `Flex x
    | _ -> `Rigid

  let oracle ~counter ~scope ~subst (s,_) (t,_) (flag:I.t) =
    let depth = get_depth flag in
    let res = 
      if depth < get_option PUP.k_max_depth then (
        match head_classifier s, head_classifier t with
        | `Flex x, `Flex y when HVar.equal Type.equal x y ->
          let num_elims = get_op flag Elim in
          let remaining_elims = get_option PUP.k_max_elims - num_elims in
          if remaining_elims > 0 then (
            (subset_elimination ~max_elims:(Some remaining_elims) ~counter ~scope s t
             |> OSeq.map (fun (sub, inc) ->
                 let flag' = CCList.fold_left (fun acc _ -> inc_op acc Elim) 
                     flag (CCList.replicate inc None) in
                 Some (sub, flag'))))
          else OSeq.return (Some (elim_trivial ~counter ~scope x, flag))
        | `Flex x, `Flex y ->
          (* all rules  *)
          let ident = 
            if get_op flag Ident < get_option PUP.k_max_identifications then (
              JP_unif.identify ~scope ~counter s t []
              |> OSeq.map (fun x ->
                  let subst = U.subst x  in
                  (* variable introduced by identification *)
                  let subs_t = T.of_term_unsafe @@ fst (snd (List.hd (Subst.to_list subst))) in
                  let new_var, _ = T.as_app (snd (T.open_fun subs_t)) in
                  let new_var_id = HVar.id (T.as_var_exn new_var) in
                  (* remembering that we introduced this var in identification *)
                  ident_vars := IntSet.add new_var_id !ident_vars;
                  Some (subst, inc_op flag Ident)))
            else OSeq.empty in
          let projs =
            OSeq.append
              (proj_imit_lr ~disable_imit:true ~scope ~counter ~subst s t flag)
              (proj_imit_lr ~disable_imit:true ~scope ~counter ~subst t s flag) in
          let projs_ident =
            if OSeq.is_empty projs && OSeq.is_empty ident then OSeq.empty
            else OSeq.append projs (delay depth ident) in
          if not (OSeq.is_empty projs_ident) then projs_ident
          else OSeq.return (Some (flex_flex_diff_trivial ~scope ~counter x y, flag))
        | `Flex _, `Rigid
        | `Rigid, `Flex _ ->
          OSeq.append
            (proj_imit_lr ~counter ~scope ~subst s t flag)
            (proj_imit_lr ~counter ~scope ~subst t s flag)
        | _ -> 
          CCFormat.printf "Did not disassemble properly: [%a]\n[%a]@." T.pp s T.pp t;
          assert false)
      else OSeq.empty in
    let hd_t, hd_s = T.head_term s, T.head_term t in
    if T.is_var hd_t && T.is_var hd_s && T.equal hd_s hd_t &&
       IntSet.mem (HVar.id @@ T.as_var_exn hd_t) !elim_vars then (
      OSeq.empty)
    else res

  let unify_scoped =  
    let counter = ref 0 in

    let module PragUnifParams = struct
      exception NotInFragment = PatternUnif.NotInFragment
      exception NotUnifiable = PatternUnif.NotUnifiable
      type flag_type = int32
      let flex_state = St.st
      let init_flag = (Int32.zero:flag_type)
      let identify_scope = renamer ~counter
      let frag_algs = deciders ~counter (*[]*)
      let pb_oracle s t (f:flag_type) subst scope = 
        oracle ~counter ~scope ~subst s t f
      let oracle_composer = Flex_state.get_exn PUP.k_oracle_composer St.st
    end in

    let module PragUnif = UnifFramework.Make(PragUnifParams) in
    (fun x y ->
       elim_vars := IntSet.empty;
       ident_vars := IntSet.empty;
       let res = PragUnif.unify_scoped x y in
       OSeq.map (CCOpt.map Unif_subst.of_subst) (res))
end