1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
open Tensor_with_debug
type key = int
let key seed = Stdlib.abs seed land 0x7FFFFFFF
let split ?(n = 2) key =
let hash x =
let open Int32 in
let x = of_int x in
let x = logxor x (shift_right_logical x 16) in
let x = mul x 0x85ebca6bl in
let x = logxor x (shift_right_logical x 13) in
let x = mul x 0xc2b2ae35l in
let x = logxor x (shift_right_logical x 16) in
to_int (logand x 0x7FFFFFFFl)
in
Array.init n (fun i -> hash ((key * (n + 1)) + i + 1))
let fold_in key data =
let hash x =
let open Int32 in
let x = of_int x in
let x = logxor x (shift_right_logical x 16) in
let x = mul x 0x85ebca6bl in
let x = logxor x (shift_right_logical x 13) in
let x = mul x 0xc2b2ae35l in
let x = logxor x (shift_right_logical x 16) in
to_int (logand x 0x7FFFFFFFl)
in
hash (key lxor data)
let to_int key = key
let uniform key dtype shape =
rand dtype ~seed:key shape
let normal key dtype shape =
randn dtype ~seed:key shape
let randint key ~min ~max shape =
if min >= max then
invalid_arg
(Printf.sprintf "randint: min (%d) must be less than max (%d)" min max);
let range = max - min in
let uniform_vals = uniform key Tensor.Float32 shape in
let scaled = mul uniform_vals (scalar Tensor.Float32 (float_of_int range)) in
let shifted = add scaled (scalar Tensor.Float32 (float_of_int min)) in
astype Tensor.Int32 shifted
let bernoulli key ~p shape =
if p < 0. || p > 1. then
invalid_arg (Printf.sprintf "bernoulli: p (%.2f) must be in [0, 1]" p);
let uniform_vals = uniform key Tensor.Float32 shape in
let threshold = scalar Float32 p in
cmplt uniform_vals threshold
let permutation key n =
if n <= 0 then
invalid_arg (Printf.sprintf "permutation: n (%d) must be positive" n);
let random_vals = uniform key Tensor.Float32 [| n |] in
argsort random_vals ~axis:0 ~descending:false
let shuffle key x =
let shape_x = Tensor.shape x in
if Array.length shape_x = 0 then x
else
let n = shape_x.(0) in
let perm = permutation key n in
let perm_array = to_array perm |> Array.map Int32.to_int in
let results = Array.map (fun i -> get [ i ] x) perm_array in
concatenate ~axis:0 (Array.to_list results)
let categorical _key ?axis:(_ = -1) _logits =
failwith "categorical: not implemented yet (requires cumsum)"
let truncated_normal key dtype ~lower ~upper shape =
if lower >= upper then
invalid_arg
(Printf.sprintf "truncated_normal: lower must be less than upper");
let vals = normal key dtype shape in
clip vals ~min:lower ~max:upper