deepeval 3.5.2__py3-none-any.whl → 3.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,30 +1,34 @@
1
1
  """Generic retry policy helpers for provider SDKs.
2
2
 
3
3
  This module lets models define *what is transient* vs *non-retryable* (permanent) failure
4
- without coupling to a specific SDK. You provide an `ErrorPolicy` describing
5
- exception classes and special “non-retryable” error codes, such as quota-exhausted from OpenAI,
6
- and get back a Tenacity predicate suitable for `retry_if_exception`.
7
-
8
- Typical use:
9
-
10
- # Import dependencies
11
- from tenacity import retry, before_sleep_log
12
- from deepeval.models.retry_policy import (
13
- OPENAI_ERROR_POLICY, default_wait, default_stop, retry_predicate
14
- )
15
-
16
- # Define retry rule keywords
17
- _retry_kw = dict(
18
- wait=default_wait(),
19
- stop=default_stop(),
20
- retry=retry_predicate(OPENAI_ERROR_POLICY),
21
- before_sleep=before_sleep_log(logger, logging.INFO), # <- Optional: logs only on retries
22
- )
23
-
24
- # Apply retry rule keywords where desired
25
- @retry(**_retry_kw)
26
- def call_openai(...):
27
- ...
4
+ without coupling to a specific SDK. You provide an `ErrorPolicy` describing exception classes
5
+ and special “non-retryable” error codes (quota-exhausted), and get back Tenacity components:
6
+ a predicate suitable for `retry_if_exception`, plus convenience helpers for wait/stop/backoff.
7
+ You can also use `create_retry_decorator(slug)` to wire Tenacity with dynamic policy + logging.
8
+
9
+ Notes:
10
+ - `extract_error_code` best-effort parses codes from response JSON, `e.body`, botocore-style maps,
11
+ gRPC `e.code().name`, or message markers.
12
+ - `dynamic_retry(slug)` consults settings at call time: if SDK retries are enabled for the slug,
13
+ Tenacity will not retry.
14
+ - Logging callbacks (`before_sleep`, `after`) read log levels dynamically and log to
15
+ the `deepeval.retry.<slug>` logger.
16
+
17
+ Configuration
18
+ -------------
19
+ Retry backoff (env):
20
+ DEEPEVAL_RETRY_MAX_ATTEMPTS int (default 2, >=1)
21
+ DEEPEVAL_RETRY_INITIAL_SECONDS float (default 1.0, >=0)
22
+ DEEPEVAL_RETRY_EXP_BASE float (default 2.0, >=1)
23
+ DEEPEVAL_RETRY_JITTER float (default 2.0, >=0)
24
+ DEEPEVAL_RETRY_CAP_SECONDS float (default 5.0, >=0)
25
+
26
+ SDK-managed retries (settings):
27
+ settings.DEEPEVAL_SDK_RETRY_PROVIDERS list[str] # e.g. ["azure"] or ["*"] for all
28
+
29
+ Retry logging (settings; read at call time):
30
+ settings.DEEPEVAL_RETRY_BEFORE_LOG_LEVEL int/name (default INFO)
31
+ settings.DEEPEVAL_RETRY_AFTER_LOG_LEVEL int/name (default ERROR)
28
32
  """
29
33
 
30
34
  from __future__ import annotations
@@ -33,16 +37,27 @@ import logging
33
37
 
34
38
  from deepeval.utils import read_env_int, read_env_float
35
39
  from dataclasses import dataclass, field
36
- from typing import Iterable, Mapping, Callable, Sequence, Tuple
40
+ from typing import Callable, Iterable, Mapping, Optional, Sequence, Tuple, Union
37
41
  from collections.abc import Mapping as ABCMapping
38
42
  from tenacity import (
43
+ RetryCallState,
44
+ retry,
39
45
  wait_exponential_jitter,
40
46
  stop_after_attempt,
41
47
  retry_if_exception,
42
48
  )
49
+ from tenacity.stop import stop_base
50
+ from tenacity.wait import wait_base
51
+
52
+ from deepeval.constants import (
53
+ ProviderSlug as PS,
54
+ slugify,
55
+ )
56
+ from deepeval.config.settings import get_settings
43
57
 
44
58
 
45
59
  logger = logging.getLogger(__name__)
60
+ Provider = Union[str, PS]
46
61
 
47
62
  # --------------------------
48
63
  # Policy description
@@ -89,9 +104,9 @@ def extract_error_code(
89
104
  """Best effort extraction of an error 'code' for SDK compatibility.
90
105
 
91
106
  Order of attempts:
92
- 1) Structured JSON via `e.response.json()` (typical HTTP error payload).
93
- 2) A dict stored on `e.body` (some gateways/proxies use this).
94
- 3) Message sniffing fallback, using `message_markers`.
107
+ 1. Structured JSON via `e.response.json()` (typical HTTP error payload).
108
+ 2. A dict stored on `e.body` (some gateways/proxies use this).
109
+ 3. Message sniffing fallback, using `message_markers`.
95
110
 
96
111
  Args:
97
112
  e: The exception raised by the SDK/provider client.
@@ -103,23 +118,47 @@ def extract_error_code(
103
118
  Returns:
104
119
  The code string if found, else "".
105
120
  """
106
- # 1) Structured JSON in e.response.json()
121
+ # 0. gRPC: use e.code() -> grpc.StatusCode
122
+ code_fn = getattr(e, "code", None)
123
+ if callable(code_fn):
124
+ try:
125
+ sc = code_fn()
126
+ name = getattr(sc, "name", None) or str(sc)
127
+ if isinstance(name, str):
128
+ return name.lower()
129
+ except Exception:
130
+ pass
131
+
132
+ # 1. Structured JSON in e.response.json()
107
133
  resp = getattr(e, response_attr, None)
108
134
  if resp is not None:
109
- try:
110
- cur = resp.json()
111
- for k in code_path:
135
+
136
+ if isinstance(resp, ABCMapping):
137
+ # Structured mapping directly on response
138
+ cur = resp
139
+ for k in ("Error", "Code"): # <- AWS boto style Error / Code
112
140
  if not isinstance(cur, ABCMapping):
113
141
  cur = {}
114
142
  break
115
143
  cur = cur.get(k, {})
116
144
  if isinstance(cur, (str, int)):
117
145
  return str(cur)
118
- except Exception:
119
- # response.json() can raise; ignore and fall through
120
- pass
121
146
 
122
- # 2) SDK provided dict body
147
+ else:
148
+ try:
149
+ cur = resp.json()
150
+ for k in code_path:
151
+ if not isinstance(cur, ABCMapping):
152
+ cur = {}
153
+ break
154
+ cur = cur.get(k, {})
155
+ if isinstance(cur, (str, int)):
156
+ return str(cur)
157
+ except Exception:
158
+ # if response.json() raises, ignore and fall through
159
+ pass
160
+
161
+ # 2. SDK provided dict body
123
162
  body = getattr(e, body_attr, None)
124
163
  if isinstance(body, ABCMapping):
125
164
  cur = body
@@ -131,7 +170,7 @@ def extract_error_code(
131
170
  if isinstance(cur, (str, int)):
132
171
  return str(cur)
133
172
 
134
- # 3) Message sniff (hopefully this helps catch message codes that slip past the previous 2 parsers)
173
+ # 3. Message sniff (hopefully this helps catch message codes that slip past the previous 2 parsers)
135
174
  msg = str(e).lower()
136
175
  markers = message_markers or {}
137
176
  for code_key, needles in markers.items():
@@ -181,6 +220,7 @@ def make_is_transient(
181
220
  code = extract_error_code(
182
221
  e, message_markers=(message_markers or policy.message_markers)
183
222
  )
223
+ code = (code or "").lower()
184
224
  return code not in non_retryable
185
225
 
186
226
  if isinstance(e, policy.network_excs):
@@ -203,32 +243,31 @@ def make_is_transient(
203
243
  # --------------------------
204
244
 
205
245
 
206
- def default_wait():
207
- """Default backoff: exponential with jitter, capped.
208
- Overridable via env:
209
- - DEEPEVAL_RETRY_INITIAL_SECONDS (>=0)
210
- - DEEPEVAL_RETRY_EXP_BASE (>=1)
211
- - DEEPEVAL_RETRY_JITTER (>=0)
212
- - DEEPEVAL_RETRY_CAP_SECONDS (>=0)
213
- """
214
- initial = read_env_float(
215
- "DEEPEVAL_RETRY_INITIAL_SECONDS", 1.0, min_value=0.0
216
- )
217
- exp_base = read_env_float("DEEPEVAL_RETRY_EXP_BASE", 2.0, min_value=1.0)
218
- jitter = read_env_float("DEEPEVAL_RETRY_JITTER", 2.0, min_value=0.0)
219
- cap = read_env_float("DEEPEVAL_RETRY_CAP_SECONDS", 5.0, min_value=0.0)
220
- return wait_exponential_jitter(
221
- initial=initial, exp_base=exp_base, jitter=jitter, max=cap
222
- )
246
+ class StopFromEnv(stop_base):
247
+ def __call__(self, retry_state):
248
+ attempts = read_env_int("DEEPEVAL_RETRY_MAX_ATTEMPTS", 2, min_value=1)
249
+ return stop_after_attempt(attempts)(retry_state)
223
250
 
224
251
 
225
- def default_stop():
226
- """Default stop condition: at most N attempts (N-1 retries).
227
- Overridable via env:
228
- - DEEPEVAL_RETRY_MAX_ATTEMPTS (>=1)
229
- """
230
- attempts = read_env_int("DEEPEVAL_RETRY_MAX_ATTEMPTS", 2, min_value=1)
231
- return stop_after_attempt(attempts)
252
+ class WaitFromEnv(wait_base):
253
+ def __call__(self, retry_state):
254
+ initial = read_env_float(
255
+ "DEEPEVAL_RETRY_INITIAL_SECONDS", 1.0, min_value=0.0
256
+ )
257
+ exp_base = read_env_float("DEEPEVAL_RETRY_EXP_BASE", 2.0, min_value=1.0)
258
+ jitter = read_env_float("DEEPEVAL_RETRY_JITTER", 2.0, min_value=0.0)
259
+ cap = read_env_float("DEEPEVAL_RETRY_CAP_SECONDS", 5.0, min_value=0.0)
260
+ return wait_exponential_jitter(
261
+ initial=initial, exp_base=exp_base, jitter=jitter, max=cap
262
+ )(retry_state)
263
+
264
+
265
+ def dynamic_stop():
266
+ return StopFromEnv()
267
+
268
+
269
+ def dynamic_wait():
270
+ return WaitFromEnv()
232
271
 
233
272
 
234
273
  def retry_predicate(policy: ErrorPolicy, **kw):
@@ -240,11 +279,189 @@ def retry_predicate(policy: ErrorPolicy, **kw):
240
279
  return retry_if_exception(make_is_transient(policy, **kw))
241
280
 
242
281
 
282
+ ###########
283
+ # Helpers #
284
+ ###########
285
+ # Convenience helpers
286
+
287
+
288
+ def sdk_retries_for(provider: Provider) -> bool:
289
+ """True if this provider should delegate retries to the SDK (per settings)."""
290
+ chosen = get_settings().DEEPEVAL_SDK_RETRY_PROVIDERS or []
291
+ slug = slugify(provider)
292
+ return "*" in chosen or slug in chosen
293
+
294
+
295
+ def get_retry_policy_for(provider: Provider) -> Optional[ErrorPolicy]:
296
+ """
297
+ Return the ErrorPolicy for a given provider slug, or None when:
298
+ - the user requested SDK-managed retries for this provider, OR
299
+ - we have no usable policy (optional dependency missing).
300
+ """
301
+ if sdk_retries_for(provider):
302
+ return None
303
+ slug = slugify(provider)
304
+ return _POLICY_BY_SLUG.get(slug) or None
305
+
306
+
307
+ def dynamic_retry(provider: Provider):
308
+ """
309
+ Tenacity retry= argument that checks settings at *call time*.
310
+ If SDK retries are chosen (or no policy available), it never retries.
311
+ """
312
+ slug = slugify(provider)
313
+ static_pred = _STATIC_PRED_BY_SLUG.get(slug)
314
+
315
+ def _pred(e: Exception) -> bool:
316
+ if sdk_retries_for(slug):
317
+ return False # hand off to SDK
318
+ if static_pred is None:
319
+ return False # no policy -> no Tenacity retries
320
+ return static_pred(e) # use prebuilt predicate
321
+
322
+ return retry_if_exception(_pred)
323
+
324
+
325
+ def _retry_log_levels():
326
+ s = get_settings()
327
+ before_level = s.DEEPEVAL_RETRY_BEFORE_LOG_LEVEL
328
+ after_level = s.DEEPEVAL_RETRY_AFTER_LOG_LEVEL
329
+ return (
330
+ before_level if before_level is not None else logging.INFO,
331
+ after_level if after_level is not None else logging.ERROR,
332
+ )
333
+
334
+
335
+ def make_before_sleep_log(slug: str):
336
+ """
337
+ Tenacity 'before_sleep' callback: runs before Tenacity sleeps for the next retry.
338
+ Read the level dynamically each time.
339
+ """
340
+ _logger = logging.getLogger(f"deepeval.retry.{slug}")
341
+
342
+ def _before_sleep(retry_state: RetryCallState) -> None:
343
+ before_level, _ = _retry_log_levels()
344
+ if not _logger.isEnabledFor(before_level):
345
+ return
346
+
347
+ exc = retry_state.outcome.exception()
348
+ sleep = getattr(
349
+ getattr(retry_state, "next_action", None), "sleep", None
350
+ )
351
+
352
+ _logger.log(
353
+ before_level,
354
+ "Retrying in %s s (attempt %s) after %r",
355
+ sleep,
356
+ retry_state.attempt_number,
357
+ exc,
358
+ )
359
+
360
+ return _before_sleep
361
+
362
+
363
+ def make_after_log(slug: str):
364
+ """
365
+ Tenacity 'after' callback: runs after each attempt. We log only when the
366
+ attempt raised, and we look up the level dynamically so changes to settings
367
+ take effect immediately.
368
+ """
369
+ _logger = logging.getLogger(f"deepeval.retry.{slug}")
370
+
371
+ def _after(retry_state: RetryCallState) -> None:
372
+ exc = retry_state.outcome.exception()
373
+ if exc is None:
374
+ return
375
+
376
+ _, after_level = _retry_log_levels()
377
+ if not _logger.isEnabledFor(after_level):
378
+ return
379
+
380
+ exc_info = (
381
+ (type(exc), exc, getattr(exc, "__traceback__", None))
382
+ if after_level >= logging.ERROR
383
+ else None
384
+ )
385
+
386
+ _logger.log(
387
+ after_level,
388
+ "%s Retrying: %s time(s)...",
389
+ exc,
390
+ retry_state.attempt_number,
391
+ exc_info=exc_info,
392
+ )
393
+
394
+ return _after
395
+
396
+
397
+ def create_retry_decorator(provider: Provider):
398
+ """
399
+ Build a Tenacity @retry decorator wired to our dynamic retry policy
400
+ for the given provider slug.
401
+ """
402
+ slug = slugify(provider)
403
+
404
+ return retry(
405
+ wait=dynamic_wait(),
406
+ stop=dynamic_stop(),
407
+ retry=dynamic_retry(slug),
408
+ before_sleep=make_before_sleep_log(slug),
409
+ after=make_after_log(slug),
410
+ )
411
+
412
+
413
+ def _httpx_net_excs() -> tuple[type, ...]:
414
+ try:
415
+ import httpx
416
+ except Exception:
417
+ return ()
418
+ names = (
419
+ "RequestError", # base for transport errors
420
+ "TimeoutException", # base for timeouts
421
+ "ConnectError",
422
+ "ConnectTimeout",
423
+ "ReadTimeout",
424
+ "WriteTimeout",
425
+ "PoolTimeout",
426
+ )
427
+ return tuple(getattr(httpx, n) for n in names if hasattr(httpx, n))
428
+
429
+
430
+ def _requests_net_excs() -> tuple[type, ...]:
431
+ try:
432
+ import requests
433
+ except Exception:
434
+ return ()
435
+ names = (
436
+ "RequestException",
437
+ "Timeout",
438
+ "ConnectionError",
439
+ "ReadTimeout",
440
+ "SSLError",
441
+ "ChunkedEncodingError",
442
+ )
443
+ return tuple(
444
+ getattr(requests.exceptions, n)
445
+ for n in names
446
+ if hasattr(requests.exceptions, n)
447
+ )
448
+
449
+
243
450
  # --------------------------
244
451
  # Built-in policies
245
452
  # --------------------------
453
+
454
+ ##################
455
+ # Open AI Policy #
456
+ ##################
457
+
246
458
  OPENAI_MESSAGE_MARKERS: dict[str, tuple[str, ...]] = {
247
- "insufficient_quota": ("insufficient_quota", "exceeded your current quota"),
459
+ "insufficient_quota": (
460
+ "insufficient_quota",
461
+ "insufficient quota",
462
+ "exceeded your current quota",
463
+ "requestquotaexceeded",
464
+ ),
248
465
  }
249
466
 
250
467
  try:
@@ -268,13 +485,280 @@ except Exception: # pragma: no cover - OpenAI may not be installed in some envs
268
485
  OPENAI_ERROR_POLICY = None
269
486
 
270
487
 
488
+ ##########################
489
+ # Models that use OpenAI #
490
+ ##########################
491
+ AZURE_OPENAI_ERROR_POLICY = OPENAI_ERROR_POLICY
492
+ DEEPSEEK_ERROR_POLICY = OPENAI_ERROR_POLICY
493
+ KIMI_ERROR_POLICY = OPENAI_ERROR_POLICY
494
+ LOCAL_ERROR_POLICY = OPENAI_ERROR_POLICY
495
+
496
+ ######################
497
+ # AWS Bedrock Policy #
498
+ ######################
499
+
500
+ try:
501
+ from botocore.exceptions import (
502
+ ClientError,
503
+ EndpointConnectionError,
504
+ ConnectTimeoutError,
505
+ ReadTimeoutError,
506
+ ConnectionClosedError,
507
+ )
508
+
509
+ # Map common AWS error messages to keys via substring match (lowercased)
510
+ # Update as we encounter new error messages from the sdk
511
+ # These messages are heuristics, we don't have a list of exact error messages
512
+ BEDROCK_MESSAGE_MARKERS = {
513
+ # retryable throttling / transient
514
+ "throttlingexception": (
515
+ "throttlingexception",
516
+ "too many requests",
517
+ "rate exceeded",
518
+ ),
519
+ "serviceunavailableexception": (
520
+ "serviceunavailableexception",
521
+ "service unavailable",
522
+ ),
523
+ "internalserverexception": (
524
+ "internalserverexception",
525
+ "internal server error",
526
+ ),
527
+ "modeltimeoutexception": ("modeltimeoutexception", "model timeout"),
528
+ # clear non-retryables
529
+ "accessdeniedexception": ("accessdeniedexception",),
530
+ "validationexception": ("validationexception",),
531
+ "resourcenotfoundexception": ("resourcenotfoundexception",),
532
+ }
533
+
534
+ BEDROCK_ERROR_POLICY = ErrorPolicy(
535
+ auth_excs=(),
536
+ rate_limit_excs=(
537
+ ClientError,
538
+ ), # classify by code extracted from message
539
+ network_excs=(
540
+ EndpointConnectionError,
541
+ ConnectTimeoutError,
542
+ ReadTimeoutError,
543
+ ConnectionClosedError,
544
+ ),
545
+ http_excs=(), # no status_code attributes. We will rely on ClientError + markers
546
+ non_retryable_codes=frozenset(
547
+ {
548
+ "accessdeniedexception",
549
+ "validationexception",
550
+ "resourcenotfoundexception",
551
+ }
552
+ ),
553
+ message_markers=BEDROCK_MESSAGE_MARKERS,
554
+ )
555
+ except Exception: # botocore not present (aiobotocore optional)
556
+ BEDROCK_ERROR_POLICY = None
557
+
558
+
559
+ ####################
560
+ # Anthropic Policy #
561
+ ####################
562
+
563
+ try:
564
+ from anthropic import (
565
+ AuthenticationError,
566
+ RateLimitError,
567
+ APIConnectionError,
568
+ APITimeoutError,
569
+ APIStatusError,
570
+ )
571
+
572
+ ANTHROPIC_ERROR_POLICY = ErrorPolicy(
573
+ auth_excs=(AuthenticationError,),
574
+ rate_limit_excs=(RateLimitError,),
575
+ network_excs=(APIConnectionError, APITimeoutError),
576
+ http_excs=(APIStatusError,),
577
+ non_retryable_codes=frozenset(), # update if we learn of hard quota codes
578
+ message_markers={},
579
+ )
580
+ except Exception: # Anthropic optional
581
+ ANTHROPIC_ERROR_POLICY = None
582
+
583
+
584
+ #####################
585
+ # Google/Gemini Policy
586
+ #####################
587
+ # The google genai SDK raises google.genai.errors.*. Public docs and issues show:
588
+ # - errors.ClientError for 4xx like 400/401/403/404/422/429
589
+ # - errors.ServerError for 5xx
590
+ # - errors.APIError is a common base that exposes `.code` and message text
591
+ # The SDK doesn’t guarantee a `.status_code` attribute, but it commonly exposes `.code`,
592
+ # so we treat ServerError as transient (network-like) to get 5xx retries.
593
+ # For rate limiting (429 Resource Exhausted), we treat *ClientError* as rate limit class
594
+ # and gate retries using message markers (code sniffing).
595
+ # See: https://github.com/googleapis/python-genai?tab=readme-ov-file#error-handling
596
+ try:
597
+ from google.genai import errors as gerrors
598
+
599
+ _HTTPX_NET_EXCS = _httpx_net_excs()
600
+ _REQUESTS_EXCS = _requests_net_excs()
601
+
602
+ GOOGLE_MESSAGE_MARKERS = {
603
+ # retryable rate limit
604
+ "429": ("429", "resource_exhausted", "rate limit"),
605
+ # clearly non-retryable client codes
606
+ "401": ("401", "unauthorized", "api key"),
607
+ "403": ("403", "permission denied", "forbidden"),
608
+ "404": ("404", "not found"),
609
+ "400": ("400", "invalid argument", "bad request"),
610
+ "422": ("422", "failed_precondition", "unprocessable"),
611
+ }
612
+
613
+ GOOGLE_ERROR_POLICY = ErrorPolicy(
614
+ auth_excs=(), # we will classify 401/403 via markers below (see non-retryable codes)
615
+ rate_limit_excs=(
616
+ gerrors.ClientError,
617
+ ), # includes 429; markers decide retry vs not
618
+ network_excs=(gerrors.ServerError,)
619
+ + _HTTPX_NET_EXCS
620
+ + _REQUESTS_EXCS, # treat 5xx as transient
621
+ http_excs=(), # no reliable .status_code on exceptions; handled above
622
+ # Non-retryable codes for *ClientError*. Anything else is retried.
623
+ non_retryable_codes=frozenset({"400", "401", "403", "404", "422"}),
624
+ message_markers=GOOGLE_MESSAGE_MARKERS,
625
+ )
626
+ except Exception:
627
+ GOOGLE_ERROR_POLICY = None
628
+
629
+ #################
630
+ # Grok Policy #
631
+ #################
632
+ # The xAI Python SDK (xai-sdk) uses gRPC. Errors raised are grpc.RpcError (sync)
633
+ # and grpc.aio.AioRpcError (async). The SDK retries UNAVAILABLE by default with
634
+ # backoff; you can disable via channel option ("grpc.enable_retries", 0) or
635
+ # customize via "grpc.service_config". See xai-sdk docs.
636
+ # Refs:
637
+ # - https://github.com/xai-org/xai-sdk-python/blob/main/README.md#retries
638
+ # - https://github.com/xai-org/xai-sdk-python/blob/main/README.md#error-codes
639
+ try:
640
+ import grpc
641
+
642
+ try:
643
+ from grpc import aio as grpc_aio
644
+
645
+ _AioRpcError = getattr(grpc_aio, "AioRpcError", None)
646
+ except Exception:
647
+ _AioRpcError = None
648
+
649
+ _GRPC_EXCS = tuple(
650
+ c for c in (getattr(grpc, "RpcError", None), _AioRpcError) if c
651
+ )
652
+
653
+ # rely on extract_error_code reading e.code().name (lowercased).
654
+ GROK_ERROR_POLICY = ErrorPolicy(
655
+ auth_excs=(), # handled via code() mapping below
656
+ rate_limit_excs=_GRPC_EXCS, # gated by code() value
657
+ network_excs=(), # gRPC code handles transience
658
+ http_excs=(), # no .status_code on gRPC errors
659
+ non_retryable_codes=frozenset(
660
+ {
661
+ "invalid_argument",
662
+ "unauthenticated",
663
+ "permission_denied",
664
+ "not_found",
665
+ "resource_exhausted",
666
+ "failed_precondition",
667
+ "out_of_range",
668
+ "unimplemented",
669
+ "data_loss",
670
+ }
671
+ ),
672
+ message_markers={},
673
+ )
674
+ except Exception: # xai-sdk/grpc not present
675
+ GROK_ERROR_POLICY = None
676
+
677
+
678
+ ############
679
+ # Lite LLM #
680
+ ############
681
+ LITELLM_ERROR_POLICY = None # TODO: LiteLLM is going to take some extra care. I will return to this task last
682
+
683
+
684
+ #########################
685
+ # Ollama (local server) #
686
+ #########################
687
+
688
+ try:
689
+ # Catch transport + timeout issues via base classes
690
+ _HTTPX_NET_EXCS = _httpx_net_excs()
691
+ _REQUESTS_EXCS = _requests_net_excs()
692
+
693
+ OLLAMA_ERROR_POLICY = ErrorPolicy(
694
+ auth_excs=(),
695
+ rate_limit_excs=(), # no rate limiting semantics locally
696
+ network_excs=_HTTPX_NET_EXCS + _REQUESTS_EXCS, # retry network/timeouts
697
+ http_excs=(), # optionally add httpx.HTTPStatusError if you call raise_for_status()
698
+ non_retryable_codes=frozenset(),
699
+ message_markers={},
700
+ )
701
+ except Exception:
702
+ OLLAMA_ERROR_POLICY = None
703
+
704
+
705
+ # Map provider slugs to their policy objects.
706
+ # It is OK if some are None, we'll treat that as no Error Policy / Tenacity
707
+ _POLICY_BY_SLUG: dict[str, Optional[ErrorPolicy]] = {
708
+ PS.OPENAI.value: OPENAI_ERROR_POLICY,
709
+ PS.AZURE.value: AZURE_OPENAI_ERROR_POLICY,
710
+ PS.BEDROCK.value: BEDROCK_ERROR_POLICY,
711
+ PS.ANTHROPIC.value: ANTHROPIC_ERROR_POLICY,
712
+ PS.DEEPSEEK.value: DEEPSEEK_ERROR_POLICY,
713
+ PS.GOOGLE.value: GOOGLE_ERROR_POLICY,
714
+ PS.GROK.value: GROK_ERROR_POLICY,
715
+ PS.KIMI.value: KIMI_ERROR_POLICY,
716
+ PS.LITELLM.value: LITELLM_ERROR_POLICY,
717
+ PS.LOCAL.value: LOCAL_ERROR_POLICY,
718
+ PS.OLLAMA.value: OLLAMA_ERROR_POLICY,
719
+ }
720
+
721
+
722
+ def _opt_pred(
723
+ policy: Optional[ErrorPolicy],
724
+ ) -> Optional[Callable[[Exception], bool]]:
725
+ return make_is_transient(policy) if policy else None
726
+
727
+
728
+ _STATIC_PRED_BY_SLUG: dict[str, Optional[Callable[[Exception], bool]]] = {
729
+ PS.OPENAI.value: _opt_pred(OPENAI_ERROR_POLICY),
730
+ PS.AZURE.value: _opt_pred(AZURE_OPENAI_ERROR_POLICY),
731
+ PS.BEDROCK.value: _opt_pred(BEDROCK_ERROR_POLICY),
732
+ PS.ANTHROPIC.value: _opt_pred(ANTHROPIC_ERROR_POLICY),
733
+ PS.DEEPSEEK.value: _opt_pred(DEEPSEEK_ERROR_POLICY),
734
+ PS.GOOGLE.value: _opt_pred(GOOGLE_ERROR_POLICY),
735
+ PS.GROK.value: _opt_pred(GROK_ERROR_POLICY),
736
+ PS.KIMI.value: _opt_pred(KIMI_ERROR_POLICY),
737
+ PS.LITELLM.value: _opt_pred(LITELLM_ERROR_POLICY),
738
+ PS.LOCAL.value: _opt_pred(LOCAL_ERROR_POLICY),
739
+ PS.OLLAMA.value: _opt_pred(OLLAMA_ERROR_POLICY),
740
+ }
741
+
742
+
271
743
  __all__ = [
272
744
  "ErrorPolicy",
745
+ "get_retry_policy_for",
746
+ "create_retry_decorator",
747
+ "dynamic_retry",
273
748
  "extract_error_code",
274
749
  "make_is_transient",
275
- "default_wait",
276
- "default_stop",
750
+ "dynamic_stop",
751
+ "dynamic_wait",
277
752
  "retry_predicate",
753
+ "sdk_retries_for",
278
754
  "OPENAI_MESSAGE_MARKERS",
279
755
  "OPENAI_ERROR_POLICY",
756
+ "AZURE_OPENAI_ERROR_POLICY",
757
+ "BEDROCK_ERROR_POLICY",
758
+ "BEDROCK_MESSAGE_MARKERS",
759
+ "ANTHROPIC_ERROR_POLICY",
760
+ "DEEPSEEK_ERROR_POLICY",
761
+ "GOOGLE_ERROR_POLICY",
762
+ "GROK_ERROR_POLICY",
763
+ "LOCAL_ERROR_POLICY",
280
764
  ]