Source file resp_server.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
module type AUTH = sig
type t
val check : t -> string array -> bool
end
module type SERVER = sig
type ic
type oc
type server
type data
val run : server -> (ic * oc -> unit Lwt.t) -> unit Lwt.t
end
module Auth = struct
module String = struct
type t = string
let check auth args = Array.length args > 0 && args.(0) = auth
end
module User = struct
type t = (string, string) Hashtbl.t
let check auth args =
if Array.length args < 2 then false
else
match Hashtbl.find_opt auth args.(0) with
| Some p ->
p = args.(1)
| None ->
false
end
end
module type S = sig
include SERVER
module Value : Resp.S with type Reader.ic = ic and type Writer.oc = oc
module Auth : AUTH
type client = ic * oc
type command = data -> client -> string -> int -> unit Lwt.t
type t
val discard_n : client -> int -> unit Lwt.t
val finish : client -> nargs:int -> int -> unit Lwt.t
val ok : client -> unit Lwt.t
val error : client -> string -> unit Lwt.t
val invalid_arguments : client -> unit Lwt.t
val send : client -> Resp.t -> unit Lwt.t
val recv : client -> Resp.t Lwt.t
val create :
?auth:Auth.t
-> ?commands:(string * command) list
-> ?default:string
-> server
-> data
-> t
val start : t -> unit Lwt.t
end
module Make
(Server : SERVER)
(Auth : AUTH)
(Value : Resp.S
with type Reader.ic = Server.ic
and type Writer.oc = Server.oc) =
struct
include Server
module Value = Value
module Auth = Auth
let ( >>= ) = Lwt.( >>= )
type client = ic * oc
type command = data -> client -> string -> int -> unit Lwt.t
type t =
{ server : server
; data : data
; auth : Auth.t option
; commands : (string, command) Hashtbl.t
; default : string }
let ok (_, oc) = Value.write oc (`String "OK")
let error (_, oc) msg = Value.write oc (`Error (Printf.sprintf "ERR %s" msg))
let invalid_arguments client = error client "Invalid arguments"
let send (_, oc) x = Value.write oc x
let recv (ic, _) = Value.read ic
let hashtbl_of_list l =
let ht = Hashtbl.create (List.length l) in
List.iter (fun (k, v) -> Hashtbl.replace ht (String.lowercase_ascii k) v) l;
ht
let create ?auth ?(commands = []) ?(default = "default") server data =
let commands = hashtbl_of_list commands in
{server; data; auth; commands; default}
let check_auth auth args =
match auth with
| Some auth ->
Auth.check auth args
| None ->
true
let split_command_s arr : string * string array =
( String.lowercase_ascii @@ Resp.to_string_exn arr.(0)
, Array.map Resp.to_string_exn (Array.sub arr 1 (Array.length arr - 1)) )
let rec discard_n client n =
if n > 0 then Value.read (fst client) >>= fun _ -> discard_n client (n - 1)
else Lwt.return ()
let finish client ~nargs used = discard_n client (nargs - used)
let rec handle t client authenticated =
let argc = ref 0 in
Lwt.catch
(fun () ->
if not authenticated then handle_not_authenticated t client
else
Value.Reader.read_lexeme (fst client)
>>= function
| Ok (`As n) -> (
argc := n - 1;
Value.read (fst client)
>>= function
| `String s
| `Bulk s ->
let s = String.lowercase_ascii s in
let f =
try Hashtbl.find t.commands s with Not_found ->
Hashtbl.find t.commands t.default
in
f t.data client s !argc >>= fun () -> handle t client true
| _ ->
discard_n client !argc
>>= fun () ->
error client "invalid commands name"
>>= fun () -> handle t client true )
| Error e ->
error client (Resp.string_of_error e)
>>= fun () -> handle t client true
| _ ->
error client "invalid command format"
>>= fun () -> handle t client true )
(function
| Resp.Exc exc ->
error client (Resp.string_of_error exc)
>>= fun () -> handle t client true
| Not_found ->
discard_n client !argc
>>= fun () ->
error client "command not found" >>= fun () -> handle t client true
| Failure msg
| Invalid_argument msg ->
error client msg
| End_of_file ->
Lwt.return ()
| exc ->
raise exc)
and handle_not_authenticated t client =
Value.read (fst client)
>>= function
| `Array arr -> (
let cmd, args = split_command_s arr in
match (cmd, args) with
| "auth", args ->
if check_auth t.auth args then
ok client >>= fun () -> handle t client true
else
error client "authentication required"
>>= fun () -> handle t client false
| _, _ ->
error client "authentication required"
>>= fun () -> handle t client false )
| _ ->
error client "authentication required"
>>= fun () -> handle t client false
let start t = run t.server (fun client -> handle t client (t.auth = None))
end