wandb 0.19.8__py3-none-win_amd64.whl → 0.19.9__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +5 -1
- wandb/__init__.pyi +12 -8
- wandb/_pydantic/__init__.py +23 -0
- wandb/_pydantic/base.py +113 -0
- wandb/_pydantic/v1_compat.py +262 -0
- wandb/apis/paginator.py +82 -38
- wandb/apis/public/api.py +10 -64
- wandb/apis/public/artifacts.py +73 -17
- wandb/apis/public/files.py +2 -2
- wandb/apis/public/projects.py +3 -2
- wandb/apis/public/reports.py +2 -2
- wandb/apis/public/runs.py +19 -11
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/integration/metaflow/metaflow.py +19 -17
- wandb/integration/sacred/__init__.py +1 -1
- wandb/jupyter.py +18 -15
- wandb/proto/v3/wandb_internal_pb2.py +7 -3
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_internal_pb2.py +3 -3
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_internal_pb2.py +3 -3
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +2 -0
- wandb/sdk/artifacts/_graphql_fragments.py +18 -20
- wandb/sdk/artifacts/_validators.py +1 -0
- wandb/sdk/artifacts/artifact.py +70 -36
- wandb/sdk/artifacts/artifact_saver.py +16 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
- wandb/sdk/data_types/audio.py +1 -3
- wandb/sdk/data_types/base_types/media.py +11 -4
- wandb/sdk/data_types/image.py +44 -25
- wandb/sdk/data_types/molecule.py +1 -5
- wandb/sdk/data_types/object_3d.py +2 -1
- wandb/sdk/data_types/saved_model.py +7 -9
- wandb/sdk/data_types/video.py +1 -4
- wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
- wandb/sdk/internal/_generated/base.py +226 -0
- wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
- wandb/{apis/public → sdk/internal}/_generated/typing_compat.py +1 -1
- wandb/sdk/internal/internal_api.py +138 -47
- wandb/sdk/internal/sender.py +2 -0
- wandb/sdk/internal/sender_config.py +8 -11
- wandb/sdk/internal/settings_static.py +24 -2
- wandb/sdk/lib/apikey.py +15 -16
- wandb/sdk/lib/run_moment.py +4 -6
- wandb/sdk/lib/wb_logging.py +161 -0
- wandb/sdk/wandb_config.py +44 -43
- wandb/sdk/wandb_init.py +141 -79
- wandb/sdk/wandb_metadata.py +107 -91
- wandb/sdk/wandb_run.py +152 -44
- wandb/sdk/wandb_settings.py +403 -201
- wandb/sdk/wandb_setup.py +3 -1
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/METADATA +3 -3
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/RECORD +64 -60
- wandb/apis/public/_generated/base.py +0 -128
- /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/WHEEL +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/molecule.py
CHANGED
@@ -53,9 +53,7 @@ class Molecule(BatchableMedia):
|
|
53
53
|
caption: Optional[str] = None,
|
54
54
|
**kwargs: str,
|
55
55
|
) -> None:
|
56
|
-
super().__init__()
|
57
|
-
|
58
|
-
self._caption = caption
|
56
|
+
super().__init__(caption=caption)
|
59
57
|
|
60
58
|
if hasattr(data_or_path, "name"):
|
61
59
|
# if the file has a path, we just detect the type and copy it from there
|
@@ -208,8 +206,6 @@ class Molecule(BatchableMedia):
|
|
208
206
|
def to_json(self, run_or_artifact: Union["LocalRun", "Artifact"]) -> dict:
|
209
207
|
json_dict = super().to_json(run_or_artifact)
|
210
208
|
json_dict["_type"] = self._log_type
|
211
|
-
if self._caption:
|
212
|
-
json_dict["caption"] = self._caption
|
213
209
|
return json_dict
|
214
210
|
|
215
211
|
@classmethod
|
@@ -215,9 +215,10 @@ class Object3D(BatchableMedia):
|
|
215
215
|
def __init__(
|
216
216
|
self,
|
217
217
|
data_or_path: Union["np.ndarray", str, "TextIO", dict],
|
218
|
+
caption: Optional[str] = None,
|
218
219
|
**kwargs: Optional[Union[str, "FileFormat3D"]],
|
219
220
|
) -> None:
|
220
|
-
super().__init__()
|
221
|
+
super().__init__(caption=caption)
|
221
222
|
|
222
223
|
if hasattr(data_or_path, "name"):
|
223
224
|
# if the file has a path, we just detect the type and copy it from there
|
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|
3
3
|
import os
|
4
4
|
import shutil
|
5
5
|
import sys
|
6
|
+
from types import ModuleType
|
6
7
|
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
|
7
8
|
|
8
9
|
import wandb
|
@@ -15,9 +16,6 @@ from ._private import MEDIA_TMP
|
|
15
16
|
from .base_types.wb_value import WBValue
|
16
17
|
|
17
18
|
if TYPE_CHECKING:
|
18
|
-
from types import ModuleType
|
19
|
-
|
20
|
-
import cloudpickle # type: ignore
|
21
19
|
import sklearn # type: ignore
|
22
20
|
import tensorflow # type: ignore
|
23
21
|
import torch # type: ignore
|
@@ -264,9 +262,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
264
262
|
self._serialize(self._model_obj, target_path)
|
265
263
|
|
266
264
|
|
267
|
-
def _get_cloudpickle() ->
|
265
|
+
def _get_cloudpickle() -> ModuleType:
|
268
266
|
return cast(
|
269
|
-
|
267
|
+
ModuleType,
|
270
268
|
util.get_module("cloudpickle", "ModelAdapter requires `cloudpickle`"),
|
271
269
|
)
|
272
270
|
|
@@ -338,9 +336,9 @@ class _PicklingSavedModel(_SavedModel[SavedModelObjType]):
|
|
338
336
|
return json_obj
|
339
337
|
|
340
338
|
|
341
|
-
def _get_torch() ->
|
339
|
+
def _get_torch() -> ModuleType:
|
342
340
|
return cast(
|
343
|
-
|
341
|
+
ModuleType,
|
344
342
|
util.get_module("torch", "ModelAdapter requires `torch`"),
|
345
343
|
)
|
346
344
|
|
@@ -366,9 +364,9 @@ class _PytorchSavedModel(_PicklingSavedModel["torch.nn.Module"]):
|
|
366
364
|
)
|
367
365
|
|
368
366
|
|
369
|
-
def _get_sklearn() ->
|
367
|
+
def _get_sklearn() -> ModuleType:
|
370
368
|
return cast(
|
371
|
-
|
369
|
+
ModuleType,
|
372
370
|
util.get_module("sklearn", "ModelAdapter requires `sklearn`"),
|
373
371
|
)
|
374
372
|
|
wandb/sdk/data_types/video.py
CHANGED
@@ -90,13 +90,12 @@ class Video(BatchableMedia):
|
|
90
90
|
fps: Optional[int] = None,
|
91
91
|
format: Optional[str] = None,
|
92
92
|
):
|
93
|
-
super().__init__()
|
93
|
+
super().__init__(caption=caption)
|
94
94
|
|
95
95
|
self._format = format or "gif"
|
96
96
|
self._width = None
|
97
97
|
self._height = None
|
98
98
|
self._channels = None
|
99
|
-
self._caption = caption
|
100
99
|
if self._format not in Video.EXTS:
|
101
100
|
raise ValueError(
|
102
101
|
"wandb.Video accepts {} formats".format(", ".join(Video.EXTS))
|
@@ -190,8 +189,6 @@ class Video(BatchableMedia):
|
|
190
189
|
json_dict["width"] = self._width
|
191
190
|
if self._height is not None:
|
192
191
|
json_dict["height"] = self._height
|
193
|
-
if self._caption:
|
194
|
-
json_dict["caption"] = self._caption
|
195
192
|
|
196
193
|
return json_dict
|
197
194
|
|
@@ -1,6 +1,5 @@
|
|
1
1
|
# Generated by ariadne-codegen
|
2
2
|
|
3
|
-
from .base import Base, GQLBase, GQLId, SerializedToJson, Typename
|
4
3
|
from .operations import SERVER_FEATURES_QUERY_GQL
|
5
4
|
from .server_features_query import (
|
6
5
|
ServerFeaturesQuery,
|
@@ -9,11 +8,6 @@ from .server_features_query import (
|
|
9
8
|
)
|
10
9
|
|
11
10
|
__all__ = [
|
12
|
-
"Base",
|
13
|
-
"GQLBase",
|
14
|
-
"GQLId",
|
15
|
-
"SerializedToJson",
|
16
|
-
"Typename",
|
17
11
|
"SERVER_FEATURES_QUERY_GQL",
|
18
12
|
"ServerFeaturesQuery",
|
19
13
|
"ServerFeaturesQueryServerInfo",
|
@@ -0,0 +1,226 @@
|
|
1
|
+
# Generated by ariadne-codegen
|
2
|
+
|
3
|
+
"""This module defines base classes for generated types, including partial support for compatibility with Pydantic v1."""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
from contextlib import suppress
|
8
|
+
from importlib.metadata import version
|
9
|
+
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
10
|
+
|
11
|
+
import pydantic
|
12
|
+
from pydantic import BaseModel, ConfigDict, Field
|
13
|
+
from pydantic import main as pydantic_main
|
14
|
+
from typing_extensions import Annotated, override
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from pydantic.main import IncEx
|
18
|
+
|
19
|
+
pydantic_major_version, *_ = version(pydantic.__name__).split(".")
|
20
|
+
IS_PYDANTIC_V2: bool = int(pydantic_major_version) >= 2
|
21
|
+
|
22
|
+
# Maps {v2 -> v1} model config keys that were renamed in v2.
|
23
|
+
# See: https://docs.pydantic.dev/latest/migration/#changes-to-config
|
24
|
+
_V1_CONFIG_KEYS = {
|
25
|
+
"populate_by_name": "allow_population_by_field_name",
|
26
|
+
"str_to_lower": "anystr_lower",
|
27
|
+
"str_strip_whitespace": "anystr_strip_whitespace",
|
28
|
+
"str_to_upper": "anystr_upper",
|
29
|
+
"ignored_types": "keep_untouched",
|
30
|
+
"str_max_length": "max_anystr_length",
|
31
|
+
"str_min_length": "min_anystr_length",
|
32
|
+
"from_attributes": "orm_mode",
|
33
|
+
"json_schema_extra": "schema_extra",
|
34
|
+
"validate_default": "validate_all",
|
35
|
+
}
|
36
|
+
|
37
|
+
|
38
|
+
def _convert_v2_config(v2_config: dict[str, Any]) -> dict[str, Any]:
|
39
|
+
"""Return a copy of the v2 ConfigDict with renamed v1 keys."""
|
40
|
+
return {_V1_CONFIG_KEYS.get(k, k): v for k, v in v2_config.items()}
|
41
|
+
|
42
|
+
|
43
|
+
if IS_PYDANTIC_V2:
|
44
|
+
PydanticModelMetaclass = type # placeholder
|
45
|
+
else:
|
46
|
+
PydanticModelMetaclass = pydantic_main.ModelMetaclass
|
47
|
+
|
48
|
+
|
49
|
+
class _V1MixinMetaclass(PydanticModelMetaclass):
|
50
|
+
def __new__(
|
51
|
+
cls,
|
52
|
+
name: str,
|
53
|
+
bases: tuple[type, ...],
|
54
|
+
namespace: dict[str, Any],
|
55
|
+
**kwargs: Any,
|
56
|
+
):
|
57
|
+
# Converts a model config in a v2 class definition, e.g.:
|
58
|
+
#
|
59
|
+
# class MyModel(BaseModel):
|
60
|
+
# model_config = ConfigDict(populate_by_name=True)
|
61
|
+
#
|
62
|
+
# ...to a Config class in a v1 class definition, e.g.:
|
63
|
+
#
|
64
|
+
# class MyModel(BaseModel):
|
65
|
+
# class Config:
|
66
|
+
# populate_by_name = True
|
67
|
+
#
|
68
|
+
if config_dict := namespace.pop("model_config", None):
|
69
|
+
namespace["Config"] = type("Config", (), _convert_v2_config(config_dict))
|
70
|
+
return super().__new__(cls, name, bases, namespace, **kwargs)
|
71
|
+
|
72
|
+
|
73
|
+
class PydanticV1Mixin(metaclass=_V1MixinMetaclass):
|
74
|
+
@classmethod
|
75
|
+
def __try_update_forward_refs__(cls, **localns: Any) -> None:
|
76
|
+
with suppress(AttributeError):
|
77
|
+
super().__try_update_forward_refs__(**localns)
|
78
|
+
|
79
|
+
@classmethod
|
80
|
+
def model_rebuild(cls, *args: Any, **kwargs: Any) -> None:
|
81
|
+
return cls.update_forward_refs(*args, **kwargs)
|
82
|
+
|
83
|
+
@classmethod
|
84
|
+
def model_construct(cls, *args: Any, **kwargs: Any) -> Any:
|
85
|
+
return cls.construct(*args, **kwargs)
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def model_validate(cls, *args: Any, **kwargs: Any) -> Any:
|
89
|
+
return cls.parse_obj(*args, **kwargs)
|
90
|
+
|
91
|
+
@classmethod
|
92
|
+
def model_validate_json(cls, *args: Any, **kwargs: Any) -> Any:
|
93
|
+
return cls.parse_raw(*args, **kwargs)
|
94
|
+
|
95
|
+
def model_dump(self, *args: Any, **kwargs: Any) -> Any:
|
96
|
+
return self.dict(*args, **kwargs)
|
97
|
+
|
98
|
+
def model_dump_json(self, *args: Any, **kwargs: Any) -> Any:
|
99
|
+
return self.json(*args, **kwargs)
|
100
|
+
|
101
|
+
def model_copy(self, *args: Any, **kwargs: Any) -> Any:
|
102
|
+
return self.copy(*args, **kwargs)
|
103
|
+
|
104
|
+
|
105
|
+
class PydanticV2Mixin:
|
106
|
+
# Placeholder: Pydantic v2 is already compatible with itself, so no need for extra mixins.
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
# Pick the mixin type based on the detected Pydantic version.
|
111
|
+
PydanticCompatMixin = PydanticV2Mixin if IS_PYDANTIC_V2 else PydanticV1Mixin
|
112
|
+
|
113
|
+
|
114
|
+
# Base class for all generated classes/types.
|
115
|
+
# This is deliberately not a docstring to prevent inclusion in generated docs.
|
116
|
+
class Base(BaseModel, PydanticCompatMixin):
|
117
|
+
model_config = ConfigDict(
|
118
|
+
populate_by_name=True,
|
119
|
+
validate_assignment=True,
|
120
|
+
validate_default=True,
|
121
|
+
extra="forbid",
|
122
|
+
use_attribute_docstrings=True,
|
123
|
+
from_attributes=True,
|
124
|
+
revalidate_instances="always",
|
125
|
+
)
|
126
|
+
|
127
|
+
@override
|
128
|
+
def model_dump(
|
129
|
+
self,
|
130
|
+
*,
|
131
|
+
mode: Literal["json", "python"] | str = "json", # NOTE: changed default
|
132
|
+
include: IncEx | None = None,
|
133
|
+
exclude: IncEx | None = None,
|
134
|
+
context: dict[str, Any] | None = None,
|
135
|
+
by_alias: bool = True, # NOTE: changed default
|
136
|
+
exclude_unset: bool = False,
|
137
|
+
exclude_defaults: bool = False,
|
138
|
+
exclude_none: bool = False,
|
139
|
+
round_trip: bool = True, # NOTE: changed default
|
140
|
+
warnings: bool | Literal["none", "warn", "error"] = True,
|
141
|
+
serialize_as_any: bool = False,
|
142
|
+
) -> dict[str, Any]:
|
143
|
+
return super().model_dump(
|
144
|
+
mode=mode,
|
145
|
+
include=include,
|
146
|
+
exclude=exclude,
|
147
|
+
context=context,
|
148
|
+
by_alias=by_alias,
|
149
|
+
exclude_unset=exclude_unset,
|
150
|
+
exclude_defaults=exclude_defaults,
|
151
|
+
exclude_none=exclude_none,
|
152
|
+
round_trip=round_trip,
|
153
|
+
warnings=warnings,
|
154
|
+
serialize_as_any=serialize_as_any,
|
155
|
+
)
|
156
|
+
|
157
|
+
@override
|
158
|
+
def model_dump_json(
|
159
|
+
self,
|
160
|
+
*,
|
161
|
+
indent: int | None = None,
|
162
|
+
include: IncEx | None = None,
|
163
|
+
exclude: IncEx | None = None,
|
164
|
+
context: dict[str, Any] | None = None,
|
165
|
+
by_alias: bool = True, # NOTE: changed default
|
166
|
+
exclude_unset: bool = False,
|
167
|
+
exclude_defaults: bool = False,
|
168
|
+
exclude_none: bool = False,
|
169
|
+
round_trip: bool = True, # NOTE: changed default
|
170
|
+
warnings: bool | Literal["none", "warn", "error"] = True,
|
171
|
+
serialize_as_any: bool = False,
|
172
|
+
) -> str:
|
173
|
+
return super().model_dump_json(
|
174
|
+
indent=indent,
|
175
|
+
include=include,
|
176
|
+
exclude=exclude,
|
177
|
+
context=context,
|
178
|
+
by_alias=by_alias,
|
179
|
+
exclude_unset=exclude_unset,
|
180
|
+
exclude_defaults=exclude_defaults,
|
181
|
+
exclude_none=exclude_none,
|
182
|
+
round_trip=round_trip,
|
183
|
+
warnings=warnings,
|
184
|
+
serialize_as_any=serialize_as_any,
|
185
|
+
)
|
186
|
+
|
187
|
+
|
188
|
+
# Base class with extra customization for GQL generated types.
|
189
|
+
# This is deliberately not a docstring to prevent inclusion in generated docs.
|
190
|
+
class GQLBase(Base):
|
191
|
+
model_config = ConfigDict(
|
192
|
+
extra="ignore",
|
193
|
+
protected_namespaces=(),
|
194
|
+
)
|
195
|
+
|
196
|
+
|
197
|
+
# ------------------------------------------------------------------------------
|
198
|
+
# Reusable annotations for field types
|
199
|
+
T = TypeVar("T")
|
200
|
+
|
201
|
+
GQLId = Annotated[
|
202
|
+
str,
|
203
|
+
Field(repr=False, strict=True, frozen=True),
|
204
|
+
]
|
205
|
+
|
206
|
+
Typename = Annotated[
|
207
|
+
T,
|
208
|
+
Field(repr=False, alias="__typename", frozen=True),
|
209
|
+
]
|
210
|
+
|
211
|
+
|
212
|
+
# FIXME: Restore or modify this after ensuring pydantic v1 compatibility.
|
213
|
+
# def validate_maybe_json(v: Any, handler: ValidatorFunctionWrapHandler) -> Any:
|
214
|
+
# """Wraps default Json[...] field validator to allow instantiation with an already-decoded value."""
|
215
|
+
# try:
|
216
|
+
# return handler(v)
|
217
|
+
# except ValidationError:
|
218
|
+
# # Try revalidating after properly jsonifying the value
|
219
|
+
# return handler(to_json(v, by_alias=True, round_trip=True))
|
220
|
+
|
221
|
+
|
222
|
+
# SerializedToJson = Annotated[
|
223
|
+
# Json[T],
|
224
|
+
# # Allow lenient instantiation/validation: incoming data may already be deserialized.
|
225
|
+
# WrapValidator(validate_maybe_json),
|
226
|
+
# ]
|
@@ -7,15 +7,15 @@ from typing import List, Optional
|
|
7
7
|
|
8
8
|
from pydantic import Field
|
9
9
|
|
10
|
-
from .
|
10
|
+
from wandb._pydantic import GQLBase
|
11
11
|
|
12
12
|
|
13
13
|
class ServerFeaturesQuery(GQLBase):
|
14
|
-
server_info: Optional[
|
14
|
+
server_info: Optional[ServerFeaturesQueryServerInfo] = Field(alias="serverInfo")
|
15
15
|
|
16
16
|
|
17
17
|
class ServerFeaturesQueryServerInfo(GQLBase):
|
18
|
-
features: List[Optional[
|
18
|
+
features: List[Optional[ServerFeaturesQueryServerInfoFeatures]]
|
19
19
|
|
20
20
|
|
21
21
|
class ServerFeaturesQueryServerInfoFeatures(GQLBase):
|
@@ -36,6 +36,7 @@ import requests
|
|
36
36
|
import yaml
|
37
37
|
from wandb_gql import Client, gql
|
38
38
|
from wandb_gql.client import RetryError
|
39
|
+
from wandb_graphql.language.ast import Document
|
39
40
|
|
40
41
|
import wandb
|
41
42
|
from wandb import env, util
|
@@ -43,7 +44,9 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
|
|
43
44
|
from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
|
44
45
|
from wandb.integration.sagemaker import parse_sm_secrets
|
45
46
|
from wandb.old.settings import Settings
|
47
|
+
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
46
48
|
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
49
|
+
from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
|
47
50
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
48
51
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
49
52
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
@@ -365,6 +368,7 @@ class Api:
|
|
365
368
|
self.server_create_run_queue_supports_priority: Optional[bool] = None
|
366
369
|
self.server_supports_template_variables: Optional[bool] = None
|
367
370
|
self.server_push_to_run_queue_supports_priority: Optional[bool] = None
|
371
|
+
self._server_features_cache: Optional[dict[str, bool]] = None
|
368
372
|
|
369
373
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
370
374
|
ret = self._retry_gql(
|
@@ -869,6 +873,52 @@ class Api:
|
|
869
873
|
_, _, mutations = self.server_info_introspection()
|
870
874
|
return "updateRunQueueItemWarning" in mutations
|
871
875
|
|
876
|
+
def _check_server_feature(self, feature_value: ServerFeature) -> bool:
|
877
|
+
"""Check if a server feature is enabled.
|
878
|
+
|
879
|
+
Args:
|
880
|
+
feature_value (ServerFeature): The enum value of the feature to check.
|
881
|
+
|
882
|
+
Returns:
|
883
|
+
bool: True if the feature is enabled, False otherwise.
|
884
|
+
|
885
|
+
Raises:
|
886
|
+
Exception: If server doesn't support feature queries or other errors occur
|
887
|
+
"""
|
888
|
+
if self._server_features_cache is None:
|
889
|
+
query = gql(SERVER_FEATURES_QUERY_GQL)
|
890
|
+
response = self.gql(query)
|
891
|
+
server_info = ServerFeaturesQuery.model_validate(response).server_info
|
892
|
+
if server_info and (features := server_info.features):
|
893
|
+
self._server_features_cache = {
|
894
|
+
f.name: f.is_enabled for f in features if f
|
895
|
+
}
|
896
|
+
else:
|
897
|
+
self._server_features_cache = {}
|
898
|
+
|
899
|
+
return self._server_features_cache.get(ServerFeature.Name(feature_value), False)
|
900
|
+
|
901
|
+
def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
|
902
|
+
"""Wrapper around check_server_feature that warns and returns False for older unsupported servers.
|
903
|
+
|
904
|
+
Good to use for features that have a fallback mechanism for older servers.
|
905
|
+
|
906
|
+
Args:
|
907
|
+
feature_value (ServerFeature): The enum value of the feature to check.
|
908
|
+
|
909
|
+
Returns:
|
910
|
+
bool: True if the feature is enabled, False otherwise.
|
911
|
+
|
912
|
+
Exceptions:
|
913
|
+
Exception: If an error other than the server not supporting feature queries occurs.
|
914
|
+
"""
|
915
|
+
try:
|
916
|
+
return self._check_server_feature(feature_value)
|
917
|
+
except Exception as e:
|
918
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
919
|
+
return False
|
920
|
+
raise e
|
921
|
+
|
872
922
|
@normalize_exceptions
|
873
923
|
def update_run_queue_item_warning(
|
874
924
|
self,
|
@@ -3703,67 +3753,108 @@ class Api:
|
|
3703
3753
|
else:
|
3704
3754
|
raise ValueError(f"Unable to find an organization under entity {entity!r}.")
|
3705
3755
|
|
3706
|
-
def
|
3756
|
+
def _construct_use_artifact_query(
|
3707
3757
|
self,
|
3708
3758
|
artifact_id: str,
|
3709
3759
|
entity_name: Optional[str] = None,
|
3710
3760
|
project_name: Optional[str] = None,
|
3711
3761
|
run_name: Optional[str] = None,
|
3712
3762
|
use_as: Optional[str] = None,
|
3713
|
-
|
3714
|
-
|
3715
|
-
|
3716
|
-
|
3717
|
-
$
|
3718
|
-
$
|
3719
|
-
$
|
3720
|
-
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
3725
|
-
|
3726
|
-
|
3727
|
-
|
3728
|
-
}) {
|
3729
|
-
artifact {
|
3730
|
-
id
|
3731
|
-
digest
|
3732
|
-
description
|
3733
|
-
state
|
3734
|
-
createdAt
|
3735
|
-
metadata
|
3736
|
-
}
|
3737
|
-
}
|
3738
|
-
}
|
3739
|
-
"""
|
3763
|
+
artifact_entity_name: Optional[str] = None,
|
3764
|
+
artifact_project_name: Optional[str] = None,
|
3765
|
+
) -> Tuple[Document, Dict[str, Any]]:
|
3766
|
+
query_vars = [
|
3767
|
+
"$entityName: String!",
|
3768
|
+
"$projectName: String!",
|
3769
|
+
"$runName: String!",
|
3770
|
+
"$artifactID: ID!",
|
3771
|
+
]
|
3772
|
+
query_args = [
|
3773
|
+
"entityName: $entityName",
|
3774
|
+
"projectName: $projectName",
|
3775
|
+
"runName: $runName",
|
3776
|
+
"artifactID: $artifactID",
|
3777
|
+
]
|
3740
3778
|
|
3741
3779
|
artifact_types = self.server_use_artifact_input_introspection()
|
3742
|
-
if "usedAs" in artifact_types:
|
3743
|
-
|
3744
|
-
|
3745
|
-
).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
|
3746
|
-
else:
|
3747
|
-
query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
|
3748
|
-
"_USED_AS_VALUE_", ""
|
3749
|
-
)
|
3750
|
-
|
3751
|
-
query = gql(query_template)
|
3780
|
+
if "usedAs" in artifact_types and use_as:
|
3781
|
+
query_vars.append("$usedAs: String")
|
3782
|
+
query_args.append("usedAs: $usedAs")
|
3752
3783
|
|
3753
3784
|
entity_name = entity_name or self.settings("entity")
|
3754
3785
|
project_name = project_name or self.settings("project")
|
3755
3786
|
run_name = run_name or self.current_run_id
|
3756
3787
|
|
3757
|
-
|
3758
|
-
|
3759
|
-
|
3760
|
-
|
3761
|
-
|
3762
|
-
|
3763
|
-
|
3764
|
-
|
3765
|
-
|
3788
|
+
variable_values: Dict[str, Any] = {
|
3789
|
+
"entityName": entity_name,
|
3790
|
+
"projectName": project_name,
|
3791
|
+
"runName": run_name,
|
3792
|
+
"artifactID": artifact_id,
|
3793
|
+
"usedAs": use_as,
|
3794
|
+
}
|
3795
|
+
|
3796
|
+
server_allows_entity_project_information = (
|
3797
|
+
self._check_server_feature_with_fallback(
|
3798
|
+
ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION # type: ignore
|
3799
|
+
)
|
3800
|
+
)
|
3801
|
+
if server_allows_entity_project_information:
|
3802
|
+
query_vars.extend(
|
3803
|
+
[
|
3804
|
+
"$artifactEntityName: String",
|
3805
|
+
"$artifactProjectName: String",
|
3806
|
+
]
|
3807
|
+
)
|
3808
|
+
query_args.extend(
|
3809
|
+
[
|
3810
|
+
"artifactEntityName: $artifactEntityName",
|
3811
|
+
"artifactProjectName: $artifactProjectName",
|
3812
|
+
]
|
3813
|
+
)
|
3814
|
+
variable_values["artifactEntityName"] = artifact_entity_name
|
3815
|
+
variable_values["artifactProjectName"] = artifact_project_name
|
3816
|
+
|
3817
|
+
vars_str = ", ".join(query_vars)
|
3818
|
+
args_str = ", ".join(query_args)
|
3819
|
+
|
3820
|
+
query = gql(
|
3821
|
+
f"""
|
3822
|
+
mutation UseArtifact({vars_str}) {{
|
3823
|
+
useArtifact(input: {{{args_str}}}) {{
|
3824
|
+
artifact {{
|
3825
|
+
id
|
3826
|
+
digest
|
3827
|
+
description
|
3828
|
+
state
|
3829
|
+
createdAt
|
3830
|
+
metadata
|
3831
|
+
}}
|
3832
|
+
}}
|
3833
|
+
}}
|
3834
|
+
"""
|
3835
|
+
)
|
3836
|
+
return query, variable_values
|
3837
|
+
|
3838
|
+
def use_artifact(
|
3839
|
+
self,
|
3840
|
+
artifact_id: str,
|
3841
|
+
entity_name: Optional[str] = None,
|
3842
|
+
project_name: Optional[str] = None,
|
3843
|
+
run_name: Optional[str] = None,
|
3844
|
+
artifact_entity_name: Optional[str] = None,
|
3845
|
+
artifact_project_name: Optional[str] = None,
|
3846
|
+
use_as: Optional[str] = None,
|
3847
|
+
) -> Optional[Dict[str, Any]]:
|
3848
|
+
query, variable_values = self._construct_use_artifact_query(
|
3849
|
+
artifact_id,
|
3850
|
+
entity_name,
|
3851
|
+
project_name,
|
3852
|
+
run_name,
|
3853
|
+
use_as,
|
3854
|
+
artifact_entity_name,
|
3855
|
+
artifact_project_name,
|
3766
3856
|
)
|
3857
|
+
response = self.gql(query, variable_values)
|
3767
3858
|
|
3768
3859
|
if response["useArtifact"]["artifact"]:
|
3769
3860
|
artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1531,6 +1531,8 @@ class SendManager:
|
|
1531
1531
|
|
1532
1532
|
metadata = json.loads(artifact.metadata) if artifact.metadata else None
|
1533
1533
|
res = saver.save(
|
1534
|
+
entity=artifact.entity,
|
1535
|
+
project=artifact.project,
|
1534
1536
|
type=artifact.type,
|
1535
1537
|
name=artifact.name,
|
1536
1538
|
client_id=artifact.client_id,
|
@@ -47,17 +47,14 @@ class ConfigState:
|
|
47
47
|
# Add any top-level keys that aren't already set.
|
48
48
|
self._add_unset_keys_from_subtree(old_config_tree, [])
|
49
49
|
|
50
|
-
#
|
51
|
-
#
|
52
|
-
#
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
old_config_tree,
|
59
|
-
[_WANDB_INTERNAL_KEY, "viz"],
|
60
|
-
)
|
50
|
+
# When resuming a run, we want to ensure the some of the old configs keys
|
51
|
+
# are maintained. So we have this logic here to add back
|
52
|
+
# any keys that were in the old config but not in the new config
|
53
|
+
for key in ["viz", "visualize", "mask/class_labels"]:
|
54
|
+
self._add_unset_keys_from_subtree(
|
55
|
+
old_config_tree,
|
56
|
+
[_WANDB_INTERNAL_KEY, key],
|
57
|
+
)
|
61
58
|
|
62
59
|
def _add_unset_keys_from_subtree(
|
63
60
|
self,
|