wandb 0.15.3__py3-none-any.whl → 0.15.5__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (156) hide show
  1. wandb/__init__.py +1 -1
  2. wandb/analytics/sentry.py +1 -0
  3. wandb/apis/importers/base.py +20 -5
  4. wandb/apis/importers/mlflow.py +7 -1
  5. wandb/apis/internal.py +12 -0
  6. wandb/apis/public.py +247 -1387
  7. wandb/apis/reports/_panels.py +58 -35
  8. wandb/beta/workflows.py +6 -7
  9. wandb/cli/cli.py +130 -60
  10. wandb/data_types.py +3 -1
  11. wandb/filesync/dir_watcher.py +21 -27
  12. wandb/filesync/step_checksum.py +8 -8
  13. wandb/filesync/step_prepare.py +23 -10
  14. wandb/filesync/step_upload.py +13 -13
  15. wandb/filesync/upload_job.py +4 -8
  16. wandb/integration/cohere/__init__.py +3 -0
  17. wandb/integration/cohere/cohere.py +21 -0
  18. wandb/integration/cohere/resolver.py +347 -0
  19. wandb/integration/gym/__init__.py +4 -6
  20. wandb/integration/huggingface/__init__.py +3 -0
  21. wandb/integration/huggingface/huggingface.py +18 -0
  22. wandb/integration/huggingface/resolver.py +213 -0
  23. wandb/integration/langchain/wandb_tracer.py +16 -179
  24. wandb/integration/openai/__init__.py +1 -3
  25. wandb/integration/openai/openai.py +11 -143
  26. wandb/integration/openai/resolver.py +111 -38
  27. wandb/integration/sagemaker/config.py +2 -2
  28. wandb/integration/tensorboard/log.py +4 -4
  29. wandb/old/settings.py +24 -7
  30. wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
  31. wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
  32. wandb/proto/wandb_deprecated.py +3 -1
  33. wandb/sdk/__init__.py +1 -1
  34. wandb/sdk/artifacts/__init__.py +0 -0
  35. wandb/sdk/artifacts/artifact.py +2101 -0
  36. wandb/sdk/artifacts/artifact_download_logger.py +42 -0
  37. wandb/sdk/artifacts/artifact_manifest.py +67 -0
  38. wandb/sdk/artifacts/artifact_manifest_entry.py +159 -0
  39. wandb/sdk/artifacts/artifact_manifests/__init__.py +0 -0
  40. wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +91 -0
  41. wandb/sdk/{internal → artifacts}/artifact_saver.py +6 -5
  42. wandb/sdk/artifacts/artifact_state.py +10 -0
  43. wandb/sdk/{interface/artifacts/artifact_cache.py → artifacts/artifacts_cache.py} +22 -12
  44. wandb/sdk/artifacts/exceptions.py +55 -0
  45. wandb/sdk/artifacts/storage_handler.py +59 -0
  46. wandb/sdk/artifacts/storage_handlers/__init__.py +0 -0
  47. wandb/sdk/artifacts/storage_handlers/azure_handler.py +192 -0
  48. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +224 -0
  49. wandb/sdk/artifacts/storage_handlers/http_handler.py +112 -0
  50. wandb/sdk/artifacts/storage_handlers/local_file_handler.py +134 -0
  51. wandb/sdk/artifacts/storage_handlers/multi_handler.py +53 -0
  52. wandb/sdk/artifacts/storage_handlers/s3_handler.py +301 -0
  53. wandb/sdk/artifacts/storage_handlers/tracking_handler.py +67 -0
  54. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +132 -0
  55. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +72 -0
  56. wandb/sdk/artifacts/storage_layout.py +6 -0
  57. wandb/sdk/artifacts/storage_policies/__init__.py +0 -0
  58. wandb/sdk/artifacts/storage_policies/s3_bucket_policy.py +61 -0
  59. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +386 -0
  60. wandb/sdk/{interface/artifacts/artifact_storage.py → artifacts/storage_policy.py} +5 -57
  61. wandb/sdk/data_types/_dtypes.py +7 -12
  62. wandb/sdk/data_types/base_types/json_metadata.py +3 -2
  63. wandb/sdk/data_types/base_types/media.py +8 -8
  64. wandb/sdk/data_types/base_types/wb_value.py +12 -13
  65. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +5 -6
  66. wandb/sdk/data_types/helper_types/classes.py +6 -8
  67. wandb/sdk/data_types/helper_types/image_mask.py +5 -6
  68. wandb/sdk/data_types/histogram.py +4 -3
  69. wandb/sdk/data_types/html.py +3 -4
  70. wandb/sdk/data_types/image.py +11 -9
  71. wandb/sdk/data_types/molecule.py +5 -3
  72. wandb/sdk/data_types/object_3d.py +7 -5
  73. wandb/sdk/data_types/plotly.py +3 -2
  74. wandb/sdk/data_types/saved_model.py +11 -11
  75. wandb/sdk/data_types/trace_tree.py +5 -4
  76. wandb/sdk/data_types/utils.py +3 -5
  77. wandb/sdk/data_types/video.py +5 -4
  78. wandb/sdk/integration_utils/auto_logging.py +215 -0
  79. wandb/sdk/interface/interface.py +15 -15
  80. wandb/sdk/internal/file_pusher.py +8 -16
  81. wandb/sdk/internal/file_stream.py +5 -11
  82. wandb/sdk/internal/handler.py +13 -1
  83. wandb/sdk/internal/internal_api.py +287 -13
  84. wandb/sdk/internal/job_builder.py +119 -30
  85. wandb/sdk/internal/sender.py +6 -26
  86. wandb/sdk/internal/settings_static.py +2 -0
  87. wandb/sdk/internal/system/assets/__init__.py +2 -0
  88. wandb/sdk/internal/system/assets/gpu.py +42 -0
  89. wandb/sdk/internal/system/assets/gpu_amd.py +216 -0
  90. wandb/sdk/internal/system/env_probe_helpers.py +13 -0
  91. wandb/sdk/internal/system/system_info.py +3 -3
  92. wandb/sdk/internal/tb_watcher.py +32 -22
  93. wandb/sdk/internal/thread_local_settings.py +18 -0
  94. wandb/sdk/launch/_project_spec.py +57 -11
  95. wandb/sdk/launch/agent/agent.py +147 -65
  96. wandb/sdk/launch/agent/job_status_tracker.py +34 -0
  97. wandb/sdk/launch/agent/run_queue_item_file_saver.py +45 -0
  98. wandb/sdk/launch/builder/abstract.py +5 -1
  99. wandb/sdk/launch/builder/build.py +21 -18
  100. wandb/sdk/launch/builder/docker_builder.py +10 -4
  101. wandb/sdk/launch/builder/kaniko_builder.py +113 -23
  102. wandb/sdk/launch/builder/noop.py +6 -3
  103. wandb/sdk/launch/builder/templates/_wandb_bootstrap.py +46 -14
  104. wandb/sdk/launch/environment/aws_environment.py +3 -2
  105. wandb/sdk/launch/environment/azure_environment.py +124 -0
  106. wandb/sdk/launch/environment/gcp_environment.py +2 -4
  107. wandb/sdk/launch/environment/local_environment.py +1 -1
  108. wandb/sdk/launch/errors.py +19 -0
  109. wandb/sdk/launch/github_reference.py +32 -19
  110. wandb/sdk/launch/launch.py +3 -8
  111. wandb/sdk/launch/launch_add.py +6 -2
  112. wandb/sdk/launch/loader.py +21 -2
  113. wandb/sdk/launch/registry/azure_container_registry.py +132 -0
  114. wandb/sdk/launch/registry/elastic_container_registry.py +39 -5
  115. wandb/sdk/launch/registry/google_artifact_registry.py +68 -26
  116. wandb/sdk/launch/registry/local_registry.py +2 -1
  117. wandb/sdk/launch/runner/abstract.py +24 -3
  118. wandb/sdk/launch/runner/kubernetes_runner.py +479 -26
  119. wandb/sdk/launch/runner/local_container.py +103 -51
  120. wandb/sdk/launch/runner/local_process.py +1 -1
  121. wandb/sdk/launch/runner/sagemaker_runner.py +60 -10
  122. wandb/sdk/launch/runner/vertex_runner.py +10 -5
  123. wandb/sdk/launch/sweeps/__init__.py +7 -9
  124. wandb/sdk/launch/sweeps/scheduler.py +307 -77
  125. wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
  126. wandb/sdk/launch/sweeps/utils.py +82 -35
  127. wandb/sdk/launch/utils.py +89 -75
  128. wandb/sdk/lib/_settings_toposort_generated.py +7 -0
  129. wandb/sdk/lib/capped_dict.py +26 -0
  130. wandb/sdk/lib/{git.py → gitlib.py} +76 -59
  131. wandb/sdk/lib/hashutil.py +12 -4
  132. wandb/sdk/lib/paths.py +96 -8
  133. wandb/sdk/lib/sock_client.py +2 -2
  134. wandb/sdk/lib/timer.py +1 -0
  135. wandb/sdk/service/server.py +22 -9
  136. wandb/sdk/service/server_sock.py +1 -1
  137. wandb/sdk/service/service.py +27 -8
  138. wandb/sdk/verify/verify.py +4 -7
  139. wandb/sdk/wandb_config.py +2 -6
  140. wandb/sdk/wandb_init.py +57 -53
  141. wandb/sdk/wandb_require.py +7 -0
  142. wandb/sdk/wandb_run.py +61 -223
  143. wandb/sdk/wandb_settings.py +28 -4
  144. wandb/testing/relay.py +15 -2
  145. wandb/util.py +74 -36
  146. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/METADATA +15 -9
  147. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/RECORD +151 -116
  148. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/entry_points.txt +1 -0
  149. wandb/integration/langchain/util.py +0 -191
  150. wandb/sdk/interface/artifacts/__init__.py +0 -33
  151. wandb/sdk/interface/artifacts/artifact.py +0 -615
  152. wandb/sdk/interface/artifacts/artifact_manifest.py +0 -131
  153. wandb/sdk/wandb_artifacts.py +0 -2226
  154. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/LICENSE +0 -0
  155. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/WHEEL +0 -0
  156. {wandb-0.15.3.dist-info → wandb-0.15.5.dist-info}/top_level.txt +0 -0
wandb/apis/public.py CHANGED
@@ -15,20 +15,15 @@ import datetime
15
15
  import io
16
16
  import json
17
17
  import logging
18
- import multiprocessing.dummy # this uses threads
19
18
  import os
20
19
  import platform
21
- import re
22
20
  import shutil
23
21
  import tempfile
24
22
  import time
25
23
  import urllib
26
- from collections import namedtuple
27
- from functools import partial
28
24
  from typing import (
29
25
  TYPE_CHECKING,
30
26
  Any,
31
- Callable,
32
27
  Dict,
33
28
  List,
34
29
  Mapping,
@@ -45,21 +40,19 @@ import wandb
45
40
  from wandb import __version__, env, util
46
41
  from wandb.apis.internal import Api as InternalApi
47
42
  from wandb.apis.normalize import normalize_exceptions
48
- from wandb.data_types import WBValue
49
- from wandb.env import get_artifact_dir
50
43
  from wandb.errors import CommError
51
- from wandb.errors.term import termlog
52
44
  from wandb.sdk.data_types._dtypes import InvalidType, Type, TypeRegistry
53
- from wandb.sdk.interface import artifacts
45
+ from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
46
+ from wandb.sdk.launch.errors import LaunchError
54
47
  from wandb.sdk.launch.utils import (
55
48
  LAUNCH_DEFAULT_PROJECT,
56
- LaunchError,
57
49
  _fetch_git_repo,
58
50
  apply_patch,
51
+ convert_jupyter_notebook_to_script,
59
52
  )
60
- from wandb.sdk.lib import filesystem, ipython, retry, runid
53
+ from wandb.sdk.lib import ipython, retry, runid
61
54
  from wandb.sdk.lib.gql_request import GraphQLSession
62
- from wandb.sdk.lib.hashutil import b64_to_hex_id, hex_to_b64_id, md5_file_b64
55
+ from wandb.sdk.lib.paths import LogicalPath
63
56
 
64
57
  if TYPE_CHECKING:
65
58
  import wandb.apis.reports
@@ -144,41 +137,6 @@ fragment ArtifactTypesFragment on ArtifactTypeConnection {
144
137
  }
145
138
  """
146
139
 
147
- ARTIFACT_FRAGMENT = """
148
- fragment ArtifactFragment on Artifact {
149
- id
150
- digest
151
- description
152
- state
153
- size
154
- createdAt
155
- updatedAt
156
- labels
157
- metadata
158
- fileCount
159
- versionIndex
160
- aliases {
161
- artifactCollectionName
162
- alias
163
- }
164
- artifactSequence {
165
- id
166
- name
167
- }
168
- artifactType {
169
- id
170
- name
171
- project {
172
- name
173
- entity {
174
- name
175
- }
176
- }
177
- }
178
- commitHash
179
- }
180
- """
181
-
182
140
  # TODO, factor out common file fragment
183
141
  ARTIFACT_FILES_FRAGMENT = """fragment ArtifactFilesFragment on Artifact {
184
142
  files(names: $fileNames, after: $fileCursor, first: $fileLimit) {
@@ -407,7 +365,7 @@ class Api:
407
365
  self.settings = InternalApi().settings()
408
366
  _overrides = overrides or {}
409
367
  self._api_key = api_key
410
- if self.api_key is None:
368
+ if self.api_key is None and _thread_local_api_settings.cookies is None:
411
369
  wandb.login(host=_overrides.get("base_url"))
412
370
  self.settings.update(_overrides)
413
371
  if "username" in _overrides and "entity" not in _overrides:
@@ -424,15 +382,23 @@ class Api:
424
382
  self._reports = {}
425
383
  self._default_entity = None
426
384
  self._timeout = timeout if timeout is not None else self._HTTP_TIMEOUT
385
+ auth = None
386
+ if not _thread_local_api_settings.cookies:
387
+ auth = ("api", self.api_key)
427
388
  self._base_client = Client(
428
389
  transport=GraphQLSession(
429
- headers={"User-Agent": self.user_agent, "Use-Admin-Privileges": "true"},
390
+ headers={
391
+ "User-Agent": self.user_agent,
392
+ "Use-Admin-Privileges": "true",
393
+ **(_thread_local_api_settings.headers or {}),
394
+ },
430
395
  use_json=True,
431
396
  # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
432
397
  # https://bugs.python.org/issue22889
433
398
  timeout=self._timeout,
434
- auth=("api", self.api_key),
399
+ auth=auth,
435
400
  url="%s/graphql" % self.settings["base_url"],
401
+ cookies=_thread_local_api_settings.cookies,
436
402
  )
437
403
  )
438
404
  self._client = RetryingClient(self._base_client)
@@ -525,6 +491,9 @@ class Api:
525
491
 
526
492
  @property
527
493
  def api_key(self):
494
+ # just use thread local api key if it's set
495
+ if _thread_local_api_settings.api_key:
496
+ return _thread_local_api_settings.api_key
528
497
  if self._api_key is not None:
529
498
  return self._api_key
530
499
  auth = requests.utils.get_netrc_auth(self.settings["base_url"])
@@ -939,14 +908,13 @@ class Api:
939
908
 
940
909
  @normalize_exceptions
941
910
  def artifact(self, name, type=None):
942
- """Return a single artifact by parsing path in the form `entity/project/run_id`.
911
+ """Return a single artifact by parsing path in the form `entity/project/name`.
943
912
 
944
913
  Arguments:
945
914
  name: (str) An artifact name. May be prefixed with entity/project. Valid names
946
915
  can be in the following forms:
947
916
  name:version
948
917
  name:alias
949
- digest
950
918
  type: (str, optional) The type of artifact to fetch.
951
919
 
952
920
  Returns:
@@ -955,7 +923,9 @@ class Api:
955
923
  if name is None:
956
924
  raise ValueError("You must specify name= to fetch an artifact.")
957
925
  entity, project, artifact_name = self._parse_artifact_path(name)
958
- artifact = Artifact(self.client, entity, project, artifact_name)
926
+ artifact = wandb.Artifact._from_name(
927
+ entity, project, artifact_name, self.client
928
+ )
959
929
  if type is not None and artifact.type != type:
960
930
  raise ValueError(
961
931
  f"type {type} specified but this artifact is of type {artifact.type}"
@@ -966,6 +936,10 @@ class Api:
966
936
  def job(self, name, path=None):
967
937
  if name is None:
968
938
  raise ValueError("You must specify name= to fetch a job.")
939
+ elif name.count("/") != 2 or ":" not in name:
940
+ raise ValueError(
941
+ "Invalid job specification. A job must be of the form: <entity>/<project>/<job-name>:<alias-or-version>"
942
+ )
969
943
  return Job(self, name, path)
970
944
 
971
945
 
@@ -2036,7 +2010,7 @@ class Run(Attrs):
2036
2010
  root = os.path.abspath(root)
2037
2011
  name = os.path.relpath(path, root)
2038
2012
  with open(os.path.join(root, name), "rb") as f:
2039
- api.push({util.to_forward_slash_path(name): f})
2013
+ api.push({LogicalPath(name): f})
2040
2014
  return Files(self.client, self, [name])[0]
2041
2015
 
2042
2016
  @normalize_exceptions
@@ -2166,10 +2140,10 @@ class Run(Attrs):
2166
2140
  )
2167
2141
  api.set_current_run_id(self.id)
2168
2142
 
2169
- if isinstance(artifact, Artifact):
2143
+ if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
2170
2144
  api.use_artifact(artifact.id, use_as=use_as or artifact.name)
2171
2145
  return artifact
2172
- elif isinstance(artifact, wandb.Artifact):
2146
+ elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
2173
2147
  raise ValueError(
2174
2148
  "Only existing artifacts are accepted by this api. "
2175
2149
  "Manually create one with `wandb artifacts put`"
@@ -2194,7 +2168,7 @@ class Run(Attrs):
2194
2168
  )
2195
2169
  api.set_current_run_id(self.id)
2196
2170
 
2197
- if isinstance(artifact, Artifact):
2171
+ if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
2198
2172
  artifact_collection_name = artifact.name.split(":")[0]
2199
2173
  api.create_artifact(
2200
2174
  artifact.type,
@@ -2203,7 +2177,7 @@ class Run(Attrs):
2203
2177
  aliases=aliases,
2204
2178
  )
2205
2179
  return artifact
2206
- elif isinstance(artifact, wandb.Artifact):
2180
+ elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
2207
2181
  raise ValueError(
2208
2182
  "Only existing artifacts are accepted by this api. "
2209
2183
  "Manually create one with `wandb artifacts put`"
@@ -3812,72 +3786,72 @@ class ProjectArtifactCollections(Paginator):
3812
3786
 
3813
3787
 
3814
3788
  class RunArtifacts(Paginator):
3815
- OUTPUT_QUERY = gql(
3816
- """
3817
- query RunOutputArtifacts(
3818
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
3819
- ) {
3820
- project(name: $project, entityName: $entity) {
3821
- run(name: $runName) {
3822
- outputArtifacts(after: $cursor, first: $perPage) {
3823
- totalCount
3824
- edges {
3825
- node {
3826
- ...ArtifactFragment
3789
+ def __init__(
3790
+ self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
3791
+ ):
3792
+ output_query = gql(
3793
+ """
3794
+ query RunOutputArtifacts(
3795
+ $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
3796
+ ) {
3797
+ project(name: $project, entityName: $entity) {
3798
+ run(name: $runName) {
3799
+ outputArtifacts(after: $cursor, first: $perPage) {
3800
+ totalCount
3801
+ edges {
3802
+ node {
3803
+ ...ArtifactFragment
3804
+ }
3805
+ cursor
3806
+ }
3807
+ pageInfo {
3808
+ endCursor
3809
+ hasNextPage
3827
3810
  }
3828
- cursor
3829
- }
3830
- pageInfo {
3831
- endCursor
3832
- hasNextPage
3833
3811
  }
3834
3812
  }
3835
3813
  }
3836
3814
  }
3837
- }
3838
- %s
3839
- """
3840
- % ARTIFACT_FRAGMENT
3841
- )
3815
+ %s
3816
+ """
3817
+ % wandb.Artifact._GQL_FRAGMENT
3818
+ )
3842
3819
 
3843
- INPUT_QUERY = gql(
3844
- """
3845
- query RunInputArtifacts(
3846
- $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
3847
- ) {
3848
- project(name: $project, entityName: $entity) {
3849
- run(name: $runName) {
3850
- inputArtifacts(after: $cursor, first: $perPage) {
3851
- totalCount
3852
- edges {
3853
- node {
3854
- ...ArtifactFragment
3820
+ input_query = gql(
3821
+ """
3822
+ query RunInputArtifacts(
3823
+ $entity: String!, $project: String!, $runName: String!, $cursor: String, $perPage: Int,
3824
+ ) {
3825
+ project(name: $project, entityName: $entity) {
3826
+ run(name: $runName) {
3827
+ inputArtifacts(after: $cursor, first: $perPage) {
3828
+ totalCount
3829
+ edges {
3830
+ node {
3831
+ ...ArtifactFragment
3832
+ }
3833
+ cursor
3834
+ }
3835
+ pageInfo {
3836
+ endCursor
3837
+ hasNextPage
3855
3838
  }
3856
- cursor
3857
- }
3858
- pageInfo {
3859
- endCursor
3860
- hasNextPage
3861
3839
  }
3862
3840
  }
3863
3841
  }
3864
3842
  }
3865
- }
3866
- %s
3867
- """
3868
- % ARTIFACT_FRAGMENT
3869
- )
3843
+ %s
3844
+ """
3845
+ % wandb.Artifact._GQL_FRAGMENT
3846
+ )
3870
3847
 
3871
- def __init__(
3872
- self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
3873
- ):
3874
3848
  self.run = run
3875
3849
  if mode == "logged":
3876
3850
  self.run_key = "outputArtifacts"
3877
- self.QUERY = self.OUTPUT_QUERY
3851
+ self.QUERY = output_query
3878
3852
  elif mode == "used":
3879
3853
  self.run_key = "inputArtifacts"
3880
- self.QUERY = self.INPUT_QUERY
3854
+ self.QUERY = input_query
3881
3855
  else:
3882
3856
  raise ValueError("mode must be logged or used")
3883
3857
 
@@ -3916,14 +3890,14 @@ class RunArtifacts(Paginator):
3916
3890
 
3917
3891
  def convert_objects(self):
3918
3892
  return [
3919
- Artifact(
3920
- self.client,
3893
+ wandb.Artifact._from_attrs(
3921
3894
  self.run.entity,
3922
3895
  self.run.project,
3923
3896
  "{}:v{}".format(
3924
3897
  r["node"]["artifactSequence"]["name"], r["node"]["versionIndex"]
3925
3898
  ),
3926
3899
  r["node"],
3900
+ self.client,
3927
3901
  )
3928
3902
  for r in self.last_response["project"]["run"][self.run_key]["edges"]
3929
3903
  ]
@@ -4109,1293 +4083,154 @@ class ArtifactCollection:
4109
4083
  return f"<ArtifactCollection {self.name} ({self.type})>"
4110
4084
 
4111
4085
 
4112
- class _DownloadedArtifactEntry(artifacts.ArtifactManifestEntry):
4086
+ class ArtifactVersions(Paginator):
4087
+ """An iterable collection of artifact versions associated with a project and optional filter.
4088
+
4089
+ This is generally used indirectly via the `Api`.artifact_versions method.
4090
+ """
4091
+
4113
4092
  def __init__(
4114
4093
  self,
4115
- entry: "artifacts.ArtifactManifestEntry",
4116
- parent_artifact: "Artifact",
4094
+ client: Client,
4095
+ entity: str,
4096
+ project: str,
4097
+ collection_name: str,
4098
+ type: str,
4099
+ filters: Optional[Mapping[str, Any]] = None,
4100
+ order: Optional[str] = None,
4101
+ per_page: int = 50,
4117
4102
  ):
4118
- super().__init__(
4119
- path=entry.path,
4120
- digest=entry.digest,
4121
- ref=entry.ref,
4122
- birth_artifact_id=entry.birth_artifact_id,
4123
- size=entry.size,
4124
- extra=entry.extra,
4125
- local_path=entry.local_path,
4103
+ self.entity = entity
4104
+ self.collection_name = collection_name
4105
+ self.type = type
4106
+ self.project = project
4107
+ self.filters = {"state": "COMMITTED"} if filters is None else filters
4108
+ self.order = order
4109
+ variables = {
4110
+ "project": self.project,
4111
+ "entity": self.entity,
4112
+ "order": self.order,
4113
+ "type": self.type,
4114
+ "collection": self.collection_name,
4115
+ "filters": json.dumps(self.filters),
4116
+ }
4117
+ self.QUERY = gql(
4118
+ """
4119
+ query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
4120
+ project(name: $project, entityName: $entity) {{
4121
+ artifactType(name: $type) {{
4122
+ artifactCollection: {}(name: $collection) {{
4123
+ name
4124
+ artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
4125
+ totalCount
4126
+ edges {{
4127
+ node {{
4128
+ ...ArtifactFragment
4129
+ }}
4130
+ version
4131
+ cursor
4132
+ }}
4133
+ pageInfo {{
4134
+ endCursor
4135
+ hasNextPage
4136
+ }}
4137
+ }}
4138
+ }}
4139
+ }}
4140
+ }}
4141
+ }}
4142
+ {}
4143
+ """.format(
4144
+ artifact_collection_edge_name(
4145
+ server_supports_artifact_collections_gql_edges(client)
4146
+ ),
4147
+ wandb.Artifact._GQL_FRAGMENT,
4148
+ )
4126
4149
  )
4127
- self._parent_artifact = parent_artifact
4150
+ super().__init__(client, variables, per_page)
4128
4151
 
4129
4152
  @property
4130
- def name(self):
4131
- # TODO(hugh): add telemetry to see if anyone is still using this.
4132
- wandb.termwarn("ArtifactManifestEntry.name is deprecated, use .path instead")
4133
- return self.path
4134
-
4135
- def parent_artifact(self):
4136
- return self._parent_artifact
4137
-
4138
- def copy(self, cache_path, target_path):
4139
- raise NotImplementedError()
4140
-
4141
- def download(self, root=None):
4142
- root = root or self._parent_artifact._default_root()
4143
- dest_path = os.path.join(root, self.path)
4144
-
4145
- self._parent_artifact._add_download_root(root)
4146
- manifest = self._parent_artifact._load_manifest()
4147
-
4148
- # Skip checking the cache (and possibly downloading) if the file already exists
4149
- # and has the digest we're expecting.
4150
- entry = manifest.entries[self.path]
4151
- if os.path.exists(dest_path) and entry.digest == md5_file_b64(dest_path):
4152
- return dest_path
4153
+ def length(self):
4154
+ if self.last_response:
4155
+ return self.last_response["project"]["artifactType"]["artifactCollection"][
4156
+ "artifacts"
4157
+ ]["totalCount"]
4158
+ else:
4159
+ return None
4153
4160
 
4154
- if self.ref is not None:
4155
- cache_path = manifest.storage_policy.load_reference(entry, local=True)
4161
+ @property
4162
+ def more(self):
4163
+ if self.last_response:
4164
+ return self.last_response["project"]["artifactType"]["artifactCollection"][
4165
+ "artifacts"
4166
+ ]["pageInfo"]["hasNextPage"]
4156
4167
  else:
4157
- cache_path = manifest.storage_policy.load_file(self._parent_artifact, entry)
4168
+ return True
4158
4169
 
4159
- return filesystem.copy_or_overwrite_changed(cache_path, dest_path)
4170
+ @property
4171
+ def cursor(self):
4172
+ if self.last_response:
4173
+ return self.last_response["project"]["artifactType"]["artifactCollection"][
4174
+ "artifacts"
4175
+ ]["edges"][-1]["cursor"]
4176
+ else:
4177
+ return None
4160
4178
 
4161
- def ref_target(self):
4162
- manifest = self._parent_artifact._load_manifest()
4163
- if self.ref is not None:
4164
- return manifest.storage_policy.load_reference(
4165
- manifest.entries[self.path],
4166
- local=False,
4179
+ def convert_objects(self):
4180
+ if self.last_response["project"]["artifactType"]["artifactCollection"] is None:
4181
+ return []
4182
+ return [
4183
+ wandb.Artifact._from_attrs(
4184
+ self.entity,
4185
+ self.project,
4186
+ self.collection_name + ":" + a["version"],
4187
+ a["node"],
4188
+ self.client,
4167
4189
  )
4168
- raise ValueError("Only reference entries support ref_target().")
4169
-
4170
- def ref_url(self):
4171
- return (
4172
- "wandb-artifact://"
4173
- + b64_to_hex_id(self._parent_artifact.id)
4174
- + "/"
4175
- + self.path
4176
- )
4177
-
4178
-
4179
- class _ArtifactDownloadLogger:
4180
- def __init__(
4181
- self,
4182
- nfiles: int,
4183
- clock_for_testing: Callable[[], float] = time.monotonic,
4184
- termlog_for_testing=termlog,
4185
- ) -> None:
4186
- self._nfiles = nfiles
4187
- self._clock = clock_for_testing
4188
- self._termlog = termlog_for_testing
4189
-
4190
- self._n_files_downloaded = 0
4191
- self._spinner_index = 0
4192
- self._last_log_time = self._clock()
4193
- self._lock = multiprocessing.dummy.Lock()
4194
-
4195
- def notify_downloaded(self) -> None:
4196
- with self._lock:
4197
- self._n_files_downloaded += 1
4198
- if self._n_files_downloaded == self._nfiles:
4199
- self._termlog(
4200
- f" {self._nfiles} of {self._nfiles} files downloaded. ",
4201
- # ^ trailing spaces to wipe out ellipsis from previous logs
4202
- newline=True,
4203
- )
4204
- self._last_log_time = self._clock()
4205
- elif self._clock() - self._last_log_time > 0.1:
4206
- self._spinner_index += 1
4207
- spinner = r"-\|/"[self._spinner_index % 4]
4208
- self._termlog(
4209
- f"{spinner} {self._n_files_downloaded} of {self._nfiles} files downloaded...\r",
4210
- newline=False,
4211
- )
4212
- self._last_log_time = self._clock()
4213
-
4214
-
4215
- class Artifact(artifacts.Artifact):
4216
- """A wandb Artifact.
4217
-
4218
- An artifact that has been logged, including all its attributes, links to the runs
4219
- that use it, and a link to the run that logged it.
4220
-
4221
- Examples:
4222
- Basic usage
4223
- ```
4224
- api = wandb.Api()
4225
- artifact = api.artifact('project/artifact:alias')
4226
-
4227
- # Get information about the artifact...
4228
- artifact.digest
4229
- artifact.aliases
4230
- ```
4231
-
4232
- Updating an artifact
4233
- ```
4234
- artifact = api.artifact('project/artifact:alias')
4235
-
4236
- # Update the description
4237
- artifact.description = 'My new description'
4238
-
4239
- # Selectively update metadata keys
4240
- artifact.metadata["oldKey"] = "new value"
4241
-
4242
- # Replace the metadata entirely
4243
- artifact.metadata = {"newKey": "new value"}
4244
-
4245
- # Add an alias
4246
- artifact.aliases.append('best')
4247
-
4248
- # Remove an alias
4249
- artifact.aliases.remove('latest')
4250
-
4251
- # Completely replace the aliases
4252
- artifact.aliases = ['replaced']
4253
-
4254
- # Persist all artifact modifications
4255
- artifact.save()
4256
- ```
4257
-
4258
- Artifact graph traversal
4259
- ```
4260
- artifact = api.artifact('project/artifact:alias')
4261
-
4262
- # Walk up and down the graph from an artifact:
4263
- producer_run = artifact.logged_by()
4264
- consumer_runs = artifact.used_by()
4265
-
4266
- # Walk up and down the graph from a run:
4267
- logged_artifacts = run.logged_artifacts()
4268
- used_artifacts = run.used_artifacts()
4269
- ```
4190
+ for a in self.last_response["project"]["artifactType"][
4191
+ "artifactCollection"
4192
+ ]["artifacts"]["edges"]
4193
+ ]
4270
4194
 
4271
- Deleting an artifact
4272
- ```
4273
- artifact = api.artifact('project/artifact:alias')
4274
- artifact.delete()
4275
- ```
4276
- """
4277
4195
 
4196
+ class ArtifactFiles(Paginator):
4278
4197
  QUERY = gql(
4279
4198
  """
4280
- query ArtifactWithCurrentManifest(
4281
- $id: ID!,
4199
+ query ArtifactFiles(
4200
+ $entityName: String!,
4201
+ $projectName: String!,
4202
+ $artifactTypeName: String!,
4203
+ $artifactName: String!
4204
+ $fileNames: [String!],
4205
+ $fileCursor: String,
4206
+ $fileLimit: Int = 50
4282
4207
  ) {
4283
- artifact(id: $id) {
4284
- currentManifest {
4285
- id
4286
- file {
4287
- id
4288
- directUrl
4208
+ project(name: $projectName, entityName: $entityName) {
4209
+ artifactType(name: $artifactTypeName) {
4210
+ artifact(name: $artifactName) {
4211
+ ...ArtifactFilesFragment
4289
4212
  }
4290
4213
  }
4291
- ...ArtifactFragment
4292
4214
  }
4293
4215
  }
4294
4216
  %s
4295
4217
  """
4296
- % ARTIFACT_FRAGMENT
4297
- )
4298
-
4299
- @classmethod
4300
- def from_id(cls, artifact_id: str, client: Client):
4301
- artifact = artifacts.get_artifacts_cache().get_artifact(artifact_id)
4302
- if artifact is not None:
4303
- return artifact
4304
- response: Mapping[str, Any] = client.execute(
4305
- Artifact.QUERY,
4306
- variable_values={"id": artifact_id},
4307
- )
4308
-
4309
- name = None
4310
- if response.get("artifact") is not None:
4311
- if response["artifact"].get("aliases") is not None:
4312
- aliases = response["artifact"]["aliases"]
4313
- name = ":".join(
4314
- [aliases[0]["artifactCollectionName"], aliases[0]["alias"]]
4315
- )
4316
- if len(aliases) > 1:
4317
- for alias in aliases:
4318
- if alias["alias"] != "latest":
4319
- name = ":".join(
4320
- [alias["artifactCollectionName"], alias["alias"]]
4321
- )
4322
- break
4323
-
4324
- p = response.get("artifact", {}).get("artifactType", {}).get("project", {})
4325
- project = p.get("name") # defaults to None
4326
- entity = p.get("entity", {}).get("name")
4327
-
4328
- artifact = cls(
4329
- client=client,
4330
- entity=entity,
4331
- project=project,
4332
- name=name,
4333
- attrs=response["artifact"],
4334
- )
4335
- index_file_url = response["artifact"]["currentManifest"]["file"][
4336
- "directUrl"
4337
- ]
4338
- with requests.get(index_file_url) as req:
4339
- req.raise_for_status()
4340
- artifact._manifest = artifacts.ArtifactManifest.from_manifest_json(
4341
- json.loads(util.ensure_text(req.content))
4342
- )
4343
-
4344
- artifact._load_dependent_manifests()
4345
-
4346
- return artifact
4347
-
4348
- def __init__(self, client, entity, project, name, attrs=None):
4349
- self.client = client
4350
- self._entity = entity
4351
- self._project = project
4352
- self._artifact_name = name
4353
- self._artifact_collection_name = name.split(":")[0]
4354
- self._attrs = attrs
4355
- if self._attrs is None:
4356
- self._load()
4357
-
4358
- # The entity and project above are taken from the passed-in artifact version path
4359
- # so if the user is pulling an artifact version from an artifact portfolio, the entity/project
4360
- # of that portfolio may be different than the birth entity/project of the artifact version.
4361
- self._birth_project = (
4362
- self._attrs.get("artifactType", {}).get("project", {}).get("name")
4363
- )
4364
- self._birth_entity = (
4365
- self._attrs.get("artifactType", {})
4366
- .get("project", {})
4367
- .get("entity", {})
4368
- .get("name")
4369
- )
4370
- self._metadata = json.loads(self._attrs.get("metadata") or "{}")
4371
- self._description = self._attrs.get("description", None)
4372
- self._sequence_name = self._attrs["artifactSequence"]["name"]
4373
- self._sequence_version_index = self._attrs.get("versionIndex", None)
4374
- # We will only show aliases under the Collection this artifact version is fetched from
4375
- # _aliases will be a mutable copy on which the user can append or remove aliases
4376
- self._aliases = [
4377
- a["alias"]
4378
- for a in self._attrs["aliases"]
4379
- if not re.match(r"^v\d+$", a["alias"])
4380
- and a["artifactCollectionName"] == self._artifact_collection_name
4381
- ]
4382
- self._frozen_aliases = [a for a in self._aliases]
4383
- self._manifest = None
4384
- self._is_downloaded = False
4385
- self._dependent_artifacts = []
4386
- self._download_roots = set()
4387
- artifacts.get_artifacts_cache().store_artifact(self)
4388
-
4389
- @property
4390
- def id(self):
4391
- return self._attrs["id"]
4392
-
4393
- @property
4394
- def file_count(self):
4395
- return self._attrs["fileCount"]
4396
-
4397
- @property
4398
- def source_version(self):
4399
- """The artifact's version index under its parent artifact collection.
4400
-
4401
- A string with the format "v{number}".
4402
- """
4403
- return f"v{self._sequence_version_index}"
4404
-
4405
- @property
4406
- def version(self):
4407
- """The artifact's version index under the given artifact collection.
4408
-
4409
- A string with the format "v{number}".
4410
- """
4411
- for a in self._attrs["aliases"]:
4412
- if a[
4413
- "artifactCollectionName"
4414
- ] == self._artifact_collection_name and util.alias_is_version_index(
4415
- a["alias"]
4416
- ):
4417
- return a["alias"]
4418
- return None
4419
-
4420
- @property
4421
- def entity(self):
4422
- return self._entity
4423
-
4424
- @property
4425
- def project(self):
4426
- return self._project
4427
-
4428
- @property
4429
- def metadata(self):
4430
- return self._metadata
4431
-
4432
- @metadata.setter
4433
- def metadata(self, metadata):
4434
- self._metadata = metadata
4435
-
4436
- @property
4437
- def manifest(self):
4438
- return self._load_manifest()
4439
-
4440
- @property
4441
- def digest(self):
4442
- return self._attrs["digest"]
4443
-
4444
- @property
4445
- def state(self):
4446
- return self._attrs["state"]
4447
-
4448
- @property
4449
- def size(self):
4450
- return self._attrs["size"]
4451
-
4452
- @property
4453
- def created_at(self):
4454
- """The time at which the artifact was created."""
4455
- return self._attrs["createdAt"]
4456
-
4457
- @property
4458
- def updated_at(self):
4459
- """The time at which the artifact was last updated."""
4460
- return self._attrs["updatedAt"] or self._attrs["createdAt"]
4461
-
4462
- @property
4463
- def description(self):
4464
- return self._description
4465
-
4466
- @description.setter
4467
- def description(self, desc):
4468
- self._description = desc
4469
-
4470
- @property
4471
- def type(self):
4472
- return self._attrs["artifactType"]["name"]
4473
-
4474
- @property
4475
- def commit_hash(self):
4476
- return self._attrs.get("commitHash", "")
4477
-
4478
- @property
4479
- def name(self):
4480
- if self._sequence_version_index is None:
4481
- return self.digest
4482
- return f"{self._sequence_name}:v{self._sequence_version_index}"
4483
-
4484
- @property
4485
- def aliases(self):
4486
- """The aliases associated with this artifact.
4487
-
4488
- Returns:
4489
- List[str]: The aliases associated with this artifact.
4490
-
4491
- """
4492
- return self._aliases
4493
-
4494
- @aliases.setter
4495
- def aliases(self, aliases):
4496
- for alias in aliases:
4497
- if any(char in alias for char in ["/", ":"]):
4498
- raise ValueError(
4499
- 'Invalid alias "%s", slashes and colons are disallowed' % alias
4500
- )
4501
- self._aliases = aliases
4502
-
4503
- @staticmethod
4504
- def expected_type(client, name, entity_name, project_name):
4505
- """Returns the expected type for a given artifact name and project."""
4506
- query = gql(
4507
- """
4508
- query ArtifactType(
4509
- $entityName: String,
4510
- $projectName: String,
4511
- $name: String!
4512
- ) {
4513
- project(name: $projectName, entityName: $entityName) {
4514
- artifact(name: $name) {
4515
- artifactType {
4516
- name
4517
- }
4518
- }
4519
- }
4520
- }
4521
- """
4522
- )
4523
- if ":" not in name:
4524
- name += ":latest"
4525
-
4526
- response = client.execute(
4527
- query,
4528
- variable_values={
4529
- "entityName": entity_name,
4530
- "projectName": project_name,
4531
- "name": name,
4532
- },
4533
- )
4534
-
4535
- project = response.get("project")
4536
- if project is not None:
4537
- artifact = project.get("artifact")
4538
- if artifact is not None:
4539
- artifact_type = artifact.get("artifactType")
4540
- if artifact_type is not None:
4541
- return artifact_type.get("name")
4542
-
4543
- return None
4544
-
4545
- @property
4546
- def _use_as(self):
4547
- return self._attrs.get("_use_as")
4548
-
4549
- @_use_as.setter
4550
- def _use_as(self, use_as):
4551
- self._attrs["_use_as"] = use_as
4552
- return use_as
4553
-
4554
- @normalize_exceptions
4555
- def link(self, target_path: str, aliases=None):
4556
- if ":" in target_path:
4557
- raise ValueError(
4558
- f"target_path {target_path} cannot contain `:` because it is not an alias."
4559
- )
4560
-
4561
- portfolio, project, entity = util._parse_entity_project_item(target_path)
4562
- aliases = util._resolve_aliases(aliases)
4563
-
4564
- EmptyRunProps = namedtuple("Empty", "entity project")
4565
- r = wandb.run if wandb.run else EmptyRunProps(entity=None, project=None)
4566
- entity = entity or r.entity or self.entity
4567
- project = project or r.project or self.project
4568
-
4569
- mutation = gql(
4570
- """
4571
- mutation LinkArtifact($artifactID: ID!, $artifactPortfolioName: String!, $entityName: String!, $projectName: String!, $aliases: [ArtifactAliasInput!]) {
4572
- linkArtifact(input: {artifactID: $artifactID, artifactPortfolioName: $artifactPortfolioName,
4573
- entityName: $entityName,
4574
- projectName: $projectName,
4575
- aliases: $aliases
4576
- }) {
4577
- versionIndex
4578
- }
4579
- }
4580
- """
4581
- )
4582
- self.client.execute(
4583
- mutation,
4584
- variable_values={
4585
- "artifactID": self.id,
4586
- "artifactPortfolioName": portfolio,
4587
- "entityName": entity,
4588
- "projectName": project,
4589
- "aliases": [
4590
- {"alias": alias, "artifactCollectionName": portfolio}
4591
- for alias in aliases
4592
- ],
4593
- },
4594
- )
4595
- return True
4596
-
4597
- @normalize_exceptions
4598
- def delete(self, delete_aliases=False):
4599
- """Delete an artifact and its files.
4600
-
4601
- Examples:
4602
- Delete all the "model" artifacts a run has logged:
4603
- ```
4604
- runs = api.runs(path="my_entity/my_project")
4605
- for run in runs:
4606
- for artifact in run.logged_artifacts():
4607
- if artifact.type == "model":
4608
- artifact.delete(delete_aliases=True)
4609
- ```
4610
-
4611
- Arguments:
4612
- delete_aliases: (bool) If true, deletes all aliases associated with the artifact.
4613
- Otherwise, this raises an exception if the artifact has existing aliases.
4614
- """
4615
- mutation = gql(
4616
- """
4617
- mutation DeleteArtifact($artifactID: ID!, $deleteAliases: Boolean) {
4618
- deleteArtifact(input: {
4619
- artifactID: $artifactID
4620
- deleteAliases: $deleteAliases
4621
- }) {
4622
- artifact {
4623
- id
4624
- }
4625
- }
4626
- }
4627
- """
4628
- )
4629
- self.client.execute(
4630
- mutation,
4631
- variable_values={
4632
- "artifactID": self.id,
4633
- "deleteAliases": delete_aliases,
4634
- },
4635
- )
4636
- return True
4637
-
4638
- def new_file(self, name, mode=None):
4639
- raise ValueError("Cannot add files to an artifact once it has been saved")
4640
-
4641
- def add_file(self, local_path, name=None, is_tmp=False):
4642
- raise ValueError("Cannot add files to an artifact once it has been saved")
4643
-
4644
- def add_dir(self, path, name=None):
4645
- raise ValueError("Cannot add files to an artifact once it has been saved")
4646
-
4647
- def add_reference(self, uri, name=None, checksum=True, max_objects=None):
4648
- raise ValueError("Cannot add files to an artifact once it has been saved")
4649
-
4650
- def add(self, obj, name):
4651
- raise ValueError("Cannot add files to an artifact once it has been saved")
4652
-
4653
- def _add_download_root(self, dir_path):
4654
- """Make `dir_path` a root directory for this artifact."""
4655
- self._download_roots.add(os.path.abspath(dir_path))
4656
-
4657
- def _is_download_root(self, dir_path):
4658
- """Determine if `dir_path` is a root directory for this artifact."""
4659
- return dir_path in self._download_roots
4660
-
4661
- def _local_path_to_name(self, file_path):
4662
- """Convert a local file path to a path entry in the artifact."""
4663
- abs_file_path = os.path.abspath(file_path)
4664
- abs_file_parts = abs_file_path.split(os.sep)
4665
- for i in range(len(abs_file_parts) + 1):
4666
- if self._is_download_root(os.path.join(os.sep, *abs_file_parts[:i])):
4667
- return os.path.join(*abs_file_parts[i:])
4668
- return None
4669
-
4670
- def _get_obj_entry(self, name):
4671
- """Return an object entry by name, handling any type suffixes.
4672
-
4673
- When objects are added with `.add(obj, name)`, the name is typically changed to
4674
- include the suffix of the object type when serializing to JSON. So we need to be
4675
- able to resolve a name, without tasking the user with appending .THING.json.
4676
- This method returns an entry if it exists by a suffixed name.
4677
-
4678
- Args:
4679
- name: (str) name used when adding
4680
- """
4681
- self._load_manifest()
4682
-
4683
- type_mapping = WBValue.type_mapping()
4684
- for artifact_type_str in type_mapping:
4685
- wb_class = type_mapping[artifact_type_str]
4686
- wandb_file_name = wb_class.with_suffix(name)
4687
- entry = self._manifest.entries.get(wandb_file_name)
4688
- if entry is not None:
4689
- return entry, wb_class
4690
- return None, None
4691
-
4692
- def get_path(self, name):
4693
- manifest = self._load_manifest()
4694
- entry = manifest.entries.get(name) or self._get_obj_entry(name)[0]
4695
- if entry is None:
4696
- raise KeyError("Path not contained in artifact: %s" % name)
4697
-
4698
- return _DownloadedArtifactEntry(entry, self)
4699
-
4700
- def get(self, name):
4701
- entry, wb_class = self._get_obj_entry(name)
4702
- if entry is not None:
4703
- # If the entry is a reference from another artifact, then get it directly from that artifact
4704
- if self._manifest_entry_is_artifact_reference(entry):
4705
- artifact = self._get_ref_artifact_from_entry(entry)
4706
- return artifact.get(util.uri_from_path(entry.ref))
4707
-
4708
- # Special case for wandb.Table. This is intended to be a short term optimization.
4709
- # Since tables are likely to download many other assets in artifact(s), we eagerly download
4710
- # the artifact using the parallelized `artifact.download`. In the future, we should refactor
4711
- # the deserialization pattern such that this special case is not needed.
4712
- if wb_class == wandb.Table:
4713
- self.download(recursive=True)
4714
-
4715
- # Get the ArtifactManifestEntry
4716
- item = self.get_path(entry.path)
4717
- item_path = item.download()
4718
-
4719
- # Load the object from the JSON blob
4720
- result = None
4721
- json_obj = {}
4722
- with open(item_path) as file:
4723
- json_obj = json.load(file)
4724
- result = wb_class.from_json(json_obj, self)
4725
- result._set_artifact_source(self, name)
4726
- return result
4727
-
4728
- def download(self, root=None, recursive=False):
4729
- dirpath = root or self._default_root()
4730
- self._add_download_root(dirpath)
4731
- manifest = self._load_manifest()
4732
- nfiles = len(manifest.entries)
4733
- size = sum(e.size for e in manifest.entries.values())
4734
- log = False
4735
- if nfiles > 5000 or size > 50 * 1024 * 1024:
4736
- log = True
4737
- termlog(
4738
- "Downloading large artifact {}, {:.2f}MB. {} files... ".format(
4739
- self._artifact_name, size / (1024 * 1024), nfiles
4740
- ),
4741
- )
4742
- start_time = datetime.datetime.now()
4743
-
4744
- # Force all the files to download into the same directory.
4745
- # Download in parallel
4746
- import multiprocessing.dummy # this uses threads
4747
-
4748
- download_logger = _ArtifactDownloadLogger(nfiles=nfiles)
4749
-
4750
- pool = multiprocessing.dummy.Pool(32)
4751
- pool.map(
4752
- partial(self._download_file, root=dirpath, download_logger=download_logger),
4753
- manifest.entries,
4754
- )
4755
- if recursive:
4756
- pool.map(lambda artifact: artifact.download(), self._dependent_artifacts)
4757
- pool.close()
4758
- pool.join()
4759
-
4760
- self._is_downloaded = True
4761
-
4762
- if log:
4763
- now = datetime.datetime.now()
4764
- delta = abs((now - start_time).total_seconds())
4765
- hours = int(delta // 3600)
4766
- minutes = int((delta - hours * 3600) // 60)
4767
- seconds = delta - hours * 3600 - minutes * 60
4768
- termlog(
4769
- f"Done. {hours}:{minutes}:{seconds:.1f}",
4770
- prefix=False,
4771
- )
4772
- return dirpath
4773
-
4774
- def checkout(self, root=None):
4775
- dirpath = root or self._default_root(include_version=False)
4776
-
4777
- for root, _, files in os.walk(dirpath):
4778
- for file in files:
4779
- full_path = os.path.join(root, file)
4780
- artifact_path = util.to_forward_slash_path(
4781
- os.path.relpath(full_path, start=dirpath)
4782
- )
4783
- try:
4784
- self.get_path(artifact_path)
4785
- except KeyError:
4786
- # File is not part of the artifact, remove it.
4787
- os.remove(full_path)
4788
-
4789
- return self.download(root=dirpath)
4790
-
4791
- def verify(self, root=None):
4792
- dirpath = root or self._default_root()
4793
- manifest = self._load_manifest()
4794
- ref_count = 0
4795
-
4796
- for root, _, files in os.walk(dirpath):
4797
- for file in files:
4798
- full_path = os.path.join(root, file)
4799
- artifact_path = util.to_forward_slash_path(
4800
- os.path.relpath(full_path, start=dirpath)
4801
- )
4802
- try:
4803
- self.get_path(artifact_path)
4804
- except KeyError:
4805
- raise ValueError(
4806
- "Found file {} which is not a member of artifact {}".format(
4807
- full_path, self.name
4808
- )
4809
- )
4810
-
4811
- for entry in manifest.entries.values():
4812
- if entry.ref is None:
4813
- if md5_file_b64(os.path.join(dirpath, entry.path)) != entry.digest:
4814
- raise ValueError("Digest mismatch for file: %s" % entry.path)
4815
- else:
4816
- ref_count += 1
4817
- if ref_count > 0:
4818
- print("Warning: skipped verification of %s refs" % ref_count)
4819
-
4820
- def file(self, root=None):
4821
- """Download a single file artifact to dir specified by the root.
4822
-
4823
- Arguments:
4824
- root: (str, optional) The root directory in which to place the file. Defaults to './artifacts/self.name/'.
4825
-
4826
- Returns:
4827
- (str): The full path of the downloaded file.
4828
- """
4829
- if root is None:
4830
- root = os.path.join(".", "artifacts", self.name)
4831
-
4832
- manifest = self._load_manifest()
4833
- nfiles = len(manifest.entries)
4834
- if nfiles > 1:
4835
- raise ValueError(
4836
- "This artifact contains more than one file, call `.download()` to get all files or call "
4837
- '.get_path("filename").download()'
4838
- )
4839
-
4840
- return self._download_file(list(manifest.entries)[0], root=root)
4841
-
4842
- def _download_file(
4843
- self, name, root, download_logger: Optional[_ArtifactDownloadLogger] = None
4844
- ):
4845
- # download file into cache and copy to target dir
4846
- downloaded_path = self.get_path(name).download(root)
4847
- if download_logger is not None:
4848
- download_logger.notify_downloaded()
4849
- return downloaded_path
4850
-
4851
- def _default_root(self, include_version=True):
4852
- name = self.name if include_version else self._sequence_name
4853
- root = os.path.join(get_artifact_dir(), name)
4854
- if platform.system() == "Windows":
4855
- head, tail = os.path.splitdrive(root)
4856
- root = head + tail.replace(":", "-")
4857
- return root
4858
-
4859
- def json_encode(self):
4860
- return util.artifact_to_json(self)
4861
-
4862
- @normalize_exceptions
4863
- def save(self):
4864
- """Persists artifact changes to the wandb backend."""
4865
- mutation = gql(
4866
- """
4867
- mutation updateArtifact(
4868
- $artifactID: ID!,
4869
- $description: String,
4870
- $metadata: JSONString,
4871
- $aliases: [ArtifactAliasInput!]
4872
- ) {
4873
- updateArtifact(input: {
4874
- artifactID: $artifactID,
4875
- description: $description,
4876
- metadata: $metadata,
4877
- aliases: $aliases
4878
- }) {
4879
- artifact {
4880
- id
4881
- }
4882
- }
4883
- }
4884
- """
4885
- )
4886
- introspect_query = gql(
4887
- """
4888
- query ProbeServerAddAliasesInput {
4889
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
4890
- name
4891
- inputFields {
4892
- name
4893
- }
4894
- }
4895
- }
4896
- """
4897
- )
4898
- res = self.client.execute(introspect_query)
4899
- valid = res.get("AddAliasesInputInfoType")
4900
- aliases = None
4901
- if not valid:
4902
- # If valid, wandb backend version >= 0.13.0.
4903
- # This means we can safely remove aliases from this updateArtifact request since we'll be calling
4904
- # the alias endpoints below in _save_alias_changes.
4905
- # If not valid, wandb backend version < 0.13.0. This requires aliases to be sent in updateArtifact.
4906
- aliases = [
4907
- {
4908
- "artifactCollectionName": self._artifact_collection_name,
4909
- "alias": alias,
4910
- }
4911
- for alias in self._aliases
4912
- ]
4913
-
4914
- self.client.execute(
4915
- mutation,
4916
- variable_values={
4917
- "artifactID": self.id,
4918
- "description": self.description,
4919
- "metadata": util.json_dumps_safer(self.metadata),
4920
- "aliases": aliases,
4921
- },
4922
- )
4923
- # Save locally modified aliases
4924
- self._save_alias_changes()
4925
- return True
4926
-
4927
- def wait(self):
4928
- return self
4929
-
4930
- @normalize_exceptions
4931
- def _save_alias_changes(self):
4932
- """Persist alias changes on this artifact to the wandb backend.
4933
-
4934
- Called by artifact.save().
4935
- """
4936
- aliases_to_add = set(self._aliases) - set(self._frozen_aliases)
4937
- aliases_to_remove = set(self._frozen_aliases) - set(self._aliases)
4938
-
4939
- # Introspect
4940
- introspect_query = gql(
4941
- """
4942
- query ProbeServerAddAliasesInput {
4943
- AddAliasesInputInfoType: __type(name: "AddAliasesInput") {
4944
- name
4945
- inputFields {
4946
- name
4947
- }
4948
- }
4949
- }
4950
- """
4951
- )
4952
- res = self.client.execute(introspect_query)
4953
- valid = res.get("AddAliasesInputInfoType")
4954
- if not valid:
4955
- return
4956
-
4957
- if len(aliases_to_add) > 0:
4958
- add_mutation = gql(
4959
- """
4960
- mutation addAliases(
4961
- $artifactID: ID!,
4962
- $aliases: [ArtifactCollectionAliasInput!]!,
4963
- ) {
4964
- addAliases(
4965
- input: {
4966
- artifactID: $artifactID,
4967
- aliases: $aliases,
4968
- }
4969
- ) {
4970
- success
4971
- }
4972
- }
4973
- """
4974
- )
4975
- self.client.execute(
4976
- add_mutation,
4977
- variable_values={
4978
- "artifactID": self.id,
4979
- "aliases": [
4980
- {
4981
- "artifactCollectionName": self._artifact_collection_name,
4982
- "alias": alias,
4983
- "entityName": self._entity,
4984
- "projectName": self._project,
4985
- }
4986
- for alias in aliases_to_add
4987
- ],
4988
- },
4989
- )
4990
-
4991
- if len(aliases_to_remove) > 0:
4992
- delete_mutation = gql(
4993
- """
4994
- mutation deleteAliases(
4995
- $artifactID: ID!,
4996
- $aliases: [ArtifactCollectionAliasInput!]!,
4997
- ) {
4998
- deleteAliases(
4999
- input: {
5000
- artifactID: $artifactID,
5001
- aliases: $aliases,
5002
- }
5003
- ) {
5004
- success
5005
- }
5006
- }
5007
- """
5008
- )
5009
- self.client.execute(
5010
- delete_mutation,
5011
- variable_values={
5012
- "artifactID": self.id,
5013
- "aliases": [
5014
- {
5015
- "artifactCollectionName": self._artifact_collection_name,
5016
- "alias": alias,
5017
- "entityName": self._entity,
5018
- "projectName": self._project,
5019
- }
5020
- for alias in aliases_to_remove
5021
- ],
5022
- },
5023
- )
5024
-
5025
- # reset local state
5026
- self._frozen_aliases = self._aliases
5027
- return True
5028
-
5029
- # TODO: not yet public, but we probably want something like this.
5030
- def _list(self):
5031
- manifest = self._load_manifest()
5032
- return manifest.entries.keys()
5033
-
5034
- def __repr__(self):
5035
- return f"<Artifact {self.id}>"
5036
-
5037
- def _load(self):
5038
- query = gql(
5039
- """
5040
- query Artifact(
5041
- $entityName: String,
5042
- $projectName: String,
5043
- $name: String!
5044
- ) {
5045
- project(name: $projectName, entityName: $entityName) {
5046
- artifact(name: $name) {
5047
- ...ArtifactFragment
5048
- }
5049
- }
5050
- }
5051
- %s
5052
- """
5053
- % ARTIFACT_FRAGMENT
5054
- )
5055
- response = None
5056
- try:
5057
- response = self.client.execute(
5058
- query,
5059
- variable_values={
5060
- "entityName": self.entity,
5061
- "projectName": self.project,
5062
- "name": self._artifact_name,
5063
- },
5064
- )
5065
- except Exception:
5066
- # we check for this after doing the call, since the backend supports raw digest lookups
5067
- # which don't include ":" and are 32 characters long
5068
- if ":" not in self._artifact_name and len(self._artifact_name) != 32:
5069
- raise ValueError(
5070
- 'Attempted to fetch artifact without alias (e.g. "<artifact_name>:v3" or "<artifact_name>:latest")'
5071
- )
5072
- if (
5073
- response is None
5074
- or response.get("project") is None
5075
- or response["project"].get("artifact") is None
5076
- ):
5077
- raise ValueError(
5078
- f'Project {self.entity}/{self.project} does not contain artifact: "{self._artifact_name}"'
5079
- )
5080
- self._attrs = response["project"]["artifact"]
5081
- return self._attrs
5082
-
5083
- def files(self, names=None, per_page=50):
5084
- """Iterate over all files stored in this artifact.
5085
-
5086
- Arguments:
5087
- names: (list of str, optional) The filename paths relative to the
5088
- root of the artifact you wish to list.
5089
- per_page: (int, default 50) The number of files to return per request
5090
-
5091
- Returns:
5092
- (`ArtifactFiles`): An iterator containing `File` objects
5093
- """
5094
- return ArtifactFiles(self.client, self, names, per_page)
5095
-
5096
- def _load_manifest(self):
5097
- if self._manifest is None:
5098
- query = gql(
5099
- """
5100
- query ArtifactManifest(
5101
- $entityName: String!,
5102
- $projectName: String!,
5103
- $name: String!
5104
- ) {
5105
- project(name: $projectName, entityName: $entityName) {
5106
- artifact(name: $name) {
5107
- currentManifest {
5108
- id
5109
- file {
5110
- id
5111
- directUrl
5112
- }
5113
- }
5114
- }
5115
- }
5116
- }
5117
- """
5118
- )
5119
- response = self.client.execute(
5120
- query,
5121
- variable_values={
5122
- "entityName": self.entity,
5123
- "projectName": self.project,
5124
- "name": self._artifact_name,
5125
- },
5126
- )
5127
-
5128
- index_file_url = response["project"]["artifact"]["currentManifest"]["file"][
5129
- "directUrl"
5130
- ]
5131
- with requests.get(index_file_url) as req:
5132
- req.raise_for_status()
5133
- self._manifest = artifacts.ArtifactManifest.from_manifest_json(
5134
- json.loads(util.ensure_text(req.content))
5135
- )
5136
-
5137
- self._load_dependent_manifests()
5138
-
5139
- return self._manifest
5140
-
5141
- def _load_dependent_manifests(self):
5142
- """Interrogate entries and ensure we have loaded their manifests."""
5143
- # Make sure dependencies are avail
5144
- for entry_key in self._manifest.entries:
5145
- entry = self._manifest.entries[entry_key]
5146
- if self._manifest_entry_is_artifact_reference(entry):
5147
- dep_artifact = self._get_ref_artifact_from_entry(entry)
5148
- if dep_artifact not in self._dependent_artifacts:
5149
- dep_artifact._load_manifest()
5150
- self._dependent_artifacts.append(dep_artifact)
5151
-
5152
- @staticmethod
5153
- def _manifest_entry_is_artifact_reference(entry):
5154
- """Determine if an ArtifactManifestEntry is an artifact reference."""
5155
- return (
5156
- entry.ref is not None
5157
- and urllib.parse.urlparse(entry.ref).scheme == "wandb-artifact"
5158
- )
5159
-
5160
- def _get_ref_artifact_from_entry(self, entry):
5161
- """Helper function returns the referenced artifact from an entry."""
5162
- artifact_id = util.host_from_path(entry.ref)
5163
- return Artifact.from_id(hex_to_b64_id(artifact_id), self.client)
5164
-
5165
- def used_by(self):
5166
- """Retrieve the runs which use this artifact directly.
5167
-
5168
- Returns:
5169
- [Run]: a list of Run objects which use this artifact
5170
- """
5171
- query = gql(
5172
- """
5173
- query ArtifactUsedBy(
5174
- $id: ID!,
5175
- $before: String,
5176
- $after: String,
5177
- $first: Int,
5178
- $last: Int
5179
- ) {
5180
- artifact(id: $id) {
5181
- usedBy(before: $before, after: $after, first: $first, last: $last) {
5182
- edges {
5183
- node {
5184
- name
5185
- project {
5186
- name
5187
- entityName
5188
- }
5189
- }
5190
- }
5191
- }
5192
- }
5193
- }
5194
- """
5195
- )
5196
- response = self.client.execute(
5197
- query,
5198
- variable_values={"id": self.id},
5199
- )
5200
- # yes, "name" is actually id
5201
- runs = [
5202
- Run(
5203
- self.client,
5204
- edge["node"]["project"]["entityName"],
5205
- edge["node"]["project"]["name"],
5206
- edge["node"]["name"],
5207
- )
5208
- for edge in response.get("artifact", {}).get("usedBy", {}).get("edges", [])
5209
- ]
5210
- return runs
5211
-
5212
- def logged_by(self):
5213
- """Retrieve the run which logged this artifact.
5214
-
5215
- Returns:
5216
- Run: Run object which logged this artifact
5217
- """
5218
- query = gql(
5219
- """
5220
- query ArtifactCreatedBy(
5221
- $id: ID!
5222
- ) {
5223
- artifact(id: $id) {
5224
- createdBy {
5225
- ... on Run {
5226
- name
5227
- project {
5228
- name
5229
- entityName
5230
- }
5231
- }
5232
- }
5233
- }
5234
- }
5235
- """
5236
- )
5237
- response = self.client.execute(
5238
- query,
5239
- variable_values={"id": self.id},
5240
- )
5241
- run_obj = response.get("artifact", {}).get("createdBy", {})
5242
- if run_obj is not None:
5243
- return Run(
5244
- self.client,
5245
- run_obj["project"]["entityName"],
5246
- run_obj["project"]["name"],
5247
- run_obj["name"],
5248
- )
5249
-
5250
-
5251
- class ArtifactVersions(Paginator):
5252
- """An iterable collection of artifact versions associated with a project and optional filter.
5253
-
5254
- This is generally used indirectly via the `Api`.artifact_versions method.
5255
- """
5256
-
5257
- def __init__(
5258
- self,
5259
- client: Client,
5260
- entity: str,
5261
- project: str,
5262
- collection_name: str,
5263
- type: str,
5264
- filters: Optional[Mapping[str, Any]] = None,
5265
- order: Optional[str] = None,
5266
- per_page: int = 50,
5267
- ):
5268
- self.entity = entity
5269
- self.collection_name = collection_name
5270
- self.type = type
5271
- self.project = project
5272
- self.filters = {"state": "COMMITTED"} if filters is None else filters
5273
- self.order = order
5274
- variables = {
5275
- "project": self.project,
5276
- "entity": self.entity,
5277
- "order": self.order,
5278
- "type": self.type,
5279
- "collection": self.collection_name,
5280
- "filters": json.dumps(self.filters),
5281
- }
5282
- self.QUERY = gql(
5283
- """
5284
- query Artifacts($project: String!, $entity: String!, $type: String!, $collection: String!, $cursor: String, $perPage: Int = 50, $order: String, $filters: JSONString) {{
5285
- project(name: $project, entityName: $entity) {{
5286
- artifactType(name: $type) {{
5287
- artifactCollection: {}(name: $collection) {{
5288
- name
5289
- artifacts(filters: $filters, after: $cursor, first: $perPage, order: $order) {{
5290
- totalCount
5291
- edges {{
5292
- node {{
5293
- ...ArtifactFragment
5294
- }}
5295
- version
5296
- cursor
5297
- }}
5298
- pageInfo {{
5299
- endCursor
5300
- hasNextPage
5301
- }}
5302
- }}
5303
- }}
5304
- }}
5305
- }}
5306
- }}
5307
- {}
5308
- """.format(
5309
- artifact_collection_edge_name(
5310
- server_supports_artifact_collections_gql_edges(client)
5311
- ),
5312
- ARTIFACT_FRAGMENT,
5313
- )
5314
- )
5315
- super().__init__(client, variables, per_page)
5316
-
5317
- @property
5318
- def length(self):
5319
- if self.last_response:
5320
- return self.last_response["project"]["artifactType"]["artifactCollection"][
5321
- "artifacts"
5322
- ]["totalCount"]
5323
- else:
5324
- return None
5325
-
5326
- @property
5327
- def more(self):
5328
- if self.last_response:
5329
- return self.last_response["project"]["artifactType"]["artifactCollection"][
5330
- "artifacts"
5331
- ]["pageInfo"]["hasNextPage"]
5332
- else:
5333
- return True
5334
-
5335
- @property
5336
- def cursor(self):
5337
- if self.last_response:
5338
- return self.last_response["project"]["artifactType"]["artifactCollection"][
5339
- "artifacts"
5340
- ]["edges"][-1]["cursor"]
5341
- else:
5342
- return None
5343
-
5344
- def convert_objects(self):
5345
- if self.last_response["project"]["artifactType"]["artifactCollection"] is None:
5346
- return []
5347
- return [
5348
- Artifact(
5349
- self.client,
5350
- self.entity,
5351
- self.project,
5352
- self.collection_name + ":" + a["version"],
5353
- a["node"],
5354
- )
5355
- for a in self.last_response["project"]["artifactType"][
5356
- "artifactCollection"
5357
- ]["artifacts"]["edges"]
5358
- ]
5359
-
5360
-
5361
- class ArtifactFiles(Paginator):
5362
- QUERY = gql(
5363
- """
5364
- query ArtifactFiles(
5365
- $entityName: String!,
5366
- $projectName: String!,
5367
- $artifactTypeName: String!,
5368
- $artifactName: String!
5369
- $fileNames: [String!],
5370
- $fileCursor: String,
5371
- $fileLimit: Int = 50
5372
- ) {
5373
- project(name: $projectName, entityName: $entityName) {
5374
- artifactType(name: $artifactTypeName) {
5375
- artifact(name: $artifactName) {
5376
- ...ArtifactFilesFragment
5377
- }
5378
- }
5379
- }
5380
- }
5381
- %s
5382
- """
5383
- % ARTIFACT_FILES_FRAGMENT
4218
+ % ARTIFACT_FILES_FRAGMENT
5384
4219
  )
5385
4220
 
5386
4221
  def __init__(
5387
4222
  self,
5388
4223
  client: Client,
5389
- artifact: Artifact,
4224
+ artifact: "wandb.Artifact",
5390
4225
  names: Optional[Sequence[str]] = None,
5391
4226
  per_page: int = 50,
5392
4227
  ):
5393
4228
  self.artifact = artifact
5394
4229
  variables = {
5395
- "entityName": artifact._birth_entity,
5396
- "projectName": artifact._birth_project,
4230
+ "entityName": artifact.source_entity,
4231
+ "projectName": artifact.source_project,
5397
4232
  "artifactTypeName": artifact.type,
5398
- "artifactName": artifact.name,
4233
+ "artifactName": artifact.source_name,
5399
4234
  "fileNames": names,
5400
4235
  }
5401
4236
  # The server must advertise at least SDK 0.12.21
@@ -5452,6 +4287,7 @@ class Job:
5452
4287
  _entity: str
5453
4288
  _project: str
5454
4289
  _entrypoint: List[str]
4290
+ _notebook_job: bool
5455
4291
 
5456
4292
  def __init__(self, api: Api, name, path: Optional[str] = None) -> None:
5457
4293
  try:
@@ -5468,22 +4304,25 @@ class Job:
5468
4304
  self._entity = api.default_entity
5469
4305
 
5470
4306
  with open(os.path.join(self._fpath, "wandb-job.json")) as f:
5471
- self._source_info: Mapping[str, Any] = json.load(f)
5472
- self._entrypoint = self._source_info.get("source", {}).get("entrypoint")
5473
- self._args = self._source_info.get("source", {}).get("args")
4307
+ self._job_info: Mapping[str, Any] = json.load(f)
4308
+ source_info = self._job_info.get("source", {})
4309
+ # only use notebook job if entrypoint not set and notebook is set
4310
+ self._notebook_job = source_info.get("notebook", False)
4311
+ self._entrypoint = source_info.get("entrypoint")
4312
+ self._args = source_info.get("args")
5474
4313
  self._requirements_file = os.path.join(self._fpath, "requirements.frozen.txt")
5475
4314
  self._input_types = TypeRegistry.type_from_dict(
5476
- self._source_info.get("input_types")
4315
+ self._job_info.get("input_types")
5477
4316
  )
5478
4317
  self._output_types = TypeRegistry.type_from_dict(
5479
- self._source_info.get("output_types")
4318
+ self._job_info.get("output_types")
5480
4319
  )
5481
4320
 
5482
- if self._source_info.get("source_type") == "artifact":
4321
+ if self._job_info.get("source_type") == "artifact":
5483
4322
  self._set_configure_launch_project(self._configure_launch_project_artifact)
5484
- if self._source_info.get("source_type") == "repo":
4323
+ if self._job_info.get("source_type") == "repo":
5485
4324
  self._set_configure_launch_project(self._configure_launch_project_repo)
5486
- if self._source_info.get("source_type") == "image":
4325
+ if self._job_info.get("source_type") == "image":
5487
4326
  self._set_configure_launch_project(self._configure_launch_project_container)
5488
4327
 
5489
4328
  @property
@@ -5493,8 +4332,26 @@ class Job:
5493
4332
  def _set_configure_launch_project(self, func):
5494
4333
  self.configure_launch_project = func
5495
4334
 
4335
+ def _get_code_artifact(self, artifact_string):
4336
+ artifact_string, base_url, is_id = util.parse_artifact_string(artifact_string)
4337
+ if is_id:
4338
+ code_artifact = wandb.Artifact._from_id(artifact_string, self._api._client)
4339
+ else:
4340
+ code_artifact = self._api.artifact(name=artifact_string, type="code")
4341
+ if code_artifact is None:
4342
+ raise LaunchError("No code artifact found")
4343
+ return code_artifact
4344
+
4345
+ def _configure_launch_project_notebook(self, launch_project):
4346
+ new_fname = convert_jupyter_notebook_to_script(
4347
+ self._entrypoint[-1], launch_project.project_dir
4348
+ )
4349
+ new_entrypoint = self._entrypoint
4350
+ new_entrypoint[-1] = new_fname
4351
+ launch_project.add_entry_point(new_entrypoint)
4352
+
5496
4353
  def _configure_launch_project_repo(self, launch_project):
5497
- git_info = self._source_info.get("source", {}).get("git", {})
4354
+ git_info = self._job_info.get("source", {}).get("git", {})
5498
4355
  _fetch_git_repo(
5499
4356
  launch_project.project_dir,
5500
4357
  git_info["remote"],
@@ -5504,27 +4361,30 @@ class Job:
5504
4361
  with open(os.path.join(self._fpath, "diff.patch")) as f:
5505
4362
  apply_patch(f.read(), launch_project.project_dir)
5506
4363
  shutil.copy(self._requirements_file, launch_project.project_dir)
5507
- launch_project.add_entry_point(self._entrypoint)
5508
- launch_project.python_version = self._source_info.get("runtime")
4364
+ launch_project.python_version = self._job_info.get("runtime")
4365
+ if self._notebook_job:
4366
+ self._configure_launch_project_notebook(launch_project)
4367
+ else:
4368
+ launch_project.add_entry_point(self._entrypoint)
5509
4369
 
5510
4370
  def _configure_launch_project_artifact(self, launch_project):
5511
- artifact_string = self._source_info.get("source", {}).get("artifact")
4371
+ artifact_string = self._job_info.get("source", {}).get("artifact")
5512
4372
  if artifact_string is None:
5513
4373
  raise LaunchError(f"Job {self.name} had no source artifact")
5514
- artifact_string, base_url, is_id = util.parse_artifact_string(artifact_string)
5515
- if is_id:
5516
- code_artifact = Artifact.from_id(artifact_string, self._api._client)
5517
- else:
5518
- code_artifact = self._api.artifact(name=artifact_string, type="code")
5519
- if code_artifact is None:
5520
- raise LaunchError("No code artifact found")
5521
- code_artifact.download(launch_project.project_dir)
4374
+
4375
+ code_artifact = self._get_code_artifact(artifact_string)
4376
+ launch_project.python_version = self._job_info.get("runtime")
5522
4377
  shutil.copy(self._requirements_file, launch_project.project_dir)
5523
- launch_project.add_entry_point(self._entrypoint)
5524
- launch_project.python_version = self._source_info.get("runtime")
4378
+
4379
+ code_artifact.download(launch_project.project_dir)
4380
+
4381
+ if self._notebook_job:
4382
+ self._configure_launch_project_notebook(launch_project)
4383
+ else:
4384
+ launch_project.add_entry_point(self._entrypoint)
5525
4385
 
5526
4386
  def _configure_launch_project_container(self, launch_project):
5527
- launch_project.docker_image = self._source_info.get("source", {}).get("image")
4387
+ launch_project.docker_image = self._job_info.get("source", {}).get("image")
5528
4388
  if launch_project.docker_image is None:
5529
4389
  raise LaunchError(
5530
4390
  "Job had malformed source dictionary without an image key"
@@ -5550,7 +4410,7 @@ class Job:
5550
4410
  run_config = {}
5551
4411
  for key, item in config.items():
5552
4412
  if util._is_artifact_object(item):
5553
- if isinstance(item, wandb.Artifact) and item.id is None:
4413
+ if isinstance(item, wandb.Artifact) and item.is_draft():
5554
4414
  raise ValueError("Cannot queue jobs with unlogged artifacts")
5555
4415
  run_config[key] = util.artifact_to_json(item)
5556
4416