wandb 0.22.0__py3-none-win_amd64.whl → 0.22.2__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +8 -5
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +3 -2
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +261 -57
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats.exe +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta.py +16 -2
  35. wandb/cli/beta_leet.py +74 -0
  36. wandb/cli/beta_sync.py +9 -11
  37. wandb/cli/cli.py +34 -7
  38. wandb/errors/errors.py +3 -3
  39. wandb/proto/v3/wandb_api_pb2.py +86 -0
  40. wandb/proto/v3/wandb_internal_pb2.py +352 -351
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  43. wandb/proto/v4/wandb_api_pb2.py +37 -0
  44. wandb/proto/v4/wandb_internal_pb2.py +352 -351
  45. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  46. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  47. wandb/proto/v5/wandb_api_pb2.py +38 -0
  48. wandb/proto/v5/wandb_internal_pb2.py +352 -351
  49. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  50. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  51. wandb/proto/v6/wandb_api_pb2.py +48 -0
  52. wandb/proto/v6/wandb_internal_pb2.py +352 -351
  53. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  54. wandb/proto/v6/wandb_sync_pb2.py +10 -6
  55. wandb/proto/wandb_api_pb2.py +18 -0
  56. wandb/proto/wandb_generate_proto.py +1 -0
  57. wandb/sdk/artifacts/_factories.py +7 -2
  58. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  59. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  60. wandb/sdk/artifacts/_generated/operations.py +52 -22
  61. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  62. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  63. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  64. wandb/sdk/artifacts/_gqlutils.py +47 -0
  65. wandb/sdk/artifacts/_models/__init__.py +4 -0
  66. wandb/sdk/artifacts/_models/base_model.py +20 -0
  67. wandb/sdk/artifacts/_validators.py +40 -12
  68. wandb/sdk/artifacts/artifact.py +99 -118
  69. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  70. wandb/sdk/artifacts/artifact_manifest_entry.py +67 -14
  71. wandb/sdk/artifacts/storage_handler.py +18 -12
  72. wandb/sdk/artifacts/storage_handlers/azure_handler.py +11 -6
  73. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +9 -6
  74. wandb/sdk/artifacts/storage_handlers/http_handler.py +9 -4
  75. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +10 -6
  76. wandb/sdk/artifacts/storage_handlers/multi_handler.py +5 -4
  77. wandb/sdk/artifacts/storage_handlers/s3_handler.py +10 -8
  78. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +6 -4
  79. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +24 -21
  80. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +4 -2
  81. wandb/sdk/artifacts/storage_policies/_multipart.py +187 -0
  82. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +71 -242
  83. wandb/sdk/artifacts/storage_policy.py +25 -12
  84. wandb/sdk/data_types/bokeh.py +5 -1
  85. wandb/sdk/data_types/image.py +17 -6
  86. wandb/sdk/data_types/object_3d.py +67 -2
  87. wandb/sdk/interface/interface.py +31 -4
  88. wandb/sdk/interface/interface_queue.py +10 -0
  89. wandb/sdk/interface/interface_shared.py +0 -7
  90. wandb/sdk/interface/interface_sock.py +9 -3
  91. wandb/sdk/internal/_generated/__init__.py +2 -12
  92. wandb/sdk/internal/job_builder.py +27 -10
  93. wandb/sdk/internal/sender.py +5 -2
  94. wandb/sdk/internal/settings_static.py +2 -82
  95. wandb/sdk/launch/create_job.py +2 -1
  96. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  97. wandb/sdk/launch/utils.py +82 -1
  98. wandb/sdk/lib/progress.py +8 -74
  99. wandb/sdk/lib/service/service_client.py +5 -9
  100. wandb/sdk/lib/service/service_connection.py +39 -23
  101. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  102. wandb/sdk/projects/_generated/__init__.py +12 -33
  103. wandb/sdk/wandb_init.py +23 -3
  104. wandb/sdk/wandb_login.py +53 -27
  105. wandb/sdk/wandb_run.py +10 -5
  106. wandb/sdk/wandb_settings.py +63 -25
  107. wandb/sync/sync.py +7 -2
  108. wandb/util.py +1 -1
  109. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/METADATA +1 -1
  110. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/RECORD +113 -103
  111. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  112. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/WHEEL +0 -0
  113. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.22.0.dist-info → wandb-0.22.2.dist-info}/licenses/LICENSE +0 -0
@@ -3,7 +3,6 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import atexit
6
- import concurrent.futures
7
6
  import contextlib
8
7
  import json
9
8
  import logging
@@ -15,10 +14,10 @@ import stat
15
14
  import tempfile
16
15
  import time
17
16
  from collections import deque
17
+ from concurrent.futures import Executor, ThreadPoolExecutor, as_completed
18
18
  from copy import copy
19
- from dataclasses import dataclass
19
+ from dataclasses import asdict, dataclass, replace
20
20
  from datetime import timedelta
21
- from functools import partial
22
21
  from itertools import filterfalse
23
22
  from pathlib import Path, PurePosixPath
24
23
  from typing import (
@@ -30,7 +29,6 @@ from typing import (
30
29
  Literal,
31
30
  Sequence,
32
31
  Type,
33
- cast,
34
32
  final,
35
33
  )
36
34
  from urllib.parse import quote, urljoin, urlparse
@@ -51,6 +49,7 @@ from wandb.errors.term import termerror, termlog, termwarn
51
49
  from wandb.proto import wandb_internal_pb2 as pb
52
50
  from wandb.proto.wandb_deprecated import Deprecated
53
51
  from wandb.sdk import wandb_setup
52
+ from wandb.sdk.artifacts.storage_policies._multipart import should_multipart_download
54
53
  from wandb.sdk.data_types._dtypes import Type as WBType
55
54
  from wandb.sdk.data_types._dtypes import TypeRegistry
56
55
  from wandb.sdk.internal.internal_api import Api as InternalApi
@@ -109,10 +108,11 @@ from ._generated import (
109
108
  TagInput,
110
109
  UpdateArtifact,
111
110
  )
112
- from ._graphql_fragments import omit_artifact_fields
111
+ from ._gqlutils import omit_artifact_fields, supports_enable_tracking_var, type_info
113
112
  from ._validators import (
114
113
  LINKED_ARTIFACT_COLLECTION_TYPE,
115
114
  ArtifactPath,
115
+ FullArtifactPath,
116
116
  _LinkArtifactFields,
117
117
  ensure_logged,
118
118
  ensure_not_finalized,
@@ -210,6 +210,7 @@ class Artifact:
210
210
  metadata: dict[str, Any] | None = None,
211
211
  incremental: bool = False,
212
212
  use_as: str | None = None,
213
+ storage_region: str | None = None,
213
214
  ) -> None:
214
215
  if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
215
216
  raise ValueError(
@@ -268,7 +269,7 @@ class Artifact:
268
269
  self._use_as: str | None = None
269
270
  self._state: ArtifactState = ArtifactState.PENDING
270
271
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
271
- ArtifactManifestV1(storage_policy=make_storage_policy())
272
+ ArtifactManifestV1(storage_policy=make_storage_policy(storage_region))
272
273
  )
273
274
  self._commit_hash: str | None = None
274
275
  self._file_count: int | None = None
@@ -286,36 +287,33 @@ class Artifact:
286
287
 
287
288
  @classmethod
288
289
  def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
289
- if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
290
- return artifact
290
+ if cached_artifact := artifact_instance_cache.get(artifact_id):
291
+ return cached_artifact
291
292
 
292
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
293
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(client))
293
294
 
294
295
  data = client.execute(query, variable_values={"id": artifact_id})
295
296
  result = ArtifactByID.model_validate(data)
296
297
 
297
- if (art := result.artifact) is None:
298
+ if (artifact := result.artifact) is None:
298
299
  return None
299
300
 
300
- src_collection = art.artifact_sequence
301
+ src_collection = artifact.artifact_sequence
301
302
  src_project = src_collection.project
302
303
 
303
304
  entity_name = src_project.entity_name if src_project else ""
304
305
  project_name = src_project.name if src_project else ""
305
306
 
306
- name = f"{src_collection.name}:v{art.version_index}"
307
- return cls._from_attrs(entity_name, project_name, name, art, client)
307
+ name = f"{src_collection.name}:v{artifact.version_index}"
308
+
309
+ path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
310
+ return cls._from_attrs(path, artifact, client)
308
311
 
309
312
  @classmethod
310
313
  def _membership_from_name(
311
- cls,
312
- *,
313
- entity: str,
314
- project: str,
315
- name: str,
316
- client: RetryingClient,
314
+ cls, *, path: FullArtifactPath, client: RetryingClient
317
315
  ) -> Artifact:
318
- if not (api := InternalApi())._server_supports(
316
+ if not InternalApi()._server_supports(
319
317
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
320
318
  ):
321
319
  raise UnsupportedError(
@@ -325,74 +323,70 @@ class Artifact:
325
323
 
326
324
  query = gql_compat(
327
325
  ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
328
- omit_fields=omit_artifact_fields(api=api),
326
+ omit_fields=omit_artifact_fields(client),
329
327
  )
330
-
331
- gql_vars = {"entityName": entity, "projectName": project, "name": name}
328
+ gql_vars = {
329
+ "entityName": path.prefix,
330
+ "projectName": path.project,
331
+ "name": path.name,
332
+ }
332
333
  data = client.execute(query, variable_values=gql_vars)
333
334
  result = ArtifactViaMembershipByName.model_validate(data)
334
335
 
335
- if not (project_attrs := result.project):
336
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
337
-
338
- if not (acm_attrs := project_attrs.artifact_collection_membership):
339
- entity_project = f"{entity}/{project}"
336
+ if not (project := result.project):
340
337
  raise ValueError(
341
- f"artifact membership {name!r} not found in {entity_project!r}"
338
+ f"project {path.project!r} not found under entity {path.prefix!r}"
342
339
  )
343
-
344
- target_path = ArtifactPath(prefix=entity, project=project, name=name)
345
- return cls._from_membership(acm_attrs, target=target_path, client=client)
340
+ if not (membership := project.artifact_collection_membership):
341
+ entity_project = f"{path.prefix}/{path.project}"
342
+ raise ValueError(
343
+ f"artifact membership {path.name!r} not found in {entity_project!r}"
344
+ )
345
+ return cls._from_membership(membership, target=path, client=client)
346
346
 
347
347
  @classmethod
348
348
  def _from_name(
349
349
  cls,
350
350
  *,
351
- entity: str,
352
- project: str,
353
- name: str,
351
+ path: FullArtifactPath,
354
352
  client: RetryingClient,
355
353
  enable_tracking: bool = False,
356
354
  ) -> Artifact:
357
- if (api := InternalApi())._server_supports(
355
+ if InternalApi()._server_supports(
358
356
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
359
357
  ):
360
- return cls._membership_from_name(
361
- entity=entity, project=project, name=name, client=client
362
- )
363
-
364
- supports_enable_tracking_gql_var = api.server_project_type_introspection()
365
- omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
358
+ return cls._membership_from_name(path=path, client=client)
366
359
 
360
+ omit_vars = None if supports_enable_tracking_var(client) else {"enableTracking"}
367
361
  gql_vars = {
368
- "entityName": entity,
369
- "projectName": project,
370
- "name": name,
362
+ "entityName": path.prefix,
363
+ "projectName": path.project,
364
+ "name": path.name,
371
365
  "enableTracking": enable_tracking,
372
366
  }
373
367
  query = gql_compat(
374
368
  ARTIFACT_BY_NAME_GQL,
375
369
  omit_variables=omit_vars,
376
- omit_fields=omit_artifact_fields(api=api),
370
+ omit_fields=omit_artifact_fields(client),
377
371
  )
378
-
379
372
  data = client.execute(query, variable_values=gql_vars)
380
373
  result = ArtifactByName.model_validate(data)
381
374
 
382
- if not (proj_attrs := result.project):
383
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
384
-
385
- if not (art_attrs := proj_attrs.artifact):
386
- entity_project = f"{entity}/{project}"
387
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
375
+ if not (project := result.project):
376
+ raise ValueError(
377
+ f"project {path.project!r} not found under entity {path.prefix!r}"
378
+ )
379
+ if not (artifact := project.artifact):
380
+ entity_project = f"{path.prefix}/{path.project}"
381
+ raise ValueError(f"artifact {path.name!r} not found in {entity_project!r}")
388
382
 
389
- return cls._from_attrs(entity, project, name, art_attrs, client)
383
+ return cls._from_attrs(path, artifact, client)
390
384
 
391
385
  @classmethod
392
386
  def _from_membership(
393
387
  cls,
394
388
  membership: MembershipWithArtifact,
395
- target: ArtifactPath,
389
+ target: FullArtifactPath,
396
390
  client: RetryingClient,
397
391
  ) -> Artifact:
398
392
  if not (
@@ -410,35 +404,31 @@ class Artifact:
410
404
  f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
411
405
  f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
412
406
  )
413
- new_entity, new_project = proj.entity_name, proj.name
407
+ new_target = replace(target, prefix=proj.entity_name, project=proj.name)
414
408
  else:
415
- new_entity = cast(str, target.prefix)
416
- new_project = cast(str, target.project)
409
+ new_target = copy(target)
417
410
 
418
411
  if not (artifact := membership.artifact):
419
412
  raise ValueError(f"Artifact {target.to_str()!r} not found in response")
420
413
 
421
- return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
414
+ return cls._from_attrs(new_target, artifact, client)
422
415
 
423
416
  @classmethod
424
417
  def _from_attrs(
425
418
  cls,
426
- entity: str,
427
- project: str,
428
- name: str,
429
- attrs: dict[str, Any] | ArtifactFragment,
419
+ path: FullArtifactPath,
420
+ attrs: ArtifactFragment,
430
421
  client: RetryingClient,
431
422
  aliases: list[str] | None = None,
432
423
  ) -> Artifact:
433
424
  # Placeholder is required to skip validation.
434
425
  artifact = cls("placeholder", type="placeholder")
435
426
  artifact._client = client
436
- artifact._entity = entity
437
- artifact._project = project
438
- artifact._name = name
427
+ artifact._entity = path.prefix
428
+ artifact._project = path.project
429
+ artifact._name = path.name
439
430
 
440
- validated_attrs = ArtifactFragment.model_validate(attrs)
441
- artifact._assign_attrs(validated_attrs, aliases)
431
+ artifact._assign_attrs(attrs, aliases)
442
432
 
443
433
  artifact.finalize()
444
434
 
@@ -743,13 +733,12 @@ class Artifact:
743
733
  raise ValueError("Client is not initialized")
744
734
 
745
735
  try:
746
- artifact = self._from_name(
747
- entity=self.source_entity,
736
+ path = FullArtifactPath(
737
+ prefix=self.source_entity,
748
738
  project=self.source_project,
749
739
  name=self.source_name,
750
- client=self._client,
751
740
  )
752
- self._source_artifact = artifact
741
+ self._source_artifact = self._from_name(path=path, client=self._client)
753
742
  except Exception as e:
754
743
  raise ValueError(
755
744
  f"Unable to fetch source artifact for linked artifact {self.name}"
@@ -1234,8 +1223,9 @@ class Artifact:
1234
1223
  def _populate_after_save(self, artifact_id: str) -> None:
1235
1224
  assert self._client is not None
1236
1225
 
1237
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1238
-
1226
+ query = gql_compat(
1227
+ ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(self._client)
1228
+ )
1239
1229
  data = self._client.execute(query, variable_values={"id": artifact_id})
1240
1230
  result = ArtifactByID.model_validate(data)
1241
1231
 
@@ -1258,21 +1248,9 @@ class Artifact:
1258
1248
  collection = self.name.split(":")[0]
1259
1249
 
1260
1250
  aliases = None
1261
- introspect_query = gql(
1262
- """
1263
- query ProbeServerAddAliasesInput {
1264
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
1265
- name
1266
- inputFields {
1267
- name
1268
- }
1269
- }
1270
- }
1271
- """
1272
- )
1273
1251
 
1274
- data = self._client.execute(introspect_query)
1275
- if data.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
1252
+ if type_info(self._client, "AddAliasesInput") is not None:
1253
+ # wandb backend version >= 0.13.0
1276
1254
  alias_props = {
1277
1255
  "entity_name": entity,
1278
1256
  "project_name": project,
@@ -1330,7 +1308,7 @@ class Artifact:
1330
1308
  for alias in self.aliases
1331
1309
  ]
1332
1310
 
1333
- omit_fields = omit_artifact_fields()
1311
+ omit_fields = omit_artifact_fields(self._client)
1334
1312
  omit_variables = set()
1335
1313
 
1336
1314
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -2068,24 +2046,14 @@ class Artifact:
2068
2046
 
2069
2047
  download_logger = ArtifactDownloadLogger(nfiles=nfiles)
2070
2048
 
2071
- def _download_entry(
2072
- entry: ArtifactManifestEntry,
2073
- executor: concurrent.futures.Executor,
2074
- api_key: str | None,
2075
- cookies: dict | None,
2076
- headers: dict | None,
2077
- ) -> None:
2078
- _thread_local_api_settings.api_key = api_key
2079
- _thread_local_api_settings.cookies = cookies
2080
- _thread_local_api_settings.headers = headers
2081
-
2049
+ def _download_entry(entry: ArtifactManifestEntry, executor: Executor) -> None:
2050
+ multipart_executor = (
2051
+ executor
2052
+ if should_multipart_download(entry.size, override=multipart)
2053
+ else None
2054
+ )
2082
2055
  try:
2083
- entry.download(
2084
- root,
2085
- skip_cache=skip_cache,
2086
- executor=executor,
2087
- multipart=multipart,
2088
- )
2056
+ entry.download(root, skip_cache=skip_cache, executor=multipart_executor)
2089
2057
  except FileNotFoundError as e:
2090
2058
  if allow_missing_references:
2091
2059
  wandb.termwarn(str(e))
@@ -2096,15 +2064,23 @@ class Artifact:
2096
2064
  return
2097
2065
  download_logger.notify_downloaded()
2098
2066
 
2099
- with concurrent.futures.ThreadPoolExecutor(64) as executor:
2100
- download_entry = partial(
2101
- _download_entry,
2102
- executor=executor,
2103
- api_key=_thread_local_api_settings.api_key,
2104
- cookies=_thread_local_api_settings.cookies,
2105
- headers=_thread_local_api_settings.headers,
2106
- )
2067
+ def _init_thread(
2068
+ api_key: str | None, cookies: dict | None, headers: dict | None
2069
+ ) -> None:
2070
+ """Initialize the thread-local API settings in the CURRENT thread."""
2071
+ _thread_local_api_settings.api_key = api_key
2072
+ _thread_local_api_settings.cookies = cookies
2073
+ _thread_local_api_settings.headers = headers
2107
2074
 
2075
+ with ThreadPoolExecutor(
2076
+ max_workers=64,
2077
+ initializer=_init_thread,
2078
+ initargs=(
2079
+ _thread_local_api_settings.api_key,
2080
+ _thread_local_api_settings.cookies,
2081
+ _thread_local_api_settings.headers,
2082
+ ),
2083
+ ) as executor:
2108
2084
  batch_size = env.get_artifact_fetch_file_url_batch_size()
2109
2085
 
2110
2086
  active_futures = set()
@@ -2125,7 +2101,9 @@ class Artifact:
2125
2101
  # continue
2126
2102
  entry._download_url = node.direct_url
2127
2103
  if (not path_prefix) or entry.path.startswith(str(path_prefix)):
2128
- active_futures.add(executor.submit(download_entry, entry))
2104
+ active_futures.add(
2105
+ executor.submit(_download_entry, entry, executor=executor)
2106
+ )
2129
2107
 
2130
2108
  # Wait for download threads to catch up.
2131
2109
  #
@@ -2140,14 +2118,14 @@ class Artifact:
2140
2118
  # Consider this for a future change, or (depending on risk and risk tolerance)
2141
2119
  # managing this logic via asyncio instead, if viable.
2142
2120
  if len(active_futures) > batch_size:
2143
- for future in concurrent.futures.as_completed(active_futures):
2121
+ for future in as_completed(active_futures):
2144
2122
  future.result() # check for errors
2145
2123
  active_futures.remove(future)
2146
2124
  if len(active_futures) <= batch_size:
2147
2125
  break
2148
2126
 
2149
2127
  # Check for errors.
2150
- for future in concurrent.futures.as_completed(active_futures):
2128
+ for future in as_completed(active_futures):
2151
2129
  future.result()
2152
2130
 
2153
2131
  if log:
@@ -2427,7 +2405,7 @@ class Artifact:
2427
2405
 
2428
2406
  # Parse the entity (first part of the path) appropriately,
2429
2407
  # depending on whether we're linking to a registry
2430
- if target.project and is_artifact_registry_project(target.project):
2408
+ if target.is_registry_path():
2431
2409
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2432
2410
  # therefore the source artifact's entity is passed to the backend
2433
2411
  org = target.prefix or settings.get("organization") or ""
@@ -2435,6 +2413,9 @@ class Artifact:
2435
2413
  else:
2436
2414
  target = target.with_defaults(prefix=self.source_entity)
2437
2415
 
2416
+ # Explicitly convert to FullArtifactPath to ensure all fields are present
2417
+ target = FullArtifactPath(**asdict(target))
2418
+
2438
2419
  # Prepare the validated GQL input, send it
2439
2420
  alias_inputs = [
2440
2421
  ArtifactAliasInput(artifact_collection_name=target.name, alias=a)
@@ -26,6 +26,11 @@ class Opener(Protocol):
26
26
  pass
27
27
 
28
28
 
29
+ def artifacts_cache_dir() -> Path:
30
+ """Get the artifacts cache directory."""
31
+ return env.get_cache_dir() / "artifacts"
32
+
33
+
29
34
  def _get_sys_umask_threadsafe() -> int:
30
35
  # Workaround to get the current system umask, since
31
36
  # - `os.umask()` isn't thread-safe
@@ -248,4 +253,4 @@ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
248
253
 
249
254
 
250
255
  def get_artifact_file_cache() -> ArtifactFileCache:
251
- return _build_artifact_file_cache(env.get_cache_dir() / "artifacts")
256
+ return _build_artifact_file_cache(artifacts_cache_dir())
@@ -3,9 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import concurrent.futures
6
+ import hashlib
6
7
  import json
7
8
  import logging
8
9
  import os
10
+ from contextlib import suppress
9
11
  from pathlib import Path
10
12
  from typing import TYPE_CHECKING
11
13
  from urllib.parse import urlparse
@@ -43,6 +45,48 @@ if TYPE_CHECKING:
43
45
  _WB_ARTIFACT_SCHEME = "wandb-artifact"
44
46
 
45
47
 
48
+ def _checksum_cache_path(file_path: str) -> str:
49
+ """Get path for checksum in central cache directory."""
50
+ from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
51
+
52
+ # Create a unique cache key based on the file's absolute path
53
+ abs_path = os.path.abspath(file_path)
54
+ path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
55
+
56
+ # Store in wandb cache directory under checksums subdirectory
57
+ cache_dir = artifacts_cache_dir() / "checksums"
58
+ cache_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ return str(cache_dir / f"{path_hash}.checksum")
61
+
62
+
63
+ def _read_cached_checksum(file_path: str) -> str | None:
64
+ """Read checksum from cache if it exists and is valid."""
65
+ checksum_path = _checksum_cache_path(file_path)
66
+
67
+ try:
68
+ with open(file_path) as f, open(checksum_path) as f_checksum:
69
+ if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
70
+ # File was modified after checksum was written
71
+ return None
72
+ # Read and return the cached checksum
73
+ return f_checksum.read().strip()
74
+ except OSError:
75
+ # File doesn't exist or couldn't be opened
76
+ return None
77
+
78
+
79
+ def _write_cached_checksum(file_path: str, checksum: str) -> None:
80
+ """Write checksum to cache directory."""
81
+ checksum_path = _checksum_cache_path(file_path)
82
+ try:
83
+ with open(checksum_path, "w") as f:
84
+ f.write(checksum)
85
+ except OSError:
86
+ # Non-critical failure, just log it
87
+ logger.debug(f"Failed to write checksum cache for {file_path!r}")
88
+
89
+
46
90
  class ArtifactManifestEntry:
47
91
  """A single entry in an artifact manifest."""
48
92
 
@@ -138,7 +182,6 @@ class ArtifactManifestEntry:
138
182
  root: str | None = None,
139
183
  skip_cache: bool | None = None,
140
184
  executor: concurrent.futures.Executor | None = None,
141
- multipart: bool | None = None,
142
185
  ) -> FilePathStr:
143
186
  """Download this artifact entry to the specified root path.
144
187
 
@@ -149,11 +192,10 @@ class ArtifactManifestEntry:
149
192
  Returns:
150
193
  (str): The path of the downloaded artifact entry.
151
194
  """
152
- if self._parent_artifact is None:
153
- raise NotImplementedError
195
+ artifact = self.parent_artifact()
154
196
 
155
- root = root or self._parent_artifact._default_root()
156
- self._parent_artifact._add_download_root(root)
197
+ root = root or artifact._default_root()
198
+ artifact._add_download_root(root)
157
199
  path = str(Path(self.path))
158
200
  dest_path = os.path.join(root, path)
159
201
 
@@ -164,32 +206,43 @@ class ArtifactManifestEntry:
164
206
 
165
207
  # Skip checking the cache (and possibly downloading) if the file already exists
166
208
  # and has the digest we're expecting.
209
+
210
+ # Fast integrity check using cached checksum from persistent cache
211
+ with suppress(OSError):
212
+ if self.digest == _read_cached_checksum(dest_path):
213
+ return FilePathStr(dest_path)
214
+
215
+ # Fallback to computing/caching the checksum hash
167
216
  try:
168
217
  md5_hash = md5_file_b64(dest_path)
169
218
  except (FileNotFoundError, IsADirectoryError):
170
- logger.debug(f"unable to find {dest_path}, skip searching for file")
219
+ logger.debug(f"unable to find {dest_path!r}, skip searching for file")
171
220
  else:
221
+ _write_cached_checksum(dest_path, md5_hash)
172
222
  if self.digest == md5_hash:
173
223
  return FilePathStr(dest_path)
174
224
 
175
225
  if self.ref is not None:
176
- cache_path = self._parent_artifact.manifest.storage_policy.load_reference(
226
+ cache_path = artifact.manifest.storage_policy.load_reference(
177
227
  self, local=True, dest_path=override_cache_path
178
228
  )
179
229
  else:
180
- cache_path = self._parent_artifact.manifest.storage_policy.load_file(
181
- self._parent_artifact,
182
- self,
183
- dest_path=override_cache_path,
184
- executor=executor,
185
- multipart=multipart,
230
+ cache_path = artifact.manifest.storage_policy.load_file(
231
+ artifact, self, dest_path=override_cache_path, executor=executor
186
232
  )
187
- return FilePathStr(
233
+
234
+ # Determine the final path
235
+ final_path = (
188
236
  dest_path
189
237
  if skip_cache
190
238
  else copy_or_overwrite_changed(cache_path, dest_path)
191
239
  )
192
240
 
241
+ # Cache the checksum for future downloads
242
+ _write_cached_checksum(str(final_path), self.digest)
243
+
244
+ return FilePathStr(final_path)
245
+
193
246
  def ref_target(self) -> FilePathStr | URIStr:
194
247
  """Get the reference URL that is targeted by this artifact entry.
195
248
 
@@ -2,7 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import TYPE_CHECKING, Sequence
5
+ from abc import ABC, abstractmethod
6
+ from typing import TYPE_CHECKING, Final
6
7
 
7
8
  from wandb.sdk.lib.paths import FilePathStr, URIStr
8
9
 
@@ -12,18 +13,11 @@ if TYPE_CHECKING:
12
13
  from wandb.sdk.artifacts.artifact import Artifact
13
14
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
14
15
 
15
- DEFAULT_MAX_OBJECTS = 10**7
16
+ DEFAULT_MAX_OBJECTS: Final[int] = 10_000_000 # 10**7
16
17
 
17
18
 
18
- class StorageHandler:
19
- def can_handle(self, parsed_url: ParseResult) -> bool:
20
- """Checks whether this handler can handle the given url.
21
-
22
- Returns:
23
- Whether this handler can handle the given url.
24
- """
25
- raise NotImplementedError
26
-
19
+ class _BaseStorageHandler(ABC):
20
+ @abstractmethod
27
21
  def load_path(
28
22
  self,
29
23
  manifest_entry: ArtifactManifestEntry,
@@ -40,6 +34,7 @@ class StorageHandler:
40
34
  """
41
35
  raise NotImplementedError
42
36
 
37
+ @abstractmethod
43
38
  def store_path(
44
39
  self,
45
40
  artifact: Artifact,
@@ -47,7 +42,7 @@ class StorageHandler:
47
42
  name: str | None = None,
48
43
  checksum: bool = True,
49
44
  max_objects: int | None = None,
50
- ) -> Sequence[ArtifactManifestEntry]:
45
+ ) -> list[ArtifactManifestEntry]:
51
46
  """Store the file or directory at the given path to the specified artifact.
52
47
 
53
48
  Args:
@@ -60,3 +55,14 @@ class StorageHandler:
60
55
  A list of manifest entries to store within the artifact
61
56
  """
62
57
  raise NotImplementedError
58
+
59
+
60
+ class StorageHandler(_BaseStorageHandler, ABC): # Handles a single storage protocol
61
+ @abstractmethod
62
+ def can_handle(self, parsed_url: ParseResult) -> bool:
63
+ """Checks whether this handler can handle the given url.
64
+
65
+ Returns:
66
+ Whether this handler can handle the given url.
67
+ """
68
+ raise NotImplementedError