ragit 0.7.2__tar.gz → 0.7.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {ragit-0.7.2 → ragit-0.7.4}/PKG-INFO +74 -1
- {ragit-0.7.2 → ragit-0.7.4}/README.md +71 -0
- {ragit-0.7.2 → ragit-0.7.4}/pyproject.toml +2 -0
- ragit-0.7.4/ragit/providers/ollama.py +461 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/version.py +1 -1
- {ragit-0.7.2 → ragit-0.7.4}/ragit.egg-info/PKG-INFO +74 -1
- {ragit-0.7.2 → ragit-0.7.4}/ragit.egg-info/requires.txt +2 -0
- ragit-0.7.2/ragit/providers/ollama.py +0 -284
- {ragit-0.7.2 → ragit-0.7.4}/LICENSE +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/__init__.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/assistant.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/config.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/core/__init__.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/core/experiment/__init__.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/core/experiment/experiment.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/core/experiment/results.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/loaders.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/providers/__init__.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/providers/base.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit/utils/__init__.py +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit.egg-info/SOURCES.txt +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit.egg-info/dependency_links.txt +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/ragit.egg-info/top_level.txt +0 -0
- {ragit-0.7.2 → ragit-0.7.4}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ragit
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.4
|
|
4
4
|
Summary: Automatic RAG Pattern Optimization Engine
|
|
5
5
|
Author: RODMENA LIMITED
|
|
6
6
|
Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
|
|
@@ -26,6 +26,8 @@ Requires-Dist: pydantic>=2.0.0
|
|
|
26
26
|
Requires-Dist: python-dotenv>=1.0.0
|
|
27
27
|
Requires-Dist: scikit-learn>=1.5.0
|
|
28
28
|
Requires-Dist: tqdm>=4.66.0
|
|
29
|
+
Requires-Dist: trio>=0.24.0
|
|
30
|
+
Requires-Dist: httpx>=0.27.0
|
|
29
31
|
Provides-Extra: dev
|
|
30
32
|
Requires-Dist: ragit[test]; extra == "dev"
|
|
31
33
|
Requires-Dist: pytest; extra == "dev"
|
|
@@ -443,6 +445,77 @@ print(f"Score: {best.score:.3f}")
|
|
|
443
445
|
|
|
444
446
|
The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
|
|
445
447
|
|
|
448
|
+
## Performance Features
|
|
449
|
+
|
|
450
|
+
Ragit includes several optimizations for production workloads:
|
|
451
|
+
|
|
452
|
+
### Connection Pooling
|
|
453
|
+
|
|
454
|
+
`OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
|
|
455
|
+
|
|
456
|
+
```python
|
|
457
|
+
from ragit.providers import OllamaProvider
|
|
458
|
+
|
|
459
|
+
provider = OllamaProvider()
|
|
460
|
+
|
|
461
|
+
# All requests reuse the same connection pool
|
|
462
|
+
for text in texts:
|
|
463
|
+
provider.embed(text, model="mxbai-embed-large")
|
|
464
|
+
|
|
465
|
+
# Explicitly close when done (optional, auto-closes on garbage collection)
|
|
466
|
+
provider.close()
|
|
467
|
+
```
|
|
468
|
+
|
|
469
|
+
### Async Parallel Embedding
|
|
470
|
+
|
|
471
|
+
For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
|
|
472
|
+
|
|
473
|
+
```python
|
|
474
|
+
import trio
|
|
475
|
+
from ragit.providers import OllamaProvider
|
|
476
|
+
|
|
477
|
+
provider = OllamaProvider()
|
|
478
|
+
|
|
479
|
+
async def embed_documents():
|
|
480
|
+
texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
|
|
481
|
+
embeddings = await provider.embed_batch_async(
|
|
482
|
+
texts,
|
|
483
|
+
model="mxbai-embed-large",
|
|
484
|
+
max_concurrent=10 # Adjust based on server capacity
|
|
485
|
+
)
|
|
486
|
+
return embeddings
|
|
487
|
+
|
|
488
|
+
# Run with trio
|
|
489
|
+
results = trio.run(embed_documents)
|
|
490
|
+
```
|
|
491
|
+
|
|
492
|
+
### Embedding Cache
|
|
493
|
+
|
|
494
|
+
Repeated embedding calls are cached automatically (2048 entries LRU):
|
|
495
|
+
|
|
496
|
+
```python
|
|
497
|
+
from ragit.providers import OllamaProvider
|
|
498
|
+
|
|
499
|
+
provider = OllamaProvider(use_cache=True) # Default
|
|
500
|
+
|
|
501
|
+
# First call hits the API
|
|
502
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
503
|
+
|
|
504
|
+
# Second call returns cached result instantly
|
|
505
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
506
|
+
|
|
507
|
+
# View cache statistics
|
|
508
|
+
print(OllamaProvider.embedding_cache_info())
|
|
509
|
+
# {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
|
|
510
|
+
|
|
511
|
+
# Clear cache if needed
|
|
512
|
+
OllamaProvider.clear_embedding_cache()
|
|
513
|
+
```
|
|
514
|
+
|
|
515
|
+
### Pre-normalized Embeddings
|
|
516
|
+
|
|
517
|
+
Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
|
|
518
|
+
|
|
446
519
|
## API Reference
|
|
447
520
|
|
|
448
521
|
### Document Loading
|
|
@@ -398,6 +398,77 @@ print(f"Score: {best.score:.3f}")
|
|
|
398
398
|
|
|
399
399
|
The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
|
|
400
400
|
|
|
401
|
+
## Performance Features
|
|
402
|
+
|
|
403
|
+
Ragit includes several optimizations for production workloads:
|
|
404
|
+
|
|
405
|
+
### Connection Pooling
|
|
406
|
+
|
|
407
|
+
`OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
|
|
408
|
+
|
|
409
|
+
```python
|
|
410
|
+
from ragit.providers import OllamaProvider
|
|
411
|
+
|
|
412
|
+
provider = OllamaProvider()
|
|
413
|
+
|
|
414
|
+
# All requests reuse the same connection pool
|
|
415
|
+
for text in texts:
|
|
416
|
+
provider.embed(text, model="mxbai-embed-large")
|
|
417
|
+
|
|
418
|
+
# Explicitly close when done (optional, auto-closes on garbage collection)
|
|
419
|
+
provider.close()
|
|
420
|
+
```
|
|
421
|
+
|
|
422
|
+
### Async Parallel Embedding
|
|
423
|
+
|
|
424
|
+
For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
|
|
425
|
+
|
|
426
|
+
```python
|
|
427
|
+
import trio
|
|
428
|
+
from ragit.providers import OllamaProvider
|
|
429
|
+
|
|
430
|
+
provider = OllamaProvider()
|
|
431
|
+
|
|
432
|
+
async def embed_documents():
|
|
433
|
+
texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
|
|
434
|
+
embeddings = await provider.embed_batch_async(
|
|
435
|
+
texts,
|
|
436
|
+
model="mxbai-embed-large",
|
|
437
|
+
max_concurrent=10 # Adjust based on server capacity
|
|
438
|
+
)
|
|
439
|
+
return embeddings
|
|
440
|
+
|
|
441
|
+
# Run with trio
|
|
442
|
+
results = trio.run(embed_documents)
|
|
443
|
+
```
|
|
444
|
+
|
|
445
|
+
### Embedding Cache
|
|
446
|
+
|
|
447
|
+
Repeated embedding calls are cached automatically (2048 entries LRU):
|
|
448
|
+
|
|
449
|
+
```python
|
|
450
|
+
from ragit.providers import OllamaProvider
|
|
451
|
+
|
|
452
|
+
provider = OllamaProvider(use_cache=True) # Default
|
|
453
|
+
|
|
454
|
+
# First call hits the API
|
|
455
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
456
|
+
|
|
457
|
+
# Second call returns cached result instantly
|
|
458
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
459
|
+
|
|
460
|
+
# View cache statistics
|
|
461
|
+
print(OllamaProvider.embedding_cache_info())
|
|
462
|
+
# {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
|
|
463
|
+
|
|
464
|
+
# Clear cache if needed
|
|
465
|
+
OllamaProvider.clear_embedding_cache()
|
|
466
|
+
```
|
|
467
|
+
|
|
468
|
+
### Pre-normalized Embeddings
|
|
469
|
+
|
|
470
|
+
Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
|
|
471
|
+
|
|
401
472
|
## API Reference
|
|
402
473
|
|
|
403
474
|
### Document Loading
|
|
@@ -0,0 +1,461 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ollama provider for LLM and Embedding operations.
|
|
7
|
+
|
|
8
|
+
This provider connects to a local or remote Ollama server.
|
|
9
|
+
Configuration is loaded from environment variables.
|
|
10
|
+
|
|
11
|
+
Performance optimizations:
|
|
12
|
+
- Connection pooling via requests.Session()
|
|
13
|
+
- Async parallel embedding via trio + httpx
|
|
14
|
+
- LRU cache for repeated embedding queries
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from functools import lru_cache
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
import httpx
|
|
21
|
+
import requests
|
|
22
|
+
import trio
|
|
23
|
+
|
|
24
|
+
from ragit.config import config
|
|
25
|
+
from ragit.providers.base import (
|
|
26
|
+
BaseEmbeddingProvider,
|
|
27
|
+
BaseLLMProvider,
|
|
28
|
+
EmbeddingResponse,
|
|
29
|
+
LLMResponse,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Module-level cache for embeddings (shared across instances)
|
|
34
|
+
@lru_cache(maxsize=2048)
|
|
35
|
+
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
|
|
36
|
+
"""Cache embedding results to avoid redundant API calls."""
|
|
37
|
+
# Truncate oversized inputs
|
|
38
|
+
if len(text) > OllamaProvider.MAX_EMBED_CHARS:
|
|
39
|
+
text = text[: OllamaProvider.MAX_EMBED_CHARS]
|
|
40
|
+
|
|
41
|
+
response = requests.post(
|
|
42
|
+
f"{embedding_url}/api/embeddings",
|
|
43
|
+
headers={"Content-Type": "application/json"},
|
|
44
|
+
json={"model": model, "prompt": text},
|
|
45
|
+
timeout=timeout,
|
|
46
|
+
)
|
|
47
|
+
response.raise_for_status()
|
|
48
|
+
data = response.json()
|
|
49
|
+
embedding = data.get("embedding", [])
|
|
50
|
+
if not embedding:
|
|
51
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
52
|
+
return tuple(embedding)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
56
|
+
"""
|
|
57
|
+
Ollama provider for both LLM and Embedding operations.
|
|
58
|
+
|
|
59
|
+
Performance features:
|
|
60
|
+
- Connection pooling via requests.Session() for faster sequential requests
|
|
61
|
+
- Async parallel embedding via embed_batch_async() using trio + httpx
|
|
62
|
+
- LRU cache for repeated embedding queries (2048 entries)
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
base_url : str, optional
|
|
67
|
+
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
68
|
+
api_key : str, optional
|
|
69
|
+
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
70
|
+
timeout : int, optional
|
|
71
|
+
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
72
|
+
use_cache : bool, optional
|
|
73
|
+
Enable embedding cache (default: True)
|
|
74
|
+
|
|
75
|
+
Examples
|
|
76
|
+
--------
|
|
77
|
+
>>> provider = OllamaProvider()
|
|
78
|
+
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
79
|
+
>>> print(response.text)
|
|
80
|
+
|
|
81
|
+
>>> # Async batch embedding (5-10x faster for large batches)
|
|
82
|
+
>>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# Known embedding model dimensions
|
|
86
|
+
EMBEDDING_DIMENSIONS: dict[str, int] = {
|
|
87
|
+
"nomic-embed-text": 768,
|
|
88
|
+
"nomic-embed-text:latest": 768,
|
|
89
|
+
"mxbai-embed-large": 1024,
|
|
90
|
+
"all-minilm": 384,
|
|
91
|
+
"snowflake-arctic-embed": 1024,
|
|
92
|
+
"qwen3-embedding": 4096,
|
|
93
|
+
"qwen3-embedding:0.6b": 1024,
|
|
94
|
+
"qwen3-embedding:4b": 2560,
|
|
95
|
+
"qwen3-embedding:8b": 4096,
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Max characters per embedding request (safe limit for 512 token models)
|
|
99
|
+
MAX_EMBED_CHARS = 2000
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
base_url: str | None = None,
|
|
104
|
+
embedding_url: str | None = None,
|
|
105
|
+
api_key: str | None = None,
|
|
106
|
+
timeout: int | None = None,
|
|
107
|
+
use_cache: bool = True,
|
|
108
|
+
) -> None:
|
|
109
|
+
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
110
|
+
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
111
|
+
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
112
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
113
|
+
self.use_cache = use_cache
|
|
114
|
+
self._current_embed_model: str | None = None
|
|
115
|
+
self._current_dimensions: int = 768 # default
|
|
116
|
+
|
|
117
|
+
# Connection pooling via session
|
|
118
|
+
self._session: requests.Session | None = None
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def session(self) -> requests.Session:
|
|
122
|
+
"""Lazy-initialized session for connection pooling."""
|
|
123
|
+
if self._session is None:
|
|
124
|
+
self._session = requests.Session()
|
|
125
|
+
self._session.headers.update({"Content-Type": "application/json"})
|
|
126
|
+
if self.api_key:
|
|
127
|
+
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
|
128
|
+
return self._session
|
|
129
|
+
|
|
130
|
+
def close(self) -> None:
|
|
131
|
+
"""Close the session and release resources."""
|
|
132
|
+
if self._session is not None:
|
|
133
|
+
self._session.close()
|
|
134
|
+
self._session = None
|
|
135
|
+
|
|
136
|
+
def __del__(self) -> None:
|
|
137
|
+
"""Cleanup on garbage collection."""
|
|
138
|
+
self.close()
|
|
139
|
+
|
|
140
|
+
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
141
|
+
"""Get request headers including authentication if API key is set."""
|
|
142
|
+
headers = {"Content-Type": "application/json"}
|
|
143
|
+
if include_auth and self.api_key:
|
|
144
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
145
|
+
return headers
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def provider_name(self) -> str:
|
|
149
|
+
return "ollama"
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
def dimensions(self) -> int:
|
|
153
|
+
return self._current_dimensions
|
|
154
|
+
|
|
155
|
+
def is_available(self) -> bool:
|
|
156
|
+
"""Check if Ollama server is reachable."""
|
|
157
|
+
try:
|
|
158
|
+
response = self.session.get(
|
|
159
|
+
f"{self.base_url}/api/tags",
|
|
160
|
+
timeout=5,
|
|
161
|
+
)
|
|
162
|
+
return response.status_code == 200
|
|
163
|
+
except requests.RequestException:
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
167
|
+
"""List available models on the Ollama server."""
|
|
168
|
+
try:
|
|
169
|
+
response = self.session.get(
|
|
170
|
+
f"{self.base_url}/api/tags",
|
|
171
|
+
timeout=10,
|
|
172
|
+
)
|
|
173
|
+
response.raise_for_status()
|
|
174
|
+
data = response.json()
|
|
175
|
+
return list(data.get("models", []))
|
|
176
|
+
except requests.RequestException as e:
|
|
177
|
+
raise ConnectionError(f"Failed to list Ollama models: {e}") from e
|
|
178
|
+
|
|
179
|
+
def generate(
|
|
180
|
+
self,
|
|
181
|
+
prompt: str,
|
|
182
|
+
model: str,
|
|
183
|
+
system_prompt: str | None = None,
|
|
184
|
+
temperature: float = 0.7,
|
|
185
|
+
max_tokens: int | None = None,
|
|
186
|
+
) -> LLMResponse:
|
|
187
|
+
"""Generate text using Ollama."""
|
|
188
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
189
|
+
if max_tokens:
|
|
190
|
+
options["num_predict"] = max_tokens
|
|
191
|
+
|
|
192
|
+
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
193
|
+
"model": model,
|
|
194
|
+
"prompt": prompt,
|
|
195
|
+
"stream": False,
|
|
196
|
+
"options": options,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
if system_prompt:
|
|
200
|
+
payload["system"] = system_prompt
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
response = self.session.post(
|
|
204
|
+
f"{self.base_url}/api/generate",
|
|
205
|
+
json=payload,
|
|
206
|
+
timeout=self.timeout,
|
|
207
|
+
)
|
|
208
|
+
response.raise_for_status()
|
|
209
|
+
data = response.json()
|
|
210
|
+
|
|
211
|
+
return LLMResponse(
|
|
212
|
+
text=data.get("response", ""),
|
|
213
|
+
model=model,
|
|
214
|
+
provider=self.provider_name,
|
|
215
|
+
usage={
|
|
216
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
217
|
+
"completion_tokens": data.get("eval_count"),
|
|
218
|
+
"total_duration": data.get("total_duration"),
|
|
219
|
+
},
|
|
220
|
+
)
|
|
221
|
+
except requests.RequestException as e:
|
|
222
|
+
raise ConnectionError(f"Ollama generate failed: {e}") from e
|
|
223
|
+
|
|
224
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
225
|
+
"""Generate embedding using Ollama with optional caching."""
|
|
226
|
+
self._current_embed_model = model
|
|
227
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
228
|
+
|
|
229
|
+
try:
|
|
230
|
+
if self.use_cache:
|
|
231
|
+
# Use cached version
|
|
232
|
+
embedding = _cached_embedding(text, model, self.embedding_url, self.timeout)
|
|
233
|
+
else:
|
|
234
|
+
# Direct call without cache
|
|
235
|
+
truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
|
|
236
|
+
response = self.session.post(
|
|
237
|
+
f"{self.embedding_url}/api/embeddings",
|
|
238
|
+
json={"model": model, "prompt": truncated},
|
|
239
|
+
timeout=self.timeout,
|
|
240
|
+
)
|
|
241
|
+
response.raise_for_status()
|
|
242
|
+
data = response.json()
|
|
243
|
+
embedding_list = data.get("embedding", [])
|
|
244
|
+
if not embedding_list:
|
|
245
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
246
|
+
embedding = tuple(embedding_list)
|
|
247
|
+
|
|
248
|
+
# Update dimensions from actual response
|
|
249
|
+
self._current_dimensions = len(embedding)
|
|
250
|
+
|
|
251
|
+
return EmbeddingResponse(
|
|
252
|
+
embedding=embedding,
|
|
253
|
+
model=model,
|
|
254
|
+
provider=self.provider_name,
|
|
255
|
+
dimensions=len(embedding),
|
|
256
|
+
)
|
|
257
|
+
except requests.RequestException as e:
|
|
258
|
+
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
259
|
+
|
|
260
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
261
|
+
"""Generate embeddings for multiple texts sequentially.
|
|
262
|
+
|
|
263
|
+
For better performance with large batches, use embed_batch_async().
|
|
264
|
+
|
|
265
|
+
Note: Ollama /api/embeddings only supports single prompts, so we loop.
|
|
266
|
+
"""
|
|
267
|
+
self._current_embed_model = model
|
|
268
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
269
|
+
|
|
270
|
+
results = []
|
|
271
|
+
try:
|
|
272
|
+
for text in texts:
|
|
273
|
+
# Truncate oversized inputs
|
|
274
|
+
truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
|
|
275
|
+
|
|
276
|
+
if self.use_cache:
|
|
277
|
+
embedding = _cached_embedding(truncated, model, self.embedding_url, self.timeout)
|
|
278
|
+
else:
|
|
279
|
+
response = self.session.post(
|
|
280
|
+
f"{self.embedding_url}/api/embeddings",
|
|
281
|
+
json={"model": model, "prompt": truncated},
|
|
282
|
+
timeout=self.timeout,
|
|
283
|
+
)
|
|
284
|
+
response.raise_for_status()
|
|
285
|
+
data = response.json()
|
|
286
|
+
embedding_list = data.get("embedding", [])
|
|
287
|
+
embedding = tuple(embedding_list) if embedding_list else ()
|
|
288
|
+
|
|
289
|
+
if embedding:
|
|
290
|
+
self._current_dimensions = len(embedding)
|
|
291
|
+
|
|
292
|
+
results.append(
|
|
293
|
+
EmbeddingResponse(
|
|
294
|
+
embedding=embedding,
|
|
295
|
+
model=model,
|
|
296
|
+
provider=self.provider_name,
|
|
297
|
+
dimensions=len(embedding),
|
|
298
|
+
)
|
|
299
|
+
)
|
|
300
|
+
return results
|
|
301
|
+
except requests.RequestException as e:
|
|
302
|
+
raise ConnectionError(f"Ollama batch embed failed: {e}") from e
|
|
303
|
+
|
|
304
|
+
async def embed_batch_async(
|
|
305
|
+
self,
|
|
306
|
+
texts: list[str],
|
|
307
|
+
model: str,
|
|
308
|
+
max_concurrent: int = 10,
|
|
309
|
+
) -> list[EmbeddingResponse]:
|
|
310
|
+
"""Generate embeddings for multiple texts in parallel using trio.
|
|
311
|
+
|
|
312
|
+
This method is 5-10x faster than embed_batch() for large batches
|
|
313
|
+
by making concurrent HTTP requests.
|
|
314
|
+
|
|
315
|
+
Parameters
|
|
316
|
+
----------
|
|
317
|
+
texts : list[str]
|
|
318
|
+
Texts to embed.
|
|
319
|
+
model : str
|
|
320
|
+
Embedding model name.
|
|
321
|
+
max_concurrent : int
|
|
322
|
+
Maximum concurrent requests (default: 10).
|
|
323
|
+
Higher values = faster but more server load.
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
list[EmbeddingResponse]
|
|
328
|
+
Embeddings in the same order as input texts.
|
|
329
|
+
|
|
330
|
+
Examples
|
|
331
|
+
--------
|
|
332
|
+
>>> import trio
|
|
333
|
+
>>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
|
|
334
|
+
"""
|
|
335
|
+
self._current_embed_model = model
|
|
336
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
337
|
+
|
|
338
|
+
# Results storage (index -> embedding)
|
|
339
|
+
results: dict[int, EmbeddingResponse] = {}
|
|
340
|
+
errors: list[Exception] = []
|
|
341
|
+
|
|
342
|
+
# Semaphore to limit concurrency
|
|
343
|
+
limiter = trio.CapacityLimiter(max_concurrent)
|
|
344
|
+
|
|
345
|
+
async def fetch_embedding(client: httpx.AsyncClient, index: int, text: str) -> None:
|
|
346
|
+
"""Fetch a single embedding."""
|
|
347
|
+
async with limiter:
|
|
348
|
+
try:
|
|
349
|
+
# Truncate oversized inputs
|
|
350
|
+
truncated = text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text
|
|
351
|
+
|
|
352
|
+
response = await client.post(
|
|
353
|
+
f"{self.embedding_url}/api/embeddings",
|
|
354
|
+
json={"model": model, "prompt": truncated},
|
|
355
|
+
timeout=self.timeout,
|
|
356
|
+
)
|
|
357
|
+
response.raise_for_status()
|
|
358
|
+
data = response.json()
|
|
359
|
+
|
|
360
|
+
embedding_list = data.get("embedding", [])
|
|
361
|
+
embedding = tuple(embedding_list) if embedding_list else ()
|
|
362
|
+
|
|
363
|
+
if embedding:
|
|
364
|
+
self._current_dimensions = len(embedding)
|
|
365
|
+
|
|
366
|
+
results[index] = EmbeddingResponse(
|
|
367
|
+
embedding=embedding,
|
|
368
|
+
model=model,
|
|
369
|
+
provider=self.provider_name,
|
|
370
|
+
dimensions=len(embedding),
|
|
371
|
+
)
|
|
372
|
+
except Exception as e:
|
|
373
|
+
errors.append(e)
|
|
374
|
+
|
|
375
|
+
async with httpx.AsyncClient() as client, trio.open_nursery() as nursery:
|
|
376
|
+
for i, text in enumerate(texts):
|
|
377
|
+
nursery.start_soon(fetch_embedding, client, i, text)
|
|
378
|
+
|
|
379
|
+
if errors:
|
|
380
|
+
raise ConnectionError(f"Ollama async batch embed failed: {errors[0]}") from errors[0]
|
|
381
|
+
|
|
382
|
+
# Return results in original order
|
|
383
|
+
return [results[i] for i in range(len(texts))]
|
|
384
|
+
|
|
385
|
+
def chat(
|
|
386
|
+
self,
|
|
387
|
+
messages: list[dict[str, str]],
|
|
388
|
+
model: str,
|
|
389
|
+
temperature: float = 0.7,
|
|
390
|
+
max_tokens: int | None = None,
|
|
391
|
+
) -> LLMResponse:
|
|
392
|
+
"""
|
|
393
|
+
Chat completion using Ollama.
|
|
394
|
+
|
|
395
|
+
Parameters
|
|
396
|
+
----------
|
|
397
|
+
messages : list[dict]
|
|
398
|
+
List of messages with 'role' and 'content' keys.
|
|
399
|
+
model : str
|
|
400
|
+
Model identifier.
|
|
401
|
+
temperature : float
|
|
402
|
+
Sampling temperature.
|
|
403
|
+
max_tokens : int, optional
|
|
404
|
+
Maximum tokens to generate.
|
|
405
|
+
|
|
406
|
+
Returns
|
|
407
|
+
-------
|
|
408
|
+
LLMResponse
|
|
409
|
+
The generated response.
|
|
410
|
+
"""
|
|
411
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
412
|
+
if max_tokens:
|
|
413
|
+
options["num_predict"] = max_tokens
|
|
414
|
+
|
|
415
|
+
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
416
|
+
"model": model,
|
|
417
|
+
"messages": messages,
|
|
418
|
+
"stream": False,
|
|
419
|
+
"options": options,
|
|
420
|
+
}
|
|
421
|
+
|
|
422
|
+
try:
|
|
423
|
+
response = self.session.post(
|
|
424
|
+
f"{self.base_url}/api/chat",
|
|
425
|
+
json=payload,
|
|
426
|
+
timeout=self.timeout,
|
|
427
|
+
)
|
|
428
|
+
response.raise_for_status()
|
|
429
|
+
data = response.json()
|
|
430
|
+
|
|
431
|
+
return LLMResponse(
|
|
432
|
+
text=data.get("message", {}).get("content", ""),
|
|
433
|
+
model=model,
|
|
434
|
+
provider=self.provider_name,
|
|
435
|
+
usage={
|
|
436
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
437
|
+
"completion_tokens": data.get("eval_count"),
|
|
438
|
+
},
|
|
439
|
+
)
|
|
440
|
+
except requests.RequestException as e:
|
|
441
|
+
raise ConnectionError(f"Ollama chat failed: {e}") from e
|
|
442
|
+
|
|
443
|
+
@staticmethod
|
|
444
|
+
def clear_embedding_cache() -> None:
|
|
445
|
+
"""Clear the embedding cache."""
|
|
446
|
+
_cached_embedding.cache_clear()
|
|
447
|
+
|
|
448
|
+
@staticmethod
|
|
449
|
+
def embedding_cache_info() -> dict[str, int]:
|
|
450
|
+
"""Get embedding cache statistics."""
|
|
451
|
+
info = _cached_embedding.cache_info()
|
|
452
|
+
return {
|
|
453
|
+
"hits": info.hits,
|
|
454
|
+
"misses": info.misses,
|
|
455
|
+
"maxsize": info.maxsize or 0,
|
|
456
|
+
"currsize": info.currsize,
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
# Export the EMBEDDING_DIMENSIONS for external use
|
|
461
|
+
EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: ragit
|
|
3
|
-
Version: 0.7.
|
|
3
|
+
Version: 0.7.4
|
|
4
4
|
Summary: Automatic RAG Pattern Optimization Engine
|
|
5
5
|
Author: RODMENA LIMITED
|
|
6
6
|
Maintainer-email: RODMENA LIMITED <info@rodmena.co.uk>
|
|
@@ -26,6 +26,8 @@ Requires-Dist: pydantic>=2.0.0
|
|
|
26
26
|
Requires-Dist: python-dotenv>=1.0.0
|
|
27
27
|
Requires-Dist: scikit-learn>=1.5.0
|
|
28
28
|
Requires-Dist: tqdm>=4.66.0
|
|
29
|
+
Requires-Dist: trio>=0.24.0
|
|
30
|
+
Requires-Dist: httpx>=0.27.0
|
|
29
31
|
Provides-Extra: dev
|
|
30
32
|
Requires-Dist: ragit[test]; extra == "dev"
|
|
31
33
|
Requires-Dist: pytest; extra == "dev"
|
|
@@ -443,6 +445,77 @@ print(f"Score: {best.score:.3f}")
|
|
|
443
445
|
|
|
444
446
|
The experiment tests different combinations of chunk sizes, overlaps, and retrieval parameters to find what works best for your content.
|
|
445
447
|
|
|
448
|
+
## Performance Features
|
|
449
|
+
|
|
450
|
+
Ragit includes several optimizations for production workloads:
|
|
451
|
+
|
|
452
|
+
### Connection Pooling
|
|
453
|
+
|
|
454
|
+
`OllamaProvider` uses HTTP connection pooling via `requests.Session()` for faster sequential requests:
|
|
455
|
+
|
|
456
|
+
```python
|
|
457
|
+
from ragit.providers import OllamaProvider
|
|
458
|
+
|
|
459
|
+
provider = OllamaProvider()
|
|
460
|
+
|
|
461
|
+
# All requests reuse the same connection pool
|
|
462
|
+
for text in texts:
|
|
463
|
+
provider.embed(text, model="mxbai-embed-large")
|
|
464
|
+
|
|
465
|
+
# Explicitly close when done (optional, auto-closes on garbage collection)
|
|
466
|
+
provider.close()
|
|
467
|
+
```
|
|
468
|
+
|
|
469
|
+
### Async Parallel Embedding
|
|
470
|
+
|
|
471
|
+
For large batches, use `embed_batch_async()` with trio for 5-10x faster embedding:
|
|
472
|
+
|
|
473
|
+
```python
|
|
474
|
+
import trio
|
|
475
|
+
from ragit.providers import OllamaProvider
|
|
476
|
+
|
|
477
|
+
provider = OllamaProvider()
|
|
478
|
+
|
|
479
|
+
async def embed_documents():
|
|
480
|
+
texts = ["doc1...", "doc2...", "doc3...", ...] # hundreds of texts
|
|
481
|
+
embeddings = await provider.embed_batch_async(
|
|
482
|
+
texts,
|
|
483
|
+
model="mxbai-embed-large",
|
|
484
|
+
max_concurrent=10 # Adjust based on server capacity
|
|
485
|
+
)
|
|
486
|
+
return embeddings
|
|
487
|
+
|
|
488
|
+
# Run with trio
|
|
489
|
+
results = trio.run(embed_documents)
|
|
490
|
+
```
|
|
491
|
+
|
|
492
|
+
### Embedding Cache
|
|
493
|
+
|
|
494
|
+
Repeated embedding calls are cached automatically (2048 entries LRU):
|
|
495
|
+
|
|
496
|
+
```python
|
|
497
|
+
from ragit.providers import OllamaProvider
|
|
498
|
+
|
|
499
|
+
provider = OllamaProvider(use_cache=True) # Default
|
|
500
|
+
|
|
501
|
+
# First call hits the API
|
|
502
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
503
|
+
|
|
504
|
+
# Second call returns cached result instantly
|
|
505
|
+
provider.embed("Hello world", model="mxbai-embed-large")
|
|
506
|
+
|
|
507
|
+
# View cache statistics
|
|
508
|
+
print(OllamaProvider.embedding_cache_info())
|
|
509
|
+
# {'hits': 1, 'misses': 1, 'maxsize': 2048, 'currsize': 1}
|
|
510
|
+
|
|
511
|
+
# Clear cache if needed
|
|
512
|
+
OllamaProvider.clear_embedding_cache()
|
|
513
|
+
```
|
|
514
|
+
|
|
515
|
+
### Pre-normalized Embeddings
|
|
516
|
+
|
|
517
|
+
Vector similarity uses pre-normalized embeddings, making cosine similarity a simple dot product (O(1) per comparison).
|
|
518
|
+
|
|
446
519
|
## API Reference
|
|
447
520
|
|
|
448
521
|
### Document Loading
|
|
@@ -1,284 +0,0 @@
|
|
|
1
|
-
#
|
|
2
|
-
# Copyright RODMENA LIMITED 2025
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
#
|
|
5
|
-
"""
|
|
6
|
-
Ollama provider for LLM and Embedding operations.
|
|
7
|
-
|
|
8
|
-
This provider connects to a local or remote Ollama server.
|
|
9
|
-
Configuration is loaded from environment variables.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
import requests
|
|
13
|
-
|
|
14
|
-
from ragit.config import config
|
|
15
|
-
from ragit.providers.base import (
|
|
16
|
-
BaseEmbeddingProvider,
|
|
17
|
-
BaseLLMProvider,
|
|
18
|
-
EmbeddingResponse,
|
|
19
|
-
LLMResponse,
|
|
20
|
-
)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
24
|
-
"""
|
|
25
|
-
Ollama provider for both LLM and Embedding operations.
|
|
26
|
-
|
|
27
|
-
Parameters
|
|
28
|
-
----------
|
|
29
|
-
base_url : str, optional
|
|
30
|
-
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
31
|
-
api_key : str, optional
|
|
32
|
-
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
33
|
-
timeout : int, optional
|
|
34
|
-
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
35
|
-
|
|
36
|
-
Examples
|
|
37
|
-
--------
|
|
38
|
-
>>> provider = OllamaProvider()
|
|
39
|
-
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
40
|
-
>>> print(response.text)
|
|
41
|
-
|
|
42
|
-
>>> embedding = provider.embed("Hello world", model="nomic-embed-text")
|
|
43
|
-
>>> print(len(embedding.embedding))
|
|
44
|
-
"""
|
|
45
|
-
|
|
46
|
-
# Known embedding model dimensions
|
|
47
|
-
EMBEDDING_DIMENSIONS = {
|
|
48
|
-
"nomic-embed-text": 768,
|
|
49
|
-
"nomic-embed-text:latest": 768,
|
|
50
|
-
"mxbai-embed-large": 1024,
|
|
51
|
-
"all-minilm": 384,
|
|
52
|
-
"snowflake-arctic-embed": 1024,
|
|
53
|
-
"qwen3-embedding": 4096,
|
|
54
|
-
"qwen3-embedding:0.6b": 1024,
|
|
55
|
-
"qwen3-embedding:4b": 2560,
|
|
56
|
-
"qwen3-embedding:8b": 4096,
|
|
57
|
-
}
|
|
58
|
-
|
|
59
|
-
def __init__(
|
|
60
|
-
self,
|
|
61
|
-
base_url: str | None = None,
|
|
62
|
-
embedding_url: str | None = None,
|
|
63
|
-
api_key: str | None = None,
|
|
64
|
-
timeout: int | None = None,
|
|
65
|
-
) -> None:
|
|
66
|
-
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
67
|
-
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
68
|
-
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
69
|
-
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
70
|
-
self._current_embed_model: str | None = None
|
|
71
|
-
self._current_dimensions: int = 768 # default
|
|
72
|
-
|
|
73
|
-
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
74
|
-
"""Get request headers including authentication if API key is set."""
|
|
75
|
-
headers = {"Content-Type": "application/json"}
|
|
76
|
-
if include_auth and self.api_key:
|
|
77
|
-
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
78
|
-
return headers
|
|
79
|
-
|
|
80
|
-
@property
|
|
81
|
-
def provider_name(self) -> str:
|
|
82
|
-
return "ollama"
|
|
83
|
-
|
|
84
|
-
@property
|
|
85
|
-
def dimensions(self) -> int:
|
|
86
|
-
return self._current_dimensions
|
|
87
|
-
|
|
88
|
-
def is_available(self) -> bool:
|
|
89
|
-
"""Check if Ollama server is reachable."""
|
|
90
|
-
try:
|
|
91
|
-
response = requests.get(
|
|
92
|
-
f"{self.base_url}/api/tags",
|
|
93
|
-
headers=self._get_headers(),
|
|
94
|
-
timeout=5,
|
|
95
|
-
)
|
|
96
|
-
return response.status_code == 200
|
|
97
|
-
except requests.RequestException:
|
|
98
|
-
return False
|
|
99
|
-
|
|
100
|
-
def list_models(self) -> list[dict[str, str]]:
|
|
101
|
-
"""List available models on the Ollama server."""
|
|
102
|
-
try:
|
|
103
|
-
response = requests.get(
|
|
104
|
-
f"{self.base_url}/api/tags",
|
|
105
|
-
headers=self._get_headers(),
|
|
106
|
-
timeout=10,
|
|
107
|
-
)
|
|
108
|
-
response.raise_for_status()
|
|
109
|
-
data = response.json()
|
|
110
|
-
return list(data.get("models", []))
|
|
111
|
-
except requests.RequestException as e:
|
|
112
|
-
raise ConnectionError(f"Failed to list Ollama models: {e}") from e
|
|
113
|
-
|
|
114
|
-
def generate(
|
|
115
|
-
self,
|
|
116
|
-
prompt: str,
|
|
117
|
-
model: str,
|
|
118
|
-
system_prompt: str | None = None,
|
|
119
|
-
temperature: float = 0.7,
|
|
120
|
-
max_tokens: int | None = None,
|
|
121
|
-
) -> LLMResponse:
|
|
122
|
-
"""Generate text using Ollama."""
|
|
123
|
-
options: dict[str, float | int] = {"temperature": temperature}
|
|
124
|
-
if max_tokens:
|
|
125
|
-
options["num_predict"] = max_tokens
|
|
126
|
-
|
|
127
|
-
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
128
|
-
"model": model,
|
|
129
|
-
"prompt": prompt,
|
|
130
|
-
"stream": False,
|
|
131
|
-
"options": options,
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
if system_prompt:
|
|
135
|
-
payload["system"] = system_prompt
|
|
136
|
-
|
|
137
|
-
try:
|
|
138
|
-
response = requests.post(
|
|
139
|
-
f"{self.base_url}/api/generate",
|
|
140
|
-
headers=self._get_headers(),
|
|
141
|
-
json=payload,
|
|
142
|
-
timeout=self.timeout,
|
|
143
|
-
)
|
|
144
|
-
response.raise_for_status()
|
|
145
|
-
data = response.json()
|
|
146
|
-
|
|
147
|
-
return LLMResponse(
|
|
148
|
-
text=data.get("response", ""),
|
|
149
|
-
model=model,
|
|
150
|
-
provider=self.provider_name,
|
|
151
|
-
usage={
|
|
152
|
-
"prompt_tokens": data.get("prompt_eval_count"),
|
|
153
|
-
"completion_tokens": data.get("eval_count"),
|
|
154
|
-
"total_duration": data.get("total_duration"),
|
|
155
|
-
},
|
|
156
|
-
)
|
|
157
|
-
except requests.RequestException as e:
|
|
158
|
-
raise ConnectionError(f"Ollama generate failed: {e}") from e
|
|
159
|
-
|
|
160
|
-
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
161
|
-
"""Generate embedding using Ollama (uses embedding_url, no auth for local)."""
|
|
162
|
-
self._current_embed_model = model
|
|
163
|
-
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
164
|
-
|
|
165
|
-
try:
|
|
166
|
-
response = requests.post(
|
|
167
|
-
f"{self.embedding_url}/api/embeddings",
|
|
168
|
-
headers=self._get_headers(include_auth=False),
|
|
169
|
-
json={"model": model, "prompt": text},
|
|
170
|
-
timeout=self.timeout,
|
|
171
|
-
)
|
|
172
|
-
response.raise_for_status()
|
|
173
|
-
data = response.json()
|
|
174
|
-
|
|
175
|
-
embedding = data.get("embedding", [])
|
|
176
|
-
if not embedding:
|
|
177
|
-
raise ValueError("Empty embedding returned from Ollama")
|
|
178
|
-
|
|
179
|
-
# Update dimensions from actual response
|
|
180
|
-
self._current_dimensions = len(embedding)
|
|
181
|
-
|
|
182
|
-
return EmbeddingResponse(
|
|
183
|
-
embedding=tuple(embedding),
|
|
184
|
-
model=model,
|
|
185
|
-
provider=self.provider_name,
|
|
186
|
-
dimensions=len(embedding),
|
|
187
|
-
)
|
|
188
|
-
except requests.RequestException as e:
|
|
189
|
-
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
190
|
-
|
|
191
|
-
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
192
|
-
"""Generate embeddings for multiple texts (uses embedding_url, no auth for local).
|
|
193
|
-
|
|
194
|
-
Note: Ollama /api/embeddings only supports single prompts, so we loop.
|
|
195
|
-
"""
|
|
196
|
-
self._current_embed_model = model
|
|
197
|
-
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
198
|
-
|
|
199
|
-
results = []
|
|
200
|
-
try:
|
|
201
|
-
for text in texts:
|
|
202
|
-
response = requests.post(
|
|
203
|
-
f"{self.embedding_url}/api/embeddings",
|
|
204
|
-
headers=self._get_headers(include_auth=False),
|
|
205
|
-
json={"model": model, "prompt": text},
|
|
206
|
-
timeout=self.timeout,
|
|
207
|
-
)
|
|
208
|
-
response.raise_for_status()
|
|
209
|
-
data = response.json()
|
|
210
|
-
|
|
211
|
-
embedding = data.get("embedding", [])
|
|
212
|
-
if embedding:
|
|
213
|
-
self._current_dimensions = len(embedding)
|
|
214
|
-
|
|
215
|
-
results.append(
|
|
216
|
-
EmbeddingResponse(
|
|
217
|
-
embedding=tuple(embedding),
|
|
218
|
-
model=model,
|
|
219
|
-
provider=self.provider_name,
|
|
220
|
-
dimensions=len(embedding),
|
|
221
|
-
)
|
|
222
|
-
)
|
|
223
|
-
return results
|
|
224
|
-
except requests.RequestException as e:
|
|
225
|
-
raise ConnectionError(f"Ollama batch embed failed: {e}") from e
|
|
226
|
-
|
|
227
|
-
def chat(
|
|
228
|
-
self,
|
|
229
|
-
messages: list[dict[str, str]],
|
|
230
|
-
model: str,
|
|
231
|
-
temperature: float = 0.7,
|
|
232
|
-
max_tokens: int | None = None,
|
|
233
|
-
) -> LLMResponse:
|
|
234
|
-
"""
|
|
235
|
-
Chat completion using Ollama.
|
|
236
|
-
|
|
237
|
-
Parameters
|
|
238
|
-
----------
|
|
239
|
-
messages : list[dict]
|
|
240
|
-
List of messages with 'role' and 'content' keys.
|
|
241
|
-
model : str
|
|
242
|
-
Model identifier.
|
|
243
|
-
temperature : float
|
|
244
|
-
Sampling temperature.
|
|
245
|
-
max_tokens : int, optional
|
|
246
|
-
Maximum tokens to generate.
|
|
247
|
-
|
|
248
|
-
Returns
|
|
249
|
-
-------
|
|
250
|
-
LLMResponse
|
|
251
|
-
The generated response.
|
|
252
|
-
"""
|
|
253
|
-
options: dict[str, float | int] = {"temperature": temperature}
|
|
254
|
-
if max_tokens:
|
|
255
|
-
options["num_predict"] = max_tokens
|
|
256
|
-
|
|
257
|
-
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
258
|
-
"model": model,
|
|
259
|
-
"messages": messages,
|
|
260
|
-
"stream": False,
|
|
261
|
-
"options": options,
|
|
262
|
-
}
|
|
263
|
-
|
|
264
|
-
try:
|
|
265
|
-
response = requests.post(
|
|
266
|
-
f"{self.base_url}/api/chat",
|
|
267
|
-
headers=self._get_headers(),
|
|
268
|
-
json=payload,
|
|
269
|
-
timeout=self.timeout,
|
|
270
|
-
)
|
|
271
|
-
response.raise_for_status()
|
|
272
|
-
data = response.json()
|
|
273
|
-
|
|
274
|
-
return LLMResponse(
|
|
275
|
-
text=data.get("message", {}).get("content", ""),
|
|
276
|
-
model=model,
|
|
277
|
-
provider=self.provider_name,
|
|
278
|
-
usage={
|
|
279
|
-
"prompt_tokens": data.get("prompt_eval_count"),
|
|
280
|
-
"completion_tokens": data.get("eval_count"),
|
|
281
|
-
},
|
|
282
|
-
)
|
|
283
|
-
except requests.RequestException as e:
|
|
284
|
-
raise ConnectionError(f"Ollama chat failed: {e}") from e
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|