Skip to content

calibrax.metrics.stateful¤

Base classes for stateful metrics (Tier 1 and Tier 2). FrozenBackboneMetric provides the accumulate-then-compute pattern for metrics requiring a pre-trained feature extractor (e.g., FID, BERTScore). LearnedMetric extends it with trainable calibration weights (e.g., LPIPS).

Stateful metrics: Tier 1 (frozen backbone) and Tier 2 (learned) base classes.

FrozenBackboneMetric(name: str) ¤

Bases: MetricPlotMixin, ABC

Base class for Tier 1 metrics with frozen pretrained backbones.

Implements the StatefulMetricProtocol lifecycle: - update(**kwargs): Extract features from backbone, accumulate statistics - compute(): Produce final metric from accumulated statistics - reset(): Clear accumulated state

Subclasses must implement: - _extract_features(**kwargs): Run backbone on input batch - _accumulate(features): Update running statistics with new features - _compute_from_accumulated(): Produce final result from statistics

The backbone model is frozen (no gradient updates). Subclasses load pretrained weights in init.

Parameters:

Name Type Description Default
name str

Unique metric identifier.

required

Examples:

>>> class MyMetric(FrozenBackboneMetric):
...     def __init__(self):
...         super().__init__(name="my_metric")
...         self._values = []
...     def reset(self):
...         self._values = []
...     def _extract_features(self, **kwargs):
...         return kwargs["data"]
...     def _accumulate(self, features):
...         self._values.append(float(features.mean()))
...     def _compute_from_accumulated(self):
...         return {"my_metric": sum(self._values) / len(self._values)}

Initialize the frozen backbone metric.

Parameters:

Name Type Description Default
name str

Unique metric name.

required

name: str property ¤

Get the metric name.

update(**kwargs: Any) -> None ¤

Extract features and accumulate statistics.

Parameters:

Name Type Description Default
**kwargs Any

Batch data (images, text, etc.).

{}

compute() -> dict[str, float] ¤

Compute final metric from accumulated statistics.

Returns:

Type Description
dict[str, float]

Dictionary mapping metric names to values.

reset() -> None abstractmethod ¤

Reset accumulated state for a new evaluation.

LearnedMetric(name: str, *, rngs: nnx.Rngs) ¤

Bases: MetricPlotMixin, Module

Base class for Tier 2 metrics with trainable calibration layers.

Extends nnx.Module for JAX transform compatibility (jit, grad, vmap). Inherits train()/eval() mode switching from nnx.Module.

Subclasses should implement update/compute/reset following the StatefulMetricProtocol pattern, but with trainable parameters that can be optimized.

Parameters:

Name Type Description Default
name str

Metric name identifier.

required
rngs Rngs

RNG streams for parameter initialization.

required

Examples:

>>> class MyLearnedMetric(LearnedMetric):
...     def __init__(self, *, rngs):
...         super().__init__(name="my_metric", rngs=rngs)
...         self._linear = nnx.Linear(4, 1, rngs=rngs)

Initialize learned metric.

Parameters:

Name Type Description Default
name str

Metric name.

required
rngs Rngs

RNG streams for parameter initialization.

required

name: str property ¤

Get the metric name.

Plugin Implementations

Concrete implementations live in calibrax.metrics.plugins:

  • Image (FID, InceptionScore, LPIPS): uv pip install "calibrax[image]"
  • Text (BERTScore): uv pip install "calibrax[text]"

See Image Metrics and Text Metrics for details.