wandb 0.17.5__py3-none-any.whl → 0.17.6__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +3 -1
- wandb/apis/public/api.py +1 -1
- wandb/apis/public/jobs.py +5 -0
- wandb/bin/nvidia_gpu_stats +0 -0
- wandb/data_types.py +2 -1
- wandb/env.py +6 -0
- wandb/integration/lightning/fabric/logger.py +4 -4
- wandb/proto/v3/wandb_internal_pb2.py +226 -226
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +226 -226
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +226 -226
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +4 -0
- wandb/proto/wandb_internal_pb2.py +6 -0
- wandb/sdk/artifacts/artifact.py +5 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +31 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +35 -23
- wandb/sdk/data_types/object_3d.py +113 -2
- wandb/sdk/interface/interface.py +23 -0
- wandb/sdk/internal/sender.py +31 -15
- wandb/sdk/launch/_launch.py +4 -2
- wandb/sdk/launch/_project_spec.py +34 -8
- wandb/sdk/launch/agent/agent.py +6 -2
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -4
- wandb/sdk/launch/builder/build.py +4 -2
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +2 -1
- wandb/sdk/launch/inputs/internal.py +42 -28
- wandb/sdk/launch/inputs/schema.py +39 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +72 -0
- wandb/sdk/launch/runner/local_container.py +13 -10
- wandb/sdk/launch/runner/sagemaker_runner.py +3 -5
- wandb/sdk/launch/utils.py +2 -0
- wandb/sdk/lib/disabled.py +13 -174
- wandb/sdk/wandb_init.py +23 -27
- wandb/sdk/wandb_login.py +6 -6
- wandb/sdk/wandb_run.py +41 -22
- wandb/sdk/wandb_settings.py +3 -2
- wandb/wandb_agent.py +2 -0
- {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/METADATA +3 -2
- {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/RECORD +47 -45
- {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/WHEEL +0 -0
- {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ import enum
|
|
7
7
|
import json
|
8
8
|
import logging
|
9
9
|
import os
|
10
|
+
import shutil
|
10
11
|
import tempfile
|
11
12
|
from copy import deepcopy
|
12
13
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
|
@@ -112,6 +113,9 @@ class LaunchProject:
|
|
112
113
|
self.sweep_id = sweep_id
|
113
114
|
self.author = launch_spec.get("author")
|
114
115
|
self.python_version: Optional[str] = launch_spec.get("python_version")
|
116
|
+
self._job_dockerfile: Optional[str] = None
|
117
|
+
self._job_build_context: Optional[str] = None
|
118
|
+
self._job_base_image: Optional[str] = None
|
115
119
|
self.accelerator_base_image: Optional[str] = resource_args_build.get(
|
116
120
|
"accelerator", {}
|
117
121
|
).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
|
@@ -131,8 +135,6 @@ class LaunchProject:
|
|
131
135
|
self._queue_name: Optional[str] = None
|
132
136
|
self._queue_entity: Optional[str] = None
|
133
137
|
self._run_queue_item_id: Optional[str] = None
|
134
|
-
self._job_dockerfile: Optional[str] = None
|
135
|
-
self._job_build_context: Optional[str] = None
|
136
138
|
|
137
139
|
def init_source(self) -> None:
|
138
140
|
if self.docker_image is not None:
|
@@ -146,6 +148,21 @@ class LaunchProject:
|
|
146
148
|
self.project_dir = os.getcwd()
|
147
149
|
self._entry_point = self.override_entrypoint
|
148
150
|
|
151
|
+
def change_project_dir(self, new_dir: str) -> None:
|
152
|
+
"""Change the project directory to a new directory."""
|
153
|
+
# Copy the contents of the old project dir to the new project dir.
|
154
|
+
old_dir = self.project_dir
|
155
|
+
if old_dir is not None:
|
156
|
+
shutil.copytree(
|
157
|
+
old_dir,
|
158
|
+
new_dir,
|
159
|
+
symlinks=True,
|
160
|
+
dirs_exist_ok=True,
|
161
|
+
ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"),
|
162
|
+
)
|
163
|
+
shutil.rmtree(old_dir)
|
164
|
+
self.project_dir = new_dir
|
165
|
+
|
149
166
|
def init_git(self, git_info: Dict[str, str]) -> None:
|
150
167
|
self.git_version = git_info.get("version")
|
151
168
|
self.git_repo = git_info.get("repo")
|
@@ -212,14 +229,23 @@ class LaunchProject:
|
|
212
229
|
def job_build_context(self) -> Optional[str]:
|
213
230
|
return self._job_build_context
|
214
231
|
|
232
|
+
@property
|
233
|
+
def job_base_image(self) -> Optional[str]:
|
234
|
+
return self._job_base_image
|
235
|
+
|
215
236
|
def set_job_dockerfile(self, dockerfile: str) -> None:
|
216
237
|
self._job_dockerfile = dockerfile
|
217
238
|
|
218
239
|
def set_job_build_context(self, build_context: str) -> None:
|
219
240
|
self._job_build_context = build_context
|
220
241
|
|
242
|
+
def set_job_base_image(self, base_image: str) -> None:
|
243
|
+
self._job_base_image = base_image
|
244
|
+
|
221
245
|
@property
|
222
246
|
def image_name(self) -> str:
|
247
|
+
if self.job_base_image is not None:
|
248
|
+
return self.job_base_image
|
223
249
|
if self.docker_image is not None:
|
224
250
|
return self.docker_image
|
225
251
|
elif self.uri is not None:
|
@@ -299,10 +325,8 @@ class LaunchProject:
|
|
299
325
|
|
300
326
|
def build_required(self) -> bool:
|
301
327
|
"""Checks the source to see if a build is required."""
|
302
|
-
|
303
|
-
|
304
|
-
# we don't need to build the image for a job if that tag
|
305
|
-
# already exists
|
328
|
+
if self.job_base_image is not None:
|
329
|
+
return False
|
306
330
|
if self.source != LaunchSource.JOB:
|
307
331
|
return True
|
308
332
|
return False
|
@@ -316,7 +340,9 @@ class LaunchProject:
|
|
316
340
|
Returns:
|
317
341
|
Optional[str]: The Docker image or None if not specified.
|
318
342
|
"""
|
319
|
-
|
343
|
+
if self._docker_image:
|
344
|
+
return self._docker_image
|
345
|
+
return None
|
320
346
|
|
321
347
|
@docker_image.setter
|
322
348
|
def docker_image(self, value: str) -> None:
|
@@ -336,7 +362,7 @@ class LaunchProject:
|
|
336
362
|
# assuming project only has 1 entry point, pull that out
|
337
363
|
# tmp fn until we figure out if we want to support multiple entry points or not
|
338
364
|
if not self._entry_point:
|
339
|
-
if not self.docker_image:
|
365
|
+
if not self.docker_image and not self.job_base_image:
|
340
366
|
raise LaunchError(
|
341
367
|
"Project must have at least one entry point unless docker image is specified."
|
342
368
|
)
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -717,7 +717,7 @@ class LaunchAgent:
|
|
717
717
|
_, build_config, registry_config = construct_agent_configs(
|
718
718
|
default_config, override_build_config
|
719
719
|
)
|
720
|
-
image_uri = project.docker_image
|
720
|
+
image_uri = project.docker_image or project.job_base_image
|
721
721
|
entrypoint = project.get_job_entry_point()
|
722
722
|
environment = loader.environment_from_config(
|
723
723
|
default_config.get("environment", {})
|
@@ -727,7 +727,11 @@ class LaunchAgent:
|
|
727
727
|
backend = loader.runner_from_config(
|
728
728
|
resource, api, backend_config, environment, registry
|
729
729
|
)
|
730
|
-
if not (
|
730
|
+
if not (
|
731
|
+
project.docker_image
|
732
|
+
or project.job_base_image
|
733
|
+
or isinstance(backend, LocalProcessRunner)
|
734
|
+
):
|
731
735
|
assert entrypoint is not None
|
732
736
|
image_uri = await builder.build_image(project, entrypoint, job_tracker)
|
733
737
|
|
@@ -2,7 +2,7 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
import sys
|
5
|
-
from typing import List, Optional
|
5
|
+
from typing import List, Optional
|
6
6
|
|
7
7
|
import wandb
|
8
8
|
|
@@ -17,9 +17,7 @@ FileSubtypes = Literal["warning", "error"]
|
|
17
17
|
class RunQueueItemFileSaver:
|
18
18
|
def __init__(
|
19
19
|
self,
|
20
|
-
agent_run: Optional[
|
21
|
-
Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.RunDisabled"]
|
22
|
-
],
|
20
|
+
agent_run: Optional["wandb.sdk.wandb_run.Run"],
|
23
21
|
run_queue_item_id: str,
|
24
22
|
):
|
25
23
|
self.run_queue_item_id = run_queue_item_id
|
@@ -201,7 +201,7 @@ def get_requirements_section(
|
|
201
201
|
# If there is a requirements.txt at root of build context, use that.
|
202
202
|
if (base_path / "src" / "requirements.txt").exists():
|
203
203
|
requirements_files += ["src/requirements.txt"]
|
204
|
-
deps_install_line = "pip install -r requirements.txt"
|
204
|
+
deps_install_line = "pip install uv && uv pip install -r requirements.txt"
|
205
205
|
with open(base_path / "src" / "requirements.txt") as f:
|
206
206
|
requirements = f.readlines()
|
207
207
|
if not any(["wandb" in r for r in requirements]):
|
@@ -237,7 +237,9 @@ def get_requirements_section(
|
|
237
237
|
with open(base_path / "src" / "requirements.txt", "w") as f:
|
238
238
|
f.write("\n".join(project_deps))
|
239
239
|
requirements_files += ["src/requirements.txt"]
|
240
|
-
deps_install_line =
|
240
|
+
deps_install_line = (
|
241
|
+
"pip install uv && uv pip install -r requirements.txt"
|
242
|
+
)
|
241
243
|
return PIP_TEMPLATE.format(
|
242
244
|
buildx_optional_prefix=prefix,
|
243
245
|
requirements_files=" ".join(requirements_files),
|
@@ -39,12 +39,13 @@ def install_deps(
|
|
39
39
|
deps (str[], None): The dependencies that failed to install
|
40
40
|
"""
|
41
41
|
try:
|
42
|
+
subprocess.check_output(["pip", "install", "uv"], stderr=subprocess.STDOUT)
|
42
43
|
# Include only uri if @ is present
|
43
44
|
clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
|
44
45
|
index_args = ["--extra-index-url", extra_index] if extra_index else []
|
45
46
|
print("installing {}...".format(", ".join(clean_deps)))
|
46
47
|
opts = opts or []
|
47
|
-
args = ["pip", "install"] + opts + clean_deps + index_args
|
48
|
+
args = ["uv", "pip", "install"] + opts + clean_deps + index_args
|
48
49
|
sys.stdout.flush()
|
49
50
|
subprocess.check_output(args, stderr=subprocess.STDOUT)
|
50
51
|
return failed
|
@@ -16,7 +16,9 @@ from typing import Any, Dict, List, Optional
|
|
16
16
|
import wandb
|
17
17
|
import wandb.data_types
|
18
18
|
from wandb.sdk.launch.errors import LaunchError
|
19
|
+
from wandb.sdk.launch.inputs.schema import META_SCHEMA
|
19
20
|
from wandb.sdk.wandb_run import Run
|
21
|
+
from wandb.util import get_module
|
20
22
|
|
21
23
|
from .files import config_path_is_valid, override_file
|
22
24
|
|
@@ -129,7 +131,7 @@ def _publish_job_input(
|
|
129
131
|
)
|
130
132
|
|
131
133
|
|
132
|
-
def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
|
134
|
+
def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
|
133
135
|
"""Recursively fix JSON schemas with common issues.
|
134
136
|
|
135
137
|
1. Replaces any instances of $ref with their associated definition in defs
|
@@ -137,7 +139,7 @@ def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
|
|
137
139
|
See test_internal.py for examples
|
138
140
|
"""
|
139
141
|
ret: Dict[str, Any] = {}
|
140
|
-
if "$ref" in schema:
|
142
|
+
if "$ref" in schema and defs:
|
141
143
|
# Reference found, replace it with its definition
|
142
144
|
def_key = schema["$ref"].split("#/$defs/")[1]
|
143
145
|
# Also run recursive replacement in case a ref contains more refs
|
@@ -170,12 +172,16 @@ def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
|
|
170
172
|
return ret
|
171
173
|
|
172
174
|
|
173
|
-
def
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
175
|
+
def _validate_schema(schema: dict) -> None:
|
176
|
+
jsonschema = get_module(
|
177
|
+
"jsonschema",
|
178
|
+
required="Setting job schema requires the jsonschema package. Please install it with `pip install 'wandb[launch]'`.",
|
179
|
+
lazy=False,
|
180
|
+
)
|
181
|
+
validator = jsonschema.Draft202012Validator(META_SCHEMA)
|
182
|
+
errs = sorted(validator.iter_errors(schema), key=str)
|
183
|
+
if errs:
|
184
|
+
wandb.termwarn(f"Schema includes unhandled or invalid configurations:\n{errs}")
|
179
185
|
|
180
186
|
|
181
187
|
def handle_config_file_input(
|
@@ -204,16 +210,20 @@ def handle_config_file_input(
|
|
204
210
|
path,
|
205
211
|
dest,
|
206
212
|
)
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
schema
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
213
|
+
if schema:
|
214
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
215
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
216
|
+
if hasattr(schema, "model_json_schema") and callable(
|
217
|
+
schema.model_json_schema # type: ignore
|
218
|
+
):
|
219
|
+
schema = schema.model_json_schema()
|
220
|
+
if not isinstance(schema, dict):
|
221
|
+
raise LaunchError(
|
222
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
223
|
+
)
|
224
|
+
defs = schema.pop("$defs", None)
|
225
|
+
schema = _replace_refs_and_allofs(schema, defs)
|
226
|
+
_validate_schema(schema)
|
217
227
|
arguments = JobInputArguments(
|
218
228
|
include=include,
|
219
229
|
exclude=exclude,
|
@@ -241,16 +251,20 @@ def handle_run_config_input(
|
|
241
251
|
If there is no active run, the include and exclude paths are staged and sent
|
242
252
|
when a run is created.
|
243
253
|
"""
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
schema
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
+
if schema:
|
255
|
+
# This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
|
256
|
+
# or the BaseModel class itself (e.g. schema=MySchema)
|
257
|
+
if hasattr(schema, "model_json_schema") and callable(
|
258
|
+
schema.model_json_schema # type: ignore
|
259
|
+
):
|
260
|
+
schema = schema.model_json_schema()
|
261
|
+
if not isinstance(schema, dict):
|
262
|
+
raise LaunchError(
|
263
|
+
"schema must be a dict, Pydantic model instance, or Pydantic model class."
|
264
|
+
)
|
265
|
+
defs = schema.pop("$defs", None)
|
266
|
+
schema = _replace_refs_and_allofs(schema, defs)
|
267
|
+
_validate_schema(schema)
|
254
268
|
arguments = JobInputArguments(
|
255
269
|
include=include,
|
256
270
|
exclude=exclude,
|
@@ -0,0 +1,39 @@
|
|
1
|
+
META_SCHEMA = {
|
2
|
+
"type": "object",
|
3
|
+
"properties": {
|
4
|
+
"type": {
|
5
|
+
"type": "string",
|
6
|
+
"enum": ["boolean", "integer", "number", "string", "object"],
|
7
|
+
},
|
8
|
+
"title": {"type": "string"},
|
9
|
+
"description": {"type": "string"},
|
10
|
+
"enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
|
11
|
+
"properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
|
12
|
+
"allOf": {"type": "array", "items": {"$ref": "#"}},
|
13
|
+
},
|
14
|
+
"allOf": [
|
15
|
+
{
|
16
|
+
"if": {"properties": {"type": {"const": "number"}}},
|
17
|
+
"then": {
|
18
|
+
"properties": {
|
19
|
+
"minimum": {"type": ["integer", "number"]},
|
20
|
+
"maximum": {"type": ["integer", "number"]},
|
21
|
+
"exclusiveMinimum": {"type": ["integer", "number"]},
|
22
|
+
"exclusiveMaximum": {"type": ["integer", "number"]},
|
23
|
+
}
|
24
|
+
},
|
25
|
+
},
|
26
|
+
{
|
27
|
+
"if": {"properties": {"type": {"const": "integer"}}},
|
28
|
+
"then": {
|
29
|
+
"properties": {
|
30
|
+
"minimum": {"type": "integer"},
|
31
|
+
"maximum": {"type": "integer"},
|
32
|
+
"exclusiveMinimum": {"type": "integer"},
|
33
|
+
"exclusiveMaximum": {"type": "integer"},
|
34
|
+
}
|
35
|
+
},
|
36
|
+
},
|
37
|
+
],
|
38
|
+
"unevaluatedProperties": False,
|
39
|
+
}
|
@@ -31,6 +31,7 @@ from wandb.util import get_module
|
|
31
31
|
from .._project_spec import EntryPoint, LaunchProject
|
32
32
|
from ..errors import LaunchError
|
33
33
|
from ..utils import (
|
34
|
+
CODE_MOUNT_DIR,
|
34
35
|
LOG_PREFIX,
|
35
36
|
MAX_ENV_LENGTHS,
|
36
37
|
PROJECT_SYNCHRONOUS,
|
@@ -66,6 +67,10 @@ API_KEY_SECRET_MAX_RETRIES = 5
|
|
66
67
|
_logger = logging.getLogger(__name__)
|
67
68
|
|
68
69
|
|
70
|
+
SOURCE_CODE_PVC_MOUNT_PATH = os.environ.get("WANDB_LAUNCH_CODE_PVC_MOUNT_PATH")
|
71
|
+
SOURCE_CODE_PVC_NAME = os.environ.get("WANDB_LAUNCH_CODE_PVC_NAME")
|
72
|
+
|
73
|
+
|
69
74
|
class KubernetesSubmittedRun(AbstractRun):
|
70
75
|
"""Wrapper for a launched run on Kubernetes."""
|
71
76
|
|
@@ -468,6 +473,12 @@ class KubernetesRunner(AbstractRunner):
|
|
468
473
|
"true",
|
469
474
|
)
|
470
475
|
|
476
|
+
if launch_project.job_base_image:
|
477
|
+
apply_code_mount_configuration(
|
478
|
+
job,
|
479
|
+
launch_project,
|
480
|
+
)
|
481
|
+
|
471
482
|
# Add wandb.ai/agent: current agent label on all pods
|
472
483
|
if LaunchAgent.initialized():
|
473
484
|
add_label_to_pods(
|
@@ -504,6 +515,22 @@ class KubernetesRunner(AbstractRunner):
|
|
504
515
|
kubernetes_asyncio, resource_args
|
505
516
|
)
|
506
517
|
|
518
|
+
# If using pvc for code mount, move code there.
|
519
|
+
if launch_project.job_base_image is not None:
|
520
|
+
if SOURCE_CODE_PVC_NAME is None or SOURCE_CODE_PVC_MOUNT_PATH is None:
|
521
|
+
raise LaunchError(
|
522
|
+
"WANDB_LAUNCH_SOURCE_CODE_PVC_ environment variables not set. "
|
523
|
+
"Unable to mount source code PVC into base image. "
|
524
|
+
"Use the `codeMountPvcName` variable in the agent helm chart "
|
525
|
+
"to enable base image jobs for this agent. See "
|
526
|
+
"https://github.com/wandb/helm-charts/tree/main/charts/launch-agent "
|
527
|
+
"for more information."
|
528
|
+
)
|
529
|
+
code_subdir = launch_project.get_image_source_string()
|
530
|
+
launch_project.change_project_dir(
|
531
|
+
os.path.join(SOURCE_CODE_PVC_MOUNT_PATH, code_subdir)
|
532
|
+
)
|
533
|
+
|
507
534
|
# If the user specified an alternate api, we need will execute this
|
508
535
|
# run by creating a custom object.
|
509
536
|
api_version = resource_args.get("apiVersion", "batch/v1")
|
@@ -542,6 +569,9 @@ class KubernetesRunner(AbstractRunner):
|
|
542
569
|
LaunchAgent.name()
|
543
570
|
)
|
544
571
|
|
572
|
+
if launch_project.job_base_image:
|
573
|
+
apply_code_mount_configuration(resource_args, launch_project)
|
574
|
+
|
545
575
|
overrides = {}
|
546
576
|
if launch_project.override_args:
|
547
577
|
overrides["args"] = launch_project.override_args
|
@@ -889,3 +919,45 @@ def add_entrypoint_args_overrides(manifest: Union[dict, list], overrides: dict)
|
|
889
919
|
container["args"] = overrides["args"]
|
890
920
|
for value in manifest.values():
|
891
921
|
add_entrypoint_args_overrides(value, overrides)
|
922
|
+
|
923
|
+
|
924
|
+
def apply_code_mount_configuration(
|
925
|
+
manifest: Union[Dict, list], project: LaunchProject
|
926
|
+
) -> None:
|
927
|
+
"""Apply code mount configuration to all containers in a manifest.
|
928
|
+
|
929
|
+
Recursively traverses the manifest and adds the code mount configuration to
|
930
|
+
all containers. Containers are identified by the presence of a "spec" key
|
931
|
+
with a "containers" key in the value.
|
932
|
+
|
933
|
+
Arguments:
|
934
|
+
manifest: The manifest to modify.
|
935
|
+
project: The launch project.
|
936
|
+
|
937
|
+
Returns: None.
|
938
|
+
"""
|
939
|
+
assert SOURCE_CODE_PVC_NAME is not None
|
940
|
+
source_dir = project.get_image_source_string()
|
941
|
+
for pod in yield_pods(manifest):
|
942
|
+
for container in yield_containers(pod):
|
943
|
+
if "volumeMounts" not in container:
|
944
|
+
container["volumeMounts"] = []
|
945
|
+
container["volumeMounts"].append(
|
946
|
+
{
|
947
|
+
"name": "wandb-source-code-volume",
|
948
|
+
"mountPath": CODE_MOUNT_DIR,
|
949
|
+
"subPath": source_dir,
|
950
|
+
}
|
951
|
+
)
|
952
|
+
container["workingDir"] = CODE_MOUNT_DIR
|
953
|
+
spec = pod["spec"]
|
954
|
+
if "volumes" not in spec:
|
955
|
+
spec["volumes"] = []
|
956
|
+
spec["volumes"].append(
|
957
|
+
{
|
958
|
+
"name": "wandb-source-code-volume",
|
959
|
+
"persistentVolumeClaim": {
|
960
|
+
"claimName": SOURCE_CODE_PVC_NAME,
|
961
|
+
},
|
962
|
+
}
|
963
|
+
)
|
@@ -14,6 +14,7 @@ from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
|
14
14
|
from .._project_spec import LaunchProject
|
15
15
|
from ..errors import LaunchError
|
16
16
|
from ..utils import (
|
17
|
+
CODE_MOUNT_DIR,
|
17
18
|
LOG_PREFIX,
|
18
19
|
MAX_ENV_LENGTHS,
|
19
20
|
PROJECT_SYNCHRONOUS,
|
@@ -121,7 +122,15 @@ class LocalContainerRunner(AbstractRunner):
|
|
121
122
|
docker_args["network"] = "host"
|
122
123
|
if sys.platform == "linux" or sys.platform == "linux2":
|
123
124
|
docker_args["add-host"] = "host.docker.internal:host-gateway"
|
124
|
-
|
125
|
+
base_image = launch_project.job_base_image
|
126
|
+
if base_image is not None:
|
127
|
+
# Mount code into the container and set the working directory.
|
128
|
+
if "volume" not in docker_args:
|
129
|
+
docker_args["volume"] = []
|
130
|
+
docker_args["volume"].append(
|
131
|
+
f"{launch_project.project_dir}:{CODE_MOUNT_DIR}"
|
132
|
+
)
|
133
|
+
docker_args["workdir"] = CODE_MOUNT_DIR
|
125
134
|
return docker_args
|
126
135
|
|
127
136
|
async def run(
|
@@ -146,7 +155,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
146
155
|
elif _is_wandb_dev_uri(self._api.settings("base_url")):
|
147
156
|
env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
|
148
157
|
|
149
|
-
if launch_project.docker_image:
|
158
|
+
if launch_project.docker_image or launch_project.job_base_image:
|
150
159
|
try:
|
151
160
|
pull_docker_image(image_uri)
|
152
161
|
except Exception as e:
|
@@ -156,14 +165,8 @@ class LocalContainerRunner(AbstractRunner):
|
|
156
165
|
f"Failed to pull docker image {image_uri} with error: {e}"
|
157
166
|
)
|
158
167
|
|
159
|
-
|
160
|
-
|
161
|
-
entry_cmd = (
|
162
|
-
launch_project.override_entrypoint.command
|
163
|
-
if launch_project.override_entrypoint is not None
|
164
|
-
else None
|
165
|
-
)
|
166
|
-
|
168
|
+
entrypoint = launch_project.get_job_entry_point()
|
169
|
+
entry_cmd = None if entrypoint is None else entrypoint.command
|
167
170
|
command_str = " ".join(
|
168
171
|
get_docker_command(
|
169
172
|
image_uri,
|
@@ -221,7 +221,6 @@ class SageMakerRunner(AbstractRunner):
|
|
221
221
|
await run.wait()
|
222
222
|
return run
|
223
223
|
|
224
|
-
launch_project.fill_macros(image_uri)
|
225
224
|
_logger.info("Connecting to sagemaker client")
|
226
225
|
entry_point = (
|
227
226
|
launch_project.override_entrypoint or launch_project.get_job_entry_point()
|
@@ -296,13 +295,12 @@ def build_sagemaker_args(
|
|
296
295
|
entry_point: Optional[EntryPoint],
|
297
296
|
args: Optional[List[str]],
|
298
297
|
max_env_length: int,
|
299
|
-
image_uri:
|
298
|
+
image_uri: str,
|
300
299
|
default_output_path: Optional[str] = None,
|
301
300
|
) -> Dict[str, Any]:
|
302
301
|
sagemaker_args: Dict[str, Any] = {}
|
303
|
-
|
304
|
-
|
305
|
-
)
|
302
|
+
resource_args = launch_project.fill_macros(image_uri)
|
303
|
+
given_sagemaker_args: Optional[Dict[str, Any]] = resource_args.get("sagemaker")
|
306
304
|
|
307
305
|
if given_sagemaker_args is None:
|
308
306
|
raise LaunchError(
|
wandb/sdk/launch/utils.py
CHANGED
@@ -87,6 +87,8 @@ LOG_PREFIX = f"{click.style('launch:', fg='magenta')} "
|
|
87
87
|
MAX_ENV_LENGTHS: Dict[str, int] = defaultdict(lambda: 32670)
|
88
88
|
MAX_ENV_LENGTHS["SageMakerRunner"] = 512
|
89
89
|
|
90
|
+
CODE_MOUNT_DIR = "/mnt/wandb"
|
91
|
+
|
90
92
|
|
91
93
|
def load_wandb_config() -> Config:
|
92
94
|
"""Load wandb config from WANDB_CONFIG environment variable(s).
|