rag-python 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {rag_python-0.2.0/src/rag_python.egg-info → rag_python-0.3.0}/PKG-INFO +22 -4
  2. {rag_python-0.2.0 → rag_python-0.3.0}/README.md +17 -2
  3. {rag_python-0.2.0 → rag_python-0.3.0}/pyproject.toml +4 -3
  4. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/__init__.py +1 -1
  5. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/cli.py +38 -3
  6. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/client.py +3 -0
  7. rag_python-0.3.0/src/rag_python/document_loaders.py +146 -0
  8. rag_python-0.3.0/src/rag_python/hybrid_search.py +51 -0
  9. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/options.py +3 -2
  10. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/rag_pipeline.py +8 -2
  11. rag_python-0.3.0/src/rag_python/retrieval.py +101 -0
  12. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/vector_store.py +13 -0
  13. {rag_python-0.2.0 → rag_python-0.3.0/src/rag_python.egg-info}/PKG-INFO +22 -4
  14. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python.egg-info/SOURCES.txt +4 -1
  15. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python.egg-info/requires.txt +5 -1
  16. rag_python-0.3.0/tests/test_hybrid_search.py +35 -0
  17. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_loaders.py +26 -0
  18. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_package.py +1 -1
  19. rag_python-0.3.0/tests/test_retrieval.py +52 -0
  20. rag_python-0.2.0/src/rag_python/document_loaders.py +0 -74
  21. rag_python-0.2.0/src/rag_python/retrieval.py +0 -61
  22. {rag_python-0.2.0 → rag_python-0.3.0}/LICENSE +0 -0
  23. {rag_python-0.2.0 → rag_python-0.3.0}/setup.cfg +0 -0
  24. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/chunking.py +0 -0
  25. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/cleaning.py +0 -0
  26. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/config.py +0 -0
  27. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/evaluation.py +0 -0
  28. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/generation.py +0 -0
  29. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/guardrails.py +0 -0
  30. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/__init__.py +0 -0
  31. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/anthropic_provider.py +0 -0
  32. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/azure_openai_provider.py +0 -0
  33. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/base.py +0 -0
  34. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/factory.py +0 -0
  35. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/gemini_provider.py +0 -0
  36. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/local_provider.py +0 -0
  37. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/ollama_provider.py +0 -0
  38. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/providers/openai_provider.py +0 -0
  39. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/py.typed +0 -0
  40. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/query_rewriting.py +0 -0
  41. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python/reranker.py +0 -0
  42. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python.egg-info/dependency_links.txt +0 -0
  43. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python.egg-info/entry_points.txt +0 -0
  44. {rag_python-0.2.0 → rag_python-0.3.0}/src/rag_python.egg-info/top_level.txt +0 -0
  45. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_chunking.py +0 -0
  46. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_config.py +0 -0
  47. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_import.py +0 -0
  48. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_pipeline.py +0 -0
  49. {rag_python-0.2.0 → rag_python-0.3.0}/tests/test_providers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: rag-python
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Production-grade RAG for Python: multi-LLM, query rewriting, reranking, guardrails, and evaluation.
5
5
  Author-email: Raghav Singla <04raghavsingla28@gmail.com>
6
6
  License: MIT
@@ -35,6 +35,8 @@ Requires-Dist: sentence-transformers>=2.2.0; extra == "rerank"
35
35
  Requires-Dist: torch>=2.0.0; extra == "rerank"
36
36
  Provides-Extra: local
37
37
  Requires-Dist: sentence-transformers>=2.2.0; extra == "local"
38
+ Provides-Extra: hybrid
39
+ Requires-Dist: rank-bm25>=0.2.2; extra == "hybrid"
38
40
  Provides-Extra: anthropic
39
41
  Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
40
42
  Provides-Extra: gemini
@@ -44,8 +46,9 @@ Requires-Dist: pytest>=7.0; extra == "dev"
44
46
  Requires-Dist: ruff>=0.1.0; extra == "dev"
45
47
  Requires-Dist: build; extra == "dev"
46
48
  Requires-Dist: twine; extra == "dev"
49
+ Requires-Dist: rank-bm25>=0.2.2; extra == "dev"
47
50
  Provides-Extra: all
48
- Requires-Dist: rag-python[anthropic,gemini,local,rerank]; extra == "all"
51
+ Requires-Dist: rag-python[anthropic,gemini,hybrid,local,rerank]; extra == "all"
49
52
 
50
53
  # rag-python
51
54
 
@@ -67,10 +70,11 @@ Ingest your documents, ask questions, get grounded answers — with query rewrit
67
70
  ## Features
68
71
 
69
72
  - Document pipeline: loaders → cleaning → chunking → embeddings → ChromaDB
70
- - Query pipeline: rewriting → multi-query retrieval → reranking
73
+ - Query pipeline: rewriting → multi-query / **hybrid** retrieval → reranking
71
74
  - Generation with guardrails (prompt injection + hallucination checks)
72
75
  - Evaluation scores + self-correction retry loop
73
76
  - **LLM providers:** OpenAI, Azure OpenAI, Anthropic, Gemini, Ollama
77
+ - **Loaders:** TXT, MD, PDF, DOCX, CSV, JSON, HTML
74
78
 
75
79
  ---
76
80
 
@@ -81,7 +85,7 @@ pip install rag-python
81
85
  # or from source
82
86
  pip install -e .
83
87
  # with reranking + extra providers
84
- pip install -e ".[rerank,local,anthropic,gemini,all]"
88
+ pip install -e ".[rerank,local,hybrid,anthropic,gemini,all]"
85
89
  ```
86
90
 
87
91
  ---
@@ -103,12 +107,26 @@ answer = rag.query("How many days of annual leave?")
103
107
  print(answer.text)
104
108
  ```
105
109
 
110
+ ### Hybrid search + metadata filter
111
+
112
+ ```python
113
+ from rag_python import RAG, SearchConfig
114
+
115
+ rag = RAG(
116
+ retriever="hybrid", # pip install rag-python[hybrid]
117
+ metadata_filter={"filename": "leave-policy.pdf"},
118
+ )
119
+ rag.ingest(["./policies/leave-policy.pdf", "./policies/handbook.pdf"])
120
+ answer = rag.query("How many days of annual leave?")
121
+ ```
122
+
106
123
  ### CLI
107
124
 
108
125
  ```bash
109
126
  export OPENAI_API_KEY=sk-...
110
127
  rag-python ingest ./data --reindex
111
128
  rag-python query "How many days of annual leave?" -v
129
+ rag-python query "leave policy" --retriever hybrid --metadata-filter '{"filename": "leave-policy.pdf"}'
112
130
  ```
113
131
 
114
132
  ---
@@ -18,10 +18,11 @@ Ingest your documents, ask questions, get grounded answers — with query rewrit
18
18
  ## Features
19
19
 
20
20
  - Document pipeline: loaders → cleaning → chunking → embeddings → ChromaDB
21
- - Query pipeline: rewriting → multi-query retrieval → reranking
21
+ - Query pipeline: rewriting → multi-query / **hybrid** retrieval → reranking
22
22
  - Generation with guardrails (prompt injection + hallucination checks)
23
23
  - Evaluation scores + self-correction retry loop
24
24
  - **LLM providers:** OpenAI, Azure OpenAI, Anthropic, Gemini, Ollama
25
+ - **Loaders:** TXT, MD, PDF, DOCX, CSV, JSON, HTML
25
26
 
26
27
  ---
27
28
 
@@ -32,7 +33,7 @@ pip install rag-python
32
33
  # or from source
33
34
  pip install -e .
34
35
  # with reranking + extra providers
35
- pip install -e ".[rerank,local,anthropic,gemini,all]"
36
+ pip install -e ".[rerank,local,hybrid,anthropic,gemini,all]"
36
37
  ```
37
38
 
38
39
  ---
@@ -54,12 +55,26 @@ answer = rag.query("How many days of annual leave?")
54
55
  print(answer.text)
55
56
  ```
56
57
 
58
+ ### Hybrid search + metadata filter
59
+
60
+ ```python
61
+ from rag_python import RAG, SearchConfig
62
+
63
+ rag = RAG(
64
+ retriever="hybrid", # pip install rag-python[hybrid]
65
+ metadata_filter={"filename": "leave-policy.pdf"},
66
+ )
67
+ rag.ingest(["./policies/leave-policy.pdf", "./policies/handbook.pdf"])
68
+ answer = rag.query("How many days of annual leave?")
69
+ ```
70
+
57
71
  ### CLI
58
72
 
59
73
  ```bash
60
74
  export OPENAI_API_KEY=sk-...
61
75
  rag-python ingest ./data --reindex
62
76
  rag-python query "How many days of annual leave?" -v
77
+ rag-python query "leave policy" --retriever hybrid --metadata-filter '{"filename": "leave-policy.pdf"}'
63
78
  ```
64
79
 
65
80
  ---
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "rag-python"
7
- version = "0.2.0"
7
+ version = "0.3.0"
8
8
  description = "Production-grade RAG for Python: multi-LLM, query rewriting, reranking, guardrails, and evaluation."
9
9
  readme = "README.md"
10
10
  license = { text = "MIT" }
@@ -39,10 +39,11 @@ dependencies = [
39
39
  [project.optional-dependencies]
40
40
  rerank = ["sentence-transformers>=2.2.0", "torch>=2.0.0"]
41
41
  local = ["sentence-transformers>=2.2.0"]
42
+ hybrid = ["rank-bm25>=0.2.2"]
42
43
  anthropic = ["anthropic>=0.20.0"]
43
44
  gemini = ["google-genai>=0.3.0"]
44
- dev = ["pytest>=7.0", "ruff>=0.1.0", "build", "twine"]
45
- all = ["rag-python[rerank,local,anthropic,gemini]"]
45
+ dev = ["pytest>=7.0", "ruff>=0.1.0", "build", "twine", "rank-bm25>=0.2.2"]
46
+ all = ["rag-python[rerank,local,hybrid,anthropic,gemini]"]
46
47
 
47
48
  [project.scripts]
48
49
  rag-python = "rag_python.cli:main"
@@ -9,7 +9,7 @@ Quick start::
9
9
  print(rag.query("What is our leave policy?").text)
10
10
  """
11
11
 
12
- __version__ = "0.2.0"
12
+ __version__ = "0.3.0"
13
13
 
14
14
  from .client import RAG, RAGAnswer
15
15
  from .rag_pipeline import ingest, query, RAGResponse
@@ -1,5 +1,6 @@
1
1
  """rag-python command-line interface."""
2
2
  import argparse
3
+ import json
3
4
  from dataclasses import replace
4
5
 
5
6
  from . import __version__
@@ -7,7 +8,7 @@ from .client import RAG
7
8
 
8
9
 
9
10
  def _build_rag(args: argparse.Namespace) -> RAG:
10
- return RAG(
11
+ kwargs: dict = dict(
11
12
  llm_provider=args.llm_provider,
12
13
  llm_model=args.llm_model,
13
14
  embedding_provider=args.embedding_provider,
@@ -20,6 +21,20 @@ def _build_rag(args: argparse.Namespace) -> RAG:
20
21
  gemini_api_key=args.gemini_api_key,
21
22
  ollama_base_url=args.ollama_base_url,
22
23
  )
24
+ if getattr(args, "retriever", None):
25
+ kwargs["retriever"] = args.retriever
26
+ if getattr(args, "metadata_filter", None):
27
+ kwargs["metadata_filter"] = args.metadata_filter
28
+ return RAG(**kwargs)
29
+
30
+
31
+ def _parse_metadata_filter(raw: str | None) -> dict | None:
32
+ if not raw:
33
+ return None
34
+ try:
35
+ return json.loads(raw)
36
+ except json.JSONDecodeError as e:
37
+ raise argparse.ArgumentTypeError(f"Invalid JSON for metadata filter: {e}") from e
23
38
 
24
39
 
25
40
  def _add_provider_args(parser: argparse.ArgumentParser) -> None:
@@ -44,6 +59,21 @@ def _add_provider_args(parser: argparse.ArgumentParser) -> None:
44
59
  parser.add_argument("--gemini-api-key", default=None)
45
60
 
46
61
 
62
+ def _add_search_args(parser: argparse.ArgumentParser) -> None:
63
+ parser.add_argument(
64
+ "--retriever",
65
+ choices=["vector", "multi_query", "hybrid"],
66
+ default=None,
67
+ help="Retrieval strategy (default: multi_query; hybrid needs pip install rag-python[hybrid])",
68
+ )
69
+ parser.add_argument(
70
+ "--metadata-filter",
71
+ type=_parse_metadata_filter,
72
+ default=None,
73
+ help='Chroma metadata filter as JSON, e.g. \'{"filename": "policy.pdf"}\'',
74
+ )
75
+
76
+
47
77
  def main() -> None:
48
78
  parser = argparse.ArgumentParser(
49
79
  prog="rag-python",
@@ -59,9 +89,10 @@ def main() -> None:
59
89
 
60
90
  q = sub.add_parser("query", help="Ask a question against ingested documents")
61
91
  q.add_argument("question", nargs="+", help="Question text")
62
- q.add_argument("--no-multi-query", action="store_true")
92
+ q.add_argument("--no-multi-query", action="store_true", help="Use vector retriever only")
63
93
  q.add_argument("-v", "--verbose", action="store_true")
64
94
  _add_provider_args(q)
95
+ _add_search_args(q)
65
96
 
66
97
  args = parser.parse_args()
67
98
 
@@ -74,9 +105,13 @@ def main() -> None:
74
105
  if args.command == "query":
75
106
  rag = _build_rag(args)
76
107
  question = " ".join(args.question)
108
+ retriever = args.retriever
109
+ if retriever is None and args.no_multi_query:
110
+ retriever = "vector"
77
111
  search = replace(
78
112
  rag.config.search,
79
- retriever="vector" if args.no_multi_query else "multi_query",
113
+ retriever=retriever or rag.config.search.retriever,
114
+ metadata_filter=args.metadata_filter or rag.config.search.metadata_filter,
80
115
  )
81
116
  ans = rag.query(question, search=search)
82
117
  print(ans.text)
@@ -60,6 +60,7 @@ class RAG:
60
60
  chunk_size: int | None = None,
61
61
  chunk_overlap: int | None = None,
62
62
  retriever: str | None = None,
63
+ metadata_filter: dict | None = None,
63
64
  top_k_retrieve: int | None = None,
64
65
  top_k_rerank: int | None = None,
65
66
  multi_query_n: int | None = None,
@@ -104,6 +105,8 @@ class RAG:
104
105
  self.config.search = replace(self.config.search, rerank_enabled=rerank_enabled)
105
106
  if document_extensions is not None:
106
107
  self.config.documents = replace(self.config.documents, extensions=document_extensions)
108
+ if metadata_filter is not None:
109
+ self.config.search = replace(self.config.search, metadata_filter=metadata_filter)
107
110
 
108
111
  self.llm = make_llm_provider(
109
112
  llm_provider, # type: ignore[arg-type]
@@ -0,0 +1,146 @@
1
+ """Document loaders: raw data → structured text + metadata."""
2
+ import csv
3
+ import json
4
+ from html.parser import HTMLParser
5
+ from pathlib import Path
6
+ from dataclasses import dataclass
7
+ from typing import Iterator
8
+
9
+ try:
10
+ from pypdf import PdfReader
11
+ except ImportError:
12
+ PdfReader = None
13
+
14
+ try:
15
+ from docx import Document as DocxDocument
16
+ except ImportError:
17
+ DocxDocument = None
18
+
19
+
20
+ @dataclass
21
+ class LoadedDocument:
22
+ """Single document with content and metadata."""
23
+ content: str
24
+ source: str
25
+ metadata: dict
26
+
27
+
28
+ class _HTMLTextExtractor(HTMLParser):
29
+ def __init__(self) -> None:
30
+ super().__init__()
31
+ self.parts: list[str] = []
32
+
33
+ def handle_data(self, data: str) -> None:
34
+ text = data.strip()
35
+ if text:
36
+ self.parts.append(text)
37
+
38
+
39
+ def _html_to_text(html: str) -> str:
40
+ parser = _HTMLTextExtractor()
41
+ parser.feed(html)
42
+ return "\n".join(parser.parts)
43
+
44
+
45
+ def _load_csv(path: Path, metadata: dict) -> LoadedDocument | None:
46
+ rows: list[str] = []
47
+ with path.open(encoding="utf-8", errors="replace", newline="") as f:
48
+ reader = csv.DictReader(f)
49
+ if reader.fieldnames:
50
+ for row in reader:
51
+ rows.append(", ".join(f"{k}: {v}" for k, v in row.items() if v))
52
+ else:
53
+ f.seek(0)
54
+ for row in csv.reader(f):
55
+ rows.append(", ".join(row))
56
+ content = "\n".join(rows)
57
+ metadata["rows"] = len(rows)
58
+ return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
59
+
60
+
61
+ def _load_json(path: Path, metadata: dict) -> LoadedDocument | None:
62
+ data = json.loads(path.read_text(encoding="utf-8", errors="replace"))
63
+ if isinstance(data, list):
64
+ parts = []
65
+ for item in data:
66
+ if isinstance(item, dict) and "text" in item:
67
+ parts.append(str(item["text"]))
68
+ else:
69
+ parts.append(json.dumps(item, ensure_ascii=False))
70
+ content = "\n\n".join(parts)
71
+ elif isinstance(data, dict):
72
+ if "text" in data:
73
+ content = str(data["text"])
74
+ else:
75
+ content = json.dumps(data, ensure_ascii=False, indent=2)
76
+ else:
77
+ content = str(data)
78
+ return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
79
+
80
+
81
+ def load_file(path: Path) -> LoadedDocument | None:
82
+ """Load a single file (PDF, TXT, DOCX, MD, CSV, JSON, HTML) into text + metadata."""
83
+ path = Path(path)
84
+ if not path.exists():
85
+ return None
86
+ suffix = path.suffix.lower()
87
+ metadata = {"source": str(path), "filename": path.name}
88
+
89
+ if suffix in (".txt", ".md"):
90
+ content = path.read_text(encoding="utf-8", errors="replace")
91
+ return LoadedDocument(content=content, source=str(path), metadata=metadata)
92
+
93
+ if suffix == ".html":
94
+ html = path.read_text(encoding="utf-8", errors="replace")
95
+ content = _html_to_text(html)
96
+ return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
97
+
98
+ if suffix == ".csv":
99
+ return _load_csv(path, metadata)
100
+
101
+ if suffix == ".json":
102
+ try:
103
+ return _load_json(path, metadata)
104
+ except json.JSONDecodeError:
105
+ return None
106
+
107
+ if suffix == ".pdf" and PdfReader:
108
+ try:
109
+ reader = PdfReader(path)
110
+ parts = []
111
+ for i, page in enumerate(reader.pages):
112
+ text = page.extract_text() or ""
113
+ parts.append(text)
114
+ metadata.setdefault("page_numbers", []).append(i + 1)
115
+ content = "\n\n".join(parts)
116
+ metadata["pages"] = len(parts)
117
+ return LoadedDocument(content=content, source=str(path), metadata=metadata)
118
+ except Exception:
119
+ return None
120
+
121
+ if suffix in (".docx", ".doc") and DocxDocument:
122
+ try:
123
+ doc = DocxDocument(path)
124
+ parts = [p.text for p in doc.paragraphs]
125
+ content = "\n\n".join(parts)
126
+ metadata["paragraphs"] = len(parts)
127
+ return LoadedDocument(content=content, source=str(path), metadata=metadata)
128
+ except Exception:
129
+ return None
130
+
131
+ return None
132
+
133
+
134
+ def load_directory(
135
+ dir_path: Path,
136
+ extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html"),
137
+ ) -> Iterator[LoadedDocument]:
138
+ """Yield LoadedDocument for each supported file under dir_path."""
139
+ dir_path = Path(dir_path)
140
+ if not dir_path.is_dir():
141
+ return
142
+ for f in dir_path.rglob("*"):
143
+ if f.is_file() and f.suffix.lower() in extensions:
144
+ doc = load_file(f)
145
+ if doc and doc.content.strip():
146
+ yield doc
@@ -0,0 +1,51 @@
1
+ """BM25 + vector fusion via reciprocal rank fusion (RRF)."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Any
5
+
6
+
7
+ def reciprocal_rank_fusion(
8
+ rankings: list[list[tuple[str, dict[str, Any], float]]],
9
+ *,
10
+ rrf_k: int = 60,
11
+ ) -> list[tuple[str, dict[str, Any], float]]:
12
+ """Merge ranked lists with RRF. Higher score is better."""
13
+ scores: dict[tuple[str, str], float] = {}
14
+ doc_map: dict[tuple[str, str], tuple[str, dict[str, Any]]] = {}
15
+
16
+ for ranking in rankings:
17
+ for rank, (doc, meta, _score) in enumerate(ranking):
18
+ key = (doc[:200], str(meta.get("source", "")))
19
+ doc_map[key] = (doc, meta)
20
+ scores[key] = scores.get(key, 0.0) + 1.0 / (rrf_k + rank + 1)
21
+
22
+ merged = sorted(scores.items(), key=lambda item: item[1], reverse=True)
23
+ return [(doc_map[key][0], doc_map[key][1], score) for key, score in merged]
24
+
25
+
26
+ def bm25_retrieve(
27
+ query: str,
28
+ documents: list[str],
29
+ metadatas: list[dict[str, Any]],
30
+ *,
31
+ top_k: int = 20,
32
+ ) -> list[tuple[str, dict[str, Any], float]]:
33
+ """Keyword retrieval with BM25. Requires ``pip install rag-python[hybrid]``."""
34
+ if not documents:
35
+ return []
36
+ try:
37
+ from rank_bm25 import BM25Okapi
38
+ except ImportError as e:
39
+ raise ImportError(
40
+ "Hybrid search requires optional dependencies. Install with: pip install rag-python[hybrid]"
41
+ ) from e
42
+
43
+ tokenized_corpus = [doc.lower().split() for doc in documents]
44
+ bm25 = BM25Okapi(tokenized_corpus)
45
+ scores = bm25.get_scores(query.lower().split())
46
+ ranked = sorted(
47
+ ((documents[i], metadatas[i], float(scores[i])) for i in range(len(documents))),
48
+ key=lambda item: item[2],
49
+ reverse=True,
50
+ )
51
+ return ranked[:top_k]
@@ -16,7 +16,7 @@ from .config import (
16
16
  )
17
17
 
18
18
  ChunkStrategy = Literal["recursive", "structure_aware", "semantic"]
19
- RetrieverStrategy = Literal["vector", "multi_query"]
19
+ RetrieverStrategy = Literal["vector", "multi_query", "hybrid"]
20
20
 
21
21
 
22
22
  @dataclass
@@ -37,13 +37,14 @@ class SearchConfig:
37
37
  top_k_rerank: int = TOP_K_RERANK
38
38
  multi_query_n: int = MULTI_QUERY_N
39
39
  rerank_enabled: bool = RERANK_ENABLED
40
+ metadata_filter: dict | None = None
40
41
 
41
42
 
42
43
  @dataclass
43
44
  class DocumentConfig:
44
45
  """Which files to load and how to preprocess them."""
45
46
 
46
- extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx")
47
+ extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html")
47
48
  clean: bool = True
48
49
  copy_to_data_dir: bool = True
49
50
 
@@ -1,4 +1,5 @@
1
1
  """Full RAG pipeline: Query → Understanding/Rewrite → Retrieval (multi-query) → Rerank → LLM → Guardrails → Eval/Retry."""
2
+ import logging
2
3
  from dataclasses import dataclass
3
4
  from pathlib import Path
4
5
 
@@ -14,6 +15,8 @@ from .providers import LLMProvider, EmbeddingProvider, make_llm_provider, make_e
14
15
  from .config import DATA_DIR, CHUNK_SIZE, CHUNK_OVERLAP, CHUNK_STRATEGY
15
16
  from .options import QueryConfig, SearchConfig
16
17
 
18
+ logger = logging.getLogger(__name__)
19
+
17
20
 
18
21
  @dataclass
19
22
  class RAGResponse:
@@ -34,7 +37,7 @@ def _load_documents(
34
37
  paths: list[Path] | None = None,
35
38
  data_path: Path | None = None,
36
39
  *,
37
- extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx"),
40
+ extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html"),
38
41
  ) -> list[LoadedDocument]:
39
42
  """Load documents from explicit paths and/or a data directory."""
40
43
  docs: list[LoadedDocument] = []
@@ -136,12 +139,13 @@ def ingest(
136
139
  strategy = chunk_strategy or CHUNK_STRATEGY
137
140
  size = chunk_size or CHUNK_SIZE
138
141
  overlap = chunk_overlap or CHUNK_OVERLAP
139
- ext = extensions or (".txt", ".md", ".pdf", ".docx")
142
+ ext = extensions or (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html")
140
143
  embedder = embedder or make_embedding_provider("openai")
141
144
 
142
145
  path_list = [Path(p) for p in paths] if paths else None
143
146
  root = Path(data_path) if data_path else (None if path_list else Path(DATA_DIR))
144
147
  docs = _load_documents(path_list, root, extensions=ext)
148
+ logger.info("Loaded %s documents for ingest", len(docs))
145
149
  return _ingest_documents(
146
150
  docs,
147
151
  clean=clean,
@@ -202,11 +206,13 @@ def query(
202
206
  top_k_retrieve=search_cfg.top_k_retrieve,
203
207
  top_k_rerank=search_cfg.top_k_rerank,
204
208
  rerank_enabled=search_cfg.rerank_enabled,
209
+ metadata_filter=search_cfg.metadata_filter,
205
210
  embedder=embedder,
206
211
  embedding_model=embedding_model,
207
212
  llm=llm,
208
213
  llm_model=llm_model,
209
214
  )
215
+ logger.info("Retrieved %s chunks (retriever=%s)", len(hits), search_cfg.retriever)
210
216
  context_chunks = [h[0] for h in hits]
211
217
  sources = [{"text": h[0][:200], "metadata": h[1], "score": h[2]} for h in hits]
212
218
  context_str = "\n\n".join(context_chunks)
@@ -0,0 +1,101 @@
1
+ """Retrieval: vector, multi-query, hybrid (BM25+vector), and reranking."""
2
+ from typing import Any
3
+
4
+ from .vector_store import retrieve as chroma_retrieve, list_documents
5
+ from .query_rewriting import rewrite_for_retrieval
6
+ from .reranker import rerank_with_metadata
7
+ from .hybrid_search import bm25_retrieve, reciprocal_rank_fusion
8
+ from .providers import EmbeddingProvider, LLMProvider
9
+ from .options import RetrieverStrategy
10
+ from .config import TOP_K_RETRIEVE, TOP_K_RERANK, MULTI_QUERY_N
11
+
12
+
13
+ def _dedupe_candidates(candidates: list[tuple[str, dict, float]]) -> list[tuple[str, dict, float]]:
14
+ seen: set[tuple[str, str]] = set()
15
+ out: list[tuple[str, dict, float]] = []
16
+ for doc, meta, score in candidates:
17
+ key = (doc[:200], str(meta.get("source", "")))
18
+ if key in seen:
19
+ continue
20
+ seen.add(key)
21
+ out.append((doc, meta, score))
22
+ return out
23
+
24
+
25
+ def _vector_candidates(
26
+ queries: list[str],
27
+ *,
28
+ embedder: EmbeddingProvider,
29
+ embedding_model: str | None,
30
+ top_k_retrieve: int,
31
+ where: dict | None,
32
+ ) -> list[tuple[str, dict, float]]:
33
+ seen_docs: set[tuple[str, str]] = set()
34
+ all_candidates: list[tuple[str, dict, float]] = []
35
+ for q in queries:
36
+ emb = embedder.embed([q], model=embedding_model)[0]
37
+ hits = chroma_retrieve(emb, top_k=top_k_retrieve, where=where)
38
+ for doc, meta, dist in hits:
39
+ key = (doc[:200], str(meta.get("source", "")))
40
+ if key in seen_docs:
41
+ continue
42
+ seen_docs.add(key)
43
+ all_candidates.append((doc, meta, -dist))
44
+ return all_candidates
45
+
46
+
47
+ def retrieve(
48
+ query: str,
49
+ *,
50
+ embedder: EmbeddingProvider,
51
+ embedding_model: str | None = None,
52
+ retriever: RetrieverStrategy = "multi_query",
53
+ multi_query: bool | None = None,
54
+ n_queries: int | None = None,
55
+ top_k_retrieve: int | None = None,
56
+ top_k_rerank: int | None = None,
57
+ rerank_enabled: bool | None = None,
58
+ metadata_filter: dict | None = None,
59
+ llm: LLMProvider | None = None,
60
+ llm_model: str | None = None,
61
+ ) -> list[tuple[str, dict[str, Any], float]]:
62
+ """
63
+ Retrieve relevant chunks using vector, multi-query, or hybrid search, then rerank.
64
+ Returns list of (document_text, metadata, rerank_score).
65
+ """
66
+ top_k_retrieve = top_k_retrieve or TOP_K_RETRIEVE
67
+ top_k_rerank = top_k_rerank or TOP_K_RERANK
68
+ n_queries = n_queries or MULTI_QUERY_N
69
+
70
+ if retriever == "hybrid":
71
+ emb = embedder.embed([query], model=embedding_model)[0]
72
+ vector_hits = chroma_retrieve(emb, top_k=top_k_retrieve, where=metadata_filter)
73
+ vector_ranked = [(d, m, -dist) for d, m, dist in vector_hits]
74
+
75
+ docs, metas = list_documents(where=metadata_filter)
76
+ bm25_ranked = bm25_retrieve(query, docs, metas, top_k=top_k_retrieve)
77
+ fused = reciprocal_rank_fusion([vector_ranked, bm25_ranked])[:top_k_retrieve]
78
+ all_candidates = _dedupe_candidates(fused)
79
+ else:
80
+ use_multi_query = retriever == "multi_query" if multi_query is None else multi_query
81
+ queries = [query]
82
+ if use_multi_query and n_queries > 1:
83
+ rewritten = rewrite_for_retrieval(query, n_queries=n_queries, llm=llm, llm_model=llm_model)
84
+ if rewritten:
85
+ queries = rewritten
86
+ all_candidates = _vector_candidates(
87
+ queries,
88
+ embedder=embedder,
89
+ embedding_model=embedding_model,
90
+ top_k_retrieve=top_k_retrieve,
91
+ where=metadata_filter,
92
+ )
93
+
94
+ if not all_candidates:
95
+ return []
96
+
97
+ docs = [c[0] for c in all_candidates]
98
+ metas = [c[1] for c in all_candidates]
99
+ return rerank_with_metadata(
100
+ query, list(zip(docs, metas)), top_k=top_k_rerank, rerank_enabled=rerank_enabled
101
+ )
@@ -85,6 +85,19 @@ def retrieve(
85
85
  return list(zip(docs, metas, dists))
86
86
 
87
87
 
88
+ def list_documents(
89
+ *,
90
+ where: dict | None = None,
91
+ limit: int | None = None,
92
+ ) -> tuple[list[str], list[dict[str, Any]]]:
93
+ """Return all stored chunk texts and metadata (for BM25 indexing)."""
94
+ coll = get_collection()
95
+ res = coll.get(where=where, include=["documents", "metadatas"], limit=limit)
96
+ docs = res.get("documents") or []
97
+ metas = res.get("metadatas") or []
98
+ return docs, metas
99
+
100
+
88
101
  def delete_all() -> None:
89
102
  """Remove all documents from the collection (for re-ingestion)."""
90
103
  _get_client().delete_collection(COLLECTION_NAME)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: rag-python
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Production-grade RAG for Python: multi-LLM, query rewriting, reranking, guardrails, and evaluation.
5
5
  Author-email: Raghav Singla <04raghavsingla28@gmail.com>
6
6
  License: MIT
@@ -35,6 +35,8 @@ Requires-Dist: sentence-transformers>=2.2.0; extra == "rerank"
35
35
  Requires-Dist: torch>=2.0.0; extra == "rerank"
36
36
  Provides-Extra: local
37
37
  Requires-Dist: sentence-transformers>=2.2.0; extra == "local"
38
+ Provides-Extra: hybrid
39
+ Requires-Dist: rank-bm25>=0.2.2; extra == "hybrid"
38
40
  Provides-Extra: anthropic
39
41
  Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
40
42
  Provides-Extra: gemini
@@ -44,8 +46,9 @@ Requires-Dist: pytest>=7.0; extra == "dev"
44
46
  Requires-Dist: ruff>=0.1.0; extra == "dev"
45
47
  Requires-Dist: build; extra == "dev"
46
48
  Requires-Dist: twine; extra == "dev"
49
+ Requires-Dist: rank-bm25>=0.2.2; extra == "dev"
47
50
  Provides-Extra: all
48
- Requires-Dist: rag-python[anthropic,gemini,local,rerank]; extra == "all"
51
+ Requires-Dist: rag-python[anthropic,gemini,hybrid,local,rerank]; extra == "all"
49
52
 
50
53
  # rag-python
51
54
 
@@ -67,10 +70,11 @@ Ingest your documents, ask questions, get grounded answers — with query rewrit
67
70
  ## Features
68
71
 
69
72
  - Document pipeline: loaders → cleaning → chunking → embeddings → ChromaDB
70
- - Query pipeline: rewriting → multi-query retrieval → reranking
73
+ - Query pipeline: rewriting → multi-query / **hybrid** retrieval → reranking
71
74
  - Generation with guardrails (prompt injection + hallucination checks)
72
75
  - Evaluation scores + self-correction retry loop
73
76
  - **LLM providers:** OpenAI, Azure OpenAI, Anthropic, Gemini, Ollama
77
+ - **Loaders:** TXT, MD, PDF, DOCX, CSV, JSON, HTML
74
78
 
75
79
  ---
76
80
 
@@ -81,7 +85,7 @@ pip install rag-python
81
85
  # or from source
82
86
  pip install -e .
83
87
  # with reranking + extra providers
84
- pip install -e ".[rerank,local,anthropic,gemini,all]"
88
+ pip install -e ".[rerank,local,hybrid,anthropic,gemini,all]"
85
89
  ```
86
90
 
87
91
  ---
@@ -103,12 +107,26 @@ answer = rag.query("How many days of annual leave?")
103
107
  print(answer.text)
104
108
  ```
105
109
 
110
+ ### Hybrid search + metadata filter
111
+
112
+ ```python
113
+ from rag_python import RAG, SearchConfig
114
+
115
+ rag = RAG(
116
+ retriever="hybrid", # pip install rag-python[hybrid]
117
+ metadata_filter={"filename": "leave-policy.pdf"},
118
+ )
119
+ rag.ingest(["./policies/leave-policy.pdf", "./policies/handbook.pdf"])
120
+ answer = rag.query("How many days of annual leave?")
121
+ ```
122
+
106
123
  ### CLI
107
124
 
108
125
  ```bash
109
126
  export OPENAI_API_KEY=sk-...
110
127
  rag-python ingest ./data --reindex
111
128
  rag-python query "How many days of annual leave?" -v
129
+ rag-python query "leave policy" --retriever hybrid --metadata-filter '{"filename": "leave-policy.pdf"}'
112
130
  ```
113
131
 
114
132
  ---
@@ -11,6 +11,7 @@ src/rag_python/document_loaders.py
11
11
  src/rag_python/evaluation.py
12
12
  src/rag_python/generation.py
13
13
  src/rag_python/guardrails.py
14
+ src/rag_python/hybrid_search.py
14
15
  src/rag_python/options.py
15
16
  src/rag_python/py.typed
16
17
  src/rag_python/query_rewriting.py
@@ -35,8 +36,10 @@ src/rag_python/providers/ollama_provider.py
35
36
  src/rag_python/providers/openai_provider.py
36
37
  tests/test_chunking.py
37
38
  tests/test_config.py
39
+ tests/test_hybrid_search.py
38
40
  tests/test_import.py
39
41
  tests/test_loaders.py
40
42
  tests/test_package.py
41
43
  tests/test_pipeline.py
42
- tests/test_providers.py
44
+ tests/test_providers.py
45
+ tests/test_retrieval.py
@@ -9,7 +9,7 @@ python-dotenv>=1.0.0
9
9
  requests>=2.31.0
10
10
 
11
11
  [all]
12
- rag-python[anthropic,gemini,local,rerank]
12
+ rag-python[anthropic,gemini,hybrid,local,rerank]
13
13
 
14
14
  [anthropic]
15
15
  anthropic>=0.20.0
@@ -19,10 +19,14 @@ pytest>=7.0
19
19
  ruff>=0.1.0
20
20
  build
21
21
  twine
22
+ rank-bm25>=0.2.2
22
23
 
23
24
  [gemini]
24
25
  google-genai>=0.3.0
25
26
 
27
+ [hybrid]
28
+ rank-bm25>=0.2.2
29
+
26
30
  [local]
27
31
  sentence-transformers>=2.2.0
28
32
 
@@ -0,0 +1,35 @@
1
+ import pytest
2
+
3
+ from rag_python.hybrid_search import bm25_retrieve, reciprocal_rank_fusion
4
+
5
+
6
+ def test_reciprocal_rank_fusion_prefers_shared_docs():
7
+ vector = [
8
+ ("doc a", {"source": "a"}, 0.9),
9
+ ("doc b", {"source": "b"}, 0.8),
10
+ ]
11
+ bm25 = [
12
+ ("doc b", {"source": "b"}, 0.95),
13
+ ("doc c", {"source": "c"}, 0.7),
14
+ ]
15
+ merged = reciprocal_rank_fusion([vector, bm25])
16
+ assert len(merged) == 3
17
+ assert merged[0][0] == "doc b"
18
+
19
+
20
+ def test_bm25_retrieve_ranks_relevant_doc():
21
+ docs = [
22
+ "annual leave policy grants twenty days per year",
23
+ "office cafeteria menu and lunch hours",
24
+ ]
25
+ metas = [{"source": "policy.txt"}, {"source": "cafe.txt"}]
26
+ try:
27
+ hits = bm25_retrieve("annual leave days", docs, metas, top_k=1)
28
+ except ImportError:
29
+ pytest.skip("rank_bm25 not installed")
30
+ assert hits[0][0] == docs[0]
31
+ assert hits[0][1]["source"] == "policy.txt"
32
+
33
+
34
+ def test_bm25_retrieve_empty_corpus():
35
+ assert bm25_retrieve("query", [], [], top_k=5) == []
@@ -20,6 +20,32 @@ def test_load_markdown_file(tmp_path: Path):
20
20
  assert "Title" in doc.content
21
21
 
22
22
 
23
+ def test_load_csv_file(tmp_path: Path):
24
+ f = tmp_path / "data.csv"
25
+ f.write_text("name,days\nAlice,20\nBob,15\n", encoding="utf-8")
26
+ doc = load_file(f)
27
+ assert doc is not None
28
+ assert "Alice" in doc.content
29
+ assert doc.metadata.get("rows") == 2
30
+
31
+
32
+ def test_load_json_file(tmp_path: Path):
33
+ f = tmp_path / "data.json"
34
+ f.write_text('[{"text": "Annual leave is twenty days."}]', encoding="utf-8")
35
+ doc = load_file(f)
36
+ assert doc is not None
37
+ assert "twenty days" in doc.content
38
+
39
+
40
+ def test_load_html_file(tmp_path: Path):
41
+ f = tmp_path / "page.html"
42
+ f.write_text("<html><body><h1>Policy</h1><p>Twenty days leave.</p></body></html>", encoding="utf-8")
43
+ doc = load_file(f)
44
+ assert doc is not None
45
+ assert "Policy" in doc.content
46
+ assert "Twenty days" in doc.content
47
+
48
+
23
49
  def test_load_directory_skips_empty_files(tmp_path: Path):
24
50
  (tmp_path / "a.txt").write_text("content a", encoding="utf-8")
25
51
  (tmp_path / "empty.txt").write_text(" ", encoding="utf-8")
@@ -4,7 +4,7 @@ import importlib.metadata
4
4
  def test_package_metadata():
5
5
  dist = importlib.metadata.metadata("rag-python")
6
6
  assert dist["Name"] == "rag-python"
7
- assert dist["Version"] == "0.2.0"
7
+ assert dist["Version"] == "0.3.0"
8
8
  author = dist.get("Author") or dist.get("Author-email") or ""
9
9
  assert "Raghav Singla" in author or "RaghavOG" in author
10
10
 
@@ -0,0 +1,52 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ from rag_python.retrieval import retrieve
4
+
5
+
6
+ def test_hybrid_retriever_fuses_vector_and_bm25():
7
+ embedder = MagicMock()
8
+ embedder.embed.return_value = [[0.1, 0.2]]
9
+
10
+ vector_hits = [("vector doc", {"source": "v.txt"}, 0.1)]
11
+ bm25_hits = [("bm25 doc", {"source": "b.txt"}, 1.5)]
12
+ fused = [("vector doc", {"source": "v.txt"}, 0.5), ("bm25 doc", {"source": "b.txt"}, 0.4)]
13
+
14
+ with (
15
+ patch("rag_python.retrieval.chroma_retrieve", return_value=vector_hits) as mock_chroma,
16
+ patch("rag_python.retrieval.list_documents", return_value=(["bm25 doc"], [{"source": "b.txt"}])),
17
+ patch("rag_python.retrieval.bm25_retrieve", return_value=bm25_hits) as mock_bm25,
18
+ patch("rag_python.retrieval.reciprocal_rank_fusion", return_value=fused) as mock_rrf,
19
+ patch("rag_python.retrieval.rerank_with_metadata", side_effect=lambda q, pairs, **kw: pairs),
20
+ ):
21
+ hits = retrieve(
22
+ "leave policy",
23
+ embedder=embedder,
24
+ retriever="hybrid",
25
+ rerank_enabled=False,
26
+ metadata_filter={"filename": "policy.txt"},
27
+ )
28
+
29
+ mock_chroma.assert_called_once()
30
+ assert mock_chroma.call_args.kwargs["where"] == {"filename": "policy.txt"}
31
+ mock_bm25.assert_called_once()
32
+ mock_rrf.assert_called_once()
33
+ assert len(hits) == 2
34
+
35
+
36
+ def test_vector_retriever_passes_metadata_filter():
37
+ embedder = MagicMock()
38
+ embedder.embed.return_value = [[0.5, 0.5]]
39
+
40
+ with (
41
+ patch("rag_python.retrieval.chroma_retrieve", return_value=[]) as mock_chroma,
42
+ patch("rag_python.retrieval.rerank_with_metadata", return_value=[]),
43
+ ):
44
+ retrieve(
45
+ "question",
46
+ embedder=embedder,
47
+ retriever="vector",
48
+ metadata_filter={"source": "/data/policy.txt"},
49
+ )
50
+
51
+ mock_chroma.assert_called_once()
52
+ assert mock_chroma.call_args.kwargs["where"] == {"source": "/data/policy.txt"}
@@ -1,74 +0,0 @@
1
- """Document loaders: raw data → structured text + metadata."""
2
- from pathlib import Path
3
- from dataclasses import dataclass
4
- from typing import Iterator
5
-
6
- try:
7
- from pypdf import PdfReader
8
- except ImportError:
9
- PdfReader = None
10
-
11
- try:
12
- from docx import Document as DocxDocument
13
- except ImportError:
14
- DocxDocument = None
15
-
16
-
17
- @dataclass
18
- class LoadedDocument:
19
- """Single document with content and metadata."""
20
- content: str
21
- source: str
22
- metadata: dict
23
-
24
-
25
- def load_file(path: Path) -> LoadedDocument | None:
26
- """Load a single file (PDF, TXT, DOCX, MD) into text + metadata."""
27
- path = Path(path)
28
- if not path.exists():
29
- return None
30
- suffix = path.suffix.lower()
31
- metadata = {"source": str(path), "filename": path.name}
32
-
33
- if suffix == ".txt" or suffix == ".md":
34
- content = path.read_text(encoding="utf-8", errors="replace")
35
- return LoadedDocument(content=content, source=str(path), metadata=metadata)
36
-
37
- if suffix == ".pdf" and PdfReader:
38
- try:
39
- reader = PdfReader(path)
40
- parts = []
41
- for i, page in enumerate(reader.pages):
42
- text = page.extract_text() or ""
43
- parts.append(text)
44
- metadata.setdefault("page_numbers", []).append(i + 1)
45
- content = "\n\n".join(parts)
46
- metadata["pages"] = len(parts)
47
- return LoadedDocument(content=content, source=str(path), metadata=metadata)
48
- except Exception:
49
- return None
50
-
51
- if suffix in (".docx", ".doc") and DocxDocument:
52
- try:
53
- doc = DocxDocument(path)
54
- parts = [p.text for p in doc.paragraphs]
55
- content = "\n\n".join(parts)
56
- metadata["paragraphs"] = len(parts)
57
- return LoadedDocument(content=content, source=str(path), metadata=metadata)
58
- except Exception:
59
- return None
60
-
61
- return None
62
-
63
-
64
- def load_directory(dir_path: Path, extensions: tuple = (".txt", ".md", ".pdf", ".docx")) -> Iterator[LoadedDocument]:
65
- """Yield LoadedDocument for each supported file under dir_path."""
66
- dir_path = Path(dir_path)
67
- if not dir_path.is_dir():
68
- return
69
- for f in dir_path.rglob("*"):
70
- if f.is_file() and f.suffix.lower() in extensions:
71
- doc = load_file(f)
72
- if doc and doc.content.strip():
73
- yield doc
74
-
@@ -1,61 +0,0 @@
1
- """Retrieval: multi-query retrieval + reranking."""
2
- from typing import Any
3
-
4
- from .vector_store import retrieve as chroma_retrieve
5
- from .query_rewriting import rewrite_for_retrieval
6
- from .reranker import rerank_with_metadata
7
- from .providers import EmbeddingProvider, LLMProvider
8
- from .options import RetrieverStrategy
9
- from .config import TOP_K_RETRIEVE, TOP_K_RERANK, MULTI_QUERY_N
10
-
11
-
12
- def retrieve(
13
- query: str,
14
- *,
15
- embedder: EmbeddingProvider,
16
- embedding_model: str | None = None,
17
- retriever: RetrieverStrategy = "multi_query",
18
- multi_query: bool | None = None,
19
- n_queries: int | None = None,
20
- top_k_retrieve: int | None = None,
21
- top_k_rerank: int | None = None,
22
- rerank_enabled: bool | None = None,
23
- llm: LLMProvider | None = None,
24
- llm_model: str | None = None,
25
- ) -> list[tuple[str, dict[str, Any], float]]:
26
- """
27
- Retrieve relevant chunks using vector or multi-query search, then rerank.
28
- Returns list of (document_text, metadata, rerank_score).
29
- """
30
- top_k_retrieve = top_k_retrieve or TOP_K_RETRIEVE
31
- top_k_rerank = top_k_rerank or TOP_K_RERANK
32
- n_queries = n_queries or MULTI_QUERY_N
33
- use_multi_query = retriever == "multi_query" if multi_query is None else multi_query
34
-
35
- queries = [query]
36
- if use_multi_query and n_queries > 1:
37
- rewritten = rewrite_for_retrieval(query, n_queries=n_queries, llm=llm, llm_model=llm_model)
38
- if rewritten:
39
- queries = rewritten
40
-
41
- seen_docs: set[str] = set()
42
- all_candidates: list[tuple[str, dict, float]] = []
43
- for q in queries:
44
- emb = embedder.embed([q], model=embedding_model)[0]
45
- hits = chroma_retrieve(emb, top_k=top_k_retrieve)
46
- for doc, meta, dist in hits:
47
- key = (doc[:200], meta.get("source", ""))
48
- if key in seen_docs:
49
- continue
50
- seen_docs.add(key)
51
- all_candidates.append((doc, meta, -dist))
52
-
53
- if not all_candidates:
54
- return []
55
- docs = [c[0] for c in all_candidates]
56
- metas = [c[1] for c in all_candidates]
57
- reranked = rerank_with_metadata(
58
- query, list(zip(docs, metas)), top_k=top_k_rerank, rerank_enabled=rerank_enabled
59
- )
60
- return reranked
61
-
File without changes
File without changes