azure-ai-evaluation 1.5.0__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of azure-ai-evaluation might be problematic. Click here for more details.
- azure/ai/evaluation/__init__.py +9 -0
- azure/ai/evaluation/_aoai/__init__.py +10 -0
- azure/ai/evaluation/_aoai/aoai_grader.py +89 -0
- azure/ai/evaluation/_aoai/label_grader.py +66 -0
- azure/ai/evaluation/_aoai/string_check_grader.py +65 -0
- azure/ai/evaluation/_aoai/text_similarity_grader.py +88 -0
- azure/ai/evaluation/_azure/_clients.py +4 -4
- azure/ai/evaluation/_azure/_envs.py +208 -0
- azure/ai/evaluation/_azure/_token_manager.py +12 -7
- azure/ai/evaluation/_common/__init__.py +5 -0
- azure/ai/evaluation/_common/evaluation_onedp_client.py +118 -0
- azure/ai/evaluation/_common/onedp/__init__.py +32 -0
- azure/ai/evaluation/_common/onedp/_client.py +139 -0
- azure/ai/evaluation/_common/onedp/_configuration.py +73 -0
- azure/ai/evaluation/_common/onedp/_model_base.py +1232 -0
- azure/ai/evaluation/_common/onedp/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/_serialization.py +2032 -0
- azure/ai/evaluation/_common/onedp/_types.py +21 -0
- azure/ai/evaluation/_common/onedp/_validation.py +50 -0
- azure/ai/evaluation/_common/onedp/_vendor.py +50 -0
- azure/ai/evaluation/_common/onedp/_version.py +9 -0
- azure/ai/evaluation/_common/onedp/aio/__init__.py +29 -0
- azure/ai/evaluation/_common/onedp/aio/_client.py +143 -0
- azure/ai/evaluation/_common/onedp/aio/_configuration.py +75 -0
- azure/ai/evaluation/_common/onedp/aio/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/aio/_vendor.py +40 -0
- azure/ai/evaluation/_common/onedp/aio/operations/__init__.py +39 -0
- azure/ai/evaluation/_common/onedp/aio/operations/_operations.py +4494 -0
- azure/ai/evaluation/_common/onedp/aio/operations/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/models/__init__.py +142 -0
- azure/ai/evaluation/_common/onedp/models/_enums.py +162 -0
- azure/ai/evaluation/_common/onedp/models/_models.py +2228 -0
- azure/ai/evaluation/_common/onedp/models/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/operations/__init__.py +39 -0
- azure/ai/evaluation/_common/onedp/operations/_operations.py +5655 -0
- azure/ai/evaluation/_common/onedp/operations/_patch.py +21 -0
- azure/ai/evaluation/_common/onedp/py.typed +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_operations.py +34 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/__init__.py +1 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/__init__.py +22 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_operations.py +29 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/aio/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/__init__.py +22 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_operations.py +29 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/buildingblocks/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/__init__.py +25 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/_operations.py +34 -0
- azure/ai/evaluation/_common/onedp/servicepatterns/operations/_patch.py +20 -0
- azure/ai/evaluation/_common/rai_service.py +158 -28
- azure/ai/evaluation/_common/raiclient/_version.py +1 -1
- azure/ai/evaluation/_common/utils.py +79 -1
- azure/ai/evaluation/_constants.py +16 -0
- azure/ai/evaluation/_eval_mapping.py +71 -0
- azure/ai/evaluation/_evaluate/_batch_run/_run_submitter_client.py +30 -16
- azure/ai/evaluation/_evaluate/_batch_run/eval_run_context.py +8 -0
- azure/ai/evaluation/_evaluate/_batch_run/proxy_client.py +5 -0
- azure/ai/evaluation/_evaluate/_batch_run/target_run_context.py +17 -1
- azure/ai/evaluation/_evaluate/_eval_run.py +1 -1
- azure/ai/evaluation/_evaluate/_evaluate.py +325 -74
- azure/ai/evaluation/_evaluate/_evaluate_aoai.py +534 -0
- azure/ai/evaluation/_evaluate/_utils.py +117 -4
- azure/ai/evaluation/_evaluators/_common/_base_eval.py +8 -3
- azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +12 -3
- azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +2 -2
- azure/ai/evaluation/_evaluators/_document_retrieval/__init__.py +11 -0
- azure/ai/evaluation/_evaluators/_document_retrieval/_document_retrieval.py +467 -0
- azure/ai/evaluation/_evaluators/_fluency/_fluency.py +1 -1
- azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +1 -1
- azure/ai/evaluation/_evaluators/_intent_resolution/_intent_resolution.py +6 -2
- azure/ai/evaluation/_evaluators/_relevance/_relevance.py +1 -1
- azure/ai/evaluation/_evaluators/_response_completeness/_response_completeness.py +7 -2
- azure/ai/evaluation/_evaluators/_response_completeness/response_completeness.prompty +31 -46
- azure/ai/evaluation/_evaluators/_similarity/_similarity.py +1 -1
- azure/ai/evaluation/_evaluators/_task_adherence/_task_adherence.py +5 -2
- azure/ai/evaluation/_evaluators/_tool_call_accuracy/_tool_call_accuracy.py +6 -2
- azure/ai/evaluation/_exceptions.py +2 -0
- azure/ai/evaluation/_legacy/_adapters/__init__.py +0 -14
- azure/ai/evaluation/_legacy/_adapters/_check.py +17 -0
- azure/ai/evaluation/_legacy/_adapters/_flows.py +1 -1
- azure/ai/evaluation/_legacy/_batch_engine/_engine.py +51 -32
- azure/ai/evaluation/_legacy/_batch_engine/_openai_injector.py +114 -8
- azure/ai/evaluation/_legacy/_batch_engine/_result.py +6 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run.py +6 -0
- azure/ai/evaluation/_legacy/_batch_engine/_run_submitter.py +69 -29
- azure/ai/evaluation/_legacy/_batch_engine/_trace.py +54 -62
- azure/ai/evaluation/_legacy/_batch_engine/_utils.py +19 -1
- azure/ai/evaluation/_legacy/_common/__init__.py +3 -0
- azure/ai/evaluation/_legacy/_common/_async_token_provider.py +124 -0
- azure/ai/evaluation/_legacy/_common/_thread_pool_executor_with_context.py +15 -0
- azure/ai/evaluation/_legacy/prompty/_connection.py +11 -74
- azure/ai/evaluation/_legacy/prompty/_exceptions.py +80 -0
- azure/ai/evaluation/_legacy/prompty/_prompty.py +119 -9
- azure/ai/evaluation/_legacy/prompty/_utils.py +72 -2
- azure/ai/evaluation/_safety_evaluation/_safety_evaluation.py +90 -17
- azure/ai/evaluation/_version.py +1 -1
- azure/ai/evaluation/red_team/_attack_strategy.py +1 -1
- azure/ai/evaluation/red_team/_red_team.py +825 -450
- azure/ai/evaluation/red_team/_utils/metric_mapping.py +23 -0
- azure/ai/evaluation/red_team/_utils/strategy_utils.py +1 -1
- azure/ai/evaluation/simulator/_adversarial_simulator.py +63 -39
- azure/ai/evaluation/simulator/_constants.py +1 -0
- azure/ai/evaluation/simulator/_conversation/__init__.py +13 -6
- azure/ai/evaluation/simulator/_conversation/_conversation.py +2 -1
- azure/ai/evaluation/simulator/_direct_attack_simulator.py +35 -22
- azure/ai/evaluation/simulator/_helpers/_language_suffix_mapping.py +1 -0
- azure/ai/evaluation/simulator/_indirect_attack_simulator.py +40 -25
- azure/ai/evaluation/simulator/_model_tools/__init__.py +2 -1
- azure/ai/evaluation/simulator/_model_tools/_generated_rai_client.py +24 -18
- azure/ai/evaluation/simulator/_model_tools/_identity_manager.py +5 -10
- azure/ai/evaluation/simulator/_model_tools/_proxy_completion_model.py +65 -41
- azure/ai/evaluation/simulator/_model_tools/_template_handler.py +9 -5
- azure/ai/evaluation/simulator/_model_tools/models.py +20 -17
- {azure_ai_evaluation-1.5.0.dist-info → azure_ai_evaluation-1.6.0.dist-info}/METADATA +25 -2
- {azure_ai_evaluation-1.5.0.dist-info → azure_ai_evaluation-1.6.0.dist-info}/RECORD +123 -65
- /azure/ai/evaluation/_legacy/{_batch_engine → _common}/_logging.py +0 -0
- {azure_ai_evaluation-1.5.0.dist-info → azure_ai_evaluation-1.6.0.dist-info}/NOTICE.txt +0 -0
- {azure_ai_evaluation-1.5.0.dist-info → azure_ai_evaluation-1.6.0.dist-info}/WHEEL +0 -0
- {azure_ai_evaluation-1.5.0.dist-info → azure_ai_evaluation-1.6.0.dist-info}/top_level.txt +0 -0
azure/ai/evaluation/__init__.py
CHANGED
|
@@ -40,6 +40,11 @@ from ._model_configurations import (
|
|
|
40
40
|
Message,
|
|
41
41
|
OpenAIModelConfiguration,
|
|
42
42
|
)
|
|
43
|
+
from ._aoai.aoai_grader import AzureOpenAIGrader
|
|
44
|
+
from ._aoai.label_grader import AzureOpenAILabelGrader
|
|
45
|
+
from ._aoai.string_check_grader import AzureOpenAIStringCheckGrader
|
|
46
|
+
from ._aoai.text_similarity_grader import AzureOpenAITextSimilarityGrader
|
|
47
|
+
|
|
43
48
|
|
|
44
49
|
_patch_all = []
|
|
45
50
|
|
|
@@ -89,6 +94,10 @@ __all__ = [
|
|
|
89
94
|
"CodeVulnerabilityEvaluator",
|
|
90
95
|
"UngroundedAttributesEvaluator",
|
|
91
96
|
"ToolCallAccuracyEvaluator",
|
|
97
|
+
"AzureOpenAIGrader",
|
|
98
|
+
"AzureOpenAILabelGrader",
|
|
99
|
+
"AzureOpenAIStringCheckGrader",
|
|
100
|
+
"AzureOpenAITextSimilarityGrader",
|
|
92
101
|
]
|
|
93
102
|
|
|
94
103
|
__all__.extend([p for p in _patch_all if p not in __all__])
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
from .aoai_grader import AzureOpenAIGrader
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"AzureOpenAIGrader",
|
|
10
|
+
]
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
5
|
+
|
|
6
|
+
from azure.ai.evaluation._constants import DEFAULT_AOAI_API_VERSION
|
|
7
|
+
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
8
|
+
from typing import Any, Dict, Union
|
|
9
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@experimental
|
|
13
|
+
class AzureOpenAIGrader():
|
|
14
|
+
"""
|
|
15
|
+
Base class for Azure OpenAI grader wrappers, recommended only for use by experienced OpenAI API users.
|
|
16
|
+
Combines a model configuration and any grader configuration
|
|
17
|
+
into a singular object that can be used in evaluations.
|
|
18
|
+
|
|
19
|
+
Supplying an AzureOpenAIGrader to the `evaluate` method will cause an asynchronous request to evaluate
|
|
20
|
+
the grader via the OpenAI API. The results of the evaluation will then be merged into the standard
|
|
21
|
+
evaluation results.
|
|
22
|
+
|
|
23
|
+
:param model_config: The model configuration to use for the grader.
|
|
24
|
+
:type model_config: Union[
|
|
25
|
+
~azure.ai.evaluation.AzureOpenAIModelConfiguration,
|
|
26
|
+
~azure.ai.evaluation.OpenAIModelConfiguration
|
|
27
|
+
]
|
|
28
|
+
:param grader_config: The grader configuration to use for the grader. This is expected
|
|
29
|
+
to be formatted as a dictionary that matches the specifications of the sub-types of
|
|
30
|
+
the TestingCriterion alias specified in (OpenAI's SDK)[https://github.com/openai/openai-python/blob/ed53107e10e6c86754866b48f8bd862659134ca8/src/openai/types/eval_create_params.py#L151].
|
|
31
|
+
:type grader_config: Dict[str, Any]
|
|
32
|
+
:param kwargs: Additional keyword arguments to pass to the grader.
|
|
33
|
+
:type kwargs: Any
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
id = "aoai://general"
|
|
39
|
+
|
|
40
|
+
def __init__(self, *, model_config : Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration], grader_config: Dict[str, Any], **kwargs: Any):
|
|
41
|
+
self._model_config = model_config
|
|
42
|
+
self._grader_config = grader_config
|
|
43
|
+
|
|
44
|
+
if kwargs.get("validate", True):
|
|
45
|
+
self._validate_model_config()
|
|
46
|
+
self._validate_grader_config()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _validate_model_config(self) -> None:
|
|
51
|
+
"""Validate the model configuration that this grader wrapper is using."""
|
|
52
|
+
if "api_key" not in self._model_config or not self._model_config.get("api_key"):
|
|
53
|
+
msg = f"{type(self).__name__}: Requires an api_key in the supplied model_config."
|
|
54
|
+
raise EvaluationException(
|
|
55
|
+
message=msg,
|
|
56
|
+
blame=ErrorBlame.USER_ERROR,
|
|
57
|
+
category=ErrorCategory.INVALID_VALUE,
|
|
58
|
+
target=ErrorTarget.AOAI_GRADER,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def _validate_grader_config(self) -> None:
|
|
62
|
+
"""Validate the grader configuration that this grader wrapper is using."""
|
|
63
|
+
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
def get_client(self) -> Any:
|
|
67
|
+
"""Construct an appropriate OpenAI client using this grader's model configuration.
|
|
68
|
+
Returns a slightly different client depending on whether or not this grader's model
|
|
69
|
+
configuration is for Azure OpenAI or OpenAI.
|
|
70
|
+
|
|
71
|
+
:return: The OpenAI client.
|
|
72
|
+
:rtype: [~openai.OpenAI, ~openai.AzureOpenAI]
|
|
73
|
+
"""
|
|
74
|
+
if "azure_endpoint" in self._model_config:
|
|
75
|
+
from openai import AzureOpenAI
|
|
76
|
+
# TODO set default values?
|
|
77
|
+
return AzureOpenAI(
|
|
78
|
+
azure_endpoint=self._model_config["azure_endpoint"],
|
|
79
|
+
api_key=self._model_config.get("api_key", None), # Default-style access to appease linters.
|
|
80
|
+
api_version=self._model_config.get("api_version", DEFAULT_AOAI_API_VERSION),
|
|
81
|
+
azure_deployment=self._model_config.get("azure_deployment", ""),
|
|
82
|
+
)
|
|
83
|
+
from openai import OpenAI
|
|
84
|
+
# TODO add default values for base_url and organization?
|
|
85
|
+
return OpenAI(
|
|
86
|
+
api_key=self._model_config["api_key"],
|
|
87
|
+
base_url=self._model_config.get("base_url", ""),
|
|
88
|
+
organization=self._model_config.get("organization", ""),
|
|
89
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
from typing import Any, Dict, Union, List
|
|
5
|
+
|
|
6
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
7
|
+
from openai.types.eval_create_params import TestingCriterionLabelModel
|
|
8
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
9
|
+
|
|
10
|
+
from .aoai_grader import AzureOpenAIGrader
|
|
11
|
+
|
|
12
|
+
@experimental
|
|
13
|
+
class AzureOpenAILabelGrader(AzureOpenAIGrader):
|
|
14
|
+
"""
|
|
15
|
+
Wrapper class for OpenAI's label model graders.
|
|
16
|
+
|
|
17
|
+
Supplying a LabelGrader to the `evaluate` method will cause an asynchronous request to evaluate
|
|
18
|
+
the grader via the OpenAI API. The results of the evaluation will then be merged into the standard
|
|
19
|
+
evaluation results.
|
|
20
|
+
|
|
21
|
+
:param model_config: The model configuration to use for the grader.
|
|
22
|
+
:type model_config: Union[
|
|
23
|
+
~azure.ai.evaluation.AzureOpenAIModelConfiguration,
|
|
24
|
+
~azure.ai.evaluation.OpenAIModelConfiguration
|
|
25
|
+
]
|
|
26
|
+
:param input: The list of label-based testing criterion for this grader. Individual
|
|
27
|
+
values of this list are expected to be dictionaries that match the format of any of the valid
|
|
28
|
+
(TestingCriterionLabelModelInput)[https://github.com/openai/openai-python/blob/ed53107e10e6c86754866b48f8bd862659134ca8/src/openai/types/eval_create_params.py#L125C1-L125C32]
|
|
29
|
+
subtypes.
|
|
30
|
+
:type input: List[Dict[str, str]]
|
|
31
|
+
:param labels: A list of strings representing the classification labels of this grader.
|
|
32
|
+
:type labels: List[str]
|
|
33
|
+
:param model: The model to use for the evaluation. Must support structured outputs.
|
|
34
|
+
:type model: str
|
|
35
|
+
:param name: The name of the grader.
|
|
36
|
+
:type name: str
|
|
37
|
+
:param passing_labels: The labels that indicate a passing result. Must be a subset of labels.
|
|
38
|
+
:type passing_labels: List[str]
|
|
39
|
+
:param kwargs: Additional keyword arguments to pass to the grader.
|
|
40
|
+
:type kwargs: Any
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
id = "aoai://label_model"
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
*,
|
|
50
|
+
model_config : Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
|
|
51
|
+
input: List[Dict[str, str]],
|
|
52
|
+
labels: List[str],
|
|
53
|
+
model: str,
|
|
54
|
+
name: str,
|
|
55
|
+
passing_labels: List[str],
|
|
56
|
+
**kwargs: Any
|
|
57
|
+
):
|
|
58
|
+
grader = TestingCriterionLabelModel(
|
|
59
|
+
input=input,
|
|
60
|
+
labels=labels,
|
|
61
|
+
model=model,
|
|
62
|
+
name=name,
|
|
63
|
+
passing_labels=passing_labels,
|
|
64
|
+
type="label_model",
|
|
65
|
+
)
|
|
66
|
+
super().__init__(model_config=model_config, grader_config=grader, **kwargs)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
from typing import Any, Dict, Union
|
|
5
|
+
from typing_extensions import Literal
|
|
6
|
+
|
|
7
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
8
|
+
from openai.types.eval_string_check_grader import EvalStringCheckGrader
|
|
9
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
10
|
+
|
|
11
|
+
from .aoai_grader import AzureOpenAIGrader
|
|
12
|
+
|
|
13
|
+
@experimental
|
|
14
|
+
class AzureOpenAIStringCheckGrader(AzureOpenAIGrader):
|
|
15
|
+
"""
|
|
16
|
+
Wrapper class for OpenAI's string check graders.
|
|
17
|
+
|
|
18
|
+
Supplying a StringCheckGrader to the `evaluate` method will cause an asynchronous request to evaluate
|
|
19
|
+
the grader via the OpenAI API. The results of the evaluation will then be merged into the standard
|
|
20
|
+
evaluation results.
|
|
21
|
+
|
|
22
|
+
:param model_config: The model configuration to use for the grader.
|
|
23
|
+
:type model_config: Union[
|
|
24
|
+
~azure.ai.evaluation.AzureOpenAIModelConfiguration,
|
|
25
|
+
~azure.ai.evaluation.OpenAIModelConfiguration
|
|
26
|
+
]
|
|
27
|
+
:param input: The input text. This may include template strings.
|
|
28
|
+
:type input: str
|
|
29
|
+
:param name: The name of the grader.
|
|
30
|
+
:type name: str
|
|
31
|
+
:param operation: The string check operation to perform. One of `eq`, `ne`, `like`, or `ilike`.
|
|
32
|
+
:type operation: Literal["eq", "ne", "like", "ilike"]
|
|
33
|
+
:param reference: The reference text. This may include template strings.
|
|
34
|
+
:type reference: str
|
|
35
|
+
:param kwargs: Additional keyword arguments to pass to the grader.
|
|
36
|
+
:type kwargs: Any
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
id = "aoai://string_check"
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
*,
|
|
46
|
+
model_config : Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
|
|
47
|
+
input: str,
|
|
48
|
+
name: str,
|
|
49
|
+
operation: Literal[
|
|
50
|
+
"eq",
|
|
51
|
+
"ne",
|
|
52
|
+
"like",
|
|
53
|
+
"ilike",
|
|
54
|
+
],
|
|
55
|
+
reference: str,
|
|
56
|
+
**kwargs: Any
|
|
57
|
+
):
|
|
58
|
+
grader = EvalStringCheckGrader(
|
|
59
|
+
input=input,
|
|
60
|
+
name=name,
|
|
61
|
+
operation=operation,
|
|
62
|
+
reference=reference,
|
|
63
|
+
type="string_check",
|
|
64
|
+
)
|
|
65
|
+
super().__init__(model_config=model_config, grader_config=grader, **kwargs)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
from typing import Any, Dict, Union
|
|
5
|
+
from typing_extensions import Literal
|
|
6
|
+
|
|
7
|
+
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
|
|
8
|
+
from openai.types.eval_text_similarity_grader import EvalTextSimilarityGrader
|
|
9
|
+
from azure.ai.evaluation._common._experimental import experimental
|
|
10
|
+
|
|
11
|
+
from .aoai_grader import AzureOpenAIGrader
|
|
12
|
+
|
|
13
|
+
@experimental
|
|
14
|
+
class AzureOpenAITextSimilarityGrader(AzureOpenAIGrader):
|
|
15
|
+
"""
|
|
16
|
+
Wrapper class for OpenAI's string check graders.
|
|
17
|
+
|
|
18
|
+
Supplying a StringCheckGrader to the `evaluate` method will cause an asynchronous request to evaluate
|
|
19
|
+
the grader via the OpenAI API. The results of the evaluation will then be merged into the standard
|
|
20
|
+
evaluation results.
|
|
21
|
+
|
|
22
|
+
:param model_config: The model configuration to use for the grader.
|
|
23
|
+
:type model_config: Union[
|
|
24
|
+
~azure.ai.evaluation.AzureOpenAIModelConfiguration,
|
|
25
|
+
~azure.ai.evaluation.OpenAIModelConfiguration
|
|
26
|
+
]
|
|
27
|
+
:param evaluation_metric: The evaluation metric to use.
|
|
28
|
+
:type evaluation_metric: Literal[
|
|
29
|
+
"fuzzy_match",
|
|
30
|
+
"bleu",
|
|
31
|
+
"gleu",
|
|
32
|
+
"meteor",
|
|
33
|
+
"rouge_1",
|
|
34
|
+
"rouge_2",
|
|
35
|
+
"rouge_3",
|
|
36
|
+
"rouge_4",
|
|
37
|
+
"rouge_5",
|
|
38
|
+
"rouge_l",
|
|
39
|
+
"cosine",
|
|
40
|
+
]
|
|
41
|
+
:param input: The text being graded.
|
|
42
|
+
:type input: str
|
|
43
|
+
:param pass_threshold: A float score where a value greater than or equal indicates a passing grade.
|
|
44
|
+
:type pass_threshold: float
|
|
45
|
+
:param reference: The text being graded against.
|
|
46
|
+
:type reference: str
|
|
47
|
+
:param name: The name of the grader.
|
|
48
|
+
:type name: str
|
|
49
|
+
:param kwargs: Additional keyword arguments to pass to the grader.
|
|
50
|
+
:type kwargs: Any
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
id = "aoai://text_similarity"
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
*,
|
|
60
|
+
model_config : Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
|
|
61
|
+
evaluation_metric: Literal[
|
|
62
|
+
"fuzzy_match",
|
|
63
|
+
"bleu",
|
|
64
|
+
"gleu",
|
|
65
|
+
"meteor",
|
|
66
|
+
"rouge_1",
|
|
67
|
+
"rouge_2",
|
|
68
|
+
"rouge_3",
|
|
69
|
+
"rouge_4",
|
|
70
|
+
"rouge_5",
|
|
71
|
+
"rouge_l",
|
|
72
|
+
"cosine",
|
|
73
|
+
],
|
|
74
|
+
input: str,
|
|
75
|
+
pass_threshold: float,
|
|
76
|
+
reference: str,
|
|
77
|
+
name: str,
|
|
78
|
+
**kwargs: Any
|
|
79
|
+
):
|
|
80
|
+
grader = EvalTextSimilarityGrader(
|
|
81
|
+
evaluation_metric=evaluation_metric,
|
|
82
|
+
input=input,
|
|
83
|
+
pass_threshold=pass_threshold,
|
|
84
|
+
name=name,
|
|
85
|
+
reference=reference,
|
|
86
|
+
type="text_similarity",
|
|
87
|
+
)
|
|
88
|
+
super().__init__(model_config=model_config, grader_config=grader, **kwargs)
|
|
@@ -8,12 +8,12 @@ from threading import Lock
|
|
|
8
8
|
from urllib.parse import quote
|
|
9
9
|
from json.decoder import JSONDecodeError
|
|
10
10
|
|
|
11
|
-
from azure.core.credentials import TokenCredential, AzureSasCredential
|
|
11
|
+
from azure.core.credentials import TokenCredential, AzureSasCredential, AccessToken
|
|
12
12
|
from azure.core.rest import HttpResponse
|
|
13
13
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
14
14
|
from azure.ai.evaluation._http_utils import HttpPipeline, get_http_client
|
|
15
15
|
from azure.ai.evaluation._azure._token_manager import AzureMLTokenManager
|
|
16
|
-
from azure.ai.evaluation.
|
|
16
|
+
from azure.ai.evaluation._constants import TokenScope
|
|
17
17
|
from ._models import BlobStoreInfo, Workspace
|
|
18
18
|
|
|
19
19
|
|
|
@@ -61,7 +61,7 @@ class LiteMLClient:
|
|
|
61
61
|
self._token_manager: Optional[AzureMLTokenManager] = None
|
|
62
62
|
self._credential: Optional[TokenCredential] = credential
|
|
63
63
|
|
|
64
|
-
def get_token(self) ->
|
|
64
|
+
def get_token(self) -> AccessToken:
|
|
65
65
|
return self._get_token_manager().get_token()
|
|
66
66
|
|
|
67
67
|
def get_credential(self) -> TokenCredential:
|
|
@@ -201,4 +201,4 @@ class LiteMLClient:
|
|
|
201
201
|
return url
|
|
202
202
|
|
|
203
203
|
def _get_headers(self) -> Dict[str, str]:
|
|
204
|
-
return {"Authorization": f"Bearer {self.get_token()}", "Content-Type": "application/json"}
|
|
204
|
+
return {"Authorization": f"Bearer {self.get_token().token}", "Content-Type": "application/json"}
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# ---------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# ---------------------------------------------------------
|
|
4
|
+
|
|
5
|
+
# NOTE:
|
|
6
|
+
# This is a simplified version of the original code from azure-ai-ml:
|
|
7
|
+
# sdk\ml\azure-ai-ml\azure\ai\ml\_azure_environments.py
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
from typing import Any, Dict, Final, Mapping, Optional, Sequence, TypedDict
|
|
13
|
+
|
|
14
|
+
from azure.core import AsyncPipelineClient
|
|
15
|
+
from azure.core.configuration import Configuration
|
|
16
|
+
from azure.core.rest import HttpRequest
|
|
17
|
+
from azure.core.pipeline.policies import ProxyPolicy, AsyncRetryPolicy
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AzureEnvironmentMetadata(TypedDict):
|
|
21
|
+
"""Configuration for various Azure environments. All endpoints include a trailing slash."""
|
|
22
|
+
portal_endpoint: str
|
|
23
|
+
"""The management portal for the Azure environment (e.g. https://portal.azure.com/)"""
|
|
24
|
+
resource_manager_endpoint: str
|
|
25
|
+
"""The API endpoint for Azure control plan (e.g. https://management.azure.com/)"""
|
|
26
|
+
active_directory_endpoint: str
|
|
27
|
+
"""The active directory endpoint used for authentication (e.g. https://login.microsoftonline.com/)"""
|
|
28
|
+
aml_resource_endpoint: str
|
|
29
|
+
"""The endpoint for Azure Machine Learning resources (e.g. https://ml.azure.com/)"""
|
|
30
|
+
storage_suffix: str
|
|
31
|
+
"""The suffix to use for storage endpoint URLs (e.g. core.windows.net)"""
|
|
32
|
+
registry_discovery_endpoint: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
_ENV_ARM_CLOUD_METADATA_URL: Final[str] = "ARM_CLOUD_METADATA_URL"
|
|
36
|
+
_ENV_DEFAULT_CLOUD_NAME: Final[str] = "AZUREML_CURRENT_CLOUD"
|
|
37
|
+
_ENV_REGISTRY_DISCOVERY_URL: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_URL"
|
|
38
|
+
_ENV_REGISTRY_DISCOVERY_REGION: Final[str] = "REGISTRY_DISCOVERY_ENDPOINT_REGION"
|
|
39
|
+
_DEFAULT_REGISTRY_DISCOVERY_REGION: Final[str] = "west"
|
|
40
|
+
_DEFAULT_AZURE_ENV_NAME: Final[str] = "AzureCloud"
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_ASYNC_LOCK = asyncio.Lock()
|
|
44
|
+
_KNOWN_AZURE_ENVIRONMENTS: Dict[str, AzureEnvironmentMetadata] = {
|
|
45
|
+
_DEFAULT_AZURE_ENV_NAME: {
|
|
46
|
+
"portal_endpoint": "https://portal.azure.com/",
|
|
47
|
+
"resource_manager_endpoint": "https://management.azure.com/",
|
|
48
|
+
"active_directory_endpoint": "https://login.microsoftonline.com/",
|
|
49
|
+
"aml_resource_endpoint": "https://ml.azure.com/",
|
|
50
|
+
"storage_suffix": "core.windows.net",
|
|
51
|
+
"registry_discovery_endpoint": "https://eastus.api.azureml.ms/",
|
|
52
|
+
},
|
|
53
|
+
"AzureChinaCloud": {
|
|
54
|
+
"portal_endpoint": "https://portal.azure.cn/",
|
|
55
|
+
"resource_manager_endpoint": "https://management.chinacloudapi.cn/",
|
|
56
|
+
"active_directory_endpoint": "https://login.chinacloudapi.cn/",
|
|
57
|
+
"aml_resource_endpoint": "https://ml.azure.cn/",
|
|
58
|
+
"storage_suffix": "core.chinacloudapi.cn",
|
|
59
|
+
"registry_discovery_endpoint": "https://chinaeast2.api.ml.azure.cn/",
|
|
60
|
+
},
|
|
61
|
+
"AzureUSGovernment": {
|
|
62
|
+
"portal_endpoint": "https://portal.azure.us/",
|
|
63
|
+
"resource_manager_endpoint": "https://management.usgovcloudapi.net/",
|
|
64
|
+
"active_directory_endpoint": "https://login.microsoftonline.us/",
|
|
65
|
+
"aml_resource_endpoint": "https://ml.azure.us/",
|
|
66
|
+
"storage_suffix": "core.usgovcloudapi.net",
|
|
67
|
+
"registry_discovery_endpoint": "https://usgovarizona.api.ml.azure.us/",
|
|
68
|
+
},
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class AzureEnvironmentClient:
|
|
73
|
+
DEFAULT_API_VERSION: Final[str] = "2019-05-01"
|
|
74
|
+
DEFAULT_AZURE_CLOUD_NAME: Final[str] = _DEFAULT_AZURE_ENV_NAME
|
|
75
|
+
|
|
76
|
+
def __init__(self, *, base_url: Optional[str] = None, **kwargs: Any) -> None:
|
|
77
|
+
base_url = base_url if base_url is not None else AzureEnvironmentClient.get_default_metadata_url()
|
|
78
|
+
|
|
79
|
+
config: Configuration = kwargs.pop("config", Configuration(**kwargs))
|
|
80
|
+
if config.retry_policy is None:
|
|
81
|
+
config.retry_policy = AsyncRetryPolicy(**kwargs)
|
|
82
|
+
if config.proxy_policy is None and "proxy" in kwargs:
|
|
83
|
+
config.proxy_policy = ProxyPolicy(proxies={"http": kwargs["proxy"], "https": kwargs["proxy"]})
|
|
84
|
+
|
|
85
|
+
self._async_client = AsyncPipelineClient(base_url, config=config, **kwargs)
|
|
86
|
+
|
|
87
|
+
async def get_default_cloud_name_async(self, *, update_cached: bool = True) -> str:
|
|
88
|
+
current_cloud_env = os.getenv(_ENV_DEFAULT_CLOUD_NAME)
|
|
89
|
+
if current_cloud_env is not None:
|
|
90
|
+
return current_cloud_env
|
|
91
|
+
|
|
92
|
+
arm_metadata_url = os.getenv(_ENV_ARM_CLOUD_METADATA_URL)
|
|
93
|
+
if arm_metadata_url is None:
|
|
94
|
+
return _DEFAULT_AZURE_ENV_NAME
|
|
95
|
+
|
|
96
|
+
# load clouds from metadata url
|
|
97
|
+
clouds = await self.get_clouds_async(metadata_url=arm_metadata_url, update_cached=update_cached)
|
|
98
|
+
matched = next(filter(lambda t: t[1]["resource_manager_endpoint"] in arm_metadata_url, clouds.items()), None)
|
|
99
|
+
if matched is None:
|
|
100
|
+
return _DEFAULT_AZURE_ENV_NAME
|
|
101
|
+
|
|
102
|
+
os.environ[_ENV_DEFAULT_CLOUD_NAME] = matched[0]
|
|
103
|
+
return matched[0]
|
|
104
|
+
|
|
105
|
+
async def get_cloud_async(self, name: str, *, update_cached: bool = True) -> Optional[AzureEnvironmentMetadata]:
|
|
106
|
+
default_endpoint: Optional[str]
|
|
107
|
+
|
|
108
|
+
def case_insensitive_match(d: Mapping[str, Any], key: str) -> Optional[Any]:
|
|
109
|
+
key = key.strip().lower()
|
|
110
|
+
return next((v for k,v in d.items() if k.strip().lower() == key), None)
|
|
111
|
+
|
|
112
|
+
async with _ASYNC_LOCK:
|
|
113
|
+
cloud = _KNOWN_AZURE_ENVIRONMENTS.get(name) or case_insensitive_match(_KNOWN_AZURE_ENVIRONMENTS, name)
|
|
114
|
+
if cloud:
|
|
115
|
+
return cloud
|
|
116
|
+
default_endpoint = (_KNOWN_AZURE_ENVIRONMENTS
|
|
117
|
+
.get(_DEFAULT_AZURE_ENV_NAME, {})
|
|
118
|
+
.get("resource_manager_endpoint"))
|
|
119
|
+
|
|
120
|
+
metadata_url = self.get_default_metadata_url(default_endpoint)
|
|
121
|
+
clouds = await self.get_clouds_async(metadata_url=metadata_url, update_cached=update_cached)
|
|
122
|
+
cloud_metadata = clouds.get(name) or case_insensitive_match(clouds, name)
|
|
123
|
+
|
|
124
|
+
return cloud_metadata
|
|
125
|
+
|
|
126
|
+
async def get_clouds_async(
|
|
127
|
+
self,
|
|
128
|
+
*,
|
|
129
|
+
metadata_url: Optional[str] = None,
|
|
130
|
+
update_cached: bool = True
|
|
131
|
+
) -> Mapping[str, AzureEnvironmentMetadata]:
|
|
132
|
+
metadata_url = metadata_url or self.get_default_metadata_url()
|
|
133
|
+
|
|
134
|
+
clouds: Mapping[str, AzureEnvironmentMetadata]
|
|
135
|
+
async with self._async_client.send_request(HttpRequest("GET", metadata_url)) as response: # type: ignore
|
|
136
|
+
response.raise_for_status()
|
|
137
|
+
clouds = await self._parse_cloud_endpoints_async(response.json())
|
|
138
|
+
|
|
139
|
+
if update_cached:
|
|
140
|
+
async with _ASYNC_LOCK:
|
|
141
|
+
recursive_update(_KNOWN_AZURE_ENVIRONMENTS, clouds)
|
|
142
|
+
return clouds
|
|
143
|
+
|
|
144
|
+
async def close(self) -> None:
|
|
145
|
+
await self._async_client.close()
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def get_default_metadata_url(default_endpoint: Optional[str] = None) -> str:
|
|
149
|
+
default_endpoint = default_endpoint or "https://management.azure.com/"
|
|
150
|
+
metadata_url = os.getenv(
|
|
151
|
+
_ENV_ARM_CLOUD_METADATA_URL,
|
|
152
|
+
f"{default_endpoint}metadata/endpoints?api-version={AzureEnvironmentClient.DEFAULT_API_VERSION}")
|
|
153
|
+
return metadata_url
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
async def _get_registry_discovery_url_async(cloud_name: str, cloud_suffix: str) -> str:
|
|
157
|
+
async with _ASYNC_LOCK:
|
|
158
|
+
discovery_url = _KNOWN_AZURE_ENVIRONMENTS.get(cloud_name, {}).get("registry_discovery_endpoint")
|
|
159
|
+
if discovery_url:
|
|
160
|
+
return discovery_url
|
|
161
|
+
|
|
162
|
+
discovery_url = os.getenv(_ENV_REGISTRY_DISCOVERY_URL)
|
|
163
|
+
if discovery_url is not None:
|
|
164
|
+
return discovery_url
|
|
165
|
+
|
|
166
|
+
region = os.getenv(_ENV_REGISTRY_DISCOVERY_REGION, _DEFAULT_REGISTRY_DISCOVERY_REGION)
|
|
167
|
+
return f"https://{cloud_name.lower()}{region}.api.ml.azure.{cloud_suffix}/"
|
|
168
|
+
|
|
169
|
+
@staticmethod
|
|
170
|
+
async def _parse_cloud_endpoints_async(data: Any) -> Mapping[str, AzureEnvironmentMetadata]:
|
|
171
|
+
# If there is only one cloud, you will get a dict, otherwise a list of dicts
|
|
172
|
+
cloud_data: Sequence[Mapping[str, Any]] = data if not isinstance(data, dict) else [data]
|
|
173
|
+
clouds: Dict[str, AzureEnvironmentMetadata] = {}
|
|
174
|
+
|
|
175
|
+
def append_trailing_slash(url: str) -> str:
|
|
176
|
+
return url if url.endswith("/") else f"{url}/"
|
|
177
|
+
|
|
178
|
+
for cloud in cloud_data:
|
|
179
|
+
try:
|
|
180
|
+
name: str = cloud["name"]
|
|
181
|
+
portal_endpoint: str = cloud["portal"]
|
|
182
|
+
cloud_suffix = ".".join(portal_endpoint.split(".")[2:]).replace("/", "")
|
|
183
|
+
discovery_url = await AzureEnvironmentClient._get_registry_discovery_url_async(name, cloud_suffix)
|
|
184
|
+
clouds[name] = {
|
|
185
|
+
"portal_endpoint": append_trailing_slash(portal_endpoint),
|
|
186
|
+
"resource_manager_endpoint": append_trailing_slash(cloud["resourceManager"]),
|
|
187
|
+
"active_directory_endpoint": append_trailing_slash(cloud["authentication"]["loginEndpoint"]),
|
|
188
|
+
"aml_resource_endpoint": append_trailing_slash(f"https://ml.azure.{cloud_suffix}/"),
|
|
189
|
+
"storage_suffix": cloud["suffixes"]["storage"],
|
|
190
|
+
"registry_discovery_endpoint": append_trailing_slash(discovery_url),
|
|
191
|
+
}
|
|
192
|
+
except KeyError:
|
|
193
|
+
continue
|
|
194
|
+
|
|
195
|
+
return clouds
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def recursive_update(d: Dict, u: Mapping) -> None:
|
|
199
|
+
"""Recursively update a dictionary.
|
|
200
|
+
|
|
201
|
+
:param Dict d: The dictionary to update.
|
|
202
|
+
:param Mapping u: The mapping to update from.
|
|
203
|
+
"""
|
|
204
|
+
for k, v in u.items():
|
|
205
|
+
if isinstance(v, Dict) and k in d:
|
|
206
|
+
recursive_update(d[k], v)
|
|
207
|
+
else:
|
|
208
|
+
d[k] = v
|
|
@@ -5,11 +5,12 @@ import os
|
|
|
5
5
|
import logging
|
|
6
6
|
import time
|
|
7
7
|
import inspect
|
|
8
|
-
from typing import cast, Optional, Union
|
|
8
|
+
from typing import cast, Optional, Union, Any
|
|
9
9
|
|
|
10
10
|
from azure.core.credentials import TokenCredential, AccessToken
|
|
11
11
|
from azure.identity import AzureCliCredential, DefaultAzureCredential, ManagedIdentityCredential
|
|
12
12
|
from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException
|
|
13
|
+
|
|
13
14
|
from ..simulator._model_tools._identity_manager import APITokenManager, AZURE_TOKEN_REFRESH_INTERVAL
|
|
14
15
|
|
|
15
16
|
|
|
@@ -71,7 +72,8 @@ class AzureMLTokenManager(APITokenManager):
|
|
|
71
72
|
# Fall back to using the parent implementation
|
|
72
73
|
return super().get_aad_credential()
|
|
73
74
|
|
|
74
|
-
def get_token(
|
|
75
|
+
def get_token(
|
|
76
|
+
self, scopes = None, claims: Union[str, None] = None, tenant_id: Union[str, None] = None, enable_cae: bool = False, **kwargs: Any) -> AccessToken:
|
|
75
77
|
"""Get the API token. If the token is not available or has expired, refresh the token.
|
|
76
78
|
|
|
77
79
|
:return: API token
|
|
@@ -79,12 +81,15 @@ class AzureMLTokenManager(APITokenManager):
|
|
|
79
81
|
"""
|
|
80
82
|
if self._token_needs_update():
|
|
81
83
|
credential = cast(TokenCredential, self.credential)
|
|
82
|
-
|
|
84
|
+
token_scope = self.token_scope
|
|
85
|
+
if scopes:
|
|
86
|
+
token_scope = scopes
|
|
87
|
+
access_token = credential.get_token(token_scope)
|
|
83
88
|
self._update_token(access_token)
|
|
84
89
|
|
|
85
|
-
return cast(
|
|
90
|
+
return cast(AccessToken, self.token) # check for none is hidden in the _token_needs_update method
|
|
86
91
|
|
|
87
|
-
async def get_token_async(self) ->
|
|
92
|
+
async def get_token_async(self) -> AccessToken:
|
|
88
93
|
"""Get the API token asynchronously. If the token is not available or has expired, refresh it.
|
|
89
94
|
|
|
90
95
|
:return: API token
|
|
@@ -99,7 +104,7 @@ class AzureMLTokenManager(APITokenManager):
|
|
|
99
104
|
access_token = get_token_method
|
|
100
105
|
self._update_token(access_token)
|
|
101
106
|
|
|
102
|
-
return cast(
|
|
107
|
+
return cast(AccessToken, self.token) # check for none is hidden in the _token_needs_update method
|
|
103
108
|
|
|
104
109
|
def _token_needs_update(self) -> bool:
|
|
105
110
|
current_time = time.time()
|
|
@@ -112,7 +117,7 @@ class AzureMLTokenManager(APITokenManager):
|
|
|
112
117
|
)
|
|
113
118
|
|
|
114
119
|
def _update_token(self, access_token: AccessToken) -> None:
|
|
115
|
-
self.token =
|
|
120
|
+
self.token = access_token
|
|
116
121
|
self.token_expiry_time = access_token.expires_on
|
|
117
122
|
self.last_refresh_time = time.time()
|
|
118
123
|
self.logger.info("Refreshed Azure management token.")
|
|
@@ -8,9 +8,14 @@
|
|
|
8
8
|
from . import constants
|
|
9
9
|
from .rai_service import evaluate_with_rai_service
|
|
10
10
|
from .utils import get_harm_severity_level
|
|
11
|
+
from .evaluation_onedp_client import EvaluationServiceOneDPClient
|
|
12
|
+
from .onedp.models import EvaluationUpload, EvaluationResult
|
|
11
13
|
|
|
12
14
|
__all__ = [
|
|
13
15
|
"get_harm_severity_level",
|
|
14
16
|
"evaluate_with_rai_service",
|
|
15
17
|
"constants",
|
|
18
|
+
"EvaluationServiceOneDPClient",
|
|
19
|
+
"EvaluationResult",
|
|
20
|
+
"EvaluationUpload",
|
|
16
21
|
]
|