wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,23 @@
1
1
  """job builder."""
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  import sys
5
6
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
6
7
 
8
+ from wandb.sdk.artifacts.artifact import Artifact
7
9
  from wandb.sdk.data_types._dtypes import TypeRegistry
8
10
  from wandb.sdk.lib.filenames import DIFF_FNAME, METADATA_FNAME, REQUIREMENTS_FNAME
9
- from wandb.sdk.wandb_artifacts import Artifact
10
11
  from wandb.util import make_artifact_name_safe
11
12
 
12
13
  from .settings_static import SettingsStatic
13
14
 
14
15
  if sys.version_info >= (3, 8):
15
- from typing import TypedDict
16
+ from typing import Literal, TypedDict
16
17
  else:
17
- from typing_extensions import TypedDict
18
+ from typing_extensions import Literal, TypedDict
19
+
20
+ _logger = logging.getLogger(__name__)
18
21
 
19
22
  if TYPE_CHECKING:
20
23
  from wandb.proto.wandb_internal_pb2 import ArtifactRecord
@@ -32,11 +35,13 @@ class GitInfo(TypedDict):
32
35
  class GitSourceDict(TypedDict):
33
36
  git: GitInfo
34
37
  entrypoint: List[str]
38
+ notebook: bool
35
39
 
36
40
 
37
41
  class ArtifactSourceDict(TypedDict):
38
42
  artifact: str
39
43
  entrypoint: List[str]
44
+ notebook: bool
40
45
 
41
46
 
42
47
  class ImageSourceDict(TypedDict):
@@ -71,6 +76,7 @@ class JobBuilder:
71
76
  _summary: Optional[Dict[str, Any]]
72
77
  _logged_code_artifact: Optional[ArtifactInfoForJob]
73
78
  _disable: bool
79
+ _aliases: List[str]
74
80
 
75
81
  def __init__(self, settings: SettingsStatic):
76
82
  self._settings = settings
@@ -80,6 +86,10 @@ class JobBuilder:
80
86
  self._summary = None
81
87
  self._logged_code_artifact = None
82
88
  self._disable = settings.disable_job_creation
89
+ self._aliases = []
90
+ self._source_type: Optional[
91
+ Literal["repo", "artifact", "image"]
92
+ ] = settings.get("job_source")
83
93
 
84
94
  def set_config(self, config: Dict[str, Any]) -> None:
85
95
  self._config = config
@@ -107,19 +117,50 @@ class JobBuilder:
107
117
  )
108
118
 
109
119
  def _build_repo_job(
110
- self, metadata: Dict[str, Any], program_relpath: str
111
- ) -> Tuple[Artifact, GitSourceDict]:
120
+ self, metadata: Dict[str, Any], program_relpath: str, root: Optional[str]
121
+ ) -> Tuple[Optional[Artifact], Optional[GitSourceDict]]:
112
122
  git_info: Dict[str, str] = metadata.get("git", {})
113
123
  remote = git_info.get("remote")
114
124
  commit = git_info.get("commit")
115
125
  assert remote is not None
116
126
  assert commit is not None
127
+ if self._is_notebook_run():
128
+ if not os.path.exists(
129
+ os.path.join(os.getcwd(), os.path.basename(program_relpath))
130
+ ):
131
+ return None, None
132
+
133
+ if root is None or self._settings._jupyter_root is None:
134
+ _logger.info("target path does not exist, exiting")
135
+ return None, None
136
+ assert self._settings._jupyter_root is not None
137
+ # git notebooks set the root to the git root,
138
+ # jupyter_root contains the path where the jupyter notebook was started
139
+ # program_relpath contains the path from jupyter_root to the file
140
+ # full program path here is actually the relpath from the program to the git root
141
+ full_program_path = os.path.join(
142
+ os.path.relpath(str(self._settings._jupyter_root), root),
143
+ program_relpath,
144
+ )
145
+ full_program_path = os.path.normpath(full_program_path)
146
+ # if the notebook server is started above the git repo need to clear all the ..s
147
+ if full_program_path.startswith(".."):
148
+ split_path = full_program_path.split("/")
149
+ count_dots = 0
150
+ for p in split_path:
151
+ if p == "..":
152
+ count_dots += 1
153
+ full_program_path = "/".join(split_path[2 * count_dots :])
154
+
155
+ else:
156
+ full_program_path = program_relpath
117
157
  # TODO: update executable to a method that supports pex
118
158
  source: GitSourceDict = {
119
159
  "entrypoint": [
120
160
  os.path.basename(sys.executable),
121
- program_relpath,
161
+ full_program_path,
122
162
  ],
163
+ "notebook": self._is_notebook_run(),
123
164
  "git": {
124
165
  "remote": remote,
125
166
  "commit": commit,
@@ -132,22 +173,40 @@ class JobBuilder:
132
173
  if os.path.exists(os.path.join(self._settings.files_dir, DIFF_FNAME)):
133
174
  artifact.add_file(
134
175
  os.path.join(self._settings.files_dir, DIFF_FNAME),
135
- name="diff.patch",
176
+ name=DIFF_FNAME,
136
177
  )
137
178
  return artifact, source
138
179
 
139
180
  def _build_artifact_job(
140
- self, program_relpath: str
141
- ) -> Tuple[Artifact, ArtifactSourceDict]:
181
+ self, metadata: Dict[str, Any], program_relpath: str
182
+ ) -> Tuple[Optional[Artifact], Optional[ArtifactSourceDict]]:
142
183
  assert isinstance(self._logged_code_artifact, dict)
184
+ # TODO: should we just always exit early if the path doesn't exist?
185
+ if self._is_notebook_run() and not self._is_colab_run():
186
+ full_program_relpath = os.path.relpath(program_relpath, os.getcwd())
187
+ # if the resolved path doesn't exist, then we shouldn't make a job because it will fail
188
+ if not os.path.exists(full_program_relpath):
189
+ # when users call log code in a notebook the code artifact starts
190
+ # at the directory the notebook is in instead of the jupyter
191
+ # core
192
+ if os.path.exists(os.path.basename(program_relpath)):
193
+ full_program_relpath = os.path.basename(program_relpath)
194
+ else:
195
+ _logger.info("target path does not exist, exiting")
196
+ return None, None
197
+ else:
198
+ full_program_relpath = program_relpath
199
+ entrypoint = [
200
+ os.path.basename(sys.executable),
201
+ full_program_relpath,
202
+ ]
143
203
  # TODO: update executable to a method that supports pex
144
204
  source: ArtifactSourceDict = {
145
- "entrypoint": [
146
- os.path.basename(sys.executable),
147
- program_relpath,
148
- ],
205
+ "entrypoint": entrypoint,
206
+ "notebook": self._is_notebook_run(),
149
207
  "artifact": f"wandb-artifact://_id/{self._logged_code_artifact['id']}",
150
208
  }
209
+
151
210
  name = make_artifact_name_safe(f"job-{self._logged_code_artifact['name']}")
152
211
 
153
212
  artifact = JobArtifact(name)
@@ -158,14 +217,27 @@ class JobBuilder:
158
217
  ) -> Tuple[Artifact, ImageSourceDict]:
159
218
  image_name = metadata.get("docker")
160
219
  assert isinstance(image_name, str)
161
- name = make_artifact_name_safe(f"job-{image_name}")
220
+
221
+ raw_image_name = image_name
222
+ if ":" in image_name:
223
+ raw_image_name, tag = image_name.split(":")
224
+ self._aliases += [tag]
225
+
226
+ name = make_artifact_name_safe(f"job-{raw_image_name}")
162
227
  artifact = JobArtifact(name)
163
228
  source: ImageSourceDict = {
164
229
  "image": image_name,
165
230
  }
166
231
  return artifact, source
167
232
 
233
+ def _is_notebook_run(self) -> bool:
234
+ return hasattr(self._settings, "_jupyter") and bool(self._settings._jupyter)
235
+
236
+ def _is_colab_run(self) -> bool:
237
+ return hasattr(self._settings, "_colab") and bool(self._settings._colab)
238
+
168
239
  def build(self) -> Optional[Artifact]:
240
+ _logger.info("Attempting to build job artifact")
169
241
  if not os.path.exists(
170
242
  os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME)
171
243
  ):
@@ -181,23 +253,40 @@ class JobBuilder:
181
253
 
182
254
  program_relpath: Optional[str] = metadata.get("codePath")
183
255
 
256
+ source_type = self._source_type
257
+
258
+ if self._is_notebook_run():
259
+ _logger.info("run is notebook based run")
260
+ program_relpath = metadata.get("program")
261
+
262
+ if not source_type:
263
+ if self._has_git_job_ingredients(metadata, program_relpath):
264
+ _logger.info("is repo sourced job")
265
+ source_type = "repo"
266
+ elif self._has_artifact_job_ingredients(program_relpath):
267
+ _logger.info("is artifact sourced job")
268
+ source_type = "artifact"
269
+ elif self._has_image_job_ingredients(metadata):
270
+ _logger.info("is image sourced job")
271
+ source_type = "image"
272
+
273
+ if not source_type:
274
+ _logger.info("no source found")
275
+ return None
276
+
184
277
  artifact = None
185
- source_type = None
186
278
  source: Optional[
187
279
  Union[GitSourceDict, ArtifactSourceDict, ImageSourceDict]
188
280
  ] = None
189
-
190
- if self._has_git_job_ingredients(metadata, program_relpath):
281
+ if source_type == "repo":
191
282
  assert program_relpath is not None
192
- artifact, source = self._build_repo_job(metadata, program_relpath)
193
- source_type = "repo"
194
- elif self._has_artifact_job_ingredients(program_relpath):
283
+ root: Optional[str] = metadata.get("root")
284
+ artifact, source = self._build_repo_job(metadata, program_relpath, root)
285
+ elif source_type == "artifact":
195
286
  assert program_relpath is not None
196
- artifact, source = self._build_artifact_job(program_relpath)
197
- source_type = "artifact"
198
- elif self._has_image_job_ingredients(metadata):
287
+ artifact, source = self._build_artifact_job(metadata, program_relpath)
288
+ elif source_type == "image":
199
289
  artifact, source = self._build_image_job(metadata)
200
- source_type = "image"
201
290
 
202
291
  if artifact is None or source_type is None or source is None:
203
292
  return None
@@ -213,7 +302,7 @@ class JobBuilder:
213
302
  "output_types": output_types,
214
303
  "runtime": runtime,
215
304
  }
216
-
305
+ _logger.info("adding wandb-job metadata file")
217
306
  with artifact.new_file("wandb-job.json") as f:
218
307
  f.write(json.dumps(source_info, indent=4))
219
308
 
@@ -238,11 +327,11 @@ class JobBuilder:
238
327
  self, metadata: Dict[str, Any], program_relpath: Optional[str]
239
328
  ) -> bool:
240
329
  git_info: Dict[str, str] = metadata.get("git", {})
241
- return (
242
- git_info.get("remote") is not None
243
- and git_info.get("commit") is not None
244
- and program_relpath is not None
245
- )
330
+ if program_relpath is None:
331
+ return False
332
+ if self._is_notebook_run() and metadata.get("root") is None:
333
+ return False
334
+ return git_info.get("remote") is not None and git_info.get("commit") is not None
246
335
 
247
336
  def _has_artifact_job_ingredients(self, program_relpath: Optional[str]) -> bool:
248
337
  return self._logged_code_artifact is not None and program_relpath is not None
@@ -30,16 +30,10 @@ from wandb.errors import CommError, UsageError
30
30
  from wandb.errors.util import ProtobufErrorHandler
31
31
  from wandb.filesync.dir_watcher import DirWatcher
32
32
  from wandb.proto import wandb_internal_pb2
33
+ from wandb.sdk.artifacts import artifact_saver
33
34
  from wandb.sdk.interface import interface
34
35
  from wandb.sdk.interface.interface_queue import InterfaceQueue
35
- from wandb.sdk.internal import (
36
- artifact_saver,
37
- context,
38
- datastore,
39
- file_stream,
40
- internal_api,
41
- update,
42
- )
36
+ from wandb.sdk.internal import context, datastore, file_stream, internal_api, update
43
37
  from wandb.sdk.internal.file_pusher import FilePusher
44
38
  from wandb.sdk.internal.job_builder import JobBuilder
45
39
  from wandb.sdk.internal.settings_static import SettingsDict, SettingsStatic
@@ -486,24 +480,6 @@ class SendManager:
486
480
  result.response.check_version_response.delete_message = delete_message
487
481
  self._respond_result(result)
488
482
 
489
- def _send_request_attach(
490
- self,
491
- req: wandb_internal_pb2.AttachRequest,
492
- resp: wandb_internal_pb2.AttachResponse,
493
- ) -> None:
494
- attach_id = req.attach_id
495
- assert attach_id
496
- assert self._run
497
- resp.run.CopyFrom(self._run)
498
-
499
- def send_request_attach(self, record: "Record") -> None:
500
- assert record.control.req_resp or record.control.mailbox_slot
501
- result = proto_util._result_from_record(record)
502
- self._send_request_attach(
503
- record.request.attach, result.response.attach_response
504
- )
505
- self._respond_result(result)
506
-
507
483
  def send_request_stop_status(self, record: "Record") -> None:
508
484
  result = proto_util._result_from_record(record)
509
485
  status_resp = result.response.stop_status_response
@@ -1632,6 +1608,10 @@ class SendManager:
1632
1608
  # TODO: this should be removed when the latest tag is handled
1633
1609
  # by the backend (WB-12116)
1634
1610
  proto_artifact.aliases.append("latest")
1611
+ # add docker image tag
1612
+ for alias in self._job_builder._aliases:
1613
+ proto_artifact.aliases.append(alias)
1614
+
1635
1615
  proto_artifact.user_created = True
1636
1616
  proto_artifact.use_after_commit = True
1637
1617
  proto_artifact.finalize = True
@@ -8,6 +8,7 @@ class SettingsStatic:
8
8
  # TODO(jhr): figure out how to share type defs with sdk/wandb_settings.py
9
9
  _offline: Optional[bool]
10
10
  _sync: bool
11
+ _disable_setproctitle: bool
11
12
  _disable_stats: Optional[bool]
12
13
  _disable_meta: Optional[bool]
13
14
  _flow_control: bool
@@ -65,6 +66,7 @@ class SettingsStatic:
65
66
  disable_job_creation: bool
66
67
  _async_upload_concurrency_limit: Optional[int]
67
68
  _extra_http_headers: Optional[Mapping[str, str]]
69
+ job_source: Optional[str]
68
70
 
69
71
  # TODO(jhr): clean this up, it is only in SettingsStatic and not in Settings
70
72
  _log_level: int
@@ -3,6 +3,7 @@ __all__ = (
3
3
  "CPU",
4
4
  "Disk",
5
5
  "GPU",
6
+ "GPUAMD",
6
7
  "GPUApple",
7
8
  "IPU",
8
9
  "Memory",
@@ -16,6 +17,7 @@ from .asset_registry import asset_registry
16
17
  from .cpu import CPU
17
18
  from .disk import Disk
18
19
  from .gpu import GPU
20
+ from .gpu_amd import GPUAMD
19
21
  from .gpu_apple import GPUApple
20
22
  from .ipu import IPU
21
23
  from .memory import Memory
@@ -137,6 +137,47 @@ class GPUMemoryAllocated:
137
137
  return stats
138
138
 
139
139
 
140
+ class GPUMemoryAllocatedBytes:
141
+ """GPU memory allocated in bytes for each GPU."""
142
+
143
+ # name = "memory_allocated"
144
+ name = "gpu.{}.memoryAllocatedBytes"
145
+ # samples: Deque[Tuple[datetime.datetime, float]]
146
+ samples: "Deque[List[float]]"
147
+
148
+ def __init__(self, pid: int) -> None:
149
+ self.pid = pid
150
+ self.samples = deque([])
151
+
152
+ def sample(self) -> None:
153
+ memory_allocated = []
154
+ device_count = pynvml.nvmlDeviceGetCount() # type: ignore
155
+ for i in range(device_count):
156
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
157
+ memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore
158
+ memory_allocated.append(memory_info.used)
159
+ self.samples.append(memory_allocated)
160
+
161
+ def clear(self) -> None:
162
+ self.samples.clear()
163
+
164
+ def aggregate(self) -> dict:
165
+ if not self.samples:
166
+ return {}
167
+ stats = {}
168
+ device_count = pynvml.nvmlDeviceGetCount() # type: ignore
169
+ for i in range(device_count):
170
+ samples = [sample[i] for sample in self.samples]
171
+ aggregate = aggregate_mean(samples)
172
+ stats[self.name.format(i)] = aggregate
173
+
174
+ handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
175
+ if gpu_in_use_by_this_process(handle, self.pid):
176
+ stats[self.name.format(f"process.{i}")] = aggregate
177
+
178
+ return stats
179
+
180
+
140
181
  class GPUUtilization:
141
182
  """GPU utilization in percent for each GPU."""
142
183
 
@@ -314,6 +355,7 @@ class GPU:
314
355
  self.name = self.__class__.__name__.lower()
315
356
  self.metrics: List[Metric] = [
316
357
  GPUMemoryAllocated(settings._stats_pid),
358
+ GPUMemoryAllocatedBytes(settings._stats_pid),
317
359
  GPUMemoryUtilization(settings._stats_pid),
318
360
  GPUUtilization(settings._stats_pid),
319
361
  GPUTemperature(settings._stats_pid),
@@ -0,0 +1,216 @@
1
+ import json
2
+ import logging
3
+ import shutil
4
+ import subprocess
5
+ import sys
6
+ import threading
7
+ from collections import deque
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
9
+
10
+ if sys.version_info >= (3, 8):
11
+ from typing import Final, Literal
12
+ else:
13
+ from typing_extensions import Final, Literal
14
+
15
+ from wandb.sdk.lib import telemetry
16
+
17
+ from .aggregators import aggregate_mean
18
+ from .asset_registry import asset_registry
19
+ from .interfaces import Interface, Metric, MetricsMonitor
20
+
21
+ if TYPE_CHECKING:
22
+ from typing import Deque
23
+
24
+ from wandb.sdk.internal.settings_static import SettingsStatic
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ ROCM_SMI_CMD: Final[str] = shutil.which("rocm-smi") or "/usr/bin/rocm-smi"
31
+
32
+
33
+ def get_rocm_smi_stats() -> Dict[str, Any]:
34
+ command = [str(ROCM_SMI_CMD), "-a", "--json"]
35
+ output = (
36
+ subprocess.check_output(command, universal_newlines=True).strip().split("\n")
37
+ )[0]
38
+ return json.loads(output) # type: ignore
39
+
40
+
41
+ _StatsKeys = Literal[
42
+ "gpu",
43
+ "memoryAllocated",
44
+ "temp",
45
+ "powerWatts",
46
+ "powerPercent",
47
+ ]
48
+ _Stats = Dict[_StatsKeys, float]
49
+
50
+
51
+ _InfoDict = Dict[str, Union[int, List[Dict[str, Any]]]]
52
+
53
+
54
+ class GPUAMDStats:
55
+ """Stats for AMD GPU devices."""
56
+
57
+ name = "gpu.{gpu_id}.{key}"
58
+ samples: "Deque[List[_Stats]]"
59
+
60
+ def __init__(self) -> None:
61
+ self.samples = deque()
62
+
63
+ @staticmethod
64
+ def parse_stats(stats: Dict[str, str]) -> _Stats:
65
+ """Parse stats from rocm-smi output."""
66
+ parsed_stats: _Stats = {}
67
+
68
+ try:
69
+ parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
70
+ except (TypeError, ValueError):
71
+ logger.warning("Could not parse GPU usage as float")
72
+ try:
73
+ parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
74
+ except (TypeError, ValueError):
75
+ logger.warning("Could not parse GPU memory allocation as float")
76
+ try:
77
+ parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
78
+ except (TypeError, ValueError):
79
+ logger.warning("Could not parse GPU temperature as float")
80
+ try:
81
+ parsed_stats["powerWatts"] = float(
82
+ stats.get("Average Graphics Package Power (W)") # type: ignore
83
+ )
84
+ except (TypeError, ValueError):
85
+ logger.warning("Could not parse GPU power as float")
86
+ try:
87
+ parsed_stats["powerPercent"] = (
88
+ float(stats.get("Average Graphics Package Power (W)")) # type: ignore
89
+ / float(stats.get("Max Graphics Package Power (W)")) # type: ignore
90
+ * 100
91
+ )
92
+ except (TypeError, ValueError):
93
+ logger.warning("Could not parse GPU average/max power as float")
94
+
95
+ return parsed_stats
96
+
97
+ def sample(self) -> None:
98
+ try:
99
+ raw_stats = get_rocm_smi_stats()
100
+ cards = []
101
+
102
+ card_keys = [
103
+ key for key in sorted(raw_stats.keys()) if key.startswith("card")
104
+ ]
105
+
106
+ for card_key in card_keys:
107
+ card_stats = raw_stats[card_key]
108
+ stats = self.parse_stats(card_stats)
109
+ if stats:
110
+ cards.append(stats)
111
+
112
+ if cards:
113
+ self.samples.append(cards)
114
+
115
+ except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
116
+ logger.exception(f"GPU stats error: {e}")
117
+
118
+ def clear(self) -> None:
119
+ self.samples.clear()
120
+
121
+ def aggregate(self) -> dict:
122
+ if not self.samples:
123
+ return {}
124
+ stats = {}
125
+ device_count = len(self.samples[0])
126
+
127
+ for i in range(device_count):
128
+ samples = [sample[i] for sample in self.samples]
129
+
130
+ for key in samples[0].keys():
131
+ samples_key = [s[key] for s in samples]
132
+ aggregate = aggregate_mean(samples_key)
133
+ stats[self.name.format(gpu_id=i, key=key)] = aggregate
134
+
135
+ return stats
136
+
137
+
138
+ @asset_registry.register
139
+ class GPUAMD:
140
+ """GPUAMD is a class for monitoring AMD GPU devices.
141
+
142
+ Uses AMD's rocm_smi tool to get GPU stats.
143
+ For the list of supported environments and devices, see
144
+ https://github.com/RadeonOpenCompute/ROCm/blob/develop/docs/deploy/
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ interface: "Interface",
150
+ settings: "SettingsStatic",
151
+ shutdown_event: threading.Event,
152
+ ) -> None:
153
+ self.name = self.__class__.__name__.lower()
154
+ self.metrics: List[Metric] = [
155
+ GPUAMDStats(),
156
+ ]
157
+ self.metrics_monitor = MetricsMonitor(
158
+ self.name,
159
+ self.metrics,
160
+ interface,
161
+ settings,
162
+ shutdown_event,
163
+ )
164
+ telemetry_record = telemetry.TelemetryRecord()
165
+ telemetry_record.env.amd_gpu = True
166
+ interface._publish_telemetry(telemetry_record)
167
+
168
+ @classmethod
169
+ def is_available(cls) -> bool:
170
+ rocm_smi_available = shutil.which(ROCM_SMI_CMD) is not None
171
+ if rocm_smi_available:
172
+ try:
173
+ _ = get_rocm_smi_stats()
174
+ return True
175
+ except Exception:
176
+ pass
177
+ return False
178
+
179
+ def start(self) -> None:
180
+ self.metrics_monitor.start()
181
+
182
+ def finish(self) -> None:
183
+ self.metrics_monitor.finish()
184
+
185
+ def probe(self) -> dict:
186
+ info: _InfoDict = {}
187
+ try:
188
+ stats = get_rocm_smi_stats()
189
+
190
+ info["gpu_count"] = len(
191
+ [key for key in stats.keys() if key.startswith("card")]
192
+ )
193
+ key_mapping = {
194
+ "id": "GPU ID",
195
+ "unique_id": "Unique ID",
196
+ "vbios_version": "VBIOS version",
197
+ "performance_level": "Performance Level",
198
+ "gpu_overdrive": "GPU OverDrive value (%)",
199
+ "gpu_memory_overdrive": "GPU Memory OverDrive value (%)",
200
+ "max_power": "Max Graphics Package Power (W)",
201
+ "series": "Card series",
202
+ "model": "Card model",
203
+ "vendor": "Card vendor",
204
+ "sku": "Card SKU",
205
+ "sclk_range": "Valid sclk range",
206
+ "mclk_range": "Valid mclk range",
207
+ }
208
+
209
+ info["gpu_devices"] = [
210
+ {k: stats[key][v] for k, v in key_mapping.items() if stats[key].get(v)}
211
+ for key in stats.keys()
212
+ if key.startswith("card")
213
+ ]
214
+ except Exception as e:
215
+ logger.exception(f"GPUAMD probe error: {e}")
216
+ return info
@@ -0,0 +1,13 @@
1
+ import logging
2
+
3
+ from sentry_sdk.integrations.aws_lambda import get_lambda_bootstrap # type: ignore
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ def is_aws_lambda() -> bool:
9
+ """Check if we are running in a lambda environment."""
10
+ lambda_bootstrap = get_lambda_bootstrap()
11
+ if not lambda_bootstrap or not hasattr(lambda_bootstrap, "handle_event_request"):
12
+ return False
13
+ return True
@@ -18,7 +18,7 @@ from wandb.sdk.lib.filenames import (
18
18
  METADATA_FNAME,
19
19
  REQUIREMENTS_FNAME,
20
20
  )
21
- from wandb.sdk.lib.git import GitRepo
21
+ from wandb.sdk.lib.gitlib import GitRepo
22
22
 
23
23
  from .assets.interfaces import Interface
24
24
 
@@ -142,8 +142,8 @@ class SystemInfo:
142
142
  os.path.relpath(patch_path, start=self.settings.files_dir)
143
143
  )
144
144
 
145
- upstream_commit = self.git.get_upstream_fork_point() # type: ignore
146
- if upstream_commit and upstream_commit != self.git.repo.head.commit:
145
+ upstream_commit = self.git.get_upstream_fork_point()
146
+ if upstream_commit and upstream_commit != self.git.repo.head.commit: # type: ignore
147
147
  sha = upstream_commit.hexsha
148
148
  upstream_patch_path = os.path.join(
149
149
  self.settings.files_dir, f"upstream_diff_{sha}.patch"