deepeval 3.5.2__py3-none-any.whl → 3.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepeval/_version.py +1 -1
- deepeval/config/settings.py +94 -2
- deepeval/config/utils.py +54 -1
- deepeval/constants.py +27 -0
- deepeval/integrations/pydantic_ai/__init__.py +3 -1
- deepeval/integrations/pydantic_ai/agent.py +339 -0
- deepeval/integrations/pydantic_ai/patcher.py +479 -406
- deepeval/integrations/pydantic_ai/utils.py +239 -2
- deepeval/metrics/mcp_use_metric/mcp_use_metric.py +2 -1
- deepeval/metrics/non_advice/non_advice.py +2 -2
- deepeval/metrics/pii_leakage/pii_leakage.py +2 -2
- deepeval/models/embedding_models/azure_embedding_model.py +40 -9
- deepeval/models/embedding_models/local_embedding_model.py +52 -9
- deepeval/models/embedding_models/ollama_embedding_model.py +25 -7
- deepeval/models/embedding_models/openai_embedding_model.py +47 -5
- deepeval/models/llms/amazon_bedrock_model.py +31 -4
- deepeval/models/llms/anthropic_model.py +39 -13
- deepeval/models/llms/azure_model.py +37 -38
- deepeval/models/llms/deepseek_model.py +36 -7
- deepeval/models/llms/gemini_model.py +10 -0
- deepeval/models/llms/grok_model.py +50 -3
- deepeval/models/llms/kimi_model.py +37 -7
- deepeval/models/llms/local_model.py +38 -12
- deepeval/models/llms/ollama_model.py +15 -3
- deepeval/models/llms/openai_model.py +37 -44
- deepeval/models/mlllms/gemini_model.py +21 -3
- deepeval/models/mlllms/ollama_model.py +38 -13
- deepeval/models/mlllms/openai_model.py +18 -42
- deepeval/models/retry_policy.py +548 -64
- deepeval/tracing/tracing.py +87 -0
- {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/METADATA +1 -1
- {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/RECORD +35 -34
- {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/LICENSE.md +0 -0
- {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/WHEEL +0 -0
- {deepeval-3.5.2.dist-info → deepeval-3.5.4.dist-info}/entry_points.txt +0 -0
|
@@ -1,8 +1,29 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
1
|
+
from time import perf_counter
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
import inspect
|
|
4
|
+
import functools
|
|
5
|
+
from typing import Any, Callable, List, Optional
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.models import Model
|
|
3
8
|
from pydantic_ai.agent import AgentRunResult
|
|
4
9
|
from pydantic_ai._run_context import RunContext
|
|
10
|
+
from pydantic_ai.messages import (
|
|
11
|
+
ModelRequest,
|
|
12
|
+
ModelResponse,
|
|
13
|
+
ModelResponsePart,
|
|
14
|
+
SystemPromptPart,
|
|
15
|
+
TextPart,
|
|
16
|
+
ToolCallPart,
|
|
17
|
+
ToolReturnPart,
|
|
18
|
+
UserPromptPart,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from deepeval.prompt import Prompt
|
|
22
|
+
from deepeval.tracing.tracing import Observer
|
|
23
|
+
from deepeval.metrics.base_metric import BaseMetric
|
|
5
24
|
from deepeval.test_case.llm_test_case import ToolCall
|
|
25
|
+
from deepeval.tracing.context import current_trace_context, current_span_context
|
|
26
|
+
from deepeval.tracing.types import AgentSpan, LlmOutput, LlmSpan, LlmToolCall
|
|
6
27
|
|
|
7
28
|
|
|
8
29
|
# llm tools called
|
|
@@ -84,3 +105,219 @@ def sanitize_run_context(value):
|
|
|
84
105
|
return {sanitize_run_context(v) for v in value}
|
|
85
106
|
|
|
86
107
|
return value
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def patch_llm_model(
|
|
111
|
+
model: Model,
|
|
112
|
+
llm_metric_collection: Optional[str] = None,
|
|
113
|
+
llm_metrics: Optional[List[BaseMetric]] = None,
|
|
114
|
+
llm_prompt: Optional[Prompt] = None,
|
|
115
|
+
):
|
|
116
|
+
original_func = model.request
|
|
117
|
+
sig = inspect.signature(original_func)
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
model_name = model.model_name
|
|
121
|
+
except Exception:
|
|
122
|
+
model_name = "unknown"
|
|
123
|
+
|
|
124
|
+
@functools.wraps(original_func)
|
|
125
|
+
async def wrapper(*args, **kwargs):
|
|
126
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
127
|
+
bound.apply_defaults()
|
|
128
|
+
request = bound.arguments.get("messages", [])
|
|
129
|
+
|
|
130
|
+
with Observer(
|
|
131
|
+
span_type="llm",
|
|
132
|
+
func_name="LLM",
|
|
133
|
+
observe_kwargs={"model": model_name},
|
|
134
|
+
metrics=llm_metrics,
|
|
135
|
+
metric_collection=llm_metric_collection,
|
|
136
|
+
) as observer:
|
|
137
|
+
result = await original_func(*args, **kwargs)
|
|
138
|
+
observer.update_span_properties = (
|
|
139
|
+
lambda llm_span: set_llm_span_attributes(
|
|
140
|
+
llm_span, request, result, llm_prompt
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
observer.result = result
|
|
144
|
+
return result
|
|
145
|
+
|
|
146
|
+
model.request = wrapper
|
|
147
|
+
|
|
148
|
+
stream_original_func = model.request_stream
|
|
149
|
+
stream_sig = inspect.signature(stream_original_func)
|
|
150
|
+
|
|
151
|
+
@asynccontextmanager
|
|
152
|
+
async def stream_wrapper(*args, **kwargs):
|
|
153
|
+
bound = stream_sig.bind_partial(*args, **kwargs)
|
|
154
|
+
bound.apply_defaults()
|
|
155
|
+
request = bound.arguments.get("messages", [])
|
|
156
|
+
|
|
157
|
+
with Observer(
|
|
158
|
+
span_type="llm",
|
|
159
|
+
func_name="LLM",
|
|
160
|
+
observe_kwargs={"model": model_name},
|
|
161
|
+
metrics=llm_metrics,
|
|
162
|
+
metric_collection=llm_metric_collection,
|
|
163
|
+
) as observer:
|
|
164
|
+
llm_span: LlmSpan = current_span_context.get()
|
|
165
|
+
async with stream_original_func(
|
|
166
|
+
*args, **kwargs
|
|
167
|
+
) as streamed_response:
|
|
168
|
+
try:
|
|
169
|
+
yield streamed_response
|
|
170
|
+
if not llm_span.token_intervals:
|
|
171
|
+
llm_span.token_intervals = {perf_counter(): "NA"}
|
|
172
|
+
else:
|
|
173
|
+
llm_span.token_intervals[perf_counter()] = "NA"
|
|
174
|
+
finally:
|
|
175
|
+
try:
|
|
176
|
+
result = streamed_response.get()
|
|
177
|
+
observer.update_span_properties = (
|
|
178
|
+
lambda llm_span: set_llm_span_attributes(
|
|
179
|
+
llm_span, request, result, llm_prompt
|
|
180
|
+
)
|
|
181
|
+
)
|
|
182
|
+
observer.result = result
|
|
183
|
+
except Exception:
|
|
184
|
+
pass
|
|
185
|
+
|
|
186
|
+
model.request_stream = stream_wrapper
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def create_patched_tool(
|
|
190
|
+
func: Callable,
|
|
191
|
+
metrics: Optional[List[BaseMetric]] = None,
|
|
192
|
+
metric_collection: Optional[str] = None,
|
|
193
|
+
):
|
|
194
|
+
import asyncio
|
|
195
|
+
|
|
196
|
+
original_func = func
|
|
197
|
+
|
|
198
|
+
is_async = asyncio.iscoroutinefunction(original_func)
|
|
199
|
+
|
|
200
|
+
if is_async:
|
|
201
|
+
|
|
202
|
+
@functools.wraps(original_func)
|
|
203
|
+
async def async_wrapper(*args, **kwargs):
|
|
204
|
+
sanitized_args = sanitize_run_context(args)
|
|
205
|
+
sanitized_kwargs = sanitize_run_context(kwargs)
|
|
206
|
+
with Observer(
|
|
207
|
+
span_type="tool",
|
|
208
|
+
func_name=original_func.__name__,
|
|
209
|
+
metrics=metrics,
|
|
210
|
+
metric_collection=metric_collection,
|
|
211
|
+
function_kwargs={"args": sanitized_args, **sanitized_kwargs},
|
|
212
|
+
) as observer:
|
|
213
|
+
result = await original_func(*args, **kwargs)
|
|
214
|
+
observer.result = result
|
|
215
|
+
|
|
216
|
+
return result
|
|
217
|
+
|
|
218
|
+
return async_wrapper
|
|
219
|
+
else:
|
|
220
|
+
|
|
221
|
+
@functools.wraps(original_func)
|
|
222
|
+
def sync_wrapper(*args, **kwargs):
|
|
223
|
+
sanitized_args = sanitize_run_context(args)
|
|
224
|
+
sanitized_kwargs = sanitize_run_context(kwargs)
|
|
225
|
+
with Observer(
|
|
226
|
+
span_type="tool",
|
|
227
|
+
func_name=original_func.__name__,
|
|
228
|
+
metrics=metrics,
|
|
229
|
+
metric_collection=metric_collection,
|
|
230
|
+
function_kwargs={"args": sanitized_args, **sanitized_kwargs},
|
|
231
|
+
) as observer:
|
|
232
|
+
result = original_func(*args, **kwargs)
|
|
233
|
+
observer.result = result
|
|
234
|
+
|
|
235
|
+
return result
|
|
236
|
+
|
|
237
|
+
return sync_wrapper
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def update_trace_context(
|
|
241
|
+
trace_name: Optional[str] = None,
|
|
242
|
+
trace_tags: Optional[List[str]] = None,
|
|
243
|
+
trace_metadata: Optional[dict] = None,
|
|
244
|
+
trace_thread_id: Optional[str] = None,
|
|
245
|
+
trace_user_id: Optional[str] = None,
|
|
246
|
+
trace_metric_collection: Optional[str] = None,
|
|
247
|
+
trace_metrics: Optional[List[BaseMetric]] = None,
|
|
248
|
+
trace_input: Optional[Any] = None,
|
|
249
|
+
trace_output: Optional[Any] = None,
|
|
250
|
+
):
|
|
251
|
+
|
|
252
|
+
current_trace = current_trace_context.get()
|
|
253
|
+
|
|
254
|
+
if trace_name:
|
|
255
|
+
current_trace.name = trace_name
|
|
256
|
+
if trace_tags:
|
|
257
|
+
current_trace.tags = trace_tags
|
|
258
|
+
if trace_metadata:
|
|
259
|
+
current_trace.metadata = trace_metadata
|
|
260
|
+
if trace_thread_id:
|
|
261
|
+
current_trace.thread_id = trace_thread_id
|
|
262
|
+
if trace_user_id:
|
|
263
|
+
current_trace.user_id = trace_user_id
|
|
264
|
+
if trace_metric_collection:
|
|
265
|
+
current_trace.metric_collection = trace_metric_collection
|
|
266
|
+
if trace_metrics:
|
|
267
|
+
current_trace.metrics = trace_metrics
|
|
268
|
+
if trace_input:
|
|
269
|
+
current_trace.input = trace_input
|
|
270
|
+
if trace_output:
|
|
271
|
+
current_trace.output = trace_output
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def set_llm_span_attributes(
|
|
275
|
+
llm_span: LlmSpan,
|
|
276
|
+
requests: List[ModelRequest],
|
|
277
|
+
result: ModelResponse,
|
|
278
|
+
llm_prompt: Optional[Prompt] = None,
|
|
279
|
+
):
|
|
280
|
+
llm_span.prompt = llm_prompt
|
|
281
|
+
|
|
282
|
+
input = []
|
|
283
|
+
for request in requests:
|
|
284
|
+
for part in request.parts:
|
|
285
|
+
if isinstance(part, SystemPromptPart):
|
|
286
|
+
input.append({"role": "System", "content": part.content})
|
|
287
|
+
elif isinstance(part, UserPromptPart):
|
|
288
|
+
input.append({"role": "User", "content": part.content})
|
|
289
|
+
elif isinstance(part, ToolCallPart):
|
|
290
|
+
input.append(
|
|
291
|
+
{
|
|
292
|
+
"role": "Tool Call",
|
|
293
|
+
"name": part.tool_name,
|
|
294
|
+
"content": part.args_as_json_str(),
|
|
295
|
+
}
|
|
296
|
+
)
|
|
297
|
+
elif isinstance(part, ToolReturnPart):
|
|
298
|
+
input.append(
|
|
299
|
+
{
|
|
300
|
+
"role": "Tool Return",
|
|
301
|
+
"name": part.tool_name,
|
|
302
|
+
"content": part.model_response_str(),
|
|
303
|
+
}
|
|
304
|
+
)
|
|
305
|
+
llm_span.input = input
|
|
306
|
+
|
|
307
|
+
content = ""
|
|
308
|
+
tool_calls = []
|
|
309
|
+
for part in result.parts:
|
|
310
|
+
if isinstance(part, TextPart):
|
|
311
|
+
content += part.content + "\n"
|
|
312
|
+
elif isinstance(part, ToolCallPart):
|
|
313
|
+
tool_calls.append(
|
|
314
|
+
LlmToolCall(name=part.tool_name, args=part.args_as_dict())
|
|
315
|
+
)
|
|
316
|
+
llm_span.output = LlmOutput(
|
|
317
|
+
role="Assistant", content=content, tool_calls=tool_calls
|
|
318
|
+
)
|
|
319
|
+
llm_span.tools_called = extract_tools_called_from_llm_response(result.parts)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def set_agent_span_attributes(agent_span: AgentSpan, result: AgentRunResult):
|
|
323
|
+
agent_span.tools_called = extract_tools_called(result)
|
|
@@ -283,8 +283,9 @@ class MCPUseMetric(BaseMetric):
|
|
|
283
283
|
mcp_resources_called: List[MCPResourceCall],
|
|
284
284
|
mcp_prompts_called: List[MCPPromptCall],
|
|
285
285
|
) -> tuple[str, str]:
|
|
286
|
+
available_primitives = "MCP Primitives Available: \n"
|
|
286
287
|
for mcp_server in mcp_servers:
|
|
287
|
-
available_primitives
|
|
288
|
+
available_primitives += f"MCP Server {mcp_server.server_name}\n"
|
|
288
289
|
available_primitives += (
|
|
289
290
|
(
|
|
290
291
|
"\nAvailable Tools:\n[\n"
|
|
@@ -43,7 +43,7 @@ class NonAdviceMetric(BaseMetric):
|
|
|
43
43
|
"or ['financial', 'medical'] for multiple types."
|
|
44
44
|
)
|
|
45
45
|
|
|
46
|
-
self.threshold =
|
|
46
|
+
self.threshold = 1 if strict_mode else threshold
|
|
47
47
|
self.advice_types = advice_types
|
|
48
48
|
self.model, self.using_native_model = initialize_model(model)
|
|
49
49
|
self.evaluation_model = self.model.get_model_name()
|
|
@@ -293,7 +293,7 @@ class NonAdviceMetric(BaseMetric):
|
|
|
293
293
|
appropriate_advice_count += 1
|
|
294
294
|
|
|
295
295
|
score = appropriate_advice_count / number_of_verdicts
|
|
296
|
-
return
|
|
296
|
+
return 0 if self.strict_mode and score < self.threshold else score
|
|
297
297
|
|
|
298
298
|
def is_successful(self) -> bool:
|
|
299
299
|
if self.error is not None:
|
|
@@ -35,7 +35,7 @@ class PIILeakageMetric(BaseMetric):
|
|
|
35
35
|
verbose_mode: bool = False,
|
|
36
36
|
evaluation_template: Type[PIILeakageTemplate] = PIILeakageTemplate,
|
|
37
37
|
):
|
|
38
|
-
self.threshold =
|
|
38
|
+
self.threshold = 1 if strict_mode else threshold
|
|
39
39
|
self.model, self.using_native_model = initialize_model(model)
|
|
40
40
|
self.evaluation_model = self.model.get_model_name()
|
|
41
41
|
self.include_reason = include_reason
|
|
@@ -284,7 +284,7 @@ class PIILeakageMetric(BaseMetric):
|
|
|
284
284
|
no_privacy_count += 1
|
|
285
285
|
|
|
286
286
|
score = no_privacy_count / number_of_verdicts
|
|
287
|
-
return
|
|
287
|
+
return 0 if self.strict_mode and score < self.threshold else score
|
|
288
288
|
|
|
289
289
|
def is_successful(self) -> bool:
|
|
290
290
|
if self.error is not None:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import Dict, List
|
|
2
2
|
from openai import AzureOpenAI, AsyncAzureOpenAI
|
|
3
3
|
from deepeval.key_handler import (
|
|
4
4
|
EmbeddingKeyValues,
|
|
@@ -6,10 +6,18 @@ from deepeval.key_handler import (
|
|
|
6
6
|
KEY_FILE_HANDLER,
|
|
7
7
|
)
|
|
8
8
|
from deepeval.models import DeepEvalBaseEmbeddingModel
|
|
9
|
+
from deepeval.models.retry_policy import (
|
|
10
|
+
create_retry_decorator,
|
|
11
|
+
sdk_retries_for,
|
|
12
|
+
)
|
|
13
|
+
from deepeval.constants import ProviderSlug as PS
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
retry_azure = create_retry_decorator(PS.AZURE)
|
|
9
17
|
|
|
10
18
|
|
|
11
19
|
class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
12
|
-
def __init__(self):
|
|
20
|
+
def __init__(self, **kwargs):
|
|
13
21
|
self.azure_openai_api_key = KEY_FILE_HANDLER.fetch_data(
|
|
14
22
|
ModelKeyValues.AZURE_OPENAI_API_KEY
|
|
15
23
|
)
|
|
@@ -23,7 +31,9 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
23
31
|
ModelKeyValues.AZURE_OPENAI_ENDPOINT
|
|
24
32
|
)
|
|
25
33
|
self.model_name = self.azure_embedding_deployment
|
|
34
|
+
self.kwargs = kwargs
|
|
26
35
|
|
|
36
|
+
@retry_azure
|
|
27
37
|
def embed_text(self, text: str) -> List[float]:
|
|
28
38
|
client = self.load_model(async_mode=False)
|
|
29
39
|
response = client.embeddings.create(
|
|
@@ -32,6 +42,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
32
42
|
)
|
|
33
43
|
return response.data[0].embedding
|
|
34
44
|
|
|
45
|
+
@retry_azure
|
|
35
46
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
36
47
|
client = self.load_model(async_mode=False)
|
|
37
48
|
response = client.embeddings.create(
|
|
@@ -40,6 +51,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
40
51
|
)
|
|
41
52
|
return [item.embedding for item in response.data]
|
|
42
53
|
|
|
54
|
+
@retry_azure
|
|
43
55
|
async def a_embed_text(self, text: str) -> List[float]:
|
|
44
56
|
client = self.load_model(async_mode=True)
|
|
45
57
|
response = await client.embeddings.create(
|
|
@@ -48,6 +60,7 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
48
60
|
)
|
|
49
61
|
return response.data[0].embedding
|
|
50
62
|
|
|
63
|
+
@retry_azure
|
|
51
64
|
async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
52
65
|
client = self.load_model(async_mode=True)
|
|
53
66
|
response = await client.embeddings.create(
|
|
@@ -61,15 +74,33 @@ class AzureOpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
61
74
|
|
|
62
75
|
def load_model(self, async_mode: bool = False):
|
|
63
76
|
if not async_mode:
|
|
64
|
-
return AzureOpenAI
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
77
|
+
return self._build_client(AzureOpenAI)
|
|
78
|
+
return self._build_client(AsyncAzureOpenAI)
|
|
79
|
+
|
|
80
|
+
def _client_kwargs(self) -> Dict:
|
|
81
|
+
"""
|
|
82
|
+
If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
|
|
83
|
+
If the user opts into SDK retries for 'azure' via DEEPEVAL_SDK_RETRY_PROVIDERS,
|
|
84
|
+
leave their retry settings as is.
|
|
85
|
+
"""
|
|
86
|
+
kwargs = dict(self.kwargs or {})
|
|
87
|
+
if not sdk_retries_for(PS.AZURE):
|
|
88
|
+
kwargs["max_retries"] = 0
|
|
89
|
+
return kwargs
|
|
90
|
+
|
|
91
|
+
def _build_client(self, cls):
|
|
92
|
+
kw = dict(
|
|
71
93
|
api_key=self.azure_openai_api_key,
|
|
72
94
|
api_version=self.openai_api_version,
|
|
73
95
|
azure_endpoint=self.azure_endpoint,
|
|
74
96
|
azure_deployment=self.azure_embedding_deployment,
|
|
97
|
+
**self._client_kwargs(),
|
|
75
98
|
)
|
|
99
|
+
try:
|
|
100
|
+
return cls(**kw)
|
|
101
|
+
except TypeError as e:
|
|
102
|
+
# older OpenAI SDKs may not accept max_retries, in that case remove and retry once
|
|
103
|
+
if "max_retries" in str(e):
|
|
104
|
+
kw.pop("max_retries", None)
|
|
105
|
+
return cls(**kw)
|
|
106
|
+
raise
|
|
@@ -1,12 +1,21 @@
|
|
|
1
|
-
from openai import OpenAI
|
|
2
|
-
from typing import List
|
|
1
|
+
from openai import OpenAI, AsyncOpenAI
|
|
2
|
+
from typing import Dict, List
|
|
3
3
|
|
|
4
4
|
from deepeval.key_handler import EmbeddingKeyValues, KEY_FILE_HANDLER
|
|
5
5
|
from deepeval.models import DeepEvalBaseEmbeddingModel
|
|
6
|
+
from deepeval.models.retry_policy import (
|
|
7
|
+
create_retry_decorator,
|
|
8
|
+
sdk_retries_for,
|
|
9
|
+
)
|
|
10
|
+
from deepeval.constants import ProviderSlug as PS
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# consistent retry rules
|
|
14
|
+
retry_local = create_retry_decorator(PS.LOCAL)
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
9
|
-
def __init__(self,
|
|
18
|
+
def __init__(self, **kwargs):
|
|
10
19
|
self.base_url = KEY_FILE_HANDLER.fetch_data(
|
|
11
20
|
EmbeddingKeyValues.LOCAL_EMBEDDING_BASE_URL
|
|
12
21
|
)
|
|
@@ -16,13 +25,10 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
16
25
|
self.api_key = KEY_FILE_HANDLER.fetch_data(
|
|
17
26
|
EmbeddingKeyValues.LOCAL_EMBEDDING_API_KEY
|
|
18
27
|
)
|
|
19
|
-
self.args = args
|
|
20
28
|
self.kwargs = kwargs
|
|
21
29
|
super().__init__(model_name)
|
|
22
30
|
|
|
23
|
-
|
|
24
|
-
return OpenAI(base_url=self.base_url, api_key=self.api_key)
|
|
25
|
-
|
|
31
|
+
@retry_local
|
|
26
32
|
def embed_text(self, text: str) -> List[float]:
|
|
27
33
|
embedding_model = self.load_model()
|
|
28
34
|
response = embedding_model.embeddings.create(
|
|
@@ -31,6 +37,7 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
31
37
|
)
|
|
32
38
|
return response.data[0].embedding
|
|
33
39
|
|
|
40
|
+
@retry_local
|
|
34
41
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
35
42
|
embedding_model = self.load_model()
|
|
36
43
|
response = embedding_model.embeddings.create(
|
|
@@ -39,21 +46,57 @@ class LocalEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
39
46
|
)
|
|
40
47
|
return [data.embedding for data in response.data]
|
|
41
48
|
|
|
49
|
+
@retry_local
|
|
42
50
|
async def a_embed_text(self, text: str) -> List[float]:
|
|
43
|
-
embedding_model = self.load_model()
|
|
51
|
+
embedding_model = self.load_model(async_mode=True)
|
|
44
52
|
response = await embedding_model.embeddings.create(
|
|
45
53
|
model=self.model_name,
|
|
46
54
|
input=[text],
|
|
47
55
|
)
|
|
48
56
|
return response.data[0].embedding
|
|
49
57
|
|
|
58
|
+
@retry_local
|
|
50
59
|
async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
51
|
-
embedding_model = self.load_model()
|
|
60
|
+
embedding_model = self.load_model(async_mode=True)
|
|
52
61
|
response = await embedding_model.embeddings.create(
|
|
53
62
|
model=self.model_name,
|
|
54
63
|
input=texts,
|
|
55
64
|
)
|
|
56
65
|
return [data.embedding for data in response.data]
|
|
57
66
|
|
|
67
|
+
###############################################
|
|
68
|
+
# Model
|
|
69
|
+
###############################################
|
|
70
|
+
|
|
58
71
|
def get_model_name(self):
|
|
59
72
|
return self.model_name
|
|
73
|
+
|
|
74
|
+
def load_model(self, async_mode: bool = False):
|
|
75
|
+
if not async_mode:
|
|
76
|
+
return self._build_client(OpenAI)
|
|
77
|
+
return self._build_client(AsyncOpenAI)
|
|
78
|
+
|
|
79
|
+
def _client_kwargs(self) -> Dict:
|
|
80
|
+
"""
|
|
81
|
+
If Tenacity manages retries, turn off OpenAI SDK retries to avoid double retrying.
|
|
82
|
+
If users opt into SDK retries via DEEPEVAL_SDK_RETRY_PROVIDERS=local, leave them enabled.
|
|
83
|
+
"""
|
|
84
|
+
kwargs = dict(self.kwargs or {})
|
|
85
|
+
if not sdk_retries_for(PS.LOCAL):
|
|
86
|
+
kwargs["max_retries"] = 0
|
|
87
|
+
return kwargs
|
|
88
|
+
|
|
89
|
+
def _build_client(self, cls):
|
|
90
|
+
kw = dict(
|
|
91
|
+
api_key=self.api_key,
|
|
92
|
+
base_url=self.base_url,
|
|
93
|
+
**self._client_kwargs(),
|
|
94
|
+
)
|
|
95
|
+
try:
|
|
96
|
+
return cls(**kw)
|
|
97
|
+
except TypeError as e:
|
|
98
|
+
# Older OpenAI SDKs may not accept max_retries; drop and retry once.
|
|
99
|
+
if "max_retries" in str(e):
|
|
100
|
+
kw.pop("max_retries", None)
|
|
101
|
+
return cls(**kw)
|
|
102
|
+
raise
|
|
@@ -3,6 +3,13 @@ from typing import List
|
|
|
3
3
|
|
|
4
4
|
from deepeval.key_handler import EmbeddingKeyValues, KEY_FILE_HANDLER
|
|
5
5
|
from deepeval.models import DeepEvalBaseEmbeddingModel
|
|
6
|
+
from deepeval.models.retry_policy import (
|
|
7
|
+
create_retry_decorator,
|
|
8
|
+
)
|
|
9
|
+
from deepeval.constants import ProviderSlug as PS
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
retry_ollama = create_retry_decorator(PS.OLLAMA)
|
|
6
13
|
|
|
7
14
|
|
|
8
15
|
class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
@@ -13,6 +20,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
13
20
|
model_name = KEY_FILE_HANDLER.fetch_data(
|
|
14
21
|
EmbeddingKeyValues.LOCAL_EMBEDDING_MODEL_NAME
|
|
15
22
|
)
|
|
23
|
+
# TODO: This is not being used. Clean it up in consistency PR
|
|
16
24
|
self.api_key = KEY_FILE_HANDLER.fetch_data(
|
|
17
25
|
EmbeddingKeyValues.LOCAL_EMBEDDING_API_KEY
|
|
18
26
|
)
|
|
@@ -20,12 +28,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
20
28
|
self.kwargs = kwargs
|
|
21
29
|
super().__init__(model_name)
|
|
22
30
|
|
|
23
|
-
|
|
24
|
-
if not async_mode:
|
|
25
|
-
return Client(host=self.base_url)
|
|
26
|
-
|
|
27
|
-
return AsyncClient(host=self.base_url)
|
|
28
|
-
|
|
31
|
+
@retry_ollama
|
|
29
32
|
def embed_text(self, text: str) -> List[float]:
|
|
30
33
|
embedding_model = self.load_model()
|
|
31
34
|
response = embedding_model.embed(
|
|
@@ -34,6 +37,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
34
37
|
)
|
|
35
38
|
return response["embeddings"][0]
|
|
36
39
|
|
|
40
|
+
@retry_ollama
|
|
37
41
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
38
42
|
embedding_model = self.load_model()
|
|
39
43
|
response = embedding_model.embed(
|
|
@@ -42,6 +46,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
42
46
|
)
|
|
43
47
|
return response["embeddings"]
|
|
44
48
|
|
|
49
|
+
@retry_ollama
|
|
45
50
|
async def a_embed_text(self, text: str) -> List[float]:
|
|
46
51
|
embedding_model = self.load_model(async_mode=True)
|
|
47
52
|
response = await embedding_model.embed(
|
|
@@ -50,6 +55,7 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
50
55
|
)
|
|
51
56
|
return response["embeddings"][0]
|
|
52
57
|
|
|
58
|
+
@retry_ollama
|
|
53
59
|
async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
54
60
|
embedding_model = self.load_model(async_mode=True)
|
|
55
61
|
response = await embedding_model.embed(
|
|
@@ -58,5 +64,17 @@ class OllamaEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
58
64
|
)
|
|
59
65
|
return response["embeddings"]
|
|
60
66
|
|
|
67
|
+
###############################################
|
|
68
|
+
# Model
|
|
69
|
+
###############################################
|
|
70
|
+
|
|
71
|
+
def load_model(self, async_mode: bool = False):
|
|
72
|
+
if not async_mode:
|
|
73
|
+
return self._build_client(Client)
|
|
74
|
+
return self._build_client(AsyncClient)
|
|
75
|
+
|
|
76
|
+
def _build_client(self, cls):
|
|
77
|
+
return cls(host=self.base_url, **self.kwargs)
|
|
78
|
+
|
|
61
79
|
def get_model_name(self):
|
|
62
|
-
return self.model_name
|
|
80
|
+
return f"{self.model_name} (Ollama)"
|
|
@@ -1,6 +1,14 @@
|
|
|
1
|
-
from typing import Optional, List
|
|
1
|
+
from typing import Dict, Optional, List
|
|
2
2
|
from openai import OpenAI, AsyncOpenAI
|
|
3
3
|
from deepeval.models import DeepEvalBaseEmbeddingModel
|
|
4
|
+
from deepeval.models.retry_policy import (
|
|
5
|
+
create_retry_decorator,
|
|
6
|
+
sdk_retries_for,
|
|
7
|
+
)
|
|
8
|
+
from deepeval.constants import ProviderSlug as PS
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
retry_openai = create_retry_decorator(PS.OPENAI)
|
|
4
12
|
|
|
5
13
|
valid_openai_embedding_models = [
|
|
6
14
|
"text-embedding-3-small",
|
|
@@ -15,6 +23,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
15
23
|
self,
|
|
16
24
|
model: Optional[str] = None,
|
|
17
25
|
_openai_api_key: Optional[str] = None,
|
|
26
|
+
**kwargs,
|
|
18
27
|
):
|
|
19
28
|
model_name = model if model else default_openai_embedding_model
|
|
20
29
|
if model_name not in valid_openai_embedding_models:
|
|
@@ -23,7 +32,9 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
23
32
|
)
|
|
24
33
|
self._openai_api_key = _openai_api_key
|
|
25
34
|
self.model_name = model_name
|
|
35
|
+
self.kwargs = kwargs
|
|
26
36
|
|
|
37
|
+
@retry_openai
|
|
27
38
|
def embed_text(self, text: str) -> List[float]:
|
|
28
39
|
client = self.load_model(async_mode=False)
|
|
29
40
|
response = client.embeddings.create(
|
|
@@ -32,6 +43,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
32
43
|
)
|
|
33
44
|
return response.data[0].embedding
|
|
34
45
|
|
|
46
|
+
@retry_openai
|
|
35
47
|
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
36
48
|
client = self.load_model(async_mode=False)
|
|
37
49
|
response = client.embeddings.create(
|
|
@@ -40,6 +52,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
40
52
|
)
|
|
41
53
|
return [item.embedding for item in response.data]
|
|
42
54
|
|
|
55
|
+
@retry_openai
|
|
43
56
|
async def a_embed_text(self, text: str) -> List[float]:
|
|
44
57
|
client = self.load_model(async_mode=True)
|
|
45
58
|
response = await client.embeddings.create(
|
|
@@ -48,6 +61,7 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
48
61
|
)
|
|
49
62
|
return response.data[0].embedding
|
|
50
63
|
|
|
64
|
+
@retry_openai
|
|
51
65
|
async def a_embed_texts(self, texts: List[str]) -> List[List[float]]:
|
|
52
66
|
client = self.load_model(async_mode=True)
|
|
53
67
|
response = await client.embeddings.create(
|
|
@@ -56,11 +70,39 @@ class OpenAIEmbeddingModel(DeepEvalBaseEmbeddingModel):
|
|
|
56
70
|
)
|
|
57
71
|
return [item.embedding for item in response.data]
|
|
58
72
|
|
|
59
|
-
|
|
73
|
+
###############################################
|
|
74
|
+
# Model
|
|
75
|
+
###############################################
|
|
76
|
+
|
|
77
|
+
def get_model_name(self):
|
|
60
78
|
return self.model_name
|
|
61
79
|
|
|
62
|
-
def load_model(self, async_mode: bool):
|
|
80
|
+
def load_model(self, async_mode: bool = False):
|
|
63
81
|
if not async_mode:
|
|
64
|
-
return
|
|
82
|
+
return self._build_client(OpenAI)
|
|
83
|
+
return self._build_client(AsyncOpenAI)
|
|
65
84
|
|
|
66
|
-
|
|
85
|
+
def _client_kwargs(self) -> Dict:
|
|
86
|
+
"""
|
|
87
|
+
If Tenacity is managing retries, force OpenAI SDK retries off to avoid double retries.
|
|
88
|
+
If the user opts into SDK retries for 'openai' via DEEPEVAL_SDK_RETRY_PROVIDERS,
|
|
89
|
+
leave their retry settings as is.
|
|
90
|
+
"""
|
|
91
|
+
kwargs = dict(self.kwargs or {})
|
|
92
|
+
if not sdk_retries_for(PS.OPENAI):
|
|
93
|
+
kwargs["max_retries"] = 0
|
|
94
|
+
return kwargs
|
|
95
|
+
|
|
96
|
+
def _build_client(self, cls):
|
|
97
|
+
kw = dict(
|
|
98
|
+
api_key=self._openai_api_key,
|
|
99
|
+
**self._client_kwargs(),
|
|
100
|
+
)
|
|
101
|
+
try:
|
|
102
|
+
return cls(**kw)
|
|
103
|
+
except TypeError as e:
|
|
104
|
+
# older OpenAI SDKs may not accept max_retries, in that case remove and retry once
|
|
105
|
+
if "max_retries" in str(e):
|
|
106
|
+
kw.pop("max_retries", None)
|
|
107
|
+
return cls(**kw)
|
|
108
|
+
raise
|