calibrax.core.protocols¤
Runtime-checkable protocols for benchmarks, datasets, and metrics. These use Python's structural subtyping — any class with matching methods satisfies the protocol without needing to inherit from it.
Protocol definitions for the calibrax benchmarking framework.
All protocols are runtime_checkable for structural subtyping checks.
BenchmarkProtocol
¤
Bases: Protocol
Standard interface for benchmarks.
Defines the lifecycle of a benchmark: setup, training, evaluation, teardown, and performance target retrieval.
setup()
¤
Set up benchmark resources before execution.
run_training()
¤
Execute the training phase.
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary of training metric name-value pairs. |
run_evaluation()
¤
Execute the evaluation phase.
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary of evaluation metric name-value pairs. |
teardown()
¤
Release benchmark resources after execution.
get_performance_targets()
¤
Return expected performance targets.
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary mapping metric names to target values. |
DatasetProtocol
¤
Bases: Protocol
Interface for datasets used in benchmarks.
BatchableDatasetProtocol
¤
Bases: Protocol
Interface for datasets that support batch retrieval.
Extends DatasetProtocol with get_batch capability.
__len__()
¤
Get the number of examples in the dataset.
__getitem__(idx)
¤
Get an example by index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
Index of the example. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The example at the given index. |
get_batch(batch_size, start_idx)
¤
Get a batch of data starting at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch_size
|
int
|
Number of examples in the batch. |
required |
start_idx
|
int
|
Starting index for the batch. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Batch data dictionary. |
MetricProtocol
¤
Bases: Protocol
Universal metric interface for evaluation.
Supports computing a metric from predictions and targets, with input validation.
name
property
¤
Get the metric name.
higher_is_better
property
¤
Whether higher values indicate better performance.
compute(predictions, targets)
¤
Compute the metric value.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions. |
required |
targets
|
Array
|
Ground truth targets. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Computed metric value. |
validate_inputs(predictions, targets)
¤
Validate that inputs are compatible for metric computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predictions
|
Array
|
Model predictions. |
required |
targets
|
Array
|
Ground truth targets. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If inputs are incompatible. |
StatefulMetricProtocol
¤
Bases: Protocol
Interface for stateful metrics with batch accumulation (Tier 1-2).
Follows the update/compute/reset lifecycle pattern from TorchMetrics and Google Metrax. Metrics accumulate statistics across batches via update(), then produce final results via compute().
Tier 1 (FrozenBackboneMetric): frozen pretrained backbone extracts features, accumulates statistics (e.g., FID mean/covariance).
Tier 2 (LearnedMetric): trainable calibration layers on top of backbone features (e.g., LPIPS).
name
property
¤
Get the metric name.
update(**kwargs)
¤
Accumulate batch statistics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**kwargs
|
Any
|
Batch data (e.g., images, features, predictions). |
{}
|
compute()
¤
Compute final metric values from accumulated statistics.
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary mapping metric names to computed values. |
reset()
¤
Reset accumulated state for a new evaluation run.
MetricLearningProtocol
¤
Bases: Protocol
Interface for metric learning losses (Tier 3).
Returns a differentiable JAX array (not a Python float) to enable gradient flow for training embedding spaces. The loss function IS the metric — it learns a distance function via backpropagation.
Examples: ContrastiveLoss, TripletMarginLoss, ArcFaceLoss.
__call__(embeddings, labels)
¤
Compute the metric learning loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
Batch of embedding vectors, shape (batch_size, embedding_dim). |
required |
labels
|
Array
|
Integer class labels, shape (batch_size,). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value as a JAX array (differentiable). |