Skip to content

Calibrax: Unified Benchmarking for JAX Scientific ML¤

Python JAX License

Calibrax is an extensible benchmarking framework for the JAX scientific ML ecosystem. It provides profiling, statistical analysis, regression detection, and reporting in a single library — purpose-built for evaluating JAX-based models and training pipelines.

Why Calibrax?¤

Benchmarking scientific ML workloads involves more than timing a function call. Metrics have directions (throughput should increase, latency should decrease), results need statistical rigor (confidence intervals, significance tests), and regressions must be caught automatically in CI. General-purpose tools lack these domain-specific concepts. Calibrax fills this gap with a unified data model, direction-aware analysis, and a full pipeline from measurement to publication.

Features¤

  • Profiling


    Wall-clock timing with GPU sync, resource monitoring, GPU memory analysis, energy measurement, and FLOP counting

    Profiling guide

  • Statistical Analysis


    Bootstrap confidence intervals, outlier detection, hypothesis testing, and effect size estimation

    Statistics guide

  • Regression Detection


    Direction-aware regression detection with configurable thresholds and automatic CI gating

    Regression guide

  • Comparison & Ranking


    Cross-configuration comparison, Pareto front analysis, aggregate scoring, and ranking tables

    Comparison guide

  • Storage & Export


    JSON-per-run file store with baselines, W&B integration, and publication-ready LaTeX/HTML tables and matplotlib plots

    Storage guide

  • Validation


    Convergence analysis, accuracy assessment, and scientific validation reporting against reference implementations

    API reference

Quick Example¤

from calibrax.core.models import MetricDef, MetricDirection, Metric, Point, Run
from calibrax.storage.store import Store
from calibrax.analysis.regression import detect_regressions

# Define what "better" means for each metric
throughput_def = MetricDef(
    name="throughput", unit="samples/sec", direction=MetricDirection.HIGHER
)
latency_def = MetricDef(
    name="latency", unit="ms", direction=MetricDirection.LOWER
)

# Record benchmark results
point = Point(
    name="forward_pass",
    scenario="training",
    metrics={
        "throughput": Metric(value=1250.0, lower=1200.0, upper=1300.0),
        "latency": Metric(value=0.8, lower=0.75, upper=0.85),
    },
)
run = Run(
    points=(point,),
    metric_defs={"throughput": throughput_def, "latency": latency_def},
)

# Save and detect regressions against a baseline
store = Store("/tmp/benchmarks")
store.save(run)
store.set_baseline(run.id)

# Later, compare a new run against the baseline
new_run = Run(
    points=(Point(
        name="forward_pass", scenario="training",
        metrics={
            "throughput": Metric(value=1100.0),
            "latency": Metric(value=0.9),
        },
    ),),
    metric_defs={"throughput": throughput_def, "latency": latency_def},
)
baseline = store.get_baseline()
if baseline is not None:
    regressions = detect_regressions(new_run, baseline, threshold=0.05)
    for r in regressions:
        print(f"{r.metric}: {r.delta_pct:+.1f}% regression")

Architecture¤

flowchart LR
    A[Define Metrics] --> B[Profile Workload]
    B --> C[Collect Results]
    C --> D[Store & Baseline]
    D --> E[Analyze]
    E --> F[Export & Report]

    style A fill:#e3f2fd
    style B fill:#e3f2fd
    style C fill:#fff3e0
    style D fill:#fff3e0
    style E fill:#fff3e0
    style F fill:#c8e6c9

Next Steps¤

  • Installation


    Set up Calibrax with optional extras for statistics, GPU, and W&B support

    Install

  • Quick Start


    Working examples covering profiling, storage, regression detection, and CLI

    Quick start

  • Core Concepts


    Data model hierarchy, direction-aware metrics, protocols, and the benchmark lifecycle

    Concepts