calibrax.profiling.complexity¤
Model complexity analysis for Flax NNX modules. Examines parameter counts, memory requirements, computational cost, and scaling properties of neural network architectures.
Model complexity analysis for Flax NNX modules.
Provides parameter counts, memory usage estimates, computational complexity analysis, and scaling characteristics for any NNX module.
ComplexityResult(*, total_parameters, parameter_memory_mb, largest_layer_name, largest_layer_params, input_shape, estimated_memory_mb, total_estimated_operations, dominant_complexity, scaling_characteristics=dict())
dataclass
¤
Result of model complexity analysis.
Attributes:
| Name | Type | Description |
|---|---|---|
total_parameters |
int
|
Total number of trainable parameters. |
parameter_memory_mb |
float
|
Memory consumed by parameters (float32). |
largest_layer_name |
str
|
Name of the layer with the most parameters. |
largest_layer_params |
int
|
Parameter count of the largest layer. |
input_shape |
tuple[int, ...]
|
Shape of the analyzed input. |
estimated_memory_mb |
float
|
Estimated total memory (params + activations). |
total_estimated_operations |
int
|
Estimated total operations count. |
dominant_complexity |
str
|
Name of the dominant operation type. |
scaling_characteristics |
dict[str, str]
|
Mapping of operation type to complexity class. |
to_dict()
¤
Serialize to a JSON-compatible dictionary.
from_dict(data)
classmethod
¤
Deserialize from a dictionary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
dict[str, Any]
|
Dictionary with complexity result fields. |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
Reconstructed ComplexityResult instance. |
analyze_complexity(model, input_shape)
¤
Analyze complexity of a Flax NNX module.
Performs parameter counting, memory estimation, computational complexity analysis, and scaling characterization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Flax NNX model to analyze. |
required |
input_shape
|
tuple[int, ...]
|
Shape of input data (including batch dimension). |
required |
Returns:
| Type | Description |
|---|---|
ComplexityResult
|
ComplexityResult with detailed complexity metrics. |