ragit 0.3__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/assistant.py ADDED
@@ -0,0 +1,757 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ High-level RAG Assistant for document Q&A and code generation.
7
+
8
+ Provides a simple interface for RAG-based tasks.
9
+
10
+ Note: This class is NOT thread-safe. Do not share instances across threads.
11
+ """
12
+
13
+ from collections.abc import Callable
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import numpy as np
18
+ from numpy.typing import NDArray
19
+
20
+ from ragit.core.experiment.experiment import Chunk, Document
21
+ from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
22
+ from ragit.logging import log_operation
23
+ from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
24
+ from ragit.providers.function_adapter import FunctionProvider
25
+
26
+ if TYPE_CHECKING:
27
+ from numpy.typing import NDArray
28
+
29
+
30
+ class RAGAssistant:
31
+ """
32
+ High-level RAG assistant for document Q&A and generation.
33
+
34
+ Handles document indexing, retrieval, and LLM generation in one simple API.
35
+
36
+ Parameters
37
+ ----------
38
+ documents : list[Document] or str or Path
39
+ Documents to index. Can be:
40
+ - List of Document objects
41
+ - Path to a single file
42
+ - Path to a directory (will load all .txt, .md, .rst files)
43
+ embed_fn : Callable[[str], list[float]], optional
44
+ Function that takes text and returns an embedding vector.
45
+ If provided, creates a FunctionProvider internally.
46
+ generate_fn : Callable, optional
47
+ Function for text generation. Supports (prompt) or (prompt, system_prompt).
48
+ If provided without embed_fn, must also provide embed_fn.
49
+ provider : BaseEmbeddingProvider, optional
50
+ Provider for embeddings (and optionally LLM). If embed_fn is provided,
51
+ this is ignored for embeddings.
52
+ embedding_model : str, optional
53
+ Embedding model name (used with provider).
54
+ llm_model : str, optional
55
+ LLM model name (used with provider).
56
+ chunk_size : int, optional
57
+ Chunk size for splitting documents (default: 512).
58
+ chunk_overlap : int, optional
59
+ Overlap between chunks (default: 50).
60
+
61
+ Raises
62
+ ------
63
+ ValueError
64
+ If neither embed_fn nor provider is provided.
65
+
66
+ Note
67
+ ----
68
+ This class is NOT thread-safe. Each thread should have its own instance.
69
+
70
+ Examples
71
+ --------
72
+ >>> # With custom embedding function (retrieval-only)
73
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed)
74
+ >>> results = assistant.retrieve("query")
75
+ >>>
76
+ >>> # With custom embedding and LLM functions (full RAG)
77
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
78
+ >>> answer = assistant.ask("What is X?")
79
+ >>>
80
+ >>> # With Ollama provider (supports nomic-embed-text)
81
+ >>> from ragit.providers import OllamaProvider
82
+ >>> assistant = RAGAssistant(docs, provider=OllamaProvider())
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ documents: list[Document] | str | Path,
88
+ embed_fn: Callable[[str], list[float]] | None = None,
89
+ generate_fn: Callable[..., str] | None = None,
90
+ provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
91
+ embedding_model: str | None = None,
92
+ llm_model: str | None = None,
93
+ chunk_size: int = 512,
94
+ chunk_overlap: int = 50,
95
+ ):
96
+ # Resolve provider from embed_fn/generate_fn or explicit provider
97
+ self._embedding_provider: BaseEmbeddingProvider
98
+ self._llm_provider: BaseLLMProvider | None = None
99
+
100
+ if embed_fn is not None:
101
+ # Create FunctionProvider from provided functions
102
+ function_provider = FunctionProvider(
103
+ embed_fn=embed_fn,
104
+ generate_fn=generate_fn,
105
+ )
106
+ self._embedding_provider = function_provider
107
+ if generate_fn is not None:
108
+ self._llm_provider = function_provider
109
+ elif provider is not None and isinstance(provider, BaseLLMProvider):
110
+ # Use explicit provider for LLM if function_provider doesn't have LLM
111
+ self._llm_provider = provider
112
+ elif provider is not None:
113
+ # Use explicit provider
114
+ if not isinstance(provider, BaseEmbeddingProvider):
115
+ raise ValueError(
116
+ "Provider must implement BaseEmbeddingProvider for embeddings. Alternatively, provide embed_fn."
117
+ )
118
+ self._embedding_provider = provider
119
+ if isinstance(provider, BaseLLMProvider):
120
+ self._llm_provider = provider
121
+ else:
122
+ raise ValueError(
123
+ "Must provide embed_fn or provider for embeddings. "
124
+ "Examples:\n"
125
+ " RAGAssistant(docs, embed_fn=my_embed_function)\n"
126
+ " RAGAssistant(docs, provider=OllamaProvider())"
127
+ )
128
+
129
+ self.embedding_model = embedding_model or "default"
130
+ self.llm_model = llm_model or "default"
131
+ self.chunk_size = chunk_size
132
+ self.chunk_overlap = chunk_overlap
133
+
134
+ # Load documents if path provided
135
+ self.documents = self._load_documents(documents)
136
+
137
+ # Index chunks - embeddings stored as pre-normalized numpy matrix for fast search
138
+ self._chunks: tuple[Chunk, ...] = ()
139
+ self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
140
+ self._build_index()
141
+
142
+ def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
143
+ """Load documents from various sources."""
144
+ if isinstance(documents, list):
145
+ return documents
146
+
147
+ path = Path(documents)
148
+
149
+ if path.is_file():
150
+ return [load_text(path)]
151
+
152
+ if path.is_dir():
153
+ docs: list[Document] = []
154
+ for pattern in (
155
+ "*.txt",
156
+ "*.md",
157
+ "*.rst",
158
+ "*.py",
159
+ "*.js",
160
+ "*.ts",
161
+ "*.go",
162
+ "*.java",
163
+ "*.c",
164
+ "*.cpp",
165
+ "*.h",
166
+ "*.hpp",
167
+ ):
168
+ docs.extend(load_directory(path, pattern))
169
+ return docs
170
+
171
+ raise ValueError(f"Invalid documents source: {documents}")
172
+
173
+ def _build_index(self) -> None:
174
+ """Build vector index from documents using batch embedding."""
175
+ all_chunks: list[Chunk] = []
176
+
177
+ for doc in self.documents:
178
+ # Use RST section chunking for .rst files, otherwise regular chunking
179
+ if doc.metadata.get("filename", "").endswith(".rst"):
180
+ chunks = chunk_rst_sections(doc.content, doc.id)
181
+ else:
182
+ chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
183
+ all_chunks.extend(chunks)
184
+
185
+ if not all_chunks:
186
+ self._chunks = ()
187
+ self._embedding_matrix = None
188
+ return
189
+
190
+ # Batch embed all chunks at once (single API call)
191
+ texts = [chunk.content for chunk in all_chunks]
192
+ responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
193
+
194
+ # Build embedding matrix directly (skip storing in chunks to avoid duplication)
195
+ embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
196
+
197
+ # Pre-normalize for fast cosine similarity (normalize once, use many times)
198
+ norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
199
+ norms[norms == 0] = 1 # Avoid division by zero
200
+
201
+ # Store as immutable tuple and pre-normalized numpy matrix
202
+ self._chunks = tuple(all_chunks)
203
+ self._embedding_matrix = embedding_matrix / norms
204
+
205
+ def add_documents(self, documents: list[Document] | str | Path) -> int:
206
+ """Add documents to the existing index incrementally.
207
+
208
+ Args:
209
+ documents: Documents to add.
210
+
211
+ Returns:
212
+ Number of chunks added.
213
+ """
214
+ new_docs = self._load_documents(documents)
215
+ if not new_docs:
216
+ return 0
217
+
218
+ self.documents.extend(new_docs)
219
+
220
+ # Chunk new docs
221
+ new_chunks: list[Chunk] = []
222
+ for doc in new_docs:
223
+ if doc.metadata.get("filename", "").endswith(".rst"):
224
+ chunks = chunk_rst_sections(doc.content, doc.id)
225
+ else:
226
+ chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
227
+ new_chunks.extend(chunks)
228
+
229
+ if not new_chunks:
230
+ return 0
231
+
232
+ # Embed new chunks
233
+ texts = [chunk.content for chunk in new_chunks]
234
+ responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
235
+
236
+ new_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
237
+
238
+ # Normalize
239
+ norms = np.linalg.norm(new_matrix, axis=1, keepdims=True)
240
+ norms[norms == 0] = 1
241
+ new_matrix_norm = new_matrix / norms
242
+
243
+ # Update state
244
+ current_chunks = list(self._chunks)
245
+ current_chunks.extend(new_chunks)
246
+ self._chunks = tuple(current_chunks)
247
+
248
+ if self._embedding_matrix is None:
249
+ self._embedding_matrix = new_matrix_norm
250
+ else:
251
+ self._embedding_matrix = np.vstack((self._embedding_matrix, new_matrix_norm))
252
+
253
+ return len(new_chunks)
254
+
255
+ def remove_documents(self, source_path_pattern: str) -> int:
256
+ """Remove documents matching a source path pattern.
257
+
258
+ Args:
259
+ source_path_pattern: Glob pattern to match 'source' metadata.
260
+
261
+ Returns:
262
+ Number of chunks removed.
263
+ """
264
+ import fnmatch
265
+
266
+ if not self._chunks:
267
+ return 0
268
+
269
+ indices_to_keep = []
270
+ kept_chunks = []
271
+ removed_count = 0
272
+
273
+ for i, chunk in enumerate(self._chunks):
274
+ source = chunk.metadata.get("source", "")
275
+ if not source or not fnmatch.fnmatch(source, source_path_pattern):
276
+ indices_to_keep.append(i)
277
+ kept_chunks.append(chunk)
278
+ else:
279
+ removed_count += 1
280
+
281
+ if removed_count == 0:
282
+ return 0
283
+
284
+ self._chunks = tuple(kept_chunks)
285
+
286
+ if self._embedding_matrix is not None:
287
+ if not kept_chunks:
288
+ self._embedding_matrix = None
289
+ else:
290
+ self._embedding_matrix = self._embedding_matrix[indices_to_keep]
291
+
292
+ # Also remove from self.documents
293
+ self.documents = [
294
+ doc for doc in self.documents if not fnmatch.fnmatch(doc.metadata.get("source", ""), source_path_pattern)
295
+ ]
296
+
297
+ return removed_count
298
+
299
+ def update_documents(self, documents: list[Document] | str | Path) -> int:
300
+ """Update existing documents (remove old, add new).
301
+
302
+ Uses document source path to identify what to remove.
303
+
304
+ Args:
305
+ documents: New versions of documents.
306
+
307
+ Returns:
308
+ Number of chunks added.
309
+ """
310
+ new_docs = self._load_documents(documents)
311
+ if not new_docs:
312
+ return 0
313
+
314
+ # Identify sources to remove
315
+ sources_to_remove = set()
316
+ for doc in new_docs:
317
+ source = doc.metadata.get("source")
318
+ if source:
319
+ sources_to_remove.add(source)
320
+
321
+ # Remove old versions
322
+ for source in sources_to_remove:
323
+ self.remove_documents(source)
324
+
325
+ # Add new versions
326
+ return self.add_documents(new_docs)
327
+
328
+ def retrieve(self, query: str, top_k: int = 3) -> list[tuple[Chunk, float]]:
329
+ """
330
+ Retrieve relevant chunks for a query.
331
+
332
+ Uses vectorized cosine similarity for fast search over all chunks.
333
+
334
+ Parameters
335
+ ----------
336
+ query : str
337
+ Search query.
338
+ top_k : int
339
+ Number of chunks to return (default: 3).
340
+
341
+ Returns
342
+ -------
343
+ list[tuple[Chunk, float]]
344
+ List of (chunk, similarity_score) tuples, sorted by relevance.
345
+
346
+ Examples
347
+ --------
348
+ >>> results = assistant.retrieve("how to create a route")
349
+ >>> for chunk, score in results:
350
+ ... print(f"{score:.2f}: {chunk.content[:100]}...")
351
+ """
352
+ if not self._chunks or self._embedding_matrix is None:
353
+ return []
354
+
355
+ # Get query embedding and normalize
356
+ query_response = self._embedding_provider.embed(query, self.embedding_model)
357
+ query_vec = np.array(query_response.embedding, dtype=np.float64)
358
+ query_norm = np.linalg.norm(query_vec)
359
+ if query_norm == 0:
360
+ return []
361
+ query_normalized = query_vec / query_norm
362
+
363
+ # Fast cosine similarity: matrix is pre-normalized, just dot product
364
+ similarities = self._embedding_matrix @ query_normalized
365
+
366
+ # Get top_k indices using argpartition (faster than full sort for large arrays)
367
+ if len(similarities) <= top_k:
368
+ top_indices = np.argsort(similarities)[::-1]
369
+ else:
370
+ # Partial sort - only find top_k elements
371
+ top_indices = np.argpartition(similarities, -top_k)[-top_k:]
372
+ # Sort the top_k by score
373
+ top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
374
+
375
+ return [(self._chunks[i], float(similarities[i])) for i in top_indices]
376
+
377
+ def retrieve_with_context(
378
+ self,
379
+ query: str,
380
+ top_k: int = 3,
381
+ window_size: int = 1,
382
+ min_score: float = 0.0,
383
+ ) -> list[tuple[Chunk, float]]:
384
+ """
385
+ Retrieve chunks with adjacent context expansion (window search).
386
+
387
+ For each retrieved chunk, also includes adjacent chunks from the
388
+ same document to provide more context. This is useful when relevant
389
+ information spans multiple chunks.
390
+
391
+ Pattern inspired by ai4rag window_search.
392
+
393
+ Parameters
394
+ ----------
395
+ query : str
396
+ Search query.
397
+ top_k : int
398
+ Number of initial chunks to retrieve (default: 3).
399
+ window_size : int
400
+ Number of adjacent chunks to include on each side (default: 1).
401
+ Set to 0 to disable window expansion.
402
+ min_score : float
403
+ Minimum similarity score threshold (default: 0.0).
404
+
405
+ Returns
406
+ -------
407
+ list[tuple[Chunk, float]]
408
+ List of (chunk, similarity_score) tuples, sorted by relevance.
409
+ Adjacent chunks have slightly lower scores.
410
+
411
+ Examples
412
+ --------
413
+ >>> # Get chunks with 1 adjacent chunk on each side
414
+ >>> results = assistant.retrieve_with_context("query", window_size=1)
415
+ >>> for chunk, score in results:
416
+ ... print(f"{score:.2f}: {chunk.content[:50]}...")
417
+ """
418
+ with log_operation("retrieve_with_context", query_len=len(query), top_k=top_k, window_size=window_size) as ctx:
419
+ # Get initial results (more than top_k to account for filtering)
420
+ results = self.retrieve(query, top_k * 2)
421
+
422
+ # Apply minimum score threshold
423
+ if min_score > 0:
424
+ results = [(chunk, score) for chunk, score in results if score >= min_score]
425
+
426
+ if window_size == 0 or not results:
427
+ ctx["expanded_chunks"] = len(results)
428
+ return results[:top_k]
429
+
430
+ # Build chunk index for fast lookup
431
+ chunk_to_idx = {id(chunk): i for i, chunk in enumerate(self._chunks)}
432
+
433
+ expanded_results: list[tuple[Chunk, float]] = []
434
+ seen_indices: set[int] = set()
435
+
436
+ for chunk, score in results[:top_k]:
437
+ chunk_idx = chunk_to_idx.get(id(chunk))
438
+ if chunk_idx is None:
439
+ expanded_results.append((chunk, score))
440
+ continue
441
+
442
+ # Get window of adjacent chunks from same document
443
+ start_idx = max(0, chunk_idx - window_size)
444
+ end_idx = min(len(self._chunks), chunk_idx + window_size + 1)
445
+
446
+ for idx in range(start_idx, end_idx):
447
+ if idx in seen_indices:
448
+ continue
449
+
450
+ adjacent_chunk = self._chunks[idx]
451
+ # Only include adjacent chunks from same document
452
+ if adjacent_chunk.doc_id == chunk.doc_id:
453
+ seen_indices.add(idx)
454
+ # Original chunk keeps full score, adjacent get 80%
455
+ adj_score = score if idx == chunk_idx else score * 0.8
456
+ expanded_results.append((adjacent_chunk, adj_score))
457
+
458
+ # Sort by score (highest first)
459
+ expanded_results.sort(key=lambda x: (-x[1], self._chunks.index(x[0]) if x[0] in self._chunks else 0))
460
+ ctx["expanded_chunks"] = len(expanded_results)
461
+
462
+ return expanded_results
463
+
464
+ def get_context_with_window(
465
+ self,
466
+ query: str,
467
+ top_k: int = 3,
468
+ window_size: int = 1,
469
+ min_score: float = 0.0,
470
+ ) -> str:
471
+ """
472
+ Get formatted context with adjacent chunk expansion.
473
+
474
+ Merges overlapping text from adjacent chunks intelligently.
475
+
476
+ Parameters
477
+ ----------
478
+ query : str
479
+ Search query.
480
+ top_k : int
481
+ Number of initial chunks to retrieve.
482
+ window_size : int
483
+ Number of adjacent chunks on each side.
484
+ min_score : float
485
+ Minimum similarity score threshold.
486
+
487
+ Returns
488
+ -------
489
+ str
490
+ Formatted context string with merged chunks.
491
+ """
492
+ results = self.retrieve_with_context(query, top_k, window_size, min_score)
493
+
494
+ if not results:
495
+ return ""
496
+
497
+ # Group chunks by document to merge properly
498
+ doc_chunks: dict[str, list[tuple[Chunk, float]]] = {}
499
+ for chunk, score in results:
500
+ doc_id = chunk.doc_id or "unknown"
501
+ if doc_id not in doc_chunks:
502
+ doc_chunks[doc_id] = []
503
+ doc_chunks[doc_id].append((chunk, score))
504
+
505
+ merged_sections: list[str] = []
506
+
507
+ for _doc_id, chunks in doc_chunks.items():
508
+ # Sort chunks by their position in the original list
509
+ chunks.sort(key=lambda x: self._chunks.index(x[0]) if x[0] in self._chunks else 0)
510
+
511
+ # Merge overlapping text
512
+ merged_content = []
513
+ for chunk, _ in chunks:
514
+ if merged_content:
515
+ # Check for overlap with previous chunk
516
+ prev_content = merged_content[-1]
517
+ non_overlapping = self._get_non_overlapping_text(prev_content, chunk.content)
518
+ if non_overlapping != chunk.content:
519
+ # Found overlap, extend previous chunk
520
+ merged_content[-1] = prev_content + non_overlapping
521
+ else:
522
+ # No overlap, add as new section
523
+ merged_content.append(chunk.content)
524
+ else:
525
+ merged_content.append(chunk.content)
526
+
527
+ merged_sections.append("\n".join(merged_content))
528
+
529
+ return "\n\n---\n\n".join(merged_sections)
530
+
531
+ def _get_non_overlapping_text(self, str1: str, str2: str) -> str:
532
+ """
533
+ Find non-overlapping portion of str2 when appending after str1.
534
+
535
+ Detects overlap where the end of str1 matches the beginning of str2,
536
+ and returns only the non-overlapping portion of str2.
537
+
538
+ Pattern from ai4rag vector_store/utils.py.
539
+
540
+ Parameters
541
+ ----------
542
+ str1 : str
543
+ First string (previous content).
544
+ str2 : str
545
+ Second string (content to potentially append).
546
+
547
+ Returns
548
+ -------
549
+ str
550
+ Non-overlapping portion of str2, or full str2 if no overlap.
551
+ """
552
+ # Limit overlap search to avoid O(n^2) for large strings
553
+ max_overlap = min(len(str1), len(str2), 200)
554
+
555
+ for i in range(max_overlap, 0, -1):
556
+ if str1[-i:] == str2[:i]:
557
+ return str2[i:]
558
+
559
+ return str2
560
+
561
+ def get_context(self, query: str, top_k: int = 3) -> str:
562
+ """
563
+ Get formatted context string from retrieved chunks.
564
+
565
+ Parameters
566
+ ----------
567
+ query : str
568
+ Search query.
569
+ top_k : int
570
+ Number of chunks to include.
571
+
572
+ Returns
573
+ -------
574
+ str
575
+ Formatted context string.
576
+ """
577
+ results = self.retrieve(query, top_k)
578
+ return "\n\n---\n\n".join(chunk.content for chunk, _ in results)
579
+
580
+ def _ensure_llm(self) -> BaseLLMProvider:
581
+ """Ensure LLM provider is available."""
582
+ if self._llm_provider is None:
583
+ raise NotImplementedError(
584
+ "No LLM configured. Provide generate_fn or a provider with LLM support "
585
+ "to use ask(), generate(), or generate_code() methods."
586
+ )
587
+ return self._llm_provider
588
+
589
+ def generate(
590
+ self,
591
+ prompt: str,
592
+ system_prompt: str | None = None,
593
+ temperature: float = 0.7,
594
+ ) -> str:
595
+ """
596
+ Generate text using the LLM (without retrieval).
597
+
598
+ Parameters
599
+ ----------
600
+ prompt : str
601
+ User prompt.
602
+ system_prompt : str, optional
603
+ System prompt for context.
604
+ temperature : float
605
+ Sampling temperature (default: 0.7).
606
+
607
+ Returns
608
+ -------
609
+ str
610
+ Generated text.
611
+
612
+ Raises
613
+ ------
614
+ NotImplementedError
615
+ If no LLM is configured.
616
+ """
617
+ llm = self._ensure_llm()
618
+ response = llm.generate(
619
+ prompt=prompt,
620
+ model=self.llm_model,
621
+ system_prompt=system_prompt,
622
+ temperature=temperature,
623
+ )
624
+ return response.text
625
+
626
+ def ask(
627
+ self,
628
+ question: str,
629
+ system_prompt: str | None = None,
630
+ top_k: int = 3,
631
+ temperature: float = 0.7,
632
+ ) -> str:
633
+ """
634
+ Ask a question using RAG (retrieve + generate).
635
+
636
+ Parameters
637
+ ----------
638
+ question : str
639
+ Question to answer.
640
+ system_prompt : str, optional
641
+ System prompt. Defaults to a helpful assistant prompt.
642
+ top_k : int
643
+ Number of context chunks to retrieve (default: 3).
644
+ temperature : float
645
+ Sampling temperature (default: 0.7).
646
+
647
+ Returns
648
+ -------
649
+ str
650
+ Generated answer.
651
+
652
+ Raises
653
+ ------
654
+ NotImplementedError
655
+ If no LLM is configured.
656
+
657
+ Examples
658
+ --------
659
+ >>> answer = assistant.ask("How do I create a REST API?")
660
+ >>> print(answer)
661
+ """
662
+ # Retrieve context
663
+ context = self.get_context(question, top_k)
664
+
665
+ # Default system prompt
666
+ if system_prompt is None:
667
+ system_prompt = """You are a helpful assistant. Answer questions based on the provided context.
668
+ If the context doesn't contain enough information, say so. Be concise and accurate."""
669
+
670
+ # Build prompt with context
671
+ prompt = f"""Context:
672
+ {context}
673
+
674
+ Question: {question}
675
+
676
+ Answer:"""
677
+
678
+ return self.generate(prompt, system_prompt, temperature)
679
+
680
+ def generate_code(
681
+ self,
682
+ request: str,
683
+ language: str = "python",
684
+ top_k: int = 3,
685
+ temperature: float = 0.7,
686
+ ) -> str:
687
+ """
688
+ Generate code based on documentation context.
689
+
690
+ Parameters
691
+ ----------
692
+ request : str
693
+ Description of what code to generate.
694
+ language : str
695
+ Programming language (default: "python").
696
+ top_k : int
697
+ Number of context chunks to retrieve.
698
+ temperature : float
699
+ Sampling temperature.
700
+
701
+ Returns
702
+ -------
703
+ str
704
+ Generated code (cleaned, without markdown).
705
+
706
+ Raises
707
+ ------
708
+ NotImplementedError
709
+ If no LLM is configured.
710
+
711
+ Examples
712
+ --------
713
+ >>> code = assistant.generate_code("create a REST API with user endpoints")
714
+ >>> print(code)
715
+ """
716
+ context = self.get_context(request, top_k)
717
+
718
+ system_prompt = f"""You are an expert {language} developer. Generate ONLY valid {language} code.
719
+
720
+ RULES:
721
+ 1. Output PURE CODE ONLY - no explanations, no markdown code blocks
722
+ 2. Include necessary imports
723
+ 3. Write clean, production-ready code
724
+ 4. Add brief comments for clarity"""
725
+
726
+ prompt = f"""Documentation:
727
+ {context}
728
+
729
+ Request: {request}
730
+
731
+ Generate the {language} code:"""
732
+
733
+ response = self.generate(prompt, system_prompt, temperature)
734
+
735
+ # Clean up response - remove markdown if present
736
+ code = response
737
+ if f"```{language}" in code:
738
+ code = code.split(f"```{language}")[1].split("```")[0]
739
+ elif "```" in code:
740
+ code = code.split("```")[1].split("```")[0]
741
+
742
+ return code.strip()
743
+
744
+ @property
745
+ def num_chunks(self) -> int:
746
+ """Return number of indexed chunks."""
747
+ return len(self._chunks)
748
+
749
+ @property
750
+ def num_documents(self) -> int:
751
+ """Return number of loaded documents."""
752
+ return len(self.documents)
753
+
754
+ @property
755
+ def has_llm(self) -> bool:
756
+ """Check if LLM is configured."""
757
+ return self._llm_provider is not None