deepeval 3.5.2__py3-none-any.whl → 3.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. deepeval/_version.py +1 -1
  2. deepeval/config/settings.py +94 -2
  3. deepeval/config/utils.py +54 -1
  4. deepeval/constants.py +27 -0
  5. deepeval/integrations/pydantic_ai/__init__.py +3 -1
  6. deepeval/integrations/pydantic_ai/agent.py +339 -0
  7. deepeval/integrations/pydantic_ai/patcher.py +479 -406
  8. deepeval/integrations/pydantic_ai/utils.py +239 -2
  9. deepeval/metrics/mcp_use_metric/mcp_use_metric.py +2 -1
  10. deepeval/metrics/non_advice/non_advice.py +2 -2
  11. deepeval/metrics/pii_leakage/pii_leakage.py +2 -2
  12. deepeval/models/embedding_models/azure_embedding_model.py +40 -9
  13. deepeval/models/embedding_models/local_embedding_model.py +52 -9
  14. deepeval/models/embedding_models/ollama_embedding_model.py +25 -7
  15. deepeval/models/embedding_models/openai_embedding_model.py +47 -5
  16. deepeval/models/llms/amazon_bedrock_model.py +31 -4
  17. deepeval/models/llms/anthropic_model.py +39 -13
  18. deepeval/models/llms/azure_model.py +37 -38
  19. deepeval/models/llms/deepseek_model.py +36 -7
  20. deepeval/models/llms/gemini_model.py +10 -0
  21. deepeval/models/llms/grok_model.py +50 -3
  22. deepeval/models/llms/kimi_model.py +37 -7
  23. deepeval/models/llms/local_model.py +38 -12
  24. deepeval/models/llms/ollama_model.py +15 -3
  25. deepeval/models/llms/openai_model.py +37 -44
  26. deepeval/models/mlllms/gemini_model.py +21 -3
  27. deepeval/models/mlllms/ollama_model.py +38 -13
  28. deepeval/models/mlllms/openai_model.py +18 -42
  29. deepeval/models/retry_policy.py +548 -64
  30. deepeval/tracing/tracing.py +87 -0
  31. {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/METADATA +1 -1
  32. {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/RECORD +35 -34
  33. {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/LICENSE.md +0 -0
  34. {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/WHEEL +0 -0
  35. {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/entry_points.txt +0 -0
@@ -1,8 +1,29 @@
1
- from typing import List
2
- from pydantic_ai.messages import ModelResponsePart
1
+ from time import perf_counter
2
+ from contextlib import asynccontextmanager
3
+ import inspect
4
+ import functools
5
+ from typing import Any, Callable, List, Optional
6
+
7
+ from pydantic_ai.models import Model
3
8
  from pydantic_ai.agent import AgentRunResult
4
9
  from pydantic_ai._run_context import RunContext
10
+ from pydantic_ai.messages import (
11
+ ModelRequest,
12
+ ModelResponse,
13
+ ModelResponsePart,
14
+ SystemPromptPart,
15
+ TextPart,
16
+ ToolCallPart,
17
+ ToolReturnPart,
18
+ UserPromptPart,
19
+ )
20
+
21
+ from deepeval.prompt import Prompt
22
+ from deepeval.tracing.tracing import Observer
23
+ from deepeval.metrics.base_metric import BaseMetric
5
24
  from deepeval.test_case.llm_test_case import ToolCall
25
+ from deepeval.tracing.context import current_trace_context, current_span_context
26
+ from deepeval.tracing.types import AgentSpan, LlmOutput, LlmSpan, LlmToolCall
6
27
 
7
28
 
8
29
  # llm tools called
@@ -84,3 +105,219 @@ def sanitize_run_context(value):
84
105
  return {sanitize_run_context(v) for v in value}
85
106
 
86
107
  return value
108
+
109
+
110
+ def patch_llm_model(
111
+ model: Model,
112
+ llm_metric_collection: Optional[str] = None,
113
+ llm_metrics: Optional[List[BaseMetric]] = None,
114
+ llm_prompt: Optional[Prompt] = None,
115
+ ):
116
+ original_func = model.request
117
+ sig = inspect.signature(original_func)
118
+
119
+ try:
120
+ model_name = model.model_name
121
+ except Exception:
122
+ model_name = "unknown"
123
+
124
+ @functools.wraps(original_func)
125
+ async def wrapper(*args, **kwargs):
126
+ bound = sig.bind_partial(*args, **kwargs)
127
+ bound.apply_defaults()
128
+ request = bound.arguments.get("messages", [])
129
+
130
+ with Observer(
131
+ span_type="llm",
132
+ func_name="LLM",
133
+ observe_kwargs={"model": model_name},
134
+ metrics=llm_metrics,
135
+ metric_collection=llm_metric_collection,
136
+ ) as observer:
137
+ result = await original_func(*args, **kwargs)
138
+ observer.update_span_properties = (
139
+ lambda llm_span: set_llm_span_attributes(
140
+ llm_span, request, result, llm_prompt
141
+ )
142
+ )
143
+ observer.result = result
144
+ return result
145
+
146
+ model.request = wrapper
147
+
148
+ stream_original_func = model.request_stream
149
+ stream_sig = inspect.signature(stream_original_func)
150
+
151
+ @asynccontextmanager
152
+ async def stream_wrapper(*args, **kwargs):
153
+ bound = stream_sig.bind_partial(*args, **kwargs)
154
+ bound.apply_defaults()
155
+ request = bound.arguments.get("messages", [])
156
+
157
+ with Observer(
158
+ span_type="llm",
159
+ func_name="LLM",
160
+ observe_kwargs={"model": model_name},
161
+ metrics=llm_metrics,
162
+ metric_collection=llm_metric_collection,
163
+ ) as observer:
164
+ llm_span: LlmSpan = current_span_context.get()
165
+ async with stream_original_func(
166
+ *args, **kwargs
167
+ ) as streamed_response:
168
+ try:
169
+ yield streamed_response
170
+ if not llm_span.token_intervals:
171
+ llm_span.token_intervals = {perf_counter(): "NA"}
172
+ else:
173
+ llm_span.token_intervals[perf_counter()] = "NA"
174
+ finally:
175
+ try:
176
+ result = streamed_response.get()
177
+ observer.update_span_properties = (
178
+ lambda llm_span: set_llm_span_attributes(
179
+ llm_span, request, result, llm_prompt
180
+ )
181
+ )
182
+ observer.result = result
183
+ except Exception:
184
+ pass
185
+
186
+ model.request_stream = stream_wrapper
187
+
188
+
189
+ def create_patched_tool(
190
+ func: Callable,
191
+ metrics: Optional[List[BaseMetric]] = None,
192
+ metric_collection: Optional[str] = None,
193
+ ):
194
+ import asyncio
195
+
196
+ original_func = func
197
+
198
+ is_async = asyncio.iscoroutinefunction(original_func)
199
+
200
+ if is_async:
201
+
202
+ @functools.wraps(original_func)
203
+ async def async_wrapper(*args, **kwargs):
204
+ sanitized_args = sanitize_run_context(args)
205
+ sanitized_kwargs = sanitize_run_context(kwargs)
206
+ with Observer(
207
+ span_type="tool",
208
+ func_name=original_func.__name__,
209
+ metrics=metrics,
210
+ metric_collection=metric_collection,
211
+ function_kwargs={"args": sanitized_args, **sanitized_kwargs},
212
+ ) as observer:
213
+ result = await original_func(*args, **kwargs)
214
+ observer.result = result
215
+
216
+ return result
217
+
218
+ return async_wrapper
219
+ else:
220
+
221
+ @functools.wraps(original_func)
222
+ def sync_wrapper(*args, **kwargs):
223
+ sanitized_args = sanitize_run_context(args)
224
+ sanitized_kwargs = sanitize_run_context(kwargs)
225
+ with Observer(
226
+ span_type="tool",
227
+ func_name=original_func.__name__,
228
+ metrics=metrics,
229
+ metric_collection=metric_collection,
230
+ function_kwargs={"args": sanitized_args, **sanitized_kwargs},
231
+ ) as observer:
232
+ result = original_func(*args, **kwargs)
233
+ observer.result = result
234
+
235
+ return result
236
+
237
+ return sync_wrapper
238
+
239
+
240
+ def update_trace_context(
241
+ trace_name: Optional[str] = None,
242
+ trace_tags: Optional[List[str]] = None,
243
+ trace_metadata: Optional[dict] = None,
244
+ trace_thread_id: Optional[str] = None,
245
+ trace_user_id: Optional[str] = None,
246
+ trace_metric_collection: Optional[str] = None,
247
+ trace_metrics: Optional[List[BaseMetric]] = None,
248
+ trace_input: Optional[Any] = None,
249
+ trace_output: Optional[Any] = None,
250
+ ):
251
+
252
+ current_trace = current_trace_context.get()
253
+
254
+ if trace_name:
255
+ current_trace.name = trace_name
256
+ if trace_tags:
257
+ current_trace.tags = trace_tags
258
+ if trace_metadata:
259
+ current_trace.metadata = trace_metadata
260
+ if trace_thread_id:
261
+ current_trace.thread_id = trace_thread_id
262
+ if trace_user_id:
263
+ current_trace.user_id = trace_user_id
264
+ if trace_metric_collection:
265
+ current_trace.metric_collection = trace_metric_collection
266
+ if trace_metrics:
267
+ current_trace.metrics = trace_metrics
268
+ if trace_input:
269
+ current_trace.input = trace_input
270
+ if trace_output:
271
+ current_trace.output = trace_output
272
+
273
+
274
+ def set_llm_span_attributes(
275
+ llm_span: LlmSpan,
276
+ requests: List[ModelRequest],
277
+ result: ModelResponse,
278
+ llm_prompt: Optional[Prompt] = None,
279
+ ):
280
+ llm_span.prompt = llm_prompt
281
+
282
+ input = []
283
+ for request in requests:
284
+ for part in request.parts:
285
+ if isinstance(part, SystemPromptPart):
286
+ input.append({"role": "System", "content": part.content})
287
+ elif isinstance(part, UserPromptPart):
288
+ input.append({"role": "User", "content": part.content})
289
+ elif isinstance(part, ToolCallPart):
290
+ input.append(
291
+ {
292
+ "role": "Tool Call",
293
+ "name": part.tool_name,
294
+ "content": part.args_as_json_str(),
295
+ }
296
+ )
297
+ elif isinstance(part, ToolReturnPart):
298
+ input.append(
299
+ {
300
+ "role": "Tool Return",
301
+ "name": part.tool_name,
302
+ "content": part.model_response_str(),
303
+ }
304
+ )
305
+ llm_span.input = input
306
+
307
+ content = ""
308
+ tool_calls = []
309
+ for part in result.parts:
310
+ if isinstance(part, TextPart):
311
+ content += part.content + "\n"
312
+ elif isinstance(part, ToolCallPart):
313
+ tool_calls.append(
314
+ LlmToolCall(name=part.tool_name, args=part.args_as_dict())
315
+ )
316
+ llm_span.output = LlmOutput(
317
+ role="Assistant", content=content, tool_calls=tool_calls
318
+ )
319
+ llm_span.tools_called = extract_tools_called_from_llm_response(result.parts)
320
+
321
+
322
+ def set_agent_span_attributes(agent_span: AgentSpan, result: AgentRunResult):
323
+ agent_span.tools_called = extract_tools_called(result)
@@ -283,8 +283,9 @@ class MCPUseMetric(BaseMetric):
283
283
  mcp_resources_called: List[MCPResourceCall],
284
284
  mcp_prompts_called: List[MCPPromptCall],
285
285
  ) -> tuple[str, str]:
286
+ available_primitives = "MCP Primitives Available: \n"
286
287
  for mcp_server in mcp_servers:
287
- available_primitives = f"MCP Server {mcp_server.server_name}\n"
288
+ available_primitives += f"MCP Server {mcp_server.server_name}\n"
288
289
  available_primitives += (
289
290
  (
290
291
  "\nAvailable Tools:\n[\n"
@@ -43,7 +43,7 @@ class NonAdviceMetric(BaseMetric):
43
43
  "or ['financial', 'medical'] for multiple types."
44
44
  )
45
45
 
46
- self.threshold = 0 if strict_mode else threshold
46
+ self.threshold = 1 if strict_mode else threshold
47
47
  self.advice_types = advice_types
48
48
  self.model, self.using_native_model = initialize_model(model)
49
49
  self.evaluation_model = self.model.get_model_name()
@@ -293,7 +293,7 @@ class NonAdviceMetric(BaseMetric):
293
293
  appropriate_advice_count += 1
294
294
 
295
295
  score = appropriate_advice_count / number_of_verdicts
296
- return 1 if self.strict_mode and score < 1 else score
296
+ return 0 if self.strict_mode and score < self.threshold else score
297
297
 
298
298
  def is_successful(self) -> bool:
299
299
  if self.error is not None:
@@ -35,7 +35,7 @@ class PIILeakageMetric(BaseMetric):
35
35
  verbose_mode: bool = False,
36
36
  evaluation_template: Type[PIILeakageTemplate] = PIILeakageTemplate,
37
37
  ):
38
- self.threshold = 0 if strict_mode else threshold
38
+ self.threshold = 1 if strict_mode else threshold
39
39
  self.model, self.using_native_model = initialize_model(model)
40
40
  self.evaluation_model = self.model.get_model_name()
41
41
  self.include_reason = include_reason
@@ -284,7 +284,7 @@ class PIILeakageMetric(BaseMetric):
284
284
  no_privacy_count += 1
285
285
 
286
286
  score = no_privacy_count / number_of_verdicts
287
- return 1 if self.strict_mode and score < 1 else score
287
+ return 0 if self.strict_mode and score < self.threshold else score
288
288
 
289
289
  def is_successful(self) -> bool:
290
290
  if self.error is not None:
@@ -1,4 +1,4 @@
1
- from typing import List
1
+ from typing import Dict, List
2
2
  from openai import AzureOpenAI, AsyncAzureOpenAI
3
3
  from deepeval.key_handler import (
4
4
  EmbeddingKeyValues,
@@ -6,10 +6,18 @@ from deepeval.key_handler import (
6
6
  KEY_FILE_HANDLER,
7
7
  )
8
8
  from deepeval.models import DeepEvalBaseEmbeddingModel
9
+ from deepeval.models.retry_policy import (
10
+ create_retry_decorator,
11
+ sdk_retries_for,
12
+ )
13
+ from deepeval.constants import ProviderSlug as PS
14
+
15
+
16
+ retry_azure = create_retry_decorator(PS.AZURE)
9
17
 
10
18
 
11
19
  class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
12
- def __init__(self):
20
+ def __init__(self, **kwargs):
13
21
  self.azure_openai_api_key = KEY_FILE_HANDLER.fetch_data(
14
22
  ModelKeyValues.AZURE_OPENAI_API_KEY
15
23
  )
@@ -23,7 +31,9 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
23
31
  ModelKeyValues.AZURE_OPENAI_ENDPOINT
24
32
  )
25
33
  self.model_name = self.azure_embedding_deployment
34
+ self.kwargs = kwargs
26
35
 
36
+ @retry_azure
27
37
  def embed_text(self, text: str) -> List[float]:
28
38
  client = self.load_model(async_mode=False)
29
39
  response = client.embeddings.create(
@@ -32,6 +42,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
32
42
  )
33
43
  return response.data[0].embedding
34
44
 
45
+ @retry_azure
35
46
  def embed_texts(self, texts: List[str]) -> List[List[float]]:
36
47
  client = self.load_model(async_mode=False)
37
48
  response = client.embeddings.create(
@@ -40,6 +51,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
40
51
  )
41
52
  return [item.embedding for item in response.data]
42
53
 
54
+ @retry_azure
43
55
  async def a_embed_text(self, text: str) -> List[float]:
44
56
  client = self.load_model(async_mode=True)
45
57
  response = await client.embeddings.create(
@@ -48,6 +60,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
48
60
  )
49
61
  return response.data[0].embedding
50
62
 
63
+ @retry_azure
51
64
  async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
52
65
  client = self.load_model(async_mode=True)
53
66
  response = await client.embeddings.create(
@@ -61,15 +74,33 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
61
74
 
62
75
  def load_model(self, async_mode: bool = False):
63
76
  if not async_mode:
64
- return AzureOpenAI(
65
- api_key=self.azure_openai_api_key,
66
- api_version=self.openai_api_version,
67
- azure_endpoint=self.azure_endpoint,
68
- azure_deployment=self.azure_embedding_deployment,
69
- )
70
- return AsyncAzureOpenAI(
77
+ return self._build_client(AzureOpenAI)
78
+ return self._build_client(AsyncAzureOpenAI)
79
+
80
+ def _client_kwargs(self) -> Dict:
81
+ """
82
+ If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
83
+ If the user opts into SDK retries for 'azure' via DEEPEVAL_SDK_RETRY_PROVIDERS,
84
+ leave their retry settings as is.
85
+ """
86
+ kwargs = dict(self.kwargs or {})
87
+ if not sdk_retries_for(PS.AZURE):
88
+ kwargs["max_retries"] = 0
89
+ return kwargs
90
+
91
+ def _build_client(self, cls):
92
+ kw = dict(
71
93
  api_key=self.azure_openai_api_key,
72
94
  api_version=self.openai_api_version,
73
95
  azure_endpoint=self.azure_endpoint,
74
96
  azure_deployment=self.azure_embedding_deployment,
97
+ **self._client_kwargs(),
75
98
  )
99
+ try:
100
+ return cls(**kw)
101
+ except TypeError as e:
102
+ # older OpenAI SDKs may not accept max_retries, in that case remove and retry once
103
+ if "max_retries" in str(e):
104
+ kw.pop("max_retries", None)
105
+ return cls(**kw)
106
+ raise
@@ -1,12 +1,21 @@
1
- from openai import OpenAI
2
- from typing import List
1
+ from openai import OpenAI, AsyncOpenAI
2
+ from typing import Dict, List
3
3
 
4
4
  from deepeval.key_handler import EmbeddingKeyValues, KEY_FILE_HANDLER
5
5
  from deepeval.models import DeepEvalBaseEmbeddingModel
6
+ from deepeval.models.retry_policy import (
7
+ create_retry_decorator,
8
+ sdk_retries_for,
9
+ )
10
+ from deepeval.constants import ProviderSlug as PS
11
+
12
+
13
+ # consistent retry rules
14
+ retry_local = create_retry_decorator(PS.LOCAL)
6
15
 
7
16
 
8
17
  class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
9
- def __init__(self, *args, **kwargs):
18
+ def __init__(self, **kwargs):
10
19
  self.base_url = KEY_FILE_HANDLER.fetch_data(
11
20
  EmbeddingKeyValues.LOCAL_EMBEDDING_BASE_URL
12
21
  )
@@ -16,13 +25,10 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
16
25
  self.api_key = KEY_FILE_HANDLER.fetch_data(
17
26
  EmbeddingKeyValues.LOCAL_EMBEDDING_API_KEY
18
27
  )
19
- self.args = args
20
28
  self.kwargs = kwargs
21
29
  super().__init__(model_name)
22
30
 
23
- def load_model(self):
24
- return OpenAI(base_url=self.base_url, api_key=self.api_key)
25
-
31
+ @retry_local
26
32
  def embed_text(self, text: str) -> List[float]:
27
33
  embedding_model = self.load_model()
28
34
  response = embedding_model.embeddings.create(
@@ -31,6 +37,7 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
31
37
  )
32
38
  return response.data[0].embedding
33
39
 
40
+ @retry_local
34
41
  def embed_texts(self, texts: List[str]) -> List[List[float]]:
35
42
  embedding_model = self.load_model()
36
43
  response = embedding_model.embeddings.create(
@@ -39,21 +46,57 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
39
46
  )
40
47
  return [data.embedding for data in response.data]
41
48
 
49
+ @retry_local
42
50
  async def a_embed_text(self, text: str) -> List[float]:
43
- embedding_model = self.load_model()
51
+ embedding_model = self.load_model(async_mode=True)
44
52
  response = await embedding_model.embeddings.create(
45
53
  model=self.model_name,
46
54
  input=[text],
47
55
  )
48
56
  return response.data[0].embedding
49
57
 
58
+ @retry_local
50
59
  async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
51
- embedding_model = self.load_model()
60
+ embedding_model = self.load_model(async_mode=True)
52
61
  response = await embedding_model.embeddings.create(
53
62
  model=self.model_name,
54
63
  input=texts,
55
64
  )
56
65
  return [data.embedding for data in response.data]
57
66
 
67
+ ###############################################
68
+ # Model
69
+ ###############################################
70
+
58
71
  def get_model_name(self):
59
72
  return self.model_name
73
+
74
+ def load_model(self, async_mode: bool = False):
75
+ if not async_mode:
76
+ return self._build_client(OpenAI)
77
+ return self._build_client(AsyncOpenAI)
78
+
79
+ def _client_kwargs(self) -> Dict:
80
+ """
81
+ If Tenacity manages retries, turn off OpenAI SDK retries to avoid double retrying.
82
+ If users opt into SDK retries via DEEPEVAL_SDK_RETRY_PROVIDERS=local, leave them enabled.
83
+ """
84
+ kwargs = dict(self.kwargs or {})
85
+ if not sdk_retries_for(PS.LOCAL):
86
+ kwargs["max_retries"] = 0
87
+ return kwargs
88
+
89
+ def _build_client(self, cls):
90
+ kw = dict(
91
+ api_key=self.api_key,
92
+ base_url=self.base_url,
93
+ **self._client_kwargs(),
94
+ )
95
+ try:
96
+ return cls(**kw)
97
+ except TypeError as e:
98
+ # Older OpenAI SDKs may not accept max_retries; drop and retry once.
99
+ if "max_retries" in str(e):
100
+ kw.pop("max_retries", None)
101
+ return cls(**kw)
102
+ raise
@@ -3,6 +3,13 @@ from typing import List
3
3
 
4
4
  from deepeval.key_handler import EmbeddingKeyValues, KEY_FILE_HANDLER
5
5
  from deepeval.models import DeepEvalBaseEmbeddingModel
6
+ from deepeval.models.retry_policy import (
7
+ create_retry_decorator,
8
+ )
9
+ from deepeval.constants import ProviderSlug as PS
10
+
11
+
12
+ retry_ollama = create_retry_decorator(PS.OLLAMA)
6
13
 
7
14
 
8
15
  class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
@@ -13,6 +20,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
13
20
  model_name = KEY_FILE_HANDLER.fetch_data(
14
21
  EmbeddingKeyValues.LOCAL_EMBEDDING_MODEL_NAME
15
22
  )
23
+ # TODO: This is not being used. Clean it up in consistency PR
16
24
  self.api_key = KEY_FILE_HANDLER.fetch_data(
17
25
  EmbeddingKeyValues.LOCAL_EMBEDDING_API_KEY
18
26
  )
@@ -20,12 +28,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
20
28
  self.kwargs = kwargs
21
29
  super().__init__(model_name)
22
30
 
23
- def load_model(self, async_mode: bool = False):
24
- if not async_mode:
25
- return Client(host=self.base_url)
26
-
27
- return AsyncClient(host=self.base_url)
28
-
31
+ @retry_ollama
29
32
  def embed_text(self, text: str) -> List[float]:
30
33
  embedding_model = self.load_model()
31
34
  response = embedding_model.embed(
@@ -34,6 +37,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
34
37
  )
35
38
  return response["embeddings"][0]
36
39
 
40
+ @retry_ollama
37
41
  def embed_texts(self, texts: List[str]) -> List[List[float]]:
38
42
  embedding_model = self.load_model()
39
43
  response = embedding_model.embed(
@@ -42,6 +46,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
42
46
  )
43
47
  return response["embeddings"]
44
48
 
49
+ @retry_ollama
45
50
  async def a_embed_text(self, text: str) -> List[float]:
46
51
  embedding_model = self.load_model(async_mode=True)
47
52
  response = await embedding_model.embed(
@@ -50,6 +55,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
50
55
  )
51
56
  return response["embeddings"][0]
52
57
 
58
+ @retry_ollama
53
59
  async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
54
60
  embedding_model = self.load_model(async_mode=True)
55
61
  response = await embedding_model.embed(
@@ -58,5 +64,17 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
58
64
  )
59
65
  return response["embeddings"]
60
66
 
67
+ ###############################################
68
+ # Model
69
+ ###############################################
70
+
71
+ def load_model(self, async_mode: bool = False):
72
+ if not async_mode:
73
+ return self._build_client(Client)
74
+ return self._build_client(AsyncClient)
75
+
76
+ def _build_client(self, cls):
77
+ return cls(host=self.base_url, **self.kwargs)
78
+
61
79
  def get_model_name(self):
62
- return self.model_name
80
+ return f"{self.model_name} (Ollama)"
@@ -1,6 +1,14 @@
1
- from typing import Optional, List
1
+ from typing import Dict, Optional, List
2
2
  from openai import OpenAI, AsyncOpenAI
3
3
  from deepeval.models import DeepEvalBaseEmbeddingModel
4
+ from deepeval.models.retry_policy import (
5
+ create_retry_decorator,
6
+ sdk_retries_for,
7
+ )
8
+ from deepeval.constants import ProviderSlug as PS
9
+
10
+
11
+ retry_openai = create_retry_decorator(PS.OPENAI)
4
12
 
5
13
  valid_openai_embedding_models = [
6
14
  "text-embedding-3-small",
@@ -15,6 +23,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
15
23
  self,
16
24
  model: Optional[str] = None,
17
25
  _openai_api_key: Optional[str] = None,
26
+ **kwargs,
18
27
  ):
19
28
  model_name = model if model else default_openai_embedding_model
20
29
  if model_name not in valid_openai_embedding_models:
@@ -23,7 +32,9 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
23
32
  )
24
33
  self._openai_api_key = _openai_api_key
25
34
  self.model_name = model_name
35
+ self.kwargs = kwargs
26
36
 
37
+ @retry_openai
27
38
  def embed_text(self, text: str) -> List[float]:
28
39
  client = self.load_model(async_mode=False)
29
40
  response = client.embeddings.create(
@@ -32,6 +43,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
32
43
  )
33
44
  return response.data[0].embedding
34
45
 
46
+ @retry_openai
35
47
  def embed_texts(self, texts: List[str]) -> List[List[float]]:
36
48
  client = self.load_model(async_mode=False)
37
49
  response = client.embeddings.create(
@@ -40,6 +52,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
40
52
  )
41
53
  return [item.embedding for item in response.data]
42
54
 
55
+ @retry_openai
43
56
  async def a_embed_text(self, text: str) -> List[float]:
44
57
  client = self.load_model(async_mode=True)
45
58
  response = await client.embeddings.create(
@@ -48,6 +61,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
48
61
  )
49
62
  return response.data[0].embedding
50
63
 
64
+ @retry_openai
51
65
  async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
52
66
  client = self.load_model(async_mode=True)
53
67
  response = await client.embeddings.create(
@@ -56,11 +70,39 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
56
70
  )
57
71
  return [item.embedding for item in response.data]
58
72
 
59
- def get_model_name(self) -> str:
73
+ ###############################################
74
+ # Model
75
+ ###############################################
76
+
77
+ def get_model_name(self):
60
78
  return self.model_name
61
79
 
62
- def load_model(self, async_mode: bool):
80
+ def load_model(self, async_mode: bool = False):
63
81
  if not async_mode:
64
- return OpenAI(api_key=self._openai_api_key)
82
+ return self._build_client(OpenAI)
83
+ return self._build_client(AsyncOpenAI)
65
84
 
66
- return AsyncOpenAI(api_key=self._openai_api_key)
85
+ def _client_kwargs(self) -> Dict:
86
+ """
87
+ If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
88
+ If the user opts into SDK retries for 'openai' via DEEPEVAL_SDK_RETRY_PROVIDERS,
89
+ leave their retry settings as is.
90
+ """
91
+ kwargs = dict(self.kwargs or {})
92
+ if not sdk_retries_for(PS.OPENAI):
93
+ kwargs["max_retries"] = 0
94
+ return kwargs
95
+
96
+ def _build_client(self, cls):
97
+ kw = dict(
98
+ api_key=self._openai_api_key,
99
+ **self._client_kwargs(),
100
+ )
101
+ try:
102
+ return cls(**kw)
103
+ except TypeError as e:
104
+ # older OpenAI SDKs may not accept max_retries, in that case remove and retry once
105
+ if "max_retries" in str(e):
106
+ kw.pop("max_retries", None)
107
+ return cls(**kw)
108
+ raise