Ocannl.TensorSourcetype diff = {grad : tn;zero_grads : asgns;Prepares for backpropagation. Always compile as: Seq (zero_grads, backprop).
backprop : asgns;Backpropagates for the tensor and its descendants; which typically means adding partial gradients to the gradient tensor of the subtensors, then for sub-subtensors etc.
*)}type t = {forward : asgns;diff : diff Base.option;id : Base.int;Same as value.id.
value : tn;shape : Shape.t;The eventual shape of t.value and t.diff.grad, incorporating the current state of shape inference.
children : subtensor Base.list;}Information needed for compositional code generation.
The default precision for the value node of terminal (i.e. non-composite) tensors.
Note: the precision of a node can be set arbitrarily via Arrayjit.Tnode.update_prec. The default precision for value nodes of composite tensors is the maximum of precisions of the value nodes of sub-tensors.
The default precision for the gradient node of terminal (i.e. non-composite) tensors.
Note: the precision of a node can be set arbitrarily via Arrayjit.Tnode.update_prec. The default precision for gradient nodes of composite tensors is the maximum of precisions of the gradient nodes of sub-tensors.
val raw_binop :
initialize_neutral:Base.bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:Base.bool ->
op:Arrayjit.Ops.binop ->
t1:t ->
rhs1_is_grad:Base.bool ->
rhs1_is_merge:Base.bool ->
t2:t ->
rhs2_is_grad:Base.bool ->
rhs2_is_merge:Base.bool ->
logic:Shape.compose_type ->
asgnsval raw_unop :
initialize_neutral:Base.bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:Base.bool ->
op:Arrayjit.Ops.unop ->
t1:t ->
rhs_is_grad:Base.bool ->
rhs_is_merge:Base.bool ->
logic:Shape.transpose_type ->
asgnsval op :
label:Base.string Base.list ->
?compose_op:Shape.compose_type ->
?transpose_op:Shape.transpose_type ->
?init_op:init_op ->
op_asn:(v:tn -> projections:projections Base.Lazy.t -> asgns) ->
grad_asn:(v:tn -> g:tn -> projections:projections Base.Lazy.t -> asgns) ->
?grad_spec:grad_spec ->
(debug_name:Base.string -> id:Base.int -> Shape.t) ->
t Base.list ->
tval binop :
label:Base.string Base.list ->
?compose_op:Shape.compose_type ->
op_asn:(v:tn -> t1:t -> t2:t -> projections:projections Base.Lazy.t -> asgns) ->
grad_asn:
(v:tn ->
g:tn ->
t1:t ->
t2:t ->
projections:projections Base.Lazy.t ->
asgns) ->
?grad_spec:grad_spec ->
t ->
t ->
tval unop :
label:Base.string Base.list ->
?transpose_op:Shape.transpose_type ->
op_asn:(v:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) ->
grad_asn:
(v:tn -> g:tn -> t1:t -> projections:projections Base.Lazy.t -> asgns) ->
?grad_spec:grad_spec ->
t ->
tval term :
label:Base.string Base.list ->
grad_spec:grad_spec ->
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:Shape.deduce_within_shape ->
?init_op:init_op ->
?fetch_op:(v:tn -> fetch_op) ->
Base.unit ->
tA terminal: a constant, a parameter, an input of the model. The semantics of shape specification is the same as in Shape.make, and by default the shape will be inferred.
val number :
?label:Base.string Base.list ->
?axis_label:Base.string ->
?grad_spec:grad_spec ->
Base.float ->
tA number: a tensor with a single axis of one dimension, initialized to the given value. grad_spec is by default Prohibit_grad.
val ndarray :
?label:Base.string Base.list ->
?grad_spec:grad_spec ->
?batch_dims:Base.int Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?batch_axes:(Base.string * Base.int) Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?strict:Base.bool ->
Base.float Base.array ->
tA tensor with an explicit shape, initialized to the given values. Omitted shape rows default to no axes. grad_spec is by default Prohibit_grad. If strict is true (the default), the given values must fill the tensor's value node precisely; otherwise, the values will be looped over to populate the value node.
val param :
?more_label:Base.string Base.list ->
?input_dims:Base.int Base.list ->
?output_dims:Base.int Base.list ->
?input_axes:(Base.string * Base.int) Base.list ->
?output_axes:(Base.string * Base.int) Base.list ->
?deduced:Shape.deduce_within_shape ->
?strict:Base.bool ->
?values:Base.float Base.array ->
Base.string ->
tval non_and_embedded_nodes :
t ->
(t, comparator_witness) Base.Set.t * (t, comparator_witness) Base.Set.tA forward root is a tensor that is not (currently) used to compute another tensor. consume_forward_code t ensures t is a forward root, removes it from forward roots, and checks that there are no other forward roots for tensors with children.
A backprop root is a tensor with a gradient that is not (currently) receiving gradients from another tensor. I.e. it is not currently used to compute a tensor with a gradient. consume_backprop_code t ensures t is a backprop root, removes it from backprop roots, and checks that there are no other backprop roots for tensors with children.
Bring global state to its initialization values. This invalidates any previously defined tensors and tensor nodes. Also reinitializes the modules: Shape, Arrayjit.Tnode, Arrayjit.Rand.Random_for_tests.
Converts ID, label and the dimensions of a node to a string.
Logs debug information about the tensor on the default ppx_minidebug runtime.
type array_print_style = [ | `DefaultThe inner rectangles comprise both an input and an output axis, if available. Similarly, the outer rectangle comprises a second-from-end input axis and a second-from-end output axis, if available. At least one batch axis is output, when available. The axes that couldn't be output are printed at position/dimension 0.
| `N5_layout of Base.stringThe string should provide exclusively non-negative integer pseudo-labels. The numbers 0-4 represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (see also Node.pp_print). The numbers n >= 5 stand for the actual positions n - 5 within the corresponding axes.
| `Label_layout of (Base.string * Base.int) Base.listThe association from axis labels to integers. The negative numbers -5 to -1 represent the priorities of the axes to be printed out, where the priorities correspond to, from highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer rectangle, repetition (as above). The numbers n >= 0 stand for the actual positions within the corresponding axes. Unspecified axes are printed at position 0.
| `InlineThe tensors are printed linearly, in a bracketed manner, optionally prefixed with the labels specification. Note that the syntax causes ambiguity for 1-dimensional input axes (underscores are used for axes without explicit labels); when there is a 1-dimensional input axis, we output the labels specification even if there are no axis labels as a way to display the number of axes. The axis nesting is right-to-left (rightmost is innermost). The input axes are innermost and the batch axes outermost. The input axes use , as a separator and () as axis delimiters, but the delimiter for the outermost (i.e. leftmost) axis is omitted. The output axes use ; as a separator and [] as axis delimiters (obligatory). The batch axes use ; as a separator and [||] as axis delimiters (obligatory).
]We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner) rectangles, possibly repeated (screens).
val print_forward_roots :
with_grad:Base.bool ->
with_code:Base.bool ->
array_print_style ->
Base.unitval value_2d_points :
?from_axis:Base.int ->
xdim:Base.int ->
ydim:Base.int ->
t ->
(Base.float * Base.float) Base.arrayval grad_2d_points :
?from_axis:Base.int ->
xdim:Base.int ->
ydim:Base.int ->
t ->
(Base.float * Base.float) Base.array