calibrax.profiling.hardware¤
Hardware detection and platform-specific specs for JAX backends. Provides reference peak FLOP/s and memory bandwidth values for common accelerators (TPU v5e, A100, H100, CPU) and utilities for auto-detecting the active JAX backend.
Hardware specifications and detection for profiling.
Provides accelerator specs (TPU v5e, A100, H100, CPU) and utility functions for hardware detection and synchronized execution timing.
detect_hardware_specs()
¤
Detect current hardware and return appropriate specifications.
Uses jax.default_backend() to determine the accelerator type
and returns pre-configured specs for that platform.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Hardware specification dictionary with peak_flops, memory_bandwidth, |
dict[str, Any]
|
and critical_intensity keys (among others). |
measure_execution_time(func, inputs, warmup=3, iterations=10)
¤
Measure execution time of a JAX function with synchronization.
JIT-compiles the function, runs warmup iterations, then times
iterations executions with block_until_ready() barriers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to benchmark. |
required |
inputs
|
list[Array]
|
Input arguments as a list of arrays. |
required |
warmup
|
int
|
Number of warmup iterations (for JIT compilation). |
3
|
iterations
|
int
|
Number of timed iterations. |
10
|
Returns:
| Type | Description |
|---|---|
float
|
Average execution time in seconds. |