Skip to content

calibrax.profiling.flops¤

FLOP counting via jaxpr analysis. FlopsCounter.count() traces a JAX function and counts floating-point operations by operation type.

FLOP counting via JAX's jaxpr tracing.

Provides FlopsCounter for estimating FLOPs of JAX functions by analyzing their Jaxpr intermediate representation.

FlopsResult(*, total_flops, flops_by_operation, num_operations, function_name) dataclass ¤

Result of FLOP counting for a function.

Attributes:

Name Type Description
total_flops int

Total estimated FLOPs.

flops_by_operation dict[str, int]

Breakdown by primitive operation name.

num_operations int

Number of JAX primitives in the trace.

function_name str

Name of the analyzed function.

FlopsCounter ¤

Count FLOPs of JAX functions via jaxpr analysis.

Uses jax.make_jaxpr to trace the function and counts FLOPs for each primitive based on operation-specific rules.

For NNX models that use stochastic operations (dropout, etc.), use flax.nnx.tabulate(model, *args, compute_flops=True) instead — it handles NNX state management internally.

count(fn, *args, static_argnums=()) ¤

Count FLOPs for a function with given example arguments.

Parameters:

Name Type Description Default
fn Callable[..., Any]

JAX function to analyze.

required
*args Any

Example arguments for tracing.

()
static_argnums tuple[int, ...]

Argument indices to treat as static.

()

Returns:

Type Description
FlopsResult

FlopsResult with FLOP count and breakdown.