wandb 0.19.10__py3-none-any.whl → 0.19.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -3
- wandb/_pydantic/__init__.py +2 -3
- wandb/_pydantic/base.py +11 -31
- wandb/_pydantic/utils.py +8 -1
- wandb/_pydantic/v1_compat.py +3 -3
- wandb/apis/public/api.py +590 -22
- wandb/apis/public/artifacts.py +13 -5
- wandb/apis/public/automations.py +1 -1
- wandb/apis/public/integrations.py +22 -10
- wandb/apis/public/registries/__init__.py +0 -0
- wandb/apis/public/registries/_freezable_list.py +179 -0
- wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
- wandb/apis/public/registries/registry.py +357 -0
- wandb/apis/public/registries/utils.py +140 -0
- wandb/apis/public/runs.py +58 -56
- wandb/automations/__init__.py +16 -24
- wandb/automations/_filters/expressions.py +12 -10
- wandb/automations/_filters/operators.py +10 -19
- wandb/automations/_filters/run_metrics.py +231 -82
- wandb/automations/_generated/__init__.py +27 -34
- wandb/automations/_generated/create_automation.py +17 -0
- wandb/automations/_generated/delete_automation.py +17 -0
- wandb/automations/_generated/fragments.py +40 -25
- wandb/automations/_generated/{get_triggers.py → get_automations.py} +5 -5
- wandb/automations/_generated/{get_triggers_by_entity.py → get_automations_by_entity.py} +7 -5
- wandb/automations/_generated/operations.py +35 -98
- wandb/automations/_generated/update_automation.py +17 -0
- wandb/automations/_utils.py +178 -64
- wandb/automations/_validators.py +94 -2
- wandb/automations/actions.py +113 -98
- wandb/automations/automations.py +47 -69
- wandb/automations/events.py +139 -87
- wandb/automations/integrations.py +23 -4
- wandb/automations/scopes.py +22 -20
- wandb/bin/gpu_stats +0 -0
- wandb/env.py +11 -0
- wandb/old/settings.py +4 -1
- wandb/proto/v3/wandb_internal_pb2.py +240 -236
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +236 -236
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +236 -236
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_internal_pb2.py +236 -236
- wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_generated/__init__.py +42 -1
- wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
- wandb/sdk/artifacts/_generated/fragments.py +35 -0
- wandb/sdk/artifacts/_generated/input_types.py +12 -0
- wandb/sdk/artifacts/_generated/operations.py +101 -0
- wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
- wandb/sdk/artifacts/_graphql_fragments.py +1 -0
- wandb/sdk/artifacts/_validators.py +120 -1
- wandb/sdk/artifacts/artifact.py +380 -203
- wandb/sdk/artifacts/artifact_file_cache.py +4 -6
- wandb/sdk/artifacts/artifact_manifest_entry.py +11 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
- wandb/sdk/artifacts/storage_policy.py +3 -0
- wandb/sdk/data_types/video.py +46 -32
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/internal/internal_api.py +21 -31
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/launch/sweeps/utils.py +8 -0
- wandb/sdk/projects/_generated/__init__.py +47 -0
- wandb/sdk/projects/_generated/delete_project.py +22 -0
- wandb/sdk/projects/_generated/enums.py +4 -0
- wandb/sdk/projects/_generated/fetch_registry.py +22 -0
- wandb/sdk/projects/_generated/fragments.py +41 -0
- wandb/sdk/projects/_generated/input_types.py +13 -0
- wandb/sdk/projects/_generated/operations.py +88 -0
- wandb/sdk/projects/_generated/rename_project.py +27 -0
- wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
- wandb/sdk/service/service.py +9 -1
- wandb/sdk/wandb_init.py +32 -5
- wandb/sdk/wandb_run.py +37 -9
- wandb/sdk/wandb_settings.py +6 -7
- wandb/sdk/wandb_setup.py +12 -0
- wandb/util.py +7 -3
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/METADATA +1 -1
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/RECORD +86 -69
- wandb/automations/_generated/create_filter_trigger.py +0 -21
- wandb/automations/_generated/delete_trigger.py +0 -19
- wandb/automations/_generated/update_filter_trigger.py +0 -21
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,7 @@ import subprocess
|
|
11
11
|
import sys
|
12
12
|
from pathlib import Path
|
13
13
|
from tempfile import NamedTemporaryFile
|
14
|
-
from typing import IO,
|
14
|
+
from typing import IO, ContextManager, Iterator, Protocol
|
15
15
|
|
16
16
|
import wandb
|
17
17
|
from wandb import env, util
|
@@ -19,12 +19,10 @@ from wandb.sdk.lib.filesystem import files_in
|
|
19
19
|
from wandb.sdk.lib.hashutil import B64MD5, ETag, b64_to_hex_id
|
20
20
|
from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
|
21
21
|
|
22
|
-
if TYPE_CHECKING:
|
23
|
-
from typing import Protocol
|
24
22
|
|
25
|
-
|
26
|
-
|
27
|
-
|
23
|
+
class Opener(Protocol):
|
24
|
+
def __call__(self, mode: str = ...) -> ContextManager[IO]:
|
25
|
+
pass
|
28
26
|
|
29
27
|
|
30
28
|
def _get_sys_umask_threadsafe() -> int:
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import concurrent.futures
|
5
6
|
import json
|
6
7
|
import logging
|
7
8
|
import os
|
@@ -130,7 +131,11 @@ class ArtifactManifestEntry:
|
|
130
131
|
return self._parent_artifact
|
131
132
|
|
132
133
|
def download(
|
133
|
-
self,
|
134
|
+
self,
|
135
|
+
root: str | None = None,
|
136
|
+
skip_cache: bool | None = None,
|
137
|
+
executor: concurrent.futures.Executor | None = None,
|
138
|
+
multipart: bool | None = None,
|
134
139
|
) -> FilePathStr:
|
135
140
|
"""Download this artifact entry to the specified root path.
|
136
141
|
|
@@ -170,7 +175,11 @@ class ArtifactManifestEntry:
|
|
170
175
|
)
|
171
176
|
else:
|
172
177
|
cache_path = self._parent_artifact.manifest.storage_policy.load_file(
|
173
|
-
self._parent_artifact,
|
178
|
+
self._parent_artifact,
|
179
|
+
self,
|
180
|
+
dest_path=override_cache_path,
|
181
|
+
executor=executor,
|
182
|
+
multipart=multipart,
|
174
183
|
)
|
175
184
|
|
176
185
|
if skip_cache:
|
@@ -2,20 +2,28 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import concurrent.futures
|
6
|
+
import functools
|
5
7
|
import hashlib
|
8
|
+
import logging
|
6
9
|
import math
|
7
10
|
import os
|
11
|
+
import queue
|
8
12
|
import shutil
|
9
|
-
|
13
|
+
import threading
|
14
|
+
from collections import deque
|
15
|
+
from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
|
10
16
|
from urllib.parse import quote
|
11
17
|
|
12
18
|
import requests
|
13
19
|
import urllib3
|
14
20
|
|
21
|
+
from wandb import env
|
15
22
|
from wandb.errors.term import termwarn
|
16
23
|
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
17
24
|
from wandb.sdk.artifacts.artifact_file_cache import (
|
18
25
|
ArtifactFileCache,
|
26
|
+
Opener,
|
19
27
|
get_artifact_file_cache,
|
20
28
|
)
|
21
29
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
@@ -60,6 +68,27 @@ S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
|
|
60
68
|
S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
|
61
69
|
|
62
70
|
|
71
|
+
# Minimum size to switch to multipart download, same as upload, 2GB.
|
72
|
+
_MULTIPART_DOWNLOAD_SIZE = S3_MIN_MULTI_UPLOAD_SIZE
|
73
|
+
# Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
|
74
|
+
# https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
|
75
|
+
# https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
|
76
|
+
_DOWNLOAD_PART_SIZE_BYTES = 100 * 1024 * 1024
|
77
|
+
# Chunk size for reading http response and writing to disk. 1MB.
|
78
|
+
_HTTP_RES_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
|
79
|
+
# Signal end of _ChunkQueue, consumer (file writer) should stop after getting this item.
|
80
|
+
# NOTE: it should only be used for multithread executor, it does notwork for multiprocess executor.
|
81
|
+
# multipart download is using the executor from artifact.download() which is a multithread executor.
|
82
|
+
_CHUNK_QUEUE_SENTINEL = object()
|
83
|
+
|
84
|
+
logger = logging.getLogger(__name__)
|
85
|
+
|
86
|
+
|
87
|
+
class _ChunkContent(NamedTuple):
|
88
|
+
offset: int
|
89
|
+
data: bytes
|
90
|
+
|
91
|
+
|
63
92
|
class WandbStoragePolicy(StoragePolicy):
|
64
93
|
@classmethod
|
65
94
|
def name(cls) -> str:
|
@@ -120,7 +149,20 @@ class WandbStoragePolicy(StoragePolicy):
|
|
120
149
|
artifact: Artifact,
|
121
150
|
manifest_entry: ArtifactManifestEntry,
|
122
151
|
dest_path: str | None = None,
|
152
|
+
executor: concurrent.futures.Executor | None = None,
|
153
|
+
multipart: bool | None = None,
|
123
154
|
) -> FilePathStr:
|
155
|
+
"""Use cache or download the file using signed url.
|
156
|
+
|
157
|
+
Args:
|
158
|
+
executor: Passed from caller, artifact has a thread pool for multi file download.
|
159
|
+
Reuse the thread pool for multi part download. The thread pool is closed when
|
160
|
+
artifact download is done.
|
161
|
+
multipart: If set to `None` (default), the artifact will be downloaded
|
162
|
+
in parallel using multipart download if individual file size is greater than
|
163
|
+
2GB. If set to `True` or `False`, the artifact will be downloaded in
|
164
|
+
parallel or serially regardless of the file size.
|
165
|
+
"""
|
124
166
|
if dest_path is not None:
|
125
167
|
self._cache._override_cache_path = dest_path
|
126
168
|
|
@@ -132,6 +174,20 @@ class WandbStoragePolicy(StoragePolicy):
|
|
132
174
|
return path
|
133
175
|
|
134
176
|
if manifest_entry._download_url is not None:
|
177
|
+
# Use multipart parallel download for large file
|
178
|
+
if (
|
179
|
+
executor is not None
|
180
|
+
and manifest_entry.size is not None
|
181
|
+
and self._should_multipart_download(manifest_entry.size, multipart)
|
182
|
+
):
|
183
|
+
self._multipart_file_download(
|
184
|
+
executor,
|
185
|
+
manifest_entry._download_url,
|
186
|
+
manifest_entry.size,
|
187
|
+
cache_open,
|
188
|
+
)
|
189
|
+
return path
|
190
|
+
# Serial download
|
135
191
|
response = self._session.get(manifest_entry._download_url, stream=True)
|
136
192
|
try:
|
137
193
|
response.raise_for_status()
|
@@ -165,6 +221,131 @@ class WandbStoragePolicy(StoragePolicy):
|
|
165
221
|
file.write(data)
|
166
222
|
return path
|
167
223
|
|
224
|
+
def _should_multipart_download(
|
225
|
+
self,
|
226
|
+
file_size: int,
|
227
|
+
multipart: bool | None,
|
228
|
+
) -> bool:
|
229
|
+
if multipart is not None:
|
230
|
+
return multipart
|
231
|
+
return file_size >= _MULTIPART_DOWNLOAD_SIZE
|
232
|
+
|
233
|
+
def _write_chunks_to_file(
|
234
|
+
self,
|
235
|
+
f: IO,
|
236
|
+
q: queue.Queue,
|
237
|
+
download_has_error: threading.Event,
|
238
|
+
):
|
239
|
+
while not download_has_error.is_set():
|
240
|
+
item = q.get()
|
241
|
+
if item is _CHUNK_QUEUE_SENTINEL:
|
242
|
+
# Normal shutdown, all the chunks are written
|
243
|
+
return
|
244
|
+
elif isinstance(item, _ChunkContent):
|
245
|
+
try:
|
246
|
+
# NOTE: Seek works without pre allocating the file on disk.
|
247
|
+
# It automatically creates a sparse file, e.g. ls -hl would show
|
248
|
+
# a bigger size compared to du -sh * because downloading different
|
249
|
+
# chunks is not a sequential write.
|
250
|
+
# See https://man7.org/linux/man-pages/man2/lseek.2.html
|
251
|
+
f.seek(item.offset)
|
252
|
+
f.write(item.data)
|
253
|
+
except Exception as e:
|
254
|
+
if env.is_debug():
|
255
|
+
logger.debug(f"Error writing chunk to file: {e}")
|
256
|
+
download_has_error.set()
|
257
|
+
raise e
|
258
|
+
else:
|
259
|
+
raise ValueError(f"Unknown queue item type: {type(item)}")
|
260
|
+
|
261
|
+
def _download_part(
|
262
|
+
self,
|
263
|
+
download_url: str,
|
264
|
+
headers: dict,
|
265
|
+
start: int,
|
266
|
+
q: queue.Queue,
|
267
|
+
download_has_error: threading.Event,
|
268
|
+
):
|
269
|
+
# Other threads has error, no need to start
|
270
|
+
if download_has_error.is_set():
|
271
|
+
return
|
272
|
+
response = self._session.get(
|
273
|
+
url=download_url,
|
274
|
+
headers=headers,
|
275
|
+
stream=True,
|
276
|
+
)
|
277
|
+
response.raise_for_status()
|
278
|
+
|
279
|
+
file_offset = start
|
280
|
+
for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
|
281
|
+
if download_has_error.is_set():
|
282
|
+
return
|
283
|
+
q.put(_ChunkContent(offset=file_offset, data=content))
|
284
|
+
file_offset += len(content)
|
285
|
+
|
286
|
+
def _multipart_file_download(
|
287
|
+
self,
|
288
|
+
executor: concurrent.futures.Executor,
|
289
|
+
download_url: str,
|
290
|
+
file_size_bytes: int,
|
291
|
+
cache_open: Opener,
|
292
|
+
):
|
293
|
+
"""Download file as multiple parts in parallel.
|
294
|
+
|
295
|
+
Only one thread for writing to file. Each part run one http request in one thread.
|
296
|
+
HTTP response chunk of a file part is sent to the writer thread via a queue.
|
297
|
+
"""
|
298
|
+
q: queue.Queue[_ChunkContent | object] = queue.Queue(maxsize=500)
|
299
|
+
download_has_error = threading.Event()
|
300
|
+
|
301
|
+
# Put cache_open at top so we remove the tmp file when there is network error.
|
302
|
+
with cache_open("wb") as f:
|
303
|
+
# Start writer thread first.
|
304
|
+
write_handler = functools.partial(
|
305
|
+
self._write_chunks_to_file, f, q, download_has_error
|
306
|
+
)
|
307
|
+
write_future = executor.submit(write_handler)
|
308
|
+
|
309
|
+
# Start download threads for each part.
|
310
|
+
download_futures: deque[concurrent.futures.Future] = deque()
|
311
|
+
part_size = _DOWNLOAD_PART_SIZE_BYTES
|
312
|
+
num_parts = int(math.ceil(file_size_bytes / float(part_size)))
|
313
|
+
for i in range(num_parts):
|
314
|
+
# https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
|
315
|
+
# Start and end are both inclusive, empty end means use the actual end of the file.
|
316
|
+
start = i * part_size
|
317
|
+
bytes_range = f"bytes={start}-"
|
318
|
+
if i != (num_parts - 1):
|
319
|
+
# bytes=0-499
|
320
|
+
bytes_range += f"{start + part_size - 1}"
|
321
|
+
headers = {"Range": bytes_range}
|
322
|
+
download_handler = functools.partial(
|
323
|
+
self._download_part,
|
324
|
+
download_url,
|
325
|
+
headers,
|
326
|
+
start,
|
327
|
+
q,
|
328
|
+
download_has_error,
|
329
|
+
)
|
330
|
+
download_futures.append(executor.submit(download_handler))
|
331
|
+
|
332
|
+
# Wait for download
|
333
|
+
done, not_done = concurrent.futures.wait(
|
334
|
+
download_futures, return_when=concurrent.futures.FIRST_EXCEPTION
|
335
|
+
)
|
336
|
+
try:
|
337
|
+
for fut in done:
|
338
|
+
fut.result()
|
339
|
+
except Exception as e:
|
340
|
+
if env.is_debug():
|
341
|
+
logger.debug(f"Error downloading file: {e}")
|
342
|
+
download_has_error.set()
|
343
|
+
raise e
|
344
|
+
finally:
|
345
|
+
# Always signal the writer to stop
|
346
|
+
q.put(_CHUNK_QUEUE_SENTINEL)
|
347
|
+
write_future.result()
|
348
|
+
|
168
349
|
def store_reference(
|
169
350
|
self,
|
170
351
|
artifact: Artifact,
|
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
+
import concurrent.futures
|
5
6
|
from typing import TYPE_CHECKING, Sequence
|
6
7
|
|
7
8
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
@@ -40,6 +41,8 @@ class StoragePolicy:
|
|
40
41
|
artifact: Artifact,
|
41
42
|
manifest_entry: ArtifactManifestEntry,
|
42
43
|
dest_path: str | None = None,
|
44
|
+
executor: concurrent.futures.Executor | None = None,
|
45
|
+
multipart: bool | None = None,
|
43
46
|
) -> FilePathStr:
|
44
47
|
raise NotImplementedError
|
45
48
|
|
wandb/sdk/data_types/video.py
CHANGED
@@ -2,7 +2,7 @@ import functools
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
from io import BytesIO
|
5
|
-
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Type, Union
|
6
6
|
|
7
7
|
import wandb
|
8
8
|
from wandb import util
|
@@ -48,36 +48,7 @@ def write_gif_with_image_io(
|
|
48
48
|
|
49
49
|
|
50
50
|
class Video(BatchableMedia):
|
51
|
-
"""
|
52
|
-
|
53
|
-
Args:
|
54
|
-
data_or_path: (numpy array, string, io)
|
55
|
-
Video can be initialized with a path to a file or an io object.
|
56
|
-
The format must be "gif", "mp4", "webm" or "ogg".
|
57
|
-
The format must be specified with the format argument.
|
58
|
-
Video can be initialized with a numpy tensor.
|
59
|
-
The numpy tensor must be either 4 dimensional or 5 dimensional.
|
60
|
-
Channels should be (time, channel, height, width) or
|
61
|
-
(batch, time, channel, height width)
|
62
|
-
caption: (string) caption associated with the video for display
|
63
|
-
fps: (int)
|
64
|
-
The frame rate to use when encoding raw video frames. Default value is 4.
|
65
|
-
This parameter has no effect when data_or_path is a string, or bytes.
|
66
|
-
format: (string) format of video, necessary if initializing with path or io object.
|
67
|
-
|
68
|
-
Examples:
|
69
|
-
### Log a numpy array as a video
|
70
|
-
<!--yeadoc-test:log-video-numpy-->
|
71
|
-
```python
|
72
|
-
import numpy as np
|
73
|
-
import wandb
|
74
|
-
|
75
|
-
run = wandb.init()
|
76
|
-
# axes are (time, channel, height, width)
|
77
|
-
frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
|
78
|
-
run.log({"video": wandb.Video(frames, fps=4)})
|
79
|
-
```
|
80
|
-
"""
|
51
|
+
"""A class for logging videos to W&B."""
|
81
52
|
|
82
53
|
_log_type = "video-file"
|
83
54
|
EXTS = ("gif", "mp4", "webm", "ogg")
|
@@ -89,10 +60,53 @@ class Video(BatchableMedia):
|
|
89
60
|
data_or_path: Union["np.ndarray", str, "TextIO", "BytesIO"],
|
90
61
|
caption: Optional[str] = None,
|
91
62
|
fps: Optional[int] = None,
|
92
|
-
format: Optional[
|
63
|
+
format: Optional[Literal["gif", "mp4", "webm", "ogg"]] = None,
|
93
64
|
):
|
65
|
+
"""Initialize a W&B Video object.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
data_or_path:
|
69
|
+
Video can be initialized with a path to a file or an io object.
|
70
|
+
Video can be initialized with a numpy tensor.
|
71
|
+
The numpy tensor must be either 4 dimensional or 5 dimensional.
|
72
|
+
The dimensions should be (number of frames, channel, height, width) or
|
73
|
+
(batch, number of frames, channel, height, width)
|
74
|
+
The format parameter must be specified with the format argument
|
75
|
+
when initializing with a numpy array
|
76
|
+
or io object.
|
77
|
+
caption: Caption associated with the video for display.
|
78
|
+
fps:
|
79
|
+
The frame rate to use when encoding raw video frames.
|
80
|
+
Default value is 4.
|
81
|
+
This parameter has no effect when data_or_path is a string, or bytes.
|
82
|
+
format:
|
83
|
+
Format of video, necessary if initializing with a numpy array
|
84
|
+
or io object. This parameter will be used to determine the format
|
85
|
+
to use when encoding the video data. Accepted values are "gif",
|
86
|
+
"mp4", "webm", or "ogg".
|
87
|
+
|
88
|
+
Examples:
|
89
|
+
### Log a numpy array as a video
|
90
|
+
```python
|
91
|
+
import numpy as np
|
92
|
+
import wandb
|
93
|
+
|
94
|
+
with wandb.init() as run:
|
95
|
+
# axes are (number of frames, channel, height, width)
|
96
|
+
frames = np.random.randint(
|
97
|
+
low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8
|
98
|
+
)
|
99
|
+
run.log({"video": wandb.Video(frames, format="mp4", fps=4)})
|
100
|
+
```
|
101
|
+
"""
|
94
102
|
super().__init__(caption=caption)
|
95
103
|
|
104
|
+
if format is None:
|
105
|
+
wandb.termwarn(
|
106
|
+
"`format` argument was not provided, defaulting to `gif`. "
|
107
|
+
"This parameter will be required in v0.20.0, "
|
108
|
+
"please specify the format explicitly."
|
109
|
+
)
|
96
110
|
self._format = format or "gif"
|
97
111
|
self._width = None
|
98
112
|
self._height = None
|
wandb/sdk/interface/interface.py
CHANGED
@@ -428,7 +428,6 @@ class InterfaceBase:
|
|
428
428
|
|
429
429
|
def deliver_link_artifact(
|
430
430
|
self,
|
431
|
-
run: "Run",
|
432
431
|
artifact: "Artifact",
|
433
432
|
portfolio_name: str,
|
434
433
|
aliases: Iterable[str],
|
@@ -442,9 +441,9 @@ class InterfaceBase:
|
|
442
441
|
else:
|
443
442
|
link_artifact.server_id = artifact.id if artifact.id else ""
|
444
443
|
link_artifact.portfolio_name = portfolio_name
|
445
|
-
link_artifact.portfolio_entity = entity or
|
444
|
+
link_artifact.portfolio_entity = entity or ""
|
446
445
|
link_artifact.portfolio_organization = organization or ""
|
447
|
-
link_artifact.portfolio_project = project or
|
446
|
+
link_artifact.portfolio_project = project or ""
|
448
447
|
link_artifact.portfolio_aliases.extend(aliases)
|
449
448
|
|
450
449
|
return self._deliver_link_artifact(link_artifact)
|
@@ -12,6 +12,7 @@ import sys
|
|
12
12
|
import threading
|
13
13
|
from copy import deepcopy
|
14
14
|
from pathlib import Path
|
15
|
+
from types import MappingProxyType
|
15
16
|
from typing import (
|
16
17
|
IO,
|
17
18
|
TYPE_CHECKING,
|
@@ -189,11 +190,6 @@ def _match_org_with_fetched_org_entities(
|
|
189
190
|
"""
|
190
191
|
for org_names in orgs:
|
191
192
|
if organization in org_names:
|
192
|
-
wandb.termwarn(
|
193
|
-
"Registries can be linked/fetched using a shorthand form without specifying the organization name. "
|
194
|
-
"Try using shorthand path format: <my_registry_name>/<artifact_name> or "
|
195
|
-
"just <my_registry_name> if fetching just the project."
|
196
|
-
)
|
197
193
|
return org_names.entity_name
|
198
194
|
|
199
195
|
if len(orgs) == 1:
|
@@ -873,30 +869,29 @@ class Api:
|
|
873
869
|
_, _, mutations = self.server_info_introspection()
|
874
870
|
return "updateRunQueueItemWarning" in mutations
|
875
871
|
|
876
|
-
def
|
877
|
-
"""
|
878
|
-
|
879
|
-
Args:
|
880
|
-
feature_value (ServerFeature): The enum value of the feature to check.
|
881
|
-
|
882
|
-
Returns:
|
883
|
-
bool: True if the feature is enabled, False otherwise.
|
884
|
-
|
885
|
-
Raises:
|
886
|
-
Exception: If server doesn't support feature queries or other errors occur
|
887
|
-
"""
|
872
|
+
def _server_features(self) -> Mapping[str, bool]:
|
873
|
+
"""Returns a cached, read-only lookup of current server feature flags."""
|
888
874
|
if self._server_features_cache is None:
|
889
875
|
query = gql(SERVER_FEATURES_QUERY_GQL)
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
876
|
+
|
877
|
+
try:
|
878
|
+
response = self.gql(query)
|
879
|
+
except Exception as e:
|
880
|
+
# Unfortunately we currently have to match on the text of the error message
|
881
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
882
|
+
self._server_features_cache = {}
|
883
|
+
else:
|
884
|
+
raise
|
896
885
|
else:
|
897
|
-
|
886
|
+
info = ServerFeaturesQuery.model_validate(response).server_info
|
887
|
+
if info and (feats := info.features):
|
888
|
+
self._server_features_cache = {
|
889
|
+
f.name: f.is_enabled for f in feats if f
|
890
|
+
}
|
891
|
+
else:
|
892
|
+
self._server_features_cache = {}
|
898
893
|
|
899
|
-
return self._server_features_cache
|
894
|
+
return MappingProxyType(self._server_features_cache)
|
900
895
|
|
901
896
|
def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
|
902
897
|
"""Wrapper around check_server_feature that warns and returns False for older unsupported servers.
|
@@ -912,12 +907,7 @@ class Api:
|
|
912
907
|
Exceptions:
|
913
908
|
Exception: If an error other than the server not supporting feature queries occurs.
|
914
909
|
"""
|
915
|
-
|
916
|
-
return self._check_server_feature(feature_value)
|
917
|
-
except Exception as e:
|
918
|
-
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
919
|
-
return False
|
920
|
-
raise e
|
910
|
+
return self._server_features().get(ServerFeature.Name(feature_value), False)
|
921
911
|
|
922
912
|
@normalize_exceptions
|
923
913
|
def update_run_queue_item_warning(
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1444,7 +1444,7 @@ class SendManager:
|
|
1444
1444
|
)
|
1445
1445
|
if (client_id or server_id) and portfolio_name and entity and project:
|
1446
1446
|
try:
|
1447
|
-
self._api.link_artifact(
|
1447
|
+
response = self._api.link_artifact(
|
1448
1448
|
client_id,
|
1449
1449
|
server_id,
|
1450
1450
|
portfolio_name,
|
@@ -1453,9 +1453,12 @@ class SendManager:
|
|
1453
1453
|
aliases,
|
1454
1454
|
organization,
|
1455
1455
|
)
|
1456
|
+
result.response.link_artifact_response.version_index = response[
|
1457
|
+
"versionIndex"
|
1458
|
+
]
|
1456
1459
|
except Exception as e:
|
1457
1460
|
org_or_entity = organization or entity
|
1458
|
-
result.response.
|
1461
|
+
result.response.link_artifact_response.error_message = (
|
1459
1462
|
f"error linking artifact to "
|
1460
1463
|
f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
|
1461
1464
|
)
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -223,6 +223,10 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
224
224
|
# (5) flags without equals (e.g. --foo bar)
|
225
225
|
args_no_equals: List[str] = []
|
226
|
+
# (6) flags for hydra append config value (e.g. +foo=bar)
|
227
|
+
flags_append_hydra: List[str] = []
|
228
|
+
# (7) flags for hydra override config value (e.g. ++foo=bar)
|
229
|
+
flags_override_hydra: List[str] = []
|
226
230
|
for param, config in command["args"].items():
|
227
231
|
# allow 'None' as a valid value, but error if no value is found
|
228
232
|
try:
|
@@ -234,6 +238,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
234
238
|
flags.append("--" + _flag)
|
235
239
|
flags_no_hyphens.append(_flag)
|
236
240
|
args_no_equals += [f"--{param}", str(_value)]
|
241
|
+
flags_append_hydra.append("+" + _flag)
|
242
|
+
flags_override_hydra.append("++" + _flag)
|
237
243
|
if isinstance(_value, bool):
|
238
244
|
# omit flags if they are boolean and false
|
239
245
|
if _value:
|
@@ -248,6 +254,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
248
254
|
"args_no_boolean_flags": flags_no_booleans,
|
249
255
|
"args_json": [json.dumps(flags_dict)],
|
250
256
|
"args_dict": flags_dict,
|
257
|
+
"args_append_hydra": flags_append_hydra,
|
258
|
+
"args_override_hydra": flags_override_hydra,
|
251
259
|
}
|
252
260
|
|
253
261
|
|
@@ -0,0 +1,47 @@
|
|
1
|
+
# Generated by ariadne-codegen
|
2
|
+
|
3
|
+
from .delete_project import DeleteProject, DeleteProjectDeleteModel
|
4
|
+
from .fetch_registry import FetchRegistry, FetchRegistryEntity
|
5
|
+
from .fragments import (
|
6
|
+
RegistryFragment,
|
7
|
+
RegistryFragmentArtifactTypes,
|
8
|
+
RegistryFragmentArtifactTypesEdges,
|
9
|
+
RegistryFragmentArtifactTypesEdgesNode,
|
10
|
+
)
|
11
|
+
from .input_types import ArtifactTypeInput
|
12
|
+
from .operations import (
|
13
|
+
DELETE_PROJECT_GQL,
|
14
|
+
FETCH_REGISTRY_GQL,
|
15
|
+
RENAME_PROJECT_GQL,
|
16
|
+
UPSERT_REGISTRY_PROJECT_GQL,
|
17
|
+
)
|
18
|
+
from .rename_project import (
|
19
|
+
RenameProject,
|
20
|
+
RenameProjectRenameProject,
|
21
|
+
RenameProjectRenameProjectProject,
|
22
|
+
)
|
23
|
+
from .upsert_registry_project import (
|
24
|
+
UpsertRegistryProject,
|
25
|
+
UpsertRegistryProjectUpsertModel,
|
26
|
+
)
|
27
|
+
|
28
|
+
__all__ = [
|
29
|
+
"DELETE_PROJECT_GQL",
|
30
|
+
"FETCH_REGISTRY_GQL",
|
31
|
+
"RENAME_PROJECT_GQL",
|
32
|
+
"UPSERT_REGISTRY_PROJECT_GQL",
|
33
|
+
"FetchRegistry",
|
34
|
+
"FetchRegistryEntity",
|
35
|
+
"RenameProject",
|
36
|
+
"RenameProjectRenameProject",
|
37
|
+
"RenameProjectRenameProjectProject",
|
38
|
+
"UpsertRegistryProject",
|
39
|
+
"UpsertRegistryProjectUpsertModel",
|
40
|
+
"DeleteProject",
|
41
|
+
"DeleteProjectDeleteModel",
|
42
|
+
"ArtifactTypeInput",
|
43
|
+
"RegistryFragment",
|
44
|
+
"RegistryFragmentArtifactTypes",
|
45
|
+
"RegistryFragmentArtifactTypesEdges",
|
46
|
+
"RegistryFragmentArtifactTypesEdgesNode",
|
47
|
+
]
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Generated by ariadne-codegen
|
2
|
+
# Source: tools/graphql_codegen/projects/
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
from typing import Literal, Optional
|
7
|
+
|
8
|
+
from pydantic import Field
|
9
|
+
|
10
|
+
from wandb._pydantic import GQLBase, Typename
|
11
|
+
|
12
|
+
|
13
|
+
class DeleteProject(GQLBase):
|
14
|
+
delete_model: Optional[DeleteProjectDeleteModel] = Field(alias="deleteModel")
|
15
|
+
|
16
|
+
|
17
|
+
class DeleteProjectDeleteModel(GQLBase):
|
18
|
+
success: Optional[bool]
|
19
|
+
typename__: Typename[Literal["DeleteModelPayload"]]
|
20
|
+
|
21
|
+
|
22
|
+
DeleteProject.model_rebuild()
|
@@ -0,0 +1,22 @@
|
|
1
|
+
# Generated by ariadne-codegen
|
2
|
+
# Source: tools/graphql_codegen/projects/
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
from typing import Optional
|
7
|
+
|
8
|
+
from wandb._pydantic import GQLBase
|
9
|
+
|
10
|
+
from .fragments import RegistryFragment
|
11
|
+
|
12
|
+
|
13
|
+
class FetchRegistry(GQLBase):
|
14
|
+
entity: Optional[FetchRegistryEntity]
|
15
|
+
|
16
|
+
|
17
|
+
class FetchRegistryEntity(GQLBase):
|
18
|
+
project: Optional[RegistryFragment]
|
19
|
+
|
20
|
+
|
21
|
+
FetchRegistry.model_rebuild()
|
22
|
+
FetchRegistryEntity.model_rebuild()
|