wandb 0.19.9__py3-none-any.whl → 0.19.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +6 -3
- wandb/_pydantic/__init__.py +14 -8
- wandb/_pydantic/base.py +51 -36
- wandb/_pydantic/utils.py +73 -0
- wandb/_pydantic/v1_compat.py +79 -57
- wandb/apis/public/__init__.py +2 -2
- wandb/apis/public/api.py +684 -4
- wandb/apis/public/artifacts.py +377 -677
- wandb/apis/public/automations.py +69 -0
- wandb/apis/public/integrations.py +180 -0
- wandb/apis/public/projects.py +29 -0
- wandb/apis/public/registries/__init__.py +0 -0
- wandb/apis/public/registries/_freezable_list.py +179 -0
- wandb/apis/public/{registries.py → registries/registries_search.py} +22 -129
- wandb/apis/public/registries/registry.py +357 -0
- wandb/apis/public/registries/utils.py +140 -0
- wandb/apis/public/runs.py +58 -56
- wandb/apis/public/utils.py +107 -1
- wandb/automations/__init__.py +73 -0
- wandb/automations/_filters/__init__.py +40 -0
- wandb/automations/_filters/expressions.py +181 -0
- wandb/automations/_filters/operators.py +258 -0
- wandb/automations/_filters/run_metrics.py +332 -0
- wandb/automations/_generated/__init__.py +177 -0
- wandb/automations/_generated/create_automation.py +17 -0
- wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
- wandb/automations/_generated/delete_automation.py +17 -0
- wandb/automations/_generated/enums.py +33 -0
- wandb/automations/_generated/fragments.py +358 -0
- wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
- wandb/automations/_generated/get_automations.py +24 -0
- wandb/automations/_generated/get_automations_by_entity.py +26 -0
- wandb/automations/_generated/input_types.py +104 -0
- wandb/automations/_generated/integrations_by_entity.py +22 -0
- wandb/automations/_generated/operations.py +647 -0
- wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
- wandb/automations/_generated/update_automation.py +17 -0
- wandb/automations/_utils.py +237 -0
- wandb/automations/_validators.py +165 -0
- wandb/automations/actions.py +220 -0
- wandb/automations/automations.py +87 -0
- wandb/automations/events.py +287 -0
- wandb/automations/integrations.py +45 -0
- wandb/automations/scopes.py +78 -0
- wandb/beta/workflows.py +9 -10
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +3 -3
- wandb/env.py +11 -0
- wandb/integration/keras/keras.py +2 -1
- wandb/integration/langchain/wandb_tracer.py +2 -1
- wandb/jupyter.py +137 -118
- wandb/old/settings.py +4 -1
- wandb/old/summary.py +0 -2
- wandb/proto/v3/wandb_internal_pb2.py +297 -292
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +292 -292
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +292 -292
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_base_pb2.py +41 -0
- wandb/proto/v6/wandb_internal_pb2.py +393 -0
- wandb/proto/v6/wandb_server_pb2.py +78 -0
- wandb/proto/v6/wandb_settings_pb2.py +58 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +8 -0
- wandb/proto/wandb_internal_pb2.py +3 -1
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/_generated/__init__.py +289 -0
- wandb/sdk/artifacts/_generated/add_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
- wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
- wandb/sdk/artifacts/_generated/delete_aliases.py +21 -0
- wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
- wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_generated/enums.py +17 -0
- wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +67 -0
- wandb/sdk/artifacts/_generated/fragments.py +221 -0
- wandb/sdk/artifacts/_generated/input_types.py +28 -0
- wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
- wandb/sdk/artifacts/_generated/operations.py +611 -0
- wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
- wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
- wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/update_artifact.py +26 -0
- wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_graphql_fragments.py +57 -79
- wandb/sdk/artifacts/_validators.py +120 -1
- wandb/sdk/artifacts/artifact.py +419 -215
- wandb/sdk/artifacts/artifact_file_cache.py +4 -6
- wandb/sdk/artifacts/artifact_manifest_entry.py +13 -3
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +182 -1
- wandb/sdk/artifacts/storage_policy.py +3 -0
- wandb/sdk/data_types/base_types/media.py +2 -3
- wandb/sdk/data_types/base_types/wb_value.py +34 -11
- wandb/sdk/data_types/html.py +36 -9
- wandb/sdk/data_types/image.py +12 -12
- wandb/sdk/data_types/table.py +5 -0
- wandb/sdk/data_types/trace_tree.py +2 -0
- wandb/sdk/data_types/utils.py +1 -1
- wandb/sdk/data_types/video.py +59 -57
- wandb/sdk/interface/interface.py +4 -3
- wandb/sdk/internal/internal_api.py +21 -31
- wandb/sdk/internal/profiler.py +6 -5
- wandb/sdk/internal/run.py +13 -6
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/launch/sweeps/utils.py +8 -0
- wandb/sdk/lib/apikey.py +25 -4
- wandb/sdk/lib/asyncio_compat.py +1 -1
- wandb/sdk/lib/deprecate.py +13 -22
- wandb/sdk/lib/disabled.py +2 -1
- wandb/sdk/lib/printer.py +37 -8
- wandb/sdk/lib/printer_asyncio.py +46 -0
- wandb/sdk/lib/redirect.py +10 -5
- wandb/sdk/projects/_generated/__init__.py +47 -0
- wandb/sdk/projects/_generated/delete_project.py +22 -0
- wandb/sdk/projects/_generated/enums.py +4 -0
- wandb/sdk/projects/_generated/fetch_registry.py +22 -0
- wandb/sdk/projects/_generated/fragments.py +41 -0
- wandb/sdk/projects/_generated/input_types.py +13 -0
- wandb/sdk/projects/_generated/operations.py +88 -0
- wandb/sdk/projects/_generated/rename_project.py +27 -0
- wandb/sdk/projects/_generated/upsert_registry_project.py +27 -0
- wandb/sdk/service/server_sock.py +19 -14
- wandb/sdk/service/service.py +18 -8
- wandb/sdk/service/streams.py +5 -0
- wandb/sdk/verify/verify.py +6 -3
- wandb/sdk/wandb_init.py +217 -70
- wandb/sdk/wandb_login.py +13 -4
- wandb/sdk/wandb_run.py +419 -295
- wandb/sdk/wandb_settings.py +27 -10
- wandb/sdk/wandb_setup.py +61 -0
- wandb/util.py +33 -29
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/METADATA +5 -5
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/RECORD +152 -82
- wandb/_globals.py +0 -19
- wandb/sdk/internal/_generated/base.py +0 -226
- wandb/sdk/internal/_generated/typing_compat.py +0 -14
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/WHEEL +0 -0
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.9.dist-info → wandb-0.19.11.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/video.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
|
+
import functools
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
from io import BytesIO
|
4
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional, Sequence, Type, Union
|
5
6
|
|
6
7
|
import wandb
|
7
8
|
from wandb import util
|
8
|
-
from wandb.sdk.lib import filesystem, runid
|
9
|
+
from wandb.sdk.lib import filesystem, printer, printer_asyncio, runid
|
9
10
|
|
10
11
|
from . import _dtypes
|
11
12
|
from ._private import MEDIA_TMP
|
@@ -47,36 +48,7 @@ def write_gif_with_image_io(
|
|
47
48
|
|
48
49
|
|
49
50
|
class Video(BatchableMedia):
|
50
|
-
"""
|
51
|
-
|
52
|
-
Args:
|
53
|
-
data_or_path: (numpy array, string, io)
|
54
|
-
Video can be initialized with a path to a file or an io object.
|
55
|
-
The format must be "gif", "mp4", "webm" or "ogg".
|
56
|
-
The format must be specified with the format argument.
|
57
|
-
Video can be initialized with a numpy tensor.
|
58
|
-
The numpy tensor must be either 4 dimensional or 5 dimensional.
|
59
|
-
Channels should be (time, channel, height, width) or
|
60
|
-
(batch, time, channel, height width)
|
61
|
-
caption: (string) caption associated with the video for display
|
62
|
-
fps: (int)
|
63
|
-
The frame rate to use when encoding raw video frames. Default value is 4.
|
64
|
-
This parameter has no effect when data_or_path is a string, or bytes.
|
65
|
-
format: (string) format of video, necessary if initializing with path or io object.
|
66
|
-
|
67
|
-
Examples:
|
68
|
-
### Log a numpy array as a video
|
69
|
-
<!--yeadoc-test:log-video-numpy-->
|
70
|
-
```python
|
71
|
-
import numpy as np
|
72
|
-
import wandb
|
73
|
-
|
74
|
-
run = wandb.init()
|
75
|
-
# axes are (time, channel, height, width)
|
76
|
-
frames = np.random.randint(low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8)
|
77
|
-
run.log({"video": wandb.Video(frames, fps=4)})
|
78
|
-
```
|
79
|
-
"""
|
51
|
+
"""A class for logging videos to W&B."""
|
80
52
|
|
81
53
|
_log_type = "video-file"
|
82
54
|
EXTS = ("gif", "mp4", "webm", "ogg")
|
@@ -88,10 +60,53 @@ class Video(BatchableMedia):
|
|
88
60
|
data_or_path: Union["np.ndarray", str, "TextIO", "BytesIO"],
|
89
61
|
caption: Optional[str] = None,
|
90
62
|
fps: Optional[int] = None,
|
91
|
-
format: Optional[
|
63
|
+
format: Optional[Literal["gif", "mp4", "webm", "ogg"]] = None,
|
92
64
|
):
|
65
|
+
"""Initialize a W&B Video object.
|
66
|
+
|
67
|
+
Args:
|
68
|
+
data_or_path:
|
69
|
+
Video can be initialized with a path to a file or an io object.
|
70
|
+
Video can be initialized with a numpy tensor.
|
71
|
+
The numpy tensor must be either 4 dimensional or 5 dimensional.
|
72
|
+
The dimensions should be (number of frames, channel, height, width) or
|
73
|
+
(batch, number of frames, channel, height, width)
|
74
|
+
The format parameter must be specified with the format argument
|
75
|
+
when initializing with a numpy array
|
76
|
+
or io object.
|
77
|
+
caption: Caption associated with the video for display.
|
78
|
+
fps:
|
79
|
+
The frame rate to use when encoding raw video frames.
|
80
|
+
Default value is 4.
|
81
|
+
This parameter has no effect when data_or_path is a string, or bytes.
|
82
|
+
format:
|
83
|
+
Format of video, necessary if initializing with a numpy array
|
84
|
+
or io object. This parameter will be used to determine the format
|
85
|
+
to use when encoding the video data. Accepted values are "gif",
|
86
|
+
"mp4", "webm", or "ogg".
|
87
|
+
|
88
|
+
Examples:
|
89
|
+
### Log a numpy array as a video
|
90
|
+
```python
|
91
|
+
import numpy as np
|
92
|
+
import wandb
|
93
|
+
|
94
|
+
with wandb.init() as run:
|
95
|
+
# axes are (number of frames, channel, height, width)
|
96
|
+
frames = np.random.randint(
|
97
|
+
low=0, high=256, size=(10, 3, 100, 100), dtype=np.uint8
|
98
|
+
)
|
99
|
+
run.log({"video": wandb.Video(frames, format="mp4", fps=4)})
|
100
|
+
```
|
101
|
+
"""
|
93
102
|
super().__init__(caption=caption)
|
94
103
|
|
104
|
+
if format is None:
|
105
|
+
wandb.termwarn(
|
106
|
+
"`format` argument was not provided, defaulting to `gif`. "
|
107
|
+
"This parameter will be required in v0.20.0, "
|
108
|
+
"please specify the format explicitly."
|
109
|
+
)
|
95
110
|
self._format = format or "gif"
|
96
111
|
self._width = None
|
97
112
|
self._height = None
|
@@ -134,7 +149,11 @@ class Video(BatchableMedia):
|
|
134
149
|
"wandb.Video accepts a file path or numpy like data as input"
|
135
150
|
)
|
136
151
|
fps = fps or 4
|
137
|
-
|
152
|
+
printer_asyncio.run_async_with_spinner(
|
153
|
+
printer.new_printer(),
|
154
|
+
"Encoding video...",
|
155
|
+
functools.partial(self.encode, fps=fps),
|
156
|
+
)
|
138
157
|
|
139
158
|
def encode(self, fps: int = 4) -> None:
|
140
159
|
# import ImageSequenceClip from the appropriate MoviePy module
|
@@ -152,29 +171,12 @@ class Video(BatchableMedia):
|
|
152
171
|
filename = os.path.join(
|
153
172
|
MEDIA_TMP.name, runid.generate_id() + "." + self._format
|
154
173
|
)
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
else:
|
162
|
-
clip.write_videofile(filename, **kwargs)
|
163
|
-
except TypeError:
|
164
|
-
try: # even older versions of moviepy do not support progress_bar argument
|
165
|
-
kwargs = {"verbose": False, "progress_bar": False}
|
166
|
-
if self._format == "gif":
|
167
|
-
clip.write_gif(filename, **kwargs)
|
168
|
-
else:
|
169
|
-
clip.write_videofile(filename, **kwargs)
|
170
|
-
except TypeError:
|
171
|
-
kwargs = {
|
172
|
-
"verbose": False,
|
173
|
-
}
|
174
|
-
if self._format == "gif":
|
175
|
-
clip.write_gif(filename, **kwargs)
|
176
|
-
else:
|
177
|
-
clip.write_videofile(filename, **kwargs)
|
174
|
+
|
175
|
+
if self._format == "gif":
|
176
|
+
write_gif_with_image_io(clip, filename)
|
177
|
+
else:
|
178
|
+
clip.write_videofile(filename, logger=None)
|
179
|
+
|
178
180
|
self._set_file(filename, is_tmp=True)
|
179
181
|
|
180
182
|
@classmethod
|
wandb/sdk/interface/interface.py
CHANGED
@@ -173,6 +173,8 @@ class InterfaceBase:
|
|
173
173
|
self._make_config(data=config_dict, obj=proto_run.config)
|
174
174
|
if run._telemetry_obj:
|
175
175
|
proto_run.telemetry.MergeFrom(run._telemetry_obj)
|
176
|
+
if run._start_runtime:
|
177
|
+
proto_run.runtime = run._start_runtime
|
176
178
|
return proto_run
|
177
179
|
|
178
180
|
def publish_run(self, run: "Run") -> None:
|
@@ -426,7 +428,6 @@ class InterfaceBase:
|
|
426
428
|
|
427
429
|
def deliver_link_artifact(
|
428
430
|
self,
|
429
|
-
run: "Run",
|
430
431
|
artifact: "Artifact",
|
431
432
|
portfolio_name: str,
|
432
433
|
aliases: Iterable[str],
|
@@ -440,9 +441,9 @@ class InterfaceBase:
|
|
440
441
|
else:
|
441
442
|
link_artifact.server_id = artifact.id if artifact.id else ""
|
442
443
|
link_artifact.portfolio_name = portfolio_name
|
443
|
-
link_artifact.portfolio_entity = entity or
|
444
|
+
link_artifact.portfolio_entity = entity or ""
|
444
445
|
link_artifact.portfolio_organization = organization or ""
|
445
|
-
link_artifact.portfolio_project = project or
|
446
|
+
link_artifact.portfolio_project = project or ""
|
446
447
|
link_artifact.portfolio_aliases.extend(aliases)
|
447
448
|
|
448
449
|
return self._deliver_link_artifact(link_artifact)
|
@@ -12,6 +12,7 @@ import sys
|
|
12
12
|
import threading
|
13
13
|
from copy import deepcopy
|
14
14
|
from pathlib import Path
|
15
|
+
from types import MappingProxyType
|
15
16
|
from typing import (
|
16
17
|
IO,
|
17
18
|
TYPE_CHECKING,
|
@@ -189,11 +190,6 @@ def _match_org_with_fetched_org_entities(
|
|
189
190
|
"""
|
190
191
|
for org_names in orgs:
|
191
192
|
if organization in org_names:
|
192
|
-
wandb.termwarn(
|
193
|
-
"Registries can be linked/fetched using a shorthand form without specifying the organization name. "
|
194
|
-
"Try using shorthand path format: <my_registry_name>/<artifact_name> or "
|
195
|
-
"just <my_registry_name> if fetching just the project."
|
196
|
-
)
|
197
193
|
return org_names.entity_name
|
198
194
|
|
199
195
|
if len(orgs) == 1:
|
@@ -873,30 +869,29 @@ class Api:
|
|
873
869
|
_, _, mutations = self.server_info_introspection()
|
874
870
|
return "updateRunQueueItemWarning" in mutations
|
875
871
|
|
876
|
-
def
|
877
|
-
"""
|
878
|
-
|
879
|
-
Args:
|
880
|
-
feature_value (ServerFeature): The enum value of the feature to check.
|
881
|
-
|
882
|
-
Returns:
|
883
|
-
bool: True if the feature is enabled, False otherwise.
|
884
|
-
|
885
|
-
Raises:
|
886
|
-
Exception: If server doesn't support feature queries or other errors occur
|
887
|
-
"""
|
872
|
+
def _server_features(self) -> Mapping[str, bool]:
|
873
|
+
"""Returns a cached, read-only lookup of current server feature flags."""
|
888
874
|
if self._server_features_cache is None:
|
889
875
|
query = gql(SERVER_FEATURES_QUERY_GQL)
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
876
|
+
|
877
|
+
try:
|
878
|
+
response = self.gql(query)
|
879
|
+
except Exception as e:
|
880
|
+
# Unfortunately we currently have to match on the text of the error message
|
881
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
882
|
+
self._server_features_cache = {}
|
883
|
+
else:
|
884
|
+
raise
|
896
885
|
else:
|
897
|
-
|
886
|
+
info = ServerFeaturesQuery.model_validate(response).server_info
|
887
|
+
if info and (feats := info.features):
|
888
|
+
self._server_features_cache = {
|
889
|
+
f.name: f.is_enabled for f in feats if f
|
890
|
+
}
|
891
|
+
else:
|
892
|
+
self._server_features_cache = {}
|
898
893
|
|
899
|
-
return self._server_features_cache
|
894
|
+
return MappingProxyType(self._server_features_cache)
|
900
895
|
|
901
896
|
def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
|
902
897
|
"""Wrapper around check_server_feature that warns and returns False for older unsupported servers.
|
@@ -912,12 +907,7 @@ class Api:
|
|
912
907
|
Exceptions:
|
913
908
|
Exception: If an error other than the server not supporting feature queries occurs.
|
914
909
|
"""
|
915
|
-
|
916
|
-
return self._check_server_feature(feature_value)
|
917
|
-
except Exception as e:
|
918
|
-
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
919
|
-
return False
|
920
|
-
raise e
|
910
|
+
return self._server_features().get(ServerFeature.Name(feature_value), False)
|
921
911
|
|
922
912
|
@normalize_exceptions
|
923
913
|
def update_run_queue_item_warning(
|
wandb/sdk/internal/profiler.py
CHANGED
@@ -18,12 +18,13 @@ def torch_trace_handler():
|
|
18
18
|
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
|
19
19
|
```
|
20
20
|
|
21
|
-
Calling this function ensures that profiler charts & tables can be viewed in
|
22
|
-
on wandb.ai.
|
21
|
+
Calling this function ensures that profiler charts & tables can be viewed in
|
22
|
+
your run dashboard on wandb.ai.
|
23
23
|
|
24
|
-
Please note that `wandb.init()` must be called before this function is
|
25
|
-
|
26
|
-
|
24
|
+
Please note that `wandb.init()` must be called before this function is
|
25
|
+
invoked, and the reinit setting must not be set to "create_new". The PyTorch
|
26
|
+
(torch) version must also be at least 1.9, in order to ensure stability of
|
27
|
+
their Profiler API.
|
27
28
|
|
28
29
|
Args:
|
29
30
|
None
|
wandb/sdk/internal/run.py
CHANGED
@@ -5,21 +5,28 @@ Semi-stubbed run for internal process use.
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
|
8
|
-
|
8
|
+
import sys
|
9
9
|
|
10
|
-
|
10
|
+
if sys.version_info >= (3, 12):
|
11
|
+
from typing import override
|
12
|
+
else:
|
13
|
+
from typing_extensions import override
|
14
|
+
|
15
|
+
from wandb.sdk import wandb_run
|
11
16
|
|
12
17
|
|
13
18
|
class InternalRun(wandb_run.Run):
|
14
19
|
def __init__(self, run_obj, settings, datatypes_cb):
|
15
20
|
super().__init__(settings=settings)
|
16
21
|
self._run_obj = run_obj
|
22
|
+
self._datatypes_cb = datatypes_cb
|
17
23
|
|
18
|
-
|
19
|
-
# We really want a common interface for wandb_run.Run and InternalRun.
|
20
|
-
_datatypes_set_callback(datatypes_cb)
|
21
|
-
|
24
|
+
@override
|
22
25
|
def _set_backend(self, backend):
|
23
26
|
# This type of run object can't have a backend
|
24
27
|
# or do any writes.
|
25
28
|
pass
|
29
|
+
|
30
|
+
@override
|
31
|
+
def _publish_file(self, fname: str) -> None:
|
32
|
+
self._datatypes_cb(fname)
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1444,7 +1444,7 @@ class SendManager:
|
|
1444
1444
|
)
|
1445
1445
|
if (client_id or server_id) and portfolio_name and entity and project:
|
1446
1446
|
try:
|
1447
|
-
self._api.link_artifact(
|
1447
|
+
response = self._api.link_artifact(
|
1448
1448
|
client_id,
|
1449
1449
|
server_id,
|
1450
1450
|
portfolio_name,
|
@@ -1453,9 +1453,12 @@ class SendManager:
|
|
1453
1453
|
aliases,
|
1454
1454
|
organization,
|
1455
1455
|
)
|
1456
|
+
result.response.link_artifact_response.version_index = response[
|
1457
|
+
"versionIndex"
|
1458
|
+
]
|
1456
1459
|
except Exception as e:
|
1457
1460
|
org_or_entity = organization or entity
|
1458
|
-
result.response.
|
1461
|
+
result.response.link_artifact_response.error_message = (
|
1459
1462
|
f"error linking artifact to "
|
1460
1463
|
f'"{org_or_entity}/{project}/{portfolio_name}"; error: {e}'
|
1461
1464
|
)
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -223,6 +223,10 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
223
223
|
flags_dict: Dict[str, Any] = {}
|
224
224
|
# (5) flags without equals (e.g. --foo bar)
|
225
225
|
args_no_equals: List[str] = []
|
226
|
+
# (6) flags for hydra append config value (e.g. +foo=bar)
|
227
|
+
flags_append_hydra: List[str] = []
|
228
|
+
# (7) flags for hydra override config value (e.g. ++foo=bar)
|
229
|
+
flags_override_hydra: List[str] = []
|
226
230
|
for param, config in command["args"].items():
|
227
231
|
# allow 'None' as a valid value, but error if no value is found
|
228
232
|
try:
|
@@ -234,6 +238,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
234
238
|
flags.append("--" + _flag)
|
235
239
|
flags_no_hyphens.append(_flag)
|
236
240
|
args_no_equals += [f"--{param}", str(_value)]
|
241
|
+
flags_append_hydra.append("+" + _flag)
|
242
|
+
flags_override_hydra.append("++" + _flag)
|
237
243
|
if isinstance(_value, bool):
|
238
244
|
# omit flags if they are boolean and false
|
239
245
|
if _value:
|
@@ -248,6 +254,8 @@ def create_sweep_command_args(command: Dict) -> Dict[str, Any]:
|
|
248
254
|
"args_no_boolean_flags": flags_no_booleans,
|
249
255
|
"args_json": [json.dumps(flags_dict)],
|
250
256
|
"args_dict": flags_dict,
|
257
|
+
"args_append_hydra": flags_append_hydra,
|
258
|
+
"args_override_hydra": flags_override_hydra,
|
251
259
|
}
|
252
260
|
|
253
261
|
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -83,6 +83,25 @@ def get_netrc_file_path() -> str:
|
|
83
83
|
return os.path.join(os.path.expanduser("~"), netrc_file)
|
84
84
|
|
85
85
|
|
86
|
+
def _api_key_prompt_str(app_url: str, referrer: str | None = None) -> str:
|
87
|
+
"""Generate a prompt string for API key authorization.
|
88
|
+
|
89
|
+
Creates a URL string that directs users to the authorization page where they
|
90
|
+
can find their API key.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
app_url: The base URL of the W&B application.
|
94
|
+
referrer: Optional referrer parameter to include in the URL.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
A formatted string with instructions and the authorization URL.
|
98
|
+
"""
|
99
|
+
ref = ""
|
100
|
+
if referrer:
|
101
|
+
ref = f"?ref={referrer}"
|
102
|
+
return f"You can find your API key in your browser here: {app_url}/authorize{ref}"
|
103
|
+
|
104
|
+
|
86
105
|
def prompt_api_key( # noqa: C901
|
87
106
|
settings: Settings,
|
88
107
|
api: InternalApi | None = None,
|
@@ -91,6 +110,7 @@ def prompt_api_key( # noqa: C901
|
|
91
110
|
no_offline: bool = False,
|
92
111
|
no_create: bool = False,
|
93
112
|
local: bool = False,
|
113
|
+
referrer: str | None = None,
|
94
114
|
) -> str | bool | None:
|
95
115
|
"""Prompt for api key.
|
96
116
|
|
@@ -149,7 +169,10 @@ def prompt_api_key( # noqa: C901
|
|
149
169
|
key = browser_callback(signup=True) if browser_callback else None
|
150
170
|
|
151
171
|
if not key:
|
152
|
-
|
172
|
+
ref = f"&ref={referrer}" if referrer else ""
|
173
|
+
wandb.termlog(
|
174
|
+
f"Create an account here: {app_url}/authorize?signup=true{ref}"
|
175
|
+
)
|
153
176
|
key = input_callback(api_ask).strip()
|
154
177
|
elif result == LOGIN_CHOICE_EXISTS:
|
155
178
|
key = browser_callback() if browser_callback else None
|
@@ -164,9 +187,7 @@ def prompt_api_key( # noqa: C901
|
|
164
187
|
f"Logging into {host}. (Learn how to deploy a W&B server "
|
165
188
|
f"locally: {url_registry.url('wandb-server')})"
|
166
189
|
)
|
167
|
-
wandb.termlog(
|
168
|
-
f"You can find your API key in your browser here: {app_url}/authorize"
|
169
|
-
)
|
190
|
+
wandb.termlog(_api_key_prompt_str(app_url, referrer))
|
170
191
|
key = input_callback(api_ask).strip()
|
171
192
|
elif result == LOGIN_CHOICE_NOTTY:
|
172
193
|
# TODO: Needs refactor as this needs to be handled by caller
|
wandb/sdk/lib/asyncio_compat.py
CHANGED
wandb/sdk/lib/deprecate.py
CHANGED
@@ -1,42 +1,33 @@
|
|
1
|
-
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from typing import TYPE_CHECKING
|
3
|
+
from typing import TYPE_CHECKING
|
4
4
|
|
5
5
|
import wandb
|
6
|
-
from wandb.proto.wandb_deprecated import DEPRECATED_FEATURES
|
7
|
-
from wandb.
|
6
|
+
from wandb.proto.wandb_deprecated import DEPRECATED_FEATURES
|
7
|
+
from wandb.sdk.lib import telemetry
|
8
8
|
|
9
|
-
#
|
9
|
+
# Necessary to break import cycle.
|
10
10
|
if TYPE_CHECKING:
|
11
|
-
from
|
12
|
-
|
13
|
-
|
14
|
-
deprecated_field_names: Tuple[str, ...] = tuple(
|
15
|
-
str(v) for k, v in Deprecated.__dict__.items() if not k.startswith("_")
|
16
|
-
)
|
11
|
+
from wandb import wandb_run
|
17
12
|
|
18
13
|
|
19
14
|
def deprecate(
|
20
15
|
field_name: DEPRECATED_FEATURES,
|
21
16
|
warning_message: str,
|
22
|
-
run:
|
17
|
+
run: wandb_run.Run | None = None,
|
23
18
|
) -> None:
|
24
19
|
"""Warn the user that a feature has been deprecated.
|
25
20
|
|
26
|
-
|
21
|
+
If a run is provided, the given field on its telemetry is updated.
|
22
|
+
Otherwise, the global run is used.
|
27
23
|
|
28
24
|
Args:
|
29
|
-
field_name: The
|
30
|
-
Defined in wandb/proto/wandb_telemetry.proto::Deprecated
|
25
|
+
field_name: The field on the Deprecated proto for this deprecation.
|
31
26
|
warning_message: The message to display to the user.
|
32
|
-
run: The run
|
27
|
+
run: The run whose telemetry to update.
|
33
28
|
"""
|
34
|
-
known_fields = TelemetryDeprecated.DESCRIPTOR.fields_by_name.keys()
|
35
|
-
if field_name not in known_fields:
|
36
|
-
raise ValueError(
|
37
|
-
f"Unknown field name: {field_name}. Known fields: {known_fields}"
|
38
|
-
)
|
39
29
|
_run = run or wandb.run
|
40
|
-
with
|
30
|
+
with telemetry.context(run=_run) as tel:
|
41
31
|
setattr(tel.deprecated, field_name, True)
|
32
|
+
|
42
33
|
wandb.termwarn(warning_message, repeat=False)
|
wandb/sdk/lib/disabled.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from typing import Any
|
2
2
|
|
3
|
+
from wandb.proto.wandb_deprecated import Deprecated
|
3
4
|
from wandb.sdk.lib import deprecate
|
4
5
|
|
5
6
|
|
@@ -23,7 +24,7 @@ class RunDisabled:
|
|
23
24
|
|
24
25
|
def __getattr__(self, name: str) -> Any:
|
25
26
|
deprecate.deprecate(
|
26
|
-
field_name=
|
27
|
+
field_name=Deprecated.run_disabled,
|
27
28
|
warning_message="RunDisabled is deprecated and is a no-op. "
|
28
29
|
'`wandb.init(mode="disabled")` now returns and instance of `wandb.sdk.wandb_run.Run`.',
|
29
30
|
)
|
wandb/sdk/lib/printer.py
CHANGED
@@ -20,6 +20,7 @@ import click
|
|
20
20
|
|
21
21
|
import wandb
|
22
22
|
from wandb.errors import term
|
23
|
+
from wandb.sdk import wandb_setup
|
23
24
|
|
24
25
|
from . import ipython, sparkline
|
25
26
|
|
@@ -98,12 +99,21 @@ _JUPYTER_PANEL_STYLES = """
|
|
98
99
|
"""
|
99
100
|
|
100
101
|
|
101
|
-
def new_printer() -> Printer:
|
102
|
-
"""Returns a
|
102
|
+
def new_printer(settings: wandb.Settings | None = None) -> Printer:
|
103
|
+
"""Returns a printer appropriate for the environment we're in.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
settings: The settings of a run. If not provided and `wandb.setup()`
|
107
|
+
has been called, then global settings are used. Otherwise,
|
108
|
+
settings (such as silent mode) are ignored.
|
109
|
+
"""
|
110
|
+
if not settings and (singleton := wandb_setup.singleton()):
|
111
|
+
settings = singleton.settings
|
112
|
+
|
103
113
|
if ipython.in_jupyter():
|
104
|
-
return _PrinterJupyter()
|
114
|
+
return _PrinterJupyter(settings=settings)
|
105
115
|
else:
|
106
|
-
return _PrinterTerm()
|
116
|
+
return _PrinterTerm(settings=settings)
|
107
117
|
|
108
118
|
|
109
119
|
class Printer(abc.ABC):
|
@@ -281,13 +291,18 @@ class DynamicText(abc.ABC):
|
|
281
291
|
|
282
292
|
|
283
293
|
class _PrinterTerm(Printer):
|
284
|
-
def __init__(self) -> None:
|
294
|
+
def __init__(self, *, settings: wandb.Settings | None) -> None:
|
285
295
|
super().__init__()
|
296
|
+
self._settings = settings
|
286
297
|
self._progress = itertools.cycle(["-", "\\", "|", "/"])
|
287
298
|
|
288
299
|
@override
|
289
300
|
@contextlib.contextmanager
|
290
301
|
def dynamic_text(self) -> Iterator[DynamicText | None]:
|
302
|
+
if self._settings and self._settings.silent:
|
303
|
+
yield None
|
304
|
+
return
|
305
|
+
|
291
306
|
with term.dynamic_text() as handle:
|
292
307
|
if not handle:
|
293
308
|
yield None
|
@@ -301,6 +316,9 @@ class _PrinterTerm(Printer):
|
|
301
316
|
*,
|
302
317
|
level: str | int | None = None,
|
303
318
|
) -> None:
|
319
|
+
if self._settings and self._settings.silent:
|
320
|
+
return
|
321
|
+
|
304
322
|
text = "\n".join(text) if isinstance(text, (list, tuple)) else text
|
305
323
|
self._display_fn_mapping(level)(text)
|
306
324
|
|
@@ -323,10 +341,16 @@ class _PrinterTerm(Printer):
|
|
323
341
|
|
324
342
|
@override
|
325
343
|
def progress_update(self, text: str, percent_done: float | None = None) -> None:
|
344
|
+
if self._settings and self._settings.silent:
|
345
|
+
return
|
346
|
+
|
326
347
|
wandb.termlog(f"{next(self._progress)} {text}", newline=False)
|
327
348
|
|
328
349
|
@override
|
329
350
|
def progress_close(self, text: str | None = None) -> None:
|
351
|
+
if self._settings and self._settings.silent:
|
352
|
+
return
|
353
|
+
|
330
354
|
text = text or " " * 79
|
331
355
|
wandb.termlog(text)
|
332
356
|
|
@@ -422,8 +446,9 @@ class _DynamicTermText(DynamicText):
|
|
422
446
|
|
423
447
|
|
424
448
|
class _PrinterJupyter(Printer):
|
425
|
-
def __init__(self) -> None:
|
449
|
+
def __init__(self, *, settings: wandb.Settings | None) -> None:
|
426
450
|
super().__init__()
|
451
|
+
self._settings = settings
|
427
452
|
self._progress = ipython.jupyter_progress_bar()
|
428
453
|
|
429
454
|
from IPython import display
|
@@ -433,6 +458,10 @@ class _PrinterJupyter(Printer):
|
|
433
458
|
@override
|
434
459
|
@contextlib.contextmanager
|
435
460
|
def dynamic_text(self) -> Iterator[DynamicText | None]:
|
461
|
+
if self._settings and self._settings.silent:
|
462
|
+
yield None
|
463
|
+
return
|
464
|
+
|
436
465
|
handle = self._ipython_display.display(
|
437
466
|
self._ipython_display.HTML(""),
|
438
467
|
display_id=True,
|
@@ -452,7 +481,7 @@ class _PrinterJupyter(Printer):
|
|
452
481
|
*,
|
453
482
|
level: str | int | None = None,
|
454
483
|
) -> None:
|
455
|
-
if
|
484
|
+
if self._settings and self._settings.silent:
|
456
485
|
return
|
457
486
|
|
458
487
|
text = "<br>".join(text) if isinstance(text, (list, tuple)) else text
|
@@ -507,7 +536,7 @@ class _PrinterJupyter(Printer):
|
|
507
536
|
text: str,
|
508
537
|
percent_done: float | None = None,
|
509
538
|
) -> None:
|
510
|
-
if not self._progress:
|
539
|
+
if (self._settings and self._settings.silent) or not self._progress:
|
511
540
|
return
|
512
541
|
|
513
542
|
if percent_done is None:
|