wandb 0.17.2__py3-none-win32.whl → 0.17.4__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/apis/internal.py +4 -0
- wandb/apis/public/projects.py +0 -6
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +7 -6
- wandb/env.py +16 -0
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/artifacts/artifact.py +2 -0
- wandb/sdk/artifacts/artifact_file_cache.py +35 -13
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +11 -6
- wandb/sdk/internal/internal_api.py +68 -15
- wandb/sdk/launch/_launch_add.py +1 -1
- wandb/sdk/launch/agent/agent.py +3 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/utils.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +8 -0
- wandb/sdk/lib/apikey.py +16 -3
- wandb/sdk/lib/credentials.py +141 -0
- wandb/sdk/wandb_init.py +12 -2
- wandb/sdk/wandb_login.py +6 -0
- wandb/sdk/wandb_manager.py +25 -16
- wandb/sdk/wandb_run.py +5 -2
- wandb/sdk/wandb_settings.py +27 -3
- wandb/sdk/wandb_setup.py +12 -13
- wandb/sdk/wandb_sweep.py +4 -2
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/METADATA +1 -1
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/RECORD +32 -31
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/WHEEL +0 -0
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@ import socket
|
|
11
11
|
import sys
|
12
12
|
import threading
|
13
13
|
from copy import deepcopy
|
14
|
+
from pathlib import Path
|
14
15
|
from typing import (
|
15
16
|
IO,
|
16
17
|
TYPE_CHECKING,
|
@@ -37,14 +38,14 @@ from wandb_gql.client import RetryError
|
|
37
38
|
import wandb
|
38
39
|
from wandb import env, util
|
39
40
|
from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
|
40
|
-
from wandb.errors import CommError, UnsupportedError, UsageError
|
41
|
+
from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
|
41
42
|
from wandb.integration.sagemaker import parse_sm_secrets
|
42
43
|
from wandb.old.settings import Settings
|
43
44
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
44
45
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
45
46
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
46
47
|
|
47
|
-
from ..lib import retry
|
48
|
+
from ..lib import credentials, retry
|
48
49
|
from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
|
49
50
|
from ..lib.gitlib import GitRepo
|
50
51
|
from . import context
|
@@ -234,14 +235,18 @@ class Api:
|
|
234
235
|
extra_http_headers = self.settings("_extra_http_headers") or json.loads(
|
235
236
|
self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
|
236
237
|
)
|
238
|
+
extra_http_headers.update(_thread_local_api_settings.headers or {})
|
239
|
+
|
240
|
+
auth = None
|
241
|
+
if self.access_token is not None:
|
242
|
+
extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
|
243
|
+
elif _thread_local_api_settings.cookies is None:
|
244
|
+
auth = ("api", self.api_key or "")
|
245
|
+
|
237
246
|
proxies = self.settings("_proxies") or json.loads(
|
238
247
|
self._environ.get("WANDB__PROXIES", "{}")
|
239
248
|
)
|
240
249
|
|
241
|
-
auth = None
|
242
|
-
if _thread_local_api_settings.cookies is None:
|
243
|
-
auth = ("api", self.api_key or "")
|
244
|
-
extra_http_headers.update(_thread_local_api_settings.headers or {})
|
245
250
|
self.client = Client(
|
246
251
|
transport=GraphQLSession(
|
247
252
|
headers={
|
@@ -376,6 +381,35 @@ class Api:
|
|
376
381
|
default_key: Optional[str] = self.default_settings.get("api_key")
|
377
382
|
return env_key or key or sagemaker_key or default_key
|
378
383
|
|
384
|
+
@property
|
385
|
+
def access_token(self) -> Optional[str]:
|
386
|
+
"""Retrieves an access token for authentication.
|
387
|
+
|
388
|
+
This function attempts to exchange an identity token for a temporary
|
389
|
+
access token from the server, and save it to the credentials file.
|
390
|
+
It uses the path to the identity token as defined in the environment
|
391
|
+
variables. If the environment variable is not set, it returns None.
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
Optional[str]: The access token if available, otherwise None if
|
395
|
+
no identity token is supplied.
|
396
|
+
Raises:
|
397
|
+
AuthenticationError: If the path to the identity token is not found.
|
398
|
+
"""
|
399
|
+
token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
|
400
|
+
if not token_file_str:
|
401
|
+
return None
|
402
|
+
|
403
|
+
token_file = Path(token_file_str)
|
404
|
+
if not token_file.exists():
|
405
|
+
raise AuthenticationError(f"Identity token file not found: {token_file}")
|
406
|
+
|
407
|
+
base_url = self.settings("base_url")
|
408
|
+
credentials_file = env.get_credentials_file(
|
409
|
+
str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE), self._environ
|
410
|
+
)
|
411
|
+
return credentials.access_token(base_url, token_file, credentials_file)
|
412
|
+
|
379
413
|
@property
|
380
414
|
def api_url(self) -> str:
|
381
415
|
return self.settings("base_url") # type: ignore
|
@@ -2674,14 +2708,20 @@ class Api:
|
|
2674
2708
|
A tuple of the content length and the streaming response
|
2675
2709
|
"""
|
2676
2710
|
check_httpclient_logger_handler()
|
2711
|
+
|
2712
|
+
http_headers = _thread_local_api_settings.headers or {}
|
2713
|
+
|
2677
2714
|
auth = None
|
2678
|
-
if
|
2679
|
-
|
2715
|
+
if self.access_token is not None:
|
2716
|
+
http_headers["Authorization"] = f"Bearer {self.access_token}"
|
2717
|
+
elif _thread_local_api_settings.cookies is None:
|
2718
|
+
auth = ("api", self.api_key or "")
|
2719
|
+
|
2680
2720
|
response = requests.get(
|
2681
2721
|
url,
|
2682
2722
|
auth=auth,
|
2683
2723
|
cookies=_thread_local_api_settings.cookies or {},
|
2684
|
-
headers=
|
2724
|
+
headers=http_headers,
|
2685
2725
|
stream=True,
|
2686
2726
|
)
|
2687
2727
|
response.raise_for_status()
|
@@ -2780,9 +2820,9 @@ class Api:
|
|
2780
2820
|
except requests.exceptions.RequestException as e:
|
2781
2821
|
logger.error(f"upload_file exception {url}: {e}")
|
2782
2822
|
request_headers = e.request.headers if e.request is not None else ""
|
2783
|
-
logger.error(f"upload_file request headers: {request_headers}")
|
2823
|
+
logger.error(f"upload_file request headers: {request_headers!r}")
|
2784
2824
|
response_content = e.response.content if e.response is not None else ""
|
2785
|
-
logger.error(f"upload_file response body: {response_content}")
|
2825
|
+
logger.error(f"upload_file response body: {response_content!r}")
|
2786
2826
|
status_code = e.response.status_code if e.response is not None else 0
|
2787
2827
|
# S3 reports retryable request timeouts out-of-band
|
2788
2828
|
is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
|
@@ -2848,7 +2888,7 @@ class Api:
|
|
2848
2888
|
request_headers = e.request.headers if e.request is not None else ""
|
2849
2889
|
logger.error(f"upload_file request headers: {request_headers}")
|
2850
2890
|
response_content = e.response.content if e.response is not None else ""
|
2851
|
-
logger.error(f"upload_file response body: {response_content}")
|
2891
|
+
logger.error(f"upload_file response body: {response_content!r}")
|
2852
2892
|
status_code = e.response.status_code if e.response is not None else 0
|
2853
2893
|
# S3 reports retryable request timeouts out-of-band
|
2854
2894
|
is_aws_retryable = (
|
@@ -3039,6 +3079,7 @@ class Api:
|
|
3039
3079
|
entity: Optional[str] = None,
|
3040
3080
|
state: Optional[str] = None,
|
3041
3081
|
prior_runs: Optional[List[str]] = None,
|
3082
|
+
template_variable_values: Optional[Dict[str, Any]] = None,
|
3042
3083
|
) -> Tuple[str, List[str]]:
|
3043
3084
|
"""Upsert a sweep object.
|
3044
3085
|
|
@@ -3052,6 +3093,7 @@ class Api:
|
|
3052
3093
|
entity (str): entity to use
|
3053
3094
|
state (str): state
|
3054
3095
|
prior_runs (list): IDs of existing runs to add to the sweep
|
3096
|
+
template_variable_values (dict): template variable values
|
3055
3097
|
"""
|
3056
3098
|
project_query = """
|
3057
3099
|
project {
|
@@ -3096,7 +3138,17 @@ class Api:
|
|
3096
3138
|
"""
|
3097
3139
|
# TODO(jhr): we need protocol versioning to know schema is not supported
|
3098
3140
|
# for now we will just try both new and old query
|
3099
|
-
|
3141
|
+
mutation_5 = gql(
|
3142
|
+
mutation_str.replace(
|
3143
|
+
"$controller: JSONString,",
|
3144
|
+
"$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
|
3145
|
+
)
|
3146
|
+
.replace(
|
3147
|
+
"controller: $controller,",
|
3148
|
+
"controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
|
3149
|
+
)
|
3150
|
+
.replace("_PROJECT_QUERY_", project_query)
|
3151
|
+
)
|
3100
3152
|
# launchScheduler was introduced in core v0.14.0
|
3101
3153
|
mutation_4 = gql(
|
3102
3154
|
mutation_str.replace(
|
@@ -3105,7 +3157,7 @@ class Api:
|
|
3105
3157
|
)
|
3106
3158
|
.replace(
|
3107
3159
|
"controller: $controller,",
|
3108
|
-
"controller: $controller,launchScheduler: $launchScheduler
|
3160
|
+
"controller: $controller,launchScheduler: $launchScheduler",
|
3109
3161
|
)
|
3110
3162
|
.replace("_PROJECT_QUERY_", project_query)
|
3111
3163
|
)
|
@@ -3124,7 +3176,7 @@ class Api:
|
|
3124
3176
|
)
|
3125
3177
|
|
3126
3178
|
# TODO(dag): replace this with a query for protocol versioning
|
3127
|
-
mutations = [mutation_4, mutation_3, mutation_2, mutation_1]
|
3179
|
+
mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
|
3128
3180
|
|
3129
3181
|
config = self._validate_config_and_fill_distribution(config)
|
3130
3182
|
|
@@ -3148,6 +3200,7 @@ class Api:
|
|
3148
3200
|
"projectName": project or self.settings("project"),
|
3149
3201
|
"controller": controller,
|
3150
3202
|
"launchScheduler": launch_scheduler,
|
3203
|
+
"templateVariableValues": json.dumps(template_variable_values),
|
3151
3204
|
"scheduler": scheduler,
|
3152
3205
|
"priorRunsFilters": filters,
|
3153
3206
|
}
|
wandb/sdk/launch/_launch_add.py
CHANGED
@@ -61,7 +61,7 @@ def launch_add(
|
|
61
61
|
config: A dictionary containing the configuration for the run. May also contain
|
62
62
|
resource specific arguments under the key "resource_args"
|
63
63
|
template_variables: A dictionary containing values of template variables for a run queue.
|
64
|
-
Expected format of {"
|
64
|
+
Expected format of {"VAR_NAME": VAR_VALUE}
|
65
65
|
project: Target project to send launched run to
|
66
66
|
entity: Target entity to send launched run to
|
67
67
|
queue: the name of the queue to enqueue the run to
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -308,7 +308,9 @@ class LaunchAgent:
|
|
308
308
|
self._wandb_run = None
|
309
309
|
|
310
310
|
if self.gorilla_supports_agents:
|
311
|
-
settings = wandb.Settings(
|
311
|
+
settings = wandb.Settings(
|
312
|
+
silent=True, disable_git=True, disable_job_creation=True
|
313
|
+
)
|
312
314
|
self._wandb_run = wandb.init(
|
313
315
|
project=self._project,
|
314
316
|
entity=self._entity,
|
@@ -259,10 +259,12 @@ class Scheduler(ABC):
|
|
259
259
|
|
260
260
|
def _init_wandb_run(self) -> "SdkRun":
|
261
261
|
"""Controls resume or init logic for a scheduler wandb run."""
|
262
|
+
settings = wandb.Settings(disable_job_creation=True)
|
262
263
|
run: SdkRun = wandb.init( # type: ignore
|
263
264
|
name=f"Scheduler.{self._sweep_id}",
|
264
265
|
resume="allow",
|
265
266
|
config=self._kwargs, # when run as a job, this sets config
|
267
|
+
settings=settings,
|
266
268
|
)
|
267
269
|
return run
|
268
270
|
|
wandb/sdk/launch/utils.py
CHANGED
@@ -60,7 +60,7 @@ AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
|
60
60
|
)
|
61
61
|
|
62
62
|
ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
63
|
-
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[
|
63
|
+
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\.\/\w-]+):?(?P<tag>.*)$"
|
64
64
|
)
|
65
65
|
|
66
66
|
GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
|
@@ -26,6 +26,7 @@ _Setting = Literal[
|
|
26
26
|
"_disable_machine_info",
|
27
27
|
"_executable",
|
28
28
|
"_extra_http_headers",
|
29
|
+
"_file_stream_max_bytes",
|
29
30
|
"_file_stream_retry_max",
|
30
31
|
"_file_stream_retry_wait_min_seconds",
|
31
32
|
"_file_stream_retry_wait_max_seconds",
|
@@ -91,6 +92,7 @@ _Setting = Literal[
|
|
91
92
|
"config_paths",
|
92
93
|
"console",
|
93
94
|
"console_multipart",
|
95
|
+
"credentials_file",
|
94
96
|
"deployment",
|
95
97
|
"disable_code",
|
96
98
|
"disable_git",
|
@@ -110,6 +112,9 @@ _Setting = Literal[
|
|
110
112
|
"git_root",
|
111
113
|
"heartbeat_seconds",
|
112
114
|
"host",
|
115
|
+
"http_proxy",
|
116
|
+
"https_proxy",
|
117
|
+
"identity_token_file",
|
113
118
|
"ignore_globs",
|
114
119
|
"init_timeout",
|
115
120
|
"is_local",
|
@@ -224,6 +229,9 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
224
229
|
"disable_git",
|
225
230
|
"disable_job_creation",
|
226
231
|
"files_dir",
|
232
|
+
"_proxies",
|
233
|
+
"http_proxy",
|
234
|
+
"https_proxy",
|
227
235
|
"log_dir",
|
228
236
|
"log_internal",
|
229
237
|
"log_symlink_internal",
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""apikey util."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
import platform
|
4
5
|
import stat
|
5
6
|
import sys
|
6
7
|
import textwrap
|
@@ -15,7 +16,7 @@ else:
|
|
15
16
|
from typing_extensions import Literal
|
16
17
|
|
17
18
|
import click
|
18
|
-
|
19
|
+
from requests.utils import NETRC_FILES, get_netrc_auth
|
19
20
|
|
20
21
|
import wandb
|
21
22
|
from wandb.apis import InternalApi
|
@@ -54,10 +55,22 @@ def _fixup_anon_mode(default: Optional[Mode]) -> Optional[Mode]:
|
|
54
55
|
|
55
56
|
|
56
57
|
def get_netrc_file_path() -> str:
|
58
|
+
"""Return the path to the netrc file."""
|
59
|
+
# if the NETRC environment variable is set, use that
|
57
60
|
netrc_file = os.environ.get("NETRC")
|
58
61
|
if netrc_file:
|
59
62
|
return os.path.expanduser(netrc_file)
|
60
|
-
|
63
|
+
|
64
|
+
# if either .netrc or _netrc exists in the home directory, use that
|
65
|
+
for netrc_file in NETRC_FILES:
|
66
|
+
home_dir = os.path.expanduser("~")
|
67
|
+
if os.path.exists(os.path.join(home_dir, netrc_file)):
|
68
|
+
return os.path.join(home_dir, netrc_file)
|
69
|
+
|
70
|
+
# otherwise, use .netrc on non-Windows platforms and _netrc on Windows
|
71
|
+
netrc_file = ".netrc" if platform.system() != "Windows" else "_netrc"
|
72
|
+
|
73
|
+
return os.path.join(os.path.expanduser("~"), netrc_file)
|
61
74
|
|
62
75
|
|
63
76
|
def prompt_api_key( # noqa: C901
|
@@ -254,7 +267,7 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
|
|
254
267
|
assert settings is not None
|
255
268
|
if settings.api_key:
|
256
269
|
return settings.api_key
|
257
|
-
auth =
|
270
|
+
auth = get_netrc_auth(settings.base_url)
|
258
271
|
if auth:
|
259
272
|
return auth[-1]
|
260
273
|
return None
|
@@ -0,0 +1,141 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
from datetime import datetime, timedelta
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import requests.utils
|
7
|
+
|
8
|
+
from wandb.errors import AuthenticationError
|
9
|
+
|
10
|
+
DEFAULT_WANDB_CREDENTIALS_FILE = Path(
|
11
|
+
os.path.expanduser("~/.config/wandb/credentials.json")
|
12
|
+
)
|
13
|
+
|
14
|
+
_expires_at_fmt = "%Y-%m-%d %H:%M:%S"
|
15
|
+
|
16
|
+
|
17
|
+
def access_token(base_url: str, token_file: Path, credentials_file: Path) -> str:
|
18
|
+
"""Retrieve an access token from the credentials file.
|
19
|
+
|
20
|
+
If no access token exists, create a new one by exchanging the identity
|
21
|
+
token from the token file, and save it to the credentials file.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
base_url (str): The base URL of the server
|
25
|
+
token_file (pathlib.Path): The path to the file containing the
|
26
|
+
identity token
|
27
|
+
credentials_file (pathlib.Path): The path to file used to save
|
28
|
+
temporary access tokens
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
str: The access token
|
32
|
+
"""
|
33
|
+
if not credentials_file.exists():
|
34
|
+
_write_credentials_file(base_url, token_file, credentials_file)
|
35
|
+
|
36
|
+
data = _fetch_credentials(base_url, token_file, credentials_file)
|
37
|
+
return data["access_token"]
|
38
|
+
|
39
|
+
|
40
|
+
def _write_credentials_file(base_url: str, token_file: Path, credentials_file: Path):
|
41
|
+
"""Obtain an access token from the server and write it to the credentials file.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
base_url (str): The base URL of the server
|
45
|
+
token_file (pathlib.Path): The path to the file containing the
|
46
|
+
identity token
|
47
|
+
credentials_file (pathlib.Path): The path to file used to save
|
48
|
+
temporary access tokens
|
49
|
+
"""
|
50
|
+
credentials = _create_access_token(base_url, token_file)
|
51
|
+
data = {"credentials": {base_url: credentials}}
|
52
|
+
with open(credentials_file, "w") as file:
|
53
|
+
json.dump(data, file, indent=4)
|
54
|
+
|
55
|
+
# Set file permissions to be read/write by the owner only
|
56
|
+
os.chmod(credentials_file, 0o600)
|
57
|
+
|
58
|
+
|
59
|
+
def _fetch_credentials(base_url: str, token_file: Path, credentials_file: Path) -> dict:
|
60
|
+
"""Fetch the access token from the credentials file.
|
61
|
+
|
62
|
+
If the access token has expired, fetch a new one from the server and save it
|
63
|
+
to the credentials file.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
base_url (str): The base URL of the server
|
67
|
+
token_file (pathlib.Path): The path to the file containing the
|
68
|
+
identity token
|
69
|
+
credentials_file (pathlib.Path): The path to file used to save
|
70
|
+
temporary access tokens
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
dict: The credentials including the access token.
|
74
|
+
"""
|
75
|
+
creds = {}
|
76
|
+
with open(credentials_file) as file:
|
77
|
+
data = json.load(file)
|
78
|
+
if "credentials" not in data:
|
79
|
+
data["credentials"] = {}
|
80
|
+
if base_url in data["credentials"]:
|
81
|
+
creds = data["credentials"][base_url]
|
82
|
+
|
83
|
+
expires_at = datetime.utcnow()
|
84
|
+
if "expires_at" in creds:
|
85
|
+
expires_at = datetime.strptime(creds["expires_at"], _expires_at_fmt)
|
86
|
+
|
87
|
+
if expires_at <= datetime.utcnow():
|
88
|
+
creds = _create_access_token(base_url, token_file)
|
89
|
+
with open(credentials_file, "w") as file:
|
90
|
+
data["credentials"][base_url] = creds
|
91
|
+
json.dump(data, file, indent=4)
|
92
|
+
|
93
|
+
return creds
|
94
|
+
|
95
|
+
|
96
|
+
def _create_access_token(base_url: str, token_file: Path) -> dict:
|
97
|
+
"""Exchange an identity token for an access token from the server.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
base_url (str): The base URL of the server.
|
101
|
+
token_file (pathlib.Path): The path to the file containing the
|
102
|
+
identity token
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
dict: The access token and its expiration.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
FileNotFoundError: If the token file is not found.
|
109
|
+
OSError: If there is an issue reading the token file.
|
110
|
+
AuthenticationError: If the server fails to provide an access token.
|
111
|
+
"""
|
112
|
+
try:
|
113
|
+
with open(token_file) as file:
|
114
|
+
token = file.read().strip()
|
115
|
+
except FileNotFoundError as e:
|
116
|
+
raise FileNotFoundError(f"Identity token file not found: {token_file}") from e
|
117
|
+
except OSError as e:
|
118
|
+
raise OSError(
|
119
|
+
f"Failed to read the identity token from file: {token_file}"
|
120
|
+
) from e
|
121
|
+
|
122
|
+
url = f"{base_url}/oidc/token"
|
123
|
+
data = {
|
124
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
125
|
+
"assertion": token,
|
126
|
+
}
|
127
|
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
128
|
+
|
129
|
+
response = requests.post(url, data=data, headers=headers)
|
130
|
+
|
131
|
+
if response.status_code != 200:
|
132
|
+
raise AuthenticationError(
|
133
|
+
f"Failed to retrieve access token: {response.status_code}, {response.text}"
|
134
|
+
)
|
135
|
+
|
136
|
+
resp_json = response.json()
|
137
|
+
expires_at = datetime.utcnow() + timedelta(seconds=float(resp_json["expires_in"]))
|
138
|
+
resp_json["expires_at"] = expires_at.strftime(_expires_at_fmt)
|
139
|
+
del resp_json["expires_in"]
|
140
|
+
|
141
|
+
return resp_json
|
wandb/sdk/wandb_init.py
CHANGED
@@ -323,6 +323,15 @@ class _WandbInit:
|
|
323
323
|
if save_code_pre_user_settings is False:
|
324
324
|
settings.update({"save_code": False}, source=Source.INIT)
|
325
325
|
|
326
|
+
# TODO: remove this once we refactor the client. This is a temporary
|
327
|
+
# fix to make sure that we use the same project name for wandb-core.
|
328
|
+
# The reason this is not going throught the settings object is to
|
329
|
+
# avoid failure cases in other parts of the code that will be
|
330
|
+
# removed with the switch to wandb-core.
|
331
|
+
if settings.project is None:
|
332
|
+
project = wandb.util.auto_project_name(settings.program)
|
333
|
+
settings.update({"project": project}, source=Source.INIT)
|
334
|
+
|
326
335
|
# TODO(jhr): should this be moved? probably.
|
327
336
|
settings._set_run_start_time(source=Source.INIT)
|
328
337
|
|
@@ -989,8 +998,9 @@ def init(
|
|
989
998
|
|
990
999
|
Arguments:
|
991
1000
|
project: (str, optional) The name of the project where you're sending
|
992
|
-
the new run. If the project is not specified,
|
993
|
-
|
1001
|
+
the new run. If the project is not specified, we will try to infer
|
1002
|
+
the project name from git root or the current program file. If we
|
1003
|
+
can't infer the project name, we will default to `"uncategorized"`.
|
994
1004
|
entity: (str, optional) An entity is a username or team name where
|
995
1005
|
you're sending runs. This entity must exist before you can send runs
|
996
1006
|
there, so make sure to create your account or team in the UI before
|
wandb/sdk/wandb_login.py
CHANGED
@@ -156,6 +156,9 @@ class _WandbLogin:
|
|
156
156
|
"""Returns whether an API key is set or can be inferred."""
|
157
157
|
return apikey.api_key(settings=self._settings) is not None
|
158
158
|
|
159
|
+
def should_use_identity_token(self):
|
160
|
+
return self._settings.identity_token_file is not None
|
161
|
+
|
159
162
|
def set_backend(self, backend):
|
160
163
|
self._backend = backend
|
161
164
|
|
@@ -327,6 +330,9 @@ def _login(
|
|
327
330
|
)
|
328
331
|
return False
|
329
332
|
|
333
|
+
if wlogin.should_use_identity_token():
|
334
|
+
return True
|
335
|
+
|
330
336
|
# perform a login
|
331
337
|
logged_in = wlogin.login()
|
332
338
|
|
wandb/sdk/wandb_manager.py
CHANGED
@@ -127,6 +127,7 @@ class _Manager:
|
|
127
127
|
raise ManagerConnectionError(f"Connection to wandb service failed: {e}")
|
128
128
|
|
129
129
|
def __init__(self, settings: "Settings") -> None:
|
130
|
+
"""Connects to the internal service, starting it if necessary."""
|
130
131
|
from wandb.sdk.service import service
|
131
132
|
|
132
133
|
self._settings = settings
|
@@ -134,6 +135,7 @@ class _Manager:
|
|
134
135
|
self._hooks = None
|
135
136
|
|
136
137
|
self._service = service._Service(settings=self._settings)
|
138
|
+
|
137
139
|
token = _ManagerToken.from_environment()
|
138
140
|
if not token:
|
139
141
|
self._service.start()
|
@@ -144,7 +146,6 @@ class _Manager:
|
|
144
146
|
token = _ManagerToken.from_params(transport=transport, host=host, port=port)
|
145
147
|
token.set_environment()
|
146
148
|
self._atexit_setup()
|
147
|
-
|
148
149
|
self._token = token
|
149
150
|
|
150
151
|
try:
|
@@ -152,6 +153,24 @@ class _Manager:
|
|
152
153
|
except ManagerConnectionError as e:
|
153
154
|
wandb._sentry.reraise(e)
|
154
155
|
|
156
|
+
def _teardown(self, exit_code: int) -> int:
|
157
|
+
"""Shuts down the internal process and returns its exit code.
|
158
|
+
|
159
|
+
This sends a teardown record to the process. An exception is raised if
|
160
|
+
the process has already been shut down.
|
161
|
+
"""
|
162
|
+
unregister_all_post_import_hooks()
|
163
|
+
|
164
|
+
if self._atexit_lambda:
|
165
|
+
atexit.unregister(self._atexit_lambda)
|
166
|
+
self._atexit_lambda = None
|
167
|
+
|
168
|
+
try:
|
169
|
+
self._inform_teardown(exit_code)
|
170
|
+
return self._service.join()
|
171
|
+
finally:
|
172
|
+
self._token.reset_environment()
|
173
|
+
|
155
174
|
def _atexit_setup(self) -> None:
|
156
175
|
self._atexit_lambda = lambda: self._atexit_teardown()
|
157
176
|
|
@@ -161,28 +180,18 @@ class _Manager:
|
|
161
180
|
|
162
181
|
def _atexit_teardown(self) -> None:
|
163
182
|
trigger.call("on_finished")
|
164
|
-
exit_code = self._hooks.exit_code if self._hooks else 0
|
165
|
-
self._teardown(exit_code)
|
166
|
-
|
167
|
-
def _teardown(self, exit_code: int) -> None:
|
168
|
-
unregister_all_post_import_hooks()
|
169
183
|
|
170
|
-
|
171
|
-
|
172
|
-
|
184
|
+
# Clear the atexit hook---we're executing it now, after which the
|
185
|
+
# process will exit.
|
186
|
+
self._atexit_lambda = None
|
173
187
|
|
174
188
|
try:
|
175
|
-
self.
|
176
|
-
result = self._service.join()
|
177
|
-
if result and not self._settings._notebook:
|
178
|
-
os._exit(result)
|
189
|
+
self._teardown(self._hooks.exit_code if self._hooks else 0)
|
179
190
|
except Exception as e:
|
180
191
|
wandb.termlog(
|
181
|
-
f"
|
192
|
+
f"Encountered an error while tearing down the service manager: {e}",
|
182
193
|
repeat=False,
|
183
194
|
)
|
184
|
-
finally:
|
185
|
-
self._token.reset_environment()
|
186
195
|
|
187
196
|
def _get_service(self) -> "service._Service":
|
188
197
|
return self._service
|
wandb/sdk/wandb_run.py
CHANGED
@@ -2751,12 +2751,14 @@ class Run:
|
|
2751
2751
|
summary_items = [s.lower() for s in summary.split(",")]
|
2752
2752
|
summary_ops = []
|
2753
2753
|
valid = {"min", "max", "mean", "best", "last", "copy", "none"}
|
2754
|
+
# TODO: deprecate copy and best
|
2754
2755
|
for i in summary_items:
|
2755
2756
|
if i not in valid:
|
2756
2757
|
raise wandb.Error(f"Unhandled define_metric() arg: summary op: {i}")
|
2757
2758
|
summary_ops.append(i)
|
2758
2759
|
with telemetry.context(run=self) as tel:
|
2759
2760
|
tel.feature.metric_summary = True
|
2761
|
+
# TODO: deprecate goal
|
2760
2762
|
goal_cleaned: Optional[str] = None
|
2761
2763
|
if goal is not None:
|
2762
2764
|
goal_cleaned = goal[:3].lower()
|
@@ -2772,6 +2774,9 @@ class Run:
|
|
2772
2774
|
with telemetry.context(run=self) as tel:
|
2773
2775
|
tel.feature.metric_step_sync = True
|
2774
2776
|
|
2777
|
+
with telemetry.context(run=self) as tel:
|
2778
|
+
tel.feature.metric = True
|
2779
|
+
|
2775
2780
|
m = wandb_metric.Metric(
|
2776
2781
|
name=name,
|
2777
2782
|
step_metric=step_metric,
|
@@ -2783,8 +2788,6 @@ class Run:
|
|
2783
2788
|
)
|
2784
2789
|
m._set_callback(self._metric_callback)
|
2785
2790
|
m._commit()
|
2786
|
-
with telemetry.context(run=self) as tel:
|
2787
|
-
tel.feature.metric = True
|
2788
2791
|
return m
|
2789
2792
|
|
2790
2793
|
# TODO(jhr): annotate this
|