wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. wandb/__init__.py +1 -2
  2. wandb/__init__.pyi +3 -6
  3. wandb/_iterutils.py +26 -7
  4. wandb/_pydantic/__init__.py +2 -1
  5. wandb/_pydantic/utils.py +7 -0
  6. wandb/agents/pyagent.py +9 -15
  7. wandb/analytics/sentry.py +1 -2
  8. wandb/apis/attrs.py +3 -4
  9. wandb/apis/importers/internals/util.py +1 -1
  10. wandb/apis/importers/validation.py +2 -2
  11. wandb/apis/importers/wandb.py +30 -25
  12. wandb/apis/normalize.py +2 -2
  13. wandb/apis/public/__init__.py +1 -0
  14. wandb/apis/public/api.py +37 -33
  15. wandb/apis/public/artifacts.py +103 -72
  16. wandb/apis/public/jobs.py +3 -2
  17. wandb/apis/public/registries/registries_search.py +4 -2
  18. wandb/apis/public/registries/registry.py +1 -1
  19. wandb/apis/public/registries/utils.py +9 -9
  20. wandb/apis/public/runs.py +18 -6
  21. wandb/automations/_filters/expressions.py +1 -1
  22. wandb/automations/_filters/operators.py +1 -1
  23. wandb/automations/_filters/run_metrics.py +1 -1
  24. wandb/beta/workflows.py +6 -5
  25. wandb/bin/gpu_stats.exe +0 -0
  26. wandb/bin/wandb-core +0 -0
  27. wandb/cli/cli.py +54 -73
  28. wandb/docker/__init__.py +21 -74
  29. wandb/docker/names.py +40 -0
  30. wandb/env.py +0 -1
  31. wandb/errors/util.py +1 -1
  32. wandb/filesync/step_checksum.py +1 -1
  33. wandb/filesync/step_upload.py +1 -1
  34. wandb/integration/diffusers/resolvers/multimodal.py +1 -2
  35. wandb/integration/gym/__init__.py +5 -6
  36. wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
  37. wandb/integration/keras/keras.py +13 -19
  38. wandb/integration/kfp/kfp_patch.py +2 -3
  39. wandb/integration/langchain/wandb_tracer.py +1 -1
  40. wandb/integration/metaflow/metaflow.py +13 -13
  41. wandb/integration/openai/fine_tuning.py +3 -2
  42. wandb/integration/sagemaker/auth.py +2 -1
  43. wandb/integration/sklearn/utils.py +2 -1
  44. wandb/integration/tensorboard/__init__.py +1 -1
  45. wandb/integration/tensorboard/log.py +2 -5
  46. wandb/integration/tensorflow/__init__.py +2 -2
  47. wandb/jupyter.py +20 -17
  48. wandb/plot/confusion_matrix.py +1 -1
  49. wandb/plot/utils.py +8 -7
  50. wandb/proto/v3/wandb_internal_pb2.py +355 -335
  51. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  52. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  53. wandb/proto/v4/wandb_internal_pb2.py +339 -335
  54. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  55. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  56. wandb/proto/v5/wandb_internal_pb2.py +339 -335
  57. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  58. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  59. wandb/proto/v6/wandb_internal_pb2.py +339 -335
  60. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  61. wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
  62. wandb/proto/wandb_deprecated.py +6 -8
  63. wandb/sdk/artifacts/_internal_artifact.py +43 -0
  64. wandb/sdk/artifacts/_validators.py +55 -35
  65. wandb/sdk/artifacts/artifact.py +117 -115
  66. wandb/sdk/artifacts/artifact_download_logger.py +2 -0
  67. wandb/sdk/artifacts/artifact_saver.py +1 -3
  68. wandb/sdk/artifacts/artifact_state.py +2 -0
  69. wandb/sdk/artifacts/artifact_ttl.py +2 -0
  70. wandb/sdk/artifacts/exceptions.py +14 -0
  71. wandb/sdk/artifacts/staging.py +2 -0
  72. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
  73. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
  74. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
  75. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
  76. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  77. wandb/sdk/artifacts/storage_layout.py +2 -0
  78. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
  79. wandb/sdk/backend/backend.py +11 -182
  80. wandb/sdk/data_types/_dtypes.py +2 -6
  81. wandb/sdk/data_types/audio.py +20 -3
  82. wandb/sdk/data_types/base_types/media.py +12 -7
  83. wandb/sdk/data_types/base_types/wb_value.py +8 -18
  84. wandb/sdk/data_types/bokeh.py +19 -2
  85. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
  86. wandb/sdk/data_types/helper_types/image_mask.py +7 -1
  87. wandb/sdk/data_types/html.py +4 -4
  88. wandb/sdk/data_types/image.py +178 -103
  89. wandb/sdk/data_types/molecule.py +6 -6
  90. wandb/sdk/data_types/object_3d.py +10 -5
  91. wandb/sdk/data_types/saved_model.py +11 -6
  92. wandb/sdk/data_types/table.py +313 -83
  93. wandb/sdk/data_types/table_decorators.py +108 -0
  94. wandb/sdk/data_types/utils.py +43 -7
  95. wandb/sdk/data_types/video.py +21 -3
  96. wandb/sdk/interface/interface.py +10 -0
  97. wandb/sdk/internal/datastore.py +2 -6
  98. wandb/sdk/internal/file_pusher.py +1 -5
  99. wandb/sdk/internal/file_stream.py +8 -17
  100. wandb/sdk/internal/handler.py +2 -2
  101. wandb/sdk/internal/incremental_table_util.py +53 -0
  102. wandb/sdk/internal/internal.py +3 -5
  103. wandb/sdk/internal/internal_api.py +66 -89
  104. wandb/sdk/internal/job_builder.py +2 -7
  105. wandb/sdk/internal/profiler.py +2 -2
  106. wandb/sdk/internal/progress.py +1 -3
  107. wandb/sdk/internal/run.py +1 -6
  108. wandb/sdk/internal/sender.py +24 -36
  109. wandb/sdk/internal/system/assets/aggregators.py +1 -7
  110. wandb/sdk/internal/system/assets/disk.py +3 -3
  111. wandb/sdk/internal/system/assets/gpu.py +4 -4
  112. wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
  113. wandb/sdk/internal/system/assets/interfaces.py +6 -6
  114. wandb/sdk/internal/system/assets/tpu.py +1 -1
  115. wandb/sdk/internal/system/assets/trainium.py +6 -6
  116. wandb/sdk/internal/system/system_info.py +5 -7
  117. wandb/sdk/internal/system/system_monitor.py +4 -4
  118. wandb/sdk/internal/tb_watcher.py +5 -7
  119. wandb/sdk/launch/_launch.py +1 -1
  120. wandb/sdk/launch/_project_spec.py +19 -20
  121. wandb/sdk/launch/agent/agent.py +3 -3
  122. wandb/sdk/launch/agent/config.py +1 -1
  123. wandb/sdk/launch/agent/job_status_tracker.py +2 -2
  124. wandb/sdk/launch/builder/build.py +2 -3
  125. wandb/sdk/launch/builder/kaniko_builder.py +5 -4
  126. wandb/sdk/launch/environment/gcp_environment.py +1 -2
  127. wandb/sdk/launch/registry/azure_container_registry.py +2 -2
  128. wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
  129. wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
  130. wandb/sdk/launch/runner/abstract.py +5 -5
  131. wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
  132. wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
  133. wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
  134. wandb/sdk/launch/runner/vertex_runner.py +2 -7
  135. wandb/sdk/launch/sweeps/__init__.py +1 -1
  136. wandb/sdk/launch/sweeps/scheduler.py +2 -2
  137. wandb/sdk/launch/sweeps/utils.py +3 -3
  138. wandb/sdk/launch/utils.py +3 -4
  139. wandb/sdk/lib/apikey.py +5 -8
  140. wandb/sdk/lib/config_util.py +3 -3
  141. wandb/sdk/lib/fsm.py +3 -18
  142. wandb/sdk/lib/gitlib.py +6 -5
  143. wandb/sdk/lib/ipython.py +2 -2
  144. wandb/sdk/lib/json_util.py +9 -14
  145. wandb/sdk/lib/printer.py +3 -8
  146. wandb/sdk/lib/redirect.py +1 -1
  147. wandb/sdk/lib/retry.py +3 -7
  148. wandb/sdk/lib/run_moment.py +2 -2
  149. wandb/sdk/lib/service_connection.py +3 -1
  150. wandb/sdk/lib/service_token.py +1 -2
  151. wandb/sdk/mailbox/mailbox_handle.py +3 -7
  152. wandb/sdk/mailbox/response_handle.py +2 -6
  153. wandb/sdk/service/streams.py +3 -7
  154. wandb/sdk/verify/verify.py +5 -6
  155. wandb/sdk/wandb_config.py +1 -1
  156. wandb/sdk/wandb_init.py +38 -106
  157. wandb/sdk/wandb_login.py +7 -6
  158. wandb/sdk/wandb_run.py +52 -240
  159. wandb/sdk/wandb_settings.py +71 -60
  160. wandb/sdk/wandb_setup.py +40 -14
  161. wandb/sdk/wandb_watch.py +5 -7
  162. wandb/sync/__init__.py +1 -1
  163. wandb/sync/sync.py +13 -13
  164. wandb/util.py +17 -35
  165. wandb/wandb_agent.py +8 -11
  166. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
  167. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
  168. wandb/docker/auth.py +0 -435
  169. wandb/docker/www_authenticate.py +0 -94
  170. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
  171. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
  172. {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -4,206 +4,45 @@ Manage backend.
4
4
 
5
5
  """
6
6
 
7
- import importlib.machinery
8
7
  import logging
9
- import multiprocessing
10
- import os
11
- import queue
12
- import sys
13
- import threading
14
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
15
-
16
- import wandb
8
+ from typing import TYPE_CHECKING, Optional
9
+
17
10
  from wandb.sdk.interface.interface import InterfaceBase
18
- from wandb.sdk.interface.interface_queue import InterfaceQueue
19
- from wandb.sdk.interface.router_queue import MessageQueueRouter
20
- from wandb.sdk.internal.internal import wandb_internal
21
- from wandb.sdk.internal.settings_static import SettingsStatic
22
- from wandb.sdk.mailbox import Mailbox
23
11
  from wandb.sdk.wandb_settings import Settings
24
12
 
25
13
  if TYPE_CHECKING:
26
- from wandb.proto.wandb_internal_pb2 import Record, Result
27
14
  from wandb.sdk.lib import service_connection
28
15
 
29
- RecordQueue = Union["queue.Queue[Record]", multiprocessing.Queue[Record]]
30
- ResultQueue = Union["queue.Queue[Result]", multiprocessing.Queue[Result]]
31
-
32
16
  logger = logging.getLogger("wandb")
33
17
 
34
18
 
35
- class BackendThread(threading.Thread):
36
- """Class to running internal process as a thread."""
37
-
38
- def __init__(self, target: Callable, kwargs: Dict[str, Any]) -> None:
39
- threading.Thread.__init__(self)
40
- self.name = "BackendThr"
41
- self._target = target
42
- self._kwargs = kwargs
43
- self.daemon = True
44
- self.pid = 0
45
-
46
- def run(self) -> None:
47
- self._target(**self._kwargs)
48
-
49
-
50
19
  class Backend:
51
- # multiprocessing context or module
52
- _multiprocessing: multiprocessing.context.BaseContext
53
-
54
20
  interface: Optional[InterfaceBase]
55
- _router: Optional[MessageQueueRouter]
56
21
 
57
- _internal_pid: Optional[int]
58
- wandb_process: Optional[multiprocessing.process.BaseProcess]
59
22
  _settings: Settings
60
- record_q: Optional["RecordQueue"]
61
- result_q: Optional["ResultQueue"]
23
+
24
+ _done: bool
25
+
26
+ _service: Optional["service_connection.ServiceConnection"]
62
27
 
63
28
  def __init__(
64
29
  self,
65
30
  settings: Settings,
66
- log_level: Optional[int] = None,
67
- service: "Optional[service_connection.ServiceConnection]" = None,
31
+ service: Optional["service_connection.ServiceConnection"] = None,
68
32
  ) -> None:
69
33
  self._done = False
70
- self.record_q = None
71
- self.result_q = None
72
- self.wandb_process = None
73
34
 
74
35
  self.interface = None
75
- self._router = None
76
36
 
77
- self._internal_pid = None
78
37
  self._settings = settings
79
- self._log_level = log_level
80
38
  self._service = service
81
39
 
82
- self._multiprocessing = multiprocessing # type: ignore
83
- self._multiprocessing_setup()
84
-
85
- # for _module_main_* methods
86
- self._save_mod_path: Optional[str] = None
87
- self._save_mod_spec = None
88
-
89
- def _multiprocessing_setup(self) -> None:
90
- if self._settings.start_method == "thread":
91
- return
92
-
93
- # defaulting to spawn for now, fork needs more testing
94
- start_method = self._settings.start_method or "spawn"
95
-
96
- # TODO: use fork context if unix and frozen?
97
- # if py34+, else fall back
98
- if not hasattr(multiprocessing, "get_context"):
99
- return
100
- all_methods = multiprocessing.get_all_start_methods()
101
- logger.info(
102
- "multiprocessing start_methods={}, using: {}".format(
103
- ",".join(all_methods), start_method
104
- )
105
- )
106
- ctx = multiprocessing.get_context(start_method)
107
- self._multiprocessing = ctx
108
-
109
- def _module_main_install(self) -> None:
110
- # Support running code without a: __name__ == "__main__"
111
- main_module = sys.modules["__main__"]
112
- main_mod_spec = getattr(main_module, "__spec__", None)
113
- main_mod_path = getattr(main_module, "__file__", None)
114
- if main_mod_spec is None: # hack for pdb
115
- # Note: typing has trouble with BuiltinImporter
116
- loader: Loader = importlib.machinery.BuiltinImporter # type: ignore # noqa: F821
117
- main_mod_spec = importlib.machinery.ModuleSpec(
118
- name="wandb.mpmain", loader=loader
119
- )
120
- main_module.__spec__ = main_mod_spec
121
- else:
122
- self._save_mod_spec = main_mod_spec
123
-
124
- if main_mod_path is not None:
125
- self._save_mod_path = main_module.__file__
126
- fname = os.path.join(
127
- os.path.dirname(wandb.__file__), "mpmain", "__main__.py"
128
- )
129
- main_module.__file__ = fname
130
-
131
- def _module_main_uninstall(self) -> None:
132
- main_module = sys.modules["__main__"]
133
- # Undo temporary changes from: __name__ == "__main__"
134
- main_module.__spec__ = self._save_mod_spec
135
- if self._save_mod_path:
136
- main_module.__file__ = self._save_mod_path
137
-
138
40
  def ensure_launched(self) -> None:
139
41
  """Launch backend worker if not running."""
140
- if self._service:
141
- assert self._settings.run_id
142
- self.interface = self._service.make_interface(
143
- stream_id=self._settings.run_id,
144
- )
145
- return
146
-
147
- settings = self._settings.model_copy()
148
- settings.x_log_level = self._log_level or logging.DEBUG
149
-
150
- start_method = settings.start_method
151
-
152
- settings_static = SettingsStatic(settings.to_proto())
153
- user_pid = os.getpid()
154
-
155
- if start_method == "thread":
156
- self.record_q = queue.Queue()
157
- self.result_q = queue.Queue()
158
- wandb._set_internal_process(disable=True) # type: ignore
159
- wandb_thread = BackendThread(
160
- target=wandb_internal,
161
- kwargs=dict(
162
- settings=settings_static,
163
- record_q=self.record_q,
164
- result_q=self.result_q,
165
- user_pid=user_pid,
166
- ),
167
- )
168
- # TODO: risky cast, assumes BackendThread Process duck typing
169
- self.wandb_process = wandb_thread # type: ignore
170
- else:
171
- self.record_q = self._multiprocessing.Queue()
172
- self.result_q = self._multiprocessing.Queue()
173
- self.wandb_process = self._multiprocessing.Process( # type: ignore
174
- target=wandb_internal,
175
- kwargs=dict(
176
- settings=settings_static,
177
- record_q=self.record_q,
178
- result_q=self.result_q,
179
- user_pid=user_pid,
180
- ),
181
- )
182
- assert self.wandb_process
183
- self.wandb_process.name = "wandb_internal"
184
-
185
- self._module_main_install()
186
-
187
- logger.info("starting backend process...")
188
- # Start the process with __name__ == "__main__" workarounds
189
- assert self.wandb_process
190
- self.wandb_process.start()
191
- self._internal_pid = self.wandb_process.pid
192
- logger.info(f"started backend process with pid: {self.wandb_process.pid}")
193
-
194
- self._module_main_uninstall()
195
-
196
- mailbox = Mailbox()
197
- self.interface = InterfaceQueue(
198
- process=self.wandb_process,
199
- record_q=self.record_q, # type: ignore
200
- result_q=self.result_q, # type: ignore
201
- mailbox=mailbox,
202
- )
203
- self._router = MessageQueueRouter(
204
- request_queue=self.record_q, # type: ignore
205
- response_queue=self.result_q, # type: ignore
206
- mailbox=mailbox,
42
+ assert self._settings.run_id
43
+ assert self._service
44
+ self.interface = self._service.make_interface(
45
+ stream_id=self._settings.run_id,
207
46
  )
208
47
 
209
48
  def server_status(self) -> None:
@@ -216,13 +55,3 @@ class Backend:
216
55
  self._done = True
217
56
  if self.interface:
218
57
  self.interface.join()
219
- if self._router:
220
- self._router.join()
221
- if self.wandb_process:
222
- self.wandb_process.join()
223
-
224
- if self.record_q and hasattr(self.record_q, "close"):
225
- self.record_q.close()
226
- if self.result_q and hasattr(self.result_q, "close"):
227
- self.result_q.close()
228
- # No printing allowed from here until redirect restore!!!
@@ -686,9 +686,7 @@ class ListType(Type):
686
686
  for ndx, obj in enumerate(list(other)):
687
687
  _new_element_type = new_element_type.assign(obj)
688
688
  if isinstance(_new_element_type, InvalidType):
689
- exp += "\n{}Index {}:\n{}".format(
690
- gap, ndx, new_element_type.explain(obj, depth + 1)
691
- )
689
+ exp += f"\n{gap}Index {ndx}:\n{new_element_type.explain(obj, depth + 1)}"
692
690
  break
693
691
  new_element_type = _new_element_type
694
692
  return exp
@@ -727,9 +725,7 @@ class NDArrayType(Type):
727
725
  return cls(shape)
728
726
  else:
729
727
  raise TypeError(
730
- "NDArrayType.from_obj expects py_obj to be ndarray or list, found {}".format(
731
- py_obj.__class__
732
- )
728
+ f"NDArrayType.from_obj expects py_obj to be ndarray or list, found {py_obj.__class__}"
733
729
  )
734
730
 
735
731
  def assign_type(self, wb_type: "Type") -> t.Union["NDArrayType", InvalidType]:
@@ -1,6 +1,7 @@
1
1
  import hashlib
2
2
  import os
3
- from typing import Optional
3
+ import pathlib
4
+ from typing import TYPE_CHECKING, Optional, Union
4
5
 
5
6
  from wandb import util
6
7
  from wandb.sdk.lib import filesystem, runid
@@ -9,6 +10,9 @@ from . import _dtypes
9
10
  from ._private import MEDIA_TMP
10
11
  from .base_types.media import BatchableMedia
11
12
 
13
+ if TYPE_CHECKING:
14
+ import numpy as np
15
+
12
16
 
13
17
  class Audio(BatchableMedia):
14
18
  """Wandb class for audio clips.
@@ -23,13 +27,25 @@ class Audio(BatchableMedia):
23
27
 
24
28
  _log_type = "audio-file"
25
29
 
26
- def __init__(self, data_or_path, sample_rate=None, caption=None):
30
+ def __init__(
31
+ self,
32
+ data_or_path: Union[
33
+ str,
34
+ pathlib.Path,
35
+ list,
36
+ "np.ndarray",
37
+ ],
38
+ sample_rate: Optional[int] = None,
39
+ caption: Optional[str] = None,
40
+ ):
27
41
  """Accept a path to an audio file or a numpy array of audio data."""
28
42
  super().__init__(caption=caption)
29
43
  self._duration = None
30
44
  self._sample_rate = sample_rate
31
45
 
32
- if isinstance(data_or_path, str):
46
+ if isinstance(data_or_path, (str, pathlib.Path)):
47
+ data_or_path = str(data_or_path)
48
+
33
49
  if self.path_is_reference(data_or_path):
34
50
  self._path = data_or_path
35
51
  self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
@@ -48,6 +64,7 @@ class Audio(BatchableMedia):
48
64
  )
49
65
 
50
66
  tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
67
+
51
68
  soundfile.write(tmp_path, data_or_path, sample_rate)
52
69
  self._duration = len(data_or_path) / float(sample_rate)
53
70
 
@@ -1,5 +1,6 @@
1
1
  import hashlib
2
2
  import os
3
+ import pathlib
3
4
  import platform
4
5
  import re
5
6
  import shutil
@@ -118,7 +119,10 @@ class Media(WBValue):
118
119
  self._caption = caption
119
120
 
120
121
  def _set_file(
121
- self, path: str, is_tmp: bool = False, extension: Optional[str] = None
122
+ self,
123
+ path: str,
124
+ is_tmp: bool = False,
125
+ extension: Optional[str] = None,
122
126
  ) -> None:
123
127
  self._path = path
124
128
  self._is_tmp = is_tmp
@@ -195,9 +199,9 @@ class Media(WBValue):
195
199
  else:
196
200
  try:
197
201
  shutil.copy(self._path, new_path)
198
- except shutil.SameFileError as e:
202
+ except shutil.SameFileError:
199
203
  if not ignore_copy_err:
200
- raise e
204
+ raise
201
205
  self._path = new_path
202
206
  run._publish_file(media_path)
203
207
 
@@ -243,9 +247,7 @@ class Media(WBValue):
243
247
  json_obj["_latest_artifact_path"] = artifact_entry_latest_url
244
248
 
245
249
  if artifact_entry_url is None or self.is_bound():
246
- assert self.is_bound(), "Value of type {} must be bound to a run with bind_to_run() before being serialized to JSON.".format(
247
- type(self).__name__
248
- )
250
+ assert self.is_bound(), f"Value of type {type(self).__name__} must be bound to a run with bind_to_run() before being serialized to JSON."
249
251
 
250
252
  assert (
251
253
  self._run is run
@@ -329,7 +331,10 @@ class Media(WBValue):
329
331
  )
330
332
 
331
333
  @staticmethod
332
- def path_is_reference(path: Optional[str]) -> bool:
334
+ def path_is_reference(path: Optional[Union[str, pathlib.Path]]) -> bool:
335
+ if path is None or isinstance(path, pathlib.Path):
336
+ return False
337
+
333
338
  return bool(path and re.match(r"^(gs|s3|https?)://", path))
334
339
 
335
340
 
@@ -22,7 +22,7 @@ def _is_maybe_offline() -> bool:
22
22
  Returns:
23
23
  Whether the user likely configured wandb to be offline.
24
24
  """
25
- singleton = wandb_setup._setup(start_service=False)
25
+ singleton = wandb_setup.singleton()
26
26
 
27
27
  # First check: if there's a run, check if it is offline.
28
28
  #
@@ -37,7 +37,7 @@ def _is_maybe_offline() -> bool:
37
37
 
38
38
 
39
39
  def _server_accepts_client_ids() -> bool:
40
- from wandb.util import parse_version
40
+ from packaging.version import parse
41
41
 
42
42
  # There are versions of W&B Server that cannot accept client IDs. Those versions of
43
43
  # the backend have a max_cli_version of less than "0.11.0." If the backend cannot
@@ -51,7 +51,7 @@ def _server_accepts_client_ids() -> bool:
51
51
  # client IDs.
52
52
 
53
53
  if _is_maybe_offline():
54
- singleton = wandb_setup._setup(start_service=False)
54
+ singleton = wandb_setup.singleton()
55
55
 
56
56
  if run := singleton.most_recent_active_run:
57
57
  return run._settings.allow_offline_artifacts
@@ -63,7 +63,7 @@ def _server_accepts_client_ids() -> bool:
63
63
  max_cli_version = util._get_max_cli_version()
64
64
  if max_cli_version is None:
65
65
  return False
66
- accepts_client_ids: bool = parse_version("0.11.0") <= parse_version(max_cli_version)
66
+ accepts_client_ids: bool = parse(max_cli_version) >= parse("0.11.0")
67
67
  return accepts_client_ids
68
68
 
69
69
 
@@ -220,9 +220,7 @@ class WBValue:
220
220
  ) -> None:
221
221
  assert (
222
222
  self._artifact_source is None
223
- ), "Cannot update artifact_source. Existing source: {}/{}".format(
224
- self._artifact_source.artifact, self._artifact_source.name
225
- )
223
+ ), f"Cannot update artifact_source. Existing source: {self._artifact_source.artifact}/{self._artifact_source.name}"
226
224
  self._artifact_source = _WBValueArtifactSource(artifact, name)
227
225
 
228
226
  def _set_artifact_target(
@@ -230,9 +228,7 @@ class WBValue:
230
228
  ) -> None:
231
229
  assert (
232
230
  self._artifact_target is None
233
- ), "Cannot update artifact_target. Existing target: {}/{}".format(
234
- self._artifact_target.artifact, self._artifact_target.name
235
- )
231
+ ), f"Cannot update artifact_target. Existing target: {self._artifact_target.artifact}/{self._artifact_target.name}"
236
232
  self._artifact_target = _WBValueArtifactTarget(artifact, name)
237
233
 
238
234
  def _get_artifact_entry_ref_url(self) -> Optional[str]:
@@ -250,10 +246,7 @@ class WBValue:
250
246
  and self._artifact_target.artifact._final
251
247
  and _server_accepts_client_ids()
252
248
  ):
253
- return "wandb-client-artifact://{}/{}".format(
254
- self._artifact_target.artifact._client_id,
255
- type(self).with_suffix(self._artifact_target.name),
256
- )
249
+ return f"wandb-client-artifact://{self._artifact_target.artifact._client_id}/{type(self).with_suffix(self._artifact_target.name)}"
257
250
  # Else if we do not support client IDs, but online, then block on upload
258
251
  # Note: this is old behavior just to stay backwards compatible
259
252
  # with older server versions. This code path should be removed
@@ -281,10 +274,7 @@ class WBValue:
281
274
  and self._artifact_target.artifact._final
282
275
  and _server_accepts_client_ids()
283
276
  ):
284
- return "wandb-client-artifact://{}:latest/{}".format(
285
- self._artifact_target.artifact._sequence_client_id,
286
- type(self).with_suffix(self._artifact_target.name),
287
- )
277
+ return f"wandb-client-artifact://{self._artifact_target.artifact._sequence_client_id}:latest/{type(self).with_suffix(self._artifact_target.name)}"
288
278
  # Else if we do not support client IDs, then block on upload
289
279
  # Note: this is old behavior just to stay backwards compatible
290
280
  # with older server versions. This code path should be removed
@@ -1,6 +1,8 @@
1
1
  import codecs
2
2
  import json
3
3
  import os
4
+ import pathlib
5
+ from typing import TYPE_CHECKING, Union
4
6
 
5
7
  from wandb import util
6
8
  from wandb.sdk.lib import runid
@@ -9,6 +11,9 @@ from . import _dtypes
9
11
  from ._private import MEDIA_TMP
10
12
  from .base_types.media import Media
11
13
 
14
+ if TYPE_CHECKING:
15
+ from bokeh import document, model
16
+
12
17
 
13
18
  class Bokeh(Media):
14
19
  """Wandb class for Bokeh plots.
@@ -19,10 +24,22 @@ class Bokeh(Media):
19
24
 
20
25
  _log_type = "bokeh-file"
21
26
 
22
- def __init__(self, data_or_path):
27
+ def __init__(
28
+ self,
29
+ data_or_path: Union[
30
+ str,
31
+ pathlib.Path,
32
+ "document.Document",
33
+ "model.Model",
34
+ ],
35
+ ):
23
36
  super().__init__()
24
37
  bokeh = util.get_module("bokeh", required=True)
25
- if isinstance(data_or_path, str) and os.path.exists(data_or_path):
38
+ if isinstance(data_or_path, (str, pathlib.Path)) and os.path.exists(
39
+ data_or_path
40
+ ):
41
+ data_or_path = str(data_or_path)
42
+
26
43
  with open(data_or_path) as file:
27
44
  b_json = json.load(file)
28
45
  self.b_obj = bokeh.document.Document.from_json(b_json)
@@ -13,6 +13,18 @@ if TYPE_CHECKING: # pragma: no cover
13
13
  from ...wandb_run import Run as LocalRun
14
14
 
15
15
 
16
+ def _convert_pytorch_tensor_to_list(box_data):
17
+ for box in box_data:
18
+ if (
19
+ "position" in box
20
+ and "middle" in box["position"]
21
+ and util.is_pytorch_tensor_typename(
22
+ util.get_full_typename(box["position"]["middle"])
23
+ )
24
+ ):
25
+ box["position"]["middle"] = box["position"]["middle"].tolist()
26
+
27
+
16
28
  class BoundingBoxes2D(JSONMetadata):
17
29
  """Format images with 2D bounding box overlays for logging to W&B.
18
30
 
@@ -195,7 +207,11 @@ class BoundingBoxes2D(JSONMetadata):
195
207
  key: (string) The readable name or id for this set of bounding boxes (e.g.
196
208
  predictions, ground_truth)
197
209
  """
210
+ # Pytorch tensors are not serializable to json,
211
+ # so we convert them to lists to avoid errors later on.
212
+ _convert_pytorch_tensor_to_list(val.get("box_data", []))
198
213
  super().__init__(val)
214
+
199
215
  self._val = val["box_data"]
200
216
  self._key = key
201
217
  # Add default class mapping
@@ -304,7 +320,7 @@ class BoundingBoxes2D(JSONMetadata):
304
320
  # an object with a _type key. Will need to push this change to the UI first to ensure backwards compat
305
321
  return self._val
306
322
  else:
307
- raise ValueError("to_json accepts wandb_run.Run or wandb.Artifact")
323
+ raise TypeError("to_json accepts wandb_run.Run or wandb.Artifact")
308
324
 
309
325
  @classmethod
310
326
  def from_json(
@@ -147,6 +147,12 @@ class ImageMask(Media):
147
147
  self._set_file(val["path"])
148
148
  else:
149
149
  np = util.get_module("numpy", required="Image mask support requires numpy")
150
+
151
+ if util.is_pytorch_tensor_typename(
152
+ util.get_full_typename(val["mask_data"])
153
+ ):
154
+ val["mask_data"] = val["mask_data"].cpu().numpy()
155
+
150
156
  # Add default class mapping
151
157
  if "class_labels" not in val:
152
158
  classes = np.unique(val["mask_data"]).astype(np.int32).tolist()
@@ -214,7 +220,7 @@ class ImageMask(Media):
214
220
  # Nothing special to add (used to add "digest", but no longer used.)
215
221
  return json_dict
216
222
  else:
217
- raise ValueError("to_json accepts wandb_run.Run or wandb.Artifact")
223
+ raise TypeError("to_json accepts wandb_run.Run or wandb.Artifact")
218
224
 
219
225
  @classmethod
220
226
  def type_name(cls: Type["ImageMask"]) -> str:
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import pathlib
2
3
  from typing import TYPE_CHECKING, Sequence, Type, Union
3
4
 
4
5
  from wandb.sdk.lib import filesystem, runid
@@ -22,7 +23,7 @@ class Html(BatchableMedia):
22
23
 
23
24
  def __init__(
24
25
  self,
25
- data: Union[str, "TextIO"],
26
+ data: Union[str, pathlib.Path, "TextIO"],
26
27
  inject: bool = True,
27
28
  data_is_not_path: bool = False,
28
29
  ) -> None:
@@ -52,14 +53,13 @@ class Html(BatchableMedia):
52
53
  """
53
54
  super().__init__()
54
55
  data_is_path = (
55
- isinstance(data, str)
56
+ isinstance(data, (str, pathlib.Path))
56
57
  and os.path.isfile(data)
57
58
  and os.path.splitext(data)[1] == ".html"
58
59
  ) and not data_is_not_path
59
60
  data_path = ""
60
61
  if data_is_path:
61
- assert isinstance(data, str)
62
- data_path = data
62
+ data_path = str(data)
63
63
  with open(data_path, encoding="utf-8") as file:
64
64
  self.html = file.read()
65
65
  elif isinstance(data, str):