Skip to content

calibrax.metrics¤

JAX-native evaluation metrics organized in a 4-tier system. All Tier 0 functions accept JAX arrays and return JAX scalar arrays. Higher tiers provide stateful and learned metric implementations built on Flax NNX.

Tier System¤

Tier Name Description Example
0 Pure Functions Stateless f(y_pred, y_true) -> scalar MSE, BLEU, IoU
1 Frozen Backbone Pre-trained feature extractor, no gradient FID, BERTScore
2 Learned Backbone with learned calibration weights LPIPS
3 Metric Learning Differentiable loss for embedding spaces TripletMarginLoss, ArcFace

Registry Usage¤

All registered metrics can be discovered and computed through the MetricRegistry singleton:

from calibrax.metrics import MetricRegistry, calculate_all

# Discover metrics by domain or tier
registry = MetricRegistry()
regression_metrics = registry.list_by_domain("general")
jit_safe = registry.list_jit_compatible()

# Batch computation of Tier 0 metrics
results = calculate_all(predictions, targets, metrics=["mse", "mae", "r_squared"])

Sub-modules¤

  • Registry -- MetricRegistry, MetricEntry, MetricTier, MetricSignature
  • Regression -- MSE, MAE, RMSE, R-squared, Huber, quantile loss
  • Classification -- accuracy, precision, recall, F1, ROC-AUC
  • Calibration -- Brier score, ECE, MCE, adaptive ECE
  • Segmentation -- IoU, Dice, pixel accuracy
  • Distance -- Euclidean, cosine, Mahalanobis, Poincare, Lorentz
  • Divergence -- KL, JS, Wasserstein, Sinkhorn, MMD
  • Information -- entropy, cross-entropy, mutual information
  • Ranking -- NDCG, MAP, MRR, precision/recall at k
  • Statistical -- Pearson, Spearman, Kendall, concordance
  • Clustering -- ARI, NMI, silhouette, Davies-Bouldin
  • Fairness -- demographic parity, equalized odds, disparate impact
  • Image -- PSNR, SSIM, MS-SSIM, Vendi Score
  • Video -- VMAF via FFmpeg/libvmaf
  • Text -- BLEU, ROUGE, perplexity, distinct-N
  • Audio -- SNR, spectral convergence, mel cepstral distortion
  • Geometric -- Chamfer, Hausdorff, Earth Mover's distance
  • Graph -- spectral distance, graph edit distance, resistance distance
  • Manifold -- SPD, Grassmann, Stiefel, ultrahyperbolic distances
  • Composition -- MetricCollection, WeightedMetric, MetricSuite
  • Wrappers -- BootstrapMetric, ClasswiseWrapper, MetricTracker
  • Stateful -- FrozenBackboneMetric, LearnedMetric base classes
  • Learning -- contrastive, triplet, ArcFace, proxy losses
  • Scientific -- chemical validity, binding affinity, conformational

Metrics: JAX-native evaluation metrics for the calibrax ecosystem.

Provides a 4-tier metric system: - Tier 0: Pure functions via calibrax.metrics.functional - Tier 1: Frozen backbone metrics via calibrax.metrics.stateful - Tier 2: Learned calibration metrics via calibrax.metrics.stateful - Tier 3: Metric learning losses via calibrax.metrics.learning

MetricRegistry for metric discovery, and calculate_all for batch computation.

Individual metric functions live in their domain modules (e.g., calibrax.metrics.functional.regression). This top-level package exports registry infrastructure, types, composition, and wrapper classes.