Source file resp_server.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
open Lwt.Infix
module Value = Hiredis_value
open Unix
exception Invalid_encoding
module type BACKEND = sig
type t
type client
val new_client: t -> client
end
module type AUTH = sig
type t
val check: t -> string array -> bool
end
module type SERVER = sig
module Auth: AUTH
module Backend: BACKEND
type t
val ok: Value.t option Lwt.t
val error: string -> Value.t option Lwt.t
val invalid_arguments: unit -> Value.t option Lwt.t
type command =
Backend.t ->
Backend.client ->
string ->
Value.t array ->
Value.t option Lwt.t
val create:
?auth: Auth.t ->
?default: command ->
?commands: (string * command) list ->
?host: string ->
?tls_config: Conduit_lwt_unix.tls_server_key ->
Conduit_lwt_unix.server ->
Backend.t ->
t Lwt.t
val start:
?backlog: int ->
?timeout: int ->
?stop: unit Lwt.t ->
?on_exn: (exn -> unit) ->
t ->
unit Lwt.t
end
module Auth = struct
module String = struct
type t = string
let check auth args =
Array.length args > 0 && args.(0) = auth
end
module User = struct
type t = (string, string) Hashtbl.t
let check auth args =
if Array.length args < 2 then false
else
match Hashtbl.find_opt auth args.(0) with
| Some p -> p = args.(1)
| None -> false
end
end
let rec read_value ic =
let open Value in
Lwt_io.read_char ic >>= function
| ':' ->
Lwt_io.read_line ic >|= fun i ->
let i = Int64.of_string i in
Value.Integer i
| '-' -> Lwt_io.read_line ic >|= fun line -> Error line
| '+' -> Lwt_io.read_line ic >|= fun line -> Status line
| '*' ->
Lwt_io.read_line ic >>= fun i ->
let i = int_of_string i in
let arr = Array.make i Value.Nil in
if i < 0 then Lwt.return Nil
else
let rec aux n =
match n with
| 0 -> Lwt.return_unit
| n ->
read_value ic >>= fun x ->
arr.(i - n) <- x;
aux (n - 1)
in
aux i >|= fun () ->
Array arr
| '$' ->
Lwt_io.read_line ic >>= fun i ->
let i = int_of_string i in
if i < 0 then Lwt.return Nil
else
Lwt_io.read ~count:i ic >>= fun s ->
Lwt_io.read_char ic >>= fun _ ->
Lwt_io.read_char ic >|= fun _ ->
String s
| _ -> raise Invalid_encoding
let rec write_value oc x =
let open Value in
match x with
| Nil -> Lwt_io.write oc "*-1\r\n"
| Error e ->
Lwt_io.write oc "-" >>= fun () ->
Lwt_io.write oc e >>= fun () ->
Lwt_io.write oc "\r\n"
| Integer i ->
Lwt_io.write oc ":" >>= fun () ->
Lwt_io.write oc (Printf.sprintf "%Ld\r\n" i)
| String s ->
Lwt_io.write oc (Printf.sprintf "$%d\r\n" (String.length s)) >>= fun () ->
Lwt_io.write oc s >>= fun () ->
Lwt_io.write oc "\r\n"
| Array arr ->
Lwt_io.write oc (Printf.sprintf "*%d\r\n" (Array.length arr)) >>= fun () ->
let rec write_all arr i =
if i >= Array.length arr then Lwt.return_unit
else write_value oc arr.(i) >>= fun () -> write_all arr (i + 1)
in write_all arr 0
| Status s ->
Lwt_io.write oc "+" >>= fun () ->
Lwt_io.write oc s >>= fun () ->
Lwt_io.write oc "\r\n"
module Client = struct
open Conduit_lwt_unix
type t = flow * ic * oc
let connect ?(ctx = default_ctx) ?tls_config ?port s =
let client = match port with
| None -> `Unix_domain_socket (`File s)
| Some port ->
(match tls_config with
| Some cfg -> `TLS cfg
| None -> `TCP (`IP (Ipaddr.of_string_exn s), `Port port))
in
Conduit_lwt_unix.connect ~ctx client
let read (_, ic, _) = read_value ic
let write (_, _, oc) x = write_value oc x
let run (_, ic, oc) cmd = write_value oc (Value.Array (Array.map Value.string cmd)) >>= fun () -> read_value ic
let run_v (_, ic, oc) cmd = write_value oc (Value.Array cmd) >>= fun () -> read_value ic
end
module Make(A: AUTH)(B: BACKEND): SERVER with module Backend = B and module Auth = A = struct
module Auth = A
module Backend = B
type command = Backend.t -> Backend.client -> string -> Value.t array -> Value.t option Lwt.t
type t = {
s_ctx: Conduit_lwt_unix.ctx;
s_mode: Conduit_lwt_unix.server;
s_tls_config: Conduit_lwt_unix.tls_server_key option;
s_auth: Auth.t option;
s_cmd: (string, command) Hashtbl.t;
s_data: Backend.t;
s_default: command option;
}
let error msg =
Lwt.return_some (Value.error ("ERR " ^ msg))
let invalid_arguments () =
Lwt.return_some (Value.error "ERR invalid arguments")
let ok = Lwt.return_some (Value.status "OK")
type client = {
c_in: Lwt_io.input_channel;
c_out: Lwt_io.output_channel;
c_data: B.client;
}
let create ?auth ?default ?commands ?host:(host="127.0.0.1") ?tls_config mode data =
Conduit_lwt_unix.init ~src:host ?tls_server_key:tls_config () >|= fun ctx ->
let commands = match commands with
| Some s ->
let ht = Hashtbl.create (List.length s) in
List.iter (fun (k, v) -> Hashtbl.replace ht k v) s;
ht
| None -> Hashtbl.create 8 in
{
s_ctx = ctx;
s_mode = mode;
s_tls_config = tls_config;
s_auth = auth;
s_cmd = commands;
s_data = data;
s_default = default;
}
let check auth args =
match auth with
| Some a -> Auth.check a args
| None -> true
let split_command arr =
try
let cmd = Value.to_string arr.(0)
|> String.lowercase_ascii in
let args = Array.sub arr 1 (Array.length arr - 1) in
cmd, args
with
| Value.Invalid_value -> "", [||]
let determine_command srv cmd =
match Hashtbl.find_opt srv.s_cmd cmd with
| Some f -> Some f
| None ->
begin
match srv.s_default with
| Some f -> Some f
| None -> None
end
let rec aux srv authenticated client =
Lwt.catch (fun () -> read_value client.c_in >>= Lwt.return_some)
(function
| Invalid_encoding | End_of_file -> Lwt.return_none
| x -> raise x) >>= function
| Some (Array a) when Array.length a > 0 ->
let cmd, args = split_command a in
if authenticated then
when_authenticated srv client cmd args
else
when_not_authenticated srv client cmd args
| _ ->
write_value client.c_out (Error "NOCOMMAND Invalid Command")
and when_authenticated srv client cmd args =
match determine_command srv cmd with
| Some f ->
begin
f srv.s_data client.c_data cmd args >>= fun r ->
match r with
| Some res ->
write_value client.c_out res >>= fun () ->
aux srv true client
| None -> Lwt.return_unit
end
| None ->
(if cmd = "command" then
let commands = Hashtbl.fold (fun k v dst -> Value.string k :: dst) srv.s_cmd [] in
write_value client.c_out (Value.array (Array.of_list commands))
else
write_value client.c_out (Error "NOCOMMAND Invalid command")) >>= fun _ ->
aux srv true client
and when_not_authenticated srv client cmd args =
let args =
try
Array.map Value.to_string args
with Value.Invalid_value -> [||] in
if cmd = "auth" && check srv.s_auth args then
write_value client.c_out (Status "OK") >>= fun () ->
aux srv true client
else
write_value client.c_out (Error "NOAUTH Authentication Required") >>= fun _ ->
aux srv false client
let rec handle srv flow ic oc =
let client = {
c_in = ic;
c_out = oc;
c_data = B.new_client srv.s_data;
} in
Lwt.catch
(fun () -> aux srv (srv.s_auth = None) client)
(fun _ -> Lwt_unix.yield ())
let start ?backlog ?timeout ?stop ?on_exn srv =
Conduit_lwt_unix.serve ?backlog ?timeout ?stop ?on_exn
~ctx:srv.s_ctx ~mode:srv.s_mode (handle srv)
end