wandb 0.16.5__py3-none-any.whl → 0.17.0rc1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (141) hide show
  1. package_readme.md +95 -0
  2. wandb/__init__.py +2 -2
  3. wandb/agents/pyagent.py +0 -1
  4. wandb/analytics/sentry.py +2 -1
  5. wandb/apis/importers/internals/protocols.py +30 -56
  6. wandb/apis/importers/mlflow.py +13 -26
  7. wandb/apis/importers/wandb.py +8 -14
  8. wandb/apis/public/api.py +1 -0
  9. wandb/apis/public/artifacts.py +1 -0
  10. wandb/apis/public/files.py +1 -0
  11. wandb/apis/public/history.py +1 -0
  12. wandb/apis/public/jobs.py +1 -0
  13. wandb/apis/public/projects.py +1 -0
  14. wandb/apis/public/reports.py +1 -0
  15. wandb/apis/public/runs.py +1 -0
  16. wandb/apis/public/sweeps.py +1 -0
  17. wandb/apis/public/teams.py +1 -0
  18. wandb/apis/public/users.py +1 -0
  19. wandb/apis/reports/v1/_blocks.py +2 -6
  20. wandb/apis/reports/v2/gql.py +1 -0
  21. wandb/apis/reports/v2/interface.py +3 -4
  22. wandb/apis/reports/v2/internal.py +5 -8
  23. wandb/cli/cli.py +7 -4
  24. wandb/data_types.py +3 -3
  25. wandb/env.py +35 -5
  26. wandb/errors/__init__.py +5 -0
  27. wandb/integration/catboost/catboost.py +1 -1
  28. wandb/integration/fastai/__init__.py +1 -0
  29. wandb/integration/keras/__init__.py +1 -0
  30. wandb/integration/keras/keras.py +6 -6
  31. wandb/integration/langchain/wandb_tracer.py +1 -0
  32. wandb/integration/lightning/fabric/logger.py +1 -3
  33. wandb/integration/metaflow/metaflow.py +41 -6
  34. wandb/integration/openai/fine_tuning.py +77 -40
  35. wandb/keras/__init__.py +1 -0
  36. wandb/proto/v3/wandb_internal_pb2.py +364 -332
  37. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  38. wandb/proto/v4/wandb_internal_pb2.py +322 -316
  39. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  40. wandb/proto/wandb_internal_codegen.py +0 -25
  41. wandb/sdk/artifacts/artifact.py +41 -13
  42. wandb/sdk/artifacts/artifact_download_logger.py +1 -0
  43. wandb/sdk/artifacts/artifact_file_cache.py +18 -4
  44. wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
  45. wandb/sdk/artifacts/artifact_manifest.py +1 -0
  46. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
  47. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
  48. wandb/sdk/artifacts/artifact_saver.py +21 -21
  49. wandb/sdk/artifacts/artifact_state.py +1 -0
  50. wandb/sdk/artifacts/artifact_ttl.py +1 -0
  51. wandb/sdk/artifacts/exceptions.py +1 -0
  52. wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
  53. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
  54. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
  55. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
  56. wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
  57. wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
  58. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
  59. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
  60. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
  61. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
  62. wandb/sdk/artifacts/storage_policy.py +1 -0
  63. wandb/sdk/data_types/base_types/media.py +3 -6
  64. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
  65. wandb/sdk/integration_utils/auto_logging.py +5 -6
  66. wandb/sdk/integration_utils/data_logging.py +5 -1
  67. wandb/sdk/interface/interface.py +72 -37
  68. wandb/sdk/interface/interface_shared.py +7 -13
  69. wandb/sdk/internal/datastore.py +1 -1
  70. wandb/sdk/internal/handler.py +18 -2
  71. wandb/sdk/internal/internal.py +0 -1
  72. wandb/sdk/internal/internal_util.py +0 -1
  73. wandb/sdk/internal/job_builder.py +4 -3
  74. wandb/sdk/internal/profiler.py +1 -0
  75. wandb/sdk/internal/run.py +1 -0
  76. wandb/sdk/internal/sender.py +1 -1
  77. wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
  78. wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
  79. wandb/sdk/internal/system/assets/interfaces.py +6 -8
  80. wandb/sdk/internal/system/assets/open_metrics.py +2 -2
  81. wandb/sdk/internal/system/assets/trainium.py +1 -3
  82. wandb/sdk/launch/_launch.py +5 -0
  83. wandb/sdk/launch/_project_spec.py +10 -23
  84. wandb/sdk/launch/agent/agent.py +81 -37
  85. wandb/sdk/launch/agent/config.py +80 -11
  86. wandb/sdk/launch/builder/abstract.py +1 -0
  87. wandb/sdk/launch/builder/build.py +28 -1
  88. wandb/sdk/launch/builder/docker_builder.py +1 -0
  89. wandb/sdk/launch/builder/kaniko_builder.py +149 -134
  90. wandb/sdk/launch/builder/noop.py +1 -0
  91. wandb/sdk/launch/create_job.py +61 -48
  92. wandb/sdk/launch/environment/abstract.py +1 -0
  93. wandb/sdk/launch/environment/gcp_environment.py +1 -0
  94. wandb/sdk/launch/environment/local_environment.py +1 -0
  95. wandb/sdk/launch/loader.py +1 -0
  96. wandb/sdk/launch/registry/abstract.py +1 -0
  97. wandb/sdk/launch/registry/azure_container_registry.py +1 -0
  98. wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
  99. wandb/sdk/launch/registry/google_artifact_registry.py +1 -0
  100. wandb/sdk/launch/registry/local_registry.py +1 -0
  101. wandb/sdk/launch/runner/abstract.py +1 -0
  102. wandb/sdk/launch/runner/kubernetes_monitor.py +4 -1
  103. wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
  104. wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
  105. wandb/sdk/launch/sweeps/scheduler.py +4 -1
  106. wandb/sdk/launch/sweeps/scheduler_sweep.py +1 -0
  107. wandb/sdk/launch/sweeps/utils.py +1 -1
  108. wandb/sdk/launch/utils.py +21 -3
  109. wandb/sdk/lib/_settings_toposort_generated.py +1 -0
  110. wandb/sdk/lib/fsm.py +8 -12
  111. wandb/sdk/lib/gitlib.py +4 -4
  112. wandb/sdk/lib/lazyloader.py +0 -1
  113. wandb/sdk/lib/proto_util.py +1 -1
  114. wandb/sdk/lib/retry.py +3 -2
  115. wandb/sdk/lib/run_moment.py +7 -1
  116. wandb/sdk/service/service.py +17 -15
  117. wandb/sdk/verify/verify.py +2 -1
  118. wandb/sdk/wandb_init.py +2 -8
  119. wandb/sdk/wandb_manager.py +2 -2
  120. wandb/sdk/wandb_require.py +5 -0
  121. wandb/sdk/wandb_run.py +64 -46
  122. wandb/sdk/wandb_settings.py +2 -1
  123. wandb/sklearn/__init__.py +1 -0
  124. wandb/sklearn/plot/__init__.py +1 -0
  125. wandb/sklearn/plot/classifier.py +1 -0
  126. wandb/sklearn/plot/clusterer.py +1 -0
  127. wandb/sklearn/plot/regressor.py +1 -0
  128. wandb/sklearn/plot/shared.py +1 -0
  129. wandb/sklearn/utils.py +1 -0
  130. wandb/testing/relay.py +4 -4
  131. wandb/trigger.py +1 -0
  132. wandb/util.py +40 -17
  133. wandb/wandb_controller.py +0 -1
  134. wandb/wandb_torch.py +1 -2
  135. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/METADATA +68 -69
  136. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/RECORD +139 -140
  137. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/WHEEL +1 -2
  138. wandb/bin/apple_gpu_stats +0 -0
  139. wandb-0.16.5.dist-info/top_level.txt +0 -1
  140. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info}/entry_points.txt +0 -0
  141. {wandb-0.16.5.dist-info → wandb-0.17.0rc1.dist-info/licenses}/LICENSE +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  Arguments can come from a launch spec or call to wandb launch.
4
4
  """
5
+
5
6
  import enum
6
7
  import logging
7
8
  import os
@@ -14,6 +15,7 @@ import wandb.docker as docker
14
15
  from wandb.apis.internal import Api
15
16
  from wandb.errors import CommError
16
17
  from wandb.sdk.launch import utils
18
+ from wandb.sdk.launch.utils import get_entrypoint_file
17
19
  from wandb.sdk.lib.runid import generate_id
18
20
 
19
21
  from .errors import LaunchError
@@ -119,6 +121,7 @@ class LaunchProject:
119
121
  self.override_args: List[str] = overrides.get("args", [])
120
122
  self.override_config: Dict[str, Any] = overrides.get("run_config", {})
121
123
  self.override_artifacts: Dict[str, Any] = overrides.get("artifacts", {})
124
+ self.override_files: Dict[str, Any] = overrides.get("files", {})
122
125
  self.override_entrypoint: Optional[EntryPoint] = None
123
126
  self.override_dockerfile: Optional[str] = overrides.get("dockerfile")
124
127
  self.deps_type: Optional[str] = None
@@ -127,15 +130,15 @@ class LaunchProject:
127
130
  self._queue_name: Optional[str] = None
128
131
  self._queue_entity: Optional[str] = None
129
132
  self._run_queue_item_id: Optional[str] = None
130
- self._entry_point: Optional[
131
- EntryPoint
132
- ] = None # todo: keep multiple entrypoint support?
133
+ self._entry_point: Optional[EntryPoint] = (
134
+ None # todo: keep multiple entrypoint support?
135
+ )
133
136
 
134
137
  override_entrypoint = overrides.get("entry_point")
135
138
  if override_entrypoint:
136
139
  _logger.info("Adding override entry point")
137
140
  self.override_entrypoint = EntryPoint(
138
- name=_get_entrypoint_file(override_entrypoint),
141
+ name=get_entrypoint_file(override_entrypoint),
139
142
  command=override_entrypoint,
140
143
  )
141
144
 
@@ -536,24 +539,6 @@ class LaunchProject:
536
539
  self.git_version = branch_name
537
540
 
538
541
 
539
- def _get_entrypoint_file(entrypoint: List[str]) -> Optional[str]:
540
- """Get the entrypoint file from the given command.
541
-
542
- Args:
543
- entrypoint (List[str]): List of command and arguments.
544
-
545
- Returns:
546
- Optional[str]: The entrypoint file if found, otherwise None.
547
- """
548
- if not entrypoint:
549
- return None
550
- if entrypoint[0].endswith(".py") or entrypoint[0].endswith(".sh"):
551
- return entrypoint[0]
552
- if len(entrypoint) < 2:
553
- return None
554
- return entrypoint[1]
555
-
556
-
557
542
  class EntryPoint:
558
543
  """An entry point into a wandb launch specification."""
559
544
 
@@ -570,7 +555,9 @@ class EntryPoint:
570
555
 
571
556
  def update_entrypoint_path(self, new_path: str) -> None:
572
557
  """Updates the entrypoint path to a new path."""
573
- if len(self.command) == 2 and self.command[0] in ["python", "bash"]:
558
+ if len(self.command) == 2 and (
559
+ self.command[0].startswith("python") or self.command[0] == "bash"
560
+ ):
574
561
  self.command[1] = new_path
575
562
 
576
563
 
@@ -1,4 +1,5 @@
1
1
  """Implementation of launch agent."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import os
@@ -45,7 +46,10 @@ MAX_RESUME_COUNT = 5
45
46
 
46
47
  RUN_INFO_GRACE_PERIOD = 60
47
48
 
48
- MAX_WAIT_RUN_STOPPED = 60
49
+ DEFAULT_STOPPED_RUN_TIMEOUT = 60
50
+
51
+ DEFAULT_PRINT_INTERVAL = 5 * 60
52
+ VERBOSE_PRINT_INTERVAL = 20
49
53
 
50
54
  _env_timeout = os.environ.get("WANDB_LAUNCH_START_TIMEOUT")
51
55
  if _env_timeout:
@@ -105,30 +109,29 @@ def _max_from_config(
105
109
  return max_from_config
106
110
 
107
111
 
108
- def _is_scheduler_job(run_spec: Dict[str, Any]) -> bool:
109
- """Determine whether a job/runSpec is a sweep scheduler."""
110
- if not run_spec:
111
- _logger.debug("Recieved runSpec in _is_scheduler_job that was empty")
112
+ class InternalAgentLogger:
113
+ def __init__(self, verbosity=0):
114
+ self._print_to_terminal = verbosity >= 2
112
115
 
113
- if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
114
- return False
115
-
116
- if run_spec.get("resource") == "local-process":
117
- # Any job pushed to a run queue that has a scheduler uri is
118
- # allowed to use local-process
119
- if run_spec.get("job"):
120
- return True
116
+ def error(self, message: str):
117
+ if self._print_to_terminal:
118
+ wandb.termerror(f"{LOG_PREFIX}{message}")
119
+ _logger.error(f"{LOG_PREFIX}{message}")
121
120
 
122
- # If a scheduler is local-process and run through CLI, also
123
- # confirm command is in format: [wandb scheduler <sweep>]
124
- cmd = run_spec.get("overrides", {}).get("entry_point", [])
125
- if len(cmd) < 3:
126
- return False
121
+ def warn(self, message: str):
122
+ if self._print_to_terminal:
123
+ wandb.termwarn(f"{LOG_PREFIX}{message}")
124
+ _logger.warn(f"{LOG_PREFIX}{message}")
127
125
 
128
- if cmd[:2] != ["wandb", "scheduler"]:
129
- return False
126
+ def info(self, message: str):
127
+ if self._print_to_terminal:
128
+ wandb.termlog(f"{LOG_PREFIX}{message}")
129
+ _logger.info(f"{LOG_PREFIX}{message}")
130
130
 
131
- return True
131
+ def debug(self, message: str):
132
+ if self._print_to_terminal:
133
+ wandb.termlog(f"{LOG_PREFIX}{message}")
134
+ _logger.debug(f"{LOG_PREFIX}{message}")
132
135
 
133
136
 
134
137
  class LaunchAgent:
@@ -184,7 +187,13 @@ class LaunchAgent:
184
187
  self._max_jobs = _max_from_config(config, "max_jobs")
185
188
  self._max_schedulers = _max_from_config(config, "max_schedulers")
186
189
  self._secure_mode = config.get("secure_mode", False)
190
+ self._verbosity = config.get("verbosity", 0)
191
+ self._internal_logger = InternalAgentLogger(verbosity=self._verbosity)
192
+ self._last_status_print_time = 0.0
187
193
  self.default_config: Dict[str, Any] = config
194
+ self._stopped_run_timeout = config.get(
195
+ "stopped_run_timeout", DEFAULT_STOPPED_RUN_TIMEOUT
196
+ )
188
197
 
189
198
  # Get agent version from env var if present, otherwise wandb version
190
199
  self.version: str = "wandb@" + wandb.__version__
@@ -228,6 +237,33 @@ class LaunchAgent:
228
237
  self._name = agent_response["name"]
229
238
  self._init_agent_run()
230
239
 
240
+ def _is_scheduler_job(self, run_spec: Dict[str, Any]) -> bool:
241
+ """Determine whether a job/runSpec is a sweep scheduler."""
242
+ if not run_spec:
243
+ self._internal_logger.debug(
244
+ "Recieved runSpec in _is_scheduler_job that was empty"
245
+ )
246
+
247
+ if run_spec.get("uri") != Scheduler.PLACEHOLDER_URI:
248
+ return False
249
+
250
+ if run_spec.get("resource") == "local-process":
251
+ # Any job pushed to a run queue that has a scheduler uri is
252
+ # allowed to use local-process
253
+ if run_spec.get("job"):
254
+ return True
255
+
256
+ # If a scheduler is local-process and run through CLI, also
257
+ # confirm command is in format: [wandb scheduler <sweep>]
258
+ cmd = run_spec.get("overrides", {}).get("entry_point", [])
259
+ if len(cmd) < 3:
260
+ return False
261
+
262
+ if cmd[:2] != ["wandb", "scheduler"]:
263
+ return False
264
+
265
+ return True
266
+
231
267
  async def fail_run_queue_item(
232
268
  self,
233
269
  run_queue_item_id: str,
@@ -298,6 +334,7 @@ class LaunchAgent:
298
334
 
299
335
  def print_status(self) -> None:
300
336
  """Prints the current status of the agent."""
337
+ self._last_status_print_time = time.time()
301
338
  output_str = "agent "
302
339
  if self._name:
303
340
  output_str += f"{self._name} "
@@ -344,8 +381,8 @@ class LaunchAgent:
344
381
  if run_state.lower() != "pending":
345
382
  return True
346
383
  except CommError:
347
- _logger.info(
348
- f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run"
384
+ self._internal_logger.info(
385
+ f"Run {entity}/{project}/{run_id} with rqi id: {rqi_id} did not have associated run",
349
386
  )
350
387
  return False
351
388
 
@@ -361,8 +398,8 @@ class LaunchAgent:
361
398
  job_and_run_status.entity is not None
362
399
  and job_and_run_status.entity != self._entity
363
400
  ):
364
- _logger.info(
365
- "Skipping check for completed run status because run is on a different entity than agent"
401
+ self._internal_logger.info(
402
+ "Skipping check for completed run status because run is on a different entity than agent",
366
403
  )
367
404
  elif exception is not None:
368
405
  tb_str = traceback.format_exception(
@@ -378,8 +415,8 @@ class LaunchAgent:
378
415
  fnames,
379
416
  )
380
417
  elif job_and_run_status.project is None or job_and_run_status.run_id is None:
381
- _logger.error(
382
- f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}"
418
+ self._internal_logger.info(
419
+ f"called finish_thread_id on thread whose tracker has no project or run id. RunQueueItemID: {job_and_run_status.run_queue_item_id}",
383
420
  )
384
421
  wandb.termerror(
385
422
  "Missing project or run id on thread called finish thread id"
@@ -430,7 +467,9 @@ class LaunchAgent:
430
467
  job_and_run_status.run_queue_item_id, _msg, "run", fnames
431
468
  )
432
469
  else:
433
- _logger.info(f"Finish thread id {thread_id} had no exception and no run")
470
+ self._internal_logger.info(
471
+ f"Finish thread id {thread_id} had no exception and no run"
472
+ )
434
473
  wandb._sentry.exception(
435
474
  "launch agent called finish thread id on thread without run or exception"
436
475
  )
@@ -458,7 +497,7 @@ class LaunchAgent:
458
497
  await self.update_status(AGENT_RUNNING)
459
498
 
460
499
  # parse job
461
- _logger.info("Parsing launch spec")
500
+ self._internal_logger.info("Parsing launch spec")
462
501
  launch_spec = job["runSpec"]
463
502
 
464
503
  # Abort if this job attempts to override secure mode
@@ -511,6 +550,10 @@ class LaunchAgent:
511
550
  KeyboardInterrupt: if the agent is requested to stop.
512
551
  """
513
552
  self.print_status()
553
+ if self._verbosity == 0:
554
+ print_interval = DEFAULT_PRINT_INTERVAL
555
+ else:
556
+ print_interval = VERBOSE_PRINT_INTERVAL
514
557
  try:
515
558
  while True:
516
559
  job = None
@@ -532,7 +575,7 @@ class LaunchAgent:
532
575
  file_saver = RunQueueItemFileSaver(
533
576
  self._wandb_run, job["runQueueItemId"]
534
577
  )
535
- if _is_scheduler_job(job.get("runSpec", {})):
578
+ if self._is_scheduler_job(job.get("runSpec", {})):
536
579
  # If job is a scheduler, and we are already at the cap, ignore,
537
580
  # don't ack, and it will be pushed back onto the queue in 1 min
538
581
  if self.num_running_schedulers >= self._max_schedulers:
@@ -567,6 +610,7 @@ class LaunchAgent:
567
610
  await self.update_status(AGENT_POLLING)
568
611
  else:
569
612
  await self.update_status(AGENT_RUNNING)
613
+ if time.time() - self._last_status_print_time > print_interval:
570
614
  self.print_status()
571
615
 
572
616
  if self.num_running_jobs == self._max_jobs or job is None:
@@ -634,14 +678,14 @@ class LaunchAgent:
634
678
  await self.check_sweep_state(launch_spec, api)
635
679
 
636
680
  job_tracker.update_run_info(project)
637
- _logger.info("Fetching and validating project...")
681
+ self._internal_logger.info("Fetching and validating project...")
638
682
  project.fetch_and_validate_project()
639
- _logger.info("Fetching resource...")
683
+ self._internal_logger.info("Fetching resource...")
640
684
  resource = launch_spec.get("resource") or "local-container"
641
685
  backend_config: Dict[str, Any] = {
642
686
  PROJECT_SYNCHRONOUS: False, # agent always runs async
643
687
  }
644
- _logger.info("Loading backend")
688
+ self._internal_logger.info("Loading backend")
645
689
  override_build_config = launch_spec.get("builder")
646
690
 
647
691
  _, build_config, registry_config = construct_agent_configs(
@@ -661,13 +705,13 @@ class LaunchAgent:
661
705
  assert entrypoint is not None
662
706
  image_uri = await builder.build_image(project, entrypoint, job_tracker)
663
707
 
664
- _logger.info("Backend loaded...")
708
+ self._internal_logger.info("Backend loaded...")
665
709
  if isinstance(backend, LocalProcessRunner):
666
710
  run = await backend.run(project, image_uri)
667
711
  else:
668
712
  assert image_uri
669
713
  run = await backend.run(project, image_uri)
670
- if _is_scheduler_job(launch_spec):
714
+ if self._is_scheduler_job(launch_spec):
671
715
  with self._jobs_lock:
672
716
  self._jobs[thread_id].is_scheduler = True
673
717
  wandb.termlog(
@@ -700,7 +744,7 @@ class LaunchAgent:
700
744
  if stopped_time is None:
701
745
  stopped_time = time.time()
702
746
  else:
703
- if time.time() - stopped_time > MAX_WAIT_RUN_STOPPED:
747
+ if time.time() - stopped_time > self._stopped_run_timeout:
704
748
  await run.cancel()
705
749
  await asyncio.sleep(AGENT_POLLING_INTERVAL)
706
750
 
@@ -720,7 +764,7 @@ class LaunchAgent:
720
764
  project=launch_spec["project"],
721
765
  )
722
766
  except Exception as e:
723
- _logger.debug(f"Fetch sweep state error: {e}")
767
+ self._internal_logger.debug(f"Fetch sweep state error: {e}")
724
768
  state = None
725
769
 
726
770
  if state != "RUNNING" and state != "PAUSED":
@@ -80,17 +80,7 @@ class RegistryConfig(BaseModel):
80
80
  @validator("uri") # type: ignore
81
81
  @classmethod
82
82
  def validate_uri(cls, uri: str) -> str:
83
- for regex in [
84
- GCP_ARTIFACT_REGISTRY_URI_REGEX,
85
- AZURE_CONTAINER_REGISTRY_URI_REGEX,
86
- ELASTIC_CONTAINER_REGISTRY_URI_REGEX,
87
- ]:
88
- if regex.match(uri):
89
- return uri
90
- raise ValueError(
91
- "Invalid uri. URI must be a repository URI for an "
92
- "ECR, ACR, or GCP Artifact Registry."
93
- )
83
+ return validate_registry_uri(uri)
94
84
 
95
85
 
96
86
  class EnvironmentConfig(BaseModel):
@@ -186,6 +176,14 @@ class BuilderConfig(BaseModel):
186
176
  """Right now there are no required fields for docker builds."""
187
177
  return values
188
178
 
179
+ @validator("destination") # type: ignore
180
+ @classmethod
181
+ def validate_destination(cls, destination: Optional[str]) -> Optional[str]:
182
+ """Validate that the destination is a valid container registry URI."""
183
+ if destination is None:
184
+ return None
185
+ return validate_registry_uri(destination)
186
+
189
187
 
190
188
  class AgentConfig(BaseModel):
191
189
  """Configuration for the Launch agent."""
@@ -225,6 +223,77 @@ class AgentConfig(BaseModel):
225
223
  None,
226
224
  description="The builder to use.",
227
225
  )
226
+ verbosity: Optional[int] = Field(
227
+ 0,
228
+ description="How verbose to print, 0 = default, 1 = verbose, 2 = very verbose",
229
+ )
230
+ stopped_run_timeout: Optional[int] = Field(
231
+ 60,
232
+ description="How many seconds to wait after receiving the stop command before forcibly cancelling a run.",
233
+ )
228
234
 
229
235
  class Config:
230
236
  extra = "forbid"
237
+
238
+
239
+ def validate_registry_uri(uri: str) -> str:
240
+ """Validate that the registry URI is a valid container registry URI.
241
+
242
+ The URI should resolve to an image name in a container registry. The recognized
243
+ formats are for ECR, ACR, and GCP Artifact Registry. If the URI does not match
244
+ any of these formats, a warning is printed indicating the registry type is not
245
+ recognized and the agent can't guarantee that images can be pushed.
246
+
247
+ If the format is recognized but does not resolve to an image name, an
248
+ error is raised. For example, if the URI is an ECR URI but does not include
249
+ an image name or includes a tag as well as an image name, an error is raised.
250
+ """
251
+ tag_msg = (
252
+ "Destination for built images may not include a tag, but the URI provided "
253
+ "includes the suffix '{tag}'. Please remove the tag and try again. The agent "
254
+ "will automatically tag each image with a unique hash of the source code."
255
+ )
256
+ if uri.startswith("https://"):
257
+ uri = uri[8:]
258
+
259
+ match = GCP_ARTIFACT_REGISTRY_URI_REGEX.match(uri)
260
+ if match:
261
+ if match.group("tag"):
262
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
263
+ if not match.group("image_name"):
264
+ raise ValueError(
265
+ "An image name must be specified in the URI for a GCP Artifact Registry. "
266
+ "Please provide a uri with the format "
267
+ "'https://<region>-docker.pkg.dev/<project>/<repository>/<image>'."
268
+ )
269
+ return uri
270
+
271
+ match = AZURE_CONTAINER_REGISTRY_URI_REGEX.match(uri)
272
+ if match:
273
+ if match.group("tag"):
274
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
275
+ if not match.group("repository"):
276
+ raise ValueError(
277
+ "A repository name must be specified in the URI for an "
278
+ "Azure Container Registry. Please provide a uri with the format "
279
+ "'https://<registry-name>.azurecr.io/<repository>'."
280
+ )
281
+ return uri
282
+
283
+ match = ELASTIC_CONTAINER_REGISTRY_URI_REGEX.match(uri)
284
+ if match:
285
+ if match.group("tag"):
286
+ raise ValueError(tag_msg.format(tag=match.group("tag")))
287
+ if not match.group("repository"):
288
+ raise ValueError(
289
+ "A repository name must be specified in the URI for an "
290
+ "Elastic Container Registry. Please provide a uri with the format "
291
+ "'https://<account-id>.dkr.ecr.<region>.amazonaws.com/<repository>'."
292
+ )
293
+ return uri
294
+
295
+ wandb.termwarn(
296
+ f"Unable to recognize registry type in URI {uri}. You are responsible "
297
+ "for ensuring the agent can push images to this registry."
298
+ )
299
+ return uri
@@ -1,4 +1,5 @@
1
1
  """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
  from typing import TYPE_CHECKING, Any, Dict, Optional
4
5
 
@@ -237,7 +237,11 @@ def get_base_setup(
237
237
 
238
238
  CPU version is built on python, Accelerator version is built on user provided.
239
239
  """
240
- python_base_image = f"python:{py_version}-buster"
240
+ minor = int(py_version.split(".")[1])
241
+ if minor < 12:
242
+ python_base_image = f"python:{py_version}-buster"
243
+ else:
244
+ python_base_image = f"python:{py_version}-bookworm"
241
245
  if launch_project.accelerator_base_image:
242
246
  _logger.info(
243
247
  f"Using accelerator base image: {launch_project.accelerator_base_image}"
@@ -311,6 +315,11 @@ def get_env_vars_dict(
311
315
  _inject_wandb_config_env_vars(
312
316
  launch_project.override_config, env_vars, max_env_length
313
317
  )
318
+
319
+ _inject_file_overrides_env_vars(
320
+ launch_project.override_files, env_vars, max_env_length
321
+ )
322
+
314
323
  artifacts = {}
315
324
  # if we're spinning up a launch process from a job
316
325
  # we should tell the run to use that artifact
@@ -677,3 +686,21 @@ def _inject_wandb_config_env_vars(
677
686
  ]
678
687
  config_chunks_dict = {f"WANDB_CONFIG_{i}": chunk for i, chunk in enumerate(chunks)}
679
688
  env_dict.update(config_chunks_dict)
689
+
690
+
691
+ def _inject_file_overrides_env_vars(
692
+ overrides: Dict[str, Any], env_dict: Dict[str, Any], maximum_env_length: int
693
+ ) -> None:
694
+ str_overrides = json.dumps(overrides)
695
+ if len(str_overrides) <= maximum_env_length:
696
+ env_dict["WANDB_LAUNCH_FILE_OVERRIDES"] = str_overrides
697
+ return
698
+
699
+ chunks = [
700
+ str_overrides[i : i + maximum_env_length]
701
+ for i in range(0, len(str_overrides), maximum_env_length)
702
+ ]
703
+ overrides_chunks_dict = {
704
+ f"WANDB_LAUNCH_FILE_OVERRIDES_{i}": chunk for i, chunk in enumerate(chunks)
705
+ }
706
+ env_dict.update(overrides_chunks_dict)
@@ -1,4 +1,5 @@
1
1
  """Implementation of the docker builder."""
2
+
2
3
  import logging
3
4
  import os
4
5
  from typing import Any, Dict, Optional