ragit 0.8__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,446 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ollama provider for LLM and Embedding operations.
7
+
8
+ This provider connects to a local or remote Ollama server.
9
+ Configuration is loaded from environment variables.
10
+
11
+ Performance optimizations:
12
+ - Connection pooling via requests.Session()
13
+ - Async parallel embedding via trio + httpx
14
+ - LRU cache for repeated embedding queries
15
+ """
16
+
17
+ from functools import lru_cache
18
+ from typing import Any
19
+
20
+ import httpx
21
+ import requests
22
+
23
+ from ragit.config import config
24
+ from ragit.providers.base import (
25
+ BaseEmbeddingProvider,
26
+ BaseLLMProvider,
27
+ EmbeddingResponse,
28
+ LLMResponse,
29
+ )
30
+
31
+
32
+ # Module-level cache for embeddings (shared across instances)
33
+ @lru_cache(maxsize=2048)
34
+ def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
35
+ """Cache embedding results to avoid redundant API calls."""
36
+ # Truncate oversized inputs
37
+ if len(text) > OllamaProvider.MAX_EMBED_CHARS:
38
+ text = text[: OllamaProvider.MAX_EMBED_CHARS]
39
+
40
+ response = requests.post(
41
+ f"{embedding_url}/api/embed",
42
+ headers={"Content-Type": "application/json"},
43
+ json={"model": model, "input": text},
44
+ timeout=timeout,
45
+ )
46
+ response.raise_for_status()
47
+ data = response.json()
48
+ embeddings = data.get("embeddings", [])
49
+ if not embeddings or not embeddings[0]:
50
+ raise ValueError("Empty embedding returned from Ollama")
51
+ return tuple(embeddings[0])
52
+
53
+
54
+ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
55
+ """
56
+ Ollama provider for both LLM and Embedding operations.
57
+
58
+ Performance features:
59
+ - Connection pooling via requests.Session() for faster sequential requests
60
+ - Native batch embedding via /api/embed endpoint (single API call)
61
+ - LRU cache for repeated embedding queries (2048 entries)
62
+
63
+ Parameters
64
+ ----------
65
+ base_url : str, optional
66
+ Ollama server URL (default: from OLLAMA_BASE_URL env var)
67
+ api_key : str, optional
68
+ API key for authentication (default: from OLLAMA_API_KEY env var)
69
+ timeout : int, optional
70
+ Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
71
+ use_cache : bool, optional
72
+ Enable embedding cache (default: True)
73
+
74
+ Examples
75
+ --------
76
+ >>> provider = OllamaProvider()
77
+ >>> response = provider.generate("What is RAG?", model="llama3")
78
+ >>> print(response.text)
79
+
80
+ >>> # Batch embedding (single API call)
81
+ >>> embeddings = provider.embed_batch(texts, "mxbai-embed-large")
82
+ """
83
+
84
+ # Known embedding model dimensions
85
+ EMBEDDING_DIMENSIONS: dict[str, int] = {
86
+ "nomic-embed-text": 768,
87
+ "nomic-embed-text:latest": 768,
88
+ "mxbai-embed-large": 1024,
89
+ "all-minilm": 384,
90
+ "snowflake-arctic-embed": 1024,
91
+ "qwen3-embedding": 4096,
92
+ "qwen3-embedding:0.6b": 1024,
93
+ "qwen3-embedding:4b": 2560,
94
+ "qwen3-embedding:8b": 4096,
95
+ }
96
+
97
+ # Max characters per embedding request (safe limit for 512 token models)
98
+ MAX_EMBED_CHARS = 2000
99
+
100
+ def __init__(
101
+ self,
102
+ base_url: str | None = None,
103
+ embedding_url: str | None = None,
104
+ api_key: str | None = None,
105
+ timeout: int | None = None,
106
+ use_cache: bool = True,
107
+ ) -> None:
108
+ self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
109
+ self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
110
+ self.api_key = api_key or config.OLLAMA_API_KEY
111
+ self.timeout = timeout or config.OLLAMA_TIMEOUT
112
+ self.use_cache = use_cache
113
+ self._current_embed_model: str | None = None
114
+ self._current_dimensions: int = 768 # default
115
+
116
+ # Connection pooling via session
117
+ self._session: requests.Session | None = None
118
+
119
+ @property
120
+ def session(self) -> requests.Session:
121
+ """Lazy-initialized session for connection pooling."""
122
+ if self._session is None:
123
+ self._session = requests.Session()
124
+ self._session.headers.update({"Content-Type": "application/json"})
125
+ if self.api_key:
126
+ self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
127
+ return self._session
128
+
129
+ def close(self) -> None:
130
+ """Close the session and release resources."""
131
+ if self._session is not None:
132
+ self._session.close()
133
+ self._session = None
134
+
135
+ def __del__(self) -> None:
136
+ """Cleanup on garbage collection."""
137
+ self.close()
138
+
139
+ def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
140
+ """Get request headers including authentication if API key is set."""
141
+ headers = {"Content-Type": "application/json"}
142
+ if include_auth and self.api_key:
143
+ headers["Authorization"] = f"Bearer {self.api_key}"
144
+ return headers
145
+
146
+ @property
147
+ def provider_name(self) -> str:
148
+ return "ollama"
149
+
150
+ @property
151
+ def dimensions(self) -> int:
152
+ return self._current_dimensions
153
+
154
+ def is_available(self) -> bool:
155
+ """Check if Ollama server is reachable."""
156
+ try:
157
+ response = self.session.get(
158
+ f"{self.base_url}/api/tags",
159
+ timeout=5,
160
+ )
161
+ return bool(response.status_code == 200)
162
+ except requests.RequestException:
163
+ return False
164
+
165
+ def list_models(self) -> list[dict[str, Any]]:
166
+ """List available models on the Ollama server."""
167
+ try:
168
+ response = self.session.get(
169
+ f"{self.base_url}/api/tags",
170
+ timeout=10,
171
+ )
172
+ response.raise_for_status()
173
+ data = response.json()
174
+ return list(data.get("models", []))
175
+ except requests.RequestException as e:
176
+ raise ConnectionError(f"Failed to list Ollama models: {e}") from e
177
+
178
+ def generate(
179
+ self,
180
+ prompt: str,
181
+ model: str,
182
+ system_prompt: str | None = None,
183
+ temperature: float = 0.7,
184
+ max_tokens: int | None = None,
185
+ ) -> LLMResponse:
186
+ """Generate text using Ollama."""
187
+ options: dict[str, float | int] = {"temperature": temperature}
188
+ if max_tokens:
189
+ options["num_predict"] = max_tokens
190
+
191
+ payload: dict[str, str | bool | dict[str, float | int]] = {
192
+ "model": model,
193
+ "prompt": prompt,
194
+ "stream": False,
195
+ "options": options,
196
+ }
197
+
198
+ if system_prompt:
199
+ payload["system"] = system_prompt
200
+
201
+ try:
202
+ response = self.session.post(
203
+ f"{self.base_url}/api/generate",
204
+ json=payload,
205
+ timeout=self.timeout,
206
+ )
207
+ response.raise_for_status()
208
+ data = response.json()
209
+
210
+ return LLMResponse(
211
+ text=data.get("response", ""),
212
+ model=model,
213
+ provider=self.provider_name,
214
+ usage={
215
+ "prompt_tokens": data.get("prompt_eval_count"),
216
+ "completion_tokens": data.get("eval_count"),
217
+ "total_duration": data.get("total_duration"),
218
+ },
219
+ )
220
+ except requests.RequestException as e:
221
+ raise ConnectionError(f"Ollama generate failed: {e}") from e
222
+
223
+ def embed(self, text: str, model: str) -> EmbeddingResponse:
224
+ """Generate embedding using Ollama with optional caching."""
225
+ self._current_embed_model = model
226
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
227
+
228
+ try:
229
+ if self.use_cache:
230
+ # Use cached version
231
+ embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
232
+ else:
233
+ # Direct call without cache
234
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
235
+ response = self.session.post(
236
+ f"{self.embedding_url}/api/embed",
237
+ json={"model": model, "input": truncated},
238
+ timeout=self.timeout,
239
+ )
240
+ response.raise_for_status()
241
+ data = response.json()
242
+ embeddings = data.get("embeddings", [])
243
+ if not embeddings or not embeddings[0]:
244
+ raise ValueError("Empty embedding returned from Ollama")
245
+ embedding = tuple(embeddings[0])
246
+
247
+ # Update dimensions from actual response
248
+ self._current_dimensions = len(embedding)
249
+
250
+ return EmbeddingResponse(
251
+ embedding=embedding,
252
+ model=model,
253
+ provider=self.provider_name,
254
+ dimensions=len(embedding),
255
+ )
256
+ except requests.RequestException as e:
257
+ raise ConnectionError(f"Ollama embed failed: {e}") from e
258
+
259
+ def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
260
+ """Generate embeddings for multiple texts in a single API call.
261
+
262
+ The /api/embed endpoint supports batch inputs natively.
263
+ """
264
+ self._current_embed_model = model
265
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
266
+
267
+ # Truncate oversized inputs
268
+ truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
269
+
270
+ try:
271
+ response = self.session.post(
272
+ f"{self.embedding_url}/api/embed",
273
+ json={"model": model, "input": truncated_texts},
274
+ timeout=self.timeout,
275
+ )
276
+ response.raise_for_status()
277
+ data = response.json()
278
+ embeddings_list = data.get("embeddings", [])
279
+
280
+ if not embeddings_list:
281
+ raise ValueError("Empty embeddings returned from Ollama")
282
+
283
+ results = []
284
+ for embedding_data in embeddings_list:
285
+ embedding = tuple(embedding_data) if embedding_data else ()
286
+ if embedding:
287
+ self._current_dimensions = len(embedding)
288
+
289
+ results.append(
290
+ EmbeddingResponse(
291
+ embedding=embedding,
292
+ model=model,
293
+ provider=self.provider_name,
294
+ dimensions=len(embedding),
295
+ )
296
+ )
297
+ return results
298
+ except requests.RequestException as e:
299
+ raise ConnectionError(f"Ollama batch embed failed: {e}") from e
300
+
301
+ async def embed_batch_async(
302
+ self,
303
+ texts: list[str],
304
+ model: str,
305
+ max_concurrent: int = 10, # kept for API compatibility, no longer used
306
+ ) -> list[EmbeddingResponse]:
307
+ """Generate embeddings for multiple texts asynchronously.
308
+
309
+ The /api/embed endpoint supports batch inputs natively, so this
310
+ makes a single async HTTP request for all texts.
311
+
312
+ Parameters
313
+ ----------
314
+ texts : list[str]
315
+ Texts to embed.
316
+ model : str
317
+ Embedding model name.
318
+ max_concurrent : int
319
+ Deprecated, kept for API compatibility. No longer used since
320
+ the API now supports native batching.
321
+
322
+ Returns
323
+ -------
324
+ list[EmbeddingResponse]
325
+ Embeddings in the same order as input texts.
326
+
327
+ Examples
328
+ --------
329
+ >>> import trio
330
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
331
+ """
332
+ self._current_embed_model = model
333
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
334
+
335
+ # Truncate oversized inputs
336
+ truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
337
+
338
+ try:
339
+ async with httpx.AsyncClient() as client:
340
+ response = await client.post(
341
+ f"{self.embedding_url}/api/embed",
342
+ json={"model": model, "input": truncated_texts},
343
+ timeout=self.timeout,
344
+ )
345
+ response.raise_for_status()
346
+ data = response.json()
347
+
348
+ embeddings_list = data.get("embeddings", [])
349
+ if not embeddings_list:
350
+ raise ValueError("Empty embeddings returned from Ollama")
351
+
352
+ results = []
353
+ for embedding_data in embeddings_list:
354
+ embedding = tuple(embedding_data) if embedding_data else ()
355
+ if embedding:
356
+ self._current_dimensions = len(embedding)
357
+
358
+ results.append(
359
+ EmbeddingResponse(
360
+ embedding=embedding,
361
+ model=model,
362
+ provider=self.provider_name,
363
+ dimensions=len(embedding),
364
+ )
365
+ )
366
+ return results
367
+ except httpx.HTTPError as e:
368
+ raise ConnectionError(f"Ollama async batch embed failed: {e}") from e
369
+
370
+ def chat(
371
+ self,
372
+ messages: list[dict[str, str]],
373
+ model: str,
374
+ temperature: float = 0.7,
375
+ max_tokens: int | None = None,
376
+ ) -> LLMResponse:
377
+ """
378
+ Chat completion using Ollama.
379
+
380
+ Parameters
381
+ ----------
382
+ messages : list[dict]
383
+ List of messages with 'role' and 'content' keys.
384
+ model : str
385
+ Model identifier.
386
+ temperature : float
387
+ Sampling temperature.
388
+ max_tokens : int, optional
389
+ Maximum tokens to generate.
390
+
391
+ Returns
392
+ -------
393
+ LLMResponse
394
+ The generated response.
395
+ """
396
+ options: dict[str, float | int] = {"temperature": temperature}
397
+ if max_tokens:
398
+ options["num_predict"] = max_tokens
399
+
400
+ payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
401
+ "model": model,
402
+ "messages": messages,
403
+ "stream": False,
404
+ "options": options,
405
+ }
406
+
407
+ try:
408
+ response = self.session.post(
409
+ f"{self.base_url}/api/chat",
410
+ json=payload,
411
+ timeout=self.timeout,
412
+ )
413
+ response.raise_for_status()
414
+ data = response.json()
415
+
416
+ return LLMResponse(
417
+ text=data.get("message", {}).get("content", ""),
418
+ model=model,
419
+ provider=self.provider_name,
420
+ usage={
421
+ "prompt_tokens": data.get("prompt_eval_count"),
422
+ "completion_tokens": data.get("eval_count"),
423
+ },
424
+ )
425
+ except requests.RequestException as e:
426
+ raise ConnectionError(f"Ollama chat failed: {e}") from e
427
+
428
+ @staticmethod
429
+ def clear_embedding_cache() -> None:
430
+ """Clear the embedding cache."""
431
+ _cached_embedding.cache_clear()
432
+
433
+ @staticmethod
434
+ def embedding_cache_info() -> dict[str, int]:
435
+ """Get embedding cache statistics."""
436
+ info = _cached_embedding.cache_info()
437
+ return {
438
+ "hits": info.hits,
439
+ "misses": info.misses,
440
+ "maxsize": info.maxsize or 0,
441
+ "currsize": info.currsize,
442
+ }
443
+
444
+
445
+ # Export the EMBEDDING_DIMENSIONS for external use
446
+ EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
@@ -0,0 +1,225 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ SentenceTransformers provider for offline embedding.
7
+
8
+ This module provides embedding capabilities using the sentence-transformers
9
+ library, enabling fully offline RAG pipelines without API dependencies.
10
+
11
+ Requires: pip install ragit[transformers]
12
+ """
13
+
14
+ from typing import TYPE_CHECKING
15
+
16
+ from ragit.providers.base import (
17
+ BaseEmbeddingProvider,
18
+ EmbeddingResponse,
19
+ )
20
+
21
+ if TYPE_CHECKING:
22
+ from sentence_transformers import SentenceTransformer
23
+
24
+ # Lazy import flag
25
+ _sentence_transformers_available: bool | None = None
26
+ _model_cache: dict[str, "SentenceTransformer"] = {}
27
+
28
+
29
+ def _check_sentence_transformers() -> bool:
30
+ """Check if sentence-transformers is available."""
31
+ global _sentence_transformers_available
32
+ if _sentence_transformers_available is None:
33
+ try:
34
+ from sentence_transformers import SentenceTransformer # noqa: F401
35
+
36
+ _sentence_transformers_available = True
37
+ except ImportError:
38
+ _sentence_transformers_available = False
39
+ return _sentence_transformers_available
40
+
41
+
42
+ def _get_model(model_name: str, device: str | None = None) -> "SentenceTransformer":
43
+ """Get or create a cached SentenceTransformer model."""
44
+ cache_key = f"{model_name}:{device or 'auto'}"
45
+ if cache_key not in _model_cache:
46
+ from sentence_transformers import SentenceTransformer
47
+
48
+ _model_cache[cache_key] = SentenceTransformer(model_name, device=device)
49
+ return _model_cache[cache_key]
50
+
51
+
52
+ class SentenceTransformersProvider(BaseEmbeddingProvider):
53
+ """
54
+ Embedding provider using sentence-transformers for offline operation.
55
+
56
+ This provider uses the sentence-transformers library to generate embeddings
57
+ locally without requiring any API calls. It's ideal for:
58
+ - Offline/air-gapped environments
59
+ - Development and testing
60
+ - Cost-sensitive applications
61
+ - Privacy-sensitive use cases
62
+
63
+ Parameters
64
+ ----------
65
+ model_name : str
66
+ HuggingFace model name. Default: "all-MiniLM-L6-v2" (fast, 384 dims).
67
+ Other popular options:
68
+ - "all-mpnet-base-v2" (768 dims, higher quality)
69
+ - "paraphrase-MiniLM-L6-v2" (384 dims)
70
+ - "multi-qa-MiniLM-L6-cos-v1" (384 dims, optimized for QA)
71
+ device : str, optional
72
+ Device to run on ("cpu", "cuda", "mps"). Auto-detected if None.
73
+
74
+ Examples
75
+ --------
76
+ >>> # Basic usage
77
+ >>> from ragit.providers import SentenceTransformersProvider
78
+ >>> provider = SentenceTransformersProvider()
79
+ >>>
80
+ >>> # With RAGAssistant (retrieval-only)
81
+ >>> assistant = RAGAssistant(docs, provider=provider)
82
+ >>> results = assistant.retrieve("query")
83
+ >>>
84
+ >>> # Custom model
85
+ >>> provider = SentenceTransformersProvider(model_name="all-mpnet-base-v2")
86
+
87
+ Raises
88
+ ------
89
+ ImportError
90
+ If sentence-transformers is not installed.
91
+
92
+ Note
93
+ ----
94
+ Install with: pip install ragit[transformers]
95
+ """
96
+
97
+ # Known model dimensions for common models
98
+ MODEL_DIMENSIONS: dict[str, int] = {
99
+ "all-MiniLM-L6-v2": 384,
100
+ "all-mpnet-base-v2": 768,
101
+ "paraphrase-MiniLM-L6-v2": 384,
102
+ "multi-qa-MiniLM-L6-cos-v1": 384,
103
+ "all-distilroberta-v1": 768,
104
+ "paraphrase-multilingual-MiniLM-L12-v2": 384,
105
+ }
106
+
107
+ def __init__(
108
+ self,
109
+ model_name: str = "all-MiniLM-L6-v2",
110
+ device: str | None = None,
111
+ ) -> None:
112
+ if not _check_sentence_transformers():
113
+ raise ImportError(
114
+ "sentence-transformers is required for SentenceTransformersProvider. "
115
+ "Install with: pip install ragit[transformers]"
116
+ )
117
+
118
+ self._model_name = model_name
119
+ self._device = device
120
+ self._model: SentenceTransformer | None = None # Lazy loaded
121
+ self._dimensions: int | None = self.MODEL_DIMENSIONS.get(model_name)
122
+
123
+ def _ensure_model(self) -> "SentenceTransformer":
124
+ """Ensure model is loaded (lazy loading)."""
125
+ if self._model is None:
126
+ model = _get_model(self._model_name, self._device)
127
+ self._model = model
128
+ # Update dimensions from actual model
129
+ self._dimensions = model.get_sentence_embedding_dimension()
130
+ return self._model
131
+
132
+ @property
133
+ def provider_name(self) -> str:
134
+ return "sentence_transformers"
135
+
136
+ @property
137
+ def dimensions(self) -> int:
138
+ if self._dimensions is None:
139
+ # Load model to get dimensions
140
+ self._ensure_model()
141
+ return self._dimensions or 384 # Fallback
142
+
143
+ @property
144
+ def model_name(self) -> str:
145
+ """Return the model name being used."""
146
+ return self._model_name
147
+
148
+ def is_available(self) -> bool:
149
+ """Check if sentence-transformers is installed and model can be loaded."""
150
+ if not _check_sentence_transformers():
151
+ return False
152
+ try:
153
+ self._ensure_model()
154
+ return True
155
+ except Exception:
156
+ return False
157
+
158
+ def embed(self, text: str, model: str = "") -> EmbeddingResponse:
159
+ """
160
+ Generate embedding for text.
161
+
162
+ Parameters
163
+ ----------
164
+ text : str
165
+ Text to embed.
166
+ model : str
167
+ Model identifier (ignored, uses model from constructor).
168
+
169
+ Returns
170
+ -------
171
+ EmbeddingResponse
172
+ The embedding response.
173
+ """
174
+ model_instance = self._ensure_model()
175
+ embedding = model_instance.encode(text, convert_to_numpy=True)
176
+
177
+ # Convert to tuple
178
+ embedding_tuple = tuple(float(x) for x in embedding)
179
+
180
+ return EmbeddingResponse(
181
+ embedding=embedding_tuple,
182
+ model=self._model_name,
183
+ provider=self.provider_name,
184
+ dimensions=len(embedding_tuple),
185
+ )
186
+
187
+ def embed_batch(self, texts: list[str], model: str = "") -> list[EmbeddingResponse]:
188
+ """
189
+ Generate embeddings for multiple texts efficiently.
190
+
191
+ Uses batch encoding for better performance.
192
+
193
+ Parameters
194
+ ----------
195
+ texts : list[str]
196
+ Texts to embed.
197
+ model : str
198
+ Model identifier (ignored).
199
+
200
+ Returns
201
+ -------
202
+ list[EmbeddingResponse]
203
+ List of embedding responses.
204
+ """
205
+ if not texts:
206
+ return []
207
+
208
+ model_instance = self._ensure_model()
209
+
210
+ # Batch encode for efficiency
211
+ embeddings = model_instance.encode(texts, convert_to_numpy=True, show_progress_bar=False)
212
+
213
+ results = []
214
+ for embedding in embeddings:
215
+ embedding_tuple = tuple(float(x) for x in embedding)
216
+ results.append(
217
+ EmbeddingResponse(
218
+ embedding=embedding_tuple,
219
+ model=self._model_name,
220
+ provider=self.provider_name,
221
+ dimensions=len(embedding_tuple),
222
+ )
223
+ )
224
+
225
+ return results