azure-ai-evaluation 0.0.0b0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- azure/ai/evaluation/__init__.py +82 -0
- azure/ai/evaluation/_common/__init__.py +16 -0
- azure/ai/evaluation/_common/_experimental.py +172 -0
- azure/ai/evaluation/_common/constants.py +72 -0
- azure/ai/evaluation/_common/math.py +89 -0
- azure/ai/evaluation/_common/rai_service.py +632 -0
- azure/ai/evaluation/_common/utils.py +445 -0
- azure/ai/evaluation/_constants.py +72 -0
- azure/ai/evaluation/_evaluate/__init__.py +3 -0
- azure/ai/evaluation/_evaluate/_batch_run/__init__.py +9 -0
- azure/ai/evaluation/_evaluate/_batch_run/code_client.py +188 -0
- azure/ai/evaluation/_evaluate/_batch_run/eval_run_context.py +89 -0
- azure/ai/evaluation/_evaluate/_batch_run/proxy_client.py +99 -0
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +46 -0
- azure/ai/evaluation/_evaluate/_eval_run.py +571 -0
- azure/ai/evaluation/_evaluate/_evaluate.py +850 -0
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +179 -0
- azure/ai/evaluation/_evaluate/_utils.py +298 -0
- azure/ai/evaluation/_evaluators/__init__.py +3 -0
- azure/ai/evaluation/_evaluators/_bleu/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +72 -0
- azure/ai/evaluation/_evaluators/_coherence/__init__.py +7 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +107 -0
- azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +99 -0
- azure/ai/evaluation/_evaluators/_common/__init__.py +13 -0
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +344 -0
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +88 -0
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +133 -0
- azure/ai/evaluation/_evaluators/_content_safety/__init__.py +17 -0
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +144 -0
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +129 -0
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +123 -0
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +125 -0
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +126 -0
- azure/ai/evaluation/_evaluators/_eci/__init__.py +0 -0
- azure/ai/evaluation/_evaluators/_eci/_eci.py +89 -0
- azure/ai/evaluation/_evaluators/_f1_score/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +157 -0
- azure/ai/evaluation/_evaluators/_fluency/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +104 -0
- azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +86 -0
- azure/ai/evaluation/_evaluators/_gleu/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +69 -0
- azure/ai/evaluation/_evaluators/_groundedness/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +144 -0
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
- azure/ai/evaluation/_evaluators/_meteor/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +90 -0
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +132 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +55 -0
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +100 -0
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +124 -0
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +100 -0
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +100 -0
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +100 -0
- azure/ai/evaluation/_evaluators/_protected_material/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +113 -0
- azure/ai/evaluation/_evaluators/_qa/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_qa/_qa.py +93 -0
- azure/ai/evaluation/_evaluators/_relevance/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +114 -0
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +100 -0
- azure/ai/evaluation/_evaluators/_retrieval/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +112 -0
- azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +93 -0
- azure/ai/evaluation/_evaluators/_rouge/__init__.py +10 -0
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +98 -0
- azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +148 -0
- azure/ai/evaluation/_evaluators/_similarity/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +140 -0
- azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +66 -0
- azure/ai/evaluation/_evaluators/_xpia/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +125 -0
- azure/ai/evaluation/_exceptions.py +128 -0
- azure/ai/evaluation/_http_utils.py +466 -0
- azure/ai/evaluation/_model_configurations.py +123 -0
- azure/ai/evaluation/_user_agent.py +6 -0
- azure/ai/evaluation/_vendor/__init__.py +3 -0
- azure/ai/evaluation/_vendor/rouge_score/__init__.py +14 -0
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +328 -0
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +63 -0
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +63 -0
- azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +53 -0
- azure/ai/evaluation/_version.py +5 -0
- azure/ai/evaluation/py.typed +0 -0
- azure/ai/evaluation/simulator/__init__.py +16 -0
- azure/ai/evaluation/simulator/_adversarial_scenario.py +46 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +471 -0
- azure/ai/evaluation/simulator/_constants.py +27 -0
- azure/ai/evaluation/simulator/_conversation/__init__.py +316 -0
- azure/ai/evaluation/simulator/_conversation/_conversation.py +178 -0
- azure/ai/evaluation/simulator/_conversation/constants.py +30 -0
- azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
- azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +218 -0
- azure/ai/evaluation/simulator/_helpers/__init__.py +4 -0
- azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py +17 -0
- azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +96 -0
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +220 -0
- azure/ai/evaluation/simulator/_model_tools/__init__.py +23 -0
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +195 -0
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +244 -0
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +168 -0
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +201 -0
- azure/ai/evaluation/simulator/_model_tools/models.py +614 -0
- azure/ai/evaluation/simulator/_prompty/__init__.py +0 -0
- azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +65 -0
- azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +37 -0
- azure/ai/evaluation/simulator/_simulator.py +716 -0
- azure/ai/evaluation/simulator/_tracing.py +89 -0
- azure/ai/evaluation/simulator/_utils.py +132 -0
- azure_ai_evaluation-1.0.0.dist-info/METADATA +595 -0
- azure_ai_evaluation-1.0.0.dist-info/NOTICE.txt +70 -0
- azure_ai_evaluation-1.0.0.dist-info/RECORD +119 -0
- {azure_ai_evaluation-0.0.0b0.dist-info → azure_ai_evaluation-1.0.0.dist-info}/WHEEL +1 -1
- azure_ai_evaluation-1.0.0.dist-info/top_level.txt +1 -0
- azure_ai_evaluation-0.0.0b0.dist-info/METADATA +0 -7
- azure_ai_evaluation-0.0.0b0.dist-info/RECORD +0 -4
- azure_ai_evaluation-0.0.0b0.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,716 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
# pylint: disable=W0102,W0613,R0914,C0301,E0401,E0611,C0114,R0913,E0702,R0903,C0411
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
5
|
+
# ---------------------------------------------------------
|
|
6
|
+
import asyncio
|
|
7
|
+
import importlib.resources as pkg_resources
|
|
8
|
+
import json
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
13
|
+
|
|
14
|
+
from promptflow.core import AsyncPrompty
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
18
|
+
from azure.ai.evaluation._common.utils import construct_prompty_model_config
|
|
19
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
20
|
+
|
|
21
|
+
from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
|
|
22
|
+
from .._user_agent import USER_AGENT
|
|
23
|
+
from ._conversation.constants import ConversationRole
|
|
24
|
+
from ._helpers import ConversationHistory, Turn
|
|
25
|
+
from ._utils import JsonLineChatProtocol
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@experimental
|
|
29
|
+
class Simulator:
|
|
30
|
+
"""
|
|
31
|
+
Simulator for generating synthetic conversations.
|
|
32
|
+
|
|
33
|
+
:param model_config: A dictionary defining the configuration for the model. Acceptable types are AzureOpenAIModelConfiguration and OpenAIModelConfiguration.
|
|
34
|
+
:type model_config: Union[~azure.ai.evaluation.AzureOpenAIModelConfiguration, ~azure.ai.evaluation.OpenAIModelConfiguration]
|
|
35
|
+
:raises ValueError: If the model_config does not contain the required keys or any value is None.
|
|
36
|
+
|
|
37
|
+
.. admonition:: Example:
|
|
38
|
+
|
|
39
|
+
.. literalinclude:: ../samples/evaluation_samples_simulate.py
|
|
40
|
+
:start-after: [START nonadversarial_simulator]
|
|
41
|
+
:end-before: [END nonadversarial_simulator]
|
|
42
|
+
:language: python
|
|
43
|
+
:dedent: 8
|
|
44
|
+
:caption: Run a Simulator for 2 queries and 4 conversation turns.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration]):
|
|
48
|
+
self._validate_model_config(model_config)
|
|
49
|
+
self.model_config = model_config
|
|
50
|
+
if "api_version" not in self.model_config:
|
|
51
|
+
self.model_config["api_version"] = "2024-06-01" # type: ignore
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def _validate_model_config(model_config: Any):
|
|
55
|
+
"""
|
|
56
|
+
Validates the model_config to ensure all required keys are present and have non-None values.
|
|
57
|
+
If 'type' is not specified, it will attempt to infer the type based on the keys present.
|
|
58
|
+
|
|
59
|
+
:param model_config: The model configuration dictionary.
|
|
60
|
+
:type model_config: Dict[str, Any]
|
|
61
|
+
:raises ValueError: If required keys are missing or any of the values are None.
|
|
62
|
+
"""
|
|
63
|
+
# Attempt to infer 'type' if not provided
|
|
64
|
+
if "type" not in model_config:
|
|
65
|
+
if "azure_deployment" in model_config and "azure_endpoint" in model_config:
|
|
66
|
+
model_config["type"] = "azure_openai"
|
|
67
|
+
elif "model" in model_config:
|
|
68
|
+
model_config["type"] = "openai"
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
"Unable to infer 'type' from model_config. Please specify 'type' as 'azure_openai' or 'openai'."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
if model_config["type"] == "azure_openai":
|
|
75
|
+
required_keys = ["azure_deployment", "azure_endpoint"]
|
|
76
|
+
elif model_config["type"] == "openai":
|
|
77
|
+
required_keys = ["api_key", "model"]
|
|
78
|
+
else:
|
|
79
|
+
raise ValueError("model_config 'type' must be 'azure_openai' or 'openai'.")
|
|
80
|
+
|
|
81
|
+
missing_keys = [key for key in required_keys if key not in model_config]
|
|
82
|
+
if missing_keys:
|
|
83
|
+
raise ValueError(f"model_config is missing required keys: {', '.join(missing_keys)}")
|
|
84
|
+
none_keys = [key for key in required_keys if model_config.get(key) is None]
|
|
85
|
+
if none_keys:
|
|
86
|
+
raise ValueError(f"The following keys in model_config must not be None: {', '.join(none_keys)}")
|
|
87
|
+
|
|
88
|
+
async def __call__(
|
|
89
|
+
self,
|
|
90
|
+
*,
|
|
91
|
+
target: Callable,
|
|
92
|
+
max_conversation_turns: int = 5,
|
|
93
|
+
tasks: List[str] = [],
|
|
94
|
+
text: str = "",
|
|
95
|
+
num_queries: int = 5,
|
|
96
|
+
query_response_generating_prompty: Optional[str] = None,
|
|
97
|
+
user_simulator_prompty: Optional[str] = None,
|
|
98
|
+
api_call_delay_sec: float = 1,
|
|
99
|
+
query_response_generating_prompty_options: Dict[str, Any] = {},
|
|
100
|
+
user_simulator_prompty_options: Dict[str, Any] = {},
|
|
101
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
|
|
102
|
+
concurrent_async_tasks: int = 5,
|
|
103
|
+
**kwargs,
|
|
104
|
+
) -> List[JsonLineChatProtocol]:
|
|
105
|
+
"""
|
|
106
|
+
Generates synthetic conversations based on provided parameters.
|
|
107
|
+
|
|
108
|
+
:keyword target: The target function to call during the simulation.
|
|
109
|
+
:paramtype target: Callable
|
|
110
|
+
:keyword max_conversation_turns: Maximum number of conversation turns for the simulation. Each turn consists of a user and an assistant message.
|
|
111
|
+
:paramtype max_conversation_turns: int
|
|
112
|
+
:keyword tasks: A list of user tasks, each represented as a list of strings. Text should be relevant for the tasks and facilitate the simulation. One example is to use text to provide context for the tasks.
|
|
113
|
+
:paramtype tasks: List[str]
|
|
114
|
+
:keyword text: The initial input text for generating query responses. Given that the same 'text' is provided for a list of tasks, one example use is to break down a user task into sub-tasks that can share the 'text' variable for context.
|
|
115
|
+
:paramtype text: str
|
|
116
|
+
:keyword num_queries: The number of queries to generate.
|
|
117
|
+
:paramtype num_queries: int
|
|
118
|
+
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
119
|
+
:paramtype query_response_generating_prompty: Optional[str]
|
|
120
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
121
|
+
:paramtype user_simulator_prompty: Optional[str]
|
|
122
|
+
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
123
|
+
:paramtype api_call_delay_sec: float
|
|
124
|
+
:keyword query_response_generating_prompty_options: Additional keyword arguments for the query response generating prompty.
|
|
125
|
+
:paramtype query_response_generating_prompty_options: Dict[str, Any]
|
|
126
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
127
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any]
|
|
128
|
+
:keyword conversation_turns: Predefined conversation turns to simulate.
|
|
129
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
130
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
131
|
+
Defaults to 5.
|
|
132
|
+
:paramtype concurrent_async_tasks: int
|
|
133
|
+
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
134
|
+
:rtype: List[JsonLineChatProtocol]
|
|
135
|
+
|
|
136
|
+
Return Value:
|
|
137
|
+
The method returns a list of JsonLineChatProtocol objects, which are essentially a list of dictionaries where the dictionary contains the messages and context. Context includes all the metadata related to the conversation, such as the task, expected response, and query. The messages contain the conversation history, including the user and assistant messages.
|
|
138
|
+
|
|
139
|
+
Modes:
|
|
140
|
+
- Task-Free Mode: When only num_queries is specified and tasks is not, the method generates num_queries x max_conversation_turns lines of simulated data grounded in the context of the text.
|
|
141
|
+
- Task-Specific Mode: When both num_queries and tasks are specified, the method generates lines of simulated data based on the tasks. If num_queries > len(tasks), the remaining lines will be simulated in task-free mode. If num_queries < len(tasks), only the first num_queries tasks are used.
|
|
142
|
+
- Conversation Starter Mode: When conversation_turns are specified, the method starts each conversation with the user-specified queries and then follows the conversation history for the remaining turns.
|
|
143
|
+
"""
|
|
144
|
+
if conversation_turns and (text or tasks):
|
|
145
|
+
raise ValueError("Cannot specify both conversation_turns and text/tasks")
|
|
146
|
+
|
|
147
|
+
if text and num_queries > len(tasks):
|
|
148
|
+
warnings.warn(
|
|
149
|
+
f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
|
|
150
|
+
f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
|
|
151
|
+
)
|
|
152
|
+
elif text and num_queries < len(tasks):
|
|
153
|
+
warnings.warn(
|
|
154
|
+
f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
|
|
155
|
+
f"Only the first {num_queries} lines of the specified tasks will be simulated."
|
|
156
|
+
)
|
|
157
|
+
num_queries = min(num_queries, len(tasks))
|
|
158
|
+
max_conversation_turns *= 2 # account for both user and assistant turns
|
|
159
|
+
|
|
160
|
+
prompty_model_config = self.model_config
|
|
161
|
+
if conversation_turns:
|
|
162
|
+
return await self._simulate_with_predefined_turns(
|
|
163
|
+
target=target,
|
|
164
|
+
max_conversation_turns=max_conversation_turns,
|
|
165
|
+
conversation_turns=conversation_turns,
|
|
166
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
167
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
168
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
169
|
+
prompty_model_config=prompty_model_config,
|
|
170
|
+
concurrent_async_tasks=concurrent_async_tasks,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
query_responses = await self._generate_query_responses(
|
|
174
|
+
text=text,
|
|
175
|
+
num_queries=num_queries,
|
|
176
|
+
query_response_generating_prompty=query_response_generating_prompty,
|
|
177
|
+
query_response_generating_prompty_options=query_response_generating_prompty_options,
|
|
178
|
+
prompty_model_config=prompty_model_config,
|
|
179
|
+
**kwargs,
|
|
180
|
+
)
|
|
181
|
+
return await self._create_conversations_from_query_responses(
|
|
182
|
+
query_responses=query_responses,
|
|
183
|
+
max_conversation_turns=max_conversation_turns,
|
|
184
|
+
tasks=tasks,
|
|
185
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
186
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
187
|
+
target=target,
|
|
188
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
189
|
+
text=text,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
async def _simulate_with_predefined_turns(
|
|
193
|
+
self,
|
|
194
|
+
*,
|
|
195
|
+
target: Callable,
|
|
196
|
+
max_conversation_turns: int,
|
|
197
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]],
|
|
198
|
+
user_simulator_prompty: Optional[str],
|
|
199
|
+
user_simulator_prompty_options: Dict[str, Any],
|
|
200
|
+
api_call_delay_sec: float,
|
|
201
|
+
prompty_model_config: Any,
|
|
202
|
+
concurrent_async_tasks: int,
|
|
203
|
+
) -> List[JsonLineChatProtocol]:
|
|
204
|
+
"""
|
|
205
|
+
Simulates conversations using predefined conversation turns.
|
|
206
|
+
|
|
207
|
+
:keyword target: The target function to call during each turn of the simulation.
|
|
208
|
+
:paramtype target: Callable
|
|
209
|
+
:keyword max_conversation_turns: Maximum number of turns for the simulation.
|
|
210
|
+
:paramtype max_conversation_turns: int
|
|
211
|
+
:keyword conversation_turns: A list of predefined conversation turns.
|
|
212
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
213
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
214
|
+
:paramtype user_simulator_prompty: Optional[str]
|
|
215
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
216
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any]
|
|
217
|
+
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
218
|
+
:paramtype api_call_delay_sec: float
|
|
219
|
+
:keyword prompty_model_config: The configuration for the prompty model.
|
|
220
|
+
:paramtype prompty_model_config: Any
|
|
221
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
222
|
+
:paramtype concurrent_async_tasks: int
|
|
223
|
+
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
224
|
+
:rtype: List[JsonLineChatProtocol]
|
|
225
|
+
"""
|
|
226
|
+
progress_bar = tqdm(
|
|
227
|
+
total=int(len(conversation_turns) * (max_conversation_turns / 2)),
|
|
228
|
+
desc="Simulating with predefined conversation turns: ",
|
|
229
|
+
ncols=100,
|
|
230
|
+
unit="messages",
|
|
231
|
+
)
|
|
232
|
+
semaphore = asyncio.Semaphore(concurrent_async_tasks)
|
|
233
|
+
progress_bar_lock = asyncio.Lock()
|
|
234
|
+
|
|
235
|
+
async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
|
|
236
|
+
async with semaphore:
|
|
237
|
+
current_simulation = ConversationHistory()
|
|
238
|
+
for simulated_turn in simulation:
|
|
239
|
+
if isinstance(simulated_turn, str):
|
|
240
|
+
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
|
|
241
|
+
elif isinstance(simulated_turn, dict):
|
|
242
|
+
user_turn = Turn(
|
|
243
|
+
role=ConversationRole.USER,
|
|
244
|
+
content=str(simulated_turn.get("content")),
|
|
245
|
+
context=str(simulated_turn.get("context")),
|
|
246
|
+
)
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(
|
|
249
|
+
"Each simulated turn must be a string or a dict with 'content' and 'context' keys"
|
|
250
|
+
)
|
|
251
|
+
current_simulation.add_to_history(user_turn)
|
|
252
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
253
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
254
|
+
)
|
|
255
|
+
assistant_turn = Turn(
|
|
256
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
257
|
+
)
|
|
258
|
+
current_simulation.add_to_history(assistant_turn)
|
|
259
|
+
async with progress_bar_lock:
|
|
260
|
+
progress_bar.update(1)
|
|
261
|
+
|
|
262
|
+
if len(current_simulation) < max_conversation_turns:
|
|
263
|
+
await self._extend_conversation_with_simulator(
|
|
264
|
+
current_simulation=current_simulation,
|
|
265
|
+
max_conversation_turns=max_conversation_turns,
|
|
266
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
267
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
268
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
269
|
+
prompty_model_config=prompty_model_config,
|
|
270
|
+
target=target,
|
|
271
|
+
progress_bar=progress_bar,
|
|
272
|
+
progress_bar_lock=progress_bar_lock,
|
|
273
|
+
)
|
|
274
|
+
return JsonLineChatProtocol(
|
|
275
|
+
{
|
|
276
|
+
"messages": current_simulation.to_list(),
|
|
277
|
+
"finish_reason": ["stop"],
|
|
278
|
+
"context": {},
|
|
279
|
+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
280
|
+
}
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
|
|
284
|
+
results = await asyncio.gather(*tasks)
|
|
285
|
+
progress_bar.close()
|
|
286
|
+
return results
|
|
287
|
+
|
|
288
|
+
async def _extend_conversation_with_simulator(
|
|
289
|
+
self,
|
|
290
|
+
*,
|
|
291
|
+
current_simulation: ConversationHistory,
|
|
292
|
+
max_conversation_turns: int,
|
|
293
|
+
user_simulator_prompty: Optional[str],
|
|
294
|
+
user_simulator_prompty_options: Dict[str, Any],
|
|
295
|
+
api_call_delay_sec: float,
|
|
296
|
+
prompty_model_config: Dict[str, Any],
|
|
297
|
+
target: Callable,
|
|
298
|
+
progress_bar: tqdm,
|
|
299
|
+
progress_bar_lock: asyncio.Lock,
|
|
300
|
+
):
|
|
301
|
+
"""
|
|
302
|
+
Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
|
|
303
|
+
|
|
304
|
+
:keyword current_simulation: The current state of the conversation history.
|
|
305
|
+
:paramtype current_simulation: ConversationHistory,
|
|
306
|
+
:keyword max_conversation_turns: The maximum number of conversation turns.
|
|
307
|
+
:paramtype max_conversation_turns: int,
|
|
308
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
309
|
+
:paramtype user_simulator_prompty: Optional[str],
|
|
310
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
311
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any],
|
|
312
|
+
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
313
|
+
:paramtype api_call_delay_sec: float,
|
|
314
|
+
:keyword prompty_model_config: The configuration for the prompty model.
|
|
315
|
+
:paramtype prompty_model_config: Dict[str, Any],
|
|
316
|
+
:keyword target: The target function to call for responses.
|
|
317
|
+
:paramtype target: Callable,
|
|
318
|
+
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
319
|
+
:paramtype progress_bar: tqdm,
|
|
320
|
+
:keyword progress_bar_lock: Lock for updating the progress bar safely.
|
|
321
|
+
:paramtype progress_bar_lock: asyncio.Lock
|
|
322
|
+
"""
|
|
323
|
+
user_flow = self._load_user_simulation_flow(
|
|
324
|
+
user_simulator_prompty=user_simulator_prompty, # type: ignore
|
|
325
|
+
prompty_model_config=prompty_model_config,
|
|
326
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
while len(current_simulation) < max_conversation_turns:
|
|
330
|
+
user_response_content = await user_flow(
|
|
331
|
+
task="Continue the conversation",
|
|
332
|
+
conversation_history=current_simulation.to_context_free_list(),
|
|
333
|
+
**user_simulator_prompty_options,
|
|
334
|
+
)
|
|
335
|
+
user_response = self._parse_prompty_response(response=user_response_content)
|
|
336
|
+
user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
|
|
337
|
+
current_simulation.add_to_history(user_turn)
|
|
338
|
+
await asyncio.sleep(api_call_delay_sec)
|
|
339
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
340
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
341
|
+
)
|
|
342
|
+
assistant_turn = Turn(
|
|
343
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
344
|
+
)
|
|
345
|
+
current_simulation.add_to_history(assistant_turn)
|
|
346
|
+
async with progress_bar_lock:
|
|
347
|
+
progress_bar.update(1)
|
|
348
|
+
|
|
349
|
+
def _load_user_simulation_flow(
|
|
350
|
+
self,
|
|
351
|
+
*,
|
|
352
|
+
user_simulator_prompty: Optional[Union[str, os.PathLike]],
|
|
353
|
+
prompty_model_config: Dict[str, Any],
|
|
354
|
+
user_simulator_prompty_options: Dict[str, Any],
|
|
355
|
+
) -> "AsyncPrompty": # type: ignore
|
|
356
|
+
"""
|
|
357
|
+
Loads the flow for simulating user interactions.
|
|
358
|
+
|
|
359
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
360
|
+
:paramtype user_simulator_prompty: Optional[Union[str, os.PathLike]]
|
|
361
|
+
:keyword prompty_model_config: The configuration for the prompty model.
|
|
362
|
+
:paramtype prompty_model_config: Dict[str, Any]
|
|
363
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
364
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any]
|
|
365
|
+
:return: The loaded flow for simulating user interactions.
|
|
366
|
+
:rtype: AsyncPrompty
|
|
367
|
+
"""
|
|
368
|
+
if not user_simulator_prompty:
|
|
369
|
+
package = "azure.ai.evaluation.simulator._prompty"
|
|
370
|
+
resource_name = "task_simulate.prompty"
|
|
371
|
+
try:
|
|
372
|
+
# Access the resource as a file path
|
|
373
|
+
# pylint: disable=deprecated-method
|
|
374
|
+
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
375
|
+
prompty_model_config = construct_prompty_model_config(
|
|
376
|
+
model_config=prompty_model_config, # type: ignore
|
|
377
|
+
default_api_version="2024-06-01",
|
|
378
|
+
user_agent=USER_AGENT,
|
|
379
|
+
)
|
|
380
|
+
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
381
|
+
except FileNotFoundError as e:
|
|
382
|
+
msg = f"Flow path for {resource_name} does not exist in package {package}."
|
|
383
|
+
raise EvaluationException(
|
|
384
|
+
message=msg,
|
|
385
|
+
internal_message=msg,
|
|
386
|
+
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
387
|
+
blame=ErrorBlame.USER_ERROR,
|
|
388
|
+
) from e
|
|
389
|
+
prompty_model_config = construct_prompty_model_config(
|
|
390
|
+
model_config=prompty_model_config, # type: ignore
|
|
391
|
+
default_api_version="2024-06-01",
|
|
392
|
+
user_agent=USER_AGENT,
|
|
393
|
+
)
|
|
394
|
+
return AsyncPrompty.load(
|
|
395
|
+
source=user_simulator_prompty,
|
|
396
|
+
model=prompty_model_config,
|
|
397
|
+
**user_simulator_prompty_options,
|
|
398
|
+
) # type: ignore
|
|
399
|
+
|
|
400
|
+
def _parse_prompty_response(self, *, response: str) -> Dict[str, Any]:
|
|
401
|
+
"""
|
|
402
|
+
Parses the response from the prompty execution.
|
|
403
|
+
|
|
404
|
+
:keyword response: The raw response from the prompty.
|
|
405
|
+
:paramtype response: str
|
|
406
|
+
:return: A dictionary representing the parsed response content.
|
|
407
|
+
:rtype: Dict[str, Any]
|
|
408
|
+
:raises ValueError: If the response cannot be parsed.
|
|
409
|
+
"""
|
|
410
|
+
try:
|
|
411
|
+
if isinstance(response, str):
|
|
412
|
+
response = response.replace("\u2019", "'").replace("\u2018", "'")
|
|
413
|
+
response = response.replace("\u201C", '"').replace("\u201D", '"')
|
|
414
|
+
|
|
415
|
+
# Replace None with null
|
|
416
|
+
response = response.replace("None", "null")
|
|
417
|
+
|
|
418
|
+
# Escape unescaped single quotes inside string values
|
|
419
|
+
def escape_single_quotes(match):
|
|
420
|
+
s = match.group(0)
|
|
421
|
+
# Remove the outer single quotes
|
|
422
|
+
s_content = s[1:-1]
|
|
423
|
+
# Escape single quotes within the content
|
|
424
|
+
s_content_escaped = s_content.replace("'", "\\'")
|
|
425
|
+
return f"'{s_content_escaped}'"
|
|
426
|
+
|
|
427
|
+
# Pattern to match single-quoted strings
|
|
428
|
+
pattern = r"'(.*?)'"
|
|
429
|
+
response = re.sub(pattern, escape_single_quotes, response)
|
|
430
|
+
|
|
431
|
+
# Now replace single quotes around keys and values with double quotes
|
|
432
|
+
response = re.sub(r"'([^']+)'", r'"\1"', response)
|
|
433
|
+
parsed_data = json.loads(response)
|
|
434
|
+
return parsed_data
|
|
435
|
+
return response
|
|
436
|
+
except Exception as e:
|
|
437
|
+
raise ValueError("Error parsing response content") from e
|
|
438
|
+
|
|
439
|
+
async def _generate_query_responses(
|
|
440
|
+
self,
|
|
441
|
+
*,
|
|
442
|
+
text: str,
|
|
443
|
+
num_queries: int,
|
|
444
|
+
query_response_generating_prompty: Optional[str],
|
|
445
|
+
query_response_generating_prompty_options: Dict[str, Any],
|
|
446
|
+
prompty_model_config: Any,
|
|
447
|
+
**kwargs,
|
|
448
|
+
) -> List[Dict[str, str]]:
|
|
449
|
+
"""
|
|
450
|
+
Generates query responses using the specified prompty configuration.
|
|
451
|
+
|
|
452
|
+
:keyword text: The input text for generating queries.
|
|
453
|
+
:paramtype text: str
|
|
454
|
+
:keyword num_queries: The number of queries to generate.
|
|
455
|
+
:paramtype num_queries: int
|
|
456
|
+
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
457
|
+
:paramtype query_response_generating_prompty: Optional[str]
|
|
458
|
+
:keyword query_response_generating_prompty_options: Additional keyword arguments for the query response generating prompty.
|
|
459
|
+
:paramtype query_response_generating_prompty_options: Dict[str, Any]
|
|
460
|
+
:keyword prompty_model_config: The configuration for the prompty model.
|
|
461
|
+
:paramtype prompty_model_config: Any
|
|
462
|
+
:return: A list of query-response dictionaries.
|
|
463
|
+
:rtype: List[Dict[str, str]]
|
|
464
|
+
:raises RuntimeError: If an error occurs during query generation.
|
|
465
|
+
"""
|
|
466
|
+
query_flow = self._load_query_generation_flow(
|
|
467
|
+
query_response_generating_prompty=query_response_generating_prompty, # type: ignore
|
|
468
|
+
prompty_model_config=prompty_model_config,
|
|
469
|
+
query_response_generating_prompty_options=query_response_generating_prompty_options,
|
|
470
|
+
)
|
|
471
|
+
try:
|
|
472
|
+
query_responses = await query_flow(text=text, num_queries=num_queries)
|
|
473
|
+
if isinstance(query_responses, dict):
|
|
474
|
+
keys = list(query_responses.keys())
|
|
475
|
+
return query_responses[keys[0]]
|
|
476
|
+
if isinstance(query_responses, str):
|
|
477
|
+
query_responses = json.loads(query_responses)
|
|
478
|
+
if isinstance(query_responses, dict):
|
|
479
|
+
if len(query_responses.keys()) == 1:
|
|
480
|
+
return query_responses[list(query_responses.keys())[0]]
|
|
481
|
+
return query_responses # type: ignore
|
|
482
|
+
if isinstance(query_responses, list):
|
|
483
|
+
return query_responses
|
|
484
|
+
return json.loads(query_responses)
|
|
485
|
+
except Exception as e:
|
|
486
|
+
raise RuntimeError("Error generating query responses") from e
|
|
487
|
+
|
|
488
|
+
def _load_query_generation_flow(
|
|
489
|
+
self,
|
|
490
|
+
*,
|
|
491
|
+
query_response_generating_prompty: Optional[Union[str, os.PathLike]],
|
|
492
|
+
prompty_model_config: Dict[str, Any],
|
|
493
|
+
query_response_generating_prompty_options: Dict[str, Any],
|
|
494
|
+
) -> "AsyncPrompty":
|
|
495
|
+
"""
|
|
496
|
+
Loads the flow for generating query responses.
|
|
497
|
+
|
|
498
|
+
:keyword query_response_generating_prompty: Path to the query response generating prompty file.
|
|
499
|
+
:paramtype query_response_generating_prompty: Optional[Union[str, os.PathLike]]
|
|
500
|
+
:keyword prompty_model_config: The configuration for the prompty model.
|
|
501
|
+
:paramtype prompty_model_config: Dict[str, Any]
|
|
502
|
+
:keyword query_response_generating_prompty_options: Additional keyword arguments for the flow.
|
|
503
|
+
:paramtype query_response_generating_prompty_options: Dict[str, Any]
|
|
504
|
+
:return: The loaded flow for generating query responses.
|
|
505
|
+
:rtype: AsyncPrompty
|
|
506
|
+
"""
|
|
507
|
+
if not query_response_generating_prompty:
|
|
508
|
+
package = "azure.ai.evaluation.simulator._prompty"
|
|
509
|
+
resource_name = "task_query_response.prompty"
|
|
510
|
+
try:
|
|
511
|
+
# Access the resource as a file path
|
|
512
|
+
# pylint: disable=deprecated-method
|
|
513
|
+
with pkg_resources.path(package, resource_name) as prompty_path:
|
|
514
|
+
prompty_model_config = construct_prompty_model_config(
|
|
515
|
+
model_config=prompty_model_config, # type: ignore
|
|
516
|
+
default_api_version="2024-06-01",
|
|
517
|
+
user_agent=USER_AGENT,
|
|
518
|
+
)
|
|
519
|
+
return AsyncPrompty.load(source=prompty_path, model=prompty_model_config) # type: ignore
|
|
520
|
+
except FileNotFoundError as e:
|
|
521
|
+
msg = f"Flow path for {resource_name} does not exist in package {package}."
|
|
522
|
+
raise EvaluationException(
|
|
523
|
+
message=msg,
|
|
524
|
+
internal_message=msg,
|
|
525
|
+
error_category=ErrorCategory.FILE_OR_FOLDER_NOT_FOUND,
|
|
526
|
+
blame=ErrorBlame.USER_ERROR,
|
|
527
|
+
) from e
|
|
528
|
+
prompty_model_config = construct_prompty_model_config(
|
|
529
|
+
model_config=prompty_model_config, # type: ignore
|
|
530
|
+
default_api_version="2024-06-01",
|
|
531
|
+
user_agent=USER_AGENT,
|
|
532
|
+
)
|
|
533
|
+
return AsyncPrompty.load(
|
|
534
|
+
source=query_response_generating_prompty,
|
|
535
|
+
model=prompty_model_config,
|
|
536
|
+
**query_response_generating_prompty_options,
|
|
537
|
+
) # type: ignore
|
|
538
|
+
|
|
539
|
+
async def _create_conversations_from_query_responses(
|
|
540
|
+
self,
|
|
541
|
+
*,
|
|
542
|
+
query_responses: List[Dict[str, str]],
|
|
543
|
+
max_conversation_turns: int,
|
|
544
|
+
tasks: List[str],
|
|
545
|
+
user_simulator_prompty: Optional[str],
|
|
546
|
+
user_simulator_prompty_options: Dict[str, Any],
|
|
547
|
+
target: Callable,
|
|
548
|
+
api_call_delay_sec: float,
|
|
549
|
+
text: str,
|
|
550
|
+
) -> List[JsonLineChatProtocol]:
|
|
551
|
+
"""
|
|
552
|
+
Creates full conversations from query-response pairs.
|
|
553
|
+
|
|
554
|
+
:keyword query_responses: A list of query-response pairs.
|
|
555
|
+
:paramtype query_responses: List[Dict[str, str]]
|
|
556
|
+
:keyword max_conversation_turns: The maximum number of conversation turns.
|
|
557
|
+
:paramtype max_conversation_turns: int
|
|
558
|
+
:keyword tasks: A list of tasks for the simulation.
|
|
559
|
+
:paramtype tasks: List[str]
|
|
560
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
561
|
+
:paramtype user_simulator_prompty: Optional[str]
|
|
562
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
563
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any]
|
|
564
|
+
:keyword target: The target function to call for responses.
|
|
565
|
+
:paramtype target: Callable
|
|
566
|
+
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
567
|
+
:paramtype api_call_delay_sec: float
|
|
568
|
+
:keyword text: The initial input text for generating query responses.
|
|
569
|
+
:paramtype text: str
|
|
570
|
+
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
571
|
+
:rtype: List[JsonLineChatProtocol]
|
|
572
|
+
"""
|
|
573
|
+
total_turns = len(query_responses) * max_conversation_turns
|
|
574
|
+
|
|
575
|
+
progress_bar = tqdm(
|
|
576
|
+
total=int(total_turns / 2),
|
|
577
|
+
desc="Generating: ",
|
|
578
|
+
ncols=100,
|
|
579
|
+
unit="message",
|
|
580
|
+
)
|
|
581
|
+
all_conversations = []
|
|
582
|
+
|
|
583
|
+
for i, query_response_pair in enumerate(query_responses):
|
|
584
|
+
query = query_response_pair["q"]
|
|
585
|
+
response = query_response_pair["r"]
|
|
586
|
+
task = tasks[i]
|
|
587
|
+
|
|
588
|
+
conversation = await self._complete_conversation(
|
|
589
|
+
conversation_starter=query,
|
|
590
|
+
max_conversation_turns=max_conversation_turns,
|
|
591
|
+
task=task, # type: ignore
|
|
592
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
593
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
594
|
+
target=target,
|
|
595
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
596
|
+
progress_bar=progress_bar,
|
|
597
|
+
)
|
|
598
|
+
all_conversations.append(
|
|
599
|
+
JsonLineChatProtocol(
|
|
600
|
+
{
|
|
601
|
+
"messages": conversation,
|
|
602
|
+
"finish_reason": ["stop"],
|
|
603
|
+
"context": {
|
|
604
|
+
"task": task,
|
|
605
|
+
"expected_response": response,
|
|
606
|
+
"query": query,
|
|
607
|
+
"original_text": text,
|
|
608
|
+
},
|
|
609
|
+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
610
|
+
}
|
|
611
|
+
)
|
|
612
|
+
)
|
|
613
|
+
progress_bar.close()
|
|
614
|
+
return all_conversations
|
|
615
|
+
|
|
616
|
+
async def _complete_conversation(
|
|
617
|
+
self,
|
|
618
|
+
*,
|
|
619
|
+
conversation_starter: str,
|
|
620
|
+
max_conversation_turns: int,
|
|
621
|
+
task: str,
|
|
622
|
+
user_simulator_prompty: Optional[str],
|
|
623
|
+
user_simulator_prompty_options: Dict[str, Any],
|
|
624
|
+
target: Callable,
|
|
625
|
+
api_call_delay_sec: float,
|
|
626
|
+
progress_bar: tqdm,
|
|
627
|
+
) -> List[Dict[str, Optional[str]]]:
|
|
628
|
+
"""
|
|
629
|
+
Completes a conversation with the target model based on the conversation starter.
|
|
630
|
+
|
|
631
|
+
:keyword conversation_starter: The initial message to start the conversation.
|
|
632
|
+
:paramtype conversation_starter: str
|
|
633
|
+
:keyword max_conversation_turns: The maximum number of turns in the conversation.
|
|
634
|
+
:paramtype max_conversation_turns: int
|
|
635
|
+
:keyword task: A string representing the task details.
|
|
636
|
+
:paramtype task: str
|
|
637
|
+
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
638
|
+
:paramtype user_simulator_prompty: Optional[str]
|
|
639
|
+
:keyword user_simulator_prompty_options: Additional keyword arguments for the user simulator prompty.
|
|
640
|
+
:paramtype user_simulator_prompty_options: Dict[str, Any]
|
|
641
|
+
:keyword target: The target function to call for responses.
|
|
642
|
+
:paramtype target: Callable
|
|
643
|
+
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
644
|
+
:paramtype api_call_delay_sec: float
|
|
645
|
+
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
646
|
+
:paramtype progress_bar: tqdm
|
|
647
|
+
:return: A list representing the conversation history with each turn's content.
|
|
648
|
+
:rtype: List[Dict[str, Optional[str]]]
|
|
649
|
+
"""
|
|
650
|
+
conversation_history = ConversationHistory()
|
|
651
|
+
|
|
652
|
+
while len(conversation_history) < max_conversation_turns:
|
|
653
|
+
user_flow = self._load_user_simulation_flow(
|
|
654
|
+
user_simulator_prompty=user_simulator_prompty, # type: ignore
|
|
655
|
+
prompty_model_config=self.model_config, # type: ignore
|
|
656
|
+
user_simulator_prompty_options=user_simulator_prompty_options,
|
|
657
|
+
)
|
|
658
|
+
if len(conversation_history) == 0:
|
|
659
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
660
|
+
task=task,
|
|
661
|
+
conversation_history=[
|
|
662
|
+
{
|
|
663
|
+
"role": "assistant",
|
|
664
|
+
"content": conversation_starter,
|
|
665
|
+
}
|
|
666
|
+
],
|
|
667
|
+
action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
|
|
668
|
+
)
|
|
669
|
+
else:
|
|
670
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
671
|
+
task=task,
|
|
672
|
+
conversation_history=conversation_history.to_context_free_list(),
|
|
673
|
+
action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
|
|
674
|
+
)
|
|
675
|
+
if isinstance(conversation_starter_from_simulated_user, dict):
|
|
676
|
+
conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
|
|
677
|
+
user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
|
|
678
|
+
conversation_history.add_to_history(user_turn)
|
|
679
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
680
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
|
|
681
|
+
)
|
|
682
|
+
assistant_turn = Turn(
|
|
683
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
684
|
+
)
|
|
685
|
+
conversation_history.add_to_history(assistant_turn)
|
|
686
|
+
progress_bar.update(1)
|
|
687
|
+
|
|
688
|
+
if len(conversation_history) >= max_conversation_turns:
|
|
689
|
+
break
|
|
690
|
+
|
|
691
|
+
return conversation_history.to_list()
|
|
692
|
+
|
|
693
|
+
async def _get_target_response(
|
|
694
|
+
self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
|
|
695
|
+
) -> Tuple[str, Optional[str]]:
|
|
696
|
+
"""
|
|
697
|
+
Retrieves the response from the target callback based on the current conversation history.
|
|
698
|
+
|
|
699
|
+
:keyword target: The target function to call for a response.
|
|
700
|
+
:paramtype target: Callable
|
|
701
|
+
:keyword api_call_delay_sec: Delay in seconds before retrieving the response.
|
|
702
|
+
:paramtype api_call_delay_sec: float
|
|
703
|
+
:keyword conversation_history: The current conversation history.
|
|
704
|
+
:paramtype conversation_history: ConversationHistory
|
|
705
|
+
:return: The content of the response from the target and an optional context.
|
|
706
|
+
:rtype: str, Optional[str]
|
|
707
|
+
"""
|
|
708
|
+
response = await target(
|
|
709
|
+
messages={"messages": conversation_history.to_list()},
|
|
710
|
+
stream=False,
|
|
711
|
+
session_state=None,
|
|
712
|
+
context=None,
|
|
713
|
+
)
|
|
714
|
+
await asyncio.sleep(api_call_delay_sec)
|
|
715
|
+
latest_message = response["messages"][-1]
|
|
716
|
+
return latest_message["content"], latest_message.get("context", "") # type: ignore
|