wandb 0.20.1rc20250604__py3-none-win_amd64.whl → 0.21.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. wandb/__init__.py +3 -6
  2. wandb/__init__.pyi +24 -23
  3. wandb/analytics/sentry.py +2 -2
  4. wandb/apis/importers/internals/internal.py +0 -3
  5. wandb/apis/internal.py +3 -0
  6. wandb/apis/paginator.py +17 -4
  7. wandb/apis/public/api.py +85 -4
  8. wandb/apis/public/artifacts.py +10 -8
  9. wandb/apis/public/files.py +5 -5
  10. wandb/apis/public/projects.py +44 -3
  11. wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
  12. wandb/apis/public/registries/registries_search.py +2 -2
  13. wandb/apis/public/registries/registry.py +19 -18
  14. wandb/apis/public/reports.py +64 -8
  15. wandb/apis/public/runs.py +16 -23
  16. wandb/automations/__init__.py +10 -10
  17. wandb/automations/_filters/run_metrics.py +0 -2
  18. wandb/automations/_utils.py +0 -2
  19. wandb/automations/actions.py +0 -2
  20. wandb/automations/automations.py +0 -2
  21. wandb/automations/events.py +0 -2
  22. wandb/bin/gpu_stats.exe +0 -0
  23. wandb/bin/wandb-core +0 -0
  24. wandb/cli/beta.py +1 -7
  25. wandb/cli/cli.py +0 -30
  26. wandb/env.py +0 -6
  27. wandb/integration/catboost/catboost.py +6 -2
  28. wandb/integration/kfp/kfp_patch.py +3 -1
  29. wandb/integration/sb3/sb3.py +3 -3
  30. wandb/integration/ultralytics/callback.py +6 -2
  31. wandb/plot/__init__.py +2 -0
  32. wandb/plot/bar.py +30 -29
  33. wandb/plot/confusion_matrix.py +75 -71
  34. wandb/plot/histogram.py +26 -25
  35. wandb/plot/line.py +33 -32
  36. wandb/plot/line_series.py +100 -103
  37. wandb/plot/pr_curve.py +33 -32
  38. wandb/plot/roc_curve.py +38 -38
  39. wandb/plot/scatter.py +27 -27
  40. wandb/proto/v3/wandb_internal_pb2.py +366 -385
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  43. wandb/proto/v4/wandb_internal_pb2.py +352 -356
  44. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v5/wandb_internal_pb2.py +352 -356
  47. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  48. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  49. wandb/proto/v6/wandb_internal_pb2.py +352 -356
  50. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  51. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  52. wandb/sdk/artifacts/_generated/__init__.py +12 -1
  53. wandb/sdk/artifacts/_generated/input_types.py +20 -2
  54. wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
  55. wandb/sdk/artifacts/_generated/operations.py +9 -0
  56. wandb/sdk/artifacts/_validators.py +40 -2
  57. wandb/sdk/artifacts/artifact.py +163 -21
  58. wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
  59. wandb/sdk/backend/backend.py +1 -1
  60. wandb/sdk/data_types/base_types/media.py +9 -7
  61. wandb/sdk/data_types/base_types/wb_value.py +6 -6
  62. wandb/sdk/data_types/saved_model.py +3 -3
  63. wandb/sdk/data_types/table.py +41 -41
  64. wandb/sdk/data_types/trace_tree.py +12 -12
  65. wandb/sdk/interface/interface.py +8 -19
  66. wandb/sdk/interface/interface_shared.py +7 -16
  67. wandb/sdk/internal/datastore.py +18 -18
  68. wandb/sdk/internal/handler.py +4 -74
  69. wandb/sdk/internal/internal_api.py +54 -0
  70. wandb/sdk/internal/sender.py +23 -3
  71. wandb/sdk/internal/sender_config.py +9 -0
  72. wandb/sdk/launch/_project_spec.py +3 -3
  73. wandb/sdk/launch/agent/agent.py +3 -3
  74. wandb/sdk/launch/agent/job_status_tracker.py +3 -1
  75. wandb/sdk/launch/utils.py +3 -3
  76. wandb/sdk/lib/console_capture.py +66 -19
  77. wandb/sdk/lib/printer.py +6 -7
  78. wandb/sdk/lib/progress.py +1 -3
  79. wandb/sdk/lib/service/ipc_support.py +13 -0
  80. wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
  81. wandb/sdk/lib/service/service_port_file.py +105 -0
  82. wandb/sdk/lib/service/service_process.py +111 -0
  83. wandb/sdk/lib/service/service_token.py +164 -0
  84. wandb/sdk/lib/sock_client.py +8 -12
  85. wandb/sdk/wandb_init.py +1 -5
  86. wandb/sdk/wandb_require.py +9 -21
  87. wandb/sdk/wandb_run.py +23 -137
  88. wandb/sdk/wandb_settings.py +233 -80
  89. wandb/sdk/wandb_setup.py +2 -13
  90. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -3
  91. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/RECORD +94 -120
  92. wandb/sdk/internal/flow_control.py +0 -263
  93. wandb/sdk/internal/internal.py +0 -401
  94. wandb/sdk/internal/internal_util.py +0 -97
  95. wandb/sdk/internal/system/__init__.py +0 -0
  96. wandb/sdk/internal/system/assets/__init__.py +0 -25
  97. wandb/sdk/internal/system/assets/aggregators.py +0 -31
  98. wandb/sdk/internal/system/assets/asset_registry.py +0 -20
  99. wandb/sdk/internal/system/assets/cpu.py +0 -163
  100. wandb/sdk/internal/system/assets/disk.py +0 -210
  101. wandb/sdk/internal/system/assets/gpu.py +0 -416
  102. wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
  103. wandb/sdk/internal/system/assets/interfaces.py +0 -205
  104. wandb/sdk/internal/system/assets/ipu.py +0 -177
  105. wandb/sdk/internal/system/assets/memory.py +0 -166
  106. wandb/sdk/internal/system/assets/network.py +0 -125
  107. wandb/sdk/internal/system/assets/open_metrics.py +0 -293
  108. wandb/sdk/internal/system/assets/tpu.py +0 -154
  109. wandb/sdk/internal/system/assets/trainium.py +0 -393
  110. wandb/sdk/internal/system/env_probe_helpers.py +0 -13
  111. wandb/sdk/internal/system/system_info.py +0 -248
  112. wandb/sdk/internal/system/system_monitor.py +0 -224
  113. wandb/sdk/internal/writer.py +0 -204
  114. wandb/sdk/lib/service_token.py +0 -93
  115. wandb/sdk/service/__init__.py +0 -0
  116. wandb/sdk/service/_startup_debug.py +0 -22
  117. wandb/sdk/service/port_file.py +0 -53
  118. wandb/sdk/service/server.py +0 -107
  119. wandb/sdk/service/server_sock.py +0 -286
  120. wandb/sdk/service/service.py +0 -252
  121. wandb/sdk/service/streams.py +0 -425
  122. wandb/sdk/wandb_metadata.py +0 -623
  123. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
  124. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
  125. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -33,16 +33,12 @@ from wandb.proto.wandb_internal_pb2 import (
33
33
  SummaryItem,
34
34
  SummaryRecord,
35
35
  SummaryRecordRequest,
36
- SystemMetricSample,
37
- SystemMetricsBuffer,
38
36
  )
39
37
 
40
38
  from ..interface.interface_queue import InterfaceQueue
41
39
  from ..lib import handler_util, proto_util
42
- from ..wandb_metadata import Metadata
43
40
  from . import context, sample, tb_watcher
44
41
  from .settings_static import SettingsStatic
45
- from .system.system_monitor import SystemMonitor
46
42
 
47
43
  if TYPE_CHECKING:
48
44
  from wandb.proto.wandb_internal_pb2 import MetricSummary
@@ -89,7 +85,6 @@ class HandleManager:
89
85
  _stopped: Event
90
86
  _writer_q: "Queue[Record]"
91
87
  _interface: InterfaceQueue
92
- _system_monitor: Optional[SystemMonitor]
93
88
  _tb_watcher: Optional[tb_watcher.TBWatcher]
94
89
  _metric_defines: Dict[str, MetricRecord]
95
90
  _metric_globs: Dict[str, MetricRecord]
@@ -119,8 +114,6 @@ class HandleManager:
119
114
  self._context_keeper = context_keeper
120
115
 
121
116
  self._tb_watcher = None
122
- self._system_monitor = None
123
- self._metadata: Optional[Metadata] = None
124
117
  self._step = 0
125
118
 
126
119
  self._track_time = None
@@ -178,21 +171,12 @@ class HandleManager:
178
171
  def handle_request_cancel(self, record: Record) -> None:
179
172
  self._dispatch_record(record)
180
173
 
181
- def handle_request_metadata(self, record: Record) -> None:
182
- logger.warning("Metadata updates are ignored when using the legacy service.")
183
-
184
174
  def handle_request_defer(self, record: Record) -> None:
185
175
  defer = record.request.defer
186
176
  state = defer.state
187
177
 
188
178
  logger.info(f"handle defer: {state}")
189
- # only handle flush tb (sender handles the rest)
190
- if state == defer.FLUSH_STATS:
191
- # TODO(jhr): this could block so we dont really want to call shutdown
192
- # from handler thread
193
- if self._system_monitor is not None:
194
- self._system_monitor.finish()
195
- elif state == defer.FLUSH_TB:
179
+ if state == defer.FLUSH_TB:
196
180
  if self._tb_watcher:
197
181
  # shutdown tensorboard workers so we get all metrics flushed
198
182
  self._tb_watcher.finish()
@@ -660,6 +644,9 @@ class HandleManager:
660
644
  def handle_footer(self, record: Record) -> None:
661
645
  self._dispatch_record(record)
662
646
 
647
+ def handle_metadata(self, record: Record) -> None:
648
+ self._dispatch_record(record)
649
+
663
650
  def handle_request_attach(self, record: Record) -> None:
664
651
  result = proto_util._result_from_record(record)
665
652
  attach_id = record.request.attach.attach_id
@@ -689,24 +676,6 @@ class HandleManager:
689
676
  else:
690
677
  self._accumulate_time = 0
691
678
 
692
- # system monitor
693
- self._system_monitor = SystemMonitor(
694
- self._settings,
695
- self._interface,
696
- )
697
- if not (
698
- self._settings.x_disable_stats or self._settings.x_disable_machine_info
699
- ):
700
- self._system_monitor.start()
701
- if (
702
- not (self._settings.x_disable_meta or self._settings.x_disable_machine_info)
703
- and not run_start.run.resumed
704
- ):
705
- try:
706
- self._metadata = Metadata(**self._system_monitor.probe(publish=True))
707
- except Exception:
708
- logger.exception("Error probing system metadata.")
709
-
710
679
  self._tb_watcher = tb_watcher.TBWatcher(
711
680
  self._settings, interface=self._interface, run_proto=run_start.run
712
681
  )
@@ -717,18 +686,11 @@ class HandleManager:
717
686
  self._respond_result(result)
718
687
 
719
688
  def handle_request_resume(self, record: Record) -> None:
720
- if self._system_monitor is not None:
721
- logger.info("starting system metrics thread")
722
- self._system_monitor.start()
723
-
724
689
  if self._track_time is not None:
725
690
  self._accumulate_time += time.time() - self._track_time
726
691
  self._track_time = time.time()
727
692
 
728
693
  def handle_request_pause(self, record: Record) -> None:
729
- if self._system_monitor is not None:
730
- logger.info("stopping system metrics thread")
731
- self._system_monitor.finish()
732
694
  if self._track_time is not None:
733
695
  self._accumulate_time += time.time() - self._track_time
734
696
  self._track_time = None
@@ -763,36 +725,6 @@ class HandleManager:
763
725
  result.response.get_summary_response.item.append(item)
764
726
  self._respond_result(result)
765
727
 
766
- def handle_request_get_system_metrics(self, record: Record) -> None:
767
- result = proto_util._result_from_record(record)
768
- if self._system_monitor is None:
769
- return
770
-
771
- buffer = self._system_monitor.buffer
772
- for key, samples in buffer.items():
773
- buff = []
774
- for s in samples:
775
- sms = SystemMetricSample()
776
- sms.timestamp.FromMicroseconds(int(s[0] * 1e6))
777
- sms.value = s[1]
778
- buff.append(sms)
779
-
780
- result.response.get_system_metrics_response.system_metrics[key].CopyFrom(
781
- SystemMetricsBuffer(record=buff)
782
- )
783
-
784
- self._respond_result(result)
785
-
786
- def handle_request_get_system_metadata(self, record: Record) -> None:
787
- result = proto_util._result_from_record(record)
788
- if self._system_monitor is None or self._metadata is None:
789
- return
790
-
791
- result.response.get_system_metadata_response.metadata.CopyFrom(
792
- self._metadata.to_proto()
793
- )
794
- self._respond_result(result)
795
-
796
728
  def handle_tbrecord(self, record: Record) -> None:
797
729
  logger.info("handling tbrecord: %s", record)
798
730
  if self._tb_watcher:
@@ -895,8 +827,6 @@ class HandleManager:
895
827
 
896
828
  def finish(self) -> None:
897
829
  logger.info("shutting down handler")
898
- if self._system_monitor is not None:
899
- self._system_monitor.finish()
900
830
  if self._tb_watcher:
901
831
  self._tb_watcher.finish()
902
832
  # self._context_keeper._debug_print_orphans()
@@ -362,6 +362,7 @@ class Api:
362
362
  self.server_create_run_queue_supports_priority: Optional[bool] = None
363
363
  self.server_supports_template_variables: Optional[bool] = None
364
364
  self.server_push_to_run_queue_supports_priority: Optional[bool] = None
365
+
365
366
  self._server_features_cache: Optional[Dict[str, bool]] = None
366
367
 
367
368
  def gql(self, *args: Any, **kwargs: Any) -> Any:
@@ -4661,3 +4662,56 @@ class Api:
4661
4662
  success: bool = response["stopRun"].get("success")
4662
4663
 
4663
4664
  return success
4665
+
4666
+ @normalize_exceptions
4667
+ def create_custom_chart(
4668
+ self,
4669
+ entity: str,
4670
+ name: str,
4671
+ display_name: str,
4672
+ spec_type: str,
4673
+ access: str,
4674
+ spec: Union[str, Mapping[str, Any]],
4675
+ ) -> Optional[Dict[str, Any]]:
4676
+ if not isinstance(spec, str):
4677
+ spec = json.dumps(spec)
4678
+
4679
+ mutation = gql(
4680
+ """
4681
+ mutation CreateCustomChart(
4682
+ $entity: String!
4683
+ $name: String!
4684
+ $displayName: String!
4685
+ $type: String!
4686
+ $access: String!
4687
+ $spec: JSONString!
4688
+ ) {
4689
+ createCustomChart(
4690
+ input: {
4691
+ entity: $entity
4692
+ name: $name
4693
+ displayName: $displayName
4694
+ type: $type
4695
+ access: $access
4696
+ spec: $spec
4697
+ }
4698
+ ) {
4699
+ chart { id }
4700
+ }
4701
+ }
4702
+ """
4703
+ )
4704
+
4705
+ variable_values = {
4706
+ "entity": entity,
4707
+ "name": name,
4708
+ "displayName": display_name,
4709
+ "type": spec_type,
4710
+ "access": access,
4711
+ "spec": spec,
4712
+ }
4713
+
4714
+ result: Optional[Dict[str, Any]] = self.gql(mutation, variable_values)[
4715
+ "createCustomChart"
4716
+ ]
4717
+ return result
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
63
63
  ArtifactManifest,
64
64
  ArtifactManifestEntry,
65
65
  ArtifactRecord,
66
+ EnvironmentRecord,
66
67
  HttpResponse,
67
68
  LocalInfo,
68
69
  Record,
@@ -212,6 +213,7 @@ class SendManager:
212
213
  _context_keeper: context.ContextKeeper
213
214
 
214
215
  _telemetry_obj: telemetry.TelemetryRecord
216
+ _environment_obj: "EnvironmentRecord"
215
217
  _fs: Optional["file_stream.FileStreamApi"]
216
218
  _run: Optional["RunRecord"]
217
219
  _entity: Optional[str]
@@ -268,6 +270,7 @@ class SendManager:
268
270
 
269
271
  self._start_time: int = 0
270
272
  self._telemetry_obj = telemetry.TelemetryRecord()
273
+ self._environment_obj = wandb_internal_pb2.EnvironmentRecord()
271
274
  self._config_metric_pbdict_list: List[Dict[int, Any]] = []
272
275
  self._metadata_summary: Dict[str, Any] = defaultdict()
273
276
  self._cached_summary: Dict[str, Any] = dict()
@@ -790,12 +793,12 @@ class SendManager:
790
793
 
791
794
  def _config_backend_dict(self) -> sender_config.BackendConfigDict:
792
795
  config = self._consolidated_config or sender_config.ConfigState()
793
-
794
796
  return config.to_backend_dict(
795
797
  telemetry_record=self._telemetry_obj,
796
798
  framework=self._telemetry_get_framework(),
797
799
  start_time_millis=self._start_time,
798
800
  metric_pbdicts=self._config_metric_pbdict_list,
801
+ environment_record=self._environment_obj,
799
802
  )
800
803
 
801
804
  def _config_save(
@@ -1379,11 +1382,11 @@ class SendManager:
1379
1382
  next_idx = len(self._config_metric_pbdict_list)
1380
1383
  self._config_metric_pbdict_list.append(md)
1381
1384
  self._config_metric_index_dict[metric.name] = next_idx
1382
- self._update_config()
1385
+ self._debounce_config()
1383
1386
 
1384
1387
  def _update_telemetry_record(self, telemetry: telemetry.TelemetryRecord) -> None:
1385
1388
  self._telemetry_obj.MergeFrom(telemetry)
1386
- self._update_config()
1389
+ self._debounce_config()
1387
1390
 
1388
1391
  def send_telemetry(self, record: "Record") -> None:
1389
1392
  self._update_telemetry_record(record.telemetry)
@@ -1417,6 +1420,23 @@ class SendManager:
1417
1420
  # tbrecord watching threads are handled by handler.py
1418
1421
  pass
1419
1422
 
1423
+ def _update_environment_record(self, environment: "EnvironmentRecord") -> None:
1424
+ self._environment_obj.MergeFrom(environment)
1425
+ self._debounce_config()
1426
+
1427
+ def send_environment(self, record: "Record") -> None:
1428
+ """Inject environment info into config and upload as a JSON file."""
1429
+ self._update_environment_record(record.environment)
1430
+
1431
+ environment_json = json.dumps(proto_util.message_to_dict(self._environment_obj))
1432
+
1433
+ with open(
1434
+ os.path.join(self._settings.files_dir, filenames.METADATA_FNAME), "w"
1435
+ ) as f:
1436
+ f.write(environment_json)
1437
+
1438
+ self._save_file(interface.GlobStr(filenames.METADATA_FNAME), policy="now")
1439
+
1420
1440
  def send_request_link_artifact(self, record: "Record") -> None:
1421
1441
  if not (record.control.req_resp or record.control.mailbox_slot):
1422
1442
  raise ValueError(
@@ -79,6 +79,7 @@ class ConfigState:
79
79
  framework: Optional[str],
80
80
  start_time_millis: int,
81
81
  metric_pbdicts: Sequence[Dict[int, Any]],
82
+ environment_record: wandb_internal_pb2.EnvironmentRecord,
82
83
  ) -> BackendConfigDict:
83
84
  """Returns a dictionary representation expected by the backend.
84
85
 
@@ -125,6 +126,14 @@ class ConfigState:
125
126
  if metric_pbdicts:
126
127
  wandb_internal["m"] = metric_pbdicts
127
128
 
129
+ ###################################################
130
+ # Environment
131
+ ###################################################
132
+ writer_id = environment_record.writer_id
133
+ if writer_id:
134
+ environment_dict = proto_util.message_to_dict(environment_record)
135
+ wandb_internal["e"] = {writer_id: environment_dict}
136
+
128
137
  return BackendConfigDict(
129
138
  {
130
139
  key: {
@@ -370,9 +370,9 @@ class LaunchProject:
370
370
 
371
371
  def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
372
372
  """Set job entrypoint for the project."""
373
- assert (
374
- self._entry_point is None
375
- ), "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
373
+ assert self._entry_point is None, (
374
+ "Cannot set entry point twice. Use LaunchProject.override_entrypoint"
375
+ )
376
376
  new_entrypoint = EntryPoint(name=command[-1], command=command)
377
377
  self._entry_point = new_entrypoint
378
378
  return new_entrypoint
@@ -76,9 +76,9 @@ class JobSpecAndQueue:
76
76
  def _convert_access(access: str) -> str:
77
77
  """Convert access string to a value accepted by wandb."""
78
78
  access = access.upper()
79
- assert (
80
- access == "PROJECT" or access == "USER"
81
- ), "Queue access must be either project or user"
79
+ assert access == "PROJECT" or access == "USER", (
80
+ "Queue access must be either project or user"
81
+ )
82
82
  return access
83
83
 
84
84
 
@@ -44,7 +44,9 @@ class JobAndRunStatusTracker:
44
44
  self.run_id is not None
45
45
  and self.project is not None
46
46
  and self.entity is not None
47
- ), "Job tracker does not contain run info. Update with run info before checking if run stopped"
47
+ ), (
48
+ "Job tracker does not contain run info. Update with run info before checking if run stopped"
49
+ )
48
50
  check_stop = event_loop_thread_exec(api.api.check_stop_requested)
49
51
  try:
50
52
  return bool(await check_stop(self.project, self.entity, self.run_id))
wandb/sdk/launch/utils.py CHANGED
@@ -380,9 +380,9 @@ def diff_pip_requirements(req_1: List[str], req_2: List[str]) -> Dict[str, str]:
380
380
  else:
381
381
  raise ValueError(f"Unable to parse pip requirements file line: {line}")
382
382
  if _name is not None:
383
- assert re.match(
384
- _VALID_PIP_PACKAGE_REGEX, _name
385
- ), f"Invalid pip package name {_name}"
383
+ assert re.match(_VALID_PIP_PACKAGE_REGEX, _name), (
384
+ f"Invalid pip package name {_name}"
385
+ )
386
386
  d[_name] = _version
387
387
  return d
388
388
 
@@ -25,17 +25,38 @@ In particular, it does not work with some combinations of pytest's
25
25
 
26
26
  from __future__ import annotations
27
27
 
28
+ import logging
28
29
  import sys
29
30
  import threading
30
31
  from typing import IO, AnyStr, Callable, Protocol
31
32
 
33
+ from . import wb_logging
34
+
35
+ _logger = logging.getLogger(__name__)
36
+
32
37
 
33
38
  class CannotCaptureConsoleError(Exception):
34
39
  """The module failed to patch stdout or stderr."""
35
40
 
36
41
 
37
42
  class _WriteCallback(Protocol):
38
- """A callback that receives intercepted bytes or string data."""
43
+ """A callback that receives intercepted bytes or string data.
44
+
45
+ This may be called from any thread, but is only called from one thread
46
+ at a time.
47
+
48
+ Note on errors: Any error raised during the callback will clear all
49
+ callbacks. This means that if a user presses Ctrl-C at an unlucky time
50
+ during a run, we will stop uploading console output---but it's not
51
+ likely to be a problem unless something catches the KeyboardInterrupt.
52
+
53
+ Regular Exceptions are caught and logged instead of bubbling up to the
54
+ user's print() statements; other exceptions like KeyboardInterrupt are
55
+ re-raised.
56
+
57
+ Callbacks should handle all exceptions---a callback that raises any
58
+ Exception is considered buggy.
59
+ """
39
60
 
40
61
  def __call__(
41
62
  self,
@@ -45,6 +66,8 @@ class _WriteCallback(Protocol):
45
66
  ) -> None:
46
67
  """Intercept data passed to `write()`.
47
68
 
69
+ See the protocol docstring for information about exceptions.
70
+
48
71
  Args:
49
72
  data: The object passed to stderr's or stdout's `write()`.
50
73
  written: The number of bytes or characters written.
@@ -52,7 +75,9 @@ class _WriteCallback(Protocol):
52
75
  """
53
76
 
54
77
 
55
- _module_lock = threading.Lock()
78
+ # A reentrant lock is used to catch callbacks that write to stderr/stdout.
79
+ _module_rlock = threading.RLock()
80
+ _is_writing = False
56
81
 
57
82
  _patch_exception: CannotCaptureConsoleError | None = None
58
83
 
@@ -67,9 +92,6 @@ def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
67
92
 
68
93
  Args:
69
94
  callback: A callback to invoke after running `sys.stdout.write`.
70
- This may be called from any thread, so it must be thread-safe.
71
- Exceptions are propagated to the caller of `write`.
72
- See `_WriteCallback` for the exact protocol.
73
95
 
74
96
  Returns:
75
97
  A function to uninstall the callback.
@@ -77,7 +99,7 @@ def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
77
99
  Raises:
78
100
  CannotCaptureConsoleError: If patching failed on import.
79
101
  """
80
- with _module_lock:
102
+ with _module_rlock:
81
103
  if _patch_exception:
82
104
  raise _patch_exception
83
105
 
@@ -92,9 +114,6 @@ def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
92
114
 
93
115
  Args:
94
116
  callback: A callback to invoke after running `sys.stderr.write`.
95
- This may be called from any thread, so it must be thread-safe.
96
- Exceptions are propagated to the caller of `write`.
97
- See `_WriteCallback` for the exact protocol.
98
117
 
99
118
  Returns:
100
119
  A function to uninstall the callback.
@@ -102,7 +121,7 @@ def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
102
121
  Raises:
103
122
  CannotCaptureConsoleError: If patching failed on import.
104
123
  """
105
- with _module_lock:
124
+ with _module_rlock:
106
125
  if _patch_exception:
107
126
  raise _patch_exception
108
127
 
@@ -125,11 +144,11 @@ def _insert_disposably(
125
144
  def dispose() -> None:
126
145
  nonlocal disposed
127
146
 
128
- with _module_lock:
147
+ with _module_rlock:
129
148
  if disposed:
130
149
  return
131
150
 
132
- del callback_dict[id]
151
+ callback_dict.pop(id, None)
133
152
 
134
153
  disposed = True
135
154
 
@@ -143,16 +162,44 @@ def _patch(
143
162
  ) -> None:
144
163
  orig_write: Callable[[AnyStr], int]
145
164
 
165
+ @wb_logging.log_to_all_runs()
146
166
  def write_with_callbacks(s: AnyStr, /) -> int:
167
+ global _is_writing
147
168
  n = orig_write(s)
148
169
 
149
- # We make a copy here because callbacks could, in theory, modify
150
- # the list of callbacks.
151
- with _module_lock:
152
- callbacks_copy = list(callbacks.values())
153
-
154
- for cb in callbacks_copy:
155
- cb(s, n)
170
+ # NOTE: Since _module_rlock is reentrant, this is safe. It will not
171
+ # deadlock if a callback invokes write() again.
172
+ with _module_rlock:
173
+ if _is_writing:
174
+ return n
175
+
176
+ _is_writing = True
177
+ try:
178
+ for cb in callbacks.values():
179
+ cb(s, n)
180
+
181
+ except BaseException as e:
182
+ # Clear all callbacks on any exception to avoid infinite loops:
183
+ #
184
+ # * If we re-raise, an exception handler is likely to print
185
+ # the exception to the console and trigger callbacks again
186
+ # * If we log, we can't guarantee that this doesn't print
187
+ # to console.
188
+ #
189
+ # This is especially important for KeyboardInterrupt.
190
+ _stderr_callbacks.clear()
191
+ _stdout_callbacks.clear()
192
+
193
+ if isinstance(e, Exception):
194
+ # We suppress Exceptions so that bugs in W&B code don't
195
+ # cause the user's print() statements to raise errors.
196
+ _logger.exception("Error in console callback, clearing all!")
197
+ else:
198
+ # Re-raise errors like KeyboardInterrupt.
199
+ raise
200
+
201
+ finally:
202
+ _is_writing = False
156
203
 
157
204
  return n
158
205
 
wandb/sdk/lib/printer.py CHANGED
@@ -349,13 +349,13 @@ class _PrinterTerm(Printer):
349
349
  text = text or " " * 79
350
350
  wandb.termlog(text)
351
351
 
352
- @override
353
352
  @property
353
+ @override
354
354
  def supports_html(self) -> bool:
355
355
  return False
356
356
 
357
- @override
358
357
  @property
358
+ @override
359
359
  def supports_unicode(self) -> bool:
360
360
  return wandb.util.is_unicode_safe(sys.stderr)
361
361
 
@@ -464,11 +464,10 @@ class _PrinterJupyter(Printer):
464
464
 
465
465
  if handle:
466
466
  yield _DynamicJupyterText(handle)
467
+ handle.update(self._ipython_display.HTML(""))
467
468
  else:
468
469
  yield None
469
470
 
470
- handle.update(self._ipython_display.HTML(""))
471
-
472
471
  @override
473
472
  def display(
474
473
  self,
@@ -483,13 +482,13 @@ class _PrinterJupyter(Printer):
483
482
  text = "<br>".join(text.splitlines())
484
483
  self._ipython_display.display(self._ipython_display.HTML(text))
485
484
 
486
- @override
487
485
  @property
486
+ @override
488
487
  def supports_html(self) -> bool:
489
488
  return True
490
489
 
491
- @override
492
490
  @property
491
+ @override
493
492
  def supports_unicode(self) -> bool:
494
493
  return True
495
494
 
@@ -540,7 +539,7 @@ class _PrinterJupyter(Printer):
540
539
  self._progress.update(percent_done, text)
541
540
 
542
541
  @override
543
- def progress_close(self, _: str | None = None) -> None:
542
+ def progress_close(self, text: str | None = None) -> None:
544
543
  if self._progress:
545
544
  self._progress.close()
546
545
 
wandb/sdk/lib/progress.py CHANGED
@@ -7,7 +7,6 @@ import contextlib
7
7
  import time
8
8
  from typing import Iterable, Iterator, NoReturn
9
9
 
10
- from wandb import env
11
10
  from wandb.proto import wandb_internal_pb2 as pb
12
11
  from wandb.sdk.interface import interface
13
12
  from wandb.sdk.lib import asyncio_compat
@@ -107,8 +106,7 @@ class ProgressPrinter:
107
106
  progress_text_area: p.DynamicText | None,
108
107
  default_text: str,
109
108
  ) -> None:
110
- # Not implemented by the legacy service.
111
- self._show_operation_stats = not env.is_require_legacy_service()
109
+ self._show_operation_stats = True
112
110
  self._printer = printer
113
111
  self._progress_text_area = progress_text_area
114
112
  self._default_text = default_text
@@ -0,0 +1,13 @@
1
+ """Constants determining what IPC methods are supported."""
2
+
3
+ import socket
4
+
5
+ SUPPORTS_UNIX = hasattr(socket, "AF_UNIX")
6
+ """Whether Unix sockets are supported.
7
+
8
+ AF_UNIX is not supported on Windows:
9
+ https://github.com/python/cpython/issues/77589
10
+
11
+ Windows has supported Unix sockets since ~2017, but support in Python is
12
+ missing as of 2025.
13
+ """