wandb 0.21.0__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. wandb/__init__.py +16 -14
  2. wandb/__init__.pyi +427 -450
  3. wandb/agents/pyagent.py +41 -12
  4. wandb/analytics/sentry.py +7 -2
  5. wandb/apis/importers/mlflow.py +1 -1
  6. wandb/apis/public/__init__.py +1 -1
  7. wandb/apis/public/api.py +525 -360
  8. wandb/apis/public/artifacts.py +207 -13
  9. wandb/apis/public/automations.py +19 -3
  10. wandb/apis/public/files.py +172 -33
  11. wandb/apis/public/history.py +67 -15
  12. wandb/apis/public/integrations.py +25 -2
  13. wandb/apis/public/jobs.py +90 -2
  14. wandb/apis/public/projects.py +130 -79
  15. wandb/apis/public/query_generator.py +11 -1
  16. wandb/apis/public/registries/_utils.py +14 -16
  17. wandb/apis/public/registries/registries_search.py +183 -304
  18. wandb/apis/public/reports.py +96 -15
  19. wandb/apis/public/runs.py +299 -105
  20. wandb/apis/public/sweeps.py +222 -22
  21. wandb/apis/public/teams.py +41 -4
  22. wandb/apis/public/users.py +45 -4
  23. wandb/automations/_generated/delete_automation.py +1 -3
  24. wandb/automations/_generated/enums.py +13 -11
  25. wandb/beta/workflows.py +66 -30
  26. wandb/bin/gpu_stats.exe +0 -0
  27. wandb/bin/wandb-core +0 -0
  28. wandb/cli/cli.py +127 -3
  29. wandb/env.py +8 -0
  30. wandb/errors/errors.py +4 -1
  31. wandb/integration/lightning/fabric/logger.py +3 -4
  32. wandb/integration/metaflow/__init__.py +6 -0
  33. wandb/integration/metaflow/data_pandas.py +74 -0
  34. wandb/integration/metaflow/data_pytorch.py +75 -0
  35. wandb/integration/metaflow/data_sklearn.py +76 -0
  36. wandb/integration/metaflow/errors.py +13 -0
  37. wandb/integration/metaflow/metaflow.py +167 -223
  38. wandb/integration/openai/fine_tuning.py +1 -2
  39. wandb/integration/weave/__init__.py +6 -0
  40. wandb/integration/weave/interface.py +49 -0
  41. wandb/integration/weave/weave.py +63 -0
  42. wandb/jupyter.py +5 -5
  43. wandb/plot/custom_chart.py +30 -7
  44. wandb/proto/v3/wandb_internal_pb2.py +281 -280
  45. wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
  46. wandb/proto/v4/wandb_internal_pb2.py +280 -280
  47. wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
  48. wandb/proto/v5/wandb_internal_pb2.py +280 -280
  49. wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
  50. wandb/proto/v6/wandb_internal_pb2.py +280 -280
  51. wandb/proto/v6/wandb_telemetry_pb2.py +4 -4
  52. wandb/proto/wandb_deprecated.py +6 -0
  53. wandb/sdk/artifacts/_factories.py +17 -0
  54. wandb/sdk/artifacts/_generated/__init__.py +221 -13
  55. wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
  56. wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
  57. wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
  58. wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
  59. wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
  60. wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
  61. wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
  62. wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
  63. wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
  64. wandb/sdk/artifacts/_generated/enums.py +5 -0
  65. wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
  66. wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
  67. wandb/sdk/artifacts/_generated/fragments.py +279 -41
  68. wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
  69. wandb/sdk/artifacts/_generated/operations.py +654 -51
  70. wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
  71. wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
  72. wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
  73. wandb/sdk/artifacts/_graphql_fragments.py +3 -86
  74. wandb/sdk/artifacts/_internal_artifact.py +19 -8
  75. wandb/sdk/artifacts/_validators.py +14 -4
  76. wandb/sdk/artifacts/artifact.py +512 -618
  77. wandb/sdk/artifacts/artifact_file_cache.py +10 -6
  78. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  79. wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
  80. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
  81. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
  82. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  83. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
  84. wandb/sdk/data_types/audio.py +38 -10
  85. wandb/sdk/data_types/base_types/media.py +6 -56
  86. wandb/sdk/data_types/graph.py +48 -14
  87. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -3
  88. wandb/sdk/data_types/helper_types/image_mask.py +1 -3
  89. wandb/sdk/data_types/histogram.py +34 -21
  90. wandb/sdk/data_types/html.py +35 -12
  91. wandb/sdk/data_types/image.py +104 -68
  92. wandb/sdk/data_types/molecule.py +32 -19
  93. wandb/sdk/data_types/object_3d.py +36 -17
  94. wandb/sdk/data_types/plotly.py +18 -5
  95. wandb/sdk/data_types/saved_model.py +4 -6
  96. wandb/sdk/data_types/table.py +59 -30
  97. wandb/sdk/data_types/video.py +53 -26
  98. wandb/sdk/integration_utils/auto_logging.py +2 -2
  99. wandb/sdk/interface/interface_queue.py +1 -4
  100. wandb/sdk/interface/interface_shared.py +26 -37
  101. wandb/sdk/interface/interface_sock.py +24 -14
  102. wandb/sdk/internal/internal_api.py +6 -0
  103. wandb/sdk/internal/job_builder.py +6 -0
  104. wandb/sdk/internal/settings_static.py +2 -3
  105. wandb/sdk/launch/agent/agent.py +8 -1
  106. wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -2
  107. wandb/sdk/launch/create_job.py +15 -2
  108. wandb/sdk/launch/inputs/internal.py +3 -4
  109. wandb/sdk/launch/inputs/schema.py +1 -0
  110. wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
  111. wandb/sdk/launch/runner/kubernetes_runner.py +323 -1
  112. wandb/sdk/launch/sweeps/scheduler.py +2 -3
  113. wandb/sdk/lib/asyncio_compat.py +19 -16
  114. wandb/sdk/lib/asyncio_manager.py +252 -0
  115. wandb/sdk/lib/deprecate.py +1 -7
  116. wandb/sdk/lib/disabled.py +1 -1
  117. wandb/sdk/lib/hashutil.py +27 -5
  118. wandb/sdk/lib/module.py +7 -13
  119. wandb/sdk/lib/printer.py +2 -2
  120. wandb/sdk/lib/printer_asyncio.py +3 -1
  121. wandb/sdk/lib/progress.py +0 -19
  122. wandb/sdk/lib/retry.py +185 -78
  123. wandb/sdk/lib/service/service_client.py +106 -0
  124. wandb/sdk/lib/service/service_connection.py +20 -26
  125. wandb/sdk/lib/service/service_token.py +30 -13
  126. wandb/sdk/mailbox/mailbox.py +13 -5
  127. wandb/sdk/mailbox/mailbox_handle.py +22 -13
  128. wandb/sdk/mailbox/response_handle.py +42 -106
  129. wandb/sdk/mailbox/wait_with_progress.py +7 -42
  130. wandb/sdk/wandb_init.py +77 -116
  131. wandb/sdk/wandb_login.py +19 -15
  132. wandb/sdk/wandb_metric.py +2 -0
  133. wandb/sdk/wandb_run.py +497 -469
  134. wandb/sdk/wandb_settings.py +145 -4
  135. wandb/sdk/wandb_setup.py +204 -124
  136. wandb/sdk/wandb_sweep.py +14 -13
  137. wandb/sdk/wandb_watch.py +4 -6
  138. wandb/sync/sync.py +10 -0
  139. wandb/util.py +58 -1
  140. wandb/wandb_run.py +1 -2
  141. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
  142. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/RECORD +145 -129
  143. wandb/sdk/interface/interface_relay.py +0 -38
  144. wandb/sdk/interface/router.py +0 -89
  145. wandb/sdk/interface/router_queue.py +0 -43
  146. wandb/sdk/interface/router_relay.py +0 -50
  147. wandb/sdk/interface/router_sock.py +0 -32
  148. wandb/sdk/lib/sock_client.py +0 -236
  149. wandb/vendor/pynvml/__init__.py +0 -0
  150. wandb/vendor/pynvml/pynvml.py +0 -4779
  151. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
  152. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
  153. {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ import datetime
6
6
  import json
7
7
  import logging
8
8
  import os
9
+ import time
9
10
  from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
10
11
 
11
12
  import yaml
@@ -20,11 +21,13 @@ from wandb.sdk.launch.registry.local_registry import LocalRegistry
20
21
  from wandb.sdk.launch.runner.abstract import Status
21
22
  from wandb.sdk.launch.runner.kubernetes_monitor import (
22
23
  WANDB_K8S_LABEL_AGENT,
24
+ WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
23
25
  WANDB_K8S_LABEL_MONITOR,
24
26
  WANDB_K8S_RUN_ID,
25
27
  CustomResource,
26
28
  LaunchKubernetesMonitor,
27
29
  )
30
+ from wandb.sdk.launch.utils import recursive_macro_sub
28
31
  from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
29
32
  from wandb.util import get_module
30
33
 
@@ -47,6 +50,9 @@ get_module(
47
50
 
48
51
  import kubernetes_asyncio # type: ignore # noqa: E402
49
52
  from kubernetes_asyncio import client # noqa: E402
53
+ from kubernetes_asyncio.client.api.apps_v1_api import ( # type: ignore # noqa: E402
54
+ AppsV1Api,
55
+ )
50
56
  from kubernetes_asyncio.client.api.batch_v1_api import ( # type: ignore # noqa: E402
51
57
  BatchV1Api,
52
58
  )
@@ -56,6 +62,9 @@ from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa:
56
62
  from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
57
63
  CustomObjectsApi,
58
64
  )
65
+ from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
66
+ NetworkingV1Api,
67
+ )
59
68
  from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
60
69
  V1Secret,
61
70
  )
@@ -78,9 +87,12 @@ class KubernetesSubmittedRun(AbstractRun):
78
87
  self,
79
88
  batch_api: "BatchV1Api",
80
89
  core_api: "CoreV1Api",
90
+ apps_api: "AppsV1Api",
91
+ network_api: "NetworkingV1Api",
81
92
  name: str,
82
93
  namespace: Optional[str] = "default",
83
94
  secret: Optional["V1Secret"] = None,
95
+ auxiliary_resource_label_key: Optional[str] = None,
84
96
  ) -> None:
85
97
  """Initialize a KubernetesSubmittedRun.
86
98
 
@@ -95,6 +107,7 @@ class KubernetesSubmittedRun(AbstractRun):
95
107
  Arguments:
96
108
  batch_api: Kubernetes BatchV1Api object.
97
109
  core_api: Kubernetes CoreV1Api object.
110
+ network_api: Kubernetes NetworkV1Api object.
98
111
  name: Name of the job.
99
112
  namespace: Kubernetes namespace.
100
113
  secret: Kubernetes secret.
@@ -104,10 +117,13 @@ class KubernetesSubmittedRun(AbstractRun):
104
117
  """
105
118
  self.batch_api = batch_api
106
119
  self.core_api = core_api
120
+ self.apps_api = apps_api
121
+ self.network_api = network_api
107
122
  self.name = name
108
123
  self.namespace = namespace
109
124
  self._fail_count = 0
110
125
  self.secret = secret
126
+ self.auxiliary_resource_label_key = auxiliary_resource_label_key
111
127
 
112
128
  @property
113
129
  def id(self) -> str:
@@ -149,6 +165,7 @@ class KubernetesSubmittedRun(AbstractRun):
149
165
  await asyncio.sleep(5)
150
166
 
151
167
  await self._delete_secret()
168
+ await self._delete_auxiliary_resources_by_label()
152
169
  return (
153
170
  status.state == "finished"
154
171
  ) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
@@ -157,6 +174,7 @@ class KubernetesSubmittedRun(AbstractRun):
157
174
  status = LaunchKubernetesMonitor.get_status(self.name)
158
175
  if status in ["stopped", "failed", "finished", "preempted"]:
159
176
  await self._delete_secret()
177
+ await self._delete_auxiliary_resources_by_label()
160
178
  return status
161
179
 
162
180
  async def cancel(self) -> None:
@@ -167,6 +185,7 @@ class KubernetesSubmittedRun(AbstractRun):
167
185
  name=self.name,
168
186
  )
169
187
  await self._delete_secret()
188
+ await self._delete_auxiliary_resources_by_label()
170
189
  except ApiException as e:
171
190
  raise LaunchError(
172
191
  f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
@@ -181,6 +200,50 @@ class KubernetesSubmittedRun(AbstractRun):
181
200
  )
182
201
  self.secret = None
183
202
 
203
+ async def _delete_auxiliary_resources_by_label(self) -> None:
204
+ if self.auxiliary_resource_label_key is None:
205
+ return
206
+
207
+ label_selector = (
208
+ f"{WANDB_K8S_LABEL_AUXILIARY_RESOURCE}={self.auxiliary_resource_label_key}"
209
+ )
210
+
211
+ try:
212
+ resource_cleanups = [
213
+ (self.core_api, "service"),
214
+ (self.batch_api, "job"),
215
+ (self.core_api, "pod"),
216
+ (self.core_api, "secret"),
217
+ (self.apps_api, "deployment"),
218
+ (self.network_api, "network_policy"),
219
+ ]
220
+
221
+ for api_client, resource_type in resource_cleanups:
222
+ try:
223
+ list_method = getattr(
224
+ api_client, f"list_namespaced_{resource_type}"
225
+ )
226
+ delete_method = getattr(
227
+ api_client, f"delete_namespaced_{resource_type}"
228
+ )
229
+
230
+ # List resources with our label
231
+ resources = await list_method(
232
+ namespace=self.namespace, label_selector=label_selector
233
+ )
234
+
235
+ # Delete each resource
236
+ for resource in resources.items:
237
+ await delete_method(
238
+ name=resource.metadata.name, namespace=self.namespace
239
+ )
240
+
241
+ except (AttributeError, ApiException) as e:
242
+ wandb.termwarn(f"Could not clean up {resource_type}: {e}")
243
+
244
+ except Exception as e:
245
+ wandb.termwarn(f"Failed to clean up some auxiliary resources: {e}")
246
+
184
247
 
185
248
  class CrdSubmittedRun(AbstractRun):
186
249
  """Run submitted to a CRD backend, e.g. Volcano."""
@@ -366,6 +429,7 @@ class KubernetesRunner(AbstractRunner):
366
429
  job_metadata["generateName"] = make_name_dns_safe(
367
430
  f"launch-{launch_project.target_entity}-{launch_project.target_project}-"
368
431
  )
432
+ job_metadata["namespace"] = namespace
369
433
 
370
434
  for i, cont in enumerate(containers):
371
435
  if "name" not in cont:
@@ -489,6 +553,216 @@ class KubernetesRunner(AbstractRunner):
489
553
 
490
554
  return job, api_key_secret
491
555
 
556
+ async def _wait_for_resource_ready(
557
+ self,
558
+ api_client: kubernetes_asyncio.client.ApiClient,
559
+ config: Dict[str, Any],
560
+ namespace: str,
561
+ timeout_seconds: int = 300,
562
+ ) -> None:
563
+ """Wait for a Kubernetes resource to be ready.
564
+
565
+ Arguments:
566
+ api_client: The Kubernetes API client.
567
+ config: The resource configuration.
568
+ namespace: The namespace where the resource was created.
569
+ timeout_seconds: Maximum time to wait for readiness.
570
+ """
571
+ resource_kind = config.get("kind")
572
+ resource_name = config.get("metadata", {}).get("name")
573
+
574
+ if not resource_kind or not resource_name:
575
+ wandb.termerror(
576
+ f"{LOG_PREFIX}Cannot wait for resource without kind or name"
577
+ )
578
+ return
579
+
580
+ wandb.termlog(
581
+ f"{LOG_PREFIX}Waiting for {resource_kind} '{resource_name}' to be ready..."
582
+ )
583
+
584
+ start_time = time.time()
585
+
586
+ if resource_kind == "Deployment":
587
+ await self._wait_for_deployment_ready(
588
+ api_client, resource_name, namespace, timeout_seconds
589
+ )
590
+ elif resource_kind == "Service":
591
+ await self._wait_for_service_ready(
592
+ api_client, resource_name, namespace, timeout_seconds
593
+ )
594
+ elif resource_kind == "Pod":
595
+ await self._wait_for_pod_ready(
596
+ api_client, resource_name, namespace, timeout_seconds
597
+ )
598
+ else:
599
+ wandb.termlog(
600
+ f"{LOG_PREFIX}No specific readiness check for {resource_kind}, waiting 5 seconds..."
601
+ )
602
+ await asyncio.sleep(5)
603
+
604
+ elapsed = time.time() - start_time
605
+ wandb.termlog(
606
+ f"{LOG_PREFIX}{resource_kind} '{resource_name}' is ready after {elapsed:.1f}s"
607
+ )
608
+
609
+ async def _wait_for_deployment_ready(
610
+ self,
611
+ api_client: kubernetes_asyncio.client.ApiClient,
612
+ name: str,
613
+ namespace: str,
614
+ timeout_seconds: int,
615
+ ) -> None:
616
+ """Wait for a Deployment to be ready."""
617
+ apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
618
+
619
+ async def check_deployment_ready():
620
+ deployment = await apps_api.read_namespaced_deployment(
621
+ name=name, namespace=namespace
622
+ )
623
+ status = deployment.status
624
+
625
+ if status.ready_replicas and status.replicas:
626
+ return status.ready_replicas >= status.replicas
627
+
628
+ return False
629
+
630
+ await self._wait_with_timeout(check_deployment_ready, timeout_seconds, name)
631
+
632
+ async def _wait_for_service_ready(
633
+ self,
634
+ api_client: kubernetes_asyncio.client.ApiClient,
635
+ name: str,
636
+ namespace: str,
637
+ timeout_seconds: int,
638
+ ) -> None:
639
+ """Wait for a Service to have endpoints."""
640
+ core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
641
+
642
+ async def check_service_ready():
643
+ endpoints = await core_api.read_namespaced_endpoints(
644
+ name=name, namespace=namespace
645
+ )
646
+ if endpoints.subsets:
647
+ for subset in endpoints.subsets:
648
+ if subset.addresses: # These are ready pod addresses
649
+ return True
650
+ return False
651
+
652
+ await self._wait_with_timeout(check_service_ready, timeout_seconds, name)
653
+
654
+ async def _wait_for_pod_ready(
655
+ self,
656
+ api_client: kubernetes_asyncio.client.ApiClient,
657
+ name: str,
658
+ namespace: str,
659
+ timeout_seconds: int,
660
+ ) -> None:
661
+ """Wait for a Pod to be ready."""
662
+ core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
663
+
664
+ async def check_pod_ready():
665
+ pod = await core_api.read_namespaced_pod(name=name, namespace=namespace)
666
+ if pod.status.phase == "Running":
667
+ if pod.status.container_statuses:
668
+ return all(status.ready for status in pod.status.container_statuses)
669
+ return True
670
+ return False
671
+
672
+ await self._wait_with_timeout(check_pod_ready, timeout_seconds, name)
673
+
674
+ async def _wait_with_timeout(
675
+ self, check_func, timeout_seconds: int, name: str
676
+ ) -> None:
677
+ """Generic timeout wrapper for readiness checks."""
678
+ start_time = time.time()
679
+
680
+ while time.time() - start_time < timeout_seconds:
681
+ try:
682
+ if await check_func():
683
+ return
684
+ except kubernetes_asyncio.client.ApiException as e:
685
+ if e.status == 404:
686
+ pass
687
+ else:
688
+ wandb.termerror(
689
+ f"{LOG_PREFIX}Error waiting for resource '{name}': {e}"
690
+ )
691
+ raise
692
+ except Exception as e:
693
+ wandb.termerror(f"{LOG_PREFIX}Error waiting for resource '{name}': {e}")
694
+ raise
695
+ await asyncio.sleep(2)
696
+
697
+ raise LaunchError(
698
+ f"Resource '{name}' not ready within {timeout_seconds} seconds"
699
+ )
700
+
701
+ async def _prepare_resource(
702
+ self,
703
+ api_client: kubernetes_asyncio.client.ApiClient,
704
+ config: Dict[str, Any],
705
+ namespace: str,
706
+ run_id: str,
707
+ launch_project: LaunchProject,
708
+ api_key_secret: Optional["V1Secret"] = None,
709
+ wait_for_ready: bool = True,
710
+ wait_timeout: int = 300,
711
+ ) -> None:
712
+ """Prepare a service for launch.
713
+
714
+ Arguments:
715
+ api_client: The Kubernetes API client.
716
+ config: The resource configuration to prepare.
717
+ namespace: The namespace to create the resource in.
718
+ run_id: The run ID to label the resource with.
719
+ launch_project: The launch project to get environment variables from.
720
+ api_key_secret: The API key secret to inject.
721
+ wait_for_ready: Whether to wait for the resource to be ready after creation.
722
+ wait_timeout: Maximum time in seconds to wait for resource readiness.
723
+ """
724
+ config.setdefault("metadata", {})
725
+ config["metadata"].setdefault("labels", {})
726
+ config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
727
+ config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
728
+
729
+ env_vars = launch_project.get_env_vars_dict(
730
+ self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
731
+ )
732
+ wandb_config_env = {
733
+ "WANDB_CONFIG": env_vars.get("WANDB_CONFIG", "{}"),
734
+ }
735
+ add_wandb_env(config, wandb_config_env)
736
+
737
+ if api_key_secret:
738
+ for cont in yield_containers(config):
739
+ env = cont.setdefault("env", [])
740
+ env.append(
741
+ {
742
+ "name": "WANDB_API_KEY",
743
+ "valueFrom": {
744
+ "secretKeyRef": {
745
+ "name": api_key_secret.metadata.name,
746
+ "key": "password",
747
+ }
748
+ },
749
+ }
750
+ )
751
+ cont["env"] = env
752
+
753
+ try:
754
+ await kubernetes_asyncio.utils.create_from_dict(
755
+ api_client, config, namespace=namespace
756
+ )
757
+
758
+ if wait_for_ready:
759
+ await self._wait_for_resource_ready(
760
+ api_client, config, namespace, wait_timeout
761
+ )
762
+ except Exception as e:
763
+ wandb.termerror(f"{LOG_PREFIX}Failed to create Kubernetes resource: {e}")
764
+ raise LaunchError(f"Failed to create Kubernetes resource: {e}")
765
+
492
766
  async def run(
493
767
  self, launch_project: LaunchProject, image_uri: str
494
768
  ) -> Optional[AbstractRun]:
@@ -630,10 +904,51 @@ class KubernetesRunner(AbstractRunner):
630
904
 
631
905
  batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
632
906
  core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
907
+ apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
908
+ network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
909
+
633
910
  namespace = self.get_namespace(resource_args, context)
634
911
  job, secret = await self._inject_defaults(
635
912
  resource_args, launch_project, image_uri, namespace, core_api
636
913
  )
914
+
915
+ update_dict = {
916
+ "project_name": launch_project.target_project,
917
+ "entity_name": launch_project.target_entity,
918
+ "run_id": launch_project.run_id,
919
+ "run_name": launch_project.name,
920
+ "image_uri": image_uri,
921
+ "author": launch_project.author,
922
+ }
923
+ update_dict.update(os.environ)
924
+ additional_services: List[Dict[str, Any]] = recursive_macro_sub(
925
+ launch_project.launch_spec.get("additional_services", []), update_dict
926
+ )
927
+ if additional_services:
928
+ wandb.termlog(
929
+ f"{LOG_PREFIX}Creating additional services: {additional_services}"
930
+ )
931
+
932
+ wait_for_ready = resource_args.get("wait_for_ready", True)
933
+ wait_timeout = resource_args.get("wait_timeout", 300)
934
+
935
+ await asyncio.gather(
936
+ *[
937
+ self._prepare_resource(
938
+ api_client,
939
+ resource.get("config", {}),
940
+ namespace,
941
+ launch_project.run_id,
942
+ launch_project,
943
+ secret,
944
+ wait_for_ready,
945
+ wait_timeout,
946
+ )
947
+ for resource in additional_services
948
+ if resource.get("config", {})
949
+ ]
950
+ )
951
+
637
952
  msg = "Creating Kubernetes job"
638
953
  if "name" in resource_args:
639
954
  msg += f": {resource_args['name']}"
@@ -658,7 +973,14 @@ class KubernetesRunner(AbstractRunner):
658
973
  job_name = job_response.metadata.name
659
974
  LaunchKubernetesMonitor.monitor_namespace(namespace)
660
975
  submitted_job = KubernetesSubmittedRun(
661
- batch_api, core_api, job_name, namespace, secret
976
+ batch_api,
977
+ core_api,
978
+ apps_api,
979
+ network_api,
980
+ job_name,
981
+ namespace,
982
+ secret,
983
+ f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}",
662
984
  )
663
985
  if self.backend_config[PROJECT_SYNCHRONOUS]:
664
986
  await submitted_job.wait()
@@ -36,7 +36,6 @@ if TYPE_CHECKING:
36
36
  import wandb.apis.public as public
37
37
  from wandb.apis.internal import Api
38
38
  from wandb.apis.public import QueuedRun, Run
39
- from wandb.sdk.wandb_run import Run as SdkRun
40
39
 
41
40
 
42
41
  _logger = logging.getLogger(__name__)
@@ -255,10 +254,10 @@ class Scheduler(ABC):
255
254
  _id: w for _id, w in self._workers.items() if _id not in self.busy_workers
256
255
  }
257
256
 
258
- def _init_wandb_run(self) -> "SdkRun":
257
+ def _init_wandb_run(self) -> "wandb.Run":
259
258
  """Controls resume or init logic for a scheduler wandb run."""
260
259
  settings = wandb.Settings(disable_job_creation=True)
261
- run: SdkRun = wandb.init( # type: ignore
260
+ run: wandb.Run = wandb.init( # type: ignore
262
261
  name=f"Scheduler.{self._sweep_id}",
263
262
  resume="allow",
264
263
  config=self._kwargs, # when run as a job, this sets config
@@ -23,7 +23,7 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
23
23
  Note that due to starting a new thread, this is slightly slow.
24
24
  """
25
25
  with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
26
- runner = _Runner()
26
+ runner = CancellableRunner()
27
27
  future = executor.submit(runner.run, fn)
28
28
 
29
29
  try:
@@ -33,15 +33,16 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
33
33
  runner.cancel()
34
34
 
35
35
 
36
- class _RunnerCancelledError(Exception):
37
- """The `_Runner.run()` invocation was cancelled."""
36
+ class RunnerCancelledError(Exception):
37
+ """The `CancellableRunner.run()` invocation was cancelled."""
38
38
 
39
39
 
40
- class _Runner:
40
+ class CancellableRunner:
41
41
  """Runs an asyncio event loop allowing cancellation.
42
42
 
43
- This is like `asyncio.run()`, except it provides a `cancel()` method
44
- meant to be called in a `finally` block.
43
+ The `run()` method is like `asyncio.run()`. The `cancel()` method may
44
+ be used in a different thread, for instance in a `finally` block, to cancel
45
+ all tasks, and it is a no-op if `run()` completed.
45
46
 
46
47
  Without this, it is impossible to make `asyncio.run()` stop if it runs
47
48
  in a non-main thread. In particular, a KeyboardInterrupt causes the
@@ -69,7 +70,7 @@ class _Runner:
69
70
  The result of the coroutine returned by `fn`.
70
71
 
71
72
  Raises:
72
- _RunnerCancelledError: If `cancel()` is called.
73
+ RunnerCancelledError: If `cancel()` is called.
73
74
  """
74
75
  return asyncio.run(self._run_or_cancel(fn))
75
76
 
@@ -79,7 +80,7 @@ class _Runner:
79
80
  ) -> _T:
80
81
  with self._lock:
81
82
  if self._is_cancelled:
82
- raise _RunnerCancelledError()
83
+ raise RunnerCancelledError()
83
84
 
84
85
  self._loop = asyncio.get_running_loop()
85
86
  self._cancel_event = asyncio.Event()
@@ -97,9 +98,12 @@ class _Runner:
97
98
  if fn_task.done():
98
99
  return fn_task.result()
99
100
  else:
100
- raise _RunnerCancelledError()
101
+ raise RunnerCancelledError()
101
102
 
102
103
  finally:
104
+ # NOTE: asyncio.run() cancels all tasks after the main task exits,
105
+ # but this is not documented, so we cancel them explicitly here
106
+ # as well. It also blocks until canceled tasks complete.
103
107
  cancellation_task.cancel()
104
108
  fn_task.cancel()
105
109
 
@@ -154,11 +158,9 @@ class TaskGroup:
154
158
  )
155
159
 
156
160
  for task in done:
157
- try:
161
+ with contextlib.suppress(asyncio.CancelledError):
158
162
  if exc := task.exception():
159
163
  raise exc
160
- except asyncio.CancelledError:
161
- pass
162
164
 
163
165
  def _cancel_all(self) -> None:
164
166
  """Cancel all tasks."""
@@ -196,15 +198,16 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
196
198
  def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
197
199
  """Schedule a task, cancelling it when exiting the context manager.
198
200
 
199
- If the given coroutine raises an exception, that exception is raised
200
- when exiting the context manager.
201
+ If the context manager exits successfully but the given coroutine raises
202
+ an exception, that exception is reraised. The exception is suppressed
203
+ if the context manager raises an exception.
201
204
  """
202
205
  task = asyncio.create_task(coro)
203
206
 
204
207
  try:
205
208
  yield
206
- finally:
209
+
207
210
  if task.done() and (exception := task.exception()):
208
211
  raise exception
209
-
212
+ finally:
210
213
  task.cancel()