Kaun.PtreeSourceParameter tree data structures and operations
Parameter tree type - recursive structure for model parameters
Create a record node from bindings. Raises if duplicate keys.
get_tensor tree returns Some tensor if tree is a Tensor, None otherwise
get_list tree returns Some list if tree is a List, None otherwise
get_record tree returns Some record if tree is a Record, None otherwise
find_in_record key tree returns Some value if tree is a Record containing key, None otherwise
map f tree applies function f to all tensors in the tree
val map2 :
((float, 'layout) Rune.t ->
(float, 'layout) Rune.t ->
(float, 'layout) Rune.t) ->
'layout t ->
'layout t ->
'layout tmap2 f tree1 tree2 applies binary function f to corresponding tensors in both trees. Raises Invalid_argument if trees have different structures.
val zip :
((float, 'layout) Rune.t -> (float, 'layout) Rune.t -> 'a) ->
'layout t ->
'layout t ->
'a listzip f tree1 tree2 applies f to pairs of corresponding tensors, returning a flat list of results. Useful for pairing without building a new tree.
iter f tree applies function f to all tensors in the tree for side effects
fold f init tree folds function f over all tensors in the tree
equal_structure tree1 tree2 returns true if both trees have the same structure (ignoring tensor values)
filter pred tree replaces tensors where pred is false with zeros_like
apply_mask mask tree zeros out tensors where mask is false. Raises if structures differ.
zeros_like tree creates a new tree with same structure but all tensors filled with zeros
ones_like tree creates a new tree with same structure but all tensors filled with ones
count_tensors tree returns the number of tensors in the tree
count_parameters tree returns the total number of scalar parameters across all tensors
val flatten :
'layout t ->
(float, 'layout) Rune.t list * ((float, 'layout) Rune.t list -> 'layout t)flatten tree returns flat tensor list and a rebuild function
add tree1 tree2 performs element-wise addition of corresponding tensors
sub tree1 tree2 performs element-wise subtraction of corresponding tensors
mul tree1 tree2 performs element-wise multiplication of corresponding tensors
div tree1 tree2 performs element-wise division of corresponding tensors
scale alpha tree multiplies all tensors in the tree by scalar alpha
Pretty printer for parameter trees
to_string tree returns a string representation of the tree structure
flatten_with_paths tree returns a list of (path, tensor) pairs where paths use dot notation for records (e.g., "layer1.weight") and bracket notation for lists (e.g., "layers0").
unflatten_from_paths pairs reconstructs a parameter tree from path-tensor pairs. Raises Invalid_argument if paths are malformed or inconsistent.
get_by_path path tree retrieves the subtree at the given path. Path uses dot notation for records and bracket notation for lists. Examples: "encoder.weight", "layers0.attention.q_proj"
set_by_path path value tree returns a new tree with the value at path replaced. Creates intermediate records if they don't exist.
validate_tree ?path tree checks for structural issues:
)list_named_params tree returns a list of (path, shape_string, num_elements). Shape strings are formatted as "2×3×4" for tensors or "scalar" for 0-d tensors. Useful for inspecting model architecture.
find_params_by_pattern pattern tree returns all params whose paths match the regex pattern. Example: find_params_by_pattern ".*weight$" finds all weight tensors.