azure-ai-evaluation 1.0.0__py3-none-any.whl → 1.0.0b2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +5 -31
- azure/ai/evaluation/_common/constants.py +2 -9
- azure/ai/evaluation/_common/rai_service.py +120 -300
- azure/ai/evaluation/_common/utils.py +23 -381
- azure/ai/evaluation/_constants.py +6 -19
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/__init__.py +2 -3
- azure/ai/evaluation/_evaluate/{_batch_run/eval_run_context.py → _batch_run_client/batch_run_context.py} +7 -23
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/code_client.py +17 -33
- azure/ai/evaluation/_evaluate/{_batch_run → _batch_run_client}/proxy_client.py +4 -32
- azure/ai/evaluation/_evaluate/_eval_run.py +24 -81
- azure/ai/evaluation/_evaluate/_evaluate.py +239 -393
- azure/ai/evaluation/_evaluate/_telemetry/__init__.py +17 -17
- azure/ai/evaluation/_evaluate/_utils.py +28 -82
- azure/ai/evaluation/_evaluators/_bleu/_bleu.py +18 -17
- azure/ai/evaluation/_evaluators/{_retrieval → _chat}/__init__.py +2 -2
- azure/ai/evaluation/_evaluators/_chat/_chat.py +357 -0
- azure/ai/evaluation/_evaluators/{_service_groundedness → _chat/retrieval}/__init__.py +2 -2
- azure/ai/evaluation/_evaluators/_chat/retrieval/_retrieval.py +157 -0
- azure/ai/evaluation/_evaluators/_chat/retrieval/retrieval.prompty +48 -0
- azure/ai/evaluation/_evaluators/_coherence/_coherence.py +88 -78
- azure/ai/evaluation/_evaluators/_coherence/coherence.prompty +39 -76
- azure/ai/evaluation/_evaluators/_content_safety/__init__.py +4 -0
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +67 -105
- azure/ai/evaluation/_evaluators/{_multimodal/_content_safety_multimodal_base.py → _content_safety/_content_safety_base.py} +34 -24
- azure/ai/evaluation/_evaluators/_content_safety/_content_safety_chat.py +301 -0
- azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +54 -105
- azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +52 -99
- azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +52 -101
- azure/ai/evaluation/_evaluators/_content_safety/_violence.py +51 -101
- azure/ai/evaluation/_evaluators/_eci/_eci.py +54 -44
- azure/ai/evaluation/_evaluators/_f1_score/_f1_score.py +19 -34
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +89 -76
- azure/ai/evaluation/_evaluators/_fluency/fluency.prompty +41 -66
- azure/ai/evaluation/_evaluators/_gleu/_gleu.py +16 -14
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +87 -113
- azure/ai/evaluation/_evaluators/_groundedness/groundedness.prompty +54 -0
- azure/ai/evaluation/_evaluators/_meteor/_meteor.py +27 -20
- azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +80 -89
- azure/ai/evaluation/_evaluators/_protected_materials/__init__.py +5 -0
- azure/ai/evaluation/_evaluators/_protected_materials/_protected_materials.py +104 -0
- azure/ai/evaluation/_evaluators/_qa/_qa.py +30 -23
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +96 -84
- azure/ai/evaluation/_evaluators/_relevance/relevance.prompty +47 -78
- azure/ai/evaluation/_evaluators/_rouge/_rouge.py +27 -26
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +38 -53
- azure/ai/evaluation/_evaluators/_similarity/similarity.prompty +5 -0
- azure/ai/evaluation/_evaluators/_xpia/xpia.py +105 -91
- azure/ai/evaluation/_exceptions.py +7 -28
- azure/ai/evaluation/_http_utils.py +132 -203
- azure/ai/evaluation/_model_configurations.py +8 -104
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/simulator/__init__.py +1 -2
- azure/ai/evaluation/simulator/_adversarial_scenario.py +1 -20
- azure/ai/evaluation/simulator/_adversarial_simulator.py +92 -111
- azure/ai/evaluation/simulator/_constants.py +1 -11
- azure/ai/evaluation/simulator/_conversation/__init__.py +12 -13
- azure/ai/evaluation/simulator/_conversation/_conversation.py +4 -4
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +67 -33
- azure/ai/evaluation/simulator/_helpers/__init__.py +2 -1
- azure/ai/evaluation/{_common → simulator/_helpers}/_experimental.py +9 -24
- azure/ai/evaluation/simulator/_helpers/_simulator_data_classes.py +5 -26
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +94 -107
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +22 -70
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +11 -28
- azure/ai/evaluation/simulator/_model_tools/_rai_client.py +4 -8
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +24 -68
- azure/ai/evaluation/simulator/_model_tools/models.py +10 -10
- azure/ai/evaluation/simulator/_prompty/task_query_response.prompty +10 -6
- azure/ai/evaluation/simulator/_prompty/task_simulate.prompty +5 -6
- azure/ai/evaluation/simulator/_simulator.py +207 -277
- azure/ai/evaluation/simulator/_tracing.py +4 -4
- azure/ai/evaluation/simulator/_utils.py +13 -31
- azure_ai_evaluation-1.0.0b2.dist-info/METADATA +449 -0
- azure_ai_evaluation-1.0.0b2.dist-info/RECORD +99 -0
- {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/WHEEL +1 -1
- azure/ai/evaluation/_common/math.py +0 -89
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +0 -46
- azure/ai/evaluation/_evaluators/_common/__init__.py +0 -13
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +0 -344
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +0 -88
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +0 -133
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_with_query.prompty +0 -113
- azure/ai/evaluation/_evaluators/_groundedness/groundedness_without_query.prompty +0 -99
- azure/ai/evaluation/_evaluators/_multimodal/__init__.py +0 -20
- azure/ai/evaluation/_evaluators/_multimodal/_content_safety_multimodal.py +0 -132
- azure/ai/evaluation/_evaluators/_multimodal/_hate_unfairness.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_protected_material.py +0 -124
- azure/ai/evaluation/_evaluators/_multimodal/_self_harm.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_sexual.py +0 -100
- azure/ai/evaluation/_evaluators/_multimodal/_violence.py +0 -100
- azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +0 -112
- azure/ai/evaluation/_evaluators/_retrieval/retrieval.prompty +0 -93
- azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +0 -148
- azure/ai/evaluation/_vendor/__init__.py +0 -3
- azure/ai/evaluation/_vendor/rouge_score/__init__.py +0 -14
- azure/ai/evaluation/_vendor/rouge_score/rouge_scorer.py +0 -328
- azure/ai/evaluation/_vendor/rouge_score/scoring.py +0 -63
- azure/ai/evaluation/_vendor/rouge_score/tokenize.py +0 -63
- azure/ai/evaluation/_vendor/rouge_score/tokenizers.py +0 -53
- azure/ai/evaluation/simulator/_data_sources/__init__.py +0 -3
- azure/ai/evaluation/simulator/_data_sources/grounding.json +0 -1150
- azure_ai_evaluation-1.0.0.dist-info/METADATA +0 -595
- azure_ai_evaluation-1.0.0.dist-info/NOTICE.txt +0 -70
- azure_ai_evaluation-1.0.0.dist-info/RECORD +0 -119
- {azure_ai_evaluation-1.0.0.dist-info → azure_ai_evaluation-1.0.0b2.dist-info}/top_level.txt +0 -0
|
@@ -2,14 +2,13 @@
|
|
|
2
2
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
|
|
5
|
-
import os
|
|
6
5
|
import functools
|
|
7
6
|
import inspect
|
|
8
7
|
import logging
|
|
9
8
|
import sys
|
|
10
|
-
from typing import Callable, Type, TypeVar, Union
|
|
9
|
+
from typing import Callable, Type, TypeVar, Union
|
|
11
10
|
|
|
12
|
-
from typing_extensions import ParamSpec
|
|
11
|
+
from typing_extensions import ParamSpec
|
|
13
12
|
|
|
14
13
|
DOCSTRING_TEMPLATE = ".. note:: {0} {1}\n\n"
|
|
15
14
|
DOCSTRING_DEFAULT_INDENTATION = 8
|
|
@@ -23,31 +22,20 @@ EXPERIMENTAL_LINK_MESSAGE = (
|
|
|
23
22
|
_warning_cache = set()
|
|
24
23
|
module_logger = logging.getLogger(__name__)
|
|
25
24
|
|
|
25
|
+
TExperimental = TypeVar("TExperimental", bound=Union[Type, Callable])
|
|
26
26
|
P = ParamSpec("P")
|
|
27
27
|
T = TypeVar("T")
|
|
28
28
|
|
|
29
29
|
|
|
30
|
-
|
|
31
|
-
def experimental(wrapped: Type[T]) -> Type[T]: ...
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
@overload
|
|
35
|
-
def experimental(wrapped: Callable[P, T]) -> Callable[P, T]: ...
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def experimental(wrapped: Union[Type[T], Callable[P, T]]) -> Union[Type[T], Callable[P, T]]:
|
|
30
|
+
def experimental(wrapped: TExperimental) -> TExperimental:
|
|
39
31
|
"""Add experimental tag to a class or a method.
|
|
40
32
|
|
|
41
33
|
:param wrapped: Either a Class or Function to mark as experimental
|
|
42
|
-
:type wrapped:
|
|
34
|
+
:type wrapped: TExperimental
|
|
43
35
|
:return: The wrapped class or method
|
|
44
|
-
:rtype:
|
|
36
|
+
:rtype: TExperimental
|
|
45
37
|
"""
|
|
46
|
-
|
|
47
|
-
def is_class(t: Union[Type[T], Callable[P, T]]) -> TypeGuard[Type[T]]:
|
|
48
|
-
return isinstance(t, type)
|
|
49
|
-
|
|
50
|
-
if is_class(wrapped):
|
|
38
|
+
if inspect.isclass(wrapped):
|
|
51
39
|
return _add_class_docstring(wrapped)
|
|
52
40
|
if inspect.isfunction(wrapped):
|
|
53
41
|
return _add_method_docstring(wrapped)
|
|
@@ -86,11 +74,11 @@ def _add_class_docstring(cls: Type[T]) -> Type[T]:
|
|
|
86
74
|
cls.__doc__ = _add_note_to_docstring(cls.__doc__, doc_string)
|
|
87
75
|
else:
|
|
88
76
|
cls.__doc__ = doc_string + ">"
|
|
89
|
-
cls.__init__ = _add_class_warning(cls.__init__)
|
|
77
|
+
cls.__init__ = _add_class_warning(cls.__init__)
|
|
90
78
|
return cls
|
|
91
79
|
|
|
92
80
|
|
|
93
|
-
def _add_method_docstring(func: Callable[P, T]) -> Callable[P, T]:
|
|
81
|
+
def _add_method_docstring(func: Callable[P, T] = None) -> Callable[P, T]:
|
|
94
82
|
"""Add experimental tag to the method doc string.
|
|
95
83
|
|
|
96
84
|
:param func: The function to update
|
|
@@ -150,9 +138,6 @@ def _get_indentation_size(doc_string: str) -> int:
|
|
|
150
138
|
def _should_skip_warning():
|
|
151
139
|
skip_warning_msg = False
|
|
152
140
|
|
|
153
|
-
if os.getenv("AI_EVALS_DISABLE_EXPERIMENTAL_WARNING", "false").lower() == "true":
|
|
154
|
-
skip_warning_msg = True
|
|
155
|
-
|
|
156
141
|
# Cases where we want to suppress the warning:
|
|
157
142
|
# 1. When converting from REST object to SDK object
|
|
158
143
|
for frame in inspect.stack():
|
|
@@ -18,7 +18,7 @@ class Turn:
|
|
|
18
18
|
|
|
19
19
|
role: Union[str, ConversationRole]
|
|
20
20
|
content: str
|
|
21
|
-
context:
|
|
21
|
+
context: str = None
|
|
22
22
|
|
|
23
23
|
def to_dict(self) -> Dict[str, Optional[str]]:
|
|
24
24
|
"""
|
|
@@ -30,19 +30,7 @@ class Turn:
|
|
|
30
30
|
return {
|
|
31
31
|
"role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
|
|
32
32
|
"content": self.content,
|
|
33
|
-
"context":
|
|
34
|
-
}
|
|
35
|
-
|
|
36
|
-
def to_context_free_dict(self) -> Dict[str, Optional[str]]:
|
|
37
|
-
"""
|
|
38
|
-
Convert the conversation turn to a dictionary without context.
|
|
39
|
-
|
|
40
|
-
:returns: A dictionary representation of the conversation turn without context.
|
|
41
|
-
:rtype: Dict[str, Optional[str]]
|
|
42
|
-
"""
|
|
43
|
-
return {
|
|
44
|
-
"role": self.role.value if isinstance(self.role, ConversationRole) else self.role,
|
|
45
|
-
"content": self.content,
|
|
33
|
+
"context": self.context,
|
|
46
34
|
}
|
|
47
35
|
|
|
48
36
|
def __repr__(self):
|
|
@@ -54,13 +42,13 @@ class ConversationHistory:
|
|
|
54
42
|
Conversation history class to keep track of the conversation turns in a conversation.
|
|
55
43
|
"""
|
|
56
44
|
|
|
57
|
-
def __init__(self)
|
|
45
|
+
def __init__(self):
|
|
58
46
|
"""
|
|
59
47
|
Initializes the conversation history with an empty list of turns.
|
|
60
48
|
"""
|
|
61
49
|
self.history: List[Turn] = []
|
|
62
50
|
|
|
63
|
-
def add_to_history(self, turn: Turn)
|
|
51
|
+
def add_to_history(self, turn: Turn):
|
|
64
52
|
"""
|
|
65
53
|
Adds a turn to the conversation history.
|
|
66
54
|
|
|
@@ -69,7 +57,7 @@ class ConversationHistory:
|
|
|
69
57
|
"""
|
|
70
58
|
self.history.append(turn)
|
|
71
59
|
|
|
72
|
-
def to_list(self) -> List[Dict[str,
|
|
60
|
+
def to_list(self) -> List[Dict[str, str]]:
|
|
73
61
|
"""
|
|
74
62
|
Converts the conversation history to a list of dictionaries.
|
|
75
63
|
|
|
@@ -78,15 +66,6 @@ class ConversationHistory:
|
|
|
78
66
|
"""
|
|
79
67
|
return [turn.to_dict() for turn in self.history]
|
|
80
68
|
|
|
81
|
-
def to_context_free_list(self) -> List[Dict[str, Optional[str]]]:
|
|
82
|
-
"""
|
|
83
|
-
Converts the conversation history to a list of dictionaries without context.
|
|
84
|
-
|
|
85
|
-
:returns: A list of dictionaries representing the conversation turns without context.
|
|
86
|
-
:rtype: List[Dict[str, str]]
|
|
87
|
-
"""
|
|
88
|
-
return [turn.to_context_free_dict() for turn in self.history]
|
|
89
|
-
|
|
90
69
|
def __len__(self) -> int:
|
|
91
70
|
return len(self.history)
|
|
92
71
|
|
|
@@ -1,30 +1,54 @@
|
|
|
1
1
|
# ---------------------------------------------------------
|
|
2
2
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
|
-
# pylint: disable=C0301,C0114,R0913,R0903
|
|
5
4
|
# noqa: E501
|
|
6
|
-
import
|
|
5
|
+
import functools
|
|
7
6
|
import logging
|
|
8
|
-
from typing import Callable
|
|
7
|
+
from typing import Callable
|
|
9
8
|
|
|
10
|
-
from
|
|
9
|
+
from promptflow._sdk._telemetry import ActivityType, monitor_operation
|
|
11
10
|
|
|
12
|
-
from azure.ai.evaluation._common.utils import validate_azure_ai_project
|
|
13
|
-
from azure.ai.evaluation._common._experimental import experimental
|
|
14
11
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
15
|
-
from azure.ai.evaluation.simulator import AdversarialScenarioJailbreak, SupportedLanguages
|
|
16
12
|
from azure.ai.evaluation._model_configurations import AzureAIProject
|
|
17
|
-
from azure.
|
|
18
|
-
|
|
19
|
-
from ._adversarial_simulator import AdversarialSimulator, JsonLineList
|
|
13
|
+
from azure.ai.evaluation.simulator import AdversarialScenario
|
|
14
|
+
from azure.identity import DefaultAzureCredential
|
|
20
15
|
|
|
16
|
+
from ._adversarial_simulator import AdversarialSimulator
|
|
21
17
|
from ._model_tools import AdversarialTemplateHandler, ManagedIdentityAPITokenManager, RAIClient, TokenScope
|
|
22
18
|
|
|
23
19
|
logger = logging.getLogger(__name__)
|
|
24
20
|
|
|
25
21
|
|
|
26
|
-
|
|
27
|
-
|
|
22
|
+
def monitor_adversarial_scenario(func) -> Callable:
|
|
23
|
+
"""Decorator to monitor adversarial scenario.
|
|
24
|
+
|
|
25
|
+
:param func: The function to be decorated.
|
|
26
|
+
:type func: Callable
|
|
27
|
+
:return: The decorated function.
|
|
28
|
+
:rtype: Callable
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
@functools.wraps(func)
|
|
32
|
+
def wrapper(*args, **kwargs):
|
|
33
|
+
scenario = str(kwargs.get("scenario", None))
|
|
34
|
+
max_conversation_turns = kwargs.get("max_conversation_turns", None)
|
|
35
|
+
max_simulation_results = kwargs.get("max_simulation_results", None)
|
|
36
|
+
decorated_func = monitor_operation(
|
|
37
|
+
activity_name="xpia.adversarial.simulator.call",
|
|
38
|
+
activity_type=ActivityType.PUBLICAPI,
|
|
39
|
+
custom_dimensions={
|
|
40
|
+
"scenario": scenario,
|
|
41
|
+
"max_conversation_turns": max_conversation_turns,
|
|
42
|
+
"max_simulation_results": max_simulation_results,
|
|
43
|
+
},
|
|
44
|
+
)(func)
|
|
45
|
+
|
|
46
|
+
return decorated_func(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
return wrapper
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class IndirectAttackSimulator:
|
|
28
52
|
"""
|
|
29
53
|
Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
|
|
30
54
|
|
|
@@ -33,42 +57,44 @@ class IndirectAttackSimulator(AdversarialSimulator):
|
|
|
33
57
|
:type azure_ai_project: ~azure.ai.evaluation.AzureAIProject
|
|
34
58
|
:param credential: The credential for connecting to Azure AI project.
|
|
35
59
|
:type credential: ~azure.core.credentials.TokenCredential
|
|
36
|
-
|
|
37
|
-
.. admonition:: Example:
|
|
38
|
-
|
|
39
|
-
.. literalinclude:: ../samples/evaluation_samples_simulate.py
|
|
40
|
-
:start-after: [START indirect_attack_simulator]
|
|
41
|
-
:end-before: [END indirect_attack_simulator]
|
|
42
|
-
:language: python
|
|
43
|
-
:dedent: 8
|
|
44
|
-
:caption: Run the IndirectAttackSimulator to produce 1 result with 1 conversation turn (2 messages in the result).
|
|
45
60
|
"""
|
|
46
61
|
|
|
47
|
-
def __init__(self, *, azure_ai_project: AzureAIProject, credential
|
|
62
|
+
def __init__(self, *, azure_ai_project: AzureAIProject, credential=None):
|
|
48
63
|
"""Constructor."""
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
except EvaluationException as e:
|
|
64
|
+
# check if azure_ai_project has the keys: subscription_id, resource_group_name, project_name, credential
|
|
65
|
+
if not all(key in azure_ai_project for key in ["subscription_id", "resource_group_name", "project_name"]):
|
|
66
|
+
msg = "azure_ai_project must contain keys: subscription_id, resource_group_name and project_name"
|
|
53
67
|
raise EvaluationException(
|
|
54
|
-
message=
|
|
55
|
-
internal_message=
|
|
68
|
+
message=msg,
|
|
69
|
+
internal_message=msg,
|
|
56
70
|
target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
|
|
57
|
-
category=
|
|
58
|
-
blame=
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
71
|
+
category=ErrorCategory.MISSING_FIELD,
|
|
72
|
+
blame=ErrorBlame.USER_ERROR,
|
|
73
|
+
)
|
|
74
|
+
if not all(azure_ai_project[key] for key in ["subscription_id", "resource_group_name", "project_name"]):
|
|
75
|
+
msg = "subscription_id, resource_group_name and project_name keys cannot be None"
|
|
76
|
+
raise EvaluationException(
|
|
77
|
+
message=msg,
|
|
78
|
+
internal_message=msg,
|
|
79
|
+
target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
|
|
80
|
+
category=ErrorCategory.MISSING_FIELD,
|
|
81
|
+
blame=ErrorBlame.USER_ERROR,
|
|
82
|
+
)
|
|
83
|
+
if "credential" not in azure_ai_project and not credential:
|
|
84
|
+
credential = DefaultAzureCredential()
|
|
85
|
+
elif "credential" in azure_ai_project:
|
|
86
|
+
credential = azure_ai_project["credential"]
|
|
87
|
+
self.credential = credential
|
|
88
|
+
self.azure_ai_project = azure_ai_project
|
|
62
89
|
self.token_manager = ManagedIdentityAPITokenManager(
|
|
63
90
|
token_scope=TokenScope.DEFAULT_AZURE_MANAGEMENT,
|
|
64
91
|
logger=logging.getLogger("AdversarialSimulator"),
|
|
65
|
-
credential=
|
|
92
|
+
credential=credential,
|
|
66
93
|
)
|
|
67
|
-
self.rai_client = RAIClient(azure_ai_project=
|
|
94
|
+
self.rai_client = RAIClient(azure_ai_project=azure_ai_project, token_manager=self.token_manager)
|
|
68
95
|
self.adversarial_template_handler = AdversarialTemplateHandler(
|
|
69
|
-
azure_ai_project=
|
|
96
|
+
azure_ai_project=azure_ai_project, rai_client=self.rai_client
|
|
70
97
|
)
|
|
71
|
-
super().__init__(azure_ai_project=azure_ai_project, credential=credential)
|
|
72
98
|
|
|
73
99
|
def _ensure_service_dependencies(self):
|
|
74
100
|
if self.rai_client is None:
|
|
@@ -81,25 +107,33 @@ class IndirectAttackSimulator(AdversarialSimulator):
|
|
|
81
107
|
blame=ErrorBlame.USER_ERROR,
|
|
82
108
|
)
|
|
83
109
|
|
|
110
|
+
# @monitor_adversarial_scenario
|
|
84
111
|
async def __call__(
|
|
85
112
|
self,
|
|
86
113
|
*,
|
|
114
|
+
scenario: AdversarialScenario,
|
|
87
115
|
target: Callable,
|
|
116
|
+
max_conversation_turns: int = 1,
|
|
88
117
|
max_simulation_results: int = 3,
|
|
89
118
|
api_call_retry_limit: int = 3,
|
|
90
119
|
api_call_retry_sleep_sec: int = 1,
|
|
91
120
|
api_call_delay_sec: int = 0,
|
|
92
121
|
concurrent_async_task: int = 3,
|
|
93
|
-
**kwargs,
|
|
94
122
|
):
|
|
95
123
|
"""
|
|
96
124
|
Initializes the XPIA (cross domain prompt injected attack) jailbreak adversarial simulator with a project scope.
|
|
97
125
|
This simulator converses with your AI system using prompts injected into the context to interrupt normal
|
|
98
126
|
expected functionality by eliciting manipulated content, intrusion and attempting to gather information outside
|
|
99
127
|
the scope of your AI system.
|
|
128
|
+
|
|
129
|
+
:keyword scenario: Enum value specifying the adversarial scenario used for generating inputs.
|
|
130
|
+
:paramtype scenario: azure.ai.evaluation.simulator.AdversarialScenario
|
|
100
131
|
:keyword target: The target function to simulate adversarial inputs against.
|
|
101
132
|
This function should be asynchronous and accept a dictionary representing the adversarial input.
|
|
102
133
|
:paramtype target: Callable
|
|
134
|
+
:keyword max_conversation_turns: The maximum number of conversation turns to simulate.
|
|
135
|
+
Defaults to 1.
|
|
136
|
+
:paramtype max_conversation_turns: int
|
|
103
137
|
:keyword max_simulation_results: The maximum number of simulation results to return.
|
|
104
138
|
Defaults to 3.
|
|
105
139
|
:paramtype max_simulation_results: int
|
|
@@ -136,11 +170,11 @@ class IndirectAttackSimulator(AdversarialSimulator):
|
|
|
136
170
|
'template_parameters': {},
|
|
137
171
|
'messages': [
|
|
138
172
|
{
|
|
139
|
-
'content': '<adversarial query>',
|
|
173
|
+
'content': '<jailbreak prompt> <adversarial query>',
|
|
140
174
|
'role': 'user'
|
|
141
175
|
},
|
|
142
176
|
{
|
|
143
|
-
'content': "<response from
|
|
177
|
+
'content': "<response from endpoint>",
|
|
144
178
|
'role': 'assistant',
|
|
145
179
|
'context': None
|
|
146
180
|
}
|
|
@@ -149,72 +183,25 @@ class IndirectAttackSimulator(AdversarialSimulator):
|
|
|
149
183
|
}]
|
|
150
184
|
}
|
|
151
185
|
"""
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
sim_results = []
|
|
161
|
-
tasks = []
|
|
162
|
-
total_tasks = sum(len(t.template_parameters) for t in templates)
|
|
163
|
-
if max_simulation_results > total_tasks:
|
|
164
|
-
logger.warning(
|
|
165
|
-
"Cannot provide %s results due to maximum number of adversarial simulations that can be generated: %s."
|
|
166
|
-
"\n %s simulations will be generated.",
|
|
167
|
-
max_simulation_results,
|
|
168
|
-
total_tasks,
|
|
169
|
-
total_tasks,
|
|
186
|
+
if scenario not in AdversarialScenario.__members__.values():
|
|
187
|
+
msg = f"Invalid scenario: {scenario}. Supported scenarios: {AdversarialScenario.__members__.values()}"
|
|
188
|
+
raise EvaluationException(
|
|
189
|
+
message=msg,
|
|
190
|
+
internal_message=msg,
|
|
191
|
+
target=ErrorTarget.DIRECT_ATTACK_SIMULATOR,
|
|
192
|
+
category=ErrorCategory.INVALID_VALUE,
|
|
193
|
+
blame=ErrorBlame.USER_ERROR,
|
|
170
194
|
)
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
195
|
+
jb_sim = AdversarialSimulator(azure_ai_project=self.azure_ai_project, credential=self.credential)
|
|
196
|
+
jb_sim_results = await jb_sim(
|
|
197
|
+
scenario=scenario,
|
|
198
|
+
target=target,
|
|
199
|
+
max_conversation_turns=max_conversation_turns,
|
|
200
|
+
max_simulation_results=max_simulation_results,
|
|
201
|
+
api_call_retry_limit=api_call_retry_limit,
|
|
202
|
+
api_call_retry_sleep_sec=api_call_retry_sleep_sec,
|
|
203
|
+
api_call_delay_sec=api_call_delay_sec,
|
|
204
|
+
concurrent_async_task=concurrent_async_task,
|
|
205
|
+
_jailbreak_type="xpia",
|
|
177
206
|
)
|
|
178
|
-
|
|
179
|
-
for parameter in template.template_parameters:
|
|
180
|
-
tasks.append(
|
|
181
|
-
asyncio.create_task(
|
|
182
|
-
self._simulate_async(
|
|
183
|
-
target=target,
|
|
184
|
-
template=template,
|
|
185
|
-
parameters=parameter,
|
|
186
|
-
max_conversation_turns=max_conversation_turns,
|
|
187
|
-
api_call_retry_limit=api_call_retry_limit,
|
|
188
|
-
api_call_retry_sleep_sec=api_call_retry_sleep_sec,
|
|
189
|
-
api_call_delay_sec=api_call_delay_sec,
|
|
190
|
-
language=language,
|
|
191
|
-
semaphore=semaphore,
|
|
192
|
-
)
|
|
193
|
-
)
|
|
194
|
-
)
|
|
195
|
-
if len(tasks) >= max_simulation_results:
|
|
196
|
-
break
|
|
197
|
-
if len(tasks) >= max_simulation_results:
|
|
198
|
-
break
|
|
199
|
-
for task in asyncio.as_completed(tasks):
|
|
200
|
-
completed_task = await task # type: ignore
|
|
201
|
-
template_parameters = completed_task.get("template_parameters", {}) # type: ignore
|
|
202
|
-
xpia_attack_type = template_parameters.get("xpia_attack_type", "") # type: ignore
|
|
203
|
-
action = template_parameters.get("action", "") # type: ignore
|
|
204
|
-
document_type = template_parameters.get("document_type", "") # type: ignore
|
|
205
|
-
sim_results.append(
|
|
206
|
-
{
|
|
207
|
-
"messages": completed_task["messages"], # type: ignore
|
|
208
|
-
"$schema": "http://azureml/sdk-2-0/ChatConversation.json",
|
|
209
|
-
"template_parameters": {
|
|
210
|
-
"metadata": {
|
|
211
|
-
"xpia_attack_type": xpia_attack_type,
|
|
212
|
-
"action": action,
|
|
213
|
-
"document_type": document_type,
|
|
214
|
-
},
|
|
215
|
-
},
|
|
216
|
-
}
|
|
217
|
-
)
|
|
218
|
-
progress_bar.update(1)
|
|
219
|
-
progress_bar.close()
|
|
220
|
-
return JsonLineList(sim_results)
|
|
207
|
+
return jb_sim_results
|
|
@@ -3,20 +3,16 @@
|
|
|
3
3
|
# ---------------------------------------------------------
|
|
4
4
|
|
|
5
5
|
import asyncio
|
|
6
|
-
import inspect
|
|
7
6
|
import logging
|
|
8
7
|
import os
|
|
9
8
|
import time
|
|
10
9
|
from abc import ABC, abstractmethod
|
|
11
10
|
from enum import Enum
|
|
12
|
-
from typing import Optional, Union
|
|
11
|
+
from typing import Dict, Optional, Union
|
|
13
12
|
|
|
14
|
-
from azure.core.credentials import AccessToken, TokenCredential
|
|
15
13
|
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
|
|
16
14
|
|
|
17
|
-
AZURE_TOKEN_REFRESH_INTERVAL =
|
|
18
|
-
os.getenv("AZURE_TOKEN_REFRESH_INTERVAL", "600")
|
|
19
|
-
) # token refresh interval in seconds
|
|
15
|
+
AZURE_TOKEN_REFRESH_INTERVAL = 600 # seconds
|
|
20
16
|
|
|
21
17
|
|
|
22
18
|
class TokenScope(Enum):
|
|
@@ -33,24 +29,24 @@ class APITokenManager(ABC):
|
|
|
33
29
|
:param auth_header: Authorization header prefix. Defaults to "Bearer"
|
|
34
30
|
:type auth_header: str
|
|
35
31
|
:param credential: Azure credential object
|
|
36
|
-
:type credential: Optional[
|
|
32
|
+
:type credential: Optional[Union[azure.identity.DefaultAzureCredential, azure.identity.ManagedIdentityCredential]
|
|
37
33
|
"""
|
|
38
34
|
|
|
39
35
|
def __init__(
|
|
40
36
|
self,
|
|
41
37
|
logger: logging.Logger,
|
|
42
38
|
auth_header: str = "Bearer",
|
|
43
|
-
credential: Optional[
|
|
39
|
+
credential: Optional[Union[DefaultAzureCredential, ManagedIdentityCredential]] = None,
|
|
44
40
|
) -> None:
|
|
45
41
|
self.logger = logger
|
|
46
42
|
self.auth_header = auth_header
|
|
47
|
-
self._lock
|
|
43
|
+
self._lock = None
|
|
48
44
|
if credential is not None:
|
|
49
45
|
self.credential = credential
|
|
50
46
|
else:
|
|
51
47
|
self.credential = self.get_aad_credential()
|
|
52
|
-
self.token
|
|
53
|
-
self.last_refresh_time
|
|
48
|
+
self.token = None
|
|
49
|
+
self.last_refresh_time = None
|
|
54
50
|
|
|
55
51
|
@property
|
|
56
52
|
def lock(self) -> asyncio.Lock:
|
|
@@ -77,26 +73,20 @@ class APITokenManager(ABC):
|
|
|
77
73
|
identity_client_id = os.environ.get("DEFAULT_IDENTITY_CLIENT_ID", None)
|
|
78
74
|
if identity_client_id is not None:
|
|
79
75
|
self.logger.info(f"Using DEFAULT_IDENTITY_CLIENT_ID: {identity_client_id}")
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
@abstractmethod
|
|
86
|
-
def get_token(self) -> str:
|
|
87
|
-
"""Async method to get the API token. Subclasses should implement this method.
|
|
88
|
-
|
|
89
|
-
:return: API token
|
|
90
|
-
:rtype: str
|
|
91
|
-
"""
|
|
76
|
+
credential = ManagedIdentityCredential(client_id=identity_client_id)
|
|
77
|
+
else:
|
|
78
|
+
self.logger.info("Environment variable DEFAULT_IDENTITY_CLIENT_ID is not set, using DefaultAzureCredential")
|
|
79
|
+
credential = DefaultAzureCredential()
|
|
80
|
+
return credential
|
|
92
81
|
|
|
93
82
|
@abstractmethod
|
|
94
|
-
async def
|
|
83
|
+
async def get_token(self) -> str:
|
|
95
84
|
"""Async method to get the API token. Subclasses should implement this method.
|
|
96
85
|
|
|
97
86
|
:return: API token
|
|
98
87
|
:rtype: str
|
|
99
88
|
"""
|
|
89
|
+
pass # pylint: disable=unnecessary-pass
|
|
100
90
|
|
|
101
91
|
|
|
102
92
|
class ManagedIdentityAPITokenManager(APITokenManager):
|
|
@@ -110,18 +100,12 @@ class ManagedIdentityAPITokenManager(APITokenManager):
|
|
|
110
100
|
:paramtype kwargs: Dict
|
|
111
101
|
"""
|
|
112
102
|
|
|
113
|
-
def __init__(
|
|
114
|
-
|
|
115
|
-
token_scope: TokenScope,
|
|
116
|
-
logger: logging.Logger,
|
|
117
|
-
*,
|
|
118
|
-
auth_header: str = "Bearer",
|
|
119
|
-
credential: Optional[TokenCredential] = None,
|
|
120
|
-
):
|
|
121
|
-
super().__init__(logger, auth_header=auth_header, credential=credential)
|
|
103
|
+
def __init__(self, token_scope: TokenScope, logger: logging.Logger, **kwargs: Dict):
|
|
104
|
+
super().__init__(logger, **kwargs)
|
|
122
105
|
self.token_scope = token_scope
|
|
123
106
|
|
|
124
|
-
|
|
107
|
+
# Bug 3353724: This get_token is sync method, but it is defined as async method in the base class
|
|
108
|
+
def get_token(self) -> str: # pylint: disable=invalid-overridden-method
|
|
125
109
|
"""Get the API token. If the token is not available or has expired, refresh the token.
|
|
126
110
|
|
|
127
111
|
:return: API token
|
|
@@ -138,31 +122,6 @@ class ManagedIdentityAPITokenManager(APITokenManager):
|
|
|
138
122
|
|
|
139
123
|
return self.token
|
|
140
124
|
|
|
141
|
-
async def get_token_async(self) -> str:
|
|
142
|
-
"""Get the API token synchronously. If the token is not available or has expired, refresh it.
|
|
143
|
-
|
|
144
|
-
:return: API token
|
|
145
|
-
:rtype: str
|
|
146
|
-
"""
|
|
147
|
-
if (
|
|
148
|
-
self.token is None
|
|
149
|
-
or self.last_refresh_time is None
|
|
150
|
-
or time.time() - self.last_refresh_time > AZURE_TOKEN_REFRESH_INTERVAL
|
|
151
|
-
):
|
|
152
|
-
self.last_refresh_time = time.time()
|
|
153
|
-
get_token_method = self.credential.get_token(self.token_scope.value)
|
|
154
|
-
if inspect.isawaitable(get_token_method):
|
|
155
|
-
# If it's awaitable, await it
|
|
156
|
-
token_response: AccessToken = await get_token_method
|
|
157
|
-
else:
|
|
158
|
-
# Otherwise, call it synchronously
|
|
159
|
-
token_response = get_token_method
|
|
160
|
-
|
|
161
|
-
self.token = token_response.token
|
|
162
|
-
self.logger.info("Refreshed Azure endpoint token.")
|
|
163
|
-
|
|
164
|
-
return self.token
|
|
165
|
-
|
|
166
125
|
|
|
167
126
|
class PlainTokenManager(APITokenManager):
|
|
168
127
|
"""Plain API Token Manager
|
|
@@ -175,18 +134,11 @@ class PlainTokenManager(APITokenManager):
|
|
|
175
134
|
:paramtype kwargs: Dict
|
|
176
135
|
"""
|
|
177
136
|
|
|
178
|
-
def __init__(
|
|
179
|
-
|
|
180
|
-
openapi_key
|
|
181
|
-
logger: logging.Logger,
|
|
182
|
-
*,
|
|
183
|
-
auth_header: str = "Bearer",
|
|
184
|
-
credential: Optional[TokenCredential] = None,
|
|
185
|
-
) -> None:
|
|
186
|
-
super().__init__(logger, auth_header=auth_header, credential=credential)
|
|
187
|
-
self.token: str = openapi_key
|
|
137
|
+
def __init__(self, openapi_key: str, logger: logging.Logger, **kwargs: Dict):
|
|
138
|
+
super().__init__(logger, **kwargs)
|
|
139
|
+
self.token = openapi_key
|
|
188
140
|
|
|
189
|
-
def get_token(self) -> str:
|
|
141
|
+
async def get_token(self) -> str:
|
|
190
142
|
"""Get the API token
|
|
191
143
|
|
|
192
144
|
:return: API token
|