Skip to content

calibrax.profiling.hardware¤

Hardware detection and platform-specific specs for JAX backends. Provides reference peak FLOP/s and memory bandwidth values for common accelerators (TPU v5e, A100, H100, CPU) and utilities for auto-detecting the active JAX backend.

Hardware specifications and detection for profiling.

Provides accelerator specs (TPU v5e, A100, H100, CPU) and utility functions for hardware detection and synchronized execution timing.

detect_hardware_specs() ¤

Detect current hardware and return appropriate specifications.

Uses jax.default_backend() to determine the accelerator type and returns pre-configured specs for that platform.

Returns:

Type Description
dict[str, Any]

Hardware specification dictionary with peak_flops, memory_bandwidth,

dict[str, Any]

and critical_intensity keys (among others).

measure_execution_time(func, inputs, warmup=3, iterations=10) ¤

Measure execution time of a JAX function with synchronization.

JIT-compiles the function, runs warmup iterations, then times iterations executions with block_until_ready() barriers.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to benchmark.

required
inputs list[Array]

Input arguments as a list of arrays.

required
warmup int

Number of warmup iterations (for JIT compilation).

3
iterations int

Number of timed iterations.

10

Returns:

Type Description
float

Average execution time in seconds.