wandb 0.19.6__py3-none-musllinux_1_2_aarch64.whl → 0.19.7__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +25 -5
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/bin/wandb-core +0 -0
  16. wandb/cli/cli.py +11 -20
  17. wandb/env.py +10 -0
  18. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  19. wandb/proto/v3/wandb_server_pb2.py +4 -4
  20. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  21. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  22. wandb/proto/v4/wandb_server_pb2.py +4 -4
  23. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  24. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  25. wandb/proto/v5/wandb_server_pb2.py +4 -4
  26. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  27. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  28. wandb/sdk/artifacts/artifact.py +43 -88
  29. wandb/sdk/backend/backend.py +1 -1
  30. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  31. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  32. wandb/sdk/data_types/saved_model.py +35 -46
  33. wandb/sdk/data_types/video.py +7 -16
  34. wandb/sdk/interface/interface.py +26 -10
  35. wandb/sdk/interface/interface_queue.py +5 -8
  36. wandb/sdk/interface/interface_relay.py +1 -6
  37. wandb/sdk/interface/interface_shared.py +21 -99
  38. wandb/sdk/interface/interface_sock.py +2 -13
  39. wandb/sdk/interface/router.py +21 -15
  40. wandb/sdk/interface/router_queue.py +2 -1
  41. wandb/sdk/interface/router_relay.py +2 -1
  42. wandb/sdk/interface/router_sock.py +5 -4
  43. wandb/sdk/internal/handler.py +4 -3
  44. wandb/sdk/internal/internal_api.py +12 -1
  45. wandb/sdk/internal/sender.py +0 -18
  46. wandb/sdk/lib/apikey.py +87 -26
  47. wandb/sdk/lib/asyncio_compat.py +210 -0
  48. wandb/sdk/lib/progress.py +78 -16
  49. wandb/sdk/lib/service_connection.py +1 -1
  50. wandb/sdk/lib/sock_client.py +7 -7
  51. wandb/sdk/mailbox/__init__.py +23 -0
  52. wandb/sdk/mailbox/handles.py +199 -0
  53. wandb/sdk/mailbox/mailbox.py +121 -0
  54. wandb/sdk/mailbox/wait_with_progress.py +134 -0
  55. wandb/sdk/service/server_sock.py +5 -1
  56. wandb/sdk/service/streams.py +66 -74
  57. wandb/sdk/verify/verify.py +54 -2
  58. wandb/sdk/wandb_init.py +61 -61
  59. wandb/sdk/wandb_login.py +7 -4
  60. wandb/sdk/wandb_metadata.py +65 -34
  61. wandb/sdk/wandb_require.py +14 -8
  62. wandb/sdk/wandb_run.py +82 -87
  63. wandb/sdk/wandb_settings.py +3 -3
  64. wandb/sdk/wandb_setup.py +19 -8
  65. wandb/sdk/wandb_sync.py +2 -4
  66. wandb/util.py +3 -1
  67. {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
  68. {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/RECORD +71 -58
  69. wandb/sdk/lib/mailbox.py +0 -442
  70. {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
  71. {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
  72. {wandb-0.19.6.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import shutil
3
5
  import sys
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- ClassVar,
8
- Generic,
9
- List,
10
- Optional,
11
- Type,
12
- TypeVar,
13
- Union,
14
- cast,
15
- )
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
16
7
 
17
8
  import wandb
18
9
  from wandb import util
@@ -23,22 +14,24 @@ from wandb.sdk.lib.paths import LogicalPath
23
14
  from ._private import MEDIA_TMP
24
15
  from .base_types.wb_value import WBValue
25
16
 
26
- if TYPE_CHECKING: # pragma: no cover
17
+ if TYPE_CHECKING:
18
+ from types import ModuleType
19
+
27
20
  import cloudpickle # type: ignore
28
21
  import sklearn # type: ignore
29
22
  import tensorflow # type: ignore
30
23
  import torch # type: ignore
24
+ from typing_extensions import Self
31
25
 
32
26
  from wandb.sdk.artifacts.artifact import Artifact
33
-
34
- from ..wandb_run import Run as LocalRun
27
+ from wandb.sdk.wandb_run import Run as LocalRun
35
28
 
36
29
 
37
30
  DEBUG_MODE = False
38
31
 
39
32
 
40
33
  def _add_deterministic_dir_to_artifact(
41
- artifact: "Artifact", dir_name: str, target_dir_root: str
34
+ artifact: Artifact, dir_name: str, target_dir_root: str
42
35
  ) -> str:
43
36
  file_paths = []
44
37
  for dirpath, _, filenames in os.walk(dir_name, topdown=True):
@@ -50,7 +43,7 @@ def _add_deterministic_dir_to_artifact(
50
43
  return target_path
51
44
 
52
45
 
53
- def _load_dir_from_artifact(source_artifact: "Artifact", path: str) -> str:
46
+ def _load_dir_from_artifact(source_artifact: Artifact, path: str) -> str:
54
47
  dl_path = None
55
48
 
56
49
  # Look through the entire manifest to find all of the files in the directory.
@@ -79,14 +72,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
79
72
  _log_type: ClassVar[str]
80
73
  _path_extension: ClassVar[str]
81
74
 
82
- _model_obj: Optional["SavedModelObjType"]
83
- _path: Optional[str]
84
- _input_obj_or_path: Union[SavedModelObjType, str]
75
+ _model_obj: SavedModelObjType | None
76
+ _path: str | None
77
+ _input_obj_or_path: SavedModelObjType | str
85
78
 
86
79
  # Public Methods
87
- def __init__(
88
- self, obj_or_path: Union[SavedModelObjType, str], **kwargs: Any
89
- ) -> None:
80
+ def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
90
81
  super().__init__()
91
82
  if self.__class__ == _SavedModel:
92
83
  raise TypeError(
@@ -115,7 +106,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
115
106
  self._unset_obj()
116
107
 
117
108
  @staticmethod
118
- def init(obj_or_path: Any, **kwargs: Any) -> "_SavedModel":
109
+ def init(obj_or_path: Any, **kwargs: Any) -> _SavedModel:
119
110
  maybe_instance = _SavedModel._maybe_init(obj_or_path, **kwargs)
120
111
  if maybe_instance is None:
121
112
  raise ValueError(
@@ -125,8 +116,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
125
116
 
126
117
  @classmethod
127
118
  def from_json(
128
- cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
129
- ) -> "_SavedModel":
119
+ cls: type[_SavedModel], json_obj: dict, source_artifact: Artifact
120
+ ) -> _SavedModel:
130
121
  path = json_obj["path"]
131
122
 
132
123
  # First, if the entry is a file, the download it.
@@ -143,7 +134,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
143
134
  # and specified adapter.
144
135
  return cls(dl_path)
145
136
 
146
- def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
137
+ def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
147
138
  # Unlike other data types, we do not allow adding to a Run directly. There is a
148
139
  # bit of tech debt in the other data types which requires the input to `to_json`
149
140
  # to accept a Run or Artifact. However, Run additions should be deprecated in the future.
@@ -218,8 +209,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
218
209
  # Private Class Methods
219
210
  @classmethod
220
211
  def _maybe_init(
221
- cls: Type["_SavedModel"], obj_or_path: Any, **kwargs: Any
222
- ) -> Optional["_SavedModel"]:
212
+ cls: type[_SavedModel], obj_or_path: Any, **kwargs: Any
213
+ ) -> _SavedModel | None:
223
214
  # _maybe_init is an exception-safe method that will return an instance of this class
224
215
  # (or any subclass of this class - recursively) OR None if no subclass constructor is found.
225
216
  # We first try the current class, then recursively call this method on children classes. This pattern
@@ -241,7 +232,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
241
232
  return None
242
233
 
243
234
  @classmethod
244
- def _tmp_path(cls: Type["_SavedModel"]) -> str:
235
+ def _tmp_path(cls: type[_SavedModel]) -> str:
245
236
  # Generates a tmp path under our MEDIA_TMP directory which confirms to the file
246
237
  # or folder preferences of the class.
247
238
  assert isinstance(cls._path_extension, str), "_path_extension must be a string"
@@ -286,13 +277,13 @@ PicklingSavedModelObjType = TypeVar("PicklingSavedModelObjType")
286
277
 
287
278
 
288
279
  class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
289
- _dep_py_files: Optional[List[str]] = None
290
- _dep_py_files_path: Optional[str] = None
280
+ _dep_py_files: list[str] | None = None
281
+ _dep_py_files_path: str | None = None
291
282
 
292
283
  def __init__(
293
284
  self,
294
- obj_or_path: Union[SavedModelObjType, str],
295
- dep_py_files: Optional[List[str]] = None,
285
+ obj_or_path: SavedModelObjType | str,
286
+ dep_py_files: list[str] | None = None,
296
287
  ):
297
288
  super().__init__(obj_or_path)
298
289
  if self.__class__ == _PicklingSavedModel:
@@ -319,9 +310,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
319
310
  raise ValueError(f"Invalid dependency file: {extra_file}")
320
311
 
321
312
  @classmethod
322
- def from_json(
323
- cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
324
- ) -> "_PicklingSavedModel":
313
+ def from_json(cls, json_obj: dict, source_artifact: Artifact) -> Self:
325
314
  backup_path = [p for p in sys.path]
326
315
  if (
327
316
  "dep_py_files_path" in json_obj
@@ -337,7 +326,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
337
326
 
338
327
  return inst # type: ignore
339
328
 
340
- def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
329
+ def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
341
330
  json_obj = super().to_json(run_or_artifact)
342
331
  assert isinstance(run_or_artifact, wandb.Artifact)
343
332
  if self._dep_py_files_path is not None:
@@ -361,7 +350,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
361
350
  _path_extension = "pt"
362
351
 
363
352
  @staticmethod
364
- def _deserialize(dir_or_file_path: str) -> "torch.nn.Module":
353
+ def _deserialize(dir_or_file_path: str) -> torch.nn.Module:
365
354
  return _get_torch().load(dir_or_file_path, weights_only=False)
366
355
 
367
356
  @staticmethod
@@ -369,7 +358,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
369
358
  return isinstance(obj, _get_torch().nn.Module)
370
359
 
371
360
  @staticmethod
372
- def _serialize(model_obj: "torch.nn.Module", dir_or_file_path: str) -> None:
361
+ def _serialize(model_obj: torch.nn.Module, dir_or_file_path: str) -> None:
373
362
  _get_torch().save(
374
363
  model_obj,
375
364
  dir_or_file_path,
@@ -391,7 +380,7 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
391
380
  @staticmethod
392
381
  def _deserialize(
393
382
  dir_or_file_path: str,
394
- ) -> "sklearn.base.BaseEstimator":
383
+ ) -> sklearn.base.BaseEstimator:
395
384
  with open(dir_or_file_path, "rb") as file:
396
385
  model = _get_cloudpickle().load(file)
397
386
  return model
@@ -410,16 +399,16 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
410
399
 
411
400
  @staticmethod
412
401
  def _serialize(
413
- model_obj: "sklearn.base.BaseEstimator", dir_or_file_path: str
402
+ model_obj: sklearn.base.BaseEstimator, dir_or_file_path: str
414
403
  ) -> None:
415
404
  dynamic_cloudpickle = _get_cloudpickle()
416
405
  with open(dir_or_file_path, "wb") as file:
417
406
  dynamic_cloudpickle.dump(model_obj, file)
418
407
 
419
408
 
420
- def _get_tf_keras() -> "tensorflow.keras":
409
+ def _get_tf_keras() -> ModuleType:
421
410
  return cast(
422
- "tensorflow",
411
+ ModuleType,
423
412
  util.get_module("tensorflow", "ModelAdapter requires `tensorflow`"),
424
413
  ).keras
425
414
 
@@ -431,7 +420,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
431
420
  @staticmethod
432
421
  def _deserialize(
433
422
  dir_or_file_path: str,
434
- ) -> "tensorflow.keras.Model":
423
+ ) -> tensorflow.keras.Model:
435
424
  return _get_tf_keras().models.load_model(dir_or_file_path)
436
425
 
437
426
  @staticmethod
@@ -439,7 +428,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
439
428
  return isinstance(obj, _get_tf_keras().models.Model)
440
429
 
441
430
  @staticmethod
442
- def _serialize(model_obj: "tensorflow.keras.Model", dir_or_file_path: str) -> None:
431
+ def _serialize(model_obj: tensorflow.keras.Model, dir_or_file_path: str) -> None:
443
432
  _get_tf_keras().models.save_model(
444
433
  model_obj, dir_or_file_path, include_optimizer=True
445
434
  )
@@ -71,10 +71,10 @@ class Video(BatchableMedia):
71
71
  import numpy as np
72
72
  import wandb
73
73
 
74
- wandb.init()
74
+ run = wandb.init()
75
75
  # axes are (time, channel, height, width)
76
76
  frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
77
- wandb.log({"video": wandb.Video(frames, fps=4)})
77
+ run.log({"video": wandb.Video(frames, fps=4)})
78
78
  ```
79
79
  """
80
80
 
@@ -138,20 +138,11 @@ class Video(BatchableMedia):
138
138
  self.encode(fps=fps)
139
139
 
140
140
  def encode(self, fps: int = 4) -> None:
141
- # Try to import ImageSequenceClip from the appropriate MoviePy module
142
- mpy = None
143
- try:
144
- # Attempt to load moviepy.editor for MoviePy < 2.0
145
- mpy = util.get_module(
146
- "moviepy.editor",
147
- required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
148
- )
149
- except wandb.Error:
150
- # Fallback to moviepy for MoviePy >= 2.0
151
- mpy = util.get_module(
152
- "moviepy",
153
- required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
154
- )
141
+ # import ImageSequenceClip from the appropriate MoviePy module
142
+ mpy = util.get_module(
143
+ "moviepy.video.io.ImageSequenceClip",
144
+ required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
145
+ )
155
146
 
156
147
  tensor = self._prepare_video(self.data)
157
148
  _, self._height, self._width, self._channels = tensor.shape # type: ignore
@@ -35,6 +35,7 @@ from wandb.sdk.artifacts.artifact import Artifact
35
35
  from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
36
36
  from wandb.sdk.artifacts.staging import get_staging_dir
37
37
  from wandb.sdk.lib import json_util as json
38
+ from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle
38
39
  from wandb.util import (
39
40
  WandBJSONEncoderOld,
40
41
  get_h5_typename,
@@ -46,9 +47,7 @@ from wandb.util import (
46
47
  )
47
48
 
48
49
  from ..data_types.utils import history_dict_to_json, val_to_json
49
- from ..lib.mailbox import MailboxHandle
50
50
  from . import summary_record as sr
51
- from .message_future import MessageFuture
52
51
 
53
52
  MANIFEST_FILE_SIZE_THRESHOLD = 100_000
54
53
 
@@ -561,7 +560,7 @@ class InterfaceBase:
561
560
  def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None:
562
561
  raise NotImplementedError
563
562
 
564
- def communicate_artifact(
563
+ def deliver_artifact(
565
564
  self,
566
565
  run: "Run",
567
566
  artifact: "Artifact",
@@ -571,7 +570,7 @@ class InterfaceBase:
571
570
  is_user_created: bool = False,
572
571
  use_after_commit: bool = False,
573
572
  finalize: bool = True,
574
- ) -> MessageFuture:
573
+ ) -> MailboxHandle:
575
574
  proto_run = self._make_run(run)
576
575
  proto_artifact = self._make_artifact(artifact)
577
576
  proto_artifact.run_id = proto_run.run_id
@@ -589,13 +588,14 @@ class InterfaceBase:
589
588
  if history_step is not None:
590
589
  log_artifact.history_step = history_step
591
590
  log_artifact.staging_dir = get_staging_dir()
592
- resp = self._communicate_artifact(log_artifact)
591
+ resp = self._deliver_artifact(log_artifact)
593
592
  return resp
594
593
 
595
594
  @abstractmethod
596
- def _communicate_artifact(
597
- self, log_artifact: pb.LogArtifactRequest
598
- ) -> MessageFuture:
595
+ def _deliver_artifact(
596
+ self,
597
+ log_artifact: pb.LogArtifactRequest,
598
+ ) -> MailboxHandle:
599
599
  raise NotImplementedError
600
600
 
601
601
  def deliver_download_artifact(
@@ -877,10 +877,22 @@ class InterfaceBase:
877
877
  # Drop indicates that the internal process has already been shutdown
878
878
  if self._drop:
879
879
  return
880
- _ = self._communicate_shutdown()
880
+
881
+ handle = self._deliver_shutdown()
882
+
883
+ try:
884
+ handle.wait_or(timeout=30)
885
+ except TimeoutError:
886
+ # This can happen if the server fails to respond due to a bug
887
+ # or due to being very busy.
888
+ logger.warning("timed out communicating shutdown")
889
+ except HandleAbandonedError:
890
+ # This can happen if the connection to the server is closed
891
+ # before a response is read.
892
+ logger.warning("handle abandoned while communicating shutdown")
881
893
 
882
894
  @abstractmethod
883
- def _communicate_shutdown(self) -> None:
895
+ def _deliver_shutdown(self) -> MailboxHandle:
884
896
  raise NotImplementedError
885
897
 
886
898
  def deliver_run(self, run: "Run") -> MailboxHandle:
@@ -979,6 +991,10 @@ class InterfaceBase:
979
991
  def _deliver_exit(self, exit_data: pb.RunExitRecord) -> MailboxHandle:
980
992
  raise NotImplementedError
981
993
 
994
+ @abstractmethod
995
+ def deliver_operation_stats(self) -> MailboxHandle:
996
+ raise NotImplementedError
997
+
982
998
  def deliver_poll_exit(self) -> MailboxHandle:
983
999
  poll_exit = pb.PollExitRequest()
984
1000
  return self._deliver_poll_exit(poll_exit)
@@ -8,7 +8,8 @@ import logging
8
8
  from multiprocessing.process import BaseProcess
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
- from ..lib.mailbox import Mailbox
11
+ from wandb.sdk.mailbox import Mailbox
12
+
12
13
  from .interface_shared import InterfaceShared
13
14
  from .router_queue import MessageQueueRouter
14
15
 
@@ -22,21 +23,17 @@ logger = logging.getLogger("wandb")
22
23
 
23
24
 
24
25
  class InterfaceQueue(InterfaceShared):
25
- record_q: Optional["Queue[pb.Record]"]
26
- result_q: Optional["Queue[pb.Result]"]
27
- _mailbox: Optional[Mailbox]
28
-
29
26
  def __init__(
30
27
  self,
31
28
  record_q: Optional["Queue[pb.Record]"] = None,
32
29
  result_q: Optional["Queue[pb.Result]"] = None,
33
30
  process: Optional[BaseProcess] = None,
34
- process_check: bool = True,
35
31
  mailbox: Optional[Mailbox] = None,
36
32
  ) -> None:
37
33
  self.record_q = record_q
38
34
  self.result_q = result_q
39
- super().__init__(process=process, process_check=process_check, mailbox=mailbox)
35
+ self._process = process
36
+ super().__init__(mailbox=mailbox)
40
37
 
41
38
  def _init_router(self) -> None:
42
39
  if self.record_q and self.result_q:
@@ -45,7 +42,7 @@ class InterfaceQueue(InterfaceShared):
45
42
  )
46
43
 
47
44
  def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
48
- if self._process_check and self._process and not self._process.is_alive():
45
+ if self._process and not self._process.is_alive():
49
46
  raise Exception("The wandb backend process has shutdown")
50
47
  if local:
51
48
  record.control.local = local
@@ -5,12 +5,11 @@ See interface.py for how interface classes relate to each other.
5
5
  """
6
6
 
7
7
  import logging
8
- from multiprocessing.process import BaseProcess
9
8
  from typing import TYPE_CHECKING, Optional
10
9
 
11
10
  from wandb.proto import wandb_internal_pb2 as pb
11
+ from wandb.sdk.mailbox import Mailbox
12
12
 
13
- from ..lib.mailbox import Mailbox
14
13
  from .interface_queue import InterfaceQueue
15
14
  from .router_relay import MessageRelayRouter
16
15
 
@@ -31,15 +30,11 @@ class InterfaceRelay(InterfaceQueue):
31
30
  record_q: Optional["Queue[pb.Record]"] = None,
32
31
  result_q: Optional["Queue[pb.Result]"] = None,
33
32
  relay_q: Optional["Queue[pb.Result]"] = None,
34
- process: Optional[BaseProcess] = None,
35
- process_check: bool = True,
36
33
  ) -> None:
37
34
  self.relay_q = relay_q
38
35
  super().__init__(
39
36
  record_q=record_q,
40
37
  result_q=result_q,
41
- process=process,
42
- process_check=process_check,
43
38
  mailbox=mailbox,
44
39
  )
45
40
 
@@ -5,44 +5,27 @@ See interface.py for how interface classes relate to each other.
5
5
  """
6
6
 
7
7
  import logging
8
- import time
9
8
  from abc import abstractmethod
10
- from multiprocessing.process import BaseProcess
11
9
  from typing import Any, Optional, cast
12
10
 
13
- import wandb
14
11
  from wandb.proto import wandb_internal_pb2 as pb
15
12
  from wandb.proto import wandb_telemetry_pb2 as tpb
13
+ from wandb.sdk.mailbox import Mailbox, MailboxHandle
16
14
  from wandb.util import json_dumps_safer, json_friendly
17
15
 
18
- from ..lib.mailbox import Mailbox, MailboxHandle
19
16
  from .interface import InterfaceBase
20
- from .message_future import MessageFuture
21
17
  from .router import MessageRouter
22
18
 
23
19
  logger = logging.getLogger("wandb")
24
20
 
25
21
 
26
22
  class InterfaceShared(InterfaceBase):
27
- process: Optional[BaseProcess]
28
- _process_check: bool
29
23
  _router: Optional[MessageRouter]
30
24
  _mailbox: Optional[Mailbox]
31
- _transport_success_timestamp: float
32
- _transport_failed: bool
33
25
 
34
- def __init__(
35
- self,
36
- process: Optional[BaseProcess] = None,
37
- process_check: bool = True,
38
- mailbox: Optional[Any] = None,
39
- ) -> None:
26
+ def __init__(self, mailbox: Optional[Any] = None) -> None:
40
27
  super().__init__()
41
- self._transport_success_timestamp = time.monotonic()
42
- self._transport_failed = False
43
- self._process = process
44
28
  self._router = None
45
- self._process_check = process_check
46
29
  self._mailbox = mailbox
47
30
  self._init_router()
48
31
 
@@ -50,20 +33,6 @@ class InterfaceShared(InterfaceBase):
50
33
  def _init_router(self) -> None:
51
34
  raise NotImplementedError
52
35
 
53
- @property
54
- def transport_failed(self) -> bool:
55
- return self._transport_failed
56
-
57
- @property
58
- def transport_success_timestamp(self) -> float:
59
- return self._transport_success_timestamp
60
-
61
- def _transport_mark_failed(self) -> None:
62
- self._transport_failed = True
63
-
64
- def _transport_mark_success(self) -> None:
65
- self._transport_success_timestamp = time.monotonic()
66
-
67
36
  def _publish_output(self, outdata: pb.OutputRecord) -> None:
68
37
  rec = pb.Record()
69
38
  rec.output.CopyFrom(outdata)
@@ -114,15 +83,8 @@ class InterfaceShared(InterfaceBase):
114
83
  item.value_json = json_dumps_safer(json_friendly(v)[0])
115
84
  return stats
116
85
 
117
- def _make_login(self, api_key: Optional[str] = None) -> pb.LoginRequest:
118
- login = pb.LoginRequest()
119
- if api_key:
120
- login.api_key = api_key
121
- return login
122
-
123
86
  def _make_request( # noqa: C901
124
87
  self,
125
- login: Optional[pb.LoginRequest] = None,
126
88
  get_summary: Optional[pb.GetSummaryRequest] = None,
127
89
  pause: Optional[pb.PauseRequest] = None,
128
90
  resume: Optional[pb.ResumeRequest] = None,
@@ -130,6 +92,7 @@ class InterfaceShared(InterfaceBase):
130
92
  stop_status: Optional[pb.StopStatusRequest] = None,
131
93
  internal_messages: Optional[pb.InternalMessagesRequest] = None,
132
94
  network_status: Optional[pb.NetworkStatusRequest] = None,
95
+ operation_stats: Optional[pb.OperationStatsRequest] = None,
133
96
  poll_exit: Optional[pb.PollExitRequest] = None,
134
97
  partial_history: Optional[pb.PartialHistoryRequest] = None,
135
98
  sampled_history: Optional[pb.SampledHistoryRequest] = None,
@@ -158,9 +121,7 @@ class InterfaceShared(InterfaceBase):
158
121
  metadata: Optional[pb.MetadataRequest] = None,
159
122
  ) -> pb.Record:
160
123
  request = pb.Request()
161
- if login:
162
- request.login.CopyFrom(login)
163
- elif get_summary:
124
+ if get_summary:
164
125
  request.get_summary.CopyFrom(get_summary)
165
126
  elif pause:
166
127
  request.pause.CopyFrom(pause)
@@ -174,6 +135,8 @@ class InterfaceShared(InterfaceBase):
174
135
  request.internal_messages.CopyFrom(internal_messages)
175
136
  elif network_status:
176
137
  request.network_status.CopyFrom(network_status)
138
+ elif operation_stats:
139
+ request.operations.CopyFrom(operation_stats)
177
140
  elif poll_exit:
178
141
  request.poll_exit.CopyFrom(poll_exit)
179
142
  elif partial_history:
@@ -307,35 +270,6 @@ class InterfaceShared(InterfaceBase):
307
270
  def _publish(self, record: pb.Record, local: Optional[bool] = None) -> None:
308
271
  raise NotImplementedError
309
272
 
310
- def _communicate(
311
- self, rec: pb.Record, timeout: Optional[int] = 30, local: Optional[bool] = None
312
- ) -> Optional[pb.Result]:
313
- return self._communicate_async(rec, local=local).get(timeout=timeout)
314
-
315
- def _communicate_async(
316
- self, rec: pb.Record, local: Optional[bool] = None
317
- ) -> MessageFuture:
318
- assert self._router
319
- if self._process_check and self._process and not self._process.is_alive():
320
- raise Exception("The wandb backend process has shutdown")
321
- future = self._router.send_and_receive(rec, local=local)
322
- return future
323
-
324
- def communicate_login(
325
- self, api_key: Optional[str] = None, timeout: Optional[int] = 15
326
- ) -> pb.LoginResponse:
327
- login = self._make_login(api_key)
328
- rec = self._make_request(login=login)
329
- result = self._communicate(rec, timeout=timeout)
330
- if result is None:
331
- # TODO: friendlier error message here
332
- raise wandb.Error(
333
- "Couldn't communicate with backend after {} seconds".format(timeout)
334
- )
335
- login_response = result.response.login_response
336
- assert login_response
337
- return login_response
338
-
339
273
  def _publish_defer(self, state: "pb.DeferRequest.DeferState.V") -> None:
340
274
  defer = pb.DeferRequest(state=state)
341
275
  rec = self._make_request(defer=defer)
@@ -358,11 +292,6 @@ class InterfaceShared(InterfaceBase):
358
292
  rec = self._make_record(final=final)
359
293
  self._publish(rec)
360
294
 
361
- def publish_login(self, api_key: Optional[str] = None) -> None:
362
- login = self._make_login(api_key)
363
- rec = self._make_request(login=login)
364
- self._publish(rec)
365
-
366
295
  def _publish_pause(self, pause: pb.PauseRequest) -> None:
367
296
  rec = self._make_request(pause=pause)
368
297
  self._publish(rec)
@@ -410,9 +339,12 @@ class InterfaceShared(InterfaceBase):
410
339
  rec = self._make_record(use_artifact=use_artifact)
411
340
  self._publish(rec)
412
341
 
413
- def _communicate_artifact(self, log_artifact: pb.LogArtifactRequest) -> Any:
342
+ def _deliver_artifact(
343
+ self,
344
+ log_artifact: pb.LogArtifactRequest,
345
+ ) -> MailboxHandle:
414
346
  rec = self._make_request(log_artifact=log_artifact)
415
- return self._communicate_async(rec)
347
+ return self._deliver_record(rec)
416
348
 
417
349
  def _deliver_download_artifact(
418
350
  self, download_artifact: pb.DownloadArtifactRequest
@@ -449,11 +381,10 @@ class InterfaceShared(InterfaceBase):
449
381
  record = self._make_request(keepalive=keepalive)
450
382
  self._publish(record)
451
383
 
452
- def _communicate_shutdown(self) -> None:
453
- # shutdown
384
+ def _deliver_shutdown(self) -> MailboxHandle:
454
385
  request = pb.Request(shutdown=pb.ShutdownRequest())
455
386
  record = self._make_record(request=request)
456
- _ = self._communicate(record)
387
+ return self._deliver_record(record)
457
388
 
458
389
  def _get_mailbox(self) -> Mailbox:
459
390
  mailbox = self._mailbox
@@ -462,7 +393,10 @@ class InterfaceShared(InterfaceBase):
462
393
 
463
394
  def _deliver_record(self, record: pb.Record) -> MailboxHandle:
464
395
  mailbox = self._get_mailbox()
465
- handle = mailbox._deliver_record(record, interface=self)
396
+
397
+ handle = mailbox.require_response(record)
398
+ self._publish(record)
399
+
466
400
  return handle
467
401
 
468
402
  def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle:
@@ -497,6 +431,10 @@ class InterfaceShared(InterfaceBase):
497
431
  record = self._make_record(exit=exit_data)
498
432
  return self._deliver_record(record)
499
433
 
434
+ def deliver_operation_stats(self):
435
+ record = self._make_request(operation_stats=pb.OperationStatsRequest())
436
+ return self._deliver_record(record)
437
+
500
438
  def _deliver_poll_exit(self, poll_exit: pb.PollExitRequest) -> MailboxHandle:
501
439
  record = self._make_request(poll_exit=poll_exit)
502
440
  return self._deliver_record(record)
@@ -539,22 +477,6 @@ class InterfaceShared(InterfaceBase):
539
477
  record = self._make_request(run_status=run_status)
540
478
  return self._deliver_record(record)
541
479
 
542
- def _transport_keepalive_failed(self, keepalive_interval: int = 5) -> bool:
543
- if self._transport_failed:
544
- return True
545
-
546
- now = time.monotonic()
547
- if now < self._transport_success_timestamp + keepalive_interval:
548
- return False
549
-
550
- try:
551
- self.publish_keepalive()
552
- except Exception:
553
- self._transport_mark_failed()
554
- else:
555
- self._transport_mark_success()
556
- return self._transport_failed
557
-
558
480
  def join(self) -> None:
559
481
  super().join()
560
482