wandb 0.22.0__py3-none-musllinux_1_2_aarch64.whl → 0.22.2__py3-none-musllinux_1_2_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +8 -5
- wandb/_pydantic/__init__.py +12 -11
- wandb/_pydantic/base.py +49 -19
- wandb/apis/__init__.py +2 -0
- wandb/apis/attrs.py +2 -0
- wandb/apis/importers/internals/internal.py +16 -23
- wandb/apis/internal.py +2 -0
- wandb/apis/normalize.py +2 -0
- wandb/apis/public/__init__.py +3 -2
- wandb/apis/public/api.py +215 -164
- wandb/apis/public/artifacts.py +23 -20
- wandb/apis/public/const.py +2 -0
- wandb/apis/public/files.py +33 -24
- wandb/apis/public/history.py +2 -0
- wandb/apis/public/jobs.py +20 -18
- wandb/apis/public/projects.py +4 -2
- wandb/apis/public/query_generator.py +3 -0
- wandb/apis/public/registries/__init__.py +7 -0
- wandb/apis/public/registries/_freezable_list.py +9 -12
- wandb/apis/public/registries/registries_search.py +8 -6
- wandb/apis/public/registries/registry.py +22 -17
- wandb/apis/public/reports.py +2 -0
- wandb/apis/public/runs.py +261 -57
- wandb/apis/public/sweeps.py +10 -9
- wandb/apis/public/teams.py +2 -0
- wandb/apis/public/users.py +2 -0
- wandb/apis/public/utils.py +16 -15
- wandb/automations/_generated/__init__.py +54 -127
- wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
- wandb/automations/_generated/fragments.py +26 -91
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/beta.py +16 -2
- wandb/cli/beta_leet.py +74 -0
- wandb/cli/beta_sync.py +9 -11
- wandb/cli/cli.py +34 -7
- wandb/errors/errors.py +3 -3
- wandb/proto/v3/wandb_api_pb2.py +86 -0
- wandb/proto/v3/wandb_internal_pb2.py +352 -351
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_sync_pb2.py +19 -6
- wandb/proto/v4/wandb_api_pb2.py +37 -0
- wandb/proto/v4/wandb_internal_pb2.py +352 -351
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_sync_pb2.py +10 -6
- wandb/proto/v5/wandb_api_pb2.py +38 -0
- wandb/proto/v5/wandb_internal_pb2.py +352 -351
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_sync_pb2.py +10 -6
- wandb/proto/v6/wandb_api_pb2.py +48 -0
- wandb/proto/v6/wandb_internal_pb2.py +352 -351
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_sync_pb2.py +10 -6
- wandb/proto/wandb_api_pb2.py +18 -0
- wandb/proto/wandb_generate_proto.py +1 -0
- wandb/sdk/artifacts/_factories.py +7 -2
- wandb/sdk/artifacts/_generated/__init__.py +112 -412
- wandb/sdk/artifacts/_generated/fragments.py +65 -0
- wandb/sdk/artifacts/_generated/operations.py +52 -22
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
- wandb/sdk/artifacts/_generated/type_info.py +19 -0
- wandb/sdk/artifacts/_gqlutils.py +47 -0
- wandb/sdk/artifacts/_models/__init__.py +4 -0
- wandb/sdk/artifacts/_models/base_model.py +20 -0
- wandb/sdk/artifacts/_validators.py +40 -12
- wandb/sdk/artifacts/artifact.py +99 -118
- wandb/sdk/artifacts/artifact_file_cache.py +6 -1
- wandb/sdk/artifacts/artifact_manifest_entry.py +67 -14
- wandb/sdk/artifacts/storage_handler.py +18 -12
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
- wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
- wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +71 -242
- wandb/sdk/artifacts/storage_policy.py +25 -12
- wandb/sdk/data_types/bokeh.py +5 -1
- wandb/sdk/data_types/image.py +17 -6
- wandb/sdk/data_types/object_3d.py +67 -2
- wandb/sdk/interface/interface.py +31 -4
- wandb/sdk/interface/interface_queue.py +10 -0
- wandb/sdk/interface/interface_shared.py +0 -7
- wandb/sdk/interface/interface_sock.py +9 -3
- wandb/sdk/internal/_generated/__init__.py +2 -12
- wandb/sdk/internal/job_builder.py +27 -10
- wandb/sdk/internal/sender.py +5 -2
- wandb/sdk/internal/settings_static.py +2 -82
- wandb/sdk/launch/create_job.py +2 -1
- wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
- wandb/sdk/launch/utils.py +82 -1
- wandb/sdk/lib/progress.py +8 -74
- wandb/sdk/lib/service/service_client.py +5 -9
- wandb/sdk/lib/service/service_connection.py +39 -23
- wandb/sdk/mailbox/mailbox_handle.py +2 -0
- wandb/sdk/projects/_generated/__init__.py +12 -33
- wandb/sdk/wandb_init.py +23 -3
- wandb/sdk/wandb_login.py +53 -27
- wandb/sdk/wandb_run.py +10 -5
- wandb/sdk/wandb_settings.py +63 -25
- wandb/sync/sync.py +7 -2
- wandb/util.py +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/RECORD +784 -774
- wandb/sdk/artifacts/_graphql_fragments.py +0 -19
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
- {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/artifacts/artifact.py
CHANGED
@@ -3,7 +3,6 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import atexit
|
6
|
-
import concurrent.futures
|
7
6
|
import contextlib
|
8
7
|
import json
|
9
8
|
import logging
|
@@ -15,10 +14,10 @@ import stat
|
|
15
14
|
import tempfile
|
16
15
|
import time
|
17
16
|
from collections import deque
|
17
|
+
from concurrent.futures import Executor, ThreadPoolExecutor, as_completed
|
18
18
|
from copy import copy
|
19
|
-
from dataclasses import dataclass
|
19
|
+
from dataclasses import asdict, dataclass, replace
|
20
20
|
from datetime import timedelta
|
21
|
-
from functools import partial
|
22
21
|
from itertools import filterfalse
|
23
22
|
from pathlib import Path, PurePosixPath
|
24
23
|
from typing import (
|
@@ -30,7 +29,6 @@ from typing import (
|
|
30
29
|
Literal,
|
31
30
|
Sequence,
|
32
31
|
Type,
|
33
|
-
cast,
|
34
32
|
final,
|
35
33
|
)
|
36
34
|
from urllib.parse import quote, urljoin, urlparse
|
@@ -51,6 +49,7 @@ from wandb.errors.term import termerror, termlog, termwarn
|
|
51
49
|
from wandb.proto import wandb_internal_pb2 as pb
|
52
50
|
from wandb.proto.wandb_deprecated import Deprecated
|
53
51
|
from wandb.sdk import wandb_setup
|
52
|
+
from wandb.sdk.artifacts.storage_policies._multipart import should_multipart_download
|
54
53
|
from wandb.sdk.data_types._dtypes import Type as WBType
|
55
54
|
from wandb.sdk.data_types._dtypes import TypeRegistry
|
56
55
|
from wandb.sdk.internal.internal_api import Api as InternalApi
|
@@ -109,10 +108,11 @@ from ._generated import (
|
|
109
108
|
TagInput,
|
110
109
|
UpdateArtifact,
|
111
110
|
)
|
112
|
-
from .
|
111
|
+
from ._gqlutils import omit_artifact_fields, supports_enable_tracking_var, type_info
|
113
112
|
from ._validators import (
|
114
113
|
LINKED_ARTIFACT_COLLECTION_TYPE,
|
115
114
|
ArtifactPath,
|
115
|
+
FullArtifactPath,
|
116
116
|
_LinkArtifactFields,
|
117
117
|
ensure_logged,
|
118
118
|
ensure_not_finalized,
|
@@ -210,6 +210,7 @@ class Artifact:
|
|
210
210
|
metadata: dict[str, Any] | None = None,
|
211
211
|
incremental: bool = False,
|
212
212
|
use_as: str | None = None,
|
213
|
+
storage_region: str | None = None,
|
213
214
|
) -> None:
|
214
215
|
if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
|
215
216
|
raise ValueError(
|
@@ -268,7 +269,7 @@ class Artifact:
|
|
268
269
|
self._use_as: str | None = None
|
269
270
|
self._state: ArtifactState = ArtifactState.PENDING
|
270
271
|
self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
|
271
|
-
ArtifactManifestV1(storage_policy=make_storage_policy())
|
272
|
+
ArtifactManifestV1(storage_policy=make_storage_policy(storage_region))
|
272
273
|
)
|
273
274
|
self._commit_hash: str | None = None
|
274
275
|
self._file_count: int | None = None
|
@@ -286,36 +287,33 @@ class Artifact:
|
|
286
287
|
|
287
288
|
@classmethod
|
288
289
|
def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
|
289
|
-
if
|
290
|
-
return
|
290
|
+
if cached_artifact := artifact_instance_cache.get(artifact_id):
|
291
|
+
return cached_artifact
|
291
292
|
|
292
|
-
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
|
293
|
+
query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(client))
|
293
294
|
|
294
295
|
data = client.execute(query, variable_values={"id": artifact_id})
|
295
296
|
result = ArtifactByID.model_validate(data)
|
296
297
|
|
297
|
-
if (
|
298
|
+
if (artifact := result.artifact) is None:
|
298
299
|
return None
|
299
300
|
|
300
|
-
src_collection =
|
301
|
+
src_collection = artifact.artifact_sequence
|
301
302
|
src_project = src_collection.project
|
302
303
|
|
303
304
|
entity_name = src_project.entity_name if src_project else ""
|
304
305
|
project_name = src_project.name if src_project else ""
|
305
306
|
|
306
|
-
name = f"{src_collection.name}:v{
|
307
|
-
|
307
|
+
name = f"{src_collection.name}:v{artifact.version_index}"
|
308
|
+
|
309
|
+
path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
|
310
|
+
return cls._from_attrs(path, artifact, client)
|
308
311
|
|
309
312
|
@classmethod
|
310
313
|
def _membership_from_name(
|
311
|
-
cls,
|
312
|
-
*,
|
313
|
-
entity: str,
|
314
|
-
project: str,
|
315
|
-
name: str,
|
316
|
-
client: RetryingClient,
|
314
|
+
cls, *, path: FullArtifactPath, client: RetryingClient
|
317
315
|
) -> Artifact:
|
318
|
-
if not
|
316
|
+
if not InternalApi()._server_supports(
|
319
317
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
320
318
|
):
|
321
319
|
raise UnsupportedError(
|
@@ -325,74 +323,70 @@ class Artifact:
|
|
325
323
|
|
326
324
|
query = gql_compat(
|
327
325
|
ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
|
328
|
-
omit_fields=omit_artifact_fields(
|
326
|
+
omit_fields=omit_artifact_fields(client),
|
329
327
|
)
|
330
|
-
|
331
|
-
|
328
|
+
gql_vars = {
|
329
|
+
"entityName": path.prefix,
|
330
|
+
"projectName": path.project,
|
331
|
+
"name": path.name,
|
332
|
+
}
|
332
333
|
data = client.execute(query, variable_values=gql_vars)
|
333
334
|
result = ArtifactViaMembershipByName.model_validate(data)
|
334
335
|
|
335
|
-
if not (
|
336
|
-
raise ValueError(f"project {project!r} not found under entity {entity!r}")
|
337
|
-
|
338
|
-
if not (acm_attrs := project_attrs.artifact_collection_membership):
|
339
|
-
entity_project = f"{entity}/{project}"
|
336
|
+
if not (project := result.project):
|
340
337
|
raise ValueError(
|
341
|
-
f"
|
338
|
+
f"project {path.project!r} not found under entity {path.prefix!r}"
|
342
339
|
)
|
343
|
-
|
344
|
-
|
345
|
-
|
340
|
+
if not (membership := project.artifact_collection_membership):
|
341
|
+
entity_project = f"{path.prefix}/{path.project}"
|
342
|
+
raise ValueError(
|
343
|
+
f"artifact membership {path.name!r} not found in {entity_project!r}"
|
344
|
+
)
|
345
|
+
return cls._from_membership(membership, target=path, client=client)
|
346
346
|
|
347
347
|
@classmethod
|
348
348
|
def _from_name(
|
349
349
|
cls,
|
350
350
|
*,
|
351
|
-
|
352
|
-
project: str,
|
353
|
-
name: str,
|
351
|
+
path: FullArtifactPath,
|
354
352
|
client: RetryingClient,
|
355
353
|
enable_tracking: bool = False,
|
356
354
|
) -> Artifact:
|
357
|
-
if
|
355
|
+
if InternalApi()._server_supports(
|
358
356
|
pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
|
359
357
|
):
|
360
|
-
return cls._membership_from_name(
|
361
|
-
entity=entity, project=project, name=name, client=client
|
362
|
-
)
|
363
|
-
|
364
|
-
supports_enable_tracking_gql_var = api.server_project_type_introspection()
|
365
|
-
omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
|
358
|
+
return cls._membership_from_name(path=path, client=client)
|
366
359
|
|
360
|
+
omit_vars = None if supports_enable_tracking_var(client) else {"enableTracking"}
|
367
361
|
gql_vars = {
|
368
|
-
"entityName":
|
369
|
-
"projectName": project,
|
370
|
-
"name": name,
|
362
|
+
"entityName": path.prefix,
|
363
|
+
"projectName": path.project,
|
364
|
+
"name": path.name,
|
371
365
|
"enableTracking": enable_tracking,
|
372
366
|
}
|
373
367
|
query = gql_compat(
|
374
368
|
ARTIFACT_BY_NAME_GQL,
|
375
369
|
omit_variables=omit_vars,
|
376
|
-
omit_fields=omit_artifact_fields(
|
370
|
+
omit_fields=omit_artifact_fields(client),
|
377
371
|
)
|
378
|
-
|
379
372
|
data = client.execute(query, variable_values=gql_vars)
|
380
373
|
result = ArtifactByName.model_validate(data)
|
381
374
|
|
382
|
-
if not (
|
383
|
-
raise ValueError(
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
375
|
+
if not (project := result.project):
|
376
|
+
raise ValueError(
|
377
|
+
f"project {path.project!r} not found under entity {path.prefix!r}"
|
378
|
+
)
|
379
|
+
if not (artifact := project.artifact):
|
380
|
+
entity_project = f"{path.prefix}/{path.project}"
|
381
|
+
raise ValueError(f"artifact {path.name!r} not found in {entity_project!r}")
|
388
382
|
|
389
|
-
return cls._from_attrs(
|
383
|
+
return cls._from_attrs(path, artifact, client)
|
390
384
|
|
391
385
|
@classmethod
|
392
386
|
def _from_membership(
|
393
387
|
cls,
|
394
388
|
membership: MembershipWithArtifact,
|
395
|
-
target:
|
389
|
+
target: FullArtifactPath,
|
396
390
|
client: RetryingClient,
|
397
391
|
) -> Artifact:
|
398
392
|
if not (
|
@@ -410,35 +404,31 @@ class Artifact:
|
|
410
404
|
f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
|
411
405
|
f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
|
412
406
|
)
|
413
|
-
|
407
|
+
new_target = replace(target, prefix=proj.entity_name, project=proj.name)
|
414
408
|
else:
|
415
|
-
|
416
|
-
new_project = cast(str, target.project)
|
409
|
+
new_target = copy(target)
|
417
410
|
|
418
411
|
if not (artifact := membership.artifact):
|
419
412
|
raise ValueError(f"Artifact {target.to_str()!r} not found in response")
|
420
413
|
|
421
|
-
return cls._from_attrs(
|
414
|
+
return cls._from_attrs(new_target, artifact, client)
|
422
415
|
|
423
416
|
@classmethod
|
424
417
|
def _from_attrs(
|
425
418
|
cls,
|
426
|
-
|
427
|
-
|
428
|
-
name: str,
|
429
|
-
attrs: dict[str, Any] | ArtifactFragment,
|
419
|
+
path: FullArtifactPath,
|
420
|
+
attrs: ArtifactFragment,
|
430
421
|
client: RetryingClient,
|
431
422
|
aliases: list[str] | None = None,
|
432
423
|
) -> Artifact:
|
433
424
|
# Placeholder is required to skip validation.
|
434
425
|
artifact = cls("placeholder", type="placeholder")
|
435
426
|
artifact._client = client
|
436
|
-
artifact._entity =
|
437
|
-
artifact._project = project
|
438
|
-
artifact._name = name
|
427
|
+
artifact._entity = path.prefix
|
428
|
+
artifact._project = path.project
|
429
|
+
artifact._name = path.name
|
439
430
|
|
440
|
-
|
441
|
-
artifact._assign_attrs(validated_attrs, aliases)
|
431
|
+
artifact._assign_attrs(attrs, aliases)
|
442
432
|
|
443
433
|
artifact.finalize()
|
444
434
|
|
@@ -743,13 +733,12 @@ class Artifact:
|
|
743
733
|
raise ValueError("Client is not initialized")
|
744
734
|
|
745
735
|
try:
|
746
|
-
|
747
|
-
|
736
|
+
path = FullArtifactPath(
|
737
|
+
prefix=self.source_entity,
|
748
738
|
project=self.source_project,
|
749
739
|
name=self.source_name,
|
750
|
-
client=self._client,
|
751
740
|
)
|
752
|
-
self._source_artifact =
|
741
|
+
self._source_artifact = self._from_name(path=path, client=self._client)
|
753
742
|
except Exception as e:
|
754
743
|
raise ValueError(
|
755
744
|
f"Unable to fetch source artifact for linked artifact {self.name}"
|
@@ -1234,8 +1223,9 @@ class Artifact:
|
|
1234
1223
|
def _populate_after_save(self, artifact_id: str) -> None:
|
1235
1224
|
assert self._client is not None
|
1236
1225
|
|
1237
|
-
query = gql_compat(
|
1238
|
-
|
1226
|
+
query = gql_compat(
|
1227
|
+
ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(self._client)
|
1228
|
+
)
|
1239
1229
|
data = self._client.execute(query, variable_values={"id": artifact_id})
|
1240
1230
|
result = ArtifactByID.model_validate(data)
|
1241
1231
|
|
@@ -1258,21 +1248,9 @@ class Artifact:
|
|
1258
1248
|
collection = self.name.split(":")[0]
|
1259
1249
|
|
1260
1250
|
aliases = None
|
1261
|
-
introspect_query = gql(
|
1262
|
-
"""
|
1263
|
-
query ProbeServerAddAliasesInput {
|
1264
|
-
AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
|
1265
|
-
name
|
1266
|
-
inputFields {
|
1267
|
-
name
|
1268
|
-
}
|
1269
|
-
}
|
1270
|
-
}
|
1271
|
-
"""
|
1272
|
-
)
|
1273
1251
|
|
1274
|
-
|
1275
|
-
|
1252
|
+
if type_info(self._client, "AddAliasesInput") is not None:
|
1253
|
+
# wandb backend version >= 0.13.0
|
1276
1254
|
alias_props = {
|
1277
1255
|
"entity_name": entity,
|
1278
1256
|
"project_name": project,
|
@@ -1330,7 +1308,7 @@ class Artifact:
|
|
1330
1308
|
for alias in self.aliases
|
1331
1309
|
]
|
1332
1310
|
|
1333
|
-
omit_fields = omit_artifact_fields()
|
1311
|
+
omit_fields = omit_artifact_fields(self._client)
|
1334
1312
|
omit_variables = set()
|
1335
1313
|
|
1336
1314
|
if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
|
@@ -2068,24 +2046,14 @@ class Artifact:
|
|
2068
2046
|
|
2069
2047
|
download_logger = ArtifactDownloadLogger(nfiles=nfiles)
|
2070
2048
|
|
2071
|
-
def _download_entry(
|
2072
|
-
|
2073
|
-
|
2074
|
-
|
2075
|
-
|
2076
|
-
|
2077
|
-
) -> None:
|
2078
|
-
_thread_local_api_settings.api_key = api_key
|
2079
|
-
_thread_local_api_settings.cookies = cookies
|
2080
|
-
_thread_local_api_settings.headers = headers
|
2081
|
-
|
2049
|
+
def _download_entry(entry: ArtifactManifestEntry, executor: Executor) -> None:
|
2050
|
+
multipart_executor = (
|
2051
|
+
executor
|
2052
|
+
if should_multipart_download(entry.size, override=multipart)
|
2053
|
+
else None
|
2054
|
+
)
|
2082
2055
|
try:
|
2083
|
-
entry.download(
|
2084
|
-
root,
|
2085
|
-
skip_cache=skip_cache,
|
2086
|
-
executor=executor,
|
2087
|
-
multipart=multipart,
|
2088
|
-
)
|
2056
|
+
entry.download(root, skip_cache=skip_cache, executor=multipart_executor)
|
2089
2057
|
except FileNotFoundError as e:
|
2090
2058
|
if allow_missing_references:
|
2091
2059
|
wandb.termwarn(str(e))
|
@@ -2096,15 +2064,23 @@ class Artifact:
|
|
2096
2064
|
return
|
2097
2065
|
download_logger.notify_downloaded()
|
2098
2066
|
|
2099
|
-
|
2100
|
-
|
2101
|
-
|
2102
|
-
|
2103
|
-
|
2104
|
-
|
2105
|
-
|
2106
|
-
)
|
2067
|
+
def _init_thread(
|
2068
|
+
api_key: str | None, cookies: dict | None, headers: dict | None
|
2069
|
+
) -> None:
|
2070
|
+
"""Initialize the thread-local API settings in the CURRENT thread."""
|
2071
|
+
_thread_local_api_settings.api_key = api_key
|
2072
|
+
_thread_local_api_settings.cookies = cookies
|
2073
|
+
_thread_local_api_settings.headers = headers
|
2107
2074
|
|
2075
|
+
with ThreadPoolExecutor(
|
2076
|
+
max_workers=64,
|
2077
|
+
initializer=_init_thread,
|
2078
|
+
initargs=(
|
2079
|
+
_thread_local_api_settings.api_key,
|
2080
|
+
_thread_local_api_settings.cookies,
|
2081
|
+
_thread_local_api_settings.headers,
|
2082
|
+
),
|
2083
|
+
) as executor:
|
2108
2084
|
batch_size = env.get_artifact_fetch_file_url_batch_size()
|
2109
2085
|
|
2110
2086
|
active_futures = set()
|
@@ -2125,7 +2101,9 @@ class Artifact:
|
|
2125
2101
|
# continue
|
2126
2102
|
entry._download_url = node.direct_url
|
2127
2103
|
if (not path_prefix) or entry.path.startswith(str(path_prefix)):
|
2128
|
-
active_futures.add(
|
2104
|
+
active_futures.add(
|
2105
|
+
executor.submit(_download_entry, entry, executor=executor)
|
2106
|
+
)
|
2129
2107
|
|
2130
2108
|
# Wait for download threads to catch up.
|
2131
2109
|
#
|
@@ -2140,14 +2118,14 @@ class Artifact:
|
|
2140
2118
|
# Consider this for a future change, or (depending on risk and risk tolerance)
|
2141
2119
|
# managing this logic via asyncio instead, if viable.
|
2142
2120
|
if len(active_futures) > batch_size:
|
2143
|
-
for future in
|
2121
|
+
for future in as_completed(active_futures):
|
2144
2122
|
future.result() # check for errors
|
2145
2123
|
active_futures.remove(future)
|
2146
2124
|
if len(active_futures) <= batch_size:
|
2147
2125
|
break
|
2148
2126
|
|
2149
2127
|
# Check for errors.
|
2150
|
-
for future in
|
2128
|
+
for future in as_completed(active_futures):
|
2151
2129
|
future.result()
|
2152
2130
|
|
2153
2131
|
if log:
|
@@ -2427,7 +2405,7 @@ class Artifact:
|
|
2427
2405
|
|
2428
2406
|
# Parse the entity (first part of the path) appropriately,
|
2429
2407
|
# depending on whether we're linking to a registry
|
2430
|
-
if target.
|
2408
|
+
if target.is_registry_path():
|
2431
2409
|
# In a Registry linking, the entity is used to fetch the organization of the artifact
|
2432
2410
|
# therefore the source artifact's entity is passed to the backend
|
2433
2411
|
org = target.prefix or settings.get("organization") or ""
|
@@ -2435,6 +2413,9 @@ class Artifact:
|
|
2435
2413
|
else:
|
2436
2414
|
target = target.with_defaults(prefix=self.source_entity)
|
2437
2415
|
|
2416
|
+
# Explicitly convert to FullArtifactPath to ensure all fields are present
|
2417
|
+
target = FullArtifactPath(**asdict(target))
|
2418
|
+
|
2438
2419
|
# Prepare the validated GQL input, send it
|
2439
2420
|
alias_inputs = [
|
2440
2421
|
ArtifactAliasInput(artifact_collection_name=target.name, alias=a)
|
@@ -26,6 +26,11 @@ class Opener(Protocol):
|
|
26
26
|
pass
|
27
27
|
|
28
28
|
|
29
|
+
def artifacts_cache_dir() -> Path:
|
30
|
+
"""Get the artifacts cache directory."""
|
31
|
+
return env.get_cache_dir() / "artifacts"
|
32
|
+
|
33
|
+
|
29
34
|
def _get_sys_umask_threadsafe() -> int:
|
30
35
|
# Workaround to get the current system umask, since
|
31
36
|
# - `os.umask()` isn't thread-safe
|
@@ -248,4 +253,4 @@ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
|
|
248
253
|
|
249
254
|
|
250
255
|
def get_artifact_file_cache() -> ArtifactFileCache:
|
251
|
-
return _build_artifact_file_cache(
|
256
|
+
return _build_artifact_file_cache(artifacts_cache_dir())
|
@@ -3,9 +3,11 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import concurrent.futures
|
6
|
+
import hashlib
|
6
7
|
import json
|
7
8
|
import logging
|
8
9
|
import os
|
10
|
+
from contextlib import suppress
|
9
11
|
from pathlib import Path
|
10
12
|
from typing import TYPE_CHECKING
|
11
13
|
from urllib.parse import urlparse
|
@@ -43,6 +45,48 @@ if TYPE_CHECKING:
|
|
43
45
|
_WB_ARTIFACT_SCHEME = "wandb-artifact"
|
44
46
|
|
45
47
|
|
48
|
+
def _checksum_cache_path(file_path: str) -> str:
|
49
|
+
"""Get path for checksum in central cache directory."""
|
50
|
+
from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
|
51
|
+
|
52
|
+
# Create a unique cache key based on the file's absolute path
|
53
|
+
abs_path = os.path.abspath(file_path)
|
54
|
+
path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
|
55
|
+
|
56
|
+
# Store in wandb cache directory under checksums subdirectory
|
57
|
+
cache_dir = artifacts_cache_dir() / "checksums"
|
58
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
59
|
+
|
60
|
+
return str(cache_dir / f"{path_hash}.checksum")
|
61
|
+
|
62
|
+
|
63
|
+
def _read_cached_checksum(file_path: str) -> str | None:
|
64
|
+
"""Read checksum from cache if it exists and is valid."""
|
65
|
+
checksum_path = _checksum_cache_path(file_path)
|
66
|
+
|
67
|
+
try:
|
68
|
+
with open(file_path) as f, open(checksum_path) as f_checksum:
|
69
|
+
if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
|
70
|
+
# File was modified after checksum was written
|
71
|
+
return None
|
72
|
+
# Read and return the cached checksum
|
73
|
+
return f_checksum.read().strip()
|
74
|
+
except OSError:
|
75
|
+
# File doesn't exist or couldn't be opened
|
76
|
+
return None
|
77
|
+
|
78
|
+
|
79
|
+
def _write_cached_checksum(file_path: str, checksum: str) -> None:
|
80
|
+
"""Write checksum to cache directory."""
|
81
|
+
checksum_path = _checksum_cache_path(file_path)
|
82
|
+
try:
|
83
|
+
with open(checksum_path, "w") as f:
|
84
|
+
f.write(checksum)
|
85
|
+
except OSError:
|
86
|
+
# Non-critical failure, just log it
|
87
|
+
logger.debug(f"Failed to write checksum cache for {file_path!r}")
|
88
|
+
|
89
|
+
|
46
90
|
class ArtifactManifestEntry:
|
47
91
|
"""A single entry in an artifact manifest."""
|
48
92
|
|
@@ -138,7 +182,6 @@ class ArtifactManifestEntry:
|
|
138
182
|
root: str | None = None,
|
139
183
|
skip_cache: bool | None = None,
|
140
184
|
executor: concurrent.futures.Executor | None = None,
|
141
|
-
multipart: bool | None = None,
|
142
185
|
) -> FilePathStr:
|
143
186
|
"""Download this artifact entry to the specified root path.
|
144
187
|
|
@@ -149,11 +192,10 @@ class ArtifactManifestEntry:
|
|
149
192
|
Returns:
|
150
193
|
(str): The path of the downloaded artifact entry.
|
151
194
|
"""
|
152
|
-
|
153
|
-
raise NotImplementedError
|
195
|
+
artifact = self.parent_artifact()
|
154
196
|
|
155
|
-
root = root or
|
156
|
-
|
197
|
+
root = root or artifact._default_root()
|
198
|
+
artifact._add_download_root(root)
|
157
199
|
path = str(Path(self.path))
|
158
200
|
dest_path = os.path.join(root, path)
|
159
201
|
|
@@ -164,32 +206,43 @@ class ArtifactManifestEntry:
|
|
164
206
|
|
165
207
|
# Skip checking the cache (and possibly downloading) if the file already exists
|
166
208
|
# and has the digest we're expecting.
|
209
|
+
|
210
|
+
# Fast integrity check using cached checksum from persistent cache
|
211
|
+
with suppress(OSError):
|
212
|
+
if self.digest == _read_cached_checksum(dest_path):
|
213
|
+
return FilePathStr(dest_path)
|
214
|
+
|
215
|
+
# Fallback to computing/caching the checksum hash
|
167
216
|
try:
|
168
217
|
md5_hash = md5_file_b64(dest_path)
|
169
218
|
except (FileNotFoundError, IsADirectoryError):
|
170
|
-
logger.debug(f"unable to find {dest_path}, skip searching for file")
|
219
|
+
logger.debug(f"unable to find {dest_path!r}, skip searching for file")
|
171
220
|
else:
|
221
|
+
_write_cached_checksum(dest_path, md5_hash)
|
172
222
|
if self.digest == md5_hash:
|
173
223
|
return FilePathStr(dest_path)
|
174
224
|
|
175
225
|
if self.ref is not None:
|
176
|
-
cache_path =
|
226
|
+
cache_path = artifact.manifest.storage_policy.load_reference(
|
177
227
|
self, local=True, dest_path=override_cache_path
|
178
228
|
)
|
179
229
|
else:
|
180
|
-
cache_path =
|
181
|
-
self
|
182
|
-
self,
|
183
|
-
dest_path=override_cache_path,
|
184
|
-
executor=executor,
|
185
|
-
multipart=multipart,
|
230
|
+
cache_path = artifact.manifest.storage_policy.load_file(
|
231
|
+
artifact, self, dest_path=override_cache_path, executor=executor
|
186
232
|
)
|
187
|
-
|
233
|
+
|
234
|
+
# Determine the final path
|
235
|
+
final_path = (
|
188
236
|
dest_path
|
189
237
|
if skip_cache
|
190
238
|
else copy_or_overwrite_changed(cache_path, dest_path)
|
191
239
|
)
|
192
240
|
|
241
|
+
# Cache the checksum for future downloads
|
242
|
+
_write_cached_checksum(str(final_path), self.digest)
|
243
|
+
|
244
|
+
return FilePathStr(final_path)
|
245
|
+
|
193
246
|
def ref_target(self) -> FilePathStr | URIStr:
|
194
247
|
"""Get the reference URL that is targeted by this artifact entry.
|
195
248
|
|
@@ -2,7 +2,8 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
-
from
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import TYPE_CHECKING, Final
|
6
7
|
|
7
8
|
from wandb.sdk.lib.paths import FilePathStr, URIStr
|
8
9
|
|
@@ -12,18 +13,11 @@ if TYPE_CHECKING:
|
|
12
13
|
from wandb.sdk.artifacts.artifact import Artifact
|
13
14
|
from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
|
14
15
|
|
15
|
-
DEFAULT_MAX_OBJECTS = 10**7
|
16
|
+
DEFAULT_MAX_OBJECTS: Final[int] = 10_000_000 # 10**7
|
16
17
|
|
17
18
|
|
18
|
-
class
|
19
|
-
|
20
|
-
"""Checks whether this handler can handle the given url.
|
21
|
-
|
22
|
-
Returns:
|
23
|
-
Whether this handler can handle the given url.
|
24
|
-
"""
|
25
|
-
raise NotImplementedError
|
26
|
-
|
19
|
+
class _BaseStorageHandler(ABC):
|
20
|
+
@abstractmethod
|
27
21
|
def load_path(
|
28
22
|
self,
|
29
23
|
manifest_entry: ArtifactManifestEntry,
|
@@ -40,6 +34,7 @@ class StorageHandler:
|
|
40
34
|
"""
|
41
35
|
raise NotImplementedError
|
42
36
|
|
37
|
+
@abstractmethod
|
43
38
|
def store_path(
|
44
39
|
self,
|
45
40
|
artifact: Artifact,
|
@@ -47,7 +42,7 @@ class StorageHandler:
|
|
47
42
|
name: str | None = None,
|
48
43
|
checksum: bool = True,
|
49
44
|
max_objects: int | None = None,
|
50
|
-
) ->
|
45
|
+
) -> list[ArtifactManifestEntry]:
|
51
46
|
"""Store the file or directory at the given path to the specified artifact.
|
52
47
|
|
53
48
|
Args:
|
@@ -60,3 +55,14 @@ class StorageHandler:
|
|
60
55
|
A list of manifest entries to store within the artifact
|
61
56
|
"""
|
62
57
|
raise NotImplementedError
|
58
|
+
|
59
|
+
|
60
|
+
class StorageHandler(_BaseStorageHandler, ABC): # Handles a single storage protocol
|
61
|
+
@abstractmethod
|
62
|
+
def can_handle(self, parsed_url: ParseResult) -> bool:
|
63
|
+
"""Checks whether this handler can handle the given url.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
Whether this handler can handle the given url.
|
67
|
+
"""
|
68
|
+
raise NotImplementedError
|