calibrax.profiling.compilation¤
XLA compilation profiling for JIT-compiled JAX functions. Measures compilation overhead, analyzes shape consistency, and estimates XLA optimization effectiveness.
JIT compilation profiler for JAX.
Analyzes JIT compilation efficiency, cache hit rates, XLA optimization effectiveness, and provides recommendations for compilation optimization.
CompilationResult(*, cache_hit_rate, total_calls, cache_hits, cache_misses, avg_compilation_time_ms, max_compilation_time_ms, unique_signatures, health_score, health_level, recommendations=())
dataclass
¤
Result of compilation profiling analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
cache_hit_rate |
float
|
Fraction of calls that hit the compilation cache. |
total_calls |
int
|
Total number of profiled function calls. |
cache_hits |
int
|
Number of cache hits. |
cache_misses |
int
|
Number of cache misses (triggering compilation). |
avg_compilation_time_ms |
float
|
Average compilation time in milliseconds. |
max_compilation_time_ms |
float
|
Maximum compilation time in milliseconds. |
unique_signatures |
int
|
Number of unique function signatures compiled. |
health_score |
float
|
Overall compilation health score (0-1). |
health_level |
str
|
Human-readable health level. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
to_dict()
¤
Serialize to a JSON-compatible dictionary.
from_dict(data)
classmethod
¤
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with compilation result fields. |
required |
Returns:
| Type | Description |
|---|---|
CompilationResult
|
Reconstructed CompilationResult instance. |
XLAOptimizationResult(*, optimization_score, fusion_ratio, arithmetic_ratio, memory_ratio, total_kernels, recommendations=())
dataclass
¤
Result of XLA optimization effectiveness analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
optimization_score |
float
|
Overall optimization score (0-1). |
fusion_ratio |
float
|
Fraction of fused kernels. |
arithmetic_ratio |
float
|
Fraction of arithmetic operations. |
memory_ratio |
float
|
Fraction of memory operations. |
total_kernels |
int
|
Total number of HLO kernels. |
recommendations |
tuple[str, ...]
|
Optimization recommendations. |
to_dict()
¤
Serialize to a JSON-compatible dictionary.
from_dict(data)
classmethod
¤
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with XLA optimization result fields. |
required |
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
Reconstructed XLAOptimizationResult instance. |
CompilationProfiler()
¤
Analyzes JAX JIT compilation performance and optimization.
Instruments JIT-compiled functions to track compilation cache hits/misses,
compilation times, and input shape consistency. Use profile_jit_compilation
to wrap a function, then call get_result() for aggregated analysis.
Initialize the compilation profiler with empty tracking state.
profile_jit_compilation(func)
¤
Create an instrumented wrapper that profiles JIT compilation.
The returned callable tracks cache hits/misses, compilation times, and input shape patterns. Results accumulate in this profiler instance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to instrument. |
required |
Returns:
| Type | Description |
|---|---|
Callable[..., Any]
|
Instrumented function with identical signature. |
get_result()
¤
Get aggregated compilation profiling results.
Returns:
| Type | Description |
|---|---|
CompilationResult
|
CompilationResult with cache statistics, timing, and recommendations. |
estimate_xla_optimization(func, *sample_args)
¤
Estimate XLA optimization effectiveness by analyzing HLO text.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*sample_args
|
Any
|
Example arguments for lowering/compiling. |
()
|
Returns:
| Type | Description |
|---|---|
XLAOptimizationResult
|
XLAOptimizationResult with HLO analysis metrics. |
reset()
¤
Reset all profiling state.