Source file atd_inherit.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
(*
  Perform inheritance.
*)


open Printf

open Atd_ast

module S = Set.Make (String)


let load_defs l =
  let tbl = Atd_predef.make_table () in
  List.iter (
    fun ((_, (k, pl, _), _) as td) ->
      Hashtbl.add tbl k (List.length pl, Some td)
  ) l;
  tbl

let keep_last_defined get_name l =
  let set, l =
    List.fold_right (
      fun x (set, l) ->
        let k = get_name x in
        if S.mem k set then (set, l)
        else (S.add k set, x :: l)
    ) l (S.empty, [])
  in
  l

let get_field_name : field -> string = function
    `Field (loc, (k, _, _), _) -> k
  | `Inherit _ -> assert false

let get_variant_name : variant -> string = function
    `Variant (loc, (k, _), _) -> k
  | `Inherit _ -> assert false


let expand ?(inherit_fields = true) ?(inherit_variants = true) tbl t0 =

  let rec subst deref param (t : type_expr) : type_expr =
    match t with
        `Sum (loc, vl, a) ->
          let vl = List.flatten (List.map (subst_variant param) vl) in
          let vl =
            if inherit_variants then
              keep_last_defined get_variant_name vl
            else
              vl
          in
          `Sum (loc, vl, a)

      | `Record (loc, fl, a) ->
          let fl = List.flatten (List.map (subst_field param) fl) in
          let fl =
            if inherit_fields then
              keep_last_defined get_field_name fl
            else
              fl
          in
          `Record (loc, fl, a)

      | `Tuple (loc, tl, a) ->
          `Tuple (
            loc,
            List.map (fun (loc, x, a) -> (loc, subst false param x, a)) tl, a
          )

      | `List (loc, t, a)
      | `Name (loc, (_, "list", [t]), a) ->
          `List (loc, subst false param t, a)

      | `Option (loc, t, a)
      | `Name (loc, (_, "option", [t]), a) ->
          `Option (loc, subst false param t, a)

      | `Nullable (loc, t, a)
      | `Name (loc, (_, "nullable", [t]), a) ->
          `Nullable (loc, subst false param t, a)

      | `Shared (loc, t, a)
      | `Name (loc, (_, "shared", [t]), a) ->
          `Shared (loc, subst false param t, a)

      | `Wrap (loc, t, a)
      | `Name (loc, (_, "wrap", [t]), a) ->
          `Wrap (loc, subst false param t, a)

      | `Tvar (loc, s) ->
          (try List.assoc s param
           with Not_found -> t)

      | `Name (loc, (loc2, k, args), a) ->
          let expanded_args = List.map (subst false param) args in
          if deref then
            let k, vars, a, t =
              try
                match Hashtbl.find tbl k with
                    n, Some (_, (k, vars, a), t) -> k, vars, a, t
                  | n, None -> failwith ("Cannot inherit from type " ^ k)
              with Not_found ->
                failwith ("Missing type definition for " ^ k)
            in
            let param = List.combine vars expanded_args in
            subst true param t
          else
            `Name (loc, (loc2, k, expanded_args), a)

  and subst_field param = function
      `Field (loc, k, t) -> [ `Field (loc, k, subst false param t) ]
    | `Inherit (loc, t) as x ->
        (match subst true param t with
             `Record (loc, vl, a) ->
               if inherit_fields then vl
               else [ x ]
           | _ -> failwith "Not a record type"
        )

  and subst_variant param = function
      `Variant (loc, k, opt_t) as x ->
        (match opt_t with
             None -> [ x ]
           | Some t -> [ `Variant (loc, k, Some (subst false param t)) ]
        )
    | `Inherit (loc, t) as x ->
        (match subst true param t with
             `Sum (loc, vl, a) ->
               if inherit_variants then vl
               else [ x ]
           | _ -> failwith "Not a sum type"
        )
  in
  subst false [] t0


let expand_module_body
    ?inherit_fields
    ?inherit_variants
    (l : Atd_ast.module_body) =
  let td_list = List.map (function `Type td -> td) l in
  let tbl = load_defs td_list in
  let td_list =
    List.map (
      fun (loc, name, t) ->
        (loc, name, expand ?inherit_fields ?inherit_variants tbl t)
    ) td_list in
  List.map (fun td -> `Type td) td_list