Source file ops_threefry.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
open Nx_core.Dtype
module Shape = Nx_core.Shape
open Internal
module Threefry_impl = struct
let ks_parity_32 = 0x1BD11BDA_l
let r_2x32 = [| 13; 15; 26; 6; 17; 29; 16; 24 |]
let rotl32 x n =
let n = n land 31 in
Int32.(logor (shift_left x n) (shift_right_logical x (32 - n)))
let threefry2x32_20_rounds (c0 : int32) (c1 : int32) (k0 : int32) (k1 : int32)
: int32 * int32 =
let x0 = ref c0 in
let x1 = ref c1 in
let keys = [| k0; k1; Int32.logxor ks_parity_32 (Int32.logxor k0 k1) |] in
for r = 0 to 19 do
if r mod 4 = 0 then (
let s_div_4 = r / 4 in
x0 := Int32.add !x0 keys.(s_div_4 mod 3);
x1 := Int32.add !x1 keys.((s_div_4 + 1) mod 3);
x1 := Int32.add !x1 (Int32.of_int s_div_4));
x0 := Int32.add !x0 !x1;
x1 := rotl32 !x1 r_2x32.(r mod 8);
x1 := Int32.logxor !x1 !x0
done;
let s_div_4_final = 20 / 4 in
x0 := Int32.add !x0 keys.(s_div_4_final mod 3);
x1 := Int32.add !x1 keys.((s_div_4_final + 1) mod 3);
x1 := Int32.add !x1 (Int32.of_int s_div_4_final);
(!x0, !x1)
end
let kernel_threefry_int32 (data_t : (int32, int32_elt) t)
(seed_t : (int32, int32_elt) t) (out_t : (int32, int32_elt) t) start_idx
end_idx =
let data_buf = buffer data_t in
let seed_buf = buffer seed_t in
let out_buf = buffer out_t in
let c1_fixed = 0l in
let k1_fixed = 0xCAFEBABEl in
if is_c_contiguous data_t && is_c_contiguous seed_t then (
let data_offset = offset data_t in
let seed_offset = offset seed_t in
let out_offset = offset out_t in
let data_base = data_offset + start_idx in
let seed_base = seed_offset + start_idx in
let out_base = out_offset + start_idx in
let i = ref 0 in
let n = end_idx - start_idx in
while !i + 3 < n do
let i0 = !i and i1 = !i + 1 and i2 = !i + 2 and i3 = !i + 3 in
let d_val0 = Bigarray.Array1.unsafe_get data_buf (data_base + i0) in
let s_val0 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i0) in
let res0_0, _ =
Threefry_impl.threefry2x32_20_rounds d_val0 c1_fixed s_val0 k1_fixed
in
Bigarray.Array1.unsafe_set out_buf (out_base + i0) res0_0;
let d_val1 = Bigarray.Array1.unsafe_get data_buf (data_base + i1) in
let s_val1 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i1) in
let res0_1, _ =
Threefry_impl.threefry2x32_20_rounds d_val1 c1_fixed s_val1 k1_fixed
in
Bigarray.Array1.unsafe_set out_buf (out_base + i1) res0_1;
let d_val2 = Bigarray.Array1.unsafe_get data_buf (data_base + i2) in
let s_val2 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i2) in
let res0_2, _ =
Threefry_impl.threefry2x32_20_rounds d_val2 c1_fixed s_val2 k1_fixed
in
Bigarray.Array1.unsafe_set out_buf (out_base + i2) res0_2;
let d_val3 = Bigarray.Array1.unsafe_get data_buf (data_base + i3) in
let s_val3 = Bigarray.Array1.unsafe_get seed_buf (seed_base + i3) in
let res0_3, _ =
Threefry_impl.threefry2x32_20_rounds d_val3 c1_fixed s_val3 k1_fixed
in
Bigarray.Array1.unsafe_set out_buf (out_base + i3) res0_3;
i := !i + 4
done;
while !i < n do
let current_idx = !i in
let d_val =
Bigarray.Array1.unsafe_get data_buf (data_base + current_idx)
in
let s_val =
Bigarray.Array1.unsafe_get seed_buf (seed_base + current_idx)
in
let res0, _ =
Threefry_impl.threefry2x32_20_rounds d_val c1_fixed s_val k1_fixed
in
Bigarray.Array1.unsafe_set out_buf (out_base + current_idx) res0;
incr i
done)
else
let out_shape = shape out_t in
let data_strides = strides data_t in
let seed_strides = strides seed_t in
let data_offset = offset data_t in
let seed_offset = offset seed_t in
let md_index = Array.make (Array.length out_shape) 0 in
for k = start_idx to end_idx - 1 do
Shape.unravel_index_into k out_shape md_index;
let data_lin = Shape.ravel_index md_index data_strides in
let seed_lin = Shape.ravel_index md_index seed_strides in
let d_val =
Bigarray.Array1.unsafe_get data_buf (data_offset + data_lin)
in
let s_val =
Bigarray.Array1.unsafe_get seed_buf (seed_offset + seed_lin)
in
let res0, _ =
Threefry_impl.threefry2x32_20_rounds d_val c1_fixed s_val k1_fixed
in
Bigarray.Array1.unsafe_set out_buf k res0
done
let threefry (context : context) (data_t : (int32, int32_elt) t)
(seed_t : (int32, int32_elt) t) (out_t : (int32, int32_elt) t) : unit =
let size = size out_t in
if size = 0 then ()
else
Parallel.parallel_for context.pool 0 (size - 1) (fun start_idx end_idx ->
kernel_threefry_int32 data_t seed_t out_t start_idx end_idx)