wandb 0.22.2__py3-none-macosx_12_0_arm64.whl → 0.22.3__py3-none-macosx_12_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (147) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +2 -2
  3. wandb/_pydantic/__init__.py +8 -1
  4. wandb/_pydantic/base.py +54 -18
  5. wandb/_pydantic/field_types.py +8 -3
  6. wandb/_pydantic/pagination.py +46 -0
  7. wandb/_pydantic/utils.py +2 -2
  8. wandb/apis/public/api.py +24 -19
  9. wandb/apis/public/artifacts.py +259 -270
  10. wandb/apis/public/registries/_utils.py +40 -54
  11. wandb/apis/public/registries/registries_search.py +70 -85
  12. wandb/apis/public/registries/registry.py +173 -156
  13. wandb/apis/public/runs.py +27 -6
  14. wandb/apis/public/utils.py +43 -20
  15. wandb/automations/_generated/create_automation.py +2 -2
  16. wandb/automations/_generated/create_generic_webhook_integration.py +4 -4
  17. wandb/automations/_generated/delete_automation.py +2 -2
  18. wandb/automations/_generated/fragments.py +31 -52
  19. wandb/automations/_generated/generic_webhook_integrations_by_entity.py +3 -3
  20. wandb/automations/_generated/get_automations.py +3 -3
  21. wandb/automations/_generated/get_automations_by_entity.py +3 -3
  22. wandb/automations/_generated/input_types.py +9 -9
  23. wandb/automations/_generated/integrations_by_entity.py +3 -3
  24. wandb/automations/_generated/operations.py +6 -6
  25. wandb/automations/_generated/slack_integrations_by_entity.py +3 -3
  26. wandb/automations/_generated/update_automation.py +2 -2
  27. wandb/automations/_utils.py +3 -3
  28. wandb/automations/actions.py +3 -3
  29. wandb/automations/automations.py +6 -5
  30. wandb/bin/gpu_stats +0 -0
  31. wandb/bin/wandb-core +0 -0
  32. wandb/cli/beta.py +8 -2
  33. wandb/cli/beta_leet.py +2 -1
  34. wandb/cli/beta_sync.py +1 -1
  35. wandb/errors/term.py +8 -8
  36. wandb/jupyter.py +0 -51
  37. wandb/old/settings.py +6 -6
  38. wandb/proto/v3/wandb_internal_pb2.py +351 -352
  39. wandb/proto/v3/wandb_server_pb2.py +38 -37
  40. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  41. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  42. wandb/proto/v4/wandb_internal_pb2.py +351 -352
  43. wandb/proto/v4/wandb_server_pb2.py +38 -37
  44. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  45. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  46. wandb/proto/v5/wandb_internal_pb2.py +351 -352
  47. wandb/proto/v5/wandb_server_pb2.py +38 -37
  48. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  49. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  50. wandb/proto/v6/wandb_internal_pb2.py +351 -352
  51. wandb/proto/v6/wandb_server_pb2.py +38 -37
  52. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  53. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  54. wandb/sdk/artifacts/_generated/__init__.py +96 -40
  55. wandb/sdk/artifacts/_generated/add_aliases.py +3 -3
  56. wandb/sdk/artifacts/_generated/add_artifact_collection_tags.py +26 -0
  57. wandb/sdk/artifacts/_generated/artifact_by_id.py +2 -2
  58. wandb/sdk/artifacts/_generated/artifact_by_name.py +3 -3
  59. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +27 -8
  60. wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +27 -8
  61. wandb/sdk/artifacts/_generated/artifact_created_by.py +7 -20
  62. wandb/sdk/artifacts/_generated/artifact_file_urls.py +19 -6
  63. wandb/sdk/artifacts/_generated/artifact_membership_by_name.py +26 -0
  64. wandb/sdk/artifacts/_generated/artifact_type.py +5 -5
  65. wandb/sdk/artifacts/_generated/artifact_used_by.py +8 -17
  66. wandb/sdk/artifacts/_generated/artifact_version_files.py +19 -8
  67. wandb/sdk/artifacts/_generated/delete_aliases.py +3 -3
  68. wandb/sdk/artifacts/_generated/delete_artifact.py +4 -4
  69. wandb/sdk/artifacts/_generated/delete_artifact_collection_tags.py +23 -0
  70. wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +4 -4
  71. wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +4 -4
  72. wandb/sdk/artifacts/_generated/delete_registry.py +21 -0
  73. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +8 -20
  74. wandb/sdk/artifacts/_generated/fetch_linked_artifacts.py +13 -35
  75. wandb/sdk/artifacts/_generated/fetch_org_info_from_entity.py +28 -0
  76. wandb/sdk/artifacts/_generated/fetch_registries.py +18 -8
  77. wandb/sdk/{projects → artifacts}/_generated/fetch_registry.py +4 -4
  78. wandb/sdk/artifacts/_generated/fragments.py +183 -333
  79. wandb/sdk/artifacts/_generated/input_types.py +133 -7
  80. wandb/sdk/artifacts/_generated/link_artifact.py +5 -5
  81. wandb/sdk/artifacts/_generated/operations.py +1053 -548
  82. wandb/sdk/artifacts/_generated/project_artifact_collection.py +9 -77
  83. wandb/sdk/artifacts/_generated/project_artifact_collections.py +21 -9
  84. wandb/sdk/artifacts/_generated/project_artifact_type.py +3 -3
  85. wandb/sdk/artifacts/_generated/project_artifact_types.py +19 -6
  86. wandb/sdk/artifacts/_generated/project_artifacts.py +7 -8
  87. wandb/sdk/artifacts/_generated/registry_collections.py +21 -9
  88. wandb/sdk/artifacts/_generated/registry_versions.py +20 -9
  89. wandb/sdk/artifacts/_generated/rename_registry.py +25 -0
  90. wandb/sdk/artifacts/_generated/run_input_artifacts.py +5 -9
  91. wandb/sdk/artifacts/_generated/run_output_artifacts.py +5 -9
  92. wandb/sdk/artifacts/_generated/type_info.py +2 -2
  93. wandb/sdk/artifacts/_generated/unlink_artifact.py +3 -5
  94. wandb/sdk/artifacts/_generated/update_artifact.py +3 -3
  95. wandb/sdk/artifacts/_generated/update_artifact_collection_type.py +28 -0
  96. wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +7 -16
  97. wandb/sdk/artifacts/_generated/update_artifact_sequence.py +7 -16
  98. wandb/sdk/artifacts/_generated/upsert_registry.py +25 -0
  99. wandb/sdk/artifacts/_gqlutils.py +170 -6
  100. wandb/sdk/artifacts/_models/__init__.py +9 -0
  101. wandb/sdk/artifacts/_models/artifact_collection.py +109 -0
  102. wandb/sdk/artifacts/_models/manifest.py +26 -0
  103. wandb/sdk/artifacts/_models/pagination.py +26 -0
  104. wandb/sdk/artifacts/_models/registry.py +100 -0
  105. wandb/sdk/artifacts/_validators.py +45 -27
  106. wandb/sdk/artifacts/artifact.py +220 -215
  107. wandb/sdk/artifacts/artifact_file_cache.py +1 -1
  108. wandb/sdk/artifacts/artifact_manifest.py +37 -32
  109. wandb/sdk/artifacts/artifact_manifest_entry.py +80 -125
  110. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +43 -61
  111. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +8 -6
  112. wandb/sdk/data_types/image.py +2 -2
  113. wandb/sdk/interface/interface.py +72 -64
  114. wandb/sdk/interface/interface_queue.py +27 -18
  115. wandb/sdk/interface/interface_shared.py +61 -23
  116. wandb/sdk/interface/interface_sock.py +9 -5
  117. wandb/sdk/internal/_generated/server_features_query.py +4 -4
  118. wandb/sdk/launch/inputs/schema.py +13 -10
  119. wandb/sdk/lib/apikey.py +8 -12
  120. wandb/sdk/lib/asyncio_compat.py +1 -1
  121. wandb/sdk/lib/asyncio_manager.py +5 -5
  122. wandb/sdk/lib/console_capture.py +38 -30
  123. wandb/sdk/lib/progress.py +159 -64
  124. wandb/sdk/lib/retry.py +3 -2
  125. wandb/sdk/lib/service/service_connection.py +2 -2
  126. wandb/sdk/lib/wb_logging.py +2 -1
  127. wandb/sdk/mailbox/mailbox.py +1 -1
  128. wandb/sdk/wandb_init.py +10 -13
  129. wandb/sdk/wandb_run.py +9 -46
  130. wandb/sdk/wandb_settings.py +102 -19
  131. {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/METADATA +2 -1
  132. {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/RECORD +135 -134
  133. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +0 -26
  134. wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +0 -36
  135. wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +0 -25
  136. wandb/sdk/artifacts/_generated/move_artifact_collection.py +0 -35
  137. wandb/sdk/projects/_generated/__init__.py +0 -26
  138. wandb/sdk/projects/_generated/delete_project.py +0 -22
  139. wandb/sdk/projects/_generated/enums.py +0 -4
  140. wandb/sdk/projects/_generated/fragments.py +0 -41
  141. wandb/sdk/projects/_generated/input_types.py +0 -13
  142. wandb/sdk/projects/_generated/operations.py +0 -88
  143. wandb/sdk/projects/_generated/rename_project.py +0 -27
  144. wandb/sdk/projects/_generated/upsert_registry_project.py +0 -27
  145. {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/WHEEL +0 -0
  146. {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/entry_points.txt +0 -0
  147. {wandb-0.22.2.dist-info → wandb-0.22.3.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,9 @@ META_SCHEMA = {
7
7
  },
8
8
  "title": {"type": "string"},
9
9
  "description": {"type": "string"},
10
+ "label": {"type": "string"},
11
+ "placeholder": {"type": "string"},
12
+ "required": {"type": "boolean"},
10
13
  "format": {"type": "string"},
11
14
  "enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
12
15
  "properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
@@ -19,24 +22,24 @@ META_SCHEMA = {
19
22
  },
20
23
  "allOf": [
21
24
  {
22
- "if": {"properties": {"type": {"const": "number"}}},
25
+ "if": {"properties": {"type": {"const": "integer"}}},
23
26
  "then": {
24
27
  "properties": {
25
- "minimum": {"type": ["integer", "number"]},
26
- "maximum": {"type": ["integer", "number"]},
27
- "exclusiveMinimum": {"type": ["integer", "number"]},
28
- "exclusiveMaximum": {"type": ["integer", "number"]},
28
+ "minimum": {"type": "integer"},
29
+ "maximum": {"type": "integer"},
30
+ "exclusiveMinimum": {"type": "integer"},
31
+ "exclusiveMaximum": {"type": "integer"},
29
32
  }
30
33
  },
31
34
  },
32
35
  {
33
- "if": {"properties": {"type": {"const": "integer"}}},
36
+ "if": {"properties": {"type": {"const": "number"}}},
34
37
  "then": {
35
38
  "properties": {
36
- "minimum": {"type": "integer"},
37
- "maximum": {"type": "integer"},
38
- "exclusiveMinimum": {"type": "integer"},
39
- "exclusiveMaximum": {"type": "integer"},
39
+ "minimum": {"type": ["integer", "number"]},
40
+ "maximum": {"type": ["integer", "number"]},
41
+ "exclusiveMinimum": {"type": ["integer", "number"]},
42
+ "exclusiveMaximum": {"type": ["integer", "number"]},
40
43
  }
41
44
  },
42
45
  },
wandb/sdk/lib/apikey.py CHANGED
@@ -136,12 +136,6 @@ def prompt_api_key( # noqa: C901
136
136
  if (jupyter and not settings.login_timeout) or no_create:
137
137
  choices.remove(LOGIN_CHOICE_NEW)
138
138
 
139
- if jupyter and "google.colab" in sys.modules:
140
- log_string = term.LOG_STRING_NOCOLOR
141
- key = wandb.jupyter.attempt_colab_login(app_url) # type: ignore
142
- if key is not None:
143
- return key # type: ignore
144
-
145
139
  if anon_mode == "must":
146
140
  result = LOGIN_CHOICE_ANON
147
141
  # If we're not in an interactive environment, default to dry-run.
@@ -236,9 +230,9 @@ def check_netrc_access(
236
230
  def write_netrc(host: str, entity: str, key: str):
237
231
  """Add our host and key to .netrc."""
238
232
  _, key_suffix = key.split("-", 1) if "-" in key else ("", key)
239
- if len(key_suffix) != 40:
233
+ if len(key_suffix) < 40:
240
234
  raise ValueError(
241
- f"API-key must be exactly 40 characters long: {key_suffix} ({len(key_suffix)} chars)"
235
+ f"API-key must be at least 40 characters long: {key_suffix} ({len(key_suffix)} chars)"
242
236
  )
243
237
 
244
238
  normalized_host = urlparse(host).netloc
@@ -305,12 +299,14 @@ def write_key(
305
299
  # TODO(jhr): api shouldn't be optional or it shouldn't be passed, clean up callers
306
300
  api = api or InternalApi()
307
301
 
308
- # Normal API keys are 40-character hex strings. On-prem API keys have a
309
- # variable-length prefix, a dash, then the 40-char string.
302
+ # API keys are strings of at least 40 characters. On-prem API keys have a
303
+ # variable-length prefix, a dash, then the string of at least 40 chars.
310
304
  _, suffix = key.split("-", 1) if "-" in key else ("", key)
311
305
 
312
- if len(suffix) != 40:
313
- raise ValueError(f"API key must be 40 characters long, yours was {len(key)}")
306
+ if len(suffix) < 40:
307
+ raise ValueError(
308
+ f"API key must be at least 40 characters long, yours was {len(key)}"
309
+ )
314
310
 
315
311
  write_netrc(settings.base_url, "user", key)
316
312
 
@@ -133,7 +133,7 @@ class TaskGroup:
133
133
  """Object that `open_task_group()` yields."""
134
134
 
135
135
  def __init__(self) -> None:
136
- self._tasks: list[asyncio.Task] = []
136
+ self._tasks: list[asyncio.Task[None]] = []
137
137
 
138
138
  def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
139
139
  """Schedule a task in the group.
@@ -7,7 +7,7 @@ import concurrent.futures
7
7
  import contextlib
8
8
  import logging
9
9
  import threading
10
- from typing import Any, Callable, Coroutine, TypeVar
10
+ from typing import Awaitable, Callable, TypeVar
11
11
 
12
12
  from . import asyncio_compat
13
13
 
@@ -104,7 +104,7 @@ class AsyncioManager:
104
104
  # This only matters if the KeyboardInterrupt is suppressed.
105
105
  self._runner.cancel()
106
106
 
107
- def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
107
+ def run(self, fn: Callable[[], Awaitable[_T]]) -> _T:
108
108
  """Run an async function to completion.
109
109
 
110
110
  The function is called in the asyncio thread. Blocks until start()
@@ -148,7 +148,7 @@ class AsyncioManager:
148
148
 
149
149
  def run_soon(
150
150
  self,
151
- fn: Callable[[], Coroutine[Any, Any, None]],
151
+ fn: Callable[[], Awaitable[None]],
152
152
  *,
153
153
  daemon: bool = False,
154
154
  name: str | None = None,
@@ -186,7 +186,7 @@ class AsyncioManager:
186
186
 
187
187
  def _schedule(
188
188
  self,
189
- fn: Callable[[], Coroutine[Any, Any, _T]],
189
+ fn: Callable[[], Awaitable[_T]],
190
190
  daemon: bool,
191
191
  name: str | None = None,
192
192
  ) -> concurrent.futures.Future[_T]:
@@ -207,7 +207,7 @@ class AsyncioManager:
207
207
 
208
208
  async def _wrap(
209
209
  self,
210
- fn: Callable[[], Coroutine[Any, Any, _T]],
210
+ fn: Callable[[], Awaitable[_T]],
211
211
  daemon: bool,
212
212
  name: str | None,
213
213
  ) -> _T:
@@ -75,9 +75,12 @@ class _WriteCallback(Protocol):
75
75
  """
76
76
 
77
77
 
78
- # A reentrant lock is used to catch callbacks that write to stderr/stdout.
79
- _module_rlock = threading.RLock()
78
+ _module_lock = threading.Lock()
80
79
  _is_writing = False
80
+ """Prevents infinite print-capture loops.
81
+
82
+ If a capture callback prints, that output is not captured.
83
+ """
81
84
 
82
85
  _patch_exception: CannotCaptureConsoleError | None = None
83
86
 
@@ -99,7 +102,7 @@ def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
99
102
  Raises:
100
103
  CannotCaptureConsoleError: If patching failed on import.
101
104
  """
102
- with _module_rlock:
105
+ with _module_lock:
103
106
  if _patch_exception:
104
107
  raise _patch_exception
105
108
 
@@ -121,7 +124,7 @@ def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
121
124
  Raises:
122
125
  CannotCaptureConsoleError: If patching failed on import.
123
126
  """
124
- with _module_rlock:
127
+ with _module_lock:
125
128
  if _patch_exception:
126
129
  raise _patch_exception
127
130
 
@@ -144,7 +147,7 @@ def _insert_disposably(
144
147
  def dispose() -> None:
145
148
  nonlocal disposed
146
149
 
147
- with _module_rlock:
150
+ with _module_lock:
148
151
  if disposed:
149
152
  return
150
153
 
@@ -167,38 +170,43 @@ def _patch(
167
170
  global _is_writing
168
171
  n = orig_write(s)
169
172
 
170
- # NOTE: Since _module_rlock is reentrant, this is safe. It will not
171
- # deadlock if a callback invokes write() again.
172
- with _module_rlock:
173
+ with _module_lock:
173
174
  if _is_writing:
174
175
  return n
175
-
176
176
  _is_writing = True
177
- try:
178
- for cb in callbacks.values():
179
- cb(s, n)
180
-
181
- except BaseException as e:
182
- # Clear all callbacks on any exception to avoid infinite loops:
183
- #
184
- # * If we re-raise, an exception handler is likely to print
185
- # the exception to the console and trigger callbacks again
186
- # * If we log, we can't guarantee that this doesn't print
187
- # to console.
188
- #
189
- # This is especially important for KeyboardInterrupt.
177
+
178
+ # Invoke callbacks outside of the lock to avoid deadlocks.
179
+ # 1. A callback may print, invoking this again.
180
+ # 2. A callback may block on a different thread which then prints.
181
+ callback_list = list(callbacks.values())
182
+
183
+ try:
184
+ for cb in callback_list:
185
+ cb(s, n)
186
+
187
+ except BaseException as e:
188
+ # Clear all callbacks on any exception to avoid infinite loops:
189
+ #
190
+ # * If we re-raise, an exception handler is likely to print
191
+ # the exception to the console and trigger callbacks again
192
+ # * If we log, we can't guarantee that this doesn't print
193
+ # to console.
194
+ #
195
+ # This is especially important for KeyboardInterrupt.
196
+ with _module_lock:
190
197
  _stderr_callbacks.clear()
191
198
  _stdout_callbacks.clear()
192
199
 
193
- if isinstance(e, Exception):
194
- # We suppress Exceptions so that bugs in W&B code don't
195
- # cause the user's print() statements to raise errors.
196
- _logger.exception("Error in console callback, clearing all!")
197
- else:
198
- # Re-raise errors like KeyboardInterrupt.
199
- raise
200
+ if isinstance(e, Exception):
201
+ # We suppress Exceptions so that bugs in W&B code don't
202
+ # cause the user's print() statements to raise errors.
203
+ _logger.exception("Error in console callback, clearing all!")
204
+ else:
205
+ # Re-raise errors like KeyboardInterrupt.
206
+ raise
200
207
 
201
- finally:
208
+ finally:
209
+ with _module_lock:
202
210
  _is_writing = False
203
211
 
204
212
  return n
wandb/sdk/lib/progress.py CHANGED
@@ -5,7 +5,7 @@ from __future__ import annotations
5
5
  import asyncio
6
6
  import contextlib
7
7
  import time
8
- from typing import Iterable, Iterator, NoReturn
8
+ from typing import Iterator, NoReturn
9
9
 
10
10
  from wandb.proto import wandb_internal_pb2 as pb
11
11
  from wandb.sdk.interface import interface
@@ -13,6 +13,10 @@ from wandb.sdk.lib import asyncio_compat
13
13
 
14
14
  from . import printer as p
15
15
 
16
+ _INDENT = " "
17
+ _MAX_LINES_TO_PRINT = 6
18
+ _MAX_OPS_TO_PRINT = 5
19
+
16
20
 
17
21
  async def loop_printing_operation_stats(
18
22
  progress: ProgressPrinter,
@@ -96,101 +100,189 @@ class ProgressPrinter:
96
100
  self._printer = printer
97
101
  self._progress_text_area = progress_text_area
98
102
  self._default_text = default_text
99
- self._tick = 0
103
+ self._tick = -1
100
104
  self._last_printed_line = ""
101
105
 
102
106
  def update(
103
107
  self,
104
- progress: list[pb.PollExitResponse] | pb.OperationStats,
108
+ stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
105
109
  ) -> None:
106
- """Update the displayed information."""
107
- if not progress:
110
+ """Update the displayed information.
111
+
112
+ Args:
113
+ stats_or_groups: A single group of operations, or zero or more
114
+ labeled operation groups.
115
+ """
116
+ self._tick += 1
117
+
118
+ if not self._progress_text_area:
119
+ line = self._to_static_text(stats_or_groups)
120
+ if line and line != self._last_printed_line:
121
+ self._printer.display(line)
122
+ self._last_printed_line = line
108
123
  return
109
124
 
110
- if isinstance(progress, pb.OperationStats):
111
- self._update_operation_stats([progress])
112
- else:
113
- self._update_operation_stats(
114
- list(response.operation_stats for response in progress)
115
- )
125
+ lines = self._to_dynamic_text(stats_or_groups)
126
+ if not lines:
127
+ loading_symbol = self._printer.loading_symbol(self._tick)
128
+ if loading_symbol:
129
+ lines = [f"{loading_symbol} {self._default_text}"]
130
+ else:
131
+ lines = [self._default_text]
116
132
 
117
- self._tick += 1
133
+ self._progress_text_area.set_text("\n".join(lines))
118
134
 
119
- def _update_operation_stats(self, stats_list: list[pb.OperationStats]) -> None:
120
- if self._progress_text_area:
121
- _DynamicOperationStatsPrinter(
135
+ def _to_dynamic_text(
136
+ self,
137
+ stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
138
+ ) -> list[str]:
139
+ """Returns text to show in a dynamic text area."""
140
+ loading_symbol = self._printer.loading_symbol(self._tick)
141
+
142
+ if isinstance(stats_or_groups, dict):
143
+ return _GroupedOperationStatsPrinter(
122
144
  self._printer,
123
- self._progress_text_area,
124
- max_lines=6,
125
- loading_symbol=self._printer.loading_symbol(self._tick),
126
- default_text=self._default_text,
127
- ).display(stats_list)
145
+ _MAX_LINES_TO_PRINT,
146
+ loading_symbol,
147
+ ).render(stats_or_groups)
128
148
 
129
149
  else:
130
- top_level_operations: list[str] = []
131
- extra_operations = 0
132
- for stats in stats_list:
133
- for op in stats.operations:
134
- if len(top_level_operations) < 5:
135
- top_level_operations.append(op.desc)
136
- else:
137
- extra_operations += 1
138
-
139
- line = "; ".join(top_level_operations)
140
- if extra_operations > 0:
141
- line += f" (+ {extra_operations} more)"
150
+ return _OperationStatsPrinter(
151
+ self._printer,
152
+ _MAX_LINES_TO_PRINT,
153
+ loading_symbol,
154
+ ).render(stats_or_groups)
142
155
 
143
- if line and line != self._last_printed_line:
144
- self._printer.display(line)
145
- self._last_printed_line = line
156
+ def _to_static_text(
157
+ self,
158
+ stats_or_groups: pb.OperationStats | dict[str, pb.OperationStats],
159
+ ) -> str:
160
+ """Returns a single line of text to print out."""
161
+ if isinstance(stats_or_groups, dict):
162
+ sorted_prefixed_stats = list(
163
+ (f"[{group}] ", stats) #
164
+ for group, stats in sorted(stats_or_groups.items())
165
+ )
166
+ else:
167
+ sorted_prefixed_stats = [("", stats_or_groups)]
168
+
169
+ group_strs: list[str] = []
170
+ total_operations = 0
171
+ total_printed = 0
146
172
 
173
+ for prefix, stats in sorted_prefixed_stats:
174
+ total_operations += stats.total_operations
175
+ if not stats.operations:
176
+ continue
177
+
178
+ group_ops: list[str] = []
179
+ i = 0
180
+ while total_printed < _MAX_OPS_TO_PRINT and i < len(stats.operations):
181
+ group_ops.append(stats.operations[i].desc)
182
+ total_printed += 1
183
+ i += 1
184
+
185
+ if group_ops:
186
+ group_strs.append(prefix + "; ".join(group_ops))
187
+
188
+ line = "; ".join(group_strs)
189
+ remaining = total_operations - total_printed
190
+ if total_printed > 0 and remaining > 0:
191
+ line += f" (+ {remaining} more)"
147
192
 
148
- class _DynamicOperationStatsPrinter:
149
- """Single-use object that writes operation stats into a text area."""
193
+ return line
194
+
195
+
196
+ class _GroupedOperationStatsPrinter:
197
+ """Renders a list of labeled operation stats groups into lines of text."""
150
198
 
151
199
  def __init__(
152
200
  self,
153
201
  printer: p.Printer,
154
- text_area: p.DynamicText,
155
202
  max_lines: int,
156
203
  loading_symbol: str,
157
- default_text: str,
158
204
  ) -> None:
159
205
  self._printer = printer
160
- self._text_area = text_area
161
206
  self._max_lines = max_lines
162
207
  self._loading_symbol = loading_symbol
163
- self._default_text = default_text
164
208
 
165
- self._lines: list[str] = []
166
- self._ops_shown = 0
209
+ def render(self, groups: dict[str, pb.OperationStats]) -> list[str]:
210
+ """Convert labeled operation stats groups into text to display.
211
+
212
+ Args:
213
+ groups: A mapping from group labels to stats for that group.
214
+
215
+ Returns:
216
+ The lines of text to print. The lines do not end with the newline
217
+ character. Returns an empty list if there are no operations.
218
+ """
219
+ lines: list[str] = []
220
+
221
+ for key, stats in sorted(groups.items()):
222
+ # Don't display empty groups.
223
+ if not stats.operations:
224
+ continue
225
+
226
+ # Ensure enough space left for the group header and at least
227
+ # one line of content.
228
+ remaining_lines = self._max_lines - len(lines)
229
+ if remaining_lines < 2:
230
+ break
231
+
232
+ # Group header.
233
+ lines.append(key)
167
234
 
168
- def display(
235
+ # Group content.
236
+ stats_lines = _OperationStatsPrinter(
237
+ printer=self._printer,
238
+ max_lines=remaining_lines - 1, # minus one for the header
239
+ loading_symbol=self._loading_symbol,
240
+ ).render(stats)
241
+ for line in stats_lines:
242
+ lines.append(f"{_INDENT}{line}")
243
+
244
+ return lines
245
+
246
+
247
+ class _OperationStatsPrinter:
248
+ """Renders operation stats into lines of text."""
249
+
250
+ def __init__(
169
251
  self,
170
- stats_list: Iterable[pb.OperationStats],
252
+ printer: p.Printer,
253
+ max_lines: int,
254
+ loading_symbol: str,
171
255
  ) -> None:
172
- """Show the given stats in the text area."""
173
- total_operations = 0
174
- for stats in stats_list:
175
- for op in stats.operations:
176
- self._add_operation(op, is_subtask=False, indent="")
177
- total_operations += stats.total_operations
256
+ self._printer = printer
257
+ self._max_lines = max_lines
258
+ self._loading_symbol = loading_symbol
259
+
260
+ self._lines: list[str] = []
261
+ self._ops_shown = 0
262
+
263
+ def render(self, stats: pb.OperationStats) -> list[str]:
264
+ """Convert the stats into a list of lines to display.
265
+
266
+ Args:
267
+ stats: Collection of operations to display.
178
268
 
179
- if self._ops_shown < total_operations:
269
+ Returns:
270
+ The lines of text to print. The lines do not end with the newline
271
+ character. Returns an empty list if there are no operations.
272
+ """
273
+ for op in stats.operations:
274
+ self._add_operation(op, is_subtask=False, indent="")
275
+
276
+ if self._ops_shown < stats.total_operations:
180
277
  if 1 <= self._max_lines <= len(self._lines):
278
+ self._ops_shown -= 1
181
279
  self._lines.pop()
182
280
 
183
- remaining = total_operations - self._ops_shown
281
+ remaining = stats.total_operations - self._ops_shown
184
282
 
185
283
  self._lines.append(f"+ {remaining} more task(s)")
186
284
 
187
- if len(self._lines) == 0:
188
- if self._loading_symbol:
189
- self._text_area.set_text(f"{self._loading_symbol} {self._default_text}")
190
- else:
191
- self._text_area.set_text(self._default_text)
192
- else:
193
- self._text_area.set_text("\n".join(self._lines))
285
+ return self._lines
194
286
 
195
287
  def _add_operation(self, op: pb.Operation, is_subtask: bool, indent: str) -> None:
196
288
  """Add the operation to `self._lines`."""
@@ -200,14 +292,17 @@ class _DynamicOperationStatsPrinter:
200
292
  if not is_subtask:
201
293
  self._ops_shown += 1
202
294
 
203
- parts = []
295
+ status_indent_level = 0 # alignment for the status message, if any
296
+ parts: list[str] = []
204
297
 
205
298
  # Subtask indicator.
206
299
  if is_subtask and self._printer.supports_unicode:
300
+ status_indent_level += 2 # +1 for space
207
301
  parts.append("↳")
208
302
 
209
303
  # Loading symbol.
210
304
  if self._loading_symbol:
305
+ status_indent_level += 2 # +1 for space
211
306
  parts.append(self._loading_symbol)
212
307
 
213
308
  # Task name.
@@ -225,14 +320,14 @@ class _DynamicOperationStatsPrinter:
225
320
  if op.error_status:
226
321
  error_word = self._printer.error("ERROR")
227
322
  error_desc = self._printer.secondary_text(op.error_status)
228
- subtask_indent = " " if is_subtask else ""
323
+ status_indent = " " * status_indent_level
229
324
  self._lines.append(
230
- f"{indent}{subtask_indent} {error_word} {error_desc}",
325
+ f"{indent}{status_indent}{error_word} {error_desc}",
231
326
  )
232
327
 
233
328
  # Subtasks.
234
329
  if op.subtasks:
235
- subtask_indent = indent + " "
330
+ subtask_indent = indent + _INDENT
236
331
  for task in op.subtasks:
237
332
  self._add_operation(
238
333
  task,
wandb/sdk/lib/retry.py CHANGED
@@ -77,9 +77,10 @@ class Retry(Generic[_R]):
77
77
  self._retryable_exceptions = retryable_exceptions
78
78
  else:
79
79
  self._retryable_exceptions = (TransientError,)
80
- self._index = 0
81
80
  self.retry_callback = retry_callback
82
81
 
82
+ self._num_iter = 0
83
+
83
84
  def _sleep_check_cancelled(
84
85
  self, wait_seconds: float, cancel_event: Optional[threading.Event]
85
86
  ) -> bool:
@@ -194,7 +195,7 @@ class Retry(Generic[_R]):
194
195
  else:
195
196
  wandb.termlog(
196
197
  f"{self._error_prefix}"
197
- f" ({exception.__class__.__name__}), entering retry loop."
198
+ + f" ({exception.__class__.__name__}), entering retry loop."
198
199
  )
199
200
 
200
201
  def _print_recovered(self, start_time: datetime.datetime) -> None:
@@ -190,8 +190,8 @@ class ServiceConnection:
190
190
  except TimeoutError:
191
191
  raise WandbAttachFailedError(
192
192
  "Failed to attach because the run does not belong to"
193
- " the current service process, or because the service"
194
- " process is busy (unlikely)."
193
+ + " the current service process, or because the service"
194
+ + " process is busy (unlikely)."
195
195
  ) from None
196
196
 
197
197
  else:
@@ -136,7 +136,7 @@ def add_file_handler(run_id: str, filepath: pathlib.Path) -> logging.Handler:
136
136
  return handler
137
137
 
138
138
 
139
- class _RunIDFilter(logging.Filter):
139
+ class _RunIDFilter:
140
140
  """Filters out messages logged for a different run."""
141
141
 
142
142
  def __init__(self, run_id: str) -> None:
@@ -148,6 +148,7 @@ class _RunIDFilter(logging.Filter):
148
148
  self._run_id = run_id
149
149
 
150
150
  def filter(self, record: logging.LogRecord) -> bool:
151
+ """Modify a log record and return whether it matches the run."""
151
152
  run_id = _run_id.get()
152
153
 
153
154
  if run_id is None:
@@ -89,7 +89,7 @@ class Mailbox:
89
89
  def generate():
90
90
  return "".join(
91
91
  secrets.choice(string.ascii_lowercase + string.digits)
92
- for i in range(12)
92
+ for _ in range(12)
93
93
  )
94
94
 
95
95
  address = generate()
wandb/sdk/wandb_init.py CHANGED
@@ -12,6 +12,7 @@ from __future__ import annotations
12
12
 
13
13
  import contextlib
14
14
  import dataclasses
15
+ import functools
15
16
  import json
16
17
  import logging
17
18
  import os
@@ -988,25 +989,21 @@ class _WandbInit:
988
989
 
989
990
  run_init_handle = backend.interface.deliver_run(run)
990
991
 
991
- async def display_init_message() -> None:
992
- assert backend.interface
993
-
992
+ try:
994
993
  with progress.progress_printer(
995
994
  run_printer,
996
995
  default_text="Waiting for wandb.init()...",
997
996
  ) as progress_printer:
998
- await progress.loop_printing_operation_stats(
999
- progress_printer,
1000
- backend.interface,
997
+ result = wait_with_progress(
998
+ run_init_handle,
999
+ timeout=timeout,
1000
+ display_progress=functools.partial(
1001
+ progress.loop_printing_operation_stats,
1002
+ progress_printer,
1003
+ backend.interface,
1004
+ ),
1001
1005
  )
1002
1006
 
1003
- try:
1004
- result = wait_with_progress(
1005
- run_init_handle,
1006
- timeout=timeout,
1007
- display_progress=display_init_message,
1008
- )
1009
-
1010
1007
  except TimeoutError:
1011
1008
  run_init_handle.cancel(backend.interface)
1012
1009