wandb 0.21.4__py3-none-musllinux_1_2_aarch64.whl → 0.22.1__py3-none-musllinux_1_2_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/__init__.pyi +3 -3
  3. wandb/_pydantic/__init__.py +12 -11
  4. wandb/_pydantic/base.py +49 -19
  5. wandb/apis/__init__.py +2 -0
  6. wandb/apis/attrs.py +2 -0
  7. wandb/apis/importers/internals/internal.py +16 -23
  8. wandb/apis/internal.py +2 -0
  9. wandb/apis/normalize.py +2 -0
  10. wandb/apis/public/__init__.py +44 -1
  11. wandb/apis/public/api.py +215 -164
  12. wandb/apis/public/artifacts.py +23 -20
  13. wandb/apis/public/const.py +2 -0
  14. wandb/apis/public/files.py +33 -24
  15. wandb/apis/public/history.py +2 -0
  16. wandb/apis/public/jobs.py +20 -18
  17. wandb/apis/public/projects.py +4 -2
  18. wandb/apis/public/query_generator.py +3 -0
  19. wandb/apis/public/registries/__init__.py +7 -0
  20. wandb/apis/public/registries/_freezable_list.py +9 -12
  21. wandb/apis/public/registries/registries_search.py +8 -6
  22. wandb/apis/public/registries/registry.py +22 -17
  23. wandb/apis/public/reports.py +2 -0
  24. wandb/apis/public/runs.py +282 -60
  25. wandb/apis/public/sweeps.py +10 -9
  26. wandb/apis/public/teams.py +2 -0
  27. wandb/apis/public/users.py +2 -0
  28. wandb/apis/public/utils.py +16 -15
  29. wandb/automations/_generated/__init__.py +54 -127
  30. wandb/automations/_generated/create_generic_webhook_integration.py +1 -7
  31. wandb/automations/_generated/fragments.py +26 -91
  32. wandb/bin/gpu_stats +0 -0
  33. wandb/bin/wandb-core +0 -0
  34. wandb/cli/beta_sync.py +9 -11
  35. wandb/errors/errors.py +3 -3
  36. wandb/proto/v3/wandb_internal_pb2.py +234 -224
  37. wandb/proto/v3/wandb_sync_pb2.py +19 -6
  38. wandb/proto/v4/wandb_internal_pb2.py +226 -224
  39. wandb/proto/v4/wandb_sync_pb2.py +10 -6
  40. wandb/proto/v5/wandb_internal_pb2.py +226 -224
  41. wandb/proto/v5/wandb_sync_pb2.py +10 -6
  42. wandb/proto/v6/wandb_base_pb2.py +3 -3
  43. wandb/proto/v6/wandb_internal_pb2.py +229 -227
  44. wandb/proto/v6/wandb_server_pb2.py +3 -3
  45. wandb/proto/v6/wandb_settings_pb2.py +3 -3
  46. wandb/proto/v6/wandb_sync_pb2.py +13 -9
  47. wandb/proto/v6/wandb_telemetry_pb2.py +3 -3
  48. wandb/sdk/artifacts/_factories.py +7 -2
  49. wandb/sdk/artifacts/_generated/__init__.py +112 -412
  50. wandb/sdk/artifacts/_generated/fragments.py +65 -0
  51. wandb/sdk/artifacts/_generated/operations.py +52 -22
  52. wandb/sdk/artifacts/_generated/run_input_artifacts.py +3 -23
  53. wandb/sdk/artifacts/_generated/run_output_artifacts.py +3 -23
  54. wandb/sdk/artifacts/_generated/type_info.py +19 -0
  55. wandb/sdk/artifacts/_gqlutils.py +47 -0
  56. wandb/sdk/artifacts/_models/__init__.py +4 -0
  57. wandb/sdk/artifacts/_models/base_model.py +20 -0
  58. wandb/sdk/artifacts/_validators.py +40 -12
  59. wandb/sdk/artifacts/artifact.py +69 -88
  60. wandb/sdk/artifacts/artifact_file_cache.py +6 -1
  61. wandb/sdk/artifacts/artifact_manifest_entry.py +61 -2
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -1
  63. wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -3
  64. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -1
  65. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  66. wandb/sdk/artifacts/storage_policies/_factories.py +63 -0
  67. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +69 -124
  68. wandb/sdk/data_types/bokeh.py +5 -1
  69. wandb/sdk/data_types/image.py +17 -6
  70. wandb/sdk/interface/interface.py +41 -4
  71. wandb/sdk/interface/interface_queue.py +10 -0
  72. wandb/sdk/interface/interface_shared.py +9 -7
  73. wandb/sdk/interface/interface_sock.py +9 -3
  74. wandb/sdk/internal/_generated/__init__.py +2 -12
  75. wandb/sdk/internal/sender.py +1 -1
  76. wandb/sdk/internal/settings_static.py +2 -82
  77. wandb/sdk/launch/runner/kubernetes_runner.py +25 -20
  78. wandb/sdk/launch/utils.py +82 -1
  79. wandb/sdk/lib/progress.py +7 -4
  80. wandb/sdk/lib/service/service_client.py +5 -9
  81. wandb/sdk/lib/service/service_connection.py +39 -23
  82. wandb/sdk/mailbox/mailbox_handle.py +2 -0
  83. wandb/sdk/projects/_generated/__init__.py +12 -33
  84. wandb/sdk/wandb_init.py +31 -3
  85. wandb/sdk/wandb_login.py +53 -27
  86. wandb/sdk/wandb_run.py +5 -3
  87. wandb/sdk/wandb_settings.py +50 -13
  88. wandb/sync/sync.py +7 -2
  89. wandb/util.py +1 -1
  90. wandb/wandb_agent.py +35 -4
  91. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/METADATA +1 -1
  92. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/RECORD +818 -814
  93. wandb/sdk/artifacts/_graphql_fragments.py +0 -19
  94. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/WHEEL +0 -0
  95. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/entry_points.txt +0 -0
  96. {wandb-0.21.4.dist-info → wandb-0.22.1.dist-info}/licenses/LICENSE +0 -0
@@ -16,7 +16,7 @@ import tempfile
16
16
  import time
17
17
  from collections import deque
18
18
  from copy import copy
19
- from dataclasses import dataclass
19
+ from dataclasses import asdict, dataclass, replace
20
20
  from datetime import timedelta
21
21
  from functools import partial
22
22
  from itertools import filterfalse
@@ -30,7 +30,6 @@ from typing import (
30
30
  Literal,
31
31
  Sequence,
32
32
  Type,
33
- cast,
34
33
  final,
35
34
  )
36
35
  from urllib.parse import quote, urljoin, urlparse
@@ -109,10 +108,11 @@ from ._generated import (
109
108
  TagInput,
110
109
  UpdateArtifact,
111
110
  )
112
- from ._graphql_fragments import omit_artifact_fields
111
+ from ._gqlutils import omit_artifact_fields, supports_enable_tracking_var, type_info
113
112
  from ._validators import (
114
113
  LINKED_ARTIFACT_COLLECTION_TYPE,
115
114
  ArtifactPath,
115
+ FullArtifactPath,
116
116
  _LinkArtifactFields,
117
117
  ensure_logged,
118
118
  ensure_not_finalized,
@@ -210,6 +210,7 @@ class Artifact:
210
210
  metadata: dict[str, Any] | None = None,
211
211
  incremental: bool = False,
212
212
  use_as: str | None = None,
213
+ storage_region: str | None = None,
213
214
  ) -> None:
214
215
  if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
215
216
  raise ValueError(
@@ -268,7 +269,7 @@ class Artifact:
268
269
  self._use_as: str | None = None
269
270
  self._state: ArtifactState = ArtifactState.PENDING
270
271
  self._manifest: ArtifactManifest | _DeferredArtifactManifest | None = (
271
- ArtifactManifestV1(storage_policy=make_storage_policy())
272
+ ArtifactManifestV1(storage_policy=make_storage_policy(storage_region))
272
273
  )
273
274
  self._commit_hash: str | None = None
274
275
  self._file_count: int | None = None
@@ -286,36 +287,33 @@ class Artifact:
286
287
 
287
288
  @classmethod
288
289
  def _from_id(cls, artifact_id: str, client: RetryingClient) -> Artifact | None:
289
- if (artifact := artifact_instance_cache.get(artifact_id)) is not None:
290
- return artifact
290
+ if cached_artifact := artifact_instance_cache.get(artifact_id):
291
+ return cached_artifact
291
292
 
292
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
293
+ query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(client))
293
294
 
294
295
  data = client.execute(query, variable_values={"id": artifact_id})
295
296
  result = ArtifactByID.model_validate(data)
296
297
 
297
- if (art := result.artifact) is None:
298
+ if (artifact := result.artifact) is None:
298
299
  return None
299
300
 
300
- src_collection = art.artifact_sequence
301
+ src_collection = artifact.artifact_sequence
301
302
  src_project = src_collection.project
302
303
 
303
304
  entity_name = src_project.entity_name if src_project else ""
304
305
  project_name = src_project.name if src_project else ""
305
306
 
306
- name = f"{src_collection.name}:v{art.version_index}"
307
- return cls._from_attrs(entity_name, project_name, name, art, client)
307
+ name = f"{src_collection.name}:v{artifact.version_index}"
308
+
309
+ path = FullArtifactPath(prefix=entity_name, project=project_name, name=name)
310
+ return cls._from_attrs(path, artifact, client)
308
311
 
309
312
  @classmethod
310
313
  def _membership_from_name(
311
- cls,
312
- *,
313
- entity: str,
314
- project: str,
315
- name: str,
316
- client: RetryingClient,
314
+ cls, *, path: FullArtifactPath, client: RetryingClient
317
315
  ) -> Artifact:
318
- if not (api := InternalApi())._server_supports(
316
+ if not InternalApi()._server_supports(
319
317
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
320
318
  ):
321
319
  raise UnsupportedError(
@@ -325,74 +323,70 @@ class Artifact:
325
323
 
326
324
  query = gql_compat(
327
325
  ARTIFACT_VIA_MEMBERSHIP_BY_NAME_GQL,
328
- omit_fields=omit_artifact_fields(api=api),
326
+ omit_fields=omit_artifact_fields(client),
329
327
  )
330
-
331
- gql_vars = {"entityName": entity, "projectName": project, "name": name}
328
+ gql_vars = {
329
+ "entityName": path.prefix,
330
+ "projectName": path.project,
331
+ "name": path.name,
332
+ }
332
333
  data = client.execute(query, variable_values=gql_vars)
333
334
  result = ArtifactViaMembershipByName.model_validate(data)
334
335
 
335
- if not (project_attrs := result.project):
336
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
337
-
338
- if not (acm_attrs := project_attrs.artifact_collection_membership):
339
- entity_project = f"{entity}/{project}"
336
+ if not (project := result.project):
340
337
  raise ValueError(
341
- f"artifact membership {name!r} not found in {entity_project!r}"
338
+ f"project {path.project!r} not found under entity {path.prefix!r}"
342
339
  )
343
-
344
- target_path = ArtifactPath(prefix=entity, project=project, name=name)
345
- return cls._from_membership(acm_attrs, target=target_path, client=client)
340
+ if not (membership := project.artifact_collection_membership):
341
+ entity_project = f"{path.prefix}/{path.project}"
342
+ raise ValueError(
343
+ f"artifact membership {path.name!r} not found in {entity_project!r}"
344
+ )
345
+ return cls._from_membership(membership, target=path, client=client)
346
346
 
347
347
  @classmethod
348
348
  def _from_name(
349
349
  cls,
350
350
  *,
351
- entity: str,
352
- project: str,
353
- name: str,
351
+ path: FullArtifactPath,
354
352
  client: RetryingClient,
355
353
  enable_tracking: bool = False,
356
354
  ) -> Artifact:
357
- if (api := InternalApi())._server_supports(
355
+ if InternalApi()._server_supports(
358
356
  pb.ServerFeature.PROJECT_ARTIFACT_COLLECTION_MEMBERSHIP
359
357
  ):
360
- return cls._membership_from_name(
361
- entity=entity, project=project, name=name, client=client
362
- )
363
-
364
- supports_enable_tracking_gql_var = api.server_project_type_introspection()
365
- omit_vars = None if supports_enable_tracking_gql_var else {"enableTracking"}
358
+ return cls._membership_from_name(path=path, client=client)
366
359
 
360
+ omit_vars = None if supports_enable_tracking_var(client) else {"enableTracking"}
367
361
  gql_vars = {
368
- "entityName": entity,
369
- "projectName": project,
370
- "name": name,
362
+ "entityName": path.prefix,
363
+ "projectName": path.project,
364
+ "name": path.name,
371
365
  "enableTracking": enable_tracking,
372
366
  }
373
367
  query = gql_compat(
374
368
  ARTIFACT_BY_NAME_GQL,
375
369
  omit_variables=omit_vars,
376
- omit_fields=omit_artifact_fields(api=api),
370
+ omit_fields=omit_artifact_fields(client),
377
371
  )
378
-
379
372
  data = client.execute(query, variable_values=gql_vars)
380
373
  result = ArtifactByName.model_validate(data)
381
374
 
382
- if not (proj_attrs := result.project):
383
- raise ValueError(f"project {project!r} not found under entity {entity!r}")
384
-
385
- if not (art_attrs := proj_attrs.artifact):
386
- entity_project = f"{entity}/{project}"
387
- raise ValueError(f"artifact {name!r} not found in {entity_project!r}")
375
+ if not (project := result.project):
376
+ raise ValueError(
377
+ f"project {path.project!r} not found under entity {path.prefix!r}"
378
+ )
379
+ if not (artifact := project.artifact):
380
+ entity_project = f"{path.prefix}/{path.project}"
381
+ raise ValueError(f"artifact {path.name!r} not found in {entity_project!r}")
388
382
 
389
- return cls._from_attrs(entity, project, name, art_attrs, client)
383
+ return cls._from_attrs(path, artifact, client)
390
384
 
391
385
  @classmethod
392
386
  def _from_membership(
393
387
  cls,
394
388
  membership: MembershipWithArtifact,
395
- target: ArtifactPath,
389
+ target: FullArtifactPath,
396
390
  client: RetryingClient,
397
391
  ) -> Artifact:
398
392
  if not (
@@ -410,35 +404,31 @@ class Artifact:
410
404
  f"Your request was redirected to the corresponding artifact {name!r} in the new registry. "
411
405
  f"Please update your paths to point to the migrated registry directly, '{proj.name}/{name}'."
412
406
  )
413
- new_entity, new_project = proj.entity_name, proj.name
407
+ new_target = replace(target, prefix=proj.entity_name, project=proj.name)
414
408
  else:
415
- new_entity = cast(str, target.prefix)
416
- new_project = cast(str, target.project)
409
+ new_target = copy(target)
417
410
 
418
411
  if not (artifact := membership.artifact):
419
412
  raise ValueError(f"Artifact {target.to_str()!r} not found in response")
420
413
 
421
- return cls._from_attrs(new_entity, new_project, target.name, artifact, client)
414
+ return cls._from_attrs(new_target, artifact, client)
422
415
 
423
416
  @classmethod
424
417
  def _from_attrs(
425
418
  cls,
426
- entity: str,
427
- project: str,
428
- name: str,
429
- attrs: dict[str, Any] | ArtifactFragment,
419
+ path: FullArtifactPath,
420
+ attrs: ArtifactFragment,
430
421
  client: RetryingClient,
431
422
  aliases: list[str] | None = None,
432
423
  ) -> Artifact:
433
424
  # Placeholder is required to skip validation.
434
425
  artifact = cls("placeholder", type="placeholder")
435
426
  artifact._client = client
436
- artifact._entity = entity
437
- artifact._project = project
438
- artifact._name = name
427
+ artifact._entity = path.prefix
428
+ artifact._project = path.project
429
+ artifact._name = path.name
439
430
 
440
- validated_attrs = ArtifactFragment.model_validate(attrs)
441
- artifact._assign_attrs(validated_attrs, aliases)
431
+ artifact._assign_attrs(attrs, aliases)
442
432
 
443
433
  artifact.finalize()
444
434
 
@@ -743,13 +733,12 @@ class Artifact:
743
733
  raise ValueError("Client is not initialized")
744
734
 
745
735
  try:
746
- artifact = self._from_name(
747
- entity=self.source_entity,
736
+ path = FullArtifactPath(
737
+ prefix=self.source_entity,
748
738
  project=self.source_project,
749
739
  name=self.source_name,
750
- client=self._client,
751
740
  )
752
- self._source_artifact = artifact
741
+ self._source_artifact = self._from_name(path=path, client=self._client)
753
742
  except Exception as e:
754
743
  raise ValueError(
755
744
  f"Unable to fetch source artifact for linked artifact {self.name}"
@@ -1234,8 +1223,9 @@ class Artifact:
1234
1223
  def _populate_after_save(self, artifact_id: str) -> None:
1235
1224
  assert self._client is not None
1236
1225
 
1237
- query = gql_compat(ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields())
1238
-
1226
+ query = gql_compat(
1227
+ ARTIFACT_BY_ID_GQL, omit_fields=omit_artifact_fields(self._client)
1228
+ )
1239
1229
  data = self._client.execute(query, variable_values={"id": artifact_id})
1240
1230
  result = ArtifactByID.model_validate(data)
1241
1231
 
@@ -1258,21 +1248,9 @@ class Artifact:
1258
1248
  collection = self.name.split(":")[0]
1259
1249
 
1260
1250
  aliases = None
1261
- introspect_query = gql(
1262
- """
1263
- query ProbeServerAddAliasesInput {
1264
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
1265
- name
1266
- inputFields {
1267
- name
1268
- }
1269
- }
1270
- }
1271
- """
1272
- )
1273
1251
 
1274
- data = self._client.execute(introspect_query)
1275
- if data.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
1252
+ if type_info(self._client, "AddAliasesInput") is not None:
1253
+ # wandb backend version >= 0.13.0
1276
1254
  alias_props = {
1277
1255
  "entity_name": entity,
1278
1256
  "project_name": project,
@@ -1330,7 +1308,7 @@ class Artifact:
1330
1308
  for alias in self.aliases
1331
1309
  ]
1332
1310
 
1333
- omit_fields = omit_artifact_fields()
1311
+ omit_fields = omit_artifact_fields(self._client)
1334
1312
  omit_variables = set()
1335
1313
 
1336
1314
  if {"ttlIsInherited", "ttlDurationSeconds"} & omit_fields:
@@ -2427,7 +2405,7 @@ class Artifact:
2427
2405
 
2428
2406
  # Parse the entity (first part of the path) appropriately,
2429
2407
  # depending on whether we're linking to a registry
2430
- if target.project and is_artifact_registry_project(target.project):
2408
+ if target.is_registry_path():
2431
2409
  # In a Registry linking, the entity is used to fetch the organization of the artifact
2432
2410
  # therefore the source artifact's entity is passed to the backend
2433
2411
  org = target.prefix or settings.get("organization") or ""
@@ -2435,6 +2413,9 @@ class Artifact:
2435
2413
  else:
2436
2414
  target = target.with_defaults(prefix=self.source_entity)
2437
2415
 
2416
+ # Explicitly convert to FullArtifactPath to ensure all fields are present
2417
+ target = FullArtifactPath(**asdict(target))
2418
+
2438
2419
  # Prepare the validated GQL input, send it
2439
2420
  alias_inputs = [
2440
2421
  ArtifactAliasInput(artifact_collection_name=target.name, alias=a)
@@ -26,6 +26,11 @@ class Opener(Protocol):
26
26
  pass
27
27
 
28
28
 
29
+ def artifacts_cache_dir() -> Path:
30
+ """Get the artifacts cache directory."""
31
+ return env.get_cache_dir() / "artifacts"
32
+
33
+
29
34
  def _get_sys_umask_threadsafe() -> int:
30
35
  # Workaround to get the current system umask, since
31
36
  # - `os.umask()` isn't thread-safe
@@ -248,4 +253,4 @@ def _build_artifact_file_cache(cache_dir: StrPath) -> ArtifactFileCache:
248
253
 
249
254
 
250
255
  def get_artifact_file_cache() -> ArtifactFileCache:
251
- return _build_artifact_file_cache(env.get_cache_dir() / "artifacts")
256
+ return _build_artifact_file_cache(artifacts_cache_dir())
@@ -3,9 +3,11 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import concurrent.futures
6
+ import hashlib
6
7
  import json
7
8
  import logging
8
9
  import os
10
+ from contextlib import suppress
9
11
  from pathlib import Path
10
12
  from typing import TYPE_CHECKING
11
13
  from urllib.parse import urlparse
@@ -43,6 +45,48 @@ if TYPE_CHECKING:
43
45
  _WB_ARTIFACT_SCHEME = "wandb-artifact"
44
46
 
45
47
 
48
+ def _checksum_cache_path(file_path: str) -> str:
49
+ """Get path for checksum in central cache directory."""
50
+ from wandb.sdk.artifacts.artifact_file_cache import artifacts_cache_dir
51
+
52
+ # Create a unique cache key based on the file's absolute path
53
+ abs_path = os.path.abspath(file_path)
54
+ path_hash = hashlib.sha256(abs_path.encode()).hexdigest()
55
+
56
+ # Store in wandb cache directory under checksums subdirectory
57
+ cache_dir = artifacts_cache_dir() / "checksums"
58
+ cache_dir.mkdir(parents=True, exist_ok=True)
59
+
60
+ return str(cache_dir / f"{path_hash}.checksum")
61
+
62
+
63
+ def _read_cached_checksum(file_path: str) -> str | None:
64
+ """Read checksum from cache if it exists and is valid."""
65
+ checksum_path = _checksum_cache_path(file_path)
66
+
67
+ try:
68
+ with open(file_path) as f, open(checksum_path) as f_checksum:
69
+ if os.path.getmtime(f_checksum.name) < os.path.getmtime(f.name):
70
+ # File was modified after checksum was written
71
+ return None
72
+ # Read and return the cached checksum
73
+ return f_checksum.read().strip()
74
+ except OSError:
75
+ # File doesn't exist or couldn't be opened
76
+ return None
77
+
78
+
79
+ def _write_cached_checksum(file_path: str, checksum: str) -> None:
80
+ """Write checksum to cache directory."""
81
+ checksum_path = _checksum_cache_path(file_path)
82
+ try:
83
+ with open(checksum_path, "w") as f:
84
+ f.write(checksum)
85
+ except OSError:
86
+ # Non-critical failure, just log it
87
+ logger.debug(f"Failed to write checksum cache for {file_path!r}")
88
+
89
+
46
90
  class ArtifactManifestEntry:
47
91
  """A single entry in an artifact manifest."""
48
92
 
@@ -164,11 +208,19 @@ class ArtifactManifestEntry:
164
208
 
165
209
  # Skip checking the cache (and possibly downloading) if the file already exists
166
210
  # and has the digest we're expecting.
211
+
212
+ # Fast integrity check using cached checksum from persistent cache
213
+ with suppress(OSError):
214
+ if self.digest == _read_cached_checksum(dest_path):
215
+ return FilePathStr(dest_path)
216
+
217
+ # Fallback to computing/caching the checksum hash
167
218
  try:
168
219
  md5_hash = md5_file_b64(dest_path)
169
220
  except (FileNotFoundError, IsADirectoryError):
170
- logger.debug(f"unable to find {dest_path}, skip searching for file")
221
+ logger.debug(f"unable to find {dest_path!r}, skip searching for file")
171
222
  else:
223
+ _write_cached_checksum(dest_path, md5_hash)
172
224
  if self.digest == md5_hash:
173
225
  return FilePathStr(dest_path)
174
226
 
@@ -184,12 +236,19 @@ class ArtifactManifestEntry:
184
236
  executor=executor,
185
237
  multipart=multipart,
186
238
  )
187
- return FilePathStr(
239
+
240
+ # Determine the final path
241
+ final_path = (
188
242
  dest_path
189
243
  if skip_cache
190
244
  else copy_or_overwrite_changed(cache_path, dest_path)
191
245
  )
192
246
 
247
+ # Cache the checksum for future downloads
248
+ _write_cached_checksum(str(final_path), self.digest)
249
+
250
+ return FilePathStr(final_path)
251
+
193
252
  def ref_target(self) -> FilePathStr | URIStr:
194
253
  """Get the reference URL that is targeted by this artifact entry.
195
254
 
@@ -65,7 +65,7 @@ class GCSHandler(StorageHandler):
65
65
  path, hit, cache_open = self._cache.check_etag_obj_path(
66
66
  url=URIStr(manifest_entry.ref),
67
67
  etag=ETag(manifest_entry.digest),
68
- size=manifest_entry.size if manifest_entry.size is not None else 0,
68
+ size=manifest_entry.size or 0,
69
69
  )
70
70
  if hit:
71
71
  return path
@@ -43,7 +43,7 @@ class HTTPHandler(StorageHandler):
43
43
  path, hit, cache_open = self._cache.check_etag_obj_path(
44
44
  URIStr(manifest_entry.ref),
45
45
  ETag(manifest_entry.digest),
46
- manifest_entry.size if manifest_entry.size is not None else 0,
46
+ manifest_entry.size or 0,
47
47
  )
48
48
  if hit:
49
49
  return path
@@ -54,7 +54,6 @@ class HTTPHandler(StorageHandler):
54
54
  cookies=_thread_local_api_settings.cookies,
55
55
  headers=_thread_local_api_settings.headers,
56
56
  )
57
- response.raise_for_status()
58
57
 
59
58
  digest: ETag | FilePathStr | URIStr | None
60
59
  digest, size, extra = self._entry_from_headers(response.headers)
@@ -87,7 +86,6 @@ class HTTPHandler(StorageHandler):
87
86
  cookies=_thread_local_api_settings.cookies,
88
87
  headers=_thread_local_api_settings.headers,
89
88
  ) as response:
90
- response.raise_for_status()
91
89
  digest: ETag | FilePathStr | URIStr | None
92
90
  digest, size, extra = self._entry_from_headers(response.headers)
93
91
  digest = digest or path
@@ -51,7 +51,7 @@ class LocalFileHandler(StorageHandler):
51
51
 
52
52
  path, hit, cache_open = self._cache.check_md5_obj_path(
53
53
  B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
54
- manifest_entry.size if manifest_entry.size is not None else 0,
54
+ manifest_entry.size or 0,
55
55
  )
56
56
  if hit:
57
57
  return path
@@ -96,7 +96,7 @@ class S3Handler(StorageHandler):
96
96
  path, hit, cache_open = self._cache.check_etag_obj_path(
97
97
  URIStr(manifest_entry.ref),
98
98
  ETag(manifest_entry.digest),
99
- manifest_entry.size if manifest_entry.size is not None else 0,
99
+ manifest_entry.size or 0,
100
100
  )
101
101
  if hit:
102
102
  return path
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Final
4
+
5
+ from requests import Response, Session
6
+ from requests.adapters import HTTPAdapter
7
+ from urllib3.util.retry import Retry
8
+
9
+ from ..storage_handler import StorageHandler
10
+ from ..storage_handlers.azure_handler import AzureHandler
11
+ from ..storage_handlers.gcs_handler import GCSHandler
12
+ from ..storage_handlers.http_handler import HTTPHandler
13
+ from ..storage_handlers.local_file_handler import LocalFileHandler
14
+ from ..storage_handlers.s3_handler import S3Handler
15
+ from ..storage_handlers.wb_artifact_handler import WBArtifactHandler
16
+ from ..storage_handlers.wb_local_artifact_handler import WBLocalArtifactHandler
17
+
18
+ # Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
19
+ # seconds, i.e. a total of 20min 6s.
20
+ HTTP_RETRY_STRATEGY: Final[Retry] = Retry(
21
+ backoff_factor=1,
22
+ total=16,
23
+ status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
24
+ )
25
+ HTTP_POOL_CONNECTIONS: Final[int] = 64
26
+ HTTP_POOL_MAXSIZE: Final[int] = 64
27
+
28
+
29
+ def raise_for_status(response: Response, *_, **__) -> None:
30
+ """A `requests.Session` hook to raise for status on all requests."""
31
+ response.raise_for_status()
32
+
33
+
34
+ def make_http_session() -> Session:
35
+ """A factory that returns a `requests.Session` for use with artifact storage handlers."""
36
+ session = Session()
37
+
38
+ # Explicitly configure the retry strategy for http/https adapters.
39
+ adapter = HTTPAdapter(
40
+ max_retries=HTTP_RETRY_STRATEGY,
41
+ pool_connections=HTTP_POOL_CONNECTIONS,
42
+ pool_maxsize=HTTP_POOL_MAXSIZE,
43
+ )
44
+ session.mount("http://", adapter)
45
+ session.mount("https://", adapter)
46
+
47
+ # Always raise on HTTP status errors.
48
+ session.hooks["response"].append(raise_for_status)
49
+ return session
50
+
51
+
52
+ def make_storage_handlers(session: Session) -> list[StorageHandler]:
53
+ """A factory that returns the default artifact storage handlers."""
54
+ return [
55
+ S3Handler(), # s3
56
+ GCSHandler(), # gcs
57
+ AzureHandler(), # azure
58
+ HTTPHandler(session, scheme="http"), # http
59
+ HTTPHandler(session, scheme="https"), # https
60
+ WBArtifactHandler(), # artifact
61
+ WBLocalArtifactHandler(), # local_artifact
62
+ LocalFileHandler(), # file_handler
63
+ ]