deepeval 3.8.0__py3-none-any.whl → 3.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,77 @@
1
- from typing import Any, List, Dict, Optional
1
+ import uuid
2
+ from typing import Any, List, Dict, Optional, Union, Literal, Callable
3
+ from time import perf_counter
2
4
  from langchain_core.outputs import ChatGeneration
5
+ from rich.progress import Progress
6
+
7
+ from deepeval.metrics import BaseMetric
8
+ from deepeval.tracing.context import current_span_context, current_trace_context
9
+ from deepeval.tracing.tracing import trace_manager
10
+ from deepeval.tracing.types import (
11
+ AgentSpan,
12
+ BaseSpan,
13
+ LlmSpan,
14
+ RetrieverSpan,
15
+ SpanType,
16
+ ToolSpan,
17
+ TraceSpanStatus,
18
+ )
19
+
20
+
21
+ def convert_chat_messages_to_input(
22
+ messages: list[list[Any]], **kwargs
23
+ ) -> List[Dict[str, str]]:
24
+ """
25
+ Convert LangChain chat messages to our internal format.
26
+
27
+ Args:
28
+ messages: list[list[BaseMessage]] - outer list is batches, inner is messages.
29
+ **kwargs: May contain invocation_params with tools definitions.
30
+
31
+ Returns:
32
+ List of dicts with 'role' and 'content' keys, matching the schema used
33
+ by parse_prompts_to_messages for consistency.
34
+ """
35
+ # Valid roles matching parse_prompts_to_messages
36
+ ROLE_MAPPING = {
37
+ "human": "human",
38
+ "user": "human",
39
+ "ai": "ai",
40
+ "assistant": "ai",
41
+ "system": "system",
42
+ "tool": "tool",
43
+ "function": "function",
44
+ }
45
+
46
+ result: List[Dict[str, str]] = []
47
+ for batch in messages:
48
+ for msg in batch:
49
+ # BaseMessage has .type (role) and .content
50
+ raw_role = getattr(msg, "type", "unknown")
51
+ content = getattr(msg, "content", "")
52
+
53
+ # Normalize role using same conventions as prompt parsing
54
+ role = ROLE_MAPPING.get(raw_role.lower(), raw_role)
55
+
56
+ # Convert content to string (handles empty content, lists, etc.)
57
+ if isinstance(content, list):
58
+ # Some messages have content as a list of content blocks
59
+ content_str = " ".join(
60
+ str(c.get("text", c) if isinstance(c, dict) else c)
61
+ for c in content
62
+ )
63
+ else:
64
+ content_str = str(content) if content else ""
65
+
66
+ result.append({"role": role, "content": content_str})
67
+
68
+ # Append tool definitions if present which matches parse_prompts_to_messages behavior
69
+ tools = kwargs.get("invocation_params", {}).get("tools", None)
70
+ if tools and isinstance(tools, list):
71
+ for tool in tools:
72
+ result.append({"role": "Tool Input", "content": str(tool)})
73
+
74
+ return result
3
75
 
4
76
 
5
77
  def parse_prompts_to_messages(
@@ -112,27 +184,6 @@ def safe_extract_model_name(
112
184
  return None
113
185
 
114
186
 
115
- from typing import Any, List, Dict, Optional, Union, Literal, Callable
116
- from langchain_core.outputs import ChatGeneration
117
- from time import perf_counter
118
- import uuid
119
- from rich.progress import Progress
120
- from deepeval.tracing.tracing import Observer
121
-
122
- from deepeval.metrics import BaseMetric
123
- from deepeval.tracing.context import current_span_context, current_trace_context
124
- from deepeval.tracing.tracing import trace_manager
125
- from deepeval.tracing.types import (
126
- AgentSpan,
127
- BaseSpan,
128
- LlmSpan,
129
- RetrieverSpan,
130
- SpanType,
131
- ToolSpan,
132
- TraceSpanStatus,
133
- )
134
-
135
-
136
187
  def enter_current_context(
137
188
  span_type: Optional[
138
189
  Union[Literal["agent", "llm", "retriever", "tool"], str]
@@ -239,8 +290,8 @@ def enter_current_context(
239
290
 
240
291
  if (
241
292
  parent_span
242
- and getattr(parent_span, "progress", None) is not None
243
- and getattr(parent_span, "pbar_callback_id", None) is not None
293
+ and parent_span.progress is not None
294
+ and parent_span.pbar_callback_id is not None
244
295
  ):
245
296
  progress = parent_span.progress
246
297
  pbar_callback_id = parent_span.pbar_callback_id
@@ -40,6 +40,7 @@ try:
40
40
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
41
41
  OTLPSpanExporter,
42
42
  )
43
+ from opentelemetry.trace import set_tracer_provider
43
44
  from pydantic_ai.models.instrumented import (
44
45
  InstrumentationSettings as _BaseInstrumentationSettings,
45
46
  )
@@ -131,7 +132,12 @@ class ConfidentInstrumentationSettings(InstrumentationSettings):
131
132
  ):
132
133
  is_dependency_installed()
133
134
 
134
- _environment = os.getenv("CONFIDENT_TRACE_ENVIRONMENT", "development")
135
+ if trace_manager.environment is not None:
136
+ _environment = trace_manager.environment
137
+ elif settings.CONFIDENT_TRACE_ENVIRONMENT is not None:
138
+ _environment = settings.CONFIDENT_TRACE_ENVIRONMENT
139
+ else:
140
+ _environment = "development"
135
141
  if _environment and _environment in [
136
142
  "production",
137
143
  "staging",
@@ -176,6 +182,12 @@ class ConfidentInstrumentationSettings(InstrumentationSettings):
176
182
  )
177
183
  )
178
184
  )
185
+ try:
186
+ set_tracer_provider(trace_provider)
187
+ except Exception as e:
188
+ # Handle case where provider is already set (optional warning)
189
+ logger.warning(f"Could not set global tracer provider: {e}")
190
+
179
191
  super().__init__(tracer_provider=trace_provider)
180
192
 
181
193
 
@@ -234,16 +246,14 @@ class SpanInterceptor(SpanProcessor):
234
246
  )
235
247
 
236
248
  # set agent name and metric collection
237
- if span.attributes.get("agent_name"):
238
- span.set_attribute("confident.span.type", "agent")
239
- span.set_attribute(
240
- "confident.span.name", span.attributes.get("agent_name")
241
- )
242
- if self.settings.agent_metric_collection:
243
- span.set_attribute(
244
- "confident.span.metric_collection",
245
- self.settings.agent_metric_collection,
246
- )
249
+ agent_name = (
250
+ span.attributes.get("gen_ai.agent.name")
251
+ or span.attributes.get("pydantic_ai.agent.name")
252
+ or span.attributes.get("agent_name")
253
+ )
254
+
255
+ if agent_name:
256
+ self._add_agent_span(span, agent_name)
247
257
 
248
258
  # set llm metric collection
249
259
  if span.attributes.get("gen_ai.operation.name") in [
@@ -270,6 +280,19 @@ class SpanInterceptor(SpanProcessor):
270
280
  )
271
281
 
272
282
  def on_end(self, span):
283
+
284
+ already_processed = (
285
+ span.attributes.get("confident.span.type") == "agent"
286
+ )
287
+ if not already_processed:
288
+ agent_name = (
289
+ span.attributes.get("gen_ai.agent.name")
290
+ or span.attributes.get("pydantic_ai.agent.name")
291
+ or span.attributes.get("agent_name")
292
+ )
293
+ if agent_name:
294
+ self._add_agent_span(span, agent_name)
295
+
273
296
  if self.settings.is_test_mode:
274
297
  if span.attributes.get("confident.span.type") == "agent":
275
298
 
@@ -323,3 +346,12 @@ class SpanInterceptor(SpanProcessor):
323
346
  trace.end_time = perf_counter()
324
347
  trace_manager.traces_to_evaluate.append(trace)
325
348
  test_exporter.clear_span_json_list()
349
+
350
+ def _add_agent_span(self, span, name):
351
+ span.set_attribute("confident.span.type", "agent")
352
+ span.set_attribute("confident.span.name", name)
353
+ if self.settings.agent_metric_collection:
354
+ span.set_attribute(
355
+ "confident.span.metric_collection",
356
+ self.settings.agent_metric_collection,
357
+ )
@@ -2,6 +2,7 @@ import warnings
2
2
  from typing import Optional
3
3
  from deepeval.telemetry import capture_tracing_integration
4
4
  from deepeval.config.settings import get_settings
5
+ import logging
5
6
 
6
7
  try:
7
8
  from opentelemetry import trace
@@ -24,6 +25,9 @@ def is_opentelemetry_available():
24
25
  return True
25
26
 
26
27
 
28
+ logger = logging.getLogger(__name__)
29
+ settings = get_settings()
30
+
27
31
  settings = get_settings()
28
32
  # OTLP_ENDPOINT = "https://otel.confident-ai.com/v1/traces"
29
33
 
@@ -51,6 +55,11 @@ def instrument_pydantic_ai(api_key: Optional[str] = None):
51
55
  )
52
56
  )
53
57
  )
58
+ try:
59
+ trace.set_tracer_provider(tracer_provider)
60
+ except Exception as e:
61
+ # Handle case where provider is already set (optional warning)
62
+ logger.warning(f"Could not set global tracer provider: {e}")
54
63
 
55
64
  # create an instrumented exporter
56
65
  from pydantic_ai.models.instrumented import InstrumentationSettings
@@ -23,6 +23,7 @@ from deepeval.metrics.contextual_recall.schema import (
23
23
  ContextualRecallVerdict,
24
24
  Verdicts,
25
25
  ContextualRecallScoreReason,
26
+ VerdictWithExpectedOutput,
26
27
  )
27
28
  from deepeval.metrics.api import metric_data_manager
28
29
 
@@ -93,7 +94,7 @@ class ContextualRecallMetric(BaseMetric):
93
94
  expected_output = test_case.expected_output
94
95
  retrieval_context = test_case.retrieval_context
95
96
 
96
- self.verdicts: List[ContextualRecallVerdict] = (
97
+ self.verdicts: List[VerdictWithExpectedOutput] = (
97
98
  self._generate_verdicts(
98
99
  expected_output, retrieval_context, multimodal
99
100
  )
@@ -144,7 +145,7 @@ class ContextualRecallMetric(BaseMetric):
144
145
  expected_output = test_case.expected_output
145
146
  retrieval_context = test_case.retrieval_context
146
147
 
147
- self.verdicts: List[ContextualRecallVerdict] = (
148
+ self.verdicts: List[VerdictWithExpectedOutput] = (
148
149
  await self._a_generate_verdicts(
149
150
  expected_output, retrieval_context, multimodal
150
151
  )
@@ -241,13 +242,13 @@ class ContextualRecallMetric(BaseMetric):
241
242
  expected_output: str,
242
243
  retrieval_context: List[str],
243
244
  multimodal: bool,
244
- ) -> List[ContextualRecallVerdict]:
245
+ ) -> List[VerdictWithExpectedOutput]:
245
246
  prompt = self.evaluation_template.generate_verdicts(
246
247
  expected_output=expected_output,
247
248
  retrieval_context=retrieval_context,
248
249
  multimodal=multimodal,
249
250
  )
250
- return await a_generate_with_schema_and_extract(
251
+ verdicts = await a_generate_with_schema_and_extract(
251
252
  metric=self,
252
253
  prompt=prompt,
253
254
  schema_cls=Verdicts,
@@ -256,19 +257,28 @@ class ContextualRecallMetric(BaseMetric):
256
257
  ContextualRecallVerdict(**item) for item in data["verdicts"]
257
258
  ],
258
259
  )
260
+ final_verdicts = []
261
+ for verdict in verdicts:
262
+ new_verdict = VerdictWithExpectedOutput(
263
+ verdict=verdict.verdict,
264
+ reason=verdict.reason,
265
+ expected_output=expected_output,
266
+ )
267
+ final_verdicts.append(new_verdict)
268
+ return final_verdicts
259
269
 
260
270
  def _generate_verdicts(
261
271
  self,
262
272
  expected_output: str,
263
273
  retrieval_context: List[str],
264
274
  multimodal: bool,
265
- ) -> List[ContextualRecallVerdict]:
275
+ ) -> List[VerdictWithExpectedOutput]:
266
276
  prompt = self.evaluation_template.generate_verdicts(
267
277
  expected_output=expected_output,
268
278
  retrieval_context=retrieval_context,
269
279
  multimodal=multimodal,
270
280
  )
271
- return generate_with_schema_and_extract(
281
+ verdicts = generate_with_schema_and_extract(
272
282
  metric=self,
273
283
  prompt=prompt,
274
284
  schema_cls=Verdicts,
@@ -277,6 +287,15 @@ class ContextualRecallMetric(BaseMetric):
277
287
  ContextualRecallVerdict(**item) for item in data["verdicts"]
278
288
  ],
279
289
  )
290
+ final_verdicts = []
291
+ for verdict in verdicts:
292
+ new_verdict = VerdictWithExpectedOutput(
293
+ verdict=verdict.verdict,
294
+ reason=verdict.reason,
295
+ expected_output=expected_output,
296
+ )
297
+ final_verdicts.append(new_verdict)
298
+ return final_verdicts
280
299
 
281
300
  def is_successful(self) -> bool:
282
301
  if self.error is not None:
@@ -7,6 +7,12 @@ class ContextualRecallVerdict(BaseModel):
7
7
  reason: str
8
8
 
9
9
 
10
+ class VerdictWithExpectedOutput(BaseModel):
11
+ verdict: str
12
+ reason: str
13
+ expected_output: str
14
+
15
+
10
16
  class Verdicts(BaseModel):
11
17
  verdicts: List[ContextualRecallVerdict]
12
18
 
@@ -85,7 +85,12 @@ class ImageCoherenceMetric(BaseMetric):
85
85
  self.contexts_below = []
86
86
  self.scores = []
87
87
  self.reasons = []
88
- for image_index in self.get_image_indices(actual_output):
88
+ image_indices = self.get_image_indices(actual_output)
89
+ if not image_indices:
90
+ raise ValueError(
91
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
92
+ )
93
+ for image_index in image_indices:
89
94
  context_above, context_below = self.get_image_context(
90
95
  image_index, actual_output
91
96
  )
@@ -188,6 +193,10 @@ class ImageCoherenceMetric(BaseMetric):
188
193
 
189
194
  tasks = []
190
195
  image_indices = self.get_image_indices(actual_output)
196
+ if not image_indices:
197
+ raise ValueError(
198
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
199
+ )
191
200
  for image_index in image_indices:
192
201
  context_above, context_below = self.get_image_context(
193
202
  image_index, actual_output
@@ -86,7 +86,12 @@ class ImageHelpfulnessMetric(BaseMetric):
86
86
  self.contexts_below = []
87
87
  self.scores = []
88
88
  self.reasons = []
89
- for image_index in self.get_image_indices(actual_output):
89
+ image_indices = self.get_image_indices(actual_output)
90
+ if not image_indices:
91
+ raise ValueError(
92
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
93
+ )
94
+ for image_index in image_indices:
90
95
  context_above, context_below = self.get_image_context(
91
96
  image_index, actual_output
92
97
  )
@@ -189,6 +194,10 @@ class ImageHelpfulnessMetric(BaseMetric):
189
194
 
190
195
  tasks = []
191
196
  image_indices = self.get_image_indices(actual_output)
197
+ if not image_indices:
198
+ raise ValueError(
199
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
200
+ )
192
201
  for image_index in image_indices:
193
202
  context_above, context_below = self.get_image_context(
194
203
  image_index, actual_output
@@ -86,7 +86,12 @@ class ImageReferenceMetric(BaseMetric):
86
86
  self.contexts_below = []
87
87
  self.scores = []
88
88
  self.reasons = []
89
- for image_index in self.get_image_indices(actual_output):
89
+ image_indices = self.get_image_indices(actual_output)
90
+ if not image_indices:
91
+ raise ValueError(
92
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
93
+ )
94
+ for image_index in image_indices:
90
95
  context_above, context_below = self.get_image_context(
91
96
  image_index, actual_output
92
97
  )
@@ -189,6 +194,10 @@ class ImageReferenceMetric(BaseMetric):
189
194
 
190
195
  tasks = []
191
196
  image_indices = self.get_image_indices(actual_output)
197
+ if not image_indices:
198
+ raise ValueError(
199
+ f"The test case must have atleast one image in the `actual_output` to calculate {self.__name__} score"
200
+ )
192
201
  for image_index in image_indices:
193
202
  context_above, context_below = self.get_image_context(
194
203
  image_index, actual_output
deepeval/metrics/utils.py CHANGED
@@ -312,7 +312,7 @@ def check_llm_test_case_params(
312
312
  if isinstance(ele, MLLMImage):
313
313
  count += 1
314
314
  if count != actual_output_image_count:
315
- error_str = f"Unable to evaluate test cases with '{actual_output_image_count}' output images using the '{metric.__name__}' metric. `{count}` found."
315
+ error_str = f"Can only evaluate test cases with '{actual_output_image_count}' output images using the '{metric.__name__}' metric. `{count}` found."
316
316
  raise ValueError(error_str)
317
317
 
318
318
  if isinstance(test_case, LLMTestCase) is False:
@@ -320,6 +320,17 @@ def check_llm_test_case_params(
320
320
  metric.error = error_str
321
321
  raise ValueError(error_str)
322
322
 
323
+ # Centralized: if a metric requires actual_output, reject empty/whitespace
324
+ # (including empty multimodal outputs) as "missing params".
325
+ if LLMTestCaseParams.ACTUAL_OUTPUT in test_case_params:
326
+ actual_output = getattr(
327
+ test_case, LLMTestCaseParams.ACTUAL_OUTPUT.value
328
+ )
329
+ if isinstance(actual_output, str) and actual_output == "":
330
+ error_str = f"'actual_output' cannot be empty for the '{metric.__name__}' metric"
331
+ metric.error = error_str
332
+ raise MissingTestCaseParamsError(error_str)
333
+
323
334
  missing_params = []
324
335
  for param in test_case_params:
325
336
  if getattr(test_case, param.value) is None:
@@ -14,6 +14,7 @@ from deepeval.models.retry_policy import (
14
14
  sdk_retries_for,
15
15
  )
16
16
  from deepeval.test_case import MLLMImage
17
+ from deepeval.errors import DeepEvalError
17
18
  from deepeval.utils import check_if_multimodal, convert_to_multi_modal_array
18
19
  from deepeval.models import DeepEvalBaseLLM
19
20
  from deepeval.models.llms.constants import BEDROCK_MODELS_DATA
@@ -155,27 +156,28 @@ class AmazonBedrockModel(DeepEvalBaseLLM):
155
156
 
156
157
  def generate(
157
158
  self, prompt: str, schema: Optional[BaseModel] = None
158
- ) -> Tuple[Union[str, BaseModel], float]:
159
+ ) -> Tuple[Union[str, BaseModel], Optional[float]]:
159
160
  return safe_asyncio_run(self.a_generate(prompt, schema))
160
161
 
161
162
  @retry_bedrock
162
163
  async def a_generate(
163
164
  self, prompt: str, schema: Optional[BaseModel] = None
164
- ) -> Tuple[Union[str, BaseModel], float]:
165
+ ) -> Tuple[Union[str, BaseModel], Optional[float]]:
165
166
  if check_if_multimodal(prompt):
166
167
  prompt = convert_to_multi_modal_array(input=prompt)
167
168
  payload = self.generate_payload(prompt)
168
169
  else:
169
170
  payload = self.get_converse_request_body(prompt)
170
171
 
171
- payload = self.get_converse_request_body(prompt)
172
172
  client = await self._ensure_client()
173
173
  response = await client.converse(
174
174
  modelId=self.get_model_name(),
175
175
  messages=payload["messages"],
176
176
  inferenceConfig=payload["inferenceConfig"],
177
177
  )
178
- message = response["output"]["message"]["content"][0]["text"]
178
+
179
+ message = self._extract_text_from_converse_response(response)
180
+
179
181
  cost = self.calculate_cost(
180
182
  response["usage"]["inputTokens"],
181
183
  response["usage"]["outputTokens"],
@@ -206,7 +208,7 @@ class AmazonBedrockModel(DeepEvalBaseLLM):
206
208
  try:
207
209
  image_raw_bytes = base64.b64decode(element.dataBase64)
208
210
  except Exception:
209
- raise ValueError(
211
+ raise DeepEvalError(
210
212
  f"Invalid base64 data in MLLMImage: {element._id}"
211
213
  )
212
214
 
@@ -294,6 +296,46 @@ class AmazonBedrockModel(DeepEvalBaseLLM):
294
296
  # Helpers
295
297
  ###############################################
296
298
 
299
+ @staticmethod
300
+ def _extract_text_from_converse_response(response: dict) -> str:
301
+ try:
302
+ content = response["output"]["message"]["content"]
303
+ except Exception as e:
304
+ raise DeepEvalError(
305
+ "Missing output.message.content in Bedrock response"
306
+ ) from e
307
+
308
+ # Collect any text blocks (ignore reasoning/tool blocks)
309
+ text_parts = []
310
+ for block in content:
311
+ if isinstance(block, dict) and "text" in block:
312
+ v = block.get("text")
313
+ if isinstance(v, str) and v.strip():
314
+ text_parts.append(v)
315
+
316
+ if text_parts:
317
+ # join in case there are multiple text blocks
318
+ return "\n".join(text_parts)
319
+
320
+ # No text blocks present; raise an actionable error
321
+ keys = []
322
+ for b in content:
323
+ if isinstance(b, dict):
324
+ keys.append(list(b.keys()))
325
+ else:
326
+ keys.append(type(b).__name__)
327
+
328
+ stop_reason = (
329
+ response.get("stopReason")
330
+ or response.get("output", {}).get("stopReason")
331
+ or response.get("output", {}).get("message", {}).get("stopReason")
332
+ )
333
+
334
+ raise DeepEvalError(
335
+ f"Bedrock response contained no text content blocks. "
336
+ f"content keys={keys}, stopReason={stop_reason}"
337
+ )
338
+
297
339
  def get_converse_request_body(self, prompt: str) -> dict:
298
340
 
299
341
  return {
@@ -303,11 +345,14 @@ class AmazonBedrockModel(DeepEvalBaseLLM):
303
345
  },
304
346
  }
305
347
 
306
- def calculate_cost(self, input_tokens: int, output_tokens: int) -> float:
348
+ def calculate_cost(
349
+ self, input_tokens: int, output_tokens: int
350
+ ) -> Optional[float]:
307
351
  if self.model_data.input_price and self.model_data.output_price:
308
352
  input_cost = input_tokens * self.model_data.input_price
309
353
  output_cost = output_tokens * self.model_data.output_price
310
354
  return input_cost + output_cost
355
+ return None
311
356
 
312
357
  def load_model(self):
313
358
  pass
@@ -1,6 +1,6 @@
1
1
  from openai.types.chat.chat_completion import ChatCompletion
2
2
  from openai import AzureOpenAI, AsyncAzureOpenAI
3
- from typing import Optional, Tuple, Union, Dict, List
3
+ from typing import Optional, Tuple, Union, Dict, List, Callable, Awaitable
4
4
  from pydantic import BaseModel, SecretStr
5
5
 
6
6
  from deepeval.errors import DeepEvalError
@@ -42,6 +42,10 @@ class AzureOpenAIModel(DeepEvalBaseLLM):
42
42
  model: Optional[str] = None,
43
43
  api_key: Optional[str] = None,
44
44
  base_url: Optional[str] = None,
45
+ azure_ad_token_provider: Optional[
46
+ Callable[[], "str | Awaitable[str]"]
47
+ ] = None,
48
+ azure_ad_token: Optional[str] = None,
45
49
  temperature: Optional[float] = None,
46
50
  cost_per_input_token: Optional[float] = None,
47
51
  cost_per_output_token: Optional[float] = None,
@@ -67,12 +71,19 @@ class AzureOpenAIModel(DeepEvalBaseLLM):
67
71
  model = model or settings.AZURE_MODEL_NAME
68
72
  deployment_name = deployment_name or settings.AZURE_DEPLOYMENT_NAME
69
73
 
74
+ self.azure_ad_token_provider = azure_ad_token_provider
75
+
70
76
  if api_key is not None:
71
77
  # keep it secret, keep it safe from serializings, logging and alike
72
78
  self.api_key: Optional[SecretStr] = SecretStr(api_key)
73
79
  else:
74
80
  self.api_key = settings.AZURE_OPENAI_API_KEY
75
81
 
82
+ if azure_ad_token is not None:
83
+ self.azure_ad_token = azure_ad_token
84
+ else:
85
+ self.azure_ad_token = settings.AZURE_OPENAI_AD_TOKEN
86
+
76
87
  api_version = api_version or settings.OPENAI_API_VERSION
77
88
  if base_url is not None:
78
89
  base_url = str(base_url).rstrip("/")
@@ -431,18 +442,33 @@ class AzureOpenAIModel(DeepEvalBaseLLM):
431
442
  return kwargs
432
443
 
433
444
  def _build_client(self, cls):
434
- api_key = require_secret_api_key(
435
- self.api_key,
436
- provider_label="AzureOpenAI",
437
- env_var_name="AZURE_OPENAI_API_KEY",
438
- param_hint="`api_key` to AzureOpenAIModel(...)",
439
- )
445
+ # Only require the API key / Azure ad token if no token provider is supplied
446
+ azure_ad_token = None
447
+ api_key = None
448
+
449
+ if self.azure_ad_token_provider is None:
450
+ if self.azure_ad_token is not None:
451
+ azure_ad_token = require_secret_api_key(
452
+ self.azure_ad_token,
453
+ provider_label="AzureOpenAI",
454
+ env_var_name="AZURE_OPENAI_AD_TOKEN",
455
+ param_hint="`azure_ad_token` to AzureOpenAIModel(...)",
456
+ )
457
+ else:
458
+ api_key = require_secret_api_key(
459
+ self.api_key,
460
+ provider_label="AzureOpenAI",
461
+ env_var_name="AZURE_OPENAI_API_KEY",
462
+ param_hint="`api_key` to AzureOpenAIModel(...)",
463
+ )
440
464
 
441
465
  kw = dict(
442
466
  api_key=api_key,
443
467
  api_version=self.api_version,
444
468
  azure_endpoint=self.base_url,
445
469
  azure_deployment=self.deployment_name,
470
+ azure_ad_token_provider=self.azure_ad_token_provider,
471
+ azure_ad_token=azure_ad_token,
446
472
  **self._client_kwargs(),
447
473
  )
448
474
  try:
@@ -65,6 +65,7 @@ class GeminiModel(DeepEvalBaseLLM):
65
65
  project: Optional[str] = None,
66
66
  location: Optional[str] = None,
67
67
  service_account_key: Optional[Union[str, Dict[str, str]]] = None,
68
+ use_vertexai: Optional[bool] = None,
68
69
  generation_kwargs: Optional[Dict] = None,
69
70
  **kwargs,
70
71
  ):
@@ -93,7 +94,11 @@ class GeminiModel(DeepEvalBaseLLM):
93
94
  location if location is not None else settings.GOOGLE_CLOUD_LOCATION
94
95
  )
95
96
  self.location = str(location).strip() if location is not None else None
96
- self.use_vertexai = settings.GOOGLE_GENAI_USE_VERTEXAI
97
+ self.use_vertexai = (
98
+ use_vertexai
99
+ if use_vertexai is not None
100
+ else settings.GOOGLE_GENAI_USE_VERTEXAI
101
+ )
97
102
 
98
103
  self.service_account_key: Optional[SecretStr] = None
99
104
  if service_account_key is None: