wandb 0.19.1__py3-none-macosx_11_0_arm64.whl → 0.19.2__py3-none-macosx_11_0_arm64.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -5
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/public/files.py +1 -1
- wandb/apis/public/jobs.py +1 -1
- wandb/apis/public/runs.py +2 -7
- wandb/apis/reports/v1/__init__.py +1 -1
- wandb/apis/reports/v2/__init__.py +1 -1
- wandb/apis/workspaces/__init__.py +1 -1
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +7 -4
- wandb/cli/cli.py +5 -7
- wandb/docker/__init__.py +4 -4
- wandb/integration/fastai/__init__.py +4 -6
- wandb/integration/keras/keras.py +5 -3
- wandb/integration/metaflow/metaflow.py +7 -7
- wandb/integration/prodigy/prodigy.py +3 -11
- wandb/integration/sagemaker/__init__.py +5 -3
- wandb/integration/sagemaker/config.py +17 -8
- wandb/integration/sagemaker/files.py +0 -1
- wandb/integration/sagemaker/resources.py +47 -18
- wandb/integration/torch/wandb_torch.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +273 -235
- wandb/proto/v4/wandb_internal_pb2.py +222 -214
- wandb/proto/v5/wandb_internal_pb2.py +222 -214
- wandb/sdk/artifacts/artifact.py +3 -9
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/graph.py +2 -2
- wandb/sdk/data_types/saved_model.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/interface/interface.py +25 -25
- wandb/sdk/interface/interface_shared.py +21 -5
- wandb/sdk/internal/handler.py +19 -1
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +4 -5
- wandb/sdk/internal/sample.py +2 -2
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/settings_static.py +3 -1
- wandb/sdk/internal/system/assets/disk.py +4 -4
- wandb/sdk/internal/system/assets/gpu.py +1 -1
- wandb/sdk/internal/system/assets/memory.py +1 -1
- wandb/sdk/internal/system/system_info.py +1 -1
- wandb/sdk/internal/system/system_monitor.py +3 -1
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/builder/abstract.py +1 -1
- wandb/sdk/lib/apikey.py +2 -3
- wandb/sdk/lib/fsm.py +1 -1
- wandb/sdk/lib/gitlib.py +1 -1
- wandb/sdk/lib/gql_request.py +1 -1
- wandb/sdk/lib/interrupt.py +37 -0
- wandb/sdk/lib/lazyloader.py +1 -1
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/telemetry.py +1 -1
- wandb/sdk/service/_startup_debug.py +1 -1
- wandb/sdk/service/server_sock.py +3 -2
- wandb/sdk/service/service.py +1 -1
- wandb/sdk/service/streams.py +19 -17
- wandb/sdk/verify/verify.py +13 -13
- wandb/sdk/wandb_init.py +95 -104
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_metadata.py +547 -0
- wandb/sdk/wandb_run.py +127 -35
- wandb/sdk/wandb_settings.py +5 -36
- wandb/sdk/wandb_setup.py +83 -82
- wandb/sdk/wandb_sweep.py +2 -2
- wandb/sdk/wandb_sync.py +15 -18
- wandb/sync/sync.py +10 -10
- wandb/util.py +11 -3
- wandb/wandb_agent.py +11 -16
- wandb/wandb_controller.py +7 -7
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/METADATA +3 -2
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/RECORD +79 -77
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/WHEEL +0 -0
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py
CHANGED
wandb/__init__.pyi
CHANGED
@@ -103,7 +103,7 @@ if TYPE_CHECKING:
|
|
103
103
|
import wandb
|
104
104
|
from wandb.plot import CustomChart
|
105
105
|
|
106
|
-
__version__: str = "0.19.
|
106
|
+
__version__: str = "0.19.2"
|
107
107
|
|
108
108
|
run: Run | None
|
109
109
|
config: wandb_config.Config
|
@@ -114,9 +114,7 @@ _sentry: Sentry
|
|
114
114
|
api: InternalApi
|
115
115
|
patched: Dict[str, List[Callable]]
|
116
116
|
|
117
|
-
def setup(
|
118
|
-
settings: Settings | None = None,
|
119
|
-
) -> Optional[_WandbSetup]:
|
117
|
+
def setup(settings: Settings | None = None) -> _WandbSetup:
|
120
118
|
"""Prepares W&B for use in the current process and its children.
|
121
119
|
|
122
120
|
You can usually ignore this as it is implicitly called by `wandb.init()`.
|
@@ -265,7 +263,7 @@ def init(
|
|
265
263
|
entity: The username or team name under which the runs will be logged.
|
266
264
|
The entity must already exist, so ensure you’ve created your account
|
267
265
|
or team in the UI before starting to log runs. If not specified, the
|
268
|
-
run will default your
|
266
|
+
run will default your default entity. To change the default entity,
|
269
267
|
go to [your settings](https://wandb.ai/settings) and update the
|
270
268
|
"Default location to create new projects" under "Default team".
|
271
269
|
project: The name of the project under which this run will be logged.
|
wandb/agents/pyagent.py
CHANGED
@@ -297,7 +297,7 @@ class Agent:
|
|
297
297
|
sweep_param_path, job.config
|
298
298
|
)
|
299
299
|
os.environ[wandb.env.SWEEP_ID] = self._sweep_id
|
300
|
-
wandb.
|
300
|
+
wandb.teardown()
|
301
301
|
|
302
302
|
wandb.termlog(f"Agent Starting Run: {run_id} with config:")
|
303
303
|
for k, v in job.config.items():
|
wandb/apis/importers/wandb.py
CHANGED
@@ -1483,7 +1483,7 @@ def _get_run_or_dummy_from_art(art: Artifact, api=None):
|
|
1483
1483
|
run = art.logged_by()
|
1484
1484
|
except ValueError as e:
|
1485
1485
|
logger.warn(
|
1486
|
-
f"Can't log artifact because run
|
1486
|
+
f"Can't log artifact because run doesn't exist, {art=}, {run=}, {e=}"
|
1487
1487
|
)
|
1488
1488
|
|
1489
1489
|
if run is not None:
|
wandb/apis/public/files.py
CHANGED
@@ -233,7 +233,7 @@ class File(Attrs):
|
|
233
233
|
def _server_accepts_project_id_for_delete_file(self) -> bool:
|
234
234
|
"""Returns True if the server supports deleting files with a projectId.
|
235
235
|
|
236
|
-
This check is done by utilizing GraphQL introspection in the
|
236
|
+
This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API.
|
237
237
|
"""
|
238
238
|
query_string = """
|
239
239
|
query ProbeDeleteFilesProjectIdInput {
|
wandb/apis/public/jobs.py
CHANGED
wandb/apis/public/runs.py
CHANGED
@@ -704,7 +704,7 @@ class Run(Attrs):
|
|
704
704
|
if pd:
|
705
705
|
lines = pd.DataFrame.from_records(lines)
|
706
706
|
else:
|
707
|
-
|
707
|
+
wandb.termwarn("Unable to load pandas, call history with pandas=False")
|
708
708
|
return lines
|
709
709
|
|
710
710
|
@normalize_exceptions
|
@@ -908,7 +908,7 @@ class Run(Attrs):
|
|
908
908
|
def _server_provides_internal_id_for_project(self) -> bool:
|
909
909
|
"""Returns True if the server allows us to query the internalId field for a project.
|
910
910
|
|
911
|
-
This check is done by utilizing GraphQL introspection in the
|
911
|
+
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
|
912
912
|
"""
|
913
913
|
query_string = """
|
914
914
|
query ProbeProjectInput {
|
@@ -924,11 +924,6 @@ class Run(Attrs):
|
|
924
924
|
if self.server_provides_internal_id_field is None:
|
925
925
|
query = gql(query_string)
|
926
926
|
res = self.client.execute(query)
|
927
|
-
print(
|
928
|
-
"internalId"
|
929
|
-
in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
|
930
|
-
)
|
931
|
-
|
932
927
|
self.server_provides_internal_id_field = "internalId" in [
|
933
928
|
x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
|
934
929
|
]
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.reports.v1 import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit reports
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.reports.v2 import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit reports
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.workspaces import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit workspaces
|
7
|
+
"Failed to import wandb_workspaces. To edit workspaces programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
wandb/bin/gpu_stats
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
wandb/cli/beta.py
CHANGED
@@ -12,6 +12,7 @@ import click
|
|
12
12
|
|
13
13
|
import wandb
|
14
14
|
from wandb.errors import UsageError, WandbCoreNotAvailableError
|
15
|
+
from wandb.sdk.wandb_sync import _sync
|
15
16
|
from wandb.util import get_core_path
|
16
17
|
|
17
18
|
|
@@ -108,7 +109,9 @@ def sync_beta( # noqa: C901
|
|
108
109
|
continue
|
109
110
|
wandb_files = [p for p in d.glob("*.wandb") if p.is_file()]
|
110
111
|
if len(wandb_files) > 1:
|
111
|
-
|
112
|
+
wandb.termwarn(
|
113
|
+
f"Multiple wandb files found in directory {d}, skipping"
|
114
|
+
)
|
112
115
|
elif len(wandb_files) == 1:
|
113
116
|
paths.add(d)
|
114
117
|
else:
|
@@ -128,7 +131,7 @@ def sync_beta( # noqa: C901
|
|
128
131
|
for path in paths:
|
129
132
|
wandb_synced_files = [p for p in path.glob("*.wandb.synced") if p.is_file()]
|
130
133
|
if len(wandb_synced_files) > 1:
|
131
|
-
|
134
|
+
wandb.termwarn(
|
132
135
|
f"Multiple wandb.synced files found in directory {path}, skipping"
|
133
136
|
)
|
134
137
|
elif len(wandb_synced_files) == 1:
|
@@ -151,7 +154,7 @@ def sync_beta( # noqa: C901
|
|
151
154
|
if dry_run:
|
152
155
|
return
|
153
156
|
|
154
|
-
wandb.
|
157
|
+
wandb.setup()
|
155
158
|
|
156
159
|
# TODO: make it thread-safe in the Rust code
|
157
160
|
with concurrent.futures.ProcessPoolExecutor(
|
@@ -162,7 +165,7 @@ def sync_beta( # noqa: C901
|
|
162
165
|
# we already know there is only one wandb file in the directory
|
163
166
|
wandb_file = [p for p in path.glob("*.wandb") if p.is_file()][0]
|
164
167
|
future = executor.submit(
|
165
|
-
|
168
|
+
_sync,
|
166
169
|
wandb_file,
|
167
170
|
run_id=run_id,
|
168
171
|
project=project,
|
wandb/cli/cli.py
CHANGED
@@ -125,7 +125,7 @@ def _get_cling_api(reset=None):
|
|
125
125
|
global _api
|
126
126
|
if reset:
|
127
127
|
_api = None
|
128
|
-
|
128
|
+
wandb.teardown()
|
129
129
|
if _api is None:
|
130
130
|
# TODO(jhr): make a settings object that is better for non runs.
|
131
131
|
# only override the necessary setting
|
@@ -2437,7 +2437,7 @@ def ls(path, type):
|
|
2437
2437
|
per_page=1,
|
2438
2438
|
)
|
2439
2439
|
latest = next(versions)
|
2440
|
-
|
2440
|
+
wandb.termlog(
|
2441
2441
|
"{:<15s}{:<15s}{:>15s} {:<20s}".format(
|
2442
2442
|
kind.type,
|
2443
2443
|
latest.updated_at,
|
@@ -2463,7 +2463,7 @@ def cleanup(target_size, remove_temp):
|
|
2463
2463
|
target_size = util.from_human_size(target_size)
|
2464
2464
|
cache = get_artifact_file_cache()
|
2465
2465
|
reclaimed_bytes = cache.cleanup(target_size, remove_temp)
|
2466
|
-
|
2466
|
+
wandb.termlog(f"Reclaimed {util.to_human_size(reclaimed_bytes)} of space")
|
2467
2467
|
|
2468
2468
|
|
2469
2469
|
@cli.command(context_settings=CONTEXT, help="Pull files from Weights & Biases")
|
@@ -2664,7 +2664,6 @@ Run `git clone {}` and restore from there or pass the --no-git flag.""".format(r
|
|
2664
2664
|
def online():
|
2665
2665
|
api = InternalApi()
|
2666
2666
|
try:
|
2667
|
-
api.clear_setting("disabled", persist=True)
|
2668
2667
|
api.clear_setting("mode", persist=True)
|
2669
2668
|
except configparser.Error:
|
2670
2669
|
pass
|
@@ -2678,7 +2677,6 @@ def online():
|
|
2678
2677
|
def offline():
|
2679
2678
|
api = InternalApi()
|
2680
2679
|
try:
|
2681
|
-
api.set_setting("disabled", "true", persist=True)
|
2682
2680
|
api.set_setting("mode", "offline", persist=True)
|
2683
2681
|
click.echo(
|
2684
2682
|
"W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B."
|
@@ -2765,13 +2763,13 @@ def verify(host):
|
|
2765
2763
|
reinit = False
|
2766
2764
|
if host is None:
|
2767
2765
|
host = api.settings("base_url")
|
2768
|
-
|
2766
|
+
wandb.termlog(f"Default host selected: {host}")
|
2769
2767
|
# if the given host does not match the default host, re-run init
|
2770
2768
|
elif host != api.settings("base_url"):
|
2771
2769
|
reinit = True
|
2772
2770
|
|
2773
2771
|
tmp_dir = tempfile.mkdtemp()
|
2774
|
-
|
2772
|
+
wandb.termlog(
|
2775
2773
|
"Find detailed logs for this test at: {}".format(os.path.join(tmp_dir, "wandb"))
|
2776
2774
|
)
|
2777
2775
|
os.chdir(tmp_dir)
|
wandb/docker/__init__.py
CHANGED
@@ -62,7 +62,7 @@ def shell(cmd: List[str]) -> Optional[str]:
|
|
62
62
|
.strip()
|
63
63
|
)
|
64
64
|
except subprocess.CalledProcessError as e:
|
65
|
-
print(e)
|
65
|
+
print(e) # noqa: T201
|
66
66
|
return None
|
67
67
|
|
68
68
|
|
@@ -140,12 +140,12 @@ def run_command_live_output(args: List[Any]) -> str:
|
|
140
140
|
break
|
141
141
|
index = chunk.find(b"\r")
|
142
142
|
if index != -1:
|
143
|
-
print(chunk.decode(), end="")
|
143
|
+
print(chunk.decode(), end="") # noqa: T201
|
144
144
|
else:
|
145
145
|
stdout += chunk.decode()
|
146
|
-
print(chunk.decode(), end="\r")
|
146
|
+
print(chunk.decode(), end="\r") # noqa: T201
|
147
147
|
|
148
|
-
print(stdout)
|
148
|
+
print(stdout) # noqa: T201
|
149
149
|
|
150
150
|
return_code = process.wait()
|
151
151
|
if return_code != 0:
|
@@ -54,7 +54,7 @@ try:
|
|
54
54
|
matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
|
55
55
|
import matplotlib.pyplot as plt
|
56
56
|
except ImportError:
|
57
|
-
|
57
|
+
wandb.termwarn("matplotlib required if logging sample image predictions")
|
58
58
|
|
59
59
|
|
60
60
|
class WandbCallback(TrackerCallback):
|
@@ -134,10 +134,8 @@ class WandbCallback(TrackerCallback):
|
|
134
134
|
# Adapted from fast.ai "SaveModelCallback"
|
135
135
|
current = self.get_monitor_value()
|
136
136
|
if current is not None and self.operator(current, self.best):
|
137
|
-
|
138
|
-
"Better model found at epoch {} with {} value: {}."
|
139
|
-
epoch, self.monitor, current
|
140
|
-
)
|
137
|
+
wandb.termlog(
|
138
|
+
f"Better model found at epoch {epoch} with {self.monitor} value: {current}."
|
141
139
|
)
|
142
140
|
self.best = current
|
143
141
|
|
@@ -173,7 +171,7 @@ class WandbCallback(TrackerCallback):
|
|
173
171
|
if self.model_path.is_file():
|
174
172
|
with self.model_path.open("rb") as model_file:
|
175
173
|
self.learn.load(model_file, purge=False)
|
176
|
-
|
174
|
+
wandb.termlog(f"Loaded best saved model from {self.model_path}")
|
177
175
|
|
178
176
|
def _wandb_log_predictions(self) -> None:
|
179
177
|
"""Log prediction samples."""
|
wandb/integration/keras/keras.py
CHANGED
@@ -509,7 +509,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
509
509
|
|
510
510
|
# From Keras
|
511
511
|
if mode not in ["auto", "min", "max"]:
|
512
|
-
|
512
|
+
wandb.termwarn(
|
513
|
+
f"WandbCallback mode {mode} is unknown, fallback to auto mode."
|
514
|
+
)
|
513
515
|
mode = "auto"
|
514
516
|
|
515
517
|
if mode == "min":
|
@@ -632,7 +634,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
632
634
|
)
|
633
635
|
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
|
634
636
|
if self.verbose and not self.save_model:
|
635
|
-
|
637
|
+
wandb.termlog(
|
636
638
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
|
637
639
|
)
|
638
640
|
if self.save_model:
|
@@ -1003,7 +1005,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
1003
1005
|
if wandb.run.disabled:
|
1004
1006
|
return
|
1005
1007
|
if self.verbose > 0:
|
1006
|
-
|
1008
|
+
wandb.termlog(
|
1007
1009
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
|
1008
1010
|
f"saving model to {self.filepath}"
|
1009
1011
|
)
|
@@ -73,8 +73,8 @@ try:
|
|
73
73
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
74
74
|
|
75
75
|
except ImportError:
|
76
|
-
|
77
|
-
"
|
76
|
+
wandb.termwarn(
|
77
|
+
"`pandas` not installed >> @wandb_log(datasets=True) may not auto log your dataset!"
|
78
78
|
)
|
79
79
|
|
80
80
|
try:
|
@@ -119,8 +119,8 @@ try:
|
|
119
119
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
120
120
|
|
121
121
|
except ImportError:
|
122
|
-
|
123
|
-
"
|
122
|
+
wandb.termwarn(
|
123
|
+
"`pytorch` not installed >> @wandb_log(models=True) may not auto log your model!"
|
124
124
|
)
|
125
125
|
|
126
126
|
try:
|
@@ -164,8 +164,8 @@ try:
|
|
164
164
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
165
165
|
|
166
166
|
except ImportError:
|
167
|
-
|
168
|
-
"
|
167
|
+
wandb.termwarn(
|
168
|
+
"`sklearn` not installed >> @wandb_log(models=True) may not auto log your model!"
|
169
169
|
)
|
170
170
|
|
171
171
|
|
@@ -245,7 +245,7 @@ def wandb_use(name: str, data, *args, **kwargs):
|
|
245
245
|
try:
|
246
246
|
return _wandb_use(name, data, *args, **kwargs)
|
247
247
|
except wandb.CommError:
|
248
|
-
|
248
|
+
wandb.termwarn(
|
249
249
|
f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
|
250
250
|
f"If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier), then you can safely ignore this"
|
251
251
|
f"Otherwise you may want to check your internet connection!"
|
@@ -237,11 +237,7 @@ def create_table(data):
|
|
237
237
|
im = Image.open(urllib.request.urlopen(document["image"]))
|
238
238
|
document["image_visual"] = wandb.Image(im)
|
239
239
|
except urllib.error.URLError:
|
240
|
-
|
241
|
-
"Warning: Image URL "
|
242
|
-
+ str(document["image"])
|
243
|
-
+ " is invalid."
|
244
|
-
)
|
240
|
+
wandb.termwarn(f"Image URL {document['image']} is invalid.")
|
245
241
|
document["image_visual"] = None
|
246
242
|
elif isbase64:
|
247
243
|
# is base64 uri
|
@@ -252,11 +248,7 @@ def create_table(data):
|
|
252
248
|
im = Image.open(buf)
|
253
249
|
document["image_visual"] = wandb.Image(im)
|
254
250
|
except base64.binascii.Error:
|
255
|
-
|
256
|
-
"Warning: Base64 string "
|
257
|
-
+ str(document["image"])
|
258
|
-
+ " is invalid."
|
259
|
-
)
|
251
|
+
wandb.termwarn(f"Base64 string {document['image']} is invalid.")
|
260
252
|
document["image_visual"] = None
|
261
253
|
else:
|
262
254
|
# is data path
|
@@ -296,4 +288,4 @@ def upload_dataset(dataset_name):
|
|
296
288
|
standardize(data[i], schema, array_dict_types)
|
297
289
|
table = create_table(data)
|
298
290
|
wandb.log({dataset_name: table})
|
299
|
-
|
291
|
+
wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")
|
@@ -1,12 +1,14 @@
|
|
1
1
|
"""wandb integration sagemaker module."""
|
2
2
|
|
3
3
|
from .auth import sagemaker_auth
|
4
|
-
from .config import parse_sm_config
|
5
|
-
from .resources import
|
4
|
+
from .config import is_using_sagemaker, parse_sm_config
|
5
|
+
from .resources import parse_sm_secrets, set_global_settings, set_run_id
|
6
6
|
|
7
7
|
__all__ = [
|
8
8
|
"sagemaker_auth",
|
9
|
+
"is_using_sagemaker",
|
9
10
|
"parse_sm_config",
|
10
11
|
"parse_sm_secrets",
|
11
|
-
"
|
12
|
+
"set_global_settings",
|
13
|
+
"set_run_id",
|
12
14
|
]
|
@@ -1,13 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import json
|
2
4
|
import os
|
3
5
|
import re
|
4
6
|
import warnings
|
5
|
-
from typing import Any
|
7
|
+
from typing import Any
|
6
8
|
|
7
9
|
from . import files as sm_files
|
8
10
|
|
9
11
|
|
10
|
-
def
|
12
|
+
def is_using_sagemaker() -> bool:
|
13
|
+
"""Returns whether we're in a SageMaker environment."""
|
14
|
+
return (
|
15
|
+
os.path.exists(sm_files.SM_PARAM_CONFIG) #
|
16
|
+
or "SM_TRAINING_ENV" in os.environ
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def parse_sm_config() -> dict[str, Any]:
|
11
21
|
"""Parses SageMaker configuration.
|
12
22
|
|
13
23
|
Returns:
|
@@ -23,9 +33,7 @@ def parse_sm_config() -> Dict[str, Any]:
|
|
23
33
|
"""
|
24
34
|
conf = {}
|
25
35
|
|
26
|
-
if os.path.exists(sm_files.SM_PARAM_CONFIG)
|
27
|
-
sm_files.SM_RESOURCE_CONFIG
|
28
|
-
):
|
36
|
+
if os.path.exists(sm_files.SM_PARAM_CONFIG):
|
29
37
|
conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
|
30
38
|
|
31
39
|
# Hyperparameter searches quote configs...
|
@@ -38,12 +46,13 @@ def parse_sm_config() -> Dict[str, Any]:
|
|
38
46
|
cast = float(cast)
|
39
47
|
conf[key] = cast
|
40
48
|
|
41
|
-
if
|
49
|
+
if env := os.environ.get("SM_TRAINING_ENV"):
|
42
50
|
try:
|
43
|
-
conf
|
51
|
+
conf.update(json.loads(env))
|
44
52
|
except json.JSONDecodeError:
|
45
53
|
warnings.warn(
|
46
|
-
"Failed to parse SM_TRAINING_ENV not valid JSON string",
|
54
|
+
"Failed to parse SM_TRAINING_ENV not valid JSON string",
|
55
|
+
stacklevel=2,
|
47
56
|
)
|
48
57
|
|
49
58
|
return conf
|
@@ -1,13 +1,58 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import secrets
|
3
5
|
import socket
|
4
6
|
import string
|
5
|
-
from typing import Dict, Tuple
|
6
7
|
|
8
|
+
import wandb
|
9
|
+
|
10
|
+
from . import config
|
7
11
|
from . import files as sm_files
|
8
12
|
|
9
13
|
|
10
|
-
def
|
14
|
+
def set_run_id(run_settings: wandb.Settings) -> bool:
|
15
|
+
"""Set a run ID and group when using SageMaker.
|
16
|
+
|
17
|
+
Returns whether the ID and group were updated.
|
18
|
+
"""
|
19
|
+
# Added in https://github.com/wandb/wandb/pull/3290.
|
20
|
+
#
|
21
|
+
# Prevents SageMaker from overriding the run ID configured
|
22
|
+
# in environment variables. Note, however, that it will still
|
23
|
+
# override a run ID passed explicitly to `wandb.init()`.
|
24
|
+
if os.getenv("WANDB_RUN_ID"):
|
25
|
+
return False
|
26
|
+
|
27
|
+
run_group = os.getenv("TRAINING_JOB_NAME")
|
28
|
+
if not run_group:
|
29
|
+
return False
|
30
|
+
|
31
|
+
alphanumeric = string.ascii_lowercase + string.digits
|
32
|
+
random = "".join(secrets.choice(alphanumeric) for _ in range(6))
|
33
|
+
|
34
|
+
host = os.getenv("CURRENT_HOST", socket.gethostname())
|
35
|
+
|
36
|
+
run_settings.run_id = f"{run_group}-{random}-{host}"
|
37
|
+
run_settings.run_group = run_group
|
38
|
+
return True
|
39
|
+
|
40
|
+
|
41
|
+
def set_global_settings(settings: wandb.Settings) -> None:
|
42
|
+
"""Set global W&B settings based on the SageMaker environment."""
|
43
|
+
if env := parse_sm_secrets():
|
44
|
+
settings.update_from_env_vars(env)
|
45
|
+
|
46
|
+
# The SageMaker config may contain an API key, in which case it
|
47
|
+
# takes precedence over the value in the secrets. It's unclear
|
48
|
+
# whether this is by design, or by accident; we keep it for
|
49
|
+
# backward compatibility for now.
|
50
|
+
sm_config = config.parse_sm_config()
|
51
|
+
if api_key := sm_config.get("wandb_api_key"):
|
52
|
+
settings.api_key = api_key
|
53
|
+
|
54
|
+
|
55
|
+
def parse_sm_secrets() -> dict[str, str]:
|
11
56
|
"""We read our api_key from secrets.env in SageMaker."""
|
12
57
|
env_dict = dict()
|
13
58
|
# Set secret variables
|
@@ -16,19 +61,3 @@ def parse_sm_secrets() -> Dict[str, str]:
|
|
16
61
|
key, val = line.strip().split("=", 1)
|
17
62
|
env_dict[key] = val
|
18
63
|
return env_dict
|
19
|
-
|
20
|
-
|
21
|
-
def parse_sm_resources() -> Tuple[Dict[str, str], Dict[str, str]]:
|
22
|
-
run_dict = dict()
|
23
|
-
run_id = os.getenv("TRAINING_JOB_NAME")
|
24
|
-
|
25
|
-
if run_id and os.getenv("WANDB_RUN_ID") is None:
|
26
|
-
suffix = "".join(
|
27
|
-
secrets.choice(string.ascii_lowercase + string.digits) for _ in range(6)
|
28
|
-
)
|
29
|
-
run_dict["run_id"] = "-".join(
|
30
|
-
[run_id, suffix, os.getenv("CURRENT_HOST", socket.gethostname())]
|
31
|
-
)
|
32
|
-
run_dict["run_group"] = os.getenv("TRAINING_JOB_NAME")
|
33
|
-
env_dict = parse_sm_secrets()
|
34
|
-
return run_dict, env_dict
|
@@ -441,7 +441,7 @@ class TorchGraph(wandb.data_types.Graph):
|
|
441
441
|
decoder.weight encoder
|
442
442
|
decoder.bias decoder
|
443
443
|
"""
|
444
|
-
# TODO: We're currently not using this, but I left it here
|
444
|
+
# TODO: We're currently not using this, but I left it here in case we want to resurrect! - CVP
|
445
445
|
torch = util.get_module("torch", "Could not import torch")
|
446
446
|
|
447
447
|
module_nodes_by_hash = {id(n): n for n in module_graph.nodes}
|