Overview
FlashRAG is a Python toolkit designed for the reproduction and development of retrieval-augmented generation research. It provides a modular framework with unified interfaces across 32+ RAG methods, standardized evaluation on multiple benchmarks, and pre-built pipelines covering naive RAG, advanced RAG, and modular RAG architectures. The toolkit is ideal for researchers comparing RAG approaches and developing new methods.
FlashRAG includes pre-processed benchmark datasets (Natural Questions, TriviaQA, HotpotQA, etc.), a component library with swappable retrievers, generators, rerankers, and refiners, and comprehensive evaluation metrics. It supports both API-based and local models, making it accessible for academic research and prototyping.
Installation
pip install flashrag
# From source with all dependencies
git clone https://github.com/RUC-NLPIR/FlashRAG.git
cd FlashRAG
pip install -e ".[all]"
# Download pre-processed datasets
python -m flashrag.download --dataset nq --output_dir ./data
python -m flashrag.download --dataset triviaqa --output_dir ./data
python -m flashrag.download --dataset hotpotqa --output_dir ./data
Core Components
Pipeline Types
| Pipeline | Description | Key Feature |
|---|
SequentialPipeline | Standard retrieve-then-read | Baseline RAG |
ConditionalPipeline | Adaptive retrieval decision | Skips retrieval when not needed |
BranchingPipeline | Multiple retrieval paths | Ensemble strategies |
LoopPipeline | Iterative retrieval | Multi-hop reasoning |
SelfRAGPipeline | Self-reflective generation | Critique and refine |
IRCoTPipeline | Chain-of-thought retrieval | Interleaved reasoning |
Basic Usage
from flashrag.config import Config
from flashrag.pipeline import SequentialPipeline
from flashrag.dataset import Dataset
# Load configuration
config = Config("config.yaml")
# Load dataset
dataset = Dataset(config, "nq")
# Create pipeline
pipeline = SequentialPipeline(config)
# Run evaluation
results = pipeline.run(dataset)
print(f"EM: {results['em']:.4f}")
print(f"F1: {results['f1']:.4f}")
Configuration
# config.yaml
data:
dataset_name: nq
dataset_path: ./data/nq
split: test
retriever:
retriever_type: dense
model_path: facebook/contriever-msmarco
corpus_path: ./corpus/wiki_2018.jsonl
index_path: ./index/contriever_wiki
top_k: 5
batch_size: 64
reranker:
reranker_type: cross_encoder
model_path: cross-encoder/ms-marco-MiniLM-L-12-v2
top_k: 3
generator:
generator_type: api
model: gpt-4o
api_key: ${OPENAI_API_KEY}
max_tokens: 256
temperature: 0.0
# Or local model
# generator_type: hf
# model_path: meta-llama/Llama-3.1-8B-Instruct
# device: cuda
# max_tokens: 256
evaluation:
metrics: [em, f1, recall, precision]
save_results: true
output_dir: ./results
Retriever Configuration
Supported Retrievers
| Retriever | Type | Model Examples |
|---|
| Dense (Bi-encoder) | Neural | Contriever, DPR, BGE, E5 |
| Sparse | Lexical | BM25, SPLADE |
| ColBERT | Late-interaction | ColBERTv2 |
| Multi-vector | Neural | ColBERT, ME-BERT |
Building Index
from flashrag.retriever import DenseRetriever
# Build dense index
retriever = DenseRetriever(
model_path="facebook/contriever-msmarco",
corpus_path="./corpus/wiki_2018.jsonl",
batch_size=256,
device="cuda"
)
# Build FAISS index
retriever.build_index(
index_path="./index/contriever_wiki",
index_type="flat", # flat, ivf, hnsw
gpu=True
)
# Search
results = retriever.search(
queries=["What is machine learning?"],
top_k=10
)
BM25 Retriever
from flashrag.retriever import BM25Retriever
bm25 = BM25Retriever(
corpus_path="./corpus/wiki_2018.jsonl",
index_path="./index/bm25_wiki"
)
bm25.build_index()
results = bm25.search(["What is Python?"], top_k=10)
RAG Methods
Self-RAG
from flashrag.pipeline import SelfRAGPipeline
config = Config("selfrag_config.yaml")
pipeline = SelfRAGPipeline(config)
# Self-RAG adds reflection tokens:
# [Retrieve] - decides whether to retrieve
# [IsREL] - judges relevance of retrieved passages
# [IsSUP] - checks if generation is supported
# [IsUSE] - evaluates overall utility
results = pipeline.run(dataset)
Iterative RAG (IRCoT)
from flashrag.pipeline import IRCoTPipeline
config = Config("ircot_config.yaml")
pipeline = IRCoTPipeline(config)
# IRCoT interleaves chain-of-thought with retrieval
# Each reasoning step generates a new query
results = pipeline.run(dataset, max_iterations=4)
REPLUG
from flashrag.pipeline import REPLUGPipeline
config = Config("replug_config.yaml")
pipeline = REPLUGPipeline(config)
# REPLUG ensembles over retrieved documents
# Each document generates independently, then merges
results = pipeline.run(dataset)
Advanced Usage
Custom Pipeline
from flashrag.pipeline import BasicPipeline
from flashrag.component import BaseRetriever, BaseGenerator
class CustomRAGPipeline(BasicPipeline):
def __init__(self, config):
super().__init__(config)
self.retriever = self.build_retriever()
self.generator = self.build_generator()
def run_item(self, item):
query = item["question"]
# Custom retrieval logic
docs = self.retriever.search([query], top_k=10)
# Custom filtering
filtered = [d for d in docs[0] if d["score"] > 0.5]
# Custom prompt
context = "\n".join([d["text"] for d in filtered[:5]])
prompt = f"Answer based on context.\nContext: {context}\nQ: {query}\nA:"
# Generate
answer = self.generator.generate([prompt])[0]
return {"pred": answer, "docs": filtered}
Batch Evaluation
from flashrag.config import Config
from flashrag.pipeline import SequentialPipeline
from flashrag.dataset import Dataset
datasets = ["nq", "triviaqa", "hotpotqa", "musique"]
results = {}
for ds_name in datasets:
config = Config("config.yaml", overrides={"data.dataset_name": ds_name})
dataset = Dataset(config, ds_name)
pipeline = SequentialPipeline(config)
results[ds_name] = pipeline.run(dataset)
print(f"{ds_name}: EM={results[ds_name]['em']:.4f}, F1={results[ds_name]['f1']:.4f}")
Ablation Studies
# Compare retriever impact
for top_k in [1, 3, 5, 10, 20]:
config = Config("config.yaml", overrides={"retriever.top_k": top_k})
pipeline = SequentialPipeline(config)
result = pipeline.run(dataset)
print(f"top_k={top_k}: EM={result['em']:.4f}")
# Compare with/without reranking
for use_reranker in [True, False]:
config = Config("config.yaml", overrides={"reranker.enabled": use_reranker})
pipeline = SequentialPipeline(config)
result = pipeline.run(dataset)
print(f"Reranker={use_reranker}: EM={result['em']:.4f}")
Troubleshooting
| Issue | Solution |
|---|
| FAISS index OOM | Use IVF index type instead of flat for large corpora |
| Slow dense retrieval | Build GPU index, increase batch size |
| Dataset download fails | Download manually from HuggingFace datasets |
| Generator timeout | Increase max_retries and timeout in config |
| Corpus format error | Ensure JSONL with id, text, title fields |
| CUDA out of memory | Reduce generator batch size or use quantized model |
| Metric computation error | Check prediction format matches ground truth format |
| Index dimension mismatch | Rebuild index when changing embedding model |
# Validate config
python -m flashrag.validate_config config.yaml
# Check corpus format
python -c "
import json
with open('corpus.jsonl') as f:
line = json.loads(f.readline())
print(line.keys()) # Should have: id, text, title
"