scale-nucleus 0.1.24__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/client.py +14 -0
- cli/datasets.py +77 -0
- cli/helpers/__init__.py +0 -0
- cli/helpers/nucleus_url.py +10 -0
- cli/helpers/web_helper.py +40 -0
- cli/install_completion.py +33 -0
- cli/jobs.py +42 -0
- cli/models.py +35 -0
- cli/nu.py +42 -0
- cli/reference.py +8 -0
- cli/slices.py +62 -0
- cli/tests.py +121 -0
- nucleus/__init__.py +446 -710
- nucleus/annotation.py +405 -85
- nucleus/autocurate.py +9 -0
- nucleus/connection.py +87 -0
- nucleus/constants.py +5 -1
- nucleus/data_transfer_object/__init__.py +0 -0
- nucleus/data_transfer_object/dataset_details.py +9 -0
- nucleus/data_transfer_object/dataset_info.py +26 -0
- nucleus/data_transfer_object/dataset_size.py +5 -0
- nucleus/data_transfer_object/scenes_list.py +18 -0
- nucleus/dataset.py +1137 -212
- nucleus/dataset_item.py +130 -26
- nucleus/dataset_item_uploader.py +297 -0
- nucleus/deprecation_warning.py +32 -0
- nucleus/errors.py +9 -0
- nucleus/job.py +71 -3
- nucleus/logger.py +9 -0
- nucleus/metadata_manager.py +45 -0
- nucleus/metrics/__init__.py +10 -0
- nucleus/metrics/base.py +117 -0
- nucleus/metrics/categorization_metrics.py +197 -0
- nucleus/metrics/errors.py +7 -0
- nucleus/metrics/filters.py +40 -0
- nucleus/metrics/geometry.py +198 -0
- nucleus/metrics/metric_utils.py +28 -0
- nucleus/metrics/polygon_metrics.py +480 -0
- nucleus/metrics/polygon_utils.py +299 -0
- nucleus/model.py +121 -15
- nucleus/model_run.py +34 -57
- nucleus/payload_constructor.py +29 -19
- nucleus/prediction.py +259 -17
- nucleus/pydantic_base.py +26 -0
- nucleus/retry_strategy.py +4 -0
- nucleus/scene.py +204 -19
- nucleus/slice.py +230 -67
- nucleus/upload_response.py +20 -9
- nucleus/url_utils.py +4 -0
- nucleus/utils.py +134 -37
- nucleus/validate/__init__.py +24 -0
- nucleus/validate/client.py +168 -0
- nucleus/validate/constants.py +20 -0
- nucleus/validate/data_transfer_objects/__init__.py +0 -0
- nucleus/validate/data_transfer_objects/eval_function.py +81 -0
- nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
- nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
- nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
- nucleus/validate/errors.py +6 -0
- nucleus/validate/eval_functions/__init__.py +0 -0
- nucleus/validate/eval_functions/available_eval_functions.py +212 -0
- nucleus/validate/eval_functions/base_eval_function.py +60 -0
- nucleus/validate/scenario_test.py +143 -0
- nucleus/validate/scenario_test_evaluation.py +114 -0
- nucleus/validate/scenario_test_metric.py +14 -0
- nucleus/validate/utils.py +8 -0
- {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
- scale_nucleus-0.6.4.dist-info/METADATA +213 -0
- scale_nucleus-0.6.4.dist-info/RECORD +71 -0
- {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
- scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
- scale_nucleus-0.1.24.dist-info/METADATA +0 -85
- scale_nucleus-0.1.24.dist-info/RECORD +0 -21
nucleus/utils.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
"""Shared stateless utility function library"""
|
2
2
|
|
3
|
-
from collections import defaultdict
|
4
3
|
import io
|
5
|
-
import uuid
|
6
4
|
import json
|
7
|
-
|
5
|
+
import uuid
|
6
|
+
from collections import defaultdict
|
7
|
+
from typing import IO, Dict, List, Sequence, Type, Union
|
8
8
|
|
9
9
|
import requests
|
10
10
|
from requests.models import HTTPError
|
@@ -12,9 +12,10 @@ from requests.models import HTTPError
|
|
12
12
|
from nucleus.annotation import (
|
13
13
|
Annotation,
|
14
14
|
BoxAnnotation,
|
15
|
+
CategoryAnnotation,
|
15
16
|
CuboidAnnotation,
|
17
|
+
MultiCategoryAnnotation,
|
16
18
|
PolygonAnnotation,
|
17
|
-
CategoryAnnotation,
|
18
19
|
SegmentationAnnotation,
|
19
20
|
)
|
20
21
|
|
@@ -22,55 +23,126 @@ from .constants import (
|
|
22
23
|
ANNOTATION_TYPES,
|
23
24
|
ANNOTATIONS_KEY,
|
24
25
|
BOX_TYPE,
|
25
|
-
CUBOID_TYPE,
|
26
26
|
CATEGORY_TYPE,
|
27
|
+
CUBOID_TYPE,
|
27
28
|
ITEM_KEY,
|
29
|
+
MULTICATEGORY_TYPE,
|
28
30
|
POLYGON_TYPE,
|
29
31
|
REFERENCE_ID_KEY,
|
30
32
|
SEGMENTATION_TYPE,
|
31
33
|
)
|
32
34
|
from .dataset_item import DatasetItem
|
33
|
-
from .prediction import
|
35
|
+
from .prediction import (
|
36
|
+
BoxPrediction,
|
37
|
+
CategoryPrediction,
|
38
|
+
CuboidPrediction,
|
39
|
+
PolygonPrediction,
|
40
|
+
SegmentationPrediction,
|
41
|
+
)
|
34
42
|
from .scene import LidarScene
|
35
43
|
|
44
|
+
STRING_REPLACEMENTS = {
|
45
|
+
"\\\\n": "\n",
|
46
|
+
"\\\\t": "\t",
|
47
|
+
'\\\\"': '"',
|
48
|
+
}
|
36
49
|
|
37
|
-
def _get_all_field_values(metadata_list: List[dict], key: str):
|
38
|
-
return {metadata[key] for metadata in metadata_list if key in metadata}
|
39
50
|
|
51
|
+
class KeyErrorDict(dict):
|
52
|
+
"""Wrapper for response dicts with deprecated keys.
|
40
53
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
):
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
54
|
+
Parameters:
|
55
|
+
**kwargs: Mapping from the deprecated key to a warning message.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, **kwargs):
|
59
|
+
self._deprecated = {}
|
60
|
+
|
61
|
+
for key, msg in kwargs.items():
|
62
|
+
if not isinstance(key, str):
|
63
|
+
raise TypeError(
|
64
|
+
f"All keys must be strings! Received non-string '{key}'"
|
65
|
+
)
|
66
|
+
if not isinstance(msg, str):
|
67
|
+
raise TypeError(
|
68
|
+
f"All warning messages must be strings! Received non-string '{msg}'"
|
69
|
+
)
|
70
|
+
|
71
|
+
self._deprecated[key] = msg
|
58
72
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
73
|
+
super().__init__()
|
74
|
+
|
75
|
+
def __missing__(self, key):
|
76
|
+
"""Raises KeyError for deprecated keys, otherwise uses base dict logic."""
|
77
|
+
if key in self._deprecated:
|
78
|
+
raise KeyError(self._deprecated[key])
|
79
|
+
try:
|
80
|
+
super().__missing__(key)
|
81
|
+
except AttributeError as e:
|
82
|
+
raise KeyError(key) from e
|
83
|
+
|
84
|
+
|
85
|
+
def format_prediction_response(
|
86
|
+
response: dict,
|
87
|
+
) -> Union[
|
88
|
+
dict,
|
89
|
+
List[
|
90
|
+
Union[
|
91
|
+
BoxPrediction,
|
92
|
+
PolygonPrediction,
|
93
|
+
CuboidPrediction,
|
94
|
+
CategoryPrediction,
|
95
|
+
SegmentationPrediction,
|
96
|
+
]
|
97
|
+
],
|
98
|
+
]:
|
99
|
+
"""Helper function to convert JSON response from endpoints to python objects
|
100
|
+
|
101
|
+
Args:
|
102
|
+
response: JSON dictionary response from REST endpoint.
|
103
|
+
Returns:
|
104
|
+
annotation_response: Dictionary containing a list of annotations for each type,
|
105
|
+
keyed by the type name.
|
106
|
+
"""
|
107
|
+
annotation_payload = response.get(ANNOTATIONS_KEY, None)
|
108
|
+
if not annotation_payload:
|
109
|
+
# An error occurred
|
110
|
+
return response
|
111
|
+
annotation_response = {}
|
112
|
+
type_key_to_class: Dict[
|
113
|
+
str,
|
114
|
+
Union[
|
115
|
+
Type[BoxPrediction],
|
116
|
+
Type[PolygonPrediction],
|
117
|
+
Type[CuboidPrediction],
|
118
|
+
Type[CategoryPrediction],
|
119
|
+
Type[SegmentationPrediction],
|
120
|
+
],
|
121
|
+
] = {
|
122
|
+
BOX_TYPE: BoxPrediction,
|
123
|
+
POLYGON_TYPE: PolygonPrediction,
|
124
|
+
CUBOID_TYPE: CuboidPrediction,
|
125
|
+
CATEGORY_TYPE: CategoryPrediction,
|
126
|
+
SEGMENTATION_TYPE: SegmentationPrediction,
|
127
|
+
}
|
128
|
+
for type_key in annotation_payload:
|
129
|
+
type_class = type_key_to_class[type_key]
|
130
|
+
annotation_response[type_key] = [
|
131
|
+
type_class.from_json(annotation)
|
132
|
+
for annotation in annotation_payload[type_key]
|
133
|
+
]
|
134
|
+
return annotation_response
|
70
135
|
|
71
136
|
|
72
137
|
def format_dataset_item_response(response: dict) -> dict:
|
73
|
-
"""Format the raw client response into api objects.
|
138
|
+
"""Format the raw client response into api objects.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
response: JSON dictionary response from REST endpoint
|
142
|
+
Returns:
|
143
|
+
item_dict: A dictionary with two entries, one for the dataset item, and annother
|
144
|
+
for all of the associated annotations.
|
145
|
+
"""
|
74
146
|
if ANNOTATIONS_KEY not in response:
|
75
147
|
raise ValueError(
|
76
148
|
f"Server response was missing the annotation key: {response}"
|
@@ -96,6 +168,15 @@ def format_dataset_item_response(response: dict) -> dict:
|
|
96
168
|
|
97
169
|
|
98
170
|
def convert_export_payload(api_payload):
|
171
|
+
"""Helper function to convert raw JSON to API objects
|
172
|
+
|
173
|
+
Args:
|
174
|
+
api_payload: JSON dictionary response from REST endpoint
|
175
|
+
Returns:
|
176
|
+
return_payload: A list of dictionaries for each dataset item. Each dictionary
|
177
|
+
is in the same format as format_dataset_item_response: one key for the
|
178
|
+
dataset item, another for the annotations.
|
179
|
+
"""
|
99
180
|
return_payload = []
|
100
181
|
for row in api_payload:
|
101
182
|
return_payload_row = {}
|
@@ -123,6 +204,11 @@ def convert_export_payload(api_payload):
|
|
123
204
|
annotations[CATEGORY_TYPE].append(
|
124
205
|
CategoryAnnotation.from_json(category)
|
125
206
|
)
|
207
|
+
for multicategory in row[MULTICATEGORY_TYPE]:
|
208
|
+
multicategory[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
|
209
|
+
annotations[MULTICATEGORY_TYPE].append(
|
210
|
+
MultiCategoryAnnotation.from_json(multicategory)
|
211
|
+
)
|
126
212
|
return_payload_row[ANNOTATIONS_KEY] = annotations
|
127
213
|
return_payload.append(return_payload_row)
|
128
214
|
return return_payload
|
@@ -132,6 +218,10 @@ def serialize_and_write(
|
|
132
218
|
upload_units: Sequence[Union[DatasetItem, Annotation, LidarScene]],
|
133
219
|
file_pointer,
|
134
220
|
):
|
221
|
+
if len(upload_units) == 0:
|
222
|
+
raise ValueError(
|
223
|
+
"Expecting at least one object when serializing objects to upload, but got zero. Please try again."
|
224
|
+
)
|
135
225
|
for unit in upload_units:
|
136
226
|
try:
|
137
227
|
if isinstance(unit, (DatasetItem, Annotation, LidarScene)):
|
@@ -168,6 +258,7 @@ def serialize_and_write_to_presigned_url(
|
|
168
258
|
dataset_id: str,
|
169
259
|
client,
|
170
260
|
):
|
261
|
+
"""This helper function can be used to serialize a list of API objects to NDJSON."""
|
171
262
|
request_id = uuid.uuid4().hex
|
172
263
|
response = client.make_request(
|
173
264
|
payload={},
|
@@ -180,3 +271,9 @@ def serialize_and_write_to_presigned_url(
|
|
180
271
|
strio.seek(0)
|
181
272
|
upload_to_presigned_url(response["signed_url"], strio)
|
182
273
|
return request_id
|
274
|
+
|
275
|
+
|
276
|
+
def replace_double_slashes(s: str) -> str:
|
277
|
+
for key, val in STRING_REPLACEMENTS.items():
|
278
|
+
s = s.replace(key, val)
|
279
|
+
return s
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""Model CI Python Library."""
|
2
|
+
|
3
|
+
__all__ = [
|
4
|
+
"Validate",
|
5
|
+
"ScenarioTest",
|
6
|
+
"EvaluationCriterion",
|
7
|
+
]
|
8
|
+
|
9
|
+
from .client import Validate
|
10
|
+
from .constants import ThresholdComparison
|
11
|
+
from .data_transfer_objects.eval_function import (
|
12
|
+
EvalFunctionEntry,
|
13
|
+
EvaluationCriterion,
|
14
|
+
GetEvalFunctions,
|
15
|
+
)
|
16
|
+
from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
|
17
|
+
from .errors import CreateScenarioTestError
|
18
|
+
from .eval_functions.available_eval_functions import AvailableEvalFunctions
|
19
|
+
from .scenario_test import ScenarioTest
|
20
|
+
from .scenario_test_evaluation import (
|
21
|
+
ScenarioTestEvaluation,
|
22
|
+
ScenarioTestItemEvaluation,
|
23
|
+
)
|
24
|
+
from .scenario_test_metric import ScenarioTestMetric
|
@@ -0,0 +1,168 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from nucleus.connection import Connection
|
4
|
+
from nucleus.job import AsyncJob
|
5
|
+
|
6
|
+
from .constants import SCENARIO_TEST_ID_KEY
|
7
|
+
from .data_transfer_objects.eval_function import (
|
8
|
+
EvaluationCriterion,
|
9
|
+
GetEvalFunctions,
|
10
|
+
)
|
11
|
+
from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
|
12
|
+
from .errors import CreateScenarioTestError
|
13
|
+
from .eval_functions.available_eval_functions import AvailableEvalFunctions
|
14
|
+
from .scenario_test import ScenarioTest
|
15
|
+
|
16
|
+
SUCCESS_KEY = "success"
|
17
|
+
EVAL_FUNCTIONS_KEY = "eval_functions"
|
18
|
+
|
19
|
+
|
20
|
+
class Validate:
|
21
|
+
"""Model CI Python Client extension."""
|
22
|
+
|
23
|
+
def __init__(self, api_key: str, endpoint: str):
|
24
|
+
self.connection = Connection(api_key, endpoint)
|
25
|
+
|
26
|
+
def __repr__(self):
|
27
|
+
return f"Validate(connection='{self.connection}')"
|
28
|
+
|
29
|
+
def __eq__(self, other):
|
30
|
+
return self.connection == other.connection
|
31
|
+
|
32
|
+
@property
|
33
|
+
def eval_functions(self) -> AvailableEvalFunctions:
|
34
|
+
"""List all available evaluation functions which can be used to set up evaluation criteria.::
|
35
|
+
|
36
|
+
import nucleus
|
37
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
38
|
+
|
39
|
+
scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 # Creates an EvaluationCriterion by comparison
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
:class:`AvailableEvalFunctions`: A container for all the available eval functions
|
43
|
+
"""
|
44
|
+
response = self.connection.get(
|
45
|
+
"validate/eval_fn",
|
46
|
+
)
|
47
|
+
payload = GetEvalFunctions.parse_obj(response)
|
48
|
+
return AvailableEvalFunctions(payload.eval_functions)
|
49
|
+
|
50
|
+
def create_scenario_test(
|
51
|
+
self,
|
52
|
+
name: str,
|
53
|
+
slice_id: str,
|
54
|
+
evaluation_criteria: List[EvaluationCriterion],
|
55
|
+
) -> ScenarioTest:
|
56
|
+
"""Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. ::
|
57
|
+
|
58
|
+
import nucleus
|
59
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
60
|
+
|
61
|
+
scenario_test = client.validate.create_scenario_test(
|
62
|
+
name="sample_scenario_test",
|
63
|
+
slice_id="YOUR_SLICE_ID",
|
64
|
+
evaluation_criteria=[client.validate.eval_functions.bbox_iou() > 0.5]
|
65
|
+
)
|
66
|
+
|
67
|
+
Args:
|
68
|
+
name: unique name of test
|
69
|
+
slice_id: id of (pre-defined) slice of items to evaluate test on.
|
70
|
+
evaluation_criteria: :class:`EvaluationCriterion` defines a pass/fail criteria for the test. Created with a
|
71
|
+
comparison with an eval functions. See :class:`eval_functions`.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
Created ScenarioTest object.
|
75
|
+
"""
|
76
|
+
if not evaluation_criteria:
|
77
|
+
raise CreateScenarioTestError(
|
78
|
+
"Must pass an evaluation_criteria to the scenario test! I.e. "
|
79
|
+
"evaluation_criteria = [client.validate.eval_functions.bbox_iou() > 0.5]"
|
80
|
+
)
|
81
|
+
response = self.connection.post(
|
82
|
+
CreateScenarioTestRequest(
|
83
|
+
name=name,
|
84
|
+
slice_id=slice_id,
|
85
|
+
evaluation_criteria=evaluation_criteria,
|
86
|
+
).dict(),
|
87
|
+
"validate/scenario_test",
|
88
|
+
)
|
89
|
+
return ScenarioTest(response[SCENARIO_TEST_ID_KEY], self.connection)
|
90
|
+
|
91
|
+
def get_scenario_test(self, scenario_test_id: str) -> ScenarioTest:
|
92
|
+
response = self.connection.get(
|
93
|
+
f"validate/scenario_test/{scenario_test_id}",
|
94
|
+
)
|
95
|
+
return ScenarioTest(response["id"], self.connection)
|
96
|
+
|
97
|
+
def list_scenario_tests(self) -> List[ScenarioTest]:
|
98
|
+
"""Lists all Scenario Tests of the current user. ::
|
99
|
+
|
100
|
+
import nucleus
|
101
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
102
|
+
scenario_test = client.validate.create_scenario_test(
|
103
|
+
"sample_scenario_test", "slc_bx86ea222a6g057x4380"
|
104
|
+
)
|
105
|
+
|
106
|
+
client.validate.list_scenario_tests()
|
107
|
+
|
108
|
+
Returns:
|
109
|
+
A list of ScenarioTest objects.
|
110
|
+
"""
|
111
|
+
response = self.connection.get(
|
112
|
+
"validate/scenario_test",
|
113
|
+
)
|
114
|
+
return [
|
115
|
+
ScenarioTest(test_id, self.connection)
|
116
|
+
for test_id in response["scenario_test_ids"]
|
117
|
+
]
|
118
|
+
|
119
|
+
def delete_scenario_test(self, scenario_test_id: str) -> bool:
|
120
|
+
"""Deletes a Scenario Test. ::
|
121
|
+
|
122
|
+
import nucleus
|
123
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
124
|
+
scenario_test = client.validate.list_scenario_tests()[0]
|
125
|
+
|
126
|
+
success = client.validate.delete_scenario_test(scenario_test.id)
|
127
|
+
|
128
|
+
Args:
|
129
|
+
scenario_test_id: unique ID of scenario test
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Whether deletion was successful.
|
133
|
+
"""
|
134
|
+
response = self.connection.delete(
|
135
|
+
f"validate/scenario_test/{scenario_test_id}",
|
136
|
+
)
|
137
|
+
return response[SUCCESS_KEY]
|
138
|
+
|
139
|
+
def evaluate_model_on_scenario_tests(
|
140
|
+
self, model_id: str, scenario_test_names: List[str]
|
141
|
+
) -> AsyncJob:
|
142
|
+
"""Evaluates the given model on the specified Scenario Tests. ::
|
143
|
+
|
144
|
+
import nucleus
|
145
|
+
client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
|
146
|
+
model = client.list_models()[0]
|
147
|
+
scenario_test = client.validate.create_scenario_test(
|
148
|
+
"sample_scenario_test", "slc_bx86ea222a6g057x4380"
|
149
|
+
)
|
150
|
+
|
151
|
+
job = client.validate.evaluate_model_on_scenario_tests(
|
152
|
+
model_id=model.id,
|
153
|
+
scenario_test_names=["sample_scenario_test"],
|
154
|
+
)
|
155
|
+
job.sleep_until_complete() # Not required. Will block and update on status of the job.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
model_id: ID of model to evaluate
|
159
|
+
scenario_test_names: list of scenario test names of test to evaluate
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
AsyncJob object of evaluation job
|
163
|
+
"""
|
164
|
+
response = self.connection.post(
|
165
|
+
{"test_names": scenario_test_names},
|
166
|
+
f"validate/{model_id}/evaluate",
|
167
|
+
)
|
168
|
+
return AsyncJob.from_json(response, self.connection)
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
EVALUATION_ID_KEY = "evaluation_id"
|
4
|
+
EVAL_FUNCTION_ID_KEY = "eval_function_id"
|
5
|
+
ID_KEY = "id"
|
6
|
+
PASS_KEY = "pass"
|
7
|
+
RESULT_KEY = "result"
|
8
|
+
THRESHOLD_COMPARISON_KEY = "threshold_comparison"
|
9
|
+
THRESHOLD_KEY = "threshold"
|
10
|
+
SCENARIO_TEST_ID_KEY = "scenario_test_id"
|
11
|
+
SCENARIO_TEST_NAME_KEY = "scenario_test_name"
|
12
|
+
|
13
|
+
|
14
|
+
class ThresholdComparison(str, Enum):
|
15
|
+
"""Comparator between the result and the threshold."""
|
16
|
+
|
17
|
+
GREATER_THAN = "greater_than"
|
18
|
+
GREATER_THAN_EQUAL_TO = "greater_than_equal_to"
|
19
|
+
LESS_THAN = "less_than"
|
20
|
+
LESS_THAN_EQUAL_TO = "less_than_equal_to"
|
File without changes
|
@@ -0,0 +1,81 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
from pydantic import validator
|
4
|
+
|
5
|
+
from ...pydantic_base import ImmutableModel
|
6
|
+
from ..constants import ThresholdComparison
|
7
|
+
|
8
|
+
|
9
|
+
class EvaluationCriterion(ImmutableModel):
|
10
|
+
"""
|
11
|
+
An Evaluation Criterion is defined as an evaluation function, threshold, and comparator.
|
12
|
+
It describes how to apply an evaluation function
|
13
|
+
|
14
|
+
Notes:
|
15
|
+
To define the evaluation criteria for a scenario test we've created some syntactic sugar to make it look closer to an
|
16
|
+
actual function call, and we also hide away implementation details related to our data model that simply are not clear,
|
17
|
+
UX-wise.
|
18
|
+
|
19
|
+
Instead of defining criteria like this::
|
20
|
+
|
21
|
+
from nucleus.validate.data_transfer_objects.eval_function import (
|
22
|
+
EvaluationCriterion,
|
23
|
+
ThresholdComparison,
|
24
|
+
)
|
25
|
+
|
26
|
+
criteria = [
|
27
|
+
EvaluationCriterion(
|
28
|
+
eval_function_id="ef_c6m1khygqk400918ays0", # bbox_recall
|
29
|
+
threshold_comparison=ThresholdComparison.GREATER_THAN,
|
30
|
+
threshold=0.5,
|
31
|
+
),
|
32
|
+
]
|
33
|
+
|
34
|
+
we define it like this::
|
35
|
+
|
36
|
+
bbox_recall = client.validate.eval_functions.bbox_recall
|
37
|
+
criteria = [
|
38
|
+
bbox_recall() > 0.5
|
39
|
+
]
|
40
|
+
|
41
|
+
The chosen method allows us to document the available evaluation functions in an IDE friendly fashion and hides away
|
42
|
+
details like internal IDs (`"ef_...."`).
|
43
|
+
|
44
|
+
The actual `EvaluationCriterion` is created by overloading the comparison operators for the base class of an evaluation
|
45
|
+
function. Instead of the comparison returning a bool, we've made it create an `EvaluationCriterion` with the correct
|
46
|
+
signature to send over the wire to our API.
|
47
|
+
|
48
|
+
|
49
|
+
Parameters:
|
50
|
+
eval_function_id (str): ID of evaluation function
|
51
|
+
threshold_comparison (:class:`ThresholdComparison`): comparator for evaluation. i.e. threshold=0.5 and threshold_comparator > implies that a test only passes if score > 0.5.
|
52
|
+
threshold (float): numerical threshold that together with threshold comparison, defines success criteria for test evaluation.
|
53
|
+
"""
|
54
|
+
|
55
|
+
# TODO: Having only eval_function_id hurts readability -> Add function name
|
56
|
+
eval_function_id: str
|
57
|
+
threshold_comparison: ThresholdComparison
|
58
|
+
threshold: float
|
59
|
+
|
60
|
+
@validator("eval_function_id")
|
61
|
+
def valid_eval_function_id(cls, v): # pylint: disable=no-self-argument
|
62
|
+
if not v.startswith("ef_"):
|
63
|
+
raise ValueError(f"Expected field to start with 'ef_', got '{v}'")
|
64
|
+
return v
|
65
|
+
|
66
|
+
|
67
|
+
class EvalFunctionEntry(ImmutableModel):
|
68
|
+
"""Encapsulates information about an evaluation function for Model CI."""
|
69
|
+
|
70
|
+
id: str
|
71
|
+
name: str
|
72
|
+
is_public: bool
|
73
|
+
user_id: str
|
74
|
+
serialized_fn: Optional[str] = None
|
75
|
+
raw_source: Optional[str] = None
|
76
|
+
|
77
|
+
|
78
|
+
class GetEvalFunctions(ImmutableModel):
|
79
|
+
"""Expected format from GET validate/eval_fn"""
|
80
|
+
|
81
|
+
eval_functions: List[EvalFunctionEntry]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from pydantic import validator
|
4
|
+
|
5
|
+
from nucleus.pydantic_base import ImmutableModel
|
6
|
+
|
7
|
+
from .eval_function import EvaluationCriterion
|
8
|
+
|
9
|
+
|
10
|
+
class CreateScenarioTestRequest(ImmutableModel):
|
11
|
+
name: str
|
12
|
+
slice_id: str
|
13
|
+
evaluation_criteria: List[EvaluationCriterion]
|
14
|
+
|
15
|
+
@validator("slice_id")
|
16
|
+
def startswith_slice_indicator(cls, v): # pylint: disable=no-self-argument
|
17
|
+
if not v.startswith("slc_"):
|
18
|
+
raise ValueError(f"Expected field to start with 'slc_', got '{v}'")
|
19
|
+
return v
|
@@ -0,0 +1,12 @@
|
|
1
|
+
from nucleus.pydantic_base import ImmutableModel
|
2
|
+
|
3
|
+
from ..constants import ThresholdComparison
|
4
|
+
|
5
|
+
|
6
|
+
class AddScenarioTestMetric(ImmutableModel):
|
7
|
+
"""Data transfer object to add a scenario test."""
|
8
|
+
|
9
|
+
scenario_test_name: str
|
10
|
+
eval_function_id: str
|
11
|
+
threshold: float
|
12
|
+
threshold_comparison: ThresholdComparison
|
File without changes
|