ragit 0.7.2__tar.gz → 0.7.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ragit
3
- Version: 0.7.2
3
+ Version: 0.7.4
4
4
  Summary: Automatic RAG Pattern Optimization Engine
5
5
  Author: RODMENA LIMITED
6
6
  Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
@@ -26,6 +26,8 @@ Requires-Dist: pydantic>=2.0.0
26
26
  Requires-Dist: python-dotenv>=1.0.0
27
27
  Requires-Dist: scikit-learn>=1.5.0
28
28
  Requires-Dist: tqdm>=4.66.0
29
+ Requires-Dist: trio>=0.24.0
30
+ Requires-Dist: httpx>=0.27.0
29
31
  Provides-Extra: dev
30
32
  Requires-Dist: ragit[test]; extra == "dev"
31
33
  Requires-Dist: pytest; extra == "dev"
@@ -443,6 +445,77 @@ print(f"Score: {best.score:.3f}")
443
445
 
444
446
  The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
445
447
 
448
+ ## Performance Features
449
+
450
+ Ragit includes several optimizations for production workloads:
451
+
452
+ ### Connection Pooling
453
+
454
+ `OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
455
+
456
+ ```python
457
+ from ragit.providers import OllamaProvider
458
+
459
+ provider = OllamaProvider()
460
+
461
+ # All requests reuse the same connection pool
462
+ for text in texts:
463
+ provider.embed(text, model="mxbai-embed-large")
464
+
465
+ # Explicitly close when done (optional, auto-closes on garbage collection)
466
+ provider.close()
467
+ ```
468
+
469
+ ### Async Parallel Embedding
470
+
471
+ For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
472
+
473
+ ```python
474
+ import trio
475
+ from ragit.providers import OllamaProvider
476
+
477
+ provider = OllamaProvider()
478
+
479
+ async def embed_documents():
480
+ texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
481
+ embeddings = await provider.embed_batch_async(
482
+ texts,
483
+ model="mxbai-embed-large",
484
+ max_concurrent=10 # Adjust based on server capacity
485
+ )
486
+ return embeddings
487
+
488
+ # Run with trio
489
+ results = trio.run(embed_documents)
490
+ ```
491
+
492
+ ### Embedding Cache
493
+
494
+ Repeated embedding calls are cached automatically (2048 entries LRU):
495
+
496
+ ```python
497
+ from ragit.providers import OllamaProvider
498
+
499
+ provider = OllamaProvider(use_cache=True) # Default
500
+
501
+ # First call hits the API
502
+ provider.embed("Hello world", model="mxbai-embed-large")
503
+
504
+ # Second call returns cached result instantly
505
+ provider.embed("Hello world", model="mxbai-embed-large")
506
+
507
+ # View cache statistics
508
+ print(OllamaProvider.embedding_cache_info())
509
+ # {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
510
+
511
+ # Clear cache if needed
512
+ OllamaProvider.clear_embedding_cache()
513
+ ```
514
+
515
+ ### Pre-normalized Embeddings
516
+
517
+ Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
518
+
446
519
  ## API Reference
447
520
 
448
521
  ### Document Loading
@@ -398,6 +398,77 @@ print(f"Score: {best.score:.3f}")
398
398
 
399
399
  The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
400
400
 
401
+ ## Performance Features
402
+
403
+ Ragit includes several optimizations for production workloads:
404
+
405
+ ### Connection Pooling
406
+
407
+ `OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
408
+
409
+ ```python
410
+ from ragit.providers import OllamaProvider
411
+
412
+ provider = OllamaProvider()
413
+
414
+ # All requests reuse the same connection pool
415
+ for text in texts:
416
+ provider.embed(text, model="mxbai-embed-large")
417
+
418
+ # Explicitly close when done (optional, auto-closes on garbage collection)
419
+ provider.close()
420
+ ```
421
+
422
+ ### Async Parallel Embedding
423
+
424
+ For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
425
+
426
+ ```python
427
+ import trio
428
+ from ragit.providers import OllamaProvider
429
+
430
+ provider = OllamaProvider()
431
+
432
+ async def embed_documents():
433
+ texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
434
+ embeddings = await provider.embed_batch_async(
435
+ texts,
436
+ model="mxbai-embed-large",
437
+ max_concurrent=10 # Adjust based on server capacity
438
+ )
439
+ return embeddings
440
+
441
+ # Run with trio
442
+ results = trio.run(embed_documents)
443
+ ```
444
+
445
+ ### Embedding Cache
446
+
447
+ Repeated embedding calls are cached automatically (2048 entries LRU):
448
+
449
+ ```python
450
+ from ragit.providers import OllamaProvider
451
+
452
+ provider = OllamaProvider(use_cache=True) # Default
453
+
454
+ # First call hits the API
455
+ provider.embed("Hello world", model="mxbai-embed-large")
456
+
457
+ # Second call returns cached result instantly
458
+ provider.embed("Hello world", model="mxbai-embed-large")
459
+
460
+ # View cache statistics
461
+ print(OllamaProvider.embedding_cache_info())
462
+ # {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
463
+
464
+ # Clear cache if needed
465
+ OllamaProvider.clear_embedding_cache()
466
+ ```
467
+
468
+ ### Pre-normalized Embeddings
469
+
470
+ Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
471
+
401
472
  ## API Reference
402
473
 
403
474
  ### Document Loading
@@ -38,6 +38,8 @@ dependencies = [
38
38
  "python-dotenv>=1.0.0",
39
39
  "scikit-learn>=1.5.0",
40
40
  "tqdm>=4.66.0",
41
+ "trio>=0.24.0",
42
+ "httpx>=0.27.0",
41
43
  ]
42
44
 
43
45
  [project.urls]
@@ -0,0 +1,461 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ollama provider for LLM and Embedding operations.
7
+
8
+ This provider connects to a local or remote Ollama server.
9
+ Configuration is loaded from environment variables.
10
+
11
+ Performance optimizations:
12
+ - Connection pooling via requests.Session()
13
+ - Async parallel embedding via trio + httpx
14
+ - LRU cache for repeated embedding queries
15
+ """
16
+
17
+ from functools import lru_cache
18
+ from typing import Any
19
+
20
+ import httpx
21
+ import requests
22
+ import trio
23
+
24
+ from ragit.config import config
25
+ from ragit.providers.base import (
26
+ BaseEmbeddingProvider,
27
+ BaseLLMProvider,
28
+ EmbeddingResponse,
29
+ LLMResponse,
30
+ )
31
+
32
+
33
+ # Module-level cache for embeddings (shared across instances)
34
+ @lru_cache(maxsize=2048)
35
+ def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
36
+ """Cache embedding results to avoid redundant API calls."""
37
+ # Truncate oversized inputs
38
+ if len(text) > OllamaProvider.MAX_EMBED_CHARS:
39
+ text = text[: OllamaProvider.MAX_EMBED_CHARS]
40
+
41
+ response = requests.post(
42
+ f"{embedding_url}/api/embeddings",
43
+ headers={"Content-Type": "application/json"},
44
+ json={"model": model, "prompt": text},
45
+ timeout=timeout,
46
+ )
47
+ response.raise_for_status()
48
+ data = response.json()
49
+ embedding = data.get("embedding", [])
50
+ if not embedding:
51
+ raise ValueError("Empty embedding returned from Ollama")
52
+ return tuple(embedding)
53
+
54
+
55
+ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
56
+ """
57
+ Ollama provider for both LLM and Embedding operations.
58
+
59
+ Performance features:
60
+ - Connection pooling via requests.Session() for faster sequential requests
61
+ - Async parallel embedding via embed_batch_async() using trio + httpx
62
+ - LRU cache for repeated embedding queries (2048 entries)
63
+
64
+ Parameters
65
+ ----------
66
+ base_url : str, optional
67
+ Ollama server URL (default: from OLLAMA_BASE_URL env var)
68
+ api_key : str, optional
69
+ API key for authentication (default: from OLLAMA_API_KEY env var)
70
+ timeout : int, optional
71
+ Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
72
+ use_cache : bool, optional
73
+ Enable embedding cache (default: True)
74
+
75
+ Examples
76
+ --------
77
+ >>> provider = OllamaProvider()
78
+ >>> response = provider.generate("What is RAG?", model="llama3")
79
+ >>> print(response.text)
80
+
81
+ >>> # Async batch embedding (5-10x faster for large batches)
82
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
83
+ """
84
+
85
+ # Known embedding model dimensions
86
+ EMBEDDING_DIMENSIONS: dict[str, int] = {
87
+ "nomic-embed-text": 768,
88
+ "nomic-embed-text:latest": 768,
89
+ "mxbai-embed-large": 1024,
90
+ "all-minilm": 384,
91
+ "snowflake-arctic-embed": 1024,
92
+ "qwen3-embedding": 4096,
93
+ "qwen3-embedding:0.6b": 1024,
94
+ "qwen3-embedding:4b": 2560,
95
+ "qwen3-embedding:8b": 4096,
96
+ }
97
+
98
+ # Max characters per embedding request (safe limit for 512 token models)
99
+ MAX_EMBED_CHARS = 2000
100
+
101
+ def __init__(
102
+ self,
103
+ base_url: str | None = None,
104
+ embedding_url: str | None = None,
105
+ api_key: str | None = None,
106
+ timeout: int | None = None,
107
+ use_cache: bool = True,
108
+ ) -> None:
109
+ self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
110
+ self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
111
+ self.api_key = api_key or config.OLLAMA_API_KEY
112
+ self.timeout = timeout or config.OLLAMA_TIMEOUT
113
+ self.use_cache = use_cache
114
+ self._current_embed_model: str | None = None
115
+ self._current_dimensions: int = 768 # default
116
+
117
+ # Connection pooling via session
118
+ self._session: requests.Session | None = None
119
+
120
+ @property
121
+ def session(self) -> requests.Session:
122
+ """Lazy-initialized session for connection pooling."""
123
+ if self._session is None:
124
+ self._session = requests.Session()
125
+ self._session.headers.update({"Content-Type": "application/json"})
126
+ if self.api_key:
127
+ self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
128
+ return self._session
129
+
130
+ def close(self) -> None:
131
+ """Close the session and release resources."""
132
+ if self._session is not None:
133
+ self._session.close()
134
+ self._session = None
135
+
136
+ def __del__(self) -> None:
137
+ """Cleanup on garbage collection."""
138
+ self.close()
139
+
140
+ def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
141
+ """Get request headers including authentication if API key is set."""
142
+ headers = {"Content-Type": "application/json"}
143
+ if include_auth and self.api_key:
144
+ headers["Authorization"] = f"Bearer {self.api_key}"
145
+ return headers
146
+
147
+ @property
148
+ def provider_name(self) -> str:
149
+ return "ollama"
150
+
151
+ @property
152
+ def dimensions(self) -> int:
153
+ return self._current_dimensions
154
+
155
+ def is_available(self) -> bool:
156
+ """Check if Ollama server is reachable."""
157
+ try:
158
+ response = self.session.get(
159
+ f"{self.base_url}/api/tags",
160
+ timeout=5,
161
+ )
162
+ return response.status_code == 200
163
+ except requests.RequestException:
164
+ return False
165
+
166
+ def list_models(self) -> list[dict[str, Any]]:
167
+ """List available models on the Ollama server."""
168
+ try:
169
+ response = self.session.get(
170
+ f"{self.base_url}/api/tags",
171
+ timeout=10,
172
+ )
173
+ response.raise_for_status()
174
+ data = response.json()
175
+ return list(data.get("models", []))
176
+ except requests.RequestException as e:
177
+ raise ConnectionError(f"Failed to list Ollama models: {e}") from e
178
+
179
+ def generate(
180
+ self,
181
+ prompt: str,
182
+ model: str,
183
+ system_prompt: str | None = None,
184
+ temperature: float = 0.7,
185
+ max_tokens: int | None = None,
186
+ ) -> LLMResponse:
187
+ """Generate text using Ollama."""
188
+ options: dict[str, float | int] = {"temperature": temperature}
189
+ if max_tokens:
190
+ options["num_predict"] = max_tokens
191
+
192
+ payload: dict[str, str | bool | dict[str, float | int]] = {
193
+ "model": model,
194
+ "prompt": prompt,
195
+ "stream": False,
196
+ "options": options,
197
+ }
198
+
199
+ if system_prompt:
200
+ payload["system"] = system_prompt
201
+
202
+ try:
203
+ response = self.session.post(
204
+ f"{self.base_url}/api/generate",
205
+ json=payload,
206
+ timeout=self.timeout,
207
+ )
208
+ response.raise_for_status()
209
+ data = response.json()
210
+
211
+ return LLMResponse(
212
+ text=data.get("response", ""),
213
+ model=model,
214
+ provider=self.provider_name,
215
+ usage={
216
+ "prompt_tokens": data.get("prompt_eval_count"),
217
+ "completion_tokens": data.get("eval_count"),
218
+ "total_duration": data.get("total_duration"),
219
+ },
220
+ )
221
+ except requests.RequestException as e:
222
+ raise ConnectionError(f"Ollama generate failed: {e}") from e
223
+
224
+ def embed(self, text: str, model: str) -> EmbeddingResponse:
225
+ """Generate embedding using Ollama with optional caching."""
226
+ self._current_embed_model = model
227
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
228
+
229
+ try:
230
+ if self.use_cache:
231
+ # Use cached version
232
+ embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
233
+ else:
234
+ # Direct call without cache
235
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
236
+ response = self.session.post(
237
+ f"{self.embedding_url}/api/embeddings",
238
+ json={"model": model, "prompt": truncated},
239
+ timeout=self.timeout,
240
+ )
241
+ response.raise_for_status()
242
+ data = response.json()
243
+ embedding_list = data.get("embedding", [])
244
+ if not embedding_list:
245
+ raise ValueError("Empty embedding returned from Ollama")
246
+ embedding = tuple(embedding_list)
247
+
248
+ # Update dimensions from actual response
249
+ self._current_dimensions = len(embedding)
250
+
251
+ return EmbeddingResponse(
252
+ embedding=embedding,
253
+ model=model,
254
+ provider=self.provider_name,
255
+ dimensions=len(embedding),
256
+ )
257
+ except requests.RequestException as e:
258
+ raise ConnectionError(f"Ollama embed failed: {e}") from e
259
+
260
+ def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
261
+ """Generate embeddings for multiple texts sequentially.
262
+
263
+ For better performance with large batches, use embed_batch_async().
264
+
265
+ Note: Ollama /api/embeddings only supports single prompts, so we loop.
266
+ """
267
+ self._current_embed_model = model
268
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
269
+
270
+ results = []
271
+ try:
272
+ for text in texts:
273
+ # Truncate oversized inputs
274
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
275
+
276
+ if self.use_cache:
277
+ embedding = _cached_embedding(truncated, model, self.embedding_url, self.timeout)
278
+ else:
279
+ response = self.session.post(
280
+ f"{self.embedding_url}/api/embeddings",
281
+ json={"model": model, "prompt": truncated},
282
+ timeout=self.timeout,
283
+ )
284
+ response.raise_for_status()
285
+ data = response.json()
286
+ embedding_list = data.get("embedding", [])
287
+ embedding = tuple(embedding_list) if embedding_list else ()
288
+
289
+ if embedding:
290
+ self._current_dimensions = len(embedding)
291
+
292
+ results.append(
293
+ EmbeddingResponse(
294
+ embedding=embedding,
295
+ model=model,
296
+ provider=self.provider_name,
297
+ dimensions=len(embedding),
298
+ )
299
+ )
300
+ return results
301
+ except requests.RequestException as e:
302
+ raise ConnectionError(f"Ollama batch embed failed: {e}") from e
303
+
304
+ async def embed_batch_async(
305
+ self,
306
+ texts: list[str],
307
+ model: str,
308
+ max_concurrent: int = 10,
309
+ ) -> list[EmbeddingResponse]:
310
+ """Generate embeddings for multiple texts in parallel using trio.
311
+
312
+ This method is 5-10x faster than embed_batch() for large batches
313
+ by making concurrent HTTP requests.
314
+
315
+ Parameters
316
+ ----------
317
+ texts : list[str]
318
+ Texts to embed.
319
+ model : str
320
+ Embedding model name.
321
+ max_concurrent : int
322
+ Maximum concurrent requests (default: 10).
323
+ Higher values = faster but more server load.
324
+
325
+ Returns
326
+ -------
327
+ list[EmbeddingResponse]
328
+ Embeddings in the same order as input texts.
329
+
330
+ Examples
331
+ --------
332
+ >>> import trio
333
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
334
+ """
335
+ self._current_embed_model = model
336
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
337
+
338
+ # Results storage (index -> embedding)
339
+ results: dict[int, EmbeddingResponse] = {}
340
+ errors: list[Exception] = []
341
+
342
+ # Semaphore to limit concurrency
343
+ limiter = trio.CapacityLimiter(max_concurrent)
344
+
345
+ async def fetch_embedding(client: httpx.AsyncClient, index: int, text: str) -> None:
346
+ """Fetch a single embedding."""
347
+ async with limiter:
348
+ try:
349
+ # Truncate oversized inputs
350
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
351
+
352
+ response = await client.post(
353
+ f"{self.embedding_url}/api/embeddings",
354
+ json={"model": model, "prompt": truncated},
355
+ timeout=self.timeout,
356
+ )
357
+ response.raise_for_status()
358
+ data = response.json()
359
+
360
+ embedding_list = data.get("embedding", [])
361
+ embedding = tuple(embedding_list) if embedding_list else ()
362
+
363
+ if embedding:
364
+ self._current_dimensions = len(embedding)
365
+
366
+ results[index] = EmbeddingResponse(
367
+ embedding=embedding,
368
+ model=model,
369
+ provider=self.provider_name,
370
+ dimensions=len(embedding),
371
+ )
372
+ except Exception as e:
373
+ errors.append(e)
374
+
375
+ async with httpx.AsyncClient() as client, trio.open_nursery() as nursery:
376
+ for i, text in enumerate(texts):
377
+ nursery.start_soon(fetch_embedding, client, i, text)
378
+
379
+ if errors:
380
+ raise ConnectionError(f"Ollama async batch embed failed: {errors[0]}") from errors[0]
381
+
382
+ # Return results in original order
383
+ return [results[i] for i in range(len(texts))]
384
+
385
+ def chat(
386
+ self,
387
+ messages: list[dict[str, str]],
388
+ model: str,
389
+ temperature: float = 0.7,
390
+ max_tokens: int | None = None,
391
+ ) -> LLMResponse:
392
+ """
393
+ Chat completion using Ollama.
394
+
395
+ Parameters
396
+ ----------
397
+ messages : list[dict]
398
+ List of messages with 'role' and 'content' keys.
399
+ model : str
400
+ Model identifier.
401
+ temperature : float
402
+ Sampling temperature.
403
+ max_tokens : int, optional
404
+ Maximum tokens to generate.
405
+
406
+ Returns
407
+ -------
408
+ LLMResponse
409
+ The generated response.
410
+ """
411
+ options: dict[str, float | int] = {"temperature": temperature}
412
+ if max_tokens:
413
+ options["num_predict"] = max_tokens
414
+
415
+ payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
416
+ "model": model,
417
+ "messages": messages,
418
+ "stream": False,
419
+ "options": options,
420
+ }
421
+
422
+ try:
423
+ response = self.session.post(
424
+ f"{self.base_url}/api/chat",
425
+ json=payload,
426
+ timeout=self.timeout,
427
+ )
428
+ response.raise_for_status()
429
+ data = response.json()
430
+
431
+ return LLMResponse(
432
+ text=data.get("message", {}).get("content", ""),
433
+ model=model,
434
+ provider=self.provider_name,
435
+ usage={
436
+ "prompt_tokens": data.get("prompt_eval_count"),
437
+ "completion_tokens": data.get("eval_count"),
438
+ },
439
+ )
440
+ except requests.RequestException as e:
441
+ raise ConnectionError(f"Ollama chat failed: {e}") from e
442
+
443
+ @staticmethod
444
+ def clear_embedding_cache() -> None:
445
+ """Clear the embedding cache."""
446
+ _cached_embedding.cache_clear()
447
+
448
+ @staticmethod
449
+ def embedding_cache_info() -> dict[str, int]:
450
+ """Get embedding cache statistics."""
451
+ info = _cached_embedding.cache_info()
452
+ return {
453
+ "hits": info.hits,
454
+ "misses": info.misses,
455
+ "maxsize": info.maxsize or 0,
456
+ "currsize": info.currsize,
457
+ }
458
+
459
+
460
+ # Export the EMBEDDING_DIMENSIONS for external use
461
+ EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
@@ -2,4 +2,4 @@
2
2
  # Copyright RODMENA LIMITED 2025
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
- __version__ = "0.7.2"
5
+ __version__ = "0.7.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ragit
3
- Version: 0.7.2
3
+ Version: 0.7.4
4
4
  Summary: Automatic RAG Pattern Optimization Engine
5
5
  Author: RODMENA LIMITED
6
6
  Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
@@ -26,6 +26,8 @@ Requires-Dist: pydantic>=2.0.0
26
26
  Requires-Dist: python-dotenv>=1.0.0
27
27
  Requires-Dist: scikit-learn>=1.5.0
28
28
  Requires-Dist: tqdm>=4.66.0
29
+ Requires-Dist: trio>=0.24.0
30
+ Requires-Dist: httpx>=0.27.0
29
31
  Provides-Extra: dev
30
32
  Requires-Dist: ragit[test]; extra == "dev"
31
33
  Requires-Dist: pytest; extra == "dev"
@@ -443,6 +445,77 @@ print(f"Score: {best.score:.3f}")
443
445
 
444
446
  The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
445
447
 
448
+ ## Performance Features
449
+
450
+ Ragit includes several optimizations for production workloads:
451
+
452
+ ### Connection Pooling
453
+
454
+ `OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
455
+
456
+ ```python
457
+ from ragit.providers import OllamaProvider
458
+
459
+ provider = OllamaProvider()
460
+
461
+ # All requests reuse the same connection pool
462
+ for text in texts:
463
+ provider.embed(text, model="mxbai-embed-large")
464
+
465
+ # Explicitly close when done (optional, auto-closes on garbage collection)
466
+ provider.close()
467
+ ```
468
+
469
+ ### Async Parallel Embedding
470
+
471
+ For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
472
+
473
+ ```python
474
+ import trio
475
+ from ragit.providers import OllamaProvider
476
+
477
+ provider = OllamaProvider()
478
+
479
+ async def embed_documents():
480
+ texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
481
+ embeddings = await provider.embed_batch_async(
482
+ texts,
483
+ model="mxbai-embed-large",
484
+ max_concurrent=10 # Adjust based on server capacity
485
+ )
486
+ return embeddings
487
+
488
+ # Run with trio
489
+ results = trio.run(embed_documents)
490
+ ```
491
+
492
+ ### Embedding Cache
493
+
494
+ Repeated embedding calls are cached automatically (2048 entries LRU):
495
+
496
+ ```python
497
+ from ragit.providers import OllamaProvider
498
+
499
+ provider = OllamaProvider(use_cache=True) # Default
500
+
501
+ # First call hits the API
502
+ provider.embed("Hello world", model="mxbai-embed-large")
503
+
504
+ # Second call returns cached result instantly
505
+ provider.embed("Hello world", model="mxbai-embed-large")
506
+
507
+ # View cache statistics
508
+ print(OllamaProvider.embedding_cache_info())
509
+ # {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
510
+
511
+ # Clear cache if needed
512
+ OllamaProvider.clear_embedding_cache()
513
+ ```
514
+
515
+ ### Pre-normalized Embeddings
516
+
517
+ Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
518
+
446
519
  ## API Reference
447
520
 
448
521
  ### Document Loading
@@ -5,6 +5,8 @@ pydantic>=2.0.0
5
5
  python-dotenv>=1.0.0
6
6
  scikit-learn>=1.5.0
7
7
  tqdm>=4.66.0
8
+ trio>=0.24.0
9
+ httpx>=0.27.0
8
10
 
9
11
  [dev]
10
12
  ragit[test]
@@ -1,284 +0,0 @@
1
- #
2
- # Copyright RODMENA LIMITED 2025
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- """
6
- Ollama provider for LLM and Embedding operations.
7
-
8
- This provider connects to a local or remote Ollama server.
9
- Configuration is loaded from environment variables.
10
- """
11
-
12
- import requests
13
-
14
- from ragit.config import config
15
- from ragit.providers.base import (
16
- BaseEmbeddingProvider,
17
- BaseLLMProvider,
18
- EmbeddingResponse,
19
- LLMResponse,
20
- )
21
-
22
-
23
- class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
24
- """
25
- Ollama provider for both LLM and Embedding operations.
26
-
27
- Parameters
28
- ----------
29
- base_url : str, optional
30
- Ollama server URL (default: from OLLAMA_BASE_URL env var)
31
- api_key : str, optional
32
- API key for authentication (default: from OLLAMA_API_KEY env var)
33
- timeout : int, optional
34
- Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
35
-
36
- Examples
37
- --------
38
- >>> provider = OllamaProvider()
39
- >>> response = provider.generate("What is RAG?", model="llama3")
40
- >>> print(response.text)
41
-
42
- >>> embedding = provider.embed("Hello world", model="nomic-embed-text")
43
- >>> print(len(embedding.embedding))
44
- """
45
-
46
- # Known embedding model dimensions
47
- EMBEDDING_DIMENSIONS = {
48
- "nomic-embed-text": 768,
49
- "nomic-embed-text:latest": 768,
50
- "mxbai-embed-large": 1024,
51
- "all-minilm": 384,
52
- "snowflake-arctic-embed": 1024,
53
- "qwen3-embedding": 4096,
54
- "qwen3-embedding:0.6b": 1024,
55
- "qwen3-embedding:4b": 2560,
56
- "qwen3-embedding:8b": 4096,
57
- }
58
-
59
- def __init__(
60
- self,
61
- base_url: str | None = None,
62
- embedding_url: str | None = None,
63
- api_key: str | None = None,
64
- timeout: int | None = None,
65
- ) -> None:
66
- self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
67
- self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
68
- self.api_key = api_key or config.OLLAMA_API_KEY
69
- self.timeout = timeout or config.OLLAMA_TIMEOUT
70
- self._current_embed_model: str | None = None
71
- self._current_dimensions: int = 768 # default
72
-
73
- def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
74
- """Get request headers including authentication if API key is set."""
75
- headers = {"Content-Type": "application/json"}
76
- if include_auth and self.api_key:
77
- headers["Authorization"] = f"Bearer {self.api_key}"
78
- return headers
79
-
80
- @property
81
- def provider_name(self) -> str:
82
- return "ollama"
83
-
84
- @property
85
- def dimensions(self) -> int:
86
- return self._current_dimensions
87
-
88
- def is_available(self) -> bool:
89
- """Check if Ollama server is reachable."""
90
- try:
91
- response = requests.get(
92
- f"{self.base_url}/api/tags",
93
- headers=self._get_headers(),
94
- timeout=5,
95
- )
96
- return response.status_code == 200
97
- except requests.RequestException:
98
- return False
99
-
100
- def list_models(self) -> list[dict[str, str]]:
101
- """List available models on the Ollama server."""
102
- try:
103
- response = requests.get(
104
- f"{self.base_url}/api/tags",
105
- headers=self._get_headers(),
106
- timeout=10,
107
- )
108
- response.raise_for_status()
109
- data = response.json()
110
- return list(data.get("models", []))
111
- except requests.RequestException as e:
112
- raise ConnectionError(f"Failed to list Ollama models: {e}") from e
113
-
114
- def generate(
115
- self,
116
- prompt: str,
117
- model: str,
118
- system_prompt: str | None = None,
119
- temperature: float = 0.7,
120
- max_tokens: int | None = None,
121
- ) -> LLMResponse:
122
- """Generate text using Ollama."""
123
- options: dict[str, float | int] = {"temperature": temperature}
124
- if max_tokens:
125
- options["num_predict"] = max_tokens
126
-
127
- payload: dict[str, str | bool | dict[str, float | int]] = {
128
- "model": model,
129
- "prompt": prompt,
130
- "stream": False,
131
- "options": options,
132
- }
133
-
134
- if system_prompt:
135
- payload["system"] = system_prompt
136
-
137
- try:
138
- response = requests.post(
139
- f"{self.base_url}/api/generate",
140
- headers=self._get_headers(),
141
- json=payload,
142
- timeout=self.timeout,
143
- )
144
- response.raise_for_status()
145
- data = response.json()
146
-
147
- return LLMResponse(
148
- text=data.get("response", ""),
149
- model=model,
150
- provider=self.provider_name,
151
- usage={
152
- "prompt_tokens": data.get("prompt_eval_count"),
153
- "completion_tokens": data.get("eval_count"),
154
- "total_duration": data.get("total_duration"),
155
- },
156
- )
157
- except requests.RequestException as e:
158
- raise ConnectionError(f"Ollama generate failed: {e}") from e
159
-
160
- def embed(self, text: str, model: str) -> EmbeddingResponse:
161
- """Generate embedding using Ollama (uses embedding_url, no auth for local)."""
162
- self._current_embed_model = model
163
- self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
164
-
165
- try:
166
- response = requests.post(
167
- f"{self.embedding_url}/api/embeddings",
168
- headers=self._get_headers(include_auth=False),
169
- json={"model": model, "prompt": text},
170
- timeout=self.timeout,
171
- )
172
- response.raise_for_status()
173
- data = response.json()
174
-
175
- embedding = data.get("embedding", [])
176
- if not embedding:
177
- raise ValueError("Empty embedding returned from Ollama")
178
-
179
- # Update dimensions from actual response
180
- self._current_dimensions = len(embedding)
181
-
182
- return EmbeddingResponse(
183
- embedding=tuple(embedding),
184
- model=model,
185
- provider=self.provider_name,
186
- dimensions=len(embedding),
187
- )
188
- except requests.RequestException as e:
189
- raise ConnectionError(f"Ollama embed failed: {e}") from e
190
-
191
- def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
192
- """Generate embeddings for multiple texts (uses embedding_url, no auth for local).
193
-
194
- Note: Ollama /api/embeddings only supports single prompts, so we loop.
195
- """
196
- self._current_embed_model = model
197
- self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
198
-
199
- results = []
200
- try:
201
- for text in texts:
202
- response = requests.post(
203
- f"{self.embedding_url}/api/embeddings",
204
- headers=self._get_headers(include_auth=False),
205
- json={"model": model, "prompt": text},
206
- timeout=self.timeout,
207
- )
208
- response.raise_for_status()
209
- data = response.json()
210
-
211
- embedding = data.get("embedding", [])
212
- if embedding:
213
- self._current_dimensions = len(embedding)
214
-
215
- results.append(
216
- EmbeddingResponse(
217
- embedding=tuple(embedding),
218
- model=model,
219
- provider=self.provider_name,
220
- dimensions=len(embedding),
221
- )
222
- )
223
- return results
224
- except requests.RequestException as e:
225
- raise ConnectionError(f"Ollama batch embed failed: {e}") from e
226
-
227
- def chat(
228
- self,
229
- messages: list[dict[str, str]],
230
- model: str,
231
- temperature: float = 0.7,
232
- max_tokens: int | None = None,
233
- ) -> LLMResponse:
234
- """
235
- Chat completion using Ollama.
236
-
237
- Parameters
238
- ----------
239
- messages : list[dict]
240
- List of messages with 'role' and 'content' keys.
241
- model : str
242
- Model identifier.
243
- temperature : float
244
- Sampling temperature.
245
- max_tokens : int, optional
246
- Maximum tokens to generate.
247
-
248
- Returns
249
- -------
250
- LLMResponse
251
- The generated response.
252
- """
253
- options: dict[str, float | int] = {"temperature": temperature}
254
- if max_tokens:
255
- options["num_predict"] = max_tokens
256
-
257
- payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
258
- "model": model,
259
- "messages": messages,
260
- "stream": False,
261
- "options": options,
262
- }
263
-
264
- try:
265
- response = requests.post(
266
- f"{self.base_url}/api/chat",
267
- headers=self._get_headers(),
268
- json=payload,
269
- timeout=self.timeout,
270
- )
271
- response.raise_for_status()
272
- data = response.json()
273
-
274
- return LLMResponse(
275
- text=data.get("message", {}).get("content", ""),
276
- model=model,
277
- provider=self.provider_name,
278
- usage={
279
- "prompt_tokens": data.get("prompt_eval_count"),
280
- "completion_tokens": data.get("eval_count"),
281
- },
282
- )
283
- except requests.RequestException as e:
284
- raise ConnectionError(f"Ollama chat failed: {e}") from e
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes