wandb 0.21.4__py3-none-macosx_12_0_arm64.whl → 0.22.1__py3-none-macosx_12_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -3
- wandb/_pydantic/__init__.py +12 -11
- wandb/_pydantic/base.py +49 -19
- wandb/apis/__init__.py +2 -0
- wandb/apis/attrs.py +2 -0
- wandb/apis/importers/internals/internal.py +16 -23
- wandb/apis/internal.py +2 -0
- wandb/apis/normalize.py +2 -0
- wandb/apis/public/__init__.py +44 -1
- wandb/apis/public/api.py +215 -164
- wandb/apis/public/artifacts.py +23 -20
- wandb/apis/public/const.py +2 -0
- wandb/apis/public/files.py +33 -24
- wandb/apis/public/history.py +2 -0
- wandb/apis/public/jobs.py +20 -18
- wandb/apis/public/projects.py +4 -2
- wandb/apis/public/query_generator.py +3 -0
- wandb/apis/public/registries/__init__.py +7 -0
- wandb/apis/public/registries/_freezable_list.py +9 -12
- wandb/apis/public/registries/registries_search.py +8 -6
- wandb/apis/public/registries/registry.py +22 -17
- wandb/apis/public/reports.py +2 -0
- wandb/apis/public/runs.py +282 -60
- wandb/apis/public/sweeps.py +10 -9
- wandb/apis/public/teams.py +2 -0
- wandb/apis/public/users.py +2 -0
- wandb/apis/public/utils.py +16 -15
- wandb/automations/_generated/__init__.py +54 -127
- wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
- wandb/automations/_generated/fragments.py +26 -91
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta_sync.py +9 -11
- wandb/errors/errors.py +3 -3
- wandb/proto/v3/wandb_internal_pb2.py +234 -224
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_internal_pb2.py +226 -224
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_internal_pb2.py +226 -224
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_base_pb2.py +3 -3
- wandb/proto/v6/wandb_internal_pb2.py +229 -227
- wandb/proto/v6/wandb_server_pb2.py +3 -3
- wandb/proto/v6/wandb_settings_pb2.py +3 -3
- wandb/proto/v6/wandb_sync_pb2.py +13 -9
- wandb/proto/v6/wandb_telemetry_pb2.py +3 -3
- wandb/sdk/artifacts/_factories.py +7 -2
- wandb/sdk/artifacts/_generated/__init__.py +112 -412
- wandb/sdk/artifacts/_generated/fragments.py +65 -0
- wandb/sdk/artifacts/_generated/operations.py +52 -22
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/type_info.py +19 -0
- wandb/sdk/artifacts/_gqlutils.py +47 -0
- wandb/sdk/artifacts/_models/__init__.py +4 -0
- wandb/sdk/artifacts/_models/base_model.py +20 -0
- wandb/sdk/artifacts/_validators.py +40 -12
- wandb/sdk/artifacts/artifact.py +69 -88
- wandb/sdk/artifacts/artifact_file_cache.py +6 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +69 -124
- wandb/sdk/data_types/bokeh.py +5 -1
- wandb/sdk/data_types/image.py +17 -6
- wandb/sdk/interface/interface.py +41 -4
- wandb/sdk/interface/interface_queue.py +10 -0
- wandb/sdk/interface/interface_shared.py +9 -7
- wandb/sdk/interface/interface_sock.py +9 -3
- wandb/sdk/internal/_generated/__init__.py +2 -12
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/settings_static.py +2 -82
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
- wandb/sdk/launch/utils.py +82 -1
- wandb/sdk/lib/progress.py +7 -4
- wandb/sdk/lib/service/service_client.py +5 -9
- wandb/sdk/lib/service/service_connection.py +39 -23
- wandb/sdk/mailbox/mailbox_handle.py +2 -0
- wandb/sdk/projects/_generated/__init__.py +12 -33
- wandb/sdk/wandb_init.py +31 -3
- wandb/sdk/wandb_login.py +53 -27
- wandb/sdk/wandb_run.py +5 -3
- wandb/sdk/wandb_settings.py +50 -13
- wandb/sync/sync.py +7 -2
- wandb/util.py +1 -1
- wandb/wandb_agent.py +35 -4
- {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
- {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/RECORD +95 -91
- wandb/sdk/artifacts/_graphql_fragments.py +0 -19
- {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
- {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,6 @@ from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
|
|
16
16
|
from urllib.parse import quote
|
17
17
|
|
18
18
|
import requests
|
19
|
-
import urllib3
|
20
19
|
|
21
20
|
from wandb import env
|
22
21
|
from wandb.errors.term import termwarn
|
@@ -27,40 +26,24 @@ from wandb.sdk.artifacts.artifact_file_cache import (
|
|
27
26
|
get_artifact_file_cache,
|
28
27
|
)
|
29
28
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
30
|
-
from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
|
31
|
-
from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
|
32
|
-
from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
|
33
|
-
from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
|
34
29
|
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
35
|
-
from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
|
36
30
|
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
37
|
-
from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
|
38
|
-
from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
|
39
|
-
WBLocalArtifactHandler,
|
40
|
-
)
|
41
31
|
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
42
32
|
from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
|
43
33
|
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
44
34
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
45
35
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
46
|
-
from wandb.sdk.lib.hashutil import
|
36
|
+
from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id
|
47
37
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
48
38
|
|
39
|
+
from ._factories import make_http_session, make_storage_handlers
|
40
|
+
|
49
41
|
if TYPE_CHECKING:
|
50
42
|
from wandb.filesync.step_prepare import StepPrepare
|
51
43
|
from wandb.sdk.artifacts.artifact import Artifact
|
52
44
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
53
45
|
from wandb.sdk.internal import progress
|
54
46
|
|
55
|
-
# Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
|
56
|
-
# seconds, i.e. a total of 20min 6s.
|
57
|
-
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
|
58
|
-
backoff_factor=1,
|
59
|
-
total=16,
|
60
|
-
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
|
61
|
-
)
|
62
|
-
_REQUEST_POOL_CONNECTIONS = 64
|
63
|
-
_REQUEST_POOL_MAXSIZE = 64
|
64
47
|
|
65
48
|
# AWS S3 max upload parts without having to make additional requests for extra parts
|
66
49
|
S3_MAX_PART_NUMBERS = 1000
|
@@ -96,51 +79,36 @@ class WandbStoragePolicy(StoragePolicy):
|
|
96
79
|
|
97
80
|
@classmethod
|
98
81
|
def from_config(
|
99
|
-
cls, config: dict, api: InternalApi | None = None
|
82
|
+
cls, config: dict[str, Any], api: InternalApi | None = None
|
100
83
|
) -> WandbStoragePolicy:
|
101
84
|
return cls(config=config, api=api)
|
102
85
|
|
103
86
|
def __init__(
|
104
87
|
self,
|
105
|
-
config: dict | None = None,
|
88
|
+
config: dict[str, Any] | None = None,
|
106
89
|
cache: ArtifactFileCache | None = None,
|
107
90
|
api: InternalApi | None = None,
|
91
|
+
session: requests.Session | None = None,
|
108
92
|
) -> None:
|
109
|
-
self._cache = cache or get_artifact_file_cache()
|
110
93
|
self._config = config or {}
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
pool_maxsize=_REQUEST_POOL_MAXSIZE,
|
116
|
-
)
|
117
|
-
self._session.mount("http://", adapter)
|
118
|
-
self._session.mount("https://", adapter)
|
119
|
-
|
120
|
-
s3 = S3Handler()
|
121
|
-
gcs = GCSHandler()
|
122
|
-
azure = AzureHandler()
|
123
|
-
http = HTTPHandler(self._session)
|
124
|
-
https = HTTPHandler(self._session, scheme="https")
|
125
|
-
artifact = WBArtifactHandler()
|
126
|
-
local_artifact = WBLocalArtifactHandler()
|
127
|
-
file_handler = LocalFileHandler()
|
128
|
-
|
94
|
+
if (storage_region := self._config.get("storageRegion")) is not None:
|
95
|
+
self._validate_storage_region(storage_region)
|
96
|
+
self._cache = cache or get_artifact_file_cache()
|
97
|
+
self._session = session or make_http_session()
|
129
98
|
self._api = api or InternalApi()
|
130
99
|
self._handler = MultiHandler(
|
131
|
-
handlers=
|
132
|
-
s3,
|
133
|
-
gcs,
|
134
|
-
azure,
|
135
|
-
http,
|
136
|
-
https,
|
137
|
-
artifact,
|
138
|
-
local_artifact,
|
139
|
-
file_handler,
|
140
|
-
],
|
100
|
+
handlers=make_storage_handlers(self._session),
|
141
101
|
default_handler=TrackingHandler(),
|
142
102
|
)
|
143
103
|
|
104
|
+
def _validate_storage_region(self, storage_region: Any) -> None:
|
105
|
+
if not isinstance(storage_region, str):
|
106
|
+
raise TypeError(
|
107
|
+
f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
|
108
|
+
)
|
109
|
+
if not storage_region.strip():
|
110
|
+
raise ValueError("storageRegion must be a non-empty string")
|
111
|
+
|
144
112
|
def config(self) -> dict:
|
145
113
|
return self._config
|
146
114
|
|
@@ -167,54 +135,52 @@ class WandbStoragePolicy(StoragePolicy):
|
|
167
135
|
self._cache._override_cache_path = dest_path
|
168
136
|
|
169
137
|
path, hit, cache_open = self._cache.check_md5_obj_path(
|
170
|
-
|
171
|
-
|
138
|
+
manifest_entry.digest,
|
139
|
+
size=manifest_entry.size or 0,
|
172
140
|
)
|
173
141
|
if hit:
|
174
142
|
return path
|
175
143
|
|
176
|
-
if manifest_entry._download_url is not None:
|
144
|
+
if (url := manifest_entry._download_url) is not None:
|
177
145
|
# Use multipart parallel download for large file
|
178
146
|
if (
|
179
|
-
executor
|
180
|
-
and manifest_entry.size
|
181
|
-
and self._should_multipart_download(
|
147
|
+
executor
|
148
|
+
and (size := manifest_entry.size)
|
149
|
+
and self._should_multipart_download(size, multipart)
|
182
150
|
):
|
183
|
-
self._multipart_file_download(
|
184
|
-
executor,
|
185
|
-
manifest_entry._download_url,
|
186
|
-
manifest_entry.size,
|
187
|
-
cache_open,
|
188
|
-
)
|
151
|
+
self._multipart_file_download(executor, url, size, cache_open)
|
189
152
|
return path
|
153
|
+
|
190
154
|
# Serial download
|
191
|
-
response = self._session.get(manifest_entry._download_url, stream=True)
|
192
155
|
try:
|
193
|
-
response.
|
194
|
-
except
|
156
|
+
response = self._session.get(url, stream=True)
|
157
|
+
except requests.HTTPError:
|
195
158
|
# Signed URL might have expired, fall back to fetching it one by one.
|
196
159
|
manifest_entry._download_url = None
|
160
|
+
|
197
161
|
if manifest_entry._download_url is None:
|
198
162
|
auth = None
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
163
|
+
headers = _thread_local_api_settings.headers
|
164
|
+
cookies = _thread_local_api_settings.cookies
|
165
|
+
|
166
|
+
# For auth, prefer using (in order): auth header, cookies, HTTP Basic Auth
|
167
|
+
if token := self._api.access_token:
|
168
|
+
headers = {**(headers or {}), "Authorization": f"Bearer {token}"}
|
169
|
+
elif cookies is not None:
|
170
|
+
pass
|
171
|
+
else:
|
203
172
|
auth = ("api", self._api.api_key or "")
|
173
|
+
|
174
|
+
file_url = self._file_url(
|
175
|
+
self._api,
|
176
|
+
artifact.entity,
|
177
|
+
artifact.project,
|
178
|
+
artifact.name.split(":")[0],
|
179
|
+
manifest_entry,
|
180
|
+
)
|
204
181
|
response = self._session.get(
|
205
|
-
|
206
|
-
self._api,
|
207
|
-
artifact.entity,
|
208
|
-
artifact.project,
|
209
|
-
artifact.name.split(":")[0],
|
210
|
-
manifest_entry,
|
211
|
-
),
|
212
|
-
auth=auth,
|
213
|
-
cookies=_thread_local_api_settings.cookies,
|
214
|
-
headers=http_headers,
|
215
|
-
stream=True,
|
182
|
+
file_url, auth=auth, cookies=cookies, headers=headers, stream=True
|
216
183
|
)
|
217
|
-
response.raise_for_status()
|
218
184
|
|
219
185
|
with cache_open(mode="wb") as file:
|
220
186
|
for data in response.iter_content(chunk_size=16 * 1024):
|
@@ -269,12 +235,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
269
235
|
# Other threads has error, no need to start
|
270
236
|
if download_has_error.is_set():
|
271
237
|
return
|
272
|
-
response = self._session.get(
|
273
|
-
url=download_url,
|
274
|
-
headers=headers,
|
275
|
-
stream=True,
|
276
|
-
)
|
277
|
-
response.raise_for_status()
|
238
|
+
response = self._session.get(url=download_url, headers=headers, stream=True)
|
278
239
|
|
279
240
|
file_offset = start
|
280
241
|
for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
|
@@ -376,43 +337,27 @@ class WandbStoragePolicy(StoragePolicy):
|
|
376
337
|
entity_name: str,
|
377
338
|
project_name: str,
|
378
339
|
artifact_name: str,
|
379
|
-
|
340
|
+
entry: ArtifactManifestEntry,
|
380
341
|
) -> str:
|
381
|
-
|
382
|
-
|
383
|
-
md5_hex = b64_to_hex_id(
|
342
|
+
layout = self._config.get("storageLayout", StorageLayout.V1)
|
343
|
+
region = self._config.get("storageRegion", "default")
|
344
|
+
md5_hex = b64_to_hex_id(entry.digest)
|
384
345
|
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
346
|
+
base_url: str = api.settings("base_url")
|
347
|
+
|
348
|
+
if layout == StorageLayout.V1:
|
349
|
+
return f"{base_url}/artifacts/{entity_name}/{md5_hex}"
|
350
|
+
|
351
|
+
if layout == StorageLayout.V2:
|
352
|
+
birth_artifact_id = entry.birth_artifact_id or ""
|
390
353
|
if api._server_supports(
|
391
|
-
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
|
354
|
+
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
|
392
355
|
):
|
393
|
-
return "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}"
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
quote(artifact_name),
|
399
|
-
quote(manifest_entry.birth_artifact_id or ""),
|
400
|
-
md5_hex,
|
401
|
-
manifest_entry.path.name,
|
402
|
-
)
|
403
|
-
return "{}/artifactsV2/{}/{}/{}/{}".format(
|
404
|
-
api.settings("base_url"),
|
405
|
-
storage_region,
|
406
|
-
entity_name,
|
407
|
-
quote(
|
408
|
-
manifest_entry.birth_artifact_id
|
409
|
-
if manifest_entry.birth_artifact_id is not None
|
410
|
-
else ""
|
411
|
-
),
|
412
|
-
md5_hex,
|
413
|
-
)
|
414
|
-
else:
|
415
|
-
raise Exception(f"unrecognized storage layout: {storage_layout}")
|
356
|
+
return f"{base_url}/artifactsV2/{region}/{quote(entity_name)}/{quote(project_name)}/{quote(artifact_name)}/{quote(birth_artifact_id)}/{md5_hex}/{entry.path.name}"
|
357
|
+
|
358
|
+
return f"{base_url}/artifactsV2/{region}/{entity_name}/{quote(birth_artifact_id)}/{md5_hex}"
|
359
|
+
|
360
|
+
raise ValueError(f"unrecognized storage layout: {layout!r}")
|
416
361
|
|
417
362
|
def s3_multipart_file_upload(
|
418
363
|
self,
|
@@ -486,7 +431,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
486
431
|
True if the file was a duplicate (did not need to be uploaded),
|
487
432
|
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
488
433
|
"""
|
489
|
-
file_size = entry.size
|
434
|
+
file_size = entry.size or 0
|
490
435
|
chunk_size = self.calc_chunk_size(file_size)
|
491
436
|
upload_parts = []
|
492
437
|
hex_digests = {}
|
@@ -562,8 +507,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
562
507
|
|
563
508
|
# Cache upon successful upload.
|
564
509
|
_, hit, cache_open = self._cache.check_md5_obj_path(
|
565
|
-
|
566
|
-
|
510
|
+
entry.digest,
|
511
|
+
size=entry.size or 0,
|
567
512
|
)
|
568
513
|
|
569
514
|
staging_dir = get_staging_dir()
|
wandb/sdk/data_types/bokeh.py
CHANGED
@@ -5,6 +5,7 @@ import pathlib
|
|
5
5
|
from typing import TYPE_CHECKING, Union
|
6
6
|
|
7
7
|
from wandb import util
|
8
|
+
from wandb._strutils import nameof
|
8
9
|
from wandb.sdk.lib import runid
|
9
10
|
|
10
11
|
from . import _dtypes
|
@@ -34,7 +35,10 @@ class Bokeh(Media):
|
|
34
35
|
],
|
35
36
|
):
|
36
37
|
super().__init__()
|
37
|
-
bokeh = util.get_module(
|
38
|
+
bokeh = util.get_module(
|
39
|
+
"bokeh",
|
40
|
+
required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
|
41
|
+
)
|
38
42
|
if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
|
39
43
|
data_or_path
|
40
44
|
):
|
wandb/sdk/data_types/image.py
CHANGED
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
|
|
161
161
|
) -> None:
|
162
162
|
"""Initialize a `wandb.Image` object.
|
163
163
|
|
164
|
+
This class handles various image data formats and automatically normalizes
|
165
|
+
pixel values to the range [0, 255] when needed, ensuring compatibility
|
166
|
+
with the W&B backend.
|
167
|
+
|
168
|
+
* Data in range [0, 1] is multiplied by 255 and converted to uint8
|
169
|
+
* Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
|
170
|
+
-1 to 0 and 1 to 255, then converted to uint8
|
171
|
+
* Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
|
172
|
+
converted to uint8 (with a warning if values fall outside [0, 255])
|
173
|
+
* Data already in [0, 255] is converted to uint8 without modification
|
174
|
+
|
164
175
|
Args:
|
165
176
|
data_or_path: Accepts NumPy array/pytorch tensor of image data,
|
166
177
|
a PIL image object, or a path to an image file. If a NumPy
|
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
|
|
168
179
|
the image data will be saved to the given file type.
|
169
180
|
If the values are not in the range [0, 255] or all values are in the range [0, 1],
|
170
181
|
the image pixel values will be normalized to the range [0, 255]
|
171
|
-
unless `normalize` is set to False
|
182
|
+
unless `normalize` is set to `False`.
|
172
183
|
- pytorch tensor should be in the format (channel, height, width)
|
173
184
|
- NumPy array should be in the format (height, width, channel)
|
174
185
|
mode: The PIL mode for an image. Most common are "L", "RGB",
|
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
|
|
178
189
|
classes: A list of class information for the image,
|
179
190
|
used for labeling bounding boxes, and image masks.
|
180
191
|
boxes: A dictionary containing bounding box information for the image.
|
181
|
-
see
|
192
|
+
see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
|
182
193
|
masks: A dictionary containing mask information for the image.
|
183
|
-
see
|
194
|
+
see https://docs.wandb.ai/ref/python/data-types/imagemask/
|
184
195
|
file_type: The file type to save the image as.
|
185
|
-
This parameter has no effect if data_or_path is a path to an image file.
|
186
|
-
normalize: If True
|
187
|
-
Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
|
196
|
+
This parameter has no effect if `data_or_path` is a path to an image file.
|
197
|
+
normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
|
198
|
+
Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
|
188
199
|
|
189
200
|
Examples:
|
190
201
|
Create a wandb.Image from a numpy array
|
wandb/sdk/interface/interface.py
CHANGED
@@ -87,11 +87,38 @@ def file_enum_to_policy(enum: "pb.FilesItem.PolicyType.V") -> "PolicyName":
|
|
87
87
|
|
88
88
|
|
89
89
|
class InterfaceBase:
|
90
|
+
"""Methods for sending different types of Records to the service.
|
91
|
+
|
92
|
+
None of the methods may be called from an asyncio context other than
|
93
|
+
deliver_async().
|
94
|
+
"""
|
95
|
+
|
90
96
|
_drop: bool
|
91
97
|
|
92
98
|
def __init__(self) -> None:
|
93
99
|
self._drop = False
|
94
100
|
|
101
|
+
@abstractmethod
|
102
|
+
async def deliver_async(
|
103
|
+
self,
|
104
|
+
record: pb.Record,
|
105
|
+
) -> MailboxHandle[pb.Result]:
|
106
|
+
"""Send a record and create a handle to wait for the response.
|
107
|
+
|
108
|
+
The synchronous publish and deliver methods on this class cannot be
|
109
|
+
called in the asyncio thread because they block. Instead of having
|
110
|
+
an async copy of every method, this is a general method for sending
|
111
|
+
any kind of record in the asyncio thread.
|
112
|
+
|
113
|
+
Args:
|
114
|
+
record: The record to send. This method takes ownership of the
|
115
|
+
record and it must not be used afterward.
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
A handle to wait for a response to the record.
|
119
|
+
"""
|
120
|
+
raise NotImplementedError
|
121
|
+
|
95
122
|
def publish_header(self) -> None:
|
96
123
|
header = pb.HeaderRecord()
|
97
124
|
self._publish_header(header)
|
@@ -392,9 +419,13 @@ class InterfaceBase:
|
|
392
419
|
proto_manifest.manifest_file_path = path
|
393
420
|
return proto_manifest
|
394
421
|
|
422
|
+
# Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
|
423
|
+
# NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
|
424
|
+
# The creation logic is in artifacts/_factories.py make_storage_policy
|
395
425
|
for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
|
396
426
|
cfg = proto_manifest.storage_policy_config.add()
|
397
427
|
cfg.key = k
|
428
|
+
# TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
|
398
429
|
cfg.value_json = json.dumps(v)
|
399
430
|
|
400
431
|
for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
|
@@ -883,6 +914,16 @@ class InterfaceBase:
|
|
883
914
|
) -> MailboxHandle[pb.Result]:
|
884
915
|
raise NotImplementedError
|
885
916
|
|
917
|
+
def publish_probe_system_info(self) -> None:
|
918
|
+
probe_system_info = pb.ProbeSystemInfoRequest()
|
919
|
+
return self._publish_probe_system_info(probe_system_info)
|
920
|
+
|
921
|
+
@abstractmethod
|
922
|
+
def _publish_probe_system_info(
|
923
|
+
self, probe_system_info: pb.ProbeSystemInfoRequest
|
924
|
+
) -> None:
|
925
|
+
raise NotImplementedError
|
926
|
+
|
886
927
|
def join(self) -> None:
|
887
928
|
# Drop indicates that the internal process has already been shutdown
|
888
929
|
if self._drop:
|
@@ -1010,10 +1051,6 @@ class InterfaceBase:
|
|
1010
1051
|
) -> MailboxHandle[pb.Result]:
|
1011
1052
|
raise NotImplementedError
|
1012
1053
|
|
1013
|
-
@abstractmethod
|
1014
|
-
def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
|
1015
|
-
raise NotImplementedError
|
1016
|
-
|
1017
1054
|
def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
|
1018
1055
|
poll_exit = pb.PollExitRequest()
|
1019
1056
|
return self._deliver_poll_exit(poll_exit)
|
@@ -8,12 +8,15 @@ import logging
|
|
8
8
|
from multiprocessing.process import BaseProcess
|
9
9
|
from typing import TYPE_CHECKING, Optional
|
10
10
|
|
11
|
+
from typing_extensions import override
|
12
|
+
|
11
13
|
from .interface_shared import InterfaceShared
|
12
14
|
|
13
15
|
if TYPE_CHECKING:
|
14
16
|
from queue import Queue
|
15
17
|
|
16
18
|
from wandb.proto import wandb_internal_pb2 as pb
|
19
|
+
from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
|
17
20
|
|
18
21
|
|
19
22
|
logger = logging.getLogger("wandb")
|
@@ -31,6 +34,13 @@ class InterfaceQueue(InterfaceShared):
|
|
31
34
|
self._process = process
|
32
35
|
super().__init__()
|
33
36
|
|
37
|
+
@override
|
38
|
+
async def deliver_async(
|
39
|
+
self,
|
40
|
+
record: "pb.Record",
|
41
|
+
) -> "MailboxHandle[pb.Result]":
|
42
|
+
raise NotImplementedError
|
43
|
+
|
34
44
|
def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
|
35
45
|
if self._process and not self._process.is_alive():
|
36
46
|
raise Exception("The wandb backend process has shutdown")
|
@@ -87,7 +87,6 @@ class InterfaceShared(InterfaceBase):
|
|
87
87
|
stop_status: Optional[pb.StopStatusRequest] = None,
|
88
88
|
internal_messages: Optional[pb.InternalMessagesRequest] = None,
|
89
89
|
network_status: Optional[pb.NetworkStatusRequest] = None,
|
90
|
-
operation_stats: Optional[pb.OperationStatsRequest] = None,
|
91
90
|
poll_exit: Optional[pb.PollExitRequest] = None,
|
92
91
|
partial_history: Optional[pb.PartialHistoryRequest] = None,
|
93
92
|
sampled_history: Optional[pb.SampledHistoryRequest] = None,
|
@@ -112,6 +111,7 @@ class InterfaceShared(InterfaceBase):
|
|
112
111
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
113
112
|
job_input: Optional[pb.JobInputRequest] = None,
|
114
113
|
run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
|
114
|
+
probe_system_info: Optional[pb.ProbeSystemInfoRequest] = None,
|
115
115
|
) -> pb.Record:
|
116
116
|
request = pb.Request()
|
117
117
|
if get_summary:
|
@@ -128,8 +128,6 @@ class InterfaceShared(InterfaceBase):
|
|
128
128
|
request.internal_messages.CopyFrom(internal_messages)
|
129
129
|
elif network_status:
|
130
130
|
request.network_status.CopyFrom(network_status)
|
131
|
-
elif operation_stats:
|
132
|
-
request.operations.CopyFrom(operation_stats)
|
133
131
|
elif poll_exit:
|
134
132
|
request.poll_exit.CopyFrom(poll_exit)
|
135
133
|
elif partial_history:
|
@@ -178,6 +176,8 @@ class InterfaceShared(InterfaceBase):
|
|
178
176
|
request.job_input.CopyFrom(job_input)
|
179
177
|
elif run_finish_without_exit:
|
180
178
|
request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
|
179
|
+
elif probe_system_info:
|
180
|
+
request.probe_system_info.CopyFrom(probe_system_info)
|
181
181
|
else:
|
182
182
|
raise Exception("Invalid request")
|
183
183
|
record = self._make_record(request=request)
|
@@ -330,6 +330,12 @@ class InterfaceShared(InterfaceBase):
|
|
330
330
|
rec = self._make_record(use_artifact=use_artifact)
|
331
331
|
self._publish(rec)
|
332
332
|
|
333
|
+
def _publish_probe_system_info(
|
334
|
+
self, probe_system_info: pb.ProbeSystemInfoRequest
|
335
|
+
) -> None:
|
336
|
+
record = self._make_request(probe_system_info=probe_system_info)
|
337
|
+
self._publish(record)
|
338
|
+
|
333
339
|
def _deliver_artifact(
|
334
340
|
self,
|
335
341
|
log_artifact: pb.LogArtifactRequest,
|
@@ -415,10 +421,6 @@ class InterfaceShared(InterfaceBase):
|
|
415
421
|
record = self._make_record(exit=exit_data)
|
416
422
|
return self._deliver(record)
|
417
423
|
|
418
|
-
def deliver_operation_stats(self):
|
419
|
-
record = self._make_request(operation_stats=pb.OperationStatsRequest())
|
420
|
-
return self._deliver(record)
|
421
|
-
|
422
424
|
def _deliver_poll_exit(
|
423
425
|
self,
|
424
426
|
poll_exit: pb.PollExitRequest,
|
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
|
|
6
6
|
from typing_extensions import override
|
7
7
|
|
8
8
|
from wandb.proto import wandb_server_pb2 as spb
|
9
|
+
from wandb.sdk.lib import asyncio_manager
|
9
10
|
|
10
11
|
from .interface_shared import InterfaceShared
|
11
12
|
|
@@ -21,10 +22,12 @@ logger = logging.getLogger("wandb")
|
|
21
22
|
class InterfaceSock(InterfaceShared):
|
22
23
|
def __init__(
|
23
24
|
self,
|
25
|
+
asyncer: asyncio_manager.AsyncioManager,
|
24
26
|
client: ServiceClient,
|
25
27
|
stream_id: str,
|
26
28
|
) -> None:
|
27
29
|
super().__init__()
|
30
|
+
self._asyncer = asyncer
|
28
31
|
self._client = client
|
29
32
|
self._stream_id = stream_id
|
30
33
|
|
@@ -37,13 +40,16 @@ class InterfaceSock(InterfaceShared):
|
|
37
40
|
self._assign(record)
|
38
41
|
request = spb.ServerRequest()
|
39
42
|
request.record_publish.CopyFrom(record)
|
40
|
-
self._client.publish(request)
|
43
|
+
self._asyncer.run(lambda: self._client.publish(request))
|
41
44
|
|
42
|
-
@override
|
43
45
|
def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
|
46
|
+
return self._asyncer.run(lambda: self.deliver_async(record))
|
47
|
+
|
48
|
+
@override
|
49
|
+
async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]:
|
44
50
|
self._assign(record)
|
45
51
|
request = spb.ServerRequest()
|
46
52
|
request.record_publish.CopyFrom(record)
|
47
53
|
|
48
|
-
handle = self._client.deliver(request)
|
54
|
+
handle = await self._client.deliver(request)
|
49
55
|
return handle.map(lambda response: response.result_communicate)
|
@@ -1,15 +1,5 @@
|
|
1
1
|
# Generated by ariadne-codegen
|
2
2
|
|
3
|
+
__all__ = ["SERVER_FEATURES_QUERY_GQL", "ServerFeaturesQuery"]
|
3
4
|
from .operations import SERVER_FEATURES_QUERY_GQL
|
4
|
-
from .server_features_query import
|
5
|
-
ServerFeaturesQuery,
|
6
|
-
ServerFeaturesQueryServerInfo,
|
7
|
-
ServerFeaturesQueryServerInfoFeatures,
|
8
|
-
)
|
9
|
-
|
10
|
-
__all__ = [
|
11
|
-
"SERVER_FEATURES_QUERY_GQL",
|
12
|
-
"ServerFeaturesQuery",
|
13
|
-
"ServerFeaturesQueryServerInfo",
|
14
|
-
"ServerFeaturesQueryServerInfoFeatures",
|
15
|
-
]
|
5
|
+
from .server_features_query import ServerFeaturesQuery
|
wandb/sdk/internal/sender.py
CHANGED
@@ -343,7 +343,7 @@ class SendManager:
|
|
343
343
|
publish_interface = InterfaceQueue(record_q=record_q)
|
344
344
|
context_keeper = context.ContextKeeper()
|
345
345
|
return SendManager(
|
346
|
-
settings=SettingsStatic(settings
|
346
|
+
settings=SettingsStatic(dict(settings)),
|
347
347
|
record_q=record_q,
|
348
348
|
result_q=result_q,
|
349
349
|
interface=publish_interface,
|