Testing RAG Systems

Unit tests, integration tests, and regression testing for RAG pipelines.

Overview

RAG systems require comprehensive testing: retrieval quality, generation accuracy, and end-to-end functionality.

Unit Tests

Test Retrieval

import pytest

def test_retrieval_returns_relevant_docs():
    query = "What is the refund policy?"
    docs = retriever.retrieve(query, k=5)
    
    assert len(docs) == 5
    assert any('refund' in doc.text.lower() for doc in docs)
    assert docs[0].score > 0.7  # High relevance

def test_retrieval_handles_empty_query():
    docs = retriever.retrieve("", k=5)
    assert len(docs) == 0

Test Chunking

def test_chunking_respects_max_size():
    text = "word " * 1000
    chunks = chunker.chunk(text, max_tokens=512)
    
    for chunk in chunks:
        assert count_tokens(chunk) <= 512

def test_chunking_preserves_content():
    text = "Important content here"
    chunks = chunker.chunk(text)
    
    reconstructed = " ".join(chunks)
    assert "Important content" in reconstructed

Integration Tests

def test_end_to_end_rag():
    # Setup
    db = VectorDB()
    db.add_documents(["The sky is blue", "Grass is green"])
    
    # Query
    answer = rag_pipeline("What color is the sky?", db)
    
    # Assert
    assert "blue" in answer.lower()
    assert answer.metadata['sources'] is not None

Regression Tests

class RegressionTestSuite:
    def __init__(self):
        self.test_cases = self.load_test_cases()
    
    def run(self):
        results = []
        for case in self.test_cases:
            answer = rag_pipeline(case['query'])
            
            # Compare with expected
            score = self.evaluate(answer, case['expected'])
            results.append({
                'query': case['query'],
                'score': score,
                'passed': score > 0.8
            })
        
        return results

Evaluation Metrics

def test_retrieval_metrics():
    queries = load_test_queries()
    
    recall_at_5 = []
    for query, relevant_docs in queries:
        retrieved = retriever.retrieve(query, k=5)
        retrieved_ids = {d.id for d in retrieved}
        relevant_ids = {d.id for d in relevant_docs}
        
        recall = len(retrieved_ids & relevant_ids) / len(relevant_ids)
        recall_at_5.append(recall)
    
    avg_recall = sum(recall_at_5) / len(recall_at_5)
    assert avg_recall > 0.8, f"Recall@5 too low: {avg_recall}"

Mock Testing

from unittest.mock import Mock, patch

def test_rag_with_mock_llm():
    with patch('openai.ChatCompletion.create') as mock_llm:
        mock_llm.return_value = Mock(
            choices=[Mock(message=Mock(content="Mocked answer"))]
        )
        
        answer = rag_pipeline("test query")
        assert answer == "Mocked answer"
        assert mock_llm.called

Performance Tests

import time

def test_retrieval_latency():
    start = time.time()
    docs = retriever.retrieve("test query", k=10)
    latency = (time.time() - start) * 1000
    
    assert latency < 200, f"Retrieval too slow: {latency}ms"

def test_concurrent_queries():
    from concurrent.futures import ThreadPoolExecutor
    
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = [executor.submit(rag_pipeline, f"query {i}") for i in range(100)]
        results = [f.result() for f in futures]
    
    assert len(results) == 100
    assert all(r is not None for r in results)

Next Steps