wandb 0.22.0__py3-none-win32.whl → 0.22.2__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +8 -5
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta.py +16 -2
  35. wandb/cli/beta_leet.py +74 -0
  36. wandb/cli/beta_sync.py +9 -11
  37. wandb/cli/cli.py +34 -7
  38. wandb/errors/errors.py +3 -3
  39. wandb/proto/v3/wandb_api_pb2.py +86 -0
  40. wandb/proto/v3/wandb_internal_pb2.py +352 -351
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  43. wandb/proto/v4/wandb_api_pb2.py +37 -0
  44. wandb/proto/v4/wandb_internal_pb2.py +352 -351
  45. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  46. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  47. wandb/proto/v5/wandb_api_pb2.py +38 -0
  48. wandb/proto/v5/wandb_internal_pb2.py +352 -351
  49. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  50. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  51. wandb/proto/v6/wandb_api_pb2.py +48 -0
  52. wandb/proto/v6/wandb_internal_pb2.py +352 -351
  53. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  54. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  55. wandb/proto/wandb_api_pb2.py +18 -0
  56. wandb/proto/wandb_generate_proto.py +1 -0
  57. wandb/sdk/artifacts/_factories.py +7 -2
  58. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  59. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  60. wandb/sdk/artifacts/_generated/operations.py +52 -22
  61. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  62. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  63. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  64. wandb/sdk/artifacts/_gqlutils.py +47 -0
  65. wandb/sdk/artifacts/_models/__init__.py +4 -0
  66. wandb/sdk/artifacts/_models/base_model.py +20 -0
  67. wandb/sdk/artifacts/_validators.py +40 -12
  68. wandb/sdk/artifacts/artifact.py +99 -118
  69. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  70. wandb/sdk/artifacts/artifact_manifest_entry.py +67 -14
  71. wandb/sdk/artifacts/storage_handler.py +18 -12
  72. wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
  73. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
  74. wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
  75. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
  76. wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
  77. wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
  78. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  79. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
  80. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
  81. wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
  82. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +71 -242
  83. wandb/sdk/artifacts/storage_policy.py +25 -12
  84. wandb/sdk/data_types/bokeh.py +5 -1
  85. wandb/sdk/data_types/image.py +17 -6
  86. wandb/sdk/data_types/object_3d.py +67 -2
  87. wandb/sdk/interface/interface.py +31 -4
  88. wandb/sdk/interface/interface_queue.py +10 -0
  89. wandb/sdk/interface/interface_shared.py +0 -7
  90. wandb/sdk/interface/interface_sock.py +9 -3
  91. wandb/sdk/internal/_generated/__init__.py +2 -12
  92. wandb/sdk/internal/job_builder.py +27 -10
  93. wandb/sdk/internal/sender.py +5 -2
  94. wandb/sdk/internal/settings_static.py +2 -82
  95. wandb/sdk/launch/create_job.py +2 -1
  96. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  97. wandb/sdk/launch/utils.py +82 -1
  98. wandb/sdk/lib/progress.py +8 -74
  99. wandb/sdk/lib/service/service_client.py +5 -9
  100. wandb/sdk/lib/service/service_connection.py +39 -23
  101. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  102. wandb/sdk/projects/_generated/__init__.py +12 -33
  103. wandb/sdk/wandb_init.py +23 -3
  104. wandb/sdk/wandb_login.py +53 -27
  105. wandb/sdk/wandb_run.py +10 -5
  106. wandb/sdk/wandb_settings.py +63 -25
  107. wandb/sync/sync.py +7 -2
  108. wandb/util.py +1 -1
  109. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
  110. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/RECORD +113 -103
  111. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  112. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
  113. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from pathlib import PurePosixPath
6
6
  from types import ModuleType
7
- from typing import TYPE_CHECKING, Sequence
7
+ from typing import TYPE_CHECKING
8
8
  from urllib.parse import ParseResult, parse_qsl, urlparse
9
9
 
10
10
  import wandb
@@ -20,17 +20,22 @@ if TYPE_CHECKING:
20
20
  import azure.storage.blob # type: ignore
21
21
 
22
22
  from wandb.sdk.artifacts.artifact import Artifact
23
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
23
24
 
24
25
 
25
26
  class AzureHandler(StorageHandler):
27
+ _scheme: str
28
+ _cache: ArtifactFileCache
29
+
30
+ def __init__(self, scheme: str = "https") -> None:
31
+ self._scheme = scheme
32
+ self._cache = get_artifact_file_cache()
33
+
26
34
  def can_handle(self, parsed_url: ParseResult) -> bool:
27
- return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
35
+ return parsed_url.scheme == self._scheme and parsed_url.netloc.endswith(
28
36
  ".blob.core.windows.net"
29
37
  )
30
38
 
31
- def __init__(self, scheme: str | None = None) -> None:
32
- self._cache = get_artifact_file_cache()
33
-
34
39
  def load_path(
35
40
  self,
36
41
  manifest_entry: ArtifactManifestEntry,
@@ -101,7 +106,7 @@ class AzureHandler(StorageHandler):
101
106
  name: StrPath | None = None,
102
107
  checksum: bool = True,
103
108
  max_objects: int | None = None,
104
- ) -> Sequence[ArtifactManifestEntry]:
109
+ ) -> list[ArtifactManifestEntry]:
105
110
  account_url, container_name, blob_name, query = self._parse_uri(path)
106
111
  path = URIStr(f"{account_url}/{container_name}/{blob_name}")
107
112
 
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  import time
6
6
  from pathlib import PurePosixPath
7
- from typing import TYPE_CHECKING, Sequence
7
+ from typing import TYPE_CHECKING
8
8
  from urllib.parse import ParseResult, urlparse
9
9
 
10
10
  from wandb import util
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
19
19
  import google.cloud.storage as gcs_module # type: ignore
20
20
 
21
21
  from wandb.sdk.artifacts.artifact import Artifact
22
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
22
23
 
23
24
 
24
25
  class _GCSIsADirectoryError(Exception):
@@ -26,10 +27,12 @@ class _GCSIsADirectoryError(Exception):
26
27
 
27
28
 
28
29
  class GCSHandler(StorageHandler):
30
+ _scheme: str
29
31
  _client: gcs_module.client.Client | None
32
+ _cache: ArtifactFileCache
30
33
 
31
- def __init__(self, scheme: str | None = None) -> None:
32
- self._scheme = scheme or "gs"
34
+ def __init__(self, scheme: str = "gs") -> None:
35
+ self._scheme = scheme
33
36
  self._client = None
34
37
  self._cache = get_artifact_file_cache()
35
38
 
@@ -111,7 +114,7 @@ class GCSHandler(StorageHandler):
111
114
  name: StrPath | None = None,
112
115
  checksum: bool = True,
113
116
  max_objects: int | None = None,
114
- ) -> Sequence[ArtifactManifestEntry]:
117
+ ) -> list[ArtifactManifestEntry]:
115
118
  self.init_gcs()
116
119
  assert self._client is not None # mypy: unwraps optionality
117
120
 
@@ -131,7 +134,7 @@ class GCSHandler(StorageHandler):
131
134
  raise ValueError(f"Object does not exist: {path}#{version}")
132
135
  multi = obj is None
133
136
  if multi:
134
- start_time = time.time()
137
+ start_time = time.monotonic()
135
138
  termlog(
136
139
  f'Generating checksum for up to {max_objects} objects with prefix "{key}"... ',
137
140
  newline=False,
@@ -148,7 +151,7 @@ class GCSHandler(StorageHandler):
148
151
  if not obj.name.endswith("/")
149
152
  ]
150
153
  if start_time is not None:
151
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
154
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
152
155
  if len(entries) > max_objects:
153
156
  raise ValueError(
154
157
  f"Exceeded {max_objects} objects tracked, pass max_objects to add_reference"
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import os
6
- from typing import TYPE_CHECKING, Sequence
6
+ from typing import TYPE_CHECKING
7
7
  from urllib.parse import ParseResult
8
8
 
9
9
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
@@ -18,11 +18,16 @@ if TYPE_CHECKING:
18
18
  from requests.structures import CaseInsensitiveDict
19
19
 
20
20
  from wandb.sdk.artifacts.artifact import Artifact
21
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
21
22
 
22
23
 
23
24
  class HTTPHandler(StorageHandler):
24
- def __init__(self, session: requests.Session, scheme: str | None = None) -> None:
25
- self._scheme = scheme or "http"
25
+ _scheme: str
26
+ _cache: ArtifactFileCache
27
+ _session: requests.Session
28
+
29
+ def __init__(self, session: requests.Session, scheme: str = "http") -> None:
30
+ self._scheme = scheme
26
31
  self._cache = get_artifact_file_cache()
27
32
  self._session = session
28
33
 
@@ -75,7 +80,7 @@ class HTTPHandler(StorageHandler):
75
80
  name: StrPath | None = None,
76
81
  checksum: bool = True,
77
82
  max_objects: int | None = None,
78
- ) -> Sequence[ArtifactManifestEntry]:
83
+ ) -> list[ArtifactManifestEntry]:
79
84
  name = name or os.path.basename(path)
80
85
  if not checksum:
81
86
  return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
@@ -6,7 +6,7 @@ import os
6
6
  import shutil
7
7
  import time
8
8
  from pathlib import Path
9
- from typing import TYPE_CHECKING, Sequence
9
+ from typing import TYPE_CHECKING
10
10
  from urllib.parse import ParseResult
11
11
 
12
12
  from wandb import util
@@ -20,17 +20,21 @@ from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
20
20
 
21
21
  if TYPE_CHECKING:
22
22
  from wandb.sdk.artifacts.artifact import Artifact
23
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
23
24
 
24
25
 
25
26
  class LocalFileHandler(StorageHandler):
26
27
  """Handles file:// references."""
27
28
 
28
- def __init__(self, scheme: str | None = None) -> None:
29
+ _scheme: str
30
+ _cache: ArtifactFileCache
31
+
32
+ def __init__(self, scheme: str = "file") -> None:
29
33
  """Track files or directories on a local filesystem.
30
34
 
31
35
  Expand directories to create an entry for each file contained.
32
36
  """
33
- self._scheme = scheme or "file"
37
+ self._scheme = scheme
34
38
  self._cache = get_artifact_file_cache()
35
39
 
36
40
  def can_handle(self, parsed_url: ParseResult) -> bool:
@@ -75,7 +79,7 @@ class LocalFileHandler(StorageHandler):
75
79
  name: StrPath | None = None,
76
80
  checksum: bool = True,
77
81
  max_objects: int | None = None,
78
- ) -> Sequence[ArtifactManifestEntry]:
82
+ ) -> list[ArtifactManifestEntry]:
79
83
  local_path = util.local_file_uri_to_path(path)
80
84
  max_objects = max_objects or DEFAULT_MAX_OBJECTS
81
85
  # We have a single file or directory
@@ -95,7 +99,7 @@ class LocalFileHandler(StorageHandler):
95
99
 
96
100
  if os.path.isdir(local_path):
97
101
  i = 0
98
- start_time = time.time()
102
+ start_time = time.monotonic()
99
103
  if checksum:
100
104
  termlog(
101
105
  f'Generating checksum for up to {max_objects} files in "{local_path}"... ',
@@ -126,7 +130,7 @@ class LocalFileHandler(StorageHandler):
126
130
  )
127
131
  entries.append(entry)
128
132
  if checksum:
129
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
133
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
130
134
  elif os.path.isfile(local_path):
131
135
  name = name or os.path.basename(local_path)
132
136
  entry = ArtifactManifestEntry(
@@ -2,10 +2,10 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Sequence
5
+ from typing import TYPE_CHECKING
6
6
  from urllib.parse import urlparse
7
7
 
8
- from wandb.sdk.artifacts.storage_handler import StorageHandler
8
+ from wandb.sdk.artifacts.storage_handler import StorageHandler, _BaseStorageHandler
9
9
  from wandb.sdk.lib.paths import FilePathStr, URIStr
10
10
 
11
11
  if TYPE_CHECKING:
@@ -13,8 +13,9 @@ if TYPE_CHECKING:
13
13
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
14
14
 
15
15
 
16
- class MultiHandler(StorageHandler):
16
+ class MultiHandler(_BaseStorageHandler):
17
17
  _handlers: list[StorageHandler]
18
+ _default_handler: StorageHandler | None
18
19
 
19
20
  def __init__(
20
21
  self,
@@ -49,7 +50,7 @@ class MultiHandler(StorageHandler):
49
50
  name: str | None = None,
50
51
  checksum: bool = True,
51
52
  max_objects: int | None = None,
52
- ) -> Sequence[ArtifactManifestEntry]:
53
+ ) -> list[ArtifactManifestEntry]:
53
54
  handler = self._get_handler(path)
54
55
  return handler.store_path(
55
56
  artifact, path, name=name, checksum=checksum, max_objects=max_objects
@@ -6,7 +6,7 @@ import os
6
6
  import re
7
7
  import time
8
8
  from pathlib import PurePosixPath
9
- from typing import TYPE_CHECKING, Sequence
9
+ from typing import TYPE_CHECKING
10
10
  from urllib.parse import parse_qsl, urlparse
11
11
 
12
12
  from wandb import util
@@ -32,16 +32,18 @@ if TYPE_CHECKING:
32
32
  import boto3.session # type: ignore
33
33
 
34
34
  from wandb.sdk.artifacts.artifact import Artifact
35
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
35
36
 
36
37
 
37
38
  class S3Handler(StorageHandler):
38
- _s3: boto3.resources.base.ServiceResource | None
39
39
  _scheme: str
40
+ _cache: ArtifactFileCache
41
+ _s3: boto3.resources.base.ServiceResource | None
40
42
 
41
- def __init__(self, scheme: str | None = None) -> None:
42
- self._scheme = scheme or "s3"
43
- self._s3 = None
43
+ def __init__(self, scheme: str = "s3") -> None:
44
+ self._scheme = scheme
44
45
  self._cache = get_artifact_file_cache()
46
+ self._s3 = None
45
47
 
46
48
  def can_handle(self, parsed_url: ParseResult) -> bool:
47
49
  return parsed_url.scheme == self._scheme
@@ -160,7 +162,7 @@ class S3Handler(StorageHandler):
160
162
  name: StrPath | None = None,
161
163
  checksum: bool = True,
162
164
  max_objects: int | None = None,
163
- ) -> Sequence[ArtifactManifestEntry]:
165
+ ) -> list[ArtifactManifestEntry]:
164
166
  self.init_boto()
165
167
  assert self._s3 is not None # mypy: unwraps optionality
166
168
 
@@ -206,7 +208,7 @@ class S3Handler(StorageHandler):
206
208
  multi = True
207
209
 
208
210
  if multi:
209
- start_time = time.time()
211
+ start_time = time.monotonic()
210
212
  termlog(
211
213
  f'Generating checksum for up to {max_objects} objects in "{bucket}/{key}"... ',
212
214
  newline=False,
@@ -227,7 +229,7 @@ class S3Handler(StorageHandler):
227
229
  if size(obj) > 0
228
230
  ]
229
231
  if start_time is not None:
230
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
232
+ termlog("Done. %.1fs" % (time.monotonic() - start_time), prefix=False)
231
233
  if len(entries) > max_objects:
232
234
  raise ValueError(
233
235
  f"Exceeded {max_objects} objects tracked, pass max_objects to add_reference"
@@ -2,7 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Sequence
5
+ from typing import TYPE_CHECKING
6
6
  from urllib.parse import urlparse
7
7
 
8
8
  from wandb.errors.term import termwarn
@@ -17,7 +17,9 @@ if TYPE_CHECKING:
17
17
 
18
18
 
19
19
  class TrackingHandler(StorageHandler):
20
- def __init__(self, scheme: str | None = None) -> None:
20
+ _scheme: str
21
+
22
+ def __init__(self, scheme: str = "") -> None:
21
23
  """Track paths with no modification or special processing.
22
24
 
23
25
  Useful when paths being tracked are on file systems mounted at a standardized
@@ -26,7 +28,7 @@ class TrackingHandler(StorageHandler):
26
28
  For example, if the data to track is located on an NFS share mounted on
27
29
  `/data`, then it is sufficient to just track the paths.
28
30
  """
29
- self._scheme = scheme or ""
31
+ self._scheme = scheme
30
32
 
31
33
  def can_handle(self, parsed_url: ParseResult) -> bool:
32
34
  return parsed_url.scheme == self._scheme
@@ -55,7 +57,7 @@ class TrackingHandler(StorageHandler):
55
57
  name: StrPath | None = None,
56
58
  checksum: bool = True,
57
59
  max_objects: int | None = None,
58
- ) -> Sequence[ArtifactManifestEntry]:
60
+ ) -> list[ArtifactManifestEntry]:
59
61
  url = urlparse(path)
60
62
  if name is None:
61
63
  raise ValueError(
@@ -3,27 +3,29 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import os
6
- from typing import TYPE_CHECKING, Sequence
6
+ from typing import TYPE_CHECKING, Literal
7
7
  from urllib.parse import urlparse
8
8
 
9
- import wandb
10
- from wandb import util
9
+ from wandb._strutils import removeprefix
11
10
  from wandb.apis import PublicApi
12
11
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
13
12
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
14
13
  from wandb.sdk.artifacts.storage_handler import StorageHandler
15
- from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, hex_to_b64_id
14
+ from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id
16
15
  from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
17
16
 
18
17
  if TYPE_CHECKING:
19
18
  from urllib.parse import ParseResult
20
19
 
21
20
  from wandb.sdk.artifacts.artifact import Artifact
21
+ from wandb.sdk.artifacts.artifact_file_cache import ArtifactFileCache
22
22
 
23
23
 
24
24
  class WBArtifactHandler(StorageHandler):
25
25
  """Handles loading and storing Artifact reference-type files."""
26
26
 
27
+ _scheme: Literal["wandb-artifact"]
28
+ _cache: ArtifactFileCache
27
29
  _client: PublicApi | None
28
30
 
29
31
  def __init__(self) -> None:
@@ -55,6 +57,8 @@ class WBArtifactHandler(StorageHandler):
55
57
  Returns:
56
58
  (os.PathLike): A path to the file represented by `index_entry`
57
59
  """
60
+ from wandb.sdk.artifacts.artifact import Artifact # avoids circular import
61
+
58
62
  # We don't check for cache hits here. Since we have 0 for size (since this
59
63
  # is a cross-artifact reference which and we've made the choice to store 0
60
64
  # in the size field), we can't confirm if the file is complete. So we just
@@ -62,19 +66,17 @@ class WBArtifactHandler(StorageHandler):
62
66
  # check.
63
67
 
64
68
  # Parse the reference path and download the artifact if needed
65
- artifact_id = util.host_from_path(manifest_entry.ref)
66
- artifact_file_path = util.uri_from_path(manifest_entry.ref)
69
+ parsed = urlparse(manifest_entry.ref)
70
+ artifact_id = hex_to_b64_id(parsed.netloc)
71
+ artifact_file_path = removeprefix(str(parsed.path), "/")
67
72
 
68
- dep_artifact = wandb.Artifact._from_id(
69
- hex_to_b64_id(artifact_id), self.client.client
70
- )
73
+ dep_artifact = Artifact._from_id(artifact_id, self.client.client)
71
74
  assert dep_artifact is not None
72
75
  link_target_path: URIStr | FilePathStr
73
76
  if local:
74
77
  link_target_path = dep_artifact.get_entry(artifact_file_path).download()
75
78
  else:
76
79
  link_target_path = dep_artifact.get_entry(artifact_file_path).ref_target()
77
-
78
80
  return link_target_path
79
81
 
80
82
  def store_path(
@@ -84,7 +86,7 @@ class WBArtifactHandler(StorageHandler):
84
86
  name: StrPath | None = None,
85
87
  checksum: bool = True,
86
88
  max_objects: int | None = None,
87
- ) -> Sequence[ArtifactManifestEntry]:
89
+ ) -> list[ArtifactManifestEntry]:
88
90
  """Store the file or directory at the given path into the specified artifact.
89
91
 
90
92
  Recursively resolves the reference until the result is a concrete asset.
@@ -97,26 +99,27 @@ class WBArtifactHandler(StorageHandler):
97
99
  (list[ArtifactManifestEntry]): A list of manifest entries to store within
98
100
  the artifact
99
101
  """
102
+ from wandb.sdk.artifacts.artifact import Artifact # avoids circular import
103
+
100
104
  # Recursively resolve the reference until a concrete asset is found
101
105
  # TODO: Consider resolving server-side for performance improvements.
102
- iter_path: URIStr | FilePathStr | None = path
103
- while iter_path is not None and urlparse(iter_path).scheme == self._scheme:
104
- artifact_id = util.host_from_path(iter_path)
105
- artifact_file_path = util.uri_from_path(iter_path)
106
- target_artifact = wandb.Artifact._from_id(
107
- hex_to_b64_id(artifact_id), self.client.client
108
- )
106
+ curr_path: URIStr | FilePathStr | None = path
107
+ while curr_path and (parsed := urlparse(curr_path)).scheme == self._scheme:
108
+ artifact_id = hex_to_b64_id(parsed.netloc)
109
+ artifact_file_path = removeprefix(parsed.path, "/")
110
+
111
+ target_artifact = Artifact._from_id(artifact_id, self.client.client)
109
112
  assert target_artifact is not None
110
113
 
111
114
  entry = target_artifact.manifest.get_entry_by_path(artifact_file_path)
112
115
  assert entry is not None
113
- iter_path = entry.ref
116
+ curr_path = entry.ref
114
117
 
115
118
  # Create the path reference
116
119
  assert target_artifact is not None
117
120
  assert target_artifact.id is not None
118
- path = URIStr(
119
- f"{self._scheme}://{b64_to_hex_id(B64MD5(target_artifact.id))}/{artifact_file_path}"
121
+ path = (
122
+ f"{self._scheme}://{b64_to_hex_id(target_artifact.id)}/{artifact_file_path}"
120
123
  )
121
124
 
122
125
  # Return the new entry
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import os
6
- from typing import TYPE_CHECKING, Sequence
6
+ from typing import TYPE_CHECKING, Literal
7
7
 
8
8
  import wandb
9
9
  from wandb import util
@@ -21,6 +21,8 @@ if TYPE_CHECKING:
21
21
  class WBLocalArtifactHandler(StorageHandler):
22
22
  """Handles loading and storing Artifact reference-type files."""
23
23
 
24
+ _scheme: Literal["wandb-client-artifact"]
25
+
24
26
  def __init__(self) -> None:
25
27
  self._scheme = "wandb-client-artifact"
26
28
 
@@ -43,7 +45,7 @@ class WBLocalArtifactHandler(StorageHandler):
43
45
  name: StrPath | None = None,
44
46
  checksum: bool = True,
45
47
  max_objects: int | None = None,
46
- ) -> Sequence[ArtifactManifestEntry]:
48
+ ) -> list[ArtifactManifestEntry]:
47
49
  """Store the file or directory at the given path within the specified artifact.
48
50
 
49
51
  Args:
@@ -0,0 +1,187 @@
1
+ """Helpers and constants for multipart upload and download."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import math
7
+ import threading
8
+ from concurrent.futures import FIRST_EXCEPTION, Executor, wait
9
+ from dataclasses import dataclass, field
10
+ from queue import Queue
11
+ from typing import Any, Final, Iterator, Union
12
+
13
+ from requests import Session
14
+ from typing_extensions import TypeAlias, TypeIs, final
15
+
16
+ from wandb import env
17
+ from wandb.sdk.artifacts.artifact_file_cache import Opener
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ KiB: Final[int] = 1024
22
+ MiB: Final[int] = 1024**2
23
+ GiB: Final[int] = 1024**3
24
+ TiB: Final[int] = 1024**4
25
+
26
+ # AWS S3 max upload parts without having to make additional requests for extra parts
27
+ MAX_PARTS = 1_000
28
+ MIN_MULTI_UPLOAD_SIZE = 2 * GiB
29
+ MAX_MULTI_UPLOAD_SIZE = 5 * TiB
30
+
31
+ # Minimum size to switch to multipart download, same threshold as upload.
32
+ MIN_MULTI_DOWNLOAD_SIZE = MIN_MULTI_UPLOAD_SIZE
33
+
34
+ # Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
35
+ # https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
36
+ # https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
37
+ MULTI_DEFAULT_PART_SIZE = 100 * MiB
38
+
39
+ # Chunk size for reading http response and writing to disk.
40
+ RSP_CHUNK_SIZE = 1 * MiB
41
+
42
+
43
+ @final
44
+ class _ChunkSentinel:
45
+ """Signals the end of the multipart chunk queue.
46
+
47
+ Queue consumer(s) (file writer) should terminate on receiving an item of this type from the queue.
48
+ Do not instantiate this class directly, use the `END_CHUNK` constant as a pseudo-singleton instead.
49
+
50
+ NOTE: As implemented, this should only be used in multi-threaded (not multi-process) contexts, as
51
+ it's not currently guaranteed to be process-safe.
52
+ """
53
+
54
+ def __repr__(self) -> str:
55
+ return "ChunkSentinel"
56
+
57
+
58
+ END_CHUNK: Final[_ChunkSentinel] = _ChunkSentinel()
59
+
60
+
61
+ def is_end_chunk(obj: Any) -> TypeIs[_ChunkSentinel]:
62
+ """Returns True if the object is the terminal queue item for multipart downloads."""
63
+ # Needed for type checking, since _ChunkSentinel isn't formally a singleton.
64
+ return obj is END_CHUNK
65
+
66
+
67
+ @dataclass(frozen=True)
68
+ class ChunkContent:
69
+ __slots__ = ("offset", "data") # slots=True only introduced in Python 3.10
70
+ offset: int
71
+ data: bytes
72
+
73
+
74
+ QueuedChunk: TypeAlias = Union[ChunkContent, _ChunkSentinel]
75
+
76
+
77
+ def should_multipart_download(size: int | None, override: bool | None = None) -> bool:
78
+ return ((size or 0) >= MIN_MULTI_DOWNLOAD_SIZE) if (override is None) else override
79
+
80
+
81
+ def calc_part_size(file_size: int, min_part_size: int = MULTI_DEFAULT_PART_SIZE) -> int:
82
+ # Default to a chunk size of 100MiB. S3 has a cap of 10,000 upload parts.
83
+ return max(math.ceil(file_size / MAX_PARTS), min_part_size)
84
+
85
+
86
+ def scan_chunks(path: str, chunk_size: int) -> Iterator[bytes]:
87
+ with open(path, "rb") as f:
88
+ while data := f.read(chunk_size):
89
+ yield data
90
+
91
+
92
+ @dataclass
93
+ class MultipartDownloadContext:
94
+ q: Queue[QueuedChunk]
95
+ cancel: threading.Event = field(default_factory=threading.Event)
96
+
97
+
98
+ def multipart_download(
99
+ executor: Executor,
100
+ session: Session,
101
+ url: str,
102
+ size: int,
103
+ cached_open: Opener,
104
+ part_size: int = MULTI_DEFAULT_PART_SIZE,
105
+ ):
106
+ """Download file as multiple parts in parallel.
107
+
108
+ Only one thread for writing to file. Each part run one http request in one thread.
109
+ HTTP response chunk of a file part is sent to the writer thread via a queue.
110
+ """
111
+ # ------------------------------------------------------------------------------
112
+ # Shared between threads
113
+ ctx = MultipartDownloadContext(q=Queue(maxsize=500))
114
+
115
+ # Put cache_open at top so we remove the tmp file when there is network error.
116
+ with cached_open("wb") as f:
117
+
118
+ def download_chunk(start: int, end: int | None = None) -> None:
119
+ # Error from another thread, no need to start
120
+ if ctx.cancel.is_set():
121
+ return
122
+
123
+ # https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
124
+ # Start and end are both inclusive, empty end means use the actual end of the file.
125
+ # e.g. "bytes=0-499"
126
+ bytes_range = f"{start}-" if (end is None) else f"{start}-{end}"
127
+ headers = {"Range": f"bytes={bytes_range}"}
128
+ with session.get(url=url, headers=headers, stream=True) as rsp:
129
+ offset = start
130
+ for chunk in rsp.iter_content(chunk_size=RSP_CHUNK_SIZE):
131
+ if ctx.cancel.is_set():
132
+ return
133
+ ctx.q.put(ChunkContent(offset=offset, data=chunk))
134
+ offset += len(chunk)
135
+
136
+ def write_chunks() -> None:
137
+ # If all chunks are written or there's an error in another thread, shutdown
138
+ while not (ctx.cancel.is_set() or is_end_chunk(chunk := ctx.q.get())):
139
+ try:
140
+ # NOTE: Seek works without pre allocating the file on disk.
141
+ # It automatically creates a sparse file, e.g. ls -hl would show
142
+ # a bigger size compared to du -sh * because downloading different
143
+ # chunks is not a sequential write.
144
+ # See https://man7.org/linux/man-pages/man2/lseek.2.html
145
+ f.seek(chunk.offset)
146
+ f.write(chunk.data)
147
+
148
+ except Exception as e:
149
+ if env.is_debug():
150
+ logger.debug(f"Error writing chunk to file: {e}")
151
+ ctx.cancel.set()
152
+ raise
153
+
154
+ # Start writer thread first.
155
+ write_future = executor.submit(write_chunks)
156
+
157
+ # Start download threads for each chunk.
158
+ download_futures = set()
159
+ for start in range(0, size, part_size):
160
+ # https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
161
+ # Start and end are both inclusive, empty end means use the actual end of the file.
162
+ # e.g. bytes=0-499
163
+ end = end if (end := (start + part_size - 1)) < size else None
164
+ download_futures.add(executor.submit(download_chunk, start=start, end=end))
165
+
166
+ # Wait for download
167
+ done, not_done = wait(download_futures, return_when=FIRST_EXCEPTION)
168
+ try:
169
+ for fut in done:
170
+ fut.result()
171
+ except Exception as e:
172
+ if env.is_debug():
173
+ logger.debug(f"Error downloading file: {e}")
174
+ ctx.cancel.set()
175
+
176
+ # Cancel any pending futures. Note:
177
+ # - `Future.cancel()` does NOT stop the future if it's running, which is why
178
+ # there's a separate `threading.Event` to ensure cooperative cancellation.
179
+ # - Once Python 3.8 support is dropped, replace these `fut.cancel()`
180
+ # calls with `Executor.shutdown(cancel_futures=True)`.
181
+ for fut in not_done:
182
+ fut.cancel()
183
+ raise
184
+ finally:
185
+ # Always signal the writer to stop
186
+ ctx.q.put(END_CHUNK)
187
+ write_future.result()