azure-ai-evaluation 1.0.0b3__py3-none-any.whl → 1.0.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (93) hide show
  1. azure/ai/evaluation/__init__.py +23 -1
  2. azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +20 -9
  3. azure/ai/evaluation/_common/constants.py +9 -2
  4. azure/ai/evaluation/_common/math.py +29 -0
  5. azure/ai/evaluation/_common/rai_service.py +222 -93
  6. azure/ai/evaluation/_common/utils.py +328 -19
  7. azure/ai/evaluation/_constants.py +16 -8
  8. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
  9. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +33 -17
  10. azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +14 -7
  11. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +22 -4
  12. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +35 -0
  13. azure/ai/evaluation/_evaluate/_eval_run.py +47 -14
  14. azure/ai/evaluation/_evaluate/_evaluate.py +370 -188
  15. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +15 -16
  16. azure/ai/evaluation/_evaluate/_utils.py +77 -25
  17. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
  18. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +16 -10
  19. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
  20. azure/ai/evaluation/_evaluators/_common/_base_eval.py +76 -46
  21. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +26 -19
  22. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +62 -25
  23. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +68 -36
  24. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +67 -46
  25. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +33 -4
  26. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +33 -4
  27. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +33 -4
  28. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +33 -4
  29. azure/ai/evaluation/_evaluators/_eci/_eci.py +7 -5
  30. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +14 -6
  31. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +22 -21
  32. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
  33. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
  34. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +51 -16
  35. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
  36. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
  37. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +3 -7
  38. azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
  39. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +130 -0
  40. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +57 -0
  41. azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +96 -0
  42. azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +120 -0
  43. azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +96 -0
  44. azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +96 -0
  45. azure/ai/evaluation/_evaluators/_multimodal/_violence.py +96 -0
  46. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +46 -13
  47. azure/ai/evaluation/_evaluators/_qa/_qa.py +11 -6
  48. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +23 -20
  49. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
  50. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +126 -80
  51. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
  52. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +2 -2
  53. azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
  54. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +150 -0
  55. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +32 -15
  56. azure/ai/evaluation/_evaluators/_xpia/xpia.py +36 -10
  57. azure/ai/evaluation/_exceptions.py +26 -6
  58. azure/ai/evaluation/_http_utils.py +203 -132
  59. azure/ai/evaluation/_model_configurations.py +23 -6
  60. azure/ai/evaluation/_vendor/__init__.py +3 -0
  61. azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
  62. azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +328 -0
  63. azure/ai/evaluation/_vendor/rouge_score/scoring.py +63 -0
  64. azure/ai/evaluation/_vendor/rouge_score/tokenize.py +63 -0
  65. azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
  66. azure/ai/evaluation/_version.py +1 -1
  67. azure/ai/evaluation/simulator/__init__.py +2 -1
  68. azure/ai/evaluation/simulator/_adversarial_scenario.py +5 -0
  69. azure/ai/evaluation/simulator/_adversarial_simulator.py +88 -60
  70. azure/ai/evaluation/simulator/_conversation/__init__.py +13 -12
  71. azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
  72. azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
  73. azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
  74. azure/ai/evaluation/simulator/_direct_attack_simulator.py +24 -66
  75. azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
  76. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +26 -5
  77. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +98 -95
  78. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +67 -21
  79. azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +28 -11
  80. azure/ai/evaluation/simulator/_model_tools/_template_handler.py +68 -24
  81. azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
  82. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -9
  83. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -5
  84. azure/ai/evaluation/simulator/_simulator.py +222 -169
  85. azure/ai/evaluation/simulator/_tracing.py +4 -4
  86. azure/ai/evaluation/simulator/_utils.py +6 -6
  87. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/METADATA +237 -52
  88. azure_ai_evaluation-1.0.0b5.dist-info/NOTICE.txt +70 -0
  89. azure_ai_evaluation-1.0.0b5.dist-info/RECORD +120 -0
  90. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/WHEEL +1 -1
  91. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
  92. azure_ai_evaluation-1.0.0b3.dist-info/RECORD +0 -98
  93. {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/top_level.txt +0 -0
@@ -6,20 +6,21 @@
6
6
  import asyncio
7
7
  import logging
8
8
  import random
9
- from typing import Any, Callable, Dict, List, Optional
9
+ from typing import Any, Callable, Dict, List, Literal, Optional, Union, cast
10
10
 
11
11
  from tqdm import tqdm
12
12
 
13
+ from azure.ai.evaluation._common._experimental import experimental
14
+ from azure.ai.evaluation._common.utils import validate_azure_ai_project
13
15
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
14
16
  from azure.ai.evaluation._http_utils import get_async_http_client
15
- from azure.ai.evaluation._model_configurations import AzureAIProject
16
17
  from azure.ai.evaluation.simulator import AdversarialScenario
17
18
  from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario
19
+ from azure.core.credentials import TokenCredential
18
20
  from azure.core.pipeline.policies import AsyncRetryPolicy, RetryMode
19
- from azure.identity import DefaultAzureCredential
20
21
 
21
22
  from ._constants import SupportedLanguages
22
- from ._conversation import CallbackConversationBot, ConversationBot, ConversationRole
23
+ from ._conversation import CallbackConversationBot, ConversationBot, ConversationRole, ConversationTurn
23
24
  from ._conversation._conversation import simulate_conversation
24
25
  from ._model_tools import (
25
26
  AdversarialTemplateHandler,
@@ -28,11 +29,13 @@ from ._model_tools import (
28
29
  RAIClient,
29
30
  TokenScope,
30
31
  )
32
+ from ._model_tools._template_handler import AdversarialTemplate, TemplateParameters
31
33
  from ._utils import JsonLineList
32
34
 
33
35
  logger = logging.getLogger(__name__)
34
36
 
35
37
 
38
+ @experimental
36
39
  class AdversarialSimulator:
37
40
  """
38
41
  Initializes the adversarial simulator with a project scope.
@@ -44,41 +47,28 @@ class AdversarialSimulator:
44
47
  :type credential: ~azure.core.credentials.TokenCredential
45
48
  """
46
49
 
47
- def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
50
+ def __init__(self, *, azure_ai_project: dict, credential):
48
51
  """Constructor."""
49
- # check if azure_ai_project has the keys: subscription_id, resource_group_name and project_name
50
- if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
51
- msg = "azure_ai_project must contain keys: subscription_id, resource_group_name, project_name"
52
+
53
+ try:
54
+ self.azure_ai_project = validate_azure_ai_project(azure_ai_project)
55
+ except EvaluationException as e:
52
56
  raise EvaluationException(
53
- message=msg,
54
- internal_message=msg,
57
+ message=e.message,
58
+ internal_message=e.internal_message,
55
59
  target=ErrorTarget.ADVERSARIAL_SIMULATOR,
56
- category=ErrorCategory.MISSING_FIELD,
57
- blame=ErrorBlame.USER_ERROR,
58
- )
59
- # check the value of the keys in azure_ai_project is not none
60
- if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
61
- msg = "subscription_id, resource_group_name and project_name cannot be None"
62
- raise EvaluationException(
63
- message=msg,
64
- internal_message=msg,
65
- target=ErrorTarget.ADVERSARIAL_SIMULATOR,
66
- category=ErrorCategory.MISSING_FIELD,
67
- blame=ErrorBlame.USER_ERROR,
68
- )
69
- if "credential" not in azure_ai_project and not credential:
70
- credential = DefaultAzureCredential()
71
- elif "credential" in azure_ai_project:
72
- credential = azure_ai_project["credential"]
73
- self.azure_ai_project = azure_ai_project
60
+ category=e.category,
61
+ blame=e.blame,
62
+ ) from e
63
+
74
64
  self.token_manager = ManagedIdentityAPITokenManager(
75
65
  token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
76
66
  logger=logging.getLogger("AdversarialSimulator"),
77
- credential=credential,
67
+ credential=cast(TokenCredential, credential),
78
68
  )
79
- self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
69
+ self.rai_client = RAIClient(azure_ai_project=self.azure_ai_project, token_manager=self.token_manager)
80
70
  self.adversarial_template_handler = AdversarialTemplateHandler(
81
- azure_ai_project=azure_ai_project, rai_client=self.rai_client
71
+ azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
82
72
  )
83
73
 
84
74
  def _ensure_service_dependencies(self):
@@ -92,7 +82,7 @@ class AdversarialSimulator:
92
82
  blame=ErrorBlame.USER_ERROR,
93
83
  )
94
84
 
95
- # @monitor_adversarial_scenario
85
+ # pylint: disable=too-many-locals
96
86
  async def __call__(
97
87
  self,
98
88
  *,
@@ -106,10 +96,10 @@ class AdversarialSimulator:
106
96
  api_call_retry_sleep_sec: int = 1,
107
97
  api_call_delay_sec: int = 0,
108
98
  concurrent_async_task: int = 3,
109
- _jailbreak_type: Optional[str] = None,
110
99
  language: SupportedLanguages = SupportedLanguages.English,
111
100
  randomize_order: bool = True,
112
101
  randomization_seed: Optional[int] = None,
102
+ **kwargs,
113
103
  ):
114
104
  """
115
105
  Executes the adversarial simulation against a specified target function asynchronously.
@@ -216,6 +206,7 @@ class AdversarialSimulator:
216
206
  total_tasks,
217
207
  )
218
208
  total_tasks = min(total_tasks, max_simulation_results)
209
+ _jailbreak_type = kwargs.get("_jailbreak_type", None)
219
210
  if _jailbreak_type:
220
211
  jailbreak_dataset = await self.rai_client.get_jailbreaks_dataset(type=_jailbreak_type)
221
212
  progress_bar = tqdm(
@@ -263,16 +254,21 @@ class AdversarialSimulator:
263
254
 
264
255
  return JsonLineList(sim_results)
265
256
 
266
- def _to_chat_protocol(self, *, conversation_history, template_parameters: Dict = None):
257
+ def _to_chat_protocol(
258
+ self,
259
+ *,
260
+ conversation_history: List[ConversationTurn],
261
+ template_parameters: Optional[Dict[str, Union[str, Dict[str, str]]]] = None,
262
+ ):
267
263
  if template_parameters is None:
268
264
  template_parameters = {}
269
265
  messages = []
270
266
  for _, m in enumerate(conversation_history):
271
267
  message = {"content": m.message, "role": m.role.value}
272
- if "context" in m.full_response:
268
+ if m.full_response is not None and "context" in m.full_response:
273
269
  message["context"] = m.full_response["context"]
274
270
  messages.append(message)
275
- conversation_category = template_parameters.pop("metadata", {}).get("Category")
271
+ conversation_category = cast(Dict[str, str], template_parameters.pop("metadata", {})).get("Category")
276
272
  template_parameters["metadata"] = {}
277
273
  for key in (
278
274
  "conversation_starter",
@@ -280,6 +276,9 @@ class AdversarialSimulator:
280
276
  "target_population",
281
277
  "topic",
282
278
  "ch_template_placeholder",
279
+ "chatbot_name",
280
+ "name",
281
+ "group",
283
282
  ):
284
283
  template_parameters.pop(key, None)
285
284
  if conversation_category:
@@ -294,14 +293,14 @@ class AdversarialSimulator:
294
293
  self,
295
294
  *,
296
295
  target: Callable,
297
- template,
298
- parameters,
299
- max_conversation_turns,
300
- api_call_retry_limit,
301
- api_call_retry_sleep_sec,
302
- api_call_delay_sec,
303
- language,
304
- semaphore,
296
+ template: AdversarialTemplate,
297
+ parameters: TemplateParameters,
298
+ max_conversation_turns: int,
299
+ api_call_retry_limit: int,
300
+ api_call_retry_sleep_sec: int,
301
+ api_call_delay_sec: int,
302
+ language: SupportedLanguages,
303
+ semaphore: asyncio.Semaphore,
305
304
  ) -> List[Dict]:
306
305
  user_bot = self._setup_bot(role=ConversationRole.USER, template=template, parameters=parameters)
307
306
  system_bot = self._setup_bot(
@@ -324,9 +323,15 @@ class AdversarialSimulator:
324
323
  api_call_delay_sec=api_call_delay_sec,
325
324
  language=language,
326
325
  )
327
- return self._to_chat_protocol(conversation_history=conversation_history, template_parameters=parameters)
328
326
 
329
- def _get_user_proxy_completion_model(self, template_key, template_parameters):
327
+ return self._to_chat_protocol(
328
+ conversation_history=conversation_history,
329
+ template_parameters=cast(Dict[str, Union[str, Dict[str, str]]], parameters),
330
+ )
331
+
332
+ def _get_user_proxy_completion_model(
333
+ self, template_key: str, template_parameters: TemplateParameters
334
+ ) -> ProxyChatCompletionsModel:
330
335
  return ProxyChatCompletionsModel(
331
336
  name="raisvc_proxy_model",
332
337
  template_key=template_key,
@@ -338,8 +343,15 @@ class AdversarialSimulator:
338
343
  temperature=0.0,
339
344
  )
340
345
 
341
- def _setup_bot(self, *, role, template, parameters, target: Callable = None):
342
- if role == ConversationRole.USER:
346
+ def _setup_bot(
347
+ self,
348
+ *,
349
+ role: ConversationRole,
350
+ template: AdversarialTemplate,
351
+ parameters: TemplateParameters,
352
+ target: Optional[Callable] = None,
353
+ ) -> ConversationBot:
354
+ if role is ConversationRole.USER:
343
355
  model = self._get_user_proxy_completion_model(
344
356
  template_key=template.template_name, template_parameters=parameters
345
357
  )
@@ -350,30 +362,46 @@ class AdversarialSimulator:
350
362
  instantiation_parameters=parameters,
351
363
  )
352
364
 
353
- if role == ConversationRole.ASSISTANT:
365
+ if role is ConversationRole.ASSISTANT:
366
+ if target is None:
367
+ msg = "Cannot setup system bot. Target is None"
354
368
 
355
- def dummy_model() -> None:
356
- return None
369
+ raise EvaluationException(
370
+ message=msg,
371
+ internal_message=msg,
372
+ target=ErrorTarget.ADVERSARIAL_SIMULATOR,
373
+ error_category=ErrorCategory.INVALID_VALUE,
374
+ blame=ErrorBlame.SYSTEM_ERROR,
375
+ )
376
+
377
+ class DummyModel:
378
+ def __init__(self):
379
+ self.name = "dummy_model"
380
+
381
+ def __call__(self) -> None:
382
+ pass
357
383
 
358
- dummy_model.name = "dummy_model"
359
384
  return CallbackConversationBot(
360
385
  callback=target,
361
386
  role=role,
362
- model=dummy_model,
387
+ model=DummyModel(),
363
388
  user_template=str(template),
364
389
  user_template_parameters=parameters,
365
390
  conversation_template="",
366
391
  instantiation_parameters={},
367
392
  )
368
- return ConversationBot(
369
- role=role,
370
- model=model,
371
- conversation_template=template,
372
- instantiation_parameters=parameters,
393
+
394
+ msg = "Invalid value for enum ConversationRole. This should never happen."
395
+ raise EvaluationException(
396
+ message=msg,
397
+ internal_message=msg,
398
+ target=ErrorTarget.ADVERSARIAL_SIMULATOR,
399
+ category=ErrorCategory.INVALID_VALUE,
400
+ blame=ErrorBlame.SYSTEM_ERROR,
373
401
  )
374
402
 
375
- def _join_conversation_starter(self, parameters, to_join):
376
- key = "conversation_starter"
403
+ def _join_conversation_starter(self, parameters: TemplateParameters, to_join: str) -> TemplateParameters:
404
+ key: Literal["conversation_starter"] = "conversation_starter"
377
405
  if key in parameters.keys():
378
406
  parameters[key] = f"{to_join} {parameters[key]}"
379
407
  else:
@@ -7,7 +7,7 @@ import copy
7
7
  import logging
8
8
  import time
9
9
  from dataclasses import dataclass
10
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
11
11
 
12
12
  import jinja2
13
13
 
@@ -15,6 +15,7 @@ from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarg
15
15
  from azure.ai.evaluation._http_utils import AsyncHttpPipeline
16
16
 
17
17
  from .._model_tools import LLMBase, OpenAIChatCompletionsModel
18
+ from .._model_tools._template_handler import TemplateParameters
18
19
  from .constants import ConversationRole
19
20
 
20
21
 
@@ -40,7 +41,7 @@ class ConversationTurn:
40
41
  role: "ConversationRole"
41
42
  name: Optional[str] = None
42
43
  message: str = ""
43
- full_response: Optional[Any] = None
44
+ full_response: Optional[Dict[str, Any]] = None
44
45
  request: Optional[Any] = None
45
46
 
46
47
  def to_openai_chat_format(self, reverse: bool = False) -> Dict[str, str]:
@@ -109,7 +110,7 @@ class ConversationBot:
109
110
  role: ConversationRole,
110
111
  model: Union[LLMBase, OpenAIChatCompletionsModel],
111
112
  conversation_template: str,
112
- instantiation_parameters: Dict[str, str],
113
+ instantiation_parameters: TemplateParameters,
113
114
  ) -> None:
114
115
  self.role = role
115
116
  self.conversation_template_orig = conversation_template
@@ -118,13 +119,13 @@ class ConversationBot:
118
119
  )
119
120
  self.persona_template_args = instantiation_parameters
120
121
  if self.role == ConversationRole.USER:
121
- self.name = self.persona_template_args.get("name", role.value)
122
+ self.name: str = cast(str, self.persona_template_args.get("name", role.value))
122
123
  else:
123
- self.name = self.persona_template_args.get("chatbot_name", role.value) or model.name
124
+ self.name = cast(str, self.persona_template_args.get("chatbot_name", role.value)) or model.name
124
125
  self.model = model
125
126
 
126
127
  self.logger = logging.getLogger(repr(self))
127
- self.conversation_starter = None # can either be a dictionary or jinja template
128
+ self.conversation_starter: Optional[Union[str, jinja2.Template, Dict]] = None
128
129
  if role == ConversationRole.USER:
129
130
  if "conversation_starter" in self.persona_template_args:
130
131
  conversation_starter_content = self.persona_template_args["conversation_starter"]
@@ -148,7 +149,7 @@ class ConversationBot:
148
149
  conversation_history: List[ConversationTurn],
149
150
  max_history: int,
150
151
  turn_number: int = 0,
151
- ) -> Tuple[dict, dict, int, dict]:
152
+ ) -> Tuple[dict, dict, float, dict]:
152
153
  """
153
154
  Prompt the ConversationBot for a response.
154
155
 
@@ -161,7 +162,7 @@ class ConversationBot:
161
162
  :param turn_number: Parameters used to query GPT-4 model.
162
163
  :type turn_number: int
163
164
  :return: The response from the ConversationBot.
164
- :rtype: Tuple[dict, dict, int, dict]
165
+ :rtype: Tuple[dict, dict, float, dict]
165
166
  """
166
167
 
167
168
  # check if this is the first turn and the conversation_starter is not None,
@@ -169,11 +170,11 @@ class ConversationBot:
169
170
  if turn_number == 0 and self.conversation_starter is not None:
170
171
  # if conversation_starter is a dictionary, pass it into samples as is
171
172
  if isinstance(self.conversation_starter, dict):
172
- samples = [self.conversation_starter]
173
+ samples: List[Union[str, jinja2.Template, Dict]] = [self.conversation_starter]
173
174
  if isinstance(self.conversation_starter, jinja2.Template):
174
175
  samples = [self.conversation_starter.render(**self.persona_template_args)]
175
176
  else:
176
- samples = [self.conversation_starter] # type: ignore[attr-defined]
177
+ samples = [self.conversation_starter]
177
178
  time_taken = 0
178
179
 
179
180
  finish_reason = ["stop"]
@@ -238,7 +239,7 @@ class CallbackConversationBot(ConversationBot):
238
239
  self,
239
240
  callback: Callable,
240
241
  user_template: str,
241
- user_template_parameters: Dict,
242
+ user_template_parameters: TemplateParameters,
242
243
  *args,
243
244
  **kwargs,
244
245
  ) -> None:
@@ -254,7 +255,7 @@ class CallbackConversationBot(ConversationBot):
254
255
  conversation_history: List[Any],
255
256
  max_history: int,
256
257
  turn_number: int = 0,
257
- ) -> Tuple[dict, dict, int, dict]:
258
+ ) -> Tuple[dict, dict, float, dict]:
258
259
  chat_protocol_message = self._to_chat_protocol(
259
260
  self.user_template, conversation_history, self.user_template_parameters
260
261
  )
@@ -4,7 +4,7 @@
4
4
 
5
5
  import asyncio
6
6
  import logging
7
- from typing import Callable, Dict, List, Tuple, Union
7
+ from typing import Callable, Dict, List, Optional, Tuple, Union
8
8
 
9
9
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
10
10
  from azure.ai.evaluation.simulator._constants import SupportedLanguages
@@ -80,7 +80,7 @@ async def simulate_conversation(
80
80
  history_limit: int = 5,
81
81
  api_call_delay_sec: float = 0,
82
82
  logger: logging.Logger = logging.getLogger(__name__),
83
- ) -> Tuple:
83
+ ) -> Tuple[Optional[str], List[ConversationTurn]]:
84
84
  """
85
85
  Simulate a conversation between the given bots.
86
86
 
@@ -99,7 +99,7 @@ async def simulate_conversation(
99
99
  :keyword logger: The logger to use for logging. Defaults to the logger named after the current module.
100
100
  :paramtype logger: logging.Logger
101
101
  :return: Simulation a conversation between the given bots.
102
- :rtype: Tuple
102
+ :rtype: Tuple[Optional[str], List[ConversationTurn]]
103
103
  """
104
104
 
105
105
  # Read the first prompt.
@@ -110,7 +110,7 @@ async def simulate_conversation(
110
110
  turn_number=0,
111
111
  )
112
112
  if "id" in first_response:
113
- conversation_id = first_response["id"]
113
+ conversation_id: Optional[str] = first_response["id"]
114
114
  else:
115
115
  conversation_id = None
116
116
  first_prompt = first_response["samples"][0]
@@ -0,0 +1,3 @@
1
+ # ---------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # ---------------------------------------------------------