Source file rng.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
(* JAX-style splittable random number generation *)

open Tensor_with_debug

type key = int

let key seed = Stdlib.abs seed land 0x7FFFFFFF (* Ensure positive 31-bit int *)

let split ?(n = 2) key =
  (* MurmurHash-inspired integer hash for better distribution *)
  let hash x =
    let open Int32 in
    let x = of_int x in
    let x = logxor x (shift_right_logical x 16) in
    let x = mul x 0x85ebca6bl in
    let x = logxor x (shift_right_logical x 13) in
    let x = mul x 0xc2b2ae35l in
    let x = logxor x (shift_right_logical x 16) in
    to_int (logand x 0x7FFFFFFFl)
  in
  Array.init n (fun i -> hash ((key * (n + 1)) + i + 1))

let fold_in key data =
  let hash x =
    let open Int32 in
    let x = of_int x in
    let x = logxor x (shift_right_logical x 16) in
    let x = mul x 0x85ebca6bl in
    let x = logxor x (shift_right_logical x 13) in
    let x = mul x 0xc2b2ae35l in
    let x = logxor x (shift_right_logical x 16) in
    to_int (logand x 0x7FFFFFFFl)
  in
  hash (key lxor data)

let to_int key = key

(* Random sampling functions *)

let uniform key dtype shape =
  (* Use the key as seed for existing rand function *)
  rand dtype ~seed:key shape

let normal key dtype shape =
  (* Use the key as seed for existing randn function *)
  randn dtype ~seed:key shape

let randint key ~min ~max shape =
  if min >= max then
    invalid_arg
      (Printf.sprintf "randint: min (%d) must be less than max (%d)" min max);
  let range = max - min in
  let uniform_vals = uniform key Tensor.Float32 shape in
  let scaled = mul uniform_vals (scalar Tensor.Float32 (float_of_int range)) in
  let shifted = add scaled (scalar Tensor.Float32 (float_of_int min)) in
  astype Tensor.Int32 shifted

let bernoulli key ~p shape =
  if p < 0. || p > 1. then
    invalid_arg (Printf.sprintf "bernoulli: p (%.2f) must be in [0, 1]" p);
  let uniform_vals = uniform key Tensor.Float32 shape in
  let threshold = scalar Float32 p in
  cmplt uniform_vals threshold

let permutation key n =
  if n <= 0 then
    invalid_arg (Printf.sprintf "permutation: n (%d) must be positive" n);
  (* Generate random values for each index *)
  let random_vals = uniform key Tensor.Float32 [| n |] in
  (* Get argsort to create permutation *)
  argsort random_vals ~axis:0 ~descending:false

let shuffle key x =
  let shape_x = Tensor.shape x in
  if Array.length shape_x = 0 then x
  else
    let n = shape_x.(0) in
    let perm = permutation key n in
    (* Create shuffled tensor by indexing *)
    let perm_array = to_array perm |> Array.map Int32.to_int in
    let results = Array.map (fun i -> get [ i ] x) perm_array in
    concatenate ~axis:0 (Array.to_list results)

(* TODO: Implement categorical when cumsum is available let categorical key ctx
   ?(axis = -1) logits = let shape_array = shape logits in let ndim =
   Array.length shape_array in let axis = if axis < 0 then ndim + axis else axis
   in

   (* Generate uniform random values with same shape *) let uniform_vals =
   uniform key Tensor.Float32 shape_array in

   (* Apply softmax to get probabilities *) let probs = softmax logits ~axes:[|
   axis |] in

   (* Compute cumulative sum along axis *) let cumsum = cumsum probs ~axis in

   (* Find where uniform_vals < cumsum for the first time *) let comparison =
   cmplt uniform_vals cumsum in

   (* argmax along axis gives us the first True index *) argmax comparison ~axis
   ~keepdims:false *)

(* Temporary placeholder for categorical *)
let categorical _key ?axis:(_ = -1) _logits =
  failwith "categorical: not implemented yet (requires cumsum)"

let truncated_normal key dtype ~lower ~upper shape =
  if lower >= upper then
    invalid_arg
      (Printf.sprintf "truncated_normal: lower must be less than upper");

  (* Simple clipping approach for now *)
  let vals = normal key dtype shape in

  (* Clip values to bounds *)
  clip vals ~min:lower ~max:upper