Fine-Tune Embeddings for Your Domain

Domain adaptation through fine-tuning embedding models with sentence-transformers

Overview

Generic embedding models are trained on broad internet data and may not capture the nuances of your particular use case. By fine-tuning these models with domain-specific data, you can achieve 15-30% improvements in key retrieval metrics, making fine-tuning one of the highest-ROI optimizations for RAG systems.

The Golden Rule: It's probably a bad idea to train your own language model, but it's a very good idea to train your own embedding model.

Why This Matters

Fine-tuning embedding models is fundamentally different from fine-tuning LLMs:

  • Cost: ~$1.50 and 40 minutes on a laptop (vs. thousands for LLM fine-tuning)
  • Maintenance: No inference servers to maintain, no CUDA drivers to manage
  • Performance: With just 6,000 examples, you can outperform closed-source models on your specific tasks
  • Future-proof: When OpenAI releases a new embedding model, you can fine-tune it. When they release a new LLM, your fine-tuned LLM becomes obsolete.

From Production: "Companies like OpenAI or Anthropic aren't primarily focused on making retrieval better—they're not launching new embedding models daily. If you spend effort fine-tuning an LLM, you need to consider whether your fine-tuned model will be competitive when the original provider releases a new version in a few months."

Why Fine-Tune?

  • Domain Specificity: Capture terminology and concepts specific to your use case
  • Improved Retrieval: Better semantic understanding leads to higher recall and MRR
  • Cost Efficiency: Open-source fine-tuning gives you control and reduces inference costs
  • Customization: Tailor the model to your exact requirements

Key Concepts

Hard Negatives

Training examples that are similar but incorrect, forcing the model to learn fine distinctions. For example, in a transaction classification system:

  • Query: "Monthly subscription payment"
  • Positive: "Subscription fees category"
  • Hard Negative: "Recurring payments category" (similar but different)
  • Easy Negative: "Grocery purchases category" (obviously different)

Semi-Hard Negatives

Moderately challenging negative examples that provide optimal learning signals. These are negatives that are:

  • Further from the anchor than the positive (preserves order)
  • Still reasonably close to the anchor (provides learning signal)

Triplet Loss

Training objective that brings positive examples closer while pushing negatives away:

Loss = max(0, distance(anchor, positive) - distance(anchor, negative) + margin)

Implementation Guide

1. Generate Training Data

import instructor
from openai import OpenAI
from pydantic import BaseModel
from typing import List

class TrainingExample(BaseModel):
    query: str
    positive_doc: str
    category: str

client = instructor.from_openai(OpenAI())

def generate_synthetic_data(categories: List[str], num_per_category: int = 50):
    training_data = []
    
    for category in categories:
        examples = client.chat.completions.create(
            model="gpt-4",
            response_model=List[TrainingExample],
            messages=[
                {"role": "system", "content": "Generate realistic training examples"},
                {"role": "user", "content": f"Generate {num_per_category} examples for category: {category}"}
            ]
        )
        training_data.extend(examples)
    
    return training_data

2. Create Hard Negatives

from sentence_transformers import SentenceTransformer, util

def mine_hard_negatives(query: str, candidates: List[str], model: SentenceTransformer):
    """Find challenging negative examples using similarity search"""
    
    # Encode query and candidates
    query_emb = model.encode(query)
    candidate_embs = model.encode(candidates)
    
    # Calculate similarities
    similarities = util.cos_sim(query_emb, candidate_embs)[0]
    
    # Hard negatives are most similar incorrect examples
    sorted_indices = similarities.argsort(descending=True)
    hard_negatives = [candidates[i] for i in sorted_indices[1:4]]  # Top 3 excluding positive
    
    return hard_negatives

3. Fine-Tune with Sentence-Transformers

from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import TripletEvaluator
from torch.utils.data import DataLoader

# Load base model
model = SentenceTransformer('BAAI/bge-base-en')

# Prepare training data
train_examples = []
for item in training_data:
    for hard_neg in item.hard_negatives:
        train_examples.append(InputExample(
            texts=[item.query, item.positive_doc, hard_neg]
        ))

# Create DataLoader
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# Define loss function
train_loss = losses.TripletLoss(
    model=model,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.5
)

# Training configuration
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    warmup_steps=100,
    output_path='./fine-tuned-embeddings'
)

4. Evaluate Performance

def evaluate_fine_tuned_model(base_model, finetuned_model, test_data):
    """Compare base vs fine-tuned model performance"""
    
    results = {
        'base': {'recall@5': [], 'mrr@5': []},
        'finetuned': {'recall@5': [], 'mrr@5': []}
    }
    
    for query, expected_doc in test_data:
        # Evaluate base model
        base_results = retrieve(query, base_model, k=5)
        results['base']['recall@5'].append(expected_doc in base_results)
        
        # Evaluate fine-tuned model
        ft_results = retrieve(query, finetuned_model, k=5)
        results['finetuned']['recall@5'].append(expected_doc in ft_results)
    
    # Calculate improvements
    improvement = {
        'recall': np.mean(results['finetuned']['recall@5']) - np.mean(results['base']['recall@5']),
        'relative_gain': (np.mean(results['finetuned']['recall@5']) / np.mean(results['base']['recall@5']) - 1) * 100
    }
    
    return improvement

Training Best Practices

Data Quality

  • Manual Review: Use tools like Streamlit to manually review and filter synthetic data
  • Diversity: Ensure examples cover different writing styles and phrasings
  • Balance: Maintain balance across categories to prevent bias

Hyperparameter Tuning

training_config = {
    'epochs': 3,  # 3-5 epochs typically sufficient
    'batch_size': 16,  # Adjust based on GPU memory
    'warmup_steps': 100,  # 10% of total steps
    'learning_rate': 2e-5,  # Default for sentence-transformers
    'triplet_margin': 0.5,  # Distance margin for triplet loss
}

Negative Sampling Strategies

  • Semi-Hard Negatives: Best learning signal, use SentencesDataset with semi-hard negative mining
  • Hard Negatives: More challenging, requires pre-computed similarity search
  • In-Batch Negatives: Efficient variant using other positives in batch as negatives

Deployment

Save and Load Fine-Tuned Model

# Save model
model.save('./fine-tuned-embeddings')

# Load for inference
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('./fine-tuned-embeddings')

Upload to Hugging Face Hub

from huggingface_hub import login

login(token="your_hf_token")

model.save_to_hub(
    repo_name="your-org/fine-tuned-rag-embeddings",
    private=False,
    commit_message="Fine-tuned embedding model for domain X"
)

Expected Outcomes

After fine-tuning, you should see:

  • 15-30% improvement in retrieval metrics (recall@k, MRR@k)
  • Better handling of domain-specific terminology
  • Improved semantic understanding of specialized concepts
  • Reduced mis-retrieval of similar but incorrect documents

Common Issues

Low Quality Synthetic Data

Solution: Implement manual review process. Quality matters more than quantity—100 high-quality examples beat 1000 poor ones.

Overfitting on Small Datasets

Solution:

  • Use proper train/eval splits
  • Monitor validation metrics during training
  • Apply early stopping if validation performance degrades

GPU Memory Errors

Solution:

  • Reduce batch size
  • Use gradient accumulation
  • Consider smaller base models (e.g., all-MiniLM-L6-v2)

Next Steps

  • Generate your synthetic training dataset
  • Implement manual review process
  • Start with 100-200 high-quality examples per category
  • Fine-tune and compare against baseline
  • Iterate based on error analysis

Common Questions

"Should I fine-tune my LLM or my embedding model?"

Fine-tune your embedding model, not your LLM. Here's why:

For embedding models:

  • Costs $1.50 and takes 40 minutes on a laptop
  • With 6,000-10,000 examples, you can outperform closed-source models
  • No infrastructure to maintain
  • Easy to update when base models improve

For LLMs:

  • Very costly to fine-tune and maintain
  • Requires inference servers, CUDA drivers, infrastructure
  • Becomes obsolete when providers release new versions
  • For a team of 4-5 people, the maintenance cost is too high

From Production: "Bloomberg spent millions on their own model, and within 5-6 months, GPT-4 was better. Instead of fine-tuning, consider using RAG to retrieve relationship information first, put that in the context, and then add the actual question."

When LLM fine-tuning might make sense:

  • Specific tonality or personalization requirements
  • Access to proprietary data that can't be in prompts
  • Very large scale where inference costs justify the investment

"How much data do I need to fine-tune embeddings?"

6,000-10,000 examples is the sweet spot. With this amount, you can create embedding models that outperform general-purpose models on your specific tasks.

If you're embedding massive datasets at scale, fine-tuning becomes even more valuable:

  • Spinning up your own GPUs lets you process much more text per second
  • Example: Embedding 20GB of text takes 15 minutes and costs ~$20
  • Using OpenAI APIs would be more expensive and much slower

"Should I fine-tune my embedding model or my re-ranker?"

Use the same dataset for both, but prioritize based on your bottleneck:

  1. Is Recall@100 already good (95%) but Recall@10 poor (50%)? → Focus on the re-ranker
  2. Are you missing relevant documents even in top 100? → Focus on the embedding model

The key insight: By having metrics on both stages, you can identify where to focus improvement efforts.

"How do I avoid overfitting with synthetic data?"

The biggest issue is mismatch between user questions and synthetic data. Check:

  1. Character count variance: If customer questions average 30 characters but synthetic data averages 90, the LLM is too verbose
  2. Embedding variance: Are synthetic embeddings too similar to each other?
  3. Question patterns: Do synthetic questions match real user query patterns?

Practical Tip: "Intelligently incorporate real-world examples from users into the few-shot examples for synthetic data generation to make it more diverse."

"Should I use different embedding models for different content types?"

No, unless you're at massive scale. Having multiple embedding models for different content types (like product descriptions vs. comments) probably won't yield enough performance improvement to justify the maintenance cost.

There's evidence that a single unified model trained on all your data performs better than specialized models. In machine translation, researchers found that a single model trained to translate all languages performed better than individual models for each language pair.

The unified model learns something about the underlying system that allows it to handle even rare cases better than specialized models would.

Additional Resources