Source file record.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
open Base
open Ppxlib
open Ast_builder.Default

let create_type_name_map ~loc ~rec_flag ~type_name =
  match rec_flag with
  | Recursive -> Map.singleton (module String) type_name (evar ~loc Naming.recurse)
  | Nonrecursive -> Map.empty (module String)
;;

(* we only need to mark the function recursive if we made any recursive calls. the type
   might say recursive (the default) without actually being recursive *)
let set_any_recursive_and_return_expr ~any_recursive (result, expr) =
  match (result : Generic_map.replace_result) with
  | Unchanged -> expr
  | Replaced ->
    any_recursive := true;
    expr
;;

(* This is complicated so here's some help on fields:
   - source = fields from record [@@deriving stable_record] attached to
   - target = source + add - remove
   - add:[ a ] = value is gong to come from ~a argument
   - set:[ s ] = value is going to come from ~s argument
   - modify:[ m ] = value is going to come from ~modify_m argument

   In particular:
   fields_from_args = set + (target - source) = set + (add - remove)
   are the fields that we expect to get from ~a and ~s arguments.
*)
let convert_record
  ~loc
  ~fields
  ~source_fields
  ~target_fields
  ~modified_fields
  ~set_fields
  ~source_type
  ~target_type
  ~rec_flag
  ~type_name
  =
  let record_pat =
    let record_pat =
      List.map (Set.to_list source_fields) ~f:(fun name ->
        if Set.mem target_fields name && not (Set.mem set_fields name)
        then Ast_helpers.mk_lident ~loc name, ppat_var ~loc (Located.mk ~loc name)
        else Ast_helpers.mk_lident ~loc name, ppat_any ~loc)
    in
    ppat_record ~loc record_pat Closed
  in
  let fields_from_args = Set.union set_fields (Set.diff target_fields source_fields) in
  let any_recursive = ref false in
  let fields =
    Map.of_alist_exn (module String) (List.map fields ~f:(fun ld -> ld.pld_name.txt, ld))
  in
  let map_if_recursive = create_type_name_map ~loc ~rec_flag ~type_name in
  let target_record =
    let fields =
      List.map (Set.to_list target_fields) ~f:(fun name ->
        let expr =
          if Set.mem modified_fields name
          then (
            let f = evar ~loc (Naming.modify_field name) in
            pexp_apply ~loc f [ Nolabel, evar ~loc name ])
          else if Set.mem fields_from_args name
          then evar ~loc name
          else (
            let ld = Map.find_exn fields name in
            Generic_map.build ~loc ~map:map_if_recursive ld.pld_type (evar ~loc name)
            |> set_any_recursive_and_return_expr ~any_recursive)
        in
        Ast_helpers.mk_lident ~loc name, expr)
    in
    pexp_record ~loc fields None
  in
  let acc =
    match !any_recursive with
    | false ->
      [%expr
        let ([%p record_pat] : [%t source_type]) = _t in
        ([%e target_record] : [%t target_type])]
    | true ->
      [%expr
        let rec [%p pvar ~loc Naming.recurse] =
          fun ([%p record_pat] : [%t source_type]) : [%t target_type] ->
          [%e target_record]
        in
        [%e evar ~loc Naming.recurse] _t]
  in
  let acc =
    Set.fold fields_from_args ~init:acc ~f:(fun acc name ->
      Ast_helpers.mk_pexp_fun ~loc ~name acc)
  in
  let acc =
    Set.fold_right modified_fields ~init:acc ~f:(fun name acc ->
      let name = Naming.modify_field name in
      Ast_helpers.mk_pexp_fun ~loc ~name acc)
  in
  (* we put this argument first to help with record field disambiguation at the use site *)
  [%expr fun (_t : [%t source_type]) -> [%e acc]]
;;

let create_ast_structure_items
  ~loc
  ~fields
  ~add
  ~remove
  ~modify
  ~set
  ~target_type
  ~current_type
  ~rec_flag
  ~type_name
  =
  match target_type with
  | None ->
    Location.raise_errorf
      ~loc
      "%s: missing target version"
      (Naming.ppx ~which_ppx:`Record)
  | Some target_type ->
    let current_fields =
      Set.of_list (module String) (List.map fields ~f:(fun ld -> ld.pld_name.txt))
    in
    Invariants.things_are_known
      ~thing_name:Naming.fields
      ~supposed_to_be:Naming.removed
      ~loc
      ~all:current_fields
      remove;
    Invariants.things_are_known
      ~thing_name:Naming.fields
      ~supposed_to_be:Naming.modified
      ~loc
      ~all:current_fields
      modify;
    Invariants.things_are_known
      ~thing_name:Naming.fields
      ~supposed_to_be:Naming.set
      ~loc
      ~all:current_fields
      set;
    let other_fields = Set.diff (Set.union current_fields add) remove in
    let to_target_name =
      Naming.conversion_function ~dir:`To ~source:type_name ~target:target_type
    in
    let of_target_name =
      Naming.conversion_function ~dir:`Of ~source:type_name ~target:target_type
    in
    let to_target =
      convert_record
        ~loc
        ~fields
        ~source_fields:current_fields
        ~target_fields:other_fields
        ~modified_fields:modify
        ~set_fields:set
        ~target_type
        ~source_type:current_type
        ~rec_flag
        ~type_name
    in
    let of_target =
      convert_record
        ~loc
        ~fields
        ~source_fields:other_fields
        ~target_fields:current_fields
        ~modified_fields:modify
        ~set_fields:set
        ~target_type:current_type
        ~source_type:target_type
        ~rec_flag
        ~type_name
    in
    [ [%stri let [%p pvar ~loc to_target_name] = [%e to_target]]
    ; [%stri let [%p pvar ~loc of_target_name] = [%e of_target]]
    ]
;;