kader 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kader/tools/rag.py ADDED
@@ -0,0 +1,555 @@
1
+ """
2
+ RAG (Retrieval Augmented Generation) Tool.
3
+
4
+ Provides semantic search capabilities using Ollama embeddings and FAISS indexing.
5
+ """
6
+
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Any
11
+
12
+ try:
13
+ import ollama
14
+ from ollama import Client
15
+ except ImportError:
16
+ ollama = None
17
+ Client = None
18
+
19
+ import hashlib
20
+ import pickle
21
+
22
+ try:
23
+ import faiss
24
+ import numpy as np
25
+ except ImportError:
26
+ faiss = None
27
+ np = None
28
+
29
+ from .base import (
30
+ BaseTool,
31
+ ParameterSchema,
32
+ ToolCategory,
33
+ )
34
+
35
+ # Default embedding model
36
+ DEFAULT_EMBEDDING_MODEL = "all-minilm:22m"
37
+
38
+ # File extensions to index by default
39
+ DEFAULT_CODE_EXTENSIONS = {
40
+ ".py",
41
+ ".js",
42
+ ".ts",
43
+ ".jsx",
44
+ ".tsx",
45
+ ".java",
46
+ ".cpp",
47
+ ".c",
48
+ ".h",
49
+ ".go",
50
+ ".rs",
51
+ ".rb",
52
+ ".php",
53
+ ".cs",
54
+ ".swift",
55
+ ".kt",
56
+ ".scala",
57
+ ".html",
58
+ ".css",
59
+ ".scss",
60
+ ".less",
61
+ ".json",
62
+ ".yaml",
63
+ ".yml",
64
+ ".md",
65
+ ".txt",
66
+ ".rst",
67
+ ".toml",
68
+ ".ini",
69
+ ".cfg",
70
+ ".sh",
71
+ ".bash",
72
+ }
73
+
74
+ # Directories to exclude by default
75
+ DEFAULT_EXCLUDE_DIRS = {
76
+ ".git",
77
+ ".svn",
78
+ ".hg",
79
+ "node_modules",
80
+ "__pycache__",
81
+ ".venv",
82
+ "venv",
83
+ "env",
84
+ ".env",
85
+ "dist",
86
+ "build",
87
+ ".tox",
88
+ ".pytest_cache",
89
+ ".mypy_cache",
90
+ ".ruff_cache",
91
+ "target",
92
+ "bin",
93
+ "obj",
94
+ ".idea",
95
+ ".vscode",
96
+ }
97
+
98
+ # Maximum file size to index (1MB)
99
+ MAX_FILE_SIZE = 1_000_000
100
+
101
+ # Chunk size for text splitting
102
+ CHUNK_SIZE = 500
103
+ CHUNK_OVERLAP = 50
104
+
105
+
106
+ @dataclass
107
+ class DocumentChunk:
108
+ """A chunk of text with metadata for indexing."""
109
+
110
+ content: str
111
+ file_path: str
112
+ start_line: int
113
+ end_line: int
114
+ chunk_index: int
115
+
116
+ # Embedding vector (populated after embedding)
117
+ embedding: list[float] | None = None
118
+
119
+ def to_dict(self) -> dict[str, Any]:
120
+ """Convert to dictionary."""
121
+ return {
122
+ "content": self.content,
123
+ "file_path": self.file_path,
124
+ "start_line": self.start_line,
125
+ "end_line": self.end_line,
126
+ "chunk_index": self.chunk_index,
127
+ }
128
+
129
+
130
+ @dataclass
131
+ class SearchResult:
132
+ """A search result from RAG search."""
133
+
134
+ content: str
135
+ file_path: str
136
+ start_line: int
137
+ end_line: int
138
+ score: float
139
+
140
+ def to_dict(self) -> dict[str, Any]:
141
+ """Convert to dictionary."""
142
+ return {
143
+ "content": self.content,
144
+ "file_path": self.file_path,
145
+ "start_line": self.start_line,
146
+ "end_line": self.end_line,
147
+ "score": self.score,
148
+ }
149
+
150
+
151
+ @dataclass
152
+ class RAGIndex:
153
+ """
154
+ Manages FAISS index with Ollama embeddings for semantic search.
155
+
156
+ Example:
157
+ index = RAGIndex(base_path=Path.cwd())
158
+ index.build()
159
+ results = index.search("function to read file", top_k=5)
160
+ """
161
+
162
+ base_path: Path
163
+ embedding_model: str = DEFAULT_EMBEDDING_MODEL
164
+ include_extensions: set[str] = field(
165
+ default_factory=lambda: DEFAULT_CODE_EXTENSIONS.copy()
166
+ )
167
+ exclude_dirs: set[str] = field(default_factory=lambda: DEFAULT_EXCLUDE_DIRS.copy())
168
+ index_dir: Path | None = None
169
+
170
+ # Internal state
171
+ _chunks: list[DocumentChunk] = field(default_factory=list, repr=False)
172
+ _index: Any = field(default=None, repr=False) # FAISS index
173
+ _is_built: bool = field(default=False, repr=False)
174
+ _embedding_dim: int = field(default=768, repr=False) # Default for embeddinggemma
175
+
176
+ def _get_ollama_client(self):
177
+ """Get Ollama client for embeddings."""
178
+ if Client is None:
179
+ raise ImportError("ollama is required for embeddings.")
180
+ return Client()
181
+
182
+ def _embed_text(self, text: str) -> list[float]:
183
+ """Generate embedding for text using Ollama."""
184
+ client = self._get_ollama_client()
185
+ response = client.embed(model=self.embedding_model, input=text)
186
+ return response.embeddings[0]
187
+
188
+ def _embed_texts(self, texts: list[str]) -> list[list[float]]:
189
+ """Generate embeddings for multiple texts."""
190
+ client = self._get_ollama_client()
191
+ response = client.embed(model=self.embedding_model, input=texts)
192
+ return response.embeddings
193
+
194
+ def _chunk_text(self, content: str, file_path: str) -> list[DocumentChunk]:
195
+ """Split text into overlapping chunks."""
196
+ lines = content.split("\n")
197
+ chunks = []
198
+
199
+ current_chunk_lines = []
200
+ current_start_line = 1
201
+ current_char_count = 0
202
+ chunk_index = 0
203
+
204
+ for i, line in enumerate(lines, start=1):
205
+ line_len = len(line) + 1 # +1 for newline
206
+
207
+ if current_char_count + line_len > CHUNK_SIZE and current_chunk_lines:
208
+ # Save current chunk
209
+ chunk_content = "\n".join(current_chunk_lines)
210
+ chunks.append(
211
+ DocumentChunk(
212
+ content=chunk_content,
213
+ file_path=file_path,
214
+ start_line=current_start_line,
215
+ end_line=i - 1,
216
+ chunk_index=chunk_index,
217
+ )
218
+ )
219
+ chunk_index += 1
220
+
221
+ # Start new chunk with overlap
222
+ overlap_lines = max(1, len(current_chunk_lines) // 4)
223
+ current_chunk_lines = current_chunk_lines[-overlap_lines:]
224
+ current_start_line = i - overlap_lines
225
+ current_char_count = sum(len(line) + 1 for line in current_chunk_lines)
226
+
227
+ current_chunk_lines.append(line)
228
+ current_char_count += line_len
229
+
230
+ # Don't forget the last chunk
231
+ if current_chunk_lines:
232
+ chunk_content = "\n".join(current_chunk_lines)
233
+ chunks.append(
234
+ DocumentChunk(
235
+ content=chunk_content,
236
+ file_path=file_path,
237
+ start_line=current_start_line,
238
+ end_line=len(lines),
239
+ chunk_index=chunk_index,
240
+ )
241
+ )
242
+
243
+ return chunks
244
+
245
+ def _collect_files(self) -> list[Path]:
246
+ """Collect all files to index."""
247
+ files = []
248
+
249
+ for root, dirs, filenames in os.walk(self.base_path):
250
+ # Filter out excluded directories
251
+ dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
252
+
253
+ for filename in filenames:
254
+ file_path = Path(root) / filename
255
+
256
+ # Check extension
257
+ if file_path.suffix.lower() not in self.include_extensions:
258
+ continue
259
+
260
+ # Check file size
261
+ try:
262
+ if file_path.stat().st_size > MAX_FILE_SIZE:
263
+ continue
264
+ except OSError:
265
+ continue
266
+
267
+ files.append(file_path)
268
+
269
+ return files
270
+
271
+ @property
272
+ def _index_dir(self) -> Path:
273
+ """Get the directory where the index is stored."""
274
+ if self.index_dir:
275
+ dir_path = self.index_dir
276
+ else:
277
+ dir_path = self.base_path / ".kader" / "index"
278
+
279
+ dir_path.mkdir(parents=True, exist_ok=True)
280
+ return dir_path
281
+
282
+ @property
283
+ def _index_path(self) -> Path:
284
+ """Get path to the FAISS index file."""
285
+ # Create a unique name based on the base path hash to avoid collisions
286
+ path_hash = hashlib.md5(str(self.base_path.absolute()).encode()).hexdigest()
287
+ return self._index_dir / f"faiss_{path_hash}.index"
288
+
289
+ @property
290
+ def _chunks_path(self) -> Path:
291
+ """Get path to the chunks metadata file."""
292
+ path_hash = hashlib.md5(str(self.base_path.absolute()).encode()).hexdigest()
293
+ return self._index_dir / f"chunks_{path_hash}.pkl"
294
+
295
+ def save(self) -> None:
296
+ """Save the index and chunks to disk."""
297
+ if faiss is None:
298
+ return
299
+
300
+ if not self._is_built or self._index is None:
301
+ return
302
+
303
+ # Save FAISS index
304
+ faiss.write_index(self._index, str(self._index_path))
305
+
306
+ # Save chunks metadata
307
+ with open(self._chunks_path, "wb") as f:
308
+ pickle.dump(self._chunks, f)
309
+
310
+ def load(self) -> bool:
311
+ """
312
+ Load the index and chunks from disk.
313
+
314
+ Returns:
315
+ True if loaded successfully, False otherwise
316
+ """
317
+ if faiss is None:
318
+ return False
319
+
320
+ if not self._index_path.exists() or not self._chunks_path.exists():
321
+ return False
322
+
323
+ try:
324
+ # Load FAISS index
325
+ self._index = faiss.read_index(str(self._index_path))
326
+
327
+ # Load chunks metadata
328
+ with open(self._chunks_path, "rb") as f:
329
+ self._chunks = pickle.load(f)
330
+
331
+ self._is_built = True
332
+ self._embedding_dim = self._index.d
333
+ return True
334
+ except Exception:
335
+ return False
336
+
337
+ def build(self) -> int:
338
+ """
339
+ Build the index by scanning and embedding all files.
340
+
341
+ Returns:
342
+ Number of chunks indexed
343
+ """
344
+ if faiss is None:
345
+ raise ImportError(
346
+ "faiss-cpu is required for RAG search. "
347
+ "Install it with: uv add faiss-cpu"
348
+ )
349
+
350
+ self._chunks = []
351
+ files = self._collect_files()
352
+
353
+ # Collect all chunks
354
+ for file_path in files:
355
+ try:
356
+ content = file_path.read_text(encoding="utf-8", errors="ignore")
357
+ rel_path = str(file_path.relative_to(self.base_path))
358
+ chunks = self._chunk_text(content, rel_path)
359
+ self._chunks.extend(chunks)
360
+ except Exception:
361
+ continue
362
+
363
+ if not self._chunks:
364
+ self._is_built = True
365
+ return 0
366
+
367
+ # Generate embeddings in batches
368
+ batch_size = 32
369
+ all_embeddings = []
370
+
371
+ for i in range(0, len(self._chunks), batch_size):
372
+ batch = self._chunks[i : i + batch_size]
373
+ texts = [chunk.content for chunk in batch]
374
+ embeddings = self._embed_texts(texts)
375
+ all_embeddings.extend(embeddings)
376
+
377
+ # Store embeddings in chunks
378
+ for chunk, emb in zip(batch, embeddings):
379
+ chunk.embedding = emb
380
+
381
+ # Build FAISS index
382
+ embeddings_array = np.array(all_embeddings, dtype=np.float32)
383
+ self._embedding_dim = embeddings_array.shape[1]
384
+
385
+ self._index = faiss.IndexFlatL2(self._embedding_dim)
386
+ self._index.add(embeddings_array)
387
+
388
+ self._is_built = True
389
+ self.save() # Auto-save after build
390
+ return len(self._chunks)
391
+
392
+ def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
393
+ """
394
+ Search the index for similar content.
395
+
396
+ Args:
397
+ query: Search query text
398
+ top_k: Number of results to return
399
+
400
+ Returns:
401
+ List of SearchResult objects
402
+ """
403
+ import numpy as np
404
+
405
+ if not self._is_built:
406
+ self.build()
407
+
408
+ if not self._chunks or self._index is None:
409
+ return []
410
+
411
+ # Embed the query
412
+ query_embedding = self._embed_text(query)
413
+ query_array = np.array([query_embedding], dtype=np.float32)
414
+
415
+ # Search
416
+ k = min(top_k, len(self._chunks))
417
+ distances, indices = self._index.search(query_array, k)
418
+
419
+ # Convert to results
420
+ results = []
421
+ for dist, idx in zip(distances[0], indices[0]):
422
+ if idx < 0 or idx >= len(self._chunks):
423
+ continue
424
+
425
+ chunk = self._chunks[idx]
426
+ # Convert L2 distance to similarity score (lower distance = higher score)
427
+ score = 1.0 / (1.0 + float(dist))
428
+
429
+ results.append(
430
+ SearchResult(
431
+ content=chunk.content,
432
+ file_path=chunk.file_path,
433
+ start_line=chunk.start_line,
434
+ end_line=chunk.end_line,
435
+ score=score,
436
+ )
437
+ )
438
+
439
+ return results
440
+
441
+ def clear(self) -> None:
442
+ """Clear the index."""
443
+ self._chunks = []
444
+ self._index = None
445
+ self._is_built = False
446
+
447
+
448
+ class RAGSearchTool(BaseTool[list[dict[str, Any]]]):
449
+ """
450
+ Tool for semantic search using RAG (Retrieval Augmented Generation).
451
+
452
+ Uses Ollama embeddings and FAISS for fast similarity search across
453
+ the codebase in the current working directory.
454
+ """
455
+
456
+ def __init__(
457
+ self,
458
+ base_path: Path | None = None,
459
+ embedding_model: str = DEFAULT_EMBEDDING_MODEL,
460
+ ) -> None:
461
+ """
462
+ Initialize the RAG search tool.
463
+
464
+ Args:
465
+ base_path: Base path to search in (defaults to CWD)
466
+ embedding_model: Ollama embedding model to use
467
+ """
468
+ super().__init__(
469
+ name="rag_search",
470
+ description=(
471
+ "Search for code and text using semantic similarity. "
472
+ "Finds relevant files and code snippets based on meaning, not just keywords."
473
+ ),
474
+ parameters=[
475
+ ParameterSchema(
476
+ name="query",
477
+ type="string",
478
+ description="Natural language search query",
479
+ ),
480
+ ParameterSchema(
481
+ name="top_k",
482
+ type="integer",
483
+ description="Number of results to return",
484
+ required=False,
485
+ default=5,
486
+ minimum=1,
487
+ maximum=20,
488
+ ),
489
+ ParameterSchema(
490
+ name="rebuild",
491
+ type="boolean",
492
+ description="Force rebuild the index before searching",
493
+ required=False,
494
+ default=False,
495
+ ),
496
+ ],
497
+ category=ToolCategory.SEARCH,
498
+ )
499
+
500
+ self._base_path = base_path or Path.cwd()
501
+ self._embedding_model = embedding_model
502
+ self._index: RAGIndex | None = None
503
+
504
+ def _get_or_build_index(self, rebuild: bool = False) -> RAGIndex:
505
+ """Get existing index or build a new one."""
506
+ if self._index is None:
507
+ self._index = RAGIndex(
508
+ base_path=self._base_path,
509
+ embedding_model=self._embedding_model,
510
+ )
511
+
512
+ if rebuild:
513
+ self._index.build()
514
+ elif not self._index._is_built:
515
+ # Try loading first, otherwise build
516
+ if not self._index.load():
517
+ self._index.build()
518
+
519
+ return self._index
520
+
521
+ def execute(
522
+ self,
523
+ query: str,
524
+ top_k: int = 5,
525
+ rebuild: bool = False,
526
+ ) -> list[dict[str, Any]]:
527
+ """
528
+ Execute semantic search.
529
+
530
+ Args:
531
+ query: Natural language search query
532
+ top_k: Number of results to return
533
+ rebuild: Force rebuild the index
534
+
535
+ Returns:
536
+ List of search result dictionaries
537
+ """
538
+ index = self._get_or_build_index(rebuild)
539
+ results = index.search(query, top_k)
540
+ return [r.to_dict() for r in results]
541
+
542
+ async def aexecute(
543
+ self,
544
+ query: str,
545
+ top_k: int = 5,
546
+ rebuild: bool = False,
547
+ ) -> list[dict[str, Any]]:
548
+ """Async version of execute."""
549
+ import asyncio
550
+
551
+ return await asyncio.to_thread(self.execute, query, top_k, rebuild)
552
+
553
+ def get_interruption_message(self, query: str, **kwargs) -> str:
554
+ """Get interruption message for user confirmation."""
555
+ return f"execute rag_search: {query}"