wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -3
- wandb/apis/__init__.py +1 -3
- wandb/apis/importers/__init__.py +4 -0
- wandb/apis/importers/base.py +312 -0
- wandb/apis/importers/mlflow.py +113 -0
- wandb/apis/internal.py +29 -2
- wandb/apis/normalize.py +6 -5
- wandb/apis/public.py +163 -180
- wandb/apis/reports/_templates.py +6 -12
- wandb/apis/reports/report.py +1 -1
- wandb/apis/reports/runset.py +1 -3
- wandb/apis/reports/util.py +12 -10
- wandb/beta/workflows.py +57 -34
- wandb/catboost/__init__.py +1 -2
- wandb/cli/cli.py +215 -133
- wandb/data_types.py +63 -56
- wandb/docker/__init__.py +78 -16
- wandb/docker/auth.py +21 -22
- wandb/env.py +0 -1
- wandb/errors/__init__.py +8 -116
- wandb/errors/term.py +1 -1
- wandb/fastai/__init__.py +1 -2
- wandb/filesync/dir_watcher.py +8 -5
- wandb/filesync/step_prepare.py +76 -75
- wandb/filesync/step_upload.py +1 -2
- wandb/integration/catboost/__init__.py +1 -3
- wandb/integration/catboost/catboost.py +8 -14
- wandb/integration/fastai/__init__.py +7 -13
- wandb/integration/gym/__init__.py +35 -4
- wandb/integration/keras/__init__.py +3 -3
- wandb/integration/keras/callbacks/metrics_logger.py +9 -8
- wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
- wandb/integration/keras/callbacks/tables_builder.py +31 -19
- wandb/integration/kfp/kfp_patch.py +20 -17
- wandb/integration/kfp/wandb_logging.py +1 -2
- wandb/integration/lightgbm/__init__.py +21 -19
- wandb/integration/prodigy/prodigy.py +6 -7
- wandb/integration/sacred/__init__.py +9 -12
- wandb/integration/sagemaker/__init__.py +1 -3
- wandb/integration/sagemaker/auth.py +0 -1
- wandb/integration/sagemaker/config.py +1 -1
- wandb/integration/sagemaker/resources.py +1 -1
- wandb/integration/sb3/sb3.py +8 -4
- wandb/integration/tensorboard/__init__.py +1 -3
- wandb/integration/tensorboard/log.py +8 -8
- wandb/integration/tensorboard/monkeypatch.py +11 -9
- wandb/integration/tensorflow/__init__.py +1 -3
- wandb/integration/xgboost/__init__.py +4 -6
- wandb/integration/yolov8/__init__.py +7 -0
- wandb/integration/yolov8/yolov8.py +250 -0
- wandb/jupyter.py +31 -35
- wandb/lightgbm/__init__.py +1 -2
- wandb/old/settings.py +2 -2
- wandb/plot/bar.py +1 -2
- wandb/plot/confusion_matrix.py +1 -3
- wandb/plot/histogram.py +1 -2
- wandb/plot/line.py +1 -2
- wandb/plot/line_series.py +4 -4
- wandb/plot/pr_curve.py +17 -20
- wandb/plot/roc_curve.py +1 -3
- wandb/plot/scatter.py +1 -2
- wandb/proto/v3/wandb_server_pb2.py +85 -39
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_server_pb2.py +51 -39
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/__init__.py +1 -3
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/_dtypes.py +38 -30
- wandb/sdk/data_types/base_types/json_metadata.py +1 -3
- wandb/sdk/data_types/base_types/media.py +17 -17
- wandb/sdk/data_types/base_types/wb_value.py +33 -26
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
- wandb/sdk/data_types/helper_types/classes.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +12 -12
- wandb/sdk/data_types/histogram.py +5 -4
- wandb/sdk/data_types/html.py +1 -2
- wandb/sdk/data_types/image.py +11 -11
- wandb/sdk/data_types/molecule.py +3 -6
- wandb/sdk/data_types/object_3d.py +1 -2
- wandb/sdk/data_types/plotly.py +1 -2
- wandb/sdk/data_types/saved_model.py +10 -8
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/data_logging.py +5 -5
- wandb/sdk/interface/artifacts.py +288 -266
- wandb/sdk/interface/interface.py +2 -3
- wandb/sdk/interface/interface_grpc.py +1 -1
- wandb/sdk/interface/interface_queue.py +1 -1
- wandb/sdk/interface/interface_relay.py +1 -1
- wandb/sdk/interface/interface_shared.py +1 -2
- wandb/sdk/interface/interface_sock.py +1 -1
- wandb/sdk/interface/message_future.py +1 -1
- wandb/sdk/interface/message_future_poll.py +1 -1
- wandb/sdk/interface/router.py +1 -1
- wandb/sdk/interface/router_queue.py +1 -1
- wandb/sdk/interface/router_relay.py +1 -1
- wandb/sdk/interface/router_sock.py +1 -1
- wandb/sdk/interface/summary_record.py +1 -1
- wandb/sdk/internal/artifacts.py +1 -1
- wandb/sdk/internal/datastore.py +2 -3
- wandb/sdk/internal/file_pusher.py +5 -3
- wandb/sdk/internal/file_stream.py +22 -19
- wandb/sdk/internal/handler.py +5 -4
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +115 -55
- wandb/sdk/internal/job_builder.py +1 -3
- wandb/sdk/internal/profiler.py +1 -1
- wandb/sdk/internal/progress.py +4 -6
- wandb/sdk/internal/sample.py +1 -3
- wandb/sdk/internal/sender.py +28 -16
- wandb/sdk/internal/settings_static.py +5 -5
- wandb/sdk/internal/system/assets/__init__.py +1 -0
- wandb/sdk/internal/system/assets/cpu.py +3 -9
- wandb/sdk/internal/system/assets/disk.py +2 -4
- wandb/sdk/internal/system/assets/gpu.py +6 -18
- wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
- wandb/sdk/internal/system/assets/interfaces.py +50 -22
- wandb/sdk/internal/system/assets/ipu.py +1 -3
- wandb/sdk/internal/system/assets/memory.py +7 -13
- wandb/sdk/internal/system/assets/network.py +4 -8
- wandb/sdk/internal/system/assets/open_metrics.py +283 -0
- wandb/sdk/internal/system/assets/tpu.py +1 -4
- wandb/sdk/internal/system/assets/trainium.py +26 -14
- wandb/sdk/internal/system/system_info.py +2 -3
- wandb/sdk/internal/system/system_monitor.py +52 -20
- wandb/sdk/internal/tb_watcher.py +12 -13
- wandb/sdk/launch/_project_spec.py +54 -65
- wandb/sdk/launch/agent/agent.py +374 -90
- wandb/sdk/launch/builder/abstract.py +61 -7
- wandb/sdk/launch/builder/build.py +81 -110
- wandb/sdk/launch/builder/docker_builder.py +181 -0
- wandb/sdk/launch/builder/kaniko_builder.py +419 -0
- wandb/sdk/launch/builder/noop.py +31 -12
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
- wandb/sdk/launch/environment/abstract.py +28 -0
- wandb/sdk/launch/environment/aws_environment.py +276 -0
- wandb/sdk/launch/environment/gcp_environment.py +271 -0
- wandb/sdk/launch/environment/local_environment.py +65 -0
- wandb/sdk/launch/github_reference.py +3 -8
- wandb/sdk/launch/launch.py +38 -29
- wandb/sdk/launch/launch_add.py +6 -8
- wandb/sdk/launch/loader.py +230 -0
- wandb/sdk/launch/registry/abstract.py +54 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
- wandb/sdk/launch/registry/local_registry.py +62 -0
- wandb/sdk/launch/runner/abstract.py +1 -16
- wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
- wandb/sdk/launch/runner/local_container.py +46 -22
- wandb/sdk/launch/runner/local_process.py +1 -4
- wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
- wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
- wandb/sdk/launch/sweeps/__init__.py +3 -2
- wandb/sdk/launch/sweeps/scheduler.py +132 -39
- wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
- wandb/sdk/launch/utils.py +101 -30
- wandb/sdk/launch/wandb_reference.py +2 -7
- wandb/sdk/lib/_settings_toposort_generate.py +166 -0
- wandb/sdk/lib/_settings_toposort_generated.py +201 -0
- wandb/sdk/lib/apikey.py +2 -4
- wandb/sdk/lib/config_util.py +4 -1
- wandb/sdk/lib/console.py +1 -3
- wandb/sdk/lib/deprecate.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +7 -5
- wandb/sdk/lib/filenames.py +1 -1
- wandb/sdk/lib/filesystem.py +61 -5
- wandb/sdk/lib/git.py +1 -3
- wandb/sdk/lib/import_hooks.py +4 -7
- wandb/sdk/lib/ipython.py +8 -5
- wandb/sdk/lib/lazyloader.py +1 -3
- wandb/sdk/lib/mailbox.py +14 -4
- wandb/sdk/lib/proto_util.py +10 -5
- wandb/sdk/lib/redirect.py +15 -22
- wandb/sdk/lib/reporting.py +1 -3
- wandb/sdk/lib/retry.py +4 -5
- wandb/sdk/lib/runid.py +1 -3
- wandb/sdk/lib/server.py +15 -9
- wandb/sdk/lib/sock_client.py +1 -1
- wandb/sdk/lib/sparkline.py +1 -1
- wandb/sdk/lib/wburls.py +1 -1
- wandb/sdk/service/port_file.py +1 -2
- wandb/sdk/service/service.py +36 -13
- wandb/sdk/service/service_base.py +12 -1
- wandb/sdk/verify/verify.py +5 -7
- wandb/sdk/wandb_artifacts.py +142 -177
- wandb/sdk/wandb_config.py +5 -8
- wandb/sdk/wandb_helper.py +1 -1
- wandb/sdk/wandb_init.py +24 -13
- wandb/sdk/wandb_login.py +9 -9
- wandb/sdk/wandb_manager.py +39 -4
- wandb/sdk/wandb_metric.py +2 -6
- wandb/sdk/wandb_require.py +4 -15
- wandb/sdk/wandb_require_helpers.py +1 -9
- wandb/sdk/wandb_run.py +95 -141
- wandb/sdk/wandb_save.py +1 -3
- wandb/sdk/wandb_settings.py +149 -54
- wandb/sdk/wandb_setup.py +66 -46
- wandb/sdk/wandb_summary.py +13 -10
- wandb/sdk/wandb_sweep.py +6 -7
- wandb/sdk/wandb_watch.py +1 -1
- wandb/sklearn/calculate/confusion_matrix.py +1 -1
- wandb/sklearn/calculate/learning_curve.py +1 -1
- wandb/sklearn/calculate/summary_metrics.py +1 -3
- wandb/sklearn/plot/__init__.py +1 -1
- wandb/sklearn/plot/classifier.py +27 -18
- wandb/sklearn/plot/clusterer.py +4 -5
- wandb/sklearn/plot/regressor.py +4 -4
- wandb/sklearn/plot/shared.py +2 -2
- wandb/sync/__init__.py +1 -3
- wandb/sync/sync.py +4 -5
- wandb/testing/relay.py +11 -10
- wandb/trigger.py +1 -1
- wandb/util.py +106 -81
- wandb/viz.py +4 -4
- wandb/wandb_agent.py +50 -50
- wandb/wandb_controller.py +2 -3
- wandb/wandb_run.py +1 -2
- wandb/wandb_torch.py +1 -1
- wandb/xgboost/__init__.py +1 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
- wandb/sdk/launch/builder/docker.py +0 -80
- wandb/sdk/launch/builder/kaniko.py +0 -393
- wandb/sdk/launch/builder/loader.py +0 -32
- wandb/sdk/launch/runner/loader.py +0 -50
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
|
|
1
|
+
"""Implements the AWS environment."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import re
|
6
|
+
from typing import Dict
|
7
|
+
|
8
|
+
from wandb.sdk.launch.utils import LaunchError
|
9
|
+
from wandb.util import get_module
|
10
|
+
|
11
|
+
from .abstract import AbstractEnvironment
|
12
|
+
|
13
|
+
boto3 = get_module(
|
14
|
+
"boto3",
|
15
|
+
required="AWS environment requires boto3 to be installed. Please install "
|
16
|
+
"it with `pip install wandb[launch]`.",
|
17
|
+
)
|
18
|
+
botocore = get_module(
|
19
|
+
"botocore",
|
20
|
+
required="AWS environment requires botocore to be installed. Please install "
|
21
|
+
"it with `pip install wandb[launch]`.",
|
22
|
+
)
|
23
|
+
|
24
|
+
_logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
S3_URI_RE = re.compile(r"s3://([^/]+)/(.+)")
|
27
|
+
|
28
|
+
|
29
|
+
class AwsEnvironment(AbstractEnvironment):
|
30
|
+
"""AWS environment."""
|
31
|
+
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
region: str,
|
35
|
+
access_key: str,
|
36
|
+
secret_key: str,
|
37
|
+
session_token: str,
|
38
|
+
verify: bool = True,
|
39
|
+
) -> None:
|
40
|
+
"""Initialize the AWS environment.
|
41
|
+
|
42
|
+
Arguments:
|
43
|
+
region (str): The AWS region.
|
44
|
+
|
45
|
+
Raises:
|
46
|
+
LaunchError: If the AWS environment is not configured correctly.
|
47
|
+
"""
|
48
|
+
super().__init__()
|
49
|
+
_logger.info(f"Initializing AWS environment in region {region}.")
|
50
|
+
self._region = region
|
51
|
+
self._access_key = access_key
|
52
|
+
self._secret_key = secret_key
|
53
|
+
self._session_token = session_token
|
54
|
+
if verify:
|
55
|
+
self.verify()
|
56
|
+
|
57
|
+
@classmethod
|
58
|
+
def from_default(cls, region: str, verify: bool = True) -> "AwsEnvironment":
|
59
|
+
"""Create an AWS environment from the default AWS environment.
|
60
|
+
|
61
|
+
Arguments:
|
62
|
+
region (str): The AWS region.
|
63
|
+
verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
AwsEnvironment: The AWS environment.
|
67
|
+
"""
|
68
|
+
_logger.info("Creating AWS environment from default credentials.")
|
69
|
+
try:
|
70
|
+
session = boto3.Session()
|
71
|
+
region = region or session.region_name
|
72
|
+
credentials = session.get_credentials()
|
73
|
+
if not credentials:
|
74
|
+
raise LaunchError(
|
75
|
+
"Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly."
|
76
|
+
)
|
77
|
+
access_key = credentials.access_key
|
78
|
+
secret_key = credentials.secret_key
|
79
|
+
session_token = credentials.token
|
80
|
+
except botocore.client.ClientError as e:
|
81
|
+
raise LaunchError(
|
82
|
+
f"Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly. {e}"
|
83
|
+
)
|
84
|
+
return cls(
|
85
|
+
region=region,
|
86
|
+
access_key=access_key,
|
87
|
+
secret_key=secret_key,
|
88
|
+
session_token=session_token,
|
89
|
+
verify=verify,
|
90
|
+
)
|
91
|
+
|
92
|
+
@classmethod
|
93
|
+
def from_config(
|
94
|
+
cls, config: Dict[str, str], verify: bool = True
|
95
|
+
) -> "AwsEnvironment":
|
96
|
+
"""Create an AWS environment from the default AWS environment.
|
97
|
+
|
98
|
+
Arguments:
|
99
|
+
config (dict): Configuration dictionary.
|
100
|
+
verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
AwsEnvironment: The AWS environment.
|
104
|
+
"""
|
105
|
+
region = str(config.get("region", ""))
|
106
|
+
if not region:
|
107
|
+
raise LaunchError(
|
108
|
+
"Could not create AWS environment from config. Region not specified."
|
109
|
+
)
|
110
|
+
return cls.from_default(
|
111
|
+
region=region,
|
112
|
+
verify=verify,
|
113
|
+
)
|
114
|
+
|
115
|
+
@property
|
116
|
+
def region(self) -> str:
|
117
|
+
"""The AWS region."""
|
118
|
+
return self._region
|
119
|
+
|
120
|
+
@region.setter
|
121
|
+
def region(self, region: str) -> None:
|
122
|
+
self._region = region
|
123
|
+
|
124
|
+
def verify(self) -> None:
|
125
|
+
"""Verify that the AWS environment is configured correctly.
|
126
|
+
|
127
|
+
Raises:
|
128
|
+
LaunchError: If the AWS environment is not configured correctly.
|
129
|
+
"""
|
130
|
+
_logger.debug("Verifying AWS environment.")
|
131
|
+
try:
|
132
|
+
session = self.get_session()
|
133
|
+
client = session.client("sts")
|
134
|
+
client.get_caller_identity()
|
135
|
+
# TODO: log identity details from the response
|
136
|
+
except botocore.exceptions.ClientError as e:
|
137
|
+
raise LaunchError(
|
138
|
+
f"Could not verify AWS environment. Please verify that your AWS credentials are configured correctly. {e}"
|
139
|
+
) from e
|
140
|
+
|
141
|
+
def get_session(self) -> "boto3.Session": # type: ignore
|
142
|
+
"""Get an AWS session.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
boto3.Session: The AWS session.
|
146
|
+
|
147
|
+
Raises:
|
148
|
+
LaunchError: If the AWS session could not be created.
|
149
|
+
"""
|
150
|
+
_logger.debug(f"Creating AWS session in region {self._region}")
|
151
|
+
try:
|
152
|
+
return boto3.Session(
|
153
|
+
aws_access_key_id=self._access_key,
|
154
|
+
aws_secret_access_key=self._secret_key,
|
155
|
+
aws_session_token=self._session_token,
|
156
|
+
region_name=self._region,
|
157
|
+
)
|
158
|
+
except botocore.exceptions.ClientError as e:
|
159
|
+
raise LaunchError(f"Could not create AWS session. {e}")
|
160
|
+
|
161
|
+
def upload_file(self, source: str, destination: str) -> None:
|
162
|
+
"""Upload a file to s3 from local storage.
|
163
|
+
|
164
|
+
The destination is a valid s3 URI, e.g. s3://bucket/key and will
|
165
|
+
be used as a prefix for the uploaded file. Only the filename of the source
|
166
|
+
is kept in the upload key. So if the source is "foo/bar" and the
|
167
|
+
destination is "s3://bucket/key", the file "foo/bar" will be uploaded
|
168
|
+
to "s3://bucket/key/bar".
|
169
|
+
|
170
|
+
Arguments:
|
171
|
+
source (str): The path to the file or directory.
|
172
|
+
destination (str): The uri of the storage destination. This should
|
173
|
+
be a valid s3 URI, e.g. s3://bucket/key.
|
174
|
+
|
175
|
+
Raises:
|
176
|
+
LaunchError: If the copy fails, the source path does not exist, or the
|
177
|
+
destination is not a valid s3 URI, or the upload fails.
|
178
|
+
"""
|
179
|
+
_logger.debug(f"Uploading {source} to {destination}")
|
180
|
+
if not os.path.isfile(source):
|
181
|
+
raise LaunchError(f"Source {source} does not exist.")
|
182
|
+
match = S3_URI_RE.match(destination)
|
183
|
+
if not match:
|
184
|
+
raise LaunchError(f"Destination {destination} is not a valid s3 URI.")
|
185
|
+
bucket = match.group(1)
|
186
|
+
key = match.group(2).lstrip("/")
|
187
|
+
if not key:
|
188
|
+
key = ""
|
189
|
+
session = self.get_session()
|
190
|
+
try:
|
191
|
+
client = session.client("s3")
|
192
|
+
client.upload_file(source, bucket, key)
|
193
|
+
except botocore.exceptions.ClientError as e:
|
194
|
+
raise LaunchError(
|
195
|
+
f"botocore error attempting to copy {source} to {destination}. {e}"
|
196
|
+
) from e
|
197
|
+
|
198
|
+
def upload_dir(self, source: str, destination: str) -> None:
|
199
|
+
"""Upload a directory to s3 from local storage.
|
200
|
+
|
201
|
+
The upload will place the contents of the source directory in the destination
|
202
|
+
with the same directory structure. So if the source is "foo/bar" and the
|
203
|
+
destination is "s3://bucket/key", the contents of "foo/bar" will be uploaded
|
204
|
+
to "s3://bucket/key/bar".
|
205
|
+
|
206
|
+
Arguments:
|
207
|
+
source (str): The path to the file or directory.
|
208
|
+
destination (str): The URI of the storage.
|
209
|
+
recursive (bool, optional): If True, copy the directory recursively. Defaults to False.
|
210
|
+
|
211
|
+
Raises:
|
212
|
+
LaunchError: If the copy fails, the source path does not exist, or the
|
213
|
+
destination is not a valid s3 URI.
|
214
|
+
"""
|
215
|
+
_logger.debug(f"Uploading {source} to {destination}")
|
216
|
+
if not os.path.isdir(source):
|
217
|
+
raise LaunchError(f"Source {source} does not exist.")
|
218
|
+
match = S3_URI_RE.match(destination)
|
219
|
+
if not match:
|
220
|
+
raise LaunchError(f"Destination {destination} is not a valid s3 URI.")
|
221
|
+
bucket = match.group(1)
|
222
|
+
key = match.group(2).lstrip("/")
|
223
|
+
if not key:
|
224
|
+
key = ""
|
225
|
+
session = self.get_session()
|
226
|
+
try:
|
227
|
+
client = session.client("s3")
|
228
|
+
for path, _, files in os.walk(source):
|
229
|
+
for file in files:
|
230
|
+
abs_path = os.path.join(path, file)
|
231
|
+
key_path = (
|
232
|
+
abs_path.replace(source, "").replace("\\", "/").lstrip("/")
|
233
|
+
)
|
234
|
+
client.upload_file(
|
235
|
+
abs_path,
|
236
|
+
bucket,
|
237
|
+
key_path,
|
238
|
+
)
|
239
|
+
except botocore.exceptions.ClientError as e:
|
240
|
+
raise LaunchError(
|
241
|
+
f"botocore error attempting to copy {source} to {destination}. {e}"
|
242
|
+
) from e
|
243
|
+
except Exception as e:
|
244
|
+
raise LaunchError(
|
245
|
+
f"Unexpected error attempting to copy {source} to {destination}. {e}"
|
246
|
+
) from e
|
247
|
+
|
248
|
+
def verify_storage_uri(self, uri: str) -> None:
|
249
|
+
"""Verify that s3 storage is configured correctly.
|
250
|
+
|
251
|
+
This will check that the bucket exists and that the credentials are
|
252
|
+
configured correctly.
|
253
|
+
|
254
|
+
Arguments:
|
255
|
+
uri (str): The URI of the storage.
|
256
|
+
|
257
|
+
Raises:
|
258
|
+
LaunchError: If the storage is not configured correctly or the URI is
|
259
|
+
not a valid s3 URI.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
None
|
263
|
+
"""
|
264
|
+
_logger.debug(f"Verifying storage {uri}")
|
265
|
+
match = S3_URI_RE.match(uri)
|
266
|
+
if not match:
|
267
|
+
raise LaunchError(f"Destination {uri} is not a valid s3 URI.")
|
268
|
+
bucket = match.group(1)
|
269
|
+
try:
|
270
|
+
session = self.get_session()
|
271
|
+
client = session.client("s3")
|
272
|
+
client.head_bucket(Bucket=bucket)
|
273
|
+
except botocore.exceptions.ClientError as e:
|
274
|
+
raise LaunchError(
|
275
|
+
f"Could not verify AWS storage. Please verify that your AWS credentials are configured correctly. {e}"
|
276
|
+
) from e
|
@@ -0,0 +1,271 @@
|
|
1
|
+
"""Implementation of the GCP environment for wandb launch."""
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
|
6
|
+
from wandb.sdk.launch.utils import LaunchError
|
7
|
+
from wandb.util import get_module
|
8
|
+
|
9
|
+
from .abstract import AbstractEnvironment
|
10
|
+
|
11
|
+
google = get_module(
|
12
|
+
"google",
|
13
|
+
required="Google Cloud Platform support requires the google package. Please"
|
14
|
+
" install it with `pip install wandb[launch]`.",
|
15
|
+
)
|
16
|
+
google.cloud.compute_v1 = get_module(
|
17
|
+
"google.cloud.compute_v1",
|
18
|
+
required="Google Cloud Platform support requires the google-cloud-compute package. "
|
19
|
+
"Please install it with `pip install wandb[launch]`.",
|
20
|
+
)
|
21
|
+
google.auth.credentials = get_module(
|
22
|
+
"google.auth.credentials",
|
23
|
+
required="Google Cloud Platform support requires google-auth. "
|
24
|
+
"Please install it with `pip install wandb[launch]`.",
|
25
|
+
)
|
26
|
+
google.auth.transport.requests = get_module(
|
27
|
+
"google.auth.transport.requests",
|
28
|
+
required="Google Cloud Platform support requires google-auth. "
|
29
|
+
"Please install it with `pip install wandb[launch]`.",
|
30
|
+
)
|
31
|
+
google.api_core.exceptions = get_module(
|
32
|
+
"google.api_core.exceptions",
|
33
|
+
required="Google Cloud Platform support requires google-api-core. "
|
34
|
+
"Please install it with `pip install wandb[launch]`.",
|
35
|
+
)
|
36
|
+
google.cloud.storage = get_module(
|
37
|
+
"google.cloud.storage",
|
38
|
+
required="Google Cloud Platform support requires google-cloud-storage. "
|
39
|
+
"Please install it with `pip install wandb[launch].",
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
_logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
GCS_URI_RE = re.compile(r"gs://([^/]+)/(.+)")
|
46
|
+
|
47
|
+
|
48
|
+
class GcpEnvironment(AbstractEnvironment):
|
49
|
+
"""GCP Environment.
|
50
|
+
|
51
|
+
Attributes:
|
52
|
+
region: The GCP region.
|
53
|
+
"""
|
54
|
+
|
55
|
+
region: str
|
56
|
+
|
57
|
+
def __init__(self, region: str, verify: bool = True) -> None:
|
58
|
+
"""Initialize the GCP environment.
|
59
|
+
|
60
|
+
Arguments:
|
61
|
+
region: The GCP region.
|
62
|
+
verify: Whether to verify the credentials, region, and project.
|
63
|
+
|
64
|
+
Raises:
|
65
|
+
LaunchError: If verify is True and the environment is not properly
|
66
|
+
configured.
|
67
|
+
"""
|
68
|
+
super().__init__()
|
69
|
+
_logger.info(f"Initializing GcpEnvironment in region {region}")
|
70
|
+
self.region: str = region
|
71
|
+
self._project = ""
|
72
|
+
if verify:
|
73
|
+
self.verify()
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def from_config(cls, config: dict) -> "GcpEnvironment":
|
77
|
+
"""Create a GcpEnvironment from a config dictionary.
|
78
|
+
|
79
|
+
Arguments:
|
80
|
+
config: The config dictionary.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
GcpEnvironment: The GcpEnvironment.
|
84
|
+
"""
|
85
|
+
if config.get("type") != "gcp":
|
86
|
+
raise LaunchError(
|
87
|
+
f"Could not create GcpEnvironment from config. Expected type 'gcp' "
|
88
|
+
f"but got '{config.get('type')}'."
|
89
|
+
)
|
90
|
+
region = config.get("region", None)
|
91
|
+
if not region:
|
92
|
+
raise LaunchError(
|
93
|
+
"Could not create GcpEnvironment from config. Missing 'region' "
|
94
|
+
"field."
|
95
|
+
)
|
96
|
+
return cls(region=region)
|
97
|
+
|
98
|
+
@property
|
99
|
+
def project(self) -> str:
|
100
|
+
"""Get the name of the gcp project.
|
101
|
+
|
102
|
+
The project name is determined by the credentials, so this method
|
103
|
+
verifies the credentials if they have not already been verified.
|
104
|
+
|
105
|
+
Returns:
|
106
|
+
str: The name of the gcp project.
|
107
|
+
|
108
|
+
Raises:
|
109
|
+
LaunchError: If the launch environment cannot be verified.
|
110
|
+
"""
|
111
|
+
if not self._project:
|
112
|
+
raise LaunchError(
|
113
|
+
"This GcpEnvironment has not been verified. Please call verify() "
|
114
|
+
"before accessing the project."
|
115
|
+
)
|
116
|
+
return self._project
|
117
|
+
|
118
|
+
def get_credentials(self) -> google.auth.credentials.Credentials: # type: ignore
|
119
|
+
"""Get the GCP credentials.
|
120
|
+
|
121
|
+
Uses google.auth.default() to get the credentials. If the credentials
|
122
|
+
are invalid, this method will refresh them. If the credentials are
|
123
|
+
still invalid after refreshing, this method will raise an error.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
google.auth.credentials.Credentials: The GCP credentials.
|
127
|
+
|
128
|
+
Raises:
|
129
|
+
LaunchError: If the GCP credentials are invalid.
|
130
|
+
"""
|
131
|
+
_logger.debug("Getting GCP credentials")
|
132
|
+
# TODO: Figure out a minimal set of scopes.
|
133
|
+
scopes = [
|
134
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
135
|
+
]
|
136
|
+
try:
|
137
|
+
creds, project = google.auth.default(scopes=scopes)
|
138
|
+
if not self._project:
|
139
|
+
self._project = project
|
140
|
+
_logger.debug("Refreshing GCP credentials")
|
141
|
+
creds.refresh(google.auth.transport.requests.Request())
|
142
|
+
except google.auth.exceptions.DefaultCredentialsError as e:
|
143
|
+
raise LaunchError(
|
144
|
+
"No Google Cloud Platform credentials found. Please run "
|
145
|
+
"`gcloud auth application-default login` or set the environment "
|
146
|
+
"variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
|
147
|
+
"service account key file."
|
148
|
+
) from e
|
149
|
+
except google.auth.exceptions.RefreshError as e:
|
150
|
+
raise LaunchError(
|
151
|
+
"Could not refresh Google Cloud Platform credentials. Please run "
|
152
|
+
"`gcloud auth application-default login` or set the environment "
|
153
|
+
"variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
|
154
|
+
"service account key file."
|
155
|
+
) from e
|
156
|
+
if not creds.valid:
|
157
|
+
raise LaunchError(
|
158
|
+
"Invalid Google Cloud Platform credentials. Please run "
|
159
|
+
"`gcloud auth application-default login` or set the environment "
|
160
|
+
"variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
|
161
|
+
"service account key file."
|
162
|
+
)
|
163
|
+
return creds
|
164
|
+
|
165
|
+
def verify(self) -> None:
|
166
|
+
"""Verify the credentials, region, and project.
|
167
|
+
|
168
|
+
Credentials and region are verified by calling get_credentials(). The
|
169
|
+
region and is verified by calling the compute API.
|
170
|
+
|
171
|
+
Raises:
|
172
|
+
LaunchError: If the credentials, region, or project are invalid.
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
None
|
176
|
+
"""
|
177
|
+
_logger.debug("Verifying GCP environment")
|
178
|
+
creds = self.get_credentials()
|
179
|
+
try:
|
180
|
+
# Check if the region is available using the compute API.
|
181
|
+
compute_client = google.cloud.compute_v1.RegionsClient(credentials=creds)
|
182
|
+
compute_client.get(project=self.project, region=self.region)
|
183
|
+
except google.api_core.exceptions.NotFound as e:
|
184
|
+
raise LaunchError(
|
185
|
+
f"Region {self.region} is not available in project {self.project}."
|
186
|
+
) from e
|
187
|
+
|
188
|
+
def verify_storage_uri(self, uri: str) -> None:
|
189
|
+
"""Verify that a storage URI is valid.
|
190
|
+
|
191
|
+
Arguments:
|
192
|
+
uri: The storage URI.
|
193
|
+
|
194
|
+
Raises:
|
195
|
+
LaunchError: If the storage URI is invalid.
|
196
|
+
"""
|
197
|
+
match = GCS_URI_RE.match(uri)
|
198
|
+
if not match:
|
199
|
+
raise LaunchError(f"Invalid GCS URI: {uri}")
|
200
|
+
bucket = match.group(1)
|
201
|
+
try:
|
202
|
+
storage_client = google.cloud.storage.Client(
|
203
|
+
credentials=self.get_credentials()
|
204
|
+
)
|
205
|
+
bucket = storage_client.post_bucket(bucket)
|
206
|
+
except google.api_core.exceptions.NotFound as e:
|
207
|
+
raise LaunchError(f"Bucket {bucket} does not exist.") from e
|
208
|
+
|
209
|
+
def upload_file(self, source: str, destination: str) -> None:
|
210
|
+
"""Upload a file to GCS.
|
211
|
+
|
212
|
+
Arguments:
|
213
|
+
source: The path to the local file.
|
214
|
+
destination: The path to the GCS file.
|
215
|
+
|
216
|
+
Raises:
|
217
|
+
LaunchError: If the file cannot be uploaded.
|
218
|
+
"""
|
219
|
+
_logger.debug(f"Uploading file {source} to {destination}")
|
220
|
+
if not os.path.isfile(source):
|
221
|
+
raise LaunchError(f"File {source} does not exist.")
|
222
|
+
match = GCS_URI_RE.match(destination)
|
223
|
+
if not match:
|
224
|
+
raise LaunchError(f"Invalid GCS URI: {destination}")
|
225
|
+
bucket = match.group(1)
|
226
|
+
key = match.group(2).lstrip("/")
|
227
|
+
try:
|
228
|
+
storage_client = google.cloud.storage.Client(
|
229
|
+
credentials=self.get_credentials()
|
230
|
+
)
|
231
|
+
bucket = storage_client.bucket(bucket)
|
232
|
+
blob = bucket.blob(key)
|
233
|
+
blob.upload_from_filename(source)
|
234
|
+
except google.api_core.exceptions.GoogleAPICallError as e:
|
235
|
+
raise LaunchError(f"Could not upload file to GCS: {e}") from e
|
236
|
+
|
237
|
+
def upload_dir(self, source: str, destination: str) -> None:
|
238
|
+
"""Upload a directory to GCS.
|
239
|
+
|
240
|
+
Arguments:
|
241
|
+
source: The path to the local directory.
|
242
|
+
destination: The path to the GCS directory.
|
243
|
+
|
244
|
+
Raises:
|
245
|
+
LaunchError: If the directory cannot be uploaded.
|
246
|
+
"""
|
247
|
+
_logger.debug(f"Uploading directory {source} to {destination}")
|
248
|
+
if not os.path.isdir(source):
|
249
|
+
raise LaunchError(f"Directory {source} does not exist.")
|
250
|
+
match = GCS_URI_RE.match(destination)
|
251
|
+
if not match:
|
252
|
+
raise LaunchError(f"Invalid GCS URI: {destination}")
|
253
|
+
bucket = match.group(1)
|
254
|
+
key = match.group(2).lstrip("/")
|
255
|
+
try:
|
256
|
+
storage_client = google.cloud.storage.Client(
|
257
|
+
credentials=self.get_credentials()
|
258
|
+
)
|
259
|
+
bucket = storage_client.bucket(bucket)
|
260
|
+
for root, _, files in os.walk(source):
|
261
|
+
for file in files:
|
262
|
+
local_path = os.path.join(root, file)
|
263
|
+
gcs_path = os.path.join(
|
264
|
+
key, os.path.relpath(local_path, source)
|
265
|
+
).replace("\\", "/")
|
266
|
+
blob = bucket.blob(gcs_path)
|
267
|
+
blob.upload_from_filename(local_path)
|
268
|
+
except google.api_core.exceptions.GoogleAPICallError as e:
|
269
|
+
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
270
|
+
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
271
|
+
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
@@ -0,0 +1,65 @@
|
|
1
|
+
"""Dummy local environment implementation. This is the default environment."""
|
2
|
+
from typing import Any, Dict, Union
|
3
|
+
|
4
|
+
from wandb.sdk.launch.utils import LaunchError
|
5
|
+
|
6
|
+
from .abstract import AbstractEnvironment
|
7
|
+
|
8
|
+
|
9
|
+
class LocalEnvironment(AbstractEnvironment):
|
10
|
+
"""Local environment class."""
|
11
|
+
|
12
|
+
def __init__(self) -> None:
|
13
|
+
"""Initialize a local environment by doing nothing."""
|
14
|
+
pass
|
15
|
+
|
16
|
+
@classmethod
|
17
|
+
def from_config(
|
18
|
+
cls, config: Dict[str, Union[Dict[str, Any], str]]
|
19
|
+
) -> "LocalEnvironment":
|
20
|
+
"""Create a local environment from a config.
|
21
|
+
|
22
|
+
Arguments:
|
23
|
+
config (dict): The config. This is ignored.
|
24
|
+
|
25
|
+
Returns:
|
26
|
+
LocalEnvironment: The local environment.
|
27
|
+
"""
|
28
|
+
return cls()
|
29
|
+
|
30
|
+
def verify(self) -> None:
|
31
|
+
"""Verify that the local environment is configured correctly."""
|
32
|
+
raise LaunchError("Attempted to verify LocalEnvironment.")
|
33
|
+
|
34
|
+
def verify_storage_uri(self, uri: str) -> None:
|
35
|
+
"""Verify that the storage URI is configured correctly.
|
36
|
+
|
37
|
+
Arguments:
|
38
|
+
uri (str): The storage URI. This is ignored.
|
39
|
+
"""
|
40
|
+
raise LaunchError("Attempted to verify storage uri for LocalEnvironment.")
|
41
|
+
|
42
|
+
def upload_file(self, source: str, destination: str) -> None:
|
43
|
+
"""Upload a file from the local filesystem to storage in the environment.
|
44
|
+
|
45
|
+
Arguments:
|
46
|
+
source (str): The source file. This is ignored.
|
47
|
+
destination (str): The destination file. This is ignored.
|
48
|
+
"""
|
49
|
+
raise LaunchError("Attempted to upload file for LocalEnvironment.")
|
50
|
+
|
51
|
+
def upload_dir(self, source: str, destination: str) -> None:
|
52
|
+
"""Upload the contents of a directory from the local filesystem to the environment.
|
53
|
+
|
54
|
+
Arguments:
|
55
|
+
source (str): The source directory. This is ignored.
|
56
|
+
destination (str): The destination directory. This is ignored.
|
57
|
+
"""
|
58
|
+
raise LaunchError("Attempted to upload directory for LocalEnvironment.")
|
59
|
+
|
60
|
+
def get_project(self) -> str:
|
61
|
+
"""Get the project of the local environment.
|
62
|
+
|
63
|
+
Returns: An empty string.
|
64
|
+
"""
|
65
|
+
raise LaunchError("Attempted to get project for LocalEnvironment.")
|
@@ -1,6 +1,4 @@
|
|
1
|
-
"""
|
2
|
-
Support for parsing GitHub URLs (which might be user provided) into constituent parts.
|
3
|
-
"""
|
1
|
+
"""Support for parsing GitHub URLs (which might be user provided) into constituent parts."""
|
4
2
|
|
5
3
|
import re
|
6
4
|
from dataclasses import dataclass
|
@@ -9,7 +7,7 @@ from pathlib import Path
|
|
9
7
|
from typing import Optional, Tuple
|
10
8
|
from urllib.parse import urlparse
|
11
9
|
|
12
|
-
from wandb.
|
10
|
+
from wandb.sdk.launch.utils import LaunchError
|
13
11
|
|
14
12
|
PREFIX_HTTPS = "https://"
|
15
13
|
PREFIX_SSH = "git@"
|
@@ -43,7 +41,6 @@ def _parse_netloc(netloc: str) -> Tuple[Optional[str], Optional[str], str]:
|
|
43
41
|
|
44
42
|
@dataclass
|
45
43
|
class GitHubReference:
|
46
|
-
|
47
44
|
username: Optional[str] = None
|
48
45
|
password: Optional[str] = None
|
49
46
|
host: Optional[str] = None
|
@@ -107,9 +104,7 @@ class GitHubReference:
|
|
107
104
|
|
108
105
|
@staticmethod
|
109
106
|
def parse(uri: str) -> Optional["GitHubReference"]:
|
110
|
-
"""
|
111
|
-
Attempt to parse a string as a GitHub URL.
|
112
|
-
"""
|
107
|
+
"""Attempt to parse a string as a GitHub URL."""
|
113
108
|
# Special case: git@github.com:wandb/wandb.git
|
114
109
|
ref = GitHubReference()
|
115
110
|
if uri.startswith(PREFIX_SSH):
|