wandb 0.15.10__py3-none-any.whl → 0.15.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. wandb/__init__.py +2 -1
  2. wandb/apis/public.py +51 -9
  3. wandb/apis/reports/blocks.py +1 -0
  4. wandb/cli/cli.py +14 -9
  5. wandb/env.py +11 -1
  6. wandb/integration/xgboost/xgboost.py +3 -3
  7. wandb/proto/v3/wandb_internal_pb2.py +300 -267
  8. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  9. wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
  10. wandb/proto/v4/wandb_internal_pb2.py +260 -252
  11. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  12. wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
  13. wandb/sdk/artifacts/artifact.py +9 -6
  14. wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
  15. wandb/sdk/data_types/image.py +1 -1
  16. wandb/sdk/internal/file_stream.py +2 -1
  17. wandb/sdk/internal/handler.py +24 -20
  18. wandb/sdk/internal/internal_api.py +9 -1
  19. wandb/sdk/internal/sender.py +4 -1
  20. wandb/sdk/internal/system/system_info.py +2 -2
  21. wandb/sdk/launch/__init__.py +5 -0
  22. wandb/sdk/launch/{launch.py → _launch.py} +53 -54
  23. wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
  24. wandb/sdk/launch/agent/agent.py +36 -18
  25. wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
  26. wandb/sdk/launch/runner/abstract.py +0 -2
  27. wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
  28. wandb/sdk/launch/runner/kubernetes_runner.py +44 -301
  29. wandb/sdk/launch/runner/local_container.py +5 -2
  30. wandb/sdk/launch/sweeps/scheduler.py +14 -10
  31. wandb/sdk/launch/sweeps/utils.py +5 -3
  32. wandb/sdk/launch/utils.py +3 -1
  33. wandb/sdk/lib/_settings_toposort_generated.py +5 -0
  34. wandb/sdk/lib/gql_request.py +3 -0
  35. wandb/sdk/lib/ipython.py +4 -0
  36. wandb/sdk/service/service.py +19 -6
  37. wandb/sdk/wandb_init.py +7 -2
  38. wandb/sdk/wandb_run.py +2 -5
  39. wandb/sdk/wandb_settings.py +48 -2
  40. wandb/util.py +1 -1
  41. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/METADATA +4 -1
  42. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/RECORD +46 -45
  43. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
  44. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/WHEEL +0 -0
  45. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
  46. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,9 @@ import base64
4
4
  import json
5
5
  import logging
6
6
  import time
7
- from threading import Lock, Thread
8
7
  from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
9
8
 
10
- import urllib3
9
+ import yaml
11
10
 
12
11
  import wandb
13
12
  from wandb.apis.internal import Api
@@ -15,7 +14,7 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
15
14
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
16
15
  from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
17
16
  from wandb.sdk.launch.registry.local_registry import LocalRegistry
18
- from wandb.sdk.launch.runner.abstract import State, Status
17
+ from wandb.sdk.launch.runner.abstract import Status
19
18
  from wandb.util import get_module
20
19
 
21
20
  from .._project_spec import EntryPoint, LaunchProject
@@ -29,22 +28,20 @@ from ..utils import (
29
28
  make_name_dns_safe,
30
29
  )
31
30
  from .abstract import AbstractRun, AbstractRunner
31
+ from .kubernetes_monitor import KubernetesRunMonitor
32
32
 
33
33
  get_module(
34
34
  "kubernetes",
35
35
  required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
36
36
  )
37
37
 
38
- from kubernetes import client, watch # type: ignore # noqa: E402
38
+ from kubernetes import client # type: ignore # noqa: E402
39
39
  from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
40
40
  from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
41
41
  from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
42
42
  CustomObjectsApi,
43
43
  )
44
44
  from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
45
- from kubernetes.client.models.v1_pod_status import ( # type: ignore # noqa: E402
46
- V1PodStatus,
47
- )
48
45
  from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
49
46
  from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
50
47
 
@@ -53,180 +50,6 @@ TIMEOUT = 5
53
50
  _logger = logging.getLogger(__name__)
54
51
 
55
52
 
56
- # Dict for mapping possible states of custom objects to the states we want to report
57
- # to the agent.
58
- CRD_STATE_DICT: Dict[str, State] = {
59
- "pending": "starting",
60
- "running": "running",
61
- "completed": "finished",
62
- "failed": "failed",
63
- "aborted": "failed",
64
- "terminating": "stopping",
65
- "terminated": "stopped",
66
- }
67
-
68
-
69
- def _is_preempted(status: "V1PodStatus") -> bool:
70
- """Check if this pod has been preempted."""
71
- if hasattr(status, "conditions") and status.conditions is not None:
72
- for condition in status.conditions:
73
- if condition.type == "DisruptionTarget" and condition.reason in [
74
- "EvictionByEvictionAPI",
75
- "PreemptionByScheduler",
76
- "TerminationByKubelet",
77
- ]:
78
- return True
79
- return False
80
-
81
-
82
- def _is_container_creating(status: "V1PodStatus") -> bool:
83
- """Check if this pod has started creating containers."""
84
- for container_status in status.container_statuses or []:
85
- if (
86
- container_status.state
87
- and container_status.state.waiting
88
- and container_status.state.waiting.reason == "ContainerCreating"
89
- ):
90
- return True
91
- return False
92
-
93
-
94
- class KubernetesRunMonitor:
95
- def __init__(
96
- self,
97
- job_field_selector: str,
98
- pod_label_selector: str,
99
- namespace: str,
100
- batch_api: "BatchV1Api",
101
- core_api: "CoreV1Api",
102
- ) -> None:
103
- """Initial KubernetesRunMonitor.
104
-
105
- Arguments:
106
- jobname: Name of the job.
107
-
108
- Returns:
109
- None.
110
- """
111
- self.pod_label_selector = pod_label_selector
112
- self.job_field_selector = job_field_selector
113
- self.namespace = namespace
114
- self.batch_api = batch_api
115
- self.core_api = core_api
116
-
117
- self._status_lock = Lock()
118
- self._status = Status("starting")
119
-
120
- self._watch_job_thread = Thread(target=self._watch_job, daemon=True)
121
- self._watch_pods_thread = Thread(target=self._watch_pods, daemon=True)
122
-
123
- self._job_watcher = watch.Watch()
124
- self._pod_watcher = watch.Watch()
125
-
126
- def start(self) -> None:
127
- """Start the run monitor."""
128
- if self._watch_job_thread.is_alive() or self._watch_pods_thread.is_alive():
129
- raise LaunchError(
130
- "Attempted to start monitor that has already started"
131
- ) # TODO: what should I do here?
132
- self._watch_job_thread.start()
133
- self._watch_pods_thread.start()
134
-
135
- def stop(self) -> None:
136
- """Stop the run monitor."""
137
- self._job_watcher.stop()
138
- self._pod_watcher.stop()
139
-
140
- def _set_status(self, status: Status) -> None:
141
- """Set the run status."""
142
- with self._status_lock:
143
- self._status = status
144
-
145
- def get_status(self) -> Status:
146
- """Get the run status."""
147
- with self._status_lock:
148
- return self._status
149
-
150
- def _watch_pods(self) -> None:
151
- """Watch for pods created matching the jobname."""
152
- try:
153
- # Stream with no timeout polling for pod status updates
154
- for event in self._pod_watcher.stream(
155
- self.core_api.list_namespaced_pod,
156
- namespace=self.namespace,
157
- label_selector=self.pod_label_selector,
158
- ):
159
- type = event.get("type")
160
- object = event.get("object")
161
-
162
- if type == "MODIFIED":
163
- if object.status.phase == "Running":
164
- self._set_status(Status("running"))
165
- if _is_preempted(object.status):
166
- self._set_status(Status("preempted"))
167
- self.stop()
168
- break
169
- if _is_container_creating(object.status):
170
- self._set_status(Status("starting"))
171
-
172
- # This can happen if the initial cluster connection fails.
173
- except ApiException as e:
174
- raise LaunchError(
175
- f"Exception when calling CoreV1Api.list_namespaced_pod with selector {self.pod_label_selector}: {e}"
176
- )
177
-
178
- # This can happen if the stream starts and gets broken, typically because
179
- # a thread is hanging. The kubernetes SDK is already implementing a
180
- # retry loop so if we get here it means that the pods cannot be monitored.
181
- except urllib3.exceptions.ProtocolError as e:
182
- state = self.get_status().state
183
- if state in ["failed", "finished"]:
184
- _logger.warning(
185
- f"Hanging pod monitor thread with selector {self.pod_label_selector}: {e}"
186
- )
187
- return
188
- raise LaunchError(
189
- f"Broken event stream for pod watcher in state '{state}' and selector {self.pod_label_selector}: {e}"
190
- )
191
-
192
- def _watch_job(self) -> None:
193
- """Watch for job matching the jobname."""
194
- try:
195
- for event in self._job_watcher.stream(
196
- self.batch_api.list_namespaced_job,
197
- namespace="default",
198
- field_selector=self.job_field_selector,
199
- ):
200
- object = event.get("object")
201
- if object.status.succeeded == 1:
202
- self._set_status(Status("finished"))
203
- self.stop()
204
- break
205
- elif object.status.failed is not None and object.status.failed >= 1:
206
- self._set_status(Status("failed"))
207
- self.stop()
208
- break
209
-
210
- # This can happen if the initial cluster connection fails.
211
- except ApiException as e:
212
- raise LaunchError(
213
- f"Exception when calling CoreV1Api.list_namespaced_job with selector {self.job_field_selector}: {e}"
214
- )
215
-
216
- # This can happen if the connection is lost to the Kubernetes API server
217
- # and cannot be re-established.
218
- except urllib3.exceptions.ProtocolError as e:
219
- state = self.get_status().state
220
- if state in ["finished", "failed"]:
221
- _logger.warning(
222
- f"Hanging job monitor thread with select {self.job_field_selector}: {e}"
223
- )
224
- return
225
- raise LaunchError(
226
- f"Broken event stream for job watcher in state {state} with selector {self.job_field_selector}: {e}"
227
- )
228
-
229
-
230
53
  class KubernetesSubmittedRun(AbstractRun):
231
54
  """Wrapper for a launched run on Kubernetes."""
232
55
 
@@ -266,9 +89,6 @@ class KubernetesSubmittedRun(AbstractRun):
266
89
  self.core_api = core_api
267
90
  self.name = name
268
91
  self.namespace = namespace
269
- self.job = self.batch_api.read_namespaced_job(
270
- name=self.name, namespace=self.namespace
271
- )
272
92
  self._fail_count = 0
273
93
  self.pod_names = pod_names
274
94
  self.secret = secret
@@ -309,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
309
129
  while True:
310
130
  status = self.get_status()
311
131
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
312
- if status.state != "running":
132
+ if status.state in ["finished", "failed", "preempted"]:
313
133
  break
314
134
  time.sleep(5)
315
135
  return (
@@ -331,35 +151,18 @@ class KubernetesSubmittedRun(AbstractRun):
331
151
  def get_status(self) -> Status:
332
152
  return self.monitor.get_status()
333
153
 
334
- def suspend(self) -> None:
335
- """Suspend the run."""
336
- self.job.spec.suspend = True
337
- self.batch_api.patch_namespaced_job(
338
- name=self.name, namespace=self.namespace, body=self.job
339
- )
340
- timeout = TIMEOUT
341
- job_response = self.batch_api.read_namespaced_job_status(
342
- name=self.name, namespace=self.namespace
343
- )
344
- while job_response.status.conditions is None and timeout > 0:
345
- time.sleep(1)
346
- timeout -= 1
347
- job_response = self.batch_api.read_namespaced_job_status(
348
- name=self.name, namespace=self.namespace
349
- )
350
-
351
- if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
352
- raise LaunchError(
353
- "Failed to suspend job {}. Check Kubernetes dashboard for more info.".format(
354
- self.name
355
- )
356
- )
357
-
358
154
  def cancel(self) -> None:
359
155
  """Cancel the run."""
360
- self.suspend()
361
156
  self.monitor.stop()
362
- self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
157
+ try:
158
+ self.batch_api.delete_namespaced_job(
159
+ namespace=self.namespace,
160
+ name=self.name,
161
+ )
162
+ except ApiException as e:
163
+ raise LaunchError(
164
+ f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
165
+ ) from e
363
166
 
364
167
 
365
168
  class CrdSubmittedRun(AbstractRun):
@@ -374,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
374
177
  namespace: str,
375
178
  core_api: CoreV1Api,
376
179
  custom_api: CustomObjectsApi,
377
- pod_names: List[str],
180
+ monitor: KubernetesRunMonitor,
378
181
  ) -> None:
379
182
  """Create a run object for tracking the progress of a CRD.
380
183
 
@@ -386,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
386
189
  namespace: The namespace of the CRD instance.
387
190
  core_api: The Kubernetes core API client.
388
191
  custom_api: The Kubernetes custom object API client.
389
- pod_names: The names of the pods associated with the CRD instance.
192
+ monitor: The run monitor.
390
193
 
391
194
  Raises:
392
195
  LaunchError: If the CRD instance does not exist.
@@ -398,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
398
201
  self.namespace = namespace
399
202
  self.core_api = core_api
400
203
  self.custom_api = custom_api
401
- self.pod_names = pod_names
402
204
  self._fail_count = 0
403
- try:
404
- self.job = self.custom_api.get_namespaced_custom_object(
405
- group=self.group,
406
- version=self.version,
407
- namespace=self.namespace,
408
- plural=self.plural,
409
- name=self.name,
410
- )
411
- except ApiException as e:
412
- raise LaunchError(
413
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
414
- ) from e
205
+ self.monitor = monitor
415
206
 
416
207
  @property
417
208
  def id(self) -> str:
@@ -423,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
423
214
  # TODO: test more carefully once we release multi-node support
424
215
  logs: Dict[str, Optional[str]] = {}
425
216
  try:
426
- for pod_name in self.pod_names:
217
+ pods = self.core_api.list_namespaced_pod(
218
+ label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
219
+ )
220
+ pod_names = [pi.metadata.name for pi in pods.items]
221
+ for pod_name in pod_names:
427
222
  logs[pod_name] = self.core_api.read_namespaced_pod_log(
428
223
  name=pod_name, namespace=self.namespace
429
224
  )
@@ -437,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
437
232
 
438
233
  def get_status(self) -> Status:
439
234
  """Get status of custom object."""
440
- try:
441
- job_response = self.custom_api.get_namespaced_custom_object_status(
442
- group=self.group,
443
- version=self.version,
444
- namespace=self.namespace,
445
- plural=self.plural,
446
- name=self.name,
447
- )
448
- except ApiException as e:
449
- raise LaunchError(
450
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
451
- ) from e
452
- # Custom objects can technically define whater states and format the
453
- # response to the status request however they want. This checks for
454
- # the most common cases.
455
- status = job_response["status"]
456
- state = status.get("state")
457
- if isinstance(state, dict):
458
- state = state.get("phase")
459
- if state is None:
460
- raise LaunchError(
461
- f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
462
- )
463
- return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
235
+ return self.monitor.get_status()
464
236
 
465
237
  def cancel(self) -> None:
466
238
  """Cancel the custom object."""
@@ -482,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
482
254
  while True:
483
255
  status = self.get_status()
484
256
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
485
- if status.state != "running":
486
- break
487
257
  time.sleep(5)
488
- return status.state == "finished"
258
+ if status.state in ["finished", "failed", "preempted"]:
259
+ return status.state == "finished"
489
260
 
490
261
 
491
262
  class KubernetesRunner(AbstractRunner):
@@ -512,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
512
283
  self.environment = environment
513
284
  self.registry = registry
514
285
 
515
- def wait_job_launch(
516
- self,
517
- job_name: str,
518
- namespace: str,
519
- core_api: "CoreV1Api",
520
- label: str = "job-name",
521
- ) -> List[str]:
522
- """Wait for a job to be launched and return the pod names.
523
-
524
- Arguments:
525
- job_name: The name of the job.
526
- namespace: The namespace of the job.
527
- core_api: The Kubernetes core API client.
528
- label: The label key to match against job_name.
529
-
530
- Returns:
531
- The names of the pods associated with the job.
532
- """
533
- pods = core_api.list_namespaced_pod(
534
- label_selector=f"{label}={job_name}", namespace=namespace
535
- )
536
- timeout = TIMEOUT
537
- while len(pods.items) == 0 and timeout > 0:
538
- time.sleep(1)
539
- timeout -= 1
540
- pods = core_api.list_namespaced_pod(
541
- label_selector=f"{label}={job_name}", namespace=namespace
542
- )
543
-
544
- if timeout == 0:
545
- raise LaunchError(
546
- "No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
547
- job_name
548
- )
549
- )
550
-
551
- pod_names = [pi.metadata.name for pi in pods.items]
552
- wandb.termlog(
553
- f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
554
- )
555
- return pod_names
556
-
557
286
  def get_namespace(
558
287
  self, resource_args: Dict[str, Any], context: Dict[str, Any]
559
288
  ) -> str:
@@ -742,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
742
471
  body=resource_args,
743
472
  )
744
473
  except ApiException as e:
474
+ body = json.loads(e.body)
475
+ body_yaml = yaml.dump(body)
745
476
  raise LaunchError(
746
- f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
477
+ f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
747
478
  ) from e
748
479
  name = response.get("metadata", {}).get("name")
749
480
  _logger.info(f"Created {kind} {response['metadata']['name']}")
750
481
  core = client.CoreV1Api(api_client)
751
- pod_names = self.wait_job_launch(
752
- launch_project.run_id, namespace, core, label="wandb/run-id"
482
+ run_monitor = KubernetesRunMonitor(
483
+ job_field_selector=f"metadata.name={name}",
484
+ pod_label_selector=f"wandb/run-id={launch_project.run_id}",
485
+ namespace=namespace,
486
+ batch_api=None,
487
+ core_api=core,
488
+ custom_api=api,
489
+ group=group,
490
+ version=version,
491
+ plural=plural,
753
492
  )
754
- return CrdSubmittedRun(
493
+ run_monitor.start()
494
+ submitted_run = CrdSubmittedRun(
755
495
  name=name,
756
496
  group=group,
757
497
  version=version,
@@ -759,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
759
499
  plural=plural,
760
500
  core_api=client.CoreV1Api(api_client),
761
501
  custom_api=api,
762
- pod_names=pod_names,
502
+ monitor=run_monitor,
763
503
  )
504
+ if self.backend_config[PROJECT_SYNCHRONOUS]:
505
+ submitted_run.wait()
506
+ return submitted_run
764
507
 
765
508
  batch_api = kubernetes.client.BatchV1Api(api_client)
766
509
  core_api = kubernetes.client.CoreV1Api(api_client)
@@ -5,7 +5,7 @@ import subprocess
5
5
  import sys
6
6
  import threading
7
7
  import time
8
- from typing import Any, Dict, List, Optional
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
9
9
 
10
10
  import wandb
11
11
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
@@ -26,6 +26,9 @@ from ..utils import (
26
26
  )
27
27
  from .abstract import AbstractRun, AbstractRunner, Status
28
28
 
29
+ if TYPE_CHECKING:
30
+ from wandb.apis.internal import Api
31
+
29
32
  _logger = logging.getLogger(__name__)
30
33
 
31
34
 
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
95
98
 
96
99
  def __init__(
97
100
  self,
98
- api: wandb.apis.internal.Api,
101
+ api: "Api",
99
102
  backend_config: Dict[str, Any],
100
103
  environment: AbstractEnvironment,
101
104
  registry: AbstractRegistry,
@@ -9,26 +9,28 @@ import traceback
9
9
  from abc import ABC, abstractmethod
10
10
  from dataclasses import dataclass
11
11
  from enum import Enum
12
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
12
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
13
13
 
14
14
  import click
15
15
  import yaml
16
16
 
17
17
  import wandb
18
- import wandb.apis.public as public
19
- from wandb.apis.internal import Api
20
- from wandb.apis.public import Api as PublicApi
21
- from wandb.apis.public import QueuedRun, Run
22
18
  from wandb.errors import CommError
19
+ from wandb.sdk.launch._launch_add import launch_add
23
20
  from wandb.sdk.launch.errors import LaunchError
24
- from wandb.sdk.launch.launch_add import launch_add
25
21
  from wandb.sdk.launch.sweeps import SchedulerError
26
22
  from wandb.sdk.launch.sweeps.utils import (
27
23
  create_sweep_command_args,
28
24
  make_launch_sweep_entrypoint,
29
25
  )
30
26
  from wandb.sdk.lib.runid import generate_id
31
- from wandb.sdk.wandb_run import Run as SdkRun
27
+
28
+ if TYPE_CHECKING:
29
+ import wandb.apis.public as public
30
+ from wandb.apis.internal import Api
31
+ from wandb.apis.public import QueuedRun, Run
32
+ from wandb.sdk.wandb_run import Run as SdkRun
33
+
32
34
 
33
35
  _logger = logging.getLogger(__name__)
34
36
  LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
@@ -84,7 +86,7 @@ class SweepRun:
84
86
  id: str
85
87
  worker_id: int
86
88
  state: RunState = RunState.RUNNING
87
- queued_run: Optional[public.QueuedRun] = None
89
+ queued_run: Optional["public.QueuedRun"] = None
88
90
  args: Optional[Dict[str, Any]] = None
89
91
  logs: Optional[List[str]] = None
90
92
 
@@ -98,7 +100,7 @@ class Scheduler(ABC):
98
100
 
99
101
  def __init__(
100
102
  self,
101
- api: Api,
103
+ api: "Api",
102
104
  *args: Optional[Any],
103
105
  polling_sleep: Optional[float] = None,
104
106
  sweep_id: Optional[str] = None,
@@ -108,6 +110,8 @@ class Scheduler(ABC):
108
110
  num_workers: Optional[Union[int, str]] = None,
109
111
  **kwargs: Optional[Any],
110
112
  ):
113
+ from wandb.apis.public import Api as PublicApi
114
+
111
115
  self._api = api
112
116
  self._public_api = PublicApi()
113
117
  self._entity = (
@@ -244,7 +248,7 @@ class Scheduler(ABC):
244
248
  _id: w for _id, w in self._workers.items() if _id not in self.busy_workers
245
249
  }
246
250
 
247
- def _init_wandb_run(self) -> SdkRun:
251
+ def _init_wandb_run(self) -> "SdkRun":
248
252
  """Controls resume or init logic for a scheduler wandb run."""
249
253
  _type = self._kwargs.get("sweep_type", "sweep")
250
254
  run: SdkRun = wandb.init(
@@ -1,15 +1,17 @@
1
1
  import json
2
2
  import os
3
3
  import re
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  import yaml
7
7
 
8
8
  import wandb
9
9
  from wandb import util
10
- from wandb.apis.public import Api as PublicApi
11
10
  from wandb.sdk.launch.errors import LaunchError
12
11
 
12
+ if TYPE_CHECKING:
13
+ from wandb.apis.public import Api as PublicApi
14
+
13
15
  DEFAULT_SWEEP_COMMAND: List[str] = [
14
16
  "${env}",
15
17
  "${interpreter}",
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
276
278
  return entry_point, macro_args
277
279
 
278
280
 
279
- def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
281
+ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
280
282
  """Check if the job exists using the public api.
281
283
 
282
284
  Returns: True if no job is passed, or if the job exists.
wandb/sdk/launch/utils.py CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
127
127
  prefix = ""
128
128
  if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
129
129
  prefix = "🚀 "
130
- wandb.termlog(f"{LOG_PREFIX}{prefix}Launching run into {entity}/{project}")
130
+ wandb.termlog(
131
+ f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
132
+ )
131
133
  return project, entity
132
134
 
133
135
 
@@ -58,6 +58,7 @@ _Setting = Literal[
58
58
  "_sync",
59
59
  "_os",
60
60
  "_platform",
61
+ "_proxies",
61
62
  "_python",
62
63
  "_runqueue_item_id",
63
64
  "_require_nexus",
@@ -84,6 +85,7 @@ _Setting = Literal[
84
85
  "azure_account_url_to_access_key",
85
86
  "base_url",
86
87
  "code_dir",
88
+ "colab_url",
87
89
  "config_paths",
88
90
  "console",
89
91
  "deployment",
@@ -121,6 +123,7 @@ _Setting = Literal[
121
123
  "notebook_name",
122
124
  "problem",
123
125
  "program",
126
+ "program_abspath",
124
127
  "program_relpath",
125
128
  "project",
126
129
  "project_url",
@@ -209,6 +212,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
209
212
  "tmp_dir",
210
213
  "_tmp_code_dir",
211
214
  "_windows",
215
+ "colab_url",
212
216
  "is_local",
213
217
  "deployment",
214
218
  "disable_code",
@@ -220,6 +224,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
220
224
  "log_symlink_internal",
221
225
  "log_symlink_user",
222
226
  "log_user",
227
+ "program",
223
228
  "project_url",
224
229
  "resume_fname",
225
230
  "run_url",
@@ -20,6 +20,7 @@ class GraphQLSession(HTTPTransport):
20
20
  auth: Optional[Union[Tuple[str, str], Callable]] = None,
21
21
  use_json: bool = False,
22
22
  timeout: Optional[Union[int, float]] = None,
23
+ proxies: Optional[Dict[str, str]] = None,
23
24
  **kwargs: Any,
24
25
  ) -> None:
25
26
  """Setup a session for sending GraphQL queries and mutations.
@@ -32,6 +33,8 @@ class GraphQLSession(HTTPTransport):
32
33
  """
33
34
  super().__init__(url, **kwargs)
34
35
  self.session = requests.Session()
36
+ if proxies:
37
+ self.session.proxies.update(proxies)
35
38
  self.session.auth = auth
36
39
  self.default_timeout = timeout
37
40
  self.use_json = use_json
wandb/sdk/lib/ipython.py CHANGED
@@ -45,6 +45,10 @@ def _get_python_type() -> PythonType:
45
45
  (get_ipython().config.get("IPKernelApp", {}) or {})
46
46
  .get("connection_file", "")
47
47
  .lower()
48
+ ) or (
49
+ (get_ipython().config.get("ColabKernelApp", {}) or {})
50
+ .get("connection_file", "")
51
+ .lower()
48
52
  )
49
53
 
50
54
  if (