wandb 0.17.4__py3-none-any.whl → 0.17.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. wandb/__init__.py +3 -1
  2. wandb/apis/public/api.py +1 -1
  3. wandb/apis/public/jobs.py +5 -0
  4. wandb/bin/nvidia_gpu_stats +0 -0
  5. wandb/data_types.py +2 -1
  6. wandb/env.py +6 -0
  7. wandb/filesync/upload_job.py +1 -1
  8. wandb/integration/lightning/fabric/logger.py +4 -4
  9. wandb/proto/v3/wandb_internal_pb2.py +339 -328
  10. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +326 -323
  13. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/proto/v5/wandb_internal_pb2.py +326 -323
  16. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  17. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  18. wandb/proto/wandb_deprecated.py +4 -0
  19. wandb/proto/wandb_internal_pb2.py +6 -0
  20. wandb/sdk/artifacts/artifact.py +16 -24
  21. wandb/sdk/artifacts/artifact_manifest_entry.py +31 -0
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +35 -23
  23. wandb/sdk/data_types/object_3d.py +113 -2
  24. wandb/sdk/interface/interface.py +35 -5
  25. wandb/sdk/interface/interface_shared.py +9 -7
  26. wandb/sdk/internal/handler.py +1 -1
  27. wandb/sdk/internal/internal_api.py +4 -4
  28. wandb/sdk/internal/sender.py +40 -17
  29. wandb/sdk/launch/_launch.py +4 -2
  30. wandb/sdk/launch/_project_spec.py +34 -8
  31. wandb/sdk/launch/agent/agent.py +6 -2
  32. wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -4
  33. wandb/sdk/launch/builder/build.py +4 -2
  34. wandb/sdk/launch/builder/kaniko_builder.py +30 -9
  35. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +2 -1
  36. wandb/sdk/launch/inputs/internal.py +93 -2
  37. wandb/sdk/launch/inputs/manage.py +21 -3
  38. wandb/sdk/launch/inputs/schema.py +39 -0
  39. wandb/sdk/launch/runner/kubernetes_runner.py +72 -0
  40. wandb/sdk/launch/runner/local_container.py +13 -10
  41. wandb/sdk/launch/runner/sagemaker_runner.py +3 -5
  42. wandb/sdk/launch/utils.py +2 -0
  43. wandb/sdk/lib/disabled.py +13 -174
  44. wandb/sdk/lib/tracelog.py +2 -2
  45. wandb/sdk/wandb_init.py +23 -27
  46. wandb/sdk/wandb_login.py +6 -6
  47. wandb/sdk/wandb_manager.py +9 -5
  48. wandb/sdk/wandb_run.py +141 -97
  49. wandb/sdk/wandb_settings.py +3 -2
  50. wandb/util.py +29 -11
  51. wandb/wandb_agent.py +2 -0
  52. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/METADATA +3 -2
  53. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/RECORD +56 -54
  54. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/WHEEL +0 -0
  55. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/entry_points.txt +0 -0
  56. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/licenses/LICENSE +0 -0
@@ -137,6 +137,7 @@ class InterfaceShared(InterfaceBase):
137
137
  check_version: Optional[pb.CheckVersionRequest] = None,
138
138
  log_artifact: Optional[pb.LogArtifactRequest] = None,
139
139
  download_artifact: Optional[pb.DownloadArtifactRequest] = None,
140
+ link_artifact: Optional[pb.LinkArtifactRequest] = None,
140
141
  defer: Optional[pb.DeferRequest] = None,
141
142
  attach: Optional[pb.AttachRequest] = None,
142
143
  server_info: Optional[pb.ServerInfoRequest] = None,
@@ -184,6 +185,8 @@ class InterfaceShared(InterfaceBase):
184
185
  request.log_artifact.CopyFrom(log_artifact)
185
186
  elif download_artifact:
186
187
  request.download_artifact.CopyFrom(download_artifact)
188
+ elif link_artifact:
189
+ request.link_artifact.CopyFrom(link_artifact)
187
190
  elif defer:
188
191
  request.defer.CopyFrom(defer)
189
192
  elif attach:
@@ -242,7 +245,6 @@ class InterfaceShared(InterfaceBase):
242
245
  request: Optional[pb.Request] = None,
243
246
  telemetry: Optional[tpb.TelemetryRecord] = None,
244
247
  preempting: Optional[pb.RunPreemptingRecord] = None,
245
- link_artifact: Optional[pb.LinkArtifactRecord] = None,
246
248
  use_artifact: Optional[pb.UseArtifactRecord] = None,
247
249
  output: Optional[pb.OutputRecord] = None,
248
250
  output_raw: Optional[pb.OutputRawRecord] = None,
@@ -282,8 +284,6 @@ class InterfaceShared(InterfaceBase):
282
284
  record.metric.CopyFrom(metric)
283
285
  elif preempting:
284
286
  record.preempting.CopyFrom(preempting)
285
- elif link_artifact:
286
- record.link_artifact.CopyFrom(link_artifact)
287
287
  elif use_artifact:
288
288
  record.use_artifact.CopyFrom(use_artifact)
289
289
  elif output:
@@ -393,10 +393,6 @@ class InterfaceShared(InterfaceBase):
393
393
  rec = self._make_record(files=files)
394
394
  self._publish(rec)
395
395
 
396
- def _publish_link_artifact(self, link_artifact: pb.LinkArtifactRecord) -> Any:
397
- rec = self._make_record(link_artifact=link_artifact)
398
- self._publish(rec)
399
-
400
396
  def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
401
397
  rec = self._make_record(use_artifact=use_artifact)
402
398
  self._publish(rec)
@@ -411,6 +407,12 @@ class InterfaceShared(InterfaceBase):
411
407
  rec = self._make_request(download_artifact=download_artifact)
412
408
  return self._deliver_record(rec)
413
409
 
410
+ def _deliver_link_artifact(
411
+ self, link_artifact: pb.LinkArtifactRequest
412
+ ) -> MailboxHandle:
413
+ rec = self._make_request(link_artifact=link_artifact)
414
+ return self._deliver_record(rec)
415
+
414
416
  def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
415
417
  rec = self._make_record(artifact=proto_artifact)
416
418
  self._publish(rec)
@@ -230,7 +230,7 @@ class HandleManager:
230
230
  def handle_files(self, record: Record) -> None:
231
231
  self._dispatch_record(record)
232
232
 
233
- def handle_link_artifact(self, record: Record) -> None:
233
+ def handle_request_link_artifact(self, record: Record) -> None:
234
234
  self._dispatch_record(record)
235
235
 
236
236
  def handle_use_artifact(self, record: Record) -> None:
@@ -232,14 +232,14 @@ class Api:
232
232
 
233
233
  # todo: remove these hacky hacks after settings refactor is complete
234
234
  # keeping this code here to limit scope and so that it is easy to remove later
235
- extra_http_headers = self.settings("_extra_http_headers") or json.loads(
235
+ self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
236
236
  self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
237
237
  )
238
- extra_http_headers.update(_thread_local_api_settings.headers or {})
238
+ self._extra_http_headers.update(_thread_local_api_settings.headers or {})
239
239
 
240
240
  auth = None
241
241
  if self.access_token is not None:
242
- extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
242
+ self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
243
243
  elif _thread_local_api_settings.cookies is None:
244
244
  auth = ("api", self.api_key or "")
245
245
 
@@ -253,7 +253,7 @@ class Api:
253
253
  "User-Agent": self.user_agent,
254
254
  "X-WANDB-USERNAME": env.get_username(env=self._environ),
255
255
  "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
256
- **extra_http_headers,
256
+ **self._extra_http_headers,
257
257
  },
258
258
  use_json=True,
259
259
  # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
@@ -1,5 +1,7 @@
1
1
  """sender."""
2
2
 
3
+ import contextlib
4
+ import gzip
3
5
  import json
4
6
  import logging
5
7
  import os
@@ -66,6 +68,7 @@ else:
66
68
  if TYPE_CHECKING:
67
69
  from wandb.proto.wandb_internal_pb2 import (
68
70
  ArtifactManifest,
71
+ ArtifactManifestEntry,
69
72
  ArtifactRecord,
70
73
  HttpResponse,
71
74
  LocalInfo,
@@ -105,22 +108,18 @@ def _framework_priority() -> Generator[Tuple[str, str], None, None]:
105
108
 
106
109
  def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict:
107
110
  if manifest.version == 1:
108
- contents = {
109
- content.path: {
110
- "digest": content.digest,
111
- "birthArtifactID": content.birth_artifact_id
112
- if content.birth_artifact_id
113
- else None,
114
- "ref": content.ref if content.ref else None,
115
- "size": content.size if content.size is not None else None,
116
- "local_path": content.local_path if content.local_path else None,
117
- "skip_cache": content.skip_cache,
118
- "extra": {
119
- extra.key: json.loads(extra.value_json) for extra in content.extra
120
- },
111
+ if manifest.manifest_file_path:
112
+ contents = {}
113
+ with gzip.open(manifest.manifest_file_path, "rt") as f:
114
+ for line in f:
115
+ entry_json = json.loads(line)
116
+ path = entry_json.pop("path")
117
+ contents[path] = entry_json
118
+ else:
119
+ contents = {
120
+ content.path: _manifest_entry_from_proto(content)
121
+ for content in manifest.contents
121
122
  }
122
- for content in manifest.contents
123
- }
124
123
  else:
125
124
  raise ValueError(f"unknown artifact manifest version: {manifest.version}")
126
125
 
@@ -135,6 +134,19 @@ def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict:
135
134
  }
136
135
 
137
136
 
137
+ def _manifest_entry_from_proto(entry: "ArtifactManifestEntry") -> Dict:
138
+ birth_artifact_id = entry.birth_artifact_id if entry.birth_artifact_id else None
139
+ return {
140
+ "digest": entry.digest,
141
+ "birthArtifactID": birth_artifact_id,
142
+ "ref": entry.ref if entry.ref else None,
143
+ "size": entry.size if entry.size is not None else None,
144
+ "local_path": entry.local_path if entry.local_path else None,
145
+ "skip_cache": entry.skip_cache,
146
+ "extra": {extra.key: json.loads(extra.value_json) for extra in entry.extra},
147
+ }
148
+
149
+
138
150
  class ResumeState:
139
151
  resumed: bool
140
152
  step: int
@@ -1473,8 +1485,13 @@ class SendManager:
1473
1485
  # tbrecord watching threads are handled by handler.py
1474
1486
  pass
1475
1487
 
1476
- def send_link_artifact(self, record: "Record") -> None:
1477
- link = record.link_artifact
1488
+ def send_request_link_artifact(self, record: "Record") -> None:
1489
+ if not (record.control.req_resp or record.control.mailbox_slot):
1490
+ raise ValueError(
1491
+ f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}"
1492
+ )
1493
+ result = proto_util._result_from_record(record)
1494
+ link = record.request.link_artifact
1478
1495
  client_id = link.client_id
1479
1496
  server_id = link.server_id
1480
1497
  portfolio_name = link.portfolio_name
@@ -1490,7 +1507,9 @@ class SendManager:
1490
1507
  client_id, server_id, portfolio_name, entity, project, aliases
1491
1508
  )
1492
1509
  except Exception as e:
1510
+ result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1493
1511
  logger.warning("Failed to link artifact to portfolio: %s", e)
1512
+ self._respond_result(result)
1494
1513
 
1495
1514
  def send_use_artifact(self, record: "Record") -> None:
1496
1515
  """Pretend to send a used artifact.
@@ -1579,6 +1598,10 @@ class SendManager:
1579
1598
  )
1580
1599
 
1581
1600
  self._job_builder._handle_server_artifact(res, artifact)
1601
+
1602
+ if artifact.manifest.manifest_file_path:
1603
+ with contextlib.suppress(FileNotFoundError):
1604
+ os.remove(artifact.manifest.manifest_file_path)
1582
1605
  return res
1583
1606
 
1584
1607
  def send_alert(self, record: "Record") -> None:
@@ -211,7 +211,9 @@ async def _launch(
211
211
  launch_project = LaunchProject.from_spec(launch_spec, api)
212
212
  launch_project.fetch_and_validate_project()
213
213
  entrypoint = launch_project.get_job_entry_point()
214
- image_uri = launch_project.docker_image # Either set by user or None.
214
+ image_uri = (
215
+ launch_project.docker_image or launch_project.job_base_image
216
+ ) # Either set by user or None.
215
217
 
216
218
  # construct runner config.
217
219
  runner_config: Dict[str, Any] = {}
@@ -224,7 +226,7 @@ async def _launch(
224
226
  await environment.verify()
225
227
  registry = loader.registry_from_config(registry_config, environment)
226
228
  builder = loader.builder_from_config(build_config, environment, registry)
227
- if not launch_project.docker_image:
229
+ if not (launch_project.docker_image or launch_project.job_base_image):
228
230
  assert entrypoint
229
231
  image_uri = await builder.build_image(launch_project, entrypoint, None)
230
232
  backend = loader.runner_from_config(
@@ -7,6 +7,7 @@ import enum
7
7
  import json
8
8
  import logging
9
9
  import os
10
+ import shutil
10
11
  import tempfile
11
12
  from copy import deepcopy
12
13
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
@@ -112,6 +113,9 @@ class LaunchProject:
112
113
  self.sweep_id = sweep_id
113
114
  self.author = launch_spec.get("author")
114
115
  self.python_version: Optional[str] = launch_spec.get("python_version")
116
+ self._job_dockerfile: Optional[str] = None
117
+ self._job_build_context: Optional[str] = None
118
+ self._job_base_image: Optional[str] = None
115
119
  self.accelerator_base_image: Optional[str] = resource_args_build.get(
116
120
  "accelerator", {}
117
121
  ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
@@ -131,8 +135,6 @@ class LaunchProject:
131
135
  self._queue_name: Optional[str] = None
132
136
  self._queue_entity: Optional[str] = None
133
137
  self._run_queue_item_id: Optional[str] = None
134
- self._job_dockerfile: Optional[str] = None
135
- self._job_build_context: Optional[str] = None
136
138
 
137
139
  def init_source(self) -> None:
138
140
  if self.docker_image is not None:
@@ -146,6 +148,21 @@ class LaunchProject:
146
148
  self.project_dir = os.getcwd()
147
149
  self._entry_point = self.override_entrypoint
148
150
 
151
+ def change_project_dir(self, new_dir: str) -> None:
152
+ """Change the project directory to a new directory."""
153
+ # Copy the contents of the old project dir to the new project dir.
154
+ old_dir = self.project_dir
155
+ if old_dir is not None:
156
+ shutil.copytree(
157
+ old_dir,
158
+ new_dir,
159
+ symlinks=True,
160
+ dirs_exist_ok=True,
161
+ ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"),
162
+ )
163
+ shutil.rmtree(old_dir)
164
+ self.project_dir = new_dir
165
+
149
166
  def init_git(self, git_info: Dict[str, str]) -> None:
150
167
  self.git_version = git_info.get("version")
151
168
  self.git_repo = git_info.get("repo")
@@ -212,14 +229,23 @@ class LaunchProject:
212
229
  def job_build_context(self) -> Optional[str]:
213
230
  return self._job_build_context
214
231
 
232
+ @property
233
+ def job_base_image(self) -> Optional[str]:
234
+ return self._job_base_image
235
+
215
236
  def set_job_dockerfile(self, dockerfile: str) -> None:
216
237
  self._job_dockerfile = dockerfile
217
238
 
218
239
  def set_job_build_context(self, build_context: str) -> None:
219
240
  self._job_build_context = build_context
220
241
 
242
+ def set_job_base_image(self, base_image: str) -> None:
243
+ self._job_base_image = base_image
244
+
221
245
  @property
222
246
  def image_name(self) -> str:
247
+ if self.job_base_image is not None:
248
+ return self.job_base_image
223
249
  if self.docker_image is not None:
224
250
  return self.docker_image
225
251
  elif self.uri is not None:
@@ -299,10 +325,8 @@ class LaunchProject:
299
325
 
300
326
  def build_required(self) -> bool:
301
327
  """Checks the source to see if a build is required."""
302
- # since the image tag for images built from jobs
303
- # is based on the job version index, which is immutable
304
- # we don't need to build the image for a job if that tag
305
- # already exists
328
+ if self.job_base_image is not None:
329
+ return False
306
330
  if self.source != LaunchSource.JOB:
307
331
  return True
308
332
  return False
@@ -316,7 +340,9 @@ class LaunchProject:
316
340
  Returns:
317
341
  Optional[str]: The Docker image or None if not specified.
318
342
  """
319
- return self._docker_image
343
+ if self._docker_image:
344
+ return self._docker_image
345
+ return None
320
346
 
321
347
  @docker_image.setter
322
348
  def docker_image(self, value: str) -> None:
@@ -336,7 +362,7 @@ class LaunchProject:
336
362
  # assuming project only has 1 entry point, pull that out
337
363
  # tmp fn until we figure out if we want to support multiple entry points or not
338
364
  if not self._entry_point:
339
- if not self.docker_image:
365
+ if not self.docker_image and not self.job_base_image:
340
366
  raise LaunchError(
341
367
  "Project must have at least one entry point unless docker image is specified."
342
368
  )
@@ -717,7 +717,7 @@ class LaunchAgent:
717
717
  _, build_config, registry_config = construct_agent_configs(
718
718
  default_config, override_build_config
719
719
  )
720
- image_uri = project.docker_image
720
+ image_uri = project.docker_image or project.job_base_image
721
721
  entrypoint = project.get_job_entry_point()
722
722
  environment = loader.environment_from_config(
723
723
  default_config.get("environment", {})
@@ -727,7 +727,11 @@ class LaunchAgent:
727
727
  backend = loader.runner_from_config(
728
728
  resource, api, backend_config, environment, registry
729
729
  )
730
- if not (project.docker_image or isinstance(backend, LocalProcessRunner)):
730
+ if not (
731
+ project.docker_image
732
+ or project.job_base_image
733
+ or isinstance(backend, LocalProcessRunner)
734
+ ):
731
735
  assert entrypoint is not None
732
736
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
733
737
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import sys
5
- from typing import List, Optional, Union
5
+ from typing import List, Optional
6
6
 
7
7
  import wandb
8
8
 
@@ -17,9 +17,7 @@ FileSubtypes = Literal["warning", "error"]
17
17
  class RunQueueItemFileSaver:
18
18
  def __init__(
19
19
  self,
20
- agent_run: Optional[
21
- Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.RunDisabled"]
22
- ],
20
+ agent_run: Optional["wandb.sdk.wandb_run.Run"],
23
21
  run_queue_item_id: str,
24
22
  ):
25
23
  self.run_queue_item_id = run_queue_item_id
@@ -201,7 +201,7 @@ def get_requirements_section(
201
201
  # If there is a requirements.txt at root of build context, use that.
202
202
  if (base_path / "src" / "requirements.txt").exists():
203
203
  requirements_files += ["src/requirements.txt"]
204
- deps_install_line = "pip install -r requirements.txt"
204
+ deps_install_line = "pip install uv && uv pip install -r requirements.txt"
205
205
  with open(base_path / "src" / "requirements.txt") as f:
206
206
  requirements = f.readlines()
207
207
  if not any(["wandb" in r for r in requirements]):
@@ -237,7 +237,9 @@ def get_requirements_section(
237
237
  with open(base_path / "src" / "requirements.txt", "w") as f:
238
238
  f.write("\n".join(project_deps))
239
239
  requirements_files += ["src/requirements.txt"]
240
- deps_install_line = "pip install -r requirements.txt"
240
+ deps_install_line = (
241
+ "pip install uv && uv pip install -r requirements.txt"
242
+ )
241
243
  return PIP_TEMPLATE.format(
242
244
  buildx_optional_prefix=prefix,
243
245
  requirements_files=" ".join(requirements_files),
@@ -263,11 +263,17 @@ class KanikoBuilder(AbstractBuilder):
263
263
  repo_uri = await self.registry.get_repo_uri()
264
264
  image_uri = repo_uri + ":" + image_tag
265
265
 
266
- if (
267
- not launch_project.build_required()
268
- and await self.registry.check_image_exists(image_uri)
269
- ):
270
- return image_uri
266
+ # The DOCKER_CONFIG_SECRET option is mutually exclusive with the
267
+ # registry classes, so we must skip the check for image existence in
268
+ # that case.
269
+ if not launch_project.build_required():
270
+ if DOCKER_CONFIG_SECRET:
271
+ wandb.termlog(
272
+ f"Skipping check for existing image {image_uri} due to custom dockerconfig."
273
+ )
274
+ else:
275
+ if await self.registry.check_image_exists(image_uri):
276
+ return image_uri
271
277
 
272
278
  _logger.info(f"Building image {image_uri}...")
273
279
  _, api_client = await get_kube_context_and_api_client(
@@ -286,7 +292,12 @@ class KanikoBuilder(AbstractBuilder):
286
292
  wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
287
293
 
288
294
  try:
289
- if isinstance(self.registry, AzureContainerRegistry):
295
+ # DOCKER_CONFIG_SECRET is a user provided dockerconfigjson. Skip our
296
+ # dockerconfig handling if it's set.
297
+ if (
298
+ isinstance(self.registry, AzureContainerRegistry)
299
+ and not DOCKER_CONFIG_SECRET
300
+ ):
290
301
  dockerfile_config_map = client.V1ConfigMap(
291
302
  metadata=client.V1ObjectMeta(
292
303
  name=f"docker-config-{build_job_name}"
@@ -344,7 +355,10 @@ class KanikoBuilder(AbstractBuilder):
344
355
  finally:
345
356
  wandb.termlog(f"{LOG_PREFIX}Cleaning up resources")
346
357
  try:
347
- if isinstance(self.registry, AzureContainerRegistry):
358
+ if (
359
+ isinstance(self.registry, AzureContainerRegistry)
360
+ and not DOCKER_CONFIG_SECRET
361
+ ):
348
362
  await core_v1.delete_namespaced_config_map(
349
363
  f"docker-config-{build_job_name}", "wandb"
350
364
  )
@@ -498,7 +512,10 @@ class KanikoBuilder(AbstractBuilder):
498
512
  "readOnly": True,
499
513
  }
500
514
  )
501
- if isinstance(self.registry, AzureContainerRegistry):
515
+ if (
516
+ isinstance(self.registry, AzureContainerRegistry)
517
+ and not DOCKER_CONFIG_SECRET
518
+ ):
502
519
  # Add the docker config map
503
520
  volumes.append(
504
521
  {
@@ -533,7 +550,11 @@ class KanikoBuilder(AbstractBuilder):
533
550
  # Apply the rest of our defaults
534
551
  pod_labels["wandb"] = "launch"
535
552
  # This annotation is required to enable azure workload identity.
536
- if isinstance(self.registry, AzureContainerRegistry):
553
+ # Don't add this label if using a docker config secret for auth.
554
+ if (
555
+ isinstance(self.registry, AzureContainerRegistry)
556
+ and not DOCKER_CONFIG_SECRET
557
+ ):
537
558
  pod_labels["azure.workload.identity/use"] = "true"
538
559
  pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
539
560
  pod_spec["activeDeadlineSeconds"] = pod_spec.get(
@@ -39,12 +39,13 @@ def install_deps(
39
39
  deps (str[], None): The dependencies that failed to install
40
40
  """
41
41
  try:
42
+ subprocess.check_output(["pip", "install", "uv"], stderr=subprocess.STDOUT)
42
43
  # Include only uri if @ is present
43
44
  clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
44
45
  index_args = ["--extra-index-url", extra_index] if extra_index else []
45
46
  print("installing {}...".format(", ".join(clean_deps)))
46
47
  opts = opts or []
47
- args = ["pip", "install"] + opts + clean_deps + index_args
48
+ args = ["uv", "pip", "install"] + opts + clean_deps + index_args
48
49
  sys.stdout.flush()
49
50
  subprocess.check_output(args, stderr=subprocess.STDOUT)
50
51
  return failed
@@ -11,12 +11,14 @@ import os
11
11
  import pathlib
12
12
  import shutil
13
13
  import tempfile
14
- from typing import List, Optional
14
+ from typing import Any, Dict, List, Optional
15
15
 
16
16
  import wandb
17
17
  import wandb.data_types
18
18
  from wandb.sdk.launch.errors import LaunchError
19
+ from wandb.sdk.launch.inputs.schema import META_SCHEMA
19
20
  from wandb.sdk.wandb_run import Run
21
+ from wandb.util import get_module
20
22
 
21
23
  from .files import config_path_is_valid, override_file
22
24
 
@@ -62,11 +64,13 @@ class JobInputArguments:
62
64
  self,
63
65
  include: Optional[List[str]] = None,
64
66
  exclude: Optional[List[str]] = None,
67
+ schema: Optional[dict] = None,
65
68
  file_path: Optional[str] = None,
66
69
  run_config: Optional[bool] = None,
67
70
  ):
68
71
  self.include = include
69
72
  self.exclude = exclude
73
+ self.schema = schema
70
74
  self.file_path = file_path
71
75
  self.run_config = run_config
72
76
 
@@ -121,15 +125,70 @@ def _publish_job_input(
121
125
  exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
122
126
  if input.exclude
123
127
  else [],
128
+ input_schema=input.schema,
124
129
  run_config=input.run_config,
125
130
  file_path=input.file_path or "",
126
131
  )
127
132
 
128
133
 
134
+ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
135
+ """Recursively fix JSON schemas with common issues.
136
+
137
+ 1. Replaces any instances of $ref with their associated definition in defs
138
+ 2. Removes any "allOf" lists that only have one item, "lifting" the item up
139
+ See test_internal.py for examples
140
+ """
141
+ ret: Dict[str, Any] = {}
142
+ if "$ref" in schema and defs:
143
+ # Reference found, replace it with its definition
144
+ def_key = schema["$ref"].split("#/$defs/")[1]
145
+ # Also run recursive replacement in case a ref contains more refs
146
+ return _replace_refs_and_allofs(defs.pop(def_key), defs)
147
+ for key, val in schema.items():
148
+ if isinstance(val, dict):
149
+ # Step into dicts recursively
150
+ new_val_dict = _replace_refs_and_allofs(val, defs)
151
+ ret[key] = new_val_dict
152
+ elif isinstance(val, list):
153
+ # Step into each item in the list
154
+ new_val_list = []
155
+ for item in val:
156
+ if isinstance(item, dict):
157
+ new_val_list.append(_replace_refs_and_allofs(item, defs))
158
+ else:
159
+ new_val_list.append(item)
160
+ # Lift up allOf blocks with only one item
161
+ if (
162
+ key == "allOf"
163
+ and len(new_val_list) == 1
164
+ and isinstance(new_val_list[0], dict)
165
+ ):
166
+ ret.update(new_val_list[0])
167
+ else:
168
+ ret[key] = new_val_list
169
+ else:
170
+ # For anything else (str, int, etc) keep it as-is
171
+ ret[key] = val
172
+ return ret
173
+
174
+
175
+ def _validate_schema(schema: dict) -> None:
176
+ jsonschema = get_module(
177
+ "jsonschema",
178
+ required="Setting job schema requires the jsonschema package. Please install it with `pip install 'wandb[launch]'`.",
179
+ lazy=False,
180
+ )
181
+ validator = jsonschema.Draft202012Validator(META_SCHEMA)
182
+ errs = sorted(validator.iter_errors(schema), key=str)
183
+ if errs:
184
+ wandb.termwarn(f"Schema includes unhandled or invalid configurations:\n{errs}")
185
+
186
+
129
187
  def handle_config_file_input(
130
188
  path: str,
131
189
  include: Optional[List[str]] = None,
132
190
  exclude: Optional[List[str]] = None,
191
+ schema: Optional[Any] = None,
133
192
  ):
134
193
  """Declare an overridable configuration file for a launch job.
135
194
 
@@ -151,9 +210,24 @@ def handle_config_file_input(
151
210
  path,
152
211
  dest,
153
212
  )
213
+ if schema:
214
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
215
+ # or the BaseModel class itself (e.g. schema=MySchema)
216
+ if hasattr(schema, "model_json_schema") and callable(
217
+ schema.model_json_schema # type: ignore
218
+ ):
219
+ schema = schema.model_json_schema()
220
+ if not isinstance(schema, dict):
221
+ raise LaunchError(
222
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
223
+ )
224
+ defs = schema.pop("$defs", None)
225
+ schema = _replace_refs_and_allofs(schema, defs)
226
+ _validate_schema(schema)
154
227
  arguments = JobInputArguments(
155
228
  include=include,
156
229
  exclude=exclude,
230
+ schema=schema,
157
231
  file_path=path,
158
232
  run_config=False,
159
233
  )
@@ -165,7 +239,9 @@ def handle_config_file_input(
165
239
 
166
240
 
167
241
  def handle_run_config_input(
168
- include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
242
+ include: Optional[List[str]] = None,
243
+ exclude: Optional[List[str]] = None,
244
+ schema: Optional[Any] = None,
169
245
  ):
170
246
  """Declare wandb.config as an overridable configuration for a launch job.
171
247
 
@@ -175,9 +251,24 @@ def handle_run_config_input(
175
251
  If there is no active run, the include and exclude paths are staged and sent
176
252
  when a run is created.
177
253
  """
254
+ if schema:
255
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
256
+ # or the BaseModel class itself (e.g. schema=MySchema)
257
+ if hasattr(schema, "model_json_schema") and callable(
258
+ schema.model_json_schema # type: ignore
259
+ ):
260
+ schema = schema.model_json_schema()
261
+ if not isinstance(schema, dict):
262
+ raise LaunchError(
263
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
264
+ )
265
+ defs = schema.pop("$defs", None)
266
+ schema = _replace_refs_and_allofs(schema, defs)
267
+ _validate_schema(schema)
178
268
  arguments = JobInputArguments(
179
269
  include=include,
180
270
  exclude=exclude,
271
+ schema=schema,
181
272
  run_config=True,
182
273
  file_path=None,
183
274
  )