ragit 0.8__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,237 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Function-based provider adapter for pluggable embedding and LLM functions.
7
+
8
+ This module provides a simple adapter that wraps user-provided functions
9
+ into the provider interface, enabling easy integration with custom
10
+ embedding and LLM implementations.
11
+ """
12
+
13
+ import inspect
14
+ from collections.abc import Callable
15
+
16
+ from ragit.providers.base import (
17
+ BaseEmbeddingProvider,
18
+ BaseLLMProvider,
19
+ EmbeddingResponse,
20
+ LLMResponse,
21
+ )
22
+
23
+
24
+ class FunctionProvider(BaseLLMProvider, BaseEmbeddingProvider):
25
+ """
26
+ Adapter that wraps user-provided embedding and generation functions.
27
+
28
+ This provider allows users to bring their own embedding and/or LLM functions
29
+ without implementing the full provider interface.
30
+
31
+ Parameters
32
+ ----------
33
+ embed_fn : Callable[[str], list[float]], optional
34
+ Function that takes text and returns an embedding vector.
35
+ Example: `lambda text: openai.embeddings.create(input=text).data[0].embedding`
36
+ generate_fn : Callable, optional
37
+ Function for text generation. Supports two signatures:
38
+ - (prompt: str) -> str
39
+ - (prompt: str, system_prompt: str) -> str
40
+ embedding_dimensions : int, optional
41
+ Embedding dimensions. Auto-detected on first call if not provided.
42
+
43
+ Examples
44
+ --------
45
+ >>> # Simple embedding function
46
+ >>> def my_embed(text: str) -> list[float]:
47
+ ... return openai.embeddings.create(input=text).data[0].embedding
48
+ >>>
49
+ >>> # Use with RAGAssistant (retrieval-only)
50
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed)
51
+ >>> results = assistant.retrieve("query")
52
+ >>>
53
+ >>> # With LLM for full RAG
54
+ >>> def my_llm(prompt: str, system_prompt: str = None) -> str:
55
+ ... return openai.chat.completions.create(
56
+ ... messages=[{"role": "user", "content": prompt}]
57
+ ... ).choices[0].message.content
58
+ >>>
59
+ >>> assistant = RAGAssistant(docs, embed_fn=my_embed, generate_fn=my_llm)
60
+ >>> answer = assistant.ask("What is X?")
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ embed_fn: Callable[[str], list[float]] | None = None,
66
+ generate_fn: Callable[..., str] | None = None,
67
+ embedding_dimensions: int | None = None,
68
+ ) -> None:
69
+ self._embed_fn = embed_fn
70
+ self._generate_fn = generate_fn
71
+ self._embedding_dimensions = embedding_dimensions
72
+ self._generate_fn_signature: int | None = None # Number of args (1 or 2)
73
+
74
+ # Detect generate_fn signature if provided
75
+ if generate_fn is not None:
76
+ self._detect_generate_signature()
77
+
78
+ def _detect_generate_signature(self) -> None:
79
+ """Detect whether generate_fn accepts 1 or 2 arguments."""
80
+ if self._generate_fn is None:
81
+ return
82
+
83
+ sig = inspect.signature(self._generate_fn)
84
+ params = [
85
+ p
86
+ for p in sig.parameters.values()
87
+ if p.default is inspect.Parameter.empty and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
88
+ ]
89
+ # Count required parameters
90
+ required_count = len(params)
91
+
92
+ if required_count == 1:
93
+ self._generate_fn_signature = 1
94
+ else:
95
+ # Assume 2 args if more than 1 required or if has optional args
96
+ self._generate_fn_signature = 2
97
+
98
+ @property
99
+ def provider_name(self) -> str:
100
+ return "function"
101
+
102
+ @property
103
+ def dimensions(self) -> int:
104
+ if self._embedding_dimensions is None:
105
+ raise ValueError("Embedding dimensions not yet determined. Call embed() first or provide dimensions.")
106
+ return self._embedding_dimensions
107
+
108
+ @property
109
+ def has_embedding(self) -> bool:
110
+ """Check if embedding function is configured."""
111
+ return self._embed_fn is not None
112
+
113
+ @property
114
+ def has_llm(self) -> bool:
115
+ """Check if LLM generation function is configured."""
116
+ return self._generate_fn is not None
117
+
118
+ def is_available(self) -> bool:
119
+ """Check if the provider has at least one function configured."""
120
+ return self._embed_fn is not None or self._generate_fn is not None
121
+
122
+ def embed(self, text: str, model: str = "") -> EmbeddingResponse:
123
+ """
124
+ Generate embedding using the provided function.
125
+
126
+ Parameters
127
+ ----------
128
+ text : str
129
+ Text to embed.
130
+ model : str
131
+ Model identifier (ignored, kept for interface compatibility).
132
+
133
+ Returns
134
+ -------
135
+ EmbeddingResponse
136
+ The embedding response.
137
+
138
+ Raises
139
+ ------
140
+ ValueError
141
+ If no embedding function was provided.
142
+ """
143
+ if self._embed_fn is None:
144
+ raise ValueError("No embedding function configured. Provide embed_fn to use embeddings.")
145
+
146
+ raw_embedding = self._embed_fn(text)
147
+
148
+ # Convert to tuple for immutability
149
+ embedding_tuple: tuple[float, ...] = tuple(raw_embedding)
150
+
151
+ # Auto-detect dimensions on first call
152
+ if self._embedding_dimensions is None:
153
+ self._embedding_dimensions = len(embedding_tuple)
154
+
155
+ return EmbeddingResponse(
156
+ embedding=embedding_tuple,
157
+ model=model or "function",
158
+ provider=self.provider_name,
159
+ dimensions=len(embedding_tuple),
160
+ )
161
+
162
+ def embed_batch(self, texts: list[str], model: str = "") -> list[EmbeddingResponse]:
163
+ """
164
+ Generate embeddings for multiple texts.
165
+
166
+ Iterates over embed_fn for each text. For providers with native batch
167
+ support, users should implement their own BatchEmbeddingProvider.
168
+
169
+ Parameters
170
+ ----------
171
+ texts : list[str]
172
+ Texts to embed.
173
+ model : str
174
+ Model identifier (ignored).
175
+
176
+ Returns
177
+ -------
178
+ list[EmbeddingResponse]
179
+ List of embedding responses.
180
+ """
181
+ return [self.embed(text, model) for text in texts]
182
+
183
+ def generate(
184
+ self,
185
+ prompt: str,
186
+ model: str = "",
187
+ system_prompt: str | None = None,
188
+ temperature: float = 0.7,
189
+ max_tokens: int | None = None,
190
+ ) -> LLMResponse:
191
+ """
192
+ Generate text using the provided function.
193
+
194
+ Parameters
195
+ ----------
196
+ prompt : str
197
+ The user prompt.
198
+ model : str
199
+ Model identifier (ignored, kept for interface compatibility).
200
+ system_prompt : str, optional
201
+ System prompt for context.
202
+ temperature : float
203
+ Sampling temperature (ignored if function doesn't support it).
204
+ max_tokens : int, optional
205
+ Maximum tokens (ignored if function doesn't support it).
206
+
207
+ Returns
208
+ -------
209
+ LLMResponse
210
+ The generated response.
211
+
212
+ Raises
213
+ ------
214
+ NotImplementedError
215
+ If no generation function was provided.
216
+ """
217
+ if self._generate_fn is None:
218
+ raise NotImplementedError(
219
+ "No LLM configured. Provide generate_fn or a provider with LLM support "
220
+ "to use ask(), generate(), or generate_code() methods."
221
+ )
222
+
223
+ # Call with appropriate signature
224
+ if self._generate_fn_signature == 1:
225
+ # Single argument - prepend system prompt to prompt if provided
226
+ full_prompt = f"{system_prompt}\n\n{prompt}" if system_prompt else prompt
227
+ text = self._generate_fn(full_prompt)
228
+ else:
229
+ # Two arguments - pass separately
230
+ text = self._generate_fn(prompt, system_prompt)
231
+
232
+ return LLMResponse(
233
+ text=text,
234
+ model=model or "function",
235
+ provider=self.provider_name,
236
+ usage=None,
237
+ )
@@ -0,0 +1,446 @@
1
+ #
2
+ # Copyright RODMENA LIMITED 2025
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ """
6
+ Ollama provider for LLM and Embedding operations.
7
+
8
+ This provider connects to a local or remote Ollama server.
9
+ Configuration is loaded from environment variables.
10
+
11
+ Performance optimizations:
12
+ - Connection pooling via requests.Session()
13
+ - Async parallel embedding via trio + httpx
14
+ - LRU cache for repeated embedding queries
15
+ """
16
+
17
+ from functools import lru_cache
18
+ from typing import Any
19
+
20
+ import httpx
21
+ import requests
22
+
23
+ from ragit.config import config
24
+ from ragit.providers.base import (
25
+ BaseEmbeddingProvider,
26
+ BaseLLMProvider,
27
+ EmbeddingResponse,
28
+ LLMResponse,
29
+ )
30
+
31
+
32
+ # Module-level cache for embeddings (shared across instances)
33
+ @lru_cache(maxsize=2048)
34
+ def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
35
+ """Cache embedding results to avoid redundant API calls."""
36
+ # Truncate oversized inputs
37
+ if len(text) > OllamaProvider.MAX_EMBED_CHARS:
38
+ text = text[: OllamaProvider.MAX_EMBED_CHARS]
39
+
40
+ response = requests.post(
41
+ f"{embedding_url}/api/embed",
42
+ headers={"Content-Type": "application/json"},
43
+ json={"model": model, "input": text},
44
+ timeout=timeout,
45
+ )
46
+ response.raise_for_status()
47
+ data = response.json()
48
+ embeddings = data.get("embeddings", [])
49
+ if not embeddings or not embeddings[0]:
50
+ raise ValueError("Empty embedding returned from Ollama")
51
+ return tuple(embeddings[0])
52
+
53
+
54
+ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
55
+ """
56
+ Ollama provider for both LLM and Embedding operations.
57
+
58
+ Performance features:
59
+ - Connection pooling via requests.Session() for faster sequential requests
60
+ - Native batch embedding via /api/embed endpoint (single API call)
61
+ - LRU cache for repeated embedding queries (2048 entries)
62
+
63
+ Parameters
64
+ ----------
65
+ base_url : str, optional
66
+ Ollama server URL (default: from OLLAMA_BASE_URL env var)
67
+ api_key : str, optional
68
+ API key for authentication (default: from OLLAMA_API_KEY env var)
69
+ timeout : int, optional
70
+ Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
71
+ use_cache : bool, optional
72
+ Enable embedding cache (default: True)
73
+
74
+ Examples
75
+ --------
76
+ >>> provider = OllamaProvider()
77
+ >>> response = provider.generate("What is RAG?", model="llama3")
78
+ >>> print(response.text)
79
+
80
+ >>> # Batch embedding (single API call)
81
+ >>> embeddings = provider.embed_batch(texts, "mxbai-embed-large")
82
+ """
83
+
84
+ # Known embedding model dimensions
85
+ EMBEDDING_DIMENSIONS: dict[str, int] = {
86
+ "nomic-embed-text": 768,
87
+ "nomic-embed-text:latest": 768,
88
+ "mxbai-embed-large": 1024,
89
+ "all-minilm": 384,
90
+ "snowflake-arctic-embed": 1024,
91
+ "qwen3-embedding": 4096,
92
+ "qwen3-embedding:0.6b": 1024,
93
+ "qwen3-embedding:4b": 2560,
94
+ "qwen3-embedding:8b": 4096,
95
+ }
96
+
97
+ # Max characters per embedding request (safe limit for 512 token models)
98
+ MAX_EMBED_CHARS = 2000
99
+
100
+ def __init__(
101
+ self,
102
+ base_url: str | None = None,
103
+ embedding_url: str | None = None,
104
+ api_key: str | None = None,
105
+ timeout: int | None = None,
106
+ use_cache: bool = True,
107
+ ) -> None:
108
+ self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
109
+ self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
110
+ self.api_key = api_key or config.OLLAMA_API_KEY
111
+ self.timeout = timeout or config.OLLAMA_TIMEOUT
112
+ self.use_cache = use_cache
113
+ self._current_embed_model: str | None = None
114
+ self._current_dimensions: int = 768 # default
115
+
116
+ # Connection pooling via session
117
+ self._session: requests.Session | None = None
118
+
119
+ @property
120
+ def session(self) -> requests.Session:
121
+ """Lazy-initialized session for connection pooling."""
122
+ if self._session is None:
123
+ self._session = requests.Session()
124
+ self._session.headers.update({"Content-Type": "application/json"})
125
+ if self.api_key:
126
+ self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
127
+ return self._session
128
+
129
+ def close(self) -> None:
130
+ """Close the session and release resources."""
131
+ if self._session is not None:
132
+ self._session.close()
133
+ self._session = None
134
+
135
+ def __del__(self) -> None:
136
+ """Cleanup on garbage collection."""
137
+ self.close()
138
+
139
+ def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
140
+ """Get request headers including authentication if API key is set."""
141
+ headers = {"Content-Type": "application/json"}
142
+ if include_auth and self.api_key:
143
+ headers["Authorization"] = f"Bearer {self.api_key}"
144
+ return headers
145
+
146
+ @property
147
+ def provider_name(self) -> str:
148
+ return "ollama"
149
+
150
+ @property
151
+ def dimensions(self) -> int:
152
+ return self._current_dimensions
153
+
154
+ def is_available(self) -> bool:
155
+ """Check if Ollama server is reachable."""
156
+ try:
157
+ response = self.session.get(
158
+ f"{self.base_url}/api/tags",
159
+ timeout=5,
160
+ )
161
+ return bool(response.status_code == 200)
162
+ except requests.RequestException:
163
+ return False
164
+
165
+ def list_models(self) -> list[dict[str, Any]]:
166
+ """List available models on the Ollama server."""
167
+ try:
168
+ response = self.session.get(
169
+ f"{self.base_url}/api/tags",
170
+ timeout=10,
171
+ )
172
+ response.raise_for_status()
173
+ data = response.json()
174
+ return list(data.get("models", []))
175
+ except requests.RequestException as e:
176
+ raise ConnectionError(f"Failed to list Ollama models: {e}") from e
177
+
178
+ def generate(
179
+ self,
180
+ prompt: str,
181
+ model: str,
182
+ system_prompt: str | None = None,
183
+ temperature: float = 0.7,
184
+ max_tokens: int | None = None,
185
+ ) -> LLMResponse:
186
+ """Generate text using Ollama."""
187
+ options: dict[str, float | int] = {"temperature": temperature}
188
+ if max_tokens:
189
+ options["num_predict"] = max_tokens
190
+
191
+ payload: dict[str, str | bool | dict[str, float | int]] = {
192
+ "model": model,
193
+ "prompt": prompt,
194
+ "stream": False,
195
+ "options": options,
196
+ }
197
+
198
+ if system_prompt:
199
+ payload["system"] = system_prompt
200
+
201
+ try:
202
+ response = self.session.post(
203
+ f"{self.base_url}/api/generate",
204
+ json=payload,
205
+ timeout=self.timeout,
206
+ )
207
+ response.raise_for_status()
208
+ data = response.json()
209
+
210
+ return LLMResponse(
211
+ text=data.get("response", ""),
212
+ model=model,
213
+ provider=self.provider_name,
214
+ usage={
215
+ "prompt_tokens": data.get("prompt_eval_count"),
216
+ "completion_tokens": data.get("eval_count"),
217
+ "total_duration": data.get("total_duration"),
218
+ },
219
+ )
220
+ except requests.RequestException as e:
221
+ raise ConnectionError(f"Ollama generate failed: {e}") from e
222
+
223
+ def embed(self, text: str, model: str) -> EmbeddingResponse:
224
+ """Generate embedding using Ollama with optional caching."""
225
+ self._current_embed_model = model
226
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
227
+
228
+ try:
229
+ if self.use_cache:
230
+ # Use cached version
231
+ embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
232
+ else:
233
+ # Direct call without cache
234
+ truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
235
+ response = self.session.post(
236
+ f"{self.embedding_url}/api/embed",
237
+ json={"model": model, "input": truncated},
238
+ timeout=self.timeout,
239
+ )
240
+ response.raise_for_status()
241
+ data = response.json()
242
+ embeddings = data.get("embeddings", [])
243
+ if not embeddings or not embeddings[0]:
244
+ raise ValueError("Empty embedding returned from Ollama")
245
+ embedding = tuple(embeddings[0])
246
+
247
+ # Update dimensions from actual response
248
+ self._current_dimensions = len(embedding)
249
+
250
+ return EmbeddingResponse(
251
+ embedding=embedding,
252
+ model=model,
253
+ provider=self.provider_name,
254
+ dimensions=len(embedding),
255
+ )
256
+ except requests.RequestException as e:
257
+ raise ConnectionError(f"Ollama embed failed: {e}") from e
258
+
259
+ def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
260
+ """Generate embeddings for multiple texts in a single API call.
261
+
262
+ The /api/embed endpoint supports batch inputs natively.
263
+ """
264
+ self._current_embed_model = model
265
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
266
+
267
+ # Truncate oversized inputs
268
+ truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
269
+
270
+ try:
271
+ response = self.session.post(
272
+ f"{self.embedding_url}/api/embed",
273
+ json={"model": model, "input": truncated_texts},
274
+ timeout=self.timeout,
275
+ )
276
+ response.raise_for_status()
277
+ data = response.json()
278
+ embeddings_list = data.get("embeddings", [])
279
+
280
+ if not embeddings_list:
281
+ raise ValueError("Empty embeddings returned from Ollama")
282
+
283
+ results = []
284
+ for embedding_data in embeddings_list:
285
+ embedding = tuple(embedding_data) if embedding_data else ()
286
+ if embedding:
287
+ self._current_dimensions = len(embedding)
288
+
289
+ results.append(
290
+ EmbeddingResponse(
291
+ embedding=embedding,
292
+ model=model,
293
+ provider=self.provider_name,
294
+ dimensions=len(embedding),
295
+ )
296
+ )
297
+ return results
298
+ except requests.RequestException as e:
299
+ raise ConnectionError(f"Ollama batch embed failed: {e}") from e
300
+
301
+ async def embed_batch_async(
302
+ self,
303
+ texts: list[str],
304
+ model: str,
305
+ max_concurrent: int = 10, # kept for API compatibility, no longer used
306
+ ) -> list[EmbeddingResponse]:
307
+ """Generate embeddings for multiple texts asynchronously.
308
+
309
+ The /api/embed endpoint supports batch inputs natively, so this
310
+ makes a single async HTTP request for all texts.
311
+
312
+ Parameters
313
+ ----------
314
+ texts : list[str]
315
+ Texts to embed.
316
+ model : str
317
+ Embedding model name.
318
+ max_concurrent : int
319
+ Deprecated, kept for API compatibility. No longer used since
320
+ the API now supports native batching.
321
+
322
+ Returns
323
+ -------
324
+ list[EmbeddingResponse]
325
+ Embeddings in the same order as input texts.
326
+
327
+ Examples
328
+ --------
329
+ >>> import trio
330
+ >>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
331
+ """
332
+ self._current_embed_model = model
333
+ self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
334
+
335
+ # Truncate oversized inputs
336
+ truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
337
+
338
+ try:
339
+ async with httpx.AsyncClient() as client:
340
+ response = await client.post(
341
+ f"{self.embedding_url}/api/embed",
342
+ json={"model": model, "input": truncated_texts},
343
+ timeout=self.timeout,
344
+ )
345
+ response.raise_for_status()
346
+ data = response.json()
347
+
348
+ embeddings_list = data.get("embeddings", [])
349
+ if not embeddings_list:
350
+ raise ValueError("Empty embeddings returned from Ollama")
351
+
352
+ results = []
353
+ for embedding_data in embeddings_list:
354
+ embedding = tuple(embedding_data) if embedding_data else ()
355
+ if embedding:
356
+ self._current_dimensions = len(embedding)
357
+
358
+ results.append(
359
+ EmbeddingResponse(
360
+ embedding=embedding,
361
+ model=model,
362
+ provider=self.provider_name,
363
+ dimensions=len(embedding),
364
+ )
365
+ )
366
+ return results
367
+ except httpx.HTTPError as e:
368
+ raise ConnectionError(f"Ollama async batch embed failed: {e}") from e
369
+
370
+ def chat(
371
+ self,
372
+ messages: list[dict[str, str]],
373
+ model: str,
374
+ temperature: float = 0.7,
375
+ max_tokens: int | None = None,
376
+ ) -> LLMResponse:
377
+ """
378
+ Chat completion using Ollama.
379
+
380
+ Parameters
381
+ ----------
382
+ messages : list[dict]
383
+ List of messages with 'role' and 'content' keys.
384
+ model : str
385
+ Model identifier.
386
+ temperature : float
387
+ Sampling temperature.
388
+ max_tokens : int, optional
389
+ Maximum tokens to generate.
390
+
391
+ Returns
392
+ -------
393
+ LLMResponse
394
+ The generated response.
395
+ """
396
+ options: dict[str, float | int] = {"temperature": temperature}
397
+ if max_tokens:
398
+ options["num_predict"] = max_tokens
399
+
400
+ payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
401
+ "model": model,
402
+ "messages": messages,
403
+ "stream": False,
404
+ "options": options,
405
+ }
406
+
407
+ try:
408
+ response = self.session.post(
409
+ f"{self.base_url}/api/chat",
410
+ json=payload,
411
+ timeout=self.timeout,
412
+ )
413
+ response.raise_for_status()
414
+ data = response.json()
415
+
416
+ return LLMResponse(
417
+ text=data.get("message", {}).get("content", ""),
418
+ model=model,
419
+ provider=self.provider_name,
420
+ usage={
421
+ "prompt_tokens": data.get("prompt_eval_count"),
422
+ "completion_tokens": data.get("eval_count"),
423
+ },
424
+ )
425
+ except requests.RequestException as e:
426
+ raise ConnectionError(f"Ollama chat failed: {e}") from e
427
+
428
+ @staticmethod
429
+ def clear_embedding_cache() -> None:
430
+ """Clear the embedding cache."""
431
+ _cached_embedding.cache_clear()
432
+
433
+ @staticmethod
434
+ def embedding_cache_info() -> dict[str, int]:
435
+ """Get embedding cache statistics."""
436
+ info = _cached_embedding.cache_info()
437
+ return {
438
+ "hits": info.hits,
439
+ "misses": info.misses,
440
+ "maxsize": info.maxsize or 0,
441
+ "currsize": info.currsize,
442
+ }
443
+
444
+
445
+ # Export the EMBEDDING_DIMENSIONS for external use
446
+ EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS