scale-nucleus 0.1.24__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +446 -710
  14. nucleus/annotation.py +405 -85
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +5 -1
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1137 -212
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +9 -0
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +29 -19
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +134 -37
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.24.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.24.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.24.dist-info/RECORD +0 -21
nucleus/__init__.py CHANGED
@@ -1,18 +1,45 @@
1
- """
2
- Nucleus Python Library.
3
-
4
- For full documentation see: https://dashboard.scale.com/nucleus/docs/api?language=python
5
- """
6
- import asyncio
7
- import json
8
- import logging
1
+ """Nucleus Python SDK. """
2
+
3
+ __all__ = [
4
+ "AsyncJob",
5
+ "BoxAnnotation",
6
+ "BoxPrediction",
7
+ "CameraParams",
8
+ "CategoryAnnotation",
9
+ "CategoryPrediction",
10
+ "CuboidAnnotation",
11
+ "CuboidPrediction",
12
+ "Dataset",
13
+ "DatasetInfo",
14
+ "DatasetItem",
15
+ "DatasetItemRetrievalError",
16
+ "Frame",
17
+ "Frame",
18
+ "LidarScene",
19
+ "LidarScene",
20
+ "Model",
21
+ "ModelCreationError",
22
+ # "MultiCategoryAnnotation", # coming soon!
23
+ "NotFoundError",
24
+ "NucleusAPIError",
25
+ "NucleusClient",
26
+ "Point",
27
+ "Point3D",
28
+ "PolygonAnnotation",
29
+ "PolygonPrediction",
30
+ "Quaternion",
31
+ "Segment",
32
+ "SegmentationAnnotation",
33
+ "SegmentationPrediction",
34
+ "Slice",
35
+ ]
36
+
9
37
  import os
10
- import time
11
- from typing import Any, Dict, List, Optional, Union
38
+ import warnings
39
+ from typing import Dict, List, Optional, Sequence, Union
12
40
 
13
- import aiohttp
14
- import nest_asyncio
15
41
  import pkg_resources
42
+ import pydantic
16
43
  import requests
17
44
  import tqdm
18
45
  import tqdm.notebook as tqdm_notebook
@@ -21,20 +48,23 @@ from nucleus.url_utils import sanitize_string_args
21
48
 
22
49
  from .annotation import (
23
50
  BoxAnnotation,
51
+ CategoryAnnotation,
24
52
  CuboidAnnotation,
53
+ MultiCategoryAnnotation,
25
54
  Point,
26
55
  Point3D,
27
56
  PolygonAnnotation,
28
- CategoryAnnotation,
29
57
  Segment,
30
58
  SegmentationAnnotation,
31
59
  )
60
+ from .connection import Connection
32
61
  from .constants import (
33
62
  ANNOTATION_METADATA_SCHEMA_KEY,
34
63
  ANNOTATIONS_IGNORED_KEY,
35
64
  ANNOTATIONS_PROCESSED_KEY,
36
65
  AUTOTAGS_KEY,
37
66
  DATASET_ID_KEY,
67
+ DATASET_IS_SCENE_KEY,
38
68
  DEFAULT_NETWORK_TIMEOUT_SEC,
39
69
  EMBEDDING_DIMENSION_KEY,
40
70
  EMBEDDINGS_URL_KEY,
@@ -62,16 +92,21 @@ from .constants import (
62
92
  STATUS_CODE_KEY,
63
93
  UPDATE_KEY,
64
94
  )
95
+ from .data_transfer_object.dataset_details import DatasetDetails
96
+ from .data_transfer_object.dataset_info import DatasetInfo
65
97
  from .dataset import Dataset
66
98
  from .dataset_item import CameraParams, DatasetItem, Quaternion
99
+ from .deprecation_warning import deprecated
67
100
  from .errors import (
68
101
  DatasetItemRetrievalError,
69
102
  ModelCreationError,
70
103
  ModelRunCreationError,
104
+ NoAPIKey,
71
105
  NotFoundError,
72
106
  NucleusAPIError,
73
107
  )
74
108
  from .job import AsyncJob
109
+ from .logger import logger
75
110
  from .model import Model
76
111
  from .model_run import ModelRun
77
112
  from .payload_constructor import (
@@ -83,13 +118,16 @@ from .payload_constructor import (
83
118
  )
84
119
  from .prediction import (
85
120
  BoxPrediction,
121
+ CategoryPrediction,
86
122
  CuboidPrediction,
87
123
  PolygonPrediction,
88
124
  SegmentationPrediction,
89
125
  )
126
+ from .retry_strategy import RetryStrategy
90
127
  from .scene import Frame, LidarScene
91
128
  from .slice import Slice
92
129
  from .upload_response import UploadResponse
130
+ from .validate import Validate
93
131
 
94
132
  # pylint: disable=E1101
95
133
  # TODO: refactor to reduce this file to under 1000 lines.
@@ -98,30 +136,25 @@ from .upload_response import UploadResponse
98
136
 
99
137
  __version__ = pkg_resources.get_distribution("scale-nucleus").version
100
138
 
101
- logger = logging.getLogger(__name__)
102
- logging.basicConfig()
103
- logging.getLogger(requests.packages.urllib3.__package__).setLevel(
104
- logging.ERROR
105
- )
106
-
107
-
108
- class RetryStrategy:
109
- statuses = {503, 504}
110
- sleep_times = [1, 3, 9]
111
-
112
139
 
113
140
  class NucleusClient:
114
- """
115
- Nucleus client.
141
+ """Client to interact with the Nucleus API via Python SDK.
142
+
143
+ Parameters:
144
+ api_key: Follow `this guide <https://scale.com/docs/account#section-api-keys>`_
145
+ to retrieve your API keys.
146
+ use_notebook: Whether the client is being used in a notebook (toggles tqdm
147
+ style). Default is ``False``.
148
+ endpoint: Base URL of the API. Default is Nucleus's current production API.
116
149
  """
117
150
 
118
151
  def __init__(
119
152
  self,
120
- api_key: str,
153
+ api_key: Optional[str] = None,
121
154
  use_notebook: bool = False,
122
155
  endpoint: str = None,
123
156
  ):
124
- self.api_key = api_key
157
+ self.api_key = self._set_api_key(api_key)
125
158
  self.tqdm_bar = tqdm.tqdm
126
159
  if endpoint is None:
127
160
  self.endpoint = os.environ.get(
@@ -132,6 +165,8 @@ class NucleusClient:
132
165
  self._use_notebook = use_notebook
133
166
  if use_notebook:
134
167
  self.tqdm_bar = tqdm_notebook.tqdm
168
+ self._connection = Connection(self.api_key, self.endpoint)
169
+ self.validate = Validate(self.api_key, self.endpoint)
135
170
 
136
171
  def __repr__(self):
137
172
  return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook}, endpoint='{self.endpoint}')"
@@ -142,10 +177,26 @@ class NucleusClient:
142
177
  return True
143
178
  return False
144
179
 
145
- def list_models(self) -> List[Model]:
180
+ @property
181
+ def datasets(self) -> List[Dataset]:
182
+ """List all Datasets
183
+
184
+ Returns:
185
+ List of all datasets accessible to user
146
186
  """
147
- Lists available models in your repo.
148
- :return: model_ids
187
+ response = self.make_request({}, "dataset/details", requests.get)
188
+ dataset_details = pydantic.parse_obj_as(List[DatasetDetails], response)
189
+ return [
190
+ Dataset(d.id, client=self, name=d.name) for d in dataset_details
191
+ ]
192
+
193
+ @property
194
+ def models(self) -> List[Model]:
195
+ # TODO: implement for Dataset, scoped just to associated models
196
+ """Fetches all of your Nucleus models.
197
+
198
+ Returns:
199
+ List[:class:`Model`]: List of models associated with the client API key.
149
200
  """
150
201
  model_objects = self.make_request({}, "models/", requests.get)
151
202
 
@@ -160,20 +211,41 @@ class NucleusClient:
160
211
  for model in model_objects["models"]
161
212
  ]
162
213
 
163
- def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
164
- """
165
- Lists available datasets in your repo.
166
- :return: { datasets_ids }
214
+ @property
215
+ def jobs(
216
+ self,
217
+ ) -> List[AsyncJob]:
218
+ """Lists all jobs, see NucleusClinet.list_jobs(...) for advanced options
219
+
220
+ Returns:
221
+ List of all AsyncJobs
167
222
  """
223
+ return self.list_jobs()
224
+
225
+ @deprecated(msg="Use the NucleusClient.models property in the future.")
226
+ def list_models(self) -> List[Model]:
227
+ return self.models
228
+
229
+ @deprecated(msg="Use the NucleusClient.datasets property in the future.")
230
+ def list_datasets(self) -> Dict[str, Union[str, List[str]]]:
168
231
  return self.make_request({}, "dataset/", requests.get)
169
232
 
170
233
  def list_jobs(
171
234
  self, show_completed=None, date_limit=None
172
235
  ) -> List[AsyncJob]:
236
+ """Fetches all of your running jobs in Nucleus.
237
+
238
+ Parameters:
239
+ show_completed: Whether to fetch completed and errored jobs or just
240
+ running jobs. Default behavior is False.
241
+ date_limit: Only fetch jobs that were started after this date. Default
242
+ behavior is 2 weeks prior to the current date.
243
+
244
+ Returns:
245
+ List[:class:`AsyncJob`]: List of running asynchronous jobs
246
+ associated with the client API key.
173
247
  """
174
- Lists jobs for user.
175
- :return: jobs
176
- """
248
+ # TODO: What type is date_limit? Use pydantic ...
177
249
  payload = {show_completed: show_completed, date_limit: date_limit}
178
250
  job_objects = self.make_request(payload, "jobs/", requests.get)
179
251
  return [
@@ -187,42 +259,47 @@ class NucleusClient:
187
259
  for job in job_objects
188
260
  ]
189
261
 
262
+ @deprecated(msg="Prefer using Dataset.items")
190
263
  def get_dataset_items(self, dataset_id) -> List[DatasetItem]:
191
- """
192
- Gets all the dataset items inside your repo as a json blob.
193
- :return [ DatasetItem ]
194
- """
195
- response = self.make_request(
196
- {}, f"dataset/{dataset_id}/datasetItems", requests.get
197
- )
198
- dataset_items = response.get("dataset_items", None)
199
- error = response.get("error", None)
200
- constructed_dataset_items = []
201
- if dataset_items:
202
- for item in dataset_items:
203
- image_url = item.get("original_image_url")
204
- metadata = item.get("metadata", None)
205
- ref_id = item.get("ref_id", None)
206
- dataset_item = DatasetItem(image_url, ref_id, metadata)
207
- constructed_dataset_items.append(dataset_item)
208
- elif error:
209
- raise DatasetItemRetrievalError(message=error)
210
-
211
- return constructed_dataset_items
264
+ dataset = self.get_dataset(dataset_id)
265
+ return dataset.items
212
266
 
213
267
  def get_dataset(self, dataset_id: str) -> Dataset:
214
- """
215
- Fetches a dataset for given id
216
- :param dataset_id: internally controlled dataset_id
217
- :return: dataset
268
+ """Fetches a dataset by its ID.
269
+
270
+ Parameters:
271
+ dataset_id: The ID of the dataset to fetch.
272
+
273
+ Returns:
274
+ :class:`Dataset`: The Nucleus dataset as an object.
218
275
  """
219
276
  return Dataset(dataset_id, self)
220
277
 
221
- def get_model(self, model_id: str) -> Model:
278
+ def get_job(self, job_id: str) -> AsyncJob:
279
+ """Fetches a dataset by its ID.
280
+
281
+ Parameters:
282
+ job_id: The ID of the dataset to fetch.
283
+
284
+ Returns:
285
+ :class:`AsyncJob`: The Nucleus async job as an object.
222
286
  """
223
- Fetched a model for a given id
224
- :param model_id: internally controlled dataset_id
225
- :return: model
287
+ payload = self.make_request(
288
+ payload={},
289
+ route=f"job/{job_id}/info",
290
+ requests_command=requests.get,
291
+ )
292
+ return AsyncJob.from_json(payload=payload, client=self)
293
+
294
+ def get_model(self, model_id: str) -> Model:
295
+ """Fetches a model by its ID.
296
+
297
+ Parameters:
298
+ model_id: Nucleus-generated model ID (starts with ``prj_``). This can
299
+ be retrieved via :meth:`list_models` or a Nucleus dashboard URL.
300
+
301
+ Returns:
302
+ :class:`Model`: The Nucleus model as an object.
226
303
  """
227
304
  payload = self.make_request(
228
305
  payload={},
@@ -231,22 +308,16 @@ class NucleusClient:
231
308
  )
232
309
  return Model.from_json(payload=payload, client=self)
233
310
 
311
+ @deprecated(
312
+ "Model runs have been deprecated and will be removed. Use a Model instead"
313
+ )
234
314
  def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
235
- """
236
- Fetches a model_run for given id
237
- :param model_run_id: internally controlled model_run_id
238
- :param dataset_id: the dataset id which may determine the prediction schema
239
- for this model run if present on the dataset.
240
- :return: model_run
241
- """
242
315
  return ModelRun(model_run_id, dataset_id, self)
243
316
 
317
+ @deprecated(
318
+ "Model runs have been deprecated and will be removed. Use a Model instead"
319
+ )
244
320
  def delete_model_run(self, model_run_id: str):
245
- """
246
- Fetches a model_run for given id
247
- :param model_run_id: internally controlled model_run_id
248
- :return: model_run
249
- """
250
321
  return self.make_request(
251
322
  {}, f"modelRun/{model_run_id}", requests.delete
252
323
  )
@@ -254,12 +325,26 @@ class NucleusClient:
254
325
  def create_dataset_from_project(
255
326
  self, project_id: str, last_n_tasks: int = None, name: str = None
256
327
  ) -> Dataset:
257
- """
258
- Creates a new dataset based on payload params:
259
- name -- A human-readable name of the dataset.
260
- Returns a response with internal id and name for a new dataset.
261
- :param payload: { "name": str }
262
- :return: new Dataset object
328
+ """Create a new dataset from an existing Scale or Rapid project.
329
+
330
+ If you already have Annotation, SegmentAnnotation, VideoAnnotation,
331
+ Categorization, PolygonAnnotation, ImageAnnotation, DocumentTranscription,
332
+ LidarLinking, LidarAnnotation, or VideoboxAnnotation projects with Scale,
333
+ use this endpoint to import your project directly into Nucleus.
334
+
335
+ This endpoint is asynchronous because there can be delays when the
336
+ number of tasks is larger than 1000. As a result, the endpoint returns
337
+ an instance of :class:`AsyncJob`.
338
+
339
+ Parameters:
340
+ project_id: The ID of the Scale/Rapid project (retrievable from URL).
341
+ last_n_tasks: If supplied, only pull in this number of the most recent
342
+ tasks. By default the endpoint will pull in all eligible tasks.
343
+ name: The name for your new Nucleus dataset. By default the endpoint
344
+ will use the project's name.
345
+
346
+ Returns:
347
+ :class:`Dataset`: The newly created Nucleus dataset as an object.
263
348
  """
264
349
  payload = {"project_id": project_id}
265
350
  if last_n_tasks:
@@ -272,20 +357,51 @@ class NucleusClient:
272
357
  def create_dataset(
273
358
  self,
274
359
  name: str,
360
+ is_scene: Optional[bool] = None,
275
361
  item_metadata_schema: Optional[Dict] = None,
276
362
  annotation_metadata_schema: Optional[Dict] = None,
277
363
  ) -> Dataset:
278
364
  """
279
- Creates a new dataset:
280
- Returns a response with internal id and name for a new dataset.
281
- :param name -- A human-readable name of the dataset.
282
- :param item_metadata_schema -- optional dictionary to define item metadata schema
283
- :param annotation_metadata_schema -- optional dictionary to define annotation metadata schema
284
- :return: new Dataset object
285
- """
365
+ Creates a new, empty dataset.
366
+
367
+ Make sure that the dataset is created for the data type you would like to support.
368
+ Be sure to set the ``is_scene`` parameter correctly.
369
+
370
+ Parameters:
371
+ name: A human-readable name for the dataset.
372
+ is_scene: Whether the dataset contains strictly :class:`scenes
373
+ <LidarScene>` or :class:`items <DatasetItem>`. This value is immutable.
374
+ Default is False (dataset of items).
375
+ item_metadata_schema: Dict defining item-level metadata schema. See below.
376
+ annotation_metadata_schema: Dict defining annotation-level metadata schema.
377
+
378
+ Metadata schemas must be structured as follows::
379
+
380
+ {
381
+ "field_name": {
382
+ "type": "category" | "number" | "text"
383
+ "choices": List[str] | None
384
+ "description": str | None
385
+ },
386
+ ...
387
+ }
388
+
389
+ Returns:
390
+ :class:`Dataset`: The newly created Nucleus dataset as an object.
391
+ """
392
+ if is_scene is None:
393
+ warnings.warn(
394
+ "The default create_dataset('dataset_name', ...) method without the is_scene parameter will be "
395
+ "deprecated soon in favor of providing the is_scene parameter explicitly. "
396
+ "Please make sure to create a dataset with either create_dataset('dataset_name', is_scene=False, ...) "
397
+ "to upload DatasetItems or create_dataset('dataset_name', is_scene=True, ...) to upload LidarScenes.",
398
+ DeprecationWarning,
399
+ )
400
+ is_scene = False
286
401
  response = self.make_request(
287
402
  {
288
403
  NAME_KEY: name,
404
+ DATASET_IS_SCENE_KEY: is_scene,
289
405
  ANNOTATION_METADATA_SCHEMA_KEY: annotation_metadata_schema,
290
406
  ITEM_METADATA_SCHEMA_KEY: item_metadata_schema,
291
407
  },
@@ -295,307 +411,55 @@ class NucleusClient:
295
411
 
296
412
  def delete_dataset(self, dataset_id: str) -> dict:
297
413
  """
298
- Deletes a private dataset based on datasetId.
299
- Returns an empty payload where response status `200` indicates
300
- the dataset has been successfully deleted.
301
- :param payload: { "name": str }
302
- :return: { "dataset_id": str, "name": str }
414
+ Deletes a dataset by ID.
415
+
416
+ All items, annotations, and predictions associated with the dataset will
417
+ be deleted as well.
418
+
419
+ Parameters:
420
+ dataset_id: The ID of the dataset to delete.
421
+
422
+ Returns:
423
+ Payload to indicate deletion invocation.
303
424
  """
304
425
  return self.make_request({}, f"dataset/{dataset_id}", requests.delete)
305
426
 
306
- @sanitize_string_args
427
+ @deprecated("Use Dataset.delete_item instead.")
307
428
  def delete_dataset_item(self, dataset_id: str, reference_id) -> dict:
308
- """
309
- Deletes a private dataset based on datasetId.
310
- Returns an empty payload where response status `200` indicates
311
- the dataset has been successfully deleted.
312
- :param payload: { "name": str }
313
- :return: { "dataset_id": str, "name": str }
314
- """
315
- return self.make_request(
316
- {},
317
- f"dataset/{dataset_id}/refloc/{reference_id}",
318
- requests.delete,
319
- )
429
+ dataset = self.get_dataset(dataset_id)
430
+ return dataset.delete_item(reference_id)
320
431
 
432
+ @deprecated("Use Dataset.append instead.")
321
433
  def populate_dataset(
322
434
  self,
323
435
  dataset_id: str,
324
436
  dataset_items: List[DatasetItem],
325
- batch_size: int = 100,
437
+ batch_size: int = 20,
326
438
  update: bool = False,
327
439
  ):
328
- """
329
- Appends images to a dataset with given dataset_id.
330
- Overwrites images on collision if updated.
331
- :param dataset_id: id of a dataset
332
- :param payload: { "items": List[DatasetItem], "update": bool }
333
- :param local: flag if images are stored locally
334
- :param batch_size: size of the batch for long payload
335
- :return:
336
- {
337
- "dataset_id: str,
338
- "new_items": int,
339
- "updated_items": int,
340
- "ignored_items": int,
341
- "upload_errors": int
342
- }
343
- """
344
- local_items = []
345
- remote_items = []
346
-
347
- # Check local files exist before sending requests
348
- for item in dataset_items:
349
- if item.local:
350
- if not item.local_file_exists():
351
- raise NotFoundError()
352
- local_items.append(item)
353
- else:
354
- remote_items.append(item)
355
-
356
- local_batches = [
357
- local_items[i : i + batch_size]
358
- for i in range(0, len(local_items), batch_size)
359
- ]
360
-
361
- remote_batches = [
362
- remote_items[i : i + batch_size]
363
- for i in range(0, len(remote_items), batch_size)
364
- ]
365
-
366
- agg_response = UploadResponse(json={DATASET_ID_KEY: dataset_id})
367
-
368
- async_responses: List[Any] = []
369
-
370
- if local_batches:
371
- tqdm_local_batches = self.tqdm_bar(
372
- local_batches, desc="Local file batches"
373
- )
374
-
375
- for batch in tqdm_local_batches:
376
- payload = construct_append_payload(batch, update)
377
- responses = self._process_append_requests_local(
378
- dataset_id, payload, update
379
- )
380
- async_responses.extend(responses)
381
-
382
- if remote_batches:
383
- tqdm_remote_batches = self.tqdm_bar(
384
- remote_batches, desc="Remote file batches"
385
- )
386
- for batch in tqdm_remote_batches:
387
- payload = construct_append_payload(batch, update)
388
- responses = self._process_append_requests(
389
- dataset_id=dataset_id,
390
- payload=payload,
391
- update=update,
392
- batch_size=batch_size,
393
- )
394
- async_responses.extend(responses)
395
-
396
- for response in async_responses:
397
- agg_response.update_response(response)
398
-
399
- return agg_response
400
-
401
- def _process_append_requests_local(
402
- self,
403
- dataset_id: str,
404
- payload: dict,
405
- update: bool, # TODO: understand how to pass this in.
406
- local_batch_size: int = 10,
407
- ):
408
- def get_files(batch):
409
- for item in batch:
410
- item[UPDATE_KEY] = update
411
- request_payload = [
412
- (
413
- ITEMS_KEY,
414
- (
415
- None,
416
- json.dumps(batch, allow_nan=False),
417
- "application/json",
418
- ),
419
- )
420
- ]
421
- for item in batch:
422
- image = open( # pylint: disable=R1732
423
- item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732
424
- ) # pylint: disable=R1732
425
- img_name = os.path.basename(image.name)
426
- img_type = (
427
- f"image/{os.path.splitext(image.name)[1].strip('.')}"
428
- )
429
- request_payload.append(
430
- (IMAGE_KEY, (img_name, image, img_type))
431
- )
432
- return request_payload
433
-
434
- items = payload[ITEMS_KEY]
435
- responses: List[Any] = []
436
- files_per_request = []
437
- payload_items = []
438
- for i in range(0, len(items), local_batch_size):
439
- batch = items[i : i + local_batch_size]
440
- files_per_request.append(get_files(batch))
441
- payload_items.append(batch)
442
-
443
- future = self.make_many_files_requests_asynchronously(
444
- files_per_request,
445
- f"dataset/{dataset_id}/append",
440
+ dataset = self.get_dataset(dataset_id)
441
+ return dataset.append(
442
+ dataset_items, batch_size=batch_size, update=update
446
443
  )
447
444
 
448
- try:
449
- loop = asyncio.get_event_loop()
450
- except RuntimeError: # no event loop running:
451
- loop = asyncio.new_event_loop()
452
- responses = loop.run_until_complete(future)
453
- else:
454
- nest_asyncio.apply(loop)
455
- return loop.run_until_complete(future)
456
-
457
- def close_files(request_items):
458
- for item in request_items:
459
- # file buffer in location [1][1]
460
- if item[0] == IMAGE_KEY:
461
- item[1][1].close()
462
-
463
- # don't forget to close all open files
464
- for p in files_per_request:
465
- close_files(p)
466
-
467
- return responses
468
-
469
- async def make_many_files_requests_asynchronously(
470
- self, files_per_request, route
471
- ):
472
- """
473
- Makes an async post request with files to a Nucleus endpoint.
474
-
475
- :param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type))
476
- name will become the name by which the multer can build an array.
477
- :param route: route for the request
478
- :return: awaitable list(response)
479
- """
480
- async with aiohttp.ClientSession() as session:
481
- tasks = [
482
- asyncio.ensure_future(
483
- self._make_files_request(
484
- files=files, route=route, session=session
485
- )
486
- )
487
- for files in files_per_request
488
- ]
489
- return await asyncio.gather(*tasks)
490
-
491
- async def _make_files_request(
492
- self,
493
- files,
494
- route: str,
495
- session: aiohttp.ClientSession,
496
- ):
497
- """
498
- Makes an async post request with files to a Nucleus endpoint.
499
-
500
- :param files: A list of tuples (name, (filename, file_pointer, file_type))
501
- :param route: route for the request
502
- :param session: Session to use for post.
503
- :return: response
504
- """
505
- endpoint = f"{self.endpoint}/{route}"
506
-
507
- logger.info("Posting to %s", endpoint)
508
-
509
- form = aiohttp.FormData()
510
-
511
- for file in files:
512
- form.add_field(
513
- name=file[0],
514
- filename=file[1][0],
515
- value=file[1][1],
516
- content_type=file[1][2],
517
- )
518
-
519
- for sleep_time in RetryStrategy.sleep_times + [-1]:
520
- async with session.post(
521
- endpoint,
522
- data=form,
523
- auth=aiohttp.BasicAuth(self.api_key, ""),
524
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
525
- ) as response:
526
- logger.info(
527
- "API request has response code %s", response.status
528
- )
529
-
530
- try:
531
- data = await response.json()
532
- except aiohttp.client_exceptions.ContentTypeError:
533
- # In case of 404, the server returns text
534
- data = await response.text()
535
- if (
536
- response.status in RetryStrategy.statuses
537
- and sleep_time != -1
538
- ):
539
- time.sleep(sleep_time)
540
- continue
541
-
542
- if not response.ok:
543
- self.handle_bad_response(
544
- endpoint,
545
- session.post,
546
- aiohttp_response=(
547
- response.status,
548
- response.reason,
549
- data,
550
- ),
551
- )
552
-
553
- return data
554
-
555
- def _process_append_requests(
556
- self,
557
- dataset_id: str,
558
- payload: dict,
559
- update: bool,
560
- batch_size: int = 20,
561
- ):
562
- items = payload[ITEMS_KEY]
563
- payloads = [
564
- # batch_size images per request
565
- {ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
566
- for i in range(0, len(items), batch_size)
567
- ]
568
-
569
- return [
570
- self.make_request(
571
- payload,
572
- f"dataset/{dataset_id}/append",
573
- )
574
- for payload in payloads
575
- ]
576
-
577
445
  def annotate_dataset(
578
446
  self,
579
447
  dataset_id: str,
580
- annotations: List[
448
+ annotations: Sequence[
581
449
  Union[
582
450
  BoxAnnotation,
583
451
  PolygonAnnotation,
584
452
  CuboidAnnotation,
585
453
  CategoryAnnotation,
454
+ MultiCategoryAnnotation,
586
455
  SegmentationAnnotation,
587
456
  ]
588
457
  ],
589
458
  update: bool,
590
459
  batch_size: int = 5000,
591
- ):
592
- """
593
- Uploads ground truth annotations for a given dataset.
594
- :param dataset_id: id of the dataset
595
- :param annotations: List[Union[BoxAnnotation, PolygonAnnotation, CuboidAnnotation, SegmentationAnnotation]]
596
- :param update: whether to update or ignore conflicting annotations
597
- :return: {"dataset_id: str, "annotations_processed": int}
598
- """
460
+ ) -> Dict[str, object]:
461
+ # TODO: deprecate in favor of Dataset.annotate invocation
462
+
599
463
  # Split payload into segmentations and Box/Polygon
600
464
  segmentations = [
601
465
  ann
@@ -622,6 +486,7 @@ class NucleusClient:
622
486
  DATASET_ID_KEY: dataset_id,
623
487
  ANNOTATIONS_PROCESSED_KEY: 0,
624
488
  ANNOTATIONS_IGNORED_KEY: 0,
489
+ ERRORS_KEY: [],
625
490
  }
626
491
 
627
492
  total_batches = len(batches) + len(semseg_batches)
@@ -644,6 +509,7 @@ class NucleusClient:
644
509
  agg_response[ANNOTATIONS_IGNORED_KEY] += response[
645
510
  ANNOTATIONS_IGNORED_KEY
646
511
  ]
512
+ agg_response[ERRORS_KEY] += response[ERRORS_KEY]
647
513
 
648
514
  for s_batch in semseg_batches:
649
515
  payload = construct_segmentation_payload(s_batch, update)
@@ -663,29 +529,33 @@ class NucleusClient:
663
529
 
664
530
  return agg_response
665
531
 
532
+ @deprecated(msg="Use Dataset.ingest_tasks instead")
666
533
  def ingest_tasks(self, dataset_id: str, payload: dict):
667
- """
668
- If you already submitted tasks to Scale for annotation this endpoint ingests your completed tasks
669
- annotated by Scale into your Nucleus Dataset.
670
- Right now we support ingestion from Videobox Annotation and 2D Box Annotation projects.
671
- :param payload: {"tasks" : List[task_ids]}
672
- :param dataset_id: id of the dataset
673
- :return: {"ingested_tasks": int, "ignored_tasks": int, "pending_tasks": int}
674
- """
675
- return self.make_request(payload, f"dataset/{dataset_id}/ingest_tasks")
534
+ dataset = self.get_dataset(dataset_id)
535
+ return dataset.ingest_tasks(payload["tasks"])
676
536
 
537
+ @deprecated(msg="Use client.create_model instead.")
677
538
  def add_model(
678
539
  self, name: str, reference_id: str, metadata: Optional[Dict] = None
679
540
  ) -> Model:
680
- """
681
- Adds a model info to your repo based on payload params:
682
- name -- A human-readable name of the model project.
683
- reference_id -- An optional user-specified identifier to reference this given model.
684
- metadata -- An arbitrary metadata blob for the model.
685
- :param name: A human-readable name of the model project.
686
- :param reference_id: An user-specified identifier to reference this given model.
687
- :param metadata: An optional arbitrary metadata blob for the model.
688
- :return: { "model_id": str }
541
+ return self.create_model(name, reference_id, metadata)
542
+
543
+ def create_model(
544
+ self, name: str, reference_id: str, metadata: Optional[Dict] = None
545
+ ) -> Model:
546
+ """Adds a :class:`Model` to Nucleus.
547
+
548
+ Parameters:
549
+ name: A human-readable name for the model.
550
+ reference_id: Unique, user-controlled ID for the model. This can be
551
+ used, for example, to link to an external storage of models which
552
+ may have its own id scheme.
553
+ metadata: An arbitrary dictionary of additional data about this model
554
+ that can be stored and retrieved. For example, you can store information
555
+ about the hyperparameters used in training this model.
556
+
557
+ Returns:
558
+ :class:`Model`: The newly created model as an object.
689
559
  """
690
560
  response = self.make_request(
691
561
  construct_model_creation_payload(name, reference_id, metadata),
@@ -697,31 +567,10 @@ class NucleusClient:
697
567
 
698
568
  return Model(model_id, name, reference_id, metadata, self)
699
569
 
570
+ @deprecated(
571
+ "Model runs have been deprecated and will be removed. Use a Model instead"
572
+ )
700
573
  def create_model_run(self, dataset_id: str, payload: dict) -> ModelRun:
701
- """
702
- Creates model run for dataset_id based on the given parameters specified in the payload:
703
-
704
- 'reference_id' -- The user-specified reference identifier to associate with the model.
705
- The 'model_id' field should be empty if this field is populated.
706
-
707
- 'model_id' -- The internally-controlled identifier of the model.
708
- The 'reference_id' field should be empty if this field is populated.
709
-
710
- 'name' -- An optional name for the model run.
711
-
712
- 'metadata' -- An arbitrary metadata blob for the current run.
713
-
714
- :param
715
- dataset_id: id of the dataset
716
- payload:
717
- {
718
- "reference_id": str,
719
- "model_id": str,
720
- "name": Optional[str],
721
- "metadata": Optional[Dict[str, Any]],
722
- }
723
- :return: new ModelRun object
724
- """
725
574
  response = self.make_request(
726
575
  payload, f"dataset/{dataset_id}/modelRun/create"
727
576
  )
@@ -732,32 +581,34 @@ class NucleusClient:
732
581
  response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
733
582
  )
734
583
 
584
+ @deprecated("Use Dataset.upload_predictions instead.")
735
585
  def predict(
736
586
  self,
737
- model_run_id: str,
738
587
  annotations: List[
739
588
  Union[
740
589
  BoxPrediction,
741
590
  PolygonPrediction,
742
591
  CuboidPrediction,
743
592
  SegmentationPrediction,
593
+ CategoryPrediction,
744
594
  ]
745
595
  ],
746
- update: bool,
596
+ model_run_id: Optional[str] = None,
597
+ model_id: Optional[str] = None,
598
+ dataset_id: Optional[str] = None,
599
+ update: bool = False,
747
600
  batch_size: int = 5000,
748
601
  ):
749
- """
750
- Uploads model outputs as predictions for a model_run. Returns info about the upload.
751
- :param annotations: List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
752
- :param update: bool
753
- :return:
754
- {
755
- "dataset_id": str,
756
- "model_run_id": str,
757
- "predictions_processed": int,
758
- "predictions_ignored": int,
759
- }
760
- """
602
+ if model_run_id is not None:
603
+ assert model_id is None and dataset_id is None
604
+ endpoint = f"modelRun/{model_run_id}/predict"
605
+ else:
606
+ assert (
607
+ model_id is not None and dataset_id is not None
608
+ ), "Model ID and dataset ID are required if not using model run id."
609
+ endpoint = (
610
+ f"dataset/{dataset_id}/model/{model_id}/uploadPredictions"
611
+ )
761
612
  segmentations = [
762
613
  ann
763
614
  for ann in annotations
@@ -780,11 +631,9 @@ class NucleusClient:
780
631
  for i in range(0, len(other_predictions), batch_size)
781
632
  ]
782
633
 
783
- agg_response = {
784
- MODEL_RUN_ID_KEY: model_run_id,
785
- PREDICTIONS_PROCESSED_KEY: 0,
786
- PREDICTIONS_IGNORED_KEY: 0,
787
- }
634
+ errors = []
635
+ predictions_processed = 0
636
+ predictions_ignored = 0
788
637
 
789
638
  tqdm_batches = self.tqdm_bar(batches)
790
639
 
@@ -793,230 +642,129 @@ class NucleusClient:
793
642
  batch,
794
643
  update,
795
644
  )
796
- response = self.make_request(
797
- batch_payload, f"modelRun/{model_run_id}/predict"
798
- )
645
+ response = self.make_request(batch_payload, endpoint)
799
646
  if STATUS_CODE_KEY in response:
800
- agg_response[ERRORS_KEY] = response
647
+ errors.append(response)
801
648
  else:
802
- agg_response[PREDICTIONS_PROCESSED_KEY] += response[
803
- PREDICTIONS_PROCESSED_KEY
804
- ]
805
- agg_response[PREDICTIONS_IGNORED_KEY] += response[
806
- PREDICTIONS_IGNORED_KEY
807
- ]
649
+ predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
650
+ predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
651
+ if ERRORS_KEY in response:
652
+ errors += response[ERRORS_KEY]
808
653
 
809
654
  for s_batch in s_batches:
810
655
  payload = construct_segmentation_payload(s_batch, update)
811
- response = self.make_request(
812
- payload, f"modelRun/{model_run_id}/predict_segmentation"
813
- )
656
+ response = self.make_request(payload, endpoint)
814
657
  # pbar.update(1)
815
658
  if STATUS_CODE_KEY in response:
816
- agg_response[ERRORS_KEY] = response
659
+ errors.append(response)
817
660
  else:
818
- agg_response[PREDICTIONS_PROCESSED_KEY] += response[
819
- PREDICTIONS_PROCESSED_KEY
820
- ]
821
- agg_response[PREDICTIONS_IGNORED_KEY] += response[
822
- PREDICTIONS_IGNORED_KEY
823
- ]
661
+ predictions_processed += response[PREDICTIONS_PROCESSED_KEY]
662
+ predictions_ignored += response[PREDICTIONS_IGNORED_KEY]
824
663
 
825
- return agg_response
664
+ return {
665
+ MODEL_RUN_ID_KEY: model_run_id,
666
+ PREDICTIONS_PROCESSED_KEY: predictions_processed,
667
+ PREDICTIONS_IGNORED_KEY: predictions_ignored,
668
+ ERRORS_KEY: errors,
669
+ }
826
670
 
671
+ @deprecated(
672
+ "Model runs have been deprecated and will be removed. Use a Model instead."
673
+ )
827
674
  def commit_model_run(
828
675
  self, model_run_id: str, payload: Optional[dict] = None
829
676
  ):
830
- """
831
- Commits the model run. Starts matching algorithm defined by payload.
832
- class_agnostic -- A flag to specify if matching algorithm should be class-agnostic or not.
833
- Default value: True
834
-
835
- allowed_label_matches -- An optional list of AllowedMatch objects to specify allowed matches
836
- for ground truth and model predictions.
837
- If specified, 'class_agnostic' flag is assumed to be False
838
-
839
- Type 'AllowedMatch':
840
- {
841
- ground_truth_label: string, # A label for ground truth annotation.
842
- model_prediction_label: string, # A label for model prediction that can be matched with
843
- # corresponding ground truth label.
844
- }
845
-
846
- payload:
847
- {
848
- "class_agnostic": boolean,
849
- "allowed_label_matches": List[AllowedMatch],
850
- }
851
-
852
- :return: {"model_run_id": str}
853
- """
677
+ # TODO: deprecate ModelRun. this should be renamed to calculate_evaluation_metrics
678
+ # or completely removed in favor of Model class methods
854
679
  if payload is None:
855
680
  payload = {}
856
681
  return self.make_request(payload, f"modelRun/{model_run_id}/commit")
857
682
 
683
+ @deprecated(msg="Prefer calling Dataset.info() directly.")
858
684
  def dataset_info(self, dataset_id: str):
859
- """
860
- Returns information about existing dataset
861
- :param dataset_id: dataset id
862
- :return: dictionary of the form
863
- {
864
- 'name': str,
865
- 'length': int,
866
- 'model_run_ids': List[str],
867
- 'slice_ids': List[str]
868
- }
869
- """
870
- return self.make_request(
871
- {}, f"dataset/{dataset_id}/info", requests.get
872
- )
685
+ dataset = self.get_dataset(dataset_id)
686
+ return dataset.info()
873
687
 
688
+ @deprecated(
689
+ "Model runs have been deprecated and will be removed. Use a Model instead."
690
+ )
874
691
  def model_run_info(self, model_run_id: str):
875
- """
876
- provides information about a Model Run with given model_run_id:
877
- model_id -- Model Id corresponding to the run
878
- name -- A human-readable name of the model project.
879
- status -- Status of the Model Run.
880
- metadata -- An arbitrary metadata blob specified for the run.
881
- :return:
882
- {
883
- "model_id": str,
884
- "name": str,
885
- "status": str,
886
- "metadata": Dict[str, Any],
887
- }
888
- """
692
+ # TODO: deprecate ModelRun
889
693
  return self.make_request(
890
694
  {}, f"modelRun/{model_run_id}/info", requests.get
891
695
  )
892
696
 
697
+ @deprecated("Prefer calling Dataset.refloc instead.")
893
698
  @sanitize_string_args
894
699
  def dataitem_ref_id(self, dataset_id: str, reference_id: str):
895
- """
896
- :param dataset_id: internally controlled dataset id
897
- :param reference_id: reference_id of a dataset_item
898
- :return:
899
- """
700
+ # TODO: deprecate in favor of Dataset.refloc invocation
900
701
  return self.make_request(
901
702
  {}, f"dataset/{dataset_id}/refloc/{reference_id}", requests.get
902
703
  )
903
704
 
705
+ @deprecated("Prefer calling Dataset.predictions_refloc instead.")
904
706
  @sanitize_string_args
905
- def predictions_ref_id(self, model_run_id: str, ref_id: str):
906
- """
907
- Returns Model Run info For Dataset Item by model_run_id and item reference_id.
908
- :param model_run_id: id of the model run.
909
- :param reference_id: reference_id of a dataset item.
910
- :return:
911
- {
912
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
913
- }
914
- """
915
- return self.make_request(
916
- {}, f"modelRun/{model_run_id}/refloc/{ref_id}", requests.get
917
- )
707
+ def predictions_ref_id(
708
+ self, model_run_id: str, ref_id: str, dataset_id: Optional[str] = None
709
+ ):
710
+ if dataset_id:
711
+ raise RuntimeError(
712
+ "Need to pass a dataset id. Or use Dataset.predictions_refloc."
713
+ )
714
+ # TODO: deprecate ModelRun
715
+ m_run = self.get_model_run(model_run_id, dataset_id)
716
+ return m_run.refloc(ref_id)
918
717
 
718
+ @deprecated("Prefer calling Dataset.iloc instead.")
919
719
  def dataitem_iloc(self, dataset_id: str, i: int):
920
- """
921
- Returns Dataset Item info by dataset_id and absolute number of the dataset item.
922
- :param dataset_id: internally controlled dataset id
923
- :param i: absolute number of the dataset_item
924
- :return:
925
- """
720
+ # TODO: deprecate in favor of Dataset.iloc invocation
926
721
  return self.make_request(
927
722
  {}, f"dataset/{dataset_id}/iloc/{i}", requests.get
928
723
  )
929
724
 
725
+ @deprecated("Prefer calling Dataset.predictions_iloc instead.")
930
726
  def predictions_iloc(self, model_run_id: str, i: int):
931
- """
932
- Returns Model Run Info For Dataset Item by model_run_id and absolute number of an item.
933
- :param model_run_id: id of the model run.
934
- :param i: absolute number of Dataset Item for a dataset corresponding to the model run.
935
- :return:
936
- {
937
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
938
- }
939
- """
727
+ # TODO: deprecate ModelRun
940
728
  return self.make_request(
941
729
  {}, f"modelRun/{model_run_id}/iloc/{i}", requests.get
942
730
  )
943
731
 
732
+ @deprecated("Prefer calling Dataset.loc instead.")
944
733
  def dataitem_loc(self, dataset_id: str, dataset_item_id: str):
945
- """
946
- Returns Dataset Item Info By dataset_item_id and dataset_id
947
- :param dataset_id: internally controlled id for the dataset.
948
- :param dataset_item_id: internally controlled id for the dataset item.
949
- :return:
950
- {
951
- "item": DatasetItem,
952
- "annotations": List[Box2DAnnotation],
953
- }
954
- """
734
+ # TODO: deprecate in favor of Dataset.loc invocation
955
735
  return self.make_request(
956
736
  {}, f"dataset/{dataset_id}/loc/{dataset_item_id}", requests.get
957
737
  )
958
738
 
739
+ @deprecated("Prefer calling Dataset.predictions_loc instead.")
959
740
  def predictions_loc(self, model_run_id: str, dataset_item_id: str):
960
- """
961
- Returns Model Run Info For Dataset Item by its id.
962
- :param model_run_id: id of the model run.
963
- :param dataset_item_id: dataset_item_id of a dataset item.
964
- :return:
965
- {
966
- "annotations": List[Union[BoxPrediction, PolygonPrediction, CuboidPrediction, SegmentationPrediction]],
967
- }
968
- """
741
+ # TODO: deprecate ModelRun
969
742
  return self.make_request(
970
743
  {}, f"modelRun/{model_run_id}/loc/{dataset_item_id}", requests.get
971
744
  )
972
745
 
746
+ @deprecated("Prefer calling Dataset.create_slice instead.")
973
747
  def create_slice(self, dataset_id: str, payload: dict) -> Slice:
974
- """
975
- Creates a slice from items already present in a dataset.
976
- The caller must exclusively use either datasetItemIds or reference_ids
977
- as a means of identifying items in the dataset.
978
-
979
- "name" -- The human-readable name of the slice.
980
- "reference_ids" -- An optional list of user-specified identifier for the items in the slice
981
-
982
- :param
983
- dataset_id: id of the dataset
984
- payload:
985
- {
986
- "name": str,
987
- "reference_ids": List[str],
988
- }
989
- :return: new Slice object
990
- """
991
- response = self.make_request(
992
- payload, f"dataset/{dataset_id}/create_slice"
993
- )
994
- return Slice(response[SLICE_ID_KEY], self)
748
+ # TODO: deprecate in favor of Dataset.create_slice
749
+ dataset = self.get_dataset(dataset_id)
750
+ return dataset.create_slice(payload["name"], payload["reference_ids"])
995
751
 
996
752
  def get_slice(self, slice_id: str) -> Slice:
997
- """
998
- Returns a slice object by specified id.
753
+ # TODO: migrate to Dataset method and deprecate
754
+ """Returns a slice object by Nucleus-generated ID.
999
755
 
1000
- :param
1001
- slice_id: id of the slice
1002
- :return: a Slice object
756
+ Parameters:
757
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
758
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
759
+
760
+ Returns:
761
+ :class:`Slice`: The Nucleus slice as an object.
1003
762
  """
1004
763
  return Slice(slice_id, self)
1005
764
 
765
+ @deprecated("Prefer calling Slice.info instead.")
1006
766
  def slice_info(self, slice_id: str) -> dict:
1007
- """
1008
- This endpoint provides information about specified slice.
1009
-
1010
- :param
1011
- slice_id: id of the slice
1012
-
1013
- :return:
1014
- {
1015
- "name": str,
1016
- "dataset_id": str,
1017
- "reference_ids": List[str],
1018
- }
1019
- """
767
+ # TODO: deprecate in favor of Slice.info
1020
768
  response = self.make_request(
1021
769
  {},
1022
770
  f"slice/{slice_id}",
@@ -1025,14 +773,15 @@ class NucleusClient:
1025
773
  return response
1026
774
 
1027
775
  def delete_slice(self, slice_id: str) -> dict:
1028
- """
1029
- This endpoint deletes specified slice.
776
+ # TODO: migrate to Dataset method and deprecate
777
+ """Deletes slice from Nucleus.
1030
778
 
1031
- :param
1032
- slice_id: id of the slice
779
+ Parameters:
780
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
781
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
1033
782
 
1034
- :return:
1035
- {}
783
+ Returns:
784
+ Empty payload response.
1036
785
  """
1037
786
  response = self.make_request(
1038
787
  {},
@@ -1041,45 +790,29 @@ class NucleusClient:
1041
790
  )
1042
791
  return response
1043
792
 
793
+ @deprecated("Prefer calling Dataset.delete_annotations instead.")
1044
794
  def delete_annotations(
1045
795
  self, dataset_id: str, reference_ids: list = None, keep_history=False
1046
- ) -> dict:
1047
- """
1048
- This endpoint deletes annotations.
1049
-
1050
- :param
1051
- slice_id: id of the slice
1052
-
1053
- :return:
1054
- {}
1055
- """
1056
- payload = {KEEP_HISTORY_KEY: keep_history}
1057
- if reference_ids:
1058
- payload[REFERENCE_IDS_KEY] = reference_ids
1059
- response = self.make_request(
1060
- payload,
1061
- f"annotation/{dataset_id}",
1062
- requests_command=requests.delete,
1063
- )
1064
- return response
796
+ ) -> AsyncJob:
797
+ dataset = self.get_dataset(dataset_id)
798
+ return dataset.delete_annotations(reference_ids, keep_history)
1065
799
 
1066
800
  def append_to_slice(
1067
801
  self,
1068
802
  slice_id: str,
1069
803
  reference_ids: List[str],
1070
804
  ) -> dict:
1071
- """
1072
- Appends to a slice from items already present in a dataset.
1073
- The caller must exclusively use either datasetItemIds or reference_ids
1074
- as a means of identifying items in the dataset.
805
+ # TODO: migrate to Slice method and deprecate
806
+ """Appends dataset items to an existing slice.
1075
807
 
1076
- :param
1077
- reference_ids: List[str],
808
+ Parameters:
809
+ slice_id: Nucleus-generated dataset ID (starts with ``slc_``). This can
810
+ be retrieved via :meth:`Dataset.slices` or a Nucleus dashboard URL.
811
+ reference_ids: List of user-defined reference IDs of the dataset items
812
+ to append to the slice.
1078
813
 
1079
- :return:
1080
- {
1081
- "slice_id": str,
1082
- }
814
+ Returns:
815
+ Empty payload response.
1083
816
  """
1084
817
 
1085
818
  response = self.make_request(
@@ -1087,12 +820,8 @@ class NucleusClient:
1087
820
  )
1088
821
  return response
1089
822
 
1090
- def list_autotags(self, dataset_id: str) -> List[str]:
1091
- """
1092
- Fetches a list of autotags for a given dataset id
1093
- :param dataset_id: internally controlled dataset_id
1094
- :return: List[str] representing autotag_ids
1095
- """
823
+ def list_autotags(self, dataset_id: str) -> List[dict]:
824
+ # TODO: deprecate in favor of Dataset.list_autotags invocation
1096
825
  response = self.make_request(
1097
826
  {},
1098
827
  f"{dataset_id}/list_autotags",
@@ -1101,25 +830,27 @@ class NucleusClient:
1101
830
  return response[AUTOTAGS_KEY] if AUTOTAGS_KEY in response else response
1102
831
 
1103
832
  def delete_autotag(self, autotag_id: str) -> dict:
1104
- """
1105
- Deletes an autotag based on autotagId.
1106
- Returns an empty payload where response status `200` indicates
1107
- the autotag has been successfully deleted.
1108
- :param autotag_id: id of the autotag to delete.
1109
- :return: {}
833
+ # TODO: migrate to Dataset method (use autotag name, not id) and deprecate
834
+ """Deletes an autotag by ID.
835
+
836
+ Parameters:
837
+ autotag_id: Nucleus-generated autotag ID (starts with ``tag_``). This can
838
+ be retrieved via :meth:`list_autotags` or a Nucleus dashboard URL.
839
+
840
+ Returns:
841
+ Empty payload response.
1110
842
  """
1111
843
  return self.make_request({}, f"autotag/{autotag_id}", requests.delete)
1112
844
 
1113
845
  def delete_model(self, model_id: str) -> dict:
1114
- """
1115
- This endpoint deletes the specified model, along with all
1116
- associated model_runs.
846
+ """Deletes a model by ID.
1117
847
 
1118
- :param
1119
- model_id: id of the model_run to delete.
848
+ Parameters:
849
+ model_id: Nucleus-generated model ID (starts with ``prj_``). This can
850
+ be retrieved via :meth:`list_models` or a Nucleus dashboard URL.
1120
851
 
1121
- :return:
1122
- {}
852
+ Returns:
853
+ Empty payload response.
1123
854
  """
1124
855
  response = self.make_request(
1125
856
  {},
@@ -1128,100 +859,95 @@ class NucleusClient:
1128
859
  )
1129
860
  return response
1130
861
 
862
+ @deprecated("Prefer calling Dataset.create_custom_index instead.")
1131
863
  def create_custom_index(
1132
864
  self, dataset_id: str, embeddings_urls: list, embedding_dim: int
1133
865
  ):
1134
- """
1135
- Creates a custom index for a given dataset, which will then be used
1136
- for autotag and similarity search.
1137
-
1138
- :param
1139
- dataset_id: id of dataset that the custom index is being added to.
1140
- embeddings_urls: list of urls, each of which being a json mapping reference_id -> embedding vector
1141
- embedding_dim: the dimension of the embedding vectors, must be consistent for all embedding vectors in the index.
1142
- """
1143
- return self.make_request(
1144
- {
1145
- EMBEDDINGS_URL_KEY: embeddings_urls,
1146
- EMBEDDING_DIMENSION_KEY: embedding_dim,
1147
- },
1148
- f"indexing/{dataset_id}",
1149
- requests_command=requests.post,
866
+ # TODO: deprecate in favor of Dataset.create_custom_index invocation
867
+ dataset = self.get_dataset(dataset_id)
868
+ return dataset.create_custom_index(
869
+ embeddings_urls=embeddings_urls, embedding_dim=embedding_dim
1150
870
  )
1151
871
 
872
+ @deprecated("Prefer calling Dataset.delete_custom_index instead.")
1152
873
  def delete_custom_index(self, dataset_id: str):
874
+ # TODO: deprecate in favor of Dataset.delete_custom_index invocation
1153
875
  return self.make_request(
1154
876
  {},
1155
877
  f"indexing/{dataset_id}",
1156
878
  requests_command=requests.delete,
1157
879
  )
1158
880
 
881
+ @deprecated("Prefer calling Dataset.set_continuous_indexing instead.")
1159
882
  def set_continuous_indexing(self, dataset_id: str, enable: bool = True):
1160
- """
1161
- Sets continuous indexing for a given dataset, which will automatically generate embeddings whenever
1162
- new images are uploaded. This endpoint is currently only enabled for enterprise customers.
1163
- Please reach out to nucleus@scale.com if you wish to learn more.
1164
-
1165
- :param
1166
- dataset_id: id of dataset that continuous indexing is being toggled for
1167
- enable: boolean, sets whether we are enabling or disabling continuous indexing. The default behavior is to enable.
1168
- """
883
+ # TODO: deprecate in favor of Dataset.set_continuous_indexing invocation
1169
884
  return self.make_request(
1170
885
  {INDEX_CONTINUOUS_ENABLE_KEY: enable},
1171
886
  f"indexing/{dataset_id}/setContinuous",
1172
887
  requests_command=requests.post,
1173
888
  )
1174
889
 
890
+ @deprecated("Prefer calling Dataset.create_image_index instead.")
1175
891
  def create_image_index(self, dataset_id: str):
1176
- """
1177
- Starts generating embeddings for images that don't have embeddings in a given dataset. These embeddings will
1178
- be used for autotag and similarity search. This endpoint is currently only enabled for enterprise customers.
1179
- Please reach out to nucleus@scale.com if you wish to learn more.
1180
-
1181
- :param
1182
- dataset_id: id of dataset for generating embeddings on.
1183
- """
892
+ # TODO: deprecate in favor of Dataset.create_image_index invocation
1184
893
  return self.make_request(
1185
894
  {},
1186
895
  f"indexing/{dataset_id}/internal/image",
1187
896
  requests_command=requests.post,
1188
897
  )
1189
898
 
1190
- def make_request(
1191
- self, payload: dict, route: str, requests_command=requests.post
1192
- ) -> dict:
1193
- """
1194
- Makes a request to Nucleus endpoint and logs a warning if not
1195
- successful.
899
+ @deprecated("Prefer calling Dataset.create_object_index instead.")
900
+ def create_object_index(
901
+ self, dataset_id: str, model_run_id: str, gt_only: bool
902
+ ):
903
+ # TODO: deprecate in favor of Dataset.create_object_index invocation
904
+ payload: Dict[str, Union[str, bool]] = {}
905
+ if model_run_id:
906
+ payload["model_run_id"] = model_run_id
907
+ elif gt_only:
908
+ payload["ingest_gt_only"] = True
909
+ return self.make_request(
910
+ payload,
911
+ f"indexing/{dataset_id}/internal/object",
912
+ requests_command=requests.post,
913
+ )
1196
914
 
1197
- :param payload: given payload
1198
- :param route: route for the request
1199
- :param requests_command: requests.post, requests.get, requests.delete
1200
- :return: response JSON
1201
- """
1202
- endpoint = f"{self.endpoint}/{route}"
915
+ def delete(self, route: str):
916
+ return self._connection.delete(route)
1203
917
 
1204
- logger.info("Posting to %s", endpoint)
918
+ def get(self, route: str):
919
+ return self._connection.get(route)
1205
920
 
1206
- for retry_wait_time in RetryStrategy.sleep_times:
1207
- response = requests_command(
1208
- endpoint,
1209
- json=payload,
1210
- headers={"Content-Type": "application/json"},
1211
- auth=(self.api_key, ""),
1212
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1213
- )
1214
- logger.info(
1215
- "API request has response code %s", response.status_code
1216
- )
1217
- if response.status_code not in RetryStrategy.statuses:
1218
- break
1219
- time.sleep(retry_wait_time)
921
+ def post(self, payload: dict, route: str):
922
+ return self._connection.post(payload, route)
923
+
924
+ def put(self, payload: dict, route: str):
925
+ return self._connection.put(payload, route)
926
+
927
+ # TODO: Fix return type, can be a list as well. Brings on a lot of mypy errors ...
928
+ def make_request(
929
+ self,
930
+ payload: Optional[dict],
931
+ route: str,
932
+ requests_command=requests.post,
933
+ ) -> dict:
934
+ """Makes a request to a Nucleus API endpoint.
935
+
936
+ Logs a warning if not successful.
1220
937
 
1221
- if not response.ok:
1222
- self.handle_bad_response(endpoint, requests_command, response)
938
+ Parameters:
939
+ payload: Given request payload.
940
+ route: Route for the request.
941
+ Requests command: ``requests.post``, ``requests.get``, or ``requests.delete``.
1223
942
 
1224
- return response.json()
943
+ Returns:
944
+ Response payload as JSON dict.
945
+ """
946
+ if payload is None:
947
+ payload = {}
948
+ if requests_command is requests.get:
949
+ payload = None
950
+ return self._connection.make_request(payload, route, requests_command) # type: ignore
1225
951
 
1226
952
  def handle_bad_response(
1227
953
  self,
@@ -1230,6 +956,16 @@ class NucleusClient:
1230
956
  requests_response=None,
1231
957
  aiohttp_response=None,
1232
958
  ):
1233
- raise NucleusAPIError(
959
+ self._connection.handle_bad_response(
1234
960
  endpoint, requests_command, requests_response, aiohttp_response
1235
961
  )
962
+
963
+ def _set_api_key(self, api_key):
964
+ """Fetch API key from environment variable NUCLEUS_API_KEY if not set"""
965
+ api_key = (
966
+ api_key if api_key else os.environ.get("NUCLEUS_API_KEY", None)
967
+ )
968
+ if api_key is None:
969
+ raise NoAPIKey()
970
+
971
+ return api_key