wandb 0.15.10__py3-none-any.whl → 0.15.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. wandb/__init__.py +2 -1
  2. wandb/apis/public.py +51 -9
  3. wandb/apis/reports/blocks.py +1 -0
  4. wandb/cli/cli.py +14 -9
  5. wandb/env.py +11 -1
  6. wandb/integration/xgboost/xgboost.py +3 -3
  7. wandb/proto/v3/wandb_internal_pb2.py +300 -267
  8. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  9. wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
  10. wandb/proto/v4/wandb_internal_pb2.py +260 -252
  11. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  12. wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
  13. wandb/sdk/artifacts/artifact.py +9 -6
  14. wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
  15. wandb/sdk/data_types/image.py +1 -1
  16. wandb/sdk/internal/file_stream.py +2 -1
  17. wandb/sdk/internal/handler.py +24 -20
  18. wandb/sdk/internal/internal_api.py +9 -1
  19. wandb/sdk/internal/sender.py +4 -1
  20. wandb/sdk/internal/system/system_info.py +2 -2
  21. wandb/sdk/launch/__init__.py +5 -0
  22. wandb/sdk/launch/{launch.py → _launch.py} +53 -54
  23. wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
  24. wandb/sdk/launch/agent/agent.py +36 -18
  25. wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
  26. wandb/sdk/launch/runner/abstract.py +0 -2
  27. wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
  28. wandb/sdk/launch/runner/kubernetes_runner.py +44 -301
  29. wandb/sdk/launch/runner/local_container.py +5 -2
  30. wandb/sdk/launch/sweeps/scheduler.py +14 -10
  31. wandb/sdk/launch/sweeps/utils.py +5 -3
  32. wandb/sdk/launch/utils.py +3 -1
  33. wandb/sdk/lib/_settings_toposort_generated.py +5 -0
  34. wandb/sdk/lib/gql_request.py +3 -0
  35. wandb/sdk/lib/ipython.py +4 -0
  36. wandb/sdk/service/service.py +19 -6
  37. wandb/sdk/wandb_init.py +7 -2
  38. wandb/sdk/wandb_run.py +2 -5
  39. wandb/sdk/wandb_settings.py +48 -2
  40. wandb/util.py +1 -1
  41. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/METADATA +4 -1
  42. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/RECORD +46 -45
  43. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
  44. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/WHEEL +0 -0
  45. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
  46. {wandb-0.15.10.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -4,10 +4,9 @@ import base64
4
4
  import json
5
5
  import logging
6
6
  import time
7
- from threading import Lock, Thread
8
7
  from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
9
8
 
10
- import urllib3
9
+ import yaml
11
10
 
12
11
  import wandb
13
12
  from wandb.apis.internal import Api
@@ -15,7 +14,7 @@ from wandb.sdk.launch.environment.abstract import AbstractEnvironment
15
14
  from wandb.sdk.launch.registry.abstract import AbstractRegistry
16
15
  from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
17
16
  from wandb.sdk.launch.registry.local_registry import LocalRegistry
18
- from wandb.sdk.launch.runner.abstract import State, Status
17
+ from wandb.sdk.launch.runner.abstract import Status
19
18
  from wandb.util import get_module
20
19
 
21
20
  from .._project_spec import EntryPoint, LaunchProject
@@ -29,22 +28,20 @@ from ..utils import (
29
28
  make_name_dns_safe,
30
29
  )
31
30
  from .abstract import AbstractRun, AbstractRunner
31
+ from .kubernetes_monitor import KubernetesRunMonitor
32
32
 
33
33
  get_module(
34
34
  "kubernetes",
35
35
  required="Kubernetes runner requires the kubernetes package. Please install it with `pip install wandb[launch]`.",
36
36
  )
37
37
 
38
- from kubernetes import client, watch # type: ignore # noqa: E402
38
+ from kubernetes import client # type: ignore # noqa: E402
39
39
  from kubernetes.client.api.batch_v1_api import BatchV1Api # type: ignore # noqa: E402
40
40
  from kubernetes.client.api.core_v1_api import CoreV1Api # type: ignore # noqa: E402
41
41
  from kubernetes.client.api.custom_objects_api import ( # type: ignore # noqa: E402
42
42
  CustomObjectsApi,
43
43
  )
44
44
  from kubernetes.client.models.v1_job import V1Job # type: ignore # noqa: E402
45
- from kubernetes.client.models.v1_pod_status import ( # type: ignore # noqa: E402
46
- V1PodStatus,
47
- )
48
45
  from kubernetes.client.models.v1_secret import V1Secret # type: ignore # noqa: E402
49
46
  from kubernetes.client.rest import ApiException # type: ignore # noqa: E402
50
47
 
@@ -53,180 +50,6 @@ TIMEOUT = 5
53
50
  _logger = logging.getLogger(__name__)
54
51
 
55
52
 
56
- # Dict for mapping possible states of custom objects to the states we want to report
57
- # to the agent.
58
- CRD_STATE_DICT: Dict[str, State] = {
59
- "pending": "starting",
60
- "running": "running",
61
- "completed": "finished",
62
- "failed": "failed",
63
- "aborted": "failed",
64
- "terminating": "stopping",
65
- "terminated": "stopped",
66
- }
67
-
68
-
69
- def _is_preempted(status: "V1PodStatus") -> bool:
70
- """Check if this pod has been preempted."""
71
- if hasattr(status, "conditions") and status.conditions is not None:
72
- for condition in status.conditions:
73
- if condition.type == "DisruptionTarget" and condition.reason in [
74
- "EvictionByEvictionAPI",
75
- "PreemptionByScheduler",
76
- "TerminationByKubelet",
77
- ]:
78
- return True
79
- return False
80
-
81
-
82
- def _is_container_creating(status: "V1PodStatus") -> bool:
83
- """Check if this pod has started creating containers."""
84
- for container_status in status.container_statuses or []:
85
- if (
86
- container_status.state
87
- and container_status.state.waiting
88
- and container_status.state.waiting.reason == "ContainerCreating"
89
- ):
90
- return True
91
- return False
92
-
93
-
94
- class KubernetesRunMonitor:
95
- def __init__(
96
- self,
97
- job_field_selector: str,
98
- pod_label_selector: str,
99
- namespace: str,
100
- batch_api: "BatchV1Api",
101
- core_api: "CoreV1Api",
102
- ) -> None:
103
- """Initial KubernetesRunMonitor.
104
-
105
- Arguments:
106
- jobname: Name of the job.
107
-
108
- Returns:
109
- None.
110
- """
111
- self.pod_label_selector = pod_label_selector
112
- self.job_field_selector = job_field_selector
113
- self.namespace = namespace
114
- self.batch_api = batch_api
115
- self.core_api = core_api
116
-
117
- self._status_lock = Lock()
118
- self._status = Status("starting")
119
-
120
- self._watch_job_thread = Thread(target=self._watch_job, daemon=True)
121
- self._watch_pods_thread = Thread(target=self._watch_pods, daemon=True)
122
-
123
- self._job_watcher = watch.Watch()
124
- self._pod_watcher = watch.Watch()
125
-
126
- def start(self) -> None:
127
- """Start the run monitor."""
128
- if self._watch_job_thread.is_alive() or self._watch_pods_thread.is_alive():
129
- raise LaunchError(
130
- "Attempted to start monitor that has already started"
131
- ) # TODO: what should I do here?
132
- self._watch_job_thread.start()
133
- self._watch_pods_thread.start()
134
-
135
- def stop(self) -> None:
136
- """Stop the run monitor."""
137
- self._job_watcher.stop()
138
- self._pod_watcher.stop()
139
-
140
- def _set_status(self, status: Status) -> None:
141
- """Set the run status."""
142
- with self._status_lock:
143
- self._status = status
144
-
145
- def get_status(self) -> Status:
146
- """Get the run status."""
147
- with self._status_lock:
148
- return self._status
149
-
150
- def _watch_pods(self) -> None:
151
- """Watch for pods created matching the jobname."""
152
- try:
153
- # Stream with no timeout polling for pod status updates
154
- for event in self._pod_watcher.stream(
155
- self.core_api.list_namespaced_pod,
156
- namespace=self.namespace,
157
- label_selector=self.pod_label_selector,
158
- ):
159
- type = event.get("type")
160
- object = event.get("object")
161
-
162
- if type == "MODIFIED":
163
- if object.status.phase == "Running":
164
- self._set_status(Status("running"))
165
- if _is_preempted(object.status):
166
- self._set_status(Status("preempted"))
167
- self.stop()
168
- break
169
- if _is_container_creating(object.status):
170
- self._set_status(Status("starting"))
171
-
172
- # This can happen if the initial cluster connection fails.
173
- except ApiException as e:
174
- raise LaunchError(
175
- f"Exception when calling CoreV1Api.list_namespaced_pod with selector {self.pod_label_selector}: {e}"
176
- )
177
-
178
- # This can happen if the stream starts and gets broken, typically because
179
- # a thread is hanging. The kubernetes SDK is already implementing a
180
- # retry loop so if we get here it means that the pods cannot be monitored.
181
- except urllib3.exceptions.ProtocolError as e:
182
- state = self.get_status().state
183
- if state in ["failed", "finished"]:
184
- _logger.warning(
185
- f"Hanging pod monitor thread with selector {self.pod_label_selector}: {e}"
186
- )
187
- return
188
- raise LaunchError(
189
- f"Broken event stream for pod watcher in state '{state}' and selector {self.pod_label_selector}: {e}"
190
- )
191
-
192
- def _watch_job(self) -> None:
193
- """Watch for job matching the jobname."""
194
- try:
195
- for event in self._job_watcher.stream(
196
- self.batch_api.list_namespaced_job,
197
- namespace="default",
198
- field_selector=self.job_field_selector,
199
- ):
200
- object = event.get("object")
201
- if object.status.succeeded == 1:
202
- self._set_status(Status("finished"))
203
- self.stop()
204
- break
205
- elif object.status.failed is not None and object.status.failed >= 1:
206
- self._set_status(Status("failed"))
207
- self.stop()
208
- break
209
-
210
- # This can happen if the initial cluster connection fails.
211
- except ApiException as e:
212
- raise LaunchError(
213
- f"Exception when calling CoreV1Api.list_namespaced_job with selector {self.job_field_selector}: {e}"
214
- )
215
-
216
- # This can happen if the connection is lost to the Kubernetes API server
217
- # and cannot be re-established.
218
- except urllib3.exceptions.ProtocolError as e:
219
- state = self.get_status().state
220
- if state in ["finished", "failed"]:
221
- _logger.warning(
222
- f"Hanging job monitor thread with select {self.job_field_selector}: {e}"
223
- )
224
- return
225
- raise LaunchError(
226
- f"Broken event stream for job watcher in state {state} with selector {self.job_field_selector}: {e}"
227
- )
228
-
229
-
230
53
  class KubernetesSubmittedRun(AbstractRun):
231
54
  """Wrapper for a launched run on Kubernetes."""
232
55
 
@@ -266,9 +89,6 @@ class KubernetesSubmittedRun(AbstractRun):
266
89
  self.core_api = core_api
267
90
  self.name = name
268
91
  self.namespace = namespace
269
- self.job = self.batch_api.read_namespaced_job(
270
- name=self.name, namespace=self.namespace
271
- )
272
92
  self._fail_count = 0
273
93
  self.pod_names = pod_names
274
94
  self.secret = secret
@@ -309,7 +129,7 @@ class KubernetesSubmittedRun(AbstractRun):
309
129
  while True:
310
130
  status = self.get_status()
311
131
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
312
- if status.state != "running":
132
+ if status.state in ["finished", "failed", "preempted"]:
313
133
  break
314
134
  time.sleep(5)
315
135
  return (
@@ -331,35 +151,18 @@ class KubernetesSubmittedRun(AbstractRun):
331
151
  def get_status(self) -> Status:
332
152
  return self.monitor.get_status()
333
153
 
334
- def suspend(self) -> None:
335
- """Suspend the run."""
336
- self.job.spec.suspend = True
337
- self.batch_api.patch_namespaced_job(
338
- name=self.name, namespace=self.namespace, body=self.job
339
- )
340
- timeout = TIMEOUT
341
- job_response = self.batch_api.read_namespaced_job_status(
342
- name=self.name, namespace=self.namespace
343
- )
344
- while job_response.status.conditions is None and timeout > 0:
345
- time.sleep(1)
346
- timeout -= 1
347
- job_response = self.batch_api.read_namespaced_job_status(
348
- name=self.name, namespace=self.namespace
349
- )
350
-
351
- if timeout == 0 or job_response.status.conditions[0].type != "Suspended":
352
- raise LaunchError(
353
- "Failed to suspend job {}. Check Kubernetes dashboard for more info.".format(
354
- self.name
355
- )
356
- )
357
-
358
154
  def cancel(self) -> None:
359
155
  """Cancel the run."""
360
- self.suspend()
361
156
  self.monitor.stop()
362
- self.batch_api.delete_namespaced_job(name=self.name, namespace=self.namespace)
157
+ try:
158
+ self.batch_api.delete_namespaced_job(
159
+ namespace=self.namespace,
160
+ name=self.name,
161
+ )
162
+ except ApiException as e:
163
+ raise LaunchError(
164
+ f"Failed to delete Kubernetes Job {self.name} in namespace {self.namespace}: {str(e)}"
165
+ ) from e
363
166
 
364
167
 
365
168
  class CrdSubmittedRun(AbstractRun):
@@ -374,7 +177,7 @@ class CrdSubmittedRun(AbstractRun):
374
177
  namespace: str,
375
178
  core_api: CoreV1Api,
376
179
  custom_api: CustomObjectsApi,
377
- pod_names: List[str],
180
+ monitor: KubernetesRunMonitor,
378
181
  ) -> None:
379
182
  """Create a run object for tracking the progress of a CRD.
380
183
 
@@ -386,7 +189,7 @@ class CrdSubmittedRun(AbstractRun):
386
189
  namespace: The namespace of the CRD instance.
387
190
  core_api: The Kubernetes core API client.
388
191
  custom_api: The Kubernetes custom object API client.
389
- pod_names: The names of the pods associated with the CRD instance.
192
+ monitor: The run monitor.
390
193
 
391
194
  Raises:
392
195
  LaunchError: If the CRD instance does not exist.
@@ -398,20 +201,8 @@ class CrdSubmittedRun(AbstractRun):
398
201
  self.namespace = namespace
399
202
  self.core_api = core_api
400
203
  self.custom_api = custom_api
401
- self.pod_names = pod_names
402
204
  self._fail_count = 0
403
- try:
404
- self.job = self.custom_api.get_namespaced_custom_object(
405
- group=self.group,
406
- version=self.version,
407
- namespace=self.namespace,
408
- plural=self.plural,
409
- name=self.name,
410
- )
411
- except ApiException as e:
412
- raise LaunchError(
413
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
414
- ) from e
205
+ self.monitor = monitor
415
206
 
416
207
  @property
417
208
  def id(self) -> str:
@@ -423,7 +214,11 @@ class CrdSubmittedRun(AbstractRun):
423
214
  # TODO: test more carefully once we release multi-node support
424
215
  logs: Dict[str, Optional[str]] = {}
425
216
  try:
426
- for pod_name in self.pod_names:
217
+ pods = self.core_api.list_namespaced_pod(
218
+ label_selector=f"wandb/run-id={self.name}", namespace=self.namespace
219
+ )
220
+ pod_names = [pi.metadata.name for pi in pods.items]
221
+ for pod_name in pod_names:
427
222
  logs[pod_name] = self.core_api.read_namespaced_pod_log(
428
223
  name=pod_name, namespace=self.namespace
429
224
  )
@@ -437,30 +232,7 @@ class CrdSubmittedRun(AbstractRun):
437
232
 
438
233
  def get_status(self) -> Status:
439
234
  """Get status of custom object."""
440
- try:
441
- job_response = self.custom_api.get_namespaced_custom_object_status(
442
- group=self.group,
443
- version=self.version,
444
- namespace=self.namespace,
445
- plural=self.plural,
446
- name=self.name,
447
- )
448
- except ApiException as e:
449
- raise LaunchError(
450
- f"Failed to get CRD {self.name} in namespace {self.namespace}: {str(e)}"
451
- ) from e
452
- # Custom objects can technically define whater states and format the
453
- # response to the status request however they want. This checks for
454
- # the most common cases.
455
- status = job_response["status"]
456
- state = status.get("state")
457
- if isinstance(state, dict):
458
- state = state.get("phase")
459
- if state is None:
460
- raise LaunchError(
461
- f"Failed to get CRD {self.name} in namespace {self.namespace}: no state found"
462
- )
463
- return Status(CRD_STATE_DICT.get(state.lower(), "unknown"))
235
+ return self.monitor.get_status()
464
236
 
465
237
  def cancel(self) -> None:
466
238
  """Cancel the custom object."""
@@ -482,10 +254,9 @@ class CrdSubmittedRun(AbstractRun):
482
254
  while True:
483
255
  status = self.get_status()
484
256
  wandb.termlog(f"{LOG_PREFIX}Job {self.name} status: {status}")
485
- if status.state != "running":
486
- break
487
257
  time.sleep(5)
488
- return status.state == "finished"
258
+ if status.state in ["finished", "failed", "preempted"]:
259
+ return status.state == "finished"
489
260
 
490
261
 
491
262
  class KubernetesRunner(AbstractRunner):
@@ -512,48 +283,6 @@ class KubernetesRunner(AbstractRunner):
512
283
  self.environment = environment
513
284
  self.registry = registry
514
285
 
515
- def wait_job_launch(
516
- self,
517
- job_name: str,
518
- namespace: str,
519
- core_api: "CoreV1Api",
520
- label: str = "job-name",
521
- ) -> List[str]:
522
- """Wait for a job to be launched and return the pod names.
523
-
524
- Arguments:
525
- job_name: The name of the job.
526
- namespace: The namespace of the job.
527
- core_api: The Kubernetes core API client.
528
- label: The label key to match against job_name.
529
-
530
- Returns:
531
- The names of the pods associated with the job.
532
- """
533
- pods = core_api.list_namespaced_pod(
534
- label_selector=f"{label}={job_name}", namespace=namespace
535
- )
536
- timeout = TIMEOUT
537
- while len(pods.items) == 0 and timeout > 0:
538
- time.sleep(1)
539
- timeout -= 1
540
- pods = core_api.list_namespaced_pod(
541
- label_selector=f"{label}={job_name}", namespace=namespace
542
- )
543
-
544
- if timeout == 0:
545
- raise LaunchError(
546
- "No pods found for job {}. Check dashboard to see if job was launched successfully.".format(
547
- job_name
548
- )
549
- )
550
-
551
- pod_names = [pi.metadata.name for pi in pods.items]
552
- wandb.termlog(
553
- f"{LOG_PREFIX}Job {job_name} created on pod(s) {', '.join(pod_names)}. See logs with e.g. `kubectl logs {pod_names[0]} -n {namespace}`."
554
- )
555
- return pod_names
556
-
557
286
  def get_namespace(
558
287
  self, resource_args: Dict[str, Any], context: Dict[str, Any]
559
288
  ) -> str:
@@ -742,16 +471,27 @@ class KubernetesRunner(AbstractRunner):
742
471
  body=resource_args,
743
472
  )
744
473
  except ApiException as e:
474
+ body = json.loads(e.body)
475
+ body_yaml = yaml.dump(body)
745
476
  raise LaunchError(
746
- f"Error creating CRD of kind {kind}: {e.status} {e.reason}"
477
+ f"Error creating CRD of kind {kind}: {e.status} {e.reason}\n{body_yaml}"
747
478
  ) from e
748
479
  name = response.get("metadata", {}).get("name")
749
480
  _logger.info(f"Created {kind} {response['metadata']['name']}")
750
481
  core = client.CoreV1Api(api_client)
751
- pod_names = self.wait_job_launch(
752
- launch_project.run_id, namespace, core, label="wandb/run-id"
482
+ run_monitor = KubernetesRunMonitor(
483
+ job_field_selector=f"metadata.name={name}",
484
+ pod_label_selector=f"wandb/run-id={launch_project.run_id}",
485
+ namespace=namespace,
486
+ batch_api=None,
487
+ core_api=core,
488
+ custom_api=api,
489
+ group=group,
490
+ version=version,
491
+ plural=plural,
753
492
  )
754
- return CrdSubmittedRun(
493
+ run_monitor.start()
494
+ submitted_run = CrdSubmittedRun(
755
495
  name=name,
756
496
  group=group,
757
497
  version=version,
@@ -759,8 +499,11 @@ class KubernetesRunner(AbstractRunner):
759
499
  plural=plural,
760
500
  core_api=client.CoreV1Api(api_client),
761
501
  custom_api=api,
762
- pod_names=pod_names,
502
+ monitor=run_monitor,
763
503
  )
504
+ if self.backend_config[PROJECT_SYNCHRONOUS]:
505
+ submitted_run.wait()
506
+ return submitted_run
764
507
 
765
508
  batch_api = kubernetes.client.BatchV1Api(api_client)
766
509
  core_api = kubernetes.client.CoreV1Api(api_client)
@@ -5,7 +5,7 @@ import subprocess
5
5
  import sys
6
6
  import threading
7
7
  import time
8
- from typing import Any, Dict, List, Optional
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
9
9
 
10
10
  import wandb
11
11
  from wandb.sdk.launch.environment.abstract import AbstractEnvironment
@@ -26,6 +26,9 @@ from ..utils import (
26
26
  )
27
27
  from .abstract import AbstractRun, AbstractRunner, Status
28
28
 
29
+ if TYPE_CHECKING:
30
+ from wandb.apis.internal import Api
31
+
29
32
  _logger = logging.getLogger(__name__)
30
33
 
31
34
 
@@ -95,7 +98,7 @@ class LocalContainerRunner(AbstractRunner):
95
98
 
96
99
  def __init__(
97
100
  self,
98
- api: wandb.apis.internal.Api,
101
+ api: "Api",
99
102
  backend_config: Dict[str, Any],
100
103
  environment: AbstractEnvironment,
101
104
  registry: AbstractRegistry,
@@ -9,26 +9,28 @@ import traceback
9
9
  from abc import ABC, abstractmethod
10
10
  from dataclasses import dataclass
11
11
  from enum import Enum
12
- from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
12
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
13
13
 
14
14
  import click
15
15
  import yaml
16
16
 
17
17
  import wandb
18
- import wandb.apis.public as public
19
- from wandb.apis.internal import Api
20
- from wandb.apis.public import Api as PublicApi
21
- from wandb.apis.public import QueuedRun, Run
22
18
  from wandb.errors import CommError
19
+ from wandb.sdk.launch._launch_add import launch_add
23
20
  from wandb.sdk.launch.errors import LaunchError
24
- from wandb.sdk.launch.launch_add import launch_add
25
21
  from wandb.sdk.launch.sweeps import SchedulerError
26
22
  from wandb.sdk.launch.sweeps.utils import (
27
23
  create_sweep_command_args,
28
24
  make_launch_sweep_entrypoint,
29
25
  )
30
26
  from wandb.sdk.lib.runid import generate_id
31
- from wandb.sdk.wandb_run import Run as SdkRun
27
+
28
+ if TYPE_CHECKING:
29
+ import wandb.apis.public as public
30
+ from wandb.apis.internal import Api
31
+ from wandb.apis.public import QueuedRun, Run
32
+ from wandb.sdk.wandb_run import Run as SdkRun
33
+
32
34
 
33
35
  _logger = logging.getLogger(__name__)
34
36
  LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
@@ -84,7 +86,7 @@ class SweepRun:
84
86
  id: str
85
87
  worker_id: int
86
88
  state: RunState = RunState.RUNNING
87
- queued_run: Optional[public.QueuedRun] = None
89
+ queued_run: Optional["public.QueuedRun"] = None
88
90
  args: Optional[Dict[str, Any]] = None
89
91
  logs: Optional[List[str]] = None
90
92
 
@@ -98,7 +100,7 @@ class Scheduler(ABC):
98
100
 
99
101
  def __init__(
100
102
  self,
101
- api: Api,
103
+ api: "Api",
102
104
  *args: Optional[Any],
103
105
  polling_sleep: Optional[float] = None,
104
106
  sweep_id: Optional[str] = None,
@@ -108,6 +110,8 @@ class Scheduler(ABC):
108
110
  num_workers: Optional[Union[int, str]] = None,
109
111
  **kwargs: Optional[Any],
110
112
  ):
113
+ from wandb.apis.public import Api as PublicApi
114
+
111
115
  self._api = api
112
116
  self._public_api = PublicApi()
113
117
  self._entity = (
@@ -244,7 +248,7 @@ class Scheduler(ABC):
244
248
  _id: w for _id, w in self._workers.items() if _id not in self.busy_workers
245
249
  }
246
250
 
247
- def _init_wandb_run(self) -> SdkRun:
251
+ def _init_wandb_run(self) -> "SdkRun":
248
252
  """Controls resume or init logic for a scheduler wandb run."""
249
253
  _type = self._kwargs.get("sweep_type", "sweep")
250
254
  run: SdkRun = wandb.init(
@@ -1,15 +1,17 @@
1
1
  import json
2
2
  import os
3
3
  import re
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
5
 
6
6
  import yaml
7
7
 
8
8
  import wandb
9
9
  from wandb import util
10
- from wandb.apis.public import Api as PublicApi
11
10
  from wandb.sdk.launch.errors import LaunchError
12
11
 
12
+ if TYPE_CHECKING:
13
+ from wandb.apis.public import Api as PublicApi
14
+
13
15
  DEFAULT_SWEEP_COMMAND: List[str] = [
14
16
  "${env}",
15
17
  "${interpreter}",
@@ -276,7 +278,7 @@ def make_launch_sweep_entrypoint(
276
278
  return entry_point, macro_args
277
279
 
278
280
 
279
- def check_job_exists(public_api: PublicApi, job: Optional[str]) -> bool:
281
+ def check_job_exists(public_api: "PublicApi", job: Optional[str]) -> bool:
280
282
  """Check if the job exists using the public api.
281
283
 
282
284
  Returns: True if no job is passed, or if the job exists.
wandb/sdk/launch/utils.py CHANGED
@@ -127,7 +127,9 @@ def set_project_entity_defaults(
127
127
  prefix = ""
128
128
  if platform.system() != "Windows" and sys.stdout.encoding == "UTF-8":
129
129
  prefix = "🚀 "
130
- wandb.termlog(f"{LOG_PREFIX}{prefix}Launching run into {entity}/{project}")
130
+ wandb.termlog(
131
+ f"{LOG_PREFIX}{prefix}Launching run into {entity}{'/' + project if project else ''}"
132
+ )
131
133
  return project, entity
132
134
 
133
135
 
@@ -58,6 +58,7 @@ _Setting = Literal[
58
58
  "_sync",
59
59
  "_os",
60
60
  "_platform",
61
+ "_proxies",
61
62
  "_python",
62
63
  "_runqueue_item_id",
63
64
  "_require_nexus",
@@ -84,6 +85,7 @@ _Setting = Literal[
84
85
  "azure_account_url_to_access_key",
85
86
  "base_url",
86
87
  "code_dir",
88
+ "colab_url",
87
89
  "config_paths",
88
90
  "console",
89
91
  "deployment",
@@ -121,6 +123,7 @@ _Setting = Literal[
121
123
  "notebook_name",
122
124
  "problem",
123
125
  "program",
126
+ "program_abspath",
124
127
  "program_relpath",
125
128
  "project",
126
129
  "project_url",
@@ -209,6 +212,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
209
212
  "tmp_dir",
210
213
  "_tmp_code_dir",
211
214
  "_windows",
215
+ "colab_url",
212
216
  "is_local",
213
217
  "deployment",
214
218
  "disable_code",
@@ -220,6 +224,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
220
224
  "log_symlink_internal",
221
225
  "log_symlink_user",
222
226
  "log_user",
227
+ "program",
223
228
  "project_url",
224
229
  "resume_fname",
225
230
  "run_url",
@@ -20,6 +20,7 @@ class GraphQLSession(HTTPTransport):
20
20
  auth: Optional[Union[Tuple[str, str], Callable]] = None,
21
21
  use_json: bool = False,
22
22
  timeout: Optional[Union[int, float]] = None,
23
+ proxies: Optional[Dict[str, str]] = None,
23
24
  **kwargs: Any,
24
25
  ) -> None:
25
26
  """Setup a session for sending GraphQL queries and mutations.
@@ -32,6 +33,8 @@ class GraphQLSession(HTTPTransport):
32
33
  """
33
34
  super().__init__(url, **kwargs)
34
35
  self.session = requests.Session()
36
+ if proxies:
37
+ self.session.proxies.update(proxies)
35
38
  self.session.auth = auth
36
39
  self.default_timeout = timeout
37
40
  self.use_json = use_json
wandb/sdk/lib/ipython.py CHANGED
@@ -45,6 +45,10 @@ def _get_python_type() -> PythonType:
45
45
  (get_ipython().config.get("IPKernelApp", {}) or {})
46
46
  .get("connection_file", "")
47
47
  .lower()
48
+ ) or (
49
+ (get_ipython().config.get("ColabKernelApp", {}) or {})
50
+ .get("connection_file", "")
51
+ .lower()
48
52
  )
49
53
 
50
54
  if (