wandb 0.19.6rc4__py3-none-any.whl → 0.19.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/RECORD +70 -57
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
@@ -1,18 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import shutil
|
3
5
|
import sys
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
ClassVar,
|
8
|
-
Generic,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Type,
|
12
|
-
TypeVar,
|
13
|
-
Union,
|
14
|
-
cast,
|
15
|
-
)
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
|
16
7
|
|
17
8
|
import wandb
|
18
9
|
from wandb import util
|
@@ -23,22 +14,24 @@ from wandb.sdk.lib.paths import LogicalPath
|
|
23
14
|
from ._private import MEDIA_TMP
|
24
15
|
from .base_types.wb_value import WBValue
|
25
16
|
|
26
|
-
if TYPE_CHECKING:
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from types import ModuleType
|
19
|
+
|
27
20
|
import cloudpickle # type: ignore
|
28
21
|
import sklearn # type: ignore
|
29
22
|
import tensorflow # type: ignore
|
30
23
|
import torch # type: ignore
|
24
|
+
from typing_extensions import Self
|
31
25
|
|
32
26
|
from wandb.sdk.artifacts.artifact import Artifact
|
33
|
-
|
34
|
-
from ..wandb_run import Run as LocalRun
|
27
|
+
from wandb.sdk.wandb_run import Run as LocalRun
|
35
28
|
|
36
29
|
|
37
30
|
DEBUG_MODE = False
|
38
31
|
|
39
32
|
|
40
33
|
def _add_deterministic_dir_to_artifact(
|
41
|
-
artifact:
|
34
|
+
artifact: Artifact, dir_name: str, target_dir_root: str
|
42
35
|
) -> str:
|
43
36
|
file_paths = []
|
44
37
|
for dirpath, _, filenames in os.walk(dir_name, topdown=True):
|
@@ -50,7 +43,7 @@ def _add_deterministic_dir_to_artifact(
|
|
50
43
|
return target_path
|
51
44
|
|
52
45
|
|
53
|
-
def _load_dir_from_artifact(source_artifact:
|
46
|
+
def _load_dir_from_artifact(source_artifact: Artifact, path: str) -> str:
|
54
47
|
dl_path = None
|
55
48
|
|
56
49
|
# Look through the entire manifest to find all of the files in the directory.
|
@@ -79,14 +72,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
79
72
|
_log_type: ClassVar[str]
|
80
73
|
_path_extension: ClassVar[str]
|
81
74
|
|
82
|
-
_model_obj:
|
83
|
-
_path:
|
84
|
-
_input_obj_or_path:
|
75
|
+
_model_obj: SavedModelObjType | None
|
76
|
+
_path: str | None
|
77
|
+
_input_obj_or_path: SavedModelObjType | str
|
85
78
|
|
86
79
|
# Public Methods
|
87
|
-
def __init__(
|
88
|
-
self, obj_or_path: Union[SavedModelObjType, str], **kwargs: Any
|
89
|
-
) -> None:
|
80
|
+
def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
|
90
81
|
super().__init__()
|
91
82
|
if self.__class__ == _SavedModel:
|
92
83
|
raise TypeError(
|
@@ -115,7 +106,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
115
106
|
self._unset_obj()
|
116
107
|
|
117
108
|
@staticmethod
|
118
|
-
def init(obj_or_path: Any, **kwargs: Any) ->
|
109
|
+
def init(obj_or_path: Any, **kwargs: Any) -> _SavedModel:
|
119
110
|
maybe_instance = _SavedModel._maybe_init(obj_or_path, **kwargs)
|
120
111
|
if maybe_instance is None:
|
121
112
|
raise ValueError(
|
@@ -125,8 +116,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
125
116
|
|
126
117
|
@classmethod
|
127
118
|
def from_json(
|
128
|
-
cls:
|
129
|
-
) ->
|
119
|
+
cls: type[_SavedModel], json_obj: dict, source_artifact: Artifact
|
120
|
+
) -> _SavedModel:
|
130
121
|
path = json_obj["path"]
|
131
122
|
|
132
123
|
# First, if the entry is a file, the download it.
|
@@ -143,7 +134,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
143
134
|
# and specified adapter.
|
144
135
|
return cls(dl_path)
|
145
136
|
|
146
|
-
def to_json(self, run_or_artifact:
|
137
|
+
def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
|
147
138
|
# Unlike other data types, we do not allow adding to a Run directly. There is a
|
148
139
|
# bit of tech debt in the other data types which requires the input to `to_json`
|
149
140
|
# to accept a Run or Artifact. However, Run additions should be deprecated in the future.
|
@@ -218,8 +209,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
218
209
|
# Private Class Methods
|
219
210
|
@classmethod
|
220
211
|
def _maybe_init(
|
221
|
-
cls:
|
222
|
-
) ->
|
212
|
+
cls: type[_SavedModel], obj_or_path: Any, **kwargs: Any
|
213
|
+
) -> _SavedModel | None:
|
223
214
|
# _maybe_init is an exception-safe method that will return an instance of this class
|
224
215
|
# (or any subclass of this class - recursively) OR None if no subclass constructor is found.
|
225
216
|
# We first try the current class, then recursively call this method on children classes. This pattern
|
@@ -241,7 +232,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
241
232
|
return None
|
242
233
|
|
243
234
|
@classmethod
|
244
|
-
def _tmp_path(cls:
|
235
|
+
def _tmp_path(cls: type[_SavedModel]) -> str:
|
245
236
|
# Generates a tmp path under our MEDIA_TMP directory which confirms to the file
|
246
237
|
# or folder preferences of the class.
|
247
238
|
assert isinstance(cls._path_extension, str), "_path_extension must be a string"
|
@@ -286,13 +277,13 @@ PicklingSavedModelObjType = TypeVar("PicklingSavedModelObjType")
|
|
286
277
|
|
287
278
|
|
288
279
|
class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
289
|
-
_dep_py_files:
|
290
|
-
_dep_py_files_path:
|
280
|
+
_dep_py_files: list[str] | None = None
|
281
|
+
_dep_py_files_path: str | None = None
|
291
282
|
|
292
283
|
def __init__(
|
293
284
|
self,
|
294
|
-
obj_or_path:
|
295
|
-
dep_py_files:
|
285
|
+
obj_or_path: SavedModelObjType | str,
|
286
|
+
dep_py_files: list[str] | None = None,
|
296
287
|
):
|
297
288
|
super().__init__(obj_or_path)
|
298
289
|
if self.__class__ == _PicklingSavedModel:
|
@@ -319,9 +310,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
319
310
|
raise ValueError(f"Invalid dependency file: {extra_file}")
|
320
311
|
|
321
312
|
@classmethod
|
322
|
-
def from_json(
|
323
|
-
cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
|
324
|
-
) -> "_PicklingSavedModel":
|
313
|
+
def from_json(cls, json_obj: dict, source_artifact: Artifact) -> Self:
|
325
314
|
backup_path = [p for p in sys.path]
|
326
315
|
if (
|
327
316
|
"dep_py_files_path" in json_obj
|
@@ -337,7 +326,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
337
326
|
|
338
327
|
return inst # type: ignore
|
339
328
|
|
340
|
-
def to_json(self, run_or_artifact:
|
329
|
+
def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
|
341
330
|
json_obj = super().to_json(run_or_artifact)
|
342
331
|
assert isinstance(run_or_artifact, wandb.Artifact)
|
343
332
|
if self._dep_py_files_path is not None:
|
@@ -361,7 +350,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
|
|
361
350
|
_path_extension = "pt"
|
362
351
|
|
363
352
|
@staticmethod
|
364
|
-
def _deserialize(dir_or_file_path: str) ->
|
353
|
+
def _deserialize(dir_or_file_path: str) -> torch.nn.Module:
|
365
354
|
return _get_torch().load(dir_or_file_path, weights_only=False)
|
366
355
|
|
367
356
|
@staticmethod
|
@@ -369,7 +358,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
|
|
369
358
|
return isinstance(obj, _get_torch().nn.Module)
|
370
359
|
|
371
360
|
@staticmethod
|
372
|
-
def _serialize(model_obj:
|
361
|
+
def _serialize(model_obj: torch.nn.Module, dir_or_file_path: str) -> None:
|
373
362
|
_get_torch().save(
|
374
363
|
model_obj,
|
375
364
|
dir_or_file_path,
|
@@ -391,7 +380,7 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
|
|
391
380
|
@staticmethod
|
392
381
|
def _deserialize(
|
393
382
|
dir_or_file_path: str,
|
394
|
-
) ->
|
383
|
+
) -> sklearn.base.BaseEstimator:
|
395
384
|
with open(dir_or_file_path, "rb") as file:
|
396
385
|
model = _get_cloudpickle().load(file)
|
397
386
|
return model
|
@@ -410,16 +399,16 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
|
|
410
399
|
|
411
400
|
@staticmethod
|
412
401
|
def _serialize(
|
413
|
-
model_obj:
|
402
|
+
model_obj: sklearn.base.BaseEstimator, dir_or_file_path: str
|
414
403
|
) -> None:
|
415
404
|
dynamic_cloudpickle = _get_cloudpickle()
|
416
405
|
with open(dir_or_file_path, "wb") as file:
|
417
406
|
dynamic_cloudpickle.dump(model_obj, file)
|
418
407
|
|
419
408
|
|
420
|
-
def _get_tf_keras() ->
|
409
|
+
def _get_tf_keras() -> ModuleType:
|
421
410
|
return cast(
|
422
|
-
|
411
|
+
ModuleType,
|
423
412
|
util.get_module("tensorflow", "ModelAdapter requires `tensorflow`"),
|
424
413
|
).keras
|
425
414
|
|
@@ -431,7 +420,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
|
|
431
420
|
@staticmethod
|
432
421
|
def _deserialize(
|
433
422
|
dir_or_file_path: str,
|
434
|
-
) ->
|
423
|
+
) -> tensorflow.keras.Model:
|
435
424
|
return _get_tf_keras().models.load_model(dir_or_file_path)
|
436
425
|
|
437
426
|
@staticmethod
|
@@ -439,7 +428,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
|
|
439
428
|
return isinstance(obj, _get_tf_keras().models.Model)
|
440
429
|
|
441
430
|
@staticmethod
|
442
|
-
def _serialize(model_obj:
|
431
|
+
def _serialize(model_obj: tensorflow.keras.Model, dir_or_file_path: str) -> None:
|
443
432
|
_get_tf_keras().models.save_model(
|
444
433
|
model_obj, dir_or_file_path, include_optimizer=True
|
445
434
|
)
|
wandb/sdk/data_types/video.py
CHANGED
@@ -71,10 +71,10 @@ class Video(BatchableMedia):
|
|
71
71
|
import numpy as np
|
72
72
|
import wandb
|
73
73
|
|
74
|
-
wandb.init()
|
74
|
+
run = wandb.init()
|
75
75
|
# axes are (time, channel, height, width)
|
76
76
|
frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
|
77
|
-
|
77
|
+
run.log({"video": wandb.Video(frames, fps=4)})
|
78
78
|
```
|
79
79
|
"""
|
80
80
|
|
@@ -138,20 +138,11 @@ class Video(BatchableMedia):
|
|
138
138
|
self.encode(fps=fps)
|
139
139
|
|
140
140
|
def encode(self, fps: int = 4) -> None:
|
141
|
-
#
|
142
|
-
mpy =
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
"moviepy.editor",
|
147
|
-
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
148
|
-
)
|
149
|
-
except wandb.Error:
|
150
|
-
# Fallback to moviepy for MoviePy >= 2.0
|
151
|
-
mpy = util.get_module(
|
152
|
-
"moviepy",
|
153
|
-
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
154
|
-
)
|
141
|
+
# import ImageSequenceClip from the appropriate MoviePy module
|
142
|
+
mpy = util.get_module(
|
143
|
+
"moviepy.video.io.ImageSequenceClip",
|
144
|
+
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
145
|
+
)
|
155
146
|
|
156
147
|
tensor = self._prepare_video(self.data)
|
157
148
|
_, self._height, self._width, self._channels = tensor.shape # type: ignore
|
wandb/sdk/interface/interface.py
CHANGED
@@ -35,6 +35,7 @@ from wandb.sdk.artifacts.artifact import Artifact
|
|
35
35
|
from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
|
36
36
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
37
37
|
from wandb.sdk.lib import json_util as json
|
38
|
+
from wandb.sdk.mailbox import HandleAbandonedError, MailboxHandle
|
38
39
|
from wandb.util import (
|
39
40
|
WandBJSONEncoderOld,
|
40
41
|
get_h5_typename,
|
@@ -46,9 +47,7 @@ from wandb.util import (
|
|
46
47
|
)
|
47
48
|
|
48
49
|
from ..data_types.utils import history_dict_to_json, val_to_json
|
49
|
-
from ..lib.mailbox import MailboxHandle
|
50
50
|
from . import summary_record as sr
|
51
|
-
from .message_future import MessageFuture
|
52
51
|
|
53
52
|
MANIFEST_FILE_SIZE_THRESHOLD = 100_000
|
54
53
|
|
@@ -561,7 +560,7 @@ class InterfaceBase:
|
|
561
560
|
def _publish_use_artifact(self, proto_artifact: pb.UseArtifactRecord) -> None:
|
562
561
|
raise NotImplementedError
|
563
562
|
|
564
|
-
def
|
563
|
+
def deliver_artifact(
|
565
564
|
self,
|
566
565
|
run: "Run",
|
567
566
|
artifact: "Artifact",
|
@@ -571,7 +570,7 @@ class InterfaceBase:
|
|
571
570
|
is_user_created: bool = False,
|
572
571
|
use_after_commit: bool = False,
|
573
572
|
finalize: bool = True,
|
574
|
-
) ->
|
573
|
+
) -> MailboxHandle:
|
575
574
|
proto_run = self._make_run(run)
|
576
575
|
proto_artifact = self._make_artifact(artifact)
|
577
576
|
proto_artifact.run_id = proto_run.run_id
|
@@ -589,13 +588,14 @@ class InterfaceBase:
|
|
589
588
|
if history_step is not None:
|
590
589
|
log_artifact.history_step = history_step
|
591
590
|
log_artifact.staging_dir = get_staging_dir()
|
592
|
-
resp = self.
|
591
|
+
resp = self._deliver_artifact(log_artifact)
|
593
592
|
return resp
|
594
593
|
|
595
594
|
@abstractmethod
|
596
|
-
def
|
597
|
-
self,
|
598
|
-
|
595
|
+
def _deliver_artifact(
|
596
|
+
self,
|
597
|
+
log_artifact: pb.LogArtifactRequest,
|
598
|
+
) -> MailboxHandle:
|
599
599
|
raise NotImplementedError
|
600
600
|
|
601
601
|
def deliver_download_artifact(
|
@@ -877,10 +877,22 @@ class InterfaceBase:
|
|
877
877
|
# Drop indicates that the internal process has already been shutdown
|
878
878
|
if self._drop:
|
879
879
|
return
|
880
|
-
|
880
|
+
|
881
|
+
handle = self._deliver_shutdown()
|
882
|
+
|
883
|
+
try:
|
884
|
+
handle.wait_or(timeout=30)
|
885
|
+
except TimeoutError:
|
886
|
+
# This can happen if the server fails to respond due to a bug
|
887
|
+
# or due to being very busy.
|
888
|
+
logger.warning("timed out communicating shutdown")
|
889
|
+
except HandleAbandonedError:
|
890
|
+
# This can happen if the connection to the server is closed
|
891
|
+
# before a response is read.
|
892
|
+
logger.warning("handle abandoned while communicating shutdown")
|
881
893
|
|
882
894
|
@abstractmethod
|
883
|
-
def
|
895
|
+
def _deliver_shutdown(self) -> MailboxHandle:
|
884
896
|
raise NotImplementedError
|
885
897
|
|
886
898
|
def deliver_run(self, run: "Run") -> MailboxHandle:
|
@@ -979,6 +991,10 @@ class InterfaceBase:
|
|
979
991
|
def _deliver_exit(self, exit_data: pb.RunExitRecord) -> MailboxHandle:
|
980
992
|
raise NotImplementedError
|
981
993
|
|
994
|
+
@abstractmethod
|
995
|
+
def deliver_operation_stats(self) -> MailboxHandle:
|
996
|
+
raise NotImplementedError
|
997
|
+
|
982
998
|
def deliver_poll_exit(self) -> MailboxHandle:
|
983
999
|
poll_exit = pb.PollExitRequest()
|
984
1000
|
return self._deliver_poll_exit(poll_exit)
|
@@ -8,7 +8,8 @@ import logging
|
|
8
8
|
from multiprocessing.process import BaseProcess
|
9
9
|
from typing import TYPE_CHECKING, Optional
|
10
10
|
|
11
|
-
from
|
11
|
+
from wandb.sdk.mailbox import Mailbox
|
12
|
+
|
12
13
|
from .interface_shared import InterfaceShared
|
13
14
|
from .router_queue import MessageQueueRouter
|
14
15
|
|
@@ -22,21 +23,17 @@ logger = logging.getLogger("wandb")
|
|
22
23
|
|
23
24
|
|
24
25
|
class InterfaceQueue(InterfaceShared):
|
25
|
-
record_q: Optional["Queue[pb.Record]"]
|
26
|
-
result_q: Optional["Queue[pb.Result]"]
|
27
|
-
_mailbox: Optional[Mailbox]
|
28
|
-
|
29
26
|
def __init__(
|
30
27
|
self,
|
31
28
|
record_q: Optional["Queue[pb.Record]"] = None,
|
32
29
|
result_q: Optional["Queue[pb.Result]"] = None,
|
33
30
|
process: Optional[BaseProcess] = None,
|
34
|
-
process_check: bool = True,
|
35
31
|
mailbox: Optional[Mailbox] = None,
|
36
32
|
) -> None:
|
37
33
|
self.record_q = record_q
|
38
34
|
self.result_q = result_q
|
39
|
-
|
35
|
+
self._process = process
|
36
|
+
super().__init__(mailbox=mailbox)
|
40
37
|
|
41
38
|
def _init_router(self) -> None:
|
42
39
|
if self.record_q and self.result_q:
|
@@ -45,7 +42,7 @@ class InterfaceQueue(InterfaceShared):
|
|
45
42
|
)
|
46
43
|
|
47
44
|
def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
|
48
|
-
if self.
|
45
|
+
if self._process and not self._process.is_alive():
|
49
46
|
raise Exception("The wandb backend process has shutdown")
|
50
47
|
if local:
|
51
48
|
record.control.local = local
|
@@ -5,12 +5,11 @@ See interface.py for how interface classes relate to each other.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
-
from multiprocessing.process import BaseProcess
|
9
8
|
from typing import TYPE_CHECKING, Optional
|
10
9
|
|
11
10
|
from wandb.proto import wandb_internal_pb2 as pb
|
11
|
+
from wandb.sdk.mailbox import Mailbox
|
12
12
|
|
13
|
-
from ..lib.mailbox import Mailbox
|
14
13
|
from .interface_queue import InterfaceQueue
|
15
14
|
from .router_relay import MessageRelayRouter
|
16
15
|
|
@@ -31,15 +30,11 @@ class InterfaceRelay(InterfaceQueue):
|
|
31
30
|
record_q: Optional["Queue[pb.Record]"] = None,
|
32
31
|
result_q: Optional["Queue[pb.Result]"] = None,
|
33
32
|
relay_q: Optional["Queue[pb.Result]"] = None,
|
34
|
-
process: Optional[BaseProcess] = None,
|
35
|
-
process_check: bool = True,
|
36
33
|
) -> None:
|
37
34
|
self.relay_q = relay_q
|
38
35
|
super().__init__(
|
39
36
|
record_q=record_q,
|
40
37
|
result_q=result_q,
|
41
|
-
process=process,
|
42
|
-
process_check=process_check,
|
43
38
|
mailbox=mailbox,
|
44
39
|
)
|
45
40
|
|
@@ -5,44 +5,27 @@ See interface.py for how interface classes relate to each other.
|
|
5
5
|
"""
|
6
6
|
|
7
7
|
import logging
|
8
|
-
import time
|
9
8
|
from abc import abstractmethod
|
10
|
-
from multiprocessing.process import BaseProcess
|
11
9
|
from typing import Any, Optional, cast
|
12
10
|
|
13
|
-
import wandb
|
14
11
|
from wandb.proto import wandb_internal_pb2 as pb
|
15
12
|
from wandb.proto import wandb_telemetry_pb2 as tpb
|
13
|
+
from wandb.sdk.mailbox import Mailbox, MailboxHandle
|
16
14
|
from wandb.util import json_dumps_safer, json_friendly
|
17
15
|
|
18
|
-
from ..lib.mailbox import Mailbox, MailboxHandle
|
19
16
|
from .interface import InterfaceBase
|
20
|
-
from .message_future import MessageFuture
|
21
17
|
from .router import MessageRouter
|
22
18
|
|
23
19
|
logger = logging.getLogger("wandb")
|
24
20
|
|
25
21
|
|
26
22
|
class InterfaceShared(InterfaceBase):
|
27
|
-
process: Optional[BaseProcess]
|
28
|
-
_process_check: bool
|
29
23
|
_router: Optional[MessageRouter]
|
30
24
|
_mailbox: Optional[Mailbox]
|
31
|
-
_transport_success_timestamp: float
|
32
|
-
_transport_failed: bool
|
33
25
|
|
34
|
-
def __init__(
|
35
|
-
self,
|
36
|
-
process: Optional[BaseProcess] = None,
|
37
|
-
process_check: bool = True,
|
38
|
-
mailbox: Optional[Any] = None,
|
39
|
-
) -> None:
|
26
|
+
def __init__(self, mailbox: Optional[Any] = None) -> None:
|
40
27
|
super().__init__()
|
41
|
-
self._transport_success_timestamp = time.monotonic()
|
42
|
-
self._transport_failed = False
|
43
|
-
self._process = process
|
44
28
|
self._router = None
|
45
|
-
self._process_check = process_check
|
46
29
|
self._mailbox = mailbox
|
47
30
|
self._init_router()
|
48
31
|
|
@@ -50,20 +33,6 @@ class InterfaceShared(InterfaceBase):
|
|
50
33
|
def _init_router(self) -> None:
|
51
34
|
raise NotImplementedError
|
52
35
|
|
53
|
-
@property
|
54
|
-
def transport_failed(self) -> bool:
|
55
|
-
return self._transport_failed
|
56
|
-
|
57
|
-
@property
|
58
|
-
def transport_success_timestamp(self) -> float:
|
59
|
-
return self._transport_success_timestamp
|
60
|
-
|
61
|
-
def _transport_mark_failed(self) -> None:
|
62
|
-
self._transport_failed = True
|
63
|
-
|
64
|
-
def _transport_mark_success(self) -> None:
|
65
|
-
self._transport_success_timestamp = time.monotonic()
|
66
|
-
|
67
36
|
def _publish_output(self, outdata: pb.OutputRecord) -> None:
|
68
37
|
rec = pb.Record()
|
69
38
|
rec.output.CopyFrom(outdata)
|
@@ -114,15 +83,8 @@ class InterfaceShared(InterfaceBase):
|
|
114
83
|
item.value_json = json_dumps_safer(json_friendly(v)[0])
|
115
84
|
return stats
|
116
85
|
|
117
|
-
def _make_login(self, api_key: Optional[str] = None) -> pb.LoginRequest:
|
118
|
-
login = pb.LoginRequest()
|
119
|
-
if api_key:
|
120
|
-
login.api_key = api_key
|
121
|
-
return login
|
122
|
-
|
123
86
|
def _make_request( # noqa: C901
|
124
87
|
self,
|
125
|
-
login: Optional[pb.LoginRequest] = None,
|
126
88
|
get_summary: Optional[pb.GetSummaryRequest] = None,
|
127
89
|
pause: Optional[pb.PauseRequest] = None,
|
128
90
|
resume: Optional[pb.ResumeRequest] = None,
|
@@ -130,6 +92,7 @@ class InterfaceShared(InterfaceBase):
|
|
130
92
|
stop_status: Optional[pb.StopStatusRequest] = None,
|
131
93
|
internal_messages: Optional[pb.InternalMessagesRequest] = None,
|
132
94
|
network_status: Optional[pb.NetworkStatusRequest] = None,
|
95
|
+
operation_stats: Optional[pb.OperationStatsRequest] = None,
|
133
96
|
poll_exit: Optional[pb.PollExitRequest] = None,
|
134
97
|
partial_history: Optional[pb.PartialHistoryRequest] = None,
|
135
98
|
sampled_history: Optional[pb.SampledHistoryRequest] = None,
|
@@ -158,9 +121,7 @@ class InterfaceShared(InterfaceBase):
|
|
158
121
|
metadata: Optional[pb.MetadataRequest] = None,
|
159
122
|
) -> pb.Record:
|
160
123
|
request = pb.Request()
|
161
|
-
if
|
162
|
-
request.login.CopyFrom(login)
|
163
|
-
elif get_summary:
|
124
|
+
if get_summary:
|
164
125
|
request.get_summary.CopyFrom(get_summary)
|
165
126
|
elif pause:
|
166
127
|
request.pause.CopyFrom(pause)
|
@@ -174,6 +135,8 @@ class InterfaceShared(InterfaceBase):
|
|
174
135
|
request.internal_messages.CopyFrom(internal_messages)
|
175
136
|
elif network_status:
|
176
137
|
request.network_status.CopyFrom(network_status)
|
138
|
+
elif operation_stats:
|
139
|
+
request.operations.CopyFrom(operation_stats)
|
177
140
|
elif poll_exit:
|
178
141
|
request.poll_exit.CopyFrom(poll_exit)
|
179
142
|
elif partial_history:
|
@@ -307,35 +270,6 @@ class InterfaceShared(InterfaceBase):
|
|
307
270
|
def _publish(self, record: pb.Record, local: Optional[bool] = None) -> None:
|
308
271
|
raise NotImplementedError
|
309
272
|
|
310
|
-
def _communicate(
|
311
|
-
self, rec: pb.Record, timeout: Optional[int] = 30, local: Optional[bool] = None
|
312
|
-
) -> Optional[pb.Result]:
|
313
|
-
return self._communicate_async(rec, local=local).get(timeout=timeout)
|
314
|
-
|
315
|
-
def _communicate_async(
|
316
|
-
self, rec: pb.Record, local: Optional[bool] = None
|
317
|
-
) -> MessageFuture:
|
318
|
-
assert self._router
|
319
|
-
if self._process_check and self._process and not self._process.is_alive():
|
320
|
-
raise Exception("The wandb backend process has shutdown")
|
321
|
-
future = self._router.send_and_receive(rec, local=local)
|
322
|
-
return future
|
323
|
-
|
324
|
-
def communicate_login(
|
325
|
-
self, api_key: Optional[str] = None, timeout: Optional[int] = 15
|
326
|
-
) -> pb.LoginResponse:
|
327
|
-
login = self._make_login(api_key)
|
328
|
-
rec = self._make_request(login=login)
|
329
|
-
result = self._communicate(rec, timeout=timeout)
|
330
|
-
if result is None:
|
331
|
-
# TODO: friendlier error message here
|
332
|
-
raise wandb.Error(
|
333
|
-
"Couldn't communicate with backend after {} seconds".format(timeout)
|
334
|
-
)
|
335
|
-
login_response = result.response.login_response
|
336
|
-
assert login_response
|
337
|
-
return login_response
|
338
|
-
|
339
273
|
def _publish_defer(self, state: "pb.DeferRequest.DeferState.V") -> None:
|
340
274
|
defer = pb.DeferRequest(state=state)
|
341
275
|
rec = self._make_request(defer=defer)
|
@@ -358,11 +292,6 @@ class InterfaceShared(InterfaceBase):
|
|
358
292
|
rec = self._make_record(final=final)
|
359
293
|
self._publish(rec)
|
360
294
|
|
361
|
-
def publish_login(self, api_key: Optional[str] = None) -> None:
|
362
|
-
login = self._make_login(api_key)
|
363
|
-
rec = self._make_request(login=login)
|
364
|
-
self._publish(rec)
|
365
|
-
|
366
295
|
def _publish_pause(self, pause: pb.PauseRequest) -> None:
|
367
296
|
rec = self._make_request(pause=pause)
|
368
297
|
self._publish(rec)
|
@@ -410,9 +339,12 @@ class InterfaceShared(InterfaceBase):
|
|
410
339
|
rec = self._make_record(use_artifact=use_artifact)
|
411
340
|
self._publish(rec)
|
412
341
|
|
413
|
-
def
|
342
|
+
def _deliver_artifact(
|
343
|
+
self,
|
344
|
+
log_artifact: pb.LogArtifactRequest,
|
345
|
+
) -> MailboxHandle:
|
414
346
|
rec = self._make_request(log_artifact=log_artifact)
|
415
|
-
return self.
|
347
|
+
return self._deliver_record(rec)
|
416
348
|
|
417
349
|
def _deliver_download_artifact(
|
418
350
|
self, download_artifact: pb.DownloadArtifactRequest
|
@@ -449,11 +381,10 @@ class InterfaceShared(InterfaceBase):
|
|
449
381
|
record = self._make_request(keepalive=keepalive)
|
450
382
|
self._publish(record)
|
451
383
|
|
452
|
-
def
|
453
|
-
# shutdown
|
384
|
+
def _deliver_shutdown(self) -> MailboxHandle:
|
454
385
|
request = pb.Request(shutdown=pb.ShutdownRequest())
|
455
386
|
record = self._make_record(request=request)
|
456
|
-
|
387
|
+
return self._deliver_record(record)
|
457
388
|
|
458
389
|
def _get_mailbox(self) -> Mailbox:
|
459
390
|
mailbox = self._mailbox
|
@@ -462,7 +393,10 @@ class InterfaceShared(InterfaceBase):
|
|
462
393
|
|
463
394
|
def _deliver_record(self, record: pb.Record) -> MailboxHandle:
|
464
395
|
mailbox = self._get_mailbox()
|
465
|
-
|
396
|
+
|
397
|
+
handle = mailbox.require_response(record)
|
398
|
+
self._publish(record)
|
399
|
+
|
466
400
|
return handle
|
467
401
|
|
468
402
|
def _deliver_run(self, run: pb.RunRecord) -> MailboxHandle:
|
@@ -497,6 +431,10 @@ class InterfaceShared(InterfaceBase):
|
|
497
431
|
record = self._make_record(exit=exit_data)
|
498
432
|
return self._deliver_record(record)
|
499
433
|
|
434
|
+
def deliver_operation_stats(self):
|
435
|
+
record = self._make_request(operation_stats=pb.OperationStatsRequest())
|
436
|
+
return self._deliver_record(record)
|
437
|
+
|
500
438
|
def _deliver_poll_exit(self, poll_exit: pb.PollExitRequest) -> MailboxHandle:
|
501
439
|
record = self._make_request(poll_exit=poll_exit)
|
502
440
|
return self._deliver_record(record)
|
@@ -539,22 +477,6 @@ class InterfaceShared(InterfaceBase):
|
|
539
477
|
record = self._make_request(run_status=run_status)
|
540
478
|
return self._deliver_record(record)
|
541
479
|
|
542
|
-
def _transport_keepalive_failed(self, keepalive_interval: int = 5) -> bool:
|
543
|
-
if self._transport_failed:
|
544
|
-
return True
|
545
|
-
|
546
|
-
now = time.monotonic()
|
547
|
-
if now < self._transport_success_timestamp + keepalive_interval:
|
548
|
-
return False
|
549
|
-
|
550
|
-
try:
|
551
|
-
self.publish_keepalive()
|
552
|
-
except Exception:
|
553
|
-
self._transport_mark_failed()
|
554
|
-
else:
|
555
|
-
self._transport_mark_success()
|
556
|
-
return self._transport_failed
|
557
|
-
|
558
480
|
def join(self) -> None:
|
559
481
|
super().join()
|
560
482
|
|