RuneSourceinclude module type of NxThe type ('a, 'b) t represents a tensor where 'a is the OCaml type of elements and 'b is the bigarray element type. For example, (float, float32_elt) t is a tensor of 32-bit floats.
Operations automatically broadcast compatible shapes: each dimension must be equal or one of them must be 1. Shape |3; 1; 5| broadcasts with |1; 4; 5| to |3; 4; 5|.
Tensors can be C-contiguous or strided. Operations return views when possible (O(1)), otherwise copy (O(n)). Use is_contiguous to check layout and contiguous to ensure contiguity.
('a, 'b) t is a tensor with OCaml type 'a and bigarray type 'b.
type ('a, 'b) dtype = ('a, 'b) Nx_core.Dtype.t = | Float16 : (float, float16_elt) dtype| Float32 : (float, float32_elt) dtype| Float64 : (float, float64_elt) dtype| Int8 : (int, int8_elt) dtype| UInt8 : (int, uint8_elt) dtype| Int16 : (int, int16_elt) dtype| UInt16 : (int, uint16_elt) dtype| Int32 : (int32, int32_elt) dtype| Int64 : (int64, int64_elt) dtype| Int : (int, int_elt) dtype| NativeInt : (nativeint, nativeint_elt) dtype| Complex32 : (Complex.t, complex32_elt) dtype| Complex64 : (Complex.t, complex64_elt) dtype| BFloat16 : (float, bfloat16_elt) dtype| Bool : (bool, bool_elt) dtype| Int4 : (int, int4_elt) dtype| UInt4 : (int, uint4_elt) dtype| Float8_e4m3 : (float, float8_e4m3_elt) dtype| Float8_e5m2 : (float, float8_e5m2_elt) dtype| Complex16 : (Complex.t, complex16_elt) dtype| QInt8 : (int, qint8_elt) dtype| QUInt8 : (int, quint8_elt) dtypeData type specification. Links OCaml types to bigarray element types.
*)type index = | I of intSingle index: I 2 selects index 2
| L of int listList of indices: L [0; 2; 5] selects indices 0, 2, 5
| R of int * intRange [start, stop): R (1, 4) selects 1, 2, 3
| Rs of int * int * intRange with step: Rs (0, 10, 2) selects 0, 2, 4, 6, 8
| AAll indices: A selects entire axis
| M of (bool, bool_elt) tBoolean mask: M mask selects where mask is true
| NNew axis: N inserts dimension of size 1
Index specification for tensor slicing
Functions to inspect array dimensions, memory layout, and data access.
data t returns underlying bigarray buffer.
Buffer may contain data beyond tensor bounds for strided views. Direct access requires careful index computation using strides and offset.
is_c_contiguous t returns true if elements are contiguous in C order.
to_bigarray t converts to standard bigarray.
Always returns contiguous copy with same shape. Use for interop with libraries expecting standard bigarrays.
# let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
val t : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# shape (to_bigarray t |> of_bigarray)
- : int array = [|2; 3|]to_bigarray_ext t converts to extended bigarray.
Always returns contiguous copy with same shape. Use for interop with libraries expecting extended bigarrays (e.g., with bfloat16 support).
# let t = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
val t : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# shape (to_bigarray_ext t |> of_bigarray_ext)
- : int array = [|2; 3|]to_array t converts to OCaml array.
Flattens tensor to 1-D array in row-major (C) order. Always copies.
# let t = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |]
val t : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# to_array t
- : int32 array = [|1l; 2l; 3l; 4l|]Functions to create and initialize arrays.
create dtype shape data creates tensor from array data.
Length of data must equal product of shape.
# create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |]
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6]]init dtype shape f creates tensor where element at indices i has value f i.
Function f receives array of indices for each position. Useful for creating position-dependent values.
# init int32 [| 2; 3 |] (fun i -> Int32.of_int (i.(0) + i.(1)))
- : (int32, int32_elt) t = [[0, 1, 2],
[1, 2, 3]]
# init float32 [| 3; 3 |] (fun i -> if i.(0) = i.(1) then 1. else 0.)
- : (float, float32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]empty dtype shape allocates uninitialized tensor.
full dtype shape value creates tensor filled with value.
# full float32 [| 2; 3 |] 3.14
- : (float, float32_elt) t = [[3.14, 3.14, 3.14],
[3.14, 3.14, 3.14]]ones dtype shape creates tensor filled with ones.
zeros dtype shape creates tensor filled with zeros.
scalar dtype value creates scalar tensor containing value.
empty_like t creates uninitialized tensor with same shape and dtype as t.
full_like t value creates tensor shaped like t filled with value.
ones_like t creates tensor shaped like t filled with ones.
zeros_like t creates tensor shaped like t filled with zeros.
scalar_like t value creates scalar with same dtype as t.
eye ?m ?k dtype n creates matrix with ones on k-th diagonal.
Default m = n (square), k = 0 (main diagonal). Positive k shifts diagonal above main, negative below.
# eye int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]
# eye ~k:1 int32 3
- : (int32, int32_elt) t = [[0, 1, 0],
[0, 0, 1],
[0, 0, 0]]
# eye ~m:2 ~k:(-1) int32 3
- : (int32, int32_elt) t = [[0, 0, 0],
[1, 0, 0]]identity dtype n creates nรn identity matrix.
Equivalent to eye dtype n. Square matrix with ones on main diagonal, zeros elsewhere.
# identity int32 3
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 1, 0],
[0, 0, 1]]diag ?k v extracts diagonal or constructs diagonal array.
If v is 1D, returns 2D array with v on the k-th diagonal. If v is 2D, returns 1D array containing the k-th diagonal. Use k > 0 for diagonals above the main diagonal, k < 0 for diagonals below.
# let x = arange int32 0 9 1 |> reshape [|3; 3|]
val x : (int32, int32_elt) t = [[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]
# diag x
- : (int32, int32_elt) t = [0, 4, 8]
# diag ~k:1 x
- : (int32, int32_elt) t = [1, 5]
# let v = create int32 [|3|] [|1l; 2l; 3l|]
val v : (int32, int32_elt) t = [1, 2, 3]
# diag v
- : (int32, int32_elt) t = [[1, 0, 0],
[0, 2, 0],
[0, 0, 3]]arange dtype start stop step generates values from start to [stop).
Step must be non-zero. Result length is (stop - start) / step rounded toward zero.
# arange int32 0 10 2
- : (int32, int32_elt) t = [0, 2, 4, 6, 8]
# arange int32 5 0 (-1)
- : (int32, int32_elt) t = [5, 4, 3, 2, 1]arange_f dtype start stop step generates float values from start to [stop).
Like arange but for floating-point ranges. Handles fractional steps. Due to floating-point precision, final value may differ slightly from expected.
# arange_f float32 0. 1. 0.2
- : (float, float32_elt) t = [0, 0.2, 0.4, 0.6, 0.8]
# arange_f float32 1. 0. (-0.25)
- : (float, float32_elt) t = [1, 0.75, 0.5, 0.25]linspace dtype ?endpoint start stop count generates count evenly spaced values from start to stop.
If endpoint is true (default), stop is included.
# linspace float32 ~endpoint:true 0. 10. 5
- : (float, float32_elt) t = [0, 2.5, 5, 7.5, 10]
# linspace float32 ~endpoint:false 0. 10. 5
- : (float, float32_elt) t = [0, 2, 4, 6, 8]val logspace :
(float, 'a) dtype ->
?endpoint:bool ->
?base:float ->
float ->
float ->
int ->
(float, 'a) tlogspace dtype ?endpoint ?base start_exp stop_exp count generates values evenly spaced on log scale.
Returns base ** x where x ranges from start_exp to stop_exp. Default base = 10.0.
# logspace float32 0. 2. 3
- : (float, float32_elt) t = [1, 10, 100]
# logspace float32 ~base:2.0 0. 3. 4
- : (float, float32_elt) t = [1, 2, 4, 8]geomspace dtype ?endpoint start stop count generates values evenly spaced on geometric (multiplicative) scale.
# geomspace float32 1. 1000. 4
- : (float, float32_elt) t = [1, 10, 100, 1000]meshgrid ?indexing x y creates coordinate grids from 1D arrays.
Returns (X, Y) where X and Y are 2D arrays representing grid coordinates.
`xy (default): Cartesian indexing - X changes along columns, Y changes along rows`ij: Matrix indexing - X changes along rows, Y changes along columns # let x = linspace float32 0. 2. 3 in
let y = linspace float32 0. 1. 2 in
meshgrid x y
- : (float, float32_elt) t * (float, float32_elt) t =
([[0, 1, 2],
[0, 1, 2]], [[0, 0, 0],
[1, 1, 1]])tril ?k x returns lower triangular part of matrix.
Elements above the k-th diagonal are zeroed.
k = 0 (default): main diagonalk > 0: include k diagonals above maink < 0: exclude |k| diagonals below maintriu ?k x returns upper triangular part of matrix.
Elements below the k-th diagonal are zeroed.
k = 0 (default): main diagonalk > 0: exclude k diagonals above maink < 0: include |k| diagonals below mainof_bigarray ba creates tensor from standard bigarray.
Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both.
# let t = zeros float32 [| 2; 3 |] in
t
- : (float, float32_elt) t = [[0, 0, 0],
[0, 0, 0]]of_bigarray_ext ba creates tensor from extended bigarray.
Zero-copy when bigarray is contiguous. Creates view sharing same memory. Modifications to either affect both. Supports extended types like bfloat16.
# let ba = Bigarray_ext.Genarray.create Bigarray_ext.bfloat16
Bigarray_ext.c_layout [| 2; 3 |] in
let t = of_bigarray_ext ba in
shape t
- : int array = [|2; 3|]Functions to generate arrays with random values.
rand dtype ?seed shape generates uniform random values in [0, 1).
Only supports float dtypes. Same seed produces same sequence.
randn dtype ?seed shape generates standard normal random values.
Mean 0, variance 1. Uses Box-Muller transform for efficiency. Only supports float dtypes. Same seed produces same sequence.
randint dtype ?seed ?high shape low generates integers in [low, high).
Uniform distribution over range. Default high = 10. Note: high is exclusive (NumPy convention).
dropout ?seed ~rate x randomly zeroes elements with probability rate.
The remaining values are scaled by 1/(1 - rate) keeping the expected value unchanged. Requires floating-point input and rate in [0, 1). When seed is provided, the same mask is reproduced.
Functions to reshape, transpose, and rearrange arrays.
reshape shape t returns view with new shape.
At most one dimension can be -1 (inferred from total elements). Product of dimensions must match total elements. Returns a zero-copy view when the new layout is compatible; raises otherwise.
# let t = create int32 [|2; 3|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|6|] t
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = create int32 [|6|] [|1l; 2l; 3l; 4l; 5l; 6l|] in
reshape [|3; -1|] t
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]broadcast_to shape t broadcasts tensor to target shape.
Shapes must be broadcast-compatible: dimensions align from right, each must be equal or source must be 1. Returns view (no copy) with zero strides for broadcast dimensions.
# let t = create int32 [|1; 3|] [|1l; 2l; 3l|] in
broadcast_to [|3; 3|] t
- : (int32, int32_elt) t = [[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]
# let t = ones float32 [|3; 1|] in
shape (broadcast_to [|2; 3; 4|] t)
- : int array = [|2; 3; 4|]broadcasted ?reverse t1 t2 broadcasts tensors to common shape.
Returns views of both tensors broadcast to compatible shape. If reverse is true, returns (t2', t1') instead of (t1', t2'). Useful before element-wise operations.
# let t1 = ones float32 [|3;1|] in
let t2 = ones float32 [|1;5|] in
let t1', t2' = broadcasted t1 t2 in
shape t1', shape t2'
- : int array * int array = ([|3; 5|], [|3; 5|])expand shape t broadcasts tensor where -1 keeps original dimension.
Like broadcast_to but -1 preserves existing dimension size. Adds dimensions on left if needed.
# let t = ones float32 [|1; 4; 1|] in
shape (expand [|3; -1; 5|] t)
- : int array = [|3; 4; 5|]
# let t = ones float32 [|5; 5|] in
shape (expand [|-1; -1|] t)
- : int array = [|5; 5|]flatten ?start_dim ?end_dim t collapses dimensions into single dimension.
Default start_dim = 0, end_dim = -1 (last). Negative indices count from end. Dimensions start_dim through end_dim inclusive are flattened.
# flatten (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|24|]
# flatten ~start_dim:1 ~end_dim:2 (zeros float32 [| 2; 3; 4; 5 |]) |> shape
- : int array = [|2; 12; 5|]unflatten dim sizes t expands dimension dim into multiple dimensions.
Product of sizes must equal size of dimension dim. At most one dimension can be -1 (inferred). Inverse of flatten.
# unflatten 1 [| 3; 4 |] (zeros float32 [| 2; 12; 5 |]) |> shape
- : int array = [|2; 3; 4; 5|]
# unflatten 0 [| -1; 2 |] (ones float32 [| 6; 5 |]) |> shape
- : int array = [|3; 2; 5|]ravel t returns contiguous 1-D view.
Equivalent to flatten t but always returns contiguous result. Use when you need both flattening and contiguity.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
ravel x
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous t
- : bool = false
# let t_ravel = ravel t in
is_c_contiguous t_ravel
- : bool = truesqueeze ?axes t removes dimensions of size 1.
If axes specified, only removes those dimensions. Negative indices count from end. Returns view when possible.
# squeeze (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[ 0; 2 ] (ones float32 [| 1; 3; 1; 4 |]) |> shape
- : int array = [|3; 4|]
# squeeze ~axes:[ -1 ] (ones float32 [| 3; 4; 1 |]) |> shape
- : int array = [|3; 4|]unsqueeze ?axes t inserts dimensions of size 1 at specified positions.
Axes refer to positions in result tensor. Must be in range 0, ndim.
# unsqueeze ~axes:[ 0; 2 ] (create float32 [| 3 |] [| 1.; 2.; 3. |]) |> shape
- : int array = [|1; 3; 1|]
# unsqueeze ~axes:[ 1 ] (create float32 [| 2 |] [| 5.; 6. |]) |> shape
- : int array = [|2; 1|]squeeze_axis axis t removes dimension axis if size is 1.
unsqueeze_axis axis t inserts dimension of size 1 at axis.
expand_dims axes t is synonym for unsqueeze.
transpose ?axes t permutes dimensions.
Default reverses all dimensions. axes must be permutation of 0..ndim-1. Returns view (no copy) with adjusted strides.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
transpose x
- : (int32, int32_elt) t = [[1, 4],
[2, 5],
[3, 6]]
# transpose ~axes:[ 2; 0; 1 ] (zeros float32 [| 2; 3; 4 |]) |> shape
- : int array = [|4; 2; 3|]
# let id = transpose ~axes:[ 1; 0 ] in
id == transpose
- : bool = falseflip ?axes t reverses order along specified dimensions.
Default flips all dimensions.
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip x
- : (int32, int32_elt) t = [[6, 5, 4],
[3, 2, 1]]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
flip ~axes:[ 1 ] x
- : (int32, int32_elt) t = [[3, 2, 1],
[6, 5, 4]]as_strided shape strides ~offset t creates a strided view of the tensor with custom shape, strides (in elements), and offset (in elements).
This is a low-level operation that allows arbitrary memory layouts. Backends that support arbitrary strides (e.g., native) implement this as zero-copy. Other backends may need to materialize the data.
Warning: This function can create views that access out-of-bounds memory if used incorrectly. Use with caution.
# let x = create float32 [| 8 |] [| 0.; 1.; 2.; 3.; 4.; 5.; 6.; 7. |] in
as_strided [| 3; 3 |] [| 2; 1 |] ~offset:0 x
- : (float, float32_elt) t = [[0, 1, 2],
[2, 3, 4],
[4, 5, 6]]moveaxis src dst t moves dimension from src to dst.
swapaxes axis1 axis2 t exchanges two dimensions.
roll ?axis shift t shifts elements along axis.
Elements shifted beyond last position wrap to beginning. If axis not specified, shifts flattened tensor. Negative shift rolls backward.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
roll 2 x
- : (int32, int32_elt) t = [4, 5, 1, 2, 3]
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
roll ~axis:1 1 x
- : (int32, int32_elt) t = [[3, 1, 2],
[6, 4, 5]]
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
roll ~axis:0 (-1) x
- : (int32, int32_elt) t = [[3, 4],
[1, 2]]pad padding value t pads tensor with value.
padding specifies (before, after) for each dimension. Length must match tensor dimensions. Negative padding not allowed.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
pad [| (1, 1); (1, 1) |] 0. x |> shape
- : int array = [|4; 4|]shrink ranges t extracts slice from start to stop (exclusive) for each dimension.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
shrink [| (1, 3); (0, 2) |] x
- : (int32, int32_elt) t = [[4, 5],
[7, 8]]tile reps t constructs tensor by repeating t.
reps specifies repetitions per dimension. If longer than ndim, prepends dimensions. Zero repetitions create empty tensor.
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
tile [| 2; 3 |] x
- : (int32, int32_elt) t = [[1, 2, 1, 2, 1, 2],
[1, 2, 1, 2, 1, 2]]
# let x = create int32 [| 2 |] [| 1l; 2l |] in
tile [| 2; 1; 3 |] x |> shape
- : int array = [|2; 1; 6|]repeat ?axis count t repeats elements count times.
If axis not specified, repeats flattened tensor.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
repeat 2 x
- : (int32, int32_elt) t = [1, 1, 2, 2, 3, 3]
# let x = create int32 [| 1; 2 |] [| 1l; 2l |] in
repeat ~axis:0 3 x
- : (int32, int32_elt) t = [[1, 2],
[1, 2],
[1, 2]]Functions to join and split arrays.
concatenate ?axis ts joins tensors along existing axis.
All tensors must have same shape except on concatenation axis. If axis not specified, flattens all tensors then concatenates. Returns contiguous result.
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate ~axis:0 [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 1; 2 |] [| 5l; 6l |] in
concatenate [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]stack ?axis ts joins tensors along new axis.
All tensors must have identical shape. Result rank is input rank + 1. Default axis=0. Negative axis counts from end of result shape.
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4]]
# let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
stack ~axis:1 [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# stack ~axis:(-1) [ones float32 [| 2; 3 |]; zeros float32 [| 2; 3 |]] |> shape
- : int array = [|2; 3; 2|]vstack ts stacks tensors vertically (row-wise).
1-D tensors are treated as row vectors (shape 1;n). Higher-D tensors concatenate along axis 0. All tensors must have same shape except possibly first dimension.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 3],
[4, 5, 6]]
# let x1 = create int32 [| 1; 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 2 |] [| 3l; 4l; 5l; 6l |] in
vstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2],
[3, 4],
[5, 6]]hstack ts stacks tensors horizontally (column-wise).
1-D tensors concatenate directly. Higher-D tensors concatenate along axis 1. For 1-D arrays of different lengths, use vstack to make 2-D first.
# let x1 = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let x2 = create int32 [| 3 |] [| 4l; 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [1, 2, 3, 4, 5, 6]
# let x1 = create int32 [| 2; 1 |] [| 1l; 2l |] in
let x2 = create int32 [| 2; 1 |] [| 3l; 4l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 3],
[2, 4]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 1 |] [| 5l; 6l |] in
hstack [x1; x2]
- : (int32, int32_elt) t = [[1, 2, 5],
[3, 4, 6]]dstack ts stacks tensors depth-wise (along third axis).
Tensors are reshaped to at least 3-D before concatenation:
n โ 1;n;1m;n โ m;n;1 # let x1 = create int32 [| 2 |] [| 1l; 2l |] in
let x2 = create int32 [| 2 |] [| 3l; 4l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 3],
[2, 4]]]
# let x1 = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
let x2 = create int32 [| 2; 2 |] [| 5l; 6l; 7l; 8l |] in
dstack [x1; x2]
- : (int32, int32_elt) t = [[[1, 5],
[2, 6]],
[[3, 7],
[4, 8]]]broadcast_arrays ts broadcasts all tensors to common shape.
Finds the common broadcast shape and returns list of views with that shape. Broadcasting rules: dimensions align right, each must be 1 or equal.
# let x1 = ones float32 [| 3; 1 |] in
let x2 = ones float32 [| 1; 5 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|3; 5|]; [|3; 5|]]
# let x1 = scalar float32 5. in
let x2 = ones float32 [| 2; 3; 4 |] in
broadcast_arrays [x1; x2] |> List.map shape
- : int array list = [[|2; 3; 4|]; [|2; 3; 4|]]val array_split :
axis:int ->
[< `Count of int | `Indices of int list ] ->
('a, 'b) t ->
('a, 'b) t listarray_split ~axis sections t splits tensor into multiple parts.
`Count n divides into n parts as evenly as possible. Extra elements go to first parts. `Indices [i1;i2;...] splits at indices creating start:i1, i1:i2, i2:end.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
array_split ~axis:0 (`Count 3) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5]]
# let x = create int32 [| 6 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
array_split ~axis:0 (`Indices [ 2; 4 ]) x
- : (int32, int32_elt) t list = [[1, 2]; [3, 4]; [5, 6]]split ~axis sections t splits into equal parts.
# let x = create int32 [| 4; 2 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
split ~axis:0 2 x
- : (int32, int32_elt) t list = [[[1, 2],
[3, 4]]; [[5, 6],
[7, 8]]]Functions to convert between types and create copies.
cast dtype t converts elements to new dtype.
Returns copy with same values in new type.
# let x = create float32 [| 3 |] [| 1.5; 2.7; 3.1 |] in
cast int32 x
- : (int32, int32_elt) t = [1, 2, 3]contiguous t returns C-contiguous tensor.
Returns t unchanged if already contiguous (O(1)), otherwise creates contiguous copy (O(n)). Use before operations requiring direct memory access.
# let t = transpose (ones float32 [| 3; 4 |]) in
is_c_contiguous (contiguous t)
- : bool = truecopy t returns deep copy.
Always allocates new memory and copies data. Result is contiguous.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let y = copy x in
set_item [ 0 ] 999. y;
x, y
- : (float, float32_elt) t * (float, float32_elt) t =
([1, 2, 3], [999, 2, 3])blit src dst copies src into dst.
Shapes must match exactly. Handles broadcasting internally. Modifies dst in-place.
let dst = zeros float32 [| 3; 3 |] in
blit (ones float32 [| 3; 3 |]) dst
(* dst now contains all 1s *)ifill value t sets all elements of t to value in-place.
fill value t returns a copy of t filled with value, leaving t unchanged. Handy when wanting a filled tensor without mutating the source.
Functions to access and modify array elements.
get indices t returns subtensor at indices.
Indexes from outermost dimension. Returns scalar tensor if all dimensions indexed, otherwise returns view of remaining dimensions.
# let x = create int32 [| 2; 2; 2 |] [| 0l; 1l; 2l; 3l; 4l; 5l; 6l; 7l |] in
get [ 1; 1; 1 ] x
- : (int32, int32_elt) t = 7
# let x = create int32 [| 2; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l |] in
get [ 1 ] x
- : (int32, int32_elt) t = [4, 5, 6]set indices t value assigns value at indices.
slice specs t extracts subtensor using advanced indexing.
Each element in specs corresponds to an axis from left to right:
I i: single index (reduces dimension; negative from end)L [i1;i2;...]: fancy indexing with list of indicesR (start, stop): range [start, stop) with step 1Rs (start, stop, step): range with stepA: full axis (default for missing specs)M mask: boolean mask (shape must match axis)N: insert new dimension of size 1Missing specs default to A. Returns view when possible.
# let x = create int32 [| 2; 4 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l |] in
slice [ I 1 ] x
- : (int32, int32_elt) t = [5, 6, 7, 8]
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
slice [ R (1, 3) ] x
- : (int32, int32_elt) t = [1, 2]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
slice [ R (0, 2); L [0; 2] ] x
- : (int32, int32_elt) t = [[1, 3],
[4, 6]]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
slice [ A; N ] x (* Add new axis at position 1 *)
- : (int32, int32_elt) t = [[[1, 2, 3]],
[[4, 5, 6]],
[[7, 8, 9]]]set_slice specs t value assigns value to indexed region.
Like slice but modifies t in-place. Value is broadcast if needed.
item indices t returns scalar value at indices.
Must provide indices for all dimensions.
set_item indices value t sets scalar value at indices.
Must provide indices for all dimensions. Modifies tensor in-place.
val take :
?axis:int ->
?mode:[ `raise | `wrap | `clip ] ->
(int32, int32_elt) t ->
('a, 'b) t ->
('a, 'b) ttake ?axis ?mode indices t takes elements from t using indices.
Equivalent to tindices in NumPy along the specified axis. If axis is None, flattens t first. indices is an integer tensor of indices to take. mode handles out-of-bounds indices: `raise (default), `wrap (modulo), `clip (clamp to bounds).
Returns a new tensor with shape based on indices and t's shape.
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
take (create int32 [| 3 |] [| 1l; 3l; 0l |]) x
- : (int32, int32_elt) t = [1, 3, 0]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
take ~axis:1 (create int32 [| 3 |] [| 0l; 2l; 1l |]) x
- : (int32, int32_elt) t = [[1, 3, 2],
[4, 6, 5],
[7, 9, 8]]
# let x = create int32 [| 5 |] [| 0l; 1l; 2l; 3l; 4l |] in
take ~mode:`clip (create int32 [| 2 |] [| -1l; 5l |]) x (* Clamps to [0,4] *)
- : (int32, int32_elt) t = [0, 4]take_along_axis ~axis indices t takes values along the specified axis using indices.
Equivalent to NumPy's take_along_axis. indices must have the same shape as t except along the specified axis, where it matches the output size. Useful for gathering from argmax/argmin results.
# let x = create float32 [| 2; 3 |] [| 4.; 1.; 2.; 3.; 5.; 6. |] in
let indices = create int32 [| 2; 1 |] [| 1l; 0l |] in (* Per row indices *)
take_along_axis ~axis:1 indices x
- : (float, float32_elt) t = [[1],
[3]]
# let x = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
let indices = expand_dims [ 0 ] (argmax ~axis:0 x) in (* Shape [1, 3] *)
take_along_axis ~axis:0 indices x (* Max per column *)
- : (float, float32_elt) t = [[7, 8, 9]]val put :
?axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
?mode:[ `raise | `wrap | `clip ] ->
('a, 'b) t ->
unitput ?axis ~indices ~values ?mode t sets elements in t at positions specified by indices to values.
Equivalent to NumPy's put (in-place version of take for setting). If axis is None, flattens t first. indices is an integer tensor of positions to set. values must match the number of indices (broadcasted if needed). mode handles out-of-bounds indices: `raise (default), `wrap (modulo), `clip (clamp).
Modifies t in-place.
# let x = zeros int32 [| 5 |] in
put ~indices:(create int32 [| 3 |] [| 1l; 3l; 0l |])
~values:(create int32 [| 3 |] [| 10l; 20l; 30l |]) x;
x
- : (int32, int32_elt) t = [30, 10, 0, 20, 0]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
put ~axis:1 ~indices:(create int32 [| 3; 1 |] [| 0l; 2l; 1l |])
~values:(create int32 [| 3; 1 |] [| 10l; 20l; 30l |]) x;
x
- : (int32, int32_elt) t = [[10, 2, 3],
[4, 5, 20],
[7, 30, 9]]
# let y = zeros int32 [| 5 |] in
put ~mode:`clip ~indices:(create int32 [| 2 |] [| -1l; 5l |])
~values:(create int32 [| 2 |] [| 99l; 99l |]) y;
y (* Clamps to [0,4] *)
- : (int32, int32_elt) t = [99, 0, 0, 0, 99]val index_put :
indices:(int32, int32_elt) t array ->
values:('a, 'b) t ->
?mode:[ `raise | `wrap | `clip ] ->
('a, 'b) t ->
unitindex_put ~indices ~values ?mode t writes values into t at the coordinates specified by indices.
indices is an array that contains one tensor per axis of t. Each tensor provides integer coordinates for its axis; they are broadcast to a common shape that also determines how many updates are performed. values is broadcast to the same shape. Updates follow element-wise order and leave the shape of t unchanged. Duplicate coordinates overwrite previous updates, matching put.
mode controls how out-of-bounds indices are handled per axis: `raise (default) checks bounds, `wrap performs modular indexing, and `clip clamps to the valid range.
# let t = zeros float32 [| 3; 3 |] in
let rows = create int32 [| 4 |] [| 0l; 2l; 1l; 2l |] in
let cols = create int32 [| 4 |] [| 1l; 0l; 2l; 2l |] in
index_put ~indices:[| rows; cols |]
~values:(arange_f float32 0. 4. 1.) t;
t
- : (float, float32_elt) t = [[0, 0, 0],
[0, 0, 2],
[1, 0, 3]]val put_along_axis :
axis:int ->
indices:(int32, int32_elt) t ->
values:('a, 'b) t ->
('a, 'b) t ->
unitput_along_axis ~axis ~indices ~values t sets values along the specified axis using indices.
Equivalent to NumPy's put_along_axis. indices must have the same shape as t except along the axis (where it matches values' size along that axis). values is broadcasted to match the selection shape. Useful for scattering to argmax/argmin positions.
Modifies t in-place.
# let x = zeros float32 [| 2; 3 |] in
let indices = create int32 [| 2; 1 |] [| 1l; 0l |] in (* Per row positions *)
put_along_axis ~axis:1 ~indices ~values:(create float32 [| 2; 1 |] [| 10.; 20. |]) x;
x
- : (float, float32_elt) t = [[0, 10, 0],
[20, 0, 0]]
# let x = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
let indices = expand_dims [ 0 ] (argmax ~axis:0 x) in (* Shape [1, 3] *)
put_along_axis ~axis:0 ~indices ~values:(ones float32 [| 1; 3 |]) x;
x (* Set max per column to 1 *)
- : (float, float32_elt) t = [[1, 2, 3],
[4, 5, 6],
[1, 1, 1]]compress ?axis condition t selects elements where condition is true.
Equivalent to NumPy's compress. condition is a 1D boolean array. If axis is None, flattens t first. Otherwise, compresses along the specified axis (condition length must match t's dim along axis).
Returns a new tensor with reduced size along the axis/flattened.
# let x = create int32 [| 5 |] [| 1l; 2l; 3l; 4l; 5l |] in
compress ~condition:(create bool [| 5 |] [| true; false; true; false; true |]) x
- : (int32, int32_elt) t = [1, 3, 5]
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
compress ~axis:0 ~condition:(create bool [| 3 |] [| false; true; true |]) x
- : (int32, int32_elt) t = [[4, 5, 6],
[7, 8, 9]]extract condition t flattens and selects elements where condition is true.
Equivalent to NumPy's extract (1D compress after flatten). condition must have the same shape and size as t (element-wise).
Returns a 1D tensor with selected elements.
# let x = create int32 [| 3; 3 |] [| 1l; 2l; 3l; 4l; 5l; 6l; 7l; 8l; 9l |] in
extract ~condition:(greater_s x 5l) x
- : (int32, int32_elt) t = [6, 7, 8, 9]nonzero t returns indices of non-zero elements.
Equivalent to NumPy's nonzero. Treats non-zero as true for bool tensors. Returns an array of 1D tensors, one per dimension, with coordinates of non-zeros.
For example, for a 2D tensor, returns | rows; cols | where rowsi, colsi is the position of the i-th non-zero.
# let x = create int32 [| 3; 3 |] [| 0l; 1l; 0l; 2l; 0l; 3l; 0l; 0l; 4l |] in
let indices = nonzero x in
indices.(0), indices.(1)
- : (int32, int32_elt) t * (int32, int32_elt) t =
([0, 1, 1, 2], [1, 0, 2, 2])argwhere t returns indices of non-zero elements as a 2D tensor.
Equivalent to NumPy's argwhere. Each row is a coordinate dim0; dim1; ... of a non-zero element. Shape is num_nonzeros; ndim.
# let x = create int32 [| 3; 3 |] [| 0l; 1l; 0l; 2l; 0l; 3l; 0l; 0l; 4l |] in
argwhere x
- : (int32, int32_elt) t = [[0, 1],
[1, 0],
[1, 2],
[2, 2]]Element-wise arithmetic operations and their variants.
add t1 t2 computes element-wise sum with broadcasting.
iadd target value adds value to target in-place.
Returns modified target.
sub t1 t2 computes element-wise difference with broadcasting.
isub target value subtracts value from target in-place.
mul t1 t2 computes element-wise product with broadcasting.
imul target value multiplies target by value in-place.
div t1 t2 computes element-wise division.
True division for floats (result is float). Integer division for integers (truncates toward zero). Complex division follows standard rules.
# let x = create float32 [| 3 |] [| 7.; 8.; 9. |] in
let y = create float32 [| 3 |] [| 2.; 2.; 2. |] in
div x y
- : (float, float32_elt) t = [3.5, 4, 4.5]
# let x = create int32 [| 3 |] [| 7l; 8l; 9l |] in
let y = create int32 [| 3 |] [| 2l; 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [3, 4, 4]
# let x = create int32 [| 2 |] [| -7l; 8l |] in
let y = create int32 [| 2 |] [| 2l; 2l |] in
div x y
- : (int32, int32_elt) t = [-3, 4]idiv target value divides target by value in-place.
pow base exponent computes element-wise power.
ipow target exponent raises target to exponent in-place.
mod_s t scalar computes modulo scalar for each element.
imod target divisor computes modulo in-place.
conjugate x computes the complex conjugate.
For complex tensors, negates the imaginary part of each element. For real tensors, returns the input unchanged.
# let x = create complex32 [| 2 |]
[|Complex.{re=1.; im=2.}; Complex.{re=3.; im=4.}|] in
conjugate x |> to_array
- : Complex.t array =
[|{Complex.re = 1.; im = -2.}; {Complex.re = 3.; im = -4.}|]Unary mathematical operations and special functions.
sign t returns -1, 0, or 1 based on sign.
For unsigned types, returns 1 for all non-zero values, 0 for zero.
# let x = create float32 [| 3 |] [| -2.; 0.; 3.5 |] in
sign x
- : (float, float32_elt) t = [-1, 0, 1]atan2 y x computes arctangent of y/x using signs to determine quadrant.
Returns angle in radians in range -ฯ, ฯ. Handles x=0 correctly.
# let y = scalar float32 1. in
let x = scalar float32 1. in
atan2 y x |> item [] |> Float.round
- : float = 1.
# let y = scalar float32 1. in
let x = scalar float32 0. in
atan2 y x |> item [] |> Float.round
- : float = 2.
# let y = scalar float32 0. in
let x = scalar float32 0. in
atan2 y x |> item []
- : float = 0.hypot x y computes sqrt(xยฒ + yยฒ) avoiding overflow.
Uses numerically stable algorithm: max * sqrt(1 + (min/max)ยฒ).
# let x = scalar float32 3. in
let y = scalar float32 4. in
hypot x y |> item []
- : float = 5.
# let x = scalar float64 1e200 in
let y = scalar float64 1e200 in
hypot x y |> item [] < Float.infinity
- : bool = truetrunc t rounds toward zero.
Removes fractional part. Positive values round down, negative round up.
# let x = create float32 [| 3 |] [| 2.7; -2.7; 2.0 |] in
trunc x
- : (float, float32_elt) t = [2, -2, 2]ceil t rounds up to nearest integer.
Smallest integer not less than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
ceil x
- : (float, float32_elt) t = [3, 3, -2, -2]floor t rounds down to nearest integer.
Largest integer not greater than input.
# let x = create float32 [| 4 |] [| 2.1; 2.9; -2.1; -2.9 |] in
floor x
- : (float, float32_elt) t = [2, 2, -3, -3]round t rounds to nearest integer (half away from zero).
Ties round away from zero (not banker's rounding).
# let x = create float32 [| 4 |] [| 2.5; 3.5; -2.5; -3.5 |] in
round x
- : (float, float32_elt) t = [3, 4, -3, -4]lerp start end_ weight computes linear interpolation.
Returns start + weight * (end_ - start). Weight typically in 0, 1.
# let start = scalar float32 0. in
let end_ = scalar float32 10. in
let weight = scalar float32 0.3 in
lerp start end_ weight |> item []
- : float = 3.
# let start = create float32 [| 2 |] [| 1.; 2. |] in
let end_ = create float32 [| 2 |] [| 5.; 8. |] in
let weight = create float32 [| 2 |] [| 0.25; 0.5 |] in
lerp start end_ weight
- : (float, float32_elt) t = [2, 5]lerp_scalar_weight start end_ weight interpolates with scalar weight.
Element-wise comparisons and logical operations.
cmplt t1 t2 returns true where t1 < t2, false elsewhere.
less_s t scalar checks if each element is less than scalar and returns booleans.
cmpne t1 t2 returns true where t1 โ t2, false elsewhere.
not_equal t1 t2 is synonym for cmpne.
not_equal_s t scalar compares each element with scalar for inequality and returns booleans.
cmpeq t1 t2 returns true where t1 = t2, false elsewhere.
equal_s t scalar compares each element with scalar for equality and returns booleans.
cmpgt t1 t2 returns true where t1 > t2, false elsewhere.
greater t1 t2 is synonym for cmpgt.
greater_s t scalar checks if each element is greater than scalar and returns booleans.
cmple t1 t2 returns true where t1 โค t2, false elsewhere.
less_equal t1 t2 is synonym for cmple.
less_equal_s t scalar checks if each element is less than or equal to scalar and returns booleans.
cmpge t1 t2 returns true where t1 โฅ t2, false elsewhere.
greater_equal t1 t2 is synonym for cmpge.
greater_equal_s t scalar checks if each element is greater than or equal to scalar and returns booleans.
array_equal t1 t2 returns scalar 1 if all elements equal, 0 otherwise.
Broadcasts inputs before comparison. Returns 0 if shapes incompatible.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
let y = create int32 [| 3 |] [| 1l; 2l; 3l |] in
array_equal x y |> item []
- : bool = true
# let x = create int32 [| 2 |] [| 1l; 2l |] in
let y = create int32 [| 2 |] [| 1l; 3l |] in
array_equal x y |> item []
- : bool = falsemaximum t1 t2 returns element-wise maximum.
maximum_s t scalar returns maximum of each element and scalar.
imaximum target value computes maximum in-place.
imaximum_s t scalar computes maximum with scalar in-place.
minimum t1 t2 returns element-wise minimum.
minimum_s t scalar returns minimum of each element and scalar.
iminimum target value computes minimum in-place.
iminimum_s t scalar computes minimum with scalar in-place.
logical_and t1 t2 computes element-wise AND.
Non-zero values are true.
logical_or t1 t2 computes element-wise OR.
logical_xor t1 t2 computes element-wise XOR.
logical_not t computes element-wise NOT.
Returns 1 - x. Non-zero values become 0, zero becomes 1.
# let x = create int32 [| 3 |] [| 0l; 1l; 5l |] in
logical_not x
- : (int32, int32_elt) t = [1, 0, -4]isinf t returns 1 where infinite, 0 elsewhere.
Detects both positive and negative infinity. Non-float types return all 0s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.neg_infinity; Float.nan |] in
isinf x
- : (bool, bool_elt) t = [false, true, true, false]isnan t returns 1 where NaN, 0 elsewhere.
NaN is the only value that doesn't equal itself. Non-float types return all 0s.
# let x = create float32 [| 3 |] [| 1.; Float.nan; Float.infinity |] in
isnan x
- : (bool, bool_elt) t = [false, true, false]isfinite t returns 1 where finite, 0 elsewhere.
Finite means not inf, -inf, or NaN. Non-float types return all 1s.
# let x = create float32 [| 4 |] [| 1.; Float.infinity; Float.nan; -0. |] in
isfinite x
- : (bool, bool_elt) t = [true, false, false, true]where cond if_true if_false selects elements based on condition.
Returns if_true where cond is true, if_false elsewhere. All three inputs broadcast to common shape.
# let cond = create bool [| 3 |] [| true; false; true |] in
let if_true = create int32 [| 3 |] [| 2l; 3l; 4l |] in
let if_false = create int32 [| 3 |] [| 5l; 6l; 7l |] in
where cond if_true if_false
- : (int32, int32_elt) t = [2, 6, 4]
# let x = create float32 [| 4 |] [| -1.; 2.; -3.; 4. |] in
where (cmpgt x (scalar float32 0.)) x (scalar float32 0.)
- : (float, float32_elt) t = [0, 2, 0, 4]clamp ?min ?max t limits values to range.
Elements below min become min, above max become max.
clip ?min ?max t is synonym for clamp.
Bitwise operations on integer arrays.
bitwise_xor t1 t2 computes element-wise XOR.
bitwise_or t1 t2 computes element-wise OR.
bitwise_and t1 t2 computes element-wise AND.
invert t is synonym for bitwise_not.
lshift t shift left-shifts elements by shift bits.
Equivalent to multiplication by 2^shift. Overflow wraps around.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
lshift x 2
- : (int32, int32_elt) t = [4, 8, 12]rshift t shift right-shifts elements by shift bits.
Equivalent to integer division by 2^shift (rounds toward zero).
# let x = create int32 [| 3 |] [| 8l; 9l; 10l |] in
rshift x 2
- : (int32, int32_elt) t = [2, 2, 2]Functions that reduce array dimensions.
sum ?axes ?keepdims t sums elements along specified axes.
Default sums all axes (returns scalar). If keepdims is true, retains reduced dimensions with size 1. Negative axes count from end.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum x |> item []
- : float = 10.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
sum ~axes:[ 0 ] x
- : (float, float32_elt) t = [4, 6]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
sum ~axes:[ 1 ] ~keepdims:true x
- : (float, float32_elt) t = [[3]]
# let x = create float32 [| 1; 3 |] [| 1.; 2.; 3. |] in
sum ~axes:[ -1 ] x
- : (float, float32_elt) t = [6]max ?axes ?keepdims t finds maximum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
max x |> item []
- : float = 6.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
max ~axes:[ 0 ] x
- : (float, float32_elt) t = [3, 4]
# let x = create float32 [| 1; 2 |] [| 1.; 2. |] in
max ~axes:[ 1 ] ~keepdims:true x
- : (float, float32_elt) t = [[2]]min ?axes ?keepdims t finds minimum along axes.
Default reduces all axes. NaN propagates (any NaN input gives NaN output).
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
min x |> item []
- : float = 1.
# let x = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
min ~axes:[ 0 ] x
- : (float, float32_elt) t = [1, 2]prod ?axes ?keepdims t computes product along axes.
Default multiplies all elements. Empty axes give 1.
# let x = create int32 [| 3 |] [| 2l; 3l; 4l |] in
prod x |> item []
- : int32 = 24l
# let x = create int32 [| 2; 2 |] [| 1l; 2l; 3l; 4l |] in
prod ~axes:[ 0 ] x
- : (int32, int32_elt) t = [3, 8]cumsum ?axis t computes the inclusive cumulative sum. Defaults to flattening the tensor (row-major order) when axis is omitted.
cumprod ?axis t computes the inclusive cumulative product. Defaults to flattening the tensor when axis is omitted.
cummax ?axis t computes the inclusive cumulative maximum. NaNs propagate for floating-point dtypes. Defaults to flattening when axis is omitted.
cummin ?axis t computes the inclusive cumulative minimum. NaNs propagate for floating-point dtypes. Defaults to flattening when axis is omitted.
mean ?axes ?keepdims t computes arithmetic mean along axes.
Sum of elements divided by count. NaN propagates.
# let x = create float32 [| 4 |] [| 1.; 2.; 3.; 4. |] in
mean x |> item []
- : float = 2.5
# let x = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
mean ~axes:[ 1 ] x
- : (float, float32_elt) t = [2, 5]var ?axes ?keepdims ?ddof t computes variance along axes.
ddof is delta degrees of freedom. Default 0 (population variance). Use 1 for sample variance. Variance = E(X - E[X])ยฒ / (N - ddof).
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var x |> item []
- : float = 2.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
var ~ddof:1 x |> item []
- : float = 2.5std ?axes ?keepdims ?ddof t computes standard deviation.
Square root of variance: sqrt(var(t, ddof)). See var for ddof meaning.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std x |> item [] |> Float.round
- : float = 1.
# let x = create float32 [| 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
std ~ddof:1 x |> item [] |> Float.round
- : float = 2.all ?axes ?keepdims t tests if all elements are true (non-zero).
Returns true if all elements along axes are non-zero, false otherwise.
# let x = create int32 [| 3 |] [| 1l; 2l; 3l |] in
all x |> item []
- : bool = true
# let x = create int32 [| 3 |] [| 1l; 0l; 3l |] in
all x |> item []
- : bool = false
# let x = create int32 [| 2; 2 |] [| 1l; 0l; 1l; 1l |] in
all ~axes:[ 1 ] x
- : (bool, bool_elt) t = [false, true]any ?axes ?keepdims t tests if any element is true (non-zero).
Returns true if any element along axes is non-zero, false if all are zero.
# let x = create int32 [| 3 |] [| 0l; 0l; 1l |] in
any x |> item []
- : bool = true
# let x = create int32 [| 3 |] [| 0l; 0l; 0l |] in
any x |> item []
- : bool = false
# let x = create int32 [| 2; 2 |] [| 0l; 0l; 0l; 1l |] in
any ~axes:[ 1 ] x
- : (bool, bool_elt) t = [false, true]argmax ?axis ?keepdims t finds indices of maximum values.
Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmax x |> item []
- : int32 = 4l
# let x = create int32 [| 2; 3 |] [| 1l; 5l; 3l; 2l; 4l; 6l |] in
argmax ~axis:1 x
- : (int32, int32_elt) t = [1, 2]argmin ?axis ?keepdims t finds indices of minimum values.
Returns index of first occurrence for ties. If axis not specified, operates on flattened tensor and returns scalar.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argmin x |> item []
- : int32 = 1l
# let x = create int32 [| 2; 3 |] [| 5l; 2l; 3l; 1l; 4l; 0l |] in
argmin ~axis:1 x
- : (int32, int32_elt) t = [1, 2]Functions for sorting arrays and finding indices.
sort ?descending ?axis t sorts elements along axis.
Returns (sorted_values, indices) where indices map sorted positions to original positions. Default sorts last axis in ascending order.
Algorithm: Bitonic sort (parallel-friendly, stable)
Special values:
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
sort x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([1, 1, 3, 4, 5], [1, 3, 0, 2, 4])
# let x = create int32 [| 2; 2 |] [| 3l; 1l; 1l; 4l |] in
sort ~descending:true ~axis:0 x
- : (int32, int32_elt) t * (int32, int32_elt) t =
([[3, 4],
[1, 1]], [[0, 1],
[1, 0]])
# let x = create float32 [| 4 |] [| Float.nan; 1.; 2.; Float.nan |] in
let v, _ = sort x in
v
- : (float, float32_elt) t = [1, 2, nan, nan]argsort ?descending ?axis t returns indices that would sort tensor.
Equivalent to snd (sort ?descending ?axis t). Returns indices such that taking elements at these indices yields sorted array.
For 1-D: resulti is the index of the i-th smallest element. For N-D: sorts along specified axis independently.
# let x = create int32 [| 5 |] [| 3l; 1l; 4l; 1l; 5l |] in
argsort x
- : (int32, int32_elt) t = [1, 3, 0, 2, 4]
# let x = create int32 [| 2; 3 |] [| 3l; 1l; 4l; 2l; 5l; 0l |] in
argsort ~axis:1 x
- : (int32, int32_elt) t = [[1, 0, 2],
[2, 0, 1]]Matrix operations and linear algebra functions.
Most linear algebra functions require floating-point or complex tensors. Functions will raise Invalid_argument if given integer tensors.
dot a b computes generalized dot product.
Important: dot has different broadcasting behavior than matmul:
matmul broadcasts batch dimensionsdot does NOT broadcast; it concatenates non-contracted dimensionsFor N-D ร M-D arrays: dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m]) This can result in much larger output arrays than matmul.
Contracts last axis of a with:
b: the only axis (axis 0)b: second-to-last axis (axis -2)Dimension rules:
Supports broadcasting on batch dimensions. Result shape is concatenation of:
a (except last)b (except contracted axis) # let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2 |] [| 3.; 4. |] in
dot a b |> item []
- : float = 11.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2; 2 |] [| 5.; 6.; 7.; 8. |] in
dot a b
- : (float, float32_elt) t = [[19, 22],
[43, 50]]
# dot (ones float32 [| 3; 4; 5 |]) (ones float32 [| 5; 6 |]) |> shape
- : int array = [|3; 4; 6|]
# dot (ones float32 [| 2; 3; 4; 5 |]) (ones float32 [| 3; 5; 6 |]) |> shape
- : int array = [|2; 3; 4; 6|]matmul a b computes matrix multiplication with broadcasting.
Follows NumPy's @ operator semantics:
1 ร k @ ... ร k ร n โ ... ร n... ร m ร k @ k ร 1 โ ... ร mBroadcasting rules:
-1 == b.shape-2Result shape:
:-2, b.shape:-2)..., a.shape[-2], b.shape[-1] # let a = create float32 [| 3 |] [| 1.; 2.; 3. |] in
let b = create float32 [| 3 |] [| 4.; 5.; 6. |] in
matmul a b |> item []
- : float = 32.
# let a = create float32 [| 2; 2 |] [| 1.; 2.; 3.; 4. |] in
let b = create float32 [| 2 |] [| 5.; 6. |] in
matmul a b
- : (float, float32_elt) t = [17, 39]
# let a = create float32 [| 2 |] [| 1.; 2. |] in
let b = create float32 [| 2; 3 |] [| 3.; 4.; 5.; 6.; 7.; 8. |] in
matmul a b
- : (float, float32_elt) t = [15, 18, 21]
# matmul (ones float32 [| 10; 3; 4 |]) (ones float32 [| 10; 4; 5 |]) |> shape
- : int array = [|10; 3; 5|]
# matmul (ones float32 [| 1; 3; 4 |]) (ones float32 [| 5; 4; 2 |]) |> shape
- : int array = [|5; 3; 2|]diagonal ?offset ?axis1 ?axis2 a extracts diagonal from 2-D planes.
offset: diagonal offset (0=main, positive=above, negative=below)axis1, axis2: axes of 2-D planes (default: last two axes)For 2-D array, returns 1-D array of diagonal elements. For N-D array, returns array with diagonals from each 2-D subarray.
matrix_transpose a transposes matrix dimensions.
Swaps last two axes: ..., M, N -> ..., N, M. For 1-D arrays, returns unchanged.
This is specifically for matrix operations, unlike general transpose which can permute any axes.
vdot a b returns dot product of two vectors.
For complex vectors, conjugates first vector before multiplication. Always returns scalar tensor regardless of input shapes. Flattens inputs before computation.
vecdot ?axis x1 x2 computes vector dot product along an axis.
axis: axis along which to compute dot product (default: -1)Unlike vdot which always flattens, vecdot computes dot products along specified axis with broadcasting support.
inner a b computes inner product over last axes.
For 1-D arrays, this is ordinary inner product. For higher dimensions, sums products over last axes of a and b.
outer a b computes outer product of two vectors.
Given vectors ai and bj, produces matrix Mi,j = ai * bj. Input tensors are flattened if not already 1-D.
# outer (create float32 [|2|] [|1.; 2.|]) (create float32 [|3|] [|3.; 4.; 5.|])
- : (float, float32_elt) t = [[3, 4, 5],
[6, 8, 10]]tensordot ?axes a b computes tensor contraction along specified axes.
axes: pair of axis lists to contract (default: last of a, first of b)Generalizes matrix multiplication to arbitrary dimensions.
einsum subscripts operands evaluates Einstein summation convention.
Subscripts string specifies contraction, e.g., "ij,jk->ik" for matmul. Repeated indices are summed, free indices form output dimensions.
# let a = create float32 [| 2; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
let b = create float32 [| 3; 2 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
shape (einsum "ij,jk->ik" [|a; b|]) (* matrix multiplication *)
- : int array = [|2; 2|]
# let a = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
shape (einsum "ii->i" [|a|]) (* diagonal *)
- : int array = [|3|]
# let a = create float32 [| 3; 3 |] [| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.; 9. |] in
shape (einsum "ij->ji" [|a|]) (* transpose *)
- : int array = [|3; 3|]kron a b computes Kronecker product.
Result has shape a.shape[i] * b.shape[i] for i in range(ndim). Each element ai,j is replaced by ai,j * b.
multi_dot arrays computes chained matrix multiplication optimally.
Automatically selects the association order that minimizes computational cost. Much more efficient than repeated matmul for chains of 3+ matrices.
matrix_power a n raises square matrix to integer power.
cross ?axis a b returns cross product of 3-element vectors.
axis: axis containing vectors (default: last axis)cholesky ?upper a computes Cholesky decomposition.
upper: return upper triangular if true (default: false)Returns L (or U) such that a = L @ L.T (or U.T @ U).
qr ?mode a computes QR decomposition.
mode: `Reduced for economy mode (default), `Complete for fullReturns (Q, R) where a = Q @ R, Q is orthogonal, R is upper triangular.
val svd :
?full_matrices:bool ->
('a, 'b) t ->
('a, 'b) t * (float, float64_elt) t * ('a, 'b) tsvd ?full_matrices a computes singular value decomposition.
full_matrices: compute full U, V matrices (default: false)Returns (U, S, Vh) where a = U @ diag(S) @ Vh. S is 1-D array of singular values in descending order.
svdvals a returns singular values only.
More efficient than svd when only singular values are needed.
eig a computes eigenvalues and right eigenvectors.
Returns (eigenvalues, eigenvectors) for general square matrix. For real float32/float64 inputs, outputs are complex32/complex64 since real matrices can have complex eigenvalues.
eigh ?uplo a computes eigenvalues for symmetric/Hermitian matrix.
uplo: use upper (`U) or lower (`L) triangle (default: `L`)Returns (eigenvalues, eigenvectors) in ascending order. For real symmetric matrices, eigenvalues are guaranteed real. More efficient than eig for symmetric matrices.
eigvals a computes eigenvalues only.
For real inputs, returns complex tensor since eigenvalues may be complex. More efficient than eig when eigenvectors not needed.
eigvalsh ?uplo a computes eigenvalues for symmetric/Hermitian matrix.
For real symmetric inputs, returns real eigenvalues. More efficient than eigvals for symmetric matrices.
val norm :
?ord:
[ `Fro
| `Nuc
| `One
| `Two
| `Inf
| `NegOne
| `NegTwo
| `NegInf
| `P of float ] ->
?axes:int list ->
?keepdims:bool ->
('a, 'b) t ->
('a, 'b) tnorm ?ord ?axes ?keepdims x computes matrix or vector norm.
ord: norm type (default: Frobenius for matrices, 2-norm for vectors)`Fro: Frobenius norm`Nuc: nuclear norm (sum of singular values)`One: max column sum (for matrices)`Two: spectral norm (largest singular value)`Inf: max row sum (for matrices)`NegOne: min column sum`NegTwo: smallest singular value`NegInf: min row sum`P p: p-norm for vectorsaxes: axes to compute norm over. For matrix norms, must be 2-element listkeepdims: keep reduced dimensions as size 1val cond :
?p:[ `One | `Two | `Inf | `NegOne | `NegTwo | `NegInf | `Fro ] ->
('a, 'b) t ->
('a, 'b) tcond ?p x computes condition number.
p: norm to use (default: 2-norm)`One: 1-norm (max column sum)`Two: 2-norm (max singular value)`Inf: infinity norm (max row sum)`NegOne: -1 norm (min column sum)`NegTwo: -2 norm (min singular value)`NegInf: -infinity norm (min row sum)`Fro: Frobenius normReturns ratio of largest to smallest norm.
slogdet a computes sign and log of determinant.
Returns (sign, logdet) where det(a) = sign * exp(logdet). More stable than det for matrices with very small/large determinants.
matrix_rank ?tol ?rtol ?hermitian a returns rank of matrix.
tol: absolute tolerance for small singular valuesrtol: relative tolerance (default: max(M,N) * eps * largest_singular_value)hermitian: if true, use more efficient algorithm for Hermitian matricesCounts singular values greater than tolerance.
trace ?offset a returns sum along diagonal.
offset: diagonal offset (default: 0, positive for upper diagonals)solve a b solves linear system a @ x = b for x.
Supports batched operations when a, b have compatible batch dimensions.
val lstsq :
?rcond:float ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t * ('a, 'b) t * int * (float, float64_elt) tlstsq ?rcond a b computes least-squares solution to a @ x = b.
rcond: cutoff for small singular values (default: machine precision)Returns (solution, residuals, rank, singular_values). Handles over/under-determined systems.
pinv ?rtol ?hermitian a computes Moore-Penrose pseudoinverse.
rtol: relative tolerance for small singular valueshermitian: if true, use more efficient algorithm for Hermitian matricesHandles non-square and singular matrices.
tensorsolve ?axes a b solves tensor equation a x = b for x.
axes: axes in a to reorder to end (default: product of b.ndim rightmost axes)Solves for x such that tensordot(a, x, axes) = b.
tensorinv ?ind a computes 'inverse' of N-D array.
ind: number of first indices involved in inverse sum (default: 2)Result is such that tensordot(a, a_inv, ind) = I.
Fast Fourier Transform (FFT) and related signal processing functions.
FFT normalization mode:
`Backward: normalize by 1/n on inverse transform (default)`Forward: normalize by 1/n on forward transform`Ortho: normalize by 1/sqrt(n) on both transformsfft ?axis ?n ?norm x computes discrete Fourier transform over specified axis.
axis: axis to transform (default: last axis)n: Length of the transformed axis of the outputnorm: normalization mode (default: `Backward)Computing 1D FFT of a signal:
# let x = create complex64 [|4|]
[|Complex.{re=0.; im=0.}; {re=1.; im=0.};
{re=2.; im=0.}; {re=3.; im=0.}|] in
let result = fft ~axis:0 x in
shape result
- : int array = [|4|]ifft ?axis ?s ?norm x computes inverse discrete Fourier transform.
axis: axis to transform (default: last axis)n: Length of the transformed axis of the outputnorm: normalization mode (default: `Backward)val fft2 :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(Complex.t, 'a) tfft2 ?axes ?s ?norm x computes 2-dimensional FFT.
Transforms last two axes by default. Truncates or pads to shape s if given.
Computing 2D FFT of a 2x2 matrix:
# let x = create complex64 [|2; 2|]
[|Complex.{re=1.; im=0.}; {re=2.; im=0.};
{re=3.; im=0.}; {re=4.; im=0.}|] in
shape (fft2 x)
- : int array = [|2; 2|]val ifft2 :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(Complex.t, 'a) tifft2 ?axes ?s ?norm x computes 2-dimensional inverse FFT.
val fftn :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(Complex.t, 'a) tfftn ?axes ?s ?norm x computes N-dimensional FFT.
Transforms all axes by default.
val ifftn :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(Complex.t, 'a) tifftn ?axes ?s ?norm x computes N-dimensional inverse FFT.
val rfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(float, 'a) t ->
(Complex.t, complex64_elt) trfft ?axis ?n ?norm x computes FFT of real input.
Returns only non-redundant positive frequencies. Output size along last transformed axis is n/2+1 where n is input size.
axes: axes to transform (default: last axis)s: shape to truncate/pad to before transformnorm: normalization mode (default: `Backward)Computing real FFT:
# let x = create float64 [|4|] [|0.; 1.; 2.; 3.|] in
let result = rfft ~axis:0 x in
shape result
- : int array = [|3|]val irfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(float, float64_elt) tirfft ?axis ?n ?norm x computes inverse FFT returning real output.
Assumes Hermitian symmetry.
axis: axis to transform (default: last axis)n: output shape along transformed axesnorm: normalization mode (default: `Backward)val rfft2 :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(float, 'a) t ->
(Complex.t, complex64_elt) trfft2 ?axes ?s ?norm x computes 2D FFT of real input.
val irfft2 :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(float, float64_elt) tirfft2 ?axes ?s ?norm x computes 2D inverse FFT returning real output.
val rfftn :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(float, 'a) t ->
(Complex.t, complex64_elt) trfftn ?axes ?s ?norm x computes N-dimensional FFT of real input.
val irfftn :
?axes:int list ->
?s:int list ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(float, float64_elt) tirfftn ?axes ?s ?norm x computes N-dimensional inverse FFT returning real output.
val hfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(Complex.t, 'a) t ->
(float, float64_elt) thfft x ~n ~axis computes FFT of Hermitian signal.
Interprets input as positive frequencies of Hermitian signal.
val ihfft :
?axis:int ->
?n:int ->
?norm:fft_norm ->
(float, 'a) t ->
(Complex.t, complex64_elt) tihfft x ~n ~axis computes inverse FFT for Hermitian output.
fftfreq ?d n returns DFT sample frequencies.
For window length n and sample spacing d, returns frequencies 0, 1, ..., n/2-1, -n/2, ..., -1 / (d*n) if n is even.
Getting frequencies for 4-point FFT:
# Nx.fftfreq 4
- : (float, float64_elt) t = [0, 0.25, -0.5, -0.25]rfftfreq ?d n returns positive DFT frequencies.
Returns 0, 1, ..., n/2 / (d*n).
fftshift x ?axes shifts zero-frequency component to center.
Shifts all axes by default. For visualization of frequency spectra.
Centering frequency spectrum:
# let freqs = fftfreq 5 in
fftshift freqs
- : (float, float64_elt) t = [-0.4, -0.2, 0, 0.2, 0.4]Neural network activation functions.
relu t applies Rectified Linear Unit: max(0, x).
# let x = create float32 [| 5 |] [| -2.; -1.; 0.; 1.; 2. |] in
relu x
- : (float, float32_elt) t = [0, 0, 0, 1, 2]relu6 t applies ReLU6: min(max(0, x), 6).
Bounded ReLU used in mobile networks. Clips to 0, 6 range.
# let x = create float32 [| 3 |] [| -1.; 3.; 8. |] in
relu6 x
- : (float, float32_elt) t = [0, 3, 6]sigmoid t applies logistic sigmoid: 1 / (1 + exp(-x)).
Output in range (0, 1). Symmetric around x=0 where sigmoid(0) = 0.5.
# sigmoid (scalar float32 0.) |> item []
- : float = 0.5
# sigmoid (scalar float32 10.) |> item [] |> Float.round
- : float = 1.
# sigmoid (scalar float32 (-10.)) |> item [] |> Float.round
- : float = 0.hard_sigmoid ?alpha ?beta t applies piecewise linear sigmoid approximation.
Default alpha = 1/6, beta = 0.5.
softplus t applies smooth ReLU: log(1 + exp(x)).
Smooth approximation to ReLU. Always positive, differentiable everywhere.
# softplus (scalar float32 0.) |> item [] |> Float.round
- : float = 1.
# softplus (scalar float32 100.) |> item [] |> Float.round
- : float = infinitysilu t applies Sigmoid Linear Unit: x * sigmoid(x).
Also called Swish. Smooth, non-monotonic activation.
# silu (scalar float32 0.) |> item []
- : float = 0.
# silu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.
# silu (scalar float32 (-1.)) |> item [] |> Float.round
- : float = -0.hard_silu t applies x * hard_sigmoid(x).
Piecewise linear approximation of SiLU. More efficient than SiLU.
# let x = create float32 [| 3 |] [| -3.; 0.; 3. |] in
hard_silu x
- : (float, float32_elt) t = [-0, 0, 3]prelu alpha x applies Parametric ReLU: max(0,x) + alpha * min(0,x).
alpha must broadcast across the shape of x.
log_sigmoid t computes log(sigmoid(x)).
Numerically stable version of log(1/(1+exp(-x))). Always negative.
# log_sigmoid (scalar float32 0.) |> item [] |> Float.round
- : float = -1.
# log_sigmoid (scalar float32 100.) |> item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = trueleaky_relu ?negative_slope t applies Leaky ReLU.
Default negative_slope = 0.01. Returns x if x > 0, else negative_slope * x.
hard_tanh t clips values to -1, 1.
Linear in -1, 1, saturates outside. Cheaper than tanh.
# let x = create float32 [| 5 |] [| -2.; -0.5; 0.; 0.5; 2. |] in
hard_tanh x
- : (float, float32_elt) t = [-1, -0.5, 0, 0.5, 1]elu ?alpha t applies Exponential Linear Unit.
Default alpha = 1.0. Returns x if x > 0, else alpha * (exp(x) - 1). Smooth for x < 0, helps with vanishing gradients.
# elu (scalar float32 1.) |> item []
- : float = 1.
# elu (scalar float32 0.) |> item []
- : float = 0.
# elu (scalar float32 (-1.)) |> item [] |> Float.round
- : float = -1.selu t applies Scaled ELU with fixed alpha=1.67326, lambda=1.0507.
Self-normalizing activation. Preserves mean 0 and variance 1 in deep networks under certain conditions.
# selu (scalar float32 0.) |> item []
- : float = 0.
# selu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.celu ?alpha x applies continuously differentiable ELU.
Returns x if x >= 0, else alpha * (exp(x / alpha) - 1). Default alpha = 1.0.
squareplus ?b x applies smooth ReLU: 0.5 * (x + sqrt(xยฒ + b)).
Default b = 4.
glu ?axis x applies the gated linear unit.
Splits x into two equal parts along axis and returns x1 * sigmoid(x2). Default axis is the last dimension.
sparse_plus x applies the piecewise function:
sparse_sigmoid x applies the piecewise sparse sigmoid:
softmax ?axes ?scale t applies softmax normalization.
Default axis -1. Computes exp(scale * (x - max)) / sum(exp(scale * (x - max))) for numerical stability. Output sums to 1 along specified axes. scale defaults to 1.
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
softmax x |> to_array |> Array.map Float.round
- : float array = [|0.; 0.; 1.|]
# let x = create float32 [| 3 |] [| 1.; 2.; 3. |] in
sum (softmax x) |> item []
- : float = 1.logsumexp ?axes ?keepdims t computes log(sum(exp(t))) in a numerically stable manner along axes. Defaults to reducing across all axes.
logmeanexp ?axes ?keepdims t computes log(mean(exp(t))) in a numerically stable manner along axes. Equivalent to logsumexp minus log of the number of elements.
val standardize :
?axes:int list ->
?mean:(float, 'a) t ->
?variance:(float, 'a) t ->
?epsilon:float ->
(float, 'a) t ->
(float, 'a) tstandardize ?axes ?mean ?variance ?epsilon x normalizes x to zero mean and unit variance.
If mean or variance are not provided they are computed along axes (default: all axes). The result is (x - mean) / sqrt(variance + epsilon).
val batch_norm :
?axes:int list ->
?epsilon:float ->
scale:(float, 'a) t ->
bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) tbatch_norm ?axes ?epsilon ~scale ~bias x applies batch normalization.
Normalizes x along axes (defaults to `0` for 2D and `0;2;3` for 4D) and applies learnable scale and bias. scale and bias must broadcast across the normalized axes.
val layer_norm :
?axes:int list ->
?epsilon:float ->
?gamma:(float, 'a) t ->
?beta:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) tlayer_norm ?axes ?epsilon ?gamma ?beta x applies layer normalization.
Normalizes x along axes (default last axis), subtracting mean and dividing by sqrt(variance + epsilon). Optional gamma (scale) and beta (offset) are broadcast across the normalized axes.
val rms_norm :
?axes:int list ->
?epsilon:float ->
?gamma:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) trms_norm ?axes ?epsilon ?gamma x applies Root Mean Square normalization.
Normalizes x by the root mean square across axes. Optional gamma scales the normalized output.
val embedding :
?scale:bool ->
embedding:(float, 'a) t ->
(int32, int32_elt) t ->
(float, 'a) tembedding ?scale ~embedding indices performs embedding lookup.
embedding must have shape vocab_size; embed_dimindices is an int32 tensor of positionsindices_shape ++ [embed_dim]scale is true (default) multiplies result by โembed_dimval dot_product_attention :
?attention_mask:(bool, bool_elt) t ->
?scale:float ->
?dropout_rate:float ->
?dropout_seed:int ->
?is_causal:bool ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) tdot_product_attention ?attention_mask ?scale ?dropout_rate ?dropout_seed ?is_causal q k v computes scaled dot-product attention.
q, k, v must have matching leading dimensions; the last dimension of q and k must match.scale defaults to 1/โdepth of keys.is_causal, when true, applies a causal (lower triangular) mask to prevent attending to future positions. Requires seq_len_q == seq_len_k.attention_mask, when provided, should broadcast to the attention score shape; true values keep scores, false values set them to -โ before softmax.dropout_rate, when provided, applies dropout to the attention probabilities before multiplying by v. dropout_seed controls the random mask when supplied.erf t computes the error function.
The error function erf(x) = (2/โฯ) โซโหฃ e^(-tยฒ) dt. Uses Abramowitz and Stegun approximation for numerical stability.
# erf (scalar float32 0.) |> item []
- : float = 0.
# let result = erf (scalar float32 1.) |> item [] in
Float.round (result *. 10000.) /. 10000. (* Round to 4 decimals *)
- : float = 0.8427gelu t applies exact Gaussian Error Linear Unit.
GELU(x) = 0.5 * x * (1 + erf(x / โ2))
This is the exact GELU using error function, more numerically stable than the approximation for gradient computation.
# gelu (scalar float32 0.) |> item []
- : float = 0.
# gelu (scalar float32 1.) |> item [] |> Float.round
- : float = 1.gelu_approx t applies Gaussian Error Linear Unit approximation.
Smooth activation: x * ฮฆ(x) where ฮฆ is Gaussian CDF. This uses tanh approximation for efficiency.
# gelu_approx (scalar float32 0.) |> item []
- : float = 0.
# gelu_approx (scalar float32 1.) |> item [] |> Float.round
- : float = 1.softsign t computes x / (|x| + 1).
Similar to tanh but computationally cheaper. Range (-1, 1).
# let x = create float32 [| 3 |] [| -10.; 0.; 10. |] in
softsign x
- : (float, float32_elt) t = [-0.909091, 0, 0.909091]mish t applies Mish activation: x * tanh(softplus(x)).
Self-regularizing non-monotonic activation. Smoother than ReLU.
# mish (scalar float32 0.) |> item [] |> Float.abs |> (fun x -> x < 0.001)
- : bool = true
# mish (scalar float32 (-10.)) |> item [] |> Float.round
- : float = -0.Neural network convolution and pooling operations.
val im2col :
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) tim2col ~kernel_size ~stride ~dilation ~padding t extracts sliding local blocks from tensor.
Extracts patches of size kernel_size from the input tensor at the specified stride and dilation.
kernel_size: size of sliding blocks to extractstride: step between consecutive blocksdilation: spacing between kernel elementspadding: (before, after) padding for each spatial dimensionFor a 4D input batch; channels; height; width, produces output shape batch; channels * kh * kw; num_patches_h; num_patches_w where kh, kw are kernel dimensions and num_patches depends on stride and padding.
# let x = arange_f float32 0. 16. 1. |> reshape [| 1; 1; 4; 4 |] in
im2col ~kernel_size:[|2; 2|] ~stride:[|1; 1|]
~dilation:[|1; 1|] ~padding:[|(0, 0); (0, 0)|] x |> shape
- : int array = [|1; 4; 9|]val col2im :
output_size:int array ->
kernel_size:int array ->
stride:int array ->
dilation:int array ->
padding:(int * int) array ->
('a, 'b) t ->
('a, 'b) tcol2im ~output_size ~kernel_size ~stride ~dilation ~padding t combines sliding local blocks into tensor.
This is the inverse of im2col. Accumulates values from the unfolded representation back into spatial dimensions. Overlapping regions are summed.
output_size: target spatial dimensions height; widthkernel_size: size of sliding blocksstride: step between consecutive blocksdilation: spacing between kernel elementspadding: (before, after) padding for each spatial dimensionFor input shape batch; channels * kh * kw; num_patches_h; num_patches_w, produces output batch; channels; height; width.
# let unfolded = create float32 [| 1; 4; 9 |] (Array.init 36 Float.of_int) in
col2im ~output_size:[|4; 4|] ~kernel_size:[|2; 2|]
~stride:[|1; 1|] ~dilation:[|1; 1|]
~padding:[|(0, 0); (0, 0)|] unfolded |> shape
- : int array = [|1; 1; 4; 4|]val correlate1d :
?groups:int ->
?stride:int ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:int ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) tcorrelate1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D cross-correlation (no kernel flip).
x: input batch_size; channels_in; widthw: weights channels_out; channels_in/groups; kernel_widthbias: optional per-channel bias channels_outgroups: split input/output channels into groups (default 1)stride: step between windows (default 1)padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)dilation: spacing between kernel elements (default 1)fillvalue: padding value (default 0.0)Output width depends on padding:
# let x = create float32 [| 1; 1; 5 |] [| 1.; 2.; 3.; 4.; 5. |] in
let w = create float32 [| 1; 1; 3 |] [| 1.; 0.; -1. |] in
correlate1d x w |> shape
- : int array = [|1; 1; 3|]val correlate2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[ `Full | `Same | `Valid ] ->
?dilation:(int * int) ->
?fillvalue:float ->
?bias:(float, 'a) t ->
(float, 'a) t ->
(float, 'a) t ->
(float, 'a) tcorrelate2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D cross-correlation (no kernel flip).
x: input batch; channels_in; height; widthw: weights channels_out; channels_in/groups; kernel_h; kernel_wbias: optional per-channel bias channels_outstride: (stride_h, stride_w) step between windows (default (1,1))dilation: (dilation_h, dilation_w) kernel spacing (default (1,1))padding_mode: `Valid (no pad), `Same (preserve size), `Full (all overlaps)Uses Winograd F(4,3) for 3ร3 kernels with stride 1 when beneficial. For `Same` with even kernels, pads more on bottom/right (SciPy convention).
# let image = ones float32 [| 1; 1; 5; 5 |] in
let sobel_x = create float32 [| 1; 1; 3; 3 |] [| 1.; 0.; -1.; 2.; 0.; -2.; 1.; 0.; -1. |] in
correlate2d image sobel_x |> shape
- : int array = [|1; 1; 3; 3|]val convolve1d :
?groups:int ->
?stride:int ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:int ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) tconvolve1d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 1D convolution (flips kernel before correlation).
Same parameters as correlate1d but flips kernel. For `Same` with even kernels, pads more on left (NumPy convention).
# let x = create float32 [| 1; 1; 3 |] [| 1.; 2.; 3. |] in
let w = create float32 [| 1; 1; 2 |] [| 4.; 5. |] in
convolve1d x w
- : (float, float32_elt) t = [[[13, 22]]]val convolve2d :
?groups:int ->
?stride:(int * int) ->
?padding_mode:[< `Full | `Same | `Valid Valid ] ->
?dilation:(int * int) ->
?fillvalue:'a ->
?bias:('a, 'b) t ->
('a, 'b) t ->
('a, 'b) t ->
('a, 'b) tconvolve2d ?groups ?stride ?padding_mode ?dilation ?fillvalue ?bias x w computes 2D convolution (flips kernel before correlation).
Same parameters as correlate2d but flips kernel horizontally and vertically. For `Same` with even kernels, pads more on top/left.
# let image = ones float32 [| 1; 1; 5; 5 |] in
let gaussian = create float32 [| 1; 1; 3; 3 |] [| 1.; 2.; 1.; 2.; 4.; 2.; 1.; 2.; 1. |] in
convolve2d image (mul_s gaussian (1. /. 16.)) |> shape
- : int array = [|1; 1; 3; 3|]val avg_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) tavg_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 1D average pooling.
kernel_size: pooling window sizestride: step between windows (default: kernel_size)dilation: spacing between kernel elements (default 1)padding_spec: same as convolution padding modesceil_mode: use ceiling for output size calculation (default false)count_include_pad: include padding in average (default true)Input shape: batch; channels; width Output width: (width + 2*pad - dilation*(kernel-1) - 1)/stride + 1
# let x = create float32 [| 1; 1; 4 |] [| 1.; 2.; 3.; 4. |] in
avg_pool1d ~kernel_size:2 x
- : (float, float32_elt) t = [[[1.5, 3.5]]]val avg_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?count_include_pad:bool ->
(float, 'a) t ->
(float, 'a) tavg_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?count_include_pad x applies 2D average pooling.
kernel_size: (height, width) of pooling windowstride: (stride_h, stride_w) (default: kernel_size)dilation: (dilation_h, dilation_w) (default (1,1))count_include_pad: whether padding contributes to denominatorInput shape: batch; channels; height; width
# let x = create float32 [| 1; 1; 2; 2 |] [| 1.; 2.; 3.; 4. |] in
avg_pool2d ~kernel_size:(2, 2) x
- : (float, float32_elt) t = [[[[2.5]]]]val max_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t optionmax_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D max pooling.
return_indices: if true, also returns indices of max values for unpoolingavg_pool1dReturns (pooled_values, Some indices) if return_indices=true, otherwise (pooled_values, None). Indices are flattened positions in input.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let vals, idx = max_pool1d ~kernel_size:2 ~return_indices:true x in
vals, idx
- : (float, float32_elt) t * (int32, int32_elt) t option =
([[[3, 4]]], Some [[[1, 1]]])val max_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t optionmax_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D max pooling.
Parameters same as max_pool1d but for 2D. Indices encode flattened position within each pooling window.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = max_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[4, 8],
[12, 16]]]]val min_pool1d :
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t optionmin_pool1d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 1D min pooling.
return_indices: if true, also returns indices of min values (currently returns None)avg_pool1dReturns (pooled_values, None). Index tracking not yet implemented.
# let x = create float32 [| 1; 1; 4 |] [| 4.; 2.; 3.; 1. |] in
let vals, _ = min_pool1d ~kernel_size:2 x in
vals
- : (float, float32_elt) t = [[[2, 1]]]val min_pool2d :
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?ceil_mode:bool ->
?return_indices:bool ->
('a, 'b) t ->
('a, 'b) t * (int32, int32_elt) t optionmin_pool2d ~kernel_size ?stride ?dilation ?padding_spec ?ceil_mode ?return_indices x applies 2D min pooling.
Parameters same as min_pool1d but for 2D. Commonly used for morphological erosion operations in image processing.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 5.; 6.; 3.; 4.; 7.; 8.; 9.; 10.; 13.; 14.; 11.; 12.; 15.; 16. |] in
let vals, _ = min_pool2d ~kernel_size:(2, 2) ~stride:(2, 2) x in
vals
- : (float, float32_elt) t = [[[[1, 5],
[9, 13]]]]val max_unpool1d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:int ->
?stride:int ->
?dilation:int ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) tmax_unpool1d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses max pooling.
indices: indices from max_pool1d with return_indices=truevalues: pooled values to place at indexed positionskernel_size, stride, dilation, padding_spec: must match original pooloutput_size_opt: exact output shape (inferred if not provided)Places values at positions indicated by indices, fills rest with zeros. Output size computed from input unless explicitly specified.
# let x = create float32 [| 1; 1; 4 |] [| 1.; 3.; 2.; 4. |] in
let pooled, _ = max_pool1d ~kernel_size:2 x in
pooled
- : (float, float32_elt) t = [[[3, 4]]]val max_unpool2d :
(int, uint8_elt) t ->
('a, 'b) t ->
kernel_size:(int * int) ->
?stride:(int * int) ->
?dilation:(int * int) ->
?padding_spec:[< `Full | `Same | `Valid Valid ] ->
?output_size_opt:int array ->
unit ->
(int, uint8_elt) tmax_unpool2d indices values ~kernel_size ?stride ?dilation ?padding_spec ?output_size_opt () reverses 2D max pooling.
Same as max_unpool1d but for 2D. Indices encode position within each pooling window. Useful for architectures like segmentation networks that need to "remember" where maxima came from.
# let x = create float32 [| 1; 1; 4; 4 |]
[| 1.; 2.; 3.; 4.; 5.; 6.; 7.; 8.;
9.; 10.; 11.; 12.; 13.; 14.; 15.; 16. |] in
let pooled, _ = max_pool2d ~kernel_size:(2,2) x in
pooled
- : (float, float32_elt) t = [[[[6, 8],
[14, 16]]]]one_hot ~num_classes indices creates one-hot encoding.
Adds new last dimension of size num_classes. Values must be in [0, num_classes). Out-of-range indices produce zero vectors.
# let indices = create int32 [| 3 |] [| 0l; 1l; 3l |] in
one_hot ~num_classes:4 indices
- : (int, uint8_elt) t = [[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]]
# let indices = create int32 [| 2; 2 |] [| 0l; 2l; 1l; 0l |] in
one_hot ~num_classes:3 indices |> shape
- : int array = [|2; 2; 3|]Functions to iterate over and transform arrays.
map_item f t applies f to each element.
Operates on contiguous data directly. Type-preserving only.
iter_item f t applies f to each element for side effects.
fold_item f init t folds f over elements.
map f t applies tensor function f to each element as scalar tensor.
iter f t applies tensor function f to each element.
fold f init t folds tensor function over elements.
Functions to display arrays and convert to strings.
pp_data fmt t pretty-prints tensor data.
format_to_string pp x converts using pretty-printer.
print_with_formatter pp x prints using formatter.
pp_dtype fmt dt pretty-prints dtype.
shape_to_string shape formats shape as "2x3x4".
pp_shape fmt shape pretty-prints shape.
pp fmt t pretty-prints tensor info and data.
Functions for automatic differentiation and gradient computation.
grad f t computes the gradient of f with respect to t.
Returns a tensor of the same shape as t containing the gradient values.
# let x = create float32 [| 2 |] [| 3.; 4. |] in
let f t = sum (mul_s t 2.) in
grad f x |> item []
- : float = 2.grads f ts computes gradients of f with respect to each tensor in ts.
Returns a list of gradients, one for each input tensor.
# let xs = [create float32 [| 2 |] [| 3. |]; create float32 [| 2 |] [| 4. |]] in
let f ts = sum (mul_s (List.hd ts) 2.) +. sum (mul_s (List.nth ts 1) 3.) in
grads f xs |> List.map (fun t -> item t [])
- : float list = [6.; 12.]value_and_grad f t computes both the value of f and the gradient with respect to t.
Returns a tuple of the function value and the gradient tensor.
# let x = create float32 [| 2 |] [| 3. |] in
let f t = sum (mul_s t 2.) in
value_and_grad f x |> (fun (v, g) -> (item v [], item g []))
- : float * float = (6., 2.)val value_and_grads :
(('a, 'b) t list -> ('c, 'd) t) ->
('a, 'b) t list ->
('c, 'd) t * ('a, 'b) t listvalue_and_grads f ts computes both the value of f and the gradients with respect to each tensor in ts.
Returns a tuple of the function value and a list of gradient tensors.
# let xs = [create float32 [| 2 |] [| 3. |]; create float32 [| 2 |] [| 4. |]] in
let f ts = sum (mul_s (List.hd ts) 2.) +. sum (mul_s (List.nth ts 1) 3.) in
value_and_grads f xs |> (fun (v, gs) -> (item v [], List.map (fun g -> item g []) gs))
- : float * float list = (18., [6.; 12.])jvp f primals tangents computes a Jacobian-vector product (forward-mode AD).
Returns a tuple of (primal_output, tangent_output) where:
# let x = scalar float32 2. in
let v = scalar float32 1. in
let f x = mul x x in
jvp f x v |> (fun (p, t) -> (item p [], item t []))
- : float * float = (4., 4.)val jvp_aux :
(('a, 'b) t -> ('c, 'd) t * 'e) ->
('a, 'b) t ->
('a, 'b) t ->
('c, 'd) t * ('c, 'd) t * 'ejvp_aux f primals tangents like jvp but for functions with auxiliary output.
Returns (primal_output, tangent_output, aux) where aux is the auxiliary data.
# let x = scalar float32 2. in
let v = scalar float32 1. in
let f x = (mul x x, shape x) in
jvp_aux f x v |> (fun (p, t, aux) -> (item p [], item t [], aux))
- : float * float * int array = (4., 4., [||])val jvps :
(('a, 'b) t list -> ('c, 'd) t) ->
('a, 'b) t list ->
('a, 'b) t list ->
('c, 'd) t * ('c, 'd) tjvps f primals tangents computes JVP for functions with multiple inputs.
Returns (primal_output, tangent_output) for the list of inputs.
# let xs = [scalar float32 3.; scalar float32 2.] in
let vs = [scalar float32 1.; scalar float32 0.5] in
let f inputs = mul (List.hd inputs) (List.nth inputs 1) in
jvps f xs vs |> (fun (p, t) -> (item p [], item t []))
- : float * float = (6., 3.5)no_grad f evaluates f () without recording operations for automatic differentiation. This mirrors JAX's lax.stop_gradient semantics when applied to a computation block: all tensors produced within f are treated as constants for subsequent gradient calculations.
detach t returns a tensor with the same value as t but which is treated as a constant with respect to automatic differentiation. Equivalent to JAX's lax.stop_gradient on a single tensor.
Finite difference method to use:
`Central: (f(x+h) - f(x-h)) / 2h (most accurate)`Forward: (f(x+h) - f(x)) / h`Backward: (f(x) - f(x-h)) / hval finite_diff :
?eps:float ->
?method_:method_ ->
(('a, 'b) t -> ('c, 'd) t) ->
('a, 'b) t ->
('a, 'b) tfinite_diff ?eps ?method_ f x computes the gradient of scalar-valued function f with respect to input x using finite differences. The function f must return a scalar tensor.
val finite_diff_jacobian :
?eps:float ->
?method_:method_ ->
(('a, 'b) t -> ('c, 'd) t) ->
('a, 'b) t ->
('c, 'd) tfinite_diff_jacobian ?eps ?method_ f x computes the Jacobian matrix of function f with respect to input x using finite differences.
type gradient_check_result = {max_abs_error : float;Maximum absolute error between autodiff and finite difference gradients
*)max_rel_error : float;Maximum relative error between autodiff and finite difference gradients
*)mean_abs_error : float;Mean absolute error across all checked elements
*)mean_rel_error : float;Mean relative error across all checked elements
*)failed_indices : (int array * float * float * float) list;List of (index, autodiff_value, finite_diff_value, absolute_error) for failed elements
*)passed : bool;Whether all checked elements passed the tolerance tests
*)num_checked : int;Total number of elements checked
*)num_failed : int;Number of elements that failed the tolerance tests
*)}val check_gradient :
?eps:float ->
?rtol:float ->
?atol:float ->
?verbose:bool ->
?check_indices:int list option ->
?method_:[ `Central | `Forward | `Backward ] ->
((float, 'a) t -> ('b, 'c) t) ->
(float, 'a) t ->
[ `Pass of gradient_check_result | `Fail of gradient_check_result ]check_gradient ?eps ?rtol ?atol ?verbose ?check_indices ?method_ f x compares the gradient of f at x computed via automatic differentiation against finite differences.
The check passes if for each element either:
val check_gradients :
?eps:float ->
?rtol:float ->
?atol:float ->
?verbose:bool ->
?method_:[ `Central | `Forward | `Backward ] ->
((float, 'a) t list -> ('b, 'c) t) ->
(float, 'a) t list ->
[ `Pass of gradient_check_result list | `Fail of gradient_check_result list ]check_gradients ?eps ?rtol ?atol ?verbose ?method_ f xs compares the gradients of f with respect to each input in xs computed via automatic differentiation against finite differences.
Returns a list of results, one for each input tensor.
Functions for mapping computations over batch dimensions.
val vmap :
?in_axes:'a in_axes_spec ->
?out_axes:'b out_axes_spec ->
?axis_name:string ->
?axis_size:int ->
(('c, 'd) t -> ('e, 'f) t) ->
('c, 'd) t ->
('e, 'f) tvmap ?in_axes ?out_axes ?axis_name ?axis_size f creates a vectorized version of function f.
# let batch_x = create float32 [| 10; 3; 3 |] (Array.init 90 float_of_int) in
let w = create float32 [| 3; 2 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
let batched_matmul = vmap (fun x -> matmul x w) in
batched_matmul batch_x |> shape
- : int array = [| 10; 3; 2 |]val vmaps :
?in_axes:Rune__.Vmap.axis_spec list ->
?out_axes:'b Rune__.Vmap.out_axes_spec ->
?axis_name:string ->
?axis_size:int ->
(('c, 'd) t list -> ('e, 'f) t) ->
('c, 'd) t list ->
('e, 'f) tvmaps ?in_axes ?out_axes ?axis_name ?axis_size f creates a vectorized version of function f that takes multiple tensor arguments.
Similar to vmap but for functions taking multiple arguments.
Examples:
let x = create float32 [| 3; 2 |] [| 1.; 2.; 3.; 4.; 5.; 6. |] in
let y = create float32 [| 3; 2 |] [| 10.; 20.; 30.; 40.; 50.; 60. |] in
let batched_add = vmaps (fun [x; y] -> add x y) in
batched_add [x; y] |> to_float1
- : float array = [| 11.; 22.; 33.; 44.; 55.; 66. |]JAX-style splittable PRNG for reproducible random number generation.
Functions for debugging, JIT compilation, and gradient computation.
debug f x applies f to x and prints debug information.
Useful for inspecting intermediate values during development.
debug_with_context context f runs f with a debug context.
Prints the context name before executing f. Useful for tracing specific computation paths.
debug_push_context context pushes a new debug context.
Use this to mark the start of a specific computation section. The context will be printed in debug messages.
debug_pop_context () pops the last debug context.
Use this to mark the end of a specific computation section. The context will be removed from the debug stack.
Functions for JIT compilation of tensor operations.
jit_device represents devices supported in JIT compilation.
`llvm: CPU device using LLVM for JIT-compiled operations.`metal: GPU device using Metal for JIT-compiled operations on Apple devices.is_jit_device_available dev checks if the specified device is available.
Returns true if the device can be used for tensor operations.
jit f t compiles the function f for efficient execution on t.
Returns a compiled version of f that can be called with tensors of the same shape and type as t. This can significantly speed up repeated calls.
# let x = create float32 [| 2 |] [| 3. |] in
let f t = sum (mul_s t 2.) in
let compiled_f = jit f x in
compiled_f x |> item []
- : float = 6.