wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
1
+ """S3 bucket storage policy."""
2
+ from typing import TYPE_CHECKING, Dict, Optional, Sequence, Union
3
+
4
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
5
+ from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
6
+ from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
7
+ from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
8
+ from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
9
+ from wandb.sdk.artifacts.storage_policy import StoragePolicy
10
+
11
+ if TYPE_CHECKING:
12
+ from wandb.sdk.artifacts.artifact import Artifact
13
+ from wandb.sdk.lib.paths import FilePathStr, URIStr
14
+
15
+
16
+ # Don't use this yet!
17
+ class __S3BucketPolicy(StoragePolicy): # noqa: N801
18
+ @classmethod
19
+ def name(cls) -> str:
20
+ return "wandb-s3-bucket-policy-v1"
21
+
22
+ @classmethod
23
+ def from_config(cls, config: Dict[str, str]) -> "__S3BucketPolicy":
24
+ if "bucket" not in config:
25
+ raise ValueError("Bucket name not found in config")
26
+ return cls(config["bucket"])
27
+
28
+ def __init__(self, bucket: str) -> None:
29
+ self._bucket = bucket
30
+ s3 = S3Handler(bucket)
31
+ local = LocalFileHandler()
32
+
33
+ self._handler = MultiHandler(
34
+ handlers=[
35
+ s3,
36
+ local,
37
+ ],
38
+ default_handler=TrackingHandler(),
39
+ )
40
+
41
+ def config(self) -> Dict[str, str]:
42
+ return {"bucket": self._bucket}
43
+
44
+ def load_path(
45
+ self,
46
+ manifest_entry: ArtifactManifestEntry,
47
+ local: bool = False,
48
+ ) -> Union[URIStr, FilePathStr]:
49
+ return self._handler.load_path(manifest_entry, local=local)
50
+
51
+ def store_path(
52
+ self,
53
+ artifact: "Artifact",
54
+ path: Union[URIStr, FilePathStr],
55
+ name: Optional[str] = None,
56
+ checksum: bool = True,
57
+ max_objects: Optional[int] = None,
58
+ ) -> Sequence[ArtifactManifestEntry]:
59
+ return self._handler.store_path(
60
+ artifact, path, name=name, checksum=checksum, max_objects=max_objects
61
+ )
@@ -0,0 +1,386 @@
1
+ """WandB storage policy."""
2
+ import hashlib
3
+ import math
4
+ import shutil
5
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
6
+ from urllib.parse import quote
7
+
8
+ import requests
9
+ import urllib3
10
+
11
+ from wandb.apis import InternalApi
12
+ from wandb.errors.term import termwarn
13
+ from wandb.sdk.artifacts.artifacts_cache import ArtifactsCache, get_artifacts_cache
14
+ from wandb.sdk.artifacts.storage_handlers.azure_handler import AzureHandler
15
+ from wandb.sdk.artifacts.storage_handlers.gcs_handler import GCSHandler
16
+ from wandb.sdk.artifacts.storage_handlers.http_handler import HTTPHandler
17
+ from wandb.sdk.artifacts.storage_handlers.local_file_handler import LocalFileHandler
18
+ from wandb.sdk.artifacts.storage_handlers.multi_handler import MultiHandler
19
+ from wandb.sdk.artifacts.storage_handlers.s3_handler import S3Handler
20
+ from wandb.sdk.artifacts.storage_handlers.tracking_handler import TrackingHandler
21
+ from wandb.sdk.artifacts.storage_handlers.wb_artifact_handler import WBArtifactHandler
22
+ from wandb.sdk.artifacts.storage_handlers.wb_local_artifact_handler import (
23
+ WBLocalArtifactHandler,
24
+ )
25
+ from wandb.sdk.artifacts.storage_layout import StorageLayout
26
+ from wandb.sdk.artifacts.storage_policy import StoragePolicy
27
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
28
+ from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, hex_to_b64_id
29
+ from wandb.sdk.lib.paths import FilePathStr, URIStr
30
+
31
+ if TYPE_CHECKING:
32
+ from wandb.filesync.step_prepare import StepPrepare
33
+ from wandb.sdk.artifacts.artifact import Artifact
34
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
35
+ from wandb.sdk.internal import progress
36
+
37
+ # Sleep length: 0, 2, 4, 8, 16, 32, 64, 120, 120, 120, 120, 120, 120, 120, 120, 120
38
+ # seconds, i.e. a total of 20min 6s.
39
+ _REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
40
+ backoff_factor=1,
41
+ total=16,
42
+ status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
43
+ )
44
+ _REQUEST_POOL_CONNECTIONS = 64
45
+ _REQUEST_POOL_MAXSIZE = 64
46
+
47
+ # AWS S3 max upload parts without having to make additional requests for extra parts
48
+ S3_MAX_PART_NUMBERS = 1000
49
+ S3_MIN_MULTI_UPLOAD_SIZE = 2 * 1024**3
50
+ S3_MAX_MULTI_UPLOAD_SIZE = 5 * 1024**4
51
+
52
+
53
+ class WandbStoragePolicy(StoragePolicy):
54
+ @classmethod
55
+ def name(cls) -> str:
56
+ return "wandb-storage-policy-v1"
57
+
58
+ @classmethod
59
+ def from_config(cls, config: Dict) -> "WandbStoragePolicy":
60
+ return cls(config=config)
61
+
62
+ def __init__(
63
+ self,
64
+ config: Optional[Dict] = None,
65
+ cache: Optional[ArtifactsCache] = None,
66
+ api: Optional[InternalApi] = None,
67
+ ) -> None:
68
+ self._cache = cache or get_artifacts_cache()
69
+ self._config = config or {}
70
+ self._session = requests.Session()
71
+ adapter = requests.adapters.HTTPAdapter(
72
+ max_retries=_REQUEST_RETRY_STRATEGY,
73
+ pool_connections=_REQUEST_POOL_CONNECTIONS,
74
+ pool_maxsize=_REQUEST_POOL_MAXSIZE,
75
+ )
76
+ self._session.mount("http://", adapter)
77
+ self._session.mount("https://", adapter)
78
+
79
+ s3 = S3Handler()
80
+ gcs = GCSHandler()
81
+ azure = AzureHandler()
82
+ http = HTTPHandler(self._session)
83
+ https = HTTPHandler(self._session, scheme="https")
84
+ artifact = WBArtifactHandler()
85
+ local_artifact = WBLocalArtifactHandler()
86
+ file_handler = LocalFileHandler()
87
+
88
+ self._api = api or InternalApi()
89
+ self._handler = MultiHandler(
90
+ handlers=[
91
+ s3,
92
+ gcs,
93
+ azure,
94
+ http,
95
+ https,
96
+ artifact,
97
+ local_artifact,
98
+ file_handler,
99
+ ],
100
+ default_handler=TrackingHandler(),
101
+ )
102
+
103
+ def config(self) -> Dict:
104
+ return self._config
105
+
106
+ def load_file(
107
+ self,
108
+ artifact: "Artifact",
109
+ manifest_entry: "ArtifactManifestEntry",
110
+ ) -> FilePathStr:
111
+ path, hit, cache_open = self._cache.check_md5_obj_path(
112
+ B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
113
+ manifest_entry.size if manifest_entry.size is not None else 0,
114
+ )
115
+ if hit:
116
+ return path
117
+
118
+ if manifest_entry._download_url is not None:
119
+ response = self._session.get(manifest_entry._download_url, stream=True)
120
+ try:
121
+ response.raise_for_status()
122
+ except Exception:
123
+ # Signed URL might have expired, fall back to fetching it one by one.
124
+ manifest_entry._download_url = None
125
+ if manifest_entry._download_url is None:
126
+ auth = None
127
+ if not _thread_local_api_settings.cookies:
128
+ auth = ("api", self._api.api_key)
129
+ response = self._session.get(
130
+ self._file_url(self._api, artifact.entity, manifest_entry),
131
+ auth=auth,
132
+ cookies=_thread_local_api_settings.cookies,
133
+ headers=_thread_local_api_settings.headers,
134
+ stream=True,
135
+ )
136
+ response.raise_for_status()
137
+
138
+ with cache_open(mode="wb") as file:
139
+ for data in response.iter_content(chunk_size=16 * 1024):
140
+ file.write(data)
141
+ return path
142
+
143
+ def store_reference(
144
+ self,
145
+ artifact: "Artifact",
146
+ path: Union[URIStr, FilePathStr],
147
+ name: Optional[str] = None,
148
+ checksum: bool = True,
149
+ max_objects: Optional[int] = None,
150
+ ) -> Sequence["ArtifactManifestEntry"]:
151
+ return self._handler.store_path(
152
+ artifact, path, name=name, checksum=checksum, max_objects=max_objects
153
+ )
154
+
155
+ def load_reference(
156
+ self,
157
+ manifest_entry: "ArtifactManifestEntry",
158
+ local: bool = False,
159
+ ) -> Union[FilePathStr, URIStr]:
160
+ return self._handler.load_path(manifest_entry, local)
161
+
162
+ def _file_url(
163
+ self,
164
+ api: InternalApi,
165
+ entity_name: str,
166
+ manifest_entry: "ArtifactManifestEntry",
167
+ ) -> str:
168
+ storage_layout = self._config.get("storageLayout", StorageLayout.V1)
169
+ storage_region = self._config.get("storageRegion", "default")
170
+ md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
171
+
172
+ if storage_layout == StorageLayout.V1:
173
+ return "{}/artifacts/{}/{}".format(
174
+ api.settings("base_url"), entity_name, md5_hex
175
+ )
176
+ elif storage_layout == StorageLayout.V2:
177
+ return "{}/artifactsV2/{}/{}/{}/{}".format(
178
+ api.settings("base_url"),
179
+ storage_region,
180
+ entity_name,
181
+ quote(
182
+ manifest_entry.birth_artifact_id
183
+ if manifest_entry.birth_artifact_id is not None
184
+ else ""
185
+ ),
186
+ md5_hex,
187
+ )
188
+ else:
189
+ raise Exception(f"unrecognized storage layout: {storage_layout}")
190
+
191
+ def s3_multipart_file_upload(
192
+ self,
193
+ file_path: str,
194
+ chunk_size: int,
195
+ hex_digests: Dict[int, str],
196
+ multipart_urls: Dict[int, str],
197
+ extra_headers: Dict[str, str],
198
+ ) -> List[Dict[str, Any]]:
199
+ etags = []
200
+ part_number = 1
201
+
202
+ with open(file_path, "rb") as f:
203
+ while True:
204
+ data = f.read(chunk_size)
205
+ if not data:
206
+ break
207
+ md5_b64_str = str(hex_to_b64_id(hex_digests[part_number]))
208
+ upload_resp = self._api.upload_multipart_file_chunk_retry(
209
+ multipart_urls[part_number],
210
+ data,
211
+ extra_headers={
212
+ "content-md5": md5_b64_str,
213
+ "content-length": str(len(data)),
214
+ "content-type": extra_headers.get("Content-Type"),
215
+ },
216
+ )
217
+ etags.append(
218
+ {"partNumber": part_number, "hexMD5": upload_resp.headers["ETag"]}
219
+ )
220
+ part_number += 1
221
+ return etags
222
+
223
+ def default_file_upload(
224
+ self,
225
+ upload_url: str,
226
+ file_path: str,
227
+ extra_headers: Dict[str, Any],
228
+ progress_callback: Optional["progress.ProgressFn"] = None,
229
+ ) -> None:
230
+ """Upload a file to the artifact store and write to cache."""
231
+ with open(file_path, "rb") as file:
232
+ # This fails if we don't send the first byte before the signed URL expires.
233
+ self._api.upload_file_retry(
234
+ upload_url,
235
+ file,
236
+ progress_callback,
237
+ extra_headers=extra_headers,
238
+ )
239
+
240
+ def calc_chunk_size(self, file_size: int) -> int:
241
+ # Default to chunk size of 100MiB. S3 has cap of 10,000 upload parts.
242
+ # If file size exceeds the default chunk size, recalculate chunk size.
243
+ default_chunk_size = 100 * 1024**2
244
+ if default_chunk_size * S3_MAX_PART_NUMBERS < file_size:
245
+ return math.ceil(file_size / S3_MAX_PART_NUMBERS)
246
+ return default_chunk_size
247
+
248
+ def store_file_sync(
249
+ self,
250
+ artifact_id: str,
251
+ artifact_manifest_id: str,
252
+ entry: "ArtifactManifestEntry",
253
+ preparer: "StepPrepare",
254
+ progress_callback: Optional["progress.ProgressFn"] = None,
255
+ ) -> bool:
256
+ """Upload a file to the artifact store.
257
+
258
+ Returns:
259
+ True if the file was a duplicate (did not need to be uploaded),
260
+ False if it needed to be uploaded or was a reference (nothing to dedupe).
261
+ """
262
+ file_size = entry.size if entry.size is not None else 0
263
+ chunk_size = self.calc_chunk_size(file_size)
264
+ upload_parts = []
265
+ hex_digests = {}
266
+ file_path = entry.local_path if entry.local_path is not None else ""
267
+ # Logic for AWS s3 multipart upload.
268
+ # Only chunk files if larger than 2 GiB. Currently can only support up to 5TiB.
269
+ if (
270
+ file_size >= S3_MIN_MULTI_UPLOAD_SIZE
271
+ and file_size <= S3_MAX_MULTI_UPLOAD_SIZE
272
+ ):
273
+ part_number = 1
274
+ with open(file_path, "rb") as f:
275
+ while True:
276
+ data = f.read(chunk_size)
277
+ if not data:
278
+ break
279
+ hex_digest = hashlib.md5(data).hexdigest()
280
+ upload_parts.append(
281
+ {"hexMD5": hex_digest, "partNumber": part_number}
282
+ )
283
+ hex_digests[part_number] = hex_digest
284
+ part_number += 1
285
+
286
+ resp = preparer.prepare_sync(
287
+ {
288
+ "artifactID": artifact_id,
289
+ "artifactManifestID": artifact_manifest_id,
290
+ "name": entry.path,
291
+ "md5": entry.digest,
292
+ "uploadPartsInput": upload_parts,
293
+ }
294
+ ).get()
295
+
296
+ entry.birth_artifact_id = resp.birth_artifact_id
297
+
298
+ multipart_urls = resp.multipart_upload_urls
299
+ if resp.upload_url is None:
300
+ return True
301
+ if entry.local_path is None:
302
+ return False
303
+
304
+ extra_headers = {
305
+ header.split(":", 1)[0]: header.split(":", 1)[1]
306
+ for header in (resp.upload_headers or {})
307
+ }
308
+
309
+ # This multipart upload isn't available, do a regular single url upload
310
+ if multipart_urls is None and resp.upload_url:
311
+ self.default_file_upload(
312
+ resp.upload_url, file_path, extra_headers, progress_callback
313
+ )
314
+ else:
315
+ if multipart_urls is None:
316
+ raise ValueError(f"No multipart urls to upload for file: {file_path}")
317
+ # Upload files using s3 multipart upload urls
318
+ etags = self.s3_multipart_file_upload(
319
+ file_path,
320
+ chunk_size,
321
+ hex_digests,
322
+ multipart_urls,
323
+ extra_headers,
324
+ )
325
+ self._api.complete_multipart_upload_artifact(
326
+ artifact_id, resp.storage_path, etags, resp.upload_id
327
+ )
328
+ self._write_cache(entry)
329
+
330
+ return False
331
+
332
+ async def store_file_async(
333
+ self,
334
+ artifact_id: str,
335
+ artifact_manifest_id: str,
336
+ entry: "ArtifactManifestEntry",
337
+ preparer: "StepPrepare",
338
+ progress_callback: Optional["progress.ProgressFn"] = None,
339
+ ) -> bool:
340
+ """Async equivalent to `store_file_sync`."""
341
+ resp = await preparer.prepare_async(
342
+ {
343
+ "artifactID": artifact_id,
344
+ "artifactManifestID": artifact_manifest_id,
345
+ "name": entry.path,
346
+ "md5": entry.digest,
347
+ }
348
+ )
349
+
350
+ entry.birth_artifact_id = resp.birth_artifact_id
351
+ if resp.upload_url is None:
352
+ return True
353
+ if entry.local_path is None:
354
+ return False
355
+
356
+ with open(entry.local_path, "rb") as file:
357
+ # This fails if we don't send the first byte before the signed URL expires.
358
+ await self._api.upload_file_retry_async(
359
+ resp.upload_url,
360
+ file,
361
+ progress_callback,
362
+ extra_headers={
363
+ header.split(":", 1)[0]: header.split(":", 1)[1]
364
+ for header in (resp.upload_headers or {})
365
+ },
366
+ )
367
+
368
+ self._write_cache(entry)
369
+
370
+ return False
371
+
372
+ def _write_cache(self, entry: "ArtifactManifestEntry") -> None:
373
+ if entry.local_path is None:
374
+ return
375
+
376
+ # Cache upon successful upload.
377
+ _, hit, cache_open = self._cache.check_md5_obj_path(
378
+ B64MD5(entry.digest),
379
+ entry.size if entry.size is not None else 0,
380
+ )
381
+ if not hit:
382
+ try:
383
+ with cache_open() as f:
384
+ shutil.copyfile(entry.local_path, f.name)
385
+ except OSError as e:
386
+ termwarn(f"Failed to cache {entry.local_path}, ignoring {e}")
@@ -1,20 +1,15 @@
1
+ """Storage policy."""
1
2
  from typing import TYPE_CHECKING, Dict, Optional, Sequence, Type, Union
2
3
 
3
4
  from wandb.sdk.lib.paths import FilePathStr, URIStr
4
5
 
5
6
  if TYPE_CHECKING:
6
- from urllib.parse import ParseResult
7
-
8
7
  from wandb.filesync.step_prepare import StepPrepare
9
- from wandb.sdk.interface.artifacts import Artifact, ArtifactManifestEntry
8
+ from wandb.sdk.artifacts.artifact import Artifact
9
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
10
10
  from wandb.sdk.internal.progress import ProgressFn
11
11
 
12
12
 
13
- class StorageLayout:
14
- V1 = "V1"
15
- V2 = "V2"
16
-
17
-
18
13
  class StoragePolicy:
19
14
  @classmethod
20
15
  def lookup_by_name(cls, name: str) -> Optional[Type["StoragePolicy"]]:
@@ -36,7 +31,7 @@ class StoragePolicy:
36
31
 
37
32
  def load_file(
38
33
  self, artifact: "Artifact", manifest_entry: "ArtifactManifestEntry"
39
- ) -> str:
34
+ ) -> FilePathStr:
40
35
  raise NotImplementedError
41
36
 
42
37
  def store_file_sync(
@@ -74,52 +69,5 @@ class StoragePolicy:
74
69
  self,
75
70
  manifest_entry: "ArtifactManifestEntry",
76
71
  local: bool = False,
77
- ) -> str:
78
- raise NotImplementedError
79
-
80
-
81
- class StorageHandler:
82
- def can_handle(self, parsed_url: "ParseResult") -> bool:
83
- """Checks whether this handler can handle the given url.
84
-
85
- Returns:
86
- Whether this handler can handle the given url.
87
- """
88
- raise NotImplementedError
89
-
90
- def load_path(
91
- self,
92
- manifest_entry: "ArtifactManifestEntry",
93
- local: bool = False,
94
- ) -> Union[URIStr, FilePathStr]:
95
- """Load a file or directory given the corresponding index entry.
96
-
97
- Args:
98
- manifest_entry: The index entry to load
99
- local: Whether to load the file locally or not
100
-
101
- Returns:
102
- A path to the file represented by `index_entry`
103
- """
104
- raise NotImplementedError
105
-
106
- def store_path(
107
- self,
108
- artifact: "Artifact",
109
- path: Union[URIStr, FilePathStr],
110
- name: Optional[str] = None,
111
- checksum: bool = True,
112
- max_objects: Optional[int] = None,
113
- ) -> Sequence["ArtifactManifestEntry"]:
114
- """Store the file or directory at the given path to the specified artifact.
115
-
116
- Args:
117
- path: The path to store
118
- name: If specified, the logical name that should map to `path`
119
- checksum: Whether to compute the checksum of the file
120
- max_objects: The maximum number of objects to store
121
-
122
- Returns:
123
- A list of manifest entries to store within the artifact
124
- """
72
+ ) -> Union[FilePathStr, URIStr]:
125
73
  raise NotImplementedError
@@ -13,8 +13,7 @@ from wandb.util import (
13
13
  np = get_module("numpy") # intentionally not required
14
14
 
15
15
  if t.TYPE_CHECKING:
16
- from wandb.apis.public import Artifact as DownloadedArtifact
17
- from wandb.sdk.wandb_artifacts import Artifact as ArtifactInCreation
16
+ from wandb.sdk.artifacts.artifact import Artifact
18
17
 
19
18
  _TYPES_STRIPPED = not (sys.version_info.major == 3 and sys.version_info.minor >= 6)
20
19
  if not _TYPES_STRIPPED:
@@ -77,7 +76,7 @@ class TypeRegistry:
77
76
 
78
77
  @staticmethod
79
78
  def type_from_dict(
80
- json_dict: t.Dict[str, t.Any], artifact: t.Optional["DownloadedArtifact"] = None
79
+ json_dict: t.Dict[str, t.Any], artifact: t.Optional["Artifact"] = None
81
80
  ) -> "Type":
82
81
  wb_type = json_dict.get("wb_type")
83
82
  if wb_type is None:
@@ -135,7 +134,7 @@ class TypeRegistry:
135
134
 
136
135
  def _params_obj_to_json_obj(
137
136
  params_obj: t.Any,
138
- artifact: t.Optional["ArtifactInCreation"] = None,
137
+ artifact: t.Optional["Artifact"] = None,
139
138
  ) -> t.Any:
140
139
  """Helper method."""
141
140
  if params_obj.__class__ == dict:
@@ -152,7 +151,7 @@ def _params_obj_to_json_obj(
152
151
 
153
152
 
154
153
  def _json_obj_to_params_obj(
155
- json_obj: t.Any, artifact: t.Optional["DownloadedArtifact"] = None
154
+ json_obj: t.Any, artifact: t.Optional["Artifact"] = None
156
155
  ) -> t.Any:
157
156
  """Helper method."""
158
157
  if json_obj.__class__ == dict:
@@ -222,9 +221,7 @@ class Type:
222
221
  else:
223
222
  return InvalidType()
224
223
 
225
- def to_json(
226
- self, artifact: t.Optional["ArtifactInCreation"] = None
227
- ) -> t.Dict[str, t.Any]:
224
+ def to_json(self, artifact: t.Optional["Artifact"] = None) -> t.Dict[str, t.Any]:
228
225
  """Generate a jsonable dictionary serialization the type.
229
226
 
230
227
  If overridden by subclass, ensure that `from_json` is equivalently overridden.
@@ -249,7 +246,7 @@ class Type:
249
246
  def from_json(
250
247
  cls,
251
248
  json_dict: t.Dict[str, t.Any],
252
- artifact: t.Optional["DownloadedArtifact"] = None,
249
+ artifact: t.Optional["Artifact"] = None,
253
250
  ) -> "Type":
254
251
  """Construct a new instance of the type using a JSON dictionary.
255
252
 
@@ -756,9 +753,7 @@ class NDArrayType(Type):
756
753
 
757
754
  return InvalidType()
758
755
 
759
- def to_json(
760
- self, artifact: t.Optional["ArtifactInCreation"] = None
761
- ) -> t.Dict[str, t.Any]:
756
+ def to_json(self, artifact: t.Optional["Artifact"] = None) -> t.Dict[str, t.Any]:
762
757
  # custom override to support serialization path outside of params internal dict
763
758
  res = {
764
759
  "wb_type": self.name,
@@ -9,7 +9,8 @@ from .._private import MEDIA_TMP
9
9
  from .media import Media
10
10
 
11
11
  if TYPE_CHECKING: # pragma: no cover
12
- from ...wandb_artifacts import Artifact as LocalArtifact
12
+ from wandb.sdk.artifacts.artifact import Artifact
13
+
13
14
  from ...wandb_run import Run as LocalRun
14
15
 
15
16
 
@@ -39,7 +40,7 @@ class JSONMetadata(Media):
39
40
  def get_media_subdir(cls: Type["JSONMetadata"]) -> str:
40
41
  return os.path.join("media", "metadata", cls.type_name())
41
42
 
42
- def to_json(self, run_or_artifact: Union["LocalRun", "LocalArtifact"]) -> dict:
43
+ def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
43
44
  json_dict = super().to_json(run_or_artifact)
44
45
  json_dict["_type"] = self.type_name()
45
46
 
@@ -9,15 +9,15 @@ import wandb
9
9
  from wandb import util
10
10
  from wandb._globals import _datatypes_callback
11
11
  from wandb.sdk.lib import filesystem
12
+ from wandb.sdk.lib.paths import LogicalPath
12
13
 
13
14
  from .wb_value import WBValue
14
15
 
15
16
  if TYPE_CHECKING: # pragma: no cover
16
- import numpy as np # type: ignore
17
+ import numpy as np
17
18
 
18
- from wandb.apis.public import Artifact as PublicArtifact
19
+ from wandb.sdk.artifacts.artifact import Artifact
19
20
 
20
- from ...wandb_artifacts import Artifact as LocalArtifact
21
21
  from ...wandb_run import Run as LocalRun
22
22
 
23
23
 
@@ -143,7 +143,7 @@ class Media(WBValue):
143
143
  self._path = new_path
144
144
  _datatypes_callback(media_path)
145
145
 
146
- def to_json(self, run: Union["LocalRun", "LocalArtifact"]) -> dict:
146
+ def to_json(self, run: Union["LocalRun", "Artifact"]) -> dict:
147
147
  """Serialize the object into a JSON blob.
148
148
 
149
149
  Uses run or artifact to store additional data. If `run_or_artifact` is a
@@ -193,11 +193,11 @@ class Media(WBValue):
193
193
  # by definition is_bound, but are needed for
194
194
  # mypy to understand that these are strings below.
195
195
  assert isinstance(self._path, str)
196
- json_obj["path"] = util.to_forward_slash_path(
196
+ json_obj["path"] = LogicalPath(
197
197
  os.path.relpath(self._path, self._run.dir)
198
198
  )
199
199
 
200
- elif isinstance(run, wandb.wandb_sdk.wandb_artifacts.Artifact):
200
+ elif isinstance(run, wandb.Artifact):
201
201
  if self.file_is_set():
202
202
  # The following two assertions are guaranteed to pass
203
203
  # by definition of the call above, but are needed for
@@ -253,7 +253,7 @@ class Media(WBValue):
253
253
 
254
254
  @classmethod
255
255
  def from_json(
256
- cls: Type["Media"], json_obj: dict, source_artifact: "PublicArtifact"
256
+ cls: Type["Media"], json_obj: dict, source_artifact: "Artifact"
257
257
  ) -> "Media":
258
258
  """Likely will need to override for any more complicated media objects."""
259
259
  return cls(source_artifact.get_path(json_obj["path"]).download())
@@ -315,4 +315,4 @@ def _numpy_arrays_to_lists(
315
315
  # Protects against logging non serializable objects
316
316
  elif isinstance(payload, Media):
317
317
  return str(payload.__class__.__name__)
318
- return payload
318
+ return payload # type: ignore