azure-ai-evaluation 1.0.0b4__py3-none-any.whl → 1.0.0b5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +22 -0
- azure/ai/evaluation/_common/constants.py +5 -0
- azure/ai/evaluation/_common/math.py +11 -0
- azure/ai/evaluation/_common/rai_service.py +172 -35
- azure/ai/evaluation/_common/utils.py +162 -23
- azure/ai/evaluation/_constants.py +6 -6
- azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/__init__.py +3 -2
- azure/ai/evaluation/_evaluate/{_batch_run_client/batch_run_context.py → _batch_run/eval_run_context.py} +4 -4
- azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/proxy_client.py +6 -3
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +35 -0
- azure/ai/evaluation/_evaluate/_eval_run.py +21 -4
- azure/ai/evaluation/_evaluate/_evaluate.py +267 -139
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +5 -5
- azure/ai/evaluation/_evaluate/_utils.py +40 -7
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +1 -1
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +14 -9
- azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +76 -34
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +20 -19
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +18 -8
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +48 -9
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +56 -19
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +5 -5
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +30 -1
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +30 -1
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +30 -1
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +30 -1
- azure/ai/evaluation/_evaluators/_eci/_eci.py +3 -1
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +20 -20
- azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +66 -36
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +1 -1
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +49 -15
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +113 -0
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +99 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +3 -7
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +20 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +130 -0
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal_base.py +57 -0
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +120 -0
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +96 -0
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +96 -0
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +44 -11
- azure/ai/evaluation/_evaluators/_qa/_qa.py +7 -3
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +21 -19
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +78 -42
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +125 -82
- azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +74 -24
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +2 -2
- azure/ai/evaluation/_evaluators/_service_groundedness/__init__.py +9 -0
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +150 -0
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +17 -14
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +32 -5
- azure/ai/evaluation/_exceptions.py +17 -0
- azure/ai/evaluation/_model_configurations.py +18 -1
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/simulator/__init__.py +2 -1
- azure/ai/evaluation/simulator/_adversarial_scenario.py +5 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +4 -1
- azure/ai/evaluation/simulator/_data_sources/__init__.py +3 -0
- azure/ai/evaluation/simulator/_data_sources/grounding.json +1150 -0
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +1 -1
- azure/ai/evaluation/simulator/_helpers/__init__.py +1 -2
- azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +22 -1
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +79 -34
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +1 -1
- azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +4 -4
- azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +6 -1
- azure/ai/evaluation/simulator/_simulator.py +115 -61
- azure/ai/evaluation/simulator/_utils.py +6 -6
- {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/METADATA +166 -9
- {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/NOTICE.txt +20 -0
- azure_ai_evaluation-1.0.0b5.dist-info/RECORD +120 -0
- {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +0 -49
- azure_ai_evaluation-1.0.0b4.dist-info/RECORD +0 -106
- /azure/ai/evaluation/{simulator/_helpers → _common}/_experimental.py +0 -0
- /azure/ai/evaluation/_evaluate/{_batch_run_client → _batch_run}/code_client.py +0 -0
- {azure_ai_evaluation-1.0.0b4.dist-info → azure_ai_evaluation-1.0.0b5.dist-info}/top_level.txt +0 -0
|
@@ -7,13 +7,13 @@ import logging
|
|
|
7
7
|
from random import randint
|
|
8
8
|
from typing import Callable, Optional, cast
|
|
9
9
|
|
|
10
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
10
11
|
from azure.ai.evaluation._common.utils import validate_azure_ai_project
|
|
11
12
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
12
13
|
from azure.ai.evaluation.simulator import AdversarialScenario
|
|
13
14
|
from azure.core.credentials import TokenCredential
|
|
14
15
|
|
|
15
16
|
from ._adversarial_simulator import AdversarialSimulator
|
|
16
|
-
from ._helpers import experimental
|
|
17
17
|
from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
|
|
18
18
|
|
|
19
19
|
logger = logging.getLogger(__name__)
|
|
@@ -1,5 +1,4 @@
|
|
|
1
|
-
from ._experimental import experimental
|
|
2
1
|
from ._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING
|
|
3
2
|
from ._simulator_data_classes import ConversationHistory, Turn
|
|
4
3
|
|
|
5
|
-
__all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING"
|
|
4
|
+
__all__ = ["ConversationHistory", "Turn", "SUPPORTED_LANGUAGES_MAPPING"]
|
|
@@ -30,7 +30,19 @@ class Turn:
|
|
|
30
30
|
return {
|
|
31
31
|
"role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
|
|
32
32
|
"content": self.content,
|
|
33
|
-
"context": self.context,
|
|
33
|
+
"context": str(self.context),
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
def to_context_free_dict(self) -> Dict[str, Optional[str]]:
|
|
37
|
+
"""
|
|
38
|
+
Convert the conversation turn to a dictionary without context.
|
|
39
|
+
|
|
40
|
+
:returns: A dictionary representation of the conversation turn without context.
|
|
41
|
+
:rtype: Dict[str, Optional[str]]
|
|
42
|
+
"""
|
|
43
|
+
return {
|
|
44
|
+
"role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
|
|
45
|
+
"content": self.content,
|
|
34
46
|
}
|
|
35
47
|
|
|
36
48
|
def __repr__(self):
|
|
@@ -66,6 +78,15 @@ class ConversationHistory:
|
|
|
66
78
|
"""
|
|
67
79
|
return [turn.to_dict() for turn in self.history]
|
|
68
80
|
|
|
81
|
+
def to_context_free_list(self) -> List[Dict[str, Optional[str]]]:
|
|
82
|
+
"""
|
|
83
|
+
Converts the conversation history to a list of dictionaries without context.
|
|
84
|
+
|
|
85
|
+
:returns: A list of dictionaries representing the conversation turns without context.
|
|
86
|
+
:rtype: List[Dict[str, str]]
|
|
87
|
+
"""
|
|
88
|
+
return [turn.to_context_free_dict() for turn in self.history]
|
|
89
|
+
|
|
69
90
|
def __len__(self) -> int:
|
|
70
91
|
return len(self.history)
|
|
71
92
|
|
|
@@ -3,23 +3,27 @@
|
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
# pylint: disable=C0301,C0114,R0913,R0903
|
|
5
5
|
# noqa: E501
|
|
6
|
+
import asyncio
|
|
6
7
|
import logging
|
|
7
8
|
from typing import Callable, cast
|
|
8
9
|
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
9
12
|
from azure.ai.evaluation._common.utils import validate_azure_ai_project
|
|
13
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
10
14
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
11
|
-
from azure.ai.evaluation.simulator import
|
|
15
|
+
from azure.ai.evaluation.simulator import AdversarialScenarioJailbreak, SupportedLanguages
|
|
12
16
|
from azure.core.credentials import TokenCredential
|
|
13
17
|
|
|
14
|
-
from ._adversarial_simulator import AdversarialSimulator
|
|
15
|
-
|
|
18
|
+
from ._adversarial_simulator import AdversarialSimulator, JsonLineList
|
|
19
|
+
|
|
16
20
|
from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
|
|
17
21
|
|
|
18
22
|
logger = logging.getLogger(__name__)
|
|
19
23
|
|
|
20
24
|
|
|
21
25
|
@experimental
|
|
22
|
-
class IndirectAttackSimulator:
|
|
26
|
+
class IndirectAttackSimulator(AdversarialSimulator):
|
|
23
27
|
"""
|
|
24
28
|
Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
|
|
25
29
|
|
|
@@ -54,6 +58,7 @@ class IndirectAttackSimulator:
|
|
|
54
58
|
self.adversarial_template_handler = AdversarialTemplateHandler(
|
|
55
59
|
azure_ai_project=self.azure_ai_project, rai_client=self.rai_client
|
|
56
60
|
)
|
|
61
|
+
super().__init__(azure_ai_project=azure_ai_project, credential=credential)
|
|
57
62
|
|
|
58
63
|
def _ensure_service_dependencies(self):
|
|
59
64
|
if self.rai_client is None:
|
|
@@ -69,29 +74,22 @@ class IndirectAttackSimulator:
|
|
|
69
74
|
async def __call__(
|
|
70
75
|
self,
|
|
71
76
|
*,
|
|
72
|
-
scenario: AdversarialScenario,
|
|
73
77
|
target: Callable,
|
|
74
|
-
max_conversation_turns: int = 1,
|
|
75
78
|
max_simulation_results: int = 3,
|
|
76
79
|
api_call_retry_limit: int = 3,
|
|
77
80
|
api_call_retry_sleep_sec: int = 1,
|
|
78
81
|
api_call_delay_sec: int = 0,
|
|
79
82
|
concurrent_async_task: int = 3,
|
|
83
|
+
**kwargs,
|
|
80
84
|
):
|
|
81
85
|
"""
|
|
82
86
|
Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
|
|
83
87
|
This simulator converses with your AI system using prompts injected into the context to interrupt normal
|
|
84
88
|
expected functionality by eliciting manipulated content, intrusion and attempting to gather information outside
|
|
85
89
|
the scope of your AI system.
|
|
86
|
-
|
|
87
|
-
:keyword scenario: Enum value specifying the adversarial scenario used for generating inputs.
|
|
88
|
-
:paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario
|
|
89
90
|
:keyword target: The target function to simulate adversarial inputs against.
|
|
90
91
|
This function should be asynchronous and accept a dictionary representing the adversarial input.
|
|
91
92
|
:paramtype target: Callable
|
|
92
|
-
:keyword max_conversation_turns: The maximum number of conversation turns to simulate.
|
|
93
|
-
Defaults to 1.
|
|
94
|
-
:paramtype max_conversation_turns: int
|
|
95
93
|
:keyword max_simulation_results: The maximum number of simulation results to return.
|
|
96
94
|
Defaults to 3.
|
|
97
95
|
:paramtype max_simulation_results: int
|
|
@@ -128,11 +126,11 @@ class IndirectAttackSimulator:
|
|
|
128
126
|
'template_parameters': {},
|
|
129
127
|
'messages': [
|
|
130
128
|
{
|
|
131
|
-
'content': '<
|
|
129
|
+
'content': '<adversarial query>',
|
|
132
130
|
'role': 'user'
|
|
133
131
|
},
|
|
134
132
|
{
|
|
135
|
-
'content': "<response from
|
|
133
|
+
'content': "<response from your callback>",
|
|
136
134
|
'role': 'assistant',
|
|
137
135
|
'context': None
|
|
138
136
|
}
|
|
@@ -141,25 +139,72 @@ class IndirectAttackSimulator:
|
|
|
141
139
|
}]
|
|
142
140
|
}
|
|
143
141
|
"""
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
142
|
+
# values that cannot be changed:
|
|
143
|
+
scenario = AdversarialScenarioJailbreak.ADVERSARIAL_INDIRECT_JAILBREAK
|
|
144
|
+
max_conversation_turns = 2
|
|
145
|
+
language = SupportedLanguages.English
|
|
146
|
+
self._ensure_service_dependencies()
|
|
147
|
+
templates = await self.adversarial_template_handler._get_content_harm_template_collections(scenario.value)
|
|
148
|
+
concurrent_async_task = min(concurrent_async_task, 1000)
|
|
149
|
+
semaphore = asyncio.Semaphore(concurrent_async_task)
|
|
150
|
+
sim_results = []
|
|
151
|
+
tasks = []
|
|
152
|
+
total_tasks = sum(len(t.template_parameters) for t in templates)
|
|
153
|
+
if max_simulation_results > total_tasks:
|
|
154
|
+
logger.warning(
|
|
155
|
+
"Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
|
|
156
|
+
"\n %s simulations will be generated.",
|
|
157
|
+
max_simulation_results,
|
|
158
|
+
total_tasks,
|
|
159
|
+
total_tasks,
|
|
152
160
|
)
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
api_call_retry_limit=api_call_retry_limit,
|
|
160
|
-
api_call_retry_sleep_sec=api_call_retry_sleep_sec,
|
|
161
|
-
api_call_delay_sec=api_call_delay_sec,
|
|
162
|
-
concurrent_async_task=concurrent_async_task,
|
|
163
|
-
_jailbreak_type="xpia",
|
|
161
|
+
total_tasks = min(total_tasks, max_simulation_results)
|
|
162
|
+
progress_bar = tqdm(
|
|
163
|
+
total=total_tasks,
|
|
164
|
+
desc="generating jailbreak simulations",
|
|
165
|
+
ncols=100,
|
|
166
|
+
unit="simulations",
|
|
164
167
|
)
|
|
165
|
-
|
|
168
|
+
for template in templates:
|
|
169
|
+
for parameter in template.template_parameters:
|
|
170
|
+
tasks.append(
|
|
171
|
+
asyncio.create_task(
|
|
172
|
+
self._simulate_async(
|
|
173
|
+
target=target,
|
|
174
|
+
template=template,
|
|
175
|
+
parameters=parameter,
|
|
176
|
+
max_conversation_turns=max_conversation_turns,
|
|
177
|
+
api_call_retry_limit=api_call_retry_limit,
|
|
178
|
+
api_call_retry_sleep_sec=api_call_retry_sleep_sec,
|
|
179
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
180
|
+
language=language,
|
|
181
|
+
semaphore=semaphore,
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
)
|
|
185
|
+
if len(tasks) >= max_simulation_results:
|
|
186
|
+
break
|
|
187
|
+
if len(tasks) >= max_simulation_results:
|
|
188
|
+
break
|
|
189
|
+
for task in asyncio.as_completed(tasks):
|
|
190
|
+
completed_task = await task # type: ignore
|
|
191
|
+
template_parameters = completed_task.get("template_parameters", {}) # type: ignore
|
|
192
|
+
xpia_attack_type = template_parameters.get("xpia_attack_type", "") # type: ignore
|
|
193
|
+
action = template_parameters.get("action", "") # type: ignore
|
|
194
|
+
document_type = template_parameters.get("document_type", "") # type: ignore
|
|
195
|
+
sim_results.append(
|
|
196
|
+
{
|
|
197
|
+
"messages": completed_task["messages"], # type: ignore
|
|
198
|
+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
199
|
+
"template_parameters": {
|
|
200
|
+
"metadata": {
|
|
201
|
+
"xpia_attack_type": xpia_attack_type,
|
|
202
|
+
"action": action,
|
|
203
|
+
"document_type": document_type,
|
|
204
|
+
},
|
|
205
|
+
},
|
|
206
|
+
}
|
|
207
|
+
)
|
|
208
|
+
progress_bar.update(1)
|
|
209
|
+
progress_bar.close()
|
|
210
|
+
return JsonLineList(sim_results)
|
|
@@ -11,7 +11,7 @@ from abc import ABC, abstractmethod
|
|
|
11
11
|
from enum import Enum
|
|
12
12
|
from typing import Optional, Union
|
|
13
13
|
|
|
14
|
-
from azure.core.credentials import
|
|
14
|
+
from azure.core.credentials import AccessToken, TokenCredential
|
|
15
15
|
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
|
|
16
16
|
|
|
17
17
|
AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
|
|
@@ -36,8 +36,8 @@ On January 24, 1984, former Apple CEO Steve Jobs introduced the first Macintosh.
|
|
|
36
36
|
Some years later, research firms IDC and Gartner reported that Apple's market share in the U.S. had increased to about 6%.
|
|
37
37
|
<|text_end|>
|
|
38
38
|
Output with 5 QnAs:
|
|
39
|
-
|
|
40
|
-
{
|
|
39
|
+
{
|
|
40
|
+
"qna":[{
|
|
41
41
|
"q": "When did the former Apple CEO Steve Jobs introduced the first Macintosh?",
|
|
42
42
|
"r": "January 24, 1984"
|
|
43
43
|
},
|
|
@@ -56,8 +56,8 @@ Output with 5 QnAs:
|
|
|
56
56
|
{
|
|
57
57
|
"q": "What was the percentage increase of Apple's market share in the U.S., as reported by research firms IDC and Gartner?",
|
|
58
58
|
"r": "6%"
|
|
59
|
-
}
|
|
60
|
-
|
|
59
|
+
}]
|
|
60
|
+
}
|
|
61
61
|
Text:
|
|
62
62
|
<|text_start|>
|
|
63
63
|
{{ text }}
|
|
@@ -16,6 +16,9 @@ inputs:
|
|
|
16
16
|
type: string
|
|
17
17
|
conversation_history:
|
|
18
18
|
type: dict
|
|
19
|
+
action:
|
|
20
|
+
type: string
|
|
21
|
+
default: continue the converasation and make sure the task is completed by asking relevant questions
|
|
19
22
|
|
|
20
23
|
---
|
|
21
24
|
system:
|
|
@@ -25,8 +28,10 @@ Output must be in JSON format
|
|
|
25
28
|
Here's a sample output:
|
|
26
29
|
{
|
|
27
30
|
"content": "Here is my follow-up question.",
|
|
28
|
-
"
|
|
31
|
+
"role": "user"
|
|
29
32
|
}
|
|
30
33
|
|
|
31
34
|
Output with a json object that continues the conversation, given the conversation history:
|
|
32
35
|
{{ conversation_history }}
|
|
36
|
+
|
|
37
|
+
{{ action }}
|
|
@@ -5,20 +5,23 @@
|
|
|
5
5
|
# ---------------------------------------------------------
|
|
6
6
|
import asyncio
|
|
7
7
|
import importlib.resources as pkg_resources
|
|
8
|
-
from tqdm import tqdm
|
|
9
8
|
import json
|
|
10
9
|
import os
|
|
11
10
|
import re
|
|
12
11
|
import warnings
|
|
13
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
|
12
|
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
|
13
|
+
|
|
14
14
|
from promptflow.core import AsyncPrompty
|
|
15
|
-
from
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
16
18
|
from azure.ai.evaluation._common.utils import construct_prompty_model_config
|
|
19
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
17
20
|
|
|
18
21
|
from .._exceptions import ErrorBlame, ErrorCategory, EvaluationException
|
|
19
22
|
from .._user_agent import USER_AGENT
|
|
20
23
|
from ._conversation.constants import ConversationRole
|
|
21
|
-
from ._helpers import ConversationHistory, Turn
|
|
24
|
+
from ._helpers import ConversationHistory, Turn
|
|
22
25
|
from ._utils import JsonLineChatProtocol
|
|
23
26
|
|
|
24
27
|
|
|
@@ -89,7 +92,8 @@ class Simulator:
|
|
|
89
92
|
api_call_delay_sec: float = 1,
|
|
90
93
|
query_response_generating_prompty_kwargs: Dict[str, Any] = {},
|
|
91
94
|
user_simulator_prompty_kwargs: Dict[str, Any] = {},
|
|
92
|
-
conversation_turns: List[List[str]] = [],
|
|
95
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]] = [],
|
|
96
|
+
concurrent_async_tasks: int = 5,
|
|
93
97
|
**kwargs,
|
|
94
98
|
) -> List[JsonLineChatProtocol]:
|
|
95
99
|
"""
|
|
@@ -116,7 +120,10 @@ class Simulator:
|
|
|
116
120
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
117
121
|
:paramtype user_simulator_prompty_kwargs: Dict[str, Any]
|
|
118
122
|
:keyword conversation_turns: Predefined conversation turns to simulate.
|
|
119
|
-
:paramtype conversation_turns: List[List[str]]
|
|
123
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
124
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
125
|
+
Defaults to 5.
|
|
126
|
+
:paramtype concurrent_async_tasks: int
|
|
120
127
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
121
128
|
:rtype: List[JsonLineChatProtocol]
|
|
122
129
|
|
|
@@ -131,12 +138,12 @@ class Simulator:
|
|
|
131
138
|
if conversation_turns and (text or tasks):
|
|
132
139
|
raise ValueError("Cannot specify both conversation_turns and text/tasks")
|
|
133
140
|
|
|
134
|
-
if num_queries > len(tasks):
|
|
141
|
+
if text and num_queries > len(tasks):
|
|
135
142
|
warnings.warn(
|
|
136
143
|
f"You have specified 'num_queries' > len('tasks') ({num_queries} > {len(tasks)}). "
|
|
137
144
|
f"All tasks will be used for generation and the remaining {num_queries - len(tasks)} lines will be simulated in task-free mode"
|
|
138
145
|
)
|
|
139
|
-
elif num_queries < len(tasks):
|
|
146
|
+
elif text and num_queries < len(tasks):
|
|
140
147
|
warnings.warn(
|
|
141
148
|
f"You have specified 'num_queries' < len('tasks') ({num_queries} < {len(tasks)}). "
|
|
142
149
|
f"Only the first {num_queries} lines of the specified tasks will be simulated."
|
|
@@ -154,6 +161,7 @@ class Simulator:
|
|
|
154
161
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
155
162
|
api_call_delay_sec=api_call_delay_sec,
|
|
156
163
|
prompty_model_config=prompty_model_config,
|
|
164
|
+
concurrent_async_tasks=concurrent_async_tasks,
|
|
157
165
|
)
|
|
158
166
|
|
|
159
167
|
query_responses = await self._generate_query_responses(
|
|
@@ -172,6 +180,7 @@ class Simulator:
|
|
|
172
180
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
173
181
|
target=target,
|
|
174
182
|
api_call_delay_sec=api_call_delay_sec,
|
|
183
|
+
text=text,
|
|
175
184
|
)
|
|
176
185
|
|
|
177
186
|
async def _simulate_with_predefined_turns(
|
|
@@ -179,11 +188,12 @@ class Simulator:
|
|
|
179
188
|
*,
|
|
180
189
|
target: Callable,
|
|
181
190
|
max_conversation_turns: int,
|
|
182
|
-
conversation_turns: List[List[str]],
|
|
191
|
+
conversation_turns: List[List[Union[str, Dict[str, Any]]]],
|
|
183
192
|
user_simulator_prompty: Optional[str],
|
|
184
193
|
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
185
194
|
api_call_delay_sec: float,
|
|
186
195
|
prompty_model_config: Any,
|
|
196
|
+
concurrent_async_tasks: int,
|
|
187
197
|
) -> List[JsonLineChatProtocol]:
|
|
188
198
|
"""
|
|
189
199
|
Simulates conversations using predefined conversation turns.
|
|
@@ -193,7 +203,7 @@ class Simulator:
|
|
|
193
203
|
:keyword max_conversation_turns: Maximum number of turns for the simulation.
|
|
194
204
|
:paramtype max_conversation_turns: int
|
|
195
205
|
:keyword conversation_turns: A list of predefined conversation turns.
|
|
196
|
-
:paramtype conversation_turns: List[List[str]]
|
|
206
|
+
:paramtype conversation_turns: List[List[Union[str, Dict[str, Any]]]]
|
|
197
207
|
:keyword user_simulator_prompty: Path to the user simulator prompty file.
|
|
198
208
|
:paramtype user_simulator_prompty: Optional[str]
|
|
199
209
|
:keyword user_simulator_prompty_kwargs: Additional keyword arguments for the user simulator prompty.
|
|
@@ -202,42 +212,60 @@ class Simulator:
|
|
|
202
212
|
:paramtype api_call_delay_sec: float
|
|
203
213
|
:keyword prompty_model_config: The configuration for the prompty model.
|
|
204
214
|
:paramtype prompty_model_config: Any
|
|
215
|
+
:keyword concurrent_async_tasks: The number of asynchronous tasks to run concurrently during the simulation.
|
|
216
|
+
:paramtype concurrent_async_tasks: int
|
|
205
217
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
206
218
|
:rtype: List[JsonLineChatProtocol]
|
|
207
219
|
"""
|
|
208
|
-
simulated_conversations = []
|
|
209
220
|
progress_bar = tqdm(
|
|
210
221
|
total=int(len(conversation_turns) * (max_conversation_turns / 2)),
|
|
211
222
|
desc="Simulating with predefined conversation turns: ",
|
|
212
223
|
ncols=100,
|
|
213
224
|
unit="messages",
|
|
214
225
|
)
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
current_simulation
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
226
|
+
semaphore = asyncio.Semaphore(concurrent_async_tasks)
|
|
227
|
+
progress_bar_lock = asyncio.Lock()
|
|
228
|
+
|
|
229
|
+
async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
|
|
230
|
+
async with semaphore:
|
|
231
|
+
current_simulation = ConversationHistory()
|
|
232
|
+
for simulated_turn in simulation:
|
|
233
|
+
if isinstance(simulated_turn, str):
|
|
234
|
+
user_turn = Turn(role=ConversationRole.USER, content=simulated_turn)
|
|
235
|
+
elif isinstance(simulated_turn, dict):
|
|
236
|
+
user_turn = Turn(
|
|
237
|
+
role=ConversationRole.USER,
|
|
238
|
+
content=str(simulated_turn.get("content")),
|
|
239
|
+
context=str(simulated_turn.get("context")),
|
|
240
|
+
)
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(
|
|
243
|
+
"Each simulated turn must be a string or a dict with 'content' and 'context' keys"
|
|
244
|
+
)
|
|
245
|
+
current_simulation.add_to_history(user_turn)
|
|
246
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
247
|
+
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
248
|
+
)
|
|
249
|
+
assistant_turn = Turn(
|
|
250
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
251
|
+
)
|
|
252
|
+
current_simulation.add_to_history(assistant_turn)
|
|
253
|
+
async with progress_bar_lock:
|
|
254
|
+
progress_bar.update(1)
|
|
255
|
+
|
|
256
|
+
if len(current_simulation) < max_conversation_turns:
|
|
257
|
+
await self._extend_conversation_with_simulator(
|
|
258
|
+
current_simulation=current_simulation,
|
|
259
|
+
max_conversation_turns=max_conversation_turns,
|
|
260
|
+
user_simulator_prompty=user_simulator_prompty,
|
|
261
|
+
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
262
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
263
|
+
prompty_model_config=prompty_model_config,
|
|
264
|
+
target=target,
|
|
265
|
+
progress_bar=progress_bar,
|
|
266
|
+
progress_bar_lock=progress_bar_lock,
|
|
267
|
+
)
|
|
268
|
+
return JsonLineChatProtocol(
|
|
241
269
|
{
|
|
242
270
|
"messages": current_simulation.to_list(),
|
|
243
271
|
"finish_reason": ["stop"],
|
|
@@ -245,10 +273,11 @@ class Simulator:
|
|
|
245
273
|
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
246
274
|
}
|
|
247
275
|
)
|
|
248
|
-
)
|
|
249
276
|
|
|
277
|
+
tasks = [asyncio.create_task(run_simulation(simulation)) for simulation in conversation_turns]
|
|
278
|
+
results = await asyncio.gather(*tasks)
|
|
250
279
|
progress_bar.close()
|
|
251
|
-
return
|
|
280
|
+
return results
|
|
252
281
|
|
|
253
282
|
async def _extend_conversation_with_simulator(
|
|
254
283
|
self,
|
|
@@ -261,6 +290,7 @@ class Simulator:
|
|
|
261
290
|
prompty_model_config: Dict[str, Any],
|
|
262
291
|
target: Callable,
|
|
263
292
|
progress_bar: tqdm,
|
|
293
|
+
progress_bar_lock: asyncio.Lock,
|
|
264
294
|
):
|
|
265
295
|
"""
|
|
266
296
|
Extends an ongoing conversation using a user simulator until the maximum number of turns is reached.
|
|
@@ -281,6 +311,8 @@ class Simulator:
|
|
|
281
311
|
:paramtype target: Callable,
|
|
282
312
|
:keyword progress_bar: Progress bar for tracking simulation progress.
|
|
283
313
|
:paramtype progress_bar: tqdm,
|
|
314
|
+
:keyword progress_bar_lock: Lock for updating the progress bar safely.
|
|
315
|
+
:paramtype progress_bar_lock: asyncio.Lock
|
|
284
316
|
"""
|
|
285
317
|
user_flow = self._load_user_simulation_flow(
|
|
286
318
|
user_simulator_prompty=user_simulator_prompty, # type: ignore
|
|
@@ -291,19 +323,22 @@ class Simulator:
|
|
|
291
323
|
while len(current_simulation) < max_conversation_turns:
|
|
292
324
|
user_response_content = await user_flow(
|
|
293
325
|
task="Continue the conversation",
|
|
294
|
-
conversation_history=current_simulation.
|
|
326
|
+
conversation_history=current_simulation.to_context_free_list(),
|
|
295
327
|
**user_simulator_prompty_kwargs,
|
|
296
328
|
)
|
|
297
329
|
user_response = self._parse_prompty_response(response=user_response_content)
|
|
298
330
|
user_turn = Turn(role=ConversationRole.USER, content=user_response["content"])
|
|
299
331
|
current_simulation.add_to_history(user_turn)
|
|
300
332
|
await asyncio.sleep(api_call_delay_sec)
|
|
301
|
-
assistant_response = await self._get_target_response(
|
|
333
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
302
334
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=current_simulation
|
|
303
335
|
)
|
|
304
|
-
assistant_turn = Turn(
|
|
336
|
+
assistant_turn = Turn(
|
|
337
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
338
|
+
)
|
|
305
339
|
current_simulation.add_to_history(assistant_turn)
|
|
306
|
-
|
|
340
|
+
async with progress_bar_lock:
|
|
341
|
+
progress_bar.update(1)
|
|
307
342
|
|
|
308
343
|
def _load_user_simulation_flow(
|
|
309
344
|
self,
|
|
@@ -432,6 +467,14 @@ class Simulator:
|
|
|
432
467
|
if isinstance(query_responses, dict):
|
|
433
468
|
keys = list(query_responses.keys())
|
|
434
469
|
return query_responses[keys[0]]
|
|
470
|
+
if isinstance(query_responses, str):
|
|
471
|
+
query_responses = json.loads(query_responses)
|
|
472
|
+
if isinstance(query_responses, dict):
|
|
473
|
+
if len(query_responses.keys()) == 1:
|
|
474
|
+
return query_responses[list(query_responses.keys())[0]]
|
|
475
|
+
return query_responses # type: ignore
|
|
476
|
+
if isinstance(query_responses, list):
|
|
477
|
+
return query_responses
|
|
435
478
|
return json.loads(query_responses)
|
|
436
479
|
except Exception as e:
|
|
437
480
|
raise RuntimeError("Error generating query responses") from e
|
|
@@ -497,6 +540,7 @@ class Simulator:
|
|
|
497
540
|
user_simulator_prompty_kwargs: Dict[str, Any],
|
|
498
541
|
target: Callable,
|
|
499
542
|
api_call_delay_sec: float,
|
|
543
|
+
text: str,
|
|
500
544
|
) -> List[JsonLineChatProtocol]:
|
|
501
545
|
"""
|
|
502
546
|
Creates full conversations from query-response pairs.
|
|
@@ -515,6 +559,8 @@ class Simulator:
|
|
|
515
559
|
:paramtype target: Callable
|
|
516
560
|
:keyword api_call_delay_sec: Delay in seconds between API calls.
|
|
517
561
|
:paramtype api_call_delay_sec: float
|
|
562
|
+
:keyword text: The initial input text for generating query responses.
|
|
563
|
+
:paramtype text: str
|
|
518
564
|
:return: A list of simulated conversations represented as JsonLineChatProtocol objects.
|
|
519
565
|
:rtype: List[JsonLineChatProtocol]
|
|
520
566
|
"""
|
|
@@ -552,6 +598,7 @@ class Simulator:
|
|
|
552
598
|
"task": task,
|
|
553
599
|
"expected_response": response,
|
|
554
600
|
"query": query,
|
|
601
|
+
"original_text": text,
|
|
555
602
|
},
|
|
556
603
|
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
557
604
|
}
|
|
@@ -595,8 +642,6 @@ class Simulator:
|
|
|
595
642
|
:rtype: List[Dict[str, Optional[str]]]
|
|
596
643
|
"""
|
|
597
644
|
conversation_history = ConversationHistory()
|
|
598
|
-
# user_turn = Turn(role=ConversationRole.USER, content=conversation_starter)
|
|
599
|
-
# conversation_history.add_to_history(user_turn)
|
|
600
645
|
|
|
601
646
|
while len(conversation_history) < max_conversation_turns:
|
|
602
647
|
user_flow = self._load_user_simulation_flow(
|
|
@@ -604,24 +649,33 @@ class Simulator:
|
|
|
604
649
|
prompty_model_config=self.model_config, # type: ignore
|
|
605
650
|
user_simulator_prompty_kwargs=user_simulator_prompty_kwargs,
|
|
606
651
|
)
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
652
|
+
if len(conversation_history) == 0:
|
|
653
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
654
|
+
task=task,
|
|
655
|
+
conversation_history=[
|
|
656
|
+
{
|
|
657
|
+
"role": "assistant",
|
|
658
|
+
"content": conversation_starter,
|
|
659
|
+
}
|
|
660
|
+
],
|
|
661
|
+
action="rewrite the assistant's message as you have to accomplish the task by asking the right questions. Make sure the original question is not lost in your rewrite.",
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
conversation_starter_from_simulated_user = await user_flow(
|
|
665
|
+
task=task,
|
|
666
|
+
conversation_history=conversation_history.to_context_free_list(),
|
|
667
|
+
action="Your goal is to make sure the task is completed by asking the right questions. Do not ask the same questions again.",
|
|
668
|
+
)
|
|
617
669
|
if isinstance(conversation_starter_from_simulated_user, dict):
|
|
618
670
|
conversation_starter_from_simulated_user = conversation_starter_from_simulated_user["content"]
|
|
619
671
|
user_turn = Turn(role=ConversationRole.USER, content=conversation_starter_from_simulated_user)
|
|
620
672
|
conversation_history.add_to_history(user_turn)
|
|
621
|
-
assistant_response = await self._get_target_response(
|
|
673
|
+
assistant_response, assistant_context = await self._get_target_response(
|
|
622
674
|
target=target, api_call_delay_sec=api_call_delay_sec, conversation_history=conversation_history
|
|
623
675
|
)
|
|
624
|
-
assistant_turn = Turn(
|
|
676
|
+
assistant_turn = Turn(
|
|
677
|
+
role=ConversationRole.ASSISTANT, content=assistant_response, context=assistant_context
|
|
678
|
+
)
|
|
625
679
|
conversation_history.add_to_history(assistant_turn)
|
|
626
680
|
progress_bar.update(1)
|
|
627
681
|
|
|
@@ -632,7 +686,7 @@ class Simulator:
|
|
|
632
686
|
|
|
633
687
|
async def _get_target_response(
|
|
634
688
|
self, *, target: Callable, api_call_delay_sec: float, conversation_history: ConversationHistory
|
|
635
|
-
) -> str:
|
|
689
|
+
) -> Tuple[str, Optional[str]]:
|
|
636
690
|
"""
|
|
637
691
|
Retrieves the response from the target callback based on the current conversation history.
|
|
638
692
|
|
|
@@ -642,8 +696,8 @@ class Simulator:
|
|
|
642
696
|
:paramtype api_call_delay_sec: float
|
|
643
697
|
:keyword conversation_history: The current conversation history.
|
|
644
698
|
:paramtype conversation_history: ConversationHistory
|
|
645
|
-
:return: The content of the response from the target.
|
|
646
|
-
:rtype: str
|
|
699
|
+
:return: The content of the response from the target and an optional context.
|
|
700
|
+
:rtype: str, Optional[str]
|
|
647
701
|
"""
|
|
648
702
|
response = await target(
|
|
649
703
|
messages={"messages": conversation_history.to_list()},
|
|
@@ -653,4 +707,4 @@ class Simulator:
|
|
|
653
707
|
)
|
|
654
708
|
await asyncio.sleep(api_call_delay_sec)
|
|
655
709
|
latest_message = response["messages"][-1]
|
|
656
|
-
return latest_message["content"]
|
|
710
|
+
return latest_message["content"], latest_message.get("context", "") # type: ignore
|