wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,2226 +0,0 @@
1
- import base64
2
- import contextlib
3
- import json
4
- import os
5
- import pathlib
6
- import re
7
- import shutil
8
- import tempfile
9
- import time
10
- from types import ModuleType
11
- from typing import (
12
- IO,
13
- TYPE_CHECKING,
14
- Any,
15
- Dict,
16
- Generator,
17
- List,
18
- Mapping,
19
- Optional,
20
- Sequence,
21
- Tuple,
22
- Union,
23
- cast,
24
- )
25
- from urllib.parse import parse_qsl, quote, urlparse
26
-
27
- import requests
28
- import urllib3
29
-
30
- import wandb
31
- import wandb.data_types as data_types
32
- from wandb import env, util
33
- from wandb.apis import InternalApi, PublicApi
34
- from wandb.apis.public import Artifact as PublicArtifact
35
- from wandb.errors import CommError
36
- from wandb.errors.term import termlog, termwarn
37
- from wandb.sdk import lib as wandb_lib
38
- from wandb.sdk.data_types._dtypes import Type, TypeRegistry
39
- from wandb.sdk.interface.artifacts import Artifact as ArtifactInterface
40
- from wandb.sdk.interface.artifacts import (
41
- ArtifactFinalizedError,
42
- ArtifactManifest,
43
- ArtifactManifestEntry,
44
- ArtifactNotLoggedError,
45
- ArtifactsCache,
46
- StorageHandler,
47
- StorageLayout,
48
- StoragePolicy,
49
- get_artifacts_cache,
50
- )
51
- from wandb.sdk.internal import progress
52
- from wandb.sdk.internal.artifact_saver import get_staging_dir
53
- from wandb.sdk.lib import filesystem, runid
54
- from wandb.sdk.lib.hashutil import (
55
- B64MD5,
56
- ETag,
57
- HexMD5,
58
- _md5,
59
- b64_to_hex_id,
60
- hex_to_b64_id,
61
- md5_file_b64,
62
- md5_string,
63
- )
64
- from wandb.sdk.lib.paths import FilePathStr, LogicalFilePathStr, URIStr
65
-
66
- if TYPE_CHECKING:
67
- from urllib.parse import ParseResult
68
-
69
- import azure.storage.blob # type: ignore
70
-
71
- # We could probably use https://pypi.org/project/boto3-stubs/ or something
72
- # instead of `type:ignore`ing these boto imports, but it's nontrivial:
73
- # for some reason, despite being actively maintained as of 2022-09-30,
74
- # the latest release of boto3-stubs doesn't include all the features we use.
75
- import boto3 # type: ignore
76
- import boto3.resources.base # type: ignore
77
- import boto3.s3 # type: ignore
78
- import boto3.session # type: ignore
79
- import google.cloud.storage as gcs_module # type: ignore
80
-
81
- import wandb.apis.public
82
- from wandb.filesync.step_prepare import StepPrepare
83
-
84
- # This makes the first sleep 1s, and then doubles it up to total times,
85
- # which makes for ~18 hours.
86
- _REQUEST_RETRY_STRATEGY = urllib3.util.retry.Retry(
87
- backoff_factor=1,
88
- total=16,
89
- status_forcelist=(308, 408, 409, 429, 500, 502, 503, 504),
90
- )
91
-
92
- _REQUEST_POOL_CONNECTIONS = 64
93
-
94
- _REQUEST_POOL_MAXSIZE = 64
95
-
96
- ARTIFACT_TMP = tempfile.TemporaryDirectory("wandb-artifacts")
97
-
98
-
99
- class _AddedObj:
100
- def __init__(self, entry: ArtifactManifestEntry, obj: data_types.WBValue):
101
- self.entry = entry
102
- self.obj = obj
103
-
104
-
105
- def _normalize_metadata(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
106
- if metadata is None:
107
- return {}
108
- if not isinstance(metadata, dict):
109
- raise TypeError(f"metadata must be dict, not {type(metadata)}")
110
- return cast(
111
- Dict[str, Any], json.loads(json.dumps(util.json_friendly_val(metadata)))
112
- )
113
-
114
-
115
- class Artifact(ArtifactInterface):
116
- """Flexible and lightweight building block for dataset and model versioning.
117
-
118
- Constructs an empty artifact whose contents can be populated using its
119
- `add` family of functions. Once the artifact has all the desired files,
120
- you can call `wandb.log_artifact()` to log it.
121
-
122
- Arguments:
123
- name: (str) A human-readable name for this artifact, which is how you
124
- can identify this artifact in the UI or reference it in `use_artifact`
125
- calls. Names can contain letters, numbers, underscores, hyphens, and
126
- dots. The name must be unique across a project.
127
- type: (str) The type of the artifact, which is used to organize and differentiate
128
- artifacts. Common types include `dataset` or `model`, but you can use any string
129
- containing letters, numbers, underscores, hyphens, and dots.
130
- description: (str, optional) Free text that offers a description of the artifact. The
131
- description is markdown rendered in the UI, so this is a good place to place tables,
132
- links, etc.
133
- metadata: (dict, optional) Structured data associated with the artifact,
134
- for example class distribution of a dataset. This will eventually be queryable
135
- and plottable in the UI. There is a hard limit of 100 total keys.
136
-
137
- Examples:
138
- Basic usage
139
- ```
140
- wandb.init()
141
-
142
- artifact = wandb.Artifact('mnist', type='dataset')
143
- artifact.add_dir('mnist/')
144
- wandb.log_artifact(artifact)
145
- ```
146
-
147
- Returns:
148
- An `Artifact` object.
149
- """
150
-
151
- _added_objs: Dict[int, _AddedObj]
152
- _added_local_paths: Dict[str, ArtifactManifestEntry]
153
- _distributed_id: Optional[str]
154
- _metadata: dict
155
- _logged_artifact: Optional[ArtifactInterface]
156
- _incremental: bool
157
- _client_id: str
158
-
159
- def __init__(
160
- self,
161
- name: str,
162
- type: str,
163
- description: Optional[str] = None,
164
- metadata: Optional[dict] = None,
165
- incremental: Optional[bool] = None,
166
- use_as: Optional[str] = None,
167
- ) -> None:
168
- if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
169
- raise ValueError(
170
- "Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. "
171
- 'Invalid name: "%s"' % name
172
- )
173
- if type == "job" or type.startswith("wandb-"):
174
- raise ValueError(
175
- "Artifact types 'job' and 'wandb-*' are reserved for internal use. "
176
- "Please use a different type."
177
- )
178
-
179
- metadata = _normalize_metadata(metadata)
180
- # TODO: this shouldn't be a property of the artifact. It's a more like an
181
- # argument to log_artifact.
182
- storage_layout = StorageLayout.V2
183
- if env.get_use_v1_artifacts():
184
- storage_layout = StorageLayout.V1
185
-
186
- self._storage_policy = WandbStoragePolicy(
187
- config={
188
- "storageLayout": storage_layout,
189
- # TODO: storage region
190
- }
191
- )
192
- self._api = InternalApi()
193
- self._final = False
194
- self._digest = ""
195
- self._file_entries = None
196
- self._manifest = ArtifactManifestV1(self._storage_policy)
197
- self._cache = get_artifacts_cache()
198
- self._added_objs = {}
199
- self._added_local_paths = {}
200
- # You can write into this directory when creating artifact files
201
- self._artifact_dir = tempfile.TemporaryDirectory()
202
- self._type = type
203
- self._name = name
204
- self._description = description
205
- self._metadata = metadata
206
- self._distributed_id = None
207
- self._logged_artifact = None
208
- self._incremental = False
209
- self._client_id = runid.generate_id(128)
210
- self._sequence_client_id = runid.generate_id(128)
211
- self._cache.store_client_artifact(self)
212
- self._use_as = use_as
213
-
214
- if incremental:
215
- self._incremental = incremental
216
- wandb.termwarn("Using experimental arg `incremental`")
217
-
218
- @property
219
- def id(self) -> Optional[str]:
220
- if self._logged_artifact:
221
- return self._logged_artifact.id
222
-
223
- # The artifact hasn't been saved so an ID doesn't exist yet.
224
- return None
225
-
226
- @property
227
- def source_version(self) -> Optional[str]:
228
- if self._logged_artifact:
229
- return self._logged_artifact.source_version
230
-
231
- return None
232
-
233
- @property
234
- def version(self) -> str:
235
- if self._logged_artifact:
236
- return self._logged_artifact.version
237
-
238
- raise ArtifactNotLoggedError(self, "version")
239
-
240
- @property
241
- def entity(self) -> str:
242
- if self._logged_artifact:
243
- return self._logged_artifact.entity
244
- return self._api.settings("entity") or self._api.viewer().get("entity") # type: ignore
245
-
246
- @property
247
- def project(self) -> str:
248
- if self._logged_artifact:
249
- return self._logged_artifact.project
250
-
251
- return self._api.settings("project") # type: ignore
252
-
253
- @property
254
- def manifest(self) -> ArtifactManifest:
255
- if self._logged_artifact:
256
- return self._logged_artifact.manifest
257
-
258
- self.finalize()
259
- return self._manifest
260
-
261
- @property
262
- def digest(self) -> str:
263
- if self._logged_artifact:
264
- return self._logged_artifact.digest
265
-
266
- self.finalize()
267
- # Digest will be none if the artifact hasn't been saved yet.
268
- return self._digest
269
-
270
- @property
271
- def type(self) -> str:
272
- if self._logged_artifact:
273
- return self._logged_artifact.type
274
-
275
- return self._type
276
-
277
- @property
278
- def name(self) -> str:
279
- if self._logged_artifact:
280
- return self._logged_artifact.name
281
-
282
- return self._name
283
-
284
- @property
285
- def full_name(self) -> str:
286
- if self._logged_artifact:
287
- return self._logged_artifact.full_name
288
-
289
- return super().full_name
290
-
291
- @property
292
- def state(self) -> str:
293
- if self._logged_artifact:
294
- return self._logged_artifact.state
295
-
296
- return "PENDING"
297
-
298
- @property
299
- def size(self) -> int:
300
- if self._logged_artifact:
301
- return self._logged_artifact.size
302
- sizes: List[int]
303
- sizes = []
304
- for entry in self._manifest.entries:
305
- e_size = self._manifest.entries[entry].size
306
- if e_size is not None:
307
- sizes.append(e_size)
308
- return sum(sizes)
309
-
310
- @property
311
- def commit_hash(self) -> str:
312
- if self._logged_artifact:
313
- return self._logged_artifact.commit_hash
314
-
315
- raise ArtifactNotLoggedError(self, "commit_hash")
316
-
317
- @property
318
- def description(self) -> Optional[str]:
319
- if self._logged_artifact:
320
- return self._logged_artifact.description
321
-
322
- return self._description
323
-
324
- @description.setter
325
- def description(self, desc: Optional[str]) -> None:
326
- if self._logged_artifact:
327
- self._logged_artifact.description = desc
328
- return
329
-
330
- self._description = desc
331
-
332
- @property
333
- def metadata(self) -> dict:
334
- if self._logged_artifact:
335
- return self._logged_artifact.metadata
336
-
337
- return self._metadata
338
-
339
- @metadata.setter
340
- def metadata(self, metadata: dict) -> None:
341
- metadata = _normalize_metadata(metadata)
342
- if self._logged_artifact:
343
- self._logged_artifact.metadata = metadata
344
- return
345
-
346
- self._metadata = metadata
347
-
348
- @property
349
- def aliases(self) -> List[str]:
350
- if self._logged_artifact:
351
- return self._logged_artifact.aliases
352
-
353
- raise ArtifactNotLoggedError(self, "aliases")
354
-
355
- @aliases.setter
356
- def aliases(self, aliases: List[str]) -> None:
357
- """Set artifact aliases.
358
-
359
- Arguments:
360
- aliases: (list) The list of aliases associated with this artifact.
361
- """
362
- if self._logged_artifact:
363
- self._logged_artifact.aliases = aliases
364
- return
365
-
366
- raise ArtifactNotLoggedError(self, "aliases")
367
-
368
- @property
369
- def use_as(self) -> Optional[str]:
370
- return self._use_as
371
-
372
- @property
373
- def distributed_id(self) -> Optional[str]:
374
- return self._distributed_id
375
-
376
- @distributed_id.setter
377
- def distributed_id(self, distributed_id: Optional[str]) -> None:
378
- self._distributed_id = distributed_id
379
-
380
- @property
381
- def incremental(self) -> bool:
382
- return self._incremental
383
-
384
- def used_by(self) -> List["wandb.apis.public.Run"]:
385
- if self._logged_artifact:
386
- return self._logged_artifact.used_by()
387
-
388
- raise ArtifactNotLoggedError(self, "used_by")
389
-
390
- def logged_by(self) -> "wandb.apis.public.Run":
391
- if self._logged_artifact:
392
- return self._logged_artifact.logged_by()
393
-
394
- raise ArtifactNotLoggedError(self, "logged_by")
395
-
396
- @contextlib.contextmanager
397
- def new_file(
398
- self, name: str, mode: str = "w", encoding: Optional[str] = None
399
- ) -> Generator[IO, None, None]:
400
- self._ensure_can_add()
401
- path = os.path.join(self._artifact_dir.name, name.lstrip("/"))
402
- if os.path.exists(path):
403
- raise ValueError(f"File with name {name!r} already exists at {path!r}")
404
-
405
- filesystem.mkdir_exists_ok(os.path.dirname(path))
406
- try:
407
- with util.fsync_open(path, mode, encoding) as f:
408
- yield f
409
- except UnicodeEncodeError as e:
410
- wandb.termerror(
411
- f"Failed to open the provided file (UnicodeEncodeError: {e}). Please provide the proper encoding."
412
- )
413
- raise e
414
- self.add_file(path, name=name)
415
-
416
- def add_file(
417
- self,
418
- local_path: str,
419
- name: Optional[str] = None,
420
- is_tmp: Optional[bool] = False,
421
- ) -> ArtifactManifestEntry:
422
- self._ensure_can_add()
423
- if not os.path.isfile(local_path):
424
- raise ValueError("Path is not a file: %s" % local_path)
425
-
426
- name = util.to_forward_slash_path(name or os.path.basename(local_path))
427
- digest = md5_file_b64(local_path)
428
-
429
- if is_tmp:
430
- file_path, file_name = os.path.split(name)
431
- file_name_parts = file_name.split(".")
432
- file_name_parts[0] = b64_to_hex_id(digest)[:20]
433
- name = os.path.join(file_path, ".".join(file_name_parts))
434
-
435
- return self._add_local_file(name, local_path, digest=digest)
436
-
437
- def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
438
- self._ensure_can_add()
439
- if not os.path.isdir(local_path):
440
- raise ValueError("Path is not a directory: %s" % local_path)
441
-
442
- termlog(
443
- "Adding directory to artifact (%s)... "
444
- % os.path.join(".", os.path.normpath(local_path)),
445
- newline=False,
446
- )
447
- start_time = time.time()
448
-
449
- paths = []
450
- for dirpath, _, filenames in os.walk(local_path, followlinks=True):
451
- for fname in filenames:
452
- physical_path = os.path.join(dirpath, fname)
453
- logical_path = os.path.relpath(physical_path, start=local_path)
454
- if name is not None:
455
- logical_path = os.path.join(name, logical_path)
456
- paths.append((logical_path, physical_path))
457
-
458
- def add_manifest_file(log_phy_path: Tuple[str, str]) -> None:
459
- logical_path, physical_path = log_phy_path
460
- self._add_local_file(logical_path, physical_path)
461
-
462
- import multiprocessing.dummy # this uses threads
463
-
464
- num_threads = 8
465
- pool = multiprocessing.dummy.Pool(num_threads)
466
- pool.map(add_manifest_file, paths)
467
- pool.close()
468
- pool.join()
469
-
470
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
471
-
472
- def add_reference(
473
- self,
474
- uri: Union[ArtifactManifestEntry, str],
475
- name: Optional[str] = None,
476
- checksum: bool = True,
477
- max_objects: Optional[int] = None,
478
- ) -> Sequence[ArtifactManifestEntry]:
479
- self._ensure_can_add()
480
- if name is not None:
481
- name = util.to_forward_slash_path(name)
482
-
483
- # This is a bit of a hack, we want to check if the uri is a of the type
484
- # ArtifactManifestEntry which is a private class returned by Artifact.get_path in
485
- # wandb/apis/public.py. If so, then recover the reference URL.
486
- if isinstance(uri, ArtifactManifestEntry) and uri.parent_artifact() != self:
487
- ref_url_fn = uri.ref_url
488
- uri_str = ref_url_fn()
489
- elif isinstance(uri, str):
490
- uri_str = uri
491
- url = urlparse(str(uri_str))
492
- if not url.scheme:
493
- raise ValueError(
494
- "References must be URIs. To reference a local file, use file://"
495
- )
496
-
497
- manifest_entries = self._storage_policy.store_reference(
498
- self,
499
- URIStr(uri_str),
500
- name=name,
501
- checksum=checksum,
502
- max_objects=max_objects,
503
- )
504
- for entry in manifest_entries:
505
- self._manifest.add_entry(entry)
506
-
507
- return manifest_entries
508
-
509
- def add(self, obj: data_types.WBValue, name: str) -> ArtifactManifestEntry:
510
- self._ensure_can_add()
511
- name = util.to_forward_slash_path(name)
512
-
513
- # This is a "hack" to automatically rename tables added to
514
- # the wandb /media/tables directory to their sha-based name.
515
- # TODO: figure out a more appropriate convention.
516
- is_tmp_name = name.startswith("media/tables")
517
-
518
- # Validate that the object is one of the correct wandb.Media types
519
- # TODO: move this to checking subclass of wandb.Media once all are
520
- # generally supported
521
- allowed_types = [
522
- data_types.Bokeh,
523
- data_types.JoinedTable,
524
- data_types.PartitionedTable,
525
- data_types.Table,
526
- data_types.Classes,
527
- data_types.ImageMask,
528
- data_types.BoundingBoxes2D,
529
- data_types.Audio,
530
- data_types.Image,
531
- data_types.Video,
532
- data_types.Html,
533
- data_types.Object3D,
534
- data_types.Molecule,
535
- data_types._SavedModel,
536
- ]
537
-
538
- if not any(isinstance(obj, t) for t in allowed_types):
539
- raise ValueError(
540
- "Found object of type {}, expected one of {}.".format(
541
- obj.__class__, allowed_types
542
- )
543
- )
544
-
545
- obj_id = id(obj)
546
- if obj_id in self._added_objs:
547
- return self._added_objs[obj_id].entry
548
-
549
- # If the object is coming from another artifact, save it as a reference
550
- ref_path = obj._get_artifact_entry_ref_url()
551
- if ref_path is not None:
552
- return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
553
-
554
- val = obj.to_json(self)
555
- name = obj.with_suffix(name)
556
- entry = self._manifest.get_entry_by_path(name)
557
- if entry is not None:
558
- return entry
559
-
560
- def do_write(f: IO) -> None:
561
- import json
562
-
563
- # TODO: Do we need to open with utf-8 codec?
564
- f.write(json.dumps(val, sort_keys=True))
565
-
566
- if is_tmp_name:
567
- file_path = os.path.join(ARTIFACT_TMP.name, str(id(self)), name)
568
- folder_path, _ = os.path.split(file_path)
569
- if not os.path.exists(folder_path):
570
- os.makedirs(folder_path)
571
- with open(file_path, "w") as tmp_f:
572
- do_write(tmp_f)
573
- else:
574
- with self.new_file(name) as f:
575
- file_path = f.name
576
- do_write(f)
577
-
578
- # Note, we add the file from our temp directory.
579
- # It will be added again later on finalize, but succeed since
580
- # the checksum should match
581
- entry = self.add_file(file_path, name, is_tmp_name)
582
- self._added_objs[obj_id] = _AddedObj(entry, obj)
583
- if obj._artifact_target is None:
584
- obj._set_artifact_target(self, entry.path)
585
-
586
- if is_tmp_name:
587
- if os.path.exists(file_path):
588
- os.remove(file_path)
589
-
590
- return entry
591
-
592
- def get_path(self, name: str) -> ArtifactManifestEntry:
593
- if self._logged_artifact:
594
- return self._logged_artifact.get_path(name)
595
-
596
- raise ArtifactNotLoggedError(self, "get_path")
597
-
598
- def get(self, name: str) -> data_types.WBValue:
599
- if self._logged_artifact:
600
- return self._logged_artifact.get(name)
601
-
602
- raise ArtifactNotLoggedError(self, "get")
603
-
604
- def download(
605
- self, root: Optional[str] = None, recursive: bool = False
606
- ) -> FilePathStr:
607
- if self._logged_artifact:
608
- return self._logged_artifact.download(root=root, recursive=recursive)
609
-
610
- raise ArtifactNotLoggedError(self, "download")
611
-
612
- def checkout(self, root: Optional[str] = None) -> str:
613
- if self._logged_artifact:
614
- return self._logged_artifact.checkout(root=root)
615
-
616
- raise ArtifactNotLoggedError(self, "checkout")
617
-
618
- def verify(self, root: Optional[str] = None) -> bool:
619
- if self._logged_artifact:
620
- return self._logged_artifact.verify(root=root)
621
-
622
- raise ArtifactNotLoggedError(self, "verify")
623
-
624
- def save(
625
- self,
626
- project: Optional[str] = None,
627
- settings: Optional["wandb.wandb_sdk.wandb_settings.Settings"] = None,
628
- ) -> None:
629
- """Persist any changes made to the artifact.
630
-
631
- If currently in a run, that run will log this artifact. If not currently in a
632
- run, a run of type "auto" will be created to track this artifact.
633
-
634
- Arguments:
635
- project: (str, optional) A project to use for the artifact in the case that
636
- a run is not already in context settings: (wandb.Settings, optional) A
637
- settings object to use when initializing an automatic run. Most commonly
638
- used in testing harness.
639
-
640
- Returns:
641
- None
642
- """
643
- if self._incremental:
644
- with wandb_lib.telemetry.context() as tel:
645
- tel.feature.artifact_incremental = True
646
-
647
- if self._logged_artifact:
648
- return self._logged_artifact.save()
649
- else:
650
- if wandb.run is None:
651
- if settings is None:
652
- settings = wandb.Settings(silent="true")
653
- with wandb.init(
654
- project=project, job_type="auto", settings=settings
655
- ) as run:
656
- # redoing this here because in this branch we know we didn't
657
- # have the run at the beginning of the method
658
- if self._incremental:
659
- with wandb_lib.telemetry.context(run=run) as tel:
660
- tel.feature.artifact_incremental = True
661
- run.log_artifact(self)
662
- else:
663
- wandb.run.log_artifact(self)
664
-
665
- def delete(self) -> None:
666
- if self._logged_artifact:
667
- return self._logged_artifact.delete()
668
-
669
- raise ArtifactNotLoggedError(self, "delete")
670
-
671
- def wait(self, timeout: Optional[int] = None) -> ArtifactInterface:
672
- """Wait for an artifact to finish logging.
673
-
674
- Arguments:
675
- timeout: (int, optional) Wait up to this long.
676
- """
677
- if self._logged_artifact:
678
- return self._logged_artifact.wait(timeout) # type: ignore [call-arg]
679
-
680
- raise ArtifactNotLoggedError(self, "wait")
681
-
682
- def get_added_local_path_name(self, local_path: str) -> Optional[str]:
683
- """Get the artifact relative name of a file added by a local filesystem path.
684
-
685
- Arguments:
686
- local_path: (str) The local path to resolve into an artifact relative name.
687
-
688
- Returns:
689
- str: The artifact relative name.
690
-
691
- Examples:
692
- Basic usage
693
- ```
694
- artifact = wandb.Artifact('my_dataset', type='dataset')
695
- artifact.add_file('path/to/file.txt', name='artifact/path/file.txt')
696
-
697
- # Returns `artifact/path/file.txt`:
698
- name = artifact.get_added_local_path_name('path/to/file.txt')
699
- ```
700
- """
701
- entry = self._added_local_paths.get(local_path, None)
702
- if entry is None:
703
- return None
704
- return entry.path
705
-
706
- def finalize(self) -> None:
707
- """Mark this artifact as final, disallowing further modifications.
708
-
709
- This happens automatically when calling `log_artifact`.
710
-
711
- Returns:
712
- None
713
- """
714
- if self._final:
715
- return self._file_entries
716
-
717
- # mark final after all files are added
718
- self._final = True
719
- self._digest = self._manifest.digest()
720
-
721
- def json_encode(self) -> Dict[str, Any]:
722
- if not self._logged_artifact:
723
- raise ArtifactNotLoggedError(self, "json_encode")
724
- return util.artifact_to_json(self)
725
-
726
- def _ensure_can_add(self) -> None:
727
- if self._final:
728
- raise ArtifactFinalizedError(artifact=self)
729
-
730
- def _add_local_file(
731
- self, name: str, path: str, digest: Optional[B64MD5] = None
732
- ) -> ArtifactManifestEntry:
733
- with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
734
- staging_path = f.name
735
- shutil.copyfile(path, staging_path)
736
- os.chmod(staging_path, 0o400)
737
-
738
- entry = ArtifactManifestEntry(
739
- path=util.to_forward_slash_path(name),
740
- digest=digest or md5_file_b64(staging_path),
741
- size=os.path.getsize(staging_path),
742
- local_path=staging_path,
743
- )
744
-
745
- self._manifest.add_entry(entry)
746
- self._added_local_paths[path] = entry
747
- return entry
748
-
749
-
750
- class ArtifactManifestV1(ArtifactManifest):
751
- @classmethod
752
- def version(cls) -> int:
753
- return 1
754
-
755
- @classmethod
756
- def from_manifest_json(cls, manifest_json: Dict) -> "ArtifactManifestV1":
757
- if manifest_json["version"] != cls.version():
758
- raise ValueError(
759
- "Expected manifest version 1, got %s" % manifest_json["version"]
760
- )
761
-
762
- storage_policy_name = manifest_json["storagePolicy"]
763
- storage_policy_config = manifest_json.get("storagePolicyConfig", {})
764
- storage_policy_cls = StoragePolicy.lookup_by_name(storage_policy_name)
765
- if storage_policy_cls is None:
766
- raise ValueError('Failed to find storage policy "%s"' % storage_policy_name)
767
- if not issubclass(storage_policy_cls, WandbStoragePolicy):
768
- raise ValueError(
769
- "No handler found for storage handler of type '%s'"
770
- % storage_policy_name
771
- )
772
-
773
- entries: Mapping[str, ArtifactManifestEntry]
774
- entries = {
775
- name: ArtifactManifestEntry(
776
- path=LogicalFilePathStr(name),
777
- digest=val["digest"],
778
- birth_artifact_id=val.get("birthArtifactID"),
779
- ref=val.get("ref"),
780
- size=val.get("size"),
781
- extra=val.get("extra"),
782
- local_path=val.get("local_path"),
783
- )
784
- for name, val in manifest_json["contents"].items()
785
- }
786
-
787
- return cls(storage_policy_cls.from_config(storage_policy_config), entries)
788
-
789
- def __init__(
790
- self,
791
- storage_policy: "WandbStoragePolicy",
792
- entries: Optional[Mapping[str, ArtifactManifestEntry]] = None,
793
- ) -> None:
794
- super().__init__(storage_policy, entries=entries)
795
-
796
- def to_manifest_json(self) -> Dict:
797
- """This is the JSON that's stored in wandb_manifest.json.
798
-
799
- If include_local is True we also include the local paths to files. This is
800
- used to represent an artifact that's waiting to be saved on the current
801
- system. We don't need to include the local paths in the artifact manifest
802
- contents.
803
- """
804
- contents = {}
805
- for entry in sorted(self.entries.values(), key=lambda k: k.path):
806
- json_entry: Dict[str, Any] = {
807
- "digest": entry.digest,
808
- }
809
- if entry.birth_artifact_id:
810
- json_entry["birthArtifactID"] = entry.birth_artifact_id
811
- if entry.ref:
812
- json_entry["ref"] = entry.ref
813
- if entry.extra:
814
- json_entry["extra"] = entry.extra
815
- if entry.size is not None:
816
- json_entry["size"] = entry.size
817
- contents[entry.path] = json_entry
818
- return {
819
- "version": self.__class__.version(),
820
- "storagePolicy": self.storage_policy.name(),
821
- "storagePolicyConfig": self.storage_policy.config() or {},
822
- "contents": contents,
823
- }
824
-
825
- def digest(self) -> HexMD5:
826
- hasher = _md5()
827
- hasher.update(b"wandb-artifact-manifest-v1\n")
828
- for name, entry in sorted(self.entries.items(), key=lambda kv: kv[0]):
829
- hasher.update(f"{name}:{entry.digest}\n".encode())
830
- return HexMD5(hasher.hexdigest())
831
-
832
-
833
- class WandbStoragePolicy(StoragePolicy):
834
- @classmethod
835
- def name(cls) -> str:
836
- return "wandb-storage-policy-v1"
837
-
838
- @classmethod
839
- def from_config(cls, config: Dict) -> "WandbStoragePolicy":
840
- return cls(config=config)
841
-
842
- def __init__(
843
- self,
844
- config: Optional[Dict] = None,
845
- cache: Optional[ArtifactsCache] = None,
846
- api: Optional[InternalApi] = None,
847
- ) -> None:
848
- self._cache = cache or get_artifacts_cache()
849
- self._config = config or {}
850
- self._session = requests.Session()
851
- adapter = requests.adapters.HTTPAdapter(
852
- max_retries=_REQUEST_RETRY_STRATEGY,
853
- pool_connections=_REQUEST_POOL_CONNECTIONS,
854
- pool_maxsize=_REQUEST_POOL_MAXSIZE,
855
- )
856
- self._session.mount("http://", adapter)
857
- self._session.mount("https://", adapter)
858
-
859
- s3 = S3Handler()
860
- gcs = GCSHandler()
861
- azure = AzureHandler()
862
- http = HTTPHandler(self._session)
863
- https = HTTPHandler(self._session, scheme="https")
864
- artifact = WBArtifactHandler()
865
- local_artifact = WBLocalArtifactHandler()
866
- file_handler = LocalFileHandler()
867
-
868
- self._api = api or InternalApi()
869
- self._handler = MultiHandler(
870
- handlers=[
871
- s3,
872
- gcs,
873
- azure,
874
- http,
875
- https,
876
- artifact,
877
- local_artifact,
878
- file_handler,
879
- ],
880
- default_handler=TrackingHandler(),
881
- )
882
-
883
- def config(self) -> Dict:
884
- return self._config
885
-
886
- def load_file(
887
- self,
888
- artifact: ArtifactInterface,
889
- manifest_entry: ArtifactManifestEntry,
890
- ) -> str:
891
- path, hit, cache_open = self._cache.check_md5_obj_path(
892
- B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
893
- manifest_entry.size if manifest_entry.size is not None else 0,
894
- )
895
- if hit:
896
- return path
897
-
898
- response = self._session.get(
899
- self._file_url(self._api, artifact.entity, manifest_entry),
900
- auth=("api", self._api.api_key),
901
- stream=True,
902
- )
903
- response.raise_for_status()
904
-
905
- with cache_open(mode="wb") as file:
906
- for data in response.iter_content(chunk_size=16 * 1024):
907
- file.write(data)
908
- return path
909
-
910
- def store_reference(
911
- self,
912
- artifact: ArtifactInterface,
913
- path: Union[URIStr, FilePathStr],
914
- name: Optional[str] = None,
915
- checksum: bool = True,
916
- max_objects: Optional[int] = None,
917
- ) -> Sequence[ArtifactManifestEntry]:
918
- return self._handler.store_path(
919
- artifact, path, name=name, checksum=checksum, max_objects=max_objects
920
- )
921
-
922
- def load_reference(
923
- self,
924
- manifest_entry: ArtifactManifestEntry,
925
- local: bool = False,
926
- ) -> str:
927
- return self._handler.load_path(manifest_entry, local)
928
-
929
- def _file_url(
930
- self, api: InternalApi, entity_name: str, manifest_entry: ArtifactManifestEntry
931
- ) -> str:
932
- storage_layout = self._config.get("storageLayout", StorageLayout.V1)
933
- storage_region = self._config.get("storageRegion", "default")
934
- md5_hex = b64_to_hex_id(B64MD5(manifest_entry.digest))
935
-
936
- if storage_layout == StorageLayout.V1:
937
- return "{}/artifacts/{}/{}".format(
938
- api.settings("base_url"), entity_name, md5_hex
939
- )
940
- elif storage_layout == StorageLayout.V2:
941
- return "{}/artifactsV2/{}/{}/{}/{}".format(
942
- api.settings("base_url"),
943
- storage_region,
944
- entity_name,
945
- quote(
946
- manifest_entry.birth_artifact_id
947
- if manifest_entry.birth_artifact_id is not None
948
- else ""
949
- ),
950
- md5_hex,
951
- )
952
- else:
953
- raise Exception(f"unrecognized storage layout: {storage_layout}")
954
-
955
- def store_file_sync(
956
- self,
957
- artifact_id: str,
958
- artifact_manifest_id: str,
959
- entry: ArtifactManifestEntry,
960
- preparer: "StepPrepare",
961
- progress_callback: Optional["progress.ProgressFn"] = None,
962
- ) -> bool:
963
- """Upload a file to the artifact store.
964
-
965
- Returns:
966
- True if the file was a duplicate (did not need to be uploaded),
967
- False if it needed to be uploaded or was a reference (nothing to dedupe).
968
- """
969
- resp = preparer.prepare_sync(
970
- {
971
- "artifactID": artifact_id,
972
- "artifactManifestID": artifact_manifest_id,
973
- "name": entry.path,
974
- "md5": entry.digest,
975
- }
976
- ).get()
977
-
978
- entry.birth_artifact_id = resp.birth_artifact_id
979
- if resp.upload_url is None:
980
- return True
981
- if entry.local_path is None:
982
- return False
983
-
984
- with open(entry.local_path, "rb") as file:
985
- # This fails if we don't send the first byte before the signed URL expires.
986
- self._api.upload_file_retry(
987
- resp.upload_url,
988
- file,
989
- progress_callback,
990
- extra_headers={
991
- header.split(":", 1)[0]: header.split(":", 1)[1]
992
- for header in (resp.upload_headers or {})
993
- },
994
- )
995
- self._write_cache(entry)
996
-
997
- return False
998
-
999
- async def store_file_async(
1000
- self,
1001
- artifact_id: str,
1002
- artifact_manifest_id: str,
1003
- entry: ArtifactManifestEntry,
1004
- preparer: "StepPrepare",
1005
- progress_callback: Optional["progress.ProgressFn"] = None,
1006
- ) -> bool:
1007
- """Async equivalent to `store_file_sync`."""
1008
- resp = await preparer.prepare_async(
1009
- {
1010
- "artifactID": artifact_id,
1011
- "artifactManifestID": artifact_manifest_id,
1012
- "name": entry.path,
1013
- "md5": entry.digest,
1014
- }
1015
- )
1016
-
1017
- entry.birth_artifact_id = resp.birth_artifact_id
1018
- if resp.upload_url is None:
1019
- return True
1020
- if entry.local_path is None:
1021
- return False
1022
-
1023
- with open(entry.local_path, "rb") as file:
1024
- # This fails if we don't send the first byte before the signed URL expires.
1025
- await self._api.upload_file_retry_async(
1026
- resp.upload_url,
1027
- file,
1028
- progress_callback,
1029
- extra_headers={
1030
- header.split(":", 1)[0]: header.split(":", 1)[1]
1031
- for header in (resp.upload_headers or {})
1032
- },
1033
- )
1034
-
1035
- self._write_cache(entry)
1036
-
1037
- return False
1038
-
1039
- def _write_cache(self, entry: ArtifactManifestEntry) -> None:
1040
- if entry.local_path is None:
1041
- return
1042
-
1043
- # Cache upon successful upload.
1044
- _, hit, cache_open = self._cache.check_md5_obj_path(
1045
- B64MD5(entry.digest),
1046
- entry.size if entry.size is not None else 0,
1047
- )
1048
- if not hit:
1049
- with cache_open() as f:
1050
- shutil.copyfile(entry.local_path, f.name)
1051
-
1052
-
1053
- # Don't use this yet!
1054
- class __S3BucketPolicy(StoragePolicy): # noqa: N801
1055
- @classmethod
1056
- def name(cls) -> str:
1057
- return "wandb-s3-bucket-policy-v1"
1058
-
1059
- @classmethod
1060
- def from_config(cls, config: Dict[str, str]) -> "__S3BucketPolicy":
1061
- if "bucket" not in config:
1062
- raise ValueError("Bucket name not found in config")
1063
- return cls(config["bucket"])
1064
-
1065
- def __init__(self, bucket: str) -> None:
1066
- self._bucket = bucket
1067
- s3 = S3Handler(bucket)
1068
- local = LocalFileHandler()
1069
-
1070
- self._handler = MultiHandler(
1071
- handlers=[
1072
- s3,
1073
- local,
1074
- ],
1075
- default_handler=TrackingHandler(),
1076
- )
1077
-
1078
- def config(self) -> Dict[str, str]:
1079
- return {"bucket": self._bucket}
1080
-
1081
- def load_path(
1082
- self,
1083
- manifest_entry: ArtifactManifestEntry,
1084
- local: bool = False,
1085
- ) -> Union[URIStr, FilePathStr]:
1086
- return self._handler.load_path(manifest_entry, local=local)
1087
-
1088
- def store_path(
1089
- self,
1090
- artifact: ArtifactInterface,
1091
- path: Union[URIStr, FilePathStr],
1092
- name: Optional[str] = None,
1093
- checksum: bool = True,
1094
- max_objects: Optional[int] = None,
1095
- ) -> Sequence[ArtifactManifestEntry]:
1096
- return self._handler.store_path(
1097
- artifact, path, name=name, checksum=checksum, max_objects=max_objects
1098
- )
1099
-
1100
-
1101
- class MultiHandler(StorageHandler):
1102
- _handlers: List[StorageHandler]
1103
-
1104
- def __init__(
1105
- self,
1106
- handlers: Optional[List[StorageHandler]] = None,
1107
- default_handler: Optional[StorageHandler] = None,
1108
- ) -> None:
1109
- self._handlers = handlers or []
1110
- self._default_handler = default_handler
1111
-
1112
- def _get_handler(self, url: Union[FilePathStr, URIStr]) -> StorageHandler:
1113
- parsed_url = urlparse(url)
1114
- for handler in self._handlers:
1115
- if handler.can_handle(parsed_url):
1116
- return handler
1117
- if self._default_handler is not None:
1118
- return self._default_handler
1119
- raise ValueError('No storage handler registered for url "%s"' % str(url))
1120
-
1121
- def load_path(
1122
- self,
1123
- manifest_entry: ArtifactManifestEntry,
1124
- local: bool = False,
1125
- ) -> Union[URIStr, FilePathStr]:
1126
- assert manifest_entry.ref is not None
1127
- handler = self._get_handler(manifest_entry.ref)
1128
- return handler.load_path(manifest_entry, local=local)
1129
-
1130
- def store_path(
1131
- self,
1132
- artifact: ArtifactInterface,
1133
- path: Union[URIStr, FilePathStr],
1134
- name: Optional[str] = None,
1135
- checksum: bool = True,
1136
- max_objects: Optional[int] = None,
1137
- ) -> Sequence[ArtifactManifestEntry]:
1138
- handler = self._get_handler(path)
1139
- return handler.store_path(
1140
- artifact, path, name=name, checksum=checksum, max_objects=max_objects
1141
- )
1142
-
1143
-
1144
- class TrackingHandler(StorageHandler):
1145
- def __init__(self, scheme: Optional[str] = None) -> None:
1146
- """Track paths with no modification or special processing.
1147
-
1148
- Useful when paths being tracked are on file systems mounted at a standardized
1149
- location.
1150
-
1151
- For example, if the data to track is located on an NFS share mounted on
1152
- `/data`, then it is sufficient to just track the paths.
1153
- """
1154
- self._scheme = scheme or ""
1155
-
1156
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1157
- return parsed_url.scheme == self._scheme
1158
-
1159
- def load_path(
1160
- self,
1161
- manifest_entry: ArtifactManifestEntry,
1162
- local: bool = False,
1163
- ) -> Union[URIStr, FilePathStr]:
1164
- if local:
1165
- # Likely a user error. The tracking handler is
1166
- # oblivious to the underlying paths, so it has
1167
- # no way of actually loading it.
1168
- url = urlparse(manifest_entry.ref)
1169
- raise ValueError(
1170
- f"Cannot download file at path {str(manifest_entry.ref)}, scheme {str(url.scheme)} not recognized"
1171
- )
1172
- # TODO(spencerpearson): should this go through util.to_native_slash_path
1173
- # instead of just getting typecast?
1174
- return FilePathStr(manifest_entry.path)
1175
-
1176
- def store_path(
1177
- self,
1178
- artifact: ArtifactInterface,
1179
- path: Union[URIStr, FilePathStr],
1180
- name: Optional[str] = None,
1181
- checksum: bool = True,
1182
- max_objects: Optional[int] = None,
1183
- ) -> Sequence[ArtifactManifestEntry]:
1184
- url = urlparse(path)
1185
- if name is None:
1186
- raise ValueError(
1187
- 'You must pass name="<entry_name>" when tracking references with unknown schemes. ref: %s'
1188
- % path
1189
- )
1190
- termwarn(
1191
- "Artifact references with unsupported schemes cannot be checksummed: %s"
1192
- % path
1193
- )
1194
- name = LogicalFilePathStr(name or url.path[1:]) # strip leading slash
1195
- return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
1196
-
1197
-
1198
- DEFAULT_MAX_OBJECTS = 10000
1199
-
1200
-
1201
- class LocalFileHandler(StorageHandler):
1202
- """Handles file:// references."""
1203
-
1204
- def __init__(self, scheme: Optional[str] = None) -> None:
1205
- """Track files or directories on a local filesystem.
1206
-
1207
- Expand directories to create an entry for each file contained.
1208
- """
1209
- self._scheme = scheme or "file"
1210
- self._cache = get_artifacts_cache()
1211
-
1212
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1213
- return parsed_url.scheme == self._scheme
1214
-
1215
- def load_path(
1216
- self,
1217
- manifest_entry: ArtifactManifestEntry,
1218
- local: bool = False,
1219
- ) -> Union[URIStr, FilePathStr]:
1220
- if manifest_entry.ref is None:
1221
- raise ValueError(f"Cannot add path with no ref: {manifest_entry.path}")
1222
- local_path = util.local_file_uri_to_path(str(manifest_entry.ref))
1223
- if not os.path.exists(local_path):
1224
- raise ValueError(
1225
- "Local file reference: Failed to find file at path %s" % local_path
1226
- )
1227
-
1228
- path, hit, cache_open = self._cache.check_md5_obj_path(
1229
- B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
1230
- manifest_entry.size if manifest_entry.size is not None else 0,
1231
- )
1232
- if hit:
1233
- return path
1234
-
1235
- md5 = md5_file_b64(local_path)
1236
- if md5 != manifest_entry.digest:
1237
- raise ValueError(
1238
- f"Local file reference: Digest mismatch for path {local_path}: expected {manifest_entry.digest} but found {md5}"
1239
- )
1240
-
1241
- filesystem.mkdir_exists_ok(os.path.dirname(path))
1242
-
1243
- with cache_open() as f:
1244
- shutil.copy(local_path, f.name)
1245
- return path
1246
-
1247
- def store_path(
1248
- self,
1249
- artifact: ArtifactInterface,
1250
- path: Union[URIStr, FilePathStr],
1251
- name: Optional[str] = None,
1252
- checksum: bool = True,
1253
- max_objects: Optional[int] = None,
1254
- ) -> Sequence[ArtifactManifestEntry]:
1255
- local_path = util.local_file_uri_to_path(path)
1256
- max_objects = max_objects or DEFAULT_MAX_OBJECTS
1257
- # We have a single file or directory
1258
- # Note, we follow symlinks for files contained within the directory
1259
- entries = []
1260
-
1261
- def md5(path: str) -> B64MD5:
1262
- return (
1263
- md5_file_b64(path)
1264
- if checksum
1265
- else md5_string(str(os.stat(path).st_size))
1266
- )
1267
-
1268
- if os.path.isdir(local_path):
1269
- i = 0
1270
- start_time = time.time()
1271
- if checksum:
1272
- termlog(
1273
- 'Generating checksum for up to %i files in "%s"...\n'
1274
- % (max_objects, local_path),
1275
- newline=False,
1276
- )
1277
- for root, _, files in os.walk(local_path):
1278
- for sub_path in files:
1279
- i += 1
1280
- if i > max_objects:
1281
- raise ValueError(
1282
- "Exceeded %i objects tracked, pass max_objects to add_reference"
1283
- % max_objects
1284
- )
1285
- physical_path = os.path.join(root, sub_path)
1286
- # TODO(spencerpearson): this is not a "logical path" in the sense that
1287
- # `util.to_forward_slash_path` returns a "logical path"; it's a relative path
1288
- # **on the local filesystem**.
1289
- logical_path = os.path.relpath(physical_path, start=local_path)
1290
- if name is not None:
1291
- logical_path = os.path.join(name, logical_path)
1292
-
1293
- entry = ArtifactManifestEntry(
1294
- path=LogicalFilePathStr(logical_path),
1295
- ref=FilePathStr(os.path.join(path, logical_path)),
1296
- size=os.path.getsize(physical_path),
1297
- digest=md5(physical_path),
1298
- )
1299
- entries.append(entry)
1300
- if checksum:
1301
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1302
- elif os.path.isfile(local_path):
1303
- name = name or os.path.basename(local_path)
1304
- entry = ArtifactManifestEntry(
1305
- path=LogicalFilePathStr(name),
1306
- ref=path,
1307
- size=os.path.getsize(local_path),
1308
- digest=md5(local_path),
1309
- )
1310
- entries.append(entry)
1311
- else:
1312
- # TODO: update error message if we don't allow directories.
1313
- raise ValueError('Path "%s" must be a valid file or directory path' % path)
1314
- return entries
1315
-
1316
-
1317
- class S3Handler(StorageHandler):
1318
- _s3: Optional["boto3.resources.base.ServiceResource"]
1319
- _scheme: str
1320
- _versioning_enabled: Optional[bool]
1321
-
1322
- def __init__(self, scheme: Optional[str] = None) -> None:
1323
- self._scheme = scheme or "s3"
1324
- self._s3 = None
1325
- self._versioning_enabled = None
1326
- self._cache = get_artifacts_cache()
1327
-
1328
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1329
- return parsed_url.scheme == self._scheme
1330
-
1331
- def init_boto(self) -> "boto3.resources.base.ServiceResource":
1332
- if self._s3 is not None:
1333
- return self._s3
1334
- boto: "boto3" = util.get_module(
1335
- "boto3",
1336
- required="s3:// references requires the boto3 library, run pip install wandb[aws]",
1337
- lazy=False,
1338
- )
1339
- self._s3 = boto.session.Session().resource(
1340
- "s3",
1341
- endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL"),
1342
- region_name=os.getenv("AWS_REGION"),
1343
- )
1344
- self._botocore = util.get_module("botocore")
1345
- return self._s3
1346
-
1347
- def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
1348
- url = urlparse(uri)
1349
- query = dict(parse_qsl(url.query))
1350
-
1351
- bucket = url.netloc
1352
- key = url.path[1:] # strip leading slash
1353
- version = query.get("versionId")
1354
-
1355
- return bucket, key, version
1356
-
1357
- def versioning_enabled(self, bucket: str) -> bool:
1358
- self.init_boto()
1359
- assert self._s3 is not None # mypy: unwraps optionality
1360
- if self._versioning_enabled is not None:
1361
- return self._versioning_enabled
1362
- res = self._s3.BucketVersioning(bucket)
1363
- self._versioning_enabled = res.status == "Enabled"
1364
- return self._versioning_enabled
1365
-
1366
- def load_path(
1367
- self,
1368
- manifest_entry: ArtifactManifestEntry,
1369
- local: bool = False,
1370
- ) -> Union[URIStr, FilePathStr]:
1371
- if not local:
1372
- assert manifest_entry.ref is not None
1373
- return manifest_entry.ref
1374
-
1375
- assert manifest_entry.ref is not None
1376
-
1377
- path, hit, cache_open = self._cache.check_etag_obj_path(
1378
- URIStr(manifest_entry.ref),
1379
- ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
1380
- manifest_entry.size if manifest_entry.size is not None else 0,
1381
- )
1382
- if hit:
1383
- return path
1384
-
1385
- self.init_boto()
1386
- assert self._s3 is not None # mypy: unwraps optionality
1387
- bucket, key, _ = self._parse_uri(manifest_entry.ref)
1388
- version = manifest_entry.extra.get("versionID")
1389
-
1390
- extra_args = {}
1391
- if version is None:
1392
- # We don't have version information so just get the latest version
1393
- # and fallback to listing all versions if we don't have a match.
1394
- obj = self._s3.Object(bucket, key)
1395
- etag = self._etag_from_obj(obj)
1396
- if etag != manifest_entry.digest:
1397
- if self.versioning_enabled(bucket):
1398
- # Fallback to listing versions
1399
- obj = None
1400
- object_versions = self._s3.Bucket(bucket).object_versions.filter(
1401
- Prefix=key
1402
- )
1403
- for object_version in object_versions:
1404
- if (
1405
- manifest_entry.extra.get("etag")
1406
- == object_version.e_tag[1:-1]
1407
- ):
1408
- obj = object_version.Object()
1409
- extra_args["VersionId"] = object_version.version_id
1410
- break
1411
- if obj is None:
1412
- raise ValueError(
1413
- "Couldn't find object version for {}/{} matching etag {}".format(
1414
- bucket, key, manifest_entry.extra.get("etag")
1415
- )
1416
- )
1417
- else:
1418
- raise ValueError(
1419
- f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {etag}"
1420
- )
1421
- else:
1422
- obj = self._s3.ObjectVersion(bucket, key, version).Object()
1423
- extra_args["VersionId"] = version
1424
-
1425
- with cache_open(mode="wb") as f:
1426
- obj.download_fileobj(f, ExtraArgs=extra_args)
1427
- return path
1428
-
1429
- def store_path(
1430
- self,
1431
- artifact: ArtifactInterface,
1432
- path: Union[URIStr, FilePathStr],
1433
- name: Optional[str] = None,
1434
- checksum: bool = True,
1435
- max_objects: Optional[int] = None,
1436
- ) -> Sequence[ArtifactManifestEntry]:
1437
- self.init_boto()
1438
- assert self._s3 is not None # mypy: unwraps optionality
1439
-
1440
- # The passed in path might have query string parameters.
1441
- # We only need to care about a subset, like version, when
1442
- # parsing. Once we have that, we can store the rest of the
1443
- # metadata in the artifact entry itself.
1444
- bucket, key, version = self._parse_uri(path)
1445
- path = URIStr(f"{self._scheme}://{bucket}/{key}")
1446
- if not self.versioning_enabled(bucket) and version:
1447
- raise ValueError(
1448
- f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
1449
- )
1450
-
1451
- max_objects = max_objects or DEFAULT_MAX_OBJECTS
1452
- if not checksum:
1453
- return [
1454
- ArtifactManifestEntry(
1455
- path=LogicalFilePathStr(name or key), ref=path, digest=path
1456
- )
1457
- ]
1458
-
1459
- # If an explicit version is specified, use that. Otherwise, use the head version.
1460
- objs = (
1461
- [self._s3.ObjectVersion(bucket, key, version).Object()]
1462
- if version
1463
- else [self._s3.Object(bucket, key)]
1464
- )
1465
- start_time = None
1466
- multi = False
1467
- try:
1468
- objs[0].load()
1469
- # S3 doesn't have real folders, however there are cases where the folder key has a valid file which will not
1470
- # trigger a recursive upload.
1471
- # we should check the object's metadata says it is a directory and do a multi file upload if it is
1472
- if "x-directory" in objs[0].content_type:
1473
- multi = True
1474
- except self._botocore.exceptions.ClientError as e:
1475
- if e.response["Error"]["Code"] == "404":
1476
- multi = True
1477
- else:
1478
- raise CommError(
1479
- "Unable to connect to S3 ({}): {}".format(
1480
- e.response["Error"]["Code"], e.response["Error"]["Message"]
1481
- )
1482
- )
1483
- if multi:
1484
- start_time = time.time()
1485
- termlog(
1486
- 'Generating checksum for up to %i objects with prefix "%s"... '
1487
- % (max_objects, key),
1488
- newline=False,
1489
- )
1490
- objs = self._s3.Bucket(bucket).objects.filter(Prefix=key).limit(max_objects)
1491
- # Weird iterator scoping makes us assign this to a local function
1492
- size = self._size_from_obj
1493
- entries = [
1494
- self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
1495
- for obj in objs
1496
- if size(obj) > 0
1497
- ]
1498
- if start_time is not None:
1499
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1500
- if len(entries) > max_objects:
1501
- raise ValueError(
1502
- "Exceeded %i objects tracked, pass max_objects to add_reference"
1503
- % max_objects
1504
- )
1505
- return entries
1506
-
1507
- def _size_from_obj(self, obj: "boto3.s3.Object") -> int:
1508
- # ObjectSummary has size, Object has content_length
1509
- size: int
1510
- if hasattr(obj, "size"):
1511
- size = obj.size
1512
- else:
1513
- size = obj.content_length
1514
- return size
1515
-
1516
- def _entry_from_obj(
1517
- self,
1518
- obj: "boto3.s3.Object",
1519
- path: str,
1520
- name: Optional[str] = None,
1521
- prefix: str = "",
1522
- multi: bool = False,
1523
- ) -> ArtifactManifestEntry:
1524
- """Create an ArtifactManifestEntry from an S3 object.
1525
-
1526
- Arguments:
1527
- obj: The S3 object
1528
- path: The S3-style path (e.g.: "s3://bucket/file.txt")
1529
- name: The user assigned name, or None if not specified
1530
- prefix: The prefix to add (will be the same as `path` for directories)
1531
- multi: Whether or not this is a multi-object add.
1532
- """
1533
- bucket, key, _ = self._parse_uri(path)
1534
-
1535
- # Always use posix paths, since that's what S3 uses.
1536
- posix_key = pathlib.PurePosixPath(obj.key) # the bucket key
1537
- posix_path = pathlib.PurePosixPath(bucket) / pathlib.PurePosixPath(
1538
- key
1539
- ) # the path, with the scheme stripped
1540
- posix_prefix = pathlib.PurePosixPath(prefix) # the prefix, if adding a prefix
1541
- posix_name = pathlib.PurePosixPath(name or "")
1542
- posix_ref = posix_path
1543
-
1544
- if name is None:
1545
- # We're adding a directory (prefix), so calculate a relative path.
1546
- if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
1547
- posix_name = posix_key.relative_to(posix_prefix)
1548
- posix_ref = posix_path / posix_name
1549
- else:
1550
- posix_name = pathlib.PurePosixPath(posix_key.name)
1551
- posix_ref = posix_path
1552
- elif multi:
1553
- # We're adding a directory with a name override.
1554
- relpath = posix_key.relative_to(posix_prefix)
1555
- posix_name = posix_name / relpath
1556
- posix_ref = posix_path / relpath
1557
- return ArtifactManifestEntry(
1558
- path=LogicalFilePathStr(str(posix_name)),
1559
- ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
1560
- digest=ETag(self._etag_from_obj(obj)),
1561
- size=self._size_from_obj(obj),
1562
- extra=self._extra_from_obj(obj),
1563
- )
1564
-
1565
- @staticmethod
1566
- def _etag_from_obj(obj: "boto3.s3.Object") -> ETag:
1567
- etag: ETag
1568
- etag = obj.e_tag[1:-1] # escape leading and trailing quote
1569
- return etag
1570
-
1571
- @staticmethod
1572
- def _extra_from_obj(obj: "boto3.s3.Object") -> Dict[str, str]:
1573
- extra = {
1574
- "etag": obj.e_tag[1:-1], # escape leading and trailing quote
1575
- }
1576
- # ObjectSummary will never have version_id
1577
- if hasattr(obj, "version_id") and obj.version_id != "null":
1578
- extra["versionID"] = obj.version_id
1579
- return extra
1580
-
1581
- @staticmethod
1582
- def _content_addressed_path(md5: str) -> FilePathStr:
1583
- # TODO: is this the structure we want? not at all human
1584
- # readable, but that's probably OK. don't want people
1585
- # poking around in the bucket
1586
- return FilePathStr(
1587
- "wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
1588
- )
1589
-
1590
-
1591
- class GCSHandler(StorageHandler):
1592
- _client: Optional["gcs_module.client.Client"]
1593
- _versioning_enabled: Optional[bool]
1594
-
1595
- def __init__(self, scheme: Optional[str] = None) -> None:
1596
- self._scheme = scheme or "gs"
1597
- self._client = None
1598
- self._versioning_enabled = None
1599
- self._cache = get_artifacts_cache()
1600
-
1601
- def versioning_enabled(self, bucket_path: str) -> bool:
1602
- if self._versioning_enabled is not None:
1603
- return self._versioning_enabled
1604
- self.init_gcs()
1605
- assert self._client is not None # mypy: unwraps optionality
1606
- bucket = self._client.bucket(bucket_path)
1607
- bucket.reload()
1608
- self._versioning_enabled = bucket.versioning_enabled
1609
- return self._versioning_enabled
1610
-
1611
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1612
- return parsed_url.scheme == self._scheme
1613
-
1614
- def init_gcs(self) -> "gcs_module.client.Client":
1615
- if self._client is not None:
1616
- return self._client
1617
- storage = util.get_module(
1618
- "google.cloud.storage",
1619
- required="gs:// references requires the google-cloud-storage library, run pip install wandb[gcp]",
1620
- )
1621
- self._client = storage.Client()
1622
- return self._client
1623
-
1624
- def _parse_uri(self, uri: str) -> Tuple[str, str, Optional[str]]:
1625
- url = urlparse(uri)
1626
- bucket = url.netloc
1627
- key = url.path[1:]
1628
- version = url.fragment if url.fragment else None
1629
- return bucket, key, version
1630
-
1631
- def load_path(
1632
- self,
1633
- manifest_entry: ArtifactManifestEntry,
1634
- local: bool = False,
1635
- ) -> Union[URIStr, FilePathStr]:
1636
- if not local:
1637
- assert manifest_entry.ref is not None
1638
- return manifest_entry.ref
1639
-
1640
- path, hit, cache_open = self._cache.check_md5_obj_path(
1641
- B64MD5(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
1642
- manifest_entry.size if manifest_entry.size is not None else 0,
1643
- )
1644
- if hit:
1645
- return path
1646
-
1647
- self.init_gcs()
1648
- assert self._client is not None # mypy: unwraps optionality
1649
- assert manifest_entry.ref is not None
1650
- bucket, key, _ = self._parse_uri(manifest_entry.ref)
1651
- version = manifest_entry.extra.get("versionID")
1652
-
1653
- obj = None
1654
- # First attempt to get the generation specified, this will return None if versioning is not enabled
1655
- if version is not None:
1656
- obj = self._client.bucket(bucket).get_blob(key, generation=version)
1657
-
1658
- if obj is None:
1659
- # Object versioning is disabled on the bucket, so just get
1660
- # the latest version and make sure the MD5 matches.
1661
- obj = self._client.bucket(bucket).get_blob(key)
1662
- if obj is None:
1663
- raise ValueError(
1664
- f"Unable to download object {manifest_entry.ref} with generation {version}"
1665
- )
1666
- md5 = obj.md5_hash
1667
- if md5 != manifest_entry.digest:
1668
- raise ValueError(
1669
- f"Digest mismatch for object {manifest_entry.ref}: expected {manifest_entry.digest} but found {md5}"
1670
- )
1671
-
1672
- with cache_open(mode="wb") as f:
1673
- obj.download_to_file(f)
1674
- return path
1675
-
1676
- def store_path(
1677
- self,
1678
- artifact: ArtifactInterface,
1679
- path: Union[URIStr, FilePathStr],
1680
- name: Optional[str] = None,
1681
- checksum: bool = True,
1682
- max_objects: Optional[int] = None,
1683
- ) -> Sequence[ArtifactManifestEntry]:
1684
- self.init_gcs()
1685
- assert self._client is not None # mypy: unwraps optionality
1686
-
1687
- # After parsing any query params / fragments for additional context,
1688
- # such as version identifiers, pare down the path to just the bucket
1689
- # and key.
1690
- bucket, key, version = self._parse_uri(path)
1691
- path = URIStr(f"{self._scheme}://{bucket}/{key}")
1692
- max_objects = max_objects or DEFAULT_MAX_OBJECTS
1693
- if not self.versioning_enabled(bucket) and version:
1694
- raise ValueError(
1695
- f"Specifying a versionId is not valid for s3://{bucket} as it does not have versioning enabled."
1696
- )
1697
-
1698
- if not checksum:
1699
- return [
1700
- ArtifactManifestEntry(
1701
- path=LogicalFilePathStr(name or key), ref=path, digest=path
1702
- )
1703
- ]
1704
-
1705
- start_time = None
1706
- obj = self._client.bucket(bucket).get_blob(key, generation=version)
1707
- multi = obj is None
1708
- if multi:
1709
- start_time = time.time()
1710
- termlog(
1711
- 'Generating checksum for up to %i objects with prefix "%s"... '
1712
- % (max_objects, key),
1713
- newline=False,
1714
- )
1715
- objects = self._client.bucket(bucket).list_blobs(
1716
- prefix=key, max_results=max_objects
1717
- )
1718
- else:
1719
- objects = [obj]
1720
-
1721
- entries = [
1722
- self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
1723
- for obj in objects
1724
- ]
1725
- if start_time is not None:
1726
- termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1727
- if len(entries) > max_objects:
1728
- raise ValueError(
1729
- "Exceeded %i objects tracked, pass max_objects to add_reference"
1730
- % max_objects
1731
- )
1732
- return entries
1733
-
1734
- def _entry_from_obj(
1735
- self,
1736
- obj: "gcs_module.blob.Blob",
1737
- path: str,
1738
- name: Optional[str] = None,
1739
- prefix: str = "",
1740
- multi: bool = False,
1741
- ) -> ArtifactManifestEntry:
1742
- """Create an ArtifactManifestEntry from a GCS object.
1743
-
1744
- Arguments:
1745
- obj: The GCS object
1746
- path: The GCS-style path (e.g.: "gs://bucket/file.txt")
1747
- name: The user assigned name, or None if not specified
1748
- prefix: The prefix to add (will be the same as `path` for directories)
1749
- multi: Whether or not this is a multi-object add.
1750
- """
1751
- bucket, key, _ = self._parse_uri(path)
1752
-
1753
- # Always use posix paths, since that's what S3 uses.
1754
- posix_key = pathlib.PurePosixPath(obj.name) # the bucket key
1755
- posix_path = pathlib.PurePosixPath(bucket) / pathlib.PurePosixPath(
1756
- key
1757
- ) # the path, with the scheme stripped
1758
- posix_prefix = pathlib.PurePosixPath(prefix) # the prefix, if adding a prefix
1759
- posix_name = pathlib.PurePosixPath(name or "")
1760
- posix_ref = posix_path
1761
-
1762
- if name is None:
1763
- # We're adding a directory (prefix), so calculate a relative path.
1764
- if str(posix_prefix) in str(posix_key) and posix_prefix != posix_key:
1765
- posix_name = posix_key.relative_to(posix_prefix)
1766
- posix_ref = posix_path / posix_name
1767
- else:
1768
- posix_name = pathlib.PurePosixPath(posix_key.name)
1769
- posix_ref = posix_path
1770
- elif multi:
1771
- # We're adding a directory with a name override.
1772
- relpath = posix_key.relative_to(posix_prefix)
1773
- posix_name = posix_name / relpath
1774
- posix_ref = posix_path / relpath
1775
- return ArtifactManifestEntry(
1776
- path=LogicalFilePathStr(str(posix_name)),
1777
- ref=URIStr(f"{self._scheme}://{str(posix_ref)}"),
1778
- digest=obj.md5_hash,
1779
- size=obj.size,
1780
- extra=self._extra_from_obj(obj),
1781
- )
1782
-
1783
- @staticmethod
1784
- def _extra_from_obj(obj: "gcs_module.blob.Blob") -> Dict[str, str]:
1785
- return {
1786
- "etag": obj.etag,
1787
- "versionID": obj.generation,
1788
- }
1789
-
1790
- @staticmethod
1791
- def _content_addressed_path(md5: str) -> FilePathStr:
1792
- # TODO: is this the structure we want? not at all human
1793
- # readable, but that's probably OK. don't want people
1794
- # poking around in the bucket
1795
- return FilePathStr(
1796
- "wandb/%s" % base64.b64encode(md5.encode("ascii")).decode("ascii")
1797
- )
1798
-
1799
-
1800
- class AzureHandler(StorageHandler):
1801
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1802
- return parsed_url.scheme == "https" and parsed_url.netloc.endswith(
1803
- ".blob.core.windows.net"
1804
- )
1805
-
1806
- def load_path(
1807
- self,
1808
- manifest_entry: "ArtifactManifestEntry",
1809
- local: bool = False,
1810
- ) -> Union[URIStr, FilePathStr]:
1811
- assert manifest_entry.ref is not None
1812
- if not local:
1813
- return manifest_entry.ref
1814
-
1815
- path, hit, cache_open = get_artifacts_cache().check_etag_obj_path(
1816
- URIStr(manifest_entry.ref),
1817
- ETag(manifest_entry.digest),
1818
- manifest_entry.size or 0,
1819
- )
1820
- if hit:
1821
- return path
1822
-
1823
- account_url, container_name, blob_name, query = self._parse_uri(
1824
- manifest_entry.ref
1825
- )
1826
- version_id = manifest_entry.extra.get("versionID")
1827
- blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
1828
- account_url,
1829
- credential=self._get_module("azure.identity").DefaultAzureCredential(),
1830
- )
1831
- blob_client = blob_service_client.get_blob_client(
1832
- container=container_name, blob=blob_name
1833
- )
1834
- if version_id is None:
1835
- # Try current version, then all versions.
1836
- try:
1837
- downloader = blob_client.download_blob(
1838
- etag=manifest_entry.digest,
1839
- match_condition=self._get_module(
1840
- "azure.core"
1841
- ).MatchConditions.IfNotModified,
1842
- )
1843
- except self._get_module("azure.core.exceptions").ResourceModifiedError:
1844
- container_client = blob_service_client.get_container_client(
1845
- container_name
1846
- )
1847
- for blob_properties in container_client.walk_blobs(
1848
- name_starts_with=blob_name, include=["versions"]
1849
- ):
1850
- if (
1851
- blob_properties.name == blob_name
1852
- and blob_properties.etag == manifest_entry.digest
1853
- and blob_properties.version_id is not None
1854
- ):
1855
- downloader = blob_client.download_blob(
1856
- version_id=blob_properties.version_id
1857
- )
1858
- break
1859
- else: # didn't break
1860
- raise ValueError(
1861
- f"Couldn't find blob version for {manifest_entry.ref} matching "
1862
- f"etag {manifest_entry.digest}."
1863
- )
1864
- else:
1865
- downloader = blob_client.download_blob(version_id=version_id)
1866
- with cache_open(mode="wb") as f:
1867
- downloader.readinto(f)
1868
- return path
1869
-
1870
- def store_path(
1871
- self,
1872
- artifact: ArtifactInterface,
1873
- path: Union[URIStr, FilePathStr],
1874
- name: Optional[str] = None,
1875
- checksum: bool = True,
1876
- max_objects: Optional[int] = None,
1877
- ) -> Sequence["ArtifactManifestEntry"]:
1878
- account_url, container_name, blob_name, query = self._parse_uri(path)
1879
- path = URIStr(f"{account_url}/{container_name}/{blob_name}")
1880
-
1881
- if not checksum:
1882
- return [
1883
- ArtifactManifestEntry(
1884
- path=LogicalFilePathStr(name or blob_name), digest=path, ref=path
1885
- )
1886
- ]
1887
-
1888
- blob_service_client = self._get_module("azure.storage.blob").BlobServiceClient(
1889
- account_url,
1890
- credential=self._get_module("azure.identity").DefaultAzureCredential(),
1891
- )
1892
- blob_client = blob_service_client.get_blob_client(
1893
- container=container_name, blob=blob_name
1894
- )
1895
- if blob_client.exists(version_id=query.get("versionId")):
1896
- blob_properties = blob_client.get_blob_properties(
1897
- version_id=query.get("versionId")
1898
- )
1899
- return [
1900
- self._create_entry(
1901
- blob_properties,
1902
- path=LogicalFilePathStr(
1903
- name or pathlib.PurePosixPath(blob_name).name
1904
- ),
1905
- ref=URIStr(
1906
- f"{account_url}/{container_name}/{blob_properties.name}"
1907
- ),
1908
- )
1909
- ]
1910
-
1911
- entries = []
1912
- container_client = blob_service_client.get_container_client(container_name)
1913
- max_objects = max_objects or DEFAULT_MAX_OBJECTS
1914
- for i, blob_properties in enumerate(
1915
- container_client.list_blobs(name_starts_with=f"{blob_name}/")
1916
- ):
1917
- if i >= max_objects:
1918
- raise ValueError(
1919
- f"Exceeded {max_objects} objects tracked, pass max_objects to "
1920
- f"add_reference"
1921
- )
1922
- suffix = pathlib.PurePosixPath(blob_properties.name).relative_to(blob_name)
1923
- entries.append(
1924
- self._create_entry(
1925
- blob_properties,
1926
- path=LogicalFilePathStr(str(name / suffix if name else suffix)),
1927
- ref=URIStr(
1928
- f"{account_url}/{container_name}/{blob_properties.name}"
1929
- ),
1930
- )
1931
- )
1932
- return entries
1933
-
1934
- def _get_module(self, name: str) -> ModuleType:
1935
- module = util.get_module(
1936
- name,
1937
- lazy=False,
1938
- required="Azure references require the azure library, run "
1939
- "pip install wandb[azure]",
1940
- )
1941
- assert isinstance(module, ModuleType)
1942
- return module
1943
-
1944
- def _parse_uri(self, uri: str) -> Tuple[str, str, str, Dict[str, str]]:
1945
- parsed_url = urlparse(uri)
1946
- query = dict(parse_qsl(parsed_url.query))
1947
- account_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
1948
- _, container_name, blob_name = parsed_url.path.split("/", 2)
1949
- return account_url, container_name, blob_name, query
1950
-
1951
- def _create_entry(
1952
- self,
1953
- blob_properties: "azure.storage.blob.BlobProperties",
1954
- path: LogicalFilePathStr,
1955
- ref: URIStr,
1956
- ) -> ArtifactManifestEntry:
1957
- extra = {"etag": blob_properties.etag.strip('"')}
1958
- if blob_properties.version_id:
1959
- extra["versionID"] = blob_properties.version_id
1960
- return ArtifactManifestEntry(
1961
- path=path,
1962
- ref=ref,
1963
- digest=blob_properties.etag.strip('"'),
1964
- size=blob_properties.size,
1965
- extra=extra,
1966
- )
1967
-
1968
-
1969
- class HTTPHandler(StorageHandler):
1970
- def __init__(self, session: requests.Session, scheme: Optional[str] = None) -> None:
1971
- self._scheme = scheme or "http"
1972
- self._cache = get_artifacts_cache()
1973
- self._session = session
1974
-
1975
- def can_handle(self, parsed_url: "ParseResult") -> bool:
1976
- return parsed_url.scheme == self._scheme
1977
-
1978
- def load_path(
1979
- self,
1980
- manifest_entry: ArtifactManifestEntry,
1981
- local: bool = False,
1982
- ) -> Union[URIStr, FilePathStr]:
1983
- if not local:
1984
- assert manifest_entry.ref is not None
1985
- return manifest_entry.ref
1986
-
1987
- assert manifest_entry.ref is not None
1988
-
1989
- path, hit, cache_open = self._cache.check_etag_obj_path(
1990
- URIStr(manifest_entry.ref),
1991
- ETag(manifest_entry.digest), # TODO(spencerpearson): unsafe cast
1992
- manifest_entry.size if manifest_entry.size is not None else 0,
1993
- )
1994
- if hit:
1995
- return path
1996
-
1997
- response = self._session.get(manifest_entry.ref, stream=True)
1998
- response.raise_for_status()
1999
-
2000
- digest: Optional[Union[ETag, FilePathStr, URIStr]]
2001
- digest, size, extra = self._entry_from_headers(response.headers)
2002
- digest = digest or manifest_entry.ref
2003
- if manifest_entry.digest != digest:
2004
- raise ValueError(
2005
- f"Digest mismatch for url {manifest_entry.ref}: expected {manifest_entry.digest} but found {digest}"
2006
- )
2007
-
2008
- with cache_open(mode="wb") as file:
2009
- for data in response.iter_content(chunk_size=16 * 1024):
2010
- file.write(data)
2011
- return path
2012
-
2013
- def store_path(
2014
- self,
2015
- artifact: ArtifactInterface,
2016
- path: Union[URIStr, FilePathStr],
2017
- name: Optional[str] = None,
2018
- checksum: bool = True,
2019
- max_objects: Optional[int] = None,
2020
- ) -> Sequence[ArtifactManifestEntry]:
2021
- name = LogicalFilePathStr(name or os.path.basename(path))
2022
- if not checksum:
2023
- return [ArtifactManifestEntry(path=name, ref=path, digest=path)]
2024
-
2025
- with self._session.get(path, stream=True) as response:
2026
- response.raise_for_status()
2027
- digest: Optional[Union[ETag, FilePathStr, URIStr]]
2028
- digest, size, extra = self._entry_from_headers(response.headers)
2029
- digest = digest or path
2030
- return [
2031
- ArtifactManifestEntry(
2032
- path=name, ref=path, digest=digest, size=size, extra=extra
2033
- )
2034
- ]
2035
-
2036
- def _entry_from_headers(
2037
- self, headers: requests.structures.CaseInsensitiveDict
2038
- ) -> Tuple[Optional[ETag], Optional[int], Dict[str, str]]:
2039
- response_headers = {k.lower(): v for k, v in headers.items()}
2040
- size = None
2041
- if response_headers.get("content-length", None):
2042
- size = int(response_headers["content-length"])
2043
-
2044
- digest = response_headers.get("etag", None)
2045
- extra = {}
2046
- if digest:
2047
- extra["etag"] = digest
2048
- if digest and digest[:1] == '"' and digest[-1:] == '"':
2049
- digest = digest[1:-1] # trim leading and trailing quotes around etag
2050
- return digest, size, extra
2051
-
2052
-
2053
- class WBArtifactHandler(StorageHandler):
2054
- """Handles loading and storing Artifact reference-type files."""
2055
-
2056
- _client: Optional[PublicApi]
2057
-
2058
- def __init__(self) -> None:
2059
- self._scheme = "wandb-artifact"
2060
- self._cache = get_artifacts_cache()
2061
- self._client = None
2062
-
2063
- def can_handle(self, parsed_url: "ParseResult") -> bool:
2064
- return parsed_url.scheme == self._scheme
2065
-
2066
- @property
2067
- def client(self) -> PublicApi:
2068
- if self._client is None:
2069
- self._client = PublicApi()
2070
- return self._client
2071
-
2072
- def load_path(
2073
- self,
2074
- manifest_entry: ArtifactManifestEntry,
2075
- local: bool = False,
2076
- ) -> Union[URIStr, FilePathStr]:
2077
- """Load the file in the specified artifact given its corresponding entry.
2078
-
2079
- Download the referenced artifact; create and return a new symlink to the caller.
2080
-
2081
- Arguments:
2082
- manifest_entry (ArtifactManifestEntry): The index entry to load
2083
-
2084
- Returns:
2085
- (os.PathLike): A path to the file represented by `index_entry`
2086
- """
2087
- # We don't check for cache hits here. Since we have 0 for size (since this
2088
- # is a cross-artifact reference which and we've made the choice to store 0
2089
- # in the size field), we can't confirm if the file is complete. So we just
2090
- # rely on the dep_artifact entry's download() method to do its own cache
2091
- # check.
2092
-
2093
- # Parse the reference path and download the artifact if needed
2094
- artifact_id = util.host_from_path(manifest_entry.ref)
2095
- artifact_file_path = util.uri_from_path(manifest_entry.ref)
2096
-
2097
- dep_artifact = PublicArtifact.from_id(hex_to_b64_id(artifact_id), self.client)
2098
- link_target_path: FilePathStr
2099
- if local:
2100
- link_target_path = dep_artifact.get_path(artifact_file_path).download()
2101
- else:
2102
- link_target_path = dep_artifact.get_path(artifact_file_path).ref_target()
2103
-
2104
- return link_target_path
2105
-
2106
- def store_path(
2107
- self,
2108
- artifact: ArtifactInterface,
2109
- path: Union[URIStr, FilePathStr],
2110
- name: Optional[str] = None,
2111
- checksum: bool = True,
2112
- max_objects: Optional[int] = None,
2113
- ) -> Sequence[ArtifactManifestEntry]:
2114
- """Store the file or directory at the given path into the specified artifact.
2115
-
2116
- Recursively resolves the reference until the result is a concrete asset.
2117
-
2118
- Arguments:
2119
- artifact: The artifact doing the storing path (str): The path to store name
2120
- (str): If specified, the logical name that should map to `path`
2121
-
2122
- Returns:
2123
- (list[ArtifactManifestEntry]): A list of manifest entries to store within
2124
- the artifact
2125
- """
2126
- # Recursively resolve the reference until a concrete asset is found
2127
- # TODO: Consider resolving server-side for performance improvements.
2128
- while path is not None and urlparse(path).scheme == self._scheme:
2129
- artifact_id = util.host_from_path(path)
2130
- artifact_file_path = util.uri_from_path(path)
2131
- target_artifact = PublicArtifact.from_id(
2132
- hex_to_b64_id(artifact_id), self.client
2133
- )
2134
-
2135
- # this should only have an effect if the user added the reference by url
2136
- # string directly (in other words they did not already load the artifact into ram.)
2137
- target_artifact._load_manifest()
2138
-
2139
- entry = target_artifact._manifest.get_entry_by_path(artifact_file_path)
2140
- path = entry.ref
2141
-
2142
- # Create the path reference
2143
- path = URIStr(
2144
- "{}://{}/{}".format(
2145
- self._scheme,
2146
- b64_to_hex_id(target_artifact.id),
2147
- artifact_file_path,
2148
- )
2149
- )
2150
-
2151
- # Return the new entry
2152
- return [
2153
- ArtifactManifestEntry(
2154
- path=LogicalFilePathStr(name or os.path.basename(path)),
2155
- ref=path,
2156
- size=0,
2157
- digest=entry.digest,
2158
- )
2159
- ]
2160
-
2161
-
2162
- class WBLocalArtifactHandler(StorageHandler):
2163
- """Handles loading and storing Artifact reference-type files."""
2164
-
2165
- _client: Optional[PublicApi]
2166
-
2167
- def __init__(self) -> None:
2168
- self._scheme = "wandb-client-artifact"
2169
- self._cache = get_artifacts_cache()
2170
-
2171
- def can_handle(self, parsed_url: "ParseResult") -> bool:
2172
- return parsed_url.scheme == self._scheme
2173
-
2174
- def load_path(
2175
- self,
2176
- manifest_entry: ArtifactManifestEntry,
2177
- local: bool = False,
2178
- ) -> Union[URIStr, FilePathStr]:
2179
- raise NotImplementedError(
2180
- "Should not be loading a path for an artifact entry with unresolved client id."
2181
- )
2182
-
2183
- def store_path(
2184
- self,
2185
- artifact: ArtifactInterface,
2186
- path: Union[URIStr, FilePathStr],
2187
- name: Optional[str] = None,
2188
- checksum: bool = True,
2189
- max_objects: Optional[int] = None,
2190
- ) -> Sequence[ArtifactManifestEntry]:
2191
- """Store the file or directory at the given path within the specified artifact.
2192
-
2193
- Arguments:
2194
- artifact: The artifact doing the storing
2195
- path (str): The path to store
2196
- name (str): If specified, the logical name that should map to `path`
2197
-
2198
- Returns:
2199
- (list[ArtifactManifestEntry]): A list of manifest entries to store within the artifact
2200
- """
2201
- client_id = util.host_from_path(path)
2202
- target_path = util.uri_from_path(path)
2203
- target_artifact = self._cache.get_client_artifact(client_id)
2204
- if not isinstance(target_artifact, Artifact):
2205
- raise RuntimeError("Local Artifact not found - invalid reference")
2206
- target_entry = target_artifact._manifest.entries[target_path]
2207
- if target_entry is None:
2208
- raise RuntimeError("Local entry not found - invalid reference")
2209
-
2210
- # Return the new entry
2211
- return [
2212
- ArtifactManifestEntry(
2213
- path=LogicalFilePathStr(name or os.path.basename(path)),
2214
- ref=path,
2215
- size=0,
2216
- digest=target_entry.digest,
2217
- )
2218
- ]
2219
-
2220
-
2221
- class _ArtifactVersionType(Type):
2222
- name = "artifactVersion"
2223
- types = [Artifact, PublicArtifact]
2224
-
2225
-
2226
- TypeRegistry.add(_ArtifactVersionType)