scale-nucleus 0.1.1__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nucleus/annotation.py CHANGED
@@ -1,25 +1,29 @@
1
+ import json
1
2
  from dataclasses import dataclass
2
3
  from enum import Enum
3
- from typing import Dict, Optional, Any, Union, List
4
+ from typing import Dict, List, Optional, Sequence, Union
5
+ from nucleus.dataset_item import is_local_path
6
+
4
7
  from .constants import (
5
8
  ANNOTATION_ID_KEY,
9
+ ANNOTATIONS_KEY,
10
+ BOX_TYPE,
6
11
  DATASET_ITEM_ID_KEY,
7
- REFERENCE_ID_KEY,
8
- METADATA_KEY,
9
- X_KEY,
10
- Y_KEY,
11
- WIDTH_KEY,
12
- HEIGHT_KEY,
13
12
  GEOMETRY_KEY,
14
- BOX_TYPE,
15
- POLYGON_TYPE,
13
+ HEIGHT_KEY,
14
+ INDEX_KEY,
15
+ ITEM_ID_KEY,
16
16
  LABEL_KEY,
17
+ MASK_TYPE,
18
+ MASK_URL_KEY,
19
+ METADATA_KEY,
20
+ POLYGON_TYPE,
21
+ REFERENCE_ID_KEY,
17
22
  TYPE_KEY,
18
23
  VERTICES_KEY,
19
- ITEM_ID_KEY,
20
- MASK_URL_KEY,
21
- INDEX_KEY,
22
- ANNOTATIONS_KEY,
24
+ WIDTH_KEY,
25
+ X_KEY,
26
+ Y_KEY,
23
27
  )
24
28
 
25
29
 
@@ -42,6 +46,15 @@ class Annotation:
42
46
  else:
43
47
  return SegmentationAnnotation.from_json(payload)
44
48
 
49
+ def to_payload(self):
50
+ raise NotImplementedError(
51
+ "For serialization, use a specific subclass (i.e. SegmentationAnnotation), "
52
+ "not the base annotation class."
53
+ )
54
+
55
+ def to_json(self) -> str:
56
+ return json.dumps(self.to_payload(), allow_nan=False)
57
+
45
58
 
46
59
  @dataclass
47
60
  class Segment:
@@ -97,6 +110,7 @@ class SegmentationAnnotation(Annotation):
97
110
 
98
111
  def to_payload(self) -> dict:
99
112
  payload = {
113
+ TYPE_KEY: MASK_TYPE,
100
114
  MASK_URL_KEY: self.mask_url,
101
115
  ANNOTATIONS_KEY: [ann.to_payload() for ann in self.annotations],
102
116
  ANNOTATION_ID_KEY: self.annotation_id,
@@ -160,11 +174,23 @@ class BoxAnnotation(Annotation): # pylint: disable=R0902
160
174
  }
161
175
 
162
176
 
163
- # TODO: Add Generic type for 2D point
177
+ @dataclass
178
+ class Point:
179
+ x: float
180
+ y: float
181
+
182
+ @classmethod
183
+ def from_json(cls, payload: Dict[str, float]):
184
+ return cls(payload[X_KEY], payload[Y_KEY])
185
+
186
+ def to_payload(self) -> dict:
187
+ return {X_KEY: self.x, Y_KEY: self.y}
188
+
189
+
164
190
  @dataclass
165
191
  class PolygonAnnotation(Annotation):
166
192
  label: str
167
- vertices: List[Any]
193
+ vertices: List[Point]
168
194
  reference_id: Optional[str] = None
169
195
  item_id: Optional[str] = None
170
196
  annotation_id: Optional[str] = None
@@ -173,13 +199,28 @@ class PolygonAnnotation(Annotation):
173
199
  def __post_init__(self):
174
200
  self._check_ids()
175
201
  self.metadata = self.metadata if self.metadata else {}
202
+ if len(self.vertices) > 0:
203
+ if not hasattr(self.vertices[0], X_KEY) or not hasattr(
204
+ self.vertices[0], "to_payload"
205
+ ):
206
+ try:
207
+ self.vertices = [
208
+ Point(x=vertex[X_KEY], y=vertex[Y_KEY])
209
+ for vertex in self.vertices
210
+ ]
211
+ except KeyError as ke:
212
+ raise ValueError(
213
+ "Use a point object to pass in vertices. For example, vertices=[nucleus.Point(x=1, y=2)]"
214
+ ) from ke
176
215
 
177
216
  @classmethod
178
217
  def from_json(cls, payload: dict):
179
218
  geometry = payload.get(GEOMETRY_KEY, {})
180
219
  return cls(
181
220
  label=payload.get(LABEL_KEY, 0),
182
- vertices=geometry.get(VERTICES_KEY, []),
221
+ vertices=[
222
+ Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
223
+ ],
183
224
  reference_id=payload.get(REFERENCE_ID_KEY, None),
184
225
  item_id=payload.get(DATASET_ITEM_ID_KEY, None),
185
226
  annotation_id=payload.get(ANNOTATION_ID_KEY, None),
@@ -187,11 +228,25 @@ class PolygonAnnotation(Annotation):
187
228
  )
188
229
 
189
230
  def to_payload(self) -> dict:
190
- return {
231
+ payload = {
191
232
  LABEL_KEY: self.label,
192
233
  TYPE_KEY: POLYGON_TYPE,
193
- GEOMETRY_KEY: {VERTICES_KEY: self.vertices},
234
+ GEOMETRY_KEY: {
235
+ VERTICES_KEY: [_.to_payload() for _ in self.vertices]
236
+ },
194
237
  REFERENCE_ID_KEY: self.reference_id,
195
238
  ANNOTATION_ID_KEY: self.annotation_id,
196
239
  METADATA_KEY: self.metadata,
197
240
  }
241
+ return payload
242
+
243
+
244
+ def check_all_annotation_paths_remote(
245
+ annotations: Sequence[Union[Annotation]],
246
+ ):
247
+ for annotation in annotations:
248
+ if hasattr(annotation, MASK_URL_KEY):
249
+ if is_local_path(getattr(annotation, MASK_URL_KEY)):
250
+ raise ValueError(
251
+ f"Found an annotation with a local path, which cannot be uploaded asynchronously. Use a remote path instead. {annotation}"
252
+ )
nucleus/constants.py CHANGED
@@ -1,63 +1,67 @@
1
- NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
2
- DEFAULT_NETWORK_TIMEOUT_SEC = 120
3
- ITEMS_KEY = "items"
4
- ITEM_KEY = "item"
5
- REFERENCE_ID_KEY = "reference_id"
6
- REFERENCE_IDS_KEY = "reference_ids"
7
- DATASET_ID_KEY = "dataset_id"
8
- IMAGE_KEY = "image"
9
- IMAGE_URL_KEY = "image_url"
10
- NEW_ITEMS = "new_items"
11
- UPDATED_ITEMS = "updated_items"
12
- IGNORED_ITEMS = "ignored_items"
13
- ERROR_ITEMS = "upload_errors"
14
- ERROR_PAYLOAD = "error_payload"
15
- ERROR_CODES = "error_codes"
1
+ ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
16
2
  ANNOTATIONS_KEY = "annotations"
17
- ANNOTATION_ID_KEY = "annotation_id"
18
3
  ANNOTATIONS_PROCESSED_KEY = "annotations_processed"
19
- ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
20
- PREDICTIONS_PROCESSED_KEY = "predictions_processed"
21
- PREDICTIONS_IGNORED_KEY = "predictions_ignored"
4
+ ANNOTATION_ID_KEY = "annotation_id"
5
+ ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
6
+ BOX_TYPE = "box"
7
+ POLYGON_TYPE = "polygon"
8
+ MASK_TYPE = "mask"
9
+ SEGMENTATION_TYPE = "segmentation"
10
+ ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
22
11
  ANNOTATION_UPDATE_KEY = "update"
23
- DEFAULT_ANNOTATION_UPDATE_MODE = False
24
- STATUS_CODE_KEY = "status_code"
25
- STATUS_KEY = "status"
26
- SUCCESS_STATUS_CODES = [200, 201, 202]
27
- ERRORS_KEY = "errors"
28
- MODEL_RUN_ID_KEY = "model_run_id"
29
- MODEL_ID_KEY = "model_id"
30
- DATASET_ITEM_ID_KEY = "dataset_item_id"
31
- ITEM_ID_KEY = "item_id"
12
+ AUTOTAGS_KEY = "autotags"
13
+ EXPORTED_ROWS = "exportedRows"
14
+ CLASS_PDF_KEY = "class_pdf"
15
+ CONFIDENCE_KEY = "confidence"
16
+ DATASET_ID_KEY = "dataset_id"
32
17
  DATASET_ITEM_IDS_KEY = "dataset_item_ids"
33
- SLICE_ID_KEY = "slice_id"
34
- DATASET_NAME_KEY = "name"
18
+ DATASET_ITEM_ID_KEY = "dataset_item_id"
19
+ DATASET_LENGTH_KEY = "length"
35
20
  DATASET_MODEL_RUNS_KEY = "model_run_ids"
21
+ DATASET_NAME_KEY = "name"
36
22
  DATASET_SLICES_KEY = "slice_ids"
37
- DATASET_LENGTH_KEY = "length"
38
- FORCE_KEY = "update"
23
+ DEFAULT_ANNOTATION_UPDATE_MODE = False
24
+ DEFAULT_NETWORK_TIMEOUT_SEC = 120
25
+ EMBEDDINGS_URL_KEY = "embeddings_url"
26
+ ERRORS_KEY = "errors"
27
+ ERROR_CODES = "error_codes"
28
+ ERROR_ITEMS = "upload_errors"
29
+ ERROR_PAYLOAD = "error_payload"
30
+ GEOMETRY_KEY = "geometry"
31
+ HEIGHT_KEY = "height"
32
+ IGNORED_ITEMS = "ignored_items"
33
+ IMAGE_KEY = "image"
34
+ IMAGE_URL_KEY = "image_url"
35
+ INDEX_KEY = "index"
36
+ ITEMS_KEY = "items"
37
+ ITEM_ID_KEY = "item_id"
38
+ ITEM_KEY = "item"
39
+ ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
40
+ JOB_ID_KEY = "job_id"
41
+ LABEL_KEY = "label"
42
+ MASK_URL_KEY = "mask_url"
43
+ MESSAGE_KEY = "message"
39
44
  METADATA_KEY = "metadata"
45
+ MODEL_ID_KEY = "model_id"
46
+ MODEL_RUN_ID_KEY = "model_run_id"
40
47
  NAME_KEY = "name"
41
- LABEL_KEY = "label"
42
- CONFIDENCE_KEY = "confidence"
48
+ NEW_ITEMS = "new_items"
49
+ NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
43
50
  ORIGINAL_IMAGE_URL_KEY = "original_image_url"
44
- X_KEY = "x"
45
- Y_KEY = "y"
46
- WIDTH_KEY = "width"
47
- HEIGHT_KEY = "height"
51
+ PREDICTIONS_IGNORED_KEY = "predictions_ignored"
52
+ PREDICTIONS_PROCESSED_KEY = "predictions_processed"
53
+ REFERENCE_IDS_KEY = "reference_ids"
54
+ REFERENCE_ID_KEY = "reference_id"
55
+ REQUEST_ID_KEY = "requestId"
56
+ SEGMENTATIONS_KEY = "segmentations"
57
+ SLICE_ID_KEY = "slice_id"
58
+ STATUS_CODE_KEY = "status_code"
59
+ STATUS_KEY = "status"
60
+ SUCCESS_STATUS_CODES = [200, 201, 202]
48
61
  TYPE_KEY = "type"
62
+ UPDATED_ITEMS = "updated_items"
63
+ UPDATE_KEY = "update"
49
64
  VERTICES_KEY = "vertices"
50
- BOX_TYPE = "box"
51
- POLYGON_TYPE = "polygon"
52
- SEGMENTATION_TYPE = "segmentation"
53
- ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
54
- GEOMETRY_KEY = "geometry"
55
- AUTOTAGS_KEY = "autotags"
56
- ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
57
- ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
58
- MASK_URL_KEY = "mask_url"
59
- INDEX_KEY = "index"
60
- SEGMENTATIONS_KEY = "segmentations"
61
- EMBEDDINGS_URL_KEY = "embeddings_url"
62
- JOB_ID_KEY = "job_id"
63
- MESSAGE_KEY = "message"
65
+ WIDTH_KEY = "width"
66
+ X_KEY = "x"
67
+ Y_KEY = "y"
nucleus/dataset.py CHANGED
@@ -1,23 +1,40 @@
1
- from typing import List, Dict, Any, Optional
1
+ from typing import Any, Dict, List, Optional, Union
2
2
 
3
- from nucleus.utils import format_dataset_item_response
4
- from .dataset_item import DatasetItem
5
- from .annotation import (
6
- Annotation,
3
+ import requests
4
+
5
+ from nucleus.job import AsyncJob
6
+ from nucleus.utils import (
7
+ convert_export_payload,
8
+ format_dataset_item_response,
9
+ serialize_and_write_to_presigned_url,
7
10
  )
11
+
12
+ from .annotation import Annotation, check_all_annotation_paths_remote
8
13
  from .constants import (
9
- DATASET_NAME_KEY,
14
+ DATASET_ITEM_IDS_KEY,
15
+ DATASET_LENGTH_KEY,
10
16
  DATASET_MODEL_RUNS_KEY,
17
+ DATASET_NAME_KEY,
11
18
  DATASET_SLICES_KEY,
12
- DATASET_LENGTH_KEY,
13
- DATASET_ITEM_IDS_KEY,
14
- REFERENCE_IDS_KEY,
15
- NAME_KEY,
16
19
  DEFAULT_ANNOTATION_UPDATE_MODE,
20
+ EXPORTED_ROWS,
21
+ JOB_ID_KEY,
22
+ NAME_KEY,
23
+ REFERENCE_IDS_KEY,
24
+ REQUEST_ID_KEY,
25
+ UPDATE_KEY,
26
+ )
27
+ from .dataset_item import (
28
+ DatasetItem,
29
+ check_all_paths_remote,
30
+ check_for_duplicate_reference_ids,
17
31
  )
18
32
  from .payload_constructor import construct_model_run_creation_payload
19
33
 
20
34
 
35
+ WARN_FOR_LARGE_UPLOAD = 50000
36
+
37
+
21
38
  class Dataset:
22
39
  """
23
40
  Nucleus Dataset. You can append images with metadata to your dataset,
@@ -25,7 +42,11 @@ class Dataset:
25
42
  compare model performance on you data.
26
43
  """
27
44
 
28
- def __init__(self, dataset_id: str, client):
45
+ def __init__(
46
+ self,
47
+ dataset_id: str,
48
+ client: "NucleusClient", # type:ignore # noqa: F821
49
+ ):
29
50
  self.id = dataset_id
30
51
  self._client = client
31
52
 
@@ -58,6 +79,25 @@ class Dataset:
58
79
  def items(self) -> List[DatasetItem]:
59
80
  return self._client.get_dataset_items(self.id)
60
81
 
82
+ def autotag_scores(self, autotag_name, for_scores_greater_than=0):
83
+ """Export the autotag scores above a threshold, largest scores first.
84
+
85
+ If you have pandas installed, you can create a pandas dataframe using
86
+
87
+ pandas.Dataframe(dataset.autotag_scores(autotag_name))
88
+
89
+ :return: dictionary of the form
90
+ {'ref_ids': List[str],
91
+ 'datset_item_ids': List[str],
92
+ 'score': List[float]}
93
+ """
94
+ response = self._client.make_request(
95
+ payload={},
96
+ route=f"autotag/{self.id}/{autotag_name}/{for_scores_greater_than}",
97
+ requests_command=requests.get,
98
+ )
99
+ return response
100
+
61
101
  def info(self) -> dict:
62
102
  """
63
103
  Returns information about existing dataset
@@ -109,7 +149,8 @@ class Dataset:
109
149
  annotations: List[Annotation],
110
150
  update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
111
151
  batch_size: int = 5000,
112
- ) -> dict:
152
+ asynchronous: bool = False,
153
+ ) -> Union[Dict[str, Any], AsyncJob]:
113
154
  """
114
155
  Uploads ground truth annotations for a given dataset.
115
156
  :param annotations: ground truth annotations for a given dataset to upload
@@ -122,6 +163,19 @@ class Dataset:
122
163
  "ignored_items": int,
123
164
  }
124
165
  """
166
+ if asynchronous:
167
+ check_all_annotation_paths_remote(annotations)
168
+
169
+ request_id = serialize_and_write_to_presigned_url(
170
+ annotations, self.id, self._client
171
+ )
172
+ response = self._client.make_request(
173
+ payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
174
+ route=f"dataset/{self.id}/annotate?async=1",
175
+ )
176
+
177
+ return AsyncJob(response[JOB_ID_KEY], self._client)
178
+
125
179
  return self._client.annotate_dataset(
126
180
  self.id, annotations, update=update, batch_size=batch_size
127
181
  )
@@ -140,16 +194,18 @@ class Dataset:
140
194
  def append(
141
195
  self,
142
196
  dataset_items: List[DatasetItem],
143
- force: Optional[bool] = False,
197
+ update: Optional[bool] = False,
144
198
  batch_size: Optional[int] = 20,
145
- ) -> dict:
199
+ asynchronous=False,
200
+ ) -> Union[dict, AsyncJob]:
146
201
  """
147
202
  Appends images with metadata (dataset items) to the dataset. Overwrites images on collision if forced.
148
203
 
149
204
  Parameters:
150
205
  :param dataset_items: items to upload
151
- :param force: if True overwrites images on collision
206
+ :param update: if True overwrites images and metadata on collision
152
207
  :param batch_size: batch parameter for long uploads
208
+ :param aynchronous: if True, return a job object representing asynchronous ingestion job.
153
209
  :return:
154
210
  {
155
211
  'dataset_id': str,
@@ -158,10 +214,31 @@ class Dataset:
158
214
  'ignored_items': int,
159
215
  }
160
216
  """
217
+ check_for_duplicate_reference_ids(dataset_items)
218
+
219
+ if len(dataset_items) > WARN_FOR_LARGE_UPLOAD and not asynchronous:
220
+ print(
221
+ "Tip: for large uploads, get faster performance by importing your data "
222
+ "into Nucleus directly from a cloud storage provider. See "
223
+ "https://dashboard.scale.com/nucleus/docs/api?language=python#guide-for-large-ingestions"
224
+ " for details."
225
+ )
226
+
227
+ if asynchronous:
228
+ check_all_paths_remote(dataset_items)
229
+ request_id = serialize_and_write_to_presigned_url(
230
+ dataset_items, self.id, self._client
231
+ )
232
+ response = self._client.make_request(
233
+ payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
234
+ route=f"dataset/{self.id}/append?async=1",
235
+ )
236
+ return AsyncJob(response["job_id"], self._client)
237
+
161
238
  return self._client.populate_dataset(
162
239
  self.id,
163
240
  dataset_items,
164
- force=force,
241
+ update=update,
165
242
  batch_size=batch_size,
166
243
  )
167
244
 
@@ -221,9 +298,9 @@ class Dataset:
221
298
 
222
299
  :return: new Slice object
223
300
  """
224
- if dataset_item_ids and reference_ids:
301
+ if bool(dataset_item_ids) == bool(reference_ids):
225
302
  raise Exception(
226
- "You cannot specify both dataset_item_ids and reference_ids"
303
+ "You must specify exactly one of dataset_item_ids or reference_ids."
227
304
  )
228
305
  payload: Dict[str, Any] = {NAME_KEY: name}
229
306
  if dataset_item_ids:
@@ -252,3 +329,23 @@ class Dataset:
252
329
 
253
330
  def check_index_status(self, job_id: str):
254
331
  return self._client.check_index_status(job_id)
332
+
333
+ def items_and_annotations(
334
+ self,
335
+ ) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
336
+ """Returns a list of all DatasetItems and Annotations in this slice.
337
+
338
+ Returns:
339
+ A list, where each item is a dict with two keys representing a row
340
+ in the dataset.
341
+ * One value in the dict is the DatasetItem, containing a reference to the
342
+ item that was annotated.
343
+ * The other value is a dictionary containing all the annotations for this
344
+ dataset item, sorted by annotation type.
345
+ """
346
+ api_payload = self._client.make_request(
347
+ payload=None,
348
+ route=f"dataset/{self.id}/exportForTraining",
349
+ requests_command=requests.get,
350
+ )
351
+ return convert_export_payload(api_payload[EXPORTED_ROWS])
nucleus/dataset_item.py CHANGED
@@ -1,12 +1,16 @@
1
- from dataclasses import dataclass
1
+ from collections import Counter
2
+ import json
2
3
  import os.path
3
- from typing import Optional
4
+ from dataclasses import dataclass
5
+ from typing import Optional, Sequence
6
+ from urllib.parse import urlparse
7
+
4
8
  from .constants import (
9
+ DATASET_ITEM_ID_KEY,
5
10
  IMAGE_URL_KEY,
6
11
  METADATA_KEY,
7
- REFERENCE_ID_KEY,
8
12
  ORIGINAL_IMAGE_URL_KEY,
9
- DATASET_ITEM_ID_KEY,
13
+ REFERENCE_ID_KEY,
10
14
  )
11
15
 
12
16
 
@@ -19,8 +23,7 @@ class DatasetItem:
19
23
  metadata: Optional[dict] = None
20
24
 
21
25
  def __post_init__(self):
22
- self.image_url = self.image_location
23
- self.local = self._is_local_path(self.image_location)
26
+ self.local = is_local_path(self.image_location)
24
27
 
25
28
  @classmethod
26
29
  def from_json(cls, payload: dict):
@@ -34,20 +37,12 @@ class DatasetItem:
34
37
  metadata=payload.get(METADATA_KEY, {}),
35
38
  )
36
39
 
37
- def _is_local_path(self, path: str) -> bool:
38
- path_components = [comp.lower() for comp in path.split("/")]
39
- return not (
40
- "https:" in path_components
41
- or "http:" in path_components
42
- or "s3:" in path_components
43
- )
44
-
45
40
  def local_file_exists(self):
46
- return os.path.isfile(self.image_url)
41
+ return os.path.isfile(self.image_location)
47
42
 
48
43
  def to_payload(self) -> dict:
49
44
  payload = {
50
- IMAGE_URL_KEY: self.image_url,
45
+ IMAGE_URL_KEY: self.image_location,
51
46
  METADATA_KEY: self.metadata or {},
52
47
  }
53
48
  if self.reference_id:
@@ -55,3 +50,35 @@ class DatasetItem:
55
50
  if self.item_id:
56
51
  payload[DATASET_ITEM_ID_KEY] = self.item_id
57
52
  return payload
53
+
54
+ def to_json(self) -> str:
55
+ return json.dumps(self.to_payload(), allow_nan=False)
56
+
57
+
58
+ def is_local_path(path: str) -> bool:
59
+ return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
60
+
61
+
62
+ def check_all_paths_remote(dataset_items: Sequence[DatasetItem]):
63
+ for item in dataset_items:
64
+ if is_local_path(item.image_location):
65
+ raise ValueError(
66
+ f"All paths must be remote, but {item.image_location} is either "
67
+ "local, or a remote URL type that is not supported."
68
+ )
69
+
70
+
71
+ def check_for_duplicate_reference_ids(dataset_items: Sequence[DatasetItem]):
72
+ ref_ids = []
73
+ for dataset_item in dataset_items:
74
+ if dataset_item.reference_id is not None:
75
+ ref_ids.append(dataset_item.reference_id)
76
+ if len(ref_ids) != len(set(ref_ids)):
77
+ duplicates = {
78
+ f"{key}": f"Count: {value}"
79
+ for key, value in Counter(ref_ids).items()
80
+ }
81
+ raise ValueError(
82
+ "Duplicate reference ids found among dataset_items: %s"
83
+ % duplicates
84
+ )
nucleus/errors.py CHANGED
@@ -22,3 +22,25 @@ class DatasetItemRetrievalError(Exception):
22
22
  def __init__(self, message="Could not retrieve dataset items"):
23
23
  self.message = message
24
24
  super().__init__(self.message)
25
+
26
+
27
+ class NucleusAPIError(Exception):
28
+ def __init__(
29
+ self, endpoint, command, requests_response=None, aiohttp_response=None
30
+ ):
31
+
32
+ if requests_response is not None:
33
+ message = f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
34
+ if hasattr(requests_response, "text"):
35
+ if requests_response.text:
36
+ message += (
37
+ f"\nThe detailed error is:\n{requests_response.text}"
38
+ )
39
+
40
+ if aiohttp_response is not None:
41
+ status, reason, data = aiohttp_response
42
+ message = f"Tried to {command.__name__} {endpoint}, but received {status}: {reason}."
43
+ if data:
44
+ message += f"\nThe detailed error is:\n{data}"
45
+
46
+ super().__init__(message)
nucleus/job.py ADDED
@@ -0,0 +1,56 @@
1
+ from dataclasses import dataclass
2
+ import time
3
+ from typing import Dict, List
4
+
5
+ import requests
6
+
7
+ JOB_POLLING_INTERVAL = 5
8
+
9
+
10
+ @dataclass
11
+ class AsyncJob:
12
+ id: str
13
+ client: "NucleusClient" # type: ignore # noqa: F821
14
+
15
+ def status(self) -> Dict[str, str]:
16
+ return self.client.make_request(
17
+ payload={},
18
+ route=f"job/{self.id}",
19
+ requests_command=requests.get,
20
+ )
21
+
22
+ def errors(self) -> List[str]:
23
+ return self.client.make_request(
24
+ payload={},
25
+ route=f"job/{self.id}/errors",
26
+ requests_command=requests.get,
27
+ )
28
+
29
+ def sleep_until_complete(self, verbose_std_out=True):
30
+ while 1:
31
+ status = self.status()
32
+
33
+ time.sleep(JOB_POLLING_INTERVAL)
34
+
35
+ if verbose_std_out:
36
+ print(f"Status at {time.ctime()}: {status}")
37
+ if status["status"] == "Running":
38
+ continue
39
+ break
40
+
41
+ final_status = status
42
+ if final_status["status"] == "Errored":
43
+ raise JobError(final_status, self)
44
+
45
+
46
+ class JobError(Exception):
47
+ def __init__(self, job_status: Dict[str, str], job: AsyncJob):
48
+ final_status_message = job_status["message"]
49
+ final_status = job_status["status"]
50
+ message = (
51
+ f"The job reported a final status of {final_status} "
52
+ "This could, however, mean a partial success with some successes and some failures. "
53
+ f"The final status message was: {final_status_message} \n"
54
+ f"For more detailed error messages you can call {str(job)}.errors()"
55
+ )
56
+ super().__init__(message)
nucleus/model.py CHANGED
@@ -45,6 +45,7 @@ class Model:
45
45
  Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
46
46
  ],
47
47
  metadata: Optional[Dict] = None,
48
+ asynchronous: bool = False,
48
49
  ) -> ModelRun:
49
50
  payload: dict = {
50
51
  NAME_KEY: name,
@@ -56,6 +57,6 @@ class Model:
56
57
  dataset.id, payload
57
58
  )
58
59
 
59
- model_run.predict(predictions)
60
+ model_run.predict(predictions, asynchronous=asynchronous)
60
61
 
61
62
  return model_run