wandb 0.21.4__py3-none-win32.whl → 0.22.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +44 -1
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +282 -60
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_internal_pb2.py +234 -224
  37. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  38. wandb/proto/v4/wandb_internal_pb2.py +226 -224
  39. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  40. wandb/proto/v5/wandb_internal_pb2.py +226 -224
  41. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  42. wandb/proto/v6/wandb_base_pb2.py +3 -3
  43. wandb/proto/v6/wandb_internal_pb2.py +229 -227
  44. wandb/proto/v6/wandb_server_pb2.py +3 -3
  45. wandb/proto/v6/wandb_settings_pb2.py +3 -3
  46. wandb/proto/v6/wandb_sync_pb2.py +13 -9
  47. wandb/proto/v6/wandb_telemetry_pb2.py +3 -3
  48. wandb/sdk/artifacts/_factories.py +7 -2
  49. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  50. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  51. wandb/sdk/artifacts/_generated/operations.py +52 -22
  52. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  53. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  54. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  55. wandb/sdk/artifacts/_gqlutils.py +47 -0
  56. wandb/sdk/artifacts/_models/__init__.py +4 -0
  57. wandb/sdk/artifacts/_models/base_model.py +20 -0
  58. wandb/sdk/artifacts/_validators.py +40 -12
  59. wandb/sdk/artifacts/artifact.py +69 -88
  60. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  61. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
  65. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  66. wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
  67. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +69 -124
  68. wandb/sdk/data_types/bokeh.py +5 -1
  69. wandb/sdk/data_types/image.py +17 -6
  70. wandb/sdk/interface/interface.py +41 -4
  71. wandb/sdk/interface/interface_queue.py +10 -0
  72. wandb/sdk/interface/interface_shared.py +9 -7
  73. wandb/sdk/interface/interface_sock.py +9 -3
  74. wandb/sdk/internal/_generated/__init__.py +2 -12
  75. wandb/sdk/internal/sender.py +1 -1
  76. wandb/sdk/internal/settings_static.py +2 -82
  77. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  78. wandb/sdk/launch/utils.py +82 -1
  79. wandb/sdk/lib/progress.py +7 -4
  80. wandb/sdk/lib/service/service_client.py +5 -9
  81. wandb/sdk/lib/service/service_connection.py +39 -23
  82. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  83. wandb/sdk/projects/_generated/__init__.py +12 -33
  84. wandb/sdk/wandb_init.py +31 -3
  85. wandb/sdk/wandb_login.py +53 -27
  86. wandb/sdk/wandb_run.py +5 -3
  87. wandb/sdk/wandb_settings.py +50 -13
  88. wandb/sync/sync.py +7 -2
  89. wandb/util.py +1 -1
  90. wandb/wandb_agent.py +35 -4
  91. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  92. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/RECORD +95 -91
  93. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  94. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  95. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  96. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,6 @@ from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
16
16
  from urllib.parse import quote
17
17
 
18
18
  import requests
19
- import urllib3
20
19
 
21
20
  from wandb import env
22
21
  from wandb.errors.term import termwarn
@@ -27,40 +26,24 @@ from wandb.sdk.artifacts.artifact_file_cache import (
27
26
  get_artifact_file_cache,
28
27
  )
29
28
  from wandb.sdk.artifacts.staging import get_staging_dir
30
- from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
31
- from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
32
- from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
33
- from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
34
29
  from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
35
- from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
36
30
  from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
37
- from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
38
- from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
39
- WBLocalArtifactHandler,
40
- )
41
31
  from wandb.sdk.artifacts.storage_layout import StorageLayout
42
32
  from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
43
33
  from wandb.sdk.artifacts.storage_policy import StoragePolicy
44
34
  from wandb.sdk.internal.internal_api import Api as InternalApi
45
35
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
46
- from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, hex_to_b64_id
36
+ from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id
47
37
  from wandb.sdk.lib.paths import FilePathStr, URIStr
48
38
 
39
+ from ._factories import make_http_session, make_storage_handlers
40
+
49
41
  if TYPE_CHECKING:
50
42
  from wandb.filesync.step_prepare import StepPrepare
51
43
  from wandb.sdk.artifacts.artifact import Artifact
52
44
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
53
45
  from wandb.sdk.internal import progress
54
46
 
55
- # Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
56
- # seconds, i.e. a total of 20min 6s.
57
- _REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
58
- backoff_factor=1,
59
- total=16,
60
- status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
61
- )
62
- _REQUEST_POOL_CONNECTIONS = 64
63
- _REQUEST_POOL_MAXSIZE = 64
64
47
 
65
48
  # AWS S3 max upload parts without having to make additional requests for extra parts
66
49
  S3_MAX_PART_NUMBERS = 1000
@@ -96,51 +79,36 @@ class WandbStoragePolicy(StoragePolicy):
96
79
 
97
80
  @classmethod
98
81
  def from_config(
99
- cls, config: dict, api: InternalApi | None = None
82
+ cls, config: dict[str, Any], api: InternalApi | None = None
100
83
  ) -> WandbStoragePolicy:
101
84
  return cls(config=config, api=api)
102
85
 
103
86
  def __init__(
104
87
  self,
105
- config: dict | None = None,
88
+ config: dict[str, Any] | None = None,
106
89
  cache: ArtifactFileCache | None = None,
107
90
  api: InternalApi | None = None,
91
+ session: requests.Session | None = None,
108
92
  ) -> None:
109
- self._cache = cache or get_artifact_file_cache()
110
93
  self._config = config or {}
111
- self._session = requests.Session()
112
- adapter = requests.adapters.HTTPAdapter(
113
- max_retries=_REQUEST_RETRY_STRATEGY,
114
- pool_connections=_REQUEST_POOL_CONNECTIONS,
115
- pool_maxsize=_REQUEST_POOL_MAXSIZE,
116
- )
117
- self._session.mount("http://", adapter)
118
- self._session.mount("https://", adapter)
119
-
120
- s3 = S3Handler()
121
- gcs = GCSHandler()
122
- azure = AzureHandler()
123
- http = HTTPHandler(self._session)
124
- https = HTTPHandler(self._session, scheme="https")
125
- artifact = WBArtifactHandler()
126
- local_artifact = WBLocalArtifactHandler()
127
- file_handler = LocalFileHandler()
128
-
94
+ if (storage_region := self._config.get("storageRegion")) is not None:
95
+ self._validate_storage_region(storage_region)
96
+ self._cache = cache or get_artifact_file_cache()
97
+ self._session = session or make_http_session()
129
98
  self._api = api or InternalApi()
130
99
  self._handler = MultiHandler(
131
- handlers=[
132
- s3,
133
- gcs,
134
- azure,
135
- http,
136
- https,
137
- artifact,
138
- local_artifact,
139
- file_handler,
140
- ],
100
+ handlers=make_storage_handlers(self._session),
141
101
  default_handler=TrackingHandler(),
142
102
  )
143
103
 
104
+ def _validate_storage_region(self, storage_region: Any) -> None:
105
+ if not isinstance(storage_region, str):
106
+ raise TypeError(
107
+ f"storageRegion must be a string, got {type(storage_region).__name__}: {storage_region!r}"
108
+ )
109
+ if not storage_region.strip():
110
+ raise ValueError("storageRegion must be a non-empty string")
111
+
144
112
  def config(self) -> dict:
145
113
  return self._config
146
114
 
@@ -167,54 +135,52 @@ class WandbStoragePolicy(StoragePolicy):
167
135
  self._cache._override_cache_path = dest_path
168
136
 
169
137
  path, hit, cache_open = self._cache.check_md5_obj_path(
170
- B64MD5(manifest_entry.digest),
171
- manifest_entry.size if manifest_entry.size is not None else 0,
138
+ manifest_entry.digest,
139
+ size=manifest_entry.size or 0,
172
140
  )
173
141
  if hit:
174
142
  return path
175
143
 
176
- if manifest_entry._download_url is not None:
144
+ if (url := manifest_entry._download_url) is not None:
177
145
  # Use multipart parallel download for large file
178
146
  if (
179
- executor is not None
180
- and manifest_entry.size is not None
181
- and self._should_multipart_download(manifest_entry.size, multipart)
147
+ executor
148
+ and (size := manifest_entry.size)
149
+ and self._should_multipart_download(size, multipart)
182
150
  ):
183
- self._multipart_file_download(
184
- executor,
185
- manifest_entry._download_url,
186
- manifest_entry.size,
187
- cache_open,
188
- )
151
+ self._multipart_file_download(executor, url, size, cache_open)
189
152
  return path
153
+
190
154
  # Serial download
191
- response = self._session.get(manifest_entry._download_url, stream=True)
192
155
  try:
193
- response.raise_for_status()
194
- except Exception:
156
+ response = self._session.get(url, stream=True)
157
+ except requests.HTTPError:
195
158
  # Signed URL might have expired, fall back to fetching it one by one.
196
159
  manifest_entry._download_url = None
160
+
197
161
  if manifest_entry._download_url is None:
198
162
  auth = None
199
- http_headers = _thread_local_api_settings.headers or {}
200
- if self._api.access_token is not None:
201
- http_headers["Authorization"] = f"Bearer {self._api.access_token}"
202
- elif _thread_local_api_settings.cookies is None:
163
+ headers = _thread_local_api_settings.headers
164
+ cookies = _thread_local_api_settings.cookies
165
+
166
+ # For auth, prefer using (in order): auth header, cookies, HTTP Basic Auth
167
+ if token := self._api.access_token:
168
+ headers = {**(headers or {}), "Authorization": f"Bearer {token}"}
169
+ elif cookies is not None:
170
+ pass
171
+ else:
203
172
  auth = ("api", self._api.api_key or "")
173
+
174
+ file_url = self._file_url(
175
+ self._api,
176
+ artifact.entity,
177
+ artifact.project,
178
+ artifact.name.split(":")[0],
179
+ manifest_entry,
180
+ )
204
181
  response = self._session.get(
205
- self._file_url(
206
- self._api,
207
- artifact.entity,
208
- artifact.project,
209
- artifact.name.split(":")[0],
210
- manifest_entry,
211
- ),
212
- auth=auth,
213
- cookies=_thread_local_api_settings.cookies,
214
- headers=http_headers,
215
- stream=True,
182
+ file_url, auth=auth, cookies=cookies, headers=headers, stream=True
216
183
  )
217
- response.raise_for_status()
218
184
 
219
185
  with cache_open(mode="wb") as file:
220
186
  for data in response.iter_content(chunk_size=16 * 1024):
@@ -269,12 +235,7 @@ class WandbStoragePolicy(StoragePolicy):
269
235
  # Other threads has error, no need to start
270
236
  if download_has_error.is_set():
271
237
  return
272
- response = self._session.get(
273
- url=download_url,
274
- headers=headers,
275
- stream=True,
276
- )
277
- response.raise_for_status()
238
+ response = self._session.get(url=download_url, headers=headers, stream=True)
278
239
 
279
240
  file_offset = start
280
241
  for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
@@ -376,43 +337,27 @@ class WandbStoragePolicy(StoragePolicy):
376
337
  entity_name: str,
377
338
  project_name: str,
378
339
  artifact_name: str,
379
- manifest_entry: ArtifactManifestEntry,
340
+ entry: ArtifactManifestEntry,
380
341
  ) -> str:
381
- storage_layout = self._config.get("storageLayout", StorageLayout.V1)
382
- storage_region = self._config.get("storageRegion", "default")
383
- md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
342
+ layout = self._config.get("storageLayout", StorageLayout.V1)
343
+ region = self._config.get("storageRegion", "default")
344
+ md5_hex = b64_to_hex_id(entry.digest)
384
345
 
385
- if storage_layout == StorageLayout.V1:
386
- return "{}/artifacts/{}/{}".format(
387
- api.settings("base_url"), entity_name, md5_hex
388
- )
389
- elif storage_layout == StorageLayout.V2:
346
+ base_url: str = api.settings("base_url")
347
+
348
+ if layout == StorageLayout.V1:
349
+ return f"{base_url}/artifacts/{entity_name}/{md5_hex}"
350
+
351
+ if layout == StorageLayout.V2:
352
+ birth_artifact_id = entry.birth_artifact_id or ""
390
353
  if api._server_supports(
391
- ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER # type: ignore
354
+ ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
392
355
  ):
393
- return "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}".format(
394
- api.settings("base_url"),
395
- storage_region,
396
- quote(entity_name),
397
- quote(project_name),
398
- quote(artifact_name),
399
- quote(manifest_entry.birth_artifact_id or ""),
400
- md5_hex,
401
- manifest_entry.path.name,
402
- )
403
- return "{}/artifactsV2/{}/{}/{}/{}".format(
404
- api.settings("base_url"),
405
- storage_region,
406
- entity_name,
407
- quote(
408
- manifest_entry.birth_artifact_id
409
- if manifest_entry.birth_artifact_id is not None
410
- else ""
411
- ),
412
- md5_hex,
413
- )
414
- else:
415
- raise Exception(f"unrecognized storage layout: {storage_layout}")
356
+ return f"{base_url}/artifactsV2/{region}/{quote(entity_name)}/{quote(project_name)}/{quote(artifact_name)}/{quote(birth_artifact_id)}/{md5_hex}/{entry.path.name}"
357
+
358
+ return f"{base_url}/artifactsV2/{region}/{entity_name}/{quote(birth_artifact_id)}/{md5_hex}"
359
+
360
+ raise ValueError(f"unrecognized storage layout: {layout!r}")
416
361
 
417
362
  def s3_multipart_file_upload(
418
363
  self,
@@ -486,7 +431,7 @@ class WandbStoragePolicy(StoragePolicy):
486
431
  True if the file was a duplicate (did not need to be uploaded),
487
432
  False if it needed to be uploaded or was a reference (nothing to dedupe).
488
433
  """
489
- file_size = entry.size if entry.size is not None else 0
434
+ file_size = entry.size or 0
490
435
  chunk_size = self.calc_chunk_size(file_size)
491
436
  upload_parts = []
492
437
  hex_digests = {}
@@ -562,8 +507,8 @@ class WandbStoragePolicy(StoragePolicy):
562
507
 
563
508
  # Cache upon successful upload.
564
509
  _, hit, cache_open = self._cache.check_md5_obj_path(
565
- B64MD5(entry.digest),
566
- entry.size if entry.size is not None else 0,
510
+ entry.digest,
511
+ size=entry.size or 0,
567
512
  )
568
513
 
569
514
  staging_dir = get_staging_dir()
@@ -5,6 +5,7 @@ import pathlib
5
5
  from typing import TYPE_CHECKING, Union
6
6
 
7
7
  from wandb import util
8
+ from wandb._strutils import nameof
8
9
  from wandb.sdk.lib import runid
9
10
 
10
11
  from . import _dtypes
@@ -34,7 +35,10 @@ class Bokeh(Media):
34
35
  ],
35
36
  ):
36
37
  super().__init__()
37
- bokeh = util.get_module("bokeh", required=True)
38
+ bokeh = util.get_module(
39
+ "bokeh",
40
+ required=f"{nameof(Bokeh)!r} requires the bokeh package. Please install it with `pip install bokeh`.",
41
+ )
38
42
  if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
39
43
  data_or_path
40
44
  ):
@@ -161,6 +161,17 @@ class Image(BatchableMedia):
161
161
  ) -> None:
162
162
  """Initialize a `wandb.Image` object.
163
163
 
164
+ This class handles various image data formats and automatically normalizes
165
+ pixel values to the range [0, 255] when needed, ensuring compatibility
166
+ with the W&B backend.
167
+
168
+ * Data in range [0, 1] is multiplied by 255 and converted to uint8
169
+ * Data in range [-1, 1] is rescaled from [-1, 1] to [0, 255] by mapping
170
+ -1 to 0 and 1 to 255, then converted to uint8
171
+ * Data outside [-1, 1] but not in [0, 255] is clipped to [0, 255] and
172
+ converted to uint8 (with a warning if values fall outside [0, 255])
173
+ * Data already in [0, 255] is converted to uint8 without modification
174
+
164
175
  Args:
165
176
  data_or_path: Accepts NumPy array/pytorch tensor of image data,
166
177
  a PIL image object, or a path to an image file. If a NumPy
@@ -168,7 +179,7 @@ class Image(BatchableMedia):
168
179
  the image data will be saved to the given file type.
169
180
  If the values are not in the range [0, 255] or all values are in the range [0, 1],
170
181
  the image pixel values will be normalized to the range [0, 255]
171
- unless `normalize` is set to False.
182
+ unless `normalize` is set to `False`.
172
183
  - pytorch tensor should be in the format (channel, height, width)
173
184
  - NumPy array should be in the format (height, width, channel)
174
185
  mode: The PIL mode for an image. Most common are "L", "RGB",
@@ -178,13 +189,13 @@ class Image(BatchableMedia):
178
189
  classes: A list of class information for the image,
179
190
  used for labeling bounding boxes, and image masks.
180
191
  boxes: A dictionary containing bounding box information for the image.
181
- see: https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
192
+ see https://docs.wandb.ai/ref/python/data-types/boundingboxes2d/
182
193
  masks: A dictionary containing mask information for the image.
183
- see: https://docs.wandb.ai/ref/python/data-types/imagemask/
194
+ see https://docs.wandb.ai/ref/python/data-types/imagemask/
184
195
  file_type: The file type to save the image as.
185
- This parameter has no effect if data_or_path is a path to an image file.
186
- normalize: If True, normalize the image pixel values to fall within the range of [0, 255].
187
- Normalize is only applied if data_or_path is a numpy array or pytorch tensor.
196
+ This parameter has no effect if `data_or_path` is a path to an image file.
197
+ normalize: If `True`, normalize the image pixel values to fall within the range of [0, 255].
198
+ Normalize is only applied if `data_or_path` is a numpy array or pytorch tensor.
188
199
 
189
200
  Examples:
190
201
  Create a wandb.Image from a numpy array
@@ -87,11 +87,38 @@ def file_enum_to_policy(enum: "pb.FilesItem.PolicyType.V") -> "PolicyName":
87
87
 
88
88
 
89
89
  class InterfaceBase:
90
+ """Methods for sending different types of Records to the service.
91
+
92
+ None of the methods may be called from an asyncio context other than
93
+ deliver_async().
94
+ """
95
+
90
96
  _drop: bool
91
97
 
92
98
  def __init__(self) -> None:
93
99
  self._drop = False
94
100
 
101
+ @abstractmethod
102
+ async def deliver_async(
103
+ self,
104
+ record: pb.Record,
105
+ ) -> MailboxHandle[pb.Result]:
106
+ """Send a record and create a handle to wait for the response.
107
+
108
+ The synchronous publish and deliver methods on this class cannot be
109
+ called in the asyncio thread because they block. Instead of having
110
+ an async copy of every method, this is a general method for sending
111
+ any kind of record in the asyncio thread.
112
+
113
+ Args:
114
+ record: The record to send. This method takes ownership of the
115
+ record and it must not be used afterward.
116
+
117
+ Returns:
118
+ A handle to wait for a response to the record.
119
+ """
120
+ raise NotImplementedError
121
+
95
122
  def publish_header(self) -> None:
96
123
  header = pb.HeaderRecord()
97
124
  self._publish_header(header)
@@ -392,9 +419,13 @@ class InterfaceBase:
392
419
  proto_manifest.manifest_file_path = path
393
420
  return proto_manifest
394
421
 
422
+ # Set storage policy on storageLayout (always V2) and storageRegion, only allow coreweave-us on wandb.ai for now.
423
+ # NOTE: the decode logic is NewManifestFromProto in core/pkg/artifacts/manifest.go
424
+ # The creation logic is in artifacts/_factories.py make_storage_policy
395
425
  for k, v in artifact_manifest.storage_policy.config().items() or {}.items():
396
426
  cfg = proto_manifest.storage_policy_config.add()
397
427
  cfg.key = k
428
+ # TODO: Why json.dumps when existing values are plain string? We want to send complex structure without defining the proto?
398
429
  cfg.value_json = json.dumps(v)
399
430
 
400
431
  for entry in sorted(artifact_manifest.entries.values(), key=lambda k: k.path):
@@ -883,6 +914,16 @@ class InterfaceBase:
883
914
  ) -> MailboxHandle[pb.Result]:
884
915
  raise NotImplementedError
885
916
 
917
+ def publish_probe_system_info(self) -> None:
918
+ probe_system_info = pb.ProbeSystemInfoRequest()
919
+ return self._publish_probe_system_info(probe_system_info)
920
+
921
+ @abstractmethod
922
+ def _publish_probe_system_info(
923
+ self, probe_system_info: pb.ProbeSystemInfoRequest
924
+ ) -> None:
925
+ raise NotImplementedError
926
+
886
927
  def join(self) -> None:
887
928
  # Drop indicates that the internal process has already been shutdown
888
929
  if self._drop:
@@ -1010,10 +1051,6 @@ class InterfaceBase:
1010
1051
  ) -> MailboxHandle[pb.Result]:
1011
1052
  raise NotImplementedError
1012
1053
 
1013
- @abstractmethod
1014
- def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
1015
- raise NotImplementedError
1016
-
1017
1054
  def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
1018
1055
  poll_exit = pb.PollExitRequest()
1019
1056
  return self._deliver_poll_exit(poll_exit)
@@ -8,12 +8,15 @@ import logging
8
8
  from multiprocessing.process import BaseProcess
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
+ from typing_extensions import override
12
+
11
13
  from .interface_shared import InterfaceShared
12
14
 
13
15
  if TYPE_CHECKING:
14
16
  from queue import Queue
15
17
 
16
18
  from wandb.proto import wandb_internal_pb2 as pb
19
+ from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
17
20
 
18
21
 
19
22
  logger = logging.getLogger("wandb")
@@ -31,6 +34,13 @@ class InterfaceQueue(InterfaceShared):
31
34
  self._process = process
32
35
  super().__init__()
33
36
 
37
+ @override
38
+ async def deliver_async(
39
+ self,
40
+ record: "pb.Record",
41
+ ) -> "MailboxHandle[pb.Result]":
42
+ raise NotImplementedError
43
+
34
44
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
35
45
  if self._process and not self._process.is_alive():
36
46
  raise Exception("The wandb backend process has shutdown")
@@ -87,7 +87,6 @@ class InterfaceShared(InterfaceBase):
87
87
  stop_status: Optional[pb.StopStatusRequest] = None,
88
88
  internal_messages: Optional[pb.InternalMessagesRequest] = None,
89
89
  network_status: Optional[pb.NetworkStatusRequest] = None,
90
- operation_stats: Optional[pb.OperationStatsRequest] = None,
91
90
  poll_exit: Optional[pb.PollExitRequest] = None,
92
91
  partial_history: Optional[pb.PartialHistoryRequest] = None,
93
92
  sampled_history: Optional[pb.SampledHistoryRequest] = None,
@@ -112,6 +111,7 @@ class InterfaceShared(InterfaceBase):
112
111
  python_packages: Optional[pb.PythonPackagesRequest] = None,
113
112
  job_input: Optional[pb.JobInputRequest] = None,
114
113
  run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
114
+ probe_system_info: Optional[pb.ProbeSystemInfoRequest] = None,
115
115
  ) -> pb.Record:
116
116
  request = pb.Request()
117
117
  if get_summary:
@@ -128,8 +128,6 @@ class InterfaceShared(InterfaceBase):
128
128
  request.internal_messages.CopyFrom(internal_messages)
129
129
  elif network_status:
130
130
  request.network_status.CopyFrom(network_status)
131
- elif operation_stats:
132
- request.operations.CopyFrom(operation_stats)
133
131
  elif poll_exit:
134
132
  request.poll_exit.CopyFrom(poll_exit)
135
133
  elif partial_history:
@@ -178,6 +176,8 @@ class InterfaceShared(InterfaceBase):
178
176
  request.job_input.CopyFrom(job_input)
179
177
  elif run_finish_without_exit:
180
178
  request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
179
+ elif probe_system_info:
180
+ request.probe_system_info.CopyFrom(probe_system_info)
181
181
  else:
182
182
  raise Exception("Invalid request")
183
183
  record = self._make_record(request=request)
@@ -330,6 +330,12 @@ class InterfaceShared(InterfaceBase):
330
330
  rec = self._make_record(use_artifact=use_artifact)
331
331
  self._publish(rec)
332
332
 
333
+ def _publish_probe_system_info(
334
+ self, probe_system_info: pb.ProbeSystemInfoRequest
335
+ ) -> None:
336
+ record = self._make_request(probe_system_info=probe_system_info)
337
+ self._publish(record)
338
+
333
339
  def _deliver_artifact(
334
340
  self,
335
341
  log_artifact: pb.LogArtifactRequest,
@@ -415,10 +421,6 @@ class InterfaceShared(InterfaceBase):
415
421
  record = self._make_record(exit=exit_data)
416
422
  return self._deliver(record)
417
423
 
418
- def deliver_operation_stats(self):
419
- record = self._make_request(operation_stats=pb.OperationStatsRequest())
420
- return self._deliver(record)
421
-
422
424
  def _deliver_poll_exit(
423
425
  self,
424
426
  poll_exit: pb.PollExitRequest,
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any
6
6
  from typing_extensions import override
7
7
 
8
8
  from wandb.proto import wandb_server_pb2 as spb
9
+ from wandb.sdk.lib import asyncio_manager
9
10
 
10
11
  from .interface_shared import InterfaceShared
11
12
 
@@ -21,10 +22,12 @@ logger = logging.getLogger("wandb")
21
22
  class InterfaceSock(InterfaceShared):
22
23
  def __init__(
23
24
  self,
25
+ asyncer: asyncio_manager.AsyncioManager,
24
26
  client: ServiceClient,
25
27
  stream_id: str,
26
28
  ) -> None:
27
29
  super().__init__()
30
+ self._asyncer = asyncer
28
31
  self._client = client
29
32
  self._stream_id = stream_id
30
33
 
@@ -37,13 +40,16 @@ class InterfaceSock(InterfaceShared):
37
40
  self._assign(record)
38
41
  request = spb.ServerRequest()
39
42
  request.record_publish.CopyFrom(record)
40
- self._client.publish(request)
43
+ self._asyncer.run(lambda: self._client.publish(request))
41
44
 
42
- @override
43
45
  def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
46
+ return self._asyncer.run(lambda: self.deliver_async(record))
47
+
48
+ @override
49
+ async def deliver_async(self, record: pb.Record) -> MailboxHandle[pb.Result]:
44
50
  self._assign(record)
45
51
  request = spb.ServerRequest()
46
52
  request.record_publish.CopyFrom(record)
47
53
 
48
- handle = self._client.deliver(request)
54
+ handle = await self._client.deliver(request)
49
55
  return handle.map(lambda response: response.result_communicate)
@@ -1,15 +1,5 @@
1
1
  # Generated by ariadne-codegen
2
2
 
3
+ __all__ = ["SERVER_FEATURES_QUERY_GQL", "ServerFeaturesQuery"]
3
4
  from .operations import SERVER_FEATURES_QUERY_GQL
4
- from .server_features_query import (
5
- ServerFeaturesQuery,
6
- ServerFeaturesQueryServerInfo,
7
- ServerFeaturesQueryServerInfoFeatures,
8
- )
9
-
10
- __all__ = [
11
- "SERVER_FEATURES_QUERY_GQL",
12
- "ServerFeaturesQuery",
13
- "ServerFeaturesQueryServerInfo",
14
- "ServerFeaturesQueryServerInfoFeatures",
15
- ]
5
+ from .server_features_query import ServerFeaturesQuery
@@ -343,7 +343,7 @@ class SendManager:
343
343
  publish_interface = InterfaceQueue(record_q=record_q)
344
344
  context_keeper = context.ContextKeeper()
345
345
  return SendManager(
346
- settings=SettingsStatic(settings.to_proto()),
346
+ settings=SettingsStatic(dict(settings)),
347
347
  record_q=record_q,
348
348
  result_q=result_q,
349
349
  interface=publish_interface,