wandb 0.19.1rc1__py3-none-macosx_11_0_x86_64.whl → 0.19.2__py3-none-macosx_11_0_x86_64.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (80) hide show
  1. wandb/__init__.py +1 -7
  2. wandb/__init__.pyi +3 -5
  3. wandb/agents/pyagent.py +1 -1
  4. wandb/apis/importers/wandb.py +1 -1
  5. wandb/apis/public/files.py +1 -1
  6. wandb/apis/public/jobs.py +1 -1
  7. wandb/apis/public/runs.py +2 -7
  8. wandb/apis/reports/v1/__init__.py +1 -1
  9. wandb/apis/reports/v2/__init__.py +1 -1
  10. wandb/apis/workspaces/__init__.py +1 -1
  11. wandb/bin/gpu_stats +0 -0
  12. wandb/bin/wandb-core +0 -0
  13. wandb/cli/beta.py +7 -4
  14. wandb/cli/cli.py +5 -7
  15. wandb/docker/__init__.py +4 -4
  16. wandb/integration/fastai/__init__.py +4 -6
  17. wandb/integration/keras/keras.py +5 -3
  18. wandb/integration/metaflow/metaflow.py +7 -7
  19. wandb/integration/prodigy/prodigy.py +3 -11
  20. wandb/integration/sagemaker/__init__.py +5 -3
  21. wandb/integration/sagemaker/config.py +17 -8
  22. wandb/integration/sagemaker/files.py +0 -1
  23. wandb/integration/sagemaker/resources.py +47 -18
  24. wandb/integration/torch/wandb_torch.py +1 -1
  25. wandb/proto/v3/wandb_internal_pb2.py +273 -235
  26. wandb/proto/v4/wandb_internal_pb2.py +222 -214
  27. wandb/proto/v5/wandb_internal_pb2.py +222 -214
  28. wandb/sdk/artifacts/artifact.py +3 -9
  29. wandb/sdk/backend/backend.py +1 -1
  30. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  31. wandb/sdk/data_types/graph.py +2 -2
  32. wandb/sdk/data_types/saved_model.py +1 -1
  33. wandb/sdk/data_types/video.py +1 -1
  34. wandb/sdk/interface/interface.py +25 -25
  35. wandb/sdk/interface/interface_shared.py +21 -5
  36. wandb/sdk/internal/handler.py +19 -1
  37. wandb/sdk/internal/internal.py +1 -1
  38. wandb/sdk/internal/internal_api.py +4 -5
  39. wandb/sdk/internal/sample.py +2 -2
  40. wandb/sdk/internal/sender.py +1 -2
  41. wandb/sdk/internal/settings_static.py +3 -1
  42. wandb/sdk/internal/system/assets/disk.py +4 -4
  43. wandb/sdk/internal/system/assets/gpu.py +1 -1
  44. wandb/sdk/internal/system/assets/memory.py +1 -1
  45. wandb/sdk/internal/system/system_info.py +1 -1
  46. wandb/sdk/internal/system/system_monitor.py +3 -1
  47. wandb/sdk/internal/tb_watcher.py +1 -1
  48. wandb/sdk/launch/_project_spec.py +3 -3
  49. wandb/sdk/launch/builder/abstract.py +1 -1
  50. wandb/sdk/lib/apikey.py +2 -3
  51. wandb/sdk/lib/fsm.py +1 -1
  52. wandb/sdk/lib/gitlib.py +1 -1
  53. wandb/sdk/lib/gql_request.py +1 -1
  54. wandb/sdk/lib/interrupt.py +37 -0
  55. wandb/sdk/lib/lazyloader.py +1 -1
  56. wandb/sdk/lib/progress.py +7 -1
  57. wandb/sdk/lib/service_connection.py +1 -1
  58. wandb/sdk/lib/telemetry.py +1 -1
  59. wandb/sdk/service/_startup_debug.py +1 -1
  60. wandb/sdk/service/server_sock.py +3 -2
  61. wandb/sdk/service/service.py +1 -1
  62. wandb/sdk/service/streams.py +19 -17
  63. wandb/sdk/verify/verify.py +13 -13
  64. wandb/sdk/wandb_init.py +95 -104
  65. wandb/sdk/wandb_login.py +1 -1
  66. wandb/sdk/wandb_metadata.py +547 -0
  67. wandb/sdk/wandb_run.py +127 -35
  68. wandb/sdk/wandb_settings.py +6 -37
  69. wandb/sdk/wandb_setup.py +83 -82
  70. wandb/sdk/wandb_sweep.py +2 -2
  71. wandb/sdk/wandb_sync.py +15 -18
  72. wandb/sync/sync.py +10 -10
  73. wandb/util.py +11 -3
  74. wandb/wandb_agent.py +11 -16
  75. wandb/wandb_controller.py +7 -7
  76. {wandb-0.19.1rc1.dist-info → wandb-0.19.2.dist-info}/METADATA +3 -2
  77. {wandb-0.19.1rc1.dist-info → wandb-0.19.2.dist-info}/RECORD +80 -78
  78. {wandb-0.19.1rc1.dist-info → wandb-0.19.2.dist-info}/WHEEL +0 -0
  79. {wandb-0.19.1rc1.dist-info → wandb-0.19.2.dist-info}/entry_points.txt +0 -0
  80. {wandb-0.19.1rc1.dist-info → wandb-0.19.2.dist-info}/licenses/LICENSE +0 -0
wandb/__init__.py CHANGED
@@ -10,7 +10,7 @@ For reference documentation, see https://docs.wandb.com/ref/python.
10
10
  """
11
11
  from __future__ import annotations
12
12
 
13
- __version__ = "0.19.1rc1"
13
+ __version__ = "0.19.2"
14
14
 
15
15
 
16
16
  from wandb.errors import Error
@@ -204,12 +204,6 @@ if "dev" in __version__:
204
204
  "false",
205
205
  )
206
206
 
207
- # Enable new features in dev versions.
208
- os.environ["WANDB_X_SHOW_OPERATION_STATS"] = os.environ.get(
209
- "WANDB_X_SHOW_OPERATION_STATS",
210
- "true",
211
- )
212
-
213
207
  _sentry = _Sentry()
214
208
  _sentry.setup()
215
209
 
wandb/__init__.pyi CHANGED
@@ -103,7 +103,7 @@ if TYPE_CHECKING:
103
103
  import wandb
104
104
  from wandb.plot import CustomChart
105
105
 
106
- __version__: str = "0.19.1rc1"
106
+ __version__: str = "0.19.2"
107
107
 
108
108
  run: Run | None
109
109
  config: wandb_config.Config
@@ -114,9 +114,7 @@ _sentry: Sentry
114
114
  api: InternalApi
115
115
  patched: Dict[str, List[Callable]]
116
116
 
117
- def setup(
118
- settings: Settings | None = None,
119
- ) -> Optional[_WandbSetup]:
117
+ def setup(settings: Settings | None = None) -> _WandbSetup:
120
118
  """Prepares W&B for use in the current process and its children.
121
119
 
122
120
  You can usually ignore this as it is implicitly called by `wandb.init()`.
@@ -265,7 +263,7 @@ def init(
265
263
  entity: The username or team name under which the runs will be logged.
266
264
  The entity must already exist, so ensure you’ve created your account
267
265
  or team in the UI before starting to log runs. If not specified, the
268
- run will default your defualt entity. To change the default entity,
266
+ run will default your default entity. To change the default entity,
269
267
  go to [your settings](https://wandb.ai/settings) and update the
270
268
  "Default location to create new projects" under "Default team".
271
269
  project: The name of the project under which this run will be logged.
wandb/agents/pyagent.py CHANGED
@@ -297,7 +297,7 @@ class Agent:
297
297
  sweep_param_path, job.config
298
298
  )
299
299
  os.environ[wandb.env.SWEEP_ID] = self._sweep_id
300
- wandb.sdk.wandb_setup._setup(_reset=True)
300
+ wandb.teardown()
301
301
 
302
302
  wandb.termlog(f"Agent Starting Run: {run_id} with config:")
303
303
  for k, v in job.config.items():
@@ -1483,7 +1483,7 @@ def _get_run_or_dummy_from_art(art: Artifact, api=None):
1483
1483
  run = art.logged_by()
1484
1484
  except ValueError as e:
1485
1485
  logger.warn(
1486
- f"Can't log artifact because run does't exist, {art=}, {run=}, {e=}"
1486
+ f"Can't log artifact because run doesn't exist, {art=}, {run=}, {e=}"
1487
1487
  )
1488
1488
 
1489
1489
  if run is not None:
@@ -233,7 +233,7 @@ class File(Attrs):
233
233
  def _server_accepts_project_id_for_delete_file(self) -> bool:
234
234
  """Returns True if the server supports deleting files with a projectId.
235
235
 
236
- This check is done by utilizing GraphQL introspection in the avaiable fields on the DeleteFiles API.
236
+ This check is done by utilizing GraphQL introspection in the available fields on the DeleteFiles API.
237
237
  """
238
238
  query_string = """
239
239
  query ProbeDeleteFilesProjectIdInput {
wandb/apis/public/jobs.py CHANGED
@@ -407,7 +407,7 @@ class QueuedRun:
407
407
  self._run_id = item["associatedRunId"]
408
408
  return self._run
409
409
  except ValueError as e:
410
- print(e)
410
+ wandb.termwarn(e)
411
411
  elif item:
412
412
  wandb.termlog("Waiting for run to start")
413
413
 
wandb/apis/public/runs.py CHANGED
@@ -704,7 +704,7 @@ class Run(Attrs):
704
704
  if pd:
705
705
  lines = pd.DataFrame.from_records(lines)
706
706
  else:
707
- print("Unable to load pandas, call history with pandas=False")
707
+ wandb.termwarn("Unable to load pandas, call history with pandas=False")
708
708
  return lines
709
709
 
710
710
  @normalize_exceptions
@@ -908,7 +908,7 @@ class Run(Attrs):
908
908
  def _server_provides_internal_id_for_project(self) -> bool:
909
909
  """Returns True if the server allows us to query the internalId field for a project.
910
910
 
911
- This check is done by utilizing GraphQL introspection in the avaiable fields on the Project type.
911
+ This check is done by utilizing GraphQL introspection in the available fields on the Project type.
912
912
  """
913
913
  query_string = """
914
914
  query ProbeProjectInput {
@@ -924,11 +924,6 @@ class Run(Attrs):
924
924
  if self.server_provides_internal_id_field is None:
925
925
  query = gql(query_string)
926
926
  res = self.client.execute(query)
927
- print(
928
- "internalId"
929
- in [x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))]
930
- )
931
-
932
927
  self.server_provides_internal_id_field = "internalId" in [
933
928
  x["name"] for x in (res.get("ProjectType", {}).get("fields", [{}]))
934
929
  ]
@@ -4,5 +4,5 @@ try:
4
4
  from wandb_workspaces.reports.v1 import * # noqa: F403
5
5
  except ImportError:
6
6
  wandb.termerror(
7
- "Failed to import wandb_workspaces. To edit reports programatically, please install it using `pip install wandb[workspaces]`."
7
+ "Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
8
8
  )
@@ -4,5 +4,5 @@ try:
4
4
  from wandb_workspaces.reports.v2 import * # noqa: F403
5
5
  except ImportError:
6
6
  wandb.termerror(
7
- "Failed to import wandb_workspaces. To edit reports programatically, please install it using `pip install wandb[workspaces]`."
7
+ "Failed to import wandb_workspaces. To edit reports programmatically, please install it using `pip install wandb[workspaces]`."
8
8
  )
@@ -4,5 +4,5 @@ try:
4
4
  from wandb_workspaces.workspaces import * # noqa: F403
5
5
  except ImportError:
6
6
  wandb.termerror(
7
- "Failed to import wandb_workspaces. To edit workspaces programatically, please install it using `pip install wandb[workspaces]`."
7
+ "Failed to import wandb_workspaces. To edit workspaces programmatically, please install it using `pip install wandb[workspaces]`."
8
8
  )
wandb/bin/gpu_stats CHANGED
Binary file
wandb/bin/wandb-core CHANGED
Binary file
wandb/cli/beta.py CHANGED
@@ -12,6 +12,7 @@ import click
12
12
 
13
13
  import wandb
14
14
  from wandb.errors import UsageError, WandbCoreNotAvailableError
15
+ from wandb.sdk.wandb_sync import _sync
15
16
  from wandb.util import get_core_path
16
17
 
17
18
 
@@ -108,7 +109,9 @@ def sync_beta( # noqa: C901
108
109
  continue
109
110
  wandb_files = [p for p in d.glob("*.wandb") if p.is_file()]
110
111
  if len(wandb_files) > 1:
111
- print(f"Multiple wandb files found in directory {d}, skipping")
112
+ wandb.termwarn(
113
+ f"Multiple wandb files found in directory {d}, skipping"
114
+ )
112
115
  elif len(wandb_files) == 1:
113
116
  paths.add(d)
114
117
  else:
@@ -128,7 +131,7 @@ def sync_beta( # noqa: C901
128
131
  for path in paths:
129
132
  wandb_synced_files = [p for p in path.glob("*.wandb.synced") if p.is_file()]
130
133
  if len(wandb_synced_files) > 1:
131
- print(
134
+ wandb.termwarn(
132
135
  f"Multiple wandb.synced files found in directory {path}, skipping"
133
136
  )
134
137
  elif len(wandb_synced_files) == 1:
@@ -151,7 +154,7 @@ def sync_beta( # noqa: C901
151
154
  if dry_run:
152
155
  return
153
156
 
154
- wandb.sdk.wandb_setup.setup()
157
+ wandb.setup()
155
158
 
156
159
  # TODO: make it thread-safe in the Rust code
157
160
  with concurrent.futures.ProcessPoolExecutor(
@@ -162,7 +165,7 @@ def sync_beta( # noqa: C901
162
165
  # we already know there is only one wandb file in the directory
163
166
  wandb_file = [p for p in path.glob("*.wandb") if p.is_file()][0]
164
167
  future = executor.submit(
165
- wandb._sync,
168
+ _sync,
166
169
  wandb_file,
167
170
  run_id=run_id,
168
171
  project=project,
wandb/cli/cli.py CHANGED
@@ -125,7 +125,7 @@ def _get_cling_api(reset=None):
125
125
  global _api
126
126
  if reset:
127
127
  _api = None
128
- wandb_sdk.wandb_setup._setup(_reset=True)
128
+ wandb.teardown()
129
129
  if _api is None:
130
130
  # TODO(jhr): make a settings object that is better for non runs.
131
131
  # only override the necessary setting
@@ -2437,7 +2437,7 @@ def ls(path, type):
2437
2437
  per_page=1,
2438
2438
  )
2439
2439
  latest = next(versions)
2440
- print(
2440
+ wandb.termlog(
2441
2441
  "{:<15s}{:<15s}{:>15s} {:<20s}".format(
2442
2442
  kind.type,
2443
2443
  latest.updated_at,
@@ -2463,7 +2463,7 @@ def cleanup(target_size, remove_temp):
2463
2463
  target_size = util.from_human_size(target_size)
2464
2464
  cache = get_artifact_file_cache()
2465
2465
  reclaimed_bytes = cache.cleanup(target_size, remove_temp)
2466
- print(f"Reclaimed {util.to_human_size(reclaimed_bytes)} of space")
2466
+ wandb.termlog(f"Reclaimed {util.to_human_size(reclaimed_bytes)} of space")
2467
2467
 
2468
2468
 
2469
2469
  @cli.command(context_settings=CONTEXT, help="Pull files from Weights & Biases")
@@ -2664,7 +2664,6 @@ Run `git clone {}` and restore from there or pass the --no-git flag.""".format(r
2664
2664
  def online():
2665
2665
  api = InternalApi()
2666
2666
  try:
2667
- api.clear_setting("disabled", persist=True)
2668
2667
  api.clear_setting("mode", persist=True)
2669
2668
  except configparser.Error:
2670
2669
  pass
@@ -2678,7 +2677,6 @@ def online():
2678
2677
  def offline():
2679
2678
  api = InternalApi()
2680
2679
  try:
2681
- api.set_setting("disabled", "true", persist=True)
2682
2680
  api.set_setting("mode", "offline", persist=True)
2683
2681
  click.echo(
2684
2682
  "W&B offline. Running your script from this directory will only write metadata locally. Use wandb disabled to completely turn off W&B."
@@ -2765,13 +2763,13 @@ def verify(host):
2765
2763
  reinit = False
2766
2764
  if host is None:
2767
2765
  host = api.settings("base_url")
2768
- print(f"Default host selected: {host}")
2766
+ wandb.termlog(f"Default host selected: {host}")
2769
2767
  # if the given host does not match the default host, re-run init
2770
2768
  elif host != api.settings("base_url"):
2771
2769
  reinit = True
2772
2770
 
2773
2771
  tmp_dir = tempfile.mkdtemp()
2774
- print(
2772
+ wandb.termlog(
2775
2773
  "Find detailed logs for this test at: {}".format(os.path.join(tmp_dir, "wandb"))
2776
2774
  )
2777
2775
  os.chdir(tmp_dir)
wandb/docker/__init__.py CHANGED
@@ -62,7 +62,7 @@ def shell(cmd: List[str]) -> Optional[str]:
62
62
  .strip()
63
63
  )
64
64
  except subprocess.CalledProcessError as e:
65
- print(e)
65
+ print(e) # noqa: T201
66
66
  return None
67
67
 
68
68
 
@@ -140,12 +140,12 @@ def run_command_live_output(args: List[Any]) -> str:
140
140
  break
141
141
  index = chunk.find(b"\r")
142
142
  if index != -1:
143
- print(chunk.decode(), end="")
143
+ print(chunk.decode(), end="") # noqa: T201
144
144
  else:
145
145
  stdout += chunk.decode()
146
- print(chunk.decode(), end="\r")
146
+ print(chunk.decode(), end="\r") # noqa: T201
147
147
 
148
- print(stdout)
148
+ print(stdout) # noqa: T201
149
149
 
150
150
  return_code = process.wait()
151
151
  if return_code != 0:
@@ -54,7 +54,7 @@ try:
54
54
  matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
55
55
  import matplotlib.pyplot as plt
56
56
  except ImportError:
57
- print("Warning: matplotlib required if logging sample image predictions")
57
+ wandb.termwarn("matplotlib required if logging sample image predictions")
58
58
 
59
59
 
60
60
  class WandbCallback(TrackerCallback):
@@ -134,10 +134,8 @@ class WandbCallback(TrackerCallback):
134
134
  # Adapted from fast.ai "SaveModelCallback"
135
135
  current = self.get_monitor_value()
136
136
  if current is not None and self.operator(current, self.best):
137
- print(
138
- "Better model found at epoch {} with {} value: {}.".format(
139
- epoch, self.monitor, current
140
- )
137
+ wandb.termlog(
138
+ f"Better model found at epoch {epoch} with {self.monitor} value: {current}."
141
139
  )
142
140
  self.best = current
143
141
 
@@ -173,7 +171,7 @@ class WandbCallback(TrackerCallback):
173
171
  if self.model_path.is_file():
174
172
  with self.model_path.open("rb") as model_file:
175
173
  self.learn.load(model_file, purge=False)
176
- print(f"Loaded best saved model from {self.model_path}")
174
+ wandb.termlog(f"Loaded best saved model from {self.model_path}")
177
175
 
178
176
  def _wandb_log_predictions(self) -> None:
179
177
  """Log prediction samples."""
@@ -509,7 +509,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
509
509
 
510
510
  # From Keras
511
511
  if mode not in ["auto", "min", "max"]:
512
- print(f"WandbCallback mode {mode} is unknown, fallback to auto mode.")
512
+ wandb.termwarn(
513
+ f"WandbCallback mode {mode} is unknown, fallback to auto mode."
514
+ )
513
515
  mode = "auto"
514
516
 
515
517
  if mode == "min":
@@ -632,7 +634,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
632
634
  )
633
635
  wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
634
636
  if self.verbose and not self.save_model:
635
- print(
637
+ wandb.termlog(
636
638
  f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
637
639
  )
638
640
  if self.save_model:
@@ -1003,7 +1005,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
1003
1005
  if wandb.run.disabled:
1004
1006
  return
1005
1007
  if self.verbose > 0:
1006
- print(
1008
+ wandb.termlog(
1007
1009
  f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
1008
1010
  f"saving model to {self.filepath}"
1009
1011
  )
@@ -73,8 +73,8 @@ try:
73
73
  wandb.termlog(f"Logging artifact: {name} ({type(data)})")
74
74
 
75
75
  except ImportError:
76
- print(
77
- "Warning: `pandas` not installed >> @wandb_log(datasets=True) may not auto log your dataset!"
76
+ wandb.termwarn(
77
+ "`pandas` not installed >> @wandb_log(datasets=True) may not auto log your dataset!"
78
78
  )
79
79
 
80
80
  try:
@@ -119,8 +119,8 @@ try:
119
119
  wandb.termlog(f"Logging artifact: {name} ({type(data)})")
120
120
 
121
121
  except ImportError:
122
- print(
123
- "Warning: `pytorch` not installed >> @wandb_log(models=True) may not auto log your model!"
122
+ wandb.termwarn(
123
+ "`pytorch` not installed >> @wandb_log(models=True) may not auto log your model!"
124
124
  )
125
125
 
126
126
  try:
@@ -164,8 +164,8 @@ try:
164
164
  wandb.termlog(f"Logging artifact: {name} ({type(data)})")
165
165
 
166
166
  except ImportError:
167
- print(
168
- "Warning: `sklearn` not installed >> @wandb_log(models=True) may not auto log your model!"
167
+ wandb.termwarn(
168
+ "`sklearn` not installed >> @wandb_log(models=True) may not auto log your model!"
169
169
  )
170
170
 
171
171
 
@@ -245,7 +245,7 @@ def wandb_use(name: str, data, *args, **kwargs):
245
245
  try:
246
246
  return _wandb_use(name, data, *args, **kwargs)
247
247
  except wandb.CommError:
248
- print(
248
+ wandb.termwarn(
249
249
  f"This artifact ({name}, {type(data)}) does not exist in the wandb datastore!"
250
250
  f"If you created an instance inline (e.g. sklearn.ensemble.RandomForestClassifier), then you can safely ignore this"
251
251
  f"Otherwise you may want to check your internet connection!"
@@ -237,11 +237,7 @@ def create_table(data):
237
237
  im = Image.open(urllib.request.urlopen(document["image"]))
238
238
  document["image_visual"] = wandb.Image(im)
239
239
  except urllib.error.URLError:
240
- print(
241
- "Warning: Image URL "
242
- + str(document["image"])
243
- + " is invalid."
244
- )
240
+ wandb.termwarn(f"Image URL {document['image']} is invalid.")
245
241
  document["image_visual"] = None
246
242
  elif isbase64:
247
243
  # is base64 uri
@@ -252,11 +248,7 @@ def create_table(data):
252
248
  im = Image.open(buf)
253
249
  document["image_visual"] = wandb.Image(im)
254
250
  except base64.binascii.Error:
255
- print(
256
- "Warning: Base64 string "
257
- + str(document["image"])
258
- + " is invalid."
259
- )
251
+ wandb.termwarn(f"Base64 string {document['image']} is invalid.")
260
252
  document["image_visual"] = None
261
253
  else:
262
254
  # is data path
@@ -296,4 +288,4 @@ def upload_dataset(dataset_name):
296
288
  standardize(data[i], schema, array_dict_types)
297
289
  table = create_table(data)
298
290
  wandb.log({dataset_name: table})
299
- print("Prodigy dataset `" + dataset_name + "` uploaded.")
291
+ wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")
@@ -1,12 +1,14 @@
1
1
  """wandb integration sagemaker module."""
2
2
 
3
3
  from .auth import sagemaker_auth
4
- from .config import parse_sm_config
5
- from .resources import parse_sm_resources, parse_sm_secrets
4
+ from .config import is_using_sagemaker, parse_sm_config
5
+ from .resources import parse_sm_secrets, set_global_settings, set_run_id
6
6
 
7
7
  __all__ = [
8
8
  "sagemaker_auth",
9
+ "is_using_sagemaker",
9
10
  "parse_sm_config",
10
11
  "parse_sm_secrets",
11
- "parse_sm_resources",
12
+ "set_global_settings",
13
+ "set_run_id",
12
14
  ]
@@ -1,13 +1,23 @@
1
+ from __future__ import annotations
2
+
1
3
  import json
2
4
  import os
3
5
  import re
4
6
  import warnings
5
- from typing import Any, Dict
7
+ from typing import Any
6
8
 
7
9
  from . import files as sm_files
8
10
 
9
11
 
10
- def parse_sm_config() -> Dict[str, Any]:
12
+ def is_using_sagemaker() -> bool:
13
+ """Returns whether we're in a SageMaker environment."""
14
+ return (
15
+ os.path.exists(sm_files.SM_PARAM_CONFIG) #
16
+ or "SM_TRAINING_ENV" in os.environ
17
+ )
18
+
19
+
20
+ def parse_sm_config() -> dict[str, Any]:
11
21
  """Parses SageMaker configuration.
12
22
 
13
23
  Returns:
@@ -23,9 +33,7 @@ def parse_sm_config() -> Dict[str, Any]:
23
33
  """
24
34
  conf = {}
25
35
 
26
- if os.path.exists(sm_files.SM_PARAM_CONFIG) and os.path.exists(
27
- sm_files.SM_RESOURCE_CONFIG
28
- ):
36
+ if os.path.exists(sm_files.SM_PARAM_CONFIG):
29
37
  conf["sagemaker_training_job_name"] = os.getenv("TRAINING_JOB_NAME")
30
38
 
31
39
  # Hyperparameter searches quote configs...
@@ -38,12 +46,13 @@ def parse_sm_config() -> Dict[str, Any]:
38
46
  cast = float(cast)
39
47
  conf[key] = cast
40
48
 
41
- if "SM_TRAINING_ENV" in os.environ:
49
+ if env := os.environ.get("SM_TRAINING_ENV"):
42
50
  try:
43
- conf = {**conf, **json.loads(os.environ["SM_TRAINING_ENV"])}
51
+ conf.update(json.loads(env))
44
52
  except json.JSONDecodeError:
45
53
  warnings.warn(
46
- "Failed to parse SM_TRAINING_ENV not valid JSON string", stacklevel=2
54
+ "Failed to parse SM_TRAINING_ENV not valid JSON string",
55
+ stacklevel=2,
47
56
  )
48
57
 
49
58
  return conf
@@ -1,3 +1,2 @@
1
1
  SM_PARAM_CONFIG = "/opt/ml/input/config/hyperparameters.json"
2
- SM_RESOURCE_CONFIG = "/opt/ml/input/config/resourceconfig.json"
3
2
  SM_SECRETS = "secrets.env"
@@ -1,13 +1,58 @@
1
+ from __future__ import annotations
2
+
1
3
  import os
2
4
  import secrets
3
5
  import socket
4
6
  import string
5
- from typing import Dict, Tuple
6
7
 
8
+ import wandb
9
+
10
+ from . import config
7
11
  from . import files as sm_files
8
12
 
9
13
 
10
- def parse_sm_secrets() -> Dict[str, str]:
14
+ def set_run_id(run_settings: wandb.Settings) -> bool:
15
+ """Set a run ID and group when using SageMaker.
16
+
17
+ Returns whether the ID and group were updated.
18
+ """
19
+ # Added in https://github.com/wandb/wandb/pull/3290.
20
+ #
21
+ # Prevents SageMaker from overriding the run ID configured
22
+ # in environment variables. Note, however, that it will still
23
+ # override a run ID passed explicitly to `wandb.init()`.
24
+ if os.getenv("WANDB_RUN_ID"):
25
+ return False
26
+
27
+ run_group = os.getenv("TRAINING_JOB_NAME")
28
+ if not run_group:
29
+ return False
30
+
31
+ alphanumeric = string.ascii_lowercase + string.digits
32
+ random = "".join(secrets.choice(alphanumeric) for _ in range(6))
33
+
34
+ host = os.getenv("CURRENT_HOST", socket.gethostname())
35
+
36
+ run_settings.run_id = f"{run_group}-{random}-{host}"
37
+ run_settings.run_group = run_group
38
+ return True
39
+
40
+
41
+ def set_global_settings(settings: wandb.Settings) -> None:
42
+ """Set global W&B settings based on the SageMaker environment."""
43
+ if env := parse_sm_secrets():
44
+ settings.update_from_env_vars(env)
45
+
46
+ # The SageMaker config may contain an API key, in which case it
47
+ # takes precedence over the value in the secrets. It's unclear
48
+ # whether this is by design, or by accident; we keep it for
49
+ # backward compatibility for now.
50
+ sm_config = config.parse_sm_config()
51
+ if api_key := sm_config.get("wandb_api_key"):
52
+ settings.api_key = api_key
53
+
54
+
55
+ def parse_sm_secrets() -> dict[str, str]:
11
56
  """We read our api_key from secrets.env in SageMaker."""
12
57
  env_dict = dict()
13
58
  # Set secret variables
@@ -16,19 +61,3 @@ def parse_sm_secrets() -> Dict[str, str]:
16
61
  key, val = line.strip().split("=", 1)
17
62
  env_dict[key] = val
18
63
  return env_dict
19
-
20
-
21
- def parse_sm_resources() -> Tuple[Dict[str, str], Dict[str, str]]:
22
- run_dict = dict()
23
- run_id = os.getenv("TRAINING_JOB_NAME")
24
-
25
- if run_id and os.getenv("WANDB_RUN_ID") is None:
26
- suffix = "".join(
27
- secrets.choice(string.ascii_lowercase + string.digits) for _ in range(6)
28
- )
29
- run_dict["run_id"] = "-".join(
30
- [run_id, suffix, os.getenv("CURRENT_HOST", socket.gethostname())]
31
- )
32
- run_dict["run_group"] = os.getenv("TRAINING_JOB_NAME")
33
- env_dict = parse_sm_secrets()
34
- return run_dict, env_dict
@@ -441,7 +441,7 @@ class TorchGraph(wandb.data_types.Graph):
441
441
  decoder.weight encoder
442
442
  decoder.bias decoder
443
443
  """
444
- # TODO: We're currently not using this, but I left it here incase we want to resurrect! - CVP
444
+ # TODO: We're currently not using this, but I left it here in case we want to resurrect! - CVP
445
445
  torch = util.get_module("torch", "Could not import torch")
446
446
 
447
447
  module_nodes_by_hash = {id(n): n for n in module_graph.nodes}