wandb 0.17.3__py3-none-any.whl → 0.17.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/apis/internal.py +4 -0
  3. wandb/cli/cli.py +7 -6
  4. wandb/env.py +16 -0
  5. wandb/filesync/upload_job.py +1 -1
  6. wandb/proto/v3/wandb_internal_pb2.py +339 -328
  7. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  8. wandb/proto/v4/wandb_internal_pb2.py +326 -323
  9. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  10. wandb/proto/v5/wandb_internal_pb2.py +326 -323
  11. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  12. wandb/sdk/artifacts/artifact.py +13 -24
  13. wandb/sdk/artifacts/artifact_file_cache.py +35 -13
  14. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +11 -6
  15. wandb/sdk/interface/interface.py +12 -5
  16. wandb/sdk/interface/interface_shared.py +9 -7
  17. wandb/sdk/internal/handler.py +1 -1
  18. wandb/sdk/internal/internal_api.py +67 -14
  19. wandb/sdk/internal/sender.py +9 -2
  20. wandb/sdk/launch/agent/agent.py +3 -1
  21. wandb/sdk/launch/builder/kaniko_builder.py +30 -9
  22. wandb/sdk/launch/inputs/internal.py +79 -2
  23. wandb/sdk/launch/inputs/manage.py +21 -3
  24. wandb/sdk/launch/sweeps/scheduler.py +2 -0
  25. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  26. wandb/sdk/lib/credentials.py +141 -0
  27. wandb/sdk/lib/tracelog.py +2 -2
  28. wandb/sdk/wandb_init.py +12 -2
  29. wandb/sdk/wandb_login.py +6 -0
  30. wandb/sdk/wandb_manager.py +34 -21
  31. wandb/sdk/wandb_run.py +100 -75
  32. wandb/sdk/wandb_settings.py +13 -2
  33. wandb/sdk/wandb_setup.py +12 -13
  34. wandb/util.py +29 -11
  35. {wandb-0.17.3.dist-info → wandb-0.17.5.dist-info}/METADATA +1 -1
  36. {wandb-0.17.3.dist-info → wandb-0.17.5.dist-info}/RECORD +39 -38
  37. {wandb-0.17.3.dist-info → wandb-0.17.5.dist-info}/WHEEL +0 -0
  38. {wandb-0.17.3.dist-info → wandb-0.17.5.dist-info}/entry_points.txt +0 -0
  39. {wandb-0.17.3.dist-info → wandb-0.17.5.dist-info}/licenses/LICENSE +0 -0
@@ -263,11 +263,17 @@ class KanikoBuilder(AbstractBuilder):
263
263
  repo_uri = await self.registry.get_repo_uri()
264
264
  image_uri = repo_uri + ":" + image_tag
265
265
 
266
- if (
267
- not launch_project.build_required()
268
- and await self.registry.check_image_exists(image_uri)
269
- ):
270
- return image_uri
266
+ # The DOCKER_CONFIG_SECRET option is mutually exclusive with the
267
+ # registry classes, so we must skip the check for image existence in
268
+ # that case.
269
+ if not launch_project.build_required():
270
+ if DOCKER_CONFIG_SECRET:
271
+ wandb.termlog(
272
+ f"Skipping check for existing image {image_uri} due to custom dockerconfig."
273
+ )
274
+ else:
275
+ if await self.registry.check_image_exists(image_uri):
276
+ return image_uri
271
277
 
272
278
  _logger.info(f"Building image {image_uri}...")
273
279
  _, api_client = await get_kube_context_and_api_client(
@@ -286,7 +292,12 @@ class KanikoBuilder(AbstractBuilder):
286
292
  wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
287
293
 
288
294
  try:
289
- if isinstance(self.registry, AzureContainerRegistry):
295
+ # DOCKER_CONFIG_SECRET is a user provided dockerconfigjson. Skip our
296
+ # dockerconfig handling if it's set.
297
+ if (
298
+ isinstance(self.registry, AzureContainerRegistry)
299
+ and not DOCKER_CONFIG_SECRET
300
+ ):
290
301
  dockerfile_config_map = client.V1ConfigMap(
291
302
  metadata=client.V1ObjectMeta(
292
303
  name=f"docker-config-{build_job_name}"
@@ -344,7 +355,10 @@ class KanikoBuilder(AbstractBuilder):
344
355
  finally:
345
356
  wandb.termlog(f"{LOG_PREFIX}Cleaning up resources")
346
357
  try:
347
- if isinstance(self.registry, AzureContainerRegistry):
358
+ if (
359
+ isinstance(self.registry, AzureContainerRegistry)
360
+ and not DOCKER_CONFIG_SECRET
361
+ ):
348
362
  await core_v1.delete_namespaced_config_map(
349
363
  f"docker-config-{build_job_name}", "wandb"
350
364
  )
@@ -498,7 +512,10 @@ class KanikoBuilder(AbstractBuilder):
498
512
  "readOnly": True,
499
513
  }
500
514
  )
501
- if isinstance(self.registry, AzureContainerRegistry):
515
+ if (
516
+ isinstance(self.registry, AzureContainerRegistry)
517
+ and not DOCKER_CONFIG_SECRET
518
+ ):
502
519
  # Add the docker config map
503
520
  volumes.append(
504
521
  {
@@ -533,7 +550,11 @@ class KanikoBuilder(AbstractBuilder):
533
550
  # Apply the rest of our defaults
534
551
  pod_labels["wandb"] = "launch"
535
552
  # This annotation is required to enable azure workload identity.
536
- if isinstance(self.registry, AzureContainerRegistry):
553
+ # Don't add this label if using a docker config secret for auth.
554
+ if (
555
+ isinstance(self.registry, AzureContainerRegistry)
556
+ and not DOCKER_CONFIG_SECRET
557
+ ):
537
558
  pod_labels["azure.workload.identity/use"] = "true"
538
559
  pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
539
560
  pod_spec["activeDeadlineSeconds"] = pod_spec.get(
@@ -11,7 +11,7 @@ import os
11
11
  import pathlib
12
12
  import shutil
13
13
  import tempfile
14
- from typing import List, Optional
14
+ from typing import Any, Dict, List, Optional
15
15
 
16
16
  import wandb
17
17
  import wandb.data_types
@@ -62,11 +62,13 @@ class JobInputArguments:
62
62
  self,
63
63
  include: Optional[List[str]] = None,
64
64
  exclude: Optional[List[str]] = None,
65
+ schema: Optional[dict] = None,
65
66
  file_path: Optional[str] = None,
66
67
  run_config: Optional[bool] = None,
67
68
  ):
68
69
  self.include = include
69
70
  self.exclude = exclude
71
+ self.schema = schema
70
72
  self.file_path = file_path
71
73
  self.run_config = run_config
72
74
 
@@ -121,15 +123,66 @@ def _publish_job_input(
121
123
  exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
122
124
  if input.exclude
123
125
  else [],
126
+ input_schema=input.schema,
124
127
  run_config=input.run_config,
125
128
  file_path=input.file_path or "",
126
129
  )
127
130
 
128
131
 
132
+ def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
133
+ """Recursively fix JSON schemas with common issues.
134
+
135
+ 1. Replaces any instances of $ref with their associated definition in defs
136
+ 2. Removes any "allOf" lists that only have one item, "lifting" the item up
137
+ See test_internal.py for examples
138
+ """
139
+ ret: Dict[str, Any] = {}
140
+ if "$ref" in schema:
141
+ # Reference found, replace it with its definition
142
+ def_key = schema["$ref"].split("#/$defs/")[1]
143
+ # Also run recursive replacement in case a ref contains more refs
144
+ return _replace_refs_and_allofs(defs.pop(def_key), defs)
145
+ for key, val in schema.items():
146
+ if isinstance(val, dict):
147
+ # Step into dicts recursively
148
+ new_val_dict = _replace_refs_and_allofs(val, defs)
149
+ ret[key] = new_val_dict
150
+ elif isinstance(val, list):
151
+ # Step into each item in the list
152
+ new_val_list = []
153
+ for item in val:
154
+ if isinstance(item, dict):
155
+ new_val_list.append(_replace_refs_and_allofs(item, defs))
156
+ else:
157
+ new_val_list.append(item)
158
+ # Lift up allOf blocks with only one item
159
+ if (
160
+ key == "allOf"
161
+ and len(new_val_list) == 1
162
+ and isinstance(new_val_list[0], dict)
163
+ ):
164
+ ret.update(new_val_list[0])
165
+ else:
166
+ ret[key] = new_val_list
167
+ else:
168
+ # For anything else (str, int, etc) keep it as-is
169
+ ret[key] = val
170
+ return ret
171
+
172
+
173
+ def _convert_pydantic_model_to_jsonschema(model: Any) -> dict:
174
+ schema = model.model_json_schema()
175
+ defs = schema.pop("$defs")
176
+ if not defs:
177
+ return schema
178
+ return _replace_refs_and_allofs(schema, defs)
179
+
180
+
129
181
  def handle_config_file_input(
130
182
  path: str,
131
183
  include: Optional[List[str]] = None,
132
184
  exclude: Optional[List[str]] = None,
185
+ schema: Optional[Any] = None,
133
186
  ):
134
187
  """Declare an overridable configuration file for a launch job.
135
188
 
@@ -151,9 +204,20 @@ def handle_config_file_input(
151
204
  path,
152
205
  dest,
153
206
  )
207
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
208
+ # or the BaseModel class itself (e.g. schema=MySchema)
209
+ if hasattr(schema, "model_json_schema") and callable(
210
+ schema.model_json_schema # type: ignore
211
+ ):
212
+ schema = _convert_pydantic_model_to_jsonschema(schema)
213
+ if schema and not isinstance(schema, dict):
214
+ raise LaunchError(
215
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
216
+ )
154
217
  arguments = JobInputArguments(
155
218
  include=include,
156
219
  exclude=exclude,
220
+ schema=schema,
157
221
  file_path=path,
158
222
  run_config=False,
159
223
  )
@@ -165,7 +229,9 @@ def handle_config_file_input(
165
229
 
166
230
 
167
231
  def handle_run_config_input(
168
- include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
232
+ include: Optional[List[str]] = None,
233
+ exclude: Optional[List[str]] = None,
234
+ schema: Optional[Any] = None,
169
235
  ):
170
236
  """Declare wandb.config as an overridable configuration for a launch job.
171
237
 
@@ -175,9 +241,20 @@ def handle_run_config_input(
175
241
  If there is no active run, the include and exclude paths are staged and sent
176
242
  when a run is created.
177
243
  """
244
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
245
+ # or the BaseModel class itself (e.g. schema=MySchema)
246
+ if hasattr(schema, "model_json_schema") and callable(
247
+ schema.model_json_schema # type: ignore
248
+ ):
249
+ schema = _convert_pydantic_model_to_jsonschema(schema)
250
+ if schema and not isinstance(schema, dict):
251
+ raise LaunchError(
252
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
253
+ )
178
254
  arguments = JobInputArguments(
179
255
  include=include,
180
256
  exclude=exclude,
257
+ schema=schema,
181
258
  run_config=True,
182
259
  file_path=None,
183
260
  )
@@ -1,12 +1,13 @@
1
1
  """Functions for declaring overridable configuration for launch jobs."""
2
2
 
3
- from typing import List, Optional
3
+ from typing import Any, List, Optional
4
4
 
5
5
 
6
6
  def manage_config_file(
7
7
  path: str,
8
8
  include: Optional[List[str]] = None,
9
9
  exclude: Optional[List[str]] = None,
10
+ schema: Optional[Any] = None,
10
11
  ):
11
12
  r"""Declare an overridable configuration file for a launch job.
12
13
 
@@ -43,18 +44,27 @@ def manage_config_file(
43
44
  relative and must not contain backwards traversal, i.e. `..`.
44
45
  include (List[str]): A list of keys to include in the configuration file.
45
46
  exclude (List[str]): A list of keys to exclude from the configuration file.
47
+ schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
48
+ describing which attributes will be editable from the Launch drawer.
49
+ Accepts both an instance of a Pydantic BaseModel class or the BaseModel
50
+ class itself.
46
51
 
47
52
  Raises:
48
53
  LaunchError: If the path is not valid, or if there is no active run.
49
54
  """
55
+ # note: schema's Any type is because in the case where a BaseModel class is
56
+ # provided, its type is a pydantic internal type that we don't want our typing
57
+ # to depend on. schema's type should be considered
58
+ # "Optional[dict | <something with a .model_json_schema() method>]"
50
59
  from .internal import handle_config_file_input
51
60
 
52
- return handle_config_file_input(path, include, exclude)
61
+ return handle_config_file_input(path, include, exclude, schema)
53
62
 
54
63
 
55
64
  def manage_wandb_config(
56
65
  include: Optional[List[str]] = None,
57
66
  exclude: Optional[List[str]] = None,
67
+ schema: Optional[Any] = None,
58
68
  ):
59
69
  r"""Declare wandb.config as an overridable configuration for a launch job.
60
70
 
@@ -86,10 +96,18 @@ def manage_wandb_config(
86
96
  Args:
87
97
  include (List[str]): A list of subtrees to include in the configuration.
88
98
  exclude (List[str]): A list of subtrees to exclude from the configuration.
99
+ schema (dict | Pydantic model): A JSON Schema or Pydantic model describing
100
+ describing which attributes will be editable from the Launch drawer.
101
+ Accepts both an instance of a Pydantic BaseModel class or the BaseModel
102
+ class itself.
89
103
 
90
104
  Raises:
91
105
  LaunchError: If there is no active run.
92
106
  """
107
+ # note: schema's Any type is because in the case where a BaseModel class is
108
+ # provided, its type is a pydantic internal type that we don't want our typing
109
+ # to depend on. schema's type should be considered
110
+ # "Optional[dict | <something with a .model_json_schema() method>]"
93
111
  from .internal import handle_run_config_input
94
112
 
95
- handle_run_config_input(include, exclude)
113
+ handle_run_config_input(include, exclude, schema)
@@ -259,10 +259,12 @@ class Scheduler(ABC):
259
259
 
260
260
  def _init_wandb_run(self) -> "SdkRun":
261
261
  """Controls resume or init logic for a scheduler wandb run."""
262
+ settings = wandb.Settings(disable_job_creation=True)
262
263
  run: SdkRun = wandb.init( # type: ignore
263
264
  name=f"Scheduler.{self._sweep_id}",
264
265
  resume="allow",
265
266
  config=self._kwargs, # when run as a job, this sets config
267
+ settings=settings,
266
268
  )
267
269
  return run
268
270
 
@@ -26,6 +26,7 @@ _Setting = Literal[
26
26
  "_disable_machine_info",
27
27
  "_executable",
28
28
  "_extra_http_headers",
29
+ "_file_stream_max_bytes",
29
30
  "_file_stream_retry_max",
30
31
  "_file_stream_retry_wait_min_seconds",
31
32
  "_file_stream_retry_wait_max_seconds",
@@ -91,6 +92,7 @@ _Setting = Literal[
91
92
  "config_paths",
92
93
  "console",
93
94
  "console_multipart",
95
+ "credentials_file",
94
96
  "deployment",
95
97
  "disable_code",
96
98
  "disable_git",
@@ -112,6 +114,7 @@ _Setting = Literal[
112
114
  "host",
113
115
  "http_proxy",
114
116
  "https_proxy",
117
+ "identity_token_file",
115
118
  "ignore_globs",
116
119
  "init_timeout",
117
120
  "is_local",
@@ -0,0 +1,141 @@
1
+ import json
2
+ import os
3
+ from datetime import datetime, timedelta
4
+ from pathlib import Path
5
+
6
+ import requests.utils
7
+
8
+ from wandb.errors import AuthenticationError
9
+
10
+ DEFAULT_WANDB_CREDENTIALS_FILE = Path(
11
+ os.path.expanduser("~/.config/wandb/credentials.json")
12
+ )
13
+
14
+ _expires_at_fmt = "%Y-%m-%d %H:%M:%S"
15
+
16
+
17
+ def access_token(base_url: str, token_file: Path, credentials_file: Path) -> str:
18
+ """Retrieve an access token from the credentials file.
19
+
20
+ If no access token exists, create a new one by exchanging the identity
21
+ token from the token file, and save it to the credentials file.
22
+
23
+ Args:
24
+ base_url (str): The base URL of the server
25
+ token_file (pathlib.Path): The path to the file containing the
26
+ identity token
27
+ credentials_file (pathlib.Path): The path to file used to save
28
+ temporary access tokens
29
+
30
+ Returns:
31
+ str: The access token
32
+ """
33
+ if not credentials_file.exists():
34
+ _write_credentials_file(base_url, token_file, credentials_file)
35
+
36
+ data = _fetch_credentials(base_url, token_file, credentials_file)
37
+ return data["access_token"]
38
+
39
+
40
+ def _write_credentials_file(base_url: str, token_file: Path, credentials_file: Path):
41
+ """Obtain an access token from the server and write it to the credentials file.
42
+
43
+ Args:
44
+ base_url (str): The base URL of the server
45
+ token_file (pathlib.Path): The path to the file containing the
46
+ identity token
47
+ credentials_file (pathlib.Path): The path to file used to save
48
+ temporary access tokens
49
+ """
50
+ credentials = _create_access_token(base_url, token_file)
51
+ data = {"credentials": {base_url: credentials}}
52
+ with open(credentials_file, "w") as file:
53
+ json.dump(data, file, indent=4)
54
+
55
+ # Set file permissions to be read/write by the owner only
56
+ os.chmod(credentials_file, 0o600)
57
+
58
+
59
+ def _fetch_credentials(base_url: str, token_file: Path, credentials_file: Path) -> dict:
60
+ """Fetch the access token from the credentials file.
61
+
62
+ If the access token has expired, fetch a new one from the server and save it
63
+ to the credentials file.
64
+
65
+ Args:
66
+ base_url (str): The base URL of the server
67
+ token_file (pathlib.Path): The path to the file containing the
68
+ identity token
69
+ credentials_file (pathlib.Path): The path to file used to save
70
+ temporary access tokens
71
+
72
+ Returns:
73
+ dict: The credentials including the access token.
74
+ """
75
+ creds = {}
76
+ with open(credentials_file) as file:
77
+ data = json.load(file)
78
+ if "credentials" not in data:
79
+ data["credentials"] = {}
80
+ if base_url in data["credentials"]:
81
+ creds = data["credentials"][base_url]
82
+
83
+ expires_at = datetime.utcnow()
84
+ if "expires_at" in creds:
85
+ expires_at = datetime.strptime(creds["expires_at"], _expires_at_fmt)
86
+
87
+ if expires_at <= datetime.utcnow():
88
+ creds = _create_access_token(base_url, token_file)
89
+ with open(credentials_file, "w") as file:
90
+ data["credentials"][base_url] = creds
91
+ json.dump(data, file, indent=4)
92
+
93
+ return creds
94
+
95
+
96
+ def _create_access_token(base_url: str, token_file: Path) -> dict:
97
+ """Exchange an identity token for an access token from the server.
98
+
99
+ Args:
100
+ base_url (str): The base URL of the server.
101
+ token_file (pathlib.Path): The path to the file containing the
102
+ identity token
103
+
104
+ Returns:
105
+ dict: The access token and its expiration.
106
+
107
+ Raises:
108
+ FileNotFoundError: If the token file is not found.
109
+ OSError: If there is an issue reading the token file.
110
+ AuthenticationError: If the server fails to provide an access token.
111
+ """
112
+ try:
113
+ with open(token_file) as file:
114
+ token = file.read().strip()
115
+ except FileNotFoundError as e:
116
+ raise FileNotFoundError(f"Identity token file not found: {token_file}") from e
117
+ except OSError as e:
118
+ raise OSError(
119
+ f"Failed to read the identity token from file: {token_file}"
120
+ ) from e
121
+
122
+ url = f"{base_url}/oidc/token"
123
+ data = {
124
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
125
+ "assertion": token,
126
+ }
127
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
128
+
129
+ response = requests.post(url, data=data, headers=headers)
130
+
131
+ if response.status_code != 200:
132
+ raise AuthenticationError(
133
+ f"Failed to retrieve access token: {response.status_code}, {response.text}"
134
+ )
135
+
136
+ resp_json = response.json()
137
+ expires_at = datetime.utcnow() + timedelta(seconds=float(resp_json["expires_in"]))
138
+ resp_json["expires_at"] = expires_at.strftime(_expires_at_fmt)
139
+ del resp_json["expires_in"]
140
+
141
+ return resp_json
wandb/sdk/lib/tracelog.py CHANGED
@@ -45,8 +45,8 @@ logger = logging.getLogger(__name__)
45
45
  ANNOTATE_QUEUE_NAME = "_DEBUGLOG_QUEUE_NAME"
46
46
 
47
47
  # capture stdout and stderr before anyone messes with them
48
- stdout_write = sys.__stdout__.write
49
- stderr_write = sys.__stderr__.write
48
+ stdout_write = sys.__stdout__.write # type: ignore
49
+ stderr_write = sys.__stderr__.write # type: ignore
50
50
 
51
51
 
52
52
  def _log(
wandb/sdk/wandb_init.py CHANGED
@@ -323,6 +323,15 @@ class _WandbInit:
323
323
  if save_code_pre_user_settings is False:
324
324
  settings.update({"save_code": False}, source=Source.INIT)
325
325
 
326
+ # TODO: remove this once we refactor the client. This is a temporary
327
+ # fix to make sure that we use the same project name for wandb-core.
328
+ # The reason this is not going throught the settings object is to
329
+ # avoid failure cases in other parts of the code that will be
330
+ # removed with the switch to wandb-core.
331
+ if settings.project is None:
332
+ project = wandb.util.auto_project_name(settings.program)
333
+ settings.update({"project": project}, source=Source.INIT)
334
+
326
335
  # TODO(jhr): should this be moved? probably.
327
336
  settings._set_run_start_time(source=Source.INIT)
328
337
 
@@ -989,8 +998,9 @@ def init(
989
998
 
990
999
  Arguments:
991
1000
  project: (str, optional) The name of the project where you're sending
992
- the new run. If the project is not specified, the run is put in an
993
- "Uncategorized" project.
1001
+ the new run. If the project is not specified, we will try to infer
1002
+ the project name from git root or the current program file. If we
1003
+ can't infer the project name, we will default to `"uncategorized"`.
994
1004
  entity: (str, optional) An entity is a username or team name where
995
1005
  you're sending runs. This entity must exist before you can send runs
996
1006
  there, so make sure to create your account or team in the UI before
wandb/sdk/wandb_login.py CHANGED
@@ -156,6 +156,9 @@ class _WandbLogin:
156
156
  """Returns whether an API key is set or can be inferred."""
157
157
  return apikey.api_key(settings=self._settings) is not None
158
158
 
159
+ def should_use_identity_token(self):
160
+ return self._settings.identity_token_file is not None
161
+
159
162
  def set_backend(self, backend):
160
163
  self._backend = backend
161
164
 
@@ -327,6 +330,9 @@ def _login(
327
330
  )
328
331
  return False
329
332
 
333
+ if wlogin.should_use_identity_token():
334
+ return True
335
+
330
336
  # perform a login
331
337
  logged_in = wlogin.login()
332
338
 
@@ -114,19 +114,24 @@ class _Manager:
114
114
 
115
115
  try:
116
116
  svc_iface._svc_connect(port=port)
117
+
117
118
  except ConnectionRefusedError as e:
118
119
  if not psutil.pid_exists(self._token.pid):
119
120
  message = (
120
- "Connection to wandb service failed "
121
- "since the process is not available. "
121
+ "Connection to wandb service failed"
122
+ " because the process is not available."
122
123
  )
123
124
  else:
124
- message = f"Connection to wandb service failed: {e}. "
125
- raise ManagerConnectionRefusedError(message)
125
+ message = "Connection to wandb service failed."
126
+ raise ManagerConnectionRefusedError(message) from e
127
+
126
128
  except Exception as e:
127
- raise ManagerConnectionError(f"Connection to wandb service failed: {e}")
129
+ raise ManagerConnectionError(
130
+ "Connection to wandb service failed.",
131
+ ) from e
128
132
 
129
133
  def __init__(self, settings: "Settings") -> None:
134
+ """Connects to the internal service, starting it if necessary."""
130
135
  from wandb.sdk.service import service
131
136
 
132
137
  self._settings = settings
@@ -134,6 +139,7 @@ class _Manager:
134
139
  self._hooks = None
135
140
 
136
141
  self._service = service._Service(settings=self._settings)
142
+
137
143
  token = _ManagerToken.from_environment()
138
144
  if not token:
139
145
  self._service.start()
@@ -144,7 +150,6 @@ class _Manager:
144
150
  token = _ManagerToken.from_params(transport=transport, host=host, port=port)
145
151
  token.set_environment()
146
152
  self._atexit_setup()
147
-
148
153
  self._token = token
149
154
 
150
155
  try:
@@ -152,6 +157,24 @@ class _Manager:
152
157
  except ManagerConnectionError as e:
153
158
  wandb._sentry.reraise(e)
154
159
 
160
+ def _teardown(self, exit_code: int) -> int:
161
+ """Shuts down the internal process and returns its exit code.
162
+
163
+ This sends a teardown record to the process. An exception is raised if
164
+ the process has already been shut down.
165
+ """
166
+ unregister_all_post_import_hooks()
167
+
168
+ if self._atexit_lambda:
169
+ atexit.unregister(self._atexit_lambda)
170
+ self._atexit_lambda = None
171
+
172
+ try:
173
+ self._inform_teardown(exit_code)
174
+ return self._service.join()
175
+ finally:
176
+ self._token.reset_environment()
177
+
155
178
  def _atexit_setup(self) -> None:
156
179
  self._atexit_lambda = lambda: self._atexit_teardown()
157
180
 
@@ -161,28 +184,18 @@ class _Manager:
161
184
 
162
185
  def _atexit_teardown(self) -> None:
163
186
  trigger.call("on_finished")
164
- exit_code = self._hooks.exit_code if self._hooks else 0
165
- self._teardown(exit_code)
166
187
 
167
- def _teardown(self, exit_code: int) -> None:
168
- unregister_all_post_import_hooks()
169
-
170
- if self._atexit_lambda:
171
- atexit.unregister(self._atexit_lambda)
172
- self._atexit_lambda = None
188
+ # Clear the atexit hook---we're executing it now, after which the
189
+ # process will exit.
190
+ self._atexit_lambda = None
173
191
 
174
192
  try:
175
- self._inform_teardown(exit_code)
176
- result = self._service.join()
177
- if result and not self._settings._notebook:
178
- os._exit(result)
193
+ self._teardown(self._hooks.exit_code if self._hooks else 0)
179
194
  except Exception as e:
180
195
  wandb.termlog(
181
- f"While tearing down the service manager. The following error has occurred: {e}",
196
+ f"Encountered an error while tearing down the service manager: {e}",
182
197
  repeat=False,
183
198
  )
184
- finally:
185
- self._token.reset_environment()
186
199
 
187
200
  def _get_service(self) -> "service._Service":
188
201
  return self._service