calibrax.metrics.stateful¤
Base classes for stateful metrics (Tier 1 and Tier 2). FrozenBackboneMetric
provides the accumulate-then-compute pattern for metrics requiring a
pre-trained feature extractor (e.g., FID, BERTScore). LearnedMetric
extends it with trainable calibration weights (e.g., LPIPS).
Stateful metrics: Tier 1 (frozen backbone) and Tier 2 (learned) base classes.
FrozenBackboneMetric(name: str)
¤
Bases: MetricPlotMixin, ABC
Base class for Tier 1 metrics with frozen pretrained backbones.
Implements the StatefulMetricProtocol lifecycle: - update(**kwargs): Extract features from backbone, accumulate statistics - compute(): Produce final metric from accumulated statistics - reset(): Clear accumulated state
Subclasses must implement: - _extract_features(**kwargs): Run backbone on input batch - _accumulate(features): Update running statistics with new features - _compute_from_accumulated(): Produce final result from statistics
The backbone model is frozen (no gradient updates). Subclasses load pretrained weights in init.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique metric identifier. |
required |
Examples:
>>> class MyMetric(FrozenBackboneMetric):
... def __init__(self):
... super().__init__(name="my_metric")
... self._values = []
... def reset(self):
... self._values = []
... def _extract_features(self, **kwargs):
... return kwargs["data"]
... def _accumulate(self, features):
... self._values.append(float(features.mean()))
... def _compute_from_accumulated(self):
... return {"my_metric": sum(self._values) / len(self._values)}
Initialize the frozen backbone metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique metric name. |
required |
name: str
property
¤
Get the metric name.
update(**kwargs: Any) -> None
¤
Extract features and accumulate statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Batch data (images, text, etc.). |
{}
|
compute() -> dict[str, float]
¤
Compute final metric from accumulated statistics.
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary mapping metric names to values. |
reset() -> None
abstractmethod
¤
Reset accumulated state for a new evaluation.
LearnedMetric(name: str, *, rngs: nnx.Rngs)
¤
Bases: MetricPlotMixin, Module
Base class for Tier 2 metrics with trainable calibration layers.
Extends nnx.Module for JAX transform compatibility (jit, grad, vmap). Inherits train()/eval() mode switching from nnx.Module.
Subclasses should implement update/compute/reset following the StatefulMetricProtocol pattern, but with trainable parameters that can be optimized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Metric name identifier. |
required |
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
Examples:
>>> class MyLearnedMetric(LearnedMetric):
... def __init__(self, *, rngs):
... super().__init__(name="my_metric", rngs=rngs)
... self._linear = nnx.Linear(4, 1, rngs=rngs)
Initialize learned metric.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Metric name. |
required |
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
name: str
property
¤
Get the metric name.
Plugin Implementations
Concrete implementations live in calibrax.metrics.plugins:
- Image (FID, InceptionScore, LPIPS):
uv pip install "calibrax[image]" - Text (BERTScore):
uv pip install "calibrax[text]"
See Image Metrics and Text Metrics for details.