wandb 0.13.10__py3-none-any.whl → 0.14.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (228) hide show
  1. wandb/__init__.py +2 -3
  2. wandb/apis/__init__.py +1 -3
  3. wandb/apis/importers/__init__.py +4 -0
  4. wandb/apis/importers/base.py +312 -0
  5. wandb/apis/importers/mlflow.py +113 -0
  6. wandb/apis/internal.py +29 -2
  7. wandb/apis/normalize.py +6 -5
  8. wandb/apis/public.py +163 -180
  9. wandb/apis/reports/_templates.py +6 -12
  10. wandb/apis/reports/report.py +1 -1
  11. wandb/apis/reports/runset.py +1 -3
  12. wandb/apis/reports/util.py +12 -10
  13. wandb/beta/workflows.py +57 -34
  14. wandb/catboost/__init__.py +1 -2
  15. wandb/cli/cli.py +215 -133
  16. wandb/data_types.py +63 -56
  17. wandb/docker/__init__.py +78 -16
  18. wandb/docker/auth.py +21 -22
  19. wandb/env.py +0 -1
  20. wandb/errors/__init__.py +8 -116
  21. wandb/errors/term.py +1 -1
  22. wandb/fastai/__init__.py +1 -2
  23. wandb/filesync/dir_watcher.py +8 -5
  24. wandb/filesync/step_prepare.py +76 -75
  25. wandb/filesync/step_upload.py +1 -2
  26. wandb/integration/catboost/__init__.py +1 -3
  27. wandb/integration/catboost/catboost.py +8 -14
  28. wandb/integration/fastai/__init__.py +7 -13
  29. wandb/integration/gym/__init__.py +35 -4
  30. wandb/integration/keras/__init__.py +3 -3
  31. wandb/integration/keras/callbacks/metrics_logger.py +9 -8
  32. wandb/integration/keras/callbacks/model_checkpoint.py +9 -9
  33. wandb/integration/keras/callbacks/tables_builder.py +31 -19
  34. wandb/integration/kfp/kfp_patch.py +20 -17
  35. wandb/integration/kfp/wandb_logging.py +1 -2
  36. wandb/integration/lightgbm/__init__.py +21 -19
  37. wandb/integration/prodigy/prodigy.py +6 -7
  38. wandb/integration/sacred/__init__.py +9 -12
  39. wandb/integration/sagemaker/__init__.py +1 -3
  40. wandb/integration/sagemaker/auth.py +0 -1
  41. wandb/integration/sagemaker/config.py +1 -1
  42. wandb/integration/sagemaker/resources.py +1 -1
  43. wandb/integration/sb3/sb3.py +8 -4
  44. wandb/integration/tensorboard/__init__.py +1 -3
  45. wandb/integration/tensorboard/log.py +8 -8
  46. wandb/integration/tensorboard/monkeypatch.py +11 -9
  47. wandb/integration/tensorflow/__init__.py +1 -3
  48. wandb/integration/xgboost/__init__.py +4 -6
  49. wandb/integration/yolov8/__init__.py +7 -0
  50. wandb/integration/yolov8/yolov8.py +250 -0
  51. wandb/jupyter.py +31 -35
  52. wandb/lightgbm/__init__.py +1 -2
  53. wandb/old/settings.py +2 -2
  54. wandb/plot/bar.py +1 -2
  55. wandb/plot/confusion_matrix.py +1 -3
  56. wandb/plot/histogram.py +1 -2
  57. wandb/plot/line.py +1 -2
  58. wandb/plot/line_series.py +4 -4
  59. wandb/plot/pr_curve.py +17 -20
  60. wandb/plot/roc_curve.py +1 -3
  61. wandb/plot/scatter.py +1 -2
  62. wandb/proto/v3/wandb_server_pb2.py +85 -39
  63. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  64. wandb/proto/v4/wandb_server_pb2.py +51 -39
  65. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  66. wandb/sdk/__init__.py +1 -3
  67. wandb/sdk/backend/backend.py +1 -1
  68. wandb/sdk/data_types/_dtypes.py +38 -30
  69. wandb/sdk/data_types/base_types/json_metadata.py +1 -3
  70. wandb/sdk/data_types/base_types/media.py +17 -17
  71. wandb/sdk/data_types/base_types/wb_value.py +33 -26
  72. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +91 -125
  73. wandb/sdk/data_types/helper_types/classes.py +1 -1
  74. wandb/sdk/data_types/helper_types/image_mask.py +12 -12
  75. wandb/sdk/data_types/histogram.py +5 -4
  76. wandb/sdk/data_types/html.py +1 -2
  77. wandb/sdk/data_types/image.py +11 -11
  78. wandb/sdk/data_types/molecule.py +3 -6
  79. wandb/sdk/data_types/object_3d.py +1 -2
  80. wandb/sdk/data_types/plotly.py +1 -2
  81. wandb/sdk/data_types/saved_model.py +10 -8
  82. wandb/sdk/data_types/video.py +1 -1
  83. wandb/sdk/integration_utils/data_logging.py +5 -5
  84. wandb/sdk/interface/artifacts.py +288 -266
  85. wandb/sdk/interface/interface.py +2 -3
  86. wandb/sdk/interface/interface_grpc.py +1 -1
  87. wandb/sdk/interface/interface_queue.py +1 -1
  88. wandb/sdk/interface/interface_relay.py +1 -1
  89. wandb/sdk/interface/interface_shared.py +1 -2
  90. wandb/sdk/interface/interface_sock.py +1 -1
  91. wandb/sdk/interface/message_future.py +1 -1
  92. wandb/sdk/interface/message_future_poll.py +1 -1
  93. wandb/sdk/interface/router.py +1 -1
  94. wandb/sdk/interface/router_queue.py +1 -1
  95. wandb/sdk/interface/router_relay.py +1 -1
  96. wandb/sdk/interface/router_sock.py +1 -1
  97. wandb/sdk/interface/summary_record.py +1 -1
  98. wandb/sdk/internal/artifacts.py +1 -1
  99. wandb/sdk/internal/datastore.py +2 -3
  100. wandb/sdk/internal/file_pusher.py +5 -3
  101. wandb/sdk/internal/file_stream.py +22 -19
  102. wandb/sdk/internal/handler.py +5 -4
  103. wandb/sdk/internal/internal.py +1 -1
  104. wandb/sdk/internal/internal_api.py +115 -55
  105. wandb/sdk/internal/job_builder.py +1 -3
  106. wandb/sdk/internal/profiler.py +1 -1
  107. wandb/sdk/internal/progress.py +4 -6
  108. wandb/sdk/internal/sample.py +1 -3
  109. wandb/sdk/internal/sender.py +28 -16
  110. wandb/sdk/internal/settings_static.py +5 -5
  111. wandb/sdk/internal/system/assets/__init__.py +1 -0
  112. wandb/sdk/internal/system/assets/cpu.py +3 -9
  113. wandb/sdk/internal/system/assets/disk.py +2 -4
  114. wandb/sdk/internal/system/assets/gpu.py +6 -18
  115. wandb/sdk/internal/system/assets/gpu_apple.py +2 -4
  116. wandb/sdk/internal/system/assets/interfaces.py +50 -22
  117. wandb/sdk/internal/system/assets/ipu.py +1 -3
  118. wandb/sdk/internal/system/assets/memory.py +7 -13
  119. wandb/sdk/internal/system/assets/network.py +4 -8
  120. wandb/sdk/internal/system/assets/open_metrics.py +283 -0
  121. wandb/sdk/internal/system/assets/tpu.py +1 -4
  122. wandb/sdk/internal/system/assets/trainium.py +26 -14
  123. wandb/sdk/internal/system/system_info.py +2 -3
  124. wandb/sdk/internal/system/system_monitor.py +52 -20
  125. wandb/sdk/internal/tb_watcher.py +12 -13
  126. wandb/sdk/launch/_project_spec.py +54 -65
  127. wandb/sdk/launch/agent/agent.py +374 -90
  128. wandb/sdk/launch/builder/abstract.py +61 -7
  129. wandb/sdk/launch/builder/build.py +81 -110
  130. wandb/sdk/launch/builder/docker_builder.py +181 -0
  131. wandb/sdk/launch/builder/kaniko_builder.py +419 -0
  132. wandb/sdk/launch/builder/noop.py +31 -12
  133. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +70 -20
  134. wandb/sdk/launch/environment/abstract.py +28 -0
  135. wandb/sdk/launch/environment/aws_environment.py +276 -0
  136. wandb/sdk/launch/environment/gcp_environment.py +271 -0
  137. wandb/sdk/launch/environment/local_environment.py +65 -0
  138. wandb/sdk/launch/github_reference.py +3 -8
  139. wandb/sdk/launch/launch.py +38 -29
  140. wandb/sdk/launch/launch_add.py +6 -8
  141. wandb/sdk/launch/loader.py +230 -0
  142. wandb/sdk/launch/registry/abstract.py +54 -0
  143. wandb/sdk/launch/registry/elastic_container_registry.py +163 -0
  144. wandb/sdk/launch/registry/google_artifact_registry.py +203 -0
  145. wandb/sdk/launch/registry/local_registry.py +62 -0
  146. wandb/sdk/launch/runner/abstract.py +1 -16
  147. wandb/sdk/launch/runner/{kubernetes.py → kubernetes_runner.py} +83 -95
  148. wandb/sdk/launch/runner/local_container.py +46 -22
  149. wandb/sdk/launch/runner/local_process.py +1 -4
  150. wandb/sdk/launch/runner/{aws.py → sagemaker_runner.py} +53 -212
  151. wandb/sdk/launch/runner/{gcp_vertex.py → vertex_runner.py} +38 -55
  152. wandb/sdk/launch/sweeps/__init__.py +3 -2
  153. wandb/sdk/launch/sweeps/scheduler.py +132 -39
  154. wandb/sdk/launch/sweeps/scheduler_sweep.py +80 -89
  155. wandb/sdk/launch/utils.py +101 -30
  156. wandb/sdk/launch/wandb_reference.py +2 -7
  157. wandb/sdk/lib/_settings_toposort_generate.py +166 -0
  158. wandb/sdk/lib/_settings_toposort_generated.py +201 -0
  159. wandb/sdk/lib/apikey.py +2 -4
  160. wandb/sdk/lib/config_util.py +4 -1
  161. wandb/sdk/lib/console.py +1 -3
  162. wandb/sdk/lib/deprecate.py +3 -3
  163. wandb/sdk/lib/file_stream_utils.py +7 -5
  164. wandb/sdk/lib/filenames.py +1 -1
  165. wandb/sdk/lib/filesystem.py +61 -5
  166. wandb/sdk/lib/git.py +1 -3
  167. wandb/sdk/lib/import_hooks.py +4 -7
  168. wandb/sdk/lib/ipython.py +8 -5
  169. wandb/sdk/lib/lazyloader.py +1 -3
  170. wandb/sdk/lib/mailbox.py +14 -4
  171. wandb/sdk/lib/proto_util.py +10 -5
  172. wandb/sdk/lib/redirect.py +15 -22
  173. wandb/sdk/lib/reporting.py +1 -3
  174. wandb/sdk/lib/retry.py +4 -5
  175. wandb/sdk/lib/runid.py +1 -3
  176. wandb/sdk/lib/server.py +15 -9
  177. wandb/sdk/lib/sock_client.py +1 -1
  178. wandb/sdk/lib/sparkline.py +1 -1
  179. wandb/sdk/lib/wburls.py +1 -1
  180. wandb/sdk/service/port_file.py +1 -2
  181. wandb/sdk/service/service.py +36 -13
  182. wandb/sdk/service/service_base.py +12 -1
  183. wandb/sdk/verify/verify.py +5 -7
  184. wandb/sdk/wandb_artifacts.py +142 -177
  185. wandb/sdk/wandb_config.py +5 -8
  186. wandb/sdk/wandb_helper.py +1 -1
  187. wandb/sdk/wandb_init.py +24 -13
  188. wandb/sdk/wandb_login.py +9 -9
  189. wandb/sdk/wandb_manager.py +39 -4
  190. wandb/sdk/wandb_metric.py +2 -6
  191. wandb/sdk/wandb_require.py +4 -15
  192. wandb/sdk/wandb_require_helpers.py +1 -9
  193. wandb/sdk/wandb_run.py +95 -141
  194. wandb/sdk/wandb_save.py +1 -3
  195. wandb/sdk/wandb_settings.py +149 -54
  196. wandb/sdk/wandb_setup.py +66 -46
  197. wandb/sdk/wandb_summary.py +13 -10
  198. wandb/sdk/wandb_sweep.py +6 -7
  199. wandb/sdk/wandb_watch.py +1 -1
  200. wandb/sklearn/calculate/confusion_matrix.py +1 -1
  201. wandb/sklearn/calculate/learning_curve.py +1 -1
  202. wandb/sklearn/calculate/summary_metrics.py +1 -3
  203. wandb/sklearn/plot/__init__.py +1 -1
  204. wandb/sklearn/plot/classifier.py +27 -18
  205. wandb/sklearn/plot/clusterer.py +4 -5
  206. wandb/sklearn/plot/regressor.py +4 -4
  207. wandb/sklearn/plot/shared.py +2 -2
  208. wandb/sync/__init__.py +1 -3
  209. wandb/sync/sync.py +4 -5
  210. wandb/testing/relay.py +11 -10
  211. wandb/trigger.py +1 -1
  212. wandb/util.py +106 -81
  213. wandb/viz.py +4 -4
  214. wandb/wandb_agent.py +50 -50
  215. wandb/wandb_controller.py +2 -3
  216. wandb/wandb_run.py +1 -2
  217. wandb/wandb_torch.py +1 -1
  218. wandb/xgboost/__init__.py +1 -2
  219. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/METADATA +6 -2
  220. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/RECORD +224 -209
  221. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/WHEEL +1 -1
  222. wandb/sdk/launch/builder/docker.py +0 -80
  223. wandb/sdk/launch/builder/kaniko.py +0 -393
  224. wandb/sdk/launch/builder/loader.py +0 -32
  225. wandb/sdk/launch/runner/loader.py +0 -50
  226. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/LICENSE +0 -0
  227. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/entry_points.txt +0 -0
  228. {wandb-0.13.10.dist-info → wandb-0.14.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,276 @@
1
+ """Implements the AWS environment."""
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ from typing import Dict
7
+
8
+ from wandb.sdk.launch.utils import LaunchError
9
+ from wandb.util import get_module
10
+
11
+ from .abstract import AbstractEnvironment
12
+
13
+ boto3 = get_module(
14
+ "boto3",
15
+ required="AWS environment requires boto3 to be installed. Please install "
16
+ "it with `pip install wandb[launch]`.",
17
+ )
18
+ botocore = get_module(
19
+ "botocore",
20
+ required="AWS environment requires botocore to be installed. Please install "
21
+ "it with `pip install wandb[launch]`.",
22
+ )
23
+
24
+ _logger = logging.getLogger(__name__)
25
+
26
+ S3_URI_RE = re.compile(r"s3://([^/]+)/(.+)")
27
+
28
+
29
+ class AwsEnvironment(AbstractEnvironment):
30
+ """AWS environment."""
31
+
32
+ def __init__(
33
+ self,
34
+ region: str,
35
+ access_key: str,
36
+ secret_key: str,
37
+ session_token: str,
38
+ verify: bool = True,
39
+ ) -> None:
40
+ """Initialize the AWS environment.
41
+
42
+ Arguments:
43
+ region (str): The AWS region.
44
+
45
+ Raises:
46
+ LaunchError: If the AWS environment is not configured correctly.
47
+ """
48
+ super().__init__()
49
+ _logger.info(f"Initializing AWS environment in region {region}.")
50
+ self._region = region
51
+ self._access_key = access_key
52
+ self._secret_key = secret_key
53
+ self._session_token = session_token
54
+ if verify:
55
+ self.verify()
56
+
57
+ @classmethod
58
+ def from_default(cls, region: str, verify: bool = True) -> "AwsEnvironment":
59
+ """Create an AWS environment from the default AWS environment.
60
+
61
+ Arguments:
62
+ region (str): The AWS region.
63
+ verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
64
+
65
+ Returns:
66
+ AwsEnvironment: The AWS environment.
67
+ """
68
+ _logger.info("Creating AWS environment from default credentials.")
69
+ try:
70
+ session = boto3.Session()
71
+ region = region or session.region_name
72
+ credentials = session.get_credentials()
73
+ if not credentials:
74
+ raise LaunchError(
75
+ "Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly."
76
+ )
77
+ access_key = credentials.access_key
78
+ secret_key = credentials.secret_key
79
+ session_token = credentials.token
80
+ except botocore.client.ClientError as e:
81
+ raise LaunchError(
82
+ f"Could not create AWS environment from default environment. Please verify that your AWS credentials are configured correctly. {e}"
83
+ )
84
+ return cls(
85
+ region=region,
86
+ access_key=access_key,
87
+ secret_key=secret_key,
88
+ session_token=session_token,
89
+ verify=verify,
90
+ )
91
+
92
+ @classmethod
93
+ def from_config(
94
+ cls, config: Dict[str, str], verify: bool = True
95
+ ) -> "AwsEnvironment":
96
+ """Create an AWS environment from the default AWS environment.
97
+
98
+ Arguments:
99
+ config (dict): Configuration dictionary.
100
+ verify (bool, optional): Whether to verify the AWS environment. Defaults to True.
101
+
102
+ Returns:
103
+ AwsEnvironment: The AWS environment.
104
+ """
105
+ region = str(config.get("region", ""))
106
+ if not region:
107
+ raise LaunchError(
108
+ "Could not create AWS environment from config. Region not specified."
109
+ )
110
+ return cls.from_default(
111
+ region=region,
112
+ verify=verify,
113
+ )
114
+
115
+ @property
116
+ def region(self) -> str:
117
+ """The AWS region."""
118
+ return self._region
119
+
120
+ @region.setter
121
+ def region(self, region: str) -> None:
122
+ self._region = region
123
+
124
+ def verify(self) -> None:
125
+ """Verify that the AWS environment is configured correctly.
126
+
127
+ Raises:
128
+ LaunchError: If the AWS environment is not configured correctly.
129
+ """
130
+ _logger.debug("Verifying AWS environment.")
131
+ try:
132
+ session = self.get_session()
133
+ client = session.client("sts")
134
+ client.get_caller_identity()
135
+ # TODO: log identity details from the response
136
+ except botocore.exceptions.ClientError as e:
137
+ raise LaunchError(
138
+ f"Could not verify AWS environment. Please verify that your AWS credentials are configured correctly. {e}"
139
+ ) from e
140
+
141
+ def get_session(self) -> "boto3.Session": # type: ignore
142
+ """Get an AWS session.
143
+
144
+ Returns:
145
+ boto3.Session: The AWS session.
146
+
147
+ Raises:
148
+ LaunchError: If the AWS session could not be created.
149
+ """
150
+ _logger.debug(f"Creating AWS session in region {self._region}")
151
+ try:
152
+ return boto3.Session(
153
+ aws_access_key_id=self._access_key,
154
+ aws_secret_access_key=self._secret_key,
155
+ aws_session_token=self._session_token,
156
+ region_name=self._region,
157
+ )
158
+ except botocore.exceptions.ClientError as e:
159
+ raise LaunchError(f"Could not create AWS session. {e}")
160
+
161
+ def upload_file(self, source: str, destination: str) -> None:
162
+ """Upload a file to s3 from local storage.
163
+
164
+ The destination is a valid s3 URI, e.g. s3://bucket/key and will
165
+ be used as a prefix for the uploaded file. Only the filename of the source
166
+ is kept in the upload key. So if the source is "foo/bar" and the
167
+ destination is "s3://bucket/key", the file "foo/bar" will be uploaded
168
+ to "s3://bucket/key/bar".
169
+
170
+ Arguments:
171
+ source (str): The path to the file or directory.
172
+ destination (str): The uri of the storage destination. This should
173
+ be a valid s3 URI, e.g. s3://bucket/key.
174
+
175
+ Raises:
176
+ LaunchError: If the copy fails, the source path does not exist, or the
177
+ destination is not a valid s3 URI, or the upload fails.
178
+ """
179
+ _logger.debug(f"Uploading {source} to {destination}")
180
+ if not os.path.isfile(source):
181
+ raise LaunchError(f"Source {source} does not exist.")
182
+ match = S3_URI_RE.match(destination)
183
+ if not match:
184
+ raise LaunchError(f"Destination {destination} is not a valid s3 URI.")
185
+ bucket = match.group(1)
186
+ key = match.group(2).lstrip("/")
187
+ if not key:
188
+ key = ""
189
+ session = self.get_session()
190
+ try:
191
+ client = session.client("s3")
192
+ client.upload_file(source, bucket, key)
193
+ except botocore.exceptions.ClientError as e:
194
+ raise LaunchError(
195
+ f"botocore error attempting to copy {source} to {destination}. {e}"
196
+ ) from e
197
+
198
+ def upload_dir(self, source: str, destination: str) -> None:
199
+ """Upload a directory to s3 from local storage.
200
+
201
+ The upload will place the contents of the source directory in the destination
202
+ with the same directory structure. So if the source is "foo/bar" and the
203
+ destination is "s3://bucket/key", the contents of "foo/bar" will be uploaded
204
+ to "s3://bucket/key/bar".
205
+
206
+ Arguments:
207
+ source (str): The path to the file or directory.
208
+ destination (str): The URI of the storage.
209
+ recursive (bool, optional): If True, copy the directory recursively. Defaults to False.
210
+
211
+ Raises:
212
+ LaunchError: If the copy fails, the source path does not exist, or the
213
+ destination is not a valid s3 URI.
214
+ """
215
+ _logger.debug(f"Uploading {source} to {destination}")
216
+ if not os.path.isdir(source):
217
+ raise LaunchError(f"Source {source} does not exist.")
218
+ match = S3_URI_RE.match(destination)
219
+ if not match:
220
+ raise LaunchError(f"Destination {destination} is not a valid s3 URI.")
221
+ bucket = match.group(1)
222
+ key = match.group(2).lstrip("/")
223
+ if not key:
224
+ key = ""
225
+ session = self.get_session()
226
+ try:
227
+ client = session.client("s3")
228
+ for path, _, files in os.walk(source):
229
+ for file in files:
230
+ abs_path = os.path.join(path, file)
231
+ key_path = (
232
+ abs_path.replace(source, "").replace("\\", "/").lstrip("/")
233
+ )
234
+ client.upload_file(
235
+ abs_path,
236
+ bucket,
237
+ key_path,
238
+ )
239
+ except botocore.exceptions.ClientError as e:
240
+ raise LaunchError(
241
+ f"botocore error attempting to copy {source} to {destination}. {e}"
242
+ ) from e
243
+ except Exception as e:
244
+ raise LaunchError(
245
+ f"Unexpected error attempting to copy {source} to {destination}. {e}"
246
+ ) from e
247
+
248
+ def verify_storage_uri(self, uri: str) -> None:
249
+ """Verify that s3 storage is configured correctly.
250
+
251
+ This will check that the bucket exists and that the credentials are
252
+ configured correctly.
253
+
254
+ Arguments:
255
+ uri (str): The URI of the storage.
256
+
257
+ Raises:
258
+ LaunchError: If the storage is not configured correctly or the URI is
259
+ not a valid s3 URI.
260
+
261
+ Returns:
262
+ None
263
+ """
264
+ _logger.debug(f"Verifying storage {uri}")
265
+ match = S3_URI_RE.match(uri)
266
+ if not match:
267
+ raise LaunchError(f"Destination {uri} is not a valid s3 URI.")
268
+ bucket = match.group(1)
269
+ try:
270
+ session = self.get_session()
271
+ client = session.client("s3")
272
+ client.head_bucket(Bucket=bucket)
273
+ except botocore.exceptions.ClientError as e:
274
+ raise LaunchError(
275
+ f"Could not verify AWS storage. Please verify that your AWS credentials are configured correctly. {e}"
276
+ ) from e
@@ -0,0 +1,271 @@
1
+ """Implementation of the GCP environment for wandb launch."""
2
+ import logging
3
+ import os
4
+ import re
5
+
6
+ from wandb.sdk.launch.utils import LaunchError
7
+ from wandb.util import get_module
8
+
9
+ from .abstract import AbstractEnvironment
10
+
11
+ google = get_module(
12
+ "google",
13
+ required="Google Cloud Platform support requires the google package. Please"
14
+ " install it with `pip install wandb[launch]`.",
15
+ )
16
+ google.cloud.compute_v1 = get_module(
17
+ "google.cloud.compute_v1",
18
+ required="Google Cloud Platform support requires the google-cloud-compute package. "
19
+ "Please install it with `pip install wandb[launch]`.",
20
+ )
21
+ google.auth.credentials = get_module(
22
+ "google.auth.credentials",
23
+ required="Google Cloud Platform support requires google-auth. "
24
+ "Please install it with `pip install wandb[launch]`.",
25
+ )
26
+ google.auth.transport.requests = get_module(
27
+ "google.auth.transport.requests",
28
+ required="Google Cloud Platform support requires google-auth. "
29
+ "Please install it with `pip install wandb[launch]`.",
30
+ )
31
+ google.api_core.exceptions = get_module(
32
+ "google.api_core.exceptions",
33
+ required="Google Cloud Platform support requires google-api-core. "
34
+ "Please install it with `pip install wandb[launch]`.",
35
+ )
36
+ google.cloud.storage = get_module(
37
+ "google.cloud.storage",
38
+ required="Google Cloud Platform support requires google-cloud-storage. "
39
+ "Please install it with `pip install wandb[launch].",
40
+ )
41
+
42
+
43
+ _logger = logging.getLogger(__name__)
44
+
45
+ GCS_URI_RE = re.compile(r"gs://([^/]+)/(.+)")
46
+
47
+
48
+ class GcpEnvironment(AbstractEnvironment):
49
+ """GCP Environment.
50
+
51
+ Attributes:
52
+ region: The GCP region.
53
+ """
54
+
55
+ region: str
56
+
57
+ def __init__(self, region: str, verify: bool = True) -> None:
58
+ """Initialize the GCP environment.
59
+
60
+ Arguments:
61
+ region: The GCP region.
62
+ verify: Whether to verify the credentials, region, and project.
63
+
64
+ Raises:
65
+ LaunchError: If verify is True and the environment is not properly
66
+ configured.
67
+ """
68
+ super().__init__()
69
+ _logger.info(f"Initializing GcpEnvironment in region {region}")
70
+ self.region: str = region
71
+ self._project = ""
72
+ if verify:
73
+ self.verify()
74
+
75
+ @classmethod
76
+ def from_config(cls, config: dict) -> "GcpEnvironment":
77
+ """Create a GcpEnvironment from a config dictionary.
78
+
79
+ Arguments:
80
+ config: The config dictionary.
81
+
82
+ Returns:
83
+ GcpEnvironment: The GcpEnvironment.
84
+ """
85
+ if config.get("type") != "gcp":
86
+ raise LaunchError(
87
+ f"Could not create GcpEnvironment from config. Expected type 'gcp' "
88
+ f"but got '{config.get('type')}'."
89
+ )
90
+ region = config.get("region", None)
91
+ if not region:
92
+ raise LaunchError(
93
+ "Could not create GcpEnvironment from config. Missing 'region' "
94
+ "field."
95
+ )
96
+ return cls(region=region)
97
+
98
+ @property
99
+ def project(self) -> str:
100
+ """Get the name of the gcp project.
101
+
102
+ The project name is determined by the credentials, so this method
103
+ verifies the credentials if they have not already been verified.
104
+
105
+ Returns:
106
+ str: The name of the gcp project.
107
+
108
+ Raises:
109
+ LaunchError: If the launch environment cannot be verified.
110
+ """
111
+ if not self._project:
112
+ raise LaunchError(
113
+ "This GcpEnvironment has not been verified. Please call verify() "
114
+ "before accessing the project."
115
+ )
116
+ return self._project
117
+
118
+ def get_credentials(self) -> google.auth.credentials.Credentials: # type: ignore
119
+ """Get the GCP credentials.
120
+
121
+ Uses google.auth.default() to get the credentials. If the credentials
122
+ are invalid, this method will refresh them. If the credentials are
123
+ still invalid after refreshing, this method will raise an error.
124
+
125
+ Returns:
126
+ google.auth.credentials.Credentials: The GCP credentials.
127
+
128
+ Raises:
129
+ LaunchError: If the GCP credentials are invalid.
130
+ """
131
+ _logger.debug("Getting GCP credentials")
132
+ # TODO: Figure out a minimal set of scopes.
133
+ scopes = [
134
+ "https://www.googleapis.com/auth/cloud-platform",
135
+ ]
136
+ try:
137
+ creds, project = google.auth.default(scopes=scopes)
138
+ if not self._project:
139
+ self._project = project
140
+ _logger.debug("Refreshing GCP credentials")
141
+ creds.refresh(google.auth.transport.requests.Request())
142
+ except google.auth.exceptions.DefaultCredentialsError as e:
143
+ raise LaunchError(
144
+ "No Google Cloud Platform credentials found. Please run "
145
+ "`gcloud auth application-default login` or set the environment "
146
+ "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
147
+ "service account key file."
148
+ ) from e
149
+ except google.auth.exceptions.RefreshError as e:
150
+ raise LaunchError(
151
+ "Could not refresh Google Cloud Platform credentials. Please run "
152
+ "`gcloud auth application-default login` or set the environment "
153
+ "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
154
+ "service account key file."
155
+ ) from e
156
+ if not creds.valid:
157
+ raise LaunchError(
158
+ "Invalid Google Cloud Platform credentials. Please run "
159
+ "`gcloud auth application-default login` or set the environment "
160
+ "variable GOOGLE_APPLICATION_CREDENTIALS to the path of a valid "
161
+ "service account key file."
162
+ )
163
+ return creds
164
+
165
+ def verify(self) -> None:
166
+ """Verify the credentials, region, and project.
167
+
168
+ Credentials and region are verified by calling get_credentials(). The
169
+ region and is verified by calling the compute API.
170
+
171
+ Raises:
172
+ LaunchError: If the credentials, region, or project are invalid.
173
+
174
+ Returns:
175
+ None
176
+ """
177
+ _logger.debug("Verifying GCP environment")
178
+ creds = self.get_credentials()
179
+ try:
180
+ # Check if the region is available using the compute API.
181
+ compute_client = google.cloud.compute_v1.RegionsClient(credentials=creds)
182
+ compute_client.get(project=self.project, region=self.region)
183
+ except google.api_core.exceptions.NotFound as e:
184
+ raise LaunchError(
185
+ f"Region {self.region} is not available in project {self.project}."
186
+ ) from e
187
+
188
+ def verify_storage_uri(self, uri: str) -> None:
189
+ """Verify that a storage URI is valid.
190
+
191
+ Arguments:
192
+ uri: The storage URI.
193
+
194
+ Raises:
195
+ LaunchError: If the storage URI is invalid.
196
+ """
197
+ match = GCS_URI_RE.match(uri)
198
+ if not match:
199
+ raise LaunchError(f"Invalid GCS URI: {uri}")
200
+ bucket = match.group(1)
201
+ try:
202
+ storage_client = google.cloud.storage.Client(
203
+ credentials=self.get_credentials()
204
+ )
205
+ bucket = storage_client.post_bucket(bucket)
206
+ except google.api_core.exceptions.NotFound as e:
207
+ raise LaunchError(f"Bucket {bucket} does not exist.") from e
208
+
209
+ def upload_file(self, source: str, destination: str) -> None:
210
+ """Upload a file to GCS.
211
+
212
+ Arguments:
213
+ source: The path to the local file.
214
+ destination: The path to the GCS file.
215
+
216
+ Raises:
217
+ LaunchError: If the file cannot be uploaded.
218
+ """
219
+ _logger.debug(f"Uploading file {source} to {destination}")
220
+ if not os.path.isfile(source):
221
+ raise LaunchError(f"File {source} does not exist.")
222
+ match = GCS_URI_RE.match(destination)
223
+ if not match:
224
+ raise LaunchError(f"Invalid GCS URI: {destination}")
225
+ bucket = match.group(1)
226
+ key = match.group(2).lstrip("/")
227
+ try:
228
+ storage_client = google.cloud.storage.Client(
229
+ credentials=self.get_credentials()
230
+ )
231
+ bucket = storage_client.bucket(bucket)
232
+ blob = bucket.blob(key)
233
+ blob.upload_from_filename(source)
234
+ except google.api_core.exceptions.GoogleAPICallError as e:
235
+ raise LaunchError(f"Could not upload file to GCS: {e}") from e
236
+
237
+ def upload_dir(self, source: str, destination: str) -> None:
238
+ """Upload a directory to GCS.
239
+
240
+ Arguments:
241
+ source: The path to the local directory.
242
+ destination: The path to the GCS directory.
243
+
244
+ Raises:
245
+ LaunchError: If the directory cannot be uploaded.
246
+ """
247
+ _logger.debug(f"Uploading directory {source} to {destination}")
248
+ if not os.path.isdir(source):
249
+ raise LaunchError(f"Directory {source} does not exist.")
250
+ match = GCS_URI_RE.match(destination)
251
+ if not match:
252
+ raise LaunchError(f"Invalid GCS URI: {destination}")
253
+ bucket = match.group(1)
254
+ key = match.group(2).lstrip("/")
255
+ try:
256
+ storage_client = google.cloud.storage.Client(
257
+ credentials=self.get_credentials()
258
+ )
259
+ bucket = storage_client.bucket(bucket)
260
+ for root, _, files in os.walk(source):
261
+ for file in files:
262
+ local_path = os.path.join(root, file)
263
+ gcs_path = os.path.join(
264
+ key, os.path.relpath(local_path, source)
265
+ ).replace("\\", "/")
266
+ blob = bucket.blob(gcs_path)
267
+ blob.upload_from_filename(local_path)
268
+ except google.api_core.exceptions.GoogleAPICallError as e:
269
+ raise LaunchError(f"Could not upload directory to GCS: {e}") from e
270
+ raise LaunchError(f"Could not upload directory to GCS: {e}") from e
271
+ raise LaunchError(f"Could not upload directory to GCS: {e}") from e
@@ -0,0 +1,65 @@
1
+ """Dummy local environment implementation. This is the default environment."""
2
+ from typing import Any, Dict, Union
3
+
4
+ from wandb.sdk.launch.utils import LaunchError
5
+
6
+ from .abstract import AbstractEnvironment
7
+
8
+
9
+ class LocalEnvironment(AbstractEnvironment):
10
+ """Local environment class."""
11
+
12
+ def __init__(self) -> None:
13
+ """Initialize a local environment by doing nothing."""
14
+ pass
15
+
16
+ @classmethod
17
+ def from_config(
18
+ cls, config: Dict[str, Union[Dict[str, Any], str]]
19
+ ) -> "LocalEnvironment":
20
+ """Create a local environment from a config.
21
+
22
+ Arguments:
23
+ config (dict): The config. This is ignored.
24
+
25
+ Returns:
26
+ LocalEnvironment: The local environment.
27
+ """
28
+ return cls()
29
+
30
+ def verify(self) -> None:
31
+ """Verify that the local environment is configured correctly."""
32
+ raise LaunchError("Attempted to verify LocalEnvironment.")
33
+
34
+ def verify_storage_uri(self, uri: str) -> None:
35
+ """Verify that the storage URI is configured correctly.
36
+
37
+ Arguments:
38
+ uri (str): The storage URI. This is ignored.
39
+ """
40
+ raise LaunchError("Attempted to verify storage uri for LocalEnvironment.")
41
+
42
+ def upload_file(self, source: str, destination: str) -> None:
43
+ """Upload a file from the local filesystem to storage in the environment.
44
+
45
+ Arguments:
46
+ source (str): The source file. This is ignored.
47
+ destination (str): The destination file. This is ignored.
48
+ """
49
+ raise LaunchError("Attempted to upload file for LocalEnvironment.")
50
+
51
+ def upload_dir(self, source: str, destination: str) -> None:
52
+ """Upload the contents of a directory from the local filesystem to the environment.
53
+
54
+ Arguments:
55
+ source (str): The source directory. This is ignored.
56
+ destination (str): The destination directory. This is ignored.
57
+ """
58
+ raise LaunchError("Attempted to upload directory for LocalEnvironment.")
59
+
60
+ def get_project(self) -> str:
61
+ """Get the project of the local environment.
62
+
63
+ Returns: An empty string.
64
+ """
65
+ raise LaunchError("Attempted to get project for LocalEnvironment.")
@@ -1,6 +1,4 @@
1
- """
2
- Support for parsing GitHub URLs (which might be user provided) into constituent parts.
3
- """
1
+ """Support for parsing GitHub URLs (which might be user provided) into constituent parts."""
4
2
 
5
3
  import re
6
4
  from dataclasses import dataclass
@@ -9,7 +7,7 @@ from pathlib import Path
9
7
  from typing import Optional, Tuple
10
8
  from urllib.parse import urlparse
11
9
 
12
- from wandb.errors import LaunchError
10
+ from wandb.sdk.launch.utils import LaunchError
13
11
 
14
12
  PREFIX_HTTPS = "https://"
15
13
  PREFIX_SSH = "git@"
@@ -43,7 +41,6 @@ def _parse_netloc(netloc: str) -> Tuple[Optional[str], Optional[str], str]:
43
41
 
44
42
  @dataclass
45
43
  class GitHubReference:
46
-
47
44
  username: Optional[str] = None
48
45
  password: Optional[str] = None
49
46
  host: Optional[str] = None
@@ -107,9 +104,7 @@ class GitHubReference:
107
104
 
108
105
  @staticmethod
109
106
  def parse(uri: str) -> Optional["GitHubReference"]:
110
- """
111
- Attempt to parse a string as a GitHub URL.
112
- """
107
+ """Attempt to parse a string as a GitHub URL."""
113
108
  # Special case: git@github.com:wandb/wandb.git
114
109
  ref = GitHubReference()
115
110
  if uri.startswith(PREFIX_SSH):