Fine-tuning: Adapting embeddings to your specific domain

Introduction

In this chapter, we are going to explore fine-tuning of embedding models and briefly mention embedding adapters. Fine-tuning adapts embedding models to domain-specific data, optimizing both query and document representations for better alignment. Embedding adapters, on the other hand, adjust query embeddings dynamically, offering a cost-efficient alternative without requiring re-indexing of the document corpus.

Fine-Tuning Embedding Models

Fine-tuning offers flexible options for adapting an embedding model to domain-specific tasks or retrieval objectives. You can retrain the entire model to deeply integrate domain nuances, or focus on specific layers—often the final ones or newly added layers—similar to techniques used in Convolutional neural nets. This allows for tailored optimization of embeddings, ensuring that queries and documents align more closely with the desired relevance criteria while balancing computational complexity.

How It Works

  1. Data Preparation
  2. A labeled dataset of query-document pairs with relevance scores is required. This dataset can be created manually or generated synthetically using models like LLMs.

  3. Training Process
    1. The model learns to adjust the embedding space such that relevant query-document pairs are closer, and irrelevant pairs are farther apart. This involves:

    2. Forward propagation to compute similarity scores for query-document pairs.
    3. Loss computation, such as contrastive loss or triplet loss, to optimize distances in the embedding space.
    4. Backpropagation to update the model's weights across all layers.
  4. Inference
  5. Once fine-tuned, both query and document embeddings are re-encoded, ensuring that future retrievals reflect the improved alignment. This step requires reprocessing your entire corpus to generate updated document embeddings.

Key Considerations

  • Advantages:
    • Provides deep integration with domain-specific semantics.
    • Ideal for retrieval tasks requiring high precision in specialized fields like legal, medical, or technical documentation.
  • Drawbacks:
    • Requires significant labeled data and computational resources.
    • Necessitates re-embedding the entire corpus post-training, which can be costly for large datasets.

Code Example

Code adapted from our friends at llamaindex 💙: Finetuning an Adapter on Top of any Black-Box Embedding Model - LlamaIndexFinetuning an Adapter on Top of any Black-Box Embedding Model - LlamaIndex

Getting started

This code sets up the foundation for fine-tuning by preparing a smaller, manageable dataset. It uses the Hugging Face datasets library to load the SQuAD dataset, a popular benchmark for question-answering tasks. The dataset is then reduced to 1,000 training examples and 500 validation examples to speed up the fine-tuning process while still maintaining sufficient data for experimentation. Finally, the sizes of the reduced datasets are printed to confirm successful selection. This step ensures efficient and focused training without overwhelming computational resources.

#!pip install datasets llama-index-embeddings-huggingface llama_index.finetuning

from datasets import load_dataset

# Load the SQuAD dataset
squad_dataset = load_dataset('squad')

# Reduce the size of the dataset to 1000 examples for both training and validation
train_dataset = squad_dataset['train'].select(range(1000))
validation_dataset = squad_dataset['validation'].select(range(500))

# Print the sizes to verify
print(f"Training Dataset Size: {len(train_dataset)}")
print(f"Validation Dataset Size: {len(validation_dataset)}")

Data preparation

This section transforms the SQuAD dataset into a format suitable for fine-tuning an embedding model using EmbeddingQAFinetuneDataset. The function transform_squad_to_finetune_dataset extracts key components: questions (queries), contexts (documents in the corpus), and their mappings (relevance between queries and documents). Unique IDs are assigned to each query and document for efficient indexing.

The method ensures no duplicate contexts in the corpus and establishes a clear relationship between each query and its relevant document(s). Debugging prints provide insights into the processed dataset, including the number of queries, documents, and their mappings. Finally, both training and validation datasets are prepared in this structured format, ready for fine-tuning the embedding model.

Note: this format is necessary for the EmbeddingAdapterFinetuneEngine function from llamaindex used later, different tools may require different formats.

from llama_index.finetuning.embeddings.common import EmbeddingQAFinetuneDataset

# Transform SQuAD into EmbeddingQAFinetuneDataset
def transform_squad_to_finetune_dataset(squad_dataset):
    queries = {}
    relevant_docs = {}
    corpus = {}

    # Use indices as IDs
    query_id_counter = 0
    doc_id_counter = 0

    for entry in squad_dataset:
        context = entry['context']
        question = entry['question']
        answers = entry['answers']['text']  # Answers not used in this example but could be for QA systems

        # Add context to corpus if not already present
        if context not in corpus.values():
            doc_id = str(doc_id_counter)
            corpus[doc_id] = context
            doc_id_counter += 1
        else:
            doc_id = list(corpus.keys())[list(corpus.values()).index(context)]

        # Add question as a query
        query_id = str(query_id_counter)
        queries[query_id] = question
        relevant_docs[query_id] = [doc_id]
        query_id_counter += 1

    # Debugging prints to check the dataset format
    print("\nFormatted Dataset Details:")
    print(f"Number of Queries: {len(queries)}")
    print(f"Sample Query: {list(queries.items())[:8]}")
    print(f"Number of Documents in Corpus: {len(corpus)}")
    print(f"Sample Document: {list(corpus.items())[:8]}")
    print(f"Relevant Documents Mapping (Query to Doc): {list(relevant_docs.items())[:8]}")

    return EmbeddingQAFinetuneDataset(
        queries=queries,
        relevant_docs=relevant_docs,
        corpus=corpus,
    )

# Prepare SQuAD train dataset for fine-tuning
train_dataset_formatted = transform_squad_to_finetune_dataset(train_dataset)
val_dataset_formatted = transform_squad_to_finetune_dataset(validation_dataset)

Fine-tuning

This section demonstrates a basic setup for fine-tuning an embedding model using the EmbeddingAdapterFinetuneEngine. It begins by resolving a base embedding model (BGE-small), which serves as the starting point for training. The finetuning engine is initialized with the formatted training dataset, specifying parameters like the number of epochs (12), an optimizer (Adam), and a learning rate (0.001).

The finetune() method executes the fine-tuning process, and the resulting fine-tuned model is retrieved for use. While this setup is functional, it is simplistic and serves as a starting point. For optimal results, you would perform hyperparameter optimization—experimenting with learning rates, batch sizes, and other training configurations to maximize performance.

from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.core.embeddings import resolve_embed_model
import torch

# Resolve base embedding model (BGE-small)
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")

# Initialize and configure the finetuning engine
finetune_engine = EmbeddingAdapterFinetuneEngine(
    dataset=train_dataset_formatted,
    embed_model=base_embed_model,
    model_output_path="model_output_test",
    epochs=12,  # Adjust as needed
    verbose=True,
    optimizer_class=torch.optim.Adam,  # Optional customization
    optimizer_params={"lr": 0.001}     # Optional customization
)

# Run fine-tuning
finetune_engine.finetune()

# Retrieve the fine-tuned embedding model
embed_model = finetune_engine.get_finetuned_model()

Helper functions

Below are helper functions provided by LlamaIndex that are used for analyzing the results of the fine-tuned embedding model. The evaluate function tests the retrieval performance of the model by comparing its output to expected results and calculating metrics like hit rate and Mean Reciprocal Rank (MRR). The display_results function aggregates and visualizes these metrics for multiple retrievers, providing a clear comparison of their performance. These functions will help you assess the effectiveness of your fine-tuned model.

Helper functions

Getting the results

This section evaluates the performance of both the base embedding model (bge) and the fine-tuned model (ft) on the validation dataset. The evaluate function is used to compute retrieval metrics like hit rate and MRR for each model. Finally, the display_results function presents a side-by-side comparison of the two models, allowing you to see the impact of fine-tuning on retrieval performance. This is where you validate whether the fine-tuned model has improved over the base model.

ft_val_results = evaluate(val_dataset_formatted, embed_model)
bge = "local:BAAI/bge-small-en"
bge_val_results = evaluate(val_dataset_formatted, bge)
display_results(
    ["bge", "ft"], [bge_val_results, ft_val_results]
)

Results

Retriever
Hit Rate
MRR
bge
93.2%
0.779233
ft
95.4%
0.776600

The results compare the performance of the base embedding model (bge) with the fine-tuned model (ft) on the validation dataset:

  1. Hit Rate: Measures how often the correct document appears in the top-k results.
    • bge: 93.2%
    • ft: 95.4%
    • Interpretation: The fine-tuned model demonstrates a slight improvement, showing it retrieves relevant documents more consistently.
  2. MRR (Mean Reciprocal Rank): Reflects how high the correct document appears in the ranked results.
    • bge: 0.779233
    • ft: 0.776600
    • Interpretation: While the hit rate improved, the MRR remained nearly unchanged, indicating the fine-tuned model occasionally ranks the correct document slightly lower.

Observations:

  • Despite using a very basic training setup with limited examples and no hyperparameter optimization, the results show promise, highlighting the potential of fine-tuning even in constrained scenarios.
  • The observed improvements are modest but meaningful, especially for applications requiring high reliability.

Potential Next Steps:

  • Perform statistical tests, such as paired t-tests or bootstrapping, to confirm the significance of improvements.
  • Optimize the training setup, including dataset size and hyperparameters, to maximize performance gains.

Further resources:

Embedding Adapters

Embedding adapters are a promising alternative to fine-tuning, offering a lightweight method to adjust embeddings dynamically during retrieval. Unlike the fine-tuning example discussed earlier, where adapters were additional trainable layers requiring the re-indexing of the entire vector database, these adapters operate exclusively on the query side. This eliminates the need to re-embed and re-index the document corpus, resulting in significant cost and computational savings. Instead of modifying the embedding model or document embeddings, query embeddings are transformed through a learned matrix, adapting them to specific relevance criteria on the fly.

The core idea behind embedding adapters is to apply a transformation matrix with the same dimensions as the embeddings themselves. This matrix amplifies relevant dimensions of the embedding space while suppressing less relevant ones, dynamically tailoring the query embedding to the task at hand. While this approach has seen limited exploration in the broader AI community, it holds considerable promise for scenarios requiring adaptive retrieval without the overhead of full re-indexing. By refining query embeddings in real-time, embedding adapters offer a flexible and cost-efficient way to enhance retrieval accuracy in domain-specific RAG pipelines.

So what?

I highly recommend fine-tuning for your specific use case. In his amazing Stanford lecture, Douwe Kiela described the traditional RAG setup as a "Frankenstein monster" – a collection of components stitched together without much focus on optimization. By fine-tuning embedding models or incorporating embedding adapters, you take a significant step toward building RAG systems that are carefully tailored and optimized for your unique requirements. It is definitely worth exploring!

Conclusion

We demonstrated how fine-tuning embedding models, even with a basic setup, can improve retrieval accuracy by better aligning embeddings to the training data. Additionally, we introduced embedding adapters as a lightweight option for refining query embeddings, which can enhance retrieval without reprocessing the corpus. These techniques are practical tools for increasing the likelihood of retrieving relevant information in domain-specific contexts.

In the next chapter, we will cover Multimodal RAG! This is a new and exciting field, modern databases (cough cough Deeplake) allow you to embed images the same way as you embed text and then you can query both at the same time! 🤯

Jupyter: Google ColabGoogle Colab