azure-ai-evaluation 1.0.0b4__py3-none-any.whl → 1.0.0b5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of azure-ai-evaluation might be problematic. Click here for more details.

Files changed (79) hide show
  1. azure/ai/evaluation/__init__.py +22 -0
  2. azure/ai/evaluation/_common/constants.py +5 -0
  3. azure/ai/evaluation/_common/math.py +11 -0
  4. azure/ai/evaluation/_common/rai_service.py +172 -35
  5. azure/ai/evaluation/_common/utils.py +162 -23
  6. azure/ai/evaluation/_constants.py +6 -6
  7. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
  8. azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +4 -4
  9. azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +6 -3
  10. azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +35 -0
  11. azure/ai/evaluation/_evaluate/_eval_run.py +21 -4
  12. azure/ai/evaluation/_evaluate/_evaluate.py +267 -139
  13. azure/ai/evaluation/_evaluate/_telemetry/__init__.py +5 -5
  14. azure/ai/evaluation/_evaluate/_utils.py +40 -7
  15. azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
  16. azure/ai/evaluation/_evaluators/_coherence/_coherence.py +14 -9
  17. azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
  18. azure/ai/evaluation/_evaluators/_common/_base_eval.py +20 -19
  19. azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +18 -8
  20. azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +48 -9
  21. azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +56 -19
  22. azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +5 -5
  23. azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +30 -1
  24. azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +30 -1
  25. azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +30 -1
  26. azure/ai/evaluation/_evaluators/_content_safety/_violence.py +30 -1
  27. azure/ai/evaluation/_evaluators/_eci/_eci.py +3 -1
  28. azure/ai/evaluation/_evaluators/_fluency/_fluency.py +20 -20
  29. azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
  30. azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
  31. azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +49 -15
  32. azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
  33. azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
  34. azure/ai/evaluation/_evaluators/_meteor/_meteor.py +3 -7
  35. azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
  36. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +130 -0
  37. azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +57 -0
  38. azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +96 -0
  39. azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +120 -0
  40. azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +96 -0
  41. azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +96 -0
  42. azure/ai/evaluation/_evaluators/_multimodal/_violence.py +96 -0
  43. azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +44 -11
  44. azure/ai/evaluation/_evaluators/_qa/_qa.py +7 -3
  45. azure/ai/evaluation/_evaluators/_relevance/_relevance.py +21 -19
  46. azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
  47. azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +125 -82
  48. azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
  49. azure/ai/evaluation/_evaluators/_rouge/_rouge.py +2 -2
  50. azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
  51. azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +150 -0
  52. azure/ai/evaluation/_evaluators/_similarity/_similarity.py +17 -14
  53. azure/ai/evaluation/_evaluators/_xpia/xpia.py +32 -5
  54. azure/ai/evaluation/_exceptions.py +17 -0
  55. azure/ai/evaluation/_model_configurations.py +18 -1
  56. azure/ai/evaluation/_version.py +1 -1
  57. azure/ai/evaluation/simulator/__init__.py +2 -1
  58. azure/ai/evaluation/simulator/_adversarial_scenario.py +5 -0
  59. azure/ai/evaluation/simulator/_adversarial_simulator.py +4 -1
  60. azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
  61. azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
  62. azure/ai/evaluation/simulator/_direct_attack_simulator.py +1 -1
  63. azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
  64. azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +22 -1
  65. azure/ai/evaluation/simulator/_indirect_attack_simulator.py +79 -34
  66. azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +1 -1
  67. azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -4
  68. azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -1
  69. azure/ai/evaluation/simulator/_simulator.py +115 -61
  70. azure/ai/evaluation/simulator/_utils.py +6 -6
  71. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/METADATA +166 -9
  72. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/NOTICE.txt +20 -0
  73. azure_ai_evaluation-1.0.0b5.dist-info/RECORD +120 -0
  74. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/WHEEL +1 -1
  75. azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
  76. azure_ai_evaluation-1.0.0b4.dist-info/RECORD +0 -106
  77. /azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +0 -0
  78. /azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +0 -0
  79. {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/top_level.txt +0 -0
@@ -7,13 +7,13 @@ import logging
7
7
  from random import randint
8
8
  from typing import Callable, Optional, cast
9
9
 
10
+ from azure.ai.evaluation._common._experimental import experimental
10
11
  from azure.ai.evaluation._common.utils import validate_azure_ai_project
11
12
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
12
13
  from azure.ai.evaluation.simulator import AdversarialScenario
13
14
  from azure.core.credentials import TokenCredential
14
15
 
15
16
  from ._adversarial_simulator import AdversarialSimulator
16
- from ._helpers import experimental
17
17
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
18
18
 
19
19
  logger = logging.getLogger(__name__)
@@ -1,5 +1,4 @@
1
- from ._experimental import experimental
2
1
  from ._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING
3
2
  from ._simulator_data_classes import ConversationHistory, Turn
4
3
 
5
- __all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING", "experimental"]
4
+ __all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING"]
@@ -30,7 +30,19 @@ class Turn:
30
30
  return {
31
31
  "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
32
32
  "content": self.content,
33
- "context": self.context,
33
+ "context": str(self.context),
34
+ }
35
+
36
+ def to_context_free_dict(self) -> Dict[str, Optional[str]]:
37
+ """
38
+ Convert the conversation turn to a dictionary without context.
39
+
40
+ :returns: A dictionary representation of the conversation turn without context.
41
+ :rtype: Dict[str, Optional[str]]
42
+ """
43
+ return {
44
+ "role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
45
+ "content": self.content,
34
46
  }
35
47
 
36
48
  def __repr__(self):
@@ -66,6 +78,15 @@ class ConversationHistory:
66
78
  """
67
79
  return [turn.to_dict() for turn in self.history]
68
80
 
81
+ def to_context_free_list(self) -> List[Dict[str, Optional[str]]]:
82
+ """
83
+ Converts the conversation history to a list of dictionaries without context.
84
+
85
+ :returns: A list of dictionaries representing the conversation turns without context.
86
+ :rtype: List[Dict[str, str]]
87
+ """
88
+ return [turn.to_context_free_dict() for turn in self.history]
89
+
69
90
  def __len__(self) -> int:
70
91
  return len(self.history)
71
92
 
@@ -3,23 +3,27 @@
3
3
  # ---------------------------------------------------------
4
4
  # pylint: disable=C0301,C0114,R0913,R0903
5
5
  # noqa: E501
6
+ import asyncio
6
7
  import logging
7
8
  from typing import Callable, cast
8
9
 
10
+ from tqdm import tqdm
11
+
9
12
  from azure.ai.evaluation._common.utils import validate_azure_ai_project
13
+ from azure.ai.evaluation._common._experimental import experimental
10
14
  from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
11
- from azure.ai.evaluation.simulator import AdversarialScenario
15
+ from azure.ai.evaluation.simulator import AdversarialScenarioJailbreak, SupportedLanguages
12
16
  from azure.core.credentials import TokenCredential
13
17
 
14
- from ._adversarial_simulator import AdversarialSimulator
15
- from ._helpers import experimental
18
+ from ._adversarial_simulator import AdversarialSimulator, JsonLineList
19
+
16
20
  from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
17
21
 
18
22
  logger = logging.getLogger(__name__)
19
23
 
20
24
 
21
25
  @experimental
22
- class IndirectAttackSimulator:
26
+ class IndirectAttackSimulator(AdversarialSimulator):
23
27
  """
24
28
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
25
29
 
@@ -54,6 +58,7 @@ class IndirectAttackSimulator:
54
58
  self.adversarial_template_handler = AdversarialTemplateHandler(
55
59
  azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
56
60
  )
61
+ super().__init__(azure_ai_project=azure_ai_project, credential=credential)
57
62
 
58
63
  def _ensure_service_dependencies(self):
59
64
  if self.rai_client is None:
@@ -69,29 +74,22 @@ class IndirectAttackSimulator:
69
74
  async def __call__(
70
75
  self,
71
76
  *,
72
- scenario: AdversarialScenario,
73
77
  target: Callable,
74
- max_conversation_turns: int = 1,
75
78
  max_simulation_results: int = 3,
76
79
  api_call_retry_limit: int = 3,
77
80
  api_call_retry_sleep_sec: int = 1,
78
81
  api_call_delay_sec: int = 0,
79
82
  concurrent_async_task: int = 3,
83
+ **kwargs,
80
84
  ):
81
85
  """
82
86
  Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
83
87
  This simulator converses with your AI system using prompts injected into the context to interrupt normal
84
88
  expected functionality by eliciting manipulated content, intrusion and attempting to gather information outside
85
89
  the scope of your AI system.
86
-
87
- :keyword scenario: Enum value specifying the adversarial scenario used for generating inputs.
88
- :paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario
89
90
  :keyword target: The target function to simulate adversarial inputs against.
90
91
  This function should be asynchronous and accept a dictionary representing the adversarial input.
91
92
  :paramtype target: Callable
92
- :keyword max_conversation_turns: The maximum number of conversation turns to simulate.
93
- Defaults to 1.
94
- :paramtype max_conversation_turns: int
95
93
  :keyword max_simulation_results: The maximum number of simulation results to return.
96
94
  Defaults to 3.
97
95
  :paramtype max_simulation_results: int
@@ -128,11 +126,11 @@ class IndirectAttackSimulator:
128
126
  'template_parameters': {},
129
127
  'messages': [
130
128
  {
131
- 'content': '<jailbreak prompt> <adversarial query>',
129
+ 'content': '<adversarial query>',
132
130
  'role': 'user'
133
131
  },
134
132
  {
135
- 'content': "<response from endpoint>",
133
+ 'content': "<response from your callback>",
136
134
  'role': 'assistant',
137
135
  'context': None
138
136
  }
@@ -141,25 +139,72 @@ class IndirectAttackSimulator:
141
139
  }]
142
140
  }
143
141
  """
144
- if scenario not in AdversarialScenario.__members__.values():
145
- msg = f"Invalid scenario: {scenario}. Supported scenarios: {AdversarialScenario.__members__.values()}"
146
- raise EvaluationException(
147
- message=msg,
148
- internal_message=msg,
149
- target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
150
- category=ErrorCategory.INVALID_VALUE,
151
- blame=ErrorBlame.USER_ERROR,
142
+ # values that cannot be changed:
143
+ scenario = AdversarialScenarioJailbreak.ADVERSARIAL_INDIRECT_JAILBREAK
144
+ max_conversation_turns = 2
145
+ language = SupportedLanguages.English
146
+ self._ensure_service_dependencies()
147
+ templates = await self.adversarial_template_handler._get_content_harm_template_collections(scenario.value)
148
+ concurrent_async_task = min(concurrent_async_task, 1000)
149
+ semaphore = asyncio.Semaphore(concurrent_async_task)
150
+ sim_results = []
151
+ tasks = []
152
+ total_tasks = sum(len(t.template_parameters) for t in templates)
153
+ if max_simulation_results > total_tasks:
154
+ logger.warning(
155
+ "Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
156
+ "\n %s simulations will be generated.",
157
+ max_simulation_results,
158
+ total_tasks,
159
+ total_tasks,
152
160
  )
153
- jb_sim = AdversarialSimulator(azure_ai_project=cast(dict, self.azure_ai_project), credential=self.credential)
154
- jb_sim_results = await jb_sim(
155
- scenario=scenario,
156
- target=target,
157
- max_conversation_turns=max_conversation_turns,
158
- max_simulation_results=max_simulation_results,
159
- api_call_retry_limit=api_call_retry_limit,
160
- api_call_retry_sleep_sec=api_call_retry_sleep_sec,
161
- api_call_delay_sec=api_call_delay_sec,
162
- concurrent_async_task=concurrent_async_task,
163
- _jailbreak_type="xpia",
161
+ total_tasks = min(total_tasks, max_simulation_results)
162
+ progress_bar = tqdm(
163
+ total=total_tasks,
164
+ desc="generating jailbreak simulations",
165
+ ncols=100,
166
+ unit="simulations",
164
167
  )
165
- return jb_sim_results
168
+ for template in templates:
169
+ for parameter in template.template_parameters:
170
+ tasks.append(
171
+ asyncio.create_task(
172
+ self._simulate_async(
173
+ target=target,
174
+ template=template,
175
+ parameters=parameter,
176
+ max_conversation_turns=max_conversation_turns,
177
+ api_call_retry_limit=api_call_retry_limit,
178
+ api_call_retry_sleep_sec=api_call_retry_sleep_sec,
179
+ api_call_delay_sec=api_call_delay_sec,
180
+ language=language,
181
+ semaphore=semaphore,
182
+ )
183
+ )
184
+ )
185
+ if len(tasks) >= max_simulation_results:
186
+ break
187
+ if len(tasks) >= max_simulation_results:
188
+ break
189
+ for task in asyncio.as_completed(tasks):
190
+ completed_task = await task # type: ignore
191
+ template_parameters = completed_task.get("template_parameters", {}) # type: ignore
192
+ xpia_attack_type = template_parameters.get("xpia_attack_type", "") # type: ignore
193
+ action = template_parameters.get("action", "") # type: ignore
194
+ document_type = template_parameters.get("document_type", "") # type: ignore
195
+ sim_results.append(
196
+ {
197
+ "messages": completed_task["messages"], # type: ignore
198
+ "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
199
+ "template_parameters": {
200
+ "metadata": {
201
+ "xpia_attack_type": xpia_attack_type,
202
+ "action": action,
203
+ "document_type": document_type,
204
+ },
205
+ },
206
+ }
207
+ )
208
+ progress_bar.update(1)
209
+ progress_bar.close()
210
+ return JsonLineList(sim_results)
@@ -11,7 +11,7 @@ from abc import ABC, abstractmethod
11
11
  from enum import Enum
12
12
  from typing import Optional, Union
13
13
 
14
- from azure.core.credentials import TokenCredential, AccessToken
14
+ from azure.core.credentials import AccessToken, TokenCredential
15
15
  from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
16
16
 
17
17
  AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
@@ -36,8 +36,8 @@ On January 24, 1984, former Apple CEO Steve Jobs introduced the first Macintosh.
36
36
  Some years later, research firms IDC and Gartner reported that Apple's market share in the U.S. had increased to about 6%.
37
37
  <|text_end|>
38
38
  Output with 5 QnAs:
39
- [
40
- {
39
+ {
40
+ "qna":[{
41
41
  "q": "When did the former Apple CEO Steve Jobs introduced the first Macintosh?",
42
42
  "r": "January 24, 1984"
43
43
  },
@@ -56,8 +56,8 @@ Output with 5 QnAs:
56
56
  {
57
57
  "q": "What was the percentage increase of Apple's market share in the U.S., as reported by research firms IDC and Gartner?",
58
58
  "r": "6%"
59
- }
60
- ]
59
+ }]
60
+ }
61
61
  Text:
62
62
  <|text_start|>
63
63
  {{ text }}
@@ -16,6 +16,9 @@ inputs:
16
16
  type: string
17
17
  conversation_history:
18
18
  type: dict
19
+ action:
20
+ type: string
21
+ default: continue the converasation and make sure the task is completed by asking relevant questions
19
22
 
20
23
  ---
21
24
  system:
@@ -25,8 +28,10 @@ Output must be in JSON format
25
28
  Here's a sample output:
26
29
  {
27
30
  "content": "Here is my follow-up question.",
28
- "user": "user"
31
+ "role": "user"
29
32
  }
30
33
 
31
34
  Output with a json object that continues the conversation, given the conversation history:
32
35
  {{ conversation_history }}
36
+
37
+ {{ action }}
@@ -5,20 +5,23 @@
5
5
  # ---------------------------------------------------------
6
6
  import asyncio
7
7
  import importlib.resources as pkg_resources
8
- from tqdm import tqdm
9
8
  import json
10
9
  import os
11
10
  import re
12
11
  import warnings
13
- from typing import Any, Callable, Dict, List, Optional, Union
12
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
13
+
14
14
  from promptflow.core import AsyncPrompty
15
- from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
15
+ from tqdm import tqdm
16
+
17
+ from azure.ai.evaluation._common._experimental import experimental
16
18
  from azure.ai.evaluation._common.utils import construct_prompty_model_config
19
+ from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
17
20
 
18
21
  from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
19
22
  from .._user_agent import USER_AGENT
20
23
  from ._conversation.constants import ConversationRole
21
- from ._helpers import ConversationHistory, Turn, experimental
24
+ from ._helpers import ConversationHistory, Turn
22
25
  from ._utils import JsonLineChatProtocol
23
26
 
24
27
 
@@ -89,7 +92,8 @@ class Simulator:
89
92
  api_call_delay_sec: float = 1,
90
93
  query_response_generating_prompty_kwargs: Dict[str, Any] = {},
91
94
  user_simulator_prompty_kwargs: Dict[str, Any] = {},
92
- conversation_turns: List[List[str]] = [],
95
+ conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
96
+ concurrent_async_tasks: int = 5,
93
97
  **kwargs,
94
98
  ) -> List[JsonLineChatProtocol]:
95
99
  """
@@ -116,7 +120,10 @@ class Simulator:
116
120
  :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
117
121
  :paramtype user_simulator_prompty_kwargs: Dict[str, Any]
118
122
  :keyword conversation_turns: Predefined conversation turns to simulate.
119
- :paramtype conversation_turns: List[List[str]]
123
+ :paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
124
+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
125
+ Defaults to 5.
126
+ :paramtype concurrent_async_tasks: int
120
127
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
121
128
  :rtype: List[JsonLineChatProtocol]
122
129
 
@@ -131,12 +138,12 @@ class Simulator:
131
138
  if conversation_turns and (text or tasks):
132
139
  raise ValueError("Cannot specify both conversation_turns and text/tasks")
133
140
 
134
- if num_queries > len(tasks):
141
+ if text and num_queries > len(tasks):
135
142
  warnings.warn(
136
143
  f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
137
144
  f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
138
145
  )
139
- elif num_queries < len(tasks):
146
+ elif text and num_queries < len(tasks):
140
147
  warnings.warn(
141
148
  f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
142
149
  f"Only the first {num_queries} lines of the specified tasks will be simulated."
@@ -154,6 +161,7 @@ class Simulator:
154
161
  user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
155
162
  api_call_delay_sec=api_call_delay_sec,
156
163
  prompty_model_config=prompty_model_config,
164
+ concurrent_async_tasks=concurrent_async_tasks,
157
165
  )
158
166
 
159
167
  query_responses = await self._generate_query_responses(
@@ -172,6 +180,7 @@ class Simulator:
172
180
  user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
173
181
  target=target,
174
182
  api_call_delay_sec=api_call_delay_sec,
183
+ text=text,
175
184
  )
176
185
 
177
186
  async def _simulate_with_predefined_turns(
@@ -179,11 +188,12 @@ class Simulator:
179
188
  *,
180
189
  target: Callable,
181
190
  max_conversation_turns: int,
182
- conversation_turns: List[List[str]],
191
+ conversation_turns: List[List[Union[str, Dict[str, Any]]]],
183
192
  user_simulator_prompty: Optional[str],
184
193
  user_simulator_prompty_kwargs: Dict[str, Any],
185
194
  api_call_delay_sec: float,
186
195
  prompty_model_config: Any,
196
+ concurrent_async_tasks: int,
187
197
  ) -> List[JsonLineChatProtocol]:
188
198
  """
189
199
  Simulates conversations using predefined conversation turns.
@@ -193,7 +203,7 @@ class Simulator:
193
203
  :keyword max_conversation_turns: Maximum number of turns for the simulation.
194
204
  :paramtype max_conversation_turns: int
195
205
  :keyword conversation_turns: A list of predefined conversation turns.
196
- :paramtype conversation_turns: List[List[str]]
206
+ :paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
197
207
  :keyword user_simulator_prompty: Path to the user simulator prompty file.
198
208
  :paramtype user_simulator_prompty: Optional[str]
199
209
  :keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
@@ -202,42 +212,60 @@ class Simulator:
202
212
  :paramtype api_call_delay_sec: float
203
213
  :keyword prompty_model_config: The configuration for the prompty model.
204
214
  :paramtype prompty_model_config: Any
215
+ :keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
216
+ :paramtype concurrent_async_tasks: int
205
217
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
206
218
  :rtype: List[JsonLineChatProtocol]
207
219
  """
208
- simulated_conversations = []
209
220
  progress_bar = tqdm(
210
221
  total=int(len(conversation_turns) * (max_conversation_turns / 2)),
211
222
  desc="Simulating with predefined conversation turns: ",
212
223
  ncols=100,
213
224
  unit="messages",
214
225
  )
215
-
216
- for simulation in conversation_turns:
217
- current_simulation = ConversationHistory()
218
- for simulated_turn in simulation:
219
- user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
220
- current_simulation.add_to_history(user_turn)
221
- assistant_response = await self._get_target_response(
222
- target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
223
- )
224
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
225
- current_simulation.add_to_history(assistant_turn)
226
- progress_bar.update(1) # Update progress bar for both user and assistant turns
227
-
228
- if len(current_simulation) < max_conversation_turns:
229
- await self._extend_conversation_with_simulator(
230
- current_simulation=current_simulation,
231
- max_conversation_turns=max_conversation_turns,
232
- user_simulator_prompty=user_simulator_prompty,
233
- user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
234
- api_call_delay_sec=api_call_delay_sec,
235
- prompty_model_config=prompty_model_config,
236
- target=target,
237
- progress_bar=progress_bar,
238
- )
239
- simulated_conversations.append(
240
- JsonLineChatProtocol(
226
+ semaphore = asyncio.Semaphore(concurrent_async_tasks)
227
+ progress_bar_lock = asyncio.Lock()
228
+
229
+ async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
230
+ async with semaphore:
231
+ current_simulation = ConversationHistory()
232
+ for simulated_turn in simulation:
233
+ if isinstance(simulated_turn, str):
234
+ user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
235
+ elif isinstance(simulated_turn, dict):
236
+ user_turn = Turn(
237
+ role=ConversationRole.USER,
238
+ content=str(simulated_turn.get("content")),
239
+ context=str(simulated_turn.get("context")),
240
+ )
241
+ else:
242
+ raise ValueError(
243
+ "Each simulated turn must be a string or a dict with 'content' and 'context' keys"
244
+ )
245
+ current_simulation.add_to_history(user_turn)
246
+ assistant_response, assistant_context = await self._get_target_response(
247
+ target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
248
+ )
249
+ assistant_turn = Turn(
250
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
251
+ )
252
+ current_simulation.add_to_history(assistant_turn)
253
+ async with progress_bar_lock:
254
+ progress_bar.update(1)
255
+
256
+ if len(current_simulation) < max_conversation_turns:
257
+ await self._extend_conversation_with_simulator(
258
+ current_simulation=current_simulation,
259
+ max_conversation_turns=max_conversation_turns,
260
+ user_simulator_prompty=user_simulator_prompty,
261
+ user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
262
+ api_call_delay_sec=api_call_delay_sec,
263
+ prompty_model_config=prompty_model_config,
264
+ target=target,
265
+ progress_bar=progress_bar,
266
+ progress_bar_lock=progress_bar_lock,
267
+ )
268
+ return JsonLineChatProtocol(
241
269
  {
242
270
  "messages": current_simulation.to_list(),
243
271
  "finish_reason": ["stop"],
@@ -245,10 +273,11 @@ class Simulator:
245
273
  "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
246
274
  }
247
275
  )
248
- )
249
276
 
277
+ tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
278
+ results = await asyncio.gather(*tasks)
250
279
  progress_bar.close()
251
- return simulated_conversations
280
+ return results
252
281
 
253
282
  async def _extend_conversation_with_simulator(
254
283
  self,
@@ -261,6 +290,7 @@ class Simulator:
261
290
  prompty_model_config: Dict[str, Any],
262
291
  target: Callable,
263
292
  progress_bar: tqdm,
293
+ progress_bar_lock: asyncio.Lock,
264
294
  ):
265
295
  """
266
296
  Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
@@ -281,6 +311,8 @@ class Simulator:
281
311
  :paramtype target: Callable,
282
312
  :keyword progress_bar: Progress bar for tracking simulation progress.
283
313
  :paramtype progress_bar: tqdm,
314
+ :keyword progress_bar_lock: Lock for updating the progress bar safely.
315
+ :paramtype progress_bar_lock: asyncio.Lock
284
316
  """
285
317
  user_flow = self._load_user_simulation_flow(
286
318
  user_simulator_prompty=user_simulator_prompty, # type: ignore
@@ -291,19 +323,22 @@ class Simulator:
291
323
  while len(current_simulation) < max_conversation_turns:
292
324
  user_response_content = await user_flow(
293
325
  task="Continue the conversation",
294
- conversation_history=current_simulation.to_list(),
326
+ conversation_history=current_simulation.to_context_free_list(),
295
327
  **user_simulator_prompty_kwargs,
296
328
  )
297
329
  user_response = self._parse_prompty_response(response=user_response_content)
298
330
  user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
299
331
  current_simulation.add_to_history(user_turn)
300
332
  await asyncio.sleep(api_call_delay_sec)
301
- assistant_response = await self._get_target_response(
333
+ assistant_response, assistant_context = await self._get_target_response(
302
334
  target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
303
335
  )
304
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
336
+ assistant_turn = Turn(
337
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
338
+ )
305
339
  current_simulation.add_to_history(assistant_turn)
306
- progress_bar.update(1)
340
+ async with progress_bar_lock:
341
+ progress_bar.update(1)
307
342
 
308
343
  def _load_user_simulation_flow(
309
344
  self,
@@ -432,6 +467,14 @@ class Simulator:
432
467
  if isinstance(query_responses, dict):
433
468
  keys = list(query_responses.keys())
434
469
  return query_responses[keys[0]]
470
+ if isinstance(query_responses, str):
471
+ query_responses = json.loads(query_responses)
472
+ if isinstance(query_responses, dict):
473
+ if len(query_responses.keys()) == 1:
474
+ return query_responses[list(query_responses.keys())[0]]
475
+ return query_responses # type: ignore
476
+ if isinstance(query_responses, list):
477
+ return query_responses
435
478
  return json.loads(query_responses)
436
479
  except Exception as e:
437
480
  raise RuntimeError("Error generating query responses") from e
@@ -497,6 +540,7 @@ class Simulator:
497
540
  user_simulator_prompty_kwargs: Dict[str, Any],
498
541
  target: Callable,
499
542
  api_call_delay_sec: float,
543
+ text: str,
500
544
  ) -> List[JsonLineChatProtocol]:
501
545
  """
502
546
  Creates full conversations from query-response pairs.
@@ -515,6 +559,8 @@ class Simulator:
515
559
  :paramtype target: Callable
516
560
  :keyword api_call_delay_sec: Delay in seconds between API calls.
517
561
  :paramtype api_call_delay_sec: float
562
+ :keyword text: The initial input text for generating query responses.
563
+ :paramtype text: str
518
564
  :return: A list of simulated conversations represented as JsonLineChatProtocol objects.
519
565
  :rtype: List[JsonLineChatProtocol]
520
566
  """
@@ -552,6 +598,7 @@ class Simulator:
552
598
  "task": task,
553
599
  "expected_response": response,
554
600
  "query": query,
601
+ "original_text": text,
555
602
  },
556
603
  "$schema": "http://azureml/sdk-2-0/ChatConversation.json",
557
604
  }
@@ -595,8 +642,6 @@ class Simulator:
595
642
  :rtype: List[Dict[str, Optional[str]]]
596
643
  """
597
644
  conversation_history = ConversationHistory()
598
- # user_turn = Turn(role=ConversationRole.USER, content=conversation_starter)
599
- # conversation_history.add_to_history(user_turn)
600
645
 
601
646
  while len(conversation_history) < max_conversation_turns:
602
647
  user_flow = self._load_user_simulation_flow(
@@ -604,24 +649,33 @@ class Simulator:
604
649
  prompty_model_config=self.model_config, # type: ignore
605
650
  user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
606
651
  )
607
- conversation_starter_from_simulated_user = await user_flow(
608
- task=task,
609
- conversation_history=[
610
- {
611
- "role": "assistant",
612
- "content": conversation_starter,
613
- "your_task": "Act as the user and translate the content into a user query.",
614
- }
615
- ],
616
- )
652
+ if len(conversation_history) == 0:
653
+ conversation_starter_from_simulated_user = await user_flow(
654
+ task=task,
655
+ conversation_history=[
656
+ {
657
+ "role": "assistant",
658
+ "content": conversation_starter,
659
+ }
660
+ ],
661
+ action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
662
+ )
663
+ else:
664
+ conversation_starter_from_simulated_user = await user_flow(
665
+ task=task,
666
+ conversation_history=conversation_history.to_context_free_list(),
667
+ action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
668
+ )
617
669
  if isinstance(conversation_starter_from_simulated_user, dict):
618
670
  conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
619
671
  user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
620
672
  conversation_history.add_to_history(user_turn)
621
- assistant_response = await self._get_target_response(
673
+ assistant_response, assistant_context = await self._get_target_response(
622
674
  target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
623
675
  )
624
- assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
676
+ assistant_turn = Turn(
677
+ role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
678
+ )
625
679
  conversation_history.add_to_history(assistant_turn)
626
680
  progress_bar.update(1)
627
681
 
@@ -632,7 +686,7 @@ class Simulator:
632
686
 
633
687
  async def _get_target_response(
634
688
  self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
635
- ) -> str:
689
+ ) -> Tuple[str, Optional[str]]:
636
690
  """
637
691
  Retrieves the response from the target callback based on the current conversation history.
638
692
 
@@ -642,8 +696,8 @@ class Simulator:
642
696
  :paramtype api_call_delay_sec: float
643
697
  :keyword conversation_history: The current conversation history.
644
698
  :paramtype conversation_history: ConversationHistory
645
- :return: The content of the response from the target.
646
- :rtype: str
699
+ :return: The content of the response from the target and an optional context.
700
+ :rtype: str, Optional[str]
647
701
  """
648
702
  response = await target(
649
703
  messages={"messages": conversation_history.to_list()},
@@ -653,4 +707,4 @@ class Simulator:
653
707
  )
654
708
  await asyncio.sleep(api_call_delay_sec)
655
709
  latest_message = response["messages"][-1]
656
- return latest_message["content"]
710
+ return latest_message["content"], latest_message.get("context", "") # type: ignore