ragit 0.8__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +116 -2
- ragit/assistant.py +442 -0
- ragit/config.py +60 -0
- ragit/core/__init__.py +5 -0
- ragit/core/experiment/__init__.py +22 -0
- ragit/core/experiment/experiment.py +572 -0
- ragit/core/experiment/results.py +131 -0
- ragit/loaders.py +219 -0
- ragit/providers/__init__.py +47 -0
- ragit/providers/base.py +147 -0
- ragit/providers/function_adapter.py +237 -0
- ragit/providers/ollama.py +446 -0
- ragit/providers/sentence_transformers.py +225 -0
- ragit/utils/__init__.py +105 -0
- ragit/version.py +5 -0
- ragit-0.8.1.dist-info/METADATA +166 -0
- ragit-0.8.1.dist-info/RECORD +20 -0
- {ragit-0.8.dist-info → ragit-0.8.1.dist-info}/WHEEL +1 -1
- ragit-0.8.1.dist-info/licenses/LICENSE +201 -0
- {ragit-0.8.dist-info → ragit-0.8.1.dist-info}/top_level.txt +0 -0
- ragit/main.py +0 -354
- ragit-0.8.dist-info/LICENSE +0 -21
- ragit-0.8.dist-info/METADATA +0 -176
- ragit-0.8.dist-info/RECORD +0 -7
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ollama provider for LLM and Embedding operations.
|
|
7
|
+
|
|
8
|
+
This provider connects to a local or remote Ollama server.
|
|
9
|
+
Configuration is loaded from environment variables.
|
|
10
|
+
|
|
11
|
+
Performance optimizations:
|
|
12
|
+
- Connection pooling via requests.Session()
|
|
13
|
+
- Async parallel embedding via trio + httpx
|
|
14
|
+
- LRU cache for repeated embedding queries
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
import requests
|
|
22
|
+
|
|
23
|
+
from ragit.config import config
|
|
24
|
+
from ragit.providers.base import (
|
|
25
|
+
BaseEmbeddingProvider,
|
|
26
|
+
BaseLLMProvider,
|
|
27
|
+
EmbeddingResponse,
|
|
28
|
+
LLMResponse,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Module-level cache for embeddings (shared across instances)
|
|
33
|
+
@lru_cache(maxsize=2048)
|
|
34
|
+
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
|
|
35
|
+
"""Cache embedding results to avoid redundant API calls."""
|
|
36
|
+
# Truncate oversized inputs
|
|
37
|
+
if len(text) > OllamaProvider.MAX_EMBED_CHARS:
|
|
38
|
+
text = text[: OllamaProvider.MAX_EMBED_CHARS]
|
|
39
|
+
|
|
40
|
+
response = requests.post(
|
|
41
|
+
f"{embedding_url}/api/embed",
|
|
42
|
+
headers={"Content-Type": "application/json"},
|
|
43
|
+
json={"model": model, "input": text},
|
|
44
|
+
timeout=timeout,
|
|
45
|
+
)
|
|
46
|
+
response.raise_for_status()
|
|
47
|
+
data = response.json()
|
|
48
|
+
embeddings = data.get("embeddings", [])
|
|
49
|
+
if not embeddings or not embeddings[0]:
|
|
50
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
51
|
+
return tuple(embeddings[0])
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
55
|
+
"""
|
|
56
|
+
Ollama provider for both LLM and Embedding operations.
|
|
57
|
+
|
|
58
|
+
Performance features:
|
|
59
|
+
- Connection pooling via requests.Session() for faster sequential requests
|
|
60
|
+
- Native batch embedding via /api/embed endpoint (single API call)
|
|
61
|
+
- LRU cache for repeated embedding queries (2048 entries)
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
base_url : str, optional
|
|
66
|
+
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
67
|
+
api_key : str, optional
|
|
68
|
+
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
69
|
+
timeout : int, optional
|
|
70
|
+
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
71
|
+
use_cache : bool, optional
|
|
72
|
+
Enable embedding cache (default: True)
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
>>> provider = OllamaProvider()
|
|
77
|
+
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
78
|
+
>>> print(response.text)
|
|
79
|
+
|
|
80
|
+
>>> # Batch embedding (single API call)
|
|
81
|
+
>>> embeddings = provider.embed_batch(texts, "mxbai-embed-large")
|
|
82
|
+
"""
|
|
83
|
+
|
|
84
|
+
# Known embedding model dimensions
|
|
85
|
+
EMBEDDING_DIMENSIONS: dict[str, int] = {
|
|
86
|
+
"nomic-embed-text": 768,
|
|
87
|
+
"nomic-embed-text:latest": 768,
|
|
88
|
+
"mxbai-embed-large": 1024,
|
|
89
|
+
"all-minilm": 384,
|
|
90
|
+
"snowflake-arctic-embed": 1024,
|
|
91
|
+
"qwen3-embedding": 4096,
|
|
92
|
+
"qwen3-embedding:0.6b": 1024,
|
|
93
|
+
"qwen3-embedding:4b": 2560,
|
|
94
|
+
"qwen3-embedding:8b": 4096,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
# Max characters per embedding request (safe limit for 512 token models)
|
|
98
|
+
MAX_EMBED_CHARS = 2000
|
|
99
|
+
|
|
100
|
+
def __init__(
|
|
101
|
+
self,
|
|
102
|
+
base_url: str | None = None,
|
|
103
|
+
embedding_url: str | None = None,
|
|
104
|
+
api_key: str | None = None,
|
|
105
|
+
timeout: int | None = None,
|
|
106
|
+
use_cache: bool = True,
|
|
107
|
+
) -> None:
|
|
108
|
+
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
109
|
+
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
110
|
+
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
111
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
112
|
+
self.use_cache = use_cache
|
|
113
|
+
self._current_embed_model: str | None = None
|
|
114
|
+
self._current_dimensions: int = 768 # default
|
|
115
|
+
|
|
116
|
+
# Connection pooling via session
|
|
117
|
+
self._session: requests.Session | None = None
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def session(self) -> requests.Session:
|
|
121
|
+
"""Lazy-initialized session for connection pooling."""
|
|
122
|
+
if self._session is None:
|
|
123
|
+
self._session = requests.Session()
|
|
124
|
+
self._session.headers.update({"Content-Type": "application/json"})
|
|
125
|
+
if self.api_key:
|
|
126
|
+
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
|
127
|
+
return self._session
|
|
128
|
+
|
|
129
|
+
def close(self) -> None:
|
|
130
|
+
"""Close the session and release resources."""
|
|
131
|
+
if self._session is not None:
|
|
132
|
+
self._session.close()
|
|
133
|
+
self._session = None
|
|
134
|
+
|
|
135
|
+
def __del__(self) -> None:
|
|
136
|
+
"""Cleanup on garbage collection."""
|
|
137
|
+
self.close()
|
|
138
|
+
|
|
139
|
+
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
140
|
+
"""Get request headers including authentication if API key is set."""
|
|
141
|
+
headers = {"Content-Type": "application/json"}
|
|
142
|
+
if include_auth and self.api_key:
|
|
143
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
144
|
+
return headers
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def provider_name(self) -> str:
|
|
148
|
+
return "ollama"
|
|
149
|
+
|
|
150
|
+
@property
|
|
151
|
+
def dimensions(self) -> int:
|
|
152
|
+
return self._current_dimensions
|
|
153
|
+
|
|
154
|
+
def is_available(self) -> bool:
|
|
155
|
+
"""Check if Ollama server is reachable."""
|
|
156
|
+
try:
|
|
157
|
+
response = self.session.get(
|
|
158
|
+
f"{self.base_url}/api/tags",
|
|
159
|
+
timeout=5,
|
|
160
|
+
)
|
|
161
|
+
return bool(response.status_code == 200)
|
|
162
|
+
except requests.RequestException:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
166
|
+
"""List available models on the Ollama server."""
|
|
167
|
+
try:
|
|
168
|
+
response = self.session.get(
|
|
169
|
+
f"{self.base_url}/api/tags",
|
|
170
|
+
timeout=10,
|
|
171
|
+
)
|
|
172
|
+
response.raise_for_status()
|
|
173
|
+
data = response.json()
|
|
174
|
+
return list(data.get("models", []))
|
|
175
|
+
except requests.RequestException as e:
|
|
176
|
+
raise ConnectionError(f"Failed to list Ollama models: {e}") from e
|
|
177
|
+
|
|
178
|
+
def generate(
|
|
179
|
+
self,
|
|
180
|
+
prompt: str,
|
|
181
|
+
model: str,
|
|
182
|
+
system_prompt: str | None = None,
|
|
183
|
+
temperature: float = 0.7,
|
|
184
|
+
max_tokens: int | None = None,
|
|
185
|
+
) -> LLMResponse:
|
|
186
|
+
"""Generate text using Ollama."""
|
|
187
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
188
|
+
if max_tokens:
|
|
189
|
+
options["num_predict"] = max_tokens
|
|
190
|
+
|
|
191
|
+
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
192
|
+
"model": model,
|
|
193
|
+
"prompt": prompt,
|
|
194
|
+
"stream": False,
|
|
195
|
+
"options": options,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
if system_prompt:
|
|
199
|
+
payload["system"] = system_prompt
|
|
200
|
+
|
|
201
|
+
try:
|
|
202
|
+
response = self.session.post(
|
|
203
|
+
f"{self.base_url}/api/generate",
|
|
204
|
+
json=payload,
|
|
205
|
+
timeout=self.timeout,
|
|
206
|
+
)
|
|
207
|
+
response.raise_for_status()
|
|
208
|
+
data = response.json()
|
|
209
|
+
|
|
210
|
+
return LLMResponse(
|
|
211
|
+
text=data.get("response", ""),
|
|
212
|
+
model=model,
|
|
213
|
+
provider=self.provider_name,
|
|
214
|
+
usage={
|
|
215
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
216
|
+
"completion_tokens": data.get("eval_count"),
|
|
217
|
+
"total_duration": data.get("total_duration"),
|
|
218
|
+
},
|
|
219
|
+
)
|
|
220
|
+
except requests.RequestException as e:
|
|
221
|
+
raise ConnectionError(f"Ollama generate failed: {e}") from e
|
|
222
|
+
|
|
223
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
224
|
+
"""Generate embedding using Ollama with optional caching."""
|
|
225
|
+
self._current_embed_model = model
|
|
226
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
227
|
+
|
|
228
|
+
try:
|
|
229
|
+
if self.use_cache:
|
|
230
|
+
# Use cached version
|
|
231
|
+
embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
|
|
232
|
+
else:
|
|
233
|
+
# Direct call without cache
|
|
234
|
+
truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
|
|
235
|
+
response = self.session.post(
|
|
236
|
+
f"{self.embedding_url}/api/embed",
|
|
237
|
+
json={"model": model, "input": truncated},
|
|
238
|
+
timeout=self.timeout,
|
|
239
|
+
)
|
|
240
|
+
response.raise_for_status()
|
|
241
|
+
data = response.json()
|
|
242
|
+
embeddings = data.get("embeddings", [])
|
|
243
|
+
if not embeddings or not embeddings[0]:
|
|
244
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
245
|
+
embedding = tuple(embeddings[0])
|
|
246
|
+
|
|
247
|
+
# Update dimensions from actual response
|
|
248
|
+
self._current_dimensions = len(embedding)
|
|
249
|
+
|
|
250
|
+
return EmbeddingResponse(
|
|
251
|
+
embedding=embedding,
|
|
252
|
+
model=model,
|
|
253
|
+
provider=self.provider_name,
|
|
254
|
+
dimensions=len(embedding),
|
|
255
|
+
)
|
|
256
|
+
except requests.RequestException as e:
|
|
257
|
+
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
258
|
+
|
|
259
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
260
|
+
"""Generate embeddings for multiple texts in a single API call.
|
|
261
|
+
|
|
262
|
+
The /api/embed endpoint supports batch inputs natively.
|
|
263
|
+
"""
|
|
264
|
+
self._current_embed_model = model
|
|
265
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
266
|
+
|
|
267
|
+
# Truncate oversized inputs
|
|
268
|
+
truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
|
|
269
|
+
|
|
270
|
+
try:
|
|
271
|
+
response = self.session.post(
|
|
272
|
+
f"{self.embedding_url}/api/embed",
|
|
273
|
+
json={"model": model, "input": truncated_texts},
|
|
274
|
+
timeout=self.timeout,
|
|
275
|
+
)
|
|
276
|
+
response.raise_for_status()
|
|
277
|
+
data = response.json()
|
|
278
|
+
embeddings_list = data.get("embeddings", [])
|
|
279
|
+
|
|
280
|
+
if not embeddings_list:
|
|
281
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
282
|
+
|
|
283
|
+
results = []
|
|
284
|
+
for embedding_data in embeddings_list:
|
|
285
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
286
|
+
if embedding:
|
|
287
|
+
self._current_dimensions = len(embedding)
|
|
288
|
+
|
|
289
|
+
results.append(
|
|
290
|
+
EmbeddingResponse(
|
|
291
|
+
embedding=embedding,
|
|
292
|
+
model=model,
|
|
293
|
+
provider=self.provider_name,
|
|
294
|
+
dimensions=len(embedding),
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
return results
|
|
298
|
+
except requests.RequestException as e:
|
|
299
|
+
raise ConnectionError(f"Ollama batch embed failed: {e}") from e
|
|
300
|
+
|
|
301
|
+
async def embed_batch_async(
|
|
302
|
+
self,
|
|
303
|
+
texts: list[str],
|
|
304
|
+
model: str,
|
|
305
|
+
max_concurrent: int = 10, # kept for API compatibility, no longer used
|
|
306
|
+
) -> list[EmbeddingResponse]:
|
|
307
|
+
"""Generate embeddings for multiple texts asynchronously.
|
|
308
|
+
|
|
309
|
+
The /api/embed endpoint supports batch inputs natively, so this
|
|
310
|
+
makes a single async HTTP request for all texts.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
texts : list[str]
|
|
315
|
+
Texts to embed.
|
|
316
|
+
model : str
|
|
317
|
+
Embedding model name.
|
|
318
|
+
max_concurrent : int
|
|
319
|
+
Deprecated, kept for API compatibility. No longer used since
|
|
320
|
+
the API now supports native batching.
|
|
321
|
+
|
|
322
|
+
Returns
|
|
323
|
+
-------
|
|
324
|
+
list[EmbeddingResponse]
|
|
325
|
+
Embeddings in the same order as input texts.
|
|
326
|
+
|
|
327
|
+
Examples
|
|
328
|
+
--------
|
|
329
|
+
>>> import trio
|
|
330
|
+
>>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
|
|
331
|
+
"""
|
|
332
|
+
self._current_embed_model = model
|
|
333
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
334
|
+
|
|
335
|
+
# Truncate oversized inputs
|
|
336
|
+
truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
|
|
337
|
+
|
|
338
|
+
try:
|
|
339
|
+
async with httpx.AsyncClient() as client:
|
|
340
|
+
response = await client.post(
|
|
341
|
+
f"{self.embedding_url}/api/embed",
|
|
342
|
+
json={"model": model, "input": truncated_texts},
|
|
343
|
+
timeout=self.timeout,
|
|
344
|
+
)
|
|
345
|
+
response.raise_for_status()
|
|
346
|
+
data = response.json()
|
|
347
|
+
|
|
348
|
+
embeddings_list = data.get("embeddings", [])
|
|
349
|
+
if not embeddings_list:
|
|
350
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
351
|
+
|
|
352
|
+
results = []
|
|
353
|
+
for embedding_data in embeddings_list:
|
|
354
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
355
|
+
if embedding:
|
|
356
|
+
self._current_dimensions = len(embedding)
|
|
357
|
+
|
|
358
|
+
results.append(
|
|
359
|
+
EmbeddingResponse(
|
|
360
|
+
embedding=embedding,
|
|
361
|
+
model=model,
|
|
362
|
+
provider=self.provider_name,
|
|
363
|
+
dimensions=len(embedding),
|
|
364
|
+
)
|
|
365
|
+
)
|
|
366
|
+
return results
|
|
367
|
+
except httpx.HTTPError as e:
|
|
368
|
+
raise ConnectionError(f"Ollama async batch embed failed: {e}") from e
|
|
369
|
+
|
|
370
|
+
def chat(
|
|
371
|
+
self,
|
|
372
|
+
messages: list[dict[str, str]],
|
|
373
|
+
model: str,
|
|
374
|
+
temperature: float = 0.7,
|
|
375
|
+
max_tokens: int | None = None,
|
|
376
|
+
) -> LLMResponse:
|
|
377
|
+
"""
|
|
378
|
+
Chat completion using Ollama.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
messages : list[dict]
|
|
383
|
+
List of messages with 'role' and 'content' keys.
|
|
384
|
+
model : str
|
|
385
|
+
Model identifier.
|
|
386
|
+
temperature : float
|
|
387
|
+
Sampling temperature.
|
|
388
|
+
max_tokens : int, optional
|
|
389
|
+
Maximum tokens to generate.
|
|
390
|
+
|
|
391
|
+
Returns
|
|
392
|
+
-------
|
|
393
|
+
LLMResponse
|
|
394
|
+
The generated response.
|
|
395
|
+
"""
|
|
396
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
397
|
+
if max_tokens:
|
|
398
|
+
options["num_predict"] = max_tokens
|
|
399
|
+
|
|
400
|
+
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
401
|
+
"model": model,
|
|
402
|
+
"messages": messages,
|
|
403
|
+
"stream": False,
|
|
404
|
+
"options": options,
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
try:
|
|
408
|
+
response = self.session.post(
|
|
409
|
+
f"{self.base_url}/api/chat",
|
|
410
|
+
json=payload,
|
|
411
|
+
timeout=self.timeout,
|
|
412
|
+
)
|
|
413
|
+
response.raise_for_status()
|
|
414
|
+
data = response.json()
|
|
415
|
+
|
|
416
|
+
return LLMResponse(
|
|
417
|
+
text=data.get("message", {}).get("content", ""),
|
|
418
|
+
model=model,
|
|
419
|
+
provider=self.provider_name,
|
|
420
|
+
usage={
|
|
421
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
422
|
+
"completion_tokens": data.get("eval_count"),
|
|
423
|
+
},
|
|
424
|
+
)
|
|
425
|
+
except requests.RequestException as e:
|
|
426
|
+
raise ConnectionError(f"Ollama chat failed: {e}") from e
|
|
427
|
+
|
|
428
|
+
@staticmethod
|
|
429
|
+
def clear_embedding_cache() -> None:
|
|
430
|
+
"""Clear the embedding cache."""
|
|
431
|
+
_cached_embedding.cache_clear()
|
|
432
|
+
|
|
433
|
+
@staticmethod
|
|
434
|
+
def embedding_cache_info() -> dict[str, int]:
|
|
435
|
+
"""Get embedding cache statistics."""
|
|
436
|
+
info = _cached_embedding.cache_info()
|
|
437
|
+
return {
|
|
438
|
+
"hits": info.hits,
|
|
439
|
+
"misses": info.misses,
|
|
440
|
+
"maxsize": info.maxsize or 0,
|
|
441
|
+
"currsize": info.currsize,
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
# Export the EMBEDDING_DIMENSIONS for external use
|
|
446
|
+
EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
|
|
@@ -0,0 +1,225 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
SentenceTransformers provider for offline embedding.
|
|
7
|
+
|
|
8
|
+
This module provides embedding capabilities using the sentence-transformers
|
|
9
|
+
library, enabling fully offline RAG pipelines without API dependencies.
|
|
10
|
+
|
|
11
|
+
Requires: pip install ragit[transformers]
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from typing import TYPE_CHECKING
|
|
15
|
+
|
|
16
|
+
from ragit.providers.base import (
|
|
17
|
+
BaseEmbeddingProvider,
|
|
18
|
+
EmbeddingResponse,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from sentence_transformers import SentenceTransformer
|
|
23
|
+
|
|
24
|
+
# Lazy import flag
|
|
25
|
+
_sentence_transformers_available: bool | None = None
|
|
26
|
+
_model_cache: dict[str, "SentenceTransformer"] = {}
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _check_sentence_transformers() -> bool:
|
|
30
|
+
"""Check if sentence-transformers is available."""
|
|
31
|
+
global _sentence_transformers_available
|
|
32
|
+
if _sentence_transformers_available is None:
|
|
33
|
+
try:
|
|
34
|
+
from sentence_transformers import SentenceTransformer # noqa: F401
|
|
35
|
+
|
|
36
|
+
_sentence_transformers_available = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
_sentence_transformers_available = False
|
|
39
|
+
return _sentence_transformers_available
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_model(model_name: str, device: str | None = None) -> "SentenceTransformer":
|
|
43
|
+
"""Get or create a cached SentenceTransformer model."""
|
|
44
|
+
cache_key = f"{model_name}:{device or 'auto'}"
|
|
45
|
+
if cache_key not in _model_cache:
|
|
46
|
+
from sentence_transformers import SentenceTransformer
|
|
47
|
+
|
|
48
|
+
_model_cache[cache_key] = SentenceTransformer(model_name, device=device)
|
|
49
|
+
return _model_cache[cache_key]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class SentenceTransformersProvider(BaseEmbeddingProvider):
|
|
53
|
+
"""
|
|
54
|
+
Embedding provider using sentence-transformers for offline operation.
|
|
55
|
+
|
|
56
|
+
This provider uses the sentence-transformers library to generate embeddings
|
|
57
|
+
locally without requiring any API calls. It's ideal for:
|
|
58
|
+
- Offline/air-gapped environments
|
|
59
|
+
- Development and testing
|
|
60
|
+
- Cost-sensitive applications
|
|
61
|
+
- Privacy-sensitive use cases
|
|
62
|
+
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
model_name : str
|
|
66
|
+
HuggingFace model name. Default: "all-MiniLM-L6-v2" (fast, 384 dims).
|
|
67
|
+
Other popular options:
|
|
68
|
+
- "all-mpnet-base-v2" (768 dims, higher quality)
|
|
69
|
+
- "paraphrase-MiniLM-L6-v2" (384 dims)
|
|
70
|
+
- "multi-qa-MiniLM-L6-cos-v1" (384 dims, optimized for QA)
|
|
71
|
+
device : str, optional
|
|
72
|
+
Device to run on ("cpu", "cuda", "mps"). Auto-detected if None.
|
|
73
|
+
|
|
74
|
+
Examples
|
|
75
|
+
--------
|
|
76
|
+
>>> # Basic usage
|
|
77
|
+
>>> from ragit.providers import SentenceTransformersProvider
|
|
78
|
+
>>> provider = SentenceTransformersProvider()
|
|
79
|
+
>>>
|
|
80
|
+
>>> # With RAGAssistant (retrieval-only)
|
|
81
|
+
>>> assistant = RAGAssistant(docs, provider=provider)
|
|
82
|
+
>>> results = assistant.retrieve("query")
|
|
83
|
+
>>>
|
|
84
|
+
>>> # Custom model
|
|
85
|
+
>>> provider = SentenceTransformersProvider(model_name="all-mpnet-base-v2")
|
|
86
|
+
|
|
87
|
+
Raises
|
|
88
|
+
------
|
|
89
|
+
ImportError
|
|
90
|
+
If sentence-transformers is not installed.
|
|
91
|
+
|
|
92
|
+
Note
|
|
93
|
+
----
|
|
94
|
+
Install with: pip install ragit[transformers]
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
# Known model dimensions for common models
|
|
98
|
+
MODEL_DIMENSIONS: dict[str, int] = {
|
|
99
|
+
"all-MiniLM-L6-v2": 384,
|
|
100
|
+
"all-mpnet-base-v2": 768,
|
|
101
|
+
"paraphrase-MiniLM-L6-v2": 384,
|
|
102
|
+
"multi-qa-MiniLM-L6-cos-v1": 384,
|
|
103
|
+
"all-distilroberta-v1": 768,
|
|
104
|
+
"paraphrase-multilingual-MiniLM-L12-v2": 384,
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
model_name: str = "all-MiniLM-L6-v2",
|
|
110
|
+
device: str | None = None,
|
|
111
|
+
) -> None:
|
|
112
|
+
if not _check_sentence_transformers():
|
|
113
|
+
raise ImportError(
|
|
114
|
+
"sentence-transformers is required for SentenceTransformersProvider. "
|
|
115
|
+
"Install with: pip install ragit[transformers]"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
self._model_name = model_name
|
|
119
|
+
self._device = device
|
|
120
|
+
self._model: SentenceTransformer | None = None # Lazy loaded
|
|
121
|
+
self._dimensions: int | None = self.MODEL_DIMENSIONS.get(model_name)
|
|
122
|
+
|
|
123
|
+
def _ensure_model(self) -> "SentenceTransformer":
|
|
124
|
+
"""Ensure model is loaded (lazy loading)."""
|
|
125
|
+
if self._model is None:
|
|
126
|
+
model = _get_model(self._model_name, self._device)
|
|
127
|
+
self._model = model
|
|
128
|
+
# Update dimensions from actual model
|
|
129
|
+
self._dimensions = model.get_sentence_embedding_dimension()
|
|
130
|
+
return self._model
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def provider_name(self) -> str:
|
|
134
|
+
return "sentence_transformers"
|
|
135
|
+
|
|
136
|
+
@property
|
|
137
|
+
def dimensions(self) -> int:
|
|
138
|
+
if self._dimensions is None:
|
|
139
|
+
# Load model to get dimensions
|
|
140
|
+
self._ensure_model()
|
|
141
|
+
return self._dimensions or 384 # Fallback
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def model_name(self) -> str:
|
|
145
|
+
"""Return the model name being used."""
|
|
146
|
+
return self._model_name
|
|
147
|
+
|
|
148
|
+
def is_available(self) -> bool:
|
|
149
|
+
"""Check if sentence-transformers is installed and model can be loaded."""
|
|
150
|
+
if not _check_sentence_transformers():
|
|
151
|
+
return False
|
|
152
|
+
try:
|
|
153
|
+
self._ensure_model()
|
|
154
|
+
return True
|
|
155
|
+
except Exception:
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def embed(self, text: str, model: str = "") -> EmbeddingResponse:
|
|
159
|
+
"""
|
|
160
|
+
Generate embedding for text.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
text : str
|
|
165
|
+
Text to embed.
|
|
166
|
+
model : str
|
|
167
|
+
Model identifier (ignored, uses model from constructor).
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
EmbeddingResponse
|
|
172
|
+
The embedding response.
|
|
173
|
+
"""
|
|
174
|
+
model_instance = self._ensure_model()
|
|
175
|
+
embedding = model_instance.encode(text, convert_to_numpy=True)
|
|
176
|
+
|
|
177
|
+
# Convert to tuple
|
|
178
|
+
embedding_tuple = tuple(float(x) for x in embedding)
|
|
179
|
+
|
|
180
|
+
return EmbeddingResponse(
|
|
181
|
+
embedding=embedding_tuple,
|
|
182
|
+
model=self._model_name,
|
|
183
|
+
provider=self.provider_name,
|
|
184
|
+
dimensions=len(embedding_tuple),
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
def embed_batch(self, texts: list[str], model: str = "") -> list[EmbeddingResponse]:
|
|
188
|
+
"""
|
|
189
|
+
Generate embeddings for multiple texts efficiently.
|
|
190
|
+
|
|
191
|
+
Uses batch encoding for better performance.
|
|
192
|
+
|
|
193
|
+
Parameters
|
|
194
|
+
----------
|
|
195
|
+
texts : list[str]
|
|
196
|
+
Texts to embed.
|
|
197
|
+
model : str
|
|
198
|
+
Model identifier (ignored).
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
list[EmbeddingResponse]
|
|
203
|
+
List of embedding responses.
|
|
204
|
+
"""
|
|
205
|
+
if not texts:
|
|
206
|
+
return []
|
|
207
|
+
|
|
208
|
+
model_instance = self._ensure_model()
|
|
209
|
+
|
|
210
|
+
# Batch encode for efficiency
|
|
211
|
+
embeddings = model_instance.encode(texts, convert_to_numpy=True, show_progress_bar=False)
|
|
212
|
+
|
|
213
|
+
results = []
|
|
214
|
+
for embedding in embeddings:
|
|
215
|
+
embedding_tuple = tuple(float(x) for x in embedding)
|
|
216
|
+
results.append(
|
|
217
|
+
EmbeddingResponse(
|
|
218
|
+
embedding=embedding_tuple,
|
|
219
|
+
model=self._model_name,
|
|
220
|
+
provider=self.provider_name,
|
|
221
|
+
dimensions=len(embedding_tuple),
|
|
222
|
+
)
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return results
|