Skip to content

calibrax.profiling.tracing¤

XLA trace linking for JAX workloads. Wraps jax.profiler.trace() and records trace file paths alongside benchmark run metadata, enabling post-hoc analysis in TensorBoard.

XLA trace linking for connecting JAX profiler output to benchmark runs.

Provides a simple context manager wrapping jax.profiler.trace() that records the trace file path for association with Store run metadata. Does not parse trace files — only links file paths to benchmark results.

TraceReference(*, trace_dir, run_id=None) dataclass ¤

Reference to a JAX profiler trace output.

Attributes:

Name Type Description
trace_dir str

Directory where the trace files were written.

run_id str | None

Optional benchmark run ID to link the trace to.

to_dict() ¤

Serialize to a JSON-compatible dictionary.

from_dict(data) classmethod ¤

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with trace reference fields.

required

Returns:

Type Description
TraceReference

Reconstructed TraceReference instance.

TraceLinker ¤

Links JAX profiler traces to benchmark runs.

Usage:

linker = TraceLinker()
with linker.trace("/tmp/my_trace") as ref:
    # ... run workload ...
print(ref.trace_dir)  # "/tmp/my_trace"

trace(log_dir, *, run_id=None, create_perfetto_link=False, create_perfetto_trace=False) ¤

Start an XLA profiling session and record output metadata.

Wraps jax.profiler.trace() and records the output directory path as a TraceReference for downstream Store linkage.

Parameters:

Name Type Description Default
log_dir str | Path

Directory for trace output files.

required
run_id str | None

Optional benchmark run ID to associate with the trace.

None
create_perfetto_link bool

Whether to create a Perfetto link (passed to JAX).

False
create_perfetto_trace bool

Whether to create a Perfetto trace (passed to JAX).

False

Yields:

Type Description
Any

TraceReference with the trace directory and optional run ID.