wandb 0.21.3__py3-none-musllinux_1_2_aarch64.whl → 0.21.4__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +1 -1
  3. wandb/_analytics.py +65 -0
  4. wandb/_iterutils.py +8 -0
  5. wandb/_pydantic/__init__.py +10 -11
  6. wandb/_pydantic/base.py +3 -53
  7. wandb/_pydantic/field_types.py +29 -0
  8. wandb/_pydantic/v1_compat.py +47 -30
  9. wandb/_strutils.py +40 -0
  10. wandb/apis/public/api.py +17 -4
  11. wandb/apis/public/artifacts.py +5 -4
  12. wandb/apis/public/automations.py +2 -1
  13. wandb/apis/public/registries/_freezable_list.py +6 -6
  14. wandb/apis/public/registries/_utils.py +2 -1
  15. wandb/apis/public/registries/registries_search.py +4 -0
  16. wandb/apis/public/registries/registry.py +7 -0
  17. wandb/automations/_filters/expressions.py +3 -2
  18. wandb/automations/_filters/operators.py +2 -1
  19. wandb/automations/_validators.py +20 -0
  20. wandb/automations/actions.py +4 -2
  21. wandb/automations/events.py +4 -5
  22. wandb/bin/gpu_stats +0 -0
  23. wandb/bin/wandb-core +0 -0
  24. wandb/cli/beta.py +48 -130
  25. wandb/cli/beta_sync.py +226 -0
  26. wandb/integration/dspy/__init__.py +5 -0
  27. wandb/integration/dspy/dspy.py +422 -0
  28. wandb/integration/weave/weave.py +55 -0
  29. wandb/proto/v3/wandb_server_pb2.py +38 -57
  30. wandb/proto/v3/wandb_sync_pb2.py +87 -0
  31. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/v4/wandb_server_pb2.py +38 -41
  33. wandb/proto/v4/wandb_sync_pb2.py +38 -0
  34. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  35. wandb/proto/v5/wandb_server_pb2.py +38 -41
  36. wandb/proto/v5/wandb_sync_pb2.py +39 -0
  37. wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
  38. wandb/proto/v6/wandb_server_pb2.py +38 -41
  39. wandb/proto/v6/wandb_sync_pb2.py +49 -0
  40. wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
  41. wandb/proto/wandb_generate_proto.py +1 -0
  42. wandb/proto/wandb_sync_pb2.py +12 -0
  43. wandb/sdk/artifacts/_validators.py +50 -49
  44. wandb/sdk/artifacts/artifact.py +7 -7
  45. wandb/sdk/artifacts/exceptions.py +2 -1
  46. wandb/sdk/artifacts/storage_handlers/s3_handler.py +2 -1
  47. wandb/sdk/lib/asyncio_compat.py +88 -23
  48. wandb/sdk/lib/gql_request.py +18 -7
  49. wandb/sdk/lib/printer.py +9 -13
  50. wandb/sdk/lib/progress.py +8 -6
  51. wandb/sdk/lib/service/service_connection.py +42 -12
  52. wandb/sdk/mailbox/wait_with_progress.py +1 -1
  53. wandb/sdk/wandb_init.py +0 -8
  54. wandb/sdk/wandb_run.py +13 -1
  55. wandb/sdk/wandb_settings.py +55 -0
  56. {wandb-0.21.3.dist-info → wandb-0.21.4.dist-info}/METADATA +1 -1
  57. {wandb-0.21.3.dist-info → wandb-0.21.4.dist-info}/RECORD +60 -49
  58. {wandb-0.21.3.dist-info → wandb-0.21.4.dist-info}/WHEEL +0 -0
  59. {wandb-0.21.3.dist-info → wandb-0.21.4.dist-info}/entry_points.txt +0 -0
  60. {wandb-0.21.3.dist-info → wandb-0.21.4.dist-info}/licenses/LICENSE +0 -0
@@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Sequence
10
10
  from urllib.parse import parse_qsl, urlparse
11
11
 
12
12
  from wandb import util
13
+ from wandb._strutils import ensureprefix
13
14
  from wandb.errors import CommError
14
15
  from wandb.errors.term import termlog
15
16
  from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
@@ -328,7 +329,7 @@ class S3Handler(StorageHandler):
328
329
  return True
329
330
 
330
331
  # Enforce HTTPS otherwise
331
- https_url = url if url.startswith("https://") else f"https://{url}"
332
+ https_url = ensureprefix(url, "https://")
332
333
  netloc = urlparse(https_url).netloc
333
334
  return bool(
334
335
  # Match for https://cwobject.com
@@ -7,7 +7,7 @@ import concurrent
7
7
  import concurrent.futures
8
8
  import contextlib
9
9
  import threading
10
- from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
10
+ from typing import Any, AsyncIterator, Callable, Coroutine, TypeVar
11
11
 
12
12
  _T = TypeVar("_T")
13
13
 
@@ -143,34 +143,71 @@ class TaskGroup:
143
143
  """
144
144
  self._tasks.append(asyncio.create_task(coro))
145
145
 
146
- async def _wait_all(self) -> None:
147
- """Block until all tasks complete.
146
+ async def _wait_all(self, *, race: bool, timeout: float | None) -> None:
147
+ """Block until tasks complete.
148
+
149
+ Args:
150
+ race: If true, blocks until the first task completes and then
151
+ cancels the rest. Otherwise, waits for all tasks or until
152
+ the first exception.
153
+ timeout: How long to wait.
148
154
 
149
155
  Raises:
156
+ TimeoutError: If the timeout expires.
150
157
  Exception: If one or more tasks raises an exception, one of these
151
158
  is raised arbitrarily.
152
159
  """
153
- done, _ = await asyncio.wait(
160
+ if not self._tasks:
161
+ return
162
+
163
+ if race:
164
+ return_when = asyncio.FIRST_COMPLETED
165
+ else:
166
+ return_when = asyncio.FIRST_EXCEPTION
167
+
168
+ done, pending = await asyncio.wait(
154
169
  self._tasks,
155
- # NOTE: Cancelling a task counts as a normal exit,
156
- # not an exception.
157
- return_when=concurrent.futures.FIRST_EXCEPTION,
170
+ timeout=timeout,
171
+ return_when=return_when,
158
172
  )
159
173
 
174
+ if not done:
175
+ raise TimeoutError(f"Timed out after {timeout} seconds.")
176
+
177
+ # If any of the finished tasks raised an exception, pick the first one.
160
178
  for task in done:
161
- with contextlib.suppress(asyncio.CancelledError):
162
- if exc := task.exception():
163
- raise exc
179
+ if exc := task.exception():
180
+ raise exc
164
181
 
165
- def _cancel_all(self) -> None:
166
- """Cancel all tasks."""
182
+ # Wait for remaining tasks to clean up, then re-raise any exceptions
183
+ # that arise. Note that pending is only non-empty when race=True.
184
+ for task in pending:
185
+ task.cancel()
186
+ await asyncio.gather(*pending, return_exceptions=True)
187
+ for task in pending:
188
+ if task.cancelled():
189
+ continue
190
+ if exc := task.exception():
191
+ raise exc
192
+
193
+ async def _cancel_all(self) -> None:
194
+ """Cancel all tasks.
195
+
196
+ Blocks until cancelled tasks complete to allow them to clean up.
197
+ Ignores exceptions.
198
+ """
167
199
  for task in self._tasks:
168
200
  # NOTE: It is safe to cancel tasks that have already completed.
169
201
  task.cancel()
202
+ await asyncio.gather(*self._tasks, return_exceptions=True)
170
203
 
171
204
 
172
205
  @contextlib.asynccontextmanager
173
- async def open_task_group() -> AsyncIterator[TaskGroup]:
206
+ async def open_task_group(
207
+ *,
208
+ exit_timeout: float | None = None,
209
+ race: bool = False,
210
+ ) -> AsyncIterator[TaskGroup]:
174
211
  """Create a task group.
175
212
 
176
213
  `asyncio` gained task groups in Python 3.11.
@@ -184,30 +221,58 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
184
221
  NOTE: Subtask exceptions do not propagate until the context manager exits.
185
222
  This means that the task group cannot cancel code running inside the
186
223
  `async with` block .
224
+
225
+ Args:
226
+ exit_timeout: An optional timeout in seconds. When exiting the
227
+ context manager, if tasks don't complete in this time,
228
+ they are cancelled and a TimeoutError is raised.
229
+ race: If true, all pending tasks are cancelled once any task
230
+ in the group completes. Prefer to use the race() function instead.
231
+
232
+ Raises:
233
+ TimeoutError: if exit_timeout is specified and tasks don't finish
234
+ in time.
187
235
  """
188
236
  task_group = TaskGroup()
189
237
 
190
238
  try:
191
239
  yield task_group
192
- await task_group._wait_all()
240
+ await task_group._wait_all(race=race, timeout=exit_timeout)
193
241
  finally:
194
- task_group._cancel_all()
242
+ await task_group._cancel_all()
195
243
 
196
244
 
197
- @contextlib.contextmanager
198
- def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
245
+ @contextlib.asynccontextmanager
246
+ async def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> AsyncIterator[None]:
199
247
  """Schedule a task, cancelling it when exiting the context manager.
200
248
 
201
249
  If the context manager exits successfully but the given coroutine raises
202
250
  an exception, that exception is reraised. The exception is suppressed
203
251
  if the context manager raises an exception.
204
252
  """
205
- task = asyncio.create_task(coro)
206
253
 
207
- try:
254
+ async def stop_immediately():
255
+ pass
256
+
257
+ async with open_task_group(race=True) as group:
258
+ group.start_soon(stop_immediately())
259
+ group.start_soon(coro)
208
260
  yield
209
261
 
210
- if task.done() and (exception := task.exception()):
211
- raise exception
212
- finally:
213
- task.cancel()
262
+
263
+ async def race(*coros: Coroutine[Any, Any, Any]) -> None:
264
+ """Wait until the first completed task.
265
+
266
+ After any coroutine completes, all others are cancelled.
267
+ If the current task is cancelled, all coroutines are cancelled too.
268
+
269
+ If coroutines complete simultaneously and any one of them raises
270
+ an exception, an arbitrary one is propagated. Similarly, if any coroutines
271
+ raise exceptions during cancellation, one of them propagates.
272
+
273
+ Args:
274
+ coros: Coroutines to race.
275
+ """
276
+ async with open_task_group(race=True) as tg:
277
+ for coro in coros:
278
+ tg.start_soon(coro)
@@ -4,7 +4,9 @@ Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
4
4
  The only substantial change is to reuse a requests.Session object.
5
5
  """
6
6
 
7
- from typing import Any, Callable, Dict, Optional, Tuple, Union
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Callable
8
10
 
9
11
  import requests
10
12
  from wandb_gql.transport.http import HTTPTransport
@@ -12,15 +14,17 @@ from wandb_graphql.execution import ExecutionResult
12
14
  from wandb_graphql.language import ast
13
15
  from wandb_graphql.language.printer import print_ast
14
16
 
17
+ from wandb._analytics import tracked_func
18
+
15
19
 
16
20
  class GraphQLSession(HTTPTransport):
17
21
  def __init__(
18
22
  self,
19
23
  url: str,
20
- auth: Optional[Union[Tuple[str, str], Callable]] = None,
24
+ auth: tuple[str, str] | Callable | None = None,
21
25
  use_json: bool = False,
22
- timeout: Optional[Union[int, float]] = None,
23
- proxies: Optional[Dict[str, str]] = None,
26
+ timeout: int | float | None = None,
27
+ proxies: dict[str, str] | None = None,
24
28
  **kwargs: Any,
25
29
  ) -> None:
26
30
  """Setup a session for sending GraphQL queries and mutations.
@@ -42,15 +46,22 @@ class GraphQLSession(HTTPTransport):
42
46
  def execute(
43
47
  self,
44
48
  document: ast.Node,
45
- variable_values: Optional[Dict] = None,
46
- timeout: Optional[Union[int, float]] = None,
49
+ variable_values: dict[str, Any] | None = None,
50
+ timeout: int | float | None = None,
47
51
  ) -> ExecutionResult:
48
52
  query_str = print_ast(document)
49
53
  payload = {"query": query_str, "variables": variable_values or {}}
50
54
 
51
55
  data_key = "json" if self.use_json else "data"
56
+
57
+ headers = self.headers.copy() if self.headers else {}
58
+
59
+ # If we're tracking a calling python function, include it in the headers
60
+ if func_info := tracked_func():
61
+ headers.update(func_info.to_headers())
62
+
52
63
  post_args = {
53
- "headers": self.headers,
64
+ "headers": headers or None,
54
65
  "cookies": self.cookies,
55
66
  "timeout": timeout or self.default_timeout,
56
67
  data_key: payload,
wandb/sdk/lib/printer.py CHANGED
@@ -156,14 +156,10 @@ class Printer(abc.ABC):
156
156
  """
157
157
 
158
158
  @abc.abstractmethod
159
- def progress_close(self, text: str | None = None) -> None:
159
+ def progress_close(self) -> None:
160
160
  """Close the progress indicator.
161
161
 
162
162
  After this, `progress_update` should not be used.
163
-
164
- Args:
165
- text: The final text to set on the progress indicator.
166
- Ignored in Jupyter notebooks.
167
163
  """
168
164
 
169
165
  @staticmethod
@@ -342,13 +338,10 @@ class _PrinterTerm(Printer):
342
338
  wandb.termlog(f"{next(self._progress)} {text}", newline=False)
343
339
 
344
340
  @override
345
- def progress_close(self, text: str | None = None) -> None:
341
+ def progress_close(self) -> None:
346
342
  if self._settings and self._settings.silent:
347
343
  return
348
344
 
349
- text = text or " " * 79
350
- wandb.termlog(text)
351
-
352
345
  @property
353
346
  @override
354
347
  def supports_html(self) -> bool:
@@ -462,11 +455,14 @@ class _PrinterJupyter(Printer):
462
455
  display_id=True,
463
456
  )
464
457
 
465
- if handle:
458
+ if not handle:
459
+ yield None
460
+ return
461
+
462
+ try:
466
463
  yield _DynamicJupyterText(handle)
464
+ finally:
467
465
  handle.update(self._ipython_display.HTML(""))
468
- else:
469
- yield None
470
466
 
471
467
  @override
472
468
  def display(
@@ -539,7 +535,7 @@ class _PrinterJupyter(Printer):
539
535
  self._progress.update(percent_done, text)
540
536
 
541
537
  @override
542
- def progress_close(self, text: str | None = None) -> None:
538
+ def progress_close(self) -> None:
543
539
  if self._progress:
544
540
  self._progress.close()
545
541
 
wandb/sdk/lib/progress.py CHANGED
@@ -70,12 +70,14 @@ def progress_printer(
70
70
  default_text: The text to show if no information is available.
71
71
  """
72
72
  with printer.dynamic_text() as text_area:
73
- yield ProgressPrinter(
74
- printer,
75
- text_area,
76
- default_text=default_text,
77
- )
78
- printer.progress_close()
73
+ try:
74
+ yield ProgressPrinter(
75
+ printer,
76
+ text_area,
77
+ default_text=default_text,
78
+ )
79
+ finally:
80
+ printer.progress_close()
79
81
 
80
82
 
81
83
  class ProgressPrinter:
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import atexit
4
+ import pathlib
4
5
  from typing import Callable
5
6
 
6
7
  from wandb.proto import wandb_server_pb2 as spb
7
- from wandb.proto import wandb_settings_pb2
8
+ from wandb.proto import wandb_settings_pb2, wandb_sync_pb2
8
9
  from wandb.sdk import wandb_settings
9
10
  from wandb.sdk.interface.interface import InterfaceBase
10
11
  from wandb.sdk.interface.interface_sock import InterfaceSock
@@ -12,6 +13,7 @@ from wandb.sdk.lib import asyncio_manager
12
13
  from wandb.sdk.lib.exit_hooks import ExitHooks
13
14
  from wandb.sdk.lib.service.service_client import ServiceClient
14
15
  from wandb.sdk.mailbox import HandleAbandonedError, MailboxClosedError
16
+ from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
15
17
 
16
18
  from . import service_process, service_token
17
19
 
@@ -96,6 +98,45 @@ class ServiceConnection:
96
98
  """Returns an interface for communicating with the service."""
97
99
  return InterfaceSock(self._client, stream_id=stream_id)
98
100
 
101
+ def init_sync(
102
+ self,
103
+ paths: set[pathlib.Path],
104
+ settings: wandb_settings.Settings,
105
+ ) -> MailboxHandle[wandb_sync_pb2.ServerInitSyncResponse]:
106
+ """Send a ServerInitSyncRequest."""
107
+ init_sync = wandb_sync_pb2.ServerInitSyncRequest(
108
+ path=(str(path) for path in paths),
109
+ settings=settings.to_proto(),
110
+ )
111
+ request = spb.ServerRequest(init_sync=init_sync)
112
+
113
+ handle = self._client.deliver(request)
114
+ return handle.map(lambda r: r.init_sync_response)
115
+
116
+ def sync(
117
+ self,
118
+ id: str,
119
+ *,
120
+ parallelism: int,
121
+ ) -> MailboxHandle[wandb_sync_pb2.ServerSyncResponse]:
122
+ """Send a ServerSyncRequest."""
123
+ sync = wandb_sync_pb2.ServerSyncRequest(id=id, parallelism=parallelism)
124
+ request = spb.ServerRequest(sync=sync)
125
+
126
+ handle = self._client.deliver(request)
127
+ return handle.map(lambda r: r.sync_response)
128
+
129
+ def sync_status(
130
+ self,
131
+ id: str,
132
+ ) -> MailboxHandle[wandb_sync_pb2.ServerSyncStatusResponse]:
133
+ """Send a ServerSyncStatusRequest."""
134
+ sync_status = wandb_sync_pb2.ServerSyncStatusRequest(id=id)
135
+ request = spb.ServerRequest(sync_status=sync_status)
136
+
137
+ handle = self._client.deliver(request)
138
+ return handle.map(lambda r: r.sync_status_response)
139
+
99
140
  def inform_init(
100
141
  self,
101
142
  settings: wandb_settings_pb2.Settings,
@@ -143,17 +184,6 @@ class ServiceConnection:
143
184
  else:
144
185
  return response.inform_attach_response.settings
145
186
 
146
- def inform_start(
147
- self,
148
- settings: wandb_settings_pb2.Settings,
149
- run_id: str,
150
- ) -> None:
151
- """Send a start request to the service."""
152
- request = spb.ServerInformStartRequest()
153
- request.settings.CopyFrom(settings)
154
- request._info.stream_id = run_id
155
- self._client.publish(spb.ServerRequest(inform_start=request))
156
-
157
187
  def teardown(self, exit_code: int) -> int | None:
158
188
  """Close the connection.
159
189
 
@@ -62,7 +62,7 @@ def wait_all_with_progress(
62
62
  start_time = time.monotonic()
63
63
 
64
64
  async def progress_loop_with_timeout() -> list[_T]:
65
- with asyncio_compat.cancel_on_exit(display_progress()):
65
+ async with asyncio_compat.cancel_on_exit(display_progress()):
66
66
  if timeout is not None:
67
67
  elapsed_time = time.monotonic() - start_time
68
68
  remaining_timeout = timeout - elapsed_time
wandb/sdk/wandb_init.py CHANGED
@@ -1009,14 +1009,6 @@ class _WandbInit:
1009
1009
  run._set_run_obj(result.run_result.run)
1010
1010
 
1011
1011
  self._logger.info("starting run threads in backend")
1012
- # initiate run (stats and metadata probing)
1013
-
1014
- if service:
1015
- assert settings.run_id
1016
- service.inform_start(
1017
- settings=settings.to_proto(),
1018
- run_id=settings.run_id,
1019
- )
1020
1012
 
1021
1013
  assert backend.interface
1022
1014
 
wandb/sdk/wandb_run.py CHANGED
@@ -892,7 +892,19 @@ class Run:
892
892
  def tags(self, tags: Sequence) -> None:
893
893
  with telemetry.context(run=self) as tel:
894
894
  tel.feature.set_run_tags = True
895
- self._settings.run_tags = tuple(tags)
895
+
896
+ try:
897
+ self._settings.run_tags = tuple(tags)
898
+ except ValueError as e:
899
+ # For runtime tag setting, warn instead of crash
900
+ # Extract the core error message without the pydantic wrapper
901
+ error_msg = str(e)
902
+ if "Value error," in error_msg:
903
+ # Extract the actual error message after "Value error, "
904
+ error_msg = error_msg.split("Value error, ")[1].split(" [type=")[0]
905
+ wandb.termwarn(f"Invalid tag detected: {error_msg} Tags not updated.")
906
+ return
907
+
896
908
  if self._backend and self._backend.interface:
897
909
  self._backend.interface.publish_run(self)
898
910
 
@@ -1435,6 +1435,61 @@ class Settings(BaseModel, validate_assignment=True):
1435
1435
  raise UsageError("Sweep ID cannot contain only whitespace")
1436
1436
  return value
1437
1437
 
1438
+ @field_validator("run_tags", mode="before")
1439
+ @classmethod
1440
+ def validate_run_tags(cls, value):
1441
+ """Validate run tags.
1442
+
1443
+ Validates that each tag:
1444
+ - Is between 1 and 64 characters in length (inclusive)
1445
+ - Converts single string values to tuple format
1446
+ - Preserves None values
1447
+
1448
+ Args:
1449
+ value: A string, list, tuple, or None representing tags
1450
+
1451
+ Returns:
1452
+ tuple: A tuple of validated tags, or None
1453
+
1454
+ Raises:
1455
+ ValueError: If any tag is empty or exceeds 64 characters
1456
+
1457
+ <!-- lazydoc-ignore-classmethod: internal -->
1458
+ """
1459
+ if value is None:
1460
+ return None
1461
+
1462
+ # Convert to tuple if needed
1463
+ if isinstance(value, str):
1464
+ tags = (value,)
1465
+ else:
1466
+ tags = tuple(value)
1467
+
1468
+ # Validate each tag and accumulate errors
1469
+ errors = []
1470
+ for i, tag in enumerate(tags):
1471
+ tag_str = str(tag)
1472
+ if len(tag_str) == 0:
1473
+ errors.append(
1474
+ f"Tag at index {i} is empty. Tags must be between 1 and 64 characters"
1475
+ )
1476
+ elif len(tag_str) > 64:
1477
+ # Truncate long tags for display
1478
+ display_tag = (
1479
+ f"{tag_str[:20]}...{tag_str[-20:]}"
1480
+ if len(tag_str) > 43
1481
+ else tag_str
1482
+ )
1483
+ errors.append(
1484
+ f"Tag '{display_tag}' is {len(tag_str)} characters. Tags must be between 1 and 64 characters"
1485
+ )
1486
+
1487
+ # Raise combined error if any validation issues were found
1488
+ if errors:
1489
+ raise ValueError("; ".join(errors))
1490
+
1491
+ return tags
1492
+
1438
1493
  @field_validator("sweep_param_path", mode="before")
1439
1494
  @classmethod
1440
1495
  def validate_sweep_param_path(cls, value):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: wandb
3
- Version: 0.21.3
3
+ Version: 0.21.4
4
4
  Summary: A CLI and library for interacting with the Weights & Biases API.
5
5
  Project-URL: Source, https://github.com/wandb/wandb
6
6
  Project-URL: Bug Reports, https://github.com/wandb/wandb/issues