azure-ai-evaluation 1.0.0__py3-none-any.whl → 1.0.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +5 -31
- azure/ai/evaluation/_common/constants.py +2 -9
- azure/ai/evaluation/_common/rai_service.py +120 -300
- azure/ai/evaluation/_common/utils.py +23 -381
- azure/ai/evaluation/_constants.py +6 -19
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/__init__.py +2 -3
- azure/ai/evaluation/_evaluate/{_batch_run/eval_run_context.py → _batch_run_client/batch_run_context.py} +7 -23
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/code_client.py +17 -33
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/proxy_client.py +4 -32
- azure/ai/evaluation/_evaluate/_eval_run.py +24 -81
- azure/ai/evaluation/_evaluate/_evaluate.py +239 -393
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +17 -17
- azure/ai/evaluation/_evaluate/_utils.py +28 -82
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +18 -17
- azure/ai/evaluation/_evaluators/{_retrieval → _chat}/__init__.py +2 -2
- azure/ai/evaluation/_evaluators/_chat/_chat.py +357 -0
- azure/ai/evaluation/_evaluators/{_service_groundedness → _chat/retrieval}/__init__.py +2 -2
- azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py +157 -0
- azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty +48 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +88 -78
- azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +39 -76
- azure/ai/evaluation/_evaluators/_content_safety/__init__.py +4 -0
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +67 -105
- azure/ai/evaluation/_evaluators/{_multimodal/_content_safety_multimodal_base.py → _content_safety/_content_safety_base.py} +34 -24
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +301 -0
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +54 -105
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +52 -99
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +52 -101
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +51 -101
- azure/ai/evaluation/_evaluators/_eci/_eci.py +54 -44
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +19 -34
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +89 -76
- azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +41 -66
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +16 -14
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +87 -113
- azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +54 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +27 -20
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +80 -89
- azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +104 -0
- azure/ai/evaluation/_evaluators/_qa/_qa.py +30 -23
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +96 -84
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +47 -78
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +27 -26
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +38 -53
- azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +5 -0
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +105 -91
- azure/ai/evaluation/_exceptions.py +7 -28
- azure/ai/evaluation/_http_utils.py +132 -203
- azure/ai/evaluation/_model_configurations.py +8 -104
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/simulator/__init__.py +1 -2
- azure/ai/evaluation/simulator/_adversarial_scenario.py +1 -20
- azure/ai/evaluation/simulator/_adversarial_simulator.py +92 -111
- azure/ai/evaluation/simulator/_constants.py +1 -11
- azure/ai/evaluation/simulator/_conversation/__init__.py +12 -13
- azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +67 -33
- azure/ai/evaluation/simulator/_helpers/__init__.py +2 -1
- azure/ai/evaluation/{_common → simulator/_helpers}/_experimental.py +9 -24
- azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +5 -26
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +94 -107
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +22 -70
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +11 -28
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +4 -8
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +24 -68
- azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
- azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +10 -6
- azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +5 -6
- azure/ai/evaluation/simulator/_simulator.py +207 -277
- azure/ai/evaluation/simulator/_tracing.py +4 -4
- azure/ai/evaluation/simulator/_utils.py +13 -31
- azure_ai_evaluation-1.0.0b2.dist-info/METADATA +449 -0
- azure_ai_evaluation-1.0.0b2.dist-info/RECORD +99 -0
- {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_common/math.py +0 -89
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +0 -46
- azure/ai/evaluation/_evaluators/_common/__init__.py +0 -13
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +0 -344
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +0 -88
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +0 -133
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +0 -113
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +0 -99
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +0 -20
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +0 -132
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +0 -124
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +0 -100
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +0 -112
- azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +0 -93
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +0 -148
- azure/ai/evaluation/_vendor/__init__.py +0 -3
- azure/ai/evaluation/_vendor/rouge_score/__init__.py +0 -14
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +0 -328
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +0 -63
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +0 -63
- azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +0 -53
- azure/ai/evaluation/simulator/_data_sources/__init__.py +0 -3
- azure/ai/evaluation/simulator/_data_sources/grounding.json +0 -1150
- azure_ai_evaluation-1.0.0.dist-info/METADATA +0 -595
- azure_ai_evaluation-1.0.0.dist-info/NOTICE.txt +0 -70
- azure_ai_evaluation-1.0.0.dist-info/RECORD +0 -119
- {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
|
-
# pylint: disable=W0102,W0613,R0914,C0301,E0401,E0611
|
|
2
|
+
# pylint: disable=W0102,W0613,R0914,C0301,E0401,E0611
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
5
5
|
# ---------------------------------------------------------
|
|
@@ -9,19 +9,17 @@ import json
|
|
|
9
9
|
import os
|
|
10
10
|
import re
|
|
11
11
|
import warnings
|
|
12
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
|
13
13
|
|
|
14
|
-
from promptflow.
|
|
14
|
+
from promptflow.client import load_flow
|
|
15
|
+
from promptflow.core import AzureOpenAIModelConfiguration, Flow
|
|
15
16
|
from tqdm import tqdm
|
|
16
17
|
|
|
17
|
-
from azure.ai.evaluation._common._experimental import experimental
|
|
18
|
-
from azure.ai.evaluation._common.utils import construct_prompty_model_config
|
|
19
|
-
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
20
|
-
|
|
21
|
-
from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
|
|
22
18
|
from .._user_agent import USER_AGENT
|
|
23
19
|
from ._conversation.constants import ConversationRole
|
|
24
|
-
from ._helpers import ConversationHistory, Turn
|
|
20
|
+
from ._helpers import ConversationHistory, Turn, experimental
|
|
21
|
+
|
|
22
|
+
# from ._tracing import monitor_task_simulator
|
|
25
23
|
from ._utils import JsonLineChatProtocol
|
|
26
24
|
|
|
27
25
|
|
|
@@ -29,77 +27,53 @@ from ._utils import JsonLineChatProtocol
|
|
|
29
27
|
class Simulator:
|
|
30
28
|
"""
|
|
31
29
|
Simulator for generating synthetic conversations.
|
|
32
|
-
|
|
33
|
-
:param model_config: A dictionary defining the configuration for the model. Acceptable types are AzureOpenAIModelConfiguration and OpenAIModelConfiguration.
|
|
34
|
-
:type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, ~azure.ai.evaluation.OpenAIModelConfiguration]
|
|
35
|
-
:raises ValueError: If the model_config does not contain the required keys or any value is None.
|
|
36
|
-
|
|
37
|
-
.. admonition:: Example:
|
|
38
|
-
|
|
39
|
-
.. literalinclude:: ../samples/evaluation_samples_simulate.py
|
|
40
|
-
:start-after: [START nonadversarial_simulator]
|
|
41
|
-
:end-before: [END nonadversarial_simulator]
|
|
42
|
-
:language: python
|
|
43
|
-
:dedent: 8
|
|
44
|
-
:caption: Run a Simulator for 2 queries and 4 conversation turns.
|
|
45
30
|
"""
|
|
46
31
|
|
|
47
|
-
def __init__(self,
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
32
|
+
def __init__(self, azure_ai_project: Dict[str, Any], credential: Optional[Any] = None):
|
|
33
|
+
"""
|
|
34
|
+
Initializes the task simulator with a project scope.
|
|
35
|
+
|
|
36
|
+
:param azure_ai_project: A dictionary defining the scope of the project, including keys such as
|
|
37
|
+
"subscription_id", "resource_group_name", and "project_name".
|
|
38
|
+
:param credential: Azure credentials to authenticate the user. If None, the default credentials are used.
|
|
39
|
+
:paramtype credential: Optional[Any]
|
|
40
|
+
:raises ValueError: If the azure_ai_project does not contain the required keys or any value is None.
|
|
41
|
+
"""
|
|
42
|
+
self._validate_project_config(azure_ai_project)
|
|
43
|
+
self.azure_ai_project = azure_ai_project
|
|
44
|
+
self.azure_ai_project["api_version"] = "2024-02-15-preview"
|
|
45
|
+
self.credential = credential
|
|
52
46
|
|
|
53
47
|
@staticmethod
|
|
54
|
-
def
|
|
48
|
+
def _validate_project_config(azure_ai_project: Dict[str, Any]):
|
|
55
49
|
"""
|
|
56
|
-
Validates the
|
|
57
|
-
If 'type' is not specified, it will attempt to infer the type based on the keys present.
|
|
50
|
+
Validates the azure_ai_project configuration to ensure all required keys are present and have non-None values.
|
|
58
51
|
|
|
59
|
-
:param
|
|
60
|
-
:type
|
|
52
|
+
:param azure_ai_project: The Azure AI project configuration dictionary.
|
|
53
|
+
:type azure_ai_project: Dict[str, Any]
|
|
61
54
|
:raises ValueError: If required keys are missing or any of the values are None.
|
|
62
55
|
"""
|
|
63
|
-
|
|
64
|
-
if
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
model_config["type"] = "openai"
|
|
69
|
-
else:
|
|
70
|
-
raise ValueError(
|
|
71
|
-
"Unable to infer 'type' from model_config. Please specify 'type' as 'azure_openai' or 'openai'."
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
if model_config["type"] == "azure_openai":
|
|
75
|
-
required_keys = ["azure_deployment", "azure_endpoint"]
|
|
76
|
-
elif model_config["type"] == "openai":
|
|
77
|
-
required_keys = ["api_key", "model"]
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError("model_config 'type' must be 'azure_openai' or 'openai'.")
|
|
80
|
-
|
|
81
|
-
missing_keys = [key for key in required_keys if key not in model_config]
|
|
82
|
-
if missing_keys:
|
|
83
|
-
raise ValueError(f"model_config is missing required keys: {', '.join(missing_keys)}")
|
|
84
|
-
none_keys = [key for key in required_keys if model_config.get(key) is None]
|
|
85
|
-
if none_keys:
|
|
86
|
-
raise ValueError(f"The following keys in model_config must not be None: {', '.join(none_keys)}")
|
|
56
|
+
required_keys = ["subscription_id", "resource_group_name", "project_name"]
|
|
57
|
+
if not all(key in azure_ai_project for key in required_keys):
|
|
58
|
+
raise ValueError(f"azure_ai_project must contain keys: {', '.join(required_keys)}")
|
|
59
|
+
if not all(azure_ai_project[key] for key in required_keys):
|
|
60
|
+
raise ValueError("subscription_id, resource_group_name, and project_name must not be None")
|
|
87
61
|
|
|
62
|
+
# @monitor_task_simulator
|
|
88
63
|
async def __call__(
|
|
89
64
|
self,
|
|
90
65
|
*,
|
|
91
66
|
target: Callable,
|
|
92
67
|
max_conversation_turns: int = 5,
|
|
93
|
-
tasks: List[
|
|
68
|
+
tasks: List[Dict] = [],
|
|
94
69
|
text: str = "",
|
|
95
70
|
num_queries: int = 5,
|
|
96
71
|
query_response_generating_prompty: Optional[str] = None,
|
|
97
72
|
user_simulator_prompty: Optional[str] = None,
|
|
98
73
|
api_call_delay_sec: float = 1,
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
conversation_turns: List[List[
|
|
102
|
-
concurrent_async_tasks: int = 5,
|
|
74
|
+
query_response_generating_prompty_kwargs: Dict[str, Any] = {},
|
|
75
|
+
user_simulator_prompty_kwargs: Dict[str, Any] = {},
|
|
76
|
+
conversation_turns: List[List[str]] = [],
|
|
103
77
|
**kwargs,
|
|
104
78
|
) -> List[JsonLineChatProtocol]:
|
|
105
79
|
"""
|
|
@@ -121,15 +95,12 @@ class Simulator:
|
|
|
121
95
|
:paramtype user_simulator_prompty: Optional[str]
|
|
122
96
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
123
97
|
:paramtype api_call_delay_sec: float
|
|
124
|
-
:keyword
|
|
125
|
-
:paramtype
|
|
126
|
-
:keyword
|
|
127
|
-
:paramtype
|
|
98
|
+
:keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the query response generating prompty.
|
|
99
|
+
:paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
|
|
100
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
101
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
128
102
|
:keyword conversation_turns: Predefined conversation turns to simulate.
|
|
129
|
-
:paramtype conversation_turns: List[List[
|
|
130
|
-
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
131
|
-
Defaults to 5.
|
|
132
|
-
:paramtype concurrent_async_tasks: int
|
|
103
|
+
:paramtype conversation_turns: List[List[str]]
|
|
133
104
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
134
105
|
:rtype: List[JsonLineChatProtocol]
|
|
135
106
|
|
|
@@ -138,18 +109,18 @@ class Simulator:
|
|
|
138
109
|
|
|
139
110
|
Modes:
|
|
140
111
|
- Task-Free Mode: When only num_queries is specified and tasks is not, the method generates num_queries x max_conversation_turns lines of simulated data grounded in the context of the text.
|
|
141
|
-
- Task-Specific Mode: When both num_queries and tasks are specified, the method generates lines of simulated data based on the tasks. If num_queries > len(tasks), the remaining lines
|
|
112
|
+
- Task-Specific Mode: When both num_queries and tasks are specified, the method generates lines of simulated data based on the tasks. If num_queries > len(tasks), the remaining lines are simulated in task-free mode. If num_queries < len(tasks), only the first num_queries tasks are used.
|
|
142
113
|
- Conversation Starter Mode: When conversation_turns are specified, the method starts each conversation with the user-specified queries and then follows the conversation history for the remaining turns.
|
|
143
114
|
"""
|
|
144
115
|
if conversation_turns and (text or tasks):
|
|
145
116
|
raise ValueError("Cannot specify both conversation_turns and text/tasks")
|
|
146
117
|
|
|
147
|
-
if
|
|
118
|
+
if num_queries > len(tasks):
|
|
148
119
|
warnings.warn(
|
|
149
120
|
f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
|
|
150
121
|
f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
|
|
151
122
|
)
|
|
152
|
-
elif
|
|
123
|
+
elif num_queries < len(tasks):
|
|
153
124
|
warnings.warn(
|
|
154
125
|
f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
|
|
155
126
|
f"Only the first {num_queries} lines of the specified tasks will be simulated."
|
|
@@ -157,49 +128,60 @@ class Simulator:
|
|
|
157
128
|
num_queries = min(num_queries, len(tasks))
|
|
158
129
|
max_conversation_turns *= 2 # account for both user and assistant turns
|
|
159
130
|
|
|
160
|
-
prompty_model_config = self.
|
|
131
|
+
prompty_model_config = self._build_prompty_model_config()
|
|
132
|
+
|
|
161
133
|
if conversation_turns:
|
|
162
134
|
return await self._simulate_with_predefined_turns(
|
|
163
135
|
target=target,
|
|
164
136
|
max_conversation_turns=max_conversation_turns,
|
|
165
137
|
conversation_turns=conversation_turns,
|
|
166
138
|
user_simulator_prompty=user_simulator_prompty,
|
|
167
|
-
|
|
139
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
168
140
|
api_call_delay_sec=api_call_delay_sec,
|
|
169
141
|
prompty_model_config=prompty_model_config,
|
|
170
|
-
concurrent_async_tasks=concurrent_async_tasks,
|
|
171
142
|
)
|
|
172
143
|
|
|
173
144
|
query_responses = await self._generate_query_responses(
|
|
174
145
|
text=text,
|
|
175
146
|
num_queries=num_queries,
|
|
176
147
|
query_response_generating_prompty=query_response_generating_prompty,
|
|
177
|
-
|
|
148
|
+
query_response_generating_prompty_kwargs=query_response_generating_prompty_kwargs,
|
|
178
149
|
prompty_model_config=prompty_model_config,
|
|
179
150
|
**kwargs,
|
|
180
151
|
)
|
|
152
|
+
|
|
181
153
|
return await self._create_conversations_from_query_responses(
|
|
182
154
|
query_responses=query_responses,
|
|
183
155
|
max_conversation_turns=max_conversation_turns,
|
|
184
156
|
tasks=tasks,
|
|
185
157
|
user_simulator_prompty=user_simulator_prompty,
|
|
186
|
-
|
|
158
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
187
159
|
target=target,
|
|
188
160
|
api_call_delay_sec=api_call_delay_sec,
|
|
189
|
-
text=text,
|
|
190
161
|
)
|
|
191
162
|
|
|
163
|
+
def _build_prompty_model_config(self) -> Dict[str, Any]:
|
|
164
|
+
"""
|
|
165
|
+
Constructs the configuration for the prompty model.
|
|
166
|
+
|
|
167
|
+
:return: A dictionary containing the prompty model configuration, including API version and user agent headers if applicable.
|
|
168
|
+
:rtype: Dict[str, Any]
|
|
169
|
+
"""
|
|
170
|
+
config = {"configuration": self.azure_ai_project}
|
|
171
|
+
if USER_AGENT and isinstance(self.azure_ai_project, AzureOpenAIModelConfiguration):
|
|
172
|
+
config.update({"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}})
|
|
173
|
+
return config
|
|
174
|
+
|
|
192
175
|
async def _simulate_with_predefined_turns(
|
|
193
176
|
self,
|
|
194
177
|
*,
|
|
195
178
|
target: Callable,
|
|
196
179
|
max_conversation_turns: int,
|
|
197
|
-
conversation_turns: List[List[
|
|
180
|
+
conversation_turns: List[List[str]],
|
|
198
181
|
user_simulator_prompty: Optional[str],
|
|
199
|
-
|
|
182
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
200
183
|
api_call_delay_sec: float,
|
|
201
|
-
prompty_model_config: Any,
|
|
202
|
-
concurrent_async_tasks: int,
|
|
184
|
+
prompty_model_config: Dict[str, Any],
|
|
203
185
|
) -> List[JsonLineChatProtocol]:
|
|
204
186
|
"""
|
|
205
187
|
Simulates conversations using predefined conversation turns.
|
|
@@ -209,81 +191,54 @@ class Simulator:
|
|
|
209
191
|
:keyword max_conversation_turns: Maximum number of turns for the simulation.
|
|
210
192
|
:paramtype max_conversation_turns: int
|
|
211
193
|
:keyword conversation_turns: A list of predefined conversation turns.
|
|
212
|
-
:paramtype conversation_turns: List[List[
|
|
194
|
+
:paramtype conversation_turns: List[List[str]]
|
|
213
195
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
214
196
|
:paramtype user_simulator_prompty: Optional[str]
|
|
215
|
-
:keyword
|
|
216
|
-
:paramtype
|
|
197
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
198
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
217
199
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
218
200
|
:paramtype api_call_delay_sec: float
|
|
219
201
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
220
|
-
:paramtype prompty_model_config: Any
|
|
221
|
-
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
222
|
-
:paramtype concurrent_async_tasks: int
|
|
202
|
+
:paramtype prompty_model_config: Dict[str, Any]
|
|
223
203
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
224
204
|
:rtype: List[JsonLineChatProtocol]
|
|
225
205
|
"""
|
|
206
|
+
simulated_conversations = []
|
|
226
207
|
progress_bar = tqdm(
|
|
227
208
|
total=int(len(conversation_turns) * (max_conversation_turns / 2)),
|
|
228
209
|
desc="Simulating with predefined conversation turns: ",
|
|
229
210
|
ncols=100,
|
|
230
211
|
unit="messages",
|
|
231
212
|
)
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
current_simulation
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
|
|
241
|
-
elif isinstance(simulated_turn, dict):
|
|
242
|
-
user_turn = Turn(
|
|
243
|
-
role=ConversationRole.USER,
|
|
244
|
-
content=str(simulated_turn.get("content")),
|
|
245
|
-
context=str(simulated_turn.get("context")),
|
|
246
|
-
)
|
|
247
|
-
else:
|
|
248
|
-
raise ValueError(
|
|
249
|
-
"Each simulated turn must be a string or a dict with 'content' and 'context' keys"
|
|
250
|
-
)
|
|
251
|
-
current_simulation.add_to_history(user_turn)
|
|
252
|
-
assistant_response, assistant_context = await self._get_target_response(
|
|
253
|
-
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
254
|
-
)
|
|
255
|
-
assistant_turn = Turn(
|
|
256
|
-
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
257
|
-
)
|
|
258
|
-
current_simulation.add_to_history(assistant_turn)
|
|
259
|
-
async with progress_bar_lock:
|
|
260
|
-
progress_bar.update(1)
|
|
261
|
-
|
|
262
|
-
if len(current_simulation) < max_conversation_turns:
|
|
263
|
-
await self._extend_conversation_with_simulator(
|
|
264
|
-
current_simulation=current_simulation,
|
|
265
|
-
max_conversation_turns=max_conversation_turns,
|
|
266
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
267
|
-
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
268
|
-
api_call_delay_sec=api_call_delay_sec,
|
|
269
|
-
prompty_model_config=prompty_model_config,
|
|
270
|
-
target=target,
|
|
271
|
-
progress_bar=progress_bar,
|
|
272
|
-
progress_bar_lock=progress_bar_lock,
|
|
273
|
-
)
|
|
274
|
-
return JsonLineChatProtocol(
|
|
275
|
-
{
|
|
276
|
-
"messages": current_simulation.to_list(),
|
|
277
|
-
"finish_reason": ["stop"],
|
|
278
|
-
"context": {},
|
|
279
|
-
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
280
|
-
}
|
|
213
|
+
|
|
214
|
+
for simulation in conversation_turns:
|
|
215
|
+
current_simulation = ConversationHistory()
|
|
216
|
+
for simulated_turn in simulation:
|
|
217
|
+
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
|
|
218
|
+
current_simulation.add_to_history(user_turn)
|
|
219
|
+
assistant_response = await self._get_target_response(
|
|
220
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
281
221
|
)
|
|
222
|
+
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
|
|
223
|
+
current_simulation.add_to_history(assistant_turn)
|
|
224
|
+
progress_bar.update(1) # Update progress bar for both user and assistant turns
|
|
225
|
+
|
|
226
|
+
if len(current_simulation) < max_conversation_turns:
|
|
227
|
+
await self._extend_conversation_with_simulator(
|
|
228
|
+
current_simulation=current_simulation,
|
|
229
|
+
max_conversation_turns=max_conversation_turns,
|
|
230
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
231
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
232
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
233
|
+
prompty_model_config=prompty_model_config,
|
|
234
|
+
target=target,
|
|
235
|
+
progress_bar=progress_bar,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
simulated_conversations.append(current_simulation.to_list())
|
|
282
239
|
|
|
283
|
-
tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
|
|
284
|
-
results = await asyncio.gather(*tasks)
|
|
285
240
|
progress_bar.close()
|
|
286
|
-
return
|
|
241
|
+
return simulated_conversations
|
|
287
242
|
|
|
288
243
|
async def _extend_conversation_with_simulator(
|
|
289
244
|
self,
|
|
@@ -291,12 +246,11 @@ class Simulator:
|
|
|
291
246
|
current_simulation: ConversationHistory,
|
|
292
247
|
max_conversation_turns: int,
|
|
293
248
|
user_simulator_prompty: Optional[str],
|
|
294
|
-
|
|
249
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
295
250
|
api_call_delay_sec: float,
|
|
296
251
|
prompty_model_config: Dict[str, Any],
|
|
297
252
|
target: Callable,
|
|
298
253
|
progress_bar: tqdm,
|
|
299
|
-
progress_bar_lock: asyncio.Lock,
|
|
300
254
|
):
|
|
301
255
|
"""
|
|
302
256
|
Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
|
|
@@ -307,8 +261,8 @@ class Simulator:
|
|
|
307
261
|
:paramtype max_conversation_turns: int,
|
|
308
262
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
309
263
|
:paramtype user_simulator_prompty: Optional[str],
|
|
310
|
-
:keyword
|
|
311
|
-
:paramtype
|
|
264
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
265
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any],
|
|
312
266
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
313
267
|
:paramtype api_call_delay_sec: float,
|
|
314
268
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
@@ -317,92 +271,68 @@ class Simulator:
|
|
|
317
271
|
:paramtype target: Callable,
|
|
318
272
|
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
319
273
|
:paramtype progress_bar: tqdm,
|
|
320
|
-
:keyword progress_bar_lock: Lock for updating the progress bar safely.
|
|
321
|
-
:paramtype progress_bar_lock: asyncio.Lock
|
|
322
274
|
"""
|
|
323
275
|
user_flow = self._load_user_simulation_flow(
|
|
324
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
276
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
325
277
|
prompty_model_config=prompty_model_config,
|
|
326
|
-
|
|
278
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
327
279
|
)
|
|
328
280
|
|
|
329
281
|
while len(current_simulation) < max_conversation_turns:
|
|
330
|
-
user_response_content =
|
|
331
|
-
task="Continue the conversation",
|
|
332
|
-
conversation_history=current_simulation.to_context_free_list(),
|
|
333
|
-
**user_simulator_prompty_options,
|
|
282
|
+
user_response_content = user_flow(
|
|
283
|
+
task="Continue the conversation", conversation_history=current_simulation.to_list()
|
|
334
284
|
)
|
|
335
285
|
user_response = self._parse_prompty_response(response=user_response_content)
|
|
336
286
|
user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
|
|
337
287
|
current_simulation.add_to_history(user_turn)
|
|
338
288
|
await asyncio.sleep(api_call_delay_sec)
|
|
339
|
-
assistant_response
|
|
289
|
+
assistant_response = await self._get_target_response(
|
|
340
290
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
341
291
|
)
|
|
342
|
-
assistant_turn = Turn(
|
|
343
|
-
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
344
|
-
)
|
|
292
|
+
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
|
|
345
293
|
current_simulation.add_to_history(assistant_turn)
|
|
346
|
-
|
|
347
|
-
progress_bar.update(1)
|
|
294
|
+
progress_bar.update(1)
|
|
348
295
|
|
|
349
296
|
def _load_user_simulation_flow(
|
|
350
297
|
self,
|
|
351
298
|
*,
|
|
352
|
-
user_simulator_prompty:
|
|
299
|
+
user_simulator_prompty: Union[str, os.PathLike],
|
|
353
300
|
prompty_model_config: Dict[str, Any],
|
|
354
|
-
|
|
355
|
-
) ->
|
|
301
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
302
|
+
) -> Flow:
|
|
356
303
|
"""
|
|
357
304
|
Loads the flow for simulating user interactions.
|
|
358
305
|
|
|
359
306
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
360
|
-
:paramtype user_simulator_prompty:
|
|
307
|
+
:paramtype user_simulator_prompty: Union[str, os.PathLike]
|
|
361
308
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
362
309
|
:paramtype prompty_model_config: Dict[str, Any]
|
|
363
|
-
:keyword
|
|
364
|
-
:paramtype
|
|
310
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
311
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
365
312
|
:return: The loaded flow for simulating user interactions.
|
|
366
|
-
:rtype:
|
|
313
|
+
:rtype: Flow
|
|
367
314
|
"""
|
|
368
315
|
if not user_simulator_prompty:
|
|
369
316
|
package = "azure.ai.evaluation.simulator._prompty"
|
|
370
317
|
resource_name = "task_simulate.prompty"
|
|
371
318
|
try:
|
|
372
319
|
# Access the resource as a file path
|
|
373
|
-
# pylint: disable=deprecated-method
|
|
374
320
|
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
375
|
-
|
|
376
|
-
model_config=prompty_model_config, # type: ignore
|
|
377
|
-
default_api_version="2024-06-01",
|
|
378
|
-
user_agent=USER_AGENT,
|
|
379
|
-
)
|
|
380
|
-
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
321
|
+
return load_flow(source=str(prompty_path), model=prompty_model_config)
|
|
381
322
|
except FileNotFoundError as e:
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
message=msg,
|
|
385
|
-
internal_message=msg,
|
|
386
|
-
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
387
|
-
blame=ErrorBlame.USER_ERROR,
|
|
388
|
-
) from e
|
|
389
|
-
prompty_model_config = construct_prompty_model_config(
|
|
390
|
-
model_config=prompty_model_config, # type: ignore
|
|
391
|
-
default_api_version="2024-06-01",
|
|
392
|
-
user_agent=USER_AGENT,
|
|
393
|
-
)
|
|
394
|
-
return AsyncPrompty.load(
|
|
323
|
+
raise f"Flow path for {resource_name} does not exist in package {package}." from e
|
|
324
|
+
return load_flow(
|
|
395
325
|
source=user_simulator_prompty,
|
|
396
326
|
model=prompty_model_config,
|
|
397
|
-
**
|
|
398
|
-
)
|
|
327
|
+
**user_simulator_prompty_kwargs,
|
|
328
|
+
)
|
|
399
329
|
|
|
400
330
|
def _parse_prompty_response(self, *, response: str) -> Dict[str, Any]:
|
|
401
331
|
"""
|
|
402
332
|
Parses the response from the prompty execution.
|
|
403
333
|
|
|
404
334
|
:keyword response: The raw response from the prompty.
|
|
405
|
-
:paramtype
|
|
335
|
+
:paramtype str: str
|
|
406
336
|
:return: A dictionary representing the parsed response content.
|
|
407
337
|
:rtype: Dict[str, Any]
|
|
408
338
|
:raises ValueError: If the response cannot be parsed.
|
|
@@ -442,8 +372,8 @@ class Simulator:
|
|
|
442
372
|
text: str,
|
|
443
373
|
num_queries: int,
|
|
444
374
|
query_response_generating_prompty: Optional[str],
|
|
445
|
-
|
|
446
|
-
prompty_model_config: Any,
|
|
375
|
+
query_response_generating_prompty_kwargs: Dict[str, Any],
|
|
376
|
+
prompty_model_config: Dict[str, Any],
|
|
447
377
|
**kwargs,
|
|
448
378
|
) -> List[Dict[str, str]]:
|
|
449
379
|
"""
|
|
@@ -455,32 +385,25 @@ class Simulator:
|
|
|
455
385
|
:paramtype num_queries: int
|
|
456
386
|
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
457
387
|
:paramtype query_response_generating_prompty: Optional[str]
|
|
458
|
-
:keyword
|
|
459
|
-
:paramtype
|
|
388
|
+
:keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the query response generating prompty.
|
|
389
|
+
:paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
|
|
460
390
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
461
|
-
:paramtype prompty_model_config: Any
|
|
391
|
+
:paramtype prompty_model_config: Dict[str, Any]
|
|
462
392
|
:return: A list of query-response dictionaries.
|
|
463
393
|
:rtype: List[Dict[str, str]]
|
|
464
394
|
:raises RuntimeError: If an error occurs during query generation.
|
|
465
395
|
"""
|
|
466
396
|
query_flow = self._load_query_generation_flow(
|
|
467
|
-
query_response_generating_prompty=query_response_generating_prompty,
|
|
397
|
+
query_response_generating_prompty=query_response_generating_prompty,
|
|
468
398
|
prompty_model_config=prompty_model_config,
|
|
469
|
-
|
|
399
|
+
query_response_generating_prompty_kwargs=query_response_generating_prompty_kwargs,
|
|
470
400
|
)
|
|
401
|
+
|
|
471
402
|
try:
|
|
472
|
-
query_responses =
|
|
403
|
+
query_responses = query_flow(text=text, num_queries=num_queries)
|
|
473
404
|
if isinstance(query_responses, dict):
|
|
474
405
|
keys = list(query_responses.keys())
|
|
475
406
|
return query_responses[keys[0]]
|
|
476
|
-
if isinstance(query_responses, str):
|
|
477
|
-
query_responses = json.loads(query_responses)
|
|
478
|
-
if isinstance(query_responses, dict):
|
|
479
|
-
if len(query_responses.keys()) == 1:
|
|
480
|
-
return query_responses[list(query_responses.keys())[0]]
|
|
481
|
-
return query_responses # type: ignore
|
|
482
|
-
if isinstance(query_responses, list):
|
|
483
|
-
return query_responses
|
|
484
407
|
return json.loads(query_responses)
|
|
485
408
|
except Exception as e:
|
|
486
409
|
raise RuntimeError("Error generating query responses") from e
|
|
@@ -488,65 +411,47 @@ class Simulator:
|
|
|
488
411
|
def _load_query_generation_flow(
|
|
489
412
|
self,
|
|
490
413
|
*,
|
|
491
|
-
query_response_generating_prompty:
|
|
414
|
+
query_response_generating_prompty: Union[str, os.PathLike],
|
|
492
415
|
prompty_model_config: Dict[str, Any],
|
|
493
|
-
|
|
494
|
-
) ->
|
|
416
|
+
query_response_generating_prompty_kwargs: Dict[str, Any],
|
|
417
|
+
) -> Flow:
|
|
495
418
|
"""
|
|
496
419
|
Loads the flow for generating query responses.
|
|
497
420
|
|
|
498
421
|
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
499
|
-
:paramtype query_response_generating_prompty:
|
|
422
|
+
:paramtype query_response_generating_prompty: Union[str, os.PathLike]
|
|
500
423
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
501
424
|
:paramtype prompty_model_config: Dict[str, Any]
|
|
502
|
-
:keyword
|
|
503
|
-
:paramtype
|
|
425
|
+
:keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the flow.
|
|
426
|
+
:paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
|
|
504
427
|
:return: The loaded flow for generating query responses.
|
|
505
|
-
:rtype:
|
|
428
|
+
:rtype: Flow
|
|
506
429
|
"""
|
|
507
430
|
if not query_response_generating_prompty:
|
|
508
431
|
package = "azure.ai.evaluation.simulator._prompty"
|
|
509
432
|
resource_name = "task_query_response.prompty"
|
|
510
433
|
try:
|
|
511
434
|
# Access the resource as a file path
|
|
512
|
-
# pylint: disable=deprecated-method
|
|
513
435
|
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
514
|
-
|
|
515
|
-
model_config=prompty_model_config, # type: ignore
|
|
516
|
-
default_api_version="2024-06-01",
|
|
517
|
-
user_agent=USER_AGENT,
|
|
518
|
-
)
|
|
519
|
-
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
436
|
+
return load_flow(source=str(prompty_path), model=prompty_model_config)
|
|
520
437
|
except FileNotFoundError as e:
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
message=msg,
|
|
524
|
-
internal_message=msg,
|
|
525
|
-
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
526
|
-
blame=ErrorBlame.USER_ERROR,
|
|
527
|
-
) from e
|
|
528
|
-
prompty_model_config = construct_prompty_model_config(
|
|
529
|
-
model_config=prompty_model_config, # type: ignore
|
|
530
|
-
default_api_version="2024-06-01",
|
|
531
|
-
user_agent=USER_AGENT,
|
|
532
|
-
)
|
|
533
|
-
return AsyncPrompty.load(
|
|
438
|
+
raise f"Flow path for {resource_name} does not exist in package {package}." from e
|
|
439
|
+
return load_flow(
|
|
534
440
|
source=query_response_generating_prompty,
|
|
535
441
|
model=prompty_model_config,
|
|
536
|
-
**
|
|
537
|
-
)
|
|
442
|
+
**query_response_generating_prompty_kwargs,
|
|
443
|
+
)
|
|
538
444
|
|
|
539
445
|
async def _create_conversations_from_query_responses(
|
|
540
446
|
self,
|
|
541
447
|
*,
|
|
542
448
|
query_responses: List[Dict[str, str]],
|
|
543
449
|
max_conversation_turns: int,
|
|
544
|
-
tasks: List[
|
|
450
|
+
tasks: List[Dict],
|
|
545
451
|
user_simulator_prompty: Optional[str],
|
|
546
|
-
|
|
452
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
547
453
|
target: Callable,
|
|
548
454
|
api_call_delay_sec: float,
|
|
549
|
-
text: str,
|
|
550
455
|
) -> List[JsonLineChatProtocol]:
|
|
551
456
|
"""
|
|
552
457
|
Creates full conversations from query-response pairs.
|
|
@@ -556,17 +461,15 @@ class Simulator:
|
|
|
556
461
|
:keyword max_conversation_turns: The maximum number of conversation turns.
|
|
557
462
|
:paramtype max_conversation_turns: int
|
|
558
463
|
:keyword tasks: A list of tasks for the simulation.
|
|
559
|
-
:paramtype tasks: List[
|
|
464
|
+
:paramtype tasks: List[Dict]
|
|
560
465
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
561
466
|
:paramtype user_simulator_prompty: Optional[str]
|
|
562
|
-
:keyword
|
|
563
|
-
:paramtype
|
|
467
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
468
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
564
469
|
:keyword target: The target function to call for responses.
|
|
565
470
|
:paramtype target: Callable
|
|
566
471
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
567
472
|
:paramtype api_call_delay_sec: float
|
|
568
|
-
:keyword text: The initial input text for generating query responses.
|
|
569
|
-
:paramtype text: str
|
|
570
473
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
571
474
|
:rtype: List[JsonLineChatProtocol]
|
|
572
475
|
"""
|
|
@@ -588,9 +491,9 @@ class Simulator:
|
|
|
588
491
|
conversation = await self._complete_conversation(
|
|
589
492
|
conversation_starter=query,
|
|
590
493
|
max_conversation_turns=max_conversation_turns,
|
|
591
|
-
task=task,
|
|
494
|
+
task=task,
|
|
592
495
|
user_simulator_prompty=user_simulator_prompty,
|
|
593
|
-
|
|
496
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
594
497
|
target=target,
|
|
595
498
|
api_call_delay_sec=api_call_delay_sec,
|
|
596
499
|
progress_bar=progress_bar,
|
|
@@ -604,7 +507,6 @@ class Simulator:
|
|
|
604
507
|
"task": task,
|
|
605
508
|
"expected_response": response,
|
|
606
509
|
"query": query,
|
|
607
|
-
"original_text": text,
|
|
608
510
|
},
|
|
609
511
|
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
610
512
|
}
|
|
@@ -620,11 +522,11 @@ class Simulator:
|
|
|
620
522
|
max_conversation_turns: int,
|
|
621
523
|
task: str,
|
|
622
524
|
user_simulator_prompty: Optional[str],
|
|
623
|
-
|
|
525
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
624
526
|
target: Callable,
|
|
625
527
|
api_call_delay_sec: float,
|
|
626
528
|
progress_bar: tqdm,
|
|
627
|
-
) -> List[Dict[str,
|
|
529
|
+
) -> List[Dict[str, str]]:
|
|
628
530
|
"""
|
|
629
531
|
Completes a conversation with the target model based on the conversation starter.
|
|
630
532
|
|
|
@@ -636,8 +538,8 @@ class Simulator:
|
|
|
636
538
|
:paramtype task: str
|
|
637
539
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
638
540
|
:paramtype user_simulator_prompty: Optional[str]
|
|
639
|
-
:keyword
|
|
640
|
-
:paramtype
|
|
541
|
+
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
542
|
+
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
641
543
|
:keyword target: The target function to call for responses.
|
|
642
544
|
:paramtype target: Callable
|
|
643
545
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
@@ -645,43 +547,36 @@ class Simulator:
|
|
|
645
547
|
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
646
548
|
:paramtype progress_bar: tqdm
|
|
647
549
|
:return: A list representing the conversation history with each turn's content.
|
|
648
|
-
:rtype: List[Dict[str,
|
|
550
|
+
:rtype: List[Dict[str, str]]
|
|
649
551
|
"""
|
|
650
552
|
conversation_history = ConversationHistory()
|
|
553
|
+
# user_turn = Turn(role=ConversationRole.USER, content=conversation_starter)
|
|
554
|
+
# conversation_history.add_to_history(user_turn)
|
|
651
555
|
|
|
652
556
|
while len(conversation_history) < max_conversation_turns:
|
|
653
557
|
user_flow = self._load_user_simulation_flow(
|
|
654
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
655
|
-
prompty_model_config=self.
|
|
656
|
-
|
|
558
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
559
|
+
prompty_model_config=self._build_prompty_model_config(),
|
|
560
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
561
|
+
)
|
|
562
|
+
conversation_starter_from_simulated_user = user_flow(
|
|
563
|
+
task=task,
|
|
564
|
+
conversation_history=[
|
|
565
|
+
{
|
|
566
|
+
"role": "assistant",
|
|
567
|
+
"content": conversation_starter,
|
|
568
|
+
"your_task": "Act as the user and translate the content into a user query.",
|
|
569
|
+
}
|
|
570
|
+
],
|
|
657
571
|
)
|
|
658
|
-
if len(conversation_history) == 0:
|
|
659
|
-
conversation_starter_from_simulated_user = await user_flow(
|
|
660
|
-
task=task,
|
|
661
|
-
conversation_history=[
|
|
662
|
-
{
|
|
663
|
-
"role": "assistant",
|
|
664
|
-
"content": conversation_starter,
|
|
665
|
-
}
|
|
666
|
-
],
|
|
667
|
-
action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
|
|
668
|
-
)
|
|
669
|
-
else:
|
|
670
|
-
conversation_starter_from_simulated_user = await user_flow(
|
|
671
|
-
task=task,
|
|
672
|
-
conversation_history=conversation_history.to_context_free_list(),
|
|
673
|
-
action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
|
|
674
|
-
)
|
|
675
572
|
if isinstance(conversation_starter_from_simulated_user, dict):
|
|
676
573
|
conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
|
|
677
574
|
user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
|
|
678
575
|
conversation_history.add_to_history(user_turn)
|
|
679
|
-
assistant_response
|
|
576
|
+
assistant_response = await self._get_target_response(
|
|
680
577
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
|
|
681
578
|
)
|
|
682
|
-
assistant_turn = Turn(
|
|
683
|
-
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
684
|
-
)
|
|
579
|
+
assistant_turn = Turn(role=ConversationRole.ASSISTANT, content=assistant_response)
|
|
685
580
|
conversation_history.add_to_history(assistant_turn)
|
|
686
581
|
progress_bar.update(1)
|
|
687
582
|
|
|
@@ -690,9 +585,44 @@ class Simulator:
|
|
|
690
585
|
|
|
691
586
|
return conversation_history.to_list()
|
|
692
587
|
|
|
588
|
+
async def _build_user_simulation_response(
|
|
589
|
+
self,
|
|
590
|
+
task: str,
|
|
591
|
+
conversation_history: List[Dict[str, Any]],
|
|
592
|
+
user_simulator_prompty: Optional[str],
|
|
593
|
+
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
594
|
+
) -> str:
|
|
595
|
+
"""
|
|
596
|
+
Builds a response from the user simulator based on the current conversation history.
|
|
597
|
+
|
|
598
|
+
:param task: A string representing the task details.
|
|
599
|
+
:type task: str
|
|
600
|
+
:param conversation_history: The current conversation history as a list of dictionaries.
|
|
601
|
+
:type conversation_history: List[Dict[str, Any]]
|
|
602
|
+
:param user_simulator_prompty: Path to the user simulator prompty file.
|
|
603
|
+
:type user_simulator_prompty: Optional[str]
|
|
604
|
+
:param user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
605
|
+
:type user_simulator_prompty_kwargs: Dict[str, Any]
|
|
606
|
+
:return: The generated response content from the user simulator.
|
|
607
|
+
:rtype: str
|
|
608
|
+
:raises RuntimeError: If an error occurs during response generation.
|
|
609
|
+
"""
|
|
610
|
+
user_flow = self._load_user_simulation_flow(
|
|
611
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
612
|
+
prompty_model_config=self._build_prompty_model_config(),
|
|
613
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
response_content = user_flow(task=task, conversation_history=conversation_history)
|
|
618
|
+
user_response = self._parse_prompty_response(response=response_content)
|
|
619
|
+
return user_response["content"]
|
|
620
|
+
except Exception as e:
|
|
621
|
+
raise RuntimeError("Error building user simulation response") from e
|
|
622
|
+
|
|
693
623
|
async def _get_target_response(
|
|
694
624
|
self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
|
|
695
|
-
) ->
|
|
625
|
+
) -> str:
|
|
696
626
|
"""
|
|
697
627
|
Retrieves the response from the target callback based on the current conversation history.
|
|
698
628
|
|
|
@@ -702,8 +632,8 @@ class Simulator:
|
|
|
702
632
|
:paramtype api_call_delay_sec: float
|
|
703
633
|
:keyword conversation_history: The current conversation history.
|
|
704
634
|
:paramtype conversation_history: ConversationHistory
|
|
705
|
-
:return: The content of the response from the target
|
|
706
|
-
:rtype: str
|
|
635
|
+
:return: The content of the response from the target.
|
|
636
|
+
:rtype: str
|
|
707
637
|
"""
|
|
708
638
|
response = await target(
|
|
709
639
|
messages={"messages": conversation_history.to_list()},
|
|
@@ -713,4 +643,4 @@ class Simulator:
|
|
|
713
643
|
)
|
|
714
644
|
await asyncio.sleep(api_call_delay_sec)
|
|
715
645
|
latest_message = response["messages"][-1]
|
|
716
|
-
return latest_message["content"]
|
|
646
|
+
return latest_message["content"]
|