calibrax.metrics.learning¤
Tier 3 metric learning losses -- differentiable distance functions that
learn embedding spaces via backpropagation. All losses return JAX arrays
compatible with jax.grad for gradient flow. Includes contrastive,
triplet margin, NT-Xent, ArcFace, CosFace, proxy-NCA, and proxy-anchor
losses, plus hard-negative and semi-hard mining strategies.
Tier 3: Metric learning losses -- differentiable distance functions.
Loss functions that learn embedding spaces via backpropagation. All losses return differentiable JAX arrays for gradient flow.
MetricLearningLoss(*, reduction: str = 'mean')
¤
Abstract base for metric learning losses.
Subclasses implement _compute_loss to return per-element losses, which are then reduced to a scalar via the Reducer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
str
|
Loss reduction strategy ("mean" or "sum"). |
'mean'
|
Examples:
>>> class MyLoss(MetricLearningLoss):
... def _compute_loss(self, embeddings, labels):
... return jnp.zeros(embeddings.shape[0])
>>> loss_fn = MyLoss()
>>> loss_fn(embeddings, labels)
Initialize metric learning loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
str
|
Reduction method ("mean" or "sum"). |
'mean'
|
__call__(embeddings: jax.Array, labels: jax.Array, **kwargs: Any) -> jax.Array
¤
Compute the metric learning loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
Batch of embedding vectors (batch_size, embedding_dim). |
required |
labels
|
Array
|
Integer class labels (batch_size,). |
required |
**kwargs
|
Any
|
Additional arguments for subclass losses. |
{}
|
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value as a differentiable JAX array. |
Reducer(reduction: str = 'mean')
¤
Reduction strategy for per-element losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
str
|
One of "mean" or "sum". |
'mean'
|
Initialize reducer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reduction
|
str
|
Reduction method ("mean" or "sum"). |
'mean'
|
__call__(losses: jax.Array) -> jax.Array
¤
Apply reduction to per-element losses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
losses
|
Array
|
Array of per-element loss values. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Reduced scalar loss. |
ArcFaceLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.5, scale: float = 64.0, rngs: nnx.Rngs)
¤
Bases: Module
ArcFace loss with additive angular margin.
Adds an angular margin penalty to the target class logit in cosine space: cos(theta + m) for the target class, cos(theta) for other classes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of target classes. |
required |
embedding_dim
|
int
|
Dimensionality of input embeddings. |
required |
margin
|
float
|
Angular margin in radians. Defaults to 0.5. |
0.5
|
scale
|
float
|
Logit scaling factor. Defaults to 64.0. |
64.0
|
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
Examples:
>>> loss_fn = ArcFaceLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)
Initialize ArcFace loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_dim
|
int
|
Embedding dimensionality. |
required |
margin
|
float
|
Angular margin in radians. |
0.5
|
scale
|
float
|
Logit scale factor. |
64.0
|
rngs
|
Rngs
|
RNG streams. |
required |
__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array
¤
Compute ArcFace loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
(batch_size, embedding_dim) normalized embeddings. |
required |
labels
|
Array
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value. |
CosFaceLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.35, scale: float = 64.0, rngs: nnx.Rngs)
¤
Bases: Module
CosFace loss with additive cosine margin.
Subtracts a margin from the target class cosine similarity: cos(theta) - m for the target class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of target classes. |
required |
embedding_dim
|
int
|
Dimensionality of input embeddings. |
required |
margin
|
float
|
Cosine margin. Defaults to 0.35. |
0.35
|
scale
|
float
|
Logit scaling factor. Defaults to 64.0. |
64.0
|
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
Examples:
>>> loss_fn = CosFaceLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)
Initialize CosFace loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_dim
|
int
|
Embedding dimensionality. |
required |
margin
|
float
|
Cosine margin. |
0.35
|
scale
|
float
|
Logit scale factor. |
64.0
|
rngs
|
Rngs
|
RNG streams. |
required |
__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array
¤
Compute CosFace loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
(batch_size, embedding_dim) normalized embeddings. |
required |
labels
|
Array
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value. |
ContrastiveLoss(margin: float = 1.0, *, reduction: str = 'mean')
¤
Bases: MetricLearningLoss
Pair-based contrastive loss with margin.
L = y * d^2 + (1-y) * max(0, margin - d)^2
where y=1 for positive pairs (same class) and y=0 for negative pairs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Margin for negative pairs. Defaults to 1.0. |
1.0
|
reduction
|
str
|
Loss reduction ("mean" or "sum"). |
'mean'
|
Examples:
Initialize contrastive loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Distance margin for negative pairs. |
1.0
|
reduction
|
str
|
Reduction method. |
'mean'
|
NTXentLoss(temperature: float = 0.5, *, reduction: str = 'mean')
¤
Bases: MetricLearningLoss
Normalized Temperature-scaled Cross Entropy (InfoNCE) loss.
L = -log(exp(sim(a, p) / t) / sum(exp(sim(a, k) / t)))
Temperature-scaled softmax cross-entropy on cosine similarities.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperature
|
float
|
Temperature scaling factor. Defaults to 0.5. |
0.5
|
reduction
|
str
|
Loss reduction ("mean" or "sum"). |
'mean'
|
Examples:
Initialize NT-Xent loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
temperature
|
float
|
Temperature parameter. |
0.5
|
reduction
|
str
|
Reduction method. |
'mean'
|
TripletMarginLoss(margin: float = 0.2, *, reduction: str = 'mean')
¤
Bases: MetricLearningLoss
Triplet loss with margin.
L = max(0, d(anchor, positive) - d(anchor, negative) + margin)
Mines all valid triplets from the batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Triplet margin. Defaults to 0.2. |
0.2
|
reduction
|
str
|
Loss reduction ("mean" or "sum"). |
'mean'
|
Examples:
Initialize triplet margin loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Distance margin. |
0.2
|
reduction
|
str
|
Reduction method. |
'mean'
|
HardNegativeMiner
¤
Mines hardest negatives: closest embedding from a different class.
For each anchor-positive pair, selects the negative with minimum distance to the anchor.
Examples:
mine(embeddings: jnp.ndarray, labels: jnp.ndarray) -> MinedIndices
¤
Mine hard negative triplets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
ndarray
|
(batch_size, dim) embedding vectors. |
required |
labels
|
ndarray
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
MinedIndices
|
MinedIndices with anchor, positive, and hard negative indices. |
MinedIndices(anchors: jnp.ndarray, positives: jnp.ndarray, negatives: jnp.ndarray)
dataclass
¤
Indices of mined triplets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
anchors
|
ndarray
|
Anchor sample indices. |
required |
positives
|
ndarray
|
Positive sample indices (same class as anchor). |
required |
negatives
|
ndarray
|
Negative sample indices (different class from anchor). |
required |
SemiHardMiner(margin: float = 1.0)
¤
Mines semi-hard negatives: farther than positive but within margin.
Semi-hard negatives satisfy: d(a, p) < d(a, n) < d(a, p) + margin.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Margin for semi-hard selection. |
1.0
|
Examples:
Initialize semi-hard miner.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
margin
|
float
|
Distance margin for semi-hard criterion. |
1.0
|
mine(embeddings: jnp.ndarray, labels: jnp.ndarray) -> MinedIndices
¤
Mine semi-hard negative triplets.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
ndarray
|
(batch_size, dim) embedding vectors. |
required |
labels
|
ndarray
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
MinedIndices
|
MinedIndices with anchor, positive, and semi-hard negative indices. |
ProxyAnchorLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.1, scale: float = 32.0, rngs: nnx.Rngs)
¤
Bases: Module
Proxy Anchor loss with smooth hard mining.
Each proxy acts as an anchor. Aggregates positive/negative samples via LogSumExp for smooth hard mining with stable gradients.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_dim
|
int
|
Dimensionality of embedding space. |
required |
margin
|
float
|
Angular margin. Defaults to 0.1. |
0.1
|
scale
|
float
|
Logit scale factor. Defaults to 32.0. |
32.0
|
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
Examples:
>>> loss_fn = ProxyAnchorLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)
Initialize Proxy Anchor loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of proxy vectors. |
required |
embedding_dim
|
int
|
Embedding dimensionality. |
required |
margin
|
float
|
Margin for positive/negative separation. |
0.1
|
scale
|
float
|
Scale factor. |
32.0
|
rngs
|
Rngs
|
RNG streams. |
required |
__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array
¤
Compute Proxy Anchor loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
(batch_size, embedding_dim) embedding vectors. |
required |
labels
|
Array
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value. |
ProxyNCALoss(num_classes: int, embedding_dim: int, *, rngs: nnx.Rngs)
¤
Bases: Module
Proxy Neighborhood Component Analysis loss.
Learns proxy vectors (one per class) and pushes each sample toward its class proxy and away from other proxies.
L = -log(exp(-d(x, p+)) / sum(exp(-d(x, p-))))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of classes. |
required |
embedding_dim
|
int
|
Dimensionality of embedding space. |
required |
rngs
|
Rngs
|
RNG streams for parameter initialization. |
required |
Examples:
>>> loss_fn = ProxyNCALoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)
Initialize ProxyNCA loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_classes
|
int
|
Number of proxy vectors to learn. |
required |
embedding_dim
|
int
|
Embedding dimensionality. |
required |
rngs
|
Rngs
|
RNG streams. |
required |
__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array
¤
Compute ProxyNCA loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
Array
|
(batch_size, embedding_dim) embedding vectors. |
required |
labels
|
Array
|
(batch_size,) integer class labels. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Scalar loss value. |