1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
let make_alphabet alphabet =
if String.length alphabet <> 32
then invalid_arg "Length of alphabet must be 32" ;
if String.contains alphabet '='
then invalid_arg "Alphabet can not contain padding character" ;
let emap =
Array.init (String.length alphabet) (fun i -> Char.code alphabet.[i]) in
let dmap = Array.make 256 (-1) in
String.iteri (fun idx chr -> dmap.(Char.code chr) <- idx) alphabet ;
(emap, dmap)
let alphabet = make_alphabet "0123456789ABCDEFGHIJKLMNOPQRSTUV"
let pad_char = '='
let pad_int = int_of_char pad_char
let encode ?(pad = true) str =
let len = String.length str in
let str = Bytes.unsafe_of_string str in
let out_len = (len + 4) / 5 * 8 in
let out = Bytes.make out_len pad_char in
let o1 b1 = b1 lsr 3
and o2 b1 b2 = (b1 land 0x07) lsl 2 + b2 lsr 6
and o3 b2 = (b2 land 0x3E) lsr 1
and o4 b2 b3 = ((b2 land 0x01) lsl 4) + b3 lsr 4
and o5 b3 b4 = (b3 land 0x0F) lsl 1 + b4 lsr 7
and o6 b4 = (b4 land 0x7c) lsr 2
and o7 b4 b5 = (b4 land 0x03) lsl 3 + b5 lsr 5
and o8 b5 = b5 land 0x1F
in
let emit b1 b2 b3 b4 b5 off =
List.iteri (fun idx v -> Bytes.set_uint8 out (off + idx) ((fst alphabet).(v)))
[ o1 b1; o2 b1 b2; o3 b2; o4 b2 b3; o5 b3 b4; o6 b4; o7 b4 b5; o8 b5 ]
in
let rec enc s_off d_off =
if s_off = len then
0
else if s_off = len - 1 then
let b1 = Bytes.get_uint8 str s_off in
let p1 = o1 b1 and p2 = o2 b1 0 in
Bytes.set_uint8 out d_off ((fst alphabet).(p1));
Bytes.set_uint8 out (d_off + 1) ((fst alphabet).(p2));
6
else if s_off = len - 2 then
let b1 = Bytes.get_uint8 str s_off
and b2 = Bytes.get_uint8 str (s_off + 1)
in
let p1 = o1 b1 and p2 = o2 b1 b2 and p3 = o3 b2 and p4 = o4 b2 0 in
Bytes.set_uint8 out d_off ((fst alphabet).(p1));
Bytes.set_uint8 out (d_off + 1) ((fst alphabet).(p2));
Bytes.set_uint8 out (d_off + 2) ((fst alphabet).(p3));
Bytes.set_uint8 out (d_off + 3) ((fst alphabet).(p4));
4
else if s_off = len - 3 then
let b1 = Bytes.get_uint8 str s_off
and b2 = Bytes.get_uint8 str (s_off + 1)
and b3 = Bytes.get_uint8 str (s_off + 2)
in
let p1 = o1 b1 and p2 = o2 b1 b2 and p3 = o3 b2 and p4 = o4 b2 b3 and p5 = o5 b3 0 in
Bytes.set_uint8 out d_off ((fst alphabet).(p1));
Bytes.set_uint8 out (d_off + 1) ((fst alphabet).(p2));
Bytes.set_uint8 out (d_off + 2) ((fst alphabet).(p3));
Bytes.set_uint8 out (d_off + 3) ((fst alphabet).(p4));
Bytes.set_uint8 out (d_off + 4) ((fst alphabet).(p5));
3
else if s_off = len - 4 then
let b1 = Bytes.get_uint8 str s_off
and b2 = Bytes.get_uint8 str (s_off + 1)
and b3 = Bytes.get_uint8 str (s_off + 2)
and b4 = Bytes.get_uint8 str (s_off + 3)
in
let p1 = o1 b1 and p2 = o2 b1 b2 and p3 = o3 b2 and p4 = o4 b2 b3 and p5 = o5 b3 b4 and p6 = o6 b4 and p7 = o7 b4 0 in
Bytes.set_uint8 out d_off ((fst alphabet).(p1));
Bytes.set_uint8 out (d_off + 1) ((fst alphabet).(p2));
Bytes.set_uint8 out (d_off + 2) ((fst alphabet).(p3));
Bytes.set_uint8 out (d_off + 3) ((fst alphabet).(p4));
Bytes.set_uint8 out (d_off + 4) ((fst alphabet).(p5));
Bytes.set_uint8 out (d_off + 5) ((fst alphabet).(p6));
Bytes.set_uint8 out (d_off + 6) ((fst alphabet).(p7));
1
else
let b1 = Bytes.get_uint8 str s_off in
let b2 = Bytes.get_uint8 str (s_off + 1) in
let b3 = Bytes.get_uint8 str (s_off + 2) in
let b4 = Bytes.get_uint8 str (s_off + 3) in
let b5 = Bytes.get_uint8 str (s_off + 4) in
emit b1 b2 b3 b4 b5 d_off;
enc (s_off + 5) (d_off + 8)
in
let padding_bytes = enc 0 0 in
let out_s = Bytes.unsafe_to_string out in
if pad then out_s else String.sub out_s 0 (out_len - padding_bytes)
let decode ?(unpadded = false) str =
let ( let* ) = Result.bind in
let* str =
let lmod8 = String.length str mod 8 in
if lmod8 > 0 then
if unpadded then
Ok (str ^ String.make (8 - lmod8) pad_char)
else
Error (`Msg "invalid input length (not divisible by 8)")
else
Ok str
in
let len = String.length str in
let str = Bytes.unsafe_of_string str in
let out_len = len / 8 * 5 in
let out = Bytes.create out_len in
let o1 b1 b2 = b1 lsl 3 + b2 lsr 2
and o2 b2 b3 b4 = (b2 land 0x03) lsl 6 + b3 lsl 1 + b4 lsr 4
and o3 b4 b5 = (b4 land 0x0F) lsl 4 + b5 lsr 1
and o4 b5 b6 b7 = (b5 land 0x01) lsl 7 + b6 lsl 2 + b7 lsr 3
and o5 b7 b8 = (b7 land 0x07) lsl 5 + b8
in
let c ~off idx =
let r = (snd alphabet).(idx) in
if r = -1 then
Error (`Msg ("bad encoding at " ^ string_of_int off))
else
Ok r
in
let emit s_off v1 v2 v3 v4 v5 v6 v7 v8 off =
let* b1 = c ~off:s_off v1 in
let* b2 = c ~off:(s_off + 1) v2 in
let* b3 = c ~off:(s_off + 2) v3 in
let* b4 = c ~off:(s_off + 3) v4 in
let* b5 = c ~off:(s_off + 4) v5 in
let* b6 = c ~off:(s_off + 5) v6 in
let* b7 = c ~off:(s_off + 6) v7 in
let* b8 = c ~off:(s_off + 7) v8 in
Bytes.set_uint8 out off (o1 b1 b2);
Bytes.set_uint8 out (off + 1) (o2 b2 b3 b4);
Bytes.set_uint8 out (off + 2) (o3 b4 b5);
Bytes.set_uint8 out (off + 3) (o4 b5 b6 b7);
Bytes.set_uint8 out (off + 4) (o5 b7 b8);
Ok ()
in
let rec dec s_off d_off =
if s_off = len then
Ok (0, 0)
else
let v1 = Bytes.get_uint8 str s_off
and v2 = Bytes.get_uint8 str (s_off + 1)
and v3 = Bytes.get_uint8 str (s_off + 2)
and v4 = Bytes.get_uint8 str (s_off + 3)
and v5 = Bytes.get_uint8 str (s_off + 4)
and v6 = Bytes.get_uint8 str (s_off + 5)
and v7 = Bytes.get_uint8 str (s_off + 6)
and v8 = Bytes.get_uint8 str (s_off + 7)
in
if v3 = pad_int then
let* b1 = c ~off:s_off v1 in
let* b2 = c ~off:(s_off + 1) v2 in
let p1 = o1 b1 b2 in
Bytes.set_uint8 out d_off p1;
Ok (6, 4)
else if v5 = pad_int then
let* b1 = c ~off:s_off v1 in
let* b2 = c ~off:(s_off + 1) v2 in
let* b3 = c ~off:(s_off + 2) v3 in
let* b4 = c ~off:(s_off + 3) v4 in
let p1 = o1 b1 b2
and p2 = o2 b2 b3 b4
in
Bytes.set_uint8 out d_off p1;
Bytes.set_uint8 out (d_off + 1) p2;
Ok (4, 3)
else if v6 = pad_int then
let* b1 = c ~off:s_off v1 in
let* b2 = c ~off:(s_off + 1) v2 in
let* b3 = c ~off:(s_off + 2) v3 in
let* b4 = c ~off:(s_off + 3) v4 in
let* b5 = c ~off:(s_off + 4) v5 in
let p1 = o1 b1 b2
and p2 = o2 b2 b3 b4
and p3 = o3 b4 b5
in
Bytes.set_uint8 out d_off p1;
Bytes.set_uint8 out (d_off + 1) p2;
Bytes.set_uint8 out (d_off + 2) p3;
Ok (3, 2)
else if v8 = pad_int then
let* b1 = c ~off:s_off v1 in
let* b2 = c ~off:(s_off + 1) v2 in
let* b3 = c ~off:(s_off + 2) v3 in
let* b4 = c ~off:(s_off + 3) v4 in
let* b5 = c ~off:(s_off + 4) v5 in
let* b6 = c ~off:(s_off + 5) v6 in
let* b7 = c ~off:(s_off + 6) v7 in
let p1 = o1 b1 b2
and p2 = o2 b2 b3 b4
and p3 = o3 b4 b5
and p4 = o4 b5 b6 b7
in
Bytes.set_uint8 out d_off p1;
Bytes.set_uint8 out (d_off + 1) p2;
Bytes.set_uint8 out (d_off + 2) p3;
Bytes.set_uint8 out (d_off + 3) p4;
Ok (1, 1)
else
let* () = emit s_off v1 v2 v3 v4 v5 v6 v7 v8 d_off in
dec (s_off + 8) (d_off + 5)
in
let* (pad_bytes, to_remove) = dec 0 0 in
let rec check_pad = function
| 0 -> Ok ()
| n ->
if Bytes.get_uint8 str (len - n) = pad_int then
check_pad (n - 1)
else
Error (`Msg ("expected pad character at " ^ (string_of_int (len - n))))
in
let* () = check_pad pad_bytes in
let out_str = Bytes.unsafe_to_string out in
if to_remove > 0 then
Ok (String.sub out_str 0 (out_len - to_remove))
else
Ok out_str