ragit 0.7.2__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ragit/providers/ollama.py CHANGED
@@ -7,9 +7,19 @@ Ollama provider for LLM and Embedding operations.
7
7
 
8
8
  This provider connects to a local or remote Ollama server.
9
9
  Configuration is loaded from environment variables.
10
+
11
+ Performance optimizations:
12
+ - Connection pooling via requests.Session()
13
+ - Async parallel embedding via trio + httpx
14
+ - LRU cache for repeated embedding queries
10
15
  """
11
16
 
17
+ from functools import lru_cache
18
+ from typing import Any
19
+
20
+ import httpx
12
21
  import requests
22
+ import trio
13
23
 
14
24
  from ragit.config import config
15
25
  from ragit.providers.base import (
@@ -20,10 +30,37 @@ from ragit.providers.base import (
20
30
  )
21
31
 
22
32
 
33
+ # Module-level cache for embeddings (shared across instances)
34
+ @lru_cache(maxsize=2048)
35
+ def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
36
+ """Cache embedding results to avoid redundant API calls."""
37
+ # Truncate oversized inputs
38
+ if len(text) > OllamaProvider.MAX_EMBED_CHARS:
39
+ text = text[: OllamaProvider.MAX_EMBED_CHARS]
40
+
41
+ response = requests.post(
42
+ f"{embedding_url}/api/embeddings",
43
+ headers={"Content-Type": "application/json"},
44
+ json={"model": model, "prompt": text},
45
+ timeout=timeout,
46
+ )
47
+ response.raise_for_status()
48
+ data = response.json()
49
+ embedding = data.get("embedding", [])
50
+ if not embedding:
51
+ raise ValueError("Empty embedding returned from Ollama")
52
+ return tuple(embedding)
53
+
54
+
23
55
  class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
24
56
  """
25
57
  Ollama provider for both LLM and Embedding operations.
26
58
 
59
+ Performance features:
60
+ - Connection pooling via requests.Session() for faster sequential requests
61
+ - Async parallel embedding via embed_batch_async() using trio + httpx
62
+ - LRU cache for repeated embedding queries (2048 entries)
63
+
27
64
  Parameters
28
65
  ----------
29
66
  base_url : str, optional
@@ -32,6 +69,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
32
69
  API key for authentication (default: from OLLAMA_API_KEY env var)
33
70
  timeout : int, optional
34
71
  Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
72
+ use_cache : bool, optional
73
+ Enable embedding cache (default: True)
35
74
 
36
75
  Examples
37
76
  --------
@@ -39,12 +78,12 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
39
78
  >>> response = provider.generate("What is RAG?", model="llama3")
40
79
  >>> print(response.text)
41
80
 
42
- >>> embedding = provider.embed("Hello world", model="nomic-embed-text")
43
- >>> print(len(embedding.embedding))
81
+ >>> # Async batch embedding (5-10x faster for large batches)
82
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
44
83
  """
45
84
 
46
85
  # Known embedding model dimensions
47
- EMBEDDING_DIMENSIONS = {
86
+ EMBEDDING_DIMENSIONS: dict[str, int] = {
48
87
  "nomic-embed-text": 768,
49
88
  "nomic-embed-text:latest": 768,
50
89
  "mxbai-embed-large": 1024,
@@ -56,20 +95,48 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
56
95
  "qwen3-embedding:8b": 4096,
57
96
  }
58
97
 
98
+ # Max characters per embedding request (safe limit for 512 token models)
99
+ MAX_EMBED_CHARS = 2000
100
+
59
101
  def __init__(
60
102
  self,
61
103
  base_url: str | None = None,
62
104
  embedding_url: str | None = None,
63
105
  api_key: str | None = None,
64
106
  timeout: int | None = None,
107
+ use_cache: bool = True,
65
108
  ) -> None:
66
109
  self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
67
110
  self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
68
111
  self.api_key = api_key or config.OLLAMA_API_KEY
69
112
  self.timeout = timeout or config.OLLAMA_TIMEOUT
113
+ self.use_cache = use_cache
70
114
  self._current_embed_model: str | None = None
71
115
  self._current_dimensions: int = 768 # default
72
116
 
117
+ # Connection pooling via session
118
+ self._session: requests.Session | None = None
119
+
120
+ @property
121
+ def session(self) -> requests.Session:
122
+ """Lazy-initialized session for connection pooling."""
123
+ if self._session is None:
124
+ self._session = requests.Session()
125
+ self._session.headers.update({"Content-Type": "application/json"})
126
+ if self.api_key:
127
+ self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
128
+ return self._session
129
+
130
+ def close(self) -> None:
131
+ """Close the session and release resources."""
132
+ if self._session is not None:
133
+ self._session.close()
134
+ self._session = None
135
+
136
+ def __del__(self) -> None:
137
+ """Cleanup on garbage collection."""
138
+ self.close()
139
+
73
140
  def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
74
141
  """Get request headers including authentication if API key is set."""
75
142
  headers = {"Content-Type": "application/json"}
@@ -88,21 +155,19 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
88
155
  def is_available(self) -> bool:
89
156
  """Check if Ollama server is reachable."""
90
157
  try:
91
- response = requests.get(
158
+ response = self.session.get(
92
159
  f"{self.base_url}/api/tags",
93
- headers=self._get_headers(),
94
160
  timeout=5,
95
161
  )
96
162
  return response.status_code == 200
97
163
  except requests.RequestException:
98
164
  return False
99
165
 
100
- def list_models(self) -> list[dict[str, str]]:
166
+ def list_models(self) -> list[dict[str, Any]]:
101
167
  """List available models on the Ollama server."""
102
168
  try:
103
- response = requests.get(
169
+ response = self.session.get(
104
170
  f"{self.base_url}/api/tags",
105
- headers=self._get_headers(),
106
171
  timeout=10,
107
172
  )
108
173
  response.raise_for_status()
@@ -135,9 +200,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
135
200
  payload["system"] = system_prompt
136
201
 
137
202
  try:
138
- response = requests.post(
203
+ response = self.session.post(
139
204
  f"{self.base_url}/api/generate",
140
- headers=self._get_headers(),
141
205
  json=payload,
142
206
  timeout=self.timeout,
143
207
  )
@@ -158,29 +222,34 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
158
222
  raise ConnectionError(f"Ollama generate failed: {e}") from e
159
223
 
160
224
  def embed(self, text: str, model: str) -> EmbeddingResponse:
161
- """Generate embedding using Ollama (uses embedding_url, no auth for local)."""
225
+ """Generate embedding using Ollama with optional caching."""
162
226
  self._current_embed_model = model
163
227
  self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
164
228
 
165
229
  try:
166
- response = requests.post(
167
- f"{self.embedding_url}/api/embeddings",
168
- headers=self._get_headers(include_auth=False),
169
- json={"model": model, "prompt": text},
170
- timeout=self.timeout,
171
- )
172
- response.raise_for_status()
173
- data = response.json()
174
-
175
- embedding = data.get("embedding", [])
176
- if not embedding:
177
- raise ValueError("Empty embedding returned from Ollama")
230
+ if self.use_cache:
231
+ # Use cached version
232
+ embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
233
+ else:
234
+ # Direct call without cache
235
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
236
+ response = self.session.post(
237
+ f"{self.embedding_url}/api/embeddings",
238
+ json={"model": model, "prompt": truncated},
239
+ timeout=self.timeout,
240
+ )
241
+ response.raise_for_status()
242
+ data = response.json()
243
+ embedding_list = data.get("embedding", [])
244
+ if not embedding_list:
245
+ raise ValueError("Empty embedding returned from Ollama")
246
+ embedding = tuple(embedding_list)
178
247
 
179
248
  # Update dimensions from actual response
180
249
  self._current_dimensions = len(embedding)
181
250
 
182
251
  return EmbeddingResponse(
183
- embedding=tuple(embedding),
252
+ embedding=embedding,
184
253
  model=model,
185
254
  provider=self.provider_name,
186
255
  dimensions=len(embedding),
@@ -189,7 +258,9 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
189
258
  raise ConnectionError(f"Ollama embed failed: {e}") from e
190
259
 
191
260
  def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
192
- """Generate embeddings for multiple texts (uses embedding_url, no auth for local).
261
+ """Generate embeddings for multiple texts sequentially.
262
+
263
+ For better performance with large batches, use embed_batch_async().
193
264
 
194
265
  Note: Ollama /api/embeddings only supports single prompts, so we loop.
195
266
  """
@@ -199,22 +270,28 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
199
270
  results = []
200
271
  try:
201
272
  for text in texts:
202
- response = requests.post(
203
- f"{self.embedding_url}/api/embeddings",
204
- headers=self._get_headers(include_auth=False),
205
- json={"model": model, "prompt": text},
206
- timeout=self.timeout,
207
- )
208
- response.raise_for_status()
209
- data = response.json()
273
+ # Truncate oversized inputs
274
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
275
+
276
+ if self.use_cache:
277
+ embedding = _cached_embedding(truncated, model, self.embedding_url, self.timeout)
278
+ else:
279
+ response = self.session.post(
280
+ f"{self.embedding_url}/api/embeddings",
281
+ json={"model": model, "prompt": truncated},
282
+ timeout=self.timeout,
283
+ )
284
+ response.raise_for_status()
285
+ data = response.json()
286
+ embedding_list = data.get("embedding", [])
287
+ embedding = tuple(embedding_list) if embedding_list else ()
210
288
 
211
- embedding = data.get("embedding", [])
212
289
  if embedding:
213
290
  self._current_dimensions = len(embedding)
214
291
 
215
292
  results.append(
216
293
  EmbeddingResponse(
217
- embedding=tuple(embedding),
294
+ embedding=embedding,
218
295
  model=model,
219
296
  provider=self.provider_name,
220
297
  dimensions=len(embedding),
@@ -224,6 +301,87 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
224
301
  except requests.RequestException as e:
225
302
  raise ConnectionError(f"Ollama batch embed failed: {e}") from e
226
303
 
304
+ async def embed_batch_async(
305
+ self,
306
+ texts: list[str],
307
+ model: str,
308
+ max_concurrent: int = 10,
309
+ ) -> list[EmbeddingResponse]:
310
+ """Generate embeddings for multiple texts in parallel using trio.
311
+
312
+ This method is 5-10x faster than embed_batch() for large batches
313
+ by making concurrent HTTP requests.
314
+
315
+ Parameters
316
+ ----------
317
+ texts : list[str]
318
+ Texts to embed.
319
+ model : str
320
+ Embedding model name.
321
+ max_concurrent : int
322
+ Maximum concurrent requests (default: 10).
323
+ Higher values = faster but more server load.
324
+
325
+ Returns
326
+ -------
327
+ list[EmbeddingResponse]
328
+ Embeddings in the same order as input texts.
329
+
330
+ Examples
331
+ --------
332
+ >>> import trio
333
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
334
+ """
335
+ self._current_embed_model = model
336
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
337
+
338
+ # Results storage (index -> embedding)
339
+ results: dict[int, EmbeddingResponse] = {}
340
+ errors: list[Exception] = []
341
+
342
+ # Semaphore to limit concurrency
343
+ limiter = trio.CapacityLimiter(max_concurrent)
344
+
345
+ async def fetch_embedding(client: httpx.AsyncClient, index: int, text: str) -> None:
346
+ """Fetch a single embedding."""
347
+ async with limiter:
348
+ try:
349
+ # Truncate oversized inputs
350
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
351
+
352
+ response = await client.post(
353
+ f"{self.embedding_url}/api/embeddings",
354
+ json={"model": model, "prompt": truncated},
355
+ timeout=self.timeout,
356
+ )
357
+ response.raise_for_status()
358
+ data = response.json()
359
+
360
+ embedding_list = data.get("embedding", [])
361
+ embedding = tuple(embedding_list) if embedding_list else ()
362
+
363
+ if embedding:
364
+ self._current_dimensions = len(embedding)
365
+
366
+ results[index] = EmbeddingResponse(
367
+ embedding=embedding,
368
+ model=model,
369
+ provider=self.provider_name,
370
+ dimensions=len(embedding),
371
+ )
372
+ except Exception as e:
373
+ errors.append(e)
374
+
375
+ async with httpx.AsyncClient() as client, trio.open_nursery() as nursery:
376
+ for i, text in enumerate(texts):
377
+ nursery.start_soon(fetch_embedding, client, i, text)
378
+
379
+ if errors:
380
+ raise ConnectionError(f"Ollama async batch embed failed: {errors[0]}") from errors[0]
381
+
382
+ # Return results in original order
383
+ return [results[i] for i in range(len(texts))]
384
+
227
385
  def chat(
228
386
  self,
229
387
  messages: list[dict[str, str]],
@@ -262,9 +420,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
262
420
  }
263
421
 
264
422
  try:
265
- response = requests.post(
423
+ response = self.session.post(
266
424
  f"{self.base_url}/api/chat",
267
- headers=self._get_headers(),
268
425
  json=payload,
269
426
  timeout=self.timeout,
270
427
  )
@@ -282,3 +439,23 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
282
439
  )
283
440
  except requests.RequestException as e:
284
441
  raise ConnectionError(f"Ollama chat failed: {e}") from e
442
+
443
+ @staticmethod
444
+ def clear_embedding_cache() -> None:
445
+ """Clear the embedding cache."""
446
+ _cached_embedding.cache_clear()
447
+
448
+ @staticmethod
449
+ def embedding_cache_info() -> dict[str, int]:
450
+ """Get embedding cache statistics."""
451
+ info = _cached_embedding.cache_info()
452
+ return {
453
+ "hits": info.hits,
454
+ "misses": info.misses,
455
+ "maxsize": info.maxsize or 0,
456
+ "currsize": info.currsize,
457
+ }
458
+
459
+
460
+ # Export the EMBEDDING_DIMENSIONS for external use
461
+ EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
ragit/version.py CHANGED
@@ -2,4 +2,4 @@
2
2
  # Copyright RODMENA LIMITED 2025
3
3
  # SPDX-License-Identifier: Apache-2.0
4
4
  #
5
- __version__ = "0.7.2"
5
+ __version__ = "0.7.4"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ragit
3
- Version: 0.7.2
3
+ Version: 0.7.4
4
4
  Summary: Automatic RAG Pattern Optimization Engine
5
5
  Author: RODMENA LIMITED
6
6
  Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
@@ -26,6 +26,8 @@ Requires-Dist: pydantic>=2.0.0
26
26
  Requires-Dist: python-dotenv>=1.0.0
27
27
  Requires-Dist: scikit-learn>=1.5.0
28
28
  Requires-Dist: tqdm>=4.66.0
29
+ Requires-Dist: trio>=0.24.0
30
+ Requires-Dist: httpx>=0.27.0
29
31
  Provides-Extra: dev
30
32
  Requires-Dist: ragit[test]; extra == "dev"
31
33
  Requires-Dist: pytest; extra == "dev"
@@ -443,6 +445,77 @@ print(f"Score: {best.score:.3f}")
443
445
 
444
446
  The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
445
447
 
448
+ ## Performance Features
449
+
450
+ Ragit includes several optimizations for production workloads:
451
+
452
+ ### Connection Pooling
453
+
454
+ `OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
455
+
456
+ ```python
457
+ from ragit.providers import OllamaProvider
458
+
459
+ provider = OllamaProvider()
460
+
461
+ # All requests reuse the same connection pool
462
+ for text in texts:
463
+ provider.embed(text, model="mxbai-embed-large")
464
+
465
+ # Explicitly close when done (optional, auto-closes on garbage collection)
466
+ provider.close()
467
+ ```
468
+
469
+ ### Async Parallel Embedding
470
+
471
+ For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
472
+
473
+ ```python
474
+ import trio
475
+ from ragit.providers import OllamaProvider
476
+
477
+ provider = OllamaProvider()
478
+
479
+ async def embed_documents():
480
+ texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
481
+ embeddings = await provider.embed_batch_async(
482
+ texts,
483
+ model="mxbai-embed-large",
484
+ max_concurrent=10 # Adjust based on server capacity
485
+ )
486
+ return embeddings
487
+
488
+ # Run with trio
489
+ results = trio.run(embed_documents)
490
+ ```
491
+
492
+ ### Embedding Cache
493
+
494
+ Repeated embedding calls are cached automatically (2048 entries LRU):
495
+
496
+ ```python
497
+ from ragit.providers import OllamaProvider
498
+
499
+ provider = OllamaProvider(use_cache=True) # Default
500
+
501
+ # First call hits the API
502
+ provider.embed("Hello world", model="mxbai-embed-large")
503
+
504
+ # Second call returns cached result instantly
505
+ provider.embed("Hello world", model="mxbai-embed-large")
506
+
507
+ # View cache statistics
508
+ print(OllamaProvider.embedding_cache_info())
509
+ # {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
510
+
511
+ # Clear cache if needed
512
+ OllamaProvider.clear_embedding_cache()
513
+ ```
514
+
515
+ ### Pre-normalized Embeddings
516
+
517
+ Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
518
+
446
519
  ## API Reference
447
520
 
448
521
  ### Document Loading
@@ -2,17 +2,17 @@ ragit/__init__.py,sha256=PjQogIWMlydZFWVECqhmxw-X9i7lEXdUTe2XlT6qYUQ,2213
2
2
  ragit/assistant.py,sha256=lXjZRUr_WsYLP3XLOktabgfPVyKOZPdREzyL7cSRufk,11251
3
3
  ragit/config.py,sha256=uKLchJQHjH8MImZ2OJahDjSzyasFqgrFb9Z4aHxJ7og,1495
4
4
  ragit/loaders.py,sha256=keusuPzXPBiLDVj4hKfPCcge-rm-cnzNRk50fGXvTJs,5571
5
- ragit/version.py,sha256=5FNjmLKNB4z5E3tbMqwASQafUUFSWnaxBRb0EsQPVK8,97
5
+ ragit/version.py,sha256=8-YGrxlAluU3va125prug6u_bleRtoTX3c4m7WfDYNM,97
6
6
  ragit/core/__init__.py,sha256=j53PFfoSMXwSbK1rRHpMbo8mX2i4R1LJ5kvTxBd7-0w,100
7
7
  ragit/core/experiment/__init__.py,sha256=4vAPOOYlY5Dcr2gOolyhBSPGIUxZKwEkgQffxS9BodA,452
8
8
  ragit/core/experiment/experiment.py,sha256=Qh1NJkY9LbKaidRfiI8GOwBZqopjK-MSVBuD_JEgO-k,16582
9
9
  ragit/core/experiment/results.py,sha256=KHpN3YSLJ83_JUfIMccRPS-q7LEt0S9p8ehDRawk_4k,3487
10
10
  ragit/providers/__init__.py,sha256=iliJt74Lt3mFUlKGfSFW-D0cMonUygY6sRZ6lLjeU7M,435
11
11
  ragit/providers/base.py,sha256=MJ8mVeXuGWhkX2XGTbkWIY3cVoTOPr4h5XBXw8rAX2Q,3434
12
- ragit/providers/ollama.py,sha256=epmXiJ06jCWFERWFMdCNr8FP6WubxdHXj8dUn71WJuQ,9502
12
+ ragit/providers/ollama.py,sha256=fFUziRmpu_L6rnlxi53PnzQ5b34-fEbaU4u1irK1nG0,16055
13
13
  ragit/utils/__init__.py,sha256=-UsE5oJSnmEnBDswl-ph0A09Iu8yKNbPhd1-_7Lcb8Y,3051
14
- ragit-0.7.2.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
15
- ragit-0.7.2.dist-info/METADATA,sha256=TZ_VuZbe4GKkb8Jo8tC_NgT3hcu7e-uy7S9w6kpURlE,13662
16
- ragit-0.7.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
- ragit-0.7.2.dist-info/top_level.txt,sha256=pkPbG7yrw61wt9_y_xcLE2vq2a55fzockASD0yq0g4s,6
18
- ragit-0.7.2.dist-info/RECORD,,
14
+ ragit-0.7.4.dist-info/licenses/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
15
+ ragit-0.7.4.dist-info/METADATA,sha256=5nEjliSCc-F7X6IgG7OdgkadNjANPw4env04EEBI5J8,15528
16
+ ragit-0.7.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
17
+ ragit-0.7.4.dist-info/top_level.txt,sha256=pkPbG7yrw61wt9_y_xcLE2vq2a55fzockASD0yq0g4s,6
18
+ ragit-0.7.4.dist-info/RECORD,,
File without changes