scale-nucleus 0.1.22__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +453 -699
  14. nucleus/annotation.py +435 -80
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +12 -2
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1139 -215
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +21 -1
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +30 -18
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +139 -35
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.22.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.22.dist-info/RECORD +0 -21
nucleus/utils.py CHANGED
@@ -1,10 +1,10 @@
1
1
  """Shared stateless utility function library"""
2
2
 
3
- from collections import defaultdict
4
3
  import io
5
- import uuid
6
4
  import json
7
- from typing import IO, Dict, List, Sequence, Union
5
+ import uuid
6
+ from collections import defaultdict
7
+ from typing import IO, Dict, List, Sequence, Type, Union
8
8
 
9
9
  import requests
10
10
  from requests.models import HTTPError
@@ -12,7 +12,9 @@ from requests.models import HTTPError
12
12
  from nucleus.annotation import (
13
13
  Annotation,
14
14
  BoxAnnotation,
15
+ CategoryAnnotation,
15
16
  CuboidAnnotation,
17
+ MultiCategoryAnnotation,
16
18
  PolygonAnnotation,
17
19
  SegmentationAnnotation,
18
20
  )
@@ -21,54 +23,126 @@ from .constants import (
21
23
  ANNOTATION_TYPES,
22
24
  ANNOTATIONS_KEY,
23
25
  BOX_TYPE,
26
+ CATEGORY_TYPE,
24
27
  CUBOID_TYPE,
25
28
  ITEM_KEY,
29
+ MULTICATEGORY_TYPE,
26
30
  POLYGON_TYPE,
27
31
  REFERENCE_ID_KEY,
28
32
  SEGMENTATION_TYPE,
29
33
  )
30
34
  from .dataset_item import DatasetItem
31
- from .prediction import BoxPrediction, CuboidPrediction, PolygonPrediction
35
+ from .prediction import (
36
+ BoxPrediction,
37
+ CategoryPrediction,
38
+ CuboidPrediction,
39
+ PolygonPrediction,
40
+ SegmentationPrediction,
41
+ )
32
42
  from .scene import LidarScene
33
43
 
44
+ STRING_REPLACEMENTS = {
45
+ "\\\\n": "\n",
46
+ "\\\\t": "\t",
47
+ '\\\\"': '"',
48
+ }
34
49
 
35
- def _get_all_field_values(metadata_list: List[dict], key: str):
36
- return {metadata[key] for metadata in metadata_list if key in metadata}
37
50
 
51
+ class KeyErrorDict(dict):
52
+ """Wrapper for response dicts with deprecated keys.
38
53
 
39
- def suggest_metadata_schema(
40
- data: Union[
41
- List[DatasetItem],
42
- List[BoxPrediction],
43
- List[PolygonPrediction],
44
- List[CuboidPrediction],
45
- ]
46
- ):
47
- metadata_list: List[dict] = [
48
- d.metadata for d in data if d.metadata is not None
49
- ]
50
- schema = {}
51
- all_keys = {k for metadata in metadata_list for k in metadata.keys()}
52
-
53
- all_key_values: Dict[str, set] = {
54
- k: _get_all_field_values(metadata_list, k) for k in all_keys
55
- }
54
+ Parameters:
55
+ **kwargs: Mapping from the deprecated key to a warning message.
56
+ """
57
+
58
+ def __init__(self, **kwargs):
59
+ self._deprecated = {}
60
+
61
+ for key, msg in kwargs.items():
62
+ if not isinstance(key, str):
63
+ raise TypeError(
64
+ f"All keys must be strings! Received non-string '{key}'"
65
+ )
66
+ if not isinstance(msg, str):
67
+ raise TypeError(
68
+ f"All warning messages must be strings! Received non-string '{msg}'"
69
+ )
70
+
71
+ self._deprecated[key] = msg
72
+
73
+ super().__init__()
74
+
75
+ def __missing__(self, key):
76
+ """Raises KeyError for deprecated keys, otherwise uses base dict logic."""
77
+ if key in self._deprecated:
78
+ raise KeyError(self._deprecated[key])
79
+ try:
80
+ super().__missing__(key)
81
+ except AttributeError as e:
82
+ raise KeyError(key) from e
56
83
 
57
- for key, values in all_key_values.items():
58
- entry: dict = {}
59
- if all(isinstance(x, (float, int)) for x in values):
60
- entry["type"] = "number"
61
- elif len(values) <= 50:
62
- entry["type"] = "category"
63
- entry["choices"] = list(values)
64
- else:
65
- entry["type"] = "text"
66
- schema[key] = entry
67
- return schema
84
+
85
+ def format_prediction_response(
86
+ response: dict,
87
+ ) -> Union[
88
+ dict,
89
+ List[
90
+ Union[
91
+ BoxPrediction,
92
+ PolygonPrediction,
93
+ CuboidPrediction,
94
+ CategoryPrediction,
95
+ SegmentationPrediction,
96
+ ]
97
+ ],
98
+ ]:
99
+ """Helper function to convert JSON response from endpoints to python objects
100
+
101
+ Args:
102
+ response: JSON dictionary response from REST endpoint.
103
+ Returns:
104
+ annotation_response: Dictionary containing a list of annotations for each type,
105
+ keyed by the type name.
106
+ """
107
+ annotation_payload = response.get(ANNOTATIONS_KEY, None)
108
+ if not annotation_payload:
109
+ # An error occurred
110
+ return response
111
+ annotation_response = {}
112
+ type_key_to_class: Dict[
113
+ str,
114
+ Union[
115
+ Type[BoxPrediction],
116
+ Type[PolygonPrediction],
117
+ Type[CuboidPrediction],
118
+ Type[CategoryPrediction],
119
+ Type[SegmentationPrediction],
120
+ ],
121
+ ] = {
122
+ BOX_TYPE: BoxPrediction,
123
+ POLYGON_TYPE: PolygonPrediction,
124
+ CUBOID_TYPE: CuboidPrediction,
125
+ CATEGORY_TYPE: CategoryPrediction,
126
+ SEGMENTATION_TYPE: SegmentationPrediction,
127
+ }
128
+ for type_key in annotation_payload:
129
+ type_class = type_key_to_class[type_key]
130
+ annotation_response[type_key] = [
131
+ type_class.from_json(annotation)
132
+ for annotation in annotation_payload[type_key]
133
+ ]
134
+ return annotation_response
68
135
 
69
136
 
70
137
  def format_dataset_item_response(response: dict) -> dict:
71
- """Format the raw client response into api objects."""
138
+ """Format the raw client response into api objects.
139
+
140
+ Args:
141
+ response: JSON dictionary response from REST endpoint
142
+ Returns:
143
+ item_dict: A dictionary with two entries, one for the dataset item, and annother
144
+ for all of the associated annotations.
145
+ """
72
146
  if ANNOTATIONS_KEY not in response:
73
147
  raise ValueError(
74
148
  f"Server response was missing the annotation key: {response}"
@@ -94,6 +168,15 @@ def format_dataset_item_response(response: dict) -> dict:
94
168
 
95
169
 
96
170
  def convert_export_payload(api_payload):
171
+ """Helper function to convert raw JSON to API objects
172
+
173
+ Args:
174
+ api_payload: JSON dictionary response from REST endpoint
175
+ Returns:
176
+ return_payload: A list of dictionaries for each dataset item. Each dictionary
177
+ is in the same format as format_dataset_item_response: one key for the
178
+ dataset item, another for the annotations.
179
+ """
97
180
  return_payload = []
98
181
  for row in api_payload:
99
182
  return_payload_row = {}
@@ -116,6 +199,16 @@ def convert_export_payload(api_payload):
116
199
  for cuboid in row[CUBOID_TYPE]:
117
200
  cuboid[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
118
201
  annotations[CUBOID_TYPE].append(CuboidAnnotation.from_json(cuboid))
202
+ for category in row[CATEGORY_TYPE]:
203
+ category[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
204
+ annotations[CATEGORY_TYPE].append(
205
+ CategoryAnnotation.from_json(category)
206
+ )
207
+ for multicategory in row[MULTICATEGORY_TYPE]:
208
+ multicategory[REFERENCE_ID_KEY] = row[ITEM_KEY][REFERENCE_ID_KEY]
209
+ annotations[MULTICATEGORY_TYPE].append(
210
+ MultiCategoryAnnotation.from_json(multicategory)
211
+ )
119
212
  return_payload_row[ANNOTATIONS_KEY] = annotations
120
213
  return_payload.append(return_payload_row)
121
214
  return return_payload
@@ -125,6 +218,10 @@ def serialize_and_write(
125
218
  upload_units: Sequence[Union[DatasetItem, Annotation, LidarScene]],
126
219
  file_pointer,
127
220
  ):
221
+ if len(upload_units) == 0:
222
+ raise ValueError(
223
+ "Expecting at least one object when serializing objects to upload, but got zero. Please try again."
224
+ )
128
225
  for unit in upload_units:
129
226
  try:
130
227
  if isinstance(unit, (DatasetItem, Annotation, LidarScene)):
@@ -161,6 +258,7 @@ def serialize_and_write_to_presigned_url(
161
258
  dataset_id: str,
162
259
  client,
163
260
  ):
261
+ """This helper function can be used to serialize a list of API objects to NDJSON."""
164
262
  request_id = uuid.uuid4().hex
165
263
  response = client.make_request(
166
264
  payload={},
@@ -173,3 +271,9 @@ def serialize_and_write_to_presigned_url(
173
271
  strio.seek(0)
174
272
  upload_to_presigned_url(response["signed_url"], strio)
175
273
  return request_id
274
+
275
+
276
+ def replace_double_slashes(s: str) -> str:
277
+ for key, val in STRING_REPLACEMENTS.items():
278
+ s = s.replace(key, val)
279
+ return s
@@ -0,0 +1,24 @@
1
+ """Model CI Python Library."""
2
+
3
+ __all__ = [
4
+ "Validate",
5
+ "ScenarioTest",
6
+ "EvaluationCriterion",
7
+ ]
8
+
9
+ from .client import Validate
10
+ from .constants import ThresholdComparison
11
+ from .data_transfer_objects.eval_function import (
12
+ EvalFunctionEntry,
13
+ EvaluationCriterion,
14
+ GetEvalFunctions,
15
+ )
16
+ from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
17
+ from .errors import CreateScenarioTestError
18
+ from .eval_functions.available_eval_functions import AvailableEvalFunctions
19
+ from .scenario_test import ScenarioTest
20
+ from .scenario_test_evaluation import (
21
+ ScenarioTestEvaluation,
22
+ ScenarioTestItemEvaluation,
23
+ )
24
+ from .scenario_test_metric import ScenarioTestMetric
@@ -0,0 +1,168 @@
1
+ from typing import List
2
+
3
+ from nucleus.connection import Connection
4
+ from nucleus.job import AsyncJob
5
+
6
+ from .constants import SCENARIO_TEST_ID_KEY
7
+ from .data_transfer_objects.eval_function import (
8
+ EvaluationCriterion,
9
+ GetEvalFunctions,
10
+ )
11
+ from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
12
+ from .errors import CreateScenarioTestError
13
+ from .eval_functions.available_eval_functions import AvailableEvalFunctions
14
+ from .scenario_test import ScenarioTest
15
+
16
+ SUCCESS_KEY = "success"
17
+ EVAL_FUNCTIONS_KEY = "eval_functions"
18
+
19
+
20
+ class Validate:
21
+ """Model CI Python Client extension."""
22
+
23
+ def __init__(self, api_key: str, endpoint: str):
24
+ self.connection = Connection(api_key, endpoint)
25
+
26
+ def __repr__(self):
27
+ return f"Validate(connection='{self.connection}')"
28
+
29
+ def __eq__(self, other):
30
+ return self.connection == other.connection
31
+
32
+ @property
33
+ def eval_functions(self) -> AvailableEvalFunctions:
34
+ """List all available evaluation functions which can be used to set up evaluation criteria.::
35
+
36
+ import nucleus
37
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
38
+
39
+ scenario_test_criterion = client.validate.eval_functions.bbox_iou() > 0.5 # Creates an EvaluationCriterion by comparison
40
+
41
+ Returns:
42
+ :class:`AvailableEvalFunctions`: A container for all the available eval functions
43
+ """
44
+ response = self.connection.get(
45
+ "validate/eval_fn",
46
+ )
47
+ payload = GetEvalFunctions.parse_obj(response)
48
+ return AvailableEvalFunctions(payload.eval_functions)
49
+
50
+ def create_scenario_test(
51
+ self,
52
+ name: str,
53
+ slice_id: str,
54
+ evaluation_criteria: List[EvaluationCriterion],
55
+ ) -> ScenarioTest:
56
+ """Creates a new Scenario Test from an existing Nucleus :class:`Slice`:. ::
57
+
58
+ import nucleus
59
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
60
+
61
+ scenario_test = client.validate.create_scenario_test(
62
+ name="sample_scenario_test",
63
+ slice_id="YOUR_SLICE_ID",
64
+ evaluation_criteria=[client.validate.eval_functions.bbox_iou() > 0.5]
65
+ )
66
+
67
+ Args:
68
+ name: unique name of test
69
+ slice_id: id of (pre-defined) slice of items to evaluate test on.
70
+ evaluation_criteria: :class:`EvaluationCriterion` defines a pass/fail criteria for the test. Created with a
71
+ comparison with an eval functions. See :class:`eval_functions`.
72
+
73
+ Returns:
74
+ Created ScenarioTest object.
75
+ """
76
+ if not evaluation_criteria:
77
+ raise CreateScenarioTestError(
78
+ "Must pass an evaluation_criteria to the scenario test! I.e. "
79
+ "evaluation_criteria = [client.validate.eval_functions.bbox_iou() > 0.5]"
80
+ )
81
+ response = self.connection.post(
82
+ CreateScenarioTestRequest(
83
+ name=name,
84
+ slice_id=slice_id,
85
+ evaluation_criteria=evaluation_criteria,
86
+ ).dict(),
87
+ "validate/scenario_test",
88
+ )
89
+ return ScenarioTest(response[SCENARIO_TEST_ID_KEY], self.connection)
90
+
91
+ def get_scenario_test(self, scenario_test_id: str) -> ScenarioTest:
92
+ response = self.connection.get(
93
+ f"validate/scenario_test/{scenario_test_id}",
94
+ )
95
+ return ScenarioTest(response["id"], self.connection)
96
+
97
+ def list_scenario_tests(self) -> List[ScenarioTest]:
98
+ """Lists all Scenario Tests of the current user. ::
99
+
100
+ import nucleus
101
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
102
+ scenario_test = client.validate.create_scenario_test(
103
+ "sample_scenario_test", "slc_bx86ea222a6g057x4380"
104
+ )
105
+
106
+ client.validate.list_scenario_tests()
107
+
108
+ Returns:
109
+ A list of ScenarioTest objects.
110
+ """
111
+ response = self.connection.get(
112
+ "validate/scenario_test",
113
+ )
114
+ return [
115
+ ScenarioTest(test_id, self.connection)
116
+ for test_id in response["scenario_test_ids"]
117
+ ]
118
+
119
+ def delete_scenario_test(self, scenario_test_id: str) -> bool:
120
+ """Deletes a Scenario Test. ::
121
+
122
+ import nucleus
123
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
124
+ scenario_test = client.validate.list_scenario_tests()[0]
125
+
126
+ success = client.validate.delete_scenario_test(scenario_test.id)
127
+
128
+ Args:
129
+ scenario_test_id: unique ID of scenario test
130
+
131
+ Returns:
132
+ Whether deletion was successful.
133
+ """
134
+ response = self.connection.delete(
135
+ f"validate/scenario_test/{scenario_test_id}",
136
+ )
137
+ return response[SUCCESS_KEY]
138
+
139
+ def evaluate_model_on_scenario_tests(
140
+ self, model_id: str, scenario_test_names: List[str]
141
+ ) -> AsyncJob:
142
+ """Evaluates the given model on the specified Scenario Tests. ::
143
+
144
+ import nucleus
145
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
146
+ model = client.list_models()[0]
147
+ scenario_test = client.validate.create_scenario_test(
148
+ "sample_scenario_test", "slc_bx86ea222a6g057x4380"
149
+ )
150
+
151
+ job = client.validate.evaluate_model_on_scenario_tests(
152
+ model_id=model.id,
153
+ scenario_test_names=["sample_scenario_test"],
154
+ )
155
+ job.sleep_until_complete() # Not required. Will block and update on status of the job.
156
+
157
+ Args:
158
+ model_id: ID of model to evaluate
159
+ scenario_test_names: list of scenario test names of test to evaluate
160
+
161
+ Returns:
162
+ AsyncJob object of evaluation job
163
+ """
164
+ response = self.connection.post(
165
+ {"test_names": scenario_test_names},
166
+ f"validate/{model_id}/evaluate",
167
+ )
168
+ return AsyncJob.from_json(response, self.connection)
@@ -0,0 +1,20 @@
1
+ from enum import Enum
2
+
3
+ EVALUATION_ID_KEY = "evaluation_id"
4
+ EVAL_FUNCTION_ID_KEY = "eval_function_id"
5
+ ID_KEY = "id"
6
+ PASS_KEY = "pass"
7
+ RESULT_KEY = "result"
8
+ THRESHOLD_COMPARISON_KEY = "threshold_comparison"
9
+ THRESHOLD_KEY = "threshold"
10
+ SCENARIO_TEST_ID_KEY = "scenario_test_id"
11
+ SCENARIO_TEST_NAME_KEY = "scenario_test_name"
12
+
13
+
14
+ class ThresholdComparison(str, Enum):
15
+ """Comparator between the result and the threshold."""
16
+
17
+ GREATER_THAN = "greater_than"
18
+ GREATER_THAN_EQUAL_TO = "greater_than_equal_to"
19
+ LESS_THAN = "less_than"
20
+ LESS_THAN_EQUAL_TO = "less_than_equal_to"
File without changes
@@ -0,0 +1,81 @@
1
+ from typing import List, Optional
2
+
3
+ from pydantic import validator
4
+
5
+ from ...pydantic_base import ImmutableModel
6
+ from ..constants import ThresholdComparison
7
+
8
+
9
+ class EvaluationCriterion(ImmutableModel):
10
+ """
11
+ An Evaluation Criterion is defined as an evaluation function, threshold, and comparator.
12
+ It describes how to apply an evaluation function
13
+
14
+ Notes:
15
+ To define the evaluation criteria for a scenario test we've created some syntactic sugar to make it look closer to an
16
+ actual function call, and we also hide away implementation details related to our data model that simply are not clear,
17
+ UX-wise.
18
+
19
+ Instead of defining criteria like this::
20
+
21
+ from nucleus.validate.data_transfer_objects.eval_function import (
22
+ EvaluationCriterion,
23
+ ThresholdComparison,
24
+ )
25
+
26
+ criteria = [
27
+ EvaluationCriterion(
28
+ eval_function_id="ef_c6m1khygqk400918ays0", # bbox_recall
29
+ threshold_comparison=ThresholdComparison.GREATER_THAN,
30
+ threshold=0.5,
31
+ ),
32
+ ]
33
+
34
+ we define it like this::
35
+
36
+ bbox_recall = client.validate.eval_functions.bbox_recall
37
+ criteria = [
38
+ bbox_recall() > 0.5
39
+ ]
40
+
41
+ The chosen method allows us to document the available evaluation functions in an IDE friendly fashion and hides away
42
+ details like internal IDs (`"ef_...."`).
43
+
44
+ The actual `EvaluationCriterion` is created by overloading the comparison operators for the base class of an evaluation
45
+ function. Instead of the comparison returning a bool, we've made it create an `EvaluationCriterion` with the correct
46
+ signature to send over the wire to our API.
47
+
48
+
49
+ Parameters:
50
+ eval_function_id (str): ID of evaluation function
51
+ threshold_comparison (:class:`ThresholdComparison`): comparator for evaluation. i.e. threshold=0.5 and threshold_comparator > implies that a test only passes if score > 0.5.
52
+ threshold (float): numerical threshold that together with threshold comparison, defines success criteria for test evaluation.
53
+ """
54
+
55
+ # TODO: Having only eval_function_id hurts readability -> Add function name
56
+ eval_function_id: str
57
+ threshold_comparison: ThresholdComparison
58
+ threshold: float
59
+
60
+ @validator("eval_function_id")
61
+ def valid_eval_function_id(cls, v): # pylint: disable=no-self-argument
62
+ if not v.startswith("ef_"):
63
+ raise ValueError(f"Expected field to start with 'ef_', got '{v}'")
64
+ return v
65
+
66
+
67
+ class EvalFunctionEntry(ImmutableModel):
68
+ """Encapsulates information about an evaluation function for Model CI."""
69
+
70
+ id: str
71
+ name: str
72
+ is_public: bool
73
+ user_id: str
74
+ serialized_fn: Optional[str] = None
75
+ raw_source: Optional[str] = None
76
+
77
+
78
+ class GetEvalFunctions(ImmutableModel):
79
+ """Expected format from GET validate/eval_fn"""
80
+
81
+ eval_functions: List[EvalFunctionEntry]
@@ -0,0 +1,19 @@
1
+ from typing import List
2
+
3
+ from pydantic import validator
4
+
5
+ from nucleus.pydantic_base import ImmutableModel
6
+
7
+ from .eval_function import EvaluationCriterion
8
+
9
+
10
+ class CreateScenarioTestRequest(ImmutableModel):
11
+ name: str
12
+ slice_id: str
13
+ evaluation_criteria: List[EvaluationCriterion]
14
+
15
+ @validator("slice_id")
16
+ def startswith_slice_indicator(cls, v): # pylint: disable=no-self-argument
17
+ if not v.startswith("slc_"):
18
+ raise ValueError(f"Expected field to start with 'slc_', got '{v}'")
19
+ return v
@@ -0,0 +1,11 @@
1
+ from typing import List
2
+
3
+ from nucleus.pydantic_base import ImmutableModel
4
+
5
+
6
+ class EvalDetail(ImmutableModel):
7
+ id: str
8
+
9
+
10
+ class GetEvalHistory(ImmutableModel):
11
+ evaluations: List[EvalDetail]
@@ -0,0 +1,12 @@
1
+ from nucleus.pydantic_base import ImmutableModel
2
+
3
+ from ..constants import ThresholdComparison
4
+
5
+
6
+ class AddScenarioTestMetric(ImmutableModel):
7
+ """Data transfer object to add a scenario test."""
8
+
9
+ scenario_test_name: str
10
+ eval_function_id: str
11
+ threshold: float
12
+ threshold_comparison: ThresholdComparison
@@ -0,0 +1,6 @@
1
+ class CreateScenarioTestError(Exception):
2
+ pass
3
+
4
+
5
+ class EvalFunctionNotAvailableError(Exception):
6
+ pass
File without changes