ragit 0.7__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/loaders.py ADDED
@@ -0,0 +1,219 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Document loading and chunking utilities.
7
+
8
+ Provides simple functions to load documents from files and chunk text.
9
+ """
10
+
11
+ import re
12
+ from pathlib import Path
13
+
14
+ from ragit.core.experiment.experiment import Chunk, Document
15
+
16
+
17
+ def load_text(path: str | Path) -> Document:
18
+ """
19
+ Load a single text file as a Document.
20
+
21
+ Parameters
22
+ ----------
23
+ path : str or Path
24
+ Path to the text file (.txt, .md, .rst, etc.)
25
+
26
+ Returns
27
+ -------
28
+ Document
29
+ Document with file content and metadata.
30
+
31
+ Examples
32
+ --------
33
+ >>> doc = load_text("docs/tutorial.rst")
34
+ >>> print(doc.id, len(doc.content))
35
+ """
36
+ path = Path(path)
37
+ content = path.read_text(encoding="utf-8")
38
+ return Document(id=path.stem, content=content, metadata={"source": str(path), "filename": path.name})
39
+
40
+
41
+ def load_directory(path: str | Path, pattern: str = "*.txt", recursive: bool = False) -> list[Document]:
42
+ """
43
+ Load all matching files from a directory as Documents.
44
+
45
+ Parameters
46
+ ----------
47
+ path : str or Path
48
+ Directory path.
49
+ pattern : str
50
+ Glob pattern for files (default: "*.txt").
51
+ recursive : bool
52
+ If True, search recursively (default: False).
53
+
54
+ Returns
55
+ -------
56
+ list[Document]
57
+ List of loaded documents.
58
+
59
+ Examples
60
+ --------
61
+ >>> docs = load_directory("docs/", "*.rst")
62
+ >>> docs = load_directory("docs/", "**/*.md", recursive=True)
63
+ """
64
+ path = Path(path)
65
+ glob_method = path.rglob if recursive else path.glob
66
+ documents = []
67
+
68
+ for file_path in sorted(glob_method(pattern)):
69
+ if file_path.is_file():
70
+ documents.append(load_text(file_path))
71
+
72
+ return documents
73
+
74
+
75
+ def chunk_text(text: str, chunk_size: int = 512, chunk_overlap: int = 50, doc_id: str = "doc") -> list[Chunk]:
76
+ """
77
+ Split text into overlapping chunks.
78
+
79
+ Parameters
80
+ ----------
81
+ text : str
82
+ Text to chunk.
83
+ chunk_size : int
84
+ Maximum characters per chunk (default: 512).
85
+ chunk_overlap : int
86
+ Overlap between chunks (default: 50).
87
+ doc_id : str
88
+ Document ID for the chunks (default: "doc").
89
+
90
+ Returns
91
+ -------
92
+ list[Chunk]
93
+ List of text chunks.
94
+
95
+ Examples
96
+ --------
97
+ >>> chunks = chunk_text("Long document...", chunk_size=256, chunk_overlap=50)
98
+ """
99
+ if chunk_overlap >= chunk_size:
100
+ raise ValueError("chunk_overlap must be less than chunk_size")
101
+
102
+ chunks = []
103
+ start = 0
104
+ chunk_idx = 0
105
+
106
+ while start < len(text):
107
+ end = start + chunk_size
108
+ chunk_text = text[start:end].strip()
109
+
110
+ if chunk_text:
111
+ chunks.append(Chunk(content=chunk_text, doc_id=doc_id, chunk_index=chunk_idx))
112
+ chunk_idx += 1
113
+
114
+ start = end - chunk_overlap
115
+ if start >= len(text) - chunk_overlap:
116
+ break
117
+
118
+ return chunks
119
+
120
+
121
+ def chunk_document(doc: Document, chunk_size: int = 512, chunk_overlap: int = 50) -> list[Chunk]:
122
+ """
123
+ Split a Document into overlapping chunks.
124
+
125
+ Parameters
126
+ ----------
127
+ doc : Document
128
+ Document to chunk.
129
+ chunk_size : int
130
+ Maximum characters per chunk.
131
+ chunk_overlap : int
132
+ Overlap between chunks.
133
+
134
+ Returns
135
+ -------
136
+ list[Chunk]
137
+ List of chunks from the document.
138
+ """
139
+ return chunk_text(doc.content, chunk_size, chunk_overlap, doc.id)
140
+
141
+
142
+ def chunk_by_separator(text: str, separator: str = "\n\n", doc_id: str = "doc") -> list[Chunk]:
143
+ """
144
+ Split text by a separator (e.g., paragraphs, sections).
145
+
146
+ Parameters
147
+ ----------
148
+ text : str
149
+ Text to split.
150
+ separator : str
151
+ Separator string (default: double newline for paragraphs).
152
+ doc_id : str
153
+ Document ID for the chunks.
154
+
155
+ Returns
156
+ -------
157
+ list[Chunk]
158
+ List of chunks.
159
+
160
+ Examples
161
+ --------
162
+ >>> chunks = chunk_by_separator(text, separator="\\n---\\n")
163
+ """
164
+ parts = text.split(separator)
165
+ chunks = []
166
+
167
+ for idx, part in enumerate(parts):
168
+ content = part.strip()
169
+ if content:
170
+ chunks.append(Chunk(content=content, doc_id=doc_id, chunk_index=idx))
171
+
172
+ return chunks
173
+
174
+
175
+ def chunk_rst_sections(text: str, doc_id: str = "doc") -> list[Chunk]:
176
+ """
177
+ Split RST document by section headers.
178
+
179
+ Parameters
180
+ ----------
181
+ text : str
182
+ RST document text.
183
+ doc_id : str
184
+ Document ID for the chunks.
185
+
186
+ Returns
187
+ -------
188
+ list[Chunk]
189
+ List of section chunks.
190
+ """
191
+ # Match RST section headers (title followed by underline of =, -, ~, etc.)
192
+ pattern = r"\n([^\n]+)\n([=\-~`\'\"^_*+#]+)\n"
193
+
194
+ # Find all section positions
195
+ matches = list(re.finditer(pattern, text))
196
+
197
+ if not matches:
198
+ # No sections found, return whole text as one chunk
199
+ return [Chunk(content=text.strip(), doc_id=doc_id, chunk_index=0)] if text.strip() else []
200
+
201
+ chunks = []
202
+
203
+ # Handle content before first section
204
+ first_pos = matches[0].start()
205
+ if first_pos > 0:
206
+ pre_content = text[:first_pos].strip()
207
+ if pre_content:
208
+ chunks.append(Chunk(content=pre_content, doc_id=doc_id, chunk_index=0))
209
+
210
+ # Extract each section
211
+ for i, match in enumerate(matches):
212
+ start = match.start()
213
+ end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
214
+
215
+ section_content = text[start:end].strip()
216
+ if section_content:
217
+ chunks.append(Chunk(content=section_content, doc_id=doc_id, chunk_index=len(chunks)))
218
+
219
+ return chunks
@@ -0,0 +1,20 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ragit Providers - LLM and Embedding providers for RAG optimization.
7
+
8
+ Supported providers:
9
+ - Ollama (local)
10
+ - Future: Gemini, Claude, OpenAI
11
+ """
12
+
13
+ from ragit.providers.base import BaseEmbeddingProvider, BaseLLMProvider
14
+ from ragit.providers.ollama import OllamaProvider
15
+
16
+ __all__ = [
17
+ "BaseLLMProvider",
18
+ "BaseEmbeddingProvider",
19
+ "OllamaProvider",
20
+ ]
@@ -0,0 +1,147 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Base provider interfaces for LLM and Embedding providers.
7
+
8
+ These abstract classes define the interface that all providers must implement,
9
+ making it easy to add new providers (Gemini, Claude, OpenAI, etc.)
10
+ """
11
+
12
+ from abc import ABC, abstractmethod
13
+ from dataclasses import dataclass
14
+
15
+
16
+ @dataclass
17
+ class LLMResponse:
18
+ """Response from an LLM call."""
19
+
20
+ text: str
21
+ model: str
22
+ provider: str
23
+ usage: dict[str, int] | None = None
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class EmbeddingResponse:
28
+ """Response from an embedding call (immutable)."""
29
+
30
+ embedding: tuple[float, ...]
31
+ model: str
32
+ provider: str
33
+ dimensions: int
34
+
35
+
36
+ class BaseLLMProvider(ABC):
37
+ """
38
+ Abstract base class for LLM providers.
39
+
40
+ Implement this to add support for new LLM providers like Gemini, Claude, etc.
41
+ """
42
+
43
+ @property
44
+ @abstractmethod
45
+ def provider_name(self) -> str:
46
+ """Return the provider name (e.g., 'ollama', 'gemini', 'claude')."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def generate(
51
+ self,
52
+ prompt: str,
53
+ model: str,
54
+ system_prompt: str | None = None,
55
+ temperature: float = 0.7,
56
+ max_tokens: int | None = None,
57
+ ) -> LLMResponse:
58
+ """
59
+ Generate text from the LLM.
60
+
61
+ Parameters
62
+ ----------
63
+ prompt : str
64
+ The user prompt/query.
65
+ model : str
66
+ Model identifier (e.g., 'llama3', 'qwen3-vl:235b-instruct-cloud').
67
+ system_prompt : str, optional
68
+ System prompt for context/instructions.
69
+ temperature : float
70
+ Sampling temperature (0.0 to 1.0).
71
+ max_tokens : int, optional
72
+ Maximum tokens to generate.
73
+
74
+ Returns
75
+ -------
76
+ LLMResponse
77
+ The generated response.
78
+ """
79
+ pass
80
+
81
+ @abstractmethod
82
+ def is_available(self) -> bool:
83
+ """Check if the provider is available and configured."""
84
+ pass
85
+
86
+
87
+ class BaseEmbeddingProvider(ABC):
88
+ """
89
+ Abstract base class for embedding providers.
90
+
91
+ Implement this to add support for new embedding providers.
92
+ """
93
+
94
+ @property
95
+ @abstractmethod
96
+ def provider_name(self) -> str:
97
+ """Return the provider name."""
98
+ pass
99
+
100
+ @property
101
+ @abstractmethod
102
+ def dimensions(self) -> int:
103
+ """Return the embedding dimensions for the current model."""
104
+ pass
105
+
106
+ @abstractmethod
107
+ def embed(self, text: str, model: str) -> EmbeddingResponse:
108
+ """
109
+ Generate embedding for text.
110
+
111
+ Parameters
112
+ ----------
113
+ text : str
114
+ Text to embed.
115
+ model : str
116
+ Model identifier (e.g., 'nomic-embed-text').
117
+
118
+ Returns
119
+ -------
120
+ EmbeddingResponse
121
+ The embedding response.
122
+ """
123
+ pass
124
+
125
+ @abstractmethod
126
+ def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
127
+ """
128
+ Generate embeddings for multiple texts.
129
+
130
+ Parameters
131
+ ----------
132
+ texts : list[str]
133
+ Texts to embed.
134
+ model : str
135
+ Model identifier.
136
+
137
+ Returns
138
+ -------
139
+ list[EmbeddingResponse]
140
+ List of embedding responses.
141
+ """
142
+ pass
143
+
144
+ @abstractmethod
145
+ def is_available(self) -> bool:
146
+ """Check if the provider is available and configured."""
147
+ pass
@@ -0,0 +1,284 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ollama provider for LLM and Embedding operations.
7
+
8
+ This provider connects to a local or remote Ollama server.
9
+ Configuration is loaded from environment variables.
10
+ """
11
+
12
+ import requests
13
+
14
+ from ragit.config import config
15
+ from ragit.providers.base import (
16
+ BaseEmbeddingProvider,
17
+ BaseLLMProvider,
18
+ EmbeddingResponse,
19
+ LLMResponse,
20
+ )
21
+
22
+
23
+ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
24
+ """
25
+ Ollama provider for both LLM and Embedding operations.
26
+
27
+ Parameters
28
+ ----------
29
+ base_url : str, optional
30
+ Ollama server URL (default: from OLLAMA_BASE_URL env var)
31
+ api_key : str, optional
32
+ API key for authentication (default: from OLLAMA_API_KEY env var)
33
+ timeout : int, optional
34
+ Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
35
+
36
+ Examples
37
+ --------
38
+ >>> provider = OllamaProvider()
39
+ >>> response = provider.generate("What is RAG?", model="llama3")
40
+ >>> print(response.text)
41
+
42
+ >>> embedding = provider.embed("Hello world", model="nomic-embed-text")
43
+ >>> print(len(embedding.embedding))
44
+ """
45
+
46
+ # Known embedding model dimensions
47
+ EMBEDDING_DIMENSIONS = {
48
+ "nomic-embed-text": 768,
49
+ "nomic-embed-text:latest": 768,
50
+ "mxbai-embed-large": 1024,
51
+ "all-minilm": 384,
52
+ "snowflake-arctic-embed": 1024,
53
+ "qwen3-embedding": 4096,
54
+ "qwen3-embedding:0.6b": 1024,
55
+ "qwen3-embedding:4b": 2560,
56
+ "qwen3-embedding:8b": 4096,
57
+ }
58
+
59
+ def __init__(
60
+ self,
61
+ base_url: str | None = None,
62
+ embedding_url: str | None = None,
63
+ api_key: str | None = None,
64
+ timeout: int | None = None,
65
+ ) -> None:
66
+ self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
67
+ self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
68
+ self.api_key = api_key or config.OLLAMA_API_KEY
69
+ self.timeout = timeout or config.OLLAMA_TIMEOUT
70
+ self._current_embed_model: str | None = None
71
+ self._current_dimensions: int = 768 # default
72
+
73
+ def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
74
+ """Get request headers including authentication if API key is set."""
75
+ headers = {"Content-Type": "application/json"}
76
+ if include_auth and self.api_key:
77
+ headers["Authorization"] = f"Bearer {self.api_key}"
78
+ return headers
79
+
80
+ @property
81
+ def provider_name(self) -> str:
82
+ return "ollama"
83
+
84
+ @property
85
+ def dimensions(self) -> int:
86
+ return self._current_dimensions
87
+
88
+ def is_available(self) -> bool:
89
+ """Check if Ollama server is reachable."""
90
+ try:
91
+ response = requests.get(
92
+ f"{self.base_url}/api/tags",
93
+ headers=self._get_headers(),
94
+ timeout=5,
95
+ )
96
+ return response.status_code == 200
97
+ except requests.RequestException:
98
+ return False
99
+
100
+ def list_models(self) -> list[dict[str, str]]:
101
+ """List available models on the Ollama server."""
102
+ try:
103
+ response = requests.get(
104
+ f"{self.base_url}/api/tags",
105
+ headers=self._get_headers(),
106
+ timeout=10,
107
+ )
108
+ response.raise_for_status()
109
+ data = response.json()
110
+ return list(data.get("models", []))
111
+ except requests.RequestException as e:
112
+ raise ConnectionError(f"Failed to list Ollama models: {e}") from e
113
+
114
+ def generate(
115
+ self,
116
+ prompt: str,
117
+ model: str,
118
+ system_prompt: str | None = None,
119
+ temperature: float = 0.7,
120
+ max_tokens: int | None = None,
121
+ ) -> LLMResponse:
122
+ """Generate text using Ollama."""
123
+ options: dict[str, float | int] = {"temperature": temperature}
124
+ if max_tokens:
125
+ options["num_predict"] = max_tokens
126
+
127
+ payload: dict[str, str | bool | dict[str, float | int]] = {
128
+ "model": model,
129
+ "prompt": prompt,
130
+ "stream": False,
131
+ "options": options,
132
+ }
133
+
134
+ if system_prompt:
135
+ payload["system"] = system_prompt
136
+
137
+ try:
138
+ response = requests.post(
139
+ f"{self.base_url}/api/generate",
140
+ headers=self._get_headers(),
141
+ json=payload,
142
+ timeout=self.timeout,
143
+ )
144
+ response.raise_for_status()
145
+ data = response.json()
146
+
147
+ return LLMResponse(
148
+ text=data.get("response", ""),
149
+ model=model,
150
+ provider=self.provider_name,
151
+ usage={
152
+ "prompt_tokens": data.get("prompt_eval_count"),
153
+ "completion_tokens": data.get("eval_count"),
154
+ "total_duration": data.get("total_duration"),
155
+ },
156
+ )
157
+ except requests.RequestException as e:
158
+ raise ConnectionError(f"Ollama generate failed: {e}") from e
159
+
160
+ def embed(self, text: str, model: str) -> EmbeddingResponse:
161
+ """Generate embedding using Ollama (uses embedding_url, no auth for local)."""
162
+ self._current_embed_model = model
163
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
164
+
165
+ try:
166
+ response = requests.post(
167
+ f"{self.embedding_url}/api/embeddings",
168
+ headers=self._get_headers(include_auth=False),
169
+ json={"model": model, "prompt": text},
170
+ timeout=self.timeout,
171
+ )
172
+ response.raise_for_status()
173
+ data = response.json()
174
+
175
+ embedding = data.get("embedding", [])
176
+ if not embedding:
177
+ raise ValueError("Empty embedding returned from Ollama")
178
+
179
+ # Update dimensions from actual response
180
+ self._current_dimensions = len(embedding)
181
+
182
+ return EmbeddingResponse(
183
+ embedding=tuple(embedding),
184
+ model=model,
185
+ provider=self.provider_name,
186
+ dimensions=len(embedding),
187
+ )
188
+ except requests.RequestException as e:
189
+ raise ConnectionError(f"Ollama embed failed: {e}") from e
190
+
191
+ def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
192
+ """Generate embeddings for multiple texts (uses embedding_url, no auth for local).
193
+
194
+ Note: Ollama /api/embeddings only supports single prompts, so we loop.
195
+ """
196
+ self._current_embed_model = model
197
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
198
+
199
+ results = []
200
+ try:
201
+ for text in texts:
202
+ response = requests.post(
203
+ f"{self.embedding_url}/api/embeddings",
204
+ headers=self._get_headers(include_auth=False),
205
+ json={"model": model, "prompt": text},
206
+ timeout=self.timeout,
207
+ )
208
+ response.raise_for_status()
209
+ data = response.json()
210
+
211
+ embedding = data.get("embedding", [])
212
+ if embedding:
213
+ self._current_dimensions = len(embedding)
214
+
215
+ results.append(
216
+ EmbeddingResponse(
217
+ embedding=tuple(embedding),
218
+ model=model,
219
+ provider=self.provider_name,
220
+ dimensions=len(embedding),
221
+ )
222
+ )
223
+ return results
224
+ except requests.RequestException as e:
225
+ raise ConnectionError(f"Ollama batch embed failed: {e}") from e
226
+
227
+ def chat(
228
+ self,
229
+ messages: list[dict[str, str]],
230
+ model: str,
231
+ temperature: float = 0.7,
232
+ max_tokens: int | None = None,
233
+ ) -> LLMResponse:
234
+ """
235
+ Chat completion using Ollama.
236
+
237
+ Parameters
238
+ ----------
239
+ messages : list[dict]
240
+ List of messages with 'role' and 'content' keys.
241
+ model : str
242
+ Model identifier.
243
+ temperature : float
244
+ Sampling temperature.
245
+ max_tokens : int, optional
246
+ Maximum tokens to generate.
247
+
248
+ Returns
249
+ -------
250
+ LLMResponse
251
+ The generated response.
252
+ """
253
+ options: dict[str, float | int] = {"temperature": temperature}
254
+ if max_tokens:
255
+ options["num_predict"] = max_tokens
256
+
257
+ payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
258
+ "model": model,
259
+ "messages": messages,
260
+ "stream": False,
261
+ "options": options,
262
+ }
263
+
264
+ try:
265
+ response = requests.post(
266
+ f"{self.base_url}/api/chat",
267
+ headers=self._get_headers(),
268
+ json=payload,
269
+ timeout=self.timeout,
270
+ )
271
+ response.raise_for_status()
272
+ data = response.json()
273
+
274
+ return LLMResponse(
275
+ text=data.get("message", {}).get("content", ""),
276
+ model=model,
277
+ provider=self.provider_name,
278
+ usage={
279
+ "prompt_tokens": data.get("prompt_eval_count"),
280
+ "completion_tokens": data.get("eval_count"),
281
+ },
282
+ )
283
+ except requests.RequestException as e:
284
+ raise ConnectionError(f"Ollama chat failed: {e}") from e