wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import datetime
2
- import os
2
+ import logging
3
3
  import shlex
4
4
  import time
5
5
  from typing import Any, Dict, Optional
@@ -10,17 +10,20 @@ if False:
10
10
  import yaml
11
11
 
12
12
  import wandb
13
- from wandb.errors import LaunchError
13
+ from wandb.apis.internal import Api
14
14
  from wandb.util import get_module
15
15
 
16
16
  from .._project_spec import LaunchProject, get_entry_point_command
17
17
  from ..builder.abstract import AbstractBuilder
18
- from ..builder.build import construct_gcp_registry_uri, get_env_vars_dict
19
- from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, run_shell
18
+ from ..builder.build import get_env_vars_dict
19
+ from ..environment.gcp_environment import GcpEnvironment
20
+ from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, LaunchError, run_shell
20
21
  from .abstract import AbstractRun, AbstractRunner, Status
21
22
 
22
23
  GCP_CONSOLE_URI = "https://console.cloud.google.com"
23
24
 
25
+ _logger = logging.getLogger(__name__)
26
+
24
27
 
25
28
  class VertexSubmittedRun(AbstractRun):
26
29
  def __init__(self, job: Any) -> None:
@@ -57,12 +60,14 @@ class VertexSubmittedRun(AbstractRun):
57
60
 
58
61
  def get_status(self) -> Status:
59
62
  job_state = str(self._job.state) # extract from type PipelineState
60
- if job_state == "PipelineState.PIPELINE_STATE_SUCCEEDED":
63
+ if job_state == "JobState.JOB_STATE_SUCCEEDED":
61
64
  return Status("finished")
62
- if job_state == "PipelineState.PIPELINE_STATE_FAILED":
65
+ if job_state == "JobState.JOB_STATE_FAILED":
63
66
  return Status("failed")
64
- if job_state == "PipelineState.PIPELINE_STATE_RUNNING":
67
+ if job_state == "JobState.JOB_STATE_RUNNING":
65
68
  return Status("running")
69
+ if job_state == "JobState.JOB_STATE_PENDING":
70
+ return Status("starting")
66
71
  return Status("unknown")
67
72
 
68
73
  def cancel(self) -> None:
@@ -70,47 +75,37 @@ class VertexSubmittedRun(AbstractRun):
70
75
 
71
76
 
72
77
  class VertexRunner(AbstractRunner):
73
- """Runner class, uses a project to create a VertexSubmittedRun"""
78
+ """Runner class, uses a project to create a VertexSubmittedRun."""
79
+
80
+ def __init__(
81
+ self, api: Api, backend_config: Dict[str, Any], environment: GcpEnvironment
82
+ ) -> None:
83
+ """Initialize a VertexRunner instance."""
84
+ super().__init__(api, backend_config)
85
+ self.environment = environment
74
86
 
75
87
  def run(
76
88
  self,
77
89
  launch_project: LaunchProject,
78
- builder: AbstractBuilder,
79
- registry_config: Dict[str, Any],
90
+ builder: Optional[AbstractBuilder],
80
91
  ) -> Optional[AbstractRun]:
81
-
92
+ """Run a Vertex job."""
82
93
  aiplatform = get_module( # noqa: F811
83
94
  "google.cloud.aiplatform",
84
95
  "VertexRunner requires google.cloud.aiplatform to be installed",
85
96
  )
86
-
87
- resource_args = launch_project.resource_args.get("gcp_vertex")
97
+ resource_args = launch_project.resource_args.get("vertex")
98
+ if not resource_args:
99
+ resource_args = launch_project.resource_args.get("gcp-vertex")
88
100
  if not resource_args:
89
101
  raise LaunchError(
90
102
  "No Vertex resource args specified. Specify args via --resource-args with a JSON file or string under top-level key gcp_vertex"
91
103
  )
92
- gcp_config = get_gcp_config(resource_args.get("gcp_config") or "default")
93
- gcp_project = (
94
- resource_args.get("gcp_project")
95
- or gcp_config["properties"]["core"]["project"]
96
- )
97
- gcp_zone = resource_args.get("gcp_region") or gcp_config["properties"].get(
98
- "compute", {}
99
- ).get("zone")
100
- gcp_region = "-".join(gcp_zone.split("-")[:2])
101
104
  gcp_staging_bucket = resource_args.get("staging_bucket")
102
105
  if not gcp_staging_bucket:
103
106
  raise LaunchError(
104
107
  "Vertex requires a staging bucket for training and dependency packages in the same region as compute. Specify a bucket under key staging_bucket."
105
108
  )
106
- gcp_artifact_repo = resource_args.get("artifact_repo")
107
- if not gcp_artifact_repo:
108
- raise LaunchError(
109
- "Vertex requires an Artifact Registry repository for the Docker image. Specify a repo under key artifact_repo."
110
- )
111
- gcp_docker_host = (
112
- resource_args.get("docker_host") or f"{gcp_region}-docker.pkg.dev"
113
- )
114
109
  gcp_machine_type = resource_args.get("machine_type") or "n1-standard-4"
115
110
  gcp_accelerator_type = (
116
111
  resource_args.get("accelerator_type") or "ACCELERATOR_TYPE_UNSPECIFIED"
@@ -124,9 +119,10 @@ class VertexRunner(AbstractRunner):
124
119
  )
125
120
  service_account = resource_args.get("service_account")
126
121
  tensorboard = resource_args.get("tensorboard")
127
-
128
122
  aiplatform.init(
129
- project=gcp_project, location=gcp_region, staging_bucket=gcp_staging_bucket
123
+ project=self.environment.project,
124
+ location=self.environment.region,
125
+ staging_bucket=gcp_staging_bucket,
130
126
  )
131
127
  synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
132
128
 
@@ -135,21 +131,13 @@ class VertexRunner(AbstractRunner):
135
131
  if launch_project.docker_image:
136
132
  image_uri = launch_project.docker_image
137
133
  else:
138
-
139
- repository = construct_gcp_registry_uri(
140
- gcp_artifact_repo,
141
- gcp_project,
142
- gcp_docker_host,
143
- )
144
134
  assert entry_point is not None
135
+ assert builder is not None
145
136
  image_uri = builder.build_image(
146
137
  launch_project,
147
- repository,
148
138
  entry_point,
149
139
  )
150
140
 
151
- if not self.ack_run_queue_item(launch_project):
152
- return None
153
141
  # TODO: how to handle this?
154
142
  entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
155
143
 
@@ -176,18 +164,19 @@ class VertexRunner(AbstractRunner):
176
164
  display_name=gcp_training_job_name, worker_pool_specs=worker_pool_specs
177
165
  )
178
166
 
179
- submitted_run = VertexSubmittedRun(job)
180
-
181
- # todo: support gcp dataset?
182
-
183
167
  wandb.termlog(
184
168
  f"{LOG_PREFIX}Running training job {gcp_training_job_name} on {gcp_machine_type}."
185
169
  )
186
170
 
187
- # when sync is True, vertex blocks the main thread on job completion. when False, vertex returns a Future
188
- # on this thread but continues to block the process on another thread. always set sync=False so we can get
189
- # the job info (dependent on job._gca_resource)
190
- job.run(service_account=service_account, tensorboard=tensorboard, sync=False)
171
+ if synchronous:
172
+ job.run(service_account=service_account, tensorboard=tensorboard, sync=True)
173
+ else:
174
+ job.submit(
175
+ service_account=service_account,
176
+ tensorboard=tensorboard,
177
+ )
178
+
179
+ submitted_run = VertexSubmittedRun(job)
191
180
 
192
181
  while not getattr(job._gca_resource, "name", None):
193
182
  # give time for the gcp job object to be created and named, this should only loop a couple times max
@@ -196,12 +185,6 @@ class VertexRunner(AbstractRunner):
196
185
  wandb.termlog(
197
186
  f"{LOG_PREFIX}View your job status and logs at {submitted_run.get_page_link()}."
198
187
  )
199
-
200
- # hacky: if user doesn't want blocking behavior, kill both main thread and the background thread. job continues
201
- # to run remotely. this obviously doesn't work if we need to do some sort of postprocessing after this run fn
202
- if not synchronous:
203
- os._exit(0)
204
-
205
188
  return submitted_run
206
189
 
207
190
 
@@ -3,9 +3,11 @@ from typing import Any, Callable, Dict
3
3
 
4
4
  log = logging.getLogger(__name__)
5
5
 
6
+ SCHEDULER_URI = "placeholder-uri-scheduler"
7
+
6
8
 
7
9
  class SchedulerError(Exception):
8
- """Raised when a known error occurs with wandb sweep scheduler"""
10
+ """Raised when a known error occurs with wandb sweep scheduler."""
9
11
 
10
12
  pass
11
13
 
@@ -22,7 +24,6 @@ _WANDB_SCHEDULERS: Dict[str, Callable] = {
22
24
 
23
25
 
24
26
  def load_scheduler(scheduler_name: str) -> Any:
25
-
26
27
  scheduler_name = scheduler_name.lower()
27
28
  if scheduler_name not in _WANDB_SCHEDULERS:
28
29
  raise SchedulerError(
@@ -2,12 +2,14 @@
2
2
  import logging
3
3
  import os
4
4
  import threading
5
+ import traceback
5
6
  from abc import ABC, abstractmethod
6
7
  from dataclasses import dataclass
7
8
  from enum import Enum
8
9
  from typing import Any, Dict, Iterator, List, Optional, Tuple
9
10
 
10
11
  import click
12
+ import yaml
11
13
 
12
14
  import wandb
13
15
  import wandb.apis.public as public
@@ -16,21 +18,30 @@ from wandb.errors import CommError
16
18
  from wandb.sdk.launch.launch_add import launch_add
17
19
  from wandb.sdk.launch.sweeps import SchedulerError
18
20
  from wandb.sdk.lib.runid import generate_id
21
+ from wandb.wandb_agent import Agent
19
22
 
20
- logger = logging.getLogger(__name__)
23
+ _logger = logging.getLogger(__name__)
21
24
  LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
22
25
 
23
26
 
27
+ @dataclass
28
+ class _Worker:
29
+ agent_config: Dict[str, Any]
30
+ agent_id: str
31
+
32
+
24
33
  class SchedulerState(Enum):
25
34
  PENDING = 0
26
35
  STARTING = 1
27
36
  RUNNING = 2
28
- COMPLETED = 3
29
- FAILED = 4
30
- STOPPED = 5
37
+ FLUSH_RUNS = 3
38
+ COMPLETED = 4
39
+ FAILED = 5
40
+ STOPPED = 6
41
+ CANCELLED = 7
31
42
 
32
43
 
33
- class SimpleRunState(Enum):
44
+ class RunState(Enum):
34
45
  ALIVE = 0
35
46
  DEAD = 1
36
47
  UNKNOWN = 2
@@ -39,19 +50,16 @@ class SimpleRunState(Enum):
39
50
  @dataclass
40
51
  class SweepRun:
41
52
  id: str
42
- state: SimpleRunState = SimpleRunState.ALIVE
53
+ state: RunState = RunState.ALIVE
43
54
  queued_run: Optional[public.QueuedRun] = None
44
55
  args: Optional[Dict[str, Any]] = None
45
56
  logs: Optional[List[str]] = None
46
- program: Optional[str] = None
47
57
  # Threading can be used to run multiple workers in parallel
48
58
  worker_id: Optional[int] = None
49
59
 
50
60
 
51
61
  class Scheduler(ABC):
52
- """The Scheduler is a controller/agent that will populate a Launch RunQueue
53
- with jobs from a hyperparameter sweep.
54
- """
62
+ """A controller/agent that populates a Launch RunQueue from a hyperparameter sweep."""
55
63
 
56
64
  def __init__(
57
65
  self,
@@ -73,18 +81,31 @@ class Scheduler(ABC):
73
81
  self._project = (
74
82
  project or os.environ.get("WANDB_PROJECT") or api.settings("project")
75
83
  )
84
+ self._sweep_id: str = sweep_id or "empty-sweep-id"
85
+ self._state: SchedulerState = SchedulerState.PENDING
86
+
76
87
  # Make sure the provided sweep_id corresponds to a valid sweep
77
88
  try:
78
- self._api.sweep(sweep_id, "{}", entity=self._entity, project=self._project)
89
+ resp = self._api.sweep(
90
+ sweep_id, "{}", entity=self._entity, project=self._project
91
+ )
92
+ if resp.get("state") == SchedulerState.CANCELLED.name:
93
+ self._state = SchedulerState.CANCELLED
94
+ self._sweep_config = yaml.safe_load(resp["config"])
79
95
  except Exception as e:
80
96
  raise SchedulerError(f"{LOG_PREFIX}Exception when finding sweep: {e}")
81
- self._sweep_id: str = sweep_id or "empty-sweep-id"
82
- self._state: SchedulerState = SchedulerState.PENDING
97
+
83
98
  # Dictionary of the runs being managed by the scheduler
84
99
  self._runs: Dict[str, SweepRun] = {}
85
100
  # Threading lock to ensure thread-safe access to the runs dictionary
86
101
  self._threading_lock: threading.Lock = threading.Lock()
87
- self._project_queue = project_queue or self._project
102
+ self._project_queue = project_queue
103
+ # Optionally run multiple workers in (pseudo-)parallel. Workers do not
104
+ # actually run training workloads, they simply send heartbeat messages
105
+ # (emulating a real agent) and add new runs to the launch queue. The
106
+ # launch agent is the one that actually runs the training workloads.
107
+ self._workers: Dict[int, _Worker] = {}
108
+
88
109
  # Scheduler may receive additional kwargs which will be piped into the launch command
89
110
  self._kwargs: Dict[str, Any] = kwargs
90
111
 
@@ -102,12 +123,12 @@ class Scheduler(ABC):
102
123
 
103
124
  @property
104
125
  def state(self) -> SchedulerState:
105
- logger.debug(f"{LOG_PREFIX}Scheduler state is {self._state.name}")
126
+ _logger.debug(f"{LOG_PREFIX}Scheduler state is {self._state.name}")
106
127
  return self._state
107
128
 
108
129
  @state.setter
109
130
  def state(self, value: SchedulerState) -> None:
110
- logger.debug(f"{LOG_PREFIX}Scheduler was {self.state.name} is {value.name}")
131
+ _logger.debug(f"{LOG_PREFIX}Scheduler was {self.state.name} is {value.name}")
111
132
  self._state = value
112
133
 
113
134
  def is_alive(self) -> bool:
@@ -115,17 +136,33 @@ class Scheduler(ABC):
115
136
  SchedulerState.COMPLETED,
116
137
  SchedulerState.FAILED,
117
138
  SchedulerState.STOPPED,
139
+ SchedulerState.CANCELLED,
118
140
  ]:
119
141
  return False
120
142
  return True
121
143
 
122
144
  def start(self) -> None:
145
+ """Start a scheduler, confirms prerequisites, begins execution loop."""
123
146
  wandb.termlog(f"{LOG_PREFIX}Scheduler starting.")
147
+ if not self.is_alive():
148
+ wandb.termerror(
149
+ f"{LOG_PREFIX}Sweep already {self.state.name.lower()}! Exiting..."
150
+ )
151
+ self.exit()
152
+ return
153
+
124
154
  self._state = SchedulerState.STARTING
155
+ if not self._try_load_executable():
156
+ wandb.termerror(
157
+ f"{LOG_PREFIX}No 'job' or 'image_uri' loaded from sweep config."
158
+ )
159
+ self.exit()
160
+ return
125
161
  self._start()
126
162
  self.run()
127
163
 
128
164
  def run(self) -> None:
165
+ """Main run function for all external schedulers."""
129
166
  wandb.termlog(f"{LOG_PREFIX}Scheduler Running.")
130
167
  self.state = SchedulerState.RUNNING
131
168
  try:
@@ -134,6 +171,11 @@ class Scheduler(ABC):
134
171
  break
135
172
  self._update_run_states()
136
173
  self._run()
174
+ # if we hit the run_cap, now set to stopped after launching runs
175
+ if self.state == SchedulerState.FLUSH_RUNS:
176
+ if len(self._runs.keys()) == 0:
177
+ wandb.termlog(f"{LOG_PREFIX}Done polling on runs, exiting.")
178
+ self.state = SchedulerState.STOPPED
137
179
  except KeyboardInterrupt:
138
180
  wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.")
139
181
  self.state = SchedulerState.STOPPED
@@ -157,6 +199,28 @@ class Scheduler(ABC):
157
199
  self.state = SchedulerState.FAILED
158
200
  self._stop_runs()
159
201
 
202
+ def _try_load_executable(self) -> bool:
203
+ """Check existance of valid executable for a run.
204
+
205
+ logs and returns False when job is unreachable
206
+ """
207
+ if self._kwargs.get("job"):
208
+ _public_api = public.Api()
209
+ try:
210
+ _job_artifact = _public_api.artifact(self._kwargs["job"], type="job")
211
+ wandb.termlog(
212
+ f"{LOG_PREFIX}Successfully loaded job: {_job_artifact.name} in scheduler"
213
+ )
214
+ except Exception:
215
+ wandb.termerror(f"{LOG_PREFIX}{traceback.format_exc()}")
216
+ return False
217
+ return True
218
+ elif self._kwargs.get("image_uri"):
219
+ # TODO(gst): check docker existance? Use registry in launch config?
220
+ return True
221
+ else:
222
+ return False
223
+
160
224
  def _yield_runs(self) -> Iterator[Tuple[str, SweepRun]]:
161
225
  """Thread-safe way to iterate over the runs."""
162
226
  with self._threading_lock:
@@ -168,25 +232,38 @@ class Scheduler(ABC):
168
232
  self._stop_run(run_id)
169
233
 
170
234
  def _stop_run(self, run_id: str) -> None:
171
- """Stops a run and removes it from the scheduler"""
235
+ """Stop a run and removes it from the scheduler."""
172
236
  if run_id in self._runs:
173
237
  run: SweepRun = self._runs[run_id]
174
- run.state = SimpleRunState.DEAD
238
+ run.state = RunState.DEAD
175
239
  # TODO(hupo): Send command to backend to stop run
176
240
  wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.")
177
241
 
178
242
  def _update_run_states(self) -> None:
243
+ """Iterate through runs.
244
+
245
+ Get state from backend and deletes runs if not in running state. Threadsafe.
246
+ """
179
247
  _runs_to_remove: List[str] = []
180
248
  for run_id, run in self._yield_runs():
181
249
  try:
182
250
  _state = self._api.get_run_state(self._entity, self._project, run_id)
183
- if _state is None or _state in [
184
- "crashed",
185
- "failed",
186
- "killed",
187
- "finished",
188
- ]:
189
- run.state = SimpleRunState.DEAD
251
+ _rqi_state = run.queued_run.state if run.queued_run else None
252
+ if (
253
+ not _state
254
+ or _state
255
+ in [
256
+ "crashed",
257
+ "failed",
258
+ "killed",
259
+ "finished",
260
+ ]
261
+ or _rqi_state == "failed"
262
+ ):
263
+ _logger.debug(
264
+ f"({run_id}) run-state:{_state}, rqi-state:{_rqi_state}"
265
+ )
266
+ run.state = RunState.DEAD
190
267
  _runs_to_remove.append(run_id)
191
268
  elif _state in [
192
269
  "running",
@@ -194,12 +271,12 @@ class Scheduler(ABC):
194
271
  "preempted",
195
272
  "preempting",
196
273
  ]:
197
- run.state = SimpleRunState.ALIVE
274
+ run.state = RunState.ALIVE
198
275
  except CommError as e:
199
276
  wandb.termlog(
200
277
  f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}"
201
278
  )
202
- run.state = SimpleRunState.UNKNOWN
279
+ run.state = RunState.UNKNOWN
203
280
  continue
204
281
  # Remove any runs that are dead
205
282
  with self._threading_lock:
@@ -213,31 +290,47 @@ class Scheduler(ABC):
213
290
  entry_point: Optional[List[str]] = None,
214
291
  config: Optional[Dict[str, Any]] = None,
215
292
  ) -> "public.QueuedRun":
216
- """Add a launch job to the Launch RunQueue."""
293
+ """Add a launch job to the Launch RunQueue.
294
+
295
+ run_id: supplied by gorilla from agentHeartbeat
296
+ entry_point: sweep entrypoint overrides image_uri/job entrypoint
297
+ config: launch config
298
+ """
299
+ # job and image first from CLI args, then from sweep config
300
+ _job = self._kwargs.get("job") or self._sweep_config.get("job")
301
+
302
+ _sweep_config_uri = self._sweep_config.get("image_uri")
303
+ _image_uri = self._kwargs.get("image_uri") or _sweep_config_uri
304
+ if _job is None and _image_uri is None:
305
+ raise SchedulerError(
306
+ f"{LOG_PREFIX}No 'job' nor 'image_uri' (run: {run_id})"
307
+ )
308
+ elif _job is not None and _image_uri is not None:
309
+ raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
310
+
311
+ if self._sweep_config.get("command"):
312
+ entry_point = Agent._create_sweep_command(self._sweep_config["command"])
313
+ wandb.termwarn(
314
+ f"{LOG_PREFIX}Sweep command {entry_point} will override"
315
+ f' {"job" if _job else "image_uri"} entrypoint'
316
+ )
317
+
217
318
  run_id = run_id or generate_id()
218
- # One of Job and URI is required
219
- _job = self._kwargs.get("job", None)
220
- _uri = self._kwargs.get("uri", None)
221
- if _job is None and _uri is None:
222
- # If no Job is specified, use a placeholder URI to prevent Launch failure
223
- _uri = "placeholder-uri-queuedrun-from-scheduler"
224
- # Queue is required
225
- _queue = self._kwargs.get("queue", "default")
226
319
  queued_run = launch_add(
227
320
  run_id=run_id,
228
321
  entry_point=entry_point,
229
322
  config=config,
230
- uri=_uri,
323
+ docker_image=_image_uri, # TODO(gst): make agnostic (github? run uri?)
231
324
  job=_job,
232
325
  project=self._project,
233
326
  entity=self._entity,
234
- queue_name=_queue,
327
+ queue_name=self._kwargs.get("queue"),
235
328
  project_queue=self._project_queue,
236
329
  resource=self._kwargs.get("resource", None),
237
330
  resource_args=self._kwargs.get("resource_args", None),
238
331
  )
239
332
  self._runs[run_id].queued_run = queued_run
240
333
  wandb.termlog(
241
- f"{LOG_PREFIX}Added run to Launch RunQueue: {_queue} RunID:{run_id}."
334
+ f"{LOG_PREFIX}Added run to Launch queue: {self._kwargs.get('queue')} RunID:{run_id}."
242
335
  )
243
336
  return queued_run