Profiling Workloads¤
Calibrax provides a suite of profiling tools that can be used independently or composed into a full benchmark result: wall-clock timing (with warmup exclusion and compilation measurement), resource monitoring (CPU, memory, GPU clock/power), GPU memory analysis, energy measurement, FLOP counting, hardware detection, roofline analysis, compilation profiling, complexity analysis, XLA trace linking, and carbon emissions tracking.
Timing¤
TimingCollector measures wall-clock time for an iterator-based workload. Pass a
sync_fn to ensure GPU operations complete before each timestamp.
import jax
import jax.numpy as jnp
from calibrax.profiling.timing import TimingCollector
collector = TimingCollector(
sync_fn=lambda batch: jax.block_until_ready(batch),
warmup_iterations=2, # exclude first 2 batches from per_batch_times
)
sample = collector.measure_iteration(
iterator=iter(data_loader),
num_batches=100,
count_fn=lambda batch: batch["image"].shape[0],
)
print(f"Wall clock: {sample.wall_clock_sec:.3f}s")
print(f"Elements processed: {sample.num_elements}")
print(f"First batch: {sample.first_batch_time:.4f}s (includes JIT compilation)")
print(f"Warmup excluded: {sample.warmup_batches_excluded} batches")
print(f"Median batch: {sorted(sample.per_batch_times)[len(sample.per_batch_times)//2]:.4f}s")
The returned TimingSample contains:
wall_clock_sec— total elapsed time (including warmup)per_batch_times— per-iteration timings with warmup batches excludedfirst_batch_time— first iteration time, typically higher due to JITnum_batches— total iterations consumed (including warmup)num_elements— total element count (viacount_fn)warmup_batches_excluded— number of leading batches excluded from timingcompilation_time_sec— JIT compilation time, if measured separately
Measuring Compilation Time¤
Measure XLA compilation overhead separately from execution:
import jax.numpy as jnp
collector = TimingCollector()
comp_time = collector.measure_compilation_time(
lambda x: jnp.dot(x, x.T),
jnp.ones((64, 64)),
)
print(f"Compilation: {comp_time:.3f}s")
Serialization¤
TimingSample supports to_dict() and from_dict() for JSON-compatible
round-trip serialization:
from calibrax.profiling.timing import TimingSample
d = sample.to_dict()
restored = TimingSample.from_dict(d)
GPU Synchronization
Without sync_fn, GPU timing measures only dispatch time, not actual
computation. Always pass sync_fn=lambda batch: jax.block_until_ready(batch)
for accurate GPU benchmarks.
Resource Monitoring¤
ResourceMonitor runs a background thread that samples CPU and memory usage
at a configurable interval. Use it as a context manager:
from calibrax.profiling.resources import ResourceMonitor
with ResourceMonitor(sample_interval_sec=0.1) as monitor:
# Run your workload here
train(model, data)
summary = monitor.summary
print(f"Peak RSS: {summary.peak_rss_mb:.1f} MB")
print(f"Mean RSS: {summary.mean_rss_mb:.1f} MB")
print(f"Memory growth: {summary.memory_growth_mb:.1f} MB")
print(f"Duration: {summary.duration_sec:.2f}s")
print(f"Samples collected: {summary.num_samples}")
To include GPU utilization, pass a GPUProfilerProtocol-compatible object:
from calibrax.profiling.gpu import GPUMemoryProfiler
gpu_profiler = GPUMemoryProfiler()
with ResourceMonitor(gpu_profiler=gpu_profiler) as monitor:
train(model, data)
summary = monitor.summary
if summary.mean_gpu_util is not None:
print(f"Mean GPU utilization: {summary.mean_gpu_util:.1f}%")
if summary.peak_gpu_mem_mb is not None:
print(f"Peak GPU memory: {summary.peak_gpu_mem_mb:.1f} MB")
GPU Memory Analysis¤
GPUMemoryProfiler checks GPU memory usage at any point. MemoryOptimizer
analyzes the memory footprint of an entire pipeline, measuring baseline, peak,
and retained memory.
from calibrax.profiling.gpu import GPUMemoryProfiler, MemoryOptimizer
# Quick snapshot
profiler = GPUMemoryProfiler()
usage = profiler.get_memory_usage()
print(f"GPU memory used: {usage.get('gpu_memory_used_mb', 0):.1f} MB")
# Full pipeline analysis
optimizer = MemoryOptimizer()
analysis = optimizer.analyze_pipeline_memory(pipeline_fn, sample_data)
if analysis is not None:
print(f"Baseline: {analysis.baseline_memory_mb:.1f} MB")
print(f"Peak: {analysis.peak_memory_mb:.1f} MB")
print(f"Efficiency: {analysis.memory_efficiency:.1%}")
for suggestion in analysis.suggestions:
print(f" - {suggestion}")
AdaptiveOperation auto-detects hardware and optimizes tensor shapes:
from calibrax.profiling.gpu import AdaptiveOperation
adaptive = AdaptiveOperation()
print(f"Platform: {adaptive.config.platform}")
print(f"Precision: {adaptive.config.precision}")
optimized_shapes = adaptive.optimize_shapes((32, 128), (128, 64))
Energy Monitoring¤
EnergyMonitor tracks GPU power draw (via NVML) and CPU energy consumption
(via RAPL) during a workload. Use it as a context manager:
from calibrax.profiling.energy import EnergyMonitor
with EnergyMonitor(sample_interval_sec=0.1) as monitor:
train(model, data)
energy_summary = monitor.summary
print(f"Duration: {energy_summary.duration_sec:.2f}s")
if energy_summary.total_gpu_energy_joules is not None:
print(f"GPU energy: {energy_summary.total_gpu_energy_joules:.2f} J")
if energy_summary.total_cpu_energy_joules is not None:
print(f"CPU energy: {energy_summary.total_cpu_energy_joules:.2f} J")
if energy_summary.mean_gpu_power_watts is not None:
print(f"Mean GPU power: {energy_summary.mean_gpu_power_watts:.1f} W")
Note
Energy monitoring requires hardware support: NVML for GPU power and
Intel RAPL for CPU energy. On unsupported hardware, energy fields will
be None.
FLOP Counting¤
FlopsCounter analyzes a JAX function's computation graph (jaxpr) to count
floating-point operations:
import jax.numpy as jnp
from calibrax.profiling.flops import FlopsCounter
def matmul_workload(x, w):
return jnp.dot(x, w)
counter = FlopsCounter()
result = counter.count(matmul_workload, jnp.ones((64, 128)), jnp.ones((128, 32)))
print(f"Total FLOPs: {result.total_flops:,}")
print(f"Operations: {result.num_operations}")
for op, count in result.flops_by_operation.items():
print(f" {op}: {count:,}")
NNX Models
FlopsCounter works with pure JAX functions. For Flax NNX models,
use flax.nnx.tabulate(model, *args, compute_flops=True) instead,
which handles the NNX state and Rngs automatically.
Hardware Detection¤
detect_hardware_specs() auto-detects the active JAX backend and returns
reference performance specs. HARDWARE_SPECS contains peak FLOP/s and memory
bandwidth values for common accelerators.
from calibrax.profiling.hardware import detect_hardware_specs, HARDWARE_SPECS
specs = detect_hardware_specs()
print(f"Peak FLOP/s: {specs.get('peak_flops', 'N/A')}")
print(f"Memory BW (B/s): {specs.get('memory_bandwidth', 'N/A')}")
# Reference specs for specific hardware
a100_specs = HARDWARE_SPECS["a100_80g"]
Roofline Analysis¤
RooflineAnalyzer compares a workload's arithmetic intensity against the
hardware roofline to determine whether it is compute-bound or memory-bound:
import jax.numpy as jnp
from calibrax.profiling.roofline import RooflineAnalyzer
def matmul_fn(x):
return jnp.dot(x, x.T)
analyzer = RooflineAnalyzer()
result = analyzer.analyze_operation(matmul_fn, [jnp.ones((64, 64))])
print(f"Arithmetic intensity: {result.arithmetic_intensity:.2f} FLOP/byte")
print(f"Bound: {result.bottleneck}")
for rec in result.recommendations:
print(f" - {rec}")
Compilation Profiling¤
CompilationProfiler analyzes JIT compilation overhead, shape consistency,
and XLA optimization effectiveness:
from calibrax.profiling.compilation import CompilationProfiler
profiler = CompilationProfiler()
# Instrument a JIT-compiled function (returns a callable wrapper)
instrumented = profiler.profile_jit_compilation(fn)
result = instrumented(*sample_args)
# Get compilation report
report = profiler.get_result()
print(f"Cache hit rate: {report.cache_hit_rate:.2%}")
# Analyze XLA optimization effectiveness
xla_result = profiler.estimate_xla_optimization(fn, *sample_args)
print(f"Optimization score: {xla_result.optimization_score:.2f}")
Complexity Analysis¤
analyze_complexity() examines parameter counts, memory requirements, and
computational cost for Flax NNX modules:
from calibrax.profiling.complexity import analyze_complexity
result = analyze_complexity(model, (4, 128)) # input_shape tuple, not data
print(f"Total parameters: {result.total_parameters:,}")
print(f"Memory (MB): {result.parameter_memory_mb:.1f}")
print(f"Estimated operations: {result.total_estimated_operations:,}")
XLA Trace Linking¤
TraceLinker wraps jax.profiler.trace() and records the trace directory
alongside benchmark metadata for later analysis in TensorBoard:
from calibrax.profiling.tracing import TraceLinker
linker = TraceLinker()
with linker.trace("temp/doc-examples/traces/run_001") as ref:
# Run workload — XLA profiling is active
train_step(model, batch)
print(f"Trace saved to: {ref.trace_dir}")
# Open with: tensorboard --logdir temp/doc-examples/traces/run_001
Carbon Emissions Tracking¤
Import Path
CarbonTracker is not re-exported from calibrax.profiling to avoid
loading codecarbon at import time. Import it directly:
Optional Dependency
Requires codecarbon: uv pip install "calibrax[codecarbon]"
CarbonTracker measures energy consumption and CO2 emissions during a
workload via CodeCarbon:
from calibrax.profiling.carbon import CarbonTracker
with CarbonTracker(country_iso_code="USA") as tracker:
train(model, data)
result = tracker.result()
print(f"Emissions: {result.emissions_kg_co2:.4f} kg CO2")
print(f"Energy: {result.energy_consumed_kwh:.4f} kWh")
print(f"Duration: {result.duration_sec:.1f}s")
Composing a BenchmarkResult¤
Combine profiling outputs into a single BenchmarkResult for storage and analysis:
from pathlib import Path
from calibrax.core.result import BenchmarkResult
from calibrax.core.models import Metric
result = BenchmarkResult(
name="forward_pass",
domain="training",
timing=sample, # from TimingCollector
resources=summary, # from ResourceMonitor
metrics={
"throughput": Metric(value=sample.num_elements / sample.wall_clock_sec),
},
)
result.save(Path("temp/doc-examples/results/forward_pass.json"))
Best Practices¤
- Always use
sync_fnfor GPU timing — without it, measurements reflect dispatch time, not actual compute time - Set
ResourceMonitorsample interval to at least 10x shorter than the expected workload duration to get meaningful statistics - Run workloads once before timing to warm up JIT compilation, or exclude the first batch time from throughput calculations
- Use
MemoryOptimizerduring development to catch memory leaks early
Next Steps¤
-
Statistical Analysis
Apply bootstrap CI and outlier detection to your timing samples
-
Storage & Baselines
Save profiling results and establish performance baselines