rag-python 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rag_python/__init__.py +19 -4
- rag_python/cli.py +198 -24
- rag_python/client.py +22 -1
- rag_python/document_loaders.py +76 -4
- rag_python/generation.py +32 -2
- rag_python/help_text.py +229 -0
- rag_python/hybrid_search.py +51 -0
- rag_python/log.py +21 -0
- rag_python/options.py +3 -2
- rag_python/providers/anthropic_provider.py +23 -0
- rag_python/providers/azure_openai_provider.py +26 -0
- rag_python/providers/base.py +11 -0
- rag_python/providers/ollama_provider.py +37 -0
- rag_python/providers/openai_provider.py +26 -0
- rag_python/providers/streaming.py +35 -0
- rag_python/rag_pipeline.py +265 -46
- rag_python/retrieval.py +63 -23
- rag_python/vector_store.py +13 -0
- rag_python-0.3.1.dist-info/METADATA +205 -0
- rag_python-0.3.1.dist-info/RECORD +36 -0
- rag_python-0.2.0.dist-info/METADATA +0 -162
- rag_python-0.2.0.dist-info/RECORD +0 -32
- {rag_python-0.2.0.dist-info → rag_python-0.3.1.dist-info}/LICENSE +0 -0
- {rag_python-0.2.0.dist-info → rag_python-0.3.1.dist-info}/WHEEL +0 -0
- {rag_python-0.2.0.dist-info → rag_python-0.3.1.dist-info}/entry_points.txt +0 -0
- {rag_python-0.2.0.dist-info → rag_python-0.3.1.dist-info}/top_level.txt +0 -0
rag_python/__init__.py
CHANGED
|
@@ -2,18 +2,29 @@
|
|
|
2
2
|
|
|
3
3
|
Quick start::
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
pip install rag-python
|
|
6
|
+
export OPENAI_API_KEY=sk-...
|
|
7
|
+
|
|
8
|
+
# CLI
|
|
9
|
+
rag-python ingest ./docs --reindex
|
|
10
|
+
rag-python query "What is our leave policy?"
|
|
11
|
+
rag-python docs quickstart
|
|
6
12
|
|
|
7
|
-
|
|
13
|
+
# Python
|
|
14
|
+
from rag_python import RAG
|
|
15
|
+
rag = RAG()
|
|
8
16
|
rag.ingest(["./docs"], reindex=True)
|
|
9
17
|
print(rag.query("What is our leave policy?").text)
|
|
18
|
+
|
|
19
|
+
Documentation: https://github.com/RaghavOG/rag-python/tree/main/docs
|
|
10
20
|
"""
|
|
11
21
|
|
|
12
|
-
__version__ = "0.
|
|
22
|
+
__version__ = "0.3.1"
|
|
13
23
|
|
|
14
24
|
from .client import RAG, RAGAnswer
|
|
15
|
-
from .rag_pipeline import ingest, query, RAGResponse
|
|
25
|
+
from .rag_pipeline import ingest, query, query_stream, RAGResponse, RAGStream
|
|
16
26
|
from .providers import make_llm_provider, make_embedding_provider
|
|
27
|
+
from .log import configure_logging, get_logger
|
|
17
28
|
from .options import (
|
|
18
29
|
ChunkingConfig,
|
|
19
30
|
DocumentConfig,
|
|
@@ -33,7 +44,11 @@ __all__ = [
|
|
|
33
44
|
"QueryConfig",
|
|
34
45
|
"ingest",
|
|
35
46
|
"query",
|
|
47
|
+
"query_stream",
|
|
48
|
+
"RAGStream",
|
|
36
49
|
"RAGResponse",
|
|
50
|
+
"configure_logging",
|
|
51
|
+
"get_logger",
|
|
37
52
|
"make_llm_provider",
|
|
38
53
|
"make_embedding_provider",
|
|
39
54
|
]
|
rag_python/cli.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
"""rag-python command-line interface."""
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
|
|
2
4
|
import argparse
|
|
5
|
+
import json
|
|
6
|
+
import sys
|
|
3
7
|
from dataclasses import replace
|
|
4
8
|
|
|
5
9
|
from . import __version__
|
|
6
10
|
from .client import RAG
|
|
11
|
+
from .help_text import CLI_EPILOG, list_topics, print_topic, print_topic_list
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
def _build_rag(args: argparse.Namespace) -> RAG:
|
|
10
|
-
|
|
15
|
+
kwargs: dict = dict(
|
|
11
16
|
llm_provider=args.llm_provider,
|
|
12
17
|
llm_model=args.llm_model,
|
|
13
18
|
embedding_provider=args.embedding_provider,
|
|
@@ -20,6 +25,20 @@ def _build_rag(args: argparse.Namespace) -> RAG:
|
|
|
20
25
|
gemini_api_key=args.gemini_api_key,
|
|
21
26
|
ollama_base_url=args.ollama_base_url,
|
|
22
27
|
)
|
|
28
|
+
if getattr(args, "retriever", None):
|
|
29
|
+
kwargs["retriever"] = args.retriever
|
|
30
|
+
if getattr(args, "metadata_filter", None):
|
|
31
|
+
kwargs["metadata_filter"] = args.metadata_filter
|
|
32
|
+
return RAG(**kwargs)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _parse_metadata_filter(raw: str | None) -> dict | None:
|
|
36
|
+
if not raw:
|
|
37
|
+
return None
|
|
38
|
+
try:
|
|
39
|
+
return json.loads(raw)
|
|
40
|
+
except json.JSONDecodeError as e:
|
|
41
|
+
raise argparse.ArgumentTypeError(f"Invalid JSON for metadata filter: {e}") from e
|
|
23
42
|
|
|
24
43
|
|
|
25
44
|
def _add_provider_args(parser: argparse.ArgumentParser) -> None:
|
|
@@ -27,43 +46,180 @@ def _add_provider_args(parser: argparse.ArgumentParser) -> None:
|
|
|
27
46
|
"--llm-provider",
|
|
28
47
|
default="openai",
|
|
29
48
|
choices=["openai", "azure_openai", "anthropic", "gemini", "ollama"],
|
|
49
|
+
metavar="PROVIDER",
|
|
50
|
+
help="LLM backend (default: openai). See: rag-python docs providers",
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--llm-model",
|
|
54
|
+
default=None,
|
|
55
|
+
metavar="MODEL",
|
|
56
|
+
help="LLM model or Azure deployment name (default: from env LLM_MODEL)",
|
|
30
57
|
)
|
|
31
|
-
parser.add_argument("--llm-model", default=None)
|
|
32
58
|
parser.add_argument(
|
|
33
59
|
"--embedding-provider",
|
|
34
60
|
default="openai",
|
|
35
61
|
choices=["openai", "azure_openai", "ollama", "local"],
|
|
62
|
+
metavar="PROVIDER",
|
|
63
|
+
help="Embedding backend (default: openai). Use local for offline embeddings",
|
|
64
|
+
)
|
|
65
|
+
parser.add_argument(
|
|
66
|
+
"--embedding-model",
|
|
67
|
+
default=None,
|
|
68
|
+
metavar="MODEL",
|
|
69
|
+
help="Embedding model name (default: from env EMBEDDING_MODEL)",
|
|
70
|
+
)
|
|
71
|
+
parser.add_argument(
|
|
72
|
+
"--ollama-base-url",
|
|
73
|
+
default=None,
|
|
74
|
+
metavar="URL",
|
|
75
|
+
help="Ollama server URL (default: http://localhost:11434 or OLLAMA_BASE_URL)",
|
|
76
|
+
)
|
|
77
|
+
parser.add_argument("--azure-endpoint", default=None, help="Azure OpenAI endpoint URL")
|
|
78
|
+
parser.add_argument("--azure-api-key", default=None, help="Azure OpenAI API key")
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--azure-api-version",
|
|
81
|
+
default=None,
|
|
82
|
+
help="Azure API version (default: 2023-09-01-preview)",
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument("--openai-api-key", default=None, help="OpenAI API key (overrides env)")
|
|
85
|
+
parser.add_argument("--anthropic-api-key", default=None, help="Anthropic API key")
|
|
86
|
+
parser.add_argument("--gemini-api-key", default=None, help="Gemini API key")
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _add_search_args(parser: argparse.ArgumentParser) -> None:
|
|
90
|
+
parser.add_argument(
|
|
91
|
+
"--retriever",
|
|
92
|
+
choices=["vector", "multi_query", "hybrid"],
|
|
93
|
+
default=None,
|
|
94
|
+
metavar="MODE",
|
|
95
|
+
help=(
|
|
96
|
+
"Retrieval mode: vector (single query), multi_query (default, with rewriting), "
|
|
97
|
+
"or hybrid (BM25+vector; requires pip install rag-python[hybrid])"
|
|
98
|
+
),
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--metadata-filter",
|
|
102
|
+
type=_parse_metadata_filter,
|
|
103
|
+
default=None,
|
|
104
|
+
metavar="JSON",
|
|
105
|
+
help='Filter chunks by metadata, e.g. \'{"filename": "policy.pdf"}\'',
|
|
36
106
|
)
|
|
37
|
-
parser.add_argument("--embedding-model", default=None)
|
|
38
|
-
parser.add_argument("--ollama-base-url", default=None)
|
|
39
|
-
parser.add_argument("--azure-endpoint", default=None)
|
|
40
|
-
parser.add_argument("--azure-api-key", default=None)
|
|
41
|
-
parser.add_argument("--azure-api-version", default=None)
|
|
42
|
-
parser.add_argument("--openai-api-key", default=None)
|
|
43
|
-
parser.add_argument("--anthropic-api-key", default=None)
|
|
44
|
-
parser.add_argument("--gemini-api-key", default=None)
|
|
45
107
|
|
|
46
108
|
|
|
47
|
-
def
|
|
109
|
+
def _make_parser() -> argparse.ArgumentParser:
|
|
48
110
|
parser = argparse.ArgumentParser(
|
|
49
111
|
prog="rag-python",
|
|
50
|
-
description=
|
|
112
|
+
description=(
|
|
113
|
+
"Production-grade RAG for Python — ingest documents, ask questions, "
|
|
114
|
+
"get grounded answers with multi-LLM support."
|
|
115
|
+
),
|
|
116
|
+
epilog=CLI_EPILOG,
|
|
117
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
51
118
|
)
|
|
52
|
-
parser.add_argument("--version", action="version", version=f"
|
|
53
|
-
sub = parser.add_subparsers(dest="command", required=True)
|
|
119
|
+
parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
|
|
120
|
+
sub = parser.add_subparsers(dest="command", required=True, metavar="COMMAND")
|
|
54
121
|
|
|
55
|
-
ing = sub.add_parser(
|
|
56
|
-
|
|
57
|
-
|
|
122
|
+
ing = sub.add_parser(
|
|
123
|
+
"ingest",
|
|
124
|
+
help="Load files into the vector store (chunk + embed)",
|
|
125
|
+
description=(
|
|
126
|
+
"Ingest one or more files or directories into the ChromaDB vector store.\n"
|
|
127
|
+
"Supported formats: .txt .md .pdf .docx .csv .json .html"
|
|
128
|
+
),
|
|
129
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
130
|
+
epilog=(
|
|
131
|
+
"examples:\n"
|
|
132
|
+
" rag-python ingest ./data --reindex\n"
|
|
133
|
+
" rag-python ingest policy.pdf handbook/ --embedding-provider local"
|
|
134
|
+
),
|
|
135
|
+
)
|
|
136
|
+
ing.add_argument(
|
|
137
|
+
"paths",
|
|
138
|
+
nargs="+",
|
|
139
|
+
metavar="PATH",
|
|
140
|
+
help="File or directory paths to ingest",
|
|
141
|
+
)
|
|
142
|
+
ing.add_argument(
|
|
143
|
+
"--reindex",
|
|
144
|
+
action="store_true",
|
|
145
|
+
help="Delete existing vectors before ingesting (fresh index)",
|
|
146
|
+
)
|
|
58
147
|
_add_provider_args(ing)
|
|
59
148
|
|
|
60
|
-
q = sub.add_parser(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
149
|
+
q = sub.add_parser(
|
|
150
|
+
"query",
|
|
151
|
+
help="Ask a question against ingested documents",
|
|
152
|
+
description=(
|
|
153
|
+
"Run the full RAG pipeline: retrieve relevant chunks, generate an answer, "
|
|
154
|
+
"optionally stream tokens and show sources."
|
|
155
|
+
),
|
|
156
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
157
|
+
epilog=(
|
|
158
|
+
"examples:\n"
|
|
159
|
+
' rag-python query "How many days of annual leave?"\n'
|
|
160
|
+
" rag-python query \"PTO policy\" --stream -v\n"
|
|
161
|
+
' rag-python query "benefits" --retriever hybrid --metadata-filter \'{"filename": "hr.pdf"}\''
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
q.add_argument(
|
|
165
|
+
"question",
|
|
166
|
+
nargs="+",
|
|
167
|
+
metavar="QUESTION",
|
|
168
|
+
help="Question text (multiple words are joined)",
|
|
169
|
+
)
|
|
170
|
+
q.add_argument(
|
|
171
|
+
"--no-multi-query",
|
|
172
|
+
action="store_true",
|
|
173
|
+
help="Use single-query vector retrieval (same as --retriever vector)",
|
|
174
|
+
)
|
|
175
|
+
q.add_argument(
|
|
176
|
+
"--stream",
|
|
177
|
+
action="store_true",
|
|
178
|
+
help="Stream answer tokens to stdout as they are generated",
|
|
179
|
+
)
|
|
180
|
+
q.add_argument(
|
|
181
|
+
"-v",
|
|
182
|
+
"--verbose",
|
|
183
|
+
action="store_true",
|
|
184
|
+
help="After the answer, print evaluation scores and top source paths",
|
|
185
|
+
)
|
|
64
186
|
_add_provider_args(q)
|
|
187
|
+
_add_search_args(q)
|
|
188
|
+
|
|
189
|
+
docs = sub.add_parser(
|
|
190
|
+
"docs",
|
|
191
|
+
help="Show user documentation in the terminal",
|
|
192
|
+
description="Print built-in help topics. Full docs: https://github.com/RaghavOG/rag-python/tree/main/docs",
|
|
193
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
194
|
+
epilog="topics: " + ", ".join(list_topics()),
|
|
195
|
+
)
|
|
196
|
+
docs.add_argument(
|
|
197
|
+
"topic",
|
|
198
|
+
nargs="?",
|
|
199
|
+
default="quickstart",
|
|
200
|
+
choices=list_topics(),
|
|
201
|
+
metavar="TOPIC",
|
|
202
|
+
help="Documentation topic (default: quickstart)",
|
|
203
|
+
)
|
|
204
|
+
docs.add_argument(
|
|
205
|
+
"--list",
|
|
206
|
+
action="store_true",
|
|
207
|
+
help="List all available documentation topics",
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
return parser
|
|
65
211
|
|
|
66
|
-
|
|
212
|
+
|
|
213
|
+
def main(argv: list[str] | None = None) -> None:
|
|
214
|
+
parser = _make_parser()
|
|
215
|
+
args = parser.parse_args(argv)
|
|
216
|
+
|
|
217
|
+
if args.command == "docs":
|
|
218
|
+
if args.list:
|
|
219
|
+
print_topic_list()
|
|
220
|
+
else:
|
|
221
|
+
print_topic(args.topic)
|
|
222
|
+
return
|
|
67
223
|
|
|
68
224
|
if args.command == "ingest":
|
|
69
225
|
rag = _build_rag(args)
|
|
@@ -74,10 +230,28 @@ def main() -> None:
|
|
|
74
230
|
if args.command == "query":
|
|
75
231
|
rag = _build_rag(args)
|
|
76
232
|
question = " ".join(args.question)
|
|
233
|
+
retriever = args.retriever
|
|
234
|
+
if retriever is None and args.no_multi_query:
|
|
235
|
+
retriever = "vector"
|
|
77
236
|
search = replace(
|
|
78
237
|
rag.config.search,
|
|
79
|
-
retriever=
|
|
238
|
+
retriever=retriever or rag.config.search.retriever,
|
|
239
|
+
metadata_filter=args.metadata_filter or rag.config.search.metadata_filter,
|
|
80
240
|
)
|
|
241
|
+
if args.stream:
|
|
242
|
+
stream = rag.query_stream(question, search=search)
|
|
243
|
+
for token in stream:
|
|
244
|
+
print(token, end="", flush=True)
|
|
245
|
+
print()
|
|
246
|
+
result = stream.result
|
|
247
|
+
if args.verbose:
|
|
248
|
+
print("\n--- evaluation ---")
|
|
249
|
+
print(result.evaluation)
|
|
250
|
+
print("\n--- sources ---")
|
|
251
|
+
for s in result.sources[:5]:
|
|
252
|
+
print(s.get("metadata", {}).get("source", ""), "score:", s.get("score"))
|
|
253
|
+
return
|
|
254
|
+
|
|
81
255
|
ans = rag.query(question, search=search)
|
|
82
256
|
print(ans.text)
|
|
83
257
|
if args.verbose:
|
|
@@ -89,4 +263,4 @@ def main() -> None:
|
|
|
89
263
|
|
|
90
264
|
|
|
91
265
|
if __name__ == "__main__":
|
|
92
|
-
main()
|
|
266
|
+
main(sys.argv[1:])
|
rag_python/client.py
CHANGED
|
@@ -30,7 +30,7 @@ from .options import (
|
|
|
30
30
|
SearchConfig,
|
|
31
31
|
)
|
|
32
32
|
from .providers import make_llm_provider, make_embedding_provider
|
|
33
|
-
from .rag_pipeline import ingest as _ingest, query as _query, RAGResponse
|
|
33
|
+
from .rag_pipeline import ingest as _ingest, query as _query, query_stream as _query_stream, RAGResponse, RAGStream
|
|
34
34
|
from .vector_store import set_persist_dir
|
|
35
35
|
|
|
36
36
|
|
|
@@ -60,6 +60,7 @@ class RAG:
|
|
|
60
60
|
chunk_size: int | None = None,
|
|
61
61
|
chunk_overlap: int | None = None,
|
|
62
62
|
retriever: str | None = None,
|
|
63
|
+
metadata_filter: dict | None = None,
|
|
63
64
|
top_k_retrieve: int | None = None,
|
|
64
65
|
top_k_rerank: int | None = None,
|
|
65
66
|
multi_query_n: int | None = None,
|
|
@@ -104,6 +105,8 @@ class RAG:
|
|
|
104
105
|
self.config.search = replace(self.config.search, rerank_enabled=rerank_enabled)
|
|
105
106
|
if document_extensions is not None:
|
|
106
107
|
self.config.documents = replace(self.config.documents, extensions=document_extensions)
|
|
108
|
+
if metadata_filter is not None:
|
|
109
|
+
self.config.search = replace(self.config.search, metadata_filter=metadata_filter)
|
|
107
110
|
|
|
108
111
|
self.llm = make_llm_provider(
|
|
109
112
|
llm_provider, # type: ignore[arg-type]
|
|
@@ -188,3 +191,21 @@ class RAG:
|
|
|
188
191
|
evaluation=resp.evaluation,
|
|
189
192
|
retried=resp.retried,
|
|
190
193
|
)
|
|
194
|
+
|
|
195
|
+
def query_stream(
|
|
196
|
+
self,
|
|
197
|
+
question: str,
|
|
198
|
+
*,
|
|
199
|
+
search: SearchConfig | None = None,
|
|
200
|
+
query_config: QueryConfig | None = None,
|
|
201
|
+
) -> RAGStream:
|
|
202
|
+
"""Stream answer tokens; call ``stream.result`` after iterating."""
|
|
203
|
+
return _query_stream(
|
|
204
|
+
question,
|
|
205
|
+
search=search or self.config.search,
|
|
206
|
+
query_config=query_config or self.config.query,
|
|
207
|
+
llm_model=self.llm_model,
|
|
208
|
+
embedding_model=self.embedding_model,
|
|
209
|
+
llm=self.llm,
|
|
210
|
+
embedder=self.embedder,
|
|
211
|
+
)
|
rag_python/document_loaders.py
CHANGED
|
@@ -1,4 +1,7 @@
|
|
|
1
1
|
"""Document loaders: raw data → structured text + metadata."""
|
|
2
|
+
import csv
|
|
3
|
+
import json
|
|
4
|
+
from html.parser import HTMLParser
|
|
2
5
|
from pathlib import Path
|
|
3
6
|
from dataclasses import dataclass
|
|
4
7
|
from typing import Iterator
|
|
@@ -22,18 +25,85 @@ class LoadedDocument:
|
|
|
22
25
|
metadata: dict
|
|
23
26
|
|
|
24
27
|
|
|
28
|
+
class _HTMLTextExtractor(HTMLParser):
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
self.parts: list[str] = []
|
|
32
|
+
|
|
33
|
+
def handle_data(self, data: str) -> None:
|
|
34
|
+
text = data.strip()
|
|
35
|
+
if text:
|
|
36
|
+
self.parts.append(text)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _html_to_text(html: str) -> str:
|
|
40
|
+
parser = _HTMLTextExtractor()
|
|
41
|
+
parser.feed(html)
|
|
42
|
+
return "\n".join(parser.parts)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _load_csv(path: Path, metadata: dict) -> LoadedDocument | None:
|
|
46
|
+
rows: list[str] = []
|
|
47
|
+
with path.open(encoding="utf-8", errors="replace", newline="") as f:
|
|
48
|
+
reader = csv.DictReader(f)
|
|
49
|
+
if reader.fieldnames:
|
|
50
|
+
for row in reader:
|
|
51
|
+
rows.append(", ".join(f"{k}: {v}" for k, v in row.items() if v))
|
|
52
|
+
else:
|
|
53
|
+
f.seek(0)
|
|
54
|
+
for row in csv.reader(f):
|
|
55
|
+
rows.append(", ".join(row))
|
|
56
|
+
content = "\n".join(rows)
|
|
57
|
+
metadata["rows"] = len(rows)
|
|
58
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _load_json(path: Path, metadata: dict) -> LoadedDocument | None:
|
|
62
|
+
data = json.loads(path.read_text(encoding="utf-8", errors="replace"))
|
|
63
|
+
if isinstance(data, list):
|
|
64
|
+
parts = []
|
|
65
|
+
for item in data:
|
|
66
|
+
if isinstance(item, dict) and "text" in item:
|
|
67
|
+
parts.append(str(item["text"]))
|
|
68
|
+
else:
|
|
69
|
+
parts.append(json.dumps(item, ensure_ascii=False))
|
|
70
|
+
content = "\n\n".join(parts)
|
|
71
|
+
elif isinstance(data, dict):
|
|
72
|
+
if "text" in data:
|
|
73
|
+
content = str(data["text"])
|
|
74
|
+
else:
|
|
75
|
+
content = json.dumps(data, ensure_ascii=False, indent=2)
|
|
76
|
+
else:
|
|
77
|
+
content = str(data)
|
|
78
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
79
|
+
|
|
80
|
+
|
|
25
81
|
def load_file(path: Path) -> LoadedDocument | None:
|
|
26
|
-
"""Load a single file (PDF, TXT, DOCX, MD) into text + metadata."""
|
|
82
|
+
"""Load a single file (PDF, TXT, DOCX, MD, CSV, JSON, HTML) into text + metadata."""
|
|
27
83
|
path = Path(path)
|
|
28
84
|
if not path.exists():
|
|
29
85
|
return None
|
|
30
86
|
suffix = path.suffix.lower()
|
|
31
87
|
metadata = {"source": str(path), "filename": path.name}
|
|
32
88
|
|
|
33
|
-
if suffix
|
|
89
|
+
if suffix in (".txt", ".md"):
|
|
34
90
|
content = path.read_text(encoding="utf-8", errors="replace")
|
|
35
91
|
return LoadedDocument(content=content, source=str(path), metadata=metadata)
|
|
36
92
|
|
|
93
|
+
if suffix == ".html":
|
|
94
|
+
html = path.read_text(encoding="utf-8", errors="replace")
|
|
95
|
+
content = _html_to_text(html)
|
|
96
|
+
return LoadedDocument(content=content, source=str(path), metadata=metadata) if content.strip() else None
|
|
97
|
+
|
|
98
|
+
if suffix == ".csv":
|
|
99
|
+
return _load_csv(path, metadata)
|
|
100
|
+
|
|
101
|
+
if suffix == ".json":
|
|
102
|
+
try:
|
|
103
|
+
return _load_json(path, metadata)
|
|
104
|
+
except json.JSONDecodeError:
|
|
105
|
+
return None
|
|
106
|
+
|
|
37
107
|
if suffix == ".pdf" and PdfReader:
|
|
38
108
|
try:
|
|
39
109
|
reader = PdfReader(path)
|
|
@@ -61,7 +131,10 @@ def load_file(path: Path) -> LoadedDocument | None:
|
|
|
61
131
|
return None
|
|
62
132
|
|
|
63
133
|
|
|
64
|
-
def load_directory(
|
|
134
|
+
def load_directory(
|
|
135
|
+
dir_path: Path,
|
|
136
|
+
extensions: tuple = (".txt", ".md", ".pdf", ".docx", ".csv", ".json", ".html"),
|
|
137
|
+
) -> Iterator[LoadedDocument]:
|
|
65
138
|
"""Yield LoadedDocument for each supported file under dir_path."""
|
|
66
139
|
dir_path = Path(dir_path)
|
|
67
140
|
if not dir_path.is_dir():
|
|
@@ -71,4 +144,3 @@ def load_directory(dir_path: Path, extensions: tuple = (".txt", ".md", ".pdf", "
|
|
|
71
144
|
doc = load_file(f)
|
|
72
145
|
if doc and doc.content.strip():
|
|
73
146
|
yield doc
|
|
74
|
-
|
rag_python/generation.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
"""LLM generation with context (RAG)."""
|
|
2
|
+
from collections.abc import Iterator
|
|
3
|
+
|
|
2
4
|
from .config import LLM_MODEL
|
|
3
5
|
from .providers import LLMProvider, make_llm_provider
|
|
6
|
+
from .providers.streaming import stream_generate
|
|
4
7
|
|
|
5
8
|
|
|
6
9
|
RAG_SYSTEM = (
|
|
@@ -10,6 +13,11 @@ RAG_SYSTEM = (
|
|
|
10
13
|
)
|
|
11
14
|
|
|
12
15
|
|
|
16
|
+
def _build_user_prompt(query: str, context_chunks: list[str]) -> str:
|
|
17
|
+
context = "\n\n---\n\n".join(context_chunks)
|
|
18
|
+
return f"Context:\n{context}\n\nQuestion: {query}\n\nAnswer:"
|
|
19
|
+
|
|
20
|
+
|
|
13
21
|
def generate(
|
|
14
22
|
query: str,
|
|
15
23
|
context_chunks: list[str],
|
|
@@ -20,12 +28,11 @@ def generate(
|
|
|
20
28
|
) -> str:
|
|
21
29
|
"""Generate answer from query and retrieved context."""
|
|
22
30
|
llm = llm or make_llm_provider("openai")
|
|
23
|
-
context = "\n\n---\n\n".join(context_chunks)
|
|
24
31
|
sys = system_prompt or RAG_SYSTEM
|
|
25
32
|
try:
|
|
26
33
|
return llm.generate(
|
|
27
34
|
system=sys,
|
|
28
|
-
user=
|
|
35
|
+
user=_build_user_prompt(query, context_chunks),
|
|
29
36
|
model=model or LLM_MODEL,
|
|
30
37
|
temperature=0.2,
|
|
31
38
|
max_tokens=1024,
|
|
@@ -33,3 +40,26 @@ def generate(
|
|
|
33
40
|
except Exception as e:
|
|
34
41
|
return f"[Generation error: {e}]"
|
|
35
42
|
|
|
43
|
+
|
|
44
|
+
def generate_stream(
|
|
45
|
+
query: str,
|
|
46
|
+
context_chunks: list[str],
|
|
47
|
+
*,
|
|
48
|
+
model: str | None = None,
|
|
49
|
+
system_prompt: str | None = None,
|
|
50
|
+
llm: LLMProvider | None = None,
|
|
51
|
+
) -> Iterator[str]:
|
|
52
|
+
"""Stream answer tokens from query and retrieved context."""
|
|
53
|
+
llm = llm or make_llm_provider("openai")
|
|
54
|
+
sys = system_prompt or RAG_SYSTEM
|
|
55
|
+
try:
|
|
56
|
+
yield from stream_generate(
|
|
57
|
+
llm,
|
|
58
|
+
system=sys,
|
|
59
|
+
user=_build_user_prompt(query, context_chunks),
|
|
60
|
+
model=model or LLM_MODEL,
|
|
61
|
+
temperature=0.2,
|
|
62
|
+
max_tokens=1024,
|
|
63
|
+
)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
yield f"[Generation error: {e}]"
|