123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236(* Primitive tensor operations that every Nx backend must implement. *)(** Backend interface.
The [`op_*`] functions mirror tinygrad's UOps. A backend can execute them
eagerly, raise effects for a JIT, build a computation graph, etc.
The frontend handles broadcasting and shape validation, so each operation
simply produces a fresh tensor. *)moduletypeS=sigtype('a,'b)t(** Opaque tensor handle.
['a] is the OCaml element type; ['b] tags the dtype. *)typecontext(** Backend execution context. Carries any state required by the
implementation (memory pools, command queues, ...). *)(* lenses *)valview:('a,'b)t->View.t(** Return the logical view metadata of [t]. *)valdtype:('a,'b)t->('a,'b)Dtype.t(** Element type of [t]. *)valcontext:('a,'b)t->context(** Execution context of [t]. *)valdata:('a,'b)t->('a,'b,Bigarray.c_layout)Bigarray.Array1.t(** Return the raw buffer of [t]. *)(* ops: mirrors tinygrad UOps *)valop_buffer:context->('a,'b)Dtype.t->int(* size_in_elements *)->('a,'b)t(** Allocate a buffer of [size_in_elements] elements of [dtype]. *)valop_const_scalar:context->'a->('a,'b)Dtype.t->('a,'b)t(** Tensor containing a single scalar [value]. *)valop_const_array:context->('a,'b,Bigarray.c_layout)Bigarray.Array1.t->('a,'b)t(** Tensor containing the elements of [array]. The array must be contiguous.
*)(* Element-wise Binary Ops *)(* These ops assume inputs have been broadcast to the same shape and cast to
the same compatible dtype by the frontend. The output tensor will also have
this common dtype. *)valop_add:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise addition. *)valop_mul:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise multiplication. *)valop_idiv:('a,'b)t->('a,'b)t->('a,'b)t(** Integer division, truncating. *)valop_fdiv:('a,'b)t->('a,'b)t->('a,'b)t(** Floating-point division. *)valop_max:('a,'b)t->('a,'b)t->('a,'b)t(** Element-wise maximum. *)valop_mod:('a,'b)t->('a,'b)t->('a,'b)t(** Integer modulus. *)valop_pow:('a,'b)t->('a,'b)t->('a,'b)t(** Raise [base] to [exponent]. *)valop_cmplt:('a,'b)t->('a,'b)t->(int,Dtype.uint8_elt)t(** Compare [<]. Returns 0 or 1 as uint8. *)valop_cmpne:('a,'b)t->('a,'b)t->(int,Dtype.uint8_elt)t(** Compare [<>]. Returns 0 or 1 as uint8. *)valop_xor:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise XOR. *)valop_or:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise OR. *)valop_and:('a,'b)t->('a,'b)t->('a,'b)t(** Bitwise AND. *)(* Element-wise Unary Ops *)valop_neg:('a,'b)t->('a,'b)t(** Negation (logical not for bools). *)valop_log2:('a,'b)t->('a,'b)t(** Base-2 logarithm. *)valop_exp2:('a,'b)t->('a,'b)t(** Exponential base 2. *)valop_sin:('a,'b)t->('a,'b)t(** Sine. *)valop_sqrt:('a,'b)t->('a,'b)t(** Square root. *)valop_recip:('a,'b)t->('a,'b)t(** Reciprocal. *)(* Ternary Op *)valop_where:(int,Dtype.uint8_elt)t->('a,'b)t->('a,'b)t->('a,'b)t(** Select from [if_true] or [if_false] based on a boolean tensor. *)(* Reduction Ops *)valop_reduce_sum:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Sum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)valop_reduce_max:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Maximum over [axes]. Keeps reduced dimensions if [keepdims] is true. *)valop_reduce_prod:axes:intarray->keepdims:bool->('a,'b)t->('a,'b)t(** Product over [axes]. Keeps reduced dimensions if [keepdims] is true. *)(* Movement Ops - manipulate view metadata *)valop_expand:('a,'b)t->intarray->('a,'b)t(** Broadcast dimensions of size 1 to a new shape. *)valop_reshape:('a,'b)t->intarray->('a,'b)t(** Change the logical shape without moving data. *)valop_permute:('a,'b)t->intarray->('a,'b)t(** Reorder dimensions according to [axes]. *)valop_pad:('a,'b)t->(int*int)array->'a->('a,'b)t(** Pad with [fill_value] using the given configuration. *)valop_shrink:('a,'b)t->(int*int)array->('a,'b)t(** Slice according to the given start/stop pairs. *)valop_flip:('a,'b)t->boolarray->('a,'b)t(** Flip dimensions where the boolean array is [true]. *)valop_cat:('a,'b)tlist->int->('a,'b)t(** Concatenate tensors along [axis]. *)(* Other Ops *)valop_cast:('a,'b)t->('c,'d)Dtype.t->('c,'d)t(** Cast elements to [target_dtype]. *)valop_contiguous:('a,'b)t->('a,'b)t(** Return a C-contiguous tensor. May copy. *)valop_copy:('a,'b)t->('a,'b)t(** Duplicate [t]. Result has its own buffer. *)valop_assign:('a,'b)t->('a,'b)t->unit(** Store [src] into [dst] at the given logical indices. *)valop_threefry:(int32,Dtype.int32_elt)t->(int32,Dtype.int32_elt)t->(int32,Dtype.int32_elt)t(** Threefry random number generator. *)(* Element Access Ops *)(* These operations enable lazy, graph-based element access (get/set). This
differs from tinygrad's eager realization for __getitem__/__setitem__. We
opt for lazy ops to avoid premature realization and CPU<->device transfers.
These are primarily for internal Nx slice/put_slice implementations and
their direct backend exposure might be refined later. *)valop_gather:('a,'b)t(* data *)->(int32,Dtype.int32_elt)t(* indices *)->int(* axis *)->('a,'b)t(** Gather elements from [data] along [axis] using [indices]. Output shape
matches [indices]. Ranks of [data] and [indices] must match. Sizes of
[indices] dims != [axis] must be <= [data] corresponding dims. *)valop_scatter:?mode:[`Set|`Add]->?unique_indices:bool->('a,'b)t(* data_template *)->(int32,Dtype.int32_elt)t(* indices *)->('a,'b)t(* updates *)->int(* axis *)->('a,'b)t(** Scatter [updates] into a new tensor shaped like [data_template] along
[axis] using [indices]. Returns a new tensor.
- [mode] specifies how to handle duplicate indices:
- [`Set] (default): last update wins
- [`Add]: accumulate updates at duplicate indices
- [unique_indices]: hint that indices are unique (optimization) *)valop_unfold:('a,'b)t->kernel_size:intarray->stride:intarray->dilation:intarray->padding:(int*int)array->('a,'b)t(** Unfold (im2col) operation. Extracts sliding local blocks from a batched
input tensor. For an input of shape (N, C, *spatial_dims), produces output
of shape (N, C * prod(kernel_size), L) where L is the number of blocks.
Works for any number of spatial dimensions. *)valop_fold:('a,'b)t->output_size:intarray->kernel_size:intarray->stride:intarray->dilation:intarray->padding:(int*int)array->('a,'b)t(** Fold (col2im) operation. Combines an array of sliding local blocks into a
tensor. For an input of shape (N, C * prod(kernel_size), L), produces
output of shape (N, C, *output_size). Inverse of unfold. Overlapping
values are summed. Works for any number of spatial dimensions. *)valop_matmul:('a,'b)t->('a,'b)t->('a,'b)t(** Matrix multiplication. For 2D tensors, computes standard matrix
multiplication. For higher dimensions, performs batched matrix
multiplication on the last two dimensions, broadcasting batch dimensions
as needed. The last dimension of the first tensor must match the
second-to-last dimension of the second tensor. *)end