wandb 0.19.1__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl → 0.19.2__py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +3 -5
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/public/files.py +1 -1
- wandb/apis/public/jobs.py +1 -1
- wandb/apis/public/runs.py +2 -7
- wandb/apis/reports/v1/__init__.py +1 -1
- wandb/apis/reports/v2/__init__.py +1 -1
- wandb/apis/workspaces/__init__.py +1 -1
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +7 -4
- wandb/cli/cli.py +5 -7
- wandb/docker/__init__.py +4 -4
- wandb/integration/fastai/__init__.py +4 -6
- wandb/integration/keras/keras.py +5 -3
- wandb/integration/metaflow/metaflow.py +7 -7
- wandb/integration/prodigy/prodigy.py +3 -11
- wandb/integration/sagemaker/__init__.py +5 -3
- wandb/integration/sagemaker/config.py +17 -8
- wandb/integration/sagemaker/files.py +0 -1
- wandb/integration/sagemaker/resources.py +47 -18
- wandb/integration/torch/wandb_torch.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +273 -235
- wandb/proto/v4/wandb_internal_pb2.py +222 -214
- wandb/proto/v5/wandb_internal_pb2.py +222 -214
- wandb/sdk/artifacts/artifact.py +3 -9
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/base_types/wb_value.py +1 -1
- wandb/sdk/data_types/graph.py +2 -2
- wandb/sdk/data_types/saved_model.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/interface/interface.py +25 -25
- wandb/sdk/interface/interface_shared.py +21 -5
- wandb/sdk/internal/handler.py +19 -1
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +4 -5
- wandb/sdk/internal/sample.py +2 -2
- wandb/sdk/internal/sender.py +1 -2
- wandb/sdk/internal/settings_static.py +3 -1
- wandb/sdk/internal/system/assets/disk.py +4 -4
- wandb/sdk/internal/system/assets/gpu.py +1 -1
- wandb/sdk/internal/system/assets/memory.py +1 -1
- wandb/sdk/internal/system/system_info.py +1 -1
- wandb/sdk/internal/system/system_monitor.py +3 -1
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/builder/abstract.py +1 -1
- wandb/sdk/lib/apikey.py +2 -3
- wandb/sdk/lib/fsm.py +1 -1
- wandb/sdk/lib/gitlib.py +1 -1
- wandb/sdk/lib/gql_request.py +1 -1
- wandb/sdk/lib/interrupt.py +37 -0
- wandb/sdk/lib/lazyloader.py +1 -1
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/telemetry.py +1 -1
- wandb/sdk/service/_startup_debug.py +1 -1
- wandb/sdk/service/server_sock.py +3 -2
- wandb/sdk/service/service.py +1 -1
- wandb/sdk/service/streams.py +19 -17
- wandb/sdk/verify/verify.py +13 -13
- wandb/sdk/wandb_init.py +95 -104
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_metadata.py +547 -0
- wandb/sdk/wandb_run.py +127 -35
- wandb/sdk/wandb_settings.py +5 -36
- wandb/sdk/wandb_setup.py +83 -82
- wandb/sdk/wandb_sweep.py +2 -2
- wandb/sdk/wandb_sync.py +15 -18
- wandb/sync/sync.py +10 -10
- wandb/util.py +11 -3
- wandb/wandb_agent.py +11 -16
- wandb/wandb_controller.py +7 -7
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/METADATA +5 -3
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/RECORD +79 -77
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/WHEEL +1 -1
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.1.dist-info → wandb-0.19.2.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py
CHANGED
wandb/__init__.pyi
CHANGED
@@ -103,7 +103,7 @@ if TYPE_CHECKING:
|
|
103
103
|
import wandb
|
104
104
|
from wandb.plot import CustomChart
|
105
105
|
|
106
|
-
__version__: str = "0.19.
|
106
|
+
__version__: str = "0.19.2"
|
107
107
|
|
108
108
|
run: Run | None
|
109
109
|
config: wandb_config.Config
|
@@ -114,9 +114,7 @@ _sentry: Sentry
|
|
114
114
|
api: InternalApi
|
115
115
|
patched: Dict[str, List[Callable]]
|
116
116
|
|
117
|
-
def setup(
|
118
|
-
settings: Settings | None = None,
|
119
|
-
) -> Optional[_WandbSetup]:
|
117
|
+
def setup(settings: Settings | None = None) -> _WandbSetup:
|
120
118
|
"""Prepares W&B for use in the current process and its children.
|
121
119
|
|
122
120
|
You can usually ignore this as it is implicitly called by `wandb.init()`.
|
@@ -265,7 +263,7 @@ def init(
|
|
265
263
|
entity: The username or team name under which the runs will be logged.
|
266
264
|
The entity must already exist, so ensure you’ve created your account
|
267
265
|
or team in the UI before starting to log runs. If not specified, the
|
268
|
-
run will default your
|
266
|
+
run will default your default entity. To change the default entity,
|
269
267
|
go to [your settings](https://wandb.ai/settings) and update the
|
270
268
|
"Default location to create new projects" under "Default team".
|
271
269
|
project: The name of the project under which this run will be logged.
|
wandb/agents/pyagent.py
CHANGED
@@ -297,7 +297,7 @@ class Agent:
|
|
297
297
|
sweep_param_path, job.config
|
298
298
|
)
|
299
299
|
os.environ[wandb.env.SWEEP_ID] = self._sweep_id
|
300
|
-
wandb.
|
300
|
+
wandb.teardown()
|
301
301
|
|
302
302
|
wandb.termlog(f"Agent Starting Run: {run_id} with config:")
|
303
303
|
for k, v in job.config.items():
|
wandb/apis/importers/wandb.py
CHANGED
@@ -1483,7 +1483,7 @@ def _get_run_or_dummy_from_art(art: Artifact, api=None):
|
|
1483
1483
|
run = art.logged_by()
|
1484
1484
|
except ValueError as e:
|
1485
1485
|
logger.warn(
|
1486
|
-
f"Can't log artifact because run
|
1486
|
+
f"Can't log artifact because run doesn't exist, {art=}, {run=}, {e=}"
|
1487
1487
|
)
|
1488
1488
|
|
1489
1489
|
if run is not None:
|
wandb/apis/public/files.py
CHANGED
@@ -233,7 +233,7 @@ class File(Attrs):
|
|
233
233
|
def _server_accepts_project_id_for_delete_file(self) -> bool:
|
234
234
|
"""Returns True if the server supports deleting files with a projectId.
|
235
235
|
|
236
|
-
This check is done by utilizing GraphQL introspection in the
|
236
|
+
This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API.
|
237
237
|
"""
|
238
238
|
query_string = """
|
239
239
|
query ProbeDeleteFilesProjectIdInput {
|
wandb/apis/public/jobs.py
CHANGED
wandb/apis/public/runs.py
CHANGED
@@ -704,7 +704,7 @@ class Run(Attrs):
|
|
704
704
|
if pd:
|
705
705
|
lines = pd.DataFrame.from_records(lines)
|
706
706
|
else:
|
707
|
-
|
707
|
+
wandb.termwarn("Unable to load pandas, call history with pandas=False")
|
708
708
|
return lines
|
709
709
|
|
710
710
|
@normalize_exceptions
|
@@ -908,7 +908,7 @@ class Run(Attrs):
|
|
908
908
|
def _server_provides_internal_id_for_project(self) -> bool:
|
909
909
|
"""Returns True if the server allows us to query the internalId field for a project.
|
910
910
|
|
911
|
-
This check is done by utilizing GraphQL introspection in the
|
911
|
+
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
|
912
912
|
"""
|
913
913
|
query_string = """
|
914
914
|
query ProbeProjectInput {
|
@@ -924,11 +924,6 @@ class Run(Attrs):
|
|
924
924
|
if self.server_provides_internal_id_field is None:
|
925
925
|
query = gql(query_string)
|
926
926
|
res = self.client.execute(query)
|
927
|
-
print(
|
928
|
-
"internalId"
|
929
|
-
in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
|
930
|
-
)
|
931
|
-
|
932
927
|
self.server_provides_internal_id_field = "internalId" in [
|
933
928
|
x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
|
934
929
|
]
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.reports.v1 import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit reports
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.reports.v2 import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit reports
|
7
|
+
"Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
@@ -4,5 +4,5 @@ try:
|
|
4
4
|
from wandb_workspaces.workspaces import * # noqa: F403
|
5
5
|
except ImportError:
|
6
6
|
wandb.termerror(
|
7
|
-
"Failed to import wandb_workspaces. To edit workspaces
|
7
|
+
"Failed to import wandb_workspaces. To edit workspaces programmatically, please install it using `pip install wandb[workspaces]`."
|
8
8
|
)
|
wandb/bin/gpu_stats
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
wandb/cli/beta.py
CHANGED
@@ -12,6 +12,7 @@ import click
|
|
12
12
|
|
13
13
|
import wandb
|
14
14
|
from wandb.errors import UsageError, WandbCoreNotAvailableError
|
15
|
+
from wandb.sdk.wandb_sync import _sync
|
15
16
|
from wandb.util import get_core_path
|
16
17
|
|
17
18
|
|
@@ -108,7 +109,9 @@ def sync_beta( # noqa: C901
|
|
108
109
|
continue
|
109
110
|
wandb_files = [p for p in d.glob("*.wandb") if p.is_file()]
|
110
111
|
if len(wandb_files) > 1:
|
111
|
-
|
112
|
+
wandb.termwarn(
|
113
|
+
f"Multiple wandb files found in directory {d}, skipping"
|
114
|
+
)
|
112
115
|
elif len(wandb_files) == 1:
|
113
116
|
paths.add(d)
|
114
117
|
else:
|
@@ -128,7 +131,7 @@ def sync_beta( # noqa: C901
|
|
128
131
|
for path in paths:
|
129
132
|
wandb_synced_files = [p for p in path.glob("*.wandb.synced") if p.is_file()]
|
130
133
|
if len(wandb_synced_files) > 1:
|
131
|
-
|
134
|
+
wandb.termwarn(
|
132
135
|
f"Multiple wandb.synced files found in directory {path}, skipping"
|
133
136
|
)
|
134
137
|
elif len(wandb_synced_files) == 1:
|
@@ -151,7 +154,7 @@ def sync_beta( # noqa: C901
|
|
151
154
|
if dry_run:
|
152
155
|
return
|
153
156
|
|
154
|
-
wandb.
|
157
|
+
wandb.setup()
|
155
158
|
|
156
159
|
# TODO: make it thread-safe in the Rust code
|
157
160
|
with concurrent.futures.ProcessPoolExecutor(
|
@@ -162,7 +165,7 @@ def sync_beta( # noqa: C901
|
|
162
165
|
# we already know there is only one wandb file in the directory
|
163
166
|
wandb_file = [p for p in path.glob("*.wandb") if p.is_file()][0]
|
164
167
|
future = executor.submit(
|
165
|
-
|
168
|
+
_sync,
|
166
169
|
wandb_file,
|
167
170
|
run_id=run_id,
|
168
171
|
project=project,
|
wandb/cli/cli.py
CHANGED
@@ -125,7 +125,7 @@ def _get_cling_api(reset=None):
|
|
125
125
|
global _api
|
126
126
|
if reset:
|
127
127
|
_api = None
|
128
|
-
|
128
|
+
wandb.teardown()
|
129
129
|
if _api is None:
|
130
130
|
# TODO(jhr): make a settings object that is better for non runs.
|
131
131
|
# only override the necessary setting
|
@@ -2437,7 +2437,7 @@ def ls(path, type):
|
|
2437
2437
|
per_page=1,
|
2438
2438
|
)
|
2439
2439
|
latest = next(versions)
|
2440
|
-
|
2440
|
+
wandb.termlog(
|
2441
2441
|
"{:<15s}{:<15s}{:>15s} {:<20s}".format(
|
2442
2442
|
kind.type,
|
2443
2443
|
latest.updated_at,
|
@@ -2463,7 +2463,7 @@ def cleanup(target_size, remove_temp):
|
|
2463
2463
|
target_size = util.from_human_size(target_size)
|
2464
2464
|
cache = get_artifact_file_cache()
|
2465
2465
|
reclaimed_bytes = cache.cleanup(target_size, remove_temp)
|
2466
|
-
|
2466
|
+
wandb.termlog(f"Reclaimed {util.to_human_size(reclaimed_bytes)} of space")
|
2467
2467
|
|
2468
2468
|
|
2469
2469
|
@cli.command(context_settings=CONTEXT, help="Pull files from Weights & Biases")
|
@@ -2664,7 +2664,6 @@ Run `git clone {}` and restore from there or pass the --no-git flag.""".format(r
|
|
2664
2664
|
def online():
|
2665
2665
|
api = InternalApi()
|
2666
2666
|
try:
|
2667
|
-
api.clear_setting("disabled", persist=True)
|
2668
2667
|
api.clear_setting("mode", persist=True)
|
2669
2668
|
except configparser.Error:
|
2670
2669
|
pass
|
@@ -2678,7 +2677,6 @@ def online():
|
|
2678
2677
|
def offline():
|
2679
2678
|
api = InternalApi()
|
2680
2679
|
try:
|
2681
|
-
api.set_setting("disabled", "true", persist=True)
|
2682
2680
|
api.set_setting("mode", "offline", persist=True)
|
2683
2681
|
click.echo(
|
2684
2682
|
"W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B."
|
@@ -2765,13 +2763,13 @@ def verify(host):
|
|
2765
2763
|
reinit = False
|
2766
2764
|
if host is None:
|
2767
2765
|
host = api.settings("base_url")
|
2768
|
-
|
2766
|
+
wandb.termlog(f"Default host selected: {host}")
|
2769
2767
|
# if the given host does not match the default host, re-run init
|
2770
2768
|
elif host != api.settings("base_url"):
|
2771
2769
|
reinit = True
|
2772
2770
|
|
2773
2771
|
tmp_dir = tempfile.mkdtemp()
|
2774
|
-
|
2772
|
+
wandb.termlog(
|
2775
2773
|
"Find detailed logs for this test at: {}".format(os.path.join(tmp_dir, "wandb"))
|
2776
2774
|
)
|
2777
2775
|
os.chdir(tmp_dir)
|
wandb/docker/__init__.py
CHANGED
@@ -62,7 +62,7 @@ def shell(cmd: List[str]) -> Optional[str]:
|
|
62
62
|
.strip()
|
63
63
|
)
|
64
64
|
except subprocess.CalledProcessError as e:
|
65
|
-
print(e)
|
65
|
+
print(e) # noqa: T201
|
66
66
|
return None
|
67
67
|
|
68
68
|
|
@@ -140,12 +140,12 @@ def run_command_live_output(args: List[Any]) -> str:
|
|
140
140
|
break
|
141
141
|
index = chunk.find(b"\r")
|
142
142
|
if index != -1:
|
143
|
-
print(chunk.decode(), end="")
|
143
|
+
print(chunk.decode(), end="") # noqa: T201
|
144
144
|
else:
|
145
145
|
stdout += chunk.decode()
|
146
|
-
print(chunk.decode(), end="\r")
|
146
|
+
print(chunk.decode(), end="\r") # noqa: T201
|
147
147
|
|
148
|
-
print(stdout)
|
148
|
+
print(stdout) # noqa: T201
|
149
149
|
|
150
150
|
return_code = process.wait()
|
151
151
|
if return_code != 0:
|
@@ -54,7 +54,7 @@ try:
|
|
54
54
|
matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
|
55
55
|
import matplotlib.pyplot as plt
|
56
56
|
except ImportError:
|
57
|
-
|
57
|
+
wandb.termwarn("matplotlib required if logging sample image predictions")
|
58
58
|
|
59
59
|
|
60
60
|
class WandbCallback(TrackerCallback):
|
@@ -134,10 +134,8 @@ class WandbCallback(TrackerCallback):
|
|
134
134
|
# Adapted from fast.ai "SaveModelCallback"
|
135
135
|
current = self.get_monitor_value()
|
136
136
|
if current is not None and self.operator(current, self.best):
|
137
|
-
|
138
|
-
"Better model found at epoch {} with {} value: {}."
|
139
|
-
epoch, self.monitor, current
|
140
|
-
)
|
137
|
+
wandb.termlog(
|
138
|
+
f"Better model found at epoch {epoch} with {self.monitor} value: {current}."
|
141
139
|
)
|
142
140
|
self.best = current
|
143
141
|
|
@@ -173,7 +171,7 @@ class WandbCallback(TrackerCallback):
|
|
173
171
|
if self.model_path.is_file():
|
174
172
|
with self.model_path.open("rb") as model_file:
|
175
173
|
self.learn.load(model_file, purge=False)
|
176
|
-
|
174
|
+
wandb.termlog(f"Loaded best saved model from {self.model_path}")
|
177
175
|
|
178
176
|
def _wandb_log_predictions(self) -> None:
|
179
177
|
"""Log prediction samples."""
|
wandb/integration/keras/keras.py
CHANGED
@@ -509,7 +509,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
509
509
|
|
510
510
|
# From Keras
|
511
511
|
if mode not in ["auto", "min", "max"]:
|
512
|
-
|
512
|
+
wandb.termwarn(
|
513
|
+
f"WandbCallback mode {mode} is unknown, fallback to auto mode."
|
514
|
+
)
|
513
515
|
mode = "auto"
|
514
516
|
|
515
517
|
if mode == "min":
|
@@ -632,7 +634,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
632
634
|
)
|
633
635
|
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
|
634
636
|
if self.verbose and not self.save_model:
|
635
|
-
|
637
|
+
wandb.termlog(
|
636
638
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
|
637
639
|
)
|
638
640
|
if self.save_model:
|
@@ -1003,7 +1005,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
1003
1005
|
if wandb.run.disabled:
|
1004
1006
|
return
|
1005
1007
|
if self.verbose > 0:
|
1006
|
-
|
1008
|
+
wandb.termlog(
|
1007
1009
|
f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
|
1008
1010
|
f"saving model to {self.filepath}"
|
1009
1011
|
)
|
@@ -73,8 +73,8 @@ try:
|
|
73
73
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
74
74
|
|
75
75
|
except ImportError:
|
76
|
-
|
77
|
-
"
|
76
|
+
wandb.termwarn(
|
77
|
+
"`pandas` not installed >> @wandb_log(datasets=True) may not auto log your dataset!"
|
78
78
|
)
|
79
79
|
|
80
80
|
try:
|
@@ -119,8 +119,8 @@ try:
|
|
119
119
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
120
120
|
|
121
121
|
except ImportError:
|
122
|
-
|
123
|
-
"
|
122
|
+
wandb.termwarn(
|
123
|
+
"`pytorch` not installed >> @wandb_log(models=True) may not auto log your model!"
|
124
124
|
)
|
125
125
|
|
126
126
|
try:
|
@@ -164,8 +164,8 @@ try:
|
|
164
164
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
165
165
|
|
166
166
|
except ImportError:
|
167
|
-
|
168
|
-
"
|
167
|
+
wandb.termwarn(
|
168
|
+
"`sklearn` not installed >> @wandb_log(models=True) may not auto log your model!"
|
169
169
|
)
|
170
170
|
|
171
171
|
|
@@ -245,7 +245,7 @@ def wandb_use(name: str, data, *args, **kwargs):
|
|
245
245
|
try:
|
246
246
|
return _wandb_use(name, data, *args, **kwargs)
|
247
247
|
except wandb.CommError:
|
248
|
-
|
248
|
+
wandb.termwarn(
|
249
249
|
f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
|
250
250
|
f"If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier), then you can safely ignore this"
|
251
251
|
f"Otherwise you may want to check your internet connection!"
|
@@ -237,11 +237,7 @@ def create_table(data):
|
|
237
237
|
im = Image.open(urllib.request.urlopen(document["image"]))
|
238
238
|
document["image_visual"] = wandb.Image(im)
|
239
239
|
except urllib.error.URLError:
|
240
|
-
|
241
|
-
"Warning: Image URL "
|
242
|
-
+ str(document["image"])
|
243
|
-
+ " is invalid."
|
244
|
-
)
|
240
|
+
wandb.termwarn(f"Image URL {document['image']} is invalid.")
|
245
241
|
document["image_visual"] = None
|
246
242
|
elif isbase64:
|
247
243
|
# is base64 uri
|
@@ -252,11 +248,7 @@ def create_table(data):
|
|
252
248
|
im = Image.open(buf)
|
253
249
|
document["image_visual"] = wandb.Image(im)
|
254
250
|
except base64.binascii.Error:
|
255
|
-
|
256
|
-
"Warning: Base64 string "
|
257
|
-
+ str(document["image"])
|
258
|
-
+ " is invalid."
|
259
|
-
)
|
251
|
+
wandb.termwarn(f"Base64 string {document['image']} is invalid.")
|
260
252
|
document["image_visual"] = None
|
261
253
|
else:
|
262
254
|
# is data path
|
@@ -296,4 +288,4 @@ def upload_dataset(dataset_name):
|
|
296
288
|
standardize(data[i], schema, array_dict_types)
|
297
289
|
table = create_table(data)
|
298
290
|
wandb.log({dataset_name: table})
|
299
|
-
|
291
|
+
wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")
|
@@ -1,12 +1,14 @@
|
|
1
1
|
"""wandb integration sagemaker module."""
|
2
2
|
|
3
3
|
from .auth import sagemaker_auth
|
4
|
-
from .config import parse_sm_config
|
5
|
-
from .resources import
|
4
|
+
from .config import is_using_sagemaker, parse_sm_config
|
5
|
+
from .resources import parse_sm_secrets, set_global_settings, set_run_id
|
6
6
|
|
7
7
|
__all__ = [
|
8
8
|
"sagemaker_auth",
|
9
|
+
"is_using_sagemaker",
|
9
10
|
"parse_sm_config",
|
10
11
|
"parse_sm_secrets",
|
11
|
-
"
|
12
|
+
"set_global_settings",
|
13
|
+
"set_run_id",
|
12
14
|
]
|
@@ -1,13 +1,23 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import json
|
2
4
|
import os
|
3
5
|
import re
|
4
6
|
import warnings
|
5
|
-
from typing import Any
|
7
|
+
from typing import Any
|
6
8
|
|
7
9
|
from . import files as sm_files
|
8
10
|
|
9
11
|
|
10
|
-
def
|
12
|
+
def is_using_sagemaker() -> bool:
|
13
|
+
"""Returns whether we're in a SageMaker environment."""
|
14
|
+
return (
|
15
|
+
os.path.exists(sm_files.SM_PARAM_CONFIG) #
|
16
|
+
or "SM_TRAINING_ENV" in os.environ
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
def parse_sm_config() -> dict[str, Any]:
|
11
21
|
"""Parses SageMaker configuration.
|
12
22
|
|
13
23
|
Returns:
|
@@ -23,9 +33,7 @@ def parse_sm_config() -> Dict[str, Any]:
|
|
23
33
|
"""
|
24
34
|
conf = {}
|
25
35
|
|
26
|
-
if os.path.exists(sm_files.SM_PARAM_CONFIG)
|
27
|
-
sm_files.SM_RESOURCE_CONFIG
|
28
|
-
):
|
36
|
+
if os.path.exists(sm_files.SM_PARAM_CONFIG):
|
29
37
|
conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
|
30
38
|
|
31
39
|
# Hyperparameter searches quote configs...
|
@@ -38,12 +46,13 @@ def parse_sm_config() -> Dict[str, Any]:
|
|
38
46
|
cast = float(cast)
|
39
47
|
conf[key] = cast
|
40
48
|
|
41
|
-
if
|
49
|
+
if env := os.environ.get("SM_TRAINING_ENV"):
|
42
50
|
try:
|
43
|
-
conf
|
51
|
+
conf.update(json.loads(env))
|
44
52
|
except json.JSONDecodeError:
|
45
53
|
warnings.warn(
|
46
|
-
"Failed to parse SM_TRAINING_ENV not valid JSON string",
|
54
|
+
"Failed to parse SM_TRAINING_ENV not valid JSON string",
|
55
|
+
stacklevel=2,
|
47
56
|
)
|
48
57
|
|
49
58
|
return conf
|
@@ -1,13 +1,58 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import os
|
2
4
|
import secrets
|
3
5
|
import socket
|
4
6
|
import string
|
5
|
-
from typing import Dict, Tuple
|
6
7
|
|
8
|
+
import wandb
|
9
|
+
|
10
|
+
from . import config
|
7
11
|
from . import files as sm_files
|
8
12
|
|
9
13
|
|
10
|
-
def
|
14
|
+
def set_run_id(run_settings: wandb.Settings) -> bool:
|
15
|
+
"""Set a run ID and group when using SageMaker.
|
16
|
+
|
17
|
+
Returns whether the ID and group were updated.
|
18
|
+
"""
|
19
|
+
# Added in https://github.com/wandb/wandb/pull/3290.
|
20
|
+
#
|
21
|
+
# Prevents SageMaker from overriding the run ID configured
|
22
|
+
# in environment variables. Note, however, that it will still
|
23
|
+
# override a run ID passed explicitly to `wandb.init()`.
|
24
|
+
if os.getenv("WANDB_RUN_ID"):
|
25
|
+
return False
|
26
|
+
|
27
|
+
run_group = os.getenv("TRAINING_JOB_NAME")
|
28
|
+
if not run_group:
|
29
|
+
return False
|
30
|
+
|
31
|
+
alphanumeric = string.ascii_lowercase + string.digits
|
32
|
+
random = "".join(secrets.choice(alphanumeric) for _ in range(6))
|
33
|
+
|
34
|
+
host = os.getenv("CURRENT_HOST", socket.gethostname())
|
35
|
+
|
36
|
+
run_settings.run_id = f"{run_group}-{random}-{host}"
|
37
|
+
run_settings.run_group = run_group
|
38
|
+
return True
|
39
|
+
|
40
|
+
|
41
|
+
def set_global_settings(settings: wandb.Settings) -> None:
|
42
|
+
"""Set global W&B settings based on the SageMaker environment."""
|
43
|
+
if env := parse_sm_secrets():
|
44
|
+
settings.update_from_env_vars(env)
|
45
|
+
|
46
|
+
# The SageMaker config may contain an API key, in which case it
|
47
|
+
# takes precedence over the value in the secrets. It's unclear
|
48
|
+
# whether this is by design, or by accident; we keep it for
|
49
|
+
# backward compatibility for now.
|
50
|
+
sm_config = config.parse_sm_config()
|
51
|
+
if api_key := sm_config.get("wandb_api_key"):
|
52
|
+
settings.api_key = api_key
|
53
|
+
|
54
|
+
|
55
|
+
def parse_sm_secrets() -> dict[str, str]:
|
11
56
|
"""We read our api_key from secrets.env in SageMaker."""
|
12
57
|
env_dict = dict()
|
13
58
|
# Set secret variables
|
@@ -16,19 +61,3 @@ def parse_sm_secrets() -> Dict[str, str]:
|
|
16
61
|
key, val = line.strip().split("=", 1)
|
17
62
|
env_dict[key] = val
|
18
63
|
return env_dict
|
19
|
-
|
20
|
-
|
21
|
-
def parse_sm_resources() -> Tuple[Dict[str, str], Dict[str, str]]:
|
22
|
-
run_dict = dict()
|
23
|
-
run_id = os.getenv("TRAINING_JOB_NAME")
|
24
|
-
|
25
|
-
if run_id and os.getenv("WANDB_RUN_ID") is None:
|
26
|
-
suffix = "".join(
|
27
|
-
secrets.choice(string.ascii_lowercase + string.digits) for _ in range(6)
|
28
|
-
)
|
29
|
-
run_dict["run_id"] = "-".join(
|
30
|
-
[run_id, suffix, os.getenv("CURRENT_HOST", socket.gethostname())]
|
31
|
-
)
|
32
|
-
run_dict["run_group"] = os.getenv("TRAINING_JOB_NAME")
|
33
|
-
env_dict = parse_sm_secrets()
|
34
|
-
return run_dict, env_dict
|
@@ -441,7 +441,7 @@ class TorchGraph(wandb.data_types.Graph):
|
|
441
441
|
decoder.weight encoder
|
442
442
|
decoder.bias decoder
|
443
443
|
"""
|
444
|
-
# TODO: We're currently not using this, but I left it here
|
444
|
+
# TODO: We're currently not using this, but I left it here in case we want to resurrect! - CVP
|
445
445
|
torch = util.get_module("torch", "Could not import torch")
|
446
446
|
|
447
447
|
module_nodes_by_hash = {id(n): n for n in module_graph.nodes}
|