wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -1,36 +1,28 @@
|
|
1
1
|
"""Scheduler for classic wandb Sweeps."""
|
2
2
|
import logging
|
3
|
-
import pprint
|
4
3
|
import queue
|
5
4
|
import socket
|
6
5
|
import time
|
7
|
-
from
|
6
|
+
from pprint import pformat as pf
|
8
7
|
from typing import Any, Dict, List
|
9
8
|
|
10
9
|
import wandb
|
11
10
|
from wandb.sdk.launch.sweeps import SchedulerError
|
12
11
|
from wandb.sdk.launch.sweeps.scheduler import (
|
13
12
|
LOG_PREFIX,
|
13
|
+
RunState,
|
14
14
|
Scheduler,
|
15
15
|
SchedulerState,
|
16
|
-
SimpleRunState,
|
17
16
|
SweepRun,
|
17
|
+
_Worker,
|
18
18
|
)
|
19
|
-
from wandb.wandb_agent import
|
19
|
+
from wandb.wandb_agent import _create_sweep_command_args
|
20
20
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
@dataclass
|
25
|
-
class _Worker:
|
26
|
-
agent_config: Dict[str, Any]
|
27
|
-
agent_id: str
|
21
|
+
_logger = logging.getLogger(__name__)
|
28
22
|
|
29
23
|
|
30
24
|
class SweepScheduler(Scheduler):
|
31
|
-
"""A
|
32
|
-
launch jobs it creates from run suggestions it pulls from an internal sweeps RunQueue.
|
33
|
-
"""
|
25
|
+
"""A controller/agent that populates a Launch RunQueue from a sweeps RunQueue."""
|
34
26
|
|
35
27
|
def __init__(
|
36
28
|
self,
|
@@ -41,11 +33,6 @@ class SweepScheduler(Scheduler):
|
|
41
33
|
**kwargs: Any,
|
42
34
|
):
|
43
35
|
super().__init__(*args, **kwargs)
|
44
|
-
# Optionally run multiple workers in (pseudo-)parallel. Workers do not
|
45
|
-
# actually run training workloads, they simply send heartbeat messages
|
46
|
-
# (emulating a real agent) and add new runs to the launch queue. The
|
47
|
-
# launch agent is the one that actually runs the training workloads.
|
48
|
-
self._workers: Dict[int, _Worker] = {}
|
49
36
|
self._num_workers: int = num_workers
|
50
37
|
# Thread will pop items off the Sweeps RunQueue using AgentHeartbeat
|
51
38
|
# and put them in this internal queue, which will be used to populate
|
@@ -56,7 +43,7 @@ class SweepScheduler(Scheduler):
|
|
56
43
|
|
57
44
|
def _start(self) -> None:
|
58
45
|
for worker_id in range(self._num_workers):
|
59
|
-
|
46
|
+
_logger.debug(f"{LOG_PREFIX}Starting AgentHeartbeat worker {worker_id}\n")
|
60
47
|
agent_config = self._api.register_agent(
|
61
48
|
f"{socket.gethostname()}-{worker_id}", # host
|
62
49
|
sweep_id=self._sweep_id,
|
@@ -68,92 +55,96 @@ class SweepScheduler(Scheduler):
|
|
68
55
|
agent_id=agent_config["id"],
|
69
56
|
)
|
70
57
|
|
71
|
-
def
|
72
|
-
# Make sure Scheduler is alive
|
73
|
-
if not self.is_alive():
|
74
|
-
return
|
58
|
+
def _get_sweep_commands(self, worker_id: int) -> List[Dict[str, Any]]:
|
75
59
|
# AgentHeartbeat wants a Dict of runs which are running or queued
|
76
60
|
_run_states: Dict[str, bool] = {}
|
77
61
|
for run_id, run in self._yield_runs():
|
78
62
|
# Filter out runs that are from a different worker thread
|
79
|
-
if run.worker_id == worker_id and run.state ==
|
63
|
+
if run.worker_id == worker_id and run.state == RunState.ALIVE:
|
80
64
|
_run_states[run_id] = True
|
81
|
-
|
82
|
-
|
83
|
-
)
|
65
|
+
|
66
|
+
_logger.debug(f"{LOG_PREFIX}Sending states: \n{pf(_run_states)}\n")
|
84
67
|
commands: List[Dict[str, Any]] = self._api.agent_heartbeat(
|
85
68
|
self._workers[worker_id].agent_id, # agent_id: str
|
86
69
|
{}, # metrics: dict
|
87
70
|
_run_states, # run_states: dict
|
88
71
|
)
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
72
|
+
_logger.debug(f"{LOG_PREFIX}AgentHeartbeat commands: \n{pf(commands)}\n")
|
73
|
+
|
74
|
+
return commands
|
75
|
+
|
76
|
+
def _heartbeat(self, worker_id: int) -> bool:
|
77
|
+
# Make sure Scheduler is alive
|
78
|
+
if not self.is_alive():
|
79
|
+
return False
|
80
|
+
elif self.state == SchedulerState.FLUSH_RUNS:
|
81
|
+
# already hit run_cap, just noop
|
82
|
+
return False
|
83
|
+
|
84
|
+
commands: List[Dict[str, Any]] = self._get_sweep_commands(worker_id)
|
85
|
+
for command in commands:
|
86
|
+
# The command "type" can be one of "run", "resume", "stop", "exit"
|
87
|
+
_type = command.get("type")
|
88
|
+
if _type in ["exit", "stop"]:
|
89
|
+
run_cap = command.get("run_cap")
|
90
|
+
if run_cap is not None:
|
91
|
+
# If Sweep hit run_cap, go into flushing state
|
92
|
+
wandb.termlog(f"{LOG_PREFIX}Sweep hit run_cap: {run_cap}")
|
93
|
+
self.state = SchedulerState.FLUSH_RUNS
|
94
|
+
else:
|
97
95
|
# Tell (virtual) agent to stop running
|
98
96
|
self.state = SchedulerState.STOPPED
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
self.state = SchedulerState.FAILED
|
105
|
-
raise SchedulerError(
|
106
|
-
f"AgentHeartbeat command {command} missing run_id"
|
107
|
-
)
|
108
|
-
if _run_id in self._runs:
|
109
|
-
wandb.termlog(f"{LOG_PREFIX} Skipping duplicate run {run_id}")
|
110
|
-
else:
|
111
|
-
run = SweepRun(
|
112
|
-
id=_run_id,
|
113
|
-
args=command.get("args", {}),
|
114
|
-
logs=command.get("logs", []),
|
115
|
-
program=command.get("program", None),
|
116
|
-
worker_id=worker_id,
|
117
|
-
)
|
118
|
-
self._runs[run.id] = run
|
119
|
-
self._heartbeat_queue.put(run)
|
120
|
-
else:
|
97
|
+
return False
|
98
|
+
|
99
|
+
if _type in ["run", "resume"]:
|
100
|
+
_run_id = command.get("run_id")
|
101
|
+
if not _run_id:
|
121
102
|
self.state = SchedulerState.FAILED
|
122
|
-
raise SchedulerError(f"
|
103
|
+
raise SchedulerError(f"No runId in agent heartbeat: {command}")
|
104
|
+
if _run_id in self._runs:
|
105
|
+
wandb.termlog(f"{LOG_PREFIX}Skipping duplicate run: {_run_id}")
|
106
|
+
continue
|
107
|
+
|
108
|
+
run = SweepRun(
|
109
|
+
id=_run_id,
|
110
|
+
args=command.get("args", {}),
|
111
|
+
logs=command.get("logs", []),
|
112
|
+
worker_id=worker_id,
|
113
|
+
)
|
114
|
+
self._runs[run.id] = run
|
115
|
+
self._heartbeat_queue.put(run)
|
116
|
+
else:
|
117
|
+
self.state = SchedulerState.FAILED
|
118
|
+
raise SchedulerError(f"AgentHeartbeat unknown command: {_type}")
|
119
|
+
return True
|
123
120
|
|
124
121
|
def _run(self) -> None:
|
125
122
|
# Go through all workers and heartbeat
|
126
|
-
for worker_id in self._workers
|
123
|
+
for worker_id in self._workers:
|
127
124
|
self._heartbeat(worker_id)
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
"run_config": LegacySweepAgent._create_command_args(
|
152
|
-
{"args": run.args}
|
153
|
-
)["args_dict"]
|
154
|
-
}
|
155
|
-
},
|
156
|
-
)
|
125
|
+
|
126
|
+
for _worker_id in self._workers:
|
127
|
+
try:
|
128
|
+
run: SweepRun = self._heartbeat_queue.get(
|
129
|
+
timeout=self._heartbeat_queue_timeout
|
130
|
+
)
|
131
|
+
|
132
|
+
# If run is already stopped just ignore the request
|
133
|
+
if run.state in [RunState.DEAD, RunState.UNKNOWN]:
|
134
|
+
wandb.termwarn(f"{LOG_PREFIX}Ignoring dead run {run.id}")
|
135
|
+
_logger.debug(f"dead run {run.id} state: {run.state}")
|
136
|
+
continue
|
137
|
+
|
138
|
+
sweep_args = _create_sweep_command_args({"args": run.args})["args_dict"]
|
139
|
+
launch_config = {"overrides": {"run_config": sweep_args}}
|
140
|
+
self._add_to_launch_queue(run_id=run.id, config=launch_config)
|
141
|
+
except queue.Empty:
|
142
|
+
if self.state == SchedulerState.FLUSH_RUNS:
|
143
|
+
wandb.termlog(f"{LOG_PREFIX}Sweep stopped, waiting on runs...")
|
144
|
+
else:
|
145
|
+
wandb.termlog(f"{LOG_PREFIX}No new runs to launch, waiting...")
|
146
|
+
time.sleep(self._heartbeat_queue_sleep)
|
147
|
+
return
|
157
148
|
|
158
149
|
def _exit(self) -> None:
|
159
150
|
pass
|
wandb/sdk/launch/utils.py
CHANGED
@@ -10,15 +10,49 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
10
10
|
import click
|
11
11
|
|
12
12
|
import wandb
|
13
|
+
import wandb.docker as docker
|
13
14
|
from wandb import util
|
14
15
|
from wandb.apis.internal import Api
|
15
|
-
from wandb.errors import CommError,
|
16
|
+
from wandb.errors import CommError, Error
|
16
17
|
from wandb.sdk.launch.wandb_reference import WandbReference
|
17
18
|
|
19
|
+
from .builder.templates._wandb_bootstrap import (
|
20
|
+
FAILED_PACKAGES_POSTFIX,
|
21
|
+
FAILED_PACKAGES_PREFIX,
|
22
|
+
)
|
23
|
+
|
24
|
+
FAILED_PACKAGES_REGEX = re.compile(
|
25
|
+
f"{re.escape(FAILED_PACKAGES_PREFIX)}(.*){re.escape(FAILED_PACKAGES_POSTFIX)}"
|
26
|
+
)
|
27
|
+
|
18
28
|
if TYPE_CHECKING: # pragma: no cover
|
19
29
|
from wandb.apis.public import Artifact as PublicArtifact
|
20
30
|
|
21
31
|
|
32
|
+
class LaunchError(Error):
|
33
|
+
"""Raised when a known error occurs in wandb launch."""
|
34
|
+
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
class LaunchDockerError(Error):
|
39
|
+
"""Raised when Docker daemon is not running."""
|
40
|
+
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class ExecutionError(Error):
|
45
|
+
"""Generic execution exception."""
|
46
|
+
|
47
|
+
pass
|
48
|
+
|
49
|
+
|
50
|
+
class SweepError(Error):
|
51
|
+
"""Raised when a known error occurs with wandb sweeps."""
|
52
|
+
|
53
|
+
pass
|
54
|
+
|
55
|
+
|
22
56
|
# TODO: this should be restricted to just Git repos and not S3 and stuff like that
|
23
57
|
_GIT_URI_REGEX = re.compile(r"^[^/|^~|^\.].*(git|bitbucket)")
|
24
58
|
_VALID_IP_REGEX = r"^https?://[0-9]+(?:\.[0-9]+){3}(:[0-9]+)?"
|
@@ -128,11 +162,10 @@ def construct_launch_spec(
|
|
128
162
|
parameters: Optional[Dict[str, Any]],
|
129
163
|
resource_args: Optional[Dict[str, Any]],
|
130
164
|
launch_config: Optional[Dict[str, Any]],
|
131
|
-
cuda: Optional[bool],
|
132
165
|
run_id: Optional[str],
|
133
166
|
repository: Optional[str],
|
134
167
|
) -> Dict[str, Any]:
|
135
|
-
"""
|
168
|
+
"""Construct the launch specification from CLI arguments."""
|
136
169
|
# override base config (if supplied) with supplied args
|
137
170
|
launch_spec = launch_config if launch_config is not None else {}
|
138
171
|
if uri is not None:
|
@@ -184,8 +217,6 @@ def construct_launch_spec(
|
|
184
217
|
|
185
218
|
if entry_point:
|
186
219
|
launch_spec["overrides"]["entry_point"] = entry_point
|
187
|
-
if cuda is not None:
|
188
|
-
launch_spec["cuda"] = cuda
|
189
220
|
|
190
221
|
if run_id is not None:
|
191
222
|
launch_spec["run_id"] = run_id
|
@@ -214,7 +245,7 @@ def validate_launch_spec_source(launch_spec: Dict[str, Any]) -> None:
|
|
214
245
|
|
215
246
|
|
216
247
|
def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
217
|
-
"""
|
248
|
+
"""Parse wandb uri to retrieve entity, project and run name."""
|
218
249
|
ref = WandbReference.parse(uri)
|
219
250
|
if not ref or not ref.entity or not ref.project or not ref.run_id:
|
220
251
|
raise LaunchError(f"Trouble parsing wandb uri {uri}")
|
@@ -222,10 +253,12 @@ def parse_wandb_uri(uri: str) -> Tuple[str, str, str]:
|
|
222
253
|
|
223
254
|
|
224
255
|
def is_bare_wandb_uri(uri: str) -> bool:
|
225
|
-
"""
|
226
|
-
|
256
|
+
"""Check that a wandb uri is valid.
|
257
|
+
|
258
|
+
URI must be in the format
|
259
|
+
`/<entity>/<project>/runs/<run_name>[other stuff]`
|
227
260
|
or
|
228
|
-
|
261
|
+
`/<entity>/<project>/artifacts/job/<job_name>[other stuff]`.
|
229
262
|
"""
|
230
263
|
_logger.info(f"Checking if uri {uri} is bare...")
|
231
264
|
return uri.startswith("/") and WandbReference.is_uri_job_or_run(uri)
|
@@ -306,7 +339,7 @@ def get_local_python_deps(
|
|
306
339
|
|
307
340
|
|
308
341
|
def diff_pip_requirements(req_1: List[str], req_2: List[str]) -> Dict[str, str]:
|
309
|
-
"""
|
342
|
+
"""Return a list of pip requirements that are not in req_1 but are in req_2."""
|
310
343
|
|
311
344
|
def _parse_req(req: List[str]) -> Dict[str, str]:
|
312
345
|
# TODO: This can be made more exhaustive, but for 99% of cases this is fine
|
@@ -366,7 +399,7 @@ def validate_wandb_python_deps(
|
|
366
399
|
requirements_file: Optional[str],
|
367
400
|
dir: str,
|
368
401
|
) -> None:
|
369
|
-
"""
|
402
|
+
"""Warn if local python dependencies differ from wandb requirements.txt."""
|
370
403
|
if requirements_file is not None:
|
371
404
|
requirements_path = os.path.join(dir, requirements_file)
|
372
405
|
with open(requirements_path) as f:
|
@@ -417,10 +450,7 @@ def apply_patch(patch_string: str, dst_dir: str) -> None:
|
|
417
450
|
|
418
451
|
|
419
452
|
def _make_refspec_from_version(version: Optional[str]) -> List[str]:
|
420
|
-
"""
|
421
|
-
Helper to create a refspec that checks for the existence of origin/main
|
422
|
-
and the version, if provided.
|
423
|
-
"""
|
453
|
+
"""Create a refspec that checks for the existence of origin/main and the version."""
|
424
454
|
if version:
|
425
455
|
return [f"+{version}"]
|
426
456
|
|
@@ -452,10 +482,10 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
|
|
452
482
|
repo.git.checkout(version)
|
453
483
|
except git.exc.GitCommandError as e:
|
454
484
|
raise LaunchError(
|
455
|
-
"Unable to checkout version '
|
485
|
+
f"Unable to checkout version '{version}' of git repo {uri}"
|
456
486
|
"- please ensure that the version exists in the repo. "
|
457
|
-
"Error:
|
458
|
-
)
|
487
|
+
f"Error: {e}"
|
488
|
+
) from e
|
459
489
|
else:
|
460
490
|
if getattr(repo, "references", None) is not None:
|
461
491
|
branches = [ref.name for ref in repo.references]
|
@@ -475,10 +505,10 @@ def _fetch_git_repo(dst_dir: str, uri: str, version: Optional[str]) -> str:
|
|
475
505
|
)
|
476
506
|
except (AttributeError, IndexError) as e:
|
477
507
|
raise LaunchError(
|
478
|
-
"Unable to checkout default version '
|
508
|
+
f"Unable to checkout default version '{version}' of git repo {uri} "
|
479
509
|
"- to specify a git version use: --git-version \n"
|
480
|
-
"Error:
|
481
|
-
)
|
510
|
+
f"Error: {e}"
|
511
|
+
) from e
|
482
512
|
|
483
513
|
repo.submodule_update(init=True, recursive=True)
|
484
514
|
return version
|
@@ -557,10 +587,9 @@ def validate_build_and_registry_configs(
|
|
557
587
|
|
558
588
|
|
559
589
|
def get_kube_context_and_api_client(
|
560
|
-
kubernetes: Any,
|
561
|
-
resource_args: Dict[str, Any],
|
590
|
+
kubernetes: Any,
|
591
|
+
resource_args: Dict[str, Any],
|
562
592
|
) -> Tuple[Any, Any]:
|
563
|
-
|
564
593
|
config_file = resource_args.get("config_file", None)
|
565
594
|
context = None
|
566
595
|
if config_file is not None or os.path.exists(os.path.expanduser("~/.kube/config")):
|
@@ -579,7 +608,14 @@ def get_kube_context_and_api_client(
|
|
579
608
|
raise LaunchError(f"Specified context {context_name} was not found.")
|
580
609
|
else:
|
581
610
|
context = active_context
|
582
|
-
|
611
|
+
# TODO: We should not really be performing this check if the user is not
|
612
|
+
# using EKS but I don't see an obvious way to make an eks specific code path
|
613
|
+
# right here.
|
614
|
+
util.get_module(
|
615
|
+
"awscli",
|
616
|
+
"awscli is required to load a kubernetes context "
|
617
|
+
"from eks. Please run `pip install wandb[launch]` to install it.",
|
618
|
+
)
|
583
619
|
kubernetes.config.load_kube_config(config_file, context["name"])
|
584
620
|
api_client = kubernetes.config.new_client_from_config(
|
585
621
|
config_file, context=context["name"]
|
@@ -598,7 +634,7 @@ def resolve_build_and_registry_config(
|
|
598
634
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
599
635
|
resolved_build_config: Dict[str, Any] = {}
|
600
636
|
if build_config is None and default_launch_config is not None:
|
601
|
-
resolved_build_config = default_launch_config.get("
|
637
|
+
resolved_build_config = default_launch_config.get("builder", {})
|
602
638
|
elif build_config is not None:
|
603
639
|
resolved_build_config = build_config
|
604
640
|
resolved_registry_config: Dict[str, Any] = {}
|
@@ -611,10 +647,10 @@ def resolve_build_and_registry_config(
|
|
611
647
|
|
612
648
|
|
613
649
|
def check_logged_in(api: Api) -> bool:
|
614
|
-
"""
|
615
|
-
|
616
|
-
|
617
|
-
|
650
|
+
"""Check if a user is logged in.
|
651
|
+
|
652
|
+
Raises an error if the viewer doesn't load (likely a broken API key). Expected time
|
653
|
+
cost is 0.1-0.2 seconds.
|
618
654
|
"""
|
619
655
|
res = api.api.viewer()
|
620
656
|
if not res:
|
@@ -633,3 +669,38 @@ def make_name_dns_safe(name: str) -> str:
|
|
633
669
|
# Actual length limit is 253, but we want to leave room for the generated suffix
|
634
670
|
resp = resp[:200]
|
635
671
|
return resp
|
672
|
+
|
673
|
+
|
674
|
+
def warn_failed_packages_from_build_logs(log: str, image_uri: str) -> None:
|
675
|
+
match = FAILED_PACKAGES_REGEX.search(log)
|
676
|
+
if match:
|
677
|
+
wandb.termwarn(
|
678
|
+
f"Failed to install the following packages: {match.group(1)} for image: {image_uri}. Will attempt to launch image without them."
|
679
|
+
)
|
680
|
+
|
681
|
+
|
682
|
+
def docker_image_exists(docker_image: str, should_raise: bool = False) -> bool:
|
683
|
+
"""Check if a specific image is already available.
|
684
|
+
|
685
|
+
Optionally raises an exception if the image is not found.
|
686
|
+
"""
|
687
|
+
_logger.info("Checking if base image exists...")
|
688
|
+
try:
|
689
|
+
docker.run(["docker", "image", "inspect", docker_image])
|
690
|
+
return True
|
691
|
+
except (docker.DockerError, ValueError) as e:
|
692
|
+
if should_raise:
|
693
|
+
raise e
|
694
|
+
_logger.info("Base image not found. Generating new base image")
|
695
|
+
return False
|
696
|
+
|
697
|
+
|
698
|
+
def pull_docker_image(docker_image: str) -> None:
|
699
|
+
"""Pull the requested docker image."""
|
700
|
+
if docker_image_exists(docker_image):
|
701
|
+
# don't pull images if they exist already, eg if they are local images
|
702
|
+
return
|
703
|
+
try:
|
704
|
+
docker.run(["docker", "pull", docker_image])
|
705
|
+
except docker.DockerError as e:
|
706
|
+
raise LaunchError(f"Docker server returned error: {e}")
|
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Support for parsing W&B URLs (which might be user provided) into constituent parts.
|
3
|
-
"""
|
1
|
+
"""Support for parsing W&B URLs (which might be user provided) into constituent parts."""
|
4
2
|
|
5
3
|
from dataclasses import dataclass
|
6
4
|
from enum import IntEnum
|
@@ -35,7 +33,6 @@ RESERVED_JOB_PATHS = ("_view",)
|
|
35
33
|
|
36
34
|
@dataclass
|
37
35
|
class WandbReference:
|
38
|
-
|
39
36
|
# TODO: This will include port, should we separate that out?
|
40
37
|
host: Optional[str] = None
|
41
38
|
|
@@ -88,9 +85,7 @@ class WandbReference:
|
|
88
85
|
|
89
86
|
@staticmethod
|
90
87
|
def parse(uri: str) -> Optional["WandbReference"]:
|
91
|
-
"""
|
92
|
-
Attempt to parse a string as a W&B URL.
|
93
|
-
"""
|
88
|
+
"""Attempt to parse a string as a W&B URL."""
|
94
89
|
# TODO: Error if HTTP and host is not localhost?
|
95
90
|
if (
|
96
91
|
not uri.startswith("/")
|
@@ -0,0 +1,166 @@
|
|
1
|
+
import inspect
|
2
|
+
import sys
|
3
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
4
|
+
|
5
|
+
from wandb.errors import UsageError
|
6
|
+
from wandb.sdk.wandb_settings import Settings
|
7
|
+
|
8
|
+
if sys.version_info >= (3, 8):
|
9
|
+
from typing import get_args, get_origin, get_type_hints
|
10
|
+
elif sys.version_info >= (3, 7):
|
11
|
+
from typing_extensions import get_args, get_origin, get_type_hints
|
12
|
+
else:
|
13
|
+
|
14
|
+
def get_args(obj: Any) -> Optional[Any]:
|
15
|
+
return obj.__args__ if hasattr(obj, "__args__") else None
|
16
|
+
|
17
|
+
def get_origin(obj: Any) -> Optional[Any]:
|
18
|
+
return obj.__origin__ if hasattr(obj, "__origin__") else None
|
19
|
+
|
20
|
+
def get_type_hints(obj: Any) -> Dict[str, Any]:
|
21
|
+
return dict(obj.__annotations__) if hasattr(obj, "__annotations__") else dict()
|
22
|
+
|
23
|
+
|
24
|
+
template = """
|
25
|
+
__all__ = ("SETTINGS_TOPOLOGICALLY_SORTED", "_Setting")
|
26
|
+
|
27
|
+
import sys
|
28
|
+
from typing import Tuple
|
29
|
+
|
30
|
+
if sys.version_info >= (3, 8):
|
31
|
+
from typing import Final, Literal
|
32
|
+
else:
|
33
|
+
from typing_extensions import Final, Literal
|
34
|
+
|
35
|
+
|
36
|
+
_Setting = Literal[
|
37
|
+
$settings_literal_list
|
38
|
+
]
|
39
|
+
|
40
|
+
SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
41
|
+
$settings_topologically_sorted
|
42
|
+
)
|
43
|
+
"""
|
44
|
+
|
45
|
+
|
46
|
+
class Graph:
|
47
|
+
# A simple class representing an unweighted directed graph
|
48
|
+
# that uses an adjacency list representation.
|
49
|
+
# We use to ensure that we don't have cyclic dependencies in the settings
|
50
|
+
# and that modifications to the settings are applied in the correct order.
|
51
|
+
def __init__(self) -> None:
|
52
|
+
self.adj_list: Dict[str, Set[str]] = {}
|
53
|
+
|
54
|
+
def add_node(self, node: str) -> None:
|
55
|
+
if node not in self.adj_list:
|
56
|
+
self.adj_list[node] = set()
|
57
|
+
|
58
|
+
def add_edge(self, node1: str, node2: str) -> None:
|
59
|
+
self.adj_list[node1].add(node2)
|
60
|
+
|
61
|
+
def get_neighbors(self, node: str) -> Set[str]:
|
62
|
+
return self.adj_list[node]
|
63
|
+
|
64
|
+
# return a list of nodes sorted in topological order
|
65
|
+
def topological_sort_dfs(self) -> List[str]:
|
66
|
+
sorted_copy = {k: sorted(v) for k, v in self.adj_list.items()}
|
67
|
+
|
68
|
+
sorted_nodes: List[str] = []
|
69
|
+
visited_nodes: Set[str] = set()
|
70
|
+
current_nodes: Set[str] = set()
|
71
|
+
|
72
|
+
def visit(n: str) -> None:
|
73
|
+
if n in visited_nodes:
|
74
|
+
return None
|
75
|
+
if n in current_nodes:
|
76
|
+
raise UsageError("Cyclic dependency detected in wandb.Settings")
|
77
|
+
|
78
|
+
current_nodes.add(n)
|
79
|
+
for neighbor in sorted_copy[n]:
|
80
|
+
visit(neighbor)
|
81
|
+
|
82
|
+
current_nodes.remove(n)
|
83
|
+
visited_nodes.add(n)
|
84
|
+
sorted_nodes.append(n)
|
85
|
+
|
86
|
+
return None
|
87
|
+
|
88
|
+
for node in self.adj_list:
|
89
|
+
if node not in visited_nodes:
|
90
|
+
visit(node)
|
91
|
+
|
92
|
+
return sorted_nodes
|
93
|
+
|
94
|
+
|
95
|
+
def _get_modification_order(
|
96
|
+
settings: Settings,
|
97
|
+
) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
|
98
|
+
"""Return the order in which settings should be modified, based on dependencies."""
|
99
|
+
dependency_graph = Graph()
|
100
|
+
|
101
|
+
props = tuple(get_type_hints(Settings).keys())
|
102
|
+
|
103
|
+
# discover prop dependencies from validator methods and runtime hooks
|
104
|
+
|
105
|
+
prefix = "_validate_"
|
106
|
+
symbols = set(dir(settings))
|
107
|
+
validator_methods = tuple(sorted(m for m in symbols if m.startswith(prefix)))
|
108
|
+
|
109
|
+
# extract dependencies from validator methods
|
110
|
+
for m in validator_methods:
|
111
|
+
setting = m.split(prefix)[1]
|
112
|
+
dependency_graph.add_node(setting)
|
113
|
+
# if the method is not static, inspect its code to find the attributes it depends on
|
114
|
+
if (
|
115
|
+
not isinstance(Settings.__dict__[m], staticmethod)
|
116
|
+
and not isinstance(Settings.__dict__[m], classmethod)
|
117
|
+
and Settings.__dict__[m].__code__.co_argcount > 0
|
118
|
+
):
|
119
|
+
unbound_closure_vars = inspect.getclosurevars(Settings.__dict__[m]).unbound
|
120
|
+
dependencies = (v for v in unbound_closure_vars if v in props)
|
121
|
+
for d in dependencies:
|
122
|
+
dependency_graph.add_node(d)
|
123
|
+
dependency_graph.add_edge(setting, d)
|
124
|
+
|
125
|
+
# extract dependencies from props' runtime hooks
|
126
|
+
default_props = settings._default_props()
|
127
|
+
for prop, spec in default_props.items():
|
128
|
+
if "hook" not in spec:
|
129
|
+
continue
|
130
|
+
|
131
|
+
dependency_graph.add_node(prop)
|
132
|
+
|
133
|
+
hook = spec["hook"]
|
134
|
+
if callable(hook):
|
135
|
+
hook = [hook]
|
136
|
+
|
137
|
+
for h in hook:
|
138
|
+
unbound_closure_vars = inspect.getclosurevars(h).unbound
|
139
|
+
dependencies = (v for v in unbound_closure_vars if v in props)
|
140
|
+
for d in dependencies:
|
141
|
+
dependency_graph.add_node(d)
|
142
|
+
dependency_graph.add_edge(prop, d)
|
143
|
+
|
144
|
+
modification_order = dependency_graph.topological_sort_dfs()
|
145
|
+
return props, tuple(modification_order)
|
146
|
+
|
147
|
+
|
148
|
+
def generate(settings: Settings) -> None:
|
149
|
+
_settings_literal_list, _settings_topologically_sorted = _get_modification_order(
|
150
|
+
settings
|
151
|
+
)
|
152
|
+
settings_literal_list = ", ".join(f'"{s}"' for s in _settings_literal_list)
|
153
|
+
settings_topologically_sorted = ", ".join(
|
154
|
+
f'"{s}"' for s in _settings_topologically_sorted
|
155
|
+
)
|
156
|
+
|
157
|
+
print(
|
158
|
+
template.replace("$settings_literal_list", settings_literal_list,).replace(
|
159
|
+
"$settings_topologically_sorted",
|
160
|
+
settings_topologically_sorted,
|
161
|
+
)
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
if __name__ == "__main__":
|
166
|
+
generate(Settings())
|