rag-python 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rag_python/__init__.py +39 -0
- rag_python/chunking.py +181 -0
- rag_python/cleaning.py +102 -0
- rag_python/cli.py +77 -0
- rag_python/client.py +190 -0
- rag_python/config.py +37 -0
- rag_python/document_loaders.py +74 -0
- rag_python/evaluation.py +105 -0
- rag_python/generation.py +35 -0
- rag_python/guardrails.py +66 -0
- rag_python/options.py +68 -0
- rag_python/providers/__init__.py +5 -0
- rag_python/providers/anthropic_provider.py +41 -0
- rag_python/providers/azure_openai_provider.py +62 -0
- rag_python/providers/base.py +24 -0
- rag_python/providers/factory.py +53 -0
- rag_python/providers/gemini_provider.py +45 -0
- rag_python/providers/ollama_provider.py +56 -0
- rag_python/providers/openai_provider.py +46 -0
- rag_python/py.typed +0 -0
- rag_python/query_rewriting.py +65 -0
- rag_python/rag_pipeline.py +241 -0
- rag_python/reranker.py +64 -0
- rag_python/retrieval.py +61 -0
- rag_python/vector_store.py +91 -0
- rag_python-0.1.0.dist-info/LICENSE +22 -0
- rag_python-0.1.0.dist-info/METADATA +158 -0
- rag_python-0.1.0.dist-info/RECORD +31 -0
- rag_python-0.1.0.dist-info/WHEEL +5 -0
- rag_python-0.1.0.dist-info/entry_points.txt +2 -0
- rag_python-0.1.0.dist-info/top_level.txt +1 -0
rag_python/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""rag-python — production-grade RAG for Python.
|
|
2
|
+
|
|
3
|
+
Quick start::
|
|
4
|
+
|
|
5
|
+
from rag_python import RAG
|
|
6
|
+
|
|
7
|
+
rag = RAG(llm_model="gpt-4o-mini")
|
|
8
|
+
rag.ingest(["./docs"], reindex=True)
|
|
9
|
+
print(rag.query("What is our leave policy?").text)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
__version__ = "0.1.0"
|
|
13
|
+
|
|
14
|
+
from .client import RAG, RAGAnswer
|
|
15
|
+
from .rag_pipeline import ingest, query, RAGResponse
|
|
16
|
+
from .providers import make_llm_provider, make_embedding_provider
|
|
17
|
+
from .options import (
|
|
18
|
+
ChunkingConfig,
|
|
19
|
+
DocumentConfig,
|
|
20
|
+
QueryConfig,
|
|
21
|
+
RAGConfig,
|
|
22
|
+
SearchConfig,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"__version__",
|
|
27
|
+
"RAG",
|
|
28
|
+
"RAGAnswer",
|
|
29
|
+
"RAGConfig",
|
|
30
|
+
"ChunkingConfig",
|
|
31
|
+
"SearchConfig",
|
|
32
|
+
"DocumentConfig",
|
|
33
|
+
"QueryConfig",
|
|
34
|
+
"ingest",
|
|
35
|
+
"query",
|
|
36
|
+
"RAGResponse",
|
|
37
|
+
"make_llm_provider",
|
|
38
|
+
"make_embedding_provider",
|
|
39
|
+
]
|
rag_python/chunking.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
1
|
+
"""Chunking: recursive, structure-aware (headings/sections), and semantic (embedding-based)."""
|
|
2
|
+
import re
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
import tiktoken
|
|
8
|
+
except ImportError:
|
|
9
|
+
tiktoken = None
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Chunk:
|
|
14
|
+
"""Single chunk with text and metadata."""
|
|
15
|
+
text: str
|
|
16
|
+
metadata: dict
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
# --- Recursive: split by section → paragraph → sentence → tokens ---
|
|
20
|
+
RECURSIVE_SEPARATORS = ["\n\n\n", "\n\n", "\n", ". ", " ", ""]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _split_by_tokens(text: str, chunk_size: int, overlap: int, encoding_name: str = "cl100k_base") -> list[str]:
|
|
24
|
+
if not tiktoken:
|
|
25
|
+
size = chunk_size * 4
|
|
26
|
+
overlap_chars = overlap * 4
|
|
27
|
+
out = []
|
|
28
|
+
start = 0
|
|
29
|
+
while start < len(text):
|
|
30
|
+
end = min(start + size, len(text))
|
|
31
|
+
out.append(text[start:end])
|
|
32
|
+
start = end - overlap_chars if end < len(text) else len(text)
|
|
33
|
+
return out
|
|
34
|
+
enc = tiktoken.get_encoding(encoding_name)
|
|
35
|
+
tokens = enc.encode(text)
|
|
36
|
+
out = []
|
|
37
|
+
start = 0
|
|
38
|
+
while start < len(tokens):
|
|
39
|
+
end = min(start + chunk_size, len(tokens))
|
|
40
|
+
out.append(enc.decode(tokens[start:end]))
|
|
41
|
+
start = end - overlap if end < len(tokens) else len(tokens)
|
|
42
|
+
return out
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _recursive_split(text: str, separators: list[str], chunk_size: int, overlap: int) -> list[str]:
|
|
46
|
+
if not text.strip():
|
|
47
|
+
return []
|
|
48
|
+
sep = separators[0] if separators else ""
|
|
49
|
+
if sep == "":
|
|
50
|
+
return _split_by_tokens(text, chunk_size, overlap)
|
|
51
|
+
parts = text.split(sep)
|
|
52
|
+
if len(parts) == 1:
|
|
53
|
+
return _recursive_split(text, separators[1:], chunk_size, overlap)
|
|
54
|
+
chunks = []
|
|
55
|
+
current = ""
|
|
56
|
+
for p in parts:
|
|
57
|
+
bit = p if sep in "\n" else p + sep
|
|
58
|
+
if len(current) + len(bit) <= chunk_size * 4:
|
|
59
|
+
current += bit
|
|
60
|
+
else:
|
|
61
|
+
if current.strip():
|
|
62
|
+
chunks.append(current.strip())
|
|
63
|
+
current = bit[-overlap * 4 :] + bit if overlap else bit
|
|
64
|
+
if current.strip():
|
|
65
|
+
chunks.append(current.strip())
|
|
66
|
+
return chunks
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def chunk_recursive(
|
|
70
|
+
text: str,
|
|
71
|
+
chunk_size: int = 512,
|
|
72
|
+
overlap: int = 64,
|
|
73
|
+
metadata: dict | None = None,
|
|
74
|
+
) -> list[Chunk]:
|
|
75
|
+
"""Recursive chunking: section → paragraph → sentence → tokens."""
|
|
76
|
+
raw = _recursive_split(text, RECURSIVE_SEPARATORS, chunk_size, overlap)
|
|
77
|
+
meta = dict(metadata or {})
|
|
78
|
+
meta["chunk_strategy"] = "recursive"
|
|
79
|
+
return [Chunk(text=t, metadata={**meta}) for t in raw if t.strip()]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
HEADING_PATTERN = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _structure_sections(text: str) -> list[tuple[str, str]]:
|
|
86
|
+
"""Split by markdown-style headings; preserve content under each heading."""
|
|
87
|
+
sections = []
|
|
88
|
+
current_title = "Document"
|
|
89
|
+
current_content = []
|
|
90
|
+
for line in text.splitlines():
|
|
91
|
+
m = HEADING_PATTERN.match(line)
|
|
92
|
+
if m:
|
|
93
|
+
if current_content:
|
|
94
|
+
sections.append((current_title, "\n".join(current_content)))
|
|
95
|
+
current_title = m.group(2).strip()
|
|
96
|
+
current_content = []
|
|
97
|
+
else:
|
|
98
|
+
current_content.append(line)
|
|
99
|
+
if current_content:
|
|
100
|
+
sections.append((current_title, "\n".join(current_content)))
|
|
101
|
+
return sections
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def chunk_structure_aware(
|
|
105
|
+
text: str,
|
|
106
|
+
chunk_size: int = 512,
|
|
107
|
+
overlap: int = 64,
|
|
108
|
+
metadata: dict | None = None,
|
|
109
|
+
) -> list[Chunk]:
|
|
110
|
+
"""Structure-aware: chunk by sections (headings); keep tables/code blocks intact."""
|
|
111
|
+
sections = _structure_sections(text)
|
|
112
|
+
meta = dict(metadata or {})
|
|
113
|
+
meta["chunk_strategy"] = "structure_aware"
|
|
114
|
+
chunks = []
|
|
115
|
+
for title, content in sections:
|
|
116
|
+
content = content.strip()
|
|
117
|
+
if not content:
|
|
118
|
+
continue
|
|
119
|
+
if len(content) <= chunk_size * 4:
|
|
120
|
+
chunks.append(Chunk(text=f"## {title}\n\n{content}", metadata={**meta, "section": title}))
|
|
121
|
+
else:
|
|
122
|
+
sub = _recursive_split(content, RECURSIVE_SEPARATORS[1:], chunk_size, overlap)
|
|
123
|
+
for i, t in enumerate(sub):
|
|
124
|
+
if t.strip():
|
|
125
|
+
chunks.append(Chunk(
|
|
126
|
+
text=f"## {title}\n\n{t.strip()}",
|
|
127
|
+
metadata={**meta, "section": title, "section_part": i},
|
|
128
|
+
))
|
|
129
|
+
return chunks
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def chunk_semantic(
|
|
133
|
+
text: str,
|
|
134
|
+
embed_fn: Callable[[list[str]], list[list[float]]],
|
|
135
|
+
chunk_size: int = 512,
|
|
136
|
+
overlap: int = 64,
|
|
137
|
+
metadata: dict | None = None,
|
|
138
|
+
similarity_threshold: float = 0.7,
|
|
139
|
+
) -> list[Chunk]:
|
|
140
|
+
"""Semantic chunking: approximate topic shifts and split."""
|
|
141
|
+
segments = re.split(r"(?<=[.!?])\s+", text)
|
|
142
|
+
if len(segments) <= 1:
|
|
143
|
+
return chunk_recursive(text, chunk_size, overlap, metadata)
|
|
144
|
+
|
|
145
|
+
meta = dict(metadata or {})
|
|
146
|
+
meta["chunk_strategy"] = "semantic"
|
|
147
|
+
chunks = []
|
|
148
|
+
current = []
|
|
149
|
+
current_len = 0
|
|
150
|
+
for seg in segments:
|
|
151
|
+
seg = seg.strip()
|
|
152
|
+
if not seg:
|
|
153
|
+
continue
|
|
154
|
+
current.append(seg)
|
|
155
|
+
current_len += len(seg)
|
|
156
|
+
if current_len >= chunk_size * 3:
|
|
157
|
+
chunk_text = " ".join(current)
|
|
158
|
+
chunks.append(Chunk(text=chunk_text, metadata={**meta}))
|
|
159
|
+
overlap_segs = max(1, len(current) // 4)
|
|
160
|
+
current = current[-overlap_segs:]
|
|
161
|
+
current_len = sum(len(s) for s in current)
|
|
162
|
+
if current:
|
|
163
|
+
chunks.append(Chunk(text=" ".join(current), metadata={**meta}))
|
|
164
|
+
return chunks
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def chunk_text(
|
|
168
|
+
text: str,
|
|
169
|
+
strategy: str = "recursive",
|
|
170
|
+
chunk_size: int = 512,
|
|
171
|
+
overlap: int = 64,
|
|
172
|
+
metadata: dict | None = None,
|
|
173
|
+
embed_fn: Callable[[list[str]], list[list[float]]] | None = None,
|
|
174
|
+
) -> list[Chunk]:
|
|
175
|
+
"""Unified entry: recursive | structure_aware | semantic."""
|
|
176
|
+
if strategy == "structure_aware":
|
|
177
|
+
return chunk_structure_aware(text, chunk_size, overlap, metadata)
|
|
178
|
+
if strategy == "semantic" and embed_fn:
|
|
179
|
+
return chunk_semantic(text, embed_fn, chunk_size, overlap, metadata, similarity_threshold=0.7)
|
|
180
|
+
return chunk_recursive(text, chunk_size, overlap, metadata)
|
|
181
|
+
|
rag_python/cleaning.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Text cleaning & normalization. Garbage in → hallucination out."""
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
try:
|
|
5
|
+
from langdetect import detect, LangDetectException
|
|
6
|
+
except ImportError:
|
|
7
|
+
detect = None
|
|
8
|
+
LangDetectException = Exception
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def normalize_whitespace(text: str) -> str:
|
|
12
|
+
"""Collapse runs of whitespace and strip."""
|
|
13
|
+
return re.sub(r"\s+", " ", text).strip()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def remove_header_footer_candidates(text: str, min_line_len: int = 10) -> str:
|
|
17
|
+
"""Remove lines that look like headers/footers (very short, repeated at top/bottom)."""
|
|
18
|
+
lines = text.splitlines()
|
|
19
|
+
if len(lines) < 5:
|
|
20
|
+
return text
|
|
21
|
+
|
|
22
|
+
def is_likely_header_footer(line: str) -> bool:
|
|
23
|
+
s = line.strip()
|
|
24
|
+
if len(s) < min_line_len:
|
|
25
|
+
return True
|
|
26
|
+
if re.match(r"^[\d\s\-\.\/]+$", s): # page numbers, dates
|
|
27
|
+
return True
|
|
28
|
+
return False
|
|
29
|
+
|
|
30
|
+
start = 0
|
|
31
|
+
while start < len(lines) and is_likely_header_footer(lines[start]):
|
|
32
|
+
start += 1
|
|
33
|
+
end = len(lines)
|
|
34
|
+
while end > start and is_likely_header_footer(lines[end - 1]):
|
|
35
|
+
end -= 1
|
|
36
|
+
return "\n".join(lines[start:end])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def deduplicate_sentences(text: str) -> str:
|
|
40
|
+
"""Remove consecutive duplicate sentences (and near-duplicates by line)."""
|
|
41
|
+
lines = [normalize_whitespace(line) for line in text.splitlines() if line.strip()]
|
|
42
|
+
seen = set()
|
|
43
|
+
out = []
|
|
44
|
+
for line in lines:
|
|
45
|
+
key = line.lower()[:200]
|
|
46
|
+
if key in seen:
|
|
47
|
+
continue
|
|
48
|
+
seen.add(key)
|
|
49
|
+
out.append(line)
|
|
50
|
+
return "\n".join(out)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def preserve_blocks(text: str) -> str:
|
|
54
|
+
"""Normalize whitespace but preserve code blocks and tables (markdown-style)."""
|
|
55
|
+
out = []
|
|
56
|
+
in_code = False
|
|
57
|
+
for part in re.split(r"(```[\w]*\n?|```)", text):
|
|
58
|
+
if part.startswith("```"):
|
|
59
|
+
in_code = not in_code
|
|
60
|
+
out.append(part)
|
|
61
|
+
continue
|
|
62
|
+
if in_code:
|
|
63
|
+
out.append(part)
|
|
64
|
+
continue
|
|
65
|
+
out.append(normalize_whitespace(part))
|
|
66
|
+
return "".join(out) if out else text
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def detect_language(text: str) -> str | None:
|
|
70
|
+
"""Return ISO language code or None if detection fails."""
|
|
71
|
+
if not detect:
|
|
72
|
+
return None
|
|
73
|
+
try:
|
|
74
|
+
sample = text[:2000] if len(text) > 2000 else text
|
|
75
|
+
return detect(sample)
|
|
76
|
+
except LangDetectException:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def clean_document(
|
|
81
|
+
text: str,
|
|
82
|
+
*,
|
|
83
|
+
normalize_ws: bool = True,
|
|
84
|
+
remove_headers_footers: bool = True,
|
|
85
|
+
dedupe: bool = True,
|
|
86
|
+
preserve_code_tables: bool = True,
|
|
87
|
+
min_lang_length: int = 50,
|
|
88
|
+
) -> str:
|
|
89
|
+
"""Full cleaning pipeline. Preserve code/tables; optionally skip non-English if desired."""
|
|
90
|
+
if normalize_ws and not preserve_code_tables:
|
|
91
|
+
text = normalize_whitespace(text)
|
|
92
|
+
elif preserve_code_tables:
|
|
93
|
+
text = preserve_blocks(text)
|
|
94
|
+
if remove_headers_footers:
|
|
95
|
+
text = remove_header_footer_candidates(text)
|
|
96
|
+
if dedupe:
|
|
97
|
+
text = deduplicate_sentences(text)
|
|
98
|
+
if normalize_ws and preserve_code_tables:
|
|
99
|
+
text = re.sub(r"(?<!\n)\n(?!\n)", " ", text)
|
|
100
|
+
text = re.sub(r" +", " ", text)
|
|
101
|
+
return text.strip()
|
|
102
|
+
|
rag_python/cli.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
"""rag-python command-line interface."""
|
|
2
|
+
import argparse
|
|
3
|
+
|
|
4
|
+
from .client import RAG
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _build_rag(args: argparse.Namespace) -> RAG:
|
|
8
|
+
return RAG(
|
|
9
|
+
llm_provider=args.llm_provider,
|
|
10
|
+
llm_model=args.llm_model,
|
|
11
|
+
embedding_provider=args.embedding_provider,
|
|
12
|
+
embedding_model=args.embedding_model,
|
|
13
|
+
openai_api_key=args.openai_api_key,
|
|
14
|
+
azure_endpoint=args.azure_endpoint,
|
|
15
|
+
azure_api_key=args.azure_api_key,
|
|
16
|
+
azure_api_version=args.azure_api_version,
|
|
17
|
+
anthropic_api_key=args.anthropic_api_key,
|
|
18
|
+
gemini_api_key=args.gemini_api_key,
|
|
19
|
+
ollama_base_url=args.ollama_base_url,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _add_provider_args(parser: argparse.ArgumentParser) -> None:
|
|
24
|
+
parser.add_argument("--llm-provider", default="openai", choices=["openai", "azure_openai", "anthropic", "gemini", "ollama"])
|
|
25
|
+
parser.add_argument("--llm-model", default=None)
|
|
26
|
+
parser.add_argument("--embedding-provider", default="openai", choices=["openai", "azure_openai", "ollama"])
|
|
27
|
+
parser.add_argument("--embedding-model", default=None)
|
|
28
|
+
parser.add_argument("--ollama-base-url", default=None)
|
|
29
|
+
parser.add_argument("--azure-endpoint", default=None)
|
|
30
|
+
parser.add_argument("--azure-api-key", default=None)
|
|
31
|
+
parser.add_argument("--azure-api-version", default=None)
|
|
32
|
+
parser.add_argument("--openai-api-key", default=None)
|
|
33
|
+
parser.add_argument("--anthropic-api-key", default=None)
|
|
34
|
+
parser.add_argument("--gemini-api-key", default=None)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def main() -> None:
|
|
38
|
+
parser = argparse.ArgumentParser(
|
|
39
|
+
prog="rag-python",
|
|
40
|
+
description="rag-python — modular RAG with query rewriting, reranking, guardrails, and multi-LLM support.",
|
|
41
|
+
)
|
|
42
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
43
|
+
|
|
44
|
+
ing = sub.add_parser("ingest", help="Ingest files/folders into the vector store")
|
|
45
|
+
ing.add_argument("paths", nargs="+", help="Files or folders to ingest")
|
|
46
|
+
ing.add_argument("--reindex", action="store_true", help="Clear vector store and re-ingest")
|
|
47
|
+
_add_provider_args(ing)
|
|
48
|
+
|
|
49
|
+
q = sub.add_parser("query", help="Ask a question against ingested documents")
|
|
50
|
+
q.add_argument("question", nargs="+", help="Question text")
|
|
51
|
+
q.add_argument("--no-multi-query", action="store_true")
|
|
52
|
+
q.add_argument("-v", "--verbose", action="store_true")
|
|
53
|
+
_add_provider_args(q)
|
|
54
|
+
|
|
55
|
+
args = parser.parse_args()
|
|
56
|
+
|
|
57
|
+
if args.command == "ingest":
|
|
58
|
+
rag = _build_rag(args)
|
|
59
|
+
n = rag.ingest(args.paths, reindex=args.reindex)
|
|
60
|
+
print(f"Ingested {n} chunks.")
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
if args.command == "query":
|
|
64
|
+
rag = _build_rag(args)
|
|
65
|
+
question = " ".join(args.question)
|
|
66
|
+
ans = rag.query(question, multi_query=not args.no_multi_query)
|
|
67
|
+
print(ans.text)
|
|
68
|
+
if args.verbose:
|
|
69
|
+
print("\n--- evaluation ---")
|
|
70
|
+
print(ans.evaluation)
|
|
71
|
+
print("\n--- sources ---")
|
|
72
|
+
for s in ans.sources[:5]:
|
|
73
|
+
print(s.get("metadata", {}).get("source", ""), "score:", s.get("score"))
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
if __name__ == "__main__":
|
|
77
|
+
main()
|
rag_python/client.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""High-level RAG client API.
|
|
2
|
+
|
|
3
|
+
This wraps the full RAG pipeline behind a simple interface:
|
|
4
|
+
|
|
5
|
+
from rag_python import RAG, RAGConfig, ChunkingConfig, SearchConfig
|
|
6
|
+
|
|
7
|
+
rag = RAG(
|
|
8
|
+
llm_model="gpt-4o-mini",
|
|
9
|
+
embedding_provider="openai",
|
|
10
|
+
embedding_model="text-embedding-3-small",
|
|
11
|
+
config=RAGConfig(
|
|
12
|
+
chunking=ChunkingConfig(strategy="recursive", chunk_size=512),
|
|
13
|
+
search=SearchConfig(retriever="multi_query", top_k_retrieve=20),
|
|
14
|
+
),
|
|
15
|
+
)
|
|
16
|
+
rag.ingest(["./docs", "./policies.pdf", "README.md"])
|
|
17
|
+
answer = rag.query("What is our leave policy?")
|
|
18
|
+
print(answer.text)
|
|
19
|
+
"""
|
|
20
|
+
from dataclasses import dataclass, replace
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Iterable
|
|
23
|
+
|
|
24
|
+
from .config import CHROMA_PERSIST_DIR, DATA_DIR, EMBEDDING_MODEL, LLM_MODEL
|
|
25
|
+
from .options import (
|
|
26
|
+
ChunkingConfig,
|
|
27
|
+
DocumentConfig,
|
|
28
|
+
QueryConfig,
|
|
29
|
+
RAGConfig,
|
|
30
|
+
SearchConfig,
|
|
31
|
+
)
|
|
32
|
+
from .providers import make_llm_provider, make_embedding_provider
|
|
33
|
+
from .rag_pipeline import ingest as _ingest, query as _query, RAGResponse
|
|
34
|
+
from .vector_store import set_persist_dir
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class RAGAnswer:
|
|
39
|
+
text: str
|
|
40
|
+
sources: list[dict]
|
|
41
|
+
evaluation: dict
|
|
42
|
+
retried: bool
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RAG:
|
|
46
|
+
"""User-facing RAG client with configurable chunking, retrieval, and embeddings."""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
*,
|
|
51
|
+
llm_provider: str = "openai",
|
|
52
|
+
llm_model: str | None = None,
|
|
53
|
+
embedding_provider: str = "openai",
|
|
54
|
+
embedding_model: str | None = None,
|
|
55
|
+
data_dir: str | Path | None = None,
|
|
56
|
+
chroma_dir: str | Path | None = None,
|
|
57
|
+
config: RAGConfig | None = None,
|
|
58
|
+
# Shorthand overrides (merged into ``config`` when provided)
|
|
59
|
+
chunk_strategy: str | None = None,
|
|
60
|
+
chunk_size: int | None = None,
|
|
61
|
+
chunk_overlap: int | None = None,
|
|
62
|
+
retriever: str | None = None,
|
|
63
|
+
top_k_retrieve: int | None = None,
|
|
64
|
+
top_k_rerank: int | None = None,
|
|
65
|
+
multi_query_n: int | None = None,
|
|
66
|
+
rerank_enabled: bool | None = None,
|
|
67
|
+
document_extensions: tuple[str, ...] | None = None,
|
|
68
|
+
# Provider kwargs (optional)
|
|
69
|
+
openai_api_key: str | None = None,
|
|
70
|
+
azure_endpoint: str | None = None,
|
|
71
|
+
azure_api_key: str | None = None,
|
|
72
|
+
azure_api_version: str | None = None,
|
|
73
|
+
anthropic_api_key: str | None = None,
|
|
74
|
+
gemini_api_key: str | None = None,
|
|
75
|
+
ollama_base_url: str | None = None,
|
|
76
|
+
) -> None:
|
|
77
|
+
self.llm_provider_name = llm_provider
|
|
78
|
+
self.embedding_provider_name = embedding_provider
|
|
79
|
+
self.llm_model = llm_model or LLM_MODEL
|
|
80
|
+
self.embedding_model = embedding_model or EMBEDDING_MODEL
|
|
81
|
+
self.data_dir = Path(data_dir) if data_dir else Path(DATA_DIR)
|
|
82
|
+
|
|
83
|
+
if chroma_dir:
|
|
84
|
+
set_persist_dir(chroma_dir)
|
|
85
|
+
elif CHROMA_PERSIST_DIR:
|
|
86
|
+
set_persist_dir(CHROMA_PERSIST_DIR)
|
|
87
|
+
|
|
88
|
+
self.config = config or RAGConfig()
|
|
89
|
+
if chunk_strategy is not None:
|
|
90
|
+
self.config.chunking = replace(self.config.chunking, strategy=chunk_strategy) # type: ignore[arg-type]
|
|
91
|
+
if chunk_size is not None:
|
|
92
|
+
self.config.chunking = replace(self.config.chunking, chunk_size=chunk_size)
|
|
93
|
+
if chunk_overlap is not None:
|
|
94
|
+
self.config.chunking = replace(self.config.chunking, chunk_overlap=chunk_overlap)
|
|
95
|
+
if retriever is not None:
|
|
96
|
+
self.config.search = replace(self.config.search, retriever=retriever) # type: ignore[arg-type]
|
|
97
|
+
if top_k_retrieve is not None:
|
|
98
|
+
self.config.search = replace(self.config.search, top_k_retrieve=top_k_retrieve)
|
|
99
|
+
if top_k_rerank is not None:
|
|
100
|
+
self.config.search = replace(self.config.search, top_k_rerank=top_k_rerank)
|
|
101
|
+
if multi_query_n is not None:
|
|
102
|
+
self.config.search = replace(self.config.search, multi_query_n=multi_query_n)
|
|
103
|
+
if rerank_enabled is not None:
|
|
104
|
+
self.config.search = replace(self.config.search, rerank_enabled=rerank_enabled)
|
|
105
|
+
if document_extensions is not None:
|
|
106
|
+
self.config.documents = replace(self.config.documents, extensions=document_extensions)
|
|
107
|
+
|
|
108
|
+
self.llm = make_llm_provider(
|
|
109
|
+
llm_provider, # type: ignore[arg-type]
|
|
110
|
+
api_key=openai_api_key or anthropic_api_key or gemini_api_key or azure_api_key,
|
|
111
|
+
azure_endpoint=azure_endpoint,
|
|
112
|
+
api_version=azure_api_version,
|
|
113
|
+
base_url=ollama_base_url,
|
|
114
|
+
)
|
|
115
|
+
self.embedder = make_embedding_provider(
|
|
116
|
+
embedding_provider, # type: ignore[arg-type]
|
|
117
|
+
api_key=openai_api_key or azure_api_key,
|
|
118
|
+
azure_endpoint=azure_endpoint,
|
|
119
|
+
api_version=azure_api_version,
|
|
120
|
+
base_url=ollama_base_url,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def ingest(self, paths: Iterable[str | Path], *, reindex: bool = False) -> int:
|
|
124
|
+
"""Ingest one or more files/directories into the vector store."""
|
|
125
|
+
path_list = [Path(p) for p in paths]
|
|
126
|
+
doc_cfg: DocumentConfig = self.config.documents
|
|
127
|
+
chunk_cfg: ChunkingConfig = self.config.chunking
|
|
128
|
+
|
|
129
|
+
if doc_cfg.copy_to_data_dir:
|
|
130
|
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
131
|
+
for p in path_list:
|
|
132
|
+
if p.is_file():
|
|
133
|
+
target = self.data_dir / p.name
|
|
134
|
+
if str(p.resolve()) != str(target.resolve()):
|
|
135
|
+
target.write_bytes(p.read_bytes())
|
|
136
|
+
elif p.is_dir():
|
|
137
|
+
for f in p.rglob("*"):
|
|
138
|
+
if f.is_file():
|
|
139
|
+
rel = f.relative_to(p)
|
|
140
|
+
target = self.data_dir / rel
|
|
141
|
+
target.parent.mkdir(parents=True, exist_ok=True)
|
|
142
|
+
if str(f.resolve()) != str(target.resolve()):
|
|
143
|
+
target.write_bytes(f.read_bytes())
|
|
144
|
+
return _ingest(
|
|
145
|
+
data_path=self.data_dir,
|
|
146
|
+
clean=doc_cfg.clean,
|
|
147
|
+
chunk_strategy=chunk_cfg.strategy,
|
|
148
|
+
chunk_size=chunk_cfg.chunk_size,
|
|
149
|
+
chunk_overlap=chunk_cfg.chunk_overlap,
|
|
150
|
+
extensions=doc_cfg.extensions,
|
|
151
|
+
reindex=reindex,
|
|
152
|
+
embedding_model=self.embedding_model,
|
|
153
|
+
embedder=self.embedder,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
return _ingest(
|
|
157
|
+
paths=path_list,
|
|
158
|
+
clean=doc_cfg.clean,
|
|
159
|
+
chunk_strategy=chunk_cfg.strategy,
|
|
160
|
+
chunk_size=chunk_cfg.chunk_size,
|
|
161
|
+
chunk_overlap=chunk_cfg.chunk_overlap,
|
|
162
|
+
extensions=doc_cfg.extensions,
|
|
163
|
+
reindex=reindex,
|
|
164
|
+
embedding_model=self.embedding_model,
|
|
165
|
+
embedder=self.embedder,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def query(
|
|
169
|
+
self,
|
|
170
|
+
question: str,
|
|
171
|
+
*,
|
|
172
|
+
search: SearchConfig | None = None,
|
|
173
|
+
query_config: QueryConfig | None = None,
|
|
174
|
+
) -> RAGAnswer:
|
|
175
|
+
"""Run a full RAG query and return a friendly answer object."""
|
|
176
|
+
resp: RAGResponse = _query(
|
|
177
|
+
question,
|
|
178
|
+
search=search or self.config.search,
|
|
179
|
+
query_config=query_config or self.config.query,
|
|
180
|
+
llm_model=self.llm_model,
|
|
181
|
+
embedding_model=self.embedding_model,
|
|
182
|
+
llm=self.llm,
|
|
183
|
+
embedder=self.embedder,
|
|
184
|
+
)
|
|
185
|
+
return RAGAnswer(
|
|
186
|
+
text=resp.answer,
|
|
187
|
+
sources=resp.sources,
|
|
188
|
+
evaluation=resp.evaluation,
|
|
189
|
+
retried=resp.retried,
|
|
190
|
+
)
|
rag_python/config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Configuration loaded from environment variables."""
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from dotenv import load_dotenv
|
|
6
|
+
|
|
7
|
+
load_dotenv()
|
|
8
|
+
|
|
9
|
+
# API keys (provider-specific)
|
|
10
|
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
|
11
|
+
|
|
12
|
+
# Models
|
|
13
|
+
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
|
14
|
+
LLM_MODEL = os.getenv("LLM_MODEL", "gpt-4o-mini")
|
|
15
|
+
|
|
16
|
+
# Paths — default to current working directory (works when installed from PyPI)
|
|
17
|
+
PROJECT_ROOT = Path.cwd()
|
|
18
|
+
DATA_DIR = Path(os.getenv("RAG_PYTHON_DATA_DIR", PROJECT_ROOT / "data"))
|
|
19
|
+
CHROMA_PERSIST_DIR = Path(os.getenv("RAG_PYTHON_CHROMA_DIR", PROJECT_ROOT / "chroma_db"))
|
|
20
|
+
|
|
21
|
+
# Chunking
|
|
22
|
+
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "512"))
|
|
23
|
+
CHUNK_OVERLAP = int(os.getenv("CHUNK_OVERLAP", "64"))
|
|
24
|
+
CHUNK_STRATEGY = os.getenv("CHUNK_STRATEGY", "recursive") # recursive | structure_aware | semantic
|
|
25
|
+
|
|
26
|
+
# Retrieval
|
|
27
|
+
TOP_K_RETRIEVE = int(os.getenv("TOP_K_RETRIEVE", "20"))
|
|
28
|
+
TOP_K_RERANK = int(os.getenv("TOP_K_RERANK", "5"))
|
|
29
|
+
MULTI_QUERY_N = int(os.getenv("MULTI_QUERY_N", "3"))
|
|
30
|
+
|
|
31
|
+
# Guardrails
|
|
32
|
+
GUARDRAILS_ENABLED = os.getenv("GUARDRAILS_ENABLED", "true").lower() == "true"
|
|
33
|
+
MAX_RETRIES = int(os.getenv("MAX_RETRIES", "2"))
|
|
34
|
+
|
|
35
|
+
# Reranker (optional extra: pip install rag-python[rerank])
|
|
36
|
+
RERANKER_MODEL = os.getenv("RERANKER_MODEL", "BAAI/bge-reranker-base")
|
|
37
|
+
RERANK_ENABLED = os.getenv("RERANK_ENABLED", "true").lower() == "true"
|