azure-ai-evaluation 1.0.0b4__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. azure/ai/evaluation/__init__.py +22 -0
  2. azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +4 -0
  3. azure/ai/evaluation/_common/constants.py +5 -0
  4. azure/ai/evaluation/_common/math.py +73 -2
  5. azure/ai/evaluation/_common/rai_service.py +250 -62
  6. azure/ai/evaluation/_common/utils.py +196 -23
  7. azure/ai/evaluation/_constants.py +7 -6
  8. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
  9. azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +13 -4
  10. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +19 -6
  11. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +46 -0
  12. azure/ai/evaluation/_evaluate/_eval_run.py +55 -14
  13. azure/ai/evaluation/_evaluate/_evaluate.py +312 -228
  14. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +7 -6
  15. azure/ai/evaluation/_evaluate/_utils.py +46 -11
  16. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +17 -18
  17. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +67 -31
  18. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
  19. azure/ai/evaluation/_evaluators/_common/_base_eval.py +37 -24
  20. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +21 -9
  21. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +52 -16
  22. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +91 -48
  23. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +100 -26
  24. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +94 -26
  25. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +96 -26
  26. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +97 -26
  27. azure/ai/evaluation/_evaluators/_eci/_eci.py +31 -4
  28. azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +20 -13
  29. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +67 -36
  30. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
  31. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +14 -16
  32. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +106 -34
  33. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
  34. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
  35. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +20 -27
  36. azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
  37. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +132 -0
  38. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +55 -0
  39. azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +100 -0
  40. azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +124 -0
  41. azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +100 -0
  42. azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +100 -0
  43. azure/ai/evaluation/_evaluators/_multimodal/_violence.py +100 -0
  44. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +87 -31
  45. azure/ai/evaluation/_evaluators/_qa/_qa.py +23 -31
  46. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +72 -36
  47. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
  48. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +83 -125
  49. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
  50. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +26 -27
  51. azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
  52. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +148 -0
  53. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +37 -28
  54. azure/ai/evaluation/_evaluators/_xpia/xpia.py +94 -33
  55. azure/ai/evaluation/_exceptions.py +19 -0
  56. azure/ai/evaluation/_model_configurations.py +83 -15
  57. azure/ai/evaluation/_version.py +1 -1
  58. azure/ai/evaluation/simulator/__init__.py +2 -1
  59. azure/ai/evaluation/simulator/_adversarial_scenario.py +20 -1
  60. azure/ai/evaluation/simulator/_adversarial_simulator.py +29 -35
  61. azure/ai/evaluation/simulator/_constants.py +11 -1
  62. azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
  63. azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
  64. azure/ai/evaluation/simulator/_direct_attack_simulator.py +17 -9
  65. azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
  66. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +22 -1
  67. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +90 -35
  68. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +4 -2
  69. azure/ai/evaluation/simulator/_model_tools/_rai_client.py +8 -4
  70. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -4
  71. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -1
  72. azure/ai/evaluation/simulator/_simulator.py +165 -105
  73. azure/ai/evaluation/simulator/_utils.py +31 -13
  74. azure_ai_evaluation-1.0.1.dist-info/METADATA +600 -0
  75. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.1.dist-info}/NOTICE.txt +20 -0
  76. azure_ai_evaluation-1.0.1.dist-info/RECORD +119 -0
  77. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.1.dist-info}/WHEEL +1 -1
  78. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +0 -322
  79. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
  80. azure_ai_evaluation-1.0.0b4.dist-info/METADATA +0 -535
  81. azure_ai_evaluation-1.0.0b4.dist-info/RECORD +0 -106
  82. /azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +0 -0
  83. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.1.dist-info}/top_level.txt +0 -0
@@ -5,20 +5,23 @@
5
5
  # ---------------------------------------------------------
6
6
  import asyncio
7
7
  import importlib.resources as pkg_resources
8
- from tqdm import tqdm
9
8
  import json
10
9
  import os
11
10
  import re
12
11
  import warnings
13
- from typing import Any, Callable, Dict, List, Optional, Union
12
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
13
+
14
14
  from promptflow.core import AsyncPrompty
15
- from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
15
+ from tqdm import tqdm
16
+
17
+ from azure.ai.evaluation._common._experimental import experimental
16
18
  from azure.ai.evaluation._common.utils import construct_prompty_model_config
19
+ from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
17
20
 
18
21
  from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
19
22
  from .._user_agent import USER_AGENT
20
23
  from ._conversation.constants import ConversationRole
21
- from ._helpers import ConversationHistory, Turn, experimental
24
+ from ._helpers import ConversationHistory, Turn
22
25
  from ._utils import JsonLineChatProtocol
23
26
 
24
27
 
@@ -26,16 +29,22 @@ from ._utils import JsonLineChatProtocol
26
29
  class Simulator:
27
30
  """
28
31
  Simulator for generating synthetic conversations.
32
+
33
+ :param model_config: A dictionary defining the configuration for the model. Acceptable types are AzureOpenAIModelConfiguration and OpenAIModelConfiguration.
34
+ :type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, ~azure.ai.evaluation.OpenAIModelConfiguration]
35
+ :raises ValueError: If the model_config does not contain the required keys or any value is None.
36
+
37
+ .. admonition:: Example:
38
+
39
+ .. literalinclude:: ../samples/evaluation_samples_simulate.py
40
+ :start-after: [START nonadversarial_simulator]
41
+ :end-before: [END nonadversarial_simulator]
42
+ :language: python
43
+ :dedent: 8
44
+ :caption: Run a Simulator for 2 queries and 4 conversation turns.
29
45
  """
30
46
 
31
47
  def __init__(self, model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration]):
32
- """
33
- Initializes the task simulator with the model configuration.
34
-
35
- :param model_config: A dictionary defining the configuration for the model. Acceptable types are AzureOpenAIModelConfiguration and OpenAIModelConfiguration.
36
- :type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, ~azure.ai.evaluation.OpenAIModelConfiguration]
37
- :raises ValueError: If the model_config does not contain the required keys or any value is None.
38
- """
39
48
  self._validate_model_config(model_config)
40
49
  self.model_config = model_config
41
50
  if "api_version" not in self.model_config:
@@ -87,9 +96,10 @@ class Simulator:
87
96
  query_response_generating_prompty: Optional[str] = None,
88
97
  user_simulator_prompty: Optional[str] = None,
89
98
  api_call_delay_sec: float = 1,
90
- query_response_generating_prompty_kwargs: Dict[str, Any] = {},
91
- user_simulator_prompty_kwargs: Dict[str, Any] = {},
92
- conversation_turns: List[List[str]] = [],
99
+ query_response_generating_prompty_options: Dict[str, Any] = {},
100
+ user_simulator_prompty_options: Dict[str, Any] = {},
101
+ conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
102
+ concurrent_async_tasks: int = 5,
93
103
  **kwargs,
94
104
  ) -> List[JsonLineChatProtocol]:
95
105
  """
@@ -111,12 +121,15 @@ class Simulator:
111
121
  :paramtype user_simulator_prompty: Optional[str]
112
122
  :keyword api_call_delay_sec: Delay in seconds between API calls.
113
123
  :paramtype api_call_delay_sec: float
114
- :keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the query response generating prompty.
115
- :paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
116
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
117
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
124
+ :keyword query_response_generating_prompty_options: Additional keyword arguments for the query response generating prompty.
125
+ :paramtype query_response_generating_prompty_options: Dict[str, Any]
126
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
127
+ :paramtype user_simulator_prompty_options: Dict[str, Any]
118
128
  :keyword conversation_turns: Predefined conversation turns to simulate.
119
- :paramtype conversation_turns: List[List[str]]
129
+ :paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
130
+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
131
+ Defaults to 5.
132
+ :paramtype concurrent_async_tasks: int
120
133
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
121
134
  :rtype: List[JsonLineChatProtocol]
122
135
 
@@ -131,12 +144,12 @@ class Simulator:
131
144
  if conversation_turns and (text or tasks):
132
145
  raise ValueError("Cannot specify both conversation_turns and text/tasks")
133
146
 
134
- if num_queries > len(tasks):
147
+ if text and num_queries > len(tasks):
135
148
  warnings.warn(
136
149
  f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
137
150
  f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
138
151
  )
139
- elif num_queries < len(tasks):
152
+ elif text and num_queries < len(tasks):
140
153
  warnings.warn(
141
154
  f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
142
155
  f"Only the first {num_queries} lines of the specified tasks will be simulated."
@@ -151,16 +164,17 @@ class Simulator:
151
164
  max_conversation_turns=max_conversation_turns,
152
165
  conversation_turns=conversation_turns,
153
166
  user_simulator_prompty=user_simulator_prompty,
154
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
167
+ user_simulator_prompty_options=user_simulator_prompty_options,
155
168
  api_call_delay_sec=api_call_delay_sec,
156
169
  prompty_model_config=prompty_model_config,
170
+ concurrent_async_tasks=concurrent_async_tasks,
157
171
  )
158
172
 
159
173
  query_responses = await self._generate_query_responses(
160
174
  text=text,
161
175
  num_queries=num_queries,
162
176
  query_response_generating_prompty=query_response_generating_prompty,
163
- query_response_generating_prompty_kwargs=query_response_generating_prompty_kwargs,
177
+ query_response_generating_prompty_options=query_response_generating_prompty_options,
164
178
  prompty_model_config=prompty_model_config,
165
179
  **kwargs,
166
180
  )
@@ -169,9 +183,10 @@ class Simulator:
169
183
  max_conversation_turns=max_conversation_turns,
170
184
  tasks=tasks,
171
185
  user_simulator_prompty=user_simulator_prompty,
172
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
186
+ user_simulator_prompty_options=user_simulator_prompty_options,
173
187
  target=target,
174
188
  api_call_delay_sec=api_call_delay_sec,
189
+ text=text,
175
190
  )
176
191
 
177
192
  async def _simulate_with_predefined_turns(
@@ -179,11 +194,12 @@ class Simulator:
179
194
  *,
180
195
  target: Callable,
181
196
  max_conversation_turns: int,
182
- conversation_turns: List[List[str]],
197
+ conversation_turns: List[List[Union[str, Dict[str, Any]]]],
183
198
  user_simulator_prompty: Optional[str],
184
- user_simulator_prompty_kwargs: Dict[str, Any],
199
+ user_simulator_prompty_options: Dict[str, Any],
185
200
  api_call_delay_sec: float,
186
201
  prompty_model_config: Any,
202
+ concurrent_async_tasks: int,
187
203
  ) -> List[JsonLineChatProtocol]:
188
204
  """
189
205
  Simulates conversations using predefined conversation turns.
@@ -193,51 +209,69 @@ class Simulator:
193
209
  :keyword max_conversation_turns: Maximum number of turns for the simulation.
194
210
  :paramtype max_conversation_turns: int
195
211
  :keyword conversation_turns: A list of predefined conversation turns.
196
- :paramtype conversation_turns: List[List[str]]
212
+ :paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
197
213
  :keyword user_simulator_prompty: Path to the user simulator prompty file.
198
214
  :paramtype user_simulator_prompty: Optional[str]
199
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
200
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
215
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
216
+ :paramtype user_simulator_prompty_options: Dict[str, Any]
201
217
  :keyword api_call_delay_sec: Delay in seconds between API calls.
202
218
  :paramtype api_call_delay_sec: float
203
219
  :keyword prompty_model_config: The configuration for the prompty model.
204
220
  :paramtype prompty_model_config: Any
221
+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
222
+ :paramtype concurrent_async_tasks: int
205
223
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
206
224
  :rtype: List[JsonLineChatProtocol]
207
225
  """
208
- simulated_conversations = []
209
226
  progress_bar = tqdm(
210
227
  total=int(len(conversation_turns) * (max_conversation_turns / 2)),
211
228
  desc="Simulating with predefined conversation turns: ",
212
229
  ncols=100,
213
230
  unit="messages",
214
231
  )
215
-
216
- for simulation in conversation_turns:
217
- current_simulation = ConversationHistory()
218
- for simulated_turn in simulation:
219
- user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
220
- current_simulation.add_to_history(user_turn)
221
- assistant_response = await self._get_target_response(
222
- target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
223
- )
224
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
225
- current_simulation.add_to_history(assistant_turn)
226
- progress_bar.update(1) # Update progress bar for both user and assistant turns
227
-
228
- if len(current_simulation) < max_conversation_turns:
229
- await self._extend_conversation_with_simulator(
230
- current_simulation=current_simulation,
231
- max_conversation_turns=max_conversation_turns,
232
- user_simulator_prompty=user_simulator_prompty,
233
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
234
- api_call_delay_sec=api_call_delay_sec,
235
- prompty_model_config=prompty_model_config,
236
- target=target,
237
- progress_bar=progress_bar,
238
- )
239
- simulated_conversations.append(
240
- JsonLineChatProtocol(
232
+ semaphore = asyncio.Semaphore(concurrent_async_tasks)
233
+ progress_bar_lock = asyncio.Lock()
234
+
235
+ async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
236
+ async with semaphore:
237
+ current_simulation = ConversationHistory()
238
+ for simulated_turn in simulation:
239
+ if isinstance(simulated_turn, str):
240
+ user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
241
+ elif isinstance(simulated_turn, dict):
242
+ user_turn = Turn(
243
+ role=ConversationRole.USER,
244
+ content=str(simulated_turn.get("content")),
245
+ context=str(simulated_turn.get("context")),
246
+ )
247
+ else:
248
+ raise ValueError(
249
+ "Each simulated turn must be a string or a dict with 'content' and 'context' keys"
250
+ )
251
+ current_simulation.add_to_history(user_turn)
252
+ assistant_response, assistant_context = await self._get_target_response(
253
+ target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
254
+ )
255
+ assistant_turn = Turn(
256
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
257
+ )
258
+ current_simulation.add_to_history(assistant_turn)
259
+ async with progress_bar_lock:
260
+ progress_bar.update(1)
261
+
262
+ if len(current_simulation) < max_conversation_turns:
263
+ await self._extend_conversation_with_simulator(
264
+ current_simulation=current_simulation,
265
+ max_conversation_turns=max_conversation_turns,
266
+ user_simulator_prompty=user_simulator_prompty,
267
+ user_simulator_prompty_options=user_simulator_prompty_options,
268
+ api_call_delay_sec=api_call_delay_sec,
269
+ prompty_model_config=prompty_model_config,
270
+ target=target,
271
+ progress_bar=progress_bar,
272
+ progress_bar_lock=progress_bar_lock,
273
+ )
274
+ return JsonLineChatProtocol(
241
275
  {
242
276
  "messages": current_simulation.to_list(),
243
277
  "finish_reason": ["stop"],
@@ -245,10 +279,11 @@ class Simulator:
245
279
  "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
246
280
  }
247
281
  )
248
- )
249
282
 
283
+ tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
284
+ results = await asyncio.gather(*tasks)
250
285
  progress_bar.close()
251
- return simulated_conversations
286
+ return results
252
287
 
253
288
  async def _extend_conversation_with_simulator(
254
289
  self,
@@ -256,11 +291,12 @@ class Simulator:
256
291
  current_simulation: ConversationHistory,
257
292
  max_conversation_turns: int,
258
293
  user_simulator_prompty: Optional[str],
259
- user_simulator_prompty_kwargs: Dict[str, Any],
294
+ user_simulator_prompty_options: Dict[str, Any],
260
295
  api_call_delay_sec: float,
261
296
  prompty_model_config: Dict[str, Any],
262
297
  target: Callable,
263
298
  progress_bar: tqdm,
299
+ progress_bar_lock: asyncio.Lock,
264
300
  ):
265
301
  """
266
302
  Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
@@ -271,8 +307,8 @@ class Simulator:
271
307
  :paramtype max_conversation_turns: int,
272
308
  :keyword user_simulator_prompty: Path to the user simulator prompty file.
273
309
  :paramtype user_simulator_prompty: Optional[str],
274
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
275
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any],
310
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
311
+ :paramtype user_simulator_prompty_options: Dict[str, Any],
276
312
  :keyword api_call_delay_sec: Delay in seconds between API calls.
277
313
  :paramtype api_call_delay_sec: float,
278
314
  :keyword prompty_model_config: The configuration for the prompty model.
@@ -281,36 +317,41 @@ class Simulator:
281
317
  :paramtype target: Callable,
282
318
  :keyword progress_bar: Progress bar for tracking simulation progress.
283
319
  :paramtype progress_bar: tqdm,
320
+ :keyword progress_bar_lock: Lock for updating the progress bar safely.
321
+ :paramtype progress_bar_lock: asyncio.Lock
284
322
  """
285
323
  user_flow = self._load_user_simulation_flow(
286
324
  user_simulator_prompty=user_simulator_prompty, # type: ignore
287
325
  prompty_model_config=prompty_model_config,
288
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
326
+ user_simulator_prompty_options=user_simulator_prompty_options,
289
327
  )
290
328
 
291
329
  while len(current_simulation) < max_conversation_turns:
292
330
  user_response_content = await user_flow(
293
331
  task="Continue the conversation",
294
- conversation_history=current_simulation.to_list(),
295
- **user_simulator_prompty_kwargs,
332
+ conversation_history=current_simulation.to_context_free_list(),
333
+ **user_simulator_prompty_options,
296
334
  )
297
335
  user_response = self._parse_prompty_response(response=user_response_content)
298
336
  user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
299
337
  current_simulation.add_to_history(user_turn)
300
338
  await asyncio.sleep(api_call_delay_sec)
301
- assistant_response = await self._get_target_response(
339
+ assistant_response, assistant_context = await self._get_target_response(
302
340
  target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
303
341
  )
304
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
342
+ assistant_turn = Turn(
343
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
344
+ )
305
345
  current_simulation.add_to_history(assistant_turn)
306
- progress_bar.update(1)
346
+ async with progress_bar_lock:
347
+ progress_bar.update(1)
307
348
 
308
349
  def _load_user_simulation_flow(
309
350
  self,
310
351
  *,
311
352
  user_simulator_prompty: Optional[Union[str, os.PathLike]],
312
353
  prompty_model_config: Dict[str, Any],
313
- user_simulator_prompty_kwargs: Dict[str, Any],
354
+ user_simulator_prompty_options: Dict[str, Any],
314
355
  ) -> "AsyncPrompty": # type: ignore
315
356
  """
316
357
  Loads the flow for simulating user interactions.
@@ -319,8 +360,8 @@ class Simulator:
319
360
  :paramtype user_simulator_prompty: Optional[Union[str, os.PathLike]]
320
361
  :keyword prompty_model_config: The configuration for the prompty model.
321
362
  :paramtype prompty_model_config: Dict[str, Any]
322
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
323
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
363
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
364
+ :paramtype user_simulator_prompty_options: Dict[str, Any]
324
365
  :return: The loaded flow for simulating user interactions.
325
366
  :rtype: AsyncPrompty
326
367
  """
@@ -353,7 +394,7 @@ class Simulator:
353
394
  return AsyncPrompty.load(
354
395
  source=user_simulator_prompty,
355
396
  model=prompty_model_config,
356
- **user_simulator_prompty_kwargs,
397
+ **user_simulator_prompty_options,
357
398
  ) # type: ignore
358
399
 
359
400
  def _parse_prompty_response(self, *, response: str) -> Dict[str, Any]:
@@ -401,7 +442,7 @@ class Simulator:
401
442
  text: str,
402
443
  num_queries: int,
403
444
  query_response_generating_prompty: Optional[str],
404
- query_response_generating_prompty_kwargs: Dict[str, Any],
445
+ query_response_generating_prompty_options: Dict[str, Any],
405
446
  prompty_model_config: Any,
406
447
  **kwargs,
407
448
  ) -> List[Dict[str, str]]:
@@ -414,8 +455,8 @@ class Simulator:
414
455
  :paramtype num_queries: int
415
456
  :keyword query_response_generating_prompty: Path to the query response generating prompty file.
416
457
  :paramtype query_response_generating_prompty: Optional[str]
417
- :keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the query response generating prompty.
418
- :paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
458
+ :keyword query_response_generating_prompty_options: Additional keyword arguments for the query response generating prompty.
459
+ :paramtype query_response_generating_prompty_options: Dict[str, Any]
419
460
  :keyword prompty_model_config: The configuration for the prompty model.
420
461
  :paramtype prompty_model_config: Any
421
462
  :return: A list of query-response dictionaries.
@@ -425,13 +466,21 @@ class Simulator:
425
466
  query_flow = self._load_query_generation_flow(
426
467
  query_response_generating_prompty=query_response_generating_prompty, # type: ignore
427
468
  prompty_model_config=prompty_model_config,
428
- query_response_generating_prompty_kwargs=query_response_generating_prompty_kwargs,
469
+ query_response_generating_prompty_options=query_response_generating_prompty_options,
429
470
  )
430
471
  try:
431
472
  query_responses = await query_flow(text=text, num_queries=num_queries)
432
473
  if isinstance(query_responses, dict):
433
474
  keys = list(query_responses.keys())
434
475
  return query_responses[keys[0]]
476
+ if isinstance(query_responses, str):
477
+ query_responses = json.loads(query_responses)
478
+ if isinstance(query_responses, dict):
479
+ if len(query_responses.keys()) == 1:
480
+ return query_responses[list(query_responses.keys())[0]]
481
+ return query_responses # type: ignore
482
+ if isinstance(query_responses, list):
483
+ return query_responses
435
484
  return json.loads(query_responses)
436
485
  except Exception as e:
437
486
  raise RuntimeError("Error generating query responses") from e
@@ -441,7 +490,7 @@ class Simulator:
441
490
  *,
442
491
  query_response_generating_prompty: Optional[Union[str, os.PathLike]],
443
492
  prompty_model_config: Dict[str, Any],
444
- query_response_generating_prompty_kwargs: Dict[str, Any],
493
+ query_response_generating_prompty_options: Dict[str, Any],
445
494
  ) -> "AsyncPrompty":
446
495
  """
447
496
  Loads the flow for generating query responses.
@@ -450,8 +499,8 @@ class Simulator:
450
499
  :paramtype query_response_generating_prompty: Optional[Union[str, os.PathLike]]
451
500
  :keyword prompty_model_config: The configuration for the prompty model.
452
501
  :paramtype prompty_model_config: Dict[str, Any]
453
- :keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the flow.
454
- :paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
502
+ :keyword query_response_generating_prompty_options: Additional keyword arguments for the flow.
503
+ :paramtype query_response_generating_prompty_options: Dict[str, Any]
455
504
  :return: The loaded flow for generating query responses.
456
505
  :rtype: AsyncPrompty
457
506
  """
@@ -484,7 +533,7 @@ class Simulator:
484
533
  return AsyncPrompty.load(
485
534
  source=query_response_generating_prompty,
486
535
  model=prompty_model_config,
487
- **query_response_generating_prompty_kwargs,
536
+ **query_response_generating_prompty_options,
488
537
  ) # type: ignore
489
538
 
490
539
  async def _create_conversations_from_query_responses(
@@ -494,9 +543,10 @@ class Simulator:
494
543
  max_conversation_turns: int,
495
544
  tasks: List[str],
496
545
  user_simulator_prompty: Optional[str],
497
- user_simulator_prompty_kwargs: Dict[str, Any],
546
+ user_simulator_prompty_options: Dict[str, Any],
498
547
  target: Callable,
499
548
  api_call_delay_sec: float,
549
+ text: str,
500
550
  ) -> List[JsonLineChatProtocol]:
501
551
  """
502
552
  Creates full conversations from query-response pairs.
@@ -509,12 +559,14 @@ class Simulator:
509
559
  :paramtype tasks: List[str]
510
560
  :keyword user_simulator_prompty: Path to the user simulator prompty file.
511
561
  :paramtype user_simulator_prompty: Optional[str]
512
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
513
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
562
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
563
+ :paramtype user_simulator_prompty_options: Dict[str, Any]
514
564
  :keyword target: The target function to call for responses.
515
565
  :paramtype target: Callable
516
566
  :keyword api_call_delay_sec: Delay in seconds between API calls.
517
567
  :paramtype api_call_delay_sec: float
568
+ :keyword text: The initial input text for generating query responses.
569
+ :paramtype text: str
518
570
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
519
571
  :rtype: List[JsonLineChatProtocol]
520
572
  """
@@ -538,7 +590,7 @@ class Simulator:
538
590
  max_conversation_turns=max_conversation_turns,
539
591
  task=task, # type: ignore
540
592
  user_simulator_prompty=user_simulator_prompty,
541
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
593
+ user_simulator_prompty_options=user_simulator_prompty_options,
542
594
  target=target,
543
595
  api_call_delay_sec=api_call_delay_sec,
544
596
  progress_bar=progress_bar,
@@ -552,6 +604,7 @@ class Simulator:
552
604
  "task": task,
553
605
  "expected_response": response,
554
606
  "query": query,
607
+ "original_text": text,
555
608
  },
556
609
  "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
557
610
  }
@@ -567,7 +620,7 @@ class Simulator:
567
620
  max_conversation_turns: int,
568
621
  task: str,
569
622
  user_simulator_prompty: Optional[str],
570
- user_simulator_prompty_kwargs: Dict[str, Any],
623
+ user_simulator_prompty_options: Dict[str, Any],
571
624
  target: Callable,
572
625
  api_call_delay_sec: float,
573
626
  progress_bar: tqdm,
@@ -583,8 +636,8 @@ class Simulator:
583
636
  :paramtype task: str
584
637
  :keyword user_simulator_prompty: Path to the user simulator prompty file.
585
638
  :paramtype user_simulator_prompty: Optional[str]
586
- :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
587
- :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
639
+ :keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
640
+ :paramtype user_simulator_prompty_options: Dict[str, Any]
588
641
  :keyword target: The target function to call for responses.
589
642
  :paramtype target: Callable
590
643
  :keyword api_call_delay_sec: Delay in seconds between API calls.
@@ -595,33 +648,40 @@ class Simulator:
595
648
  :rtype: List[Dict[str, Optional[str]]]
596
649
  """
597
650
  conversation_history = ConversationHistory()
598
- # user_turn = Turn(role=ConversationRole.USER, content=conversation_starter)
599
- # conversation_history.add_to_history(user_turn)
600
651
 
601
652
  while len(conversation_history) < max_conversation_turns:
602
653
  user_flow = self._load_user_simulation_flow(
603
654
  user_simulator_prompty=user_simulator_prompty, # type: ignore
604
655
  prompty_model_config=self.model_config, # type: ignore
605
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
606
- )
607
- conversation_starter_from_simulated_user = await user_flow(
608
- task=task,
609
- conversation_history=[
610
- {
611
- "role": "assistant",
612
- "content": conversation_starter,
613
- "your_task": "Act as the user and translate the content into a user query.",
614
- }
615
- ],
656
+ user_simulator_prompty_options=user_simulator_prompty_options,
616
657
  )
658
+ if len(conversation_history) == 0:
659
+ conversation_starter_from_simulated_user = await user_flow(
660
+ task=task,
661
+ conversation_history=[
662
+ {
663
+ "role": "assistant",
664
+ "content": conversation_starter,
665
+ }
666
+ ],
667
+ action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
668
+ )
669
+ else:
670
+ conversation_starter_from_simulated_user = await user_flow(
671
+ task=task,
672
+ conversation_history=conversation_history.to_context_free_list(),
673
+ action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
674
+ )
617
675
  if isinstance(conversation_starter_from_simulated_user, dict):
618
676
  conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
619
677
  user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
620
678
  conversation_history.add_to_history(user_turn)
621
- assistant_response = await self._get_target_response(
679
+ assistant_response, assistant_context = await self._get_target_response(
622
680
  target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
623
681
  )
624
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
682
+ assistant_turn = Turn(
683
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
684
+ )
625
685
  conversation_history.add_to_history(assistant_turn)
626
686
  progress_bar.update(1)
627
687
 
@@ -632,7 +692,7 @@ class Simulator:
632
692
 
633
693
  async def _get_target_response(
634
694
  self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
635
- ) -> str:
695
+ ) -> Tuple[str, Optional[str]]:
636
696
  """
637
697
  Retrieves the response from the target callback based on the current conversation history.
638
698
 
@@ -642,8 +702,8 @@ class Simulator:
642
702
  :paramtype api_call_delay_sec: float
643
703
  :keyword conversation_history: The current conversation history.
644
704
  :paramtype conversation_history: ConversationHistory
645
- :return: The content of the response from the target.
646
- :rtype: str
705
+ :return: The content of the response from the target and an optional context.
706
+ :rtype: str, Optional[str]
647
707
  """
648
708
  response = await target(
649
709
  messages={"messages": conversation_history.to_list()},
@@ -653,4 +713,4 @@ class Simulator:
653
713
  )
654
714
  await asyncio.sleep(api_call_delay_sec)
655
715
  latest_message = response["messages"][-1]
656
- return latest_message["content"]
716
+ return latest_message["content"], latest_message.get("context", "") # type: ignore
@@ -26,9 +26,9 @@ class JsonLineList(list):
26
26
  json_lines += json.dumps(item) + "\n"
27
27
  return json_lines
28
28
 
29
- def to_eval_qa_json_lines(self):
29
+ def to_eval_qr_json_lines(self):
30
30
  """
31
- Converts the list to a string of JSON lines suitable for evaluation in a Q&A format.
31
+ Converts the list to a string of JSON lines suitable for evaluation in a query & response format.
32
32
  Each item in the list is expected to be a dictionary with
33
33
  'messages' key. The 'messages' value is a list of
34
34
  dictionaries, each with a 'role' key and a 'content' key.
@@ -44,23 +44,41 @@ class JsonLineList(list):
44
44
  for item in self:
45
45
  user_message = None
46
46
  assistant_message = None
47
- context = None
47
+ user_context = None
48
+ assistant_context = None
49
+ template_parameters = item.get("template_parameters", {})
50
+ category = template_parameters.get("category", None)
48
51
  for message in item["messages"]:
49
52
  if message["role"] == "user":
50
53
  user_message = message["content"]
54
+ user_context = message.get("context", "")
51
55
  elif message["role"] == "assistant":
52
56
  assistant_message = message["content"]
53
- if "context" in message:
54
- context = message.get("context", None)
57
+ assistant_context = message.get("context", "")
55
58
  if user_message and assistant_message:
56
- if context:
59
+ if user_context or assistant_context:
57
60
  json_lines += (
58
- json.dumps({"query": user_message, "response": assistant_message, "context": context})
61
+ json.dumps(
62
+ {
63
+ "query": user_message,
64
+ "response": assistant_message,
65
+ "context": str(
66
+ {
67
+ "user_context": user_context,
68
+ "assistant_context": assistant_context,
69
+ }
70
+ ),
71
+ "category": category,
72
+ }
73
+ )
59
74
  + "\n"
60
75
  )
61
- user_message = assistant_message = context = None
76
+ user_message = assistant_message = None
62
77
  else:
63
- json_lines += json.dumps({"query": user_message, "response": assistant_message}) + "\n"
78
+ json_lines += (
79
+ json.dumps({"query": user_message, "response": assistant_message, "category": category})
80
+ + "\n"
81
+ )
64
82
  user_message = assistant_message = None
65
83
 
66
84
  return json_lines
@@ -80,9 +98,9 @@ class JsonLineChatProtocol(dict):
80
98
  """
81
99
  return json.dumps(self)
82
100
 
83
- def to_eval_qa_json_lines(self) -> str:
101
+ def to_eval_qr_json_lines(self) -> str:
84
102
  """
85
- Converts the object to a string of JSON lines suitable for evaluation in a Q&A format.
103
+ Converts the object to a string of JSON lines suitable for evaluation in a query and response format.
86
104
  The object is expected to be a dictionary with 'messages' key.
87
105
 
88
106
  :returns: A json lines document
@@ -105,10 +123,10 @@ class JsonLineChatProtocol(dict):
105
123
  if user_message and assistant_message:
106
124
  if context:
107
125
  json_lines += (
108
- json.dumps({"question": user_message, "answer": assistant_message, "context": context}) + "\n"
126
+ json.dumps({"query": user_message, "response": assistant_message, "context": context}) + "\n"
109
127
  )
110
128
  user_message = assistant_message = None
111
129
  else:
112
- json_lines += json.dumps({"question": user_message, "answer": assistant_message}) + "\n"
130
+ json_lines += json.dumps({"query": user_message, "response": assistant_message}) + "\n"
113
131
  user_message = assistant_message = None
114
132
  return json_lines