ragit 0.3__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +128 -2
- ragit/assistant.py +757 -0
- ragit/config.py +204 -0
- ragit/core/__init__.py +5 -0
- ragit/core/experiment/__init__.py +22 -0
- ragit/core/experiment/experiment.py +577 -0
- ragit/core/experiment/results.py +131 -0
- ragit/exceptions.py +271 -0
- ragit/loaders.py +401 -0
- ragit/logging.py +194 -0
- ragit/monitor.py +307 -0
- ragit/providers/__init__.py +35 -0
- ragit/providers/base.py +147 -0
- ragit/providers/function_adapter.py +237 -0
- ragit/providers/ollama.py +670 -0
- ragit/utils/__init__.py +105 -0
- ragit/version.py +5 -0
- ragit-0.10.1.dist-info/METADATA +153 -0
- ragit-0.10.1.dist-info/RECORD +22 -0
- {ragit-0.3.dist-info → ragit-0.10.1.dist-info}/WHEEL +1 -1
- ragit-0.10.1.dist-info/licenses/LICENSE +201 -0
- ragit/main.py +0 -384
- ragit-0.3.dist-info/METADATA +0 -163
- ragit-0.3.dist-info/RECORD +0 -6
- {ragit-0.3.dist-info → ragit-0.10.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,670 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright RODMENA LIMITED 2025
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Ollama provider for LLM and Embedding operations.
|
|
7
|
+
|
|
8
|
+
This provider connects to a local or remote Ollama server.
|
|
9
|
+
Configuration is loaded from environment variables.
|
|
10
|
+
|
|
11
|
+
Performance optimizations:
|
|
12
|
+
- Connection pooling via requests.Session()
|
|
13
|
+
- Async parallel embedding via trio + httpx
|
|
14
|
+
- LRU cache for repeated embedding queries
|
|
15
|
+
|
|
16
|
+
Resilience features (via resilient-circuit):
|
|
17
|
+
- Retry with exponential backoff
|
|
18
|
+
- Circuit breaker pattern for fault tolerance
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from datetime import timedelta
|
|
22
|
+
from fractions import Fraction
|
|
23
|
+
from functools import lru_cache
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import httpx
|
|
27
|
+
import requests
|
|
28
|
+
from resilient_circuit import (
|
|
29
|
+
CircuitProtectorPolicy,
|
|
30
|
+
ExponentialDelay,
|
|
31
|
+
RetryWithBackoffPolicy,
|
|
32
|
+
SafetyNet,
|
|
33
|
+
)
|
|
34
|
+
from resilient_circuit.exceptions import ProtectedCallError, RetryLimitReached
|
|
35
|
+
|
|
36
|
+
from ragit.config import config
|
|
37
|
+
from ragit.exceptions import IndexingError, ProviderError
|
|
38
|
+
from ragit.logging import log_operation, logger
|
|
39
|
+
from ragit.providers.base import (
|
|
40
|
+
BaseEmbeddingProvider,
|
|
41
|
+
BaseLLMProvider,
|
|
42
|
+
EmbeddingResponse,
|
|
43
|
+
LLMResponse,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _create_generate_policy() -> SafetyNet:
|
|
48
|
+
"""Create resilience policy for LLM generation (longer timeouts, more tolerant)."""
|
|
49
|
+
return SafetyNet(
|
|
50
|
+
policies=(
|
|
51
|
+
RetryWithBackoffPolicy(
|
|
52
|
+
max_retries=3,
|
|
53
|
+
backoff=ExponentialDelay(
|
|
54
|
+
min_delay=timedelta(seconds=1),
|
|
55
|
+
max_delay=timedelta(seconds=30),
|
|
56
|
+
factor=2,
|
|
57
|
+
jitter=0.1,
|
|
58
|
+
),
|
|
59
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
|
|
60
|
+
),
|
|
61
|
+
CircuitProtectorPolicy(
|
|
62
|
+
resource_key="ollama_generate",
|
|
63
|
+
cooldown=timedelta(seconds=60),
|
|
64
|
+
failure_limit=Fraction(3, 10), # 30% failure rate trips circuit
|
|
65
|
+
success_limit=Fraction(4, 5), # 80% success to close
|
|
66
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _create_embed_policy() -> SafetyNet:
|
|
73
|
+
"""Create resilience policy for embeddings (faster, stricter)."""
|
|
74
|
+
return SafetyNet(
|
|
75
|
+
policies=(
|
|
76
|
+
RetryWithBackoffPolicy(
|
|
77
|
+
max_retries=2,
|
|
78
|
+
backoff=ExponentialDelay(
|
|
79
|
+
min_delay=timedelta(milliseconds=500),
|
|
80
|
+
max_delay=timedelta(seconds=5),
|
|
81
|
+
factor=2,
|
|
82
|
+
jitter=0.1,
|
|
83
|
+
),
|
|
84
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
|
|
85
|
+
),
|
|
86
|
+
CircuitProtectorPolicy(
|
|
87
|
+
resource_key="ollama_embed",
|
|
88
|
+
cooldown=timedelta(seconds=30),
|
|
89
|
+
failure_limit=Fraction(2, 5), # 40% failure rate trips circuit
|
|
90
|
+
success_limit=Fraction(3, 3), # All 3 tests must succeed to close
|
|
91
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _truncate_text(text: str, max_chars: int = 2000) -> str:
|
|
98
|
+
"""Truncate text to max_chars. Used BEFORE cache lookup to fix cache key bug."""
|
|
99
|
+
return text[:max_chars] if len(text) > max_chars else text
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# Module-level cache for embeddings (shared across instances)
|
|
103
|
+
# NOTE: Text must be truncated BEFORE calling this function to ensure correct cache keys
|
|
104
|
+
@lru_cache(maxsize=2048)
|
|
105
|
+
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
|
|
106
|
+
"""Cache embedding results to avoid redundant API calls.
|
|
107
|
+
|
|
108
|
+
IMPORTANT: Caller must truncate text BEFORE calling this function.
|
|
109
|
+
This ensures cache keys are consistent for truncated inputs.
|
|
110
|
+
"""
|
|
111
|
+
response = requests.post(
|
|
112
|
+
f"{embedding_url}/api/embed",
|
|
113
|
+
headers={"Content-Type": "application/json"},
|
|
114
|
+
json={"model": model, "input": text},
|
|
115
|
+
timeout=timeout,
|
|
116
|
+
)
|
|
117
|
+
response.raise_for_status()
|
|
118
|
+
data = response.json()
|
|
119
|
+
embeddings = data.get("embeddings", [])
|
|
120
|
+
if not embeddings or not embeddings[0]:
|
|
121
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
122
|
+
return tuple(embeddings[0])
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
126
|
+
"""
|
|
127
|
+
Ollama provider for both LLM and Embedding operations.
|
|
128
|
+
|
|
129
|
+
Performance features:
|
|
130
|
+
- Connection pooling via requests.Session() for faster sequential requests
|
|
131
|
+
- Native batch embedding via /api/embed endpoint (single API call)
|
|
132
|
+
- LRU cache for repeated embedding queries (2048 entries)
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
base_url : str, optional
|
|
137
|
+
Ollama server URL (default: from OLLAMA_BASE_URL env var)
|
|
138
|
+
api_key : str, optional
|
|
139
|
+
API key for authentication (default: from OLLAMA_API_KEY env var)
|
|
140
|
+
timeout : int, optional
|
|
141
|
+
Request timeout in seconds (default: from OLLAMA_TIMEOUT env var)
|
|
142
|
+
use_cache : bool, optional
|
|
143
|
+
Enable embedding cache (default: True)
|
|
144
|
+
|
|
145
|
+
Examples
|
|
146
|
+
--------
|
|
147
|
+
>>> provider = OllamaProvider()
|
|
148
|
+
>>> response = provider.generate("What is RAG?", model="llama3")
|
|
149
|
+
>>> print(response.text)
|
|
150
|
+
|
|
151
|
+
>>> # Batch embedding (single API call)
|
|
152
|
+
>>> embeddings = provider.embed_batch(texts, "mxbai-embed-large")
|
|
153
|
+
"""
|
|
154
|
+
|
|
155
|
+
# Known embedding model dimensions
|
|
156
|
+
EMBEDDING_DIMENSIONS: dict[str, int] = {
|
|
157
|
+
"nomic-embed-text": 768,
|
|
158
|
+
"nomic-embed-text:latest": 768,
|
|
159
|
+
"mxbai-embed-large": 1024,
|
|
160
|
+
"all-minilm": 384,
|
|
161
|
+
"snowflake-arctic-embed": 1024,
|
|
162
|
+
"qwen3-embedding": 4096,
|
|
163
|
+
"qwen3-embedding:0.6b": 1024,
|
|
164
|
+
"qwen3-embedding:4b": 2560,
|
|
165
|
+
"qwen3-embedding:8b": 4096,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
# Max characters per embedding request (safe limit for 512 token models)
|
|
169
|
+
MAX_EMBED_CHARS = 2000
|
|
170
|
+
|
|
171
|
+
# Default timeouts per operation type (in seconds)
|
|
172
|
+
DEFAULT_TIMEOUTS: dict[str, int] = {
|
|
173
|
+
"generate": 300, # 5 minutes for LLM generation
|
|
174
|
+
"chat": 300, # 5 minutes for chat
|
|
175
|
+
"embed": 30, # 30 seconds for single embedding
|
|
176
|
+
"embed_batch": 120, # 2 minutes for batch embedding
|
|
177
|
+
"health": 5, # 5 seconds for health check
|
|
178
|
+
"list_models": 10, # 10 seconds for listing models
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
base_url: str | None = None,
|
|
184
|
+
embedding_url: str | None = None,
|
|
185
|
+
api_key: str | None = None,
|
|
186
|
+
timeout: int | None = None,
|
|
187
|
+
timeouts: dict[str, int] | None = None,
|
|
188
|
+
use_cache: bool = True,
|
|
189
|
+
use_resilience: bool = True,
|
|
190
|
+
) -> None:
|
|
191
|
+
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
192
|
+
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
193
|
+
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
194
|
+
self.use_cache = use_cache
|
|
195
|
+
self.use_resilience = use_resilience
|
|
196
|
+
self._current_embed_model: str | None = None
|
|
197
|
+
self._current_dimensions: int = 768 # default
|
|
198
|
+
|
|
199
|
+
# Per-operation timeouts (merge user overrides with defaults)
|
|
200
|
+
self._timeouts = {**self.DEFAULT_TIMEOUTS, **(timeouts or {})}
|
|
201
|
+
# Legacy single timeout parameter overrides all operations
|
|
202
|
+
if timeout is not None:
|
|
203
|
+
self._timeouts = {k: timeout for k in self._timeouts}
|
|
204
|
+
# Keep legacy timeout property for backwards compatibility
|
|
205
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
206
|
+
|
|
207
|
+
# Connection pooling via session
|
|
208
|
+
self._session: requests.Session | None = None
|
|
209
|
+
|
|
210
|
+
# Resilience policies (retry + circuit breaker)
|
|
211
|
+
self._generate_policy: SafetyNet | None = None
|
|
212
|
+
self._embed_policy: SafetyNet | None = None
|
|
213
|
+
if use_resilience:
|
|
214
|
+
self._generate_policy = _create_generate_policy()
|
|
215
|
+
self._embed_policy = _create_embed_policy()
|
|
216
|
+
|
|
217
|
+
@property
|
|
218
|
+
def session(self) -> requests.Session:
|
|
219
|
+
"""Lazy-initialized session for connection pooling."""
|
|
220
|
+
if self._session is None:
|
|
221
|
+
self._session = requests.Session()
|
|
222
|
+
self._session.headers.update({"Content-Type": "application/json"})
|
|
223
|
+
if self.api_key:
|
|
224
|
+
self._session.headers.update({"Authorization": f"Bearer {self.api_key}"})
|
|
225
|
+
return self._session
|
|
226
|
+
|
|
227
|
+
def close(self) -> None:
|
|
228
|
+
"""Close the session and release resources."""
|
|
229
|
+
if self._session is not None:
|
|
230
|
+
self._session.close()
|
|
231
|
+
self._session = None
|
|
232
|
+
|
|
233
|
+
def __del__(self) -> None:
|
|
234
|
+
"""Cleanup on garbage collection."""
|
|
235
|
+
self.close()
|
|
236
|
+
|
|
237
|
+
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
238
|
+
"""Get request headers including authentication if API key is set."""
|
|
239
|
+
headers = {"Content-Type": "application/json"}
|
|
240
|
+
if include_auth and self.api_key:
|
|
241
|
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
242
|
+
return headers
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def provider_name(self) -> str:
|
|
246
|
+
return "ollama"
|
|
247
|
+
|
|
248
|
+
@property
|
|
249
|
+
def dimensions(self) -> int:
|
|
250
|
+
return self._current_dimensions
|
|
251
|
+
|
|
252
|
+
def is_available(self) -> bool:
|
|
253
|
+
"""Check if Ollama server is reachable."""
|
|
254
|
+
try:
|
|
255
|
+
response = self.session.get(
|
|
256
|
+
f"{self.base_url}/api/tags",
|
|
257
|
+
timeout=self._timeouts["health"],
|
|
258
|
+
)
|
|
259
|
+
return bool(response.status_code == 200)
|
|
260
|
+
except requests.RequestException:
|
|
261
|
+
return False
|
|
262
|
+
|
|
263
|
+
def list_models(self) -> list[dict[str, Any]]:
|
|
264
|
+
"""List available models on the Ollama server."""
|
|
265
|
+
try:
|
|
266
|
+
response = self.session.get(
|
|
267
|
+
f"{self.base_url}/api/tags",
|
|
268
|
+
timeout=self._timeouts["list_models"],
|
|
269
|
+
)
|
|
270
|
+
response.raise_for_status()
|
|
271
|
+
data = response.json()
|
|
272
|
+
return list(data.get("models", []))
|
|
273
|
+
except requests.RequestException as e:
|
|
274
|
+
raise ProviderError("Failed to list Ollama models", e) from e
|
|
275
|
+
|
|
276
|
+
def generate(
|
|
277
|
+
self,
|
|
278
|
+
prompt: str,
|
|
279
|
+
model: str,
|
|
280
|
+
system_prompt: str | None = None,
|
|
281
|
+
temperature: float = 0.7,
|
|
282
|
+
max_tokens: int | None = None,
|
|
283
|
+
) -> LLMResponse:
|
|
284
|
+
"""Generate text using Ollama with optional resilience (retry + circuit breaker)."""
|
|
285
|
+
if self.use_resilience and self._generate_policy is not None:
|
|
286
|
+
|
|
287
|
+
@self._generate_policy
|
|
288
|
+
def _protected_generate() -> LLMResponse:
|
|
289
|
+
return self._do_generate(prompt, model, system_prompt, temperature, max_tokens)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
return _protected_generate()
|
|
293
|
+
except ProtectedCallError as e:
|
|
294
|
+
logger.warning(f"Circuit breaker OPEN for ollama.generate (model={model})")
|
|
295
|
+
raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e
|
|
296
|
+
except RetryLimitReached as e:
|
|
297
|
+
logger.error(f"Retry limit reached for ollama.generate (model={model}): {e.__cause__}")
|
|
298
|
+
raise ProviderError("Ollama generate failed after retries", e.__cause__) from e
|
|
299
|
+
else:
|
|
300
|
+
return self._do_generate(prompt, model, system_prompt, temperature, max_tokens)
|
|
301
|
+
|
|
302
|
+
def _do_generate(
|
|
303
|
+
self,
|
|
304
|
+
prompt: str,
|
|
305
|
+
model: str,
|
|
306
|
+
system_prompt: str | None = None,
|
|
307
|
+
temperature: float = 0.7,
|
|
308
|
+
max_tokens: int | None = None,
|
|
309
|
+
) -> LLMResponse:
|
|
310
|
+
"""Internal generate implementation (unprotected)."""
|
|
311
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
312
|
+
if max_tokens:
|
|
313
|
+
options["num_predict"] = max_tokens
|
|
314
|
+
|
|
315
|
+
payload: dict[str, str | bool | dict[str, float | int]] = {
|
|
316
|
+
"model": model,
|
|
317
|
+
"prompt": prompt,
|
|
318
|
+
"stream": False,
|
|
319
|
+
"options": options,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
if system_prompt:
|
|
323
|
+
payload["system"] = system_prompt
|
|
324
|
+
|
|
325
|
+
with log_operation("ollama.generate", model=model, prompt_len=len(prompt)) as ctx:
|
|
326
|
+
try:
|
|
327
|
+
response = self.session.post(
|
|
328
|
+
f"{self.base_url}/api/generate",
|
|
329
|
+
json=payload,
|
|
330
|
+
timeout=self._timeouts["generate"],
|
|
331
|
+
)
|
|
332
|
+
response.raise_for_status()
|
|
333
|
+
data = response.json()
|
|
334
|
+
|
|
335
|
+
ctx["completion_tokens"] = data.get("eval_count")
|
|
336
|
+
|
|
337
|
+
return LLMResponse(
|
|
338
|
+
text=data.get("response", ""),
|
|
339
|
+
model=model,
|
|
340
|
+
provider=self.provider_name,
|
|
341
|
+
usage={
|
|
342
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
343
|
+
"completion_tokens": data.get("eval_count"),
|
|
344
|
+
"total_duration": data.get("total_duration"),
|
|
345
|
+
},
|
|
346
|
+
)
|
|
347
|
+
except requests.RequestException as e:
|
|
348
|
+
raise ProviderError("Ollama generate failed", e) from e
|
|
349
|
+
|
|
350
|
+
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
351
|
+
"""Generate embedding using Ollama with optional caching and resilience."""
|
|
352
|
+
if self.use_resilience and self._embed_policy is not None:
|
|
353
|
+
|
|
354
|
+
@self._embed_policy
|
|
355
|
+
def _protected_embed() -> EmbeddingResponse:
|
|
356
|
+
return self._do_embed(text, model)
|
|
357
|
+
|
|
358
|
+
try:
|
|
359
|
+
return _protected_embed()
|
|
360
|
+
except ProtectedCallError as e:
|
|
361
|
+
logger.warning(f"Circuit breaker OPEN for ollama.embed (model={model})")
|
|
362
|
+
raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e
|
|
363
|
+
except RetryLimitReached as e:
|
|
364
|
+
logger.error(f"Retry limit reached for ollama.embed (model={model}): {e.__cause__}")
|
|
365
|
+
raise IndexingError("Ollama embed failed after retries", e.__cause__) from e
|
|
366
|
+
else:
|
|
367
|
+
return self._do_embed(text, model)
|
|
368
|
+
|
|
369
|
+
def _do_embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
370
|
+
"""Internal embed implementation (unprotected)."""
|
|
371
|
+
self._current_embed_model = model
|
|
372
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
373
|
+
|
|
374
|
+
# Truncate BEFORE cache lookup (fixes cache key bug)
|
|
375
|
+
truncated_text = _truncate_text(text, self.MAX_EMBED_CHARS)
|
|
376
|
+
was_truncated = len(text) > self.MAX_EMBED_CHARS
|
|
377
|
+
|
|
378
|
+
with log_operation("ollama.embed", model=model, text_len=len(text), truncated=was_truncated) as ctx:
|
|
379
|
+
try:
|
|
380
|
+
if self.use_cache:
|
|
381
|
+
# Use cached version with truncated text
|
|
382
|
+
embedding = _cached_embedding(truncated_text, model, self.embedding_url, self._timeouts["embed"])
|
|
383
|
+
ctx["cache"] = "hit_or_miss" # Can't tell from here
|
|
384
|
+
else:
|
|
385
|
+
# Direct call without cache
|
|
386
|
+
response = self.session.post(
|
|
387
|
+
f"{self.embedding_url}/api/embed",
|
|
388
|
+
json={"model": model, "input": truncated_text},
|
|
389
|
+
timeout=self._timeouts["embed"],
|
|
390
|
+
)
|
|
391
|
+
response.raise_for_status()
|
|
392
|
+
data = response.json()
|
|
393
|
+
embeddings = data.get("embeddings", [])
|
|
394
|
+
if not embeddings or not embeddings[0]:
|
|
395
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
396
|
+
embedding = tuple(embeddings[0])
|
|
397
|
+
ctx["cache"] = "disabled"
|
|
398
|
+
|
|
399
|
+
# Update dimensions from actual response
|
|
400
|
+
self._current_dimensions = len(embedding)
|
|
401
|
+
ctx["dimensions"] = len(embedding)
|
|
402
|
+
|
|
403
|
+
return EmbeddingResponse(
|
|
404
|
+
embedding=embedding,
|
|
405
|
+
model=model,
|
|
406
|
+
provider=self.provider_name,
|
|
407
|
+
dimensions=len(embedding),
|
|
408
|
+
)
|
|
409
|
+
except requests.RequestException as e:
|
|
410
|
+
raise IndexingError("Ollama embed failed", e) from e
|
|
411
|
+
|
|
412
|
+
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
413
|
+
"""Generate embeddings for multiple texts in a single API call with resilience.
|
|
414
|
+
|
|
415
|
+
The /api/embed endpoint supports batch inputs natively.
|
|
416
|
+
"""
|
|
417
|
+
if self.use_resilience and self._embed_policy is not None:
|
|
418
|
+
|
|
419
|
+
@self._embed_policy
|
|
420
|
+
def _protected_embed_batch() -> list[EmbeddingResponse]:
|
|
421
|
+
return self._do_embed_batch(texts, model)
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
return _protected_embed_batch()
|
|
425
|
+
except ProtectedCallError as e:
|
|
426
|
+
logger.warning(f"Circuit breaker OPEN for ollama.embed_batch (model={model}, batch_size={len(texts)})")
|
|
427
|
+
raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e
|
|
428
|
+
except RetryLimitReached as e:
|
|
429
|
+
logger.error(f"Retry limit reached for ollama.embed_batch (model={model}): {e.__cause__}")
|
|
430
|
+
raise IndexingError("Ollama batch embed failed after retries", e.__cause__) from e
|
|
431
|
+
else:
|
|
432
|
+
return self._do_embed_batch(texts, model)
|
|
433
|
+
|
|
434
|
+
def _do_embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
435
|
+
"""Internal batch embed implementation (unprotected)."""
|
|
436
|
+
self._current_embed_model = model
|
|
437
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
438
|
+
|
|
439
|
+
# Truncate oversized inputs
|
|
440
|
+
truncated_texts = [_truncate_text(text, self.MAX_EMBED_CHARS) for text in texts]
|
|
441
|
+
truncated_count = sum(1 for t, tt in zip(texts, truncated_texts, strict=True) if len(t) != len(tt))
|
|
442
|
+
|
|
443
|
+
with log_operation(
|
|
444
|
+
"ollama.embed_batch", model=model, batch_size=len(texts), truncated_count=truncated_count
|
|
445
|
+
) as ctx:
|
|
446
|
+
try:
|
|
447
|
+
response = self.session.post(
|
|
448
|
+
f"{self.embedding_url}/api/embed",
|
|
449
|
+
json={"model": model, "input": truncated_texts},
|
|
450
|
+
timeout=self._timeouts["embed_batch"],
|
|
451
|
+
)
|
|
452
|
+
response.raise_for_status()
|
|
453
|
+
data = response.json()
|
|
454
|
+
embeddings_list = data.get("embeddings", [])
|
|
455
|
+
|
|
456
|
+
if not embeddings_list:
|
|
457
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
458
|
+
|
|
459
|
+
results = []
|
|
460
|
+
for embedding_data in embeddings_list:
|
|
461
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
462
|
+
if embedding:
|
|
463
|
+
self._current_dimensions = len(embedding)
|
|
464
|
+
|
|
465
|
+
results.append(
|
|
466
|
+
EmbeddingResponse(
|
|
467
|
+
embedding=embedding,
|
|
468
|
+
model=model,
|
|
469
|
+
provider=self.provider_name,
|
|
470
|
+
dimensions=len(embedding),
|
|
471
|
+
)
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
ctx["dimensions"] = self._current_dimensions
|
|
475
|
+
return results
|
|
476
|
+
except requests.RequestException as e:
|
|
477
|
+
raise IndexingError("Ollama batch embed failed", e) from e
|
|
478
|
+
|
|
479
|
+
async def embed_batch_async(
|
|
480
|
+
self,
|
|
481
|
+
texts: list[str],
|
|
482
|
+
model: str,
|
|
483
|
+
max_concurrent: int = 10, # kept for API compatibility, no longer used
|
|
484
|
+
) -> list[EmbeddingResponse]:
|
|
485
|
+
"""Generate embeddings for multiple texts asynchronously.
|
|
486
|
+
|
|
487
|
+
The /api/embed endpoint supports batch inputs natively, so this
|
|
488
|
+
makes a single async HTTP request for all texts.
|
|
489
|
+
|
|
490
|
+
Parameters
|
|
491
|
+
----------
|
|
492
|
+
texts : list[str]
|
|
493
|
+
Texts to embed.
|
|
494
|
+
model : str
|
|
495
|
+
Embedding model name.
|
|
496
|
+
max_concurrent : int
|
|
497
|
+
Deprecated, kept for API compatibility. No longer used since
|
|
498
|
+
the API now supports native batching.
|
|
499
|
+
|
|
500
|
+
Returns
|
|
501
|
+
-------
|
|
502
|
+
list[EmbeddingResponse]
|
|
503
|
+
Embeddings in the same order as input texts.
|
|
504
|
+
|
|
505
|
+
Examples
|
|
506
|
+
--------
|
|
507
|
+
>>> import trio
|
|
508
|
+
>>> embeddings = trio.run(provider.embed_batch_async, texts, "mxbai-embed-large")
|
|
509
|
+
"""
|
|
510
|
+
self._current_embed_model = model
|
|
511
|
+
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
512
|
+
|
|
513
|
+
# Truncate oversized inputs
|
|
514
|
+
truncated_texts = [text[: self.MAX_EMBED_CHARS] if len(text) > self.MAX_EMBED_CHARS else text for text in texts]
|
|
515
|
+
|
|
516
|
+
try:
|
|
517
|
+
async with httpx.AsyncClient() as client:
|
|
518
|
+
response = await client.post(
|
|
519
|
+
f"{self.embedding_url}/api/embed",
|
|
520
|
+
json={"model": model, "input": truncated_texts},
|
|
521
|
+
timeout=self._timeouts["embed_batch"],
|
|
522
|
+
)
|
|
523
|
+
response.raise_for_status()
|
|
524
|
+
data = response.json()
|
|
525
|
+
|
|
526
|
+
embeddings_list = data.get("embeddings", [])
|
|
527
|
+
if not embeddings_list:
|
|
528
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
529
|
+
|
|
530
|
+
results = []
|
|
531
|
+
for embedding_data in embeddings_list:
|
|
532
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
533
|
+
if embedding:
|
|
534
|
+
self._current_dimensions = len(embedding)
|
|
535
|
+
|
|
536
|
+
results.append(
|
|
537
|
+
EmbeddingResponse(
|
|
538
|
+
embedding=embedding,
|
|
539
|
+
model=model,
|
|
540
|
+
provider=self.provider_name,
|
|
541
|
+
dimensions=len(embedding),
|
|
542
|
+
)
|
|
543
|
+
)
|
|
544
|
+
return results
|
|
545
|
+
except httpx.HTTPError as e:
|
|
546
|
+
raise IndexingError("Ollama async batch embed failed", e) from e
|
|
547
|
+
|
|
548
|
+
def chat(
|
|
549
|
+
self,
|
|
550
|
+
messages: list[dict[str, str]],
|
|
551
|
+
model: str,
|
|
552
|
+
temperature: float = 0.7,
|
|
553
|
+
max_tokens: int | None = None,
|
|
554
|
+
) -> LLMResponse:
|
|
555
|
+
"""
|
|
556
|
+
Chat completion using Ollama with optional resilience.
|
|
557
|
+
|
|
558
|
+
Parameters
|
|
559
|
+
----------
|
|
560
|
+
messages : list[dict]
|
|
561
|
+
List of messages with 'role' and 'content' keys.
|
|
562
|
+
model : str
|
|
563
|
+
Model identifier.
|
|
564
|
+
temperature : float
|
|
565
|
+
Sampling temperature.
|
|
566
|
+
max_tokens : int, optional
|
|
567
|
+
Maximum tokens to generate.
|
|
568
|
+
|
|
569
|
+
Returns
|
|
570
|
+
-------
|
|
571
|
+
LLMResponse
|
|
572
|
+
The generated response.
|
|
573
|
+
"""
|
|
574
|
+
if self.use_resilience and self._generate_policy is not None:
|
|
575
|
+
|
|
576
|
+
@self._generate_policy
|
|
577
|
+
def _protected_chat() -> LLMResponse:
|
|
578
|
+
return self._do_chat(messages, model, temperature, max_tokens)
|
|
579
|
+
|
|
580
|
+
try:
|
|
581
|
+
return _protected_chat()
|
|
582
|
+
except ProtectedCallError as e:
|
|
583
|
+
logger.warning(f"Circuit breaker OPEN for ollama.chat (model={model})")
|
|
584
|
+
raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e
|
|
585
|
+
except RetryLimitReached as e:
|
|
586
|
+
logger.error(f"Retry limit reached for ollama.chat (model={model}): {e.__cause__}")
|
|
587
|
+
raise ProviderError("Ollama chat failed after retries", e.__cause__) from e
|
|
588
|
+
else:
|
|
589
|
+
return self._do_chat(messages, model, temperature, max_tokens)
|
|
590
|
+
|
|
591
|
+
def _do_chat(
|
|
592
|
+
self,
|
|
593
|
+
messages: list[dict[str, str]],
|
|
594
|
+
model: str,
|
|
595
|
+
temperature: float = 0.7,
|
|
596
|
+
max_tokens: int | None = None,
|
|
597
|
+
) -> LLMResponse:
|
|
598
|
+
"""Internal chat implementation (unprotected)."""
|
|
599
|
+
options: dict[str, float | int] = {"temperature": temperature}
|
|
600
|
+
if max_tokens:
|
|
601
|
+
options["num_predict"] = max_tokens
|
|
602
|
+
|
|
603
|
+
payload: dict[str, str | bool | list[dict[str, str]] | dict[str, float | int]] = {
|
|
604
|
+
"model": model,
|
|
605
|
+
"messages": messages,
|
|
606
|
+
"stream": False,
|
|
607
|
+
"options": options,
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
with log_operation("ollama.chat", model=model, message_count=len(messages)) as ctx:
|
|
611
|
+
try:
|
|
612
|
+
response = self.session.post(
|
|
613
|
+
f"{self.base_url}/api/chat",
|
|
614
|
+
json=payload,
|
|
615
|
+
timeout=self._timeouts["chat"],
|
|
616
|
+
)
|
|
617
|
+
response.raise_for_status()
|
|
618
|
+
data = response.json()
|
|
619
|
+
|
|
620
|
+
ctx["completion_tokens"] = data.get("eval_count")
|
|
621
|
+
|
|
622
|
+
return LLMResponse(
|
|
623
|
+
text=data.get("message", {}).get("content", ""),
|
|
624
|
+
model=model,
|
|
625
|
+
provider=self.provider_name,
|
|
626
|
+
usage={
|
|
627
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
628
|
+
"completion_tokens": data.get("eval_count"),
|
|
629
|
+
},
|
|
630
|
+
)
|
|
631
|
+
except requests.RequestException as e:
|
|
632
|
+
raise ProviderError("Ollama chat failed", e) from e
|
|
633
|
+
|
|
634
|
+
# Circuit breaker status monitoring
|
|
635
|
+
@property
|
|
636
|
+
def generate_circuit_status(self) -> str:
|
|
637
|
+
"""Get generate circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled')."""
|
|
638
|
+
if not self.use_resilience or self._generate_policy is None:
|
|
639
|
+
return "disabled"
|
|
640
|
+
# Access the circuit protector (second policy in SafetyNet)
|
|
641
|
+
circuit = self._generate_policy._policies[1]
|
|
642
|
+
return circuit.status.name
|
|
643
|
+
|
|
644
|
+
@property
|
|
645
|
+
def embed_circuit_status(self) -> str:
|
|
646
|
+
"""Get embed circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled')."""
|
|
647
|
+
if not self.use_resilience or self._embed_policy is None:
|
|
648
|
+
return "disabled"
|
|
649
|
+
circuit = self._embed_policy._policies[1]
|
|
650
|
+
return circuit.status.name
|
|
651
|
+
|
|
652
|
+
@staticmethod
|
|
653
|
+
def clear_embedding_cache() -> None:
|
|
654
|
+
"""Clear the embedding cache."""
|
|
655
|
+
_cached_embedding.cache_clear()
|
|
656
|
+
|
|
657
|
+
@staticmethod
|
|
658
|
+
def embedding_cache_info() -> dict[str, int]:
|
|
659
|
+
"""Get embedding cache statistics."""
|
|
660
|
+
info = _cached_embedding.cache_info()
|
|
661
|
+
return {
|
|
662
|
+
"hits": info.hits,
|
|
663
|
+
"misses": info.misses,
|
|
664
|
+
"maxsize": info.maxsize or 0,
|
|
665
|
+
"currsize": info.currsize,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
# Export the EMBEDDING_DIMENSIONS for external use
|
|
670
|
+
EMBEDDING_DIMENSIONS = OllamaProvider.EMBEDDING_DIMENSIONS
|