Metric Learning¤
Tier 3 losses train the embedding space itself. They are differentiable
loss functions that work with jax.grad() for end-to-end training.
All losses follow the same interface:
from calibrax.metrics.learning import ContrastiveLoss
loss_fn = ContrastiveLoss(margin=1.0)
loss = loss_fn(embeddings, labels) # Returns a differentiable JAX scalar
Contrastive Losses¤
ContrastiveLoss¤
Pair-based loss with a margin. Pulls same-class pairs together and pushes different-class pairs apart.
from calibrax.metrics.learning import ContrastiveLoss
loss_fn = ContrastiveLoss(margin=1.0)
loss = loss_fn(embeddings, labels)
TripletMarginLoss¤
Mines all valid (anchor, positive, negative) triplets from the batch.
from calibrax.metrics.learning import TripletMarginLoss
loss_fn = TripletMarginLoss(margin=0.2)
loss = loss_fn(embeddings, labels)
NTXentLoss (InfoNCE)¤
Temperature-scaled softmax cross-entropy on cosine similarities. The standard loss for self-supervised contrastive learning (SimCLR, MoCo).
from calibrax.metrics.learning import NTXentLoss
loss_fn = NTXentLoss(temperature=0.5)
loss = loss_fn(embeddings, labels)
Angular Margin Losses¤
These losses add angular penalties in cosine space and maintain learnable
weight matrices. They are nnx.Module subclasses with trainable parameters.
ArcFaceLoss¤
Additive angular margin: \(\cos(\theta + m)\) for the target class.
import jax
import flax.nnx as nnx
from calibrax.metrics.learning import ArcFaceLoss
# Angular margin losses need embeddings matching embedding_dim
embeddings = jax.random.normal(jax.random.PRNGKey(0), (4, 128))
labels = jnp.array([0, 0, 1, 1])
loss_fn = ArcFaceLoss(
num_classes=100,
embedding_dim=128,
margin=0.5,
scale=64.0,
rngs=nnx.Rngs(0),
)
loss = loss_fn(embeddings, labels)
CosFaceLoss¤
Additive cosine margin: \(\cos(\theta) - m\) for the target class.
from calibrax.metrics.learning import CosFaceLoss
loss_fn = CosFaceLoss(
num_classes=100,
embedding_dim=128,
margin=0.35,
scale=64.0,
rngs=nnx.Rngs(0),
)
loss = loss_fn(embeddings, labels)
Proxy-Based Losses¤
Proxy methods learn class-representative vectors in embedding space. \(O(MC)\) complexity instead of \(O(N^2)\) for pair-based methods, making them practical for large class counts.
ProxyNCALoss¤
Pushes each sample toward its class proxy and away from other proxies:
from calibrax.metrics.learning import ProxyNCALoss
loss_fn = ProxyNCALoss(
num_classes=100,
embedding_dim=128,
rngs=nnx.Rngs(0),
)
loss = loss_fn(embeddings, labels)
ProxyAnchorLoss¤
Uses LogSumExp for smooth hard mining with stable gradients:
from calibrax.metrics.learning import ProxyAnchorLoss
loss_fn = ProxyAnchorLoss(
num_classes=100,
embedding_dim=128,
margin=0.1,
scale=32.0,
rngs=nnx.Rngs(0),
)
loss = loss_fn(embeddings, labels)
Miners¤
Miners identify informative triplets from a batch before passing them to a loss function. This improves training efficiency by focusing on hard examples.
from calibrax.metrics.learning import HardNegativeMiner, SemiHardMiner
# Hard negatives: closest sample from a different class
miner = HardNegativeMiner()
indices = miner.mine(embeddings, labels)
# indices.anchors, indices.positives, indices.negatives
# Semi-hard: farther than positive but within margin
miner = SemiHardMiner(margin=1.0)
indices = miner.mine(embeddings, labels)
MinedIndices is a frozen dataclass with anchors, positives, and
negatives arrays of matching length.
Training Loop Integration¤
All Tier 3 losses return differentiable JAX arrays. Use nnx.grad() for
end-to-end training:
import flax.nnx as nnx
import optax
from calibrax.metrics.learning import TripletMarginLoss
# Simple encoder: Linear layer projecting to embedding space
encoder = nnx.Linear(128, 32, rngs=nnx.Rngs(0))
loss_fn = TripletMarginLoss(margin=0.2)
optimizer = nnx.Optimizer(encoder, optax.adam(1e-3), wrt=nnx.Param)
def train_step(encoder, optimizer, batch_x, batch_labels):
def loss(encoder):
return loss_fn(encoder(batch_x), batch_labels)
grads = nnx.grad(loss, wrt=nnx.Param)(encoder)
optimizer.update(grads)
For losses with trainable parameters (ArcFace, CosFace, ProxyNCA, ProxyAnchor), include the loss module in the gradient computation so its weights are jointly optimized:
arcface = ArcFaceLoss(num_classes=100, embedding_dim=128, rngs=nnx.Rngs(1))
def train_step(encoder, arcface, optimizer, batch_x, batch_labels):
def loss(encoder, arcface):
return arcface(encoder(batch_x), batch_labels)
grads = nnx.grad(loss)(encoder, arcface)
optimizer.update(grads)
Choosing a Loss¤
| Loss | Best for | Complexity |
|---|---|---|
| ContrastiveLoss | Small batches, pairwise similarity | \(O(N^2)\) |
| TripletMarginLoss | Baseline metric learning | \(O(N^3)\) triplets |
| NTXentLoss | Self-supervised contrastive (SimCLR) | \(O(N^2)\) |
| ArcFaceLoss | Face recognition, fine-grained classification | \(O(NC)\) |
| CosFaceLoss | Similar to ArcFace, simpler margin | \(O(NC)\) |
| ProxyNCALoss | Large class count, fast convergence | \(O(MC)\) |
| ProxyAnchorLoss | Hardest mining with smooth gradients | \(O(MC)\) |
Where \(N\) = batch size, \(C\) = number of classes, \(M\) = number of proxies.
Reduction
All MetricLearningLoss subclasses accept a reduction parameter
("mean" or "sum"). Default is "mean". Angular margin and proxy
losses handle reduction internally.
Next Steps¤
-
Stateful Metrics
Frozen backbone (Tier 1) and learned (Tier 2) metrics
-
Metric Composition
Group, weight, threshold, and track metrics