Skip to content

calibrax.profiling.compilation¤

XLA compilation profiling for JIT-compiled JAX functions. Measures compilation overhead, analyzes shape consistency, and estimates XLA optimization effectiveness.

JIT compilation profiler for JAX.

Analyzes JIT compilation efficiency, cache hit rates, XLA optimization effectiveness, and provides recommendations for compilation optimization.

CompilationResult(*, cache_hit_rate, total_calls, cache_hits, cache_misses, avg_compilation_time_ms, max_compilation_time_ms, unique_signatures, health_score, health_level, recommendations=()) dataclass ¤

Result of compilation profiling analysis.

Attributes:

Name Type Description
cache_hit_rate float

Fraction of calls that hit the compilation cache.

total_calls int

Total number of profiled function calls.

cache_hits int

Number of cache hits.

cache_misses int

Number of cache misses (triggering compilation).

avg_compilation_time_ms float

Average compilation time in milliseconds.

max_compilation_time_ms float

Maximum compilation time in milliseconds.

unique_signatures int

Number of unique function signatures compiled.

health_score float

Overall compilation health score (0-1).

health_level str

Human-readable health level.

recommendations tuple[str, ...]

Optimization recommendations.

to_dict() ¤

Serialize to a JSON-compatible dictionary.

from_dict(data) classmethod ¤

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with compilation result fields.

required

Returns:

Type Description
CompilationResult

Reconstructed CompilationResult instance.

XLAOptimizationResult(*, optimization_score, fusion_ratio, arithmetic_ratio, memory_ratio, total_kernels, recommendations=()) dataclass ¤

Result of XLA optimization effectiveness analysis.

Attributes:

Name Type Description
optimization_score float

Overall optimization score (0-1).

fusion_ratio float

Fraction of fused kernels.

arithmetic_ratio float

Fraction of arithmetic operations.

memory_ratio float

Fraction of memory operations.

total_kernels int

Total number of HLO kernels.

recommendations tuple[str, ...]

Optimization recommendations.

to_dict() ¤

Serialize to a JSON-compatible dictionary.

from_dict(data) classmethod ¤

Deserialize from a dictionary.

Parameters:

Name Type Description Default
data dict[str, Any]

Dictionary with XLA optimization result fields.

required

Returns:

Type Description
XLAOptimizationResult

Reconstructed XLAOptimizationResult instance.

CompilationProfiler() ¤

Analyzes JAX JIT compilation performance and optimization.

Instruments JIT-compiled functions to track compilation cache hits/misses, compilation times, and input shape consistency. Use profile_jit_compilation to wrap a function, then call get_result() for aggregated analysis.

Initialize the compilation profiler with empty tracking state.

profile_jit_compilation(func) ¤

Create an instrumented wrapper that profiles JIT compilation.

The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to instrument.

required

Returns:

Type Description
Callable[..., Any]

Instrumented function with identical signature.

get_result() ¤

Get aggregated compilation profiling results.

Returns:

Type Description
CompilationResult

CompilationResult with cache statistics, timing, and recommendations.

estimate_xla_optimization(func, *sample_args) ¤

Estimate XLA optimization effectiveness by analyzing HLO text.

Parameters:

Name Type Description Default
func Callable[..., Any]

JAX function to analyze.

required
*sample_args Any

Example arguments for lowering/compiling.

()

Returns:

Type Description
XLAOptimizationResult

XLAOptimizationResult with HLO analysis metrics.

reset() ¤

Reset all profiling state.