ragit 0.8.2__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ragit/__init__.py +27 -15
- ragit/assistant.py +431 -40
- ragit/config.py +165 -22
- ragit/core/experiment/experiment.py +7 -1
- ragit/exceptions.py +271 -0
- ragit/loaders.py +200 -44
- ragit/logging.py +194 -0
- ragit/monitor.py +307 -0
- ragit/providers/__init__.py +1 -13
- ragit/providers/ollama.py +379 -121
- ragit/utils/__init__.py +0 -22
- ragit/version.py +1 -1
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/METADATA +48 -25
- ragit-0.11.0.dist-info/RECORD +22 -0
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/WHEEL +1 -1
- ragit/providers/sentence_transformers.py +0 -225
- ragit-0.8.2.dist-info/RECORD +0 -20
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/licenses/LICENSE +0 -0
- {ragit-0.8.2.dist-info → ragit-0.11.0.dist-info}/top_level.txt +0 -0
ragit/providers/ollama.py
CHANGED
|
@@ -10,17 +10,32 @@ Configuration is loaded from environment variables.
|
|
|
10
10
|
|
|
11
11
|
Performance optimizations:
|
|
12
12
|
- Connection pooling via requests.Session()
|
|
13
|
-
- Async parallel embedding via
|
|
13
|
+
- Async parallel embedding via httpx
|
|
14
14
|
- LRU cache for repeated embedding queries
|
|
15
|
+
|
|
16
|
+
Resilience features (via resilient-circuit):
|
|
17
|
+
- Retry with exponential backoff
|
|
18
|
+
- Circuit breaker pattern for fault tolerance
|
|
15
19
|
"""
|
|
16
20
|
|
|
21
|
+
from datetime import timedelta
|
|
22
|
+
from fractions import Fraction
|
|
17
23
|
from functools import lru_cache
|
|
18
24
|
from typing import Any
|
|
19
25
|
|
|
20
26
|
import httpx
|
|
21
27
|
import requests
|
|
28
|
+
from resilient_circuit import (
|
|
29
|
+
CircuitProtectorPolicy,
|
|
30
|
+
ExponentialDelay,
|
|
31
|
+
RetryWithBackoffPolicy,
|
|
32
|
+
SafetyNet,
|
|
33
|
+
)
|
|
34
|
+
from resilient_circuit.exceptions import ProtectedCallError, RetryLimitReached
|
|
22
35
|
|
|
23
36
|
from ragit.config import config
|
|
37
|
+
from ragit.exceptions import IndexingError, ProviderError
|
|
38
|
+
from ragit.logging import log_operation, logger
|
|
24
39
|
from ragit.providers.base import (
|
|
25
40
|
BaseEmbeddingProvider,
|
|
26
41
|
BaseLLMProvider,
|
|
@@ -29,14 +44,70 @@ from ragit.providers.base import (
|
|
|
29
44
|
)
|
|
30
45
|
|
|
31
46
|
|
|
47
|
+
def _create_generate_policy() -> SafetyNet:
|
|
48
|
+
"""Create resilience policy for LLM generation (longer timeouts, more tolerant)."""
|
|
49
|
+
return SafetyNet(
|
|
50
|
+
policies=(
|
|
51
|
+
RetryWithBackoffPolicy(
|
|
52
|
+
max_retries=3,
|
|
53
|
+
backoff=ExponentialDelay(
|
|
54
|
+
min_delay=timedelta(seconds=1),
|
|
55
|
+
max_delay=timedelta(seconds=30),
|
|
56
|
+
factor=2,
|
|
57
|
+
jitter=0.1,
|
|
58
|
+
),
|
|
59
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
|
|
60
|
+
),
|
|
61
|
+
CircuitProtectorPolicy(
|
|
62
|
+
resource_key="ollama_generate",
|
|
63
|
+
cooldown=timedelta(seconds=60),
|
|
64
|
+
failure_limit=Fraction(3, 10), # 30% failure rate trips circuit
|
|
65
|
+
success_limit=Fraction(4, 5), # 80% success to close
|
|
66
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _create_embed_policy() -> SafetyNet:
|
|
73
|
+
"""Create resilience policy for embeddings (faster, stricter)."""
|
|
74
|
+
return SafetyNet(
|
|
75
|
+
policies=(
|
|
76
|
+
RetryWithBackoffPolicy(
|
|
77
|
+
max_retries=2,
|
|
78
|
+
backoff=ExponentialDelay(
|
|
79
|
+
min_delay=timedelta(milliseconds=500),
|
|
80
|
+
max_delay=timedelta(seconds=5),
|
|
81
|
+
factor=2,
|
|
82
|
+
jitter=0.1,
|
|
83
|
+
),
|
|
84
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, TimeoutError, requests.RequestException)),
|
|
85
|
+
),
|
|
86
|
+
CircuitProtectorPolicy(
|
|
87
|
+
resource_key="ollama_embed",
|
|
88
|
+
cooldown=timedelta(seconds=30),
|
|
89
|
+
failure_limit=Fraction(2, 5), # 40% failure rate trips circuit
|
|
90
|
+
success_limit=Fraction(3, 3), # All 3 tests must succeed to close
|
|
91
|
+
should_handle=lambda e: isinstance(e, (ConnectionError, requests.RequestException)),
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _truncate_text(text: str, max_chars: int = 2000) -> str:
|
|
98
|
+
"""Truncate text to max_chars. Used BEFORE cache lookup to fix cache key bug."""
|
|
99
|
+
return text[:max_chars] if len(text) > max_chars else text
|
|
100
|
+
|
|
101
|
+
|
|
32
102
|
# Module-level cache for embeddings (shared across instances)
|
|
103
|
+
# NOTE: Text must be truncated BEFORE calling this function to ensure correct cache keys
|
|
33
104
|
@lru_cache(maxsize=2048)
|
|
34
105
|
def _cached_embedding(text: str, model: str, embedding_url: str, timeout: int) -> tuple[float, ...]:
|
|
35
|
-
"""Cache embedding results to avoid redundant API calls.
|
|
36
|
-
# Truncate oversized inputs
|
|
37
|
-
if len(text) > OllamaProvider.MAX_EMBED_CHARS:
|
|
38
|
-
text = text[: OllamaProvider.MAX_EMBED_CHARS]
|
|
106
|
+
"""Cache embedding results to avoid redundant API calls.
|
|
39
107
|
|
|
108
|
+
IMPORTANT: Caller must truncate text BEFORE calling this function.
|
|
109
|
+
This ensures cache keys are consistent for truncated inputs.
|
|
110
|
+
"""
|
|
40
111
|
response = requests.post(
|
|
41
112
|
f"{embedding_url}/api/embed",
|
|
42
113
|
headers={"Content-Type": "application/json"},
|
|
@@ -97,43 +168,90 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
97
168
|
# Max characters per embedding request (safe limit for 512 token models)
|
|
98
169
|
MAX_EMBED_CHARS = 2000
|
|
99
170
|
|
|
171
|
+
# Default timeouts per operation type (in seconds)
|
|
172
|
+
DEFAULT_TIMEOUTS: dict[str, int] = {
|
|
173
|
+
"generate": 300, # 5 minutes for LLM generation
|
|
174
|
+
"chat": 300, # 5 minutes for chat
|
|
175
|
+
"embed": 30, # 30 seconds for single embedding
|
|
176
|
+
"embed_batch": 120, # 2 minutes for batch embedding
|
|
177
|
+
"health": 5, # 5 seconds for health check
|
|
178
|
+
"list_models": 10, # 10 seconds for listing models
|
|
179
|
+
}
|
|
180
|
+
|
|
100
181
|
def __init__(
|
|
101
182
|
self,
|
|
102
183
|
base_url: str | None = None,
|
|
103
184
|
embedding_url: str | None = None,
|
|
104
185
|
api_key: str | None = None,
|
|
105
186
|
timeout: int | None = None,
|
|
187
|
+
timeouts: dict[str, int] | None = None,
|
|
106
188
|
use_cache: bool = True,
|
|
189
|
+
use_resilience: bool = True,
|
|
107
190
|
) -> None:
|
|
108
191
|
self.base_url = (base_url or config.OLLAMA_BASE_URL).rstrip("/")
|
|
109
192
|
self.embedding_url = (embedding_url or config.OLLAMA_EMBEDDING_URL).rstrip("/")
|
|
110
193
|
self.api_key = api_key or config.OLLAMA_API_KEY
|
|
111
|
-
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
112
194
|
self.use_cache = use_cache
|
|
195
|
+
self.use_resilience = use_resilience
|
|
113
196
|
self._current_embed_model: str | None = None
|
|
114
197
|
self._current_dimensions: int = 768 # default
|
|
115
198
|
|
|
199
|
+
# Per-operation timeouts (merge user overrides with defaults)
|
|
200
|
+
self._timeouts = {**self.DEFAULT_TIMEOUTS, **(timeouts or {})}
|
|
201
|
+
# Legacy single timeout parameter overrides all operations
|
|
202
|
+
if timeout is not None:
|
|
203
|
+
self._timeouts = {k: timeout for k in self._timeouts}
|
|
204
|
+
# Keep legacy timeout property for backwards compatibility
|
|
205
|
+
self.timeout = timeout or config.OLLAMA_TIMEOUT
|
|
206
|
+
|
|
116
207
|
# Connection pooling via session
|
|
117
208
|
self._session: requests.Session | None = None
|
|
118
209
|
|
|
210
|
+
# Resilience policies (retry + circuit breaker)
|
|
211
|
+
self._generate_policy: SafetyNet | None = None
|
|
212
|
+
self._embed_policy: SafetyNet | None = None
|
|
213
|
+
if use_resilience:
|
|
214
|
+
self._generate_policy = _create_generate_policy()
|
|
215
|
+
self._embed_policy = _create_embed_policy()
|
|
216
|
+
|
|
119
217
|
@property
|
|
120
218
|
def session(self) -> requests.Session:
|
|
121
|
-
"""Lazy-initialized session for connection pooling.
|
|
219
|
+
"""Lazy-initialized session for connection pooling.
|
|
220
|
+
|
|
221
|
+
Note: API key is NOT stored in session headers to prevent
|
|
222
|
+
potential exposure in logs or error messages. Authentication
|
|
223
|
+
is handled per-request via _get_headers().
|
|
224
|
+
"""
|
|
122
225
|
if self._session is None:
|
|
123
226
|
self._session = requests.Session()
|
|
124
227
|
self._session.headers.update({"Content-Type": "application/json"})
|
|
125
|
-
|
|
126
|
-
|
|
228
|
+
# Security: API key is injected per-request via _get_headers()
|
|
229
|
+
# rather than stored in session headers to prevent log exposure
|
|
127
230
|
return self._session
|
|
128
231
|
|
|
129
232
|
def close(self) -> None:
|
|
130
233
|
"""Close the session and release resources."""
|
|
131
|
-
|
|
132
|
-
|
|
234
|
+
session = getattr(self, "_session", None)
|
|
235
|
+
if session is not None:
|
|
236
|
+
session.close()
|
|
133
237
|
self._session = None
|
|
134
238
|
|
|
239
|
+
def __enter__(self) -> "OllamaProvider":
|
|
240
|
+
"""Context manager entry - returns self for use in 'with' statements.
|
|
241
|
+
|
|
242
|
+
Example:
|
|
243
|
+
with OllamaProvider() as provider:
|
|
244
|
+
result = provider.generate("Hello", model="llama3")
|
|
245
|
+
# Session automatically closed here
|
|
246
|
+
"""
|
|
247
|
+
return self
|
|
248
|
+
|
|
249
|
+
def __exit__(self, exc_type: type | None, exc_val: Exception | None, exc_tb: object) -> None:
|
|
250
|
+
"""Context manager exit - ensures cleanup regardless of exceptions."""
|
|
251
|
+
self.close()
|
|
252
|
+
|
|
135
253
|
def __del__(self) -> None:
|
|
136
|
-
"""Cleanup on garbage collection."""
|
|
254
|
+
"""Cleanup on garbage collection (fallback, prefer context manager)."""
|
|
137
255
|
self.close()
|
|
138
256
|
|
|
139
257
|
def _get_headers(self, include_auth: bool = True) -> dict[str, str]:
|
|
@@ -156,7 +274,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
156
274
|
try:
|
|
157
275
|
response = self.session.get(
|
|
158
276
|
f"{self.base_url}/api/tags",
|
|
159
|
-
|
|
277
|
+
headers=self._get_headers(),
|
|
278
|
+
timeout=self._timeouts["health"],
|
|
160
279
|
)
|
|
161
280
|
return bool(response.status_code == 200)
|
|
162
281
|
except requests.RequestException:
|
|
@@ -167,13 +286,14 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
167
286
|
try:
|
|
168
287
|
response = self.session.get(
|
|
169
288
|
f"{self.base_url}/api/tags",
|
|
170
|
-
|
|
289
|
+
headers=self._get_headers(),
|
|
290
|
+
timeout=self._timeouts["list_models"],
|
|
171
291
|
)
|
|
172
292
|
response.raise_for_status()
|
|
173
293
|
data = response.json()
|
|
174
294
|
return list(data.get("models", []))
|
|
175
295
|
except requests.RequestException as e:
|
|
176
|
-
raise
|
|
296
|
+
raise ProviderError("Failed to list Ollama models", e) from e
|
|
177
297
|
|
|
178
298
|
def generate(
|
|
179
299
|
self,
|
|
@@ -183,7 +303,33 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
183
303
|
temperature: float = 0.7,
|
|
184
304
|
max_tokens: int | None = None,
|
|
185
305
|
) -> LLMResponse:
|
|
186
|
-
"""Generate text using Ollama."""
|
|
306
|
+
"""Generate text using Ollama with optional resilience (retry + circuit breaker)."""
|
|
307
|
+
if self.use_resilience and self._generate_policy is not None:
|
|
308
|
+
|
|
309
|
+
@self._generate_policy
|
|
310
|
+
def _protected_generate() -> LLMResponse:
|
|
311
|
+
return self._do_generate(prompt, model, system_prompt, temperature, max_tokens)
|
|
312
|
+
|
|
313
|
+
try:
|
|
314
|
+
return _protected_generate()
|
|
315
|
+
except ProtectedCallError as e:
|
|
316
|
+
logger.warning(f"Circuit breaker OPEN for ollama.generate (model={model})")
|
|
317
|
+
raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e
|
|
318
|
+
except RetryLimitReached as e:
|
|
319
|
+
logger.error(f"Retry limit reached for ollama.generate (model={model}): {e.__cause__}")
|
|
320
|
+
raise ProviderError("Ollama generate failed after retries", e.__cause__) from e
|
|
321
|
+
else:
|
|
322
|
+
return self._do_generate(prompt, model, system_prompt, temperature, max_tokens)
|
|
323
|
+
|
|
324
|
+
def _do_generate(
|
|
325
|
+
self,
|
|
326
|
+
prompt: str,
|
|
327
|
+
model: str,
|
|
328
|
+
system_prompt: str | None = None,
|
|
329
|
+
temperature: float = 0.7,
|
|
330
|
+
max_tokens: int | None = None,
|
|
331
|
+
) -> LLMResponse:
|
|
332
|
+
"""Internal generate implementation (unprotected)."""
|
|
187
333
|
options: dict[str, float | int] = {"temperature": temperature}
|
|
188
334
|
if max_tokens:
|
|
189
335
|
options["num_predict"] = max_tokens
|
|
@@ -198,105 +344,162 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
198
344
|
if system_prompt:
|
|
199
345
|
payload["system"] = system_prompt
|
|
200
346
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
347
|
+
with log_operation("ollama.generate", model=model, prompt_len=len(prompt)) as ctx:
|
|
348
|
+
try:
|
|
349
|
+
response = self.session.post(
|
|
350
|
+
f"{self.base_url}/api/generate",
|
|
351
|
+
headers=self._get_headers(),
|
|
352
|
+
json=payload,
|
|
353
|
+
timeout=self._timeouts["generate"],
|
|
354
|
+
)
|
|
355
|
+
response.raise_for_status()
|
|
356
|
+
data = response.json()
|
|
209
357
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
358
|
+
ctx["completion_tokens"] = data.get("eval_count")
|
|
359
|
+
|
|
360
|
+
return LLMResponse(
|
|
361
|
+
text=data.get("response", ""),
|
|
362
|
+
model=model,
|
|
363
|
+
provider=self.provider_name,
|
|
364
|
+
usage={
|
|
365
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
366
|
+
"completion_tokens": data.get("eval_count"),
|
|
367
|
+
"total_duration": data.get("total_duration"),
|
|
368
|
+
},
|
|
369
|
+
)
|
|
370
|
+
except requests.RequestException as e:
|
|
371
|
+
raise ProviderError("Ollama generate failed", e) from e
|
|
222
372
|
|
|
223
373
|
def embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
224
|
-
"""Generate embedding using Ollama with optional caching."""
|
|
374
|
+
"""Generate embedding using Ollama with optional caching and resilience."""
|
|
375
|
+
if self.use_resilience and self._embed_policy is not None:
|
|
376
|
+
|
|
377
|
+
@self._embed_policy
|
|
378
|
+
def _protected_embed() -> EmbeddingResponse:
|
|
379
|
+
return self._do_embed(text, model)
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
return _protected_embed()
|
|
383
|
+
except ProtectedCallError as e:
|
|
384
|
+
logger.warning(f"Circuit breaker OPEN for ollama.embed (model={model})")
|
|
385
|
+
raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e
|
|
386
|
+
except RetryLimitReached as e:
|
|
387
|
+
logger.error(f"Retry limit reached for ollama.embed (model={model}): {e.__cause__}")
|
|
388
|
+
raise IndexingError("Ollama embed failed after retries", e.__cause__) from e
|
|
389
|
+
else:
|
|
390
|
+
return self._do_embed(text, model)
|
|
391
|
+
|
|
392
|
+
def _do_embed(self, text: str, model: str) -> EmbeddingResponse:
|
|
393
|
+
"""Internal embed implementation (unprotected)."""
|
|
225
394
|
self._current_embed_model = model
|
|
226
395
|
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
227
396
|
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
397
|
+
# Truncate BEFORE cache lookup (fixes cache key bug)
|
|
398
|
+
truncated_text = _truncate_text(text, self.MAX_EMBED_CHARS)
|
|
399
|
+
was_truncated = len(text) > self.MAX_EMBED_CHARS
|
|
400
|
+
|
|
401
|
+
with log_operation("ollama.embed", model=model, text_len=len(text), truncated=was_truncated) as ctx:
|
|
402
|
+
try:
|
|
403
|
+
if self.use_cache:
|
|
404
|
+
# Use cached version with truncated text
|
|
405
|
+
embedding = _cached_embedding(truncated_text, model, self.embedding_url, self._timeouts["embed"])
|
|
406
|
+
ctx["cache"] = "hit_or_miss" # Can't tell from here
|
|
407
|
+
else:
|
|
408
|
+
# Direct call without cache
|
|
409
|
+
response = self.session.post(
|
|
410
|
+
f"{self.embedding_url}/api/embed",
|
|
411
|
+
headers=self._get_headers(),
|
|
412
|
+
json={"model": model, "input": truncated_text},
|
|
413
|
+
timeout=self._timeouts["embed"],
|
|
414
|
+
)
|
|
415
|
+
response.raise_for_status()
|
|
416
|
+
data = response.json()
|
|
417
|
+
embeddings = data.get("embeddings", [])
|
|
418
|
+
if not embeddings or not embeddings[0]:
|
|
419
|
+
raise ValueError("Empty embedding returned from Ollama")
|
|
420
|
+
embedding = tuple(embeddings[0])
|
|
421
|
+
ctx["cache"] = "disabled"
|
|
422
|
+
|
|
423
|
+
# Update dimensions from actual response
|
|
424
|
+
self._current_dimensions = len(embedding)
|
|
425
|
+
ctx["dimensions"] = len(embedding)
|
|
426
|
+
|
|
427
|
+
return EmbeddingResponse(
|
|
428
|
+
embedding=embedding,
|
|
429
|
+
model=model,
|
|
430
|
+
provider=self.provider_name,
|
|
431
|
+
dimensions=len(embedding),
|
|
239
432
|
)
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
embeddings = data.get("embeddings", [])
|
|
243
|
-
if not embeddings or not embeddings[0]:
|
|
244
|
-
raise ValueError("Empty embedding returned from Ollama")
|
|
245
|
-
embedding = tuple(embeddings[0])
|
|
246
|
-
|
|
247
|
-
# Update dimensions from actual response
|
|
248
|
-
self._current_dimensions = len(embedding)
|
|
249
|
-
|
|
250
|
-
return EmbeddingResponse(
|
|
251
|
-
embedding=embedding,
|
|
252
|
-
model=model,
|
|
253
|
-
provider=self.provider_name,
|
|
254
|
-
dimensions=len(embedding),
|
|
255
|
-
)
|
|
256
|
-
except requests.RequestException as e:
|
|
257
|
-
raise ConnectionError(f"Ollama embed failed: {e}") from e
|
|
433
|
+
except requests.RequestException as e:
|
|
434
|
+
raise IndexingError("Ollama embed failed", e) from e
|
|
258
435
|
|
|
259
436
|
def embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
260
|
-
"""Generate embeddings for multiple texts in a single API call.
|
|
437
|
+
"""Generate embeddings for multiple texts in a single API call with resilience.
|
|
261
438
|
|
|
262
439
|
The /api/embed endpoint supports batch inputs natively.
|
|
263
440
|
"""
|
|
441
|
+
if self.use_resilience and self._embed_policy is not None:
|
|
442
|
+
|
|
443
|
+
@self._embed_policy
|
|
444
|
+
def _protected_embed_batch() -> list[EmbeddingResponse]:
|
|
445
|
+
return self._do_embed_batch(texts, model)
|
|
446
|
+
|
|
447
|
+
try:
|
|
448
|
+
return _protected_embed_batch()
|
|
449
|
+
except ProtectedCallError as e:
|
|
450
|
+
logger.warning(f"Circuit breaker OPEN for ollama.embed_batch (model={model}, batch_size={len(texts)})")
|
|
451
|
+
raise ProviderError("Ollama embedding service unavailable - circuit breaker open", e) from e
|
|
452
|
+
except RetryLimitReached as e:
|
|
453
|
+
logger.error(f"Retry limit reached for ollama.embed_batch (model={model}): {e.__cause__}")
|
|
454
|
+
raise IndexingError("Ollama batch embed failed after retries", e.__cause__) from e
|
|
455
|
+
else:
|
|
456
|
+
return self._do_embed_batch(texts, model)
|
|
457
|
+
|
|
458
|
+
def _do_embed_batch(self, texts: list[str], model: str) -> list[EmbeddingResponse]:
|
|
459
|
+
"""Internal batch embed implementation (unprotected)."""
|
|
264
460
|
self._current_embed_model = model
|
|
265
461
|
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
266
462
|
|
|
267
463
|
# Truncate oversized inputs
|
|
268
|
-
truncated_texts = [text
|
|
269
|
-
|
|
270
|
-
try:
|
|
271
|
-
response = self.session.post(
|
|
272
|
-
f"{self.embedding_url}/api/embed",
|
|
273
|
-
json={"model": model, "input": truncated_texts},
|
|
274
|
-
timeout=self.timeout,
|
|
275
|
-
)
|
|
276
|
-
response.raise_for_status()
|
|
277
|
-
data = response.json()
|
|
278
|
-
embeddings_list = data.get("embeddings", [])
|
|
279
|
-
|
|
280
|
-
if not embeddings_list:
|
|
281
|
-
raise ValueError("Empty embeddings returned from Ollama")
|
|
282
|
-
|
|
283
|
-
results = []
|
|
284
|
-
for embedding_data in embeddings_list:
|
|
285
|
-
embedding = tuple(embedding_data) if embedding_data else ()
|
|
286
|
-
if embedding:
|
|
287
|
-
self._current_dimensions = len(embedding)
|
|
464
|
+
truncated_texts = [_truncate_text(text, self.MAX_EMBED_CHARS) for text in texts]
|
|
465
|
+
truncated_count = sum(1 for t, tt in zip(texts, truncated_texts, strict=True) if len(t) != len(tt))
|
|
288
466
|
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
)
|
|
467
|
+
with log_operation(
|
|
468
|
+
"ollama.embed_batch", model=model, batch_size=len(texts), truncated_count=truncated_count
|
|
469
|
+
) as ctx:
|
|
470
|
+
try:
|
|
471
|
+
response = self.session.post(
|
|
472
|
+
f"{self.embedding_url}/api/embed",
|
|
473
|
+
headers=self._get_headers(),
|
|
474
|
+
json={"model": model, "input": truncated_texts},
|
|
475
|
+
timeout=self._timeouts["embed_batch"],
|
|
296
476
|
)
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
477
|
+
response.raise_for_status()
|
|
478
|
+
data = response.json()
|
|
479
|
+
embeddings_list = data.get("embeddings", [])
|
|
480
|
+
|
|
481
|
+
if not embeddings_list:
|
|
482
|
+
raise ValueError("Empty embeddings returned from Ollama")
|
|
483
|
+
|
|
484
|
+
results = []
|
|
485
|
+
for embedding_data in embeddings_list:
|
|
486
|
+
embedding = tuple(embedding_data) if embedding_data else ()
|
|
487
|
+
if embedding:
|
|
488
|
+
self._current_dimensions = len(embedding)
|
|
489
|
+
|
|
490
|
+
results.append(
|
|
491
|
+
EmbeddingResponse(
|
|
492
|
+
embedding=embedding,
|
|
493
|
+
model=model,
|
|
494
|
+
provider=self.provider_name,
|
|
495
|
+
dimensions=len(embedding),
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
ctx["dimensions"] = self._current_dimensions
|
|
500
|
+
return results
|
|
501
|
+
except requests.RequestException as e:
|
|
502
|
+
raise IndexingError("Ollama batch embed failed", e) from e
|
|
300
503
|
|
|
301
504
|
async def embed_batch_async(
|
|
302
505
|
self,
|
|
@@ -326,8 +529,8 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
326
529
|
|
|
327
530
|
Examples
|
|
328
531
|
--------
|
|
329
|
-
>>> import
|
|
330
|
-
>>> embeddings =
|
|
532
|
+
>>> import asyncio
|
|
533
|
+
>>> embeddings = asyncio.run(provider.embed_batch_async(texts, "mxbai-embed-large"))
|
|
331
534
|
"""
|
|
332
535
|
self._current_embed_model = model
|
|
333
536
|
self._current_dimensions = self.EMBEDDING_DIMENSIONS.get(model, 768)
|
|
@@ -340,7 +543,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
340
543
|
response = await client.post(
|
|
341
544
|
f"{self.embedding_url}/api/embed",
|
|
342
545
|
json={"model": model, "input": truncated_texts},
|
|
343
|
-
timeout=self.
|
|
546
|
+
timeout=self._timeouts["embed_batch"],
|
|
344
547
|
)
|
|
345
548
|
response.raise_for_status()
|
|
346
549
|
data = response.json()
|
|
@@ -365,7 +568,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
365
568
|
)
|
|
366
569
|
return results
|
|
367
570
|
except httpx.HTTPError as e:
|
|
368
|
-
raise
|
|
571
|
+
raise IndexingError("Ollama async batch embed failed", e) from e
|
|
369
572
|
|
|
370
573
|
def chat(
|
|
371
574
|
self,
|
|
@@ -375,7 +578,7 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
375
578
|
max_tokens: int | None = None,
|
|
376
579
|
) -> LLMResponse:
|
|
377
580
|
"""
|
|
378
|
-
Chat completion using Ollama.
|
|
581
|
+
Chat completion using Ollama with optional resilience.
|
|
379
582
|
|
|
380
583
|
Parameters
|
|
381
584
|
----------
|
|
@@ -393,6 +596,31 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
393
596
|
LLMResponse
|
|
394
597
|
The generated response.
|
|
395
598
|
"""
|
|
599
|
+
if self.use_resilience and self._generate_policy is not None:
|
|
600
|
+
|
|
601
|
+
@self._generate_policy
|
|
602
|
+
def _protected_chat() -> LLMResponse:
|
|
603
|
+
return self._do_chat(messages, model, temperature, max_tokens)
|
|
604
|
+
|
|
605
|
+
try:
|
|
606
|
+
return _protected_chat()
|
|
607
|
+
except ProtectedCallError as e:
|
|
608
|
+
logger.warning(f"Circuit breaker OPEN for ollama.chat (model={model})")
|
|
609
|
+
raise ProviderError("Ollama service unavailable - circuit breaker open", e) from e
|
|
610
|
+
except RetryLimitReached as e:
|
|
611
|
+
logger.error(f"Retry limit reached for ollama.chat (model={model}): {e.__cause__}")
|
|
612
|
+
raise ProviderError("Ollama chat failed after retries", e.__cause__) from e
|
|
613
|
+
else:
|
|
614
|
+
return self._do_chat(messages, model, temperature, max_tokens)
|
|
615
|
+
|
|
616
|
+
def _do_chat(
|
|
617
|
+
self,
|
|
618
|
+
messages: list[dict[str, str]],
|
|
619
|
+
model: str,
|
|
620
|
+
temperature: float = 0.7,
|
|
621
|
+
max_tokens: int | None = None,
|
|
622
|
+
) -> LLMResponse:
|
|
623
|
+
"""Internal chat implementation (unprotected)."""
|
|
396
624
|
options: dict[str, float | int] = {"temperature": temperature}
|
|
397
625
|
if max_tokens:
|
|
398
626
|
options["num_predict"] = max_tokens
|
|
@@ -404,26 +632,56 @@ class OllamaProvider(BaseLLMProvider, BaseEmbeddingProvider):
|
|
|
404
632
|
"options": options,
|
|
405
633
|
}
|
|
406
634
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
635
|
+
with log_operation("ollama.chat", model=model, message_count=len(messages)) as ctx:
|
|
636
|
+
try:
|
|
637
|
+
response = self.session.post(
|
|
638
|
+
f"{self.base_url}/api/chat",
|
|
639
|
+
headers=self._get_headers(),
|
|
640
|
+
json=payload,
|
|
641
|
+
timeout=self._timeouts["chat"],
|
|
642
|
+
)
|
|
643
|
+
response.raise_for_status()
|
|
644
|
+
data = response.json()
|
|
415
645
|
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
646
|
+
ctx["completion_tokens"] = data.get("eval_count")
|
|
647
|
+
|
|
648
|
+
return LLMResponse(
|
|
649
|
+
text=data.get("message", {}).get("content", ""),
|
|
650
|
+
model=model,
|
|
651
|
+
provider=self.provider_name,
|
|
652
|
+
usage={
|
|
653
|
+
"prompt_tokens": data.get("prompt_eval_count"),
|
|
654
|
+
"completion_tokens": data.get("eval_count"),
|
|
655
|
+
},
|
|
656
|
+
)
|
|
657
|
+
except requests.RequestException as e:
|
|
658
|
+
raise ProviderError("Ollama chat failed", e) from e
|
|
659
|
+
|
|
660
|
+
# Circuit breaker status monitoring
|
|
661
|
+
@property
|
|
662
|
+
def generate_circuit_status(self) -> str:
|
|
663
|
+
"""Get generate circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled')."""
|
|
664
|
+
if not self.use_resilience or self._generate_policy is None:
|
|
665
|
+
return "disabled"
|
|
666
|
+
# Access the circuit protector (second policy in SafetyNet)
|
|
667
|
+
policies = getattr(self._generate_policy, "policies", None)
|
|
668
|
+
if policies is None or len(policies) < 2:
|
|
669
|
+
return "unknown"
|
|
670
|
+
circuit = policies[1]
|
|
671
|
+
status = getattr(circuit, "status", None)
|
|
672
|
+
return str(getattr(status, "name", "unknown"))
|
|
673
|
+
|
|
674
|
+
@property
|
|
675
|
+
def embed_circuit_status(self) -> str:
|
|
676
|
+
"""Get embed circuit breaker status (CLOSED, OPEN, HALF_OPEN, or 'disabled')."""
|
|
677
|
+
if not self.use_resilience or self._embed_policy is None:
|
|
678
|
+
return "disabled"
|
|
679
|
+
policies = getattr(self._embed_policy, "policies", None)
|
|
680
|
+
if policies is None or len(policies) < 2:
|
|
681
|
+
return "unknown"
|
|
682
|
+
circuit = policies[1]
|
|
683
|
+
status = getattr(circuit, "status", None)
|
|
684
|
+
return str(getattr(status, "name", "unknown"))
|
|
427
685
|
|
|
428
686
|
@staticmethod
|
|
429
687
|
def clear_embedding_cache() -> None:
|