wandb 0.15.10__py3-none-any.whl → 0.15.11__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -1
- wandb/apis/public.py +51 -9
- wandb/apis/reports/blocks.py +1 -0
- wandb/cli/cli.py +14 -9
- wandb/env.py +11 -1
- wandb/integration/xgboost/xgboost.py +3 -3
- wandb/proto/v3/wandb_internal_pb2.py +300 -267
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
- wandb/proto/v4/wandb_internal_pb2.py +260 -252
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
- wandb/sdk/artifacts/artifact.py +9 -6
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/internal/file_stream.py +2 -1
- wandb/sdk/internal/handler.py +24 -20
- wandb/sdk/internal/internal_api.py +9 -1
- wandb/sdk/internal/sender.py +4 -1
- wandb/sdk/internal/system/system_info.py +2 -2
- wandb/sdk/launch/__init__.py +5 -0
- wandb/sdk/launch/{launch.py → _launch.py} +53 -54
- wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
- wandb/sdk/launch/agent/agent.py +36 -18
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
- wandb/sdk/launch/runner/abstract.py +0 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +44 -301
- wandb/sdk/launch/runner/local_container.py +5 -2
- wandb/sdk/launch/sweeps/scheduler.py +14 -10
- wandb/sdk/launch/sweeps/utils.py +5 -3
- wandb/sdk/launch/utils.py +3 -1
- wandb/sdk/lib/_settings_toposort_generated.py +5 -0
- wandb/sdk/lib/gql_request.py +3 -0
- wandb/sdk/lib/ipython.py +4 -0
- wandb/sdk/service/service.py +19 -6
- wandb/sdk/wandb_init.py +7 -2
- wandb/sdk/wandb_run.py +2 -5
- wandb/sdk/wandb_settings.py +48 -2
- wandb/util.py +1 -1
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/METADATA +4 -1
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/RECORD +46 -45
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/WHEEL +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,9 @@ import base64
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import time
|
7
|
-
from threading import Lock, Thread
|
8
7
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
9
8
|
|
10
|
-
import
|
9
|
+
import yaml
|
11
10
|
|
12
11
|
import wandb
|
13
12
|
from wandb.apis.internal import Api
|
@@ -15,7 +14,7 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
|
15
14
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
16
15
|
from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
|
17
16
|
from wandb.sdk.launch.registry.local_registry import LocalRegistry
|
18
|
-
from wandb.sdk.launch.runner.abstract import
|
17
|
+
from wandb.sdk.launch.runner.abstract import Status
|
19
18
|
from wandb.util import get_module
|
20
19
|
|
21
20
|
from .._project_spec import EntryPoint, LaunchProject
|
@@ -29,22 +28,20 @@ from ..utils import (
|
|
29
28
|
make_name_dns_safe,
|
30
29
|
)
|
31
30
|
from .abstract import AbstractRun, AbstractRunner
|
31
|
+
from .kubernetes_monitor import KubernetesRunMonitor
|
32
32
|
|
33
33
|
get_module(
|
34
34
|
"kubernetes",
|
35
35
|
required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
|
36
36
|
)
|
37
37
|
|
38
|
-
from kubernetes import client
|
38
|
+
from kubernetes import client # type: ignore # noqa: E402
|
39
39
|
from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
|
40
40
|
from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
|
41
41
|
from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
|
42
42
|
CustomObjectsApi,
|
43
43
|
)
|
44
44
|
from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
|
45
|
-
from kubernetes.client.models.v1_pod_status import ( # type: ignore # noqa: E402
|
46
|
-
V1PodStatus,
|
47
|
-
)
|
48
45
|
from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
|
49
46
|
from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
|
50
47
|
|
@@ -53,180 +50,6 @@ TIMEOUT = 5
|
|
53
50
|
_logger = logging.getLogger(__name__)
|
54
51
|
|
55
52
|
|
56
|
-
# Dict for mapping possible states of custom objects to the states we want to report
|
57
|
-
# to the agent.
|
58
|
-
CRD_STATE_DICT: Dict[str, State] = {
|
59
|
-
"pending": "starting",
|
60
|
-
"running": "running",
|
61
|
-
"completed": "finished",
|
62
|
-
"failed": "failed",
|
63
|
-
"aborted": "failed",
|
64
|
-
"terminating": "stopping",
|
65
|
-
"terminated": "stopped",
|
66
|
-
}
|
67
|
-
|
68
|
-
|
69
|
-
def _is_preempted(status: "V1PodStatus") -> bool:
|
70
|
-
"""Check if this pod has been preempted."""
|
71
|
-
if hasattr(status, "conditions") and status.conditions is not None:
|
72
|
-
for condition in status.conditions:
|
73
|
-
if condition.type == "DisruptionTarget" and condition.reason in [
|
74
|
-
"EvictionByEvictionAPI",
|
75
|
-
"PreemptionByScheduler",
|
76
|
-
"TerminationByKubelet",
|
77
|
-
]:
|
78
|
-
return True
|
79
|
-
return False
|
80
|
-
|
81
|
-
|
82
|
-
def _is_container_creating(status: "V1PodStatus") -> bool:
|
83
|
-
"""Check if this pod has started creating containers."""
|
84
|
-
for container_status in status.container_statuses or []:
|
85
|
-
if (
|
86
|
-
container_status.state
|
87
|
-
and container_status.state.waiting
|
88
|
-
and container_status.state.waiting.reason == "ContainerCreating"
|
89
|
-
):
|
90
|
-
return True
|
91
|
-
return False
|
92
|
-
|
93
|
-
|
94
|
-
class KubernetesRunMonitor:
|
95
|
-
def __init__(
|
96
|
-
self,
|
97
|
-
job_field_selector: str,
|
98
|
-
pod_label_selector: str,
|
99
|
-
namespace: str,
|
100
|
-
batch_api: "BatchV1Api",
|
101
|
-
core_api: "CoreV1Api",
|
102
|
-
) -> None:
|
103
|
-
"""Initial KubernetesRunMonitor.
|
104
|
-
|
105
|
-
Arguments:
|
106
|
-
jobname: Name of the job.
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
None.
|
110
|
-
"""
|
111
|
-
self.pod_label_selector = pod_label_selector
|
112
|
-
self.job_field_selector = job_field_selector
|
113
|
-
self.namespace = namespace
|
114
|
-
self.batch_api = batch_api
|
115
|
-
self.core_api = core_api
|
116
|
-
|
117
|
-
self._status_lock = Lock()
|
118
|
-
self._status = Status("starting")
|
119
|
-
|
120
|
-
self._watch_job_thread = Thread(target=self._watch_job, daemon=True)
|
121
|
-
self._watch_pods_thread = Thread(target=self._watch_pods, daemon=True)
|
122
|
-
|
123
|
-
self._job_watcher = watch.Watch()
|
124
|
-
self._pod_watcher = watch.Watch()
|
125
|
-
|
126
|
-
def start(self) -> None:
|
127
|
-
"""Start the run monitor."""
|
128
|
-
if self._watch_job_thread.is_alive() or self._watch_pods_thread.is_alive():
|
129
|
-
raise LaunchError(
|
130
|
-
"Attempted to start monitor that has already started"
|
131
|
-
) # TODO: what should I do here?
|
132
|
-
self._watch_job_thread.start()
|
133
|
-
self._watch_pods_thread.start()
|
134
|
-
|
135
|
-
def stop(self) -> None:
|
136
|
-
"""Stop the run monitor."""
|
137
|
-
self._job_watcher.stop()
|
138
|
-
self._pod_watcher.stop()
|
139
|
-
|
140
|
-
def _set_status(self, status: Status) -> None:
|
141
|
-
"""Set the run status."""
|
142
|
-
with self._status_lock:
|
143
|
-
self._status = status
|
144
|
-
|
145
|
-
def get_status(self) -> Status:
|
146
|
-
"""Get the run status."""
|
147
|
-
with self._status_lock:
|
148
|
-
return self._status
|
149
|
-
|
150
|
-
def _watch_pods(self) -> None:
|
151
|
-
"""Watch for pods created matching the jobname."""
|
152
|
-
try:
|
153
|
-
# Stream with no timeout polling for pod status updates
|
154
|
-
for event in self._pod_watcher.stream(
|
155
|
-
self.core_api.list_namespaced_pod,
|
156
|
-
namespace=self.namespace,
|
157
|
-
label_selector=self.pod_label_selector,
|
158
|
-
):
|
159
|
-
type = event.get("type")
|
160
|
-
object = event.get("object")
|
161
|
-
|
162
|
-
if type == "MODIFIED":
|
163
|
-
if object.status.phase == "Running":
|
164
|
-
self._set_status(Status("running"))
|
165
|
-
if _is_preempted(object.status):
|
166
|
-
self._set_status(Status("preempted"))
|
167
|
-
self.stop()
|
168
|
-
break
|
169
|
-
if _is_container_creating(object.status):
|
170
|
-
self._set_status(Status("starting"))
|
171
|
-
|
172
|
-
# This can happen if the initial cluster connection fails.
|
173
|
-
except ApiException as e:
|
174
|
-
raise LaunchError(
|
175
|
-
f"Exception when calling CoreV1Api.list_namespaced_pod with selector {self.pod_label_selector}: {e}"
|
176
|
-
)
|
177
|
-
|
178
|
-
# This can happen if the stream starts and gets broken, typically because
|
179
|
-
# a thread is hanging. The kubernetes SDK is already implementing a
|
180
|
-
# retry loop so if we get here it means that the pods cannot be monitored.
|
181
|
-
except urllib3.exceptions.ProtocolError as e:
|
182
|
-
state = self.get_status().state
|
183
|
-
if state in ["failed", "finished"]:
|
184
|
-
_logger.warning(
|
185
|
-
f"Hanging pod monitor thread with selector {self.pod_label_selector}: {e}"
|
186
|
-
)
|
187
|
-
return
|
188
|
-
raise LaunchError(
|
189
|
-
f"Broken event stream for pod watcher in state '{state}' and selector {self.pod_label_selector}: {e}"
|
190
|
-
)
|
191
|
-
|
192
|
-
def _watch_job(self) -> None:
|
193
|
-
"""Watch for job matching the jobname."""
|
194
|
-
try:
|
195
|
-
for event in self._job_watcher.stream(
|
196
|
-
self.batch_api.list_namespaced_job,
|
197
|
-
namespace="default",
|
198
|
-
field_selector=self.job_field_selector,
|
199
|
-
):
|
200
|
-
object = event.get("object")
|
201
|
-
if object.status.succeeded == 1:
|
202
|
-
self._set_status(Status("finished"))
|
203
|
-
self.stop()
|
204
|
-
break
|
205
|
-
elif object.status.failed is not None and object.status.failed >= 1:
|
206
|
-
self._set_status(Status("failed"))
|
207
|
-
self.stop()
|
208
|
-
break
|
209
|
-
|
210
|
-
# This can happen if the initial cluster connection fails.
|
211
|
-
except ApiException as e:
|
212
|
-
raise LaunchError(
|
213
|
-
f"Exception when calling CoreV1Api.list_namespaced_job with selector {self.job_field_selector}: {e}"
|
214
|
-
)
|
215
|
-
|
216
|
-
# This can happen if the connection is lost to the Kubernetes API server
|
217
|
-
# and cannot be re-established.
|
218
|
-
except urllib3.exceptions.ProtocolError as e:
|
219
|
-
state = self.get_status().state
|
220
|
-
if state in ["finished", "failed"]:
|
221
|
-
_logger.warning(
|
222
|
-
f"Hanging job monitor thread with select {self.job_field_selector}: {e}"
|
223
|
-
)
|
224
|
-
return
|
225
|
-
raise LaunchError(
|
226
|
-
f"Broken event stream for job watcher in state {state} with selector {self.job_field_selector}: {e}"
|
227
|
-
)
|
228
|
-
|
229
|
-
|
230
53
|
class KubernetesSubmittedRun(AbstractRun):
|
231
54
|
"""Wrapper for a launched run on Kubernetes."""
|
232
55
|
|
@@ -266,9 +89,6 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
266
89
|
self.core_api = core_api
|
267
90
|
self.name = name
|
268
91
|
self.namespace = namespace
|
269
|
-
self.job = self.batch_api.read_namespaced_job(
|
270
|
-
name=self.name, namespace=self.namespace
|
271
|
-
)
|
272
92
|
self._fail_count = 0
|
273
93
|
self.pod_names = pod_names
|
274
94
|
self.secret = secret
|
@@ -309,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
309
129
|
while True:
|
310
130
|
status = self.get_status()
|
311
131
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
312
|
-
if status.state
|
132
|
+
if status.state in ["finished", "failed", "preempted"]:
|
313
133
|
break
|
314
134
|
time.sleep(5)
|
315
135
|
return (
|
@@ -331,35 +151,18 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
331
151
|
def get_status(self) -> Status:
|
332
152
|
return self.monitor.get_status()
|
333
153
|
|
334
|
-
def suspend(self) -> None:
|
335
|
-
"""Suspend the run."""
|
336
|
-
self.job.spec.suspend = True
|
337
|
-
self.batch_api.patch_namespaced_job(
|
338
|
-
name=self.name, namespace=self.namespace, body=self.job
|
339
|
-
)
|
340
|
-
timeout = TIMEOUT
|
341
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
342
|
-
name=self.name, namespace=self.namespace
|
343
|
-
)
|
344
|
-
while job_response.status.conditions is None and timeout > 0:
|
345
|
-
time.sleep(1)
|
346
|
-
timeout -= 1
|
347
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
348
|
-
name=self.name, namespace=self.namespace
|
349
|
-
)
|
350
|
-
|
351
|
-
if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
|
352
|
-
raise LaunchError(
|
353
|
-
"Failed to suspend job {}. Check Kubernetes dashboard for more info.".format(
|
354
|
-
self.name
|
355
|
-
)
|
356
|
-
)
|
357
|
-
|
358
154
|
def cancel(self) -> None:
|
359
155
|
"""Cancel the run."""
|
360
|
-
self.suspend()
|
361
156
|
self.monitor.stop()
|
362
|
-
|
157
|
+
try:
|
158
|
+
self.batch_api.delete_namespaced_job(
|
159
|
+
namespace=self.namespace,
|
160
|
+
name=self.name,
|
161
|
+
)
|
162
|
+
except ApiException as e:
|
163
|
+
raise LaunchError(
|
164
|
+
f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
|
165
|
+
) from e
|
363
166
|
|
364
167
|
|
365
168
|
class CrdSubmittedRun(AbstractRun):
|
@@ -374,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
374
177
|
namespace: str,
|
375
178
|
core_api: CoreV1Api,
|
376
179
|
custom_api: CustomObjectsApi,
|
377
|
-
|
180
|
+
monitor: KubernetesRunMonitor,
|
378
181
|
) -> None:
|
379
182
|
"""Create a run object for tracking the progress of a CRD.
|
380
183
|
|
@@ -386,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
386
189
|
namespace: The namespace of the CRD instance.
|
387
190
|
core_api: The Kubernetes core API client.
|
388
191
|
custom_api: The Kubernetes custom object API client.
|
389
|
-
|
192
|
+
monitor: The run monitor.
|
390
193
|
|
391
194
|
Raises:
|
392
195
|
LaunchError: If the CRD instance does not exist.
|
@@ -398,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
|
|
398
201
|
self.namespace = namespace
|
399
202
|
self.core_api = core_api
|
400
203
|
self.custom_api = custom_api
|
401
|
-
self.pod_names = pod_names
|
402
204
|
self._fail_count = 0
|
403
|
-
|
404
|
-
self.job = self.custom_api.get_namespaced_custom_object(
|
405
|
-
group=self.group,
|
406
|
-
version=self.version,
|
407
|
-
namespace=self.namespace,
|
408
|
-
plural=self.plural,
|
409
|
-
name=self.name,
|
410
|
-
)
|
411
|
-
except ApiException as e:
|
412
|
-
raise LaunchError(
|
413
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
414
|
-
) from e
|
205
|
+
self.monitor = monitor
|
415
206
|
|
416
207
|
@property
|
417
208
|
def id(self) -> str:
|
@@ -423,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
|
|
423
214
|
# TODO: test more carefully once we release multi-node support
|
424
215
|
logs: Dict[str, Optional[str]] = {}
|
425
216
|
try:
|
426
|
-
|
217
|
+
pods = self.core_api.list_namespaced_pod(
|
218
|
+
label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
|
219
|
+
)
|
220
|
+
pod_names = [pi.metadata.name for pi in pods.items]
|
221
|
+
for pod_name in pod_names:
|
427
222
|
logs[pod_name] = self.core_api.read_namespaced_pod_log(
|
428
223
|
name=pod_name, namespace=self.namespace
|
429
224
|
)
|
@@ -437,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
437
232
|
|
438
233
|
def get_status(self) -> Status:
|
439
234
|
"""Get status of custom object."""
|
440
|
-
|
441
|
-
job_response = self.custom_api.get_namespaced_custom_object_status(
|
442
|
-
group=self.group,
|
443
|
-
version=self.version,
|
444
|
-
namespace=self.namespace,
|
445
|
-
plural=self.plural,
|
446
|
-
name=self.name,
|
447
|
-
)
|
448
|
-
except ApiException as e:
|
449
|
-
raise LaunchError(
|
450
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
451
|
-
) from e
|
452
|
-
# Custom objects can technically define whater states and format the
|
453
|
-
# response to the status request however they want. This checks for
|
454
|
-
# the most common cases.
|
455
|
-
status = job_response["status"]
|
456
|
-
state = status.get("state")
|
457
|
-
if isinstance(state, dict):
|
458
|
-
state = state.get("phase")
|
459
|
-
if state is None:
|
460
|
-
raise LaunchError(
|
461
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
|
462
|
-
)
|
463
|
-
return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
|
235
|
+
return self.monitor.get_status()
|
464
236
|
|
465
237
|
def cancel(self) -> None:
|
466
238
|
"""Cancel the custom object."""
|
@@ -482,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
|
|
482
254
|
while True:
|
483
255
|
status = self.get_status()
|
484
256
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
485
|
-
if status.state != "running":
|
486
|
-
break
|
487
257
|
time.sleep(5)
|
488
|
-
|
258
|
+
if status.state in ["finished", "failed", "preempted"]:
|
259
|
+
return status.state == "finished"
|
489
260
|
|
490
261
|
|
491
262
|
class KubernetesRunner(AbstractRunner):
|
@@ -512,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
|
|
512
283
|
self.environment = environment
|
513
284
|
self.registry = registry
|
514
285
|
|
515
|
-
def wait_job_launch(
|
516
|
-
self,
|
517
|
-
job_name: str,
|
518
|
-
namespace: str,
|
519
|
-
core_api: "CoreV1Api",
|
520
|
-
label: str = "job-name",
|
521
|
-
) -> List[str]:
|
522
|
-
"""Wait for a job to be launched and return the pod names.
|
523
|
-
|
524
|
-
Arguments:
|
525
|
-
job_name: The name of the job.
|
526
|
-
namespace: The namespace of the job.
|
527
|
-
core_api: The Kubernetes core API client.
|
528
|
-
label: The label key to match against job_name.
|
529
|
-
|
530
|
-
Returns:
|
531
|
-
The names of the pods associated with the job.
|
532
|
-
"""
|
533
|
-
pods = core_api.list_namespaced_pod(
|
534
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
535
|
-
)
|
536
|
-
timeout = TIMEOUT
|
537
|
-
while len(pods.items) == 0 and timeout > 0:
|
538
|
-
time.sleep(1)
|
539
|
-
timeout -= 1
|
540
|
-
pods = core_api.list_namespaced_pod(
|
541
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
542
|
-
)
|
543
|
-
|
544
|
-
if timeout == 0:
|
545
|
-
raise LaunchError(
|
546
|
-
"No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
|
547
|
-
job_name
|
548
|
-
)
|
549
|
-
)
|
550
|
-
|
551
|
-
pod_names = [pi.metadata.name for pi in pods.items]
|
552
|
-
wandb.termlog(
|
553
|
-
f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
|
554
|
-
)
|
555
|
-
return pod_names
|
556
|
-
|
557
286
|
def get_namespace(
|
558
287
|
self, resource_args: Dict[str, Any], context: Dict[str, Any]
|
559
288
|
) -> str:
|
@@ -742,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
|
|
742
471
|
body=resource_args,
|
743
472
|
)
|
744
473
|
except ApiException as e:
|
474
|
+
body = json.loads(e.body)
|
475
|
+
body_yaml = yaml.dump(body)
|
745
476
|
raise LaunchError(
|
746
|
-
f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
|
477
|
+
f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
|
747
478
|
) from e
|
748
479
|
name = response.get("metadata", {}).get("name")
|
749
480
|
_logger.info(f"Created {kind} {response['metadata']['name']}")
|
750
481
|
core = client.CoreV1Api(api_client)
|
751
|
-
|
752
|
-
|
482
|
+
run_monitor = KubernetesRunMonitor(
|
483
|
+
job_field_selector=f"metadata.name={name}",
|
484
|
+
pod_label_selector=f"wandb/run-id={launch_project.run_id}",
|
485
|
+
namespace=namespace,
|
486
|
+
batch_api=None,
|
487
|
+
core_api=core,
|
488
|
+
custom_api=api,
|
489
|
+
group=group,
|
490
|
+
version=version,
|
491
|
+
plural=plural,
|
753
492
|
)
|
754
|
-
|
493
|
+
run_monitor.start()
|
494
|
+
submitted_run = CrdSubmittedRun(
|
755
495
|
name=name,
|
756
496
|
group=group,
|
757
497
|
version=version,
|
@@ -759,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
|
|
759
499
|
plural=plural,
|
760
500
|
core_api=client.CoreV1Api(api_client),
|
761
501
|
custom_api=api,
|
762
|
-
|
502
|
+
monitor=run_monitor,
|
763
503
|
)
|
504
|
+
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
505
|
+
submitted_run.wait()
|
506
|
+
return submitted_run
|
764
507
|
|
765
508
|
batch_api = kubernetes.client.BatchV1Api(api_client)
|
766
509
|
core_api = kubernetes.client.CoreV1Api(api_client)
|
@@ -5,7 +5,7 @@ import subprocess
|
|
5
5
|
import sys
|
6
6
|
import threading
|
7
7
|
import time
|
8
|
-
from typing import Any, Dict, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
@@ -26,6 +26,9 @@ from ..utils import (
|
|
26
26
|
)
|
27
27
|
from .abstract import AbstractRun, AbstractRunner, Status
|
28
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
|
29
32
|
_logger = logging.getLogger(__name__)
|
30
33
|
|
31
34
|
|
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
95
98
|
|
96
99
|
def __init__(
|
97
100
|
self,
|
98
|
-
api:
|
101
|
+
api: "Api",
|
99
102
|
backend_config: Dict[str, Any],
|
100
103
|
environment: AbstractEnvironment,
|
101
104
|
registry: AbstractRegistry,
|
@@ -9,26 +9,28 @@ import traceback
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from enum import Enum
|
12
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
12
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
|
13
13
|
|
14
14
|
import click
|
15
15
|
import yaml
|
16
16
|
|
17
17
|
import wandb
|
18
|
-
import wandb.apis.public as public
|
19
|
-
from wandb.apis.internal import Api
|
20
|
-
from wandb.apis.public import Api as PublicApi
|
21
|
-
from wandb.apis.public import QueuedRun, Run
|
22
18
|
from wandb.errors import CommError
|
19
|
+
from wandb.sdk.launch._launch_add import launch_add
|
23
20
|
from wandb.sdk.launch.errors import LaunchError
|
24
|
-
from wandb.sdk.launch.launch_add import launch_add
|
25
21
|
from wandb.sdk.launch.sweeps import SchedulerError
|
26
22
|
from wandb.sdk.launch.sweeps.utils import (
|
27
23
|
create_sweep_command_args,
|
28
24
|
make_launch_sweep_entrypoint,
|
29
25
|
)
|
30
26
|
from wandb.sdk.lib.runid import generate_id
|
31
|
-
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
import wandb.apis.public as public
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
from wandb.apis.public import QueuedRun, Run
|
32
|
+
from wandb.sdk.wandb_run import Run as SdkRun
|
33
|
+
|
32
34
|
|
33
35
|
_logger = logging.getLogger(__name__)
|
34
36
|
LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
|
@@ -84,7 +86,7 @@ class SweepRun:
|
|
84
86
|
id: str
|
85
87
|
worker_id: int
|
86
88
|
state: RunState = RunState.RUNNING
|
87
|
-
queued_run: Optional[public.QueuedRun] = None
|
89
|
+
queued_run: Optional["public.QueuedRun"] = None
|
88
90
|
args: Optional[Dict[str, Any]] = None
|
89
91
|
logs: Optional[List[str]] = None
|
90
92
|
|
@@ -98,7 +100,7 @@ class Scheduler(ABC):
|
|
98
100
|
|
99
101
|
def __init__(
|
100
102
|
self,
|
101
|
-
api: Api,
|
103
|
+
api: "Api",
|
102
104
|
*args: Optional[Any],
|
103
105
|
polling_sleep: Optional[float] = None,
|
104
106
|
sweep_id: Optional[str] = None,
|
@@ -108,6 +110,8 @@ class Scheduler(ABC):
|
|
108
110
|
num_workers: Optional[Union[int, str]] = None,
|
109
111
|
**kwargs: Optional[Any],
|
110
112
|
):
|
113
|
+
from wandb.apis.public import Api as PublicApi
|
114
|
+
|
111
115
|
self._api = api
|
112
116
|
self._public_api = PublicApi()
|
113
117
|
self._entity = (
|
@@ -244,7 +248,7 @@ class Scheduler(ABC):
|
|
244
248
|
_id: w for _id, w in self._workers.items() if _id not in self.busy_workers
|
245
249
|
}
|
246
250
|
|
247
|
-
def _init_wandb_run(self) -> SdkRun:
|
251
|
+
def _init_wandb_run(self) -> "SdkRun":
|
248
252
|
"""Controls resume or init logic for a scheduler wandb run."""
|
249
253
|
_type = self._kwargs.get("sweep_type", "sweep")
|
250
254
|
run: SdkRun = wandb.init(
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
8
8
|
import wandb
|
9
9
|
from wandb import util
|
10
|
-
from wandb.apis.public import Api as PublicApi
|
11
10
|
from wandb.sdk.launch.errors import LaunchError
|
12
11
|
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from wandb.apis.public import Api as PublicApi
|
14
|
+
|
13
15
|
DEFAULT_SWEEP_COMMAND: List[str] = [
|
14
16
|
"${env}",
|
15
17
|
"${interpreter}",
|
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
|
|
276
278
|
return entry_point, macro_args
|
277
279
|
|
278
280
|
|
279
|
-
def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
|
281
|
+
def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
280
282
|
"""Check if the job exists using the public api.
|
281
283
|
|
282
284
|
Returns: True if no job is passed, or if the job exists.
|
wandb/sdk/launch/utils.py
CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
|
|
127
127
|
prefix = ""
|
128
128
|
if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
|
129
129
|
prefix = "🚀 "
|
130
|
-
wandb.termlog(
|
130
|
+
wandb.termlog(
|
131
|
+
f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
|
132
|
+
)
|
131
133
|
return project, entity
|
132
134
|
|
133
135
|
|
@@ -58,6 +58,7 @@ _Setting = Literal[
|
|
58
58
|
"_sync",
|
59
59
|
"_os",
|
60
60
|
"_platform",
|
61
|
+
"_proxies",
|
61
62
|
"_python",
|
62
63
|
"_runqueue_item_id",
|
63
64
|
"_require_nexus",
|
@@ -84,6 +85,7 @@ _Setting = Literal[
|
|
84
85
|
"azure_account_url_to_access_key",
|
85
86
|
"base_url",
|
86
87
|
"code_dir",
|
88
|
+
"colab_url",
|
87
89
|
"config_paths",
|
88
90
|
"console",
|
89
91
|
"deployment",
|
@@ -121,6 +123,7 @@ _Setting = Literal[
|
|
121
123
|
"notebook_name",
|
122
124
|
"problem",
|
123
125
|
"program",
|
126
|
+
"program_abspath",
|
124
127
|
"program_relpath",
|
125
128
|
"project",
|
126
129
|
"project_url",
|
@@ -209,6 +212,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
209
212
|
"tmp_dir",
|
210
213
|
"_tmp_code_dir",
|
211
214
|
"_windows",
|
215
|
+
"colab_url",
|
212
216
|
"is_local",
|
213
217
|
"deployment",
|
214
218
|
"disable_code",
|
@@ -220,6 +224,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
220
224
|
"log_symlink_internal",
|
221
225
|
"log_symlink_user",
|
222
226
|
"log_user",
|
227
|
+
"program",
|
223
228
|
"project_url",
|
224
229
|
"resume_fname",
|
225
230
|
"run_url",
|
wandb/sdk/lib/gql_request.py
CHANGED
@@ -20,6 +20,7 @@ class GraphQLSession(HTTPTransport):
|
|
20
20
|
auth: Optional[Union[Tuple[str, str], Callable]] = None,
|
21
21
|
use_json: bool = False,
|
22
22
|
timeout: Optional[Union[int, float]] = None,
|
23
|
+
proxies: Optional[Dict[str, str]] = None,
|
23
24
|
**kwargs: Any,
|
24
25
|
) -> None:
|
25
26
|
"""Setup a session for sending GraphQL queries and mutations.
|
@@ -32,6 +33,8 @@ class GraphQLSession(HTTPTransport):
|
|
32
33
|
"""
|
33
34
|
super().__init__(url, **kwargs)
|
34
35
|
self.session = requests.Session()
|
36
|
+
if proxies:
|
37
|
+
self.session.proxies.update(proxies)
|
35
38
|
self.session.auth = auth
|
36
39
|
self.default_timeout = timeout
|
37
40
|
self.use_json = use_json
|
wandb/sdk/lib/ipython.py
CHANGED
@@ -45,6 +45,10 @@ def _get_python_type() -> PythonType:
|
|
45
45
|
(get_ipython().config.get("IPKernelApp", {}) or {})
|
46
46
|
.get("connection_file", "")
|
47
47
|
.lower()
|
48
|
+
) or (
|
49
|
+
(get_ipython().config.get("ColabKernelApp", {}) or {})
|
50
|
+
.get("connection_file", "")
|
51
|
+
.lower()
|
48
52
|
)
|
49
53
|
|
50
54
|
if (
|