scale-nucleus 0.1.3__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucleus/__init__.py +186 -209
- nucleus/annotation.py +51 -7
- nucleus/constants.py +56 -52
- nucleus/dataset.py +87 -10
- nucleus/dataset_item.py +36 -10
- nucleus/errors.py +18 -5
- nucleus/job.py +56 -0
- nucleus/model.py +2 -1
- nucleus/model_run.py +29 -6
- nucleus/payload_constructor.py +2 -2
- nucleus/prediction.py +17 -3
- nucleus/slice.py +18 -39
- nucleus/utils.py +75 -8
- {scale_nucleus-0.1.3.dist-info → scale_nucleus-0.1.10.dist-info}/LICENSE +0 -0
- {scale_nucleus-0.1.3.dist-info → scale_nucleus-0.1.10.dist-info}/METADATA +49 -12
- scale_nucleus-0.1.10.dist-info/RECORD +18 -0
- {scale_nucleus-0.1.3.dist-info → scale_nucleus-0.1.10.dist-info}/WHEEL +1 -1
- scale_nucleus-0.1.3.dist-info/RECORD +0 -17
nucleus/constants.py
CHANGED
@@ -1,63 +1,67 @@
|
|
1
|
-
|
2
|
-
DEFAULT_NETWORK_TIMEOUT_SEC = 120
|
3
|
-
ITEMS_KEY = "items"
|
4
|
-
ITEM_KEY = "item"
|
5
|
-
REFERENCE_ID_KEY = "reference_id"
|
6
|
-
REFERENCE_IDS_KEY = "reference_ids"
|
7
|
-
DATASET_ID_KEY = "dataset_id"
|
8
|
-
IMAGE_KEY = "image"
|
9
|
-
IMAGE_URL_KEY = "image_url"
|
10
|
-
NEW_ITEMS = "new_items"
|
11
|
-
UPDATED_ITEMS = "updated_items"
|
12
|
-
IGNORED_ITEMS = "ignored_items"
|
13
|
-
ERROR_ITEMS = "upload_errors"
|
14
|
-
ERROR_PAYLOAD = "error_payload"
|
15
|
-
ERROR_CODES = "error_codes"
|
1
|
+
ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
|
16
2
|
ANNOTATIONS_KEY = "annotations"
|
17
|
-
ANNOTATION_ID_KEY = "annotation_id"
|
18
3
|
ANNOTATIONS_PROCESSED_KEY = "annotations_processed"
|
19
|
-
|
20
|
-
|
21
|
-
|
4
|
+
ANNOTATION_ID_KEY = "annotation_id"
|
5
|
+
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
|
6
|
+
BOX_TYPE = "box"
|
7
|
+
POLYGON_TYPE = "polygon"
|
8
|
+
MASK_TYPE = "mask"
|
9
|
+
SEGMENTATION_TYPE = "segmentation"
|
10
|
+
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
|
22
11
|
ANNOTATION_UPDATE_KEY = "update"
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
MODEL_RUN_ID_KEY = "model_run_id"
|
29
|
-
MODEL_ID_KEY = "model_id"
|
30
|
-
DATASET_ITEM_ID_KEY = "dataset_item_id"
|
31
|
-
ITEM_ID_KEY = "item_id"
|
12
|
+
AUTOTAGS_KEY = "autotags"
|
13
|
+
EXPORTED_ROWS = "exportedRows"
|
14
|
+
CLASS_PDF_KEY = "class_pdf"
|
15
|
+
CONFIDENCE_KEY = "confidence"
|
16
|
+
DATASET_ID_KEY = "dataset_id"
|
32
17
|
DATASET_ITEM_IDS_KEY = "dataset_item_ids"
|
33
|
-
|
34
|
-
|
18
|
+
DATASET_ITEM_ID_KEY = "dataset_item_id"
|
19
|
+
DATASET_LENGTH_KEY = "length"
|
35
20
|
DATASET_MODEL_RUNS_KEY = "model_run_ids"
|
21
|
+
DATASET_NAME_KEY = "name"
|
36
22
|
DATASET_SLICES_KEY = "slice_ids"
|
37
|
-
|
38
|
-
|
23
|
+
DEFAULT_ANNOTATION_UPDATE_MODE = False
|
24
|
+
DEFAULT_NETWORK_TIMEOUT_SEC = 120
|
25
|
+
EMBEDDINGS_URL_KEY = "embeddings_url"
|
26
|
+
ERRORS_KEY = "errors"
|
27
|
+
ERROR_CODES = "error_codes"
|
28
|
+
ERROR_ITEMS = "upload_errors"
|
29
|
+
ERROR_PAYLOAD = "error_payload"
|
30
|
+
GEOMETRY_KEY = "geometry"
|
31
|
+
HEIGHT_KEY = "height"
|
32
|
+
IGNORED_ITEMS = "ignored_items"
|
33
|
+
IMAGE_KEY = "image"
|
34
|
+
IMAGE_URL_KEY = "image_url"
|
35
|
+
INDEX_KEY = "index"
|
36
|
+
ITEMS_KEY = "items"
|
37
|
+
ITEM_ID_KEY = "item_id"
|
38
|
+
ITEM_KEY = "item"
|
39
|
+
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
|
40
|
+
JOB_ID_KEY = "job_id"
|
41
|
+
LABEL_KEY = "label"
|
42
|
+
MASK_URL_KEY = "mask_url"
|
43
|
+
MESSAGE_KEY = "message"
|
39
44
|
METADATA_KEY = "metadata"
|
45
|
+
MODEL_ID_KEY = "model_id"
|
46
|
+
MODEL_RUN_ID_KEY = "model_run_id"
|
40
47
|
NAME_KEY = "name"
|
41
|
-
|
42
|
-
|
48
|
+
NEW_ITEMS = "new_items"
|
49
|
+
NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
|
43
50
|
ORIGINAL_IMAGE_URL_KEY = "original_image_url"
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
51
|
+
PREDICTIONS_IGNORED_KEY = "predictions_ignored"
|
52
|
+
PREDICTIONS_PROCESSED_KEY = "predictions_processed"
|
53
|
+
REFERENCE_IDS_KEY = "reference_ids"
|
54
|
+
REFERENCE_ID_KEY = "reference_id"
|
55
|
+
REQUEST_ID_KEY = "requestId"
|
56
|
+
SEGMENTATIONS_KEY = "segmentations"
|
57
|
+
SLICE_ID_KEY = "slice_id"
|
58
|
+
STATUS_CODE_KEY = "status_code"
|
59
|
+
STATUS_KEY = "status"
|
60
|
+
SUCCESS_STATUS_CODES = [200, 201, 202]
|
48
61
|
TYPE_KEY = "type"
|
62
|
+
UPDATED_ITEMS = "updated_items"
|
63
|
+
UPDATE_KEY = "update"
|
49
64
|
VERTICES_KEY = "vertices"
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
|
54
|
-
GEOMETRY_KEY = "geometry"
|
55
|
-
AUTOTAGS_KEY = "autotags"
|
56
|
-
ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
|
57
|
-
ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
|
58
|
-
MASK_URL_KEY = "mask_url"
|
59
|
-
INDEX_KEY = "index"
|
60
|
-
SEGMENTATIONS_KEY = "segmentations"
|
61
|
-
EMBEDDINGS_URL_KEY = "embeddings_url"
|
62
|
-
JOB_ID_KEY = "job_id"
|
63
|
-
MESSAGE_KEY = "message"
|
65
|
+
WIDTH_KEY = "width"
|
66
|
+
X_KEY = "x"
|
67
|
+
Y_KEY = "y"
|
nucleus/dataset.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1
|
-
from typing import Any, Dict, List, Optional
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
2
2
|
|
3
3
|
import requests
|
4
4
|
|
5
|
-
from nucleus.
|
5
|
+
from nucleus.job import AsyncJob
|
6
|
+
from nucleus.utils import (
|
7
|
+
convert_export_payload,
|
8
|
+
format_dataset_item_response,
|
9
|
+
serialize_and_write_to_presigned_url,
|
10
|
+
)
|
6
11
|
|
7
|
-
from .annotation import Annotation
|
12
|
+
from .annotation import Annotation, check_all_annotation_paths_remote
|
8
13
|
from .constants import (
|
9
14
|
DATASET_ITEM_IDS_KEY,
|
10
15
|
DATASET_LENGTH_KEY,
|
@@ -12,13 +17,24 @@ from .constants import (
|
|
12
17
|
DATASET_NAME_KEY,
|
13
18
|
DATASET_SLICES_KEY,
|
14
19
|
DEFAULT_ANNOTATION_UPDATE_MODE,
|
20
|
+
EXPORTED_ROWS,
|
21
|
+
JOB_ID_KEY,
|
15
22
|
NAME_KEY,
|
16
23
|
REFERENCE_IDS_KEY,
|
24
|
+
REQUEST_ID_KEY,
|
25
|
+
UPDATE_KEY,
|
26
|
+
)
|
27
|
+
from .dataset_item import (
|
28
|
+
DatasetItem,
|
29
|
+
check_all_paths_remote,
|
30
|
+
check_for_duplicate_reference_ids,
|
17
31
|
)
|
18
|
-
from .dataset_item import DatasetItem
|
19
32
|
from .payload_constructor import construct_model_run_creation_payload
|
20
33
|
|
21
34
|
|
35
|
+
WARN_FOR_LARGE_UPLOAD = 50000
|
36
|
+
|
37
|
+
|
22
38
|
class Dataset:
|
23
39
|
"""
|
24
40
|
Nucleus Dataset. You can append images with metadata to your dataset,
|
@@ -26,7 +42,11 @@ class Dataset:
|
|
26
42
|
compare model performance on you data.
|
27
43
|
"""
|
28
44
|
|
29
|
-
def __init__(
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
dataset_id: str,
|
48
|
+
client: "NucleusClient", # type:ignore # noqa: F821
|
49
|
+
):
|
30
50
|
self.id = dataset_id
|
31
51
|
self._client = client
|
32
52
|
|
@@ -129,7 +149,8 @@ class Dataset:
|
|
129
149
|
annotations: List[Annotation],
|
130
150
|
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
|
131
151
|
batch_size: int = 5000,
|
132
|
-
|
152
|
+
asynchronous: bool = False,
|
153
|
+
) -> Union[Dict[str, Any], AsyncJob]:
|
133
154
|
"""
|
134
155
|
Uploads ground truth annotations for a given dataset.
|
135
156
|
:param annotations: ground truth annotations for a given dataset to upload
|
@@ -142,6 +163,19 @@ class Dataset:
|
|
142
163
|
"ignored_items": int,
|
143
164
|
}
|
144
165
|
"""
|
166
|
+
if asynchronous:
|
167
|
+
check_all_annotation_paths_remote(annotations)
|
168
|
+
|
169
|
+
request_id = serialize_and_write_to_presigned_url(
|
170
|
+
annotations, self.id, self._client
|
171
|
+
)
|
172
|
+
response = self._client.make_request(
|
173
|
+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
|
174
|
+
route=f"dataset/{self.id}/annotate?async=1",
|
175
|
+
)
|
176
|
+
|
177
|
+
return AsyncJob(response[JOB_ID_KEY], self._client)
|
178
|
+
|
145
179
|
return self._client.annotate_dataset(
|
146
180
|
self.id, annotations, update=update, batch_size=batch_size
|
147
181
|
)
|
@@ -160,16 +194,18 @@ class Dataset:
|
|
160
194
|
def append(
|
161
195
|
self,
|
162
196
|
dataset_items: List[DatasetItem],
|
163
|
-
|
197
|
+
update: Optional[bool] = False,
|
164
198
|
batch_size: Optional[int] = 20,
|
165
|
-
|
199
|
+
asynchronous=False,
|
200
|
+
) -> Union[dict, AsyncJob]:
|
166
201
|
"""
|
167
202
|
Appends images with metadata (dataset items) to the dataset. Overwrites images on collision if forced.
|
168
203
|
|
169
204
|
Parameters:
|
170
205
|
:param dataset_items: items to upload
|
171
|
-
:param
|
206
|
+
:param update: if True overwrites images and metadata on collision
|
172
207
|
:param batch_size: batch parameter for long uploads
|
208
|
+
:param aynchronous: if True, return a job object representing asynchronous ingestion job.
|
173
209
|
:return:
|
174
210
|
{
|
175
211
|
'dataset_id': str,
|
@@ -178,10 +214,31 @@ class Dataset:
|
|
178
214
|
'ignored_items': int,
|
179
215
|
}
|
180
216
|
"""
|
217
|
+
check_for_duplicate_reference_ids(dataset_items)
|
218
|
+
|
219
|
+
if len(dataset_items) > WARN_FOR_LARGE_UPLOAD and not asynchronous:
|
220
|
+
print(
|
221
|
+
"Tip: for large uploads, get faster performance by importing your data "
|
222
|
+
"into Nucleus directly from a cloud storage provider. See "
|
223
|
+
"https://dashboard.scale.com/nucleus/docs/api?language=python#guide-for-large-ingestions"
|
224
|
+
" for details."
|
225
|
+
)
|
226
|
+
|
227
|
+
if asynchronous:
|
228
|
+
check_all_paths_remote(dataset_items)
|
229
|
+
request_id = serialize_and_write_to_presigned_url(
|
230
|
+
dataset_items, self.id, self._client
|
231
|
+
)
|
232
|
+
response = self._client.make_request(
|
233
|
+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
|
234
|
+
route=f"dataset/{self.id}/append?async=1",
|
235
|
+
)
|
236
|
+
return AsyncJob(response["job_id"], self._client)
|
237
|
+
|
181
238
|
return self._client.populate_dataset(
|
182
239
|
self.id,
|
183
240
|
dataset_items,
|
184
|
-
|
241
|
+
update=update,
|
185
242
|
batch_size=batch_size,
|
186
243
|
)
|
187
244
|
|
@@ -272,3 +329,23 @@ class Dataset:
|
|
272
329
|
|
273
330
|
def check_index_status(self, job_id: str):
|
274
331
|
return self._client.check_index_status(job_id)
|
332
|
+
|
333
|
+
def items_and_annotations(
|
334
|
+
self,
|
335
|
+
) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
|
336
|
+
"""Returns a list of all DatasetItems and Annotations in this slice.
|
337
|
+
|
338
|
+
Returns:
|
339
|
+
A list, where each item is a dict with two keys representing a row
|
340
|
+
in the dataset.
|
341
|
+
* One value in the dict is the DatasetItem, containing a reference to the
|
342
|
+
item that was annotated.
|
343
|
+
* The other value is a dictionary containing all the annotations for this
|
344
|
+
dataset item, sorted by annotation type.
|
345
|
+
"""
|
346
|
+
api_payload = self._client.make_request(
|
347
|
+
payload=None,
|
348
|
+
route=f"dataset/{self.id}/exportForTraining",
|
349
|
+
requests_command=requests.get,
|
350
|
+
)
|
351
|
+
return convert_export_payload(api_payload[EXPORTED_ROWS])
|
nucleus/dataset_item.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1
|
+
from collections import Counter
|
1
2
|
import json
|
2
3
|
import os.path
|
3
4
|
from dataclasses import dataclass
|
4
|
-
from typing import Optional
|
5
|
+
from typing import Optional, Sequence
|
6
|
+
from urllib.parse import urlparse
|
5
7
|
|
6
8
|
from .constants import (
|
7
9
|
DATASET_ITEM_ID_KEY,
|
@@ -21,8 +23,7 @@ class DatasetItem:
|
|
21
23
|
metadata: Optional[dict] = None
|
22
24
|
|
23
25
|
def __post_init__(self):
|
24
|
-
self.
|
25
|
-
self.local = self._is_local_path(self.image_location)
|
26
|
+
self.local = is_local_path(self.image_location)
|
26
27
|
|
27
28
|
@classmethod
|
28
29
|
def from_json(cls, payload: dict):
|
@@ -36,16 +37,12 @@ class DatasetItem:
|
|
36
37
|
metadata=payload.get(METADATA_KEY, {}),
|
37
38
|
)
|
38
39
|
|
39
|
-
def _is_local_path(self, path: str) -> bool:
|
40
|
-
path_components = [comp.lower() for comp in path.split("/")]
|
41
|
-
return path_components[0] not in {"https:", "http:", "s3:", "gs:"}
|
42
|
-
|
43
40
|
def local_file_exists(self):
|
44
|
-
return os.path.isfile(self.
|
41
|
+
return os.path.isfile(self.image_location)
|
45
42
|
|
46
43
|
def to_payload(self) -> dict:
|
47
44
|
payload = {
|
48
|
-
IMAGE_URL_KEY: self.
|
45
|
+
IMAGE_URL_KEY: self.image_location,
|
49
46
|
METADATA_KEY: self.metadata or {},
|
50
47
|
}
|
51
48
|
if self.reference_id:
|
@@ -55,4 +52,33 @@ class DatasetItem:
|
|
55
52
|
return payload
|
56
53
|
|
57
54
|
def to_json(self) -> str:
|
58
|
-
return json.dumps(self.to_payload())
|
55
|
+
return json.dumps(self.to_payload(), allow_nan=False)
|
56
|
+
|
57
|
+
|
58
|
+
def is_local_path(path: str) -> bool:
|
59
|
+
return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
|
60
|
+
|
61
|
+
|
62
|
+
def check_all_paths_remote(dataset_items: Sequence[DatasetItem]):
|
63
|
+
for item in dataset_items:
|
64
|
+
if is_local_path(item.image_location):
|
65
|
+
raise ValueError(
|
66
|
+
f"All paths must be remote, but {item.image_location} is either "
|
67
|
+
"local, or a remote URL type that is not supported."
|
68
|
+
)
|
69
|
+
|
70
|
+
|
71
|
+
def check_for_duplicate_reference_ids(dataset_items: Sequence[DatasetItem]):
|
72
|
+
ref_ids = []
|
73
|
+
for dataset_item in dataset_items:
|
74
|
+
if dataset_item.reference_id is not None:
|
75
|
+
ref_ids.append(dataset_item.reference_id)
|
76
|
+
if len(ref_ids) != len(set(ref_ids)):
|
77
|
+
duplicates = {
|
78
|
+
f"{key}": f"Count: {value}"
|
79
|
+
for key, value in Counter(ref_ids).items()
|
80
|
+
}
|
81
|
+
raise ValueError(
|
82
|
+
"Duplicate reference ids found among dataset_items: %s"
|
83
|
+
% duplicates
|
84
|
+
)
|
nucleus/errors.py
CHANGED
@@ -25,9 +25,22 @@ class DatasetItemRetrievalError(Exception):
|
|
25
25
|
|
26
26
|
|
27
27
|
class NucleusAPIError(Exception):
|
28
|
-
def __init__(
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
28
|
+
def __init__(
|
29
|
+
self, endpoint, command, requests_response=None, aiohttp_response=None
|
30
|
+
):
|
31
|
+
|
32
|
+
if requests_response is not None:
|
33
|
+
message = f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
|
34
|
+
if hasattr(requests_response, "text"):
|
35
|
+
if requests_response.text:
|
36
|
+
message += (
|
37
|
+
f"\nThe detailed error is:\n{requests_response.text}"
|
38
|
+
)
|
39
|
+
|
40
|
+
if aiohttp_response is not None:
|
41
|
+
status, reason, data = aiohttp_response
|
42
|
+
message = f"Tried to {command.__name__} {endpoint}, but received {status}: {reason}."
|
43
|
+
if data:
|
44
|
+
message += f"\nThe detailed error is:\n{data}"
|
45
|
+
|
33
46
|
super().__init__(message)
|
nucleus/job.py
ADDED
@@ -0,0 +1,56 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
import time
|
3
|
+
from typing import Dict, List
|
4
|
+
|
5
|
+
import requests
|
6
|
+
|
7
|
+
JOB_POLLING_INTERVAL = 5
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class AsyncJob:
|
12
|
+
id: str
|
13
|
+
client: "NucleusClient" # type: ignore # noqa: F821
|
14
|
+
|
15
|
+
def status(self) -> Dict[str, str]:
|
16
|
+
return self.client.make_request(
|
17
|
+
payload={},
|
18
|
+
route=f"job/{self.id}",
|
19
|
+
requests_command=requests.get,
|
20
|
+
)
|
21
|
+
|
22
|
+
def errors(self) -> List[str]:
|
23
|
+
return self.client.make_request(
|
24
|
+
payload={},
|
25
|
+
route=f"job/{self.id}/errors",
|
26
|
+
requests_command=requests.get,
|
27
|
+
)
|
28
|
+
|
29
|
+
def sleep_until_complete(self, verbose_std_out=True):
|
30
|
+
while 1:
|
31
|
+
status = self.status()
|
32
|
+
|
33
|
+
time.sleep(JOB_POLLING_INTERVAL)
|
34
|
+
|
35
|
+
if verbose_std_out:
|
36
|
+
print(f"Status at {time.ctime()}: {status}")
|
37
|
+
if status["status"] == "Running":
|
38
|
+
continue
|
39
|
+
break
|
40
|
+
|
41
|
+
final_status = status
|
42
|
+
if final_status["status"] == "Errored":
|
43
|
+
raise JobError(final_status, self)
|
44
|
+
|
45
|
+
|
46
|
+
class JobError(Exception):
|
47
|
+
def __init__(self, job_status: Dict[str, str], job: AsyncJob):
|
48
|
+
final_status_message = job_status["message"]
|
49
|
+
final_status = job_status["status"]
|
50
|
+
message = (
|
51
|
+
f"The job reported a final status of {final_status} "
|
52
|
+
"This could, however, mean a partial success with some successes and some failures. "
|
53
|
+
f"The final status message was: {final_status_message} \n"
|
54
|
+
f"For more detailed error messages you can call {str(job)}.errors()"
|
55
|
+
)
|
56
|
+
super().__init__(message)
|
nucleus/model.py
CHANGED
@@ -45,6 +45,7 @@ class Model:
|
|
45
45
|
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
|
46
46
|
],
|
47
47
|
metadata: Optional[Dict] = None,
|
48
|
+
asynchronous: bool = False,
|
48
49
|
) -> ModelRun:
|
49
50
|
payload: dict = {
|
50
51
|
NAME_KEY: name,
|
@@ -56,6 +57,6 @@ class Model:
|
|
56
57
|
dataset.id, payload
|
57
58
|
)
|
58
59
|
|
59
|
-
model_run.predict(predictions)
|
60
|
+
model_run.predict(predictions, asynchronous=asynchronous)
|
60
61
|
|
61
62
|
return model_run
|
nucleus/model_run.py
CHANGED
@@ -1,10 +1,18 @@
|
|
1
|
-
from typing import Dict, Optional,
|
1
|
+
from typing import Dict, List, Optional, Type, Union
|
2
|
+
|
3
|
+
from nucleus.annotation import check_all_annotation_paths_remote
|
4
|
+
from nucleus.job import AsyncJob
|
5
|
+
from nucleus.utils import serialize_and_write_to_presigned_url
|
6
|
+
|
2
7
|
from .constants import (
|
3
8
|
ANNOTATIONS_KEY,
|
4
|
-
DEFAULT_ANNOTATION_UPDATE_MODE,
|
5
9
|
BOX_TYPE,
|
10
|
+
DEFAULT_ANNOTATION_UPDATE_MODE,
|
11
|
+
JOB_ID_KEY,
|
6
12
|
POLYGON_TYPE,
|
13
|
+
REQUEST_ID_KEY,
|
7
14
|
SEGMENTATION_TYPE,
|
15
|
+
UPDATE_KEY,
|
8
16
|
)
|
9
17
|
from .prediction import (
|
10
18
|
BoxPrediction,
|
@@ -19,12 +27,13 @@ class ModelRun:
|
|
19
27
|
Having an open model run is a prerequisite for uploading predictions to your dataset.
|
20
28
|
"""
|
21
29
|
|
22
|
-
def __init__(self, model_run_id: str, client):
|
30
|
+
def __init__(self, model_run_id: str, dataset_id: str, client):
|
23
31
|
self.model_run_id = model_run_id
|
24
32
|
self._client = client
|
33
|
+
self._dataset_id = dataset_id
|
25
34
|
|
26
35
|
def __repr__(self):
|
27
|
-
return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
|
36
|
+
return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"
|
28
37
|
|
29
38
|
def __eq__(self, other):
|
30
39
|
if self.model_run_id == other.model_run_id:
|
@@ -84,7 +93,8 @@ class ModelRun:
|
|
84
93
|
Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
|
85
94
|
],
|
86
95
|
update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
|
87
|
-
|
96
|
+
asynchronous: bool = False,
|
97
|
+
) -> Union[dict, AsyncJob]:
|
88
98
|
"""
|
89
99
|
Uploads model outputs as predictions for a model_run. Returns info about the upload.
|
90
100
|
:param annotations: List[Union[BoxPrediction, PolygonPrediction]],
|
@@ -95,7 +105,20 @@ class ModelRun:
|
|
95
105
|
"predictions_ignored": int,
|
96
106
|
}
|
97
107
|
"""
|
98
|
-
|
108
|
+
if asynchronous:
|
109
|
+
check_all_annotation_paths_remote(annotations)
|
110
|
+
|
111
|
+
request_id = serialize_and_write_to_presigned_url(
|
112
|
+
annotations, self._dataset_id, self._client
|
113
|
+
)
|
114
|
+
response = self._client.make_request(
|
115
|
+
payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
|
116
|
+
route=f"modelRun/{self.model_run_id}/predict?async=1",
|
117
|
+
)
|
118
|
+
|
119
|
+
return AsyncJob(response[JOB_ID_KEY], self._client)
|
120
|
+
else:
|
121
|
+
return self._client.predict(self.model_run_id, annotations, update)
|
99
122
|
|
100
123
|
def iloc(self, i: int):
|
101
124
|
"""
|
nucleus/payload_constructor.py
CHANGED
@@ -17,7 +17,7 @@ from .constants import (
|
|
17
17
|
REFERENCE_ID_KEY,
|
18
18
|
ANNOTATIONS_KEY,
|
19
19
|
ITEMS_KEY,
|
20
|
-
|
20
|
+
UPDATE_KEY,
|
21
21
|
MODEL_ID_KEY,
|
22
22
|
ANNOTATION_METADATA_SCHEMA_KEY,
|
23
23
|
SEGMENTATIONS_KEY,
|
@@ -34,7 +34,7 @@ def construct_append_payload(
|
|
34
34
|
return (
|
35
35
|
{ITEMS_KEY: items}
|
36
36
|
if not force
|
37
|
-
else {ITEMS_KEY: items,
|
37
|
+
else {ITEMS_KEY: items, UPDATE_KEY: True}
|
38
38
|
)
|
39
39
|
|
40
40
|
|
nucleus/prediction.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
|
-
from typing import Dict, Optional, List
|
1
|
+
from typing import Dict, Optional, List
|
2
2
|
from .annotation import (
|
3
3
|
BoxAnnotation,
|
4
|
+
Point,
|
4
5
|
PolygonAnnotation,
|
5
6
|
Segment,
|
6
7
|
SegmentationAnnotation,
|
@@ -16,6 +17,7 @@ from .constants import (
|
|
16
17
|
Y_KEY,
|
17
18
|
WIDTH_KEY,
|
18
19
|
HEIGHT_KEY,
|
20
|
+
CLASS_PDF_KEY,
|
19
21
|
CONFIDENCE_KEY,
|
20
22
|
VERTICES_KEY,
|
21
23
|
ANNOTATIONS_KEY,
|
@@ -54,6 +56,7 @@ class BoxPrediction(BoxAnnotation):
|
|
54
56
|
confidence: Optional[float] = None,
|
55
57
|
annotation_id: Optional[str] = None,
|
56
58
|
metadata: Optional[Dict] = None,
|
59
|
+
class_pdf: Optional[Dict] = None,
|
57
60
|
):
|
58
61
|
super().__init__(
|
59
62
|
label=label,
|
@@ -67,11 +70,14 @@ class BoxPrediction(BoxAnnotation):
|
|
67
70
|
metadata=metadata,
|
68
71
|
)
|
69
72
|
self.confidence = confidence
|
73
|
+
self.class_pdf = class_pdf
|
70
74
|
|
71
75
|
def to_payload(self) -> dict:
|
72
76
|
payload = super().to_payload()
|
73
77
|
if self.confidence is not None:
|
74
78
|
payload[CONFIDENCE_KEY] = self.confidence
|
79
|
+
if self.class_pdf is not None:
|
80
|
+
payload[CLASS_PDF_KEY] = self.class_pdf
|
75
81
|
|
76
82
|
return payload
|
77
83
|
|
@@ -89,6 +95,7 @@ class BoxPrediction(BoxAnnotation):
|
|
89
95
|
confidence=payload.get(CONFIDENCE_KEY, None),
|
90
96
|
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
|
91
97
|
metadata=payload.get(METADATA_KEY, {}),
|
98
|
+
class_pdf=payload.get(CLASS_PDF_KEY, None),
|
92
99
|
)
|
93
100
|
|
94
101
|
|
@@ -96,12 +103,13 @@ class PolygonPrediction(PolygonAnnotation):
|
|
96
103
|
def __init__(
|
97
104
|
self,
|
98
105
|
label: str,
|
99
|
-
vertices: List[
|
106
|
+
vertices: List[Point],
|
100
107
|
reference_id: Optional[str] = None,
|
101
108
|
item_id: Optional[str] = None,
|
102
109
|
confidence: Optional[float] = None,
|
103
110
|
annotation_id: Optional[str] = None,
|
104
111
|
metadata: Optional[Dict] = None,
|
112
|
+
class_pdf: Optional[Dict] = None,
|
105
113
|
):
|
106
114
|
super().__init__(
|
107
115
|
label=label,
|
@@ -112,11 +120,14 @@ class PolygonPrediction(PolygonAnnotation):
|
|
112
120
|
metadata=metadata,
|
113
121
|
)
|
114
122
|
self.confidence = confidence
|
123
|
+
self.class_pdf = class_pdf
|
115
124
|
|
116
125
|
def to_payload(self) -> dict:
|
117
126
|
payload = super().to_payload()
|
118
127
|
if self.confidence is not None:
|
119
128
|
payload[CONFIDENCE_KEY] = self.confidence
|
129
|
+
if self.class_pdf is not None:
|
130
|
+
payload[CLASS_PDF_KEY] = self.class_pdf
|
120
131
|
|
121
132
|
return payload
|
122
133
|
|
@@ -125,10 +136,13 @@ class PolygonPrediction(PolygonAnnotation):
|
|
125
136
|
geometry = payload.get(GEOMETRY_KEY, {})
|
126
137
|
return cls(
|
127
138
|
label=payload.get(LABEL_KEY, 0),
|
128
|
-
vertices=
|
139
|
+
vertices=[
|
140
|
+
Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
|
141
|
+
],
|
129
142
|
reference_id=payload.get(REFERENCE_ID_KEY, None),
|
130
143
|
item_id=payload.get(DATASET_ITEM_ID_KEY, None),
|
131
144
|
confidence=payload.get(CONFIDENCE_KEY, None),
|
132
145
|
annotation_id=payload.get(ANNOTATION_ID_KEY, None),
|
133
146
|
metadata=payload.get(METADATA_KEY, {}),
|
147
|
+
class_pdf=payload.get(CLASS_PDF_KEY, None),
|
134
148
|
)
|