wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -5,7 +5,7 @@ import os
|
|
5
5
|
import re
|
6
6
|
from typing import Dict
|
7
7
|
|
8
|
-
from wandb.sdk.launch.
|
8
|
+
from wandb.sdk.launch.errors import LaunchError
|
9
9
|
from wandb.util import get_module
|
10
10
|
|
11
11
|
from .abstract import AbstractEnvironment
|
@@ -51,6 +51,7 @@ class AwsEnvironment(AbstractEnvironment):
|
|
51
51
|
self._access_key = access_key
|
52
52
|
self._secret_key = secret_key
|
53
53
|
self._session_token = session_token
|
54
|
+
self._account = None
|
54
55
|
if verify:
|
55
56
|
self.verify()
|
56
57
|
|
@@ -131,7 +132,7 @@ class AwsEnvironment(AbstractEnvironment):
|
|
131
132
|
try:
|
132
133
|
session = self.get_session()
|
133
134
|
client = session.client("sts")
|
134
|
-
client.get_caller_identity()
|
135
|
+
self._account = client.get_caller_identity().get("Account")
|
135
136
|
# TODO: log identity details from the response
|
136
137
|
except botocore.exceptions.ClientError as e:
|
137
138
|
raise LaunchError(
|
@@ -0,0 +1,124 @@
|
|
1
|
+
"""Implementation of AzureEnvironment class."""
|
2
|
+
|
3
|
+
import re
|
4
|
+
from typing import TYPE_CHECKING, Tuple
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from azure.identity import DefaultAzureCredential # type: ignore
|
8
|
+
from azure.storage.blob import BlobClient, BlobServiceClient # type: ignore
|
9
|
+
|
10
|
+
from wandb.util import get_module
|
11
|
+
|
12
|
+
from ..errors import LaunchError
|
13
|
+
from .abstract import AbstractEnvironment
|
14
|
+
|
15
|
+
AZURE_BLOB_REGEX = re.compile(
|
16
|
+
r"^https://([^\.]+)\.blob\.core\.windows\.net/([^/]+)/?(.*)$"
|
17
|
+
)
|
18
|
+
|
19
|
+
|
20
|
+
DefaultAzureCredential = get_module( # noqa: F811
|
21
|
+
"azure.identity",
|
22
|
+
required="The azure-identity package is required to use launch with Azure. Please install it with `pip install azure-identity`.",
|
23
|
+
).DefaultAzureCredential
|
24
|
+
blob = get_module(
|
25
|
+
"azure.storage.blob",
|
26
|
+
required="The azure-storage-blob package is required to use launch with Azure. Please install it with `pip install azure-storage-blob`.",
|
27
|
+
)
|
28
|
+
BlobClient, BlobServiceClient = blob.BlobClient, blob.BlobServiceClient # noqa: F811
|
29
|
+
|
30
|
+
|
31
|
+
class AzureEnvironment(AbstractEnvironment):
|
32
|
+
"""AzureEnvironment is a helper for accessing Azure resources."""
|
33
|
+
|
34
|
+
def __init__(
|
35
|
+
self,
|
36
|
+
verify: bool = True,
|
37
|
+
):
|
38
|
+
"""Initialize an AzureEnvironment."""
|
39
|
+
if verify:
|
40
|
+
self.verify()
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def from_config(cls, config: dict, verify: bool = True) -> "AzureEnvironment":
|
44
|
+
"""Create an AzureEnvironment from a config dict."""
|
45
|
+
return cls(verify=verify)
|
46
|
+
|
47
|
+
@classmethod
|
48
|
+
def get_credentials(cls) -> DefaultAzureCredential:
|
49
|
+
"""Get Azure credentials."""
|
50
|
+
try:
|
51
|
+
return DefaultAzureCredential()
|
52
|
+
except Exception as e:
|
53
|
+
raise LaunchError(
|
54
|
+
"Could not get Azure credentials. Please make sure you have "
|
55
|
+
"configured your Azure CLI correctly."
|
56
|
+
) from e
|
57
|
+
|
58
|
+
def upload_file(self, source: str, destination: str) -> None:
|
59
|
+
"""Upload a file to Azure blob storage.
|
60
|
+
|
61
|
+
Arguments:
|
62
|
+
source (str): The path to the file to upload.
|
63
|
+
destination (str): The destination path in Azure blob storage. Ex:
|
64
|
+
https://<storage_account>.blob.core.windows.net/<storage_container>/<path>
|
65
|
+
Raise:
|
66
|
+
LaunchError: If the file could not be uploaded.
|
67
|
+
"""
|
68
|
+
storage_account, storage_container, path = self.parse_uri(destination)
|
69
|
+
creds = self.get_credentials()
|
70
|
+
try:
|
71
|
+
client = BlobClient(
|
72
|
+
f"https://{storage_account}.blob.core.windows.net",
|
73
|
+
storage_container,
|
74
|
+
path,
|
75
|
+
credential=creds,
|
76
|
+
)
|
77
|
+
with open(source, "rb") as f:
|
78
|
+
client.upload_blob(f)
|
79
|
+
except Exception as e:
|
80
|
+
raise LaunchError(
|
81
|
+
f"Could not upload file {source} to Azure blob {destination}."
|
82
|
+
) from e
|
83
|
+
|
84
|
+
def upload_dir(self, source: str, destination: str) -> None:
|
85
|
+
"""Upload a directory to Azure blob storage."""
|
86
|
+
raise NotImplementedError()
|
87
|
+
|
88
|
+
def verify_storage_uri(self, uri: str) -> None:
|
89
|
+
"""Verify that the given blob storage prefix exists.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
uri (str): The URI to verify.
|
93
|
+
"""
|
94
|
+
creds = self.get_credentials()
|
95
|
+
storage_account, storage_container, _ = self.parse_uri(uri)
|
96
|
+
try:
|
97
|
+
client = BlobServiceClient(
|
98
|
+
f"https://{storage_account}.blob.core.windows.net",
|
99
|
+
credential=creds,
|
100
|
+
)
|
101
|
+
client.get_container_client(storage_container)
|
102
|
+
except Exception as e:
|
103
|
+
raise LaunchError(
|
104
|
+
f"Could not verify storage URI {uri} in container {storage_container}."
|
105
|
+
) from e
|
106
|
+
|
107
|
+
def verify(self) -> None:
|
108
|
+
"""Verify that the AzureEnvironment is valid."""
|
109
|
+
self.get_credentials()
|
110
|
+
|
111
|
+
@staticmethod
|
112
|
+
def parse_uri(uri: str) -> Tuple[str, str, str]:
|
113
|
+
"""Parse an Azure blob storage URI into a storage account and container.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
uri (str): The URI to parse.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Tuple[str, str]: The storage account and container.
|
120
|
+
"""
|
121
|
+
match = AZURE_BLOB_REGEX.match(uri)
|
122
|
+
if match is None:
|
123
|
+
raise LaunchError(f"Could not parse Azure blob URI {uri}.")
|
124
|
+
return match.group(1), match.group(2), match.group(3)
|
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
import re
|
5
5
|
|
6
|
-
from wandb.sdk.launch.
|
6
|
+
from wandb.sdk.launch.errors import LaunchError
|
7
7
|
from wandb.util import get_module
|
8
8
|
|
9
9
|
from .abstract import AbstractEnvironment
|
@@ -202,7 +202,7 @@ class GcpEnvironment(AbstractEnvironment):
|
|
202
202
|
storage_client = google.cloud.storage.Client(
|
203
203
|
credentials=self.get_credentials()
|
204
204
|
)
|
205
|
-
bucket = storage_client.
|
205
|
+
bucket = storage_client.get_bucket(bucket)
|
206
206
|
except google.api_core.exceptions.NotFound as e:
|
207
207
|
raise LaunchError(f"Bucket {bucket} does not exist.") from e
|
208
208
|
|
@@ -267,5 +267,3 @@ class GcpEnvironment(AbstractEnvironment):
|
|
267
267
|
blob.upload_from_filename(local_path)
|
268
268
|
except google.api_core.exceptions.GoogleAPICallError as e:
|
269
269
|
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
270
|
-
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
271
|
-
raise LaunchError(f"Could not upload directory to GCS: {e}") from e
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from wandb.errors import Error
|
2
|
+
|
3
|
+
|
4
|
+
class LaunchError(Error):
|
5
|
+
"""Raised when a known error occurs in wandb launch."""
|
6
|
+
|
7
|
+
pass
|
8
|
+
|
9
|
+
|
10
|
+
class LaunchDockerError(Error):
|
11
|
+
"""Raised when Docker daemon is not running."""
|
12
|
+
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
class ExecutionError(Error):
|
17
|
+
"""Generic execution exception."""
|
18
|
+
|
19
|
+
pass
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
7
7
|
from typing import Optional, Tuple
|
8
8
|
from urllib.parse import urlparse
|
9
9
|
|
10
|
-
from wandb.sdk.launch.
|
10
|
+
from wandb.sdk.launch.errors import LaunchError
|
11
11
|
|
12
12
|
PREFIX_HTTPS = "https://"
|
13
13
|
PREFIX_SSH = "git@"
|
@@ -58,6 +58,7 @@ class GitHubReference:
|
|
58
58
|
|
59
59
|
ref: Optional[str] = None # branch or commit
|
60
60
|
ref_type: Optional[ReferenceType] = None
|
61
|
+
commit_hash: Optional[str] = None # hash of commit
|
61
62
|
|
62
63
|
directory: Optional[str] = None
|
63
64
|
file: Optional[str] = None
|
@@ -68,6 +69,7 @@ class GitHubReference:
|
|
68
69
|
self.ref_type = None
|
69
70
|
self.ref = ref
|
70
71
|
|
72
|
+
@property
|
71
73
|
def url_host(self) -> str:
|
72
74
|
assert self.host
|
73
75
|
auth = self.username or ""
|
@@ -77,19 +79,23 @@ class GitHubReference:
|
|
77
79
|
auth += "@"
|
78
80
|
return f"{PREFIX_HTTPS}{auth}{self.host}"
|
79
81
|
|
82
|
+
@property
|
80
83
|
def url_organization(self) -> str:
|
81
84
|
assert self.organization
|
82
|
-
return f"{self.url_host
|
85
|
+
return f"{self.url_host}/{self.organization}"
|
83
86
|
|
87
|
+
@property
|
84
88
|
def url_repo(self) -> str:
|
85
89
|
assert self.repo
|
86
|
-
return f"{self.url_organization
|
90
|
+
return f"{self.url_organization}/{self.repo}"
|
87
91
|
|
92
|
+
@property
|
88
93
|
def repo_ssh(self) -> str:
|
89
94
|
return f"{PREFIX_SSH}{self.host}:{self.organization}/{self.repo}{SUFFIX_GIT}"
|
90
95
|
|
96
|
+
@property
|
91
97
|
def url(self) -> str:
|
92
|
-
url = self.url_repo
|
98
|
+
url = self.url_repo
|
93
99
|
if self.view:
|
94
100
|
url += f"/{self.view}"
|
95
101
|
if self.ref:
|
@@ -98,7 +104,7 @@ class GitHubReference:
|
|
98
104
|
url += f"/{self.directory}"
|
99
105
|
if self.file:
|
100
106
|
url += f"/{self.file}"
|
101
|
-
|
107
|
+
if self.path:
|
102
108
|
url += f"/{self.path}"
|
103
109
|
return url
|
104
110
|
|
@@ -127,18 +133,21 @@ class GitHubReference:
|
|
127
133
|
ref.username, ref.password, ref.host = _parse_netloc(parsed.netloc)
|
128
134
|
|
129
135
|
parts = parsed.path.split("/")
|
130
|
-
if len(parts)
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
136
|
+
if len(parts) < 2:
|
137
|
+
return ref
|
138
|
+
if parts[1] == "orgs" and len(parts) > 2:
|
139
|
+
ref.organization = parts[2]
|
140
|
+
return ref
|
141
|
+
ref.organization = parts[1]
|
142
|
+
if len(parts) < 3:
|
143
|
+
return ref
|
144
|
+
repo = parts[2]
|
145
|
+
if repo.endswith(SUFFIX_GIT):
|
146
|
+
repo = repo[: -len(SUFFIX_GIT)]
|
147
|
+
ref.repo = repo
|
148
|
+
ref.view = parts[3] if len(parts) > 3 else None
|
149
|
+
ref.path = "/".join(parts[4:])
|
150
|
+
|
142
151
|
return ref
|
143
152
|
|
144
153
|
def fetch(self, dst_dir: str) -> None:
|
@@ -148,7 +157,7 @@ class GitHubReference:
|
|
148
157
|
import git # type: ignore
|
149
158
|
|
150
159
|
repo = git.Repo.init(dst_dir)
|
151
|
-
origin = repo.create_remote("origin", self.url_repo
|
160
|
+
origin = repo.create_remote("origin", self.url_repo)
|
152
161
|
|
153
162
|
# We fetch the origin so that we have branch and tag references
|
154
163
|
origin.fetch(depth=1)
|
@@ -165,6 +174,7 @@ class GitHubReference:
|
|
165
174
|
self.path = self.path[len(first_segment) + 1 :]
|
166
175
|
head = repo.create_head(first_segment, commit)
|
167
176
|
head.checkout()
|
177
|
+
self.commit_hash = head.commit.hexsha
|
168
178
|
except ValueError:
|
169
179
|
# Apparently it just looked like a commit
|
170
180
|
pass
|
@@ -188,6 +198,7 @@ class GitHubReference:
|
|
188
198
|
self.path = self.path[len(refname) + 1 :]
|
189
199
|
head = repo.create_head(branch, origin.refs[branch])
|
190
200
|
head.checkout()
|
201
|
+
self.commit_hash = head.commit.hexsha
|
191
202
|
break
|
192
203
|
|
193
204
|
# Must be on default branch. Try to figure out what that is.
|
@@ -209,11 +220,13 @@ class GitHubReference:
|
|
209
220
|
# (While the references appear to be sorted, not clear if that's guaranteed.)
|
210
221
|
if not default_branch:
|
211
222
|
raise LaunchError(
|
212
|
-
f"Unable to determine branch or commit to checkout from {self.url
|
223
|
+
f"Unable to determine branch or commit to checkout from {self.url}"
|
213
224
|
)
|
214
225
|
self.default_branch = default_branch
|
215
226
|
head = repo.create_head(default_branch, origin.refs[default_branch])
|
216
227
|
head.checkout()
|
228
|
+
self.commit_hash = head.commit.hexsha
|
229
|
+
repo.submodule_update(init=True, recursive=True)
|
217
230
|
|
218
231
|
# Now that we've checked something out, try to extract directory and file from what remains
|
219
232
|
self._update_path(dst_dir)
|
wandb/sdk/launch/launch.py
CHANGED
@@ -11,13 +11,12 @@ from . import loader
|
|
11
11
|
from ._project_spec import create_project_from_spec, fetch_and_validate_project
|
12
12
|
from .agent import LaunchAgent
|
13
13
|
from .builder.build import construct_builder_args
|
14
|
+
from .errors import ExecutionError, LaunchError
|
14
15
|
from .runner.abstract import AbstractRun
|
15
16
|
from .utils import (
|
16
17
|
LAUNCH_CONFIG_FILE,
|
17
18
|
LAUNCH_DEFAULT_PROJECT,
|
18
19
|
PROJECT_SYNCHRONOUS,
|
19
|
-
ExecutionError,
|
20
|
-
LaunchError,
|
21
20
|
construct_launch_spec,
|
22
21
|
validate_launch_spec_source,
|
23
22
|
)
|
@@ -76,14 +75,10 @@ def resolve_agent_config( # noqa: C901
|
|
76
75
|
user_set_project = True
|
77
76
|
if os.environ.get("WANDB_ENTITY") is not None:
|
78
77
|
resolved_config.update({"entity": os.environ.get("WANDB_ENTITY")})
|
79
|
-
if os.environ.get("WANDB_API_KEY") is not None:
|
80
|
-
resolved_config.update({"api_key": os.environ.get("WANDB_API_KEY")})
|
81
78
|
if os.environ.get("WANDB_LAUNCH_MAX_JOBS") is not None:
|
82
79
|
resolved_config.update(
|
83
80
|
{"max_jobs": int(os.environ.get("WANDB_LAUNCH_MAX_JOBS", 1))}
|
84
81
|
)
|
85
|
-
if os.environ.get("WANDB_BASE_URL") is not None:
|
86
|
-
resolved_config.update({"base_url": os.environ.get("WANDB_BASE_URL")})
|
87
82
|
|
88
83
|
if project is not None:
|
89
84
|
resolved_config.update({"project": project})
|
@@ -104,7 +99,7 @@ def resolve_agent_config( # noqa: C901
|
|
104
99
|
+ " (expected str). Specify multiple queues with the 'queues' key"
|
105
100
|
)
|
106
101
|
|
107
|
-
keys = ["
|
102
|
+
keys = ["project", "entity"]
|
108
103
|
settings = {
|
109
104
|
k: resolved_config.get(k) for k in keys if resolved_config.get(k) is not None
|
110
105
|
}
|
@@ -180,7 +175,7 @@ def _run(
|
|
180
175
|
builder = loader.builder_from_config(build_config, environment, registry)
|
181
176
|
backend = loader.runner_from_config(resource, api, runner_config, environment)
|
182
177
|
if backend:
|
183
|
-
submitted_run = backend.run(launch_project, builder)
|
178
|
+
submitted_run = backend.run(launch_project, builder, None)
|
184
179
|
# this check will always pass, run is only optional in the agent case where
|
185
180
|
# a run queue id is present on the backend config
|
186
181
|
assert submitted_run
|
wandb/sdk/launch/launch_add.py
CHANGED
@@ -6,10 +6,10 @@ import wandb.apis.public as public
|
|
6
6
|
from wandb.apis.internal import Api
|
7
7
|
from wandb.sdk.launch._project_spec import create_project_from_spec
|
8
8
|
from wandb.sdk.launch.builder.build import build_image_from_project
|
9
|
+
from wandb.sdk.launch.errors import LaunchError
|
9
10
|
from wandb.sdk.launch.utils import (
|
10
11
|
LAUNCH_DEFAULT_PROJECT,
|
11
12
|
LOG_PREFIX,
|
12
|
-
LaunchError,
|
13
13
|
construct_launch_spec,
|
14
14
|
validate_launch_spec_source,
|
15
15
|
)
|
@@ -43,6 +43,7 @@ def launch_add(
|
|
43
43
|
run_id: Optional[str] = None,
|
44
44
|
build: Optional[bool] = False,
|
45
45
|
repository: Optional[str] = None,
|
46
|
+
sweep_id: Optional[str] = None,
|
46
47
|
author: Optional[str] = None,
|
47
48
|
) -> "public.QueuedRun":
|
48
49
|
"""Enqueue a W&B launch experiment. With either a source uri, job or docker_image.
|
@@ -111,6 +112,7 @@ def launch_add(
|
|
111
112
|
run_id=run_id,
|
112
113
|
build=build,
|
113
114
|
repository=repository,
|
115
|
+
sweep_id=sweep_id,
|
114
116
|
author=author,
|
115
117
|
)
|
116
118
|
|
@@ -133,6 +135,7 @@ def _launch_add(
|
|
133
135
|
run_id: Optional[str] = None,
|
134
136
|
build: Optional[bool] = False,
|
135
137
|
repository: Optional[str] = None,
|
138
|
+
sweep_id: Optional[str] = None,
|
136
139
|
author: Optional[str] = None,
|
137
140
|
) -> "public.QueuedRun":
|
138
141
|
launch_spec = construct_launch_spec(
|
@@ -151,6 +154,7 @@ def _launch_add(
|
|
151
154
|
run_id,
|
152
155
|
repository,
|
153
156
|
author,
|
157
|
+
sweep_id,
|
154
158
|
)
|
155
159
|
|
156
160
|
if build:
|
@@ -207,7 +211,7 @@ def _launch_add(
|
|
207
211
|
container_job = False
|
208
212
|
if job:
|
209
213
|
job_artifact = public_api.job(job)
|
210
|
-
if job_artifact.
|
214
|
+
if job_artifact._job_info.get("source_type") == "image":
|
211
215
|
container_job = True
|
212
216
|
|
213
217
|
queued_run = public_api.queued_run(
|
wandb/sdk/launch/loader.py
CHANGED
@@ -2,7 +2,7 @@
|
|
2
2
|
from typing import Any, Dict, Optional
|
3
3
|
|
4
4
|
from wandb.apis.internal import Api
|
5
|
-
from wandb.sdk.launch.
|
5
|
+
from wandb.sdk.launch.errors import LaunchError
|
6
6
|
|
7
7
|
from .builder.abstract import AbstractBuilder
|
8
8
|
from .environment.abstract import AbstractEnvironment
|
@@ -42,6 +42,10 @@ def environment_from_config(config: Optional[Dict[str, Any]]) -> AbstractEnviron
|
|
42
42
|
raise LaunchError(
|
43
43
|
"Could not create environment from config. Environment type not specified!"
|
44
44
|
)
|
45
|
+
if env_type == "local":
|
46
|
+
from .environment.local_environment import LocalEnvironment
|
47
|
+
|
48
|
+
return LocalEnvironment.from_config(config)
|
45
49
|
if env_type == "aws":
|
46
50
|
from .environment.aws_environment import AwsEnvironment
|
47
51
|
|
@@ -50,6 +54,10 @@ def environment_from_config(config: Optional[Dict[str, Any]]) -> AbstractEnviron
|
|
50
54
|
from .environment.gcp_environment import GcpEnvironment
|
51
55
|
|
52
56
|
return GcpEnvironment.from_config(config)
|
57
|
+
if env_type == "azure":
|
58
|
+
from .environment.azure_environment import AzureEnvironment
|
59
|
+
|
60
|
+
return AzureEnvironment.from_config(config)
|
53
61
|
raise LaunchError(
|
54
62
|
f"Could not create environment from config. Invalid type: {env_type}"
|
55
63
|
)
|
@@ -80,7 +88,7 @@ def registry_from_config(
|
|
80
88
|
|
81
89
|
return LocalRegistry() # This is the default, dummy registry.
|
82
90
|
registry_type = config.get("type")
|
83
|
-
if registry_type is None:
|
91
|
+
if registry_type is None or registry_type == "local":
|
84
92
|
from .registry.local_registry import LocalRegistry
|
85
93
|
|
86
94
|
return LocalRegistry() # This is the default, dummy registry.
|
@@ -106,6 +114,17 @@ def registry_from_config(
|
|
106
114
|
from .registry.google_artifact_registry import GoogleArtifactRegistry
|
107
115
|
|
108
116
|
return GoogleArtifactRegistry.from_config(config, environment)
|
117
|
+
if registry_type == "acr":
|
118
|
+
from .environment.azure_environment import AzureEnvironment
|
119
|
+
|
120
|
+
if not isinstance(environment, AzureEnvironment):
|
121
|
+
raise LaunchError(
|
122
|
+
"Could not create ACR registry. "
|
123
|
+
"Environment must be an instance of AzureEnvironment."
|
124
|
+
)
|
125
|
+
from .registry.azure_container_registry import AzureContainerRegistry
|
126
|
+
|
127
|
+
return AzureContainerRegistry.from_config(config, environment)
|
109
128
|
raise LaunchError(
|
110
129
|
f"Could not create registry from config. Invalid registry type: {registry_type}"
|
111
130
|
)
|
@@ -0,0 +1,132 @@
|
|
1
|
+
"""Implementation of AzureContainerRegistry class."""
|
2
|
+
import re
|
3
|
+
from typing import TYPE_CHECKING, Tuple
|
4
|
+
|
5
|
+
from wandb.util import get_module
|
6
|
+
|
7
|
+
from ..environment.abstract import AbstractEnvironment
|
8
|
+
from ..environment.azure_environment import AzureEnvironment
|
9
|
+
from ..errors import LaunchError
|
10
|
+
from .abstract import AbstractRegistry
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from azure.containerregistry import ContainerRegistryClient # type: ignore
|
14
|
+
from azure.core.exceptions import ResourceNotFoundError # type: ignore
|
15
|
+
|
16
|
+
|
17
|
+
ContainerRegistryClient = get_module( # noqa: F811
|
18
|
+
"azure.containerregistry",
|
19
|
+
required="The azure-containerregistry package is required to use launch with Azure. Please install it with `pip install azure-containerregistry`.",
|
20
|
+
).ContainerRegistryClient
|
21
|
+
|
22
|
+
ResourceNotFoundError = get_module( # noqa: F811
|
23
|
+
"azure.core.exceptions",
|
24
|
+
required="The azure-core package is required to use launch with Azure. Please install it with `pip install azure-core`.",
|
25
|
+
).ResourceNotFoundError
|
26
|
+
|
27
|
+
|
28
|
+
class AzureContainerRegistry(AbstractRegistry):
|
29
|
+
"""Helper for accessing Azure Container Registry resources."""
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
environment: AzureEnvironment,
|
34
|
+
uri: str,
|
35
|
+
verify: bool = True,
|
36
|
+
):
|
37
|
+
"""Initialize an AzureContainerRegistry."""
|
38
|
+
self.environment = environment
|
39
|
+
self.uri = uri
|
40
|
+
if verify:
|
41
|
+
self.verify()
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def from_config(
|
45
|
+
cls, config: dict, environment: AbstractEnvironment, verify: bool = True
|
46
|
+
) -> "AzureContainerRegistry":
|
47
|
+
"""Create an AzureContainerRegistry from a config dict.
|
48
|
+
|
49
|
+
Args:
|
50
|
+
config (dict): The config dict.
|
51
|
+
environment (AbstractEnvironment): The environment to use.
|
52
|
+
verify (bool, optional): Whether to verify the registry. Defaults to True.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
AzureContainerRegistry: The registry.
|
56
|
+
|
57
|
+
Raises:
|
58
|
+
LaunchError: If the config is invalid.
|
59
|
+
"""
|
60
|
+
if not isinstance(environment, AzureEnvironment):
|
61
|
+
raise LaunchError(
|
62
|
+
"AzureContainerRegistry requires an AzureEnvironment to be passed in."
|
63
|
+
)
|
64
|
+
uri = config.get("uri")
|
65
|
+
if uri is None:
|
66
|
+
raise LaunchError(
|
67
|
+
"Please specify a registry name to use under the registry.uri."
|
68
|
+
)
|
69
|
+
return cls(
|
70
|
+
uri=uri,
|
71
|
+
environment=environment,
|
72
|
+
verify=verify,
|
73
|
+
)
|
74
|
+
|
75
|
+
def get_username_password(self) -> Tuple[str, str]:
|
76
|
+
"""Get username and password for container registry."""
|
77
|
+
raise NotImplementedError
|
78
|
+
|
79
|
+
def check_image_exists(self, image_uri: str) -> bool:
|
80
|
+
"""Check if image exists in container registry.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
image_uri (str): Image URI to check.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
bool: True if image exists, False otherwise.
|
87
|
+
"""
|
88
|
+
credential = self.environment.get_credentials()
|
89
|
+
registry, repository, tag = self.parse_azurecr_uri(image_uri)
|
90
|
+
client = ContainerRegistryClient(f"https://{registry}.azurecr.io", credential)
|
91
|
+
try:
|
92
|
+
client.get_manifest_properties(repository, tag)
|
93
|
+
return True
|
94
|
+
except ResourceNotFoundError:
|
95
|
+
return False
|
96
|
+
except Exception as e:
|
97
|
+
raise LaunchError(
|
98
|
+
f"Unable to check if image exists in Azure Container Registry: {e}"
|
99
|
+
) from e
|
100
|
+
|
101
|
+
def get_repo_uri(self) -> str:
|
102
|
+
return self.uri
|
103
|
+
|
104
|
+
def verify(self) -> None:
|
105
|
+
try:
|
106
|
+
_ = self.registry_name
|
107
|
+
except Exception as e:
|
108
|
+
raise LaunchError(f"Unable to verify Azure Container Registry: {e}") from e
|
109
|
+
|
110
|
+
@property
|
111
|
+
def registry_name(self) -> str:
|
112
|
+
"""Get registry name."""
|
113
|
+
return self.parse_azurecr_uri(self.uri)[0]
|
114
|
+
|
115
|
+
@staticmethod
|
116
|
+
def parse_azurecr_uri(uri: str) -> Tuple[str, str, str]:
|
117
|
+
"""Parse an Azure Container Registry URI.
|
118
|
+
|
119
|
+
Args:
|
120
|
+
uri (str): URI to parse.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
Tuple[str, str, str]: Tuple of registry name, repository name, and tag.
|
124
|
+
|
125
|
+
Raises:
|
126
|
+
LaunchError: If unable to parse URI.
|
127
|
+
"""
|
128
|
+
regex = r"(?:https://)([\w]+)\.azurecr\.io/([\w\-]+):?(.*)"
|
129
|
+
match = re.match(regex, uri)
|
130
|
+
if match is None:
|
131
|
+
raise LaunchError(f"Unable to parse Azure Container Registry URI: {uri}")
|
132
|
+
return match.group(1), match.group(2), match.group(3)
|