Module Kaun.PtreeSource

Heterogeneous parameter tree structure.

Sourcemodule Path : sig ... end

Path submodule for advanced use

Sourcetype tensor =
  1. | P : ('a, 'layout) Rune.t -> tensor
    (*

    Existential wrapper for tensor tensors.

    *)
Sourcetype t =
  1. | Tensor of tensor
  2. | List of t list
  3. | Dict of (string * t) list
    (*

    Parameter tree: tensors or containers.

    *)
Sourcemodule Tensor : sig ... end

Tensor utilities

Sourcemodule List : sig ... end
Sourcemodule Dict : sig ... end

Builders

Sourceval tensor : ('a, 'layout) Rune.t -> t

Create a tensor node from a tensor.

Sourceval list : t list -> t

Create a list container.

Sourceval dict : (string * t) list -> t

Create a dict container from key-value pairs. Keys must be unique.

Sourcetype 'r tensor_handler = {
  1. run : 'a 'layout. ('a, 'layout) Rune.t -> 'r;
}
Sourceval with_tensor : tensor -> 'a tensor_handler -> 'a
Sourceval as_tensor : t -> tensor option

Extract tensor if the tree is a single leaf, else None.

Sourceval as_tensor_exn : ?ctx:string -> t -> tensor

Extract tensor or raise. Optional context for error message.

Walking / zipping

Sourceval map : (('a, 'l) Rune.t -> ('a, 'l) Rune.t) -> t -> t

Typed map over tensors. Result dtype must equal input dtype.

Sourceval map2 : (('a, 'l) Rune.t -> ('a, 'l) Rune.t -> ('a, 'l) Rune.t) -> t -> t -> t

Typed zip-with over tensors. Structures must match; dtype per-pair must match.

Sourceval map_packed : (tensor -> tensor) -> t -> t

Packed map over tensors (escape hatch if types are dynamic).

Sourceval iter : (tensor -> unit) -> t -> unit

Iterate over tensors.

Sourceval fold : ('acc -> tensor -> 'acc) -> 'acc -> t -> 'acc

Fold over tensors.

Flatten & rebuild

Sourceval flatten : t -> tensor list * (tensor list -> t)

Flatten to tensors and a rebuilder function.

Path access

Sourceval get : path:Path.t -> t -> t option

Get subtree at path.

Sourceval get_exn : path:Path.t -> t -> t

Get subtree or raise.

Sourceval set : path:Path.t -> value:t -> t -> t

Set subtree at path.

Sourceval update : path:Path.t -> (t -> t) -> t -> t

Update subtree at path with function.

Sourceval mem : path:Path.t -> t -> bool

Check if path exists.

Typed path access for tensors

Sourceval get_tensor : path:Path.t -> t -> ('a, 'l) Rune.dtype -> ('a, 'l) Rune.t option

Get typed tensor at path, checking dtype.

Sourceval get_tensor_exn : path:Path.t -> t -> ('a, 'l) Rune.dtype -> ('a, 'l) Rune.t

Get typed tensor or raise.

Flatten with paths

Sourceval flatten_with_paths : t -> (Path.t * tensor) list

Flatten to (path, tensor) pairs.

Sourceval filter_tensors : t -> (Path.t -> tensor -> bool) -> (Path.t * tensor) list

Filter tensors by predicate on path and tensor.

Float dtype discovery

Sourcetype float_dtype =
  1. | F : (float, 'l) Rune.dtype -> float_dtype
    (*

    Witness that the dtype is a floating-point dtype. Encodes ('a = float) at the type level to avoid enumerating float constructors at call sites.

    *)
Sourceval first_float_dtype : t -> float_dtype option

Find the first floating-point tensor in the tree and return a float dtype witness, if any. Floating-point dtypes include float32/float64/float16, bfloat16, and float8 variants.

Sourceval first_float_dtype_exn : t -> float_dtype

Like first_float_dtype but raises if no floating-point tensors are present.

Convenience

Sourceval zeros_like : t -> t

Create tree with zeros_like tensors.

Sourceval copy : t -> t

Deep copy the tree.

Sourceval count_tensors : t -> int

Count number of tensors.

Sourceval count_parameters : t -> int

Total elements across all tensors.

Sourceval pp : Format.formatter -> t -> unit

Printing

Sourceval to_string : t -> string