wandb 0.17.5__py3-none-any.whl → 0.17.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. wandb/__init__.py +3 -1
  2. wandb/apis/public/api.py +1 -1
  3. wandb/apis/public/jobs.py +5 -0
  4. wandb/bin/nvidia_gpu_stats +0 -0
  5. wandb/data_types.py +2 -1
  6. wandb/env.py +6 -0
  7. wandb/integration/lightning/fabric/logger.py +4 -4
  8. wandb/proto/v3/wandb_internal_pb2.py +226 -226
  9. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  10. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  11. wandb/proto/v4/wandb_internal_pb2.py +226 -226
  12. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  13. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  14. wandb/proto/v5/wandb_internal_pb2.py +226 -226
  15. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  16. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  17. wandb/proto/wandb_deprecated.py +4 -0
  18. wandb/proto/wandb_internal_pb2.py +6 -0
  19. wandb/sdk/artifacts/artifact.py +5 -0
  20. wandb/sdk/artifacts/artifact_manifest_entry.py +31 -0
  21. wandb/sdk/artifacts/storage_handlers/azure_handler.py +35 -23
  22. wandb/sdk/data_types/object_3d.py +113 -2
  23. wandb/sdk/interface/interface.py +23 -0
  24. wandb/sdk/internal/sender.py +31 -15
  25. wandb/sdk/launch/_launch.py +4 -2
  26. wandb/sdk/launch/_project_spec.py +34 -8
  27. wandb/sdk/launch/agent/agent.py +6 -2
  28. wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -4
  29. wandb/sdk/launch/builder/build.py +4 -2
  30. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +2 -1
  31. wandb/sdk/launch/inputs/internal.py +42 -28
  32. wandb/sdk/launch/inputs/schema.py +39 -0
  33. wandb/sdk/launch/runner/kubernetes_runner.py +72 -0
  34. wandb/sdk/launch/runner/local_container.py +13 -10
  35. wandb/sdk/launch/runner/sagemaker_runner.py +3 -5
  36. wandb/sdk/launch/utils.py +2 -0
  37. wandb/sdk/lib/disabled.py +13 -174
  38. wandb/sdk/wandb_init.py +23 -27
  39. wandb/sdk/wandb_login.py +6 -6
  40. wandb/sdk/wandb_run.py +41 -22
  41. wandb/sdk/wandb_settings.py +3 -2
  42. wandb/wandb_agent.py +2 -0
  43. {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/METADATA +3 -2
  44. {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/RECORD +47 -45
  45. {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/WHEEL +0 -0
  46. {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/entry_points.txt +0 -0
  47. {wandb-0.17.5.dist-info → wandb-0.17.6.dist-info}/licenses/LICENSE +0 -0
@@ -7,6 +7,7 @@ import enum
7
7
  import json
8
8
  import logging
9
9
  import os
10
+ import shutil
10
11
  import tempfile
11
12
  from copy import deepcopy
12
13
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
@@ -112,6 +113,9 @@ class LaunchProject:
112
113
  self.sweep_id = sweep_id
113
114
  self.author = launch_spec.get("author")
114
115
  self.python_version: Optional[str] = launch_spec.get("python_version")
116
+ self._job_dockerfile: Optional[str] = None
117
+ self._job_build_context: Optional[str] = None
118
+ self._job_base_image: Optional[str] = None
115
119
  self.accelerator_base_image: Optional[str] = resource_args_build.get(
116
120
  "accelerator", {}
117
121
  ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
@@ -131,8 +135,6 @@ class LaunchProject:
131
135
  self._queue_name: Optional[str] = None
132
136
  self._queue_entity: Optional[str] = None
133
137
  self._run_queue_item_id: Optional[str] = None
134
- self._job_dockerfile: Optional[str] = None
135
- self._job_build_context: Optional[str] = None
136
138
 
137
139
  def init_source(self) -> None:
138
140
  if self.docker_image is not None:
@@ -146,6 +148,21 @@ class LaunchProject:
146
148
  self.project_dir = os.getcwd()
147
149
  self._entry_point = self.override_entrypoint
148
150
 
151
+ def change_project_dir(self, new_dir: str) -> None:
152
+ """Change the project directory to a new directory."""
153
+ # Copy the contents of the old project dir to the new project dir.
154
+ old_dir = self.project_dir
155
+ if old_dir is not None:
156
+ shutil.copytree(
157
+ old_dir,
158
+ new_dir,
159
+ symlinks=True,
160
+ dirs_exist_ok=True,
161
+ ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"),
162
+ )
163
+ shutil.rmtree(old_dir)
164
+ self.project_dir = new_dir
165
+
149
166
  def init_git(self, git_info: Dict[str, str]) -> None:
150
167
  self.git_version = git_info.get("version")
151
168
  self.git_repo = git_info.get("repo")
@@ -212,14 +229,23 @@ class LaunchProject:
212
229
  def job_build_context(self) -> Optional[str]:
213
230
  return self._job_build_context
214
231
 
232
+ @property
233
+ def job_base_image(self) -> Optional[str]:
234
+ return self._job_base_image
235
+
215
236
  def set_job_dockerfile(self, dockerfile: str) -> None:
216
237
  self._job_dockerfile = dockerfile
217
238
 
218
239
  def set_job_build_context(self, build_context: str) -> None:
219
240
  self._job_build_context = build_context
220
241
 
242
+ def set_job_base_image(self, base_image: str) -> None:
243
+ self._job_base_image = base_image
244
+
221
245
  @property
222
246
  def image_name(self) -> str:
247
+ if self.job_base_image is not None:
248
+ return self.job_base_image
223
249
  if self.docker_image is not None:
224
250
  return self.docker_image
225
251
  elif self.uri is not None:
@@ -299,10 +325,8 @@ class LaunchProject:
299
325
 
300
326
  def build_required(self) -> bool:
301
327
  """Checks the source to see if a build is required."""
302
- # since the image tag for images built from jobs
303
- # is based on the job version index, which is immutable
304
- # we don't need to build the image for a job if that tag
305
- # already exists
328
+ if self.job_base_image is not None:
329
+ return False
306
330
  if self.source != LaunchSource.JOB:
307
331
  return True
308
332
  return False
@@ -316,7 +340,9 @@ class LaunchProject:
316
340
  Returns:
317
341
  Optional[str]: The Docker image or None if not specified.
318
342
  """
319
- return self._docker_image
343
+ if self._docker_image:
344
+ return self._docker_image
345
+ return None
320
346
 
321
347
  @docker_image.setter
322
348
  def docker_image(self, value: str) -> None:
@@ -336,7 +362,7 @@ class LaunchProject:
336
362
  # assuming project only has 1 entry point, pull that out
337
363
  # tmp fn until we figure out if we want to support multiple entry points or not
338
364
  if not self._entry_point:
339
- if not self.docker_image:
365
+ if not self.docker_image and not self.job_base_image:
340
366
  raise LaunchError(
341
367
  "Project must have at least one entry point unless docker image is specified."
342
368
  )
@@ -717,7 +717,7 @@ class LaunchAgent:
717
717
  _, build_config, registry_config = construct_agent_configs(
718
718
  default_config, override_build_config
719
719
  )
720
- image_uri = project.docker_image
720
+ image_uri = project.docker_image or project.job_base_image
721
721
  entrypoint = project.get_job_entry_point()
722
722
  environment = loader.environment_from_config(
723
723
  default_config.get("environment", {})
@@ -727,7 +727,11 @@ class LaunchAgent:
727
727
  backend = loader.runner_from_config(
728
728
  resource, api, backend_config, environment, registry
729
729
  )
730
- if not (project.docker_image or isinstance(backend, LocalProcessRunner)):
730
+ if not (
731
+ project.docker_image
732
+ or project.job_base_image
733
+ or isinstance(backend, LocalProcessRunner)
734
+ ):
731
735
  assert entrypoint is not None
732
736
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
733
737
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import sys
5
- from typing import List, Optional, Union
5
+ from typing import List, Optional
6
6
 
7
7
  import wandb
8
8
 
@@ -17,9 +17,7 @@ FileSubtypes = Literal["warning", "error"]
17
17
  class RunQueueItemFileSaver:
18
18
  def __init__(
19
19
  self,
20
- agent_run: Optional[
21
- Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.RunDisabled"]
22
- ],
20
+ agent_run: Optional["wandb.sdk.wandb_run.Run"],
23
21
  run_queue_item_id: str,
24
22
  ):
25
23
  self.run_queue_item_id = run_queue_item_id
@@ -201,7 +201,7 @@ def get_requirements_section(
201
201
  # If there is a requirements.txt at root of build context, use that.
202
202
  if (base_path / "src" / "requirements.txt").exists():
203
203
  requirements_files += ["src/requirements.txt"]
204
- deps_install_line = "pip install -r requirements.txt"
204
+ deps_install_line = "pip install uv && uv pip install -r requirements.txt"
205
205
  with open(base_path / "src" / "requirements.txt") as f:
206
206
  requirements = f.readlines()
207
207
  if not any(["wandb" in r for r in requirements]):
@@ -237,7 +237,9 @@ def get_requirements_section(
237
237
  with open(base_path / "src" / "requirements.txt", "w") as f:
238
238
  f.write("\n".join(project_deps))
239
239
  requirements_files += ["src/requirements.txt"]
240
- deps_install_line = "pip install -r requirements.txt"
240
+ deps_install_line = (
241
+ "pip install uv && uv pip install -r requirements.txt"
242
+ )
241
243
  return PIP_TEMPLATE.format(
242
244
  buildx_optional_prefix=prefix,
243
245
  requirements_files=" ".join(requirements_files),
@@ -39,12 +39,13 @@ def install_deps(
39
39
  deps (str[], None): The dependencies that failed to install
40
40
  """
41
41
  try:
42
+ subprocess.check_output(["pip", "install", "uv"], stderr=subprocess.STDOUT)
42
43
  # Include only uri if @ is present
43
44
  clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
44
45
  index_args = ["--extra-index-url", extra_index] if extra_index else []
45
46
  print("installing {}...".format(", ".join(clean_deps)))
46
47
  opts = opts or []
47
- args = ["pip", "install"] + opts + clean_deps + index_args
48
+ args = ["uv", "pip", "install"] + opts + clean_deps + index_args
48
49
  sys.stdout.flush()
49
50
  subprocess.check_output(args, stderr=subprocess.STDOUT)
50
51
  return failed
@@ -16,7 +16,9 @@ from typing import Any, Dict, List, Optional
16
16
  import wandb
17
17
  import wandb.data_types
18
18
  from wandb.sdk.launch.errors import LaunchError
19
+ from wandb.sdk.launch.inputs.schema import META_SCHEMA
19
20
  from wandb.sdk.wandb_run import Run
21
+ from wandb.util import get_module
20
22
 
21
23
  from .files import config_path_is_valid, override_file
22
24
 
@@ -129,7 +131,7 @@ def _publish_job_input(
129
131
  )
130
132
 
131
133
 
132
- def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
134
+ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
133
135
  """Recursively fix JSON schemas with common issues.
134
136
 
135
137
  1. Replaces any instances of $ref with their associated definition in defs
@@ -137,7 +139,7 @@ def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
137
139
  See test_internal.py for examples
138
140
  """
139
141
  ret: Dict[str, Any] = {}
140
- if "$ref" in schema:
142
+ if "$ref" in schema and defs:
141
143
  # Reference found, replace it with its definition
142
144
  def_key = schema["$ref"].split("#/$defs/")[1]
143
145
  # Also run recursive replacement in case a ref contains more refs
@@ -170,12 +172,16 @@ def _replace_refs_and_allofs(schema: dict, defs: dict) -> dict:
170
172
  return ret
171
173
 
172
174
 
173
- def _convert_pydantic_model_to_jsonschema(model: Any) -> dict:
174
- schema = model.model_json_schema()
175
- defs = schema.pop("$defs")
176
- if not defs:
177
- return schema
178
- return _replace_refs_and_allofs(schema, defs)
175
+ def _validate_schema(schema: dict) -> None:
176
+ jsonschema = get_module(
177
+ "jsonschema",
178
+ required="Setting job schema requires the jsonschema package. Please install it with `pip install 'wandb[launch]'`.",
179
+ lazy=False,
180
+ )
181
+ validator = jsonschema.Draft202012Validator(META_SCHEMA)
182
+ errs = sorted(validator.iter_errors(schema), key=str)
183
+ if errs:
184
+ wandb.termwarn(f"Schema includes unhandled or invalid configurations:\n{errs}")
179
185
 
180
186
 
181
187
  def handle_config_file_input(
@@ -204,16 +210,20 @@ def handle_config_file_input(
204
210
  path,
205
211
  dest,
206
212
  )
207
- # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
208
- # or the BaseModel class itself (e.g. schema=MySchema)
209
- if hasattr(schema, "model_json_schema") and callable(
210
- schema.model_json_schema # type: ignore
211
- ):
212
- schema = _convert_pydantic_model_to_jsonschema(schema)
213
- if schema and not isinstance(schema, dict):
214
- raise LaunchError(
215
- "schema must be a dict, Pydantic model instance, or Pydantic model class."
216
- )
213
+ if schema:
214
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
215
+ # or the BaseModel class itself (e.g. schema=MySchema)
216
+ if hasattr(schema, "model_json_schema") and callable(
217
+ schema.model_json_schema # type: ignore
218
+ ):
219
+ schema = schema.model_json_schema()
220
+ if not isinstance(schema, dict):
221
+ raise LaunchError(
222
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
223
+ )
224
+ defs = schema.pop("$defs", None)
225
+ schema = _replace_refs_and_allofs(schema, defs)
226
+ _validate_schema(schema)
217
227
  arguments = JobInputArguments(
218
228
  include=include,
219
229
  exclude=exclude,
@@ -241,16 +251,20 @@ def handle_run_config_input(
241
251
  If there is no active run, the include and exclude paths are staged and sent
242
252
  when a run is created.
243
253
  """
244
- # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
245
- # or the BaseModel class itself (e.g. schema=MySchema)
246
- if hasattr(schema, "model_json_schema") and callable(
247
- schema.model_json_schema # type: ignore
248
- ):
249
- schema = _convert_pydantic_model_to_jsonschema(schema)
250
- if schema and not isinstance(schema, dict):
251
- raise LaunchError(
252
- "schema must be a dict, Pydantic model instance, or Pydantic model class."
253
- )
254
+ if schema:
255
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
256
+ # or the BaseModel class itself (e.g. schema=MySchema)
257
+ if hasattr(schema, "model_json_schema") and callable(
258
+ schema.model_json_schema # type: ignore
259
+ ):
260
+ schema = schema.model_json_schema()
261
+ if not isinstance(schema, dict):
262
+ raise LaunchError(
263
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
264
+ )
265
+ defs = schema.pop("$defs", None)
266
+ schema = _replace_refs_and_allofs(schema, defs)
267
+ _validate_schema(schema)
254
268
  arguments = JobInputArguments(
255
269
  include=include,
256
270
  exclude=exclude,
@@ -0,0 +1,39 @@
1
+ META_SCHEMA = {
2
+ "type": "object",
3
+ "properties": {
4
+ "type": {
5
+ "type": "string",
6
+ "enum": ["boolean", "integer", "number", "string", "object"],
7
+ },
8
+ "title": {"type": "string"},
9
+ "description": {"type": "string"},
10
+ "enum": {"type": "array", "items": {"type": ["integer", "number", "string"]}},
11
+ "properties": {"type": "object", "patternProperties": {".*": {"$ref": "#"}}},
12
+ "allOf": {"type": "array", "items": {"$ref": "#"}},
13
+ },
14
+ "allOf": [
15
+ {
16
+ "if": {"properties": {"type": {"const": "number"}}},
17
+ "then": {
18
+ "properties": {
19
+ "minimum": {"type": ["integer", "number"]},
20
+ "maximum": {"type": ["integer", "number"]},
21
+ "exclusiveMinimum": {"type": ["integer", "number"]},
22
+ "exclusiveMaximum": {"type": ["integer", "number"]},
23
+ }
24
+ },
25
+ },
26
+ {
27
+ "if": {"properties": {"type": {"const": "integer"}}},
28
+ "then": {
29
+ "properties": {
30
+ "minimum": {"type": "integer"},
31
+ "maximum": {"type": "integer"},
32
+ "exclusiveMinimum": {"type": "integer"},
33
+ "exclusiveMaximum": {"type": "integer"},
34
+ }
35
+ },
36
+ },
37
+ ],
38
+ "unevaluatedProperties": False,
39
+ }
@@ -31,6 +31,7 @@ from wandb.util import get_module
31
31
  from .._project_spec import EntryPoint, LaunchProject
32
32
  from ..errors import LaunchError
33
33
  from ..utils import (
34
+ CODE_MOUNT_DIR,
34
35
  LOG_PREFIX,
35
36
  MAX_ENV_LENGTHS,
36
37
  PROJECT_SYNCHRONOUS,
@@ -66,6 +67,10 @@ API_KEY_SECRET_MAX_RETRIES = 5
66
67
  _logger = logging.getLogger(__name__)
67
68
 
68
69
 
70
+ SOURCE_CODE_PVC_MOUNT_PATH = os.environ.get("WANDB_LAUNCH_CODE_PVC_MOUNT_PATH")
71
+ SOURCE_CODE_PVC_NAME = os.environ.get("WANDB_LAUNCH_CODE_PVC_NAME")
72
+
73
+
69
74
  class KubernetesSubmittedRun(AbstractRun):
70
75
  """Wrapper for a launched run on Kubernetes."""
71
76
 
@@ -468,6 +473,12 @@ class KubernetesRunner(AbstractRunner):
468
473
  "true",
469
474
  )
470
475
 
476
+ if launch_project.job_base_image:
477
+ apply_code_mount_configuration(
478
+ job,
479
+ launch_project,
480
+ )
481
+
471
482
  # Add wandb.ai/agent: current agent label on all pods
472
483
  if LaunchAgent.initialized():
473
484
  add_label_to_pods(
@@ -504,6 +515,22 @@ class KubernetesRunner(AbstractRunner):
504
515
  kubernetes_asyncio, resource_args
505
516
  )
506
517
 
518
+ # If using pvc for code mount, move code there.
519
+ if launch_project.job_base_image is not None:
520
+ if SOURCE_CODE_PVC_NAME is None or SOURCE_CODE_PVC_MOUNT_PATH is None:
521
+ raise LaunchError(
522
+ "WANDB_LAUNCH_SOURCE_CODE_PVC_ environment variables not set. "
523
+ "Unable to mount source code PVC into base image. "
524
+ "Use the `codeMountPvcName` variable in the agent helm chart "
525
+ "to enable base image jobs for this agent. See "
526
+ "https://github.com/wandb/helm-charts/tree/main/charts/launch-agent "
527
+ "for more information."
528
+ )
529
+ code_subdir = launch_project.get_image_source_string()
530
+ launch_project.change_project_dir(
531
+ os.path.join(SOURCE_CODE_PVC_MOUNT_PATH, code_subdir)
532
+ )
533
+
507
534
  # If the user specified an alternate api, we need will execute this
508
535
  # run by creating a custom object.
509
536
  api_version = resource_args.get("apiVersion", "batch/v1")
@@ -542,6 +569,9 @@ class KubernetesRunner(AbstractRunner):
542
569
  LaunchAgent.name()
543
570
  )
544
571
 
572
+ if launch_project.job_base_image:
573
+ apply_code_mount_configuration(resource_args, launch_project)
574
+
545
575
  overrides = {}
546
576
  if launch_project.override_args:
547
577
  overrides["args"] = launch_project.override_args
@@ -889,3 +919,45 @@ def add_entrypoint_args_overrides(manifest: Union[dict, list], overrides: dict)
889
919
  container["args"] = overrides["args"]
890
920
  for value in manifest.values():
891
921
  add_entrypoint_args_overrides(value, overrides)
922
+
923
+
924
+ def apply_code_mount_configuration(
925
+ manifest: Union[Dict, list], project: LaunchProject
926
+ ) -> None:
927
+ """Apply code mount configuration to all containers in a manifest.
928
+
929
+ Recursively traverses the manifest and adds the code mount configuration to
930
+ all containers. Containers are identified by the presence of a "spec" key
931
+ with a "containers" key in the value.
932
+
933
+ Arguments:
934
+ manifest: The manifest to modify.
935
+ project: The launch project.
936
+
937
+ Returns: None.
938
+ """
939
+ assert SOURCE_CODE_PVC_NAME is not None
940
+ source_dir = project.get_image_source_string()
941
+ for pod in yield_pods(manifest):
942
+ for container in yield_containers(pod):
943
+ if "volumeMounts" not in container:
944
+ container["volumeMounts"] = []
945
+ container["volumeMounts"].append(
946
+ {
947
+ "name": "wandb-source-code-volume",
948
+ "mountPath": CODE_MOUNT_DIR,
949
+ "subPath": source_dir,
950
+ }
951
+ )
952
+ container["workingDir"] = CODE_MOUNT_DIR
953
+ spec = pod["spec"]
954
+ if "volumes" not in spec:
955
+ spec["volumes"] = []
956
+ spec["volumes"].append(
957
+ {
958
+ "name": "wandb-source-code-volume",
959
+ "persistentVolumeClaim": {
960
+ "claimName": SOURCE_CODE_PVC_NAME,
961
+ },
962
+ }
963
+ )
@@ -14,6 +14,7 @@ from wandb.sdk.launch.registry.abstract import AbstractRegistry
14
14
  from .._project_spec import LaunchProject
15
15
  from ..errors import LaunchError
16
16
  from ..utils import (
17
+ CODE_MOUNT_DIR,
17
18
  LOG_PREFIX,
18
19
  MAX_ENV_LENGTHS,
19
20
  PROJECT_SYNCHRONOUS,
@@ -121,7 +122,15 @@ class LocalContainerRunner(AbstractRunner):
121
122
  docker_args["network"] = "host"
122
123
  if sys.platform == "linux" or sys.platform == "linux2":
123
124
  docker_args["add-host"] = "host.docker.internal:host-gateway"
124
-
125
+ base_image = launch_project.job_base_image
126
+ if base_image is not None:
127
+ # Mount code into the container and set the working directory.
128
+ if "volume" not in docker_args:
129
+ docker_args["volume"] = []
130
+ docker_args["volume"].append(
131
+ f"{launch_project.project_dir}:{CODE_MOUNT_DIR}"
132
+ )
133
+ docker_args["workdir"] = CODE_MOUNT_DIR
125
134
  return docker_args
126
135
 
127
136
  async def run(
@@ -146,7 +155,7 @@ class LocalContainerRunner(AbstractRunner):
146
155
  elif _is_wandb_dev_uri(self._api.settings("base_url")):
147
156
  env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
148
157
 
149
- if launch_project.docker_image:
158
+ if launch_project.docker_image or launch_project.job_base_image:
150
159
  try:
151
160
  pull_docker_image(image_uri)
152
161
  except Exception as e:
@@ -156,14 +165,8 @@ class LocalContainerRunner(AbstractRunner):
156
165
  f"Failed to pull docker image {image_uri} with error: {e}"
157
166
  )
158
167
 
159
- assert launch_project.docker_image == image_uri
160
-
161
- entry_cmd = (
162
- launch_project.override_entrypoint.command
163
- if launch_project.override_entrypoint is not None
164
- else None
165
- )
166
-
168
+ entrypoint = launch_project.get_job_entry_point()
169
+ entry_cmd = None if entrypoint is None else entrypoint.command
167
170
  command_str = " ".join(
168
171
  get_docker_command(
169
172
  image_uri,
@@ -221,7 +221,6 @@ class SageMakerRunner(AbstractRunner):
221
221
  await run.wait()
222
222
  return run
223
223
 
224
- launch_project.fill_macros(image_uri)
225
224
  _logger.info("Connecting to sagemaker client")
226
225
  entry_point = (
227
226
  launch_project.override_entrypoint or launch_project.get_job_entry_point()
@@ -296,13 +295,12 @@ def build_sagemaker_args(
296
295
  entry_point: Optional[EntryPoint],
297
296
  args: Optional[List[str]],
298
297
  max_env_length: int,
299
- image_uri: Optional[str] = None,
298
+ image_uri: str,
300
299
  default_output_path: Optional[str] = None,
301
300
  ) -> Dict[str, Any]:
302
301
  sagemaker_args: Dict[str, Any] = {}
303
- given_sagemaker_args: Optional[Dict[str, Any]] = launch_project.resource_args.get(
304
- "sagemaker"
305
- )
302
+ resource_args = launch_project.fill_macros(image_uri)
303
+ given_sagemaker_args: Optional[Dict[str, Any]] = resource_args.get("sagemaker")
306
304
 
307
305
  if given_sagemaker_args is None:
308
306
  raise LaunchError(
wandb/sdk/launch/utils.py CHANGED
@@ -87,6 +87,8 @@ LOG_PREFIX = f"{click.style('launch:', fg='magenta')} "
87
87
  MAX_ENV_LENGTHS: Dict[str, int] = defaultdict(lambda: 32670)
88
88
  MAX_ENV_LENGTHS["SageMakerRunner"] = 512
89
89
 
90
+ CODE_MOUNT_DIR = "/mnt/wandb"
91
+
90
92
 
91
93
  def load_wandb_config() -> Config:
92
94
  """Load wandb config from WANDB_CONFIG environment variable(s).