Source file indexing.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
open Base

type symbol = Symbol of int [@@deriving compare, equal, sexp, hash, variants]

let unique_id = ref 1

let get_symbol () =
  let uid = !unique_id in
  Int.incr unique_id;
  Symbol uid

module CompareSymbol = struct
  type t = symbol = Symbol of int [@@deriving compare, equal, sexp, hash]
end

module Symbol = struct
  include CompareSymbol
  include Comparator.Make (CompareSymbol)
end

let symbol_ident (Symbol s) = "i" ^ Int.to_string s

type 'a environment = 'a Map.M(Symbol).t [@@deriving sexp]

let empty_env : 'a environment = Map.empty (module Symbol)

type static_symbol = {
  static_symbol : symbol;
  mutable static_range : int option; [@compare.ignore] [@equal.ignore] [@hash.ignore]
}
[@@deriving compare, equal, sexp, hash]

type 'a bindings = Empty | Bind of static_symbol * (int -> 'a) bindings [@@deriving sexp_of]

let bound_symbols bs =
  let rec loop : 'a. 'a bindings -> static_symbol list =
   fun (type a) (b : a bindings) -> match b with Empty -> [] | Bind (s, bs) -> s :: loop bs
  in
  (* Reverse order to match [lowered_bindings]. *)
  List.rev @@ loop bs

(** Helps lowering the bindings. *)
type ('r, 'idcs, 'p1, 'p2) variadic =
  | Result of 'r
  | Param_idx of int ref * (int -> 'r, int -> 'idcs, 'p1, 'p2) variadic
  | Param_1 of 'p1 option ref * ('p1 -> 'r, 'idcs, 'p1, 'p2) variadic
  | Param_2 of 'p2 option ref * ('p2 -> 'r, 'idcs, 'p1, 'p2) variadic
  | Param_2f :
      ('p2f -> 'p2) * 'p2f option ref * ('p2 -> 'r, 'idcs, 'p1, 'p2) variadic
      -> ('r, 'idcs, 'p1, 'p2) variadic

type unit_bindings = (unit -> unit) bindings [@@deriving sexp_of]
type lowered_bindings = (static_symbol, int ref) List.Assoc.t [@@deriving sexp_of]

(** [apply run_variadic ()] applies the parameters in reverse order to how they appear in the
    [run_variadic] list. *)
let rec apply : 'r 'idcs 'p1 'p2. ('r, 'idcs, 'p1, 'p2) variadic -> 'r =
 fun (type r idcs p1 p2) (f : (r, idcs, p1, p2) variadic) ->
  match f with
  | Result rf -> rf
  | Param_idx (i, more) -> apply more !i
  | Param_1 ({ contents = Some p1 }, more) -> apply more p1
  | Param_2 ({ contents = Some p2 }, more) -> apply more p2
  | Param_2f (pf, { contents = Some p2 }, more) -> apply more @@ pf p2
  | Param_1 ({ contents = None }, _) -> invalid_arg "Indexing.apply: Param_1 missing value"
  | Param_2 ({ contents = None }, _) -> invalid_arg "Indexing.apply: Param_2 missing value"
  | Param_2f (_, { contents = None }, _) -> invalid_arg "Indexing.apply: Param_2 missing value"

let lowered_bindings bs vs =
  let rec loop : 'r 'idcs. 'idcs bindings * ('r, 'idcs, 'p1, 'p2) variadic -> lowered_bindings =
   fun (type r idcs) (f : idcs bindings * (r, idcs, 'p1, 'p2) variadic) ->
    match f with
    | Empty, Result _ -> []
    | Bind (s, bs), Param_idx (i, vs) -> (s, i) :: loop (bs, vs)
    | bs, Param_1 (_, vs) -> loop (bs, vs)
    | bs, Param_2 (_, vs) -> loop (bs, vs)
    | bs, Param_2f (_, _, vs) -> loop (bs, vs)
    | Empty, _ -> assert false
    | Bind _, Result _ -> assert false
  in
  (* Reverse order because [apply] above also reverses the order! *)
  List.rev @@ loop (bs, vs)

let find_exn (bs : lowered_bindings) = List.Assoc.find_exn ~equal:equal_static_symbol bs

let get_static_symbol ?static_range bindings =
  let s = { static_symbol = get_symbol (); static_range } in
  (s, Bind (s, bindings))

(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.
    Outputs ["-"] for empty dimensions. *)
let dims_to_string ?(with_axis_numbers = false) dims =
  if Array.is_empty dims then "-"
  else if with_axis_numbers then
    String.concat_array ~sep:" x "
    @@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s)
  else String.concat_array ~sep:"x" @@ Array.map dims ~f:Int.to_string

type axis_index =
  | Fixed_idx of int  (** The specific position along an axis. *)
  | Iterator of symbol
      (** The given member of the [product_space] corresponding to some [product_iterators]. *)
[@@deriving compare, equal, sexp, variants]

type str_osym_map = (string, symbol option, Base.String.comparator_witness) Base.Map.t

let sexp_of_str_osym_map (map : str_osym_map) =
  Sexp.List (Map.to_alist map |> List.map ~f:[%sexp_of: string * symbol option])

type projections_debug = { spec : string; derived_for : Sexp.t; trace : (string * int) list }
[@@deriving sexp]

let unique_debug_id =
  let projections_uid = ref 0 in
  fun () ->
    Int.incr projections_uid;
    !projections_uid

type projections = {
  product_space : int array;
      (** The product space dimensions that an operation should parallelize (map-reduce) over. *)
  lhs_dims : int array;  (** The dimensions of the LHS array. *)
  rhs_dims : int array array;
      (** The dimensions of the RHS arrays, needed for deriving projections from other projections. *)
  product_iterators : symbol array;
      (** The product space iterators (concatentation of the relevant batch, output, input axes) for
          iterating over the [product_space] axes, where same axes are at same array indices. *)
  project_lhs : axis_index array;
      (** A projection that takes an [product_space]-bound index and produces an index into the
          result of an operation. *)
  project_rhs : axis_index array array;
      (** [project_rhs.(i)] Produces an index into the [i+1]th argument of an operation. *)
  debug_info : (projections_debug[@sexp.ignore] [@compare.ignore] [@equal.ignore]);
}
[@@deriving compare, equal, sexp]
(** All the information relevant for code generation. *)

let iterated dim = dim > 1
let opt_symbol d = if iterated d then Some (get_symbol ()) else None
let opt_iterator = function None -> Fixed_idx 0 | Some sym -> Iterator sym

let is_bijective proj =
  let lhs_symbols =
    Set.of_array (module Symbol)
    @@ Array.filter_map proj.project_lhs ~f:(function Iterator s -> Some s | Fixed_idx _ -> None)
  in
  Set.equal lhs_symbols (Set.of_array (module Symbol) proj.product_iterators)

(** Projections for a pointwise unary operator. *)
let identity_projections ~debug_info ~lhs_dims =
  let product_iterators = Array.map lhs_dims ~f:opt_symbol in
  let project_lhs = Array.map product_iterators ~f:opt_iterator in
  let product_space = Array.filter ~f:iterated lhs_dims in
  let product_iterators = Array.filter_map ~f:Fn.id product_iterators in
  {
    product_space;
    lhs_dims;
    rhs_dims = [| lhs_dims |];
    product_iterators;
    project_lhs;
    project_rhs = [| project_lhs |];
    debug_info =
      { debug_info with trace = ("indentity_projections", unique_debug_id ()) :: debug_info.trace };
  }

let derive_index ~product_syms ~(projection : axis_index array) =
  let sym_to_i =
    Array.mapi product_syms ~f:(fun i s -> (s, i))
    |> Array.to_list
    |> Map.of_alist_exn (module Symbol)
  in
  let positions =
    Array.map projection ~f:(function
      | Iterator s when Map.mem sym_to_i s -> Either.First (Map.find_exn sym_to_i s)
      | it -> Second it)
  in
  fun ~product -> Array.map positions ~f:(function First p -> product.(p) | Second it -> it)