calibrax.profiling.timing¤
Wall-clock timing collector with GPU synchronization support.
TimingCollector.measure_iteration() consumes an iterator and records
per-batch times, total wall clock, and element counts. Supports
warm-up iteration exclusion and JIT compilation time measurement.
Framework-agnostic timing with configurable result synchronization.
Provides TimingSample (frozen dataclass) and TimingCollector for measuring iteration throughput with per-batch timing breakdown. Uses time.perf_counter() exclusively for accurate benchmarking. Supports warm-up iteration exclusion and JIT compilation time measurement.
TimingSample(*, wall_clock_sec, per_batch_times, first_batch_time, num_batches, num_elements, compilation_time_sec=None, warmup_batches_excluded=0)
dataclass
¤
Result of timing an iteration through a data pipeline.
Attributes:
| Name | Type | Description |
|---|---|---|
wall_clock_sec |
float
|
Total wall-clock time for the iteration. |
per_batch_times |
tuple[float, ...]
|
Per-batch durations in seconds (warmup batches excluded). |
first_batch_time |
float
|
Time from iteration start to first batch completion. |
num_batches |
int
|
Number of batches consumed (including warmup). |
num_elements |
int
|
Total elements processed (via count_fn). |
compilation_time_sec |
float | None
|
JIT compilation time, if measured separately. |
warmup_batches_excluded |
int
|
Number of leading batches excluded from per_batch_times. |
to_dict()
¤
Serialize to a JSON-compatible dictionary.
from_dict(data)
classmethod
¤
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with TimingSample fields. |
required |
Returns:
| Type | Description |
|---|---|
TimingSample
|
Reconstructed TimingSample instance. |
TimingCollector(sync_fn=None, warmup_iterations=0)
¤
Framework-agnostic timing with configurable GPU sync support.
Uses time.perf_counter() exclusively for accurate benchmarking. Supports configurable result synchronization via sync_fn and warm-up iteration exclusion for JIT-compiled workloads.
JAX dispatches operations asynchronously -- the host returns
immediately while the device is still computing. Without an explicit
synchronization barrier, perf_counter measures only host-side
dispatch latency, not actual compute time. Pass a sync_fn that
calls block_until_ready() on the workload result to force the host
to wait for device completion before recording the timestamp.
Example -- JAX GPU timing with warm-up:
import jax.numpy as jnp
def run_step(batch):
return jax.jit(step_fn)(batch)
collector = TimingCollector(
sync_fn=lambda result: result.block_until_ready(),
warmup_iterations=2,
)
sample = collector.measure_iteration(data_iter, num_batches=50, process_fn=run_step)
# sample.per_batch_times excludes the first 2 batches
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result.
For JAX: |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from per_batch_times statistics. They are still executed (important for JIT warm-up) but omitted from the timing result. Default: 0. |
0
|
Initialize TimingCollector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sync_fn
|
Callable[[Any], object] | None
|
Synchronization function called with each batch result. |
None
|
warmup_iterations
|
int
|
Number of initial batches to exclude from timing stats. |
0
|
measure_iteration(iterator, num_batches=None, process_fn=None, count_fn=None)
¤
Measure timing for batches from an iterator.
Warm-up batches (if configured) are executed but excluded from
per_batch_times. wall_clock_sec covers the entire run
including warm-up. num_batches reflects total batches consumed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
iterator
|
Iterator[Any]
|
Data iterator to measure. |
required |
num_batches
|
int | None
|
Max batches to consume (including warmup). None exhausts iterator. |
None
|
process_fn
|
Callable[[Any], Any] | None
|
Optional per-batch function whose execution is timed. Defaults to identity (the yielded batch is treated as result). |
None
|
count_fn
|
Callable[[Any], int] | None
|
Function to count elements per batch. Default: 1 per batch. |
None
|
Returns:
| Type | Description |
|---|---|
TimingSample
|
TimingSample with timing measurements. |
measure_compilation_time(fn, *args)
¤
Measure JIT compilation time for a JAX function.
Calls jax.jit(fn).lower(*args).compile() and times it.
This measures the XLA compilation step only, not execution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to compile. |
required |
*args
|
Any
|
Example arguments for lowering. |
()
|
Returns:
| Type | Description |
|---|---|
float
|
Compilation time in seconds. |