wandb 0.19.7__py3-none-macosx_11_0_arm64.whl → 0.19.9__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/__init__.pyi +43 -9
  3. wandb/_pydantic/__init__.py +23 -0
  4. wandb/_pydantic/base.py +113 -0
  5. wandb/_pydantic/v1_compat.py +262 -0
  6. wandb/apis/paginator.py +82 -38
  7. wandb/apis/public/api.py +10 -64
  8. wandb/apis/public/artifacts.py +73 -17
  9. wandb/apis/public/files.py +2 -2
  10. wandb/apis/public/projects.py +3 -2
  11. wandb/apis/public/reports.py +2 -2
  12. wandb/apis/public/runs.py +19 -11
  13. wandb/bin/gpu_stats +0 -0
  14. wandb/bin/wandb-core +0 -0
  15. wandb/data_types.py +1 -1
  16. wandb/filesync/dir_watcher.py +2 -1
  17. wandb/integration/metaflow/metaflow.py +19 -17
  18. wandb/integration/sacred/__init__.py +1 -1
  19. wandb/jupyter.py +18 -15
  20. wandb/proto/v3/wandb_internal_pb2.py +7 -3
  21. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  22. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  23. wandb/proto/v4/wandb_internal_pb2.py +3 -3
  24. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  25. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  26. wandb/proto/v5/wandb_internal_pb2.py +3 -3
  27. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  28. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  29. wandb/proto/wandb_deprecated.py +2 -0
  30. wandb/sdk/artifacts/_graphql_fragments.py +18 -20
  31. wandb/sdk/artifacts/_validators.py +1 -0
  32. wandb/sdk/artifacts/artifact.py +81 -46
  33. wandb/sdk/artifacts/artifact_saver.py +16 -2
  34. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
  35. wandb/sdk/backend/backend.py +16 -5
  36. wandb/sdk/data_types/audio.py +1 -3
  37. wandb/sdk/data_types/base_types/media.py +11 -4
  38. wandb/sdk/data_types/image.py +44 -25
  39. wandb/sdk/data_types/molecule.py +1 -5
  40. wandb/sdk/data_types/object_3d.py +2 -1
  41. wandb/sdk/data_types/saved_model.py +7 -9
  42. wandb/sdk/data_types/video.py +1 -4
  43. wandb/sdk/interface/interface.py +65 -43
  44. wandb/sdk/interface/interface_queue.py +0 -7
  45. wandb/sdk/interface/interface_relay.py +6 -16
  46. wandb/sdk/interface/interface_shared.py +47 -40
  47. wandb/sdk/interface/interface_sock.py +1 -8
  48. wandb/sdk/interface/router.py +22 -54
  49. wandb/sdk/interface/router_queue.py +11 -10
  50. wandb/sdk/interface/router_relay.py +24 -12
  51. wandb/sdk/interface/router_sock.py +6 -11
  52. wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
  53. wandb/sdk/internal/_generated/base.py +226 -0
  54. wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
  55. wandb/{apis/public → sdk/internal}/_generated/typing_compat.py +1 -1
  56. wandb/sdk/internal/internal_api.py +138 -47
  57. wandb/sdk/internal/sender.py +5 -1
  58. wandb/sdk/internal/sender_config.py +8 -11
  59. wandb/sdk/internal/settings_static.py +24 -2
  60. wandb/sdk/lib/apikey.py +15 -16
  61. wandb/sdk/lib/console_capture.py +172 -0
  62. wandb/sdk/lib/redirect.py +102 -76
  63. wandb/sdk/lib/run_moment.py +4 -6
  64. wandb/sdk/lib/service_connection.py +37 -17
  65. wandb/sdk/lib/sock_client.py +2 -52
  66. wandb/sdk/lib/wb_logging.py +161 -0
  67. wandb/sdk/mailbox/__init__.py +3 -3
  68. wandb/sdk/mailbox/mailbox.py +31 -17
  69. wandb/sdk/mailbox/mailbox_handle.py +127 -0
  70. wandb/sdk/mailbox/{handles.py → response_handle.py} +34 -66
  71. wandb/sdk/mailbox/wait_with_progress.py +16 -15
  72. wandb/sdk/service/server_sock.py +4 -2
  73. wandb/sdk/service/streams.py +10 -5
  74. wandb/sdk/wandb_config.py +44 -43
  75. wandb/sdk/wandb_init.py +151 -92
  76. wandb/sdk/wandb_metadata.py +107 -91
  77. wandb/sdk/wandb_run.py +160 -54
  78. wandb/sdk/wandb_settings.py +410 -202
  79. wandb/sdk/wandb_setup.py +3 -1
  80. wandb/sdk/wandb_sync.py +1 -7
  81. {wandb-0.19.7.dist-info → wandb-0.19.9.dist-info}/METADATA +3 -3
  82. {wandb-0.19.7.dist-info → wandb-0.19.9.dist-info}/RECORD +88 -84
  83. wandb/apis/public/_generated/base.py +0 -128
  84. wandb/sdk/interface/message_future.py +0 -27
  85. wandb/sdk/interface/message_future_poll.py +0 -50
  86. /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
  87. /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
  88. /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
  89. {wandb-0.19.7.dist-info → wandb-0.19.9.dist-info}/WHEEL +0 -0
  90. {wandb-0.19.7.dist-info → wandb-0.19.9.dist-info}/entry_points.txt +0 -0
  91. {wandb-0.19.7.dist-info → wandb-0.19.9.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,7 @@ import os
3
3
  import platform
4
4
  import re
5
5
  import shutil
6
- from typing import TYPE_CHECKING, Optional, Sequence, Type, Union, cast
6
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Type, Union, cast
7
7
 
8
8
  import wandb
9
9
  from wandb import util
@@ -222,7 +222,10 @@ class Media(WBValue):
222
222
  from wandb.data_types import Audio
223
223
  from wandb.sdk.wandb_run import Run
224
224
 
225
- json_obj = {}
225
+ json_obj: Dict[str, Any] = {}
226
+
227
+ if self._caption is not None:
228
+ json_obj["caption"] = self._caption
226
229
 
227
230
  if isinstance(run, Run):
228
231
  json_obj.update(
@@ -232,6 +235,7 @@ class Media(WBValue):
232
235
  "size": self._size,
233
236
  }
234
237
  )
238
+
235
239
  artifact_entry_url = self._get_artifact_entry_ref_url()
236
240
  if artifact_entry_url is not None:
237
241
  json_obj["artifact_path"] = artifact_entry_url
@@ -337,8 +341,11 @@ class BatchableMedia(Media):
337
341
  organize files by name in the media directory.
338
342
  """
339
343
 
340
- def __init__(self) -> None:
341
- super().__init__()
344
+ def __init__(
345
+ self,
346
+ caption: Optional[str] = None,
347
+ ) -> None:
348
+ super().__init__(caption=caption)
342
349
 
343
350
  @classmethod
344
351
  def seq_to_json(
@@ -152,12 +152,11 @@ class Image(BatchableMedia):
152
152
  masks: Optional[Union[Dict[str, "ImageMask"], Dict[str, dict]]] = None,
153
153
  file_type: Optional[str] = None,
154
154
  ) -> None:
155
- super().__init__()
155
+ super().__init__(caption=caption)
156
156
  # TODO: We should remove grouping, it's a terrible name and I don't
157
157
  # think anyone uses it.
158
158
 
159
159
  self._grouping = None
160
- self._caption = None
161
160
  self._width = None
162
161
  self._height = None
163
162
  self._image = None
@@ -193,9 +192,6 @@ class Image(BatchableMedia):
193
192
  if grouping is not None:
194
193
  self._grouping = grouping
195
194
 
196
- if caption is not None:
197
- self._caption = caption
198
-
199
195
  total_classes = {}
200
196
 
201
197
  if boxes:
@@ -297,10 +293,19 @@ class Image(BatchableMedia):
297
293
  "PIL.Image",
298
294
  required='wandb.Image needs the PIL package. To get it, run "pip install pillow".',
299
295
  )
296
+
297
+ accepted_formats = ["png", "jpg", "jpeg", "bmp"]
298
+ self.format = file_type or "png"
299
+
300
+ if self.format not in accepted_formats:
301
+ raise ValueError(f"file_type must be one of {accepted_formats}")
302
+
303
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + "." + self.format)
304
+
300
305
  if util.is_matplotlib_typename(util.get_full_typename(data)):
301
306
  buf = BytesIO()
302
- util.ensure_matplotlib_figure(data).savefig(buf, format="png")
303
- self._image = pil_image.open(buf, formats=["PNG"])
307
+ util.ensure_matplotlib_figure(data).savefig(buf, format=self.format)
308
+ self._image = pil_image.open(buf)
304
309
  elif isinstance(data, pil_image.Image):
305
310
  self._image = data
306
311
  elif util.is_pytorch_tensor_typename(util.get_full_typename(data)):
@@ -312,26 +317,23 @@ class Image(BatchableMedia):
312
317
  if hasattr(data, "dtype") and str(data.dtype) == "torch.uint8":
313
318
  data = data.to(float)
314
319
  data = vis_util.make_grid(data, normalize=True)
320
+ mode = mode or self.guess_mode(data, file_type)
315
321
  self._image = pil_image.fromarray(
316
- data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
322
+ data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(),
323
+ mode=mode,
317
324
  )
318
325
  else:
319
326
  if hasattr(data, "numpy"): # TF data eager tensors
320
327
  data = data.numpy()
321
328
  if data.ndim > 2:
322
329
  data = data.squeeze() # get rid of trivial dimensions as a convenience
330
+
331
+ mode = mode or self.guess_mode(data, file_type)
323
332
  self._image = pil_image.fromarray(
324
- self.to_uint8(data), mode=mode or self.guess_mode(data)
333
+ self.to_uint8(data),
334
+ mode=mode,
325
335
  )
326
- accepted_formats = ["png", "jpg", "jpeg", "bmp"]
327
- if file_type is None:
328
- self.format = "png"
329
- else:
330
- self.format = file_type
331
- assert (
332
- self.format in accepted_formats
333
- ), f"file_type must be one of {accepted_formats}"
334
- tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + "." + self.format)
336
+
335
337
  assert self._image is not None
336
338
  self._image.save(tmp_path, transparency=None)
337
339
  self._set_file(tmp_path, is_tmp=True)
@@ -430,8 +432,6 @@ class Image(BatchableMedia):
430
432
  json_dict["height"] = self._height
431
433
  if self._grouping:
432
434
  json_dict["grouping"] = self._grouping
433
- if self._caption:
434
- json_dict["caption"] = self._caption
435
435
 
436
436
  if isinstance(run_or_artifact, wandb.Artifact):
437
437
  artifact = run_or_artifact
@@ -471,15 +471,34 @@ class Image(BatchableMedia):
471
471
  }
472
472
  return json_dict
473
473
 
474
- def guess_mode(self, data: "np.ndarray") -> str:
474
+ def guess_mode(
475
+ self,
476
+ data: Union["np.ndarray", "torch.Tensor"],
477
+ file_type: Optional[str] = None,
478
+ ) -> str:
475
479
  """Guess what type of image the np.array is representing."""
476
480
  # TODO: do we want to support dimensions being at the beginning of the array?
477
- if data.ndim == 2:
481
+ ndims = data.ndim
482
+ if util.is_pytorch_tensor_typename(util.get_full_typename(data)):
483
+ # Torch tenors typically have the channels dimension first
484
+ num_channels = data.shape[0]
485
+ else:
486
+ num_channels = data.shape[-1]
487
+
488
+ if ndims == 2:
478
489
  return "L"
479
- elif data.shape[-1] == 3:
490
+ elif num_channels == 3:
480
491
  return "RGB"
481
- elif data.shape[-1] == 4:
482
- return "RGBA"
492
+ elif num_channels == 4:
493
+ if file_type in ["jpg", "jpeg"]:
494
+ wandb.termwarn(
495
+ "JPEG format does not support transparency. "
496
+ "Ignoring alpha channel.",
497
+ repeat=False,
498
+ )
499
+ return "RGB"
500
+ else:
501
+ return "RGBA"
483
502
  else:
484
503
  raise ValueError(
485
504
  "Un-supported shape for image conversion {}".format(list(data.shape))
@@ -53,9 +53,7 @@ class Molecule(BatchableMedia):
53
53
  caption: Optional[str] = None,
54
54
  **kwargs: str,
55
55
  ) -> None:
56
- super().__init__()
57
-
58
- self._caption = caption
56
+ super().__init__(caption=caption)
59
57
 
60
58
  if hasattr(data_or_path, "name"):
61
59
  # if the file has a path, we just detect the type and copy it from there
@@ -208,8 +206,6 @@ class Molecule(BatchableMedia):
208
206
  def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
209
207
  json_dict = super().to_json(run_or_artifact)
210
208
  json_dict["_type"] = self._log_type
211
- if self._caption:
212
- json_dict["caption"] = self._caption
213
209
  return json_dict
214
210
 
215
211
  @classmethod
@@ -215,9 +215,10 @@ class Object3D(BatchableMedia):
215
215
  def __init__(
216
216
  self,
217
217
  data_or_path: Union["np.ndarray", str, "TextIO", dict],
218
+ caption: Optional[str] = None,
218
219
  **kwargs: Optional[Union[str, "FileFormat3D"]],
219
220
  ) -> None:
220
- super().__init__()
221
+ super().__init__(caption=caption)
221
222
 
222
223
  if hasattr(data_or_path, "name"):
223
224
  # if the file has a path, we just detect the type and copy it from there
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  import shutil
5
5
  import sys
6
+ from types import ModuleType
6
7
  from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
7
8
 
8
9
  import wandb
@@ -15,9 +16,6 @@ from ._private import MEDIA_TMP
15
16
  from .base_types.wb_value import WBValue
16
17
 
17
18
  if TYPE_CHECKING:
18
- from types import ModuleType
19
-
20
- import cloudpickle # type: ignore
21
19
  import sklearn # type: ignore
22
20
  import tensorflow # type: ignore
23
21
  import torch # type: ignore
@@ -264,9 +262,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
264
262
  self._serialize(self._model_obj, target_path)
265
263
 
266
264
 
267
- def _get_cloudpickle() -> "cloudpickle":
265
+ def _get_cloudpickle() -> ModuleType:
268
266
  return cast(
269
- "cloudpickle",
267
+ ModuleType,
270
268
  util.get_module("cloudpickle", "ModelAdapter requires `cloudpickle`"),
271
269
  )
272
270
 
@@ -338,9 +336,9 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
338
336
  return json_obj
339
337
 
340
338
 
341
- def _get_torch() -> "torch":
339
+ def _get_torch() -> ModuleType:
342
340
  return cast(
343
- "torch",
341
+ ModuleType,
344
342
  util.get_module("torch", "ModelAdapter requires `torch`"),
345
343
  )
346
344
 
@@ -366,9 +364,9 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
366
364
  )
367
365
 
368
366
 
369
- def _get_sklearn() -> "sklearn":
367
+ def _get_sklearn() -> ModuleType:
370
368
  return cast(
371
- "sklearn",
369
+ ModuleType,
372
370
  util.get_module("sklearn", "ModelAdapter requires `sklearn`"),
373
371
  )
374
372
 
@@ -90,13 +90,12 @@ class Video(BatchableMedia):
90
90
  fps: Optional[int] = None,
91
91
  format: Optional[str] = None,
92
92
  ):
93
- super().__init__()
93
+ super().__init__(caption=caption)
94
94
 
95
95
  self._format = format or "gif"
96
96
  self._width = None
97
97
  self._height = None
98
98
  self._channels = None
99
- self._caption = caption
100
99
  if self._format not in Video.EXTS:
101
100
  raise ValueError(
102
101
  "wandb.Video accepts {} formats".format(", ".join(Video.EXTS))
@@ -190,8 +189,6 @@ class Video(BatchableMedia):
190
189
  json_dict["width"] = self._width
191
190
  if self._height is not None:
192
191
  json_dict["height"] = self._height
193
- if self._caption:
194
- json_dict["caption"] = self._caption
195
192
 
196
193
  return json_dict
197
194
 
@@ -4,8 +4,6 @@ InterfaceBase: The abstract class
4
4
  InterfaceShared: Common routines for socket and queue based implementations
5
5
  InterfaceQueue: Use multiprocessing queues to send and receive messages
6
6
  InterfaceSock: Use socket to send and receive messages
7
- InterfaceRelay: Responses are routed to a relay queue (not matching uuids)
8
-
9
7
  """
10
8
 
11
9
  import gzip
@@ -102,14 +100,14 @@ class InterfaceBase:
102
100
  def _publish_header(self, header: pb.HeaderRecord) -> None:
103
101
  raise NotImplementedError
104
102
 
105
- def deliver_status(self) -> MailboxHandle:
103
+ def deliver_status(self) -> MailboxHandle[pb.Result]:
106
104
  return self._deliver_status(pb.StatusRequest())
107
105
 
108
106
  @abstractmethod
109
107
  def _deliver_status(
110
108
  self,
111
109
  status: pb.StatusRequest,
112
- ) -> MailboxHandle:
110
+ ) -> MailboxHandle[pb.Result]:
113
111
  raise NotImplementedError
114
112
 
115
113
  def _make_config(
@@ -435,7 +433,7 @@ class InterfaceBase:
435
433
  entity: Optional[str] = None,
436
434
  project: Optional[str] = None,
437
435
  organization: Optional[str] = None,
438
- ) -> MailboxHandle:
436
+ ) -> MailboxHandle[pb.Result]:
439
437
  link_artifact = pb.LinkArtifactRequest()
440
438
  if artifact.is_draft():
441
439
  link_artifact.client_id = artifact._client_id
@@ -452,7 +450,7 @@ class InterfaceBase:
452
450
  @abstractmethod
453
451
  def _deliver_link_artifact(
454
452
  self, link_artifact: pb.LinkArtifactRequest
455
- ) -> MailboxHandle:
453
+ ) -> MailboxHandle[pb.Result]:
456
454
  raise NotImplementedError
457
455
 
458
456
  @staticmethod
@@ -570,7 +568,7 @@ class InterfaceBase:
570
568
  is_user_created: bool = False,
571
569
  use_after_commit: bool = False,
572
570
  finalize: bool = True,
573
- ) -> MailboxHandle:
571
+ ) -> MailboxHandle[pb.Result]:
574
572
  proto_run = self._make_run(run)
575
573
  proto_artifact = self._make_artifact(artifact)
576
574
  proto_artifact.run_id = proto_run.run_id
@@ -595,7 +593,7 @@ class InterfaceBase:
595
593
  def _deliver_artifact(
596
594
  self,
597
595
  log_artifact: pb.LogArtifactRequest,
598
- ) -> MailboxHandle:
596
+ ) -> MailboxHandle[pb.Result]:
599
597
  raise NotImplementedError
600
598
 
601
599
  def deliver_download_artifact(
@@ -605,7 +603,7 @@ class InterfaceBase:
605
603
  allow_missing_references: bool,
606
604
  skip_cache: bool,
607
605
  path_prefix: Optional[str],
608
- ) -> MailboxHandle:
606
+ ) -> MailboxHandle[pb.Result]:
609
607
  download_artifact = pb.DownloadArtifactRequest()
610
608
  download_artifact.artifact_id = artifact_id
611
609
  download_artifact.download_root = download_root
@@ -618,7 +616,7 @@ class InterfaceBase:
618
616
  @abstractmethod
619
617
  def _deliver_download_artifact(
620
618
  self, download_artifact: pb.DownloadArtifactRequest
621
- ) -> MailboxHandle:
619
+ ) -> MailboxHandle[pb.Result]:
622
620
  raise NotImplementedError
623
621
 
624
622
  def publish_artifact(
@@ -870,7 +868,9 @@ class InterfaceBase:
870
868
  return self._publish_job_input(request)
871
869
 
872
870
  @abstractmethod
873
- def _publish_job_input(self, request: pb.JobInputRequest) -> MailboxHandle:
871
+ def _publish_job_input(
872
+ self, request: pb.JobInputRequest
873
+ ) -> MailboxHandle[pb.Result]:
874
874
  raise NotImplementedError
875
875
 
876
876
  def join(self) -> None:
@@ -892,143 +892,165 @@ class InterfaceBase:
892
892
  logger.warning("handle abandoned while communicating shutdown")
893
893
 
894
894
  @abstractmethod
895
- def _deliver_shutdown(self) -> MailboxHandle:
895
+ def _deliver_shutdown(self) -> MailboxHandle[pb.Result]:
896
896
  raise NotImplementedError
897
897
 
898
- def deliver_run(self, run: "Run") -> MailboxHandle:
898
+ def deliver_run(self, run: "Run") -> MailboxHandle[pb.Result]:
899
899
  run_record = self._make_run(run)
900
900
  return self._deliver_run(run_record)
901
901
 
902
902
  def deliver_finish_sync(
903
903
  self,
904
- ) -> MailboxHandle:
904
+ ) -> MailboxHandle[pb.Result]:
905
905
  sync = pb.SyncFinishRequest()
906
906
  return self._deliver_finish_sync(sync)
907
907
 
908
908
  @abstractmethod
909
- def _deliver_finish_sync(self, sync: pb.SyncFinishRequest) -> MailboxHandle:
909
+ def _deliver_finish_sync(
910
+ self, sync: pb.SyncFinishRequest
911
+ ) -> MailboxHandle[pb.Result]:
910
912
  raise NotImplementedError
911
913
 
912
914
  @abstractmethod
913
- def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle:
915
+ def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle[pb.Result]:
914
916
  raise NotImplementedError
915
917
 
916
- def deliver_run_start(self, run: "Run") -> MailboxHandle:
918
+ def deliver_run_start(self, run: "Run") -> MailboxHandle[pb.Result]:
917
919
  run_start = pb.RunStartRequest(run=self._make_run(run))
918
920
  return self._deliver_run_start(run_start)
919
921
 
920
922
  @abstractmethod
921
- def _deliver_run_start(self, run_start: pb.RunStartRequest) -> MailboxHandle:
923
+ def _deliver_run_start(
924
+ self, run_start: pb.RunStartRequest
925
+ ) -> MailboxHandle[pb.Result]:
922
926
  raise NotImplementedError
923
927
 
924
- def deliver_attach(self, attach_id: str) -> MailboxHandle:
928
+ def deliver_attach(self, attach_id: str) -> MailboxHandle[pb.Result]:
925
929
  attach = pb.AttachRequest(attach_id=attach_id)
926
930
  return self._deliver_attach(attach)
927
931
 
928
932
  @abstractmethod
929
- def _deliver_attach(self, status: pb.AttachRequest) -> MailboxHandle:
933
+ def _deliver_attach(
934
+ self,
935
+ status: pb.AttachRequest,
936
+ ) -> MailboxHandle[pb.Result]:
930
937
  raise NotImplementedError
931
938
 
932
- def deliver_stop_status(self) -> MailboxHandle:
939
+ def deliver_stop_status(self) -> MailboxHandle[pb.Result]:
933
940
  status = pb.StopStatusRequest()
934
941
  return self._deliver_stop_status(status)
935
942
 
936
943
  @abstractmethod
937
- def _deliver_stop_status(self, status: pb.StopStatusRequest) -> MailboxHandle:
944
+ def _deliver_stop_status(
945
+ self,
946
+ status: pb.StopStatusRequest,
947
+ ) -> MailboxHandle[pb.Result]:
938
948
  raise NotImplementedError
939
949
 
940
- def deliver_network_status(self) -> MailboxHandle:
950
+ def deliver_network_status(self) -> MailboxHandle[pb.Result]:
941
951
  status = pb.NetworkStatusRequest()
942
952
  return self._deliver_network_status(status)
943
953
 
944
954
  @abstractmethod
945
- def _deliver_network_status(self, status: pb.NetworkStatusRequest) -> MailboxHandle:
955
+ def _deliver_network_status(
956
+ self,
957
+ status: pb.NetworkStatusRequest,
958
+ ) -> MailboxHandle[pb.Result]:
946
959
  raise NotImplementedError
947
960
 
948
- def deliver_internal_messages(self) -> MailboxHandle:
961
+ def deliver_internal_messages(self) -> MailboxHandle[pb.Result]:
949
962
  internal_message = pb.InternalMessagesRequest()
950
963
  return self._deliver_internal_messages(internal_message)
951
964
 
952
965
  @abstractmethod
953
966
  def _deliver_internal_messages(
954
967
  self, internal_message: pb.InternalMessagesRequest
955
- ) -> MailboxHandle:
968
+ ) -> MailboxHandle[pb.Result]:
956
969
  raise NotImplementedError
957
970
 
958
- def deliver_get_summary(self) -> MailboxHandle:
971
+ def deliver_get_summary(self) -> MailboxHandle[pb.Result]:
959
972
  get_summary = pb.GetSummaryRequest()
960
973
  return self._deliver_get_summary(get_summary)
961
974
 
962
975
  @abstractmethod
963
- def _deliver_get_summary(self, get_summary: pb.GetSummaryRequest) -> MailboxHandle:
976
+ def _deliver_get_summary(
977
+ self,
978
+ get_summary: pb.GetSummaryRequest,
979
+ ) -> MailboxHandle[pb.Result]:
964
980
  raise NotImplementedError
965
981
 
966
- def deliver_get_system_metrics(self) -> MailboxHandle:
982
+ def deliver_get_system_metrics(self) -> MailboxHandle[pb.Result]:
967
983
  get_system_metrics = pb.GetSystemMetricsRequest()
968
984
  return self._deliver_get_system_metrics(get_system_metrics)
969
985
 
970
986
  @abstractmethod
971
987
  def _deliver_get_system_metrics(
972
988
  self, get_summary: pb.GetSystemMetricsRequest
973
- ) -> MailboxHandle:
989
+ ) -> MailboxHandle[pb.Result]:
974
990
  raise NotImplementedError
975
991
 
976
- def deliver_get_system_metadata(self) -> MailboxHandle:
992
+ def deliver_get_system_metadata(self) -> MailboxHandle[pb.Result]:
977
993
  get_system_metadata = pb.GetSystemMetadataRequest()
978
994
  return self._deliver_get_system_metadata(get_system_metadata)
979
995
 
980
996
  @abstractmethod
981
997
  def _deliver_get_system_metadata(
982
998
  self, get_system_metadata: pb.GetSystemMetadataRequest
983
- ) -> MailboxHandle:
999
+ ) -> MailboxHandle[pb.Result]:
984
1000
  raise NotImplementedError
985
1001
 
986
- def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle:
1002
+ def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle[pb.Result]:
987
1003
  exit_data = self._make_exit(exit_code)
988
1004
  return self._deliver_exit(exit_data)
989
1005
 
990
1006
  @abstractmethod
991
- def _deliver_exit(self, exit_data: pb.RunExitRecord) -> MailboxHandle:
1007
+ def _deliver_exit(
1008
+ self,
1009
+ exit_data: pb.RunExitRecord,
1010
+ ) -> MailboxHandle[pb.Result]:
992
1011
  raise NotImplementedError
993
1012
 
994
1013
  @abstractmethod
995
- def deliver_operation_stats(self) -> MailboxHandle:
1014
+ def deliver_operation_stats(self) -> MailboxHandle[pb.Result]:
996
1015
  raise NotImplementedError
997
1016
 
998
- def deliver_poll_exit(self) -> MailboxHandle:
1017
+ def deliver_poll_exit(self) -> MailboxHandle[pb.Result]:
999
1018
  poll_exit = pb.PollExitRequest()
1000
1019
  return self._deliver_poll_exit(poll_exit)
1001
1020
 
1002
1021
  @abstractmethod
1003
- def _deliver_poll_exit(self, poll_exit: pb.PollExitRequest) -> MailboxHandle:
1022
+ def _deliver_poll_exit(
1023
+ self,
1024
+ poll_exit: pb.PollExitRequest,
1025
+ ) -> MailboxHandle[pb.Result]:
1004
1026
  raise NotImplementedError
1005
1027
 
1006
- def deliver_finish_without_exit(self) -> MailboxHandle:
1028
+ def deliver_finish_without_exit(self) -> MailboxHandle[pb.Result]:
1007
1029
  run_finish_without_exit = pb.RunFinishWithoutExitRequest()
1008
1030
  return self._deliver_finish_without_exit(run_finish_without_exit)
1009
1031
 
1010
1032
  @abstractmethod
1011
1033
  def _deliver_finish_without_exit(
1012
1034
  self, run_finish_without_exit: pb.RunFinishWithoutExitRequest
1013
- ) -> MailboxHandle:
1035
+ ) -> MailboxHandle[pb.Result]:
1014
1036
  raise NotImplementedError
1015
1037
 
1016
- def deliver_request_sampled_history(self) -> MailboxHandle:
1038
+ def deliver_request_sampled_history(self) -> MailboxHandle[pb.Result]:
1017
1039
  sampled_history = pb.SampledHistoryRequest()
1018
1040
  return self._deliver_request_sampled_history(sampled_history)
1019
1041
 
1020
1042
  @abstractmethod
1021
1043
  def _deliver_request_sampled_history(
1022
1044
  self, sampled_history: pb.SampledHistoryRequest
1023
- ) -> MailboxHandle:
1045
+ ) -> MailboxHandle[pb.Result]:
1024
1046
  raise NotImplementedError
1025
1047
 
1026
- def deliver_request_run_status(self) -> MailboxHandle:
1048
+ def deliver_request_run_status(self) -> MailboxHandle[pb.Result]:
1027
1049
  run_status = pb.RunStatusRequest()
1028
1050
  return self._deliver_request_run_status(run_status)
1029
1051
 
1030
1052
  @abstractmethod
1031
1053
  def _deliver_request_run_status(
1032
1054
  self, run_status: pb.RunStatusRequest
1033
- ) -> MailboxHandle:
1055
+ ) -> MailboxHandle[pb.Result]:
1034
1056
  raise NotImplementedError
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Optional
11
11
  from wandb.sdk.mailbox import Mailbox
12
12
 
13
13
  from .interface_shared import InterfaceShared
14
- from .router_queue import MessageQueueRouter
15
14
 
16
15
  if TYPE_CHECKING:
17
16
  from queue import Queue
@@ -35,12 +34,6 @@ class InterfaceQueue(InterfaceShared):
35
34
  self._process = process
36
35
  super().__init__(mailbox=mailbox)
37
36
 
38
- def _init_router(self) -> None:
39
- if self.record_q and self.result_q:
40
- self._router = MessageQueueRouter(
41
- self.record_q, self.result_q, mailbox=self._mailbox
42
- )
43
-
44
37
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
45
38
  if self._process and not self._process.is_alive():
46
39
  raise Exception("The wandb backend process has shutdown")
@@ -1,17 +1,17 @@
1
1
  """InterfaceRelay - Derived from InterfaceQueue using RelayRouter to preserve uuid req/resp.
2
2
 
3
3
  See interface.py for how interface classes relate to each other.
4
-
5
4
  """
6
5
 
6
+ from __future__ import annotations
7
+
7
8
  import logging
8
- from typing import TYPE_CHECKING, Optional
9
+ from typing import TYPE_CHECKING
9
10
 
10
11
  from wandb.proto import wandb_internal_pb2 as pb
11
12
  from wandb.sdk.mailbox import Mailbox
12
13
 
13
14
  from .interface_queue import InterfaceQueue
14
- from .router_relay import MessageRelayRouter
15
15
 
16
16
  if TYPE_CHECKING:
17
17
  from queue import Queue
@@ -22,14 +22,13 @@ logger = logging.getLogger("wandb")
22
22
 
23
23
  class InterfaceRelay(InterfaceQueue):
24
24
  _mailbox: Mailbox
25
- relay_q: Optional["Queue[pb.Result]"]
26
25
 
27
26
  def __init__(
28
27
  self,
29
28
  mailbox: Mailbox,
30
- record_q: Optional["Queue[pb.Record]"] = None,
31
- result_q: Optional["Queue[pb.Result]"] = None,
32
- relay_q: Optional["Queue[pb.Result]"] = None,
29
+ record_q: Queue[pb.Record],
30
+ result_q: Queue[pb.Result],
31
+ relay_q: Queue[pb.Result],
33
32
  ) -> None:
34
33
  self.relay_q = relay_q
35
34
  super().__init__(
@@ -37,12 +36,3 @@ class InterfaceRelay(InterfaceQueue):
37
36
  result_q=result_q,
38
37
  mailbox=mailbox,
39
38
  )
40
-
41
- def _init_router(self) -> None:
42
- if self.record_q and self.result_q and self.relay_q:
43
- self._router = MessageRelayRouter(
44
- request_queue=self.record_q,
45
- response_queue=self.result_q,
46
- relay_queue=self.relay_q,
47
- mailbox=self._mailbox,
48
- )