ragit 0.8__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/__init__.py CHANGED
@@ -1,2 +1,116 @@
1
- # __init__.py
2
- from .main import VectorDBManager
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ragit - RAG toolkit for document Q&A and hyperparameter optimization.
7
+
8
+ Quick Start
9
+ -----------
10
+ >>> from ragit import RAGAssistant
11
+ >>>
12
+ >>> # With custom embedding function (retrieval-only)
13
+ >>> def my_embed(text: str) -> list[float]:
14
+ ... # Your embedding implementation
15
+ ... pass
16
+ >>> assistant = RAGAssistant("docs/", embed_fn=my_embed)
17
+ >>> results = assistant.retrieve("How do I create a REST API?")
18
+ >>>
19
+ >>> # With SentenceTransformers (offline, requires ragit[transformers])
20
+ >>> from ragit.providers import SentenceTransformersProvider
21
+ >>> assistant = RAGAssistant("docs/", provider=SentenceTransformersProvider())
22
+ >>>
23
+ >>> # With Ollama (explicit)
24
+ >>> from ragit.providers import OllamaProvider
25
+ >>> assistant = RAGAssistant("docs/", provider=OllamaProvider())
26
+ >>> answer = assistant.ask("How do I create a REST API?")
27
+
28
+ Optimization
29
+ ------------
30
+ >>> from ragit import RagitExperiment, Document, BenchmarkQuestion
31
+ >>>
32
+ >>> docs = [Document(id="doc1", content="...")]
33
+ >>> benchmark = [BenchmarkQuestion(question="What is X?", ground_truth="...")]
34
+ >>>
35
+ >>> # With explicit provider
36
+ >>> experiment = RagitExperiment(docs, benchmark, provider=OllamaProvider())
37
+ >>> results = experiment.run()
38
+ >>> print(results[0]) # Best configuration
39
+ """
40
+
41
+ import logging
42
+ import os
43
+
44
+ from ragit.version import __version__
45
+
46
+ # Set up logging
47
+ logger = logging.getLogger("ragit")
48
+ logger.setLevel(os.getenv("RAGIT_LOG_LEVEL", "INFO"))
49
+
50
+ if not logger.handlers:
51
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
52
+ handler = logging.StreamHandler()
53
+ handler.setFormatter(formatter)
54
+ logger.addHandler(handler)
55
+
56
+ # Public API (imports after logging setup)
57
+ from ragit.assistant import RAGAssistant # noqa: E402
58
+ from ragit.core.experiment.experiment import ( # noqa: E402
59
+ BenchmarkQuestion,
60
+ Chunk,
61
+ Document,
62
+ RAGConfig,
63
+ RagitExperiment,
64
+ )
65
+ from ragit.core.experiment.results import EvaluationResult, ExperimentResults # noqa: E402
66
+ from ragit.loaders import ( # noqa: E402
67
+ chunk_by_separator,
68
+ chunk_document,
69
+ chunk_rst_sections,
70
+ chunk_text,
71
+ load_directory,
72
+ load_text,
73
+ )
74
+ from ragit.providers import ( # noqa: E402
75
+ BaseEmbeddingProvider,
76
+ BaseLLMProvider,
77
+ FunctionProvider,
78
+ OllamaProvider,
79
+ )
80
+
81
+ __all__ = [
82
+ "__version__",
83
+ # High-level API
84
+ "RAGAssistant",
85
+ # Document loading
86
+ "load_text",
87
+ "load_directory",
88
+ "chunk_text",
89
+ "chunk_document",
90
+ "chunk_by_separator",
91
+ "chunk_rst_sections",
92
+ # Core classes
93
+ "Document",
94
+ "Chunk",
95
+ # Providers
96
+ "OllamaProvider",
97
+ "FunctionProvider",
98
+ "BaseLLMProvider",
99
+ "BaseEmbeddingProvider",
100
+ # Optimization
101
+ "RagitExperiment",
102
+ "BenchmarkQuestion",
103
+ "RAGConfig",
104
+ "EvaluationResult",
105
+ "ExperimentResults",
106
+ ]
107
+
108
+ # Conditionally add SentenceTransformersProvider if available
109
+ try:
110
+ from ragit.providers import ( # noqa: E402
111
+ SentenceTransformersProvider as SentenceTransformersProvider,
112
+ )
113
+
114
+ __all__ += ["SentenceTransformersProvider"]
115
+ except ImportError:
116
+ pass
ragit/assistant.py ADDED
@@ -0,0 +1,442 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ High-level RAG Assistant for document Q&A and code generation.
7
+
8
+ Provides a simple interface for RAG-based tasks.
9
+
10
+ Note: This class is NOT thread-safe. Do not share instances across threads.
11
+ """
12
+
13
+ from collections.abc import Callable
14
+ from pathlib import Path
15
+ from typing import TYPE_CHECKING
16
+
17
+ import numpy as np
18
+ from numpy.typing import NDArray
19
+
20
+ from ragit.core.experiment.experiment import Chunk, Document
21
+ from ragit.loaders import chunk_document, chunk_rst_sections, load_directory, load_text
22
+ from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
23
+ from ragit.providers.function_adapter import FunctionProvider
24
+
25
+ if TYPE_CHECKING:
26
+ from numpy.typing import NDArray
27
+
28
+
29
+ class RAGAssistant:
30
+ """
31
+ High-level RAG assistant for document Q&A and generation.
32
+
33
+ Handles document indexing, retrieval, and LLM generation in one simple API.
34
+
35
+ Parameters
36
+ ----------
37
+ documents : list[Document] or str or Path
38
+ Documents to index. Can be:
39
+ - List of Document objects
40
+ - Path to a single file
41
+ - Path to a directory (will load all .txt, .md, .rst files)
42
+ embed_fn : Callable[[str], list[float]], optional
43
+ Function that takes text and returns an embedding vector.
44
+ If provided, creates a FunctionProvider internally.
45
+ generate_fn : Callable, optional
46
+ Function for text generation. Supports (prompt) or (prompt, system_prompt).
47
+ If provided without embed_fn, must also provide embed_fn.
48
+ provider : BaseEmbeddingProvider, optional
49
+ Provider for embeddings (and optionally LLM). If embed_fn is provided,
50
+ this is ignored for embeddings.
51
+ embedding_model : str, optional
52
+ Embedding model name (used with provider).
53
+ llm_model : str, optional
54
+ LLM model name (used with provider).
55
+ chunk_size : int, optional
56
+ Chunk size for splitting documents (default: 512).
57
+ chunk_overlap : int, optional
58
+ Overlap between chunks (default: 50).
59
+
60
+ Raises
61
+ ------
62
+ ValueError
63
+ If neither embed_fn nor provider is provided.
64
+
65
+ Note
66
+ ----
67
+ This class is NOT thread-safe. Each thread should have its own instance.
68
+
69
+ Examples
70
+ --------
71
+ >>> # With custom embedding function (retrieval-only)
72
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed)
73
+ >>> results = assistant.retrieve("query")
74
+ >>>
75
+ >>> # With custom embedding and LLM functions (full RAG)
76
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
77
+ >>> answer = assistant.ask("What is X?")
78
+ >>>
79
+ >>> # With explicit provider
80
+ >>> from ragit.providers import OllamaProvider
81
+ >>> assistant = RAGAssistant(docs, provider=OllamaProvider())
82
+ >>>
83
+ >>> # With SentenceTransformers (offline)
84
+ >>> from ragit.providers import SentenceTransformersProvider
85
+ >>> assistant = RAGAssistant(docs, provider=SentenceTransformersProvider())
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ documents: list[Document] | str | Path,
91
+ embed_fn: Callable[[str], list[float]] | None = None,
92
+ generate_fn: Callable[..., str] | None = None,
93
+ provider: BaseEmbeddingProvider | BaseLLMProvider | None = None,
94
+ embedding_model: str | None = None,
95
+ llm_model: str | None = None,
96
+ chunk_size: int = 512,
97
+ chunk_overlap: int = 50,
98
+ ):
99
+ # Resolve provider from embed_fn/generate_fn or explicit provider
100
+ self._embedding_provider: BaseEmbeddingProvider
101
+ self._llm_provider: BaseLLMProvider | None = None
102
+
103
+ if embed_fn is not None:
104
+ # Create FunctionProvider from provided functions
105
+ function_provider = FunctionProvider(
106
+ embed_fn=embed_fn,
107
+ generate_fn=generate_fn,
108
+ )
109
+ self._embedding_provider = function_provider
110
+ if generate_fn is not None:
111
+ self._llm_provider = function_provider
112
+ elif provider is not None and isinstance(provider, BaseLLMProvider):
113
+ # Use explicit provider for LLM if function_provider doesn't have LLM
114
+ self._llm_provider = provider
115
+ elif provider is not None:
116
+ # Use explicit provider
117
+ if not isinstance(provider, BaseEmbeddingProvider):
118
+ raise ValueError(
119
+ "Provider must implement BaseEmbeddingProvider for embeddings. "
120
+ "Alternatively, provide embed_fn."
121
+ )
122
+ self._embedding_provider = provider
123
+ if isinstance(provider, BaseLLMProvider):
124
+ self._llm_provider = provider
125
+ else:
126
+ raise ValueError(
127
+ "Must provide embed_fn or provider for embeddings. "
128
+ "Examples:\n"
129
+ " RAGAssistant(docs, embed_fn=my_embed_function)\n"
130
+ " RAGAssistant(docs, provider=OllamaProvider())\n"
131
+ " RAGAssistant(docs, provider=SentenceTransformersProvider())"
132
+ )
133
+
134
+ self.embedding_model = embedding_model or "default"
135
+ self.llm_model = llm_model or "default"
136
+ self.chunk_size = chunk_size
137
+ self.chunk_overlap = chunk_overlap
138
+
139
+ # Load documents if path provided
140
+ self.documents = self._load_documents(documents)
141
+
142
+ # Index chunks - embeddings stored as pre-normalized numpy matrix for fast search
143
+ self._chunks: tuple[Chunk, ...] = ()
144
+ self._embedding_matrix: NDArray[np.float64] | None = None # Pre-normalized
145
+ self._build_index()
146
+
147
+ def _load_documents(self, documents: list[Document] | str | Path) -> list[Document]:
148
+ """Load documents from various sources."""
149
+ if isinstance(documents, list):
150
+ return documents
151
+
152
+ path = Path(documents)
153
+
154
+ if path.is_file():
155
+ return [load_text(path)]
156
+
157
+ if path.is_dir():
158
+ docs: list[Document] = []
159
+ for pattern in ("*.txt", "*.md", "*.rst"):
160
+ docs.extend(load_directory(path, pattern))
161
+ return docs
162
+
163
+ raise ValueError(f"Invalid documents source: {documents}")
164
+
165
+ def _build_index(self) -> None:
166
+ """Build vector index from documents using batch embedding."""
167
+ all_chunks: list[Chunk] = []
168
+
169
+ for doc in self.documents:
170
+ # Use RST section chunking for .rst files, otherwise regular chunking
171
+ if doc.metadata.get("filename", "").endswith(".rst"):
172
+ chunks = chunk_rst_sections(doc.content, doc.id)
173
+ else:
174
+ chunks = chunk_document(doc, self.chunk_size, self.chunk_overlap)
175
+ all_chunks.extend(chunks)
176
+
177
+ if not all_chunks:
178
+ self._chunks = ()
179
+ self._embedding_matrix = None
180
+ return
181
+
182
+ # Batch embed all chunks at once (single API call)
183
+ texts = [chunk.content for chunk in all_chunks]
184
+ responses = self._embedding_provider.embed_batch(texts, self.embedding_model)
185
+
186
+ # Build embedding matrix directly (skip storing in chunks to avoid duplication)
187
+ embedding_matrix = np.array([response.embedding for response in responses], dtype=np.float64)
188
+
189
+ # Pre-normalize for fast cosine similarity (normalize once, use many times)
190
+ norms = np.linalg.norm(embedding_matrix, axis=1, keepdims=True)
191
+ norms[norms == 0] = 1 # Avoid division by zero
192
+
193
+ # Store as immutable tuple and pre-normalized numpy matrix
194
+ self._chunks = tuple(all_chunks)
195
+ self._embedding_matrix = embedding_matrix / norms
196
+
197
+ def retrieve(self, query: str, top_k: int = 3) -> list[tuple[Chunk, float]]:
198
+ """
199
+ Retrieve relevant chunks for a query.
200
+
201
+ Uses vectorized cosine similarity for fast search over all chunks.
202
+
203
+ Parameters
204
+ ----------
205
+ query : str
206
+ Search query.
207
+ top_k : int
208
+ Number of chunks to return (default: 3).
209
+
210
+ Returns
211
+ -------
212
+ list[tuple[Chunk, float]]
213
+ List of (chunk, similarity_score) tuples, sorted by relevance.
214
+
215
+ Examples
216
+ --------
217
+ >>> results = assistant.retrieve("how to create a route")
218
+ >>> for chunk, score in results:
219
+ ... print(f"{score:.2f}: {chunk.content[:100]}...")
220
+ """
221
+ if not self._chunks or self._embedding_matrix is None:
222
+ return []
223
+
224
+ # Get query embedding and normalize
225
+ query_response = self._embedding_provider.embed(query, self.embedding_model)
226
+ query_vec = np.array(query_response.embedding, dtype=np.float64)
227
+ query_norm = np.linalg.norm(query_vec)
228
+ if query_norm == 0:
229
+ return []
230
+ query_normalized = query_vec / query_norm
231
+
232
+ # Fast cosine similarity: matrix is pre-normalized, just dot product
233
+ similarities = self._embedding_matrix @ query_normalized
234
+
235
+ # Get top_k indices using argpartition (faster than full sort for large arrays)
236
+ if len(similarities) <= top_k:
237
+ top_indices = np.argsort(similarities)[::-1]
238
+ else:
239
+ # Partial sort - only find top_k elements
240
+ top_indices = np.argpartition(similarities, -top_k)[-top_k:]
241
+ # Sort the top_k by score
242
+ top_indices = top_indices[np.argsort(similarities[top_indices])[::-1]]
243
+
244
+ return [(self._chunks[i], float(similarities[i])) for i in top_indices]
245
+
246
+ def get_context(self, query: str, top_k: int = 3) -> str:
247
+ """
248
+ Get formatted context string from retrieved chunks.
249
+
250
+ Parameters
251
+ ----------
252
+ query : str
253
+ Search query.
254
+ top_k : int
255
+ Number of chunks to include.
256
+
257
+ Returns
258
+ -------
259
+ str
260
+ Formatted context string.
261
+ """
262
+ results = self.retrieve(query, top_k)
263
+ return "\n\n---\n\n".join(chunk.content for chunk, _ in results)
264
+
265
+ def _ensure_llm(self) -> BaseLLMProvider:
266
+ """Ensure LLM provider is available."""
267
+ if self._llm_provider is None:
268
+ raise NotImplementedError(
269
+ "No LLM configured. Provide generate_fn or a provider with LLM support "
270
+ "to use ask(), generate(), or generate_code() methods."
271
+ )
272
+ return self._llm_provider
273
+
274
+ def generate(
275
+ self,
276
+ prompt: str,
277
+ system_prompt: str | None = None,
278
+ temperature: float = 0.7,
279
+ ) -> str:
280
+ """
281
+ Generate text using the LLM (without retrieval).
282
+
283
+ Parameters
284
+ ----------
285
+ prompt : str
286
+ User prompt.
287
+ system_prompt : str, optional
288
+ System prompt for context.
289
+ temperature : float
290
+ Sampling temperature (default: 0.7).
291
+
292
+ Returns
293
+ -------
294
+ str
295
+ Generated text.
296
+
297
+ Raises
298
+ ------
299
+ NotImplementedError
300
+ If no LLM is configured.
301
+ """
302
+ llm = self._ensure_llm()
303
+ response = llm.generate(
304
+ prompt=prompt,
305
+ model=self.llm_model,
306
+ system_prompt=system_prompt,
307
+ temperature=temperature,
308
+ )
309
+ return response.text
310
+
311
+ def ask(
312
+ self,
313
+ question: str,
314
+ system_prompt: str | None = None,
315
+ top_k: int = 3,
316
+ temperature: float = 0.7,
317
+ ) -> str:
318
+ """
319
+ Ask a question using RAG (retrieve + generate).
320
+
321
+ Parameters
322
+ ----------
323
+ question : str
324
+ Question to answer.
325
+ system_prompt : str, optional
326
+ System prompt. Defaults to a helpful assistant prompt.
327
+ top_k : int
328
+ Number of context chunks to retrieve (default: 3).
329
+ temperature : float
330
+ Sampling temperature (default: 0.7).
331
+
332
+ Returns
333
+ -------
334
+ str
335
+ Generated answer.
336
+
337
+ Raises
338
+ ------
339
+ NotImplementedError
340
+ If no LLM is configured.
341
+
342
+ Examples
343
+ --------
344
+ >>> answer = assistant.ask("How do I create a REST API?")
345
+ >>> print(answer)
346
+ """
347
+ # Retrieve context
348
+ context = self.get_context(question, top_k)
349
+
350
+ # Default system prompt
351
+ if system_prompt is None:
352
+ system_prompt = """You are a helpful assistant. Answer questions based on the provided context.
353
+ If the context doesn't contain enough information, say so. Be concise and accurate."""
354
+
355
+ # Build prompt with context
356
+ prompt = f"""Context:
357
+ {context}
358
+
359
+ Question: {question}
360
+
361
+ Answer:"""
362
+
363
+ return self.generate(prompt, system_prompt, temperature)
364
+
365
+ def generate_code(
366
+ self,
367
+ request: str,
368
+ language: str = "python",
369
+ top_k: int = 3,
370
+ temperature: float = 0.7,
371
+ ) -> str:
372
+ """
373
+ Generate code based on documentation context.
374
+
375
+ Parameters
376
+ ----------
377
+ request : str
378
+ Description of what code to generate.
379
+ language : str
380
+ Programming language (default: "python").
381
+ top_k : int
382
+ Number of context chunks to retrieve.
383
+ temperature : float
384
+ Sampling temperature.
385
+
386
+ Returns
387
+ -------
388
+ str
389
+ Generated code (cleaned, without markdown).
390
+
391
+ Raises
392
+ ------
393
+ NotImplementedError
394
+ If no LLM is configured.
395
+
396
+ Examples
397
+ --------
398
+ >>> code = assistant.generate_code("create a REST API with user endpoints")
399
+ >>> print(code)
400
+ """
401
+ context = self.get_context(request, top_k)
402
+
403
+ system_prompt = f"""You are an expert {language} developer. Generate ONLY valid {language} code.
404
+
405
+ RULES:
406
+ 1. Output PURE CODE ONLY - no explanations, no markdown code blocks
407
+ 2. Include necessary imports
408
+ 3. Write clean, production-ready code
409
+ 4. Add brief comments for clarity"""
410
+
411
+ prompt = f"""Documentation:
412
+ {context}
413
+
414
+ Request: {request}
415
+
416
+ Generate the {language} code:"""
417
+
418
+ response = self.generate(prompt, system_prompt, temperature)
419
+
420
+ # Clean up response - remove markdown if present
421
+ code = response
422
+ if f"```{language}" in code:
423
+ code = code.split(f"```{language}")[1].split("```")[0]
424
+ elif "```" in code:
425
+ code = code.split("```")[1].split("```")[0]
426
+
427
+ return code.strip()
428
+
429
+ @property
430
+ def num_chunks(self) -> int:
431
+ """Return number of indexed chunks."""
432
+ return len(self._chunks)
433
+
434
+ @property
435
+ def num_documents(self) -> int:
436
+ """Return number of loaded documents."""
437
+ return len(self.documents)
438
+
439
+ @property
440
+ def has_llm(self) -> bool:
441
+ """Check if LLM is configured."""
442
+ return self._llm_provider is not None
ragit/config.py ADDED
@@ -0,0 +1,60 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ragit configuration management.
7
+
8
+ Loads configuration from environment variables and .env files.
9
+
10
+ Note: As of v0.8.0, ragit no longer has default LLM or embedding models.
11
+ Users must explicitly configure providers.
12
+ """
13
+
14
+ import os
15
+ from pathlib import Path
16
+
17
+ from dotenv import load_dotenv
18
+
19
+ # Load .env file from current working directory or project root
20
+ _env_path = Path.cwd() / ".env"
21
+ if _env_path.exists():
22
+ load_dotenv(_env_path)
23
+ else:
24
+ # Try to find .env in parent directories
25
+ for parent in Path.cwd().parents:
26
+ _env_path = parent / ".env"
27
+ if _env_path.exists():
28
+ load_dotenv(_env_path)
29
+ break
30
+
31
+
32
+ class Config:
33
+ """Ragit configuration loaded from environment variables.
34
+
35
+ Note: As of v0.8.0, DEFAULT_LLM_MODEL and DEFAULT_EMBEDDING_MODEL are
36
+ no longer used as defaults. They are only read from environment variables
37
+ for backwards compatibility with user configurations.
38
+ """
39
+
40
+ # Ollama LLM API Configuration (used when explicitly using OllamaProvider)
41
+ OLLAMA_BASE_URL: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
42
+ OLLAMA_API_KEY: str | None = os.getenv("OLLAMA_API_KEY")
43
+ OLLAMA_TIMEOUT: int = int(os.getenv("OLLAMA_TIMEOUT", "120"))
44
+
45
+ # Ollama Embedding API Configuration
46
+ OLLAMA_EMBEDDING_URL: str = os.getenv(
47
+ "OLLAMA_EMBEDDING_URL", os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
48
+ )
49
+
50
+ # Model settings (only used if explicitly requested, no defaults)
51
+ # These can still be set via environment variables for convenience
52
+ DEFAULT_LLM_MODEL: str | None = os.getenv("RAGIT_DEFAULT_LLM_MODEL")
53
+ DEFAULT_EMBEDDING_MODEL: str | None = os.getenv("RAGIT_DEFAULT_EMBEDDING_MODEL")
54
+
55
+ # Logging
56
+ LOG_LEVEL: str = os.getenv("RAGIT_LOG_LEVEL", "INFO")
57
+
58
+
59
+ # Singleton instance
60
+ config = Config()
ragit/core/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """Ragit core module."""
@@ -0,0 +1,22 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """Ragit experiment module."""
6
+
7
+ from ragit.core.experiment.experiment import (
8
+ BenchmarkQuestion,
9
+ Document,
10
+ RAGConfig,
11
+ RagitExperiment,
12
+ )
13
+ from ragit.core.experiment.results import EvaluationResult, ExperimentResults
14
+
15
+ __all__ = [
16
+ "RagitExperiment",
17
+ "Document",
18
+ "BenchmarkQuestion",
19
+ "RAGConfig",
20
+ "EvaluationResult",
21
+ "ExperimentResults",
22
+ ]