wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public.py +18 -20
  5. wandb/beta/workflows.py +5 -6
  6. wandb/cli/cli.py +27 -27
  7. wandb/data_types.py +2 -0
  8. wandb/integration/langchain/wandb_tracer.py +16 -179
  9. wandb/integration/sagemaker/config.py +2 -2
  10. wandb/integration/tensorboard/log.py +4 -4
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  13. wandb/proto/wandb_deprecated.py +3 -1
  14. wandb/sdk/__init__.py +1 -4
  15. wandb/sdk/artifacts/__init__.py +0 -14
  16. wandb/sdk/artifacts/artifact.py +1757 -277
  17. wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
  18. wandb/sdk/artifacts/artifact_state.py +10 -0
  19. wandb/sdk/artifacts/artifacts_cache.py +7 -8
  20. wandb/sdk/artifacts/exceptions.py +4 -4
  21. wandb/sdk/artifacts/storage_handler.py +2 -2
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
  23. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
  24. wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
  25. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
  26. wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
  27. wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
  28. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
  29. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
  30. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
  31. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
  33. wandb/sdk/artifacts/storage_policy.py +3 -3
  34. wandb/sdk/data_types/_dtypes.py +7 -12
  35. wandb/sdk/data_types/base_types/json_metadata.py +2 -2
  36. wandb/sdk/data_types/base_types/media.py +5 -6
  37. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  38. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
  39. wandb/sdk/data_types/helper_types/classes.py +5 -8
  40. wandb/sdk/data_types/helper_types/image_mask.py +4 -5
  41. wandb/sdk/data_types/histogram.py +3 -3
  42. wandb/sdk/data_types/html.py +3 -4
  43. wandb/sdk/data_types/image.py +4 -5
  44. wandb/sdk/data_types/molecule.py +2 -2
  45. wandb/sdk/data_types/object_3d.py +3 -3
  46. wandb/sdk/data_types/plotly.py +2 -2
  47. wandb/sdk/data_types/saved_model.py +7 -8
  48. wandb/sdk/data_types/trace_tree.py +4 -4
  49. wandb/sdk/data_types/video.py +4 -4
  50. wandb/sdk/interface/interface.py +8 -10
  51. wandb/sdk/internal/file_stream.py +2 -3
  52. wandb/sdk/internal/internal_api.py +99 -4
  53. wandb/sdk/internal/job_builder.py +15 -7
  54. wandb/sdk/internal/sender.py +4 -0
  55. wandb/sdk/internal/settings_static.py +1 -0
  56. wandb/sdk/launch/_project_spec.py +9 -7
  57. wandb/sdk/launch/agent/agent.py +115 -58
  58. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  59. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  60. wandb/sdk/launch/builder/abstract.py +5 -1
  61. wandb/sdk/launch/builder/build.py +16 -10
  62. wandb/sdk/launch/builder/docker_builder.py +9 -2
  63. wandb/sdk/launch/builder/kaniko_builder.py +108 -22
  64. wandb/sdk/launch/builder/noop.py +3 -1
  65. wandb/sdk/launch/environment/aws_environment.py +2 -1
  66. wandb/sdk/launch/environment/azure_environment.py +124 -0
  67. wandb/sdk/launch/github_reference.py +30 -18
  68. wandb/sdk/launch/launch.py +1 -1
  69. wandb/sdk/launch/loader.py +15 -0
  70. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  71. wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
  72. wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
  73. wandb/sdk/launch/runner/abstract.py +19 -3
  74. wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
  75. wandb/sdk/launch/runner/local_container.py +101 -48
  76. wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
  77. wandb/sdk/launch/runner/vertex_runner.py +8 -4
  78. wandb/sdk/launch/sweeps/scheduler.py +102 -27
  79. wandb/sdk/launch/sweeps/utils.py +21 -0
  80. wandb/sdk/launch/utils.py +19 -7
  81. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  82. wandb/sdk/service/server.py +22 -9
  83. wandb/sdk/service/service.py +27 -8
  84. wandb/sdk/verify/verify.py +6 -9
  85. wandb/sdk/wandb_config.py +2 -4
  86. wandb/sdk/wandb_init.py +2 -0
  87. wandb/sdk/wandb_require.py +7 -0
  88. wandb/sdk/wandb_run.py +32 -35
  89. wandb/sdk/wandb_settings.py +10 -3
  90. wandb/testing/relay.py +15 -2
  91. wandb/util.py +55 -23
  92. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
  93. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
  94. wandb/integration/langchain/util.py +0 -191
  95. wandb/sdk/artifacts/invalid_artifact.py +0 -23
  96. wandb/sdk/artifacts/lazy_artifact.py +0 -162
  97. wandb/sdk/artifacts/local_artifact.py +0 -719
  98. wandb/sdk/artifacts/public_artifact.py +0 -1188
  99. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  100. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  101. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
  102. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,6 @@ import pprint
5
5
  import threading
6
6
  import time
7
7
  import traceback
8
- from dataclasses import dataclass
9
8
  from multiprocessing import Event
10
9
  from multiprocessing.pool import ThreadPool
11
10
  from typing import Any, Dict, List, Optional, Union
@@ -13,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
13
12
  import wandb
14
13
  from wandb.apis.internal import Api
15
14
  from wandb.errors import CommError
16
- from wandb.sdk.launch._project_spec import LaunchProject
15
+ from wandb.sdk.launch.launch_add import launch_add
17
16
  from wandb.sdk.launch.runner.local_container import LocalSubmittedRun
18
17
  from wandb.sdk.launch.sweeps.scheduler import Scheduler
19
18
  from wandb.sdk.lib import runid
@@ -22,8 +21,9 @@ from .. import loader
22
21
  from .._project_spec import create_project_from_spec, fetch_and_validate_project
23
22
  from ..builder.build import construct_builder_args
24
23
  from ..errors import LaunchDockerError, LaunchError
25
- from ..runner.abstract import AbstractRun
26
24
  from ..utils import LAUNCH_DEFAULT_PROJECT, LOG_PREFIX, PROJECT_SYNCHRONOUS
25
+ from .job_status_tracker import JobAndRunStatusTracker
26
+ from .run_queue_item_file_saver import RunQueueItemFileSaver
27
27
 
28
28
  AGENT_POLLING_INTERVAL = 10
29
29
  ACTIVE_SWEEP_POLLING_INTERVAL = 1 # more frequent when we know we have jobs
@@ -36,28 +36,9 @@ HIDDEN_AGENT_RUN_TYPE = "sweep-controller"
36
36
 
37
37
  MAX_THREADS = 64
38
38
 
39
- _logger = logging.getLogger(__name__)
40
-
41
-
42
- @dataclass
43
- class JobAndRunStatus:
44
- run_queue_item_id: str
45
- run_id: Optional[str] = None
46
- project: Optional[str] = None
47
- entity: Optional[str] = None
48
- run: Optional[AbstractRun] = None
49
- failed_to_start: bool = False
50
- completed_status: Optional[str] = None
51
- is_scheduler: bool = False
52
-
53
- @property
54
- def job_completed(self) -> bool:
55
- return self.failed_to_start or self.completed_status is not None
39
+ MAX_RESUME_COUNT = 5
56
40
 
57
- def update_run_info(self, launch_project: LaunchProject) -> None:
58
- self.run_id = launch_project.run_id
59
- self.project = launch_project.target_project
60
- self.entity = launch_project.target_entity
41
+ _logger = logging.getLogger(__name__)
61
42
 
62
43
 
63
44
  def _convert_access(access: str) -> str:
@@ -139,7 +120,7 @@ class LaunchAgent:
139
120
  self._api = api
140
121
  self._base_url = self._api.settings().get("base_url")
141
122
  self._ticks = 0
142
- self._jobs: Dict[int, JobAndRunStatus] = {}
123
+ self._jobs: Dict[int, JobAndRunStatusTracker] = {}
143
124
  self._jobs_lock = threading.Lock()
144
125
  self._jobs_event = Event()
145
126
  self._jobs_event.set()
@@ -180,22 +161,31 @@ class LaunchAgent:
180
161
  self._id, self.gorilla_supports_agents
181
162
  )
182
163
  self._name = agent_response["name"]
183
- if self.gorilla_supports_agents:
184
- self._init_agent_run()
164
+ self._init_agent_run()
185
165
 
186
- def fail_run_queue_item(self, run_queue_item_id: str) -> None:
166
+ def fail_run_queue_item(
167
+ self,
168
+ run_queue_item_id: str,
169
+ message: str,
170
+ phase: str,
171
+ files: Optional[List[str]] = None,
172
+ ) -> None:
187
173
  if self._gorilla_supports_fail_run_queue_items:
188
- self._api.fail_run_queue_item(run_queue_item_id)
174
+ self._api.fail_run_queue_item(run_queue_item_id, message, phase, files)
189
175
 
190
176
  def _init_agent_run(self) -> None:
191
- settings = wandb.Settings(silent=True, disable_git=True)
192
- wandb.init(
193
- project=self._project,
194
- entity=self._entity,
195
- settings=settings,
196
- id=self._name,
197
- job_type=HIDDEN_AGENT_RUN_TYPE,
198
- )
177
+ # TODO: has it been long enough that all backends support agents?
178
+ if self.gorilla_supports_agents:
179
+ settings = wandb.Settings(silent=True, disable_git=True)
180
+ self._wandb_run = wandb.init(
181
+ project=self._project,
182
+ entity=self._entity,
183
+ settings=settings,
184
+ id=self._name,
185
+ job_type=HIDDEN_AGENT_RUN_TYPE,
186
+ )
187
+ else:
188
+ self._wandb_run = None
199
189
 
200
190
  @property
201
191
  def thread_ids(self) -> List[int]:
@@ -279,24 +269,35 @@ class LaunchAgent:
279
269
  """Removes the job from our list for now."""
280
270
  job_and_run_status = self._jobs[thread_id]
281
271
  if (
282
- not job_and_run_status.run_id
283
- or not job_and_run_status.project
284
- or exception is not None
272
+ job_and_run_status.entity is not None
273
+ and job_and_run_status.entity != self._entity
285
274
  ):
286
- self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
287
- elif job_and_run_status.entity != self._entity:
288
275
  _logger.info(
289
276
  "Skipping check for completed run status because run is on a different entity than agent"
290
277
  )
278
+ elif exception is not None:
279
+ tb_str = traceback.format_exception(
280
+ type(exception), value=exception, tb=exception.__traceback__
281
+ )
282
+ fnames = job_and_run_status.saver.save_contents(
283
+ "".join(tb_str), "error.log", "error"
284
+ )
285
+ self.fail_run_queue_item(
286
+ job_and_run_status.run_queue_item_id,
287
+ str(exception),
288
+ job_and_run_status.err_stage,
289
+ fnames,
290
+ )
291
291
  elif job_and_run_status.completed_status not in ["stopped", "failed"]:
292
292
  _logger.info(
293
293
  "Skipping check for completed run status because run was successful"
294
294
  )
295
- else:
295
+ elif job_and_run_status.run is not None:
296
296
  run_info = None
297
297
  # sweep runs exist but have no info before they are started
298
298
  # so run_info returned will be None
299
299
  # normal runs just throw a comm error
300
+ # TODO: make more clear
300
301
  try:
301
302
  run_info = self._api.get_run_info(
302
303
  self._entity, job_and_run_status.project, job_and_run_status.run_id
@@ -305,7 +306,22 @@ class LaunchAgent:
305
306
  except CommError:
306
307
  pass
307
308
  if run_info is None:
308
- self.fail_run_queue_item(job_and_run_status.run_queue_item_id)
309
+ _msg = "The submitted run was not successfully started"
310
+ fnames = None
311
+
312
+ logs = job_and_run_status.run.get_logs()
313
+ if logs:
314
+ fnames = job_and_run_status.saver.save_contents(
315
+ logs, "error.log", "error"
316
+ )
317
+ self.fail_run_queue_item(
318
+ job_and_run_status.run_queue_item_id, _msg, "run", fnames
319
+ )
320
+ else:
321
+ _logger.info("Finish thread id had no exception, ror run")
322
+ wandb._sentry.exception(
323
+ "launch agent called finish thread id on thread without run or exception"
324
+ )
309
325
 
310
326
  # TODO: keep logs or something for the finished jobs
311
327
  with self._jobs_lock:
@@ -322,7 +338,9 @@ class LaunchAgent:
322
338
  if job.job_completed:
323
339
  self.finish_thread_id(thread_id)
324
340
 
325
- def run_job(self, job: Dict[str, Any]) -> None:
341
+ def run_job(
342
+ self, job: Dict[str, Any], queue: str, file_saver: RunQueueItemFileSaver
343
+ ) -> None:
326
344
  """Set up project and run the job.
327
345
 
328
346
  Arguments:
@@ -348,6 +366,8 @@ class LaunchAgent:
348
366
  job,
349
367
  self.default_config,
350
368
  self._api,
369
+ queue,
370
+ file_saver,
351
371
  ),
352
372
  )
353
373
 
@@ -401,6 +421,9 @@ class LaunchAgent:
401
421
  for queue in self._queues:
402
422
  job = self.pop_from_queue(queue)
403
423
  if job:
424
+ file_saver = RunQueueItemFileSaver(
425
+ self._wandb_run, job["runQueueItemId"]
426
+ )
404
427
  if _is_scheduler_job(job.get("runSpec")):
405
428
  # If job is a scheduler, and we are already at the cap, ignore,
406
429
  # don't ack, and it will be pushed back onto the queue in 1 min
@@ -413,13 +436,25 @@ class LaunchAgent:
413
436
  continue
414
437
 
415
438
  try:
416
- self.run_job(job)
439
+ self.run_job(job, queue, file_saver)
417
440
  except Exception as e:
418
441
  wandb.termerror(
419
442
  f"{LOG_PREFIX}Error running job: {traceback.format_exc()}"
420
443
  )
421
444
  wandb._sentry.exception(e)
422
- self.fail_run_queue_item(job["runQueueItemId"])
445
+
446
+ # always the first phase, because we only enter phase 2 within the thread
447
+ files = file_saver.save_contents(
448
+ contents=traceback.format_exc(),
449
+ fname="error.log",
450
+ file_sub_type="error",
451
+ )
452
+ self.fail_run_queue_item(
453
+ run_queue_item_id=job["runQueueItemId"],
454
+ message=str(e),
455
+ phase="agent",
456
+ files=files,
457
+ )
423
458
 
424
459
  for thread_id in self.thread_ids:
425
460
  self._update_finished(thread_id)
@@ -454,11 +489,18 @@ class LaunchAgent:
454
489
  job: Dict[str, Any],
455
490
  default_config: Dict[str, Any],
456
491
  api: Api,
492
+ queue: str,
493
+ file_saver: RunQueueItemFileSaver,
457
494
  ) -> None:
458
495
  thread_id = threading.current_thread().ident
459
496
  assert thread_id is not None
497
+ job_tracker = JobAndRunStatusTracker(job["runQueueItemId"], queue, file_saver)
498
+ with self._jobs_lock:
499
+ self._jobs[thread_id] = job_tracker
460
500
  try:
461
- self._thread_run_job(launch_spec, job, default_config, api, thread_id)
501
+ self._thread_run_job(
502
+ launch_spec, job, default_config, api, queue, thread_id, job_tracker
503
+ )
462
504
  except LaunchDockerError as e:
463
505
  wandb.termerror(
464
506
  f"{LOG_PREFIX}agent {self._name} encountered an issue while starting Docker, see above output for details."
@@ -476,11 +518,10 @@ class LaunchAgent:
476
518
  job: Dict[str, Any],
477
519
  default_config: Dict[str, Any],
478
520
  api: Api,
521
+ queue: str,
479
522
  thread_id: int,
523
+ job_tracker: JobAndRunStatusTracker,
480
524
  ) -> None:
481
- job_tracker = JobAndRunStatus(job["runQueueItemId"])
482
- with self._jobs_lock:
483
- self._jobs[thread_id] = job_tracker
484
525
  project = create_project_from_spec(launch_spec, api)
485
526
  job_tracker.update_run_info(project)
486
527
  _logger.info("Fetching and validating project...")
@@ -505,8 +546,7 @@ class LaunchAgent:
505
546
  backend = loader.runner_from_config(resource, api, backend_config, environment)
506
547
  _logger.info("Backend loaded...")
507
548
  api.ack_run_queue_item(job["runQueueItemId"], project.run_id)
508
- run = backend.run(project, builder)
509
-
549
+ run = backend.run(project, builder, job_tracker)
510
550
  if _is_scheduler_job(launch_spec):
511
551
  with self._jobs_lock:
512
552
  self._jobs[thread_id].is_scheduler = True
@@ -522,15 +562,17 @@ class LaunchAgent:
522
562
  with self._jobs_lock:
523
563
  job_tracker.run = run
524
564
  while self._jobs_event.is_set():
525
- if self._check_run_finished(job_tracker):
565
+ if self._check_run_finished(job_tracker, launch_spec):
526
566
  return
527
567
  time.sleep(AGENT_POLLING_INTERVAL)
528
568
  # temp: for local, kill all jobs. we don't yet have good handling for different
529
569
  # types of runners in general
530
- if isinstance(run, LocalSubmittedRun):
531
- run.command_proc.kill()
570
+ if isinstance(run, LocalSubmittedRun) and run._command_proc is not None:
571
+ run._command_proc.kill()
532
572
 
533
- def _check_run_finished(self, job_tracker: JobAndRunStatus) -> bool:
573
+ def _check_run_finished(
574
+ self, job_tracker: JobAndRunStatusTracker, launch_spec: Dict[str, Any]
575
+ ) -> bool:
534
576
  if job_tracker.completed_status:
535
577
  return True
536
578
 
@@ -547,13 +589,28 @@ class LaunchAgent:
547
589
  try:
548
590
  run = job_tracker.run
549
591
  status = run.get_status().state
550
- if status in ["stopped", "failed", "finished"]:
592
+ if status in ["stopped", "failed", "finished", "preempted"]:
551
593
  if job_tracker.is_scheduler:
552
594
  wandb.termlog(f"{LOG_PREFIX}Scheduler finished with ID: {run.id}")
553
595
  else:
554
596
  wandb.termlog(f"{LOG_PREFIX}Job finished with ID: {run.id}")
555
597
  with self._jobs_lock:
556
598
  job_tracker.completed_status = status
599
+ if status == "preempted":
600
+ config = launch_spec.copy()
601
+ config["run_id"] = job_tracker.run_id
602
+ config["_resume_count"] = config.get("_resume_count", 0) + 1
603
+ if config["_resume_count"] > MAX_RESUME_COUNT:
604
+ wandb.termlog(
605
+ f"{LOG_PREFIX}Run {job_tracker.run_id} has already resumed {MAX_RESUME_COUNT} times."
606
+ )
607
+ return True
608
+ wandb.termlog(f"{LOG_PREFIX}Requeueing run {job_tracker.run_id}.")
609
+ launch_add(
610
+ config=config,
611
+ project_queue=self._project,
612
+ queue_name=job_tracker.queue,
613
+ )
557
614
  return True
558
615
  return False
559
616
  except LaunchError as e:
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from wandb.sdk.launch._project_spec import LaunchProject
5
+
6
+ from ..runner.abstract import AbstractRun
7
+ from .run_queue_item_file_saver import RunQueueItemFileSaver
8
+
9
+
10
+ @dataclass
11
+ class JobAndRunStatusTracker:
12
+ run_queue_item_id: str
13
+ queue: str
14
+ saver: RunQueueItemFileSaver
15
+ run_id: Optional[str] = None
16
+ project: Optional[str] = None
17
+ entity: Optional[str] = None
18
+ run: Optional[AbstractRun] = None
19
+ failed_to_start: bool = False
20
+ completed_status: Optional[str] = None
21
+ is_scheduler: bool = False
22
+ err_stage: str = "agent"
23
+
24
+ @property
25
+ def job_completed(self) -> bool:
26
+ return self.failed_to_start or self.completed_status is not None
27
+
28
+ def update_run_info(self, launch_project: LaunchProject) -> None:
29
+ self.run_id = launch_project.run_id
30
+ self.project = launch_project.target_project
31
+ self.entity = launch_project.target_entity
32
+
33
+ def set_err_stage(self, stage: str) -> None:
34
+ self.err_stage = stage
@@ -0,0 +1,45 @@
1
+ """Implementation of the run queue item file saver class."""
2
+
3
+ import os
4
+ import sys
5
+ from typing import List, Optional, Union
6
+
7
+ import wandb
8
+ from wandb.sdk.lib import RunDisabled
9
+ from wandb.sdk.wandb_run import Run
10
+
11
+ if sys.version_info >= (3, 8):
12
+ from typing import Literal
13
+ else:
14
+ from typing_extensions import Literal
15
+
16
+ FileSubtypes = Literal["warning", "error"]
17
+
18
+
19
+ class RunQueueItemFileSaver:
20
+ def __init__(
21
+ self, agent_run: Optional[Union[Run, RunDisabled]], run_queue_item_id: str
22
+ ):
23
+ self.run_queue_item_id = run_queue_item_id
24
+ self.run = agent_run
25
+
26
+ def save_contents(
27
+ self, contents: str, fname: str, file_sub_type: FileSubtypes
28
+ ) -> Optional[List[str]]:
29
+ if not isinstance(self.run, Run):
30
+ wandb.termwarn("Not saving file contents because agent has no run")
31
+ return None
32
+ root_dir = self.run._settings.files_dir
33
+ saved_run_path = os.path.join(self.run_queue_item_id, file_sub_type, fname)
34
+ local_path = os.path.join(root_dir, saved_run_path)
35
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
36
+ with open(local_path, "w") as f:
37
+ f.write(contents)
38
+ res = self.run.save(local_path, base_path=root_dir, policy="now")
39
+ if isinstance(res, list):
40
+ return [saved_run_path]
41
+ else:
42
+ wandb.termwarn(
43
+ f"Failed to save files for run queue item: {self.run_queue_item_id}"
44
+ )
45
+ return None
@@ -1,12 +1,15 @@
1
1
  """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
2
2
  from abc import ABC, abstractmethod
3
- from typing import Any, Dict
3
+ from typing import TYPE_CHECKING, Any, Dict, Optional
4
4
 
5
5
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
6
6
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
7
7
 
8
8
  from .._project_spec import EntryPoint, LaunchProject
9
9
 
10
+ if TYPE_CHECKING:
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
12
+
10
13
 
11
14
  class AbstractBuilder(ABC):
12
15
  """Abstract plugin class defining the interface needed to build container images for W&B Launch."""
@@ -63,6 +66,7 @@ class AbstractBuilder(ABC):
63
66
  self,
64
67
  launch_project: LaunchProject,
65
68
  entrypoint: EntryPoint,
69
+ job_tracker: Optional["JobAndRunStatusTracker"] = None,
66
70
  ) -> str:
67
71
  """Build the image for the given project.
68
72
 
@@ -38,8 +38,6 @@ _logger = logging.getLogger(__name__)
38
38
  _GENERATED_DOCKERFILE_NAME = "Dockerfile.wandb-autogenerated"
39
39
  DEFAULT_ENTRYPOINT = "_wandb_default_entrypoint"
40
40
 
41
- DEFAULT_CUDA_VERSION = "10.0"
42
-
43
41
 
44
42
  def validate_docker_installation() -> None:
45
43
  """Verify if Docker is installed on host machine."""
@@ -103,8 +101,12 @@ FROM {py_base_image} as base
103
101
  """
104
102
 
105
103
  # this goes into base_setup in TEMPLATE
106
- CUDA_SETUP_TEMPLATE = """
107
- FROM {cuda_base_image} as base
104
+ ACCELERATOR_SETUP_TEMPLATE = """
105
+ FROM {accelerator_base_image} as base
106
+
107
+ # make non-interactive so build doesn't block on questions
108
+ ENV DEBIAN_FRONTEND=noninteractive
109
+
108
110
  # TODO: once NVIDIA their linux repository keys for all docker images
109
111
  RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/$(cat /etc/os-release | grep ^ID= | cut -d "=" -f2 )$(cat /etc/os-release | grep ^VERSION_ID= | cut -d "=" -f2 | sed -e 's/[\".]//g' )/$(uname -i)/3bf863cc.pub
110
112
  RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/$(cat /etc/os-release | grep ^ID= | cut -d "=" -f2 )$(cat /etc/os-release | grep ^VERSION_ID= | cut -d "=" -f2 | sed -e 's/[\".]//g' )/$(uname -i)/7fa2af80.pub
@@ -184,12 +186,14 @@ def get_base_setup(
184
186
  ) -> str:
185
187
  """Fill in the Dockerfile templates for stage 2 of build.
186
188
 
187
- CPU version is built on python, GPU version is built on nvidia:cuda.
189
+ CPU version is built on python, Accelerator version is built on user provided.
188
190
  """
189
191
  python_base_image = f"python:{py_version}-buster"
190
- if launch_project.cuda_base_image:
191
- _logger.info(f"Using cuda base image: {launch_project.cuda_base_image}")
192
- # cuda image doesn't come with python tooling
192
+ if launch_project.accelerator_base_image:
193
+ _logger.info(
194
+ f"Using accelerator base image: {launch_project.accelerator_base_image}"
195
+ )
196
+ # accelerator base images doesn't come with python tooling
193
197
  if py_major == "2":
194
198
  python_packages = [
195
199
  f"python{py_version}",
@@ -204,8 +208,8 @@ def get_base_setup(
204
208
  "python3-pip",
205
209
  "python3-setuptools",
206
210
  ]
207
- base_setup = CUDA_SETUP_TEMPLATE.format(
208
- cuda_base_image=launch_project.cuda_base_image,
211
+ base_setup = ACCELERATOR_SETUP_TEMPLATE.format(
212
+ accelerator_base_image=launch_project.accelerator_base_image,
209
213
  python_packages=" \\\n".join(python_packages),
210
214
  py_version=py_version,
211
215
  )
@@ -243,6 +247,8 @@ def get_env_vars_dict(launch_project: LaunchProject, api: Api) -> Dict[str, str]
243
247
  env_vars["WANDB_USERNAME"] = launch_project.launch_spec["author"]
244
248
  if launch_project.sweep_id:
245
249
  env_vars["WANDB_SWEEP_ID"] = launch_project.sweep_id
250
+ if launch_project.launch_spec.get("_resume_count"):
251
+ env_vars["WANDB_RESUME"] = "must"
246
252
 
247
253
  # TODO: handle env vars > 32760 characters
248
254
  env_vars["WANDB_CONFIG"] = json.dumps(launch_project.override_config)
@@ -1,10 +1,11 @@
1
1
  """Implementation of the docker builder."""
2
2
  import logging
3
3
  import os
4
- from typing import Any, Dict
4
+ from typing import Any, Dict, Optional
5
5
 
6
6
  import wandb
7
7
  import wandb.docker as docker
8
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
8
9
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
9
10
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
10
11
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
@@ -111,6 +112,7 @@ class DockerBuilder(AbstractBuilder):
111
112
  self,
112
113
  launch_project: LaunchProject,
113
114
  entrypoint: EntryPoint,
115
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
114
116
  ) -> str:
115
117
  """Build the image for the given project.
116
118
 
@@ -159,9 +161,14 @@ class DockerBuilder(AbstractBuilder):
159
161
  context_path=build_ctx_path,
160
162
  platform=self.config.get("platform"),
161
163
  )
162
- warn_failed_packages_from_build_logs(output, image_uri)
164
+
165
+ warn_failed_packages_from_build_logs(
166
+ output, image_uri, launch_project.api, job_tracker
167
+ )
163
168
 
164
169
  except docker.DockerError as e:
170
+ if job_tracker:
171
+ job_tracker.set_err_stage("build")
165
172
  raise LaunchDockerError(f"Error communicating with docker client: {e}")
166
173
 
167
174
  try: