1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
open Util
let hmac mac seq buf =
let hmac = mac.Hmac.hmac in
let key = mac.Hmac.key in
let seqbuf = Cstruct.create 4 in
Cstruct.BE.set_uint32 seqbuf 0 seq;
Hmac.hmacv hmac ~key [ seqbuf; buf ]
let peek_len cipher seq block_len buf =
assert (block_len <= Cstruct.length buf);
let buf = Cstruct.sub buf 0 block_len in
let* hdr, _ = Cipher.decrypt ~len:true seq cipher buf in
Ok (Ssh.get_pkt_hdr_pkt_len hdr |> Int32.to_int)
let partial buf =
if Cstruct.length buf < Ssh.max_pkt_len then
Ok None
else
Error "Buffer is too big"
let to_msg pkt =
Result.bind (Wire.get_payload pkt) Wire.get_message
let decrypt keys buf =
let open Ssh in
let cipher = keys.Kex.cipher in
let mac = keys.Kex.mac in
let seq = keys.Kex.seq in
let block_len = max 8 (Cipher.block_len cipher.Cipher.cipher) in
let digest_len = Hmac.(digest_len mac.hmac)
and mac_len = Cipher.(mac_len cipher.Cipher.cipher)
in
if Cstruct.length buf < max sizeof_pkt_hdr (digest_len + mac_len + block_len) then
partial buf
else
let* pkt_len = peek_len cipher seq block_len buf in
let* () =
guard (pkt_len > 0 && pkt_len < max_pkt_len) "decrypt: Bogus pkt len"
in
if Cstruct.length buf < pkt_len + 4 + digest_len + mac_len then
partial buf
else
let pkt_enc, digest1 = Cstruct.split buf (pkt_len + 4 + mac_len) in
let tx_rx = Int64.(add keys.Kex.tx_rx (Cstruct.length pkt_enc - mac_len |> of_int)) in
let* pkt_dec, cipher = Cipher.decrypt ~len:false seq cipher pkt_enc in
let digest1 = Cstruct.sub digest1 0 digest_len in
let digest2 = hmac mac seq pkt_dec in
let* () =
guard (Cstruct.equal digest1 digest2) "decrypt: Bad digest"
in
let pad_len = get_pkt_hdr_pad_len pkt_dec in
let* () =
guard (pad_len >= 4 && pad_len <= 255 && pad_len < pkt_len)
"decrypt: Bogus pad len"
in
let buf = Cstruct.shift buf (4 + pkt_len + mac_len + digest_len) in
let keys = Kex.{ cipher; mac; seq = Int32.succ keys.Kex.seq; tx_rx } in
Ok (Some (pkt_dec, buf, keys))
let encrypt keys msg =
let cipher = keys.Kex.cipher in
let mac = keys.Kex.mac in
let seq = keys.Kex.seq in
let block_len = max 8 (Cipher.block_len cipher.Cipher.cipher) in
let buf = Dbuf.reserve Ssh.sizeof_pkt_hdr (Dbuf.create ()) |> Wire.put_message msg in
let len = Dbuf.used buf in
let len = if Cipher.aead cipher.Cipher.cipher then len - 4 else len in
let padlen =
let x = block_len - (len mod block_len) in
if x < 4 then x + block_len else x
in
assert (padlen >= 4 && padlen <= 255);
let pkt = Wire.put_random padlen buf |> Dbuf.to_cstruct in
Ssh.set_pkt_hdr_pkt_len pkt (Int32.of_int (Cstruct.length pkt - 4));
Ssh.set_pkt_hdr_pad_len pkt padlen;
let digest = hmac mac seq pkt in
let enc, cipher = Cipher.encrypt ~len:false seq cipher pkt in
let packet = Cstruct.append enc digest in
let tx_rx = Int64.add keys.Kex.tx_rx (Cstruct.length packet |> Int64.of_int) in
let keys = Kex.{ cipher; mac; seq = Int32.succ keys.Kex.seq; tx_rx } in
packet, keys