Skip to content

calibrax.metrics.learning¤

Tier 3 metric learning losses -- differentiable distance functions that learn embedding spaces via backpropagation. All losses return JAX arrays compatible with jax.grad for gradient flow. Includes contrastive, triplet margin, NT-Xent, ArcFace, CosFace, proxy-NCA, and proxy-anchor losses, plus hard-negative and semi-hard mining strategies.

Tier 3: Metric learning losses -- differentiable distance functions.

Loss functions that learn embedding spaces via backpropagation. All losses return differentiable JAX arrays for gradient flow.

MetricLearningLoss(*, reduction: str = 'mean') ¤

Abstract base for metric learning losses.

Subclasses implement _compute_loss to return per-element losses, which are then reduced to a scalar via the Reducer.

Parameters:

Name Type Description Default
reduction str

Loss reduction strategy ("mean" or "sum").

'mean'

Examples:

>>> class MyLoss(MetricLearningLoss):
...     def _compute_loss(self, embeddings, labels):
...         return jnp.zeros(embeddings.shape[0])
>>> loss_fn = MyLoss()
>>> loss_fn(embeddings, labels)

Initialize metric learning loss.

Parameters:

Name Type Description Default
reduction str

Reduction method ("mean" or "sum").

'mean'

__call__(embeddings: jax.Array, labels: jax.Array, **kwargs: Any) -> jax.Array ¤

Compute the metric learning loss.

Parameters:

Name Type Description Default
embeddings Array

Batch of embedding vectors (batch_size, embedding_dim).

required
labels Array

Integer class labels (batch_size,).

required
**kwargs Any

Additional arguments for subclass losses.

{}

Returns:

Type Description
Array

Scalar loss value as a differentiable JAX array.

Reducer(reduction: str = 'mean') ¤

Reduction strategy for per-element losses.

Parameters:

Name Type Description Default
reduction str

One of "mean" or "sum".

'mean'

Initialize reducer.

Parameters:

Name Type Description Default
reduction str

Reduction method ("mean" or "sum").

'mean'

__call__(losses: jax.Array) -> jax.Array ¤

Apply reduction to per-element losses.

Parameters:

Name Type Description Default
losses Array

Array of per-element loss values.

required

Returns:

Type Description
Array

Reduced scalar loss.

ArcFaceLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.5, scale: float = 64.0, rngs: nnx.Rngs) ¤

Bases: Module

ArcFace loss with additive angular margin.

Adds an angular margin penalty to the target class logit in cosine space: cos(theta + m) for the target class, cos(theta) for other classes.

Parameters:

Name Type Description Default
num_classes int

Number of target classes.

required
embedding_dim int

Dimensionality of input embeddings.

required
margin float

Angular margin in radians. Defaults to 0.5.

0.5
scale float

Logit scaling factor. Defaults to 64.0.

64.0
rngs Rngs

RNG streams for parameter initialization.

required

Examples:

>>> loss_fn = ArcFaceLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)

Initialize ArcFace loss.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_dim int

Embedding dimensionality.

required
margin float

Angular margin in radians.

0.5
scale float

Logit scale factor.

64.0
rngs Rngs

RNG streams.

required

__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array ¤

Compute ArcFace loss.

Parameters:

Name Type Description Default
embeddings Array

(batch_size, embedding_dim) normalized embeddings.

required
labels Array

(batch_size,) integer class labels.

required

Returns:

Type Description
Array

Scalar loss value.

CosFaceLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.35, scale: float = 64.0, rngs: nnx.Rngs) ¤

Bases: Module

CosFace loss with additive cosine margin.

Subtracts a margin from the target class cosine similarity: cos(theta) - m for the target class.

Parameters:

Name Type Description Default
num_classes int

Number of target classes.

required
embedding_dim int

Dimensionality of input embeddings.

required
margin float

Cosine margin. Defaults to 0.35.

0.35
scale float

Logit scaling factor. Defaults to 64.0.

64.0
rngs Rngs

RNG streams for parameter initialization.

required

Examples:

>>> loss_fn = CosFaceLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)

Initialize CosFace loss.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_dim int

Embedding dimensionality.

required
margin float

Cosine margin.

0.35
scale float

Logit scale factor.

64.0
rngs Rngs

RNG streams.

required

__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array ¤

Compute CosFace loss.

Parameters:

Name Type Description Default
embeddings Array

(batch_size, embedding_dim) normalized embeddings.

required
labels Array

(batch_size,) integer class labels.

required

Returns:

Type Description
Array

Scalar loss value.

ContrastiveLoss(margin: float = 1.0, *, reduction: str = 'mean') ¤

Bases: MetricLearningLoss

Pair-based contrastive loss with margin.

L = y * d^2 + (1-y) * max(0, margin - d)^2

where y=1 for positive pairs (same class) and y=0 for negative pairs.

Parameters:

Name Type Description Default
margin float

Margin for negative pairs. Defaults to 1.0.

1.0
reduction str

Loss reduction ("mean" or "sum").

'mean'

Examples:

>>> loss_fn = ContrastiveLoss(margin=1.0)
>>> loss = loss_fn(embeddings, labels)

Initialize contrastive loss.

Parameters:

Name Type Description Default
margin float

Distance margin for negative pairs.

1.0
reduction str

Reduction method.

'mean'

NTXentLoss(temperature: float = 0.5, *, reduction: str = 'mean') ¤

Bases: MetricLearningLoss

Normalized Temperature-scaled Cross Entropy (InfoNCE) loss.

L = -log(exp(sim(a, p) / t) / sum(exp(sim(a, k) / t)))

Temperature-scaled softmax cross-entropy on cosine similarities.

Parameters:

Name Type Description Default
temperature float

Temperature scaling factor. Defaults to 0.5.

0.5
reduction str

Loss reduction ("mean" or "sum").

'mean'

Examples:

>>> loss_fn = NTXentLoss(temperature=0.5)
>>> loss = loss_fn(embeddings, labels)

Initialize NT-Xent loss.

Parameters:

Name Type Description Default
temperature float

Temperature parameter.

0.5
reduction str

Reduction method.

'mean'

TripletMarginLoss(margin: float = 0.2, *, reduction: str = 'mean') ¤

Bases: MetricLearningLoss

Triplet loss with margin.

L = max(0, d(anchor, positive) - d(anchor, negative) + margin)

Mines all valid triplets from the batch.

Parameters:

Name Type Description Default
margin float

Triplet margin. Defaults to 0.2.

0.2
reduction str

Loss reduction ("mean" or "sum").

'mean'

Examples:

>>> loss_fn = TripletMarginLoss(margin=0.2)
>>> loss = loss_fn(embeddings, labels)

Initialize triplet margin loss.

Parameters:

Name Type Description Default
margin float

Distance margin.

0.2
reduction str

Reduction method.

'mean'

HardNegativeMiner ¤

Mines hardest negatives: closest embedding from a different class.

For each anchor-positive pair, selects the negative with minimum distance to the anchor.

Examples:

>>> miner = HardNegativeMiner()
>>> indices = miner.mine(embeddings, labels)

mine(embeddings: jnp.ndarray, labels: jnp.ndarray) -> MinedIndices ¤

Mine hard negative triplets.

Parameters:

Name Type Description Default
embeddings ndarray

(batch_size, dim) embedding vectors.

required
labels ndarray

(batch_size,) integer class labels.

required

Returns:

Type Description
MinedIndices

MinedIndices with anchor, positive, and hard negative indices.

MinedIndices(anchors: jnp.ndarray, positives: jnp.ndarray, negatives: jnp.ndarray) dataclass ¤

Indices of mined triplets.

Parameters:

Name Type Description Default
anchors ndarray

Anchor sample indices.

required
positives ndarray

Positive sample indices (same class as anchor).

required
negatives ndarray

Negative sample indices (different class from anchor).

required

SemiHardMiner(margin: float = 1.0) ¤

Mines semi-hard negatives: farther than positive but within margin.

Semi-hard negatives satisfy: d(a, p) < d(a, n) < d(a, p) + margin.

Parameters:

Name Type Description Default
margin float

Margin for semi-hard selection.

1.0

Examples:

>>> miner = SemiHardMiner(margin=1.0)
>>> indices = miner.mine(embeddings, labels)

Initialize semi-hard miner.

Parameters:

Name Type Description Default
margin float

Distance margin for semi-hard criterion.

1.0

mine(embeddings: jnp.ndarray, labels: jnp.ndarray) -> MinedIndices ¤

Mine semi-hard negative triplets.

Parameters:

Name Type Description Default
embeddings ndarray

(batch_size, dim) embedding vectors.

required
labels ndarray

(batch_size,) integer class labels.

required

Returns:

Type Description
MinedIndices

MinedIndices with anchor, positive, and semi-hard negative indices.

ProxyAnchorLoss(num_classes: int, embedding_dim: int, *, margin: float = 0.1, scale: float = 32.0, rngs: nnx.Rngs) ¤

Bases: Module

Proxy Anchor loss with smooth hard mining.

Each proxy acts as an anchor. Aggregates positive/negative samples via LogSumExp for smooth hard mining with stable gradients.

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_dim int

Dimensionality of embedding space.

required
margin float

Angular margin. Defaults to 0.1.

0.1
scale float

Logit scale factor. Defaults to 32.0.

32.0
rngs Rngs

RNG streams for parameter initialization.

required

Examples:

>>> loss_fn = ProxyAnchorLoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)

Initialize Proxy Anchor loss.

Parameters:

Name Type Description Default
num_classes int

Number of proxy vectors.

required
embedding_dim int

Embedding dimensionality.

required
margin float

Margin for positive/negative separation.

0.1
scale float

Scale factor.

32.0
rngs Rngs

RNG streams.

required

__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array ¤

Compute Proxy Anchor loss.

Parameters:

Name Type Description Default
embeddings Array

(batch_size, embedding_dim) embedding vectors.

required
labels Array

(batch_size,) integer class labels.

required

Returns:

Type Description
Array

Scalar loss value.

ProxyNCALoss(num_classes: int, embedding_dim: int, *, rngs: nnx.Rngs) ¤

Bases: Module

Proxy Neighborhood Component Analysis loss.

Learns proxy vectors (one per class) and pushes each sample toward its class proxy and away from other proxies.

L = -log(exp(-d(x, p+)) / sum(exp(-d(x, p-))))

Parameters:

Name Type Description Default
num_classes int

Number of classes.

required
embedding_dim int

Dimensionality of embedding space.

required
rngs Rngs

RNG streams for parameter initialization.

required

Examples:

>>> loss_fn = ProxyNCALoss(num_classes=10, embedding_dim=128, rngs=nnx.Rngs(0))
>>> loss = loss_fn(embeddings, labels)

Initialize ProxyNCA loss.

Parameters:

Name Type Description Default
num_classes int

Number of proxy vectors to learn.

required
embedding_dim int

Embedding dimensionality.

required
rngs Rngs

RNG streams.

required

__call__(embeddings: jax.Array, labels: jax.Array) -> jax.Array ¤

Compute ProxyNCA loss.

Parameters:

Name Type Description Default
embeddings Array

(batch_size, embedding_dim) embedding vectors.

required
labels Array

(batch_size,) integer class labels.

required

Returns:

Type Description
Array

Scalar loss value.