wandb 0.22.0__py3-none-win32.whl → 0.22.2__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +8 -5
- wandb/_pydantic/__init__.py +12 -11
- wandb/_pydantic/base.py +49 -19
- wandb/apis/__init__.py +2 -0
- wandb/apis/attrs.py +2 -0
- wandb/apis/importers/internals/internal.py +16 -23
- wandb/apis/internal.py +2 -0
- wandb/apis/normalize.py +2 -0
- wandb/apis/public/__init__.py +3 -2
- wandb/apis/public/api.py +215 -164
- wandb/apis/public/artifacts.py +23 -20
- wandb/apis/public/const.py +2 -0
- wandb/apis/public/files.py +33 -24
- wandb/apis/public/history.py +2 -0
- wandb/apis/public/jobs.py +20 -18
- wandb/apis/public/projects.py +4 -2
- wandb/apis/public/query_generator.py +3 -0
- wandb/apis/public/registries/__init__.py +7 -0
- wandb/apis/public/registries/_freezable_list.py +9 -12
- wandb/apis/public/registries/registries_search.py +8 -6
- wandb/apis/public/registries/registry.py +22 -17
- wandb/apis/public/reports.py +2 -0
- wandb/apis/public/runs.py +261 -57
- wandb/apis/public/sweeps.py +10 -9
- wandb/apis/public/teams.py +2 -0
- wandb/apis/public/users.py +2 -0
- wandb/apis/public/utils.py +16 -15
- wandb/automations/_generated/__init__.py +54 -127
- wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
- wandb/automations/_generated/fragments.py +26 -91
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +16 -2
- wandb/cli/beta_leet.py +74 -0
- wandb/cli/beta_sync.py +9 -11
- wandb/cli/cli.py +34 -7
- wandb/errors/errors.py +3 -3
- wandb/proto/v3/wandb_api_pb2.py +86 -0
- wandb/proto/v3/wandb_internal_pb2.py +352 -351
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_api_pb2.py +37 -0
- wandb/proto/v4/wandb_internal_pb2.py +352 -351
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_api_pb2.py +38 -0
- wandb/proto/v5/wandb_internal_pb2.py +352 -351
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_api_pb2.py +48 -0
- wandb/proto/v6/wandb_internal_pb2.py +352 -351
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_sync_pb2.py +10 -6
- wandb/proto/wandb_api_pb2.py +18 -0
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/sdk/artifacts/_factories.py +7 -2
- wandb/sdk/artifacts/_generated/__init__.py +112 -412
- wandb/sdk/artifacts/_generated/fragments.py +65 -0
- wandb/sdk/artifacts/_generated/operations.py +52 -22
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/type_info.py +19 -0
- wandb/sdk/artifacts/_gqlutils.py +47 -0
- wandb/sdk/artifacts/_models/__init__.py +4 -0
- wandb/sdk/artifacts/_models/base_model.py +20 -0
- wandb/sdk/artifacts/_validators.py +40 -12
- wandb/sdk/artifacts/artifact.py +99 -118
- wandb/sdk/artifacts/artifact_file_cache.py +6 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +67 -14
- wandb/sdk/artifacts/storage_handler.py +18 -12
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
- wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
- wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +71 -242
- wandb/sdk/artifacts/storage_policy.py +25 -12
- wandb/sdk/data_types/bokeh.py +5 -1
- wandb/sdk/data_types/image.py +17 -6
- wandb/sdk/data_types/object_3d.py +67 -2
- wandb/sdk/interface/interface.py +31 -4
- wandb/sdk/interface/interface_queue.py +10 -0
- wandb/sdk/interface/interface_shared.py +0 -7
- wandb/sdk/interface/interface_sock.py +9 -3
- wandb/sdk/internal/_generated/__init__.py +2 -12
- wandb/sdk/internal/job_builder.py +27 -10
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/internal/settings_static.py +2 -82
- wandb/sdk/launch/create_job.py +2 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
- wandb/sdk/launch/utils.py +82 -1
- wandb/sdk/lib/progress.py +8 -74
- wandb/sdk/lib/service/service_client.py +5 -9
- wandb/sdk/lib/service/service_connection.py +39 -23
- wandb/sdk/mailbox/mailbox_handle.py +2 -0
- wandb/sdk/projects/_generated/__init__.py +12 -33
- wandb/sdk/wandb_init.py +23 -3
- wandb/sdk/wandb_login.py +53 -27
- wandb/sdk/wandb_run.py +10 -5
- wandb/sdk/wandb_settings.py +63 -25
- wandb/sync/sync.py +7 -2
- wandb/util.py +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/RECORD +113 -103
- wandb/sdk/artifacts/_graphql_fragments.py +0 -19
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,32 +3,35 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
-
import functools
|
7
6
|
import hashlib
|
8
7
|
import logging
|
9
|
-
import math
|
10
8
|
import os
|
11
|
-
import queue
|
12
9
|
import shutil
|
13
|
-
import threading
|
14
10
|
from collections import deque
|
15
|
-
from
|
11
|
+
from operator import itemgetter
|
12
|
+
from typing import TYPE_CHECKING, Any
|
16
13
|
from urllib.parse import quote
|
17
14
|
|
18
15
|
import requests
|
19
16
|
|
20
|
-
from wandb import env
|
21
17
|
from wandb.errors.term import termwarn
|
22
18
|
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
23
19
|
from wandb.sdk.artifacts.artifact_file_cache import (
|
24
20
|
ArtifactFileCache,
|
25
|
-
Opener,
|
26
21
|
get_artifact_file_cache,
|
27
22
|
)
|
28
23
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
29
24
|
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
30
25
|
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
31
26
|
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
27
|
+
from wandb.sdk.artifacts.storage_policies._multipart import (
|
28
|
+
MAX_MULTI_UPLOAD_SIZE,
|
29
|
+
MIN_MULTI_UPLOAD_SIZE,
|
30
|
+
KiB,
|
31
|
+
calc_part_size,
|
32
|
+
multipart_download,
|
33
|
+
scan_chunks,
|
34
|
+
)
|
32
35
|
from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
|
33
36
|
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
34
37
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
@@ -44,34 +47,9 @@ if TYPE_CHECKING:
|
|
44
47
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
45
48
|
from wandb.sdk.internal import progress
|
46
49
|
|
47
|
-
|
48
|
-
# AWS S3 max upload parts without having to make additional requests for extra parts
|
49
|
-
S3_MAX_PART_NUMBERS = 1000
|
50
|
-
S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
|
51
|
-
S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
|
52
|
-
|
53
|
-
|
54
|
-
# Minimum size to switch to multipart download, same as upload, 2GB.
|
55
|
-
_MULTIPART_DOWNLOAD_SIZE = S3_MIN_MULTI_UPLOAD_SIZE
|
56
|
-
# Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
|
57
|
-
# https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
|
58
|
-
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
|
59
|
-
_DOWNLOAD_PART_SIZE_BYTES = 100 * 1024 * 1024
|
60
|
-
# Chunk size for reading http response and writing to disk. 1MB.
|
61
|
-
_HTTP_RES_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
|
62
|
-
# Signal end of _ChunkQueue, consumer (file writer) should stop after getting this item.
|
63
|
-
# NOTE: it should only be used for multithread executor, it does notwork for multiprocess executor.
|
64
|
-
# multipart download is using the executor from artifact.download() which is a multithread executor.
|
65
|
-
_CHUNK_QUEUE_SENTINEL = object()
|
66
|
-
|
67
50
|
logger = logging.getLogger(__name__)
|
68
51
|
|
69
52
|
|
70
|
-
class _ChunkContent(NamedTuple):
|
71
|
-
offset: int
|
72
|
-
data: bytes
|
73
|
-
|
74
|
-
|
75
53
|
class WandbStoragePolicy(StoragePolicy):
|
76
54
|
@classmethod
|
77
55
|
def name(cls) -> str:
|
@@ -91,6 +69,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
91
69
|
session: requests.Session | None = None,
|
92
70
|
) -> None:
|
93
71
|
self._config = config or {}
|
72
|
+
if (storage_region := self._config.get("storageRegion")) is not None:
|
73
|
+
self._validate_storage_region(storage_region)
|
94
74
|
self._cache = cache or get_artifact_file_cache()
|
95
75
|
self._session = session or make_http_session()
|
96
76
|
self._api = api or InternalApi()
|
@@ -99,7 +79,15 @@ class WandbStoragePolicy(StoragePolicy):
|
|
99
79
|
default_handler=TrackingHandler(),
|
100
80
|
)
|
101
81
|
|
102
|
-
def
|
82
|
+
def _validate_storage_region(self, storage_region: Any) -> None:
|
83
|
+
if not isinstance(storage_region, str):
|
84
|
+
raise TypeError(
|
85
|
+
f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
|
86
|
+
)
|
87
|
+
if not storage_region.strip():
|
88
|
+
raise ValueError("storageRegion must be a non-empty string")
|
89
|
+
|
90
|
+
def config(self) -> dict[str, Any]:
|
103
91
|
return self._config
|
104
92
|
|
105
93
|
def load_file(
|
@@ -107,8 +95,9 @@ class WandbStoragePolicy(StoragePolicy):
|
|
107
95
|
artifact: Artifact,
|
108
96
|
manifest_entry: ArtifactManifestEntry,
|
109
97
|
dest_path: str | None = None,
|
98
|
+
# FIXME: We should avoid passing the executor into multiple inner functions,
|
99
|
+
# it leads to confusing code and opaque tracebacks/call stacks.
|
110
100
|
executor: concurrent.futures.Executor | None = None,
|
111
|
-
multipart: bool | None = None,
|
112
101
|
) -> FilePathStr:
|
113
102
|
"""Use cache or download the file using signed url.
|
114
103
|
|
@@ -116,10 +105,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
116
105
|
executor: Passed from caller, artifact has a thread pool for multi file download.
|
117
106
|
Reuse the thread pool for multi part download. The thread pool is closed when
|
118
107
|
artifact download is done.
|
119
|
-
|
120
|
-
|
121
|
-
2GB. If set to `True` or `False`, the artifact will be downloaded in
|
122
|
-
parallel or serially regardless of the file size.
|
108
|
+
|
109
|
+
If this is None, download the file serially.
|
123
110
|
"""
|
124
111
|
if dest_path is not None:
|
125
112
|
self._cache._override_cache_path = dest_path
|
@@ -131,14 +118,10 @@ class WandbStoragePolicy(StoragePolicy):
|
|
131
118
|
if hit:
|
132
119
|
return path
|
133
120
|
|
134
|
-
if
|
121
|
+
if url := manifest_entry._download_url:
|
135
122
|
# Use multipart parallel download for large file
|
136
|
-
if (
|
137
|
-
executor
|
138
|
-
and (size := manifest_entry.size)
|
139
|
-
and self._should_multipart_download(size, multipart)
|
140
|
-
):
|
141
|
-
self._multipart_file_download(executor, url, size, cache_open)
|
123
|
+
if executor and (size := manifest_entry.size):
|
124
|
+
multipart_download(executor, self._session, url, size, cache_open)
|
142
125
|
return path
|
143
126
|
|
144
127
|
# Serial download
|
@@ -161,142 +144,16 @@ class WandbStoragePolicy(StoragePolicy):
|
|
161
144
|
else:
|
162
145
|
auth = ("api", self._api.api_key or "")
|
163
146
|
|
164
|
-
file_url = self._file_url(
|
165
|
-
self._api,
|
166
|
-
artifact.entity,
|
167
|
-
artifact.project,
|
168
|
-
artifact.name.split(":")[0],
|
169
|
-
manifest_entry,
|
170
|
-
)
|
147
|
+
file_url = self._file_url(self._api, artifact, manifest_entry)
|
171
148
|
response = self._session.get(
|
172
149
|
file_url, auth=auth, cookies=cookies, headers=headers, stream=True
|
173
150
|
)
|
174
151
|
|
175
152
|
with cache_open(mode="wb") as file:
|
176
|
-
for data in response.iter_content(chunk_size=16 *
|
153
|
+
for data in response.iter_content(chunk_size=16 * KiB):
|
177
154
|
file.write(data)
|
178
155
|
return path
|
179
156
|
|
180
|
-
def _should_multipart_download(
|
181
|
-
self,
|
182
|
-
file_size: int,
|
183
|
-
multipart: bool | None,
|
184
|
-
) -> bool:
|
185
|
-
if multipart is not None:
|
186
|
-
return multipart
|
187
|
-
return file_size >= _MULTIPART_DOWNLOAD_SIZE
|
188
|
-
|
189
|
-
def _write_chunks_to_file(
|
190
|
-
self,
|
191
|
-
f: IO,
|
192
|
-
q: queue.Queue,
|
193
|
-
download_has_error: threading.Event,
|
194
|
-
):
|
195
|
-
while not download_has_error.is_set():
|
196
|
-
item = q.get()
|
197
|
-
if item is _CHUNK_QUEUE_SENTINEL:
|
198
|
-
# Normal shutdown, all the chunks are written
|
199
|
-
return
|
200
|
-
elif isinstance(item, _ChunkContent):
|
201
|
-
try:
|
202
|
-
# NOTE: Seek works without pre allocating the file on disk.
|
203
|
-
# It automatically creates a sparse file, e.g. ls -hl would show
|
204
|
-
# a bigger size compared to du -sh * because downloading different
|
205
|
-
# chunks is not a sequential write.
|
206
|
-
# See https://man7.org/linux/man-pages/man2/lseek.2.html
|
207
|
-
f.seek(item.offset)
|
208
|
-
f.write(item.data)
|
209
|
-
except Exception as e:
|
210
|
-
if env.is_debug():
|
211
|
-
logger.debug(f"Error writing chunk to file: {e}")
|
212
|
-
download_has_error.set()
|
213
|
-
raise
|
214
|
-
else:
|
215
|
-
raise ValueError(f"Unknown queue item type: {type(item)}")
|
216
|
-
|
217
|
-
def _download_part(
|
218
|
-
self,
|
219
|
-
download_url: str,
|
220
|
-
headers: dict,
|
221
|
-
start: int,
|
222
|
-
q: queue.Queue,
|
223
|
-
download_has_error: threading.Event,
|
224
|
-
):
|
225
|
-
# Other threads has error, no need to start
|
226
|
-
if download_has_error.is_set():
|
227
|
-
return
|
228
|
-
response = self._session.get(url=download_url, headers=headers, stream=True)
|
229
|
-
|
230
|
-
file_offset = start
|
231
|
-
for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
|
232
|
-
if download_has_error.is_set():
|
233
|
-
return
|
234
|
-
q.put(_ChunkContent(offset=file_offset, data=content))
|
235
|
-
file_offset += len(content)
|
236
|
-
|
237
|
-
def _multipart_file_download(
|
238
|
-
self,
|
239
|
-
executor: concurrent.futures.Executor,
|
240
|
-
download_url: str,
|
241
|
-
file_size_bytes: int,
|
242
|
-
cache_open: Opener,
|
243
|
-
):
|
244
|
-
"""Download file as multiple parts in parallel.
|
245
|
-
|
246
|
-
Only one thread for writing to file. Each part run one http request in one thread.
|
247
|
-
HTTP response chunk of a file part is sent to the writer thread via a queue.
|
248
|
-
"""
|
249
|
-
q: queue.Queue[_ChunkContent | object] = queue.Queue(maxsize=500)
|
250
|
-
download_has_error = threading.Event()
|
251
|
-
|
252
|
-
# Put cache_open at top so we remove the tmp file when there is network error.
|
253
|
-
with cache_open("wb") as f:
|
254
|
-
# Start writer thread first.
|
255
|
-
write_handler = functools.partial(
|
256
|
-
self._write_chunks_to_file, f, q, download_has_error
|
257
|
-
)
|
258
|
-
write_future = executor.submit(write_handler)
|
259
|
-
|
260
|
-
# Start download threads for each part.
|
261
|
-
download_futures: deque[concurrent.futures.Future] = deque()
|
262
|
-
part_size = _DOWNLOAD_PART_SIZE_BYTES
|
263
|
-
num_parts = int(math.ceil(file_size_bytes / float(part_size)))
|
264
|
-
for i in range(num_parts):
|
265
|
-
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
|
266
|
-
# Start and end are both inclusive, empty end means use the actual end of the file.
|
267
|
-
start = i * part_size
|
268
|
-
bytes_range = f"bytes={start}-"
|
269
|
-
if i != (num_parts - 1):
|
270
|
-
# bytes=0-499
|
271
|
-
bytes_range += f"{start + part_size - 1}"
|
272
|
-
headers = {"Range": bytes_range}
|
273
|
-
download_handler = functools.partial(
|
274
|
-
self._download_part,
|
275
|
-
download_url,
|
276
|
-
headers,
|
277
|
-
start,
|
278
|
-
q,
|
279
|
-
download_has_error,
|
280
|
-
)
|
281
|
-
download_futures.append(executor.submit(download_handler))
|
282
|
-
|
283
|
-
# Wait for download
|
284
|
-
done, not_done = concurrent.futures.wait(
|
285
|
-
download_futures, return_when=concurrent.futures.FIRST_EXCEPTION
|
286
|
-
)
|
287
|
-
try:
|
288
|
-
for fut in done:
|
289
|
-
fut.result()
|
290
|
-
except Exception as e:
|
291
|
-
if env.is_debug():
|
292
|
-
logger.debug(f"Error downloading file: {e}")
|
293
|
-
download_has_error.set()
|
294
|
-
raise
|
295
|
-
finally:
|
296
|
-
# Always signal the writer to stop
|
297
|
-
q.put(_CHUNK_QUEUE_SENTINEL)
|
298
|
-
write_future.result()
|
299
|
-
|
300
157
|
def store_reference(
|
301
158
|
self,
|
302
159
|
artifact: Artifact,
|
@@ -304,7 +161,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
304
161
|
name: str | None = None,
|
305
162
|
checksum: bool = True,
|
306
163
|
max_objects: int | None = None,
|
307
|
-
) ->
|
164
|
+
) -> list[ArtifactManifestEntry]:
|
308
165
|
return self._handler.store_path(
|
309
166
|
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
310
167
|
)
|
@@ -324,13 +181,16 @@ class WandbStoragePolicy(StoragePolicy):
|
|
324
181
|
def _file_url(
|
325
182
|
self,
|
326
183
|
api: InternalApi,
|
327
|
-
|
328
|
-
project_name: str,
|
329
|
-
artifact_name: str,
|
184
|
+
artifact: Artifact,
|
330
185
|
entry: ArtifactManifestEntry,
|
331
186
|
) -> str:
|
332
187
|
layout = self._config.get("storageLayout", StorageLayout.V1)
|
333
188
|
region = self._config.get("storageRegion", "default")
|
189
|
+
|
190
|
+
entity_name = artifact.entity
|
191
|
+
project_name = artifact.project
|
192
|
+
artifact_name = artifact.name.split(":")[0]
|
193
|
+
|
334
194
|
md5_hex = b64_to_hex_id(entry.digest)
|
335
195
|
|
336
196
|
base_url: str = api.settings("base_url")
|
@@ -357,30 +217,21 @@ class WandbStoragePolicy(StoragePolicy):
|
|
357
217
|
multipart_urls: dict[int, str],
|
358
218
|
extra_headers: dict[str, str],
|
359
219
|
) -> list[dict[str, Any]]:
|
360
|
-
etags =
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
data
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
"content-type": extra_headers.get("Content-Type", ""),
|
376
|
-
},
|
377
|
-
)
|
378
|
-
assert upload_resp is not None
|
379
|
-
etags.append(
|
380
|
-
{"partNumber": part_number, "hexMD5": upload_resp.headers["ETag"]}
|
381
|
-
)
|
382
|
-
part_number += 1
|
383
|
-
return etags
|
220
|
+
etags: deque[dict[str, Any]] = deque()
|
221
|
+
file_chunks = scan_chunks(file_path, chunk_size)
|
222
|
+
for num, data in enumerate(file_chunks, start=1):
|
223
|
+
rsp = self._api.upload_multipart_file_chunk_retry(
|
224
|
+
multipart_urls[num],
|
225
|
+
data,
|
226
|
+
extra_headers={
|
227
|
+
"content-md5": hex_to_b64_id(hex_digests[num]),
|
228
|
+
"content-length": str(len(data)),
|
229
|
+
"content-type": extra_headers.get("Content-Type") or "",
|
230
|
+
},
|
231
|
+
)
|
232
|
+
assert rsp is not None
|
233
|
+
etags.append({"partNumber": num, "hexMD5": rsp.headers["ETag"]})
|
234
|
+
return list(etags)
|
384
235
|
|
385
236
|
def default_file_upload(
|
386
237
|
self,
|
@@ -393,20 +244,9 @@ class WandbStoragePolicy(StoragePolicy):
|
|
393
244
|
with open(file_path, "rb") as file:
|
394
245
|
# This fails if we don't send the first byte before the signed URL expires.
|
395
246
|
self._api.upload_file_retry(
|
396
|
-
upload_url,
|
397
|
-
file,
|
398
|
-
progress_callback,
|
399
|
-
extra_headers=extra_headers,
|
247
|
+
upload_url, file, progress_callback, extra_headers=extra_headers
|
400
248
|
)
|
401
249
|
|
402
|
-
def calc_chunk_size(self, file_size: int) -> int:
|
403
|
-
# Default to chunk size of 100MiB. S3 has cap of 10,000 upload parts.
|
404
|
-
# If file size exceeds the default chunk size, recalculate chunk size.
|
405
|
-
default_chunk_size = 100 * 1024**2
|
406
|
-
if default_chunk_size * S3_MAX_PART_NUMBERS < file_size:
|
407
|
-
return math.ceil(file_size / S3_MAX_PART_NUMBERS)
|
408
|
-
return default_chunk_size
|
409
|
-
|
410
250
|
def store_file(
|
411
251
|
self,
|
412
252
|
artifact_id: str,
|
@@ -422,28 +262,20 @@ class WandbStoragePolicy(StoragePolicy):
|
|
422
262
|
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
423
263
|
"""
|
424
264
|
file_size = entry.size or 0
|
425
|
-
chunk_size =
|
426
|
-
|
427
|
-
hex_digests = {}
|
428
|
-
file_path = entry.local_path if entry.local_path is not None else ""
|
265
|
+
chunk_size = calc_part_size(file_size)
|
266
|
+
file_path = entry.local_path or ""
|
429
267
|
# Logic for AWS s3 multipart upload.
|
430
268
|
# Only chunk files if larger than 2 GiB. Currently can only support up to 5TiB.
|
431
|
-
if
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
hex_digest = hashlib.md5(data).hexdigest()
|
442
|
-
upload_parts.append(
|
443
|
-
{"hexMD5": hex_digest, "partNumber": part_number}
|
444
|
-
)
|
445
|
-
hex_digests[part_number] = hex_digest
|
446
|
-
part_number += 1
|
269
|
+
if MIN_MULTI_UPLOAD_SIZE <= file_size <= MAX_MULTI_UPLOAD_SIZE:
|
270
|
+
file_chunks = scan_chunks(file_path, chunk_size)
|
271
|
+
upload_parts = [
|
272
|
+
{"partNumber": num, "hexMD5": hashlib.md5(data).hexdigest()}
|
273
|
+
for num, data in enumerate(file_chunks, start=1)
|
274
|
+
]
|
275
|
+
hex_digests = dict(map(itemgetter("partNumber", "hexMD5"), upload_parts))
|
276
|
+
else:
|
277
|
+
upload_parts = []
|
278
|
+
hex_digests = {}
|
447
279
|
|
448
280
|
resp = preparer.prepare(
|
449
281
|
{
|
@@ -457,24 +289,21 @@ class WandbStoragePolicy(StoragePolicy):
|
|
457
289
|
|
458
290
|
entry.birth_artifact_id = resp.birth_artifact_id
|
459
291
|
|
460
|
-
multipart_urls = resp.multipart_upload_urls
|
461
292
|
if resp.upload_url is None:
|
462
293
|
return True
|
463
294
|
if entry.local_path is None:
|
464
295
|
return False
|
465
|
-
|
466
|
-
|
467
|
-
for header in (resp.upload_headers or {})
|
468
|
-
}
|
296
|
+
|
297
|
+
extra_headers = dict(hdr.split(":", 1) for hdr in (resp.upload_headers or []))
|
469
298
|
|
470
299
|
# This multipart upload isn't available, do a regular single url upload
|
471
|
-
if multipart_urls is None and resp.upload_url:
|
300
|
+
if (multipart_urls := resp.multipart_upload_urls) is None and resp.upload_url:
|
472
301
|
self.default_file_upload(
|
473
302
|
resp.upload_url, file_path, extra_headers, progress_callback
|
474
303
|
)
|
304
|
+
elif multipart_urls is None:
|
305
|
+
raise ValueError(f"No multipart urls to upload for file: {file_path}")
|
475
306
|
else:
|
476
|
-
if multipart_urls is None:
|
477
|
-
raise ValueError(f"No multipart urls to upload for file: {file_path}")
|
478
307
|
# Upload files using s3 multipart upload urls
|
479
308
|
etags = self.s3_multipart_file_upload(
|
480
309
|
file_path,
|
@@ -503,7 +332,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
503
332
|
|
504
333
|
staging_dir = get_staging_dir()
|
505
334
|
try:
|
506
|
-
if not entry.skip_cache
|
335
|
+
if not (entry.skip_cache or hit):
|
507
336
|
with cache_open("wb") as f, open(entry.local_path, "rb") as src:
|
508
337
|
shutil.copyfileobj(src, f)
|
509
338
|
if entry.local_path.startswith(staging_dir):
|
@@ -3,7 +3,8 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
-
from
|
6
|
+
from abc import ABC, abstractmethod
|
7
|
+
from typing import TYPE_CHECKING, Any
|
7
8
|
|
8
9
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
9
10
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
@@ -15,37 +16,47 @@ if TYPE_CHECKING:
|
|
15
16
|
from wandb.sdk.internal.progress import ProgressFn
|
16
17
|
|
17
18
|
|
18
|
-
|
19
|
+
_POLICY_REGISTRY: dict[str, type[StoragePolicy]] = {}
|
20
|
+
|
21
|
+
|
22
|
+
class StoragePolicy(ABC):
|
23
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
24
|
+
super().__init_subclass__(**kwargs)
|
25
|
+
_POLICY_REGISTRY[cls.name()] = cls
|
26
|
+
|
19
27
|
@classmethod
|
20
28
|
def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
if sub.name() == name:
|
25
|
-
return sub
|
26
|
-
raise NotImplementedError(f"Failed to find storage policy '{name}'")
|
29
|
+
if policy := _POLICY_REGISTRY.get(name):
|
30
|
+
return policy
|
31
|
+
raise ValueError(f"Failed to find storage policy {name!r}")
|
27
32
|
|
28
33
|
@classmethod
|
34
|
+
@abstractmethod
|
29
35
|
def name(cls) -> str:
|
30
36
|
raise NotImplementedError
|
31
37
|
|
32
38
|
@classmethod
|
33
|
-
|
39
|
+
@abstractmethod
|
40
|
+
def from_config(
|
41
|
+
cls, config: dict[str, Any], api: InternalApi | None = None
|
42
|
+
) -> StoragePolicy:
|
34
43
|
raise NotImplementedError
|
35
44
|
|
36
|
-
|
45
|
+
@abstractmethod
|
46
|
+
def config(self) -> dict[str, Any]:
|
37
47
|
raise NotImplementedError
|
38
48
|
|
49
|
+
@abstractmethod
|
39
50
|
def load_file(
|
40
51
|
self,
|
41
52
|
artifact: Artifact,
|
42
53
|
manifest_entry: ArtifactManifestEntry,
|
43
54
|
dest_path: str | None = None,
|
44
55
|
executor: concurrent.futures.Executor | None = None,
|
45
|
-
multipart: bool | None = None,
|
46
56
|
) -> FilePathStr:
|
47
57
|
raise NotImplementedError
|
48
58
|
|
59
|
+
@abstractmethod
|
49
60
|
def store_file(
|
50
61
|
self,
|
51
62
|
artifact_id: str,
|
@@ -56,6 +67,7 @@ class StoragePolicy:
|
|
56
67
|
) -> bool:
|
57
68
|
raise NotImplementedError
|
58
69
|
|
70
|
+
@abstractmethod
|
59
71
|
def store_reference(
|
60
72
|
self,
|
61
73
|
artifact: Artifact,
|
@@ -63,9 +75,10 @@ class StoragePolicy:
|
|
63
75
|
name: str | None = None,
|
64
76
|
checksum: bool = True,
|
65
77
|
max_objects: int | None = None,
|
66
|
-
) ->
|
78
|
+
) -> list[ArtifactManifestEntry]:
|
67
79
|
raise NotImplementedError
|
68
80
|
|
81
|
+
@abstractmethod
|
69
82
|
def load_reference(
|
70
83
|
self,
|
71
84
|
manifest_entry: ArtifactManifestEntry,
|
wandb/sdk/data_types/bokeh.py
CHANGED
@@ -5,6 +5,7 @@ import pathlib
|
|
5
5
|
from typing import TYPE_CHECKING, Union
|
6
6
|
|
7
7
|
from wandb import util
|
8
|
+
from wandb._strutils import nameof
|
8
9
|
from wandb.sdk.lib import runid
|
9
10
|
|
10
11
|
from . import _dtypes
|
@@ -34,7 +35,10 @@ class Bokeh(Media):
|
|
34
35
|
],
|
35
36
|
):
|
36
37
|
super().__init__()
|
37
|
-
bokeh = util.get_module(
|
38
|
+
bokeh = util.get_module(
|
39
|
+
"bokeh",
|
40
|
+
required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
|
41
|
+
)
|
38
42
|
if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
|
39
43
|
data_or_path
|
40
44
|
):
|
wandb/sdk/data_types/image.py
CHANGED
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
|
|
161
161
|
) -> None:
|
162
162
|
"""Initialize a `wandb.Image` object.
|
163
163
|
|
164
|
+
This class handles various image data formats and automatically normalizes
|
165
|
+
pixel values to the range [0, 255] when needed, ensuring compatibility
|
166
|
+
with the W&B backend.
|
167
|
+
|
168
|
+
* Data in range [0, 1] is multiplied by 255 and converted to uint8
|
169
|
+
* Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
|
170
|
+
-1 to 0 and 1 to 255, then converted to uint8
|
171
|
+
* Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
|
172
|
+
converted to uint8 (with a warning if values fall outside [0, 255])
|
173
|
+
* Data already in [0, 255] is converted to uint8 without modification
|
174
|
+
|
164
175
|
Args:
|
165
176
|
data_or_path: Accepts NumPy array/pytorch tensor of image data,
|
166
177
|
a PIL image object, or a path to an image file. If a NumPy
|
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
|
|
168
179
|
the image data will be saved to the given file type.
|
169
180
|
If the values are not in the range [0, 255] or all values are in the range [0, 1],
|
170
181
|
the image pixel values will be normalized to the range [0, 255]
|
171
|
-
unless `normalize` is set to False
|
182
|
+
unless `normalize` is set to `False`.
|
172
183
|
- pytorch tensor should be in the format (channel, height, width)
|
173
184
|
- NumPy array should be in the format (height, width, channel)
|
174
185
|
mode: The PIL mode for an image. Most common are "L", "RGB",
|
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
|
|
178
189
|
classes: A list of class information for the image,
|
179
190
|
used for labeling bounding boxes, and image masks.
|
180
191
|
boxes: A dictionary containing bounding box information for the image.
|
181
|
-
see
|
192
|
+
see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
|
182
193
|
masks: A dictionary containing mask information for the image.
|
183
|
-
see
|
194
|
+
see https://docs.wandb.ai/ref/python/data-types/imagemask/
|
184
195
|
file_type: The file type to save the image as.
|
185
|
-
This parameter has no effect if data_or_path is a path to an image file.
|
186
|
-
normalize: If True
|
187
|
-
Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
|
196
|
+
This parameter has no effect if `data_or_path` is a path to an image file.
|
197
|
+
normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
|
198
|
+
Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
|
188
199
|
|
189
200
|
Examples:
|
190
201
|
Create a wandb.Image from a numpy array
|