wandb 0.19.6rc4__py3-none-macosx_11_0_arm64.whl → 0.19.8__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +56 -6
  3. wandb/apis/public/_generated/__init__.py +21 -0
  4. wandb/apis/public/_generated/base.py +128 -0
  5. wandb/apis/public/_generated/enums.py +4 -0
  6. wandb/apis/public/_generated/input_types.py +4 -0
  7. wandb/apis/public/_generated/operations.py +15 -0
  8. wandb/apis/public/_generated/server_features_query.py +27 -0
  9. wandb/apis/public/_generated/typing_compat.py +14 -0
  10. wandb/apis/public/api.py +192 -6
  11. wandb/apis/public/artifacts.py +13 -45
  12. wandb/apis/public/registries.py +573 -0
  13. wandb/apis/public/utils.py +36 -0
  14. wandb/bin/gpu_stats +0 -0
  15. wandb/bin/wandb-core +0 -0
  16. wandb/cli/cli.py +11 -20
  17. wandb/data_types.py +1 -1
  18. wandb/env.py +10 -0
  19. wandb/filesync/dir_watcher.py +2 -1
  20. wandb/proto/v3/wandb_internal_pb2.py +243 -222
  21. wandb/proto/v3/wandb_server_pb2.py +4 -4
  22. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  23. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  24. wandb/proto/v4/wandb_internal_pb2.py +226 -222
  25. wandb/proto/v4/wandb_server_pb2.py +4 -4
  26. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  27. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  28. wandb/proto/v5/wandb_internal_pb2.py +226 -222
  29. wandb/proto/v5/wandb_server_pb2.py +4 -4
  30. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  31. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  32. wandb/sdk/artifacts/_graphql_fragments.py +126 -0
  33. wandb/sdk/artifacts/artifact.py +51 -95
  34. wandb/sdk/backend/backend.py +17 -6
  35. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
  36. wandb/sdk/data_types/helper_types/image_mask.py +12 -6
  37. wandb/sdk/data_types/saved_model.py +35 -46
  38. wandb/sdk/data_types/video.py +7 -16
  39. wandb/sdk/interface/interface.py +87 -49
  40. wandb/sdk/interface/interface_queue.py +5 -15
  41. wandb/sdk/interface/interface_relay.py +7 -22
  42. wandb/sdk/interface/interface_shared.py +65 -136
  43. wandb/sdk/interface/interface_sock.py +3 -21
  44. wandb/sdk/interface/router.py +42 -68
  45. wandb/sdk/interface/router_queue.py +13 -11
  46. wandb/sdk/interface/router_relay.py +26 -13
  47. wandb/sdk/interface/router_sock.py +12 -16
  48. wandb/sdk/internal/handler.py +4 -3
  49. wandb/sdk/internal/internal_api.py +12 -1
  50. wandb/sdk/internal/sender.py +3 -19
  51. wandb/sdk/lib/apikey.py +87 -26
  52. wandb/sdk/lib/asyncio_compat.py +210 -0
  53. wandb/sdk/lib/console_capture.py +172 -0
  54. wandb/sdk/lib/progress.py +78 -16
  55. wandb/sdk/lib/redirect.py +102 -76
  56. wandb/sdk/lib/service_connection.py +37 -17
  57. wandb/sdk/lib/sock_client.py +6 -56
  58. wandb/sdk/mailbox/__init__.py +23 -0
  59. wandb/sdk/mailbox/mailbox.py +135 -0
  60. wandb/sdk/mailbox/mailbox_handle.py +127 -0
  61. wandb/sdk/mailbox/response_handle.py +167 -0
  62. wandb/sdk/mailbox/wait_with_progress.py +135 -0
  63. wandb/sdk/service/server_sock.py +9 -3
  64. wandb/sdk/service/streams.py +75 -78
  65. wandb/sdk/verify/verify.py +54 -2
  66. wandb/sdk/wandb_init.py +72 -75
  67. wandb/sdk/wandb_login.py +7 -4
  68. wandb/sdk/wandb_metadata.py +65 -34
  69. wandb/sdk/wandb_require.py +14 -8
  70. wandb/sdk/wandb_run.py +90 -97
  71. wandb/sdk/wandb_settings.py +10 -4
  72. wandb/sdk/wandb_setup.py +19 -8
  73. wandb/sdk/wandb_sync.py +2 -10
  74. wandb/util.py +3 -1
  75. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
  76. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
  77. wandb/sdk/interface/message_future.py +0 -27
  78. wandb/sdk/interface/message_future_poll.py +0 -50
  79. wandb/sdk/lib/mailbox.py +0 -442
  80. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
  81. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
  82. {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
@@ -16,9 +16,10 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
16
16
  import wandb
17
17
  from wandb.sdk.interface.interface import InterfaceBase
18
18
  from wandb.sdk.interface.interface_queue import InterfaceQueue
19
+ from wandb.sdk.interface.router_queue import MessageQueueRouter
19
20
  from wandb.sdk.internal.internal import wandb_internal
20
21
  from wandb.sdk.internal.settings_static import SettingsStatic
21
- from wandb.sdk.lib.mailbox import Mailbox
22
+ from wandb.sdk.mailbox import Mailbox
22
23
  from wandb.sdk.wandb_settings import Settings
23
24
 
24
25
  if TYPE_CHECKING:
@@ -49,17 +50,18 @@ class BackendThread(threading.Thread):
49
50
  class Backend:
50
51
  # multiprocessing context or module
51
52
  _multiprocessing: multiprocessing.context.BaseContext
53
+
52
54
  interface: Optional[InterfaceBase]
55
+ _router: Optional[MessageQueueRouter]
56
+
53
57
  _internal_pid: Optional[int]
54
58
  wandb_process: Optional[multiprocessing.process.BaseProcess]
55
59
  _settings: Settings
56
60
  record_q: Optional["RecordQueue"]
57
61
  result_q: Optional["ResultQueue"]
58
- _mailbox: Mailbox
59
62
 
60
63
  def __init__(
61
64
  self,
62
- mailbox: Mailbox,
63
65
  settings: Settings,
64
66
  log_level: Optional[int] = None,
65
67
  service: "Optional[service_connection.ServiceConnection]" = None,
@@ -68,12 +70,14 @@ class Backend:
68
70
  self.record_q = None
69
71
  self.result_q = None
70
72
  self.wandb_process = None
73
+
71
74
  self.interface = None
75
+ self._router = None
76
+
72
77
  self._internal_pid = None
73
78
  self._settings = settings
74
79
  self._log_level = log_level
75
80
  self._service = service
76
- self._mailbox = mailbox
77
81
 
78
82
  self._multiprocessing = multiprocessing # type: ignore
79
83
  self._multiprocessing_setup()
@@ -136,7 +140,6 @@ class Backend:
136
140
  if self._service:
137
141
  assert self._settings.run_id
138
142
  self.interface = self._service.make_interface(
139
- self._mailbox,
140
143
  stream_id=self._settings.run_id,
141
144
  )
142
145
  return
@@ -190,11 +193,17 @@ class Backend:
190
193
 
191
194
  self._module_main_uninstall()
192
195
 
196
+ mailbox = Mailbox()
193
197
  self.interface = InterfaceQueue(
194
198
  process=self.wandb_process,
195
199
  record_q=self.record_q, # type: ignore
196
200
  result_q=self.result_q, # type: ignore
197
- mailbox=self._mailbox,
201
+ mailbox=mailbox,
202
+ )
203
+ self._router = MessageQueueRouter(
204
+ request_queue=self.record_q, # type: ignore
205
+ response_queue=self.result_q, # type: ignore
206
+ mailbox=mailbox,
198
207
  )
199
208
 
200
209
  def server_status(self) -> None:
@@ -207,6 +216,8 @@ class Backend:
207
216
  self._done = True
208
217
  if self.interface:
209
218
  self.interface.join()
219
+ if self._router:
220
+ self._router.join()
210
221
  if self.wandb_process:
211
222
  self.wandb_process.join()
212
223
 
@@ -53,7 +53,7 @@ class BoundingBoxes2D(JSONMetadata):
53
53
  import numpy as np
54
54
  import wandb
55
55
 
56
- wandb.init()
56
+ run = wandb.init()
57
57
  image = np.random.randint(low=0, high=256, size=(200, 300, 3))
58
58
 
59
59
  class_labels = {0: "person", 1: "car", 2: "road", 3: "building"}
@@ -77,7 +77,11 @@ class BoundingBoxes2D(JSONMetadata):
77
77
  },
78
78
  {
79
79
  # another box expressed in the pixel domain
80
- "position": {"middle": [150, 20], "width": 68, "height": 112},
80
+ "position": {
81
+ "middle": [150, 20],
82
+ "width": 68,
83
+ "height": 112,
84
+ },
81
85
  "domain": "pixel",
82
86
  "class_id": 3,
83
87
  "box_caption": "a building",
@@ -90,7 +94,7 @@ class BoundingBoxes2D(JSONMetadata):
90
94
  },
91
95
  )
92
96
 
93
- wandb.log({"driving_scene": img})
97
+ run.log({"driving_scene": img})
94
98
  ```
95
99
 
96
100
  ### Log a bounding box overlay to a Table
@@ -99,7 +103,7 @@ class BoundingBoxes2D(JSONMetadata):
99
103
  import numpy as np
100
104
  import wandb
101
105
 
102
- wandb.init()
106
+ run = wandb.init()
103
107
  image = np.random.randint(low=0, high=256, size=(200, 300, 3))
104
108
 
105
109
  class_labels = {0: "person", 1: "car", 2: "road", 3: "building"}
@@ -132,7 +136,11 @@ class BoundingBoxes2D(JSONMetadata):
132
136
  },
133
137
  {
134
138
  # another box expressed in the pixel domain
135
- "position": {"middle": [150, 20], "width": 68, "height": 112},
139
+ "position": {
140
+ "middle": [150, 20],
141
+ "width": 68,
142
+ "height": 112,
143
+ },
136
144
  "domain": "pixel",
137
145
  "class_id": 3,
138
146
  "box_caption": "a building",
@@ -148,7 +156,7 @@ class BoundingBoxes2D(JSONMetadata):
148
156
 
149
157
  table = wandb.Table(columns=["image"])
150
158
  table.add_data(img)
151
- wandb.log({"driving_scene": table})
159
+ run.log({"driving_scene": table})
152
160
  ```
153
161
  """
154
162
 
@@ -38,7 +38,7 @@ class ImageMask(Media):
38
38
  import numpy as np
39
39
  import wandb
40
40
 
41
- wandb.init()
41
+ run = wandb.init()
42
42
  image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
43
43
  predicted_mask = np.empty((100, 100), dtype=np.uint8)
44
44
  ground_truth_mask = np.empty((100, 100), dtype=np.uint8)
@@ -58,14 +58,17 @@ class ImageMask(Media):
58
58
  masked_image = wandb.Image(
59
59
  image,
60
60
  masks={
61
- "predictions": {"mask_data": predicted_mask, "class_labels": class_labels},
61
+ "predictions": {
62
+ "mask_data": predicted_mask,
63
+ "class_labels": class_labels,
64
+ },
62
65
  "ground_truth": {
63
66
  "mask_data": ground_truth_mask,
64
67
  "class_labels": class_labels,
65
68
  },
66
69
  },
67
70
  )
68
- wandb.log({"img_with_masks": masked_image})
71
+ run.log({"img_with_masks": masked_image})
69
72
  ```
70
73
 
71
74
  ### Log a masked image inside a Table
@@ -74,7 +77,7 @@ class ImageMask(Media):
74
77
  import numpy as np
75
78
  import wandb
76
79
 
77
- wandb.init()
80
+ run = wandb.init()
78
81
  image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8)
79
82
  predicted_mask = np.empty((100, 100), dtype=np.uint8)
80
83
  ground_truth_mask = np.empty((100, 100), dtype=np.uint8)
@@ -103,7 +106,10 @@ class ImageMask(Media):
103
106
  masked_image = wandb.Image(
104
107
  image,
105
108
  masks={
106
- "predictions": {"mask_data": predicted_mask, "class_labels": class_labels},
109
+ "predictions": {
110
+ "mask_data": predicted_mask,
111
+ "class_labels": class_labels,
112
+ },
107
113
  "ground_truth": {
108
114
  "mask_data": ground_truth_mask,
109
115
  "class_labels": class_labels,
@@ -114,7 +120,7 @@ class ImageMask(Media):
114
120
 
115
121
  table = wandb.Table(columns=["image"])
116
122
  table.add_data(masked_image)
117
- wandb.log({"random_field": table})
123
+ run.log({"random_field": table})
118
124
  ```
119
125
  """
120
126
 
@@ -1,18 +1,9 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import shutil
3
5
  import sys
4
- from typing import (
5
- TYPE_CHECKING,
6
- Any,
7
- ClassVar,
8
- Generic,
9
- List,
10
- Optional,
11
- Type,
12
- TypeVar,
13
- Union,
14
- cast,
15
- )
6
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
16
7
 
17
8
  import wandb
18
9
  from wandb import util
@@ -23,22 +14,24 @@ from wandb.sdk.lib.paths import LogicalPath
23
14
  from ._private import MEDIA_TMP
24
15
  from .base_types.wb_value import WBValue
25
16
 
26
- if TYPE_CHECKING: # pragma: no cover
17
+ if TYPE_CHECKING:
18
+ from types import ModuleType
19
+
27
20
  import cloudpickle # type: ignore
28
21
  import sklearn # type: ignore
29
22
  import tensorflow # type: ignore
30
23
  import torch # type: ignore
24
+ from typing_extensions import Self
31
25
 
32
26
  from wandb.sdk.artifacts.artifact import Artifact
33
-
34
- from ..wandb_run import Run as LocalRun
27
+ from wandb.sdk.wandb_run import Run as LocalRun
35
28
 
36
29
 
37
30
  DEBUG_MODE = False
38
31
 
39
32
 
40
33
  def _add_deterministic_dir_to_artifact(
41
- artifact: "Artifact", dir_name: str, target_dir_root: str
34
+ artifact: Artifact, dir_name: str, target_dir_root: str
42
35
  ) -> str:
43
36
  file_paths = []
44
37
  for dirpath, _, filenames in os.walk(dir_name, topdown=True):
@@ -50,7 +43,7 @@ def _add_deterministic_dir_to_artifact(
50
43
  return target_path
51
44
 
52
45
 
53
- def _load_dir_from_artifact(source_artifact: "Artifact", path: str) -> str:
46
+ def _load_dir_from_artifact(source_artifact: Artifact, path: str) -> str:
54
47
  dl_path = None
55
48
 
56
49
  # Look through the entire manifest to find all of the files in the directory.
@@ -79,14 +72,12 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
79
72
  _log_type: ClassVar[str]
80
73
  _path_extension: ClassVar[str]
81
74
 
82
- _model_obj: Optional["SavedModelObjType"]
83
- _path: Optional[str]
84
- _input_obj_or_path: Union[SavedModelObjType, str]
75
+ _model_obj: SavedModelObjType | None
76
+ _path: str | None
77
+ _input_obj_or_path: SavedModelObjType | str
85
78
 
86
79
  # Public Methods
87
- def __init__(
88
- self, obj_or_path: Union[SavedModelObjType, str], **kwargs: Any
89
- ) -> None:
80
+ def __init__(self, obj_or_path: SavedModelObjType | str, **kwargs: Any) -> None:
90
81
  super().__init__()
91
82
  if self.__class__ == _SavedModel:
92
83
  raise TypeError(
@@ -115,7 +106,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
115
106
  self._unset_obj()
116
107
 
117
108
  @staticmethod
118
- def init(obj_or_path: Any, **kwargs: Any) -> "_SavedModel":
109
+ def init(obj_or_path: Any, **kwargs: Any) -> _SavedModel:
119
110
  maybe_instance = _SavedModel._maybe_init(obj_or_path, **kwargs)
120
111
  if maybe_instance is None:
121
112
  raise ValueError(
@@ -125,8 +116,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
125
116
 
126
117
  @classmethod
127
118
  def from_json(
128
- cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
129
- ) -> "_SavedModel":
119
+ cls: type[_SavedModel], json_obj: dict, source_artifact: Artifact
120
+ ) -> _SavedModel:
130
121
  path = json_obj["path"]
131
122
 
132
123
  # First, if the entry is a file, the download it.
@@ -143,7 +134,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
143
134
  # and specified adapter.
144
135
  return cls(dl_path)
145
136
 
146
- def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
137
+ def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
147
138
  # Unlike other data types, we do not allow adding to a Run directly. There is a
148
139
  # bit of tech debt in the other data types which requires the input to `to_json`
149
140
  # to accept a Run or Artifact. However, Run additions should be deprecated in the future.
@@ -218,8 +209,8 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
218
209
  # Private Class Methods
219
210
  @classmethod
220
211
  def _maybe_init(
221
- cls: Type["_SavedModel"], obj_or_path: Any, **kwargs: Any
222
- ) -> Optional["_SavedModel"]:
212
+ cls: type[_SavedModel], obj_or_path: Any, **kwargs: Any
213
+ ) -> _SavedModel | None:
223
214
  # _maybe_init is an exception-safe method that will return an instance of this class
224
215
  # (or any subclass of this class - recursively) OR None if no subclass constructor is found.
225
216
  # We first try the current class, then recursively call this method on children classes. This pattern
@@ -241,7 +232,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
241
232
  return None
242
233
 
243
234
  @classmethod
244
- def _tmp_path(cls: Type["_SavedModel"]) -> str:
235
+ def _tmp_path(cls: type[_SavedModel]) -> str:
245
236
  # Generates a tmp path under our MEDIA_TMP directory which confirms to the file
246
237
  # or folder preferences of the class.
247
238
  assert isinstance(cls._path_extension, str), "_path_extension must be a string"
@@ -286,13 +277,13 @@ PicklingSavedModelObjType = TypeVar("PicklingSavedModelObjType")
286
277
 
287
278
 
288
279
  class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
289
- _dep_py_files: Optional[List[str]] = None
290
- _dep_py_files_path: Optional[str] = None
280
+ _dep_py_files: list[str] | None = None
281
+ _dep_py_files_path: str | None = None
291
282
 
292
283
  def __init__(
293
284
  self,
294
- obj_or_path: Union[SavedModelObjType, str],
295
- dep_py_files: Optional[List[str]] = None,
285
+ obj_or_path: SavedModelObjType | str,
286
+ dep_py_files: list[str] | None = None,
296
287
  ):
297
288
  super().__init__(obj_or_path)
298
289
  if self.__class__ == _PicklingSavedModel:
@@ -319,9 +310,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
319
310
  raise ValueError(f"Invalid dependency file: {extra_file}")
320
311
 
321
312
  @classmethod
322
- def from_json(
323
- cls: Type["_SavedModel"], json_obj: dict, source_artifact: "Artifact"
324
- ) -> "_PicklingSavedModel":
313
+ def from_json(cls, json_obj: dict, source_artifact: Artifact) -> Self:
325
314
  backup_path = [p for p in sys.path]
326
315
  if (
327
316
  "dep_py_files_path" in json_obj
@@ -337,7 +326,7 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
337
326
 
338
327
  return inst # type: ignore
339
328
 
340
- def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
329
+ def to_json(self, run_or_artifact: LocalRun | Artifact) -> dict:
341
330
  json_obj = super().to_json(run_or_artifact)
342
331
  assert isinstance(run_or_artifact, wandb.Artifact)
343
332
  if self._dep_py_files_path is not None:
@@ -361,7 +350,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
361
350
  _path_extension = "pt"
362
351
 
363
352
  @staticmethod
364
- def _deserialize(dir_or_file_path: str) -> "torch.nn.Module":
353
+ def _deserialize(dir_or_file_path: str) -> torch.nn.Module:
365
354
  return _get_torch().load(dir_or_file_path, weights_only=False)
366
355
 
367
356
  @staticmethod
@@ -369,7 +358,7 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
369
358
  return isinstance(obj, _get_torch().nn.Module)
370
359
 
371
360
  @staticmethod
372
- def _serialize(model_obj: "torch.nn.Module", dir_or_file_path: str) -> None:
361
+ def _serialize(model_obj: torch.nn.Module, dir_or_file_path: str) -> None:
373
362
  _get_torch().save(
374
363
  model_obj,
375
364
  dir_or_file_path,
@@ -391,7 +380,7 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
391
380
  @staticmethod
392
381
  def _deserialize(
393
382
  dir_or_file_path: str,
394
- ) -> "sklearn.base.BaseEstimator":
383
+ ) -> sklearn.base.BaseEstimator:
395
384
  with open(dir_or_file_path, "rb") as file:
396
385
  model = _get_cloudpickle().load(file)
397
386
  return model
@@ -410,16 +399,16 @@ class _SklearnSavedModel(_PicklingSavedModel["sklearn.base.BaseEstimator"]):
410
399
 
411
400
  @staticmethod
412
401
  def _serialize(
413
- model_obj: "sklearn.base.BaseEstimator", dir_or_file_path: str
402
+ model_obj: sklearn.base.BaseEstimator, dir_or_file_path: str
414
403
  ) -> None:
415
404
  dynamic_cloudpickle = _get_cloudpickle()
416
405
  with open(dir_or_file_path, "wb") as file:
417
406
  dynamic_cloudpickle.dump(model_obj, file)
418
407
 
419
408
 
420
- def _get_tf_keras() -> "tensorflow.keras":
409
+ def _get_tf_keras() -> ModuleType:
421
410
  return cast(
422
- "tensorflow",
411
+ ModuleType,
423
412
  util.get_module("tensorflow", "ModelAdapter requires `tensorflow`"),
424
413
  ).keras
425
414
 
@@ -431,7 +420,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
431
420
  @staticmethod
432
421
  def _deserialize(
433
422
  dir_or_file_path: str,
434
- ) -> "tensorflow.keras.Model":
423
+ ) -> tensorflow.keras.Model:
435
424
  return _get_tf_keras().models.load_model(dir_or_file_path)
436
425
 
437
426
  @staticmethod
@@ -439,7 +428,7 @@ class _TensorflowKerasSavedModel(_SavedModel["tensorflow.keras.Model"]):
439
428
  return isinstance(obj, _get_tf_keras().models.Model)
440
429
 
441
430
  @staticmethod
442
- def _serialize(model_obj: "tensorflow.keras.Model", dir_or_file_path: str) -> None:
431
+ def _serialize(model_obj: tensorflow.keras.Model, dir_or_file_path: str) -> None:
443
432
  _get_tf_keras().models.save_model(
444
433
  model_obj, dir_or_file_path, include_optimizer=True
445
434
  )
@@ -71,10 +71,10 @@ class Video(BatchableMedia):
71
71
  import numpy as np
72
72
  import wandb
73
73
 
74
- wandb.init()
74
+ run = wandb.init()
75
75
  # axes are (time, channel, height, width)
76
76
  frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
77
- wandb.log({"video": wandb.Video(frames, fps=4)})
77
+ run.log({"video": wandb.Video(frames, fps=4)})
78
78
  ```
79
79
  """
80
80
 
@@ -138,20 +138,11 @@ class Video(BatchableMedia):
138
138
  self.encode(fps=fps)
139
139
 
140
140
  def encode(self, fps: int = 4) -> None:
141
- # Try to import ImageSequenceClip from the appropriate MoviePy module
142
- mpy = None
143
- try:
144
- # Attempt to load moviepy.editor for MoviePy < 2.0
145
- mpy = util.get_module(
146
- "moviepy.editor",
147
- required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
148
- )
149
- except wandb.Error:
150
- # Fallback to moviepy for MoviePy >= 2.0
151
- mpy = util.get_module(
152
- "moviepy",
153
- required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
154
- )
141
+ # import ImageSequenceClip from the appropriate MoviePy module
142
+ mpy = util.get_module(
143
+ "moviepy.video.io.ImageSequenceClip",
144
+ required='wandb.Video requires moviepy when passing raw data. Install with "pip install wandb[media]"',
145
+ )
155
146
 
156
147
  tensor = self._prepare_video(self.data)
157
148
  _, self._height, self._width, self._channels = tensor.shape # type: ignore