wandb 0.19.6rc4__py3-none-macosx_11_0_arm64.whl → 0.19.8__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +56 -6
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/data_types.py +1 -1
- wandb/env.py +10 -0
- wandb/filesync/dir_watcher.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +51 -95
- wandb/sdk/backend/backend.py +17 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +87 -49
- wandb/sdk/interface/interface_queue.py +5 -15
- wandb/sdk/interface/interface_relay.py +7 -22
- wandb/sdk/interface/interface_shared.py +65 -136
- wandb/sdk/interface/interface_sock.py +3 -21
- wandb/sdk/interface/router.py +42 -68
- wandb/sdk/interface/router_queue.py +13 -11
- wandb/sdk/interface/router_relay.py +26 -13
- wandb/sdk/interface/router_sock.py +12 -16
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +3 -19
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/console_capture.py +172 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/redirect.py +102 -76
- wandb/sdk/lib/service_connection.py +37 -17
- wandb/sdk/lib/sock_client.py +6 -56
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/mailbox.py +135 -0
- wandb/sdk/mailbox/mailbox_handle.py +127 -0
- wandb/sdk/mailbox/response_handle.py +167 -0
- wandb/sdk/mailbox/wait_with_progress.py +135 -0
- wandb/sdk/service/server_sock.py +9 -3
- wandb/sdk/service/streams.py +75 -78
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +72 -75
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +90 -97
- wandb/sdk/wandb_settings.py +10 -4
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -10
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
- wandb/sdk/interface/message_future.py +0 -27
- wandb/sdk/interface/message_future_poll.py +0 -50
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/backend/backend.py
CHANGED
@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
|
16
16
|
import wandb
|
17
17
|
from wandb.sdk.interface.interface import InterfaceBase
|
18
18
|
from wandb.sdk.interface.interface_queue import InterfaceQueue
|
19
|
+
from wandb.sdk.interface.router_queue import MessageQueueRouter
|
19
20
|
from wandb.sdk.internal.internal import wandb_internal
|
20
21
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
21
|
-
from wandb.sdk.
|
22
|
+
from wandb.sdk.mailbox import Mailbox
|
22
23
|
from wandb.sdk.wandb_settings import Settings
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
@@ -49,17 +50,18 @@ class BackendThread(threading.Thread):
|
|
49
50
|
class Backend:
|
50
51
|
# multiprocessing context or module
|
51
52
|
_multiprocessing: multiprocessing.context.BaseContext
|
53
|
+
|
52
54
|
interface: Optional[InterfaceBase]
|
55
|
+
_router: Optional[MessageQueueRouter]
|
56
|
+
|
53
57
|
_internal_pid: Optional[int]
|
54
58
|
wandb_process: Optional[multiprocessing.process.BaseProcess]
|
55
59
|
_settings: Settings
|
56
60
|
record_q: Optional["RecordQueue"]
|
57
61
|
result_q: Optional["ResultQueue"]
|
58
|
-
_mailbox: Mailbox
|
59
62
|
|
60
63
|
def __init__(
|
61
64
|
self,
|
62
|
-
mailbox: Mailbox,
|
63
65
|
settings: Settings,
|
64
66
|
log_level: Optional[int] = None,
|
65
67
|
service: "Optional[service_connection.ServiceConnection]" = None,
|
@@ -68,12 +70,14 @@ class Backend:
|
|
68
70
|
self.record_q = None
|
69
71
|
self.result_q = None
|
70
72
|
self.wandb_process = None
|
73
|
+
|
71
74
|
self.interface = None
|
75
|
+
self._router = None
|
76
|
+
|
72
77
|
self._internal_pid = None
|
73
78
|
self._settings = settings
|
74
79
|
self._log_level = log_level
|
75
80
|
self._service = service
|
76
|
-
self._mailbox = mailbox
|
77
81
|
|
78
82
|
self._multiprocessing = multiprocessing # type: ignore
|
79
83
|
self._multiprocessing_setup()
|
@@ -136,7 +140,6 @@ class Backend:
|
|
136
140
|
if self._service:
|
137
141
|
assert self._settings.run_id
|
138
142
|
self.interface = self._service.make_interface(
|
139
|
-
self._mailbox,
|
140
143
|
stream_id=self._settings.run_id,
|
141
144
|
)
|
142
145
|
return
|
@@ -190,11 +193,17 @@ class Backend:
|
|
190
193
|
|
191
194
|
self._module_main_uninstall()
|
192
195
|
|
196
|
+
mailbox = Mailbox()
|
193
197
|
self.interface = InterfaceQueue(
|
194
198
|
process=self.wandb_process,
|
195
199
|
record_q=self.record_q, # type: ignore
|
196
200
|
result_q=self.result_q, # type: ignore
|
197
|
-
mailbox=
|
201
|
+
mailbox=mailbox,
|
202
|
+
)
|
203
|
+
self._router = MessageQueueRouter(
|
204
|
+
request_queue=self.record_q, # type: ignore
|
205
|
+
response_queue=self.result_q, # type: ignore
|
206
|
+
mailbox=mailbox,
|
198
207
|
)
|
199
208
|
|
200
209
|
def server_status(self) -> None:
|
@@ -207,6 +216,8 @@ class Backend:
|
|
207
216
|
self._done = True
|
208
217
|
if self.interface:
|
209
218
|
self.interface.join()
|
219
|
+
if self._router:
|
220
|
+
self._router.join()
|
210
221
|
if self.wandb_process:
|
211
222
|
self.wandb_process.join()
|
212
223
|
|
@@ -53,7 +53,7 @@ class BoundingBoxes2D(JSONMetadata):
|
|
53
53
|
import numpy as np
|
54
54
|
import wandb
|
55
55
|
|
56
|
-
wandb.init()
|
56
|
+
run = wandb.init()
|
57
57
|
image = np.random.randint(low=0, high=256, size=(200, 300, 3))
|
58
58
|
|
59
59
|
class_labels = {0: "person", 1: "car", 2: "road", 3: "building"}
|
@@ -77,7 +77,11 @@ class BoundingBoxes2D(JSONMetadata):
|
|
77
77
|
},
|
78
78
|
{
|
79
79
|
# another box expressed in the pixel domain
|
80
|
-
"position": {
|
80
|
+
"position": {
|
81
|
+
"middle": [150, 20],
|
82
|
+
"width": 68,
|
83
|
+
"height": 112,
|
84
|
+
},
|
81
85
|
"domain": "pixel",
|
82
86
|
"class_id": 3,
|
83
87
|
"box_caption": "a building",
|
@@ -90,7 +94,7 @@ class BoundingBoxes2D(JSONMetadata):
|
|
90
94
|
},
|
91
95
|
)
|
92
96
|
|
93
|
-
|
97
|
+
run.log({"driving_scene": img})
|
94
98
|
```
|
95
99
|
|
96
100
|
### Log a bounding box overlay to a Table
|
@@ -99,7 +103,7 @@ class BoundingBoxes2D(JSONMetadata):
|
|
99
103
|
import numpy as np
|
100
104
|
import wandb
|
101
105
|
|
102
|
-
wandb.init()
|
106
|
+
run = wandb.init()
|
103
107
|
image = np.random.randint(low=0, high=256, size=(200, 300, 3))
|
104
108
|
|
105
109
|
class_labels = {0: "person", 1: "car", 2: "road", 3: "building"}
|
@@ -132,7 +136,11 @@ class BoundingBoxes2D(JSONMetadata):
|
|
132
136
|
},
|
133
137
|
{
|
134
138
|
# another box expressed in the pixel domain
|
135
|
-
"position": {
|
139
|
+
"position": {
|
140
|
+
"middle": [150, 20],
|
141
|
+
"width": 68,
|
142
|
+
"height": 112,
|
143
|
+
},
|
136
144
|
"domain": "pixel",
|
137
145
|
"class_id": 3,
|
138
146
|
"box_caption": "a building",
|
@@ -148,7 +156,7 @@ class BoundingBoxes2D(JSONMetadata):
|
|
148
156
|
|
149
157
|
table = wandb.Table(columns=["image"])
|
150
158
|
table.add_data(img)
|
151
|
-
|
159
|
+
run.log({"driving_scene": table})
|
152
160
|
```
|
153
161
|
"""
|
154
162
|
|
@@ -38,7 +38,7 @@ class ImageMask(Media):
|
|
38
38
|
import numpy as np
|
39
39
|
import wandb
|
40
40
|
|
41
|
-
wandb.init()
|
41
|
+
run = wandb.init()
|
42
42
|
image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
|
43
43
|
predicted_mask = np.empty((100, 100), dtype=np.uint8)
|
44
44
|
ground_truth_mask = np.empty((100, 100), dtype=np.uint8)
|
@@ -58,14 +58,17 @@ class ImageMask(Media):
|
|
58
58
|
masked_image = wandb.Image(
|
59
59
|
image,
|
60
60
|
masks={
|
61
|
-
"predictions": {
|
61
|
+
"predictions": {
|
62
|
+
"mask_data": predicted_mask,
|
63
|
+
"class_labels": class_labels,
|
64
|
+
},
|
62
65
|
"ground_truth": {
|
63
66
|
"mask_data": ground_truth_mask,
|
64
67
|
"class_labels": class_labels,
|
65
68
|
},
|
66
69
|
},
|
67
70
|
)
|
68
|
-
|
71
|
+
run.log({"img_with_masks": masked_image})
|
69
72
|
```
|
70
73
|
|
71
74
|
### Log a masked image inside a Table
|
@@ -74,7 +77,7 @@ class ImageMask(Media):
|
|
74
77
|
import numpy as np
|
75
78
|
import wandb
|
76
79
|
|
77
|
-
wandb.init()
|
80
|
+
run = wandb.init()
|
78
81
|
image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
|
79
82
|
predicted_mask = np.empty((100, 100), dtype=np.uint8)
|
80
83
|
ground_truth_mask = np.empty((100, 100), dtype=np.uint8)
|
@@ -103,7 +106,10 @@ class ImageMask(Media):
|
|
103
106
|
masked_image = wandb.Image(
|
104
107
|
image,
|
105
108
|
masks={
|
106
|
-
"predictions": {
|
109
|
+
"predictions": {
|
110
|
+
"mask_data": predicted_mask,
|
111
|
+
"class_labels": class_labels,
|
112
|
+
},
|
107
113
|
"ground_truth": {
|
108
114
|
"mask_data": ground_truth_mask,
|
109
115
|
"class_labels": class_labels,
|
@@ -114,7 +120,7 @@ class ImageMask(Media):
|
|
114
120
|
|
115
121
|
table = wandb.Table(columns=["image"])
|
116
122
|
table.add_data(masked_image)
|
117
|
-
|
123
|
+
run.log({"random_field": table})
|
118
124
|
```
|
119
125
|
"""
|
120
126
|
|
@@ -1,18 +1,9 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import shutil
|
3
5
|
import sys
|
4
|
-
from typing import
|
5
|
-
TYPE_CHECKING,
|
6
|
-
Any,
|
7
|
-
ClassVar,
|
8
|
-
Generic,
|
9
|
-
List,
|
10
|
-
Optional,
|
11
|
-
Type,
|
12
|
-
TypeVar,
|
13
|
-
Union,
|
14
|
-
cast,
|
15
|
-
)
|
6
|
+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
|
16
7
|
|
17
8
|
import wandb
|
18
9
|
from wandb import util
|
@@ -23,22 +14,24 @@ from wandb.sdk.lib.paths import LogicalPath
|
|
23
14
|
from ._private import MEDIA_TMP
|
24
15
|
from .base_types.wb_value import WBValue
|
25
16
|
|
26
|
-
if TYPE_CHECKING:
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from types import ModuleType
|
19
|
+
|
27
20
|
import cloudpickle # type: ignore
|
28
21
|
import sklearn # type: ignore
|
29
22
|
import tensorflow # type: ignore
|
30
23
|
import torch # type: ignore
|
24
|
+
from typing_extensions import Self
|
31
25
|
|
32
26
|
from wandb.sdk.artifacts.artifact import Artifact
|
33
|
-
|
34
|
-
from ..wandb_run import Run as LocalRun
|
27
|
+
from wandb.sdk.wandb_run import Run as LocalRun
|
35
28
|
|
36
29
|
|
37
30
|
DEBUG_MODE = False
|
38
31
|
|
39
32
|
|
40
33
|
def _add_deterministic_dir_to_artifact(
|
41
|
-
artifact:
|
34
|
+
artifact: Artifact, dir_name: str, target_dir_root: str
|
42
35
|
) -> str:
|
43
36
|
file_paths = []
|
44
37
|
for dirpath, _, filenames in os.walk(dir_name, topdown=True):
|
@@ -50,7 +43,7 @@ def _add_deterministic_dir_to_artifact(
|
|
50
43
|
return target_path
|
51
44
|
|
52
45
|
|
53
|
-
def _load_dir_from_artifact(source_artifact:
|
46
|
+
def _load_dir_from_artifact(source_artifact: Artifact, path: str) -> str:
|
54
47
|
dl_path = None
|
55
48
|
|
56
49
|
# Look through the entire manifest to find all of the files in the directory.
|
@@ -79,14 +72,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
79
72
|
_log_type: ClassVar[str]
|
80
73
|
_path_extension: ClassVar[str]
|
81
74
|
|
82
|
-
_model_obj:
|
83
|
-
_path:
|
84
|
-
_input_obj_or_path:
|
75
|
+
_model_obj: SavedModelObjType | None
|
76
|
+
_path: str | None
|
77
|
+
_input_obj_or_path: SavedModelObjType | str
|
85
78
|
|
86
79
|
# Public Methods
|
87
|
-
def __init__(
|
88
|
-
self, obj_or_path: Union[SavedModelObjType, str], **kwargs: Any
|
89
|
-
) -> None:
|
80
|
+
def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
|
90
81
|
super().__init__()
|
91
82
|
if self.__class__ == _SavedModel:
|
92
83
|
raise TypeError(
|
@@ -115,7 +106,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
115
106
|
self._unset_obj()
|
116
107
|
|
117
108
|
@staticmethod
|
118
|
-
def init(obj_or_path: Any, **kwargs: Any) ->
|
109
|
+
def init(obj_or_path: Any, **kwargs: Any) -> _SavedModel:
|
119
110
|
maybe_instance = _SavedModel._maybe_init(obj_or_path, **kwargs)
|
120
111
|
if maybe_instance is None:
|
121
112
|
raise ValueError(
|
@@ -125,8 +116,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
125
116
|
|
126
117
|
@classmethod
|
127
118
|
def from_json(
|
128
|
-
cls:
|
129
|
-
) ->
|
119
|
+
cls: type[_SavedModel], json_obj: dict, source_artifact: Artifact
|
120
|
+
) -> _SavedModel:
|
130
121
|
path = json_obj["path"]
|
131
122
|
|
132
123
|
# First, if the entry is a file, the download it.
|
@@ -143,7 +134,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
143
134
|
# and specified adapter.
|
144
135
|
return cls(dl_path)
|
145
136
|
|
146
|
-
def to_json(self, run_or_artifact:
|
137
|
+
def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
|
147
138
|
# Unlike other data types, we do not allow adding to a Run directly. There is a
|
148
139
|
# bit of tech debt in the other data types which requires the input to `to_json`
|
149
140
|
# to accept a Run or Artifact. However, Run additions should be deprecated in the future.
|
@@ -218,8 +209,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
218
209
|
# Private Class Methods
|
219
210
|
@classmethod
|
220
211
|
def _maybe_init(
|
221
|
-
cls:
|
222
|
-
) ->
|
212
|
+
cls: type[_SavedModel], obj_or_path: Any, **kwargs: Any
|
213
|
+
) -> _SavedModel | None:
|
223
214
|
# _maybe_init is an exception-safe method that will return an instance of this class
|
224
215
|
# (or any subclass of this class - recursively) OR None if no subclass constructor is found.
|
225
216
|
# We first try the current class, then recursively call this method on children classes. This pattern
|
@@ -241,7 +232,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
241
232
|
return None
|
242
233
|
|
243
234
|
@classmethod
|
244
|
-
def _tmp_path(cls:
|
235
|
+
def _tmp_path(cls: type[_SavedModel]) -> str:
|
245
236
|
# Generates a tmp path under our MEDIA_TMP directory which confirms to the file
|
246
237
|
# or folder preferences of the class.
|
247
238
|
assert isinstance(cls._path_extension, str), "_path_extension must be a string"
|
@@ -286,13 +277,13 @@ PicklingSavedModelObjType = TypeVar("PicklingSavedModelObjType")
|
|
286
277
|
|
287
278
|
|
288
279
|
class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
289
|
-
_dep_py_files:
|
290
|
-
_dep_py_files_path:
|
280
|
+
_dep_py_files: list[str] | None = None
|
281
|
+
_dep_py_files_path: str | None = None
|
291
282
|
|
292
283
|
def __init__(
|
293
284
|
self,
|
294
|
-
obj_or_path:
|
295
|
-
dep_py_files:
|
285
|
+
obj_or_path: SavedModelObjType | str,
|
286
|
+
dep_py_files: list[str] | None = None,
|
296
287
|
):
|
297
288
|
super().__init__(obj_or_path)
|
298
289
|
if self.__class__ == _PicklingSavedModel:
|
@@ -319,9 +310,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
319
310
|
raise ValueError(f"Invalid dependency file: {extra_file}")
|
320
311
|
|
321
312
|
@classmethod
|
322
|
-
def from_json(
|
323
|
-
cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
|
324
|
-
) -> "_PicklingSavedModel":
|
313
|
+
def from_json(cls, json_obj: dict, source_artifact: Artifact) -> Self:
|
325
314
|
backup_path = [p for p in sys.path]
|
326
315
|
if (
|
327
316
|
"dep_py_files_path" in json_obj
|
@@ -337,7 +326,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
337
326
|
|
338
327
|
return inst # type: ignore
|
339
328
|
|
340
|
-
def to_json(self, run_or_artifact:
|
329
|
+
def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
|
341
330
|
json_obj = super().to_json(run_or_artifact)
|
342
331
|
assert isinstance(run_or_artifact, wandb.Artifact)
|
343
332
|
if self._dep_py_files_path is not None:
|
@@ -361,7 +350,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
|
|
361
350
|
_path_extension = "pt"
|
362
351
|
|
363
352
|
@staticmethod
|
364
|
-
def _deserialize(dir_or_file_path: str) ->
|
353
|
+
def _deserialize(dir_or_file_path: str) -> torch.nn.Module:
|
365
354
|
return _get_torch().load(dir_or_file_path, weights_only=False)
|
366
355
|
|
367
356
|
@staticmethod
|
@@ -369,7 +358,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
|
|
369
358
|
return isinstance(obj, _get_torch().nn.Module)
|
370
359
|
|
371
360
|
@staticmethod
|
372
|
-
def _serialize(model_obj:
|
361
|
+
def _serialize(model_obj: torch.nn.Module, dir_or_file_path: str) -> None:
|
373
362
|
_get_torch().save(
|
374
363
|
model_obj,
|
375
364
|
dir_or_file_path,
|
@@ -391,7 +380,7 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
|
|
391
380
|
@staticmethod
|
392
381
|
def _deserialize(
|
393
382
|
dir_or_file_path: str,
|
394
|
-
) ->
|
383
|
+
) -> sklearn.base.BaseEstimator:
|
395
384
|
with open(dir_or_file_path, "rb") as file:
|
396
385
|
model = _get_cloudpickle().load(file)
|
397
386
|
return model
|
@@ -410,16 +399,16 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
|
|
410
399
|
|
411
400
|
@staticmethod
|
412
401
|
def _serialize(
|
413
|
-
model_obj:
|
402
|
+
model_obj: sklearn.base.BaseEstimator, dir_or_file_path: str
|
414
403
|
) -> None:
|
415
404
|
dynamic_cloudpickle = _get_cloudpickle()
|
416
405
|
with open(dir_or_file_path, "wb") as file:
|
417
406
|
dynamic_cloudpickle.dump(model_obj, file)
|
418
407
|
|
419
408
|
|
420
|
-
def _get_tf_keras() ->
|
409
|
+
def _get_tf_keras() -> ModuleType:
|
421
410
|
return cast(
|
422
|
-
|
411
|
+
ModuleType,
|
423
412
|
util.get_module("tensorflow", "ModelAdapter requires `tensorflow`"),
|
424
413
|
).keras
|
425
414
|
|
@@ -431,7 +420,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
|
|
431
420
|
@staticmethod
|
432
421
|
def _deserialize(
|
433
422
|
dir_or_file_path: str,
|
434
|
-
) ->
|
423
|
+
) -> tensorflow.keras.Model:
|
435
424
|
return _get_tf_keras().models.load_model(dir_or_file_path)
|
436
425
|
|
437
426
|
@staticmethod
|
@@ -439,7 +428,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
|
|
439
428
|
return isinstance(obj, _get_tf_keras().models.Model)
|
440
429
|
|
441
430
|
@staticmethod
|
442
|
-
def _serialize(model_obj:
|
431
|
+
def _serialize(model_obj: tensorflow.keras.Model, dir_or_file_path: str) -> None:
|
443
432
|
_get_tf_keras().models.save_model(
|
444
433
|
model_obj, dir_or_file_path, include_optimizer=True
|
445
434
|
)
|
wandb/sdk/data_types/video.py
CHANGED
@@ -71,10 +71,10 @@ class Video(BatchableMedia):
|
|
71
71
|
import numpy as np
|
72
72
|
import wandb
|
73
73
|
|
74
|
-
wandb.init()
|
74
|
+
run = wandb.init()
|
75
75
|
# axes are (time, channel, height, width)
|
76
76
|
frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
|
77
|
-
|
77
|
+
run.log({"video": wandb.Video(frames, fps=4)})
|
78
78
|
```
|
79
79
|
"""
|
80
80
|
|
@@ -138,20 +138,11 @@ class Video(BatchableMedia):
|
|
138
138
|
self.encode(fps=fps)
|
139
139
|
|
140
140
|
def encode(self, fps: int = 4) -> None:
|
141
|
-
#
|
142
|
-
mpy =
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
"moviepy.editor",
|
147
|
-
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
148
|
-
)
|
149
|
-
except wandb.Error:
|
150
|
-
# Fallback to moviepy for MoviePy >= 2.0
|
151
|
-
mpy = util.get_module(
|
152
|
-
"moviepy",
|
153
|
-
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
154
|
-
)
|
141
|
+
# import ImageSequenceClip from the appropriate MoviePy module
|
142
|
+
mpy = util.get_module(
|
143
|
+
"moviepy.video.io.ImageSequenceClip",
|
144
|
+
required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
|
145
|
+
)
|
155
146
|
|
156
147
|
tensor = self._prepare_video(self.data)
|
157
148
|
_, self._height, self._width, self._channels = tensor.shape # type: ignore
|