azure-ai-evaluation 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/_azure/__init__.py +3 -0
- azure/ai/evaluation/_azure/_clients.py +204 -0
- azure/ai/evaluation/_azure/_models.py +227 -0
- azure/ai/evaluation/_azure/_token_manager.py +118 -0
- azure/ai/evaluation/_common/rai_service.py +30 -21
- azure/ai/evaluation/_constants.py +19 -0
- azure/ai/evaluation/_evaluate/_batch_run/__init__.py +2 -1
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +1 -1
- azure/ai/evaluation/_evaluate/_eval_run.py +16 -43
- azure/ai/evaluation/_evaluate/_evaluate.py +76 -44
- azure/ai/evaluation/_evaluate/_utils.py +93 -34
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +46 -25
- azure/ai/evaluation/_evaluators/_common/__init__.py +2 -0
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +140 -5
- azure/ai/evaluation/_evaluators/_common/_base_multi_eval.py +61 -0
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +12 -1
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +40 -2
- azure/ai/evaluation/_evaluators/_common/_conversation_aggregators.py +49 -0
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +6 -43
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +2 -0
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +2 -0
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +2 -0
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +2 -0
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +61 -68
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +45 -23
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +55 -34
- azure/ai/evaluation/_evaluators/_qa/_qa.py +32 -27
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +44 -23
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +42 -82
- azure/ai/evaluation/_http_utils.py +6 -4
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +0 -4
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +0 -4
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +0 -4
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/simulator/_adversarial_scenario.py +2 -0
- azure/ai/evaluation/simulator/_adversarial_simulator.py +35 -16
- azure/ai/evaluation/simulator/_conversation/__init__.py +128 -7
- azure/ai/evaluation/simulator/_conversation/_conversation.py +0 -1
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +1 -0
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +40 -0
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +1 -0
- azure/ai/evaluation/simulator/_simulator.py +24 -13
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.2.0.dist-info}/METADATA +84 -15
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.2.0.dist-info}/RECORD +47 -41
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.2.0.dist-info}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.2.0.dist-info}/WHEEL +0 -0
- {azure_ai_evaluation-1.0.1.dist-info → azure_ai_evaluation-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -2,85 +2,17 @@
|
|
|
2
2
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
|
|
5
|
-
import math
|
|
6
5
|
import os
|
|
7
|
-
import
|
|
6
|
+
from typing import Dict
|
|
8
7
|
|
|
9
|
-
from
|
|
10
|
-
from promptflow.core import AsyncPrompty
|
|
8
|
+
from typing_extensions import overload, override
|
|
11
9
|
|
|
12
|
-
from azure.ai.evaluation.
|
|
10
|
+
from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
|
|
13
11
|
|
|
14
|
-
from ..._common.utils import construct_prompty_model_config, validate_model_config
|
|
15
12
|
|
|
16
|
-
|
|
17
|
-
from ..._user_agent import USER_AGENT
|
|
18
|
-
except ImportError:
|
|
19
|
-
USER_AGENT = "None"
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class _AsyncSimilarityEvaluator:
|
|
23
|
-
# Constants must be defined within eval's directory to be save/loadable
|
|
24
|
-
_PROMPTY_FILE = "similarity.prompty"
|
|
25
|
-
_LLM_CALL_TIMEOUT = 600
|
|
26
|
-
_DEFAULT_OPEN_API_VERSION = "2024-02-15-preview"
|
|
27
|
-
|
|
28
|
-
def __init__(self, model_config: dict):
|
|
29
|
-
prompty_model_config = construct_prompty_model_config(
|
|
30
|
-
validate_model_config(model_config),
|
|
31
|
-
self._DEFAULT_OPEN_API_VERSION,
|
|
32
|
-
USER_AGENT,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
current_dir = os.path.dirname(__file__)
|
|
36
|
-
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
|
|
37
|
-
self._flow = AsyncPrompty.load(source=prompty_path, model=prompty_model_config)
|
|
38
|
-
|
|
39
|
-
async def __call__(self, *, query: str, response: str, ground_truth: str, **kwargs):
|
|
40
|
-
"""
|
|
41
|
-
Evaluate similarity.
|
|
42
|
-
|
|
43
|
-
:keyword query: The query to be evaluated.
|
|
44
|
-
:paramtype query: str
|
|
45
|
-
:keyword response: The response to be evaluated.
|
|
46
|
-
:paramtype response: str
|
|
47
|
-
:keyword ground_truth: The ground truth to be evaluated.
|
|
48
|
-
:paramtype ground_truth: str
|
|
49
|
-
:return: The similarity score.
|
|
50
|
-
:rtype: Dict[str, float]
|
|
51
|
-
"""
|
|
52
|
-
# Validate input parameters
|
|
53
|
-
query = str(query or "")
|
|
54
|
-
response = str(response or "")
|
|
55
|
-
ground_truth = str(ground_truth or "")
|
|
56
|
-
|
|
57
|
-
if not (query.strip() and response.strip() and ground_truth.strip()):
|
|
58
|
-
msg = "'query', 'response' and 'ground_truth' must be non-empty strings."
|
|
59
|
-
raise EvaluationException(
|
|
60
|
-
message=msg,
|
|
61
|
-
internal_message=msg,
|
|
62
|
-
error_category=ErrorCategory.MISSING_FIELD,
|
|
63
|
-
error_blame=ErrorBlame.USER_ERROR,
|
|
64
|
-
error_target=ErrorTarget.SIMILARITY_EVALUATOR,
|
|
65
|
-
)
|
|
66
|
-
|
|
67
|
-
# Run the evaluation flow
|
|
68
|
-
llm_output = await self._flow(
|
|
69
|
-
query=query, response=response, ground_truth=ground_truth, timeout=self._LLM_CALL_TIMEOUT, **kwargs
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
score = math.nan
|
|
73
|
-
if llm_output:
|
|
74
|
-
match = re.search(r"\d", llm_output)
|
|
75
|
-
if match:
|
|
76
|
-
score = float(match.group())
|
|
77
|
-
|
|
78
|
-
return {"similarity": float(score), "gpt_similarity": float(score)}
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
class SimilarityEvaluator:
|
|
13
|
+
class SimilarityEvaluator(PromptyEvaluatorBase):
|
|
82
14
|
"""
|
|
83
|
-
Evaluates similarity score for a given query, response, and ground truth
|
|
15
|
+
Evaluates similarity score for a given query, response, and ground truth.
|
|
84
16
|
|
|
85
17
|
The similarity measure evaluates the likeness between a ground truth sentence (or document) and the
|
|
86
18
|
AI model's generated prediction. This calculation involves creating sentence-level embeddings for both
|
|
@@ -113,13 +45,27 @@ class SimilarityEvaluator:
|
|
|
113
45
|
however, it is recommended to use the new key moving forward as the old key will be deprecated in the future.
|
|
114
46
|
"""
|
|
115
47
|
|
|
116
|
-
|
|
48
|
+
# Constants must be defined within eval's directory to be save/loadable
|
|
49
|
+
|
|
50
|
+
_PROMPTY_FILE = "similarity.prompty"
|
|
51
|
+
_RESULT_KEY = "similarity"
|
|
52
|
+
|
|
53
|
+
id = "similarity"
|
|
117
54
|
"""Evaluator identifier, experimental and to be used only with evaluation in cloud."""
|
|
118
55
|
|
|
56
|
+
@override
|
|
119
57
|
def __init__(self, model_config):
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
58
|
+
current_dir = os.path.dirname(__file__)
|
|
59
|
+
prompty_path = os.path.join(current_dir, self._PROMPTY_FILE)
|
|
60
|
+
super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY)
|
|
61
|
+
|
|
62
|
+
# Ignoring a mypy error about having only 1 overload function.
|
|
63
|
+
# We want to use the overload style for all evals, even single-inputs. This is both to make
|
|
64
|
+
# refactoring to multi-input styles easier, stylistic consistency consistency across evals,
|
|
65
|
+
# and due to the fact that non-overloaded syntax now causes various parsing issues that
|
|
66
|
+
# we don't want to deal with.
|
|
67
|
+
@overload # type: ignore
|
|
68
|
+
def __call__(self, *, query: str, response: str, ground_truth: str) -> Dict[str, float]:
|
|
123
69
|
"""
|
|
124
70
|
Evaluate similarity.
|
|
125
71
|
|
|
@@ -132,9 +78,23 @@ class SimilarityEvaluator:
|
|
|
132
78
|
:return: The similarity score.
|
|
133
79
|
:rtype: Dict[str, float]
|
|
134
80
|
"""
|
|
135
|
-
return async_run_allowing_running_loop(
|
|
136
|
-
self._async_evaluator, query=query, response=response, ground_truth=ground_truth, **kwargs
|
|
137
|
-
)
|
|
138
81
|
|
|
139
|
-
|
|
140
|
-
|
|
82
|
+
@override
|
|
83
|
+
def __call__( # pylint: disable=docstring-missing-param
|
|
84
|
+
self,
|
|
85
|
+
*args,
|
|
86
|
+
**kwargs,
|
|
87
|
+
):
|
|
88
|
+
"""
|
|
89
|
+
Evaluate similarity.
|
|
90
|
+
|
|
91
|
+
:keyword query: The query to be evaluated.
|
|
92
|
+
:paramtype query: str
|
|
93
|
+
:keyword response: The response to be evaluated.
|
|
94
|
+
:paramtype response: str
|
|
95
|
+
:keyword ground_truth: The ground truth to be evaluated.
|
|
96
|
+
:paramtype ground_truth: str
|
|
97
|
+
:return: The similarity score.
|
|
98
|
+
:rtype: Dict[str, float]
|
|
99
|
+
"""
|
|
100
|
+
return super().__call__(*args, **kwargs)
|
|
@@ -448,19 +448,21 @@ class AsyncHttpPipeline(AsyncPipeline):
|
|
|
448
448
|
return cast(Self, await super().__aenter__())
|
|
449
449
|
|
|
450
450
|
|
|
451
|
-
def get_http_client() -> HttpPipeline:
|
|
451
|
+
def get_http_client(**kwargs: Any) -> HttpPipeline:
|
|
452
452
|
"""Get an HttpPipeline configured with common policies.
|
|
453
453
|
|
|
454
454
|
:returns: An HttpPipeline with a set of applied policies:
|
|
455
455
|
:rtype: HttpPipeline
|
|
456
456
|
"""
|
|
457
|
-
|
|
457
|
+
kwargs.setdefault("user_agent_policy", UserAgentPolicy(base_user_agent=USER_AGENT))
|
|
458
|
+
return HttpPipeline(**kwargs)
|
|
458
459
|
|
|
459
460
|
|
|
460
|
-
def get_async_http_client() -> AsyncHttpPipeline:
|
|
461
|
+
def get_async_http_client(**kwargs: Any) -> AsyncHttpPipeline:
|
|
461
462
|
"""Get an AsyncHttpPipeline configured with common policies.
|
|
462
463
|
|
|
463
464
|
:returns: An AsyncHttpPipeline with a set of applied policies:
|
|
464
465
|
:rtype: AsyncHttpPipeline
|
|
465
466
|
"""
|
|
466
|
-
|
|
467
|
+
kwargs.setdefault("user_agent_policy", UserAgentPolicy(base_user_agent=USER_AGENT))
|
|
468
|
+
return AsyncHttpPipeline(**kwargs)
|
|
@@ -32,10 +32,6 @@ ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml
|
|
|
32
32
|
In these examples settings.xml lists input files and formats.
|
|
33
33
|
"""
|
|
34
34
|
|
|
35
|
-
from __future__ import absolute_import
|
|
36
|
-
from __future__ import division
|
|
37
|
-
from __future__ import print_function
|
|
38
|
-
|
|
39
35
|
import collections
|
|
40
36
|
import re
|
|
41
37
|
|
|
@@ -21,10 +21,6 @@ Aggregation functions use bootstrap resampling to compute confidence intervals
|
|
|
21
21
|
as per the original ROUGE perl implementation.
|
|
22
22
|
"""
|
|
23
23
|
|
|
24
|
-
from __future__ import absolute_import
|
|
25
|
-
from __future__ import division
|
|
26
|
-
from __future__ import print_function
|
|
27
|
-
|
|
28
24
|
import abc
|
|
29
25
|
import collections
|
|
30
26
|
from typing import Dict
|
azure/ai/evaluation/_version.py
CHANGED
|
@@ -6,8 +6,7 @@
|
|
|
6
6
|
import asyncio
|
|
7
7
|
import logging
|
|
8
8
|
import random
|
|
9
|
-
from typing import Any, Callable, Dict, List,
|
|
10
|
-
from itertools import zip_longest
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
|
11
10
|
|
|
12
11
|
from tqdm import tqdm
|
|
13
12
|
|
|
@@ -16,13 +15,19 @@ from azure.ai.evaluation._common.utils import validate_azure_ai_project
|
|
|
16
15
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
17
16
|
from azure.ai.evaluation._http_utils import get_async_http_client
|
|
18
17
|
from azure.ai.evaluation._model_configurations import AzureAIProject
|
|
19
|
-
from azure.ai.evaluation.simulator import AdversarialScenario
|
|
18
|
+
from azure.ai.evaluation.simulator import AdversarialScenario, AdversarialScenarioJailbreak
|
|
20
19
|
from azure.ai.evaluation.simulator._adversarial_scenario import _UnstableAdversarialScenario
|
|
21
20
|
from azure.core.credentials import TokenCredential
|
|
22
21
|
from azure.core.pipeline.policies import AsyncRetryPolicy, RetryMode
|
|
23
22
|
|
|
24
23
|
from ._constants import SupportedLanguages
|
|
25
|
-
from ._conversation import
|
|
24
|
+
from ._conversation import (
|
|
25
|
+
CallbackConversationBot,
|
|
26
|
+
MultiModalConversationBot,
|
|
27
|
+
ConversationBot,
|
|
28
|
+
ConversationRole,
|
|
29
|
+
ConversationTurn,
|
|
30
|
+
)
|
|
26
31
|
from ._conversation._conversation import simulate_conversation
|
|
27
32
|
from ._model_tools import (
|
|
28
33
|
AdversarialTemplateHandler,
|
|
@@ -205,7 +210,6 @@ class AdversarialSimulator:
|
|
|
205
210
|
ncols=100,
|
|
206
211
|
unit="simulations",
|
|
207
212
|
)
|
|
208
|
-
|
|
209
213
|
if randomize_order:
|
|
210
214
|
# The template parameter lists are persistent across sim runs within a session,
|
|
211
215
|
# So randomize a the selection instead of the parameter list directly,
|
|
@@ -214,11 +218,11 @@ class AdversarialSimulator:
|
|
|
214
218
|
random.seed(randomization_seed)
|
|
215
219
|
random.shuffle(templates)
|
|
216
220
|
parameter_lists = [t.template_parameters for t in templates]
|
|
217
|
-
zipped_parameters = list(
|
|
221
|
+
zipped_parameters = list(zip(*parameter_lists))
|
|
218
222
|
for param_group in zipped_parameters:
|
|
219
223
|
for template, parameter in zip(templates, param_group):
|
|
220
224
|
if _jailbreak_type == "upia":
|
|
221
|
-
parameter = self.
|
|
225
|
+
parameter = self._add_jailbreak_parameter(parameter, random.choice(jailbreak_dataset))
|
|
222
226
|
tasks.append(
|
|
223
227
|
asyncio.create_task(
|
|
224
228
|
self._simulate_async(
|
|
@@ -231,6 +235,7 @@ class AdversarialSimulator:
|
|
|
231
235
|
api_call_delay_sec=api_call_delay_sec,
|
|
232
236
|
language=language,
|
|
233
237
|
semaphore=semaphore,
|
|
238
|
+
scenario=scenario,
|
|
234
239
|
)
|
|
235
240
|
)
|
|
236
241
|
)
|
|
@@ -292,10 +297,13 @@ class AdversarialSimulator:
|
|
|
292
297
|
api_call_delay_sec: int,
|
|
293
298
|
language: SupportedLanguages,
|
|
294
299
|
semaphore: asyncio.Semaphore,
|
|
300
|
+
scenario: Union[AdversarialScenario, AdversarialScenarioJailbreak],
|
|
295
301
|
) -> List[Dict]:
|
|
296
|
-
user_bot = self._setup_bot(
|
|
302
|
+
user_bot = self._setup_bot(
|
|
303
|
+
role=ConversationRole.USER, template=template, parameters=parameters, scenario=scenario
|
|
304
|
+
)
|
|
297
305
|
system_bot = self._setup_bot(
|
|
298
|
-
target=target, role=ConversationRole.ASSISTANT, template=template, parameters=parameters
|
|
306
|
+
target=target, role=ConversationRole.ASSISTANT, template=template, parameters=parameters, scenario=scenario
|
|
299
307
|
)
|
|
300
308
|
bots = [user_bot, system_bot]
|
|
301
309
|
session = get_async_http_client().with_policies(
|
|
@@ -341,6 +349,7 @@ class AdversarialSimulator:
|
|
|
341
349
|
template: AdversarialTemplate,
|
|
342
350
|
parameters: TemplateParameters,
|
|
343
351
|
target: Optional[Callable] = None,
|
|
352
|
+
scenario: Union[AdversarialScenario, AdversarialScenarioJailbreak],
|
|
344
353
|
) -> ConversationBot:
|
|
345
354
|
if role is ConversationRole.USER:
|
|
346
355
|
model = self._get_user_proxy_completion_model(
|
|
@@ -372,6 +381,21 @@ class AdversarialSimulator:
|
|
|
372
381
|
def __call__(self) -> None:
|
|
373
382
|
pass
|
|
374
383
|
|
|
384
|
+
if scenario in [
|
|
385
|
+
_UnstableAdversarialScenario.ADVERSARIAL_IMAGE_GEN,
|
|
386
|
+
_UnstableAdversarialScenario.ADVERSARIAL_IMAGE_MULTIMODAL,
|
|
387
|
+
]:
|
|
388
|
+
return MultiModalConversationBot(
|
|
389
|
+
callback=target,
|
|
390
|
+
role=role,
|
|
391
|
+
model=DummyModel(),
|
|
392
|
+
user_template=str(template),
|
|
393
|
+
user_template_parameters=parameters,
|
|
394
|
+
rai_client=self.rai_client,
|
|
395
|
+
conversation_template="",
|
|
396
|
+
instantiation_parameters={},
|
|
397
|
+
)
|
|
398
|
+
|
|
375
399
|
return CallbackConversationBot(
|
|
376
400
|
callback=target,
|
|
377
401
|
role=role,
|
|
@@ -391,13 +415,8 @@ class AdversarialSimulator:
|
|
|
391
415
|
blame=ErrorBlame.SYSTEM_ERROR,
|
|
392
416
|
)
|
|
393
417
|
|
|
394
|
-
def
|
|
395
|
-
|
|
396
|
-
if key in parameters.keys():
|
|
397
|
-
parameters[key] = f"{to_join} {parameters[key]}"
|
|
398
|
-
else:
|
|
399
|
-
parameters[key] = to_join
|
|
400
|
-
|
|
418
|
+
def _add_jailbreak_parameter(self, parameters: TemplateParameters, to_join: str) -> TemplateParameters:
|
|
419
|
+
parameters["jailbreak_string"] = to_join
|
|
401
420
|
return parameters
|
|
402
421
|
|
|
403
422
|
def call_sync(
|
|
@@ -9,12 +9,12 @@ import time
|
|
|
9
9
|
from dataclasses import dataclass
|
|
10
10
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
|
|
11
11
|
|
|
12
|
+
import re
|
|
12
13
|
import jinja2
|
|
13
14
|
|
|
14
15
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
15
16
|
from azure.ai.evaluation._http_utils import AsyncHttpPipeline
|
|
16
|
-
|
|
17
|
-
from .._model_tools import LLMBase, OpenAIChatCompletionsModel
|
|
17
|
+
from .._model_tools import LLMBase, OpenAIChatCompletionsModel, RAIClient
|
|
18
18
|
from .._model_tools._template_handler import TemplateParameters
|
|
19
19
|
from .constants import ConversationRole
|
|
20
20
|
|
|
@@ -128,15 +128,19 @@ class ConversationBot:
|
|
|
128
128
|
self.conversation_starter: Optional[Union[str, jinja2.Template, Dict]] = None
|
|
129
129
|
if role == ConversationRole.USER:
|
|
130
130
|
if "conversation_starter" in self.persona_template_args:
|
|
131
|
+
print(self.persona_template_args)
|
|
131
132
|
conversation_starter_content = self.persona_template_args["conversation_starter"]
|
|
132
133
|
if isinstance(conversation_starter_content, dict):
|
|
133
134
|
self.conversation_starter = conversation_starter_content
|
|
135
|
+
print(f"Conversation starter content: {conversation_starter_content}")
|
|
134
136
|
else:
|
|
135
137
|
try:
|
|
136
138
|
self.conversation_starter = jinja2.Template(
|
|
137
139
|
conversation_starter_content, undefined=jinja2.StrictUndefined
|
|
138
140
|
)
|
|
139
|
-
|
|
141
|
+
print("Successfully created a Jinja2 template for the conversation starter.")
|
|
142
|
+
except jinja2.exceptions.TemplateSyntaxError as e: # noqa: F841
|
|
143
|
+
print(f"Template syntax error: {e}. Using raw content.")
|
|
140
144
|
self.conversation_starter = conversation_starter_content
|
|
141
145
|
else:
|
|
142
146
|
self.logger.info(
|
|
@@ -175,6 +179,9 @@ class ConversationBot:
|
|
|
175
179
|
samples = [self.conversation_starter.render(**self.persona_template_args)]
|
|
176
180
|
else:
|
|
177
181
|
samples = [self.conversation_starter]
|
|
182
|
+
jailbreak_string = self.persona_template_args.get("jailbreak_string", None)
|
|
183
|
+
if jailbreak_string:
|
|
184
|
+
samples = [f"{jailbreak_string} {samples[0]}"]
|
|
178
185
|
time_taken = 0
|
|
179
186
|
|
|
180
187
|
finish_reason = ["stop"]
|
|
@@ -271,8 +278,6 @@ class CallbackConversationBot(ConversationBot):
|
|
|
271
278
|
"id": None,
|
|
272
279
|
"template_parameters": {},
|
|
273
280
|
}
|
|
274
|
-
self.logger.info("Using user provided callback returning response.")
|
|
275
|
-
|
|
276
281
|
time_taken = end_time - start_time
|
|
277
282
|
try:
|
|
278
283
|
response = {
|
|
@@ -290,8 +295,6 @@ class CallbackConversationBot(ConversationBot):
|
|
|
290
295
|
blame=ErrorBlame.USER_ERROR,
|
|
291
296
|
) from exc
|
|
292
297
|
|
|
293
|
-
self.logger.info("Parsed callback response")
|
|
294
|
-
|
|
295
298
|
return response, {}, time_taken, result
|
|
296
299
|
|
|
297
300
|
# Bug 3354264: template is unused in the method - is this intentional?
|
|
@@ -308,9 +311,127 @@ class CallbackConversationBot(ConversationBot):
|
|
|
308
311
|
}
|
|
309
312
|
|
|
310
313
|
|
|
314
|
+
class MultiModalConversationBot(ConversationBot):
|
|
315
|
+
"""MultiModal Conversation bot that uses a user provided callback to generate responses.
|
|
316
|
+
|
|
317
|
+
:param callback: The callback function to use to generate responses.
|
|
318
|
+
:type callback: Callable
|
|
319
|
+
:param user_template: The template to use for the request.
|
|
320
|
+
:type user_template: str
|
|
321
|
+
:param user_template_parameters: The template parameters to use for the request.
|
|
322
|
+
:type user_template_parameters: Dict
|
|
323
|
+
:param args: Optional arguments to pass to the parent class.
|
|
324
|
+
:type args: Any
|
|
325
|
+
:param kwargs: Optional keyword arguments to pass to the parent class.
|
|
326
|
+
:type kwargs: Any
|
|
327
|
+
"""
|
|
328
|
+
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
callback: Callable,
|
|
332
|
+
user_template: str,
|
|
333
|
+
user_template_parameters: TemplateParameters,
|
|
334
|
+
rai_client: RAIClient,
|
|
335
|
+
*args,
|
|
336
|
+
**kwargs,
|
|
337
|
+
) -> None:
|
|
338
|
+
self.callback = callback
|
|
339
|
+
self.user_template = user_template
|
|
340
|
+
self.user_template_parameters = user_template_parameters
|
|
341
|
+
self.rai_client = rai_client
|
|
342
|
+
|
|
343
|
+
super().__init__(*args, **kwargs)
|
|
344
|
+
|
|
345
|
+
async def generate_response(
|
|
346
|
+
self,
|
|
347
|
+
session: AsyncHttpPipeline,
|
|
348
|
+
conversation_history: List[Any],
|
|
349
|
+
max_history: int,
|
|
350
|
+
turn_number: int = 0,
|
|
351
|
+
) -> Tuple[dict, dict, float, dict]:
|
|
352
|
+
previous_prompt = conversation_history[-1]
|
|
353
|
+
chat_protocol_message = await self._to_chat_protocol(conversation_history, self.user_template_parameters)
|
|
354
|
+
|
|
355
|
+
# replace prompt with {image.jpg} tags with image content data.
|
|
356
|
+
conversation_history.pop()
|
|
357
|
+
conversation_history.append(
|
|
358
|
+
ConversationTurn(
|
|
359
|
+
role=previous_prompt.role,
|
|
360
|
+
name=previous_prompt.name,
|
|
361
|
+
message=chat_protocol_message["messages"][0]["content"],
|
|
362
|
+
full_response=previous_prompt.full_response,
|
|
363
|
+
request=chat_protocol_message,
|
|
364
|
+
)
|
|
365
|
+
)
|
|
366
|
+
msg_copy = copy.deepcopy(chat_protocol_message)
|
|
367
|
+
result = {}
|
|
368
|
+
start_time = time.time()
|
|
369
|
+
result = await self.callback(msg_copy)
|
|
370
|
+
end_time = time.time()
|
|
371
|
+
if not result:
|
|
372
|
+
result = {
|
|
373
|
+
"messages": [{"content": "Callback did not return a response.", "role": "assistant"}],
|
|
374
|
+
"finish_reason": ["stop"],
|
|
375
|
+
"id": None,
|
|
376
|
+
"template_parameters": {},
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
time_taken = end_time - start_time
|
|
380
|
+
try:
|
|
381
|
+
response = {
|
|
382
|
+
"samples": [result["messages"][-1]["content"]],
|
|
383
|
+
"finish_reason": ["stop"],
|
|
384
|
+
"id": None,
|
|
385
|
+
}
|
|
386
|
+
except Exception as exc:
|
|
387
|
+
msg = "User provided callback does not conform to chat protocol standard."
|
|
388
|
+
raise EvaluationException(
|
|
389
|
+
message=msg,
|
|
390
|
+
internal_message=msg,
|
|
391
|
+
target=ErrorTarget.CALLBACK_CONVERSATION_BOT,
|
|
392
|
+
category=ErrorCategory.INVALID_VALUE,
|
|
393
|
+
blame=ErrorBlame.USER_ERROR,
|
|
394
|
+
) from exc
|
|
395
|
+
|
|
396
|
+
return response, chat_protocol_message, time_taken, result
|
|
397
|
+
|
|
398
|
+
async def _to_chat_protocol(self, conversation_history, template_parameters): # pylint: disable=unused-argument
|
|
399
|
+
messages = []
|
|
400
|
+
|
|
401
|
+
for _, m in enumerate(conversation_history):
|
|
402
|
+
if "image:" in m.message:
|
|
403
|
+
content = await self._to_multi_modal_content(m.message)
|
|
404
|
+
messages.append({"content": content, "role": m.role.value})
|
|
405
|
+
else:
|
|
406
|
+
messages.append({"content": m.message, "role": m.role.value})
|
|
407
|
+
|
|
408
|
+
return {
|
|
409
|
+
"template_parameters": template_parameters,
|
|
410
|
+
"messages": messages,
|
|
411
|
+
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
async def _to_multi_modal_content(self, text: str) -> list:
|
|
415
|
+
split_text = re.findall(r"[^{}]+|\{[^{}]*\}", text)
|
|
416
|
+
messages = [
|
|
417
|
+
text.strip("{}").replace("image:", "").strip() if text.startswith("{") else text for text in split_text
|
|
418
|
+
]
|
|
419
|
+
contents = []
|
|
420
|
+
for msg in messages:
|
|
421
|
+
if msg.startswith("image_understanding/"):
|
|
422
|
+
encoded_image = await self.rai_client.get_image_data(msg)
|
|
423
|
+
contents.append(
|
|
424
|
+
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}},
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
contents.append({"type": "text", "text": msg})
|
|
428
|
+
return contents
|
|
429
|
+
|
|
430
|
+
|
|
311
431
|
__all__ = [
|
|
312
432
|
"ConversationRole",
|
|
313
433
|
"ConversationBot",
|
|
314
434
|
"CallbackConversationBot",
|
|
435
|
+
"MultiModalConversationBot",
|
|
315
436
|
"ConversationTurn",
|
|
316
437
|
]
|
|
@@ -9,7 +9,6 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
9
9
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
10
10
|
from azure.ai.evaluation.simulator._constants import SupportedLanguages
|
|
11
11
|
from azure.ai.evaluation.simulator._helpers._language_suffix_mapping import SUPPORTED_LANGUAGES_MAPPING
|
|
12
|
-
|
|
13
12
|
from ..._http_utils import AsyncHttpPipeline
|
|
14
13
|
from . import ConversationBot, ConversationTurn
|
|
15
14
|
|
|
@@ -4,6 +4,7 @@
|
|
|
4
4
|
import os
|
|
5
5
|
from typing import Any
|
|
6
6
|
from urllib.parse import urljoin, urlparse
|
|
7
|
+
import base64
|
|
7
8
|
|
|
8
9
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
9
10
|
from azure.ai.evaluation._http_utils import AsyncHttpPipeline, get_async_http_client, get_http_client
|
|
@@ -57,6 +58,7 @@ class RAIClient: # pylint: disable=client-accepts-api-version-keyword
|
|
|
57
58
|
# add a "/" at the end of the url
|
|
58
59
|
self.api_url = self.api_url.rstrip("/") + "/"
|
|
59
60
|
self.parameter_json_endpoint = urljoin(self.api_url, "simulation/template/parameters")
|
|
61
|
+
self.parameter_image_endpoint = urljoin(self.api_url, "simulation/template/parameters/image")
|
|
60
62
|
self.jailbreaks_json_endpoint = urljoin(self.api_url, "simulation/jailbreak")
|
|
61
63
|
self.simulation_submit_endpoint = urljoin(self.api_url, "simulation/chat/completions/submit")
|
|
62
64
|
self.xpia_jailbreaks_json_endpoint = urljoin(self.api_url, "simulation/jailbreak/xpia")
|
|
@@ -166,3 +168,41 @@ class RAIClient: # pylint: disable=client-accepts-api-version-keyword
|
|
|
166
168
|
category=ErrorCategory.UNKNOWN,
|
|
167
169
|
blame=ErrorBlame.USER_ERROR,
|
|
168
170
|
)
|
|
171
|
+
|
|
172
|
+
async def get_image_data(self, path: str) -> Any:
|
|
173
|
+
"""Make a GET Image request to the given url
|
|
174
|
+
|
|
175
|
+
:param path: The url of the image
|
|
176
|
+
:type path: str
|
|
177
|
+
:raises EvaluationException: If the Azure safety evaluation service is not available in the current region
|
|
178
|
+
:return: The response
|
|
179
|
+
:rtype: Any
|
|
180
|
+
"""
|
|
181
|
+
token = self.token_manager.get_token()
|
|
182
|
+
headers = {
|
|
183
|
+
"Authorization": f"Bearer {token}",
|
|
184
|
+
"Content-Type": "application/json",
|
|
185
|
+
"User-Agent": USER_AGENT,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
session = self._create_async_client()
|
|
189
|
+
params = {"path": path}
|
|
190
|
+
async with session:
|
|
191
|
+
response = await session.get(
|
|
192
|
+
url=self.parameter_image_endpoint, params=params, headers=headers
|
|
193
|
+
) # pylint: disable=unexpected-keyword-arg
|
|
194
|
+
|
|
195
|
+
if response.status_code == 200:
|
|
196
|
+
return base64.b64encode(response.content).decode("utf-8")
|
|
197
|
+
|
|
198
|
+
msg = (
|
|
199
|
+
"Azure safety evaluation service is not available in your current region, "
|
|
200
|
+
+ "please go to https://aka.ms/azureaistudiosafetyeval to see which regions are supported"
|
|
201
|
+
)
|
|
202
|
+
raise EvaluationException(
|
|
203
|
+
message=msg,
|
|
204
|
+
internal_message=msg,
|
|
205
|
+
target=ErrorTarget.RAI_CLIENT,
|
|
206
|
+
category=ErrorCategory.UNKNOWN,
|
|
207
|
+
blame=ErrorBlame.USER_ERROR,
|
|
208
|
+
)
|