wandb 0.18.0rc1__py3-none-win_amd64.whl → 0.18.2__py3-none-win_amd64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +4 -4
- wandb/__init__.pyi +67 -12
- wandb/apis/internal.py +3 -0
- wandb/apis/public/api.py +128 -2
- wandb/apis/public/artifacts.py +11 -7
- wandb/apis/public/jobs.py +8 -0
- wandb/apis/public/runs.py +18 -5
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +0 -5
- wandb/data_types.py +9 -2019
- wandb/env.py +0 -5
- wandb/errors/__init__.py +11 -40
- wandb/errors/errors.py +37 -0
- wandb/errors/warnings.py +2 -0
- wandb/{sklearn → integration/sklearn}/calculate/calibration_curves.py +7 -7
- wandb/{sklearn → integration/sklearn}/calculate/class_proportions.py +1 -1
- wandb/{sklearn → integration/sklearn}/calculate/confusion_matrix.py +3 -2
- wandb/{sklearn → integration/sklearn}/calculate/elbow_curve.py +6 -6
- wandb/{sklearn → integration/sklearn}/calculate/learning_curve.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/outlier_candidates.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/residuals.py +8 -8
- wandb/{sklearn → integration/sklearn}/calculate/silhouette.py +2 -2
- wandb/{sklearn → integration/sklearn}/calculate/summary_metrics.py +2 -2
- wandb/{sklearn → integration/sklearn}/plot/classifier.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/clusterer.py +10 -6
- wandb/{sklearn → integration/sklearn}/plot/regressor.py +5 -5
- wandb/{sklearn → integration/sklearn}/plot/shared.py +3 -3
- wandb/{sklearn → integration/sklearn}/utils.py +8 -8
- wandb/integration/tensorboard/log.py +1 -1
- wandb/{wandb_torch.py → integration/torch/wandb_torch.py} +36 -32
- wandb/old/core.py +2 -80
- wandb/plot/bar.py +7 -4
- wandb/plot/confusion_matrix.py +5 -4
- wandb/plot/histogram.py +7 -4
- wandb/plot/line.py +7 -4
- wandb/proto/v3/wandb_base_pb2.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +2 -1
- wandb/proto/v3/wandb_server_pb2.py +2 -1
- wandb/proto/v3/wandb_settings_pb2.py +3 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v4/wandb_base_pb2.py +2 -1
- wandb/proto/v4/wandb_internal_pb2.py +2 -1
- wandb/proto/v4/wandb_server_pb2.py +2 -1
- wandb/proto/v4/wandb_settings_pb2.py +3 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +2 -1
- wandb/proto/v5/wandb_base_pb2.py +3 -2
- wandb/proto/v5/wandb_internal_pb2.py +3 -2
- wandb/proto/v5/wandb_server_pb2.py +3 -2
- wandb/proto/v5/wandb_settings_pb2.py +4 -3
- wandb/proto/v5/wandb_telemetry_pb2.py +3 -2
- wandb/sdk/artifacts/_validators.py +48 -3
- wandb/sdk/artifacts/artifact.py +157 -183
- wandb/sdk/artifacts/artifact_file_cache.py +13 -11
- wandb/sdk/artifacts/artifact_instance_cache.py +4 -2
- wandb/sdk/artifacts/artifact_manifest.py +13 -11
- wandb/sdk/artifacts/artifact_manifest_entry.py +24 -22
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +9 -7
- wandb/sdk/artifacts/artifact_saver.py +27 -25
- wandb/sdk/artifacts/exceptions.py +26 -25
- wandb/sdk/artifacts/storage_handler.py +11 -9
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -14
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +15 -13
- wandb/sdk/artifacts/storage_handlers/http_handler.py +15 -14
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +14 -12
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +19 -19
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +12 -10
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +9 -7
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +31 -29
- wandb/sdk/artifacts/storage_policy.py +20 -20
- wandb/sdk/backend/backend.py +8 -26
- wandb/sdk/data_types/audio.py +165 -0
- wandb/sdk/data_types/base_types/wb_value.py +1 -3
- wandb/sdk/data_types/bokeh.py +70 -0
- wandb/sdk/data_types/graph.py +405 -0
- wandb/sdk/data_types/image.py +156 -0
- wandb/sdk/data_types/table.py +1204 -0
- wandb/sdk/data_types/trace_tree.py +2 -2
- wandb/sdk/data_types/utils.py +49 -0
- wandb/sdk/data_types/video.py +2 -2
- wandb/sdk/interface/interface.py +0 -24
- wandb/sdk/interface/interface_shared.py +0 -12
- wandb/sdk/internal/handler.py +0 -10
- wandb/sdk/internal/internal_api.py +71 -0
- wandb/sdk/internal/sender.py +0 -43
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +1 -0
- wandb/sdk/lib/hashutil.py +34 -12
- wandb/sdk/lib/service_connection.py +216 -0
- wandb/sdk/lib/service_token.py +94 -0
- wandb/sdk/lib/sock_client.py +7 -3
- wandb/sdk/service/server.py +2 -5
- wandb/sdk/service/service.py +2 -31
- wandb/sdk/service/streams.py +0 -7
- wandb/sdk/wandb_init.py +42 -25
- wandb/sdk/wandb_run.py +18 -159
- wandb/sdk/wandb_settings.py +2 -0
- wandb/sdk/wandb_setup.py +25 -16
- wandb/sdk/wandb_sync.py +9 -3
- wandb/sdk/wandb_watch.py +31 -15
- wandb/sklearn.py +35 -0
- wandb/util.py +14 -3
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/METADATA +6 -5
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/RECORD +114 -110
- wandb/sdk/internal/update.py +0 -113
- wandb/sdk/lib/console.py +0 -39
- wandb/sdk/service/service_base.py +0 -50
- wandb/sdk/service/service_sock.py +0 -70
- wandb/sdk/wandb_manager.py +0 -232
- /wandb/{sklearn → integration/sklearn}/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/__init__.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/decision_boundaries.py +0 -0
- /wandb/{sklearn → integration/sklearn}/calculate/feature_importances.py +0 -0
- /wandb/{sklearn → integration/sklearn}/plot/__init__.py +0 -0
- /wandb/{sdk/lib → plot}/viz.py +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/WHEEL +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.0rc1.dist-info → wandb-0.18.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,10 +1,12 @@
|
|
1
1
|
"""WandB storage policy."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import hashlib
|
4
6
|
import math
|
5
7
|
import os
|
6
8
|
import shutil
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
9
|
+
from typing import TYPE_CHECKING, Any, Sequence
|
8
10
|
from urllib.parse import quote
|
9
11
|
|
10
12
|
import requests
|
@@ -64,15 +66,15 @@ class WandbStoragePolicy(StoragePolicy):
|
|
64
66
|
|
65
67
|
@classmethod
|
66
68
|
def from_config(
|
67
|
-
cls, config:
|
68
|
-
) ->
|
69
|
+
cls, config: dict, api: InternalApi | None = None
|
70
|
+
) -> WandbStoragePolicy:
|
69
71
|
return cls(config=config, api=api)
|
70
72
|
|
71
73
|
def __init__(
|
72
74
|
self,
|
73
|
-
config:
|
74
|
-
cache:
|
75
|
-
api:
|
75
|
+
config: dict | None = None,
|
76
|
+
cache: ArtifactFileCache | None = None,
|
77
|
+
api: InternalApi | None = None,
|
76
78
|
) -> None:
|
77
79
|
self._cache = cache or get_artifact_file_cache()
|
78
80
|
self._config = config or {}
|
@@ -109,14 +111,14 @@ class WandbStoragePolicy(StoragePolicy):
|
|
109
111
|
default_handler=TrackingHandler(),
|
110
112
|
)
|
111
113
|
|
112
|
-
def config(self) ->
|
114
|
+
def config(self) -> dict:
|
113
115
|
return self._config
|
114
116
|
|
115
117
|
def load_file(
|
116
118
|
self,
|
117
|
-
artifact:
|
118
|
-
manifest_entry:
|
119
|
-
dest_path:
|
119
|
+
artifact: Artifact,
|
120
|
+
manifest_entry: ArtifactManifestEntry,
|
121
|
+
dest_path: str | None = None,
|
120
122
|
) -> FilePathStr:
|
121
123
|
if dest_path is not None:
|
122
124
|
self._cache._override_cache_path = dest_path
|
@@ -159,22 +161,22 @@ class WandbStoragePolicy(StoragePolicy):
|
|
159
161
|
|
160
162
|
def store_reference(
|
161
163
|
self,
|
162
|
-
artifact:
|
163
|
-
path:
|
164
|
-
name:
|
164
|
+
artifact: Artifact,
|
165
|
+
path: URIStr | FilePathStr,
|
166
|
+
name: str | None = None,
|
165
167
|
checksum: bool = True,
|
166
|
-
max_objects:
|
167
|
-
) -> Sequence[
|
168
|
+
max_objects: int | None = None,
|
169
|
+
) -> Sequence[ArtifactManifestEntry]:
|
168
170
|
return self._handler.store_path(
|
169
171
|
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
170
172
|
)
|
171
173
|
|
172
174
|
def load_reference(
|
173
175
|
self,
|
174
|
-
manifest_entry:
|
176
|
+
manifest_entry: ArtifactManifestEntry,
|
175
177
|
local: bool = False,
|
176
|
-
dest_path:
|
177
|
-
) ->
|
178
|
+
dest_path: str | None = None,
|
179
|
+
) -> FilePathStr | URIStr:
|
178
180
|
assert manifest_entry.ref is not None
|
179
181
|
used_handler = self._handler._get_handler(manifest_entry.ref)
|
180
182
|
if hasattr(used_handler, "_cache") and (dest_path is not None):
|
@@ -185,7 +187,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
185
187
|
self,
|
186
188
|
api: InternalApi,
|
187
189
|
entity_name: str,
|
188
|
-
manifest_entry:
|
190
|
+
manifest_entry: ArtifactManifestEntry,
|
189
191
|
) -> str:
|
190
192
|
storage_layout = self._config.get("storageLayout", StorageLayout.V1)
|
191
193
|
storage_region = self._config.get("storageRegion", "default")
|
@@ -214,10 +216,10 @@ class WandbStoragePolicy(StoragePolicy):
|
|
214
216
|
self,
|
215
217
|
file_path: str,
|
216
218
|
chunk_size: int,
|
217
|
-
hex_digests:
|
218
|
-
multipart_urls:
|
219
|
-
extra_headers:
|
220
|
-
) ->
|
219
|
+
hex_digests: dict[int, str],
|
220
|
+
multipart_urls: dict[int, str],
|
221
|
+
extra_headers: dict[str, str],
|
222
|
+
) -> list[dict[str, Any]]:
|
221
223
|
etags = []
|
222
224
|
part_number = 1
|
223
225
|
|
@@ -247,8 +249,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
247
249
|
self,
|
248
250
|
upload_url: str,
|
249
251
|
file_path: str,
|
250
|
-
extra_headers:
|
251
|
-
progress_callback:
|
252
|
+
extra_headers: dict[str, Any],
|
253
|
+
progress_callback: progress.ProgressFn | None = None,
|
252
254
|
) -> None:
|
253
255
|
"""Upload a file to the artifact store and write to cache."""
|
254
256
|
with open(file_path, "rb") as file:
|
@@ -272,9 +274,9 @@ class WandbStoragePolicy(StoragePolicy):
|
|
272
274
|
self,
|
273
275
|
artifact_id: str,
|
274
276
|
artifact_manifest_id: str,
|
275
|
-
entry:
|
276
|
-
preparer:
|
277
|
-
progress_callback:
|
277
|
+
entry: ArtifactManifestEntry,
|
278
|
+
preparer: StepPrepare,
|
279
|
+
progress_callback: progress.ProgressFn | None = None,
|
278
280
|
) -> bool:
|
279
281
|
"""Upload a file to the artifact store.
|
280
282
|
|
@@ -352,7 +354,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
352
354
|
|
353
355
|
return False
|
354
356
|
|
355
|
-
def _write_cache(self, entry:
|
357
|
+
def _write_cache(self, entry: ArtifactManifestEntry) -> None:
|
356
358
|
if entry.local_path is None:
|
357
359
|
return
|
358
360
|
|
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Storage policy."""
|
2
2
|
|
3
|
-
from
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING, Sequence
|
4
6
|
|
5
7
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
6
8
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
@@ -14,7 +16,7 @@ if TYPE_CHECKING:
|
|
14
16
|
|
15
17
|
class StoragePolicy:
|
16
18
|
@classmethod
|
17
|
-
def lookup_by_name(cls, name: str) ->
|
19
|
+
def lookup_by_name(cls, name: str) -> type[StoragePolicy]:
|
18
20
|
import wandb.sdk.artifacts.storage_policies # noqa: F401
|
19
21
|
|
20
22
|
for sub in cls.__subclasses__():
|
@@ -27,19 +29,17 @@ class StoragePolicy:
|
|
27
29
|
raise NotImplementedError
|
28
30
|
|
29
31
|
@classmethod
|
30
|
-
def from_config(
|
31
|
-
cls, config: Dict, api: Optional[InternalApi] = None
|
32
|
-
) -> "StoragePolicy":
|
32
|
+
def from_config(cls, config: dict, api: InternalApi | None = None) -> StoragePolicy:
|
33
33
|
raise NotImplementedError
|
34
34
|
|
35
|
-
def config(self) ->
|
35
|
+
def config(self) -> dict:
|
36
36
|
raise NotImplementedError
|
37
37
|
|
38
38
|
def load_file(
|
39
39
|
self,
|
40
|
-
artifact:
|
41
|
-
manifest_entry:
|
42
|
-
dest_path:
|
40
|
+
artifact: Artifact,
|
41
|
+
manifest_entry: ArtifactManifestEntry,
|
42
|
+
dest_path: str | None = None,
|
43
43
|
) -> FilePathStr:
|
44
44
|
raise NotImplementedError
|
45
45
|
|
@@ -47,26 +47,26 @@ class StoragePolicy:
|
|
47
47
|
self,
|
48
48
|
artifact_id: str,
|
49
49
|
artifact_manifest_id: str,
|
50
|
-
entry:
|
51
|
-
preparer:
|
52
|
-
progress_callback:
|
50
|
+
entry: ArtifactManifestEntry,
|
51
|
+
preparer: StepPrepare,
|
52
|
+
progress_callback: ProgressFn | None = None,
|
53
53
|
) -> bool:
|
54
54
|
raise NotImplementedError
|
55
55
|
|
56
56
|
def store_reference(
|
57
57
|
self,
|
58
|
-
artifact:
|
59
|
-
path:
|
60
|
-
name:
|
58
|
+
artifact: Artifact,
|
59
|
+
path: URIStr | FilePathStr,
|
60
|
+
name: str | None = None,
|
61
61
|
checksum: bool = True,
|
62
|
-
max_objects:
|
63
|
-
) -> Sequence[
|
62
|
+
max_objects: int | None = None,
|
63
|
+
) -> Sequence[ArtifactManifestEntry]:
|
64
64
|
raise NotImplementedError
|
65
65
|
|
66
66
|
def load_reference(
|
67
67
|
self,
|
68
|
-
manifest_entry:
|
68
|
+
manifest_entry: ArtifactManifestEntry,
|
69
69
|
local: bool = False,
|
70
|
-
dest_path:
|
71
|
-
) ->
|
70
|
+
dest_path: str | None = None,
|
71
|
+
) -> FilePathStr | URIStr:
|
72
72
|
raise NotImplementedError
|
wandb/sdk/backend/backend.py
CHANGED
@@ -11,7 +11,7 @@ import os
|
|
11
11
|
import queue
|
12
12
|
import sys
|
13
13
|
import threading
|
14
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
14
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
15
15
|
|
16
16
|
import wandb
|
17
17
|
from wandb.sdk.interface.interface import InterfaceBase
|
@@ -19,17 +19,16 @@ from wandb.sdk.interface.interface_queue import InterfaceQueue
|
|
19
19
|
from wandb.sdk.internal.internal import wandb_internal
|
20
20
|
from wandb.sdk.internal.settings_static import SettingsStatic
|
21
21
|
from wandb.sdk.lib.mailbox import Mailbox
|
22
|
-
from wandb.sdk.wandb_manager import _Manager
|
23
22
|
from wandb.sdk.wandb_settings import Settings
|
24
23
|
|
25
24
|
if TYPE_CHECKING:
|
26
25
|
from wandb.proto.wandb_internal_pb2 import Record, Result
|
26
|
+
from wandb.sdk.lib import service_connection
|
27
27
|
|
28
|
-
from ..service.service_sock import ServiceSockInterface
|
29
28
|
from ..wandb_run import Run
|
30
29
|
|
31
|
-
RecordQueue = Union[queue.Queue[Record], multiprocessing.Queue[Record]]
|
32
|
-
ResultQueue = Union[queue.Queue[Result], multiprocessing.Queue[Result]]
|
30
|
+
RecordQueue = Union["queue.Queue[Record]", multiprocessing.Queue[Record]]
|
31
|
+
ResultQueue = Union["queue.Queue[Result]", multiprocessing.Queue[Result]]
|
33
32
|
|
34
33
|
logger = logging.getLogger("wandb")
|
35
34
|
|
@@ -65,7 +64,7 @@ class Backend:
|
|
65
64
|
mailbox: Mailbox,
|
66
65
|
settings: Optional[Settings] = None,
|
67
66
|
log_level: Optional[int] = None,
|
68
|
-
|
67
|
+
service: "Optional[service_connection.ServiceConnection]" = None,
|
69
68
|
) -> None:
|
70
69
|
self._done = False
|
71
70
|
self.record_q = None
|
@@ -75,7 +74,7 @@ class Backend:
|
|
75
74
|
self._internal_pid = None
|
76
75
|
self._settings = settings
|
77
76
|
self._log_level = log_level
|
78
|
-
self.
|
77
|
+
self._service = service
|
79
78
|
self._mailbox = mailbox
|
80
79
|
|
81
80
|
self._multiprocessing = multiprocessing # type: ignore
|
@@ -139,27 +138,10 @@ class Backend:
|
|
139
138
|
if self._save_mod_path:
|
140
139
|
main_module.__file__ = self._save_mod_path
|
141
140
|
|
142
|
-
def _ensure_launched_manager(self) -> None:
|
143
|
-
assert self._manager
|
144
|
-
svc = self._manager._get_service()
|
145
|
-
assert svc
|
146
|
-
svc_iface = svc.service_interface
|
147
|
-
|
148
|
-
svc_transport = svc_iface.get_transport()
|
149
|
-
if svc_transport == "tcp":
|
150
|
-
from ..interface.interface_sock import InterfaceSock
|
151
|
-
|
152
|
-
svc_iface_sock = cast("ServiceSockInterface", svc_iface)
|
153
|
-
sock_client = svc_iface_sock._get_sock_client()
|
154
|
-
sock_interface = InterfaceSock(sock_client, mailbox=self._mailbox)
|
155
|
-
self.interface = sock_interface
|
156
|
-
else:
|
157
|
-
raise AssertionError(f"Unsupported service transport: {svc_transport}")
|
158
|
-
|
159
141
|
def ensure_launched(self) -> None:
|
160
142
|
"""Launch backend worker if not running."""
|
161
|
-
if self.
|
162
|
-
self.
|
143
|
+
if self._service:
|
144
|
+
self.interface = self._service.make_interface(self._mailbox)
|
163
145
|
return
|
164
146
|
|
165
147
|
assert self._settings
|
@@ -0,0 +1,165 @@
|
|
1
|
+
import hashlib
|
2
|
+
import os
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from wandb import util
|
6
|
+
from wandb.sdk.lib import filesystem, runid
|
7
|
+
|
8
|
+
from . import _dtypes
|
9
|
+
from ._private import MEDIA_TMP
|
10
|
+
from .base_types.media import BatchableMedia
|
11
|
+
|
12
|
+
|
13
|
+
class Audio(BatchableMedia):
|
14
|
+
"""Wandb class for audio clips.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
data_or_path: (string or numpy array) A path to an audio file
|
18
|
+
or a numpy array of audio data.
|
19
|
+
sample_rate: (int) Sample rate, required when passing in raw
|
20
|
+
numpy array of audio data.
|
21
|
+
caption: (string) Caption to display with audio.
|
22
|
+
"""
|
23
|
+
|
24
|
+
_log_type = "audio-file"
|
25
|
+
|
26
|
+
def __init__(self, data_or_path, sample_rate=None, caption=None):
|
27
|
+
"""Accept a path to an audio file or a numpy array of audio data."""
|
28
|
+
super().__init__()
|
29
|
+
self._duration = None
|
30
|
+
self._sample_rate = sample_rate
|
31
|
+
self._caption = caption
|
32
|
+
|
33
|
+
if isinstance(data_or_path, str):
|
34
|
+
if self.path_is_reference(data_or_path):
|
35
|
+
self._path = data_or_path
|
36
|
+
self._sha256 = hashlib.sha256(data_or_path.encode("utf-8")).hexdigest()
|
37
|
+
self._is_tmp = False
|
38
|
+
else:
|
39
|
+
self._set_file(data_or_path, is_tmp=False)
|
40
|
+
else:
|
41
|
+
if sample_rate is None:
|
42
|
+
raise ValueError(
|
43
|
+
'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.'
|
44
|
+
)
|
45
|
+
|
46
|
+
soundfile = util.get_module(
|
47
|
+
"soundfile",
|
48
|
+
required='Raw audio requires the soundfile package. To get it, run "pip install soundfile"',
|
49
|
+
)
|
50
|
+
|
51
|
+
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".wav")
|
52
|
+
soundfile.write(tmp_path, data_or_path, sample_rate)
|
53
|
+
self._duration = len(data_or_path) / float(sample_rate)
|
54
|
+
|
55
|
+
self._set_file(tmp_path, is_tmp=True)
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def get_media_subdir(cls):
|
59
|
+
return os.path.join("media", "audio")
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def from_json(cls, json_obj, source_artifact):
|
63
|
+
return cls(
|
64
|
+
source_artifact.get_entry(json_obj["path"]).download(),
|
65
|
+
caption=json_obj["caption"],
|
66
|
+
)
|
67
|
+
|
68
|
+
def bind_to_run(
|
69
|
+
self, run, key, step, id_=None, ignore_copy_err: Optional[bool] = None
|
70
|
+
):
|
71
|
+
if self.path_is_reference(self._path):
|
72
|
+
raise ValueError(
|
73
|
+
"Audio media created by a reference to external storage cannot currently be added to a run"
|
74
|
+
)
|
75
|
+
|
76
|
+
return super().bind_to_run(run, key, step, id_, ignore_copy_err)
|
77
|
+
|
78
|
+
def to_json(self, run):
|
79
|
+
json_dict = super().to_json(run)
|
80
|
+
json_dict.update(
|
81
|
+
{
|
82
|
+
"_type": self._log_type,
|
83
|
+
"caption": self._caption,
|
84
|
+
}
|
85
|
+
)
|
86
|
+
return json_dict
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def seq_to_json(cls, seq, run, key, step):
|
90
|
+
audio_list = list(seq)
|
91
|
+
|
92
|
+
util.get_module(
|
93
|
+
"soundfile",
|
94
|
+
required="wandb.Audio requires the soundfile package. To get it, run: pip install soundfile",
|
95
|
+
)
|
96
|
+
base_path = os.path.join(run.dir, "media", "audio")
|
97
|
+
filesystem.mkdir_exists_ok(base_path)
|
98
|
+
meta = {
|
99
|
+
"_type": "audio",
|
100
|
+
"count": len(audio_list),
|
101
|
+
"audio": [a.to_json(run) for a in audio_list],
|
102
|
+
}
|
103
|
+
sample_rates = cls.sample_rates(audio_list)
|
104
|
+
if sample_rates:
|
105
|
+
meta["sampleRates"] = sample_rates
|
106
|
+
durations = cls.durations(audio_list)
|
107
|
+
if durations:
|
108
|
+
meta["durations"] = durations
|
109
|
+
captions = cls.captions(audio_list)
|
110
|
+
if captions:
|
111
|
+
meta["captions"] = captions
|
112
|
+
|
113
|
+
return meta
|
114
|
+
|
115
|
+
@classmethod
|
116
|
+
def durations(cls, audio_list):
|
117
|
+
return [a._duration for a in audio_list]
|
118
|
+
|
119
|
+
@classmethod
|
120
|
+
def sample_rates(cls, audio_list):
|
121
|
+
return [a._sample_rate for a in audio_list]
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
def captions(cls, audio_list):
|
125
|
+
captions = [a._caption for a in audio_list]
|
126
|
+
if all(c is None for c in captions):
|
127
|
+
return False
|
128
|
+
else:
|
129
|
+
return ["" if c is None else c for c in captions]
|
130
|
+
|
131
|
+
def resolve_ref(self):
|
132
|
+
if self.path_is_reference(self._path):
|
133
|
+
# this object was already created using a ref:
|
134
|
+
return self._path
|
135
|
+
source_artifact = self._artifact_source.artifact
|
136
|
+
|
137
|
+
resolved_name = source_artifact._local_path_to_name(self._path)
|
138
|
+
if resolved_name is not None:
|
139
|
+
target_entry = source_artifact.manifest.get_entry_by_path(resolved_name)
|
140
|
+
if target_entry is not None:
|
141
|
+
return target_entry.ref
|
142
|
+
|
143
|
+
return None
|
144
|
+
|
145
|
+
def __eq__(self, other):
|
146
|
+
if self.path_is_reference(self._path) or self.path_is_reference(other._path):
|
147
|
+
# one or more of these objects is an unresolved reference -- we'll compare
|
148
|
+
# their reference paths instead of their SHAs:
|
149
|
+
return (
|
150
|
+
self.resolve_ref() == other.resolve_ref()
|
151
|
+
and self._caption == other._caption
|
152
|
+
)
|
153
|
+
|
154
|
+
return super().__eq__(other) and self._caption == other._caption
|
155
|
+
|
156
|
+
def __ne__(self, other):
|
157
|
+
return not self.__eq__(other)
|
158
|
+
|
159
|
+
|
160
|
+
class _AudioFileType(_dtypes.Type):
|
161
|
+
name = "audio-file"
|
162
|
+
types = [Audio]
|
163
|
+
|
164
|
+
|
165
|
+
_dtypes.TypeRegistry.add(_AudioFileType)
|
@@ -88,9 +88,7 @@ class WBValue:
|
|
88
88
|
raise NotImplementedError
|
89
89
|
|
90
90
|
@classmethod
|
91
|
-
def from_json(
|
92
|
-
cls: Type["WBValue"], json_obj: dict, source_artifact: "Artifact"
|
93
|
-
) -> "WBValue":
|
91
|
+
def from_json(cls, json_obj: dict, source_artifact: "Artifact") -> "WBValue":
|
94
92
|
"""Deserialize a `json_obj` into it's class representation.
|
95
93
|
|
96
94
|
If additional resources were stored in the `run_or_artifact` artifact during the
|
@@ -0,0 +1,70 @@
|
|
1
|
+
import codecs
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
|
5
|
+
from wandb import util
|
6
|
+
from wandb.sdk.lib import runid
|
7
|
+
|
8
|
+
from . import _dtypes
|
9
|
+
from ._private import MEDIA_TMP
|
10
|
+
from .base_types.media import Media
|
11
|
+
|
12
|
+
|
13
|
+
class Bokeh(Media):
|
14
|
+
"""Wandb class for Bokeh plots.
|
15
|
+
|
16
|
+
Arguments:
|
17
|
+
val: Bokeh plot
|
18
|
+
"""
|
19
|
+
|
20
|
+
_log_type = "bokeh-file"
|
21
|
+
|
22
|
+
def __init__(self, data_or_path):
|
23
|
+
super().__init__()
|
24
|
+
bokeh = util.get_module("bokeh", required=True)
|
25
|
+
if isinstance(data_or_path, str) and os.path.exists(data_or_path):
|
26
|
+
with open(data_or_path) as file:
|
27
|
+
b_json = json.load(file)
|
28
|
+
self.b_obj = bokeh.document.Document.from_json(b_json)
|
29
|
+
self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json")
|
30
|
+
elif isinstance(data_or_path, bokeh.model.Model):
|
31
|
+
_data = bokeh.document.Document()
|
32
|
+
_data.add_root(data_or_path)
|
33
|
+
# serialize/deserialize pairing followed by sorting attributes ensures
|
34
|
+
# that the file's sha's are equivalent in subsequent calls
|
35
|
+
self.b_obj = bokeh.document.Document.from_json(_data.to_json())
|
36
|
+
b_json = self.b_obj.to_json()
|
37
|
+
if "references" in b_json["roots"]:
|
38
|
+
b_json["roots"]["references"].sort(key=lambda x: x["id"])
|
39
|
+
|
40
|
+
tmp_path = os.path.join(MEDIA_TMP.name, runid.generate_id() + ".bokeh.json")
|
41
|
+
with codecs.open(tmp_path, "w", encoding="utf-8") as fp:
|
42
|
+
util.json_dump_safer(b_json, fp)
|
43
|
+
self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json")
|
44
|
+
elif not isinstance(data_or_path, bokeh.document.Document):
|
45
|
+
raise TypeError(
|
46
|
+
"Bokeh constructor accepts Bokeh document/model or path to Bokeh json file"
|
47
|
+
)
|
48
|
+
|
49
|
+
def get_media_subdir(self):
|
50
|
+
return os.path.join("media", "bokeh")
|
51
|
+
|
52
|
+
def to_json(self, run):
|
53
|
+
# TODO: (tss) this is getting redundant for all the media objects. We can probably
|
54
|
+
# pull this into Media#to_json and remove this type override for all the media types.
|
55
|
+
# There are only a few cases where the type is different between artifacts and runs.
|
56
|
+
json_dict = super().to_json(run)
|
57
|
+
json_dict["_type"] = self._log_type
|
58
|
+
return json_dict
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def from_json(cls, json_obj, source_artifact):
|
62
|
+
return cls(source_artifact.get_entry(json_obj["path"]).download())
|
63
|
+
|
64
|
+
|
65
|
+
class _BokehFileType(_dtypes.Type):
|
66
|
+
name = "bokeh-file"
|
67
|
+
types = [Bokeh]
|
68
|
+
|
69
|
+
|
70
|
+
_dtypes.TypeRegistry.add(_BokehFileType)
|