ragit 0.1__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +128 -2
- ragit/assistant.py +757 -0
- ragit/config.py +204 -0
- ragit/core/__init__.py +5 -0
- ragit/core/experiment/__init__.py +22 -0
- ragit/core/experiment/experiment.py +577 -0
- ragit/core/experiment/results.py +131 -0
- ragit/exceptions.py +271 -0
- ragit/loaders.py +401 -0
- ragit/logging.py +194 -0
- ragit/monitor.py +307 -0
- ragit/providers/__init__.py +35 -0
- ragit/providers/base.py +147 -0
- ragit/providers/function_adapter.py +237 -0
- ragit/providers/ollama.py +670 -0
- ragit/utils/__init__.py +105 -0
- ragit/version.py +5 -0
- ragit-0.10.1.dist-info/METADATA +153 -0
- ragit-0.10.1.dist-info/RECORD +22 -0
- {ragit-0.1.dist-info → ragit-0.10.1.dist-info}/WHEEL +1 -1
- ragit-0.10.1.dist-info/licenses/LICENSE +201 -0
- ragit/main.py +0 -384
- ragit-0.1.dist-info/METADATA +0 -10
- ragit-0.1.dist-info/RECORD +0 -6
- {ragit-0.1.dist-info → ragit-0.10.1.dist-info}/top_level.txt +0 -0
ragit/assistant.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
High-level RAG Assistant for document Q&A and code generation.
|
|
7
|
+
|
|
8
|
+
Provides a simple interface for RAG-based tasks.
|
|
9
|
+
|
|
10
|
+
Note: This class is NOT thread-safe. Do not share instances across threads.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from collections.abc import Callable
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import TYPE_CHECKING
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from numpy.typing import NDArray
|
|
19
|
+
|
|
20
|
+
from ragit.core.experiment.experiment import Chunk, Document
|
|
21
|
+
from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
|
|
22
|
+
from ragit.logging import log_operation
|
|
23
|
+
from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
|
|
24
|
+
from ragit.providers.function_adapter import FunctionProvider
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from numpy.typing import NDArray
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RAGAssistant:
|
|
31
|
+
"""
|
|
32
|
+
High-level RAG assistant for document Q&A and generation.
|
|
33
|
+
|
|
34
|
+
Handles document indexing, retrieval, and LLM generation in one simple API.
|
|
35
|
+
|
|
36
|
+
Parameters
|
|
37
|
+
----------
|
|
38
|
+
documents : list[Document] or str or Path
|
|
39
|
+
Documents to index. Can be:
|
|
40
|
+
- List of Document objects
|
|
41
|
+
- Path to a single file
|
|
42
|
+
- Path to a directory (will load all .txt, .md, .rst files)
|
|
43
|
+
embed_fn : Callable[[str], list[float]], optional
|
|
44
|
+
Function that takes text and returns an embedding vector.
|
|
45
|
+
If provided, creates a FunctionProvider internally.
|
|
46
|
+
generate_fn : Callable, optional
|
|
47
|
+
Function for text generation. Supports (prompt) or (prompt, system_prompt).
|
|
48
|
+
If provided without embed_fn, must also provide embed_fn.
|
|
49
|
+
provider : BaseEmbeddingProvider, optional
|
|
50
|
+
Provider for embeddings (and optionally LLM). If embed_fn is provided,
|
|
51
|
+
this is ignored for embeddings.
|
|
52
|
+
embedding_model : str, optional
|
|
53
|
+
Embedding model name (used with provider).
|
|
54
|
+
llm_model : str, optional
|
|
55
|
+
LLM model name (used with provider).
|
|
56
|
+
chunk_size : int, optional
|
|
57
|
+
Chunk size for splitting documents (default: 512).
|
|
58
|
+
chunk_overlap : int, optional
|
|
59
|
+
Overlap between chunks (default: 50).
|
|
60
|
+
|
|
61
|
+
Raises
|
|
62
|
+
------
|
|
63
|
+
ValueError
|
|
64
|
+
If neither embed_fn nor provider is provided.
|
|
65
|
+
|
|
66
|
+
Note
|
|
67
|
+
----
|
|
68
|
+
This class is NOT thread-safe. Each thread should have its own instance.
|
|
69
|
+
|
|
70
|
+
Examples
|
|
71
|
+
--------
|
|
72
|
+
>>> # With custom embedding function (retrieval-only)
|
|
73
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed)
|
|
74
|
+
>>> results = assistant.retrieve("query")
|
|
75
|
+
>>>
|
|
76
|
+
>>> # With custom embedding and LLM functions (full RAG)
|
|
77
|
+
>>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
|
|
78
|
+
>>> answer = assistant.ask("What is X?")
|
|
79
|
+
>>>
|
|
80
|
+
>>> # With Ollama provider (supports nomic-embed-text)
|
|
81
|
+
>>> from ragit.providers import OllamaProvider
|
|
82
|
+
>>> assistant = RAGAssistant(docs, provider=OllamaProvider())
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
documents: list[Document] | str | Path,
|
|
88
|
+
embed_fn: Callable[[str], list[float]] | None = None,
|
|
89
|
+
generate_fn: Callable[..., str] | None = None,
|
|
90
|
+
provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
|
|
91
|
+
embedding_model: str | None = None,
|
|
92
|
+
llm_model: str | None = None,
|
|
93
|
+
chunk_size: int = 512,
|
|
94
|
+
chunk_overlap: int = 50,
|
|
95
|
+
):
|
|
96
|
+
# Resolve provider from embed_fn/generate_fn or explicit provider
|
|
97
|
+
self._embedding_provider: BaseEmbeddingProvider
|
|
98
|
+
self._llm_provider: BaseLLMProvider | None = None
|
|
99
|
+
|
|
100
|
+
if embed_fn is not None:
|
|
101
|
+
# Create FunctionProvider from provided functions
|
|
102
|
+
function_provider = FunctionProvider(
|
|
103
|
+
embed_fn=embed_fn,
|
|
104
|
+
generate_fn=generate_fn,
|
|
105
|
+
)
|
|
106
|
+
self._embedding_provider = function_provider
|
|
107
|
+
if generate_fn is not None:
|
|
108
|
+
self._llm_provider = function_provider
|
|
109
|
+
elif provider is not None and isinstance(provider, BaseLLMProvider):
|
|
110
|
+
# Use explicit provider for LLM if function_provider doesn't have LLM
|
|
111
|
+
self._llm_provider = provider
|
|
112
|
+
elif provider is not None:
|
|
113
|
+
# Use explicit provider
|
|
114
|
+
if not isinstance(provider, BaseEmbeddingProvider):
|
|
115
|
+
raise ValueError(
|
|
116
|
+
"Provider must implement BaseEmbeddingProvider for embeddings. Alternatively, provide embed_fn."
|
|
117
|
+
)
|
|
118
|
+
self._embedding_provider = provider
|
|
119
|
+
if isinstance(provider, BaseLLMProvider):
|
|
120
|
+
self._llm_provider = provider
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
"Must provide embed_fn or provider for embeddings. "
|
|
124
|
+
"Examples:\n"
|
|
125
|
+
" RAGAssistant(docs, embed_fn=my_embed_function)\n"
|
|
126
|
+
" RAGAssistant(docs, provider=OllamaProvider())"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
self.embedding_model = embedding_model or "default"
|
|
130
|
+
self.llm_model = llm_model or "default"
|
|
131
|
+
self.chunk_size = chunk_size
|
|
132
|
+
self.chunk_overlap = chunk_overlap
|
|
133
|
+
|
|
134
|
+
# Load documents if path provided
|
|
135
|
+
self.documents = self._load_documents(documents)
|
|
136
|
+
|
|
137
|
+
# Index chunks - embeddings stored as pre-normalized numpy matrix for fast search
|
|
138
|
+
self._chunks: tuple[Chunk, ...] = ()
|
|
139
|
+
self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
|
|
140
|
+
self._build_index()
|
|
141
|
+
|
|
142
|
+
def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
|
|
143
|
+
"""Load documents from various sources."""
|
|
144
|
+
if isinstance(documents, list):
|
|
145
|
+
return documents
|
|
146
|
+
|
|
147
|
+
path = Path(documents)
|
|
148
|
+
|
|
149
|
+
if path.is_file():
|
|
150
|
+
return [load_text(path)]
|
|
151
|
+
|
|
152
|
+
if path.is_dir():
|
|
153
|
+
docs: list[Document] = []
|
|
154
|
+
for pattern in (
|
|
155
|
+
"*.txt",
|
|
156
|
+
"*.md",
|
|
157
|
+
"*.rst",
|
|
158
|
+
"*.py",
|
|
159
|
+
"*.js",
|
|
160
|
+
"*.ts",
|
|
161
|
+
"*.go",
|
|
162
|
+
"*.java",
|
|
163
|
+
"*.c",
|
|
164
|
+
"*.cpp",
|
|
165
|
+
"*.h",
|
|
166
|
+
"*.hpp",
|
|
167
|
+
):
|
|
168
|
+
docs.extend(load_directory(path, pattern))
|
|
169
|
+
return docs
|
|
170
|
+
|
|
171
|
+
raise ValueError(f"Invalid documents source: {documents}")
|
|
172
|
+
|
|
173
|
+
def _build_index(self) -> None:
|
|
174
|
+
"""Build vector index from documents using batch embedding."""
|
|
175
|
+
all_chunks: list[Chunk] = []
|
|
176
|
+
|
|
177
|
+
for doc in self.documents:
|
|
178
|
+
# Use RST section chunking for .rst files, otherwise regular chunking
|
|
179
|
+
if doc.metadata.get("filename", "").endswith(".rst"):
|
|
180
|
+
chunks = chunk_rst_sections(doc.content, doc.id)
|
|
181
|
+
else:
|
|
182
|
+
chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
|
|
183
|
+
all_chunks.extend(chunks)
|
|
184
|
+
|
|
185
|
+
if not all_chunks:
|
|
186
|
+
self._chunks = ()
|
|
187
|
+
self._embedding_matrix = None
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
# Batch embed all chunks at once (single API call)
|
|
191
|
+
texts = [chunk.content for chunk in all_chunks]
|
|
192
|
+
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
193
|
+
|
|
194
|
+
# Build embedding matrix directly (skip storing in chunks to avoid duplication)
|
|
195
|
+
embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
196
|
+
|
|
197
|
+
# Pre-normalize for fast cosine similarity (normalize once, use many times)
|
|
198
|
+
norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
|
|
199
|
+
norms[norms == 0] = 1 # Avoid division by zero
|
|
200
|
+
|
|
201
|
+
# Store as immutable tuple and pre-normalized numpy matrix
|
|
202
|
+
self._chunks = tuple(all_chunks)
|
|
203
|
+
self._embedding_matrix = embedding_matrix / norms
|
|
204
|
+
|
|
205
|
+
def add_documents(self, documents: list[Document] | str | Path) -> int:
|
|
206
|
+
"""Add documents to the existing index incrementally.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
documents: Documents to add.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Number of chunks added.
|
|
213
|
+
"""
|
|
214
|
+
new_docs = self._load_documents(documents)
|
|
215
|
+
if not new_docs:
|
|
216
|
+
return 0
|
|
217
|
+
|
|
218
|
+
self.documents.extend(new_docs)
|
|
219
|
+
|
|
220
|
+
# Chunk new docs
|
|
221
|
+
new_chunks: list[Chunk] = []
|
|
222
|
+
for doc in new_docs:
|
|
223
|
+
if doc.metadata.get("filename", "").endswith(".rst"):
|
|
224
|
+
chunks = chunk_rst_sections(doc.content, doc.id)
|
|
225
|
+
else:
|
|
226
|
+
chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
|
|
227
|
+
new_chunks.extend(chunks)
|
|
228
|
+
|
|
229
|
+
if not new_chunks:
|
|
230
|
+
return 0
|
|
231
|
+
|
|
232
|
+
# Embed new chunks
|
|
233
|
+
texts = [chunk.content for chunk in new_chunks]
|
|
234
|
+
responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
|
|
235
|
+
|
|
236
|
+
new_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
|
|
237
|
+
|
|
238
|
+
# Normalize
|
|
239
|
+
norms = np.linalg.norm(new_matrix, axis=1, keepdims=True)
|
|
240
|
+
norms[norms == 0] = 1
|
|
241
|
+
new_matrix_norm = new_matrix / norms
|
|
242
|
+
|
|
243
|
+
# Update state
|
|
244
|
+
current_chunks = list(self._chunks)
|
|
245
|
+
current_chunks.extend(new_chunks)
|
|
246
|
+
self._chunks = tuple(current_chunks)
|
|
247
|
+
|
|
248
|
+
if self._embedding_matrix is None:
|
|
249
|
+
self._embedding_matrix = new_matrix_norm
|
|
250
|
+
else:
|
|
251
|
+
self._embedding_matrix = np.vstack((self._embedding_matrix, new_matrix_norm))
|
|
252
|
+
|
|
253
|
+
return len(new_chunks)
|
|
254
|
+
|
|
255
|
+
def remove_documents(self, source_path_pattern: str) -> int:
|
|
256
|
+
"""Remove documents matching a source path pattern.
|
|
257
|
+
|
|
258
|
+
Args:
|
|
259
|
+
source_path_pattern: Glob pattern to match 'source' metadata.
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Number of chunks removed.
|
|
263
|
+
"""
|
|
264
|
+
import fnmatch
|
|
265
|
+
|
|
266
|
+
if not self._chunks:
|
|
267
|
+
return 0
|
|
268
|
+
|
|
269
|
+
indices_to_keep = []
|
|
270
|
+
kept_chunks = []
|
|
271
|
+
removed_count = 0
|
|
272
|
+
|
|
273
|
+
for i, chunk in enumerate(self._chunks):
|
|
274
|
+
source = chunk.metadata.get("source", "")
|
|
275
|
+
if not source or not fnmatch.fnmatch(source, source_path_pattern):
|
|
276
|
+
indices_to_keep.append(i)
|
|
277
|
+
kept_chunks.append(chunk)
|
|
278
|
+
else:
|
|
279
|
+
removed_count += 1
|
|
280
|
+
|
|
281
|
+
if removed_count == 0:
|
|
282
|
+
return 0
|
|
283
|
+
|
|
284
|
+
self._chunks = tuple(kept_chunks)
|
|
285
|
+
|
|
286
|
+
if self._embedding_matrix is not None:
|
|
287
|
+
if not kept_chunks:
|
|
288
|
+
self._embedding_matrix = None
|
|
289
|
+
else:
|
|
290
|
+
self._embedding_matrix = self._embedding_matrix[indices_to_keep]
|
|
291
|
+
|
|
292
|
+
# Also remove from self.documents
|
|
293
|
+
self.documents = [
|
|
294
|
+
doc for doc in self.documents if not fnmatch.fnmatch(doc.metadata.get("source", ""), source_path_pattern)
|
|
295
|
+
]
|
|
296
|
+
|
|
297
|
+
return removed_count
|
|
298
|
+
|
|
299
|
+
def update_documents(self, documents: list[Document] | str | Path) -> int:
|
|
300
|
+
"""Update existing documents (remove old, add new).
|
|
301
|
+
|
|
302
|
+
Uses document source path to identify what to remove.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
documents: New versions of documents.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Number of chunks added.
|
|
309
|
+
"""
|
|
310
|
+
new_docs = self._load_documents(documents)
|
|
311
|
+
if not new_docs:
|
|
312
|
+
return 0
|
|
313
|
+
|
|
314
|
+
# Identify sources to remove
|
|
315
|
+
sources_to_remove = set()
|
|
316
|
+
for doc in new_docs:
|
|
317
|
+
source = doc.metadata.get("source")
|
|
318
|
+
if source:
|
|
319
|
+
sources_to_remove.add(source)
|
|
320
|
+
|
|
321
|
+
# Remove old versions
|
|
322
|
+
for source in sources_to_remove:
|
|
323
|
+
self.remove_documents(source)
|
|
324
|
+
|
|
325
|
+
# Add new versions
|
|
326
|
+
return self.add_documents(new_docs)
|
|
327
|
+
|
|
328
|
+
def retrieve(self, query: str, top_k: int = 3) -> list[tuple[Chunk, float]]:
|
|
329
|
+
"""
|
|
330
|
+
Retrieve relevant chunks for a query.
|
|
331
|
+
|
|
332
|
+
Uses vectorized cosine similarity for fast search over all chunks.
|
|
333
|
+
|
|
334
|
+
Parameters
|
|
335
|
+
----------
|
|
336
|
+
query : str
|
|
337
|
+
Search query.
|
|
338
|
+
top_k : int
|
|
339
|
+
Number of chunks to return (default: 3).
|
|
340
|
+
|
|
341
|
+
Returns
|
|
342
|
+
-------
|
|
343
|
+
list[tuple[Chunk, float]]
|
|
344
|
+
List of (chunk, similarity_score) tuples, sorted by relevance.
|
|
345
|
+
|
|
346
|
+
Examples
|
|
347
|
+
--------
|
|
348
|
+
>>> results = assistant.retrieve("how to create a route")
|
|
349
|
+
>>> for chunk, score in results:
|
|
350
|
+
... print(f"{score:.2f}: {chunk.content[:100]}...")
|
|
351
|
+
"""
|
|
352
|
+
if not self._chunks or self._embedding_matrix is None:
|
|
353
|
+
return []
|
|
354
|
+
|
|
355
|
+
# Get query embedding and normalize
|
|
356
|
+
query_response = self._embedding_provider.embed(query, self.embedding_model)
|
|
357
|
+
query_vec = np.array(query_response.embedding, dtype=np.float64)
|
|
358
|
+
query_norm = np.linalg.norm(query_vec)
|
|
359
|
+
if query_norm == 0:
|
|
360
|
+
return []
|
|
361
|
+
query_normalized = query_vec / query_norm
|
|
362
|
+
|
|
363
|
+
# Fast cosine similarity: matrix is pre-normalized, just dot product
|
|
364
|
+
similarities = self._embedding_matrix @ query_normalized
|
|
365
|
+
|
|
366
|
+
# Get top_k indices using argpartition (faster than full sort for large arrays)
|
|
367
|
+
if len(similarities) <= top_k:
|
|
368
|
+
top_indices = np.argsort(similarities)[::-1]
|
|
369
|
+
else:
|
|
370
|
+
# Partial sort - only find top_k elements
|
|
371
|
+
top_indices = np.argpartition(similarities, -top_k)[-top_k:]
|
|
372
|
+
# Sort the top_k by score
|
|
373
|
+
top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
|
|
374
|
+
|
|
375
|
+
return [(self._chunks[i], float(similarities[i])) for i in top_indices]
|
|
376
|
+
|
|
377
|
+
def retrieve_with_context(
|
|
378
|
+
self,
|
|
379
|
+
query: str,
|
|
380
|
+
top_k: int = 3,
|
|
381
|
+
window_size: int = 1,
|
|
382
|
+
min_score: float = 0.0,
|
|
383
|
+
) -> list[tuple[Chunk, float]]:
|
|
384
|
+
"""
|
|
385
|
+
Retrieve chunks with adjacent context expansion (window search).
|
|
386
|
+
|
|
387
|
+
For each retrieved chunk, also includes adjacent chunks from the
|
|
388
|
+
same document to provide more context. This is useful when relevant
|
|
389
|
+
information spans multiple chunks.
|
|
390
|
+
|
|
391
|
+
Pattern inspired by ai4rag window_search.
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
query : str
|
|
396
|
+
Search query.
|
|
397
|
+
top_k : int
|
|
398
|
+
Number of initial chunks to retrieve (default: 3).
|
|
399
|
+
window_size : int
|
|
400
|
+
Number of adjacent chunks to include on each side (default: 1).
|
|
401
|
+
Set to 0 to disable window expansion.
|
|
402
|
+
min_score : float
|
|
403
|
+
Minimum similarity score threshold (default: 0.0).
|
|
404
|
+
|
|
405
|
+
Returns
|
|
406
|
+
-------
|
|
407
|
+
list[tuple[Chunk, float]]
|
|
408
|
+
List of (chunk, similarity_score) tuples, sorted by relevance.
|
|
409
|
+
Adjacent chunks have slightly lower scores.
|
|
410
|
+
|
|
411
|
+
Examples
|
|
412
|
+
--------
|
|
413
|
+
>>> # Get chunks with 1 adjacent chunk on each side
|
|
414
|
+
>>> results = assistant.retrieve_with_context("query", window_size=1)
|
|
415
|
+
>>> for chunk, score in results:
|
|
416
|
+
... print(f"{score:.2f}: {chunk.content[:50]}...")
|
|
417
|
+
"""
|
|
418
|
+
with log_operation("retrieve_with_context", query_len=len(query), top_k=top_k, window_size=window_size) as ctx:
|
|
419
|
+
# Get initial results (more than top_k to account for filtering)
|
|
420
|
+
results = self.retrieve(query, top_k * 2)
|
|
421
|
+
|
|
422
|
+
# Apply minimum score threshold
|
|
423
|
+
if min_score > 0:
|
|
424
|
+
results = [(chunk, score) for chunk, score in results if score >= min_score]
|
|
425
|
+
|
|
426
|
+
if window_size == 0 or not results:
|
|
427
|
+
ctx["expanded_chunks"] = len(results)
|
|
428
|
+
return results[:top_k]
|
|
429
|
+
|
|
430
|
+
# Build chunk index for fast lookup
|
|
431
|
+
chunk_to_idx = {id(chunk): i for i, chunk in enumerate(self._chunks)}
|
|
432
|
+
|
|
433
|
+
expanded_results: list[tuple[Chunk, float]] = []
|
|
434
|
+
seen_indices: set[int] = set()
|
|
435
|
+
|
|
436
|
+
for chunk, score in results[:top_k]:
|
|
437
|
+
chunk_idx = chunk_to_idx.get(id(chunk))
|
|
438
|
+
if chunk_idx is None:
|
|
439
|
+
expanded_results.append((chunk, score))
|
|
440
|
+
continue
|
|
441
|
+
|
|
442
|
+
# Get window of adjacent chunks from same document
|
|
443
|
+
start_idx = max(0, chunk_idx - window_size)
|
|
444
|
+
end_idx = min(len(self._chunks), chunk_idx + window_size + 1)
|
|
445
|
+
|
|
446
|
+
for idx in range(start_idx, end_idx):
|
|
447
|
+
if idx in seen_indices:
|
|
448
|
+
continue
|
|
449
|
+
|
|
450
|
+
adjacent_chunk = self._chunks[idx]
|
|
451
|
+
# Only include adjacent chunks from same document
|
|
452
|
+
if adjacent_chunk.doc_id == chunk.doc_id:
|
|
453
|
+
seen_indices.add(idx)
|
|
454
|
+
# Original chunk keeps full score, adjacent get 80%
|
|
455
|
+
adj_score = score if idx == chunk_idx else score * 0.8
|
|
456
|
+
expanded_results.append((adjacent_chunk, adj_score))
|
|
457
|
+
|
|
458
|
+
# Sort by score (highest first)
|
|
459
|
+
expanded_results.sort(key=lambda x: (-x[1], self._chunks.index(x[0]) if x[0] in self._chunks else 0))
|
|
460
|
+
ctx["expanded_chunks"] = len(expanded_results)
|
|
461
|
+
|
|
462
|
+
return expanded_results
|
|
463
|
+
|
|
464
|
+
def get_context_with_window(
|
|
465
|
+
self,
|
|
466
|
+
query: str,
|
|
467
|
+
top_k: int = 3,
|
|
468
|
+
window_size: int = 1,
|
|
469
|
+
min_score: float = 0.0,
|
|
470
|
+
) -> str:
|
|
471
|
+
"""
|
|
472
|
+
Get formatted context with adjacent chunk expansion.
|
|
473
|
+
|
|
474
|
+
Merges overlapping text from adjacent chunks intelligently.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
query : str
|
|
479
|
+
Search query.
|
|
480
|
+
top_k : int
|
|
481
|
+
Number of initial chunks to retrieve.
|
|
482
|
+
window_size : int
|
|
483
|
+
Number of adjacent chunks on each side.
|
|
484
|
+
min_score : float
|
|
485
|
+
Minimum similarity score threshold.
|
|
486
|
+
|
|
487
|
+
Returns
|
|
488
|
+
-------
|
|
489
|
+
str
|
|
490
|
+
Formatted context string with merged chunks.
|
|
491
|
+
"""
|
|
492
|
+
results = self.retrieve_with_context(query, top_k, window_size, min_score)
|
|
493
|
+
|
|
494
|
+
if not results:
|
|
495
|
+
return ""
|
|
496
|
+
|
|
497
|
+
# Group chunks by document to merge properly
|
|
498
|
+
doc_chunks: dict[str, list[tuple[Chunk, float]]] = {}
|
|
499
|
+
for chunk, score in results:
|
|
500
|
+
doc_id = chunk.doc_id or "unknown"
|
|
501
|
+
if doc_id not in doc_chunks:
|
|
502
|
+
doc_chunks[doc_id] = []
|
|
503
|
+
doc_chunks[doc_id].append((chunk, score))
|
|
504
|
+
|
|
505
|
+
merged_sections: list[str] = []
|
|
506
|
+
|
|
507
|
+
for _doc_id, chunks in doc_chunks.items():
|
|
508
|
+
# Sort chunks by their position in the original list
|
|
509
|
+
chunks.sort(key=lambda x: self._chunks.index(x[0]) if x[0] in self._chunks else 0)
|
|
510
|
+
|
|
511
|
+
# Merge overlapping text
|
|
512
|
+
merged_content = []
|
|
513
|
+
for chunk, _ in chunks:
|
|
514
|
+
if merged_content:
|
|
515
|
+
# Check for overlap with previous chunk
|
|
516
|
+
prev_content = merged_content[-1]
|
|
517
|
+
non_overlapping = self._get_non_overlapping_text(prev_content, chunk.content)
|
|
518
|
+
if non_overlapping != chunk.content:
|
|
519
|
+
# Found overlap, extend previous chunk
|
|
520
|
+
merged_content[-1] = prev_content + non_overlapping
|
|
521
|
+
else:
|
|
522
|
+
# No overlap, add as new section
|
|
523
|
+
merged_content.append(chunk.content)
|
|
524
|
+
else:
|
|
525
|
+
merged_content.append(chunk.content)
|
|
526
|
+
|
|
527
|
+
merged_sections.append("\n".join(merged_content))
|
|
528
|
+
|
|
529
|
+
return "\n\n---\n\n".join(merged_sections)
|
|
530
|
+
|
|
531
|
+
def _get_non_overlapping_text(self, str1: str, str2: str) -> str:
|
|
532
|
+
"""
|
|
533
|
+
Find non-overlapping portion of str2 when appending after str1.
|
|
534
|
+
|
|
535
|
+
Detects overlap where the end of str1 matches the beginning of str2,
|
|
536
|
+
and returns only the non-overlapping portion of str2.
|
|
537
|
+
|
|
538
|
+
Pattern from ai4rag vector_store/utils.py.
|
|
539
|
+
|
|
540
|
+
Parameters
|
|
541
|
+
----------
|
|
542
|
+
str1 : str
|
|
543
|
+
First string (previous content).
|
|
544
|
+
str2 : str
|
|
545
|
+
Second string (content to potentially append).
|
|
546
|
+
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
str
|
|
550
|
+
Non-overlapping portion of str2, or full str2 if no overlap.
|
|
551
|
+
"""
|
|
552
|
+
# Limit overlap search to avoid O(n^2) for large strings
|
|
553
|
+
max_overlap = min(len(str1), len(str2), 200)
|
|
554
|
+
|
|
555
|
+
for i in range(max_overlap, 0, -1):
|
|
556
|
+
if str1[-i:] == str2[:i]:
|
|
557
|
+
return str2[i:]
|
|
558
|
+
|
|
559
|
+
return str2
|
|
560
|
+
|
|
561
|
+
def get_context(self, query: str, top_k: int = 3) -> str:
|
|
562
|
+
"""
|
|
563
|
+
Get formatted context string from retrieved chunks.
|
|
564
|
+
|
|
565
|
+
Parameters
|
|
566
|
+
----------
|
|
567
|
+
query : str
|
|
568
|
+
Search query.
|
|
569
|
+
top_k : int
|
|
570
|
+
Number of chunks to include.
|
|
571
|
+
|
|
572
|
+
Returns
|
|
573
|
+
-------
|
|
574
|
+
str
|
|
575
|
+
Formatted context string.
|
|
576
|
+
"""
|
|
577
|
+
results = self.retrieve(query, top_k)
|
|
578
|
+
return "\n\n---\n\n".join(chunk.content for chunk, _ in results)
|
|
579
|
+
|
|
580
|
+
def _ensure_llm(self) -> BaseLLMProvider:
|
|
581
|
+
"""Ensure LLM provider is available."""
|
|
582
|
+
if self._llm_provider is None:
|
|
583
|
+
raise NotImplementedError(
|
|
584
|
+
"No LLM configured. Provide generate_fn or a provider with LLM support "
|
|
585
|
+
"to use ask(), generate(), or generate_code() methods."
|
|
586
|
+
)
|
|
587
|
+
return self._llm_provider
|
|
588
|
+
|
|
589
|
+
def generate(
|
|
590
|
+
self,
|
|
591
|
+
prompt: str,
|
|
592
|
+
system_prompt: str | None = None,
|
|
593
|
+
temperature: float = 0.7,
|
|
594
|
+
) -> str:
|
|
595
|
+
"""
|
|
596
|
+
Generate text using the LLM (without retrieval).
|
|
597
|
+
|
|
598
|
+
Parameters
|
|
599
|
+
----------
|
|
600
|
+
prompt : str
|
|
601
|
+
User prompt.
|
|
602
|
+
system_prompt : str, optional
|
|
603
|
+
System prompt for context.
|
|
604
|
+
temperature : float
|
|
605
|
+
Sampling temperature (default: 0.7).
|
|
606
|
+
|
|
607
|
+
Returns
|
|
608
|
+
-------
|
|
609
|
+
str
|
|
610
|
+
Generated text.
|
|
611
|
+
|
|
612
|
+
Raises
|
|
613
|
+
------
|
|
614
|
+
NotImplementedError
|
|
615
|
+
If no LLM is configured.
|
|
616
|
+
"""
|
|
617
|
+
llm = self._ensure_llm()
|
|
618
|
+
response = llm.generate(
|
|
619
|
+
prompt=prompt,
|
|
620
|
+
model=self.llm_model,
|
|
621
|
+
system_prompt=system_prompt,
|
|
622
|
+
temperature=temperature,
|
|
623
|
+
)
|
|
624
|
+
return response.text
|
|
625
|
+
|
|
626
|
+
def ask(
|
|
627
|
+
self,
|
|
628
|
+
question: str,
|
|
629
|
+
system_prompt: str | None = None,
|
|
630
|
+
top_k: int = 3,
|
|
631
|
+
temperature: float = 0.7,
|
|
632
|
+
) -> str:
|
|
633
|
+
"""
|
|
634
|
+
Ask a question using RAG (retrieve + generate).
|
|
635
|
+
|
|
636
|
+
Parameters
|
|
637
|
+
----------
|
|
638
|
+
question : str
|
|
639
|
+
Question to answer.
|
|
640
|
+
system_prompt : str, optional
|
|
641
|
+
System prompt. Defaults to a helpful assistant prompt.
|
|
642
|
+
top_k : int
|
|
643
|
+
Number of context chunks to retrieve (default: 3).
|
|
644
|
+
temperature : float
|
|
645
|
+
Sampling temperature (default: 0.7).
|
|
646
|
+
|
|
647
|
+
Returns
|
|
648
|
+
-------
|
|
649
|
+
str
|
|
650
|
+
Generated answer.
|
|
651
|
+
|
|
652
|
+
Raises
|
|
653
|
+
------
|
|
654
|
+
NotImplementedError
|
|
655
|
+
If no LLM is configured.
|
|
656
|
+
|
|
657
|
+
Examples
|
|
658
|
+
--------
|
|
659
|
+
>>> answer = assistant.ask("How do I create a REST API?")
|
|
660
|
+
>>> print(answer)
|
|
661
|
+
"""
|
|
662
|
+
# Retrieve context
|
|
663
|
+
context = self.get_context(question, top_k)
|
|
664
|
+
|
|
665
|
+
# Default system prompt
|
|
666
|
+
if system_prompt is None:
|
|
667
|
+
system_prompt = """You are a helpful assistant. Answer questions based on the provided context.
|
|
668
|
+
If the context doesn't contain enough information, say so. Be concise and accurate."""
|
|
669
|
+
|
|
670
|
+
# Build prompt with context
|
|
671
|
+
prompt = f"""Context:
|
|
672
|
+
{context}
|
|
673
|
+
|
|
674
|
+
Question: {question}
|
|
675
|
+
|
|
676
|
+
Answer:"""
|
|
677
|
+
|
|
678
|
+
return self.generate(prompt, system_prompt, temperature)
|
|
679
|
+
|
|
680
|
+
def generate_code(
|
|
681
|
+
self,
|
|
682
|
+
request: str,
|
|
683
|
+
language: str = "python",
|
|
684
|
+
top_k: int = 3,
|
|
685
|
+
temperature: float = 0.7,
|
|
686
|
+
) -> str:
|
|
687
|
+
"""
|
|
688
|
+
Generate code based on documentation context.
|
|
689
|
+
|
|
690
|
+
Parameters
|
|
691
|
+
----------
|
|
692
|
+
request : str
|
|
693
|
+
Description of what code to generate.
|
|
694
|
+
language : str
|
|
695
|
+
Programming language (default: "python").
|
|
696
|
+
top_k : int
|
|
697
|
+
Number of context chunks to retrieve.
|
|
698
|
+
temperature : float
|
|
699
|
+
Sampling temperature.
|
|
700
|
+
|
|
701
|
+
Returns
|
|
702
|
+
-------
|
|
703
|
+
str
|
|
704
|
+
Generated code (cleaned, without markdown).
|
|
705
|
+
|
|
706
|
+
Raises
|
|
707
|
+
------
|
|
708
|
+
NotImplementedError
|
|
709
|
+
If no LLM is configured.
|
|
710
|
+
|
|
711
|
+
Examples
|
|
712
|
+
--------
|
|
713
|
+
>>> code = assistant.generate_code("create a REST API with user endpoints")
|
|
714
|
+
>>> print(code)
|
|
715
|
+
"""
|
|
716
|
+
context = self.get_context(request, top_k)
|
|
717
|
+
|
|
718
|
+
system_prompt = f"""You are an expert {language} developer. Generate ONLY valid {language} code.
|
|
719
|
+
|
|
720
|
+
RULES:
|
|
721
|
+
1. Output PURE CODE ONLY - no explanations, no markdown code blocks
|
|
722
|
+
2. Include necessary imports
|
|
723
|
+
3. Write clean, production-ready code
|
|
724
|
+
4. Add brief comments for clarity"""
|
|
725
|
+
|
|
726
|
+
prompt = f"""Documentation:
|
|
727
|
+
{context}
|
|
728
|
+
|
|
729
|
+
Request: {request}
|
|
730
|
+
|
|
731
|
+
Generate the {language} code:"""
|
|
732
|
+
|
|
733
|
+
response = self.generate(prompt, system_prompt, temperature)
|
|
734
|
+
|
|
735
|
+
# Clean up response - remove markdown if present
|
|
736
|
+
code = response
|
|
737
|
+
if f"```{language}" in code:
|
|
738
|
+
code = code.split(f"```{language}")[1].split("```")[0]
|
|
739
|
+
elif "```" in code:
|
|
740
|
+
code = code.split("```")[1].split("```")[0]
|
|
741
|
+
|
|
742
|
+
return code.strip()
|
|
743
|
+
|
|
744
|
+
@property
|
|
745
|
+
def num_chunks(self) -> int:
|
|
746
|
+
"""Return number of indexed chunks."""
|
|
747
|
+
return len(self._chunks)
|
|
748
|
+
|
|
749
|
+
@property
|
|
750
|
+
def num_documents(self) -> int:
|
|
751
|
+
"""Return number of loaded documents."""
|
|
752
|
+
return len(self.documents)
|
|
753
|
+
|
|
754
|
+
@property
|
|
755
|
+
def has_llm(self) -> bool:
|
|
756
|
+
"""Check if LLM is configured."""
|
|
757
|
+
return self._llm_provider is not None
|