Skip to content

calibrax.profiling.timing¤

Wall-clock timing collector with GPU synchronization support. TimingCollector.measure_iteration() consumes an iterator and records per-batch times, total wall clock, and element counts. Supports warm-up iteration exclusion and JIT compilation time measurement.

Framework-agnostic timing with configurable result synchronization.

Provides TimingSample (frozen dataclass) and TimingCollector for measuring iteration throughput with per-batch timing breakdown. Uses time.perf_counter() exclusively for accurate benchmarking. Supports warm-up iteration exclusion and JIT compilation time measurement.

TimingSample(*, wall_clock_sec, per_batch_times, first_batch_time, num_batches, num_elements, compilation_time_sec=None, warmup_batches_excluded=0) dataclass ¤

Result of timing an iteration through a data pipeline.

Attributes:

Name Type Description
wall_clock_sec float

Total wall-clock time for the iteration.

per_batch_times tuple[float, ...]

Per-batch durations in seconds (warmup batches excluded).

first_batch_time float

Time from iteration start to first batch completion.

num_batches int

Number of batches consumed (including warmup).

num_elements int

Total elements processed (via count_fn).

compilation_time_sec float | None

JIT compilation time, if measured separately.

warmup_batches_excluded int

Number of leading batches excluded from per_batch_times.

to_dict() ¤

Serialize to a JSON-compatible dictionary.

from_dict(data) classmethod ¤

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with TimingSample fields.

required

Returns:

Type Description
TimingSample

Reconstructed TimingSample instance.

TimingCollector(sync_fn=None, warmup_iterations=0) ¤

Framework-agnostic timing with configurable GPU sync support.

Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.

JAX dispatches operations asynchronously -- the host returns immediately while the device is still computing. Without an explicit synchronization barrier, perf_counter measures only host-side dispatch latency, not actual compute time. Pass a sync_fn that calls block_until_ready() on the workload result to force the host to wait for device completion before recording the timestamp.

Example -- JAX GPU timing with warm-up:

import jax.numpy as jnp

def run_step(batch):
    return jax.jit(step_fn)(batch)

collector = TimingCollector(
    sync_fn=lambda result: result.block_until_ready(),
    warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result. For JAX: lambda result: result.block_until_ready() For PyTorch: lambda _: torch.cuda.synchronize() For CPU-only: None (default, no-op)

None
warmup_iterations int

Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0.

0

Initialize TimingCollector.

Parameters:

Name Type Description Default
sync_fn Callable[[Any], object] | None

Synchronization function called with each batch result.

None
warmup_iterations int

Number of initial batches to exclude from timing stats.

0

measure_iteration(iterator, num_batches=None, process_fn=None, count_fn=None) ¤

Measure timing for batches from an iterator.

Warm-up batches (if configured) are executed but excluded from per_batch_times. wall_clock_sec covers the entire run including warm-up. num_batches reflects total batches consumed.

Parameters:

Name Type Description Default
iterator Iterator[Any]

Data iterator to measure.

required
num_batches int | None

Max batches to consume (including warmup). None exhausts iterator.

None
process_fn Callable[[Any], Any] | None

Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result).

None
count_fn Callable[[Any], int] | None

Function to count elements per batch. Default: 1 per batch.

None

Returns:

Type Description
TimingSample

TimingSample with timing measurements.

measure_compilation_time(fn, *args) ¤

Measure JIT compilation time for a JAX function.

Calls jax.jit(fn).lower(*args).compile() and times it. This measures the XLA compilation step only, not execution.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to compile.

required
*args Any

Example arguments for lowering.

()

Returns:

Type Description
float

Compilation time in seconds.