wandb 0.21.0__py3-none-win_amd64.whl → 0.21.2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +16 -14
- wandb/__init__.pyi +427 -450
- wandb/agents/pyagent.py +41 -12
- wandb/analytics/sentry.py +7 -2
- wandb/apis/importers/mlflow.py +1 -1
- wandb/apis/public/__init__.py +1 -1
- wandb/apis/public/api.py +525 -360
- wandb/apis/public/artifacts.py +207 -13
- wandb/apis/public/automations.py +19 -3
- wandb/apis/public/files.py +172 -33
- wandb/apis/public/history.py +67 -15
- wandb/apis/public/integrations.py +25 -2
- wandb/apis/public/jobs.py +90 -2
- wandb/apis/public/projects.py +130 -79
- wandb/apis/public/query_generator.py +11 -1
- wandb/apis/public/registries/_utils.py +14 -16
- wandb/apis/public/registries/registries_search.py +183 -304
- wandb/apis/public/reports.py +96 -15
- wandb/apis/public/runs.py +299 -105
- wandb/apis/public/sweeps.py +222 -22
- wandb/apis/public/teams.py +41 -4
- wandb/apis/public/users.py +45 -4
- wandb/automations/_generated/delete_automation.py +1 -3
- wandb/automations/_generated/enums.py +13 -11
- wandb/beta/workflows.py +66 -30
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +127 -3
- wandb/env.py +8 -0
- wandb/errors/errors.py +4 -1
- wandb/integration/lightning/fabric/logger.py +3 -4
- wandb/integration/metaflow/__init__.py +6 -0
- wandb/integration/metaflow/data_pandas.py +74 -0
- wandb/integration/metaflow/data_pytorch.py +75 -0
- wandb/integration/metaflow/data_sklearn.py +76 -0
- wandb/integration/metaflow/errors.py +13 -0
- wandb/integration/metaflow/metaflow.py +167 -223
- wandb/integration/openai/fine_tuning.py +1 -2
- wandb/integration/weave/__init__.py +6 -0
- wandb/integration/weave/interface.py +49 -0
- wandb/integration/weave/weave.py +63 -0
- wandb/jupyter.py +5 -5
- wandb/plot/custom_chart.py +30 -7
- wandb/proto/v3/wandb_internal_pb2.py +281 -280
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_internal_pb2.py +280 -280
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_internal_pb2.py +280 -280
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v6/wandb_internal_pb2.py +280 -280
- wandb/proto/v6/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +6 -0
- wandb/sdk/artifacts/_factories.py +17 -0
- wandb/sdk/artifacts/_generated/__init__.py +221 -13
- wandb/sdk/artifacts/_generated/artifact_by_id.py +17 -0
- wandb/sdk/artifacts/_generated/artifact_by_name.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_file_urls.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_created_by.py +47 -0
- wandb/sdk/artifacts/_generated/artifact_file_urls.py +22 -0
- wandb/sdk/artifacts/_generated/artifact_type.py +31 -0
- wandb/sdk/artifacts/_generated/artifact_used_by.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_via_membership_by_name.py +26 -0
- wandb/sdk/artifacts/_generated/delete_artifact.py +28 -0
- wandb/sdk/artifacts/_generated/enums.py +5 -0
- wandb/sdk/artifacts/_generated/fetch_artifact_manifest.py +38 -0
- wandb/sdk/artifacts/_generated/fetch_registries.py +32 -0
- wandb/sdk/artifacts/_generated/fragments.py +279 -41
- wandb/sdk/artifacts/_generated/link_artifact.py +6 -0
- wandb/sdk/artifacts/_generated/operations.py +654 -51
- wandb/sdk/artifacts/_generated/registry_collections.py +34 -0
- wandb/sdk/artifacts/_generated/registry_versions.py +34 -0
- wandb/sdk/artifacts/_generated/unlink_artifact.py +25 -0
- wandb/sdk/artifacts/_graphql_fragments.py +3 -86
- wandb/sdk/artifacts/_internal_artifact.py +19 -8
- wandb/sdk/artifacts/_validators.py +14 -4
- wandb/sdk/artifacts/artifact.py +512 -618
- wandb/sdk/artifacts/artifact_file_cache.py +10 -6
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +9 -10
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +5 -3
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -1
- wandb/sdk/data_types/audio.py +38 -10
- wandb/sdk/data_types/base_types/media.py +6 -56
- wandb/sdk/data_types/graph.py +48 -14
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -3
- wandb/sdk/data_types/helper_types/image_mask.py +1 -3
- wandb/sdk/data_types/histogram.py +34 -21
- wandb/sdk/data_types/html.py +35 -12
- wandb/sdk/data_types/image.py +104 -68
- wandb/sdk/data_types/molecule.py +32 -19
- wandb/sdk/data_types/object_3d.py +36 -17
- wandb/sdk/data_types/plotly.py +18 -5
- wandb/sdk/data_types/saved_model.py +4 -6
- wandb/sdk/data_types/table.py +59 -30
- wandb/sdk/data_types/video.py +53 -26
- wandb/sdk/integration_utils/auto_logging.py +2 -2
- wandb/sdk/interface/interface_queue.py +1 -4
- wandb/sdk/interface/interface_shared.py +26 -37
- wandb/sdk/interface/interface_sock.py +24 -14
- wandb/sdk/internal/internal_api.py +6 -0
- wandb/sdk/internal/job_builder.py +6 -0
- wandb/sdk/internal/settings_static.py +2 -3
- wandb/sdk/launch/agent/agent.py +8 -1
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +2 -2
- wandb/sdk/launch/create_job.py +15 -2
- wandb/sdk/launch/inputs/internal.py +3 -4
- wandb/sdk/launch/inputs/schema.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +323 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -3
- wandb/sdk/lib/asyncio_compat.py +19 -16
- wandb/sdk/lib/asyncio_manager.py +252 -0
- wandb/sdk/lib/deprecate.py +1 -7
- wandb/sdk/lib/disabled.py +1 -1
- wandb/sdk/lib/hashutil.py +27 -5
- wandb/sdk/lib/module.py +7 -13
- wandb/sdk/lib/printer.py +2 -2
- wandb/sdk/lib/printer_asyncio.py +3 -1
- wandb/sdk/lib/progress.py +0 -19
- wandb/sdk/lib/retry.py +185 -78
- wandb/sdk/lib/service/service_client.py +106 -0
- wandb/sdk/lib/service/service_connection.py +20 -26
- wandb/sdk/lib/service/service_token.py +30 -13
- wandb/sdk/mailbox/mailbox.py +13 -5
- wandb/sdk/mailbox/mailbox_handle.py +22 -13
- wandb/sdk/mailbox/response_handle.py +42 -106
- wandb/sdk/mailbox/wait_with_progress.py +7 -42
- wandb/sdk/wandb_init.py +77 -116
- wandb/sdk/wandb_login.py +19 -15
- wandb/sdk/wandb_metric.py +2 -0
- wandb/sdk/wandb_run.py +497 -469
- wandb/sdk/wandb_settings.py +145 -4
- wandb/sdk/wandb_setup.py +204 -124
- wandb/sdk/wandb_sweep.py +14 -13
- wandb/sdk/wandb_watch.py +4 -6
- wandb/sync/sync.py +10 -0
- wandb/util.py +58 -1
- wandb/wandb_run.py +1 -2
- {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/METADATA +1 -1
- {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/RECORD +145 -129
- wandb/sdk/interface/interface_relay.py +0 -38
- wandb/sdk/interface/router.py +0 -89
- wandb/sdk/interface/router_queue.py +0 -43
- wandb/sdk/interface/router_relay.py +0 -50
- wandb/sdk/interface/router_sock.py +0 -32
- wandb/sdk/lib/sock_client.py +0 -236
- wandb/vendor/pynvml/__init__.py +0 -0
- wandb/vendor/pynvml/pynvml.py +0 -4779
- {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/WHEEL +0 -0
- {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.21.0.dist-info → wandb-0.21.2.dist-info}/licenses/LICENSE +0 -0
@@ -6,6 +6,7 @@ import datetime
|
|
6
6
|
import json
|
7
7
|
import logging
|
8
8
|
import os
|
9
|
+
import time
|
9
10
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
10
11
|
|
11
12
|
import yaml
|
@@ -20,11 +21,13 @@ from wandb.sdk.launch.registry.local_registry import LocalRegistry
|
|
20
21
|
from wandb.sdk.launch.runner.abstract import Status
|
21
22
|
from wandb.sdk.launch.runner.kubernetes_monitor import (
|
22
23
|
WANDB_K8S_LABEL_AGENT,
|
24
|
+
WANDB_K8S_LABEL_AUXILIARY_RESOURCE,
|
23
25
|
WANDB_K8S_LABEL_MONITOR,
|
24
26
|
WANDB_K8S_RUN_ID,
|
25
27
|
CustomResource,
|
26
28
|
LaunchKubernetesMonitor,
|
27
29
|
)
|
30
|
+
from wandb.sdk.launch.utils import recursive_macro_sub
|
28
31
|
from wandb.sdk.lib.retry import ExponentialBackoff, retry_async
|
29
32
|
from wandb.util import get_module
|
30
33
|
|
@@ -47,6 +50,9 @@ get_module(
|
|
47
50
|
|
48
51
|
import kubernetes_asyncio # type: ignore # noqa: E402
|
49
52
|
from kubernetes_asyncio import client # noqa: E402
|
53
|
+
from kubernetes_asyncio.client.api.apps_v1_api import ( # type: ignore # noqa: E402
|
54
|
+
AppsV1Api,
|
55
|
+
)
|
50
56
|
from kubernetes_asyncio.client.api.batch_v1_api import ( # type: ignore # noqa: E402
|
51
57
|
BatchV1Api,
|
52
58
|
)
|
@@ -56,6 +62,9 @@ from kubernetes_asyncio.client.api.core_v1_api import ( # type: ignore # noqa:
|
|
56
62
|
from kubernetes_asyncio.client.api.custom_objects_api import ( # type: ignore # noqa: E402
|
57
63
|
CustomObjectsApi,
|
58
64
|
)
|
65
|
+
from kubernetes_asyncio.client.api.networking_v1_api import ( # type: ignore # noqa: E402
|
66
|
+
NetworkingV1Api,
|
67
|
+
)
|
59
68
|
from kubernetes_asyncio.client.models.v1_secret import ( # type: ignore # noqa: E402
|
60
69
|
V1Secret,
|
61
70
|
)
|
@@ -78,9 +87,12 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
78
87
|
self,
|
79
88
|
batch_api: "BatchV1Api",
|
80
89
|
core_api: "CoreV1Api",
|
90
|
+
apps_api: "AppsV1Api",
|
91
|
+
network_api: "NetworkingV1Api",
|
81
92
|
name: str,
|
82
93
|
namespace: Optional[str] = "default",
|
83
94
|
secret: Optional["V1Secret"] = None,
|
95
|
+
auxiliary_resource_label_key: Optional[str] = None,
|
84
96
|
) -> None:
|
85
97
|
"""Initialize a KubernetesSubmittedRun.
|
86
98
|
|
@@ -95,6 +107,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
95
107
|
Arguments:
|
96
108
|
batch_api: Kubernetes BatchV1Api object.
|
97
109
|
core_api: Kubernetes CoreV1Api object.
|
110
|
+
network_api: Kubernetes NetworkV1Api object.
|
98
111
|
name: Name of the job.
|
99
112
|
namespace: Kubernetes namespace.
|
100
113
|
secret: Kubernetes secret.
|
@@ -104,10 +117,13 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
104
117
|
"""
|
105
118
|
self.batch_api = batch_api
|
106
119
|
self.core_api = core_api
|
120
|
+
self.apps_api = apps_api
|
121
|
+
self.network_api = network_api
|
107
122
|
self.name = name
|
108
123
|
self.namespace = namespace
|
109
124
|
self._fail_count = 0
|
110
125
|
self.secret = secret
|
126
|
+
self.auxiliary_resource_label_key = auxiliary_resource_label_key
|
111
127
|
|
112
128
|
@property
|
113
129
|
def id(self) -> str:
|
@@ -149,6 +165,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
149
165
|
await asyncio.sleep(5)
|
150
166
|
|
151
167
|
await self._delete_secret()
|
168
|
+
await self._delete_auxiliary_resources_by_label()
|
152
169
|
return (
|
153
170
|
status.state == "finished"
|
154
171
|
) # todo: not sure if this (copied from aws runner) is the right approach? should we return false on failure
|
@@ -157,6 +174,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
157
174
|
status = LaunchKubernetesMonitor.get_status(self.name)
|
158
175
|
if status in ["stopped", "failed", "finished", "preempted"]:
|
159
176
|
await self._delete_secret()
|
177
|
+
await self._delete_auxiliary_resources_by_label()
|
160
178
|
return status
|
161
179
|
|
162
180
|
async def cancel(self) -> None:
|
@@ -167,6 +185,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
167
185
|
name=self.name,
|
168
186
|
)
|
169
187
|
await self._delete_secret()
|
188
|
+
await self._delete_auxiliary_resources_by_label()
|
170
189
|
except ApiException as e:
|
171
190
|
raise LaunchError(
|
172
191
|
f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
|
@@ -181,6 +200,50 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
181
200
|
)
|
182
201
|
self.secret = None
|
183
202
|
|
203
|
+
async def _delete_auxiliary_resources_by_label(self) -> None:
|
204
|
+
if self.auxiliary_resource_label_key is None:
|
205
|
+
return
|
206
|
+
|
207
|
+
label_selector = (
|
208
|
+
f"{WANDB_K8S_LABEL_AUXILIARY_RESOURCE}={self.auxiliary_resource_label_key}"
|
209
|
+
)
|
210
|
+
|
211
|
+
try:
|
212
|
+
resource_cleanups = [
|
213
|
+
(self.core_api, "service"),
|
214
|
+
(self.batch_api, "job"),
|
215
|
+
(self.core_api, "pod"),
|
216
|
+
(self.core_api, "secret"),
|
217
|
+
(self.apps_api, "deployment"),
|
218
|
+
(self.network_api, "network_policy"),
|
219
|
+
]
|
220
|
+
|
221
|
+
for api_client, resource_type in resource_cleanups:
|
222
|
+
try:
|
223
|
+
list_method = getattr(
|
224
|
+
api_client, f"list_namespaced_{resource_type}"
|
225
|
+
)
|
226
|
+
delete_method = getattr(
|
227
|
+
api_client, f"delete_namespaced_{resource_type}"
|
228
|
+
)
|
229
|
+
|
230
|
+
# List resources with our label
|
231
|
+
resources = await list_method(
|
232
|
+
namespace=self.namespace, label_selector=label_selector
|
233
|
+
)
|
234
|
+
|
235
|
+
# Delete each resource
|
236
|
+
for resource in resources.items:
|
237
|
+
await delete_method(
|
238
|
+
name=resource.metadata.name, namespace=self.namespace
|
239
|
+
)
|
240
|
+
|
241
|
+
except (AttributeError, ApiException) as e:
|
242
|
+
wandb.termwarn(f"Could not clean up {resource_type}: {e}")
|
243
|
+
|
244
|
+
except Exception as e:
|
245
|
+
wandb.termwarn(f"Failed to clean up some auxiliary resources: {e}")
|
246
|
+
|
184
247
|
|
185
248
|
class CrdSubmittedRun(AbstractRun):
|
186
249
|
"""Run submitted to a CRD backend, e.g. Volcano."""
|
@@ -366,6 +429,7 @@ class KubernetesRunner(AbstractRunner):
|
|
366
429
|
job_metadata["generateName"] = make_name_dns_safe(
|
367
430
|
f"launch-{launch_project.target_entity}-{launch_project.target_project}-"
|
368
431
|
)
|
432
|
+
job_metadata["namespace"] = namespace
|
369
433
|
|
370
434
|
for i, cont in enumerate(containers):
|
371
435
|
if "name" not in cont:
|
@@ -489,6 +553,216 @@ class KubernetesRunner(AbstractRunner):
|
|
489
553
|
|
490
554
|
return job, api_key_secret
|
491
555
|
|
556
|
+
async def _wait_for_resource_ready(
|
557
|
+
self,
|
558
|
+
api_client: kubernetes_asyncio.client.ApiClient,
|
559
|
+
config: Dict[str, Any],
|
560
|
+
namespace: str,
|
561
|
+
timeout_seconds: int = 300,
|
562
|
+
) -> None:
|
563
|
+
"""Wait for a Kubernetes resource to be ready.
|
564
|
+
|
565
|
+
Arguments:
|
566
|
+
api_client: The Kubernetes API client.
|
567
|
+
config: The resource configuration.
|
568
|
+
namespace: The namespace where the resource was created.
|
569
|
+
timeout_seconds: Maximum time to wait for readiness.
|
570
|
+
"""
|
571
|
+
resource_kind = config.get("kind")
|
572
|
+
resource_name = config.get("metadata", {}).get("name")
|
573
|
+
|
574
|
+
if not resource_kind or not resource_name:
|
575
|
+
wandb.termerror(
|
576
|
+
f"{LOG_PREFIX}Cannot wait for resource without kind or name"
|
577
|
+
)
|
578
|
+
return
|
579
|
+
|
580
|
+
wandb.termlog(
|
581
|
+
f"{LOG_PREFIX}Waiting for {resource_kind} '{resource_name}' to be ready..."
|
582
|
+
)
|
583
|
+
|
584
|
+
start_time = time.time()
|
585
|
+
|
586
|
+
if resource_kind == "Deployment":
|
587
|
+
await self._wait_for_deployment_ready(
|
588
|
+
api_client, resource_name, namespace, timeout_seconds
|
589
|
+
)
|
590
|
+
elif resource_kind == "Service":
|
591
|
+
await self._wait_for_service_ready(
|
592
|
+
api_client, resource_name, namespace, timeout_seconds
|
593
|
+
)
|
594
|
+
elif resource_kind == "Pod":
|
595
|
+
await self._wait_for_pod_ready(
|
596
|
+
api_client, resource_name, namespace, timeout_seconds
|
597
|
+
)
|
598
|
+
else:
|
599
|
+
wandb.termlog(
|
600
|
+
f"{LOG_PREFIX}No specific readiness check for {resource_kind}, waiting 5 seconds..."
|
601
|
+
)
|
602
|
+
await asyncio.sleep(5)
|
603
|
+
|
604
|
+
elapsed = time.time() - start_time
|
605
|
+
wandb.termlog(
|
606
|
+
f"{LOG_PREFIX}{resource_kind} '{resource_name}' is ready after {elapsed:.1f}s"
|
607
|
+
)
|
608
|
+
|
609
|
+
async def _wait_for_deployment_ready(
|
610
|
+
self,
|
611
|
+
api_client: kubernetes_asyncio.client.ApiClient,
|
612
|
+
name: str,
|
613
|
+
namespace: str,
|
614
|
+
timeout_seconds: int,
|
615
|
+
) -> None:
|
616
|
+
"""Wait for a Deployment to be ready."""
|
617
|
+
apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
|
618
|
+
|
619
|
+
async def check_deployment_ready():
|
620
|
+
deployment = await apps_api.read_namespaced_deployment(
|
621
|
+
name=name, namespace=namespace
|
622
|
+
)
|
623
|
+
status = deployment.status
|
624
|
+
|
625
|
+
if status.ready_replicas and status.replicas:
|
626
|
+
return status.ready_replicas >= status.replicas
|
627
|
+
|
628
|
+
return False
|
629
|
+
|
630
|
+
await self._wait_with_timeout(check_deployment_ready, timeout_seconds, name)
|
631
|
+
|
632
|
+
async def _wait_for_service_ready(
|
633
|
+
self,
|
634
|
+
api_client: kubernetes_asyncio.client.ApiClient,
|
635
|
+
name: str,
|
636
|
+
namespace: str,
|
637
|
+
timeout_seconds: int,
|
638
|
+
) -> None:
|
639
|
+
"""Wait for a Service to have endpoints."""
|
640
|
+
core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
|
641
|
+
|
642
|
+
async def check_service_ready():
|
643
|
+
endpoints = await core_api.read_namespaced_endpoints(
|
644
|
+
name=name, namespace=namespace
|
645
|
+
)
|
646
|
+
if endpoints.subsets:
|
647
|
+
for subset in endpoints.subsets:
|
648
|
+
if subset.addresses: # These are ready pod addresses
|
649
|
+
return True
|
650
|
+
return False
|
651
|
+
|
652
|
+
await self._wait_with_timeout(check_service_ready, timeout_seconds, name)
|
653
|
+
|
654
|
+
async def _wait_for_pod_ready(
|
655
|
+
self,
|
656
|
+
api_client: kubernetes_asyncio.client.ApiClient,
|
657
|
+
name: str,
|
658
|
+
namespace: str,
|
659
|
+
timeout_seconds: int,
|
660
|
+
) -> None:
|
661
|
+
"""Wait for a Pod to be ready."""
|
662
|
+
core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
|
663
|
+
|
664
|
+
async def check_pod_ready():
|
665
|
+
pod = await core_api.read_namespaced_pod(name=name, namespace=namespace)
|
666
|
+
if pod.status.phase == "Running":
|
667
|
+
if pod.status.container_statuses:
|
668
|
+
return all(status.ready for status in pod.status.container_statuses)
|
669
|
+
return True
|
670
|
+
return False
|
671
|
+
|
672
|
+
await self._wait_with_timeout(check_pod_ready, timeout_seconds, name)
|
673
|
+
|
674
|
+
async def _wait_with_timeout(
|
675
|
+
self, check_func, timeout_seconds: int, name: str
|
676
|
+
) -> None:
|
677
|
+
"""Generic timeout wrapper for readiness checks."""
|
678
|
+
start_time = time.time()
|
679
|
+
|
680
|
+
while time.time() - start_time < timeout_seconds:
|
681
|
+
try:
|
682
|
+
if await check_func():
|
683
|
+
return
|
684
|
+
except kubernetes_asyncio.client.ApiException as e:
|
685
|
+
if e.status == 404:
|
686
|
+
pass
|
687
|
+
else:
|
688
|
+
wandb.termerror(
|
689
|
+
f"{LOG_PREFIX}Error waiting for resource '{name}': {e}"
|
690
|
+
)
|
691
|
+
raise
|
692
|
+
except Exception as e:
|
693
|
+
wandb.termerror(f"{LOG_PREFIX}Error waiting for resource '{name}': {e}")
|
694
|
+
raise
|
695
|
+
await asyncio.sleep(2)
|
696
|
+
|
697
|
+
raise LaunchError(
|
698
|
+
f"Resource '{name}' not ready within {timeout_seconds} seconds"
|
699
|
+
)
|
700
|
+
|
701
|
+
async def _prepare_resource(
|
702
|
+
self,
|
703
|
+
api_client: kubernetes_asyncio.client.ApiClient,
|
704
|
+
config: Dict[str, Any],
|
705
|
+
namespace: str,
|
706
|
+
run_id: str,
|
707
|
+
launch_project: LaunchProject,
|
708
|
+
api_key_secret: Optional["V1Secret"] = None,
|
709
|
+
wait_for_ready: bool = True,
|
710
|
+
wait_timeout: int = 300,
|
711
|
+
) -> None:
|
712
|
+
"""Prepare a service for launch.
|
713
|
+
|
714
|
+
Arguments:
|
715
|
+
api_client: The Kubernetes API client.
|
716
|
+
config: The resource configuration to prepare.
|
717
|
+
namespace: The namespace to create the resource in.
|
718
|
+
run_id: The run ID to label the resource with.
|
719
|
+
launch_project: The launch project to get environment variables from.
|
720
|
+
api_key_secret: The API key secret to inject.
|
721
|
+
wait_for_ready: Whether to wait for the resource to be ready after creation.
|
722
|
+
wait_timeout: Maximum time in seconds to wait for resource readiness.
|
723
|
+
"""
|
724
|
+
config.setdefault("metadata", {})
|
725
|
+
config["metadata"].setdefault("labels", {})
|
726
|
+
config["metadata"]["labels"][WANDB_K8S_RUN_ID] = run_id
|
727
|
+
config["metadata"]["labels"]["wandb.ai/created-by"] = "launch-agent"
|
728
|
+
|
729
|
+
env_vars = launch_project.get_env_vars_dict(
|
730
|
+
self._api, MAX_ENV_LENGTHS[self.__class__.__name__]
|
731
|
+
)
|
732
|
+
wandb_config_env = {
|
733
|
+
"WANDB_CONFIG": env_vars.get("WANDB_CONFIG", "{}"),
|
734
|
+
}
|
735
|
+
add_wandb_env(config, wandb_config_env)
|
736
|
+
|
737
|
+
if api_key_secret:
|
738
|
+
for cont in yield_containers(config):
|
739
|
+
env = cont.setdefault("env", [])
|
740
|
+
env.append(
|
741
|
+
{
|
742
|
+
"name": "WANDB_API_KEY",
|
743
|
+
"valueFrom": {
|
744
|
+
"secretKeyRef": {
|
745
|
+
"name": api_key_secret.metadata.name,
|
746
|
+
"key": "password",
|
747
|
+
}
|
748
|
+
},
|
749
|
+
}
|
750
|
+
)
|
751
|
+
cont["env"] = env
|
752
|
+
|
753
|
+
try:
|
754
|
+
await kubernetes_asyncio.utils.create_from_dict(
|
755
|
+
api_client, config, namespace=namespace
|
756
|
+
)
|
757
|
+
|
758
|
+
if wait_for_ready:
|
759
|
+
await self._wait_for_resource_ready(
|
760
|
+
api_client, config, namespace, wait_timeout
|
761
|
+
)
|
762
|
+
except Exception as e:
|
763
|
+
wandb.termerror(f"{LOG_PREFIX}Failed to create Kubernetes resource: {e}")
|
764
|
+
raise LaunchError(f"Failed to create Kubernetes resource: {e}")
|
765
|
+
|
492
766
|
async def run(
|
493
767
|
self, launch_project: LaunchProject, image_uri: str
|
494
768
|
) -> Optional[AbstractRun]:
|
@@ -630,10 +904,51 @@ class KubernetesRunner(AbstractRunner):
|
|
630
904
|
|
631
905
|
batch_api = kubernetes_asyncio.client.BatchV1Api(api_client)
|
632
906
|
core_api = kubernetes_asyncio.client.CoreV1Api(api_client)
|
907
|
+
apps_api = kubernetes_asyncio.client.AppsV1Api(api_client)
|
908
|
+
network_api = kubernetes_asyncio.client.NetworkingV1Api(api_client)
|
909
|
+
|
633
910
|
namespace = self.get_namespace(resource_args, context)
|
634
911
|
job, secret = await self._inject_defaults(
|
635
912
|
resource_args, launch_project, image_uri, namespace, core_api
|
636
913
|
)
|
914
|
+
|
915
|
+
update_dict = {
|
916
|
+
"project_name": launch_project.target_project,
|
917
|
+
"entity_name": launch_project.target_entity,
|
918
|
+
"run_id": launch_project.run_id,
|
919
|
+
"run_name": launch_project.name,
|
920
|
+
"image_uri": image_uri,
|
921
|
+
"author": launch_project.author,
|
922
|
+
}
|
923
|
+
update_dict.update(os.environ)
|
924
|
+
additional_services: List[Dict[str, Any]] = recursive_macro_sub(
|
925
|
+
launch_project.launch_spec.get("additional_services", []), update_dict
|
926
|
+
)
|
927
|
+
if additional_services:
|
928
|
+
wandb.termlog(
|
929
|
+
f"{LOG_PREFIX}Creating additional services: {additional_services}"
|
930
|
+
)
|
931
|
+
|
932
|
+
wait_for_ready = resource_args.get("wait_for_ready", True)
|
933
|
+
wait_timeout = resource_args.get("wait_timeout", 300)
|
934
|
+
|
935
|
+
await asyncio.gather(
|
936
|
+
*[
|
937
|
+
self._prepare_resource(
|
938
|
+
api_client,
|
939
|
+
resource.get("config", {}),
|
940
|
+
namespace,
|
941
|
+
launch_project.run_id,
|
942
|
+
launch_project,
|
943
|
+
secret,
|
944
|
+
wait_for_ready,
|
945
|
+
wait_timeout,
|
946
|
+
)
|
947
|
+
for resource in additional_services
|
948
|
+
if resource.get("config", {})
|
949
|
+
]
|
950
|
+
)
|
951
|
+
|
637
952
|
msg = "Creating Kubernetes job"
|
638
953
|
if "name" in resource_args:
|
639
954
|
msg += f": {resource_args['name']}"
|
@@ -658,7 +973,14 @@ class KubernetesRunner(AbstractRunner):
|
|
658
973
|
job_name = job_response.metadata.name
|
659
974
|
LaunchKubernetesMonitor.monitor_namespace(namespace)
|
660
975
|
submitted_job = KubernetesSubmittedRun(
|
661
|
-
batch_api,
|
976
|
+
batch_api,
|
977
|
+
core_api,
|
978
|
+
apps_api,
|
979
|
+
network_api,
|
980
|
+
job_name,
|
981
|
+
namespace,
|
982
|
+
secret,
|
983
|
+
f"aux-{launch_project.target_entity}-{launch_project.target_project}-{launch_project.run_id}",
|
662
984
|
)
|
663
985
|
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
664
986
|
await submitted_job.wait()
|
@@ -36,7 +36,6 @@ if TYPE_CHECKING:
|
|
36
36
|
import wandb.apis.public as public
|
37
37
|
from wandb.apis.internal import Api
|
38
38
|
from wandb.apis.public import QueuedRun, Run
|
39
|
-
from wandb.sdk.wandb_run import Run as SdkRun
|
40
39
|
|
41
40
|
|
42
41
|
_logger = logging.getLogger(__name__)
|
@@ -255,10 +254,10 @@ class Scheduler(ABC):
|
|
255
254
|
_id: w for _id, w in self._workers.items() if _id not in self.busy_workers
|
256
255
|
}
|
257
256
|
|
258
|
-
def _init_wandb_run(self) -> "
|
257
|
+
def _init_wandb_run(self) -> "wandb.Run":
|
259
258
|
"""Controls resume or init logic for a scheduler wandb run."""
|
260
259
|
settings = wandb.Settings(disable_job_creation=True)
|
261
|
-
run:
|
260
|
+
run: wandb.Run = wandb.init( # type: ignore
|
262
261
|
name=f"Scheduler.{self._sweep_id}",
|
263
262
|
resume="allow",
|
264
263
|
config=self._kwargs, # when run as a job, this sets config
|
wandb/sdk/lib/asyncio_compat.py
CHANGED
@@ -23,7 +23,7 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
|
23
23
|
Note that due to starting a new thread, this is slightly slow.
|
24
24
|
"""
|
25
25
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
26
|
-
runner =
|
26
|
+
runner = CancellableRunner()
|
27
27
|
future = executor.submit(runner.run, fn)
|
28
28
|
|
29
29
|
try:
|
@@ -33,15 +33,16 @@ def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
|
33
33
|
runner.cancel()
|
34
34
|
|
35
35
|
|
36
|
-
class
|
37
|
-
"""The `
|
36
|
+
class RunnerCancelledError(Exception):
|
37
|
+
"""The `CancellableRunner.run()` invocation was cancelled."""
|
38
38
|
|
39
39
|
|
40
|
-
class
|
40
|
+
class CancellableRunner:
|
41
41
|
"""Runs an asyncio event loop allowing cancellation.
|
42
42
|
|
43
|
-
|
44
|
-
|
43
|
+
The `run()` method is like `asyncio.run()`. The `cancel()` method may
|
44
|
+
be used in a different thread, for instance in a `finally` block, to cancel
|
45
|
+
all tasks, and it is a no-op if `run()` completed.
|
45
46
|
|
46
47
|
Without this, it is impossible to make `asyncio.run()` stop if it runs
|
47
48
|
in a non-main thread. In particular, a KeyboardInterrupt causes the
|
@@ -69,7 +70,7 @@ class _Runner:
|
|
69
70
|
The result of the coroutine returned by `fn`.
|
70
71
|
|
71
72
|
Raises:
|
72
|
-
|
73
|
+
RunnerCancelledError: If `cancel()` is called.
|
73
74
|
"""
|
74
75
|
return asyncio.run(self._run_or_cancel(fn))
|
75
76
|
|
@@ -79,7 +80,7 @@ class _Runner:
|
|
79
80
|
) -> _T:
|
80
81
|
with self._lock:
|
81
82
|
if self._is_cancelled:
|
82
|
-
raise
|
83
|
+
raise RunnerCancelledError()
|
83
84
|
|
84
85
|
self._loop = asyncio.get_running_loop()
|
85
86
|
self._cancel_event = asyncio.Event()
|
@@ -97,9 +98,12 @@ class _Runner:
|
|
97
98
|
if fn_task.done():
|
98
99
|
return fn_task.result()
|
99
100
|
else:
|
100
|
-
raise
|
101
|
+
raise RunnerCancelledError()
|
101
102
|
|
102
103
|
finally:
|
104
|
+
# NOTE: asyncio.run() cancels all tasks after the main task exits,
|
105
|
+
# but this is not documented, so we cancel them explicitly here
|
106
|
+
# as well. It also blocks until canceled tasks complete.
|
103
107
|
cancellation_task.cancel()
|
104
108
|
fn_task.cancel()
|
105
109
|
|
@@ -154,11 +158,9 @@ class TaskGroup:
|
|
154
158
|
)
|
155
159
|
|
156
160
|
for task in done:
|
157
|
-
|
161
|
+
with contextlib.suppress(asyncio.CancelledError):
|
158
162
|
if exc := task.exception():
|
159
163
|
raise exc
|
160
|
-
except asyncio.CancelledError:
|
161
|
-
pass
|
162
164
|
|
163
165
|
def _cancel_all(self) -> None:
|
164
166
|
"""Cancel all tasks."""
|
@@ -196,15 +198,16 @@ async def open_task_group() -> AsyncIterator[TaskGroup]:
|
|
196
198
|
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
|
197
199
|
"""Schedule a task, cancelling it when exiting the context manager.
|
198
200
|
|
199
|
-
If the
|
200
|
-
|
201
|
+
If the context manager exits successfully but the given coroutine raises
|
202
|
+
an exception, that exception is reraised. The exception is suppressed
|
203
|
+
if the context manager raises an exception.
|
201
204
|
"""
|
202
205
|
task = asyncio.create_task(coro)
|
203
206
|
|
204
207
|
try:
|
205
208
|
yield
|
206
|
-
|
209
|
+
|
207
210
|
if task.done() and (exception := task.exception()):
|
208
211
|
raise exception
|
209
|
-
|
212
|
+
finally:
|
210
213
|
task.cancel()
|