wandb 0.21.3__py3-none-win32.whl → 0.22.0__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/_analytics.py +65 -0
  4. wandb/_iterutils.py +8 -0
  5. wandb/_pydantic/__init__.py +10 -11
  6. wandb/_pydantic/base.py +3 -53
  7. wandb/_pydantic/field_types.py +29 -0
  8. wandb/_pydantic/v1_compat.py +47 -30
  9. wandb/_strutils.py +40 -0
  10. wandb/apis/public/__init__.py +42 -0
  11. wandb/apis/public/api.py +17 -4
  12. wandb/apis/public/artifacts.py +5 -4
  13. wandb/apis/public/automations.py +2 -1
  14. wandb/apis/public/registries/_freezable_list.py +6 -6
  15. wandb/apis/public/registries/_utils.py +2 -1
  16. wandb/apis/public/registries/registries_search.py +4 -0
  17. wandb/apis/public/registries/registry.py +7 -0
  18. wandb/apis/public/runs.py +24 -6
  19. wandb/automations/_filters/expressions.py +3 -2
  20. wandb/automations/_filters/operators.py +2 -1
  21. wandb/automations/_validators.py +20 -0
  22. wandb/automations/actions.py +4 -2
  23. wandb/automations/events.py +4 -5
  24. wandb/bin/gpu_stats.exe +0 -0
  25. wandb/bin/wandb-core +0 -0
  26. wandb/cli/beta.py +48 -130
  27. wandb/cli/beta_sync.py +226 -0
  28. wandb/integration/dspy/__init__.py +5 -0
  29. wandb/integration/dspy/dspy.py +422 -0
  30. wandb/integration/weave/weave.py +55 -0
  31. wandb/proto/v3/wandb_internal_pb2.py +234 -224
  32. wandb/proto/v3/wandb_server_pb2.py +38 -57
  33. wandb/proto/v3/wandb_sync_pb2.py +87 -0
  34. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  35. wandb/proto/v4/wandb_internal_pb2.py +226 -224
  36. wandb/proto/v4/wandb_server_pb2.py +38 -41
  37. wandb/proto/v4/wandb_sync_pb2.py +38 -0
  38. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  39. wandb/proto/v5/wandb_internal_pb2.py +226 -224
  40. wandb/proto/v5/wandb_server_pb2.py +38 -41
  41. wandb/proto/v5/wandb_sync_pb2.py +39 -0
  42. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  43. wandb/proto/v6/wandb_base_pb2.py +3 -3
  44. wandb/proto/v6/wandb_internal_pb2.py +229 -227
  45. wandb/proto/v6/wandb_server_pb2.py +41 -44
  46. wandb/proto/v6/wandb_settings_pb2.py +3 -3
  47. wandb/proto/v6/wandb_sync_pb2.py +49 -0
  48. wandb/proto/v6/wandb_telemetry_pb2.py +15 -15
  49. wandb/proto/wandb_generate_proto.py +1 -0
  50. wandb/proto/wandb_sync_pb2.py +12 -0
  51. wandb/sdk/artifacts/_validators.py +50 -49
  52. wandb/sdk/artifacts/artifact.py +7 -7
  53. wandb/sdk/artifacts/exceptions.py +2 -1
  54. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
  55. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
  56. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
  57. wandb/sdk/artifacts/storage_handlers/s3_handler.py +3 -2
  58. wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +59 -124
  60. wandb/sdk/interface/interface.py +10 -0
  61. wandb/sdk/interface/interface_shared.py +9 -0
  62. wandb/sdk/lib/asyncio_compat.py +88 -23
  63. wandb/sdk/lib/gql_request.py +18 -7
  64. wandb/sdk/lib/printer.py +9 -13
  65. wandb/sdk/lib/progress.py +8 -6
  66. wandb/sdk/lib/service/service_connection.py +42 -12
  67. wandb/sdk/mailbox/wait_with_progress.py +1 -1
  68. wandb/sdk/wandb_init.py +9 -9
  69. wandb/sdk/wandb_run.py +13 -1
  70. wandb/sdk/wandb_settings.py +55 -0
  71. wandb/wandb_agent.py +35 -4
  72. {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/METADATA +1 -1
  73. {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/RECORD +76 -64
  74. {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/WHEEL +0 -0
  75. {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/entry_points.txt +0 -0
  76. {wandb-0.21.3.dist-info → wandb-0.22.0.dist-info}/licenses/LICENSE +0 -0
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Sequence
10
10
  from urllib.parse import parse_qsl, urlparse
11
11
 
12
12
  from wandb import util
13
+ from wandb._strutils import ensureprefix
13
14
  from wandb.errors import CommError
14
15
  from wandb.errors.term import termlog
15
16
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
@@ -95,7 +96,7 @@ class S3Handler(StorageHandler):
95
96
  path, hit, cache_open = self._cache.check_etag_obj_path(
96
97
  URIStr(manifest_entry.ref),
97
98
  ETag(manifest_entry.digest),
98
- manifest_entry.size if manifest_entry.size is not None else 0,
99
+ manifest_entry.size or 0,
99
100
  )
100
101
  if hit:
101
102
  return path
@@ -328,7 +329,7 @@ class S3Handler(StorageHandler):
328
329
  return True
329
330
 
330
331
  # Enforce HTTPS otherwise
331
- https_url = url if url.startswith("https://") else f"https://{url}"
332
+ https_url = ensureprefix(url, "https://")
332
333
  netloc = urlparse(https_url).netloc
333
334
  return bool(
334
335
  # Match for https://cwobject.com
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Final
4
+
5
+ from requests import Response, Session
6
+ from requests.adapters import HTTPAdapter
7
+ from urllib3.util.retry import Retry
8
+
9
+ from ..storage_handler import StorageHandler
10
+ from ..storage_handlers.azure_handler import AzureHandler
11
+ from ..storage_handlers.gcs_handler import GCSHandler
12
+ from ..storage_handlers.http_handler import HTTPHandler
13
+ from ..storage_handlers.local_file_handler import LocalFileHandler
14
+ from ..storage_handlers.s3_handler import S3Handler
15
+ from ..storage_handlers.wb_artifact_handler import WBArtifactHandler
16
+ from ..storage_handlers.wb_local_artifact_handler import WBLocalArtifactHandler
17
+
18
+ # Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
19
+ # seconds, i.e. a total of 20min 6s.
20
+ HTTP_RETRY_STRATEGY: Final[Retry] = Retry(
21
+ backoff_factor=1,
22
+ total=16,
23
+ status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
24
+ )
25
+ HTTP_POOL_CONNECTIONS: Final[int] = 64
26
+ HTTP_POOL_MAXSIZE: Final[int] = 64
27
+
28
+
29
+ def raise_for_status(response: Response, *_, **__) -> None:
30
+ """A `requests.Session` hook to raise for status on all requests."""
31
+ response.raise_for_status()
32
+
33
+
34
+ def make_http_session() -> Session:
35
+ """A factory that returns a `requests.Session` for use with artifact storage handlers."""
36
+ session = Session()
37
+
38
+ # Explicitly configure the retry strategy for http/https adapters.
39
+ adapter = HTTPAdapter(
40
+ max_retries=HTTP_RETRY_STRATEGY,
41
+ pool_connections=HTTP_POOL_CONNECTIONS,
42
+ pool_maxsize=HTTP_POOL_MAXSIZE,
43
+ )
44
+ session.mount("http://", adapter)
45
+ session.mount("https://", adapter)
46
+
47
+ # Always raise on HTTP status errors.
48
+ session.hooks["response"].append(raise_for_status)
49
+ return session
50
+
51
+
52
+ def make_storage_handlers(session: Session) -> list[StorageHandler]:
53
+ """A factory that returns the default artifact storage handlers."""
54
+ return [
55
+ S3Handler(), # s3
56
+ GCSHandler(), # gcs
57
+ AzureHandler(), # azure
58
+ HTTPHandler(session, scheme="http"), # http
59
+ HTTPHandler(session, scheme="https"), # https
60
+ WBArtifactHandler(), # artifact
61
+ WBLocalArtifactHandler(), # local_artifact
62
+ LocalFileHandler(), # file_handler
63
+ ]
@@ -16,7 +16,6 @@ from typing import IO, TYPE_CHECKING, Any, NamedTuple, Sequence
16
16
  from urllib.parse import quote
17
17
 
18
18
  import requests
19
- import urllib3
20
19
 
21
20
  from wandb import env
22
21
  from wandb.errors.term import termwarn
@@ -27,40 +26,24 @@ from wandb.sdk.artifacts.artifact_file_cache import (
27
26
  get_artifact_file_cache,
28
27
  )
29
28
  from wandb.sdk.artifacts.staging import get_staging_dir
30
- from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
31
- from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
32
- from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
33
- from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
34
29
  from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
35
- from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
36
30
  from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
37
- from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
38
- from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
39
- WBLocalArtifactHandler,
40
- )
41
31
  from wandb.sdk.artifacts.storage_layout import StorageLayout
42
32
  from wandb.sdk.artifacts.storage_policies.register import WANDB_STORAGE_POLICY
43
33
  from wandb.sdk.artifacts.storage_policy import StoragePolicy
44
34
  from wandb.sdk.internal.internal_api import Api as InternalApi
45
35
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
46
- from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, hex_to_b64_id
36
+ from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id
47
37
  from wandb.sdk.lib.paths import FilePathStr, URIStr
48
38
 
39
+ from ._factories import make_http_session, make_storage_handlers
40
+
49
41
  if TYPE_CHECKING:
50
42
  from wandb.filesync.step_prepare import StepPrepare
51
43
  from wandb.sdk.artifacts.artifact import Artifact
52
44
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
53
45
  from wandb.sdk.internal import progress
54
46
 
55
- # Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
56
- # seconds, i.e. a total of 20min 6s.
57
- _REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
58
- backoff_factor=1,
59
- total=16,
60
- status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
61
- )
62
- _REQUEST_POOL_CONNECTIONS = 64
63
- _REQUEST_POOL_MAXSIZE = 64
64
47
 
65
48
  # AWS S3 max upload parts without having to make additional requests for extra parts
66
49
  S3_MAX_PART_NUMBERS = 1000
@@ -96,48 +79,23 @@ class WandbStoragePolicy(StoragePolicy):
96
79
 
97
80
  @classmethod
98
81
  def from_config(
99
- cls, config: dict, api: InternalApi | None = None
82
+ cls, config: dict[str, Any], api: InternalApi | None = None
100
83
  ) -> WandbStoragePolicy:
101
84
  return cls(config=config, api=api)
102
85
 
103
86
  def __init__(
104
87
  self,
105
- config: dict | None = None,
88
+ config: dict[str, Any] | None = None,
106
89
  cache: ArtifactFileCache | None = None,
107
90
  api: InternalApi | None = None,
91
+ session: requests.Session | None = None,
108
92
  ) -> None:
109
- self._cache = cache or get_artifact_file_cache()
110
93
  self._config = config or {}
111
- self._session = requests.Session()
112
- adapter = requests.adapters.HTTPAdapter(
113
- max_retries=_REQUEST_RETRY_STRATEGY,
114
- pool_connections=_REQUEST_POOL_CONNECTIONS,
115
- pool_maxsize=_REQUEST_POOL_MAXSIZE,
116
- )
117
- self._session.mount("http://", adapter)
118
- self._session.mount("https://", adapter)
119
-
120
- s3 = S3Handler()
121
- gcs = GCSHandler()
122
- azure = AzureHandler()
123
- http = HTTPHandler(self._session)
124
- https = HTTPHandler(self._session, scheme="https")
125
- artifact = WBArtifactHandler()
126
- local_artifact = WBLocalArtifactHandler()
127
- file_handler = LocalFileHandler()
128
-
94
+ self._cache = cache or get_artifact_file_cache()
95
+ self._session = session or make_http_session()
129
96
  self._api = api or InternalApi()
130
97
  self._handler = MultiHandler(
131
- handlers=[
132
- s3,
133
- gcs,
134
- azure,
135
- http,
136
- https,
137
- artifact,
138
- local_artifact,
139
- file_handler,
140
- ],
98
+ handlers=make_storage_handlers(self._session),
141
99
  default_handler=TrackingHandler(),
142
100
  )
143
101
 
@@ -167,54 +125,52 @@ class WandbStoragePolicy(StoragePolicy):
167
125
  self._cache._override_cache_path = dest_path
168
126
 
169
127
  path, hit, cache_open = self._cache.check_md5_obj_path(
170
- B64MD5(manifest_entry.digest),
171
- manifest_entry.size if manifest_entry.size is not None else 0,
128
+ manifest_entry.digest,
129
+ size=manifest_entry.size or 0,
172
130
  )
173
131
  if hit:
174
132
  return path
175
133
 
176
- if manifest_entry._download_url is not None:
134
+ if (url := manifest_entry._download_url) is not None:
177
135
  # Use multipart parallel download for large file
178
136
  if (
179
- executor is not None
180
- and manifest_entry.size is not None
181
- and self._should_multipart_download(manifest_entry.size, multipart)
137
+ executor
138
+ and (size := manifest_entry.size)
139
+ and self._should_multipart_download(size, multipart)
182
140
  ):
183
- self._multipart_file_download(
184
- executor,
185
- manifest_entry._download_url,
186
- manifest_entry.size,
187
- cache_open,
188
- )
141
+ self._multipart_file_download(executor, url, size, cache_open)
189
142
  return path
143
+
190
144
  # Serial download
191
- response = self._session.get(manifest_entry._download_url, stream=True)
192
145
  try:
193
- response.raise_for_status()
194
- except Exception:
146
+ response = self._session.get(url, stream=True)
147
+ except requests.HTTPError:
195
148
  # Signed URL might have expired, fall back to fetching it one by one.
196
149
  manifest_entry._download_url = None
150
+
197
151
  if manifest_entry._download_url is None:
198
152
  auth = None
199
- http_headers = _thread_local_api_settings.headers or {}
200
- if self._api.access_token is not None:
201
- http_headers["Authorization"] = f"Bearer {self._api.access_token}"
202
- elif _thread_local_api_settings.cookies is None:
153
+ headers = _thread_local_api_settings.headers
154
+ cookies = _thread_local_api_settings.cookies
155
+
156
+ # For auth, prefer using (in order): auth header, cookies, HTTP Basic Auth
157
+ if token := self._api.access_token:
158
+ headers = {**(headers or {}), "Authorization": f"Bearer {token}"}
159
+ elif cookies is not None:
160
+ pass
161
+ else:
203
162
  auth = ("api", self._api.api_key or "")
163
+
164
+ file_url = self._file_url(
165
+ self._api,
166
+ artifact.entity,
167
+ artifact.project,
168
+ artifact.name.split(":")[0],
169
+ manifest_entry,
170
+ )
204
171
  response = self._session.get(
205
- self._file_url(
206
- self._api,
207
- artifact.entity,
208
- artifact.project,
209
- artifact.name.split(":")[0],
210
- manifest_entry,
211
- ),
212
- auth=auth,
213
- cookies=_thread_local_api_settings.cookies,
214
- headers=http_headers,
215
- stream=True,
172
+ file_url, auth=auth, cookies=cookies, headers=headers, stream=True
216
173
  )
217
- response.raise_for_status()
218
174
 
219
175
  with cache_open(mode="wb") as file:
220
176
  for data in response.iter_content(chunk_size=16 * 1024):
@@ -269,12 +225,7 @@ class WandbStoragePolicy(StoragePolicy):
269
225
  # Other threads has error, no need to start
270
226
  if download_has_error.is_set():
271
227
  return
272
- response = self._session.get(
273
- url=download_url,
274
- headers=headers,
275
- stream=True,
276
- )
277
- response.raise_for_status()
228
+ response = self._session.get(url=download_url, headers=headers, stream=True)
278
229
 
279
230
  file_offset = start
280
231
  for content in response.iter_content(chunk_size=_HTTP_RES_CHUNK_SIZE_BYTES):
@@ -376,43 +327,27 @@ class WandbStoragePolicy(StoragePolicy):
376
327
  entity_name: str,
377
328
  project_name: str,
378
329
  artifact_name: str,
379
- manifest_entry: ArtifactManifestEntry,
330
+ entry: ArtifactManifestEntry,
380
331
  ) -> str:
381
- storage_layout = self._config.get("storageLayout", StorageLayout.V1)
382
- storage_region = self._config.get("storageRegion", "default")
383
- md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
332
+ layout = self._config.get("storageLayout", StorageLayout.V1)
333
+ region = self._config.get("storageRegion", "default")
334
+ md5_hex = b64_to_hex_id(entry.digest)
384
335
 
385
- if storage_layout == StorageLayout.V1:
386
- return "{}/artifacts/{}/{}".format(
387
- api.settings("base_url"), entity_name, md5_hex
388
- )
389
- elif storage_layout == StorageLayout.V2:
336
+ base_url: str = api.settings("base_url")
337
+
338
+ if layout == StorageLayout.V1:
339
+ return f"{base_url}/artifacts/{entity_name}/{md5_hex}"
340
+
341
+ if layout == StorageLayout.V2:
342
+ birth_artifact_id = entry.birth_artifact_id or ""
390
343
  if api._server_supports(
391
- ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER # type: ignore
344
+ ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILE_DOWNLOAD_HANDLER
392
345
  ):
393
- return "{}/artifactsV2/{}/{}/{}/{}/{}/{}/{}".format(
394
- api.settings("base_url"),
395
- storage_region,
396
- quote(entity_name),
397
- quote(project_name),
398
- quote(artifact_name),
399
- quote(manifest_entry.birth_artifact_id or ""),
400
- md5_hex,
401
- manifest_entry.path.name,
402
- )
403
- return "{}/artifactsV2/{}/{}/{}/{}".format(
404
- api.settings("base_url"),
405
- storage_region,
406
- entity_name,
407
- quote(
408
- manifest_entry.birth_artifact_id
409
- if manifest_entry.birth_artifact_id is not None
410
- else ""
411
- ),
412
- md5_hex,
413
- )
414
- else:
415
- raise Exception(f"unrecognized storage layout: {storage_layout}")
346
+ return f"{base_url}/artifactsV2/{region}/{quote(entity_name)}/{quote(project_name)}/{quote(artifact_name)}/{quote(birth_artifact_id)}/{md5_hex}/{entry.path.name}"
347
+
348
+ return f"{base_url}/artifactsV2/{region}/{entity_name}/{quote(birth_artifact_id)}/{md5_hex}"
349
+
350
+ raise ValueError(f"unrecognized storage layout: {layout!r}")
416
351
 
417
352
  def s3_multipart_file_upload(
418
353
  self,
@@ -486,7 +421,7 @@ class WandbStoragePolicy(StoragePolicy):
486
421
  True if the file was a duplicate (did not need to be uploaded),
487
422
  False if it needed to be uploaded or was a reference (nothing to dedupe).
488
423
  """
489
- file_size = entry.size if entry.size is not None else 0
424
+ file_size = entry.size or 0
490
425
  chunk_size = self.calc_chunk_size(file_size)
491
426
  upload_parts = []
492
427
  hex_digests = {}
@@ -562,8 +497,8 @@ class WandbStoragePolicy(StoragePolicy):
562
497
 
563
498
  # Cache upon successful upload.
564
499
  _, hit, cache_open = self._cache.check_md5_obj_path(
565
- B64MD5(entry.digest),
566
- entry.size if entry.size is not None else 0,
500
+ entry.digest,
501
+ size=entry.size or 0,
567
502
  )
568
503
 
569
504
  staging_dir = get_staging_dir()
@@ -883,6 +883,16 @@ class InterfaceBase:
883
883
  ) -> MailboxHandle[pb.Result]:
884
884
  raise NotImplementedError
885
885
 
886
+ def publish_probe_system_info(self) -> None:
887
+ probe_system_info = pb.ProbeSystemInfoRequest()
888
+ return self._publish_probe_system_info(probe_system_info)
889
+
890
+ @abstractmethod
891
+ def _publish_probe_system_info(
892
+ self, probe_system_info: pb.ProbeSystemInfoRequest
893
+ ) -> None:
894
+ raise NotImplementedError
895
+
886
896
  def join(self) -> None:
887
897
  # Drop indicates that the internal process has already been shutdown
888
898
  if self._drop:
@@ -112,6 +112,7 @@ class InterfaceShared(InterfaceBase):
112
112
  python_packages: Optional[pb.PythonPackagesRequest] = None,
113
113
  job_input: Optional[pb.JobInputRequest] = None,
114
114
  run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
115
+ probe_system_info: Optional[pb.ProbeSystemInfoRequest] = None,
115
116
  ) -> pb.Record:
116
117
  request = pb.Request()
117
118
  if get_summary:
@@ -178,6 +179,8 @@ class InterfaceShared(InterfaceBase):
178
179
  request.job_input.CopyFrom(job_input)
179
180
  elif run_finish_without_exit:
180
181
  request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
182
+ elif probe_system_info:
183
+ request.probe_system_info.CopyFrom(probe_system_info)
181
184
  else:
182
185
  raise Exception("Invalid request")
183
186
  record = self._make_record(request=request)
@@ -330,6 +333,12 @@ class InterfaceShared(InterfaceBase):
330
333
  rec = self._make_record(use_artifact=use_artifact)
331
334
  self._publish(rec)
332
335
 
336
+ def _publish_probe_system_info(
337
+ self, probe_system_info: pb.ProbeSystemInfoRequest
338
+ ) -> None:
339
+ record = self._make_request(probe_system_info=probe_system_info)
340
+ self._publish(record)
341
+
333
342
  def _deliver_artifact(
334
343
  self,
335
344
  log_artifact: pb.LogArtifactRequest,
@@ -7,7 +7,7 @@ import concurrent
7
7
  import concurrent.futures
8
8
  import contextlib
9
9
  import threading
10
- from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
10
+ from typing import Any, AsyncIterator, Callable, Coroutine, TypeVar
11
11
 
12
12
  _T = TypeVar("_T")
13
13
 
@@ -143,34 +143,71 @@ class TaskGroup:
143
143
  """
144
144
  self._tasks.append(asyncio.create_task(coro))
145
145
 
146
- async def _wait_all(self) -> None:
147
- """Block until all tasks complete.
146
+ async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
147
+ """Block until tasks complete.
148
+
149
+ Args:
150
+ race: If true, blocks until the first task completes and then
151
+ cancels the rest. Otherwise, waits for all tasks or until
152
+ the first exception.
153
+ timeout: How long to wait.
148
154
 
149
155
  Raises:
156
+ TimeoutError: If the timeout expires.
150
157
  Exception: If one or more tasks raises an exception, one of these
151
158
  is raised arbitrarily.
152
159
  """
153
- done, _ = await asyncio.wait(
160
+ if not self._tasks:
161
+ return
162
+
163
+ if race:
164
+ return_when = asyncio.FIRST_COMPLETED
165
+ else:
166
+ return_when = asyncio.FIRST_EXCEPTION
167
+
168
+ done, pending = await asyncio.wait(
154
169
  self._tasks,
155
- # NOTE: Cancelling a task counts as a normal exit,
156
- # not an exception.
157
- return_when=concurrent.futures.FIRST_EXCEPTION,
170
+ timeout=timeout,
171
+ return_when=return_when,
158
172
  )
159
173
 
174
+ if not done:
175
+ raise TimeoutError(f"Timed out after {timeout} seconds.")
176
+
177
+ # If any of the finished tasks raised an exception, pick the first one.
160
178
  for task in done:
161
- with contextlib.suppress(asyncio.CancelledError):
162
- if exc := task.exception():
163
- raise exc
179
+ if exc := task.exception():
180
+ raise exc
164
181
 
165
- def _cancel_all(self) -> None:
166
- """Cancel all tasks."""
182
+ # Wait for remaining tasks to clean up, then re-raise any exceptions
183
+ # that arise. Note that pending is only non-empty when race=True.
184
+ for task in pending:
185
+ task.cancel()
186
+ await asyncio.gather(*pending, return_exceptions=True)
187
+ for task in pending:
188
+ if task.cancelled():
189
+ continue
190
+ if exc := task.exception():
191
+ raise exc
192
+
193
+ async def _cancel_all(self) -> None:
194
+ """Cancel all tasks.
195
+
196
+ Blocks until cancelled tasks complete to allow them to clean up.
197
+ Ignores exceptions.
198
+ """
167
199
  for task in self._tasks:
168
200
  # NOTE: It is safe to cancel tasks that have already completed.
169
201
  task.cancel()
202
+ await asyncio.gather(*self._tasks, return_exceptions=True)
170
203
 
171
204
 
172
205
  @contextlib.asynccontextmanager
173
- async def open_task_group() -> AsyncIterator[TaskGroup]:
206
+ async def open_task_group(
207
+ *,
208
+ exit_timeout: float | None = None,
209
+ race: bool = False,
210
+ ) -> AsyncIterator[TaskGroup]:
174
211
  """Create a task group.
175
212
 
176
213
  `asyncio` gained task groups in Python 3.11.
@@ -184,30 +221,58 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
184
221
  NOTE: Subtask exceptions do not propagate until the context manager exits.
185
222
  This means that the task group cannot cancel code running inside the
186
223
  `async with` block .
224
+
225
+ Args:
226
+ exit_timeout: An optional timeout in seconds. When exiting the
227
+ context manager, if tasks don't complete in this time,
228
+ they are cancelled and a TimeoutError is raised.
229
+ race: If true, all pending tasks are cancelled once any task
230
+ in the group completes. Prefer to use the race() function instead.
231
+
232
+ Raises:
233
+ TimeoutError: if exit_timeout is specified and tasks don't finish
234
+ in time.
187
235
  """
188
236
  task_group = TaskGroup()
189
237
 
190
238
  try:
191
239
  yield task_group
192
- await task_group._wait_all()
240
+ await task_group._wait_all(race=race, timeout=exit_timeout)
193
241
  finally:
194
- task_group._cancel_all()
242
+ await task_group._cancel_all()
195
243
 
196
244
 
197
- @contextlib.contextmanager
198
- def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
245
+ @contextlib.asynccontextmanager
246
+ async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
199
247
  """Schedule a task, cancelling it when exiting the context manager.
200
248
 
201
249
  If the context manager exits successfully but the given coroutine raises
202
250
  an exception, that exception is reraised. The exception is suppressed
203
251
  if the context manager raises an exception.
204
252
  """
205
- task = asyncio.create_task(coro)
206
253
 
207
- try:
254
+ async def stop_immediately():
255
+ pass
256
+
257
+ async with open_task_group(race=True) as group:
258
+ group.start_soon(stop_immediately())
259
+ group.start_soon(coro)
208
260
  yield
209
261
 
210
- if task.done() and (exception := task.exception()):
211
- raise exception
212
- finally:
213
- task.cancel()
262
+
263
+ async def race(*coros: Coroutine[Any, Any, Any]) -> None:
264
+ """Wait until the first completed task.
265
+
266
+ After any coroutine completes, all others are cancelled.
267
+ If the current task is cancelled, all coroutines are cancelled too.
268
+
269
+ If coroutines complete simultaneously and any one of them raises
270
+ an exception, an arbitrary one is propagated. Similarly, if any coroutines
271
+ raise exceptions during cancellation, one of them propagates.
272
+
273
+ Args:
274
+ coros: Coroutines to race.
275
+ """
276
+ async with open_task_group(race=True) as tg:
277
+ for coro in coros:
278
+ tg.start_soon(coro)
@@ -4,7 +4,9 @@ Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
4
4
  The only substantial change is to reuse a requests.Session object.
5
5
  """
6
6
 
7
- from typing import Any, Callable, Dict, Optional, Tuple, Union
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Callable
8
10
 
9
11
  import requests
10
12
  from wandb_gql.transport.http import HTTPTransport
@@ -12,15 +14,17 @@ from wandb_graphql.execution import ExecutionResult
12
14
  from wandb_graphql.language import ast
13
15
  from wandb_graphql.language.printer import print_ast
14
16
 
17
+ from wandb._analytics import tracked_func
18
+
15
19
 
16
20
  class GraphQLSession(HTTPTransport):
17
21
  def __init__(
18
22
  self,
19
23
  url: str,
20
- auth: Optional[Union[Tuple[str, str], Callable]] = None,
24
+ auth: tuple[str, str] | Callable | None = None,
21
25
  use_json: bool = False,
22
- timeout: Optional[Union[int, float]] = None,
23
- proxies: Optional[Dict[str, str]] = None,
26
+ timeout: int | float | None = None,
27
+ proxies: dict[str, str] | None = None,
24
28
  **kwargs: Any,
25
29
  ) -> None:
26
30
  """Setup a session for sending GraphQL queries and mutations.
@@ -42,15 +46,22 @@ class GraphQLSession(HTTPTransport):
42
46
  def execute(
43
47
  self,
44
48
  document: ast.Node,
45
- variable_values: Optional[Dict] = None,
46
- timeout: Optional[Union[int, float]] = None,
49
+ variable_values: dict[str, Any] | None = None,
50
+ timeout: int | float | None = None,
47
51
  ) -> ExecutionResult:
48
52
  query_str = print_ast(document)
49
53
  payload = {"query": query_str, "variables": variable_values or {}}
50
54
 
51
55
  data_key = "json" if self.use_json else "data"
56
+
57
+ headers = self.headers.copy() if self.headers else {}
58
+
59
+ # If we're tracking a calling python function, include it in the headers
60
+ if func_info := tracked_func():
61
+ headers.update(func_info.to_headers())
62
+
52
63
  post_args = {
53
- "headers": self.headers,
64
+ "headers": headers or None,
54
65
  "cookies": self.cookies,
55
66
  "timeout": timeout or self.default_timeout,
56
67
  data_key: payload,