wandb 0.21.1__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/apis/public/api.py +1 -2
  4. wandb/apis/public/artifacts.py +3 -5
  5. wandb/apis/public/registries/_utils.py +14 -16
  6. wandb/apis/public/registries/registries_search.py +176 -289
  7. wandb/apis/public/reports.py +13 -10
  8. wandb/automations/_generated/delete_automation.py +1 -3
  9. wandb/automations/_generated/enums.py +13 -11
  10. wandb/bin/gpu_stats.exe +0 -0
  11. wandb/bin/wandb-core +0 -0
  12. wandb/cli/cli.py +47 -2
  13. wandb/integration/metaflow/data_pandas.py +2 -2
  14. wandb/integration/metaflow/data_pytorch.py +75 -0
  15. wandb/integration/metaflow/data_sklearn.py +76 -0
  16. wandb/integration/metaflow/metaflow.py +16 -87
  17. wandb/integration/weave/__init__.py +6 -0
  18. wandb/integration/weave/interface.py +49 -0
  19. wandb/integration/weave/weave.py +63 -0
  20. wandb/proto/v3/wandb_internal_pb2.py +3 -2
  21. wandb/proto/v4/wandb_internal_pb2.py +2 -2
  22. wandb/proto/v5/wandb_internal_pb2.py +2 -2
  23. wandb/proto/v6/wandb_internal_pb2.py +2 -2
  24. wandb/sdk/artifacts/_factories.py +17 -0
  25. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  26. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  27. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  28. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  29. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  30. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  31. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  32. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  33. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  34. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  35. wandb/sdk/artifacts/_generated/enums.py +5 -0
  36. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  37. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  38. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  39. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  40. wandb/sdk/artifacts/_generated/operations.py +654 -51
  41. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  42. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  43. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  44. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  45. wandb/sdk/artifacts/_validators.py +6 -4
  46. wandb/sdk/artifacts/artifact.py +406 -543
  47. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  48. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  49. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  50. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  51. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  53. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  54. wandb/sdk/data_types/video.py +2 -2
  55. wandb/sdk/interface/interface_queue.py +1 -4
  56. wandb/sdk/interface/interface_shared.py +26 -37
  57. wandb/sdk/interface/interface_sock.py +24 -14
  58. wandb/sdk/internal/settings_static.py +2 -3
  59. wandb/sdk/launch/create_job.py +12 -1
  60. wandb/sdk/launch/runner/kubernetes_runner.py +24 -29
  61. wandb/sdk/lib/asyncio_compat.py +16 -16
  62. wandb/sdk/lib/asyncio_manager.py +252 -0
  63. wandb/sdk/lib/hashutil.py +13 -4
  64. wandb/sdk/lib/printer.py +2 -2
  65. wandb/sdk/lib/printer_asyncio.py +3 -1
  66. wandb/sdk/lib/retry.py +185 -78
  67. wandb/sdk/lib/service/service_client.py +106 -0
  68. wandb/sdk/lib/service/service_connection.py +20 -26
  69. wandb/sdk/lib/service/service_token.py +30 -13
  70. wandb/sdk/mailbox/mailbox.py +13 -5
  71. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  72. wandb/sdk/mailbox/response_handle.py +42 -106
  73. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  74. wandb/sdk/wandb_init.py +11 -25
  75. wandb/sdk/wandb_login.py +1 -1
  76. wandb/sdk/wandb_run.py +91 -55
  77. wandb/sdk/wandb_settings.py +45 -32
  78. wandb/sdk/wandb_setup.py +176 -96
  79. wandb/util.py +1 -1
  80. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  81. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/RECORD +84 -68
  82. wandb/sdk/interface/interface_relay.py +0 -38
  83. wandb/sdk/interface/router.py +0 -89
  84. wandb/sdk/interface/router_queue.py +0 -43
  85. wandb/sdk/interface/router_relay.py +0 -50
  86. wandb/sdk/interface/router_sock.py +0 -32
  87. wandb/sdk/lib/sock_client.py +0 -232
  88. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  89. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  90. {wandb-0.21.1.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -9,6 +9,7 @@ import os
9
9
  import shutil
10
10
  import subprocess
11
11
  import sys
12
+ from functools import lru_cache
12
13
  from pathlib import Path
13
14
  from tempfile import NamedTemporaryFile
14
15
  from typing import IO, ContextManager, Iterator, Protocol
@@ -236,12 +237,15 @@ class ArtifactFileCache:
236
237
  ) from e
237
238
 
238
239
 
239
- _artifact_file_cache: ArtifactFileCache | None = None
240
+ # Memo `ArtifactFileCache` instances while avoiding reliance on global
241
+ # variable(s). Notes:
242
+ # - @lru_cache should be thread-safe.
243
+ # - We don't memoize `get_artifact_file_cache` directly, as the cache_dir
244
+ # may change at runtime. This is likely rare in practice, though.
245
+ @lru_cache(maxsize=1)
246
+ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
247
+ return ArtifactFileCache(cache_dir)
240
248
 
241
249
 
242
250
  def get_artifact_file_cache() -> ArtifactFileCache:
243
- global _artifact_file_cache
244
- cache_dir = env.get_cache_dir() / "artifacts"
245
- if _artifact_file_cache is None or _artifact_file_cache._cache_dir != cache_dir:
246
- _artifact_file_cache = ArtifactFileCache(cache_dir)
247
- return _artifact_file_cache
251
+ return _build_artifact_file_cache(env.get_cache_dir() / "artifacts")
@@ -50,10 +50,12 @@ class ArtifactManifest:
50
50
 
51
51
  def add_entry(self, entry: ArtifactManifestEntry, overwrite: bool = False) -> None:
52
52
  path = entry.path
53
- if not overwrite:
54
- prev_entry = self.entries.get(path)
55
- if prev_entry and (entry.digest != prev_entry.digest):
56
- raise ValueError(f"Cannot add the same path twice: {path!r}")
53
+ if (
54
+ (not overwrite)
55
+ and (old_entry := self.entries.get(path))
56
+ and (entry.digest != old_entry.digest)
57
+ ):
58
+ raise ValueError(f"Cannot add the same path twice: {path!r}")
57
59
  self.entries[path] = entry
58
60
 
59
61
  def remove_entry(self, entry: ArtifactManifestEntry) -> None:
@@ -67,9 +69,8 @@ class ArtifactManifest:
67
69
 
68
70
  def get_entries_in_directory(self, directory: str) -> list[ArtifactManifestEntry]:
69
71
  return [
70
- self.entries[entry_key]
71
- for entry_key in self.entries
72
- if entry_key.startswith(
73
- directory + "/"
74
- ) # entries use forward slash even for windows
72
+ entry
73
+ for key, entry in self.entries.items()
74
+ # entry keys (paths) use forward slash even for windows
75
+ if key.startswith(f"{directory}/")
75
76
  ]
@@ -40,6 +40,9 @@ if TYPE_CHECKING:
40
40
  local_path: str
41
41
 
42
42
 
43
+ _WB_ARTIFACT_SCHEME = "wandb-artifact"
44
+
45
+
43
46
  class ArtifactManifestEntry:
44
47
  """A single entry in an artifact manifest."""
45
48
 
@@ -221,15 +224,11 @@ class ArtifactManifestEntry:
221
224
  derived_artifact.add_reference(ref_url)
222
225
  ```
223
226
  """
224
- if self._parent_artifact is None:
225
- raise NotImplementedError
226
- assert self._parent_artifact.id is not None
227
- return (
228
- "wandb-artifact://"
229
- + b64_to_hex_id(B64MD5(self._parent_artifact.id))
230
- + "/"
231
- + self.path
232
- )
227
+ if (parent_artifact := self.parent_artifact()) is None:
228
+ raise ValueError("Parent artifact is not set")
229
+ elif (parent_id := parent_artifact.id) is None:
230
+ raise ValueError("Parent artifact ID is not set")
231
+ return f"{_WB_ARTIFACT_SCHEME}://{b64_to_hex_id(B64MD5(parent_id))}/{self.path}"
233
232
 
234
233
  def to_json(self) -> ArtifactManifestEntryDict:
235
234
  contents: ArtifactManifestEntryDict = {
@@ -251,7 +250,7 @@ class ArtifactManifestEntry:
251
250
  return contents
252
251
 
253
252
  def _is_artifact_reference(self) -> bool:
254
- return self.ref is not None and urlparse(self.ref).scheme == "wandb-artifact"
253
+ return self.ref is not None and urlparse(self.ref).scheme == _WB_ARTIFACT_SCHEME
255
254
 
256
255
  def _referenced_artifact_id(self) -> str | None:
257
256
  if not self._is_artifact_reference():
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from operator import itemgetter
5
6
  from typing import Any, Mapping
6
7
 
7
8
  from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
@@ -64,7 +65,7 @@ class ArtifactManifestV1(ArtifactManifest):
64
65
  contents.
65
66
  """
66
67
  contents = {}
67
- for entry in sorted(self.entries.values(), key=lambda k: k.path):
68
+ for name, entry in sorted(self.entries.items(), key=itemgetter(0)):
68
69
  json_entry: dict[str, Any] = {
69
70
  "digest": entry.digest,
70
71
  }
@@ -76,7 +77,7 @@ class ArtifactManifestV1(ArtifactManifest):
76
77
  json_entry["extra"] = entry.extra
77
78
  if entry.size is not None:
78
79
  json_entry["size"] = entry.size
79
- contents[entry.path] = json_entry
80
+ contents[name] = json_entry
80
81
  return {
81
82
  "version": self.__class__.version(),
82
83
  "storagePolicy": self.storage_policy.name(),
@@ -87,6 +88,7 @@ class ArtifactManifestV1(ArtifactManifest):
87
88
  def digest(self) -> HexMD5:
88
89
  hasher = _md5()
89
90
  hasher.update(b"wandb-artifact-manifest-v1\n")
90
- for name, entry in sorted(self.entries.items(), key=lambda kv: kv[0]):
91
+ # sort by key (path)
92
+ for name, entry in sorted(self.entries.items(), key=itemgetter(0)):
91
93
  hasher.update(f"{name}:{entry.digest}\n".encode())
92
94
  return HexMD5(hasher.hexdigest())
@@ -42,7 +42,7 @@ class HTTPHandler(StorageHandler):
42
42
 
43
43
  path, hit, cache_open = self._cache.check_etag_obj_path(
44
44
  URIStr(manifest_entry.ref),
45
- ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
45
+ ETag(manifest_entry.digest),
46
46
  manifest_entry.size if manifest_entry.size is not None else 0,
47
47
  )
48
48
  if hit:
@@ -94,7 +94,7 @@ class S3Handler(StorageHandler):
94
94
 
95
95
  path, hit, cache_open = self._cache.check_etag_obj_path(
96
96
  URIStr(manifest_entry.ref),
97
- ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
97
+ ETag(manifest_entry.digest),
98
98
  manifest_entry.size if manifest_entry.size is not None else 0,
99
99
  )
100
100
  if hit:
@@ -167,7 +167,7 @@ class WandbStoragePolicy(StoragePolicy):
167
167
  self._cache._override_cache_path = dest_path
168
168
 
169
169
  path, hit, cache_open = self._cache.check_md5_obj_path(
170
- B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
170
+ B64MD5(manifest_entry.digest),
171
171
  manifest_entry.size if manifest_entry.size is not None else 0,
172
172
  )
173
173
  if hit:
@@ -25,8 +25,8 @@ if TYPE_CHECKING: # pragma: no cover
25
25
 
26
26
 
27
27
  def _should_print_spinner() -> bool:
28
- singleton = wandb_setup.singleton_if_setup()
29
- if singleton and (singleton.settings.quiet or singleton.settings.silent):
28
+ settings = wandb_setup.singleton().settings_if_loaded
29
+ if settings and (settings.quiet or settings.silent):
30
30
  return False
31
31
 
32
32
  return not env.is_quiet() and not env.is_silent()
@@ -8,8 +8,6 @@ import logging
8
8
  from multiprocessing.process import BaseProcess
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
- from wandb.sdk.mailbox import Mailbox
12
-
13
11
  from .interface_shared import InterfaceShared
14
12
 
15
13
  if TYPE_CHECKING:
@@ -27,12 +25,11 @@ class InterfaceQueue(InterfaceShared):
27
25
  record_q: Optional["Queue[pb.Record]"] = None,
28
26
  result_q: Optional["Queue[pb.Result]"] = None,
29
27
  process: Optional[BaseProcess] = None,
30
- mailbox: Optional[Mailbox] = None,
31
28
  ) -> None:
32
29
  self.record_q = record_q
33
30
  self.result_q = result_q
34
31
  self._process = process
35
- super().__init__(mailbox=mailbox)
32
+ super().__init__()
36
33
 
37
34
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
38
35
  if self._process and not self._process.is_alive():
@@ -10,7 +10,7 @@ from typing import Any, Optional, cast
10
10
 
11
11
  from wandb.proto import wandb_internal_pb2 as pb
12
12
  from wandb.proto import wandb_telemetry_pb2 as tpb
13
- from wandb.sdk.mailbox import Mailbox, MailboxHandle
13
+ from wandb.sdk.mailbox import MailboxHandle
14
14
  from wandb.util import json_dumps_safer, json_friendly
15
15
 
16
16
  from .interface import InterfaceBase
@@ -19,9 +19,8 @@ logger = logging.getLogger("wandb")
19
19
 
20
20
 
21
21
  class InterfaceShared(InterfaceBase):
22
- def __init__(self, mailbox: Optional[Mailbox] = None) -> None:
22
+ def __init__(self) -> None:
23
23
  super().__init__()
24
- self._mailbox = mailbox
25
24
 
26
25
  def _publish_output(self, outdata: pb.OutputRecord) -> None:
27
26
  rec = pb.Record()
@@ -67,7 +66,7 @@ class InterfaceShared(InterfaceBase):
67
66
  self, job_input: pb.JobInputRequest
68
67
  ) -> MailboxHandle[pb.Result]:
69
68
  record = self._make_request(job_input=job_input)
70
- return self._deliver_record(record)
69
+ return self._deliver(record)
71
70
 
72
71
  def _make_stats(self, stats_dict: dict) -> pb.StatsRecord:
73
72
  stats = pb.StatsRecord()
@@ -263,6 +262,9 @@ class InterfaceShared(InterfaceBase):
263
262
  def _publish(self, record: pb.Record, local: Optional[bool] = None) -> None:
264
263
  raise NotImplementedError
265
264
 
265
+ def _deliver(self, record: pb.Record) -> "MailboxHandle[pb.Result]":
266
+ raise NotImplementedError
267
+
266
268
  def _publish_defer(self, state: "pb.DeferRequest.DeferState.V") -> None:
267
269
  defer = pb.DeferRequest(state=state)
268
270
  rec = self._make_request(defer=defer)
@@ -333,19 +335,19 @@ class InterfaceShared(InterfaceBase):
333
335
  log_artifact: pb.LogArtifactRequest,
334
336
  ) -> MailboxHandle[pb.Result]:
335
337
  rec = self._make_request(log_artifact=log_artifact)
336
- return self._deliver_record(rec)
338
+ return self._deliver(rec)
337
339
 
338
340
  def _deliver_download_artifact(
339
341
  self, download_artifact: pb.DownloadArtifactRequest
340
342
  ) -> MailboxHandle[pb.Result]:
341
343
  rec = self._make_request(download_artifact=download_artifact)
342
- return self._deliver_record(rec)
344
+ return self._deliver(rec)
343
345
 
344
346
  def _deliver_link_artifact(
345
347
  self, link_artifact: pb.LinkArtifactRequest
346
348
  ) -> MailboxHandle[pb.Result]:
347
349
  rec = self._make_request(link_artifact=link_artifact)
348
- return self._deliver_record(rec)
350
+ return self._deliver(rec)
349
351
 
350
352
  def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
351
353
  rec = self._make_record(artifact=proto_artifact)
@@ -360,7 +362,7 @@ class InterfaceShared(InterfaceBase):
360
362
  status: pb.StatusRequest,
361
363
  ) -> MailboxHandle[pb.Result]:
362
364
  req = self._make_request(status=status)
363
- return self._deliver_record(req)
365
+ return self._deliver(req)
364
366
 
365
367
  def _publish_exit(self, exit_data: pb.RunExitRecord) -> None:
366
368
  rec = self._make_record(exit=exit_data)
@@ -373,110 +375,97 @@ class InterfaceShared(InterfaceBase):
373
375
  def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
374
376
  request = pb.Request(shutdown=pb.ShutdownRequest())
375
377
  record = self._make_record(request=request)
376
- return self._deliver_record(record)
377
-
378
- def _get_mailbox(self) -> Mailbox:
379
- mailbox = self._mailbox
380
- assert mailbox
381
- return mailbox
382
-
383
- def _deliver_record(self, record: pb.Record) -> MailboxHandle[pb.Result]:
384
- mailbox = self._get_mailbox()
385
-
386
- handle = mailbox.require_response(record)
387
- self._publish(record)
388
-
389
- return handle.map(lambda resp: resp.result_communicate)
378
+ return self._deliver(record)
390
379
 
391
380
  def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
392
381
  record = self._make_record(run=run)
393
- return self._deliver_record(record)
382
+ return self._deliver(record)
394
383
 
395
384
  def _deliver_finish_sync(
396
385
  self,
397
386
  sync_finish: pb.SyncFinishRequest,
398
387
  ) -> MailboxHandle[pb.Result]:
399
388
  record = self._make_request(sync_finish=sync_finish)
400
- return self._deliver_record(record)
389
+ return self._deliver(record)
401
390
 
402
391
  def _deliver_run_start(
403
392
  self,
404
393
  run_start: pb.RunStartRequest,
405
394
  ) -> MailboxHandle[pb.Result]:
406
395
  record = self._make_request(run_start=run_start)
407
- return self._deliver_record(record)
396
+ return self._deliver(record)
408
397
 
409
398
  def _deliver_get_summary(
410
399
  self,
411
400
  get_summary: pb.GetSummaryRequest,
412
401
  ) -> MailboxHandle[pb.Result]:
413
402
  record = self._make_request(get_summary=get_summary)
414
- return self._deliver_record(record)
403
+ return self._deliver(record)
415
404
 
416
405
  def _deliver_get_system_metrics(
417
406
  self, get_system_metrics: pb.GetSystemMetricsRequest
418
407
  ) -> MailboxHandle[pb.Result]:
419
408
  record = self._make_request(get_system_metrics=get_system_metrics)
420
- return self._deliver_record(record)
409
+ return self._deliver(record)
421
410
 
422
411
  def _deliver_exit(
423
412
  self,
424
413
  exit_data: pb.RunExitRecord,
425
414
  ) -> MailboxHandle[pb.Result]:
426
415
  record = self._make_record(exit=exit_data)
427
- return self._deliver_record(record)
416
+ return self._deliver(record)
428
417
 
429
418
  def deliver_operation_stats(self):
430
419
  record = self._make_request(operation_stats=pb.OperationStatsRequest())
431
- return self._deliver_record(record)
420
+ return self._deliver(record)
432
421
 
433
422
  def _deliver_poll_exit(
434
423
  self,
435
424
  poll_exit: pb.PollExitRequest,
436
425
  ) -> MailboxHandle[pb.Result]:
437
426
  record = self._make_request(poll_exit=poll_exit)
438
- return self._deliver_record(record)
427
+ return self._deliver(record)
439
428
 
440
429
  def _deliver_finish_without_exit(
441
430
  self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
442
431
  ) -> MailboxHandle[pb.Result]:
443
432
  record = self._make_request(run_finish_without_exit=run_finish_without_exit)
444
- return self._deliver_record(record)
433
+ return self._deliver(record)
445
434
 
446
435
  def _deliver_stop_status(
447
436
  self,
448
437
  stop_status: pb.StopStatusRequest,
449
438
  ) -> MailboxHandle[pb.Result]:
450
439
  record = self._make_request(stop_status=stop_status)
451
- return self._deliver_record(record)
440
+ return self._deliver(record)
452
441
 
453
442
  def _deliver_attach(
454
443
  self,
455
444
  attach: pb.AttachRequest,
456
445
  ) -> MailboxHandle[pb.Result]:
457
446
  record = self._make_request(attach=attach)
458
- return self._deliver_record(record)
447
+ return self._deliver(record)
459
448
 
460
449
  def _deliver_network_status(
461
450
  self, network_status: pb.NetworkStatusRequest
462
451
  ) -> MailboxHandle[pb.Result]:
463
452
  record = self._make_request(network_status=network_status)
464
- return self._deliver_record(record)
453
+ return self._deliver(record)
465
454
 
466
455
  def _deliver_internal_messages(
467
456
  self, internal_message: pb.InternalMessagesRequest
468
457
  ) -> MailboxHandle[pb.Result]:
469
458
  record = self._make_request(internal_messages=internal_message)
470
- return self._deliver_record(record)
459
+ return self._deliver(record)
471
460
 
472
461
  def _deliver_request_sampled_history(
473
462
  self, sampled_history: pb.SampledHistoryRequest
474
463
  ) -> MailboxHandle[pb.Result]:
475
464
  record = self._make_request(sampled_history=sampled_history)
476
- return self._deliver_record(record)
465
+ return self._deliver(record)
477
466
 
478
467
  def _deliver_request_run_status(
479
468
  self, run_status: pb.RunStatusRequest
480
469
  ) -> MailboxHandle[pb.Result]:
481
470
  record = self._make_request(run_status=run_status)
482
- return self._deliver_record(record)
471
+ return self._deliver(record)
@@ -1,19 +1,18 @@
1
- """InterfaceSock - Derived from InterfaceShared using a socket to send to internal thread.
2
-
3
- See interface.py for how interface classes relate to each other.
4
-
5
- """
1
+ from __future__ import annotations
6
2
 
7
3
  import logging
8
- from typing import TYPE_CHECKING, Any, Optional
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from typing_extensions import override
9
7
 
10
- from wandb.sdk.mailbox import Mailbox
8
+ from wandb.proto import wandb_server_pb2 as spb
11
9
 
12
- from ..lib.sock_client import SockClient
13
10
  from .interface_shared import InterfaceShared
14
11
 
15
12
  if TYPE_CHECKING:
16
13
  from wandb.proto import wandb_internal_pb2 as pb
14
+ from wandb.sdk.lib.service.service_client import ServiceClient
15
+ from wandb.sdk.mailbox import MailboxHandle
17
16
 
18
17
 
19
18
  logger = logging.getLogger("wandb")
@@ -22,18 +21,29 @@ logger = logging.getLogger("wandb")
22
21
  class InterfaceSock(InterfaceShared):
23
22
  def __init__(
24
23
  self,
25
- sock_client: SockClient,
26
- mailbox: Mailbox,
24
+ client: ServiceClient,
27
25
  stream_id: str,
28
26
  ) -> None:
29
- super().__init__(mailbox=mailbox)
30
- self._sock_client = sock_client
27
+ super().__init__()
28
+ self._client = client
31
29
  self._stream_id = stream_id
32
30
 
33
31
  def _assign(self, record: Any) -> None:
34
32
  assert self._stream_id
35
33
  record._info.stream_id = self._stream_id
36
34
 
37
- def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
35
+ @override
36
+ def _publish(self, record: pb.Record, local: bool | None = None) -> None:
38
37
  self._assign(record)
39
- self._sock_client.send_record_publish(record)
38
+ request = spb.ServerRequest()
39
+ request.record_publish.CopyFrom(record)
40
+ self._client.publish(request)
41
+
42
+ @override
43
+ def _deliver(self, record: pb.Record) -> MailboxHandle[pb.Result]:
44
+ self._assign(record)
45
+ request = spb.ServerRequest()
46
+ request.record_publish.CopyFrom(record)
47
+
48
+ handle = self._client.deliver(request)
49
+ return handle.map(lambda response: response.result_communicate)
@@ -4,7 +4,7 @@ from typing import Any, Iterable
4
4
 
5
5
  from wandb.proto import wandb_settings_pb2
6
6
  from wandb.sdk.lib import RunMoment
7
- from wandb.sdk.wandb_settings import Settings
7
+ from wandb.sdk.wandb_settings import CLIENT_ONLY_SETTINGS, Settings
8
8
 
9
9
 
10
10
  class SettingsStatic(Settings):
@@ -41,8 +41,7 @@ class SettingsStatic(Settings):
41
41
 
42
42
  forks_specified: list[str] = []
43
43
  for key in fields:
44
- # Skip Python-only keys that do not exist on the proto.
45
- if key in ("reinit",):
44
+ if key in CLIENT_ONLY_SETTINGS:
46
45
  continue
47
46
 
48
47
  value: Any = None
@@ -11,6 +11,7 @@ from wandb.apis.internal import Api
11
11
  from wandb.sdk.artifacts.artifact import Artifact
12
12
  from wandb.sdk.internal.job_builder import JobBuilder
13
13
  from wandb.sdk.launch.git_reference import GitReference
14
+ from wandb.sdk.launch.inputs.internal import _validate_schema
14
15
  from wandb.sdk.launch.utils import (
15
16
  _is_git_uri,
16
17
  get_current_python_version,
@@ -116,6 +117,7 @@ def _create_job(
116
117
  dockerfile: Optional[str] = None,
117
118
  base_image: Optional[str] = None,
118
119
  services: Optional[Dict[str, str]] = None,
120
+ schema: Optional[Dict[str, Any]] = None,
119
121
  ) -> Tuple[Optional[Artifact], str, List[str]]:
120
122
  wandb.termlog(f"Creating launch job of type: {job_type}...")
121
123
 
@@ -206,6 +208,15 @@ def _create_job(
206
208
  if "latest" not in aliases:
207
209
  aliases += ["latest"]
208
210
 
211
+ metadata = {"_partial": True}
212
+ if schema:
213
+ _validate_schema(schema)
214
+ metadata = {
215
+ "input_schemas": {
216
+ "@wandb.config": schema,
217
+ }
218
+ }
219
+
209
220
  res, _ = api.create_artifact(
210
221
  artifact_type_name="job",
211
222
  artifact_collection_name=name,
@@ -216,7 +227,7 @@ def _create_job(
216
227
  project_name=project,
217
228
  run_name=run.id, # type: ignore # run will be deleted after creation
218
229
  description=description,
219
- metadata={"_partial": True},
230
+ metadata=metadata,
220
231
  is_user_created=True,
221
232
  aliases=[{"artifactCollectionName": name, "alias": a} for a in aliases],
222
233
  )