wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +29 -5
  7. wandb/integration/openai/fine_tuning.py +74 -37
  8. wandb/integration/ultralytics/callback.py +0 -1
  9. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  10. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  13. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/sdk/artifacts/artifact.py +92 -26
  16. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  17. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  18. wandb/sdk/artifacts/artifact_saver.py +16 -36
  19. wandb/sdk/artifacts/storage_handler.py +2 -1
  20. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  21. wandb/sdk/interface/interface.py +60 -15
  22. wandb/sdk/interface/interface_shared.py +13 -7
  23. wandb/sdk/internal/file_stream.py +19 -0
  24. wandb/sdk/internal/handler.py +1 -4
  25. wandb/sdk/internal/internal_api.py +2 -0
  26. wandb/sdk/internal/job_builder.py +45 -17
  27. wandb/sdk/internal/sender.py +53 -28
  28. wandb/sdk/internal/settings_static.py +9 -0
  29. wandb/sdk/internal/system/system_info.py +4 -1
  30. wandb/sdk/launch/_launch.py +5 -0
  31. wandb/sdk/launch/_project_spec.py +5 -20
  32. wandb/sdk/launch/agent/agent.py +80 -37
  33. wandb/sdk/launch/agent/config.py +8 -0
  34. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  35. wandb/sdk/launch/create_job.py +44 -48
  36. wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
  37. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  38. wandb/sdk/launch/sweeps/scheduler.py +3 -1
  39. wandb/sdk/launch/utils.py +23 -5
  40. wandb/sdk/lib/__init__.py +2 -5
  41. wandb/sdk/lib/_settings_toposort_generated.py +2 -0
  42. wandb/sdk/lib/filesystem.py +11 -1
  43. wandb/sdk/lib/run_moment.py +78 -0
  44. wandb/sdk/service/streams.py +1 -6
  45. wandb/sdk/wandb_init.py +12 -7
  46. wandb/sdk/wandb_login.py +43 -26
  47. wandb/sdk/wandb_run.py +179 -94
  48. wandb/sdk/wandb_settings.py +55 -16
  49. wandb/testing/relay.py +5 -6
  50. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
  51. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
  52. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
  53. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
  54. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
  55. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,10 @@ MAX_RESUME_COUNT = 5
45
45
 
46
46
  RUN_INFO_GRACE_PERIOD = 60
47
47
 
48
- MAX_WAIT_RUN_STOPPED = 60
48
+ DEFAULT_STOPPED_RUN_TIMEOUT = 60
49
+
50
+ DEFAULT_PRINT_INTERVAL = 5 * 60
51
+ VERBOSE_PRINT_INTERVAL = 20
49
52
 
50
53
  _env_timeout = os.environ.get("WANDB_LAUNCH_START_TIMEOUT")
51
54
  if _env_timeout:
@@ -105,30 +108,29 @@ def _max_from_config(
105
108
  return max_from_config
106
109
 
107
110
 
108
- def _is_scheduler_job(run_spec: Dict[str, Any]) -> bool:
109
- """Determine whether a job/runSpec is a sweep scheduler."""
110
- if not run_spec:
111
- _logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
111
+ class InternalAgentLogger:
112
+ def __init__(self, verbosity=0):
113
+ self._print_to_terminal = verbosity >= 2
112
114
 
113
- if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
114
- return False
115
+ def error(self, message: str):
116
+ if self._print_to_terminal:
117
+ wandb.termerror(f"{LOG_PREFIX}{message}")
118
+ _logger.error(f"{LOG_PREFIX}{message}")
115
119
 
116
- if run_spec.get("resource") == "local-process":
117
- # Any job pushed to a run queue that has a scheduler uri is
118
- # allowed to use local-process
119
- if run_spec.get("job"):
120
- return True
120
+ def warn(self, message: str):
121
+ if self._print_to_terminal:
122
+ wandb.termwarn(f"{LOG_PREFIX}{message}")
123
+ _logger.warn(f"{LOG_PREFIX}{message}")
121
124
 
122
- # If a scheduler is local-process and run through CLI, also
123
- # confirm command is in format: [wandb scheduler <sweep>]
124
- cmd = run_spec.get("overrides", {}).get("entry_point", [])
125
- if len(cmd) < 3:
126
- return False
125
+ def info(self, message: str):
126
+ if self._print_to_terminal:
127
+ wandb.termlog(f"{LOG_PREFIX}{message}")
128
+ _logger.info(f"{LOG_PREFIX}{message}")
127
129
 
128
- if cmd[:2] != ["wandb", "scheduler"]:
129
- return False
130
-
131
- return True
130
+ def debug(self, message: str):
131
+ if self._print_to_terminal:
132
+ wandb.termlog(f"{LOG_PREFIX}{message}")
133
+ _logger.debug(f"{LOG_PREFIX}{message}")
132
134
 
133
135
 
134
136
  class LaunchAgent:
@@ -184,7 +186,13 @@ class LaunchAgent:
184
186
  self._max_jobs = _max_from_config(config, "max_jobs")
185
187
  self._max_schedulers = _max_from_config(config, "max_schedulers")
186
188
  self._secure_mode = config.get("secure_mode", False)
189
+ self._verbosity = config.get("verbosity", 0)
190
+ self._internal_logger = InternalAgentLogger(verbosity=self._verbosity)
191
+ self._last_status_print_time = 0.0
187
192
  self.default_config: Dict[str, Any] = config
193
+ self._stopped_run_timeout = config.get(
194
+ "stopped_run_timeout", DEFAULT_STOPPED_RUN_TIMEOUT
195
+ )
188
196
 
189
197
  # Get agent version from env var if present, otherwise wandb version
190
198
  self.version: str = "wandb@" + wandb.__version__
@@ -228,6 +236,33 @@ class LaunchAgent:
228
236
  self._name = agent_response["name"]
229
237
  self._init_agent_run()
230
238
 
239
+ def _is_scheduler_job(self, run_spec: Dict[str, Any]) -> bool:
240
+ """Determine whether a job/runSpec is a sweep scheduler."""
241
+ if not run_spec:
242
+ self._internal_logger.debug(
243
+ "Recieved runSpec in _is_scheduler_job that was empty"
244
+ )
245
+
246
+ if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
247
+ return False
248
+
249
+ if run_spec.get("resource") == "local-process":
250
+ # Any job pushed to a run queue that has a scheduler uri is
251
+ # allowed to use local-process
252
+ if run_spec.get("job"):
253
+ return True
254
+
255
+ # If a scheduler is local-process and run through CLI, also
256
+ # confirm command is in format: [wandb scheduler <sweep>]
257
+ cmd = run_spec.get("overrides", {}).get("entry_point", [])
258
+ if len(cmd) < 3:
259
+ return False
260
+
261
+ if cmd[:2] != ["wandb", "scheduler"]:
262
+ return False
263
+
264
+ return True
265
+
231
266
  async def fail_run_queue_item(
232
267
  self,
233
268
  run_queue_item_id: str,
@@ -298,6 +333,7 @@ class LaunchAgent:
298
333
 
299
334
  def print_status(self) -> None:
300
335
  """Prints the current status of the agent."""
336
+ self._last_status_print_time = time.time()
301
337
  output_str = "agent "
302
338
  if self._name:
303
339
  output_str += f"{self._name} "
@@ -344,8 +380,8 @@ class LaunchAgent:
344
380
  if run_state.lower() != "pending":
345
381
  return True
346
382
  except CommError:
347
- _logger.info(
348
- f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run"
383
+ self._internal_logger.info(
384
+ f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run",
349
385
  )
350
386
  return False
351
387
 
@@ -361,8 +397,8 @@ class LaunchAgent:
361
397
  job_and_run_status.entity is not None
362
398
  and job_and_run_status.entity != self._entity
363
399
  ):
364
- _logger.info(
365
- "Skipping check for completed run status because run is on a different entity than agent"
400
+ self._internal_logger.info(
401
+ "Skipping check for completed run status because run is on a different entity than agent",
366
402
  )
367
403
  elif exception is not None:
368
404
  tb_str = traceback.format_exception(
@@ -378,8 +414,8 @@ class LaunchAgent:
378
414
  fnames,
379
415
  )
380
416
  elif job_and_run_status.project is None or job_and_run_status.run_id is None:
381
- _logger.error(
382
- f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}"
417
+ self._internal_logger.info(
418
+ f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}",
383
419
  )
384
420
  wandb.termerror(
385
421
  "Missing project or run id on thread called finish thread id"
@@ -430,7 +466,9 @@ class LaunchAgent:
430
466
  job_and_run_status.run_queue_item_id, _msg, "run", fnames
431
467
  )
432
468
  else:
433
- _logger.info(f"Finish thread id {thread_id} had no exception and no run")
469
+ self._internal_logger.info(
470
+ f"Finish thread id {thread_id} had no exception and no run"
471
+ )
434
472
  wandb._sentry.exception(
435
473
  "launch agent called finish thread id on thread without run or exception"
436
474
  )
@@ -458,7 +496,7 @@ class LaunchAgent:
458
496
  await self.update_status(AGENT_RUNNING)
459
497
 
460
498
  # parse job
461
- _logger.info("Parsing launch spec")
499
+ self._internal_logger.info("Parsing launch spec")
462
500
  launch_spec = job["runSpec"]
463
501
 
464
502
  # Abort if this job attempts to override secure mode
@@ -511,6 +549,10 @@ class LaunchAgent:
511
549
  KeyboardInterrupt: if the agent is requested to stop.
512
550
  """
513
551
  self.print_status()
552
+ if self._verbosity == 0:
553
+ print_interval = DEFAULT_PRINT_INTERVAL
554
+ else:
555
+ print_interval = VERBOSE_PRINT_INTERVAL
514
556
  try:
515
557
  while True:
516
558
  job = None
@@ -532,7 +574,7 @@ class LaunchAgent:
532
574
  file_saver = RunQueueItemFileSaver(
533
575
  self._wandb_run, job["runQueueItemId"]
534
576
  )
535
- if _is_scheduler_job(job.get("runSpec", {})):
577
+ if self._is_scheduler_job(job.get("runSpec", {})):
536
578
  # If job is a scheduler, and we are already at the cap, ignore,
537
579
  # don't ack, and it will be pushed back onto the queue in 1 min
538
580
  if self.num_running_schedulers >= self._max_schedulers:
@@ -567,6 +609,7 @@ class LaunchAgent:
567
609
  await self.update_status(AGENT_POLLING)
568
610
  else:
569
611
  await self.update_status(AGENT_RUNNING)
612
+ if time.time() - self._last_status_print_time > print_interval:
570
613
  self.print_status()
571
614
 
572
615
  if self.num_running_jobs == self._max_jobs or job is None:
@@ -634,14 +677,14 @@ class LaunchAgent:
634
677
  await self.check_sweep_state(launch_spec, api)
635
678
 
636
679
  job_tracker.update_run_info(project)
637
- _logger.info("Fetching and validating project...")
680
+ self._internal_logger.info("Fetching and validating project...")
638
681
  project.fetch_and_validate_project()
639
- _logger.info("Fetching resource...")
682
+ self._internal_logger.info("Fetching resource...")
640
683
  resource = launch_spec.get("resource") or "local-container"
641
684
  backend_config: Dict[str, Any] = {
642
685
  PROJECT_SYNCHRONOUS: False, # agent always runs async
643
686
  }
644
- _logger.info("Loading backend")
687
+ self._internal_logger.info("Loading backend")
645
688
  override_build_config = launch_spec.get("builder")
646
689
 
647
690
  _, build_config, registry_config = construct_agent_configs(
@@ -661,13 +704,13 @@ class LaunchAgent:
661
704
  assert entrypoint is not None
662
705
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
663
706
 
664
- _logger.info("Backend loaded...")
707
+ self._internal_logger.info("Backend loaded...")
665
708
  if isinstance(backend, LocalProcessRunner):
666
709
  run = await backend.run(project, image_uri)
667
710
  else:
668
711
  assert image_uri
669
712
  run = await backend.run(project, image_uri)
670
- if _is_scheduler_job(launch_spec):
713
+ if self._is_scheduler_job(launch_spec):
671
714
  with self._jobs_lock:
672
715
  self._jobs[thread_id].is_scheduler = True
673
716
  wandb.termlog(
@@ -700,7 +743,7 @@ class LaunchAgent:
700
743
  if stopped_time is None:
701
744
  stopped_time = time.time()
702
745
  else:
703
- if time.time() - stopped_time > MAX_WAIT_RUN_STOPPED:
746
+ if time.time() - stopped_time > self._stopped_run_timeout:
704
747
  await run.cancel()
705
748
  await asyncio.sleep(AGENT_POLLING_INTERVAL)
706
749
 
@@ -720,7 +763,7 @@ class LaunchAgent:
720
763
  project=launch_spec["project"],
721
764
  )
722
765
  except Exception as e:
723
- _logger.debug(f"Fetch sweep state error: {e}")
766
+ self._internal_logger.debug(f"Fetch sweep state error: {e}")
724
767
  state = None
725
768
 
726
769
  if state != "RUNNING" and state != "PAUSED":
@@ -225,6 +225,14 @@ class AgentConfig(BaseModel):
225
225
  None,
226
226
  description="The builder to use.",
227
227
  )
228
+ verbosity: Optional[int] = Field(
229
+ 0,
230
+ description="How verbose to print, 0 = default, 1 = verbose, 2 = very verbose",
231
+ )
232
+ stopped_run_timeout: Optional[int] = Field(
233
+ 60,
234
+ description="How many seconds to wait after receiving the stop command before forcibly cancelling a run.",
235
+ )
228
236
 
229
237
  class Config:
230
238
  extra = "forbid"
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import base64
3
+ import copy
3
4
  import json
4
5
  import logging
5
6
  import os
@@ -8,7 +9,7 @@ import tarfile
8
9
  import tempfile
9
10
  import time
10
11
  import traceback
11
- from typing import Optional
12
+ from typing import Any, Dict, Optional
12
13
 
13
14
  import wandb
14
15
  from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
@@ -105,6 +106,7 @@ class KanikoBuilder(AbstractBuilder):
105
106
  secret_name: str = "",
106
107
  secret_key: str = "",
107
108
  image: str = "gcr.io/kaniko-project/executor:v1.11.0",
109
+ config: Optional[dict] = None,
108
110
  ):
109
111
  """Initialize a KanikoBuilder.
110
112
 
@@ -125,6 +127,7 @@ class KanikoBuilder(AbstractBuilder):
125
127
  self.secret_name = secret_name
126
128
  self.secret_key = secret_key
127
129
  self.image = image
130
+ self.kaniko_config = config or {}
128
131
 
129
132
  @classmethod
130
133
  def from_config(
@@ -170,6 +173,7 @@ class KanikoBuilder(AbstractBuilder):
170
173
  image_uri = config.get("destination")
171
174
  if image_uri is not None:
172
175
  registry = registry_from_uri(image_uri)
176
+ kaniko_config = config.get("kaniko-config", {})
173
177
 
174
178
  return cls(
175
179
  environment,
@@ -179,6 +183,7 @@ class KanikoBuilder(AbstractBuilder):
179
183
  secret_name=secret_name,
180
184
  secret_key=secret_key,
181
185
  image=kaniko_image,
186
+ config=kaniko_config,
182
187
  )
183
188
 
184
189
  async def verify(self) -> None:
@@ -289,7 +294,7 @@ class KanikoBuilder(AbstractBuilder):
289
294
 
290
295
  build_context = await self._upload_build_context(run_id, context_path)
291
296
  build_job = await self._create_kaniko_job(
292
- build_job_name, repo_uri, image_uri, build_context, core_v1
297
+ build_job_name, repo_uri, image_uri, build_context, core_v1, api_client
293
298
  )
294
299
  wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
295
300
 
@@ -324,7 +329,9 @@ class KanikoBuilder(AbstractBuilder):
324
329
  ):
325
330
  if job_tracker:
326
331
  job_tracker.set_err_stage("build")
327
- raise Exception(f"Failed to build image in kaniko for job {run_id}")
332
+ raise Exception(
333
+ f"Failed to build image in kaniko for job {run_id}. View logs with `kubectl logs -n {NAMESPACE} {build_job_name}`."
334
+ )
328
335
  try:
329
336
  pods_from_job = await core_v1.list_namespaced_pod(
330
337
  namespace=NAMESPACE, label_selector=f"job-name={build_job_name}"
@@ -371,23 +378,32 @@ class KanikoBuilder(AbstractBuilder):
371
378
  image_tag: str,
372
379
  build_context_path: str,
373
380
  core_client: client.CoreV1Api,
374
- ) -> "client.V1Job":
375
- env = []
376
- volume_mounts = []
377
- volumes = []
381
+ api_client,
382
+ ) -> Dict[str, Any]:
383
+ job = copy.deepcopy(self.kaniko_config)
384
+ job_metadata = job.get("metadata", {})
385
+ job_labels = job_metadata.get("labels", {})
386
+ job_spec = job.get("spec", {})
387
+ pod_template = job_spec.get("template", {})
388
+ pod_metadata = pod_template.get("metadata", {})
389
+ pod_labels = pod_metadata.get("labels", {})
390
+ pod_spec = pod_template.get("spec", {})
391
+ volumes = pod_spec.get("volumes", [])
392
+ containers = pod_spec.get("containers") or [{}]
393
+ if len(containers) > 1:
394
+ raise LaunchError(
395
+ "Multiple container configs not supported for kaniko builder."
396
+ )
397
+ container = containers[0]
398
+ volume_mounts = container.get("volumeMounts", [])
399
+ env = container.get("env", [])
400
+ custom_args = container.get("args", [])
378
401
 
379
402
  if PVC_MOUNT_PATH:
380
403
  volumes.append(
381
- client.V1Volume(
382
- name="kaniko-pvc",
383
- persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
384
- claim_name=PVC_NAME
385
- ),
386
- )
387
- )
388
- volume_mounts.append(
389
- client.V1VolumeMount(name="kaniko-pvc", mount_path="/context")
404
+ {"name": "kaniko-pvc", "persistentVolumeClaim": {"claimName": PVC_NAME}}
390
405
  )
406
+ volume_mounts.append({"name": "kaniko-pvc", "mountPath": "/context"})
391
407
 
392
408
  if bool(self.secret_name) != bool(self.secret_key):
393
409
  raise LaunchError(
@@ -395,13 +411,13 @@ class KanikoBuilder(AbstractBuilder):
395
411
  "for kaniko build. You provided only one of them."
396
412
  )
397
413
  if isinstance(self.registry, ElasticContainerRegistry):
398
- env += [
399
- client.V1EnvVar(
400
- name="AWS_REGION",
401
- value=self.registry.region,
402
- )
403
- ]
404
- # TODO: Refactor all of this environment/registry
414
+ env.append(
415
+ {
416
+ "name": "AWS_REGION",
417
+ "value": self.registry.region,
418
+ }
419
+ )
420
+ # TODO(ben): Refactor all of this environment/registry
405
421
  # specific stuff into methods of those classes.
406
422
  if isinstance(self.environment, AzureEnvironment):
407
423
  # Use the core api to check if the secret exists
@@ -416,52 +432,46 @@ class KanikoBuilder(AbstractBuilder):
416
432
  "namespace wandb. Please create it with the key password "
417
433
  "set to your azure storage access key."
418
434
  ) from e
419
- env += [
420
- client.V1EnvVar(
421
- name="AZURE_STORAGE_ACCESS_KEY",
422
- value_from=client.V1EnvVarSource(
423
- secret_key_ref=client.V1SecretKeySelector(
424
- name="azure-storage-access-key",
425
- key="password",
426
- )
427
- ),
428
- )
429
- ]
435
+ env.append(
436
+ {
437
+ "name": "AZURE_STORAGE_ACCESS_KEY",
438
+ "valueFrom": {
439
+ "secretKeyRef": {
440
+ "name": "azure-storage-access-key",
441
+ "key": "password",
442
+ }
443
+ },
444
+ }
445
+ )
430
446
  if DOCKER_CONFIG_SECRET:
431
447
  volumes.append(
432
- client.V1Volume(
433
- name="kaniko-docker-config",
434
- secret=client.V1SecretVolumeSource(
435
- secret_name=DOCKER_CONFIG_SECRET,
436
- items=[
437
- client.V1KeyToPath(
438
- key=".dockerconfigjson", path="config.json"
439
- )
448
+ {
449
+ "name": "kaniko-docker-config",
450
+ "secret": {
451
+ "secretName": DOCKER_CONFIG_SECRET,
452
+ "items": [
453
+ {
454
+ "key": ".dockerconfigjson",
455
+ "path": "config.json",
456
+ }
440
457
  ],
441
- ),
442
- )
458
+ },
459
+ }
443
460
  )
444
461
  volume_mounts.append(
445
- client.V1VolumeMount(
446
- name="kaniko-docker-config",
447
- mount_path="/kaniko/.docker",
448
- )
462
+ {"name": "kaniko-docker-config", "mountPath": "/kaniko/.docker"}
449
463
  )
450
464
  elif self.secret_name and self.secret_key:
451
- volumes += [
452
- client.V1Volume(
453
- name="docker-config",
454
- config_map=client.V1ConfigMapVolumeSource(
455
- name=f"docker-config-{job_name}",
456
- ),
457
- ),
458
- ]
459
- volume_mounts += [
460
- client.V1VolumeMount(
461
- name="docker-config", mount_path="/kaniko/.docker/"
462
- ),
463
- ]
464
- # TODO: I don't like conditioning on the registry type here. As a
465
+ volumes.append(
466
+ {
467
+ "name": "docker-config",
468
+ "configMap": {"name": f"docker-config-{job_name}"},
469
+ }
470
+ )
471
+ volume_mounts.append(
472
+ {"name": "docker-config", "mountPath": "/kaniko/.docker"}
473
+ )
474
+ # TODO(ben): I don't like conditioning on the registry type here. As a
465
475
  # future change I want the registry and environment classes to provide
466
476
  # a list of environment variables and volume mounts that need to be
467
477
  # added to the job. The environment class provides credentials for
@@ -475,90 +485,95 @@ class KanikoBuilder(AbstractBuilder):
475
485
  elif isinstance(self.registry, GoogleArtifactRegistry):
476
486
  mount_path = "/kaniko/.config/gcloud"
477
487
  key = "config.json"
478
- env += [
479
- client.V1EnvVar(
480
- name="GOOGLE_APPLICATION_CREDENTIALS",
481
- value="/kaniko/.config/gcloud/config.json",
482
- )
483
- ]
488
+ env.append(
489
+ {
490
+ "name": "GOOGLE_APPLICATION_CREDENTIALS",
491
+ "value": "/kaniko/.config/gcloud/config.json",
492
+ }
493
+ )
484
494
  else:
485
495
  raise LaunchError(
486
496
  f"Registry type {type(self.registry)} not supported by kaniko"
487
497
  )
488
- volume_mounts += [
489
- client.V1VolumeMount(
490
- name=self.secret_name,
491
- mount_path=mount_path,
492
- read_only=True,
493
- )
494
- ]
495
- volumes += [
496
- client.V1Volume(
497
- name=self.secret_name,
498
- secret=client.V1SecretVolumeSource(
499
- secret_name=self.secret_name,
500
- items=[client.V1KeyToPath(key=self.secret_key, path=key)],
501
- ),
502
- )
503
- ]
498
+ volumes.append(
499
+ {
500
+ "name": self.secret_name,
501
+ "secret": {
502
+ "secretName": self.secret_name,
503
+ "items": [{"key": self.secret_key, "path": key}],
504
+ },
505
+ }
506
+ )
507
+ volume_mounts.append(
508
+ {
509
+ "name": self.secret_name,
510
+ "mountPath": mount_path,
511
+ "readOnly": True,
512
+ }
513
+ )
504
514
  if isinstance(self.registry, AzureContainerRegistry):
505
- # ADd the docker config map
506
- volume_mounts += [
507
- client.V1VolumeMount(
508
- name="docker-config", mount_path="/kaniko/.docker/"
509
- ),
510
- ]
511
- volumes += [
512
- client.V1Volume(
513
- name="docker-config",
514
- config_map=client.V1ConfigMapVolumeSource(
515
- name=f"docker-config-{job_name}",
516
- ),
517
- ),
518
- ]
515
+ # Add the docker config map
516
+ volumes.append(
517
+ {
518
+ "name": "docker-config",
519
+ "configMap": {"name": f"docker-config-{job_name}"},
520
+ }
521
+ )
522
+ volume_mounts.append(
523
+ {"name": "docker-config", "mountPath": "/kaniko/.docker/"}
524
+ )
519
525
  # Kaniko doesn't want https:// at the begining of the image tag.
520
526
  destination = image_tag
521
527
  if destination.startswith("https://"):
522
528
  destination = destination.replace("https://", "")
523
- args = [
524
- f"--context={build_context_path}",
525
- f"--dockerfile={_WANDB_DOCKERFILE_NAME}",
526
- f"--destination={destination}",
527
- "--cache=true",
528
- f"--cache-repo={repository.replace('https://', '')}",
529
- "--snapshotMode=redo",
530
- "--compressed-caching=false",
529
+ args = {
530
+ "--context": build_context_path,
531
+ "--dockerfile": _WANDB_DOCKERFILE_NAME,
532
+ "--destination": destination,
533
+ "--cache": "true",
534
+ "--cache-repo": repository.replace("https://", ""),
535
+ "--snapshot-mode": "redo",
536
+ "--compressed-caching": "false",
537
+ }
538
+ for custom_arg in custom_args:
539
+ arg_name, arg_value = custom_arg.split("=", 1)
540
+ args[arg_name] = arg_value
541
+ parsed_args = [
542
+ f"{arg_name}={arg_value}" for arg_name, arg_value in args.items()
531
543
  ]
532
- container = client.V1Container(
533
- name="wandb-container-build",
534
- image=self.image,
535
- args=args,
536
- volume_mounts=volume_mounts,
537
- env=env if env else None,
538
- )
539
- # Create and configure a spec section
540
- labels = {"wandb": "launch"}
544
+ container["args"] = parsed_args
545
+
546
+ # Apply the rest of our defaults
547
+ pod_labels["wandb"] = "launch"
541
548
  # This annotation is required to enable azure workload identity.
542
549
  if isinstance(self.registry, AzureContainerRegistry):
543
- labels["azure.workload.identity/use"] = "true"
544
- template = client.V1PodTemplateSpec(
545
- metadata=client.V1ObjectMeta(labels=labels),
546
- spec=client.V1PodSpec(
547
- restart_policy="Never",
548
- active_deadline_seconds=_DEFAULT_BUILD_TIMEOUT_SECS,
549
- containers=[container],
550
- volumes=volumes,
551
- service_account_name=SERVICE_ACCOUNT_NAME,
552
- ),
550
+ pod_labels["azure.workload.identity/use"] = "true"
551
+ pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
552
+ pod_spec["activeDeadlineSeconds"] = pod_spec.get(
553
+ "activeDeadlineSeconds", _DEFAULT_BUILD_TIMEOUT_SECS
553
554
  )
554
- # Create the specification of job
555
- spec = client.V1JobSpec(template=template, backoff_limit=0)
556
- job = client.V1Job(
557
- api_version="batch/v1",
558
- kind="Job",
559
- metadata=client.V1ObjectMeta(
560
- name=job_name, namespace=NAMESPACE, labels={"wandb": "launch"}
561
- ),
562
- spec=spec,
555
+ pod_spec["serviceAccountName"] = pod_spec.get(
556
+ "serviceAccountName", SERVICE_ACCOUNT_NAME
563
557
  )
558
+ job_spec["backoffLimit"] = job_spec.get("backoffLimit", 0)
559
+ job_labels["wandb"] = "launch"
560
+ job_metadata["namespace"] = job_metadata.get("namespace", NAMESPACE)
561
+ job_metadata["name"] = job_metadata.get("name", job_name)
562
+ job["apiVersion"] = "batch/v1"
563
+ job["kind"] = "Job"
564
+
565
+ # Apply all nested configs from the bottom up
566
+ pod_metadata["labels"] = pod_labels
567
+ pod_template["metadata"] = pod_metadata
568
+ container["name"] = container.get("name", "wandb-container-build")
569
+ container["image"] = container.get("image", self.image)
570
+ container["volumeMounts"] = volume_mounts
571
+ container["env"] = env
572
+ pod_spec["containers"] = [container]
573
+ pod_spec["volumes"] = volumes
574
+ pod_template["spec"] = pod_spec
575
+ job_spec["template"] = pod_template
576
+ job_metadata["labels"] = job_labels
577
+ job["metadata"] = job_metadata
578
+ job["spec"] = job_spec
564
579
  return job