wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/internal.py +3 -0
- wandb/apis/public.py +18 -20
- wandb/beta/workflows.py +5 -6
- wandb/cli/cli.py +27 -27
- wandb/data_types.py +2 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -4
- wandb/sdk/artifacts/__init__.py +0 -14
- wandb/sdk/artifacts/artifact.py +1757 -277
- wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/artifacts/artifacts_cache.py +7 -8
- wandb/sdk/artifacts/exceptions.py +4 -4
- wandb/sdk/artifacts/storage_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
- wandb/sdk/artifacts/storage_policy.py +3 -3
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +2 -2
- wandb/sdk/data_types/base_types/media.py +5 -6
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
- wandb/sdk/data_types/helper_types/classes.py +5 -8
- wandb/sdk/data_types/helper_types/image_mask.py +4 -5
- wandb/sdk/data_types/histogram.py +3 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +4 -5
- wandb/sdk/data_types/molecule.py +2 -2
- wandb/sdk/data_types/object_3d.py +3 -3
- wandb/sdk/data_types/plotly.py +2 -2
- wandb/sdk/data_types/saved_model.py +7 -8
- wandb/sdk/data_types/trace_tree.py +4 -4
- wandb/sdk/data_types/video.py +4 -4
- wandb/sdk/interface/interface.py +8 -10
- wandb/sdk/internal/file_stream.py +2 -3
- wandb/sdk/internal/internal_api.py +99 -4
- wandb/sdk/internal/job_builder.py +15 -7
- wandb/sdk/internal/sender.py +4 -0
- wandb/sdk/internal/settings_static.py +1 -0
- wandb/sdk/launch/_project_spec.py +9 -7
- wandb/sdk/launch/agent/agent.py +115 -58
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +16 -10
- wandb/sdk/launch/builder/docker_builder.py +9 -2
- wandb/sdk/launch/builder/kaniko_builder.py +108 -22
- wandb/sdk/launch/builder/noop.py +3 -1
- wandb/sdk/launch/environment/aws_environment.py +2 -1
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/github_reference.py +30 -18
- wandb/sdk/launch/launch.py +1 -1
- wandb/sdk/launch/loader.py +15 -0
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
- wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
- wandb/sdk/launch/runner/abstract.py +19 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
- wandb/sdk/launch/runner/local_container.py +101 -48
- wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
- wandb/sdk/launch/runner/vertex_runner.py +8 -4
- wandb/sdk/launch/sweeps/scheduler.py +102 -27
- wandb/sdk/launch/sweeps/utils.py +21 -0
- wandb/sdk/launch/utils.py +19 -7
- wandb/sdk/lib/_settings_toposort_generated.py +3 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +6 -9
- wandb/sdk/wandb_config.py +2 -4
- wandb/sdk/wandb_init.py +2 -0
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +32 -35
- wandb/sdk/wandb_settings.py +10 -3
- wandb/testing/relay.py +15 -2
- wandb/util.py +55 -23
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/artifacts/invalid_artifact.py +0 -23
- wandb/sdk/artifacts/lazy_artifact.py +0 -162
- wandb/sdk/artifacts/local_artifact.py +0 -719
- wandb/sdk/artifacts/public_artifact.py +0 -1188
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
- {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -8,9 +8,12 @@ import time
|
|
8
8
|
from typing import Optional
|
9
9
|
|
10
10
|
import wandb
|
11
|
+
from wandb.sdk.launch.agent.job_status_tracker import JobAndRunStatusTracker
|
11
12
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
12
13
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
14
|
+
from wandb.sdk.launch.environment.azure_environment import AzureEnvironment
|
13
15
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
16
|
+
from wandb.sdk.launch.registry.azure_container_registry import AzureContainerRegistry
|
14
17
|
from wandb.sdk.launch.registry.elastic_container_registry import (
|
15
18
|
ElasticContainerRegistry,
|
16
19
|
)
|
@@ -50,13 +53,19 @@ _DEFAULT_BUILD_TIMEOUT_SECS = 1800 # 30 minute build timeout
|
|
50
53
|
|
51
54
|
SERVICE_ACCOUNT_NAME = os.environ.get("WANDB_LAUNCH_SERVICE_ACCOUNT_NAME", "default")
|
52
55
|
|
56
|
+
if os.path.exists("/var/run/secrets/kubernetes.io/serviceaccount/namespace"):
|
57
|
+
with open("/var/run/secrets/kubernetes.io/serviceaccount/namespace") as f:
|
58
|
+
NAMESPACE = f.read().strip()
|
59
|
+
else:
|
60
|
+
NAMESPACE = "wandb"
|
61
|
+
|
53
62
|
|
54
63
|
def _wait_for_completion(
|
55
64
|
batch_client: client.BatchV1Api, job_name: str, deadline_secs: Optional[int] = None
|
56
65
|
) -> bool:
|
57
66
|
start_time = time.time()
|
58
67
|
while True:
|
59
|
-
job = batch_client.read_namespaced_job_status(job_name,
|
68
|
+
job = batch_client.read_namespaced_job_status(job_name, NAMESPACE)
|
60
69
|
if job.status.succeeded is not None and job.status.succeeded >= 1:
|
61
70
|
return True
|
62
71
|
elif job.status.failed is not None and job.status.failed >= 1:
|
@@ -78,6 +87,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
78
87
|
build_context_store: str
|
79
88
|
secret_name: Optional[str]
|
80
89
|
secret_key: Optional[str]
|
90
|
+
image: str
|
81
91
|
|
82
92
|
def __init__(
|
83
93
|
self,
|
@@ -87,6 +97,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
87
97
|
build_context_store: str = "",
|
88
98
|
secret_name: str = "",
|
89
99
|
secret_key: str = "",
|
100
|
+
image: str = "gcr.io/kaniko-project/executor:v1.11.0",
|
90
101
|
verify: bool = True,
|
91
102
|
):
|
92
103
|
"""Initialize a KanikoBuilder.
|
@@ -113,6 +124,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
113
124
|
self.build_context_store = build_context_store.rstrip("/")
|
114
125
|
self.secret_name = secret_name
|
115
126
|
self.secret_key = secret_key
|
127
|
+
self.image = image
|
116
128
|
if verify:
|
117
129
|
self.verify()
|
118
130
|
|
@@ -151,6 +163,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
151
163
|
build_job_name = config.get("build-job-name", "wandb-launch-container-build")
|
152
164
|
secret_name = config.get("secret-name", "")
|
153
165
|
secret_key = config.get("secret-key", "")
|
166
|
+
image = config.get("kaniko-image", "gcr.io/kaniko-project/executor:v1.11.0")
|
154
167
|
return cls(
|
155
168
|
environment,
|
156
169
|
registry,
|
@@ -158,6 +171,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
158
171
|
build_job_name=build_job_name,
|
159
172
|
secret_name=secret_name,
|
160
173
|
secret_key=secret_key,
|
174
|
+
image=image,
|
161
175
|
verify=verify,
|
162
176
|
)
|
163
177
|
|
@@ -187,7 +201,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
187
201
|
kind="ConfigMap",
|
188
202
|
metadata=client.V1ObjectMeta(
|
189
203
|
name=f"docker-config-{job_name}",
|
190
|
-
namespace=
|
204
|
+
namespace=NAMESPACE,
|
191
205
|
),
|
192
206
|
data={
|
193
207
|
"config.json": json.dumps(
|
@@ -196,13 +210,13 @@ class KanikoBuilder(AbstractBuilder):
|
|
196
210
|
},
|
197
211
|
immutable=True,
|
198
212
|
)
|
199
|
-
corev1_client.create_namespaced_config_map(
|
213
|
+
corev1_client.create_namespaced_config_map(NAMESPACE, ecr_config_map)
|
200
214
|
|
201
215
|
def _delete_docker_ecr_config_map(
|
202
216
|
self, job_name: str, client: client.CoreV1Api
|
203
217
|
) -> None:
|
204
218
|
if self.secret_name:
|
205
|
-
client.delete_namespaced_config_map(f"docker-config-{job_name}",
|
219
|
+
client.delete_namespaced_config_map(f"docker-config-{job_name}", NAMESPACE)
|
206
220
|
|
207
221
|
def _upload_build_context(self, run_id: str, context_path: str) -> str:
|
208
222
|
# creat a tar archive of the build context and upload it to s3
|
@@ -220,6 +234,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
220
234
|
self,
|
221
235
|
launch_project: LaunchProject,
|
222
236
|
entrypoint: EntryPoint,
|
237
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
223
238
|
) -> str:
|
224
239
|
# TODO: this should probably throw an error if the registry is a local registry
|
225
240
|
if not self.registry:
|
@@ -255,35 +270,52 @@ class KanikoBuilder(AbstractBuilder):
|
|
255
270
|
_, api_client = get_kube_context_and_api_client(
|
256
271
|
kubernetes, launch_project.resource_args
|
257
272
|
)
|
273
|
+
# TODO: use same client as kuberentes_runner.py
|
274
|
+
batch_v1 = client.BatchV1Api(api_client)
|
275
|
+
core_v1 = client.CoreV1Api(api_client)
|
276
|
+
|
258
277
|
build_job_name = f"{self.build_job_name}-{run_id}"
|
259
278
|
|
260
279
|
build_context = self._upload_build_context(run_id, context_path)
|
261
280
|
build_job = self._create_kaniko_job(
|
262
|
-
build_job_name,
|
263
|
-
repo_uri,
|
264
|
-
image_uri,
|
265
|
-
build_context,
|
281
|
+
build_job_name, repo_uri, image_uri, build_context, core_v1
|
266
282
|
)
|
267
283
|
wandb.termlog(f"{LOG_PREFIX}Created kaniko job {build_job_name}")
|
268
284
|
|
269
|
-
# TODO: use same client as kuberentes.py
|
270
|
-
batch_v1 = client.BatchV1Api(api_client)
|
271
|
-
core_v1 = client.CoreV1Api(api_client)
|
272
|
-
|
273
285
|
try:
|
286
|
+
if isinstance(self.registry, AzureContainerRegistry):
|
287
|
+
dockerfile_config_map = client.V1ConfigMap(
|
288
|
+
metadata=client.V1ObjectMeta(
|
289
|
+
name=f"docker-config-{build_job_name}"
|
290
|
+
),
|
291
|
+
data={
|
292
|
+
"config.json": json.dumps(
|
293
|
+
{
|
294
|
+
"credHelpers": {
|
295
|
+
f"{self.registry.registry_name}.azurecr.io": "acr-env"
|
296
|
+
}
|
297
|
+
}
|
298
|
+
)
|
299
|
+
},
|
300
|
+
)
|
301
|
+
core_v1.create_namespaced_config_map("wandb", dockerfile_config_map)
|
274
302
|
# core_v1.create_namespaced_config_map("wandb", dockerfile_config_map)
|
275
303
|
if self.secret_name:
|
276
304
|
self._create_docker_ecr_config_map(build_job_name, core_v1, repo_uri)
|
277
|
-
batch_v1.create_namespaced_job(
|
305
|
+
batch_v1.create_namespaced_job(NAMESPACE, build_job)
|
278
306
|
|
279
307
|
# wait for double the job deadline since it might take time to schedule
|
280
308
|
if not _wait_for_completion(
|
281
309
|
batch_v1, build_job_name, 3 * _DEFAULT_BUILD_TIMEOUT_SECS
|
282
310
|
):
|
311
|
+
if job_tracker:
|
312
|
+
job_tracker.set_err_stage("build")
|
283
313
|
raise Exception(f"Failed to build image in kaniko for job {run_id}")
|
284
314
|
try:
|
285
|
-
logs = batch_v1.read_namespaced_job_log(build_job_name,
|
286
|
-
warn_failed_packages_from_build_logs(
|
315
|
+
logs = batch_v1.read_namespaced_job_log(build_job_name, NAMESPACE)
|
316
|
+
warn_failed_packages_from_build_logs(
|
317
|
+
logs, image_uri, launch_project.api, job_tracker
|
318
|
+
)
|
287
319
|
except Exception as e:
|
288
320
|
wandb.termwarn(
|
289
321
|
f"{LOG_PREFIX}Failed to get logs for kaniko job {build_job_name}: {e}"
|
@@ -298,9 +330,13 @@ class KanikoBuilder(AbstractBuilder):
|
|
298
330
|
try:
|
299
331
|
# should we clean up the s3 build contexts? can set bucket level policy to auto deletion
|
300
332
|
# core_v1.delete_namespaced_config_map(config_map_name, "wandb")
|
333
|
+
if isinstance(self.registry, AzureContainerRegistry):
|
334
|
+
core_v1.delete_namespaced_config_map(
|
335
|
+
f"docker-config-{build_job_name}", "wandb"
|
336
|
+
)
|
301
337
|
if self.secret_name:
|
302
338
|
self._delete_docker_ecr_config_map(build_job_name, core_v1)
|
303
|
-
batch_v1.delete_namespaced_job(build_job_name,
|
339
|
+
batch_v1.delete_namespaced_job(build_job_name, NAMESPACE)
|
304
340
|
except Exception as e:
|
305
341
|
raise LaunchError(f"Exception during Kubernetes resource clean up {e}")
|
306
342
|
|
@@ -312,6 +348,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
312
348
|
repository: str,
|
313
349
|
image_tag: str,
|
314
350
|
build_context_path: str,
|
351
|
+
core_client: client.CoreV1Api,
|
315
352
|
) -> "client.V1Job":
|
316
353
|
env = []
|
317
354
|
volume_mounts = []
|
@@ -328,6 +365,33 @@ class KanikoBuilder(AbstractBuilder):
|
|
328
365
|
value=self.registry.environment.region,
|
329
366
|
)
|
330
367
|
]
|
368
|
+
# TODO: Refactor all of this environment/registry
|
369
|
+
# specific stuff into methods of those classes.
|
370
|
+
if isinstance(self.environment, AzureEnvironment):
|
371
|
+
# Use the core api to check if the secret exists
|
372
|
+
try:
|
373
|
+
core_client.read_namespaced_secret(
|
374
|
+
"azure-storage-access-key",
|
375
|
+
"wandb",
|
376
|
+
)
|
377
|
+
except Exception as e:
|
378
|
+
raise LaunchError(
|
379
|
+
"Secret azure-storage-access-key does not exist in "
|
380
|
+
"namespace wandb. Please create it with the key password "
|
381
|
+
"set to your azure storage access key."
|
382
|
+
) from e
|
383
|
+
env += [
|
384
|
+
client.V1EnvVar(
|
385
|
+
name="AZURE_STORAGE_ACCESS_KEY",
|
386
|
+
value_from=client.V1EnvVarSource(
|
387
|
+
secret_key_ref=client.V1SecretKeySelector(
|
388
|
+
name="azure-storage-access-key",
|
389
|
+
key="password",
|
390
|
+
)
|
391
|
+
),
|
392
|
+
)
|
393
|
+
]
|
394
|
+
|
331
395
|
if self.secret_name and self.secret_key:
|
332
396
|
volumes += [
|
333
397
|
client.V1Volume(
|
@@ -382,26 +446,48 @@ class KanikoBuilder(AbstractBuilder):
|
|
382
446
|
),
|
383
447
|
)
|
384
448
|
]
|
385
|
-
|
449
|
+
if isinstance(self.registry, AzureContainerRegistry):
|
450
|
+
# ADd the docker config map
|
451
|
+
volume_mounts += [
|
452
|
+
client.V1VolumeMount(
|
453
|
+
name="docker-config", mount_path="/kaniko/.docker/"
|
454
|
+
),
|
455
|
+
]
|
456
|
+
volumes += [
|
457
|
+
client.V1Volume(
|
458
|
+
name="docker-config",
|
459
|
+
config_map=client.V1ConfigMapVolumeSource(
|
460
|
+
name=f"docker-config-{job_name}",
|
461
|
+
),
|
462
|
+
),
|
463
|
+
]
|
464
|
+
# Kaniko doesn't want https:// at the begining of the image tag.
|
465
|
+
destination = image_tag
|
466
|
+
if destination.startswith("https://"):
|
467
|
+
destination = destination.replace("https://", "")
|
386
468
|
args = [
|
387
469
|
f"--context={build_context_path}",
|
388
470
|
"--dockerfile=Dockerfile.wandb-autogenerated",
|
389
|
-
f"--destination={
|
471
|
+
f"--destination={destination}",
|
390
472
|
"--cache=true",
|
391
|
-
f"--cache-repo={repository}",
|
473
|
+
f"--cache-repo={repository.replace('https://', '')}",
|
392
474
|
"--snapshotMode=redo",
|
393
475
|
"--compressed-caching=false",
|
394
476
|
]
|
395
477
|
container = client.V1Container(
|
396
478
|
name="wandb-container-build",
|
397
|
-
image=
|
479
|
+
image=self.image,
|
398
480
|
args=args,
|
399
481
|
volume_mounts=volume_mounts,
|
400
482
|
env=env if env else None,
|
401
483
|
)
|
402
484
|
# Create and configure a spec section
|
485
|
+
labels = {"wandb": "launch"}
|
486
|
+
# This annotation is required to enable azure workload identity.
|
487
|
+
if isinstance(self.registry, AzureContainerRegistry):
|
488
|
+
labels["azure.workload.identity/use"] = "true"
|
403
489
|
template = client.V1PodTemplateSpec(
|
404
|
-
metadata=client.V1ObjectMeta(labels=
|
490
|
+
metadata=client.V1ObjectMeta(labels=labels),
|
405
491
|
spec=client.V1PodSpec(
|
406
492
|
restart_policy="Never",
|
407
493
|
active_deadline_seconds=_DEFAULT_BUILD_TIMEOUT_SECS,
|
@@ -416,7 +502,7 @@ class KanikoBuilder(AbstractBuilder):
|
|
416
502
|
api_version="batch/v1",
|
417
503
|
kind="Job",
|
418
504
|
metadata=client.V1ObjectMeta(
|
419
|
-
name=job_name, namespace=
|
505
|
+
name=job_name, namespace=NAMESPACE, labels={"wandb": "launch"}
|
420
506
|
),
|
421
507
|
spec=spec,
|
422
508
|
)
|
wandb/sdk/launch/builder/noop.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""NoOp builder implementation."""
|
2
|
-
from typing import Any, Dict
|
2
|
+
from typing import Any, Dict, Optional
|
3
3
|
|
4
4
|
from wandb.sdk.launch.builder.abstract import AbstractBuilder
|
5
5
|
from wandb.sdk.launch.environment.abstract import AbstractEnvironment
|
@@ -7,6 +7,7 @@ from wandb.sdk.launch.errors import LaunchError
|
|
7
7
|
from wandb.sdk.launch.registry.abstract import AbstractRegistry
|
8
8
|
|
9
9
|
from .._project_spec import EntryPoint, LaunchProject
|
10
|
+
from ..agent.job_status_tracker import JobAndRunStatusTracker
|
10
11
|
|
11
12
|
|
12
13
|
class NoOpBuilder(AbstractBuilder):
|
@@ -43,6 +44,7 @@ class NoOpBuilder(AbstractBuilder):
|
|
43
44
|
self,
|
44
45
|
launch_project: LaunchProject,
|
45
46
|
entrypoint: EntryPoint,
|
47
|
+
job_tracker: Optional[JobAndRunStatusTracker] = None,
|
46
48
|
) -> str:
|
47
49
|
"""Build the image.
|
48
50
|
|
@@ -51,6 +51,7 @@ class AwsEnvironment(AbstractEnvironment):
|
|
51
51
|
self._access_key = access_key
|
52
52
|
self._secret_key = secret_key
|
53
53
|
self._session_token = session_token
|
54
|
+
self._account = None
|
54
55
|
if verify:
|
55
56
|
self.verify()
|
56
57
|
|
@@ -131,7 +132,7 @@ class AwsEnvironment(AbstractEnvironment):
|
|
131
132
|
try:
|
132
133
|
session = self.get_session()
|
133
134
|
client = session.client("sts")
|
134
|
-
client.get_caller_identity()
|
135
|
+
self._account = client.get_caller_identity().get("Account")
|
135
136
|
# TODO: log identity details from the response
|
136
137
|
except botocore.exceptions.ClientError as e:
|
137
138
|
raise LaunchError(
|
@@ -0,0 +1,124 @@
|
|
1
|
+
"""Implementation of AzureEnvironment class."""
|
2
|
+
|
3
|
+
import re
|
4
|
+
from typing import TYPE_CHECKING, Tuple
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from azure.identity import DefaultAzureCredential # type: ignore
|
8
|
+
from azure.storage.blob import BlobClient, BlobServiceClient # type: ignore
|
9
|
+
|
10
|
+
from wandb.util import get_module
|
11
|
+
|
12
|
+
from ..errors import LaunchError
|
13
|
+
from .abstract import AbstractEnvironment
|
14
|
+
|
15
|
+
AZURE_BLOB_REGEX = re.compile(
|
16
|
+
r"^https://([^\.]+)\.blob\.core\.windows\.net/([^/]+)/?(.*)$"
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
DefaultAzureCredential = get_module( # noqa: F811
|
21
|
+
"azure.identity",
|
22
|
+
required="The azure-identity package is required to use launch with Azure. Please install it with `pip install azure-identity`.",
|
23
|
+
).DefaultAzureCredential
|
24
|
+
blob = get_module(
|
25
|
+
"azure.storage.blob",
|
26
|
+
required="The azure-storage-blob package is required to use launch with Azure. Please install it with `pip install azure-storage-blob`.",
|
27
|
+
)
|
28
|
+
BlobClient, BlobServiceClient = blob.BlobClient, blob.BlobServiceClient # noqa: F811
|
29
|
+
|
30
|
+
|
31
|
+
class AzureEnvironment(AbstractEnvironment):
|
32
|
+
"""AzureEnvironment is a helper for accessing Azure resources."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
verify: bool = True,
|
37
|
+
):
|
38
|
+
"""Initialize an AzureEnvironment."""
|
39
|
+
if verify:
|
40
|
+
self.verify()
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def from_config(cls, config: dict, verify: bool = True) -> "AzureEnvironment":
|
44
|
+
"""Create an AzureEnvironment from a config dict."""
|
45
|
+
return cls(verify=verify)
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def get_credentials(cls) -> DefaultAzureCredential:
|
49
|
+
"""Get Azure credentials."""
|
50
|
+
try:
|
51
|
+
return DefaultAzureCredential()
|
52
|
+
except Exception as e:
|
53
|
+
raise LaunchError(
|
54
|
+
"Could not get Azure credentials. Please make sure you have "
|
55
|
+
"configured your Azure CLI correctly."
|
56
|
+
) from e
|
57
|
+
|
58
|
+
def upload_file(self, source: str, destination: str) -> None:
|
59
|
+
"""Upload a file to Azure blob storage.
|
60
|
+
|
61
|
+
Arguments:
|
62
|
+
source (str): The path to the file to upload.
|
63
|
+
destination (str): The destination path in Azure blob storage. Ex:
|
64
|
+
https://<storage_account>.blob.core.windows.net/<storage_container>/<path>
|
65
|
+
Raise:
|
66
|
+
LaunchError: If the file could not be uploaded.
|
67
|
+
"""
|
68
|
+
storage_account, storage_container, path = self.parse_uri(destination)
|
69
|
+
creds = self.get_credentials()
|
70
|
+
try:
|
71
|
+
client = BlobClient(
|
72
|
+
f"https://{storage_account}.blob.core.windows.net",
|
73
|
+
storage_container,
|
74
|
+
path,
|
75
|
+
credential=creds,
|
76
|
+
)
|
77
|
+
with open(source, "rb") as f:
|
78
|
+
client.upload_blob(f)
|
79
|
+
except Exception as e:
|
80
|
+
raise LaunchError(
|
81
|
+
f"Could not upload file {source} to Azure blob {destination}."
|
82
|
+
) from e
|
83
|
+
|
84
|
+
def upload_dir(self, source: str, destination: str) -> None:
|
85
|
+
"""Upload a directory to Azure blob storage."""
|
86
|
+
raise NotImplementedError()
|
87
|
+
|
88
|
+
def verify_storage_uri(self, uri: str) -> None:
|
89
|
+
"""Verify that the given blob storage prefix exists.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
uri (str): The URI to verify.
|
93
|
+
"""
|
94
|
+
creds = self.get_credentials()
|
95
|
+
storage_account, storage_container, _ = self.parse_uri(uri)
|
96
|
+
try:
|
97
|
+
client = BlobServiceClient(
|
98
|
+
f"https://{storage_account}.blob.core.windows.net",
|
99
|
+
credential=creds,
|
100
|
+
)
|
101
|
+
client.get_container_client(storage_container)
|
102
|
+
except Exception as e:
|
103
|
+
raise LaunchError(
|
104
|
+
f"Could not verify storage URI {uri} in container {storage_container}."
|
105
|
+
) from e
|
106
|
+
|
107
|
+
def verify(self) -> None:
|
108
|
+
"""Verify that the AzureEnvironment is valid."""
|
109
|
+
self.get_credentials()
|
110
|
+
|
111
|
+
@staticmethod
|
112
|
+
def parse_uri(uri: str) -> Tuple[str, str, str]:
|
113
|
+
"""Parse an Azure blob storage URI into a storage account and container.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
uri (str): The URI to parse.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Tuple[str, str]: The storage account and container.
|
120
|
+
"""
|
121
|
+
match = AZURE_BLOB_REGEX.match(uri)
|
122
|
+
if match is None:
|
123
|
+
raise LaunchError(f"Could not parse Azure blob URI {uri}.")
|
124
|
+
return match.group(1), match.group(2), match.group(3)
|
@@ -58,6 +58,7 @@ class GitHubReference:
|
|
58
58
|
|
59
59
|
ref: Optional[str] = None # branch or commit
|
60
60
|
ref_type: Optional[ReferenceType] = None
|
61
|
+
commit_hash: Optional[str] = None # hash of commit
|
61
62
|
|
62
63
|
directory: Optional[str] = None
|
63
64
|
file: Optional[str] = None
|
@@ -68,6 +69,7 @@ class GitHubReference:
|
|
68
69
|
self.ref_type = None
|
69
70
|
self.ref = ref
|
70
71
|
|
72
|
+
@property
|
71
73
|
def url_host(self) -> str:
|
72
74
|
assert self.host
|
73
75
|
auth = self.username or ""
|
@@ -77,19 +79,23 @@ class GitHubReference:
|
|
77
79
|
auth += "@"
|
78
80
|
return f"{PREFIX_HTTPS}{auth}{self.host}"
|
79
81
|
|
82
|
+
@property
|
80
83
|
def url_organization(self) -> str:
|
81
84
|
assert self.organization
|
82
|
-
return f"{self.url_host
|
85
|
+
return f"{self.url_host}/{self.organization}"
|
83
86
|
|
87
|
+
@property
|
84
88
|
def url_repo(self) -> str:
|
85
89
|
assert self.repo
|
86
|
-
return f"{self.url_organization
|
90
|
+
return f"{self.url_organization}/{self.repo}"
|
87
91
|
|
92
|
+
@property
|
88
93
|
def repo_ssh(self) -> str:
|
89
94
|
return f"{PREFIX_SSH}{self.host}:{self.organization}/{self.repo}{SUFFIX_GIT}"
|
90
95
|
|
96
|
+
@property
|
91
97
|
def url(self) -> str:
|
92
|
-
url = self.url_repo
|
98
|
+
url = self.url_repo
|
93
99
|
if self.view:
|
94
100
|
url += f"/{self.view}"
|
95
101
|
if self.ref:
|
@@ -98,7 +104,7 @@ class GitHubReference:
|
|
98
104
|
url += f"/{self.directory}"
|
99
105
|
if self.file:
|
100
106
|
url += f"/{self.file}"
|
101
|
-
|
107
|
+
if self.path:
|
102
108
|
url += f"/{self.path}"
|
103
109
|
return url
|
104
110
|
|
@@ -127,18 +133,21 @@ class GitHubReference:
|
|
127
133
|
ref.username, ref.password, ref.host = _parse_netloc(parsed.netloc)
|
128
134
|
|
129
135
|
parts = parsed.path.split("/")
|
130
|
-
if len(parts)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
136
|
+
if len(parts) < 2:
|
137
|
+
return ref
|
138
|
+
if parts[1] == "orgs" and len(parts) > 2:
|
139
|
+
ref.organization = parts[2]
|
140
|
+
return ref
|
141
|
+
ref.organization = parts[1]
|
142
|
+
if len(parts) < 3:
|
143
|
+
return ref
|
144
|
+
repo = parts[2]
|
145
|
+
if repo.endswith(SUFFIX_GIT):
|
146
|
+
repo = repo[: -len(SUFFIX_GIT)]
|
147
|
+
ref.repo = repo
|
148
|
+
ref.view = parts[3] if len(parts) > 3 else None
|
149
|
+
ref.path = "/".join(parts[4:])
|
150
|
+
|
142
151
|
return ref
|
143
152
|
|
144
153
|
def fetch(self, dst_dir: str) -> None:
|
@@ -148,7 +157,7 @@ class GitHubReference:
|
|
148
157
|
import git # type: ignore
|
149
158
|
|
150
159
|
repo = git.Repo.init(dst_dir)
|
151
|
-
origin = repo.create_remote("origin", self.url_repo
|
160
|
+
origin = repo.create_remote("origin", self.url_repo)
|
152
161
|
|
153
162
|
# We fetch the origin so that we have branch and tag references
|
154
163
|
origin.fetch(depth=1)
|
@@ -165,6 +174,7 @@ class GitHubReference:
|
|
165
174
|
self.path = self.path[len(first_segment) + 1 :]
|
166
175
|
head = repo.create_head(first_segment, commit)
|
167
176
|
head.checkout()
|
177
|
+
self.commit_hash = head.commit.hexsha
|
168
178
|
except ValueError:
|
169
179
|
# Apparently it just looked like a commit
|
170
180
|
pass
|
@@ -188,6 +198,7 @@ class GitHubReference:
|
|
188
198
|
self.path = self.path[len(refname) + 1 :]
|
189
199
|
head = repo.create_head(branch, origin.refs[branch])
|
190
200
|
head.checkout()
|
201
|
+
self.commit_hash = head.commit.hexsha
|
191
202
|
break
|
192
203
|
|
193
204
|
# Must be on default branch. Try to figure out what that is.
|
@@ -209,11 +220,12 @@ class GitHubReference:
|
|
209
220
|
# (While the references appear to be sorted, not clear if that's guaranteed.)
|
210
221
|
if not default_branch:
|
211
222
|
raise LaunchError(
|
212
|
-
f"Unable to determine branch or commit to checkout from {self.url
|
223
|
+
f"Unable to determine branch or commit to checkout from {self.url}"
|
213
224
|
)
|
214
225
|
self.default_branch = default_branch
|
215
226
|
head = repo.create_head(default_branch, origin.refs[default_branch])
|
216
227
|
head.checkout()
|
228
|
+
self.commit_hash = head.commit.hexsha
|
217
229
|
repo.submodule_update(init=True, recursive=True)
|
218
230
|
|
219
231
|
# Now that we've checked something out, try to extract directory and file from what remains
|
wandb/sdk/launch/launch.py
CHANGED
@@ -175,7 +175,7 @@ def _run(
|
|
175
175
|
builder = loader.builder_from_config(build_config, environment, registry)
|
176
176
|
backend = loader.runner_from_config(resource, api, runner_config, environment)
|
177
177
|
if backend:
|
178
|
-
submitted_run = backend.run(launch_project, builder)
|
178
|
+
submitted_run = backend.run(launch_project, builder, None)
|
179
179
|
# this check will always pass, run is only optional in the agent case where
|
180
180
|
# a run queue id is present on the backend config
|
181
181
|
assert submitted_run
|
wandb/sdk/launch/loader.py
CHANGED
@@ -54,6 +54,10 @@ def environment_from_config(config: Optional[Dict[str, Any]]) -> AbstractEnviron
|
|
54
54
|
from .environment.gcp_environment import GcpEnvironment
|
55
55
|
|
56
56
|
return GcpEnvironment.from_config(config)
|
57
|
+
if env_type == "azure":
|
58
|
+
from .environment.azure_environment import AzureEnvironment
|
59
|
+
|
60
|
+
return AzureEnvironment.from_config(config)
|
57
61
|
raise LaunchError(
|
58
62
|
f"Could not create environment from config. Invalid type: {env_type}"
|
59
63
|
)
|
@@ -110,6 +114,17 @@ def registry_from_config(
|
|
110
114
|
from .registry.google_artifact_registry import GoogleArtifactRegistry
|
111
115
|
|
112
116
|
return GoogleArtifactRegistry.from_config(config, environment)
|
117
|
+
if registry_type == "acr":
|
118
|
+
from .environment.azure_environment import AzureEnvironment
|
119
|
+
|
120
|
+
if not isinstance(environment, AzureEnvironment):
|
121
|
+
raise LaunchError(
|
122
|
+
"Could not create ACR registry. "
|
123
|
+
"Environment must be an instance of AzureEnvironment."
|
124
|
+
)
|
125
|
+
from .registry.azure_container_registry import AzureContainerRegistry
|
126
|
+
|
127
|
+
return AzureContainerRegistry.from_config(config, environment)
|
113
128
|
raise LaunchError(
|
114
129
|
f"Could not create registry from config. Invalid registry type: {registry_type}"
|
115
130
|
)
|