wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,7 @@ import traceback
|
|
9
9
|
from abc import ABC, abstractmethod
|
10
10
|
from dataclasses import dataclass
|
11
11
|
from enum import Enum
|
12
|
-
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
12
|
+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
13
13
|
|
14
14
|
import click
|
15
15
|
import yaml
|
@@ -17,7 +17,10 @@ import yaml
|
|
17
17
|
import wandb
|
18
18
|
import wandb.apis.public as public
|
19
19
|
from wandb.apis.internal import Api
|
20
|
+
from wandb.apis.public import Api as PublicApi
|
21
|
+
from wandb.apis.public import QueuedRun, Run
|
20
22
|
from wandb.errors import CommError
|
23
|
+
from wandb.sdk.launch.errors import LaunchError
|
21
24
|
from wandb.sdk.launch.launch_add import launch_add
|
22
25
|
from wandb.sdk.launch.sweeps import SchedulerError
|
23
26
|
from wandb.sdk.launch.sweeps.utils import (
|
@@ -25,10 +28,13 @@ from wandb.sdk.launch.sweeps.utils import (
|
|
25
28
|
make_launch_sweep_entrypoint,
|
26
29
|
)
|
27
30
|
from wandb.sdk.lib.runid import generate_id
|
31
|
+
from wandb.sdk.wandb_run import Run as SdkRun
|
28
32
|
|
29
33
|
_logger = logging.getLogger(__name__)
|
30
34
|
LOG_PREFIX = f"{click.style('sched:', fg='cyan')} "
|
31
35
|
|
36
|
+
DEFAULT_POLLING_SLEEP = 5.0
|
37
|
+
|
32
38
|
|
33
39
|
class SchedulerState(Enum):
|
34
40
|
PENDING = 0
|
@@ -42,9 +48,29 @@ class SchedulerState(Enum):
|
|
42
48
|
|
43
49
|
|
44
50
|
class RunState(Enum):
|
45
|
-
|
46
|
-
|
47
|
-
|
51
|
+
RUNNING = "running", "alive"
|
52
|
+
PENDING = "pending", "alive"
|
53
|
+
PREEMPTING = "preempting", "alive"
|
54
|
+
CRASHED = "crashed", "dead"
|
55
|
+
FAILED = "failed", "dead"
|
56
|
+
KILLED = "killed", "dead"
|
57
|
+
FINISHED = "finished", "dead"
|
58
|
+
PREEMPTED = "preempted", "dead"
|
59
|
+
# unknown when api.get_run_state fails or returns unexpected state
|
60
|
+
# assumed alive, unless we get unknown 2x then move to failed (dead)
|
61
|
+
UNKNOWN = "unknown", "alive"
|
62
|
+
|
63
|
+
def __new__(cls: Any, *args: List, **kwds: Any) -> "RunState":
|
64
|
+
obj: "RunState" = object.__new__(cls)
|
65
|
+
obj._value_ = args[0]
|
66
|
+
return obj
|
67
|
+
|
68
|
+
def __init__(self, _: str, life: str = "unknown") -> None:
|
69
|
+
self._life = life
|
70
|
+
|
71
|
+
@property
|
72
|
+
def is_alive(self) -> bool:
|
73
|
+
return self._life == "alive"
|
48
74
|
|
49
75
|
|
50
76
|
@dataclass
|
@@ -57,7 +83,7 @@ class _Worker:
|
|
57
83
|
class SweepRun:
|
58
84
|
id: str
|
59
85
|
worker_id: int
|
60
|
-
state: RunState = RunState.
|
86
|
+
state: RunState = RunState.RUNNING
|
61
87
|
queued_run: Optional[public.QueuedRun] = None
|
62
88
|
args: Optional[Dict[str, Any]] = None
|
63
89
|
logs: Optional[List[str]] = None
|
@@ -66,20 +92,24 @@ class SweepRun:
|
|
66
92
|
class Scheduler(ABC):
|
67
93
|
"""A controller/agent that populates a Launch RunQueue from a hyperparameter sweep."""
|
68
94
|
|
95
|
+
PLACEHOLDER_URI = "placeholder-uri-scheduler"
|
96
|
+
SWEEP_JOB_TYPE = "sweep-controller"
|
97
|
+
ENTRYPOINT = ["wandb", "scheduler", "WANDB_SWEEP_ID"]
|
98
|
+
|
69
99
|
def __init__(
|
70
100
|
self,
|
71
101
|
api: Api,
|
72
102
|
*args: Optional[Any],
|
73
|
-
|
74
|
-
polling_sleep: float = 5.0,
|
103
|
+
polling_sleep: Optional[float] = None,
|
75
104
|
sweep_id: Optional[str] = None,
|
76
105
|
entity: Optional[str] = None,
|
77
106
|
project: Optional[str] = None,
|
78
107
|
project_queue: Optional[str] = None,
|
108
|
+
num_workers: Optional[Union[int, str]] = None,
|
79
109
|
**kwargs: Optional[Any],
|
80
110
|
):
|
81
111
|
self._api = api
|
82
|
-
self._public_api =
|
112
|
+
self._public_api = PublicApi()
|
83
113
|
self._entity = (
|
84
114
|
entity
|
85
115
|
or os.environ.get("WANDB_ENTITY")
|
@@ -100,27 +130,42 @@ class Scheduler(ABC):
|
|
100
130
|
if resp.get("state") == SchedulerState.CANCELLED.name:
|
101
131
|
self._state = SchedulerState.CANCELLED
|
102
132
|
self._sweep_config = yaml.safe_load(resp["config"])
|
133
|
+
self._num_runs_launched: int = self._get_num_runs_launched(resp["runs"])
|
134
|
+
if self._num_runs_launched > 0:
|
135
|
+
wandb.termlog(
|
136
|
+
f"{LOG_PREFIX}Found {self._num_runs_launched} previous valid runs for sweep {self._sweep_id}"
|
137
|
+
)
|
103
138
|
except Exception as e:
|
104
139
|
raise SchedulerError(
|
105
140
|
f"{LOG_PREFIX}Exception when finding sweep ({sweep_id}) {e}"
|
106
141
|
)
|
107
142
|
|
143
|
+
# Scheduler may receive additional kwargs which will be piped into the launch command
|
144
|
+
self._kwargs: Dict[str, Any] = kwargs
|
145
|
+
|
108
146
|
# Dictionary of the runs being managed by the scheduler
|
109
147
|
self._runs: Dict[str, SweepRun] = {}
|
110
148
|
# Threading lock to ensure thread-safe access to the runs dictionary
|
111
149
|
self._threading_lock: threading.Lock = threading.Lock()
|
112
|
-
self._polling_sleep = polling_sleep
|
150
|
+
self._polling_sleep = polling_sleep or DEFAULT_POLLING_SLEEP
|
113
151
|
self._project_queue = project_queue
|
114
152
|
# Optionally run multiple workers in (pseudo-)parallel. Workers do not
|
115
153
|
# actually run training workloads, they simply send heartbeat messages
|
116
154
|
# (emulating a real agent) and add new runs to the launch queue. The
|
117
155
|
# launch agent is the one that actually runs the training workloads.
|
118
156
|
self._workers: Dict[int, _Worker] = {}
|
119
|
-
self._num_workers = num_workers
|
120
|
-
self._num_runs_launched = 0
|
121
157
|
|
122
|
-
#
|
123
|
-
self.
|
158
|
+
# Init wandb scheduler run
|
159
|
+
self._wandb_run = self._init_wandb_run()
|
160
|
+
|
161
|
+
# Grab params from scheduler wandb run config
|
162
|
+
num_workers = num_workers or self._wandb_run.config.get("scheduler", {}).get(
|
163
|
+
"num_workers"
|
164
|
+
)
|
165
|
+
self._num_workers = int(num_workers) if str(num_workers).isdigit() else 8
|
166
|
+
self._settings_config: Dict[str, Any] = self._wandb_run.config.get(
|
167
|
+
"settings", {}
|
168
|
+
)
|
124
169
|
|
125
170
|
@abstractmethod
|
126
171
|
def _get_next_sweep_run(self, worker_id: int) -> Optional[SweepRun]:
|
@@ -168,7 +213,6 @@ class Scheduler(ABC):
|
|
168
213
|
@property
|
169
214
|
def at_runcap(self) -> bool:
|
170
215
|
"""False if under user-specified cap on # of runs."""
|
171
|
-
# TODO(gst): Count previous runs for resumed sweeps
|
172
216
|
run_cap = self._sweep_config.get("run_cap")
|
173
217
|
if not run_cap:
|
174
218
|
return False
|
@@ -200,6 +244,18 @@ class Scheduler(ABC):
|
|
200
244
|
_id: w for _id, w in self._workers.items() if _id not in self.busy_workers
|
201
245
|
}
|
202
246
|
|
247
|
+
def _init_wandb_run(self) -> SdkRun:
|
248
|
+
"""Controls resume or init logic for a scheduler wandb run."""
|
249
|
+
_type = self._kwargs.get("sweep_type", "sweep")
|
250
|
+
run: SdkRun = wandb.init(
|
251
|
+
name=f"{_type}-scheduler-{self._sweep_id}",
|
252
|
+
job_type=self.SWEEP_JOB_TYPE,
|
253
|
+
# WANDB_RUN_ID = sweep_id for scheduler
|
254
|
+
resume="allow",
|
255
|
+
config=self._kwargs, # when run as a job, this sets config
|
256
|
+
)
|
257
|
+
return run
|
258
|
+
|
203
259
|
def stop_sweep(self) -> None:
|
204
260
|
"""Stop the sweep."""
|
205
261
|
self._state = SchedulerState.STOPPED
|
@@ -228,6 +284,7 @@ class Scheduler(ABC):
|
|
228
284
|
self.exit()
|
229
285
|
return
|
230
286
|
|
287
|
+
# For resuming sweeps
|
231
288
|
self._load_state()
|
232
289
|
self._register_agents()
|
233
290
|
self.run()
|
@@ -238,10 +295,12 @@ class Scheduler(ABC):
|
|
238
295
|
self.state = SchedulerState.RUNNING
|
239
296
|
try:
|
240
297
|
while True:
|
241
|
-
|
298
|
+
self._update_scheduler_run_state()
|
242
299
|
if not self.is_alive:
|
243
300
|
break
|
244
301
|
|
302
|
+
wandb.termlog(f"{LOG_PREFIX}Polling for new runs to launch")
|
303
|
+
|
245
304
|
self._update_run_states()
|
246
305
|
self._poll()
|
247
306
|
if self.state == SchedulerState.FLUSH_RUNS:
|
@@ -259,8 +318,17 @@ class Scheduler(ABC):
|
|
259
318
|
self.state = SchedulerState.FLUSH_RUNS
|
260
319
|
break
|
261
320
|
|
262
|
-
|
263
|
-
|
321
|
+
try:
|
322
|
+
run: Optional[SweepRun] = self._get_next_sweep_run(worker_id)
|
323
|
+
if not run:
|
324
|
+
break
|
325
|
+
except SchedulerError as e:
|
326
|
+
raise SchedulerError(e)
|
327
|
+
except Exception as e:
|
328
|
+
wandb.termerror(
|
329
|
+
f"{LOG_PREFIX}Failed to get next sweep run: {e}"
|
330
|
+
)
|
331
|
+
self.state = SchedulerState.FAILED
|
264
332
|
break
|
265
333
|
|
266
334
|
if self._add_to_launch_queue(run):
|
@@ -278,18 +346,49 @@ class Scheduler(ABC):
|
|
278
346
|
self.exit()
|
279
347
|
raise e
|
280
348
|
else:
|
281
|
-
wandb.termlog(f"{LOG_PREFIX}Scheduler completed")
|
349
|
+
wandb.termlog(f"{LOG_PREFIX}Scheduler completed successfully")
|
350
|
+
# don't overwrite special states (e.g. STOPPED, FAILED)
|
351
|
+
if self.state in [SchedulerState.RUNNING, SchedulerState.FLUSH_RUNS]:
|
352
|
+
self.state = SchedulerState.COMPLETED
|
282
353
|
self.exit()
|
283
354
|
|
284
355
|
def exit(self) -> None:
|
285
356
|
self._exit()
|
286
|
-
|
357
|
+
# _save_state isn't controlled, possibly fails
|
358
|
+
try:
|
359
|
+
self._save_state()
|
360
|
+
except Exception:
|
361
|
+
wandb.termerror(
|
362
|
+
f"{LOG_PREFIX}Failed to save state: {traceback.format_exc()}"
|
363
|
+
)
|
364
|
+
|
287
365
|
if self.state not in [
|
288
366
|
SchedulerState.COMPLETED,
|
289
367
|
SchedulerState.STOPPED,
|
290
368
|
]:
|
291
369
|
self.state = SchedulerState.FAILED
|
370
|
+
self._set_sweep_state("CRASHED")
|
371
|
+
else:
|
372
|
+
self._set_sweep_state("FINISHED")
|
373
|
+
|
292
374
|
self._stop_runs()
|
375
|
+
self._wandb_run.finish()
|
376
|
+
|
377
|
+
def _get_num_runs_launched(self, runs: List[Dict[str, Any]]) -> int:
|
378
|
+
"""Returns the number of valid runs in the sweep."""
|
379
|
+
count = 0
|
380
|
+
for run in runs:
|
381
|
+
# if bad run, shouldn't be counted against run cap
|
382
|
+
if run.get("state", "") in ["killed", "crashed"] and not run.get(
|
383
|
+
"summaryMetrics"
|
384
|
+
):
|
385
|
+
_logger.debug(
|
386
|
+
f"excluding run: {run['name']} with state: {run['state']} from run cap \n{run}"
|
387
|
+
)
|
388
|
+
continue
|
389
|
+
count += 1
|
390
|
+
|
391
|
+
return count
|
293
392
|
|
294
393
|
def _try_load_executable(self) -> bool:
|
295
394
|
"""Check existance of valid executable for a run.
|
@@ -297,9 +396,8 @@ class Scheduler(ABC):
|
|
297
396
|
logs and returns False when job is unreachable
|
298
397
|
"""
|
299
398
|
if self._kwargs.get("job"):
|
300
|
-
_public_api = public.Api()
|
301
399
|
try:
|
302
|
-
_job_artifact = _public_api.
|
400
|
+
_job_artifact = self._public_api.job(self._kwargs["job"])
|
303
401
|
wandb.termlog(
|
304
402
|
f"{LOG_PREFIX}Successfully loaded job ({_job_artifact.name}) in scheduler"
|
305
403
|
)
|
@@ -316,12 +414,17 @@ class Scheduler(ABC):
|
|
316
414
|
def _register_agents(self) -> None:
|
317
415
|
for worker_id in range(self._num_workers):
|
318
416
|
_logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker ({worker_id})")
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
417
|
+
try:
|
418
|
+
agent_config = self._api.register_agent(
|
419
|
+
f"{socket.gethostname()}-{worker_id}", # host
|
420
|
+
sweep_id=self._sweep_id,
|
421
|
+
project_name=self._project,
|
422
|
+
entity=self._entity,
|
423
|
+
)
|
424
|
+
except Exception as e:
|
425
|
+
_logger.debug(f"failed to register agent: {e}")
|
426
|
+
self.fail_sweep(f"failed to register agent: {e}")
|
427
|
+
|
325
428
|
self._workers[worker_id] = _Worker(
|
326
429
|
agent_config=agent_config,
|
327
430
|
agent_id=agent_config["id"],
|
@@ -332,6 +435,17 @@ class Scheduler(ABC):
|
|
332
435
|
with self._threading_lock:
|
333
436
|
yield from self._runs.items()
|
334
437
|
|
438
|
+
def _cleanup_runs(self, runs_to_remove: List[str]) -> None:
|
439
|
+
"""Helper for removing runs from memory.
|
440
|
+
|
441
|
+
Can be overloaded to prevent deletion of runs, which is useful
|
442
|
+
for debugging or when polling on completed runs.
|
443
|
+
"""
|
444
|
+
with self._threading_lock:
|
445
|
+
for run_id in runs_to_remove:
|
446
|
+
wandb.termlog(f"{LOG_PREFIX}Cleaning up finished run ({run_id})")
|
447
|
+
del self._runs[run_id]
|
448
|
+
|
335
449
|
def _stop_runs(self) -> None:
|
336
450
|
to_delete = []
|
337
451
|
for run_id, _ in self._yield_runs():
|
@@ -357,7 +471,7 @@ class Scheduler(ABC):
|
|
357
471
|
)
|
358
472
|
return False
|
359
473
|
|
360
|
-
if run.state
|
474
|
+
if not run.state.is_alive:
|
361
475
|
# run already dead, just delete reference
|
362
476
|
return True
|
363
477
|
|
@@ -366,82 +480,195 @@ class Scheduler(ABC):
|
|
366
480
|
f"Run:v1:{run_id}:{self._project}:{self._entity}".encode()
|
367
481
|
).decode("utf-8")
|
368
482
|
|
369
|
-
|
370
|
-
|
371
|
-
|
483
|
+
try:
|
484
|
+
success: bool = self._api.stop_run(run_id=encoded_run_id)
|
485
|
+
if success:
|
486
|
+
wandb.termlog(f"{LOG_PREFIX}Stopped run {run_id}.")
|
487
|
+
return True
|
488
|
+
except Exception as e:
|
489
|
+
_logger.debug(f"error stopping run ({run_id}): {e}")
|
490
|
+
|
491
|
+
return False
|
492
|
+
|
493
|
+
def _update_scheduler_run_state(self) -> None:
|
494
|
+
"""Update the scheduler state from state of scheduler run and sweep state."""
|
495
|
+
state: RunState = self._get_run_state(self._wandb_run.id)
|
372
496
|
|
373
|
-
|
497
|
+
if state == RunState.KILLED:
|
498
|
+
self.state = SchedulerState.STOPPED
|
499
|
+
elif state in [RunState.FAILED, RunState.CRASHED]:
|
500
|
+
self.state = SchedulerState.FAILED
|
501
|
+
elif state == RunState.FINISHED:
|
502
|
+
self.state = SchedulerState.COMPLETED
|
503
|
+
|
504
|
+
try:
|
505
|
+
sweep_state = self._api.get_sweep_state(
|
506
|
+
self._sweep_id, self._entity, self._project
|
507
|
+
)
|
508
|
+
except Exception as e:
|
509
|
+
_logger.debug(f"sweep state error: {sweep_state} e: {e}")
|
510
|
+
return
|
511
|
+
|
512
|
+
if sweep_state in ["FINISHED", "CANCELLED"]:
|
513
|
+
self.state = SchedulerState.COMPLETED
|
514
|
+
elif sweep_state in ["PAUSED", "STOPPED"]:
|
515
|
+
self.state = SchedulerState.FLUSH_RUNS
|
374
516
|
|
375
517
|
def _update_run_states(self) -> None:
|
376
518
|
"""Iterate through runs.
|
377
519
|
|
378
520
|
Get state from backend and deletes runs if not in running state. Threadsafe.
|
379
521
|
"""
|
380
|
-
|
381
|
-
end_states = [
|
382
|
-
"crashed",
|
383
|
-
"failed",
|
384
|
-
"killed",
|
385
|
-
"finished",
|
386
|
-
"preempted",
|
387
|
-
]
|
388
|
-
run_states = ["running", "pending", "preempting"]
|
389
|
-
|
390
|
-
_runs_to_remove: List[str] = []
|
522
|
+
runs_to_remove: List[str] = []
|
391
523
|
for run_id, run in self._yield_runs():
|
524
|
+
run.state = self._get_run_state(run_id, run.state)
|
525
|
+
|
392
526
|
try:
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
527
|
+
rqi_state = run.queued_run.state if run.queued_run else None
|
528
|
+
except (CommError, LaunchError) as e:
|
529
|
+
_logger.debug(f"Failed to get queued_run.state: {e}")
|
530
|
+
rqi_state = None
|
531
|
+
|
532
|
+
if not run.state.is_alive or rqi_state == "failed":
|
533
|
+
_logger.debug(f"({run_id}) states: ({run.state}, {rqi_state})")
|
534
|
+
runs_to_remove.append(run_id)
|
535
|
+
self._cleanup_runs(runs_to_remove)
|
536
|
+
|
537
|
+
def _get_metrics_from_run(self, run_id: str) -> List[Any]:
|
538
|
+
"""Use the public api to get metrics from a run.
|
539
|
+
|
540
|
+
Uses the metric name found in the sweep config, any
|
541
|
+
misspellings will result in an empty list.
|
542
|
+
"""
|
543
|
+
try:
|
544
|
+
queued_run: Optional[QueuedRun] = self._runs[run_id].queued_run
|
545
|
+
if not queued_run:
|
546
|
+
return []
|
547
|
+
|
548
|
+
api_run: Run = self._public_api.run(
|
549
|
+
f"{queued_run.entity}/{queued_run.project}/{run_id}"
|
550
|
+
)
|
551
|
+
metric_name = self._sweep_config["metric"]["name"]
|
552
|
+
history = api_run.scan_history(keys=["_step", metric_name])
|
553
|
+
metrics = [x[metric_name] for x in history]
|
554
|
+
|
555
|
+
return metrics
|
556
|
+
except Exception as e:
|
557
|
+
_logger.debug(f"[_get_metrics_from_run] {e}")
|
558
|
+
return []
|
559
|
+
|
560
|
+
def _get_run_info(self, run_id: str) -> Dict[str, Any]:
|
561
|
+
"""Use the public api to get info about a run."""
|
562
|
+
try:
|
563
|
+
info: Dict[str, Any] = self._api.get_run_info(
|
564
|
+
self._entity, self._project, run_id
|
565
|
+
)
|
566
|
+
if info:
|
567
|
+
return info
|
568
|
+
except Exception as e:
|
569
|
+
_logger.debug(f"[_get_run_info] {e}")
|
570
|
+
return {}
|
571
|
+
|
572
|
+
def _get_run_state(
|
573
|
+
self, run_id: str, prev_run_state: RunState = RunState.UNKNOWN
|
574
|
+
) -> RunState:
|
575
|
+
"""Use the public api to get state of a run."""
|
576
|
+
run_state = None
|
577
|
+
try:
|
578
|
+
state = self._api.get_run_state(self._entity, self._project, run_id)
|
579
|
+
run_state = RunState(state)
|
580
|
+
except CommError as e:
|
581
|
+
_logger.debug(f"error getting state for run ({run_id}): {e}")
|
582
|
+
if prev_run_state == RunState.UNKNOWN:
|
583
|
+
# triggers when we get an unknown state for the second time
|
584
|
+
wandb.termwarn(
|
585
|
+
f"Failed to get runstate for run ({run_id}). Error: {traceback.format_exc()}"
|
406
586
|
)
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
587
|
+
run_state = RunState.FAILED
|
588
|
+
else: # first time we get unknwon state
|
589
|
+
run_state = RunState.UNKNOWN
|
590
|
+
except (AttributeError, ValueError):
|
591
|
+
wandb.termwarn(
|
592
|
+
f"Bad state ({run_state}) for run ({run_id}). Error: {traceback.format_exc()}"
|
593
|
+
)
|
594
|
+
run_state = RunState.UNKNOWN
|
595
|
+
return run_state
|
414
596
|
|
415
|
-
def
|
416
|
-
"""
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
597
|
+
def _create_run(self) -> Dict[str, Any]:
|
598
|
+
"""Use the public api to create a blank run."""
|
599
|
+
try:
|
600
|
+
run: List[Dict[str, Any]] = self._api.upsert_run(
|
601
|
+
project=self._project,
|
602
|
+
entity=self._entity,
|
603
|
+
sweep_name=self._sweep_id,
|
604
|
+
)
|
605
|
+
if run:
|
606
|
+
return run[0]
|
607
|
+
except Exception as e:
|
608
|
+
_logger.debug(f"[_create_run] {e}")
|
609
|
+
raise SchedulerError(
|
610
|
+
"Error creating run from scheduler, check API connection and CLI version."
|
611
|
+
)
|
612
|
+
return {}
|
613
|
+
|
614
|
+
def _set_sweep_state(self, state: str) -> None:
|
615
|
+
wandb.termlog(f"{LOG_PREFIX}Updating sweep state to: {state.lower()}")
|
616
|
+
try:
|
617
|
+
self._api.set_sweep_state(sweep=self._sweep_id, state=state)
|
618
|
+
except Exception as e:
|
619
|
+
_logger.debug(f"[set_sweep_state] {e}")
|
620
|
+
|
621
|
+
def _encode(self, _id: str) -> str:
|
622
|
+
return (
|
623
|
+
base64.b64decode(bytes(_id.encode("utf-8"))).decode("utf-8").split(":")[2]
|
624
|
+
)
|
425
625
|
|
626
|
+
def _make_entry_and_launch_config(
|
627
|
+
self, run: SweepRun
|
628
|
+
) -> Tuple[Optional[List[str]], Dict[str, Dict[str, Any]]]:
|
426
629
|
args = create_sweep_command_args({"args": run.args})
|
427
630
|
entry_point, macro_args = make_launch_sweep_entrypoint(
|
428
631
|
args, self._sweep_config.get("command")
|
429
632
|
)
|
633
|
+
# handle program macro
|
634
|
+
if entry_point and "${program}" in entry_point:
|
635
|
+
if not self._sweep_config.get("program"):
|
636
|
+
raise SchedulerError(
|
637
|
+
f"{LOG_PREFIX}Program macro in command has no corresponding 'program' in sweep config."
|
638
|
+
)
|
639
|
+
pidx = entry_point.index("${program}")
|
640
|
+
entry_point[pidx] = self._sweep_config["program"]
|
641
|
+
|
430
642
|
launch_config = {"overrides": {"run_config": args["args_dict"]}}
|
431
643
|
if macro_args: # pipe in hyperparam args as params to launch
|
432
644
|
launch_config["overrides"]["args"] = macro_args
|
433
645
|
|
434
646
|
if entry_point:
|
435
|
-
wandb.termwarn(
|
436
|
-
f"{LOG_PREFIX}Sweep command {entry_point} will override"
|
437
|
-
f' {"job" if _job else "image_uri"} entrypoint'
|
438
|
-
)
|
439
647
|
unresolved = [x for x in entry_point if str(x).startswith("${")]
|
440
648
|
if unresolved:
|
441
649
|
wandb.termwarn(
|
442
650
|
f"{LOG_PREFIX}Sweep command contains unresolved macros: "
|
443
651
|
f"{unresolved}, see launch docs for supported macros."
|
444
652
|
)
|
653
|
+
return entry_point, launch_config
|
654
|
+
|
655
|
+
def _add_to_launch_queue(self, run: SweepRun) -> bool:
|
656
|
+
"""Convert a sweeprun into a launch job then push to runqueue."""
|
657
|
+
# job and image first from CLI args, then from sweep config
|
658
|
+
_job = self._kwargs.get("job") or self._sweep_config.get("job")
|
659
|
+
_sweep_config_uri = self._sweep_config.get("image_uri")
|
660
|
+
_image_uri = self._kwargs.get("image_uri") or _sweep_config_uri
|
661
|
+
if _job is None and _image_uri is None:
|
662
|
+
raise SchedulerError(f"{LOG_PREFIX}No 'job' nor 'image_uri' ({run.id})")
|
663
|
+
elif _job is not None and _image_uri is not None:
|
664
|
+
raise SchedulerError(f"{LOG_PREFIX}Sweep has both 'job' and 'image_uri'")
|
665
|
+
|
666
|
+
entry_point, launch_config = self._make_entry_and_launch_config(run)
|
667
|
+
if entry_point:
|
668
|
+
wandb.termwarn(
|
669
|
+
f"{LOG_PREFIX}Sweep command {entry_point} will override"
|
670
|
+
f' {"job" if _job else "image_uri"} entrypoint'
|
671
|
+
)
|
445
672
|
|
446
673
|
run_id = run.id or generate_id()
|
447
674
|
queued_run = launch_add(
|
@@ -457,8 +684,11 @@ class Scheduler(ABC):
|
|
457
684
|
resource=self._kwargs.get("resource", None),
|
458
685
|
resource_args=self._kwargs.get("resource_args", None),
|
459
686
|
author=self._kwargs.get("author"),
|
687
|
+
sweep_id=self._sweep_id,
|
460
688
|
)
|
461
689
|
run.queued_run = queued_run
|
690
|
+
# TODO(gst): unify run and queued_run state
|
691
|
+
run.state = RunState.RUNNING # assume it will get picked up
|
462
692
|
self._runs[run_id] = run
|
463
693
|
|
464
694
|
wandb.termlog(
|
@@ -50,6 +50,7 @@ class SweepScheduler(Scheduler):
|
|
50
50
|
|
51
51
|
return SweepRun(
|
52
52
|
id=_run_id,
|
53
|
+
state=RunState.PENDING,
|
53
54
|
args=command.get("args", {}),
|
54
55
|
logs=command.get("logs", []),
|
55
56
|
worker_id=worker_id,
|
@@ -62,7 +63,7 @@ class SweepScheduler(Scheduler):
|
|
62
63
|
_run_states: Dict[str, bool] = {}
|
63
64
|
for run_id, run in self._yield_runs():
|
64
65
|
# Filter out runs that are from a different worker thread
|
65
|
-
if run.worker_id == worker_id and run.state
|
66
|
+
if run.worker_id == worker_id and run.state.is_alive:
|
66
67
|
_run_states[run_id] = True
|
67
68
|
|
68
69
|
_logger.debug(f"Sending states: \n{pf(_run_states)}\n")
|