wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2101 @@
1
+ """Artifact class."""
2
+ import concurrent.futures
3
+ import contextlib
4
+ import datetime
5
+ import json
6
+ import multiprocessing.dummy
7
+ import os
8
+ import platform
9
+ import re
10
+ import shutil
11
+ import tempfile
12
+ import time
13
+ from copy import copy
14
+ from functools import partial
15
+ from pathlib import PurePosixPath
16
+ from typing import (
17
+ IO,
18
+ TYPE_CHECKING,
19
+ Any,
20
+ Dict,
21
+ Generator,
22
+ List,
23
+ Optional,
24
+ Sequence,
25
+ Set,
26
+ Tuple,
27
+ Type,
28
+ Union,
29
+ cast,
30
+ )
31
+ from urllib.parse import urlparse
32
+
33
+ import requests
34
+
35
+ import wandb
36
+ from wandb import data_types, env, util
37
+ from wandb.apis.normalize import normalize_exceptions
38
+ from wandb.apis.public import ArtifactFiles, RetryingClient, Run
39
+ from wandb.data_types import WBValue
40
+ from wandb.errors.term import termerror, termlog, termwarn
41
+ from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
42
+ from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
43
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
44
+ from wandb.sdk.artifacts.artifact_manifests.artifact_manifest_v1 import (
45
+ ArtifactManifestV1,
46
+ )
47
+ from wandb.sdk.artifacts.artifact_saver import get_staging_dir
48
+ from wandb.sdk.artifacts.artifact_state import ArtifactState
49
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
50
+ from wandb.sdk.artifacts.exceptions import (
51
+ ArtifactFinalizedError,
52
+ ArtifactNotLoggedError,
53
+ WaitTimeoutError,
54
+ )
55
+ from wandb.sdk.artifacts.storage_layout import StorageLayout
56
+ from wandb.sdk.artifacts.storage_policies.wandb_storage_policy import WandbStoragePolicy
57
+ from wandb.sdk.data_types._dtypes import Type as WBType
58
+ from wandb.sdk.data_types._dtypes import TypeRegistry
59
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
60
+ from wandb.sdk.lib import filesystem, retry, runid, telemetry
61
+ from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
62
+ from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
63
+
64
+ reset_path = util.vendor_setup()
65
+
66
+ from wandb_gql import gql # noqa: E402
67
+
68
+ reset_path()
69
+
70
+ if TYPE_CHECKING:
71
+ from wandb.sdk.interface.message_future import MessageFuture
72
+
73
+
74
+ class Artifact:
75
+ """Flexible and lightweight building block for dataset and model versioning.
76
+
77
+ Constructs an empty artifact whose contents can be populated using its `add` family
78
+ of functions. Once the artifact has all the desired files, you can call
79
+ `wandb.log_artifact()` to log it.
80
+
81
+ Arguments:
82
+ name: A human-readable name for this artifact, which is how you can identify
83
+ this artifact in the UI or reference it in `use_artifact` calls. Names can
84
+ contain letters, numbers, underscores, hyphens, and dots. The name must be
85
+ unique across a project.
86
+ type: The type of the artifact, which is used to organize and differentiate
87
+ artifacts. Common types include `dataset` or `model`, but you can use any
88
+ string containing letters, numbers, underscores, hyphens, and dots.
89
+ description: Free text that offers a description of the artifact. The
90
+ description is markdown rendered in the UI, so this is a good place to place
91
+ tables, links, etc.
92
+ metadata: Structured data associated with the artifact, for example class
93
+ distribution of a dataset. This will eventually be queryable and plottable
94
+ in the UI. There is a hard limit of 100 total keys.
95
+
96
+ Returns:
97
+ An `Artifact` object.
98
+
99
+ Examples:
100
+ Basic usage:
101
+ ```
102
+ wandb.init()
103
+
104
+ artifact = wandb.Artifact("mnist", type="dataset")
105
+ artifact.add_dir("mnist/")
106
+ wandb.log_artifact(artifact)
107
+ ```
108
+ """
109
+
110
+ _TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts")
111
+ _GQL_FRAGMENT = """
112
+ fragment ArtifactFragment on Artifact {
113
+ id
114
+ artifactSequence {
115
+ project {
116
+ entityName
117
+ name
118
+ }
119
+ name
120
+ }
121
+ versionIndex
122
+ artifactType {
123
+ name
124
+ }
125
+ description
126
+ metadata
127
+ aliases {
128
+ artifactCollectionName
129
+ alias
130
+ }
131
+ state
132
+ commitHash
133
+ fileCount
134
+ createdAt
135
+ updatedAt
136
+ }
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ name: str,
142
+ type: str,
143
+ description: Optional[str] = None,
144
+ metadata: Optional[Dict[str, Any]] = None,
145
+ incremental: bool = False,
146
+ use_as: Optional[str] = None,
147
+ ) -> None:
148
+ if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
149
+ raise ValueError(
150
+ f"Artifact name may only contain alphanumeric characters, dashes, "
151
+ f"underscores, and dots. Invalid name: {name}"
152
+ )
153
+ if type == "job" or type.startswith("wandb-"):
154
+ raise ValueError(
155
+ "Artifact types 'job' and 'wandb-*' are reserved for internal use. "
156
+ "Please use a different type."
157
+ )
158
+ if incremental:
159
+ termwarn("Using experimental arg `incremental`")
160
+
161
+ # Internal.
162
+ self._client: Optional[RetryingClient] = None
163
+ storage_layout = (
164
+ StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
165
+ )
166
+ self._storage_policy = WandbStoragePolicy(
167
+ config={
168
+ "storageLayout": storage_layout,
169
+ # TODO: storage region
170
+ }
171
+ )
172
+ self._tmp_dir: Optional[tempfile.TemporaryDirectory] = None
173
+ self._added_objs: Dict[
174
+ int, Tuple[data_types.WBValue, ArtifactManifestEntry]
175
+ ] = {}
176
+ self._added_local_paths: Dict[str, ArtifactManifestEntry] = {}
177
+ self._save_future: Optional["MessageFuture"] = None
178
+ self._dependent_artifacts: Set["Artifact"] = set()
179
+ self._download_roots: Set[str] = set()
180
+ # Properties.
181
+ self._id: Optional[str] = None
182
+ self._client_id: str = runid.generate_id(128)
183
+ self._sequence_client_id: str = runid.generate_id(128)
184
+ self._entity: Optional[str] = None
185
+ self._project: Optional[str] = None
186
+ self._name: str = name # includes version after saving
187
+ self._version: Optional[str] = None
188
+ self._source_entity: Optional[str] = None
189
+ self._source_project: Optional[str] = None
190
+ self._source_name: str = name # includes version after saving
191
+ self._source_version: Optional[str] = None
192
+ self._type: str = type
193
+ self._description: Optional[str] = description
194
+ self._metadata: dict = self._normalize_metadata(metadata)
195
+ self._aliases: List[str] = []
196
+ self._saved_aliases: List[str] = []
197
+ self._distributed_id: Optional[str] = None
198
+ self._incremental: bool = incremental
199
+ self._use_as: Optional[str] = use_as
200
+ self._state: ArtifactState = ArtifactState.PENDING
201
+ self._manifest: Optional[ArtifactManifest] = ArtifactManifestV1(
202
+ self._storage_policy
203
+ )
204
+ self._commit_hash: Optional[str] = None
205
+ self._file_count: Optional[int] = None
206
+ self._created_at: Optional[str] = None
207
+ self._updated_at: Optional[str] = None
208
+ self._final: bool = False
209
+ # Cache.
210
+ get_artifacts_cache().store_client_artifact(self)
211
+
212
+ def __repr__(self) -> str:
213
+ return f"<Artifact {self.id or self.name}>"
214
+
215
+ @classmethod
216
+ def _from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
217
+ artifact = get_artifacts_cache().get_artifact(artifact_id)
218
+ if artifact is not None:
219
+ return artifact
220
+
221
+ query = gql(
222
+ """
223
+ query ArtifactByID($id: ID!) {
224
+ artifact(id: $id) {
225
+ ...ArtifactFragment
226
+ currentManifest {
227
+ file {
228
+ directUrl
229
+ }
230
+ }
231
+ }
232
+ }
233
+ """
234
+ + cls._GQL_FRAGMENT
235
+ )
236
+ response = client.execute(
237
+ query,
238
+ variable_values={"id": artifact_id},
239
+ )
240
+ attrs = response.get("artifact")
241
+ if attrs is None:
242
+ return None
243
+ entity = attrs["artifactSequence"]["project"]["entityName"]
244
+ project = attrs["artifactSequence"]["project"]["name"]
245
+ name = "{}:v{}".format(attrs["artifactSequence"]["name"], attrs["versionIndex"])
246
+ return cls._from_attrs(entity, project, name, attrs, client)
247
+
248
+ @classmethod
249
+ def _from_name(
250
+ cls, entity: str, project: str, name: str, client: RetryingClient
251
+ ) -> "Artifact":
252
+ query = gql(
253
+ """
254
+ query ArtifactByName(
255
+ $entityName: String!,
256
+ $projectName: String!,
257
+ $name: String!
258
+ ) {
259
+ project(name: $projectName, entityName: $entityName) {
260
+ artifact(name: $name) {
261
+ ...ArtifactFragment
262
+ }
263
+ }
264
+ }
265
+ """
266
+ + cls._GQL_FRAGMENT
267
+ )
268
+ response = client.execute(
269
+ query,
270
+ variable_values={
271
+ "entityName": entity,
272
+ "projectName": project,
273
+ "name": name,
274
+ },
275
+ )
276
+ attrs = response.get("project", {}).get("artifact")
277
+ if attrs is None:
278
+ raise ValueError(
279
+ f"Unable to fetch artifact with name {entity}/{project}/{name}"
280
+ )
281
+ return cls._from_attrs(entity, project, name, attrs, client)
282
+
283
+ @classmethod
284
+ def _from_attrs(
285
+ cls,
286
+ entity: str,
287
+ project: str,
288
+ name: str,
289
+ attrs: Dict[str, Any],
290
+ client: RetryingClient,
291
+ ) -> "Artifact":
292
+ # Placeholder is required to skip validation.
293
+ artifact = cls("placeholder", type="placeholder")
294
+ artifact._client = client
295
+ artifact._id = attrs["id"]
296
+ artifact._entity = entity
297
+ artifact._project = project
298
+ artifact._name = name
299
+ version_aliases = [
300
+ alias["alias"]
301
+ for alias in attrs.get("aliases", [])
302
+ if alias["artifactCollectionName"] == name.split(":")[0]
303
+ and util.alias_is_version_index(alias["alias"])
304
+ ]
305
+ # assert len(version_aliases) == 1
306
+ artifact._version = version_aliases[0]
307
+ artifact._source_entity = attrs["artifactSequence"]["project"]["entityName"]
308
+ artifact._source_project = attrs["artifactSequence"]["project"]["name"]
309
+ artifact._source_name = "{}:v{}".format(
310
+ attrs["artifactSequence"]["name"], attrs["versionIndex"]
311
+ )
312
+ artifact._source_version = "v{}".format(attrs["versionIndex"])
313
+ artifact._type = attrs["artifactType"]["name"]
314
+ artifact._description = attrs["description"]
315
+ artifact.metadata = cls._normalize_metadata(
316
+ json.loads(attrs["metadata"] or "{}")
317
+ )
318
+ artifact._aliases = [
319
+ alias["alias"]
320
+ for alias in attrs.get("aliases", [])
321
+ if alias["artifactCollectionName"] == name.split(":")[0]
322
+ and not util.alias_is_version_index(alias["alias"])
323
+ ]
324
+ artifact._saved_aliases = copy(artifact._aliases)
325
+ artifact._state = ArtifactState(attrs["state"])
326
+ if "currentManifest" in attrs:
327
+ artifact._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
328
+ else:
329
+ artifact._manifest = None
330
+ artifact._commit_hash = attrs["commitHash"]
331
+ artifact._file_count = attrs["fileCount"]
332
+ artifact._created_at = attrs["createdAt"]
333
+ artifact._updated_at = attrs["updatedAt"]
334
+ artifact._final = True
335
+ # Cache.
336
+ get_artifacts_cache().store_artifact(artifact)
337
+ return artifact
338
+
339
+ def new_draft(self) -> "Artifact":
340
+ """Create a new draft artifact with the same content as this committed artifact.
341
+
342
+ The artifact returned can be extended or modified and logged as a new version.
343
+
344
+ Raises:
345
+ ArtifactNotLoggedError: if the artifact has not been logged
346
+ """
347
+ if self._state == ArtifactState.PENDING:
348
+ raise ArtifactNotLoggedError(self, "new_draft")
349
+
350
+ artifact = Artifact(self.source_name.split(":")[0], self.type)
351
+ artifact._description = self.description
352
+ artifact._metadata = self.metadata
353
+ artifact._manifest = ArtifactManifest.from_manifest_json(
354
+ self.manifest.to_manifest_json()
355
+ )
356
+ return artifact
357
+
358
+ # Properties.
359
+
360
+ @property
361
+ def id(self) -> Optional[str]:
362
+ """The artifact's ID."""
363
+ if self._state == ArtifactState.PENDING:
364
+ return None
365
+ assert self._id is not None
366
+ return self._id
367
+
368
+ @property
369
+ def entity(self) -> str:
370
+ """The name of the entity of the secondary (portfolio) artifact collection."""
371
+ if self._state == ArtifactState.PENDING:
372
+ raise ArtifactNotLoggedError(self, "entity")
373
+ assert self._entity is not None
374
+ return self._entity
375
+
376
+ @property
377
+ def project(self) -> str:
378
+ """The name of the project of the secondary (portfolio) artifact collection."""
379
+ if self._state == ArtifactState.PENDING:
380
+ raise ArtifactNotLoggedError(self, "project")
381
+ assert self._project is not None
382
+ return self._project
383
+
384
+ @property
385
+ def name(self) -> str:
386
+ """The artifact name and version in its secondary (portfolio) collection.
387
+
388
+ A string with the format {collection}:{alias}. Before the artifact is saved,
389
+ contains only the name since the version is not yet known.
390
+ """
391
+ return self._name
392
+
393
+ @property
394
+ def qualified_name(self) -> str:
395
+ """The entity/project/name of the secondary (portfolio) collection."""
396
+ return f"{self.entity}/{self.project}/{self.name}"
397
+
398
+ @property
399
+ def version(self) -> str:
400
+ """The artifact's version in its secondary (portfolio) collection."""
401
+ if self._state == ArtifactState.PENDING:
402
+ raise ArtifactNotLoggedError(self, "version")
403
+ assert self._version is not None
404
+ return self._version
405
+
406
+ @property
407
+ def source_entity(self) -> str:
408
+ """The name of the entity of the primary (sequence) artifact collection."""
409
+ if self._state == ArtifactState.PENDING:
410
+ raise ArtifactNotLoggedError(self, "source_entity")
411
+ assert self._source_entity is not None
412
+ return self._source_entity
413
+
414
+ @property
415
+ def source_project(self) -> str:
416
+ """The name of the project of the primary (sequence) artifact collection."""
417
+ if self._state == ArtifactState.PENDING:
418
+ raise ArtifactNotLoggedError(self, "source_project")
419
+ assert self._source_project is not None
420
+ return self._source_project
421
+
422
+ @property
423
+ def source_name(self) -> str:
424
+ """The artifact name and version in its primary (sequence) collection.
425
+
426
+ A string with the format {collection}:{alias}. Before the artifact is saved,
427
+ contains only the name since the version is not yet known.
428
+ """
429
+ return self._source_name
430
+
431
+ @property
432
+ def source_qualified_name(self) -> str:
433
+ """The entity/project/name of the primary (sequence) collection."""
434
+ return f"{self.source_entity}/{self.source_project}/{self.source_name}"
435
+
436
+ @property
437
+ def source_version(self) -> str:
438
+ """The artifact's version in its primary (sequence) collection.
439
+
440
+ A string with the format "v{number}".
441
+ """
442
+ if self._state == ArtifactState.PENDING:
443
+ raise ArtifactNotLoggedError(self, "source_version")
444
+ assert self._source_version is not None
445
+ return self._source_version
446
+
447
+ @property
448
+ def type(self) -> str:
449
+ """The artifact's type."""
450
+ return self._type
451
+
452
+ @property
453
+ def description(self) -> Optional[str]:
454
+ """The artifact description.
455
+
456
+ Free text that offers a user-set description of the artifact.
457
+ """
458
+ return self._description
459
+
460
+ @description.setter
461
+ def description(self, description: Optional[str]) -> None:
462
+ """Set the description of the artifact.
463
+
464
+ The description is markdown rendered in the UI, so this is a good place to put
465
+ links, etc.
466
+
467
+ Arguments:
468
+ desc: Free text that offers a description of the artifact.
469
+ """
470
+ self._description = description
471
+
472
+ @property
473
+ def metadata(self) -> dict:
474
+ """User-defined artifact metadata.
475
+
476
+ Structured data associated with the artifact.
477
+ """
478
+ return self._metadata
479
+
480
+ @metadata.setter
481
+ def metadata(self, metadata: dict) -> None:
482
+ """User-defined artifact metadata.
483
+
484
+ Metadata set this way will eventually be queryable and plottable in the UI; e.g.
485
+ the class distribution of a dataset.
486
+
487
+ Note: There is currently a limit of 100 total keys.
488
+
489
+ Arguments:
490
+ metadata: Structured data associated with the artifact.
491
+ """
492
+ self._metadata = self._normalize_metadata(metadata)
493
+
494
+ @property
495
+ def aliases(self) -> List[str]:
496
+ """The aliases associated with this artifact.
497
+
498
+ The list is mutable and calling `save()` will persist all alias changes.
499
+ """
500
+ if self._state == ArtifactState.PENDING:
501
+ raise ArtifactNotLoggedError(self, "aliases")
502
+ return self._aliases
503
+
504
+ @aliases.setter
505
+ def aliases(self, aliases: List[str]) -> None:
506
+ """Set the aliases associated with this artifact."""
507
+ if self._state == ArtifactState.PENDING:
508
+ raise ArtifactNotLoggedError(self, "aliases")
509
+
510
+ if any(char in alias for alias in aliases for char in ["/", ":"]):
511
+ raise ValueError(
512
+ "Aliases must not contain any of the following characters: /, :"
513
+ )
514
+ self._aliases = aliases
515
+
516
+ @property
517
+ def distributed_id(self) -> Optional[str]:
518
+ return self._distributed_id
519
+
520
+ @distributed_id.setter
521
+ def distributed_id(self, distributed_id: Optional[str]) -> None:
522
+ self._distributed_id = distributed_id
523
+
524
+ @property
525
+ def incremental(self) -> bool:
526
+ return self._incremental
527
+
528
+ @property
529
+ def use_as(self) -> Optional[str]:
530
+ return self._use_as
531
+
532
+ @property
533
+ def state(self) -> str:
534
+ """The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED"."""
535
+ return self._state.value
536
+
537
+ @property
538
+ def manifest(self) -> ArtifactManifest:
539
+ """The artifact's manifest.
540
+
541
+ The manifest lists all of its contents, and can't be changed once the artifact
542
+ has been logged.
543
+ """
544
+ if self._manifest is None:
545
+ query = gql(
546
+ """
547
+ query ArtifactManifest(
548
+ $entityName: String!,
549
+ $projectName: String!,
550
+ $name: String!
551
+ ) {
552
+ project(entityName: $entityName, name: $projectName) {
553
+ artifact(name: $name) {
554
+ currentManifest {
555
+ file {
556
+ directUrl
557
+ }
558
+ }
559
+ }
560
+ }
561
+ }
562
+ """
563
+ )
564
+ assert self._client is not None
565
+ response = self._client.execute(
566
+ query,
567
+ variable_values={
568
+ "entityName": self._entity,
569
+ "projectName": self._project,
570
+ "name": self._name,
571
+ },
572
+ )
573
+ attrs = response["project"]["artifact"]
574
+ self._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
575
+ assert self._manifest is not None
576
+ return self._manifest
577
+
578
+ @property
579
+ def digest(self) -> str:
580
+ """The logical digest of the artifact.
581
+
582
+ The digest is the checksum of the artifact's contents. If an artifact has the
583
+ same digest as the current `latest` version, then `log_artifact` is a no-op.
584
+ """
585
+ return self.manifest.digest()
586
+
587
+ @property
588
+ def size(self) -> int:
589
+ """The total size of the artifact in bytes.
590
+
591
+ Includes any references tracked by this artifact.
592
+ """
593
+ total_size: int = 0
594
+ for entry in self.manifest.entries.values():
595
+ if entry.size is not None:
596
+ total_size += entry.size
597
+ return total_size
598
+
599
+ @property
600
+ def commit_hash(self) -> str:
601
+ """The hash returned when this artifact was committed."""
602
+ if self._state == ArtifactState.PENDING:
603
+ raise ArtifactNotLoggedError(self, "commit_hash")
604
+ assert self._commit_hash is not None
605
+ return self._commit_hash
606
+
607
+ @property
608
+ def file_count(self) -> int:
609
+ """The number of files (including references)."""
610
+ if self._state == ArtifactState.PENDING:
611
+ raise ArtifactNotLoggedError(self, "file_count")
612
+ assert self._file_count is not None
613
+ return self._file_count
614
+
615
+ @property
616
+ def created_at(self) -> str:
617
+ """The time at which the artifact was created."""
618
+ if self._state == ArtifactState.PENDING:
619
+ raise ArtifactNotLoggedError(self, "created_at")
620
+ assert self._created_at is not None
621
+ return self._created_at
622
+
623
+ @property
624
+ def updated_at(self) -> str:
625
+ """The time at which the artifact was last updated."""
626
+ if self._state == ArtifactState.PENDING:
627
+ raise ArtifactNotLoggedError(self, "created_at")
628
+ assert self._created_at is not None
629
+ return self._updated_at or self._created_at
630
+
631
+ # State management.
632
+
633
+ def finalize(self) -> None:
634
+ """Mark this artifact as final, disallowing further modifications.
635
+
636
+ This happens automatically when calling `log_artifact`.
637
+ """
638
+ self._final = True
639
+
640
+ def _ensure_can_add(self) -> None:
641
+ if self._final:
642
+ raise ArtifactFinalizedError(artifact=self)
643
+
644
+ def is_draft(self) -> bool:
645
+ """Whether the artifact is a draft, i.e. it hasn't been saved yet."""
646
+ return self._state == ArtifactState.PENDING
647
+
648
+ def _is_draft_save_started(self) -> bool:
649
+ return self._save_future is not None
650
+
651
+ def save(
652
+ self,
653
+ project: Optional[str] = None,
654
+ settings: Optional["wandb.wandb_sdk.wandb_settings.Settings"] = None,
655
+ ) -> None:
656
+ """Persist any changes made to the artifact.
657
+
658
+ If currently in a run, that run will log this artifact. If not currently in a
659
+ run, a run of type "auto" will be created to track this artifact.
660
+
661
+ Arguments:
662
+ project: A project to use for the artifact in the case that a run is not
663
+ already in context
664
+ settings: A settings object to use when initializing an automatic run. Most
665
+ commonly used in testing harness.
666
+ """
667
+ if self._state != ArtifactState.PENDING:
668
+ return self._update()
669
+
670
+ if self._incremental:
671
+ with telemetry.context() as tel:
672
+ tel.feature.artifact_incremental = True
673
+
674
+ if wandb.run is None:
675
+ if settings is None:
676
+ settings = wandb.Settings(silent="true")
677
+ with wandb.init(project=project, job_type="auto", settings=settings) as run:
678
+ # redoing this here because in this branch we know we didn't
679
+ # have the run at the beginning of the method
680
+ if self._incremental:
681
+ with telemetry.context(run=run) as tel:
682
+ tel.feature.artifact_incremental = True
683
+ run.log_artifact(self)
684
+ else:
685
+ wandb.run.log_artifact(self)
686
+
687
+ def _set_save_future(
688
+ self, save_future: "MessageFuture", client: RetryingClient
689
+ ) -> None:
690
+ self._save_future = save_future
691
+ self._client = client
692
+
693
+ def wait(self, timeout: Optional[int] = None) -> "Artifact":
694
+ """Wait for this artifact to finish logging, if needed.
695
+
696
+ Arguments:
697
+ timeout: Wait up to this long.
698
+ """
699
+ if self._state == ArtifactState.PENDING:
700
+ if self._save_future is None:
701
+ raise ArtifactNotLoggedError(self, "wait")
702
+ result = self._save_future.get(timeout)
703
+ if not result:
704
+ raise WaitTimeoutError(
705
+ "Artifact upload wait timed out, failed to fetch Artifact response"
706
+ )
707
+ response = result.response.log_artifact_response
708
+ if response.error_message:
709
+ raise ValueError(response.error_message)
710
+ self._populate_after_save(response.artifact_id)
711
+ return self
712
+
713
+ def _populate_after_save(self, artifact_id: str) -> None:
714
+ query = gql(
715
+ """
716
+ query ArtifactByIDShort($id: ID!) {
717
+ artifact(id: $id) {
718
+ artifactSequence {
719
+ project {
720
+ entityName
721
+ name
722
+ }
723
+ name
724
+ }
725
+ versionIndex
726
+ aliases {
727
+ artifactCollectionName
728
+ alias
729
+ }
730
+ state
731
+ currentManifest {
732
+ file {
733
+ directUrl
734
+ }
735
+ }
736
+ commitHash
737
+ fileCount
738
+ createdAt
739
+ updatedAt
740
+ }
741
+ }
742
+ """
743
+ )
744
+ assert self._client is not None
745
+ response = self._client.execute(
746
+ query,
747
+ variable_values={"id": artifact_id},
748
+ )
749
+ attrs = response.get("artifact")
750
+ if attrs is None:
751
+ raise ValueError(f"Unable to fetch artifact with id {artifact_id}")
752
+ self._id = artifact_id
753
+ self._entity = attrs["artifactSequence"]["project"]["entityName"]
754
+ self._project = attrs["artifactSequence"]["project"]["name"]
755
+ self._name = "{}:v{}".format(
756
+ attrs["artifactSequence"]["name"], attrs["versionIndex"]
757
+ )
758
+ self._version = "v{}".format(attrs["versionIndex"])
759
+ self._source_entity = self._entity
760
+ self._source_project = self._project
761
+ self._source_name = self._name
762
+ self._source_version = self._version
763
+ self._aliases = [
764
+ alias["alias"]
765
+ for alias in attrs.get("aliases", [])
766
+ if alias["artifactCollectionName"] == self._name.split(":")[0]
767
+ and not util.alias_is_version_index(alias["alias"])
768
+ ]
769
+ self._state = ArtifactState(attrs["state"])
770
+ with requests.get(attrs["currentManifest"]["file"]["directUrl"]) as request:
771
+ request.raise_for_status()
772
+ self._manifest = ArtifactManifest.from_manifest_json(
773
+ json.loads(util.ensure_text(request.content))
774
+ )
775
+ self._commit_hash = attrs["commitHash"]
776
+ self._file_count = attrs["fileCount"]
777
+ self._created_at = attrs["createdAt"]
778
+ self._updated_at = attrs["updatedAt"]
779
+
780
+ @normalize_exceptions
781
+ def _update(self) -> None:
782
+ """Persists artifact changes to the wandb backend."""
783
+ aliases = None
784
+ introspect_query = gql(
785
+ """
786
+ query ProbeServerAddAliasesInput {
787
+ AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
788
+ name
789
+ inputFields {
790
+ name
791
+ }
792
+ }
793
+ }
794
+ """
795
+ )
796
+ assert self._client is not None
797
+ response = self._client.execute(introspect_query)
798
+ if response.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
799
+ aliases_to_add = set(self._aliases) - set(self._saved_aliases)
800
+ aliases_to_delete = set(self._saved_aliases) - set(self._aliases)
801
+ if len(aliases_to_add) > 0:
802
+ add_mutation = gql(
803
+ """
804
+ mutation addAliases(
805
+ $artifactID: ID!,
806
+ $aliases: [ArtifactCollectionAliasInput!]!,
807
+ ) {
808
+ addAliases(
809
+ input: {artifactID: $artifactID, aliases: $aliases}
810
+ ) {
811
+ success
812
+ }
813
+ }
814
+ """
815
+ )
816
+ assert self._client is not None
817
+ self._client.execute(
818
+ add_mutation,
819
+ variable_values={
820
+ "artifactID": self.id,
821
+ "aliases": [
822
+ {
823
+ "entityName": self._entity,
824
+ "projectName": self._project,
825
+ "artifactCollectionName": self._name.split(":")[0],
826
+ "alias": alias,
827
+ }
828
+ for alias in aliases_to_add
829
+ ],
830
+ },
831
+ )
832
+ if len(aliases_to_delete) > 0:
833
+ delete_mutation = gql(
834
+ """
835
+ mutation deleteAliases(
836
+ $artifactID: ID!,
837
+ $aliases: [ArtifactCollectionAliasInput!]!,
838
+ ) {
839
+ deleteAliases(
840
+ input: {artifactID: $artifactID, aliases: $aliases}
841
+ ) {
842
+ success
843
+ }
844
+ }
845
+ """
846
+ )
847
+ assert self._client is not None
848
+ self._client.execute(
849
+ delete_mutation,
850
+ variable_values={
851
+ "artifactID": self.id,
852
+ "aliases": [
853
+ {
854
+ "entityName": self._entity,
855
+ "projectName": self._project,
856
+ "artifactCollectionName": self._name.split(":")[0],
857
+ "alias": alias,
858
+ }
859
+ for alias in aliases_to_delete
860
+ ],
861
+ },
862
+ )
863
+ self._saved_aliases = copy(self._aliases)
864
+ else: # wandb backend version < 0.13.0
865
+ aliases = [
866
+ {
867
+ "artifactCollectionName": self._name.split(":")[0],
868
+ "alias": alias,
869
+ }
870
+ for alias in self._aliases
871
+ ]
872
+
873
+ mutation = gql(
874
+ """
875
+ mutation updateArtifact(
876
+ $artifactID: ID!,
877
+ $description: String,
878
+ $metadata: JSONString,
879
+ $aliases: [ArtifactAliasInput!]
880
+ ) {
881
+ updateArtifact(
882
+ input: {
883
+ artifactID: $artifactID,
884
+ description: $description,
885
+ metadata: $metadata,
886
+ aliases: $aliases
887
+ }
888
+ ) {
889
+ artifact {
890
+ id
891
+ }
892
+ }
893
+ }
894
+ """
895
+ )
896
+ assert self._client is not None
897
+ self._client.execute(
898
+ mutation,
899
+ variable_values={
900
+ "artifactID": self.id,
901
+ "description": self.description,
902
+ "metadata": util.json_dumps_safer(self.metadata),
903
+ "aliases": aliases,
904
+ },
905
+ )
906
+
907
+ # Adding, removing, getting entries.
908
+
909
+ def __getitem__(self, name: str) -> Optional[data_types.WBValue]:
910
+ """Get the WBValue object located at the artifact relative `name`.
911
+
912
+ Arguments:
913
+ name: The artifact relative name to get
914
+
915
+ Raises:
916
+ ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
917
+
918
+ Examples:
919
+ Basic usage:
920
+ ```
921
+ artifact = wandb.Artifact("my_table", type="dataset")
922
+ table = wandb.Table(
923
+ columns=["a", "b", "c"],
924
+ data=[(i, i * 2, 2**i) for i in range(10)]
925
+ )
926
+ artifact["my_table"] = table
927
+
928
+ wandb.log_artifact(artifact)
929
+ ```
930
+
931
+ Retrieving an object:
932
+ ```
933
+ artifact = wandb.use_artifact("my_table:latest")
934
+ table = artifact["my_table"]
935
+ ```
936
+ """
937
+ return self.get(name)
938
+
939
+ def __setitem__(self, name: str, item: data_types.WBValue) -> ArtifactManifestEntry:
940
+ """Add `item` to the artifact at path `name`.
941
+
942
+ Arguments:
943
+ name: The path within the artifact to add the object.
944
+ item: The object to add.
945
+
946
+ Returns:
947
+ The added manifest entry
948
+
949
+ Raises:
950
+ ArtifactFinalizedError: if the artifact has already been finalized.
951
+
952
+ Examples:
953
+ Basic usage:
954
+ ```
955
+ artifact = wandb.Artifact("my_table", type="dataset")
956
+ table = wandb.Table(
957
+ columns=["a", "b", "c"],
958
+ data=[(i, i * 2, 2**i) for i in range(10)]
959
+ )
960
+ artifact["my_table"] = table
961
+
962
+ wandb.log_artifact(artifact)
963
+ ```
964
+
965
+ Retrieving an object:
966
+ ```
967
+ artifact = wandb.use_artifact("my_table:latest")
968
+ table = artifact["my_table"]
969
+ ```
970
+ """
971
+ return self.add(item, name)
972
+
973
+ @contextlib.contextmanager
974
+ def new_file(
975
+ self, name: str, mode: str = "w", encoding: Optional[str] = None
976
+ ) -> Generator[IO, None, None]:
977
+ """Open a new temporary file that will be automatically added to the artifact.
978
+
979
+ Arguments:
980
+ name: The name of the new file being added to the artifact.
981
+ mode: The mode in which to open the new file.
982
+ encoding: The encoding in which to open the new file.
983
+
984
+ Returns:
985
+ A new file object that can be written to. Upon closing, the file will be
986
+ automatically added to the artifact.
987
+
988
+ Raises:
989
+ ArtifactFinalizedError: if the artifact has already been finalized.
990
+
991
+ Examples:
992
+ ```
993
+ artifact = wandb.Artifact("my_data", type="dataset")
994
+ with artifact.new_file("hello.txt") as f:
995
+ f.write("hello!")
996
+ wandb.log_artifact(artifact)
997
+ ```
998
+ """
999
+ self._ensure_can_add()
1000
+ if self._tmp_dir is None:
1001
+ self._tmp_dir = tempfile.TemporaryDirectory()
1002
+ path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
1003
+ if os.path.exists(path):
1004
+ raise ValueError(f"File with name {name!r} already exists at {path!r}")
1005
+
1006
+ filesystem.mkdir_exists_ok(os.path.dirname(path))
1007
+ try:
1008
+ with util.fsync_open(path, mode, encoding) as f:
1009
+ yield f
1010
+ except UnicodeEncodeError as e:
1011
+ termerror(
1012
+ f"Failed to open the provided file (UnicodeEncodeError: {e}). Please "
1013
+ f"provide the proper encoding."
1014
+ )
1015
+ raise e
1016
+
1017
+ self.add_file(path, name=name)
1018
+
1019
+ def add_file(
1020
+ self,
1021
+ local_path: str,
1022
+ name: Optional[str] = None,
1023
+ is_tmp: Optional[bool] = False,
1024
+ ) -> ArtifactManifestEntry:
1025
+ """Add a local file to the artifact.
1026
+
1027
+ Arguments:
1028
+ local_path: The path to the file being added.
1029
+ name: The path within the artifact to use for the file being added. Defaults
1030
+ to the basename of the file.
1031
+ is_tmp: If true, then the file is renamed deterministically to avoid
1032
+ collisions.
1033
+
1034
+ Returns:
1035
+ The added manifest entry
1036
+
1037
+ Raises:
1038
+ ArtifactFinalizedError: if the artifact has already been finalized
1039
+
1040
+ Examples:
1041
+ Add a file without an explicit name:
1042
+ ```
1043
+ # Add as `file.txt'
1044
+ artifact.add_file("path/to/file.txt")
1045
+ ```
1046
+
1047
+ Add a file with an explicit name:
1048
+ ```
1049
+ # Add as 'new/path/file.txt'
1050
+ artifact.add_file("path/to/file.txt", name="new/path/file.txt")
1051
+ ```
1052
+ """
1053
+ self._ensure_can_add()
1054
+ if not os.path.isfile(local_path):
1055
+ raise ValueError("Path is not a file: %s" % local_path)
1056
+
1057
+ name = LogicalPath(name or os.path.basename(local_path))
1058
+ digest = md5_file_b64(local_path)
1059
+
1060
+ if is_tmp:
1061
+ file_path, file_name = os.path.split(name)
1062
+ file_name_parts = file_name.split(".")
1063
+ file_name_parts[0] = b64_to_hex_id(digest)[:20]
1064
+ name = os.path.join(file_path, ".".join(file_name_parts))
1065
+
1066
+ return self._add_local_file(name, local_path, digest=digest)
1067
+
1068
+ def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
1069
+ """Add a local directory to the artifact.
1070
+
1071
+ Arguments:
1072
+ local_path: The path to the directory being added.
1073
+ name: The path within the artifact to use for the directory being added.
1074
+ Defaults to the root of the artifact.
1075
+
1076
+ Raises:
1077
+ ArtifactFinalizedError: if the artifact has already been finalized
1078
+
1079
+ Examples:
1080
+ Add a directory without an explicit name:
1081
+ ```
1082
+ # All files in `my_dir/` are added at the root of the artifact.
1083
+ artifact.add_dir("my_dir/")
1084
+ ```
1085
+
1086
+ Add a directory and name it explicitly:
1087
+ ```
1088
+ # All files in `my_dir/` are added under `destination/`.
1089
+ artifact.add_dir("my_dir/", name="destination")
1090
+ ```
1091
+ """
1092
+ self._ensure_can_add()
1093
+ if not os.path.isdir(local_path):
1094
+ raise ValueError("Path is not a directory: %s" % local_path)
1095
+
1096
+ termlog(
1097
+ "Adding directory to artifact (%s)... "
1098
+ % os.path.join(".", os.path.normpath(local_path)),
1099
+ newline=False,
1100
+ )
1101
+ start_time = time.time()
1102
+
1103
+ paths = []
1104
+ for dirpath, _, filenames in os.walk(local_path, followlinks=True):
1105
+ for fname in filenames:
1106
+ physical_path = os.path.join(dirpath, fname)
1107
+ logical_path = os.path.relpath(physical_path, start=local_path)
1108
+ if name is not None:
1109
+ logical_path = os.path.join(name, logical_path)
1110
+ paths.append((logical_path, physical_path))
1111
+
1112
+ def add_manifest_file(log_phy_path: Tuple[str, str]) -> None:
1113
+ logical_path, physical_path = log_phy_path
1114
+ self._add_local_file(logical_path, physical_path)
1115
+
1116
+ num_threads = 8
1117
+ pool = multiprocessing.dummy.Pool(num_threads)
1118
+ pool.map(add_manifest_file, paths)
1119
+ pool.close()
1120
+ pool.join()
1121
+
1122
+ termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
1123
+
1124
+ def add_reference(
1125
+ self,
1126
+ uri: Union[ArtifactManifestEntry, str],
1127
+ name: Optional[StrPath] = None,
1128
+ checksum: bool = True,
1129
+ max_objects: Optional[int] = None,
1130
+ ) -> Sequence[ArtifactManifestEntry]:
1131
+ """Add a reference denoted by a URI to the artifact.
1132
+
1133
+ Unlike adding files or directories, references are NOT uploaded to W&B. However,
1134
+ artifact methods such as `download()` can be used regardless of whether the
1135
+ artifact contains references or uploaded files.
1136
+
1137
+ By default, W&B offers special handling for the following schemes:
1138
+
1139
+ - http(s): The size and digest of the file will be inferred by the
1140
+ `Content-Length` and the `ETag` response headers returned by the server.
1141
+ - s3: The checksum and size will be pulled from the object metadata. If bucket
1142
+ versioning is enabled, then the version ID is also tracked.
1143
+ - gs: The checksum and size will be pulled from the object metadata. If bucket
1144
+ versioning is enabled, then the version ID is also tracked.
1145
+ - https, domain matching *.blob.core.windows.net (Azure): The checksum and size
1146
+ will be pulled from the blob metadata. If storage account versioning is
1147
+ enabled, then the version ID is also tracked.
1148
+ - file: The checksum and size will be pulled from the file system. This scheme
1149
+ is useful if you have an NFS share or other externally mounted volume
1150
+ containing files you wish to track but not necessarily upload.
1151
+
1152
+ For any other scheme, the digest is just a hash of the URI and the size is left
1153
+ blank.
1154
+
1155
+ Arguments:
1156
+ uri: The URI path of the reference to add. Can be an object returned from
1157
+ Artifact.get_path to store a reference to another artifact's entry.
1158
+ name: The path within the artifact to place the contents of this reference
1159
+ checksum: Whether or not to checksum the resource(s) located at the
1160
+ reference URI. Checksumming is strongly recommended as it enables
1161
+ automatic integrity validation, however it can be disabled to speed up
1162
+ artifact creation. (default: True)
1163
+ max_objects: The maximum number of objects to consider when adding a
1164
+ reference that points to directory or bucket store prefix. For S3 and
1165
+ GCS, this limit is 10,000 by default but is uncapped for other URI
1166
+ schemes. (default: None)
1167
+
1168
+ Returns:
1169
+ The added manifest entries.
1170
+
1171
+ Raises:
1172
+ ArtifactFinalizedError: if the artifact has already been finalized.
1173
+
1174
+ Examples:
1175
+ Add an HTTP link:
1176
+ ```python
1177
+ # Adds `file.txt` to the root of the artifact as a reference.
1178
+ artifact.add_reference("http://myserver.com/file.txt")
1179
+ ```
1180
+
1181
+ Add an S3 prefix without an explicit name:
1182
+ ```python
1183
+ # All objects under `prefix/` will be added at the root of the artifact.
1184
+ artifact.add_reference("s3://mybucket/prefix")
1185
+ ```
1186
+
1187
+ Add a GCS prefix with an explicit name:
1188
+ ```python
1189
+ # All objects under `prefix/` will be added under `path/` at the artifact
1190
+ # root.
1191
+ artifact.add_reference("gs://mybucket/prefix", name="path")
1192
+ ```
1193
+ """
1194
+ self._ensure_can_add()
1195
+ if name is not None:
1196
+ name = LogicalPath(name)
1197
+
1198
+ # This is a bit of a hack, we want to check if the uri is a of the type
1199
+ # ArtifactManifestEntry. If so, then recover the reference URL.
1200
+ if isinstance(uri, ArtifactManifestEntry):
1201
+ uri_str = uri.ref_url()
1202
+ elif isinstance(uri, str):
1203
+ uri_str = uri
1204
+ url = urlparse(str(uri_str))
1205
+ if not url.scheme:
1206
+ raise ValueError(
1207
+ "References must be URIs. To reference a local file, use file://"
1208
+ )
1209
+
1210
+ manifest_entries = self._storage_policy.store_reference(
1211
+ self,
1212
+ URIStr(uri_str),
1213
+ name=name,
1214
+ checksum=checksum,
1215
+ max_objects=max_objects,
1216
+ )
1217
+ for entry in manifest_entries:
1218
+ self.manifest.add_entry(entry)
1219
+
1220
+ return manifest_entries
1221
+
1222
+ def add(self, obj: data_types.WBValue, name: StrPath) -> ArtifactManifestEntry:
1223
+ """Add wandb.WBValue `obj` to the artifact.
1224
+
1225
+ Arguments:
1226
+ obj: The object to add. Currently support one of Bokeh, JoinedTable,
1227
+ PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D, Audio,
1228
+ Image, Video, Html, Object3D
1229
+ name: The path within the artifact to add the object.
1230
+
1231
+ Returns:
1232
+ The added manifest entry
1233
+
1234
+ Raises:
1235
+ ArtifactFinalizedError: if the artifact has already been finalized
1236
+
1237
+ Examples:
1238
+ Basic usage:
1239
+ ```
1240
+ artifact = wandb.Artifact("my_table", type="dataset")
1241
+ table = wandb.Table(
1242
+ columns=["a", "b", "c"],
1243
+ data=[(i, i * 2, 2**i) for i in range(10)]
1244
+ )
1245
+ artifact.add(table, "my_table")
1246
+
1247
+ wandb.log_artifact(artifact)
1248
+ ```
1249
+
1250
+ Retrieve an object:
1251
+ ```
1252
+ artifact = wandb.use_artifact("my_table:latest")
1253
+ table = artifact.get("my_table")
1254
+ ```
1255
+ """
1256
+ self._ensure_can_add()
1257
+ name = LogicalPath(name)
1258
+
1259
+ # This is a "hack" to automatically rename tables added to
1260
+ # the wandb /media/tables directory to their sha-based name.
1261
+ # TODO: figure out a more appropriate convention.
1262
+ is_tmp_name = name.startswith("media/tables")
1263
+
1264
+ # Validate that the object is one of the correct wandb.Media types
1265
+ # TODO: move this to checking subclass of wandb.Media once all are
1266
+ # generally supported
1267
+ allowed_types = [
1268
+ data_types.Bokeh,
1269
+ data_types.JoinedTable,
1270
+ data_types.PartitionedTable,
1271
+ data_types.Table,
1272
+ data_types.Classes,
1273
+ data_types.ImageMask,
1274
+ data_types.BoundingBoxes2D,
1275
+ data_types.Audio,
1276
+ data_types.Image,
1277
+ data_types.Video,
1278
+ data_types.Html,
1279
+ data_types.Object3D,
1280
+ data_types.Molecule,
1281
+ data_types._SavedModel,
1282
+ ]
1283
+
1284
+ if not any(isinstance(obj, t) for t in allowed_types):
1285
+ raise ValueError(
1286
+ "Found object of type {}, expected one of {}.".format(
1287
+ obj.__class__, allowed_types
1288
+ )
1289
+ )
1290
+
1291
+ obj_id = id(obj)
1292
+ if obj_id in self._added_objs:
1293
+ return self._added_objs[obj_id][1]
1294
+
1295
+ # If the object is coming from another artifact, save it as a reference
1296
+ ref_path = obj._get_artifact_entry_ref_url()
1297
+ if ref_path is not None:
1298
+ return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
1299
+
1300
+ val = obj.to_json(self)
1301
+ name = obj.with_suffix(name)
1302
+ entry = self.manifest.get_entry_by_path(name)
1303
+ if entry is not None:
1304
+ return entry
1305
+
1306
+ def do_write(f: IO) -> None:
1307
+ import json
1308
+
1309
+ # TODO: Do we need to open with utf-8 codec?
1310
+ f.write(json.dumps(val, sort_keys=True))
1311
+
1312
+ if is_tmp_name:
1313
+ file_path = os.path.join(self._TMP_DIR.name, str(id(self)), name)
1314
+ folder_path, _ = os.path.split(file_path)
1315
+ if not os.path.exists(folder_path):
1316
+ os.makedirs(folder_path)
1317
+ with open(file_path, "w") as tmp_f:
1318
+ do_write(tmp_f)
1319
+ else:
1320
+ with self.new_file(name) as f:
1321
+ file_path = f.name
1322
+ do_write(f)
1323
+
1324
+ # Note, we add the file from our temp directory.
1325
+ # It will be added again later on finalize, but succeed since
1326
+ # the checksum should match
1327
+ entry = self.add_file(file_path, name, is_tmp_name)
1328
+ # We store a reference to the obj so that its id doesn't get reused.
1329
+ self._added_objs[obj_id] = (obj, entry)
1330
+ if obj._artifact_target is None:
1331
+ obj._set_artifact_target(self, entry.path)
1332
+
1333
+ if is_tmp_name:
1334
+ if os.path.exists(file_path):
1335
+ os.remove(file_path)
1336
+
1337
+ return entry
1338
+
1339
+ def _add_local_file(
1340
+ self, name: StrPath, path: StrPath, digest: Optional[B64MD5] = None
1341
+ ) -> ArtifactManifestEntry:
1342
+ with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
1343
+ staging_path = f.name
1344
+ shutil.copyfile(path, staging_path)
1345
+ os.chmod(staging_path, 0o400)
1346
+
1347
+ entry = ArtifactManifestEntry(
1348
+ path=name,
1349
+ digest=digest or md5_file_b64(staging_path),
1350
+ size=os.path.getsize(staging_path),
1351
+ local_path=staging_path,
1352
+ )
1353
+
1354
+ self.manifest.add_entry(entry)
1355
+ self._added_local_paths[os.fspath(path)] = entry
1356
+ return entry
1357
+
1358
+ def remove(self, item: Union[StrPath, "ArtifactManifestEntry"]) -> None:
1359
+ """Remove an item from the artifact.
1360
+
1361
+ Arguments:
1362
+ item: the item to remove. Can be a specific manifest entry or the name of an
1363
+ artifact-relative path. If the item matches a directory all items in
1364
+ that directory will be removed.
1365
+
1366
+ Raises:
1367
+ ArtifactFinalizedError: if the artifact has already been finalized.
1368
+ FileNotFoundError: if the item isn't found in the artifact.
1369
+ """
1370
+ self._ensure_can_add()
1371
+
1372
+ if isinstance(item, ArtifactManifestEntry):
1373
+ self.manifest.remove_entry(item)
1374
+ return
1375
+
1376
+ path = str(PurePosixPath(item))
1377
+ entry = self.manifest.get_entry_by_path(path)
1378
+ if entry:
1379
+ self.manifest.remove_entry(entry)
1380
+ return
1381
+
1382
+ entries = self.manifest.get_entries_in_directory(path)
1383
+ if not entries:
1384
+ raise FileNotFoundError(f"No such file or directory: {path}")
1385
+ for entry in entries:
1386
+ self.manifest.remove_entry(entry)
1387
+
1388
+ def get_path(self, name: StrPath) -> ArtifactManifestEntry:
1389
+ """Get the entry with the given name.
1390
+
1391
+ Arguments:
1392
+ name: The artifact relative name to get
1393
+
1394
+ Raises:
1395
+ ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
1396
+ KeyError: if the artifact doesn't contain an entry with the given name
1397
+
1398
+ Examples:
1399
+ Basic usage:
1400
+ ```
1401
+ # Run logging the artifact
1402
+ with wandb.init() as r:
1403
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1404
+ artifact.add_file("path/to/file.txt")
1405
+ wandb.log_artifact(artifact)
1406
+
1407
+ # Run using the artifact
1408
+ with wandb.init() as r:
1409
+ artifact = r.use_artifact("my_dataset:latest")
1410
+ path = artifact.get_path("file.txt")
1411
+
1412
+ # Can now download 'file.txt' directly:
1413
+ path.download()
1414
+ ```
1415
+ """
1416
+ if self._state == ArtifactState.PENDING:
1417
+ raise ArtifactNotLoggedError(self, "get_path")
1418
+
1419
+ name = LogicalPath(name)
1420
+ entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0]
1421
+ if entry is None:
1422
+ raise KeyError("Path not contained in artifact: %s" % name)
1423
+ entry._parent_artifact = self
1424
+ return entry
1425
+
1426
+ def get(self, name: str) -> Optional[data_types.WBValue]:
1427
+ """Get the WBValue object located at the artifact relative `name`.
1428
+
1429
+ Arguments:
1430
+ name: The artifact relative name to get
1431
+
1432
+ Raises:
1433
+ ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
1434
+
1435
+ Examples:
1436
+ Basic usage:
1437
+ ```
1438
+ # Run logging the artifact
1439
+ with wandb.init() as r:
1440
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1441
+ table = wandb.Table(
1442
+ columns=["a", "b", "c"],
1443
+ data=[(i, i * 2, 2**i) for i in range(10)]
1444
+ )
1445
+ artifact.add(table, "my_table")
1446
+ wandb.log_artifact(artifact)
1447
+
1448
+ # Run using the artifact
1449
+ with wandb.init() as r:
1450
+ artifact = r.use_artifact("my_dataset:latest")
1451
+ table = artifact.get("my_table")
1452
+ ```
1453
+ """
1454
+ if self._state == ArtifactState.PENDING:
1455
+ raise ArtifactNotLoggedError(self, "get")
1456
+
1457
+ entry, wb_class = self._get_obj_entry(name)
1458
+ if entry is None or wb_class is None:
1459
+ return None
1460
+
1461
+ # If the entry is a reference from another artifact, then get it directly from
1462
+ # that artifact.
1463
+ if entry._is_artifact_reference():
1464
+ assert self._client is not None
1465
+ artifact = entry._get_referenced_artifact(self._client)
1466
+ return artifact.get(util.uri_from_path(entry.ref))
1467
+
1468
+ # Special case for wandb.Table. This is intended to be a short term
1469
+ # optimization. Since tables are likely to download many other assets in
1470
+ # artifact(s), we eagerly download the artifact using the parallelized
1471
+ # `artifact.download`. In the future, we should refactor the deserialization
1472
+ # pattern such that this special case is not needed.
1473
+ if wb_class == wandb.Table:
1474
+ self.download(recursive=True)
1475
+
1476
+ # Get the ArtifactManifestEntry
1477
+ item = self.get_path(entry.path)
1478
+ item_path = item.download()
1479
+
1480
+ # Load the object from the JSON blob
1481
+ result = None
1482
+ json_obj = {}
1483
+ with open(item_path) as file:
1484
+ json_obj = json.load(file)
1485
+ result = wb_class.from_json(json_obj, self)
1486
+ result._set_artifact_source(self, name)
1487
+ return result
1488
+
1489
+ def get_added_local_path_name(self, local_path: str) -> Optional[str]:
1490
+ """Get the artifact relative name of a file added by a local filesystem path.
1491
+
1492
+ Arguments:
1493
+ local_path: The local path to resolve into an artifact relative name.
1494
+
1495
+ Returns:
1496
+ The artifact relative name.
1497
+
1498
+ Examples:
1499
+ Basic usage:
1500
+ ```
1501
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1502
+ artifact.add_file("path/to/file.txt", name="artifact/path/file.txt")
1503
+
1504
+ # Returns `artifact/path/file.txt`:
1505
+ name = artifact.get_added_local_path_name("path/to/file.txt")
1506
+ ```
1507
+ """
1508
+ entry = self._added_local_paths.get(local_path, None)
1509
+ if entry is None:
1510
+ return None
1511
+ return entry.path
1512
+
1513
+ def _get_obj_entry(
1514
+ self, name: str
1515
+ ) -> Tuple[Optional["ArtifactManifestEntry"], Optional[Type[WBValue]]]:
1516
+ """Return an object entry by name, handling any type suffixes.
1517
+
1518
+ When objects are added with `.add(obj, name)`, the name is typically changed to
1519
+ include the suffix of the object type when serializing to JSON. So we need to be
1520
+ able to resolve a name, without tasking the user with appending .THING.json.
1521
+ This method returns an entry if it exists by a suffixed name.
1522
+
1523
+ Arguments:
1524
+ name: name used when adding
1525
+ """
1526
+ for wb_class in WBValue.type_mapping().values():
1527
+ wandb_file_name = wb_class.with_suffix(name)
1528
+ entry = self.manifest.entries.get(wandb_file_name)
1529
+ if entry is not None:
1530
+ return entry, wb_class
1531
+ return None, None
1532
+
1533
+ # Downloading.
1534
+
1535
+ def download(
1536
+ self,
1537
+ root: Optional[str] = None,
1538
+ recursive: bool = False,
1539
+ allow_missing_references: bool = False,
1540
+ ) -> FilePathStr:
1541
+ """Download the contents of the artifact to the specified root directory.
1542
+
1543
+ NOTE: Any existing files at `root` are left untouched. Explicitly delete
1544
+ root before calling `download` if you want the contents of `root` to exactly
1545
+ match the artifact.
1546
+
1547
+ Arguments:
1548
+ root: The directory in which to download this artifact's files.
1549
+ recursive: If true, then all dependent artifacts are eagerly downloaded.
1550
+ Otherwise, the dependent artifacts are downloaded as needed.
1551
+
1552
+ Returns:
1553
+ The path to the downloaded contents.
1554
+
1555
+ Raises:
1556
+ ArtifactNotLoggedError: if the artifact has not been logged
1557
+ """
1558
+ if self._state == ArtifactState.PENDING:
1559
+ raise ArtifactNotLoggedError(self, "download")
1560
+
1561
+ root = root or self._default_root()
1562
+ self._add_download_root(root)
1563
+
1564
+ nfiles = len(self.manifest.entries)
1565
+ size = sum(e.size or 0 for e in self.manifest.entries.values())
1566
+ log = False
1567
+ if nfiles > 5000 or size > 50 * 1024 * 1024:
1568
+ log = True
1569
+ termlog(
1570
+ "Downloading large artifact {}, {:.2f}MB. {} files... ".format(
1571
+ self.name, size / (1024 * 1024), nfiles
1572
+ ),
1573
+ )
1574
+ start_time = datetime.datetime.now()
1575
+ download_logger = ArtifactDownloadLogger(nfiles=nfiles)
1576
+
1577
+ def _download_entry(
1578
+ entry: ArtifactManifestEntry,
1579
+ api_key: Optional[str],
1580
+ cookies: Optional[Dict],
1581
+ headers: Optional[Dict],
1582
+ ) -> None:
1583
+ _thread_local_api_settings.api_key = api_key
1584
+ _thread_local_api_settings.cookies = cookies
1585
+ _thread_local_api_settings.headers = headers
1586
+
1587
+ try:
1588
+ entry.download(root)
1589
+ except FileNotFoundError as e:
1590
+ if allow_missing_references:
1591
+ wandb.termwarn(str(e))
1592
+ return
1593
+ raise
1594
+ download_logger.notify_downloaded()
1595
+
1596
+ download_entry = partial(
1597
+ _download_entry,
1598
+ api_key=_thread_local_api_settings.api_key,
1599
+ cookies=_thread_local_api_settings.cookies,
1600
+ headers=_thread_local_api_settings.headers,
1601
+ )
1602
+
1603
+ with concurrent.futures.ThreadPoolExecutor(64) as executor:
1604
+ active_futures = set()
1605
+ has_next_page = True
1606
+ cursor = None
1607
+ while has_next_page:
1608
+ attrs = self._fetch_file_urls(cursor)
1609
+ has_next_page = attrs["pageInfo"]["hasNextPage"]
1610
+ cursor = attrs["pageInfo"]["endCursor"]
1611
+ for edge in attrs["edges"]:
1612
+ entry = self.get_path(edge["node"]["name"])
1613
+ entry._download_url = edge["node"]["directUrl"]
1614
+ active_futures.add(executor.submit(download_entry, entry))
1615
+ # Wait for download threads to catch up.
1616
+ max_backlog = 5000
1617
+ if len(active_futures) > max_backlog:
1618
+ for future in concurrent.futures.as_completed(active_futures):
1619
+ future.result() # check for errors
1620
+ active_futures.remove(future)
1621
+ if len(active_futures) <= max_backlog:
1622
+ break
1623
+ # Check for errors.
1624
+ for future in concurrent.futures.as_completed(active_futures):
1625
+ future.result()
1626
+
1627
+ if recursive:
1628
+ for dependent_artifact in self._dependent_artifacts:
1629
+ dependent_artifact.download()
1630
+
1631
+ if log:
1632
+ now = datetime.datetime.now()
1633
+ delta = abs((now - start_time).total_seconds())
1634
+ hours = int(delta // 3600)
1635
+ minutes = int((delta - hours * 3600) // 60)
1636
+ seconds = delta - hours * 3600 - minutes * 60
1637
+ termlog(
1638
+ f"Done. {hours}:{minutes}:{seconds:.1f}",
1639
+ prefix=False,
1640
+ )
1641
+ return FilePathStr(root)
1642
+
1643
+ @retry.retriable(
1644
+ retry_timedelta=datetime.timedelta(minutes=3),
1645
+ retryable_exceptions=(requests.RequestException),
1646
+ )
1647
+ def _fetch_file_urls(self, cursor: Optional[str]) -> Any:
1648
+ query = gql(
1649
+ """
1650
+ query ArtifactFileURLs($id: ID!, $cursor: String) {
1651
+ artifact(id: $id) {
1652
+ files(after: $cursor, first: 5000) {
1653
+ pageInfo {
1654
+ hasNextPage
1655
+ endCursor
1656
+ }
1657
+ edges {
1658
+ node {
1659
+ name
1660
+ directUrl
1661
+ }
1662
+ }
1663
+ }
1664
+ }
1665
+ }
1666
+ """
1667
+ )
1668
+ assert self._client is not None
1669
+ response = self._client.execute(
1670
+ query,
1671
+ variable_values={"id": self.id, "cursor": cursor},
1672
+ timeout=60,
1673
+ )
1674
+ return response["artifact"]["files"]
1675
+
1676
+ def checkout(self, root: Optional[str] = None) -> str:
1677
+ """Replace the specified root directory with the contents of the artifact.
1678
+
1679
+ WARNING: This will DELETE all files in `root` that are not included in the
1680
+ artifact.
1681
+
1682
+ Arguments:
1683
+ root: The directory to replace with this artifact's files.
1684
+
1685
+ Returns:
1686
+ The path to the checked out contents.
1687
+
1688
+ Raises:
1689
+ ArtifactNotLoggedError: if the artifact has not been logged
1690
+ """
1691
+ if self._state == ArtifactState.PENDING:
1692
+ raise ArtifactNotLoggedError(self, "checkout")
1693
+
1694
+ root = root or self._default_root(include_version=False)
1695
+
1696
+ for dirpath, _, files in os.walk(root):
1697
+ for file in files:
1698
+ full_path = os.path.join(dirpath, file)
1699
+ artifact_path = os.path.relpath(full_path, start=root)
1700
+ try:
1701
+ self.get_path(artifact_path)
1702
+ except KeyError:
1703
+ # File is not part of the artifact, remove it.
1704
+ os.remove(full_path)
1705
+
1706
+ return self.download(root=root)
1707
+
1708
+ def verify(self, root: Optional[str] = None) -> None:
1709
+ """Verify that the actual contents of an artifact match the manifest.
1710
+
1711
+ All files in the directory are checksummed and the checksums are then
1712
+ cross-referenced against the artifact's manifest.
1713
+
1714
+ NOTE: References are not verified.
1715
+
1716
+ Arguments:
1717
+ root: The directory to verify. If None artifact will be downloaded to
1718
+ './artifacts/self.name/'
1719
+
1720
+ Raises:
1721
+ ArtifactNotLoggedError: if the artifact has not been logged
1722
+ ValueError: If the verification fails.
1723
+ """
1724
+ if self._state == ArtifactState.PENDING:
1725
+ raise ArtifactNotLoggedError(self, "verify")
1726
+
1727
+ root = root or self._default_root()
1728
+
1729
+ for dirpath, _, files in os.walk(root):
1730
+ for file in files:
1731
+ full_path = os.path.join(dirpath, file)
1732
+ artifact_path = os.path.relpath(full_path, start=root)
1733
+ try:
1734
+ self.get_path(artifact_path)
1735
+ except KeyError:
1736
+ raise ValueError(
1737
+ "Found file {} which is not a member of artifact {}".format(
1738
+ full_path, self.name
1739
+ )
1740
+ )
1741
+
1742
+ ref_count = 0
1743
+ for entry in self.manifest.entries.values():
1744
+ if entry.ref is None:
1745
+ if md5_file_b64(os.path.join(root, entry.path)) != entry.digest:
1746
+ raise ValueError("Digest mismatch for file: %s" % entry.path)
1747
+ else:
1748
+ ref_count += 1
1749
+ if ref_count > 0:
1750
+ print("Warning: skipped verification of %s refs" % ref_count)
1751
+
1752
+ def file(self, root: Optional[str] = None) -> StrPath:
1753
+ """Download a single file artifact to dir specified by the root.
1754
+
1755
+ Arguments:
1756
+ root: The root directory in which to place the file. Defaults to
1757
+ './artifacts/self.name/'.
1758
+
1759
+ Returns:
1760
+ The full path of the downloaded file.
1761
+
1762
+ Raises:
1763
+ ArtifactNotLoggedError: if the artifact has not been logged
1764
+ ValueError: if the artifact contains more than one file
1765
+ """
1766
+ if self._state == ArtifactState.PENDING:
1767
+ raise ArtifactNotLoggedError(self, "file")
1768
+
1769
+ if root is None:
1770
+ root = os.path.join(".", "artifacts", self.name)
1771
+
1772
+ if len(self.manifest.entries) > 1:
1773
+ raise ValueError(
1774
+ "This artifact contains more than one file, call `.download()` to get "
1775
+ 'all files or call .get_path("filename").download()'
1776
+ )
1777
+
1778
+ return self.get_path(list(self.manifest.entries)[0]).download(root)
1779
+
1780
+ def files(
1781
+ self, names: Optional[List[str]] = None, per_page: int = 50
1782
+ ) -> ArtifactFiles:
1783
+ """Iterate over all files stored in this artifact.
1784
+
1785
+ Arguments:
1786
+ names: The filename paths relative to the root of the artifact you wish to
1787
+ list.
1788
+ per_page: The number of files to return per request
1789
+
1790
+ Returns:
1791
+ An iterator containing `File` objects
1792
+
1793
+ Raises:
1794
+ ArtifactNotLoggedError: if the artifact has not been logged
1795
+ """
1796
+ if self._state == ArtifactState.PENDING:
1797
+ raise ArtifactNotLoggedError(self, "files")
1798
+
1799
+ return ArtifactFiles(self._client, self, names, per_page)
1800
+
1801
+ def _default_root(self, include_version: bool = True) -> str:
1802
+ name = self.name if include_version else self.name.split(":")[0]
1803
+ root = os.path.join(env.get_artifact_dir(), name)
1804
+ if platform.system() == "Windows":
1805
+ head, tail = os.path.splitdrive(root)
1806
+ root = head + tail.replace(":", "-")
1807
+ return root
1808
+
1809
+ def _add_download_root(self, dir_path: str) -> None:
1810
+ self._download_roots.add(os.path.abspath(dir_path))
1811
+
1812
+ def _local_path_to_name(self, file_path: str) -> Optional[str]:
1813
+ """Convert a local file path to a path entry in the artifact."""
1814
+ abs_file_path = os.path.abspath(file_path)
1815
+ abs_file_parts = abs_file_path.split(os.sep)
1816
+ for i in range(len(abs_file_parts) + 1):
1817
+ if os.path.join(os.sep, *abs_file_parts[:i]) in self._download_roots:
1818
+ return os.path.join(*abs_file_parts[i:])
1819
+ return None
1820
+
1821
+ # Others.
1822
+
1823
+ def delete(self, delete_aliases: bool = False) -> None:
1824
+ """Delete an artifact and its files.
1825
+
1826
+ Arguments:
1827
+ delete_aliases: If true, deletes all aliases associated with the artifact.
1828
+ Otherwise, this raises an exception if the artifact has existing
1829
+ aliases.
1830
+
1831
+ Raises:
1832
+ ArtifactNotLoggedError: if the artifact has not been logged
1833
+
1834
+ Examples:
1835
+ Delete all the "model" artifacts a run has logged:
1836
+ ```
1837
+ runs = api.runs(path="my_entity/my_project")
1838
+ for run in runs:
1839
+ for artifact in run.logged_artifacts():
1840
+ if artifact.type == "model":
1841
+ artifact.delete(delete_aliases=True)
1842
+ ```
1843
+ """
1844
+ if self._state == ArtifactState.PENDING:
1845
+ raise ArtifactNotLoggedError(self, "delete")
1846
+ self._delete(delete_aliases)
1847
+
1848
+ @normalize_exceptions
1849
+ def _delete(self, delete_aliases: bool = False) -> None:
1850
+ mutation = gql(
1851
+ """
1852
+ mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
1853
+ deleteArtifact(input: {
1854
+ artifactID: $artifactID
1855
+ deleteAliases: $deleteAliases
1856
+ }) {
1857
+ artifact {
1858
+ id
1859
+ }
1860
+ }
1861
+ }
1862
+ """
1863
+ )
1864
+ assert self._client is not None
1865
+ self._client.execute(
1866
+ mutation,
1867
+ variable_values={
1868
+ "artifactID": self.id,
1869
+ "deleteAliases": delete_aliases,
1870
+ },
1871
+ )
1872
+
1873
+ def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1874
+ """Link this artifact to a portfolio (a promoted collection of artifacts).
1875
+
1876
+ Arguments:
1877
+ target_path: The path to the portfolio. It must take the form {portfolio},
1878
+ {project}/{portfolio} or {entity}/{project}/{portfolio}.
1879
+ aliases: A list of strings which uniquely identifies the artifact inside the
1880
+ specified portfolio.
1881
+
1882
+ Raises:
1883
+ ArtifactNotLoggedError: if the artifact has not been logged
1884
+ """
1885
+ if self._state == ArtifactState.PENDING:
1886
+ raise ArtifactNotLoggedError(self, "link")
1887
+ self._link(target_path, aliases)
1888
+
1889
+ @normalize_exceptions
1890
+ def _link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1891
+ if ":" in target_path:
1892
+ raise ValueError(
1893
+ f"target_path {target_path} cannot contain `:` because it is not an "
1894
+ f"alias."
1895
+ )
1896
+
1897
+ portfolio, project, entity = util._parse_entity_project_item(target_path)
1898
+ aliases = util._resolve_aliases(aliases)
1899
+
1900
+ run_entity = wandb.run.entity if wandb.run else None
1901
+ run_project = wandb.run.project if wandb.run else None
1902
+ entity = entity or run_entity or self.entity
1903
+ project = project or run_project or self.project
1904
+
1905
+ mutation = gql(
1906
+ """
1907
+ mutation LinkArtifact(
1908
+ $artifactID: ID!,
1909
+ $artifactPortfolioName: String!,
1910
+ $entityName: String!,
1911
+ $projectName: String!,
1912
+ $aliases: [ArtifactAliasInput!]
1913
+ ) {
1914
+ linkArtifact(
1915
+ input: {
1916
+ artifactID: $artifactID,
1917
+ artifactPortfolioName: $artifactPortfolioName,
1918
+ entityName: $entityName,
1919
+ projectName: $projectName,
1920
+ aliases: $aliases
1921
+ }
1922
+ ) {
1923
+ versionIndex
1924
+ }
1925
+ }
1926
+ """
1927
+ )
1928
+ assert self._client is not None
1929
+ self._client.execute(
1930
+ mutation,
1931
+ variable_values={
1932
+ "artifactID": self.id,
1933
+ "artifactPortfolioName": portfolio,
1934
+ "entityName": entity,
1935
+ "projectName": project,
1936
+ "aliases": [
1937
+ {"alias": alias, "artifactCollectionName": portfolio}
1938
+ for alias in aliases
1939
+ ],
1940
+ },
1941
+ )
1942
+
1943
+ def used_by(self) -> List[Run]:
1944
+ """Get a list of the runs that have used this artifact.
1945
+
1946
+ Raises:
1947
+ ArtifactNotLoggedError: if the artifact has not been logged
1948
+ """
1949
+ if self._state == ArtifactState.PENDING:
1950
+ raise ArtifactNotLoggedError(self, "used_by")
1951
+
1952
+ query = gql(
1953
+ """
1954
+ query ArtifactUsedBy(
1955
+ $id: ID!,
1956
+ ) {
1957
+ artifact(id: $id) {
1958
+ usedBy {
1959
+ edges {
1960
+ node {
1961
+ name
1962
+ project {
1963
+ name
1964
+ entityName
1965
+ }
1966
+ }
1967
+ }
1968
+ }
1969
+ }
1970
+ }
1971
+ """
1972
+ )
1973
+ assert self._client is not None
1974
+ response = self._client.execute(
1975
+ query,
1976
+ variable_values={"id": self.id},
1977
+ )
1978
+ return [
1979
+ Run(
1980
+ self._client,
1981
+ edge["node"]["project"]["entityName"],
1982
+ edge["node"]["project"]["name"],
1983
+ edge["node"]["name"],
1984
+ )
1985
+ for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
1986
+ ]
1987
+
1988
+ def logged_by(self) -> Optional[Run]:
1989
+ """Get the run that first logged this artifact.
1990
+
1991
+ Raises:
1992
+ ArtifactNotLoggedError: if the artifact has not been logged
1993
+ """
1994
+ if self._state == ArtifactState.PENDING:
1995
+ raise ArtifactNotLoggedError(self, "logged_by")
1996
+
1997
+ query = gql(
1998
+ """
1999
+ query ArtifactCreatedBy(
2000
+ $id: ID!
2001
+ ) {
2002
+ artifact(id: $id) {
2003
+ createdBy {
2004
+ ... on Run {
2005
+ name
2006
+ project {
2007
+ name
2008
+ entityName
2009
+ }
2010
+ }
2011
+ }
2012
+ }
2013
+ }
2014
+ """
2015
+ )
2016
+ assert self._client is not None
2017
+ response = self._client.execute(
2018
+ query,
2019
+ variable_values={"id": self.id},
2020
+ )
2021
+ creator = response.get("artifact", {}).get("createdBy", {})
2022
+ if creator.get("name") is None:
2023
+ return None
2024
+ return Run(
2025
+ self._client,
2026
+ creator["project"]["entityName"],
2027
+ creator["project"]["name"],
2028
+ creator["name"],
2029
+ )
2030
+
2031
+ def json_encode(self) -> Dict[str, Any]:
2032
+ if self._state == ArtifactState.PENDING:
2033
+ raise ArtifactNotLoggedError(self, "json_encode")
2034
+ return util.artifact_to_json(self)
2035
+
2036
+ @staticmethod
2037
+ def _expected_type(
2038
+ entity_name: str, project_name: str, name: str, client: RetryingClient
2039
+ ) -> Optional[str]:
2040
+ """Returns the expected type for a given artifact name and project."""
2041
+ query = gql(
2042
+ """
2043
+ query ArtifactType(
2044
+ $entityName: String,
2045
+ $projectName: String,
2046
+ $name: String!
2047
+ ) {
2048
+ project(name: $projectName, entityName: $entityName) {
2049
+ artifact(name: $name) {
2050
+ artifactType {
2051
+ name
2052
+ }
2053
+ }
2054
+ }
2055
+ }
2056
+ """
2057
+ )
2058
+ if ":" not in name:
2059
+ name += ":latest"
2060
+ response = client.execute(
2061
+ query,
2062
+ variable_values={
2063
+ "entityName": entity_name,
2064
+ "projectName": project_name,
2065
+ "name": name,
2066
+ },
2067
+ )
2068
+ return (
2069
+ ((response.get("project") or {}).get("artifact") or {}).get("artifactType")
2070
+ or {}
2071
+ ).get("name")
2072
+
2073
+ @staticmethod
2074
+ def _normalize_metadata(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
2075
+ if metadata is None:
2076
+ return {}
2077
+ if not isinstance(metadata, dict):
2078
+ raise TypeError(f"metadata must be dict, not {type(metadata)}")
2079
+ return cast(
2080
+ Dict[str, Any], json.loads(json.dumps(util.json_friendly_val(metadata)))
2081
+ )
2082
+
2083
+ def _load_manifest(self, url: str) -> None:
2084
+ with requests.get(url) as request:
2085
+ request.raise_for_status()
2086
+ self._manifest = ArtifactManifest.from_manifest_json(
2087
+ json.loads(util.ensure_text(request.content))
2088
+ )
2089
+ for entry in self.manifest.entries.values():
2090
+ if entry._is_artifact_reference():
2091
+ assert self._client is not None
2092
+ dep_artifact = entry._get_referenced_artifact(self._client)
2093
+ self._dependent_artifacts.add(dep_artifact)
2094
+
2095
+
2096
+ class _ArtifactVersionType(WBType):
2097
+ name = "artifactVersion"
2098
+ types = [Artifact]
2099
+
2100
+
2101
+ TypeRegistry.add(_ArtifactVersionType)