wandb 0.17.2__py3-none-any.whl → 0.17.4__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/apis/internal.py +4 -0
- wandb/apis/public/projects.py +0 -6
- wandb/cli/cli.py +7 -6
- wandb/env.py +16 -0
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/sdk/artifacts/artifact.py +2 -0
- wandb/sdk/artifacts/artifact_file_cache.py +35 -13
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +11 -6
- wandb/sdk/internal/internal_api.py +68 -15
- wandb/sdk/launch/_launch_add.py +1 -1
- wandb/sdk/launch/agent/agent.py +3 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -0
- wandb/sdk/launch/utils.py +1 -1
- wandb/sdk/lib/_settings_toposort_generated.py +8 -0
- wandb/sdk/lib/apikey.py +16 -3
- wandb/sdk/lib/credentials.py +141 -0
- wandb/sdk/wandb_init.py +12 -2
- wandb/sdk/wandb_login.py +6 -0
- wandb/sdk/wandb_manager.py +25 -16
- wandb/sdk/wandb_run.py +5 -2
- wandb/sdk/wandb_settings.py +27 -3
- wandb/sdk/wandb_setup.py +12 -13
- wandb/sdk/wandb_sweep.py +4 -2
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/METADATA +1 -1
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/RECORD +31 -30
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/WHEEL +1 -1
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.2.dist-info → wandb-0.17.4.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@ import socket
|
|
11
11
|
import sys
|
12
12
|
import threading
|
13
13
|
from copy import deepcopy
|
14
|
+
from pathlib import Path
|
14
15
|
from typing import (
|
15
16
|
IO,
|
16
17
|
TYPE_CHECKING,
|
@@ -37,14 +38,14 @@ from wandb_gql.client import RetryError
|
|
37
38
|
import wandb
|
38
39
|
from wandb import env, util
|
39
40
|
from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messages
|
40
|
-
from wandb.errors import CommError, UnsupportedError, UsageError
|
41
|
+
from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
|
41
42
|
from wandb.integration.sagemaker import parse_sm_secrets
|
42
43
|
from wandb.old.settings import Settings
|
43
44
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
44
45
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
45
46
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
46
47
|
|
47
|
-
from ..lib import retry
|
48
|
+
from ..lib import credentials, retry
|
48
49
|
from ..lib.filenames import DIFF_FNAME, METADATA_FNAME
|
49
50
|
from ..lib.gitlib import GitRepo
|
50
51
|
from . import context
|
@@ -234,14 +235,18 @@ class Api:
|
|
234
235
|
extra_http_headers = self.settings("_extra_http_headers") or json.loads(
|
235
236
|
self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
|
236
237
|
)
|
238
|
+
extra_http_headers.update(_thread_local_api_settings.headers or {})
|
239
|
+
|
240
|
+
auth = None
|
241
|
+
if self.access_token is not None:
|
242
|
+
extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
|
243
|
+
elif _thread_local_api_settings.cookies is None:
|
244
|
+
auth = ("api", self.api_key or "")
|
245
|
+
|
237
246
|
proxies = self.settings("_proxies") or json.loads(
|
238
247
|
self._environ.get("WANDB__PROXIES", "{}")
|
239
248
|
)
|
240
249
|
|
241
|
-
auth = None
|
242
|
-
if _thread_local_api_settings.cookies is None:
|
243
|
-
auth = ("api", self.api_key or "")
|
244
|
-
extra_http_headers.update(_thread_local_api_settings.headers or {})
|
245
250
|
self.client = Client(
|
246
251
|
transport=GraphQLSession(
|
247
252
|
headers={
|
@@ -376,6 +381,35 @@ class Api:
|
|
376
381
|
default_key: Optional[str] = self.default_settings.get("api_key")
|
377
382
|
return env_key or key or sagemaker_key or default_key
|
378
383
|
|
384
|
+
@property
|
385
|
+
def access_token(self) -> Optional[str]:
|
386
|
+
"""Retrieves an access token for authentication.
|
387
|
+
|
388
|
+
This function attempts to exchange an identity token for a temporary
|
389
|
+
access token from the server, and save it to the credentials file.
|
390
|
+
It uses the path to the identity token as defined in the environment
|
391
|
+
variables. If the environment variable is not set, it returns None.
|
392
|
+
|
393
|
+
Returns:
|
394
|
+
Optional[str]: The access token if available, otherwise None if
|
395
|
+
no identity token is supplied.
|
396
|
+
Raises:
|
397
|
+
AuthenticationError: If the path to the identity token is not found.
|
398
|
+
"""
|
399
|
+
token_file_str = self._environ.get(env.IDENTITY_TOKEN_FILE)
|
400
|
+
if not token_file_str:
|
401
|
+
return None
|
402
|
+
|
403
|
+
token_file = Path(token_file_str)
|
404
|
+
if not token_file.exists():
|
405
|
+
raise AuthenticationError(f"Identity token file not found: {token_file}")
|
406
|
+
|
407
|
+
base_url = self.settings("base_url")
|
408
|
+
credentials_file = env.get_credentials_file(
|
409
|
+
str(credentials.DEFAULT_WANDB_CREDENTIALS_FILE), self._environ
|
410
|
+
)
|
411
|
+
return credentials.access_token(base_url, token_file, credentials_file)
|
412
|
+
|
379
413
|
@property
|
380
414
|
def api_url(self) -> str:
|
381
415
|
return self.settings("base_url") # type: ignore
|
@@ -2674,14 +2708,20 @@ class Api:
|
|
2674
2708
|
A tuple of the content length and the streaming response
|
2675
2709
|
"""
|
2676
2710
|
check_httpclient_logger_handler()
|
2711
|
+
|
2712
|
+
http_headers = _thread_local_api_settings.headers or {}
|
2713
|
+
|
2677
2714
|
auth = None
|
2678
|
-
if
|
2679
|
-
|
2715
|
+
if self.access_token is not None:
|
2716
|
+
http_headers["Authorization"] = f"Bearer {self.access_token}"
|
2717
|
+
elif _thread_local_api_settings.cookies is None:
|
2718
|
+
auth = ("api", self.api_key or "")
|
2719
|
+
|
2680
2720
|
response = requests.get(
|
2681
2721
|
url,
|
2682
2722
|
auth=auth,
|
2683
2723
|
cookies=_thread_local_api_settings.cookies or {},
|
2684
|
-
headers=
|
2724
|
+
headers=http_headers,
|
2685
2725
|
stream=True,
|
2686
2726
|
)
|
2687
2727
|
response.raise_for_status()
|
@@ -2780,9 +2820,9 @@ class Api:
|
|
2780
2820
|
except requests.exceptions.RequestException as e:
|
2781
2821
|
logger.error(f"upload_file exception {url}: {e}")
|
2782
2822
|
request_headers = e.request.headers if e.request is not None else ""
|
2783
|
-
logger.error(f"upload_file request headers: {request_headers}")
|
2823
|
+
logger.error(f"upload_file request headers: {request_headers!r}")
|
2784
2824
|
response_content = e.response.content if e.response is not None else ""
|
2785
|
-
logger.error(f"upload_file response body: {response_content}")
|
2825
|
+
logger.error(f"upload_file response body: {response_content!r}")
|
2786
2826
|
status_code = e.response.status_code if e.response is not None else 0
|
2787
2827
|
# S3 reports retryable request timeouts out-of-band
|
2788
2828
|
is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
|
@@ -2848,7 +2888,7 @@ class Api:
|
|
2848
2888
|
request_headers = e.request.headers if e.request is not None else ""
|
2849
2889
|
logger.error(f"upload_file request headers: {request_headers}")
|
2850
2890
|
response_content = e.response.content if e.response is not None else ""
|
2851
|
-
logger.error(f"upload_file response body: {response_content}")
|
2891
|
+
logger.error(f"upload_file response body: {response_content!r}")
|
2852
2892
|
status_code = e.response.status_code if e.response is not None else 0
|
2853
2893
|
# S3 reports retryable request timeouts out-of-band
|
2854
2894
|
is_aws_retryable = (
|
@@ -3039,6 +3079,7 @@ class Api:
|
|
3039
3079
|
entity: Optional[str] = None,
|
3040
3080
|
state: Optional[str] = None,
|
3041
3081
|
prior_runs: Optional[List[str]] = None,
|
3082
|
+
template_variable_values: Optional[Dict[str, Any]] = None,
|
3042
3083
|
) -> Tuple[str, List[str]]:
|
3043
3084
|
"""Upsert a sweep object.
|
3044
3085
|
|
@@ -3052,6 +3093,7 @@ class Api:
|
|
3052
3093
|
entity (str): entity to use
|
3053
3094
|
state (str): state
|
3054
3095
|
prior_runs (list): IDs of existing runs to add to the sweep
|
3096
|
+
template_variable_values (dict): template variable values
|
3055
3097
|
"""
|
3056
3098
|
project_query = """
|
3057
3099
|
project {
|
@@ -3096,7 +3138,17 @@ class Api:
|
|
3096
3138
|
"""
|
3097
3139
|
# TODO(jhr): we need protocol versioning to know schema is not supported
|
3098
3140
|
# for now we will just try both new and old query
|
3099
|
-
|
3141
|
+
mutation_5 = gql(
|
3142
|
+
mutation_str.replace(
|
3143
|
+
"$controller: JSONString,",
|
3144
|
+
"$controller: JSONString,$launchScheduler: JSONString, $templateVariableValues: JSONString,",
|
3145
|
+
)
|
3146
|
+
.replace(
|
3147
|
+
"controller: $controller,",
|
3148
|
+
"controller: $controller,launchScheduler: $launchScheduler,templateVariableValues: $templateVariableValues,",
|
3149
|
+
)
|
3150
|
+
.replace("_PROJECT_QUERY_", project_query)
|
3151
|
+
)
|
3100
3152
|
# launchScheduler was introduced in core v0.14.0
|
3101
3153
|
mutation_4 = gql(
|
3102
3154
|
mutation_str.replace(
|
@@ -3105,7 +3157,7 @@ class Api:
|
|
3105
3157
|
)
|
3106
3158
|
.replace(
|
3107
3159
|
"controller: $controller,",
|
3108
|
-
"controller: $controller,launchScheduler: $launchScheduler
|
3160
|
+
"controller: $controller,launchScheduler: $launchScheduler",
|
3109
3161
|
)
|
3110
3162
|
.replace("_PROJECT_QUERY_", project_query)
|
3111
3163
|
)
|
@@ -3124,7 +3176,7 @@ class Api:
|
|
3124
3176
|
)
|
3125
3177
|
|
3126
3178
|
# TODO(dag): replace this with a query for protocol versioning
|
3127
|
-
mutations = [mutation_4, mutation_3, mutation_2, mutation_1]
|
3179
|
+
mutations = [mutation_5, mutation_4, mutation_3, mutation_2, mutation_1]
|
3128
3180
|
|
3129
3181
|
config = self._validate_config_and_fill_distribution(config)
|
3130
3182
|
|
@@ -3148,6 +3200,7 @@ class Api:
|
|
3148
3200
|
"projectName": project or self.settings("project"),
|
3149
3201
|
"controller": controller,
|
3150
3202
|
"launchScheduler": launch_scheduler,
|
3203
|
+
"templateVariableValues": json.dumps(template_variable_values),
|
3151
3204
|
"scheduler": scheduler,
|
3152
3205
|
"priorRunsFilters": filters,
|
3153
3206
|
}
|
wandb/sdk/launch/_launch_add.py
CHANGED
@@ -61,7 +61,7 @@ def launch_add(
|
|
61
61
|
config: A dictionary containing the configuration for the run. May also contain
|
62
62
|
resource specific arguments under the key "resource_args"
|
63
63
|
template_variables: A dictionary containing values of template variables for a run queue.
|
64
|
-
Expected format of {"
|
64
|
+
Expected format of {"VAR_NAME": VAR_VALUE}
|
65
65
|
project: Target project to send launched run to
|
66
66
|
entity: Target entity to send launched run to
|
67
67
|
queue: the name of the queue to enqueue the run to
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -308,7 +308,9 @@ class LaunchAgent:
|
|
308
308
|
self._wandb_run = None
|
309
309
|
|
310
310
|
if self.gorilla_supports_agents:
|
311
|
-
settings = wandb.Settings(
|
311
|
+
settings = wandb.Settings(
|
312
|
+
silent=True, disable_git=True, disable_job_creation=True
|
313
|
+
)
|
312
314
|
self._wandb_run = wandb.init(
|
313
315
|
project=self._project,
|
314
316
|
entity=self._entity,
|
@@ -259,10 +259,12 @@ class Scheduler(ABC):
|
|
259
259
|
|
260
260
|
def _init_wandb_run(self) -> "SdkRun":
|
261
261
|
"""Controls resume or init logic for a scheduler wandb run."""
|
262
|
+
settings = wandb.Settings(disable_job_creation=True)
|
262
263
|
run: SdkRun = wandb.init( # type: ignore
|
263
264
|
name=f"Scheduler.{self._sweep_id}",
|
264
265
|
resume="allow",
|
265
266
|
config=self._kwargs, # when run as a job, this sets config
|
267
|
+
settings=settings,
|
266
268
|
)
|
267
269
|
return run
|
268
270
|
|
wandb/sdk/launch/utils.py
CHANGED
@@ -60,7 +60,7 @@ AZURE_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
|
60
60
|
)
|
61
61
|
|
62
62
|
ELASTIC_CONTAINER_REGISTRY_URI_REGEX = re.compile(
|
63
|
-
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[
|
63
|
+
r"^(?:https://)?(?P<account>[\w-]+)\.dkr\.ecr\.(?P<region>[\w-]+)\.amazonaws\.com/(?P<repository>[\.\/\w-]+):?(?P<tag>.*)$"
|
64
64
|
)
|
65
65
|
|
66
66
|
GCP_ARTIFACT_REGISTRY_URI_REGEX = re.compile(
|
@@ -26,6 +26,7 @@ _Setting = Literal[
|
|
26
26
|
"_disable_machine_info",
|
27
27
|
"_executable",
|
28
28
|
"_extra_http_headers",
|
29
|
+
"_file_stream_max_bytes",
|
29
30
|
"_file_stream_retry_max",
|
30
31
|
"_file_stream_retry_wait_min_seconds",
|
31
32
|
"_file_stream_retry_wait_max_seconds",
|
@@ -91,6 +92,7 @@ _Setting = Literal[
|
|
91
92
|
"config_paths",
|
92
93
|
"console",
|
93
94
|
"console_multipart",
|
95
|
+
"credentials_file",
|
94
96
|
"deployment",
|
95
97
|
"disable_code",
|
96
98
|
"disable_git",
|
@@ -110,6 +112,9 @@ _Setting = Literal[
|
|
110
112
|
"git_root",
|
111
113
|
"heartbeat_seconds",
|
112
114
|
"host",
|
115
|
+
"http_proxy",
|
116
|
+
"https_proxy",
|
117
|
+
"identity_token_file",
|
113
118
|
"ignore_globs",
|
114
119
|
"init_timeout",
|
115
120
|
"is_local",
|
@@ -224,6 +229,9 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
224
229
|
"disable_git",
|
225
230
|
"disable_job_creation",
|
226
231
|
"files_dir",
|
232
|
+
"_proxies",
|
233
|
+
"http_proxy",
|
234
|
+
"https_proxy",
|
227
235
|
"log_dir",
|
228
236
|
"log_internal",
|
229
237
|
"log_symlink_internal",
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""apikey util."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
import platform
|
4
5
|
import stat
|
5
6
|
import sys
|
6
7
|
import textwrap
|
@@ -15,7 +16,7 @@ else:
|
|
15
16
|
from typing_extensions import Literal
|
16
17
|
|
17
18
|
import click
|
18
|
-
|
19
|
+
from requests.utils import NETRC_FILES, get_netrc_auth
|
19
20
|
|
20
21
|
import wandb
|
21
22
|
from wandb.apis import InternalApi
|
@@ -54,10 +55,22 @@ def _fixup_anon_mode(default: Optional[Mode]) -> Optional[Mode]:
|
|
54
55
|
|
55
56
|
|
56
57
|
def get_netrc_file_path() -> str:
|
58
|
+
"""Return the path to the netrc file."""
|
59
|
+
# if the NETRC environment variable is set, use that
|
57
60
|
netrc_file = os.environ.get("NETRC")
|
58
61
|
if netrc_file:
|
59
62
|
return os.path.expanduser(netrc_file)
|
60
|
-
|
63
|
+
|
64
|
+
# if either .netrc or _netrc exists in the home directory, use that
|
65
|
+
for netrc_file in NETRC_FILES:
|
66
|
+
home_dir = os.path.expanduser("~")
|
67
|
+
if os.path.exists(os.path.join(home_dir, netrc_file)):
|
68
|
+
return os.path.join(home_dir, netrc_file)
|
69
|
+
|
70
|
+
# otherwise, use .netrc on non-Windows platforms and _netrc on Windows
|
71
|
+
netrc_file = ".netrc" if platform.system() != "Windows" else "_netrc"
|
72
|
+
|
73
|
+
return os.path.join(os.path.expanduser("~"), netrc_file)
|
61
74
|
|
62
75
|
|
63
76
|
def prompt_api_key( # noqa: C901
|
@@ -254,7 +267,7 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
|
|
254
267
|
assert settings is not None
|
255
268
|
if settings.api_key:
|
256
269
|
return settings.api_key
|
257
|
-
auth =
|
270
|
+
auth = get_netrc_auth(settings.base_url)
|
258
271
|
if auth:
|
259
272
|
return auth[-1]
|
260
273
|
return None
|
@@ -0,0 +1,141 @@
|
|
1
|
+
import json
|
2
|
+
import os
|
3
|
+
from datetime import datetime, timedelta
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
import requests.utils
|
7
|
+
|
8
|
+
from wandb.errors import AuthenticationError
|
9
|
+
|
10
|
+
DEFAULT_WANDB_CREDENTIALS_FILE = Path(
|
11
|
+
os.path.expanduser("~/.config/wandb/credentials.json")
|
12
|
+
)
|
13
|
+
|
14
|
+
_expires_at_fmt = "%Y-%m-%d %H:%M:%S"
|
15
|
+
|
16
|
+
|
17
|
+
def access_token(base_url: str, token_file: Path, credentials_file: Path) -> str:
|
18
|
+
"""Retrieve an access token from the credentials file.
|
19
|
+
|
20
|
+
If no access token exists, create a new one by exchanging the identity
|
21
|
+
token from the token file, and save it to the credentials file.
|
22
|
+
|
23
|
+
Args:
|
24
|
+
base_url (str): The base URL of the server
|
25
|
+
token_file (pathlib.Path): The path to the file containing the
|
26
|
+
identity token
|
27
|
+
credentials_file (pathlib.Path): The path to file used to save
|
28
|
+
temporary access tokens
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
str: The access token
|
32
|
+
"""
|
33
|
+
if not credentials_file.exists():
|
34
|
+
_write_credentials_file(base_url, token_file, credentials_file)
|
35
|
+
|
36
|
+
data = _fetch_credentials(base_url, token_file, credentials_file)
|
37
|
+
return data["access_token"]
|
38
|
+
|
39
|
+
|
40
|
+
def _write_credentials_file(base_url: str, token_file: Path, credentials_file: Path):
|
41
|
+
"""Obtain an access token from the server and write it to the credentials file.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
base_url (str): The base URL of the server
|
45
|
+
token_file (pathlib.Path): The path to the file containing the
|
46
|
+
identity token
|
47
|
+
credentials_file (pathlib.Path): The path to file used to save
|
48
|
+
temporary access tokens
|
49
|
+
"""
|
50
|
+
credentials = _create_access_token(base_url, token_file)
|
51
|
+
data = {"credentials": {base_url: credentials}}
|
52
|
+
with open(credentials_file, "w") as file:
|
53
|
+
json.dump(data, file, indent=4)
|
54
|
+
|
55
|
+
# Set file permissions to be read/write by the owner only
|
56
|
+
os.chmod(credentials_file, 0o600)
|
57
|
+
|
58
|
+
|
59
|
+
def _fetch_credentials(base_url: str, token_file: Path, credentials_file: Path) -> dict:
|
60
|
+
"""Fetch the access token from the credentials file.
|
61
|
+
|
62
|
+
If the access token has expired, fetch a new one from the server and save it
|
63
|
+
to the credentials file.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
base_url (str): The base URL of the server
|
67
|
+
token_file (pathlib.Path): The path to the file containing the
|
68
|
+
identity token
|
69
|
+
credentials_file (pathlib.Path): The path to file used to save
|
70
|
+
temporary access tokens
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
dict: The credentials including the access token.
|
74
|
+
"""
|
75
|
+
creds = {}
|
76
|
+
with open(credentials_file) as file:
|
77
|
+
data = json.load(file)
|
78
|
+
if "credentials" not in data:
|
79
|
+
data["credentials"] = {}
|
80
|
+
if base_url in data["credentials"]:
|
81
|
+
creds = data["credentials"][base_url]
|
82
|
+
|
83
|
+
expires_at = datetime.utcnow()
|
84
|
+
if "expires_at" in creds:
|
85
|
+
expires_at = datetime.strptime(creds["expires_at"], _expires_at_fmt)
|
86
|
+
|
87
|
+
if expires_at <= datetime.utcnow():
|
88
|
+
creds = _create_access_token(base_url, token_file)
|
89
|
+
with open(credentials_file, "w") as file:
|
90
|
+
data["credentials"][base_url] = creds
|
91
|
+
json.dump(data, file, indent=4)
|
92
|
+
|
93
|
+
return creds
|
94
|
+
|
95
|
+
|
96
|
+
def _create_access_token(base_url: str, token_file: Path) -> dict:
|
97
|
+
"""Exchange an identity token for an access token from the server.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
base_url (str): The base URL of the server.
|
101
|
+
token_file (pathlib.Path): The path to the file containing the
|
102
|
+
identity token
|
103
|
+
|
104
|
+
Returns:
|
105
|
+
dict: The access token and its expiration.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
FileNotFoundError: If the token file is not found.
|
109
|
+
OSError: If there is an issue reading the token file.
|
110
|
+
AuthenticationError: If the server fails to provide an access token.
|
111
|
+
"""
|
112
|
+
try:
|
113
|
+
with open(token_file) as file:
|
114
|
+
token = file.read().strip()
|
115
|
+
except FileNotFoundError as e:
|
116
|
+
raise FileNotFoundError(f"Identity token file not found: {token_file}") from e
|
117
|
+
except OSError as e:
|
118
|
+
raise OSError(
|
119
|
+
f"Failed to read the identity token from file: {token_file}"
|
120
|
+
) from e
|
121
|
+
|
122
|
+
url = f"{base_url}/oidc/token"
|
123
|
+
data = {
|
124
|
+
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
125
|
+
"assertion": token,
|
126
|
+
}
|
127
|
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
128
|
+
|
129
|
+
response = requests.post(url, data=data, headers=headers)
|
130
|
+
|
131
|
+
if response.status_code != 200:
|
132
|
+
raise AuthenticationError(
|
133
|
+
f"Failed to retrieve access token: {response.status_code}, {response.text}"
|
134
|
+
)
|
135
|
+
|
136
|
+
resp_json = response.json()
|
137
|
+
expires_at = datetime.utcnow() + timedelta(seconds=float(resp_json["expires_in"]))
|
138
|
+
resp_json["expires_at"] = expires_at.strftime(_expires_at_fmt)
|
139
|
+
del resp_json["expires_in"]
|
140
|
+
|
141
|
+
return resp_json
|
wandb/sdk/wandb_init.py
CHANGED
@@ -323,6 +323,15 @@ class _WandbInit:
|
|
323
323
|
if save_code_pre_user_settings is False:
|
324
324
|
settings.update({"save_code": False}, source=Source.INIT)
|
325
325
|
|
326
|
+
# TODO: remove this once we refactor the client. This is a temporary
|
327
|
+
# fix to make sure that we use the same project name for wandb-core.
|
328
|
+
# The reason this is not going throught the settings object is to
|
329
|
+
# avoid failure cases in other parts of the code that will be
|
330
|
+
# removed with the switch to wandb-core.
|
331
|
+
if settings.project is None:
|
332
|
+
project = wandb.util.auto_project_name(settings.program)
|
333
|
+
settings.update({"project": project}, source=Source.INIT)
|
334
|
+
|
326
335
|
# TODO(jhr): should this be moved? probably.
|
327
336
|
settings._set_run_start_time(source=Source.INIT)
|
328
337
|
|
@@ -989,8 +998,9 @@ def init(
|
|
989
998
|
|
990
999
|
Arguments:
|
991
1000
|
project: (str, optional) The name of the project where you're sending
|
992
|
-
the new run. If the project is not specified,
|
993
|
-
|
1001
|
+
the new run. If the project is not specified, we will try to infer
|
1002
|
+
the project name from git root or the current program file. If we
|
1003
|
+
can't infer the project name, we will default to `"uncategorized"`.
|
994
1004
|
entity: (str, optional) An entity is a username or team name where
|
995
1005
|
you're sending runs. This entity must exist before you can send runs
|
996
1006
|
there, so make sure to create your account or team in the UI before
|
wandb/sdk/wandb_login.py
CHANGED
@@ -156,6 +156,9 @@ class _WandbLogin:
|
|
156
156
|
"""Returns whether an API key is set or can be inferred."""
|
157
157
|
return apikey.api_key(settings=self._settings) is not None
|
158
158
|
|
159
|
+
def should_use_identity_token(self):
|
160
|
+
return self._settings.identity_token_file is not None
|
161
|
+
|
159
162
|
def set_backend(self, backend):
|
160
163
|
self._backend = backend
|
161
164
|
|
@@ -327,6 +330,9 @@ def _login(
|
|
327
330
|
)
|
328
331
|
return False
|
329
332
|
|
333
|
+
if wlogin.should_use_identity_token():
|
334
|
+
return True
|
335
|
+
|
330
336
|
# perform a login
|
331
337
|
logged_in = wlogin.login()
|
332
338
|
|
wandb/sdk/wandb_manager.py
CHANGED
@@ -127,6 +127,7 @@ class _Manager:
|
|
127
127
|
raise ManagerConnectionError(f"Connection to wandb service failed: {e}")
|
128
128
|
|
129
129
|
def __init__(self, settings: "Settings") -> None:
|
130
|
+
"""Connects to the internal service, starting it if necessary."""
|
130
131
|
from wandb.sdk.service import service
|
131
132
|
|
132
133
|
self._settings = settings
|
@@ -134,6 +135,7 @@ class _Manager:
|
|
134
135
|
self._hooks = None
|
135
136
|
|
136
137
|
self._service = service._Service(settings=self._settings)
|
138
|
+
|
137
139
|
token = _ManagerToken.from_environment()
|
138
140
|
if not token:
|
139
141
|
self._service.start()
|
@@ -144,7 +146,6 @@ class _Manager:
|
|
144
146
|
token = _ManagerToken.from_params(transport=transport, host=host, port=port)
|
145
147
|
token.set_environment()
|
146
148
|
self._atexit_setup()
|
147
|
-
|
148
149
|
self._token = token
|
149
150
|
|
150
151
|
try:
|
@@ -152,6 +153,24 @@ class _Manager:
|
|
152
153
|
except ManagerConnectionError as e:
|
153
154
|
wandb._sentry.reraise(e)
|
154
155
|
|
156
|
+
def _teardown(self, exit_code: int) -> int:
|
157
|
+
"""Shuts down the internal process and returns its exit code.
|
158
|
+
|
159
|
+
This sends a teardown record to the process. An exception is raised if
|
160
|
+
the process has already been shut down.
|
161
|
+
"""
|
162
|
+
unregister_all_post_import_hooks()
|
163
|
+
|
164
|
+
if self._atexit_lambda:
|
165
|
+
atexit.unregister(self._atexit_lambda)
|
166
|
+
self._atexit_lambda = None
|
167
|
+
|
168
|
+
try:
|
169
|
+
self._inform_teardown(exit_code)
|
170
|
+
return self._service.join()
|
171
|
+
finally:
|
172
|
+
self._token.reset_environment()
|
173
|
+
|
155
174
|
def _atexit_setup(self) -> None:
|
156
175
|
self._atexit_lambda = lambda: self._atexit_teardown()
|
157
176
|
|
@@ -161,28 +180,18 @@ class _Manager:
|
|
161
180
|
|
162
181
|
def _atexit_teardown(self) -> None:
|
163
182
|
trigger.call("on_finished")
|
164
|
-
exit_code = self._hooks.exit_code if self._hooks else 0
|
165
|
-
self._teardown(exit_code)
|
166
|
-
|
167
|
-
def _teardown(self, exit_code: int) -> None:
|
168
|
-
unregister_all_post_import_hooks()
|
169
183
|
|
170
|
-
|
171
|
-
|
172
|
-
|
184
|
+
# Clear the atexit hook---we're executing it now, after which the
|
185
|
+
# process will exit.
|
186
|
+
self._atexit_lambda = None
|
173
187
|
|
174
188
|
try:
|
175
|
-
self.
|
176
|
-
result = self._service.join()
|
177
|
-
if result and not self._settings._notebook:
|
178
|
-
os._exit(result)
|
189
|
+
self._teardown(self._hooks.exit_code if self._hooks else 0)
|
179
190
|
except Exception as e:
|
180
191
|
wandb.termlog(
|
181
|
-
f"
|
192
|
+
f"Encountered an error while tearing down the service manager: {e}",
|
182
193
|
repeat=False,
|
183
194
|
)
|
184
|
-
finally:
|
185
|
-
self._token.reset_environment()
|
186
195
|
|
187
196
|
def _get_service(self) -> "service._Service":
|
188
197
|
return self._service
|
wandb/sdk/wandb_run.py
CHANGED
@@ -2751,12 +2751,14 @@ class Run:
|
|
2751
2751
|
summary_items = [s.lower() for s in summary.split(",")]
|
2752
2752
|
summary_ops = []
|
2753
2753
|
valid = {"min", "max", "mean", "best", "last", "copy", "none"}
|
2754
|
+
# TODO: deprecate copy and best
|
2754
2755
|
for i in summary_items:
|
2755
2756
|
if i not in valid:
|
2756
2757
|
raise wandb.Error(f"Unhandled define_metric() arg: summary op: {i}")
|
2757
2758
|
summary_ops.append(i)
|
2758
2759
|
with telemetry.context(run=self) as tel:
|
2759
2760
|
tel.feature.metric_summary = True
|
2761
|
+
# TODO: deprecate goal
|
2760
2762
|
goal_cleaned: Optional[str] = None
|
2761
2763
|
if goal is not None:
|
2762
2764
|
goal_cleaned = goal[:3].lower()
|
@@ -2772,6 +2774,9 @@ class Run:
|
|
2772
2774
|
with telemetry.context(run=self) as tel:
|
2773
2775
|
tel.feature.metric_step_sync = True
|
2774
2776
|
|
2777
|
+
with telemetry.context(run=self) as tel:
|
2778
|
+
tel.feature.metric = True
|
2779
|
+
|
2775
2780
|
m = wandb_metric.Metric(
|
2776
2781
|
name=name,
|
2777
2782
|
step_metric=step_metric,
|
@@ -2783,8 +2788,6 @@ class Run:
|
|
2783
2788
|
)
|
2784
2789
|
m._set_callback(self._metric_callback)
|
2785
2790
|
m._commit()
|
2786
|
-
with telemetry.context(run=self) as tel:
|
2787
|
-
tel.feature.metric = True
|
2788
2791
|
return m
|
2789
2792
|
|
2790
2793
|
# TODO(jhr): annotate this
|