wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,23 @@
|
|
1
1
|
"""job builder."""
|
2
2
|
import json
|
3
|
+
import logging
|
3
4
|
import os
|
4
5
|
import sys
|
5
6
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
6
7
|
|
8
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
7
9
|
from wandb.sdk.data_types._dtypes import TypeRegistry
|
8
10
|
from wandb.sdk.lib.filenames import DIFF_FNAME, METADATA_FNAME, REQUIREMENTS_FNAME
|
9
|
-
from wandb.sdk.wandb_artifacts import Artifact
|
10
11
|
from wandb.util import make_artifact_name_safe
|
11
12
|
|
12
13
|
from .settings_static import SettingsStatic
|
13
14
|
|
14
15
|
if sys.version_info >= (3, 8):
|
15
|
-
from typing import TypedDict
|
16
|
+
from typing import Literal, TypedDict
|
16
17
|
else:
|
17
|
-
from typing_extensions import TypedDict
|
18
|
+
from typing_extensions import Literal, TypedDict
|
19
|
+
|
20
|
+
_logger = logging.getLogger(__name__)
|
18
21
|
|
19
22
|
if TYPE_CHECKING:
|
20
23
|
from wandb.proto.wandb_internal_pb2 import ArtifactRecord
|
@@ -32,11 +35,13 @@ class GitInfo(TypedDict):
|
|
32
35
|
class GitSourceDict(TypedDict):
|
33
36
|
git: GitInfo
|
34
37
|
entrypoint: List[str]
|
38
|
+
notebook: bool
|
35
39
|
|
36
40
|
|
37
41
|
class ArtifactSourceDict(TypedDict):
|
38
42
|
artifact: str
|
39
43
|
entrypoint: List[str]
|
44
|
+
notebook: bool
|
40
45
|
|
41
46
|
|
42
47
|
class ImageSourceDict(TypedDict):
|
@@ -71,6 +76,7 @@ class JobBuilder:
|
|
71
76
|
_summary: Optional[Dict[str, Any]]
|
72
77
|
_logged_code_artifact: Optional[ArtifactInfoForJob]
|
73
78
|
_disable: bool
|
79
|
+
_aliases: List[str]
|
74
80
|
|
75
81
|
def __init__(self, settings: SettingsStatic):
|
76
82
|
self._settings = settings
|
@@ -80,6 +86,10 @@ class JobBuilder:
|
|
80
86
|
self._summary = None
|
81
87
|
self._logged_code_artifact = None
|
82
88
|
self._disable = settings.disable_job_creation
|
89
|
+
self._aliases = []
|
90
|
+
self._source_type: Optional[
|
91
|
+
Literal["repo", "artifact", "image"]
|
92
|
+
] = settings.get("job_source")
|
83
93
|
|
84
94
|
def set_config(self, config: Dict[str, Any]) -> None:
|
85
95
|
self._config = config
|
@@ -107,19 +117,50 @@ class JobBuilder:
|
|
107
117
|
)
|
108
118
|
|
109
119
|
def _build_repo_job(
|
110
|
-
self, metadata: Dict[str, Any], program_relpath: str
|
111
|
-
) -> Tuple[Artifact, GitSourceDict]:
|
120
|
+
self, metadata: Dict[str, Any], program_relpath: str, root: Optional[str]
|
121
|
+
) -> Tuple[Optional[Artifact], Optional[GitSourceDict]]:
|
112
122
|
git_info: Dict[str, str] = metadata.get("git", {})
|
113
123
|
remote = git_info.get("remote")
|
114
124
|
commit = git_info.get("commit")
|
115
125
|
assert remote is not None
|
116
126
|
assert commit is not None
|
127
|
+
if self._is_notebook_run():
|
128
|
+
if not os.path.exists(
|
129
|
+
os.path.join(os.getcwd(), os.path.basename(program_relpath))
|
130
|
+
):
|
131
|
+
return None, None
|
132
|
+
|
133
|
+
if root is None or self._settings._jupyter_root is None:
|
134
|
+
_logger.info("target path does not exist, exiting")
|
135
|
+
return None, None
|
136
|
+
assert self._settings._jupyter_root is not None
|
137
|
+
# git notebooks set the root to the git root,
|
138
|
+
# jupyter_root contains the path where the jupyter notebook was started
|
139
|
+
# program_relpath contains the path from jupyter_root to the file
|
140
|
+
# full program path here is actually the relpath from the program to the git root
|
141
|
+
full_program_path = os.path.join(
|
142
|
+
os.path.relpath(str(self._settings._jupyter_root), root),
|
143
|
+
program_relpath,
|
144
|
+
)
|
145
|
+
full_program_path = os.path.normpath(full_program_path)
|
146
|
+
# if the notebook server is started above the git repo need to clear all the ..s
|
147
|
+
if full_program_path.startswith(".."):
|
148
|
+
split_path = full_program_path.split("/")
|
149
|
+
count_dots = 0
|
150
|
+
for p in split_path:
|
151
|
+
if p == "..":
|
152
|
+
count_dots += 1
|
153
|
+
full_program_path = "/".join(split_path[2 * count_dots :])
|
154
|
+
|
155
|
+
else:
|
156
|
+
full_program_path = program_relpath
|
117
157
|
# TODO: update executable to a method that supports pex
|
118
158
|
source: GitSourceDict = {
|
119
159
|
"entrypoint": [
|
120
160
|
os.path.basename(sys.executable),
|
121
|
-
|
161
|
+
full_program_path,
|
122
162
|
],
|
163
|
+
"notebook": self._is_notebook_run(),
|
123
164
|
"git": {
|
124
165
|
"remote": remote,
|
125
166
|
"commit": commit,
|
@@ -132,22 +173,40 @@ class JobBuilder:
|
|
132
173
|
if os.path.exists(os.path.join(self._settings.files_dir, DIFF_FNAME)):
|
133
174
|
artifact.add_file(
|
134
175
|
os.path.join(self._settings.files_dir, DIFF_FNAME),
|
135
|
-
name=
|
176
|
+
name=DIFF_FNAME,
|
136
177
|
)
|
137
178
|
return artifact, source
|
138
179
|
|
139
180
|
def _build_artifact_job(
|
140
|
-
self, program_relpath: str
|
141
|
-
) -> Tuple[Artifact, ArtifactSourceDict]:
|
181
|
+
self, metadata: Dict[str, Any], program_relpath: str
|
182
|
+
) -> Tuple[Optional[Artifact], Optional[ArtifactSourceDict]]:
|
142
183
|
assert isinstance(self._logged_code_artifact, dict)
|
184
|
+
# TODO: should we just always exit early if the path doesn't exist?
|
185
|
+
if self._is_notebook_run() and not self._is_colab_run():
|
186
|
+
full_program_relpath = os.path.relpath(program_relpath, os.getcwd())
|
187
|
+
# if the resolved path doesn't exist, then we shouldn't make a job because it will fail
|
188
|
+
if not os.path.exists(full_program_relpath):
|
189
|
+
# when users call log code in a notebook the code artifact starts
|
190
|
+
# at the directory the notebook is in instead of the jupyter
|
191
|
+
# core
|
192
|
+
if os.path.exists(os.path.basename(program_relpath)):
|
193
|
+
full_program_relpath = os.path.basename(program_relpath)
|
194
|
+
else:
|
195
|
+
_logger.info("target path does not exist, exiting")
|
196
|
+
return None, None
|
197
|
+
else:
|
198
|
+
full_program_relpath = program_relpath
|
199
|
+
entrypoint = [
|
200
|
+
os.path.basename(sys.executable),
|
201
|
+
full_program_relpath,
|
202
|
+
]
|
143
203
|
# TODO: update executable to a method that supports pex
|
144
204
|
source: ArtifactSourceDict = {
|
145
|
-
"entrypoint":
|
146
|
-
|
147
|
-
program_relpath,
|
148
|
-
],
|
205
|
+
"entrypoint": entrypoint,
|
206
|
+
"notebook": self._is_notebook_run(),
|
149
207
|
"artifact": f"wandb-artifact://_id/{self._logged_code_artifact['id']}",
|
150
208
|
}
|
209
|
+
|
151
210
|
name = make_artifact_name_safe(f"job-{self._logged_code_artifact['name']}")
|
152
211
|
|
153
212
|
artifact = JobArtifact(name)
|
@@ -158,14 +217,27 @@ class JobBuilder:
|
|
158
217
|
) -> Tuple[Artifact, ImageSourceDict]:
|
159
218
|
image_name = metadata.get("docker")
|
160
219
|
assert isinstance(image_name, str)
|
161
|
-
|
220
|
+
|
221
|
+
raw_image_name = image_name
|
222
|
+
if ":" in image_name:
|
223
|
+
raw_image_name, tag = image_name.split(":")
|
224
|
+
self._aliases += [tag]
|
225
|
+
|
226
|
+
name = make_artifact_name_safe(f"job-{raw_image_name}")
|
162
227
|
artifact = JobArtifact(name)
|
163
228
|
source: ImageSourceDict = {
|
164
229
|
"image": image_name,
|
165
230
|
}
|
166
231
|
return artifact, source
|
167
232
|
|
233
|
+
def _is_notebook_run(self) -> bool:
|
234
|
+
return hasattr(self._settings, "_jupyter") and bool(self._settings._jupyter)
|
235
|
+
|
236
|
+
def _is_colab_run(self) -> bool:
|
237
|
+
return hasattr(self._settings, "_colab") and bool(self._settings._colab)
|
238
|
+
|
168
239
|
def build(self) -> Optional[Artifact]:
|
240
|
+
_logger.info("Attempting to build job artifact")
|
169
241
|
if not os.path.exists(
|
170
242
|
os.path.join(self._settings.files_dir, REQUIREMENTS_FNAME)
|
171
243
|
):
|
@@ -181,23 +253,40 @@ class JobBuilder:
|
|
181
253
|
|
182
254
|
program_relpath: Optional[str] = metadata.get("codePath")
|
183
255
|
|
256
|
+
source_type = self._source_type
|
257
|
+
|
258
|
+
if self._is_notebook_run():
|
259
|
+
_logger.info("run is notebook based run")
|
260
|
+
program_relpath = metadata.get("program")
|
261
|
+
|
262
|
+
if not source_type:
|
263
|
+
if self._has_git_job_ingredients(metadata, program_relpath):
|
264
|
+
_logger.info("is repo sourced job")
|
265
|
+
source_type = "repo"
|
266
|
+
elif self._has_artifact_job_ingredients(program_relpath):
|
267
|
+
_logger.info("is artifact sourced job")
|
268
|
+
source_type = "artifact"
|
269
|
+
elif self._has_image_job_ingredients(metadata):
|
270
|
+
_logger.info("is image sourced job")
|
271
|
+
source_type = "image"
|
272
|
+
|
273
|
+
if not source_type:
|
274
|
+
_logger.info("no source found")
|
275
|
+
return None
|
276
|
+
|
184
277
|
artifact = None
|
185
|
-
source_type = None
|
186
278
|
source: Optional[
|
187
279
|
Union[GitSourceDict, ArtifactSourceDict, ImageSourceDict]
|
188
280
|
] = None
|
189
|
-
|
190
|
-
if self._has_git_job_ingredients(metadata, program_relpath):
|
281
|
+
if source_type == "repo":
|
191
282
|
assert program_relpath is not None
|
192
|
-
|
193
|
-
|
194
|
-
elif
|
283
|
+
root: Optional[str] = metadata.get("root")
|
284
|
+
artifact, source = self._build_repo_job(metadata, program_relpath, root)
|
285
|
+
elif source_type == "artifact":
|
195
286
|
assert program_relpath is not None
|
196
|
-
artifact, source = self._build_artifact_job(program_relpath)
|
197
|
-
|
198
|
-
elif self._has_image_job_ingredients(metadata):
|
287
|
+
artifact, source = self._build_artifact_job(metadata, program_relpath)
|
288
|
+
elif source_type == "image":
|
199
289
|
artifact, source = self._build_image_job(metadata)
|
200
|
-
source_type = "image"
|
201
290
|
|
202
291
|
if artifact is None or source_type is None or source is None:
|
203
292
|
return None
|
@@ -213,7 +302,7 @@ class JobBuilder:
|
|
213
302
|
"output_types": output_types,
|
214
303
|
"runtime": runtime,
|
215
304
|
}
|
216
|
-
|
305
|
+
_logger.info("adding wandb-job metadata file")
|
217
306
|
with artifact.new_file("wandb-job.json") as f:
|
218
307
|
f.write(json.dumps(source_info, indent=4))
|
219
308
|
|
@@ -238,11 +327,11 @@ class JobBuilder:
|
|
238
327
|
self, metadata: Dict[str, Any], program_relpath: Optional[str]
|
239
328
|
) -> bool:
|
240
329
|
git_info: Dict[str, str] = metadata.get("git", {})
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
)
|
330
|
+
if program_relpath is None:
|
331
|
+
return False
|
332
|
+
if self._is_notebook_run() and metadata.get("root") is None:
|
333
|
+
return False
|
334
|
+
return git_info.get("remote") is not None and git_info.get("commit") is not None
|
246
335
|
|
247
336
|
def _has_artifact_job_ingredients(self, program_relpath: Optional[str]) -> bool:
|
248
337
|
return self._logged_code_artifact is not None and program_relpath is not None
|
wandb/sdk/internal/sender.py
CHANGED
@@ -30,16 +30,10 @@ from wandb.errors import CommError, UsageError
|
|
30
30
|
from wandb.errors.util import ProtobufErrorHandler
|
31
31
|
from wandb.filesync.dir_watcher import DirWatcher
|
32
32
|
from wandb.proto import wandb_internal_pb2
|
33
|
+
from wandb.sdk.artifacts import artifact_saver
|
33
34
|
from wandb.sdk.interface import interface
|
34
35
|
from wandb.sdk.interface.interface_queue import InterfaceQueue
|
35
|
-
from wandb.sdk.internal import
|
36
|
-
artifact_saver,
|
37
|
-
context,
|
38
|
-
datastore,
|
39
|
-
file_stream,
|
40
|
-
internal_api,
|
41
|
-
update,
|
42
|
-
)
|
36
|
+
from wandb.sdk.internal import context, datastore, file_stream, internal_api, update
|
43
37
|
from wandb.sdk.internal.file_pusher import FilePusher
|
44
38
|
from wandb.sdk.internal.job_builder import JobBuilder
|
45
39
|
from wandb.sdk.internal.settings_static import SettingsDict, SettingsStatic
|
@@ -486,24 +480,6 @@ class SendManager:
|
|
486
480
|
result.response.check_version_response.delete_message = delete_message
|
487
481
|
self._respond_result(result)
|
488
482
|
|
489
|
-
def _send_request_attach(
|
490
|
-
self,
|
491
|
-
req: wandb_internal_pb2.AttachRequest,
|
492
|
-
resp: wandb_internal_pb2.AttachResponse,
|
493
|
-
) -> None:
|
494
|
-
attach_id = req.attach_id
|
495
|
-
assert attach_id
|
496
|
-
assert self._run
|
497
|
-
resp.run.CopyFrom(self._run)
|
498
|
-
|
499
|
-
def send_request_attach(self, record: "Record") -> None:
|
500
|
-
assert record.control.req_resp or record.control.mailbox_slot
|
501
|
-
result = proto_util._result_from_record(record)
|
502
|
-
self._send_request_attach(
|
503
|
-
record.request.attach, result.response.attach_response
|
504
|
-
)
|
505
|
-
self._respond_result(result)
|
506
|
-
|
507
483
|
def send_request_stop_status(self, record: "Record") -> None:
|
508
484
|
result = proto_util._result_from_record(record)
|
509
485
|
status_resp = result.response.stop_status_response
|
@@ -1632,6 +1608,10 @@ class SendManager:
|
|
1632
1608
|
# TODO: this should be removed when the latest tag is handled
|
1633
1609
|
# by the backend (WB-12116)
|
1634
1610
|
proto_artifact.aliases.append("latest")
|
1611
|
+
# add docker image tag
|
1612
|
+
for alias in self._job_builder._aliases:
|
1613
|
+
proto_artifact.aliases.append(alias)
|
1614
|
+
|
1635
1615
|
proto_artifact.user_created = True
|
1636
1616
|
proto_artifact.use_after_commit = True
|
1637
1617
|
proto_artifact.finalize = True
|
@@ -8,6 +8,7 @@ class SettingsStatic:
|
|
8
8
|
# TODO(jhr): figure out how to share type defs with sdk/wandb_settings.py
|
9
9
|
_offline: Optional[bool]
|
10
10
|
_sync: bool
|
11
|
+
_disable_setproctitle: bool
|
11
12
|
_disable_stats: Optional[bool]
|
12
13
|
_disable_meta: Optional[bool]
|
13
14
|
_flow_control: bool
|
@@ -65,6 +66,7 @@ class SettingsStatic:
|
|
65
66
|
disable_job_creation: bool
|
66
67
|
_async_upload_concurrency_limit: Optional[int]
|
67
68
|
_extra_http_headers: Optional[Mapping[str, str]]
|
69
|
+
job_source: Optional[str]
|
68
70
|
|
69
71
|
# TODO(jhr): clean this up, it is only in SettingsStatic and not in Settings
|
70
72
|
_log_level: int
|
@@ -3,6 +3,7 @@ __all__ = (
|
|
3
3
|
"CPU",
|
4
4
|
"Disk",
|
5
5
|
"GPU",
|
6
|
+
"GPUAMD",
|
6
7
|
"GPUApple",
|
7
8
|
"IPU",
|
8
9
|
"Memory",
|
@@ -16,6 +17,7 @@ from .asset_registry import asset_registry
|
|
16
17
|
from .cpu import CPU
|
17
18
|
from .disk import Disk
|
18
19
|
from .gpu import GPU
|
20
|
+
from .gpu_amd import GPUAMD
|
19
21
|
from .gpu_apple import GPUApple
|
20
22
|
from .ipu import IPU
|
21
23
|
from .memory import Memory
|
@@ -137,6 +137,47 @@ class GPUMemoryAllocated:
|
|
137
137
|
return stats
|
138
138
|
|
139
139
|
|
140
|
+
class GPUMemoryAllocatedBytes:
|
141
|
+
"""GPU memory allocated in bytes for each GPU."""
|
142
|
+
|
143
|
+
# name = "memory_allocated"
|
144
|
+
name = "gpu.{}.memoryAllocatedBytes"
|
145
|
+
# samples: Deque[Tuple[datetime.datetime, float]]
|
146
|
+
samples: "Deque[List[float]]"
|
147
|
+
|
148
|
+
def __init__(self, pid: int) -> None:
|
149
|
+
self.pid = pid
|
150
|
+
self.samples = deque([])
|
151
|
+
|
152
|
+
def sample(self) -> None:
|
153
|
+
memory_allocated = []
|
154
|
+
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
|
155
|
+
for i in range(device_count):
|
156
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
|
157
|
+
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle) # type: ignore
|
158
|
+
memory_allocated.append(memory_info.used)
|
159
|
+
self.samples.append(memory_allocated)
|
160
|
+
|
161
|
+
def clear(self) -> None:
|
162
|
+
self.samples.clear()
|
163
|
+
|
164
|
+
def aggregate(self) -> dict:
|
165
|
+
if not self.samples:
|
166
|
+
return {}
|
167
|
+
stats = {}
|
168
|
+
device_count = pynvml.nvmlDeviceGetCount() # type: ignore
|
169
|
+
for i in range(device_count):
|
170
|
+
samples = [sample[i] for sample in self.samples]
|
171
|
+
aggregate = aggregate_mean(samples)
|
172
|
+
stats[self.name.format(i)] = aggregate
|
173
|
+
|
174
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(i) # type: ignore
|
175
|
+
if gpu_in_use_by_this_process(handle, self.pid):
|
176
|
+
stats[self.name.format(f"process.{i}")] = aggregate
|
177
|
+
|
178
|
+
return stats
|
179
|
+
|
180
|
+
|
140
181
|
class GPUUtilization:
|
141
182
|
"""GPU utilization in percent for each GPU."""
|
142
183
|
|
@@ -314,6 +355,7 @@ class GPU:
|
|
314
355
|
self.name = self.__class__.__name__.lower()
|
315
356
|
self.metrics: List[Metric] = [
|
316
357
|
GPUMemoryAllocated(settings._stats_pid),
|
358
|
+
GPUMemoryAllocatedBytes(settings._stats_pid),
|
317
359
|
GPUMemoryUtilization(settings._stats_pid),
|
318
360
|
GPUUtilization(settings._stats_pid),
|
319
361
|
GPUTemperature(settings._stats_pid),
|
@@ -0,0 +1,216 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
import shutil
|
4
|
+
import subprocess
|
5
|
+
import sys
|
6
|
+
import threading
|
7
|
+
from collections import deque
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
9
|
+
|
10
|
+
if sys.version_info >= (3, 8):
|
11
|
+
from typing import Final, Literal
|
12
|
+
else:
|
13
|
+
from typing_extensions import Final, Literal
|
14
|
+
|
15
|
+
from wandb.sdk.lib import telemetry
|
16
|
+
|
17
|
+
from .aggregators import aggregate_mean
|
18
|
+
from .asset_registry import asset_registry
|
19
|
+
from .interfaces import Interface, Metric, MetricsMonitor
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from typing import Deque
|
23
|
+
|
24
|
+
from wandb.sdk.internal.settings_static import SettingsStatic
|
25
|
+
|
26
|
+
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
29
|
+
|
30
|
+
ROCM_SMI_CMD: Final[str] = shutil.which("rocm-smi") or "/usr/bin/rocm-smi"
|
31
|
+
|
32
|
+
|
33
|
+
def get_rocm_smi_stats() -> Dict[str, Any]:
|
34
|
+
command = [str(ROCM_SMI_CMD), "-a", "--json"]
|
35
|
+
output = (
|
36
|
+
subprocess.check_output(command, universal_newlines=True).strip().split("\n")
|
37
|
+
)[0]
|
38
|
+
return json.loads(output) # type: ignore
|
39
|
+
|
40
|
+
|
41
|
+
_StatsKeys = Literal[
|
42
|
+
"gpu",
|
43
|
+
"memoryAllocated",
|
44
|
+
"temp",
|
45
|
+
"powerWatts",
|
46
|
+
"powerPercent",
|
47
|
+
]
|
48
|
+
_Stats = Dict[_StatsKeys, float]
|
49
|
+
|
50
|
+
|
51
|
+
_InfoDict = Dict[str, Union[int, List[Dict[str, Any]]]]
|
52
|
+
|
53
|
+
|
54
|
+
class GPUAMDStats:
|
55
|
+
"""Stats for AMD GPU devices."""
|
56
|
+
|
57
|
+
name = "gpu.{gpu_id}.{key}"
|
58
|
+
samples: "Deque[List[_Stats]]"
|
59
|
+
|
60
|
+
def __init__(self) -> None:
|
61
|
+
self.samples = deque()
|
62
|
+
|
63
|
+
@staticmethod
|
64
|
+
def parse_stats(stats: Dict[str, str]) -> _Stats:
|
65
|
+
"""Parse stats from rocm-smi output."""
|
66
|
+
parsed_stats: _Stats = {}
|
67
|
+
|
68
|
+
try:
|
69
|
+
parsed_stats["gpu"] = float(stats.get("GPU use (%)")) # type: ignore
|
70
|
+
except (TypeError, ValueError):
|
71
|
+
logger.warning("Could not parse GPU usage as float")
|
72
|
+
try:
|
73
|
+
parsed_stats["memoryAllocated"] = float(stats.get("GPU memory use (%)")) # type: ignore
|
74
|
+
except (TypeError, ValueError):
|
75
|
+
logger.warning("Could not parse GPU memory allocation as float")
|
76
|
+
try:
|
77
|
+
parsed_stats["temp"] = float(stats.get("Temperature (Sensor memory) (C)")) # type: ignore
|
78
|
+
except (TypeError, ValueError):
|
79
|
+
logger.warning("Could not parse GPU temperature as float")
|
80
|
+
try:
|
81
|
+
parsed_stats["powerWatts"] = float(
|
82
|
+
stats.get("Average Graphics Package Power (W)") # type: ignore
|
83
|
+
)
|
84
|
+
except (TypeError, ValueError):
|
85
|
+
logger.warning("Could not parse GPU power as float")
|
86
|
+
try:
|
87
|
+
parsed_stats["powerPercent"] = (
|
88
|
+
float(stats.get("Average Graphics Package Power (W)")) # type: ignore
|
89
|
+
/ float(stats.get("Max Graphics Package Power (W)")) # type: ignore
|
90
|
+
* 100
|
91
|
+
)
|
92
|
+
except (TypeError, ValueError):
|
93
|
+
logger.warning("Could not parse GPU average/max power as float")
|
94
|
+
|
95
|
+
return parsed_stats
|
96
|
+
|
97
|
+
def sample(self) -> None:
|
98
|
+
try:
|
99
|
+
raw_stats = get_rocm_smi_stats()
|
100
|
+
cards = []
|
101
|
+
|
102
|
+
card_keys = [
|
103
|
+
key for key in sorted(raw_stats.keys()) if key.startswith("card")
|
104
|
+
]
|
105
|
+
|
106
|
+
for card_key in card_keys:
|
107
|
+
card_stats = raw_stats[card_key]
|
108
|
+
stats = self.parse_stats(card_stats)
|
109
|
+
if stats:
|
110
|
+
cards.append(stats)
|
111
|
+
|
112
|
+
if cards:
|
113
|
+
self.samples.append(cards)
|
114
|
+
|
115
|
+
except (OSError, ValueError, TypeError, subprocess.CalledProcessError) as e:
|
116
|
+
logger.exception(f"GPU stats error: {e}")
|
117
|
+
|
118
|
+
def clear(self) -> None:
|
119
|
+
self.samples.clear()
|
120
|
+
|
121
|
+
def aggregate(self) -> dict:
|
122
|
+
if not self.samples:
|
123
|
+
return {}
|
124
|
+
stats = {}
|
125
|
+
device_count = len(self.samples[0])
|
126
|
+
|
127
|
+
for i in range(device_count):
|
128
|
+
samples = [sample[i] for sample in self.samples]
|
129
|
+
|
130
|
+
for key in samples[0].keys():
|
131
|
+
samples_key = [s[key] for s in samples]
|
132
|
+
aggregate = aggregate_mean(samples_key)
|
133
|
+
stats[self.name.format(gpu_id=i, key=key)] = aggregate
|
134
|
+
|
135
|
+
return stats
|
136
|
+
|
137
|
+
|
138
|
+
@asset_registry.register
|
139
|
+
class GPUAMD:
|
140
|
+
"""GPUAMD is a class for monitoring AMD GPU devices.
|
141
|
+
|
142
|
+
Uses AMD's rocm_smi tool to get GPU stats.
|
143
|
+
For the list of supported environments and devices, see
|
144
|
+
https://github.com/RadeonOpenCompute/ROCm/blob/develop/docs/deploy/
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
interface: "Interface",
|
150
|
+
settings: "SettingsStatic",
|
151
|
+
shutdown_event: threading.Event,
|
152
|
+
) -> None:
|
153
|
+
self.name = self.__class__.__name__.lower()
|
154
|
+
self.metrics: List[Metric] = [
|
155
|
+
GPUAMDStats(),
|
156
|
+
]
|
157
|
+
self.metrics_monitor = MetricsMonitor(
|
158
|
+
self.name,
|
159
|
+
self.metrics,
|
160
|
+
interface,
|
161
|
+
settings,
|
162
|
+
shutdown_event,
|
163
|
+
)
|
164
|
+
telemetry_record = telemetry.TelemetryRecord()
|
165
|
+
telemetry_record.env.amd_gpu = True
|
166
|
+
interface._publish_telemetry(telemetry_record)
|
167
|
+
|
168
|
+
@classmethod
|
169
|
+
def is_available(cls) -> bool:
|
170
|
+
rocm_smi_available = shutil.which(ROCM_SMI_CMD) is not None
|
171
|
+
if rocm_smi_available:
|
172
|
+
try:
|
173
|
+
_ = get_rocm_smi_stats()
|
174
|
+
return True
|
175
|
+
except Exception:
|
176
|
+
pass
|
177
|
+
return False
|
178
|
+
|
179
|
+
def start(self) -> None:
|
180
|
+
self.metrics_monitor.start()
|
181
|
+
|
182
|
+
def finish(self) -> None:
|
183
|
+
self.metrics_monitor.finish()
|
184
|
+
|
185
|
+
def probe(self) -> dict:
|
186
|
+
info: _InfoDict = {}
|
187
|
+
try:
|
188
|
+
stats = get_rocm_smi_stats()
|
189
|
+
|
190
|
+
info["gpu_count"] = len(
|
191
|
+
[key for key in stats.keys() if key.startswith("card")]
|
192
|
+
)
|
193
|
+
key_mapping = {
|
194
|
+
"id": "GPU ID",
|
195
|
+
"unique_id": "Unique ID",
|
196
|
+
"vbios_version": "VBIOS version",
|
197
|
+
"performance_level": "Performance Level",
|
198
|
+
"gpu_overdrive": "GPU OverDrive value (%)",
|
199
|
+
"gpu_memory_overdrive": "GPU Memory OverDrive value (%)",
|
200
|
+
"max_power": "Max Graphics Package Power (W)",
|
201
|
+
"series": "Card series",
|
202
|
+
"model": "Card model",
|
203
|
+
"vendor": "Card vendor",
|
204
|
+
"sku": "Card SKU",
|
205
|
+
"sclk_range": "Valid sclk range",
|
206
|
+
"mclk_range": "Valid mclk range",
|
207
|
+
}
|
208
|
+
|
209
|
+
info["gpu_devices"] = [
|
210
|
+
{k: stats[key][v] for k, v in key_mapping.items() if stats[key].get(v)}
|
211
|
+
for key in stats.keys()
|
212
|
+
if key.startswith("card")
|
213
|
+
]
|
214
|
+
except Exception as e:
|
215
|
+
logger.exception(f"GPUAMD probe error: {e}")
|
216
|
+
return info
|
@@ -0,0 +1,13 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from sentry_sdk.integrations.aws_lambda import get_lambda_bootstrap # type: ignore
|
4
|
+
|
5
|
+
logger = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
def is_aws_lambda() -> bool:
|
9
|
+
"""Check if we are running in a lambda environment."""
|
10
|
+
lambda_bootstrap = get_lambda_bootstrap()
|
11
|
+
if not lambda_bootstrap or not hasattr(lambda_bootstrap, "handle_event_request"):
|
12
|
+
return False
|
13
|
+
return True
|
@@ -18,7 +18,7 @@ from wandb.sdk.lib.filenames import (
|
|
18
18
|
METADATA_FNAME,
|
19
19
|
REQUIREMENTS_FNAME,
|
20
20
|
)
|
21
|
-
from wandb.sdk.lib.
|
21
|
+
from wandb.sdk.lib.gitlib import GitRepo
|
22
22
|
|
23
23
|
from .assets.interfaces import Interface
|
24
24
|
|
@@ -142,8 +142,8 @@ class SystemInfo:
|
|
142
142
|
os.path.relpath(patch_path, start=self.settings.files_dir)
|
143
143
|
)
|
144
144
|
|
145
|
-
upstream_commit = self.git.get_upstream_fork_point()
|
146
|
-
if upstream_commit and upstream_commit != self.git.repo.head.commit:
|
145
|
+
upstream_commit = self.git.get_upstream_fork_point()
|
146
|
+
if upstream_commit and upstream_commit != self.git.repo.head.commit: # type: ignore
|
147
147
|
sha = upstream_commit.hexsha
|
148
148
|
upstream_patch_path = os.path.join(
|
149
149
|
self.settings.files_dir, f"upstream_diff_{sha}.patch"
|