scale-nucleus 0.1.22__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli/client.py +14 -0
- cli/datasets.py +77 -0
- cli/helpers/__init__.py +0 -0
- cli/helpers/nucleus_url.py +10 -0
- cli/helpers/web_helper.py +40 -0
- cli/install_completion.py +33 -0
- cli/jobs.py +42 -0
- cli/models.py +35 -0
- cli/nu.py +42 -0
- cli/reference.py +8 -0
- cli/slices.py +62 -0
- cli/tests.py +121 -0
- nucleus/__init__.py +453 -699
- nucleus/annotation.py +435 -80
- nucleus/autocurate.py +9 -0
- nucleus/connection.py +87 -0
- nucleus/constants.py +12 -2
- nucleus/data_transfer_object/__init__.py +0 -0
- nucleus/data_transfer_object/dataset_details.py +9 -0
- nucleus/data_transfer_object/dataset_info.py +26 -0
- nucleus/data_transfer_object/dataset_size.py +5 -0
- nucleus/data_transfer_object/scenes_list.py +18 -0
- nucleus/dataset.py +1139 -215
- nucleus/dataset_item.py +130 -26
- nucleus/dataset_item_uploader.py +297 -0
- nucleus/deprecation_warning.py +32 -0
- nucleus/errors.py +21 -1
- nucleus/job.py +71 -3
- nucleus/logger.py +9 -0
- nucleus/metadata_manager.py +45 -0
- nucleus/metrics/__init__.py +10 -0
- nucleus/metrics/base.py +117 -0
- nucleus/metrics/categorization_metrics.py +197 -0
- nucleus/metrics/errors.py +7 -0
- nucleus/metrics/filters.py +40 -0
- nucleus/metrics/geometry.py +198 -0
- nucleus/metrics/metric_utils.py +28 -0
- nucleus/metrics/polygon_metrics.py +480 -0
- nucleus/metrics/polygon_utils.py +299 -0
- nucleus/model.py +121 -15
- nucleus/model_run.py +34 -57
- nucleus/payload_constructor.py +30 -18
- nucleus/prediction.py +259 -17
- nucleus/pydantic_base.py +26 -0
- nucleus/retry_strategy.py +4 -0
- nucleus/scene.py +204 -19
- nucleus/slice.py +230 -67
- nucleus/upload_response.py +20 -9
- nucleus/url_utils.py +4 -0
- nucleus/utils.py +139 -35
- nucleus/validate/__init__.py +24 -0
- nucleus/validate/client.py +168 -0
- nucleus/validate/constants.py +20 -0
- nucleus/validate/data_transfer_objects/__init__.py +0 -0
- nucleus/validate/data_transfer_objects/eval_function.py +81 -0
- nucleus/validate/data_transfer_objects/scenario_test.py +19 -0
- nucleus/validate/data_transfer_objects/scenario_test_evaluations.py +11 -0
- nucleus/validate/data_transfer_objects/scenario_test_metric.py +12 -0
- nucleus/validate/errors.py +6 -0
- nucleus/validate/eval_functions/__init__.py +0 -0
- nucleus/validate/eval_functions/available_eval_functions.py +212 -0
- nucleus/validate/eval_functions/base_eval_function.py +60 -0
- nucleus/validate/scenario_test.py +143 -0
- nucleus/validate/scenario_test_evaluation.py +114 -0
- nucleus/validate/scenario_test_metric.py +14 -0
- nucleus/validate/utils.py +8 -0
- {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/LICENSE +0 -0
- scale_nucleus-0.6.4.dist-info/METADATA +213 -0
- scale_nucleus-0.6.4.dist-info/RECORD +71 -0
- {scale_nucleus-0.1.22.dist-info → scale_nucleus-0.6.4.dist-info}/WHEEL +1 -1
- scale_nucleus-0.6.4.dist-info/entry_points.txt +3 -0
- scale_nucleus-0.1.22.dist-info/METADATA +0 -85
- scale_nucleus-0.1.22.dist-info/RECORD +0 -21
nucleus/dataset_item.py
CHANGED
@@ -1,36 +1,53 @@
|
|
1
|
-
from collections import Counter
|
2
1
|
import json
|
3
2
|
import os.path
|
3
|
+
from collections import Counter
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Optional, Sequence, Dict, Any
|
6
5
|
from enum import Enum
|
6
|
+
from typing import Any, Dict, Optional, Sequence
|
7
7
|
|
8
|
-
from .annotation import
|
8
|
+
from .annotation import Point3D, is_local_path
|
9
9
|
from .constants import (
|
10
|
+
CAMERA_PARAMS_KEY,
|
11
|
+
CX_KEY,
|
12
|
+
CY_KEY,
|
13
|
+
FX_KEY,
|
14
|
+
FY_KEY,
|
15
|
+
HEADING_KEY,
|
10
16
|
IMAGE_URL_KEY,
|
11
17
|
METADATA_KEY,
|
12
18
|
ORIGINAL_IMAGE_URL_KEY,
|
13
|
-
|
19
|
+
POINTCLOUD_URL_KEY,
|
20
|
+
POSITION_KEY,
|
14
21
|
REFERENCE_ID_KEY,
|
15
22
|
TYPE_KEY,
|
23
|
+
UPLOAD_TO_SCALE_KEY,
|
16
24
|
URL_KEY,
|
17
|
-
|
18
|
-
POINTCLOUD_URL_KEY,
|
25
|
+
W_KEY,
|
19
26
|
X_KEY,
|
20
27
|
Y_KEY,
|
21
28
|
Z_KEY,
|
22
|
-
W_KEY,
|
23
|
-
POSITION_KEY,
|
24
|
-
HEADING_KEY,
|
25
|
-
FX_KEY,
|
26
|
-
FY_KEY,
|
27
|
-
CX_KEY,
|
28
|
-
CY_KEY,
|
29
29
|
)
|
30
30
|
|
31
31
|
|
32
32
|
@dataclass
|
33
33
|
class Quaternion:
|
34
|
+
"""Quaternion objects are used to represent rotation.
|
35
|
+
|
36
|
+
We use the Hamilton/right-handed quaternion convention, where
|
37
|
+
::
|
38
|
+
|
39
|
+
i^2 = j^2 = k^2 = ijk = -1
|
40
|
+
|
41
|
+
The quaternion represented by the tuple ``(x, y, z, w)`` is equal to
|
42
|
+
``w + x*i + y*j + z*k``.
|
43
|
+
|
44
|
+
Parameters:
|
45
|
+
x (float): The x value.
|
46
|
+
y (float): The y value.
|
47
|
+
x (float): The z value.
|
48
|
+
w (float): The w value.
|
49
|
+
"""
|
50
|
+
|
34
51
|
x: float
|
35
52
|
y: float
|
36
53
|
z: float
|
@@ -38,11 +55,13 @@ class Quaternion:
|
|
38
55
|
|
39
56
|
@classmethod
|
40
57
|
def from_json(cls, payload: Dict[str, float]):
|
58
|
+
"""Instantiates quaternion object from schematized JSON dict payload."""
|
41
59
|
return cls(
|
42
60
|
payload[X_KEY], payload[Y_KEY], payload[Z_KEY], payload[W_KEY]
|
43
61
|
)
|
44
62
|
|
45
63
|
def to_payload(self) -> dict:
|
64
|
+
"""Serializes quaternion object to schematized JSON dict."""
|
46
65
|
return {
|
47
66
|
X_KEY: self.x,
|
48
67
|
Y_KEY: self.y,
|
@@ -53,6 +72,20 @@ class Quaternion:
|
|
53
72
|
|
54
73
|
@dataclass
|
55
74
|
class CameraParams:
|
75
|
+
"""Camera position/heading used to record the image.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
position (:class:`Point3D`): World-normalized position of the camera
|
79
|
+
heading (:class:`Quaternion`): Vector4 indicating the quaternion of the
|
80
|
+
camera direction; note that the z-axis of the camera frame
|
81
|
+
represents the camera's optical axis. See `Heading Examples
|
82
|
+
<https://docs.scale.com/reference/data-types-and-the-frame-objects#heading-examples>`_.
|
83
|
+
fx (float): Focal length in x direction (in pixels).
|
84
|
+
fy (float): Focal length in y direction (in pixels).
|
85
|
+
cx (float): Principal point x value.
|
86
|
+
cy (float): Principal point y value.
|
87
|
+
"""
|
88
|
+
|
56
89
|
position: Point3D
|
57
90
|
heading: Quaternion
|
58
91
|
fx: float
|
@@ -62,6 +95,7 @@ class CameraParams:
|
|
62
95
|
|
63
96
|
@classmethod
|
64
97
|
def from_json(cls, payload: Dict[str, Any]):
|
98
|
+
"""Instantiates camera params object from schematized JSON dict payload."""
|
65
99
|
return cls(
|
66
100
|
Point3D.from_json(payload[POSITION_KEY]),
|
67
101
|
Quaternion.from_json(payload[HEADING_KEY]),
|
@@ -72,6 +106,7 @@ class CameraParams:
|
|
72
106
|
)
|
73
107
|
|
74
108
|
def to_payload(self) -> dict:
|
109
|
+
"""Serializes camera params object to schematized JSON dict."""
|
75
110
|
return {
|
76
111
|
POSITION_KEY: self.position.to_payload(),
|
77
112
|
HEADING_KEY: self.heading.to_payload(),
|
@@ -89,14 +124,87 @@ class DatasetItemType(Enum):
|
|
89
124
|
|
90
125
|
@dataclass # pylint: disable=R0902
|
91
126
|
class DatasetItem: # pylint: disable=R0902
|
127
|
+
"""A dataset item is an image or pointcloud that has associated metadata.
|
128
|
+
|
129
|
+
Note: for 3D data, please include a :class:`CameraParams` object under a key named
|
130
|
+
"camera_params" within the metadata dictionary. This will allow for projecting
|
131
|
+
3D annotations to any image within a scene.
|
132
|
+
|
133
|
+
Args:
|
134
|
+
image_location (Optional[str]): Required if pointcloud_location not present: The
|
135
|
+
location containing the image for the given row of data. This can be a
|
136
|
+
local path, or a remote URL. Remote formats supported include any URL
|
137
|
+
(``http://`` or ``https://``) or URIs for AWS S3, Azure, or GCS
|
138
|
+
(i.e. ``s3://``, ``gcs://``).
|
139
|
+
|
140
|
+
pointcloud_location (Optional[str]): Required if image_location not
|
141
|
+
present: The remote URL containing the pointcloud JSON. Remote
|
142
|
+
formats supported include any URL (``http://`` or ``https://``) or
|
143
|
+
URIs for AWS S3, Azure, or GCS (i.e. ``s3://``, ``gcs://``).
|
144
|
+
|
145
|
+
reference_id (Optional[str]): A user-specified identifier to reference the
|
146
|
+
item.
|
147
|
+
|
148
|
+
metadata (Optional[dict]): Extra information about the particular
|
149
|
+
dataset item. ints, floats, string values will be made searchable in
|
150
|
+
the query bar by the key in this dict For example, ``{"animal":
|
151
|
+
"dog"}`` will become searchable via ``metadata.animal = "dog"``.
|
152
|
+
|
153
|
+
Categorical data can be passed as a string and will be treated
|
154
|
+
categorically by Nucleus if there are less than 250 unique values in the
|
155
|
+
dataset. This means histograms of values in the "Insights" section and
|
156
|
+
autocomplete within the query bar.
|
157
|
+
|
158
|
+
Numerical metadata will generate histograms in the "Insights" section,
|
159
|
+
allow for sorting the results of any query, and can be used with the
|
160
|
+
modulo operator For example: metadata.frame_number % 5 = 0
|
161
|
+
|
162
|
+
All other types of metadata will be visible from the dataset item detail
|
163
|
+
view.
|
164
|
+
|
165
|
+
It is important that string and numerical metadata fields are consistent
|
166
|
+
- if a metadata field has a string value, then all metadata fields with
|
167
|
+
the same key should also have string values, and vice versa for numerical
|
168
|
+
metadata. If conflicting types are found, Nucleus will return an error
|
169
|
+
during upload!
|
170
|
+
|
171
|
+
The recommended way of adding or updating existing metadata is to re-run
|
172
|
+
the ingestion (dataset.append) with update=True, which will replace any
|
173
|
+
existing metadata with whatever your new ingestion run uses. This will
|
174
|
+
delete any metadata keys that are not present in the new ingestion run.
|
175
|
+
We have a cache based on image_location that will skip the need for a
|
176
|
+
re-upload of the images, so your second ingestion will be faster than
|
177
|
+
your first.
|
178
|
+
|
179
|
+
For 3D (sensor fusion) data, it is highly recommended to include
|
180
|
+
camera intrinsics the metadata of your camera image items. Nucleus
|
181
|
+
requires these intrinsics to create visualizations such as cuboid
|
182
|
+
projections. Refer to our `guide to uploading 3D data
|
183
|
+
<https://nucleus.scale.com/docs/uploading-3d-data>`_ for more
|
184
|
+
info.
|
185
|
+
|
186
|
+
.. todo ::
|
187
|
+
Shorten this once we have a guide migrated for metadata, or maybe link
|
188
|
+
from other places to here.
|
189
|
+
|
190
|
+
upload_to_scale (Optional[bool]): Set this to false in order to use
|
191
|
+
`privacy mode <https://nucleus.scale.com/docs/privacy-mode>`_.
|
192
|
+
|
193
|
+
Setting this to false means the actual data within the item (i.e. the
|
194
|
+
image or pointcloud) will not be uploaded to scale meaning that you can
|
195
|
+
send in links that are only accessible to certain users, and not to Scale.
|
196
|
+
"""
|
197
|
+
|
92
198
|
image_location: Optional[str] = None
|
93
|
-
reference_id:
|
199
|
+
reference_id: str = (
|
200
|
+
"DUMMY_VALUE" # preserve argument ordering for backwards compatibility
|
201
|
+
)
|
94
202
|
metadata: Optional[dict] = None
|
95
203
|
pointcloud_location: Optional[str] = None
|
96
204
|
upload_to_scale: Optional[bool] = True
|
97
205
|
|
98
206
|
def __post_init__(self):
|
99
|
-
assert self.reference_id
|
207
|
+
assert self.reference_id != "DUMMY_VALUE", "reference_id is required."
|
100
208
|
assert bool(self.image_location) != bool(
|
101
209
|
self.pointcloud_location
|
102
210
|
), "Must specify exactly one of the image_location, pointcloud_location parameters"
|
@@ -122,30 +230,25 @@ class DatasetItem: # pylint: disable=R0902
|
|
122
230
|
)
|
123
231
|
|
124
232
|
@classmethod
|
125
|
-
def from_json(cls, payload: dict
|
233
|
+
def from_json(cls, payload: dict):
|
234
|
+
"""Instantiates dataset item object from schematized JSON dict payload."""
|
126
235
|
image_url = payload.get(IMAGE_URL_KEY, None) or payload.get(
|
127
236
|
ORIGINAL_IMAGE_URL_KEY, None
|
128
237
|
)
|
129
|
-
|
130
|
-
if is_scene:
|
131
|
-
return cls(
|
132
|
-
image_location=image_url,
|
133
|
-
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
|
134
|
-
reference_id=payload.get(REFERENCE_ID_KEY, None),
|
135
|
-
metadata=payload.get(METADATA_KEY, {}),
|
136
|
-
)
|
137
|
-
|
138
238
|
return cls(
|
139
239
|
image_location=image_url,
|
240
|
+
pointcloud_location=payload.get(POINTCLOUD_URL_KEY, None),
|
140
241
|
reference_id=payload.get(REFERENCE_ID_KEY, None),
|
141
242
|
metadata=payload.get(METADATA_KEY, {}),
|
142
|
-
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY,
|
243
|
+
upload_to_scale=payload.get(UPLOAD_TO_SCALE_KEY, True),
|
143
244
|
)
|
144
245
|
|
145
246
|
def local_file_exists(self):
|
247
|
+
# TODO: make private
|
146
248
|
return os.path.isfile(self.image_location)
|
147
249
|
|
148
250
|
def to_payload(self, is_scene=False) -> dict:
|
251
|
+
"""Serializes dataset item object to schematized JSON dict."""
|
149
252
|
payload: Dict[str, Any] = {
|
150
253
|
METADATA_KEY: self.metadata or {},
|
151
254
|
}
|
@@ -170,6 +273,7 @@ class DatasetItem: # pylint: disable=R0902
|
|
170
273
|
return payload
|
171
274
|
|
172
275
|
def to_json(self) -> str:
|
276
|
+
"""Serializes dataset item object to schematized JSON string."""
|
173
277
|
return json.dumps(self.to_payload(), allow_nan=False)
|
174
278
|
|
175
279
|
|
@@ -0,0 +1,297 @@
|
|
1
|
+
import asyncio
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import TYPE_CHECKING, Any, List
|
6
|
+
|
7
|
+
import aiohttp
|
8
|
+
import nest_asyncio
|
9
|
+
|
10
|
+
from .constants import (
|
11
|
+
DATASET_ID_KEY,
|
12
|
+
DEFAULT_NETWORK_TIMEOUT_SEC,
|
13
|
+
IMAGE_KEY,
|
14
|
+
IMAGE_URL_KEY,
|
15
|
+
ITEMS_KEY,
|
16
|
+
UPDATE_KEY,
|
17
|
+
)
|
18
|
+
from .dataset_item import DatasetItem
|
19
|
+
from .errors import NotFoundError
|
20
|
+
from .logger import logger
|
21
|
+
from .payload_constructor import construct_append_payload
|
22
|
+
from .retry_strategy import RetryStrategy
|
23
|
+
from .upload_response import UploadResponse
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from . import NucleusClient
|
27
|
+
|
28
|
+
|
29
|
+
class DatasetItemUploader:
|
30
|
+
def __init__(self, dataset_id: str, client: "NucleusClient"): # noqa: F821
|
31
|
+
self.dataset_id = dataset_id
|
32
|
+
self._client = client
|
33
|
+
|
34
|
+
def upload(
|
35
|
+
self,
|
36
|
+
dataset_items: List[DatasetItem],
|
37
|
+
batch_size: int = 20,
|
38
|
+
update: bool = False,
|
39
|
+
) -> UploadResponse:
|
40
|
+
"""
|
41
|
+
|
42
|
+
Args:
|
43
|
+
dataset_items: Items to Upload
|
44
|
+
batch_size: How many items to pool together for a single request
|
45
|
+
update: Update records instead of overwriting
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
|
49
|
+
"""
|
50
|
+
local_items = []
|
51
|
+
remote_items = []
|
52
|
+
|
53
|
+
# Check local files exist before sending requests
|
54
|
+
for item in dataset_items:
|
55
|
+
if item.local:
|
56
|
+
if not item.local_file_exists():
|
57
|
+
raise NotFoundError()
|
58
|
+
local_items.append(item)
|
59
|
+
else:
|
60
|
+
remote_items.append(item)
|
61
|
+
|
62
|
+
local_batches = [
|
63
|
+
local_items[i : i + batch_size]
|
64
|
+
for i in range(0, len(local_items), batch_size)
|
65
|
+
]
|
66
|
+
|
67
|
+
remote_batches = [
|
68
|
+
remote_items[i : i + batch_size]
|
69
|
+
for i in range(0, len(remote_items), batch_size)
|
70
|
+
]
|
71
|
+
|
72
|
+
agg_response = UploadResponse(json={DATASET_ID_KEY: self.dataset_id})
|
73
|
+
|
74
|
+
async_responses: List[Any] = []
|
75
|
+
|
76
|
+
if local_batches:
|
77
|
+
tqdm_local_batches = self._client.tqdm_bar(
|
78
|
+
local_batches, desc="Local file batches"
|
79
|
+
)
|
80
|
+
|
81
|
+
for batch in tqdm_local_batches:
|
82
|
+
payload = construct_append_payload(batch, update)
|
83
|
+
responses = self._process_append_requests_local(
|
84
|
+
self.dataset_id, payload, update
|
85
|
+
)
|
86
|
+
async_responses.extend(responses)
|
87
|
+
|
88
|
+
if remote_batches:
|
89
|
+
tqdm_remote_batches = self._client.tqdm_bar(
|
90
|
+
remote_batches, desc="Remote file batches"
|
91
|
+
)
|
92
|
+
for batch in tqdm_remote_batches:
|
93
|
+
payload = construct_append_payload(batch, update)
|
94
|
+
responses = self._process_append_requests(
|
95
|
+
dataset_id=self.dataset_id,
|
96
|
+
payload=payload,
|
97
|
+
update=update,
|
98
|
+
batch_size=batch_size,
|
99
|
+
)
|
100
|
+
async_responses.extend(responses)
|
101
|
+
|
102
|
+
for response in async_responses:
|
103
|
+
agg_response.update_response(response)
|
104
|
+
|
105
|
+
return agg_response
|
106
|
+
|
107
|
+
def _process_append_requests_local(
|
108
|
+
self,
|
109
|
+
dataset_id: str,
|
110
|
+
payload: dict,
|
111
|
+
update: bool, # TODO: understand how to pass this in.
|
112
|
+
local_batch_size: int = 10,
|
113
|
+
):
|
114
|
+
def get_files(batch):
|
115
|
+
for item in batch:
|
116
|
+
item[UPDATE_KEY] = update
|
117
|
+
request_payload = [
|
118
|
+
(
|
119
|
+
ITEMS_KEY,
|
120
|
+
(
|
121
|
+
None,
|
122
|
+
json.dumps(batch, allow_nan=False),
|
123
|
+
"application/json",
|
124
|
+
),
|
125
|
+
)
|
126
|
+
]
|
127
|
+
for item in batch:
|
128
|
+
image = open( # pylint: disable=R1732
|
129
|
+
item.get(IMAGE_URL_KEY), "rb" # pylint: disable=R1732
|
130
|
+
) # pylint: disable=R1732
|
131
|
+
img_name = os.path.basename(image.name)
|
132
|
+
img_type = (
|
133
|
+
f"image/{os.path.splitext(image.name)[1].strip('.')}"
|
134
|
+
)
|
135
|
+
request_payload.append(
|
136
|
+
(IMAGE_KEY, (img_name, image, img_type))
|
137
|
+
)
|
138
|
+
return request_payload
|
139
|
+
|
140
|
+
items = payload[ITEMS_KEY]
|
141
|
+
responses: List[Any] = []
|
142
|
+
files_per_request = []
|
143
|
+
payload_items = []
|
144
|
+
for i in range(0, len(items), local_batch_size):
|
145
|
+
batch = items[i : i + local_batch_size]
|
146
|
+
files_per_request.append(get_files(batch))
|
147
|
+
payload_items.append(batch)
|
148
|
+
|
149
|
+
future = self.make_many_files_requests_asynchronously(
|
150
|
+
files_per_request,
|
151
|
+
f"dataset/{dataset_id}/append",
|
152
|
+
)
|
153
|
+
|
154
|
+
try:
|
155
|
+
loop = asyncio.get_event_loop()
|
156
|
+
except RuntimeError: # no event loop running:
|
157
|
+
loop = asyncio.new_event_loop()
|
158
|
+
responses = loop.run_until_complete(future)
|
159
|
+
else:
|
160
|
+
nest_asyncio.apply(loop)
|
161
|
+
return loop.run_until_complete(future)
|
162
|
+
|
163
|
+
def close_files(request_items):
|
164
|
+
for item in request_items:
|
165
|
+
# file buffer in location [1][1]
|
166
|
+
if item[0] == IMAGE_KEY:
|
167
|
+
item[1][1].close()
|
168
|
+
|
169
|
+
# don't forget to close all open files
|
170
|
+
for p in files_per_request:
|
171
|
+
close_files(p)
|
172
|
+
|
173
|
+
return responses
|
174
|
+
|
175
|
+
async def make_many_files_requests_asynchronously(
|
176
|
+
self, files_per_request, route
|
177
|
+
):
|
178
|
+
"""
|
179
|
+
Makes an async post request with files to a Nucleus endpoint.
|
180
|
+
|
181
|
+
:param files_per_request: A list of lists of tuples (name, (filename, file_pointer, content_type))
|
182
|
+
name will become the name by which the multer can build an array.
|
183
|
+
:param route: route for the request
|
184
|
+
:return: awaitable list(response)
|
185
|
+
"""
|
186
|
+
async with aiohttp.ClientSession() as session:
|
187
|
+
tasks = [
|
188
|
+
asyncio.ensure_future(
|
189
|
+
self._make_files_request(
|
190
|
+
files=files, route=route, session=session
|
191
|
+
)
|
192
|
+
)
|
193
|
+
for files in files_per_request
|
194
|
+
]
|
195
|
+
return await asyncio.gather(*tasks)
|
196
|
+
|
197
|
+
async def _make_files_request(
|
198
|
+
self,
|
199
|
+
files,
|
200
|
+
route: str,
|
201
|
+
session: aiohttp.ClientSession,
|
202
|
+
retry_attempt=0,
|
203
|
+
max_retries=3,
|
204
|
+
sleep_intervals=(1, 3, 9),
|
205
|
+
):
|
206
|
+
"""
|
207
|
+
Makes an async post request with files to a Nucleus endpoint.
|
208
|
+
|
209
|
+
:param files: A list of tuples (name, (filename, file_pointer, file_type))
|
210
|
+
:param route: route for the request
|
211
|
+
:param session: Session to use for post.
|
212
|
+
:return: response
|
213
|
+
"""
|
214
|
+
endpoint = f"{self._client.endpoint}/{route}"
|
215
|
+
|
216
|
+
logger.info("Posting to %s", endpoint)
|
217
|
+
|
218
|
+
form = aiohttp.FormData()
|
219
|
+
|
220
|
+
for file in files:
|
221
|
+
form.add_field(
|
222
|
+
name=file[0],
|
223
|
+
filename=file[1][0],
|
224
|
+
value=file[1][1],
|
225
|
+
content_type=file[1][2],
|
226
|
+
)
|
227
|
+
|
228
|
+
for sleep_time in RetryStrategy.sleep_times + [-1]:
|
229
|
+
|
230
|
+
async with session.post(
|
231
|
+
endpoint,
|
232
|
+
data=form,
|
233
|
+
auth=aiohttp.BasicAuth(self._client.api_key, ""),
|
234
|
+
timeout=DEFAULT_NETWORK_TIMEOUT_SEC,
|
235
|
+
) as response:
|
236
|
+
logger.info(
|
237
|
+
"API request has response code %s", response.status
|
238
|
+
)
|
239
|
+
|
240
|
+
try:
|
241
|
+
data = await response.json()
|
242
|
+
except aiohttp.client_exceptions.ContentTypeError:
|
243
|
+
# In case of 404, the server returns text
|
244
|
+
data = await response.text()
|
245
|
+
if (
|
246
|
+
response.status in RetryStrategy.statuses
|
247
|
+
and sleep_time != -1
|
248
|
+
):
|
249
|
+
time.sleep(sleep_time)
|
250
|
+
continue
|
251
|
+
|
252
|
+
if not response.ok:
|
253
|
+
if retry_attempt < max_retries:
|
254
|
+
time.sleep(sleep_intervals[retry_attempt])
|
255
|
+
retry_attempt += 1
|
256
|
+
return self._make_files_request(
|
257
|
+
files,
|
258
|
+
route,
|
259
|
+
session,
|
260
|
+
retry_attempt,
|
261
|
+
max_retries,
|
262
|
+
sleep_intervals,
|
263
|
+
)
|
264
|
+
else:
|
265
|
+
self._client.handle_bad_response(
|
266
|
+
endpoint,
|
267
|
+
session.post,
|
268
|
+
aiohttp_response=(
|
269
|
+
response.status,
|
270
|
+
response.reason,
|
271
|
+
data,
|
272
|
+
),
|
273
|
+
)
|
274
|
+
|
275
|
+
return data
|
276
|
+
|
277
|
+
def _process_append_requests(
|
278
|
+
self,
|
279
|
+
dataset_id: str,
|
280
|
+
payload: dict,
|
281
|
+
update: bool,
|
282
|
+
batch_size: int = 20,
|
283
|
+
):
|
284
|
+
items = payload[ITEMS_KEY]
|
285
|
+
payloads = [
|
286
|
+
# batch_size images per request
|
287
|
+
{ITEMS_KEY: items[i : i + batch_size], UPDATE_KEY: update}
|
288
|
+
for i in range(0, len(items), batch_size)
|
289
|
+
]
|
290
|
+
|
291
|
+
return [
|
292
|
+
self._client.make_request(
|
293
|
+
payload,
|
294
|
+
f"dataset/{dataset_id}/append",
|
295
|
+
)
|
296
|
+
for payload in payloads
|
297
|
+
]
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import warnings
|
2
|
+
from functools import wraps
|
3
|
+
from typing import Callable
|
4
|
+
|
5
|
+
|
6
|
+
def deprecated(msg: str):
|
7
|
+
"""Adds a deprecation warning via the `warnings` lib which can be caught by linters.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
msg: State reason of deprecation and point towards preferred practices
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
Deprecation wrapped function
|
14
|
+
"""
|
15
|
+
|
16
|
+
def decorator(func: Callable):
|
17
|
+
@wraps(func)
|
18
|
+
def wrapper(*args, **kwargs):
|
19
|
+
# NOTE: __qualname looks a lot better for method calls
|
20
|
+
name = (
|
21
|
+
func.__qualname__
|
22
|
+
if hasattr(func, "__qualname__")
|
23
|
+
else func.__name__
|
24
|
+
)
|
25
|
+
full_message = f"Calling {name} is deprecated: {msg}"
|
26
|
+
# NOTE: stacklevel=2 makes sure that the level is applied to the decorated function
|
27
|
+
warnings.warn(full_message, DeprecationWarning, stacklevel=2)
|
28
|
+
return func(*args, **kwargs)
|
29
|
+
|
30
|
+
return wrapper
|
31
|
+
|
32
|
+
return decorator
|
nucleus/errors.py
CHANGED
@@ -4,6 +4,11 @@ nucleus_client_version = pkg_resources.get_distribution(
|
|
4
4
|
"scale-nucleus"
|
5
5
|
).version
|
6
6
|
|
7
|
+
INFRA_FLAKE_MESSAGES = [
|
8
|
+
"downstream duration timeout",
|
9
|
+
"upstream connect error or disconnect/reset before headers. reset reason: local reset",
|
10
|
+
]
|
11
|
+
|
7
12
|
|
8
13
|
class ModelCreationError(Exception):
|
9
14
|
def __init__(self, message="Could not create the model"):
|
@@ -35,7 +40,7 @@ class NucleusAPIError(Exception):
|
|
35
40
|
def __init__(
|
36
41
|
self, endpoint, command, requests_response=None, aiohttp_response=None
|
37
42
|
):
|
38
|
-
message = f"Your client is on version {nucleus_client_version}.
|
43
|
+
message = f"Your client is on version {nucleus_client_version}. If you have not recently done so, please make sure you have updated to the latest version of the client by running pip install --upgrade scale-nucleus\n"
|
39
44
|
if requests_response is not None:
|
40
45
|
message += f"Tried to {command.__name__} {endpoint}, but received {requests_response.status_code}: {requests_response.reason}."
|
41
46
|
if hasattr(requests_response, "text"):
|
@@ -50,4 +55,19 @@ class NucleusAPIError(Exception):
|
|
50
55
|
if data:
|
51
56
|
message += f"\nThe detailed error is:\n{data}"
|
52
57
|
|
58
|
+
if any(
|
59
|
+
infra_flake_message in message
|
60
|
+
for infra_flake_message in INFRA_FLAKE_MESSAGES
|
61
|
+
):
|
62
|
+
message += "\n This likely indicates temporary downtime of the API, please try again in a minute or two"
|
63
|
+
|
53
64
|
super().__init__(message)
|
65
|
+
|
66
|
+
|
67
|
+
class NoAPIKey(Exception):
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
message="You need to pass an API key to the NucleusClient or set the environment variable NUCLEUS_API_KEY",
|
71
|
+
):
|
72
|
+
self.message = message
|
73
|
+
super().__init__(self.message)
|