1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
open! Stdlib
type ('a, 'b) result =
| Ok of 'a
| Error of 'b
type alphabet =
{ emap : int array
; dmap : int array
}
type sub = string * int * int
let ( // ) x y = if x > 0 then 1 + ((x - 1) / y) else 0
let unsafe_get_uint8 t off = Char.code (String.unsafe_get t off)
let unsafe_set_uint8 t off v = Bytes.unsafe_set t off (Char.chr v)
external unsafe_set_uint16 : bytes -> int -> int -> unit = "%caml_bytes_set16u"
[@@noalloc]
external unsafe_get_uint16 : string -> int -> int = "%caml_string_get16u" [@@noalloc]
external swap16 : int -> int = "%bswap16" [@@noalloc]
let none = -1
let make_alphabet alphabet =
if String.length alphabet <> 64 then invalid_arg "Length of alphabet must be 64";
if String.contains alphabet '='
then invalid_arg "Alphabet can not contain padding character";
let emap = Array.init (String.length alphabet) ~f:(fun i -> Char.code alphabet.[i]) in
let dmap = Array.make 256 none in
String.iteri ~f:(fun idx chr -> dmap.(Char.code chr) <- idx) alphabet;
{ emap; dmap }
let length_alphabet { emap; _ } = Array.length emap
let alphabet { emap; _ } = emap
let default_alphabet =
make_alphabet "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
let uri_safe_alphabet =
make_alphabet "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
let unsafe_set_be_uint16 =
if Sys.big_endian
then fun t off v -> unsafe_set_uint16 t off v
else fun t off v -> unsafe_set_uint16 t off (swap16 v)
exception Out_of_bounds
let get_uint8 t off =
if off < 0 || off >= String.length t then raise Out_of_bounds;
unsafe_get_uint8 t off
let padding = int_of_char '='
let error_msgf fmt = Format.ksprintf (fun err -> Error (`Msg err)) fmt
let encode_sub pad { emap; _ } ?(off = 0) ?len input =
let len =
match len with
| Some len -> len
| None -> String.length input - off
in
if len < 0 || off < 0 || off > String.length input - len
then error_msgf "Invalid bounds"
else
let n = len in
let n' = n // 3 * 4 in
let res = Bytes.create n' in
let emap i = Array.unsafe_get emap i in
let emit b1 b2 b3 i =
unsafe_set_be_uint16
res
i
((emap ((b1 lsr 2) land 0x3f) lsl 8)
lor emap ((b1 lsl 4) lor (b2 lsr 4) land 0x3f));
unsafe_set_be_uint16
res
(i + 2)
((emap ((b2 lsl 2) lor (b3 lsr 6) land 0x3f) lsl 8) lor emap (b3 land 0x3f))
in
let rec enc j i =
if i = n
then ()
else if i = n - 1
then emit (unsafe_get_uint8 input (off + i)) 0 0 j
else if i = n - 2
then
emit (unsafe_get_uint8 input (off + i)) (unsafe_get_uint8 input (off + i + 1)) 0 j
else (
emit
(unsafe_get_uint8 input (off + i))
(unsafe_get_uint8 input (off + i + 1))
(unsafe_get_uint8 input (off + i + 2))
j;
enc (j + 4) (i + 3))
in
let rec unsafe_fix = function
| 0 -> ()
| i ->
unsafe_set_uint8 res (n' - i) padding;
unsafe_fix (i - 1)
in
enc 0 0;
let pad_to_write = (3 - (n mod 3)) mod 3 in
if pad
then (
unsafe_fix pad_to_write;
Ok (Bytes.unsafe_to_string res, 0, n'))
else Ok (Bytes.unsafe_to_string res, 0, n' - pad_to_write)
let encode ?(pad = true) ?(alphabet = default_alphabet) ?off ?len input =
match encode_sub pad alphabet ?off ?len input with
| Ok (res, off, len) -> Ok (String.sub res ~pos:off ~len)
| Error _ as err -> err
let encode_string ?pad ?alphabet input =
match encode ?pad ?alphabet input with
| Ok res -> res
| Error _ -> assert false
let encode_sub ?(pad = true) ?(alphabet = default_alphabet) ?off ?len input =
encode_sub pad alphabet ?off ?len input
let encode_exn ?pad ?alphabet ?off ?len input =
match encode ?pad ?alphabet ?off ?len input with
| Ok v -> v
| Error (`Msg err) -> invalid_arg err
let decode_sub ?(pad = true) { dmap; _ } ?(off = 0) ?len input =
let len =
match len with
| Some len -> len
| None -> String.length input - off
in
if len < 0 || off < 0 || off > String.length input - len
then error_msgf "Invalid bounds"
else
let n = len // 4 * 4 in
let n' = n // 4 * 3 in
let res = Bytes.create n' in
let get_uint8_or_padding =
if pad
then (fun t i ->
if i >= len then raise Out_of_bounds;
get_uint8 t (off + i))
else
fun t i ->
try if i < len then get_uint8 t (off + i) else padding
with Out_of_bounds -> padding
in
let set_be_uint16 t off v =
if off < 0 || off + 1 > Bytes.length t
then ()
else if off < 0 || off + 2 > Bytes.length t
then unsafe_set_uint8 t off (v lsr 8)
else unsafe_set_be_uint16 t off v
in
let set_uint8 t off v =
if off < 0 || off >= Bytes.length t then () else unsafe_set_uint8 t off v
in
let emit a b c d j =
let x = (a lsl 18) lor (b lsl 12) lor (c lsl 6) lor d in
set_be_uint16 res j (x lsr 8);
set_uint8 res (j + 2) (x land 0xff)
in
let dmap i =
let x = Array.unsafe_get dmap i in
if x = none then raise Not_found;
x
in
let only_padding pad idx =
let pad = ref (pad + 3) in
let idx = ref idx in
while !idx + 4 < len do
if unsafe_get_uint16 input (off + !idx) <> 0x3d3d
|| unsafe_get_uint16 input (off + !idx + 2) <> 0x3d3d
then raise Not_found;
idx := !idx + 4;
pad := !pad + 3
done;
while !idx < len do
if unsafe_get_uint8 input (off + !idx) <> padding then raise Not_found;
incr idx
done;
!pad
in
let rec dec j i =
if i = n
then 0
else
let d, pad =
let x = get_uint8_or_padding input (i + 3) in
try dmap x, 0 with Not_found when x = padding -> 0, 1
in
let c, pad =
let x = get_uint8_or_padding input (i + 2) in
try dmap x, pad with Not_found when x = padding && pad = 1 -> 0, 2
in
let b, pad =
let x = get_uint8_or_padding input (i + 1) in
try dmap x, pad with Not_found when x = padding && pad = 2 -> 0, 3
in
let a, pad =
let x = get_uint8_or_padding input i in
try dmap x, pad with Not_found when x = padding && pad = 3 -> 0, 4
in
emit a b c d j;
if i + 4 = n
then
match pad with
| 0 -> 0
| 4 -> 3
| pad -> pad
else
match pad with
| 0 -> dec (j + 3) (i + 4)
| 4 -> only_padding 3 (i + 4)
| pad -> only_padding pad (i + 4)
in
match dec 0 0 with
| 0 -> Ok (Bytes.unsafe_to_string res, 0, n')
| pad -> Ok (Bytes.unsafe_to_string res, 0, n' - pad)
| exception Out_of_bounds ->
error_msgf "Wrong padding"
| exception Not_found ->
error_msgf "Malformed input"
let decode ?pad ?(alphabet = default_alphabet) ?off ?len input =
match decode_sub ?pad alphabet ?off ?len input with
| Ok (res, off, len) -> Ok (String.sub res ~pos:off ~len)
| Error _ as err -> err
let decode_sub ?pad ?(alphabet = default_alphabet) ?off ?len input =
decode_sub ?pad alphabet ?off ?len input
let decode_exn ?pad ?alphabet ?off ?len input =
match decode ?pad ?alphabet ?off ?len input with
| Ok res -> res
| Error (`Msg err) -> invalid_arg err