calibrax.profiling.flops¤
FLOP counting via jaxpr analysis. FlopsCounter.count() traces a JAX function
and counts floating-point operations by operation type.
FLOP counting via JAX's jaxpr tracing.
Provides FlopsCounter for estimating FLOPs of JAX functions by analyzing their Jaxpr intermediate representation.
FlopsResult(*, total_flops, flops_by_operation, num_operations, function_name)
dataclass
¤
Result of FLOP counting for a function.
Attributes:
| Name | Type | Description |
|---|---|---|
total_flops |
int
|
Total estimated FLOPs. |
flops_by_operation |
dict[str, int]
|
Breakdown by primitive operation name. |
num_operations |
int
|
Number of JAX primitives in the trace. |
function_name |
str
|
Name of the analyzed function. |
FlopsCounter
¤
Count FLOPs of JAX functions via jaxpr analysis.
Uses jax.make_jaxpr to trace the function and counts FLOPs
for each primitive based on operation-specific rules.
For NNX models that use stochastic operations (dropout, etc.),
use flax.nnx.tabulate(model, *args, compute_flops=True)
instead — it handles NNX state management internally.
count(fn, *args, static_argnums=())
¤
Count FLOPs for a function with given example arguments.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable[..., Any]
|
JAX function to analyze. |
required |
*args
|
Any
|
Example arguments for tracing. |
()
|
static_argnums
|
tuple[int, ...]
|
Argument indices to treat as static. |
()
|
Returns:
| Type | Description |
|---|---|
FlopsResult
|
FlopsResult with FLOP count and breakdown. |