wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,192 @@
|
|
1
|
+
"""Azure storage handler."""
|
2
|
+
from pathlib import PurePosixPath
|
3
|
+
from types import ModuleType
|
4
|
+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
|
5
|
+
from urllib.parse import ParseResult, parse_qsl, urlparse
|
6
|
+
|
7
|
+
import wandb
|
8
|
+
from wandb import util
|
9
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
10
|
+
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
11
|
+
from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
|
12
|
+
from wandb.sdk.lib.hashutil import ETag
|
13
|
+
from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
import azure.identity # type: ignore
|
17
|
+
import azure.storage.blob # type: ignore
|
18
|
+
|
19
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
20
|
+
|
21
|
+
|
22
|
+
class AzureHandler(StorageHandler):
|
23
|
+
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
24
|
+
return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
|
25
|
+
".blob.core.windows.net"
|
26
|
+
)
|
27
|
+
|
28
|
+
def load_path(
|
29
|
+
self,
|
30
|
+
manifest_entry: "ArtifactManifestEntry",
|
31
|
+
local: bool = False,
|
32
|
+
) -> Union[URIStr, FilePathStr]:
|
33
|
+
assert manifest_entry.ref is not None
|
34
|
+
if not local:
|
35
|
+
return manifest_entry.ref
|
36
|
+
|
37
|
+
path, hit, cache_open = get_artifacts_cache().check_etag_obj_path(
|
38
|
+
URIStr(manifest_entry.ref),
|
39
|
+
ETag(manifest_entry.digest),
|
40
|
+
manifest_entry.size or 0,
|
41
|
+
)
|
42
|
+
if hit:
|
43
|
+
return path
|
44
|
+
|
45
|
+
account_url, container_name, blob_name, query = self._parse_uri(
|
46
|
+
manifest_entry.ref
|
47
|
+
)
|
48
|
+
version_id = manifest_entry.extra.get("versionID")
|
49
|
+
blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
|
50
|
+
account_url, credential=self._get_credential(account_url)
|
51
|
+
)
|
52
|
+
blob_client = blob_service_client.get_blob_client(
|
53
|
+
container=container_name, blob=blob_name
|
54
|
+
)
|
55
|
+
if version_id is None:
|
56
|
+
# Try current version, then all versions.
|
57
|
+
try:
|
58
|
+
downloader = blob_client.download_blob(
|
59
|
+
etag=manifest_entry.digest,
|
60
|
+
match_condition=self._get_module(
|
61
|
+
"azure.core"
|
62
|
+
).MatchConditions.IfNotModified,
|
63
|
+
)
|
64
|
+
except self._get_module("azure.core.exceptions").ResourceModifiedError:
|
65
|
+
container_client = blob_service_client.get_container_client(
|
66
|
+
container_name
|
67
|
+
)
|
68
|
+
for blob_properties in container_client.walk_blobs(
|
69
|
+
name_starts_with=blob_name, include=["versions"]
|
70
|
+
):
|
71
|
+
if (
|
72
|
+
blob_properties.name == blob_name
|
73
|
+
and blob_properties.etag == manifest_entry.digest
|
74
|
+
and blob_properties.version_id is not None
|
75
|
+
):
|
76
|
+
downloader = blob_client.download_blob(
|
77
|
+
version_id=blob_properties.version_id
|
78
|
+
)
|
79
|
+
break
|
80
|
+
else: # didn't break
|
81
|
+
raise ValueError(
|
82
|
+
f"Couldn't find blob version for {manifest_entry.ref} matching "
|
83
|
+
f"etag {manifest_entry.digest}."
|
84
|
+
)
|
85
|
+
else:
|
86
|
+
downloader = blob_client.download_blob(version_id=version_id)
|
87
|
+
with cache_open(mode="wb") as f:
|
88
|
+
downloader.readinto(f)
|
89
|
+
return path
|
90
|
+
|
91
|
+
def store_path(
|
92
|
+
self,
|
93
|
+
artifact: "Artifact",
|
94
|
+
path: Union[URIStr, FilePathStr],
|
95
|
+
name: Optional[StrPath] = None,
|
96
|
+
checksum: bool = True,
|
97
|
+
max_objects: Optional[int] = None,
|
98
|
+
) -> Sequence["ArtifactManifestEntry"]:
|
99
|
+
account_url, container_name, blob_name, query = self._parse_uri(path)
|
100
|
+
path = URIStr(f"{account_url}/{container_name}/{blob_name}")
|
101
|
+
|
102
|
+
if not checksum:
|
103
|
+
return [
|
104
|
+
ArtifactManifestEntry(path=name or blob_name, digest=path, ref=path)
|
105
|
+
]
|
106
|
+
|
107
|
+
blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
|
108
|
+
account_url, credential=self._get_credential(account_url)
|
109
|
+
)
|
110
|
+
blob_client = blob_service_client.get_blob_client(
|
111
|
+
container=container_name, blob=blob_name
|
112
|
+
)
|
113
|
+
if blob_client.exists(version_id=query.get("versionId")):
|
114
|
+
blob_properties = blob_client.get_blob_properties(
|
115
|
+
version_id=query.get("versionId")
|
116
|
+
)
|
117
|
+
return [
|
118
|
+
self._create_entry(
|
119
|
+
blob_properties,
|
120
|
+
path=name or PurePosixPath(blob_name).name,
|
121
|
+
ref=URIStr(
|
122
|
+
f"{account_url}/{container_name}/{blob_properties.name}"
|
123
|
+
),
|
124
|
+
)
|
125
|
+
]
|
126
|
+
|
127
|
+
entries = []
|
128
|
+
container_client = blob_service_client.get_container_client(container_name)
|
129
|
+
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
130
|
+
for i, blob_properties in enumerate(
|
131
|
+
container_client.list_blobs(name_starts_with=f"{blob_name}/")
|
132
|
+
):
|
133
|
+
if i >= max_objects:
|
134
|
+
raise ValueError(
|
135
|
+
f"Exceeded {max_objects} objects tracked, pass max_objects to "
|
136
|
+
f"add_reference"
|
137
|
+
)
|
138
|
+
suffix = PurePosixPath(blob_properties.name).relative_to(blob_name)
|
139
|
+
entries.append(
|
140
|
+
self._create_entry(
|
141
|
+
blob_properties,
|
142
|
+
path=LogicalPath(name) / suffix if name else suffix,
|
143
|
+
ref=URIStr(
|
144
|
+
f"{account_url}/{container_name}/{blob_properties.name}"
|
145
|
+
),
|
146
|
+
)
|
147
|
+
)
|
148
|
+
return entries
|
149
|
+
|
150
|
+
def _get_module(self, name: str) -> ModuleType:
|
151
|
+
module = util.get_module(
|
152
|
+
name,
|
153
|
+
lazy=False,
|
154
|
+
required="Azure references require the azure library, run "
|
155
|
+
"pip install wandb[azure]",
|
156
|
+
)
|
157
|
+
assert isinstance(module, ModuleType)
|
158
|
+
return module
|
159
|
+
|
160
|
+
def _get_credential(
|
161
|
+
self, account_url: str
|
162
|
+
) -> Union["azure.identity.DefaultAzureCredential", str]:
|
163
|
+
if (
|
164
|
+
wandb.run
|
165
|
+
and account_url in wandb.run.settings.azure_account_url_to_access_key
|
166
|
+
):
|
167
|
+
return wandb.run.settings.azure_account_url_to_access_key[account_url]
|
168
|
+
return self._get_module("azure.identity").DefaultAzureCredential()
|
169
|
+
|
170
|
+
def _parse_uri(self, uri: str) -> Tuple[str, str, str, Dict[str, str]]:
|
171
|
+
parsed_url = urlparse(uri)
|
172
|
+
query = dict(parse_qsl(parsed_url.query))
|
173
|
+
account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
174
|
+
_, container_name, blob_name = parsed_url.path.split("/", 2)
|
175
|
+
return account_url, container_name, blob_name, query
|
176
|
+
|
177
|
+
def _create_entry(
|
178
|
+
self,
|
179
|
+
blob_properties: "azure.storage.blob.BlobProperties",
|
180
|
+
path: StrPath,
|
181
|
+
ref: URIStr,
|
182
|
+
) -> ArtifactManifestEntry:
|
183
|
+
extra = {"etag": blob_properties.etag.strip('"')}
|
184
|
+
if blob_properties.version_id:
|
185
|
+
extra["versionID"] = blob_properties.version_id
|
186
|
+
return ArtifactManifestEntry(
|
187
|
+
path=path,
|
188
|
+
ref=ref,
|
189
|
+
digest=blob_properties.etag.strip('"'),
|
190
|
+
size=blob_properties.size,
|
191
|
+
extra=extra,
|
192
|
+
)
|
@@ -0,0 +1,224 @@
|
|
1
|
+
"""GCS storage handler."""
|
2
|
+
import base64
|
3
|
+
import time
|
4
|
+
from pathlib import PurePosixPath
|
5
|
+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
|
6
|
+
from urllib.parse import ParseResult, urlparse
|
7
|
+
|
8
|
+
from wandb import util
|
9
|
+
from wandb.errors.term import termlog
|
10
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
11
|
+
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
12
|
+
from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
|
13
|
+
from wandb.sdk.lib.hashutil import B64MD5
|
14
|
+
from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
import google.cloud.storage as gcs_module # type: ignore
|
18
|
+
|
19
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
20
|
+
|
21
|
+
|
22
|
+
class GCSHandler(StorageHandler):
|
23
|
+
_client: Optional["gcs_module.client.Client"]
|
24
|
+
_versioning_enabled: Optional[bool]
|
25
|
+
|
26
|
+
def __init__(self, scheme: Optional[str] = None) -> None:
|
27
|
+
self._scheme = scheme or "gs"
|
28
|
+
self._client = None
|
29
|
+
self._versioning_enabled = None
|
30
|
+
self._cache = get_artifacts_cache()
|
31
|
+
|
32
|
+
def versioning_enabled(self, bucket_path: str) -> bool:
|
33
|
+
if self._versioning_enabled is not None:
|
34
|
+
return self._versioning_enabled
|
35
|
+
self.init_gcs()
|
36
|
+
assert self._client is not None # mypy: unwraps optionality
|
37
|
+
bucket = self._client.bucket(bucket_path)
|
38
|
+
bucket.reload()
|
39
|
+
self._versioning_enabled = bucket.versioning_enabled
|
40
|
+
return self._versioning_enabled
|
41
|
+
|
42
|
+
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
43
|
+
return parsed_url.scheme == self._scheme
|
44
|
+
|
45
|
+
def init_gcs(self) -> "gcs_module.client.Client":
|
46
|
+
if self._client is not None:
|
47
|
+
return self._client
|
48
|
+
storage = util.get_module(
|
49
|
+
"google.cloud.storage",
|
50
|
+
required="gs:// references requires the google-cloud-storage library, run pip install wandb[gcp]",
|
51
|
+
)
|
52
|
+
self._client = storage.Client()
|
53
|
+
return self._client
|
54
|
+
|
55
|
+
def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
|
56
|
+
url = urlparse(uri)
|
57
|
+
bucket = url.netloc
|
58
|
+
key = url.path[1:]
|
59
|
+
version = url.fragment if url.fragment else None
|
60
|
+
return bucket, key, version
|
61
|
+
|
62
|
+
def load_path(
|
63
|
+
self,
|
64
|
+
manifest_entry: ArtifactManifestEntry,
|
65
|
+
local: bool = False,
|
66
|
+
) -> Union[URIStr, FilePathStr]:
|
67
|
+
if not local:
|
68
|
+
assert manifest_entry.ref is not None
|
69
|
+
return manifest_entry.ref
|
70
|
+
|
71
|
+
path, hit, cache_open = self._cache.check_md5_obj_path(
|
72
|
+
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
73
|
+
manifest_entry.size if manifest_entry.size is not None else 0,
|
74
|
+
)
|
75
|
+
if hit:
|
76
|
+
return path
|
77
|
+
|
78
|
+
self.init_gcs()
|
79
|
+
assert self._client is not None # mypy: unwraps optionality
|
80
|
+
assert manifest_entry.ref is not None
|
81
|
+
bucket, key, _ = self._parse_uri(manifest_entry.ref)
|
82
|
+
version = manifest_entry.extra.get("versionID")
|
83
|
+
|
84
|
+
obj = None
|
85
|
+
# First attempt to get the generation specified, this will return None if versioning is not enabled
|
86
|
+
if version is not None:
|
87
|
+
obj = self._client.bucket(bucket).get_blob(key, generation=version)
|
88
|
+
|
89
|
+
if obj is None:
|
90
|
+
# Object versioning is disabled on the bucket, so just get
|
91
|
+
# the latest version and make sure the MD5 matches.
|
92
|
+
obj = self._client.bucket(bucket).get_blob(key)
|
93
|
+
if obj is None:
|
94
|
+
raise ValueError(
|
95
|
+
f"Unable to download object {manifest_entry.ref} with generation {version}"
|
96
|
+
)
|
97
|
+
md5 = obj.md5_hash
|
98
|
+
if md5 != manifest_entry.digest:
|
99
|
+
raise ValueError(
|
100
|
+
f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {md5}"
|
101
|
+
)
|
102
|
+
|
103
|
+
with cache_open(mode="wb") as f:
|
104
|
+
obj.download_to_file(f)
|
105
|
+
return path
|
106
|
+
|
107
|
+
def store_path(
|
108
|
+
self,
|
109
|
+
artifact: "Artifact",
|
110
|
+
path: Union[URIStr, FilePathStr],
|
111
|
+
name: Optional[StrPath] = None,
|
112
|
+
checksum: bool = True,
|
113
|
+
max_objects: Optional[int] = None,
|
114
|
+
) -> Sequence[ArtifactManifestEntry]:
|
115
|
+
self.init_gcs()
|
116
|
+
assert self._client is not None # mypy: unwraps optionality
|
117
|
+
|
118
|
+
# After parsing any query params / fragments for additional context,
|
119
|
+
# such as version identifiers, pare down the path to just the bucket
|
120
|
+
# and key.
|
121
|
+
bucket, key, version = self._parse_uri(path)
|
122
|
+
path = URIStr(f"{self._scheme}://{bucket}/{key}")
|
123
|
+
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
124
|
+
if not self.versioning_enabled(bucket) and version:
|
125
|
+
raise ValueError(
|
126
|
+
f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
|
127
|
+
)
|
128
|
+
|
129
|
+
if not checksum:
|
130
|
+
return [ArtifactManifestEntry(path=name or key, ref=path, digest=path)]
|
131
|
+
|
132
|
+
start_time = None
|
133
|
+
obj = self._client.bucket(bucket).get_blob(key, generation=version)
|
134
|
+
multi = obj is None
|
135
|
+
if multi:
|
136
|
+
start_time = time.time()
|
137
|
+
termlog(
|
138
|
+
'Generating checksum for up to %i objects with prefix "%s"... '
|
139
|
+
% (max_objects, key),
|
140
|
+
newline=False,
|
141
|
+
)
|
142
|
+
objects = self._client.bucket(bucket).list_blobs(
|
143
|
+
prefix=key, max_results=max_objects
|
144
|
+
)
|
145
|
+
else:
|
146
|
+
objects = [obj]
|
147
|
+
|
148
|
+
entries = [
|
149
|
+
self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
|
150
|
+
for obj in objects
|
151
|
+
]
|
152
|
+
if start_time is not None:
|
153
|
+
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
154
|
+
if len(entries) > max_objects:
|
155
|
+
raise ValueError(
|
156
|
+
"Exceeded %i objects tracked, pass max_objects to add_reference"
|
157
|
+
% max_objects
|
158
|
+
)
|
159
|
+
return entries
|
160
|
+
|
161
|
+
def _entry_from_obj(
|
162
|
+
self,
|
163
|
+
obj: "gcs_module.blob.Blob",
|
164
|
+
path: str,
|
165
|
+
name: Optional[StrPath] = None,
|
166
|
+
prefix: str = "",
|
167
|
+
multi: bool = False,
|
168
|
+
) -> ArtifactManifestEntry:
|
169
|
+
"""Create an ArtifactManifestEntry from a GCS object.
|
170
|
+
|
171
|
+
Arguments:
|
172
|
+
obj: The GCS object
|
173
|
+
path: The GCS-style path (e.g.: "gs://bucket/file.txt")
|
174
|
+
name: The user assigned name, or None if not specified
|
175
|
+
prefix: The prefix to add (will be the same as `path` for directories)
|
176
|
+
multi: Whether or not this is a multi-object add.
|
177
|
+
"""
|
178
|
+
bucket, key, _ = self._parse_uri(path)
|
179
|
+
|
180
|
+
# Always use posix paths, since that's what S3 uses.
|
181
|
+
posix_key = PurePosixPath(obj.name) # the bucket key
|
182
|
+
posix_path = PurePosixPath(bucket) / PurePosixPath(
|
183
|
+
key
|
184
|
+
) # the path, with the scheme stripped
|
185
|
+
posix_prefix = PurePosixPath(prefix) # the prefix, if adding a prefix
|
186
|
+
posix_name = PurePosixPath(name or "")
|
187
|
+
posix_ref = posix_path
|
188
|
+
|
189
|
+
if name is None:
|
190
|
+
# We're adding a directory (prefix), so calculate a relative path.
|
191
|
+
if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
|
192
|
+
posix_name = posix_key.relative_to(posix_prefix)
|
193
|
+
posix_ref = posix_path / posix_name
|
194
|
+
else:
|
195
|
+
posix_name = PurePosixPath(posix_key.name)
|
196
|
+
posix_ref = posix_path
|
197
|
+
elif multi:
|
198
|
+
# We're adding a directory with a name override.
|
199
|
+
relpath = posix_key.relative_to(posix_prefix)
|
200
|
+
posix_name = posix_name / relpath
|
201
|
+
posix_ref = posix_path / relpath
|
202
|
+
return ArtifactManifestEntry(
|
203
|
+
path=posix_name,
|
204
|
+
ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
|
205
|
+
digest=obj.md5_hash,
|
206
|
+
size=obj.size,
|
207
|
+
extra=self._extra_from_obj(obj),
|
208
|
+
)
|
209
|
+
|
210
|
+
@staticmethod
|
211
|
+
def _extra_from_obj(obj: "gcs_module.blob.Blob") -> Dict[str, str]:
|
212
|
+
return {
|
213
|
+
"etag": obj.etag,
|
214
|
+
"versionID": obj.generation,
|
215
|
+
}
|
216
|
+
|
217
|
+
@staticmethod
|
218
|
+
def _content_addressed_path(md5: str) -> FilePathStr:
|
219
|
+
# TODO: is this the structure we want? not at all human
|
220
|
+
# readable, but that's probably OK. don't want people
|
221
|
+
# poking around in the bucket
|
222
|
+
return FilePathStr(
|
223
|
+
"wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
|
224
|
+
)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
"""HTTP storage handler."""
|
2
|
+
import os
|
3
|
+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Union
|
4
|
+
from urllib.parse import ParseResult
|
5
|
+
|
6
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
7
|
+
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
8
|
+
from wandb.sdk.artifacts.storage_handler import StorageHandler
|
9
|
+
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
10
|
+
from wandb.sdk.lib.hashutil import ETag
|
11
|
+
from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
import requests
|
15
|
+
|
16
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
17
|
+
|
18
|
+
|
19
|
+
class HTTPHandler(StorageHandler):
|
20
|
+
def __init__(
|
21
|
+
self, session: "requests.Session", scheme: Optional[str] = None
|
22
|
+
) -> None:
|
23
|
+
self._scheme = scheme or "http"
|
24
|
+
self._cache = get_artifacts_cache()
|
25
|
+
self._session = session
|
26
|
+
|
27
|
+
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
28
|
+
return parsed_url.scheme == self._scheme
|
29
|
+
|
30
|
+
def load_path(
|
31
|
+
self,
|
32
|
+
manifest_entry: ArtifactManifestEntry,
|
33
|
+
local: bool = False,
|
34
|
+
) -> Union[URIStr, FilePathStr]:
|
35
|
+
if not local:
|
36
|
+
assert manifest_entry.ref is not None
|
37
|
+
return manifest_entry.ref
|
38
|
+
|
39
|
+
assert manifest_entry.ref is not None
|
40
|
+
|
41
|
+
path, hit, cache_open = self._cache.check_etag_obj_path(
|
42
|
+
URIStr(manifest_entry.ref),
|
43
|
+
ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
44
|
+
manifest_entry.size if manifest_entry.size is not None else 0,
|
45
|
+
)
|
46
|
+
if hit:
|
47
|
+
return path
|
48
|
+
|
49
|
+
response = self._session.get(
|
50
|
+
manifest_entry.ref,
|
51
|
+
stream=True,
|
52
|
+
cookies=_thread_local_api_settings.cookies,
|
53
|
+
headers=_thread_local_api_settings.headers,
|
54
|
+
)
|
55
|
+
response.raise_for_status()
|
56
|
+
|
57
|
+
digest: Optional[Union[ETag, FilePathStr, URIStr]]
|
58
|
+
digest, size, extra = self._entry_from_headers(response.headers)
|
59
|
+
digest = digest or manifest_entry.ref
|
60
|
+
if manifest_entry.digest != digest:
|
61
|
+
raise ValueError(
|
62
|
+
f"Digest mismatch for url {manifest_entry.ref}: expected {manifest_entry.digest} but found {digest}"
|
63
|
+
)
|
64
|
+
|
65
|
+
with cache_open(mode="wb") as file:
|
66
|
+
for data in response.iter_content(chunk_size=16 * 1024):
|
67
|
+
file.write(data)
|
68
|
+
return path
|
69
|
+
|
70
|
+
def store_path(
|
71
|
+
self,
|
72
|
+
artifact: "Artifact",
|
73
|
+
path: Union[URIStr, FilePathStr],
|
74
|
+
name: Optional[StrPath] = None,
|
75
|
+
checksum: bool = True,
|
76
|
+
max_objects: Optional[int] = None,
|
77
|
+
) -> Sequence[ArtifactManifestEntry]:
|
78
|
+
name = name or os.path.basename(path)
|
79
|
+
if not checksum:
|
80
|
+
return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
|
81
|
+
|
82
|
+
with self._session.get(
|
83
|
+
path,
|
84
|
+
stream=True,
|
85
|
+
cookies=_thread_local_api_settings.cookies,
|
86
|
+
headers=_thread_local_api_settings.headers,
|
87
|
+
) as response:
|
88
|
+
response.raise_for_status()
|
89
|
+
digest: Optional[Union[ETag, FilePathStr, URIStr]]
|
90
|
+
digest, size, extra = self._entry_from_headers(response.headers)
|
91
|
+
digest = digest or path
|
92
|
+
return [
|
93
|
+
ArtifactManifestEntry(
|
94
|
+
path=name, ref=path, digest=digest, size=size, extra=extra
|
95
|
+
)
|
96
|
+
]
|
97
|
+
|
98
|
+
def _entry_from_headers(
|
99
|
+
self, headers: "requests.structures.CaseInsensitiveDict"
|
100
|
+
) -> Tuple[Optional[ETag], Optional[int], Dict[str, str]]:
|
101
|
+
response_headers = {k.lower(): v for k, v in headers.items()}
|
102
|
+
size = None
|
103
|
+
if response_headers.get("content-length", None):
|
104
|
+
size = int(response_headers["content-length"])
|
105
|
+
|
106
|
+
digest = response_headers.get("etag", None)
|
107
|
+
extra = {}
|
108
|
+
if digest:
|
109
|
+
extra["etag"] = digest
|
110
|
+
if digest and digest[:1] == '"' and digest[-1:] == '"':
|
111
|
+
digest = digest[1:-1] # trim leading and trailing quotes around etag
|
112
|
+
return digest, size, extra
|
@@ -0,0 +1,134 @@
|
|
1
|
+
"""Local file storage handler."""
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
import time
|
5
|
+
from typing import TYPE_CHECKING, Optional, Sequence, Union
|
6
|
+
from urllib.parse import ParseResult
|
7
|
+
|
8
|
+
from wandb import util
|
9
|
+
from wandb.errors.term import termlog
|
10
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
11
|
+
from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
|
12
|
+
from wandb.sdk.artifacts.storage_handler import DEFAULT_MAX_OBJECTS, StorageHandler
|
13
|
+
from wandb.sdk.lib import filesystem
|
14
|
+
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64, md5_string
|
15
|
+
from wandb.sdk.lib.paths import FilePathStr, StrPath, URIStr
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
19
|
+
|
20
|
+
|
21
|
+
class LocalFileHandler(StorageHandler):
|
22
|
+
"""Handles file:// references."""
|
23
|
+
|
24
|
+
def __init__(self, scheme: Optional[str] = None) -> None:
|
25
|
+
"""Track files or directories on a local filesystem.
|
26
|
+
|
27
|
+
Expand directories to create an entry for each file contained.
|
28
|
+
"""
|
29
|
+
self._scheme = scheme or "file"
|
30
|
+
self._cache = get_artifacts_cache()
|
31
|
+
|
32
|
+
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
33
|
+
return parsed_url.scheme == self._scheme
|
34
|
+
|
35
|
+
def load_path(
|
36
|
+
self,
|
37
|
+
manifest_entry: ArtifactManifestEntry,
|
38
|
+
local: bool = False,
|
39
|
+
) -> Union[URIStr, FilePathStr]:
|
40
|
+
if manifest_entry.ref is None:
|
41
|
+
raise ValueError(f"Cannot add path with no ref: {manifest_entry.path}")
|
42
|
+
local_path = util.local_file_uri_to_path(str(manifest_entry.ref))
|
43
|
+
if not os.path.exists(local_path):
|
44
|
+
raise ValueError(
|
45
|
+
"Local file reference: Failed to find file at path %s" % local_path
|
46
|
+
)
|
47
|
+
|
48
|
+
path, hit, cache_open = self._cache.check_md5_obj_path(
|
49
|
+
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
50
|
+
manifest_entry.size if manifest_entry.size is not None else 0,
|
51
|
+
)
|
52
|
+
if hit:
|
53
|
+
return path
|
54
|
+
|
55
|
+
md5 = md5_file_b64(local_path)
|
56
|
+
if md5 != manifest_entry.digest:
|
57
|
+
raise ValueError(
|
58
|
+
f"Local file reference: Digest mismatch for path {local_path}: expected {manifest_entry.digest} but found {md5}"
|
59
|
+
)
|
60
|
+
|
61
|
+
filesystem.mkdir_exists_ok(os.path.dirname(path))
|
62
|
+
|
63
|
+
with cache_open() as f:
|
64
|
+
shutil.copy(local_path, f.name)
|
65
|
+
return path
|
66
|
+
|
67
|
+
def store_path(
|
68
|
+
self,
|
69
|
+
artifact: "Artifact",
|
70
|
+
path: Union[URIStr, FilePathStr],
|
71
|
+
name: Optional[StrPath] = None,
|
72
|
+
checksum: bool = True,
|
73
|
+
max_objects: Optional[int] = None,
|
74
|
+
) -> Sequence[ArtifactManifestEntry]:
|
75
|
+
local_path = util.local_file_uri_to_path(path)
|
76
|
+
max_objects = max_objects or DEFAULT_MAX_OBJECTS
|
77
|
+
# We have a single file or directory
|
78
|
+
# Note, we follow symlinks for files contained within the directory
|
79
|
+
entries = []
|
80
|
+
|
81
|
+
def md5(path: str) -> B64MD5:
|
82
|
+
return (
|
83
|
+
md5_file_b64(path)
|
84
|
+
if checksum
|
85
|
+
else md5_string(str(os.stat(path).st_size))
|
86
|
+
)
|
87
|
+
|
88
|
+
if os.path.isdir(local_path):
|
89
|
+
i = 0
|
90
|
+
start_time = time.time()
|
91
|
+
if checksum:
|
92
|
+
termlog(
|
93
|
+
'Generating checksum for up to %i files in "%s"...\n'
|
94
|
+
% (max_objects, local_path),
|
95
|
+
newline=False,
|
96
|
+
)
|
97
|
+
for root, _, files in os.walk(local_path):
|
98
|
+
for sub_path in files:
|
99
|
+
i += 1
|
100
|
+
if i > max_objects:
|
101
|
+
raise ValueError(
|
102
|
+
"Exceeded %i objects tracked, pass max_objects to add_reference"
|
103
|
+
% max_objects
|
104
|
+
)
|
105
|
+
physical_path = os.path.join(root, sub_path)
|
106
|
+
# TODO(spencerpearson): this is not a "logical path" in the sense that
|
107
|
+
# `LogicalPath` returns a "logical path"; it's a relative path
|
108
|
+
# **on the local filesystem**.
|
109
|
+
logical_path = os.path.relpath(physical_path, start=local_path)
|
110
|
+
if name is not None:
|
111
|
+
logical_path = os.path.join(name, logical_path)
|
112
|
+
|
113
|
+
entry = ArtifactManifestEntry(
|
114
|
+
path=logical_path,
|
115
|
+
ref=FilePathStr(os.path.join(path, logical_path)),
|
116
|
+
size=os.path.getsize(physical_path),
|
117
|
+
digest=md5(physical_path),
|
118
|
+
)
|
119
|
+
entries.append(entry)
|
120
|
+
if checksum:
|
121
|
+
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
122
|
+
elif os.path.isfile(local_path):
|
123
|
+
name = name or os.path.basename(local_path)
|
124
|
+
entry = ArtifactManifestEntry(
|
125
|
+
path=name,
|
126
|
+
ref=path,
|
127
|
+
size=os.path.getsize(local_path),
|
128
|
+
digest=md5(local_path),
|
129
|
+
)
|
130
|
+
entries.append(entry)
|
131
|
+
else:
|
132
|
+
# TODO: update error message if we don't allow directories.
|
133
|
+
raise ValueError('Path "%s" must be a valid file or directory path' % path)
|
134
|
+
return entries
|