wandb 0.21.3__py3-none-win_amd64.whl → 0.22.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +1 -1
- wandb/_analytics.py +65 -0
- wandb/_iterutils.py +8 -0
- wandb/_pydantic/__init__.py +10 -11
- wandb/_pydantic/base.py +3 -53
- wandb/_pydantic/field_types.py +29 -0
- wandb/_pydantic/v1_compat.py +47 -30
- wandb/_strutils.py +40 -0
- wandb/apis/public/__init__.py +42 -0
- wandb/apis/public/api.py +17 -4
- wandb/apis/public/artifacts.py +5 -4
- wandb/apis/public/automations.py +2 -1
- wandb/apis/public/registries/_freezable_list.py +6 -6
- wandb/apis/public/registries/_utils.py +2 -1
- wandb/apis/public/registries/registries_search.py +4 -0
- wandb/apis/public/registries/registry.py +7 -0
- wandb/apis/public/runs.py +24 -6
- wandb/automations/_filters/expressions.py +3 -2
- wandb/automations/_filters/operators.py +2 -1
- wandb/automations/_validators.py +20 -0
- wandb/automations/actions.py +4 -2
- wandb/automations/events.py +4 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +48 -130
- wandb/cli/beta_sync.py +226 -0
- wandb/integration/dspy/__init__.py +5 -0
- wandb/integration/dspy/dspy.py +422 -0
- wandb/integration/weave/weave.py +55 -0
- wandb/proto/v3/wandb_internal_pb2.py +234 -224
- wandb/proto/v3/wandb_server_pb2.py +38 -57
- wandb/proto/v3/wandb_sync_pb2.py +87 -0
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +226 -224
- wandb/proto/v4/wandb_server_pb2.py +38 -41
- wandb/proto/v4/wandb_sync_pb2.py +38 -0
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +226 -224
- wandb/proto/v5/wandb_server_pb2.py +38 -41
- wandb/proto/v5/wandb_sync_pb2.py +39 -0
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_base_pb2.py +3 -3
- wandb/proto/v6/wandb_internal_pb2.py +229 -227
- wandb/proto/v6/wandb_server_pb2.py +41 -44
- wandb/proto/v6/wandb_settings_pb2.py +3 -3
- wandb/proto/v6/wandb_sync_pb2.py +49 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +15 -15
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/proto/wandb_sync_pb2.py +12 -0
- wandb/sdk/artifacts/_validators.py +50 -49
- wandb/sdk/artifacts/artifact.py +7 -7
- wandb/sdk/artifacts/exceptions.py +2 -1
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +3 -2
- wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +59 -124
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/interface/interface_shared.py +9 -0
- wandb/sdk/lib/asyncio_compat.py +88 -23
- wandb/sdk/lib/gql_request.py +18 -7
- wandb/sdk/lib/printer.py +9 -13
- wandb/sdk/lib/progress.py +8 -6
- wandb/sdk/lib/service/service_connection.py +42 -12
- wandb/sdk/mailbox/wait_with_progress.py +1 -1
- wandb/sdk/wandb_init.py +9 -9
- wandb/sdk/wandb_run.py +13 -1
- wandb/sdk/wandb_settings.py +55 -0
- wandb/wandb_agent.py +35 -4
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/METADATA +1 -1
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/RECORD +76 -64
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/WHEEL +0 -0
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Sequence
|
|
10
10
|
from urllib.parse import parse_qsl, urlparse
|
11
11
|
|
12
12
|
from wandb import util
|
13
|
+
from wandb._strutils import ensureprefix
|
13
14
|
from wandb.errors import CommError
|
14
15
|
from wandb.errors.term import termlog
|
15
16
|
from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
|
@@ -95,7 +96,7 @@ class S3Handler(StorageHandler):
|
|
95
96
|
path, hit, cache_open = self._cache.check_etag_obj_path(
|
96
97
|
URIStr(manifest_entry.ref),
|
97
98
|
ETag(manifest_entry.digest),
|
98
|
-
manifest_entry.size
|
99
|
+
manifest_entry.size or 0,
|
99
100
|
)
|
100
101
|
if hit:
|
101
102
|
return path
|
@@ -328,7 +329,7 @@ class S3Handler(StorageHandler):
|
|
328
329
|
return True
|
329
330
|
|
330
331
|
# Enforce HTTPS otherwise
|
331
|
-
https_url = url
|
332
|
+
https_url = ensureprefix(url, "https://")
|
332
333
|
netloc = urlparse(https_url).netloc
|
333
334
|
return bool(
|
334
335
|
# Match for https://cwobject.com
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Final
|
4
|
+
|
5
|
+
from requests import Response, Session
|
6
|
+
from requests.adapters import HTTPAdapter
|
7
|
+
from urllib3.util.retry import Retry
|
8
|
+
|
9
|
+
from ..storage_handler import StorageHandler
|
10
|
+
from ..storage_handlers.azure_handler import AzureHandler
|
11
|
+
from ..storage_handlers.gcs_handler import GCSHandler
|
12
|
+
from ..storage_handlers.http_handler import HTTPHandler
|
13
|
+
from ..storage_handlers.local_file_handler import LocalFileHandler
|
14
|
+
from ..storage_handlers.s3_handler import S3Handler
|
15
|
+
from ..storage_handlers.wb_artifact_handler import WBArtifactHandler
|
16
|
+
from ..storage_handlers.wb_local_artifact_handler import WBLocalArtifactHandler
|
17
|
+
|
18
|
+
# Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
|
19
|
+
# seconds, i.e. a total of 20min 6s.
|
20
|
+
HTTP_RETRY_STRATEGY: Final[Retry] = Retry(
|
21
|
+
backoff_factor=1,
|
22
|
+
total=16,
|
23
|
+
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
|
24
|
+
)
|
25
|
+
HTTP_POOL_CONNECTIONS: Final[int] = 64
|
26
|
+
HTTP_POOL_MAXSIZE: Final[int] = 64
|
27
|
+
|
28
|
+
|
29
|
+
def raise_for_status(response: Response, *_, **__) -> None:
|
30
|
+
"""A `requests.Session` hook to raise for status on all requests."""
|
31
|
+
response.raise_for_status()
|
32
|
+
|
33
|
+
|
34
|
+
def make_http_session() -> Session:
|
35
|
+
"""A factory that returns a `requests.Session` for use with artifact storage handlers."""
|
36
|
+
session = Session()
|
37
|
+
|
38
|
+
# Explicitly configure the retry strategy for http/https adapters.
|
39
|
+
adapter = HTTPAdapter(
|
40
|
+
max_retries=HTTP_RETRY_STRATEGY,
|
41
|
+
pool_connections=HTTP_POOL_CONNECTIONS,
|
42
|
+
pool_maxsize=HTTP_POOL_MAXSIZE,
|
43
|
+
)
|
44
|
+
session.mount("http://", adapter)
|
45
|
+
session.mount("https://", adapter)
|
46
|
+
|
47
|
+
# Always raise on HTTP status errors.
|
48
|
+
session.hooks["response"].append(raise_for_status)
|
49
|
+
return session
|
50
|
+
|
51
|
+
|
52
|
+
def make_storage_handlers(session: Session) -> list[StorageHandler]:
|
53
|
+
"""A factory that returns the default artifact storage handlers."""
|
54
|
+
return [
|
55
|
+
S3Handler(), # s3
|
56
|
+
GCSHandler(), # gcs
|
57
|
+
AzureHandler(), # azure
|
58
|
+
HTTPHandler(session, scheme="http"), # http
|
59
|
+
HTTPHandler(session, scheme="https"), # https
|
60
|
+
WBArtifactHandler(), # artifact
|
61
|
+
WBLocalArtifactHandler(), # local_artifact
|
62
|
+
LocalFileHandler(), # file_handler
|
63
|
+
]
|
@@ -16,7 +16,6 @@ from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
|
|
16
16
|
from urllib.parse import quote
|
17
17
|
|
18
18
|
import requests
|
19
|
-
import urllib3
|
20
19
|
|
21
20
|
from wandb import env
|
22
21
|
from wandb.errors.term import termwarn
|
@@ -27,40 +26,24 @@ from wandb.sdk.artifacts.artifact_file_cache import (
|
|
27
26
|
get_artifact_file_cache,
|
28
27
|
)
|
29
28
|
from wandb.sdk.artifacts.staging import get_staging_dir
|
30
|
-
from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
|
31
|
-
from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
|
32
|
-
from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
|
33
|
-
from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
|
34
29
|
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
35
|
-
from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
|
36
30
|
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
37
|
-
from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
|
38
|
-
from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
|
39
|
-
WBLocalArtifactHandler,
|
40
|
-
)
|
41
31
|
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
42
32
|
from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
|
43
33
|
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
44
34
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
45
35
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
46
|
-
from wandb.sdk.lib.hashutil import
|
36
|
+
from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id
|
47
37
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
48
38
|
|
39
|
+
from ._factories import make_http_session, make_storage_handlers
|
40
|
+
|
49
41
|
if TYPE_CHECKING:
|
50
42
|
from wandb.filesync.step_prepare import StepPrepare
|
51
43
|
from wandb.sdk.artifacts.artifact import Artifact
|
52
44
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
53
45
|
from wandb.sdk.internal import progress
|
54
46
|
|
55
|
-
# Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
|
56
|
-
# seconds, i.e. a total of 20min 6s.
|
57
|
-
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
|
58
|
-
backoff_factor=1,
|
59
|
-
total=16,
|
60
|
-
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
|
61
|
-
)
|
62
|
-
_REQUEST_POOL_CONNECTIONS = 64
|
63
|
-
_REQUEST_POOL_MAXSIZE = 64
|
64
47
|
|
65
48
|
# AWS S3 max upload parts without having to make additional requests for extra parts
|
66
49
|
S3_MAX_PART_NUMBERS = 1000
|
@@ -96,48 +79,23 @@ class WandbStoragePolicy(StoragePolicy):
|
|
96
79
|
|
97
80
|
@classmethod
|
98
81
|
def from_config(
|
99
|
-
cls, config: dict, api: InternalApi | None = None
|
82
|
+
cls, config: dict[str, Any], api: InternalApi | None = None
|
100
83
|
) -> WandbStoragePolicy:
|
101
84
|
return cls(config=config, api=api)
|
102
85
|
|
103
86
|
def __init__(
|
104
87
|
self,
|
105
|
-
config: dict | None = None,
|
88
|
+
config: dict[str, Any] | None = None,
|
106
89
|
cache: ArtifactFileCache | None = None,
|
107
90
|
api: InternalApi | None = None,
|
91
|
+
session: requests.Session | None = None,
|
108
92
|
) -> None:
|
109
|
-
self._cache = cache or get_artifact_file_cache()
|
110
93
|
self._config = config or {}
|
111
|
-
self.
|
112
|
-
|
113
|
-
max_retries=_REQUEST_RETRY_STRATEGY,
|
114
|
-
pool_connections=_REQUEST_POOL_CONNECTIONS,
|
115
|
-
pool_maxsize=_REQUEST_POOL_MAXSIZE,
|
116
|
-
)
|
117
|
-
self._session.mount("http://", adapter)
|
118
|
-
self._session.mount("https://", adapter)
|
119
|
-
|
120
|
-
s3 = S3Handler()
|
121
|
-
gcs = GCSHandler()
|
122
|
-
azure = AzureHandler()
|
123
|
-
http = HTTPHandler(self._session)
|
124
|
-
https = HTTPHandler(self._session, scheme="https")
|
125
|
-
artifact = WBArtifactHandler()
|
126
|
-
local_artifact = WBLocalArtifactHandler()
|
127
|
-
file_handler = LocalFileHandler()
|
128
|
-
|
94
|
+
self._cache = cache or get_artifact_file_cache()
|
95
|
+
self._session = session or make_http_session()
|
129
96
|
self._api = api or InternalApi()
|
130
97
|
self._handler = MultiHandler(
|
131
|
-
handlers=
|
132
|
-
s3,
|
133
|
-
gcs,
|
134
|
-
azure,
|
135
|
-
http,
|
136
|
-
https,
|
137
|
-
artifact,
|
138
|
-
local_artifact,
|
139
|
-
file_handler,
|
140
|
-
],
|
98
|
+
handlers=make_storage_handlers(self._session),
|
141
99
|
default_handler=TrackingHandler(),
|
142
100
|
)
|
143
101
|
|
@@ -167,54 +125,52 @@ class WandbStoragePolicy(StoragePolicy):
|
|
167
125
|
self._cache._override_cache_path = dest_path
|
168
126
|
|
169
127
|
path, hit, cache_open = self._cache.check_md5_obj_path(
|
170
|
-
|
171
|
-
|
128
|
+
manifest_entry.digest,
|
129
|
+
size=manifest_entry.size or 0,
|
172
130
|
)
|
173
131
|
if hit:
|
174
132
|
return path
|
175
133
|
|
176
|
-
if manifest_entry._download_url is not None:
|
134
|
+
if (url := manifest_entry._download_url) is not None:
|
177
135
|
# Use multipart parallel download for large file
|
178
136
|
if (
|
179
|
-
executor
|
180
|
-
and manifest_entry.size
|
181
|
-
and self._should_multipart_download(
|
137
|
+
executor
|
138
|
+
and (size := manifest_entry.size)
|
139
|
+
and self._should_multipart_download(size, multipart)
|
182
140
|
):
|
183
|
-
self._multipart_file_download(
|
184
|
-
executor,
|
185
|
-
manifest_entry._download_url,
|
186
|
-
manifest_entry.size,
|
187
|
-
cache_open,
|
188
|
-
)
|
141
|
+
self._multipart_file_download(executor, url, size, cache_open)
|
189
142
|
return path
|
143
|
+
|
190
144
|
# Serial download
|
191
|
-
response = self._session.get(manifest_entry._download_url, stream=True)
|
192
145
|
try:
|
193
|
-
response.
|
194
|
-
except
|
146
|
+
response = self._session.get(url, stream=True)
|
147
|
+
except requests.HTTPError:
|
195
148
|
# Signed URL might have expired, fall back to fetching it one by one.
|
196
149
|
manifest_entry._download_url = None
|
150
|
+
|
197
151
|
if manifest_entry._download_url is None:
|
198
152
|
auth = None
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
153
|
+
headers = _thread_local_api_settings.headers
|
154
|
+
cookies = _thread_local_api_settings.cookies
|
155
|
+
|
156
|
+
# For auth, prefer using (in order): auth header, cookies, HTTP Basic Auth
|
157
|
+
if token := self._api.access_token:
|
158
|
+
headers = {**(headers or {}), "Authorization": f"Bearer {token}"}
|
159
|
+
elif cookies is not None:
|
160
|
+
pass
|
161
|
+
else:
|
203
162
|
auth = ("api", self._api.api_key or "")
|
163
|
+
|
164
|
+
file_url = self._file_url(
|
165
|
+
self._api,
|
166
|
+
artifact.entity,
|
167
|
+
artifact.project,
|
168
|
+
artifact.name.split(":")[0],
|
169
|
+
manifest_entry,
|
170
|
+
)
|
204
171
|
response = self._session.get(
|
205
|
-
|
206
|
-
self._api,
|
207
|
-
artifact.entity,
|
208
|
-
artifact.project,
|
209
|
-
artifact.name.split(":")[0],
|
210
|
-
manifest_entry,
|
211
|
-
),
|
212
|
-
auth=auth,
|
213
|
-
cookies=_thread_local_api_settings.cookies,
|
214
|
-
headers=http_headers,
|
215
|
-
stream=True,
|
172
|
+
file_url, auth=auth, cookies=cookies, headers=headers, stream=True
|
216
173
|
)
|
217
|
-
response.raise_for_status()
|
218
174
|
|
219
175
|
with cache_open(mode="wb") as file:
|
220
176
|
for data in response.iter_content(chunk_size=16 * 1024):
|
@@ -269,12 +225,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
269
225
|
# Other threads has error, no need to start
|
270
226
|
if download_has_error.is_set():
|
271
227
|
return
|
272
|
-
response = self._session.get(
|
273
|
-
url=download_url,
|
274
|
-
headers=headers,
|
275
|
-
stream=True,
|
276
|
-
)
|
277
|
-
response.raise_for_status()
|
228
|
+
response = self._session.get(url=download_url, headers=headers, stream=True)
|
278
229
|
|
279
230
|
file_offset = start
|
280
231
|
for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
|
@@ -376,43 +327,27 @@ class WandbStoragePolicy(StoragePolicy):
|
|
376
327
|
entity_name: str,
|
377
328
|
project_name: str,
|
378
329
|
artifact_name: str,
|
379
|
-
|
330
|
+
entry: ArtifactManifestEntry,
|
380
331
|
) -> str:
|
381
|
-
|
382
|
-
|
383
|
-
md5_hex = b64_to_hex_id(
|
332
|
+
layout = self._config.get("storageLayout", StorageLayout.V1)
|
333
|
+
region = self._config.get("storageRegion", "default")
|
334
|
+
md5_hex = b64_to_hex_id(entry.digest)
|
384
335
|
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
336
|
+
base_url: str = api.settings("base_url")
|
337
|
+
|
338
|
+
if layout == StorageLayout.V1:
|
339
|
+
return f"{base_url}/artifacts/{entity_name}/{md5_hex}"
|
340
|
+
|
341
|
+
if layout == StorageLayout.V2:
|
342
|
+
birth_artifact_id = entry.birth_artifact_id or ""
|
390
343
|
if api._server_supports(
|
391
|
-
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
|
344
|
+
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
|
392
345
|
):
|
393
|
-
return "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}"
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
quote(artifact_name),
|
399
|
-
quote(manifest_entry.birth_artifact_id or ""),
|
400
|
-
md5_hex,
|
401
|
-
manifest_entry.path.name,
|
402
|
-
)
|
403
|
-
return "{}/artifactsV2/{}/{}/{}/{}".format(
|
404
|
-
api.settings("base_url"),
|
405
|
-
storage_region,
|
406
|
-
entity_name,
|
407
|
-
quote(
|
408
|
-
manifest_entry.birth_artifact_id
|
409
|
-
if manifest_entry.birth_artifact_id is not None
|
410
|
-
else ""
|
411
|
-
),
|
412
|
-
md5_hex,
|
413
|
-
)
|
414
|
-
else:
|
415
|
-
raise Exception(f"unrecognized storage layout: {storage_layout}")
|
346
|
+
return f"{base_url}/artifactsV2/{region}/{quote(entity_name)}/{quote(project_name)}/{quote(artifact_name)}/{quote(birth_artifact_id)}/{md5_hex}/{entry.path.name}"
|
347
|
+
|
348
|
+
return f"{base_url}/artifactsV2/{region}/{entity_name}/{quote(birth_artifact_id)}/{md5_hex}"
|
349
|
+
|
350
|
+
raise ValueError(f"unrecognized storage layout: {layout!r}")
|
416
351
|
|
417
352
|
def s3_multipart_file_upload(
|
418
353
|
self,
|
@@ -486,7 +421,7 @@ class WandbStoragePolicy(StoragePolicy):
|
|
486
421
|
True if the file was a duplicate (did not need to be uploaded),
|
487
422
|
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
488
423
|
"""
|
489
|
-
file_size = entry.size
|
424
|
+
file_size = entry.size or 0
|
490
425
|
chunk_size = self.calc_chunk_size(file_size)
|
491
426
|
upload_parts = []
|
492
427
|
hex_digests = {}
|
@@ -562,8 +497,8 @@ class WandbStoragePolicy(StoragePolicy):
|
|
562
497
|
|
563
498
|
# Cache upon successful upload.
|
564
499
|
_, hit, cache_open = self._cache.check_md5_obj_path(
|
565
|
-
|
566
|
-
|
500
|
+
entry.digest,
|
501
|
+
size=entry.size or 0,
|
567
502
|
)
|
568
503
|
|
569
504
|
staging_dir = get_staging_dir()
|
wandb/sdk/interface/interface.py
CHANGED
@@ -883,6 +883,16 @@ class InterfaceBase:
|
|
883
883
|
) -> MailboxHandle[pb.Result]:
|
884
884
|
raise NotImplementedError
|
885
885
|
|
886
|
+
def publish_probe_system_info(self) -> None:
|
887
|
+
probe_system_info = pb.ProbeSystemInfoRequest()
|
888
|
+
return self._publish_probe_system_info(probe_system_info)
|
889
|
+
|
890
|
+
@abstractmethod
|
891
|
+
def _publish_probe_system_info(
|
892
|
+
self, probe_system_info: pb.ProbeSystemInfoRequest
|
893
|
+
) -> None:
|
894
|
+
raise NotImplementedError
|
895
|
+
|
886
896
|
def join(self) -> None:
|
887
897
|
# Drop indicates that the internal process has already been shutdown
|
888
898
|
if self._drop:
|
@@ -112,6 +112,7 @@ class InterfaceShared(InterfaceBase):
|
|
112
112
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
113
113
|
job_input: Optional[pb.JobInputRequest] = None,
|
114
114
|
run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
|
115
|
+
probe_system_info: Optional[pb.ProbeSystemInfoRequest] = None,
|
115
116
|
) -> pb.Record:
|
116
117
|
request = pb.Request()
|
117
118
|
if get_summary:
|
@@ -178,6 +179,8 @@ class InterfaceShared(InterfaceBase):
|
|
178
179
|
request.job_input.CopyFrom(job_input)
|
179
180
|
elif run_finish_without_exit:
|
180
181
|
request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
|
182
|
+
elif probe_system_info:
|
183
|
+
request.probe_system_info.CopyFrom(probe_system_info)
|
181
184
|
else:
|
182
185
|
raise Exception("Invalid request")
|
183
186
|
record = self._make_record(request=request)
|
@@ -330,6 +333,12 @@ class InterfaceShared(InterfaceBase):
|
|
330
333
|
rec = self._make_record(use_artifact=use_artifact)
|
331
334
|
self._publish(rec)
|
332
335
|
|
336
|
+
def _publish_probe_system_info(
|
337
|
+
self, probe_system_info: pb.ProbeSystemInfoRequest
|
338
|
+
) -> None:
|
339
|
+
record = self._make_request(probe_system_info=probe_system_info)
|
340
|
+
self._publish(record)
|
341
|
+
|
333
342
|
def _deliver_artifact(
|
334
343
|
self,
|
335
344
|
log_artifact: pb.LogArtifactRequest,
|
wandb/sdk/lib/asyncio_compat.py
CHANGED
@@ -7,7 +7,7 @@ import concurrent
|
|
7
7
|
import concurrent.futures
|
8
8
|
import contextlib
|
9
9
|
import threading
|
10
|
-
from typing import Any, AsyncIterator, Callable, Coroutine,
|
10
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, TypeVar
|
11
11
|
|
12
12
|
_T = TypeVar("_T")
|
13
13
|
|
@@ -143,34 +143,71 @@ class TaskGroup:
|
|
143
143
|
"""
|
144
144
|
self._tasks.append(asyncio.create_task(coro))
|
145
145
|
|
146
|
-
async def _wait_all(self) -> None:
|
147
|
-
"""Block until
|
146
|
+
async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
|
147
|
+
"""Block until tasks complete.
|
148
|
+
|
149
|
+
Args:
|
150
|
+
race: If true, blocks until the first task completes and then
|
151
|
+
cancels the rest. Otherwise, waits for all tasks or until
|
152
|
+
the first exception.
|
153
|
+
timeout: How long to wait.
|
148
154
|
|
149
155
|
Raises:
|
156
|
+
TimeoutError: If the timeout expires.
|
150
157
|
Exception: If one or more tasks raises an exception, one of these
|
151
158
|
is raised arbitrarily.
|
152
159
|
"""
|
153
|
-
|
160
|
+
if not self._tasks:
|
161
|
+
return
|
162
|
+
|
163
|
+
if race:
|
164
|
+
return_when = asyncio.FIRST_COMPLETED
|
165
|
+
else:
|
166
|
+
return_when = asyncio.FIRST_EXCEPTION
|
167
|
+
|
168
|
+
done, pending = await asyncio.wait(
|
154
169
|
self._tasks,
|
155
|
-
|
156
|
-
|
157
|
-
return_when=concurrent.futures.FIRST_EXCEPTION,
|
170
|
+
timeout=timeout,
|
171
|
+
return_when=return_when,
|
158
172
|
)
|
159
173
|
|
174
|
+
if not done:
|
175
|
+
raise TimeoutError(f"Timed out after {timeout} seconds.")
|
176
|
+
|
177
|
+
# If any of the finished tasks raised an exception, pick the first one.
|
160
178
|
for task in done:
|
161
|
-
|
162
|
-
|
163
|
-
raise exc
|
179
|
+
if exc := task.exception():
|
180
|
+
raise exc
|
164
181
|
|
165
|
-
|
166
|
-
|
182
|
+
# Wait for remaining tasks to clean up, then re-raise any exceptions
|
183
|
+
# that arise. Note that pending is only non-empty when race=True.
|
184
|
+
for task in pending:
|
185
|
+
task.cancel()
|
186
|
+
await asyncio.gather(*pending, return_exceptions=True)
|
187
|
+
for task in pending:
|
188
|
+
if task.cancelled():
|
189
|
+
continue
|
190
|
+
if exc := task.exception():
|
191
|
+
raise exc
|
192
|
+
|
193
|
+
async def _cancel_all(self) -> None:
|
194
|
+
"""Cancel all tasks.
|
195
|
+
|
196
|
+
Blocks until cancelled tasks complete to allow them to clean up.
|
197
|
+
Ignores exceptions.
|
198
|
+
"""
|
167
199
|
for task in self._tasks:
|
168
200
|
# NOTE: It is safe to cancel tasks that have already completed.
|
169
201
|
task.cancel()
|
202
|
+
await asyncio.gather(*self._tasks, return_exceptions=True)
|
170
203
|
|
171
204
|
|
172
205
|
@contextlib.asynccontextmanager
|
173
|
-
async def open_task_group(
|
206
|
+
async def open_task_group(
|
207
|
+
*,
|
208
|
+
exit_timeout: float | None = None,
|
209
|
+
race: bool = False,
|
210
|
+
) -> AsyncIterator[TaskGroup]:
|
174
211
|
"""Create a task group.
|
175
212
|
|
176
213
|
`asyncio` gained task groups in Python 3.11.
|
@@ -184,30 +221,58 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
|
|
184
221
|
NOTE: Subtask exceptions do not propagate until the context manager exits.
|
185
222
|
This means that the task group cannot cancel code running inside the
|
186
223
|
`async with` block .
|
224
|
+
|
225
|
+
Args:
|
226
|
+
exit_timeout: An optional timeout in seconds. When exiting the
|
227
|
+
context manager, if tasks don't complete in this time,
|
228
|
+
they are cancelled and a TimeoutError is raised.
|
229
|
+
race: If true, all pending tasks are cancelled once any task
|
230
|
+
in the group completes. Prefer to use the race() function instead.
|
231
|
+
|
232
|
+
Raises:
|
233
|
+
TimeoutError: if exit_timeout is specified and tasks don't finish
|
234
|
+
in time.
|
187
235
|
"""
|
188
236
|
task_group = TaskGroup()
|
189
237
|
|
190
238
|
try:
|
191
239
|
yield task_group
|
192
|
-
await task_group._wait_all()
|
240
|
+
await task_group._wait_all(race=race, timeout=exit_timeout)
|
193
241
|
finally:
|
194
|
-
task_group._cancel_all()
|
242
|
+
await task_group._cancel_all()
|
195
243
|
|
196
244
|
|
197
|
-
@contextlib.
|
198
|
-
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) ->
|
245
|
+
@contextlib.asynccontextmanager
|
246
|
+
async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
|
199
247
|
"""Schedule a task, cancelling it when exiting the context manager.
|
200
248
|
|
201
249
|
If the context manager exits successfully but the given coroutine raises
|
202
250
|
an exception, that exception is reraised. The exception is suppressed
|
203
251
|
if the context manager raises an exception.
|
204
252
|
"""
|
205
|
-
task = asyncio.create_task(coro)
|
206
253
|
|
207
|
-
|
254
|
+
async def stop_immediately():
|
255
|
+
pass
|
256
|
+
|
257
|
+
async with open_task_group(race=True) as group:
|
258
|
+
group.start_soon(stop_immediately())
|
259
|
+
group.start_soon(coro)
|
208
260
|
yield
|
209
261
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
262
|
+
|
263
|
+
async def race(*coros: Coroutine[Any, Any, Any]) -> None:
|
264
|
+
"""Wait until the first completed task.
|
265
|
+
|
266
|
+
After any coroutine completes, all others are cancelled.
|
267
|
+
If the current task is cancelled, all coroutines are cancelled too.
|
268
|
+
|
269
|
+
If coroutines complete simultaneously and any one of them raises
|
270
|
+
an exception, an arbitrary one is propagated. Similarly, if any coroutines
|
271
|
+
raise exceptions during cancellation, one of them propagates.
|
272
|
+
|
273
|
+
Args:
|
274
|
+
coros: Coroutines to race.
|
275
|
+
"""
|
276
|
+
async with open_task_group(race=True) as tg:
|
277
|
+
for coro in coros:
|
278
|
+
tg.start_soon(coro)
|
wandb/sdk/lib/gql_request.py
CHANGED
@@ -4,7 +4,9 @@ Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
|
|
4
4
|
The only substantial change is to reuse a requests.Session object.
|
5
5
|
"""
|
6
6
|
|
7
|
-
from
|
7
|
+
from __future__ import annotations
|
8
|
+
|
9
|
+
from typing import Any, Callable
|
8
10
|
|
9
11
|
import requests
|
10
12
|
from wandb_gql.transport.http import HTTPTransport
|
@@ -12,15 +14,17 @@ from wandb_graphql.execution import ExecutionResult
|
|
12
14
|
from wandb_graphql.language import ast
|
13
15
|
from wandb_graphql.language.printer import print_ast
|
14
16
|
|
17
|
+
from wandb._analytics import tracked_func
|
18
|
+
|
15
19
|
|
16
20
|
class GraphQLSession(HTTPTransport):
|
17
21
|
def __init__(
|
18
22
|
self,
|
19
23
|
url: str,
|
20
|
-
auth:
|
24
|
+
auth: tuple[str, str] | Callable | None = None,
|
21
25
|
use_json: bool = False,
|
22
|
-
timeout:
|
23
|
-
proxies:
|
26
|
+
timeout: int | float | None = None,
|
27
|
+
proxies: dict[str, str] | None = None,
|
24
28
|
**kwargs: Any,
|
25
29
|
) -> None:
|
26
30
|
"""Setup a session for sending GraphQL queries and mutations.
|
@@ -42,15 +46,22 @@ class GraphQLSession(HTTPTransport):
|
|
42
46
|
def execute(
|
43
47
|
self,
|
44
48
|
document: ast.Node,
|
45
|
-
variable_values:
|
46
|
-
timeout:
|
49
|
+
variable_values: dict[str, Any] | None = None,
|
50
|
+
timeout: int | float | None = None,
|
47
51
|
) -> ExecutionResult:
|
48
52
|
query_str = print_ast(document)
|
49
53
|
payload = {"query": query_str, "variables": variable_values or {}}
|
50
54
|
|
51
55
|
data_key = "json" if self.use_json else "data"
|
56
|
+
|
57
|
+
headers = self.headers.copy() if self.headers else {}
|
58
|
+
|
59
|
+
# If we're tracking a calling python function, include it in the headers
|
60
|
+
if func_info := tracked_func():
|
61
|
+
headers.update(func_info.to_headers())
|
62
|
+
|
52
63
|
post_args = {
|
53
|
-
"headers":
|
64
|
+
"headers": headers or None,
|
54
65
|
"cookies": self.cookies,
|
55
66
|
"timeout": timeout or self.default_timeout,
|
56
67
|
data_key: payload,
|