wandb 0.15.10__py3-none-any.whl → 0.15.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -1
- wandb/apis/public.py +51 -9
- wandb/apis/reports/blocks.py +1 -0
- wandb/cli/cli.py +14 -9
- wandb/env.py +11 -1
- wandb/integration/xgboost/xgboost.py +3 -3
- wandb/proto/v3/wandb_internal_pb2.py +300 -267
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
- wandb/proto/v4/wandb_internal_pb2.py +260 -252
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
- wandb/sdk/artifacts/artifact.py +9 -6
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/internal/file_stream.py +2 -1
- wandb/sdk/internal/handler.py +24 -20
- wandb/sdk/internal/internal_api.py +9 -1
- wandb/sdk/internal/sender.py +4 -1
- wandb/sdk/internal/system/system_info.py +2 -2
- wandb/sdk/launch/__init__.py +5 -0
- wandb/sdk/launch/{launch.py → _launch.py} +53 -54
- wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
- wandb/sdk/launch/agent/agent.py +36 -18
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
- wandb/sdk/launch/runner/abstract.py +0 -2
- wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +44 -301
- wandb/sdk/launch/runner/local_container.py +5 -2
- wandb/sdk/launch/sweeps/scheduler.py +14 -10
- wandb/sdk/launch/sweeps/utils.py +5 -3
- wandb/sdk/launch/utils.py +3 -1
- wandb/sdk/lib/_settings_toposort_generated.py +5 -0
- wandb/sdk/lib/gql_request.py +3 -0
- wandb/sdk/lib/ipython.py +4 -0
- wandb/sdk/service/service.py +19 -6
- wandb/sdk/wandb_init.py +7 -2
- wandb/sdk/wandb_run.py +2 -5
- wandb/sdk/wandb_settings.py +48 -2
- wandb/util.py +1 -1
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/METADATA +4 -1
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/RECORD +46 -45
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/WHEEL +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,9 @@ import base64
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import time
|
7
|
-
from threading import Lock, Thread
|
8
7
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
9
8
|
|
10
|
-
import
|
9
|
+
import yaml
|
11
10
|
|
12
11
|
import wandb
|
13
12
|
from wandb.apis.internal import Api
|
@@ -15,7 +14,7 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
|
15
14
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
16
15
|
from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
|
17
16
|
from wandb.sdk.launch.registry.local_registry import LocalRegistry
|
18
|
-
from wandb.sdk.launch.runner.abstract import
|
17
|
+
from wandb.sdk.launch.runner.abstract import Status
|
19
18
|
from wandb.util import get_module
|
20
19
|
|
21
20
|
from .._project_spec import EntryPoint, LaunchProject
|
@@ -29,22 +28,20 @@ from ..utils import (
|
|
29
28
|
make_name_dns_safe,
|
30
29
|
)
|
31
30
|
from .abstract import AbstractRun, AbstractRunner
|
31
|
+
from .kubernetes_monitor import KubernetesRunMonitor
|
32
32
|
|
33
33
|
get_module(
|
34
34
|
"kubernetes",
|
35
35
|
required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
|
36
36
|
)
|
37
37
|
|
38
|
-
from kubernetes import client
|
38
|
+
from kubernetes import client # type: ignore # noqa: E402
|
39
39
|
from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
|
40
40
|
from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
|
41
41
|
from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
|
42
42
|
CustomObjectsApi,
|
43
43
|
)
|
44
44
|
from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
|
45
|
-
from kubernetes.client.models.v1_pod_status import ( # type: ignore # noqa: E402
|
46
|
-
V1PodStatus,
|
47
|
-
)
|
48
45
|
from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
|
49
46
|
from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
|
50
47
|
|
@@ -53,180 +50,6 @@ TIMEOUT = 5
|
|
53
50
|
_logger = logging.getLogger(__name__)
|
54
51
|
|
55
52
|
|
56
|
-
# Dict for mapping possible states of custom objects to the states we want to report
|
57
|
-
# to the agent.
|
58
|
-
CRD_STATE_DICT: Dict[str, State] = {
|
59
|
-
"pending": "starting",
|
60
|
-
"running": "running",
|
61
|
-
"completed": "finished",
|
62
|
-
"failed": "failed",
|
63
|
-
"aborted": "failed",
|
64
|
-
"terminating": "stopping",
|
65
|
-
"terminated": "stopped",
|
66
|
-
}
|
67
|
-
|
68
|
-
|
69
|
-
def _is_preempted(status: "V1PodStatus") -> bool:
|
70
|
-
"""Check if this pod has been preempted."""
|
71
|
-
if hasattr(status, "conditions") and status.conditions is not None:
|
72
|
-
for condition in status.conditions:
|
73
|
-
if condition.type == "DisruptionTarget" and condition.reason in [
|
74
|
-
"EvictionByEvictionAPI",
|
75
|
-
"PreemptionByScheduler",
|
76
|
-
"TerminationByKubelet",
|
77
|
-
]:
|
78
|
-
return True
|
79
|
-
return False
|
80
|
-
|
81
|
-
|
82
|
-
def _is_container_creating(status: "V1PodStatus") -> bool:
|
83
|
-
"""Check if this pod has started creating containers."""
|
84
|
-
for container_status in status.container_statuses or []:
|
85
|
-
if (
|
86
|
-
container_status.state
|
87
|
-
and container_status.state.waiting
|
88
|
-
and container_status.state.waiting.reason == "ContainerCreating"
|
89
|
-
):
|
90
|
-
return True
|
91
|
-
return False
|
92
|
-
|
93
|
-
|
94
|
-
class KubernetesRunMonitor:
|
95
|
-
def __init__(
|
96
|
-
self,
|
97
|
-
job_field_selector: str,
|
98
|
-
pod_label_selector: str,
|
99
|
-
namespace: str,
|
100
|
-
batch_api: "BatchV1Api",
|
101
|
-
core_api: "CoreV1Api",
|
102
|
-
) -> None:
|
103
|
-
"""Initial KubernetesRunMonitor.
|
104
|
-
|
105
|
-
Arguments:
|
106
|
-
jobname: Name of the job.
|
107
|
-
|
108
|
-
Returns:
|
109
|
-
None.
|
110
|
-
"""
|
111
|
-
self.pod_label_selector = pod_label_selector
|
112
|
-
self.job_field_selector = job_field_selector
|
113
|
-
self.namespace = namespace
|
114
|
-
self.batch_api = batch_api
|
115
|
-
self.core_api = core_api
|
116
|
-
|
117
|
-
self._status_lock = Lock()
|
118
|
-
self._status = Status("starting")
|
119
|
-
|
120
|
-
self._watch_job_thread = Thread(target=self._watch_job, daemon=True)
|
121
|
-
self._watch_pods_thread = Thread(target=self._watch_pods, daemon=True)
|
122
|
-
|
123
|
-
self._job_watcher = watch.Watch()
|
124
|
-
self._pod_watcher = watch.Watch()
|
125
|
-
|
126
|
-
def start(self) -> None:
|
127
|
-
"""Start the run monitor."""
|
128
|
-
if self._watch_job_thread.is_alive() or self._watch_pods_thread.is_alive():
|
129
|
-
raise LaunchError(
|
130
|
-
"Attempted to start monitor that has already started"
|
131
|
-
) # TODO: what should I do here?
|
132
|
-
self._watch_job_thread.start()
|
133
|
-
self._watch_pods_thread.start()
|
134
|
-
|
135
|
-
def stop(self) -> None:
|
136
|
-
"""Stop the run monitor."""
|
137
|
-
self._job_watcher.stop()
|
138
|
-
self._pod_watcher.stop()
|
139
|
-
|
140
|
-
def _set_status(self, status: Status) -> None:
|
141
|
-
"""Set the run status."""
|
142
|
-
with self._status_lock:
|
143
|
-
self._status = status
|
144
|
-
|
145
|
-
def get_status(self) -> Status:
|
146
|
-
"""Get the run status."""
|
147
|
-
with self._status_lock:
|
148
|
-
return self._status
|
149
|
-
|
150
|
-
def _watch_pods(self) -> None:
|
151
|
-
"""Watch for pods created matching the jobname."""
|
152
|
-
try:
|
153
|
-
# Stream with no timeout polling for pod status updates
|
154
|
-
for event in self._pod_watcher.stream(
|
155
|
-
self.core_api.list_namespaced_pod,
|
156
|
-
namespace=self.namespace,
|
157
|
-
label_selector=self.pod_label_selector,
|
158
|
-
):
|
159
|
-
type = event.get("type")
|
160
|
-
object = event.get("object")
|
161
|
-
|
162
|
-
if type == "MODIFIED":
|
163
|
-
if object.status.phase == "Running":
|
164
|
-
self._set_status(Status("running"))
|
165
|
-
if _is_preempted(object.status):
|
166
|
-
self._set_status(Status("preempted"))
|
167
|
-
self.stop()
|
168
|
-
break
|
169
|
-
if _is_container_creating(object.status):
|
170
|
-
self._set_status(Status("starting"))
|
171
|
-
|
172
|
-
# This can happen if the initial cluster connection fails.
|
173
|
-
except ApiException as e:
|
174
|
-
raise LaunchError(
|
175
|
-
f"Exception when calling CoreV1Api.list_namespaced_pod with selector {self.pod_label_selector}: {e}"
|
176
|
-
)
|
177
|
-
|
178
|
-
# This can happen if the stream starts and gets broken, typically because
|
179
|
-
# a thread is hanging. The kubernetes SDK is already implementing a
|
180
|
-
# retry loop so if we get here it means that the pods cannot be monitored.
|
181
|
-
except urllib3.exceptions.ProtocolError as e:
|
182
|
-
state = self.get_status().state
|
183
|
-
if state in ["failed", "finished"]:
|
184
|
-
_logger.warning(
|
185
|
-
f"Hanging pod monitor thread with selector {self.pod_label_selector}: {e}"
|
186
|
-
)
|
187
|
-
return
|
188
|
-
raise LaunchError(
|
189
|
-
f"Broken event stream for pod watcher in state '{state}' and selector {self.pod_label_selector}: {e}"
|
190
|
-
)
|
191
|
-
|
192
|
-
def _watch_job(self) -> None:
|
193
|
-
"""Watch for job matching the jobname."""
|
194
|
-
try:
|
195
|
-
for event in self._job_watcher.stream(
|
196
|
-
self.batch_api.list_namespaced_job,
|
197
|
-
namespace="default",
|
198
|
-
field_selector=self.job_field_selector,
|
199
|
-
):
|
200
|
-
object = event.get("object")
|
201
|
-
if object.status.succeeded == 1:
|
202
|
-
self._set_status(Status("finished"))
|
203
|
-
self.stop()
|
204
|
-
break
|
205
|
-
elif object.status.failed is not None and object.status.failed >= 1:
|
206
|
-
self._set_status(Status("failed"))
|
207
|
-
self.stop()
|
208
|
-
break
|
209
|
-
|
210
|
-
# This can happen if the initial cluster connection fails.
|
211
|
-
except ApiException as e:
|
212
|
-
raise LaunchError(
|
213
|
-
f"Exception when calling CoreV1Api.list_namespaced_job with selector {self.job_field_selector}: {e}"
|
214
|
-
)
|
215
|
-
|
216
|
-
# This can happen if the connection is lost to the Kubernetes API server
|
217
|
-
# and cannot be re-established.
|
218
|
-
except urllib3.exceptions.ProtocolError as e:
|
219
|
-
state = self.get_status().state
|
220
|
-
if state in ["finished", "failed"]:
|
221
|
-
_logger.warning(
|
222
|
-
f"Hanging job monitor thread with select {self.job_field_selector}: {e}"
|
223
|
-
)
|
224
|
-
return
|
225
|
-
raise LaunchError(
|
226
|
-
f"Broken event stream for job watcher in state {state} with selector {self.job_field_selector}: {e}"
|
227
|
-
)
|
228
|
-
|
229
|
-
|
230
53
|
class KubernetesSubmittedRun(AbstractRun):
|
231
54
|
"""Wrapper for a launched run on Kubernetes."""
|
232
55
|
|
@@ -266,9 +89,6 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
266
89
|
self.core_api = core_api
|
267
90
|
self.name = name
|
268
91
|
self.namespace = namespace
|
269
|
-
self.job = self.batch_api.read_namespaced_job(
|
270
|
-
name=self.name, namespace=self.namespace
|
271
|
-
)
|
272
92
|
self._fail_count = 0
|
273
93
|
self.pod_names = pod_names
|
274
94
|
self.secret = secret
|
@@ -309,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
309
129
|
while True:
|
310
130
|
status = self.get_status()
|
311
131
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
312
|
-
if status.state
|
132
|
+
if status.state in ["finished", "failed", "preempted"]:
|
313
133
|
break
|
314
134
|
time.sleep(5)
|
315
135
|
return (
|
@@ -331,35 +151,18 @@ class KubernetesSubmittedRun(AbstractRun):
|
|
331
151
|
def get_status(self) -> Status:
|
332
152
|
return self.monitor.get_status()
|
333
153
|
|
334
|
-
def suspend(self) -> None:
|
335
|
-
"""Suspend the run."""
|
336
|
-
self.job.spec.suspend = True
|
337
|
-
self.batch_api.patch_namespaced_job(
|
338
|
-
name=self.name, namespace=self.namespace, body=self.job
|
339
|
-
)
|
340
|
-
timeout = TIMEOUT
|
341
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
342
|
-
name=self.name, namespace=self.namespace
|
343
|
-
)
|
344
|
-
while job_response.status.conditions is None and timeout > 0:
|
345
|
-
time.sleep(1)
|
346
|
-
timeout -= 1
|
347
|
-
job_response = self.batch_api.read_namespaced_job_status(
|
348
|
-
name=self.name, namespace=self.namespace
|
349
|
-
)
|
350
|
-
|
351
|
-
if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
|
352
|
-
raise LaunchError(
|
353
|
-
"Failed to suspend job {}. Check Kubernetes dashboard for more info.".format(
|
354
|
-
self.name
|
355
|
-
)
|
356
|
-
)
|
357
|
-
|
358
154
|
def cancel(self) -> None:
|
359
155
|
"""Cancel the run."""
|
360
|
-
self.suspend()
|
361
156
|
self.monitor.stop()
|
362
|
-
|
157
|
+
try:
|
158
|
+
self.batch_api.delete_namespaced_job(
|
159
|
+
namespace=self.namespace,
|
160
|
+
name=self.name,
|
161
|
+
)
|
162
|
+
except ApiException as e:
|
163
|
+
raise LaunchError(
|
164
|
+
f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
|
165
|
+
) from e
|
363
166
|
|
364
167
|
|
365
168
|
class CrdSubmittedRun(AbstractRun):
|
@@ -374,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
374
177
|
namespace: str,
|
375
178
|
core_api: CoreV1Api,
|
376
179
|
custom_api: CustomObjectsApi,
|
377
|
-
|
180
|
+
monitor: KubernetesRunMonitor,
|
378
181
|
) -> None:
|
379
182
|
"""Create a run object for tracking the progress of a CRD.
|
380
183
|
|
@@ -386,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
386
189
|
namespace: The namespace of the CRD instance.
|
387
190
|
core_api: The Kubernetes core API client.
|
388
191
|
custom_api: The Kubernetes custom object API client.
|
389
|
-
|
192
|
+
monitor: The run monitor.
|
390
193
|
|
391
194
|
Raises:
|
392
195
|
LaunchError: If the CRD instance does not exist.
|
@@ -398,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
|
|
398
201
|
self.namespace = namespace
|
399
202
|
self.core_api = core_api
|
400
203
|
self.custom_api = custom_api
|
401
|
-
self.pod_names = pod_names
|
402
204
|
self._fail_count = 0
|
403
|
-
|
404
|
-
self.job = self.custom_api.get_namespaced_custom_object(
|
405
|
-
group=self.group,
|
406
|
-
version=self.version,
|
407
|
-
namespace=self.namespace,
|
408
|
-
plural=self.plural,
|
409
|
-
name=self.name,
|
410
|
-
)
|
411
|
-
except ApiException as e:
|
412
|
-
raise LaunchError(
|
413
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
414
|
-
) from e
|
205
|
+
self.monitor = monitor
|
415
206
|
|
416
207
|
@property
|
417
208
|
def id(self) -> str:
|
@@ -423,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
|
|
423
214
|
# TODO: test more carefully once we release multi-node support
|
424
215
|
logs: Dict[str, Optional[str]] = {}
|
425
216
|
try:
|
426
|
-
|
217
|
+
pods = self.core_api.list_namespaced_pod(
|
218
|
+
label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
|
219
|
+
)
|
220
|
+
pod_names = [pi.metadata.name for pi in pods.items]
|
221
|
+
for pod_name in pod_names:
|
427
222
|
logs[pod_name] = self.core_api.read_namespaced_pod_log(
|
428
223
|
name=pod_name, namespace=self.namespace
|
429
224
|
)
|
@@ -437,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
|
|
437
232
|
|
438
233
|
def get_status(self) -> Status:
|
439
234
|
"""Get status of custom object."""
|
440
|
-
|
441
|
-
job_response = self.custom_api.get_namespaced_custom_object_status(
|
442
|
-
group=self.group,
|
443
|
-
version=self.version,
|
444
|
-
namespace=self.namespace,
|
445
|
-
plural=self.plural,
|
446
|
-
name=self.name,
|
447
|
-
)
|
448
|
-
except ApiException as e:
|
449
|
-
raise LaunchError(
|
450
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
|
451
|
-
) from e
|
452
|
-
# Custom objects can technically define whater states and format the
|
453
|
-
# response to the status request however they want. This checks for
|
454
|
-
# the most common cases.
|
455
|
-
status = job_response["status"]
|
456
|
-
state = status.get("state")
|
457
|
-
if isinstance(state, dict):
|
458
|
-
state = state.get("phase")
|
459
|
-
if state is None:
|
460
|
-
raise LaunchError(
|
461
|
-
f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
|
462
|
-
)
|
463
|
-
return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
|
235
|
+
return self.monitor.get_status()
|
464
236
|
|
465
237
|
def cancel(self) -> None:
|
466
238
|
"""Cancel the custom object."""
|
@@ -482,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
|
|
482
254
|
while True:
|
483
255
|
status = self.get_status()
|
484
256
|
wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
|
485
|
-
if status.state != "running":
|
486
|
-
break
|
487
257
|
time.sleep(5)
|
488
|
-
|
258
|
+
if status.state in ["finished", "failed", "preempted"]:
|
259
|
+
return status.state == "finished"
|
489
260
|
|
490
261
|
|
491
262
|
class KubernetesRunner(AbstractRunner):
|
@@ -512,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
|
|
512
283
|
self.environment = environment
|
513
284
|
self.registry = registry
|
514
285
|
|
515
|
-
def wait_job_launch(
|
516
|
-
self,
|
517
|
-
job_name: str,
|
518
|
-
namespace: str,
|
519
|
-
core_api: "CoreV1Api",
|
520
|
-
label: str = "job-name",
|
521
|
-
) -> List[str]:
|
522
|
-
"""Wait for a job to be launched and return the pod names.
|
523
|
-
|
524
|
-
Arguments:
|
525
|
-
job_name: The name of the job.
|
526
|
-
namespace: The namespace of the job.
|
527
|
-
core_api: The Kubernetes core API client.
|
528
|
-
label: The label key to match against job_name.
|
529
|
-
|
530
|
-
Returns:
|
531
|
-
The names of the pods associated with the job.
|
532
|
-
"""
|
533
|
-
pods = core_api.list_namespaced_pod(
|
534
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
535
|
-
)
|
536
|
-
timeout = TIMEOUT
|
537
|
-
while len(pods.items) == 0 and timeout > 0:
|
538
|
-
time.sleep(1)
|
539
|
-
timeout -= 1
|
540
|
-
pods = core_api.list_namespaced_pod(
|
541
|
-
label_selector=f"{label}={job_name}", namespace=namespace
|
542
|
-
)
|
543
|
-
|
544
|
-
if timeout == 0:
|
545
|
-
raise LaunchError(
|
546
|
-
"No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
|
547
|
-
job_name
|
548
|
-
)
|
549
|
-
)
|
550
|
-
|
551
|
-
pod_names = [pi.metadata.name for pi in pods.items]
|
552
|
-
wandb.termlog(
|
553
|
-
f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
|
554
|
-
)
|
555
|
-
return pod_names
|
556
|
-
|
557
286
|
def get_namespace(
|
558
287
|
self, resource_args: Dict[str, Any], context: Dict[str, Any]
|
559
288
|
) -> str:
|
@@ -742,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
|
|
742
471
|
body=resource_args,
|
743
472
|
)
|
744
473
|
except ApiException as e:
|
474
|
+
body = json.loads(e.body)
|
475
|
+
body_yaml = yaml.dump(body)
|
745
476
|
raise LaunchError(
|
746
|
-
f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
|
477
|
+
f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
|
747
478
|
) from e
|
748
479
|
name = response.get("metadata", {}).get("name")
|
749
480
|
_logger.info(f"Created {kind} {response['metadata']['name']}")
|
750
481
|
core = client.CoreV1Api(api_client)
|
751
|
-
|
752
|
-
|
482
|
+
run_monitor = KubernetesRunMonitor(
|
483
|
+
job_field_selector=f"metadata.name={name}",
|
484
|
+
pod_label_selector=f"wandb/run-id={launch_project.run_id}",
|
485
|
+
namespace=namespace,
|
486
|
+
batch_api=None,
|
487
|
+
core_api=core,
|
488
|
+
custom_api=api,
|
489
|
+
group=group,
|
490
|
+
version=version,
|
491
|
+
plural=plural,
|
753
492
|
)
|
754
|
-
|
493
|
+
run_monitor.start()
|
494
|
+
submitted_run = CrdSubmittedRun(
|
755
495
|
name=name,
|
756
496
|
group=group,
|
757
497
|
version=version,
|
@@ -759,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
|
|
759
499
|
plural=plural,
|
760
500
|
core_api=client.CoreV1Api(api_client),
|
761
501
|
custom_api=api,
|
762
|
-
|
502
|
+
monitor=run_monitor,
|
763
503
|
)
|
504
|
+
if self.backend_config[PROJECT_SYNCHRONOUS]:
|
505
|
+
submitted_run.wait()
|
506
|
+
return submitted_run
|
764
507
|
|
765
508
|
batch_api = kubernetes.client.BatchV1Api(api_client)
|
766
509
|
core_api = kubernetes.client.CoreV1Api(api_client)
|
@@ -5,7 +5,7 @@ import subprocess
|
|
5
5
|
import sys
|
6
6
|
import threading
|
7
7
|
import time
|
8
|
-
from typing import Any, Dict, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
@@ -26,6 +26,9 @@ from ..utils import (
|
|
26
26
|
)
|
27
27
|
from .abstract import AbstractRun, AbstractRunner, Status
|
28
28
|
|
29
|
+
if TYPE_CHECKING:
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
|
29
32
|
_logger = logging.getLogger(__name__)
|
30
33
|
|
31
34
|
|
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
|
|
95
98
|
|
96
99
|
def __init__(
|
97
100
|
self,
|
98
|
-
api:
|
101
|
+
api: "Api",
|
99
102
|
backend_config: Dict[str, Any],
|
100
103
|
environment: AbstractEnvironment,
|
101
104
|
registry: AbstractRegistry,
|
@@ -9,26 +9,28 @@ import traceback
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from enum import Enum
|
12
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
12
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
|
13
13
|
|
14
14
|
import click
|
15
15
|
import yaml
|
16
16
|
|
17
17
|
import wandb
|
18
|
-
import wandb.apis.public as public
|
19
|
-
from wandb.apis.internal import Api
|
20
|
-
from wandb.apis.public import Api as PublicApi
|
21
|
-
from wandb.apis.public import QueuedRun, Run
|
22
18
|
from wandb.errors import CommError
|
19
|
+
from wandb.sdk.launch._launch_add import launch_add
|
23
20
|
from wandb.sdk.launch.errors import LaunchError
|
24
|
-
from wandb.sdk.launch.launch_add import launch_add
|
25
21
|
from wandb.sdk.launch.sweeps import SchedulerError
|
26
22
|
from wandb.sdk.launch.sweeps.utils import (
|
27
23
|
create_sweep_command_args,
|
28
24
|
make_launch_sweep_entrypoint,
|
29
25
|
)
|
30
26
|
from wandb.sdk.lib.runid import generate_id
|
31
|
-
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
import wandb.apis.public as public
|
30
|
+
from wandb.apis.internal import Api
|
31
|
+
from wandb.apis.public import QueuedRun, Run
|
32
|
+
from wandb.sdk.wandb_run import Run as SdkRun
|
33
|
+
|
32
34
|
|
33
35
|
_logger = logging.getLogger(__name__)
|
34
36
|
LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
|
@@ -84,7 +86,7 @@ class SweepRun:
|
|
84
86
|
id: str
|
85
87
|
worker_id: int
|
86
88
|
state: RunState = RunState.RUNNING
|
87
|
-
queued_run: Optional[public.QueuedRun] = None
|
89
|
+
queued_run: Optional["public.QueuedRun"] = None
|
88
90
|
args: Optional[Dict[str, Any]] = None
|
89
91
|
logs: Optional[List[str]] = None
|
90
92
|
|
@@ -98,7 +100,7 @@ class Scheduler(ABC):
|
|
98
100
|
|
99
101
|
def __init__(
|
100
102
|
self,
|
101
|
-
api: Api,
|
103
|
+
api: "Api",
|
102
104
|
*args: Optional[Any],
|
103
105
|
polling_sleep: Optional[float] = None,
|
104
106
|
sweep_id: Optional[str] = None,
|
@@ -108,6 +110,8 @@ class Scheduler(ABC):
|
|
108
110
|
num_workers: Optional[Union[int, str]] = None,
|
109
111
|
**kwargs: Optional[Any],
|
110
112
|
):
|
113
|
+
from wandb.apis.public import Api as PublicApi
|
114
|
+
|
111
115
|
self._api = api
|
112
116
|
self._public_api = PublicApi()
|
113
117
|
self._entity = (
|
@@ -244,7 +248,7 @@ class Scheduler(ABC):
|
|
244
248
|
_id: w for _id, w in self._workers.items() if _id not in self.busy_workers
|
245
249
|
}
|
246
250
|
|
247
|
-
def _init_wandb_run(self) -> SdkRun:
|
251
|
+
def _init_wandb_run(self) -> "SdkRun":
|
248
252
|
"""Controls resume or init logic for a scheduler wandb run."""
|
249
253
|
_type = self._kwargs.get("sweep_type", "sweep")
|
250
254
|
run: SdkRun = wandb.init(
|
wandb/sdk/launch/sweeps/utils.py
CHANGED
@@ -1,15 +1,17 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
8
8
|
import wandb
|
9
9
|
from wandb import util
|
10
|
-
from wandb.apis.public import Api as PublicApi
|
11
10
|
from wandb.sdk.launch.errors import LaunchError
|
12
11
|
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from wandb.apis.public import Api as PublicApi
|
14
|
+
|
13
15
|
DEFAULT_SWEEP_COMMAND: List[str] = [
|
14
16
|
"${env}",
|
15
17
|
"${interpreter}",
|
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
|
|
276
278
|
return entry_point, macro_args
|
277
279
|
|
278
280
|
|
279
|
-
def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
|
281
|
+
def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
|
280
282
|
"""Check if the job exists using the public api.
|
281
283
|
|
282
284
|
Returns: True if no job is passed, or if the job exists.
|
wandb/sdk/launch/utils.py
CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
|
|
127
127
|
prefix = ""
|
128
128
|
if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
|
129
129
|
prefix = "🚀 "
|
130
|
-
wandb.termlog(
|
130
|
+
wandb.termlog(
|
131
|
+
f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
|
132
|
+
)
|
131
133
|
return project, entity
|
132
134
|
|
133
135
|
|
@@ -58,6 +58,7 @@ _Setting = Literal[
|
|
58
58
|
"_sync",
|
59
59
|
"_os",
|
60
60
|
"_platform",
|
61
|
+
"_proxies",
|
61
62
|
"_python",
|
62
63
|
"_runqueue_item_id",
|
63
64
|
"_require_nexus",
|
@@ -84,6 +85,7 @@ _Setting = Literal[
|
|
84
85
|
"azure_account_url_to_access_key",
|
85
86
|
"base_url",
|
86
87
|
"code_dir",
|
88
|
+
"colab_url",
|
87
89
|
"config_paths",
|
88
90
|
"console",
|
89
91
|
"deployment",
|
@@ -121,6 +123,7 @@ _Setting = Literal[
|
|
121
123
|
"notebook_name",
|
122
124
|
"problem",
|
123
125
|
"program",
|
126
|
+
"program_abspath",
|
124
127
|
"program_relpath",
|
125
128
|
"project",
|
126
129
|
"project_url",
|
@@ -209,6 +212,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
209
212
|
"tmp_dir",
|
210
213
|
"_tmp_code_dir",
|
211
214
|
"_windows",
|
215
|
+
"colab_url",
|
212
216
|
"is_local",
|
213
217
|
"deployment",
|
214
218
|
"disable_code",
|
@@ -220,6 +224,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
220
224
|
"log_symlink_internal",
|
221
225
|
"log_symlink_user",
|
222
226
|
"log_user",
|
227
|
+
"program",
|
223
228
|
"project_url",
|
224
229
|
"resume_fname",
|
225
230
|
"run_url",
|
wandb/sdk/lib/gql_request.py
CHANGED
@@ -20,6 +20,7 @@ class GraphQLSession(HTTPTransport):
|
|
20
20
|
auth: Optional[Union[Tuple[str, str], Callable]] = None,
|
21
21
|
use_json: bool = False,
|
22
22
|
timeout: Optional[Union[int, float]] = None,
|
23
|
+
proxies: Optional[Dict[str, str]] = None,
|
23
24
|
**kwargs: Any,
|
24
25
|
) -> None:
|
25
26
|
"""Setup a session for sending GraphQL queries and mutations.
|
@@ -32,6 +33,8 @@ class GraphQLSession(HTTPTransport):
|
|
32
33
|
"""
|
33
34
|
super().__init__(url, **kwargs)
|
34
35
|
self.session = requests.Session()
|
36
|
+
if proxies:
|
37
|
+
self.session.proxies.update(proxies)
|
35
38
|
self.session.auth = auth
|
36
39
|
self.default_timeout = timeout
|
37
40
|
self.use_json = use_json
|
wandb/sdk/lib/ipython.py
CHANGED
@@ -45,6 +45,10 @@ def _get_python_type() -> PythonType:
|
|
45
45
|
(get_ipython().config.get("IPKernelApp", {}) or {})
|
46
46
|
.get("connection_file", "")
|
47
47
|
.lower()
|
48
|
+
) or (
|
49
|
+
(get_ipython().config.get("ColabKernelApp", {}) or {})
|
50
|
+
.get("connection_file", "")
|
51
|
+
.lower()
|
48
52
|
)
|
49
53
|
|
50
54
|
if (
|