Module Kaun.PtreeSource

Parameter tree data structures and operations

Sourcemodule Record : sig ... end
Sourcetype 'layout t =
  1. | Tensor of (float, 'layout) Rune.t
  2. | List of 'layout t list
  3. | Record of 'layout t Record.t

Parameter tree type - recursive structure for model parameters

Sourcetype mask_tree =
  1. | Mask_tensor of bool
  2. | Mask_list of mask_tree list
  3. | Mask_record of mask_tree Record.t

Builders

Sourceval tensor : (float, 'layout) Rune.t -> 'layout t

Create a leaf tensor node

Sourceval list_of : 'layout t list -> 'layout t

Create a list node

Sourceval record_of : (string * 'layout t) list -> 'layout t

Create a record node from bindings. Raises if duplicate keys.

Accessors

Sourceval get_tensor : 'layout t -> (float, 'layout) Rune.t option

get_tensor tree returns Some tensor if tree is a Tensor, None otherwise

Sourceval get_list : 'layout t -> 'layout t list option

get_list tree returns Some list if tree is a List, None otherwise

Sourceval get_record : 'layout t -> 'layout t Record.t option

get_record tree returns Some record if tree is a Record, None otherwise

Sourceval find_in_record : string -> 'layout t -> 'layout t option

find_in_record key tree returns Some value if tree is a Record containing key, None otherwise

Tree Operations

Sourceval map : ((float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> 'layout t -> 'layout t

map f tree applies function f to all tensors in the tree

Sourceval map2 : ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> (float, 'layout) Rune.t) -> 'layout t -> 'layout t -> 'layout t

map2 f tree1 tree2 applies binary function f to corresponding tensors in both trees. Raises Invalid_argument if trees have different structures.

Sourceval zip : ((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> 'a) -> 'layout t -> 'layout t -> 'a list

zip f tree1 tree2 applies f to pairs of corresponding tensors, returning a flat list of results. Useful for pairing without building a new tree.

Sourceval iter : ((float, 'layout) Rune.t -> unit) -> 'layout t -> unit

iter f tree applies function f to all tensors in the tree for side effects

Sourceval fold : ('a -> (float, 'layout) Rune.t -> 'a) -> 'a -> 'layout t -> 'a

fold f init tree folds function f over all tensors in the tree

Sourceval equal_structure : 'layout t -> 'layout t -> bool

equal_structure tree1 tree2 returns true if both trees have the same structure (ignoring tensor values)

Sourceval filter : ((float, 'layout) Rune.t -> bool) -> 'layout t -> 'layout t

filter pred tree replaces tensors where pred is false with zeros_like

Sourceval apply_mask : mask_tree -> 'layout t -> 'layout t

apply_mask mask tree zeros out tensors where mask is false. Raises if structures differ.

Tree Construction

Sourceval zeros_like : 'layout t -> 'layout t

zeros_like tree creates a new tree with same structure but all tensors filled with zeros

Sourceval ones_like : 'layout t -> 'layout t

ones_like tree creates a new tree with same structure but all tensors filled with ones

Sourceval copy : 'layout t -> 'layout t

copy tree creates a deep copy of the tree

Tree Inspection

Sourceval count_tensors : 'layout t -> int

count_tensors tree returns the number of tensors in the tree

Sourceval count_parameters : 'layout t -> int

count_parameters tree returns the total number of scalar parameters across all tensors

Sourceval flatten : 'layout t -> (float, 'layout) Rune.t list * ((float, 'layout) Rune.t list -> 'layout t)

flatten tree returns flat tensor list and a rebuild function

Arithmetic Operations

Sourceval add : 'layout t -> 'layout t -> 'layout t

add tree1 tree2 performs element-wise addition of corresponding tensors

Sourceval sub : 'layout t -> 'layout t -> 'layout t

sub tree1 tree2 performs element-wise subtraction of corresponding tensors

Sourceval mul : 'layout t -> 'layout t -> 'layout t

mul tree1 tree2 performs element-wise multiplication of corresponding tensors

Sourceval div : 'layout t -> 'layout t -> 'layout t

div tree1 tree2 performs element-wise division of corresponding tensors

Sourceval scale : float -> 'layout t -> 'layout t

scale alpha tree multiplies all tensors in the tree by scalar alpha

Sourceval neg : 'layout t -> 'layout t

neg tree negates all tensors in the tree

Utility Functions

Sourceval pp : Format.formatter -> 'layout t -> unit

Pretty printer for parameter trees

Sourceval to_string : 'layout t -> string

to_string tree returns a string representation of the tree structure

Path-based Flattening

Sourceval flatten_with_paths : 'layout t -> (string * (float, 'layout) Rune.t) list

flatten_with_paths tree returns a list of (path, tensor) pairs where paths use dot notation for records (e.g., "layer1.weight") and bracket notation for lists (e.g., "layers0").

Sourceval unflatten_from_paths : (string * (float, 'layout) Rune.t) list -> 'layout t

unflatten_from_paths pairs reconstructs a parameter tree from path-tensor pairs. Raises Invalid_argument if paths are malformed or inconsistent.

Path-based Access

Sourceval get_by_path : string -> 'layout t -> 'layout t

get_by_path path tree retrieves the subtree at the given path. Path uses dot notation for records and bracket notation for lists. Examples: "encoder.weight", "layers0.attention.q_proj"

Sourceval set_by_path : string -> 'layout t -> 'layout t -> 'layout t

set_by_path path value tree returns a new tree with the value at path replaced. Creates intermediate records if they don't exist.

  • raises Invalid_argument

    if the path is malformed or incompatible with tree structure

Sourceval validate_tree : ?path:string -> 'layout t -> unit

validate_tree ?path tree checks for structural issues:

  • Empty keys
  • Duplicate keys within records
  • Invalid characters in keys (. )
  • Warnings for empty lists/records
  • parameter path

    Starting path for error messages (default: "root")

Enhanced Introspection

Sourceval list_named_params : 'layout t -> (string * string * int) list

list_named_params tree returns a list of (path, shape_string, num_elements). Shape strings are formatted as "2×3×4" for tensors or "scalar" for 0-d tensors. Useful for inspecting model architecture.

Sourceval find_params_by_pattern : string -> 'layout t -> (string * (float, 'layout) Rune.t) list

find_params_by_pattern pattern tree returns all params whose paths match the regex pattern. Example: find_params_by_pattern ".*weight$" finds all weight tensors.

Sourceval get_param_stats : 'layout t -> int * (string * int) list

get_param_stats tree returns (total_params, (group_name, count), ...). Groups parameters by top-level key for a summary view. Example: (1000000, ("encoder", 800000); ("decoder", 200000))