wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public.py +18 -20
  5. wandb/beta/workflows.py +5 -6
  6. wandb/cli/cli.py +27 -27
  7. wandb/data_types.py +2 -0
  8. wandb/integration/langchain/wandb_tracer.py +16 -179
  9. wandb/integration/sagemaker/config.py +2 -2
  10. wandb/integration/tensorboard/log.py +4 -4
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  13. wandb/proto/wandb_deprecated.py +3 -1
  14. wandb/sdk/__init__.py +1 -4
  15. wandb/sdk/artifacts/__init__.py +0 -14
  16. wandb/sdk/artifacts/artifact.py +1757 -277
  17. wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
  18. wandb/sdk/artifacts/artifact_state.py +10 -0
  19. wandb/sdk/artifacts/artifacts_cache.py +7 -8
  20. wandb/sdk/artifacts/exceptions.py +4 -4
  21. wandb/sdk/artifacts/storage_handler.py +2 -2
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
  23. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
  24. wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
  25. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
  26. wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
  27. wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
  28. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
  29. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
  30. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
  31. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
  33. wandb/sdk/artifacts/storage_policy.py +3 -3
  34. wandb/sdk/data_types/_dtypes.py +7 -12
  35. wandb/sdk/data_types/base_types/json_metadata.py +2 -2
  36. wandb/sdk/data_types/base_types/media.py +5 -6
  37. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  38. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
  39. wandb/sdk/data_types/helper_types/classes.py +5 -8
  40. wandb/sdk/data_types/helper_types/image_mask.py +4 -5
  41. wandb/sdk/data_types/histogram.py +3 -3
  42. wandb/sdk/data_types/html.py +3 -4
  43. wandb/sdk/data_types/image.py +4 -5
  44. wandb/sdk/data_types/molecule.py +2 -2
  45. wandb/sdk/data_types/object_3d.py +3 -3
  46. wandb/sdk/data_types/plotly.py +2 -2
  47. wandb/sdk/data_types/saved_model.py +7 -8
  48. wandb/sdk/data_types/trace_tree.py +4 -4
  49. wandb/sdk/data_types/video.py +4 -4
  50. wandb/sdk/interface/interface.py +8 -10
  51. wandb/sdk/internal/file_stream.py +2 -3
  52. wandb/sdk/internal/internal_api.py +99 -4
  53. wandb/sdk/internal/job_builder.py +15 -7
  54. wandb/sdk/internal/sender.py +4 -0
  55. wandb/sdk/internal/settings_static.py +1 -0
  56. wandb/sdk/launch/_project_spec.py +9 -7
  57. wandb/sdk/launch/agent/agent.py +115 -58
  58. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  59. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  60. wandb/sdk/launch/builder/abstract.py +5 -1
  61. wandb/sdk/launch/builder/build.py +16 -10
  62. wandb/sdk/launch/builder/docker_builder.py +9 -2
  63. wandb/sdk/launch/builder/kaniko_builder.py +108 -22
  64. wandb/sdk/launch/builder/noop.py +3 -1
  65. wandb/sdk/launch/environment/aws_environment.py +2 -1
  66. wandb/sdk/launch/environment/azure_environment.py +124 -0
  67. wandb/sdk/launch/github_reference.py +30 -18
  68. wandb/sdk/launch/launch.py +1 -1
  69. wandb/sdk/launch/loader.py +15 -0
  70. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  71. wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
  72. wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
  73. wandb/sdk/launch/runner/abstract.py +19 -3
  74. wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
  75. wandb/sdk/launch/runner/local_container.py +101 -48
  76. wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
  77. wandb/sdk/launch/runner/vertex_runner.py +8 -4
  78. wandb/sdk/launch/sweeps/scheduler.py +102 -27
  79. wandb/sdk/launch/sweeps/utils.py +21 -0
  80. wandb/sdk/launch/utils.py +19 -7
  81. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  82. wandb/sdk/service/server.py +22 -9
  83. wandb/sdk/service/service.py +27 -8
  84. wandb/sdk/verify/verify.py +6 -9
  85. wandb/sdk/wandb_config.py +2 -4
  86. wandb/sdk/wandb_init.py +2 -0
  87. wandb/sdk/wandb_require.py +7 -0
  88. wandb/sdk/wandb_run.py +32 -35
  89. wandb/sdk/wandb_settings.py +10 -3
  90. wandb/testing/relay.py +15 -2
  91. wandb/util.py +55 -23
  92. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
  93. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
  94. wandb/integration/langchain/util.py +0 -191
  95. wandb/sdk/artifacts/invalid_artifact.py +0 -23
  96. wandb/sdk/artifacts/lazy_artifact.py +0 -162
  97. wandb/sdk/artifacts/local_artifact.py +0 -719
  98. wandb/sdk/artifacts/public_artifact.py +0 -1188
  99. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  100. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  101. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
  102. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,385 @@
1
- """Artifact interface."""
1
+ """Artifact class."""
2
+ import concurrent.futures
2
3
  import contextlib
3
- from typing import IO, TYPE_CHECKING, Generator, List, Optional, Sequence, Union
4
+ import datetime
5
+ import json
6
+ import multiprocessing.dummy
7
+ import os
8
+ import platform
9
+ import re
10
+ import shutil
11
+ import tempfile
12
+ import time
13
+ from copy import copy
14
+ from functools import partial
15
+ from pathlib import PurePosixPath
16
+ from typing import (
17
+ IO,
18
+ TYPE_CHECKING,
19
+ Any,
20
+ Dict,
21
+ Generator,
22
+ List,
23
+ Optional,
24
+ Sequence,
25
+ Set,
26
+ Tuple,
27
+ Type,
28
+ Union,
29
+ cast,
30
+ )
31
+ from urllib.parse import urlparse
32
+
33
+ import requests
4
34
 
5
35
  import wandb
36
+ from wandb import data_types, env, util
37
+ from wandb.apis.normalize import normalize_exceptions
38
+ from wandb.apis.public import ArtifactFiles, RetryingClient, Run
6
39
  from wandb.data_types import WBValue
7
- from wandb.sdk.lib.paths import FilePathStr, StrPath
40
+ from wandb.errors.term import termerror, termlog, termwarn
41
+ from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
42
+ from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
43
+ from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
44
+ from wandb.sdk.artifacts.artifact_manifests.artifact_manifest_v1 import (
45
+ ArtifactManifestV1,
46
+ )
47
+ from wandb.sdk.artifacts.artifact_saver import get_staging_dir
48
+ from wandb.sdk.artifacts.artifact_state import ArtifactState
49
+ from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
50
+ from wandb.sdk.artifacts.exceptions import (
51
+ ArtifactFinalizedError,
52
+ ArtifactNotLoggedError,
53
+ WaitTimeoutError,
54
+ )
55
+ from wandb.sdk.artifacts.storage_layout import StorageLayout
56
+ from wandb.sdk.artifacts.storage_policies.wandb_storage_policy import WandbStoragePolicy
57
+ from wandb.sdk.data_types._dtypes import Type as WBType
58
+ from wandb.sdk.data_types._dtypes import TypeRegistry
59
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
60
+ from wandb.sdk.lib import filesystem, retry, runid, telemetry
61
+ from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
62
+ from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath, URIStr
63
+
64
+ reset_path = util.vendor_setup()
65
+
66
+ from wandb_gql import gql # noqa: E402
67
+
68
+ reset_path()
8
69
 
9
70
  if TYPE_CHECKING:
10
- import os
11
-
12
- import wandb.apis.public
13
- from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
14
- from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
71
+ from wandb.sdk.interface.message_future import MessageFuture
15
72
 
16
73
 
17
74
  class Artifact:
75
+ """Flexible and lightweight building block for dataset and model versioning.
76
+
77
+ Constructs an empty artifact whose contents can be populated using its `add` family
78
+ of functions. Once the artifact has all the desired files, you can call
79
+ `wandb.log_artifact()` to log it.
80
+
81
+ Arguments:
82
+ name: A human-readable name for this artifact, which is how you can identify
83
+ this artifact in the UI or reference it in `use_artifact` calls. Names can
84
+ contain letters, numbers, underscores, hyphens, and dots. The name must be
85
+ unique across a project.
86
+ type: The type of the artifact, which is used to organize and differentiate
87
+ artifacts. Common types include `dataset` or `model`, but you can use any
88
+ string containing letters, numbers, underscores, hyphens, and dots.
89
+ description: Free text that offers a description of the artifact. The
90
+ description is markdown rendered in the UI, so this is a good place to place
91
+ tables, links, etc.
92
+ metadata: Structured data associated with the artifact, for example class
93
+ distribution of a dataset. This will eventually be queryable and plottable
94
+ in the UI. There is a hard limit of 100 total keys.
95
+
96
+ Returns:
97
+ An `Artifact` object.
98
+
99
+ Examples:
100
+ Basic usage:
101
+ ```
102
+ wandb.init()
103
+
104
+ artifact = wandb.Artifact("mnist", type="dataset")
105
+ artifact.add_dir("mnist/")
106
+ wandb.log_artifact(artifact)
107
+ ```
108
+ """
109
+
110
+ _TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts")
111
+ _GQL_FRAGMENT = """
112
+ fragment ArtifactFragment on Artifact {
113
+ id
114
+ artifactSequence {
115
+ project {
116
+ entityName
117
+ name
118
+ }
119
+ name
120
+ }
121
+ versionIndex
122
+ artifactType {
123
+ name
124
+ }
125
+ description
126
+ metadata
127
+ aliases {
128
+ artifactCollectionName
129
+ alias
130
+ }
131
+ state
132
+ commitHash
133
+ fileCount
134
+ createdAt
135
+ updatedAt
136
+ }
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ name: str,
142
+ type: str,
143
+ description: Optional[str] = None,
144
+ metadata: Optional[Dict[str, Any]] = None,
145
+ incremental: bool = False,
146
+ use_as: Optional[str] = None,
147
+ ) -> None:
148
+ if not re.match(r"^[a-zA-Z0-9_\-.]+$", name):
149
+ raise ValueError(
150
+ f"Artifact name may only contain alphanumeric characters, dashes, "
151
+ f"underscores, and dots. Invalid name: {name}"
152
+ )
153
+ if type == "job" or type.startswith("wandb-"):
154
+ raise ValueError(
155
+ "Artifact types 'job' and 'wandb-*' are reserved for internal use. "
156
+ "Please use a different type."
157
+ )
158
+ if incremental:
159
+ termwarn("Using experimental arg `incremental`")
160
+
161
+ # Internal.
162
+ self._client: Optional[RetryingClient] = None
163
+ storage_layout = (
164
+ StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
165
+ )
166
+ self._storage_policy = WandbStoragePolicy(
167
+ config={
168
+ "storageLayout": storage_layout,
169
+ # TODO: storage region
170
+ }
171
+ )
172
+ self._tmp_dir: Optional[tempfile.TemporaryDirectory] = None
173
+ self._added_objs: Dict[
174
+ int, Tuple[data_types.WBValue, ArtifactManifestEntry]
175
+ ] = {}
176
+ self._added_local_paths: Dict[str, ArtifactManifestEntry] = {}
177
+ self._save_future: Optional["MessageFuture"] = None
178
+ self._dependent_artifacts: Set["Artifact"] = set()
179
+ self._download_roots: Set[str] = set()
180
+ # Properties.
181
+ self._id: Optional[str] = None
182
+ self._client_id: str = runid.generate_id(128)
183
+ self._sequence_client_id: str = runid.generate_id(128)
184
+ self._entity: Optional[str] = None
185
+ self._project: Optional[str] = None
186
+ self._name: str = name # includes version after saving
187
+ self._version: Optional[str] = None
188
+ self._source_entity: Optional[str] = None
189
+ self._source_project: Optional[str] = None
190
+ self._source_name: str = name # includes version after saving
191
+ self._source_version: Optional[str] = None
192
+ self._type: str = type
193
+ self._description: Optional[str] = description
194
+ self._metadata: dict = self._normalize_metadata(metadata)
195
+ self._aliases: List[str] = []
196
+ self._saved_aliases: List[str] = []
197
+ self._distributed_id: Optional[str] = None
198
+ self._incremental: bool = incremental
199
+ self._use_as: Optional[str] = use_as
200
+ self._state: ArtifactState = ArtifactState.PENDING
201
+ self._manifest: Optional[ArtifactManifest] = ArtifactManifestV1(
202
+ self._storage_policy
203
+ )
204
+ self._commit_hash: Optional[str] = None
205
+ self._file_count: Optional[int] = None
206
+ self._created_at: Optional[str] = None
207
+ self._updated_at: Optional[str] = None
208
+ self._final: bool = False
209
+ # Cache.
210
+ get_artifacts_cache().store_client_artifact(self)
211
+
212
+ def __repr__(self) -> str:
213
+ return f"<Artifact {self.id or self.name}>"
214
+
215
+ @classmethod
216
+ def _from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
217
+ artifact = get_artifacts_cache().get_artifact(artifact_id)
218
+ if artifact is not None:
219
+ return artifact
220
+
221
+ query = gql(
222
+ """
223
+ query ArtifactByID($id: ID!) {
224
+ artifact(id: $id) {
225
+ ...ArtifactFragment
226
+ currentManifest {
227
+ file {
228
+ directUrl
229
+ }
230
+ }
231
+ }
232
+ }
233
+ """
234
+ + cls._GQL_FRAGMENT
235
+ )
236
+ response = client.execute(
237
+ query,
238
+ variable_values={"id": artifact_id},
239
+ )
240
+ attrs = response.get("artifact")
241
+ if attrs is None:
242
+ return None
243
+ entity = attrs["artifactSequence"]["project"]["entityName"]
244
+ project = attrs["artifactSequence"]["project"]["name"]
245
+ name = "{}:v{}".format(attrs["artifactSequence"]["name"], attrs["versionIndex"])
246
+ return cls._from_attrs(entity, project, name, attrs, client)
247
+
248
+ @classmethod
249
+ def _from_name(
250
+ cls, entity: str, project: str, name: str, client: RetryingClient
251
+ ) -> "Artifact":
252
+ query = gql(
253
+ """
254
+ query ArtifactByName(
255
+ $entityName: String!,
256
+ $projectName: String!,
257
+ $name: String!
258
+ ) {
259
+ project(name: $projectName, entityName: $entityName) {
260
+ artifact(name: $name) {
261
+ ...ArtifactFragment
262
+ }
263
+ }
264
+ }
265
+ """
266
+ + cls._GQL_FRAGMENT
267
+ )
268
+ response = client.execute(
269
+ query,
270
+ variable_values={
271
+ "entityName": entity,
272
+ "projectName": project,
273
+ "name": name,
274
+ },
275
+ )
276
+ attrs = response.get("project", {}).get("artifact")
277
+ if attrs is None:
278
+ raise ValueError(
279
+ f"Unable to fetch artifact with name {entity}/{project}/{name}"
280
+ )
281
+ return cls._from_attrs(entity, project, name, attrs, client)
282
+
283
+ @classmethod
284
+ def _from_attrs(
285
+ cls,
286
+ entity: str,
287
+ project: str,
288
+ name: str,
289
+ attrs: Dict[str, Any],
290
+ client: RetryingClient,
291
+ ) -> "Artifact":
292
+ # Placeholder is required to skip validation.
293
+ artifact = cls("placeholder", type="placeholder")
294
+ artifact._client = client
295
+ artifact._id = attrs["id"]
296
+ artifact._entity = entity
297
+ artifact._project = project
298
+ artifact._name = name
299
+ version_aliases = [
300
+ alias["alias"]
301
+ for alias in attrs.get("aliases", [])
302
+ if alias["artifactCollectionName"] == name.split(":")[0]
303
+ and util.alias_is_version_index(alias["alias"])
304
+ ]
305
+ # assert len(version_aliases) == 1
306
+ artifact._version = version_aliases[0]
307
+ artifact._source_entity = attrs["artifactSequence"]["project"]["entityName"]
308
+ artifact._source_project = attrs["artifactSequence"]["project"]["name"]
309
+ artifact._source_name = "{}:v{}".format(
310
+ attrs["artifactSequence"]["name"], attrs["versionIndex"]
311
+ )
312
+ artifact._source_version = "v{}".format(attrs["versionIndex"])
313
+ artifact._type = attrs["artifactType"]["name"]
314
+ artifact._description = attrs["description"]
315
+ artifact.metadata = cls._normalize_metadata(
316
+ json.loads(attrs["metadata"] or "{}")
317
+ )
318
+ artifact._aliases = [
319
+ alias["alias"]
320
+ for alias in attrs.get("aliases", [])
321
+ if alias["artifactCollectionName"] == name.split(":")[0]
322
+ and not util.alias_is_version_index(alias["alias"])
323
+ ]
324
+ artifact._saved_aliases = copy(artifact._aliases)
325
+ artifact._state = ArtifactState(attrs["state"])
326
+ if "currentManifest" in attrs:
327
+ artifact._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
328
+ else:
329
+ artifact._manifest = None
330
+ artifact._commit_hash = attrs["commitHash"]
331
+ artifact._file_count = attrs["fileCount"]
332
+ artifact._created_at = attrs["createdAt"]
333
+ artifact._updated_at = attrs["updatedAt"]
334
+ artifact._final = True
335
+ # Cache.
336
+ get_artifacts_cache().store_artifact(artifact)
337
+ return artifact
338
+
339
+ def new_draft(self) -> "Artifact":
340
+ """Create a new draft artifact with the same content as this committed artifact.
341
+
342
+ The artifact returned can be extended or modified and logged as a new version.
343
+
344
+ Raises:
345
+ ArtifactNotLoggedError: if the artifact has not been logged
346
+ """
347
+ if self._state == ArtifactState.PENDING:
348
+ raise ArtifactNotLoggedError(self, "new_draft")
349
+
350
+ artifact = Artifact(self.source_name.split(":")[0], self.type)
351
+ artifact._description = self.description
352
+ artifact._metadata = self.metadata
353
+ artifact._manifest = ArtifactManifest.from_manifest_json(
354
+ self.manifest.to_manifest_json()
355
+ )
356
+ return artifact
357
+
358
+ # Properties.
359
+
18
360
  @property
19
361
  def id(self) -> Optional[str]:
20
362
  """The artifact's ID."""
21
- raise NotImplementedError
363
+ if self._state == ArtifactState.PENDING:
364
+ return None
365
+ assert self._id is not None
366
+ return self._id
22
367
 
23
368
  @property
24
369
  def entity(self) -> str:
25
370
  """The name of the entity of the secondary (portfolio) artifact collection."""
26
- raise NotImplementedError
371
+ if self._state == ArtifactState.PENDING:
372
+ raise ArtifactNotLoggedError(self, "entity")
373
+ assert self._entity is not None
374
+ return self._entity
27
375
 
28
376
  @property
29
377
  def project(self) -> str:
30
378
  """The name of the project of the secondary (portfolio) artifact collection."""
31
- raise NotImplementedError
379
+ if self._state == ArtifactState.PENDING:
380
+ raise ArtifactNotLoggedError(self, "project")
381
+ assert self._project is not None
382
+ return self._project
32
383
 
33
384
  @property
34
385
  def name(self) -> str:
@@ -37,7 +388,7 @@ class Artifact:
37
388
  A string with the format {collection}:{alias}. Before the artifact is saved,
38
389
  contains only the name since the version is not yet known.
39
390
  """
40
- raise NotImplementedError
391
+ return self._name
41
392
 
42
393
  @property
43
394
  def qualified_name(self) -> str:
@@ -46,21 +397,27 @@ class Artifact:
46
397
 
47
398
  @property
48
399
  def version(self) -> str:
49
- """The artifact's version in its secondary (portfolio) collection.
50
-
51
- A string with the format "v{number}".
52
- """
53
- raise NotImplementedError
400
+ """The artifact's version in its secondary (portfolio) collection."""
401
+ if self._state == ArtifactState.PENDING:
402
+ raise ArtifactNotLoggedError(self, "version")
403
+ assert self._version is not None
404
+ return self._version
54
405
 
55
406
  @property
56
407
  def source_entity(self) -> str:
57
408
  """The name of the entity of the primary (sequence) artifact collection."""
58
- raise NotImplementedError
409
+ if self._state == ArtifactState.PENDING:
410
+ raise ArtifactNotLoggedError(self, "source_entity")
411
+ assert self._source_entity is not None
412
+ return self._source_entity
59
413
 
60
414
  @property
61
415
  def source_project(self) -> str:
62
416
  """The name of the project of the primary (sequence) artifact collection."""
63
- raise NotImplementedError
417
+ if self._state == ArtifactState.PENDING:
418
+ raise ArtifactNotLoggedError(self, "source_project")
419
+ assert self._source_project is not None
420
+ return self._source_project
64
421
 
65
422
  @property
66
423
  def source_name(self) -> str:
@@ -69,12 +426,12 @@ class Artifact:
69
426
  A string with the format {collection}:{alias}. Before the artifact is saved,
70
427
  contains only the name since the version is not yet known.
71
428
  """
72
- raise NotImplementedError
429
+ return self._source_name
73
430
 
74
431
  @property
75
432
  def source_qualified_name(self) -> str:
76
433
  """The entity/project/name of the primary (sequence) collection."""
77
- return f"{self.entity}/{self.project}/{self.name}"
434
+ return f"{self.source_entity}/{self.source_project}/{self.source_name}"
78
435
 
79
436
  @property
80
437
  def source_version(self) -> str:
@@ -82,66 +439,26 @@ class Artifact:
82
439
 
83
440
  A string with the format "v{number}".
84
441
  """
85
- raise NotImplementedError
442
+ if self._state == ArtifactState.PENDING:
443
+ raise ArtifactNotLoggedError(self, "source_version")
444
+ assert self._source_version is not None
445
+ return self._source_version
86
446
 
87
447
  @property
88
448
  def type(self) -> str:
89
449
  """The artifact's type."""
90
- raise NotImplementedError
91
-
92
- @property
93
- def manifest(self) -> "ArtifactManifest":
94
- """The artifact's manifest.
95
-
96
- The manifest lists all of its contents, and can't be changed once the artifact
97
- has been logged.
98
- """
99
- raise NotImplementedError
100
-
101
- @property
102
- def digest(self) -> str:
103
- """The logical digest of the artifact.
104
-
105
- The digest is the checksum of the artifact's contents. If an artifact has the
106
- same digest as the current `latest` version, then `log_artifact` is a no-op.
107
- """
108
- raise NotImplementedError
109
-
110
- @property
111
- def state(self) -> str:
112
- """The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED"."""
113
- raise NotImplementedError
114
-
115
- @property
116
- def size(self) -> int:
117
- """The total size of the artifact in bytes.
118
-
119
- Returns:
120
- (int): The size in bytes of the artifact. Includes any references tracked by
121
- this artifact.
122
- """
123
- raise NotImplementedError
124
-
125
- @property
126
- def commit_hash(self) -> str:
127
- """The hash returned when this artifact was committed.
128
-
129
- Returns:
130
- (str): The artifact's commit hash which is used in http URLs.
131
- """
132
- raise NotImplementedError
450
+ return self._type
133
451
 
134
452
  @property
135
453
  def description(self) -> Optional[str]:
136
454
  """The artifact description.
137
455
 
138
- Returns:
139
- (str): Free text that offers a user-set description of the artifact.
456
+ Free text that offers a user-set description of the artifact.
140
457
  """
141
- raise NotImplementedError
458
+ return self._description
142
459
 
143
460
  @description.setter
144
- def description(self, desc: Optional[str]) -> None:
461
+ def description(self, description: Optional[str]) -> None:
145
462
  """Set the description of the artifact.
146
463
 
147
464
  The description is markdown rendered in the UI, so this is a good place to put
@@ -150,16 +467,15 @@ class Artifact:
150
467
  Arguments:
151
468
  desc: Free text that offers a description of the artifact.
152
469
  """
153
- raise NotImplementedError
470
+ self._description = description
154
471
 
155
472
  @property
156
473
  def metadata(self) -> dict:
157
474
  """User-defined artifact metadata.
158
475
 
159
- Returns:
160
- (dict): Structured data associated with the artifact.
476
+ Structured data associated with the artifact.
161
477
  """
162
- raise NotImplementedError
478
+ return self._metadata
163
479
 
164
480
  @metadata.setter
165
481
  def metadata(self, metadata: dict) -> None:
@@ -171,9 +487,9 @@ class Artifact:
171
487
  Note: There is currently a limit of 100 total keys.
172
488
 
173
489
  Arguments:
174
- metadata: (dict) Structured data associated with the artifact.
490
+ metadata: Structured data associated with the artifact.
175
491
  """
176
- raise NotImplementedError
492
+ self._metadata = self._normalize_metadata(metadata)
177
493
 
178
494
  @property
179
495
  def aliases(self) -> List[str]:
@@ -181,122 +497,637 @@ class Artifact:
181
497
 
182
498
  The list is mutable and calling `save()` will persist all alias changes.
183
499
  """
184
- raise NotImplementedError
500
+ if self._state == ArtifactState.PENDING:
501
+ raise ArtifactNotLoggedError(self, "aliases")
502
+ return self._aliases
185
503
 
186
504
  @aliases.setter
187
505
  def aliases(self, aliases: List[str]) -> None:
188
506
  """Set the aliases associated with this artifact."""
189
- raise NotImplementedError
507
+ if self._state == ArtifactState.PENDING:
508
+ raise ArtifactNotLoggedError(self, "aliases")
190
509
 
191
- def used_by(self) -> List["wandb.apis.public.Run"]:
192
- """Get a list of the runs that have used this artifact."""
193
- raise NotImplementedError
510
+ if any(char in alias for alias in aliases for char in ["/", ":"]):
511
+ raise ValueError(
512
+ "Aliases must not contain any of the following characters: /, :"
513
+ )
514
+ self._aliases = aliases
194
515
 
195
- def logged_by(self) -> Optional["wandb.apis.public.Run"]:
196
- """Get the run that first logged this artifact."""
197
- raise NotImplementedError
516
+ @property
517
+ def distributed_id(self) -> Optional[str]:
518
+ return self._distributed_id
198
519
 
199
- @contextlib.contextmanager
200
- def new_file(
201
- self, name: str, mode: str = "w", encoding: Optional[str] = None
202
- ) -> Generator[IO, None, None]:
203
- """Open a new temporary file that will be automatically added to the artifact.
520
+ @distributed_id.setter
521
+ def distributed_id(self, distributed_id: Optional[str]) -> None:
522
+ self._distributed_id = distributed_id
523
+
524
+ @property
525
+ def incremental(self) -> bool:
526
+ return self._incremental
527
+
528
+ @property
529
+ def use_as(self) -> Optional[str]:
530
+ return self._use_as
531
+
532
+ @property
533
+ def state(self) -> str:
534
+ """The status of the artifact. One of: "PENDING", "COMMITTED", or "DELETED"."""
535
+ return self._state.value
536
+
537
+ @property
538
+ def manifest(self) -> ArtifactManifest:
539
+ """The artifact's manifest.
540
+
541
+ The manifest lists all of its contents, and can't be changed once the artifact
542
+ has been logged.
543
+ """
544
+ if self._manifest is None:
545
+ query = gql(
546
+ """
547
+ query ArtifactManifest(
548
+ $entityName: String!,
549
+ $projectName: String!,
550
+ $name: String!
551
+ ) {
552
+ project(entityName: $entityName, name: $projectName) {
553
+ artifact(name: $name) {
554
+ currentManifest {
555
+ file {
556
+ directUrl
557
+ }
558
+ }
559
+ }
560
+ }
561
+ }
562
+ """
563
+ )
564
+ assert self._client is not None
565
+ response = self._client.execute(
566
+ query,
567
+ variable_values={
568
+ "entityName": self._entity,
569
+ "projectName": self._project,
570
+ "name": self._name,
571
+ },
572
+ )
573
+ attrs = response["project"]["artifact"]
574
+ self._load_manifest(attrs["currentManifest"]["file"]["directUrl"])
575
+ assert self._manifest is not None
576
+ return self._manifest
577
+
578
+ @property
579
+ def digest(self) -> str:
580
+ """The logical digest of the artifact.
581
+
582
+ The digest is the checksum of the artifact's contents. If an artifact has the
583
+ same digest as the current `latest` version, then `log_artifact` is a no-op.
584
+ """
585
+ return self.manifest.digest()
586
+
587
+ @property
588
+ def size(self) -> int:
589
+ """The total size of the artifact in bytes.
590
+
591
+ Includes any references tracked by this artifact.
592
+ """
593
+ total_size: int = 0
594
+ for entry in self.manifest.entries.values():
595
+ if entry.size is not None:
596
+ total_size += entry.size
597
+ return total_size
598
+
599
+ @property
600
+ def commit_hash(self) -> str:
601
+ """The hash returned when this artifact was committed."""
602
+ if self._state == ArtifactState.PENDING:
603
+ raise ArtifactNotLoggedError(self, "commit_hash")
604
+ assert self._commit_hash is not None
605
+ return self._commit_hash
606
+
607
+ @property
608
+ def file_count(self) -> int:
609
+ """The number of files (including references)."""
610
+ if self._state == ArtifactState.PENDING:
611
+ raise ArtifactNotLoggedError(self, "file_count")
612
+ assert self._file_count is not None
613
+ return self._file_count
614
+
615
+ @property
616
+ def created_at(self) -> str:
617
+ """The time at which the artifact was created."""
618
+ if self._state == ArtifactState.PENDING:
619
+ raise ArtifactNotLoggedError(self, "created_at")
620
+ assert self._created_at is not None
621
+ return self._created_at
622
+
623
+ @property
624
+ def updated_at(self) -> str:
625
+ """The time at which the artifact was last updated."""
626
+ if self._state == ArtifactState.PENDING:
627
+ raise ArtifactNotLoggedError(self, "created_at")
628
+ assert self._created_at is not None
629
+ return self._updated_at or self._created_at
630
+
631
+ # State management.
632
+
633
+ def finalize(self) -> None:
634
+ """Mark this artifact as final, disallowing further modifications.
635
+
636
+ This happens automatically when calling `log_artifact`.
637
+ """
638
+ self._final = True
639
+
640
+ def _ensure_can_add(self) -> None:
641
+ if self._final:
642
+ raise ArtifactFinalizedError(artifact=self)
643
+
644
+ def is_draft(self) -> bool:
645
+ """Whether the artifact is a draft, i.e. it hasn't been saved yet."""
646
+ return self._state == ArtifactState.PENDING
647
+
648
+ def _is_draft_save_started(self) -> bool:
649
+ return self._save_future is not None
650
+
651
+ def save(
652
+ self,
653
+ project: Optional[str] = None,
654
+ settings: Optional["wandb.wandb_sdk.wandb_settings.Settings"] = None,
655
+ ) -> None:
656
+ """Persist any changes made to the artifact.
657
+
658
+ If currently in a run, that run will log this artifact. If not currently in a
659
+ run, a run of type "auto" will be created to track this artifact.
204
660
 
205
661
  Arguments:
206
- name: (str) The name of the new file being added to the artifact.
207
- mode: (str, optional) The mode in which to open the new file.
208
- encoding: (str, optional) The encoding in which to open the new file.
662
+ project: A project to use for the artifact in the case that a run is not
663
+ already in context
664
+ settings: A settings object to use when initializing an automatic run. Most
665
+ commonly used in testing harness.
666
+ """
667
+ if self._state != ArtifactState.PENDING:
668
+ return self._update()
669
+
670
+ if self._incremental:
671
+ with telemetry.context() as tel:
672
+ tel.feature.artifact_incremental = True
673
+
674
+ if wandb.run is None:
675
+ if settings is None:
676
+ settings = wandb.Settings(silent="true")
677
+ with wandb.init(project=project, job_type="auto", settings=settings) as run:
678
+ # redoing this here because in this branch we know we didn't
679
+ # have the run at the beginning of the method
680
+ if self._incremental:
681
+ with telemetry.context(run=run) as tel:
682
+ tel.feature.artifact_incremental = True
683
+ run.log_artifact(self)
684
+ else:
685
+ wandb.run.log_artifact(self)
686
+
687
+ def _set_save_future(
688
+ self, save_future: "MessageFuture", client: RetryingClient
689
+ ) -> None:
690
+ self._save_future = save_future
691
+ self._client = client
692
+
693
+ def wait(self, timeout: Optional[int] = None) -> "Artifact":
694
+ """Wait for this artifact to finish logging, if needed.
695
+
696
+ Arguments:
697
+ timeout: Wait up to this long.
698
+ """
699
+ if self._state == ArtifactState.PENDING:
700
+ if self._save_future is None:
701
+ raise ArtifactNotLoggedError(self, "wait")
702
+ result = self._save_future.get(timeout)
703
+ if not result:
704
+ raise WaitTimeoutError(
705
+ "Artifact upload wait timed out, failed to fetch Artifact response"
706
+ )
707
+ response = result.response.log_artifact_response
708
+ if response.error_message:
709
+ raise ValueError(response.error_message)
710
+ self._populate_after_save(response.artifact_id)
711
+ return self
712
+
713
+ def _populate_after_save(self, artifact_id: str) -> None:
714
+ query = gql(
715
+ """
716
+ query ArtifactByIDShort($id: ID!) {
717
+ artifact(id: $id) {
718
+ artifactSequence {
719
+ project {
720
+ entityName
721
+ name
722
+ }
723
+ name
724
+ }
725
+ versionIndex
726
+ aliases {
727
+ artifactCollectionName
728
+ alias
729
+ }
730
+ state
731
+ currentManifest {
732
+ file {
733
+ directUrl
734
+ }
735
+ }
736
+ commitHash
737
+ fileCount
738
+ createdAt
739
+ updatedAt
740
+ }
741
+ }
742
+ """
743
+ )
744
+ assert self._client is not None
745
+ response = self._client.execute(
746
+ query,
747
+ variable_values={"id": artifact_id},
748
+ )
749
+ attrs = response.get("artifact")
750
+ if attrs is None:
751
+ raise ValueError(f"Unable to fetch artifact with id {artifact_id}")
752
+ self._id = artifact_id
753
+ self._entity = attrs["artifactSequence"]["project"]["entityName"]
754
+ self._project = attrs["artifactSequence"]["project"]["name"]
755
+ self._name = "{}:v{}".format(
756
+ attrs["artifactSequence"]["name"], attrs["versionIndex"]
757
+ )
758
+ self._version = "v{}".format(attrs["versionIndex"])
759
+ self._source_entity = self._entity
760
+ self._source_project = self._project
761
+ self._source_name = self._name
762
+ self._source_version = self._version
763
+ self._aliases = [
764
+ alias["alias"]
765
+ for alias in attrs.get("aliases", [])
766
+ if alias["artifactCollectionName"] == self._name.split(":")[0]
767
+ and not util.alias_is_version_index(alias["alias"])
768
+ ]
769
+ self._state = ArtifactState(attrs["state"])
770
+ with requests.get(attrs["currentManifest"]["file"]["directUrl"]) as request:
771
+ request.raise_for_status()
772
+ self._manifest = ArtifactManifest.from_manifest_json(
773
+ json.loads(util.ensure_text(request.content))
774
+ )
775
+ self._commit_hash = attrs["commitHash"]
776
+ self._file_count = attrs["fileCount"]
777
+ self._created_at = attrs["createdAt"]
778
+ self._updated_at = attrs["updatedAt"]
779
+
780
+ @normalize_exceptions
781
+ def _update(self) -> None:
782
+ """Persists artifact changes to the wandb backend."""
783
+ aliases = None
784
+ introspect_query = gql(
785
+ """
786
+ query ProbeServerAddAliasesInput {
787
+ AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
788
+ name
789
+ inputFields {
790
+ name
791
+ }
792
+ }
793
+ }
794
+ """
795
+ )
796
+ assert self._client is not None
797
+ response = self._client.execute(introspect_query)
798
+ if response.get("AddAliasesInputInfoType"): # wandb backend version >= 0.13.0
799
+ aliases_to_add = set(self._aliases) - set(self._saved_aliases)
800
+ aliases_to_delete = set(self._saved_aliases) - set(self._aliases)
801
+ if len(aliases_to_add) > 0:
802
+ add_mutation = gql(
803
+ """
804
+ mutation addAliases(
805
+ $artifactID: ID!,
806
+ $aliases: [ArtifactCollectionAliasInput!]!,
807
+ ) {
808
+ addAliases(
809
+ input: {artifactID: $artifactID, aliases: $aliases}
810
+ ) {
811
+ success
812
+ }
813
+ }
814
+ """
815
+ )
816
+ assert self._client is not None
817
+ self._client.execute(
818
+ add_mutation,
819
+ variable_values={
820
+ "artifactID": self.id,
821
+ "aliases": [
822
+ {
823
+ "entityName": self._entity,
824
+ "projectName": self._project,
825
+ "artifactCollectionName": self._name.split(":")[0],
826
+ "alias": alias,
827
+ }
828
+ for alias in aliases_to_add
829
+ ],
830
+ },
831
+ )
832
+ if len(aliases_to_delete) > 0:
833
+ delete_mutation = gql(
834
+ """
835
+ mutation deleteAliases(
836
+ $artifactID: ID!,
837
+ $aliases: [ArtifactCollectionAliasInput!]!,
838
+ ) {
839
+ deleteAliases(
840
+ input: {artifactID: $artifactID, aliases: $aliases}
841
+ ) {
842
+ success
843
+ }
844
+ }
845
+ """
846
+ )
847
+ assert self._client is not None
848
+ self._client.execute(
849
+ delete_mutation,
850
+ variable_values={
851
+ "artifactID": self.id,
852
+ "aliases": [
853
+ {
854
+ "entityName": self._entity,
855
+ "projectName": self._project,
856
+ "artifactCollectionName": self._name.split(":")[0],
857
+ "alias": alias,
858
+ }
859
+ for alias in aliases_to_delete
860
+ ],
861
+ },
862
+ )
863
+ self._saved_aliases = copy(self._aliases)
864
+ else: # wandb backend version < 0.13.0
865
+ aliases = [
866
+ {
867
+ "artifactCollectionName": self._name.split(":")[0],
868
+ "alias": alias,
869
+ }
870
+ for alias in self._aliases
871
+ ]
872
+
873
+ mutation = gql(
874
+ """
875
+ mutation updateArtifact(
876
+ $artifactID: ID!,
877
+ $description: String,
878
+ $metadata: JSONString,
879
+ $aliases: [ArtifactAliasInput!]
880
+ ) {
881
+ updateArtifact(
882
+ input: {
883
+ artifactID: $artifactID,
884
+ description: $description,
885
+ metadata: $metadata,
886
+ aliases: $aliases
887
+ }
888
+ ) {
889
+ artifact {
890
+ id
891
+ }
892
+ }
893
+ }
894
+ """
895
+ )
896
+ assert self._client is not None
897
+ self._client.execute(
898
+ mutation,
899
+ variable_values={
900
+ "artifactID": self.id,
901
+ "description": self.description,
902
+ "metadata": util.json_dumps_safer(self.metadata),
903
+ "aliases": aliases,
904
+ },
905
+ )
906
+
907
+ # Adding, removing, getting entries.
908
+
909
+ def __getitem__(self, name: str) -> Optional[data_types.WBValue]:
910
+ """Get the WBValue object located at the artifact relative `name`.
911
+
912
+ Arguments:
913
+ name: The artifact relative name to get
914
+
915
+ Raises:
916
+ ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
209
917
 
210
918
  Examples:
919
+ Basic usage:
211
920
  ```
212
- artifact = wandb.Artifact('my_data', type='dataset')
213
- with artifact.new_file('hello.txt') as f:
214
- f.write('hello!')
921
+ artifact = wandb.Artifact("my_table", type="dataset")
922
+ table = wandb.Table(
923
+ columns=["a", "b", "c"],
924
+ data=[(i, i * 2, 2**i) for i in range(10)]
925
+ )
926
+ artifact["my_table"] = table
927
+
215
928
  wandb.log_artifact(artifact)
216
929
  ```
217
930
 
931
+ Retrieving an object:
932
+ ```
933
+ artifact = wandb.use_artifact("my_table:latest")
934
+ table = artifact["my_table"]
935
+ ```
936
+ """
937
+ return self.get(name)
938
+
939
+ def __setitem__(self, name: str, item: data_types.WBValue) -> ArtifactManifestEntry:
940
+ """Add `item` to the artifact at path `name`.
941
+
942
+ Arguments:
943
+ name: The path within the artifact to add the object.
944
+ item: The object to add.
945
+
218
946
  Returns:
219
- (file): A new file object that can be written to. Upon closing,
220
- the file will be automatically added to the artifact.
947
+ The added manifest entry
221
948
 
222
949
  Raises:
223
950
  ArtifactFinalizedError: if the artifact has already been finalized.
951
+
952
+ Examples:
953
+ Basic usage:
954
+ ```
955
+ artifact = wandb.Artifact("my_table", type="dataset")
956
+ table = wandb.Table(
957
+ columns=["a", "b", "c"],
958
+ data=[(i, i * 2, 2**i) for i in range(10)]
959
+ )
960
+ artifact["my_table"] = table
961
+
962
+ wandb.log_artifact(artifact)
963
+ ```
964
+
965
+ Retrieving an object:
966
+ ```
967
+ artifact = wandb.use_artifact("my_table:latest")
968
+ table = artifact["my_table"]
969
+ ```
970
+ """
971
+ return self.add(item, name)
972
+
973
+ @contextlib.contextmanager
974
+ def new_file(
975
+ self, name: str, mode: str = "w", encoding: Optional[str] = None
976
+ ) -> Generator[IO, None, None]:
977
+ """Open a new temporary file that will be automatically added to the artifact.
978
+
979
+ Arguments:
980
+ name: The name of the new file being added to the artifact.
981
+ mode: The mode in which to open the new file.
982
+ encoding: The encoding in which to open the new file.
983
+
984
+ Returns:
985
+ A new file object that can be written to. Upon closing, the file will be
986
+ automatically added to the artifact.
987
+
988
+ Raises:
989
+ ArtifactFinalizedError: if the artifact has already been finalized.
990
+
991
+ Examples:
992
+ ```
993
+ artifact = wandb.Artifact("my_data", type="dataset")
994
+ with artifact.new_file("hello.txt") as f:
995
+ f.write("hello!")
996
+ wandb.log_artifact(artifact)
997
+ ```
224
998
  """
225
- raise NotImplementedError
999
+ self._ensure_can_add()
1000
+ if self._tmp_dir is None:
1001
+ self._tmp_dir = tempfile.TemporaryDirectory()
1002
+ path = os.path.join(self._tmp_dir.name, name.lstrip("/"))
1003
+ if os.path.exists(path):
1004
+ raise ValueError(f"File with name {name!r} already exists at {path!r}")
1005
+
1006
+ filesystem.mkdir_exists_ok(os.path.dirname(path))
1007
+ try:
1008
+ with util.fsync_open(path, mode, encoding) as f:
1009
+ yield f
1010
+ except UnicodeEncodeError as e:
1011
+ termerror(
1012
+ f"Failed to open the provided file (UnicodeEncodeError: {e}). Please "
1013
+ f"provide the proper encoding."
1014
+ )
1015
+ raise e
1016
+
1017
+ self.add_file(path, name=name)
226
1018
 
227
1019
  def add_file(
228
1020
  self,
229
1021
  local_path: str,
230
1022
  name: Optional[str] = None,
231
1023
  is_tmp: Optional[bool] = False,
232
- ) -> "ArtifactManifestEntry":
1024
+ ) -> ArtifactManifestEntry:
233
1025
  """Add a local file to the artifact.
234
1026
 
235
1027
  Arguments:
236
- local_path: (str) The path to the file being added.
237
- name: (str, optional) The path within the artifact to use for the file being
238
- added. Defaults to the basename of the file.
239
- is_tmp: (bool, optional) If true, then the file is renamed deterministically
240
- to avoid collisions. (default: False)
1028
+ local_path: The path to the file being added.
1029
+ name: The path within the artifact to use for the file being added. Defaults
1030
+ to the basename of the file.
1031
+ is_tmp: If true, then the file is renamed deterministically to avoid
1032
+ collisions.
1033
+
1034
+ Returns:
1035
+ The added manifest entry
1036
+
1037
+ Raises:
1038
+ ArtifactFinalizedError: if the artifact has already been finalized
241
1039
 
242
1040
  Examples:
243
1041
  Add a file without an explicit name:
244
1042
  ```
245
1043
  # Add as `file.txt'
246
- artifact.add_file('path/to/file.txt')
1044
+ artifact.add_file("path/to/file.txt")
247
1045
  ```
248
1046
 
249
1047
  Add a file with an explicit name:
250
1048
  ```
251
1049
  # Add as 'new/path/file.txt'
252
- artifact.add_file('path/to/file.txt', name='new/path/file.txt')
1050
+ artifact.add_file("path/to/file.txt", name="new/path/file.txt")
253
1051
  ```
1052
+ """
1053
+ self._ensure_can_add()
1054
+ if not os.path.isfile(local_path):
1055
+ raise ValueError("Path is not a file: %s" % local_path)
254
1056
 
255
- Raises:
256
- ArtifactFinalizedError: if the artifact has already been finalized.
1057
+ name = LogicalPath(name or os.path.basename(local_path))
1058
+ digest = md5_file_b64(local_path)
257
1059
 
258
- Returns:
259
- ArtifactManifestEntry: the added manifest entry
1060
+ if is_tmp:
1061
+ file_path, file_name = os.path.split(name)
1062
+ file_name_parts = file_name.split(".")
1063
+ file_name_parts[0] = b64_to_hex_id(digest)[:20]
1064
+ name = os.path.join(file_path, ".".join(file_name_parts))
260
1065
 
261
- """
262
- raise NotImplementedError
1066
+ return self._add_local_file(name, local_path, digest=digest)
263
1067
 
264
1068
  def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
265
1069
  """Add a local directory to the artifact.
266
1070
 
267
1071
  Arguments:
268
- local_path: (str) The path to the directory being added.
269
- name: (str, optional) The path within the artifact to use for the directory
270
- being added. Defaults to the root of the artifact.
1072
+ local_path: The path to the directory being added.
1073
+ name: The path within the artifact to use for the directory being added.
1074
+ Defaults to the root of the artifact.
1075
+
1076
+ Raises:
1077
+ ArtifactFinalizedError: if the artifact has already been finalized
271
1078
 
272
1079
  Examples:
273
1080
  Add a directory without an explicit name:
274
1081
  ```
275
1082
  # All files in `my_dir/` are added at the root of the artifact.
276
- artifact.add_dir('my_dir/')
1083
+ artifact.add_dir("my_dir/")
277
1084
  ```
278
1085
 
279
1086
  Add a directory and name it explicitly:
280
1087
  ```
281
1088
  # All files in `my_dir/` are added under `destination/`.
282
- artifact.add_dir('my_dir/', name='destination')
1089
+ artifact.add_dir("my_dir/", name="destination")
283
1090
  ```
284
-
285
- Raises:
286
- ArtifactFinalizedError: if the artifact has already been finalized.
287
-
288
- Returns:
289
- None
290
1091
  """
291
- raise NotImplementedError
1092
+ self._ensure_can_add()
1093
+ if not os.path.isdir(local_path):
1094
+ raise ValueError("Path is not a directory: %s" % local_path)
1095
+
1096
+ termlog(
1097
+ "Adding directory to artifact (%s)... "
1098
+ % os.path.join(".", os.path.normpath(local_path)),
1099
+ newline=False,
1100
+ )
1101
+ start_time = time.time()
1102
+
1103
+ paths = []
1104
+ for dirpath, _, filenames in os.walk(local_path, followlinks=True):
1105
+ for fname in filenames:
1106
+ physical_path = os.path.join(dirpath, fname)
1107
+ logical_path = os.path.relpath(physical_path, start=local_path)
1108
+ if name is not None:
1109
+ logical_path = os.path.join(name, logical_path)
1110
+ paths.append((logical_path, physical_path))
1111
+
1112
+ def add_manifest_file(log_phy_path: Tuple[str, str]) -> None:
1113
+ logical_path, physical_path = log_phy_path
1114
+ self._add_local_file(logical_path, physical_path)
1115
+
1116
+ num_threads = 8
1117
+ pool = multiprocessing.dummy.Pool(num_threads)
1118
+ pool.map(add_manifest_file, paths)
1119
+ pool.close()
1120
+ pool.join()
1121
+
1122
+ termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
292
1123
 
293
1124
  def add_reference(
294
1125
  self,
295
- uri: Union["ArtifactManifestEntry", str],
1126
+ uri: Union[ArtifactManifestEntry, str],
296
1127
  name: Optional[StrPath] = None,
297
1128
  checksum: bool = True,
298
1129
  max_objects: Optional[int] = None,
299
- ) -> Sequence["ArtifactManifestEntry"]:
1130
+ ) -> Sequence[ArtifactManifestEntry]:
300
1131
  """Add a reference denoted by a URI to the artifact.
301
1132
 
302
1133
  Unlike adding files or directories, references are NOT uploaded to W&B. However,
@@ -322,70 +1153,95 @@ class Artifact:
322
1153
  blank.
323
1154
 
324
1155
  Arguments:
325
- uri: (str) The URI path of the reference to add. Can be an object returned
326
- from Artifact.get_path to store a reference to another artifact's entry.
327
- name: (str) The path within the artifact to place the contents of this
328
- reference
329
- checksum: (bool, optional) Whether or not to checksum the resource(s)
330
- located at the reference URI. Checksumming is strongly recommended as it
331
- enables automatic integrity validation, however it can be disabled to
332
- speed up artifact creation. (default: True)
333
- max_objects: (int, optional) The maximum number of objects to consider when
334
- adding a reference that points to directory or bucket store prefix. For
335
- S3 and GCS, this limit is 10,000 by default but is uncapped for other
336
- URI schemes. (default: None)
1156
+ uri: The URI path of the reference to add. Can be an object returned from
1157
+ Artifact.get_path to store a reference to another artifact's entry.
1158
+ name: The path within the artifact to place the contents of this reference
1159
+ checksum: Whether or not to checksum the resource(s) located at the
1160
+ reference URI. Checksumming is strongly recommended as it enables
1161
+ automatic integrity validation, however it can be disabled to speed up
1162
+ artifact creation. (default: True)
1163
+ max_objects: The maximum number of objects to consider when adding a
1164
+ reference that points to directory or bucket store prefix. For S3 and
1165
+ GCS, this limit is 10,000 by default but is uncapped for other URI
1166
+ schemes. (default: None)
1167
+
1168
+ Returns:
1169
+ The added manifest entries.
337
1170
 
338
1171
  Raises:
339
1172
  ArtifactFinalizedError: if the artifact has already been finalized.
340
1173
 
341
- Returns:
342
- List["ArtifactManifestEntry"]: The added manifest entries.
343
-
344
1174
  Examples:
345
- Add an HTTP link:
346
- ```python
347
- # Adds `file.txt` to the root of the artifact as a reference.
348
- artifact.add_reference("http://myserver.com/file.txt")
349
- ```
1175
+ Add an HTTP link:
1176
+ ```python
1177
+ # Adds `file.txt` to the root of the artifact as a reference.
1178
+ artifact.add_reference("http://myserver.com/file.txt")
1179
+ ```
350
1180
 
351
- Add an S3 prefix without an explicit name:
352
- ```python
353
- # All objects under `prefix/` will be added at the root of the artifact.
354
- artifact.add_reference("s3://mybucket/prefix")
355
- ```
1181
+ Add an S3 prefix without an explicit name:
1182
+ ```python
1183
+ # All objects under `prefix/` will be added at the root of the artifact.
1184
+ artifact.add_reference("s3://mybucket/prefix")
1185
+ ```
356
1186
 
357
- Add a GCS prefix with an explicit name:
358
- ```python
359
- # All objects under `prefix/` will be added under `path/` at the artifact root.
360
- artifact.add_reference("gs://mybucket/prefix", name="path")
361
- ```
1187
+ Add a GCS prefix with an explicit name:
1188
+ ```python
1189
+ # All objects under `prefix/` will be added under `path/` at the artifact
1190
+ # root.
1191
+ artifact.add_reference("gs://mybucket/prefix", name="path")
1192
+ ```
362
1193
  """
363
- raise NotImplementedError
364
-
365
- def add(self, obj: WBValue, name: StrPath) -> "ArtifactManifestEntry":
1194
+ self._ensure_can_add()
1195
+ if name is not None:
1196
+ name = LogicalPath(name)
1197
+
1198
+ # This is a bit of a hack, we want to check if the uri is a of the type
1199
+ # ArtifactManifestEntry. If so, then recover the reference URL.
1200
+ if isinstance(uri, ArtifactManifestEntry):
1201
+ uri_str = uri.ref_url()
1202
+ elif isinstance(uri, str):
1203
+ uri_str = uri
1204
+ url = urlparse(str(uri_str))
1205
+ if not url.scheme:
1206
+ raise ValueError(
1207
+ "References must be URIs. To reference a local file, use file://"
1208
+ )
1209
+
1210
+ manifest_entries = self._storage_policy.store_reference(
1211
+ self,
1212
+ URIStr(uri_str),
1213
+ name=name,
1214
+ checksum=checksum,
1215
+ max_objects=max_objects,
1216
+ )
1217
+ for entry in manifest_entries:
1218
+ self.manifest.add_entry(entry)
1219
+
1220
+ return manifest_entries
1221
+
1222
+ def add(self, obj: data_types.WBValue, name: StrPath) -> ArtifactManifestEntry:
366
1223
  """Add wandb.WBValue `obj` to the artifact.
367
1224
 
368
- ```
369
- obj = artifact.get(name)
370
- ```
371
-
372
1225
  Arguments:
373
- obj: (wandb.WBValue) The object to add. Currently support one of
374
- Bokeh, JoinedTable, PartitionedTable, Table, Classes, ImageMask,
375
- BoundingBoxes2D, Audio, Image, Video, Html, Object3D
376
- name: (str) The path within the artifact to add the object.
1226
+ obj: The object to add. Currently support one of Bokeh, JoinedTable,
1227
+ PartitionedTable, Table, Classes, ImageMask, BoundingBoxes2D, Audio,
1228
+ Image, Video, Html, Object3D
1229
+ name: The path within the artifact to add the object.
377
1230
 
378
1231
  Returns:
379
- ArtifactManifestEntry: the added manifest entry
1232
+ The added manifest entry
380
1233
 
381
1234
  Raises:
382
- ArtifactFinalizedError: if the artifact has already been finalized.
1235
+ ArtifactFinalizedError: if the artifact has already been finalized
383
1236
 
384
1237
  Examples:
385
- Basic usage
1238
+ Basic usage:
386
1239
  ```
387
- artifact = wandb.Artifact('my_table', 'dataset')
388
- table = wandb.Table(columns=["a", "b", "c"], data=[[i, i*2, 2**i]])
1240
+ artifact = wandb.Artifact("my_table", type="dataset")
1241
+ table = wandb.Table(
1242
+ columns=["a", "b", "c"],
1243
+ data=[(i, i * 2, 2**i) for i in range(10)]
1244
+ )
389
1245
  artifact.add(table, "my_table")
390
1246
 
391
1247
  wandb.log_artifact(artifact)
@@ -393,87 +1249,294 @@ class Artifact:
393
1249
 
394
1250
  Retrieve an object:
395
1251
  ```
396
- artifact = wandb.use_artifact('my_table:latest')
1252
+ artifact = wandb.use_artifact("my_table:latest")
397
1253
  table = artifact.get("my_table")
398
1254
  ```
399
1255
  """
400
- raise NotImplementedError
401
-
402
- def remove(self, item: Union[str, "os.PathLike", "ArtifactManifestEntry"]) -> None:
1256
+ self._ensure_can_add()
1257
+ name = LogicalPath(name)
1258
+
1259
+ # This is a "hack" to automatically rename tables added to
1260
+ # the wandb /media/tables directory to their sha-based name.
1261
+ # TODO: figure out a more appropriate convention.
1262
+ is_tmp_name = name.startswith("media/tables")
1263
+
1264
+ # Validate that the object is one of the correct wandb.Media types
1265
+ # TODO: move this to checking subclass of wandb.Media once all are
1266
+ # generally supported
1267
+ allowed_types = [
1268
+ data_types.Bokeh,
1269
+ data_types.JoinedTable,
1270
+ data_types.PartitionedTable,
1271
+ data_types.Table,
1272
+ data_types.Classes,
1273
+ data_types.ImageMask,
1274
+ data_types.BoundingBoxes2D,
1275
+ data_types.Audio,
1276
+ data_types.Image,
1277
+ data_types.Video,
1278
+ data_types.Html,
1279
+ data_types.Object3D,
1280
+ data_types.Molecule,
1281
+ data_types._SavedModel,
1282
+ ]
1283
+
1284
+ if not any(isinstance(obj, t) for t in allowed_types):
1285
+ raise ValueError(
1286
+ "Found object of type {}, expected one of {}.".format(
1287
+ obj.__class__, allowed_types
1288
+ )
1289
+ )
1290
+
1291
+ obj_id = id(obj)
1292
+ if obj_id in self._added_objs:
1293
+ return self._added_objs[obj_id][1]
1294
+
1295
+ # If the object is coming from another artifact, save it as a reference
1296
+ ref_path = obj._get_artifact_entry_ref_url()
1297
+ if ref_path is not None:
1298
+ return self.add_reference(ref_path, type(obj).with_suffix(name))[0]
1299
+
1300
+ val = obj.to_json(self)
1301
+ name = obj.with_suffix(name)
1302
+ entry = self.manifest.get_entry_by_path(name)
1303
+ if entry is not None:
1304
+ return entry
1305
+
1306
+ def do_write(f: IO) -> None:
1307
+ import json
1308
+
1309
+ # TODO: Do we need to open with utf-8 codec?
1310
+ f.write(json.dumps(val, sort_keys=True))
1311
+
1312
+ if is_tmp_name:
1313
+ file_path = os.path.join(self._TMP_DIR.name, str(id(self)), name)
1314
+ folder_path, _ = os.path.split(file_path)
1315
+ if not os.path.exists(folder_path):
1316
+ os.makedirs(folder_path)
1317
+ with open(file_path, "w") as tmp_f:
1318
+ do_write(tmp_f)
1319
+ else:
1320
+ with self.new_file(name) as f:
1321
+ file_path = f.name
1322
+ do_write(f)
1323
+
1324
+ # Note, we add the file from our temp directory.
1325
+ # It will be added again later on finalize, but succeed since
1326
+ # the checksum should match
1327
+ entry = self.add_file(file_path, name, is_tmp_name)
1328
+ # We store a reference to the obj so that its id doesn't get reused.
1329
+ self._added_objs[obj_id] = (obj, entry)
1330
+ if obj._artifact_target is None:
1331
+ obj._set_artifact_target(self, entry.path)
1332
+
1333
+ if is_tmp_name:
1334
+ if os.path.exists(file_path):
1335
+ os.remove(file_path)
1336
+
1337
+ return entry
1338
+
1339
+ def _add_local_file(
1340
+ self, name: StrPath, path: StrPath, digest: Optional[B64MD5] = None
1341
+ ) -> ArtifactManifestEntry:
1342
+ with tempfile.NamedTemporaryFile(dir=get_staging_dir(), delete=False) as f:
1343
+ staging_path = f.name
1344
+ shutil.copyfile(path, staging_path)
1345
+ os.chmod(staging_path, 0o400)
1346
+
1347
+ entry = ArtifactManifestEntry(
1348
+ path=name,
1349
+ digest=digest or md5_file_b64(staging_path),
1350
+ size=os.path.getsize(staging_path),
1351
+ local_path=staging_path,
1352
+ )
1353
+
1354
+ self.manifest.add_entry(entry)
1355
+ self._added_local_paths[os.fspath(path)] = entry
1356
+ return entry
1357
+
1358
+ def remove(self, item: Union[StrPath, "ArtifactManifestEntry"]) -> None:
403
1359
  """Remove an item from the artifact.
404
1360
 
405
1361
  Arguments:
406
- item: (str, os.PathLike, ArtifactManifestEntry) the item to remove. Can be a
407
- specific manifest entry or the name of an artifact-relative path. If the
408
- item matches a directory all items in that directory will be removed.
1362
+ item: the item to remove. Can be a specific manifest entry or the name of an
1363
+ artifact-relative path. If the item matches a directory all items in
1364
+ that directory will be removed.
409
1365
 
410
1366
  Raises:
411
1367
  ArtifactFinalizedError: if the artifact has already been finalized.
412
1368
  FileNotFoundError: if the item isn't found in the artifact.
413
-
414
- Returns:
415
- None
416
1369
  """
417
- raise NotImplementedError
1370
+ self._ensure_can_add()
1371
+
1372
+ if isinstance(item, ArtifactManifestEntry):
1373
+ self.manifest.remove_entry(item)
1374
+ return
1375
+
1376
+ path = str(PurePosixPath(item))
1377
+ entry = self.manifest.get_entry_by_path(path)
1378
+ if entry:
1379
+ self.manifest.remove_entry(entry)
1380
+ return
418
1381
 
419
- def get_path(self, name: StrPath) -> "ArtifactManifestEntry":
420
- """Get the path to the file located at the artifact relative `name`.
1382
+ entries = self.manifest.get_entries_in_directory(path)
1383
+ if not entries:
1384
+ raise FileNotFoundError(f"No such file or directory: {path}")
1385
+ for entry in entries:
1386
+ self.manifest.remove_entry(entry)
1387
+
1388
+ def get_path(self, name: StrPath) -> ArtifactManifestEntry:
1389
+ """Get the entry with the given name.
421
1390
 
422
1391
  Arguments:
423
- name: (str) The artifact relative name to get
1392
+ name: The artifact relative name to get
424
1393
 
425
1394
  Raises:
426
1395
  ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
1396
+ KeyError: if the artifact doesn't contain an entry with the given name
427
1397
 
428
1398
  Examples:
429
- Basic usage
1399
+ Basic usage:
430
1400
  ```
431
1401
  # Run logging the artifact
432
1402
  with wandb.init() as r:
433
- artifact = wandb.Artifact('my_dataset', type='dataset')
434
- artifact.add_file('path/to/file.txt')
1403
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1404
+ artifact.add_file("path/to/file.txt")
435
1405
  wandb.log_artifact(artifact)
436
1406
 
437
1407
  # Run using the artifact
438
1408
  with wandb.init() as r:
439
- artifact = r.use_artifact('my_dataset:latest')
440
- path = artifact.get_path('file.txt')
1409
+ artifact = r.use_artifact("my_dataset:latest")
1410
+ path = artifact.get_path("file.txt")
441
1411
 
442
1412
  # Can now download 'file.txt' directly:
443
1413
  path.download()
444
1414
  ```
445
1415
  """
446
- raise NotImplementedError
1416
+ if self._state == ArtifactState.PENDING:
1417
+ raise ArtifactNotLoggedError(self, "get_path")
1418
+
1419
+ name = LogicalPath(name)
1420
+ entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0]
1421
+ if entry is None:
1422
+ raise KeyError("Path not contained in artifact: %s" % name)
1423
+ entry._parent_artifact = self
1424
+ return entry
447
1425
 
448
- def get(self, name: str) -> Optional[WBValue]:
1426
+ def get(self, name: str) -> Optional[data_types.WBValue]:
449
1427
  """Get the WBValue object located at the artifact relative `name`.
450
1428
 
451
1429
  Arguments:
452
- name: (str) The artifact relative name to get
1430
+ name: The artifact relative name to get
453
1431
 
454
1432
  Raises:
455
1433
  ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
456
1434
 
457
1435
  Examples:
458
- Basic usage
1436
+ Basic usage:
459
1437
  ```
460
1438
  # Run logging the artifact
461
1439
  with wandb.init() as r:
462
- artifact = wandb.Artifact('my_dataset', type='dataset')
463
- table = wandb.Table(columns=["a", "b", "c"], data=[[i, i*2, 2**i]])
1440
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1441
+ table = wandb.Table(
1442
+ columns=["a", "b", "c"],
1443
+ data=[(i, i * 2, 2**i) for i in range(10)]
1444
+ )
464
1445
  artifact.add(table, "my_table")
465
1446
  wandb.log_artifact(artifact)
466
1447
 
467
1448
  # Run using the artifact
468
1449
  with wandb.init() as r:
469
- artifact = r.use_artifact('my_dataset:latest')
470
- table = r.get('my_table')
1450
+ artifact = r.use_artifact("my_dataset:latest")
1451
+ table = artifact.get("my_table")
471
1452
  ```
472
1453
  """
473
- raise NotImplementedError
1454
+ if self._state == ArtifactState.PENDING:
1455
+ raise ArtifactNotLoggedError(self, "get")
1456
+
1457
+ entry, wb_class = self._get_obj_entry(name)
1458
+ if entry is None or wb_class is None:
1459
+ return None
1460
+
1461
+ # If the entry is a reference from another artifact, then get it directly from
1462
+ # that artifact.
1463
+ if entry._is_artifact_reference():
1464
+ assert self._client is not None
1465
+ artifact = entry._get_referenced_artifact(self._client)
1466
+ return artifact.get(util.uri_from_path(entry.ref))
1467
+
1468
+ # Special case for wandb.Table. This is intended to be a short term
1469
+ # optimization. Since tables are likely to download many other assets in
1470
+ # artifact(s), we eagerly download the artifact using the parallelized
1471
+ # `artifact.download`. In the future, we should refactor the deserialization
1472
+ # pattern such that this special case is not needed.
1473
+ if wb_class == wandb.Table:
1474
+ self.download(recursive=True)
1475
+
1476
+ # Get the ArtifactManifestEntry
1477
+ item = self.get_path(entry.path)
1478
+ item_path = item.download()
1479
+
1480
+ # Load the object from the JSON blob
1481
+ result = None
1482
+ json_obj = {}
1483
+ with open(item_path) as file:
1484
+ json_obj = json.load(file)
1485
+ result = wb_class.from_json(json_obj, self)
1486
+ result._set_artifact_source(self, name)
1487
+ return result
1488
+
1489
+ def get_added_local_path_name(self, local_path: str) -> Optional[str]:
1490
+ """Get the artifact relative name of a file added by a local filesystem path.
1491
+
1492
+ Arguments:
1493
+ local_path: The local path to resolve into an artifact relative name.
1494
+
1495
+ Returns:
1496
+ The artifact relative name.
1497
+
1498
+ Examples:
1499
+ Basic usage:
1500
+ ```
1501
+ artifact = wandb.Artifact("my_dataset", type="dataset")
1502
+ artifact.add_file("path/to/file.txt", name="artifact/path/file.txt")
1503
+
1504
+ # Returns `artifact/path/file.txt`:
1505
+ name = artifact.get_added_local_path_name("path/to/file.txt")
1506
+ ```
1507
+ """
1508
+ entry = self._added_local_paths.get(local_path, None)
1509
+ if entry is None:
1510
+ return None
1511
+ return entry.path
1512
+
1513
+ def _get_obj_entry(
1514
+ self, name: str
1515
+ ) -> Tuple[Optional["ArtifactManifestEntry"], Optional[Type[WBValue]]]:
1516
+ """Return an object entry by name, handling any type suffixes.
1517
+
1518
+ When objects are added with `.add(obj, name)`, the name is typically changed to
1519
+ include the suffix of the object type when serializing to JSON. So we need to be
1520
+ able to resolve a name, without tasking the user with appending .THING.json.
1521
+ This method returns an entry if it exists by a suffixed name.
1522
+
1523
+ Arguments:
1524
+ name: name used when adding
1525
+ """
1526
+ for wb_class in WBValue.type_mapping().values():
1527
+ wandb_file_name = wb_class.with_suffix(name)
1528
+ entry = self.manifest.entries.get(wandb_file_name)
1529
+ if entry is not None:
1530
+ return entry, wb_class
1531
+ return None, None
1532
+
1533
+ # Downloading.
474
1534
 
475
1535
  def download(
476
- self, root: Optional[str] = None, recursive: bool = False
1536
+ self,
1537
+ root: Optional[str] = None,
1538
+ recursive: bool = False,
1539
+ allow_missing_references: bool = False,
477
1540
  ) -> FilePathStr:
478
1541
  """Download the contents of the artifact to the specified root directory.
479
1542
 
@@ -482,14 +1545,133 @@ class Artifact:
482
1545
  match the artifact.
483
1546
 
484
1547
  Arguments:
485
- root: (str, optional) The directory in which to download this artifact's files.
486
- recursive: (bool, optional) If true, then all dependent artifacts are eagerly
487
- downloaded. Otherwise, the dependent artifacts are downloaded as needed.
1548
+ root: The directory in which to download this artifact's files.
1549
+ recursive: If true, then all dependent artifacts are eagerly downloaded.
1550
+ Otherwise, the dependent artifacts are downloaded as needed.
488
1551
 
489
1552
  Returns:
490
- (str): The path to the downloaded contents.
1553
+ The path to the downloaded contents.
1554
+
1555
+ Raises:
1556
+ ArtifactNotLoggedError: if the artifact has not been logged
491
1557
  """
492
- raise NotImplementedError
1558
+ if self._state == ArtifactState.PENDING:
1559
+ raise ArtifactNotLoggedError(self, "download")
1560
+
1561
+ root = root or self._default_root()
1562
+ self._add_download_root(root)
1563
+
1564
+ nfiles = len(self.manifest.entries)
1565
+ size = sum(e.size or 0 for e in self.manifest.entries.values())
1566
+ log = False
1567
+ if nfiles > 5000 or size > 50 * 1024 * 1024:
1568
+ log = True
1569
+ termlog(
1570
+ "Downloading large artifact {}, {:.2f}MB. {} files... ".format(
1571
+ self.name, size / (1024 * 1024), nfiles
1572
+ ),
1573
+ )
1574
+ start_time = datetime.datetime.now()
1575
+ download_logger = ArtifactDownloadLogger(nfiles=nfiles)
1576
+
1577
+ def _download_entry(
1578
+ entry: ArtifactManifestEntry,
1579
+ api_key: Optional[str],
1580
+ cookies: Optional[Dict],
1581
+ headers: Optional[Dict],
1582
+ ) -> None:
1583
+ _thread_local_api_settings.api_key = api_key
1584
+ _thread_local_api_settings.cookies = cookies
1585
+ _thread_local_api_settings.headers = headers
1586
+
1587
+ try:
1588
+ entry.download(root)
1589
+ except FileNotFoundError as e:
1590
+ if allow_missing_references:
1591
+ wandb.termwarn(str(e))
1592
+ return
1593
+ raise
1594
+ download_logger.notify_downloaded()
1595
+
1596
+ download_entry = partial(
1597
+ _download_entry,
1598
+ api_key=_thread_local_api_settings.api_key,
1599
+ cookies=_thread_local_api_settings.cookies,
1600
+ headers=_thread_local_api_settings.headers,
1601
+ )
1602
+
1603
+ with concurrent.futures.ThreadPoolExecutor(64) as executor:
1604
+ active_futures = set()
1605
+ has_next_page = True
1606
+ cursor = None
1607
+ while has_next_page:
1608
+ attrs = self._fetch_file_urls(cursor)
1609
+ has_next_page = attrs["pageInfo"]["hasNextPage"]
1610
+ cursor = attrs["pageInfo"]["endCursor"]
1611
+ for edge in attrs["edges"]:
1612
+ entry = self.get_path(edge["node"]["name"])
1613
+ entry._download_url = edge["node"]["directUrl"]
1614
+ active_futures.add(executor.submit(download_entry, entry))
1615
+ # Wait for download threads to catch up.
1616
+ max_backlog = 5000
1617
+ if len(active_futures) > max_backlog:
1618
+ for future in concurrent.futures.as_completed(active_futures):
1619
+ future.result() # check for errors
1620
+ active_futures.remove(future)
1621
+ if len(active_futures) <= max_backlog:
1622
+ break
1623
+ # Check for errors.
1624
+ for future in concurrent.futures.as_completed(active_futures):
1625
+ future.result()
1626
+
1627
+ if recursive:
1628
+ for dependent_artifact in self._dependent_artifacts:
1629
+ dependent_artifact.download()
1630
+
1631
+ if log:
1632
+ now = datetime.datetime.now()
1633
+ delta = abs((now - start_time).total_seconds())
1634
+ hours = int(delta // 3600)
1635
+ minutes = int((delta - hours * 3600) // 60)
1636
+ seconds = delta - hours * 3600 - minutes * 60
1637
+ termlog(
1638
+ f"Done. {hours}:{minutes}:{seconds:.1f}",
1639
+ prefix=False,
1640
+ )
1641
+ return FilePathStr(root)
1642
+
1643
+ @retry.retriable(
1644
+ retry_timedelta=datetime.timedelta(minutes=3),
1645
+ retryable_exceptions=(requests.RequestException),
1646
+ )
1647
+ def _fetch_file_urls(self, cursor: Optional[str]) -> Any:
1648
+ query = gql(
1649
+ """
1650
+ query ArtifactFileURLs($id: ID!, $cursor: String) {
1651
+ artifact(id: $id) {
1652
+ files(after: $cursor, first: 5000) {
1653
+ pageInfo {
1654
+ hasNextPage
1655
+ endCursor
1656
+ }
1657
+ edges {
1658
+ node {
1659
+ name
1660
+ directUrl
1661
+ }
1662
+ }
1663
+ }
1664
+ }
1665
+ }
1666
+ """
1667
+ )
1668
+ assert self._client is not None
1669
+ response = self._client.execute(
1670
+ query,
1671
+ variable_values={"id": self.id, "cursor": cursor},
1672
+ timeout=60,
1673
+ )
1674
+ return response["artifact"]["files"]
493
1675
 
494
1676
  def checkout(self, root: Optional[str] = None) -> str:
495
1677
  """Replace the specified root directory with the contents of the artifact.
@@ -498,12 +1680,30 @@ class Artifact:
498
1680
  artifact.
499
1681
 
500
1682
  Arguments:
501
- root: (str, optional) The directory to replace with this artifact's files.
1683
+ root: The directory to replace with this artifact's files.
502
1684
 
503
1685
  Returns:
504
- (str): The path to the checked out contents.
1686
+ The path to the checked out contents.
1687
+
1688
+ Raises:
1689
+ ArtifactNotLoggedError: if the artifact has not been logged
505
1690
  """
506
- raise NotImplementedError
1691
+ if self._state == ArtifactState.PENDING:
1692
+ raise ArtifactNotLoggedError(self, "checkout")
1693
+
1694
+ root = root or self._default_root(include_version=False)
1695
+
1696
+ for dirpath, _, files in os.walk(root):
1697
+ for file in files:
1698
+ full_path = os.path.join(dirpath, file)
1699
+ artifact_path = os.path.relpath(full_path, start=root)
1700
+ try:
1701
+ self.get_path(artifact_path)
1702
+ except KeyError:
1703
+ # File is not part of the artifact, remove it.
1704
+ os.remove(full_path)
1705
+
1706
+ return self.download(root=root)
507
1707
 
508
1708
  def verify(self, root: Optional[str] = None) -> None:
509
1709
  """Verify that the actual contents of an artifact match the manifest.
@@ -514,108 +1714,388 @@ class Artifact:
514
1714
  NOTE: References are not verified.
515
1715
 
516
1716
  Arguments:
517
- root: (str, optional) The directory to verify. If None
518
- artifact will be downloaded to './artifacts/self.name/'
1717
+ root: The directory to verify. If None artifact will be downloaded to
1718
+ './artifacts/self.name/'
519
1719
 
520
1720
  Raises:
521
- (ValueError): If the verification fails.
1721
+ ArtifactNotLoggedError: if the artifact has not been logged
1722
+ ValueError: If the verification fails.
522
1723
  """
523
- raise NotImplementedError
1724
+ if self._state == ArtifactState.PENDING:
1725
+ raise ArtifactNotLoggedError(self, "verify")
1726
+
1727
+ root = root or self._default_root()
1728
+
1729
+ for dirpath, _, files in os.walk(root):
1730
+ for file in files:
1731
+ full_path = os.path.join(dirpath, file)
1732
+ artifact_path = os.path.relpath(full_path, start=root)
1733
+ try:
1734
+ self.get_path(artifact_path)
1735
+ except KeyError:
1736
+ raise ValueError(
1737
+ "Found file {} which is not a member of artifact {}".format(
1738
+ full_path, self.name
1739
+ )
1740
+ )
1741
+
1742
+ ref_count = 0
1743
+ for entry in self.manifest.entries.values():
1744
+ if entry.ref is None:
1745
+ if md5_file_b64(os.path.join(root, entry.path)) != entry.digest:
1746
+ raise ValueError("Digest mismatch for file: %s" % entry.path)
1747
+ else:
1748
+ ref_count += 1
1749
+ if ref_count > 0:
1750
+ print("Warning: skipped verification of %s refs" % ref_count)
1751
+
1752
+ def file(self, root: Optional[str] = None) -> StrPath:
1753
+ """Download a single file artifact to dir specified by the root.
524
1754
 
525
- def save(self) -> None:
526
- """Persist any changes made to the artifact.
1755
+ Arguments:
1756
+ root: The root directory in which to place the file. Defaults to
1757
+ './artifacts/self.name/'.
527
1758
 
528
1759
  Returns:
529
- None
1760
+ The full path of the downloaded file.
1761
+
1762
+ Raises:
1763
+ ArtifactNotLoggedError: if the artifact has not been logged
1764
+ ValueError: if the artifact contains more than one file
530
1765
  """
531
- raise NotImplementedError
1766
+ if self._state == ArtifactState.PENDING:
1767
+ raise ArtifactNotLoggedError(self, "file")
532
1768
 
533
- def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
534
- """Link this artifact to a portfolio (a promoted collection of artifacts), with aliases.
1769
+ if root is None:
1770
+ root = os.path.join(".", "artifacts", self.name)
1771
+
1772
+ if len(self.manifest.entries) > 1:
1773
+ raise ValueError(
1774
+ "This artifact contains more than one file, call `.download()` to get "
1775
+ 'all files or call .get_path("filename").download()'
1776
+ )
1777
+
1778
+ return self.get_path(list(self.manifest.entries)[0]).download(root)
1779
+
1780
+ def files(
1781
+ self, names: Optional[List[str]] = None, per_page: int = 50
1782
+ ) -> ArtifactFiles:
1783
+ """Iterate over all files stored in this artifact.
535
1784
 
536
1785
  Arguments:
537
- target_path: (str) The path to the portfolio. It must take the form
538
- {portfolio}, {project}/{portfolio} or {entity}/{project}/{portfolio}.
539
- aliases: (Optional[List[str]]) A list of strings which uniquely
540
- identifies the artifact inside the specified portfolio.
1786
+ names: The filename paths relative to the root of the artifact you wish to
1787
+ list.
1788
+ per_page: The number of files to return per request
541
1789
 
542
1790
  Returns:
543
- None
1791
+ An iterator containing `File` objects
1792
+
1793
+ Raises:
1794
+ ArtifactNotLoggedError: if the artifact has not been logged
544
1795
  """
545
- raise NotImplementedError
1796
+ if self._state == ArtifactState.PENDING:
1797
+ raise ArtifactNotLoggedError(self, "files")
546
1798
 
547
- def delete(self, delete_aliases: bool = False) -> None:
548
- """Delete this artifact, cleaning up all files associated with it.
1799
+ return ArtifactFiles(self._client, self, names, per_page)
549
1800
 
550
- NOTE: Deletion is permanent and CANNOT be undone.
1801
+ def _default_root(self, include_version: bool = True) -> str:
1802
+ name = self.name if include_version else self.name.split(":")[0]
1803
+ root = os.path.join(env.get_artifact_dir(), name)
1804
+ if platform.system() == "Windows":
1805
+ head, tail = os.path.splitdrive(root)
1806
+ root = head + tail.replace(":", "-")
1807
+ return root
551
1808
 
552
- Returns:
553
- None
554
- """
555
- raise NotImplementedError
1809
+ def _add_download_root(self, dir_path: str) -> None:
1810
+ self._download_roots.add(os.path.abspath(dir_path))
556
1811
 
557
- def wait(self) -> "Artifact":
558
- """Wait for this artifact to finish logging, if needed.
1812
+ def _local_path_to_name(self, file_path: str) -> Optional[str]:
1813
+ """Convert a local file path to a path entry in the artifact."""
1814
+ abs_file_path = os.path.abspath(file_path)
1815
+ abs_file_parts = abs_file_path.split(os.sep)
1816
+ for i in range(len(abs_file_parts) + 1):
1817
+ if os.path.join(os.sep, *abs_file_parts[:i]) in self._download_roots:
1818
+ return os.path.join(*abs_file_parts[i:])
1819
+ return None
559
1820
 
560
- Returns:
561
- Artifact
562
- """
563
- raise NotImplementedError
1821
+ # Others.
564
1822
 
565
- def __getitem__(self, name: str) -> Optional[WBValue]:
566
- """Get the WBValue object located at the artifact relative `name`.
1823
+ def delete(self, delete_aliases: bool = False) -> None:
1824
+ """Delete an artifact and its files.
567
1825
 
568
1826
  Arguments:
569
- name: (str) The artifact relative name to get
1827
+ delete_aliases: If true, deletes all aliases associated with the artifact.
1828
+ Otherwise, this raises an exception if the artifact has existing
1829
+ aliases.
570
1830
 
571
1831
  Raises:
572
- ArtifactNotLoggedError: if the artifact isn't logged or the run is offline
1832
+ ArtifactNotLoggedError: if the artifact has not been logged
573
1833
 
574
1834
  Examples:
575
- Basic usage
1835
+ Delete all the "model" artifacts a run has logged:
576
1836
  ```
577
- artifact = wandb.Artifact('my_table', 'dataset')
578
- table = wandb.Table(columns=["a", "b", "c"], data=[[i, i*2, 2**i]])
579
- artifact["my_table"] = table
580
-
581
- wandb.log_artifact(artifact)
582
- ```
583
-
584
- Retrieving an object:
585
- ```
586
- artifact = wandb.use_artifact('my_table:latest')
587
- table = artifact["my_table"]
1837
+ runs = api.runs(path="my_entity/my_project")
1838
+ for run in runs:
1839
+ for artifact in run.logged_artifacts():
1840
+ if artifact.type == "model":
1841
+ artifact.delete(delete_aliases=True)
588
1842
  ```
589
1843
  """
590
- return self.get(name)
1844
+ if self._state == ArtifactState.PENDING:
1845
+ raise ArtifactNotLoggedError(self, "delete")
1846
+ self._delete(delete_aliases)
1847
+
1848
+ @normalize_exceptions
1849
+ def _delete(self, delete_aliases: bool = False) -> None:
1850
+ mutation = gql(
1851
+ """
1852
+ mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
1853
+ deleteArtifact(input: {
1854
+ artifactID: $artifactID
1855
+ deleteAliases: $deleteAliases
1856
+ }) {
1857
+ artifact {
1858
+ id
1859
+ }
1860
+ }
1861
+ }
1862
+ """
1863
+ )
1864
+ assert self._client is not None
1865
+ self._client.execute(
1866
+ mutation,
1867
+ variable_values={
1868
+ "artifactID": self.id,
1869
+ "deleteAliases": delete_aliases,
1870
+ },
1871
+ )
591
1872
 
592
- def __setitem__(self, name: str, item: WBValue) -> "ArtifactManifestEntry":
593
- """Add `item` to the artifact at path `name`.
1873
+ def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1874
+ """Link this artifact to a portfolio (a promoted collection of artifacts).
594
1875
 
595
1876
  Arguments:
596
- name: (str) The path within the artifact to add the object.
597
- item: (wandb.WBValue) The object to add.
598
-
599
- Returns:
600
- ArtifactManifestEntry: the added manifest entry
1877
+ target_path: The path to the portfolio. It must take the form {portfolio},
1878
+ {project}/{portfolio} or {entity}/{project}/{portfolio}.
1879
+ aliases: A list of strings which uniquely identifies the artifact inside the
1880
+ specified portfolio.
601
1881
 
602
1882
  Raises:
603
- ArtifactFinalizedError: if the artifact has already been finalized.
604
-
605
- Examples:
606
- Basic usage
607
- ```
608
- artifact = wandb.Artifact('my_table', 'dataset')
609
- table = wandb.Table(columns=["a", "b", "c"], data=[[i, i*2, 2**i]])
610
- artifact["my_table"] = table
1883
+ ArtifactNotLoggedError: if the artifact has not been logged
1884
+ """
1885
+ if self._state == ArtifactState.PENDING:
1886
+ raise ArtifactNotLoggedError(self, "link")
1887
+ self._link(target_path, aliases)
1888
+
1889
+ @normalize_exceptions
1890
+ def _link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1891
+ if ":" in target_path:
1892
+ raise ValueError(
1893
+ f"target_path {target_path} cannot contain `:` because it is not an "
1894
+ f"alias."
1895
+ )
1896
+
1897
+ portfolio, project, entity = util._parse_entity_project_item(target_path)
1898
+ aliases = util._resolve_aliases(aliases)
1899
+
1900
+ run_entity = wandb.run.entity if wandb.run else None
1901
+ run_project = wandb.run.project if wandb.run else None
1902
+ entity = entity or run_entity or self.entity
1903
+ project = project or run_project or self.project
1904
+
1905
+ mutation = gql(
1906
+ """
1907
+ mutation LinkArtifact(
1908
+ $artifactID: ID!,
1909
+ $artifactPortfolioName: String!,
1910
+ $entityName: String!,
1911
+ $projectName: String!,
1912
+ $aliases: [ArtifactAliasInput!]
1913
+ ) {
1914
+ linkArtifact(
1915
+ input: {
1916
+ artifactID: $artifactID,
1917
+ artifactPortfolioName: $artifactPortfolioName,
1918
+ entityName: $entityName,
1919
+ projectName: $projectName,
1920
+ aliases: $aliases
1921
+ }
1922
+ ) {
1923
+ versionIndex
1924
+ }
1925
+ }
1926
+ """
1927
+ )
1928
+ assert self._client is not None
1929
+ self._client.execute(
1930
+ mutation,
1931
+ variable_values={
1932
+ "artifactID": self.id,
1933
+ "artifactPortfolioName": portfolio,
1934
+ "entityName": entity,
1935
+ "projectName": project,
1936
+ "aliases": [
1937
+ {"alias": alias, "artifactCollectionName": portfolio}
1938
+ for alias in aliases
1939
+ ],
1940
+ },
1941
+ )
1942
+
1943
+ def used_by(self) -> List[Run]:
1944
+ """Get a list of the runs that have used this artifact.
611
1945
 
612
- wandb.log_artifact(artifact)
613
- ```
1946
+ Raises:
1947
+ ArtifactNotLoggedError: if the artifact has not been logged
1948
+ """
1949
+ if self._state == ArtifactState.PENDING:
1950
+ raise ArtifactNotLoggedError(self, "used_by")
1951
+
1952
+ query = gql(
1953
+ """
1954
+ query ArtifactUsedBy(
1955
+ $id: ID!,
1956
+ ) {
1957
+ artifact(id: $id) {
1958
+ usedBy {
1959
+ edges {
1960
+ node {
1961
+ name
1962
+ project {
1963
+ name
1964
+ entityName
1965
+ }
1966
+ }
1967
+ }
1968
+ }
1969
+ }
1970
+ }
1971
+ """
1972
+ )
1973
+ assert self._client is not None
1974
+ response = self._client.execute(
1975
+ query,
1976
+ variable_values={"id": self.id},
1977
+ )
1978
+ return [
1979
+ Run(
1980
+ self._client,
1981
+ edge["node"]["project"]["entityName"],
1982
+ edge["node"]["project"]["name"],
1983
+ edge["node"]["name"],
1984
+ )
1985
+ for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
1986
+ ]
1987
+
1988
+ def logged_by(self) -> Optional[Run]:
1989
+ """Get the run that first logged this artifact.
614
1990
 
615
- Retrieving an object:
616
- ```
617
- artifact = wandb.use_artifact('my_table:latest')
618
- table = artifact["my_table"]
619
- ```
1991
+ Raises:
1992
+ ArtifactNotLoggedError: if the artifact has not been logged
620
1993
  """
621
- return self.add(item, name)
1994
+ if self._state == ArtifactState.PENDING:
1995
+ raise ArtifactNotLoggedError(self, "logged_by")
1996
+
1997
+ query = gql(
1998
+ """
1999
+ query ArtifactCreatedBy(
2000
+ $id: ID!
2001
+ ) {
2002
+ artifact(id: $id) {
2003
+ createdBy {
2004
+ ... on Run {
2005
+ name
2006
+ project {
2007
+ name
2008
+ entityName
2009
+ }
2010
+ }
2011
+ }
2012
+ }
2013
+ }
2014
+ """
2015
+ )
2016
+ assert self._client is not None
2017
+ response = self._client.execute(
2018
+ query,
2019
+ variable_values={"id": self.id},
2020
+ )
2021
+ creator = response.get("artifact", {}).get("createdBy", {})
2022
+ if creator.get("name") is None:
2023
+ return None
2024
+ return Run(
2025
+ self._client,
2026
+ creator["project"]["entityName"],
2027
+ creator["project"]["name"],
2028
+ creator["name"],
2029
+ )
2030
+
2031
+ def json_encode(self) -> Dict[str, Any]:
2032
+ if self._state == ArtifactState.PENDING:
2033
+ raise ArtifactNotLoggedError(self, "json_encode")
2034
+ return util.artifact_to_json(self)
2035
+
2036
+ @staticmethod
2037
+ def _expected_type(
2038
+ entity_name: str, project_name: str, name: str, client: RetryingClient
2039
+ ) -> Optional[str]:
2040
+ """Returns the expected type for a given artifact name and project."""
2041
+ query = gql(
2042
+ """
2043
+ query ArtifactType(
2044
+ $entityName: String,
2045
+ $projectName: String,
2046
+ $name: String!
2047
+ ) {
2048
+ project(name: $projectName, entityName: $entityName) {
2049
+ artifact(name: $name) {
2050
+ artifactType {
2051
+ name
2052
+ }
2053
+ }
2054
+ }
2055
+ }
2056
+ """
2057
+ )
2058
+ if ":" not in name:
2059
+ name += ":latest"
2060
+ response = client.execute(
2061
+ query,
2062
+ variable_values={
2063
+ "entityName": entity_name,
2064
+ "projectName": project_name,
2065
+ "name": name,
2066
+ },
2067
+ )
2068
+ return (
2069
+ ((response.get("project") or {}).get("artifact") or {}).get("artifactType")
2070
+ or {}
2071
+ ).get("name")
2072
+
2073
+ @staticmethod
2074
+ def _normalize_metadata(metadata: Optional[Dict[str, Any]]) -> Dict[str, Any]:
2075
+ if metadata is None:
2076
+ return {}
2077
+ if not isinstance(metadata, dict):
2078
+ raise TypeError(f"metadata must be dict, not {type(metadata)}")
2079
+ return cast(
2080
+ Dict[str, Any], json.loads(json.dumps(util.json_friendly_val(metadata)))
2081
+ )
2082
+
2083
+ def _load_manifest(self, url: str) -> None:
2084
+ with requests.get(url) as request:
2085
+ request.raise_for_status()
2086
+ self._manifest = ArtifactManifest.from_manifest_json(
2087
+ json.loads(util.ensure_text(request.content))
2088
+ )
2089
+ for entry in self.manifest.entries.values():
2090
+ if entry._is_artifact_reference():
2091
+ assert self._client is not None
2092
+ dep_artifact = entry._get_referenced_artifact(self._client)
2093
+ self._dependent_artifacts.add(dep_artifact)
2094
+
2095
+
2096
+ class _ArtifactVersionType(WBType):
2097
+ name = "artifactVersion"
2098
+ types = [Artifact]
2099
+
2100
+
2101
+ TypeRegistry.add(_ArtifactVersionType)