123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333(* Primitive tensor operations that every Nx backend must implement. *)(** Backend interface.
The [`op_*`] functions mirror tinygrad's UOps. A backend can execute them
eagerly, raise effects for a JIT, build a computation graph, etc.
The frontend handles broadcasting and shape validation, so each operation
simply produces a fresh tensor. *)moduletypeS=sigtype('a,'b)t(** Opaque tensor handle.
['a] is the OCaml element type; ['b] tags the dtype. *)typecontext(** Backend execution context. Carries any state required by the
implementation (memory pools, command queues, ...). *)(* lenses *)valview:('a,'b)t->Lazy_view.t(** Return the view tracker for [t]. *)valdtype:('a,'b)t->('a,'b)Dtype.t(** Element type of [t]. *)valcontext:('a,'b)t->context(** Execution context of [t]. *)valdata:('a,'b)t->('a,'b,Bigarray.c_layout)Bigarray.Array1.t(** Return the raw buffer of [t]. *)(* ops: mirrors tinygrad UOps *)valop_buffer:context->('a,'b)Dtype.t->int(* size_in_elements *)->('a,'b)t(** Allocate a buffer of [size_in_elements] elements of [dtype]. *)valop_const_scalar:context->'a->('a,'b)Dtype.t->('a,'b)t(** Tensor containing a single scalar [value]. *)valop_const_array:context->('a,'b,Bigarray.c_layout)Bigarray.Array1.t->('a,'b)t(** Tensor containing the elements of [array]. The array must be contiguous.
*)(* Element-wise Binary Ops *)(* These ops assume inputs have been broadcast to the same shape and cast to
the same compatible dtype by the frontend. The output tensor will also have
this common dtype. *)valop_add:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise addition. *)valop_mul:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise multiplication. *)valop_idiv:('a,'b)t->('a,'b)t->('a,'b)t(** Integer division, truncating. *)valop_fdiv:('a,'b)t->('a,'b)t->('a,'b)t(** Floating-point division. *)valop_max:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise maximum. *)valop_mod:('a,'b)t->('a,'b)t->('a,'b)t(** Integer modulus. *)valop_pow:('a,'b)t->('a,'b)t->('a,'b)t(** Raise [base] to [exponent]. *)valop_cmplt:('a,'b)t->('a,'b)t->(int,Dtype.uint8_elt)t(** Compare [<]. Returns 0 or 1 as uint8. *)valop_cmpne:('a,'b)t->('a,'b)t->(int,Dtype.uint8_elt)t(** Compare [<>]. Returns 0 or 1 as uint8. *)valop_xor:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise XOR. *)valop_or:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise OR. *)valop_and:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise AND. *)(* Element-wise Unary Ops *)valop_neg:('a,'b)t->('a,'b)t(** Negation (logical not for bools). *)valop_log2:('a,'b)t->('a,'b)t(** Base-2 logarithm. *)valop_exp2:('a,'b)t->('a,'b)t(** Exponential base 2. *)valop_sin:('a,'b)t->('a,'b)t(** Sine. *)valop_sqrt:('a,'b)t->('a,'b)t(** Square root. *)valop_recip:('a,'b)t->('a,'b)t(** Reciprocal. *)(* Ternary Op *)valop_where:(int,Dtype.uint8_elt)t->('a,'b)t->('a,'b)t->('a,'b)t(** Select from [if_true] or [if_false] based on a boolean tensor. *)(* Reduction Ops *)valop_reduce_sum:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Sum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)valop_reduce_max:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Maximum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)valop_reduce_prod:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Product over [axes]. Keeps reduced dimensions if [keepdims] is true. *)valop_associative_scan:axis:int->op:[`Sum|`Prod|`Max|`Min]->('a,'b)t->('a,'b)t(** Inclusive scan along [axis] using the associative operation [op]. *)(* Movement Ops - manipulate view metadata *)valop_expand:('a,'b)t->Symbolic_shape.t->('a,'b)t(** Broadcast dimensions of size 1 to a new shape. *)valop_reshape:('a,'b)t->Symbolic_shape.t->('a,'b)t(** Change the logical shape without moving data. *)valop_permute:('a,'b)t->intarray->('a,'b)t(** Reorder dimensions according to [axes]. *)valop_shrink:('a,'b)t->(int*int)array->('a,'b)t(** Slice according to the given start/stop pairs. *)valop_flip:('a,'b)t->boolarray->('a,'b)t(** Flip dimensions where the boolean array is [true]. *)valop_pad:('a,'b)t->(int*int)array->'a->('a,'b)t(** Pad with [fill_value] using the given configuration. *)valop_cat:('a,'b)tlist->int->('a,'b)t(** Concatenate tensors along [axis]. *)(* Other Ops *)valop_cast:('a,'b)t->('c,'d)Dtype.t->('c,'d)t(** Cast elements to [target_dtype]. *)valop_contiguous:('a,'b)t->('a,'b)t(** Return a C-contiguous tensor. May copy. *)valop_copy:('a,'b)t->('a,'b)t(** Duplicate [t]. Result has its own buffer. *)valop_assign:('a,'b)t->('a,'b)t->unit(** Store [src] into [dst] at the given logical indices. *)valop_threefry:(int32,Dtype.int32_elt)t->(int32,Dtype.int32_elt)t->(int32,Dtype.int32_elt)t(** Threefry random number generator. *)(* Element Access Ops *)(* These operations enable lazy, graph-based element access (get/set). This
differs from tinygrad's eager realization for __getitem__/__setitem__. We
opt for lazy ops to avoid premature realization and CPU<->device transfers.
These are primarily for internal Nx slice/put_slice implementations and
their direct backend exposure might be refined later. *)valop_gather:('a,'b)t(* data *)->(int32,Dtype.int32_elt)t(* indices *)->int(* axis *)->('a,'b)t(** Gather elements from [data] along [axis] using [indices]. Output shape
matches [indices]. Ranks of [data] and [indices] must match. Sizes of
[indices] dims != [axis] must be <= [data] corresponding dims. *)valop_scatter:?mode:[`Set|`Add]->?unique_indices:bool->('a,'b)t(* data_template *)->(int32,Dtype.int32_elt)t(* indices *)->('a,'b)t(* updates *)->int(* axis *)->('a,'b)t(** Scatter [updates] into a new tensor shaped like [data_template] along
[axis] using [indices]. Returns a new tensor.
- [mode] specifies how to handle duplicate indices:
- [`Set] (default): last update wins
- [`Add]: accumulate updates at duplicate indices
- [unique_indices]: hint that indices are unique (optimization) *)valop_unfold:('a,'b)t->kernel_size:intarray->stride:intarray->dilation:intarray->padding:(int*int)array->('a,'b)t(** Unfold (im2col) operation. Extracts sliding local blocks from a batched
input tensor. For an input of shape (N, C, *spatial_dims), produces output
of shape (N, C * prod(kernel_size), L) where L is the number of blocks.
Works for any number of spatial dimensions. *)valop_fold:('a,'b)t->output_size:intarray->kernel_size:intarray->stride:intarray->dilation:intarray->padding:(int*int)array->('a,'b)t(** Fold (col2im) operation. Combines an array of sliding local blocks into a
tensor. For an input of shape (N, C * prod(kernel_size), L), produces
output of shape (N, C, *output_size). Inverse of unfold. Overlapping
values are summed. Works for any number of spatial dimensions. *)valop_matmul:('a,'b)t->('a,'b)t->('a,'b)t(** Matrix multiplication. For 2D tensors, computes standard matrix
multiplication. For higher dimensions, performs batched matrix
multiplication on the last two dimensions, broadcasting batch dimensions
as needed. The last dimension of the first tensor must match the
second-to-last dimension of the second tensor. *)(* Fourier transforms *)valop_fft:(Complex.t,'b)t->axes:intarray->(Complex.t,'b)t(** Compute the discrete Fourier transform (DFT) of the input tensor. *)valop_ifft:(Complex.t,'b)t->axes:intarray->(Complex.t,'b)t(** Compute the inverse discrete Fourier transform (IDFT) of the input tensor.
*)valop_rfft:(float,'a)t->dtype:(Complex.t,'b)Dtype.t->axes:intarray->(Complex.t,'b)t(** Compute the real-valued discrete Fourier transform (RDFT) of the input
tensor. *)valop_irfft:(Complex.t,'a)t->dtype:(float,'b)Dtype.t->axes:intarray->s:intarrayoption->(float,'b)t(** Compute the inverse real-valued discrete Fourier transform (IRDFT) of the
input tensor. *)(* Linear algebra operations *)valop_cholesky:upper:bool->('a,'b)t->('a,'b)t(** Cholesky decomposition of a positive-definite matrix.
- [upper]: If true, returns upper triangular factor; else lower (default).
- Input: Square matrix A (batched).
- Output: Triangular factor L or U such that A = L*L^T or A = U^T*U.
- Raises if input is not positive-definite. *)valop_qr:reduced:bool->('a,'b)t->('a,'b)t*('a,'b)t(** QR decomposition.
- [reduced]: If true (default), returns economy/reduced QR; else full QR.
- Input: m x n matrix A (batched).
- Output: (Q, R) where A = Q*R, Q orthogonal, R upper triangular. *)valop_svd:full_matrices:bool->('a,'b)t->('a,'b)t*(float,Dtype.float64_elt)t*('a,'b)t(** Singular value decomposition.
- [full_matrices]: If false (default), returns thin SVD; else full.
- Input: m x n matrix A (batched).
- Output: (U, S, V^H) where A = U*S*V^H.
- S is 1D vector of singular values in descending order, always float64.
*)valop_eig:vectors:bool->('a,'b)t->(Complex.t,Dtype.complex64_elt)t*(Complex.t,Dtype.complex64_elt)toption(** General eigenvalue decomposition.
- [vectors]: If true (default), computes eigenvectors.
- Input: Square matrix A (batched).
- Output: (eigenvalues, optional eigenvectors) always as complex64. *)valop_eigh:vectors:bool->('a,'b)t->(float,Dtype.float64_elt)t*('a,'b)toption(** Symmetric/Hermitian eigenvalue decomposition.
- [vectors]: If true (default), computes eigenvectors.
- Input: Symmetric (real) or Hermitian (complex) matrix A (batched).
- Output: (eigenvalues as float64, eigenvectors same type as input). *)valop_triangular_solve:upper:bool->transpose:bool->unit_diag:bool->('a,'b)t->('a,'b)t->('a,'b)t(** Solve triangular system A*x = b or A^T*x = b.
- [upper]: If true, A is upper triangular; else lower.
- [transpose]: If true, solve A^T*x = b; else A*x = b.
- [unit_diag]: If true, assume diagonal of A is all 1s.
- Input: Triangular matrix A, right-hand side b (batched).
- Output: Solution x. *)valop_as_strided:('a,'b)t->Symbolic_shape.t->intarray->int->('a,'b)t(** Create a strided view of the input tensor with the given shape, strides
(in elements), and offset (in elements). Backends that support arbitrary
strided views (e.g., native with Bigarray) can implement this as
zero-copy. Other backends may fall back to copying data if necessary.
Raises if the view would access out-of-bounds memory. *)end