wandb 0.22.0__py3-none-macosx_12_0_arm64.whl → 0.22.2__py3-none-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +8 -5
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta.py +16 -2
  35. wandb/cli/beta_leet.py +74 -0
  36. wandb/cli/beta_sync.py +9 -11
  37. wandb/cli/cli.py +34 -7
  38. wandb/errors/errors.py +3 -3
  39. wandb/proto/v3/wandb_api_pb2.py +86 -0
  40. wandb/proto/v3/wandb_internal_pb2.py +352 -351
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  43. wandb/proto/v4/wandb_api_pb2.py +37 -0
  44. wandb/proto/v4/wandb_internal_pb2.py +352 -351
  45. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  46. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  47. wandb/proto/v5/wandb_api_pb2.py +38 -0
  48. wandb/proto/v5/wandb_internal_pb2.py +352 -351
  49. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  50. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  51. wandb/proto/v6/wandb_api_pb2.py +48 -0
  52. wandb/proto/v6/wandb_internal_pb2.py +352 -351
  53. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  54. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  55. wandb/proto/wandb_api_pb2.py +18 -0
  56. wandb/proto/wandb_generate_proto.py +1 -0
  57. wandb/sdk/artifacts/_factories.py +7 -2
  58. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  59. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  60. wandb/sdk/artifacts/_generated/operations.py +52 -22
  61. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  62. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  63. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  64. wandb/sdk/artifacts/_gqlutils.py +47 -0
  65. wandb/sdk/artifacts/_models/__init__.py +4 -0
  66. wandb/sdk/artifacts/_models/base_model.py +20 -0
  67. wandb/sdk/artifacts/_validators.py +40 -12
  68. wandb/sdk/artifacts/artifact.py +99 -118
  69. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  70. wandb/sdk/artifacts/artifact_manifest_entry.py +67 -14
  71. wandb/sdk/artifacts/storage_handler.py +18 -12
  72. wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
  73. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
  74. wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
  75. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
  76. wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
  77. wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
  78. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  79. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
  80. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
  81. wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
  82. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +71 -242
  83. wandb/sdk/artifacts/storage_policy.py +25 -12
  84. wandb/sdk/data_types/bokeh.py +5 -1
  85. wandb/sdk/data_types/image.py +17 -6
  86. wandb/sdk/data_types/object_3d.py +67 -2
  87. wandb/sdk/interface/interface.py +31 -4
  88. wandb/sdk/interface/interface_queue.py +10 -0
  89. wandb/sdk/interface/interface_shared.py +0 -7
  90. wandb/sdk/interface/interface_sock.py +9 -3
  91. wandb/sdk/internal/_generated/__init__.py +2 -12
  92. wandb/sdk/internal/job_builder.py +27 -10
  93. wandb/sdk/internal/sender.py +5 -2
  94. wandb/sdk/internal/settings_static.py +2 -82
  95. wandb/sdk/launch/create_job.py +2 -1
  96. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  97. wandb/sdk/launch/utils.py +82 -1
  98. wandb/sdk/lib/progress.py +8 -74
  99. wandb/sdk/lib/service/service_client.py +5 -9
  100. wandb/sdk/lib/service/service_connection.py +39 -23
  101. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  102. wandb/sdk/projects/_generated/__init__.py +12 -33
  103. wandb/sdk/wandb_init.py +23 -3
  104. wandb/sdk/wandb_login.py +53 -27
  105. wandb/sdk/wandb_run.py +10 -5
  106. wandb/sdk/wandb_settings.py +63 -25
  107. wandb/sync/sync.py +7 -2
  108. wandb/util.py +1 -1
  109. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
  110. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/RECORD +113 -103
  111. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  112. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
  113. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,32 +3,35 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import concurrent.futures
6
- import functools
7
6
  import hashlib
8
7
  import logging
9
- import math
10
8
  import os
11
- import queue
12
9
  import shutil
13
- import threading
14
10
  from collections import deque
15
- from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
11
+ from operator import itemgetter
12
+ from typing import TYPE_CHECKING, Any
16
13
  from urllib.parse import quote
17
14
 
18
15
  import requests
19
16
 
20
- from wandb import env
21
17
  from wandb.errors.term import termwarn
22
18
  from wandb.proto.wandb_internal_pb2 import ServerFeature
23
19
  from wandb.sdk.artifacts.artifact_file_cache import (
24
20
  ArtifactFileCache,
25
- Opener,
26
21
  get_artifact_file_cache,
27
22
  )
28
23
  from wandb.sdk.artifacts.staging import get_staging_dir
29
24
  from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
30
25
  from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
31
26
  from wandb.sdk.artifacts.storage_layout import StorageLayout
27
+ from wandb.sdk.artifacts.storage_policies._multipart import (
28
+ MAX_MULTI_UPLOAD_SIZE,
29
+ MIN_MULTI_UPLOAD_SIZE,
30
+ KiB,
31
+ calc_part_size,
32
+ multipart_download,
33
+ scan_chunks,
34
+ )
32
35
  from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
33
36
  from wandb.sdk.artifacts.storage_policy import StoragePolicy
34
37
  from wandb.sdk.internal.internal_api import Api as InternalApi
@@ -44,34 +47,9 @@ if TYPE_CHECKING:
44
47
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
45
48
  from wandb.sdk.internal import progress
46
49
 
47
-
48
- # AWS S3 max upload parts without having to make additional requests for extra parts
49
- S3_MAX_PART_NUMBERS = 1000
50
- S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
51
- S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
52
-
53
-
54
- # Minimum size to switch to multipart download, same as upload, 2GB.
55
- _MULTIPART_DOWNLOAD_SIZE = S3_MIN_MULTI_UPLOAD_SIZE
56
- # Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
57
- # https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
58
- # https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
59
- _DOWNLOAD_PART_SIZE_BYTES = 100 * 1024 * 1024
60
- # Chunk size for reading http response and writing to disk. 1MB.
61
- _HTTP_RES_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
62
- # Signal end of _ChunkQueue, consumer (file writer) should stop after getting this item.
63
- # NOTE: it should only be used for multithread executor, it does notwork for multiprocess executor.
64
- # multipart download is using the executor from artifact.download() which is a multithread executor.
65
- _CHUNK_QUEUE_SENTINEL = object()
66
-
67
50
  logger = logging.getLogger(__name__)
68
51
 
69
52
 
70
- class _ChunkContent(NamedTuple):
71
- offset: int
72
- data: bytes
73
-
74
-
75
53
  class WandbStoragePolicy(StoragePolicy):
76
54
  @classmethod
77
55
  def name(cls) -> str:
@@ -91,6 +69,8 @@ class WandbStoragePolicy(StoragePolicy):
91
69
  session: requests.Session | None = None,
92
70
  ) -> None:
93
71
  self._config = config or {}
72
+ if (storage_region := self._config.get("storageRegion")) is not None:
73
+ self._validate_storage_region(storage_region)
94
74
  self._cache = cache or get_artifact_file_cache()
95
75
  self._session = session or make_http_session()
96
76
  self._api = api or InternalApi()
@@ -99,7 +79,15 @@ class WandbStoragePolicy(StoragePolicy):
99
79
  default_handler=TrackingHandler(),
100
80
  )
101
81
 
102
- def config(self) -> dict:
82
+ def _validate_storage_region(self, storage_region: Any) -> None:
83
+ if not isinstance(storage_region, str):
84
+ raise TypeError(
85
+ f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
86
+ )
87
+ if not storage_region.strip():
88
+ raise ValueError("storageRegion must be a non-empty string")
89
+
90
+ def config(self) -> dict[str, Any]:
103
91
  return self._config
104
92
 
105
93
  def load_file(
@@ -107,8 +95,9 @@ class WandbStoragePolicy(StoragePolicy):
107
95
  artifact: Artifact,
108
96
  manifest_entry: ArtifactManifestEntry,
109
97
  dest_path: str | None = None,
98
+ # FIXME: We should avoid passing the executor into multiple inner functions,
99
+ # it leads to confusing code and opaque tracebacks/call stacks.
110
100
  executor: concurrent.futures.Executor | None = None,
111
- multipart: bool | None = None,
112
101
  ) -> FilePathStr:
113
102
  """Use cache or download the file using signed url.
114
103
 
@@ -116,10 +105,8 @@ class WandbStoragePolicy(StoragePolicy):
116
105
  executor: Passed from caller, artifact has a thread pool for multi file download.
117
106
  Reuse the thread pool for multi part download. The thread pool is closed when
118
107
  artifact download is done.
119
- multipart: If set to `None` (default), the artifact will be downloaded
120
- in parallel using multipart download if individual file size is greater than
121
- 2GB. If set to `True` or `False`, the artifact will be downloaded in
122
- parallel or serially regardless of the file size.
108
+
109
+ If this is None, download the file serially.
123
110
  """
124
111
  if dest_path is not None:
125
112
  self._cache._override_cache_path = dest_path
@@ -131,14 +118,10 @@ class WandbStoragePolicy(StoragePolicy):
131
118
  if hit:
132
119
  return path
133
120
 
134
- if (url := manifest_entry._download_url) is not None:
121
+ if url := manifest_entry._download_url:
135
122
  # Use multipart parallel download for large file
136
- if (
137
- executor
138
- and (size := manifest_entry.size)
139
- and self._should_multipart_download(size, multipart)
140
- ):
141
- self._multipart_file_download(executor, url, size, cache_open)
123
+ if executor and (size := manifest_entry.size):
124
+ multipart_download(executor, self._session, url, size, cache_open)
142
125
  return path
143
126
 
144
127
  # Serial download
@@ -161,142 +144,16 @@ class WandbStoragePolicy(StoragePolicy):
161
144
  else:
162
145
  auth = ("api", self._api.api_key or "")
163
146
 
164
- file_url = self._file_url(
165
- self._api,
166
- artifact.entity,
167
- artifact.project,
168
- artifact.name.split(":")[0],
169
- manifest_entry,
170
- )
147
+ file_url = self._file_url(self._api, artifact, manifest_entry)
171
148
  response = self._session.get(
172
149
  file_url, auth=auth, cookies=cookies, headers=headers, stream=True
173
150
  )
174
151
 
175
152
  with cache_open(mode="wb") as file:
176
- for data in response.iter_content(chunk_size=16 * 1024):
153
+ for data in response.iter_content(chunk_size=16 * KiB):
177
154
  file.write(data)
178
155
  return path
179
156
 
180
- def _should_multipart_download(
181
- self,
182
- file_size: int,
183
- multipart: bool | None,
184
- ) -> bool:
185
- if multipart is not None:
186
- return multipart
187
- return file_size >= _MULTIPART_DOWNLOAD_SIZE
188
-
189
- def _write_chunks_to_file(
190
- self,
191
- f: IO,
192
- q: queue.Queue,
193
- download_has_error: threading.Event,
194
- ):
195
- while not download_has_error.is_set():
196
- item = q.get()
197
- if item is _CHUNK_QUEUE_SENTINEL:
198
- # Normal shutdown, all the chunks are written
199
- return
200
- elif isinstance(item, _ChunkContent):
201
- try:
202
- # NOTE: Seek works without pre allocating the file on disk.
203
- # It automatically creates a sparse file, e.g. ls -hl would show
204
- # a bigger size compared to du -sh * because downloading different
205
- # chunks is not a sequential write.
206
- # See https://man7.org/linux/man-pages/man2/lseek.2.html
207
- f.seek(item.offset)
208
- f.write(item.data)
209
- except Exception as e:
210
- if env.is_debug():
211
- logger.debug(f"Error writing chunk to file: {e}")
212
- download_has_error.set()
213
- raise
214
- else:
215
- raise ValueError(f"Unknown queue item type: {type(item)}")
216
-
217
- def _download_part(
218
- self,
219
- download_url: str,
220
- headers: dict,
221
- start: int,
222
- q: queue.Queue,
223
- download_has_error: threading.Event,
224
- ):
225
- # Other threads has error, no need to start
226
- if download_has_error.is_set():
227
- return
228
- response = self._session.get(url=download_url, headers=headers, stream=True)
229
-
230
- file_offset = start
231
- for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
232
- if download_has_error.is_set():
233
- return
234
- q.put(_ChunkContent(offset=file_offset, data=content))
235
- file_offset += len(content)
236
-
237
- def _multipart_file_download(
238
- self,
239
- executor: concurrent.futures.Executor,
240
- download_url: str,
241
- file_size_bytes: int,
242
- cache_open: Opener,
243
- ):
244
- """Download file as multiple parts in parallel.
245
-
246
- Only one thread for writing to file. Each part run one http request in one thread.
247
- HTTP response chunk of a file part is sent to the writer thread via a queue.
248
- """
249
- q: queue.Queue[_ChunkContent | object] = queue.Queue(maxsize=500)
250
- download_has_error = threading.Event()
251
-
252
- # Put cache_open at top so we remove the tmp file when there is network error.
253
- with cache_open("wb") as f:
254
- # Start writer thread first.
255
- write_handler = functools.partial(
256
- self._write_chunks_to_file, f, q, download_has_error
257
- )
258
- write_future = executor.submit(write_handler)
259
-
260
- # Start download threads for each part.
261
- download_futures: deque[concurrent.futures.Future] = deque()
262
- part_size = _DOWNLOAD_PART_SIZE_BYTES
263
- num_parts = int(math.ceil(file_size_bytes / float(part_size)))
264
- for i in range(num_parts):
265
- # https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
266
- # Start and end are both inclusive, empty end means use the actual end of the file.
267
- start = i * part_size
268
- bytes_range = f"bytes={start}-"
269
- if i != (num_parts - 1):
270
- # bytes=0-499
271
- bytes_range += f"{start + part_size - 1}"
272
- headers = {"Range": bytes_range}
273
- download_handler = functools.partial(
274
- self._download_part,
275
- download_url,
276
- headers,
277
- start,
278
- q,
279
- download_has_error,
280
- )
281
- download_futures.append(executor.submit(download_handler))
282
-
283
- # Wait for download
284
- done, not_done = concurrent.futures.wait(
285
- download_futures, return_when=concurrent.futures.FIRST_EXCEPTION
286
- )
287
- try:
288
- for fut in done:
289
- fut.result()
290
- except Exception as e:
291
- if env.is_debug():
292
- logger.debug(f"Error downloading file: {e}")
293
- download_has_error.set()
294
- raise
295
- finally:
296
- # Always signal the writer to stop
297
- q.put(_CHUNK_QUEUE_SENTINEL)
298
- write_future.result()
299
-
300
157
  def store_reference(
301
158
  self,
302
159
  artifact: Artifact,
@@ -304,7 +161,7 @@ class WandbStoragePolicy(StoragePolicy):
304
161
  name: str | None = None,
305
162
  checksum: bool = True,
306
163
  max_objects: int | None = None,
307
- ) -> Sequence[ArtifactManifestEntry]:
164
+ ) -> list[ArtifactManifestEntry]:
308
165
  return self._handler.store_path(
309
166
  artifact, path, name=name, checksum=checksum, max_objects=max_objects
310
167
  )
@@ -324,13 +181,16 @@ class WandbStoragePolicy(StoragePolicy):
324
181
  def _file_url(
325
182
  self,
326
183
  api: InternalApi,
327
- entity_name: str,
328
- project_name: str,
329
- artifact_name: str,
184
+ artifact: Artifact,
330
185
  entry: ArtifactManifestEntry,
331
186
  ) -> str:
332
187
  layout = self._config.get("storageLayout", StorageLayout.V1)
333
188
  region = self._config.get("storageRegion", "default")
189
+
190
+ entity_name = artifact.entity
191
+ project_name = artifact.project
192
+ artifact_name = artifact.name.split(":")[0]
193
+
334
194
  md5_hex = b64_to_hex_id(entry.digest)
335
195
 
336
196
  base_url: str = api.settings("base_url")
@@ -357,30 +217,21 @@ class WandbStoragePolicy(StoragePolicy):
357
217
  multipart_urls: dict[int, str],
358
218
  extra_headers: dict[str, str],
359
219
  ) -> list[dict[str, Any]]:
360
- etags = []
361
- part_number = 1
362
-
363
- with open(file_path, "rb") as f:
364
- while True:
365
- data = f.read(chunk_size)
366
- if not data:
367
- break
368
- md5_b64_str = str(hex_to_b64_id(hex_digests[part_number]))
369
- upload_resp = self._api.upload_multipart_file_chunk_retry(
370
- multipart_urls[part_number],
371
- data,
372
- extra_headers={
373
- "content-md5": md5_b64_str,
374
- "content-length": str(len(data)),
375
- "content-type": extra_headers.get("Content-Type", ""),
376
- },
377
- )
378
- assert upload_resp is not None
379
- etags.append(
380
- {"partNumber": part_number, "hexMD5": upload_resp.headers["ETag"]}
381
- )
382
- part_number += 1
383
- return etags
220
+ etags: deque[dict[str, Any]] = deque()
221
+ file_chunks = scan_chunks(file_path, chunk_size)
222
+ for num, data in enumerate(file_chunks, start=1):
223
+ rsp = self._api.upload_multipart_file_chunk_retry(
224
+ multipart_urls[num],
225
+ data,
226
+ extra_headers={
227
+ "content-md5": hex_to_b64_id(hex_digests[num]),
228
+ "content-length": str(len(data)),
229
+ "content-type": extra_headers.get("Content-Type") or "",
230
+ },
231
+ )
232
+ assert rsp is not None
233
+ etags.append({"partNumber": num, "hexMD5": rsp.headers["ETag"]})
234
+ return list(etags)
384
235
 
385
236
  def default_file_upload(
386
237
  self,
@@ -393,20 +244,9 @@ class WandbStoragePolicy(StoragePolicy):
393
244
  with open(file_path, "rb") as file:
394
245
  # This fails if we don't send the first byte before the signed URL expires.
395
246
  self._api.upload_file_retry(
396
- upload_url,
397
- file,
398
- progress_callback,
399
- extra_headers=extra_headers,
247
+ upload_url, file, progress_callback, extra_headers=extra_headers
400
248
  )
401
249
 
402
- def calc_chunk_size(self, file_size: int) -> int:
403
- # Default to chunk size of 100MiB. S3 has cap of 10,000 upload parts.
404
- # If file size exceeds the default chunk size, recalculate chunk size.
405
- default_chunk_size = 100 * 1024**2
406
- if default_chunk_size * S3_MAX_PART_NUMBERS < file_size:
407
- return math.ceil(file_size / S3_MAX_PART_NUMBERS)
408
- return default_chunk_size
409
-
410
250
  def store_file(
411
251
  self,
412
252
  artifact_id: str,
@@ -422,28 +262,20 @@ class WandbStoragePolicy(StoragePolicy):
422
262
  False if it needed to be uploaded or was a reference (nothing to dedupe).
423
263
  """
424
264
  file_size = entry.size or 0
425
- chunk_size = self.calc_chunk_size(file_size)
426
- upload_parts = []
427
- hex_digests = {}
428
- file_path = entry.local_path if entry.local_path is not None else ""
265
+ chunk_size = calc_part_size(file_size)
266
+ file_path = entry.local_path or ""
429
267
  # Logic for AWS s3 multipart upload.
430
268
  # Only chunk files if larger than 2 GiB. Currently can only support up to 5TiB.
431
- if (
432
- file_size >= S3_MIN_MULTI_UPLOAD_SIZE
433
- and file_size <= S3_MAX_MULTI_UPLOAD_SIZE
434
- ):
435
- part_number = 1
436
- with open(file_path, "rb") as f:
437
- while True:
438
- data = f.read(chunk_size)
439
- if not data:
440
- break
441
- hex_digest = hashlib.md5(data).hexdigest()
442
- upload_parts.append(
443
- {"hexMD5": hex_digest, "partNumber": part_number}
444
- )
445
- hex_digests[part_number] = hex_digest
446
- part_number += 1
269
+ if MIN_MULTI_UPLOAD_SIZE <= file_size <= MAX_MULTI_UPLOAD_SIZE:
270
+ file_chunks = scan_chunks(file_path, chunk_size)
271
+ upload_parts = [
272
+ {"partNumber": num, "hexMD5": hashlib.md5(data).hexdigest()}
273
+ for num, data in enumerate(file_chunks, start=1)
274
+ ]
275
+ hex_digests = dict(map(itemgetter("partNumber", "hexMD5"), upload_parts))
276
+ else:
277
+ upload_parts = []
278
+ hex_digests = {}
447
279
 
448
280
  resp = preparer.prepare(
449
281
  {
@@ -457,24 +289,21 @@ class WandbStoragePolicy(StoragePolicy):
457
289
 
458
290
  entry.birth_artifact_id = resp.birth_artifact_id
459
291
 
460
- multipart_urls = resp.multipart_upload_urls
461
292
  if resp.upload_url is None:
462
293
  return True
463
294
  if entry.local_path is None:
464
295
  return False
465
- extra_headers = {
466
- header.split(":", 1)[0]: header.split(":", 1)[1]
467
- for header in (resp.upload_headers or {})
468
- }
296
+
297
+ extra_headers = dict(hdr.split(":", 1) for hdr in (resp.upload_headers or []))
469
298
 
470
299
  # This multipart upload isn't available, do a regular single url upload
471
- if multipart_urls is None and resp.upload_url:
300
+ if (multipart_urls := resp.multipart_upload_urls) is None and resp.upload_url:
472
301
  self.default_file_upload(
473
302
  resp.upload_url, file_path, extra_headers, progress_callback
474
303
  )
304
+ elif multipart_urls is None:
305
+ raise ValueError(f"No multipart urls to upload for file: {file_path}")
475
306
  else:
476
- if multipart_urls is None:
477
- raise ValueError(f"No multipart urls to upload for file: {file_path}")
478
307
  # Upload files using s3 multipart upload urls
479
308
  etags = self.s3_multipart_file_upload(
480
309
  file_path,
@@ -503,7 +332,7 @@ class WandbStoragePolicy(StoragePolicy):
503
332
 
504
333
  staging_dir = get_staging_dir()
505
334
  try:
506
- if not entry.skip_cache and not hit:
335
+ if not (entry.skip_cache or hit):
507
336
  with cache_open("wb") as f, open(entry.local_path, "rb") as src:
508
337
  shutil.copyfileobj(src, f)
509
338
  if entry.local_path.startswith(staging_dir):
@@ -3,7 +3,8 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import concurrent.futures
6
- from typing import TYPE_CHECKING, Sequence
6
+ from abc import ABC, abstractmethod
7
+ from typing import TYPE_CHECKING, Any
7
8
 
8
9
  from wandb.sdk.internal.internal_api import Api as InternalApi
9
10
  from wandb.sdk.lib.paths import FilePathStr, URIStr
@@ -15,37 +16,47 @@ if TYPE_CHECKING:
15
16
  from wandb.sdk.internal.progress import ProgressFn
16
17
 
17
18
 
18
- class StoragePolicy:
19
+ _POLICY_REGISTRY: dict[str, type[StoragePolicy]] = {}
20
+
21
+
22
+ class StoragePolicy(ABC):
23
+ def __init_subclass__(cls, **kwargs: Any) -> None:
24
+ super().__init_subclass__(**kwargs)
25
+ _POLICY_REGISTRY[cls.name()] = cls
26
+
19
27
  @classmethod
20
28
  def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
21
- import wandb.sdk.artifacts.storage_policies # noqa: F401
22
-
23
- for sub in cls.__subclasses__():
24
- if sub.name() == name:
25
- return sub
26
- raise NotImplementedError(f"Failed to find storage policy '{name}'")
29
+ if policy := _POLICY_REGISTRY.get(name):
30
+ return policy
31
+ raise ValueError(f"Failed to find storage policy {name!r}")
27
32
 
28
33
  @classmethod
34
+ @abstractmethod
29
35
  def name(cls) -> str:
30
36
  raise NotImplementedError
31
37
 
32
38
  @classmethod
33
- def from_config(cls, config: dict, api: InternalApi | None = None) -> StoragePolicy:
39
+ @abstractmethod
40
+ def from_config(
41
+ cls, config: dict[str, Any], api: InternalApi | None = None
42
+ ) -> StoragePolicy:
34
43
  raise NotImplementedError
35
44
 
36
- def config(self) -> dict:
45
+ @abstractmethod
46
+ def config(self) -> dict[str, Any]:
37
47
  raise NotImplementedError
38
48
 
49
+ @abstractmethod
39
50
  def load_file(
40
51
  self,
41
52
  artifact: Artifact,
42
53
  manifest_entry: ArtifactManifestEntry,
43
54
  dest_path: str | None = None,
44
55
  executor: concurrent.futures.Executor | None = None,
45
- multipart: bool | None = None,
46
56
  ) -> FilePathStr:
47
57
  raise NotImplementedError
48
58
 
59
+ @abstractmethod
49
60
  def store_file(
50
61
  self,
51
62
  artifact_id: str,
@@ -56,6 +67,7 @@ class StoragePolicy:
56
67
  ) -> bool:
57
68
  raise NotImplementedError
58
69
 
70
+ @abstractmethod
59
71
  def store_reference(
60
72
  self,
61
73
  artifact: Artifact,
@@ -63,9 +75,10 @@ class StoragePolicy:
63
75
  name: str | None = None,
64
76
  checksum: bool = True,
65
77
  max_objects: int | None = None,
66
- ) -> Sequence[ArtifactManifestEntry]:
78
+ ) -> list[ArtifactManifestEntry]:
67
79
  raise NotImplementedError
68
80
 
81
+ @abstractmethod
69
82
  def load_reference(
70
83
  self,
71
84
  manifest_entry: ArtifactManifestEntry,
@@ -5,6 +5,7 @@ import pathlib
5
5
  from typing import TYPE_CHECKING, Union
6
6
 
7
7
  from wandb import util
8
+ from wandb._strutils import nameof
8
9
  from wandb.sdk.lib import runid
9
10
 
10
11
  from . import _dtypes
@@ -34,7 +35,10 @@ class Bokeh(Media):
34
35
  ],
35
36
  ):
36
37
  super().__init__()
37
- bokeh = util.get_module("bokeh", required=True)
38
+ bokeh = util.get_module(
39
+ "bokeh",
40
+ required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
41
+ )
38
42
  if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
39
43
  data_or_path
40
44
  ):
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
161
161
  ) -> None:
162
162
  """Initialize a `wandb.Image` object.
163
163
 
164
+ This class handles various image data formats and automatically normalizes
165
+ pixel values to the range [0, 255] when needed, ensuring compatibility
166
+ with the W&B backend.
167
+
168
+ * Data in range [0, 1] is multiplied by 255 and converted to uint8
169
+ * Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
170
+ -1 to 0 and 1 to 255, then converted to uint8
171
+ * Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
172
+ converted to uint8 (with a warning if values fall outside [0, 255])
173
+ * Data already in [0, 255] is converted to uint8 without modification
174
+
164
175
  Args:
165
176
  data_or_path: Accepts NumPy array/pytorch tensor of image data,
166
177
  a PIL image object, or a path to an image file. If a NumPy
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
168
179
  the image data will be saved to the given file type.
169
180
  If the values are not in the range [0, 255] or all values are in the range [0, 1],
170
181
  the image pixel values will be normalized to the range [0, 255]
171
- unless `normalize` is set to False.
182
+ unless `normalize` is set to `False`.
172
183
  - pytorch tensor should be in the format (channel, height, width)
173
184
  - NumPy array should be in the format (height, width, channel)
174
185
  mode: The PIL mode for an image. Most common are "L", "RGB",
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
178
189
  classes: A list of class information for the image,
179
190
  used for labeling bounding boxes, and image masks.
180
191
  boxes: A dictionary containing bounding box information for the image.
181
- see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
192
+ see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
182
193
  masks: A dictionary containing mask information for the image.
183
- see: https://docs.wandb.ai/ref/python/data-types/imagemask/
194
+ see https://docs.wandb.ai/ref/python/data-types/imagemask/
184
195
  file_type: The file type to save the image as.
185
- This parameter has no effect if data_or_path is a path to an image file.
186
- normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
187
- Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
196
+ This parameter has no effect if `data_or_path` is a path to an image file.
197
+ normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
198
+ Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
188
199
 
189
200
  Examples:
190
201
  Create a wandb.Image from a numpy array