wandb 0.17.4__py3-none-win32.whl → 0.17.6__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. wandb/__init__.py +3 -1
  2. wandb/apis/public/api.py +1 -1
  3. wandb/apis/public/jobs.py +5 -0
  4. wandb/bin/wandb-core +0 -0
  5. wandb/data_types.py +2 -1
  6. wandb/env.py +6 -0
  7. wandb/filesync/upload_job.py +1 -1
  8. wandb/integration/lightning/fabric/logger.py +4 -4
  9. wandb/proto/v3/wandb_internal_pb2.py +339 -328
  10. wandb/proto/v3/wandb_settings_pb2.py +1 -1
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +326 -323
  13. wandb/proto/v4/wandb_settings_pb2.py +1 -1
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/proto/v5/wandb_internal_pb2.py +326 -323
  16. wandb/proto/v5/wandb_settings_pb2.py +1 -1
  17. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  18. wandb/proto/wandb_deprecated.py +4 -0
  19. wandb/proto/wandb_internal_pb2.py +6 -0
  20. wandb/sdk/artifacts/artifact.py +16 -24
  21. wandb/sdk/artifacts/artifact_manifest_entry.py +31 -0
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +35 -23
  23. wandb/sdk/data_types/object_3d.py +113 -2
  24. wandb/sdk/interface/interface.py +35 -5
  25. wandb/sdk/interface/interface_shared.py +9 -7
  26. wandb/sdk/internal/handler.py +1 -1
  27. wandb/sdk/internal/internal_api.py +4 -4
  28. wandb/sdk/internal/sender.py +40 -17
  29. wandb/sdk/launch/_launch.py +4 -2
  30. wandb/sdk/launch/_project_spec.py +34 -8
  31. wandb/sdk/launch/agent/agent.py +6 -2
  32. wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -4
  33. wandb/sdk/launch/builder/build.py +4 -2
  34. wandb/sdk/launch/builder/kaniko_builder.py +30 -9
  35. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +2 -1
  36. wandb/sdk/launch/inputs/internal.py +93 -2
  37. wandb/sdk/launch/inputs/manage.py +21 -3
  38. wandb/sdk/launch/inputs/schema.py +39 -0
  39. wandb/sdk/launch/runner/kubernetes_runner.py +72 -0
  40. wandb/sdk/launch/runner/local_container.py +13 -10
  41. wandb/sdk/launch/runner/sagemaker_runner.py +3 -5
  42. wandb/sdk/launch/utils.py +2 -0
  43. wandb/sdk/lib/disabled.py +13 -174
  44. wandb/sdk/lib/tracelog.py +2 -2
  45. wandb/sdk/wandb_init.py +23 -27
  46. wandb/sdk/wandb_login.py +6 -6
  47. wandb/sdk/wandb_manager.py +9 -5
  48. wandb/sdk/wandb_run.py +141 -97
  49. wandb/sdk/wandb_settings.py +3 -2
  50. wandb/util.py +29 -11
  51. wandb/wandb_agent.py +2 -0
  52. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/METADATA +3 -2
  53. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/RECORD +56 -55
  54. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/WHEEL +0 -0
  55. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/entry_points.txt +0 -0
  56. {wandb-0.17.4.dist-info → wandb-0.17.6.dist-info}/licenses/LICENSE +0 -0
@@ -137,6 +137,7 @@ class InterfaceShared(InterfaceBase):
137
137
  check_version: Optional[pb.CheckVersionRequest] = None,
138
138
  log_artifact: Optional[pb.LogArtifactRequest] = None,
139
139
  download_artifact: Optional[pb.DownloadArtifactRequest] = None,
140
+ link_artifact: Optional[pb.LinkArtifactRequest] = None,
140
141
  defer: Optional[pb.DeferRequest] = None,
141
142
  attach: Optional[pb.AttachRequest] = None,
142
143
  server_info: Optional[pb.ServerInfoRequest] = None,
@@ -184,6 +185,8 @@ class InterfaceShared(InterfaceBase):
184
185
  request.log_artifact.CopyFrom(log_artifact)
185
186
  elif download_artifact:
186
187
  request.download_artifact.CopyFrom(download_artifact)
188
+ elif link_artifact:
189
+ request.link_artifact.CopyFrom(link_artifact)
187
190
  elif defer:
188
191
  request.defer.CopyFrom(defer)
189
192
  elif attach:
@@ -242,7 +245,6 @@ class InterfaceShared(InterfaceBase):
242
245
  request: Optional[pb.Request] = None,
243
246
  telemetry: Optional[tpb.TelemetryRecord] = None,
244
247
  preempting: Optional[pb.RunPreemptingRecord] = None,
245
- link_artifact: Optional[pb.LinkArtifactRecord] = None,
246
248
  use_artifact: Optional[pb.UseArtifactRecord] = None,
247
249
  output: Optional[pb.OutputRecord] = None,
248
250
  output_raw: Optional[pb.OutputRawRecord] = None,
@@ -282,8 +284,6 @@ class InterfaceShared(InterfaceBase):
282
284
  record.metric.CopyFrom(metric)
283
285
  elif preempting:
284
286
  record.preempting.CopyFrom(preempting)
285
- elif link_artifact:
286
- record.link_artifact.CopyFrom(link_artifact)
287
287
  elif use_artifact:
288
288
  record.use_artifact.CopyFrom(use_artifact)
289
289
  elif output:
@@ -393,10 +393,6 @@ class InterfaceShared(InterfaceBase):
393
393
  rec = self._make_record(files=files)
394
394
  self._publish(rec)
395
395
 
396
- def _publish_link_artifact(self, link_artifact: pb.LinkArtifactRecord) -> Any:
397
- rec = self._make_record(link_artifact=link_artifact)
398
- self._publish(rec)
399
-
400
396
  def _publish_use_artifact(self, use_artifact: pb.UseArtifactRecord) -> Any:
401
397
  rec = self._make_record(use_artifact=use_artifact)
402
398
  self._publish(rec)
@@ -411,6 +407,12 @@ class InterfaceShared(InterfaceBase):
411
407
  rec = self._make_request(download_artifact=download_artifact)
412
408
  return self._deliver_record(rec)
413
409
 
410
+ def _deliver_link_artifact(
411
+ self, link_artifact: pb.LinkArtifactRequest
412
+ ) -> MailboxHandle:
413
+ rec = self._make_request(link_artifact=link_artifact)
414
+ return self._deliver_record(rec)
415
+
414
416
  def _publish_artifact(self, proto_artifact: pb.ArtifactRecord) -> None:
415
417
  rec = self._make_record(artifact=proto_artifact)
416
418
  self._publish(rec)
@@ -230,7 +230,7 @@ class HandleManager:
230
230
  def handle_files(self, record: Record) -> None:
231
231
  self._dispatch_record(record)
232
232
 
233
- def handle_link_artifact(self, record: Record) -> None:
233
+ def handle_request_link_artifact(self, record: Record) -> None:
234
234
  self._dispatch_record(record)
235
235
 
236
236
  def handle_use_artifact(self, record: Record) -> None:
@@ -232,14 +232,14 @@ class Api:
232
232
 
233
233
  # todo: remove these hacky hacks after settings refactor is complete
234
234
  # keeping this code here to limit scope and so that it is easy to remove later
235
- extra_http_headers = self.settings("_extra_http_headers") or json.loads(
235
+ self._extra_http_headers = self.settings("_extra_http_headers") or json.loads(
236
236
  self._environ.get("WANDB__EXTRA_HTTP_HEADERS", "{}")
237
237
  )
238
- extra_http_headers.update(_thread_local_api_settings.headers or {})
238
+ self._extra_http_headers.update(_thread_local_api_settings.headers or {})
239
239
 
240
240
  auth = None
241
241
  if self.access_token is not None:
242
- extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
242
+ self._extra_http_headers["Authorization"] = f"Bearer {self.access_token}"
243
243
  elif _thread_local_api_settings.cookies is None:
244
244
  auth = ("api", self.api_key or "")
245
245
 
@@ -253,7 +253,7 @@ class Api:
253
253
  "User-Agent": self.user_agent,
254
254
  "X-WANDB-USERNAME": env.get_username(env=self._environ),
255
255
  "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
256
- **extra_http_headers,
256
+ **self._extra_http_headers,
257
257
  },
258
258
  use_json=True,
259
259
  # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
@@ -1,5 +1,7 @@
1
1
  """sender."""
2
2
 
3
+ import contextlib
4
+ import gzip
3
5
  import json
4
6
  import logging
5
7
  import os
@@ -66,6 +68,7 @@ else:
66
68
  if TYPE_CHECKING:
67
69
  from wandb.proto.wandb_internal_pb2 import (
68
70
  ArtifactManifest,
71
+ ArtifactManifestEntry,
69
72
  ArtifactRecord,
70
73
  HttpResponse,
71
74
  LocalInfo,
@@ -105,22 +108,18 @@ def _framework_priority() -> Generator[Tuple[str, str], None, None]:
105
108
 
106
109
  def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict:
107
110
  if manifest.version == 1:
108
- contents = {
109
- content.path: {
110
- "digest": content.digest,
111
- "birthArtifactID": content.birth_artifact_id
112
- if content.birth_artifact_id
113
- else None,
114
- "ref": content.ref if content.ref else None,
115
- "size": content.size if content.size is not None else None,
116
- "local_path": content.local_path if content.local_path else None,
117
- "skip_cache": content.skip_cache,
118
- "extra": {
119
- extra.key: json.loads(extra.value_json) for extra in content.extra
120
- },
111
+ if manifest.manifest_file_path:
112
+ contents = {}
113
+ with gzip.open(manifest.manifest_file_path, "rt") as f:
114
+ for line in f:
115
+ entry_json = json.loads(line)
116
+ path = entry_json.pop("path")
117
+ contents[path] = entry_json
118
+ else:
119
+ contents = {
120
+ content.path: _manifest_entry_from_proto(content)
121
+ for content in manifest.contents
121
122
  }
122
- for content in manifest.contents
123
- }
124
123
  else:
125
124
  raise ValueError(f"unknown artifact manifest version: {manifest.version}")
126
125
 
@@ -135,6 +134,19 @@ def _manifest_json_from_proto(manifest: "ArtifactManifest") -> Dict:
135
134
  }
136
135
 
137
136
 
137
+ def _manifest_entry_from_proto(entry: "ArtifactManifestEntry") -> Dict:
138
+ birth_artifact_id = entry.birth_artifact_id if entry.birth_artifact_id else None
139
+ return {
140
+ "digest": entry.digest,
141
+ "birthArtifactID": birth_artifact_id,
142
+ "ref": entry.ref if entry.ref else None,
143
+ "size": entry.size if entry.size is not None else None,
144
+ "local_path": entry.local_path if entry.local_path else None,
145
+ "skip_cache": entry.skip_cache,
146
+ "extra": {extra.key: json.loads(extra.value_json) for extra in entry.extra},
147
+ }
148
+
149
+
138
150
  class ResumeState:
139
151
  resumed: bool
140
152
  step: int
@@ -1473,8 +1485,13 @@ class SendManager:
1473
1485
  # tbrecord watching threads are handled by handler.py
1474
1486
  pass
1475
1487
 
1476
- def send_link_artifact(self, record: "Record") -> None:
1477
- link = record.link_artifact
1488
+ def send_request_link_artifact(self, record: "Record") -> None:
1489
+ if not (record.control.req_resp or record.control.mailbox_slot):
1490
+ raise ValueError(
1491
+ f"Expected either `req_resp` or `mailbox_slot`, got: {record.control!r}"
1492
+ )
1493
+ result = proto_util._result_from_record(record)
1494
+ link = record.request.link_artifact
1478
1495
  client_id = link.client_id
1479
1496
  server_id = link.server_id
1480
1497
  portfolio_name = link.portfolio_name
@@ -1490,7 +1507,9 @@ class SendManager:
1490
1507
  client_id, server_id, portfolio_name, entity, project, aliases
1491
1508
  )
1492
1509
  except Exception as e:
1510
+ result.response.log_artifact_response.error_message = f'error linking artifact to "{entity}/{project}/{portfolio_name}"; error: {e}'
1493
1511
  logger.warning("Failed to link artifact to portfolio: %s", e)
1512
+ self._respond_result(result)
1494
1513
 
1495
1514
  def send_use_artifact(self, record: "Record") -> None:
1496
1515
  """Pretend to send a used artifact.
@@ -1579,6 +1598,10 @@ class SendManager:
1579
1598
  )
1580
1599
 
1581
1600
  self._job_builder._handle_server_artifact(res, artifact)
1601
+
1602
+ if artifact.manifest.manifest_file_path:
1603
+ with contextlib.suppress(FileNotFoundError):
1604
+ os.remove(artifact.manifest.manifest_file_path)
1582
1605
  return res
1583
1606
 
1584
1607
  def send_alert(self, record: "Record") -> None:
@@ -211,7 +211,9 @@ async def _launch(
211
211
  launch_project = LaunchProject.from_spec(launch_spec, api)
212
212
  launch_project.fetch_and_validate_project()
213
213
  entrypoint = launch_project.get_job_entry_point()
214
- image_uri = launch_project.docker_image # Either set by user or None.
214
+ image_uri = (
215
+ launch_project.docker_image or launch_project.job_base_image
216
+ ) # Either set by user or None.
215
217
 
216
218
  # construct runner config.
217
219
  runner_config: Dict[str, Any] = {}
@@ -224,7 +226,7 @@ async def _launch(
224
226
  await environment.verify()
225
227
  registry = loader.registry_from_config(registry_config, environment)
226
228
  builder = loader.builder_from_config(build_config, environment, registry)
227
- if not launch_project.docker_image:
229
+ if not (launch_project.docker_image or launch_project.job_base_image):
228
230
  assert entrypoint
229
231
  image_uri = await builder.build_image(launch_project, entrypoint, None)
230
232
  backend = loader.runner_from_config(
@@ -7,6 +7,7 @@ import enum
7
7
  import json
8
8
  import logging
9
9
  import os
10
+ import shutil
10
11
  import tempfile
11
12
  from copy import deepcopy
12
13
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
@@ -112,6 +113,9 @@ class LaunchProject:
112
113
  self.sweep_id = sweep_id
113
114
  self.author = launch_spec.get("author")
114
115
  self.python_version: Optional[str] = launch_spec.get("python_version")
116
+ self._job_dockerfile: Optional[str] = None
117
+ self._job_build_context: Optional[str] = None
118
+ self._job_base_image: Optional[str] = None
115
119
  self.accelerator_base_image: Optional[str] = resource_args_build.get(
116
120
  "accelerator", {}
117
121
  ).get("base_image") or resource_args_build.get("cuda", {}).get("base_image")
@@ -131,8 +135,6 @@ class LaunchProject:
131
135
  self._queue_name: Optional[str] = None
132
136
  self._queue_entity: Optional[str] = None
133
137
  self._run_queue_item_id: Optional[str] = None
134
- self._job_dockerfile: Optional[str] = None
135
- self._job_build_context: Optional[str] = None
136
138
 
137
139
  def init_source(self) -> None:
138
140
  if self.docker_image is not None:
@@ -146,6 +148,21 @@ class LaunchProject:
146
148
  self.project_dir = os.getcwd()
147
149
  self._entry_point = self.override_entrypoint
148
150
 
151
+ def change_project_dir(self, new_dir: str) -> None:
152
+ """Change the project directory to a new directory."""
153
+ # Copy the contents of the old project dir to the new project dir.
154
+ old_dir = self.project_dir
155
+ if old_dir is not None:
156
+ shutil.copytree(
157
+ old_dir,
158
+ new_dir,
159
+ symlinks=True,
160
+ dirs_exist_ok=True,
161
+ ignore=shutil.ignore_patterns("fsmonitor--daemon.ipc", ".git"),
162
+ )
163
+ shutil.rmtree(old_dir)
164
+ self.project_dir = new_dir
165
+
149
166
  def init_git(self, git_info: Dict[str, str]) -> None:
150
167
  self.git_version = git_info.get("version")
151
168
  self.git_repo = git_info.get("repo")
@@ -212,14 +229,23 @@ class LaunchProject:
212
229
  def job_build_context(self) -> Optional[str]:
213
230
  return self._job_build_context
214
231
 
232
+ @property
233
+ def job_base_image(self) -> Optional[str]:
234
+ return self._job_base_image
235
+
215
236
  def set_job_dockerfile(self, dockerfile: str) -> None:
216
237
  self._job_dockerfile = dockerfile
217
238
 
218
239
  def set_job_build_context(self, build_context: str) -> None:
219
240
  self._job_build_context = build_context
220
241
 
242
+ def set_job_base_image(self, base_image: str) -> None:
243
+ self._job_base_image = base_image
244
+
221
245
  @property
222
246
  def image_name(self) -> str:
247
+ if self.job_base_image is not None:
248
+ return self.job_base_image
223
249
  if self.docker_image is not None:
224
250
  return self.docker_image
225
251
  elif self.uri is not None:
@@ -299,10 +325,8 @@ class LaunchProject:
299
325
 
300
326
  def build_required(self) -> bool:
301
327
  """Checks the source to see if a build is required."""
302
- # since the image tag for images built from jobs
303
- # is based on the job version index, which is immutable
304
- # we don't need to build the image for a job if that tag
305
- # already exists
328
+ if self.job_base_image is not None:
329
+ return False
306
330
  if self.source != LaunchSource.JOB:
307
331
  return True
308
332
  return False
@@ -316,7 +340,9 @@ class LaunchProject:
316
340
  Returns:
317
341
  Optional[str]: The Docker image or None if not specified.
318
342
  """
319
- return self._docker_image
343
+ if self._docker_image:
344
+ return self._docker_image
345
+ return None
320
346
 
321
347
  @docker_image.setter
322
348
  def docker_image(self, value: str) -> None:
@@ -336,7 +362,7 @@ class LaunchProject:
336
362
  # assuming project only has 1 entry point, pull that out
337
363
  # tmp fn until we figure out if we want to support multiple entry points or not
338
364
  if not self._entry_point:
339
- if not self.docker_image:
365
+ if not self.docker_image and not self.job_base_image:
340
366
  raise LaunchError(
341
367
  "Project must have at least one entry point unless docker image is specified."
342
368
  )
@@ -717,7 +717,7 @@ class LaunchAgent:
717
717
  _, build_config, registry_config = construct_agent_configs(
718
718
  default_config, override_build_config
719
719
  )
720
- image_uri = project.docker_image
720
+ image_uri = project.docker_image or project.job_base_image
721
721
  entrypoint = project.get_job_entry_point()
722
722
  environment = loader.environment_from_config(
723
723
  default_config.get("environment", {})
@@ -727,7 +727,11 @@ class LaunchAgent:
727
727
  backend = loader.runner_from_config(
728
728
  resource, api, backend_config, environment, registry
729
729
  )
730
- if not (project.docker_image or isinstance(backend, LocalProcessRunner)):
730
+ if not (
731
+ project.docker_image
732
+ or project.job_base_image
733
+ or isinstance(backend, LocalProcessRunner)
734
+ ):
731
735
  assert entrypoint is not None
732
736
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
733
737
 
@@ -2,7 +2,7 @@
2
2
 
3
3
  import os
4
4
  import sys
5
- from typing import List, Optional, Union
5
+ from typing import List, Optional
6
6
 
7
7
  import wandb
8
8
 
@@ -17,9 +17,7 @@ FileSubtypes = Literal["warning", "error"]
17
17
  class RunQueueItemFileSaver:
18
18
  def __init__(
19
19
  self,
20
- agent_run: Optional[
21
- Union["wandb.sdk.wandb_run.Run", "wandb.sdk.lib.RunDisabled"]
22
- ],
20
+ agent_run: Optional["wandb.sdk.wandb_run.Run"],
23
21
  run_queue_item_id: str,
24
22
  ):
25
23
  self.run_queue_item_id = run_queue_item_id
@@ -201,7 +201,7 @@ def get_requirements_section(
201
201
  # If there is a requirements.txt at root of build context, use that.
202
202
  if (base_path / "src" / "requirements.txt").exists():
203
203
  requirements_files += ["src/requirements.txt"]
204
- deps_install_line = "pip install -r requirements.txt"
204
+ deps_install_line = "pip install uv && uv pip install -r requirements.txt"
205
205
  with open(base_path / "src" / "requirements.txt") as f:
206
206
  requirements = f.readlines()
207
207
  if not any(["wandb" in r for r in requirements]):
@@ -237,7 +237,9 @@ def get_requirements_section(
237
237
  with open(base_path / "src" / "requirements.txt", "w") as f:
238
238
  f.write("\n".join(project_deps))
239
239
  requirements_files += ["src/requirements.txt"]
240
- deps_install_line = "pip install -r requirements.txt"
240
+ deps_install_line = (
241
+ "pip install uv && uv pip install -r requirements.txt"
242
+ )
241
243
  return PIP_TEMPLATE.format(
242
244
  buildx_optional_prefix=prefix,
243
245
  requirements_files=" ".join(requirements_files),
@@ -263,11 +263,17 @@ class KanikoBuilder(AbstractBuilder):
263
263
  repo_uri = await self.registry.get_repo_uri()
264
264
  image_uri = repo_uri + ":" + image_tag
265
265
 
266
- if (
267
- not launch_project.build_required()
268
- and await self.registry.check_image_exists(image_uri)
269
- ):
270
- return image_uri
266
+ # The DOCKER_CONFIG_SECRET option is mutually exclusive with the
267
+ # registry classes, so we must skip the check for image existence in
268
+ # that case.
269
+ if not launch_project.build_required():
270
+ if DOCKER_CONFIG_SECRET:
271
+ wandb.termlog(
272
+ f"Skipping check for existing image {image_uri} due to custom dockerconfig."
273
+ )
274
+ else:
275
+ if await self.registry.check_image_exists(image_uri):
276
+ return image_uri
271
277
 
272
278
  _logger.info(f"Building image {image_uri}...")
273
279
  _, api_client = await get_kube_context_and_api_client(
@@ -286,7 +292,12 @@ class KanikoBuilder(AbstractBuilder):
286
292
  wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
287
293
 
288
294
  try:
289
- if isinstance(self.registry, AzureContainerRegistry):
295
+ # DOCKER_CONFIG_SECRET is a user provided dockerconfigjson. Skip our
296
+ # dockerconfig handling if it's set.
297
+ if (
298
+ isinstance(self.registry, AzureContainerRegistry)
299
+ and not DOCKER_CONFIG_SECRET
300
+ ):
290
301
  dockerfile_config_map = client.V1ConfigMap(
291
302
  metadata=client.V1ObjectMeta(
292
303
  name=f"docker-config-{build_job_name}"
@@ -344,7 +355,10 @@ class KanikoBuilder(AbstractBuilder):
344
355
  finally:
345
356
  wandb.termlog(f"{LOG_PREFIX}Cleaning up resources")
346
357
  try:
347
- if isinstance(self.registry, AzureContainerRegistry):
358
+ if (
359
+ isinstance(self.registry, AzureContainerRegistry)
360
+ and not DOCKER_CONFIG_SECRET
361
+ ):
348
362
  await core_v1.delete_namespaced_config_map(
349
363
  f"docker-config-{build_job_name}", "wandb"
350
364
  )
@@ -498,7 +512,10 @@ class KanikoBuilder(AbstractBuilder):
498
512
  "readOnly": True,
499
513
  }
500
514
  )
501
- if isinstance(self.registry, AzureContainerRegistry):
515
+ if (
516
+ isinstance(self.registry, AzureContainerRegistry)
517
+ and not DOCKER_CONFIG_SECRET
518
+ ):
502
519
  # Add the docker config map
503
520
  volumes.append(
504
521
  {
@@ -533,7 +550,11 @@ class KanikoBuilder(AbstractBuilder):
533
550
  # Apply the rest of our defaults
534
551
  pod_labels["wandb"] = "launch"
535
552
  # This annotation is required to enable azure workload identity.
536
- if isinstance(self.registry, AzureContainerRegistry):
553
+ # Don't add this label if using a docker config secret for auth.
554
+ if (
555
+ isinstance(self.registry, AzureContainerRegistry)
556
+ and not DOCKER_CONFIG_SECRET
557
+ ):
537
558
  pod_labels["azure.workload.identity/use"] = "true"
538
559
  pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
539
560
  pod_spec["activeDeadlineSeconds"] = pod_spec.get(
@@ -39,12 +39,13 @@ def install_deps(
39
39
  deps (str[], None): The dependencies that failed to install
40
40
  """
41
41
  try:
42
+ subprocess.check_output(["pip", "install", "uv"], stderr=subprocess.STDOUT)
42
43
  # Include only uri if @ is present
43
44
  clean_deps = [d.split("@")[-1].strip() if "@" in d else d for d in deps]
44
45
  index_args = ["--extra-index-url", extra_index] if extra_index else []
45
46
  print("installing {}...".format(", ".join(clean_deps)))
46
47
  opts = opts or []
47
- args = ["pip", "install"] + opts + clean_deps + index_args
48
+ args = ["uv", "pip", "install"] + opts + clean_deps + index_args
48
49
  sys.stdout.flush()
49
50
  subprocess.check_output(args, stderr=subprocess.STDOUT)
50
51
  return failed
@@ -11,12 +11,14 @@ import os
11
11
  import pathlib
12
12
  import shutil
13
13
  import tempfile
14
- from typing import List, Optional
14
+ from typing import Any, Dict, List, Optional
15
15
 
16
16
  import wandb
17
17
  import wandb.data_types
18
18
  from wandb.sdk.launch.errors import LaunchError
19
+ from wandb.sdk.launch.inputs.schema import META_SCHEMA
19
20
  from wandb.sdk.wandb_run import Run
21
+ from wandb.util import get_module
20
22
 
21
23
  from .files import config_path_is_valid, override_file
22
24
 
@@ -62,11 +64,13 @@ class JobInputArguments:
62
64
  self,
63
65
  include: Optional[List[str]] = None,
64
66
  exclude: Optional[List[str]] = None,
67
+ schema: Optional[dict] = None,
65
68
  file_path: Optional[str] = None,
66
69
  run_config: Optional[bool] = None,
67
70
  ):
68
71
  self.include = include
69
72
  self.exclude = exclude
73
+ self.schema = schema
70
74
  self.file_path = file_path
71
75
  self.run_config = run_config
72
76
 
@@ -121,15 +125,70 @@ def _publish_job_input(
121
125
  exclude_paths=[_split_on_unesc_dot(path) for path in input.exclude]
122
126
  if input.exclude
123
127
  else [],
128
+ input_schema=input.schema,
124
129
  run_config=input.run_config,
125
130
  file_path=input.file_path or "",
126
131
  )
127
132
 
128
133
 
134
+ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
135
+ """Recursively fix JSON schemas with common issues.
136
+
137
+ 1. Replaces any instances of $ref with their associated definition in defs
138
+ 2. Removes any "allOf" lists that only have one item, "lifting" the item up
139
+ See test_internal.py for examples
140
+ """
141
+ ret: Dict[str, Any] = {}
142
+ if "$ref" in schema and defs:
143
+ # Reference found, replace it with its definition
144
+ def_key = schema["$ref"].split("#/$defs/")[1]
145
+ # Also run recursive replacement in case a ref contains more refs
146
+ return _replace_refs_and_allofs(defs.pop(def_key), defs)
147
+ for key, val in schema.items():
148
+ if isinstance(val, dict):
149
+ # Step into dicts recursively
150
+ new_val_dict = _replace_refs_and_allofs(val, defs)
151
+ ret[key] = new_val_dict
152
+ elif isinstance(val, list):
153
+ # Step into each item in the list
154
+ new_val_list = []
155
+ for item in val:
156
+ if isinstance(item, dict):
157
+ new_val_list.append(_replace_refs_and_allofs(item, defs))
158
+ else:
159
+ new_val_list.append(item)
160
+ # Lift up allOf blocks with only one item
161
+ if (
162
+ key == "allOf"
163
+ and len(new_val_list) == 1
164
+ and isinstance(new_val_list[0], dict)
165
+ ):
166
+ ret.update(new_val_list[0])
167
+ else:
168
+ ret[key] = new_val_list
169
+ else:
170
+ # For anything else (str, int, etc) keep it as-is
171
+ ret[key] = val
172
+ return ret
173
+
174
+
175
+ def _validate_schema(schema: dict) -> None:
176
+ jsonschema = get_module(
177
+ "jsonschema",
178
+ required="Setting job schema requires the jsonschema package. Please install it with `pip install 'wandb[launch]'`.",
179
+ lazy=False,
180
+ )
181
+ validator = jsonschema.Draft202012Validator(META_SCHEMA)
182
+ errs = sorted(validator.iter_errors(schema), key=str)
183
+ if errs:
184
+ wandb.termwarn(f"Schema includes unhandled or invalid configurations:\n{errs}")
185
+
186
+
129
187
  def handle_config_file_input(
130
188
  path: str,
131
189
  include: Optional[List[str]] = None,
132
190
  exclude: Optional[List[str]] = None,
191
+ schema: Optional[Any] = None,
133
192
  ):
134
193
  """Declare an overridable configuration file for a launch job.
135
194
 
@@ -151,9 +210,24 @@ def handle_config_file_input(
151
210
  path,
152
211
  dest,
153
212
  )
213
+ if schema:
214
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
215
+ # or the BaseModel class itself (e.g. schema=MySchema)
216
+ if hasattr(schema, "model_json_schema") and callable(
217
+ schema.model_json_schema # type: ignore
218
+ ):
219
+ schema = schema.model_json_schema()
220
+ if not isinstance(schema, dict):
221
+ raise LaunchError(
222
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
223
+ )
224
+ defs = schema.pop("$defs", None)
225
+ schema = _replace_refs_and_allofs(schema, defs)
226
+ _validate_schema(schema)
154
227
  arguments = JobInputArguments(
155
228
  include=include,
156
229
  exclude=exclude,
230
+ schema=schema,
157
231
  file_path=path,
158
232
  run_config=False,
159
233
  )
@@ -165,7 +239,9 @@ def handle_config_file_input(
165
239
 
166
240
 
167
241
  def handle_run_config_input(
168
- include: Optional[List[str]] = None, exclude: Optional[List[str]] = None
242
+ include: Optional[List[str]] = None,
243
+ exclude: Optional[List[str]] = None,
244
+ schema: Optional[Any] = None,
169
245
  ):
170
246
  """Declare wandb.config as an overridable configuration for a launch job.
171
247
 
@@ -175,9 +251,24 @@ def handle_run_config_input(
175
251
  If there is no active run, the include and exclude paths are staged and sent
176
252
  when a run is created.
177
253
  """
254
+ if schema:
255
+ # This supports both an instance of a pydantic BaseModel class (e.g. schema=MySchema(...))
256
+ # or the BaseModel class itself (e.g. schema=MySchema)
257
+ if hasattr(schema, "model_json_schema") and callable(
258
+ schema.model_json_schema # type: ignore
259
+ ):
260
+ schema = schema.model_json_schema()
261
+ if not isinstance(schema, dict):
262
+ raise LaunchError(
263
+ "schema must be a dict, Pydantic model instance, or Pydantic model class."
264
+ )
265
+ defs = schema.pop("$defs", None)
266
+ schema = _replace_refs_and_allofs(schema, defs)
267
+ _validate_schema(schema)
178
268
  arguments = JobInputArguments(
179
269
  include=include,
180
270
  exclude=exclude,
271
+ schema=schema,
181
272
  run_config=True,
182
273
  file_path=None,
183
274
  )