wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public.py +18 -20
  5. wandb/beta/workflows.py +5 -6
  6. wandb/cli/cli.py +27 -27
  7. wandb/data_types.py +2 -0
  8. wandb/integration/langchain/wandb_tracer.py +16 -179
  9. wandb/integration/sagemaker/config.py +2 -2
  10. wandb/integration/tensorboard/log.py +4 -4
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  13. wandb/proto/wandb_deprecated.py +3 -1
  14. wandb/sdk/__init__.py +1 -4
  15. wandb/sdk/artifacts/__init__.py +0 -14
  16. wandb/sdk/artifacts/artifact.py +1757 -277
  17. wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
  18. wandb/sdk/artifacts/artifact_state.py +10 -0
  19. wandb/sdk/artifacts/artifacts_cache.py +7 -8
  20. wandb/sdk/artifacts/exceptions.py +4 -4
  21. wandb/sdk/artifacts/storage_handler.py +2 -2
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
  23. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
  24. wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
  25. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
  26. wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
  27. wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
  28. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
  29. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
  30. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
  31. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
  33. wandb/sdk/artifacts/storage_policy.py +3 -3
  34. wandb/sdk/data_types/_dtypes.py +7 -12
  35. wandb/sdk/data_types/base_types/json_metadata.py +2 -2
  36. wandb/sdk/data_types/base_types/media.py +5 -6
  37. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  38. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
  39. wandb/sdk/data_types/helper_types/classes.py +5 -8
  40. wandb/sdk/data_types/helper_types/image_mask.py +4 -5
  41. wandb/sdk/data_types/histogram.py +3 -3
  42. wandb/sdk/data_types/html.py +3 -4
  43. wandb/sdk/data_types/image.py +4 -5
  44. wandb/sdk/data_types/molecule.py +2 -2
  45. wandb/sdk/data_types/object_3d.py +3 -3
  46. wandb/sdk/data_types/plotly.py +2 -2
  47. wandb/sdk/data_types/saved_model.py +7 -8
  48. wandb/sdk/data_types/trace_tree.py +4 -4
  49. wandb/sdk/data_types/video.py +4 -4
  50. wandb/sdk/interface/interface.py +8 -10
  51. wandb/sdk/internal/file_stream.py +2 -3
  52. wandb/sdk/internal/internal_api.py +99 -4
  53. wandb/sdk/internal/job_builder.py +15 -7
  54. wandb/sdk/internal/sender.py +4 -0
  55. wandb/sdk/internal/settings_static.py +1 -0
  56. wandb/sdk/launch/_project_spec.py +9 -7
  57. wandb/sdk/launch/agent/agent.py +115 -58
  58. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  59. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  60. wandb/sdk/launch/builder/abstract.py +5 -1
  61. wandb/sdk/launch/builder/build.py +16 -10
  62. wandb/sdk/launch/builder/docker_builder.py +9 -2
  63. wandb/sdk/launch/builder/kaniko_builder.py +108 -22
  64. wandb/sdk/launch/builder/noop.py +3 -1
  65. wandb/sdk/launch/environment/aws_environment.py +2 -1
  66. wandb/sdk/launch/environment/azure_environment.py +124 -0
  67. wandb/sdk/launch/github_reference.py +30 -18
  68. wandb/sdk/launch/launch.py +1 -1
  69. wandb/sdk/launch/loader.py +15 -0
  70. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  71. wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
  72. wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
  73. wandb/sdk/launch/runner/abstract.py +19 -3
  74. wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
  75. wandb/sdk/launch/runner/local_container.py +101 -48
  76. wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
  77. wandb/sdk/launch/runner/vertex_runner.py +8 -4
  78. wandb/sdk/launch/sweeps/scheduler.py +102 -27
  79. wandb/sdk/launch/sweeps/utils.py +21 -0
  80. wandb/sdk/launch/utils.py +19 -7
  81. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  82. wandb/sdk/service/server.py +22 -9
  83. wandb/sdk/service/service.py +27 -8
  84. wandb/sdk/verify/verify.py +6 -9
  85. wandb/sdk/wandb_config.py +2 -4
  86. wandb/sdk/wandb_init.py +2 -0
  87. wandb/sdk/wandb_require.py +7 -0
  88. wandb/sdk/wandb_run.py +32 -35
  89. wandb/sdk/wandb_settings.py +10 -3
  90. wandb/testing/relay.py +15 -2
  91. wandb/util.py +55 -23
  92. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
  93. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
  94. wandb/integration/langchain/util.py +0 -191
  95. wandb/sdk/artifacts/invalid_artifact.py +0 -23
  96. wandb/sdk/artifacts/lazy_artifact.py +0 -162
  97. wandb/sdk/artifacts/local_artifact.py +0 -719
  98. wandb/sdk/artifacts/public_artifact.py +0 -1188
  99. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  100. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  101. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
  102. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  import logging
2
2
  import os
3
3
  import shlex
4
- import signal
5
4
  import subprocess
6
5
  import sys
6
+ import threading
7
+ import time
7
8
  from typing import Any, Dict, List, Optional
8
9
 
9
10
  import wandb
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
10
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
11
13
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
12
14
 
@@ -28,36 +30,57 @@ _logger = logging.getLogger(__name__)
28
30
  class LocalSubmittedRun(AbstractRun):
29
31
  """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command locally."""
30
32
 
31
- def __init__(self, command_proc: "subprocess.Popen[bytes]") -> None:
33
+ def __init__(self) -> None:
32
34
  super().__init__()
33
- self.command_proc = command_proc
35
+ self._command_proc: Optional[subprocess.Popen] = None
36
+ self._stdout: Optional[str] = None
37
+ self._terminate_flag: bool = False
38
+ self._thread: Optional[threading.Thread] = None
39
+
40
+ def set_command_proc(self, command_proc: subprocess.Popen) -> None:
41
+ self._command_proc = command_proc
42
+
43
+ def set_thread(self, thread: threading.Thread) -> None:
44
+ self._thread = thread
34
45
 
35
46
  @property
36
- def id(self) -> str:
37
- return str(self.command_proc.pid)
47
+ def id(self) -> Optional[str]:
48
+ if self._command_proc is None:
49
+ return None
50
+ return str(self._command_proc.pid)
38
51
 
39
52
  def wait(self) -> bool:
40
- return self.command_proc.wait() == 0
53
+ assert self._thread is not None
54
+ # if command proc is not set
55
+ # wait for thread to set it
56
+ if self._command_proc is None:
57
+ while self._thread.is_alive():
58
+ time.sleep(5)
59
+ # command proc can be updated by another thread
60
+ if self._command_proc is not None:
61
+ return self._command_proc.wait() == 0 # type: ignore
62
+ return False
63
+
64
+ return self._command_proc.wait() == 0
65
+
66
+ def get_logs(self) -> Optional[str]:
67
+ return self._stdout
41
68
 
42
69
  def cancel(self) -> None:
43
- # Interrupt child process if it hasn't already exited
44
- if self.command_proc.poll() is None:
45
- # Kill the the process tree rooted at the child if it's the leader of its own process
46
- # group, otherwise just kill the child
47
- try:
48
- if self.command_proc.pid == os.getpgid(self.command_proc.pid):
49
- os.killpg(self.command_proc.pid, signal.SIGTERM)
50
- else:
51
- self.command_proc.terminate()
52
- except OSError:
53
- # The child process may have exited before we attempted to terminate it, so we
54
- # ignore OSErrors raised during child process termination
55
- _msg = f"{LOG_PREFIX}Failed to terminate child process PID {self.command_proc.pid}"
56
- _logger.debug(_msg)
57
- self.command_proc.wait()
70
+ # thread is set immediately after starting, should always exist
71
+ assert self._thread is not None
72
+
73
+ # cancel called before the thread subprocess has started
74
+ # indicates to thread to not start command proc if not already started
75
+ self._terminate_flag = True
58
76
 
59
77
  def get_status(self) -> Status:
60
- exit_code = self.command_proc.poll()
78
+ assert self._thread is not None, "Failed to get status, self._thread = None"
79
+ if self._command_proc is None:
80
+ if self._thread.is_alive():
81
+ return Status("running")
82
+ return Status("stopped")
83
+ exit_code = self._command_proc.poll()
61
84
  if exit_code is None:
62
85
  return Status("running")
63
86
  if exit_code == 0:
@@ -77,12 +100,7 @@ class LocalContainerRunner(AbstractRunner):
77
100
  super().__init__(api, backend_config)
78
101
  self.environment = environment
79
102
 
80
- def run(
81
- self,
82
- launch_project: LaunchProject,
83
- builder: Optional[AbstractBuilder],
84
- ) -> Optional[AbstractRun]:
85
- synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
103
+ def _populate_docker_args(self, launch_project: LaunchProject) -> Dict[str, Any]:
86
104
  docker_args: Dict[str, Any] = launch_project.resource_args.get(
87
105
  "local-container", {}
88
106
  )
@@ -95,6 +113,16 @@ class LocalContainerRunner(AbstractRunner):
95
113
  if sys.platform == "linux" or sys.platform == "linux2":
96
114
  docker_args["add-host"] = "host.docker.internal:host-gateway"
97
115
 
116
+ return docker_args
117
+
118
+ def run(
119
+ self,
120
+ launch_project: LaunchProject,
121
+ builder: Optional[AbstractBuilder],
122
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
123
+ ) -> Optional[AbstractRun]:
124
+ docker_args = self._populate_docker_args(launch_project)
125
+ synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
98
126
  entry_point = launch_project.get_single_entry_point()
99
127
  env_vars = get_env_vars_dict(launch_project, self._api)
100
128
 
@@ -106,7 +134,7 @@ class LocalContainerRunner(AbstractRunner):
106
134
  _, _, port = self._api.settings("base_url").split(":")
107
135
  env_vars["WANDB_BASE_URL"] = f"http://host.docker.internal:{port}"
108
136
  elif _is_wandb_dev_uri(self._api.settings("base_url")):
109
- env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9002"
137
+ env_vars["WANDB_BASE_URL"] = "http://host.docker.internal:9001"
110
138
 
111
139
  if launch_project.docker_image:
112
140
  # user has provided their own docker image
@@ -128,11 +156,7 @@ class LocalContainerRunner(AbstractRunner):
128
156
  assert entry_point is not None
129
157
  _logger.info("Building docker image...")
130
158
  assert builder is not None
131
- image_uri = builder.build_image(
132
- launch_project,
133
- entry_point,
134
- )
135
-
159
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
136
160
  _logger.info(f"Docker image built with uri {image_uri}")
137
161
  # entry_cmd and additional_args are empty here because
138
162
  # if launch built the container they've been accounted
@@ -167,20 +191,49 @@ def _run_entry_point(command: str, work_dir: Optional[str]) -> AbstractRun:
167
191
  if work_dir is None:
168
192
  work_dir = os.getcwd()
169
193
  env = os.environ.copy()
170
- if os.name == "nt":
171
- # we are running on windows
172
- process = subprocess.Popen(
173
- ["cmd", "/c", command], close_fds=True, cwd=work_dir, env=env
174
- )
175
- else:
176
- process = subprocess.Popen(
177
- ["bash", "-c", command],
178
- close_fds=True,
179
- cwd=work_dir,
180
- env=env,
181
- )
182
-
183
- return LocalSubmittedRun(process)
194
+ run = LocalSubmittedRun()
195
+ thread = threading.Thread(
196
+ target=_thread_process_runner,
197
+ args=(run, ["bash", "-c", command], work_dir, env),
198
+ )
199
+ run.set_thread(thread)
200
+ thread.start()
201
+ return run
202
+
203
+
204
+ def _thread_process_runner(
205
+ run: LocalSubmittedRun, args: List[str], work_dir: str, env: Dict[str, str]
206
+ ) -> None:
207
+ # cancel was called before we started the subprocess
208
+ if run._terminate_flag:
209
+ return
210
+ process = subprocess.Popen(
211
+ args,
212
+ close_fds=True,
213
+ stdout=subprocess.PIPE,
214
+ stderr=subprocess.STDOUT,
215
+ universal_newlines=True,
216
+ bufsize=1,
217
+ cwd=work_dir,
218
+ env=env,
219
+ )
220
+ run.set_command_proc(process)
221
+ run._stdout = ""
222
+ while True:
223
+ # the agent thread could set the terminate flag
224
+ if run._terminate_flag:
225
+ process.terminate() # type: ignore
226
+ chunk = os.read(process.stdout.fileno(), 4096) # type: ignore
227
+ if not chunk:
228
+ break
229
+ index = chunk.find(b"\r")
230
+ decoded_chunk = chunk.decode()
231
+ if index != -1:
232
+ run._stdout += decoded_chunk
233
+ print(chunk.decode(), end="")
234
+ else:
235
+ run._stdout += decoded_chunk + "\r"
236
+ print(chunk.decode(), end="\r")
184
237
 
185
238
 
186
239
  def get_docker_command(
@@ -8,6 +8,7 @@ if False:
8
8
 
9
9
  import wandb
10
10
  from wandb.apis.internal import Api
11
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
11
12
  from wandb.sdk.launch.builder.abstract import AbstractBuilder
12
13
  from wandb.sdk.launch.environment.aws_environment import AwsEnvironment
13
14
  from wandb.sdk.launch.errors import LaunchError
@@ -23,9 +24,15 @@ _logger = logging.getLogger(__name__)
23
24
  class SagemakerSubmittedRun(AbstractRun):
24
25
  """Instance of ``AbstractRun`` corresponding to a subprocess launched to run an entry point command on aws sagemaker."""
25
26
 
26
- def __init__(self, training_job_name: str, client: "boto3.Client") -> None:
27
+ def __init__(
28
+ self,
29
+ training_job_name: str,
30
+ client: "boto3.Client",
31
+ log_client: Optional["boto3.Client"] = None,
32
+ ) -> None:
27
33
  super().__init__()
28
34
  self.client = client
35
+ self.log_client = log_client
29
36
  self.training_job_name = training_job_name
30
37
  self._status = Status("running")
31
38
 
@@ -33,6 +40,38 @@ class SagemakerSubmittedRun(AbstractRun):
33
40
  def id(self) -> str:
34
41
  return f"sagemaker-{self.training_job_name}"
35
42
 
43
+ def get_logs(self) -> Optional[str]:
44
+ if self.log_client is None:
45
+ return None
46
+ try:
47
+ describe_res = self.log_client.describe_log_streams(
48
+ logGroupName="/aws/sagemaker/TrainingJobs",
49
+ logStreamNamePrefix=self.training_job_name,
50
+ )
51
+ if len(describe_res["logStreams"]) == 0:
52
+ wandb.termwarn(
53
+ f"Failed to get logs for training job: {self.training_job_name}"
54
+ )
55
+ return None
56
+ log_name = describe_res["logStreams"][0]["logStreamName"]
57
+ res = self.log_client.get_log_events(
58
+ logGroupName="/aws/sagemaker/TrainingJobs",
59
+ logStreamName=log_name,
60
+ )
61
+ return "\n".join(
62
+ [f'{event["timestamp"]}:{event["message"]}' for event in res["events"]]
63
+ )
64
+ except self.log_client.exceptions.ResourceNotFoundException:
65
+ wandb.termwarn(
66
+ f"Failed to get logs for training job: {self.training_job_name}"
67
+ )
68
+ return None
69
+ except Exception as e:
70
+ wandb.termwarn(
71
+ f"Failed to handle logs for training job: {self.training_job_name} with error {str(e)}"
72
+ )
73
+ return None
74
+
36
75
  def wait(self) -> bool:
37
76
  while True:
38
77
  status_state = self.get_status().state
@@ -89,6 +128,7 @@ class SageMakerRunner(AbstractRunner):
89
128
  self,
90
129
  launch_project: LaunchProject,
91
130
  builder: Optional[AbstractBuilder],
131
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
92
132
  ) -> Optional[AbstractRun]:
93
133
  """Run a project on Amazon Sagemaker.
94
134
 
@@ -128,6 +168,13 @@ class SageMakerRunner(AbstractRunner):
128
168
 
129
169
  # Create a sagemaker client to launch the job.
130
170
  sagemaker_client = session.client("sagemaker")
171
+ log_client = None
172
+ try:
173
+ log_client = session.client("logs")
174
+ except Exception as e:
175
+ wandb.termwarn(
176
+ f"Failed to connect to cloudwatch logs with error {str(e)}, logs will not be available"
177
+ )
131
178
 
132
179
  # if the user provided the image they want to use, use that, but warn it won't have swappable artifacts
133
180
  if (
@@ -146,7 +193,9 @@ class SageMakerRunner(AbstractRunner):
146
193
  _logger.info(
147
194
  f"Launching sagemaker job on user supplied image with args: {sagemaker_args}"
148
195
  )
149
- run = launch_sagemaker_job(launch_project, sagemaker_args, sagemaker_client)
196
+ run = launch_sagemaker_job(
197
+ launch_project, sagemaker_args, sagemaker_client, log_client
198
+ )
150
199
  if self.backend_config[PROJECT_SYNCHRONOUS]:
151
200
  run.wait()
152
201
  return run
@@ -158,11 +207,9 @@ class SageMakerRunner(AbstractRunner):
158
207
  assert builder is not None
159
208
  # build our own image
160
209
  _logger.info("Building docker image...")
161
- image = builder.build_image(
162
- launch_project,
163
- entry_point,
164
- )
210
+ image = builder.build_image(launch_project, entry_point, job_tracker)
165
211
  _logger.info(f"Docker image built with uri {image}")
212
+
166
213
  launch_project.fill_macros(image)
167
214
  _logger.info("Connecting to sagemaker client")
168
215
  command_args = get_entry_point_command(
@@ -181,7 +228,9 @@ class SageMakerRunner(AbstractRunner):
181
228
  launch_project, self._api, role_arn, image, default_output_path
182
229
  )
183
230
  _logger.info(f"Launching sagemaker job with args: {sagemaker_args}")
184
- run = launch_sagemaker_job(launch_project, sagemaker_args, sagemaker_client)
231
+ run = launch_sagemaker_job(
232
+ launch_project, sagemaker_args, sagemaker_client, log_client
233
+ )
185
234
  if self.backend_config[PROJECT_SYNCHRONOUS]:
186
235
  run.wait()
187
236
  return run
@@ -296,14 +345,15 @@ def launch_sagemaker_job(
296
345
  launch_project: LaunchProject,
297
346
  sagemaker_args: Dict[str, Any],
298
347
  sagemaker_client: "boto3.Client",
348
+ log_client: Optional["boto3.Client"] = None,
299
349
  ) -> SagemakerSubmittedRun:
300
350
  training_job_name = sagemaker_args.get("TrainingJobName") or launch_project.run_id
301
351
  resp = sagemaker_client.create_training_job(**sagemaker_args)
302
352
 
303
353
  if resp.get("TrainingJobArn") is None:
304
- raise LaunchError("Unable to create training job")
354
+ raise LaunchError("Failed to create training job when submitting to SageMaker")
305
355
 
306
- run = SagemakerSubmittedRun(training_job_name, sagemaker_client)
356
+ run = SagemakerSubmittedRun(training_job_name, sagemaker_client, log_client)
307
357
  wandb.termlog(
308
358
  f"{LOG_PREFIX}Run job submitted with arn: {resp.get('TrainingJobArn')}"
309
359
  )
@@ -14,6 +14,7 @@ from wandb.apis.internal import Api
14
14
  from wandb.util import get_module
15
15
 
16
16
  from .._project_spec import LaunchProject, get_entry_point_command
17
+ from ..agent.job_status_tracker import JobAndRunStatusTracker
17
18
  from ..builder.abstract import AbstractBuilder
18
19
  from ..builder.build import get_env_vars_dict
19
20
  from ..environment.gcp_environment import GcpEnvironment
@@ -35,6 +36,10 @@ class VertexSubmittedRun(AbstractRun):
35
36
  # numeric ID of the custom training job
36
37
  return self._job.name # type: ignore
37
38
 
39
+ def get_logs(self) -> Optional[str]:
40
+ # TODO: implement
41
+ return None
42
+
38
43
  @property
39
44
  def name(self) -> str:
40
45
  return self._job.display_name # type: ignore
@@ -89,6 +94,7 @@ class VertexRunner(AbstractRunner):
89
94
  self,
90
95
  launch_project: LaunchProject,
91
96
  builder: Optional[AbstractBuilder],
97
+ job_tracker: Optional[JobAndRunStatusTracker] = None,
92
98
  ) -> Optional[AbstractRun]:
93
99
  """Run a Vertex job."""
94
100
  aiplatform = get_module( # noqa: F811
@@ -134,10 +140,8 @@ class VertexRunner(AbstractRunner):
134
140
  else:
135
141
  assert entry_point is not None
136
142
  assert builder is not None
137
- image_uri = builder.build_image(
138
- launch_project,
139
- entry_point,
140
- )
143
+ image_uri = builder.build_image(launch_project, entry_point, job_tracker)
144
+
141
145
  launch_project.fill_macros(image_uri)
142
146
  # TODO: how to handle this?
143
147
  entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
@@ -130,10 +130,10 @@ class Scheduler(ABC):
130
130
  if resp.get("state") == SchedulerState.CANCELLED.name:
131
131
  self._state = SchedulerState.CANCELLED
132
132
  self._sweep_config = yaml.safe_load(resp["config"])
133
- self._num_runs_launched: int = len(resp["runs"])
133
+ self._num_runs_launched: int = self._get_num_runs_launched(resp["runs"])
134
134
  if self._num_runs_launched > 0:
135
135
  wandb.termlog(
136
- f"{LOG_PREFIX}Found {self._num_runs_launched} previous runs for sweep {self._sweep_id}"
136
+ f"{LOG_PREFIX}Found {self._num_runs_launched} previous valid runs for sweep {self._sweep_id}"
137
137
  )
138
138
  except Exception as e:
139
139
  raise SchedulerError(
@@ -295,10 +295,12 @@ class Scheduler(ABC):
295
295
  self.state = SchedulerState.RUNNING
296
296
  try:
297
297
  while True:
298
- wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
298
+ self._update_scheduler_run_state()
299
299
  if not self.is_alive:
300
300
  break
301
301
 
302
+ wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
303
+
302
304
  self._update_run_states()
303
305
  self._poll()
304
306
  if self.state == SchedulerState.FLUSH_RUNS:
@@ -316,8 +318,17 @@ class Scheduler(ABC):
316
318
  self.state = SchedulerState.FLUSH_RUNS
317
319
  break
318
320
 
319
- run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
320
- if not run:
321
+ try:
322
+ run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
323
+ if not run:
324
+ break
325
+ except SchedulerError as e:
326
+ raise SchedulerError(e)
327
+ except Exception as e:
328
+ wandb.termerror(
329
+ f"{LOG_PREFIX}Failed to get next sweep run: {e}"
330
+ )
331
+ self.state = SchedulerState.FAILED
321
332
  break
322
333
 
323
334
  if self._add_to_launch_queue(run):
@@ -356,10 +367,29 @@ class Scheduler(ABC):
356
367
  SchedulerState.STOPPED,
357
368
  ]:
358
369
  self.state = SchedulerState.FAILED
370
+ self._set_sweep_state("CRASHED")
371
+ else:
372
+ self._set_sweep_state("FINISHED")
373
+
359
374
  self._stop_runs()
360
- self._set_sweep_state("FINISHED")
361
375
  self._wandb_run.finish()
362
376
 
377
+ def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
378
+ """Returns the number of valid runs in the sweep."""
379
+ count = 0
380
+ for run in runs:
381
+ # if bad run, shouldn't be counted against run cap
382
+ if run.get("state", "") in ["killed", "crashed"] and not run.get(
383
+ "summaryMetrics"
384
+ ):
385
+ _logger.debug(
386
+ f"excluding run: {run['name']} with state: {run['state']} from run cap \n{run}"
387
+ )
388
+ continue
389
+ count += 1
390
+
391
+ return count
392
+
363
393
  def _try_load_executable(self) -> bool:
364
394
  """Check existance of valid executable for a run.
365
395
 
@@ -384,12 +414,17 @@ class Scheduler(ABC):
384
414
  def _register_agents(self) -> None:
385
415
  for worker_id in range(self._num_workers):
386
416
  _logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker ({worker_id})")
387
- agent_config = self._api.register_agent(
388
- f"{socket.gethostname()}-{worker_id}", # host
389
- sweep_id=self._sweep_id,
390
- project_name=self._project,
391
- entity=self._entity,
392
- )
417
+ try:
418
+ agent_config = self._api.register_agent(
419
+ f"{socket.gethostname()}-{worker_id}", # host
420
+ sweep_id=self._sweep_id,
421
+ project_name=self._project,
422
+ entity=self._entity,
423
+ )
424
+ except Exception as e:
425
+ _logger.debug(f"failed to register agent: {e}")
426
+ self.fail_sweep(f"failed to register agent: {e}")
427
+
393
428
  self._workers[worker_id] = _Worker(
394
429
  agent_config=agent_config,
395
430
  agent_id=agent_config["id"],
@@ -455,6 +490,30 @@ class Scheduler(ABC):
455
490
 
456
491
  return False
457
492
 
493
+ def _update_scheduler_run_state(self) -> None:
494
+ """Update the scheduler state from state of scheduler run and sweep state."""
495
+ state: RunState = self._get_run_state(self._wandb_run.id)
496
+
497
+ if state == RunState.KILLED:
498
+ self.state = SchedulerState.STOPPED
499
+ elif state in [RunState.FAILED, RunState.CRASHED]:
500
+ self.state = SchedulerState.FAILED
501
+ elif state == RunState.FINISHED:
502
+ self.state = SchedulerState.COMPLETED
503
+
504
+ try:
505
+ sweep_state = self._api.get_sweep_state(
506
+ self._sweep_id, self._entity, self._project
507
+ )
508
+ except Exception as e:
509
+ _logger.debug(f"sweep state error: {sweep_state} e: {e}")
510
+ return
511
+
512
+ if sweep_state in ["FINISHED", "CANCELLED"]:
513
+ self.state = SchedulerState.COMPLETED
514
+ elif sweep_state in ["PAUSED", "STOPPED"]:
515
+ self.state = SchedulerState.FLUSH_RUNS
516
+
458
517
  def _update_run_states(self) -> None:
459
518
  """Iterate through runs.
460
519
 
@@ -530,7 +589,7 @@ class Scheduler(ABC):
530
589
  run_state = RunState.UNKNOWN
531
590
  except (AttributeError, ValueError):
532
591
  wandb.termwarn(
533
- f"Bad state ({state}) for run ({run_id}). Error: {traceback.format_exc()}"
592
+ f"Bad state ({run_state}) for run ({run_id}). Error: {traceback.format_exc()}"
534
593
  )
535
594
  run_state = RunState.UNKNOWN
536
595
  return run_state
@@ -564,6 +623,35 @@ class Scheduler(ABC):
564
623
  base64.b64decode(bytes(_id.encode("utf-8"))).decode("utf-8").split(":")[2]
565
624
  )
566
625
 
626
+ def _make_entry_and_launch_config(
627
+ self, run: SweepRun
628
+ ) -> Tuple[Optional[List[str]], Dict[str, Dict[str, Any]]]:
629
+ args = create_sweep_command_args({"args": run.args})
630
+ entry_point, macro_args = make_launch_sweep_entrypoint(
631
+ args, self._sweep_config.get("command")
632
+ )
633
+ # handle program macro
634
+ if entry_point and "${program}" in entry_point:
635
+ if not self._sweep_config.get("program"):
636
+ raise SchedulerError(
637
+ f"{LOG_PREFIX}Program macro in command has no corresponding 'program' in sweep config."
638
+ )
639
+ pidx = entry_point.index("${program}")
640
+ entry_point[pidx] = self._sweep_config["program"]
641
+
642
+ launch_config = {"overrides": {"run_config": args["args_dict"]}}
643
+ if macro_args: # pipe in hyperparam args as params to launch
644
+ launch_config["overrides"]["args"] = macro_args
645
+
646
+ if entry_point:
647
+ unresolved = [x for x in entry_point if str(x).startswith("${")]
648
+ if unresolved:
649
+ wandb.termwarn(
650
+ f"{LOG_PREFIX}Sweep command contains unresolved macros: "
651
+ f"{unresolved}, see launch docs for supported macros."
652
+ )
653
+ return entry_point, launch_config
654
+
567
655
  def _add_to_launch_queue(self, run: SweepRun) -> bool:
568
656
  """Convert a sweeprun into a launch job then push to runqueue."""
569
657
  # job and image first from CLI args, then from sweep config
@@ -575,25 +663,12 @@ class Scheduler(ABC):
575
663
  elif _job is not None and _image_uri is not None:
576
664
  raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
577
665
 
578
- args = create_sweep_command_args({"args": run.args})
579
- entry_point, macro_args = make_launch_sweep_entrypoint(
580
- args, self._sweep_config.get("command")
581
- )
582
- launch_config = {"overrides": {"run_config": args["args_dict"]}}
583
- if macro_args: # pipe in hyperparam args as params to launch
584
- launch_config["overrides"]["args"] = macro_args
585
-
666
+ entry_point, launch_config = self._make_entry_and_launch_config(run)
586
667
  if entry_point:
587
668
  wandb.termwarn(
588
669
  f"{LOG_PREFIX}Sweep command {entry_point} will override"
589
670
  f' {"job" if _job else "image_uri"} entrypoint'
590
671
  )
591
- unresolved = [x for x in entry_point if str(x).startswith("${")]
592
- if unresolved:
593
- wandb.termwarn(
594
- f"{LOG_PREFIX}Sweep command contains unresolved macros: "
595
- f"{unresolved}, see launch docs for supported macros."
596
- )
597
672
 
598
673
  run_id = run.id or generate_id()
599
674
  queued_run = launch_add(
@@ -291,3 +291,24 @@ def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
291
291
  wandb.termerror(f"Failed to load job. {e}")
292
292
  return False
293
293
  return True
294
+
295
+
296
+ def get_previous_args(
297
+ run_spec: Dict[str, Any]
298
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
299
+ """Parse through previous scheduler run_spec.
300
+
301
+ returns scheduler_args and settings.
302
+ """
303
+ scheduler_args = (
304
+ run_spec.get("overrides", {}).get("run_config", {}).get("scheduler", {})
305
+ )
306
+ # also pipe through top level resource setup
307
+ if run_spec.get("resource"):
308
+ scheduler_args["resource"] = run_spec["resource"]
309
+ if run_spec.get("resource_args"):
310
+ scheduler_args["resource_args"] = run_spec["resource_args"]
311
+
312
+ settings = run_spec.get("overrides", {}).get("run_config", {}).get("settings", {})
313
+
314
+ return scheduler_args, settings
wandb/sdk/launch/utils.py CHANGED
@@ -28,7 +28,8 @@ FAILED_PACKAGES_REGEX = re.compile(
28
28
  )
29
29
 
30
30
  if TYPE_CHECKING: # pragma: no cover
31
- from wandb.sdk.artifacts.public_artifact import Artifact as PublicArtifact
31
+ from wandb.sdk.artifacts.artifact import Artifact
32
+ from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
32
33
 
33
34
 
34
35
  # TODO: this should be restricted to just Git repos and not S3 and stuff like that
@@ -47,7 +48,7 @@ _WANDB_LOCAL_DEV_URI_REGEX = re.compile(
47
48
  r"^https?://localhost"
48
49
  ) # for testing, not sure if we wanna keep this
49
50
 
50
- API_KEY_REGEX = r"WANDB_API_KEY=\w+"
51
+ API_KEY_REGEX = r"WANDB_API_KEY=\w+(-\w+)?"
51
52
 
52
53
  MACRO_REGEX = re.compile(r"\$\{(\w+)\}")
53
54
 
@@ -493,7 +494,7 @@ def convert_jupyter_notebook_to_script(fname: str, project_dir: str) -> str:
493
494
 
494
495
  def check_and_download_code_artifacts(
495
496
  entity: str, project: str, run_name: str, internal_api: Api, project_dir: str
496
- ) -> Optional["PublicArtifact"]:
497
+ ) -> Optional["Artifact"]:
497
498
  _logger.info("Checking for code artifacts")
498
499
  public_api = wandb.PublicApi(
499
500
  overrides={"base_url": internal_api.settings("base_url")}
@@ -620,12 +621,23 @@ def make_name_dns_safe(name: str) -> str:
620
621
  return resp
621
622
 
622
623
 
623
- def warn_failed_packages_from_build_logs(log: str, image_uri: str) -> None:
624
+ def warn_failed_packages_from_build_logs(
625
+ log: str, image_uri: str, api: Api, job_tracker: Optional["JobAndRunStatusTracker"]
626
+ ) -> None:
624
627
  match = FAILED_PACKAGES_REGEX.search(log)
625
628
  if match:
626
- wandb.termwarn(
627
- f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
628
- )
629
+ _msg = f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
630
+ wandb.termwarn(_msg)
631
+ if job_tracker is not None:
632
+ res = job_tracker.saver.save_contents(
633
+ _msg, "failed-packages.log", "warning"
634
+ )
635
+ api.update_run_queue_item_warning(
636
+ job_tracker.run_queue_item_id,
637
+ "Some packages were not successfully installed during the build",
638
+ "build",
639
+ res,
640
+ )
629
641
 
630
642
 
631
643
  def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool: