rag-python 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rag_python/__init__.py +1 -1
- rag_python/cli.py +55 -5
- rag_python/client.py +3 -0
- rag_python/document_loaders.py +76 -4
- rag_python/hybrid_search.py +51 -0
- rag_python/options.py +3 -2
- rag_python/providers/factory.py +4 -1
- rag_python/providers/local_provider.py +34 -0
- rag_python/rag_pipeline.py +8 -2
- rag_python/retrieval.py +63 -23
- rag_python/vector_store.py +13 -0
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/METADATA +26 -4
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/RECORD +17 -15
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/LICENSE +0 -0
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/WHEEL +0 -0
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/entry_points.txt +0 -0
- {rag_python-0.1.0.dist-info → rag_python-0.3.0.dist-info}/top_level.txt +0 -0
rag_python/__init__.py
CHANGED
rag_python/cli.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
|
1
1
|
"""rag-python command-line interface."""
|
|
2
2
|
import argparse
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import replace
|
|
3
5
|
|
|
6
|
+
from . import __version__
|
|
4
7
|
from .client import RAG
|
|
5
8
|
|
|
6
9
|
|
|
7
10
|
def _build_rag(args: argparse.Namespace) -> RAG:
|
|
8
|
-
|
|
11
|
+
kwargs: dict = dict(
|
|
9
12
|
llm_provider=args.llm_provider,
|
|
10
13
|
llm_model=args.llm_model,
|
|
11
14
|
embedding_provider=args.embedding_provider,
|
|
@@ -18,12 +21,34 @@ def _build_rag(args: argparse.Namespace) -> RAG:
|
|
|
18
21
|
gemini_api_key=args.gemini_api_key,
|
|
19
22
|
ollama_base_url=args.ollama_base_url,
|
|
20
23
|
)
|
|
24
|
+
if getattr(args, "retriever", None):
|
|
25
|
+
kwargs["retriever"] = args.retriever
|
|
26
|
+
if getattr(args, "metadata_filter", None):
|
|
27
|
+
kwargs["metadata_filter"] = args.metadata_filter
|
|
28
|
+
return RAG(**kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _parse_metadata_filter(raw: str | None) -> dict | None:
|
|
32
|
+
if not raw:
|
|
33
|
+
return None
|
|
34
|
+
try:
|
|
35
|
+
return json.loads(raw)
|
|
36
|
+
except json.JSONDecodeError as e:
|
|
37
|
+
raise argparse.ArgumentTypeError(f"Invalid JSON for metadata filter: {e}") from e
|
|
21
38
|
|
|
22
39
|
|
|
23
40
|
def _add_provider_args(parser: argparse.ArgumentParser) -> None:
|
|
24
|
-
parser.add_argument(
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
"--llm-provider",
|
|
43
|
+
default="openai",
|
|
44
|
+
choices=["openai", "azure_openai", "anthropic", "gemini", "ollama"],
|
|
45
|
+
)
|
|
25
46
|
parser.add_argument("--llm-model", default=None)
|
|
26
|
-
parser.add_argument(
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--embedding-provider",
|
|
49
|
+
default="openai",
|
|
50
|
+
choices=["openai", "azure_openai", "ollama", "local"],
|
|
51
|
+
)
|
|
27
52
|
parser.add_argument("--embedding-model", default=None)
|
|
28
53
|
parser.add_argument("--ollama-base-url", default=None)
|
|
29
54
|
parser.add_argument("--azure-endpoint", default=None)
|
|
@@ -34,11 +59,27 @@ def _add_provider_args(parser: argparse.ArgumentParser) -> None:
|
|
|
34
59
|
parser.add_argument("--gemini-api-key", default=None)
|
|
35
60
|
|
|
36
61
|
|
|
62
|
+
def _add_search_args(parser: argparse.ArgumentParser) -> None:
|
|
63
|
+
parser.add_argument(
|
|
64
|
+
"--retriever",
|
|
65
|
+
choices=["vector", "multi_query", "hybrid"],
|
|
66
|
+
default=None,
|
|
67
|
+
help="Retrieval strategy (default: multi_query; hybrid needs pip install rag-python[hybrid])",
|
|
68
|
+
)
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"--metadata-filter",
|
|
71
|
+
type=_parse_metadata_filter,
|
|
72
|
+
default=None,
|
|
73
|
+
help='Chroma metadata filter as JSON, e.g. \'{"filename": "policy.pdf"}\'',
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
37
77
|
def main() -> None:
|
|
38
78
|
parser = argparse.ArgumentParser(
|
|
39
79
|
prog="rag-python",
|
|
40
80
|
description="rag-python — modular RAG with query rewriting, reranking, guardrails, and multi-LLM support.",
|
|
41
81
|
)
|
|
82
|
+
parser.add_argument("--version", action="version", version=f"rag-python {__version__}")
|
|
42
83
|
sub = parser.add_subparsers(dest="command", required=True)
|
|
43
84
|
|
|
44
85
|
ing = sub.add_parser("ingest", help="Ingest files/folders into the vector store")
|
|
@@ -48,9 +89,10 @@ def main() -> None:
|
|
|
48
89
|
|
|
49
90
|
q = sub.add_parser("query", help="Ask a question against ingested documents")
|
|
50
91
|
q.add_argument("question", nargs="+", help="Question text")
|
|
51
|
-
q.add_argument("--no-multi-query", action="store_true")
|
|
92
|
+
q.add_argument("--no-multi-query", action="store_true", help="Use vector retriever only")
|
|
52
93
|
q.add_argument("-v", "--verbose", action="store_true")
|
|
53
94
|
_add_provider_args(q)
|
|
95
|
+
_add_search_args(q)
|
|
54
96
|
|
|
55
97
|
args = parser.parse_args()
|
|
56
98
|
|
|
@@ -63,7 +105,15 @@ def main() -> None:
|
|
|
63
105
|
if args.command == "query":
|
|
64
106
|
rag = _build_rag(args)
|
|
65
107
|
question = " ".join(args.question)
|
|
66
|
-
|
|
108
|
+
retriever = args.retriever
|
|
109
|
+
if retriever is None and args.no_multi_query:
|
|
110
|
+
retriever = "vector"
|
|
111
|
+
search = replace(
|
|
112
|
+
rag.config.search,
|
|
113
|
+
retriever=retriever or rag.config.search.retriever,
|
|
114
|
+
metadata_filter=args.metadata_filter or rag.config.search.metadata_filter,
|
|
115
|
+
)
|
|
116
|
+
ans = rag.query(question, search=search)
|
|
67
117
|
print(ans.text)
|
|
68
118
|
if args.verbose:
|
|
69
119
|
print("\n--- evaluation ---")
|
rag_python/client.py
CHANGED
|
@@ -60,6 +60,7 @@ class RAG:
|
|
|
60
60
|
chunk_size: int | None = None,
|
|
61
61
|
chunk_overlap: int | None = None,
|
|
62
62
|
retriever: str | None = None,
|
|
63
|
+
metadata_filter: dict | None = None,
|
|
63
64
|
top_k_retrieve: int | None = None,
|
|
64
65
|
top_k_rerank: int | None = None,
|
|
65
66
|
multi_query_n: int | None = None,
|
|
@@ -104,6 +105,8 @@ class RAG:
|
|
|
104
105
|
self.config.search = replace(self.config.search, rerank_enabled=rerank_enabled)
|
|
105
106
|
if document_extensions is not None:
|
|
106
107
|
self.config.documents = replace(self.config.documents, extensions=document_extensions)
|
|
108
|
+
if metadata_filter is not None:
|
|
109
|
+
self.config.search = replace(self.config.search, metadata_filter=metadata_filter)
|
|
107
110
|
|
|
108
111
|
self.llm = make_llm_provider(
|
|
109
112
|
llm_provider, # type: ignore[arg-type]
|
rag_python/document_loaders.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""Document loaders: raw data → structured text + metadata."""
|
|
2
|
+
import csv
|
|
3
|
+
import json
|
|
4
|
+
from html.parser import HTMLParser
|
|
2
5
|
from pathlib import Path
|
|
3
6
|
from dataclasses import dataclass
|
|
4
7
|
from typing import Iterator
|
|
@@ -22,18 +25,85 @@ class LoadedDocument:
|
|
|
22
25
|
metadata: dict
|
|
23
26
|
|
|
24
27
|
|
|
28
|
+
class _HTMLTextExtractor(HTMLParser):
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.parts: list[str] = []
|
|
32
|
+
|
|
33
|
+
def handle_data(self, data: str) -> None:
|
|
34
|
+
text = data.strip()
|
|
35
|
+
if text:
|
|
36
|
+
self.parts.append(text)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _html_to_text(html: str) -> str:
|
|
40
|
+
parser = _HTMLTextExtractor()
|
|
41
|
+
parser.feed(html)
|
|
42
|
+
return "\n".join(parser.parts)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _load_csv(path: Path, metadata: dict) -> LoadedDocument | None:
|
|
46
|
+
rows: list[str] = []
|
|
47
|
+
with path.open(encoding="utf-8", errors="replace", newline="") as f:
|
|
48
|
+
reader = csv.DictReader(f)
|
|
49
|
+
if reader.fieldnames:
|
|
50
|
+
for row in reader:
|
|
51
|
+
rows.append(", ".join(f"{k}: {v}" for k, v in row.items() if v))
|
|
52
|
+
else:
|
|
53
|
+
f.seek(0)
|
|
54
|
+
for row in csv.reader(f):
|
|
55
|
+
rows.append(", ".join(row))
|
|
56
|
+
content = "\n".join(rows)
|
|
57
|
+
metadata["rows"] = len(rows)
|
|
58
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _load_json(path: Path, metadata: dict) -> LoadedDocument | None:
|
|
62
|
+
data = json.loads(path.read_text(encoding="utf-8", errors="replace"))
|
|
63
|
+
if isinstance(data, list):
|
|
64
|
+
parts = []
|
|
65
|
+
for item in data:
|
|
66
|
+
if isinstance(item, dict) and "text" in item:
|
|
67
|
+
parts.append(str(item["text"]))
|
|
68
|
+
else:
|
|
69
|
+
parts.append(json.dumps(item, ensure_ascii=False))
|
|
70
|
+
content = "\n\n".join(parts)
|
|
71
|
+
elif isinstance(data, dict):
|
|
72
|
+
if "text" in data:
|
|
73
|
+
content = str(data["text"])
|
|
74
|
+
else:
|
|
75
|
+
content = json.dumps(data, ensure_ascii=False, indent=2)
|
|
76
|
+
else:
|
|
77
|
+
content = str(data)
|
|
78
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
79
|
+
|
|
80
|
+
|
|
25
81
|
def load_file(path: Path) -> LoadedDocument | None:
|
|
26
|
-
"""Load a single file (PDF, TXT, DOCX, MD) into text + metadata."""
|
|
82
|
+
"""Load a single file (PDF, TXT, DOCX, MD, CSV, JSON, HTML) into text + metadata."""
|
|
27
83
|
path = Path(path)
|
|
28
84
|
if not path.exists():
|
|
29
85
|
return None
|
|
30
86
|
suffix = path.suffix.lower()
|
|
31
87
|
metadata = {"source": str(path), "filename": path.name}
|
|
32
88
|
|
|
33
|
-
if suffix
|
|
89
|
+
if suffix in (".txt", ".md"):
|
|
34
90
|
content = path.read_text(encoding="utf-8", errors="replace")
|
|
35
91
|
return LoadedDocument(content=content, source=str(path), metadata=metadata)
|
|
36
92
|
|
|
93
|
+
if suffix == ".html":
|
|
94
|
+
html = path.read_text(encoding="utf-8", errors="replace")
|
|
95
|
+
content = _html_to_text(html)
|
|
96
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
97
|
+
|
|
98
|
+
if suffix == ".csv":
|
|
99
|
+
return _load_csv(path, metadata)
|
|
100
|
+
|
|
101
|
+
if suffix == ".json":
|
|
102
|
+
try:
|
|
103
|
+
return _load_json(path, metadata)
|
|
104
|
+
except json.JSONDecodeError:
|
|
105
|
+
return None
|
|
106
|
+
|
|
37
107
|
if suffix == ".pdf" and PdfReader:
|
|
38
108
|
try:
|
|
39
109
|
reader = PdfReader(path)
|
|
@@ -61,7 +131,10 @@ def load_file(path: Path) -> LoadedDocument | None:
|
|
|
61
131
|
return None
|
|
62
132
|
|
|
63
133
|
|
|
64
|
-
def load_directory(
|
|
134
|
+
def load_directory(
|
|
135
|
+
dir_path: Path,
|
|
136
|
+
extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html"),
|
|
137
|
+
) -> Iterator[LoadedDocument]:
|
|
65
138
|
"""Yield LoadedDocument for each supported file under dir_path."""
|
|
66
139
|
dir_path = Path(dir_path)
|
|
67
140
|
if not dir_path.is_dir():
|
|
@@ -71,4 +144,3 @@ def load_directory(dir_path: Path, extensions: tuple = (".txt", ".md", ".pdf", "
|
|
|
71
144
|
doc = load_file(f)
|
|
72
145
|
if doc and doc.content.strip():
|
|
73
146
|
yield doc
|
|
74
|
-
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""BM25 + vector fusion via reciprocal rank fusion (RRF)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def reciprocal_rank_fusion(
|
|
8
|
+
rankings: list[list[tuple[str, dict[str, Any], float]]],
|
|
9
|
+
*,
|
|
10
|
+
rrf_k: int = 60,
|
|
11
|
+
) -> list[tuple[str, dict[str, Any], float]]:
|
|
12
|
+
"""Merge ranked lists with RRF. Higher score is better."""
|
|
13
|
+
scores: dict[tuple[str, str], float] = {}
|
|
14
|
+
doc_map: dict[tuple[str, str], tuple[str, dict[str, Any]]] = {}
|
|
15
|
+
|
|
16
|
+
for ranking in rankings:
|
|
17
|
+
for rank, (doc, meta, _score) in enumerate(ranking):
|
|
18
|
+
key = (doc[:200], str(meta.get("source", "")))
|
|
19
|
+
doc_map[key] = (doc, meta)
|
|
20
|
+
scores[key] = scores.get(key, 0.0) + 1.0 / (rrf_k + rank + 1)
|
|
21
|
+
|
|
22
|
+
merged = sorted(scores.items(), key=lambda item: item[1], reverse=True)
|
|
23
|
+
return [(doc_map[key][0], doc_map[key][1], score) for key, score in merged]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def bm25_retrieve(
|
|
27
|
+
query: str,
|
|
28
|
+
documents: list[str],
|
|
29
|
+
metadatas: list[dict[str, Any]],
|
|
30
|
+
*,
|
|
31
|
+
top_k: int = 20,
|
|
32
|
+
) -> list[tuple[str, dict[str, Any], float]]:
|
|
33
|
+
"""Keyword retrieval with BM25. Requires ``pip install rag-python[hybrid]``."""
|
|
34
|
+
if not documents:
|
|
35
|
+
return []
|
|
36
|
+
try:
|
|
37
|
+
from rank_bm25 import BM25Okapi
|
|
38
|
+
except ImportError as e:
|
|
39
|
+
raise ImportError(
|
|
40
|
+
"Hybrid search requires optional dependencies. Install with: pip install rag-python[hybrid]"
|
|
41
|
+
) from e
|
|
42
|
+
|
|
43
|
+
tokenized_corpus = [doc.lower().split() for doc in documents]
|
|
44
|
+
bm25 = BM25Okapi(tokenized_corpus)
|
|
45
|
+
scores = bm25.get_scores(query.lower().split())
|
|
46
|
+
ranked = sorted(
|
|
47
|
+
((documents[i], metadatas[i], float(scores[i])) for i in range(len(documents))),
|
|
48
|
+
key=lambda item: item[2],
|
|
49
|
+
reverse=True,
|
|
50
|
+
)
|
|
51
|
+
return ranked[:top_k]
|
rag_python/options.py
CHANGED
|
@@ -16,7 +16,7 @@ from .config import (
|
|
|
16
16
|
)
|
|
17
17
|
|
|
18
18
|
ChunkStrategy = Literal["recursive", "structure_aware", "semantic"]
|
|
19
|
-
RetrieverStrategy = Literal["vector", "multi_query"]
|
|
19
|
+
RetrieverStrategy = Literal["vector", "multi_query", "hybrid"]
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
@dataclass
|
|
@@ -37,13 +37,14 @@ class SearchConfig:
|
|
|
37
37
|
top_k_rerank: int = TOP_K_RERANK
|
|
38
38
|
multi_query_n: int = MULTI_QUERY_N
|
|
39
39
|
rerank_enabled: bool = RERANK_ENABLED
|
|
40
|
+
metadata_filter: dict | None = None
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
@dataclass
|
|
43
44
|
class DocumentConfig:
|
|
44
45
|
"""Which files to load and how to preprocess them."""
|
|
45
46
|
|
|
46
|
-
extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx")
|
|
47
|
+
extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html")
|
|
47
48
|
clean: bool = True
|
|
48
49
|
copy_to_data_dir: bool = True
|
|
49
50
|
|
rag_python/providers/factory.py
CHANGED
|
@@ -9,10 +9,11 @@ from .azure_openai_provider import AzureOpenAIProvider
|
|
|
9
9
|
from .anthropic_provider import AnthropicProvider
|
|
10
10
|
from .gemini_provider import GeminiProvider
|
|
11
11
|
from .ollama_provider import OllamaProvider
|
|
12
|
+
from .local_provider import LocalEmbeddingProvider
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
LLMProviderName = Literal["openai", "azure_openai", "anthropic", "gemini", "ollama"]
|
|
15
|
-
EmbeddingProviderName = Literal["openai", "azure_openai", "ollama"]
|
|
16
|
+
EmbeddingProviderName = Literal["openai", "azure_openai", "ollama", "local"]
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def make_llm_provider(name: LLMProviderName, **kwargs) -> LLMProvider:
|
|
@@ -49,5 +50,7 @@ def make_embedding_provider(name: EmbeddingProviderName, **kwargs) -> EmbeddingP
|
|
|
49
50
|
)
|
|
50
51
|
if name == "ollama":
|
|
51
52
|
return OllamaProvider(base_url=kwargs.get("base_url") or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434"))
|
|
53
|
+
if name == "local":
|
|
54
|
+
return LocalEmbeddingProvider(model_name=kwargs.get("model") or os.getenv("LOCAL_EMBEDDING_MODEL"))
|
|
52
55
|
raise ValueError(f"Unknown embedding provider: {name}")
|
|
53
56
|
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Local sentence-transformers embeddings (no API key required)."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
_DEFAULT_MODEL = "all-MiniLM-L6-v2"
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class LocalEmbeddingProvider:
|
|
10
|
+
"""Offline embeddings via sentence-transformers."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, model_name: str | None = None) -> None:
|
|
13
|
+
self.default_model = model_name or os.getenv("LOCAL_EMBEDDING_MODEL", _DEFAULT_MODEL)
|
|
14
|
+
self._models: dict[str, object] = {}
|
|
15
|
+
|
|
16
|
+
def _get_model(self, model_name: str):
|
|
17
|
+
if model_name not in self._models:
|
|
18
|
+
try:
|
|
19
|
+
from sentence_transformers import SentenceTransformer
|
|
20
|
+
except ImportError as e:
|
|
21
|
+
raise ImportError(
|
|
22
|
+
"Local embeddings require optional dependencies. "
|
|
23
|
+
"Install with: pip install rag-python[local]"
|
|
24
|
+
) from e
|
|
25
|
+
self._models[model_name] = SentenceTransformer(model_name)
|
|
26
|
+
return self._models[model_name]
|
|
27
|
+
|
|
28
|
+
def embed(self, texts: list[str], *, model: str | None = None) -> list[list[float]]:
|
|
29
|
+
if not texts:
|
|
30
|
+
return []
|
|
31
|
+
model_name = model or self.default_model
|
|
32
|
+
encoder = self._get_model(model_name)
|
|
33
|
+
vectors = encoder.encode(texts, convert_to_numpy=True)
|
|
34
|
+
return [v.tolist() for v in vectors]
|
rag_python/rag_pipeline.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
"""Full RAG pipeline: Query → Understanding/Rewrite → Retrieval (multi-query) → Rerank → LLM → Guardrails → Eval/Retry."""
|
|
2
|
+
import logging
|
|
2
3
|
from dataclasses import dataclass
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
@@ -14,6 +15,8 @@ from .providers import LLMProvider, EmbeddingProvider, make_llm_provider, make_e
|
|
|
14
15
|
from .config import DATA_DIR, CHUNK_SIZE, CHUNK_OVERLAP, CHUNK_STRATEGY
|
|
15
16
|
from .options import QueryConfig, SearchConfig
|
|
16
17
|
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
17
20
|
|
|
18
21
|
@dataclass
|
|
19
22
|
class RAGResponse:
|
|
@@ -34,7 +37,7 @@ def _load_documents(
|
|
|
34
37
|
paths: list[Path] | None = None,
|
|
35
38
|
data_path: Path | None = None,
|
|
36
39
|
*,
|
|
37
|
-
extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx"),
|
|
40
|
+
extensions: tuple[str, ...] = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html"),
|
|
38
41
|
) -> list[LoadedDocument]:
|
|
39
42
|
"""Load documents from explicit paths and/or a data directory."""
|
|
40
43
|
docs: list[LoadedDocument] = []
|
|
@@ -136,12 +139,13 @@ def ingest(
|
|
|
136
139
|
strategy = chunk_strategy or CHUNK_STRATEGY
|
|
137
140
|
size = chunk_size or CHUNK_SIZE
|
|
138
141
|
overlap = chunk_overlap or CHUNK_OVERLAP
|
|
139
|
-
ext = extensions or (".txt", ".md", ".pdf", ".docx")
|
|
142
|
+
ext = extensions or (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html")
|
|
140
143
|
embedder = embedder or make_embedding_provider("openai")
|
|
141
144
|
|
|
142
145
|
path_list = [Path(p) for p in paths] if paths else None
|
|
143
146
|
root = Path(data_path) if data_path else (None if path_list else Path(DATA_DIR))
|
|
144
147
|
docs = _load_documents(path_list, root, extensions=ext)
|
|
148
|
+
logger.info("Loaded %s documents for ingest", len(docs))
|
|
145
149
|
return _ingest_documents(
|
|
146
150
|
docs,
|
|
147
151
|
clean=clean,
|
|
@@ -202,11 +206,13 @@ def query(
|
|
|
202
206
|
top_k_retrieve=search_cfg.top_k_retrieve,
|
|
203
207
|
top_k_rerank=search_cfg.top_k_rerank,
|
|
204
208
|
rerank_enabled=search_cfg.rerank_enabled,
|
|
209
|
+
metadata_filter=search_cfg.metadata_filter,
|
|
205
210
|
embedder=embedder,
|
|
206
211
|
embedding_model=embedding_model,
|
|
207
212
|
llm=llm,
|
|
208
213
|
llm_model=llm_model,
|
|
209
214
|
)
|
|
215
|
+
logger.info("Retrieved %s chunks (retriever=%s)", len(hits), search_cfg.retriever)
|
|
210
216
|
context_chunks = [h[0] for h in hits]
|
|
211
217
|
sources = [{"text": h[0][:200], "metadata": h[1], "score": h[2]} for h in hits]
|
|
212
218
|
context_str = "\n\n".join(context_chunks)
|
rag_python/retrieval.py
CHANGED
|
@@ -1,14 +1,49 @@
|
|
|
1
|
-
"""Retrieval: multi-query
|
|
1
|
+
"""Retrieval: vector, multi-query, hybrid (BM25+vector), and reranking."""
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
|
-
from .vector_store import retrieve as chroma_retrieve
|
|
4
|
+
from .vector_store import retrieve as chroma_retrieve, list_documents
|
|
5
5
|
from .query_rewriting import rewrite_for_retrieval
|
|
6
6
|
from .reranker import rerank_with_metadata
|
|
7
|
+
from .hybrid_search import bm25_retrieve, reciprocal_rank_fusion
|
|
7
8
|
from .providers import EmbeddingProvider, LLMProvider
|
|
8
9
|
from .options import RetrieverStrategy
|
|
9
10
|
from .config import TOP_K_RETRIEVE, TOP_K_RERANK, MULTI_QUERY_N
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
def _dedupe_candidates(candidates: list[tuple[str, dict, float]]) -> list[tuple[str, dict, float]]:
|
|
14
|
+
seen: set[tuple[str, str]] = set()
|
|
15
|
+
out: list[tuple[str, dict, float]] = []
|
|
16
|
+
for doc, meta, score in candidates:
|
|
17
|
+
key = (doc[:200], str(meta.get("source", "")))
|
|
18
|
+
if key in seen:
|
|
19
|
+
continue
|
|
20
|
+
seen.add(key)
|
|
21
|
+
out.append((doc, meta, score))
|
|
22
|
+
return out
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _vector_candidates(
|
|
26
|
+
queries: list[str],
|
|
27
|
+
*,
|
|
28
|
+
embedder: EmbeddingProvider,
|
|
29
|
+
embedding_model: str | None,
|
|
30
|
+
top_k_retrieve: int,
|
|
31
|
+
where: dict | None,
|
|
32
|
+
) -> list[tuple[str, dict, float]]:
|
|
33
|
+
seen_docs: set[tuple[str, str]] = set()
|
|
34
|
+
all_candidates: list[tuple[str, dict, float]] = []
|
|
35
|
+
for q in queries:
|
|
36
|
+
emb = embedder.embed([q], model=embedding_model)[0]
|
|
37
|
+
hits = chroma_retrieve(emb, top_k=top_k_retrieve, where=where)
|
|
38
|
+
for doc, meta, dist in hits:
|
|
39
|
+
key = (doc[:200], str(meta.get("source", "")))
|
|
40
|
+
if key in seen_docs:
|
|
41
|
+
continue
|
|
42
|
+
seen_docs.add(key)
|
|
43
|
+
all_candidates.append((doc, meta, -dist))
|
|
44
|
+
return all_candidates
|
|
45
|
+
|
|
46
|
+
|
|
12
47
|
def retrieve(
|
|
13
48
|
query: str,
|
|
14
49
|
*,
|
|
@@ -20,42 +55,47 @@ def retrieve(
|
|
|
20
55
|
top_k_retrieve: int | None = None,
|
|
21
56
|
top_k_rerank: int | None = None,
|
|
22
57
|
rerank_enabled: bool | None = None,
|
|
58
|
+
metadata_filter: dict | None = None,
|
|
23
59
|
llm: LLMProvider | None = None,
|
|
24
60
|
llm_model: str | None = None,
|
|
25
61
|
) -> list[tuple[str, dict[str, Any], float]]:
|
|
26
62
|
"""
|
|
27
|
-
Retrieve relevant chunks using vector
|
|
63
|
+
Retrieve relevant chunks using vector, multi-query, or hybrid search, then rerank.
|
|
28
64
|
Returns list of (document_text, metadata, rerank_score).
|
|
29
65
|
"""
|
|
30
66
|
top_k_retrieve = top_k_retrieve or TOP_K_RETRIEVE
|
|
31
67
|
top_k_rerank = top_k_rerank or TOP_K_RERANK
|
|
32
68
|
n_queries = n_queries or MULTI_QUERY_N
|
|
33
|
-
use_multi_query = retriever == "multi_query" if multi_query is None else multi_query
|
|
34
69
|
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
queries = rewritten
|
|
70
|
+
if retriever == "hybrid":
|
|
71
|
+
emb = embedder.embed([query], model=embedding_model)[0]
|
|
72
|
+
vector_hits = chroma_retrieve(emb, top_k=top_k_retrieve, where=metadata_filter)
|
|
73
|
+
vector_ranked = [(d, m, -dist) for d, m, dist in vector_hits]
|
|
40
74
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
75
|
+
docs, metas = list_documents(where=metadata_filter)
|
|
76
|
+
bm25_ranked = bm25_retrieve(query, docs, metas, top_k=top_k_retrieve)
|
|
77
|
+
fused = reciprocal_rank_fusion([vector_ranked, bm25_ranked])[:top_k_retrieve]
|
|
78
|
+
all_candidates = _dedupe_candidates(fused)
|
|
79
|
+
else:
|
|
80
|
+
use_multi_query = retriever == "multi_query" if multi_query is None else multi_query
|
|
81
|
+
queries = [query]
|
|
82
|
+
if use_multi_query and n_queries > 1:
|
|
83
|
+
rewritten = rewrite_for_retrieval(query, n_queries=n_queries, llm=llm, llm_model=llm_model)
|
|
84
|
+
if rewritten:
|
|
85
|
+
queries = rewritten
|
|
86
|
+
all_candidates = _vector_candidates(
|
|
87
|
+
queries,
|
|
88
|
+
embedder=embedder,
|
|
89
|
+
embedding_model=embedding_model,
|
|
90
|
+
top_k_retrieve=top_k_retrieve,
|
|
91
|
+
where=metadata_filter,
|
|
92
|
+
)
|
|
52
93
|
|
|
53
94
|
if not all_candidates:
|
|
54
95
|
return []
|
|
96
|
+
|
|
55
97
|
docs = [c[0] for c in all_candidates]
|
|
56
98
|
metas = [c[1] for c in all_candidates]
|
|
57
|
-
|
|
99
|
+
return rerank_with_metadata(
|
|
58
100
|
query, list(zip(docs, metas)), top_k=top_k_rerank, rerank_enabled=rerank_enabled
|
|
59
101
|
)
|
|
60
|
-
return reranked
|
|
61
|
-
|
rag_python/vector_store.py
CHANGED
|
@@ -85,6 +85,19 @@ def retrieve(
|
|
|
85
85
|
return list(zip(docs, metas, dists))
|
|
86
86
|
|
|
87
87
|
|
|
88
|
+
def list_documents(
|
|
89
|
+
*,
|
|
90
|
+
where: dict | None = None,
|
|
91
|
+
limit: int | None = None,
|
|
92
|
+
) -> tuple[list[str], list[dict[str, Any]]]:
|
|
93
|
+
"""Return all stored chunk texts and metadata (for BM25 indexing)."""
|
|
94
|
+
coll = get_collection()
|
|
95
|
+
res = coll.get(where=where, include=["documents", "metadatas"], limit=limit)
|
|
96
|
+
docs = res.get("documents") or []
|
|
97
|
+
metas = res.get("metadatas") or []
|
|
98
|
+
return docs, metas
|
|
99
|
+
|
|
100
|
+
|
|
88
101
|
def delete_all() -> None:
|
|
89
102
|
"""Remove all documents from the collection (for re-ingestion)."""
|
|
90
103
|
_get_client().delete_collection(COLLECTION_NAME)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.2
|
|
2
2
|
Name: rag-python
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Production-grade RAG for Python: multi-LLM, query rewriting, reranking, guardrails, and evaluation.
|
|
5
5
|
Author-email: Raghav Singla <04raghavsingla28@gmail.com>
|
|
6
6
|
License: MIT
|
|
@@ -33,6 +33,10 @@ Requires-Dist: requests>=2.31.0
|
|
|
33
33
|
Provides-Extra: rerank
|
|
34
34
|
Requires-Dist: sentence-transformers>=2.2.0; extra == "rerank"
|
|
35
35
|
Requires-Dist: torch>=2.0.0; extra == "rerank"
|
|
36
|
+
Provides-Extra: local
|
|
37
|
+
Requires-Dist: sentence-transformers>=2.2.0; extra == "local"
|
|
38
|
+
Provides-Extra: hybrid
|
|
39
|
+
Requires-Dist: rank-bm25>=0.2.2; extra == "hybrid"
|
|
36
40
|
Provides-Extra: anthropic
|
|
37
41
|
Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
|
|
38
42
|
Provides-Extra: gemini
|
|
@@ -42,11 +46,14 @@ Requires-Dist: pytest>=7.0; extra == "dev"
|
|
|
42
46
|
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
43
47
|
Requires-Dist: build; extra == "dev"
|
|
44
48
|
Requires-Dist: twine; extra == "dev"
|
|
49
|
+
Requires-Dist: rank-bm25>=0.2.2; extra == "dev"
|
|
45
50
|
Provides-Extra: all
|
|
46
|
-
Requires-Dist: rag-python[anthropic,gemini,rerank]; extra == "all"
|
|
51
|
+
Requires-Dist: rag-python[anthropic,gemini,hybrid,local,rerank]; extra == "all"
|
|
47
52
|
|
|
48
53
|
# rag-python
|
|
49
54
|
|
|
55
|
+
[](https://pypi.org/project/rag-python/)
|
|
56
|
+
[](https://pypi.org/project/rag-python/)
|
|
50
57
|
[](https://www.python.org/downloads/)
|
|
51
58
|
[](LICENSE)
|
|
52
59
|
[](https://github.com/RaghavOG/rag-python)
|
|
@@ -63,10 +70,11 @@ Ingest your documents, ask questions, get grounded answers — with query rewrit
|
|
|
63
70
|
## Features
|
|
64
71
|
|
|
65
72
|
- Document pipeline: loaders → cleaning → chunking → embeddings → ChromaDB
|
|
66
|
-
- Query pipeline: rewriting → multi-query retrieval → reranking
|
|
73
|
+
- Query pipeline: rewriting → multi-query / **hybrid** retrieval → reranking
|
|
67
74
|
- Generation with guardrails (prompt injection + hallucination checks)
|
|
68
75
|
- Evaluation scores + self-correction retry loop
|
|
69
76
|
- **LLM providers:** OpenAI, Azure OpenAI, Anthropic, Gemini, Ollama
|
|
77
|
+
- **Loaders:** TXT, MD, PDF, DOCX, CSV, JSON, HTML
|
|
70
78
|
|
|
71
79
|
---
|
|
72
80
|
|
|
@@ -77,7 +85,7 @@ pip install rag-python
|
|
|
77
85
|
# or from source
|
|
78
86
|
pip install -e .
|
|
79
87
|
# with reranking + extra providers
|
|
80
|
-
pip install -e ".[rerank,anthropic,gemini,all]"
|
|
88
|
+
pip install -e ".[rerank,local,hybrid,anthropic,gemini,all]"
|
|
81
89
|
```
|
|
82
90
|
|
|
83
91
|
---
|
|
@@ -99,12 +107,26 @@ answer = rag.query("How many days of annual leave?")
|
|
|
99
107
|
print(answer.text)
|
|
100
108
|
```
|
|
101
109
|
|
|
110
|
+
### Hybrid search + metadata filter
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
from rag_python import RAG, SearchConfig
|
|
114
|
+
|
|
115
|
+
rag = RAG(
|
|
116
|
+
retriever="hybrid", # pip install rag-python[hybrid]
|
|
117
|
+
metadata_filter={"filename": "leave-policy.pdf"},
|
|
118
|
+
)
|
|
119
|
+
rag.ingest(["./policies/leave-policy.pdf", "./policies/handbook.pdf"])
|
|
120
|
+
answer = rag.query("How many days of annual leave?")
|
|
121
|
+
```
|
|
122
|
+
|
|
102
123
|
### CLI
|
|
103
124
|
|
|
104
125
|
```bash
|
|
105
126
|
export OPENAI_API_KEY=sk-...
|
|
106
127
|
rag-python ingest ./data --reindex
|
|
107
128
|
rag-python query "How many days of annual leave?" -v
|
|
129
|
+
rag-python query "leave policy" --retriever hybrid --metadata-filter '{"filename": "leave-policy.pdf"}'
|
|
108
130
|
```
|
|
109
131
|
|
|
110
132
|
---
|
|
@@ -1,31 +1,33 @@
|
|
|
1
|
-
rag_python/__init__.py,sha256=
|
|
1
|
+
rag_python/__init__.py,sha256=TzZxXzRdKszqqbq7KynrO-Cc0JMzZc1UcIxtNSLhvqQ,834
|
|
2
2
|
rag_python/chunking.py,sha256=P1dbZ8ZY7487MxrWe2cypCiKhzIJ8zBPCTVz20vt8fo,6204
|
|
3
3
|
rag_python/cleaning.py,sha256=fSux4T0pg7Xe_8NUP2pgzuForyRk1i2VPYIXSzRajzs,3193
|
|
4
|
-
rag_python/cli.py,sha256=
|
|
5
|
-
rag_python/client.py,sha256=
|
|
4
|
+
rag_python/cli.py,sha256=z22LLX6dWnMlaI9yIU2tf4HpcLbG2zRz66RQWsFxGNY,4775
|
|
5
|
+
rag_python/client.py,sha256=RyWLBvj4bAJW1Vb529me7Eo608e9Wwq-OeImAAKjyIY,7838
|
|
6
6
|
rag_python/config.py,sha256=Zw8TjQFKRvOUHpIb7kjEb7DtPFoYPzdQyOPzSXTqDcc,1389
|
|
7
|
-
rag_python/document_loaders.py,sha256=
|
|
7
|
+
rag_python/document_loaders.py,sha256=blI-rMqzmHSHzcX9RmFBQZ_MYiM_uKLvesCDTPyoQbo,4866
|
|
8
8
|
rag_python/evaluation.py,sha256=gTiXMaAtTUIsV6Ffhywz829BhfR8YhfJFkYZYrD9WYI,3561
|
|
9
9
|
rag_python/generation.py,sha256=t6aSct2vZELIf20JDwRVt8UTwPnTXx0bU3TKoliiwVg,1108
|
|
10
10
|
rag_python/guardrails.py,sha256=hJLXvpPNI9o8emyipSy5PpePofGzktlDLyMAXfAxUXs,2520
|
|
11
|
-
rag_python/
|
|
11
|
+
rag_python/hybrid_search.py,sha256=71kZyJ9obZBZGzhrl1DQjK32X4AtFppk_wvmpkUVzwo,1814
|
|
12
|
+
rag_python/options.py,sha256=P_nLMk7vQdRM11HCoR9AMUk2D0NmEVA5B5_ufhRiAmE,1935
|
|
12
13
|
rag_python/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
14
|
rag_python/query_rewriting.py,sha256=og_XWai2-08C7W67mFndA3k-aTxMdqGnu70qHi1Ohgc,2293
|
|
14
|
-
rag_python/rag_pipeline.py,sha256=
|
|
15
|
+
rag_python/rag_pipeline.py,sha256=qth2LDVi6QxpqJVskjLxaxnpwpV5dKwh515334fc8DY,9058
|
|
15
16
|
rag_python/reranker.py,sha256=8RxCPfgp80c-KSKojllGzbpZ7iSku-i7VLgPHa1a3rk,2181
|
|
16
|
-
rag_python/retrieval.py,sha256=
|
|
17
|
-
rag_python/vector_store.py,sha256=
|
|
17
|
+
rag_python/retrieval.py,sha256=iTlkaCs79iDDa_K9gktjJC9bAE0bHzy302CFGwmwEk0,3887
|
|
18
|
+
rag_python/vector_store.py,sha256=iAjGRXtzvh9F3aQJVRZ7abUfvwR5YM-qQ0N52qwJGmw,3340
|
|
18
19
|
rag_python/providers/__init__.py,sha256=SjhMvYoA30EY5VUYVXhEGwcmQnIU2tUomcNE0_0NFho,215
|
|
19
20
|
rag_python/providers/anthropic_provider.py,sha256=dSiCdM4F90jI9w7z_wS10XuVsX-pR733-cAgJHtVV2Q,1493
|
|
20
21
|
rag_python/providers/azure_openai_provider.py,sha256=8SbI7rDzQgvC4ZXP89Q8kjfqeWuBfX1KKgExGLFkmx0,1940
|
|
21
22
|
rag_python/providers/base.py,sha256=M9DYowQvNvRuATaM6944CWovK0awJ0buBmbnQfroJos,593
|
|
22
|
-
rag_python/providers/factory.py,sha256=
|
|
23
|
+
rag_python/providers/factory.py,sha256=O7nYikPOh_LnVgTVIreLQKL-ehIMayr3KXES1wpKpjw,2717
|
|
23
24
|
rag_python/providers/gemini_provider.py,sha256=OZzs1YJQSZituoxS5Gk8yv3jYNIFY1SVovWUu7lz5Z4,1842
|
|
25
|
+
rag_python/providers/local_provider.py,sha256=tgYBNUrs7pKpPebA0tpNhJmtZLwwINuZFqKMyHlymTQ,1332
|
|
24
26
|
rag_python/providers/ollama_provider.py,sha256=DDhDriB6-Ob0r2-M-P3SvIFG37ruDAErtU7LWDK8xh0,1958
|
|
25
27
|
rag_python/providers/openai_provider.py,sha256=oR7rCCaxCtirAVetJrR4oC3UrWySuqLc9kbosydoQAQ,1585
|
|
26
|
-
rag_python-0.
|
|
27
|
-
rag_python-0.
|
|
28
|
-
rag_python-0.
|
|
29
|
-
rag_python-0.
|
|
30
|
-
rag_python-0.
|
|
31
|
-
rag_python-0.
|
|
28
|
+
rag_python-0.3.0.dist-info/LICENSE,sha256=PZ61Z6ve0hBHgztaC1rPgnxQTRXRkeHKASlnKkX2pvc,1079
|
|
29
|
+
rag_python-0.3.0.dist-info/METADATA,sha256=iIp2OG2jfo7xVYYQCQf264ZAFBeIhecfs5lIy-XTLZo,6171
|
|
30
|
+
rag_python-0.3.0.dist-info/WHEEL,sha256=beeZ86-EfXScwlR_HKu4SllMC9wUEj_8Z_4FJ3egI2w,91
|
|
31
|
+
rag_python-0.3.0.dist-info/entry_points.txt,sha256=558Rd4GWV_6mIyqdRSVNE4ZZi0-KdblTZhcMbIn3ryY,51
|
|
32
|
+
rag_python-0.3.0.dist-info/top_level.txt,sha256=SrgudPwkJWfJ3gUn2n-dhrt9vN2XbQcaZ3wLQZed4Z4,11
|
|
33
|
+
rag_python-0.3.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|