Skip to content

Metric Learning And Retrieval

Classification answers "which of K classes?". Metric learning answers "how similar are these two inputs?" — and once you can answer that, you can do retrieval, verification, re-identification, few-shot classification, and nearest-prototype inference on classes you have never seen during training.

What This Is

Train an encoder f such that the distance between embeddings of matched pairs is smaller than the distance between mismatched pairs:

d(f(x_anchor), f(x_positive))  <  d(f(x_anchor), f(x_negative))

At inference you do not classify — you embed the query and retrieve the nearest gallery items. The same trained encoder handles any new class as long as at least one gallery example exists.

Three objective families you should know by name:

  • Contrastive loss (Hadsell 2006). Pull matched pairs together; push unmatched pairs apart if they are within a margin m.
  • Triplet loss (FaceNet). L = max(0, d(a, p) - d(a, n) + m). Anchor-positive-negative triplet; works when you can sample a hard negative.
  • Angular / margin softmax (ArcFace, CosFace). Multi-class classification over training identities, but with an additive angular margin so class clusters are tight and well separated. Often easier to train than triplets.

A close cousin, usually taught alongside: prototype classifiers. Compute the mean embedding of each class's support examples; classify by nearest prototype (cosine or Euclidean). This is the move behind few-shot classification and behind extending a frozen classifier to new classes without training.

When You Use It

  • the set of classes at inference time is not the set of classes at training time (open-set recognition, re-identification, retrieval)
  • you need to compare two inputs for "same or different", not label them
  • you want a single encoder that serves many downstream tasks by swapping the gallery
  • you have few labeled examples per class but many identities — the triplet/pair structure leverages the combinatorial relationships instead of the absolute count

Do Not Use It When

  • the class set is fixed and large-sample — a classifier is simpler and usually better
  • you cannot define "matched" vs "mismatched" cleanly (if your ground truth is fuzzy, so is the loss)
  • your downstream consumer wants calibrated probabilities — embedding distances are not probabilities without extra work

Triplet Loss In Practice

The loss itself is three lines. The hard part is triplet sampling.

  • random triplets converge slowly and hit a plateau — most random negatives are already far from the anchor, so the loss is near zero.
  • semi-hard triplets (FaceNet): among negatives that are farther than the positive, pick the closest. Produces steady gradients.
  • hard triplets: pick the negative closest to the anchor, period. Faster progress but can collapse if the model is too weak to separate them.
  • batch-hard mining: within a minibatch of P identities × K images each, for each anchor take the hardest positive in its identity and the hardest negative in any other identity. This is the "in-batch" trick that modern implementations use.

The margin m is a hyperparameter. Too small → easy triplets → no gradient. Too large → everything is a hard triplet → training is noisy.

Prototype Classifiers In Practice

For a frozen encoder f and a labeled support set:

prototype_c = mean over {f(x) : class(x) == c}
prediction = argmax_c  cos( f(query), prototype_c )

Normalize embeddings to unit length before averaging (otherwise magnitude imbalance skews the prototype).

This is the "extend a frozen classifier to new classes without training" move — and it composes with the base classifier: you can interpolate between the frozen classifier's logits and the prototype-distance scores, tuning the mix on a val set.

Re-Identification And Retrieval

Once you have a trained encoder, retrieval is:

  1. embed the entire gallery offline
  2. at query time, embed the query and compute cosine similarity to every gallery row
  3. return top-k

The metrics that matter:

  • Recall@k — fraction of queries whose correct match is in the top-k retrieved
  • mAP (mean average precision) — rewards ranking correct matches higher
  • CMC curve — cumulative match characteristic, Recall@k as a function of k

What To Inspect

  • embedding norm distribution — if norms vary wildly, cosine-vs-Euclidean distinction matters; typically normalize to unit sphere
  • positive-pair vs negative-pair distance histograms — they should separate as training progresses; if they overlap heavily, your triplets are not informative
  • hard-negative examples — the ones with smallest d(a, n). Spend time looking at a dozen of these. They will tell you whether the encoder is confused by a meaningful feature or a spurious one.
  • per-identity recall — any identity with recall zero means the encoder cannot distinguish it at all; those rows often reveal labeling errors or near-duplicate identities
  • dimension usage (via PCA on embeddings) — if 90% of variance is in 4 of 128 dimensions, you are under-using capacity

Failure Pattern

  • embedding collapse. The model learns f(x) = constant — distance is zero between everything, loss is the margin, gradient stays nonzero but points nowhere useful. Usually a symptom of too-hard triplets early or missing batch normalization.
  • shortcut features. The encoder latches onto a trivial feature (background color, pose) that correlates with identity in the training set but does not generalize. Inspect hard negatives to catch this.
  • sampling bias. You sampled P identities but some identities have 1 image and others have 100. The common ones dominate gradients. Fix with balanced P×K sampling.
  • margin tuning. Starting with too large a margin produces noise; too small a margin leaves easy triplets. Start with m = 0.2 for cosine, m = 1.0 for Euclidean and sweep.
  • evaluation on training identities. Metric learning is about generalization to unseen identities. Evaluate on held-out identities, not held-out images of training identities.

Quick Checks

  • are positive and negative pair distance distributions separating? (plot them every epoch)
  • are embeddings L2-normalized before the loss? (they should be for cosine-based variants)
  • is the gallery embedded with the same preprocessing as the query?
  • for prototype classifiers: is the support set per class large enough for a stable mean? Below ~5 per class, prototypes are noisy.
  • does your val set use unseen identities?

Practice

Run academy/labs/metric-learning-retrieval/src/metric_learning_workflow.py:

  • synthetic shape dataset with per-identity variants (rotation, scale, noise)
  • trains a small CNN with three heads: classification baseline, triplet loss with random sampling, triplet loss with batch-hard mining
  • evaluates Recall@1 / @5 on held-out identities not seen at training
  • adds a prototype-classifier experiment: freeze the classification-baseline encoder, compute per-class prototypes, measure accuracy on a held-out class not in training

You should leave the lab able to explain why batch-hard beats random triplets, what a good positive-vs-negative distance histogram looks like, and why prototype classifiers work on classes the model has never been trained to classify.

Longer Connection

Metric learning is the "classifier that works on new classes" building block. It shows up under different names across subfields:

  • face verification (FaceNet, ArcFace) — identity re-identification in vision
  • person re-ID / vehicle re-ID — same person or vehicle across cameras
  • dense retrieval for text (DPR, ColBERT) — query/passage encoders trained with contrastive or triplet losses
  • multimodal retrieval (CLIP) — symmetric InfoNCE is a metric-learning objective over image-text pairs
  • few-shot classification (Prototypical Networks) — prototype classifier on top of a meta-learned encoder

The unifying idea is: learn a space in which distance means similarity, then use distance for everything else. This is also why you pair this topic with Self-Supervised and Representation Learning — SSL and metric learning produce the same output (an encoder) for slightly different reasons.