wandb 0.22.1__py3-none-win_arm64.whl → 0.22.2__py3-none-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +6 -3
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +16 -2
- wandb/cli/beta_leet.py +74 -0
- wandb/cli/cli.py +34 -7
- wandb/proto/v3/wandb_api_pb2.py +86 -0
- wandb/proto/v3/wandb_internal_pb2.py +352 -351
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_api_pb2.py +37 -0
- wandb/proto/v4/wandb_internal_pb2.py +352 -351
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_api_pb2.py +38 -0
- wandb/proto/v5/wandb_internal_pb2.py +352 -351
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_api_pb2.py +48 -0
- wandb/proto/v6/wandb_internal_pb2.py +352 -351
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/wandb_api_pb2.py +18 -0
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/sdk/artifacts/artifact.py +30 -30
- wandb/sdk/artifacts/artifact_manifest_entry.py +6 -12
- wandb/sdk/artifacts/storage_handler.py +18 -12
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
- wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
- wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +61 -242
- wandb/sdk/artifacts/storage_policy.py +25 -12
- wandb/sdk/data_types/object_3d.py +67 -2
- wandb/sdk/internal/job_builder.py +27 -10
- wandb/sdk/internal/sender.py +4 -1
- wandb/sdk/launch/create_job.py +2 -1
- wandb/sdk/lib/progress.py +1 -70
- wandb/sdk/wandb_init.py +1 -1
- wandb/sdk/wandb_run.py +5 -2
- wandb/sdk/wandb_settings.py +13 -12
- {wandb-0.22.1.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
- {wandb-0.22.1.dist-info → wandb-0.22.2.dist-info}/RECORD +49 -42
- {wandb-0.22.1.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
- {wandb-0.22.1.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.1.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,32 +3,35 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
-
import functools
|
7
6
|
import hashlib
|
8
7
|
import logging
|
9
|
-
import math
|
10
8
|
import os
|
11
|
-
import queue
|
12
9
|
import shutil
|
13
|
-
import threading
|
14
10
|
from collections import deque
|
15
|
-
from
|
11
|
+
from operator import itemgetter
|
12
|
+
from typing import TYPE_CHECKING, Any
|
16
13
|
from urllib.parse import quote
|
17
14
|
|
18
15
|
import requests
|
19
16
|
|
20
|
-
from wandb import env
|
21
17
|
from wandb.errors.term import termwarn
|
22
18
|
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
23
19
|
from wandb.sdk.artifacts.artifact_file_cache import (
|
24
20
|
ArtifactFileCache,
|
25
|
-
Opener,
|
26
21
|
get_artifact_file_cache,
|
27
22
|
)
|
28
23
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
29
24
|
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
30
25
|
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
31
26
|
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
27
|
+
from wandb.sdk.artifacts.storage_policies._multipart import (
|
28
|
+
MAX_MULTI_UPLOAD_SIZE,
|
29
|
+
MIN_MULTI_UPLOAD_SIZE,
|
30
|
+
KiB,
|
31
|
+
calc_part_size,
|
32
|
+
multipart_download,
|
33
|
+
scan_chunks,
|
34
|
+
)
|
32
35
|
from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
|
33
36
|
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
34
37
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
@@ -44,34 +47,9 @@ if TYPE_CHECKING:
|
|
44
47
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
45
48
|
from wandb.sdk.internal import progress
|
46
49
|
|
47
|
-
|
48
|
-
# AWS S3 max upload parts without having to make additional requests for extra parts
|
49
|
-
S3_MAX_PART_NUMBERS = 1000
|
50
|
-
S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
|
51
|
-
S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
|
52
|
-
|
53
|
-
|
54
|
-
# Minimum size to switch to multipart download, same as upload, 2GB.
|
55
|
-
_MULTIPART_DOWNLOAD_SIZE = S3_MIN_MULTI_UPLOAD_SIZE
|
56
|
-
# Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
|
57
|
-
# https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
|
58
|
-
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
|
59
|
-
_DOWNLOAD_PART_SIZE_BYTES = 100 * 1024 * 1024
|
60
|
-
# Chunk size for reading http response and writing to disk. 1MB.
|
61
|
-
_HTTP_RES_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
|
62
|
-
# Signal end of _ChunkQueue, consumer (file writer) should stop after getting this item.
|
63
|
-
# NOTE: it should only be used for multithread executor, it does notwork for multiprocess executor.
|
64
|
-
# multipart download is using the executor from artifact.download() which is a multithread executor.
|
65
|
-
_CHUNK_QUEUE_SENTINEL = object()
|
66
|
-
|
67
50
|
logger = logging.getLogger(__name__)
|
68
51
|
|
69
52
|
|
70
|
-
class _ChunkContent(NamedTuple):
|
71
|
-
offset: int
|
72
|
-
data: bytes
|
73
|
-
|
74
|
-
|
75
53
|
class WandbStoragePolicy(StoragePolicy):
|
76
54
|
@classmethod
|
77
55
|
def name(cls) -> str:
|
@@ -109,7 +87,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
109
87
|
if not storage_region.strip():
|
110
88
|
raise ValueError("storageRegion must be a non-empty string")
|
111
89
|
|
112
|
-
def config(self) -> dict:
|
90
|
+
def config(self) -> dict[str, Any]:
|
113
91
|
return self._config
|
114
92
|
|
115
93
|
def load_file(
|
@@ -117,8 +95,9 @@ class WandbStoragePolicy(StoragePolicy):
|
|
117
95
|
artifact: Artifact,
|
118
96
|
manifest_entry: ArtifactManifestEntry,
|
119
97
|
dest_path: str | None = None,
|
98
|
+
# FIXME: We should avoid passing the executor into multiple inner functions,
|
99
|
+
# it leads to confusing code and opaque tracebacks/call stacks.
|
120
100
|
executor: concurrent.futures.Executor | None = None,
|
121
|
-
multipart: bool | None = None,
|
122
101
|
) -> FilePathStr:
|
123
102
|
"""Use cache or download the file using signed url.
|
124
103
|
|
@@ -126,10 +105,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
126
105
|
executor: Passed from caller, artifact has a thread pool for multi file download.
|
127
106
|
Reuse the thread pool for multi part download. The thread pool is closed when
|
128
107
|
artifact download is done.
|
129
|
-
|
130
|
-
|
131
|
-
2GB. If set to `True` or `False`, the artifact will be downloaded in
|
132
|
-
parallel or serially regardless of the file size.
|
108
|
+
|
109
|
+
If this is None, download the file serially.
|
133
110
|
"""
|
134
111
|
if dest_path is not None:
|
135
112
|
self._cache._override_cache_path = dest_path
|
@@ -141,14 +118,10 @@ class WandbStoragePolicy(StoragePolicy):
|
|
141
118
|
if hit:
|
142
119
|
return path
|
143
120
|
|
144
|
-
if
|
121
|
+
if url := manifest_entry._download_url:
|
145
122
|
# Use multipart parallel download for large file
|
146
|
-
if (
|
147
|
-
executor
|
148
|
-
and (size := manifest_entry.size)
|
149
|
-
and self._should_multipart_download(size, multipart)
|
150
|
-
):
|
151
|
-
self._multipart_file_download(executor, url, size, cache_open)
|
123
|
+
if executor and (size := manifest_entry.size):
|
124
|
+
multipart_download(executor, self._session, url, size, cache_open)
|
152
125
|
return path
|
153
126
|
|
154
127
|
# Serial download
|
@@ -171,142 +144,16 @@ class WandbStoragePolicy(StoragePolicy):
|
|
171
144
|
else:
|
172
145
|
auth = ("api", self._api.api_key or "")
|
173
146
|
|
174
|
-
file_url = self._file_url(
|
175
|
-
self._api,
|
176
|
-
artifact.entity,
|
177
|
-
artifact.project,
|
178
|
-
artifact.name.split(":")[0],
|
179
|
-
manifest_entry,
|
180
|
-
)
|
147
|
+
file_url = self._file_url(self._api, artifact, manifest_entry)
|
181
148
|
response = self._session.get(
|
182
149
|
file_url, auth=auth, cookies=cookies, headers=headers, stream=True
|
183
150
|
)
|
184
151
|
|
185
152
|
with cache_open(mode="wb") as file:
|
186
|
-
for data in response.iter_content(chunk_size=16 *
|
153
|
+
for data in response.iter_content(chunk_size=16 * KiB):
|
187
154
|
file.write(data)
|
188
155
|
return path
|
189
156
|
|
190
|
-
def _should_multipart_download(
|
191
|
-
self,
|
192
|
-
file_size: int,
|
193
|
-
multipart: bool | None,
|
194
|
-
) -> bool:
|
195
|
-
if multipart is not None:
|
196
|
-
return multipart
|
197
|
-
return file_size >= _MULTIPART_DOWNLOAD_SIZE
|
198
|
-
|
199
|
-
def _write_chunks_to_file(
|
200
|
-
self,
|
201
|
-
f: IO,
|
202
|
-
q: queue.Queue,
|
203
|
-
download_has_error: threading.Event,
|
204
|
-
):
|
205
|
-
while not download_has_error.is_set():
|
206
|
-
item = q.get()
|
207
|
-
if item is _CHUNK_QUEUE_SENTINEL:
|
208
|
-
# Normal shutdown, all the chunks are written
|
209
|
-
return
|
210
|
-
elif isinstance(item, _ChunkContent):
|
211
|
-
try:
|
212
|
-
# NOTE: Seek works without pre allocating the file on disk.
|
213
|
-
# It automatically creates a sparse file, e.g. ls -hl would show
|
214
|
-
# a bigger size compared to du -sh * because downloading different
|
215
|
-
# chunks is not a sequential write.
|
216
|
-
# See https://man7.org/linux/man-pages/man2/lseek.2.html
|
217
|
-
f.seek(item.offset)
|
218
|
-
f.write(item.data)
|
219
|
-
except Exception as e:
|
220
|
-
if env.is_debug():
|
221
|
-
logger.debug(f"Error writing chunk to file: {e}")
|
222
|
-
download_has_error.set()
|
223
|
-
raise
|
224
|
-
else:
|
225
|
-
raise ValueError(f"Unknown queue item type: {type(item)}")
|
226
|
-
|
227
|
-
def _download_part(
|
228
|
-
self,
|
229
|
-
download_url: str,
|
230
|
-
headers: dict,
|
231
|
-
start: int,
|
232
|
-
q: queue.Queue,
|
233
|
-
download_has_error: threading.Event,
|
234
|
-
):
|
235
|
-
# Other threads has error, no need to start
|
236
|
-
if download_has_error.is_set():
|
237
|
-
return
|
238
|
-
response = self._session.get(url=download_url, headers=headers, stream=True)
|
239
|
-
|
240
|
-
file_offset = start
|
241
|
-
for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
|
242
|
-
if download_has_error.is_set():
|
243
|
-
return
|
244
|
-
q.put(_ChunkContent(offset=file_offset, data=content))
|
245
|
-
file_offset += len(content)
|
246
|
-
|
247
|
-
def _multipart_file_download(
|
248
|
-
self,
|
249
|
-
executor: concurrent.futures.Executor,
|
250
|
-
download_url: str,
|
251
|
-
file_size_bytes: int,
|
252
|
-
cache_open: Opener,
|
253
|
-
):
|
254
|
-
"""Download file as multiple parts in parallel.
|
255
|
-
|
256
|
-
Only one thread for writing to file. Each part run one http request in one thread.
|
257
|
-
HTTP response chunk of a file part is sent to the writer thread via a queue.
|
258
|
-
"""
|
259
|
-
q: queue.Queue[_ChunkContent | object] = queue.Queue(maxsize=500)
|
260
|
-
download_has_error = threading.Event()
|
261
|
-
|
262
|
-
# Put cache_open at top so we remove the tmp file when there is network error.
|
263
|
-
with cache_open("wb") as f:
|
264
|
-
# Start writer thread first.
|
265
|
-
write_handler = functools.partial(
|
266
|
-
self._write_chunks_to_file, f, q, download_has_error
|
267
|
-
)
|
268
|
-
write_future = executor.submit(write_handler)
|
269
|
-
|
270
|
-
# Start download threads for each part.
|
271
|
-
download_futures: deque[concurrent.futures.Future] = deque()
|
272
|
-
part_size = _DOWNLOAD_PART_SIZE_BYTES
|
273
|
-
num_parts = int(math.ceil(file_size_bytes / float(part_size)))
|
274
|
-
for i in range(num_parts):
|
275
|
-
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
|
276
|
-
# Start and end are both inclusive, empty end means use the actual end of the file.
|
277
|
-
start = i * part_size
|
278
|
-
bytes_range = f"bytes={start}-"
|
279
|
-
if i != (num_parts - 1):
|
280
|
-
# bytes=0-499
|
281
|
-
bytes_range += f"{start + part_size - 1}"
|
282
|
-
headers = {"Range": bytes_range}
|
283
|
-
download_handler = functools.partial(
|
284
|
-
self._download_part,
|
285
|
-
download_url,
|
286
|
-
headers,
|
287
|
-
start,
|
288
|
-
q,
|
289
|
-
download_has_error,
|
290
|
-
)
|
291
|
-
download_futures.append(executor.submit(download_handler))
|
292
|
-
|
293
|
-
# Wait for download
|
294
|
-
done, not_done = concurrent.futures.wait(
|
295
|
-
download_futures, return_when=concurrent.futures.FIRST_EXCEPTION
|
296
|
-
)
|
297
|
-
try:
|
298
|
-
for fut in done:
|
299
|
-
fut.result()
|
300
|
-
except Exception as e:
|
301
|
-
if env.is_debug():
|
302
|
-
logger.debug(f"Error downloading file: {e}")
|
303
|
-
download_has_error.set()
|
304
|
-
raise
|
305
|
-
finally:
|
306
|
-
# Always signal the writer to stop
|
307
|
-
q.put(_CHUNK_QUEUE_SENTINEL)
|
308
|
-
write_future.result()
|
309
|
-
|
310
157
|
def store_reference(
|
311
158
|
self,
|
312
159
|
artifact: Artifact,
|
@@ -314,7 +161,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
314
161
|
name: str | None = None,
|
315
162
|
checksum: bool = True,
|
316
163
|
max_objects: int | None = None,
|
317
|
-
) ->
|
164
|
+
) -> list[ArtifactManifestEntry]:
|
318
165
|
return self._handler.store_path(
|
319
166
|
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
320
167
|
)
|
@@ -334,13 +181,16 @@ class WandbStoragePolicy(StoragePolicy):
|
|
334
181
|
def _file_url(
|
335
182
|
self,
|
336
183
|
api: InternalApi,
|
337
|
-
|
338
|
-
project_name: str,
|
339
|
-
artifact_name: str,
|
184
|
+
artifact: Artifact,
|
340
185
|
entry: ArtifactManifestEntry,
|
341
186
|
) -> str:
|
342
187
|
layout = self._config.get("storageLayout", StorageLayout.V1)
|
343
188
|
region = self._config.get("storageRegion", "default")
|
189
|
+
|
190
|
+
entity_name = artifact.entity
|
191
|
+
project_name = artifact.project
|
192
|
+
artifact_name = artifact.name.split(":")[0]
|
193
|
+
|
344
194
|
md5_hex = b64_to_hex_id(entry.digest)
|
345
195
|
|
346
196
|
base_url: str = api.settings("base_url")
|
@@ -367,30 +217,21 @@ class WandbStoragePolicy(StoragePolicy):
|
|
367
217
|
multipart_urls: dict[int, str],
|
368
218
|
extra_headers: dict[str, str],
|
369
219
|
) -> list[dict[str, Any]]:
|
370
|
-
etags =
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
data
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
"content-type": extra_headers.get("Content-Type", ""),
|
386
|
-
},
|
387
|
-
)
|
388
|
-
assert upload_resp is not None
|
389
|
-
etags.append(
|
390
|
-
{"partNumber": part_number, "hexMD5": upload_resp.headers["ETag"]}
|
391
|
-
)
|
392
|
-
part_number += 1
|
393
|
-
return etags
|
220
|
+
etags: deque[dict[str, Any]] = deque()
|
221
|
+
file_chunks = scan_chunks(file_path, chunk_size)
|
222
|
+
for num, data in enumerate(file_chunks, start=1):
|
223
|
+
rsp = self._api.upload_multipart_file_chunk_retry(
|
224
|
+
multipart_urls[num],
|
225
|
+
data,
|
226
|
+
extra_headers={
|
227
|
+
"content-md5": hex_to_b64_id(hex_digests[num]),
|
228
|
+
"content-length": str(len(data)),
|
229
|
+
"content-type": extra_headers.get("Content-Type") or "",
|
230
|
+
},
|
231
|
+
)
|
232
|
+
assert rsp is not None
|
233
|
+
etags.append({"partNumber": num, "hexMD5": rsp.headers["ETag"]})
|
234
|
+
return list(etags)
|
394
235
|
|
395
236
|
def default_file_upload(
|
396
237
|
self,
|
@@ -403,20 +244,9 @@ class WandbStoragePolicy(StoragePolicy):
|
|
403
244
|
with open(file_path, "rb") as file:
|
404
245
|
# This fails if we don't send the first byte before the signed URL expires.
|
405
246
|
self._api.upload_file_retry(
|
406
|
-
upload_url,
|
407
|
-
file,
|
408
|
-
progress_callback,
|
409
|
-
extra_headers=extra_headers,
|
247
|
+
upload_url, file, progress_callback, extra_headers=extra_headers
|
410
248
|
)
|
411
249
|
|
412
|
-
def calc_chunk_size(self, file_size: int) -> int:
|
413
|
-
# Default to chunk size of 100MiB. S3 has cap of 10,000 upload parts.
|
414
|
-
# If file size exceeds the default chunk size, recalculate chunk size.
|
415
|
-
default_chunk_size = 100 * 1024**2
|
416
|
-
if default_chunk_size * S3_MAX_PART_NUMBERS < file_size:
|
417
|
-
return math.ceil(file_size / S3_MAX_PART_NUMBERS)
|
418
|
-
return default_chunk_size
|
419
|
-
|
420
250
|
def store_file(
|
421
251
|
self,
|
422
252
|
artifact_id: str,
|
@@ -432,28 +262,20 @@ class WandbStoragePolicy(StoragePolicy):
|
|
432
262
|
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
433
263
|
"""
|
434
264
|
file_size = entry.size or 0
|
435
|
-
chunk_size =
|
436
|
-
|
437
|
-
hex_digests = {}
|
438
|
-
file_path = entry.local_path if entry.local_path is not None else ""
|
265
|
+
chunk_size = calc_part_size(file_size)
|
266
|
+
file_path = entry.local_path or ""
|
439
267
|
# Logic for AWS s3 multipart upload.
|
440
268
|
# Only chunk files if larger than 2 GiB. Currently can only support up to 5TiB.
|
441
|
-
if
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
hex_digest = hashlib.md5(data).hexdigest()
|
452
|
-
upload_parts.append(
|
453
|
-
{"hexMD5": hex_digest, "partNumber": part_number}
|
454
|
-
)
|
455
|
-
hex_digests[part_number] = hex_digest
|
456
|
-
part_number += 1
|
269
|
+
if MIN_MULTI_UPLOAD_SIZE <= file_size <= MAX_MULTI_UPLOAD_SIZE:
|
270
|
+
file_chunks = scan_chunks(file_path, chunk_size)
|
271
|
+
upload_parts = [
|
272
|
+
{"partNumber": num, "hexMD5": hashlib.md5(data).hexdigest()}
|
273
|
+
for num, data in enumerate(file_chunks, start=1)
|
274
|
+
]
|
275
|
+
hex_digests = dict(map(itemgetter("partNumber", "hexMD5"), upload_parts))
|
276
|
+
else:
|
277
|
+
upload_parts = []
|
278
|
+
hex_digests = {}
|
457
279
|
|
458
280
|
resp = preparer.prepare(
|
459
281
|
{
|
@@ -467,24 +289,21 @@ class WandbStoragePolicy(StoragePolicy):
|
|
467
289
|
|
468
290
|
entry.birth_artifact_id = resp.birth_artifact_id
|
469
291
|
|
470
|
-
multipart_urls = resp.multipart_upload_urls
|
471
292
|
if resp.upload_url is None:
|
472
293
|
return True
|
473
294
|
if entry.local_path is None:
|
474
295
|
return False
|
475
|
-
|
476
|
-
|
477
|
-
for header in (resp.upload_headers or {})
|
478
|
-
}
|
296
|
+
|
297
|
+
extra_headers = dict(hdr.split(":", 1) for hdr in (resp.upload_headers or []))
|
479
298
|
|
480
299
|
# This multipart upload isn't available, do a regular single url upload
|
481
|
-
if multipart_urls is None and resp.upload_url:
|
300
|
+
if (multipart_urls := resp.multipart_upload_urls) is None and resp.upload_url:
|
482
301
|
self.default_file_upload(
|
483
302
|
resp.upload_url, file_path, extra_headers, progress_callback
|
484
303
|
)
|
304
|
+
elif multipart_urls is None:
|
305
|
+
raise ValueError(f"No multipart urls to upload for file: {file_path}")
|
485
306
|
else:
|
486
|
-
if multipart_urls is None:
|
487
|
-
raise ValueError(f"No multipart urls to upload for file: {file_path}")
|
488
307
|
# Upload files using s3 multipart upload urls
|
489
308
|
etags = self.s3_multipart_file_upload(
|
490
309
|
file_path,
|
@@ -513,7 +332,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
513
332
|
|
514
333
|
staging_dir = get_staging_dir()
|
515
334
|
try:
|
516
|
-
if not entry.skip_cache
|
335
|
+
if not (entry.skip_cache or hit):
|
517
336
|
with cache_open("wb") as f, open(entry.local_path, "rb") as src:
|
518
337
|
shutil.copyfileobj(src, f)
|
519
338
|
if entry.local_path.startswith(staging_dir):
|
@@ -3,7 +3,8 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
-
from
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
from typing import TYPE_CHECKING, Any
|
7
8
|
|
8
9
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
9
10
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
@@ -15,37 +16,47 @@ if TYPE_CHECKING:
|
|
15
16
|
from wandb.sdk.internal.progress import ProgressFn
|
16
17
|
|
17
18
|
|
18
|
-
|
19
|
+
_POLICY_REGISTRY: dict[str, type[StoragePolicy]] = {}
|
20
|
+
|
21
|
+
|
22
|
+
class StoragePolicy(ABC):
|
23
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
24
|
+
super().__init_subclass__(**kwargs)
|
25
|
+
_POLICY_REGISTRY[cls.name()] = cls
|
26
|
+
|
19
27
|
@classmethod
|
20
28
|
def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
if sub.name() == name:
|
25
|
-
return sub
|
26
|
-
raise NotImplementedError(f"Failed to find storage policy '{name}'")
|
29
|
+
if policy := _POLICY_REGISTRY.get(name):
|
30
|
+
return policy
|
31
|
+
raise ValueError(f"Failed to find storage policy {name!r}")
|
27
32
|
|
28
33
|
@classmethod
|
34
|
+
@abstractmethod
|
29
35
|
def name(cls) -> str:
|
30
36
|
raise NotImplementedError
|
31
37
|
|
32
38
|
@classmethod
|
33
|
-
|
39
|
+
@abstractmethod
|
40
|
+
def from_config(
|
41
|
+
cls, config: dict[str, Any], api: InternalApi | None = None
|
42
|
+
) -> StoragePolicy:
|
34
43
|
raise NotImplementedError
|
35
44
|
|
36
|
-
|
45
|
+
@abstractmethod
|
46
|
+
def config(self) -> dict[str, Any]:
|
37
47
|
raise NotImplementedError
|
38
48
|
|
49
|
+
@abstractmethod
|
39
50
|
def load_file(
|
40
51
|
self,
|
41
52
|
artifact: Artifact,
|
42
53
|
manifest_entry: ArtifactManifestEntry,
|
43
54
|
dest_path: str | None = None,
|
44
55
|
executor: concurrent.futures.Executor | None = None,
|
45
|
-
multipart: bool | None = None,
|
46
56
|
) -> FilePathStr:
|
47
57
|
raise NotImplementedError
|
48
58
|
|
59
|
+
@abstractmethod
|
49
60
|
def store_file(
|
50
61
|
self,
|
51
62
|
artifact_id: str,
|
@@ -56,6 +67,7 @@ class StoragePolicy:
|
|
56
67
|
) -> bool:
|
57
68
|
raise NotImplementedError
|
58
69
|
|
70
|
+
@abstractmethod
|
59
71
|
def store_reference(
|
60
72
|
self,
|
61
73
|
artifact: Artifact,
|
@@ -63,9 +75,10 @@ class StoragePolicy:
|
|
63
75
|
name: str | None = None,
|
64
76
|
checksum: bool = True,
|
65
77
|
max_objects: int | None = None,
|
66
|
-
) ->
|
78
|
+
) -> list[ArtifactManifestEntry]:
|
67
79
|
raise NotImplementedError
|
68
80
|
|
81
|
+
@abstractmethod
|
69
82
|
def load_reference(
|
70
83
|
self,
|
71
84
|
manifest_entry: ArtifactManifestEntry,
|
@@ -138,7 +138,7 @@ def box3d(
|
|
138
138
|
label: "Optional[str]" = None,
|
139
139
|
score: "Optional[numeric]" = None,
|
140
140
|
) -> "Box3D":
|
141
|
-
"""
|
141
|
+
"""A 3D bounding box. The box is specified by its center, size and orientation.
|
142
142
|
|
143
143
|
Args:
|
144
144
|
center: The center point of the box as a length-3 ndarray.
|
@@ -149,7 +149,72 @@ def box3d(
|
|
149
149
|
r + xi + yj + zk.
|
150
150
|
color: The box's color as an (r, g, b) tuple with 0 <= r,g,b <= 1.
|
151
151
|
label: An optional label for the box.
|
152
|
-
score: An optional score for the box.
|
152
|
+
score: An optional score for the box. Typically used to indicate
|
153
|
+
the confidence of a detection.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
A Box3D object.
|
157
|
+
|
158
|
+
Example:
|
159
|
+
The following example creates a point cloud with 60 boxes rotating
|
160
|
+
around the X, Y and Z axes.
|
161
|
+
|
162
|
+
```python
|
163
|
+
import wandb
|
164
|
+
|
165
|
+
import math
|
166
|
+
import numpy as np
|
167
|
+
from scipy.spatial.transform import Rotation
|
168
|
+
|
169
|
+
|
170
|
+
with wandb.init() as run:
|
171
|
+
run.log(
|
172
|
+
{
|
173
|
+
"points": wandb.Object3D.from_point_cloud(
|
174
|
+
points=np.random.uniform(-5, 5, size=(100, 3)),
|
175
|
+
boxes=[
|
176
|
+
wandb.box3d(
|
177
|
+
center=(0.3 * t - 3, 0, 0),
|
178
|
+
size=(0.1, 0.1, 0.1),
|
179
|
+
orientation=Rotation.from_euler(
|
180
|
+
"xyz", [t * math.pi / 10, 0, 0]
|
181
|
+
).as_quat(),
|
182
|
+
color=(0.5 + t / 40, 0.5, 0.5),
|
183
|
+
label=f"box {t}",
|
184
|
+
score=0.9,
|
185
|
+
)
|
186
|
+
for t in range(20)
|
187
|
+
]
|
188
|
+
+ [
|
189
|
+
wandb.box3d(
|
190
|
+
center=(0, 0.3 * t - 3, 0.3),
|
191
|
+
size=(0.1, 0.1, 0.1),
|
192
|
+
orientation=Rotation.from_euler(
|
193
|
+
"xyz", [0, t * math.pi / 10, 0]
|
194
|
+
).as_quat(),
|
195
|
+
color=(0.5, 0.5 + t / 40, 0.5),
|
196
|
+
label=f"box {t}",
|
197
|
+
score=0.9,
|
198
|
+
)
|
199
|
+
for t in range(20)
|
200
|
+
]
|
201
|
+
+ [
|
202
|
+
wandb.box3d(
|
203
|
+
center=(0.3, 0.3, 0.3 * t - 3),
|
204
|
+
size=(0.1, 0.1, 0.1),
|
205
|
+
orientation=Rotation.from_euler(
|
206
|
+
"xyz", [0, 0, t * math.pi / 10]
|
207
|
+
).as_quat(),
|
208
|
+
color=(0.5, 0.5, 0.5 + t / 40),
|
209
|
+
label=f"box {t}",
|
210
|
+
score=0.9,
|
211
|
+
)
|
212
|
+
for t in range(20)
|
213
|
+
],
|
214
|
+
),
|
215
|
+
}
|
216
|
+
)
|
217
|
+
```
|
153
218
|
"""
|
154
219
|
try:
|
155
220
|
import numpy as np
|