wandb 0.15.9__py3-none-any.whl → 0.15.11__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/apis/public.py +137 -17
  3. wandb/apis/reports/_panels.py +1 -1
  4. wandb/apis/reports/blocks.py +1 -0
  5. wandb/apis/reports/report.py +27 -5
  6. wandb/cli/cli.py +52 -41
  7. wandb/docker/__init__.py +17 -0
  8. wandb/docker/auth.py +1 -1
  9. wandb/env.py +24 -4
  10. wandb/filesync/step_checksum.py +3 -3
  11. wandb/integration/openai/openai.py +3 -0
  12. wandb/integration/ultralytics/__init__.py +9 -0
  13. wandb/integration/ultralytics/bbox_utils.py +196 -0
  14. wandb/integration/ultralytics/callback.py +458 -0
  15. wandb/integration/ultralytics/classification_utils.py +66 -0
  16. wandb/integration/ultralytics/mask_utils.py +141 -0
  17. wandb/integration/ultralytics/pose_utils.py +92 -0
  18. wandb/integration/xgboost/xgboost.py +3 -3
  19. wandb/integration/yolov8/__init__.py +0 -7
  20. wandb/integration/yolov8/yolov8.py +22 -3
  21. wandb/old/settings.py +7 -0
  22. wandb/plot/line_series.py +0 -1
  23. wandb/proto/v3/wandb_internal_pb2.py +353 -300
  24. wandb/proto/v3/wandb_server_pb2.py +37 -41
  25. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  26. wandb/proto/v3/wandb_telemetry_pb2.py +16 -16
  27. wandb/proto/v4/wandb_internal_pb2.py +272 -260
  28. wandb/proto/v4/wandb_server_pb2.py +37 -40
  29. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  30. wandb/proto/v4/wandb_telemetry_pb2.py +16 -16
  31. wandb/proto/wandb_internal_codegen.py +7 -31
  32. wandb/sdk/artifacts/artifact.py +321 -189
  33. wandb/sdk/artifacts/artifact_cache.py +14 -0
  34. wandb/sdk/artifacts/artifact_manifest.py +5 -4
  35. wandb/sdk/artifacts/artifact_manifest_entry.py +37 -9
  36. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -9
  37. wandb/sdk/artifacts/artifact_saver.py +13 -50
  38. wandb/sdk/artifacts/artifact_ttl.py +6 -0
  39. wandb/sdk/artifacts/artifacts_cache.py +119 -93
  40. wandb/sdk/artifacts/staging.py +25 -0
  41. wandb/sdk/artifacts/storage_handlers/s3_handler.py +12 -7
  42. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +2 -3
  43. wandb/sdk/artifacts/storage_policies/__init__.py +4 -0
  44. wandb/sdk/artifacts/storage_policies/register.py +1 -0
  45. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +4 -3
  46. wandb/sdk/artifacts/storage_policy.py +4 -2
  47. wandb/sdk/backend/backend.py +0 -16
  48. wandb/sdk/data_types/image.py +3 -1
  49. wandb/sdk/integration_utils/auto_logging.py +38 -13
  50. wandb/sdk/interface/interface.py +16 -135
  51. wandb/sdk/interface/interface_shared.py +9 -147
  52. wandb/sdk/interface/interface_sock.py +0 -26
  53. wandb/sdk/internal/file_pusher.py +20 -3
  54. wandb/sdk/internal/file_stream.py +3 -1
  55. wandb/sdk/internal/handler.py +53 -70
  56. wandb/sdk/internal/internal_api.py +220 -130
  57. wandb/sdk/internal/job_builder.py +41 -37
  58. wandb/sdk/internal/sender.py +7 -25
  59. wandb/sdk/internal/system/assets/disk.py +144 -11
  60. wandb/sdk/internal/system/system_info.py +6 -2
  61. wandb/sdk/launch/__init__.py +5 -0
  62. wandb/sdk/launch/{launch.py → _launch.py} +53 -54
  63. wandb/sdk/launch/{launch_add.py → _launch_add.py} +34 -31
  64. wandb/sdk/launch/_project_spec.py +13 -2
  65. wandb/sdk/launch/agent/agent.py +103 -59
  66. wandb/sdk/launch/agent/run_queue_item_file_saver.py +6 -4
  67. wandb/sdk/launch/builder/build.py +19 -1
  68. wandb/sdk/launch/builder/docker_builder.py +5 -1
  69. wandb/sdk/launch/builder/kaniko_builder.py +5 -1
  70. wandb/sdk/launch/create_job.py +20 -5
  71. wandb/sdk/launch/loader.py +14 -5
  72. wandb/sdk/launch/runner/abstract.py +0 -2
  73. wandb/sdk/launch/runner/kubernetes_monitor.py +329 -0
  74. wandb/sdk/launch/runner/kubernetes_runner.py +66 -209
  75. wandb/sdk/launch/runner/local_container.py +5 -2
  76. wandb/sdk/launch/runner/local_process.py +4 -1
  77. wandb/sdk/launch/sweeps/scheduler.py +43 -25
  78. wandb/sdk/launch/sweeps/utils.py +5 -3
  79. wandb/sdk/launch/utils.py +3 -1
  80. wandb/sdk/lib/_settings_toposort_generate.py +3 -9
  81. wandb/sdk/lib/_settings_toposort_generated.py +27 -3
  82. wandb/sdk/lib/_wburls_generated.py +1 -0
  83. wandb/sdk/lib/filenames.py +27 -6
  84. wandb/sdk/lib/filesystem.py +181 -7
  85. wandb/sdk/lib/fsm.py +5 -3
  86. wandb/sdk/lib/gql_request.py +3 -0
  87. wandb/sdk/lib/ipython.py +7 -0
  88. wandb/sdk/lib/wburls.py +1 -0
  89. wandb/sdk/service/port_file.py +2 -15
  90. wandb/sdk/service/server.py +7 -55
  91. wandb/sdk/service/service.py +56 -26
  92. wandb/sdk/service/service_base.py +1 -1
  93. wandb/sdk/service/streams.py +11 -5
  94. wandb/sdk/verify/verify.py +2 -2
  95. wandb/sdk/wandb_init.py +8 -2
  96. wandb/sdk/wandb_manager.py +4 -14
  97. wandb/sdk/wandb_run.py +143 -53
  98. wandb/sdk/wandb_settings.py +148 -35
  99. wandb/testing/relay.py +85 -38
  100. wandb/util.py +87 -4
  101. wandb/wandb_torch.py +24 -38
  102. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/METADATA +48 -23
  103. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/RECORD +107 -103
  104. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/WHEEL +1 -1
  105. wandb/proto/v3/wandb_server_pb2_grpc.py +0 -1422
  106. wandb/proto/v4/wandb_server_pb2_grpc.py +0 -1422
  107. wandb/proto/wandb_server_pb2_grpc.py +0 -8
  108. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +0 -61
  109. wandb/sdk/interface/interface_grpc.py +0 -460
  110. wandb/sdk/service/server_grpc.py +0 -444
  111. wandb/sdk/service/service_grpc.py +0 -73
  112. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/LICENSE +0 -0
  113. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/entry_points.txt +0 -0
  114. {wandb-0.15.9.dist-info → wandb-0.15.11.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,15 @@
1
1
  """Artifact class."""
2
2
  import concurrent.futures
3
3
  import contextlib
4
- import datetime
5
4
  import json
6
5
  import multiprocessing.dummy
7
6
  import os
8
- import platform
9
7
  import re
10
8
  import shutil
11
9
  import tempfile
12
10
  import time
13
11
  from copy import copy
12
+ from datetime import datetime, timedelta
14
13
  from functools import partial
15
14
  from pathlib import PurePosixPath
16
15
  from typing import (
@@ -35,27 +34,30 @@ import requests
35
34
  import wandb
36
35
  from wandb import data_types, env, util
37
36
  from wandb.apis.normalize import normalize_exceptions
38
- from wandb.apis.public import ArtifactFiles, RetryingClient, Run
37
+ from wandb.apis.public import ArtifactCollection, ArtifactFiles, RetryingClient, Run
39
38
  from wandb.data_types import WBValue
40
39
  from wandb.errors.term import termerror, termlog, termwarn
40
+ from wandb.sdk.artifacts.artifact_cache import artifact_cache
41
41
  from wandb.sdk.artifacts.artifact_download_logger import ArtifactDownloadLogger
42
42
  from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
43
43
  from wandb.sdk.artifacts.artifact_manifest_entry import ArtifactManifestEntry
44
44
  from wandb.sdk.artifacts.artifact_manifests.artifact_manifest_v1 import (
45
45
  ArtifactManifestV1,
46
46
  )
47
- from wandb.sdk.artifacts.artifact_saver import get_staging_dir
48
47
  from wandb.sdk.artifacts.artifact_state import ArtifactState
49
- from wandb.sdk.artifacts.artifacts_cache import get_artifacts_cache
48
+ from wandb.sdk.artifacts.artifact_ttl import ArtifactTTL
50
49
  from wandb.sdk.artifacts.exceptions import (
51
50
  ArtifactFinalizedError,
52
51
  ArtifactNotLoggedError,
53
52
  WaitTimeoutError,
54
53
  )
54
+ from wandb.sdk.artifacts.staging import get_staging_dir
55
55
  from wandb.sdk.artifacts.storage_layout import StorageLayout
56
- from wandb.sdk.artifacts.storage_policies.wandb_storage_policy import WandbStoragePolicy
56
+ from wandb.sdk.artifacts.storage_policies import WANDB_STORAGE_POLICY
57
+ from wandb.sdk.artifacts.storage_policy import StoragePolicy
57
58
  from wandb.sdk.data_types._dtypes import Type as WBType
58
59
  from wandb.sdk.data_types._dtypes import TypeRegistry
60
+ from wandb.sdk.internal.internal_api import Api as InternalApi
59
61
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
60
62
  from wandb.sdk.lib import filesystem, retry, runid, telemetry
61
63
  from wandb.sdk.lib.hashutil import B64MD5, b64_to_hex_id, md5_file_b64
@@ -108,39 +110,6 @@ class Artifact:
108
110
  """
109
111
 
110
112
  _TMP_DIR = tempfile.TemporaryDirectory("wandb-artifacts")
111
- _GQL_FRAGMENT = """
112
- fragment ArtifactFragment on Artifact {
113
- id
114
- artifactSequence {
115
- project {
116
- entityName
117
- name
118
- }
119
- name
120
- }
121
- versionIndex
122
- artifactType {
123
- name
124
- }
125
- description
126
- metadata
127
- aliases {
128
- artifactCollection {
129
- project {
130
- entityName
131
- name
132
- }
133
- name
134
- }
135
- alias
136
- }
137
- state
138
- commitHash
139
- fileCount
140
- createdAt
141
- updatedAt
142
- }
143
- """
144
113
 
145
114
  def __init__(
146
115
  self,
@@ -166,15 +135,12 @@ class Artifact:
166
135
 
167
136
  # Internal.
168
137
  self._client: Optional[RetryingClient] = None
169
- storage_layout = (
170
- StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
171
- )
172
- self._storage_policy = WandbStoragePolicy(
173
- config={
174
- "storageLayout": storage_layout,
175
- # TODO: storage region
176
- }
177
- )
138
+
139
+ storage_policy_cls = StoragePolicy.lookup_by_name(WANDB_STORAGE_POLICY)
140
+ layout = StorageLayout.V1 if env.get_use_v1_artifacts() else StorageLayout.V2
141
+ policy_config = {"storageLayout": layout}
142
+ self._storage_policy = storage_policy_cls.from_config(config=policy_config)
143
+
178
144
  self._tmp_dir: Optional[tempfile.TemporaryDirectory] = None
179
145
  self._added_objs: Dict[
180
146
  int, Tuple[data_types.WBValue, ArtifactManifestEntry]
@@ -200,6 +166,9 @@ class Artifact:
200
166
  self._type: str = type
201
167
  self._description: Optional[str] = description
202
168
  self._metadata: dict = self._normalize_metadata(metadata)
169
+ self._ttl_duration_seconds: Optional[int] = None
170
+ self._ttl_is_inherited: bool = True
171
+ self._ttl_changed: bool = False
203
172
  self._aliases: List[str] = []
204
173
  self._saved_aliases: List[str] = []
205
174
  self._distributed_id: Optional[str] = None
@@ -214,15 +183,16 @@ class Artifact:
214
183
  self._created_at: Optional[str] = None
215
184
  self._updated_at: Optional[str] = None
216
185
  self._final: bool = False
186
+
217
187
  # Cache.
218
- get_artifacts_cache().store_client_artifact(self)
188
+ artifact_cache[self._client_id] = self
219
189
 
220
190
  def __repr__(self) -> str:
221
191
  return f"<Artifact {self.id or self.name}>"
222
192
 
223
193
  @classmethod
224
194
  def _from_id(cls, artifact_id: str, client: RetryingClient) -> Optional["Artifact"]:
225
- artifact = get_artifacts_cache().get_artifact(artifact_id)
195
+ artifact = artifact_cache.get(artifact_id)
226
196
  if artifact is not None:
227
197
  return artifact
228
198
 
@@ -239,7 +209,7 @@ class Artifact:
239
209
  }
240
210
  }
241
211
  """
242
- + cls._GQL_FRAGMENT
212
+ + cls._get_gql_artifact_fragment()
243
213
  )
244
214
  response = client.execute(
245
215
  query,
@@ -271,7 +241,7 @@ class Artifact:
271
241
  }
272
242
  }
273
243
  """
274
- + cls._GQL_FRAGMENT
244
+ + cls._get_gql_artifact_fragment()
275
245
  )
276
246
  response = client.execute(
277
247
  query,
@@ -327,6 +297,12 @@ class Artifact:
327
297
  artifact.metadata = cls._normalize_metadata(
328
298
  json.loads(attrs["metadata"] or "{}")
329
299
  )
300
+ artifact._ttl_duration_seconds = artifact._ttl_duration_seconds_from_gql(
301
+ attrs.get("ttlDurationSeconds")
302
+ )
303
+ artifact._ttl_is_inherited = (
304
+ True if attrs.get("ttlIsInherited") is None else attrs["ttlIsInherited"]
305
+ )
330
306
  artifact._aliases = [
331
307
  alias for alias in aliases if not util.alias_is_version_index(alias)
332
308
  ]
@@ -342,7 +318,9 @@ class Artifact:
342
318
  artifact._updated_at = attrs["updatedAt"]
343
319
  artifact._final = True
344
320
  # Cache.
345
- get_artifacts_cache().store_artifact(artifact)
321
+
322
+ assert artifact.id is not None
323
+ artifact_cache[artifact.id] = artifact
346
324
  return artifact
347
325
 
348
326
  def new_draft(self) -> "Artifact":
@@ -353,11 +331,22 @@ class Artifact:
353
331
  Raises:
354
332
  ArtifactNotLoggedError: if the artifact has not been logged
355
333
  """
356
- if self._state == ArtifactState.PENDING:
357
- raise ArtifactNotLoggedError(self, "new_draft")
334
+ self._ensure_logged("new_draft")
358
335
 
336
+ # Name, _entity and _project are set to the *source* name/entity/project:
337
+ # if this artifact is saved it must be saved to the source sequence.
359
338
  artifact = Artifact(self.source_name.split(":")[0], self.type)
339
+ artifact._entity = self._source_entity
340
+ artifact._project = self._source_project
341
+ artifact._source_entity = self._source_entity
342
+ artifact._source_project = self._source_project
343
+
344
+ # This artifact's parent is the one we are making a draft from.
360
345
  artifact._base_id = self.id
346
+
347
+ # We can reuse the client, and copy over all the attributes that aren't
348
+ # version-dependent and don't depend on having been logged.
349
+ artifact._client = self._client
361
350
  artifact._description = self.description
362
351
  artifact._metadata = self.metadata
363
352
  artifact._manifest = ArtifactManifest.from_manifest_json(
@@ -370,7 +359,7 @@ class Artifact:
370
359
  @property
371
360
  def id(self) -> Optional[str]:
372
361
  """The artifact's ID."""
373
- if self._state == ArtifactState.PENDING:
362
+ if self.is_draft():
374
363
  return None
375
364
  assert self._id is not None
376
365
  return self._id
@@ -378,16 +367,14 @@ class Artifact:
378
367
  @property
379
368
  def entity(self) -> str:
380
369
  """The name of the entity of the secondary (portfolio) artifact collection."""
381
- if self._state == ArtifactState.PENDING:
382
- raise ArtifactNotLoggedError(self, "entity")
370
+ self._ensure_logged("entity")
383
371
  assert self._entity is not None
384
372
  return self._entity
385
373
 
386
374
  @property
387
375
  def project(self) -> str:
388
376
  """The name of the project of the secondary (portfolio) artifact collection."""
389
- if self._state == ArtifactState.PENDING:
390
- raise ArtifactNotLoggedError(self, "project")
377
+ self._ensure_logged("project")
391
378
  assert self._project is not None
392
379
  return self._project
393
380
 
@@ -408,24 +395,34 @@ class Artifact:
408
395
  @property
409
396
  def version(self) -> str:
410
397
  """The artifact's version in its secondary (portfolio) collection."""
411
- if self._state == ArtifactState.PENDING:
412
- raise ArtifactNotLoggedError(self, "version")
398
+ self._ensure_logged("version")
413
399
  assert self._version is not None
414
400
  return self._version
415
401
 
402
+ @property
403
+ def collection(self) -> ArtifactCollection:
404
+ """The collection this artifact was retrieved from.
405
+
406
+ If this artifact was retrieved from a portfolio / linked collection, that
407
+ collection will be returned rather than the source sequence.
408
+ """
409
+ self._ensure_logged("collection")
410
+ base_name = self.name.split(":")[0]
411
+ return ArtifactCollection(
412
+ self._client, self.entity, self.project, base_name, self.type
413
+ )
414
+
416
415
  @property
417
416
  def source_entity(self) -> str:
418
417
  """The name of the entity of the primary (sequence) artifact collection."""
419
- if self._state == ArtifactState.PENDING:
420
- raise ArtifactNotLoggedError(self, "source_entity")
418
+ self._ensure_logged("source_entity")
421
419
  assert self._source_entity is not None
422
420
  return self._source_entity
423
421
 
424
422
  @property
425
423
  def source_project(self) -> str:
426
424
  """The name of the project of the primary (sequence) artifact collection."""
427
- if self._state == ArtifactState.PENDING:
428
- raise ArtifactNotLoggedError(self, "source_project")
425
+ self._ensure_logged("source_project")
429
426
  assert self._source_project is not None
430
427
  return self._source_project
431
428
 
@@ -449,11 +446,19 @@ class Artifact:
449
446
 
450
447
  A string with the format "v{number}".
451
448
  """
452
- if self._state == ArtifactState.PENDING:
453
- raise ArtifactNotLoggedError(self, "source_version")
449
+ self._ensure_logged("source_version")
454
450
  assert self._source_version is not None
455
451
  return self._source_version
456
452
 
453
+ @property
454
+ def source_collection(self) -> ArtifactCollection:
455
+ """The artifact's primary (sequence) collection."""
456
+ self._ensure_logged("source_collection")
457
+ base_name = self.source_name.split(":")[0]
458
+ return ArtifactCollection(
459
+ self._client, self.source_entity, self.source_project, base_name, self.type
460
+ )
461
+
457
462
  @property
458
463
  def type(self) -> str:
459
464
  """The artifact's type."""
@@ -501,21 +506,69 @@ class Artifact:
501
506
  """
502
507
  self._metadata = self._normalize_metadata(metadata)
503
508
 
509
+ @property
510
+ def ttl(self) -> Union[timedelta, None]:
511
+ """Time To Live (TTL).
512
+
513
+ The artifact will be deleted shortly after TTL since its creation.
514
+ None means the artifact will never expire.
515
+ If TTL is not set on an artifact, it will inherit the default for its collection.
516
+
517
+ Raises:
518
+ ArtifactNotLoggedError: Unable to fetch inherited TTL if the artifact has not been logged or saved
519
+ """
520
+ if self._ttl_is_inherited and (self.is_draft() or self._ttl_changed):
521
+ raise ArtifactNotLoggedError(self, "ttl")
522
+ if self._ttl_duration_seconds is None:
523
+ return None
524
+ return timedelta(seconds=self._ttl_duration_seconds)
525
+
526
+ @ttl.setter
527
+ def ttl(self, ttl: Union[timedelta, ArtifactTTL, None]) -> None:
528
+ """Time To Live (TTL).
529
+
530
+ The artifact will be deleted shortly after TTL since its creation. None means the artifact will never expire.
531
+ If TTL is not set on an artifact, it will inherit the default TTL rules for its collection.
532
+
533
+ Arguments:
534
+ ttl: How long the artifact will remain active from its creation.
535
+ - Timedelta must be positive.
536
+ - `None` means the artifact will never expire.
537
+ - wandb.ArtifactTTL.INHERIT will set the TTL to go back to the default and inherit from collection rules.
538
+ """
539
+ if self.type == "wandb-history":
540
+ raise ValueError("Cannot set artifact TTL for type wandb-history")
541
+
542
+ self._ttl_changed = True
543
+ if isinstance(ttl, ArtifactTTL):
544
+ if ttl == ArtifactTTL.INHERIT:
545
+ self._ttl_is_inherited = True
546
+ else:
547
+ raise ValueError(f"Unhandled ArtifactTTL enum {ttl}")
548
+ else:
549
+ self._ttl_is_inherited = False
550
+ if ttl is None:
551
+ self._ttl_duration_seconds = None
552
+ else:
553
+ if ttl.total_seconds() <= 0:
554
+ raise ValueError(
555
+ f"Artifact TTL Duration has to be positive. ttl: {ttl.total_seconds()}"
556
+ )
557
+ self._ttl_duration_seconds = int(ttl.total_seconds())
558
+
504
559
  @property
505
560
  def aliases(self) -> List[str]:
506
561
  """The aliases associated with this artifact.
507
562
 
508
563
  The list is mutable and calling `save()` will persist all alias changes.
509
564
  """
510
- if self._state == ArtifactState.PENDING:
511
- raise ArtifactNotLoggedError(self, "aliases")
565
+ self._ensure_logged("aliases")
512
566
  return self._aliases
513
567
 
514
568
  @aliases.setter
515
569
  def aliases(self, aliases: List[str]) -> None:
516
570
  """Set the aliases associated with this artifact."""
517
- if self._state == ArtifactState.PENDING:
518
- raise ArtifactNotLoggedError(self, "aliases")
571
+ self._ensure_logged("aliases")
519
572
 
520
573
  if any(char in alias for alias in aliases for char in ["/", ":"]):
521
574
  raise ValueError(
@@ -609,32 +662,28 @@ class Artifact:
609
662
  @property
610
663
  def commit_hash(self) -> str:
611
664
  """The hash returned when this artifact was committed."""
612
- if self._state == ArtifactState.PENDING:
613
- raise ArtifactNotLoggedError(self, "commit_hash")
665
+ self._ensure_logged("commit_hash")
614
666
  assert self._commit_hash is not None
615
667
  return self._commit_hash
616
668
 
617
669
  @property
618
670
  def file_count(self) -> int:
619
671
  """The number of files (including references)."""
620
- if self._state == ArtifactState.PENDING:
621
- raise ArtifactNotLoggedError(self, "file_count")
672
+ self._ensure_logged("file_count")
622
673
  assert self._file_count is not None
623
674
  return self._file_count
624
675
 
625
676
  @property
626
677
  def created_at(self) -> str:
627
678
  """The time at which the artifact was created."""
628
- if self._state == ArtifactState.PENDING:
629
- raise ArtifactNotLoggedError(self, "created_at")
679
+ self._ensure_logged("created_at")
630
680
  assert self._created_at is not None
631
681
  return self._created_at
632
682
 
633
683
  @property
634
684
  def updated_at(self) -> str:
635
685
  """The time at which the artifact was last updated."""
636
- if self._state == ArtifactState.PENDING:
637
- raise ArtifactNotLoggedError(self, "created_at")
686
+ self._ensure_logged("updated_at")
638
687
  assert self._created_at is not None
639
688
  return self._updated_at or self._created_at
640
689
 
@@ -651,6 +700,10 @@ class Artifact:
651
700
  if self._final:
652
701
  raise ArtifactFinalizedError(artifact=self)
653
702
 
703
+ def _ensure_logged(self, attr: Optional[str] = None) -> None:
704
+ if self.is_draft():
705
+ raise ArtifactNotLoggedError(self, attr)
706
+
654
707
  def is_draft(self) -> bool:
655
708
  """Whether the artifact is a draft, i.e. it hasn't been saved yet."""
656
709
  return self._state == ArtifactState.PENDING
@@ -684,7 +737,12 @@ class Artifact:
684
737
  if wandb.run is None:
685
738
  if settings is None:
686
739
  settings = wandb.Settings(silent="true")
687
- with wandb.init(project=project, job_type="auto", settings=settings) as run:
740
+ with wandb.init(
741
+ entity=self._source_entity,
742
+ project=project or self._source_project,
743
+ job_type="auto",
744
+ settings=settings,
745
+ ) as run:
688
746
  # redoing this here because in this branch we know we didn't
689
747
  # have the run at the beginning of the method
690
748
  if self._incremental:
@@ -706,7 +764,7 @@ class Artifact:
706
764
  Arguments:
707
765
  timeout: Wait up to this long.
708
766
  """
709
- if self._state == ArtifactState.PENDING:
767
+ if self.is_draft():
710
768
  if self._save_future is None:
711
769
  raise ArtifactNotLoggedError(self, "wait")
712
770
  result = self._save_future.get(timeout)
@@ -721,8 +779,7 @@ class Artifact:
721
779
  return self
722
780
 
723
781
  def _populate_after_save(self, artifact_id: str) -> None:
724
- query = gql(
725
- """
782
+ query_template = """
726
783
  query ArtifactByIDShort($id: ID!) {
727
784
  artifact(id: $id) {
728
785
  artifactSequence {
@@ -733,6 +790,8 @@ class Artifact:
733
790
  name
734
791
  }
735
792
  versionIndex
793
+ ttlDurationSeconds
794
+ ttlIsInherited
736
795
  aliases {
737
796
  artifactCollection {
738
797
  project {
@@ -755,8 +814,16 @@ class Artifact:
755
814
  updatedAt
756
815
  }
757
816
  }
758
- """
759
- )
817
+ """
818
+
819
+ fields = InternalApi().server_artifact_introspection()
820
+ if "ttlIsInherited" not in fields:
821
+ query_template = query_template.replace("ttlDurationSeconds", "").replace(
822
+ "ttlIsInherited",
823
+ "",
824
+ )
825
+ query = gql(query_template)
826
+
760
827
  assert self._client is not None
761
828
  response = self._client.execute(
762
829
  query,
@@ -776,6 +843,13 @@ class Artifact:
776
843
  self._source_project = self._project
777
844
  self._source_name = self._name
778
845
  self._source_version = self._version
846
+ self._ttl_duration_seconds = self._ttl_duration_seconds_from_gql(
847
+ attrs.get("ttlDurationSeconds")
848
+ )
849
+ self._ttl_is_inherited = (
850
+ True if attrs.get("ttlIsInherited") is None else attrs["ttlIsInherited"]
851
+ )
852
+ self._ttl_changed = False # Reset after saving artifact
779
853
  self._aliases = [
780
854
  alias["alias"]
781
855
  for alias in attrs["aliases"]
@@ -884,12 +958,12 @@ class Artifact:
884
958
  for alias in self._aliases
885
959
  ]
886
960
 
887
- mutation = gql(
888
- """
961
+ mutation_template = """
889
962
  mutation updateArtifact(
890
963
  $artifactID: ID!,
891
964
  $description: String,
892
965
  $metadata: JSONString,
966
+ _TTL_DURATION_SECONDS_TYPE_
893
967
  $aliases: [ArtifactAliasInput!]
894
968
  ) {
895
969
  updateArtifact(
@@ -897,26 +971,68 @@ class Artifact:
897
971
  artifactID: $artifactID,
898
972
  description: $description,
899
973
  metadata: $metadata,
974
+ _TTL_DURATION_SECONDS_VALUE_
900
975
  aliases: $aliases
901
976
  }
902
977
  ) {
903
978
  artifact {
904
979
  id
980
+ _TTL_DURATION_SECONDS_FIELDS_
905
981
  }
906
982
  }
907
983
  }
908
- """
909
- )
984
+ """
985
+ fields = InternalApi().server_artifact_introspection()
986
+ if "ttlIsInherited" in fields:
987
+ mutation_template = (
988
+ mutation_template.replace(
989
+ "_TTL_DURATION_SECONDS_TYPE_", "$ttlDurationSeconds: Int64,"
990
+ )
991
+ .replace(
992
+ "_TTL_DURATION_SECONDS_VALUE_",
993
+ "ttlDurationSeconds: $ttlDurationSeconds,",
994
+ )
995
+ .replace(
996
+ "_TTL_DURATION_SECONDS_FIELDS_", "ttlDurationSeconds ttlIsInherited"
997
+ )
998
+ )
999
+ else:
1000
+ if self._ttl_changed:
1001
+ termwarn(
1002
+ "Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
1003
+ )
1004
+ mutation_template = (
1005
+ mutation_template.replace("_TTL_DURATION_SECONDS_TYPE_", "")
1006
+ .replace(
1007
+ "_TTL_DURATION_SECONDS_VALUE_",
1008
+ "",
1009
+ )
1010
+ .replace("_TTL_DURATION_SECONDS_FIELDS_", "")
1011
+ )
1012
+ mutation = gql(mutation_template)
910
1013
  assert self._client is not None
911
- self._client.execute(
1014
+
1015
+ ttl_duration_input = self._ttl_duration_seconds_to_gql()
1016
+ response = self._client.execute(
912
1017
  mutation,
913
1018
  variable_values={
914
1019
  "artifactID": self.id,
915
1020
  "description": self.description,
916
1021
  "metadata": util.json_dumps_safer(self.metadata),
1022
+ "ttlDurationSeconds": ttl_duration_input,
917
1023
  "aliases": aliases,
918
1024
  },
919
1025
  )
1026
+ attrs = response["updateArtifact"]["artifact"]
1027
+
1028
+ # Update ttl_duration_seconds based on updateArtifact
1029
+ self._ttl_duration_seconds = self._ttl_duration_seconds_from_gql(
1030
+ attrs.get("ttlDurationSeconds")
1031
+ )
1032
+ self._ttl_is_inherited = (
1033
+ True if attrs.get("ttlIsInherited") is None else attrs["ttlIsInherited"]
1034
+ )
1035
+ self._ttl_changed = False # Reset after updating artifact
920
1036
 
921
1037
  # Adding, removing, getting entries.
922
1038
 
@@ -1424,8 +1540,7 @@ class Artifact:
1424
1540
  path.download()
1425
1541
  ```
1426
1542
  """
1427
- if self._state == ArtifactState.PENDING:
1428
- raise ArtifactNotLoggedError(self, "get_path")
1543
+ self._ensure_logged("get_path")
1429
1544
 
1430
1545
  name = LogicalPath(name)
1431
1546
  entry = self.manifest.entries.get(name) or self._get_obj_entry(name)[0]
@@ -1461,8 +1576,7 @@ class Artifact:
1461
1576
  table = artifact.get("my_table")
1462
1577
  ```
1463
1578
  """
1464
- if self._state == ArtifactState.PENDING:
1465
- raise ArtifactNotLoggedError(self, "get")
1579
+ self._ensure_logged("get")
1466
1580
 
1467
1581
  entry, wb_class = self._get_obj_entry(name)
1468
1582
  if entry is None or wb_class is None:
@@ -1470,9 +1584,11 @@ class Artifact:
1470
1584
 
1471
1585
  # If the entry is a reference from another artifact, then get it directly from
1472
1586
  # that artifact.
1473
- if entry._is_artifact_reference():
1587
+ referenced_id = entry._referenced_artifact_id()
1588
+ if referenced_id:
1474
1589
  assert self._client is not None
1475
- artifact = entry._get_referenced_artifact(self._client)
1590
+ artifact = self._from_id(referenced_id, client=self._client)
1591
+ assert artifact is not None
1476
1592
  return artifact.get(util.uri_from_path(entry.ref))
1477
1593
 
1478
1594
  # Special case for wandb.Table. This is intended to be a short term
@@ -1565,8 +1681,7 @@ class Artifact:
1565
1681
  Raises:
1566
1682
  ArtifactNotLoggedError: if the artifact has not been logged
1567
1683
  """
1568
- if self._state == ArtifactState.PENDING:
1569
- raise ArtifactNotLoggedError(self, "download")
1684
+ self._ensure_logged("download")
1570
1685
 
1571
1686
  root = root or self._default_root()
1572
1687
  self._add_download_root(root)
@@ -1581,7 +1696,7 @@ class Artifact:
1581
1696
  self.name, size / (1024 * 1024), nfiles
1582
1697
  ),
1583
1698
  )
1584
- start_time = datetime.datetime.now()
1699
+ start_time = datetime.now()
1585
1700
  download_logger = ArtifactDownloadLogger(nfiles=nfiles)
1586
1701
 
1587
1702
  def _download_entry(
@@ -1615,7 +1730,8 @@ class Artifact:
1615
1730
  has_next_page = True
1616
1731
  cursor = None
1617
1732
  while has_next_page:
1618
- attrs = self._fetch_file_urls(cursor)
1733
+ fetch_url_batch_size = env.get_artifact_fetch_file_url_batch_size()
1734
+ attrs = self._fetch_file_urls(cursor, fetch_url_batch_size)
1619
1735
  has_next_page = attrs["pageInfo"]["hasNextPage"]
1620
1736
  cursor = attrs["pageInfo"]["endCursor"]
1621
1737
  for edge in attrs["edges"]:
@@ -1623,7 +1739,7 @@ class Artifact:
1623
1739
  entry._download_url = edge["node"]["directUrl"]
1624
1740
  active_futures.add(executor.submit(download_entry, entry))
1625
1741
  # Wait for download threads to catch up.
1626
- max_backlog = 5000
1742
+ max_backlog = fetch_url_batch_size
1627
1743
  if len(active_futures) > max_backlog:
1628
1744
  for future in concurrent.futures.as_completed(active_futures):
1629
1745
  future.result() # check for errors
@@ -1639,7 +1755,7 @@ class Artifact:
1639
1755
  dependent_artifact.download()
1640
1756
 
1641
1757
  if log:
1642
- now = datetime.datetime.now()
1758
+ now = datetime.now()
1643
1759
  delta = abs((now - start_time).total_seconds())
1644
1760
  hours = int(delta // 3600)
1645
1761
  minutes = int((delta - hours * 3600) // 60)
@@ -1651,15 +1767,17 @@ class Artifact:
1651
1767
  return FilePathStr(root)
1652
1768
 
1653
1769
  @retry.retriable(
1654
- retry_timedelta=datetime.timedelta(minutes=3),
1770
+ retry_timedelta=timedelta(minutes=3),
1655
1771
  retryable_exceptions=(requests.RequestException),
1656
1772
  )
1657
- def _fetch_file_urls(self, cursor: Optional[str]) -> Any:
1773
+ def _fetch_file_urls(
1774
+ self, cursor: Optional[str], per_page: Optional[int] = 5000
1775
+ ) -> Any:
1658
1776
  query = gql(
1659
1777
  """
1660
- query ArtifactFileURLs($id: ID!, $cursor: String) {
1778
+ query ArtifactFileURLs($id: ID!, $cursor: String, $perPage: Int) {
1661
1779
  artifact(id: $id) {
1662
- files(after: $cursor, first: 5000) {
1780
+ files(after: $cursor, first: $perPage) {
1663
1781
  pageInfo {
1664
1782
  hasNextPage
1665
1783
  endCursor
@@ -1678,7 +1796,7 @@ class Artifact:
1678
1796
  assert self._client is not None
1679
1797
  response = self._client.execute(
1680
1798
  query,
1681
- variable_values={"id": self.id, "cursor": cursor},
1799
+ variable_values={"id": self.id, "cursor": cursor, "perPage": per_page},
1682
1800
  timeout=60,
1683
1801
  )
1684
1802
  return response["artifact"]["files"]
@@ -1698,8 +1816,7 @@ class Artifact:
1698
1816
  Raises:
1699
1817
  ArtifactNotLoggedError: if the artifact has not been logged
1700
1818
  """
1701
- if self._state == ArtifactState.PENDING:
1702
- raise ArtifactNotLoggedError(self, "checkout")
1819
+ self._ensure_logged("checkout")
1703
1820
 
1704
1821
  root = root or self._default_root(include_version=False)
1705
1822
 
@@ -1731,8 +1848,7 @@ class Artifact:
1731
1848
  ArtifactNotLoggedError: if the artifact has not been logged
1732
1849
  ValueError: If the verification fails.
1733
1850
  """
1734
- if self._state == ArtifactState.PENDING:
1735
- raise ArtifactNotLoggedError(self, "verify")
1851
+ self._ensure_logged("verify")
1736
1852
 
1737
1853
  root = root or self._default_root()
1738
1854
 
@@ -1773,8 +1889,7 @@ class Artifact:
1773
1889
  ArtifactNotLoggedError: if the artifact has not been logged
1774
1890
  ValueError: if the artifact contains more than one file
1775
1891
  """
1776
- if self._state == ArtifactState.PENDING:
1777
- raise ArtifactNotLoggedError(self, "file")
1892
+ self._ensure_logged("file")
1778
1893
 
1779
1894
  if root is None:
1780
1895
  root = os.path.join(".", "artifacts", self.name)
@@ -1803,18 +1918,17 @@ class Artifact:
1803
1918
  Raises:
1804
1919
  ArtifactNotLoggedError: if the artifact has not been logged
1805
1920
  """
1806
- if self._state == ArtifactState.PENDING:
1807
- raise ArtifactNotLoggedError(self, "files")
1808
-
1921
+ self._ensure_logged("files")
1809
1922
  return ArtifactFiles(self._client, self, names, per_page)
1810
1923
 
1811
- def _default_root(self, include_version: bool = True) -> str:
1924
+ def _default_root(self, include_version: bool = True) -> FilePathStr:
1812
1925
  name = self.source_name if include_version else self.source_name.split(":")[0]
1813
1926
  root = os.path.join(env.get_artifact_dir(), name)
1814
- if platform.system() == "Windows":
1815
- head, tail = os.path.splitdrive(root)
1816
- root = head + tail.replace(":", "-")
1817
- return root
1927
+ # In case we're on a system where the artifact dir has a name corresponding to
1928
+ # an unexpected filesystem, we'll check for alternate roots. If one exists we'll
1929
+ # use that, otherwise we'll fall back to the system-preferred path.
1930
+ path = filesystem.check_exists(root) or filesystem.system_preferred_path(root)
1931
+ return FilePathStr(str(path))
1818
1932
 
1819
1933
  def _add_download_root(self, dir_path: str) -> None:
1820
1934
  self._download_roots.add(os.path.abspath(dir_path))
@@ -1851,8 +1965,7 @@ class Artifact:
1851
1965
  artifact.delete(delete_aliases=True)
1852
1966
  ```
1853
1967
  """
1854
- if self._state == ArtifactState.PENDING:
1855
- raise ArtifactNotLoggedError(self, "delete")
1968
+ self._ensure_logged("delete")
1856
1969
  self._delete(delete_aliases)
1857
1970
 
1858
1971
  @normalize_exceptions
@@ -1880,6 +1993,7 @@ class Artifact:
1880
1993
  },
1881
1994
  )
1882
1995
 
1996
+ @normalize_exceptions
1883
1997
  def link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1884
1998
  """Link this artifact to a portfolio (a promoted collection of artifacts).
1885
1999
 
@@ -1892,63 +2006,16 @@ class Artifact:
1892
2006
  Raises:
1893
2007
  ArtifactNotLoggedError: if the artifact has not been logged
1894
2008
  """
1895
- if self._state == ArtifactState.PENDING:
1896
- raise ArtifactNotLoggedError(self, "link")
1897
- self._link(target_path, aliases)
1898
-
1899
- @normalize_exceptions
1900
- def _link(self, target_path: str, aliases: Optional[List[str]] = None) -> None:
1901
- if ":" in target_path:
1902
- raise ValueError(
1903
- f"target_path {target_path} cannot contain `:` because it is not an "
1904
- f"alias."
1905
- )
1906
-
1907
- portfolio, project, entity = util._parse_entity_project_item(target_path)
1908
- aliases = util._resolve_aliases(aliases)
1909
-
1910
- run_entity = wandb.run.entity if wandb.run else None
1911
- run_project = wandb.run.project if wandb.run else None
1912
- entity = entity or run_entity or self.entity
1913
- project = project or run_project or self.project
1914
-
1915
- mutation = gql(
1916
- """
1917
- mutation LinkArtifact(
1918
- $artifactID: ID!,
1919
- $artifactPortfolioName: String!,
1920
- $entityName: String!,
1921
- $projectName: String!,
1922
- $aliases: [ArtifactAliasInput!]
1923
- ) {
1924
- linkArtifact(
1925
- input: {
1926
- artifactID: $artifactID,
1927
- artifactPortfolioName: $artifactPortfolioName,
1928
- entityName: $entityName,
1929
- projectName: $projectName,
1930
- aliases: $aliases
1931
- }
1932
- ) {
1933
- versionIndex
1934
- }
1935
- }
1936
- """
1937
- )
1938
- assert self._client is not None
1939
- self._client.execute(
1940
- mutation,
1941
- variable_values={
1942
- "artifactID": self.id,
1943
- "artifactPortfolioName": portfolio,
1944
- "entityName": entity,
1945
- "projectName": project,
1946
- "aliases": [
1947
- {"alias": alias, "artifactCollectionName": portfolio}
1948
- for alias in aliases
1949
- ],
1950
- },
1951
- )
2009
+ if wandb.run is None:
2010
+ with wandb.init(
2011
+ entity=self._source_entity,
2012
+ project=self._source_project,
2013
+ job_type="auto",
2014
+ settings=wandb.Settings(silent="true"),
2015
+ ) as run:
2016
+ run.link_artifact(self, target_path, aliases)
2017
+ else:
2018
+ wandb.run.link_artifact(self, target_path, aliases)
1952
2019
 
1953
2020
  def used_by(self) -> List[Run]:
1954
2021
  """Get a list of the runs that have used this artifact.
@@ -1956,8 +2023,7 @@ class Artifact:
1956
2023
  Raises:
1957
2024
  ArtifactNotLoggedError: if the artifact has not been logged
1958
2025
  """
1959
- if self._state == ArtifactState.PENDING:
1960
- raise ArtifactNotLoggedError(self, "used_by")
2026
+ self._ensure_logged("used_by")
1961
2027
 
1962
2028
  query = gql(
1963
2029
  """
@@ -2001,8 +2067,7 @@ class Artifact:
2001
2067
  Raises:
2002
2068
  ArtifactNotLoggedError: if the artifact has not been logged
2003
2069
  """
2004
- if self._state == ArtifactState.PENDING:
2005
- raise ArtifactNotLoggedError(self, "logged_by")
2070
+ self._ensure_logged("logged_by")
2006
2071
 
2007
2072
  query = gql(
2008
2073
  """
@@ -2039,8 +2104,7 @@ class Artifact:
2039
2104
  )
2040
2105
 
2041
2106
  def json_encode(self) -> Dict[str, Any]:
2042
- if self._state == ArtifactState.PENDING:
2043
- raise ArtifactNotLoggedError(self, "json_encode")
2107
+ self._ensure_logged("json_encode")
2044
2108
  return util.artifact_to_json(self)
2045
2109
 
2046
2110
  @staticmethod
@@ -2097,11 +2161,79 @@ class Artifact:
2097
2161
  json.loads(util.ensure_text(request.content))
2098
2162
  )
2099
2163
  for entry in self.manifest.entries.values():
2100
- if entry._is_artifact_reference():
2164
+ referenced_id = entry._referenced_artifact_id()
2165
+ if referenced_id:
2101
2166
  assert self._client is not None
2102
- dep_artifact = entry._get_referenced_artifact(self._client)
2167
+ dep_artifact = self._from_id(referenced_id, client=self._client)
2168
+ assert dep_artifact is not None
2103
2169
  self._dependent_artifacts.add(dep_artifact)
2104
2170
 
2171
+ @staticmethod
2172
+ def _get_gql_artifact_fragment() -> str:
2173
+ fields = InternalApi().server_artifact_introspection()
2174
+ fragment = """
2175
+ fragment ArtifactFragment on Artifact {
2176
+ id
2177
+ artifactSequence {
2178
+ project {
2179
+ entityName
2180
+ name
2181
+ }
2182
+ name
2183
+ }
2184
+ versionIndex
2185
+ artifactType {
2186
+ name
2187
+ }
2188
+ description
2189
+ metadata
2190
+ ttlDurationSeconds
2191
+ ttlIsInherited
2192
+ aliases {
2193
+ artifactCollection {
2194
+ project {
2195
+ entityName
2196
+ name
2197
+ }
2198
+ name
2199
+ }
2200
+ alias
2201
+ }
2202
+ state
2203
+ commitHash
2204
+ fileCount
2205
+ createdAt
2206
+ updatedAt
2207
+ }
2208
+ """
2209
+ if "ttlIsInherited" not in fields:
2210
+ return fragment.replace("ttlDurationSeconds", "").replace(
2211
+ "ttlIsInherited", ""
2212
+ )
2213
+ return fragment
2214
+
2215
+ def _ttl_duration_seconds_to_gql(self) -> Optional[int]:
2216
+ # Set artifact ttl value to ttl_duration_seconds if the user set a value
2217
+ # otherwise use ttl_status to indicate the backend INHERIT(-1) or DISABLED(-2) when the TTL is None
2218
+ # When ttl_change = None its a no op since nothing changed
2219
+ INHERIT = -1 # noqa: N806
2220
+ DISABLED = -2 # noqa: N806
2221
+
2222
+ if not self._ttl_changed:
2223
+ return None
2224
+ if self._ttl_is_inherited:
2225
+ return INHERIT
2226
+ return self._ttl_duration_seconds or DISABLED
2227
+
2228
+ def _ttl_duration_seconds_from_gql(
2229
+ self, gql_ttl_duration_seconds: Optional[int]
2230
+ ) -> Optional[int]:
2231
+ # If gql_ttl_duration_seconds is not positive, its indicating that TTL is DISABLED(-2)
2232
+ # gql_ttl_duration_seconds only returns None if the server is not compatible with setting Artifact TTLs
2233
+ if gql_ttl_duration_seconds and gql_ttl_duration_seconds > 0:
2234
+ return gql_ttl_duration_seconds
2235
+ return None
2236
+
2105
2237
 
2106
2238
  class _ArtifactVersionType(WBType):
2107
2239
  name = "artifactVersion"