wandb 0.15.9__py3-none-any.whl → 0.15.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/apis/public.py +137 -17
  3. wandb/apis/reports/_panels.py +1 -1
  4. wandb/apis/reports/blocks.py +1 -0
  5. wandb/apis/reports/report.py +27 -5
  6. wandb/cli/cli.py +52 -41
  7. wandb/docker/__init__.py +17 -0
  8. wandb/docker/auth.py +1 -1
  9. wandb/env.py +24 -4
  10. wandb/filesync/step_checksum.py +3 -3
  11. wandb/integration/openai/openai.py +3 -0
  12. wandb/integration/ultralytics/__init__.py +9 -0
  13. wandb/integration/ultralytics/bbox_utils.py +196 -0
  14. wandb/integration/ultralytics/callback.py +458 -0
  15. wandb/integration/ultralytics/classification_utils.py +66 -0
  16. wandb/integration/ultralytics/mask_utils.py +141 -0
  17. wandb/integration/ultralytics/pose_utils.py +92 -0
  18. wandb/integration/xgboost/xgboost.py +3 -3
  19. wandb/integration/yolov8/__init__.py +0 -7
  20. wandb/integration/yolov8/yolov8.py +22 -3
  21. wandb/old/settings.py +7 -0
  22. wandb/plot/line_series.py +0 -1
  23. wandb/proto/v3/wandb_internal_pb2.py +353 -300
  24. wandb/proto/v3/wandb_server_pb2.py +37 -41
  25. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  26. wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
  27. wandb/proto/v4/wandb_internal_pb2.py +272 -260
  28. wandb/proto/v4/wandb_server_pb2.py +37 -40
  29. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  30. wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
  31. wandb/proto/wandb_internal_codegen.py +7 -31
  32. wandb/sdk/artifacts/artifact.py +321 -189
  33. wandb/sdk/artifacts/artifact_cache.py +14 -0
  34. wandb/sdk/artifacts/artifact_manifest.py +5 -4
  35. wandb/sdk/artifacts/artifact_manifest_entry.py +37 -9
  36. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -9
  37. wandb/sdk/artifacts/artifact_saver.py +13 -50
  38. wandb/sdk/artifacts/artifact_ttl.py +6 -0
  39. wandb/sdk/artifacts/artifacts_cache.py +119 -93
  40. wandb/sdk/artifacts/staging.py +25 -0
  41. wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
  42. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -3
  43. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  44. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  45. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +4 -3
  46. wandb/sdk/artifacts/storage_policy.py +4 -2
  47. wandb/sdk/backend/backend.py +0 -16
  48. wandb/sdk/data_types/image.py +3 -1
  49. wandb/sdk/integration_utils/auto_logging.py +38 -13
  50. wandb/sdk/interface/interface.py +16 -135
  51. wandb/sdk/interface/interface_shared.py +9 -147
  52. wandb/sdk/interface/interface_sock.py +0 -26
  53. wandb/sdk/internal/file_pusher.py +20 -3
  54. wandb/sdk/internal/file_stream.py +3 -1
  55. wandb/sdk/internal/handler.py +53 -70
  56. wandb/sdk/internal/internal_api.py +220 -130
  57. wandb/sdk/internal/job_builder.py +41 -37
  58. wandb/sdk/internal/sender.py +7 -25
  59. wandb/sdk/internal/system/assets/disk.py +144 -11
  60. wandb/sdk/internal/system/system_info.py +6 -2
  61. wandb/sdk/launch/__init__.py +5 -0
  62. wandb/sdk/launch/{launch.py → _launch.py} +53 -54
  63. wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
  64. wandb/sdk/launch/_project_spec.py +13 -2
  65. wandb/sdk/launch/agent/agent.py +103 -59
  66. wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
  67. wandb/sdk/launch/builder/build.py +19 -1
  68. wandb/sdk/launch/builder/docker_builder.py +5 -1
  69. wandb/sdk/launch/builder/kaniko_builder.py +5 -1
  70. wandb/sdk/launch/create_job.py +20 -5
  71. wandb/sdk/launch/loader.py +14 -5
  72. wandb/sdk/launch/runner/abstract.py +0 -2
  73. wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
  74. wandb/sdk/launch/runner/kubernetes_runner.py +66 -209
  75. wandb/sdk/launch/runner/local_container.py +5 -2
  76. wandb/sdk/launch/runner/local_process.py +4 -1
  77. wandb/sdk/launch/sweeps/scheduler.py +43 -25
  78. wandb/sdk/launch/sweeps/utils.py +5 -3
  79. wandb/sdk/launch/utils.py +3 -1
  80. wandb/sdk/lib/_settings_toposort_generate.py +3 -9
  81. wandb/sdk/lib/_settings_toposort_generated.py +27 -3
  82. wandb/sdk/lib/_wburls_generated.py +1 -0
  83. wandb/sdk/lib/filenames.py +27 -6
  84. wandb/sdk/lib/filesystem.py +181 -7
  85. wandb/sdk/lib/fsm.py +5 -3
  86. wandb/sdk/lib/gql_request.py +3 -0
  87. wandb/sdk/lib/ipython.py +7 -0
  88. wandb/sdk/lib/wburls.py +1 -0
  89. wandb/sdk/service/port_file.py +2 -15
  90. wandb/sdk/service/server.py +7 -55
  91. wandb/sdk/service/service.py +56 -26
  92. wandb/sdk/service/service_base.py +1 -1
  93. wandb/sdk/service/streams.py +11 -5
  94. wandb/sdk/verify/verify.py +2 -2
  95. wandb/sdk/wandb_init.py +8 -2
  96. wandb/sdk/wandb_manager.py +4 -14
  97. wandb/sdk/wandb_run.py +143 -53
  98. wandb/sdk/wandb_settings.py +148 -35
  99. wandb/testing/relay.py +85 -38
  100. wandb/util.py +87 -4
  101. wandb/wandb_torch.py +24 -38
  102. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/METADATA +48 -23
  103. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/RECORD +107 -103
  104. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/WHEEL +1 -1
  105. wandb/proto/v3/wandb_server_pb2_grpc.py +0 -1422
  106. wandb/proto/v4/wandb_server_pb2_grpc.py +0 -1422
  107. wandb/proto/wandb_server_pb2_grpc.py +0 -8
  108. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +0 -61
  109. wandb/sdk/interface/interface_grpc.py +0 -460
  110. wandb/sdk/service/server_grpc.py +0 -444
  111. wandb/sdk/service/service_grpc.py +0 -73
  112. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
  113. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -6,13 +6,15 @@ import logging
6
6
  import time
7
7
  from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
8
8
 
9
+ import yaml
10
+
9
11
  import wandb
10
12
  from wandb.apis.internal import Api
11
13
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
12
14
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
13
15
  from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
14
16
  from wandb.sdk.launch.registry.local_registry import LocalRegistry
15
- from wandb.sdk.launch.runner.abstract import State, Status
17
+ from wandb.sdk.launch.runner.abstract import Status
16
18
  from wandb.util import get_module
17
19
 
18
20
  from .._project_spec import EntryPoint, LaunchProject
@@ -26,6 +28,7 @@ from ..utils import (
26
28
  make_name_dns_safe,
27
29
  )
28
30
  from .abstract import AbstractRun, AbstractRunner
31
+ from .kubernetes_monitor import KubernetesRunMonitor
29
32
 
30
33
  get_module(
31
34
  "kubernetes",
@@ -43,32 +46,16 @@ from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa:
43
46
  from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
44
47
 
45
48
  TIMEOUT = 5
46
- MAX_KUBERNETES_RETRIES = (
47
- 60 # default 10 second loop time on the agent, this is 10 minutes
48
- )
49
- FAIL_MESSAGE_INTERVAL = 60
50
49
 
51
50
  _logger = logging.getLogger(__name__)
52
51
 
53
52
 
54
- # Dict for mapping possible states of custom objects to the states we want to report
55
- # to the agent.
56
- CRD_STATE_DICT: Dict[str, State] = {
57
- "pending": "starting",
58
- "running": "running",
59
- "completed": "finished",
60
- "failed": "failed",
61
- "aborted": "failed",
62
- "terminating": "stopping",
63
- "terminated": "stopped",
64
- }
65
-
66
-
67
53
  class KubernetesSubmittedRun(AbstractRun):
68
54
  """Wrapper for a launched run on Kubernetes."""
69
55
 
70
56
  def __init__(
71
57
  self,
58
+ monitor: KubernetesRunMonitor,
72
59
  batch_api: "BatchV1Api",
73
60
  core_api: "CoreV1Api",
74
61
  name: str,
@@ -78,6 +65,14 @@ class KubernetesSubmittedRun(AbstractRun):
78
65
  ) -> None:
79
66
  """Initialize a KubernetesSubmittedRun.
80
67
 
68
+ Other implementations of the AbstractRun interface poll on the run
69
+ when `get_status` is called, but KubernetesSubmittedRun uses
70
+ Kubernetes watch streams to update the run status. One thread handles
71
+ events from the job object and another thread handles events from the
72
+ rank 0 pod. These threads updated the `_status` attributed of the
73
+ KubernetesSubmittedRun object. When `get_status` is called, the
74
+ `_status` attribute is returned.
75
+
81
76
  Arguments:
82
77
  batch_api: Kubernetes BatchV1Api object.
83
78
  core_api: Kubernetes CoreV1Api object.
@@ -89,13 +84,11 @@ class KubernetesSubmittedRun(AbstractRun):
89
84
  Returns:
90
85
  None.
91
86
  """
87
+ self.monitor = monitor
92
88
  self.batch_api = batch_api
93
89
  self.core_api = core_api
94
90
  self.name = name
95
91
  self.namespace = namespace
96
- self.job = self.batch_api.read_namespaced_job(
97
- name=self.name, namespace=self.namespace
98
- )
99
92
  self._fail_count = 0
100
93
  self.pod_names = pod_names
101
94
  self.secret = secret
@@ -136,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
136
129
  while True:
137
130
  status = self.get_status()
138
131
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
139
- if status.state != "running":
132
+ if status.state in ["finished", "failed", "preempted"]:
140
133
  break
141
134
  time.sleep(5)
142
135
  return (
@@ -156,98 +149,20 @@ class KubernetesSubmittedRun(AbstractRun):
156
149
  )
157
150
 
158
151
  def get_status(self) -> Status:
159
- """Return the run status."""
160
- try:
161
- job_response = self.batch_api.read_namespaced_job_status(
162
- name=self.name, namespace=self.namespace
163
- )
164
- except ApiException as e:
165
- if e.status == 404:
166
- wandb.termerror(
167
- f"Could not reach job {self.name} in namespace {self.namespace}"
168
- )
169
- self._delete_secret_if_completed("failed")
170
- return Status("failed")
171
-
172
- status = job_response.status
152
+ return self.monitor.get_status()
173
153
 
154
+ def cancel(self) -> None:
155
+ """Cancel the run."""
156
+ self.monitor.stop()
174
157
  try:
175
- pod = self.core_api.read_namespaced_pod(
176
- name=self.pod_names[0], namespace=self.namespace
158
+ self.batch_api.delete_namespaced_job(
159
+ namespace=self.namespace,
160
+ name=self.name,
177
161
  )
178
162
  except ApiException as e:
179
- if e.status == 404:
180
- wandb.termerror(
181
- f"Could not reach pod {self.pod_names[0]} in namespace {self.namespace}"
182
- )
183
- self._delete_secret_if_completed("failed")
184
- return Status("failed")
185
-
186
- if hasattr(pod.status, "conditions") and pod.status.conditions is not None:
187
- for condition in pod.status.conditions:
188
- if condition.type == "DisruptionTarget" and condition.reason in [
189
- "EvictionByEvictionAPI",
190
- "PreemptionByScheduler",
191
- "TerminationByKubelet",
192
- ]:
193
- return Status("preempted")
194
- if pod.status.phase in ["Pending", "Unknown"]:
195
- now = time.time()
196
- if self._fail_count == 0:
197
- self._fail_first_msg_time = now
198
- self._fail_last_msg_time = 0.0
199
- self._fail_count += 1
200
- if now - self._fail_last_msg_time > FAIL_MESSAGE_INTERVAL:
201
- wandb.termlog(
202
- f"{LOG_PREFIX}Pod has not started yet for job: {self.name}. Will wait up to {round(10 - (now - self._fail_first_msg_time)/60)} minutes."
203
- )
204
- self._fail_last_msg_time = now
205
- if self._fail_count > MAX_KUBERNETES_RETRIES:
206
- raise LaunchError(f"Failed to start job {self.name}")
207
- # todo: we only handle the 1 pod case. see https://kubernetes.io/docs/concepts/workloads/controllers/job/#parallel-jobs for multipod handling
208
- return_status = None
209
- if status.succeeded == 1:
210
- return_status = Status("finished")
211
- elif status.failed is not None and status.failed >= 1:
212
- return_status = Status("failed")
213
- elif status.active == 1:
214
- return Status("running")
215
- elif status.conditions is not None and status.conditions[0].type == "Suspended":
216
- return_status = Status("stopped")
217
- else:
218
- return_status = Status("unknown")
219
-
220
- self._delete_secret_if_completed(return_status.state)
221
- return return_status
222
-
223
- def suspend(self) -> None:
224
- """Suspend the run."""
225
- self.job.spec.suspend = True
226
- self.batch_api.patch_namespaced_job(
227
- name=self.name, namespace=self.namespace, body=self.job
228
- )
229
- timeout = TIMEOUT
230
- job_response = self.batch_api.read_namespaced_job_status(
231
- name=self.name, namespace=self.namespace
232
- )
233
- while job_response.status.conditions is None and timeout > 0:
234
- time.sleep(1)
235
- timeout -= 1
236
- job_response = self.batch_api.read_namespaced_job_status(
237
- name=self.name, namespace=self.namespace
238
- )
239
-
240
- if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
241
163
  raise LaunchError(
242
- "Failed to suspend job {}. Check Kubernetes dashboard for more info.".format(
243
- self.name
244
- )
245
- )
246
-
247
- def cancel(self) -> None:
248
- """Cancel the run."""
249
- self.suspend()
250
- self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
164
+ f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
165
+ ) from e
251
166
 
252
167
 
253
168
  class CrdSubmittedRun(AbstractRun):
@@ -262,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
262
177
  namespace: str,
263
178
  core_api: CoreV1Api,
264
179
  custom_api: CustomObjectsApi,
265
- pod_names: List[str],
180
+ monitor: KubernetesRunMonitor,
266
181
  ) -> None:
267
182
  """Create a run object for tracking the progress of a CRD.
268
183
 
@@ -274,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
274
189
  namespace: The namespace of the CRD instance.
275
190
  core_api: The Kubernetes core API client.
276
191
  custom_api: The Kubernetes custom object API client.
277
- pod_names: The names of the pods associated with the CRD instance.
192
+ monitor: The run monitor.
278
193
 
279
194
  Raises:
280
195
  LaunchError: If the CRD instance does not exist.
@@ -286,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
286
201
  self.namespace = namespace
287
202
  self.core_api = core_api
288
203
  self.custom_api = custom_api
289
- self.pod_names = pod_names
290
204
  self._fail_count = 0
291
- try:
292
- self.job = self.custom_api.get_namespaced_custom_object(
293
- group=self.group,
294
- version=self.version,
295
- namespace=self.namespace,
296
- plural=self.plural,
297
- name=self.name,
298
- )
299
- except ApiException as e:
300
- raise LaunchError(
301
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
302
- ) from e
205
+ self.monitor = monitor
303
206
 
304
207
  @property
305
208
  def id(self) -> str:
@@ -311,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
311
214
  # TODO: test more carefully once we release multi-node support
312
215
  logs: Dict[str, Optional[str]] = {}
313
216
  try:
314
- for pod_name in self.pod_names:
217
+ pods = self.core_api.list_namespaced_pod(
218
+ label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
219
+ )
220
+ pod_names = [pi.metadata.name for pi in pods.items]
221
+ for pod_name in pod_names:
315
222
  logs[pod_name] = self.core_api.read_namespaced_pod_log(
316
223
  name=pod_name, namespace=self.namespace
317
224
  )
@@ -325,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
325
232
 
326
233
  def get_status(self) -> Status:
327
234
  """Get status of custom object."""
328
- try:
329
- job_response = self.custom_api.get_namespaced_custom_object_status(
330
- group=self.group,
331
- version=self.version,
332
- namespace=self.namespace,
333
- plural=self.plural,
334
- name=self.name,
335
- )
336
- except ApiException as e:
337
- raise LaunchError(
338
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
339
- ) from e
340
- # Custom objects can technically define whater states and format the
341
- # response to the status request however they want. This checks for
342
- # the most common cases.
343
- status = job_response["status"]
344
- state = status.get("state")
345
- if isinstance(state, dict):
346
- state = state.get("phase")
347
- if state is None:
348
- raise LaunchError(
349
- f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
350
- )
351
- return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
235
+ return self.monitor.get_status()
352
236
 
353
237
  def cancel(self) -> None:
354
238
  """Cancel the custom object."""
@@ -370,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
370
254
  while True:
371
255
  status = self.get_status()
372
256
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
373
- if status.state != "running":
374
- break
375
257
  time.sleep(5)
376
- return status.state == "finished"
258
+ if status.state in ["finished", "failed", "preempted"]:
259
+ return status.state == "finished"
377
260
 
378
261
 
379
262
  class KubernetesRunner(AbstractRunner):
@@ -400,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
400
283
  self.environment = environment
401
284
  self.registry = registry
402
285
 
403
- def wait_job_launch(
404
- self,
405
- job_name: str,
406
- namespace: str,
407
- core_api: "CoreV1Api",
408
- label: str = "job-name",
409
- ) -> List[str]:
410
- """Wait for a job to be launched and return the pod names.
411
-
412
- Arguments:
413
- job_name: The name of the job.
414
- namespace: The namespace of the job.
415
- core_api: The Kubernetes core API client.
416
- label: The label key to match against job_name.
417
-
418
- Returns:
419
- The names of the pods associated with the job.
420
- """
421
- pods = core_api.list_namespaced_pod(
422
- label_selector=f"{label}={job_name}", namespace=namespace
423
- )
424
- timeout = TIMEOUT
425
- while len(pods.items) == 0 and timeout > 0:
426
- time.sleep(1)
427
- timeout -= 1
428
- pods = core_api.list_namespaced_pod(
429
- label_selector=f"{label}={job_name}", namespace=namespace
430
- )
431
-
432
- if timeout == 0:
433
- raise LaunchError(
434
- "No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
435
- job_name
436
- )
437
- )
438
-
439
- pod_names = [pi.metadata.name for pi in pods.items]
440
- wandb.termlog(
441
- f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
442
- )
443
- return pod_names
444
-
445
286
  def get_namespace(
446
287
  self, resource_args: Dict[str, Any], context: Dict[str, Any]
447
288
  ) -> str:
@@ -522,18 +363,10 @@ class KubernetesRunner(AbstractRunner):
522
363
  or launch_project.get_single_entry_point()
523
364
  )
524
365
  if launch_project.docker_image:
525
- if len(containers) > 1:
526
- raise LaunchError(
527
- "Invalid specification of multiple containers. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
528
- )
529
366
  # dont specify run id if user provided image, could have multiple runs
530
367
  containers[0]["image"] = image_uri
531
368
  # TODO: handle secret pulling image from registry
532
369
  elif not any(["image" in cont for cont in containers]):
533
- if len(containers) > 1:
534
- raise LaunchError(
535
- "Launch only builds one container at a time. See https://docs.wandb.ai/guides/launch for guidance on submitting jobs."
536
- )
537
370
  assert entry_point is not None
538
371
  # in the non instance case we need to make an imagePullSecret
539
372
  # so the new job can pull the image
@@ -638,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
638
471
  body=resource_args,
639
472
  )
640
473
  except ApiException as e:
474
+ body = json.loads(e.body)
475
+ body_yaml = yaml.dump(body)
641
476
  raise LaunchError(
642
- f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
477
+ f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
643
478
  ) from e
644
479
  name = response.get("metadata", {}).get("name")
645
480
  _logger.info(f"Created {kind} {response['metadata']['name']}")
646
481
  core = client.CoreV1Api(api_client)
647
- pod_names = self.wait_job_launch(
648
- launch_project.run_id, namespace, core, label="wandb/run-id"
482
+ run_monitor = KubernetesRunMonitor(
483
+ job_field_selector=f"metadata.name={name}",
484
+ pod_label_selector=f"wandb/run-id={launch_project.run_id}",
485
+ namespace=namespace,
486
+ batch_api=None,
487
+ core_api=core,
488
+ custom_api=api,
489
+ group=group,
490
+ version=version,
491
+ plural=plural,
649
492
  )
650
- return CrdSubmittedRun(
493
+ run_monitor.start()
494
+ submitted_run = CrdSubmittedRun(
651
495
  name=name,
652
496
  group=group,
653
497
  version=version,
@@ -655,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
655
499
  plural=plural,
656
500
  core_api=client.CoreV1Api(api_client),
657
501
  custom_api=api,
658
- pod_names=pod_names,
502
+ monitor=run_monitor,
659
503
  )
504
+ if self.backend_config[PROJECT_SYNCHRONOUS]:
505
+ submitted_run.wait()
506
+ return submitted_run
660
507
 
661
508
  batch_api = kubernetes.client.BatchV1Api(api_client)
662
509
  core_api = kubernetes.client.CoreV1Api(api_client)
@@ -674,12 +521,22 @@ class KubernetesRunner(AbstractRunner):
674
521
  0
675
522
  ] # create_from_yaml returns a nested list of k8s objects
676
523
  job_name = job_response.metadata.name
677
- pod_names = self.wait_job_launch(job_name, namespace, core_api)
524
+
525
+ # Event stream monitor to ensure pod creation and job completion.
526
+ monitor = KubernetesRunMonitor(
527
+ job_field_selector=f"metadata.name={job_name}",
528
+ pod_label_selector=f"job-name={job_name}",
529
+ namespace=namespace,
530
+ batch_api=batch_api,
531
+ core_api=core_api,
532
+ )
533
+ monitor.start()
678
534
  submitted_job = KubernetesSubmittedRun(
679
- batch_api, core_api, job_name, pod_names, namespace, secret
535
+ monitor, batch_api, core_api, job_name, [], namespace, secret
680
536
  )
681
537
  if self.backend_config[PROJECT_SYNCHRONOUS]:
682
538
  submitted_job.wait()
539
+
683
540
  return submitted_job
684
541
 
685
542
 
@@ -5,7 +5,7 @@ import subprocess
5
5
  import sys
6
6
  import threading
7
7
  import time
8
- from typing import Any, Dict, List, Optional
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
9
9
 
10
10
  import wandb
11
11
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
@@ -26,6 +26,9 @@ from ..utils import (
26
26
  )
27
27
  from .abstract import AbstractRun, AbstractRunner, Status
28
28
 
29
+ if TYPE_CHECKING:
30
+ from wandb.apis.internal import Api
31
+
29
32
  _logger = logging.getLogger(__name__)
30
33
 
31
34
 
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
95
98
 
96
99
  def __init__(
97
100
  self,
98
- api: wandb.apis.internal.Api,
101
+ api: "Api",
99
102
  backend_config: Dict[str, Any],
100
103
  environment: AbstractEnvironment,
101
104
  registry: AbstractRegistry,
@@ -46,7 +46,10 @@ class LocalProcessRunner(AbstractRunner):
46
46
  _logger.warning(_msg)
47
47
 
48
48
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
49
- entry_point = launch_project.get_single_entry_point()
49
+ entry_point = (
50
+ launch_project.override_entrypoint
51
+ or launch_project.get_single_entry_point()
52
+ )
50
53
 
51
54
  cmd: List[Any] = []
52
55
 
@@ -9,26 +9,28 @@ import traceback
9
9
  from abc import ABC, abstractmethod
10
10
  from dataclasses import dataclass
11
11
  from enum import Enum
12
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
12
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
13
13
 
14
14
  import click
15
15
  import yaml
16
16
 
17
17
  import wandb
18
- import wandb.apis.public as public
19
- from wandb.apis.internal import Api
20
- from wandb.apis.public import Api as PublicApi
21
- from wandb.apis.public import QueuedRun, Run
22
18
  from wandb.errors import CommError
19
+ from wandb.sdk.launch._launch_add import launch_add
23
20
  from wandb.sdk.launch.errors import LaunchError
24
- from wandb.sdk.launch.launch_add import launch_add
25
21
  from wandb.sdk.launch.sweeps import SchedulerError
26
22
  from wandb.sdk.launch.sweeps.utils import (
27
23
  create_sweep_command_args,
28
24
  make_launch_sweep_entrypoint,
29
25
  )
30
26
  from wandb.sdk.lib.runid import generate_id
31
- from wandb.sdk.wandb_run import Run as SdkRun
27
+
28
+ if TYPE_CHECKING:
29
+ import wandb.apis.public as public
30
+ from wandb.apis.internal import Api
31
+ from wandb.apis.public import QueuedRun, Run
32
+ from wandb.sdk.wandb_run import Run as SdkRun
33
+
32
34
 
33
35
  _logger = logging.getLogger(__name__)
34
36
  LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
@@ -84,7 +86,7 @@ class SweepRun:
84
86
  id: str
85
87
  worker_id: int
86
88
  state: RunState = RunState.RUNNING
87
- queued_run: Optional[public.QueuedRun] = None
89
+ queued_run: Optional["public.QueuedRun"] = None
88
90
  args: Optional[Dict[str, Any]] = None
89
91
  logs: Optional[List[str]] = None
90
92
 
@@ -98,7 +100,7 @@ class Scheduler(ABC):
98
100
 
99
101
  def __init__(
100
102
  self,
101
- api: Api,
103
+ api: "Api",
102
104
  *args: Optional[Any],
103
105
  polling_sleep: Optional[float] = None,
104
106
  sweep_id: Optional[str] = None,
@@ -108,6 +110,8 @@ class Scheduler(ABC):
108
110
  num_workers: Optional[Union[int, str]] = None,
109
111
  **kwargs: Optional[Any],
110
112
  ):
113
+ from wandb.apis.public import Api as PublicApi
114
+
111
115
  self._api = api
112
116
  self._public_api = PublicApi()
113
117
  self._entity = (
@@ -244,7 +248,7 @@ class Scheduler(ABC):
244
248
  _id: w for _id, w in self._workers.items() if _id not in self.busy_workers
245
249
  }
246
250
 
247
- def _init_wandb_run(self) -> SdkRun:
251
+ def _init_wandb_run(self) -> "SdkRun":
248
252
  """Controls resume or init logic for a scheduler wandb run."""
249
253
  _type = self._kwargs.get("sweep_type", "sweep")
250
254
  run: SdkRun = wandb.init(
@@ -346,9 +350,8 @@ class Scheduler(ABC):
346
350
  self.exit()
347
351
  raise e
348
352
  else:
349
- wandb.termlog(f"{LOG_PREFIX}Scheduler completed successfully")
350
- # don't overwrite special states (e.g. STOPPED, FAILED)
351
- if self.state in [SchedulerState.RUNNING, SchedulerState.FLUSH_RUNS]:
353
+ # scheduler succeeds if at runcap
354
+ if self.state == SchedulerState.FLUSH_RUNS and self.at_runcap:
352
355
  self.state = SchedulerState.COMPLETED
353
356
  self.exit()
354
357
 
@@ -362,16 +365,24 @@ class Scheduler(ABC):
362
365
  f"{LOG_PREFIX}Failed to save state: {traceback.format_exc()}"
363
366
  )
364
367
 
365
- if self.state not in [
366
- SchedulerState.COMPLETED,
367
- SchedulerState.STOPPED,
368
- ]:
368
+ status = ""
369
+ if self.state == SchedulerState.FLUSH_RUNS:
370
+ self._set_sweep_state("PAUSED")
371
+ status = "paused"
372
+ elif self.state == SchedulerState.COMPLETED:
373
+ self._set_sweep_state("FINISHED")
374
+ status = "completed"
375
+ elif self.state in [SchedulerState.CANCELLED, SchedulerState.STOPPED]:
376
+ self._set_sweep_state("CANCELED") # one L
377
+ status = "cancelled"
378
+ self._stop_runs()
379
+ else:
369
380
  self.state = SchedulerState.FAILED
370
381
  self._set_sweep_state("CRASHED")
371
- else:
372
- self._set_sweep_state("FINISHED")
382
+ status = "crashed"
383
+ self._stop_runs()
373
384
 
374
- self._stop_runs()
385
+ wandb.termlog(f"{LOG_PREFIX}Scheduler {status}")
375
386
  self._wandb_run.finish()
376
387
 
377
388
  def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
@@ -494,6 +505,7 @@ class Scheduler(ABC):
494
505
  """Update the scheduler state from state of scheduler run and sweep state."""
495
506
  state: RunState = self._get_run_state(self._wandb_run.id)
496
507
 
508
+ # map scheduler run-state to scheduler-state
497
509
  if state == RunState.KILLED:
498
510
  self.state = SchedulerState.STOPPED
499
511
  elif state in [RunState.FAILED, RunState.CRASHED]:
@@ -501,17 +513,20 @@ class Scheduler(ABC):
501
513
  elif state == RunState.FINISHED:
502
514
  self.state = SchedulerState.COMPLETED
503
515
 
516
+ # check sweep state for completed states, overwrite scheduler state
504
517
  try:
505
518
  sweep_state = self._api.get_sweep_state(
506
519
  self._sweep_id, self._entity, self._project
507
520
  )
508
521
  except Exception as e:
509
- _logger.debug(f"sweep state error: {sweep_state} e: {e}")
522
+ _logger.debug(f"sweep state error: {e}")
510
523
  return
511
524
 
512
- if sweep_state in ["FINISHED", "CANCELLED"]:
525
+ if sweep_state == "FINISHED":
513
526
  self.state = SchedulerState.COMPLETED
514
- elif sweep_state in ["PAUSED", "STOPPED"]:
527
+ elif sweep_state in ["CANCELLED", "STOPPED"]:
528
+ self.state = SchedulerState.CANCELLED
529
+ elif sweep_state == "PAUSED":
515
530
  self.state = SchedulerState.FLUSH_RUNS
516
531
 
517
532
  def _update_run_states(self) -> None:
@@ -674,6 +689,9 @@ class Scheduler(ABC):
674
689
  f' {"job" if _job else "image_uri"} entrypoint'
675
690
  )
676
691
 
692
+ # override resource and args of job
693
+ _job_launch_config = self._wandb_run.config.get("launch") or {}
694
+
677
695
  run_id = run.id or generate_id()
678
696
  queued_run = launch_add(
679
697
  run_id=run_id,
@@ -685,8 +703,8 @@ class Scheduler(ABC):
685
703
  entity=self._entity,
686
704
  queue_name=self._kwargs.get("queue"),
687
705
  project_queue=self._project_queue,
688
- resource=self._kwargs.get("resource", None),
689
- resource_args=self._kwargs.get("resource_args", None),
706
+ resource=_job_launch_config.get("resource"),
707
+ resource_args=_job_launch_config.get("resource_args"),
690
708
  author=self._kwargs.get("author"),
691
709
  sweep_id=self._sweep_id,
692
710
  )
@@ -1,15 +1,17 @@
1
1
  import json
2
2
  import os
3
3
  import re
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  import yaml
7
7
 
8
8
  import wandb
9
9
  from wandb import util
10
- from wandb.apis.public import Api as PublicApi
11
10
  from wandb.sdk.launch.errors import LaunchError
12
11
 
12
+ if TYPE_CHECKING:
13
+ from wandb.apis.public import Api as PublicApi
14
+
13
15
  DEFAULT_SWEEP_COMMAND: List[str] = [
14
16
  "${env}",
15
17
  "${interpreter}",
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
276
278
  return entry_point, macro_args
277
279
 
278
280
 
279
- def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
281
+ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
280
282
  """Check if the job exists using the public api.
281
283
 
282
284
  Returns: True if no job is passed, or if the job exists.
wandb/sdk/launch/utils.py CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
127
127
  prefix = ""
128
128
  if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
129
129
  prefix = "🚀 "
130
- wandb.termlog(f"{LOG_PREFIX}{prefix}Launching run into {entity}/{project}")
130
+ wandb.termlog(
131
+ f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
132
+ )
131
133
  return project, entity
132
134
 
133
135