scale-nucleus 0.1.22__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +453 -699
  14. nucleus/annotation.py +435 -80
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +12 -2
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1139 -215
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +21 -1
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +30 -18
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +139 -35
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.22.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.22.dist-info/RECORD +0 -21
@@ -0,0 +1,212 @@
1
+ import itertools
2
+ from typing import Callable, Dict, List, Type, Union
3
+
4
+ from nucleus.logger import logger
5
+ from nucleus.validate.eval_functions.base_eval_function import BaseEvalFunction
6
+
7
+ from ..data_transfer_objects.eval_function import EvalFunctionEntry
8
+ from ..errors import EvalFunctionNotAvailableError
9
+
10
+ MEAN_AVG_PRECISION_NAME = "mean_average_precision_boxes"
11
+
12
+
13
+ class BoundingBoxIOU(BaseEvalFunction):
14
+ @classmethod
15
+ def expected_name(cls) -> str:
16
+ return "bbox_iou"
17
+
18
+
19
+ class BoundingBoxMeanAveragePrecision(BaseEvalFunction):
20
+ @classmethod
21
+ def expected_name(cls) -> str:
22
+ return "bbox_map"
23
+
24
+
25
+ class BoundingBoxRecall(BaseEvalFunction):
26
+ @classmethod
27
+ def expected_name(cls) -> str:
28
+ return "bbox_recall"
29
+
30
+
31
+ class BoundingBoxPrecision(BaseEvalFunction):
32
+ @classmethod
33
+ def expected_name(cls) -> str:
34
+ return "bbox_precision"
35
+
36
+
37
+ class CategorizationF1(BaseEvalFunction):
38
+ @classmethod
39
+ def expected_name(cls) -> str:
40
+ return "cat_f1"
41
+
42
+
43
+ class CustomEvalFunction(BaseEvalFunction):
44
+ @classmethod
45
+ def expected_name(cls) -> str:
46
+ raise NotImplementedError(
47
+ "Custm evaluation functions are coming soon"
48
+ ) # Placeholder: See super().eval_func_entry for actual name
49
+
50
+
51
+ class StandardEvalFunction(BaseEvalFunction):
52
+ """Class for standard Model CI eval functions that have not been added as attributes on
53
+ AvailableEvalFunctions yet.
54
+ """
55
+
56
+ def __init__(self, eval_function_entry: EvalFunctionEntry):
57
+ logger.warning(
58
+ "Standard function %s not implemented as an attribute on AvailableEvalFunctions",
59
+ eval_function_entry.name,
60
+ )
61
+ super().__init__(eval_function_entry)
62
+
63
+ @classmethod
64
+ def expected_name(cls) -> str:
65
+ return "public_function" # Placeholder: See super().eval_func_entry for actual name
66
+
67
+
68
+ class EvalFunctionNotAvailable(BaseEvalFunction):
69
+ def __init__(
70
+ self, not_available_name: str
71
+ ): # pylint: disable=super-init-not-called
72
+ self.not_available_name = not_available_name
73
+
74
+ def __call__(self, *args, **kwargs):
75
+ self._raise_error()
76
+
77
+ def _op_to_test_metric(self, *args, **kwargs):
78
+ self._raise_error()
79
+
80
+ def _raise_error(self):
81
+ raise EvalFunctionNotAvailableError(
82
+ f"Eval function '{self.not_available_name}' is not available to the current user. "
83
+ f"Is Model CI enabled for the user?"
84
+ )
85
+
86
+ @classmethod
87
+ def expected_name(cls) -> str:
88
+ return "public_function" # Placeholder: See super().eval_func_entry for actual name
89
+
90
+
91
+ EvalFunction = Union[
92
+ Type[BoundingBoxIOU],
93
+ Type[BoundingBoxMeanAveragePrecision],
94
+ Type[BoundingBoxPrecision],
95
+ Type[BoundingBoxRecall],
96
+ Type[CustomEvalFunction],
97
+ Type[EvalFunctionNotAvailable],
98
+ Type[StandardEvalFunction],
99
+ ]
100
+
101
+
102
+ class AvailableEvalFunctions:
103
+ """Collection class that acts as a common entrypoint to access evaluation functions. Standard evaluation functions
104
+ provided by Scale are attributes of this class.
105
+
106
+ The available evaluation functions are listed in the sample below::
107
+
108
+ e = client.validate.eval_functions
109
+ unit_test_criteria = [
110
+ e.bbox_iou() > 5,
111
+ e.bbox_map() > 0.95,
112
+ e.bbox_precision() > 0.8,
113
+ e.bbox_recall() > 0.5,
114
+ ]
115
+ """
116
+
117
+ # pylint: disable=too-many-instance-attributes
118
+
119
+ def __init__(self, available_functions: List[EvalFunctionEntry]):
120
+ assert (
121
+ available_functions
122
+ ), "Passed no available functions for current user. Is the feature flag enabled?"
123
+ self._public_func_entries: Dict[str, EvalFunctionEntry] = {
124
+ f.name: f for f in available_functions if f.is_public
125
+ }
126
+ # NOTE: Public are assigned
127
+ self._public_to_function: Dict[str, BaseEvalFunction] = {}
128
+ self._custom_to_function: Dict[str, CustomEvalFunction] = {
129
+ f.name: CustomEvalFunction(f)
130
+ for f in available_functions
131
+ if not f.is_public
132
+ }
133
+ self.bbox_iou = self._assign_eval_function_if_defined(BoundingBoxIOU) # type: ignore
134
+ self.bbox_precision = self._assign_eval_function_if_defined(
135
+ BoundingBoxPrecision # type: ignore
136
+ )
137
+ self.bbox_recall = self._assign_eval_function_if_defined(
138
+ BoundingBoxRecall # type: ignore
139
+ )
140
+ self.bbox_map = self._assign_eval_function_if_defined(
141
+ BoundingBoxMeanAveragePrecision # type: ignore
142
+ )
143
+ self.cat_f1 = self._assign_eval_function_if_defined(
144
+ CategorizationF1 # type: ignore
145
+ )
146
+
147
+ # Add public entries that have not been implemented as an attribute on this class
148
+ for func_entry in self._public_func_entries.values():
149
+ if func_entry.name not in self._public_to_function:
150
+ self._public_to_function[
151
+ func_entry.name
152
+ ] = StandardEvalFunction(func_entry)
153
+
154
+ def __repr__(self):
155
+ """Standard functions are ones Scale provides and custom ones customer defined"""
156
+ # NOTE: setting to lower to be consistent with attribute names
157
+ functions_lower = [
158
+ str(name).lower() for name in self._public_func_entries.keys()
159
+ ]
160
+ return (
161
+ f"<AvailableEvaluationFunctions: public:{functions_lower} "
162
+ f"private: {list(self._custom_to_function.keys())}"
163
+ )
164
+
165
+ @property
166
+ def public_functions(self) -> Dict[str, BaseEvalFunction]:
167
+ """Standard functions provided by Model CI.
168
+
169
+ Notes:
170
+ These functions are also available as attributes on :class:`AvailableEvalFunctions`
171
+
172
+ Returns:
173
+ Dict of function name to :class:`BaseEvalFunction`.
174
+ """
175
+ return self._public_to_function
176
+
177
+ @property
178
+ def private_functions(self) -> Dict[str, CustomEvalFunction]:
179
+ """Custom functions uploaded to Model CI
180
+
181
+ Returns:
182
+ Dict of function name to :class:`CustomEvalFunction`.
183
+ """
184
+ return self._custom_to_function
185
+
186
+ def _assign_eval_function_if_defined(
187
+ self,
188
+ eval_function_constructor: Callable[[EvalFunctionEntry], EvalFunction],
189
+ ):
190
+ """Helper function for book-keeping and assignment of standard Scale provided functions that are accessible
191
+ via attribute access
192
+ """
193
+ # TODO(gunnar): Too convoluted .. simplify
194
+ expected_name = eval_function_constructor.expected_name() # type: ignore
195
+ if expected_name in self._public_func_entries:
196
+ definition = self._public_func_entries[expected_name]
197
+ eval_function = eval_function_constructor(definition)
198
+ self._public_to_function[expected_name] = eval_function # type: ignore
199
+ return eval_function
200
+ else:
201
+ return EvalFunctionNotAvailable(expected_name)
202
+
203
+ def from_id(self, eval_function_id: str):
204
+ for eval_func in itertools.chain(
205
+ self._public_to_function.values(),
206
+ self._custom_to_function.values(),
207
+ ):
208
+ if eval_func.id == eval_function_id:
209
+ return eval_func
210
+ raise EvalFunctionNotAvailableError(
211
+ f"Could not find Eval Function with id {eval_function_id}"
212
+ )
@@ -0,0 +1,60 @@
1
+ import abc
2
+
3
+ from ..constants import ThresholdComparison
4
+ from ..data_transfer_objects.eval_function import (
5
+ EvalFunctionEntry,
6
+ EvaluationCriterion,
7
+ )
8
+
9
+
10
+ class BaseEvalFunction(abc.ABC):
11
+ """Abstract base class for concrete implementations of EvalFunctions
12
+
13
+ Operating on this class with comparison operators produces an EvaluationCriterion
14
+ """
15
+
16
+ def __init__(self, eval_func_entry: EvalFunctionEntry):
17
+ self.eval_func_entry = eval_func_entry
18
+ self.id = eval_func_entry.id
19
+ self.name = eval_func_entry.name
20
+
21
+ def __repr__(self):
22
+ return f"<EvalFunction: name={self.name}, id={self.id}>"
23
+
24
+ @classmethod
25
+ @abc.abstractmethod
26
+ def expected_name(cls) -> str:
27
+ """Name to look for in the EvalFunctionDefinitions"""
28
+
29
+ def __call__(self) -> "BaseEvalFunction":
30
+ """Adding call to prepare for being able to pass parameters to function
31
+
32
+ Notes:
33
+ Technically now you could do something like eval_function > 0.5 but we want it
34
+ to look like eval_function() > 0.5 to support eval_function(parameters) > 0.5
35
+ in the future
36
+ """
37
+ return self
38
+
39
+ def __gt__(self, other) -> EvaluationCriterion:
40
+ return self._op_to_test_metric(ThresholdComparison.GREATER_THAN, other)
41
+
42
+ def __ge__(self, other) -> EvaluationCriterion:
43
+ return self._op_to_test_metric(
44
+ ThresholdComparison.GREATER_THAN_EQUAL_TO, other
45
+ )
46
+
47
+ def __lt__(self, other) -> EvaluationCriterion:
48
+ return self._op_to_test_metric(ThresholdComparison.LESS_THAN, other)
49
+
50
+ def __le__(self, other) -> EvaluationCriterion:
51
+ return self._op_to_test_metric(
52
+ ThresholdComparison.LESS_THAN_EQUAL_TO, other
53
+ )
54
+
55
+ def _op_to_test_metric(self, comparison: ThresholdComparison, value):
56
+ return EvaluationCriterion(
57
+ eval_function_id=self.eval_func_entry.id,
58
+ threshold_comparison=comparison,
59
+ threshold=value,
60
+ )
@@ -0,0 +1,143 @@
1
+ """Scenario Tests combine collections of data and evaluation metrics to accelerate model evaluation.
2
+
3
+ With Model CI Scenario Tests, an ML engineer can define a Scenario Test from critical
4
+ edge case scenarios that the model must get right (e.g. pedestrians at night),
5
+ and have confidence that they’re always shipping the best model.
6
+ """
7
+ from dataclasses import dataclass, field
8
+ from typing import List
9
+
10
+ from ..connection import Connection
11
+ from ..constants import NAME_KEY, SLICE_ID_KEY
12
+ from ..dataset_item import DatasetItem
13
+ from .data_transfer_objects.eval_function import EvaluationCriterion
14
+ from .data_transfer_objects.scenario_test_evaluations import GetEvalHistory
15
+ from .data_transfer_objects.scenario_test_metric import AddScenarioTestMetric
16
+ from .scenario_test_evaluation import ScenarioTestEvaluation
17
+ from .scenario_test_metric import ScenarioTestMetric
18
+
19
+ DATASET_ITEMS_KEY = "dataset_items"
20
+
21
+
22
+ @dataclass
23
+ class ScenarioTest:
24
+ """A Scenario Test combines a slice and at least one evaluation criterion. A :class:`ScenarioTest` is not created through
25
+ the default constructor but using the instructions shown in :class:`Validate`. This :class:`ScenarioTest` class only
26
+ simplifies the interaction with the scenario tests from this SDK.
27
+
28
+ Attributes:
29
+ id (str): The ID of the scenario test.
30
+ connection (Connection): The connection to Nucleus API.
31
+ name (str): The name of the scenario test.
32
+ slice_id (str): The ID of the associated Nucleus slice.
33
+ """
34
+
35
+ id: str
36
+ connection: Connection = field(repr=False)
37
+ name: str = field(init=False)
38
+ slice_id: str = field(init=False)
39
+
40
+ def __post_init__(self):
41
+ # TODO(gunnar): Remove this pattern. It's too slow. We should get all the info required in one call
42
+ response = self.connection.get(
43
+ f"validate/scenario_test/{self.id}/info",
44
+ )
45
+ self.name = response[NAME_KEY]
46
+ self.slice_id = response[SLICE_ID_KEY]
47
+
48
+ def add_criterion(
49
+ self, evaluation_criterion: EvaluationCriterion
50
+ ) -> ScenarioTestMetric:
51
+ """Creates and adds a new criteria to the :class:`ScenarioTest`. ::
52
+
53
+ import nucleus
54
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
55
+ scenario_test = client.validate.create_scenario_test(
56
+ "sample_scenario_test", "slc_bx86ea222a6g057x4380"
57
+ )
58
+
59
+ e = client.validate.eval_functions
60
+ # Assuming a user would like to add all available public evaluation functions as criteria
61
+ scenario_test.add_criterion(
62
+ e.bbox_iou() > 0.5
63
+ )
64
+ scenario_test.add_criterion(
65
+ e.bbox_map() > 0.85
66
+ )
67
+ scenario_test.add_criterion(
68
+ e.bbox_precision() > 0.7
69
+ )
70
+ scenario_test.add_criterion(
71
+ e.bbox_recall() > 0.6
72
+ )
73
+
74
+ Args:
75
+ evaluation_criterion: :class:`EvaluationCriterion` created by comparison with an :class:`EvalFunction`
76
+
77
+ Returns:
78
+ The created ScenarioTestMetric object.
79
+ """
80
+ response = self.connection.post(
81
+ AddScenarioTestMetric(
82
+ scenario_test_name=self.name,
83
+ eval_function_id=evaluation_criterion.eval_function_id,
84
+ threshold=evaluation_criterion.threshold,
85
+ threshold_comparison=evaluation_criterion.threshold_comparison,
86
+ ).dict(),
87
+ "validate/scenario_test_metric",
88
+ )
89
+ return ScenarioTestMetric(
90
+ scenario_test_id=response["scenario_test_id"],
91
+ eval_function_id=response["eval_function_id"],
92
+ threshold=evaluation_criterion.threshold,
93
+ threshold_comparison=evaluation_criterion.threshold_comparison,
94
+ )
95
+
96
+ def get_criteria(self) -> List[ScenarioTestMetric]:
97
+ """Retrieves all criteria of the :class:`ScenarioTest`. ::
98
+
99
+ import nucleus
100
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
101
+ scenario_test = client.validate.list_scenario_tests()[0]
102
+
103
+ scenario_test.get_criteria()
104
+
105
+ Returns:
106
+ A list of ScenarioTestMetric objects.
107
+ """
108
+ response = self.connection.get(
109
+ f"validate/scenario_test/{self.id}/metrics",
110
+ )
111
+ return [
112
+ ScenarioTestMetric(**metric)
113
+ for metric in response["scenario_test_metrics"]
114
+ ]
115
+
116
+ def get_eval_history(self) -> List[ScenarioTestEvaluation]:
117
+ """Retrieves evaluation history for :class:`ScenarioTest`. ::
118
+
119
+ import nucleus
120
+ client = nucleus.NucleusClient("YOUR_SCALE_API_KEY")
121
+ scenario_test = client.validate.list_scenario_tests()[0]
122
+
123
+ scenario_test.get_eval_history()
124
+
125
+ Returns:
126
+ A list of :class:`ScenarioTestEvaluation` objects.
127
+ """
128
+ response = self.connection.get(
129
+ f"validate/scenario_test/{self.id}/eval_history",
130
+ )
131
+ eval_history = GetEvalHistory.parse_obj(response)
132
+ return [
133
+ ScenarioTestEvaluation(evaluation.id, self.connection)
134
+ for evaluation in eval_history.evaluations
135
+ ]
136
+
137
+ def get_items(self) -> List[DatasetItem]:
138
+ response = self.connection.get(
139
+ f"validate/scenario_test/{self.id}/items",
140
+ )
141
+ return [
142
+ DatasetItem.from_json(item) for item in response[DATASET_ITEMS_KEY]
143
+ ]
@@ -0,0 +1,114 @@
1
+ """Data types for Scenario Test Evaluation results."""
2
+ from dataclasses import InitVar, dataclass, field
3
+ from enum import Enum
4
+ from typing import List, Optional
5
+
6
+ import requests
7
+
8
+ from nucleus.connection import Connection
9
+ from nucleus.constants import DATASET_ITEM_ID_KEY, MODEL_ID_KEY, STATUS_KEY
10
+ from nucleus.validate.constants import (
11
+ EVAL_FUNCTION_ID_KEY,
12
+ EVALUATION_ID_KEY,
13
+ PASS_KEY,
14
+ RESULT_KEY,
15
+ SCENARIO_TEST_ID_KEY,
16
+ )
17
+
18
+ from .utils import try_convert_float
19
+
20
+ SCENARIO_TEST_EVAL_KEY = "scenario_test_evaluation"
21
+ ITEM_EVAL_KEY = "scenario_test_item_evaluations"
22
+
23
+
24
+ class ScenarioTestEvaluationStatus(Enum):
25
+ """The Job status of scenario test evaluation."""
26
+
27
+ PENDING = "pending"
28
+ STARTED = "started"
29
+ COMPLETED = "completed"
30
+ ERRORED = "errored"
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class ScenarioTestItemEvaluation:
35
+ """Dataset item-level results of an evaluation of a scenario test.
36
+ Note that this class is immutable.
37
+
38
+ Attributes:
39
+ evaluation_id (str): The ID of the associated scenario test evaluation
40
+ scenario_test_id (str): The ID of the associated scenario test.
41
+ eval_function_id (str): The ID of the associated evaluation function.
42
+ dataset_item_id (str): The ID of the dataset item of this evaluation.
43
+ result (Optional[float]): The numerical result of the evaluation on this item.
44
+ passed (bool): Whether the result was sufficient to pass the test for this item.
45
+ """
46
+
47
+ evaluation_id: str
48
+ scenario_test_id: str
49
+ eval_function_id: str
50
+ dataset_item_id: str
51
+ result: Optional[float]
52
+ passed: bool
53
+
54
+
55
+ @dataclass
56
+ class ScenarioTestEvaluation:
57
+ """The results and attributes of an evaluation of a scenario test.
58
+
59
+ Attributes:
60
+ id (str): The ID of this scenario test evaluation.
61
+ scenario_test_id (str): The ID of the associated scenario test.
62
+ eval_function_id (str): The ID of the associated evaluation function.
63
+ model_id (str): THe ID of the associated model.
64
+ status (str): The status of the evaluation job.
65
+ result (Optional[float]): The float result of the evaluation.
66
+ passed (bool): Whether the scenario test was passed.
67
+ item_evals (List[ScenarioTestItemEvaluation]): The individual results for each dataset item.
68
+ connection (Connection): The connection to the Nucleus API.
69
+ """
70
+
71
+ # pylint: disable=too-many-instance-attributes
72
+
73
+ id: str
74
+ scenario_test_id: str = field(init=False)
75
+ eval_function_id: str = field(init=False)
76
+ model_id: str = field(init=False)
77
+ status: ScenarioTestEvaluationStatus = field(init=False)
78
+ result: Optional[float] = field(init=False)
79
+ passed: bool = field(init=False)
80
+ item_evals: List[ScenarioTestItemEvaluation] = field(init=False)
81
+ connection: InitVar[Connection]
82
+
83
+ def __post_init__(self, connection: Connection):
84
+ # TODO(gunnar): Having the function call /info on every construction is too slow. The original
85
+ # endpoint should rather return the necessary human-readable information
86
+ response = connection.make_request(
87
+ {},
88
+ f"validate/eval/{self.id}/info",
89
+ requests_command=requests.get,
90
+ )
91
+ eval_response = response[SCENARIO_TEST_EVAL_KEY]
92
+ items_response = response[ITEM_EVAL_KEY]
93
+
94
+ self.scenario_test_id: str = eval_response[SCENARIO_TEST_ID_KEY]
95
+ self.eval_function_id: str = eval_response[EVAL_FUNCTION_ID_KEY]
96
+ self.model_id: str = eval_response[MODEL_ID_KEY]
97
+ self.status: ScenarioTestEvaluationStatus = (
98
+ ScenarioTestEvaluationStatus(eval_response[STATUS_KEY])
99
+ )
100
+ self.result: Optional[float] = try_convert_float(
101
+ eval_response[RESULT_KEY]
102
+ )
103
+ self.passed: bool = bool(eval_response[PASS_KEY])
104
+ self.item_evals: List[ScenarioTestItemEvaluation] = [
105
+ ScenarioTestItemEvaluation(
106
+ evaluation_id=res[EVALUATION_ID_KEY],
107
+ scenario_test_id=res[SCENARIO_TEST_ID_KEY],
108
+ eval_function_id=res[EVAL_FUNCTION_ID_KEY],
109
+ dataset_item_id=res[DATASET_ITEM_ID_KEY],
110
+ result=try_convert_float(res[RESULT_KEY]),
111
+ passed=bool(res[PASS_KEY]),
112
+ )
113
+ for res in items_response
114
+ ]
@@ -0,0 +1,14 @@
1
+ from nucleus.pydantic_base import ImmutableModel
2
+
3
+ from .constants import ThresholdComparison
4
+
5
+
6
+ class ScenarioTestMetric(ImmutableModel):
7
+ """A Scenario Test Metric is an evaluation function combined with a comparator and associated with a Scenario Test.
8
+ Scenario Test Metrics serve as the basis when evaluating a Model on a Scenario Test.
9
+ """
10
+
11
+ scenario_test_id: str
12
+ eval_function_id: str
13
+ threshold: float
14
+ threshold_comparison: ThresholdComparison
@@ -0,0 +1,8 @@
1
+ from typing import Optional
2
+
3
+
4
+ def try_convert_float(float_str: str) -> Optional[float]:
5
+ try:
6
+ return float(float_str)
7
+ except ValueError:
8
+ return None