pytest Integration Patterns¤
Calibrax can be used alongside pytest-benchmark for a layered
benchmarking strategy: pytest-benchmark for quick CI checks, Calibrax
for full profiling pipelines.
Basic pytest Fixture¤
Wrap Calibrax's TimingCollector in a pytest fixture:
import pytest
from calibrax.profiling.timing import TimingCollector
@pytest.fixture
def timing_collector():
"""Provide a TimingCollector with 2 warmup iterations."""
return TimingCollector(warmup_iterations=2)
def test_matmul_throughput(timing_collector):
import jax.numpy as jnp
data = [jnp.ones((64, 64)) for _ in range(20)]
sample = timing_collector.measure_iteration(iter(data), num_batches=20)
assert sample.num_batches == 20
assert sample.warmup_batches_excluded == 2
mean_time = sum(sample.per_batch_times) / len(sample.per_batch_times)
assert mean_time < 1.0 # Sanity check
Store-Backed Benchmarking Fixture¤
Save benchmark results to a Calibrax Store for trend tracking:
import pytest
from pathlib import Path
from calibrax.storage.store import Store
from calibrax.core.models import Run, Point, Metric
@pytest.fixture
def benchmark_store(tmp_path):
"""Provide a temporary Store for benchmark results."""
return Store(tmp_path / "benchmarks")
def test_training_step_speed(benchmark_store):
import time
start = time.perf_counter()
# ... run training step ...
elapsed = time.perf_counter() - start
run = Run(
points=(
Point(
name="training_step",
scenario="speed",
tags={"framework": "jax"},
metrics={"wall_time_sec": Metric(value=elapsed)},
),
),
)
benchmark_store.save(run)
Combining with pytest-benchmark¤
Use pytest-benchmark for quick timing and Calibrax for detailed analysis:
def test_forward_pass(benchmark):
import jax.numpy as jnp
x = jnp.ones((32, 128))
def forward():
return jnp.dot(x, x.T).block_until_ready()
# Quick benchmark via pytest-benchmark
result = benchmark(forward)
# Detailed profiling via Calibrax (optional, for CI)
from calibrax.profiling.timing import TimingCollector
collector = TimingCollector(
sync_fn=lambda result: jnp.array(0.0).block_until_ready(),
warmup_iterations=3,
)
sample = collector.measure_iteration(
iter(range(50)),
num_batches=50,
count_fn=lambda _: (forward(), 1)[1],
)
assert len(sample.per_batch_times) == 47 # 50 - 3 warmup
CI Regression Checking¤
Use Calibrax's CIGuard in pytest for automated regression detection:
import pytest
from calibrax.ci.guard import CIGuard
from calibrax.storage.store import Store
@pytest.fixture
def ci_guard(tmp_path):
store = Store(tmp_path / "benchmarks")
return CIGuard(store, threshold=0.05)
def test_no_regressions(ci_guard):
result = ci_guard.check()
if not result.passed:
failures = [
f"{r.metric} on {r.point_name}: {r.delta_pct:+.1f}%"
for r in result.regressions
]
pytest.fail(f"Performance regressions detected:\n" + "\n".join(failures))
Resource Monitoring in Tests¤
Monitor resource usage during test execution:
from calibrax.profiling.resources import ResourceMonitor
def test_memory_bounded(benchmark):
with ResourceMonitor(sample_interval_sec=0.05) as mon:
# ... run workload ...
pass
summary = mon.summary
assert summary.peak_rss_mb < 2048, f"Peak RSS too high: {summary.peak_rss_mb:.0f} MB"
assert summary.memory_growth_mb < 100, "Memory leak detected"