kader 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/README.md +169 -0
- cli/__init__.py +5 -0
- cli/__main__.py +6 -0
- cli/app.py +547 -0
- cli/app.tcss +648 -0
- cli/utils.py +62 -0
- cli/widgets/__init__.py +13 -0
- cli/widgets/confirmation.py +309 -0
- cli/widgets/conversation.py +55 -0
- cli/widgets/loading.py +59 -0
- kader/__init__.py +22 -0
- kader/agent/__init__.py +8 -0
- kader/agent/agents.py +126 -0
- kader/agent/base.py +920 -0
- kader/agent/logger.py +188 -0
- kader/config.py +139 -0
- kader/memory/__init__.py +66 -0
- kader/memory/conversation.py +409 -0
- kader/memory/session.py +385 -0
- kader/memory/state.py +211 -0
- kader/memory/types.py +116 -0
- kader/prompts/__init__.py +9 -0
- kader/prompts/agent_prompts.py +27 -0
- kader/prompts/base.py +81 -0
- kader/prompts/templates/planning_agent.j2 +26 -0
- kader/prompts/templates/react_agent.j2 +18 -0
- kader/providers/__init__.py +9 -0
- kader/providers/base.py +581 -0
- kader/providers/mock.py +96 -0
- kader/providers/ollama.py +447 -0
- kader/tools/README.md +483 -0
- kader/tools/__init__.py +130 -0
- kader/tools/base.py +955 -0
- kader/tools/exec_commands.py +249 -0
- kader/tools/filesys.py +650 -0
- kader/tools/filesystem.py +607 -0
- kader/tools/protocol.py +456 -0
- kader/tools/rag.py +555 -0
- kader/tools/todo.py +210 -0
- kader/tools/utils.py +456 -0
- kader/tools/web.py +246 -0
- kader-0.1.0.dist-info/METADATA +319 -0
- kader-0.1.0.dist-info/RECORD +45 -0
- kader-0.1.0.dist-info/WHEEL +4 -0
- kader-0.1.0.dist-info/entry_points.txt +2 -0
kader/tools/rag.py
ADDED
|
@@ -0,0 +1,555 @@
|
|
|
1
|
+
"""
|
|
2
|
+
RAG (Retrieval Augmented Generation) Tool.
|
|
3
|
+
|
|
4
|
+
Provides semantic search capabilities using Ollama embeddings and FAISS indexing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import ollama
|
|
14
|
+
from ollama import Client
|
|
15
|
+
except ImportError:
|
|
16
|
+
ollama = None
|
|
17
|
+
Client = None
|
|
18
|
+
|
|
19
|
+
import hashlib
|
|
20
|
+
import pickle
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import faiss
|
|
24
|
+
import numpy as np
|
|
25
|
+
except ImportError:
|
|
26
|
+
faiss = None
|
|
27
|
+
np = None
|
|
28
|
+
|
|
29
|
+
from .base import (
|
|
30
|
+
BaseTool,
|
|
31
|
+
ParameterSchema,
|
|
32
|
+
ToolCategory,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Default embedding model
|
|
36
|
+
DEFAULT_EMBEDDING_MODEL = "all-minilm:22m"
|
|
37
|
+
|
|
38
|
+
# File extensions to index by default
|
|
39
|
+
DEFAULT_CODE_EXTENSIONS = {
|
|
40
|
+
".py",
|
|
41
|
+
".js",
|
|
42
|
+
".ts",
|
|
43
|
+
".jsx",
|
|
44
|
+
".tsx",
|
|
45
|
+
".java",
|
|
46
|
+
".cpp",
|
|
47
|
+
".c",
|
|
48
|
+
".h",
|
|
49
|
+
".go",
|
|
50
|
+
".rs",
|
|
51
|
+
".rb",
|
|
52
|
+
".php",
|
|
53
|
+
".cs",
|
|
54
|
+
".swift",
|
|
55
|
+
".kt",
|
|
56
|
+
".scala",
|
|
57
|
+
".html",
|
|
58
|
+
".css",
|
|
59
|
+
".scss",
|
|
60
|
+
".less",
|
|
61
|
+
".json",
|
|
62
|
+
".yaml",
|
|
63
|
+
".yml",
|
|
64
|
+
".md",
|
|
65
|
+
".txt",
|
|
66
|
+
".rst",
|
|
67
|
+
".toml",
|
|
68
|
+
".ini",
|
|
69
|
+
".cfg",
|
|
70
|
+
".sh",
|
|
71
|
+
".bash",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# Directories to exclude by default
|
|
75
|
+
DEFAULT_EXCLUDE_DIRS = {
|
|
76
|
+
".git",
|
|
77
|
+
".svn",
|
|
78
|
+
".hg",
|
|
79
|
+
"node_modules",
|
|
80
|
+
"__pycache__",
|
|
81
|
+
".venv",
|
|
82
|
+
"venv",
|
|
83
|
+
"env",
|
|
84
|
+
".env",
|
|
85
|
+
"dist",
|
|
86
|
+
"build",
|
|
87
|
+
".tox",
|
|
88
|
+
".pytest_cache",
|
|
89
|
+
".mypy_cache",
|
|
90
|
+
".ruff_cache",
|
|
91
|
+
"target",
|
|
92
|
+
"bin",
|
|
93
|
+
"obj",
|
|
94
|
+
".idea",
|
|
95
|
+
".vscode",
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Maximum file size to index (1MB)
|
|
99
|
+
MAX_FILE_SIZE = 1_000_000
|
|
100
|
+
|
|
101
|
+
# Chunk size for text splitting
|
|
102
|
+
CHUNK_SIZE = 500
|
|
103
|
+
CHUNK_OVERLAP = 50
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class DocumentChunk:
|
|
108
|
+
"""A chunk of text with metadata for indexing."""
|
|
109
|
+
|
|
110
|
+
content: str
|
|
111
|
+
file_path: str
|
|
112
|
+
start_line: int
|
|
113
|
+
end_line: int
|
|
114
|
+
chunk_index: int
|
|
115
|
+
|
|
116
|
+
# Embedding vector (populated after embedding)
|
|
117
|
+
embedding: list[float] | None = None
|
|
118
|
+
|
|
119
|
+
def to_dict(self) -> dict[str, Any]:
|
|
120
|
+
"""Convert to dictionary."""
|
|
121
|
+
return {
|
|
122
|
+
"content": self.content,
|
|
123
|
+
"file_path": self.file_path,
|
|
124
|
+
"start_line": self.start_line,
|
|
125
|
+
"end_line": self.end_line,
|
|
126
|
+
"chunk_index": self.chunk_index,
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@dataclass
|
|
131
|
+
class SearchResult:
|
|
132
|
+
"""A search result from RAG search."""
|
|
133
|
+
|
|
134
|
+
content: str
|
|
135
|
+
file_path: str
|
|
136
|
+
start_line: int
|
|
137
|
+
end_line: int
|
|
138
|
+
score: float
|
|
139
|
+
|
|
140
|
+
def to_dict(self) -> dict[str, Any]:
|
|
141
|
+
"""Convert to dictionary."""
|
|
142
|
+
return {
|
|
143
|
+
"content": self.content,
|
|
144
|
+
"file_path": self.file_path,
|
|
145
|
+
"start_line": self.start_line,
|
|
146
|
+
"end_line": self.end_line,
|
|
147
|
+
"score": self.score,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class RAGIndex:
|
|
153
|
+
"""
|
|
154
|
+
Manages FAISS index with Ollama embeddings for semantic search.
|
|
155
|
+
|
|
156
|
+
Example:
|
|
157
|
+
index = RAGIndex(base_path=Path.cwd())
|
|
158
|
+
index.build()
|
|
159
|
+
results = index.search("function to read file", top_k=5)
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
base_path: Path
|
|
163
|
+
embedding_model: str = DEFAULT_EMBEDDING_MODEL
|
|
164
|
+
include_extensions: set[str] = field(
|
|
165
|
+
default_factory=lambda: DEFAULT_CODE_EXTENSIONS.copy()
|
|
166
|
+
)
|
|
167
|
+
exclude_dirs: set[str] = field(default_factory=lambda: DEFAULT_EXCLUDE_DIRS.copy())
|
|
168
|
+
index_dir: Path | None = None
|
|
169
|
+
|
|
170
|
+
# Internal state
|
|
171
|
+
_chunks: list[DocumentChunk] = field(default_factory=list, repr=False)
|
|
172
|
+
_index: Any = field(default=None, repr=False) # FAISS index
|
|
173
|
+
_is_built: bool = field(default=False, repr=False)
|
|
174
|
+
_embedding_dim: int = field(default=768, repr=False) # Default for embeddinggemma
|
|
175
|
+
|
|
176
|
+
def _get_ollama_client(self):
|
|
177
|
+
"""Get Ollama client for embeddings."""
|
|
178
|
+
if Client is None:
|
|
179
|
+
raise ImportError("ollama is required for embeddings.")
|
|
180
|
+
return Client()
|
|
181
|
+
|
|
182
|
+
def _embed_text(self, text: str) -> list[float]:
|
|
183
|
+
"""Generate embedding for text using Ollama."""
|
|
184
|
+
client = self._get_ollama_client()
|
|
185
|
+
response = client.embed(model=self.embedding_model, input=text)
|
|
186
|
+
return response.embeddings[0]
|
|
187
|
+
|
|
188
|
+
def _embed_texts(self, texts: list[str]) -> list[list[float]]:
|
|
189
|
+
"""Generate embeddings for multiple texts."""
|
|
190
|
+
client = self._get_ollama_client()
|
|
191
|
+
response = client.embed(model=self.embedding_model, input=texts)
|
|
192
|
+
return response.embeddings
|
|
193
|
+
|
|
194
|
+
def _chunk_text(self, content: str, file_path: str) -> list[DocumentChunk]:
|
|
195
|
+
"""Split text into overlapping chunks."""
|
|
196
|
+
lines = content.split("\n")
|
|
197
|
+
chunks = []
|
|
198
|
+
|
|
199
|
+
current_chunk_lines = []
|
|
200
|
+
current_start_line = 1
|
|
201
|
+
current_char_count = 0
|
|
202
|
+
chunk_index = 0
|
|
203
|
+
|
|
204
|
+
for i, line in enumerate(lines, start=1):
|
|
205
|
+
line_len = len(line) + 1 # +1 for newline
|
|
206
|
+
|
|
207
|
+
if current_char_count + line_len > CHUNK_SIZE and current_chunk_lines:
|
|
208
|
+
# Save current chunk
|
|
209
|
+
chunk_content = "\n".join(current_chunk_lines)
|
|
210
|
+
chunks.append(
|
|
211
|
+
DocumentChunk(
|
|
212
|
+
content=chunk_content,
|
|
213
|
+
file_path=file_path,
|
|
214
|
+
start_line=current_start_line,
|
|
215
|
+
end_line=i - 1,
|
|
216
|
+
chunk_index=chunk_index,
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
chunk_index += 1
|
|
220
|
+
|
|
221
|
+
# Start new chunk with overlap
|
|
222
|
+
overlap_lines = max(1, len(current_chunk_lines) // 4)
|
|
223
|
+
current_chunk_lines = current_chunk_lines[-overlap_lines:]
|
|
224
|
+
current_start_line = i - overlap_lines
|
|
225
|
+
current_char_count = sum(len(line) + 1 for line in current_chunk_lines)
|
|
226
|
+
|
|
227
|
+
current_chunk_lines.append(line)
|
|
228
|
+
current_char_count += line_len
|
|
229
|
+
|
|
230
|
+
# Don't forget the last chunk
|
|
231
|
+
if current_chunk_lines:
|
|
232
|
+
chunk_content = "\n".join(current_chunk_lines)
|
|
233
|
+
chunks.append(
|
|
234
|
+
DocumentChunk(
|
|
235
|
+
content=chunk_content,
|
|
236
|
+
file_path=file_path,
|
|
237
|
+
start_line=current_start_line,
|
|
238
|
+
end_line=len(lines),
|
|
239
|
+
chunk_index=chunk_index,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
return chunks
|
|
244
|
+
|
|
245
|
+
def _collect_files(self) -> list[Path]:
|
|
246
|
+
"""Collect all files to index."""
|
|
247
|
+
files = []
|
|
248
|
+
|
|
249
|
+
for root, dirs, filenames in os.walk(self.base_path):
|
|
250
|
+
# Filter out excluded directories
|
|
251
|
+
dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
|
|
252
|
+
|
|
253
|
+
for filename in filenames:
|
|
254
|
+
file_path = Path(root) / filename
|
|
255
|
+
|
|
256
|
+
# Check extension
|
|
257
|
+
if file_path.suffix.lower() not in self.include_extensions:
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
# Check file size
|
|
261
|
+
try:
|
|
262
|
+
if file_path.stat().st_size > MAX_FILE_SIZE:
|
|
263
|
+
continue
|
|
264
|
+
except OSError:
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
files.append(file_path)
|
|
268
|
+
|
|
269
|
+
return files
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def _index_dir(self) -> Path:
|
|
273
|
+
"""Get the directory where the index is stored."""
|
|
274
|
+
if self.index_dir:
|
|
275
|
+
dir_path = self.index_dir
|
|
276
|
+
else:
|
|
277
|
+
dir_path = self.base_path / ".kader" / "index"
|
|
278
|
+
|
|
279
|
+
dir_path.mkdir(parents=True, exist_ok=True)
|
|
280
|
+
return dir_path
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def _index_path(self) -> Path:
|
|
284
|
+
"""Get path to the FAISS index file."""
|
|
285
|
+
# Create a unique name based on the base path hash to avoid collisions
|
|
286
|
+
path_hash = hashlib.md5(str(self.base_path.absolute()).encode()).hexdigest()
|
|
287
|
+
return self._index_dir / f"faiss_{path_hash}.index"
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def _chunks_path(self) -> Path:
|
|
291
|
+
"""Get path to the chunks metadata file."""
|
|
292
|
+
path_hash = hashlib.md5(str(self.base_path.absolute()).encode()).hexdigest()
|
|
293
|
+
return self._index_dir / f"chunks_{path_hash}.pkl"
|
|
294
|
+
|
|
295
|
+
def save(self) -> None:
|
|
296
|
+
"""Save the index and chunks to disk."""
|
|
297
|
+
if faiss is None:
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
if not self._is_built or self._index is None:
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
# Save FAISS index
|
|
304
|
+
faiss.write_index(self._index, str(self._index_path))
|
|
305
|
+
|
|
306
|
+
# Save chunks metadata
|
|
307
|
+
with open(self._chunks_path, "wb") as f:
|
|
308
|
+
pickle.dump(self._chunks, f)
|
|
309
|
+
|
|
310
|
+
def load(self) -> bool:
|
|
311
|
+
"""
|
|
312
|
+
Load the index and chunks from disk.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
True if loaded successfully, False otherwise
|
|
316
|
+
"""
|
|
317
|
+
if faiss is None:
|
|
318
|
+
return False
|
|
319
|
+
|
|
320
|
+
if not self._index_path.exists() or not self._chunks_path.exists():
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
try:
|
|
324
|
+
# Load FAISS index
|
|
325
|
+
self._index = faiss.read_index(str(self._index_path))
|
|
326
|
+
|
|
327
|
+
# Load chunks metadata
|
|
328
|
+
with open(self._chunks_path, "rb") as f:
|
|
329
|
+
self._chunks = pickle.load(f)
|
|
330
|
+
|
|
331
|
+
self._is_built = True
|
|
332
|
+
self._embedding_dim = self._index.d
|
|
333
|
+
return True
|
|
334
|
+
except Exception:
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
def build(self) -> int:
|
|
338
|
+
"""
|
|
339
|
+
Build the index by scanning and embedding all files.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Number of chunks indexed
|
|
343
|
+
"""
|
|
344
|
+
if faiss is None:
|
|
345
|
+
raise ImportError(
|
|
346
|
+
"faiss-cpu is required for RAG search. "
|
|
347
|
+
"Install it with: uv add faiss-cpu"
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
self._chunks = []
|
|
351
|
+
files = self._collect_files()
|
|
352
|
+
|
|
353
|
+
# Collect all chunks
|
|
354
|
+
for file_path in files:
|
|
355
|
+
try:
|
|
356
|
+
content = file_path.read_text(encoding="utf-8", errors="ignore")
|
|
357
|
+
rel_path = str(file_path.relative_to(self.base_path))
|
|
358
|
+
chunks = self._chunk_text(content, rel_path)
|
|
359
|
+
self._chunks.extend(chunks)
|
|
360
|
+
except Exception:
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
if not self._chunks:
|
|
364
|
+
self._is_built = True
|
|
365
|
+
return 0
|
|
366
|
+
|
|
367
|
+
# Generate embeddings in batches
|
|
368
|
+
batch_size = 32
|
|
369
|
+
all_embeddings = []
|
|
370
|
+
|
|
371
|
+
for i in range(0, len(self._chunks), batch_size):
|
|
372
|
+
batch = self._chunks[i : i + batch_size]
|
|
373
|
+
texts = [chunk.content for chunk in batch]
|
|
374
|
+
embeddings = self._embed_texts(texts)
|
|
375
|
+
all_embeddings.extend(embeddings)
|
|
376
|
+
|
|
377
|
+
# Store embeddings in chunks
|
|
378
|
+
for chunk, emb in zip(batch, embeddings):
|
|
379
|
+
chunk.embedding = emb
|
|
380
|
+
|
|
381
|
+
# Build FAISS index
|
|
382
|
+
embeddings_array = np.array(all_embeddings, dtype=np.float32)
|
|
383
|
+
self._embedding_dim = embeddings_array.shape[1]
|
|
384
|
+
|
|
385
|
+
self._index = faiss.IndexFlatL2(self._embedding_dim)
|
|
386
|
+
self._index.add(embeddings_array)
|
|
387
|
+
|
|
388
|
+
self._is_built = True
|
|
389
|
+
self.save() # Auto-save after build
|
|
390
|
+
return len(self._chunks)
|
|
391
|
+
|
|
392
|
+
def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
|
|
393
|
+
"""
|
|
394
|
+
Search the index for similar content.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
query: Search query text
|
|
398
|
+
top_k: Number of results to return
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
List of SearchResult objects
|
|
402
|
+
"""
|
|
403
|
+
import numpy as np
|
|
404
|
+
|
|
405
|
+
if not self._is_built:
|
|
406
|
+
self.build()
|
|
407
|
+
|
|
408
|
+
if not self._chunks or self._index is None:
|
|
409
|
+
return []
|
|
410
|
+
|
|
411
|
+
# Embed the query
|
|
412
|
+
query_embedding = self._embed_text(query)
|
|
413
|
+
query_array = np.array([query_embedding], dtype=np.float32)
|
|
414
|
+
|
|
415
|
+
# Search
|
|
416
|
+
k = min(top_k, len(self._chunks))
|
|
417
|
+
distances, indices = self._index.search(query_array, k)
|
|
418
|
+
|
|
419
|
+
# Convert to results
|
|
420
|
+
results = []
|
|
421
|
+
for dist, idx in zip(distances[0], indices[0]):
|
|
422
|
+
if idx < 0 or idx >= len(self._chunks):
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
chunk = self._chunks[idx]
|
|
426
|
+
# Convert L2 distance to similarity score (lower distance = higher score)
|
|
427
|
+
score = 1.0 / (1.0 + float(dist))
|
|
428
|
+
|
|
429
|
+
results.append(
|
|
430
|
+
SearchResult(
|
|
431
|
+
content=chunk.content,
|
|
432
|
+
file_path=chunk.file_path,
|
|
433
|
+
start_line=chunk.start_line,
|
|
434
|
+
end_line=chunk.end_line,
|
|
435
|
+
score=score,
|
|
436
|
+
)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
return results
|
|
440
|
+
|
|
441
|
+
def clear(self) -> None:
|
|
442
|
+
"""Clear the index."""
|
|
443
|
+
self._chunks = []
|
|
444
|
+
self._index = None
|
|
445
|
+
self._is_built = False
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class RAGSearchTool(BaseTool[list[dict[str, Any]]]):
|
|
449
|
+
"""
|
|
450
|
+
Tool for semantic search using RAG (Retrieval Augmented Generation).
|
|
451
|
+
|
|
452
|
+
Uses Ollama embeddings and FAISS for fast similarity search across
|
|
453
|
+
the codebase in the current working directory.
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
def __init__(
|
|
457
|
+
self,
|
|
458
|
+
base_path: Path | None = None,
|
|
459
|
+
embedding_model: str = DEFAULT_EMBEDDING_MODEL,
|
|
460
|
+
) -> None:
|
|
461
|
+
"""
|
|
462
|
+
Initialize the RAG search tool.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
base_path: Base path to search in (defaults to CWD)
|
|
466
|
+
embedding_model: Ollama embedding model to use
|
|
467
|
+
"""
|
|
468
|
+
super().__init__(
|
|
469
|
+
name="rag_search",
|
|
470
|
+
description=(
|
|
471
|
+
"Search for code and text using semantic similarity. "
|
|
472
|
+
"Finds relevant files and code snippets based on meaning, not just keywords."
|
|
473
|
+
),
|
|
474
|
+
parameters=[
|
|
475
|
+
ParameterSchema(
|
|
476
|
+
name="query",
|
|
477
|
+
type="string",
|
|
478
|
+
description="Natural language search query",
|
|
479
|
+
),
|
|
480
|
+
ParameterSchema(
|
|
481
|
+
name="top_k",
|
|
482
|
+
type="integer",
|
|
483
|
+
description="Number of results to return",
|
|
484
|
+
required=False,
|
|
485
|
+
default=5,
|
|
486
|
+
minimum=1,
|
|
487
|
+
maximum=20,
|
|
488
|
+
),
|
|
489
|
+
ParameterSchema(
|
|
490
|
+
name="rebuild",
|
|
491
|
+
type="boolean",
|
|
492
|
+
description="Force rebuild the index before searching",
|
|
493
|
+
required=False,
|
|
494
|
+
default=False,
|
|
495
|
+
),
|
|
496
|
+
],
|
|
497
|
+
category=ToolCategory.SEARCH,
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
self._base_path = base_path or Path.cwd()
|
|
501
|
+
self._embedding_model = embedding_model
|
|
502
|
+
self._index: RAGIndex | None = None
|
|
503
|
+
|
|
504
|
+
def _get_or_build_index(self, rebuild: bool = False) -> RAGIndex:
|
|
505
|
+
"""Get existing index or build a new one."""
|
|
506
|
+
if self._index is None:
|
|
507
|
+
self._index = RAGIndex(
|
|
508
|
+
base_path=self._base_path,
|
|
509
|
+
embedding_model=self._embedding_model,
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
if rebuild:
|
|
513
|
+
self._index.build()
|
|
514
|
+
elif not self._index._is_built:
|
|
515
|
+
# Try loading first, otherwise build
|
|
516
|
+
if not self._index.load():
|
|
517
|
+
self._index.build()
|
|
518
|
+
|
|
519
|
+
return self._index
|
|
520
|
+
|
|
521
|
+
def execute(
|
|
522
|
+
self,
|
|
523
|
+
query: str,
|
|
524
|
+
top_k: int = 5,
|
|
525
|
+
rebuild: bool = False,
|
|
526
|
+
) -> list[dict[str, Any]]:
|
|
527
|
+
"""
|
|
528
|
+
Execute semantic search.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
query: Natural language search query
|
|
532
|
+
top_k: Number of results to return
|
|
533
|
+
rebuild: Force rebuild the index
|
|
534
|
+
|
|
535
|
+
Returns:
|
|
536
|
+
List of search result dictionaries
|
|
537
|
+
"""
|
|
538
|
+
index = self._get_or_build_index(rebuild)
|
|
539
|
+
results = index.search(query, top_k)
|
|
540
|
+
return [r.to_dict() for r in results]
|
|
541
|
+
|
|
542
|
+
async def aexecute(
|
|
543
|
+
self,
|
|
544
|
+
query: str,
|
|
545
|
+
top_k: int = 5,
|
|
546
|
+
rebuild: bool = False,
|
|
547
|
+
) -> list[dict[str, Any]]:
|
|
548
|
+
"""Async version of execute."""
|
|
549
|
+
import asyncio
|
|
550
|
+
|
|
551
|
+
return await asyncio.to_thread(self.execute, query, top_k, rebuild)
|
|
552
|
+
|
|
553
|
+
def get_interruption_message(self, query: str, **kwargs) -> str:
|
|
554
|
+
"""Get interruption message for user confirmation."""
|
|
555
|
+
return f"execute rag_search: {query}"
|