Module Kaun.MetricsSource

Performance metrics for neural network training and evaluation.

This module provides a comprehensive set of metrics for monitoring model performance during training and evaluation. Metrics are designed to be composable, efficient, and stateful for accumulation across batches while remaining layout-agnostic at the type level.

Core Types

Sourcetype metric

Layout-independent metric accumulator that produces host float values when computed.

Sourcetype reduction =
  1. | Mean
  2. | Sum
    (*

    How to reduce metric values across batch dimensions

    *)

Metric Creation

Classification Metrics

Sourceval accuracy : ?threshold:float -> ?top_k:int -> unit -> metric

accuracy ?threshold ?top_k () creates an accuracy metric.

  • parameter threshold

    Threshold for binary classification (default: 0.5)

  • parameter top_k

    For multi-class, count as correct if true label is in top-k predictions @note For per-class or aggregated variants, combine Metrics.confusion_matrix with custom post-processing.

Example

  let acc = Metrics.accuracy () in
  let top5_acc = Metrics.accuracy ~top_k:5 ()
Sourceval precision : ?threshold:float -> ?zero_division:float -> unit -> metric

precision ?threshold ?zero_division () creates a precision metric.

Precision = True Positives / (True Positives + False Positives)

  • parameter threshold

    Binary classification threshold (default: 0.5)

  • parameter zero_division

    Value to return when there are no positive predictions (default: 0.0)

Example

  let prec = Metrics.precision ()
Sourceval recall : ?threshold:float -> ?zero_division:float -> unit -> metric

recall ?threshold ?zero_division () creates a recall metric.

Recall = True Positives / (True Positives + False Negatives)

  • parameter threshold

    Binary classification threshold (default: 0.5)

  • parameter zero_division

    Value to return when there are no actual positives (default: 0.0)

Sourceval f1_score : ?threshold:float -> ?beta:float -> unit -> metric

f1_score ?threshold ?beta () creates an F-score metric.

F-score = (1 + β²) * (Precision * Recall) / (β² * Precision + Recall)

  • parameter threshold

    Binary classification threshold (default: 0.5)

  • parameter beta

    Weight of recall vs precision (default: 1.0 for F1)

Sourceval auc_roc : unit -> metric

auc_roc () creates an AUC-ROC (Area Under the Receiver Operating Characteristic) metric that integrates true/false positive rates observed across batches.

Sourceval auc_pr : unit -> metric

auc_pr () creates an AUC-PR (Area Under the Precision–Recall) metric. Computes the exact precision–recall integral by sorting predictions and accumulating precision/recall scores across all seen batches.

Sourceval confusion_matrix : num_classes:int -> ?normalize:[ `None | `True | `Pred | `All ] -> unit -> metric

confusion_matrix ~num_classes ?normalize () accumulates a confusion matrix for classification tasks.

  • parameter num_classes

    Number of classes

  • parameter normalize

    Normalisation mode (default: `None)

Regression Metrics

Sourceval mse : ?reduction:reduction -> unit -> metric

mse ?reduction () creates a Mean Squared Error metric.

MSE = mean((predictions - targets)²)

Sourceval rmse : ?reduction:reduction -> unit -> metric

rmse ?reduction () creates a Root Mean Squared Error metric.

RMSE = sqrt(mean((predictions - targets)²))

Sourceval mae : ?reduction:reduction -> unit -> metric

mae ?reduction () creates a Mean Absolute Error metric.

MAE = mean(|predictions - targets|)

Sourceval loss : unit -> metric

loss () tracks the running mean of loss values. Pass batch losses through update ~loss to accumulate them.

Sourceval mape : ?eps:float -> unit -> metric

mape ?eps () creates a Mean Absolute Percentage Error metric.

MAPE = mean(|predictions - targets| / (|targets| + eps)) * 100

  • parameter eps

    Small value to avoid division by zero (default: 1e-7)

Sourceval r2_score : ?adjusted:bool -> ?num_features:int -> unit -> metric

r2_score ?adjusted ?num_features () creates an R² coefficient of determination metric.

R² = 1 - (SS_res / SS_tot)

  • parameter adjusted

    If true, compute adjusted R² (requires num_features)

  • parameter num_features

    Number of features (needed for adjusted R²)

Sourceval explained_variance : unit -> metric

explained_variance () creates an explained variance metric.

EV = 1 - Var(targets - predictions) / Var(targets)

Probabilistic Metrics

Sourceval cross_entropy : ?from_logits:bool -> unit -> metric

cross_entropy ?from_logits () creates a cross-entropy metric.

  • parameter from_logits

    If true, apply softmax to predictions (default: true)

Sourceval binary_cross_entropy : ?from_logits:bool -> unit -> metric

binary_cross_entropy ?from_logits () creates a binary cross-entropy metric.

  • parameter from_logits

    If true, apply sigmoid to predictions (default: true)

Sourceval kl_divergence : ?eps:float -> unit -> metric

kl_divergence ?eps () creates a Kullback–Leibler divergence metric.

KL(P||Q) = ÎŁ P log(P / Q)

  • parameter eps

    Small value for numerical stability (default: 1e-7)

Sourceval perplexity : ?base:float -> unit -> metric

perplexity ?base () creates a perplexity metric for language models.

Perplexity = base^(cross_entropy)

  • parameter base

    Base for exponentiation (default: e)

Ranking Metrics

Sourceval ndcg : ?k:int -> unit -> metric

ndcg ?k () creates a Normalised Discounted Cumulative Gain metric.

  • parameter k

    Consider only the top-k ranked items (default: all)

Sourceval map : ?k:int -> unit -> metric

map ?k () creates a Mean Average Precision metric for ranking.

Sourceval mrr : ?k:int -> unit -> metric

mrr ?k () creates a Mean Reciprocal Rank metric.

MRR = mean(1 / rank_of_first_relevant_item)

  • parameter k

    Consider only top-k items when computing the reciprocal rank (default: all)

Natural Language Metrics

Sourceval bleu : ?max_n:int -> ?weights:float array -> ?smoothing:bool -> unit -> metric

bleu ?max_n ?weights ?smoothing () creates a BLEU score metric for pre-tokenized integer sequences.

  • parameter max_n

    Maximum n-gram order (default: 4)

  • parameter weights

    Weights for each n-gram order (default: uniform)

  • parameter smoothing

    Apply smoothing for zero counts (default: true)

Predictions and targets must be shaped batch, seq_len with integer token identifiers. Zero values are treated as padding and ignored.

Sourceval rouge : variant:[ `Rouge1 | `Rouge2 | `RougeL ] -> ?use_stemmer:bool -> unit -> metric

rouge ~variant ?use_stemmer () creates a ROUGE score metric for pre-tokenized integer sequences.

  • parameter variant

    Which ROUGE variant to compute

  • parameter use_stemmer

    Enable stemming (currently unsupported; raises when set)

Predictions and targets must be shaped batch, seq_len with integer token identifiers. Zero values are treated as padding and ignored.

Sourceval meteor : ?alpha:float -> ?beta:float -> ?gamma:float -> unit -> metric

meteor ?alpha ?beta ?gamma () creates a METEOR score metric for pre-tokenized integer sequences.

  • parameter alpha

    Parameter controlling precision vs recall balance (default: 0.9)

  • parameter beta

    Exponent for the chunk penalty (default: 3.0)

  • parameter gamma

    Weight of the chunk penalty (default: 0.5)

Predictions and targets must be shaped batch, seq_len with integer token identifiers. Zero values are treated as padding and ignored.

Image Metrics

Sourceval psnr : ?max_val:float -> unit -> metric

psnr ?max_val () creates a Peak Signal-to-Noise Ratio metric.

PSNR = 10 * log10(max_val² / MSE)

  • parameter max_val

    Maximum possible pixel value (default: 1.0)

Sourceval ssim : ?window_size:int -> ?k1:float -> ?k2:float -> unit -> metric

ssim ?window_size ?k1 ?k2 () creates a Structural Similarity Index metric.

The implementation evaluates the global SSIM across the full prediction and target tensors using scalar statistics derived from window_size, k1, and k2.

Sourceval iou : ?threshold:float -> ?per_class:bool -> num_classes:int -> unit -> metric

iou ?threshold ?per_class ~num_classes () creates an Intersection over Union metric.

Inputs must contain integer class indices in 0, num_classes). When [num_classes = 2], [threshold] binarises predictions before computing IoU. When [per_class = true], the metric reports one IoU per class; otherwise it returns the mean over classes with non-zero support.

Sourceval dice : ?threshold:float -> ?per_class:bool -> num_classes:int -> unit -> metric

dice ?threshold ?per_class ~num_classes () creates a Sørenson Dice coefficient metric with the same input conventions as iou.

Metric Operations

Sourceval update : metric -> predictions:(float, 'layout) Rune.t -> targets:(_, 'layout) Rune.t -> ?loss:(float, 'layout) Rune.t -> ?weights:(float, 'layout) Rune.t -> unit -> unit

update metric ~predictions ~targets ?loss ?weights () updates the metric state. All tensors must share the same (hidden) layout. When supplied, the loss tensor is treated as an auxiliary scalar for metrics that track losses.

Sourceval compute : metric -> float

compute metric returns the aggregated metric value as a host float.

Sourceval compute_tensor : metric -> Ptree.tensor

compute_tensor metric returns the aggregated metric value as a device tensor.

Sourceval reset : metric -> unit

reset metric clears internal accumulators for a fresh run.

Sourceval clone : metric -> metric

clone metric creates a new metric with the same configuration but fresh state.

Sourceval name : metric -> string

name metric returns the metric's descriptive name.

Sourceval create_custom : dtype:(float, 'layout) Rune.dtype -> name:string -> init:(unit -> (float, 'layout) Rune.t list) -> update: ((float, 'layout) Rune.t list -> predictions:(float, 'layout) Rune.t -> targets:(float, 'layout) Rune.t -> ?weights:(float, 'layout) Rune.t -> unit -> (float, 'layout) Rune.t list) -> compute:((float, 'layout) Rune.t list -> (float, 'layout) Rune.t) -> reset:((float, 'layout) Rune.t list -> (float, 'layout) Rune.t list) -> metric

create_custom ~dtype ~name ~init ~update ~compute ~reset constructs a custom metric from user-provided accumulator functions.

Sourceval is_better : metric -> higher_better:bool -> old_val:float -> new_val:float -> bool

is_better metric ~higher_better ~old_val ~new_val determines whether the new metric value improves upon the previous one.

Sourceval format : metric -> float -> string

format metric value pretty-prints a metric value for logging.

Metric Collections

Sourcemodule Collection : sig ... end