wandb 0.19.8__py3-none-macosx_11_0_arm64.whl → 0.19.9__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. wandb/__init__.py +5 -1
  2. wandb/__init__.pyi +12 -8
  3. wandb/_pydantic/__init__.py +23 -0
  4. wandb/_pydantic/base.py +113 -0
  5. wandb/_pydantic/v1_compat.py +262 -0
  6. wandb/apis/paginator.py +82 -38
  7. wandb/apis/public/api.py +10 -64
  8. wandb/apis/public/artifacts.py +73 -17
  9. wandb/apis/public/files.py +2 -2
  10. wandb/apis/public/projects.py +3 -2
  11. wandb/apis/public/reports.py +2 -2
  12. wandb/apis/public/runs.py +19 -11
  13. wandb/bin/gpu_stats +0 -0
  14. wandb/bin/wandb-core +0 -0
  15. wandb/integration/metaflow/metaflow.py +19 -17
  16. wandb/integration/sacred/__init__.py +1 -1
  17. wandb/jupyter.py +18 -15
  18. wandb/proto/v3/wandb_internal_pb2.py +7 -3
  19. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  20. wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
  21. wandb/proto/v4/wandb_internal_pb2.py +3 -3
  22. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  23. wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
  24. wandb/proto/v5/wandb_internal_pb2.py +3 -3
  25. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  26. wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
  27. wandb/proto/wandb_deprecated.py +2 -0
  28. wandb/sdk/artifacts/_graphql_fragments.py +18 -20
  29. wandb/sdk/artifacts/_validators.py +1 -0
  30. wandb/sdk/artifacts/artifact.py +70 -36
  31. wandb/sdk/artifacts/artifact_saver.py +16 -2
  32. wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
  33. wandb/sdk/data_types/audio.py +1 -3
  34. wandb/sdk/data_types/base_types/media.py +11 -4
  35. wandb/sdk/data_types/image.py +44 -25
  36. wandb/sdk/data_types/molecule.py +1 -5
  37. wandb/sdk/data_types/object_3d.py +2 -1
  38. wandb/sdk/data_types/saved_model.py +7 -9
  39. wandb/sdk/data_types/video.py +1 -4
  40. wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
  41. wandb/sdk/internal/_generated/base.py +226 -0
  42. wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
  43. wandb/{apis/public → sdk/internal}/_generated/typing_compat.py +1 -1
  44. wandb/sdk/internal/internal_api.py +138 -47
  45. wandb/sdk/internal/sender.py +2 -0
  46. wandb/sdk/internal/sender_config.py +8 -11
  47. wandb/sdk/internal/settings_static.py +24 -2
  48. wandb/sdk/lib/apikey.py +15 -16
  49. wandb/sdk/lib/run_moment.py +4 -6
  50. wandb/sdk/lib/wb_logging.py +161 -0
  51. wandb/sdk/wandb_config.py +44 -43
  52. wandb/sdk/wandb_init.py +141 -79
  53. wandb/sdk/wandb_metadata.py +107 -91
  54. wandb/sdk/wandb_run.py +152 -44
  55. wandb/sdk/wandb_settings.py +403 -201
  56. wandb/sdk/wandb_setup.py +3 -1
  57. {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/METADATA +3 -3
  58. {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/RECORD +64 -60
  59. wandb/apis/public/_generated/base.py +0 -128
  60. /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
  61. /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
  62. /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
  63. {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/WHEEL +0 -0
  64. {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/entry_points.txt +0 -0
  65. {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/licenses/LICENSE +0 -0
@@ -53,9 +53,7 @@ class Molecule(BatchableMedia):
53
53
  caption: Optional[str] = None,
54
54
  **kwargs: str,
55
55
  ) -> None:
56
- super().__init__()
57
-
58
- self._caption = caption
56
+ super().__init__(caption=caption)
59
57
 
60
58
  if hasattr(data_or_path, "name"):
61
59
  # if the file has a path, we just detect the type and copy it from there
@@ -208,8 +206,6 @@ class Molecule(BatchableMedia):
208
206
  def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
209
207
  json_dict = super().to_json(run_or_artifact)
210
208
  json_dict["_type"] = self._log_type
211
- if self._caption:
212
- json_dict["caption"] = self._caption
213
209
  return json_dict
214
210
 
215
211
  @classmethod
@@ -215,9 +215,10 @@ class Object3D(BatchableMedia):
215
215
  def __init__(
216
216
  self,
217
217
  data_or_path: Union["np.ndarray", str, "TextIO", dict],
218
+ caption: Optional[str] = None,
218
219
  **kwargs: Optional[Union[str, "FileFormat3D"]],
219
220
  ) -> None:
220
- super().__init__()
221
+ super().__init__(caption=caption)
221
222
 
222
223
  if hasattr(data_or_path, "name"):
223
224
  # if the file has a path, we just detect the type and copy it from there
@@ -3,6 +3,7 @@ from __future__ import annotations
3
3
  import os
4
4
  import shutil
5
5
  import sys
6
+ from types import ModuleType
6
7
  from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
7
8
 
8
9
  import wandb
@@ -15,9 +16,6 @@ from ._private import MEDIA_TMP
15
16
  from .base_types.wb_value import WBValue
16
17
 
17
18
  if TYPE_CHECKING:
18
- from types import ModuleType
19
-
20
- import cloudpickle # type: ignore
21
19
  import sklearn # type: ignore
22
20
  import tensorflow # type: ignore
23
21
  import torch # type: ignore
@@ -264,9 +262,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
264
262
  self._serialize(self._model_obj, target_path)
265
263
 
266
264
 
267
- def _get_cloudpickle() -> "cloudpickle":
265
+ def _get_cloudpickle() -> ModuleType:
268
266
  return cast(
269
- "cloudpickle",
267
+ ModuleType,
270
268
  util.get_module("cloudpickle", "ModelAdapter requires `cloudpickle`"),
271
269
  )
272
270
 
@@ -338,9 +336,9 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
338
336
  return json_obj
339
337
 
340
338
 
341
- def _get_torch() -> "torch":
339
+ def _get_torch() -> ModuleType:
342
340
  return cast(
343
- "torch",
341
+ ModuleType,
344
342
  util.get_module("torch", "ModelAdapter requires `torch`"),
345
343
  )
346
344
 
@@ -366,9 +364,9 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
366
364
  )
367
365
 
368
366
 
369
- def _get_sklearn() -> "sklearn":
367
+ def _get_sklearn() -> ModuleType:
370
368
  return cast(
371
- "sklearn",
369
+ ModuleType,
372
370
  util.get_module("sklearn", "ModelAdapter requires `sklearn`"),
373
371
  )
374
372
 
@@ -90,13 +90,12 @@ class Video(BatchableMedia):
90
90
  fps: Optional[int] = None,
91
91
  format: Optional[str] = None,
92
92
  ):
93
- super().__init__()
93
+ super().__init__(caption=caption)
94
94
 
95
95
  self._format = format or "gif"
96
96
  self._width = None
97
97
  self._height = None
98
98
  self._channels = None
99
- self._caption = caption
100
99
  if self._format not in Video.EXTS:
101
100
  raise ValueError(
102
101
  "wandb.Video accepts {} formats".format(", ".join(Video.EXTS))
@@ -190,8 +189,6 @@ class Video(BatchableMedia):
190
189
  json_dict["width"] = self._width
191
190
  if self._height is not None:
192
191
  json_dict["height"] = self._height
193
- if self._caption:
194
- json_dict["caption"] = self._caption
195
192
 
196
193
  return json_dict
197
194
 
@@ -1,6 +1,5 @@
1
1
  # Generated by ariadne-codegen
2
2
 
3
- from .base import Base, GQLBase, GQLId, SerializedToJson, Typename
4
3
  from .operations import SERVER_FEATURES_QUERY_GQL
5
4
  from .server_features_query import (
6
5
  ServerFeaturesQuery,
@@ -9,11 +8,6 @@ from .server_features_query import (
9
8
  )
10
9
 
11
10
  __all__ = [
12
- "Base",
13
- "GQLBase",
14
- "GQLId",
15
- "SerializedToJson",
16
- "Typename",
17
11
  "SERVER_FEATURES_QUERY_GQL",
18
12
  "ServerFeaturesQuery",
19
13
  "ServerFeaturesQueryServerInfo",
@@ -0,0 +1,226 @@
1
+ # Generated by ariadne-codegen
2
+
3
+ """This module defines base classes for generated types, including partial support for compatibility with Pydantic v1."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from contextlib import suppress
8
+ from importlib.metadata import version
9
+ from typing import TYPE_CHECKING, Any, Literal, TypeVar
10
+
11
+ import pydantic
12
+ from pydantic import BaseModel, ConfigDict, Field
13
+ from pydantic import main as pydantic_main
14
+ from typing_extensions import Annotated, override
15
+
16
+ if TYPE_CHECKING:
17
+ from pydantic.main import IncEx
18
+
19
+ pydantic_major_version, *_ = version(pydantic.__name__).split(".")
20
+ IS_PYDANTIC_V2: bool = int(pydantic_major_version) >= 2
21
+
22
+ # Maps {v2 -> v1} model config keys that were renamed in v2.
23
+ # See: https://docs.pydantic.dev/latest/migration/#changes-to-config
24
+ _V1_CONFIG_KEYS = {
25
+ "populate_by_name": "allow_population_by_field_name",
26
+ "str_to_lower": "anystr_lower",
27
+ "str_strip_whitespace": "anystr_strip_whitespace",
28
+ "str_to_upper": "anystr_upper",
29
+ "ignored_types": "keep_untouched",
30
+ "str_max_length": "max_anystr_length",
31
+ "str_min_length": "min_anystr_length",
32
+ "from_attributes": "orm_mode",
33
+ "json_schema_extra": "schema_extra",
34
+ "validate_default": "validate_all",
35
+ }
36
+
37
+
38
+ def _convert_v2_config(v2_config: dict[str, Any]) -> dict[str, Any]:
39
+ """Return a copy of the v2 ConfigDict with renamed v1 keys."""
40
+ return {_V1_CONFIG_KEYS.get(k, k): v for k, v in v2_config.items()}
41
+
42
+
43
+ if IS_PYDANTIC_V2:
44
+ PydanticModelMetaclass = type # placeholder
45
+ else:
46
+ PydanticModelMetaclass = pydantic_main.ModelMetaclass
47
+
48
+
49
+ class _V1MixinMetaclass(PydanticModelMetaclass):
50
+ def __new__(
51
+ cls,
52
+ name: str,
53
+ bases: tuple[type, ...],
54
+ namespace: dict[str, Any],
55
+ **kwargs: Any,
56
+ ):
57
+ # Converts a model config in a v2 class definition, e.g.:
58
+ #
59
+ # class MyModel(BaseModel):
60
+ # model_config = ConfigDict(populate_by_name=True)
61
+ #
62
+ # ...to a Config class in a v1 class definition, e.g.:
63
+ #
64
+ # class MyModel(BaseModel):
65
+ # class Config:
66
+ # populate_by_name = True
67
+ #
68
+ if config_dict := namespace.pop("model_config", None):
69
+ namespace["Config"] = type("Config", (), _convert_v2_config(config_dict))
70
+ return super().__new__(cls, name, bases, namespace, **kwargs)
71
+
72
+
73
+ class PydanticV1Mixin(metaclass=_V1MixinMetaclass):
74
+ @classmethod
75
+ def __try_update_forward_refs__(cls, **localns: Any) -> None:
76
+ with suppress(AttributeError):
77
+ super().__try_update_forward_refs__(**localns)
78
+
79
+ @classmethod
80
+ def model_rebuild(cls, *args: Any, **kwargs: Any) -> None:
81
+ return cls.update_forward_refs(*args, **kwargs)
82
+
83
+ @classmethod
84
+ def model_construct(cls, *args: Any, **kwargs: Any) -> Any:
85
+ return cls.construct(*args, **kwargs)
86
+
87
+ @classmethod
88
+ def model_validate(cls, *args: Any, **kwargs: Any) -> Any:
89
+ return cls.parse_obj(*args, **kwargs)
90
+
91
+ @classmethod
92
+ def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any:
93
+ return cls.parse_raw(*args, **kwargs)
94
+
95
+ def model_dump(self, *args: Any, **kwargs: Any) -> Any:
96
+ return self.dict(*args, **kwargs)
97
+
98
+ def model_dump_json(self, *args: Any, **kwargs: Any) -> Any:
99
+ return self.json(*args, **kwargs)
100
+
101
+ def model_copy(self, *args: Any, **kwargs: Any) -> Any:
102
+ return self.copy(*args, **kwargs)
103
+
104
+
105
+ class PydanticV2Mixin:
106
+ # Placeholder: Pydantic v2 is already compatible with itself, so no need for extra mixins.
107
+ pass
108
+
109
+
110
+ # Pick the mixin type based on the detected Pydantic version.
111
+ PydanticCompatMixin = PydanticV2Mixin if IS_PYDANTIC_V2 else PydanticV1Mixin
112
+
113
+
114
+ # Base class for all generated classes/types.
115
+ # This is deliberately not a docstring to prevent inclusion in generated docs.
116
+ class Base(BaseModel, PydanticCompatMixin):
117
+ model_config = ConfigDict(
118
+ populate_by_name=True,
119
+ validate_assignment=True,
120
+ validate_default=True,
121
+ extra="forbid",
122
+ use_attribute_docstrings=True,
123
+ from_attributes=True,
124
+ revalidate_instances="always",
125
+ )
126
+
127
+ @override
128
+ def model_dump(
129
+ self,
130
+ *,
131
+ mode: Literal["json", "python"] | str = "json", # NOTE: changed default
132
+ include: IncEx | None = None,
133
+ exclude: IncEx | None = None,
134
+ context: dict[str, Any] | None = None,
135
+ by_alias: bool = True, # NOTE: changed default
136
+ exclude_unset: bool = False,
137
+ exclude_defaults: bool = False,
138
+ exclude_none: bool = False,
139
+ round_trip: bool = True, # NOTE: changed default
140
+ warnings: bool | Literal["none", "warn", "error"] = True,
141
+ serialize_as_any: bool = False,
142
+ ) -> dict[str, Any]:
143
+ return super().model_dump(
144
+ mode=mode,
145
+ include=include,
146
+ exclude=exclude,
147
+ context=context,
148
+ by_alias=by_alias,
149
+ exclude_unset=exclude_unset,
150
+ exclude_defaults=exclude_defaults,
151
+ exclude_none=exclude_none,
152
+ round_trip=round_trip,
153
+ warnings=warnings,
154
+ serialize_as_any=serialize_as_any,
155
+ )
156
+
157
+ @override
158
+ def model_dump_json(
159
+ self,
160
+ *,
161
+ indent: int | None = None,
162
+ include: IncEx | None = None,
163
+ exclude: IncEx | None = None,
164
+ context: dict[str, Any] | None = None,
165
+ by_alias: bool = True, # NOTE: changed default
166
+ exclude_unset: bool = False,
167
+ exclude_defaults: bool = False,
168
+ exclude_none: bool = False,
169
+ round_trip: bool = True, # NOTE: changed default
170
+ warnings: bool | Literal["none", "warn", "error"] = True,
171
+ serialize_as_any: bool = False,
172
+ ) -> str:
173
+ return super().model_dump_json(
174
+ indent=indent,
175
+ include=include,
176
+ exclude=exclude,
177
+ context=context,
178
+ by_alias=by_alias,
179
+ exclude_unset=exclude_unset,
180
+ exclude_defaults=exclude_defaults,
181
+ exclude_none=exclude_none,
182
+ round_trip=round_trip,
183
+ warnings=warnings,
184
+ serialize_as_any=serialize_as_any,
185
+ )
186
+
187
+
188
+ # Base class with extra customization for GQL generated types.
189
+ # This is deliberately not a docstring to prevent inclusion in generated docs.
190
+ class GQLBase(Base):
191
+ model_config = ConfigDict(
192
+ extra="ignore",
193
+ protected_namespaces=(),
194
+ )
195
+
196
+
197
+ # ------------------------------------------------------------------------------
198
+ # Reusable annotations for field types
199
+ T = TypeVar("T")
200
+
201
+ GQLId = Annotated[
202
+ str,
203
+ Field(repr=False, strict=True, frozen=True),
204
+ ]
205
+
206
+ Typename = Annotated[
207
+ T,
208
+ Field(repr=False, alias="__typename", frozen=True),
209
+ ]
210
+
211
+
212
+ # FIXME: Restore or modify this after ensuring pydantic v1 compatibility.
213
+ # def validate_maybe_json(v: Any, handler: ValidatorFunctionWrapHandler) -> Any:
214
+ # """Wraps default Json[...] field validator to allow instantiation with an already-decoded value."""
215
+ # try:
216
+ # return handler(v)
217
+ # except ValidationError:
218
+ # # Try revalidating after properly jsonifying the value
219
+ # return handler(to_json(v, by_alias=True, round_trip=True))
220
+
221
+
222
+ # SerializedToJson = Annotated[
223
+ # Json[T],
224
+ # # Allow lenient instantiation/validation: incoming data may already be deserialized.
225
+ # WrapValidator(validate_maybe_json),
226
+ # ]
@@ -7,15 +7,15 @@ from typing import List, Optional
7
7
 
8
8
  from pydantic import Field
9
9
 
10
- from .base import GQLBase
10
+ from wandb._pydantic import GQLBase
11
11
 
12
12
 
13
13
  class ServerFeaturesQuery(GQLBase):
14
- server_info: Optional["ServerFeaturesQueryServerInfo"] = Field(alias="serverInfo")
14
+ server_info: Optional[ServerFeaturesQueryServerInfo] = Field(alias="serverInfo")
15
15
 
16
16
 
17
17
  class ServerFeaturesQueryServerInfo(GQLBase):
18
- features: List[Optional["ServerFeaturesQueryServerInfoFeatures"]]
18
+ features: List[Optional[ServerFeaturesQueryServerInfoFeatures]]
19
19
 
20
20
 
21
21
  class ServerFeaturesQueryServerInfoFeatures(GQLBase):
@@ -10,5 +10,5 @@ else:
10
10
  from typing_extensions import Annotated, override
11
11
 
12
12
 
13
- Annnotated = Annotated
13
+ Annotated = Annotated
14
14
  override = override
@@ -36,6 +36,7 @@ import requests
36
36
  import yaml
37
37
  from wandb_gql import Client, gql
38
38
  from wandb_gql.client import RetryError
39
+ from wandb_graphql.language.ast import Document
39
40
 
40
41
  import wandb
41
42
  from wandb import env, util
@@ -43,7 +44,9 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
43
44
  from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
44
45
  from wandb.integration.sagemaker import parse_sm_secrets
45
46
  from wandb.old.settings import Settings
47
+ from wandb.proto.wandb_internal_pb2 import ServerFeature
46
48
  from wandb.sdk.artifacts._validators import is_artifact_registry_project
49
+ from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
47
50
  from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
48
51
  from wandb.sdk.lib.gql_request import GraphQLSession
49
52
  from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
@@ -365,6 +368,7 @@ class Api:
365
368
  self.server_create_run_queue_supports_priority: Optional[bool] = None
366
369
  self.server_supports_template_variables: Optional[bool] = None
367
370
  self.server_push_to_run_queue_supports_priority: Optional[bool] = None
371
+ self._server_features_cache: Optional[dict[str, bool]] = None
368
372
 
369
373
  def gql(self, *args: Any, **kwargs: Any) -> Any:
370
374
  ret = self._retry_gql(
@@ -869,6 +873,52 @@ class Api:
869
873
  _, _, mutations = self.server_info_introspection()
870
874
  return "updateRunQueueItemWarning" in mutations
871
875
 
876
+ def _check_server_feature(self, feature_value: ServerFeature) -> bool:
877
+ """Check if a server feature is enabled.
878
+
879
+ Args:
880
+ feature_value (ServerFeature): The enum value of the feature to check.
881
+
882
+ Returns:
883
+ bool: True if the feature is enabled, False otherwise.
884
+
885
+ Raises:
886
+ Exception: If server doesn't support feature queries or other errors occur
887
+ """
888
+ if self._server_features_cache is None:
889
+ query = gql(SERVER_FEATURES_QUERY_GQL)
890
+ response = self.gql(query)
891
+ server_info = ServerFeaturesQuery.model_validate(response).server_info
892
+ if server_info and (features := server_info.features):
893
+ self._server_features_cache = {
894
+ f.name: f.is_enabled for f in features if f
895
+ }
896
+ else:
897
+ self._server_features_cache = {}
898
+
899
+ return self._server_features_cache.get(ServerFeature.Name(feature_value), False)
900
+
901
+ def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
902
+ """Wrapper around check_server_feature that warns and returns False for older unsupported servers.
903
+
904
+ Good to use for features that have a fallback mechanism for older servers.
905
+
906
+ Args:
907
+ feature_value (ServerFeature): The enum value of the feature to check.
908
+
909
+ Returns:
910
+ bool: True if the feature is enabled, False otherwise.
911
+
912
+ Exceptions:
913
+ Exception: If an error other than the server not supporting feature queries occurs.
914
+ """
915
+ try:
916
+ return self._check_server_feature(feature_value)
917
+ except Exception as e:
918
+ if 'Cannot query field "features" on type "ServerInfo".' in str(e):
919
+ return False
920
+ raise e
921
+
872
922
  @normalize_exceptions
873
923
  def update_run_queue_item_warning(
874
924
  self,
@@ -3703,67 +3753,108 @@ class Api:
3703
3753
  else:
3704
3754
  raise ValueError(f"Unable to find an organization under entity {entity!r}.")
3705
3755
 
3706
- def use_artifact(
3756
+ def _construct_use_artifact_query(
3707
3757
  self,
3708
3758
  artifact_id: str,
3709
3759
  entity_name: Optional[str] = None,
3710
3760
  project_name: Optional[str] = None,
3711
3761
  run_name: Optional[str] = None,
3712
3762
  use_as: Optional[str] = None,
3713
- ) -> Optional[Dict[str, Any]]:
3714
- query_template = """
3715
- mutation UseArtifact(
3716
- $entityName: String!,
3717
- $projectName: String!,
3718
- $runName: String!,
3719
- $artifactID: ID!,
3720
- _USED_AS_TYPE_
3721
- ) {
3722
- useArtifact(input: {
3723
- entityName: $entityName,
3724
- projectName: $projectName,
3725
- runName: $runName,
3726
- artifactID: $artifactID,
3727
- _USED_AS_VALUE_
3728
- }) {
3729
- artifact {
3730
- id
3731
- digest
3732
- description
3733
- state
3734
- createdAt
3735
- metadata
3736
- }
3737
- }
3738
- }
3739
- """
3763
+ artifact_entity_name: Optional[str] = None,
3764
+ artifact_project_name: Optional[str] = None,
3765
+ ) -> Tuple[Document, Dict[str, Any]]:
3766
+ query_vars = [
3767
+ "$entityName: String!",
3768
+ "$projectName: String!",
3769
+ "$runName: String!",
3770
+ "$artifactID: ID!",
3771
+ ]
3772
+ query_args = [
3773
+ "entityName: $entityName",
3774
+ "projectName: $projectName",
3775
+ "runName: $runName",
3776
+ "artifactID: $artifactID",
3777
+ ]
3740
3778
 
3741
3779
  artifact_types = self.server_use_artifact_input_introspection()
3742
- if "usedAs" in artifact_types:
3743
- query_template = query_template.replace(
3744
- "_USED_AS_TYPE_", "$usedAs: String"
3745
- ).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
3746
- else:
3747
- query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
3748
- "_USED_AS_VALUE_", ""
3749
- )
3750
-
3751
- query = gql(query_template)
3780
+ if "usedAs" in artifact_types and use_as:
3781
+ query_vars.append("$usedAs: String")
3782
+ query_args.append("usedAs: $usedAs")
3752
3783
 
3753
3784
  entity_name = entity_name or self.settings("entity")
3754
3785
  project_name = project_name or self.settings("project")
3755
3786
  run_name = run_name or self.current_run_id
3756
3787
 
3757
- response = self.gql(
3758
- query,
3759
- variable_values={
3760
- "entityName": entity_name,
3761
- "projectName": project_name,
3762
- "runName": run_name,
3763
- "artifactID": artifact_id,
3764
- "usedAs": use_as,
3765
- },
3788
+ variable_values: Dict[str, Any] = {
3789
+ "entityName": entity_name,
3790
+ "projectName": project_name,
3791
+ "runName": run_name,
3792
+ "artifactID": artifact_id,
3793
+ "usedAs": use_as,
3794
+ }
3795
+
3796
+ server_allows_entity_project_information = (
3797
+ self._check_server_feature_with_fallback(
3798
+ ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION # type: ignore
3799
+ )
3800
+ )
3801
+ if server_allows_entity_project_information:
3802
+ query_vars.extend(
3803
+ [
3804
+ "$artifactEntityName: String",
3805
+ "$artifactProjectName: String",
3806
+ ]
3807
+ )
3808
+ query_args.extend(
3809
+ [
3810
+ "artifactEntityName: $artifactEntityName",
3811
+ "artifactProjectName: $artifactProjectName",
3812
+ ]
3813
+ )
3814
+ variable_values["artifactEntityName"] = artifact_entity_name
3815
+ variable_values["artifactProjectName"] = artifact_project_name
3816
+
3817
+ vars_str = ", ".join(query_vars)
3818
+ args_str = ", ".join(query_args)
3819
+
3820
+ query = gql(
3821
+ f"""
3822
+ mutation UseArtifact({vars_str}) {{
3823
+ useArtifact(input: {{{args_str}}}) {{
3824
+ artifact {{
3825
+ id
3826
+ digest
3827
+ description
3828
+ state
3829
+ createdAt
3830
+ metadata
3831
+ }}
3832
+ }}
3833
+ }}
3834
+ """
3835
+ )
3836
+ return query, variable_values
3837
+
3838
+ def use_artifact(
3839
+ self,
3840
+ artifact_id: str,
3841
+ entity_name: Optional[str] = None,
3842
+ project_name: Optional[str] = None,
3843
+ run_name: Optional[str] = None,
3844
+ artifact_entity_name: Optional[str] = None,
3845
+ artifact_project_name: Optional[str] = None,
3846
+ use_as: Optional[str] = None,
3847
+ ) -> Optional[Dict[str, Any]]:
3848
+ query, variable_values = self._construct_use_artifact_query(
3849
+ artifact_id,
3850
+ entity_name,
3851
+ project_name,
3852
+ run_name,
3853
+ use_as,
3854
+ artifact_entity_name,
3855
+ artifact_project_name,
3766
3856
  )
3857
+ response = self.gql(query, variable_values)
3767
3858
 
3768
3859
  if response["useArtifact"]["artifact"]:
3769
3860
  artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
@@ -1531,6 +1531,8 @@ class SendManager:
1531
1531
 
1532
1532
  metadata = json.loads(artifact.metadata) if artifact.metadata else None
1533
1533
  res = saver.save(
1534
+ entity=artifact.entity,
1535
+ project=artifact.project,
1534
1536
  type=artifact.type,
1535
1537
  name=artifact.name,
1536
1538
  client_id=artifact.client_id,
@@ -47,17 +47,14 @@ class ConfigState:
47
47
  # Add any top-level keys that aren't already set.
48
48
  self._add_unset_keys_from_subtree(old_config_tree, [])
49
49
 
50
- # Unfortunately, when a user logs visualizations, we store them in the
51
- # run's config. When resuming a run, we want to avoid erasing previously
52
- # logged visualizations, hence this special handling:
53
- self._add_unset_keys_from_subtree(
54
- old_config_tree,
55
- [_WANDB_INTERNAL_KEY, "visualize"],
56
- )
57
- self._add_unset_keys_from_subtree(
58
- old_config_tree,
59
- [_WANDB_INTERNAL_KEY, "viz"],
60
- )
50
+ # When resuming a run, we want to ensure the some of the old configs keys
51
+ # are maintained. So we have this logic here to add back
52
+ # any keys that were in the old config but not in the new config
53
+ for key in ["viz", "visualize", "mask/class_labels"]:
54
+ self._add_unset_keys_from_subtree(
55
+ old_config_tree,
56
+ [_WANDB_INTERNAL_KEY, key],
57
+ )
61
58
 
62
59
  def _add_unset_keys_from_subtree(
63
60
  self,