scale-nucleus 0.1.24__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/client.py +14 -0
- cli/datasets.py +77 -0
- cli/helpers/__init__.py +0 -0
- cli/helpers/nucleus_url.py +10 -0
- cli/helpers/web_helper.py +40 -0
- cli/install_completion.py +33 -0
- cli/jobs.py +42 -0
- cli/models.py +35 -0
- cli/nu.py +42 -0
- cli/reference.py +8 -0
- cli/slices.py +62 -0
- cli/tests.py +121 -0
- nucleus/__init__.py +446 -710
- nucleus/annotation.py +405 -85
- nucleus/autocurate.py +9 -0
- nucleus/connection.py +87 -0
- nucleus/constants.py +5 -1
- nucleus/data_transfer_object/__init__.py +0 -0
- nucleus/data_transfer_object/dataset_details.py +9 -0
- nucleus/data_transfer_object/dataset_info.py +26 -0
- nucleus/data_transfer_object/dataset_size.py +5 -0
- nucleus/data_transfer_object/scenes_list.py +18 -0
- nucleus/dataset.py +1137 -212
- nucleus/dataset_item.py +130 -26
- nucleus/dataset_item_uploader.py +297 -0
- nucleus/deprecation_warning.py +32 -0
- nucleus/errors.py +9 -0
- nucleus/job.py +71 -3
- nucleus/logger.py +9 -0
- nucleus/metadata_manager.py +45 -0
- nucleus/metrics/__init__.py +10 -0
- nucleus/metrics/base.py +117 -0
- nucleus/metrics/categorization_metrics.py +197 -0
- nucleus/metrics/errors.py +7 -0
- nucleus/metrics/filters.py +40 -0
- nucleus/metrics/geometry.py +198 -0
- nucleus/metrics/metric_utils.py +28 -0
- nucleus/metrics/polygon_metrics.py +480 -0
- nucleus/metrics/polygon_utils.py +299 -0
- nucleus/model.py +121 -15
- nucleus/model_run.py +34 -57
- nucleus/payload_constructor.py +29 -19
- nucleus/prediction.py +259 -17
- nucleus/pydantic_base.py +26 -0
- nucleus/retry_strategy.py +4 -0
- nucleus/scene.py +204 -19
- nucleus/slice.py +230 -67
- nucleus/upload_response.py +20 -9
- nucleus/url_utils.py +4 -0
- nucleus/utils.py +134 -37
- nucleus/validate/__init__.py +24 -0
- nucleus/validate/client.py +168 -0
- nucleus/validate/constants.py +20 -0
- nucleus/validate/data_transfer_objects/__init__.py +0 -0
- nucleus/validate/data_transfer_objects/eval_function.py +81 -0
- nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
- nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
- nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
- nucleus/validate/errors.py +6 -0
- nucleus/validate/eval_functions/__init__.py +0 -0
- nucleus/validate/eval_functions/available_eval_functions.py +212 -0
- nucleus/validate/eval_functions/base_eval_function.py +60 -0
- nucleus/validate/scenario_test.py +143 -0
- nucleus/validate/scenario_test_evaluation.py +114 -0
- nucleus/validate/scenario_test_metric.py +14 -0
- nucleus/validate/utils.py +8 -0
- {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
- scale_nucleus-0.6.4.dist-info/METADATA +213 -0
- scale_nucleus-0.6.4.dist-info/RECORD +71 -0
- {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
- scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
- scale_nucleus-0.1.24.dist-info/METADATA +0 -85
- scale_nucleus-0.1.24.dist-info/RECORD +0 -21
@@ -0,0 +1,212 @@
|
|
1
|
+
import itertools
|
2
|
+
from typing import Callable, Dict, List, Type, Union
|
3
|
+
|
4
|
+
from nucleus.logger import logger
|
5
|
+
from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction
|
6
|
+
|
7
|
+
from ..data_transfer_objects.eval_function import EvalFunctionEntry
|
8
|
+
from ..errors import EvalFunctionNotAvailableError
|
9
|
+
|
10
|
+
MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes"
|
11
|
+
|
12
|
+
|
13
|
+
class BoundingBoxIOU(BaseEvalFunction):
|
14
|
+
@classmethod
|
15
|
+
def expected_name(cls) -> str:
|
16
|
+
return "bbox_iou"
|
17
|
+
|
18
|
+
|
19
|
+
class BoundingBoxMeanAveragePrecision(BaseEvalFunction):
|
20
|
+
@classmethod
|
21
|
+
def expected_name(cls) -> str:
|
22
|
+
return "bbox_map"
|
23
|
+
|
24
|
+
|
25
|
+
class BoundingBoxRecall(BaseEvalFunction):
|
26
|
+
@classmethod
|
27
|
+
def expected_name(cls) -> str:
|
28
|
+
return "bbox_recall"
|
29
|
+
|
30
|
+
|
31
|
+
class BoundingBoxPrecision(BaseEvalFunction):
|
32
|
+
@classmethod
|
33
|
+
def expected_name(cls) -> str:
|
34
|
+
return "bbox_precision"
|
35
|
+
|
36
|
+
|
37
|
+
class CategorizationF1(BaseEvalFunction):
|
38
|
+
@classmethod
|
39
|
+
def expected_name(cls) -> str:
|
40
|
+
return "cat_f1"
|
41
|
+
|
42
|
+
|
43
|
+
class CustomEvalFunction(BaseEvalFunction):
|
44
|
+
@classmethod
|
45
|
+
def expected_name(cls) -> str:
|
46
|
+
raise NotImplementedError(
|
47
|
+
"Custm evaluation functions are coming soon"
|
48
|
+
) # Placeholder: See super().eval_func_entry for actual name
|
49
|
+
|
50
|
+
|
51
|
+
class StandardEvalFunction(BaseEvalFunction):
|
52
|
+
"""Class for standard Model CI eval functions that have not been added as attributes on
|
53
|
+
AvailableEvalFunctions yet.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def __init__(self, eval_function_entry: EvalFunctionEntry):
|
57
|
+
logger.warning(
|
58
|
+
"Standard function %s not implemented as an attribute on AvailableEvalFunctions",
|
59
|
+
eval_function_entry.name,
|
60
|
+
)
|
61
|
+
super().__init__(eval_function_entry)
|
62
|
+
|
63
|
+
@classmethod
|
64
|
+
def expected_name(cls) -> str:
|
65
|
+
return "public_function" # Placeholder: See super().eval_func_entry for actual name
|
66
|
+
|
67
|
+
|
68
|
+
class EvalFunctionNotAvailable(BaseEvalFunction):
|
69
|
+
def __init__(
|
70
|
+
self, not_available_name: str
|
71
|
+
): # pylint: disable=super-init-not-called
|
72
|
+
self.not_available_name = not_available_name
|
73
|
+
|
74
|
+
def __call__(self, *args, **kwargs):
|
75
|
+
self._raise_error()
|
76
|
+
|
77
|
+
def _op_to_test_metric(self, *args, **kwargs):
|
78
|
+
self._raise_error()
|
79
|
+
|
80
|
+
def _raise_error(self):
|
81
|
+
raise EvalFunctionNotAvailableError(
|
82
|
+
f"Eval function '{self.not_available_name}' is not available to the current user. "
|
83
|
+
f"Is Model CI enabled for the user?"
|
84
|
+
)
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def expected_name(cls) -> str:
|
88
|
+
return "public_function" # Placeholder: See super().eval_func_entry for actual name
|
89
|
+
|
90
|
+
|
91
|
+
EvalFunction = Union[
|
92
|
+
Type[BoundingBoxIOU],
|
93
|
+
Type[BoundingBoxMeanAveragePrecision],
|
94
|
+
Type[BoundingBoxPrecision],
|
95
|
+
Type[BoundingBoxRecall],
|
96
|
+
Type[CustomEvalFunction],
|
97
|
+
Type[EvalFunctionNotAvailable],
|
98
|
+
Type[StandardEvalFunction],
|
99
|
+
]
|
100
|
+
|
101
|
+
|
102
|
+
class AvailableEvalFunctions:
|
103
|
+
"""Collection class that acts as a common entrypoint to access evaluation functions. Standard evaluation functions
|
104
|
+
provided by Scale are attributes of this class.
|
105
|
+
|
106
|
+
The available evaluation functions are listed in the sample below::
|
107
|
+
|
108
|
+
e = client.validate.eval_functions
|
109
|
+
unit_test_criteria = [
|
110
|
+
e.bbox_iou() > 5,
|
111
|
+
e.bbox_map() > 0.95,
|
112
|
+
e.bbox_precision() > 0.8,
|
113
|
+
e.bbox_recall() > 0.5,
|
114
|
+
]
|
115
|
+
"""
|
116
|
+
|
117
|
+
# pylint: disable=too-many-instance-attributes
|
118
|
+
|
119
|
+
def __init__(self, available_functions: List[EvalFunctionEntry]):
|
120
|
+
assert (
|
121
|
+
available_functions
|
122
|
+
), "Passed no available functions for current user. Is the feature flag enabled?"
|
123
|
+
self._public_func_entries: Dict[str, EvalFunctionEntry] = {
|
124
|
+
f.name: f for f in available_functions if f.is_public
|
125
|
+
}
|
126
|
+
# NOTE: Public are assigned
|
127
|
+
self._public_to_function: Dict[str, BaseEvalFunction] = {}
|
128
|
+
self._custom_to_function: Dict[str, CustomEvalFunction] = {
|
129
|
+
f.name: CustomEvalFunction(f)
|
130
|
+
for f in available_functions
|
131
|
+
if not f.is_public
|
132
|
+
}
|
133
|
+
self.bbox_iou = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore
|
134
|
+
self.bbox_precision = self._assign_eval_function_if_defined(
|
135
|
+
BoundingBoxPrecision # type: ignore
|
136
|
+
)
|
137
|
+
self.bbox_recall = self._assign_eval_function_if_defined(
|
138
|
+
BoundingBoxRecall # type: ignore
|
139
|
+
)
|
140
|
+
self.bbox_map = self._assign_eval_function_if_defined(
|
141
|
+
BoundingBoxMeanAveragePrecision # type: ignore
|
142
|
+
)
|
143
|
+
self.cat_f1 = self._assign_eval_function_if_defined(
|
144
|
+
CategorizationF1 # type: ignore
|
145
|
+
)
|
146
|
+
|
147
|
+
# Add public entries that have not been implemented as an attribute on this class
|
148
|
+
for func_entry in self._public_func_entries.values():
|
149
|
+
if func_entry.name not in self._public_to_function:
|
150
|
+
self._public_to_function[
|
151
|
+
func_entry.name
|
152
|
+
] = StandardEvalFunction(func_entry)
|
153
|
+
|
154
|
+
def __repr__(self):
|
155
|
+
"""Standard functions are ones Scale provides and custom ones customer defined"""
|
156
|
+
# NOTE: setting to lower to be consistent with attribute names
|
157
|
+
functions_lower = [
|
158
|
+
str(name).lower() for name in self._public_func_entries.keys()
|
159
|
+
]
|
160
|
+
return (
|
161
|
+
f"<AvailableEvaluationFunctions: public:{functions_lower} "
|
162
|
+
f"private: {list(self._custom_to_function.keys())}"
|
163
|
+
)
|
164
|
+
|
165
|
+
@property
|
166
|
+
def public_functions(self) -> Dict[str, BaseEvalFunction]:
|
167
|
+
"""Standard functions provided by Model CI.
|
168
|
+
|
169
|
+
Notes:
|
170
|
+
These functions are also available as attributes on :class:`AvailableEvalFunctions`
|
171
|
+
|
172
|
+
Returns:
|
173
|
+
Dict of function name to :class:`BaseEvalFunction`.
|
174
|
+
"""
|
175
|
+
return self._public_to_function
|
176
|
+
|
177
|
+
@property
|
178
|
+
def private_functions(self) -> Dict[str, CustomEvalFunction]:
|
179
|
+
"""Custom functions uploaded to Model CI
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
Dict of function name to :class:`CustomEvalFunction`.
|
183
|
+
"""
|
184
|
+
return self._custom_to_function
|
185
|
+
|
186
|
+
def _assign_eval_function_if_defined(
|
187
|
+
self,
|
188
|
+
eval_function_constructor: Callable[[EvalFunctionEntry], EvalFunction],
|
189
|
+
):
|
190
|
+
"""Helper function for book-keeping and assignment of standard Scale provided functions that are accessible
|
191
|
+
via attribute access
|
192
|
+
"""
|
193
|
+
# TODO(gunnar): Too convoluted .. simplify
|
194
|
+
expected_name = eval_function_constructor.expected_name() # type: ignore
|
195
|
+
if expected_name in self._public_func_entries:
|
196
|
+
definition = self._public_func_entries[expected_name]
|
197
|
+
eval_function = eval_function_constructor(definition)
|
198
|
+
self._public_to_function[expected_name] = eval_function # type: ignore
|
199
|
+
return eval_function
|
200
|
+
else:
|
201
|
+
return EvalFunctionNotAvailable(expected_name)
|
202
|
+
|
203
|
+
def from_id(self, eval_function_id: str):
|
204
|
+
for eval_func in itertools.chain(
|
205
|
+
self._public_to_function.values(),
|
206
|
+
self._custom_to_function.values(),
|
207
|
+
):
|
208
|
+
if eval_func.id == eval_function_id:
|
209
|
+
return eval_func
|
210
|
+
raise EvalFunctionNotAvailableError(
|
211
|
+
f"Could not find Eval Function with id {eval_function_id}"
|
212
|
+
)
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import abc
|
2
|
+
|
3
|
+
from ..constants import ThresholdComparison
|
4
|
+
from ..data_transfer_objects.eval_function import (
|
5
|
+
EvalFunctionEntry,
|
6
|
+
EvaluationCriterion,
|
7
|
+
)
|
8
|
+
|
9
|
+
|
10
|
+
class BaseEvalFunction(abc.ABC):
|
11
|
+
"""Abstract base class for concrete implementations of EvalFunctions
|
12
|
+
|
13
|
+
Operating on this class with comparison operators produces an EvaluationCriterion
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, eval_func_entry: EvalFunctionEntry):
|
17
|
+
self.eval_func_entry = eval_func_entry
|
18
|
+
self.id = eval_func_entry.id
|
19
|
+
self.name = eval_func_entry.name
|
20
|
+
|
21
|
+
def __repr__(self):
|
22
|
+
return f"<EvalFunction: name={self.name}, id={self.id}>"
|
23
|
+
|
24
|
+
@classmethod
|
25
|
+
@abc.abstractmethod
|
26
|
+
def expected_name(cls) -> str:
|
27
|
+
"""Name to look for in the EvalFunctionDefinitions"""
|
28
|
+
|
29
|
+
def __call__(self) -> "BaseEvalFunction":
|
30
|
+
"""Adding call to prepare for being able to pass parameters to function
|
31
|
+
|
32
|
+
Notes:
|
33
|
+
Technically now you could do something like eval_function > 0.5 but we want it
|
34
|
+
to look like eval_function() > 0.5 to support eval_function(parameters) > 0.5
|
35
|
+
in the future
|
36
|
+
"""
|
37
|
+
return self
|
38
|
+
|
39
|
+
def __gt__(self, other) -> EvaluationCriterion:
|
40
|
+
return self._op_to_test_metric(ThresholdComparison.GREATER_THAN, other)
|
41
|
+
|
42
|
+
def __ge__(self, other) -> EvaluationCriterion:
|
43
|
+
return self._op_to_test_metric(
|
44
|
+
ThresholdComparison.GREATER_THAN_EQUAL_TO, other
|
45
|
+
)
|
46
|
+
|
47
|
+
def __lt__(self, other) -> EvaluationCriterion:
|
48
|
+
return self._op_to_test_metric(ThresholdComparison.LESS_THAN, other)
|
49
|
+
|
50
|
+
def __le__(self, other) -> EvaluationCriterion:
|
51
|
+
return self._op_to_test_metric(
|
52
|
+
ThresholdComparison.LESS_THAN_EQUAL_TO, other
|
53
|
+
)
|
54
|
+
|
55
|
+
def _op_to_test_metric(self, comparison: ThresholdComparison, value):
|
56
|
+
return EvaluationCriterion(
|
57
|
+
eval_function_id=self.eval_func_entry.id,
|
58
|
+
threshold_comparison=comparison,
|
59
|
+
threshold=value,
|
60
|
+
)
|
@@ -0,0 +1,143 @@
|
|
1
|
+
"""Scenario Tests combine collections of data and evaluation metrics to accelerate model evaluation.
|
2
|
+
|
3
|
+
With Model CI Scenario Tests, an ML engineer can define a Scenario Test from critical
|
4
|
+
edge case scenarios that the model must get right (e.g. pedestrians at night),
|
5
|
+
and have confidence that they’re always shipping the best model.
|
6
|
+
"""
|
7
|
+
from dataclasses import dataclass, field
|
8
|
+
from typing import List
|
9
|
+
|
10
|
+
from ..connection import Connection
|
11
|
+
from ..constants import NAME_KEY, SLICE_ID_KEY
|
12
|
+
from ..dataset_item import DatasetItem
|
13
|
+
from .data_transfer_objects.eval_function import EvaluationCriterion
|
14
|
+
from .data_transfer_objects.scenario_test_evaluations import GetEvalHistory
|
15
|
+
from .data_transfer_objects.scenario_test_metric import AddScenarioTestMetric
|
16
|
+
from .scenario_test_evaluation import ScenarioTestEvaluation
|
17
|
+
from .scenario_test_metric import ScenarioTestMetric
|
18
|
+
|
19
|
+
DATASET_ITEMS_KEY = "dataset_items"
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class ScenarioTest:
|
24
|
+
"""A Scenario Test combines a slice and at least one evaluation criterion. A :class:`ScenarioTest` is not created through
|
25
|
+
the default constructor but using the instructions shown in :class:`Validate`. This :class:`ScenarioTest` class only
|
26
|
+
simplifies the interaction with the scenario tests from this SDK.
|
27
|
+
|
28
|
+
Attributes:
|
29
|
+
id (str): The ID of the scenario test.
|
30
|
+
connection (Connection): The connection to Nucleus API.
|
31
|
+
name (str): The name of the scenario test.
|
32
|
+
slice_id (str): The ID of the associated Nucleus slice.
|
33
|
+
"""
|
34
|
+
|
35
|
+
id: str
|
36
|
+
connection: Connection = field(repr=False)
|
37
|
+
name: str = field(init=False)
|
38
|
+
slice_id: str = field(init=False)
|
39
|
+
|
40
|
+
def __post_init__(self):
|
41
|
+
# TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call
|
42
|
+
response = self.connection.get(
|
43
|
+
f"validate/scenario_test/{self.id}/info",
|
44
|
+
)
|
45
|
+
self.name = response[NAME_KEY]
|
46
|
+
self.slice_id = response[SLICE_ID_KEY]
|
47
|
+
|
48
|
+
def add_criterion(
|
49
|
+
self, evaluation_criterion: EvaluationCriterion
|
50
|
+
) -> ScenarioTestMetric:
|
51
|
+
"""Creates and adds a new criteria to the :class:`ScenarioTest`. ::
|
52
|
+
|
53
|
+
import nucleus
|
54
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
55
|
+
scenario_test = client.validate.create_scenario_test(
|
56
|
+
"sample_scenario_test", "slc_bx86ea222a6g057x4380"
|
57
|
+
)
|
58
|
+
|
59
|
+
e = client.validate.eval_functions
|
60
|
+
# Assuming a user would like to add all available public evaluation functions as criteria
|
61
|
+
scenario_test.add_criterion(
|
62
|
+
e.bbox_iou() > 0.5
|
63
|
+
)
|
64
|
+
scenario_test.add_criterion(
|
65
|
+
e.bbox_map() > 0.85
|
66
|
+
)
|
67
|
+
scenario_test.add_criterion(
|
68
|
+
e.bbox_precision() > 0.7
|
69
|
+
)
|
70
|
+
scenario_test.add_criterion(
|
71
|
+
e.bbox_recall() > 0.6
|
72
|
+
)
|
73
|
+
|
74
|
+
Args:
|
75
|
+
evaluation_criterion: :class:`EvaluationCriterion` created by comparison with an :class:`EvalFunction`
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The created ScenarioTestMetric object.
|
79
|
+
"""
|
80
|
+
response = self.connection.post(
|
81
|
+
AddScenarioTestMetric(
|
82
|
+
scenario_test_name=self.name,
|
83
|
+
eval_function_id=evaluation_criterion.eval_function_id,
|
84
|
+
threshold=evaluation_criterion.threshold,
|
85
|
+
threshold_comparison=evaluation_criterion.threshold_comparison,
|
86
|
+
).dict(),
|
87
|
+
"validate/scenario_test_metric",
|
88
|
+
)
|
89
|
+
return ScenarioTestMetric(
|
90
|
+
scenario_test_id=response["scenario_test_id"],
|
91
|
+
eval_function_id=response["eval_function_id"],
|
92
|
+
threshold=evaluation_criterion.threshold,
|
93
|
+
threshold_comparison=evaluation_criterion.threshold_comparison,
|
94
|
+
)
|
95
|
+
|
96
|
+
def get_criteria(self) -> List[ScenarioTestMetric]:
|
97
|
+
"""Retrieves all criteria of the :class:`ScenarioTest`. ::
|
98
|
+
|
99
|
+
import nucleus
|
100
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
101
|
+
scenario_test = client.validate.list_scenario_tests()[0]
|
102
|
+
|
103
|
+
scenario_test.get_criteria()
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
A list of ScenarioTestMetric objects.
|
107
|
+
"""
|
108
|
+
response = self.connection.get(
|
109
|
+
f"validate/scenario_test/{self.id}/metrics",
|
110
|
+
)
|
111
|
+
return [
|
112
|
+
ScenarioTestMetric(**metric)
|
113
|
+
for metric in response["scenario_test_metrics"]
|
114
|
+
]
|
115
|
+
|
116
|
+
def get_eval_history(self) -> List[ScenarioTestEvaluation]:
|
117
|
+
"""Retrieves evaluation history for :class:`ScenarioTest`. ::
|
118
|
+
|
119
|
+
import nucleus
|
120
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
121
|
+
scenario_test = client.validate.list_scenario_tests()[0]
|
122
|
+
|
123
|
+
scenario_test.get_eval_history()
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
A list of :class:`ScenarioTestEvaluation` objects.
|
127
|
+
"""
|
128
|
+
response = self.connection.get(
|
129
|
+
f"validate/scenario_test/{self.id}/eval_history",
|
130
|
+
)
|
131
|
+
eval_history = GetEvalHistory.parse_obj(response)
|
132
|
+
return [
|
133
|
+
ScenarioTestEvaluation(evaluation.id, self.connection)
|
134
|
+
for evaluation in eval_history.evaluations
|
135
|
+
]
|
136
|
+
|
137
|
+
def get_items(self) -> List[DatasetItem]:
|
138
|
+
response = self.connection.get(
|
139
|
+
f"validate/scenario_test/{self.id}/items",
|
140
|
+
)
|
141
|
+
return [
|
142
|
+
DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
|
143
|
+
]
|
@@ -0,0 +1,114 @@
|
|
1
|
+
"""Data types for Scenario Test Evaluation results."""
|
2
|
+
from dataclasses import InitVar, dataclass, field
|
3
|
+
from enum import Enum
|
4
|
+
from typing import List, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
|
8
|
+
from nucleus.connection import Connection
|
9
|
+
from nucleus.constants import DATASET_ITEM_ID_KEY, MODEL_ID_KEY, STATUS_KEY
|
10
|
+
from nucleus.validate.constants import (
|
11
|
+
EVAL_FUNCTION_ID_KEY,
|
12
|
+
EVALUATION_ID_KEY,
|
13
|
+
PASS_KEY,
|
14
|
+
RESULT_KEY,
|
15
|
+
SCENARIO_TEST_ID_KEY,
|
16
|
+
)
|
17
|
+
|
18
|
+
from .utils import try_convert_float
|
19
|
+
|
20
|
+
SCENARIO_TEST_EVAL_KEY = "scenario_test_evaluation"
|
21
|
+
ITEM_EVAL_KEY = "scenario_test_item_evaluations"
|
22
|
+
|
23
|
+
|
24
|
+
class ScenarioTestEvaluationStatus(Enum):
|
25
|
+
"""The Job status of scenario test evaluation."""
|
26
|
+
|
27
|
+
PENDING = "pending"
|
28
|
+
STARTED = "started"
|
29
|
+
COMPLETED = "completed"
|
30
|
+
ERRORED = "errored"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass(frozen=True)
|
34
|
+
class ScenarioTestItemEvaluation:
|
35
|
+
"""Dataset item-level results of an evaluation of a scenario test.
|
36
|
+
Note that this class is immutable.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
evaluation_id (str): The ID of the associated scenario test evaluation
|
40
|
+
scenario_test_id (str): The ID of the associated scenario test.
|
41
|
+
eval_function_id (str): The ID of the associated evaluation function.
|
42
|
+
dataset_item_id (str): The ID of the dataset item of this evaluation.
|
43
|
+
result (Optional[float]): The numerical result of the evaluation on this item.
|
44
|
+
passed (bool): Whether the result was sufficient to pass the test for this item.
|
45
|
+
"""
|
46
|
+
|
47
|
+
evaluation_id: str
|
48
|
+
scenario_test_id: str
|
49
|
+
eval_function_id: str
|
50
|
+
dataset_item_id: str
|
51
|
+
result: Optional[float]
|
52
|
+
passed: bool
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class ScenarioTestEvaluation:
|
57
|
+
"""The results and attributes of an evaluation of a scenario test.
|
58
|
+
|
59
|
+
Attributes:
|
60
|
+
id (str): The ID of this scenario test evaluation.
|
61
|
+
scenario_test_id (str): The ID of the associated scenario test.
|
62
|
+
eval_function_id (str): The ID of the associated evaluation function.
|
63
|
+
model_id (str): THe ID of the associated model.
|
64
|
+
status (str): The status of the evaluation job.
|
65
|
+
result (Optional[float]): The float result of the evaluation.
|
66
|
+
passed (bool): Whether the scenario test was passed.
|
67
|
+
item_evals (List[ScenarioTestItemEvaluation]): The individual results for each dataset item.
|
68
|
+
connection (Connection): The connection to the Nucleus API.
|
69
|
+
"""
|
70
|
+
|
71
|
+
# pylint: disable=too-many-instance-attributes
|
72
|
+
|
73
|
+
id: str
|
74
|
+
scenario_test_id: str = field(init=False)
|
75
|
+
eval_function_id: str = field(init=False)
|
76
|
+
model_id: str = field(init=False)
|
77
|
+
status: ScenarioTestEvaluationStatus = field(init=False)
|
78
|
+
result: Optional[float] = field(init=False)
|
79
|
+
passed: bool = field(init=False)
|
80
|
+
item_evals: List[ScenarioTestItemEvaluation] = field(init=False)
|
81
|
+
connection: InitVar[Connection]
|
82
|
+
|
83
|
+
def __post_init__(self, connection: Connection):
|
84
|
+
# TODO(gunnar): Having the function call /info on every construction is too slow. The original
|
85
|
+
# endpoint should rather return the necessary human-readable information
|
86
|
+
response = connection.make_request(
|
87
|
+
{},
|
88
|
+
f"validate/eval/{self.id}/info",
|
89
|
+
requests_command=requests.get,
|
90
|
+
)
|
91
|
+
eval_response = response[SCENARIO_TEST_EVAL_KEY]
|
92
|
+
items_response = response[ITEM_EVAL_KEY]
|
93
|
+
|
94
|
+
self.scenario_test_id: str = eval_response[SCENARIO_TEST_ID_KEY]
|
95
|
+
self.eval_function_id: str = eval_response[EVAL_FUNCTION_ID_KEY]
|
96
|
+
self.model_id: str = eval_response[MODEL_ID_KEY]
|
97
|
+
self.status: ScenarioTestEvaluationStatus = (
|
98
|
+
ScenarioTestEvaluationStatus(eval_response[STATUS_KEY])
|
99
|
+
)
|
100
|
+
self.result: Optional[float] = try_convert_float(
|
101
|
+
eval_response[RESULT_KEY]
|
102
|
+
)
|
103
|
+
self.passed: bool = bool(eval_response[PASS_KEY])
|
104
|
+
self.item_evals: List[ScenarioTestItemEvaluation] = [
|
105
|
+
ScenarioTestItemEvaluation(
|
106
|
+
evaluation_id=res[EVALUATION_ID_KEY],
|
107
|
+
scenario_test_id=res[SCENARIO_TEST_ID_KEY],
|
108
|
+
eval_function_id=res[EVAL_FUNCTION_ID_KEY],
|
109
|
+
dataset_item_id=res[DATASET_ITEM_ID_KEY],
|
110
|
+
result=try_convert_float(res[RESULT_KEY]),
|
111
|
+
passed=bool(res[PASS_KEY]),
|
112
|
+
)
|
113
|
+
for res in items_response
|
114
|
+
]
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from nucleus.pydantic_base import ImmutableModel
|
2
|
+
|
3
|
+
from .constants import ThresholdComparison
|
4
|
+
|
5
|
+
|
6
|
+
class ScenarioTestMetric(ImmutableModel):
|
7
|
+
"""A Scenario Test Metric is an evaluation function combined with a comparator and associated with a Scenario Test.
|
8
|
+
Scenario Test Metrics serve as the basis when evaluating a Model on a Scenario Test.
|
9
|
+
"""
|
10
|
+
|
11
|
+
scenario_test_id: str
|
12
|
+
eval_function_id: str
|
13
|
+
threshold: float
|
14
|
+
threshold_comparison: ThresholdComparison
|
File without changes
|