wandb 0.19.1rc1__py3-none-any.whl → 0.19.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (79) hide show
  1. wandb/__init__.py +1 -7
  2. wandb/__init__.pyi +15 -7
  3. wandb/agents/pyagent.py +1 -1
  4. wandb/apis/importers/wandb.py +1 -1
  5. wandb/apis/public/files.py +1 -1
  6. wandb/apis/public/jobs.py +1 -1
  7. wandb/apis/public/runs.py +2 -7
  8. wandb/apis/reports/v1/__init__.py +1 -1
  9. wandb/apis/reports/v2/__init__.py +1 -1
  10. wandb/apis/workspaces/__init__.py +1 -1
  11. wandb/bin/gpu_stats +0 -0
  12. wandb/cli/beta.py +7 -4
  13. wandb/cli/cli.py +5 -7
  14. wandb/docker/__init__.py +4 -4
  15. wandb/integration/fastai/__init__.py +4 -6
  16. wandb/integration/keras/keras.py +5 -3
  17. wandb/integration/metaflow/metaflow.py +14 -16
  18. wandb/integration/prodigy/prodigy.py +3 -11
  19. wandb/integration/sagemaker/__init__.py +5 -3
  20. wandb/integration/sagemaker/config.py +17 -8
  21. wandb/integration/sagemaker/files.py +0 -1
  22. wandb/integration/sagemaker/resources.py +47 -18
  23. wandb/integration/torch/wandb_torch.py +1 -1
  24. wandb/proto/v3/wandb_internal_pb2.py +273 -235
  25. wandb/proto/v4/wandb_internal_pb2.py +222 -214
  26. wandb/proto/v5/wandb_internal_pb2.py +222 -214
  27. wandb/sdk/artifacts/artifact.py +3 -9
  28. wandb/sdk/backend/backend.py +1 -1
  29. wandb/sdk/data_types/base_types/wb_value.py +1 -1
  30. wandb/sdk/data_types/graph.py +2 -2
  31. wandb/sdk/data_types/saved_model.py +1 -1
  32. wandb/sdk/data_types/video.py +1 -1
  33. wandb/sdk/interface/interface.py +25 -25
  34. wandb/sdk/interface/interface_shared.py +21 -5
  35. wandb/sdk/internal/handler.py +19 -1
  36. wandb/sdk/internal/internal.py +1 -1
  37. wandb/sdk/internal/internal_api.py +4 -5
  38. wandb/sdk/internal/sample.py +2 -2
  39. wandb/sdk/internal/sender.py +1 -2
  40. wandb/sdk/internal/settings_static.py +3 -1
  41. wandb/sdk/internal/system/assets/disk.py +4 -4
  42. wandb/sdk/internal/system/assets/gpu.py +1 -1
  43. wandb/sdk/internal/system/assets/memory.py +1 -1
  44. wandb/sdk/internal/system/system_info.py +1 -1
  45. wandb/sdk/internal/system/system_monitor.py +3 -1
  46. wandb/sdk/internal/tb_watcher.py +1 -1
  47. wandb/sdk/launch/_project_spec.py +3 -3
  48. wandb/sdk/launch/builder/abstract.py +1 -1
  49. wandb/sdk/lib/apikey.py +2 -3
  50. wandb/sdk/lib/fsm.py +1 -1
  51. wandb/sdk/lib/gitlib.py +1 -1
  52. wandb/sdk/lib/gql_request.py +1 -1
  53. wandb/sdk/lib/interrupt.py +37 -0
  54. wandb/sdk/lib/lazyloader.py +1 -1
  55. wandb/sdk/lib/progress.py +7 -1
  56. wandb/sdk/lib/service_connection.py +1 -1
  57. wandb/sdk/lib/telemetry.py +1 -1
  58. wandb/sdk/service/_startup_debug.py +1 -1
  59. wandb/sdk/service/server_sock.py +3 -2
  60. wandb/sdk/service/service.py +1 -1
  61. wandb/sdk/service/streams.py +19 -17
  62. wandb/sdk/verify/verify.py +13 -13
  63. wandb/sdk/wandb_init.py +316 -246
  64. wandb/sdk/wandb_login.py +1 -1
  65. wandb/sdk/wandb_metadata.py +547 -0
  66. wandb/sdk/wandb_run.py +134 -39
  67. wandb/sdk/wandb_settings.py +7 -63
  68. wandb/sdk/wandb_setup.py +83 -82
  69. wandb/sdk/wandb_sweep.py +2 -2
  70. wandb/sdk/wandb_sync.py +15 -18
  71. wandb/sync/sync.py +10 -10
  72. wandb/util.py +11 -3
  73. wandb/wandb_agent.py +11 -16
  74. wandb/wandb_controller.py +7 -7
  75. {wandb-0.19.1rc1.dist-info → wandb-0.19.3.dist-info}/METADATA +5 -3
  76. {wandb-0.19.1rc1.dist-info → wandb-0.19.3.dist-info}/RECORD +79 -77
  77. {wandb-0.19.1rc1.dist-info → wandb-0.19.3.dist-info}/WHEEL +1 -1
  78. {wandb-0.19.1rc1.dist-info → wandb-0.19.3.dist-info}/entry_points.txt +0 -0
  79. {wandb-0.19.1rc1.dist-info → wandb-0.19.3.dist-info}/licenses/LICENSE +0 -0
@@ -1685,11 +1685,7 @@ class Artifact:
1685
1685
  from wandb.sdk.backend.backend import Backend
1686
1686
 
1687
1687
  if wandb.run is None:
1688
- # ensure wandb-core is up and running
1689
- from wandb.sdk import wandb_setup
1690
-
1691
- wl = wandb_setup.setup()
1692
- assert wl is not None
1688
+ wl = wandb.setup()
1693
1689
 
1694
1690
  stream_id = generate_id()
1695
1691
 
@@ -1702,9 +1698,7 @@ class Artifact:
1702
1698
  settings.files_dir.value = str(tmp_dir / "files")
1703
1699
  settings.run_id.value = stream_id
1704
1700
 
1705
- service = wl.service
1706
- assert service
1707
-
1701
+ service = wl.ensure_service()
1708
1702
  service.inform_init(settings=settings, run_id=stream_id)
1709
1703
 
1710
1704
  mailbox = Mailbox()
@@ -1941,7 +1935,7 @@ class Artifact:
1941
1935
  else:
1942
1936
  ref_count += 1
1943
1937
  if ref_count > 0:
1944
- print("Warning: skipped verification of {} refs".format(ref_count))
1938
+ termwarn(f"skipped verification of {ref_count} refs")
1945
1939
 
1946
1940
  @ensure_logged
1947
1941
  def file(self, root: str | None = None) -> StrPath:
@@ -145,7 +145,7 @@ class Backend:
145
145
  return
146
146
 
147
147
  assert self._settings
148
- settings = self._settings.copy()
148
+ settings = self._settings.model_copy()
149
149
  settings.x_log_level = self._log_level or logging.DEBUG
150
150
 
151
151
  start_method = settings.start_method
@@ -135,7 +135,7 @@ class WBValue:
135
135
  def init_from_json(
136
136
  json_obj: dict, source_artifact: "Artifact"
137
137
  ) -> Optional["WBValue"]:
138
- """Initialize a `WBValue` from a JSON blob based on the class that creatd it.
138
+ """Initialize a `WBValue` from a JSON blob based on the class that created it.
139
139
 
140
140
  Looks through all subclasses and tries to match the json obj with the class
141
141
  which created it. It will then call that subclass' `from_json` method.
@@ -311,9 +311,9 @@ class Graph(Media):
311
311
 
312
312
  def pprint(self):
313
313
  for edge in self.edges:
314
- pprint.pprint(edge.attributes)
314
+ pprint.pprint(edge.attributes) # noqa: T203
315
315
  for node in self.nodes:
316
- pprint.pprint(node.attributes)
316
+ pprint.pprint(node.attributes) # noqa: T203
317
317
 
318
318
  def add_node(self, node=None, **node_kwargs):
319
319
  if node is None:
@@ -231,7 +231,7 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
231
231
  return cls(obj_or_path, **kwargs)
232
232
  except Exception as e:
233
233
  if DEBUG_MODE:
234
- print(f"{cls}._maybe_init({obj_or_path}) failed: {e}")
234
+ print(f"{cls}._maybe_init({obj_or_path}) failed: {e}") # noqa: T201
235
235
 
236
236
  for child_cls in cls.__subclasses__():
237
237
  maybe_instance = child_cls._maybe_init(obj_or_path, **kwargs)
@@ -212,7 +212,7 @@ class Video(BatchableMedia):
212
212
  )
213
213
  if video.ndim < 4:
214
214
  raise ValueError(
215
- "Video must be atleast 4 dimensions: time, channels, height, width"
215
+ "Video must be at least 4 dimensions: time, channels, height, width"
216
216
  )
217
217
  if video.ndim == 4:
218
218
  video = video.reshape(1, *video.shape)
@@ -216,6 +216,13 @@ class InterfaceBase:
216
216
  def _publish_config(self, cfg: pb.ConfigRecord) -> None:
217
217
  raise NotImplementedError
218
218
 
219
+ def publish_metadata(self, metadata: pb.MetadataRequest) -> None:
220
+ self._publish_metadata(metadata)
221
+
222
+ @abstractmethod
223
+ def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
224
+ raise NotImplementedError
225
+
219
226
  @abstractmethod
220
227
  def _publish_metric(self, metric: pb.MetricRecord) -> None:
221
228
  raise NotImplementedError
@@ -722,7 +729,7 @@ class InterfaceBase:
722
729
  otype = pb.OutputRecord.OutputType.STDERR
723
730
  else:
724
731
  # TODO(jhr): throw error?
725
- print("unknown type")
732
+ termwarn("unknown type")
726
733
  o = pb.OutputRecord(output_type=otype, line=data)
727
734
  o.timestamp.GetCurrentTime()
728
735
  self._publish_output(o)
@@ -742,7 +749,7 @@ class InterfaceBase:
742
749
  otype = pb.OutputRawRecord.OutputType.STDERR
743
750
  else:
744
751
  # TODO(jhr): throw error?
745
- print("unknown type")
752
+ termwarn("unknown type")
746
753
  o = pb.OutputRawRecord(output_type=otype, line=data)
747
754
  o.timestamp.GetCurrentTime()
748
755
  self._publish_output_raw(o)
@@ -872,31 +879,14 @@ class InterfaceBase:
872
879
  run_record = self._make_run(run)
873
880
  return self._deliver_run(run_record)
874
881
 
875
- def deliver_sync(
882
+ def deliver_finish_sync(
876
883
  self,
877
- start_offset: int,
878
- final_offset: int,
879
- entity: Optional[str] = None,
880
- project: Optional[str] = None,
881
- run_id: Optional[str] = None,
882
- skip_output_raw: Optional[bool] = None,
883
884
  ) -> MailboxHandle:
884
- sync = pb.SyncRequest(
885
- start_offset=start_offset,
886
- final_offset=final_offset,
887
- )
888
- if entity:
889
- sync.overwrite.entity = entity
890
- if project:
891
- sync.overwrite.project = project
892
- if run_id:
893
- sync.overwrite.run_id = run_id
894
- if skip_output_raw:
895
- sync.skip.output_raw = skip_output_raw
896
- return self._deliver_sync(sync)
885
+ sync = pb.SyncFinishRequest()
886
+ return self._deliver_finish_sync(sync)
897
887
 
898
888
  @abstractmethod
899
- def _deliver_sync(self, sync: pb.SyncRequest) -> MailboxHandle:
889
+ def _deliver_finish_sync(self, sync: pb.SyncFinishRequest) -> MailboxHandle:
900
890
  raise NotImplementedError
901
891
 
902
892
  @abstractmethod
@@ -954,8 +944,8 @@ class InterfaceBase:
954
944
  raise NotImplementedError
955
945
 
956
946
  def deliver_get_system_metrics(self) -> MailboxHandle:
957
- get_summary = pb.GetSystemMetricsRequest()
958
- return self._deliver_get_system_metrics(get_summary)
947
+ get_system_metrics = pb.GetSystemMetricsRequest()
948
+ return self._deliver_get_system_metrics(get_system_metrics)
959
949
 
960
950
  @abstractmethod
961
951
  def _deliver_get_system_metrics(
@@ -963,6 +953,16 @@ class InterfaceBase:
963
953
  ) -> MailboxHandle:
964
954
  raise NotImplementedError
965
955
 
956
+ def deliver_get_system_metadata(self) -> MailboxHandle:
957
+ get_system_metadata = pb.GetSystemMetadataRequest()
958
+ return self._deliver_get_system_metadata(get_system_metadata)
959
+
960
+ @abstractmethod
961
+ def _deliver_get_system_metadata(
962
+ self, get_system_metadata: pb.GetSystemMetadataRequest
963
+ ) -> MailboxHandle:
964
+ raise NotImplementedError
965
+
966
966
  def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle:
967
967
  exit_data = self._make_exit(exit_code)
968
968
  return self._deliver_exit(exit_data)
@@ -145,15 +145,17 @@ class InterfaceShared(InterfaceBase):
145
145
  run_status: Optional[pb.RunStatusRequest] = None,
146
146
  sender_mark: Optional[pb.SenderMarkRequest] = None,
147
147
  sender_read: Optional[pb.SenderReadRequest] = None,
148
- sync: Optional[pb.SyncRequest] = None,
148
+ sync_finish: Optional[pb.SyncFinishRequest] = None,
149
149
  status_report: Optional[pb.StatusReportRequest] = None,
150
150
  cancel: Optional[pb.CancelRequest] = None,
151
151
  summary_record: Optional[pb.SummaryRecordRequest] = None,
152
152
  telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
153
153
  get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
154
+ get_system_metadata: Optional[pb.GetSystemMetadataRequest] = None,
154
155
  python_packages: Optional[pb.PythonPackagesRequest] = None,
155
156
  job_input: Optional[pb.JobInputRequest] = None,
156
157
  run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
158
+ metadata: Optional[pb.MetadataRequest] = None,
157
159
  ) -> pb.Record:
158
160
  request = pb.Request()
159
161
  if login:
@@ -212,14 +214,18 @@ class InterfaceShared(InterfaceBase):
212
214
  request.telemetry_record.CopyFrom(telemetry_record)
213
215
  elif get_system_metrics:
214
216
  request.get_system_metrics.CopyFrom(get_system_metrics)
215
- elif sync:
216
- request.sync.CopyFrom(sync)
217
+ elif get_system_metadata:
218
+ request.get_system_metadata.CopyFrom(get_system_metadata)
219
+ elif sync_finish:
220
+ request.sync_finish.CopyFrom(sync_finish)
217
221
  elif python_packages:
218
222
  request.python_packages.CopyFrom(python_packages)
219
223
  elif job_input:
220
224
  request.job_input.CopyFrom(job_input)
221
225
  elif run_finish_without_exit:
222
226
  request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
227
+ elif metadata:
228
+ request.metadata.CopyFrom(metadata)
223
229
  else:
224
230
  raise Exception("Invalid request")
225
231
  record = self._make_record(request=request)
@@ -377,6 +383,10 @@ class InterfaceShared(InterfaceBase):
377
383
  rec = self._make_record(summary=summary)
378
384
  self._publish(rec)
379
385
 
386
+ def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
387
+ rec = self._make_request(metadata=metadata)
388
+ self._publish(rec)
389
+
380
390
  def _publish_metric(self, metric: pb.MetricRecord) -> None:
381
391
  rec = self._make_record(metric=metric)
382
392
  self._publish(rec)
@@ -459,8 +469,8 @@ class InterfaceShared(InterfaceBase):
459
469
  record = self._make_record(run=run)
460
470
  return self._deliver_record(record)
461
471
 
462
- def _deliver_sync(self, sync: pb.SyncRequest) -> MailboxHandle:
463
- record = self._make_request(sync=sync)
472
+ def _deliver_finish_sync(self, sync_finish: pb.SyncFinishRequest) -> MailboxHandle:
473
+ record = self._make_request(sync_finish=sync_finish)
464
474
  return self._deliver_record(record)
465
475
 
466
476
  def _deliver_run_start(self, run_start: pb.RunStartRequest) -> MailboxHandle:
@@ -477,6 +487,12 @@ class InterfaceShared(InterfaceBase):
477
487
  record = self._make_request(get_system_metrics=get_system_metrics)
478
488
  return self._deliver_record(record)
479
489
 
490
+ def _deliver_get_system_metadata(
491
+ self, get_system_metadata: pb.GetSystemMetadataRequest
492
+ ) -> MailboxHandle:
493
+ record = self._make_request(get_system_metadata=get_system_metadata)
494
+ return self._deliver_record(record)
495
+
480
496
  def _deliver_exit(self, exit_data: pb.RunExitRecord) -> MailboxHandle:
481
497
  record = self._make_record(exit=exit_data)
482
498
  return self._deliver_record(record)
@@ -39,6 +39,7 @@ from wandb.proto.wandb_internal_pb2 import (
39
39
 
40
40
  from ..interface.interface_queue import InterfaceQueue
41
41
  from ..lib import handler_util, proto_util
42
+ from ..wandb_metadata import Metadata
42
43
  from . import context, sample, tb_watcher
43
44
  from .settings_static import SettingsStatic
44
45
  from .system.system_monitor import SystemMonitor
@@ -119,6 +120,7 @@ class HandleManager:
119
120
 
120
121
  self._tb_watcher = None
121
122
  self._system_monitor = None
123
+ self._metadata: Optional[Metadata] = None
122
124
  self._step = 0
123
125
 
124
126
  self._track_time = None
@@ -176,6 +178,9 @@ class HandleManager:
176
178
  def handle_request_cancel(self, record: Record) -> None:
177
179
  self._dispatch_record(record)
178
180
 
181
+ def handle_request_metadata(self, record: Record) -> None:
182
+ logger.warning("Metadata updates are ignored when using the legacy service.")
183
+
179
184
  def handle_request_defer(self, record: Record) -> None:
180
185
  defer = record.request.defer
181
186
  state = defer.state
@@ -700,7 +705,10 @@ class HandleManager:
700
705
  not (self._settings.x_disable_meta or self._settings.x_disable_machine_info)
701
706
  and not run_start.run.resumed
702
707
  ):
703
- self._system_monitor.probe(publish=True)
708
+ try:
709
+ self._metadata = Metadata(**self._system_monitor.probe(publish=True))
710
+ except Exception as e:
711
+ logger.error("Error probing system metadata: %s", e)
704
712
 
705
713
  self._tb_watcher = tb_watcher.TBWatcher(
706
714
  self._settings, interface=self._interface, run_proto=run_start.run
@@ -778,6 +786,16 @@ class HandleManager:
778
786
 
779
787
  self._respond_result(result)
780
788
 
789
+ def handle_request_get_system_metadata(self, record: Record) -> None:
790
+ result = proto_util._result_from_record(record)
791
+ if self._system_monitor is None or self._metadata is None:
792
+ return
793
+
794
+ result.response.get_system_metadata_response.metadata.CopyFrom(
795
+ self._metadata.to_proto()
796
+ )
797
+ self._respond_result(result)
798
+
781
799
  def handle_tbrecord(self, record: Record) -> None:
782
800
  logger.info("handling tbrecord: %s", record)
783
801
  if self._tb_watcher:
@@ -165,7 +165,7 @@ def wandb_internal(
165
165
  exc_info = thread.get_exception()
166
166
  if exc_info:
167
167
  logger.error(f"Thread {thread.name}:", exc_info=exc_info)
168
- print(f"Thread {thread.name}:", file=sys.stderr)
168
+ print(f"Thread {thread.name}:", file=sys.stderr) # noqa: T201
169
169
  traceback.print_exception(*exc_info)
170
170
  wandb._sentry.exception(exc_info)
171
171
  wandb.termerror("Internal wandb error: file data was not synced")
@@ -400,9 +400,6 @@ class Api:
400
400
  wandb.termerror(f"Error while calling W&B API: {error} ({response})")
401
401
  raise
402
402
 
403
- def disabled(self) -> Union[str, bool]:
404
- return self._settings.get(Settings.DEFAULT_SECTION, "disabled", fallback=False) # type: ignore
405
-
406
403
  def set_current_run_id(self, run_id: str) -> None:
407
404
  self._current_run_id = run_id
408
405
 
@@ -2321,7 +2318,9 @@ class Api:
2321
2318
  "commit": commit,
2322
2319
  "displayName": display_name,
2323
2320
  "notes": notes,
2324
- "host": None if self.settings().get("anonymous") == "true" else host,
2321
+ "host": None
2322
+ if self.settings().get("anonymous") in ["allow", "must"]
2323
+ else host,
2325
2324
  "debug": env.is_debug(env=self._environ),
2326
2325
  "repo": repo,
2327
2326
  "program": program_path,
@@ -3468,7 +3467,7 @@ class Api:
3468
3467
  else open(normal_name, "rb")
3469
3468
  )
3470
3469
  except OSError:
3471
- print(f"{file_name} does not exist")
3470
+ print(f"{file_name} does not exist") # noqa: T201
3472
3471
  continue
3473
3472
  if progress is False:
3474
3473
  responses.append(
@@ -30,11 +30,11 @@ class UniformSampleAccumulator:
30
30
  self._log2 += [int(math.log(i, 2)) for i in range(1, 2**self._buckets + 1)]
31
31
 
32
32
  def _show(self):
33
- print("=" * 20)
33
+ print("=" * 20) # noqa: T201
34
34
  for b in range(self._buckets):
35
35
  b = (b + self._buckets_index) % self._buckets
36
36
  vals = [self._bucket[b][i] for i in range(self._index[b])]
37
- print(f"{b}: {vals}")
37
+ print(f"{b}: {vals}") # noqa: T201
38
38
 
39
39
  def add(self, val):
40
40
  self._count += 1
@@ -323,7 +323,6 @@ class SendManager:
323
323
 
324
324
  Exclusively used in `sync.py`.
325
325
  """
326
- print(root_dir)
327
326
  files_dir = os.path.join(root_dir, "files")
328
327
  settings = wandb.Settings(
329
328
  x_files_dir=files_dir,
@@ -1339,7 +1338,7 @@ class SendManager:
1339
1338
  if not line.endswith("\n"):
1340
1339
  self._partial_output.setdefault(stream, "")
1341
1340
  if line.startswith("\r"):
1342
- # TODO: maybe we shouldnt just drop this, what if there was some \ns in the partial
1341
+ # TODO: maybe we shouldn't just drop this, what if there was some \ns in the partial
1343
1342
  # that should probably be the check instead of not line.endswith(\n")
1344
1343
  # logger.info(f"Dropping data {self._partial_output[stream]}")
1345
1344
  self._partial_output[stream] = ""
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from typing import Any, Iterable
2
4
 
3
5
  from wandb.proto import wandb_settings_pb2
@@ -78,7 +80,7 @@ class SettingsStatic(Settings):
78
80
  def __setitem__(self, key: str, val: object) -> None:
79
81
  raise AttributeError("Error: SettingsStatic is a readonly object")
80
82
 
81
- def keys(self) -> "Iterable[str]":
83
+ def keys(self) -> Iterable[str]:
82
84
  return self.__dict__.keys()
83
85
 
84
86
  def __getitem__(self, key: str) -> Any:
@@ -190,10 +190,10 @@ class Disk:
190
190
  disk_metrics = {}
191
191
  for disk_path in disk_paths:
192
192
  try:
193
- # total disk space in GB:
194
- total = psutil.disk_usage(disk_path).total / 1024 / 1024 / 1024
195
- # total disk space used in GB:
196
- used = psutil.disk_usage(disk_path).used / 1024 / 1024 / 1024
193
+ # total disk space in Bytes:
194
+ total = psutil.disk_usage(disk_path).total
195
+ # total disk space used in Bytes:
196
+ used = psutil.disk_usage(disk_path).used
197
197
  disk_metrics[disk_path] = {
198
198
  "total": total,
199
199
  "used": used,
@@ -32,7 +32,7 @@ def gpu_in_use_by_this_process(gpu_handle: "GPUHandle", pid: int) -> bool:
32
32
  try:
33
33
  base_process = psutil.Process(pid=pid)
34
34
  except psutil.NoSuchProcess:
35
- # do not report any gpu metrics if the base process cant be found
35
+ # do not report any gpu metrics if the base process can't be found
36
36
  return False
37
37
 
38
38
  our_processes = base_process.children(recursive=True)
@@ -161,6 +161,6 @@ class Memory:
161
161
  # total available memory in gigabytes
162
162
  return {
163
163
  "memory": {
164
- "total": psutil.virtual_memory().total / 1024 / 1024 / 1024,
164
+ "total": psutil.virtual_memory().total,
165
165
  }
166
166
  }
@@ -190,7 +190,7 @@ class SystemInfo:
190
190
  # get the git repo info
191
191
  data = self._probe_git(data)
192
192
 
193
- if self.settings.anonymous != "true":
193
+ if self.settings.anonymous not in ["allow", "must"]:
194
194
  data["host"] = self.settings.host
195
195
  data["username"] = self.settings.username
196
196
  data["executable"] = sys.executable
@@ -203,7 +203,7 @@ class SystemMonitor:
203
203
  logger.error(f"Error joining system monitor process: {e}")
204
204
  self._process = None
205
205
 
206
- def probe(self, publish: bool = True) -> None:
206
+ def probe(self, publish: bool = True) -> dict:
207
207
  logger.info("Collecting system info")
208
208
  # collect static info about the hardware from registered assets
209
209
  hardware_info: dict = {
@@ -220,3 +220,5 @@ class SystemMonitor:
220
220
  logger.info("Publishing system info")
221
221
  self.system_info.publish(system_info)
222
222
  logger.info("Finished publishing system info")
223
+
224
+ return system_info
@@ -492,7 +492,7 @@ class TBHistory:
492
492
  self._step, len(dropped_keys)
493
493
  )
494
494
  )
495
- print("\t" + ("\n\t".join(dropped_keys)))
495
+ print("\t" + ("\n\t".join(dropped_keys))) # noqa: T201
496
496
  self._data["_step"] = self._step
497
497
  self._added.append(self._data)
498
498
  self._step += 1
@@ -45,9 +45,9 @@ class LaunchSource(enum.IntEnum):
45
45
  SCHEDULER: Source is a wandb sweep scheduler command.
46
46
  """
47
47
 
48
- DOCKER: int = 1
49
- JOB: int = 2
50
- SCHEDULER: int = 3
48
+ DOCKER = 1
49
+ JOB = 2
50
+ SCHEDULER = 3
51
51
 
52
52
 
53
53
  class LaunchProject:
@@ -116,7 +116,7 @@ def registry_from_uri(uri: str) -> AbstractRegistry:
116
116
  it as an AWS Elastic Container Registry. If the uri contains
117
117
  `-docker.pkg.dev`, we classify it as a Google Artifact Registry.
118
118
 
119
- This function will attempt to load the approriate cloud helpers for the
119
+ This function will attempt to load the appropriate cloud helpers for the
120
120
 
121
121
  `https://` prefix is optional for all of the above.
122
122
 
wandb/sdk/lib/apikey.py CHANGED
@@ -250,7 +250,7 @@ def write_key(
250
250
  )
251
251
 
252
252
  if anonymous:
253
- api.set_setting("anonymous", "true", globally=True, persist=True)
253
+ api.set_setting("anonymous", "must", globally=True, persist=True)
254
254
  else:
255
255
  api.clear_setting("anonymous", globally=True, persist=True)
256
256
 
@@ -259,8 +259,7 @@ def write_key(
259
259
 
260
260
  def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
261
261
  if settings is None:
262
- settings = wandb.setup().settings # type: ignore
263
- assert settings is not None
262
+ settings = wandb.setup().settings
264
263
  if settings.api_key:
265
264
  return settings.api_key
266
265
  auth = get_netrc_auth(settings.base_url)
wandb/sdk/lib/fsm.py CHANGED
@@ -93,7 +93,7 @@ class FsmStateExit(Protocol[T_FsmInputs, T_FsmContext_cov]):
93
93
  def on_exit(self, inputs: T_FsmInputs) -> T_FsmContext_cov: ... # pragma: no cover
94
94
 
95
95
 
96
- # It would be nice if python provided optional protocol members, but it doesnt as described here:
96
+ # It would be nice if python provided optional protocol members, but it does not as described here:
97
97
  # https://peps.python.org/pep-0544/#support-optional-protocol-members
98
98
  # Until then, we can only enforce that a state at least supports one protocol interface. This
99
99
  # unfortunately will not check the signature of other potential protocols.
wandb/sdk/lib/gitlib.py CHANGED
@@ -226,7 +226,7 @@ class GitRepo:
226
226
  try:
227
227
  return self.repo.create_tag(f"wandb/{name}", message=message, force=True)
228
228
  except GitCommandError:
229
- print("Failed to tag repository.")
229
+ logger.debug("Failed to tag repository.")
230
230
  return None
231
231
 
232
232
  def push(self, name: str) -> Any:
@@ -1,7 +1,7 @@
1
1
  """A simple GraphQL client for sending queries and mutations.
2
2
 
3
3
  Note: This was originally wandb/vendor/gql-0.2.0/wandb_gql/transport/requests.py
4
- The only substantial change is to re-use a requests.Session object.
4
+ The only substantial change is to reuse a requests.Session object.
5
5
  """
6
6
 
7
7
  from typing import Any, Callable, Dict, Optional, Tuple, Union
@@ -0,0 +1,37 @@
1
+ """Utility to send an interrupt (Ctrl+C) signal to the main thread.
2
+
3
+ This is necessary because Windows and POSIX use different models for Ctrl+C
4
+ interrupts.
5
+ """
6
+
7
+ import platform
8
+ import signal
9
+ import threading
10
+
11
+
12
+ def interrupt_main():
13
+ """Interrupt the main Python thread with a SIGINT signal.
14
+
15
+ In POSIX, signal.pthread_kill() is the most reliable way to send a signal
16
+ to the main thread.
17
+
18
+ os.kill() is often recommended, but it isn't guaranteed to deliver the
19
+ signal to the main OS thread. Likewise, signal.raise_signal() delivers
20
+ the signal to the current thread in POSIX. The issue is that if any other
21
+ thread receives the signal, Python will set an internal flag and process it
22
+ on the main thread at the next opportunity. If the main thread is executing
23
+ C code or is blocked on a syscall (e.g. time.sleep(999999)) the signal
24
+ handler won't execute until that's done---i.e. Python won't preempt the OS
25
+ thread on its own.
26
+
27
+ On Windows, pthread_kill is not available and os.kill() ignores its
28
+ second argument and always kills the process. However,
29
+ signal.raise_signal() does the right thing.
30
+ """
31
+ if platform.system() == "Windows":
32
+ signal.raise_signal(signal.SIGINT)
33
+ else:
34
+ signal.pthread_kill(
35
+ threading.main_thread().ident,
36
+ signal.SIGINT,
37
+ )
@@ -38,7 +38,7 @@ class LazyLoader(types.ModuleType):
38
38
 
39
39
  # Emit a warning if one was specified
40
40
  if self._warning:
41
- print(self._warning)
41
+ print(self._warning) # noqa: T201
42
42
  # Make sure to only warn once.
43
43
  self._warning = None
44
44
 
wandb/sdk/lib/progress.py CHANGED
@@ -6,6 +6,7 @@ import contextlib
6
6
  from typing import Iterable, Iterator
7
7
 
8
8
  import wandb
9
+ from wandb import env
9
10
  from wandb.proto import wandb_internal_pb2 as pb
10
11
 
11
12
  from . import printer as p
@@ -50,7 +51,12 @@ class ProgressPrinter:
50
51
  progress_text_area: p.DynamicText | None,
51
52
  settings: wandb.Settings | None,
52
53
  ) -> None:
53
- self._show_operation_stats = settings and settings.x_show_operation_stats
54
+ self._show_operation_stats = (
55
+ settings
56
+ and settings.x_show_operation_stats
57
+ # Not implemented by the legacy service.
58
+ and not env.is_require_legacy_service()
59
+ )
54
60
  self._printer = printer
55
61
  self._progress_text_area = progress_text_area
56
62
  self._tick = 0
@@ -41,7 +41,7 @@ def connect_to_service(
41
41
 
42
42
 
43
43
  def _try_connect_to_existing_service() -> ServiceConnection | None:
44
- """Attemps to connect to an existing service process."""
44
+ """Attempts to connect to an existing service process."""
45
45
  token = service_token.get_service_token()
46
46
  if not token:
47
47
  return None
@@ -65,7 +65,7 @@ def _parse_label_lines(lines: List[str]) -> Dict[str, str]:
65
65
  label_str = line[idx + len(_LABEL_TOKEN) :]
66
66
 
67
67
  # match identifier (first token without key=value syntax (optional)
68
- # Note: Parse is fairly permissive as it doesnt enforce strict syntax
68
+ # Note: Parse is fairly permissive as it does not enforce strict syntax
69
69
  r = MATCH_RE.match(label_str)
70
70
  if r:
71
71
  ret["code"] = r.group("code").replace("-", "_")
@@ -19,4 +19,4 @@ def is_enabled() -> bool:
19
19
 
20
20
  def print_message(message: str) -> None:
21
21
  time_now = time.time()
22
- print("WANDB_STARTUP_DEBUG", time_now, message)
22
+ print("WANDB_STARTUP_DEBUG", time_now, message) # noqa: T201