Systems Atlas
Chapter 7.3Training Embedding Models

Contrastive Learning & Losses

How contrastive objectives train models to shape vector space geometry. Cosine similarity, bi-encoders, InfoNCE/MNR loss, the temperature trick, batch size as free negatives, and why it all eventually collapses if you ignore the five failure modes.

Architecture

Bi-Encoders

Dominant pattern in production: one shared encoder for queries and documents, scored by cosine similarity.

In-Batch Negatives

1,024

Free negatives per query in a batch of 1024. Larger batches provide dramatically more training signal at no extra data cost.

Temperature

0.05

Typical production default. Lower sharpens learning signal. Higher smooths difficulty. Most impactful single hyperparameter.

Key Insight: Contrastive learning does not define what "relevant" means. It teaches the model a geometry — things that should be close are pulled together, things that should be far are pushed apart. Everything depends on the quality of your positive pairs and the difficulty of your negatives.

1. The Geometric Objective

The central idea of contrastive learning is deceptively simple: encode similar things so that their vectors point in similar directions, and encode dissimilar things so their vectors point in different directions. Over many gradient steps, the encoder learns a map from raw text to a high-dimensional manifold where semantic similarity becomes spatial proximity.

This is not a classification task. There are no output classes. The model is never told what category a document belongs to. Instead, it is given pairs or triplets and trained to arrange them correctly in a continuous vector space. The distance that matters in production is cosine similarity.

Why Contrastive Learning Over Other Approaches?

Classification head (MLM, cross-entropy)

Works for fixed taxonomies. Cannot generalize to queries and documents it has never seen labeled. Fails for open-vocabulary search.

Contrastive objective

Learns from structural relationships. Generalizes to unseen query-document pairs. Scales with data volume and batch size. Native support for retrieval.


2. Cosine Similarity: The Math That Drives the Geometry

Cosine similarity measures the angle between two vectors, ignoring their magnitude. It is computed as the dot product of the vectors divided by the product of their L2 norms. This makes it invariant to vector magnitude, so it captures pure directional (semantic) similarity.

Dot product

a · b = Σᵢ aᵢ × bᵢ

L2 Norm

‖a‖ = √(Σᵢ aᵢ²)

Cosine Similarity

cos(a, b) = (a · b) / (‖a‖ × ‖b‖)

The Production Identity — when both vectors are L2-normalized:

If ‖a‖ = ‖b‖ = 1, then cos(a, b) = a · b

This means dotproduct and cosine similarity are identical for normalized vectors. Production systems always normalize, enabling fast MIPS (maximum inner product search) for nearest-neighbor retrieval.

The practical implication: always normalize embedding vectors before any similarity computation or index insertion. This is not an optional optimization — it ensures the similarity score is on the [−1, 1] scale, enabling meaningful interpolation, thresholding, and score comparison across queries. A vector with a large magnitude and small cosine angle will appear falsely similar without normalization.


3. Bi-Encoders vs Cross-Encoders

The two dominant architectures for text matching are bi-encoders and cross-encoders. They differ in how they produce a relevance score — and that difference determines where each can be used in a real search system.

Bi-Encoder

bi_encoder_score.py
import torch.nn.functional as F
def score(query, doc):
q = F.normalize(encoder(query), dim=-1)
d = F.normalize(encoder(doc), dim=-1)
return (q * d).sum(dim=-1)
  • Can pre-compute doc embeddings
  • Scales to millions of docs
  • No token-level cross-attention

Cross-Encoder

cross_encoder_score.py
def score(query, doc):
# Both texts tokenized together
inp = tokenize(query + "[SEP]" + doc)
return reranker(inp).logit
  • Higher accuracy on close pairs
  • Full cross-attention interaction
  • Cannot pre-compute, too slow at scale
The Standard Architecture: Bi-encoder at retrieval (ANN index), cross-encoder as a reranker on the top-100 candidates. The bi-encoder provides recall, the cross-encoder provides precision. Training with contrastive objectives applies to the bi-encoder stage.

4. InfoNCE / Multiple Negatives Ranking Loss

Multiple Negatives Ranking (MNR) loss — also known as InfoNCE in the contrastive learning literature — is the dominant training objective for embedding models. Given a batch of (query, positive) pairs, each query treats every other positive in the batch as a negative. This makes the loss implicitly leverage in-batch negatives with no extra data.

mnr_loss.py
import torch
import torch.nn.functional as F
def mnr_loss(q_embs, d_embs, temperature=0.05):
# Normalize so dot product = cosine similarity
q = F.normalize(q_embs, dim=-1)
d = F.normalize(d_embs, dim=-1)
# BxB similarity matrix, diagonal = positives
scores = torch.matmul(q, d.T) / temperature
# Targets: each row, the correct column is the diagonal
labels = torch.arange(len(q), device=q.device)
# Cross-entropy treats off-diagonal entries as negatives
return F.cross_entropy(scores, labels)

Temperature Control

The temperature scalar τ controls how sharply the loss discriminates between positives and negatives. Lower temperature compresses the similarity distribution, making the model work harder to rank positives above negatives. Higher temperature softens it.

TemperatureEffect
0.02Very sharp — strong learning signal, high collapse risk
0.05Standard default — good balance for most tasks
0.1Softer — gentler training, weaker fine distinctions

Batch Size as Free Negatives

The number of in-batch negatives grows linearly with batch size. A batch of 32 gives 31 negatives per query. A batch of 1024 gives 1023. This is a key reason why embedding training runs on large batch sizes — it is free additional training signal.

Batch SizeFree negativesMemory
3231Low
256255Medium
10241023High (multi-GPU)

5. Pooling: Getting a Vector from Token Outputs

A transformer encoder produces one vector per token. For a sentence embedding, you need a single vector. Pooling is how you aggregate token representations into a sentence-level embedding.

MethodDescriptionBest For
CLS TokenUse the embedding of the classification token [CLS]Models pre-trained with classification objective (BERT)
Mean PoolingAverage all non-padding token embeddingsBest default for sentence similarity and retrieval models; usually beats CLS
Max PoolingTake the maximum value across tokens for each dimensionCaptures salient features; rarely best for retrieval

Mean pooling consistently outperforms CLS pooling for retrieval tasks except for models specifically fine-tuned with CLS-based contrastive objectives. Modern embedding models like E5 and BGE use mean pooling or task prompts that interact with the encoder before pooling. Use mean pooling as the default unless the model card specifies otherwise.


6. Beyond Vanilla InfoNCE: Extended Loss Functions

MNR/InfoNCE is the starting point, not the end. Several extensions address specific problems encountered in production training.

Triplet Loss

Optimizes: score(q, pos) - score(q, neg) > margin. Useful when you have explicit triplets and a margin to enforce. Less sample-efficient than MNR with in-batch negatives.

MarginMSE

Matches the margin between positive and negative scores from a teacher reranker. Excellent when you have soft labels from a more powerful model and want to distill ranking preferences rather than just relevance flags.

KL-Divergence Distillation

Minimize KL divergence between student score distribution and teacher score distribution over candidates. Transfers the full softmax distribution, not just the margin. Useful in knowledge distillation setups.

Multi-Stage Training

Start with MNR on large weakly labeled pairs → add hard negatives → fine-tune with MarginMSE from a teacher. This staged curriculum is how the strongest public embedding models are trained.


7. Five Failure Modes to Diagnose Early

Contrastive training can go wrong in a small set of well-understood ways. Knowing these patterns before training begins means you can read metrics and identify what happened in hours rather than days.

Representation Collapse

All query embeddings converge to the same point in vector space. Loss drops to near-zero immediately not because the model learned everything, but because it encodes all inputs identically. Diagnostic: STD of embedding dimensions collapses to nearly zero. Fix: increase temperature, reduce learning rate, check normalization.

False Negative Saturation

In-batch negatives often include documents that are actually relevant to the query. With large batches in dense domains, this dramatically exceeds 5%. Diagnostic: nDCG@10 stalls or regresses as batch size increases. Fix: deduplicate batch, track same-entity pairs, or use ANCE-style deduplication.

Shortcut Learning

The model learns surface features (product ID overlap, date prefix, brand name) rather than semantic meaning. Diagnostic: Strong in-domain performance but fails on paraphrase benchmarks and out-of-distribution queries. Fix: use soft labels from teacher, remove exact-match examples from training, add paraphrase augmentation.

Overfitting to Query Patterns

Benchmark nDCG improves but live search metrics plateau or regress. This often happens with small datasets and too many epochs. Diagnostic: Offline metrics diverge from online metrics. Fix: use linear warmup + cosine decay, add weight decay, use fewer unique queries with more negatives per query rather than more query-positive pairs.

Embedding Space Drift

After fine-tuning, existing document vectors in the ANN index are stale — they were computed by the old model. Serving a new query encoder against an old document index produces incoherent scores. Diagnostic: Norms of query vs document embeddings diverge. Fix: always trigger a full corpus re-embedding pass after model weight updates.


8. Practical Training Recipe

Here is the minimal training loop structure that covers the key implementation choices: normalization before loss, proper device handling, and the hyperparameter surface to tune.

train_biencoder.py
for batch in train_loader:
queries, docs = batch
# Forward pass through shared encoder
q_embs = encoder(queries) # [B, D]
d_embs = encoder(docs) # [B, D]
# Mean-pool over tokens
q_embs = mean_pool(q_embs, queries.attention_mask)
d_embs = mean_pool(d_embs, docs.attention_mask)
# Compute MNR loss
loss = mnr_loss(q_embs, d_embs, temperature=0.05)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()

Key Hyperparameters

HyperparameterTypical RangeNotes
Learning rate1e-5 to 2e-4Lower for full fine-tune, higher for LoRA adapters
Temperature0.02 to 0.10.05 is the most common default starting point
Batch size64 to 8192Bigger = more in-batch negatives = better signal
Max seq length128 to 512128 for queries, 256–512 for longer documents
Hard negatives1 to 20 per pairStart with 1–3, increase as model quality improves

Key Takeaways

01

Geometry is the training objective, not labels

Contrastive learning places matching pairs closer and non-matching pairs further apart in high-dimensional space. The model never sees explicit 'relevance' labels — it learns from the structural relationship between embeddings. Cosine similarity after L2 normalization is the standard distance metric.

02

Temperature controls how hard the problem is

Lower temperature (0.02–0.05) sharpens the distribution, forcing the model to correctly rank even very similar pairs — strong signal but risk of collapse. Higher values (0.07–0.1) soften it, easier training but the model may not learn fine distinctions. Temperature is often the single most impactful hyperparameter.

03

Batch size multiplies free negatives

With in-batch negatives, a batch of 1024 gives each query 1023 negatives at zero extra cost. Larger batch sizes are significantly more sample-efficient than additional epochs. But this requires high GPU memory and careful false negative filtering.

04

Recognize the five failure modes before training

Representation collapse, false negative saturation, shortcut learning, overfitting, and embedding space drift are the most common failures. Each has a distinct diagnostic signature. Knowing them before training means you can identify what went wrong in hours rather than days.