scale-nucleus 0.1.22__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +453 -699
  14. nucleus/annotation.py +435 -80
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +12 -2
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1139 -215
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +21 -1
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +30 -18
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +139 -35
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.22.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.22.dist-info/RECORD +0 -21
nucleus/__init__.py CHANGED
@@ -1,19 +1,45 @@
1
- """
2
- Nucleus Python Library.
3
-
4
- For full documentation see: https://dashboard.scale.com/nucleus/docs/api?language=python
5
- """
6
- import asyncio
7
- import json
8
- import logging
1
+ """Nucleus Python SDK. """
2
+
3
+ __all__ = [
4
+ "AsyncJob",
5
+ "BoxAnnotation",
6
+ "BoxPrediction",
7
+ "CameraParams",
8
+ "CategoryAnnotation",
9
+ "CategoryPrediction",
10
+ "CuboidAnnotation",
11
+ "CuboidPrediction",
12
+ "Dataset",
13
+ "DatasetInfo",
14
+ "DatasetItem",
15
+ "DatasetItemRetrievalError",
16
+ "Frame",
17
+ "Frame",
18
+ "LidarScene",
19
+ "LidarScene",
20
+ "Model",
21
+ "ModelCreationError",
22
+ # "MultiCategoryAnnotation", # coming soon!
23
+ "NotFoundError",
24
+ "NucleusAPIError",
25
+ "NucleusClient",
26
+ "Point",
27
+ "Point3D",
28
+ "PolygonAnnotation",
29
+ "PolygonPrediction",
30
+ "Quaternion",
31
+ "Segment",
32
+ "SegmentationAnnotation",
33
+ "SegmentationPrediction",
34
+ "Slice",
35
+ ]
36
+
9
37
  import os
10
- import urllib.request
11
- from asyncio.tasks import Task
12
- from typing import Any, Dict, List, Optional, Union
38
+ import warnings
39
+ from typing import Dict, List, Optional, Sequence, Union
13
40
 
14
- import aiohttp
15
- import nest_asyncio
16
41
  import pkg_resources
42
+ import pydantic
17
43
  import requests
18
44
  import tqdm
19
45
  import tqdm.notebook as tqdm_notebook
@@ -22,34 +48,38 @@ from nucleus.url_utils import sanitize_string_args
22
48
 
23
49
  from .annotation import (
24
50
  BoxAnnotation,
51
+ CategoryAnnotation,
25
52
  CuboidAnnotation,
53
+ MultiCategoryAnnotation,
26
54
  Point,
27
55
  Point3D,
28
56
  PolygonAnnotation,
29
57
  Segment,
30
58
  SegmentationAnnotation,
31
59
  )
60
+ from .connection import Connection
32
61
  from .constants import (
33
62
  ANNOTATION_METADATA_SCHEMA_KEY,
34
63
  ANNOTATIONS_IGNORED_KEY,
35
64
  ANNOTATIONS_PROCESSED_KEY,
36
65
  AUTOTAGS_KEY,
37
66
  DATASET_ID_KEY,
67
+ DATASET_IS_SCENE_KEY,
38
68
  DEFAULT_NETWORK_TIMEOUT_SEC,
39
69
  EMBEDDING_DIMENSION_KEY,
40
70
  EMBEDDINGS_URL_KEY,
41
71
  ERROR_ITEMS,
42
72
  ERROR_PAYLOAD,
43
73
  ERRORS_KEY,
44
- JOB_ID_KEY,
45
- JOB_LAST_KNOWN_STATUS_KEY,
46
- JOB_TYPE_KEY,
47
- JOB_CREATION_TIME_KEY,
48
74
  IMAGE_KEY,
49
75
  IMAGE_URL_KEY,
50
76
  INDEX_CONTINUOUS_ENABLE_KEY,
51
77
  ITEM_METADATA_SCHEMA_KEY,
52
78
  ITEMS_KEY,
79
+ JOB_CREATION_TIME_KEY,
80
+ JOB_ID_KEY,
81
+ JOB_LAST_KNOWN_STATUS_KEY,
82
+ JOB_TYPE_KEY,
53
83
  KEEP_HISTORY_KEY,
54
84
  MESSAGE_KEY,
55
85
  MODEL_RUN_ID_KEY,
@@ -62,16 +92,21 @@ from .constants import (
62
92
  STATUS_CODE_KEY,
63
93
  UPDATE_KEY,
64
94
  )
95
+ from .data_transfer_object.dataset_details import DatasetDetails
96
+ from .data_transfer_object.dataset_info import DatasetInfo
65
97
  from .dataset import Dataset
66
- from .dataset_item import DatasetItem, CameraParams, Quaternion
98
+ from .dataset_item import CameraParams, DatasetItem, Quaternion
99
+ from .deprecation_warning import deprecated
67
100
  from .errors import (
68
101
  DatasetItemRetrievalError,
69
102
  ModelCreationError,
70
103
  ModelRunCreationError,
104
+ NoAPIKey,
71
105
  NotFoundError,
72
106
  NucleusAPIError,
73
107
  )
74
108
  from .job import AsyncJob
109
+ from .logger import logger
75
110
  from .model import Model
76
111
  from .model_run import ModelRun
77
112
  from .payload_constructor import (
@@ -83,13 +118,16 @@ from .payload_constructor import (
83
118
  )
84
119
  from .prediction import (
85
120
  BoxPrediction,
121
+ CategoryPrediction,
86
122
  CuboidPrediction,
87
123
  PolygonPrediction,
88
124
  SegmentationPrediction,
89
125
  )
126
+ from .retry_strategy import RetryStrategy
127
+ from .scene import Frame, LidarScene
90
128
  from .slice import Slice
91
129
  from .upload_response import UploadResponse
92
- from .scene import Frame, LidarScene
130
+ from .validate import Validate
93
131
 
94
132
  # pylint: disable=E1101
95
133
  # TODO: refactor to reduce this file to under 1000 lines.
@@ -98,25 +136,25 @@ from .scene import Frame, LidarScene
98
136
 
99
137
  __version__ = pkg_resources.get_distribution("scale-nucleus").version
100
138
 
101
- logger = logging.getLogger(__name__)
102
- logging.basicConfig()
103
- logging.getLogger(requests.packages.urllib3.__package__).setLevel(
104
- logging.ERROR
105
- )
106
-
107
139
 
108
140
  class NucleusClient:
109
- """
110
- Nucleus client.
141
+ """Client to interact with the Nucleus API via Python SDK.
142
+
143
+ Parameters:
144
+ api_key: Follow `this guide <https://scale.com/docs/account#section-api-keys>`_
145
+ to retrieve your API keys.
146
+ use_notebook: Whether the client is being used in a notebook (toggles tqdm
147
+ style). Default is ``False``.
148
+ endpoint: Base URL of the API. Default is Nucleus's current production API.
111
149
  """
112
150
 
113
151
  def __init__(
114
152
  self,
115
- api_key: str,
153
+ api_key: Optional[str] = None,
116
154
  use_notebook: bool = False,
117
155
  endpoint: str = None,
118
156
  ):
119
- self.api_key = api_key
157
+ self.api_key = self._set_api_key(api_key)
120
158
  self.tqdm_bar = tqdm.tqdm
121
159
  if endpoint is None:
122
160
  self.endpoint = os.environ.get(
@@ -127,6 +165,8 @@ class NucleusClient:
127
165
  self._use_notebook = use_notebook
128
166
  if use_notebook:
129
167
  self.tqdm_bar = tqdm_notebook.tqdm
168
+ self._connection = Connection(self.api_key, self.endpoint)
169
+ self.validate = Validate(self.api_key, self.endpoint)
130
170
 
131
171
  def __repr__(self):
132
172
  return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook}, endpoint='{self.endpoint}')"
@@ -137,10 +177,26 @@ class NucleusClient:
137
177
  return True
138
178
  return False
139
179
 
140
- def list_models(self) -> List[Model]:
180
+ @property
181
+ def datasets(self) -> List[Dataset]:
182
+ """List all Datasets
183
+
184
+ Returns:
185
+ List of all datasets accessible to user
141
186
  """
142
- Lists available models in your repo.
143
- :return: model_ids
187
+ response = self.make_request({}, "dataset/details", requests.get)
188
+ dataset_details = pydantic.parse_obj_as(List[DatasetDetails], response)
189
+ return [
190
+ Dataset(d.id, client=self, name=d.name) for d in dataset_details
191
+ ]
192
+
193
+ @property
194
+ def models(self) -> List[Model]:
195
+ # TODO: implement for Dataset, scoped just to associated models
196
+ """Fetches all of your Nucleus models.
197
+
198
+ Returns:
199
+ List[:class:`Model`]: List of models associated with the client API key.
144
200
  """
145
201
  model_objects = self.make_request({}, "models/", requests.get)
146
202
 
@@ -155,20 +211,41 @@ class NucleusClient:
155
211
  for model in model_objects["models"]
156
212
  ]
157
213
 
158
- def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
159
- """
160
- Lists available datasets in your repo.
161
- :return: { datasets_ids }
214
+ @property
215
+ def jobs(
216
+ self,
217
+ ) -> List[AsyncJob]:
218
+ """Lists all jobs, see NucleusClinet.list_jobs(...) for advanced options
219
+
220
+ Returns:
221
+ List of all AsyncJobs
162
222
  """
223
+ return self.list_jobs()
224
+
225
+ @deprecated(msg="Use the NucleusClient.models property in the future.")
226
+ def list_models(self) -> List[Model]:
227
+ return self.models
228
+
229
+ @deprecated(msg="Use the NucleusClient.datasets property in the future.")
230
+ def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
163
231
  return self.make_request({}, "dataset/", requests.get)
164
232
 
165
233
  def list_jobs(
166
234
  self, show_completed=None, date_limit=None
167
235
  ) -> List[AsyncJob]:
236
+ """Fetches all of your running jobs in Nucleus.
237
+
238
+ Parameters:
239
+ show_completed: Whether to fetch completed and errored jobs or just
240
+ running jobs. Default behavior is False.
241
+ date_limit: Only fetch jobs that were started after this date. Default
242
+ behavior is 2 weeks prior to the current date.
243
+
244
+ Returns:
245
+ List[:class:`AsyncJob`]: List of running asynchronous jobs
246
+ associated with the client API key.
168
247
  """
169
- Lists jobs for user.
170
- :return: jobs
171
- """
248
+ # TODO: What type is date_limit? Use pydantic ...
172
249
  payload = {show_completed: show_completed, date_limit: date_limit}
173
250
  job_objects = self.make_request(payload, "jobs/", requests.get)
174
251
  return [
@@ -182,42 +259,47 @@ class NucleusClient:
182
259
  for job in job_objects
183
260
  ]
184
261
 
262
+ @deprecated(msg="Prefer using Dataset.items")
185
263
  def get_dataset_items(self, dataset_id) -> List[DatasetItem]:
186
- """
187
- Gets all the dataset items inside your repo as a json blob.
188
- :return [ DatasetItem ]
189
- """
190
- response = self.make_request(
191
- {}, f"dataset/{dataset_id}/datasetItems", requests.get
192
- )
193
- dataset_items = response.get("dataset_items", None)
194
- error = response.get("error", None)
195
- constructed_dataset_items = []
196
- if dataset_items:
197
- for item in dataset_items:
198
- image_url = item.get("original_image_url")
199
- metadata = item.get("metadata", None)
200
- ref_id = item.get("ref_id", None)
201
- dataset_item = DatasetItem(image_url, ref_id, metadata)
202
- constructed_dataset_items.append(dataset_item)
203
- elif error:
204
- raise DatasetItemRetrievalError(message=error)
205
-
206
- return constructed_dataset_items
264
+ dataset = self.get_dataset(dataset_id)
265
+ return dataset.items
207
266
 
208
267
  def get_dataset(self, dataset_id: str) -> Dataset:
209
- """
210
- Fetches a dataset for given id
211
- :param dataset_id: internally controlled dataset_id
212
- :return: dataset
268
+ """Fetches a dataset by its ID.
269
+
270
+ Parameters:
271
+ dataset_id: The ID of the dataset to fetch.
272
+
273
+ Returns:
274
+ :class:`Dataset`: The Nucleus dataset as an object.
213
275
  """
214
276
  return Dataset(dataset_id, self)
215
277
 
216
- def get_model(self, model_id: str) -> Model:
278
+ def get_job(self, job_id: str) -> AsyncJob:
279
+ """Fetches a dataset by its ID.
280
+
281
+ Parameters:
282
+ job_id: The ID of the dataset to fetch.
283
+
284
+ Returns:
285
+ :class:`AsyncJob`: The Nucleus async job as an object.
217
286
  """
218
- Fetched a model for a given id
219
- :param model_id: internally controlled dataset_id
220
- :return: model
287
+ payload = self.make_request(
288
+ payload={},
289
+ route=f"job/{job_id}/info",
290
+ requests_command=requests.get,
291
+ )
292
+ return AsyncJob.from_json(payload=payload, client=self)
293
+
294
+ def get_model(self, model_id: str) -> Model:
295
+ """Fetches a model by its ID.
296
+
297
+ Parameters:
298
+ model_id: Nucleus-generated model ID (starts with ``prj_``). This can
299
+ be retrieved via :meth:`list_models` or a Nucleus dashboard URL.
300
+
301
+ Returns:
302
+ :class:`Model`: The Nucleus model as an object.
221
303
  """
222
304
  payload = self.make_request(
223
305
  payload={},
@@ -226,22 +308,16 @@ class NucleusClient:
226
308
  )
227
309
  return Model.from_json(payload=payload, client=self)
228
310
 
311
+ @deprecated(
312
+ "Model runs have been deprecated and will be removed. Use a Model instead"
313
+ )
229
314
  def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
230
- """
231
- Fetches a model_run for given id
232
- :param model_run_id: internally controlled model_run_id
233
- :param dataset_id: the dataset id which may determine the prediction schema
234
- for this model run if present on the dataset.
235
- :return: model_run
236
- """
237
315
  return ModelRun(model_run_id, dataset_id, self)
238
316
 
317
+ @deprecated(
318
+ "Model runs have been deprecated and will be removed. Use a Model instead"
319
+ )
239
320
  def delete_model_run(self, model_run_id: str):
240
- """
241
- Fetches a model_run for given id
242
- :param model_run_id: internally controlled model_run_id
243
- :return: model_run
244
- """
245
321
  return self.make_request(
246
322
  {}, f"modelRun/{model_run_id}", requests.delete
247
323
  )
@@ -249,12 +325,26 @@ class NucleusClient:
249
325
  def create_dataset_from_project(
250
326
  self, project_id: str, last_n_tasks: int = None, name: str = None
251
327
  ) -> Dataset:
252
- """
253
- Creates a new dataset based on payload params:
254
- name -- A human-readable name of the dataset.
255
- Returns a response with internal id and name for a new dataset.
256
- :param payload: { "name": str }
257
- :return: new Dataset object
328
+ """Create a new dataset from an existing Scale or Rapid project.
329
+
330
+ If you already have Annotation, SegmentAnnotation, VideoAnnotation,
331
+ Categorization, PolygonAnnotation, ImageAnnotation, DocumentTranscription,
332
+ LidarLinking, LidarAnnotation, or VideoboxAnnotation projects with Scale,
333
+ use this endpoint to import your project directly into Nucleus.
334
+
335
+ This endpoint is asynchronous because there can be delays when the
336
+ number of tasks is larger than 1000. As a result, the endpoint returns
337
+ an instance of :class:`AsyncJob`.
338
+
339
+ Parameters:
340
+ project_id: The ID of the Scale/Rapid project (retrievable from URL).
341
+ last_n_tasks: If supplied, only pull in this number of the most recent
342
+ tasks. By default the endpoint will pull in all eligible tasks.
343
+ name: The name for your new Nucleus dataset. By default the endpoint
344
+ will use the project's name.
345
+
346
+ Returns:
347
+ :class:`Dataset`: The newly created Nucleus dataset as an object.
258
348
  """
259
349
  payload = {"project_id": project_id}
260
350
  if last_n_tasks:
@@ -267,20 +357,51 @@ class NucleusClient:
267
357
  def create_dataset(
268
358
  self,
269
359
  name: str,
360
+ is_scene: Optional[bool] = None,
270
361
  item_metadata_schema: Optional[Dict] = None,
271
362
  annotation_metadata_schema: Optional[Dict] = None,
272
363
  ) -> Dataset:
273
364
  """
274
- Creates a new dataset:
275
- Returns a response with internal id and name for a new dataset.
276
- :param name -- A human-readable name of the dataset.
277
- :param item_metadata_schema -- optional dictionary to define item metadata schema
278
- :param annotation_metadata_schema -- optional dictionary to define annotation metadata schema
279
- :return: new Dataset object
280
- """
365
+ Creates a new, empty dataset.
366
+
367
+ Make sure that the dataset is created for the data type you would like to support.
368
+ Be sure to set the ``is_scene`` parameter correctly.
369
+
370
+ Parameters:
371
+ name: A human-readable name for the dataset.
372
+ is_scene: Whether the dataset contains strictly :class:`scenes
373
+ <LidarScene>` or :class:`items <DatasetItem>`. This value is immutable.
374
+ Default is False (dataset of items).
375
+ item_metadata_schema: Dict defining item-level metadata schema. See below.
376
+ annotation_metadata_schema: Dict defining annotation-level metadata schema.
377
+
378
+ Metadata schemas must be structured as follows::
379
+
380
+ {
381
+ "field_name": {
382
+ "type": "category" | "number" | "text"
383
+ "choices": List[str] | None
384
+ "description": str | None
385
+ },
386
+ ...
387
+ }
388
+
389
+ Returns:
390
+ :class:`Dataset`: The newly created Nucleus dataset as an object.
391
+ """
392
+ if is_scene is None:
393
+ warnings.warn(
394
+ "The default create_dataset('dataset_name', ...) method without the is_scene parameter will be "
395
+ "deprecated soon in favor of providing the is_scene parameter explicitly. "
396
+ "Please make sure to create a dataset with either create_dataset('dataset_name', is_scene=False, ...) "
397
+ "to upload DatasetItems or create_dataset('dataset_name', is_scene=True, ...) to upload LidarScenes.",
398
+ DeprecationWarning,
399
+ )
400
+ is_scene = False
281
401
  response = self.make_request(
282
402
  {
283
403
  NAME_KEY: name,
404
+ DATASET_IS_SCENE_KEY: is_scene,
284
405
  ANNOTATION_METADATA_SCHEMA_KEY: annotation_metadata_schema,
285
406
  ITEM_METADATA_SCHEMA_KEY: item_metadata_schema,
286
407
  },
@@ -290,293 +411,55 @@ class NucleusClient:
290
411
 
291
412
  def delete_dataset(self, dataset_id: str) -> dict:
292
413
  """
293
- Deletes a private dataset based on datasetId.
294
- Returns an empty payload where response status `200` indicates
295
- the dataset has been successfully deleted.
296
- :param payload: { "name": str }
297
- :return: { "dataset_id": str, "name": str }
414
+ Deletes a dataset by ID.
415
+
416
+ All items, annotations, and predictions associated with the dataset will
417
+ be deleted as well.
418
+
419
+ Parameters:
420
+ dataset_id: The ID of the dataset to delete.
421
+
422
+ Returns:
423
+ Payload to indicate deletion invocation.
298
424
  """
299
425
  return self.make_request({}, f"dataset/{dataset_id}", requests.delete)
300
426
 
301
- @sanitize_string_args
427
+ @deprecated("Use Dataset.delete_item instead.")
302
428
  def delete_dataset_item(self, dataset_id: str, reference_id) -> dict:
303
- """
304
- Deletes a private dataset based on datasetId.
305
- Returns an empty payload where response status `200` indicates
306
- the dataset has been successfully deleted.
307
- :param payload: { "name": str }
308
- :return: { "dataset_id": str, "name": str }
309
- """
310
- return self.make_request(
311
- {},
312
- f"dataset/{dataset_id}/refloc/{reference_id}",
313
- requests.delete,
314
- )
429
+ dataset = self.get_dataset(dataset_id)
430
+ return dataset.delete_item(reference_id)
315
431
 
432
+ @deprecated("Use Dataset.append instead.")
316
433
  def populate_dataset(
317
434
  self,
318
435
  dataset_id: str,
319
436
  dataset_items: List[DatasetItem],
320
- batch_size: int = 100,
437
+ batch_size: int = 20,
321
438
  update: bool = False,
322
439
  ):
323
- """
324
- Appends images to a dataset with given dataset_id.
325
- Overwrites images on collision if updated.
326
- :param dataset_id: id of a dataset
327
- :param payload: { "items": List[DatasetItem], "update": bool }
328
- :param local: flag if images are stored locally
329
- :param batch_size: size of the batch for long payload
330
- :return:
331
- {
332
- "dataset_id: str,
333
- "new_items": int,
334
- "updated_items": int,
335
- "ignored_items": int,
336
- "upload_errors": int
337
- }
338
- """
339
- local_items = []
340
- remote_items = []
341
-
342
- # Check local files exist before sending requests
343
- for item in dataset_items:
344
- if item.local:
345
- if not item.local_file_exists():
346
- raise NotFoundError()
347
- local_items.append(item)
348
- else:
349
- remote_items.append(item)
350
-
351
- local_batches = [
352
- local_items[i : i + batch_size]
353
- for i in range(0, len(local_items), batch_size)
354
- ]
355
-
356
- remote_batches = [
357
- remote_items[i : i + batch_size]
358
- for i in range(0, len(remote_items), batch_size)
359
- ]
360
-
361
- agg_response = UploadResponse(json={DATASET_ID_KEY: dataset_id})
362
-
363
- async_responses: List[Any] = []
364
-
365
- if local_batches:
366
- tqdm_local_batches = self.tqdm_bar(
367
- local_batches, desc="Local file batches"
368
- )
369
-
370
- for batch in tqdm_local_batches:
371
- payload = construct_append_payload(batch, update)
372
- responses = self._process_append_requests_local(
373
- dataset_id, payload, update
374
- )
375
- async_responses.extend(responses)
376
-
377
- if remote_batches:
378
- tqdm_remote_batches = self.tqdm_bar(
379
- remote_batches, desc="Remote file batches"
380
- )
381
- for batch in tqdm_remote_batches:
382
- payload = construct_append_payload(batch, update)
383
- responses = self._process_append_requests(
384
- dataset_id=dataset_id,
385
- payload=payload,
386
- update=update,
387
- batch_size=batch_size,
388
- )
389
- async_responses.extend(responses)
390
-
391
- for response in async_responses:
392
- agg_response.update_response(response)
393
-
394
- return agg_response
395
-
396
- def _process_append_requests_local(
397
- self,
398
- dataset_id: str,
399
- payload: dict,
400
- update: bool, # TODO: understand how to pass this in.
401
- local_batch_size: int = 10,
402
- ):
403
- def get_files(batch):
404
- for item in batch:
405
- item[UPDATE_KEY] = update
406
- request_payload = [
407
- (
408
- ITEMS_KEY,
409
- (
410
- None,
411
- json.dumps(batch, allow_nan=False),
412
- "application/json",
413
- ),
414
- )
415
- ]
416
- for item in batch:
417
- image = open( # pylint: disable=R1732
418
- item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732
419
- ) # pylint: disable=R1732
420
- img_name = os.path.basename(image.name)
421
- img_type = (
422
- f"image/{os.path.splitext(image.name)[1].strip('.')}"
423
- )
424
- request_payload.append(
425
- (IMAGE_KEY, (img_name, image, img_type))
426
- )
427
- return request_payload
428
-
429
- items = payload[ITEMS_KEY]
430
- responses: List[Any] = []
431
- files_per_request = []
432
- payload_items = []
433
- for i in range(0, len(items), local_batch_size):
434
- batch = items[i : i + local_batch_size]
435
- files_per_request.append(get_files(batch))
436
- payload_items.append(batch)
437
-
438
- future = self.make_many_files_requests_asynchronously(
439
- files_per_request,
440
- f"dataset/{dataset_id}/append",
440
+ dataset = self.get_dataset(dataset_id)
441
+ return dataset.append(
442
+ dataset_items, batch_size=batch_size, update=update
441
443
  )
442
444
 
443
- try:
444
- loop = asyncio.get_event_loop()
445
- except RuntimeError: # no event loop running:
446
- loop = asyncio.new_event_loop()
447
- responses = loop.run_until_complete(future)
448
- else:
449
- nest_asyncio.apply(loop)
450
- return loop.run_until_complete(future)
451
-
452
- def close_files(request_items):
453
- for item in request_items:
454
- # file buffer in location [1][1]
455
- if item[0] == IMAGE_KEY:
456
- item[1][1].close()
457
-
458
- # don't forget to close all open files
459
- for p in files_per_request:
460
- close_files(p)
461
-
462
- return responses
463
-
464
- async def make_many_files_requests_asynchronously(
465
- self, files_per_request, route
466
- ):
467
- """
468
- Makes an async post request with files to a Nucleus endpoint.
469
-
470
- :param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type))
471
- name will become the name by which the multer can build an array.
472
- :param route: route for the request
473
- :return: awaitable list(response)
474
- """
475
- async with aiohttp.ClientSession() as session:
476
- tasks = [
477
- asyncio.ensure_future(
478
- self._make_files_request(
479
- files=files, route=route, session=session
480
- )
481
- )
482
- for files in files_per_request
483
- ]
484
- return await asyncio.gather(*tasks)
485
-
486
- async def _make_files_request(
487
- self,
488
- files,
489
- route: str,
490
- session: aiohttp.ClientSession,
491
- ):
492
- """
493
- Makes an async post request with files to a Nucleus endpoint.
494
-
495
- :param files: A list of tuples (name, (filename, file_pointer, file_type))
496
- :param route: route for the request
497
- :param session: Session to use for post.
498
- :return: response
499
- """
500
- endpoint = f"{self.endpoint}/{route}"
501
-
502
- logger.info("Posting to %s", endpoint)
503
-
504
- form = aiohttp.FormData()
505
-
506
- for file in files:
507
- form.add_field(
508
- name=file[0],
509
- filename=file[1][0],
510
- value=file[1][1],
511
- content_type=file[1][2],
512
- )
513
-
514
- async with session.post(
515
- endpoint,
516
- data=form,
517
- auth=aiohttp.BasicAuth(self.api_key, ""),
518
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
519
- ) as response:
520
- logger.info("API request has response code %s", response.status)
521
-
522
- try:
523
- data = await response.json()
524
- except aiohttp.client_exceptions.ContentTypeError:
525
- # In case of 404, the server returns text
526
- data = await response.text()
527
-
528
- if not response.ok:
529
- self.handle_bad_response(
530
- endpoint,
531
- session.post,
532
- aiohttp_response=(response.status, response.reason, data),
533
- )
534
-
535
- return data
536
-
537
- def _process_append_requests(
538
- self,
539
- dataset_id: str,
540
- payload: dict,
541
- update: bool,
542
- batch_size: int = 20,
543
- ):
544
- items = payload[ITEMS_KEY]
545
- payloads = [
546
- # batch_size images per request
547
- {ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
548
- for i in range(0, len(items), batch_size)
549
- ]
550
-
551
- return [
552
- self.make_request(
553
- payload,
554
- f"dataset/{dataset_id}/append",
555
- )
556
- for payload in payloads
557
- ]
558
-
559
445
  def annotate_dataset(
560
446
  self,
561
447
  dataset_id: str,
562
- annotations: List[
448
+ annotations: Sequence[
563
449
  Union[
564
450
  BoxAnnotation,
565
451
  PolygonAnnotation,
566
452
  CuboidAnnotation,
453
+ CategoryAnnotation,
454
+ MultiCategoryAnnotation,
567
455
  SegmentationAnnotation,
568
456
  ]
569
457
  ],
570
458
  update: bool,
571
459
  batch_size: int = 5000,
572
- ):
573
- """
574
- Uploads ground truth annotations for a given dataset.
575
- :param dataset_id: id of the dataset
576
- :param annotations: List[Union[BoxAnnotation, PolygonAnnotation, CuboidAnnotation, SegmentationAnnotation]]
577
- :param update: whether to update or ignore conflicting annotations
578
- :return: {"dataset_id: str, "annotations_processed": int}
579
- """
460
+ ) -> Dict[str, object]:
461
+ # TODO: deprecate in favor of Dataset.annotate invocation
462
+
580
463
  # Split payload into segmentations and Box/Polygon
581
464
  segmentations = [
582
465
  ann
@@ -603,6 +486,7 @@ class NucleusClient:
603
486
  DATASET_ID_KEY: dataset_id,
604
487
  ANNOTATIONS_PROCESSED_KEY: 0,
605
488
  ANNOTATIONS_IGNORED_KEY: 0,
489
+ ERRORS_KEY: [],
606
490
  }
607
491
 
608
492
  total_batches = len(batches) + len(semseg_batches)
@@ -625,6 +509,7 @@ class NucleusClient:
625
509
  agg_response[ANNOTATIONS_IGNORED_KEY] += response[
626
510
  ANNOTATIONS_IGNORED_KEY
627
511
  ]
512
+ agg_response[ERRORS_KEY] += response[ERRORS_KEY]
628
513
 
629
514
  for s_batch in semseg_batches:
630
515
  payload = construct_segmentation_payload(s_batch, update)
@@ -644,29 +529,33 @@ class NucleusClient:
644
529
 
645
530
  return agg_response
646
531
 
532
+ @deprecated(msg="Use Dataset.ingest_tasks instead")
647
533
  def ingest_tasks(self, dataset_id: str, payload: dict):
648
- """
649
- If you already submitted tasks to Scale for annotation this endpoint ingests your completed tasks
650
- annotated by Scale into your Nucleus Dataset.
651
- Right now we support ingestion from Videobox Annotation and 2D Box Annotation projects.
652
- :param payload: {"tasks" : List[task_ids]}
653
- :param dataset_id: id of the dataset
654
- :return: {"ingested_tasks": int, "ignored_tasks": int, "pending_tasks": int}
655
- """
656
- return self.make_request(payload, f"dataset/{dataset_id}/ingest_tasks")
534
+ dataset = self.get_dataset(dataset_id)
535
+ return dataset.ingest_tasks(payload["tasks"])
657
536
 
537
+ @deprecated(msg="Use client.create_model instead.")
658
538
  def add_model(
659
539
  self, name: str, reference_id: str, metadata: Optional[Dict] = None
660
540
  ) -> Model:
661
- """
662
- Adds a model info to your repo based on payload params:
663
- name -- A human-readable name of the model project.
664
- reference_id -- An optional user-specified identifier to reference this given model.
665
- metadata -- An arbitrary metadata blob for the model.
666
- :param name: A human-readable name of the model project.
667
- :param reference_id: An user-specified identifier to reference this given model.
668
- :param metadata: An optional arbitrary metadata blob for the model.
669
- :return: { "model_id": str }
541
+ return self.create_model(name, reference_id, metadata)
542
+
543
+ def create_model(
544
+ self, name: str, reference_id: str, metadata: Optional[Dict] = None
545
+ ) -> Model:
546
+ """Adds a :class:`Model` to Nucleus.
547
+
548
+ Parameters:
549
+ name: A human-readable name for the model.
550
+ reference_id: Unique, user-controlled ID for the model. This can be
551
+ used, for example, to link to an external storage of models which
552
+ may have its own id scheme.
553
+ metadata: An arbitrary dictionary of additional data about this model
554
+ that can be stored and retrieved. For example, you can store information
555
+ about the hyperparameters used in training this model.
556
+
557
+ Returns:
558
+ :class:`Model`: The newly created model as an object.
670
559
  """
671
560
  response = self.make_request(
672
561
  construct_model_creation_payload(name, reference_id, metadata),
@@ -678,31 +567,10 @@ class NucleusClient:
678
567
 
679
568
  return Model(model_id, name, reference_id, metadata, self)
680
569
 
570
+ @deprecated(
571
+ "Model runs have been deprecated and will be removed. Use a Model instead"
572
+ )
681
573
  def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
682
- """
683
- Creates model run for dataset_id based on the given parameters specified in the payload:
684
-
685
- 'reference_id' -- The user-specified reference identifier to associate with the model.
686
- The 'model_id' field should be empty if this field is populated.
687
-
688
- 'model_id' -- The internally-controlled identifier of the model.
689
- The 'reference_id' field should be empty if this field is populated.
690
-
691
- 'name' -- An optional name for the model run.
692
-
693
- 'metadata' -- An arbitrary metadata blob for the current run.
694
-
695
- :param
696
- dataset_id: id of the dataset
697
- payload:
698
- {
699
- "reference_id": str,
700
- "model_id": str,
701
- "name": Optional[str],
702
- "metadata": Optional[Dict[str, Any]],
703
- }
704
- :return: new ModelRun object
705
- """
706
574
  response = self.make_request(
707
575
  payload, f"dataset/{dataset_id}/modelRun/create"
708
576
  )
@@ -713,32 +581,34 @@ class NucleusClient:
713
581
  response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
714
582
  )
715
583
 
584
+ @deprecated("Use Dataset.upload_predictions instead.")
716
585
  def predict(
717
586
  self,
718
- model_run_id: str,
719
587
  annotations: List[
720
588
  Union[
721
589
  BoxPrediction,
722
590
  PolygonPrediction,
723
591
  CuboidPrediction,
724
592
  SegmentationPrediction,
593
+ CategoryPrediction,
725
594
  ]
726
595
  ],
727
- update: bool,
596
+ model_run_id: Optional[str] = None,
597
+ model_id: Optional[str] = None,
598
+ dataset_id: Optional[str] = None,
599
+ update: bool = False,
728
600
  batch_size: int = 5000,
729
601
  ):
730
- """
731
- Uploads model outputs as predictions for a model_run. Returns info about the upload.
732
- :param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
733
- :param update: bool
734
- :return:
735
- {
736
- "dataset_id": str,
737
- "model_run_id": str,
738
- "predictions_processed": int,
739
- "predictions_ignored": int,
740
- }
741
- """
602
+ if model_run_id is not None:
603
+ assert model_id is None and dataset_id is None
604
+ endpoint = f"modelRun/{model_run_id}/predict"
605
+ else:
606
+ assert (
607
+ model_id is not None and dataset_id is not None
608
+ ), "Model ID and dataset ID are required if not using model run id."
609
+ endpoint = (
610
+ f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
611
+ )
742
612
  segmentations = [
743
613
  ann
744
614
  for ann in annotations
@@ -761,11 +631,9 @@ class NucleusClient:
761
631
  for i in range(0, len(other_predictions), batch_size)
762
632
  ]
763
633
 
764
- agg_response = {
765
- MODEL_RUN_ID_KEY: model_run_id,
766
- PREDICTIONS_PROCESSED_KEY: 0,
767
- PREDICTIONS_IGNORED_KEY: 0,
768
- }
634
+ errors = []
635
+ predictions_processed = 0
636
+ predictions_ignored = 0
769
637
 
770
638
  tqdm_batches = self.tqdm_bar(batches)
771
639
 
@@ -774,230 +642,129 @@ class NucleusClient:
774
642
  batch,
775
643
  update,
776
644
  )
777
- response = self.make_request(
778
- batch_payload, f"modelRun/{model_run_id}/predict"
779
- )
645
+ response = self.make_request(batch_payload, endpoint)
780
646
  if STATUS_CODE_KEY in response:
781
- agg_response[ERRORS_KEY] = response
647
+ errors.append(response)
782
648
  else:
783
- agg_response[PREDICTIONS_PROCESSED_KEY] += response[
784
- PREDICTIONS_PROCESSED_KEY
785
- ]
786
- agg_response[PREDICTIONS_IGNORED_KEY] += response[
787
- PREDICTIONS_IGNORED_KEY
788
- ]
649
+ predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
650
+ predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
651
+ if ERRORS_KEY in response:
652
+ errors += response[ERRORS_KEY]
789
653
 
790
654
  for s_batch in s_batches:
791
655
  payload = construct_segmentation_payload(s_batch, update)
792
- response = self.make_request(
793
- payload, f"modelRun/{model_run_id}/predict_segmentation"
794
- )
656
+ response = self.make_request(payload, endpoint)
795
657
  # pbar.update(1)
796
658
  if STATUS_CODE_KEY in response:
797
- agg_response[ERRORS_KEY] = response
659
+ errors.append(response)
798
660
  else:
799
- agg_response[PREDICTIONS_PROCESSED_KEY] += response[
800
- PREDICTIONS_PROCESSED_KEY
801
- ]
802
- agg_response[PREDICTIONS_IGNORED_KEY] += response[
803
- PREDICTIONS_IGNORED_KEY
804
- ]
661
+ predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
662
+ predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
805
663
 
806
- return agg_response
664
+ return {
665
+ MODEL_RUN_ID_KEY: model_run_id,
666
+ PREDICTIONS_PROCESSED_KEY: predictions_processed,
667
+ PREDICTIONS_IGNORED_KEY: predictions_ignored,
668
+ ERRORS_KEY: errors,
669
+ }
807
670
 
671
+ @deprecated(
672
+ "Model runs have been deprecated and will be removed. Use a Model instead."
673
+ )
808
674
  def commit_model_run(
809
675
  self, model_run_id: str, payload: Optional[dict] = None
810
676
  ):
811
- """
812
- Commits the model run. Starts matching algorithm defined by payload.
813
- class_agnostic -- A flag to specify if matching algorithm should be class-agnostic or not.
814
- Default value: True
815
-
816
- allowed_label_matches -- An optional list of AllowedMatch objects to specify allowed matches
817
- for ground truth and model predictions.
818
- If specified, 'class_agnostic' flag is assumed to be False
819
-
820
- Type 'AllowedMatch':
821
- {
822
- ground_truth_label: string, # A label for ground truth annotation.
823
- model_prediction_label: string, # A label for model prediction that can be matched with
824
- # corresponding ground truth label.
825
- }
826
-
827
- payload:
828
- {
829
- "class_agnostic": boolean,
830
- "allowed_label_matches": List[AllowedMatch],
831
- }
832
-
833
- :return: {"model_run_id": str}
834
- """
677
+ # TODO: deprecate ModelRun. this should be renamed to calculate_evaluation_metrics
678
+ # or completely removed in favor of Model class methods
835
679
  if payload is None:
836
680
  payload = {}
837
681
  return self.make_request(payload, f"modelRun/{model_run_id}/commit")
838
682
 
683
+ @deprecated(msg="Prefer calling Dataset.info() directly.")
839
684
  def dataset_info(self, dataset_id: str):
840
- """
841
- Returns information about existing dataset
842
- :param dataset_id: dataset id
843
- :return: dictionary of the form
844
- {
845
- 'name': str,
846
- 'length': int,
847
- 'model_run_ids': List[str],
848
- 'slice_ids': List[str]
849
- }
850
- """
851
- return self.make_request(
852
- {}, f"dataset/{dataset_id}/info", requests.get
853
- )
685
+ dataset = self.get_dataset(dataset_id)
686
+ return dataset.info()
854
687
 
688
+ @deprecated(
689
+ "Model runs have been deprecated and will be removed. Use a Model instead."
690
+ )
855
691
  def model_run_info(self, model_run_id: str):
856
- """
857
- provides information about a Model Run with given model_run_id:
858
- model_id -- Model Id corresponding to the run
859
- name -- A human-readable name of the model project.
860
- status -- Status of the Model Run.
861
- metadata -- An arbitrary metadata blob specified for the run.
862
- :return:
863
- {
864
- "model_id": str,
865
- "name": str,
866
- "status": str,
867
- "metadata": Dict[str, Any],
868
- }
869
- """
692
+ # TODO: deprecate ModelRun
870
693
  return self.make_request(
871
694
  {}, f"modelRun/{model_run_id}/info", requests.get
872
695
  )
873
696
 
697
+ @deprecated("Prefer calling Dataset.refloc instead.")
874
698
  @sanitize_string_args
875
699
  def dataitem_ref_id(self, dataset_id: str, reference_id: str):
876
- """
877
- :param dataset_id: internally controlled dataset id
878
- :param reference_id: reference_id of a dataset_item
879
- :return:
880
- """
700
+ # TODO: deprecate in favor of Dataset.refloc invocation
881
701
  return self.make_request(
882
702
  {}, f"dataset/{dataset_id}/refloc/{reference_id}", requests.get
883
703
  )
884
704
 
705
+ @deprecated("Prefer calling Dataset.predictions_refloc instead.")
885
706
  @sanitize_string_args
886
- def predictions_ref_id(self, model_run_id: str, ref_id: str):
887
- """
888
- Returns Model Run info For Dataset Item by model_run_id and item reference_id.
889
- :param model_run_id: id of the model run.
890
- :param reference_id: reference_id of a dataset item.
891
- :return:
892
- {
893
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
894
- }
895
- """
896
- return self.make_request(
897
- {}, f"modelRun/{model_run_id}/refloc/{ref_id}", requests.get
898
- )
707
+ def predictions_ref_id(
708
+ self, model_run_id: str, ref_id: str, dataset_id: Optional[str] = None
709
+ ):
710
+ if dataset_id:
711
+ raise RuntimeError(
712
+ "Need to pass a dataset id. Or use Dataset.predictions_refloc."
713
+ )
714
+ # TODO: deprecate ModelRun
715
+ m_run = self.get_model_run(model_run_id, dataset_id)
716
+ return m_run.refloc(ref_id)
899
717
 
718
+ @deprecated("Prefer calling Dataset.iloc instead.")
900
719
  def dataitem_iloc(self, dataset_id: str, i: int):
901
- """
902
- Returns Dataset Item info by dataset_id and absolute number of the dataset item.
903
- :param dataset_id: internally controlled dataset id
904
- :param i: absolute number of the dataset_item
905
- :return:
906
- """
720
+ # TODO: deprecate in favor of Dataset.iloc invocation
907
721
  return self.make_request(
908
722
  {}, f"dataset/{dataset_id}/iloc/{i}", requests.get
909
723
  )
910
724
 
725
+ @deprecated("Prefer calling Dataset.predictions_iloc instead.")
911
726
  def predictions_iloc(self, model_run_id: str, i: int):
912
- """
913
- Returns Model Run Info For Dataset Item by model_run_id and absolute number of an item.
914
- :param model_run_id: id of the model run.
915
- :param i: absolute number of Dataset Item for a dataset corresponding to the model run.
916
- :return:
917
- {
918
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
919
- }
920
- """
727
+ # TODO: deprecate ModelRun
921
728
  return self.make_request(
922
729
  {}, f"modelRun/{model_run_id}/iloc/{i}", requests.get
923
730
  )
924
731
 
732
+ @deprecated("Prefer calling Dataset.loc instead.")
925
733
  def dataitem_loc(self, dataset_id: str, dataset_item_id: str):
926
- """
927
- Returns Dataset Item Info By dataset_item_id and dataset_id
928
- :param dataset_id: internally controlled id for the dataset.
929
- :param dataset_item_id: internally controlled id for the dataset item.
930
- :return:
931
- {
932
- "item": DatasetItem,
933
- "annotations": List[Box2DAnnotation],
934
- }
935
- """
734
+ # TODO: deprecate in favor of Dataset.loc invocation
936
735
  return self.make_request(
937
736
  {}, f"dataset/{dataset_id}/loc/{dataset_item_id}", requests.get
938
737
  )
939
738
 
739
+ @deprecated("Prefer calling Dataset.predictions_loc instead.")
940
740
  def predictions_loc(self, model_run_id: str, dataset_item_id: str):
941
- """
942
- Returns Model Run Info For Dataset Item by its id.
943
- :param model_run_id: id of the model run.
944
- :param dataset_item_id: dataset_item_id of a dataset item.
945
- :return:
946
- {
947
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
948
- }
949
- """
741
+ # TODO: deprecate ModelRun
950
742
  return self.make_request(
951
743
  {}, f"modelRun/{model_run_id}/loc/{dataset_item_id}", requests.get
952
744
  )
953
745
 
746
+ @deprecated("Prefer calling Dataset.create_slice instead.")
954
747
  def create_slice(self, dataset_id: str, payload: dict) -> Slice:
955
- """
956
- Creates a slice from items already present in a dataset.
957
- The caller must exclusively use either datasetItemIds or reference_ids
958
- as a means of identifying items in the dataset.
959
-
960
- "name" -- The human-readable name of the slice.
961
- "reference_ids" -- An optional list of user-specified identifier for the items in the slice
962
-
963
- :param
964
- dataset_id: id of the dataset
965
- payload:
966
- {
967
- "name": str,
968
- "reference_ids": List[str],
969
- }
970
- :return: new Slice object
971
- """
972
- response = self.make_request(
973
- payload, f"dataset/{dataset_id}/create_slice"
974
- )
975
- return Slice(response[SLICE_ID_KEY], self)
748
+ # TODO: deprecate in favor of Dataset.create_slice
749
+ dataset = self.get_dataset(dataset_id)
750
+ return dataset.create_slice(payload["name"], payload["reference_ids"])
976
751
 
977
752
  def get_slice(self, slice_id: str) -> Slice:
978
- """
979
- Returns a slice object by specified id.
753
+ # TODO: migrate to Dataset method and deprecate
754
+ """Returns a slice object by Nucleus-generated ID.
755
+
756
+ Parameters:
757
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
758
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
980
759
 
981
- :param
982
- slice_id: id of the slice
983
- :return: a Slice object
760
+ Returns:
761
+ :class:`Slice`: The Nucleus slice as an object.
984
762
  """
985
763
  return Slice(slice_id, self)
986
764
 
765
+ @deprecated("Prefer calling Slice.info instead.")
987
766
  def slice_info(self, slice_id: str) -> dict:
988
- """
989
- This endpoint provides information about specified slice.
990
-
991
- :param
992
- slice_id: id of the slice
993
-
994
- :return:
995
- {
996
- "name": str,
997
- "dataset_id": str,
998
- "reference_ids": List[str],
999
- }
1000
- """
767
+ # TODO: deprecate in favor of Slice.info
1001
768
  response = self.make_request(
1002
769
  {},
1003
770
  f"slice/{slice_id}",
@@ -1006,14 +773,15 @@ class NucleusClient:
1006
773
  return response
1007
774
 
1008
775
  def delete_slice(self, slice_id: str) -> dict:
1009
- """
1010
- This endpoint deletes specified slice.
776
+ # TODO: migrate to Dataset method and deprecate
777
+ """Deletes slice from Nucleus.
1011
778
 
1012
- :param
1013
- slice_id: id of the slice
779
+ Parameters:
780
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
781
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
1014
782
 
1015
- :return:
1016
- {}
783
+ Returns:
784
+ Empty payload response.
1017
785
  """
1018
786
  response = self.make_request(
1019
787
  {},
@@ -1022,45 +790,29 @@ class NucleusClient:
1022
790
  )
1023
791
  return response
1024
792
 
793
+ @deprecated("Prefer calling Dataset.delete_annotations instead.")
1025
794
  def delete_annotations(
1026
795
  self, dataset_id: str, reference_ids: list = None, keep_history=False
1027
- ) -> dict:
1028
- """
1029
- This endpoint deletes annotations.
1030
-
1031
- :param
1032
- slice_id: id of the slice
1033
-
1034
- :return:
1035
- {}
1036
- """
1037
- payload = {KEEP_HISTORY_KEY: keep_history}
1038
- if reference_ids:
1039
- payload[REFERENCE_IDS_KEY] = reference_ids
1040
- response = self.make_request(
1041
- payload,
1042
- f"annotation/{dataset_id}",
1043
- requests_command=requests.delete,
1044
- )
1045
- return response
796
+ ) -> AsyncJob:
797
+ dataset = self.get_dataset(dataset_id)
798
+ return dataset.delete_annotations(reference_ids, keep_history)
1046
799
 
1047
800
  def append_to_slice(
1048
801
  self,
1049
802
  slice_id: str,
1050
803
  reference_ids: List[str],
1051
804
  ) -> dict:
1052
- """
1053
- Appends to a slice from items already present in a dataset.
1054
- The caller must exclusively use either datasetItemIds or reference_ids
1055
- as a means of identifying items in the dataset.
805
+ # TODO: migrate to Slice method and deprecate
806
+ """Appends dataset items to an existing slice.
1056
807
 
1057
- :param
1058
- reference_ids: List[str],
808
+ Parameters:
809
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
810
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
811
+ reference_ids: List of user-defined reference IDs of the dataset items
812
+ to append to the slice.
1059
813
 
1060
- :return:
1061
- {
1062
- "slice_id": str,
1063
- }
814
+ Returns:
815
+ Empty payload response.
1064
816
  """
1065
817
 
1066
818
  response = self.make_request(
@@ -1068,12 +820,8 @@ class NucleusClient:
1068
820
  )
1069
821
  return response
1070
822
 
1071
- def list_autotags(self, dataset_id: str) -> List[str]:
1072
- """
1073
- Fetches a list of autotags for a given dataset id
1074
- :param dataset_id: internally controlled dataset_id
1075
- :return: List[str] representing autotag_ids
1076
- """
823
+ def list_autotags(self, dataset_id: str) -> List[dict]:
824
+ # TODO: deprecate in favor of Dataset.list_autotags invocation
1077
825
  response = self.make_request(
1078
826
  {},
1079
827
  f"{dataset_id}/list_autotags",
@@ -1082,25 +830,27 @@ class NucleusClient:
1082
830
  return response[AUTOTAGS_KEY] if AUTOTAGS_KEY in response else response
1083
831
 
1084
832
  def delete_autotag(self, autotag_id: str) -> dict:
1085
- """
1086
- Deletes an autotag based on autotagId.
1087
- Returns an empty payload where response status `200` indicates
1088
- the autotag has been successfully deleted.
1089
- :param autotag_id: id of the autotag to delete.
1090
- :return: {}
833
+ # TODO: migrate to Dataset method (use autotag name, not id) and deprecate
834
+ """Deletes an autotag by ID.
835
+
836
+ Parameters:
837
+ autotag_id: Nucleus-generated autotag ID (starts with ``tag_``). This can
838
+ be retrieved via :meth:`list_autotags` or a Nucleus dashboard URL.
839
+
840
+ Returns:
841
+ Empty payload response.
1091
842
  """
1092
843
  return self.make_request({}, f"autotag/{autotag_id}", requests.delete)
1093
844
 
1094
845
  def delete_model(self, model_id: str) -> dict:
1095
- """
1096
- This endpoint deletes the specified model, along with all
1097
- associated model_runs.
846
+ """Deletes a model by ID.
1098
847
 
1099
- :param
1100
- model_id: id of the model_run to delete.
848
+ Parameters:
849
+ model_id: Nucleus-generated model ID (starts with ``prj_``). This can
850
+ be retrieved via :meth:`list_models` or a Nucleus dashboard URL.
1101
851
 
1102
- :return:
1103
- {}
852
+ Returns:
853
+ Empty payload response.
1104
854
  """
1105
855
  response = self.make_request(
1106
856
  {},
@@ -1109,101 +859,95 @@ class NucleusClient:
1109
859
  )
1110
860
  return response
1111
861
 
862
+ @deprecated("Prefer calling Dataset.create_custom_index instead.")
1112
863
  def create_custom_index(
1113
864
  self, dataset_id: str, embeddings_urls: list, embedding_dim: int
1114
865
  ):
1115
- """
1116
- Creates a custom index for a given dataset, which will then be used
1117
- for autotag and similarity search.
1118
-
1119
- :param
1120
- dataset_id: id of dataset that the custom index is being added to.
1121
- embeddings_urls: list of urls, each of which being a json mapping reference_id -> embedding vector
1122
- embedding_dim: the dimension of the embedding vectors, must be consistent for all embedding vectors in the index.
1123
- """
1124
- return self.make_request(
1125
- {
1126
- EMBEDDINGS_URL_KEY: embeddings_urls,
1127
- EMBEDDING_DIMENSION_KEY: embedding_dim,
1128
- },
1129
- f"indexing/{dataset_id}",
1130
- requests_command=requests.post,
1131
- )
1132
-
1133
- def check_index_status(self, job_id: str):
1134
- return self.make_request(
1135
- {},
1136
- f"indexing/{job_id}",
1137
- requests_command=requests.get,
866
+ # TODO: deprecate in favor of Dataset.create_custom_index invocation
867
+ dataset = self.get_dataset(dataset_id)
868
+ return dataset.create_custom_index(
869
+ embeddings_urls=embeddings_urls, embedding_dim=embedding_dim
1138
870
  )
1139
871
 
872
+ @deprecated("Prefer calling Dataset.delete_custom_index instead.")
1140
873
  def delete_custom_index(self, dataset_id: str):
874
+ # TODO: deprecate in favor of Dataset.delete_custom_index invocation
1141
875
  return self.make_request(
1142
876
  {},
1143
877
  f"indexing/{dataset_id}",
1144
878
  requests_command=requests.delete,
1145
879
  )
1146
880
 
881
+ @deprecated("Prefer calling Dataset.set_continuous_indexing instead.")
1147
882
  def set_continuous_indexing(self, dataset_id: str, enable: bool = True):
1148
- """
1149
- Sets continuous indexing for a given dataset, which will automatically generate embeddings whenever
1150
- new images are uploaded. This endpoint is currently only enabled for enterprise customers.
1151
- Please reach out to nucleus@scale.com if you wish to learn more.
1152
-
1153
- :param
1154
- dataset_id: id of dataset that continuous indexing is being toggled for
1155
- enable: boolean, sets whether we are enabling or disabling continuous indexing. The default behavior is to enable.
1156
- """
883
+ # TODO: deprecate in favor of Dataset.set_continuous_indexing invocation
1157
884
  return self.make_request(
1158
885
  {INDEX_CONTINUOUS_ENABLE_KEY: enable},
1159
886
  f"indexing/{dataset_id}/setContinuous",
1160
887
  requests_command=requests.post,
1161
888
  )
1162
889
 
890
+ @deprecated("Prefer calling Dataset.create_image_index instead.")
1163
891
  def create_image_index(self, dataset_id: str):
1164
- """
1165
- Starts generating embeddings for images that don't have embeddings in a given dataset. These embeddings will
1166
- be used for autotag and similarity search. This endpoint is currently only enabled for enterprise customers.
1167
- Please reach out to nucleus@scale.com if you wish to learn more.
1168
-
1169
- :param
1170
- dataset_id: id of dataset for generating embeddings on.
1171
- """
892
+ # TODO: deprecate in favor of Dataset.create_image_index invocation
1172
893
  return self.make_request(
1173
894
  {},
1174
895
  f"indexing/{dataset_id}/internal/image",
1175
896
  requests_command=requests.post,
1176
897
  )
1177
898
 
1178
- def make_request(
1179
- self, payload: dict, route: str, requests_command=requests.post
1180
- ) -> dict:
1181
- """
1182
- Makes a request to Nucleus endpoint and logs a warning if not
1183
- successful.
899
+ @deprecated("Prefer calling Dataset.create_object_index instead.")
900
+ def create_object_index(
901
+ self, dataset_id: str, model_run_id: str, gt_only: bool
902
+ ):
903
+ # TODO: deprecate in favor of Dataset.create_object_index invocation
904
+ payload: Dict[str, Union[str, bool]] = {}
905
+ if model_run_id:
906
+ payload["model_run_id"] = model_run_id
907
+ elif gt_only:
908
+ payload["ingest_gt_only"] = True
909
+ return self.make_request(
910
+ payload,
911
+ f"indexing/{dataset_id}/internal/object",
912
+ requests_command=requests.post,
913
+ )
1184
914
 
1185
- :param payload: given payload
1186
- :param route: route for the request
1187
- :param requests_command: requests.post, requests.get, requests.delete
1188
- :return: response JSON
1189
- """
1190
- endpoint = f"{self.endpoint}/{route}"
915
+ def delete(self, route: str):
916
+ return self._connection.delete(route)
1191
917
 
1192
- logger.info("Posting to %s", endpoint)
918
+ def get(self, route: str):
919
+ return self._connection.get(route)
1193
920
 
1194
- response = requests_command(
1195
- endpoint,
1196
- json=payload,
1197
- headers={"Content-Type": "application/json"},
1198
- auth=(self.api_key, ""),
1199
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1200
- )
1201
- logger.info("API request has response code %s", response.status_code)
921
+ def post(self, payload: dict, route: str):
922
+ return self._connection.post(payload, route)
923
+
924
+ def put(self, payload: dict, route: str):
925
+ return self._connection.put(payload, route)
1202
926
 
1203
- if not response.ok:
1204
- self.handle_bad_response(endpoint, requests_command, response)
927
+ # TODO: Fix return type, can be a list as well. Brings on a lot of mypy errors ...
928
+ def make_request(
929
+ self,
930
+ payload: Optional[dict],
931
+ route: str,
932
+ requests_command=requests.post,
933
+ ) -> dict:
934
+ """Makes a request to a Nucleus API endpoint.
1205
935
 
1206
- return response.json()
936
+ Logs a warning if not successful.
937
+
938
+ Parameters:
939
+ payload: Given request payload.
940
+ route: Route for the request.
941
+ Requests command: ``requests.post``, ``requests.get``, or ``requests.delete``.
942
+
943
+ Returns:
944
+ Response payload as JSON dict.
945
+ """
946
+ if payload is None:
947
+ payload = {}
948
+ if requests_command is requests.get:
949
+ payload = None
950
+ return self._connection.make_request(payload, route, requests_command) # type: ignore
1207
951
 
1208
952
  def handle_bad_response(
1209
953
  self,
@@ -1212,6 +956,16 @@ class NucleusClient:
1212
956
  requests_response=None,
1213
957
  aiohttp_response=None,
1214
958
  ):
1215
- raise NucleusAPIError(
959
+ self._connection.handle_bad_response(
1216
960
  endpoint, requests_command, requests_response, aiohttp_response
1217
961
  )
962
+
963
+ def _set_api_key(self, api_key):
964
+ """Fetch API key from environment variable NUCLEUS_API_KEY if not set"""
965
+ api_key = (
966
+ api_key if api_key else os.environ.get("NUCLEUS_API_KEY", None)
967
+ )
968
+ if api_key is None:
969
+ raise NoAPIKey()
970
+
971
+ return api_key