Source file shape_expr.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
module Var = struct
type t = { id : int; name : string; min : int; max : int }
let create ~id ~name ~min ~max =
if min < 0 then
invalid_arg "Shape_expr.Var.create: min must be non-negative";
if min > max then invalid_arg "Shape_expr.Var.create: min must be <= max";
{ id; name; min; max }
let id v = v.id
let name v = v.name
let min v = v.min
let max v = v.max
end
type expr =
| Const of int
| Var of Var.t
| Add of expr * expr
| Mul of expr * expr
| Neg of expr
type shape = expr array
let const n = Const n
let var v = Var v
let add a b = Add (a, b)
let mul a b = Mul (a, b)
let neg e = Neg e
let of_int_array arr = Array.map const arr
let rec to_string_expr = function
| Const n -> string_of_int n
| Var v ->
if v.name = "" then Printf.sprintf "v%d" v.id
else Printf.sprintf "%s#%d" v.name v.id
| Add (a, b) ->
Printf.sprintf "(%s + %s)" (to_string_expr a) (to_string_expr b)
| Mul (a, b) ->
Printf.sprintf "(%s * %s)" (to_string_expr a) (to_string_expr b)
| Neg e -> Printf.sprintf "(-%s)" (to_string_expr e)
let to_string shape =
"["
^ String.concat "; "
(Array.to_list (Array.map (fun e -> to_string_expr e) shape))
^ "]"
let rec eval_expr bindings = function
| Const n -> Some n
| Var v -> List.assoc_opt v.id bindings
| Add (a, b) -> (
match (eval_expr bindings a, eval_expr bindings b) with
| Some x, Some y -> Some (x + y)
| _ -> None)
| Mul (a, b) -> (
match (eval_expr bindings a, eval_expr bindings b) with
| Some x, Some y -> Some (x * y)
| _ -> None)
| Neg e -> Option.map (fun x -> -x) (eval_expr bindings e)
let eval bindings shape = Array.map (fun e -> eval_expr bindings e) shape
let to_int_array_exn bindings shape =
let evaluated = eval bindings shape in
Array.mapi
(fun i -> function
| Some n -> n
| None ->
invalid_arg (Printf.sprintf "Shape_expr: dimension %d unresolved" i))
evaluated
let map f shape = Array.map f shape
let map2 f s1 s2 =
if Array.length s1 <> Array.length s2 then
invalid_arg "Shape_expr.map2: shape rank mismatch";
Array.init (Array.length s1) (fun i -> f s1.(i) s2.(i))
let fold f init shape = Array.fold_left f init shape
let rec upper_bound_expr = function
| Const n -> n
| Var v -> v.max
| Add (a, b) -> upper_bound_expr a + upper_bound_expr b
| Mul (a, b) -> upper_bound_expr a * upper_bound_expr b
| Neg _ -> 0
let upper_bounds shape = Array.map upper_bound_expr shape