scale-nucleus 0.12b1__py3-none-any.whl → 0.14.14b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/slices.py +14 -28
- nucleus/__init__.py +211 -18
- nucleus/annotation.py +28 -5
- nucleus/connection.py +9 -1
- nucleus/constants.py +9 -3
- nucleus/dataset.py +197 -59
- nucleus/dataset_item.py +11 -1
- nucleus/job.py +1 -1
- nucleus/metrics/__init__.py +2 -1
- nucleus/metrics/base.py +34 -56
- nucleus/metrics/categorization_metrics.py +6 -2
- nucleus/metrics/cuboid_utils.py +4 -6
- nucleus/metrics/errors.py +4 -0
- nucleus/metrics/filtering.py +369 -19
- nucleus/metrics/polygon_utils.py +3 -3
- nucleus/metrics/segmentation_loader.py +30 -0
- nucleus/metrics/segmentation_metrics.py +256 -195
- nucleus/metrics/segmentation_to_poly_metrics.py +229 -105
- nucleus/metrics/segmentation_utils.py +239 -8
- nucleus/model.py +66 -10
- nucleus/model_run.py +1 -1
- nucleus/{shapely_not_installed.py → package_not_installed.py} +3 -3
- nucleus/payload_constructor.py +4 -0
- nucleus/prediction.py +6 -3
- nucleus/scene.py +7 -0
- nucleus/slice.py +160 -16
- nucleus/utils.py +51 -12
- nucleus/validate/__init__.py +1 -0
- nucleus/validate/client.py +57 -8
- nucleus/validate/constants.py +1 -0
- nucleus/validate/data_transfer_objects/eval_function.py +22 -0
- nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +13 -5
- nucleus/validate/eval_functions/available_eval_functions.py +33 -20
- nucleus/validate/eval_functions/config_classes/segmentation.py +2 -46
- nucleus/validate/scenario_test.py +71 -13
- nucleus/validate/scenario_test_evaluation.py +21 -21
- nucleus/validate/utils.py +1 -1
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/LICENSE +0 -0
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/METADATA +13 -11
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/RECORD +42 -41
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/WHEEL +1 -1
- {scale_nucleus-0.12b1.dist-info → scale_nucleus-0.14.14b0.dist-info}/entry_points.txt +0 -0
nucleus/slice.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1
|
+
import datetime
|
1
2
|
import warnings
|
2
|
-
from typing import Dict, Iterable, List, Set, Tuple, Union
|
3
|
+
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
|
3
4
|
|
4
5
|
import requests
|
5
6
|
|
6
7
|
from nucleus.annotation import Annotation
|
7
|
-
from nucleus.constants import EXPORTED_ROWS, ITEMS_KEY
|
8
|
+
from nucleus.constants import EXPORT_FOR_TRAINING_KEY, EXPORTED_ROWS, ITEMS_KEY
|
8
9
|
from nucleus.dataset_item import DatasetItem
|
9
10
|
from nucleus.errors import NucleusAPIError
|
10
11
|
from nucleus.job import AsyncJob
|
11
12
|
from nucleus.utils import (
|
12
13
|
KeyErrorDict,
|
13
14
|
convert_export_payload,
|
14
|
-
|
15
|
+
format_scale_task_info_response,
|
15
16
|
paginate_generator,
|
16
17
|
)
|
17
18
|
|
@@ -49,9 +50,11 @@ class Slice:
|
|
49
50
|
self._client = client
|
50
51
|
self._name = None
|
51
52
|
self._dataset_id = None
|
53
|
+
self._created_at = None
|
54
|
+
self._pending_job_count = None
|
52
55
|
|
53
56
|
def __repr__(self):
|
54
|
-
return f"Slice(slice_id='{self.id}',
|
57
|
+
return f"Slice(slice_id='{self.id}', name={self._name}, dataset_id={self._dataset_id})"
|
55
58
|
|
56
59
|
def __eq__(self, other):
|
57
60
|
if self.id == other.id:
|
@@ -59,6 +62,43 @@ class Slice:
|
|
59
62
|
return True
|
60
63
|
return False
|
61
64
|
|
65
|
+
@property
|
66
|
+
def created_at(self) -> Optional[datetime.datetime]:
|
67
|
+
"""Timestamp of creation of the slice
|
68
|
+
|
69
|
+
Returns:
|
70
|
+
datetime of creation or None if not created yet
|
71
|
+
"""
|
72
|
+
if self._created_at is None:
|
73
|
+
self._created_at = self.info().get("created_at", None)
|
74
|
+
return self._created_at
|
75
|
+
|
76
|
+
@property
|
77
|
+
def pending_job_count(self) -> Optional[int]:
|
78
|
+
if self._pending_job_count is None:
|
79
|
+
self._pending_job_count = self.info().get(
|
80
|
+
"pending_job_count", None
|
81
|
+
)
|
82
|
+
return self._pending_job_count
|
83
|
+
|
84
|
+
@classmethod
|
85
|
+
def from_request(cls, request, client):
|
86
|
+
instance = cls(request["id"], client)
|
87
|
+
instance._name = request.get("name", None)
|
88
|
+
instance._dataset_id = request.get("dataset_id", None)
|
89
|
+
created_at_str = request.get("created_at").rstrip("Z")
|
90
|
+
if hasattr(datetime.datetime, "fromisoformat"):
|
91
|
+
instance._created_at = datetime.datetime.fromisoformat(
|
92
|
+
created_at_str
|
93
|
+
)
|
94
|
+
else:
|
95
|
+
fmt_str = r"%Y-%m-%dT%H:%M:%S.%f" # replaces the fromisoformatm, not available in python 3.6
|
96
|
+
instance._created_at = datetime.datetime.strptime(
|
97
|
+
created_at_str, fmt_str
|
98
|
+
)
|
99
|
+
instance._pending_job_count = request.get("pending_job_count", None)
|
100
|
+
return instance
|
101
|
+
|
62
102
|
@property
|
63
103
|
def slice_id(self):
|
64
104
|
warnings.warn(
|
@@ -85,9 +125,11 @@ class Slice:
|
|
85
125
|
"""Generator yielding all dataset items in the dataset.
|
86
126
|
|
87
127
|
::
|
88
|
-
|
89
|
-
|
90
|
-
|
128
|
+
|
129
|
+
collected_ref_ids = []
|
130
|
+
for item in dataset.items_generator():
|
131
|
+
print(f"Exporting item: {item.reference_id}")
|
132
|
+
collected_ref_ids.append(item.reference_id)
|
91
133
|
|
92
134
|
Args:
|
93
135
|
page_size (int, optional): Number of items to return per page. If you are
|
@@ -110,7 +152,7 @@ class Slice:
|
|
110
152
|
def items(self):
|
111
153
|
"""All DatasetItems contained in the Slice.
|
112
154
|
|
113
|
-
|
155
|
+
We recommend using :meth:`Slice.items_generator` if the Slice has more than 200k items.
|
114
156
|
|
115
157
|
"""
|
116
158
|
try:
|
@@ -184,7 +226,7 @@ class Slice:
|
|
184
226
|
|
185
227
|
Returns:
|
186
228
|
Generator where each element is a dict containing the DatasetItem
|
187
|
-
and all of its associated Annotations, grouped by type.
|
229
|
+
and all of its associated Annotations, grouped by type (e.g. box).
|
188
230
|
::
|
189
231
|
|
190
232
|
Iterable[{
|
@@ -193,18 +235,22 @@ class Slice:
|
|
193
235
|
"box": List[BoxAnnotation],
|
194
236
|
"polygon": List[PolygonAnnotation],
|
195
237
|
"cuboid": List[CuboidAnnotation],
|
238
|
+
"line": List[LineAnnotation],
|
196
239
|
"segmentation": List[SegmentationAnnotation],
|
197
240
|
"category": List[CategoryAnnotation],
|
241
|
+
"keypoints": List[KeypointsAnnotation],
|
198
242
|
}
|
199
243
|
}]
|
200
244
|
"""
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
245
|
+
json_generator = paginate_generator(
|
246
|
+
client=self._client,
|
247
|
+
endpoint=f"slice/{self.id}/exportForTrainingPage",
|
248
|
+
result_key=EXPORT_FOR_TRAINING_KEY,
|
249
|
+
page_size=100000,
|
250
|
+
)
|
251
|
+
for data in json_generator:
|
252
|
+
for ia in convert_export_payload([data], has_predictions=False):
|
253
|
+
yield ia
|
208
254
|
|
209
255
|
def items_and_annotations(
|
210
256
|
self,
|
@@ -222,8 +268,10 @@ class Slice:
|
|
222
268
|
"box": List[BoxAnnotation],
|
223
269
|
"polygon": List[PolygonAnnotation],
|
224
270
|
"cuboid": List[CuboidAnnotation],
|
271
|
+
"line": List[LineAnnotation],
|
225
272
|
"segmentation": List[SegmentationAnnotation],
|
226
273
|
"category": List[CategoryAnnotation],
|
274
|
+
"keypoints": List[KeypointsAnnotation],
|
227
275
|
}
|
228
276
|
}]
|
229
277
|
"""
|
@@ -234,6 +282,102 @@ class Slice:
|
|
234
282
|
)
|
235
283
|
return convert_export_payload(api_payload[EXPORTED_ROWS])
|
236
284
|
|
285
|
+
def export_predictions(
|
286
|
+
self, model
|
287
|
+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
|
288
|
+
"""Provides a list of all DatasetItems and Predictions in the Slice for the given Model.
|
289
|
+
|
290
|
+
Parameters:
|
291
|
+
model (Model): the nucleus model objects representing the model for which to export predictions.
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
List where each element is a dict containing the DatasetItem
|
295
|
+
and all of its associated Predictions, grouped by type (e.g. box).
|
296
|
+
::
|
297
|
+
|
298
|
+
List[{
|
299
|
+
"item": DatasetItem,
|
300
|
+
"predictions": {
|
301
|
+
"box": List[BoxAnnotation],
|
302
|
+
"polygon": List[PolygonAnnotation],
|
303
|
+
"cuboid": List[CuboidAnnotation],
|
304
|
+
"segmentation": List[SegmentationAnnotation],
|
305
|
+
"category": List[CategoryAnnotation],
|
306
|
+
}
|
307
|
+
}]
|
308
|
+
"""
|
309
|
+
api_payload = self._client.make_request(
|
310
|
+
payload=None,
|
311
|
+
route=f"slice/{self.id}/{model.id}/exportForTraining",
|
312
|
+
requests_command=requests.get,
|
313
|
+
)
|
314
|
+
return convert_export_payload(api_payload[EXPORTED_ROWS], True)
|
315
|
+
|
316
|
+
def export_predictions_generator(
|
317
|
+
self, model
|
318
|
+
) -> Iterable[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
|
319
|
+
"""Provides a list of all DatasetItems and Predictions in the Slice for the given Model.
|
320
|
+
|
321
|
+
Parameters:
|
322
|
+
model (Model): the nucleus model objects representing the model for which to export predictions.
|
323
|
+
|
324
|
+
Returns:
|
325
|
+
Iterable where each element is a dict containing the DatasetItem
|
326
|
+
and all of its associated Predictions, grouped by type (e.g. box).
|
327
|
+
::
|
328
|
+
|
329
|
+
List[{
|
330
|
+
"item": DatasetItem,
|
331
|
+
"predictions": {
|
332
|
+
"box": List[BoxAnnotation],
|
333
|
+
"polygon": List[PolygonAnnotation],
|
334
|
+
"cuboid": List[CuboidAnnotation],
|
335
|
+
"segmentation": List[SegmentationAnnotation],
|
336
|
+
"category": List[CategoryAnnotation],
|
337
|
+
}
|
338
|
+
}]
|
339
|
+
"""
|
340
|
+
json_generator = paginate_generator(
|
341
|
+
client=self._client,
|
342
|
+
endpoint=f"slice/{self.id}/{model.id}/exportForTrainingPage",
|
343
|
+
result_key=EXPORT_FOR_TRAINING_KEY,
|
344
|
+
page_size=100000,
|
345
|
+
)
|
346
|
+
for data in json_generator:
|
347
|
+
for ip in convert_export_payload([data], has_predictions=True):
|
348
|
+
yield ip
|
349
|
+
|
350
|
+
def export_scale_task_info(self):
|
351
|
+
"""Fetches info for all linked Scale tasks of items/scenes in the slice.
|
352
|
+
|
353
|
+
Returns:
|
354
|
+
A list of dicts, each with two keys, respectively mapping to items/scenes
|
355
|
+
and info on their corresponding Scale tasks within the dataset::
|
356
|
+
|
357
|
+
List[{
|
358
|
+
"item" | "scene": Union[DatasetItem, Scene],
|
359
|
+
"scale_task_info": {
|
360
|
+
"task_id": str,
|
361
|
+
"subtask_id": str,
|
362
|
+
"task_status": str,
|
363
|
+
"task_audit_status": str,
|
364
|
+
"task_audit_review_comment": Optional[str],
|
365
|
+
"project_name": str,
|
366
|
+
"batch": str,
|
367
|
+
"created_at": str,
|
368
|
+
"completed_at": Optional[str]
|
369
|
+
}]
|
370
|
+
}]
|
371
|
+
|
372
|
+
"""
|
373
|
+
response = self._client.make_request(
|
374
|
+
payload=None,
|
375
|
+
route=f"slice/{self.id}/exportScaleTaskInfo",
|
376
|
+
requests_command=requests.get,
|
377
|
+
)
|
378
|
+
# TODO: implement format function with nice keying
|
379
|
+
return format_scale_task_info_response(response)
|
380
|
+
|
237
381
|
def send_to_labeling(self, project_id: str):
|
238
382
|
"""Send items in the Slice as tasks to a Scale labeling project.
|
239
383
|
|
nucleus/utils.py
CHANGED
@@ -28,16 +28,20 @@ from .constants import (
|
|
28
28
|
BOX_TYPE,
|
29
29
|
CATEGORY_TYPE,
|
30
30
|
CUBOID_TYPE,
|
31
|
+
EXPORTED_SCALE_TASK_INFO_ROWS,
|
31
32
|
ITEM_KEY,
|
32
33
|
KEYPOINTS_TYPE,
|
33
|
-
LAST_PAGE,
|
34
34
|
LINE_TYPE,
|
35
35
|
MAX_PAYLOAD_SIZE,
|
36
36
|
MULTICATEGORY_TYPE,
|
37
|
-
|
38
|
-
|
37
|
+
NEXT_TOKEN_KEY,
|
38
|
+
PAGE_SIZE_KEY,
|
39
|
+
PAGE_TOKEN_KEY,
|
39
40
|
POLYGON_TYPE,
|
41
|
+
PREDICTIONS_KEY,
|
40
42
|
REFERENCE_ID_KEY,
|
43
|
+
SCALE_TASK_INFO_KEY,
|
44
|
+
SCENE_KEY,
|
41
45
|
SEGMENTATION_TYPE,
|
42
46
|
)
|
43
47
|
from .dataset_item import DatasetItem
|
@@ -160,7 +164,7 @@ def format_dataset_item_response(response: dict) -> dict:
|
|
160
164
|
Args:
|
161
165
|
response: JSON dictionary response from REST endpoint
|
162
166
|
Returns:
|
163
|
-
item_dict: A dictionary with two entries, one for the dataset item, and
|
167
|
+
item_dict: A dictionary with two entries, one for the dataset item, and another
|
164
168
|
for all of the associated annotations.
|
165
169
|
"""
|
166
170
|
if ANNOTATIONS_KEY not in response:
|
@@ -187,7 +191,34 @@ def format_dataset_item_response(response: dict) -> dict:
|
|
187
191
|
}
|
188
192
|
|
189
193
|
|
190
|
-
def
|
194
|
+
def format_scale_task_info_response(response: dict) -> Union[Dict, List[Dict]]:
|
195
|
+
"""Format the raw client response into api objects.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
response: JSON dictionary response from REST endpoint
|
199
|
+
Returns:
|
200
|
+
A dictionary with two entries, one for the dataset item, and another
|
201
|
+
for all of the associated Scale tasks.
|
202
|
+
"""
|
203
|
+
if EXPORTED_SCALE_TASK_INFO_ROWS not in response:
|
204
|
+
# Payload is empty so an error occurred
|
205
|
+
return response
|
206
|
+
|
207
|
+
ret = []
|
208
|
+
for row in response[EXPORTED_SCALE_TASK_INFO_ROWS]:
|
209
|
+
if ITEM_KEY in row:
|
210
|
+
ret.append(
|
211
|
+
{
|
212
|
+
ITEM_KEY: DatasetItem.from_json(row[ITEM_KEY]),
|
213
|
+
SCALE_TASK_INFO_KEY: row[SCALE_TASK_INFO_KEY],
|
214
|
+
}
|
215
|
+
)
|
216
|
+
elif SCENE_KEY in row:
|
217
|
+
ret.append(row)
|
218
|
+
return ret
|
219
|
+
|
220
|
+
|
221
|
+
def convert_export_payload(api_payload, has_predictions: bool = False):
|
191
222
|
"""Helper function to convert raw JSON to API objects
|
192
223
|
|
193
224
|
Args:
|
@@ -237,7 +268,9 @@ def convert_export_payload(api_payload):
|
|
237
268
|
annotations[MULTICATEGORY_TYPE].append(
|
238
269
|
MultiCategoryAnnotation.from_json(multicategory)
|
239
270
|
)
|
240
|
-
return_payload_row[
|
271
|
+
return_payload_row[
|
272
|
+
ANNOTATIONS_KEY if not has_predictions else PREDICTIONS_KEY
|
273
|
+
] = annotations
|
241
274
|
return_payload.append(return_payload_row)
|
242
275
|
return return_payload
|
243
276
|
|
@@ -273,7 +306,7 @@ def serialize_and_write(
|
|
273
306
|
f"The following {type_name} could not be serialized: {unit}\n"
|
274
307
|
)
|
275
308
|
message += (
|
276
|
-
"This is
|
309
|
+
"This is usually an issue with a custom python object being "
|
277
310
|
"present in the metadata. Please inspect this error and adjust the "
|
278
311
|
"metadata so it is json-serializable: only python primitives such as "
|
279
312
|
"strings, ints, floats, lists, and dicts. For example, you must "
|
@@ -329,13 +362,17 @@ def paginate_generator(
|
|
329
362
|
endpoint: str,
|
330
363
|
result_key: str,
|
331
364
|
page_size: int = 100000,
|
365
|
+
**kwargs,
|
332
366
|
):
|
333
|
-
|
334
|
-
|
335
|
-
while not last_page:
|
367
|
+
next_token = None
|
368
|
+
while True:
|
336
369
|
try:
|
337
370
|
response = client.make_request(
|
338
|
-
{
|
371
|
+
{
|
372
|
+
PAGE_TOKEN_KEY: next_token,
|
373
|
+
PAGE_SIZE_KEY: page_size,
|
374
|
+
**kwargs,
|
375
|
+
},
|
339
376
|
endpoint,
|
340
377
|
requests.post,
|
341
378
|
)
|
@@ -343,6 +380,8 @@ def paginate_generator(
|
|
343
380
|
if e.status_code == 503:
|
344
381
|
e.message += f"/n Your request timed out while trying to get a page size of {page_size}. Try lowering the page_size."
|
345
382
|
raise e
|
346
|
-
|
383
|
+
next_token = response[NEXT_TOKEN_KEY]
|
347
384
|
for json_value in response[result_key]:
|
348
385
|
yield json_value
|
386
|
+
if not next_token:
|
387
|
+
break
|
nucleus/validate/__init__.py
CHANGED
@@ -14,6 +14,7 @@ from .data_transfer_objects.eval_function import (
|
|
14
14
|
GetEvalFunctions,
|
15
15
|
)
|
16
16
|
from .data_transfer_objects.scenario_test import CreateScenarioTestRequest
|
17
|
+
from .data_transfer_objects.scenario_test_evaluations import EvaluationResult
|
17
18
|
from .errors import CreateScenarioTestError
|
18
19
|
from .eval_functions.available_eval_functions import AvailableEvalFunctions
|
19
20
|
from .scenario_test import ScenarioTest
|
nucleus/validate/client.py
CHANGED
@@ -3,8 +3,12 @@ from typing import List
|
|
3
3
|
from nucleus.connection import Connection
|
4
4
|
from nucleus.job import AsyncJob
|
5
5
|
|
6
|
-
from .constants import SCENARIO_TEST_ID_KEY
|
7
|
-
from .data_transfer_objects.eval_function import
|
6
|
+
from .constants import EVAL_FUNCTION_KEY, SCENARIO_TEST_ID_KEY
|
7
|
+
from .data_transfer_objects.eval_function import (
|
8
|
+
CreateEvalFunction,
|
9
|
+
EvalFunctionEntry,
|
10
|
+
GetEvalFunctions,
|
11
|
+
)
|
8
12
|
from .data_transfer_objects.scenario_test import (
|
9
13
|
CreateScenarioTestRequest,
|
10
14
|
EvalFunctionListEntry,
|
@@ -81,6 +85,15 @@ class Validate:
|
|
81
85
|
"evaluation_functions=[client.validate.eval_functions.bbox_iou()]"
|
82
86
|
)
|
83
87
|
|
88
|
+
external_fns = [
|
89
|
+
f.eval_func_entry.is_external_function
|
90
|
+
for f in evaluation_functions
|
91
|
+
]
|
92
|
+
if any(external_fns):
|
93
|
+
assert all(
|
94
|
+
external_fns
|
95
|
+
), "Cannot create scenario tests with mixed placeholder and non-placeholder evaluation functions"
|
96
|
+
|
84
97
|
response = self.connection.post(
|
85
98
|
CreateScenarioTestRequest(
|
86
99
|
name=name,
|
@@ -94,13 +107,17 @@ class Validate:
|
|
94
107
|
).dict(),
|
95
108
|
"validate/scenario_test",
|
96
109
|
)
|
97
|
-
return ScenarioTest(
|
110
|
+
return ScenarioTest.from_id(
|
111
|
+
response[SCENARIO_TEST_ID_KEY], self.connection
|
112
|
+
)
|
98
113
|
|
99
114
|
def get_scenario_test(self, scenario_test_id: str) -> ScenarioTest:
|
100
115
|
response = self.connection.get(
|
101
116
|
f"validate/scenario_test/{scenario_test_id}",
|
102
117
|
)
|
103
|
-
return ScenarioTest(
|
118
|
+
return ScenarioTest.from_id(
|
119
|
+
response["unit_test"]["id"], self.connection
|
120
|
+
)
|
104
121
|
|
105
122
|
@property
|
106
123
|
def scenario_tests(self) -> List[ScenarioTest]:
|
@@ -118,12 +135,13 @@ class Validate:
|
|
118
135
|
A list of ScenarioTest objects.
|
119
136
|
"""
|
120
137
|
response = self.connection.get(
|
121
|
-
"validate/scenario_test",
|
138
|
+
"validate/scenario_test/details",
|
122
139
|
)
|
123
|
-
|
124
|
-
ScenarioTest(
|
125
|
-
for
|
140
|
+
tests = [
|
141
|
+
ScenarioTest.from_response(payload, self.connection)
|
142
|
+
for payload in response
|
126
143
|
]
|
144
|
+
return tests
|
127
145
|
|
128
146
|
def delete_scenario_test(self, scenario_test_id: str) -> bool:
|
129
147
|
"""Deletes a Scenario Test. ::
|
@@ -175,3 +193,34 @@ class Validate:
|
|
175
193
|
f"validate/{model_id}/evaluate",
|
176
194
|
)
|
177
195
|
return AsyncJob.from_json(response, self.connection)
|
196
|
+
|
197
|
+
def create_external_eval_function(
|
198
|
+
self,
|
199
|
+
name: str,
|
200
|
+
) -> EvalFunctionEntry:
|
201
|
+
"""Creates a new external evaluation function. This external function can be used to upload evaluation
|
202
|
+
results with functions defined and computed by the customer, without having to share the source code of the
|
203
|
+
respective function.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
name: unique name of evaluation function
|
207
|
+
|
208
|
+
Raises:
|
209
|
+
- NucleusAPIError if the creation of the function fails on the server side
|
210
|
+
- ValidationError if the evaluation name is not well defined
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Created EvalFunctionConfig object.
|
214
|
+
|
215
|
+
"""
|
216
|
+
|
217
|
+
response = self.connection.post(
|
218
|
+
CreateEvalFunction(
|
219
|
+
name=name,
|
220
|
+
is_external_function=True,
|
221
|
+
serialized_fn=None,
|
222
|
+
raw_source=None,
|
223
|
+
).dict(),
|
224
|
+
"validate/eval_fn",
|
225
|
+
)
|
226
|
+
return EvalFunctionEntry.parse_obj(response[EVAL_FUNCTION_KEY])
|
nucleus/validate/constants.py
CHANGED
@@ -72,6 +72,7 @@ class EvalFunctionEntry(ImmutableModel):
|
|
72
72
|
id: str
|
73
73
|
name: str
|
74
74
|
is_public: bool
|
75
|
+
is_external_function: bool = False
|
75
76
|
user_id: str
|
76
77
|
serialized_fn: Optional[str] = None
|
77
78
|
raw_source: Optional[str] = None
|
@@ -81,3 +82,24 @@ class GetEvalFunctions(ImmutableModel):
|
|
81
82
|
"""Expected format from GET validate/eval_fn"""
|
82
83
|
|
83
84
|
eval_functions: List[EvalFunctionEntry]
|
85
|
+
|
86
|
+
|
87
|
+
class CreateEvalFunction(ImmutableModel):
|
88
|
+
"""Expected payload to POST validate/eval_fn"""
|
89
|
+
|
90
|
+
name: str
|
91
|
+
is_external_function: bool
|
92
|
+
serialized_fn: Optional[str] = None
|
93
|
+
raw_source: Optional[str] = None
|
94
|
+
|
95
|
+
@validator("name")
|
96
|
+
def name_is_valid(cls, v): # pylint: disable=no-self-argument
|
97
|
+
if " " in v:
|
98
|
+
raise ValueError(
|
99
|
+
f"No spaces allowed in an evaluation function name, got '{v}'"
|
100
|
+
)
|
101
|
+
if len(v) == 0 or len(v) > 255:
|
102
|
+
raise ValueError(
|
103
|
+
"Name of evaluation function must be between 1-255 characters long"
|
104
|
+
)
|
105
|
+
return v
|
@@ -1,11 +1,19 @@
|
|
1
1
|
from typing import List
|
2
2
|
|
3
|
-
from
|
3
|
+
from pydantic import validator
|
4
4
|
|
5
|
+
from nucleus.pydantic_base import ImmutableModel
|
5
6
|
|
6
|
-
class EvalDetail(ImmutableModel):
|
7
|
-
id: str
|
8
7
|
|
8
|
+
class EvaluationResult(ImmutableModel):
|
9
|
+
item_ref_id: str
|
10
|
+
score: float
|
11
|
+
weight: float = 1
|
9
12
|
|
10
|
-
|
11
|
-
|
13
|
+
@validator("score", "weight")
|
14
|
+
def is_normalized(cls, v): # pylint: disable=no-self-argument
|
15
|
+
if 0 <= v <= 1:
|
16
|
+
return v
|
17
|
+
raise ValueError(
|
18
|
+
f"Expected evaluation score and weights to be normalized between 0 and 1, but got: {v}"
|
19
|
+
)
|