wandb 0.19.10__py3-none-musllinux_1_2_aarch64.whl → 0.19.11__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +2 -3
  4. wandb/_pydantic/base.py +11 -31
  5. wandb/_pydantic/utils.py +8 -1
  6. wandb/_pydantic/v1_compat.py +3 -3
  7. wandb/apis/public/api.py +590 -22
  8. wandb/apis/public/artifacts.py +13 -5
  9. wandb/apis/public/automations.py +1 -1
  10. wandb/apis/public/integrations.py +22 -10
  11. wandb/apis/public/registries/__init__.py +0 -0
  12. wandb/apis/public/registries/_freezable_list.py +179 -0
  13. wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
  14. wandb/apis/public/registries/registry.py +357 -0
  15. wandb/apis/public/registries/utils.py +140 -0
  16. wandb/apis/public/runs.py +58 -56
  17. wandb/automations/__init__.py +16 -24
  18. wandb/automations/_filters/expressions.py +12 -10
  19. wandb/automations/_filters/operators.py +10 -19
  20. wandb/automations/_filters/run_metrics.py +231 -82
  21. wandb/automations/_generated/__init__.py +27 -34
  22. wandb/automations/_generated/create_automation.py +17 -0
  23. wandb/automations/_generated/delete_automation.py +17 -0
  24. wandb/automations/_generated/fragments.py +40 -25
  25. wandb/automations/_generated/{get_triggers.py → get_automations.py} +5 -5
  26. wandb/automations/_generated/{get_triggers_by_entity.py → get_automations_by_entity.py} +7 -5
  27. wandb/automations/_generated/operations.py +35 -98
  28. wandb/automations/_generated/update_automation.py +17 -0
  29. wandb/automations/_utils.py +178 -64
  30. wandb/automations/_validators.py +94 -2
  31. wandb/automations/actions.py +113 -98
  32. wandb/automations/automations.py +47 -69
  33. wandb/automations/events.py +139 -87
  34. wandb/automations/integrations.py +23 -4
  35. wandb/automations/scopes.py +22 -20
  36. wandb/bin/gpu_stats +0 -0
  37. wandb/bin/wandb-core +0 -0
  38. wandb/env.py +11 -0
  39. wandb/old/settings.py +4 -1
  40. wandb/proto/v3/wandb_internal_pb2.py +240 -236
  41. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  42. wandb/proto/v4/wandb_internal_pb2.py +236 -236
  43. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  44. wandb/proto/v5/wandb_internal_pb2.py +236 -236
  45. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v6/wandb_internal_pb2.py +236 -236
  47. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  48. wandb/sdk/artifacts/_generated/__init__.py +42 -1
  49. wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
  50. wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
  51. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
  52. wandb/sdk/artifacts/_generated/fragments.py +35 -0
  53. wandb/sdk/artifacts/_generated/input_types.py +12 -0
  54. wandb/sdk/artifacts/_generated/operations.py +101 -0
  55. wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
  56. wandb/sdk/artifacts/_graphql_fragments.py +1 -0
  57. wandb/sdk/artifacts/_validators.py +120 -1
  58. wandb/sdk/artifacts/artifact.py +380 -203
  59. wandb/sdk/artifacts/artifact_file_cache.py +4 -6
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +11 -2
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
  62. wandb/sdk/artifacts/storage_policy.py +3 -0
  63. wandb/sdk/data_types/video.py +46 -32
  64. wandb/sdk/interface/interface.py +2 -3
  65. wandb/sdk/internal/internal_api.py +21 -31
  66. wandb/sdk/internal/sender.py +5 -2
  67. wandb/sdk/launch/sweeps/utils.py +8 -0
  68. wandb/sdk/projects/_generated/__init__.py +47 -0
  69. wandb/sdk/projects/_generated/delete_project.py +22 -0
  70. wandb/sdk/projects/_generated/enums.py +4 -0
  71. wandb/sdk/projects/_generated/fetch_registry.py +22 -0
  72. wandb/sdk/projects/_generated/fragments.py +41 -0
  73. wandb/sdk/projects/_generated/input_types.py +13 -0
  74. wandb/sdk/projects/_generated/operations.py +88 -0
  75. wandb/sdk/projects/_generated/rename_project.py +27 -0
  76. wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
  77. wandb/sdk/service/service.py +9 -1
  78. wandb/sdk/wandb_init.py +32 -5
  79. wandb/sdk/wandb_run.py +37 -9
  80. wandb/sdk/wandb_settings.py +6 -7
  81. wandb/sdk/wandb_setup.py +12 -0
  82. wandb/util.py +7 -3
  83. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/METADATA +1 -1
  84. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/RECORD +87 -70
  85. wandb/automations/_generated/create_filter_trigger.py +0 -21
  86. wandb/automations/_generated/delete_trigger.py +0 -19
  87. wandb/automations/_generated/update_filter_trigger.py +0 -21
  88. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
  89. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.19.10.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
@@ -11,7 +11,7 @@ import subprocess
11
11
  import sys
12
12
  from pathlib import Path
13
13
  from tempfile import NamedTemporaryFile
14
- from typing import IO, TYPE_CHECKING, ContextManager, Iterator
14
+ from typing import IO, ContextManager, Iterator, Protocol
15
15
 
16
16
  import wandb
17
17
  from wandb import env, util
@@ -19,12 +19,10 @@ from wandb.sdk.lib.filesystem import files_in
19
19
  from wandb.sdk.lib.hashutil import B64MD5, ETag, b64_to_hex_id
20
20
  from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
21
21
 
22
- if TYPE_CHECKING:
23
- from typing import Protocol
24
22
 
25
- class Opener(Protocol):
26
- def __call__(self, mode: str = ...) -> ContextManager[IO]:
27
- pass
23
+ class Opener(Protocol):
24
+ def __call__(self, mode: str = ...) -> ContextManager[IO]:
25
+ pass
28
26
 
29
27
 
30
28
  def _get_sys_umask_threadsafe() -> int:
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import concurrent.futures
5
6
  import json
6
7
  import logging
7
8
  import os
@@ -130,7 +131,11 @@ class ArtifactManifestEntry:
130
131
  return self._parent_artifact
131
132
 
132
133
  def download(
133
- self, root: str | None = None, skip_cache: bool | None = None
134
+ self,
135
+ root: str | None = None,
136
+ skip_cache: bool | None = None,
137
+ executor: concurrent.futures.Executor | None = None,
138
+ multipart: bool | None = None,
134
139
  ) -> FilePathStr:
135
140
  """Download this artifact entry to the specified root path.
136
141
 
@@ -170,7 +175,11 @@ class ArtifactManifestEntry:
170
175
  )
171
176
  else:
172
177
  cache_path = self._parent_artifact.manifest.storage_policy.load_file(
173
- self._parent_artifact, self, dest_path=override_cache_path
178
+ self._parent_artifact,
179
+ self,
180
+ dest_path=override_cache_path,
181
+ executor=executor,
182
+ multipart=multipart,
174
183
  )
175
184
 
176
185
  if skip_cache:
@@ -2,20 +2,28 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import concurrent.futures
6
+ import functools
5
7
  import hashlib
8
+ import logging
6
9
  import math
7
10
  import os
11
+ import queue
8
12
  import shutil
9
- from typing import TYPE_CHECKING, Any, Sequence
13
+ import threading
14
+ from collections import deque
15
+ from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
10
16
  from urllib.parse import quote
11
17
 
12
18
  import requests
13
19
  import urllib3
14
20
 
21
+ from wandb import env
15
22
  from wandb.errors.term import termwarn
16
23
  from wandb.proto.wandb_internal_pb2 import ServerFeature
17
24
  from wandb.sdk.artifacts.artifact_file_cache import (
18
25
  ArtifactFileCache,
26
+ Opener,
19
27
  get_artifact_file_cache,
20
28
  )
21
29
  from wandb.sdk.artifacts.staging import get_staging_dir
@@ -60,6 +68,27 @@ S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
60
68
  S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
61
69
 
62
70
 
71
+ # Minimum size to switch to multipart download, same as upload, 2GB.
72
+ _MULTIPART_DOWNLOAD_SIZE = S3_MIN_MULTI_UPLOAD_SIZE
73
+ # Multipart download part size is same as multpart upload size, which is hard coded to 100MB.
74
+ # https://github.com/wandb/wandb/blob/7b2a13cb8efcd553317167b823c8e52d8c3f7c4e/core/pkg/artifacts/saver.go#L496
75
+ # https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-guidelines.html#optimizing-performance-guidelines-get-range
76
+ _DOWNLOAD_PART_SIZE_BYTES = 100 * 1024 * 1024
77
+ # Chunk size for reading http response and writing to disk. 1MB.
78
+ _HTTP_RES_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
79
+ # Signal end of _ChunkQueue, consumer (file writer) should stop after getting this item.
80
+ # NOTE: it should only be used for multithread executor, it does notwork for multiprocess executor.
81
+ # multipart download is using the executor from artifact.download() which is a multithread executor.
82
+ _CHUNK_QUEUE_SENTINEL = object()
83
+
84
+ logger = logging.getLogger(__name__)
85
+
86
+
87
+ class _ChunkContent(NamedTuple):
88
+ offset: int
89
+ data: bytes
90
+
91
+
63
92
  class WandbStoragePolicy(StoragePolicy):
64
93
  @classmethod
65
94
  def name(cls) -> str:
@@ -120,7 +149,20 @@ class WandbStoragePolicy(StoragePolicy):
120
149
  artifact: Artifact,
121
150
  manifest_entry: ArtifactManifestEntry,
122
151
  dest_path: str | None = None,
152
+ executor: concurrent.futures.Executor | None = None,
153
+ multipart: bool | None = None,
123
154
  ) -> FilePathStr:
155
+ """Use cache or download the file using signed url.
156
+
157
+ Args:
158
+ executor: Passed from caller, artifact has a thread pool for multi file download.
159
+ Reuse the thread pool for multi part download. The thread pool is closed when
160
+ artifact download is done.
161
+ multipart: If set to `None` (default), the artifact will be downloaded
162
+ in parallel using multipart download if individual file size is greater than
163
+ 2GB. If set to `True` or `False`, the artifact will be downloaded in
164
+ parallel or serially regardless of the file size.
165
+ """
124
166
  if dest_path is not None:
125
167
  self._cache._override_cache_path = dest_path
126
168
 
@@ -132,6 +174,20 @@ class WandbStoragePolicy(StoragePolicy):
132
174
  return path
133
175
 
134
176
  if manifest_entry._download_url is not None:
177
+ # Use multipart parallel download for large file
178
+ if (
179
+ executor is not None
180
+ and manifest_entry.size is not None
181
+ and self._should_multipart_download(manifest_entry.size, multipart)
182
+ ):
183
+ self._multipart_file_download(
184
+ executor,
185
+ manifest_entry._download_url,
186
+ manifest_entry.size,
187
+ cache_open,
188
+ )
189
+ return path
190
+ # Serial download
135
191
  response = self._session.get(manifest_entry._download_url, stream=True)
136
192
  try:
137
193
  response.raise_for_status()
@@ -165,6 +221,131 @@ class WandbStoragePolicy(StoragePolicy):
165
221
  file.write(data)
166
222
  return path
167
223
 
224
+ def _should_multipart_download(
225
+ self,
226
+ file_size: int,
227
+ multipart: bool | None,
228
+ ) -> bool:
229
+ if multipart is not None:
230
+ return multipart
231
+ return file_size >= _MULTIPART_DOWNLOAD_SIZE
232
+
233
+ def _write_chunks_to_file(
234
+ self,
235
+ f: IO,
236
+ q: queue.Queue,
237
+ download_has_error: threading.Event,
238
+ ):
239
+ while not download_has_error.is_set():
240
+ item = q.get()
241
+ if item is _CHUNK_QUEUE_SENTINEL:
242
+ # Normal shutdown, all the chunks are written
243
+ return
244
+ elif isinstance(item, _ChunkContent):
245
+ try:
246
+ # NOTE: Seek works without pre allocating the file on disk.
247
+ # It automatically creates a sparse file, e.g. ls -hl would show
248
+ # a bigger size compared to du -sh * because downloading different
249
+ # chunks is not a sequential write.
250
+ # See https://man7.org/linux/man-pages/man2/lseek.2.html
251
+ f.seek(item.offset)
252
+ f.write(item.data)
253
+ except Exception as e:
254
+ if env.is_debug():
255
+ logger.debug(f"Error writing chunk to file: {e}")
256
+ download_has_error.set()
257
+ raise e
258
+ else:
259
+ raise ValueError(f"Unknown queue item type: {type(item)}")
260
+
261
+ def _download_part(
262
+ self,
263
+ download_url: str,
264
+ headers: dict,
265
+ start: int,
266
+ q: queue.Queue,
267
+ download_has_error: threading.Event,
268
+ ):
269
+ # Other threads has error, no need to start
270
+ if download_has_error.is_set():
271
+ return
272
+ response = self._session.get(
273
+ url=download_url,
274
+ headers=headers,
275
+ stream=True,
276
+ )
277
+ response.raise_for_status()
278
+
279
+ file_offset = start
280
+ for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
281
+ if download_has_error.is_set():
282
+ return
283
+ q.put(_ChunkContent(offset=file_offset, data=content))
284
+ file_offset += len(content)
285
+
286
+ def _multipart_file_download(
287
+ self,
288
+ executor: concurrent.futures.Executor,
289
+ download_url: str,
290
+ file_size_bytes: int,
291
+ cache_open: Opener,
292
+ ):
293
+ """Download file as multiple parts in parallel.
294
+
295
+ Only one thread for writing to file. Each part run one http request in one thread.
296
+ HTTP response chunk of a file part is sent to the writer thread via a queue.
297
+ """
298
+ q: queue.Queue[_ChunkContent | object] = queue.Queue(maxsize=500)
299
+ download_has_error = threading.Event()
300
+
301
+ # Put cache_open at top so we remove the tmp file when there is network error.
302
+ with cache_open("wb") as f:
303
+ # Start writer thread first.
304
+ write_handler = functools.partial(
305
+ self._write_chunks_to_file, f, q, download_has_error
306
+ )
307
+ write_future = executor.submit(write_handler)
308
+
309
+ # Start download threads for each part.
310
+ download_futures: deque[concurrent.futures.Future] = deque()
311
+ part_size = _DOWNLOAD_PART_SIZE_BYTES
312
+ num_parts = int(math.ceil(file_size_bytes / float(part_size)))
313
+ for i in range(num_parts):
314
+ # https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Range
315
+ # Start and end are both inclusive, empty end means use the actual end of the file.
316
+ start = i * part_size
317
+ bytes_range = f"bytes={start}-"
318
+ if i != (num_parts - 1):
319
+ # bytes=0-499
320
+ bytes_range += f"{start + part_size - 1}"
321
+ headers = {"Range": bytes_range}
322
+ download_handler = functools.partial(
323
+ self._download_part,
324
+ download_url,
325
+ headers,
326
+ start,
327
+ q,
328
+ download_has_error,
329
+ )
330
+ download_futures.append(executor.submit(download_handler))
331
+
332
+ # Wait for download
333
+ done, not_done = concurrent.futures.wait(
334
+ download_futures, return_when=concurrent.futures.FIRST_EXCEPTION
335
+ )
336
+ try:
337
+ for fut in done:
338
+ fut.result()
339
+ except Exception as e:
340
+ if env.is_debug():
341
+ logger.debug(f"Error downloading file: {e}")
342
+ download_has_error.set()
343
+ raise e
344
+ finally:
345
+ # Always signal the writer to stop
346
+ q.put(_CHUNK_QUEUE_SENTINEL)
347
+ write_future.result()
348
+
168
349
  def store_reference(
169
350
  self,
170
351
  artifact: Artifact,
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import concurrent.futures
5
6
  from typing import TYPE_CHECKING, Sequence
6
7
 
7
8
  from wandb.sdk.internal.internal_api import Api as InternalApi
@@ -40,6 +41,8 @@ class StoragePolicy:
40
41
  artifact: Artifact,
41
42
  manifest_entry: ArtifactManifestEntry,
42
43
  dest_path: str | None = None,
44
+ executor: concurrent.futures.Executor | None = None,
45
+ multipart: bool | None = None,
43
46
  ) -> FilePathStr:
44
47
  raise NotImplementedError
45
48
 
@@ -2,7 +2,7 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from io import BytesIO
5
- from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union
5
+ from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Type, Union
6
6
 
7
7
  import wandb
8
8
  from wandb import util
@@ -48,36 +48,7 @@ def write_gif_with_image_io(
48
48
 
49
49
 
50
50
  class Video(BatchableMedia):
51
- """Format a video for logging to W&B.
52
-
53
- Args:
54
- data_or_path: (numpy array, string, io)
55
- Video can be initialized with a path to a file or an io object.
56
- The format must be "gif", "mp4", "webm" or "ogg".
57
- The format must be specified with the format argument.
58
- Video can be initialized with a numpy tensor.
59
- The numpy tensor must be either 4 dimensional or 5 dimensional.
60
- Channels should be (time, channel, height, width) or
61
- (batch, time, channel, height width)
62
- caption: (string) caption associated with the video for display
63
- fps: (int)
64
- The frame rate to use when encoding raw video frames. Default value is 4.
65
- This parameter has no effect when data_or_path is a string, or bytes.
66
- format: (string) format of video, necessary if initializing with path or io object.
67
-
68
- Examples:
69
- ### Log a numpy array as a video
70
- <!--yeadoc-test:log-video-numpy-->
71
- ```python
72
- import numpy as np
73
- import wandb
74
-
75
- run = wandb.init()
76
- # axes are (time, channel, height, width)
77
- frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
78
- run.log({"video": wandb.Video(frames, fps=4)})
79
- ```
80
- """
51
+ """A class for logging videos to W&B."""
81
52
 
82
53
  _log_type = "video-file"
83
54
  EXTS = ("gif", "mp4", "webm", "ogg")
@@ -89,10 +60,53 @@ class Video(BatchableMedia):
89
60
  data_or_path: Union["np.ndarray", str, "TextIO", "BytesIO"],
90
61
  caption: Optional[str] = None,
91
62
  fps: Optional[int] = None,
92
- format: Optional[str] = None,
63
+ format: Optional[Literal["gif", "mp4", "webm", "ogg"]] = None,
93
64
  ):
65
+ """Initialize a W&B Video object.
66
+
67
+ Args:
68
+ data_or_path:
69
+ Video can be initialized with a path to a file or an io object.
70
+ Video can be initialized with a numpy tensor.
71
+ The numpy tensor must be either 4 dimensional or 5 dimensional.
72
+ The dimensions should be (number of frames, channel, height, width) or
73
+ (batch, number of frames, channel, height, width)
74
+ The format parameter must be specified with the format argument
75
+ when initializing with a numpy array
76
+ or io object.
77
+ caption: Caption associated with the video for display.
78
+ fps:
79
+ The frame rate to use when encoding raw video frames.
80
+ Default value is 4.
81
+ This parameter has no effect when data_or_path is a string, or bytes.
82
+ format:
83
+ Format of video, necessary if initializing with a numpy array
84
+ or io object. This parameter will be used to determine the format
85
+ to use when encoding the video data. Accepted values are "gif",
86
+ "mp4", "webm", or "ogg".
87
+
88
+ Examples:
89
+ ### Log a numpy array as a video
90
+ ```python
91
+ import numpy as np
92
+ import wandb
93
+
94
+ with wandb.init() as run:
95
+ # axes are (number of frames, channel, height, width)
96
+ frames = np.random.randint(
97
+ low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8
98
+ )
99
+ run.log({"video": wandb.Video(frames, format="mp4", fps=4)})
100
+ ```
101
+ """
94
102
  super().__init__(caption=caption)
95
103
 
104
+ if format is None:
105
+ wandb.termwarn(
106
+ "`format` argument was not provided, defaulting to `gif`. "
107
+ "This parameter will be required in v0.20.0, "
108
+ "please specify the format explicitly."
109
+ )
96
110
  self._format = format or "gif"
97
111
  self._width = None
98
112
  self._height = None
@@ -428,7 +428,6 @@ class InterfaceBase:
428
428
 
429
429
  def deliver_link_artifact(
430
430
  self,
431
- run: "Run",
432
431
  artifact: "Artifact",
433
432
  portfolio_name: str,
434
433
  aliases: Iterable[str],
@@ -442,9 +441,9 @@ class InterfaceBase:
442
441
  else:
443
442
  link_artifact.server_id = artifact.id if artifact.id else ""
444
443
  link_artifact.portfolio_name = portfolio_name
445
- link_artifact.portfolio_entity = entity or run.entity
444
+ link_artifact.portfolio_entity = entity or ""
446
445
  link_artifact.portfolio_organization = organization or ""
447
- link_artifact.portfolio_project = project or run.project
446
+ link_artifact.portfolio_project = project or ""
448
447
  link_artifact.portfolio_aliases.extend(aliases)
449
448
 
450
449
  return self._deliver_link_artifact(link_artifact)
@@ -12,6 +12,7 @@ import sys
12
12
  import threading
13
13
  from copy import deepcopy
14
14
  from pathlib import Path
15
+ from types import MappingProxyType
15
16
  from typing import (
16
17
  IO,
17
18
  TYPE_CHECKING,
@@ -189,11 +190,6 @@ def _match_org_with_fetched_org_entities(
189
190
  """
190
191
  for org_names in orgs:
191
192
  if organization in org_names:
192
- wandb.termwarn(
193
- "Registries can be linked/fetched using a shorthand form without specifying the organization name. "
194
- "Try using shorthand path format: <my_registry_name>/<artifact_name> or "
195
- "just <my_registry_name> if fetching just the project."
196
- )
197
193
  return org_names.entity_name
198
194
 
199
195
  if len(orgs) == 1:
@@ -873,30 +869,29 @@ class Api:
873
869
  _, _, mutations = self.server_info_introspection()
874
870
  return "updateRunQueueItemWarning" in mutations
875
871
 
876
- def _check_server_feature(self, feature_value: ServerFeature) -> bool:
877
- """Check if a server feature is enabled.
878
-
879
- Args:
880
- feature_value (ServerFeature): The enum value of the feature to check.
881
-
882
- Returns:
883
- bool: True if the feature is enabled, False otherwise.
884
-
885
- Raises:
886
- Exception: If server doesn't support feature queries or other errors occur
887
- """
872
+ def _server_features(self) -> Mapping[str, bool]:
873
+ """Returns a cached, read-only lookup of current server feature flags."""
888
874
  if self._server_features_cache is None:
889
875
  query = gql(SERVER_FEATURES_QUERY_GQL)
890
- response = self.gql(query)
891
- server_info = ServerFeaturesQuery.model_validate(response).server_info
892
- if server_info and (features := server_info.features):
893
- self._server_features_cache = {
894
- f.name: f.is_enabled for f in features if f
895
- }
876
+
877
+ try:
878
+ response = self.gql(query)
879
+ except Exception as e:
880
+ # Unfortunately we currently have to match on the text of the error message
881
+ if 'Cannot query field "features" on type "ServerInfo".' in str(e):
882
+ self._server_features_cache = {}
883
+ else:
884
+ raise
896
885
  else:
897
- self._server_features_cache = {}
886
+ info = ServerFeaturesQuery.model_validate(response).server_info
887
+ if info and (feats := info.features):
888
+ self._server_features_cache = {
889
+ f.name: f.is_enabled for f in feats if f
890
+ }
891
+ else:
892
+ self._server_features_cache = {}
898
893
 
899
- return self._server_features_cache.get(ServerFeature.Name(feature_value), False)
894
+ return MappingProxyType(self._server_features_cache)
900
895
 
901
896
  def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
902
897
  """Wrapper around check_server_feature that warns and returns False for older unsupported servers.
@@ -912,12 +907,7 @@ class Api:
912
907
  Exceptions:
913
908
  Exception: If an error other than the server not supporting feature queries occurs.
914
909
  """
915
- try:
916
- return self._check_server_feature(feature_value)
917
- except Exception as e:
918
- if 'Cannot query field "features" on type "ServerInfo".' in str(e):
919
- return False
920
- raise e
910
+ return self._server_features().get(ServerFeature.Name(feature_value), False)
921
911
 
922
912
  @normalize_exceptions
923
913
  def update_run_queue_item_warning(
@@ -1444,7 +1444,7 @@ class SendManager:
1444
1444
  )
1445
1445
  if (client_id or server_id) and portfolio_name and entity and project:
1446
1446
  try:
1447
- self._api.link_artifact(
1447
+ response = self._api.link_artifact(
1448
1448
  client_id,
1449
1449
  server_id,
1450
1450
  portfolio_name,
@@ -1453,9 +1453,12 @@ class SendManager:
1453
1453
  aliases,
1454
1454
  organization,
1455
1455
  )
1456
+ result.response.link_artifact_response.version_index = response[
1457
+ "versionIndex"
1458
+ ]
1456
1459
  except Exception as e:
1457
1460
  org_or_entity = organization or entity
1458
- result.response.log_artifact_response.error_message = (
1461
+ result.response.link_artifact_response.error_message = (
1459
1462
  f"error linking artifact to "
1460
1463
  f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
1461
1464
  )
@@ -223,6 +223,10 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
223
223
  flags_dict: Dict[str, Any] = {}
224
224
  # (5) flags without equals (e.g. --foo bar)
225
225
  args_no_equals: List[str] = []
226
+ # (6) flags for hydra append config value (e.g. +foo=bar)
227
+ flags_append_hydra: List[str] = []
228
+ # (7) flags for hydra override config value (e.g. ++foo=bar)
229
+ flags_override_hydra: List[str] = []
226
230
  for param, config in command["args"].items():
227
231
  # allow 'None' as a valid value, but error if no value is found
228
232
  try:
@@ -234,6 +238,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
234
238
  flags.append("--" + _flag)
235
239
  flags_no_hyphens.append(_flag)
236
240
  args_no_equals += [f"--{param}", str(_value)]
241
+ flags_append_hydra.append("+" + _flag)
242
+ flags_override_hydra.append("++" + _flag)
237
243
  if isinstance(_value, bool):
238
244
  # omit flags if they are boolean and false
239
245
  if _value:
@@ -248,6 +254,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
248
254
  "args_no_boolean_flags": flags_no_booleans,
249
255
  "args_json": [json.dumps(flags_dict)],
250
256
  "args_dict": flags_dict,
257
+ "args_append_hydra": flags_append_hydra,
258
+ "args_override_hydra": flags_override_hydra,
251
259
  }
252
260
 
253
261
 
@@ -0,0 +1,47 @@
1
+ # Generated by ariadne-codegen
2
+
3
+ from .delete_project import DeleteProject, DeleteProjectDeleteModel
4
+ from .fetch_registry import FetchRegistry, FetchRegistryEntity
5
+ from .fragments import (
6
+ RegistryFragment,
7
+ RegistryFragmentArtifactTypes,
8
+ RegistryFragmentArtifactTypesEdges,
9
+ RegistryFragmentArtifactTypesEdgesNode,
10
+ )
11
+ from .input_types import ArtifactTypeInput
12
+ from .operations import (
13
+ DELETE_PROJECT_GQL,
14
+ FETCH_REGISTRY_GQL,
15
+ RENAME_PROJECT_GQL,
16
+ UPSERT_REGISTRY_PROJECT_GQL,
17
+ )
18
+ from .rename_project import (
19
+ RenameProject,
20
+ RenameProjectRenameProject,
21
+ RenameProjectRenameProjectProject,
22
+ )
23
+ from .upsert_registry_project import (
24
+ UpsertRegistryProject,
25
+ UpsertRegistryProjectUpsertModel,
26
+ )
27
+
28
+ __all__ = [
29
+ "DELETE_PROJECT_GQL",
30
+ "FETCH_REGISTRY_GQL",
31
+ "RENAME_PROJECT_GQL",
32
+ "UPSERT_REGISTRY_PROJECT_GQL",
33
+ "FetchRegistry",
34
+ "FetchRegistryEntity",
35
+ "RenameProject",
36
+ "RenameProjectRenameProject",
37
+ "RenameProjectRenameProjectProject",
38
+ "UpsertRegistryProject",
39
+ "UpsertRegistryProjectUpsertModel",
40
+ "DeleteProject",
41
+ "DeleteProjectDeleteModel",
42
+ "ArtifactTypeInput",
43
+ "RegistryFragment",
44
+ "RegistryFragmentArtifactTypes",
45
+ "RegistryFragmentArtifactTypesEdges",
46
+ "RegistryFragmentArtifactTypesEdgesNode",
47
+ ]
@@ -0,0 +1,22 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: tools/graphql_codegen/projects/
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Literal, Optional
7
+
8
+ from pydantic import Field
9
+
10
+ from wandb._pydantic import GQLBase, Typename
11
+
12
+
13
+ class DeleteProject(GQLBase):
14
+ delete_model: Optional[DeleteProjectDeleteModel] = Field(alias="deleteModel")
15
+
16
+
17
+ class DeleteProjectDeleteModel(GQLBase):
18
+ success: Optional[bool]
19
+ typename__: Typename[Literal["DeleteModelPayload"]]
20
+
21
+
22
+ DeleteProject.model_rebuild()
@@ -0,0 +1,4 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: core/api/graphql/schemas/schema-latest.graphql
3
+
4
+ from __future__ import annotations
@@ -0,0 +1,22 @@
1
+ # Generated by ariadne-codegen
2
+ # Source: tools/graphql_codegen/projects/
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Optional
7
+
8
+ from wandb._pydantic import GQLBase
9
+
10
+ from .fragments import RegistryFragment
11
+
12
+
13
+ class FetchRegistry(GQLBase):
14
+ entity: Optional[FetchRegistryEntity]
15
+
16
+
17
+ class FetchRegistryEntity(GQLBase):
18
+ project: Optional[RegistryFragment]
19
+
20
+
21
+ FetchRegistry.model_rebuild()
22
+ FetchRegistryEntity.model_rebuild()