wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import datetime
|
2
|
-
import
|
2
|
+
import logging
|
3
3
|
import shlex
|
4
4
|
import time
|
5
5
|
from typing import Any, Dict, Optional
|
@@ -10,17 +10,20 @@ if False:
|
|
10
10
|
import yaml
|
11
11
|
|
12
12
|
import wandb
|
13
|
-
from wandb.
|
13
|
+
from wandb.apis.internal import Api
|
14
14
|
from wandb.util import get_module
|
15
15
|
|
16
16
|
from .._project_spec import LaunchProject, get_entry_point_command
|
17
17
|
from ..builder.abstract import AbstractBuilder
|
18
|
-
from ..builder.build import
|
19
|
-
from ..
|
18
|
+
from ..builder.build import get_env_vars_dict
|
19
|
+
from ..environment.gcp_environment import GcpEnvironment
|
20
|
+
from ..utils import LOG_PREFIX, PROJECT_SYNCHRONOUS, LaunchError, run_shell
|
20
21
|
from .abstract import AbstractRun, AbstractRunner, Status
|
21
22
|
|
22
23
|
GCP_CONSOLE_URI = "https://console.cloud.google.com"
|
23
24
|
|
25
|
+
_logger = logging.getLogger(__name__)
|
26
|
+
|
24
27
|
|
25
28
|
class VertexSubmittedRun(AbstractRun):
|
26
29
|
def __init__(self, job: Any) -> None:
|
@@ -57,12 +60,14 @@ class VertexSubmittedRun(AbstractRun):
|
|
57
60
|
|
58
61
|
def get_status(self) -> Status:
|
59
62
|
job_state = str(self._job.state) # extract from type PipelineState
|
60
|
-
if job_state == "
|
63
|
+
if job_state == "JobState.JOB_STATE_SUCCEEDED":
|
61
64
|
return Status("finished")
|
62
|
-
if job_state == "
|
65
|
+
if job_state == "JobState.JOB_STATE_FAILED":
|
63
66
|
return Status("failed")
|
64
|
-
if job_state == "
|
67
|
+
if job_state == "JobState.JOB_STATE_RUNNING":
|
65
68
|
return Status("running")
|
69
|
+
if job_state == "JobState.JOB_STATE_PENDING":
|
70
|
+
return Status("starting")
|
66
71
|
return Status("unknown")
|
67
72
|
|
68
73
|
def cancel(self) -> None:
|
@@ -70,47 +75,37 @@ class VertexSubmittedRun(AbstractRun):
|
|
70
75
|
|
71
76
|
|
72
77
|
class VertexRunner(AbstractRunner):
|
73
|
-
"""Runner class, uses a project to create a VertexSubmittedRun"""
|
78
|
+
"""Runner class, uses a project to create a VertexSubmittedRun."""
|
79
|
+
|
80
|
+
def __init__(
|
81
|
+
self, api: Api, backend_config: Dict[str, Any], environment: GcpEnvironment
|
82
|
+
) -> None:
|
83
|
+
"""Initialize a VertexRunner instance."""
|
84
|
+
super().__init__(api, backend_config)
|
85
|
+
self.environment = environment
|
74
86
|
|
75
87
|
def run(
|
76
88
|
self,
|
77
89
|
launch_project: LaunchProject,
|
78
|
-
builder: AbstractBuilder,
|
79
|
-
registry_config: Dict[str, Any],
|
90
|
+
builder: Optional[AbstractBuilder],
|
80
91
|
) -> Optional[AbstractRun]:
|
81
|
-
|
92
|
+
"""Run a Vertex job."""
|
82
93
|
aiplatform = get_module( # noqa: F811
|
83
94
|
"google.cloud.aiplatform",
|
84
95
|
"VertexRunner requires google.cloud.aiplatform to be installed",
|
85
96
|
)
|
86
|
-
|
87
|
-
|
97
|
+
resource_args = launch_project.resource_args.get("vertex")
|
98
|
+
if not resource_args:
|
99
|
+
resource_args = launch_project.resource_args.get("gcp-vertex")
|
88
100
|
if not resource_args:
|
89
101
|
raise LaunchError(
|
90
102
|
"No Vertex resource args specified. Specify args via --resource-args with a JSON file or string under top-level key gcp_vertex"
|
91
103
|
)
|
92
|
-
gcp_config = get_gcp_config(resource_args.get("gcp_config") or "default")
|
93
|
-
gcp_project = (
|
94
|
-
resource_args.get("gcp_project")
|
95
|
-
or gcp_config["properties"]["core"]["project"]
|
96
|
-
)
|
97
|
-
gcp_zone = resource_args.get("gcp_region") or gcp_config["properties"].get(
|
98
|
-
"compute", {}
|
99
|
-
).get("zone")
|
100
|
-
gcp_region = "-".join(gcp_zone.split("-")[:2])
|
101
104
|
gcp_staging_bucket = resource_args.get("staging_bucket")
|
102
105
|
if not gcp_staging_bucket:
|
103
106
|
raise LaunchError(
|
104
107
|
"Vertex requires a staging bucket for training and dependency packages in the same region as compute. Specify a bucket under key staging_bucket."
|
105
108
|
)
|
106
|
-
gcp_artifact_repo = resource_args.get("artifact_repo")
|
107
|
-
if not gcp_artifact_repo:
|
108
|
-
raise LaunchError(
|
109
|
-
"Vertex requires an Artifact Registry repository for the Docker image. Specify a repo under key artifact_repo."
|
110
|
-
)
|
111
|
-
gcp_docker_host = (
|
112
|
-
resource_args.get("docker_host") or f"{gcp_region}-docker.pkg.dev"
|
113
|
-
)
|
114
109
|
gcp_machine_type = resource_args.get("machine_type") or "n1-standard-4"
|
115
110
|
gcp_accelerator_type = (
|
116
111
|
resource_args.get("accelerator_type") or "ACCELERATOR_TYPE_UNSPECIFIED"
|
@@ -124,9 +119,10 @@ class VertexRunner(AbstractRunner):
|
|
124
119
|
)
|
125
120
|
service_account = resource_args.get("service_account")
|
126
121
|
tensorboard = resource_args.get("tensorboard")
|
127
|
-
|
128
122
|
aiplatform.init(
|
129
|
-
project=
|
123
|
+
project=self.environment.project,
|
124
|
+
location=self.environment.region,
|
125
|
+
staging_bucket=gcp_staging_bucket,
|
130
126
|
)
|
131
127
|
synchronous: bool = self.backend_config[PROJECT_SYNCHRONOUS]
|
132
128
|
|
@@ -135,21 +131,13 @@ class VertexRunner(AbstractRunner):
|
|
135
131
|
if launch_project.docker_image:
|
136
132
|
image_uri = launch_project.docker_image
|
137
133
|
else:
|
138
|
-
|
139
|
-
repository = construct_gcp_registry_uri(
|
140
|
-
gcp_artifact_repo,
|
141
|
-
gcp_project,
|
142
|
-
gcp_docker_host,
|
143
|
-
)
|
144
134
|
assert entry_point is not None
|
135
|
+
assert builder is not None
|
145
136
|
image_uri = builder.build_image(
|
146
137
|
launch_project,
|
147
|
-
repository,
|
148
138
|
entry_point,
|
149
139
|
)
|
150
140
|
|
151
|
-
if not self.ack_run_queue_item(launch_project):
|
152
|
-
return None
|
153
141
|
# TODO: how to handle this?
|
154
142
|
entry_cmd = get_entry_point_command(entry_point, launch_project.override_args)
|
155
143
|
|
@@ -176,18 +164,19 @@ class VertexRunner(AbstractRunner):
|
|
176
164
|
display_name=gcp_training_job_name, worker_pool_specs=worker_pool_specs
|
177
165
|
)
|
178
166
|
|
179
|
-
submitted_run = VertexSubmittedRun(job)
|
180
|
-
|
181
|
-
# todo: support gcp dataset?
|
182
|
-
|
183
167
|
wandb.termlog(
|
184
168
|
f"{LOG_PREFIX}Running training job {gcp_training_job_name} on {gcp_machine_type}."
|
185
169
|
)
|
186
170
|
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
171
|
+
if synchronous:
|
172
|
+
job.run(service_account=service_account, tensorboard=tensorboard, sync=True)
|
173
|
+
else:
|
174
|
+
job.submit(
|
175
|
+
service_account=service_account,
|
176
|
+
tensorboard=tensorboard,
|
177
|
+
)
|
178
|
+
|
179
|
+
submitted_run = VertexSubmittedRun(job)
|
191
180
|
|
192
181
|
while not getattr(job._gca_resource, "name", None):
|
193
182
|
# give time for the gcp job object to be created and named, this should only loop a couple times max
|
@@ -196,12 +185,6 @@ class VertexRunner(AbstractRunner):
|
|
196
185
|
wandb.termlog(
|
197
186
|
f"{LOG_PREFIX}View your job status and logs at {submitted_run.get_page_link()}."
|
198
187
|
)
|
199
|
-
|
200
|
-
# hacky: if user doesn't want blocking behavior, kill both main thread and the background thread. job continues
|
201
|
-
# to run remotely. this obviously doesn't work if we need to do some sort of postprocessing after this run fn
|
202
|
-
if not synchronous:
|
203
|
-
os._exit(0)
|
204
|
-
|
205
188
|
return submitted_run
|
206
189
|
|
207
190
|
|
@@ -3,9 +3,11 @@ from typing import Any, Callable, Dict
|
|
3
3
|
|
4
4
|
log = logging.getLogger(__name__)
|
5
5
|
|
6
|
+
SCHEDULER_URI = "placeholder-uri-scheduler"
|
7
|
+
|
6
8
|
|
7
9
|
class SchedulerError(Exception):
|
8
|
-
"""Raised when a known error occurs with wandb sweep scheduler"""
|
10
|
+
"""Raised when a known error occurs with wandb sweep scheduler."""
|
9
11
|
|
10
12
|
pass
|
11
13
|
|
@@ -22,7 +24,6 @@ _WANDB_SCHEDULERS: Dict[str, Callable] = {
|
|
22
24
|
|
23
25
|
|
24
26
|
def load_scheduler(scheduler_name: str) -> Any:
|
25
|
-
|
26
27
|
scheduler_name = scheduler_name.lower()
|
27
28
|
if scheduler_name not in _WANDB_SCHEDULERS:
|
28
29
|
raise SchedulerError(
|
@@ -2,12 +2,14 @@
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import threading
|
5
|
+
import traceback
|
5
6
|
from abc import ABC, abstractmethod
|
6
7
|
from dataclasses import dataclass
|
7
8
|
from enum import Enum
|
8
9
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
9
10
|
|
10
11
|
import click
|
12
|
+
import yaml
|
11
13
|
|
12
14
|
import wandb
|
13
15
|
import wandb.apis.public as public
|
@@ -16,21 +18,30 @@ from wandb.errors import CommError
|
|
16
18
|
from wandb.sdk.launch.launch_add import launch_add
|
17
19
|
from wandb.sdk.launch.sweeps import SchedulerError
|
18
20
|
from wandb.sdk.lib.runid import generate_id
|
21
|
+
from wandb.wandb_agent import Agent
|
19
22
|
|
20
|
-
|
23
|
+
_logger = logging.getLogger(__name__)
|
21
24
|
LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
|
22
25
|
|
23
26
|
|
27
|
+
@dataclass
|
28
|
+
class _Worker:
|
29
|
+
agent_config: Dict[str, Any]
|
30
|
+
agent_id: str
|
31
|
+
|
32
|
+
|
24
33
|
class SchedulerState(Enum):
|
25
34
|
PENDING = 0
|
26
35
|
STARTING = 1
|
27
36
|
RUNNING = 2
|
28
|
-
|
29
|
-
|
30
|
-
|
37
|
+
FLUSH_RUNS = 3
|
38
|
+
COMPLETED = 4
|
39
|
+
FAILED = 5
|
40
|
+
STOPPED = 6
|
41
|
+
CANCELLED = 7
|
31
42
|
|
32
43
|
|
33
|
-
class
|
44
|
+
class RunState(Enum):
|
34
45
|
ALIVE = 0
|
35
46
|
DEAD = 1
|
36
47
|
UNKNOWN = 2
|
@@ -39,19 +50,16 @@ class SimpleRunState(Enum):
|
|
39
50
|
@dataclass
|
40
51
|
class SweepRun:
|
41
52
|
id: str
|
42
|
-
state:
|
53
|
+
state: RunState = RunState.ALIVE
|
43
54
|
queued_run: Optional[public.QueuedRun] = None
|
44
55
|
args: Optional[Dict[str, Any]] = None
|
45
56
|
logs: Optional[List[str]] = None
|
46
|
-
program: Optional[str] = None
|
47
57
|
# Threading can be used to run multiple workers in parallel
|
48
58
|
worker_id: Optional[int] = None
|
49
59
|
|
50
60
|
|
51
61
|
class Scheduler(ABC):
|
52
|
-
"""
|
53
|
-
with jobs from a hyperparameter sweep.
|
54
|
-
"""
|
62
|
+
"""A controller/agent that populates a Launch RunQueue from a hyperparameter sweep."""
|
55
63
|
|
56
64
|
def __init__(
|
57
65
|
self,
|
@@ -73,18 +81,31 @@ class Scheduler(ABC):
|
|
73
81
|
self._project = (
|
74
82
|
project or os.environ.get("WANDB_PROJECT") or api.settings("project")
|
75
83
|
)
|
84
|
+
self._sweep_id: str = sweep_id or "empty-sweep-id"
|
85
|
+
self._state: SchedulerState = SchedulerState.PENDING
|
86
|
+
|
76
87
|
# Make sure the provided sweep_id corresponds to a valid sweep
|
77
88
|
try:
|
78
|
-
self._api.sweep(
|
89
|
+
resp = self._api.sweep(
|
90
|
+
sweep_id, "{}", entity=self._entity, project=self._project
|
91
|
+
)
|
92
|
+
if resp.get("state") == SchedulerState.CANCELLED.name:
|
93
|
+
self._state = SchedulerState.CANCELLED
|
94
|
+
self._sweep_config = yaml.safe_load(resp["config"])
|
79
95
|
except Exception as e:
|
80
96
|
raise SchedulerError(f"{LOG_PREFIX}Exception when finding sweep: {e}")
|
81
|
-
|
82
|
-
self._state: SchedulerState = SchedulerState.PENDING
|
97
|
+
|
83
98
|
# Dictionary of the runs being managed by the scheduler
|
84
99
|
self._runs: Dict[str, SweepRun] = {}
|
85
100
|
# Threading lock to ensure thread-safe access to the runs dictionary
|
86
101
|
self._threading_lock: threading.Lock = threading.Lock()
|
87
|
-
self._project_queue = project_queue
|
102
|
+
self._project_queue = project_queue
|
103
|
+
# Optionally run multiple workers in (pseudo-)parallel. Workers do not
|
104
|
+
# actually run training workloads, they simply send heartbeat messages
|
105
|
+
# (emulating a real agent) and add new runs to the launch queue. The
|
106
|
+
# launch agent is the one that actually runs the training workloads.
|
107
|
+
self._workers: Dict[int, _Worker] = {}
|
108
|
+
|
88
109
|
# Scheduler may receive additional kwargs which will be piped into the launch command
|
89
110
|
self._kwargs: Dict[str, Any] = kwargs
|
90
111
|
|
@@ -102,12 +123,12 @@ class Scheduler(ABC):
|
|
102
123
|
|
103
124
|
@property
|
104
125
|
def state(self) -> SchedulerState:
|
105
|
-
|
126
|
+
_logger.debug(f"{LOG_PREFIX}Scheduler state is {self._state.name}")
|
106
127
|
return self._state
|
107
128
|
|
108
129
|
@state.setter
|
109
130
|
def state(self, value: SchedulerState) -> None:
|
110
|
-
|
131
|
+
_logger.debug(f"{LOG_PREFIX}Scheduler was {self.state.name} is {value.name}")
|
111
132
|
self._state = value
|
112
133
|
|
113
134
|
def is_alive(self) -> bool:
|
@@ -115,17 +136,33 @@ class Scheduler(ABC):
|
|
115
136
|
SchedulerState.COMPLETED,
|
116
137
|
SchedulerState.FAILED,
|
117
138
|
SchedulerState.STOPPED,
|
139
|
+
SchedulerState.CANCELLED,
|
118
140
|
]:
|
119
141
|
return False
|
120
142
|
return True
|
121
143
|
|
122
144
|
def start(self) -> None:
|
145
|
+
"""Start a scheduler, confirms prerequisites, begins execution loop."""
|
123
146
|
wandb.termlog(f"{LOG_PREFIX}Scheduler starting.")
|
147
|
+
if not self.is_alive():
|
148
|
+
wandb.termerror(
|
149
|
+
f"{LOG_PREFIX}Sweep already {self.state.name.lower()}! Exiting..."
|
150
|
+
)
|
151
|
+
self.exit()
|
152
|
+
return
|
153
|
+
|
124
154
|
self._state = SchedulerState.STARTING
|
155
|
+
if not self._try_load_executable():
|
156
|
+
wandb.termerror(
|
157
|
+
f"{LOG_PREFIX}No 'job' or 'image_uri' loaded from sweep config."
|
158
|
+
)
|
159
|
+
self.exit()
|
160
|
+
return
|
125
161
|
self._start()
|
126
162
|
self.run()
|
127
163
|
|
128
164
|
def run(self) -> None:
|
165
|
+
"""Main run function for all external schedulers."""
|
129
166
|
wandb.termlog(f"{LOG_PREFIX}Scheduler Running.")
|
130
167
|
self.state = SchedulerState.RUNNING
|
131
168
|
try:
|
@@ -134,6 +171,11 @@ class Scheduler(ABC):
|
|
134
171
|
break
|
135
172
|
self._update_run_states()
|
136
173
|
self._run()
|
174
|
+
# if we hit the run_cap, now set to stopped after launching runs
|
175
|
+
if self.state == SchedulerState.FLUSH_RUNS:
|
176
|
+
if len(self._runs.keys()) == 0:
|
177
|
+
wandb.termlog(f"{LOG_PREFIX}Done polling on runs, exiting.")
|
178
|
+
self.state = SchedulerState.STOPPED
|
137
179
|
except KeyboardInterrupt:
|
138
180
|
wandb.termlog(f"{LOG_PREFIX}Scheduler received KeyboardInterrupt. Exiting.")
|
139
181
|
self.state = SchedulerState.STOPPED
|
@@ -157,6 +199,28 @@ class Scheduler(ABC):
|
|
157
199
|
self.state = SchedulerState.FAILED
|
158
200
|
self._stop_runs()
|
159
201
|
|
202
|
+
def _try_load_executable(self) -> bool:
|
203
|
+
"""Check existance of valid executable for a run.
|
204
|
+
|
205
|
+
logs and returns False when job is unreachable
|
206
|
+
"""
|
207
|
+
if self._kwargs.get("job"):
|
208
|
+
_public_api = public.Api()
|
209
|
+
try:
|
210
|
+
_job_artifact = _public_api.artifact(self._kwargs["job"], type="job")
|
211
|
+
wandb.termlog(
|
212
|
+
f"{LOG_PREFIX}Successfully loaded job: {_job_artifact.name} in scheduler"
|
213
|
+
)
|
214
|
+
except Exception:
|
215
|
+
wandb.termerror(f"{LOG_PREFIX}{traceback.format_exc()}")
|
216
|
+
return False
|
217
|
+
return True
|
218
|
+
elif self._kwargs.get("image_uri"):
|
219
|
+
# TODO(gst): check docker existance? Use registry in launch config?
|
220
|
+
return True
|
221
|
+
else:
|
222
|
+
return False
|
223
|
+
|
160
224
|
def _yield_runs(self) -> Iterator[Tuple[str, SweepRun]]:
|
161
225
|
"""Thread-safe way to iterate over the runs."""
|
162
226
|
with self._threading_lock:
|
@@ -168,25 +232,38 @@ class Scheduler(ABC):
|
|
168
232
|
self._stop_run(run_id)
|
169
233
|
|
170
234
|
def _stop_run(self, run_id: str) -> None:
|
171
|
-
"""
|
235
|
+
"""Stop a run and removes it from the scheduler."""
|
172
236
|
if run_id in self._runs:
|
173
237
|
run: SweepRun = self._runs[run_id]
|
174
|
-
run.state =
|
238
|
+
run.state = RunState.DEAD
|
175
239
|
# TODO(hupo): Send command to backend to stop run
|
176
240
|
wandb.termlog(f"{LOG_PREFIX} Stopped run {run_id}.")
|
177
241
|
|
178
242
|
def _update_run_states(self) -> None:
|
243
|
+
"""Iterate through runs.
|
244
|
+
|
245
|
+
Get state from backend and deletes runs if not in running state. Threadsafe.
|
246
|
+
"""
|
179
247
|
_runs_to_remove: List[str] = []
|
180
248
|
for run_id, run in self._yield_runs():
|
181
249
|
try:
|
182
250
|
_state = self._api.get_run_state(self._entity, self._project, run_id)
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
251
|
+
_rqi_state = run.queued_run.state if run.queued_run else None
|
252
|
+
if (
|
253
|
+
not _state
|
254
|
+
or _state
|
255
|
+
in [
|
256
|
+
"crashed",
|
257
|
+
"failed",
|
258
|
+
"killed",
|
259
|
+
"finished",
|
260
|
+
]
|
261
|
+
or _rqi_state == "failed"
|
262
|
+
):
|
263
|
+
_logger.debug(
|
264
|
+
f"({run_id}) run-state:{_state}, rqi-state:{_rqi_state}"
|
265
|
+
)
|
266
|
+
run.state = RunState.DEAD
|
190
267
|
_runs_to_remove.append(run_id)
|
191
268
|
elif _state in [
|
192
269
|
"running",
|
@@ -194,12 +271,12 @@ class Scheduler(ABC):
|
|
194
271
|
"preempted",
|
195
272
|
"preempting",
|
196
273
|
]:
|
197
|
-
run.state =
|
274
|
+
run.state = RunState.ALIVE
|
198
275
|
except CommError as e:
|
199
276
|
wandb.termlog(
|
200
277
|
f"{LOG_PREFIX}Issue when getting RunState for Run {run_id}: {e}"
|
201
278
|
)
|
202
|
-
run.state =
|
279
|
+
run.state = RunState.UNKNOWN
|
203
280
|
continue
|
204
281
|
# Remove any runs that are dead
|
205
282
|
with self._threading_lock:
|
@@ -213,31 +290,47 @@ class Scheduler(ABC):
|
|
213
290
|
entry_point: Optional[List[str]] = None,
|
214
291
|
config: Optional[Dict[str, Any]] = None,
|
215
292
|
) -> "public.QueuedRun":
|
216
|
-
"""Add a launch job to the Launch RunQueue.
|
293
|
+
"""Add a launch job to the Launch RunQueue.
|
294
|
+
|
295
|
+
run_id: supplied by gorilla from agentHeartbeat
|
296
|
+
entry_point: sweep entrypoint overrides image_uri/job entrypoint
|
297
|
+
config: launch config
|
298
|
+
"""
|
299
|
+
# job and image first from CLI args, then from sweep config
|
300
|
+
_job = self._kwargs.get("job") or self._sweep_config.get("job")
|
301
|
+
|
302
|
+
_sweep_config_uri = self._sweep_config.get("image_uri")
|
303
|
+
_image_uri = self._kwargs.get("image_uri") or _sweep_config_uri
|
304
|
+
if _job is None and _image_uri is None:
|
305
|
+
raise SchedulerError(
|
306
|
+
f"{LOG_PREFIX}No 'job' nor 'image_uri' (run: {run_id})"
|
307
|
+
)
|
308
|
+
elif _job is not None and _image_uri is not None:
|
309
|
+
raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
|
310
|
+
|
311
|
+
if self._sweep_config.get("command"):
|
312
|
+
entry_point = Agent._create_sweep_command(self._sweep_config["command"])
|
313
|
+
wandb.termwarn(
|
314
|
+
f"{LOG_PREFIX}Sweep command {entry_point} will override"
|
315
|
+
f' {"job" if _job else "image_uri"} entrypoint'
|
316
|
+
)
|
317
|
+
|
217
318
|
run_id = run_id or generate_id()
|
218
|
-
# One of Job and URI is required
|
219
|
-
_job = self._kwargs.get("job", None)
|
220
|
-
_uri = self._kwargs.get("uri", None)
|
221
|
-
if _job is None and _uri is None:
|
222
|
-
# If no Job is specified, use a placeholder URI to prevent Launch failure
|
223
|
-
_uri = "placeholder-uri-queuedrun-from-scheduler"
|
224
|
-
# Queue is required
|
225
|
-
_queue = self._kwargs.get("queue", "default")
|
226
319
|
queued_run = launch_add(
|
227
320
|
run_id=run_id,
|
228
321
|
entry_point=entry_point,
|
229
322
|
config=config,
|
230
|
-
|
323
|
+
docker_image=_image_uri, # TODO(gst): make agnostic (github? run uri?)
|
231
324
|
job=_job,
|
232
325
|
project=self._project,
|
233
326
|
entity=self._entity,
|
234
|
-
queue_name=
|
327
|
+
queue_name=self._kwargs.get("queue"),
|
235
328
|
project_queue=self._project_queue,
|
236
329
|
resource=self._kwargs.get("resource", None),
|
237
330
|
resource_args=self._kwargs.get("resource_args", None),
|
238
331
|
)
|
239
332
|
self._runs[run_id].queued_run = queued_run
|
240
333
|
wandb.termlog(
|
241
|
-
f"{LOG_PREFIX}Added run to Launch
|
334
|
+
f"{LOG_PREFIX}Added run to Launch queue: {self._kwargs.get('queue')} RunID:{run_id}."
|
242
335
|
)
|
243
336
|
return queued_run
|