calibrax.metrics¤
JAX-native evaluation metrics organized in a 4-tier system. All Tier 0 functions accept JAX arrays and return JAX scalar arrays. Higher tiers provide stateful and learned metric implementations built on Flax NNX.
Tier System¤
| Tier | Name | Description | Example |
|---|---|---|---|
| 0 | Pure Functions | Stateless f(y_pred, y_true) -> scalar |
MSE, BLEU, IoU |
| 1 | Frozen Backbone | Pre-trained feature extractor, no gradient | FID, BERTScore |
| 2 | Learned | Backbone with learned calibration weights | LPIPS |
| 3 | Metric Learning | Differentiable loss for embedding spaces | TripletMarginLoss, ArcFace |
Registry Usage¤
All registered metrics can be discovered and computed through the
MetricRegistry singleton:
from calibrax.metrics import MetricRegistry, calculate_all
# Discover metrics by domain or tier
registry = MetricRegistry()
regression_metrics = registry.list_by_domain("general")
jit_safe = registry.list_jit_compatible()
# Batch computation of Tier 0 metrics
results = calculate_all(predictions, targets, metrics=["mse", "mae", "r_squared"])
Sub-modules¤
- Registry -- MetricRegistry, MetricEntry, MetricTier, MetricSignature
- Regression -- MSE, MAE, RMSE, R-squared, Huber, quantile loss
- Classification -- accuracy, precision, recall, F1, ROC-AUC
- Calibration -- Brier score, ECE, MCE, adaptive ECE
- Segmentation -- IoU, Dice, pixel accuracy
- Distance -- Euclidean, cosine, Mahalanobis, Poincare, Lorentz
- Divergence -- KL, JS, Wasserstein, Sinkhorn, MMD
- Information -- entropy, cross-entropy, mutual information
- Ranking -- NDCG, MAP, MRR, precision/recall at k
- Statistical -- Pearson, Spearman, Kendall, concordance
- Clustering -- ARI, NMI, silhouette, Davies-Bouldin
- Fairness -- demographic parity, equalized odds, disparate impact
- Image -- PSNR, SSIM, MS-SSIM, Vendi Score
- Video -- VMAF via FFmpeg/libvmaf
- Text -- BLEU, ROUGE, perplexity, distinct-N
- Audio -- SNR, spectral convergence, mel cepstral distortion
- Geometric -- Chamfer, Hausdorff, Earth Mover's distance
- Graph -- spectral distance, graph edit distance, resistance distance
- Manifold -- SPD, Grassmann, Stiefel, ultrahyperbolic distances
- Composition -- MetricCollection, WeightedMetric, MetricSuite
- Wrappers -- BootstrapMetric, ClasswiseWrapper, MetricTracker
- Stateful -- FrozenBackboneMetric, LearnedMetric base classes
- Learning -- contrastive, triplet, ArcFace, proxy losses
- Scientific -- chemical validity, binding affinity, conformational
Metrics: JAX-native evaluation metrics for the calibrax ecosystem.
Provides a 4-tier metric system:
- Tier 0: Pure functions via calibrax.metrics.functional
- Tier 1: Frozen backbone metrics via calibrax.metrics.stateful
- Tier 2: Learned calibration metrics via calibrax.metrics.stateful
- Tier 3: Metric learning losses via calibrax.metrics.learning
MetricRegistry for metric discovery, and calculate_all for batch computation.
Individual metric functions live in their domain modules
(e.g., calibrax.metrics.functional.regression). This top-level package
exports registry infrastructure, types, composition, and wrapper classes.