scale-nucleus 0.1.1__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nucleus/__init__.py CHANGED
@@ -50,89 +50,83 @@ confidence | float | The optional confidence level of this annotation
50
50
  geometry | dict | Representation of the bounding box in the Box2DGeometry format.\n
51
51
  metadata | dict | An arbitrary metadata blob for the annotation.\n
52
52
  """
53
- __version__ = "0.1.0"
54
-
53
+ import asyncio
55
54
  import json
56
55
  import logging
57
- import warnings
58
56
  import os
59
- from typing import List, Union, Dict, Callable, Any, Optional
57
+ from typing import Any, Dict, List, Optional, Union
60
58
 
59
+ import aiohttp
60
+ import pkg_resources
61
+ import requests
61
62
  import tqdm
62
63
  import tqdm.notebook as tqdm_notebook
63
64
 
64
- import grequests
65
- import requests
66
- from requests.adapters import HTTPAdapter
67
-
68
- # pylint: disable=E1101
69
- # TODO: refactor to reduce this file to under 1000 lines.
70
- # pylint: disable=C0302
71
- from requests.packages.urllib3.util.retry import Retry
72
-
73
- from .constants import REFERENCE_IDS_KEY, DATASET_ITEM_IDS_KEY
74
- from .dataset import Dataset
75
- from .dataset_item import DatasetItem
76
65
  from .annotation import (
77
66
  BoxAnnotation,
78
67
  PolygonAnnotation,
79
- SegmentationAnnotation,
80
68
  Segment,
81
- )
82
- from .prediction import (
83
- BoxPrediction,
84
- PolygonPrediction,
85
- SegmentationPrediction,
86
- )
87
- from .model_run import ModelRun
88
- from .slice import Slice
89
- from .upload_response import UploadResponse
90
- from .payload_constructor import (
91
- construct_append_payload,
92
- construct_annotation_payload,
93
- construct_model_creation_payload,
94
- construct_box_predictions_payload,
95
- construct_segmentation_payload,
69
+ SegmentationAnnotation,
70
+ Point,
96
71
  )
97
72
  from .constants import (
98
- NUCLEUS_ENDPOINT,
73
+ ANNOTATION_METADATA_SCHEMA_KEY,
74
+ ANNOTATIONS_IGNORED_KEY,
75
+ ANNOTATIONS_PROCESSED_KEY,
76
+ AUTOTAGS_KEY,
77
+ DATASET_ID_KEY,
78
+ DATASET_ITEM_IDS_KEY,
99
79
  DEFAULT_NETWORK_TIMEOUT_SEC,
100
- ERRORS_KEY,
80
+ EMBEDDINGS_URL_KEY,
101
81
  ERROR_ITEMS,
102
82
  ERROR_PAYLOAD,
103
- ITEMS_KEY,
104
- ITEM_KEY,
83
+ ERRORS_KEY,
105
84
  IMAGE_KEY,
106
85
  IMAGE_URL_KEY,
107
- DATASET_ID_KEY,
86
+ ITEM_METADATA_SCHEMA_KEY,
87
+ ITEMS_KEY,
108
88
  MODEL_RUN_ID_KEY,
109
- DATASET_ITEM_ID_KEY,
110
- SLICE_ID_KEY,
111
- ANNOTATIONS_PROCESSED_KEY,
112
- ANNOTATIONS_IGNORED_KEY,
113
- PREDICTIONS_PROCESSED_KEY,
89
+ NAME_KEY,
90
+ NUCLEUS_ENDPOINT,
114
91
  PREDICTIONS_IGNORED_KEY,
92
+ PREDICTIONS_PROCESSED_KEY,
93
+ REFERENCE_IDS_KEY,
94
+ SLICE_ID_KEY,
115
95
  STATUS_CODE_KEY,
116
- SUCCESS_STATUS_CODES,
117
- DATASET_NAME_KEY,
118
- DATASET_MODEL_RUNS_KEY,
119
- DATASET_SLICES_KEY,
120
- DATASET_LENGTH_KEY,
121
- NAME_KEY,
122
- ANNOTATIONS_KEY,
123
- AUTOTAGS_KEY,
124
- ANNOTATION_METADATA_SCHEMA_KEY,
125
- ITEM_METADATA_SCHEMA_KEY,
126
- FORCE_KEY,
127
- EMBEDDINGS_URL_KEY,
96
+ UPDATE_KEY,
128
97
  )
129
- from .model import Model
98
+ from .dataset import Dataset
99
+ from .dataset_item import DatasetItem
130
100
  from .errors import (
101
+ DatasetItemRetrievalError,
131
102
  ModelCreationError,
132
103
  ModelRunCreationError,
133
- DatasetItemRetrievalError,
134
104
  NotFoundError,
105
+ NucleusAPIError,
106
+ )
107
+ from .model import Model
108
+ from .model_run import ModelRun
109
+ from .payload_constructor import (
110
+ construct_annotation_payload,
111
+ construct_append_payload,
112
+ construct_box_predictions_payload,
113
+ construct_model_creation_payload,
114
+ construct_segmentation_payload,
135
115
  )
116
+ from .prediction import (
117
+ BoxPrediction,
118
+ PolygonPrediction,
119
+ SegmentationPrediction,
120
+ )
121
+ from .slice import Slice
122
+ from .upload_response import UploadResponse
123
+
124
+ # pylint: disable=E1101
125
+ # TODO: refactor to reduce this file to under 1000 lines.
126
+ # pylint: disable=C0302
127
+
128
+
129
+ __version__ = pkg_resources.get_distribution("scale-nucleus").version
136
130
 
137
131
  logger = logging.getLogger(__name__)
138
132
  logging.basicConfig()
@@ -146,15 +140,26 @@ class NucleusClient:
146
140
  Nucleus client.
147
141
  """
148
142
 
149
- def __init__(self, api_key: str, use_notebook: bool = False):
143
+ def __init__(
144
+ self,
145
+ api_key: str,
146
+ use_notebook: bool = False,
147
+ endpoint: str = None,
148
+ ):
150
149
  self.api_key = api_key
151
150
  self.tqdm_bar = tqdm.tqdm
151
+ if endpoint is None:
152
+ self.endpoint = os.environ.get(
153
+ "NUCLEUS_ENDPOINT", NUCLEUS_ENDPOINT
154
+ )
155
+ else:
156
+ self.endpoint = endpoint
152
157
  self._use_notebook = use_notebook
153
158
  if use_notebook:
154
159
  self.tqdm_bar = tqdm_notebook.tqdm
155
160
 
156
161
  def __repr__(self):
157
- return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook})"
162
+ return f"NucleusClient(api_key='{self.api_key}', use_notebook={self._use_notebook}, endpoint='{self.endpoint}')"
158
163
 
159
164
  def __eq__(self, other):
160
165
  if self.api_key == other.api_key:
@@ -167,7 +172,7 @@ class NucleusClient:
167
172
  Lists available models in your repo.
168
173
  :return: model_ids
169
174
  """
170
- model_objects = self._make_request({}, "models/", requests.get)
175
+ model_objects = self.make_request({}, "models/", requests.get)
171
176
 
172
177
  return [
173
178
  Model(
@@ -185,14 +190,14 @@ class NucleusClient:
185
190
  Lists available datasets in your repo.
186
191
  :return: { datasets_ids }
187
192
  """
188
- return self._make_request({}, "dataset/", requests.get)
193
+ return self.make_request({}, "dataset/", requests.get)
189
194
 
190
195
  def get_dataset_items(self, dataset_id) -> List[DatasetItem]:
191
196
  """
192
197
  Gets all the dataset items inside your repo as a json blob.
193
198
  :return [ DatasetItem ]
194
199
  """
195
- response = self._make_request(
200
+ response = self.make_request(
196
201
  {}, f"dataset/{dataset_id}/datasetItems", requests.get
197
202
  )
198
203
  dataset_items = response.get("dataset_items", None)
@@ -221,13 +226,15 @@ class NucleusClient:
221
226
  """
222
227
  return Dataset(dataset_id, self)
223
228
 
224
- def get_model_run(self, model_run_id: str) -> ModelRun:
229
+ def get_model_run(self, model_run_id: str, dataset_id: str) -> ModelRun:
225
230
  """
226
231
  Fetches a model_run for given id
227
232
  :param model_run_id: internally controlled model_run_id
233
+ :param dataset_id: the dataset id which may determine the prediction schema
234
+ for this model run if present on the dataset.
228
235
  :return: model_run
229
236
  """
230
- return ModelRun(model_run_id, self)
237
+ return ModelRun(model_run_id, dataset_id, self)
231
238
 
232
239
  def delete_model_run(self, model_run_id: str):
233
240
  """
@@ -235,7 +242,7 @@ class NucleusClient:
235
242
  :param model_run_id: internally controlled model_run_id
236
243
  :return: model_run
237
244
  """
238
- return self._make_request(
245
+ return self.make_request(
239
246
  {}, f"modelRun/{model_run_id}", requests.delete
240
247
  )
241
248
 
@@ -254,7 +261,7 @@ class NucleusClient:
254
261
  payload["last_n_tasks"] = str(last_n_tasks)
255
262
  if name:
256
263
  payload["name"] = name
257
- response = self._make_request(payload, "dataset/create_from_project")
264
+ response = self.make_request(payload, "dataset/create_from_project")
258
265
  return Dataset(response[DATASET_ID_KEY], self)
259
266
 
260
267
  def create_dataset(
@@ -271,7 +278,7 @@ class NucleusClient:
271
278
  :param annotation_metadata_schema -- optional dictionary to define annotation metadata schema
272
279
  :return: new Dataset object
273
280
  """
274
- response = self._make_request(
281
+ response = self.make_request(
275
282
  {
276
283
  NAME_KEY: name,
277
284
  ANNOTATION_METADATA_SCHEMA_KEY: annotation_metadata_schema,
@@ -289,7 +296,7 @@ class NucleusClient:
289
296
  :param payload: { "name": str }
290
297
  :return: { "dataset_id": str, "name": str }
291
298
  """
292
- return self._make_request({}, f"dataset/{dataset_id}", requests.delete)
299
+ return self.make_request({}, f"dataset/{dataset_id}", requests.delete)
293
300
 
294
301
  def delete_dataset_item(
295
302
  self, dataset_id: str, item_id: str = None, reference_id: str = None
@@ -302,11 +309,11 @@ class NucleusClient:
302
309
  :return: { "dataset_id": str, "name": str }
303
310
  """
304
311
  if item_id:
305
- return self._make_request(
312
+ return self.make_request(
306
313
  {}, f"dataset/{dataset_id}/{item_id}", requests.delete
307
314
  )
308
315
  else: # Assume reference_id is provided
309
- return self._make_request(
316
+ return self.make_request(
310
317
  {},
311
318
  f"dataset/{dataset_id}/refloc/{reference_id}",
312
319
  requests.delete,
@@ -317,13 +324,13 @@ class NucleusClient:
317
324
  dataset_id: str,
318
325
  dataset_items: List[DatasetItem],
319
326
  batch_size: int = 100,
320
- force: bool = False,
327
+ update: bool = False,
321
328
  ):
322
329
  """
323
330
  Appends images to a dataset with given dataset_id.
324
- Overwrites images on collision if forced.
331
+ Overwrites images on collision if updated.
325
332
  :param dataset_id: id of a dataset
326
- :param payload: { "items": List[DatasetItem], "force": bool }
333
+ :param payload: { "items": List[DatasetItem], "update": bool }
327
334
  :param local: flag if images are stored locally
328
335
  :param batch_size: size of the batch for long payload
329
336
  :return:
@@ -366,21 +373,24 @@ class NucleusClient:
366
373
  async_responses: List[Any] = []
367
374
 
368
375
  for batch in tqdm_local_batches:
369
- payload = construct_append_payload(batch, force)
376
+ payload = construct_append_payload(batch, update)
370
377
  responses = self._process_append_requests_local(
371
- dataset_id, payload, force
378
+ dataset_id, payload, update
372
379
  )
373
380
  async_responses.extend(responses)
374
381
 
375
382
  for batch in tqdm_remote_batches:
376
- payload = construct_append_payload(batch, force)
383
+ payload = construct_append_payload(batch, update)
377
384
  responses = self._process_append_requests(
378
- dataset_id, payload, force, batch_size, batch_size
385
+ dataset_id=dataset_id,
386
+ payload=payload,
387
+ update=update,
388
+ batch_size=batch_size,
379
389
  )
380
390
  async_responses.extend(responses)
381
391
 
382
392
  for response in async_responses:
383
- agg_response.update_response(response.json())
393
+ agg_response.update_response(response)
384
394
 
385
395
  return agg_response
386
396
 
@@ -388,28 +398,24 @@ class NucleusClient:
388
398
  self,
389
399
  dataset_id: str,
390
400
  payload: dict,
391
- update: bool,
401
+ update: bool, # TODO: understand how to pass this in.
392
402
  local_batch_size: int = 10,
393
- size: int = 10,
394
403
  ):
395
- def error(batch_items: dict) -> UploadResponse:
396
- return UploadResponse(
397
- {
398
- DATASET_ID_KEY: dataset_id,
399
- ERROR_ITEMS: len(batch_items),
400
- ERROR_PAYLOAD: batch_items,
401
- }
402
- )
403
-
404
- def exception_handler(request, exception):
405
- logger.error(exception)
406
-
407
- def preprocess_payload(batch):
404
+ def get_files(batch):
408
405
  request_payload = [
409
- (ITEMS_KEY, (None, json.dumps(batch), "application/json"))
406
+ (
407
+ ITEMS_KEY,
408
+ (
409
+ None,
410
+ json.dumps(batch, allow_nan=False),
411
+ "application/json",
412
+ ),
413
+ )
410
414
  ]
411
415
  for item in batch:
412
- image = open(item.get(IMAGE_URL_KEY), "rb")
416
+ image = open( # pylint: disable=R1732
417
+ item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732
418
+ ) # pylint: disable=R1732
413
419
  img_name = os.path.basename(image.name)
414
420
  img_type = (
415
421
  f"image/{os.path.splitext(image.name)[1].strip('.')}"
@@ -421,27 +427,19 @@ class NucleusClient:
421
427
 
422
428
  items = payload[ITEMS_KEY]
423
429
  responses: List[Any] = []
424
- request_payloads = []
430
+ files_per_request = []
425
431
  payload_items = []
426
432
  for i in range(0, len(items), local_batch_size):
427
433
  batch = items[i : i + local_batch_size]
428
- batch_payload = preprocess_payload(batch)
429
- request_payloads.append(batch_payload)
434
+ files_per_request.append(get_files(batch))
430
435
  payload_items.append(batch)
431
436
 
432
- async_requests = [
433
- self._make_grequest(
434
- payload,
437
+ loop = asyncio.get_event_loop()
438
+ responses = loop.run_until_complete(
439
+ self.make_many_files_requests_asynchronously(
440
+ files_per_request,
435
441
  f"dataset/{dataset_id}/append",
436
- local=True,
437
442
  )
438
- for payload in request_payloads
439
- ]
440
-
441
- async_responses = grequests.map(
442
- async_requests,
443
- exception_handler=exception_handler,
444
- size=size,
445
443
  )
446
444
 
447
445
  def close_files(request_items):
@@ -451,69 +449,106 @@ class NucleusClient:
451
449
  item[1][1].close()
452
450
 
453
451
  # don't forget to close all open files
454
- for p in request_payloads:
452
+ for p in files_per_request:
455
453
  close_files(p)
456
454
 
457
- # response object will be None if an error occurred
458
- async_responses = [
459
- response
460
- if (response and response.status_code == 200)
461
- else error(request_items)
462
- for response, request_items in zip(async_responses, payload_items)
463
- ]
464
- responses.extend(async_responses)
465
-
466
455
  return responses
467
456
 
457
+ async def make_many_files_requests_asynchronously(
458
+ self, files_per_request, route
459
+ ):
460
+ """
461
+ Makes an async post request with files to a Nucleus endpoint.
462
+
463
+ :param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type))
464
+ name will become the name by which the multer can build an array.
465
+ :param route: route for the request
466
+ :return: awaitable list(response)
467
+ """
468
+ async with aiohttp.ClientSession() as session:
469
+ tasks = [
470
+ asyncio.ensure_future(
471
+ self._make_files_request(
472
+ files=files, route=route, session=session
473
+ )
474
+ )
475
+ for files in files_per_request
476
+ ]
477
+ return await asyncio.gather(*tasks)
478
+
479
+ async def _make_files_request(
480
+ self,
481
+ files,
482
+ route: str,
483
+ session: aiohttp.ClientSession,
484
+ ):
485
+ """
486
+ Makes an async post request with files to a Nucleus endpoint.
487
+
488
+ :param files: A list of tuples (name, (filename, file_pointer, file_type))
489
+ :param route: route for the request
490
+ :param session: Session to use for post.
491
+ :return: response
492
+ """
493
+ endpoint = f"{self.endpoint}/{route}"
494
+
495
+ logger.info("Posting to %s", endpoint)
496
+
497
+ form = aiohttp.FormData()
498
+
499
+ for file in files:
500
+ form.add_field(
501
+ name=file[0],
502
+ filename=file[1][0],
503
+ value=file[1][1],
504
+ content_type=file[1][2],
505
+ )
506
+
507
+ async with session.post(
508
+ endpoint,
509
+ data=form,
510
+ auth=aiohttp.BasicAuth(self.api_key, ""),
511
+ timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
512
+ ) as response:
513
+ logger.info("API request has response code %s", response.status)
514
+
515
+ try:
516
+ data = await response.json()
517
+ except aiohttp.client_exceptions.ContentTypeError:
518
+ # In case of 404, the server returns text
519
+ data = await response.text()
520
+
521
+ if not response.ok:
522
+ self.handle_bad_response(
523
+ endpoint,
524
+ session.post,
525
+ aiohttp_response=(response.status, response.reason, data),
526
+ )
527
+
528
+ return data
529
+
468
530
  def _process_append_requests(
469
531
  self,
470
532
  dataset_id: str,
471
533
  payload: dict,
472
534
  update: bool,
473
535
  batch_size: int = 20,
474
- size: int = 10,
475
536
  ):
476
- def default_error(payload: dict) -> UploadResponse:
477
- return UploadResponse(
478
- {
479
- DATASET_ID_KEY: dataset_id,
480
- ERROR_ITEMS: len(payload[ITEMS_KEY]),
481
- ERROR_PAYLOAD: payload[ITEMS_KEY],
482
- }
483
- )
484
-
485
- def exception_handler(request, exception):
486
- logger.error(exception)
487
-
488
537
  items = payload[ITEMS_KEY]
489
538
  payloads = [
490
539
  # batch_size images per request
491
- {ITEMS_KEY: items[i : i + batch_size], FORCE_KEY: update}
540
+ {ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
492
541
  for i in range(0, len(items), batch_size)
493
542
  ]
494
543
 
495
- async_requests = [
496
- self._make_grequest(
544
+ return [
545
+ self.make_request(
497
546
  payload,
498
547
  f"dataset/{dataset_id}/append",
499
- local=False,
500
548
  )
501
549
  for payload in payloads
502
550
  ]
503
551
 
504
- async_responses = grequests.map(
505
- async_requests, exception_handler=exception_handler, size=size
506
- )
507
-
508
- async_responses = [
509
- response
510
- if (response and response.status_code == 200)
511
- else default_error(payload)
512
- for response, payload in zip(async_responses, payloads)
513
- ]
514
-
515
- return async_responses
516
-
517
552
  def annotate_dataset(
518
553
  self,
519
554
  dataset_id: str,
@@ -566,7 +601,7 @@ class NucleusClient:
566
601
  with self.tqdm_bar(total=total_batches) as pbar:
567
602
  for batch in tqdm_batches:
568
603
  payload = construct_annotation_payload(batch, update)
569
- response = self._make_request(
604
+ response = self.make_request(
570
605
  payload, f"dataset/{dataset_id}/annotate"
571
606
  )
572
607
  pbar.update(1)
@@ -582,7 +617,7 @@ class NucleusClient:
582
617
 
583
618
  for s_batch in semseg_batches:
584
619
  payload = construct_segmentation_payload(s_batch, update)
585
- response = self._make_request(
620
+ response = self.make_request(
586
621
  payload, f"dataset/{dataset_id}/annotate_segmentation"
587
622
  )
588
623
  pbar.update(1)
@@ -607,9 +642,7 @@ class NucleusClient:
607
642
  :param dataset_id: id of the dataset
608
643
  :return: {"ingested_tasks": int, "ignored_tasks": int, "pending_tasks": int}
609
644
  """
610
- return self._make_request(
611
- payload, f"dataset/{dataset_id}/ingest_tasks"
612
- )
645
+ return self.make_request(payload, f"dataset/{dataset_id}/ingest_tasks")
613
646
 
614
647
  def add_model(
615
648
  self, name: str, reference_id: str, metadata: Optional[Dict] = None
@@ -624,7 +657,7 @@ class NucleusClient:
624
657
  :param metadata: An optional arbitrary metadata blob for the model.
625
658
  :return: { "model_id": str }
626
659
  """
627
- response = self._make_request(
660
+ response = self.make_request(
628
661
  construct_model_creation_payload(name, reference_id, metadata),
629
662
  "models/add",
630
663
  )
@@ -659,13 +692,15 @@ class NucleusClient:
659
692
  }
660
693
  :return: new ModelRun object
661
694
  """
662
- response = self._make_request(
695
+ response = self.make_request(
663
696
  payload, f"dataset/{dataset_id}/modelRun/create"
664
697
  )
665
698
  if response.get(STATUS_CODE_KEY, None):
666
699
  raise ModelRunCreationError(response.get("error"))
667
700
 
668
- return ModelRun(response[MODEL_RUN_ID_KEY], self)
701
+ return ModelRun(
702
+ response[MODEL_RUN_ID_KEY], dataset_id=dataset_id, client=self
703
+ )
669
704
 
670
705
  def predict(
671
706
  self,
@@ -723,7 +758,7 @@ class NucleusClient:
723
758
  batch,
724
759
  update,
725
760
  )
726
- response = self._make_request(
761
+ response = self.make_request(
727
762
  batch_payload, f"modelRun/{model_run_id}/predict"
728
763
  )
729
764
  if STATUS_CODE_KEY in response:
@@ -738,7 +773,7 @@ class NucleusClient:
738
773
 
739
774
  for s_batch in s_batches:
740
775
  payload = construct_segmentation_payload(s_batch, update)
741
- response = self._make_request(
776
+ response = self.make_request(
742
777
  payload, f"modelRun/{model_run_id}/predict_segmentation"
743
778
  )
744
779
  # pbar.update(1)
@@ -783,7 +818,7 @@ class NucleusClient:
783
818
  """
784
819
  if payload is None:
785
820
  payload = {}
786
- return self._make_request(payload, f"modelRun/{model_run_id}/commit")
821
+ return self.make_request(payload, f"modelRun/{model_run_id}/commit")
787
822
 
788
823
  def dataset_info(self, dataset_id: str):
789
824
  """
@@ -797,7 +832,7 @@ class NucleusClient:
797
832
  'slice_ids': List[str]
798
833
  }
799
834
  """
800
- return self._make_request(
835
+ return self.make_request(
801
836
  {}, f"dataset/{dataset_id}/info", requests.get
802
837
  )
803
838
 
@@ -816,7 +851,7 @@ class NucleusClient:
816
851
  "metadata": Dict[str, Any],
817
852
  }
818
853
  """
819
- return self._make_request(
854
+ return self.make_request(
820
855
  {}, f"modelRun/{model_run_id}/info", requests.get
821
856
  )
822
857
 
@@ -826,7 +861,7 @@ class NucleusClient:
826
861
  :param reference_id: reference_id of a dataset_item
827
862
  :return:
828
863
  """
829
- return self._make_request(
864
+ return self.make_request(
830
865
  {}, f"dataset/{dataset_id}/refloc/{reference_id}", requests.get
831
866
  )
832
867
 
@@ -840,7 +875,7 @@ class NucleusClient:
840
875
  "annotations": List[BoxPrediction],
841
876
  }
842
877
  """
843
- return self._make_request(
878
+ return self.make_request(
844
879
  {}, f"modelRun/{model_run_id}/refloc/{ref_id}", requests.get
845
880
  )
846
881
 
@@ -851,7 +886,7 @@ class NucleusClient:
851
886
  :param i: absolute number of the dataset_item
852
887
  :return:
853
888
  """
854
- return self._make_request(
889
+ return self.make_request(
855
890
  {}, f"dataset/{dataset_id}/iloc/{i}", requests.get
856
891
  )
857
892
 
@@ -865,7 +900,7 @@ class NucleusClient:
865
900
  "annotations": List[BoxPrediction],
866
901
  }
867
902
  """
868
- return self._make_request(
903
+ return self.make_request(
869
904
  {}, f"modelRun/{model_run_id}/iloc/{i}", requests.get
870
905
  )
871
906
 
@@ -880,7 +915,7 @@ class NucleusClient:
880
915
  "annotations": List[Box2DAnnotation],
881
916
  }
882
917
  """
883
- return self._make_request(
918
+ return self.make_request(
884
919
  {}, f"dataset/{dataset_id}/loc/{dataset_item_id}", requests.get
885
920
  )
886
921
 
@@ -894,7 +929,7 @@ class NucleusClient:
894
929
  "annotations": List[BoxPrediction],
895
930
  }
896
931
  """
897
- return self._make_request(
932
+ return self.make_request(
898
933
  {}, f"modelRun/{model_run_id}/loc/{dataset_item_id}", requests.get
899
934
  )
900
935
 
@@ -920,7 +955,7 @@ class NucleusClient:
920
955
  }
921
956
  :return: new Slice object
922
957
  """
923
- response = self._make_request(
958
+ response = self.make_request(
924
959
  payload, f"dataset/{dataset_id}/create_slice"
925
960
  )
926
961
  return Slice(response[SLICE_ID_KEY], self)
@@ -951,7 +986,7 @@ class NucleusClient:
951
986
  "dataset_item_ids": List[str],
952
987
  }
953
988
  """
954
- response = self._make_request(
989
+ response = self.make_request(
955
990
  {},
956
991
  f"slice/{slice_id}",
957
992
  requests_command=requests.get,
@@ -968,7 +1003,7 @@ class NucleusClient:
968
1003
  :return:
969
1004
  {}
970
1005
  """
971
- response = self._make_request(
1006
+ response = self.make_request(
972
1007
  {},
973
1008
  f"slice/{slice_id}",
974
1009
  requests_command=requests.delete,
@@ -1006,9 +1041,7 @@ class NucleusClient:
1006
1041
  if reference_ids:
1007
1042
  ids_to_append[REFERENCE_IDS_KEY] = reference_ids
1008
1043
 
1009
- response = self._make_request(
1010
- ids_to_append, f"slice/{slice_id}/append"
1011
- )
1044
+ response = self.make_request(ids_to_append, f"slice/{slice_id}/append")
1012
1045
  return response
1013
1046
 
1014
1047
  def list_autotags(self, dataset_id: str) -> List[str]:
@@ -1017,7 +1050,7 @@ class NucleusClient:
1017
1050
  :param dataset_id: internally controlled dataset_id
1018
1051
  :return: List[str] representing autotag_ids
1019
1052
  """
1020
- response = self._make_request(
1053
+ response = self.make_request(
1021
1054
  {},
1022
1055
  f"{dataset_id}/list_autotags",
1023
1056
  requests_command=requests.get,
@@ -1035,7 +1068,7 @@ class NucleusClient:
1035
1068
  :return:
1036
1069
  {}
1037
1070
  """
1038
- response = self._make_request(
1071
+ response = self.make_request(
1039
1072
  {},
1040
1073
  f"model/{model_id}",
1041
1074
  requests_command=requests.delete,
@@ -1043,82 +1076,40 @@ class NucleusClient:
1043
1076
  return response
1044
1077
 
1045
1078
  def create_custom_index(self, dataset_id: str, embeddings_url: str):
1046
- return self._make_request(
1079
+ return self.make_request(
1047
1080
  {EMBEDDINGS_URL_KEY: embeddings_url},
1048
1081
  f"indexing/{dataset_id}",
1049
1082
  requests_command=requests.post,
1050
1083
  )
1051
1084
 
1052
1085
  def check_index_status(self, job_id: str):
1053
- return self._make_request(
1086
+ return self.make_request(
1054
1087
  {},
1055
1088
  f"indexing/{job_id}",
1056
1089
  requests_command=requests.get,
1057
1090
  )
1058
1091
 
1059
1092
  def delete_custom_index(self, dataset_id: str):
1060
- return self._make_request(
1093
+ return self.make_request(
1061
1094
  {},
1062
1095
  f"indexing/{dataset_id}",
1063
1096
  requests_command=requests.delete,
1064
1097
  )
1065
1098
 
1066
- def _make_grequest(
1067
- self,
1068
- payload: dict,
1069
- route: str,
1070
- session=None,
1071
- requests_command: Callable = grequests.post,
1072
- local=True,
1073
- ):
1074
- """
1075
- makes a grequest to Nucleus endpoint
1076
- :param payload: file dict for multipart-formdata
1077
- :param route: route for the request
1078
- :param session: requests.session
1079
- :param requests_command: grequests.post, grequests.get, grequests.delete
1080
- :return: An async grequest object
1081
- """
1082
- adapter = HTTPAdapter(max_retries=Retry(total=3))
1083
- sess = requests.Session()
1084
- sess.mount("https://", adapter)
1085
- sess.mount("http://", adapter)
1086
-
1087
- endpoint = f"{NUCLEUS_ENDPOINT}/{route}"
1088
- logger.info("Posting to %s", endpoint)
1089
-
1090
- if local:
1091
- post = requests_command(
1092
- endpoint,
1093
- session=sess,
1094
- files=payload,
1095
- auth=(self.api_key, ""),
1096
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1097
- )
1098
- else:
1099
- post = requests_command(
1100
- endpoint,
1101
- session=sess,
1102
- json=payload,
1103
- headers={"Content-Type": "application/json"},
1104
- auth=(self.api_key, ""),
1105
- timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
1106
- )
1107
- return post
1108
-
1109
- def _make_request_raw(
1099
+ def make_request(
1110
1100
  self, payload: dict, route: str, requests_command=requests.post
1111
- ):
1101
+ ) -> dict:
1112
1102
  """
1113
- Makes a request to Nucleus endpoint. This method returns the raw
1114
- requests.Response object which is useful for unit testing.
1103
+ Makes a request to Nucleus endpoint and logs a warning if not
1104
+ successful.
1115
1105
 
1116
1106
  :param payload: given payload
1117
1107
  :param route: route for the request
1118
1108
  :param requests_command: requests.post, requests.get, requests.delete
1119
- :return: response
1109
+ :return: response JSON
1120
1110
  """
1121
- endpoint = f"{NUCLEUS_ENDPOINT}/{route}"
1111
+ endpoint = f"{self.endpoint}/{route}"
1112
+
1122
1113
  logger.info("Posting to %s", endpoint)
1123
1114
 
1124
1115
  response = requests_command(
@@ -1130,25 +1121,18 @@ class NucleusClient:
1130
1121
  )
1131
1122
  logger.info("API request has response code %s", response.status_code)
1132
1123
 
1133
- return response
1124
+ if not response.ok:
1125
+ self.handle_bad_response(endpoint, requests_command, response)
1134
1126
 
1135
- def _make_request(
1136
- self, payload: dict, route: str, requests_command=requests.post
1137
- ) -> dict:
1138
- """
1139
- Makes a request to Nucleus endpoint and logs a warning if not
1140
- successful.
1127
+ return response.json()
1141
1128
 
1142
- :param payload: given payload
1143
- :param route: route for the request
1144
- :param requests_command: requests.post, requests.get, requests.delete
1145
- :return: response JSON
1146
- """
1147
- response = self._make_request_raw(payload, route, requests_command)
1148
-
1149
- if getattr(response, "status_code") not in SUCCESS_STATUS_CODES:
1150
- logger.warning(response)
1151
-
1152
- return (
1153
- response.json()
1154
- ) # TODO: this line fails if response has code == 404
1129
+ def handle_bad_response(
1130
+ self,
1131
+ endpoint,
1132
+ requests_command,
1133
+ requests_response=None,
1134
+ aiohttp_response=None,
1135
+ ):
1136
+ raise NucleusAPIError(
1137
+ endpoint, requests_command, requests_response, aiohttp_response
1138
+ )