scale-nucleus 0.1.3__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nucleus/constants.py CHANGED
@@ -1,63 +1,67 @@
1
- NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
2
- DEFAULT_NETWORK_TIMEOUT_SEC = 120
3
- ITEMS_KEY = "items"
4
- ITEM_KEY = "item"
5
- REFERENCE_ID_KEY = "reference_id"
6
- REFERENCE_IDS_KEY = "reference_ids"
7
- DATASET_ID_KEY = "dataset_id"
8
- IMAGE_KEY = "image"
9
- IMAGE_URL_KEY = "image_url"
10
- NEW_ITEMS = "new_items"
11
- UPDATED_ITEMS = "updated_items"
12
- IGNORED_ITEMS = "ignored_items"
13
- ERROR_ITEMS = "upload_errors"
14
- ERROR_PAYLOAD = "error_payload"
15
- ERROR_CODES = "error_codes"
1
+ ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
16
2
  ANNOTATIONS_KEY = "annotations"
17
- ANNOTATION_ID_KEY = "annotation_id"
18
3
  ANNOTATIONS_PROCESSED_KEY = "annotations_processed"
19
- ANNOTATIONS_IGNORED_KEY = "annotations_ignored"
20
- PREDICTIONS_PROCESSED_KEY = "predictions_processed"
21
- PREDICTIONS_IGNORED_KEY = "predictions_ignored"
4
+ ANNOTATION_ID_KEY = "annotation_id"
5
+ ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
6
+ BOX_TYPE = "box"
7
+ POLYGON_TYPE = "polygon"
8
+ MASK_TYPE = "mask"
9
+ SEGMENTATION_TYPE = "segmentation"
10
+ ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
22
11
  ANNOTATION_UPDATE_KEY = "update"
23
- DEFAULT_ANNOTATION_UPDATE_MODE = False
24
- STATUS_CODE_KEY = "status_code"
25
- STATUS_KEY = "status"
26
- SUCCESS_STATUS_CODES = [200, 201, 202]
27
- ERRORS_KEY = "errors"
28
- MODEL_RUN_ID_KEY = "model_run_id"
29
- MODEL_ID_KEY = "model_id"
30
- DATASET_ITEM_ID_KEY = "dataset_item_id"
31
- ITEM_ID_KEY = "item_id"
12
+ AUTOTAGS_KEY = "autotags"
13
+ EXPORTED_ROWS = "exportedRows"
14
+ CLASS_PDF_KEY = "class_pdf"
15
+ CONFIDENCE_KEY = "confidence"
16
+ DATASET_ID_KEY = "dataset_id"
32
17
  DATASET_ITEM_IDS_KEY = "dataset_item_ids"
33
- SLICE_ID_KEY = "slice_id"
34
- DATASET_NAME_KEY = "name"
18
+ DATASET_ITEM_ID_KEY = "dataset_item_id"
19
+ DATASET_LENGTH_KEY = "length"
35
20
  DATASET_MODEL_RUNS_KEY = "model_run_ids"
21
+ DATASET_NAME_KEY = "name"
36
22
  DATASET_SLICES_KEY = "slice_ids"
37
- DATASET_LENGTH_KEY = "length"
38
- FORCE_KEY = "update"
23
+ DEFAULT_ANNOTATION_UPDATE_MODE = False
24
+ DEFAULT_NETWORK_TIMEOUT_SEC = 120
25
+ EMBEDDINGS_URL_KEY = "embeddings_url"
26
+ ERRORS_KEY = "errors"
27
+ ERROR_CODES = "error_codes"
28
+ ERROR_ITEMS = "upload_errors"
29
+ ERROR_PAYLOAD = "error_payload"
30
+ GEOMETRY_KEY = "geometry"
31
+ HEIGHT_KEY = "height"
32
+ IGNORED_ITEMS = "ignored_items"
33
+ IMAGE_KEY = "image"
34
+ IMAGE_URL_KEY = "image_url"
35
+ INDEX_KEY = "index"
36
+ ITEMS_KEY = "items"
37
+ ITEM_ID_KEY = "item_id"
38
+ ITEM_KEY = "item"
39
+ ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
40
+ JOB_ID_KEY = "job_id"
41
+ LABEL_KEY = "label"
42
+ MASK_URL_KEY = "mask_url"
43
+ MESSAGE_KEY = "message"
39
44
  METADATA_KEY = "metadata"
45
+ MODEL_ID_KEY = "model_id"
46
+ MODEL_RUN_ID_KEY = "model_run_id"
40
47
  NAME_KEY = "name"
41
- LABEL_KEY = "label"
42
- CONFIDENCE_KEY = "confidence"
48
+ NEW_ITEMS = "new_items"
49
+ NUCLEUS_ENDPOINT = "https://api.scale.com/v1/nucleus"
43
50
  ORIGINAL_IMAGE_URL_KEY = "original_image_url"
44
- X_KEY = "x"
45
- Y_KEY = "y"
46
- WIDTH_KEY = "width"
47
- HEIGHT_KEY = "height"
51
+ PREDICTIONS_IGNORED_KEY = "predictions_ignored"
52
+ PREDICTIONS_PROCESSED_KEY = "predictions_processed"
53
+ REFERENCE_IDS_KEY = "reference_ids"
54
+ REFERENCE_ID_KEY = "reference_id"
55
+ REQUEST_ID_KEY = "requestId"
56
+ SEGMENTATIONS_KEY = "segmentations"
57
+ SLICE_ID_KEY = "slice_id"
58
+ STATUS_CODE_KEY = "status_code"
59
+ STATUS_KEY = "status"
60
+ SUCCESS_STATUS_CODES = [200, 201, 202]
48
61
  TYPE_KEY = "type"
62
+ UPDATED_ITEMS = "updated_items"
63
+ UPDATE_KEY = "update"
49
64
  VERTICES_KEY = "vertices"
50
- BOX_TYPE = "box"
51
- POLYGON_TYPE = "polygon"
52
- SEGMENTATION_TYPE = "segmentation"
53
- ANNOTATION_TYPES = (BOX_TYPE, POLYGON_TYPE, SEGMENTATION_TYPE)
54
- GEOMETRY_KEY = "geometry"
55
- AUTOTAGS_KEY = "autotags"
56
- ANNOTATION_METADATA_SCHEMA_KEY = "annotation_metadata_schema"
57
- ITEM_METADATA_SCHEMA_KEY = "item_metadata_schema"
58
- MASK_URL_KEY = "mask_url"
59
- INDEX_KEY = "index"
60
- SEGMENTATIONS_KEY = "segmentations"
61
- EMBEDDINGS_URL_KEY = "embeddings_url"
62
- JOB_ID_KEY = "job_id"
63
- MESSAGE_KEY = "message"
65
+ WIDTH_KEY = "width"
66
+ X_KEY = "x"
67
+ Y_KEY = "y"
nucleus/dataset.py CHANGED
@@ -1,10 +1,15 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Dict, List, Optional, Union
2
2
 
3
3
  import requests
4
4
 
5
- from nucleus.utils import format_dataset_item_response
5
+ from nucleus.job import AsyncJob
6
+ from nucleus.utils import (
7
+ convert_export_payload,
8
+ format_dataset_item_response,
9
+ serialize_and_write_to_presigned_url,
10
+ )
6
11
 
7
- from .annotation import Annotation
12
+ from .annotation import Annotation, check_all_annotation_paths_remote
8
13
  from .constants import (
9
14
  DATASET_ITEM_IDS_KEY,
10
15
  DATASET_LENGTH_KEY,
@@ -12,13 +17,24 @@ from .constants import (
12
17
  DATASET_NAME_KEY,
13
18
  DATASET_SLICES_KEY,
14
19
  DEFAULT_ANNOTATION_UPDATE_MODE,
20
+ EXPORTED_ROWS,
21
+ JOB_ID_KEY,
15
22
  NAME_KEY,
16
23
  REFERENCE_IDS_KEY,
24
+ REQUEST_ID_KEY,
25
+ UPDATE_KEY,
26
+ )
27
+ from .dataset_item import (
28
+ DatasetItem,
29
+ check_all_paths_remote,
30
+ check_for_duplicate_reference_ids,
17
31
  )
18
- from .dataset_item import DatasetItem
19
32
  from .payload_constructor import construct_model_run_creation_payload
20
33
 
21
34
 
35
+ WARN_FOR_LARGE_UPLOAD = 50000
36
+
37
+
22
38
  class Dataset:
23
39
  """
24
40
  Nucleus Dataset. You can append images with metadata to your dataset,
@@ -26,7 +42,11 @@ class Dataset:
26
42
  compare model performance on you data.
27
43
  """
28
44
 
29
- def __init__(self, dataset_id: str, client):
45
+ def __init__(
46
+ self,
47
+ dataset_id: str,
48
+ client: "NucleusClient", # type:ignore # noqa: F821
49
+ ):
30
50
  self.id = dataset_id
31
51
  self._client = client
32
52
 
@@ -129,7 +149,8 @@ class Dataset:
129
149
  annotations: List[Annotation],
130
150
  update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
131
151
  batch_size: int = 5000,
132
- ) -> dict:
152
+ asynchronous: bool = False,
153
+ ) -> Union[Dict[str, Any], AsyncJob]:
133
154
  """
134
155
  Uploads ground truth annotations for a given dataset.
135
156
  :param annotations: ground truth annotations for a given dataset to upload
@@ -142,6 +163,19 @@ class Dataset:
142
163
  "ignored_items": int,
143
164
  }
144
165
  """
166
+ if asynchronous:
167
+ check_all_annotation_paths_remote(annotations)
168
+
169
+ request_id = serialize_and_write_to_presigned_url(
170
+ annotations, self.id, self._client
171
+ )
172
+ response = self._client.make_request(
173
+ payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
174
+ route=f"dataset/{self.id}/annotate?async=1",
175
+ )
176
+
177
+ return AsyncJob(response[JOB_ID_KEY], self._client)
178
+
145
179
  return self._client.annotate_dataset(
146
180
  self.id, annotations, update=update, batch_size=batch_size
147
181
  )
@@ -160,16 +194,18 @@ class Dataset:
160
194
  def append(
161
195
  self,
162
196
  dataset_items: List[DatasetItem],
163
- force: Optional[bool] = False,
197
+ update: Optional[bool] = False,
164
198
  batch_size: Optional[int] = 20,
165
- ) -> dict:
199
+ asynchronous=False,
200
+ ) -> Union[dict, AsyncJob]:
166
201
  """
167
202
  Appends images with metadata (dataset items) to the dataset. Overwrites images on collision if forced.
168
203
 
169
204
  Parameters:
170
205
  :param dataset_items: items to upload
171
- :param force: if True overwrites images on collision
206
+ :param update: if True overwrites images and metadata on collision
172
207
  :param batch_size: batch parameter for long uploads
208
+ :param aynchronous: if True, return a job object representing asynchronous ingestion job.
173
209
  :return:
174
210
  {
175
211
  'dataset_id': str,
@@ -178,10 +214,31 @@ class Dataset:
178
214
  'ignored_items': int,
179
215
  }
180
216
  """
217
+ check_for_duplicate_reference_ids(dataset_items)
218
+
219
+ if len(dataset_items) > WARN_FOR_LARGE_UPLOAD and not asynchronous:
220
+ print(
221
+ "Tip: for large uploads, get faster performance by importing your data "
222
+ "into Nucleus directly from a cloud storage provider. See "
223
+ "https://dashboard.scale.com/nucleus/docs/api?language=python#guide-for-large-ingestions"
224
+ " for details."
225
+ )
226
+
227
+ if asynchronous:
228
+ check_all_paths_remote(dataset_items)
229
+ request_id = serialize_and_write_to_presigned_url(
230
+ dataset_items, self.id, self._client
231
+ )
232
+ response = self._client.make_request(
233
+ payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
234
+ route=f"dataset/{self.id}/append?async=1",
235
+ )
236
+ return AsyncJob(response["job_id"], self._client)
237
+
181
238
  return self._client.populate_dataset(
182
239
  self.id,
183
240
  dataset_items,
184
- force=force,
241
+ update=update,
185
242
  batch_size=batch_size,
186
243
  )
187
244
 
@@ -272,3 +329,23 @@ class Dataset:
272
329
 
273
330
  def check_index_status(self, job_id: str):
274
331
  return self._client.check_index_status(job_id)
332
+
333
+ def items_and_annotations(
334
+ self,
335
+ ) -> List[Dict[str, Union[DatasetItem, Dict[str, List[Annotation]]]]]:
336
+ """Returns a list of all DatasetItems and Annotations in this slice.
337
+
338
+ Returns:
339
+ A list, where each item is a dict with two keys representing a row
340
+ in the dataset.
341
+ * One value in the dict is the DatasetItem, containing a reference to the
342
+ item that was annotated.
343
+ * The other value is a dictionary containing all the annotations for this
344
+ dataset item, sorted by annotation type.
345
+ """
346
+ api_payload = self._client.make_request(
347
+ payload=None,
348
+ route=f"dataset/{self.id}/exportForTraining",
349
+ requests_command=requests.get,
350
+ )
351
+ return convert_export_payload(api_payload[EXPORTED_ROWS])
nucleus/dataset_item.py CHANGED
@@ -1,7 +1,9 @@
1
+ from collections import Counter
1
2
  import json
2
3
  import os.path
3
4
  from dataclasses import dataclass
4
- from typing import Optional
5
+ from typing import Optional, Sequence
6
+ from urllib.parse import urlparse
5
7
 
6
8
  from .constants import (
7
9
  DATASET_ITEM_ID_KEY,
@@ -21,8 +23,7 @@ class DatasetItem:
21
23
  metadata: Optional[dict] = None
22
24
 
23
25
  def __post_init__(self):
24
- self.image_url = self.image_location
25
- self.local = self._is_local_path(self.image_location)
26
+ self.local = is_local_path(self.image_location)
26
27
 
27
28
  @classmethod
28
29
  def from_json(cls, payload: dict):
@@ -36,16 +37,12 @@ class DatasetItem:
36
37
  metadata=payload.get(METADATA_KEY, {}),
37
38
  )
38
39
 
39
- def _is_local_path(self, path: str) -> bool:
40
- path_components = [comp.lower() for comp in path.split("/")]
41
- return path_components[0] not in {"https:", "http:", "s3:", "gs:"}
42
-
43
40
  def local_file_exists(self):
44
- return os.path.isfile(self.image_url)
41
+ return os.path.isfile(self.image_location)
45
42
 
46
43
  def to_payload(self) -> dict:
47
44
  payload = {
48
- IMAGE_URL_KEY: self.image_url,
45
+ IMAGE_URL_KEY: self.image_location,
49
46
  METADATA_KEY: self.metadata or {},
50
47
  }
51
48
  if self.reference_id:
@@ -55,4 +52,33 @@ class DatasetItem:
55
52
  return payload
56
53
 
57
54
  def to_json(self) -> str:
58
- return json.dumps(self.to_payload())
55
+ return json.dumps(self.to_payload(), allow_nan=False)
56
+
57
+
58
+ def is_local_path(path: str) -> bool:
59
+ return urlparse(path).scheme not in {"https", "http", "s3", "gs"}
60
+
61
+
62
+ def check_all_paths_remote(dataset_items: Sequence[DatasetItem]):
63
+ for item in dataset_items:
64
+ if is_local_path(item.image_location):
65
+ raise ValueError(
66
+ f"All paths must be remote, but {item.image_location} is either "
67
+ "local, or a remote URL type that is not supported."
68
+ )
69
+
70
+
71
+ def check_for_duplicate_reference_ids(dataset_items: Sequence[DatasetItem]):
72
+ ref_ids = []
73
+ for dataset_item in dataset_items:
74
+ if dataset_item.reference_id is not None:
75
+ ref_ids.append(dataset_item.reference_id)
76
+ if len(ref_ids) != len(set(ref_ids)):
77
+ duplicates = {
78
+ f"{key}": f"Count: {value}"
79
+ for key, value in Counter(ref_ids).items()
80
+ }
81
+ raise ValueError(
82
+ "Duplicate reference ids found among dataset_items: %s"
83
+ % duplicates
84
+ )
nucleus/errors.py CHANGED
@@ -25,9 +25,22 @@ class DatasetItemRetrievalError(Exception):
25
25
 
26
26
 
27
27
  class NucleusAPIError(Exception):
28
- def __init__(self, endpoint, command, response):
29
- message = f"Tried to {command.__name__} {endpoint}, but received {response.status_code}: {response.reason}."
30
- if hasattr(response, "text"):
31
- if response.text:
32
- message += f"\nThe detailed error is:\n{response.text}"
28
+ def __init__(
29
+ self, endpoint, command, requests_response=None, aiohttp_response=None
30
+ ):
31
+
32
+ if requests_response is not None:
33
+ message = f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
34
+ if hasattr(requests_response, "text"):
35
+ if requests_response.text:
36
+ message += (
37
+ f"\nThe detailed error is:\n{requests_response.text}"
38
+ )
39
+
40
+ if aiohttp_response is not None:
41
+ status, reason, data = aiohttp_response
42
+ message = f"Tried to {command.__name__} {endpoint}, but received {status}: {reason}."
43
+ if data:
44
+ message += f"\nThe detailed error is:\n{data}"
45
+
33
46
  super().__init__(message)
nucleus/job.py ADDED
@@ -0,0 +1,56 @@
1
+ from dataclasses import dataclass
2
+ import time
3
+ from typing import Dict, List
4
+
5
+ import requests
6
+
7
+ JOB_POLLING_INTERVAL = 5
8
+
9
+
10
+ @dataclass
11
+ class AsyncJob:
12
+ id: str
13
+ client: "NucleusClient" # type: ignore # noqa: F821
14
+
15
+ def status(self) -> Dict[str, str]:
16
+ return self.client.make_request(
17
+ payload={},
18
+ route=f"job/{self.id}",
19
+ requests_command=requests.get,
20
+ )
21
+
22
+ def errors(self) -> List[str]:
23
+ return self.client.make_request(
24
+ payload={},
25
+ route=f"job/{self.id}/errors",
26
+ requests_command=requests.get,
27
+ )
28
+
29
+ def sleep_until_complete(self, verbose_std_out=True):
30
+ while 1:
31
+ status = self.status()
32
+
33
+ time.sleep(JOB_POLLING_INTERVAL)
34
+
35
+ if verbose_std_out:
36
+ print(f"Status at {time.ctime()}: {status}")
37
+ if status["status"] == "Running":
38
+ continue
39
+ break
40
+
41
+ final_status = status
42
+ if final_status["status"] == "Errored":
43
+ raise JobError(final_status, self)
44
+
45
+
46
+ class JobError(Exception):
47
+ def __init__(self, job_status: Dict[str, str], job: AsyncJob):
48
+ final_status_message = job_status["message"]
49
+ final_status = job_status["status"]
50
+ message = (
51
+ f"The job reported a final status of {final_status} "
52
+ "This could, however, mean a partial success with some successes and some failures. "
53
+ f"The final status message was: {final_status_message} \n"
54
+ f"For more detailed error messages you can call {str(job)}.errors()"
55
+ )
56
+ super().__init__(message)
nucleus/model.py CHANGED
@@ -45,6 +45,7 @@ class Model:
45
45
  Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
46
46
  ],
47
47
  metadata: Optional[Dict] = None,
48
+ asynchronous: bool = False,
48
49
  ) -> ModelRun:
49
50
  payload: dict = {
50
51
  NAME_KEY: name,
@@ -56,6 +57,6 @@ class Model:
56
57
  dataset.id, payload
57
58
  )
58
59
 
59
- model_run.predict(predictions)
60
+ model_run.predict(predictions, asynchronous=asynchronous)
60
61
 
61
62
  return model_run
nucleus/model_run.py CHANGED
@@ -1,10 +1,18 @@
1
- from typing import Dict, Optional, List, Union, Type
1
+ from typing import Dict, List, Optional, Type, Union
2
+
3
+ from nucleus.annotation import check_all_annotation_paths_remote
4
+ from nucleus.job import AsyncJob
5
+ from nucleus.utils import serialize_and_write_to_presigned_url
6
+
2
7
  from .constants import (
3
8
  ANNOTATIONS_KEY,
4
- DEFAULT_ANNOTATION_UPDATE_MODE,
5
9
  BOX_TYPE,
10
+ DEFAULT_ANNOTATION_UPDATE_MODE,
11
+ JOB_ID_KEY,
6
12
  POLYGON_TYPE,
13
+ REQUEST_ID_KEY,
7
14
  SEGMENTATION_TYPE,
15
+ UPDATE_KEY,
8
16
  )
9
17
  from .prediction import (
10
18
  BoxPrediction,
@@ -19,12 +27,13 @@ class ModelRun:
19
27
  Having an open model run is a prerequisite for uploading predictions to your dataset.
20
28
  """
21
29
 
22
- def __init__(self, model_run_id: str, client):
30
+ def __init__(self, model_run_id: str, dataset_id: str, client):
23
31
  self.model_run_id = model_run_id
24
32
  self._client = client
33
+ self._dataset_id = dataset_id
25
34
 
26
35
  def __repr__(self):
27
- return f"ModelRun(model_run_id='{self.model_run_id}', client={self._client})"
36
+ return f"ModelRun(model_run_id='{self.model_run_id}', dataset_id='{self._dataset_id}', client={self._client})"
28
37
 
29
38
  def __eq__(self, other):
30
39
  if self.model_run_id == other.model_run_id:
@@ -84,7 +93,8 @@ class ModelRun:
84
93
  Union[BoxPrediction, PolygonPrediction, SegmentationPrediction]
85
94
  ],
86
95
  update: Optional[bool] = DEFAULT_ANNOTATION_UPDATE_MODE,
87
- ) -> dict:
96
+ asynchronous: bool = False,
97
+ ) -> Union[dict, AsyncJob]:
88
98
  """
89
99
  Uploads model outputs as predictions for a model_run. Returns info about the upload.
90
100
  :param annotations: List[Union[BoxPrediction, PolygonPrediction]],
@@ -95,7 +105,20 @@ class ModelRun:
95
105
  "predictions_ignored": int,
96
106
  }
97
107
  """
98
- return self._client.predict(self.model_run_id, annotations, update)
108
+ if asynchronous:
109
+ check_all_annotation_paths_remote(annotations)
110
+
111
+ request_id = serialize_and_write_to_presigned_url(
112
+ annotations, self._dataset_id, self._client
113
+ )
114
+ response = self._client.make_request(
115
+ payload={REQUEST_ID_KEY: request_id, UPDATE_KEY: update},
116
+ route=f"modelRun/{self.model_run_id}/predict?async=1",
117
+ )
118
+
119
+ return AsyncJob(response[JOB_ID_KEY], self._client)
120
+ else:
121
+ return self._client.predict(self.model_run_id, annotations, update)
99
122
 
100
123
  def iloc(self, i: int):
101
124
  """
@@ -17,7 +17,7 @@ from .constants import (
17
17
  REFERENCE_ID_KEY,
18
18
  ANNOTATIONS_KEY,
19
19
  ITEMS_KEY,
20
- FORCE_KEY,
20
+ UPDATE_KEY,
21
21
  MODEL_ID_KEY,
22
22
  ANNOTATION_METADATA_SCHEMA_KEY,
23
23
  SEGMENTATIONS_KEY,
@@ -34,7 +34,7 @@ def construct_append_payload(
34
34
  return (
35
35
  {ITEMS_KEY: items}
36
36
  if not force
37
- else {ITEMS_KEY: items, FORCE_KEY: True}
37
+ else {ITEMS_KEY: items, UPDATE_KEY: True}
38
38
  )
39
39
 
40
40
 
nucleus/prediction.py CHANGED
@@ -1,6 +1,7 @@
1
- from typing import Dict, Optional, List, Any
1
+ from typing import Dict, Optional, List
2
2
  from .annotation import (
3
3
  BoxAnnotation,
4
+ Point,
4
5
  PolygonAnnotation,
5
6
  Segment,
6
7
  SegmentationAnnotation,
@@ -16,6 +17,7 @@ from .constants import (
16
17
  Y_KEY,
17
18
  WIDTH_KEY,
18
19
  HEIGHT_KEY,
20
+ CLASS_PDF_KEY,
19
21
  CONFIDENCE_KEY,
20
22
  VERTICES_KEY,
21
23
  ANNOTATIONS_KEY,
@@ -54,6 +56,7 @@ class BoxPrediction(BoxAnnotation):
54
56
  confidence: Optional[float] = None,
55
57
  annotation_id: Optional[str] = None,
56
58
  metadata: Optional[Dict] = None,
59
+ class_pdf: Optional[Dict] = None,
57
60
  ):
58
61
  super().__init__(
59
62
  label=label,
@@ -67,11 +70,14 @@ class BoxPrediction(BoxAnnotation):
67
70
  metadata=metadata,
68
71
  )
69
72
  self.confidence = confidence
73
+ self.class_pdf = class_pdf
70
74
 
71
75
  def to_payload(self) -> dict:
72
76
  payload = super().to_payload()
73
77
  if self.confidence is not None:
74
78
  payload[CONFIDENCE_KEY] = self.confidence
79
+ if self.class_pdf is not None:
80
+ payload[CLASS_PDF_KEY] = self.class_pdf
75
81
 
76
82
  return payload
77
83
 
@@ -89,6 +95,7 @@ class BoxPrediction(BoxAnnotation):
89
95
  confidence=payload.get(CONFIDENCE_KEY, None),
90
96
  annotation_id=payload.get(ANNOTATION_ID_KEY, None),
91
97
  metadata=payload.get(METADATA_KEY, {}),
98
+ class_pdf=payload.get(CLASS_PDF_KEY, None),
92
99
  )
93
100
 
94
101
 
@@ -96,12 +103,13 @@ class PolygonPrediction(PolygonAnnotation):
96
103
  def __init__(
97
104
  self,
98
105
  label: str,
99
- vertices: List[Any],
106
+ vertices: List[Point],
100
107
  reference_id: Optional[str] = None,
101
108
  item_id: Optional[str] = None,
102
109
  confidence: Optional[float] = None,
103
110
  annotation_id: Optional[str] = None,
104
111
  metadata: Optional[Dict] = None,
112
+ class_pdf: Optional[Dict] = None,
105
113
  ):
106
114
  super().__init__(
107
115
  label=label,
@@ -112,11 +120,14 @@ class PolygonPrediction(PolygonAnnotation):
112
120
  metadata=metadata,
113
121
  )
114
122
  self.confidence = confidence
123
+ self.class_pdf = class_pdf
115
124
 
116
125
  def to_payload(self) -> dict:
117
126
  payload = super().to_payload()
118
127
  if self.confidence is not None:
119
128
  payload[CONFIDENCE_KEY] = self.confidence
129
+ if self.class_pdf is not None:
130
+ payload[CLASS_PDF_KEY] = self.class_pdf
120
131
 
121
132
  return payload
122
133
 
@@ -125,10 +136,13 @@ class PolygonPrediction(PolygonAnnotation):
125
136
  geometry = payload.get(GEOMETRY_KEY, {})
126
137
  return cls(
127
138
  label=payload.get(LABEL_KEY, 0),
128
- vertices=geometry.get(VERTICES_KEY, []),
139
+ vertices=[
140
+ Point.from_json(_) for _ in geometry.get(VERTICES_KEY, [])
141
+ ],
129
142
  reference_id=payload.get(REFERENCE_ID_KEY, None),
130
143
  item_id=payload.get(DATASET_ITEM_ID_KEY, None),
131
144
  confidence=payload.get(CONFIDENCE_KEY, None),
132
145
  annotation_id=payload.get(ANNOTATION_ID_KEY, None),
133
146
  metadata=payload.get(METADATA_KEY, {}),
147
+ class_pdf=payload.get(CLASS_PDF_KEY, None),
134
148
  )