wandb 0.16.5__py3-none-any.whl → 0.17.0rc1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (141) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -2
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/protocols.py +30 -56
  6. wandb/apis/importers/mlflow.py +13 -26
  7. wandb/apis/importers/wandb.py +8 -14
  8. wandb/apis/public/api.py +1 -0
  9. wandb/apis/public/artifacts.py +1 -0
  10. wandb/apis/public/files.py +1 -0
  11. wandb/apis/public/history.py +1 -0
  12. wandb/apis/public/jobs.py +1 -0
  13. wandb/apis/public/projects.py +1 -0
  14. wandb/apis/public/reports.py +1 -0
  15. wandb/apis/public/runs.py +1 -0
  16. wandb/apis/public/sweeps.py +1 -0
  17. wandb/apis/public/teams.py +1 -0
  18. wandb/apis/public/users.py +1 -0
  19. wandb/apis/reports/v1/_blocks.py +2 -6
  20. wandb/apis/reports/v2/gql.py +1 -0
  21. wandb/apis/reports/v2/interface.py +3 -4
  22. wandb/apis/reports/v2/internal.py +5 -8
  23. wandb/cli/cli.py +7 -4
  24. wandb/data_types.py +3 -3
  25. wandb/env.py +35 -5
  26. wandb/errors/__init__.py +5 -0
  27. wandb/integration/catboost/catboost.py +1 -1
  28. wandb/integration/fastai/__init__.py +1 -0
  29. wandb/integration/keras/__init__.py +1 -0
  30. wandb/integration/keras/keras.py +6 -6
  31. wandb/integration/langchain/wandb_tracer.py +1 -0
  32. wandb/integration/lightning/fabric/logger.py +1 -3
  33. wandb/integration/metaflow/metaflow.py +41 -6
  34. wandb/integration/openai/fine_tuning.py +77 -40
  35. wandb/keras/__init__.py +1 -0
  36. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  37. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  38. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  39. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  40. wandb/proto/wandb_internal_codegen.py +0 -25
  41. wandb/sdk/artifacts/artifact.py +41 -13
  42. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  43. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  44. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  45. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  46. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
  47. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  48. wandb/sdk/artifacts/artifact_saver.py +21 -21
  49. wandb/sdk/artifacts/artifact_state.py +1 -0
  50. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  51. wandb/sdk/artifacts/exceptions.py +1 -0
  52. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  53. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  54. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  55. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  56. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  57. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  58. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  59. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  60. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
  62. wandb/sdk/artifacts/storage_policy.py +1 -0
  63. wandb/sdk/data_types/base_types/media.py +3 -6
  64. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  65. wandb/sdk/integration_utils/auto_logging.py +5 -6
  66. wandb/sdk/integration_utils/data_logging.py +5 -1
  67. wandb/sdk/interface/interface.py +72 -37
  68. wandb/sdk/interface/interface_shared.py +7 -13
  69. wandb/sdk/internal/datastore.py +1 -1
  70. wandb/sdk/internal/handler.py +18 -2
  71. wandb/sdk/internal/internal.py +0 -1
  72. wandb/sdk/internal/internal_util.py +0 -1
  73. wandb/sdk/internal/job_builder.py +4 -3
  74. wandb/sdk/internal/profiler.py +1 -0
  75. wandb/sdk/internal/run.py +1 -0
  76. wandb/sdk/internal/sender.py +1 -1
  77. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  78. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  79. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  80. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  81. wandb/sdk/internal/system/assets/trainium.py +1 -3
  82. wandb/sdk/launch/_launch.py +5 -0
  83. wandb/sdk/launch/_project_spec.py +10 -23
  84. wandb/sdk/launch/agent/agent.py +81 -37
  85. wandb/sdk/launch/agent/config.py +80 -11
  86. wandb/sdk/launch/builder/abstract.py +1 -0
  87. wandb/sdk/launch/builder/build.py +28 -1
  88. wandb/sdk/launch/builder/docker_builder.py +1 -0
  89. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  90. wandb/sdk/launch/builder/noop.py +1 -0
  91. wandb/sdk/launch/create_job.py +61 -48
  92. wandb/sdk/launch/environment/abstract.py +1 -0
  93. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  94. wandb/sdk/launch/environment/local_environment.py +1 -0
  95. wandb/sdk/launch/loader.py +1 -0
  96. wandb/sdk/launch/registry/abstract.py +1 -0
  97. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  98. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  99. wandb/sdk/launch/registry/google_artifact_registry.py +1 -0
  100. wandb/sdk/launch/registry/local_registry.py +1 -0
  101. wandb/sdk/launch/runner/abstract.py +1 -0
  102. wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
  103. wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
  104. wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
  105. wandb/sdk/launch/sweeps/scheduler.py +4 -1
  106. wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -0
  107. wandb/sdk/launch/sweeps/utils.py +1 -1
  108. wandb/sdk/launch/utils.py +21 -3
  109. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  110. wandb/sdk/lib/fsm.py +8 -12
  111. wandb/sdk/lib/gitlib.py +4 -4
  112. wandb/sdk/lib/lazyloader.py +0 -1
  113. wandb/sdk/lib/proto_util.py +1 -1
  114. wandb/sdk/lib/retry.py +3 -2
  115. wandb/sdk/lib/run_moment.py +7 -1
  116. wandb/sdk/service/service.py +17 -15
  117. wandb/sdk/verify/verify.py +2 -1
  118. wandb/sdk/wandb_init.py +2 -8
  119. wandb/sdk/wandb_manager.py +2 -2
  120. wandb/sdk/wandb_require.py +5 -0
  121. wandb/sdk/wandb_run.py +64 -46
  122. wandb/sdk/wandb_settings.py +2 -1
  123. wandb/sklearn/__init__.py +1 -0
  124. wandb/sklearn/plot/__init__.py +1 -0
  125. wandb/sklearn/plot/classifier.py +1 -0
  126. wandb/sklearn/plot/clusterer.py +1 -0
  127. wandb/sklearn/plot/regressor.py +1 -0
  128. wandb/sklearn/plot/shared.py +1 -0
  129. wandb/sklearn/utils.py +1 -0
  130. wandb/testing/relay.py +4 -4
  131. wandb/trigger.py +1 -0
  132. wandb/util.py +40 -17
  133. wandb/wandb_controller.py +0 -1
  134. wandb/wandb_torch.py +1 -2
  135. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/METADATA +68 -69
  136. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/RECORD +139 -140
  137. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/WHEEL +1 -2
  138. wandb/bin/apple_gpu_stats +0 -0
  139. wandb-0.16.5.dist-info/top_level.txt +0 -1
  140. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/entry_points.txt +0 -0
  141. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -291,7 +291,11 @@ def _infer_single_example_keyed_processor(
291
291
  ):
292
292
  # assume this is a class
293
293
  if class_labels_table is not None:
294
- processors["class"] = lambda n, d, p: class_labels_table.index_ref(d[0]) if d[0] < len(class_labels_table.data) else d[0] # type: ignore
294
+ processors["class"] = (
295
+ lambda n, d, p: class_labels_table.index_ref(d[0])
296
+ if d[0] < len(class_labels_table.data)
297
+ else d[0]
298
+ ) # type: ignore
295
299
  else:
296
300
  processors["val"] = lambda n, d, p: d[0]
297
301
  elif len(shape) == 1:
@@ -25,6 +25,7 @@ from typing import (
25
25
  Union,
26
26
  )
27
27
 
28
+ from wandb import termwarn
28
29
  from wandb.proto import wandb_internal_pb2 as pb
29
30
  from wandb.proto import wandb_telemetry_pb2 as tpb
30
31
  from wandb.sdk.artifacts.artifact import Artifact
@@ -423,7 +424,7 @@ class InterfaceBase:
423
424
  job_info=job_info,
424
425
  metadata=metadata,
425
426
  )
426
- use_artifact.partial.source_info.source.ParseFromString(src_str)
427
+ use_artifact.partial.source_info.source.ParseFromString(src_str) # type: ignore[arg-type]
427
428
 
428
429
  return use_artifact
429
430
 
@@ -447,16 +448,27 @@ class InterfaceBase:
447
448
  path = artifact.get_entry("wandb-job.json").download()
448
449
  with open(path) as f:
449
450
  job_info = json.load(f)
451
+
450
452
  except Exception as e:
451
453
  logger.warning(
452
454
  f"Failed to download partial job info from artifact {artifact}, : {e}"
453
455
  )
454
- use_artifact = self._make_proto_use_artifact(
455
- use_artifact=use_artifact,
456
- job_name=artifact.name,
457
- job_info=job_info,
458
- metadata=artifact.metadata,
459
- )
456
+ termwarn(
457
+ f"Failed to download partial job info from artifact {artifact}, : {e}"
458
+ )
459
+ return
460
+
461
+ try:
462
+ use_artifact = self._make_proto_use_artifact(
463
+ use_artifact=use_artifact,
464
+ job_name=artifact.name,
465
+ job_info=job_info,
466
+ metadata=artifact.metadata,
467
+ )
468
+ except Exception as e:
469
+ logger.warning(f"Failed to construct use artifact proto: {e}")
470
+ termwarn(f"Failed to construct use artifact proto: {e}")
471
+ return
460
472
 
461
473
  self._publish_use_artifact(use_artifact)
462
474
 
@@ -504,11 +516,15 @@ class InterfaceBase:
504
516
  artifact_id: str,
505
517
  download_root: str,
506
518
  allow_missing_references: bool,
519
+ skip_cache: bool,
520
+ path_prefix: Optional[str],
507
521
  ) -> MailboxHandle:
508
522
  download_artifact = pb.DownloadArtifactRequest()
509
523
  download_artifact.artifact_id = artifact_id
510
524
  download_artifact.download_root = download_root
511
525
  download_artifact.allow_missing_references = allow_missing_references
526
+ download_artifact.skip_cache = skip_cache
527
+ download_artifact.path_prefix = path_prefix or ""
512
528
  resp = self._deliver_download_artifact(download_artifact)
513
529
  return resp
514
530
 
@@ -717,6 +733,55 @@ class InterfaceBase:
717
733
  def _publish_keepalive(self, keepalive: pb.KeepaliveRequest) -> None:
718
734
  raise NotImplementedError
719
735
 
736
+ def publish_job_input(
737
+ self,
738
+ include_paths: List[List[str]],
739
+ exclude_paths: List[List[str]],
740
+ run_config: bool = False,
741
+ file_path: str = "",
742
+ ):
743
+ """Publishes a request to add inputs to the job.
744
+
745
+ If run_config is True, the wandb.config will be added as a job input.
746
+ If file_path is provided, the file at file_path will be added as a job
747
+ input.
748
+
749
+ The paths provided as arguments are sequences of dictionary keys that
750
+ specify a path within the wandb.config. If a path is included, the
751
+ corresponding field will be treated as a job input. If a path is
752
+ excluded, the corresponding field will not be treated as a job input.
753
+
754
+ Args:
755
+ include_paths: paths within config to include as job inputs.
756
+ exclude_paths: paths within config to exclude as job inputs.
757
+ run_config: bool indicating whether wandb.config is the input source.
758
+ file_path: path to file to include as a job input.
759
+ """
760
+ if run_config and file_path:
761
+ raise ValueError(
762
+ "run_config and file_path are mutually exclusive arguments."
763
+ )
764
+ request = pb.JobInputRequest()
765
+ include_records = [pb.JobInputPath(path=path) for path in include_paths]
766
+ exclude_records = [pb.JobInputPath(path=path) for path in exclude_paths]
767
+ request.include_paths.extend(include_records)
768
+ request.exclude_paths.extend(exclude_records)
769
+ source = pb.JobInputSource(
770
+ run_config=pb.JobInputSource.RunConfigSource(),
771
+ )
772
+ if run_config:
773
+ source.run_config.CopyFrom(pb.JobInputSource.RunConfigSource())
774
+ else:
775
+ source.file.CopyFrom(
776
+ pb.JobInputSource.ConfigFileSource(path=file_path),
777
+ )
778
+
779
+ return self._publish_job_input(request)
780
+
781
+ @abstractmethod
782
+ def _publish_job_input(self, request: pb.JobInputRequest) -> MailboxHandle:
783
+ raise NotImplementedError
784
+
720
785
  def join(self) -> None:
721
786
  # Drop indicates that the internal process has already been shutdown
722
787
  if self._drop:
@@ -767,36 +832,6 @@ class InterfaceBase:
767
832
  run_start.run.CopyFrom(run_pb)
768
833
  return self._deliver_run_start(run_start)
769
834
 
770
- def publish_launch_wandb_config_parameters(
771
- self, include_paths: List[List[str]], exclude_paths: List[List[str]]
772
- ):
773
- """Tells the internal process to treat wandb.config fields as job inputs.
774
-
775
- The paths provided as arguments are sequences of dictionary keys that
776
- specify a path within the wandb.config. If a path is included, the
777
- corresponding field will be treated as a job input. If a path is
778
- excluded, the corresponding field will not be treated as a job input.
779
-
780
- Args:
781
- include_paths: paths within config to include as job inputs.
782
- exclude_paths: paths within config to exclude as job inputs.
783
-
784
- Returns:
785
- None
786
- """
787
- config_parameters = pb.LaunchWandbConfigParametersRecord()
788
- include_records = [pb.ConfigFilterPath(path=path) for path in include_paths]
789
- exclude_records = [pb.ConfigFilterPath(path=path) for path in exclude_paths]
790
- config_parameters.include_paths.extend(include_records)
791
- config_parameters.exclude_paths.extend(exclude_records)
792
- return self._publish_launch_wandb_config_parameters(config_parameters)
793
-
794
- @abstractmethod
795
- def _publish_launch_wandb_config_parameters(
796
- self, config_parameters: pb.LaunchWandbConfigParametersRecord
797
- ) -> None:
798
- raise NotImplementedError
799
-
800
835
  @abstractmethod
801
836
  def _deliver_run_start(self, run_start: pb.RunStartRequest) -> MailboxHandle:
802
837
  raise NotImplementedError
@@ -100,6 +100,10 @@ class InterfaceShared(InterfaceBase):
100
100
  rec = self._make_record(telemetry=telem)
101
101
  self._publish(rec)
102
102
 
103
+ def _publish_job_input(self, job_input: pb.JobInputRequest) -> MailboxHandle:
104
+ record = self._make_request(job_input=job_input)
105
+ return self._deliver_record(record)
106
+
103
107
  def _make_stats(self, stats_dict: dict) -> pb.StatsRecord:
104
108
  stats = pb.StatsRecord()
105
109
  stats.stats_type = pb.StatsRecord.StatsType.SYSTEM
@@ -147,6 +151,7 @@ class InterfaceShared(InterfaceBase):
147
151
  telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
148
152
  get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
149
153
  python_packages: Optional[pb.PythonPackagesRequest] = None,
154
+ job_input: Optional[pb.JobInputRequest] = None,
150
155
  ) -> pb.Record:
151
156
  request = pb.Request()
152
157
  if login:
@@ -207,6 +212,8 @@ class InterfaceShared(InterfaceBase):
207
212
  request.sync.CopyFrom(sync)
208
213
  elif python_packages:
209
214
  request.python_packages.CopyFrom(python_packages)
215
+ elif job_input:
216
+ request.job_input.CopyFrom(job_input)
210
217
  else:
211
218
  raise Exception("Invalid request")
212
219
  record = self._make_record(request=request)
@@ -239,9 +246,6 @@ class InterfaceShared(InterfaceBase):
239
246
  use_artifact: Optional[pb.UseArtifactRecord] = None,
240
247
  output: Optional[pb.OutputRecord] = None,
241
248
  output_raw: Optional[pb.OutputRawRecord] = None,
242
- launch_wandb_config_parameters: Optional[
243
- pb.LaunchWandbConfigParametersRecord
244
- ] = None,
245
249
  ) -> pb.Record:
246
250
  record = pb.Record()
247
251
  if run:
@@ -286,8 +290,6 @@ class InterfaceShared(InterfaceBase):
286
290
  record.output.CopyFrom(output)
287
291
  elif output_raw:
288
292
  record.output_raw.CopyFrom(output_raw)
289
- elif launch_wandb_config_parameters:
290
- record.wandb_config_parameters.CopyFrom(launch_wandb_config_parameters)
291
293
  else:
292
294
  raise Exception("Invalid record")
293
295
  return record
@@ -417,14 +419,6 @@ class InterfaceShared(InterfaceBase):
417
419
  rec = self._make_record(alert=proto_alert)
418
420
  self._publish(rec)
419
421
 
420
- def _publish_launch_wandb_config_parameters(
421
- self, launch_wandb_config_parameters: pb.LaunchWandbConfigParametersRecord
422
- ) -> None:
423
- rec = self._make_record(
424
- launch_wandb_config_parameters=launch_wandb_config_parameters
425
- )
426
- self._publish(rec)
427
-
428
422
  def _communicate_status(
429
423
  self, status: pb.StatusRequest
430
424
  ) -> Optional[pb.StatusResponse]:
@@ -52,7 +52,7 @@ try:
52
52
  bytes("", "ascii")
53
53
 
54
54
  def strtobytes(x):
55
- """strtobytes."""
55
+ """Strtobytes."""
56
56
  return bytes(x, "iso8859-1")
57
57
 
58
58
  # def bytestostr(x):
@@ -50,6 +50,18 @@ SummaryDict = Dict[str, Any]
50
50
 
51
51
  logger = logging.getLogger(__name__)
52
52
 
53
+ # Update (March 5, 2024): Since ~2020/2021, when constructing the summary
54
+ # object, we had replaced the artifact path for media types with the latest
55
+ # artifact path. The primary purpose of this was to support live updating of
56
+ # media objects in the UI (since the default artifact path was fully qualified
57
+ # and would not update). However, in March of 2024, a bug was discovered with
58
+ # this approach which causes this path to be incorrect in cases where the media
59
+ # object is logged to another artifact before being logged to the run. Setting
60
+ # this to `False` disables this copy behavior. The impact is that users will
61
+ # need to refresh to see updates. Ironically, this updating behavior is not
62
+ # currently supported in the UI, so the impact of this change is minimal.
63
+ REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False
64
+
53
65
 
54
66
  def _dict_nested_set(target: Dict[str, Any], key_list: Sequence[str], v: Any) -> None:
55
67
  # recurse down the dictionary structure:
@@ -371,7 +383,11 @@ class HandleManager:
371
383
  updated = True
372
384
  return updated
373
385
  # If the dict is a media object, update the pointer to the latest alias
374
- elif isinstance(v, dict) and handler_util.metric_is_wandb_dict(v):
386
+ elif (
387
+ REPLACE_SUMMARY_ART_PATH_WITH_LATEST
388
+ and isinstance(v, dict)
389
+ and handler_util.metric_is_wandb_dict(v)
390
+ ):
375
391
  if "_latest_artifact_path" in v and "artifact_path" in v:
376
392
  # TODO: Make non-destructive?
377
393
  v["artifact_path"] = v["_latest_artifact_path"]
@@ -381,7 +397,7 @@ class HandleManager:
381
397
  def _update_summary_media_objects(self, v: Dict[str, Any]) -> Dict[str, Any]:
382
398
  # For now, non-recursive - just top level
383
399
  for nk, nv in v.items():
384
- if (
400
+ if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and (
385
401
  isinstance(nv, dict)
386
402
  and handler_util.metric_is_wandb_dict(nv)
387
403
  and "_latest_artifact_path" in nv
@@ -12,7 +12,6 @@ Threads:
12
12
 
13
13
  """
14
14
 
15
-
16
15
  import atexit
17
16
  import logging
18
17
  import os
@@ -4,7 +4,6 @@ Collection of classes to support the internal process.
4
4
 
5
5
  """
6
6
 
7
-
8
7
  import logging
9
8
  import queue
10
9
  import sys
@@ -1,4 +1,5 @@
1
1
  """job builder."""
2
+
2
3
  import json
3
4
  import logging
4
5
  import os
@@ -105,9 +106,9 @@ class JobBuilder:
105
106
  self._disable = settings.disable_job_creation
106
107
  self._partial_source = None
107
108
  self._aliases = []
108
- self._source_type: Optional[
109
- Literal["repo", "artifact", "image"]
110
- ] = settings.job_source # type: ignore[assignment]
109
+ self._source_type: Optional[Literal["repo", "artifact", "image"]] = (
110
+ settings.job_source # type: ignore[assignment]
111
+ )
111
112
  self._is_notebook_run = self._get_is_notebook_run()
112
113
  self._verbose = verbose
113
114
 
@@ -1,4 +1,5 @@
1
1
  """Integration with pytorch profiler."""
2
+
2
3
  import os
3
4
 
4
5
  import wandb
wandb/sdk/internal/run.py CHANGED
@@ -4,6 +4,7 @@
4
4
  Semi-stubbed run for internal process use.
5
5
 
6
6
  """
7
+
7
8
  from wandb._globals import _datatypes_set_callback
8
9
 
9
10
  from .. import wandb_run
@@ -910,7 +910,7 @@ class SendManager:
910
910
  is_wandb_init = self._run is None
911
911
 
912
912
  # save start time of a run
913
- self._start_time = run.start_time.ToMicroseconds() // 1e6
913
+ self._start_time = int(run.start_time.ToMicroseconds() // 1e6)
914
914
 
915
915
  # update telemetry
916
916
  if run.telemetry:
@@ -28,14 +28,6 @@ logger = logging.getLogger(__name__)
28
28
  ROCM_SMI_CMD: Final[str] = shutil.which("rocm-smi") or "/usr/bin/rocm-smi"
29
29
 
30
30
 
31
- def get_rocm_smi_stats() -> Dict[str, Any]:
32
- command = [str(ROCM_SMI_CMD), "-a", "--json"]
33
- output = subprocess.check_output(command, universal_newlines=True).strip()
34
- if "No AMD GPUs specified" in output:
35
- return {}
36
- return json.loads(output.split("\n")[0]) # type: ignore
37
-
38
-
39
31
  _StatsKeys = Literal[
40
32
  "gpu",
41
33
  "memoryAllocated",
@@ -49,6 +41,48 @@ _Stats = Dict[_StatsKeys, float]
49
41
  _InfoDict = Dict[str, Union[int, List[Dict[str, Any]]]]
50
42
 
51
43
 
44
+ def get_rocm_smi_stats() -> Dict[str, Any]:
45
+ command = [str(ROCM_SMI_CMD), "-a", "--json"]
46
+ output = subprocess.check_output(command, universal_newlines=True).strip()
47
+ if "No AMD GPUs specified" in output:
48
+ return {}
49
+ return json.loads(output.split("\n")[0]) # type: ignore
50
+
51
+
52
+ def parse_stats(stats: Dict[str, str]) -> _Stats:
53
+ """Parse stats from rocm-smi output."""
54
+ parsed_stats: _Stats = {}
55
+
56
+ try:
57
+ parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
58
+ except (TypeError, ValueError):
59
+ logger.warning("Could not parse GPU usage as float")
60
+ try:
61
+ parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
62
+ except (TypeError, ValueError):
63
+ logger.warning("Could not parse GPU memory allocation as float")
64
+ try:
65
+ parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
66
+ except (TypeError, ValueError):
67
+ logger.warning("Could not parse GPU temperature as float")
68
+ try:
69
+ parsed_stats["powerWatts"] = float(
70
+ stats.get("Average Graphics Package Power (W)") # type: ignore
71
+ )
72
+ except (TypeError, ValueError):
73
+ logger.warning("Could not parse GPU power as float")
74
+ try:
75
+ parsed_stats["powerPercent"] = (
76
+ float(stats.get("Average Graphics Package Power (W)")) # type: ignore
77
+ / float(stats.get("Max Graphics Package Power (W)")) # type: ignore
78
+ * 100
79
+ )
80
+ except (TypeError, ValueError):
81
+ logger.warning("Could not parse GPU average/max power as float")
82
+
83
+ return parsed_stats
84
+
85
+
52
86
  class GPUAMDStats:
53
87
  """Stats for AMD GPU devices."""
54
88
 
@@ -58,40 +92,6 @@ class GPUAMDStats:
58
92
  def __init__(self) -> None:
59
93
  self.samples = deque()
60
94
 
61
- @staticmethod
62
- def parse_stats(stats: Dict[str, str]) -> _Stats:
63
- """Parse stats from rocm-smi output."""
64
- parsed_stats: _Stats = {}
65
-
66
- try:
67
- parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
68
- except (TypeError, ValueError):
69
- logger.warning("Could not parse GPU usage as float")
70
- try:
71
- parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
72
- except (TypeError, ValueError):
73
- logger.warning("Could not parse GPU memory allocation as float")
74
- try:
75
- parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
76
- except (TypeError, ValueError):
77
- logger.warning("Could not parse GPU temperature as float")
78
- try:
79
- parsed_stats["powerWatts"] = float(
80
- stats.get("Average Graphics Package Power (W)") # type: ignore
81
- )
82
- except (TypeError, ValueError):
83
- logger.warning("Could not parse GPU power as float")
84
- try:
85
- parsed_stats["powerPercent"] = (
86
- float(stats.get("Average Graphics Package Power (W)")) # type: ignore
87
- / float(stats.get("Max Graphics Package Power (W)")) # type: ignore
88
- * 100
89
- )
90
- except (TypeError, ValueError):
91
- logger.warning("Could not parse GPU average/max power as float")
92
-
93
- return parsed_stats
94
-
95
95
  def sample(self) -> None:
96
96
  try:
97
97
  raw_stats = get_rocm_smi_stats()
@@ -103,7 +103,7 @@ class GPUAMDStats:
103
103
 
104
104
  for card_key in card_keys:
105
105
  card_stats = raw_stats[card_key]
106
- stats = self.parse_stats(card_stats)
106
+ stats = parse_stats(card_stats)
107
107
  if stats:
108
108
  cards.append(stats)
109
109
 
@@ -183,7 +183,7 @@ class GPUAMD:
183
183
 
184
184
  can_read_rocm_smi = False
185
185
  try:
186
- if get_rocm_smi_stats():
186
+ if parse_stats(get_rocm_smi_stats()):
187
187
  can_read_rocm_smi = True
188
188
  except Exception:
189
189
  pass
@@ -37,6 +37,12 @@ class _Stats(TypedDict):
37
37
  # cpuWaitMs: float
38
38
 
39
39
 
40
+ def get_apple_gpu_path() -> pathlib.Path:
41
+ return (
42
+ pathlib.Path(sys.modules["wandb"].__path__[0]) / "bin" / "apple_gpu_stats"
43
+ ).resolve()
44
+
45
+
40
46
  class GPUAppleStats:
41
47
  """Apple GPU stats available on Arm Macs."""
42
48
 
@@ -49,9 +55,7 @@ class GPUAppleStats:
49
55
 
50
56
  def __init__(self) -> None:
51
57
  self.samples = deque()
52
- self.binary_path = (
53
- pathlib.Path(sys.modules["wandb"].__path__[0]) / "bin" / "apple_gpu_stats"
54
- ).resolve()
58
+ self.binary_path = get_apple_gpu_path()
55
59
 
56
60
  def sample(self) -> None:
57
61
  try:
@@ -63,22 +67,47 @@ class GPUAppleStats:
63
67
  )[0]
64
68
  raw_stats = json.loads(output)
65
69
 
70
+ temp_keys = [
71
+ "m1Gpu1",
72
+ "m1Gpu2",
73
+ "m1Gpu3",
74
+ "m1Gpu4",
75
+ "m2Gpu1",
76
+ "m2Gpu2",
77
+ "m3Gpu1",
78
+ "m3Gpu2",
79
+ "m3Gpu3",
80
+ "m3Gpu4",
81
+ "m3Gpu5",
82
+ "m3Gpu6",
83
+ "m3Gpu7",
84
+ "m3Gpu8",
85
+ ]
86
+ temp, count = 0, 0
87
+ for k in temp_keys:
88
+ if raw_stats.get(k, 0) > 0:
89
+ temp += raw_stats[k]
90
+ count += 1
91
+
66
92
  stats: _Stats = {
67
93
  "gpu": raw_stats["utilization"],
68
- "memoryAllocated": raw_stats["mem_used"],
69
- "temp": raw_stats["temperature"],
70
- "powerWatts": raw_stats["power"],
71
- "powerPercent": (raw_stats["power"] / self.MAX_POWER_WATTS) * 100,
94
+ "memoryAllocated": (
95
+ raw_stats["inUseSystemMemory"]
96
+ / raw_stats["allocatedSystemMemory"]
97
+ * 100
98
+ ),
99
+ "powerWatts": raw_stats["systemPower"],
100
+ "powerPercent": (raw_stats["systemPower"] / self.MAX_POWER_WATTS) * 100,
101
+ "temp": temp / count if count > 0 else 0,
72
102
  # TODO: this stat could be useful eventually, it was consistently
73
103
  # 0 in my experimentation and requires a frontend change
74
104
  # so leaving it out for now.
75
105
  # "cpuWaitMs": raw_stats["cpu_wait_ms"],
76
106
  }
77
-
78
107
  self.samples.append(stats)
79
108
 
80
109
  except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
81
- logger.exception(f"GPU stats error: {e}")
110
+ logger.exception("GPU stats error: %s", e)
82
111
 
83
112
  def clear(self) -> None:
84
113
  self.samples.clear()
@@ -116,6 +145,7 @@ class GPUApple:
116
145
  telemetry_record = telemetry.TelemetryRecord()
117
146
  telemetry_record.env.m1_gpu = True
118
147
  interface._publish_telemetry(telemetry_record)
148
+ self.binary_path = get_apple_gpu_path()
119
149
 
120
150
  @classmethod
121
151
  def is_available(cls) -> bool:
@@ -128,5 +158,20 @@ class GPUApple:
128
158
  self.metrics_monitor.finish()
129
159
 
130
160
  def probe(self) -> dict:
131
- # todo: make this actually meaningful
132
- return {self.name: {"type": "arm", "vendor": "Apple"}}
161
+ try:
162
+ command = [str(self.binary_path), "--json"]
163
+ output = (
164
+ subprocess.check_output(command, universal_newlines=True)
165
+ .strip()
166
+ .split("\n")
167
+ )[0]
168
+ raw_stats = json.loads(output)
169
+ return {
170
+ self.name: {
171
+ "type": raw_stats["name"],
172
+ "vendor": raw_stats["vendor"],
173
+ }
174
+ }
175
+ except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
176
+ logger.exception("GPU stats error: %s", e)
177
+ return {self.name: {"type": "arm", "vendor": "Apple"}}
@@ -68,8 +68,7 @@ class Asset(Protocol):
68
68
  metrics: List[Metric]
69
69
  metrics_monitor: "MetricsMonitor"
70
70
 
71
- def __init__(self, *args: Any, **kwargs: Any) -> None:
72
- ... # pragma: no cover
71
+ def __init__(self, *args: Any, **kwargs: Any) -> None: ... # pragma: no cover
73
72
 
74
73
  @classmethod
75
74
  def is_available(cls) -> bool:
@@ -90,14 +89,13 @@ class Asset(Protocol):
90
89
 
91
90
 
92
91
  class Interface(Protocol):
93
- def publish_stats(self, stats: dict) -> None:
94
- ... # pragma: no cover
92
+ def publish_stats(self, stats: dict) -> None: ... # pragma: no cover
95
93
 
96
- def _publish_telemetry(self, telemetry: "TelemetryRecord") -> None:
97
- ... # pragma: no cover
94
+ def _publish_telemetry(
95
+ self, telemetry: "TelemetryRecord"
96
+ ) -> None: ... # pragma: no cover
98
97
 
99
- def publish_files(self, files_dict: "FilesDict") -> None:
100
- ... # pragma: no cover
98
+ def publish_files(self, files_dict: "FilesDict") -> None: ... # pragma: no cover
101
99
 
102
100
 
103
101
  class MetricsMonitor:
@@ -65,13 +65,13 @@ def _setup_requests_session() -> requests.Session:
65
65
 
66
66
 
67
67
  def _nested_dict_to_tuple(
68
- nested_dict: Mapping[str, Mapping[str, str]]
68
+ nested_dict: Mapping[str, Mapping[str, str]],
69
69
  ) -> Tuple[Tuple[str, Tuple[str, str]], ...]:
70
70
  return tuple((k, *v.items()) for k, v in nested_dict.items()) # type: ignore
71
71
 
72
72
 
73
73
  def _tuple_to_nested_dict(
74
- nested_tuple: Tuple[Tuple[str, Tuple[str, str]], ...]
74
+ nested_tuple: Tuple[Tuple[str, Tuple[str, str]], ...],
75
75
  ) -> Dict[str, Dict[str, str]]:
76
76
  return {k: dict(v) for k, *v in nested_tuple}
77
77
 
@@ -197,9 +197,7 @@ class NeuronCoreStats:
197
197
  entry["report"]
198
198
  for entry in raw_stats["neuron_runtime_data"]
199
199
  if self._is_matching_entry(entry)
200
- ][
201
- 0
202
- ] # there should be only one entry with the pid
200
+ ][0] # there should be only one entry with the pid
203
201
 
204
202
  neuroncores_in_use = neuron_runtime_data["neuroncore_counters"][
205
203
  "neuroncores_in_use"
@@ -62,6 +62,7 @@ def resolve_agent_config( # noqa: C901
62
62
  max_jobs: Optional[int],
63
63
  queues: Optional[Tuple[str]],
64
64
  config: Optional[str],
65
+ verbosity: Optional[int],
65
66
  ) -> Tuple[Dict[str, Any], Api]:
66
67
  """Resolve the agent config.
67
68
 
@@ -72,6 +73,7 @@ def resolve_agent_config( # noqa: C901
72
73
  max_jobs (int): The max number of jobs.
73
74
  queues (Tuple[str]): The queues.
74
75
  config (str): The config.
76
+ verbosity (int): How verbose to print, 0 or None = default, 1 = print status every 20 seconds, 2 = also print debugging information
75
77
 
76
78
  Returns:
77
79
  Tuple[Dict[str, Any], Api]: The resolved config and api.
@@ -83,6 +85,7 @@ def resolve_agent_config( # noqa: C901
83
85
  "queues": [],
84
86
  "registry": {},
85
87
  "builder": {},
88
+ "verbosity": 0,
86
89
  }
87
90
  user_set_project = False
88
91
  resolved_config: Dict[str, Any] = defaults
@@ -123,6 +126,8 @@ def resolve_agent_config( # noqa: C901
123
126
  resolved_config.update({"max_jobs": int(max_jobs)})
124
127
  if queues:
125
128
  resolved_config.update({"queues": list(queues)})
129
+ if verbosity:
130
+ resolved_config.update({"verbosity": int(verbosity)})
126
131
  # queue -> queues
127
132
  if resolved_config.get("queue"):
128
133
  if isinstance(resolved_config.get("queue"), str):