wandb 0.16.4__py3-none-any.whl → 0.16.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (55) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/agents/pyagent.py +1 -1
  3. wandb/apis/public/api.py +6 -6
  4. wandb/apis/reports/v2/interface.py +4 -8
  5. wandb/apis/reports/v2/internal.py +12 -45
  6. wandb/cli/cli.py +29 -5
  7. wandb/integration/openai/fine_tuning.py +74 -37
  8. wandb/integration/ultralytics/callback.py +0 -1
  9. wandb/proto/v3/wandb_internal_pb2.py +332 -312
  10. wandb/proto/v3/wandb_settings_pb2.py +13 -3
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_internal_pb2.py +316 -312
  13. wandb/proto/v4/wandb_settings_pb2.py +5 -3
  14. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  15. wandb/sdk/artifacts/artifact.py +92 -26
  16. wandb/sdk/artifacts/artifact_manifest_entry.py +6 -1
  17. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  18. wandb/sdk/artifacts/artifact_saver.py +16 -36
  19. wandb/sdk/artifacts/storage_handler.py +2 -1
  20. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +13 -5
  21. wandb/sdk/interface/interface.py +60 -15
  22. wandb/sdk/interface/interface_shared.py +13 -7
  23. wandb/sdk/internal/file_stream.py +19 -0
  24. wandb/sdk/internal/handler.py +1 -4
  25. wandb/sdk/internal/internal_api.py +2 -0
  26. wandb/sdk/internal/job_builder.py +45 -17
  27. wandb/sdk/internal/sender.py +53 -28
  28. wandb/sdk/internal/settings_static.py +9 -0
  29. wandb/sdk/internal/system/system_info.py +4 -1
  30. wandb/sdk/launch/_launch.py +5 -0
  31. wandb/sdk/launch/_project_spec.py +5 -20
  32. wandb/sdk/launch/agent/agent.py +80 -37
  33. wandb/sdk/launch/agent/config.py +8 -0
  34. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  35. wandb/sdk/launch/create_job.py +44 -48
  36. wandb/sdk/launch/runner/kubernetes_monitor.py +3 -1
  37. wandb/sdk/launch/runner/kubernetes_runner.py +20 -2
  38. wandb/sdk/launch/sweeps/scheduler.py +3 -1
  39. wandb/sdk/launch/utils.py +23 -5
  40. wandb/sdk/lib/__init__.py +2 -5
  41. wandb/sdk/lib/_settings_toposort_generated.py +2 -0
  42. wandb/sdk/lib/filesystem.py +11 -1
  43. wandb/sdk/lib/run_moment.py +78 -0
  44. wandb/sdk/service/streams.py +1 -6
  45. wandb/sdk/wandb_init.py +12 -7
  46. wandb/sdk/wandb_login.py +43 -26
  47. wandb/sdk/wandb_run.py +179 -94
  48. wandb/sdk/wandb_settings.py +55 -16
  49. wandb/testing/relay.py +5 -6
  50. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/METADATA +1 -1
  51. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/RECORD +55 -54
  52. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/WHEEL +1 -1
  53. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/LICENSE +0 -0
  54. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/entry_points.txt +0 -0
  55. {wandb-0.16.4.dist-info → wandb-0.16.6.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,10 @@ MAX_RESUME_COUNT = 5
45
45
 
46
46
  RUN_INFO_GRACE_PERIOD = 60
47
47
 
48
- MAX_WAIT_RUN_STOPPED = 60
48
+ DEFAULT_STOPPED_RUN_TIMEOUT = 60
49
+
50
+ DEFAULT_PRINT_INTERVAL = 5 * 60
51
+ VERBOSE_PRINT_INTERVAL = 20
49
52
 
50
53
  _env_timeout = os.environ.get("WANDB_LAUNCH_START_TIMEOUT")
51
54
  if _env_timeout:
@@ -105,30 +108,29 @@ def _max_from_config(
105
108
  return max_from_config
106
109
 
107
110
 
108
- def _is_scheduler_job(run_spec: Dict[str, Any]) -> bool:
109
- """Determine whether a job/runSpec is a sweep scheduler."""
110
- if not run_spec:
111
- _logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
111
+ class InternalAgentLogger:
112
+ def __init__(self, verbosity=0):
113
+ self._print_to_terminal = verbosity >= 2
112
114
 
113
- if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
114
- return False
115
+ def error(self, message: str):
116
+ if self._print_to_terminal:
117
+ wandb.termerror(f"{LOG_PREFIX}{message}")
118
+ _logger.error(f"{LOG_PREFIX}{message}")
115
119
 
116
- if run_spec.get("resource") == "local-process":
117
- # Any job pushed to a run queue that has a scheduler uri is
118
- # allowed to use local-process
119
- if run_spec.get("job"):
120
- return True
120
+ def warn(self, message: str):
121
+ if self._print_to_terminal:
122
+ wandb.termwarn(f"{LOG_PREFIX}{message}")
123
+ _logger.warn(f"{LOG_PREFIX}{message}")
121
124
 
122
- # If a scheduler is local-process and run through CLI, also
123
- # confirm command is in format: [wandb scheduler <sweep>]
124
- cmd = run_spec.get("overrides", {}).get("entry_point", [])
125
- if len(cmd) < 3:
126
- return False
125
+ def info(self, message: str):
126
+ if self._print_to_terminal:
127
+ wandb.termlog(f"{LOG_PREFIX}{message}")
128
+ _logger.info(f"{LOG_PREFIX}{message}")
127
129
 
128
- if cmd[:2] != ["wandb", "scheduler"]:
129
- return False
130
-
131
- return True
130
+ def debug(self, message: str):
131
+ if self._print_to_terminal:
132
+ wandb.termlog(f"{LOG_PREFIX}{message}")
133
+ _logger.debug(f"{LOG_PREFIX}{message}")
132
134
 
133
135
 
134
136
  class LaunchAgent:
@@ -184,7 +186,13 @@ class LaunchAgent:
184
186
  self._max_jobs = _max_from_config(config, "max_jobs")
185
187
  self._max_schedulers = _max_from_config(config, "max_schedulers")
186
188
  self._secure_mode = config.get("secure_mode", False)
189
+ self._verbosity = config.get("verbosity", 0)
190
+ self._internal_logger = InternalAgentLogger(verbosity=self._verbosity)
191
+ self._last_status_print_time = 0.0
187
192
  self.default_config: Dict[str, Any] = config
193
+ self._stopped_run_timeout = config.get(
194
+ "stopped_run_timeout", DEFAULT_STOPPED_RUN_TIMEOUT
195
+ )
188
196
 
189
197
  # Get agent version from env var if present, otherwise wandb version
190
198
  self.version: str = "wandb@" + wandb.__version__
@@ -228,6 +236,33 @@ class LaunchAgent:
228
236
  self._name = agent_response["name"]
229
237
  self._init_agent_run()
230
238
 
239
+ def _is_scheduler_job(self, run_spec: Dict[str, Any]) -> bool:
240
+ """Determine whether a job/runSpec is a sweep scheduler."""
241
+ if not run_spec:
242
+ self._internal_logger.debug(
243
+ "Recieved runSpec in _is_scheduler_job that was empty"
244
+ )
245
+
246
+ if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
247
+ return False
248
+
249
+ if run_spec.get("resource") == "local-process":
250
+ # Any job pushed to a run queue that has a scheduler uri is
251
+ # allowed to use local-process
252
+ if run_spec.get("job"):
253
+ return True
254
+
255
+ # If a scheduler is local-process and run through CLI, also
256
+ # confirm command is in format: [wandb scheduler <sweep>]
257
+ cmd = run_spec.get("overrides", {}).get("entry_point", [])
258
+ if len(cmd) < 3:
259
+ return False
260
+
261
+ if cmd[:2] != ["wandb", "scheduler"]:
262
+ return False
263
+
264
+ return True
265
+
231
266
  async def fail_run_queue_item(
232
267
  self,
233
268
  run_queue_item_id: str,
@@ -298,6 +333,7 @@ class LaunchAgent:
298
333
 
299
334
  def print_status(self) -> None:
300
335
  """Prints the current status of the agent."""
336
+ self._last_status_print_time = time.time()
301
337
  output_str = "agent "
302
338
  if self._name:
303
339
  output_str += f"{self._name} "
@@ -344,8 +380,8 @@ class LaunchAgent:
344
380
  if run_state.lower() != "pending":
345
381
  return True
346
382
  except CommError:
347
- _logger.info(
348
- f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run"
383
+ self._internal_logger.info(
384
+ f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run",
349
385
  )
350
386
  return False
351
387
 
@@ -361,8 +397,8 @@ class LaunchAgent:
361
397
  job_and_run_status.entity is not None
362
398
  and job_and_run_status.entity != self._entity
363
399
  ):
364
- _logger.info(
365
- "Skipping check for completed run status because run is on a different entity than agent"
400
+ self._internal_logger.info(
401
+ "Skipping check for completed run status because run is on a different entity than agent",
366
402
  )
367
403
  elif exception is not None:
368
404
  tb_str = traceback.format_exception(
@@ -378,8 +414,8 @@ class LaunchAgent:
378
414
  fnames,
379
415
  )
380
416
  elif job_and_run_status.project is None or job_and_run_status.run_id is None:
381
- _logger.error(
382
- f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}"
417
+ self._internal_logger.info(
418
+ f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}",
383
419
  )
384
420
  wandb.termerror(
385
421
  "Missing project or run id on thread called finish thread id"
@@ -430,7 +466,9 @@ class LaunchAgent:
430
466
  job_and_run_status.run_queue_item_id, _msg, "run", fnames
431
467
  )
432
468
  else:
433
- _logger.info(f"Finish thread id {thread_id} had no exception and no run")
469
+ self._internal_logger.info(
470
+ f"Finish thread id {thread_id} had no exception and no run"
471
+ )
434
472
  wandb._sentry.exception(
435
473
  "launch agent called finish thread id on thread without run or exception"
436
474
  )
@@ -458,7 +496,7 @@ class LaunchAgent:
458
496
  await self.update_status(AGENT_RUNNING)
459
497
 
460
498
  # parse job
461
- _logger.info("Parsing launch spec")
499
+ self._internal_logger.info("Parsing launch spec")
462
500
  launch_spec = job["runSpec"]
463
501
 
464
502
  # Abort if this job attempts to override secure mode
@@ -511,6 +549,10 @@ class LaunchAgent:
511
549
  KeyboardInterrupt: if the agent is requested to stop.
512
550
  """
513
551
  self.print_status()
552
+ if self._verbosity == 0:
553
+ print_interval = DEFAULT_PRINT_INTERVAL
554
+ else:
555
+ print_interval = VERBOSE_PRINT_INTERVAL
514
556
  try:
515
557
  while True:
516
558
  job = None
@@ -532,7 +574,7 @@ class LaunchAgent:
532
574
  file_saver = RunQueueItemFileSaver(
533
575
  self._wandb_run, job["runQueueItemId"]
534
576
  )
535
- if _is_scheduler_job(job.get("runSpec", {})):
577
+ if self._is_scheduler_job(job.get("runSpec", {})):
536
578
  # If job is a scheduler, and we are already at the cap, ignore,
537
579
  # don't ack, and it will be pushed back onto the queue in 1 min
538
580
  if self.num_running_schedulers >= self._max_schedulers:
@@ -567,6 +609,7 @@ class LaunchAgent:
567
609
  await self.update_status(AGENT_POLLING)
568
610
  else:
569
611
  await self.update_status(AGENT_RUNNING)
612
+ if time.time() - self._last_status_print_time > print_interval:
570
613
  self.print_status()
571
614
 
572
615
  if self.num_running_jobs == self._max_jobs or job is None:
@@ -634,14 +677,14 @@ class LaunchAgent:
634
677
  await self.check_sweep_state(launch_spec, api)
635
678
 
636
679
  job_tracker.update_run_info(project)
637
- _logger.info("Fetching and validating project...")
680
+ self._internal_logger.info("Fetching and validating project...")
638
681
  project.fetch_and_validate_project()
639
- _logger.info("Fetching resource...")
682
+ self._internal_logger.info("Fetching resource...")
640
683
  resource = launch_spec.get("resource") or "local-container"
641
684
  backend_config: Dict[str, Any] = {
642
685
  PROJECT_SYNCHRONOUS: False, # agent always runs async
643
686
  }
644
- _logger.info("Loading backend")
687
+ self._internal_logger.info("Loading backend")
645
688
  override_build_config = launch_spec.get("builder")
646
689
 
647
690
  _, build_config, registry_config = construct_agent_configs(
@@ -661,13 +704,13 @@ class LaunchAgent:
661
704
  assert entrypoint is not None
662
705
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
663
706
 
664
- _logger.info("Backend loaded...")
707
+ self._internal_logger.info("Backend loaded...")
665
708
  if isinstance(backend, LocalProcessRunner):
666
709
  run = await backend.run(project, image_uri)
667
710
  else:
668
711
  assert image_uri
669
712
  run = await backend.run(project, image_uri)
670
- if _is_scheduler_job(launch_spec):
713
+ if self._is_scheduler_job(launch_spec):
671
714
  with self._jobs_lock:
672
715
  self._jobs[thread_id].is_scheduler = True
673
716
  wandb.termlog(
@@ -700,7 +743,7 @@ class LaunchAgent:
700
743
  if stopped_time is None:
701
744
  stopped_time = time.time()
702
745
  else:
703
- if time.time() - stopped_time > MAX_WAIT_RUN_STOPPED:
746
+ if time.time() - stopped_time > self._stopped_run_timeout:
704
747
  await run.cancel()
705
748
  await asyncio.sleep(AGENT_POLLING_INTERVAL)
706
749
 
@@ -720,7 +763,7 @@ class LaunchAgent:
720
763
  project=launch_spec["project"],
721
764
  )
722
765
  except Exception as e:
723
- _logger.debug(f"Fetch sweep state error: {e}")
766
+ self._internal_logger.debug(f"Fetch sweep state error: {e}")
724
767
  state = None
725
768
 
726
769
  if state != "RUNNING" and state != "PAUSED":
@@ -225,6 +225,14 @@ class AgentConfig(BaseModel):
225
225
  None,
226
226
  description="The builder to use.",
227
227
  )
228
+ verbosity: Optional[int] = Field(
229
+ 0,
230
+ description="How verbose to print, 0 = default, 1 = verbose, 2 = very verbose",
231
+ )
232
+ stopped_run_timeout: Optional[int] = Field(
233
+ 60,
234
+ description="How many seconds to wait after receiving the stop command before forcibly cancelling a run.",
235
+ )
228
236
 
229
237
  class Config:
230
238
  extra = "forbid"
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
2
  import base64
3
+ import copy
3
4
  import json
4
5
  import logging
5
6
  import os
@@ -8,7 +9,7 @@ import tarfile
8
9
  import tempfile
9
10
  import time
10
11
  import traceback
11
- from typing import Optional
12
+ from typing import Any, Dict, Optional
12
13
 
13
14
  import wandb
14
15
  from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
@@ -105,6 +106,7 @@ class KanikoBuilder(AbstractBuilder):
105
106
  secret_name: str = "",
106
107
  secret_key: str = "",
107
108
  image: str = "gcr.io/kaniko-project/executor:v1.11.0",
109
+ config: Optional[dict] = None,
108
110
  ):
109
111
  """Initialize a KanikoBuilder.
110
112
 
@@ -125,6 +127,7 @@ class KanikoBuilder(AbstractBuilder):
125
127
  self.secret_name = secret_name
126
128
  self.secret_key = secret_key
127
129
  self.image = image
130
+ self.kaniko_config = config or {}
128
131
 
129
132
  @classmethod
130
133
  def from_config(
@@ -170,6 +173,7 @@ class KanikoBuilder(AbstractBuilder):
170
173
  image_uri = config.get("destination")
171
174
  if image_uri is not None:
172
175
  registry = registry_from_uri(image_uri)
176
+ kaniko_config = config.get("kaniko-config", {})
173
177
 
174
178
  return cls(
175
179
  environment,
@@ -179,6 +183,7 @@ class KanikoBuilder(AbstractBuilder):
179
183
  secret_name=secret_name,
180
184
  secret_key=secret_key,
181
185
  image=kaniko_image,
186
+ config=kaniko_config,
182
187
  )
183
188
 
184
189
  async def verify(self) -> None:
@@ -289,7 +294,7 @@ class KanikoBuilder(AbstractBuilder):
289
294
 
290
295
  build_context = await self._upload_build_context(run_id, context_path)
291
296
  build_job = await self._create_kaniko_job(
292
- build_job_name, repo_uri, image_uri, build_context, core_v1
297
+ build_job_name, repo_uri, image_uri, build_context, core_v1, api_client
293
298
  )
294
299
  wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
295
300
 
@@ -324,7 +329,9 @@ class KanikoBuilder(AbstractBuilder):
324
329
  ):
325
330
  if job_tracker:
326
331
  job_tracker.set_err_stage("build")
327
- raise Exception(f"Failed to build image in kaniko for job {run_id}")
332
+ raise Exception(
333
+ f"Failed to build image in kaniko for job {run_id}. View logs with `kubectl logs -n {NAMESPACE} {build_job_name}`."
334
+ )
328
335
  try:
329
336
  pods_from_job = await core_v1.list_namespaced_pod(
330
337
  namespace=NAMESPACE, label_selector=f"job-name={build_job_name}"
@@ -371,23 +378,32 @@ class KanikoBuilder(AbstractBuilder):
371
378
  image_tag: str,
372
379
  build_context_path: str,
373
380
  core_client: client.CoreV1Api,
374
- ) -> "client.V1Job":
375
- env = []
376
- volume_mounts = []
377
- volumes = []
381
+ api_client,
382
+ ) -> Dict[str, Any]:
383
+ job = copy.deepcopy(self.kaniko_config)
384
+ job_metadata = job.get("metadata", {})
385
+ job_labels = job_metadata.get("labels", {})
386
+ job_spec = job.get("spec", {})
387
+ pod_template = job_spec.get("template", {})
388
+ pod_metadata = pod_template.get("metadata", {})
389
+ pod_labels = pod_metadata.get("labels", {})
390
+ pod_spec = pod_template.get("spec", {})
391
+ volumes = pod_spec.get("volumes", [])
392
+ containers = pod_spec.get("containers") or [{}]
393
+ if len(containers) > 1:
394
+ raise LaunchError(
395
+ "Multiple container configs not supported for kaniko builder."
396
+ )
397
+ container = containers[0]
398
+ volume_mounts = container.get("volumeMounts", [])
399
+ env = container.get("env", [])
400
+ custom_args = container.get("args", [])
378
401
 
379
402
  if PVC_MOUNT_PATH:
380
403
  volumes.append(
381
- client.V1Volume(
382
- name="kaniko-pvc",
383
- persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource(
384
- claim_name=PVC_NAME
385
- ),
386
- )
387
- )
388
- volume_mounts.append(
389
- client.V1VolumeMount(name="kaniko-pvc", mount_path="/context")
404
+ {"name": "kaniko-pvc", "persistentVolumeClaim": {"claimName": PVC_NAME}}
390
405
  )
406
+ volume_mounts.append({"name": "kaniko-pvc", "mountPath": "/context"})
391
407
 
392
408
  if bool(self.secret_name) != bool(self.secret_key):
393
409
  raise LaunchError(
@@ -395,13 +411,13 @@ class KanikoBuilder(AbstractBuilder):
395
411
  "for kaniko build. You provided only one of them."
396
412
  )
397
413
  if isinstance(self.registry, ElasticContainerRegistry):
398
- env += [
399
- client.V1EnvVar(
400
- name="AWS_REGION",
401
- value=self.registry.region,
402
- )
403
- ]
404
- # TODO: Refactor all of this environment/registry
414
+ env.append(
415
+ {
416
+ "name": "AWS_REGION",
417
+ "value": self.registry.region,
418
+ }
419
+ )
420
+ # TODO(ben): Refactor all of this environment/registry
405
421
  # specific stuff into methods of those classes.
406
422
  if isinstance(self.environment, AzureEnvironment):
407
423
  # Use the core api to check if the secret exists
@@ -416,52 +432,46 @@ class KanikoBuilder(AbstractBuilder):
416
432
  "namespace wandb. Please create it with the key password "
417
433
  "set to your azure storage access key."
418
434
  ) from e
419
- env += [
420
- client.V1EnvVar(
421
- name="AZURE_STORAGE_ACCESS_KEY",
422
- value_from=client.V1EnvVarSource(
423
- secret_key_ref=client.V1SecretKeySelector(
424
- name="azure-storage-access-key",
425
- key="password",
426
- )
427
- ),
428
- )
429
- ]
435
+ env.append(
436
+ {
437
+ "name": "AZURE_STORAGE_ACCESS_KEY",
438
+ "valueFrom": {
439
+ "secretKeyRef": {
440
+ "name": "azure-storage-access-key",
441
+ "key": "password",
442
+ }
443
+ },
444
+ }
445
+ )
430
446
  if DOCKER_CONFIG_SECRET:
431
447
  volumes.append(
432
- client.V1Volume(
433
- name="kaniko-docker-config",
434
- secret=client.V1SecretVolumeSource(
435
- secret_name=DOCKER_CONFIG_SECRET,
436
- items=[
437
- client.V1KeyToPath(
438
- key=".dockerconfigjson", path="config.json"
439
- )
448
+ {
449
+ "name": "kaniko-docker-config",
450
+ "secret": {
451
+ "secretName": DOCKER_CONFIG_SECRET,
452
+ "items": [
453
+ {
454
+ "key": ".dockerconfigjson",
455
+ "path": "config.json",
456
+ }
440
457
  ],
441
- ),
442
- )
458
+ },
459
+ }
443
460
  )
444
461
  volume_mounts.append(
445
- client.V1VolumeMount(
446
- name="kaniko-docker-config",
447
- mount_path="/kaniko/.docker",
448
- )
462
+ {"name": "kaniko-docker-config", "mountPath": "/kaniko/.docker"}
449
463
  )
450
464
  elif self.secret_name and self.secret_key:
451
- volumes += [
452
- client.V1Volume(
453
- name="docker-config",
454
- config_map=client.V1ConfigMapVolumeSource(
455
- name=f"docker-config-{job_name}",
456
- ),
457
- ),
458
- ]
459
- volume_mounts += [
460
- client.V1VolumeMount(
461
- name="docker-config", mount_path="/kaniko/.docker/"
462
- ),
463
- ]
464
- # TODO: I don't like conditioning on the registry type here. As a
465
+ volumes.append(
466
+ {
467
+ "name": "docker-config",
468
+ "configMap": {"name": f"docker-config-{job_name}"},
469
+ }
470
+ )
471
+ volume_mounts.append(
472
+ {"name": "docker-config", "mountPath": "/kaniko/.docker"}
473
+ )
474
+ # TODO(ben): I don't like conditioning on the registry type here. As a
465
475
  # future change I want the registry and environment classes to provide
466
476
  # a list of environment variables and volume mounts that need to be
467
477
  # added to the job. The environment class provides credentials for
@@ -475,90 +485,95 @@ class KanikoBuilder(AbstractBuilder):
475
485
  elif isinstance(self.registry, GoogleArtifactRegistry):
476
486
  mount_path = "/kaniko/.config/gcloud"
477
487
  key = "config.json"
478
- env += [
479
- client.V1EnvVar(
480
- name="GOOGLE_APPLICATION_CREDENTIALS",
481
- value="/kaniko/.config/gcloud/config.json",
482
- )
483
- ]
488
+ env.append(
489
+ {
490
+ "name": "GOOGLE_APPLICATION_CREDENTIALS",
491
+ "value": "/kaniko/.config/gcloud/config.json",
492
+ }
493
+ )
484
494
  else:
485
495
  raise LaunchError(
486
496
  f"Registry type {type(self.registry)} not supported by kaniko"
487
497
  )
488
- volume_mounts += [
489
- client.V1VolumeMount(
490
- name=self.secret_name,
491
- mount_path=mount_path,
492
- read_only=True,
493
- )
494
- ]
495
- volumes += [
496
- client.V1Volume(
497
- name=self.secret_name,
498
- secret=client.V1SecretVolumeSource(
499
- secret_name=self.secret_name,
500
- items=[client.V1KeyToPath(key=self.secret_key, path=key)],
501
- ),
502
- )
503
- ]
498
+ volumes.append(
499
+ {
500
+ "name": self.secret_name,
501
+ "secret": {
502
+ "secretName": self.secret_name,
503
+ "items": [{"key": self.secret_key, "path": key}],
504
+ },
505
+ }
506
+ )
507
+ volume_mounts.append(
508
+ {
509
+ "name": self.secret_name,
510
+ "mountPath": mount_path,
511
+ "readOnly": True,
512
+ }
513
+ )
504
514
  if isinstance(self.registry, AzureContainerRegistry):
505
- # ADd the docker config map
506
- volume_mounts += [
507
- client.V1VolumeMount(
508
- name="docker-config", mount_path="/kaniko/.docker/"
509
- ),
510
- ]
511
- volumes += [
512
- client.V1Volume(
513
- name="docker-config",
514
- config_map=client.V1ConfigMapVolumeSource(
515
- name=f"docker-config-{job_name}",
516
- ),
517
- ),
518
- ]
515
+ # Add the docker config map
516
+ volumes.append(
517
+ {
518
+ "name": "docker-config",
519
+ "configMap": {"name": f"docker-config-{job_name}"},
520
+ }
521
+ )
522
+ volume_mounts.append(
523
+ {"name": "docker-config", "mountPath": "/kaniko/.docker/"}
524
+ )
519
525
  # Kaniko doesn't want https:// at the begining of the image tag.
520
526
  destination = image_tag
521
527
  if destination.startswith("https://"):
522
528
  destination = destination.replace("https://", "")
523
- args = [
524
- f"--context={build_context_path}",
525
- f"--dockerfile={_WANDB_DOCKERFILE_NAME}",
526
- f"--destination={destination}",
527
- "--cache=true",
528
- f"--cache-repo={repository.replace('https://', '')}",
529
- "--snapshotMode=redo",
530
- "--compressed-caching=false",
529
+ args = {
530
+ "--context": build_context_path,
531
+ "--dockerfile": _WANDB_DOCKERFILE_NAME,
532
+ "--destination": destination,
533
+ "--cache": "true",
534
+ "--cache-repo": repository.replace("https://", ""),
535
+ "--snapshot-mode": "redo",
536
+ "--compressed-caching": "false",
537
+ }
538
+ for custom_arg in custom_args:
539
+ arg_name, arg_value = custom_arg.split("=", 1)
540
+ args[arg_name] = arg_value
541
+ parsed_args = [
542
+ f"{arg_name}={arg_value}" for arg_name, arg_value in args.items()
531
543
  ]
532
- container = client.V1Container(
533
- name="wandb-container-build",
534
- image=self.image,
535
- args=args,
536
- volume_mounts=volume_mounts,
537
- env=env if env else None,
538
- )
539
- # Create and configure a spec section
540
- labels = {"wandb": "launch"}
544
+ container["args"] = parsed_args
545
+
546
+ # Apply the rest of our defaults
547
+ pod_labels["wandb"] = "launch"
541
548
  # This annotation is required to enable azure workload identity.
542
549
  if isinstance(self.registry, AzureContainerRegistry):
543
- labels["azure.workload.identity/use"] = "true"
544
- template = client.V1PodTemplateSpec(
545
- metadata=client.V1ObjectMeta(labels=labels),
546
- spec=client.V1PodSpec(
547
- restart_policy="Never",
548
- active_deadline_seconds=_DEFAULT_BUILD_TIMEOUT_SECS,
549
- containers=[container],
550
- volumes=volumes,
551
- service_account_name=SERVICE_ACCOUNT_NAME,
552
- ),
550
+ pod_labels["azure.workload.identity/use"] = "true"
551
+ pod_spec["restartPolicy"] = pod_spec.get("restartPolicy", "Never")
552
+ pod_spec["activeDeadlineSeconds"] = pod_spec.get(
553
+ "activeDeadlineSeconds", _DEFAULT_BUILD_TIMEOUT_SECS
553
554
  )
554
- # Create the specification of job
555
- spec = client.V1JobSpec(template=template, backoff_limit=0)
556
- job = client.V1Job(
557
- api_version="batch/v1",
558
- kind="Job",
559
- metadata=client.V1ObjectMeta(
560
- name=job_name, namespace=NAMESPACE, labels={"wandb": "launch"}
561
- ),
562
- spec=spec,
555
+ pod_spec["serviceAccountName"] = pod_spec.get(
556
+ "serviceAccountName", SERVICE_ACCOUNT_NAME
563
557
  )
558
+ job_spec["backoffLimit"] = job_spec.get("backoffLimit", 0)
559
+ job_labels["wandb"] = "launch"
560
+ job_metadata["namespace"] = job_metadata.get("namespace", NAMESPACE)
561
+ job_metadata["name"] = job_metadata.get("name", job_name)
562
+ job["apiVersion"] = "batch/v1"
563
+ job["kind"] = "Job"
564
+
565
+ # Apply all nested configs from the bottom up
566
+ pod_metadata["labels"] = pod_labels
567
+ pod_template["metadata"] = pod_metadata
568
+ container["name"] = container.get("name", "wandb-container-build")
569
+ container["image"] = container.get("image", self.image)
570
+ container["volumeMounts"] = volume_mounts
571
+ container["env"] = env
572
+ pod_spec["containers"] = [container]
573
+ pod_spec["volumes"] = volumes
574
+ pod_template["spec"] = pod_spec
575
+ job_spec["template"] = pod_template
576
+ job_metadata["labels"] = job_labels
577
+ job["metadata"] = job_metadata
578
+ job["spec"] = job_spec
564
579
  return job