azure-ai-evaluation 1.0.0b3__py3-none-any.whl → 1.0.0b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +23 -1
- azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +20 -9
- azure/ai/evaluation/_common/constants.py +9 -2
- azure/ai/evaluation/_common/math.py +29 -0
- azure/ai/evaluation/_common/rai_service.py +222 -93
- azure/ai/evaluation/_common/utils.py +328 -19
- azure/ai/evaluation/_constants.py +16 -8
- azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
- azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +33 -17
- azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +14 -7
- azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +22 -4
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +35 -0
- azure/ai/evaluation/_evaluate/_eval_run.py +47 -14
- azure/ai/evaluation/_evaluate/_evaluate.py +370 -188
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +15 -16
- azure/ai/evaluation/_evaluate/_utils.py +77 -25
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +16 -10
- azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +76 -46
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +26 -19
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +62 -25
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +68 -36
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +67 -46
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +33 -4
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +33 -4
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +33 -4
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +33 -4
- azure/ai/evaluation/_evaluators/_eci/_eci.py +7 -5
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +14 -6
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +22 -21
- azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +51 -16
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +3 -7
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +130 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +57 -0
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +120 -0
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +96 -0
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +46 -13
- azure/ai/evaluation/_evaluators/_qa/_qa.py +11 -6
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +23 -20
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +126 -80
- azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +2 -2
- azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +150 -0
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +32 -15
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +36 -10
- azure/ai/evaluation/_exceptions.py +26 -6
- azure/ai/evaluation/_http_utils.py +203 -132
- azure/ai/evaluation/_model_configurations.py +23 -6
- azure/ai/evaluation/_vendor/__init__.py +3 -0
- azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +328 -0
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +63 -0
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +63 -0
- azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/simulator/__init__.py +2 -1
- azure/ai/evaluation/simulator/_adversarial_scenario.py +5 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +88 -60
- azure/ai/evaluation/simulator/_conversation/__init__.py +13 -12
- azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
- azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
- azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +24 -66
- azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
- azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +26 -5
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +98 -95
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +67 -21
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +28 -11
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +68 -24
- azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
- azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -9
- azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -5
- azure/ai/evaluation/simulator/_simulator.py +222 -169
- azure/ai/evaluation/simulator/_tracing.py +4 -4
- azure/ai/evaluation/simulator/_utils.py +6 -6
- {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/METADATA +237 -52
- azure_ai_evaluation-1.0.0b5.dist-info/NOTICE.txt +70 -0
- azure_ai_evaluation-1.0.0b5.dist-info/RECORD +120 -0
- {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
- azure_ai_evaluation-1.0.0b3.dist-info/RECORD +0 -98
- {azure_ai_evaluation-1.0.0b3.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# flake8: noqa
|
|
2
|
-
# pylint: disable=W0102,W0613,R0914,C0301,E0401,E0611
|
|
2
|
+
# pylint: disable=W0102,W0613,R0914,C0301,E0401,E0611,C0114,R0913,E0702,R0903,C0411
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
5
5
|
# ---------------------------------------------------------
|
|
@@ -9,17 +9,19 @@ import json
|
|
|
9
9
|
import os
|
|
10
10
|
import re
|
|
11
11
|
import warnings
|
|
12
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
13
13
|
|
|
14
|
-
from promptflow.
|
|
15
|
-
from promptflow.core import AzureOpenAIModelConfiguration, Flow
|
|
14
|
+
from promptflow.core import AsyncPrompty
|
|
16
15
|
from tqdm import tqdm
|
|
17
16
|
|
|
17
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
18
|
+
from azure.ai.evaluation._common.utils import construct_prompty_model_config
|
|
19
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
20
|
+
|
|
21
|
+
from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
|
|
18
22
|
from .._user_agent import USER_AGENT
|
|
19
23
|
from ._conversation.constants import ConversationRole
|
|
20
|
-
from ._helpers import ConversationHistory, Turn
|
|
21
|
-
|
|
22
|
-
# from ._tracing import monitor_task_simulator
|
|
24
|
+
from ._helpers import ConversationHistory, Turn
|
|
23
25
|
from ._utils import JsonLineChatProtocol
|
|
24
26
|
|
|
25
27
|
|
|
@@ -29,43 +31,60 @@ class Simulator:
|
|
|
29
31
|
Simulator for generating synthetic conversations.
|
|
30
32
|
"""
|
|
31
33
|
|
|
32
|
-
def __init__(self,
|
|
34
|
+
def __init__(self, model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration]):
|
|
33
35
|
"""
|
|
34
|
-
Initializes the task simulator with
|
|
36
|
+
Initializes the task simulator with the model configuration.
|
|
35
37
|
|
|
36
|
-
:param
|
|
37
|
-
|
|
38
|
-
:
|
|
39
|
-
:paramtype credential: Optional[Any]
|
|
40
|
-
:raises ValueError: If the azure_ai_project does not contain the required keys or any value is None.
|
|
38
|
+
:param model_config: A dictionary defining the configuration for the model. Acceptable types are AzureOpenAIModelConfiguration and OpenAIModelConfiguration.
|
|
39
|
+
:type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, ~azure.ai.evaluation.OpenAIModelConfiguration]
|
|
40
|
+
:raises ValueError: If the model_config does not contain the required keys or any value is None.
|
|
41
41
|
"""
|
|
42
|
-
self.
|
|
43
|
-
self.
|
|
44
|
-
|
|
45
|
-
|
|
42
|
+
self._validate_model_config(model_config)
|
|
43
|
+
self.model_config = model_config
|
|
44
|
+
if "api_version" not in self.model_config:
|
|
45
|
+
self.model_config["api_version"] = "2024-06-01" # type: ignore
|
|
46
46
|
|
|
47
47
|
@staticmethod
|
|
48
|
-
def
|
|
48
|
+
def _validate_model_config(model_config: Any):
|
|
49
49
|
"""
|
|
50
|
-
Validates the
|
|
50
|
+
Validates the model_config to ensure all required keys are present and have non-None values.
|
|
51
|
+
If 'type' is not specified, it will attempt to infer the type based on the keys present.
|
|
51
52
|
|
|
52
|
-
:param
|
|
53
|
-
:type
|
|
53
|
+
:param model_config: The model configuration dictionary.
|
|
54
|
+
:type model_config: Dict[str, Any]
|
|
54
55
|
:raises ValueError: If required keys are missing or any of the values are None.
|
|
55
56
|
"""
|
|
56
|
-
|
|
57
|
-
if not
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
57
|
+
# Attempt to infer 'type' if not provided
|
|
58
|
+
if "type" not in model_config:
|
|
59
|
+
if "azure_deployment" in model_config and "azure_endpoint" in model_config:
|
|
60
|
+
model_config["type"] = "azure_openai"
|
|
61
|
+
elif "model" in model_config:
|
|
62
|
+
model_config["type"] = "openai"
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
"Unable to infer 'type' from model_config. Please specify 'type' as 'azure_openai' or 'openai'."
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
if model_config["type"] == "azure_openai":
|
|
69
|
+
required_keys = ["azure_deployment", "azure_endpoint"]
|
|
70
|
+
elif model_config["type"] == "openai":
|
|
71
|
+
required_keys = ["api_key", "model"]
|
|
72
|
+
else:
|
|
73
|
+
raise ValueError("model_config 'type' must be 'azure_openai' or 'openai'.")
|
|
74
|
+
|
|
75
|
+
missing_keys = [key for key in required_keys if key not in model_config]
|
|
76
|
+
if missing_keys:
|
|
77
|
+
raise ValueError(f"model_config is missing required keys: {', '.join(missing_keys)}")
|
|
78
|
+
none_keys = [key for key in required_keys if model_config.get(key) is None]
|
|
79
|
+
if none_keys:
|
|
80
|
+
raise ValueError(f"The following keys in model_config must not be None: {', '.join(none_keys)}")
|
|
61
81
|
|
|
62
|
-
# @monitor_task_simulator
|
|
63
82
|
async def __call__(
|
|
64
83
|
self,
|
|
65
84
|
*,
|
|
66
85
|
target: Callable,
|
|
67
86
|
max_conversation_turns: int = 5,
|
|
68
|
-
tasks: List[
|
|
87
|
+
tasks: List[str] = [],
|
|
69
88
|
text: str = "",
|
|
70
89
|
num_queries: int = 5,
|
|
71
90
|
query_response_generating_prompty: Optional[str] = None,
|
|
@@ -73,7 +92,8 @@ class Simulator:
|
|
|
73
92
|
api_call_delay_sec: float = 1,
|
|
74
93
|
query_response_generating_prompty_kwargs: Dict[str, Any] = {},
|
|
75
94
|
user_simulator_prompty_kwargs: Dict[str, Any] = {},
|
|
76
|
-
conversation_turns: List[List[str]] = [],
|
|
95
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
|
|
96
|
+
concurrent_async_tasks: int = 5,
|
|
77
97
|
**kwargs,
|
|
78
98
|
) -> List[JsonLineChatProtocol]:
|
|
79
99
|
"""
|
|
@@ -100,7 +120,10 @@ class Simulator:
|
|
|
100
120
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
101
121
|
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
102
122
|
:keyword conversation_turns: Predefined conversation turns to simulate.
|
|
103
|
-
:paramtype conversation_turns: List[List[str]]
|
|
123
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
124
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
125
|
+
Defaults to 5.
|
|
126
|
+
:paramtype concurrent_async_tasks: int
|
|
104
127
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
105
128
|
:rtype: List[JsonLineChatProtocol]
|
|
106
129
|
|
|
@@ -109,18 +132,18 @@ class Simulator:
|
|
|
109
132
|
|
|
110
133
|
Modes:
|
|
111
134
|
- Task-Free Mode: When only num_queries is specified and tasks is not, the method generates num_queries x max_conversation_turns lines of simulated data grounded in the context of the text.
|
|
112
|
-
- Task-Specific Mode: When both num_queries and tasks are specified, the method generates lines of simulated data based on the tasks. If num_queries > len(tasks), the remaining lines
|
|
135
|
+
- Task-Specific Mode: When both num_queries and tasks are specified, the method generates lines of simulated data based on the tasks. If num_queries > len(tasks), the remaining lines will be simulated in task-free mode. If num_queries < len(tasks), only the first num_queries tasks are used.
|
|
113
136
|
- Conversation Starter Mode: When conversation_turns are specified, the method starts each conversation with the user-specified queries and then follows the conversation history for the remaining turns.
|
|
114
137
|
"""
|
|
115
138
|
if conversation_turns and (text or tasks):
|
|
116
139
|
raise ValueError("Cannot specify both conversation_turns and text/tasks")
|
|
117
140
|
|
|
118
|
-
if num_queries > len(tasks):
|
|
141
|
+
if text and num_queries > len(tasks):
|
|
119
142
|
warnings.warn(
|
|
120
143
|
f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
|
|
121
144
|
f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
|
|
122
145
|
)
|
|
123
|
-
elif num_queries < len(tasks):
|
|
146
|
+
elif text and num_queries < len(tasks):
|
|
124
147
|
warnings.warn(
|
|
125
148
|
f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
|
|
126
149
|
f"Only the first {num_queries} lines of the specified tasks will be simulated."
|
|
@@ -128,7 +151,7 @@ class Simulator:
|
|
|
128
151
|
num_queries = min(num_queries, len(tasks))
|
|
129
152
|
max_conversation_turns *= 2 # account for both user and assistant turns
|
|
130
153
|
|
|
131
|
-
prompty_model_config = self.
|
|
154
|
+
prompty_model_config = self.model_config
|
|
132
155
|
if conversation_turns:
|
|
133
156
|
return await self._simulate_with_predefined_turns(
|
|
134
157
|
target=target,
|
|
@@ -138,6 +161,7 @@ class Simulator:
|
|
|
138
161
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
139
162
|
api_call_delay_sec=api_call_delay_sec,
|
|
140
163
|
prompty_model_config=prompty_model_config,
|
|
164
|
+
concurrent_async_tasks=concurrent_async_tasks,
|
|
141
165
|
)
|
|
142
166
|
|
|
143
167
|
query_responses = await self._generate_query_responses(
|
|
@@ -148,7 +172,6 @@ class Simulator:
|
|
|
148
172
|
prompty_model_config=prompty_model_config,
|
|
149
173
|
**kwargs,
|
|
150
174
|
)
|
|
151
|
-
|
|
152
175
|
return await self._create_conversations_from_query_responses(
|
|
153
176
|
query_responses=query_responses,
|
|
154
177
|
max_conversation_turns=max_conversation_turns,
|
|
@@ -157,30 +180,20 @@ class Simulator:
|
|
|
157
180
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
158
181
|
target=target,
|
|
159
182
|
api_call_delay_sec=api_call_delay_sec,
|
|
183
|
+
text=text,
|
|
160
184
|
)
|
|
161
185
|
|
|
162
|
-
def _build_prompty_model_config(self) -> Dict[str, Any]:
|
|
163
|
-
"""
|
|
164
|
-
Constructs the configuration for the prompty model.
|
|
165
|
-
|
|
166
|
-
:return: A dictionary containing the prompty model configuration, including API version and user agent headers if applicable.
|
|
167
|
-
:rtype: Dict[str, Any]
|
|
168
|
-
"""
|
|
169
|
-
config = {"configuration": self.azure_ai_project}
|
|
170
|
-
if USER_AGENT and isinstance(self.azure_ai_project, AzureOpenAIModelConfiguration):
|
|
171
|
-
config.update({"parameters": {"extra_headers": {"x-ms-useragent": USER_AGENT}}})
|
|
172
|
-
return config
|
|
173
|
-
|
|
174
186
|
async def _simulate_with_predefined_turns(
|
|
175
187
|
self,
|
|
176
188
|
*,
|
|
177
189
|
target: Callable,
|
|
178
190
|
max_conversation_turns: int,
|
|
179
|
-
conversation_turns: List[List[str]],
|
|
191
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]],
|
|
180
192
|
user_simulator_prompty: Optional[str],
|
|
181
193
|
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
182
194
|
api_call_delay_sec: float,
|
|
183
|
-
prompty_model_config:
|
|
195
|
+
prompty_model_config: Any,
|
|
196
|
+
concurrent_async_tasks: int,
|
|
184
197
|
) -> List[JsonLineChatProtocol]:
|
|
185
198
|
"""
|
|
186
199
|
Simulates conversations using predefined conversation turns.
|
|
@@ -190,7 +203,7 @@ class Simulator:
|
|
|
190
203
|
:keyword max_conversation_turns: Maximum number of turns for the simulation.
|
|
191
204
|
:paramtype max_conversation_turns: int
|
|
192
205
|
:keyword conversation_turns: A list of predefined conversation turns.
|
|
193
|
-
:paramtype conversation_turns: List[List[str]]
|
|
206
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
194
207
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
195
208
|
:paramtype user_simulator_prompty: Optional[str]
|
|
196
209
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
@@ -198,43 +211,61 @@ class Simulator:
|
|
|
198
211
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
199
212
|
:paramtype api_call_delay_sec: float
|
|
200
213
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
201
|
-
:paramtype prompty_model_config:
|
|
214
|
+
:paramtype prompty_model_config: Any
|
|
215
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
216
|
+
:paramtype concurrent_async_tasks: int
|
|
202
217
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
203
218
|
:rtype: List[JsonLineChatProtocol]
|
|
204
219
|
"""
|
|
205
|
-
simulated_conversations = []
|
|
206
220
|
progress_bar = tqdm(
|
|
207
221
|
total=int(len(conversation_turns) * (max_conversation_turns / 2)),
|
|
208
222
|
desc="Simulating with predefined conversation turns: ",
|
|
209
223
|
ncols=100,
|
|
210
224
|
unit="messages",
|
|
211
225
|
)
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
current_simulation
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
226
|
+
semaphore = asyncio.Semaphore(concurrent_async_tasks)
|
|
227
|
+
progress_bar_lock = asyncio.Lock()
|
|
228
|
+
|
|
229
|
+
async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
|
|
230
|
+
async with semaphore:
|
|
231
|
+
current_simulation = ConversationHistory()
|
|
232
|
+
for simulated_turn in simulation:
|
|
233
|
+
if isinstance(simulated_turn, str):
|
|
234
|
+
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
|
|
235
|
+
elif isinstance(simulated_turn, dict):
|
|
236
|
+
user_turn = Turn(
|
|
237
|
+
role=ConversationRole.USER,
|
|
238
|
+
content=str(simulated_turn.get("content")),
|
|
239
|
+
context=str(simulated_turn.get("context")),
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"Each simulated turn must be a string or a dict with 'content' and 'context' keys"
|
|
244
|
+
)
|
|
245
|
+
current_simulation.add_to_history(user_turn)
|
|
246
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
247
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
248
|
+
)
|
|
249
|
+
assistant_turn = Turn(
|
|
250
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
251
|
+
)
|
|
252
|
+
current_simulation.add_to_history(assistant_turn)
|
|
253
|
+
async with progress_bar_lock:
|
|
254
|
+
progress_bar.update(1)
|
|
255
|
+
|
|
256
|
+
if len(current_simulation) < max_conversation_turns:
|
|
257
|
+
await self._extend_conversation_with_simulator(
|
|
258
|
+
current_simulation=current_simulation,
|
|
259
|
+
max_conversation_turns=max_conversation_turns,
|
|
260
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
261
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
262
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
263
|
+
prompty_model_config=prompty_model_config,
|
|
264
|
+
target=target,
|
|
265
|
+
progress_bar=progress_bar,
|
|
266
|
+
progress_bar_lock=progress_bar_lock,
|
|
267
|
+
)
|
|
268
|
+
return JsonLineChatProtocol(
|
|
238
269
|
{
|
|
239
270
|
"messages": current_simulation.to_list(),
|
|
240
271
|
"finish_reason": ["stop"],
|
|
@@ -242,10 +273,11 @@ class Simulator:
|
|
|
242
273
|
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
243
274
|
}
|
|
244
275
|
)
|
|
245
|
-
)
|
|
246
276
|
|
|
277
|
+
tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
|
|
278
|
+
results = await asyncio.gather(*tasks)
|
|
247
279
|
progress_bar.close()
|
|
248
|
-
return
|
|
280
|
+
return results
|
|
249
281
|
|
|
250
282
|
async def _extend_conversation_with_simulator(
|
|
251
283
|
self,
|
|
@@ -258,6 +290,7 @@ class Simulator:
|
|
|
258
290
|
prompty_model_config: Dict[str, Any],
|
|
259
291
|
target: Callable,
|
|
260
292
|
progress_bar: tqdm,
|
|
293
|
+
progress_bar_lock: asyncio.Lock,
|
|
261
294
|
):
|
|
262
295
|
"""
|
|
263
296
|
Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
|
|
@@ -278,48 +311,53 @@ class Simulator:
|
|
|
278
311
|
:paramtype target: Callable,
|
|
279
312
|
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
280
313
|
:paramtype progress_bar: tqdm,
|
|
314
|
+
:keyword progress_bar_lock: Lock for updating the progress bar safely.
|
|
315
|
+
:paramtype progress_bar_lock: asyncio.Lock
|
|
281
316
|
"""
|
|
282
317
|
user_flow = self._load_user_simulation_flow(
|
|
283
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
318
|
+
user_simulator_prompty=user_simulator_prompty, # type: ignore
|
|
284
319
|
prompty_model_config=prompty_model_config,
|
|
285
320
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
286
321
|
)
|
|
287
322
|
|
|
288
323
|
while len(current_simulation) < max_conversation_turns:
|
|
289
|
-
user_response_content = user_flow(
|
|
324
|
+
user_response_content = await user_flow(
|
|
290
325
|
task="Continue the conversation",
|
|
291
|
-
conversation_history=current_simulation.
|
|
326
|
+
conversation_history=current_simulation.to_context_free_list(),
|
|
292
327
|
**user_simulator_prompty_kwargs,
|
|
293
328
|
)
|
|
294
329
|
user_response = self._parse_prompty_response(response=user_response_content)
|
|
295
330
|
user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
|
|
296
331
|
current_simulation.add_to_history(user_turn)
|
|
297
332
|
await asyncio.sleep(api_call_delay_sec)
|
|
298
|
-
assistant_response = await self._get_target_response(
|
|
333
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
299
334
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
300
335
|
)
|
|
301
|
-
assistant_turn = Turn(
|
|
336
|
+
assistant_turn = Turn(
|
|
337
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
338
|
+
)
|
|
302
339
|
current_simulation.add_to_history(assistant_turn)
|
|
303
|
-
|
|
340
|
+
async with progress_bar_lock:
|
|
341
|
+
progress_bar.update(1)
|
|
304
342
|
|
|
305
343
|
def _load_user_simulation_flow(
|
|
306
344
|
self,
|
|
307
345
|
*,
|
|
308
|
-
user_simulator_prompty: Union[str, os.PathLike],
|
|
346
|
+
user_simulator_prompty: Optional[Union[str, os.PathLike]],
|
|
309
347
|
prompty_model_config: Dict[str, Any],
|
|
310
348
|
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
311
|
-
) ->
|
|
349
|
+
) -> "AsyncPrompty": # type: ignore
|
|
312
350
|
"""
|
|
313
351
|
Loads the flow for simulating user interactions.
|
|
314
352
|
|
|
315
353
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
316
|
-
:paramtype user_simulator_prompty: Union[str, os.PathLike]
|
|
354
|
+
:paramtype user_simulator_prompty: Optional[Union[str, os.PathLike]]
|
|
317
355
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
318
356
|
:paramtype prompty_model_config: Dict[str, Any]
|
|
319
357
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
320
358
|
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
321
359
|
:return: The loaded flow for simulating user interactions.
|
|
322
|
-
:rtype:
|
|
360
|
+
:rtype: AsyncPrompty
|
|
323
361
|
"""
|
|
324
362
|
if not user_simulator_prompty:
|
|
325
363
|
package = "azure.ai.evaluation.simulator._prompty"
|
|
@@ -328,21 +366,37 @@ class Simulator:
|
|
|
328
366
|
# Access the resource as a file path
|
|
329
367
|
# pylint: disable=deprecated-method
|
|
330
368
|
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
331
|
-
|
|
369
|
+
prompty_model_config = construct_prompty_model_config(
|
|
370
|
+
model_config=prompty_model_config, # type: ignore
|
|
371
|
+
default_api_version="2024-06-01",
|
|
372
|
+
user_agent=USER_AGENT,
|
|
373
|
+
)
|
|
374
|
+
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
332
375
|
except FileNotFoundError as e:
|
|
333
|
-
|
|
334
|
-
|
|
376
|
+
msg = f"Flow path for {resource_name} does not exist in package {package}."
|
|
377
|
+
raise EvaluationException(
|
|
378
|
+
message=msg,
|
|
379
|
+
internal_message=msg,
|
|
380
|
+
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
381
|
+
blame=ErrorBlame.USER_ERROR,
|
|
382
|
+
) from e
|
|
383
|
+
prompty_model_config = construct_prompty_model_config(
|
|
384
|
+
model_config=prompty_model_config, # type: ignore
|
|
385
|
+
default_api_version="2024-06-01",
|
|
386
|
+
user_agent=USER_AGENT,
|
|
387
|
+
)
|
|
388
|
+
return AsyncPrompty.load(
|
|
335
389
|
source=user_simulator_prompty,
|
|
336
390
|
model=prompty_model_config,
|
|
337
391
|
**user_simulator_prompty_kwargs,
|
|
338
|
-
)
|
|
392
|
+
) # type: ignore
|
|
339
393
|
|
|
340
394
|
def _parse_prompty_response(self, *, response: str) -> Dict[str, Any]:
|
|
341
395
|
"""
|
|
342
396
|
Parses the response from the prompty execution.
|
|
343
397
|
|
|
344
398
|
:keyword response: The raw response from the prompty.
|
|
345
|
-
:paramtype
|
|
399
|
+
:paramtype response: str
|
|
346
400
|
:return: A dictionary representing the parsed response content.
|
|
347
401
|
:rtype: Dict[str, Any]
|
|
348
402
|
:raises ValueError: If the response cannot be parsed.
|
|
@@ -383,7 +437,7 @@ class Simulator:
|
|
|
383
437
|
num_queries: int,
|
|
384
438
|
query_response_generating_prompty: Optional[str],
|
|
385
439
|
query_response_generating_prompty_kwargs: Dict[str, Any],
|
|
386
|
-
prompty_model_config:
|
|
440
|
+
prompty_model_config: Any,
|
|
387
441
|
**kwargs,
|
|
388
442
|
) -> List[Dict[str, str]]:
|
|
389
443
|
"""
|
|
@@ -398,21 +452,29 @@ class Simulator:
|
|
|
398
452
|
:keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the query response generating prompty.
|
|
399
453
|
:paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
|
|
400
454
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
401
|
-
:paramtype prompty_model_config:
|
|
455
|
+
:paramtype prompty_model_config: Any
|
|
402
456
|
:return: A list of query-response dictionaries.
|
|
403
457
|
:rtype: List[Dict[str, str]]
|
|
404
458
|
:raises RuntimeError: If an error occurs during query generation.
|
|
405
459
|
"""
|
|
406
460
|
query_flow = self._load_query_generation_flow(
|
|
407
|
-
query_response_generating_prompty=query_response_generating_prompty,
|
|
461
|
+
query_response_generating_prompty=query_response_generating_prompty, # type: ignore
|
|
408
462
|
prompty_model_config=prompty_model_config,
|
|
409
463
|
query_response_generating_prompty_kwargs=query_response_generating_prompty_kwargs,
|
|
410
464
|
)
|
|
411
465
|
try:
|
|
412
|
-
query_responses = query_flow(text=text, num_queries=num_queries)
|
|
466
|
+
query_responses = await query_flow(text=text, num_queries=num_queries)
|
|
413
467
|
if isinstance(query_responses, dict):
|
|
414
468
|
keys = list(query_responses.keys())
|
|
415
469
|
return query_responses[keys[0]]
|
|
470
|
+
if isinstance(query_responses, str):
|
|
471
|
+
query_responses = json.loads(query_responses)
|
|
472
|
+
if isinstance(query_responses, dict):
|
|
473
|
+
if len(query_responses.keys()) == 1:
|
|
474
|
+
return query_responses[list(query_responses.keys())[0]]
|
|
475
|
+
return query_responses # type: ignore
|
|
476
|
+
if isinstance(query_responses, list):
|
|
477
|
+
return query_responses
|
|
416
478
|
return json.loads(query_responses)
|
|
417
479
|
except Exception as e:
|
|
418
480
|
raise RuntimeError("Error generating query responses") from e
|
|
@@ -420,21 +482,21 @@ class Simulator:
|
|
|
420
482
|
def _load_query_generation_flow(
|
|
421
483
|
self,
|
|
422
484
|
*,
|
|
423
|
-
query_response_generating_prompty: Union[str, os.PathLike],
|
|
485
|
+
query_response_generating_prompty: Optional[Union[str, os.PathLike]],
|
|
424
486
|
prompty_model_config: Dict[str, Any],
|
|
425
487
|
query_response_generating_prompty_kwargs: Dict[str, Any],
|
|
426
|
-
) ->
|
|
488
|
+
) -> "AsyncPrompty":
|
|
427
489
|
"""
|
|
428
490
|
Loads the flow for generating query responses.
|
|
429
491
|
|
|
430
492
|
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
431
|
-
:paramtype query_response_generating_prompty: Union[str, os.PathLike]
|
|
493
|
+
:paramtype query_response_generating_prompty: Optional[Union[str, os.PathLike]]
|
|
432
494
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
433
495
|
:paramtype prompty_model_config: Dict[str, Any]
|
|
434
496
|
:keyword query_response_generating_prompty_kwargs: Additional keyword arguments for the flow.
|
|
435
497
|
:paramtype query_response_generating_prompty_kwargs: Dict[str, Any]
|
|
436
498
|
:return: The loaded flow for generating query responses.
|
|
437
|
-
:rtype:
|
|
499
|
+
:rtype: AsyncPrompty
|
|
438
500
|
"""
|
|
439
501
|
if not query_response_generating_prompty:
|
|
440
502
|
package = "azure.ai.evaluation.simulator._prompty"
|
|
@@ -443,25 +505,42 @@ class Simulator:
|
|
|
443
505
|
# Access the resource as a file path
|
|
444
506
|
# pylint: disable=deprecated-method
|
|
445
507
|
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
446
|
-
|
|
508
|
+
prompty_model_config = construct_prompty_model_config(
|
|
509
|
+
model_config=prompty_model_config, # type: ignore
|
|
510
|
+
default_api_version="2024-06-01",
|
|
511
|
+
user_agent=USER_AGENT,
|
|
512
|
+
)
|
|
513
|
+
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
447
514
|
except FileNotFoundError as e:
|
|
448
|
-
|
|
449
|
-
|
|
515
|
+
msg = f"Flow path for {resource_name} does not exist in package {package}."
|
|
516
|
+
raise EvaluationException(
|
|
517
|
+
message=msg,
|
|
518
|
+
internal_message=msg,
|
|
519
|
+
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
520
|
+
blame=ErrorBlame.USER_ERROR,
|
|
521
|
+
) from e
|
|
522
|
+
prompty_model_config = construct_prompty_model_config(
|
|
523
|
+
model_config=prompty_model_config, # type: ignore
|
|
524
|
+
default_api_version="2024-06-01",
|
|
525
|
+
user_agent=USER_AGENT,
|
|
526
|
+
)
|
|
527
|
+
return AsyncPrompty.load(
|
|
450
528
|
source=query_response_generating_prompty,
|
|
451
529
|
model=prompty_model_config,
|
|
452
530
|
**query_response_generating_prompty_kwargs,
|
|
453
|
-
)
|
|
531
|
+
) # type: ignore
|
|
454
532
|
|
|
455
533
|
async def _create_conversations_from_query_responses(
|
|
456
534
|
self,
|
|
457
535
|
*,
|
|
458
536
|
query_responses: List[Dict[str, str]],
|
|
459
537
|
max_conversation_turns: int,
|
|
460
|
-
tasks: List[
|
|
538
|
+
tasks: List[str],
|
|
461
539
|
user_simulator_prompty: Optional[str],
|
|
462
540
|
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
463
541
|
target: Callable,
|
|
464
542
|
api_call_delay_sec: float,
|
|
543
|
+
text: str,
|
|
465
544
|
) -> List[JsonLineChatProtocol]:
|
|
466
545
|
"""
|
|
467
546
|
Creates full conversations from query-response pairs.
|
|
@@ -471,7 +550,7 @@ class Simulator:
|
|
|
471
550
|
:keyword max_conversation_turns: The maximum number of conversation turns.
|
|
472
551
|
:paramtype max_conversation_turns: int
|
|
473
552
|
:keyword tasks: A list of tasks for the simulation.
|
|
474
|
-
:paramtype tasks: List[
|
|
553
|
+
:paramtype tasks: List[str]
|
|
475
554
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
476
555
|
:paramtype user_simulator_prompty: Optional[str]
|
|
477
556
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
@@ -480,6 +559,8 @@ class Simulator:
|
|
|
480
559
|
:paramtype target: Callable
|
|
481
560
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
482
561
|
:paramtype api_call_delay_sec: float
|
|
562
|
+
:keyword text: The initial input text for generating query responses.
|
|
563
|
+
:paramtype text: str
|
|
483
564
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
484
565
|
:rtype: List[JsonLineChatProtocol]
|
|
485
566
|
"""
|
|
@@ -501,7 +582,7 @@ class Simulator:
|
|
|
501
582
|
conversation = await self._complete_conversation(
|
|
502
583
|
conversation_starter=query,
|
|
503
584
|
max_conversation_turns=max_conversation_turns,
|
|
504
|
-
task=task,
|
|
585
|
+
task=task, # type: ignore
|
|
505
586
|
user_simulator_prompty=user_simulator_prompty,
|
|
506
587
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
507
588
|
target=target,
|
|
@@ -517,6 +598,7 @@ class Simulator:
|
|
|
517
598
|
"task": task,
|
|
518
599
|
"expected_response": response,
|
|
519
600
|
"query": query,
|
|
601
|
+
"original_text": text,
|
|
520
602
|
},
|
|
521
603
|
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
522
604
|
}
|
|
@@ -536,7 +618,7 @@ class Simulator:
|
|
|
536
618
|
target: Callable,
|
|
537
619
|
api_call_delay_sec: float,
|
|
538
620
|
progress_bar: tqdm,
|
|
539
|
-
) -> List[Dict[str, str]]:
|
|
621
|
+
) -> List[Dict[str, Optional[str]]]:
|
|
540
622
|
"""
|
|
541
623
|
Completes a conversation with the target model based on the conversation starter.
|
|
542
624
|
|
|
@@ -557,36 +639,43 @@ class Simulator:
|
|
|
557
639
|
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
558
640
|
:paramtype progress_bar: tqdm
|
|
559
641
|
:return: A list representing the conversation history with each turn's content.
|
|
560
|
-
:rtype: List[Dict[str, str]]
|
|
642
|
+
:rtype: List[Dict[str, Optional[str]]]
|
|
561
643
|
"""
|
|
562
644
|
conversation_history = ConversationHistory()
|
|
563
|
-
# user_turn = Turn(role=ConversationRole.USER, content=conversation_starter)
|
|
564
|
-
# conversation_history.add_to_history(user_turn)
|
|
565
645
|
|
|
566
646
|
while len(conversation_history) < max_conversation_turns:
|
|
567
647
|
user_flow = self._load_user_simulation_flow(
|
|
568
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
569
|
-
prompty_model_config=self.
|
|
648
|
+
user_simulator_prompty=user_simulator_prompty, # type: ignore
|
|
649
|
+
prompty_model_config=self.model_config, # type: ignore
|
|
570
650
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
571
651
|
)
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
652
|
+
if len(conversation_history) == 0:
|
|
653
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
654
|
+
task=task,
|
|
655
|
+
conversation_history=[
|
|
656
|
+
{
|
|
657
|
+
"role": "assistant",
|
|
658
|
+
"content": conversation_starter,
|
|
659
|
+
}
|
|
660
|
+
],
|
|
661
|
+
action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
665
|
+
task=task,
|
|
666
|
+
conversation_history=conversation_history.to_context_free_list(),
|
|
667
|
+
action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
|
|
668
|
+
)
|
|
582
669
|
if isinstance(conversation_starter_from_simulated_user, dict):
|
|
583
670
|
conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
|
|
584
671
|
user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
|
|
585
672
|
conversation_history.add_to_history(user_turn)
|
|
586
|
-
assistant_response = await self._get_target_response(
|
|
673
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
587
674
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
|
|
588
675
|
)
|
|
589
|
-
assistant_turn = Turn(
|
|
676
|
+
assistant_turn = Turn(
|
|
677
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
678
|
+
)
|
|
590
679
|
conversation_history.add_to_history(assistant_turn)
|
|
591
680
|
progress_bar.update(1)
|
|
592
681
|
|
|
@@ -595,45 +684,9 @@ class Simulator:
|
|
|
595
684
|
|
|
596
685
|
return conversation_history.to_list()
|
|
597
686
|
|
|
598
|
-
async def _build_user_simulation_response(
|
|
599
|
-
self,
|
|
600
|
-
task: str,
|
|
601
|
-
conversation_history: List[Dict[str, Any]],
|
|
602
|
-
user_simulator_prompty: Optional[str],
|
|
603
|
-
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
604
|
-
) -> str:
|
|
605
|
-
"""
|
|
606
|
-
Builds a response from the user simulator based on the current conversation history.
|
|
607
|
-
|
|
608
|
-
:param task: A string representing the task details.
|
|
609
|
-
:type task: str
|
|
610
|
-
:param conversation_history: The current conversation history as a list of dictionaries.
|
|
611
|
-
:type conversation_history: List[Dict[str, Any]]
|
|
612
|
-
:param user_simulator_prompty: Path to the user simulator prompty file.
|
|
613
|
-
:type user_simulator_prompty: Optional[str]
|
|
614
|
-
:param user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
615
|
-
:type user_simulator_prompty_kwargs: Dict[str, Any]
|
|
616
|
-
:return: The generated response content from the user simulator.
|
|
617
|
-
:rtype: str
|
|
618
|
-
:raises RuntimeError: If an error occurs during response generation.
|
|
619
|
-
"""
|
|
620
|
-
user_flow = self._load_user_simulation_flow(
|
|
621
|
-
user_simulator_prompty=user_simulator_prompty,
|
|
622
|
-
prompty_model_config=self._build_prompty_model_config(),
|
|
623
|
-
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
624
|
-
)
|
|
625
|
-
try:
|
|
626
|
-
response_content = user_flow(
|
|
627
|
-
task=task, conversation_history=conversation_history, **user_simulator_prompty_kwargs
|
|
628
|
-
)
|
|
629
|
-
user_response = self._parse_prompty_response(response=response_content)
|
|
630
|
-
return user_response["content"]
|
|
631
|
-
except Exception as e:
|
|
632
|
-
raise RuntimeError("Error building user simulation response") from e
|
|
633
|
-
|
|
634
687
|
async def _get_target_response(
|
|
635
688
|
self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
|
|
636
|
-
) -> str:
|
|
689
|
+
) -> Tuple[str, Optional[str]]:
|
|
637
690
|
"""
|
|
638
691
|
Retrieves the response from the target callback based on the current conversation history.
|
|
639
692
|
|
|
@@ -643,8 +696,8 @@ class Simulator:
|
|
|
643
696
|
:paramtype api_call_delay_sec: float
|
|
644
697
|
:keyword conversation_history: The current conversation history.
|
|
645
698
|
:paramtype conversation_history: ConversationHistory
|
|
646
|
-
:return: The content of the response from the target.
|
|
647
|
-
:rtype: str
|
|
699
|
+
:return: The content of the response from the target and an optional context.
|
|
700
|
+
:rtype: str, Optional[str]
|
|
648
701
|
"""
|
|
649
702
|
response = await target(
|
|
650
703
|
messages={"messages": conversation_history.to_list()},
|
|
@@ -654,4 +707,4 @@ class Simulator:
|
|
|
654
707
|
)
|
|
655
708
|
await asyncio.sleep(api_call_delay_sec)
|
|
656
709
|
latest_message = response["messages"][-1]
|
|
657
|
-
return latest_message["content"]
|
|
710
|
+
return latest_message["content"], latest_message.get("context", "") # type: ignore
|