wandb 0.18.0rc1__py3-none-win32.whl → 0.18.2__py3-none-win32.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. wandb/__init__.py +4 -4
  2. wandb/__init__.pyi +67 -12
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public/api.py +128 -2
  5. wandb/apis/public/artifacts.py +11 -7
  6. wandb/apis/public/jobs.py +8 -0
  7. wandb/apis/public/runs.py +18 -5
  8. wandb/bin/wandb-core +0 -0
  9. wandb/cli/cli.py +0 -5
  10. wandb/data_types.py +9 -2019
  11. wandb/env.py +0 -5
  12. wandb/errors/__init__.py +11 -40
  13. wandb/errors/errors.py +37 -0
  14. wandb/errors/warnings.py +2 -0
  15. wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
  16. wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
  17. wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
  18. wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
  19. wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
  20. wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
  21. wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
  22. wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
  23. wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
  24. wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
  25. wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
  26. wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
  27. wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
  28. wandb/{sklearn → integration/sklearn}/utils.py +8 -8
  29. wandb/integration/tensorboard/log.py +1 -1
  30. wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
  31. wandb/old/core.py +2 -80
  32. wandb/plot/bar.py +7 -4
  33. wandb/plot/confusion_matrix.py +5 -4
  34. wandb/plot/histogram.py +7 -4
  35. wandb/plot/line.py +7 -4
  36. wandb/proto/v3/wandb_base_pb2.py +2 -1
  37. wandb/proto/v3/wandb_internal_pb2.py +2 -1
  38. wandb/proto/v3/wandb_server_pb2.py +2 -1
  39. wandb/proto/v3/wandb_settings_pb2.py +3 -2
  40. wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
  41. wandb/proto/v4/wandb_base_pb2.py +2 -1
  42. wandb/proto/v4/wandb_internal_pb2.py +2 -1
  43. wandb/proto/v4/wandb_server_pb2.py +2 -1
  44. wandb/proto/v4/wandb_settings_pb2.py +3 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
  46. wandb/proto/v5/wandb_base_pb2.py +3 -2
  47. wandb/proto/v5/wandb_internal_pb2.py +3 -2
  48. wandb/proto/v5/wandb_server_pb2.py +3 -2
  49. wandb/proto/v5/wandb_settings_pb2.py +4 -3
  50. wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
  51. wandb/sdk/artifacts/_validators.py +48 -3
  52. wandb/sdk/artifacts/artifact.py +157 -183
  53. wandb/sdk/artifacts/artifact_file_cache.py +13 -11
  54. wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
  55. wandb/sdk/artifacts/artifact_manifest.py +13 -11
  56. wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
  57. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
  58. wandb/sdk/artifacts/artifact_saver.py +27 -25
  59. wandb/sdk/artifacts/exceptions.py +26 -25
  60. wandb/sdk/artifacts/storage_handler.py +11 -9
  61. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
  65. wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
  66. wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
  67. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
  68. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
  69. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
  70. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
  71. wandb/sdk/artifacts/storage_policy.py +20 -20
  72. wandb/sdk/backend/backend.py +8 -26
  73. wandb/sdk/data_types/audio.py +165 -0
  74. wandb/sdk/data_types/base_types/wb_value.py +1 -3
  75. wandb/sdk/data_types/bokeh.py +70 -0
  76. wandb/sdk/data_types/graph.py +405 -0
  77. wandb/sdk/data_types/image.py +156 -0
  78. wandb/sdk/data_types/table.py +1204 -0
  79. wandb/sdk/data_types/trace_tree.py +2 -2
  80. wandb/sdk/data_types/utils.py +49 -0
  81. wandb/sdk/data_types/video.py +2 -2
  82. wandb/sdk/interface/interface.py +0 -24
  83. wandb/sdk/interface/interface_shared.py +0 -12
  84. wandb/sdk/internal/handler.py +0 -10
  85. wandb/sdk/internal/internal_api.py +71 -0
  86. wandb/sdk/internal/sender.py +0 -43
  87. wandb/sdk/internal/tb_watcher.py +1 -1
  88. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  89. wandb/sdk/lib/hashutil.py +34 -12
  90. wandb/sdk/lib/service_connection.py +216 -0
  91. wandb/sdk/lib/service_token.py +94 -0
  92. wandb/sdk/lib/sock_client.py +7 -3
  93. wandb/sdk/service/server.py +2 -5
  94. wandb/sdk/service/service.py +2 -31
  95. wandb/sdk/service/streams.py +0 -7
  96. wandb/sdk/wandb_init.py +42 -25
  97. wandb/sdk/wandb_run.py +18 -159
  98. wandb/sdk/wandb_settings.py +2 -0
  99. wandb/sdk/wandb_setup.py +25 -16
  100. wandb/sdk/wandb_sync.py +9 -3
  101. wandb/sdk/wandb_watch.py +31 -15
  102. wandb/sklearn.py +35 -0
  103. wandb/util.py +14 -3
  104. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
  105. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +114 -110
  106. wandb/sdk/internal/update.py +0 -113
  107. wandb/sdk/lib/console.py +0 -39
  108. wandb/sdk/service/service_base.py +0 -50
  109. wandb/sdk/service/service_sock.py +0 -70
  110. wandb/sdk/wandb_manager.py +0 -232
  111. /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
  112. /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
  113. /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
  114. /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
  115. /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
  116. /wandb/{sdk/lib → plot}/viz.py +0 -0
  117. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
  118. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
  119. {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,12 @@
1
1
  """WandB storage policy."""
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import hashlib
4
6
  import math
5
7
  import os
6
8
  import shutil
7
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
9
+ from typing import TYPE_CHECKING, Any, Sequence
8
10
  from urllib.parse import quote
9
11
 
10
12
  import requests
@@ -64,15 +66,15 @@ class WandbStoragePolicy(StoragePolicy):
64
66
 
65
67
  @classmethod
66
68
  def from_config(
67
- cls, config: Dict, api: Optional[InternalApi] = None
68
- ) -> "WandbStoragePolicy":
69
+ cls, config: dict, api: InternalApi | None = None
70
+ ) -> WandbStoragePolicy:
69
71
  return cls(config=config, api=api)
70
72
 
71
73
  def __init__(
72
74
  self,
73
- config: Optional[Dict] = None,
74
- cache: Optional[ArtifactFileCache] = None,
75
- api: Optional[InternalApi] = None,
75
+ config: dict | None = None,
76
+ cache: ArtifactFileCache | None = None,
77
+ api: InternalApi | None = None,
76
78
  ) -> None:
77
79
  self._cache = cache or get_artifact_file_cache()
78
80
  self._config = config or {}
@@ -109,14 +111,14 @@ class WandbStoragePolicy(StoragePolicy):
109
111
  default_handler=TrackingHandler(),
110
112
  )
111
113
 
112
- def config(self) -> Dict:
114
+ def config(self) -> dict:
113
115
  return self._config
114
116
 
115
117
  def load_file(
116
118
  self,
117
- artifact: "Artifact",
118
- manifest_entry: "ArtifactManifestEntry",
119
- dest_path: Optional[str] = None,
119
+ artifact: Artifact,
120
+ manifest_entry: ArtifactManifestEntry,
121
+ dest_path: str | None = None,
120
122
  ) -> FilePathStr:
121
123
  if dest_path is not None:
122
124
  self._cache._override_cache_path = dest_path
@@ -159,22 +161,22 @@ class WandbStoragePolicy(StoragePolicy):
159
161
 
160
162
  def store_reference(
161
163
  self,
162
- artifact: "Artifact",
163
- path: Union[URIStr, FilePathStr],
164
- name: Optional[str] = None,
164
+ artifact: Artifact,
165
+ path: URIStr | FilePathStr,
166
+ name: str | None = None,
165
167
  checksum: bool = True,
166
- max_objects: Optional[int] = None,
167
- ) -> Sequence["ArtifactManifestEntry"]:
168
+ max_objects: int | None = None,
169
+ ) -> Sequence[ArtifactManifestEntry]:
168
170
  return self._handler.store_path(
169
171
  artifact, path, name=name, checksum=checksum, max_objects=max_objects
170
172
  )
171
173
 
172
174
  def load_reference(
173
175
  self,
174
- manifest_entry: "ArtifactManifestEntry",
176
+ manifest_entry: ArtifactManifestEntry,
175
177
  local: bool = False,
176
- dest_path: Optional[str] = None,
177
- ) -> Union[FilePathStr, URIStr]:
178
+ dest_path: str | None = None,
179
+ ) -> FilePathStr | URIStr:
178
180
  assert manifest_entry.ref is not None
179
181
  used_handler = self._handler._get_handler(manifest_entry.ref)
180
182
  if hasattr(used_handler, "_cache") and (dest_path is not None):
@@ -185,7 +187,7 @@ class WandbStoragePolicy(StoragePolicy):
185
187
  self,
186
188
  api: InternalApi,
187
189
  entity_name: str,
188
- manifest_entry: "ArtifactManifestEntry",
190
+ manifest_entry: ArtifactManifestEntry,
189
191
  ) -> str:
190
192
  storage_layout = self._config.get("storageLayout", StorageLayout.V1)
191
193
  storage_region = self._config.get("storageRegion", "default")
@@ -214,10 +216,10 @@ class WandbStoragePolicy(StoragePolicy):
214
216
  self,
215
217
  file_path: str,
216
218
  chunk_size: int,
217
- hex_digests: Dict[int, str],
218
- multipart_urls: Dict[int, str],
219
- extra_headers: Dict[str, str],
220
- ) -> List[Dict[str, Any]]:
219
+ hex_digests: dict[int, str],
220
+ multipart_urls: dict[int, str],
221
+ extra_headers: dict[str, str],
222
+ ) -> list[dict[str, Any]]:
221
223
  etags = []
222
224
  part_number = 1
223
225
 
@@ -247,8 +249,8 @@ class WandbStoragePolicy(StoragePolicy):
247
249
  self,
248
250
  upload_url: str,
249
251
  file_path: str,
250
- extra_headers: Dict[str, Any],
251
- progress_callback: Optional["progress.ProgressFn"] = None,
252
+ extra_headers: dict[str, Any],
253
+ progress_callback: progress.ProgressFn | None = None,
252
254
  ) -> None:
253
255
  """Upload a file to the artifact store and write to cache."""
254
256
  with open(file_path, "rb") as file:
@@ -272,9 +274,9 @@ class WandbStoragePolicy(StoragePolicy):
272
274
  self,
273
275
  artifact_id: str,
274
276
  artifact_manifest_id: str,
275
- entry: "ArtifactManifestEntry",
276
- preparer: "StepPrepare",
277
- progress_callback: Optional["progress.ProgressFn"] = None,
277
+ entry: ArtifactManifestEntry,
278
+ preparer: StepPrepare,
279
+ progress_callback: progress.ProgressFn | None = None,
278
280
  ) -> bool:
279
281
  """Upload a file to the artifact store.
280
282
 
@@ -352,7 +354,7 @@ class WandbStoragePolicy(StoragePolicy):
352
354
 
353
355
  return False
354
356
 
355
- def _write_cache(self, entry: "ArtifactManifestEntry") -> None:
357
+ def _write_cache(self, entry: ArtifactManifestEntry) -> None:
356
358
  if entry.local_path is None:
357
359
  return
358
360
 
@@ -1,6 +1,8 @@
1
1
  """Storage policy."""
2
2
 
3
- from typing import TYPE_CHECKING, Dict, Optional, Sequence, Type, Union
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Sequence
4
6
 
5
7
  from wandb.sdk.internal.internal_api import Api as InternalApi
6
8
  from wandb.sdk.lib.paths import FilePathStr, URIStr
@@ -14,7 +16,7 @@ if TYPE_CHECKING:
14
16
 
15
17
  class StoragePolicy:
16
18
  @classmethod
17
- def lookup_by_name(cls, name: str) -> Type["StoragePolicy"]:
19
+ def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
18
20
  import wandb.sdk.artifacts.storage_policies # noqa: F401
19
21
 
20
22
  for sub in cls.__subclasses__():
@@ -27,19 +29,17 @@ class StoragePolicy:
27
29
  raise NotImplementedError
28
30
 
29
31
  @classmethod
30
- def from_config(
31
- cls, config: Dict, api: Optional[InternalApi] = None
32
- ) -> "StoragePolicy":
32
+ def from_config(cls, config: dict, api: InternalApi | None = None) -> StoragePolicy:
33
33
  raise NotImplementedError
34
34
 
35
- def config(self) -> Dict:
35
+ def config(self) -> dict:
36
36
  raise NotImplementedError
37
37
 
38
38
  def load_file(
39
39
  self,
40
- artifact: "Artifact",
41
- manifest_entry: "ArtifactManifestEntry",
42
- dest_path: Optional[str] = None,
40
+ artifact: Artifact,
41
+ manifest_entry: ArtifactManifestEntry,
42
+ dest_path: str | None = None,
43
43
  ) -> FilePathStr:
44
44
  raise NotImplementedError
45
45
 
@@ -47,26 +47,26 @@ class StoragePolicy:
47
47
  self,
48
48
  artifact_id: str,
49
49
  artifact_manifest_id: str,
50
- entry: "ArtifactManifestEntry",
51
- preparer: "StepPrepare",
52
- progress_callback: Optional["ProgressFn"] = None,
50
+ entry: ArtifactManifestEntry,
51
+ preparer: StepPrepare,
52
+ progress_callback: ProgressFn | None = None,
53
53
  ) -> bool:
54
54
  raise NotImplementedError
55
55
 
56
56
  def store_reference(
57
57
  self,
58
- artifact: "Artifact",
59
- path: Union[URIStr, FilePathStr],
60
- name: Optional[str] = None,
58
+ artifact: Artifact,
59
+ path: URIStr | FilePathStr,
60
+ name: str | None = None,
61
61
  checksum: bool = True,
62
- max_objects: Optional[int] = None,
63
- ) -> Sequence["ArtifactManifestEntry"]:
62
+ max_objects: int | None = None,
63
+ ) -> Sequence[ArtifactManifestEntry]:
64
64
  raise NotImplementedError
65
65
 
66
66
  def load_reference(
67
67
  self,
68
- manifest_entry: "ArtifactManifestEntry",
68
+ manifest_entry: ArtifactManifestEntry,
69
69
  local: bool = False,
70
- dest_path: Optional[str] = None,
71
- ) -> Union[FilePathStr, URIStr]:
70
+ dest_path: str | None = None,
71
+ ) -> FilePathStr | URIStr:
72
72
  raise NotImplementedError
@@ -11,7 +11,7 @@ import os
11
11
  import queue
12
12
  import sys
13
13
  import threading
14
- from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
14
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
15
15
 
16
16
  import wandb
17
17
  from wandb.sdk.interface.interface import InterfaceBase
@@ -19,17 +19,16 @@ from wandb.sdk.interface.interface_queue import InterfaceQueue
19
19
  from wandb.sdk.internal.internal import wandb_internal
20
20
  from wandb.sdk.internal.settings_static import SettingsStatic
21
21
  from wandb.sdk.lib.mailbox import Mailbox
22
- from wandb.sdk.wandb_manager import _Manager
23
22
  from wandb.sdk.wandb_settings import Settings
24
23
 
25
24
  if TYPE_CHECKING:
26
25
  from wandb.proto.wandb_internal_pb2 import Record, Result
26
+ from wandb.sdk.lib import service_connection
27
27
 
28
- from ..service.service_sock import ServiceSockInterface
29
28
  from ..wandb_run import Run
30
29
 
31
- RecordQueue = Union[queue.Queue[Record], multiprocessing.Queue[Record]]
32
- ResultQueue = Union[queue.Queue[Result], multiprocessing.Queue[Result]]
30
+ RecordQueue = Union["queue.Queue[Record]", multiprocessing.Queue[Record]]
31
+ ResultQueue = Union["queue.Queue[Result]", multiprocessing.Queue[Result]]
33
32
 
34
33
  logger = logging.getLogger("wandb")
35
34
 
@@ -65,7 +64,7 @@ class Backend:
65
64
  mailbox: Mailbox,
66
65
  settings: Optional[Settings] = None,
67
66
  log_level: Optional[int] = None,
68
- manager: Optional[_Manager] = None,
67
+ service: "Optional[service_connection.ServiceConnection]" = None,
69
68
  ) -> None:
70
69
  self._done = False
71
70
  self.record_q = None
@@ -75,7 +74,7 @@ class Backend:
75
74
  self._internal_pid = None
76
75
  self._settings = settings
77
76
  self._log_level = log_level
78
- self._manager = manager
77
+ self._service = service
79
78
  self._mailbox = mailbox
80
79
 
81
80
  self._multiprocessing = multiprocessing # type: ignore
@@ -139,27 +138,10 @@ class Backend:
139
138
  if self._save_mod_path:
140
139
  main_module.__file__ = self._save_mod_path
141
140
 
142
- def _ensure_launched_manager(self) -> None:
143
- assert self._manager
144
- svc = self._manager._get_service()
145
- assert svc
146
- svc_iface = svc.service_interface
147
-
148
- svc_transport = svc_iface.get_transport()
149
- if svc_transport == "tcp":
150
- from ..interface.interface_sock import InterfaceSock
151
-
152
- svc_iface_sock = cast("ServiceSockInterface", svc_iface)
153
- sock_client = svc_iface_sock._get_sock_client()
154
- sock_interface = InterfaceSock(sock_client, mailbox=self._mailbox)
155
- self.interface = sock_interface
156
- else:
157
- raise AssertionError(f"Unsupported service transport: {svc_transport}")
158
-
159
141
  def ensure_launched(self) -> None:
160
142
  """Launch backend worker if not running."""
161
- if self._manager:
162
- self._ensure_launched_manager()
143
+ if self._service:
144
+ self.interface = self._service.make_interface(self._mailbox)
163
145
  return
164
146
 
165
147
  assert self._settings
@@ -0,0 +1,165 @@
1
+ import hashlib
2
+ import os
3
+ from typing import Optional
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import filesystem, runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import BatchableMedia
11
+
12
+
13
+ class Audio(BatchableMedia):
14
+ """Wandb class for audio clips.
15
+
16
+ Arguments:
17
+ data_or_path: (string or numpy array) A path to an audio file
18
+ or a numpy array of audio data.
19
+ sample_rate: (int) Sample rate, required when passing in raw
20
+ numpy array of audio data.
21
+ caption: (string) Caption to display with audio.
22
+ """
23
+
24
+ _log_type = "audio-file"
25
+
26
+ def __init__(self, data_or_path, sample_rate=None, caption=None):
27
+ """Accept a path to an audio file or a numpy array of audio data."""
28
+ super().__init__()
29
+ self._duration = None
30
+ self._sample_rate = sample_rate
31
+ self._caption = caption
32
+
33
+ if isinstance(data_or_path, str):
34
+ if self.path_is_reference(data_or_path):
35
+ self._path = data_or_path
36
+ self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
37
+ self._is_tmp = False
38
+ else:
39
+ self._set_file(data_or_path, is_tmp=False)
40
+ else:
41
+ if sample_rate is None:
42
+ raise ValueError(
43
+ 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
44
+ )
45
+
46
+ soundfile = util.get_module(
47
+ "soundfile",
48
+ required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
49
+ )
50
+
51
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
52
+ soundfile.write(tmp_path, data_or_path, sample_rate)
53
+ self._duration = len(data_or_path) / float(sample_rate)
54
+
55
+ self._set_file(tmp_path, is_tmp=True)
56
+
57
+ @classmethod
58
+ def get_media_subdir(cls):
59
+ return os.path.join("media", "audio")
60
+
61
+ @classmethod
62
+ def from_json(cls, json_obj, source_artifact):
63
+ return cls(
64
+ source_artifact.get_entry(json_obj["path"]).download(),
65
+ caption=json_obj["caption"],
66
+ )
67
+
68
+ def bind_to_run(
69
+ self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
70
+ ):
71
+ if self.path_is_reference(self._path):
72
+ raise ValueError(
73
+ "Audio media created by a reference to external storage cannot currently be added to a run"
74
+ )
75
+
76
+ return super().bind_to_run(run, key, step, id_, ignore_copy_err)
77
+
78
+ def to_json(self, run):
79
+ json_dict = super().to_json(run)
80
+ json_dict.update(
81
+ {
82
+ "_type": self._log_type,
83
+ "caption": self._caption,
84
+ }
85
+ )
86
+ return json_dict
87
+
88
+ @classmethod
89
+ def seq_to_json(cls, seq, run, key, step):
90
+ audio_list = list(seq)
91
+
92
+ util.get_module(
93
+ "soundfile",
94
+ required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
95
+ )
96
+ base_path = os.path.join(run.dir, "media", "audio")
97
+ filesystem.mkdir_exists_ok(base_path)
98
+ meta = {
99
+ "_type": "audio",
100
+ "count": len(audio_list),
101
+ "audio": [a.to_json(run) for a in audio_list],
102
+ }
103
+ sample_rates = cls.sample_rates(audio_list)
104
+ if sample_rates:
105
+ meta["sampleRates"] = sample_rates
106
+ durations = cls.durations(audio_list)
107
+ if durations:
108
+ meta["durations"] = durations
109
+ captions = cls.captions(audio_list)
110
+ if captions:
111
+ meta["captions"] = captions
112
+
113
+ return meta
114
+
115
+ @classmethod
116
+ def durations(cls, audio_list):
117
+ return [a._duration for a in audio_list]
118
+
119
+ @classmethod
120
+ def sample_rates(cls, audio_list):
121
+ return [a._sample_rate for a in audio_list]
122
+
123
+ @classmethod
124
+ def captions(cls, audio_list):
125
+ captions = [a._caption for a in audio_list]
126
+ if all(c is None for c in captions):
127
+ return False
128
+ else:
129
+ return ["" if c is None else c for c in captions]
130
+
131
+ def resolve_ref(self):
132
+ if self.path_is_reference(self._path):
133
+ # this object was already created using a ref:
134
+ return self._path
135
+ source_artifact = self._artifact_source.artifact
136
+
137
+ resolved_name = source_artifact._local_path_to_name(self._path)
138
+ if resolved_name is not None:
139
+ target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
140
+ if target_entry is not None:
141
+ return target_entry.ref
142
+
143
+ return None
144
+
145
+ def __eq__(self, other):
146
+ if self.path_is_reference(self._path) or self.path_is_reference(other._path):
147
+ # one or more of these objects is an unresolved reference -- we'll compare
148
+ # their reference paths instead of their SHAs:
149
+ return (
150
+ self.resolve_ref() == other.resolve_ref()
151
+ and self._caption == other._caption
152
+ )
153
+
154
+ return super().__eq__(other) and self._caption == other._caption
155
+
156
+ def __ne__(self, other):
157
+ return not self.__eq__(other)
158
+
159
+
160
+ class _AudioFileType(_dtypes.Type):
161
+ name = "audio-file"
162
+ types = [Audio]
163
+
164
+
165
+ _dtypes.TypeRegistry.add(_AudioFileType)
@@ -88,9 +88,7 @@ class WBValue:
88
88
  raise NotImplementedError
89
89
 
90
90
  @classmethod
91
- def from_json(
92
- cls: Type["WBValue"], json_obj: dict, source_artifact: "Artifact"
93
- ) -> "WBValue":
91
+ def from_json(cls, json_obj: dict, source_artifact: "Artifact") -> "WBValue":
94
92
  """Deserialize a `json_obj` into it's class representation.
95
93
 
96
94
  If additional resources were stored in the `run_or_artifact` artifact during the
@@ -0,0 +1,70 @@
1
+ import codecs
2
+ import json
3
+ import os
4
+
5
+ from wandb import util
6
+ from wandb.sdk.lib import runid
7
+
8
+ from . import _dtypes
9
+ from ._private import MEDIA_TMP
10
+ from .base_types.media import Media
11
+
12
+
13
+ class Bokeh(Media):
14
+ """Wandb class for Bokeh plots.
15
+
16
+ Arguments:
17
+ val: Bokeh plot
18
+ """
19
+
20
+ _log_type = "bokeh-file"
21
+
22
+ def __init__(self, data_or_path):
23
+ super().__init__()
24
+ bokeh = util.get_module("bokeh", required=True)
25
+ if isinstance(data_or_path, str) and os.path.exists(data_or_path):
26
+ with open(data_or_path) as file:
27
+ b_json = json.load(file)
28
+ self.b_obj = bokeh.document.Document.from_json(b_json)
29
+ self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
30
+ elif isinstance(data_or_path, bokeh.model.Model):
31
+ _data = bokeh.document.Document()
32
+ _data.add_root(data_or_path)
33
+ # serialize/deserialize pairing followed by sorting attributes ensures
34
+ # that the file's sha's are equivalent in subsequent calls
35
+ self.b_obj = bokeh.document.Document.from_json(_data.to_json())
36
+ b_json = self.b_obj.to_json()
37
+ if "references" in b_json["roots"]:
38
+ b_json["roots"]["references"].sort(key=lambda x: x["id"])
39
+
40
+ tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
41
+ with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
42
+ util.json_dump_safer(b_json, fp)
43
+ self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
44
+ elif not isinstance(data_or_path, bokeh.document.Document):
45
+ raise TypeError(
46
+ "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
47
+ )
48
+
49
+ def get_media_subdir(self):
50
+ return os.path.join("media", "bokeh")
51
+
52
+ def to_json(self, run):
53
+ # TODO: (tss) this is getting redundant for all the media objects. We can probably
54
+ # pull this into Media#to_json and remove this type override for all the media types.
55
+ # There are only a few cases where the type is different between artifacts and runs.
56
+ json_dict = super().to_json(run)
57
+ json_dict["_type"] = self._log_type
58
+ return json_dict
59
+
60
+ @classmethod
61
+ def from_json(cls, json_obj, source_artifact):
62
+ return cls(source_artifact.get_entry(json_obj["path"]).download())
63
+
64
+
65
+ class _BokehFileType(_dtypes.Type):
66
+ name = "bokeh-file"
67
+ types = [Bokeh]
68
+
69
+
70
+ _dtypes.TypeRegistry.add(_BokehFileType)