wandb 0.22.0__py3-none-win_amd64.whl → 0.22.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  37. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  38. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  39. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  40. wandb/sdk/artifacts/_factories.py +7 -2
  41. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  42. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  43. wandb/sdk/artifacts/_generated/operations.py +52 -22
  44. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  45. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  46. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  47. wandb/sdk/artifacts/_gqlutils.py +47 -0
  48. wandb/sdk/artifacts/_models/__init__.py +4 -0
  49. wandb/sdk/artifacts/_models/base_model.py +20 -0
  50. wandb/sdk/artifacts/_validators.py +40 -12
  51. wandb/sdk/artifacts/artifact.py +69 -88
  52. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  53. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  54. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +10 -0
  55. wandb/sdk/data_types/bokeh.py +5 -1
  56. wandb/sdk/data_types/image.py +17 -6
  57. wandb/sdk/interface/interface.py +31 -4
  58. wandb/sdk/interface/interface_queue.py +10 -0
  59. wandb/sdk/interface/interface_shared.py +0 -7
  60. wandb/sdk/interface/interface_sock.py +9 -3
  61. wandb/sdk/internal/_generated/__init__.py +2 -12
  62. wandb/sdk/internal/sender.py +1 -1
  63. wandb/sdk/internal/settings_static.py +2 -82
  64. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  65. wandb/sdk/launch/utils.py +82 -1
  66. wandb/sdk/lib/progress.py +7 -4
  67. wandb/sdk/lib/service/service_client.py +5 -9
  68. wandb/sdk/lib/service/service_connection.py +39 -23
  69. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  70. wandb/sdk/projects/_generated/__init__.py +12 -33
  71. wandb/sdk/wandb_init.py +22 -2
  72. wandb/sdk/wandb_login.py +53 -27
  73. wandb/sdk/wandb_run.py +5 -3
  74. wandb/sdk/wandb_settings.py +50 -13
  75. wandb/sync/sync.py +7 -2
  76. wandb/util.py +1 -1
  77. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  78. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/RECORD +81 -78
  79. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  80. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  81. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.22.0.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,7 @@ import tempfile
16
16
  import time
17
17
  from collections import deque
18
18
  from copy import copy
19
- from dataclasses import dataclass
19
+ from dataclasses import asdict, dataclass, replace
20
20
  from datetime import timedelta
21
21
  from functools import partial
22
22
  from itertools import filterfalse
@@ -30,7 +30,6 @@ from typing import (
30
30
  Literal,
31
31
  Sequence,
32
32
  Type,
33
- cast,
34
33
  final,
35
34
  )
36
35
  from urllib.parse import quote, urljoin, urlparse
@@ -109,10 +108,11 @@ from ._generated import (
109
108
  TagInput,
110
109
  UpdateArtifact,
111
110
  )
112
- from ._graphql_fragments import omit_artifact_fields
111
+ from ._gqlutils import omit_artifact_fields, supports_enable_tracking_var, type_info
113
112
  from ._validators import (
114
113
  LINKED_ARTIFACT_COLLECTION_TYPE,
115
114
  ArtifactPath,
115
+ FullArtifactPath,
116
116
  _LinkArtifactFields,
117
117
  ensure_logged,
118
118
  ensure_not_finalized,
@@ -210,6 +210,7 @@ class Artifact:
210
210
  metadata: dict[str, Any] | None = None,
211
211
  incremental: bool = False,
212
212
  use_as: str | None = None,
213
+ storage_region: str | None = None,
213
214
  ) -> None:
214
215
  if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
215
216
  raise ValueError(
@@ -268,7 +269,7 @@ class Artifact:
268
269
  self._use_as: str | None = None
269
270
  self._state: ArtifactState = ArtifactState.PENDING
270
271
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
271
- ArtifactManifestV1(storage_policy=make_storage_policy())
272
+ ArtifactManifestV1(storage_policy=make_storage_policy(storage_region))
272
273
  )
273
274
  self._commit_hash: str | None = None
274
275
  self._file_count: int | None = None
@@ -286,36 +287,33 @@ class Artifact:
286
287
 
287
288
  @classmethod
288
289
  def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
289
- if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
290
- return artifact
290
+ if cached_artifact := artifact_instance_cache.get(artifact_id):
291
+ return cached_artifact
291
292
 
292
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
293
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(client))
293
294
 
294
295
  data = client.execute(query, variable_values={"id": artifact_id})
295
296
  result = ArtifactByID.model_validate(data)
296
297
 
297
- if (art := result.artifact) is None:
298
+ if (artifact := result.artifact) is None:
298
299
  return None
299
300
 
300
- src_collection = art.artifact_sequence
301
+ src_collection = artifact.artifact_sequence
301
302
  src_project = src_collection.project
302
303
 
303
304
  entity_name = src_project.entity_name if src_project else ""
304
305
  project_name = src_project.name if src_project else ""
305
306
 
306
- name = f"{src_collection.name}:v{art.version_index}"
307
- return cls._from_attrs(entity_name, project_name, name, art, client)
307
+ name = f"{src_collection.name}:v{artifact.version_index}"
308
+
309
+ path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
310
+ return cls._from_attrs(path, artifact, client)
308
311
 
309
312
  @classmethod
310
313
  def _membership_from_name(
311
- cls,
312
- *,
313
- entity: str,
314
- project: str,
315
- name: str,
316
- client: RetryingClient,
314
+ cls, *, path: FullArtifactPath, client: RetryingClient
317
315
  ) -> Artifact:
318
- if not (api := InternalApi())._server_supports(
316
+ if not InternalApi()._server_supports(
319
317
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
320
318
  ):
321
319
  raise UnsupportedError(
@@ -325,74 +323,70 @@ class Artifact:
325
323
 
326
324
  query = gql_compat(
327
325
  ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
328
- omit_fields=omit_artifact_fields(api=api),
326
+ omit_fields=omit_artifact_fields(client),
329
327
  )
330
-
331
- gql_vars = {"entityName": entity, "projectName": project, "name": name}
328
+ gql_vars = {
329
+ "entityName": path.prefix,
330
+ "projectName": path.project,
331
+ "name": path.name,
332
+ }
332
333
  data = client.execute(query, variable_values=gql_vars)
333
334
  result = ArtifactViaMembershipByName.model_validate(data)
334
335
 
335
- if not (project_attrs := result.project):
336
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
337
-
338
- if not (acm_attrs := project_attrs.artifact_collection_membership):
339
- entity_project = f"{entity}/{project}"
336
+ if not (project := result.project):
340
337
  raise ValueError(
341
- f"artifact membership {name!r} not found in {entity_project!r}"
338
+ f"project {path.project!r} not found under entity {path.prefix!r}"
342
339
  )
343
-
344
- target_path = ArtifactPath(prefix=entity, project=project, name=name)
345
- return cls._from_membership(acm_attrs, target=target_path, client=client)
340
+ if not (membership := project.artifact_collection_membership):
341
+ entity_project = f"{path.prefix}/{path.project}"
342
+ raise ValueError(
343
+ f"artifact membership {path.name!r} not found in {entity_project!r}"
344
+ )
345
+ return cls._from_membership(membership, target=path, client=client)
346
346
 
347
347
  @classmethod
348
348
  def _from_name(
349
349
  cls,
350
350
  *,
351
- entity: str,
352
- project: str,
353
- name: str,
351
+ path: FullArtifactPath,
354
352
  client: RetryingClient,
355
353
  enable_tracking: bool = False,
356
354
  ) -> Artifact:
357
- if (api := InternalApi())._server_supports(
355
+ if InternalApi()._server_supports(
358
356
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
359
357
  ):
360
- return cls._membership_from_name(
361
- entity=entity, project=project, name=name, client=client
362
- )
363
-
364
- supports_enable_tracking_gql_var = api.server_project_type_introspection()
365
- omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
358
+ return cls._membership_from_name(path=path, client=client)
366
359
 
360
+ omit_vars = None if supports_enable_tracking_var(client) else {"enableTracking"}
367
361
  gql_vars = {
368
- "entityName": entity,
369
- "projectName": project,
370
- "name": name,
362
+ "entityName": path.prefix,
363
+ "projectName": path.project,
364
+ "name": path.name,
371
365
  "enableTracking": enable_tracking,
372
366
  }
373
367
  query = gql_compat(
374
368
  ARTIFACT_BY_NAME_GQL,
375
369
  omit_variables=omit_vars,
376
- omit_fields=omit_artifact_fields(api=api),
370
+ omit_fields=omit_artifact_fields(client),
377
371
  )
378
-
379
372
  data = client.execute(query, variable_values=gql_vars)
380
373
  result = ArtifactByName.model_validate(data)
381
374
 
382
- if not (proj_attrs := result.project):
383
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
384
-
385
- if not (art_attrs := proj_attrs.artifact):
386
- entity_project = f"{entity}/{project}"
387
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
375
+ if not (project := result.project):
376
+ raise ValueError(
377
+ f"project {path.project!r} not found under entity {path.prefix!r}"
378
+ )
379
+ if not (artifact := project.artifact):
380
+ entity_project = f"{path.prefix}/{path.project}"
381
+ raise ValueError(f"artifact {path.name!r} not found in {entity_project!r}")
388
382
 
389
- return cls._from_attrs(entity, project, name, art_attrs, client)
383
+ return cls._from_attrs(path, artifact, client)
390
384
 
391
385
  @classmethod
392
386
  def _from_membership(
393
387
  cls,
394
388
  membership: MembershipWithArtifact,
395
- target: ArtifactPath,
389
+ target: FullArtifactPath,
396
390
  client: RetryingClient,
397
391
  ) -> Artifact:
398
392
  if not (
@@ -410,35 +404,31 @@ class Artifact:
410
404
  f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
411
405
  f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
412
406
  )
413
- new_entity, new_project = proj.entity_name, proj.name
407
+ new_target = replace(target, prefix=proj.entity_name, project=proj.name)
414
408
  else:
415
- new_entity = cast(str, target.prefix)
416
- new_project = cast(str, target.project)
409
+ new_target = copy(target)
417
410
 
418
411
  if not (artifact := membership.artifact):
419
412
  raise ValueError(f"Artifact {target.to_str()!r} not found in response")
420
413
 
421
- return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
414
+ return cls._from_attrs(new_target, artifact, client)
422
415
 
423
416
  @classmethod
424
417
  def _from_attrs(
425
418
  cls,
426
- entity: str,
427
- project: str,
428
- name: str,
429
- attrs: dict[str, Any] | ArtifactFragment,
419
+ path: FullArtifactPath,
420
+ attrs: ArtifactFragment,
430
421
  client: RetryingClient,
431
422
  aliases: list[str] | None = None,
432
423
  ) -> Artifact:
433
424
  # Placeholder is required to skip validation.
434
425
  artifact = cls("placeholder", type="placeholder")
435
426
  artifact._client = client
436
- artifact._entity = entity
437
- artifact._project = project
438
- artifact._name = name
427
+ artifact._entity = path.prefix
428
+ artifact._project = path.project
429
+ artifact._name = path.name
439
430
 
440
- validated_attrs = ArtifactFragment.model_validate(attrs)
441
- artifact._assign_attrs(validated_attrs, aliases)
431
+ artifact._assign_attrs(attrs, aliases)
442
432
 
443
433
  artifact.finalize()
444
434
 
@@ -743,13 +733,12 @@ class Artifact:
743
733
  raise ValueError("Client is not initialized")
744
734
 
745
735
  try:
746
- artifact = self._from_name(
747
- entity=self.source_entity,
736
+ path = FullArtifactPath(
737
+ prefix=self.source_entity,
748
738
  project=self.source_project,
749
739
  name=self.source_name,
750
- client=self._client,
751
740
  )
752
- self._source_artifact = artifact
741
+ self._source_artifact = self._from_name(path=path, client=self._client)
753
742
  except Exception as e:
754
743
  raise ValueError(
755
744
  f"Unable to fetch source artifact for linked artifact {self.name}"
@@ -1234,8 +1223,9 @@ class Artifact:
1234
1223
  def _populate_after_save(self, artifact_id: str) -> None:
1235
1224
  assert self._client is not None
1236
1225
 
1237
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1238
-
1226
+ query = gql_compat(
1227
+ ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(self._client)
1228
+ )
1239
1229
  data = self._client.execute(query, variable_values={"id": artifact_id})
1240
1230
  result = ArtifactByID.model_validate(data)
1241
1231
 
@@ -1258,21 +1248,9 @@ class Artifact:
1258
1248
  collection = self.name.split(":")[0]
1259
1249
 
1260
1250
  aliases = None
1261
- introspect_query = gql(
1262
- """
1263
- query ProbeServerAddAliasesInput {
1264
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
1265
- name
1266
- inputFields {
1267
- name
1268
- }
1269
- }
1270
- }
1271
- """
1272
- )
1273
1251
 
1274
- data = self._client.execute(introspect_query)
1275
- if data.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
1252
+ if type_info(self._client, "AddAliasesInput") is not None:
1253
+ # wandb backend version >= 0.13.0
1276
1254
  alias_props = {
1277
1255
  "entity_name": entity,
1278
1256
  "project_name": project,
@@ -1330,7 +1308,7 @@ class Artifact:
1330
1308
  for alias in self.aliases
1331
1309
  ]
1332
1310
 
1333
- omit_fields = omit_artifact_fields()
1311
+ omit_fields = omit_artifact_fields(self._client)
1334
1312
  omit_variables = set()
1335
1313
 
1336
1314
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -2427,7 +2405,7 @@ class Artifact:
2427
2405
 
2428
2406
  # Parse the entity (first part of the path) appropriately,
2429
2407
  # depending on whether we're linking to a registry
2430
- if target.project and is_artifact_registry_project(target.project):
2408
+ if target.is_registry_path():
2431
2409
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2432
2410
  # therefore the source artifact's entity is passed to the backend
2433
2411
  org = target.prefix or settings.get("organization") or ""
@@ -2435,6 +2413,9 @@ class Artifact:
2435
2413
  else:
2436
2414
  target = target.with_defaults(prefix=self.source_entity)
2437
2415
 
2416
+ # Explicitly convert to FullArtifactPath to ensure all fields are present
2417
+ target = FullArtifactPath(**asdict(target))
2418
+
2438
2419
  # Prepare the validated GQL input, send it
2439
2420
  alias_inputs = [
2440
2421
  ArtifactAliasInput(artifact_collection_name=target.name, alias=a)
@@ -26,6 +26,11 @@ class Opener(Protocol):
26
26
  pass
27
27
 
28
28
 
29
+ def artifacts_cache_dir() -> Path:
30
+ """Get the artifacts cache directory."""
31
+ return env.get_cache_dir() / "artifacts"
32
+
33
+
29
34
  def _get_sys_umask_threadsafe() -> int:
30
35
  # Workaround to get the current system umask, since
31
36
  # - `os.umask()` isn't thread-safe
@@ -248,4 +253,4 @@ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
248
253
 
249
254
 
250
255
  def get_artifact_file_cache() -> ArtifactFileCache:
251
- return _build_artifact_file_cache(env.get_cache_dir() / "artifacts")
256
+ return _build_artifact_file_cache(artifacts_cache_dir())
@@ -3,9 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import concurrent.futures
6
+ import hashlib
6
7
  import json
7
8
  import logging
8
9
  import os
10
+ from contextlib import suppress
9
11
  from pathlib import Path
10
12
  from typing import TYPE_CHECKING
11
13
  from urllib.parse import urlparse
@@ -43,6 +45,48 @@ if TYPE_CHECKING:
43
45
  _WB_ARTIFACT_SCHEME = "wandb-artifact"
44
46
 
45
47
 
48
+ def _checksum_cache_path(file_path: str) -> str:
49
+ """Get path for checksum in central cache directory."""
50
+ from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
51
+
52
+ # Create a unique cache key based on the file's absolute path
53
+ abs_path = os.path.abspath(file_path)
54
+ path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
55
+
56
+ # Store in wandb cache directory under checksums subdirectory
57
+ cache_dir = artifacts_cache_dir() / "checksums"
58
+ cache_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ return str(cache_dir / f"{path_hash}.checksum")
61
+
62
+
63
+ def _read_cached_checksum(file_path: str) -> str | None:
64
+ """Read checksum from cache if it exists and is valid."""
65
+ checksum_path = _checksum_cache_path(file_path)
66
+
67
+ try:
68
+ with open(file_path) as f, open(checksum_path) as f_checksum:
69
+ if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
70
+ # File was modified after checksum was written
71
+ return None
72
+ # Read and return the cached checksum
73
+ return f_checksum.read().strip()
74
+ except OSError:
75
+ # File doesn't exist or couldn't be opened
76
+ return None
77
+
78
+
79
+ def _write_cached_checksum(file_path: str, checksum: str) -> None:
80
+ """Write checksum to cache directory."""
81
+ checksum_path = _checksum_cache_path(file_path)
82
+ try:
83
+ with open(checksum_path, "w") as f:
84
+ f.write(checksum)
85
+ except OSError:
86
+ # Non-critical failure, just log it
87
+ logger.debug(f"Failed to write checksum cache for {file_path!r}")
88
+
89
+
46
90
  class ArtifactManifestEntry:
47
91
  """A single entry in an artifact manifest."""
48
92
 
@@ -164,11 +208,19 @@ class ArtifactManifestEntry:
164
208
 
165
209
  # Skip checking the cache (and possibly downloading) if the file already exists
166
210
  # and has the digest we're expecting.
211
+
212
+ # Fast integrity check using cached checksum from persistent cache
213
+ with suppress(OSError):
214
+ if self.digest == _read_cached_checksum(dest_path):
215
+ return FilePathStr(dest_path)
216
+
217
+ # Fallback to computing/caching the checksum hash
167
218
  try:
168
219
  md5_hash = md5_file_b64(dest_path)
169
220
  except (FileNotFoundError, IsADirectoryError):
170
- logger.debug(f"unable to find {dest_path}, skip searching for file")
221
+ logger.debug(f"unable to find {dest_path!r}, skip searching for file")
171
222
  else:
223
+ _write_cached_checksum(dest_path, md5_hash)
172
224
  if self.digest == md5_hash:
173
225
  return FilePathStr(dest_path)
174
226
 
@@ -184,12 +236,19 @@ class ArtifactManifestEntry:
184
236
  executor=executor,
185
237
  multipart=multipart,
186
238
  )
187
- return FilePathStr(
239
+
240
+ # Determine the final path
241
+ final_path = (
188
242
  dest_path
189
243
  if skip_cache
190
244
  else copy_or_overwrite_changed(cache_path, dest_path)
191
245
  )
192
246
 
247
+ # Cache the checksum for future downloads
248
+ _write_cached_checksum(str(final_path), self.digest)
249
+
250
+ return FilePathStr(final_path)
251
+
193
252
  def ref_target(self) -> FilePathStr | URIStr:
194
253
  """Get the reference URL that is targeted by this artifact entry.
195
254
 
@@ -91,6 +91,8 @@ class WandbStoragePolicy(StoragePolicy):
91
91
  session: requests.Session | None = None,
92
92
  ) -> None:
93
93
  self._config = config or {}
94
+ if (storage_region := self._config.get("storageRegion")) is not None:
95
+ self._validate_storage_region(storage_region)
94
96
  self._cache = cache or get_artifact_file_cache()
95
97
  self._session = session or make_http_session()
96
98
  self._api = api or InternalApi()
@@ -99,6 +101,14 @@ class WandbStoragePolicy(StoragePolicy):
99
101
  default_handler=TrackingHandler(),
100
102
  )
101
103
 
104
+ def _validate_storage_region(self, storage_region: Any) -> None:
105
+ if not isinstance(storage_region, str):
106
+ raise TypeError(
107
+ f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
108
+ )
109
+ if not storage_region.strip():
110
+ raise ValueError("storageRegion must be a non-empty string")
111
+
102
112
  def config(self) -> dict:
103
113
  return self._config
104
114
 
@@ -5,6 +5,7 @@ import pathlib
5
5
  from typing import TYPE_CHECKING, Union
6
6
 
7
7
  from wandb import util
8
+ from wandb._strutils import nameof
8
9
  from wandb.sdk.lib import runid
9
10
 
10
11
  from . import _dtypes
@@ -34,7 +35,10 @@ class Bokeh(Media):
34
35
  ],
35
36
  ):
36
37
  super().__init__()
37
- bokeh = util.get_module("bokeh", required=True)
38
+ bokeh = util.get_module(
39
+ "bokeh",
40
+ required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
41
+ )
38
42
  if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
39
43
  data_or_path
40
44
  ):
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
161
161
  ) -> None:
162
162
  """Initialize a `wandb.Image` object.
163
163
 
164
+ This class handles various image data formats and automatically normalizes
165
+ pixel values to the range [0, 255] when needed, ensuring compatibility
166
+ with the W&B backend.
167
+
168
+ * Data in range [0, 1] is multiplied by 255 and converted to uint8
169
+ * Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
170
+ -1 to 0 and 1 to 255, then converted to uint8
171
+ * Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
172
+ converted to uint8 (with a warning if values fall outside [0, 255])
173
+ * Data already in [0, 255] is converted to uint8 without modification
174
+
164
175
  Args:
165
176
  data_or_path: Accepts NumPy array/pytorch tensor of image data,
166
177
  a PIL image object, or a path to an image file. If a NumPy
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
168
179
  the image data will be saved to the given file type.
169
180
  If the values are not in the range [0, 255] or all values are in the range [0, 1],
170
181
  the image pixel values will be normalized to the range [0, 255]
171
- unless `normalize` is set to False.
182
+ unless `normalize` is set to `False`.
172
183
  - pytorch tensor should be in the format (channel, height, width)
173
184
  - NumPy array should be in the format (height, width, channel)
174
185
  mode: The PIL mode for an image. Most common are "L", "RGB",
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
178
189
  classes: A list of class information for the image,
179
190
  used for labeling bounding boxes, and image masks.
180
191
  boxes: A dictionary containing bounding box information for the image.
181
- see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
192
+ see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
182
193
  masks: A dictionary containing mask information for the image.
183
- see: https://docs.wandb.ai/ref/python/data-types/imagemask/
194
+ see https://docs.wandb.ai/ref/python/data-types/imagemask/
184
195
  file_type: The file type to save the image as.
185
- This parameter has no effect if data_or_path is a path to an image file.
186
- normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
187
- Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
196
+ This parameter has no effect if `data_or_path` is a path to an image file.
197
+ normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
198
+ Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
188
199
 
189
200
  Examples:
190
201
  Create a wandb.Image from a numpy array
@@ -87,11 +87,38 @@ def file_enum_to_policy(enum: "pb.FilesItem.PolicyType.V") -> "PolicyName":
87
87
 
88
88
 
89
89
  class InterfaceBase:
90
+ """Methods for sending different types of Records to the service.
91
+
92
+ None of the methods may be called from an asyncio context other than
93
+ deliver_async().
94
+ """
95
+
90
96
  _drop: bool
91
97
 
92
98
  def __init__(self) -> None:
93
99
  self._drop = False
94
100
 
101
+ @abstractmethod
102
+ async def deliver_async(
103
+ self,
104
+ record: pb.Record,
105
+ ) -> MailboxHandle[pb.Result]:
106
+ """Send a record and create a handle to wait for the response.
107
+
108
+ The synchronous publish and deliver methods on this class cannot be
109
+ called in the asyncio thread because they block. Instead of having
110
+ an async copy of every method, this is a general method for sending
111
+ any kind of record in the asyncio thread.
112
+
113
+ Args:
114
+ record: The record to send. This method takes ownership of the
115
+ record and it must not be used afterward.
116
+
117
+ Returns:
118
+ A handle to wait for a response to the record.
119
+ """
120
+ raise NotImplementedError
121
+
95
122
  def publish_header(self) -> None:
96
123
  header = pb.HeaderRecord()
97
124
  self._publish_header(header)
@@ -392,9 +419,13 @@ class InterfaceBase:
392
419
  proto_manifest.manifest_file_path = path
393
420
  return proto_manifest
394
421
 
422
+ # Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
423
+ # NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
424
+ # The creation logic is in artifacts/_factories.py make_storage_policy
395
425
  for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
396
426
  cfg = proto_manifest.storage_policy_config.add()
397
427
  cfg.key = k
428
+ # TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
398
429
  cfg.value_json = json.dumps(v)
399
430
 
400
431
  for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
@@ -1020,10 +1051,6 @@ class InterfaceBase:
1020
1051
  ) -> MailboxHandle[pb.Result]:
1021
1052
  raise NotImplementedError
1022
1053
 
1023
- @abstractmethod
1024
- def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
1025
- raise NotImplementedError
1026
-
1027
1054
  def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
1028
1055
  poll_exit = pb.PollExitRequest()
1029
1056
  return self._deliver_poll_exit(poll_exit)
@@ -8,12 +8,15 @@ import logging
8
8
  from multiprocessing.process import BaseProcess
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
+ from typing_extensions import override
12
+
11
13
  from .interface_shared import InterfaceShared
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from queue import Queue
15
17
 
16
18
  from wandb.proto import wandb_internal_pb2 as pb
19
+ from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
17
20
 
18
21
 
19
22
  logger = logging.getLogger("wandb")
@@ -31,6 +34,13 @@ class InterfaceQueue(InterfaceShared):
31
34
  self._process = process
32
35
  super().__init__()
33
36
 
37
+ @override
38
+ async def deliver_async(
39
+ self,
40
+ record: "pb.Record",
41
+ ) -> "MailboxHandle[pb.Result]":
42
+ raise NotImplementedError
43
+
34
44
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
35
45
  if self._process and not self._process.is_alive():
36
46
  raise Exception("The wandb backend process has shutdown")
@@ -87,7 +87,6 @@ class InterfaceShared(InterfaceBase):
87
87
  stop_status: Optional[pb.StopStatusRequest] = None,
88
88
  internal_messages: Optional[pb.InternalMessagesRequest] = None,
89
89
  network_status: Optional[pb.NetworkStatusRequest] = None,
90
- operation_stats: Optional[pb.OperationStatsRequest] = None,
91
90
  poll_exit: Optional[pb.PollExitRequest] = None,
92
91
  partial_history: Optional[pb.PartialHistoryRequest] = None,
93
92
  sampled_history: Optional[pb.SampledHistoryRequest] = None,
@@ -129,8 +128,6 @@ class InterfaceShared(InterfaceBase):
129
128
  request.internal_messages.CopyFrom(internal_messages)
130
129
  elif network_status:
131
130
  request.network_status.CopyFrom(network_status)
132
- elif operation_stats:
133
- request.operations.CopyFrom(operation_stats)
134
131
  elif poll_exit:
135
132
  request.poll_exit.CopyFrom(poll_exit)
136
133
  elif partial_history:
@@ -424,10 +421,6 @@ class InterfaceShared(InterfaceBase):
424
421
  record = self._make_record(exit=exit_data)
425
422
  return self._deliver(record)
426
423
 
427
- def deliver_operation_stats(self):
428
- record = self._make_request(operation_stats=pb.OperationStatsRequest())
429
- return self._deliver(record)
430
-
431
424
  def _deliver_poll_exit(
432
425
  self,
433
426
  poll_exit: pb.PollExitRequest,