Writing Adapters¤
Adapters wrap external objects — typically ML models — so they can be used with Calibrax's benchmarking and profiling tools. Calibrax provides two adapter base classes for different use cases.
When to Use Adapters¤
Use an adapter when you need to benchmark an object that does not conform to
Calibrax's protocols. For example, a Flax NNX model has a __call__ method
but not the setup() / run_training() / run_evaluation() interface that
BenchmarkProtocol expects.
BenchmarkAdapter (ABC)¤
BenchmarkAdapter is an abstract base class for wrapping non-JAX targets such
as PyTorch models or plain Python objects:
from calibrax.core.adapters import BenchmarkAdapter
class PyTorchAdapter(BenchmarkAdapter):
@classmethod
def can_adapt(cls, target: object) -> bool:
# Return True if this adapter can wrap the given target
return hasattr(target, "forward")
@property
def name(self) -> str:
return type(self.target).__name__
# Usage
adapter = PyTorchAdapter(my_pytorch_model)
print(adapter.name) # "MyModel"
print(adapter.target) # the original model
Key points:
- Subclass
BenchmarkAdapterand overridecan_adapt()to declare which objects your adapter supports - Access the wrapped object via the
targetproperty - The
nameproperty defaults totarget.name, thentarget.model_name, falling back to"unknown"
NNXBenchmarkAdapter¤
NNXBenchmarkAdapter inherits from nnx.Module (not BenchmarkAdapter),
making it compatible with JAX transformations like nnx.jit, nnx.vmap,
and nnx.grad:
from flax import nnx
from calibrax.core.adapters import NNXBenchmarkAdapter
model = nnx.Linear(128, 64, rngs=nnx.Rngs(0))
adapter = NNXBenchmarkAdapter(model)
print(adapter.name) # "unknown" (nnx.Linear has no .name attribute)
Using nnx.jit
NNXBenchmarkAdapter is intentionally minimal — subclasses add
domain-specific methods like predict(). Because nnx.jit does not
support bound methods, use the unbound method pattern:
NNXBenchmarkAdapter.can_adapt() returns True for any nnx.Module instance.
AdapterRegistry¤
The AdapterRegistry manages adapter resolution. When you call adapt(),
it tries each registered adapter in priority order (most recently registered
first) and returns the first one whose can_adapt() returns True:
from calibrax.core.adapters import adapt, register_adapter
# Register a custom adapter
register_adapter(PyTorchAdapter)
# adapt() tries adapters in reverse registration order
wrapped = adapt(my_model)
The default registry pre-registers NNXBenchmarkAdapter, so NNX models are
adapted automatically:
from calibrax.core.adapters import adapt
wrapped = adapt(my_nnx_model) # returns NNXBenchmarkAdapter
Manual Registry¤
For isolated testing or custom resolution logic:
from calibrax.core.adapters import AdapterRegistry
registry = AdapterRegistry()
registry.register(PyTorchAdapter)
registry.register(NNXBenchmarkAdapter)
wrapped = registry.adapt(model)
registry.reset() # clear all registrations
Custom Adapter Example¤
A complete adapter for a hypothetical framework:
from calibrax.core.adapters import BenchmarkAdapter, register_adapter
class SklearnAdapter(BenchmarkAdapter):
"""Adapter for scikit-learn estimators."""
@classmethod
def can_adapt(cls, target: object) -> bool:
return hasattr(target, "fit") and hasattr(target, "predict")
@property
def name(self) -> str:
return type(self.target).__name__
def predict(self, x):
return self.target.predict(x)
def fit(self, x, y):
return self.target.fit(x, y)
# Register globally
register_adapter(SklearnAdapter)
Best Practices¤
- Use
NNXBenchmarkAdapterfor Flax NNX models — it preserves JIT compatibility - Use
BenchmarkAdapterfor everything else — PyTorch, scikit-learn, custom objects - Register adapters early (e.g., at module load time) so
adapt()can resolve them when needed - Override
can_adapt()with precise checks to avoid false matches - Keep adapters thin — delegate to the wrapped object rather than reimplementing logic