wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +1 -1
- wandb/analytics/sentry.py +1 -0
- wandb/apis/importers/base.py +20 -5
- wandb/apis/importers/mlflow.py +7 -1
- wandb/apis/internal.py +12 -0
- wandb/apis/public.py +247 -1387
- wandb/apis/reports/_panels.py +58 -35
- wandb/beta/workflows.py +6 -7
- wandb/cli/cli.py +130 -60
- wandb/data_types.py +3 -1
- wandb/filesync/dir_watcher.py +21 -27
- wandb/filesync/step_checksum.py +8 -8
- wandb/filesync/step_prepare.py +23 -10
- wandb/filesync/step_upload.py +13 -13
- wandb/filesync/upload_job.py +4 -8
- wandb/integration/cohere/__init__.py +3 -0
- wandb/integration/cohere/cohere.py +21 -0
- wandb/integration/cohere/resolver.py +347 -0
- wandb/integration/gym/__init__.py +4 -6
- wandb/integration/huggingface/__init__.py +3 -0
- wandb/integration/huggingface/huggingface.py +18 -0
- wandb/integration/huggingface/resolver.py +213 -0
- wandb/integration/langchain/wandb_tracer.py +16 -179
- wandb/integration/openai/__init__.py +1 -3
- wandb/integration/openai/openai.py +11 -143
- wandb/integration/openai/resolver.py +111 -38
- wandb/integration/sagemaker/config.py +2 -2
- wandb/integration/tensorboard/log.py +4 -4
- wandb/old/settings.py +24 -7
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +3 -1
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/__init__.py +0 -0
- wandb/sdk/artifacts/artifact.py +2101 -0
- wandb/sdk/artifacts/artifact_download_logger.py +42 -0
- wandb/sdk/artifacts/artifact_manifest.py +67 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
- wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
- wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
- wandb/sdk/artifacts/artifact_state.py +10 -0
- wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
- wandb/sdk/artifacts/exceptions.py +55 -0
- wandb/sdk/artifacts/storage_handler.py +59 -0
- wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
- wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
- wandb/sdk/artifacts/storage_layout.py +6 -0
- wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
- wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
- wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
- wandb/sdk/data_types/_dtypes.py +7 -12
- wandb/sdk/data_types/base_types/json_metadata.py +3 -2
- wandb/sdk/data_types/base_types/media.py +8 -8
- wandb/sdk/data_types/base_types/wb_value.py +12 -13
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
- wandb/sdk/data_types/helper_types/classes.py +6 -8
- wandb/sdk/data_types/helper_types/image_mask.py +5 -6
- wandb/sdk/data_types/histogram.py +4 -3
- wandb/sdk/data_types/html.py +3 -4
- wandb/sdk/data_types/image.py +11 -9
- wandb/sdk/data_types/molecule.py +5 -3
- wandb/sdk/data_types/object_3d.py +7 -5
- wandb/sdk/data_types/plotly.py +3 -2
- wandb/sdk/data_types/saved_model.py +11 -11
- wandb/sdk/data_types/trace_tree.py +5 -4
- wandb/sdk/data_types/utils.py +3 -5
- wandb/sdk/data_types/video.py +5 -4
- wandb/sdk/integration_utils/auto_logging.py +215 -0
- wandb/sdk/interface/interface.py +15 -15
- wandb/sdk/internal/file_pusher.py +8 -16
- wandb/sdk/internal/file_stream.py +5 -11
- wandb/sdk/internal/handler.py +13 -1
- wandb/sdk/internal/internal_api.py +287 -13
- wandb/sdk/internal/job_builder.py +119 -30
- wandb/sdk/internal/sender.py +6 -26
- wandb/sdk/internal/settings_static.py +2 -0
- wandb/sdk/internal/system/assets/__init__.py +2 -0
- wandb/sdk/internal/system/assets/gpu.py +42 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
- wandb/sdk/internal/system/env_probe_helpers.py +13 -0
- wandb/sdk/internal/system/system_info.py +3 -3
- wandb/sdk/internal/tb_watcher.py +32 -22
- wandb/sdk/internal/thread_local_settings.py +18 -0
- wandb/sdk/launch/_project_spec.py +57 -11
- wandb/sdk/launch/agent/agent.py +147 -65
- wandb/sdk/launch/agent/job_status_tracker.py +34 -0
- wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
- wandb/sdk/launch/builder/abstract.py +5 -1
- wandb/sdk/launch/builder/build.py +21 -18
- wandb/sdk/launch/builder/docker_builder.py +10 -4
- wandb/sdk/launch/builder/kaniko_builder.py +113 -23
- wandb/sdk/launch/builder/noop.py +6 -3
- wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
- wandb/sdk/launch/environment/aws_environment.py +3 -2
- wandb/sdk/launch/environment/azure_environment.py +124 -0
- wandb/sdk/launch/environment/gcp_environment.py +2 -4
- wandb/sdk/launch/environment/local_environment.py +1 -1
- wandb/sdk/launch/errors.py +19 -0
- wandb/sdk/launch/github_reference.py +32 -19
- wandb/sdk/launch/launch.py +3 -8
- wandb/sdk/launch/launch_add.py +6 -2
- wandb/sdk/launch/loader.py +21 -2
- wandb/sdk/launch/registry/azure_container_registry.py +132 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
- wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
- wandb/sdk/launch/registry/local_registry.py +2 -1
- wandb/sdk/launch/runner/abstract.py +24 -3
- wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
- wandb/sdk/launch/runner/local_container.py +103 -51
- wandb/sdk/launch/runner/local_process.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
- wandb/sdk/launch/runner/vertex_runner.py +10 -5
- wandb/sdk/launch/sweeps/__init__.py +7 -9
- wandb/sdk/launch/sweeps/scheduler.py +307 -77
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +82 -35
- wandb/sdk/launch/utils.py +89 -75
- wandb/sdk/lib/_settings_toposort_generated.py +7 -0
- wandb/sdk/lib/capped_dict.py +26 -0
- wandb/sdk/lib/{git.py → gitlib.py} +76 -59
- wandb/sdk/lib/hashutil.py +12 -4
- wandb/sdk/lib/paths.py +96 -8
- wandb/sdk/lib/sock_client.py +2 -2
- wandb/sdk/lib/timer.py +1 -0
- wandb/sdk/service/server.py +22 -9
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +27 -8
- wandb/sdk/verify/verify.py +4 -7
- wandb/sdk/wandb_config.py +2 -6
- wandb/sdk/wandb_init.py +57 -53
- wandb/sdk/wandb_require.py +7 -0
- wandb/sdk/wandb_run.py +61 -223
- wandb/sdk/wandb_settings.py +28 -4
- wandb/testing/relay.py +15 -2
- wandb/util.py +74 -36
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
- wandb/integration/langchain/util.py +0 -191
- wandb/sdk/interface/artifacts/__init__.py +0 -33
- wandb/sdk/interface/artifacts/artifact.py +0 -615
- wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
- wandb/sdk/wandb_artifacts.py +0 -2226
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
- {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
"""S3 bucket storage policy."""
|
2
|
+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
|
3
|
+
|
4
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
5
|
+
from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
|
6
|
+
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
7
|
+
from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
|
8
|
+
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
9
|
+
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
13
|
+
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
14
|
+
|
15
|
+
|
16
|
+
# Don't use this yet!
|
17
|
+
class __S3BucketPolicy(StoragePolicy): # noqa: N801
|
18
|
+
@classmethod
|
19
|
+
def name(cls) -> str:
|
20
|
+
return "wandb-s3-bucket-policy-v1"
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def from_config(cls, config: Dict[str, str]) -> "__S3BucketPolicy":
|
24
|
+
if "bucket" not in config:
|
25
|
+
raise ValueError("Bucket name not found in config")
|
26
|
+
return cls(config["bucket"])
|
27
|
+
|
28
|
+
def __init__(self, bucket: str) -> None:
|
29
|
+
self._bucket = bucket
|
30
|
+
s3 = S3Handler(bucket)
|
31
|
+
local = LocalFileHandler()
|
32
|
+
|
33
|
+
self._handler = MultiHandler(
|
34
|
+
handlers=[
|
35
|
+
s3,
|
36
|
+
local,
|
37
|
+
],
|
38
|
+
default_handler=TrackingHandler(),
|
39
|
+
)
|
40
|
+
|
41
|
+
def config(self) -> Dict[str, str]:
|
42
|
+
return {"bucket": self._bucket}
|
43
|
+
|
44
|
+
def load_path(
|
45
|
+
self,
|
46
|
+
manifest_entry: ArtifactManifestEntry,
|
47
|
+
local: bool = False,
|
48
|
+
) -> Union[URIStr, FilePathStr]:
|
49
|
+
return self._handler.load_path(manifest_entry, local=local)
|
50
|
+
|
51
|
+
def store_path(
|
52
|
+
self,
|
53
|
+
artifact: "Artifact",
|
54
|
+
path: Union[URIStr, FilePathStr],
|
55
|
+
name: Optional[str] = None,
|
56
|
+
checksum: bool = True,
|
57
|
+
max_objects: Optional[int] = None,
|
58
|
+
) -> Sequence[ArtifactManifestEntry]:
|
59
|
+
return self._handler.store_path(
|
60
|
+
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
61
|
+
)
|
@@ -0,0 +1,386 @@
|
|
1
|
+
"""WandB storage policy."""
|
2
|
+
import hashlib
|
3
|
+
import math
|
4
|
+
import shutil
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
6
|
+
from urllib.parse import quote
|
7
|
+
|
8
|
+
import requests
|
9
|
+
import urllib3
|
10
|
+
|
11
|
+
from wandb.apis import InternalApi
|
12
|
+
from wandb.errors.term import termwarn
|
13
|
+
from wandb.sdk.artifacts.artifacts_cache import ArtifactsCache, get_artifacts_cache
|
14
|
+
from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
|
15
|
+
from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
|
16
|
+
from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
|
17
|
+
from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
|
18
|
+
from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
|
19
|
+
from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
|
20
|
+
from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
|
21
|
+
from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
|
22
|
+
from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
|
23
|
+
WBLocalArtifactHandler,
|
24
|
+
)
|
25
|
+
from wandb.sdk.artifacts.storage_layout import StorageLayout
|
26
|
+
from wandb.sdk.artifacts.storage_policy import StoragePolicy
|
27
|
+
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
28
|
+
from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, hex_to_b64_id
|
29
|
+
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
30
|
+
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from wandb.filesync.step_prepare import StepPrepare
|
33
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
34
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
35
|
+
from wandb.sdk.internal import progress
|
36
|
+
|
37
|
+
# Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
|
38
|
+
# seconds, i.e. a total of 20min 6s.
|
39
|
+
_REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
|
40
|
+
backoff_factor=1,
|
41
|
+
total=16,
|
42
|
+
status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
|
43
|
+
)
|
44
|
+
_REQUEST_POOL_CONNECTIONS = 64
|
45
|
+
_REQUEST_POOL_MAXSIZE = 64
|
46
|
+
|
47
|
+
# AWS S3 max upload parts without having to make additional requests for extra parts
|
48
|
+
S3_MAX_PART_NUMBERS = 1000
|
49
|
+
S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
|
50
|
+
S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
|
51
|
+
|
52
|
+
|
53
|
+
class WandbStoragePolicy(StoragePolicy):
|
54
|
+
@classmethod
|
55
|
+
def name(cls) -> str:
|
56
|
+
return "wandb-storage-policy-v1"
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_config(cls, config: Dict) -> "WandbStoragePolicy":
|
60
|
+
return cls(config=config)
|
61
|
+
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
config: Optional[Dict] = None,
|
65
|
+
cache: Optional[ArtifactsCache] = None,
|
66
|
+
api: Optional[InternalApi] = None,
|
67
|
+
) -> None:
|
68
|
+
self._cache = cache or get_artifacts_cache()
|
69
|
+
self._config = config or {}
|
70
|
+
self._session = requests.Session()
|
71
|
+
adapter = requests.adapters.HTTPAdapter(
|
72
|
+
max_retries=_REQUEST_RETRY_STRATEGY,
|
73
|
+
pool_connections=_REQUEST_POOL_CONNECTIONS,
|
74
|
+
pool_maxsize=_REQUEST_POOL_MAXSIZE,
|
75
|
+
)
|
76
|
+
self._session.mount("http://", adapter)
|
77
|
+
self._session.mount("https://", adapter)
|
78
|
+
|
79
|
+
s3 = S3Handler()
|
80
|
+
gcs = GCSHandler()
|
81
|
+
azure = AzureHandler()
|
82
|
+
http = HTTPHandler(self._session)
|
83
|
+
https = HTTPHandler(self._session, scheme="https")
|
84
|
+
artifact = WBArtifactHandler()
|
85
|
+
local_artifact = WBLocalArtifactHandler()
|
86
|
+
file_handler = LocalFileHandler()
|
87
|
+
|
88
|
+
self._api = api or InternalApi()
|
89
|
+
self._handler = MultiHandler(
|
90
|
+
handlers=[
|
91
|
+
s3,
|
92
|
+
gcs,
|
93
|
+
azure,
|
94
|
+
http,
|
95
|
+
https,
|
96
|
+
artifact,
|
97
|
+
local_artifact,
|
98
|
+
file_handler,
|
99
|
+
],
|
100
|
+
default_handler=TrackingHandler(),
|
101
|
+
)
|
102
|
+
|
103
|
+
def config(self) -> Dict:
|
104
|
+
return self._config
|
105
|
+
|
106
|
+
def load_file(
|
107
|
+
self,
|
108
|
+
artifact: "Artifact",
|
109
|
+
manifest_entry: "ArtifactManifestEntry",
|
110
|
+
) -> FilePathStr:
|
111
|
+
path, hit, cache_open = self._cache.check_md5_obj_path(
|
112
|
+
B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
|
113
|
+
manifest_entry.size if manifest_entry.size is not None else 0,
|
114
|
+
)
|
115
|
+
if hit:
|
116
|
+
return path
|
117
|
+
|
118
|
+
if manifest_entry._download_url is not None:
|
119
|
+
response = self._session.get(manifest_entry._download_url, stream=True)
|
120
|
+
try:
|
121
|
+
response.raise_for_status()
|
122
|
+
except Exception:
|
123
|
+
# Signed URL might have expired, fall back to fetching it one by one.
|
124
|
+
manifest_entry._download_url = None
|
125
|
+
if manifest_entry._download_url is None:
|
126
|
+
auth = None
|
127
|
+
if not _thread_local_api_settings.cookies:
|
128
|
+
auth = ("api", self._api.api_key)
|
129
|
+
response = self._session.get(
|
130
|
+
self._file_url(self._api, artifact.entity, manifest_entry),
|
131
|
+
auth=auth,
|
132
|
+
cookies=_thread_local_api_settings.cookies,
|
133
|
+
headers=_thread_local_api_settings.headers,
|
134
|
+
stream=True,
|
135
|
+
)
|
136
|
+
response.raise_for_status()
|
137
|
+
|
138
|
+
with cache_open(mode="wb") as file:
|
139
|
+
for data in response.iter_content(chunk_size=16 * 1024):
|
140
|
+
file.write(data)
|
141
|
+
return path
|
142
|
+
|
143
|
+
def store_reference(
|
144
|
+
self,
|
145
|
+
artifact: "Artifact",
|
146
|
+
path: Union[URIStr, FilePathStr],
|
147
|
+
name: Optional[str] = None,
|
148
|
+
checksum: bool = True,
|
149
|
+
max_objects: Optional[int] = None,
|
150
|
+
) -> Sequence["ArtifactManifestEntry"]:
|
151
|
+
return self._handler.store_path(
|
152
|
+
artifact, path, name=name, checksum=checksum, max_objects=max_objects
|
153
|
+
)
|
154
|
+
|
155
|
+
def load_reference(
|
156
|
+
self,
|
157
|
+
manifest_entry: "ArtifactManifestEntry",
|
158
|
+
local: bool = False,
|
159
|
+
) -> Union[FilePathStr, URIStr]:
|
160
|
+
return self._handler.load_path(manifest_entry, local)
|
161
|
+
|
162
|
+
def _file_url(
|
163
|
+
self,
|
164
|
+
api: InternalApi,
|
165
|
+
entity_name: str,
|
166
|
+
manifest_entry: "ArtifactManifestEntry",
|
167
|
+
) -> str:
|
168
|
+
storage_layout = self._config.get("storageLayout", StorageLayout.V1)
|
169
|
+
storage_region = self._config.get("storageRegion", "default")
|
170
|
+
md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
|
171
|
+
|
172
|
+
if storage_layout == StorageLayout.V1:
|
173
|
+
return "{}/artifacts/{}/{}".format(
|
174
|
+
api.settings("base_url"), entity_name, md5_hex
|
175
|
+
)
|
176
|
+
elif storage_layout == StorageLayout.V2:
|
177
|
+
return "{}/artifactsV2/{}/{}/{}/{}".format(
|
178
|
+
api.settings("base_url"),
|
179
|
+
storage_region,
|
180
|
+
entity_name,
|
181
|
+
quote(
|
182
|
+
manifest_entry.birth_artifact_id
|
183
|
+
if manifest_entry.birth_artifact_id is not None
|
184
|
+
else ""
|
185
|
+
),
|
186
|
+
md5_hex,
|
187
|
+
)
|
188
|
+
else:
|
189
|
+
raise Exception(f"unrecognized storage layout: {storage_layout}")
|
190
|
+
|
191
|
+
def s3_multipart_file_upload(
|
192
|
+
self,
|
193
|
+
file_path: str,
|
194
|
+
chunk_size: int,
|
195
|
+
hex_digests: Dict[int, str],
|
196
|
+
multipart_urls: Dict[int, str],
|
197
|
+
extra_headers: Dict[str, str],
|
198
|
+
) -> List[Dict[str, Any]]:
|
199
|
+
etags = []
|
200
|
+
part_number = 1
|
201
|
+
|
202
|
+
with open(file_path, "rb") as f:
|
203
|
+
while True:
|
204
|
+
data = f.read(chunk_size)
|
205
|
+
if not data:
|
206
|
+
break
|
207
|
+
md5_b64_str = str(hex_to_b64_id(hex_digests[part_number]))
|
208
|
+
upload_resp = self._api.upload_multipart_file_chunk_retry(
|
209
|
+
multipart_urls[part_number],
|
210
|
+
data,
|
211
|
+
extra_headers={
|
212
|
+
"content-md5": md5_b64_str,
|
213
|
+
"content-length": str(len(data)),
|
214
|
+
"content-type": extra_headers.get("Content-Type"),
|
215
|
+
},
|
216
|
+
)
|
217
|
+
etags.append(
|
218
|
+
{"partNumber": part_number, "hexMD5": upload_resp.headers["ETag"]}
|
219
|
+
)
|
220
|
+
part_number += 1
|
221
|
+
return etags
|
222
|
+
|
223
|
+
def default_file_upload(
|
224
|
+
self,
|
225
|
+
upload_url: str,
|
226
|
+
file_path: str,
|
227
|
+
extra_headers: Dict[str, Any],
|
228
|
+
progress_callback: Optional["progress.ProgressFn"] = None,
|
229
|
+
) -> None:
|
230
|
+
"""Upload a file to the artifact store and write to cache."""
|
231
|
+
with open(file_path, "rb") as file:
|
232
|
+
# This fails if we don't send the first byte before the signed URL expires.
|
233
|
+
self._api.upload_file_retry(
|
234
|
+
upload_url,
|
235
|
+
file,
|
236
|
+
progress_callback,
|
237
|
+
extra_headers=extra_headers,
|
238
|
+
)
|
239
|
+
|
240
|
+
def calc_chunk_size(self, file_size: int) -> int:
|
241
|
+
# Default to chunk size of 100MiB. S3 has cap of 10,000 upload parts.
|
242
|
+
# If file size exceeds the default chunk size, recalculate chunk size.
|
243
|
+
default_chunk_size = 100 * 1024**2
|
244
|
+
if default_chunk_size * S3_MAX_PART_NUMBERS < file_size:
|
245
|
+
return math.ceil(file_size / S3_MAX_PART_NUMBERS)
|
246
|
+
return default_chunk_size
|
247
|
+
|
248
|
+
def store_file_sync(
|
249
|
+
self,
|
250
|
+
artifact_id: str,
|
251
|
+
artifact_manifest_id: str,
|
252
|
+
entry: "ArtifactManifestEntry",
|
253
|
+
preparer: "StepPrepare",
|
254
|
+
progress_callback: Optional["progress.ProgressFn"] = None,
|
255
|
+
) -> bool:
|
256
|
+
"""Upload a file to the artifact store.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
True if the file was a duplicate (did not need to be uploaded),
|
260
|
+
False if it needed to be uploaded or was a reference (nothing to dedupe).
|
261
|
+
"""
|
262
|
+
file_size = entry.size if entry.size is not None else 0
|
263
|
+
chunk_size = self.calc_chunk_size(file_size)
|
264
|
+
upload_parts = []
|
265
|
+
hex_digests = {}
|
266
|
+
file_path = entry.local_path if entry.local_path is not None else ""
|
267
|
+
# Logic for AWS s3 multipart upload.
|
268
|
+
# Only chunk files if larger than 2 GiB. Currently can only support up to 5TiB.
|
269
|
+
if (
|
270
|
+
file_size >= S3_MIN_MULTI_UPLOAD_SIZE
|
271
|
+
and file_size <= S3_MAX_MULTI_UPLOAD_SIZE
|
272
|
+
):
|
273
|
+
part_number = 1
|
274
|
+
with open(file_path, "rb") as f:
|
275
|
+
while True:
|
276
|
+
data = f.read(chunk_size)
|
277
|
+
if not data:
|
278
|
+
break
|
279
|
+
hex_digest = hashlib.md5(data).hexdigest()
|
280
|
+
upload_parts.append(
|
281
|
+
{"hexMD5": hex_digest, "partNumber": part_number}
|
282
|
+
)
|
283
|
+
hex_digests[part_number] = hex_digest
|
284
|
+
part_number += 1
|
285
|
+
|
286
|
+
resp = preparer.prepare_sync(
|
287
|
+
{
|
288
|
+
"artifactID": artifact_id,
|
289
|
+
"artifactManifestID": artifact_manifest_id,
|
290
|
+
"name": entry.path,
|
291
|
+
"md5": entry.digest,
|
292
|
+
"uploadPartsInput": upload_parts,
|
293
|
+
}
|
294
|
+
).get()
|
295
|
+
|
296
|
+
entry.birth_artifact_id = resp.birth_artifact_id
|
297
|
+
|
298
|
+
multipart_urls = resp.multipart_upload_urls
|
299
|
+
if resp.upload_url is None:
|
300
|
+
return True
|
301
|
+
if entry.local_path is None:
|
302
|
+
return False
|
303
|
+
|
304
|
+
extra_headers = {
|
305
|
+
header.split(":", 1)[0]: header.split(":", 1)[1]
|
306
|
+
for header in (resp.upload_headers or {})
|
307
|
+
}
|
308
|
+
|
309
|
+
# This multipart upload isn't available, do a regular single url upload
|
310
|
+
if multipart_urls is None and resp.upload_url:
|
311
|
+
self.default_file_upload(
|
312
|
+
resp.upload_url, file_path, extra_headers, progress_callback
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
if multipart_urls is None:
|
316
|
+
raise ValueError(f"No multipart urls to upload for file: {file_path}")
|
317
|
+
# Upload files using s3 multipart upload urls
|
318
|
+
etags = self.s3_multipart_file_upload(
|
319
|
+
file_path,
|
320
|
+
chunk_size,
|
321
|
+
hex_digests,
|
322
|
+
multipart_urls,
|
323
|
+
extra_headers,
|
324
|
+
)
|
325
|
+
self._api.complete_multipart_upload_artifact(
|
326
|
+
artifact_id, resp.storage_path, etags, resp.upload_id
|
327
|
+
)
|
328
|
+
self._write_cache(entry)
|
329
|
+
|
330
|
+
return False
|
331
|
+
|
332
|
+
async def store_file_async(
|
333
|
+
self,
|
334
|
+
artifact_id: str,
|
335
|
+
artifact_manifest_id: str,
|
336
|
+
entry: "ArtifactManifestEntry",
|
337
|
+
preparer: "StepPrepare",
|
338
|
+
progress_callback: Optional["progress.ProgressFn"] = None,
|
339
|
+
) -> bool:
|
340
|
+
"""Async equivalent to `store_file_sync`."""
|
341
|
+
resp = await preparer.prepare_async(
|
342
|
+
{
|
343
|
+
"artifactID": artifact_id,
|
344
|
+
"artifactManifestID": artifact_manifest_id,
|
345
|
+
"name": entry.path,
|
346
|
+
"md5": entry.digest,
|
347
|
+
}
|
348
|
+
)
|
349
|
+
|
350
|
+
entry.birth_artifact_id = resp.birth_artifact_id
|
351
|
+
if resp.upload_url is None:
|
352
|
+
return True
|
353
|
+
if entry.local_path is None:
|
354
|
+
return False
|
355
|
+
|
356
|
+
with open(entry.local_path, "rb") as file:
|
357
|
+
# This fails if we don't send the first byte before the signed URL expires.
|
358
|
+
await self._api.upload_file_retry_async(
|
359
|
+
resp.upload_url,
|
360
|
+
file,
|
361
|
+
progress_callback,
|
362
|
+
extra_headers={
|
363
|
+
header.split(":", 1)[0]: header.split(":", 1)[1]
|
364
|
+
for header in (resp.upload_headers or {})
|
365
|
+
},
|
366
|
+
)
|
367
|
+
|
368
|
+
self._write_cache(entry)
|
369
|
+
|
370
|
+
return False
|
371
|
+
|
372
|
+
def _write_cache(self, entry: "ArtifactManifestEntry") -> None:
|
373
|
+
if entry.local_path is None:
|
374
|
+
return
|
375
|
+
|
376
|
+
# Cache upon successful upload.
|
377
|
+
_, hit, cache_open = self._cache.check_md5_obj_path(
|
378
|
+
B64MD5(entry.digest),
|
379
|
+
entry.size if entry.size is not None else 0,
|
380
|
+
)
|
381
|
+
if not hit:
|
382
|
+
try:
|
383
|
+
with cache_open() as f:
|
384
|
+
shutil.copyfile(entry.local_path, f.name)
|
385
|
+
except OSError as e:
|
386
|
+
termwarn(f"Failed to cache {entry.local_path}, ignoring {e}")
|
@@ -1,20 +1,15 @@
|
|
1
|
+
"""Storage policy."""
|
1
2
|
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Type, Union
|
2
3
|
|
3
4
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
4
5
|
|
5
6
|
if TYPE_CHECKING:
|
6
|
-
from urllib.parse import ParseResult
|
7
|
-
|
8
7
|
from wandb.filesync.step_prepare import StepPrepare
|
9
|
-
from wandb.sdk.
|
8
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
9
|
+
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
10
10
|
from wandb.sdk.internal.progress import ProgressFn
|
11
11
|
|
12
12
|
|
13
|
-
class StorageLayout:
|
14
|
-
V1 = "V1"
|
15
|
-
V2 = "V2"
|
16
|
-
|
17
|
-
|
18
13
|
class StoragePolicy:
|
19
14
|
@classmethod
|
20
15
|
def lookup_by_name(cls, name: str) -> Optional[Type["StoragePolicy"]]:
|
@@ -36,7 +31,7 @@ class StoragePolicy:
|
|
36
31
|
|
37
32
|
def load_file(
|
38
33
|
self, artifact: "Artifact", manifest_entry: "ArtifactManifestEntry"
|
39
|
-
) ->
|
34
|
+
) -> FilePathStr:
|
40
35
|
raise NotImplementedError
|
41
36
|
|
42
37
|
def store_file_sync(
|
@@ -74,52 +69,5 @@ class StoragePolicy:
|
|
74
69
|
self,
|
75
70
|
manifest_entry: "ArtifactManifestEntry",
|
76
71
|
local: bool = False,
|
77
|
-
) ->
|
78
|
-
raise NotImplementedError
|
79
|
-
|
80
|
-
|
81
|
-
class StorageHandler:
|
82
|
-
def can_handle(self, parsed_url: "ParseResult") -> bool:
|
83
|
-
"""Checks whether this handler can handle the given url.
|
84
|
-
|
85
|
-
Returns:
|
86
|
-
Whether this handler can handle the given url.
|
87
|
-
"""
|
88
|
-
raise NotImplementedError
|
89
|
-
|
90
|
-
def load_path(
|
91
|
-
self,
|
92
|
-
manifest_entry: "ArtifactManifestEntry",
|
93
|
-
local: bool = False,
|
94
|
-
) -> Union[URIStr, FilePathStr]:
|
95
|
-
"""Load a file or directory given the corresponding index entry.
|
96
|
-
|
97
|
-
Args:
|
98
|
-
manifest_entry: The index entry to load
|
99
|
-
local: Whether to load the file locally or not
|
100
|
-
|
101
|
-
Returns:
|
102
|
-
A path to the file represented by `index_entry`
|
103
|
-
"""
|
104
|
-
raise NotImplementedError
|
105
|
-
|
106
|
-
def store_path(
|
107
|
-
self,
|
108
|
-
artifact: "Artifact",
|
109
|
-
path: Union[URIStr, FilePathStr],
|
110
|
-
name: Optional[str] = None,
|
111
|
-
checksum: bool = True,
|
112
|
-
max_objects: Optional[int] = None,
|
113
|
-
) -> Sequence["ArtifactManifestEntry"]:
|
114
|
-
"""Store the file or directory at the given path to the specified artifact.
|
115
|
-
|
116
|
-
Args:
|
117
|
-
path: The path to store
|
118
|
-
name: If specified, the logical name that should map to `path`
|
119
|
-
checksum: Whether to compute the checksum of the file
|
120
|
-
max_objects: The maximum number of objects to store
|
121
|
-
|
122
|
-
Returns:
|
123
|
-
A list of manifest entries to store within the artifact
|
124
|
-
"""
|
72
|
+
) -> Union[FilePathStr, URIStr]:
|
125
73
|
raise NotImplementedError
|
wandb/sdk/data_types/_dtypes.py
CHANGED
@@ -13,8 +13,7 @@ from wandb.util import (
|
|
13
13
|
np = get_module("numpy") # intentionally not required
|
14
14
|
|
15
15
|
if t.TYPE_CHECKING:
|
16
|
-
from wandb.
|
17
|
-
from wandb.sdk.wandb_artifacts import Artifact as ArtifactInCreation
|
16
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
18
17
|
|
19
18
|
_TYPES_STRIPPED = not (sys.version_info.major == 3 and sys.version_info.minor >= 6)
|
20
19
|
if not _TYPES_STRIPPED:
|
@@ -77,7 +76,7 @@ class TypeRegistry:
|
|
77
76
|
|
78
77
|
@staticmethod
|
79
78
|
def type_from_dict(
|
80
|
-
json_dict: t.Dict[str, t.Any], artifact: t.Optional["
|
79
|
+
json_dict: t.Dict[str, t.Any], artifact: t.Optional["Artifact"] = None
|
81
80
|
) -> "Type":
|
82
81
|
wb_type = json_dict.get("wb_type")
|
83
82
|
if wb_type is None:
|
@@ -135,7 +134,7 @@ class TypeRegistry:
|
|
135
134
|
|
136
135
|
def _params_obj_to_json_obj(
|
137
136
|
params_obj: t.Any,
|
138
|
-
artifact: t.Optional["
|
137
|
+
artifact: t.Optional["Artifact"] = None,
|
139
138
|
) -> t.Any:
|
140
139
|
"""Helper method."""
|
141
140
|
if params_obj.__class__ == dict:
|
@@ -152,7 +151,7 @@ def _params_obj_to_json_obj(
|
|
152
151
|
|
153
152
|
|
154
153
|
def _json_obj_to_params_obj(
|
155
|
-
json_obj: t.Any, artifact: t.Optional["
|
154
|
+
json_obj: t.Any, artifact: t.Optional["Artifact"] = None
|
156
155
|
) -> t.Any:
|
157
156
|
"""Helper method."""
|
158
157
|
if json_obj.__class__ == dict:
|
@@ -222,9 +221,7 @@ class Type:
|
|
222
221
|
else:
|
223
222
|
return InvalidType()
|
224
223
|
|
225
|
-
def to_json(
|
226
|
-
self, artifact: t.Optional["ArtifactInCreation"] = None
|
227
|
-
) -> t.Dict[str, t.Any]:
|
224
|
+
def to_json(self, artifact: t.Optional["Artifact"] = None) -> t.Dict[str, t.Any]:
|
228
225
|
"""Generate a jsonable dictionary serialization the type.
|
229
226
|
|
230
227
|
If overridden by subclass, ensure that `from_json` is equivalently overridden.
|
@@ -249,7 +246,7 @@ class Type:
|
|
249
246
|
def from_json(
|
250
247
|
cls,
|
251
248
|
json_dict: t.Dict[str, t.Any],
|
252
|
-
artifact: t.Optional["
|
249
|
+
artifact: t.Optional["Artifact"] = None,
|
253
250
|
) -> "Type":
|
254
251
|
"""Construct a new instance of the type using a JSON dictionary.
|
255
252
|
|
@@ -756,9 +753,7 @@ class NDArrayType(Type):
|
|
756
753
|
|
757
754
|
return InvalidType()
|
758
755
|
|
759
|
-
def to_json(
|
760
|
-
self, artifact: t.Optional["ArtifactInCreation"] = None
|
761
|
-
) -> t.Dict[str, t.Any]:
|
756
|
+
def to_json(self, artifact: t.Optional["Artifact"] = None) -> t.Dict[str, t.Any]:
|
762
757
|
# custom override to support serialization path outside of params internal dict
|
763
758
|
res = {
|
764
759
|
"wb_type": self.name,
|
@@ -9,7 +9,8 @@ from .._private import MEDIA_TMP
|
|
9
9
|
from .media import Media
|
10
10
|
|
11
11
|
if TYPE_CHECKING: # pragma: no cover
|
12
|
-
from
|
12
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
13
|
+
|
13
14
|
from ...wandb_run import Run as LocalRun
|
14
15
|
|
15
16
|
|
@@ -39,7 +40,7 @@ class JSONMetadata(Media):
|
|
39
40
|
def get_media_subdir(cls: Type["JSONMetadata"]) -> str:
|
40
41
|
return os.path.join("media", "metadata", cls.type_name())
|
41
42
|
|
42
|
-
def to_json(self, run_or_artifact: Union["LocalRun", "
|
43
|
+
def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
|
43
44
|
json_dict = super().to_json(run_or_artifact)
|
44
45
|
json_dict["_type"] = self.type_name()
|
45
46
|
|
@@ -9,15 +9,15 @@ import wandb
|
|
9
9
|
from wandb import util
|
10
10
|
from wandb._globals import _datatypes_callback
|
11
11
|
from wandb.sdk.lib import filesystem
|
12
|
+
from wandb.sdk.lib.paths import LogicalPath
|
12
13
|
|
13
14
|
from .wb_value import WBValue
|
14
15
|
|
15
16
|
if TYPE_CHECKING: # pragma: no cover
|
16
|
-
import numpy as np
|
17
|
+
import numpy as np
|
17
18
|
|
18
|
-
from wandb.
|
19
|
+
from wandb.sdk.artifacts.artifact import Artifact
|
19
20
|
|
20
|
-
from ...wandb_artifacts import Artifact as LocalArtifact
|
21
21
|
from ...wandb_run import Run as LocalRun
|
22
22
|
|
23
23
|
|
@@ -143,7 +143,7 @@ class Media(WBValue):
|
|
143
143
|
self._path = new_path
|
144
144
|
_datatypes_callback(media_path)
|
145
145
|
|
146
|
-
def to_json(self, run: Union["LocalRun", "
|
146
|
+
def to_json(self, run: Union["LocalRun", "Artifact"]) -> dict:
|
147
147
|
"""Serialize the object into a JSON blob.
|
148
148
|
|
149
149
|
Uses run or artifact to store additional data. If `run_or_artifact` is a
|
@@ -193,11 +193,11 @@ class Media(WBValue):
|
|
193
193
|
# by definition is_bound, but are needed for
|
194
194
|
# mypy to understand that these are strings below.
|
195
195
|
assert isinstance(self._path, str)
|
196
|
-
json_obj["path"] =
|
196
|
+
json_obj["path"] = LogicalPath(
|
197
197
|
os.path.relpath(self._path, self._run.dir)
|
198
198
|
)
|
199
199
|
|
200
|
-
elif isinstance(run, wandb.
|
200
|
+
elif isinstance(run, wandb.Artifact):
|
201
201
|
if self.file_is_set():
|
202
202
|
# The following two assertions are guaranteed to pass
|
203
203
|
# by definition of the call above, but are needed for
|
@@ -253,7 +253,7 @@ class Media(WBValue):
|
|
253
253
|
|
254
254
|
@classmethod
|
255
255
|
def from_json(
|
256
|
-
cls: Type["Media"], json_obj: dict, source_artifact: "
|
256
|
+
cls: Type["Media"], json_obj: dict, source_artifact: "Artifact"
|
257
257
|
) -> "Media":
|
258
258
|
"""Likely will need to override for any more complicated media objects."""
|
259
259
|
return cls(source_artifact.get_path(json_obj["path"]).download())
|
@@ -315,4 +315,4 @@ def _numpy_arrays_to_lists(
|
|
315
315
|
# Protects against logging non serializable objects
|
316
316
|
elif isinstance(payload, Media):
|
317
317
|
return str(payload.__class__.__name__)
|
318
|
-
return payload
|
318
|
+
return payload # type: ignore
|