wandb 0.15.4__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (102) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/internal.py +3 -0
  4. wandb/apis/public.py +18 -20
  5. wandb/beta/workflows.py +5 -6
  6. wandb/cli/cli.py +27 -27
  7. wandb/data_types.py +2 -0
  8. wandb/integration/langchain/wandb_tracer.py +16 -179
  9. wandb/integration/sagemaker/config.py +2 -2
  10. wandb/integration/tensorboard/log.py +4 -4
  11. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  12. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  13. wandb/proto/wandb_deprecated.py +3 -1
  14. wandb/sdk/__init__.py +1 -4
  15. wandb/sdk/artifacts/__init__.py +0 -14
  16. wandb/sdk/artifacts/artifact.py +1757 -277
  17. wandb/sdk/artifacts/artifact_manifest_entry.py +26 -6
  18. wandb/sdk/artifacts/artifact_state.py +10 -0
  19. wandb/sdk/artifacts/artifacts_cache.py +7 -8
  20. wandb/sdk/artifacts/exceptions.py +4 -4
  21. wandb/sdk/artifacts/storage_handler.py +2 -2
  22. wandb/sdk/artifacts/storage_handlers/azure_handler.py +16 -6
  23. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +2 -2
  24. wandb/sdk/artifacts/storage_handlers/http_handler.py +2 -2
  25. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -2
  26. wandb/sdk/artifacts/storage_handlers/multi_handler.py +2 -2
  27. wandb/sdk/artifacts/storage_handlers/s3_handler.py +35 -32
  28. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -2
  29. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +5 -9
  30. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -2
  31. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +2 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +24 -16
  33. wandb/sdk/artifacts/storage_policy.py +3 -3
  34. wandb/sdk/data_types/_dtypes.py +7 -12
  35. wandb/sdk/data_types/base_types/json_metadata.py +2 -2
  36. wandb/sdk/data_types/base_types/media.py +5 -6
  37. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  38. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +4 -5
  39. wandb/sdk/data_types/helper_types/classes.py +5 -8
  40. wandb/sdk/data_types/helper_types/image_mask.py +4 -5
  41. wandb/sdk/data_types/histogram.py +3 -3
  42. wandb/sdk/data_types/html.py +3 -4
  43. wandb/sdk/data_types/image.py +4 -5
  44. wandb/sdk/data_types/molecule.py +2 -2
  45. wandb/sdk/data_types/object_3d.py +3 -3
  46. wandb/sdk/data_types/plotly.py +2 -2
  47. wandb/sdk/data_types/saved_model.py +7 -8
  48. wandb/sdk/data_types/trace_tree.py +4 -4
  49. wandb/sdk/data_types/video.py +4 -4
  50. wandb/sdk/interface/interface.py +8 -10
  51. wandb/sdk/internal/file_stream.py +2 -3
  52. wandb/sdk/internal/internal_api.py +99 -4
  53. wandb/sdk/internal/job_builder.py +15 -7
  54. wandb/sdk/internal/sender.py +4 -0
  55. wandb/sdk/internal/settings_static.py +1 -0
  56. wandb/sdk/launch/_project_spec.py +9 -7
  57. wandb/sdk/launch/agent/agent.py +115 -58
  58. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  59. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  60. wandb/sdk/launch/builder/abstract.py +5 -1
  61. wandb/sdk/launch/builder/build.py +16 -10
  62. wandb/sdk/launch/builder/docker_builder.py +9 -2
  63. wandb/sdk/launch/builder/kaniko_builder.py +108 -22
  64. wandb/sdk/launch/builder/noop.py +3 -1
  65. wandb/sdk/launch/environment/aws_environment.py +2 -1
  66. wandb/sdk/launch/environment/azure_environment.py +124 -0
  67. wandb/sdk/launch/github_reference.py +30 -18
  68. wandb/sdk/launch/launch.py +1 -1
  69. wandb/sdk/launch/loader.py +15 -0
  70. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  71. wandb/sdk/launch/registry/elastic_container_registry.py +38 -4
  72. wandb/sdk/launch/registry/google_artifact_registry.py +46 -7
  73. wandb/sdk/launch/runner/abstract.py +19 -3
  74. wandb/sdk/launch/runner/kubernetes_runner.py +111 -47
  75. wandb/sdk/launch/runner/local_container.py +101 -48
  76. wandb/sdk/launch/runner/sagemaker_runner.py +59 -9
  77. wandb/sdk/launch/runner/vertex_runner.py +8 -4
  78. wandb/sdk/launch/sweeps/scheduler.py +102 -27
  79. wandb/sdk/launch/sweeps/utils.py +21 -0
  80. wandb/sdk/launch/utils.py +19 -7
  81. wandb/sdk/lib/_settings_toposort_generated.py +3 -0
  82. wandb/sdk/service/server.py +22 -9
  83. wandb/sdk/service/service.py +27 -8
  84. wandb/sdk/verify/verify.py +6 -9
  85. wandb/sdk/wandb_config.py +2 -4
  86. wandb/sdk/wandb_init.py +2 -0
  87. wandb/sdk/wandb_require.py +7 -0
  88. wandb/sdk/wandb_run.py +32 -35
  89. wandb/sdk/wandb_settings.py +10 -3
  90. wandb/testing/relay.py +15 -2
  91. wandb/util.py +55 -23
  92. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/METADATA +11 -8
  93. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/RECORD +97 -97
  94. wandb/integration/langchain/util.py +0 -191
  95. wandb/sdk/artifacts/invalid_artifact.py +0 -23
  96. wandb/sdk/artifacts/lazy_artifact.py +0 -162
  97. wandb/sdk/artifacts/local_artifact.py +0 -719
  98. wandb/sdk/artifacts/public_artifact.py +0 -1188
  99. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  100. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  101. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +0 -0
  102. {wandb-0.15.4.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
@@ -1,1188 +0,0 @@
1
- """Public (saved) artifact."""
2
- import contextlib
3
- import datetime
4
- import json
5
- import os
6
- import platform
7
- import re
8
- import urllib
9
- from functools import partial
10
- from typing import (
11
- IO,
12
- TYPE_CHECKING,
13
- Any,
14
- Dict,
15
- Generator,
16
- Iterable,
17
- List,
18
- Mapping,
19
- Optional,
20
- Sequence,
21
- Set,
22
- Tuple,
23
- Type,
24
- Union,
25
- )
26
-
27
- import requests
28
-
29
- import wandb
30
- from wandb import util
31
- from wandb.apis.normalize import normalize_exceptions
32
- from wandb.apis.public import ArtifactFiles, RetryingClient, Run
33
- from wandb.data_types import WBValue
34
- from wandb.env import get_artifact_dir
35
- from wandb.errors.term import termlog
36
- from wandb.sdk.artifacts.artifact import Artifact as ArtifactInterface
37
- from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
38
- from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
39
- from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
40
- from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
41
- from wandb.sdk.lib.hashutil import hex_to_b64_id, md5_file_b64
42
- from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath
43
-
44
- if TYPE_CHECKING:
45
- from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
46
- from wandb.sdk.artifacts.local_artifact import Artifact as LocalArtifact
47
-
48
- reset_path = util.vendor_setup()
49
-
50
- from wandb_gql import gql # noqa: E402
51
-
52
- reset_path()
53
-
54
- ARTIFACT_FRAGMENT = """
55
- fragment ArtifactFragment on Artifact {
56
- id
57
- digest
58
- description
59
- state
60
- size
61
- createdAt
62
- updatedAt
63
- labels
64
- metadata
65
- fileCount
66
- versionIndex
67
- aliases {
68
- artifactCollectionName
69
- alias
70
- }
71
- artifactSequence {
72
- id
73
- name
74
- }
75
- artifactType {
76
- id
77
- name
78
- project {
79
- name
80
- entity {
81
- name
82
- }
83
- }
84
- }
85
- commitHash
86
- }
87
- """
88
-
89
-
90
- class Artifact(ArtifactInterface):
91
- """A wandb Artifact.
92
-
93
- An artifact that has been logged, including all its attributes, links to the runs
94
- that use it, and a link to the run that logged it.
95
-
96
- Examples:
97
- Basic usage
98
- ```
99
- api = wandb.Api()
100
- artifact = api.artifact('project/artifact:alias')
101
-
102
- # Get information about the artifact...
103
- artifact.digest
104
- artifact.aliases
105
- ```
106
-
107
- Updating an artifact
108
- ```
109
- artifact = api.artifact('project/artifact:alias')
110
-
111
- # Update the description
112
- artifact.description = 'My new description'
113
-
114
- # Selectively update metadata keys
115
- artifact.metadata["oldKey"] = "new value"
116
-
117
- # Replace the metadata entirely
118
- artifact.metadata = {"newKey": "new value"}
119
-
120
- # Add an alias
121
- artifact.aliases.append('best')
122
-
123
- # Remove an alias
124
- artifact.aliases.remove('latest')
125
-
126
- # Completely replace the aliases
127
- artifact.aliases = ['replaced']
128
-
129
- # Persist all artifact modifications
130
- artifact.save()
131
- ```
132
-
133
- Artifact graph traversal
134
- ```
135
- artifact = api.artifact('project/artifact:alias')
136
-
137
- # Walk up and down the graph from an artifact:
138
- producer_run = artifact.logged_by()
139
- consumer_runs = artifact.used_by()
140
-
141
- # Walk up and down the graph from a run:
142
- logged_artifacts = run.logged_artifacts()
143
- used_artifacts = run.used_artifacts()
144
- ```
145
-
146
- Deleting an artifact
147
- ```
148
- artifact = api.artifact('project/artifact:alias')
149
- artifact.delete()
150
- ```
151
- """
152
-
153
- QUERY = gql(
154
- """
155
- query ArtifactWithCurrentManifest(
156
- $id: ID!,
157
- ) {
158
- artifact(id: $id) {
159
- currentManifest {
160
- id
161
- file {
162
- id
163
- directUrl
164
- }
165
- }
166
- ...ArtifactFragment
167
- }
168
- }
169
- %s
170
- """
171
- % ARTIFACT_FRAGMENT
172
- )
173
-
174
- @classmethod
175
- def from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
176
- artifact = get_artifacts_cache().get_artifact(artifact_id)
177
- if artifact is not None:
178
- assert isinstance(artifact, Artifact)
179
- return artifact
180
- response: Mapping[str, Any] = client.execute(
181
- Artifact.QUERY,
182
- variable_values={"id": artifact_id},
183
- )
184
-
185
- if response.get("artifact") is None:
186
- return None
187
- p = response.get("artifact", {}).get("artifactType", {}).get("project", {})
188
- project = p.get("name") # defaults to None
189
- entity = p.get("entity", {}).get("name")
190
- name = "{}:v{}".format(
191
- response["artifact"]["artifactSequence"]["name"],
192
- response["artifact"]["versionIndex"],
193
- )
194
- artifact = cls(
195
- client=client,
196
- entity=entity,
197
- project=project,
198
- name=name,
199
- attrs=response["artifact"],
200
- )
201
- index_file_url = response["artifact"]["currentManifest"]["file"]["directUrl"]
202
- with requests.get(index_file_url) as req:
203
- req.raise_for_status()
204
- artifact._manifest = ArtifactManifest.from_manifest_json(
205
- json.loads(util.ensure_text(req.content))
206
- )
207
-
208
- artifact._load_dependent_manifests()
209
-
210
- return artifact
211
-
212
- def __init__(
213
- self,
214
- client: RetryingClient,
215
- entity: str,
216
- project: str,
217
- name: str,
218
- attrs: Optional[Dict[str, Any]] = None,
219
- ) -> None:
220
- self.client = client
221
- self._entity = entity
222
- self._project = project
223
- self._name = name
224
- self._artifact_collection_name = name.split(":")[0]
225
- self._attrs = attrs or self._load()
226
-
227
- # The entity and project above are taken from the passed-in artifact version path
228
- # so if the user is pulling an artifact version from an artifact portfolio, the entity/project
229
- # of that portfolio may be different than the birth entity/project of the artifact version.
230
- self._source_project = (
231
- self._attrs.get("artifactType", {}).get("project", {}).get("name")
232
- )
233
- self._source_entity = (
234
- self._attrs.get("artifactType", {})
235
- .get("project", {})
236
- .get("entity", {})
237
- .get("name")
238
- )
239
- self._metadata = json.loads(self._attrs.get("metadata") or "{}")
240
- self._description = self._attrs.get("description", None)
241
- self._source_name = "{}:v{}".format(
242
- self._attrs["artifactSequence"]["name"], self._attrs.get("versionIndex")
243
- )
244
- self._source_version = "v{}".format(self._attrs.get("versionIndex"))
245
- # We will only show aliases under the Collection this artifact version is fetched from
246
- # _aliases will be a mutable copy on which the user can append or remove aliases
247
- self._aliases: List[str] = [
248
- a["alias"]
249
- for a in self._attrs["aliases"]
250
- if not re.match(r"^v\d+$", a["alias"])
251
- and a["artifactCollectionName"] == self._artifact_collection_name
252
- ]
253
- self._frozen_aliases: List[str] = [a for a in self._aliases]
254
- self._manifest: Optional[ArtifactManifest] = None
255
- self._is_downloaded: bool = False
256
- self._dependent_artifacts: List["Artifact"] = []
257
- self._download_roots: Set[str] = set()
258
- get_artifacts_cache().store_artifact(self)
259
-
260
- @property
261
- def id(self) -> Optional[str]:
262
- return self._attrs["id"]
263
-
264
- @property
265
- def entity(self) -> str:
266
- return self._entity
267
-
268
- @property
269
- def project(self) -> str:
270
- return self._project
271
-
272
- @property
273
- def name(self) -> str:
274
- return self._name
275
-
276
- @property
277
- def version(self) -> str:
278
- """The artifact's version index under the given artifact collection.
279
-
280
- A string with the format "v{number}".
281
- """
282
- for a in self._attrs["aliases"]:
283
- if a[
284
- "artifactCollectionName"
285
- ] == self._artifact_collection_name and util.alias_is_version_index(
286
- a["alias"]
287
- ):
288
- return a["alias"]
289
- raise NotImplementedError
290
-
291
- @property
292
- def source_entity(self) -> str:
293
- return self._source_entity
294
-
295
- @property
296
- def source_project(self) -> str:
297
- return self._source_project
298
-
299
- @property
300
- def source_name(self) -> str:
301
- return self._source_name
302
-
303
- @property
304
- def source_version(self) -> str:
305
- """The artifact's version index under its parent artifact collection.
306
-
307
- A string with the format "v{number}".
308
- """
309
- return self._source_version
310
-
311
- @property
312
- def file_count(self) -> int:
313
- return self._attrs["fileCount"]
314
-
315
- @property
316
- def metadata(self) -> dict:
317
- return self._metadata
318
-
319
- @metadata.setter
320
- def metadata(self, metadata: dict) -> None:
321
- self._metadata = metadata
322
-
323
- @property
324
- def manifest(self) -> ArtifactManifest:
325
- return self._load_manifest()
326
-
327
- @property
328
- def digest(self) -> str:
329
- return self._attrs["digest"]
330
-
331
- @property
332
- def state(self) -> str:
333
- return self._attrs["state"]
334
-
335
- @property
336
- def size(self) -> int:
337
- return self._attrs["size"]
338
-
339
- @property
340
- def created_at(self) -> str:
341
- """The time at which the artifact was created."""
342
- return self._attrs["createdAt"]
343
-
344
- @property
345
- def updated_at(self) -> str:
346
- """The time at which the artifact was last updated."""
347
- return self._attrs["updatedAt"] or self._attrs["createdAt"]
348
-
349
- @property
350
- def description(self) -> Optional[str]:
351
- return self._description
352
-
353
- @description.setter
354
- def description(self, desc: Optional[str]) -> None:
355
- self._description = desc
356
-
357
- @property
358
- def type(self) -> str:
359
- return self._attrs["artifactType"]["name"]
360
-
361
- @property
362
- def commit_hash(self) -> str:
363
- return self._attrs.get("commitHash", "")
364
-
365
- @property
366
- def aliases(self) -> List[str]:
367
- """The aliases associated with this artifact.
368
-
369
- Returns:
370
- List[str]: The aliases associated with this artifact.
371
-
372
- """
373
- return self._aliases
374
-
375
- @aliases.setter
376
- def aliases(self, aliases: List[str]) -> None:
377
- for alias in aliases:
378
- if any(char in alias for char in ["/", ":"]):
379
- raise ValueError(
380
- 'Invalid alias "%s", slashes and colons are disallowed' % alias
381
- )
382
- self._aliases = aliases
383
-
384
- @staticmethod
385
- def expected_type(
386
- client: RetryingClient, name: str, entity_name: str, project_name: str
387
- ) -> Optional[str]:
388
- """Returns the expected type for a given artifact name and project."""
389
- query = gql(
390
- """
391
- query ArtifactType(
392
- $entityName: String,
393
- $projectName: String,
394
- $name: String!
395
- ) {
396
- project(name: $projectName, entityName: $entityName) {
397
- artifact(name: $name) {
398
- artifactType {
399
- name
400
- }
401
- }
402
- }
403
- }
404
- """
405
- )
406
- if ":" not in name:
407
- name += ":latest"
408
-
409
- response = client.execute(
410
- query,
411
- variable_values={
412
- "entityName": entity_name,
413
- "projectName": project_name,
414
- "name": name,
415
- },
416
- )
417
-
418
- project = response.get("project")
419
- if project is not None:
420
- artifact = project.get("artifact")
421
- if artifact is not None:
422
- artifact_type = artifact.get("artifactType")
423
- if artifact_type is not None:
424
- return artifact_type.get("name")
425
-
426
- return None
427
-
428
- @property
429
- def _use_as(self) -> Optional[str]:
430
- return self._attrs.get("_use_as")
431
-
432
- @_use_as.setter
433
- def _use_as(self, use_as: Optional[str]) -> None:
434
- self._attrs["_use_as"] = use_as
435
-
436
- @normalize_exceptions
437
- def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
438
- if ":" in target_path:
439
- raise ValueError(
440
- f"target_path {target_path} cannot contain `:` because it is not an alias."
441
- )
442
-
443
- portfolio, project, entity = util._parse_entity_project_item(target_path)
444
- aliases = util._resolve_aliases(aliases)
445
-
446
- run_entity = wandb.run.entity if wandb.run else None
447
- run_project = wandb.run.project if wandb.run else None
448
- entity = entity or run_entity or self.entity
449
- project = project or run_project or self.project
450
-
451
- mutation = gql(
452
- """
453
- mutation LinkArtifact($artifactID: ID!, $artifactPortfolioName: String!, $entityName: String!, $projectName: String!, $aliases: [ArtifactAliasInput!]) {
454
- linkArtifact(input: {artifactID: $artifactID, artifactPortfolioName: $artifactPortfolioName,
455
- entityName: $entityName,
456
- projectName: $projectName,
457
- aliases: $aliases
458
- }) {
459
- versionIndex
460
- }
461
- }
462
- """
463
- )
464
- self.client.execute(
465
- mutation,
466
- variable_values={
467
- "artifactID": self.id,
468
- "artifactPortfolioName": portfolio,
469
- "entityName": entity,
470
- "projectName": project,
471
- "aliases": [
472
- {"alias": alias, "artifactCollectionName": portfolio}
473
- for alias in aliases
474
- ],
475
- },
476
- )
477
-
478
- @normalize_exceptions
479
- def delete(self, delete_aliases: bool = False) -> None:
480
- """Delete an artifact and its files.
481
-
482
- Examples:
483
- Delete all the "model" artifacts a run has logged:
484
- ```
485
- runs = api.runs(path="my_entity/my_project")
486
- for run in runs:
487
- for artifact in run.logged_artifacts():
488
- if artifact.type == "model":
489
- artifact.delete(delete_aliases=True)
490
- ```
491
-
492
- Arguments:
493
- delete_aliases: (bool) If true, deletes all aliases associated with the artifact.
494
- Otherwise, this raises an exception if the artifact has existing aliases.
495
- """
496
- mutation = gql(
497
- """
498
- mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
499
- deleteArtifact(input: {
500
- artifactID: $artifactID
501
- deleteAliases: $deleteAliases
502
- }) {
503
- artifact {
504
- id
505
- }
506
- }
507
- }
508
- """
509
- )
510
- self.client.execute(
511
- mutation,
512
- variable_values={
513
- "artifactID": self.id,
514
- "deleteAliases": delete_aliases,
515
- },
516
- )
517
-
518
- @contextlib.contextmanager
519
- def new_file(
520
- self, name: str, mode: str = "w", encoding: Optional[str] = None
521
- ) -> Generator[IO, None, None]:
522
- raise ValueError("Cannot add files to an artifact once it has been saved")
523
-
524
- def add_file(
525
- self,
526
- local_path: str,
527
- name: Optional[str] = None,
528
- is_tmp: Optional[bool] = False,
529
- ) -> "ArtifactManifestEntry":
530
- raise ValueError("Cannot add files to an artifact once it has been saved")
531
-
532
- def add_dir(self, local_path: str, name: Optional[str] = None) -> None:
533
- raise ValueError("Cannot add files to an artifact once it has been saved")
534
-
535
- def add_reference(
536
- self,
537
- uri: Union["ArtifactManifestEntry", str],
538
- name: Optional[StrPath] = None,
539
- checksum: bool = True,
540
- max_objects: Optional[int] = None,
541
- ) -> Sequence["ArtifactManifestEntry"]:
542
- raise ValueError("Cannot add files to an artifact once it has been saved")
543
-
544
- def add(self, obj: WBValue, name: StrPath) -> "ArtifactManifestEntry":
545
- raise ValueError("Cannot add files to an artifact once it has been saved")
546
-
547
- def remove(self, item: Union[str, "os.PathLike", "ArtifactManifestEntry"]) -> None:
548
- raise ValueError("Cannot remove files from an artifact once it has been saved")
549
-
550
- def _add_download_root(self, dir_path: str) -> None:
551
- """Make `dir_path` a root directory for this artifact."""
552
- self._download_roots.add(os.path.abspath(dir_path))
553
-
554
- def _is_download_root(self, dir_path: str) -> bool:
555
- """Determine if `dir_path` is a root directory for this artifact."""
556
- return dir_path in self._download_roots
557
-
558
- def _local_path_to_name(self, file_path: str) -> Optional[str]:
559
- """Convert a local file path to a path entry in the artifact."""
560
- abs_file_path = os.path.abspath(file_path)
561
- abs_file_parts = abs_file_path.split(os.sep)
562
- for i in range(len(abs_file_parts) + 1):
563
- if self._is_download_root(os.path.join(os.sep, *abs_file_parts[:i])):
564
- return os.path.join(*abs_file_parts[i:])
565
- return None
566
-
567
- def _get_obj_entry(
568
- self, name: str
569
- ) -> Tuple[Optional["ArtifactManifestEntry"], Optional[Type[WBValue]]]:
570
- """Return an object entry by name, handling any type suffixes.
571
-
572
- When objects are added with `.add(obj, name)`, the name is typically changed to
573
- include the suffix of the object type when serializing to JSON. So we need to be
574
- able to resolve a name, without tasking the user with appending .THING.json.
575
- This method returns an entry if it exists by a suffixed name.
576
-
577
- Args:
578
- name: (str) name used when adding
579
- """
580
- self._load_manifest()
581
-
582
- type_mapping = WBValue.type_mapping()
583
- for artifact_type_str in type_mapping:
584
- wb_class = type_mapping[artifact_type_str]
585
- wandb_file_name = wb_class.with_suffix(name)
586
- entry = self.manifest.entries.get(wandb_file_name)
587
- if entry is not None:
588
- return entry, wb_class
589
- return None, None
590
-
591
- def get_path(self, name: StrPath) -> "ArtifactManifestEntry":
592
- name = LogicalPath(name)
593
- manifest = self._load_manifest()
594
- entry = manifest.entries.get(name) or self._get_obj_entry(name)[0]
595
- if entry is None:
596
- raise KeyError("Path not contained in artifact: %s" % name)
597
- entry._parent_artifact = self
598
- return entry
599
-
600
- def get(self, name: str) -> Optional[WBValue]:
601
- entry, wb_class = self._get_obj_entry(name)
602
- if entry is None or wb_class is None:
603
- return None
604
- # If the entry is a reference from another artifact, then get it directly from that artifact
605
- if self._manifest_entry_is_artifact_reference(entry):
606
- artifact = self._get_ref_artifact_from_entry(entry)
607
- return artifact.get(util.uri_from_path(entry.ref))
608
-
609
- # Special case for wandb.Table. This is intended to be a short term optimization.
610
- # Since tables are likely to download many other assets in artifact(s), we eagerly download
611
- # the artifact using the parallelized `artifact.download`. In the future, we should refactor
612
- # the deserialization pattern such that this special case is not needed.
613
- if wb_class == wandb.Table:
614
- self.download(recursive=True)
615
-
616
- # Get the ArtifactManifestEntry
617
- item = self.get_path(entry.path)
618
- item_path = item.download()
619
-
620
- # Load the object from the JSON blob
621
- result = None
622
- json_obj = {}
623
- with open(item_path) as file:
624
- json_obj = json.load(file)
625
- result = wb_class.from_json(json_obj, self)
626
- result._set_artifact_source(self, name)
627
- return result
628
-
629
- def download(
630
- self, root: Optional[str] = None, recursive: bool = False
631
- ) -> FilePathStr:
632
- dirpath = root or self._default_root()
633
- self._add_download_root(dirpath)
634
- manifest = self._load_manifest()
635
- nfiles = len(manifest.entries)
636
- size = sum(e.size or 0 for e in manifest.entries.values())
637
- log = False
638
- if nfiles > 5000 or size > 50 * 1024 * 1024:
639
- log = True
640
- termlog(
641
- "Downloading large artifact {}, {:.2f}MB. {} files... ".format(
642
- self.name, size / (1024 * 1024), nfiles
643
- ),
644
- )
645
- start_time = datetime.datetime.now()
646
-
647
- # Force all the files to download into the same directory.
648
- # Download in parallel
649
- import multiprocessing.dummy # this uses threads
650
-
651
- download_logger = ArtifactDownloadLogger(nfiles=nfiles)
652
-
653
- def _download_file_with_thread_local_api_settings(
654
- name: str,
655
- root: str,
656
- download_logger: ArtifactDownloadLogger,
657
- tlas_api_key: Optional[str],
658
- tlas_cookies: Optional[Dict],
659
- tlas_headers: Optional[Dict],
660
- ) -> StrPath:
661
- _thread_local_api_settings.api_key = tlas_api_key
662
- _thread_local_api_settings.cookies = tlas_cookies
663
- _thread_local_api_settings.headers = tlas_headers
664
-
665
- return self._download_file(name, root, download_logger)
666
-
667
- pool = multiprocessing.dummy.Pool(32)
668
- pool.map(
669
- partial(
670
- _download_file_with_thread_local_api_settings,
671
- root=dirpath,
672
- download_logger=download_logger,
673
- tlas_api_key=_thread_local_api_settings.api_key,
674
- tlas_headers={**(_thread_local_api_settings.headers or {})},
675
- tlas_cookies={**(_thread_local_api_settings.cookies or {})},
676
- ),
677
- manifest.entries,
678
- )
679
- if recursive:
680
- pool.map(lambda artifact: artifact.download(), self._dependent_artifacts)
681
- pool.close()
682
- pool.join()
683
-
684
- self._is_downloaded = True
685
-
686
- if log:
687
- now = datetime.datetime.now()
688
- delta = abs((now - start_time).total_seconds())
689
- hours = int(delta // 3600)
690
- minutes = int((delta - hours * 3600) // 60)
691
- seconds = delta - hours * 3600 - minutes * 60
692
- termlog(
693
- f"Done. {hours}:{minutes}:{seconds:.1f}",
694
- prefix=False,
695
- )
696
- return FilePathStr(dirpath)
697
-
698
- def checkout(self, root: Optional[str] = None) -> str:
699
- dirpath = root or self._default_root(include_version=False)
700
-
701
- for root, _, files in os.walk(dirpath):
702
- for file in files:
703
- full_path = os.path.join(root, file)
704
- artifact_path = os.path.relpath(full_path, start=dirpath)
705
- try:
706
- self.get_path(artifact_path)
707
- except KeyError:
708
- # File is not part of the artifact, remove it.
709
- os.remove(full_path)
710
-
711
- return self.download(root=dirpath)
712
-
713
- def verify(self, root: Optional[str] = None) -> None:
714
- dirpath = root or self._default_root()
715
- manifest = self._load_manifest()
716
- ref_count = 0
717
-
718
- for root, _, files in os.walk(dirpath):
719
- for file in files:
720
- full_path = os.path.join(root, file)
721
- artifact_path = os.path.relpath(full_path, start=dirpath)
722
- try:
723
- self.get_path(artifact_path)
724
- except KeyError:
725
- raise ValueError(
726
- "Found file {} which is not a member of artifact {}".format(
727
- full_path, self.name
728
- )
729
- )
730
-
731
- for entry in manifest.entries.values():
732
- if entry.ref is None:
733
- if md5_file_b64(os.path.join(dirpath, entry.path)) != entry.digest:
734
- raise ValueError("Digest mismatch for file: %s" % entry.path)
735
- else:
736
- ref_count += 1
737
- if ref_count > 0:
738
- print("Warning: skipped verification of %s refs" % ref_count)
739
-
740
- def file(self, root: Optional[str] = None) -> StrPath:
741
- """Download a single file artifact to dir specified by the root.
742
-
743
- Arguments:
744
- root: (str, optional) The root directory in which to place the file. Defaults to './artifacts/self.name/'.
745
-
746
- Returns:
747
- (str): The full path of the downloaded file.
748
- """
749
- if root is None:
750
- root = os.path.join(".", "artifacts", self.name)
751
-
752
- manifest = self._load_manifest()
753
- nfiles = len(manifest.entries)
754
- if nfiles > 1:
755
- raise ValueError(
756
- "This artifact contains more than one file, call `.download()` to get all files or call "
757
- '.get_path("filename").download()'
758
- )
759
-
760
- return self._download_file(list(manifest.entries)[0], root=root)
761
-
762
- def _download_file(
763
- self,
764
- name: str,
765
- root: str,
766
- download_logger: Optional[ArtifactDownloadLogger] = None,
767
- ) -> StrPath:
768
- # download file into cache and copy to target dir
769
- downloaded_path = self.get_path(name).download(root)
770
- if download_logger is not None:
771
- download_logger.notify_downloaded()
772
- return downloaded_path
773
-
774
- def _default_root(self, include_version: bool = True) -> str:
775
- name = self.source_name if include_version else self.source_name.split(":")[0]
776
- root = os.path.join(get_artifact_dir(), name)
777
- if platform.system() == "Windows":
778
- head, tail = os.path.splitdrive(root)
779
- root = head + tail.replace(":", "-")
780
- return root
781
-
782
- def json_encode(self) -> Dict[str, Any]:
783
- return util.artifact_to_json(self)
784
-
785
- @normalize_exceptions
786
- def save(self) -> None:
787
- """Persists artifact changes to the wandb backend."""
788
- mutation = gql(
789
- """
790
- mutation updateArtifact(
791
- $artifactID: ID!,
792
- $description: String,
793
- $metadata: JSONString,
794
- $aliases: [ArtifactAliasInput!]
795
- ) {
796
- updateArtifact(input: {
797
- artifactID: $artifactID,
798
- description: $description,
799
- metadata: $metadata,
800
- aliases: $aliases
801
- }) {
802
- artifact {
803
- id
804
- }
805
- }
806
- }
807
- """
808
- )
809
- introspect_query = gql(
810
- """
811
- query ProbeServerAddAliasesInput {
812
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
813
- name
814
- inputFields {
815
- name
816
- }
817
- }
818
- }
819
- """
820
- )
821
- res = self.client.execute(introspect_query)
822
- valid = res.get("AddAliasesInputInfoType")
823
- aliases = None
824
- if not valid:
825
- # If valid, wandb backend version >= 0.13.0.
826
- # This means we can safely remove aliases from this updateArtifact request since we'll be calling
827
- # the alias endpoints below in _save_alias_changes.
828
- # If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact.
829
- aliases = [
830
- {
831
- "artifactCollectionName": self._artifact_collection_name,
832
- "alias": alias,
833
- }
834
- for alias in self._aliases
835
- ]
836
-
837
- self.client.execute(
838
- mutation,
839
- variable_values={
840
- "artifactID": self.id,
841
- "description": self.description,
842
- "metadata": util.json_dumps_safer(self.metadata),
843
- "aliases": aliases,
844
- },
845
- )
846
- # Save locally modified aliases
847
- self._save_alias_changes()
848
-
849
- def wait(self) -> "Artifact":
850
- return self
851
-
852
- @normalize_exceptions
853
- def _save_alias_changes(self) -> None:
854
- """Persist alias changes on this artifact to the wandb backend.
855
-
856
- Called by artifact.save().
857
- """
858
- aliases_to_add = set(self._aliases) - set(self._frozen_aliases)
859
- aliases_to_remove = set(self._frozen_aliases) - set(self._aliases)
860
-
861
- # Introspect
862
- introspect_query = gql(
863
- """
864
- query ProbeServerAddAliasesInput {
865
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
866
- name
867
- inputFields {
868
- name
869
- }
870
- }
871
- }
872
- """
873
- )
874
- res = self.client.execute(introspect_query)
875
- valid = res.get("AddAliasesInputInfoType")
876
- if not valid:
877
- return
878
-
879
- if len(aliases_to_add) > 0:
880
- add_mutation = gql(
881
- """
882
- mutation addAliases(
883
- $artifactID: ID!,
884
- $aliases: [ArtifactCollectionAliasInput!]!,
885
- ) {
886
- addAliases(
887
- input: {
888
- artifactID: $artifactID,
889
- aliases: $aliases,
890
- }
891
- ) {
892
- success
893
- }
894
- }
895
- """
896
- )
897
- self.client.execute(
898
- add_mutation,
899
- variable_values={
900
- "artifactID": self.id,
901
- "aliases": [
902
- {
903
- "artifactCollectionName": self._artifact_collection_name,
904
- "alias": alias,
905
- "entityName": self._entity,
906
- "projectName": self._project,
907
- }
908
- for alias in aliases_to_add
909
- ],
910
- },
911
- )
912
-
913
- if len(aliases_to_remove) > 0:
914
- delete_mutation = gql(
915
- """
916
- mutation deleteAliases(
917
- $artifactID: ID!,
918
- $aliases: [ArtifactCollectionAliasInput!]!,
919
- ) {
920
- deleteAliases(
921
- input: {
922
- artifactID: $artifactID,
923
- aliases: $aliases,
924
- }
925
- ) {
926
- success
927
- }
928
- }
929
- """
930
- )
931
- self.client.execute(
932
- delete_mutation,
933
- variable_values={
934
- "artifactID": self.id,
935
- "aliases": [
936
- {
937
- "artifactCollectionName": self._artifact_collection_name,
938
- "alias": alias,
939
- "entityName": self._entity,
940
- "projectName": self._project,
941
- }
942
- for alias in aliases_to_remove
943
- ],
944
- },
945
- )
946
-
947
- # reset local state
948
- self._frozen_aliases = self._aliases
949
-
950
- # TODO: not yet public, but we probably want something like this.
951
- def _list(self) -> Iterable[str]:
952
- manifest = self._load_manifest()
953
- return manifest.entries.keys()
954
-
955
- def __repr__(self) -> str:
956
- return f"<Artifact {self.id}>"
957
-
958
- def _load(self) -> Dict[str, Any]:
959
- query = gql(
960
- """
961
- query Artifact(
962
- $entityName: String,
963
- $projectName: String,
964
- $name: String!
965
- ) {
966
- project(name: $projectName, entityName: $entityName) {
967
- artifact(name: $name) {
968
- ...ArtifactFragment
969
- }
970
- }
971
- }
972
- %s
973
- """
974
- % ARTIFACT_FRAGMENT
975
- )
976
- response = None
977
- try:
978
- response = self.client.execute(
979
- query,
980
- variable_values={
981
- "entityName": self.entity,
982
- "projectName": self.project,
983
- "name": self.name,
984
- },
985
- )
986
- except Exception:
987
- # we check for this after doing the call, since the backend supports raw digest lookups
988
- # which don't include ":" and are 32 characters long
989
- if ":" not in self.name and len(self.name) != 32:
990
- raise ValueError(
991
- 'Attempted to fetch artifact without alias (e.g. "<artifact_name>:v3" or "<artifact_name>:latest")'
992
- )
993
- if (
994
- response is None
995
- or response.get("project") is None
996
- or response["project"].get("artifact") is None
997
- ):
998
- raise ValueError(
999
- f'Project {self.entity}/{self.project} does not contain artifact: "{self.name}"'
1000
- )
1001
- return response["project"]["artifact"]
1002
-
1003
- def files(
1004
- self, names: Optional[List[str]] = None, per_page: int = 50
1005
- ) -> ArtifactFiles:
1006
- """Iterate over all files stored in this artifact.
1007
-
1008
- Arguments:
1009
- names: (list of str, optional) The filename paths relative to the
1010
- root of the artifact you wish to list.
1011
- per_page: (int, default 50) The number of files to return per request
1012
-
1013
- Returns:
1014
- (`ArtifactFiles`): An iterator containing `File` objects
1015
- """
1016
- return ArtifactFiles(self.client, self, names, per_page)
1017
-
1018
- def _load_manifest(self) -> ArtifactManifest:
1019
- if self._manifest is None:
1020
- query = gql(
1021
- """
1022
- query ArtifactManifest(
1023
- $entityName: String!,
1024
- $projectName: String!,
1025
- $name: String!
1026
- ) {
1027
- project(name: $projectName, entityName: $entityName) {
1028
- artifact(name: $name) {
1029
- currentManifest {
1030
- id
1031
- file {
1032
- id
1033
- directUrl
1034
- }
1035
- }
1036
- }
1037
- }
1038
- }
1039
- """
1040
- )
1041
- response = self.client.execute(
1042
- query,
1043
- variable_values={
1044
- "entityName": self.entity,
1045
- "projectName": self.project,
1046
- "name": self.name,
1047
- },
1048
- )
1049
-
1050
- index_file_url = response["project"]["artifact"]["currentManifest"]["file"][
1051
- "directUrl"
1052
- ]
1053
- with requests.get(index_file_url) as req:
1054
- req.raise_for_status()
1055
- self._manifest = ArtifactManifest.from_manifest_json(
1056
- json.loads(util.ensure_text(req.content))
1057
- )
1058
-
1059
- self._load_dependent_manifests()
1060
-
1061
- return self._manifest
1062
-
1063
- def _load_dependent_manifests(self) -> None:
1064
- """Interrogate entries and ensure we have loaded their manifests."""
1065
- # Make sure dependencies are avail
1066
- for entry_key in self.manifest.entries:
1067
- entry = self.manifest.entries[entry_key]
1068
- if self._manifest_entry_is_artifact_reference(entry):
1069
- dep_artifact = self._get_ref_artifact_from_entry(entry)
1070
- if dep_artifact not in self._dependent_artifacts:
1071
- dep_artifact._load_manifest()
1072
- self._dependent_artifacts.append(dep_artifact)
1073
-
1074
- @staticmethod
1075
- def _manifest_entry_is_artifact_reference(entry: "ArtifactManifestEntry") -> bool:
1076
- """Determine if an ArtifactManifestEntry is an artifact reference."""
1077
- return (
1078
- entry.ref is not None
1079
- and urllib.parse.urlparse(entry.ref).scheme == "wandb-artifact"
1080
- )
1081
-
1082
- def _get_ref_artifact_from_entry(
1083
- self, entry: "ArtifactManifestEntry"
1084
- ) -> "Artifact":
1085
- """Helper function returns the referenced artifact from an entry."""
1086
- artifact_id = util.host_from_path(entry.ref)
1087
- artifact = Artifact.from_id(hex_to_b64_id(artifact_id), self.client)
1088
- assert artifact is not None
1089
- return artifact
1090
-
1091
- def used_by(self) -> List[Run]:
1092
- """Retrieve the runs which use this artifact directly.
1093
-
1094
- Returns:
1095
- [Run]: a list of Run objects which use this artifact
1096
- """
1097
- query = gql(
1098
- """
1099
- query ArtifactUsedBy(
1100
- $id: ID!,
1101
- $before: String,
1102
- $after: String,
1103
- $first: Int,
1104
- $last: Int
1105
- ) {
1106
- artifact(id: $id) {
1107
- usedBy(before: $before, after: $after, first: $first, last: $last) {
1108
- edges {
1109
- node {
1110
- name
1111
- project {
1112
- name
1113
- entityName
1114
- }
1115
- }
1116
- }
1117
- }
1118
- }
1119
- }
1120
- """
1121
- )
1122
- response = self.client.execute(
1123
- query,
1124
- variable_values={"id": self.id},
1125
- )
1126
- # yes, "name" is actually id
1127
- runs = [
1128
- Run(
1129
- self.client,
1130
- edge["node"]["project"]["entityName"],
1131
- edge["node"]["project"]["name"],
1132
- edge["node"]["name"],
1133
- )
1134
- for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
1135
- ]
1136
- return runs
1137
-
1138
- def logged_by(self) -> Optional[Run]:
1139
- """Retrieve the run which logged this artifact.
1140
-
1141
- Returns:
1142
- Run: Run object which logged this artifact
1143
- """
1144
- query = gql(
1145
- """
1146
- query ArtifactCreatedBy(
1147
- $id: ID!
1148
- ) {
1149
- artifact(id: $id) {
1150
- createdBy {
1151
- ... on Run {
1152
- name
1153
- project {
1154
- name
1155
- entityName
1156
- }
1157
- }
1158
- }
1159
- }
1160
- }
1161
- """
1162
- )
1163
- response = self.client.execute(
1164
- query,
1165
- variable_values={"id": self.id},
1166
- )
1167
- run_obj = response.get("artifact", {}).get("createdBy", {})
1168
- if run_obj is None:
1169
- return None
1170
- return Run(
1171
- self.client,
1172
- run_obj["project"]["entityName"],
1173
- run_obj["project"]["name"],
1174
- run_obj["name"],
1175
- )
1176
-
1177
- def new_draft(self) -> "LocalArtifact":
1178
- """Create a new draft artifact with the same content as this committed artifact.
1179
-
1180
- The artifact returned can be extended or modified and logged as a new version.
1181
- """
1182
- artifact = wandb.Artifact(self.name.split(":")[0], self.type)
1183
- artifact._description = self.description
1184
- artifact._metadata = self.metadata
1185
- artifact._manifest = ArtifactManifest.from_manifest_json(
1186
- self.manifest.to_manifest_json()
1187
- )
1188
- return artifact