Skip to content

calibrax.core.protocols¤

Runtime-checkable protocols for benchmarks, datasets, and metrics. These use Python's structural subtyping — any class with matching methods satisfies the protocol without needing to inherit from it.

Protocol definitions for the calibrax benchmarking framework.

All protocols are runtime_checkable for structural subtyping checks.

BenchmarkProtocol ¤

Bases: Protocol

Standard interface for benchmarks.

Defines the lifecycle of a benchmark: setup, training, evaluation, teardown, and performance target retrieval.

setup() ¤

Set up benchmark resources before execution.

run_training() ¤

Execute the training phase.

Returns:

Type Description
dict[str, float]

Dictionary of training metric name-value pairs.

run_evaluation() ¤

Execute the evaluation phase.

Returns:

Type Description
dict[str, float]

Dictionary of evaluation metric name-value pairs.

teardown() ¤

Release benchmark resources after execution.

get_performance_targets() ¤

Return expected performance targets.

Returns:

Type Description
dict[str, float]

Dictionary mapping metric names to target values.

DatasetProtocol ¤

Bases: Protocol

Interface for datasets used in benchmarks.

__len__() ¤

Get the number of examples in the dataset.

__getitem__(idx) ¤

Get an example by index.

Parameters:

Name Type Description Default
idx int

Index of the example.

required

Returns:

Type Description
Any

The example at the given index.

BatchableDatasetProtocol ¤

Bases: Protocol

Interface for datasets that support batch retrieval.

Extends DatasetProtocol with get_batch capability.

__len__() ¤

Get the number of examples in the dataset.

__getitem__(idx) ¤

Get an example by index.

Parameters:

Name Type Description Default
idx int

Index of the example.

required

Returns:

Type Description
Any

The example at the given index.

get_batch(batch_size, start_idx) ¤

Get a batch of data starting at the given index.

Parameters:

Name Type Description Default
batch_size int

Number of examples in the batch.

required
start_idx int

Starting index for the batch.

required

Returns:

Type Description
dict[str, Any]

Batch data dictionary.

MetricProtocol ¤

Bases: Protocol

Universal metric interface for evaluation.

Supports computing a metric from predictions and targets, with input validation.

name property ¤

Get the metric name.

higher_is_better property ¤

Whether higher values indicate better performance.

compute(predictions, targets) ¤

Compute the metric value.

Parameters:

Name Type Description Default
predictions Array

Model predictions.

required
targets Array

Ground truth targets.

required

Returns:

Type Description
Any

Computed metric value.

validate_inputs(predictions, targets) ¤

Validate that inputs are compatible for metric computation.

Parameters:

Name Type Description Default
predictions Array

Model predictions.

required
targets Array

Ground truth targets.

required

Raises:

Type Description
ValueError

If inputs are incompatible.

StatefulMetricProtocol ¤

Bases: Protocol

Interface for stateful metrics with batch accumulation (Tier 1-2).

Follows the update/compute/reset lifecycle pattern from TorchMetrics and Google Metrax. Metrics accumulate statistics across batches via update(), then produce final results via compute().

Tier 1 (FrozenBackboneMetric): frozen pretrained backbone extracts features, accumulates statistics (e.g., FID mean/covariance).

Tier 2 (LearnedMetric): trainable calibration layers on top of backbone features (e.g., LPIPS).

name property ¤

Get the metric name.

update(**kwargs) ¤

Accumulate batch statistics.

Parameters:

Name Type Description Default
**kwargs Any

Batch data (e.g., images, features, predictions).

{}

compute() ¤

Compute final metric values from accumulated statistics.

Returns:

Type Description
dict[str, float]

Dictionary mapping metric names to computed values.

reset() ¤

Reset accumulated state for a new evaluation run.

MetricLearningProtocol ¤

Bases: Protocol

Interface for metric learning losses (Tier 3).

Returns a differentiable JAX array (not a Python float) to enable gradient flow for training embedding spaces. The loss function IS the metric — it learns a distance function via backpropagation.

Examples: ContrastiveLoss, TripletMarginLoss, ArcFaceLoss.

__call__(embeddings, labels) ¤

Compute the metric learning loss.

Parameters:

Name Type Description Default
embeddings Array

Batch of embedding vectors, shape (batch_size, embedding_dim).

required
labels Array

Integer class labels, shape (batch_size,).

required

Returns:

Type Description
Array

Scalar loss value as a JAX array (differentiable).