scale-nucleus 0.1.22__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. cli/client.py +14 -0
  2. cli/datasets.py +77 -0
  3. cli/helpers/__init__.py +0 -0
  4. cli/helpers/nucleus_url.py +10 -0
  5. cli/helpers/web_helper.py +40 -0
  6. cli/install_completion.py +33 -0
  7. cli/jobs.py +42 -0
  8. cli/models.py +35 -0
  9. cli/nu.py +42 -0
  10. cli/reference.py +8 -0
  11. cli/slices.py +62 -0
  12. cli/tests.py +121 -0
  13. nucleus/__init__.py +453 -699
  14. nucleus/annotation.py +435 -80
  15. nucleus/autocurate.py +9 -0
  16. nucleus/connection.py +87 -0
  17. nucleus/constants.py +12 -2
  18. nucleus/data_transfer_object/__init__.py +0 -0
  19. nucleus/data_transfer_object/dataset_details.py +9 -0
  20. nucleus/data_transfer_object/dataset_info.py +26 -0
  21. nucleus/data_transfer_object/dataset_size.py +5 -0
  22. nucleus/data_transfer_object/scenes_list.py +18 -0
  23. nucleus/dataset.py +1139 -215
  24. nucleus/dataset_item.py +130 -26
  25. nucleus/dataset_item_uploader.py +297 -0
  26. nucleus/deprecation_warning.py +32 -0
  27. nucleus/errors.py +21 -1
  28. nucleus/job.py +71 -3
  29. nucleus/logger.py +9 -0
  30. nucleus/metadata_manager.py +45 -0
  31. nucleus/metrics/__init__.py +10 -0
  32. nucleus/metrics/base.py +117 -0
  33. nucleus/metrics/categorization_metrics.py +197 -0
  34. nucleus/metrics/errors.py +7 -0
  35. nucleus/metrics/filters.py +40 -0
  36. nucleus/metrics/geometry.py +198 -0
  37. nucleus/metrics/metric_utils.py +28 -0
  38. nucleus/metrics/polygon_metrics.py +480 -0
  39. nucleus/metrics/polygon_utils.py +299 -0
  40. nucleus/model.py +121 -15
  41. nucleus/model_run.py +34 -57
  42. nucleus/payload_constructor.py +30 -18
  43. nucleus/prediction.py +259 -17
  44. nucleus/pydantic_base.py +26 -0
  45. nucleus/retry_strategy.py +4 -0
  46. nucleus/scene.py +204 -19
  47. nucleus/slice.py +230 -67
  48. nucleus/upload_response.py +20 -9
  49. nucleus/url_utils.py +4 -0
  50. nucleus/utils.py +139 -35
  51. nucleus/validate/__init__.py +24 -0
  52. nucleus/validate/client.py +168 -0
  53. nucleus/validate/constants.py +20 -0
  54. nucleus/validate/data_transfer_objects/__init__.py +0 -0
  55. nucleus/validate/data_transfer_objects/eval_function.py +81 -0
  56. nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
  57. nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
  58. nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
  59. nucleus/validate/errors.py +6 -0
  60. nucleus/validate/eval_functions/__init__.py +0 -0
  61. nucleus/validate/eval_functions/available_eval_functions.py +212 -0
  62. nucleus/validate/eval_functions/base_eval_function.py +60 -0
  63. nucleus/validate/scenario_test.py +143 -0
  64. nucleus/validate/scenario_test_evaluation.py +114 -0
  65. nucleus/validate/scenario_test_metric.py +14 -0
  66. nucleus/validate/utils.py +8 -0
  67. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
  68. scale_nucleus-0.6.4.dist-info/METADATA +213 -0
  69. scale_nucleus-0.6.4.dist-info/RECORD +71 -0
  70. {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
  71. scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
  72. scale_nucleus-0.1.22.dist-info/METADATA +0 -85
  73. scale_nucleus-0.1.22.dist-info/RECORD +0 -21
nucleus/dataset_item.py CHANGED
@@ -1,36 +1,53 @@
1
- from collections import Counter
2
1
  import json
3
2
  import os.path
3
+ from collections import Counter
4
4
  from dataclasses import dataclass
5
- from typing import Optional, Sequence, Dict, Any
6
5
  from enum import Enum
6
+ from typing import Any, Dict, Optional, Sequence
7
7
 
8
- from .annotation import is_local_path, Point3D
8
+ from .annotation import Point3D, is_local_path
9
9
  from .constants import (
10
+ CAMERA_PARAMS_KEY,
11
+ CX_KEY,
12
+ CY_KEY,
13
+ FX_KEY,
14
+ FY_KEY,
15
+ HEADING_KEY,
10
16
  IMAGE_URL_KEY,
11
17
  METADATA_KEY,
12
18
  ORIGINAL_IMAGE_URL_KEY,
13
- UPLOAD_TO_SCALE_KEY,
19
+ POINTCLOUD_URL_KEY,
20
+ POSITION_KEY,
14
21
  REFERENCE_ID_KEY,
15
22
  TYPE_KEY,
23
+ UPLOAD_TO_SCALE_KEY,
16
24
  URL_KEY,
17
- CAMERA_PARAMS_KEY,
18
- POINTCLOUD_URL_KEY,
25
+ W_KEY,
19
26
  X_KEY,
20
27
  Y_KEY,
21
28
  Z_KEY,
22
- W_KEY,
23
- POSITION_KEY,
24
- HEADING_KEY,
25
- FX_KEY,
26
- FY_KEY,
27
- CX_KEY,
28
- CY_KEY,
29
29
  )
30
30
 
31
31
 
32
32
  @dataclass
33
33
  class Quaternion:
34
+ """Quaternion objects are used to represent rotation.
35
+
36
+ We use the Hamilton/right-handed quaternion convention, where
37
+ ::
38
+
39
+ i^2 = j^2 = k^2 = ijk = -1
40
+
41
+ The quaternion represented by the tuple ``(x, y, z, w)`` is equal to
42
+ ``w + x*i + y*j + z*k``.
43
+
44
+ Parameters:
45
+ x (float): The x value.
46
+ y (float): The y value.
47
+ x (float): The z value.
48
+ w (float): The w value.
49
+ """
50
+
34
51
  x: float
35
52
  y: float
36
53
  z: float
@@ -38,11 +55,13 @@ class Quaternion:
38
55
 
39
56
  @classmethod
40
57
  def from_json(cls, payload: Dict[str, float]):
58
+ """Instantiates quaternion object from schematized JSON dict payload."""
41
59
  return cls(
42
60
  payload[X_KEY], payload[Y_KEY], payload[Z_KEY], payload[W_KEY]
43
61
  )
44
62
 
45
63
  def to_payload(self) -> dict:
64
+ """Serializes quaternion object to schematized JSON dict."""
46
65
  return {
47
66
  X_KEY: self.x,
48
67
  Y_KEY: self.y,
@@ -53,6 +72,20 @@ class Quaternion:
53
72
 
54
73
  @dataclass
55
74
  class CameraParams:
75
+ """Camera position/heading used to record the image.
76
+
77
+ Args:
78
+ position (:class:`Point3D`): World-normalized position of the camera
79
+ heading (:class:`Quaternion`): Vector4 indicating the quaternion of the
80
+ camera direction; note that the z-axis of the camera frame
81
+ represents the camera's optical axis. See `Heading Examples
82
+ <https://docs.scale.com/reference/data-types-and-the-frame-objects#heading-examples>`_.
83
+ fx (float): Focal length in x direction (in pixels).
84
+ fy (float): Focal length in y direction (in pixels).
85
+ cx (float): Principal point x value.
86
+ cy (float): Principal point y value.
87
+ """
88
+
56
89
  position: Point3D
57
90
  heading: Quaternion
58
91
  fx: float
@@ -62,6 +95,7 @@ class CameraParams:
62
95
 
63
96
  @classmethod
64
97
  def from_json(cls, payload: Dict[str, Any]):
98
+ """Instantiates camera params object from schematized JSON dict payload."""
65
99
  return cls(
66
100
  Point3D.from_json(payload[POSITION_KEY]),
67
101
  Quaternion.from_json(payload[HEADING_KEY]),
@@ -72,6 +106,7 @@ class CameraParams:
72
106
  )
73
107
 
74
108
  def to_payload(self) -> dict:
109
+ """Serializes camera params object to schematized JSON dict."""
75
110
  return {
76
111
  POSITION_KEY: self.position.to_payload(),
77
112
  HEADING_KEY: self.heading.to_payload(),
@@ -89,14 +124,87 @@ class DatasetItemType(Enum):
89
124
 
90
125
  @dataclass # pylint: disable=R0902
91
126
  class DatasetItem: # pylint: disable=R0902
127
+ """A dataset item is an image or pointcloud that has associated metadata.
128
+
129
+ Note: for 3D data, please include a :class:`CameraParams` object under a key named
130
+ "camera_params" within the metadata dictionary. This will allow for projecting
131
+ 3D annotations to any image within a scene.
132
+
133
+ Args:
134
+ image_location (Optional[str]): Required if pointcloud_location not present: The
135
+ location containing the image for the given row of data. This can be a
136
+ local path, or a remote URL. Remote formats supported include any URL
137
+ (``http://`` or ``https://``) or URIs for AWS S3, Azure, or GCS
138
+ (i.e. ``s3://``, ``gcs://``).
139
+
140
+ pointcloud_location (Optional[str]): Required if image_location not
141
+ present: The remote URL containing the pointcloud JSON. Remote
142
+ formats supported include any URL (``http://`` or ``https://``) or
143
+ URIs for AWS S3, Azure, or GCS (i.e. ``s3://``, ``gcs://``).
144
+
145
+ reference_id (Optional[str]): A user-specified identifier to reference the
146
+ item.
147
+
148
+ metadata (Optional[dict]): Extra information about the particular
149
+ dataset item. ints, floats, string values will be made searchable in
150
+ the query bar by the key in this dict For example, ``{"animal":
151
+ "dog"}`` will become searchable via ``metadata.animal = "dog"``.
152
+
153
+ Categorical data can be passed as a string and will be treated
154
+ categorically by Nucleus if there are less than 250 unique values in the
155
+ dataset. This means histograms of values in the "Insights" section and
156
+ autocomplete within the query bar.
157
+
158
+ Numerical metadata will generate histograms in the "Insights" section,
159
+ allow for sorting the results of any query, and can be used with the
160
+ modulo operator For example: metadata.frame_number % 5 = 0
161
+
162
+ All other types of metadata will be visible from the dataset item detail
163
+ view.
164
+
165
+ It is important that string and numerical metadata fields are consistent
166
+ - if a metadata field has a string value, then all metadata fields with
167
+ the same key should also have string values, and vice versa for numerical
168
+ metadata. If conflicting types are found, Nucleus will return an error
169
+ during upload!
170
+
171
+ The recommended way of adding or updating existing metadata is to re-run
172
+ the ingestion (dataset.append) with update=True, which will replace any
173
+ existing metadata with whatever your new ingestion run uses. This will
174
+ delete any metadata keys that are not present in the new ingestion run.
175
+ We have a cache based on image_location that will skip the need for a
176
+ re-upload of the images, so your second ingestion will be faster than
177
+ your first.
178
+
179
+ For 3D (sensor fusion) data, it is highly recommended to include
180
+ camera intrinsics the metadata of your camera image items. Nucleus
181
+ requires these intrinsics to create visualizations such as cuboid
182
+ projections. Refer to our `guide to uploading 3D data
183
+ <https://nucleus.scale.com/docs/uploading-3d-data>`_ for more
184
+ info.
185
+
186
+ .. todo ::
187
+ Shorten this once we have a guide migrated for metadata, or maybe link
188
+ from other places to here.
189
+
190
+ upload_to_scale (Optional[bool]): Set this to false in order to use
191
+ `privacy mode <https://nucleus.scale.com/docs/privacy-mode>`_.
192
+
193
+ Setting this to false means the actual data within the item (i.e. the
194
+ image or pointcloud) will not be uploaded to scale meaning that you can
195
+ send in links that are only accessible to certain users, and not to Scale.
196
+ """
197
+
92
198
  image_location: Optional[str] = None
93
- reference_id: Optional[str] = None
199
+ reference_id: str = (
200
+ "DUMMY_VALUE" # preserve argument ordering for backwards compatibility
201
+ )
94
202
  metadata: Optional[dict] = None
95
203
  pointcloud_location: Optional[str] = None
96
204
  upload_to_scale: Optional[bool] = True
97
205
 
98
206
  def __post_init__(self):
99
- assert self.reference_id is not None, "reference_id is required."
207
+ assert self.reference_id != "DUMMY_VALUE", "reference_id is required."
100
208
  assert bool(self.image_location) != bool(
101
209
  self.pointcloud_location
102
210
  ), "Must specify exactly one of the image_location, pointcloud_location parameters"
@@ -122,30 +230,25 @@ class DatasetItem: # pylint: disable=R0902
122
230
  )
123
231
 
124
232
  @classmethod
125
- def from_json(cls, payload: dict, is_scene=False):
233
+ def from_json(cls, payload: dict):
234
+ """Instantiates dataset item object from schematized JSON dict payload."""
126
235
  image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
127
236
  ORIGINAL_IMAGE_URL_KEY, None
128
237
  )
129
-
130
- if is_scene:
131
- return cls(
132
- image_location=image_url,
133
- pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
134
- reference_id=payload.get(REFERENCE_ID_KEY, None),
135
- metadata=payload.get(METADATA_KEY, {}),
136
- )
137
-
138
238
  return cls(
139
239
  image_location=image_url,
240
+ pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
140
241
  reference_id=payload.get(REFERENCE_ID_KEY, None),
141
242
  metadata=payload.get(METADATA_KEY, {}),
142
- upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, None),
243
+ upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, True),
143
244
  )
144
245
 
145
246
  def local_file_exists(self):
247
+ # TODO: make private
146
248
  return os.path.isfile(self.image_location)
147
249
 
148
250
  def to_payload(self, is_scene=False) -> dict:
251
+ """Serializes dataset item object to schematized JSON dict."""
149
252
  payload: Dict[str, Any] = {
150
253
  METADATA_KEY: self.metadata or {},
151
254
  }
@@ -170,6 +273,7 @@ class DatasetItem: # pylint: disable=R0902
170
273
  return payload
171
274
 
172
275
  def to_json(self) -> str:
276
+ """Serializes dataset item object to schematized JSON string."""
173
277
  return json.dumps(self.to_payload(), allow_nan=False)
174
278
 
175
279
 
@@ -0,0 +1,297 @@
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import time
5
+ from typing import TYPE_CHECKING, Any, List
6
+
7
+ import aiohttp
8
+ import nest_asyncio
9
+
10
+ from .constants import (
11
+ DATASET_ID_KEY,
12
+ DEFAULT_NETWORK_TIMEOUT_SEC,
13
+ IMAGE_KEY,
14
+ IMAGE_URL_KEY,
15
+ ITEMS_KEY,
16
+ UPDATE_KEY,
17
+ )
18
+ from .dataset_item import DatasetItem
19
+ from .errors import NotFoundError
20
+ from .logger import logger
21
+ from .payload_constructor import construct_append_payload
22
+ from .retry_strategy import RetryStrategy
23
+ from .upload_response import UploadResponse
24
+
25
+ if TYPE_CHECKING:
26
+ from . import NucleusClient
27
+
28
+
29
+ class DatasetItemUploader:
30
+ def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821
31
+ self.dataset_id = dataset_id
32
+ self._client = client
33
+
34
+ def upload(
35
+ self,
36
+ dataset_items: List[DatasetItem],
37
+ batch_size: int = 20,
38
+ update: bool = False,
39
+ ) -> UploadResponse:
40
+ """
41
+
42
+ Args:
43
+ dataset_items: Items to Upload
44
+ batch_size: How many items to pool together for a single request
45
+ update: Update records instead of overwriting
46
+
47
+ Returns:
48
+
49
+ """
50
+ local_items = []
51
+ remote_items = []
52
+
53
+ # Check local files exist before sending requests
54
+ for item in dataset_items:
55
+ if item.local:
56
+ if not item.local_file_exists():
57
+ raise NotFoundError()
58
+ local_items.append(item)
59
+ else:
60
+ remote_items.append(item)
61
+
62
+ local_batches = [
63
+ local_items[i : i + batch_size]
64
+ for i in range(0, len(local_items), batch_size)
65
+ ]
66
+
67
+ remote_batches = [
68
+ remote_items[i : i + batch_size]
69
+ for i in range(0, len(remote_items), batch_size)
70
+ ]
71
+
72
+ agg_response = UploadResponse(json={DATASET_ID_KEY: self.dataset_id})
73
+
74
+ async_responses: List[Any] = []
75
+
76
+ if local_batches:
77
+ tqdm_local_batches = self._client.tqdm_bar(
78
+ local_batches, desc="Local file batches"
79
+ )
80
+
81
+ for batch in tqdm_local_batches:
82
+ payload = construct_append_payload(batch, update)
83
+ responses = self._process_append_requests_local(
84
+ self.dataset_id, payload, update
85
+ )
86
+ async_responses.extend(responses)
87
+
88
+ if remote_batches:
89
+ tqdm_remote_batches = self._client.tqdm_bar(
90
+ remote_batches, desc="Remote file batches"
91
+ )
92
+ for batch in tqdm_remote_batches:
93
+ payload = construct_append_payload(batch, update)
94
+ responses = self._process_append_requests(
95
+ dataset_id=self.dataset_id,
96
+ payload=payload,
97
+ update=update,
98
+ batch_size=batch_size,
99
+ )
100
+ async_responses.extend(responses)
101
+
102
+ for response in async_responses:
103
+ agg_response.update_response(response)
104
+
105
+ return agg_response
106
+
107
+ def _process_append_requests_local(
108
+ self,
109
+ dataset_id: str,
110
+ payload: dict,
111
+ update: bool, # TODO: understand how to pass this in.
112
+ local_batch_size: int = 10,
113
+ ):
114
+ def get_files(batch):
115
+ for item in batch:
116
+ item[UPDATE_KEY] = update
117
+ request_payload = [
118
+ (
119
+ ITEMS_KEY,
120
+ (
121
+ None,
122
+ json.dumps(batch, allow_nan=False),
123
+ "application/json",
124
+ ),
125
+ )
126
+ ]
127
+ for item in batch:
128
+ image = open( # pylint: disable=R1732
129
+ item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732
130
+ ) # pylint: disable=R1732
131
+ img_name = os.path.basename(image.name)
132
+ img_type = (
133
+ f"image/{os.path.splitext(image.name)[1].strip('.')}"
134
+ )
135
+ request_payload.append(
136
+ (IMAGE_KEY, (img_name, image, img_type))
137
+ )
138
+ return request_payload
139
+
140
+ items = payload[ITEMS_KEY]
141
+ responses: List[Any] = []
142
+ files_per_request = []
143
+ payload_items = []
144
+ for i in range(0, len(items), local_batch_size):
145
+ batch = items[i : i + local_batch_size]
146
+ files_per_request.append(get_files(batch))
147
+ payload_items.append(batch)
148
+
149
+ future = self.make_many_files_requests_asynchronously(
150
+ files_per_request,
151
+ f"dataset/{dataset_id}/append",
152
+ )
153
+
154
+ try:
155
+ loop = asyncio.get_event_loop()
156
+ except RuntimeError: # no event loop running:
157
+ loop = asyncio.new_event_loop()
158
+ responses = loop.run_until_complete(future)
159
+ else:
160
+ nest_asyncio.apply(loop)
161
+ return loop.run_until_complete(future)
162
+
163
+ def close_files(request_items):
164
+ for item in request_items:
165
+ # file buffer in location [1][1]
166
+ if item[0] == IMAGE_KEY:
167
+ item[1][1].close()
168
+
169
+ # don't forget to close all open files
170
+ for p in files_per_request:
171
+ close_files(p)
172
+
173
+ return responses
174
+
175
+ async def make_many_files_requests_asynchronously(
176
+ self, files_per_request, route
177
+ ):
178
+ """
179
+ Makes an async post request with files to a Nucleus endpoint.
180
+
181
+ :param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type))
182
+ name will become the name by which the multer can build an array.
183
+ :param route: route for the request
184
+ :return: awaitable list(response)
185
+ """
186
+ async with aiohttp.ClientSession() as session:
187
+ tasks = [
188
+ asyncio.ensure_future(
189
+ self._make_files_request(
190
+ files=files, route=route, session=session
191
+ )
192
+ )
193
+ for files in files_per_request
194
+ ]
195
+ return await asyncio.gather(*tasks)
196
+
197
+ async def _make_files_request(
198
+ self,
199
+ files,
200
+ route: str,
201
+ session: aiohttp.ClientSession,
202
+ retry_attempt=0,
203
+ max_retries=3,
204
+ sleep_intervals=(1, 3, 9),
205
+ ):
206
+ """
207
+ Makes an async post request with files to a Nucleus endpoint.
208
+
209
+ :param files: A list of tuples (name, (filename, file_pointer, file_type))
210
+ :param route: route for the request
211
+ :param session: Session to use for post.
212
+ :return: response
213
+ """
214
+ endpoint = f"{self._client.endpoint}/{route}"
215
+
216
+ logger.info("Posting to %s", endpoint)
217
+
218
+ form = aiohttp.FormData()
219
+
220
+ for file in files:
221
+ form.add_field(
222
+ name=file[0],
223
+ filename=file[1][0],
224
+ value=file[1][1],
225
+ content_type=file[1][2],
226
+ )
227
+
228
+ for sleep_time in RetryStrategy.sleep_times + [-1]:
229
+
230
+ async with session.post(
231
+ endpoint,
232
+ data=form,
233
+ auth=aiohttp.BasicAuth(self._client.api_key, ""),
234
+ timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
235
+ ) as response:
236
+ logger.info(
237
+ "API request has response code %s", response.status
238
+ )
239
+
240
+ try:
241
+ data = await response.json()
242
+ except aiohttp.client_exceptions.ContentTypeError:
243
+ # In case of 404, the server returns text
244
+ data = await response.text()
245
+ if (
246
+ response.status in RetryStrategy.statuses
247
+ and sleep_time != -1
248
+ ):
249
+ time.sleep(sleep_time)
250
+ continue
251
+
252
+ if not response.ok:
253
+ if retry_attempt < max_retries:
254
+ time.sleep(sleep_intervals[retry_attempt])
255
+ retry_attempt += 1
256
+ return self._make_files_request(
257
+ files,
258
+ route,
259
+ session,
260
+ retry_attempt,
261
+ max_retries,
262
+ sleep_intervals,
263
+ )
264
+ else:
265
+ self._client.handle_bad_response(
266
+ endpoint,
267
+ session.post,
268
+ aiohttp_response=(
269
+ response.status,
270
+ response.reason,
271
+ data,
272
+ ),
273
+ )
274
+
275
+ return data
276
+
277
+ def _process_append_requests(
278
+ self,
279
+ dataset_id: str,
280
+ payload: dict,
281
+ update: bool,
282
+ batch_size: int = 20,
283
+ ):
284
+ items = payload[ITEMS_KEY]
285
+ payloads = [
286
+ # batch_size images per request
287
+ {ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
288
+ for i in range(0, len(items), batch_size)
289
+ ]
290
+
291
+ return [
292
+ self._client.make_request(
293
+ payload,
294
+ f"dataset/{dataset_id}/append",
295
+ )
296
+ for payload in payloads
297
+ ]
@@ -0,0 +1,32 @@
1
+ import warnings
2
+ from functools import wraps
3
+ from typing import Callable
4
+
5
+
6
+ def deprecated(msg: str):
7
+ """Adds a deprecation warning via the `warnings` lib which can be caught by linters.
8
+
9
+ Args:
10
+ msg: State reason of deprecation and point towards preferred practices
11
+
12
+ Returns:
13
+ Deprecation wrapped function
14
+ """
15
+
16
+ def decorator(func: Callable):
17
+ @wraps(func)
18
+ def wrapper(*args, **kwargs):
19
+ # NOTE: __qualname looks a lot better for method calls
20
+ name = (
21
+ func.__qualname__
22
+ if hasattr(func, "__qualname__")
23
+ else func.__name__
24
+ )
25
+ full_message = f"Calling {name} is deprecated: {msg}"
26
+ # NOTE: stacklevel=2 makes sure that the level is applied to the decorated function
27
+ warnings.warn(full_message, DeprecationWarning, stacklevel=2)
28
+ return func(*args, **kwargs)
29
+
30
+ return wrapper
31
+
32
+ return decorator
nucleus/errors.py CHANGED
@@ -4,6 +4,11 @@ nucleus_client_version = pkg_resources.get_distribution(
4
4
  "scale-nucleus"
5
5
  ).version
6
6
 
7
+ INFRA_FLAKE_MESSAGES = [
8
+ "downstream duration timeout",
9
+ "upstream connect error or disconnect/reset before headers. reset reason: local reset",
10
+ ]
11
+
7
12
 
8
13
  class ModelCreationError(Exception):
9
14
  def __init__(self, message="Could not create the model"):
@@ -35,7 +40,7 @@ class NucleusAPIError(Exception):
35
40
  def __init__(
36
41
  self, endpoint, command, requests_response=None, aiohttp_response=None
37
42
  ):
38
- message = f"Your client is on version {nucleus_client_version}. Before reporting this error, please make sure you update to the latest version of the client by running pip install --upgrade scale-nucleus\n"
43
+ message = f"Your client is on version {nucleus_client_version}. If you have not recently done so, please make sure you have updated to the latest version of the client by running pip install --upgrade scale-nucleus\n"
39
44
  if requests_response is not None:
40
45
  message += f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
41
46
  if hasattr(requests_response, "text"):
@@ -50,4 +55,19 @@ class NucleusAPIError(Exception):
50
55
  if data:
51
56
  message += f"\nThe detailed error is:\n{data}"
52
57
 
58
+ if any(
59
+ infra_flake_message in message
60
+ for infra_flake_message in INFRA_FLAKE_MESSAGES
61
+ ):
62
+ message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
63
+
53
64
  super().__init__(message)
65
+
66
+
67
+ class NoAPIKey(Exception):
68
+ def __init__(
69
+ self,
70
+ message="You need to pass an API key to the NucleusClient or set the environment variable NUCLEUS_API_KEY",
71
+ ):
72
+ self.message = message
73
+ super().__init__(self.message)