1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
exception Shape_mismatch of string
let cannot ~op ~what ~from ~to_ ?reason ?hint () =
let msg =
let base = Printf.sprintf "%s: cannot %s %s to %s" op what from to_ in
let with_reason =
match reason with
| None -> base
| Some r -> Printf.sprintf "%s (%s)" base r
in
match hint with
| None -> with_reason
| Some h -> Printf.sprintf "%s\nhint: %s" with_reason h
in
invalid_arg msg
let invalid ~op ~what ?reason ?hint () =
let msg =
let base = Printf.sprintf "%s: invalid %s" op what in
let with_reason =
match reason with
| None -> base
| Some r -> Printf.sprintf "%s (%s)" base r
in
match hint with
| None -> with_reason
| Some h -> Printf.sprintf "%s\nhint: %s" with_reason h
in
invalid_arg msg
let failed ~op ~what ?reason ?hint () =
let msg =
let base = Printf.sprintf "%s: %s" op what in
let with_reason =
match reason with
| None -> base
| Some r -> Printf.sprintf "%s (%s)" base r
in
match hint with
| None -> with_reason
| Some h -> Printf.sprintf "%s\nhint: %s" with_reason h
in
invalid_arg msg
let shape_to_string shape =
let elements = Array.map string_of_int shape |> Array.to_list in
Printf.sprintf "[%s]" (String.concat "," elements)
let shape_mismatch ~op ~expected ~actual ?hint () =
let expected_str = shape_to_string expected in
let actual_str = shape_to_string actual in
let expected_size = Array.fold_left ( * ) 1 expected in
let actual_size = Array.fold_left ( * ) 1 actual in
let reason =
if Array.length expected = Array.length actual then
let mismatches =
Array.mapi
(fun i (e, a) ->
if e <> a then Some (Printf.sprintf "dim %d: %d≠%d" i e a) else None)
(Array.combine expected actual)
|> Array.to_list
|> List.filter_map (fun x -> x)
in
String.concat ", " mismatches
else if expected_size <> actual_size then
Printf.sprintf "%d→%d elements" expected_size actual_size
else
Printf.sprintf "incompatible ranks %d and %d" (Array.length expected)
(Array.length actual)
in
cannot ~op ~what:"reshape" ~from:expected_str ~to_:actual_str ~reason ?hint ()
let broadcast_incompatible ~op ~shape1 ~shape2 ?hint () =
let shape1_str = shape_to_string shape1 in
let shape2_str = shape_to_string shape2 in
let ndim1 = Array.length shape1 in
let ndim2 = Array.length shape2 in
let max_ndim = max ndim1 ndim2 in
let mismatches = ref [] in
for i = 0 to max_ndim - 1 do
let dim1 = if i < ndim1 then shape1.(ndim1 - 1 - i) else 1 in
let dim2 = if i < ndim2 then shape2.(ndim2 - 1 - i) else 1 in
if dim1 <> dim2 && dim1 <> 1 && dim2 <> 1 then
mismatches :=
Printf.sprintf "dim %d: %d≠%d" (max_ndim - 1 - i) dim1 dim2
:: !mismatches
done;
let reason = String.concat ", " (List.rev !mismatches) in
let default_hint =
"broadcasting requires dimensions to be either equal or 1"
in
let hint = Option.value hint ~default:default_hint in
cannot ~op ~what:"broadcast" ~from:shape1_str ~to_:shape2_str ~reason ~hint ()
let dtype_mismatch ~op ~expected ~actual ?hint () =
let default_hint = Printf.sprintf "cast one array to %s" expected in
let hint = Option.value hint ~default:default_hint in
cannot ~op ~what:op ~from:expected ~to_:("with " ^ actual)
~reason:"dtype mismatch" ~hint ()
let axis_out_of_bounds ~op ~axis ~ndim ?hint () =
invalid ~op
~what:(Printf.sprintf "axis %d" axis)
~reason:(Printf.sprintf "out of bounds for %dD array" ndim)
?hint ()
let invalid_shape ~op ~shape ~reason ?hint () =
invalid ~op
~what:(Printf.sprintf "shape %s" (shape_to_string shape))
~reason ?hint ()
let empty_input ~op ~what = invalid ~op ~what ~reason:"cannot be empty" ()
let check_bounds ~op ~name ~value ?min ?max () =
let check_min =
match min with
| Some m when value < m -> Some (Printf.sprintf "%s=%d < %d" name value m)
| _ -> None
in
let check_max =
match max with
| Some m when value > m -> Some (Printf.sprintf "%s=%d > %d" name value m)
| _ -> None
in
match (check_min, check_max) with
| Some msg, _ | _, Some msg -> invalid ~op ~what:name ~reason:msg ()
| None, None -> ()
let check_axes ~op ~axes ~ndim =
Array.iter
(fun axis ->
if axis < -ndim || axis >= ndim then axis_out_of_bounds ~op ~axis ~ndim ())
axes
let multi_issue ~op issues =
let formatted_issues =
List.map
(fun (issue, detail) -> Printf.sprintf " - %s: %s" issue detail)
issues
in
let msg =
Printf.sprintf "%s: invalid configuration\n%s" op
(String.concat "\n" formatted_issues)
in
invalid_arg msg