wandb 0.19.8__py3-none-any.whl → 0.19.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +5 -1
- wandb/__init__.pyi +15 -8
- wandb/_pydantic/__init__.py +30 -0
- wandb/_pydantic/base.py +148 -0
- wandb/_pydantic/utils.py +66 -0
- wandb/_pydantic/v1_compat.py +284 -0
- wandb/apis/paginator.py +82 -38
- wandb/apis/public/__init__.py +2 -2
- wandb/apis/public/api.py +111 -53
- wandb/apis/public/artifacts.py +387 -639
- wandb/apis/public/automations.py +69 -0
- wandb/apis/public/files.py +2 -2
- wandb/apis/public/integrations.py +168 -0
- wandb/apis/public/projects.py +32 -2
- wandb/apis/public/reports.py +2 -2
- wandb/apis/public/runs.py +19 -11
- wandb/apis/public/utils.py +107 -1
- wandb/automations/__init__.py +81 -0
- wandb/automations/_filters/__init__.py +40 -0
- wandb/automations/_filters/expressions.py +179 -0
- wandb/automations/_filters/operators.py +267 -0
- wandb/automations/_filters/run_metrics.py +183 -0
- wandb/automations/_generated/__init__.py +184 -0
- wandb/automations/_generated/create_filter_trigger.py +21 -0
- wandb/automations/_generated/create_generic_webhook_integration.py +43 -0
- wandb/automations/_generated/delete_trigger.py +19 -0
- wandb/automations/_generated/enums.py +33 -0
- wandb/automations/_generated/fragments.py +343 -0
- wandb/automations/_generated/generic_webhook_integrations_by_entity.py +22 -0
- wandb/automations/_generated/get_triggers.py +24 -0
- wandb/automations/_generated/get_triggers_by_entity.py +24 -0
- wandb/automations/_generated/input_types.py +104 -0
- wandb/automations/_generated/integrations_by_entity.py +22 -0
- wandb/automations/_generated/operations.py +710 -0
- wandb/automations/_generated/slack_integrations_by_entity.py +22 -0
- wandb/automations/_generated/update_filter_trigger.py +21 -0
- wandb/automations/_utils.py +123 -0
- wandb/automations/_validators.py +73 -0
- wandb/automations/actions.py +205 -0
- wandb/automations/automations.py +109 -0
- wandb/automations/events.py +235 -0
- wandb/automations/integrations.py +26 -0
- wandb/automations/scopes.py +76 -0
- wandb/beta/workflows.py +9 -10
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +3 -3
- wandb/integration/keras/keras.py +2 -1
- wandb/integration/langchain/wandb_tracer.py +2 -1
- wandb/integration/metaflow/metaflow.py +19 -17
- wandb/integration/sacred/__init__.py +1 -1
- wandb/jupyter.py +155 -133
- wandb/old/summary.py +0 -2
- wandb/proto/v3/wandb_internal_pb2.py +297 -292
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +292 -292
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +292 -292
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_base_pb2.py +41 -0
- wandb/proto/v6/wandb_internal_pb2.py +393 -0
- wandb/proto/v6/wandb_server_pb2.py +78 -0
- wandb/proto/v6/wandb_settings_pb2.py +58 -0
- wandb/proto/v6/wandb_telemetry_pb2.py +52 -0
- wandb/proto/wandb_base_pb2.py +2 -0
- wandb/proto/wandb_deprecated.py +10 -0
- wandb/proto/wandb_internal_pb2.py +3 -1
- wandb/proto/wandb_server_pb2.py +2 -0
- wandb/proto/wandb_settings_pb2.py +2 -0
- wandb/proto/wandb_telemetry_pb2.py +2 -0
- wandb/sdk/artifacts/_generated/__init__.py +248 -0
- wandb/sdk/artifacts/_generated/artifact_collection_membership_files.py +43 -0
- wandb/sdk/artifacts/_generated/artifact_version_files.py +36 -0
- wandb/sdk/artifacts/_generated/create_artifact_collection_tag_assignments.py +36 -0
- wandb/sdk/artifacts/_generated/delete_artifact_collection_tag_assignments.py +25 -0
- wandb/sdk/artifacts/_generated/delete_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/delete_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_generated/enums.py +17 -0
- wandb/sdk/artifacts/_generated/fragments.py +186 -0
- wandb/sdk/artifacts/_generated/input_types.py +16 -0
- wandb/sdk/artifacts/_generated/move_artifact_collection.py +35 -0
- wandb/sdk/artifacts/_generated/operations.py +510 -0
- wandb/sdk/artifacts/_generated/project_artifact_collection.py +101 -0
- wandb/sdk/artifacts/_generated/project_artifact_collections.py +33 -0
- wandb/sdk/artifacts/_generated/project_artifact_type.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifact_types.py +24 -0
- wandb/sdk/artifacts/_generated/project_artifacts.py +42 -0
- wandb/sdk/artifacts/_generated/run_input_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/run_output_artifacts.py +51 -0
- wandb/sdk/artifacts/_generated/update_artifact_portfolio.py +35 -0
- wandb/sdk/artifacts/_generated/update_artifact_sequence.py +35 -0
- wandb/sdk/artifacts/_graphql_fragments.py +56 -81
- wandb/sdk/artifacts/_validators.py +1 -0
- wandb/sdk/artifacts/artifact.py +110 -49
- wandb/sdk/artifacts/artifact_manifest_entry.py +2 -1
- wandb/sdk/artifacts/artifact_saver.py +16 -2
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
- wandb/sdk/data_types/audio.py +1 -3
- wandb/sdk/data_types/base_types/media.py +13 -7
- wandb/sdk/data_types/base_types/wb_value.py +34 -11
- wandb/sdk/data_types/html.py +36 -9
- wandb/sdk/data_types/image.py +56 -37
- wandb/sdk/data_types/molecule.py +1 -5
- wandb/sdk/data_types/object_3d.py +2 -1
- wandb/sdk/data_types/saved_model.py +7 -9
- wandb/sdk/data_types/table.py +5 -0
- wandb/sdk/data_types/trace_tree.py +2 -0
- wandb/sdk/data_types/utils.py +1 -1
- wandb/sdk/data_types/video.py +15 -30
- wandb/sdk/interface/interface.py +2 -0
- wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
- wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
- wandb/sdk/internal/internal_api.py +138 -47
- wandb/sdk/internal/profiler.py +6 -5
- wandb/sdk/internal/run.py +13 -6
- wandb/sdk/internal/sender.py +2 -0
- wandb/sdk/internal/sender_config.py +8 -11
- wandb/sdk/internal/settings_static.py +24 -2
- wandb/sdk/lib/apikey.py +40 -20
- wandb/sdk/lib/asyncio_compat.py +1 -1
- wandb/sdk/lib/deprecate.py +13 -22
- wandb/sdk/lib/disabled.py +2 -1
- wandb/sdk/lib/printer.py +37 -8
- wandb/sdk/lib/printer_asyncio.py +46 -0
- wandb/sdk/lib/redirect.py +10 -5
- wandb/sdk/lib/run_moment.py +4 -6
- wandb/sdk/lib/wb_logging.py +161 -0
- wandb/sdk/service/server_sock.py +19 -14
- wandb/sdk/service/service.py +9 -7
- wandb/sdk/service/streams.py +5 -0
- wandb/sdk/verify/verify.py +6 -3
- wandb/sdk/wandb_config.py +44 -43
- wandb/sdk/wandb_init.py +323 -141
- wandb/sdk/wandb_login.py +13 -4
- wandb/sdk/wandb_metadata.py +107 -91
- wandb/sdk/wandb_run.py +529 -325
- wandb/sdk/wandb_settings.py +422 -202
- wandb/sdk/wandb_setup.py +52 -1
- wandb/util.py +29 -29
- {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/METADATA +7 -7
- {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/RECORD +150 -93
- wandb/_globals.py +0 -19
- wandb/apis/public/_generated/base.py +0 -128
- wandb/apis/public/_generated/typing_compat.py +0 -14
- /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/WHEEL +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.10.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/video.py
CHANGED
@@ -1,11 +1,12 @@
|
|
1
|
+
import functools
|
1
2
|
import logging
|
2
3
|
import os
|
3
4
|
from io import BytesIO
|
4
|
-
from typing import TYPE_CHECKING, Any,
|
5
|
+
from typing import TYPE_CHECKING, Any, Optional, Sequence, Type, Union
|
5
6
|
|
6
7
|
import wandb
|
7
8
|
from wandb import util
|
8
|
-
from wandb.sdk.lib import filesystem, runid
|
9
|
+
from wandb.sdk.lib import filesystem, printer, printer_asyncio, runid
|
9
10
|
|
10
11
|
from . import _dtypes
|
11
12
|
from ._private import MEDIA_TMP
|
@@ -90,13 +91,12 @@ class Video(BatchableMedia):
|
|
90
91
|
fps: Optional[int] = None,
|
91
92
|
format: Optional[str] = None,
|
92
93
|
):
|
93
|
-
super().__init__()
|
94
|
+
super().__init__(caption=caption)
|
94
95
|
|
95
96
|
self._format = format or "gif"
|
96
97
|
self._width = None
|
97
98
|
self._height = None
|
98
99
|
self._channels = None
|
99
|
-
self._caption = caption
|
100
100
|
if self._format not in Video.EXTS:
|
101
101
|
raise ValueError(
|
102
102
|
"wandb.Video accepts {} formats".format(", ".join(Video.EXTS))
|
@@ -135,7 +135,11 @@ class Video(BatchableMedia):
|
|
135
135
|
"wandb.Video accepts a file path or numpy like data as input"
|
136
136
|
)
|
137
137
|
fps = fps or 4
|
138
|
-
|
138
|
+
printer_asyncio.run_async_with_spinner(
|
139
|
+
printer.new_printer(),
|
140
|
+
"Encoding video...",
|
141
|
+
functools.partial(self.encode, fps=fps),
|
142
|
+
)
|
139
143
|
|
140
144
|
def encode(self, fps: int = 4) -> None:
|
141
145
|
# import ImageSequenceClip from the appropriate MoviePy module
|
@@ -153,29 +157,12 @@ class Video(BatchableMedia):
|
|
153
157
|
filename = os.path.join(
|
154
158
|
MEDIA_TMP.name, runid.generate_id() + "." + self._format
|
155
159
|
)
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
else:
|
163
|
-
clip.write_videofile(filename, **kwargs)
|
164
|
-
except TypeError:
|
165
|
-
try: # even older versions of moviepy do not support progress_bar argument
|
166
|
-
kwargs = {"verbose": False, "progress_bar": False}
|
167
|
-
if self._format == "gif":
|
168
|
-
clip.write_gif(filename, **kwargs)
|
169
|
-
else:
|
170
|
-
clip.write_videofile(filename, **kwargs)
|
171
|
-
except TypeError:
|
172
|
-
kwargs = {
|
173
|
-
"verbose": False,
|
174
|
-
}
|
175
|
-
if self._format == "gif":
|
176
|
-
clip.write_gif(filename, **kwargs)
|
177
|
-
else:
|
178
|
-
clip.write_videofile(filename, **kwargs)
|
160
|
+
|
161
|
+
if self._format == "gif":
|
162
|
+
write_gif_with_image_io(clip, filename)
|
163
|
+
else:
|
164
|
+
clip.write_videofile(filename, logger=None)
|
165
|
+
|
179
166
|
self._set_file(filename, is_tmp=True)
|
180
167
|
|
181
168
|
@classmethod
|
@@ -190,8 +177,6 @@ class Video(BatchableMedia):
|
|
190
177
|
json_dict["width"] = self._width
|
191
178
|
if self._height is not None:
|
192
179
|
json_dict["height"] = self._height
|
193
|
-
if self._caption:
|
194
|
-
json_dict["caption"] = self._caption
|
195
180
|
|
196
181
|
return json_dict
|
197
182
|
|
wandb/sdk/interface/interface.py
CHANGED
@@ -173,6 +173,8 @@ class InterfaceBase:
|
|
173
173
|
self._make_config(data=config_dict, obj=proto_run.config)
|
174
174
|
if run._telemetry_obj:
|
175
175
|
proto_run.telemetry.MergeFrom(run._telemetry_obj)
|
176
|
+
if run._start_runtime:
|
177
|
+
proto_run.runtime = run._start_runtime
|
176
178
|
return proto_run
|
177
179
|
|
178
180
|
def publish_run(self, run: "Run") -> None:
|
@@ -1,6 +1,5 @@
|
|
1
1
|
# Generated by ariadne-codegen
|
2
2
|
|
3
|
-
from .base import Base, GQLBase, GQLId, SerializedToJson, Typename
|
4
3
|
from .operations import SERVER_FEATURES_QUERY_GQL
|
5
4
|
from .server_features_query import (
|
6
5
|
ServerFeaturesQuery,
|
@@ -9,11 +8,6 @@ from .server_features_query import (
|
|
9
8
|
)
|
10
9
|
|
11
10
|
__all__ = [
|
12
|
-
"Base",
|
13
|
-
"GQLBase",
|
14
|
-
"GQLId",
|
15
|
-
"SerializedToJson",
|
16
|
-
"Typename",
|
17
11
|
"SERVER_FEATURES_QUERY_GQL",
|
18
12
|
"ServerFeaturesQuery",
|
19
13
|
"ServerFeaturesQueryServerInfo",
|
@@ -7,15 +7,15 @@ from typing import List, Optional
|
|
7
7
|
|
8
8
|
from pydantic import Field
|
9
9
|
|
10
|
-
from .
|
10
|
+
from wandb._pydantic import GQLBase
|
11
11
|
|
12
12
|
|
13
13
|
class ServerFeaturesQuery(GQLBase):
|
14
|
-
server_info: Optional[
|
14
|
+
server_info: Optional[ServerFeaturesQueryServerInfo] = Field(alias="serverInfo")
|
15
15
|
|
16
16
|
|
17
17
|
class ServerFeaturesQueryServerInfo(GQLBase):
|
18
|
-
features: List[Optional[
|
18
|
+
features: List[Optional[ServerFeaturesQueryServerInfoFeatures]]
|
19
19
|
|
20
20
|
|
21
21
|
class ServerFeaturesQueryServerInfoFeatures(GQLBase):
|
@@ -36,6 +36,7 @@ import requests
|
|
36
36
|
import yaml
|
37
37
|
from wandb_gql import Client, gql
|
38
38
|
from wandb_gql.client import RetryError
|
39
|
+
from wandb_graphql.language.ast import Document
|
39
40
|
|
40
41
|
import wandb
|
41
42
|
from wandb import env, util
|
@@ -43,7 +44,9 @@ from wandb.apis.normalize import normalize_exceptions, parse_backend_error_messa
|
|
43
44
|
from wandb.errors import AuthenticationError, CommError, UnsupportedError, UsageError
|
44
45
|
from wandb.integration.sagemaker import parse_sm_secrets
|
45
46
|
from wandb.old.settings import Settings
|
47
|
+
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
46
48
|
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
49
|
+
from wandb.sdk.internal._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
|
47
50
|
from wandb.sdk.internal.thread_local_settings import _thread_local_api_settings
|
48
51
|
from wandb.sdk.lib.gql_request import GraphQLSession
|
49
52
|
from wandb.sdk.lib.hashutil import B64MD5, md5_file_b64
|
@@ -365,6 +368,7 @@ class Api:
|
|
365
368
|
self.server_create_run_queue_supports_priority: Optional[bool] = None
|
366
369
|
self.server_supports_template_variables: Optional[bool] = None
|
367
370
|
self.server_push_to_run_queue_supports_priority: Optional[bool] = None
|
371
|
+
self._server_features_cache: Optional[dict[str, bool]] = None
|
368
372
|
|
369
373
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
370
374
|
ret = self._retry_gql(
|
@@ -869,6 +873,52 @@ class Api:
|
|
869
873
|
_, _, mutations = self.server_info_introspection()
|
870
874
|
return "updateRunQueueItemWarning" in mutations
|
871
875
|
|
876
|
+
def _check_server_feature(self, feature_value: ServerFeature) -> bool:
|
877
|
+
"""Check if a server feature is enabled.
|
878
|
+
|
879
|
+
Args:
|
880
|
+
feature_value (ServerFeature): The enum value of the feature to check.
|
881
|
+
|
882
|
+
Returns:
|
883
|
+
bool: True if the feature is enabled, False otherwise.
|
884
|
+
|
885
|
+
Raises:
|
886
|
+
Exception: If server doesn't support feature queries or other errors occur
|
887
|
+
"""
|
888
|
+
if self._server_features_cache is None:
|
889
|
+
query = gql(SERVER_FEATURES_QUERY_GQL)
|
890
|
+
response = self.gql(query)
|
891
|
+
server_info = ServerFeaturesQuery.model_validate(response).server_info
|
892
|
+
if server_info and (features := server_info.features):
|
893
|
+
self._server_features_cache = {
|
894
|
+
f.name: f.is_enabled for f in features if f
|
895
|
+
}
|
896
|
+
else:
|
897
|
+
self._server_features_cache = {}
|
898
|
+
|
899
|
+
return self._server_features_cache.get(ServerFeature.Name(feature_value), False)
|
900
|
+
|
901
|
+
def _check_server_feature_with_fallback(self, feature_value: ServerFeature) -> bool:
|
902
|
+
"""Wrapper around check_server_feature that warns and returns False for older unsupported servers.
|
903
|
+
|
904
|
+
Good to use for features that have a fallback mechanism for older servers.
|
905
|
+
|
906
|
+
Args:
|
907
|
+
feature_value (ServerFeature): The enum value of the feature to check.
|
908
|
+
|
909
|
+
Returns:
|
910
|
+
bool: True if the feature is enabled, False otherwise.
|
911
|
+
|
912
|
+
Exceptions:
|
913
|
+
Exception: If an error other than the server not supporting feature queries occurs.
|
914
|
+
"""
|
915
|
+
try:
|
916
|
+
return self._check_server_feature(feature_value)
|
917
|
+
except Exception as e:
|
918
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
919
|
+
return False
|
920
|
+
raise e
|
921
|
+
|
872
922
|
@normalize_exceptions
|
873
923
|
def update_run_queue_item_warning(
|
874
924
|
self,
|
@@ -3703,67 +3753,108 @@ class Api:
|
|
3703
3753
|
else:
|
3704
3754
|
raise ValueError(f"Unable to find an organization under entity {entity!r}.")
|
3705
3755
|
|
3706
|
-
def
|
3756
|
+
def _construct_use_artifact_query(
|
3707
3757
|
self,
|
3708
3758
|
artifact_id: str,
|
3709
3759
|
entity_name: Optional[str] = None,
|
3710
3760
|
project_name: Optional[str] = None,
|
3711
3761
|
run_name: Optional[str] = None,
|
3712
3762
|
use_as: Optional[str] = None,
|
3713
|
-
|
3714
|
-
|
3715
|
-
|
3716
|
-
|
3717
|
-
$
|
3718
|
-
$
|
3719
|
-
$
|
3720
|
-
|
3721
|
-
|
3722
|
-
|
3723
|
-
|
3724
|
-
|
3725
|
-
|
3726
|
-
|
3727
|
-
|
3728
|
-
}) {
|
3729
|
-
artifact {
|
3730
|
-
id
|
3731
|
-
digest
|
3732
|
-
description
|
3733
|
-
state
|
3734
|
-
createdAt
|
3735
|
-
metadata
|
3736
|
-
}
|
3737
|
-
}
|
3738
|
-
}
|
3739
|
-
"""
|
3763
|
+
artifact_entity_name: Optional[str] = None,
|
3764
|
+
artifact_project_name: Optional[str] = None,
|
3765
|
+
) -> Tuple[Document, Dict[str, Any]]:
|
3766
|
+
query_vars = [
|
3767
|
+
"$entityName: String!",
|
3768
|
+
"$projectName: String!",
|
3769
|
+
"$runName: String!",
|
3770
|
+
"$artifactID: ID!",
|
3771
|
+
]
|
3772
|
+
query_args = [
|
3773
|
+
"entityName: $entityName",
|
3774
|
+
"projectName: $projectName",
|
3775
|
+
"runName: $runName",
|
3776
|
+
"artifactID: $artifactID",
|
3777
|
+
]
|
3740
3778
|
|
3741
3779
|
artifact_types = self.server_use_artifact_input_introspection()
|
3742
|
-
if "usedAs" in artifact_types:
|
3743
|
-
|
3744
|
-
|
3745
|
-
).replace("_USED_AS_VALUE_", "usedAs: $usedAs")
|
3746
|
-
else:
|
3747
|
-
query_template = query_template.replace("_USED_AS_TYPE_", "").replace(
|
3748
|
-
"_USED_AS_VALUE_", ""
|
3749
|
-
)
|
3750
|
-
|
3751
|
-
query = gql(query_template)
|
3780
|
+
if "usedAs" in artifact_types and use_as:
|
3781
|
+
query_vars.append("$usedAs: String")
|
3782
|
+
query_args.append("usedAs: $usedAs")
|
3752
3783
|
|
3753
3784
|
entity_name = entity_name or self.settings("entity")
|
3754
3785
|
project_name = project_name or self.settings("project")
|
3755
3786
|
run_name = run_name or self.current_run_id
|
3756
3787
|
|
3757
|
-
|
3758
|
-
|
3759
|
-
|
3760
|
-
|
3761
|
-
|
3762
|
-
|
3763
|
-
|
3764
|
-
|
3765
|
-
|
3788
|
+
variable_values: Dict[str, Any] = {
|
3789
|
+
"entityName": entity_name,
|
3790
|
+
"projectName": project_name,
|
3791
|
+
"runName": run_name,
|
3792
|
+
"artifactID": artifact_id,
|
3793
|
+
"usedAs": use_as,
|
3794
|
+
}
|
3795
|
+
|
3796
|
+
server_allows_entity_project_information = (
|
3797
|
+
self._check_server_feature_with_fallback(
|
3798
|
+
ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION # type: ignore
|
3799
|
+
)
|
3800
|
+
)
|
3801
|
+
if server_allows_entity_project_information:
|
3802
|
+
query_vars.extend(
|
3803
|
+
[
|
3804
|
+
"$artifactEntityName: String",
|
3805
|
+
"$artifactProjectName: String",
|
3806
|
+
]
|
3807
|
+
)
|
3808
|
+
query_args.extend(
|
3809
|
+
[
|
3810
|
+
"artifactEntityName: $artifactEntityName",
|
3811
|
+
"artifactProjectName: $artifactProjectName",
|
3812
|
+
]
|
3813
|
+
)
|
3814
|
+
variable_values["artifactEntityName"] = artifact_entity_name
|
3815
|
+
variable_values["artifactProjectName"] = artifact_project_name
|
3816
|
+
|
3817
|
+
vars_str = ", ".join(query_vars)
|
3818
|
+
args_str = ", ".join(query_args)
|
3819
|
+
|
3820
|
+
query = gql(
|
3821
|
+
f"""
|
3822
|
+
mutation UseArtifact({vars_str}) {{
|
3823
|
+
useArtifact(input: {{{args_str}}}) {{
|
3824
|
+
artifact {{
|
3825
|
+
id
|
3826
|
+
digest
|
3827
|
+
description
|
3828
|
+
state
|
3829
|
+
createdAt
|
3830
|
+
metadata
|
3831
|
+
}}
|
3832
|
+
}}
|
3833
|
+
}}
|
3834
|
+
"""
|
3835
|
+
)
|
3836
|
+
return query, variable_values
|
3837
|
+
|
3838
|
+
def use_artifact(
|
3839
|
+
self,
|
3840
|
+
artifact_id: str,
|
3841
|
+
entity_name: Optional[str] = None,
|
3842
|
+
project_name: Optional[str] = None,
|
3843
|
+
run_name: Optional[str] = None,
|
3844
|
+
artifact_entity_name: Optional[str] = None,
|
3845
|
+
artifact_project_name: Optional[str] = None,
|
3846
|
+
use_as: Optional[str] = None,
|
3847
|
+
) -> Optional[Dict[str, Any]]:
|
3848
|
+
query, variable_values = self._construct_use_artifact_query(
|
3849
|
+
artifact_id,
|
3850
|
+
entity_name,
|
3851
|
+
project_name,
|
3852
|
+
run_name,
|
3853
|
+
use_as,
|
3854
|
+
artifact_entity_name,
|
3855
|
+
artifact_project_name,
|
3766
3856
|
)
|
3857
|
+
response = self.gql(query, variable_values)
|
3767
3858
|
|
3768
3859
|
if response["useArtifact"]["artifact"]:
|
3769
3860
|
artifact: Dict[str, Any] = response["useArtifact"]["artifact"]
|
wandb/sdk/internal/profiler.py
CHANGED
@@ -18,12 +18,13 @@ def torch_trace_handler():
|
|
18
18
|
torch.profiler.profile(..., on_trace_ready=wandb.profiler.torch_trace_handler())
|
19
19
|
```
|
20
20
|
|
21
|
-
Calling this function ensures that profiler charts & tables can be viewed in
|
22
|
-
on wandb.ai.
|
21
|
+
Calling this function ensures that profiler charts & tables can be viewed in
|
22
|
+
your run dashboard on wandb.ai.
|
23
23
|
|
24
|
-
Please note that `wandb.init()` must be called before this function is
|
25
|
-
|
26
|
-
|
24
|
+
Please note that `wandb.init()` must be called before this function is
|
25
|
+
invoked, and the reinit setting must not be set to "create_new". The PyTorch
|
26
|
+
(torch) version must also be at least 1.9, in order to ensure stability of
|
27
|
+
their Profiler API.
|
27
28
|
|
28
29
|
Args:
|
29
30
|
None
|
wandb/sdk/internal/run.py
CHANGED
@@ -5,21 +5,28 @@ Semi-stubbed run for internal process use.
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
|
8
|
-
|
8
|
+
import sys
|
9
9
|
|
10
|
-
|
10
|
+
if sys.version_info >= (3, 12):
|
11
|
+
from typing import override
|
12
|
+
else:
|
13
|
+
from typing_extensions import override
|
14
|
+
|
15
|
+
from wandb.sdk import wandb_run
|
11
16
|
|
12
17
|
|
13
18
|
class InternalRun(wandb_run.Run):
|
14
19
|
def __init__(self, run_obj, settings, datatypes_cb):
|
15
20
|
super().__init__(settings=settings)
|
16
21
|
self._run_obj = run_obj
|
22
|
+
self._datatypes_cb = datatypes_cb
|
17
23
|
|
18
|
-
|
19
|
-
# We really want a common interface for wandb_run.Run and InternalRun.
|
20
|
-
_datatypes_set_callback(datatypes_cb)
|
21
|
-
|
24
|
+
@override
|
22
25
|
def _set_backend(self, backend):
|
23
26
|
# This type of run object can't have a backend
|
24
27
|
# or do any writes.
|
25
28
|
pass
|
29
|
+
|
30
|
+
@override
|
31
|
+
def _publish_file(self, fname: str) -> None:
|
32
|
+
self._datatypes_cb(fname)
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1531,6 +1531,8 @@ class SendManager:
|
|
1531
1531
|
|
1532
1532
|
metadata = json.loads(artifact.metadata) if artifact.metadata else None
|
1533
1533
|
res = saver.save(
|
1534
|
+
entity=artifact.entity,
|
1535
|
+
project=artifact.project,
|
1534
1536
|
type=artifact.type,
|
1535
1537
|
name=artifact.name,
|
1536
1538
|
client_id=artifact.client_id,
|
@@ -47,17 +47,14 @@ class ConfigState:
|
|
47
47
|
# Add any top-level keys that aren't already set.
|
48
48
|
self._add_unset_keys_from_subtree(old_config_tree, [])
|
49
49
|
|
50
|
-
#
|
51
|
-
#
|
52
|
-
#
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
old_config_tree,
|
59
|
-
[_WANDB_INTERNAL_KEY, "viz"],
|
60
|
-
)
|
50
|
+
# When resuming a run, we want to ensure the some of the old configs keys
|
51
|
+
# are maintained. So we have this logic here to add back
|
52
|
+
# any keys that were in the old config but not in the new config
|
53
|
+
for key in ["viz", "visualize", "mask/class_labels"]:
|
54
|
+
self._add_unset_keys_from_subtree(
|
55
|
+
old_config_tree,
|
56
|
+
[_WANDB_INTERNAL_KEY, key],
|
57
|
+
)
|
61
58
|
|
62
59
|
def _add_unset_keys_from_subtree(
|
63
60
|
self,
|
@@ -21,11 +21,33 @@ class SettingsStatic(Settings):
|
|
21
21
|
def _proto_to_dict(self, proto: wandb_settings_pb2.Settings) -> dict:
|
22
22
|
data = {}
|
23
23
|
|
24
|
+
exclude_fields = {
|
25
|
+
"model_config",
|
26
|
+
"model_fields",
|
27
|
+
"model_fields_set",
|
28
|
+
"__fields__",
|
29
|
+
"__model_fields_set",
|
30
|
+
"__pydantic_self__",
|
31
|
+
"__pydantic_initialised__",
|
32
|
+
}
|
33
|
+
|
34
|
+
fields = (
|
35
|
+
Settings.model_fields
|
36
|
+
if hasattr(Settings, "model_fields")
|
37
|
+
else Settings.__fields__
|
38
|
+
) # type: ignore [attr-defined]
|
39
|
+
|
40
|
+
fields = {k: v for k, v in fields.items() if k not in exclude_fields} # type: ignore [union-attr]
|
41
|
+
|
24
42
|
forks_specified: list[str] = []
|
25
|
-
for key in
|
43
|
+
for key in fields:
|
44
|
+
# Skip Python-only keys that do not exist on the proto.
|
45
|
+
if key in ("reinit",):
|
46
|
+
continue
|
47
|
+
|
26
48
|
value: Any = None
|
27
49
|
|
28
|
-
field_info =
|
50
|
+
field_info = fields[key]
|
29
51
|
annotation = str(field_info.annotation)
|
30
52
|
|
31
53
|
if key == "_stats_open_metrics_filters":
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -11,7 +11,7 @@ import textwrap
|
|
11
11
|
from functools import partial
|
12
12
|
|
13
13
|
# import Literal
|
14
|
-
from typing import TYPE_CHECKING, Callable,
|
14
|
+
from typing import TYPE_CHECKING, Callable, Literal
|
15
15
|
from urllib.parse import urlparse
|
16
16
|
|
17
17
|
import click
|
@@ -23,6 +23,9 @@ from wandb.errors import term
|
|
23
23
|
from wandb.errors.links import url_registry
|
24
24
|
from wandb.util import _is_databricks, isatty, prompt_choices
|
25
25
|
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from wandb.sdk.wandb_settings import Settings
|
28
|
+
|
26
29
|
LOGIN_CHOICE_ANON = "Private W&B dashboard, no account required"
|
27
30
|
LOGIN_CHOICE_NEW = "Create a W&B account"
|
28
31
|
LOGIN_CHOICE_EXISTS = "Use an existing W&B account"
|
@@ -50,18 +53,14 @@ class WriteNetrcError(Exception):
|
|
50
53
|
Mode = Literal["allow", "must", "never", "false", "true"]
|
51
54
|
|
52
55
|
|
53
|
-
if TYPE_CHECKING:
|
54
|
-
from wandb.sdk.wandb_settings import Settings
|
55
|
-
|
56
|
-
|
57
56
|
getpass = partial(click.prompt, hide_input=True, err=True)
|
58
57
|
|
59
58
|
|
60
|
-
def _fixup_anon_mode(default:
|
59
|
+
def _fixup_anon_mode(default: Mode | None) -> Mode | None:
|
61
60
|
# Convert weird anonymode values from legacy settings files
|
62
61
|
# into one of our expected values.
|
63
62
|
anon_mode = default or "never"
|
64
|
-
mapping:
|
63
|
+
mapping: dict[Mode, Mode] = {"true": "allow", "false": "never"}
|
65
64
|
return mapping.get(anon_mode, anon_mode)
|
66
65
|
|
67
66
|
|
@@ -84,15 +83,35 @@ def get_netrc_file_path() -> str:
|
|
84
83
|
return os.path.join(os.path.expanduser("~"), netrc_file)
|
85
84
|
|
86
85
|
|
86
|
+
def _api_key_prompt_str(app_url: str, referrer: str | None = None) -> str:
|
87
|
+
"""Generate a prompt string for API key authorization.
|
88
|
+
|
89
|
+
Creates a URL string that directs users to the authorization page where they
|
90
|
+
can find their API key.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
app_url: The base URL of the W&B application.
|
94
|
+
referrer: Optional referrer parameter to include in the URL.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
A formatted string with instructions and the authorization URL.
|
98
|
+
"""
|
99
|
+
ref = ""
|
100
|
+
if referrer:
|
101
|
+
ref = f"?ref={referrer}"
|
102
|
+
return f"You can find your API key in your browser here: {app_url}/authorize{ref}"
|
103
|
+
|
104
|
+
|
87
105
|
def prompt_api_key( # noqa: C901
|
88
|
-
settings:
|
89
|
-
api:
|
90
|
-
input_callback:
|
91
|
-
browser_callback:
|
106
|
+
settings: Settings,
|
107
|
+
api: InternalApi | None = None,
|
108
|
+
input_callback: Callable | None = None,
|
109
|
+
browser_callback: Callable | None = None,
|
92
110
|
no_offline: bool = False,
|
93
111
|
no_create: bool = False,
|
94
112
|
local: bool = False,
|
95
|
-
|
113
|
+
referrer: str | None = None,
|
114
|
+
) -> str | bool | None:
|
96
115
|
"""Prompt for api key.
|
97
116
|
|
98
117
|
Returns:
|
@@ -150,7 +169,10 @@ def prompt_api_key( # noqa: C901
|
|
150
169
|
key = browser_callback(signup=True) if browser_callback else None
|
151
170
|
|
152
171
|
if not key:
|
153
|
-
|
172
|
+
ref = f"&ref={referrer}" if referrer else ""
|
173
|
+
wandb.termlog(
|
174
|
+
f"Create an account here: {app_url}/authorize?signup=true{ref}"
|
175
|
+
)
|
154
176
|
key = input_callback(api_ask).strip()
|
155
177
|
elif result == LOGIN_CHOICE_EXISTS:
|
156
178
|
key = browser_callback() if browser_callback else None
|
@@ -165,9 +187,7 @@ def prompt_api_key( # noqa: C901
|
|
165
187
|
f"Logging into {host}. (Learn how to deploy a W&B server "
|
166
188
|
f"locally: {url_registry.url('wandb-server')})"
|
167
189
|
)
|
168
|
-
wandb.termlog(
|
169
|
-
f"You can find your API key in your browser here: {app_url}/authorize"
|
170
|
-
)
|
190
|
+
wandb.termlog(_api_key_prompt_str(app_url, referrer))
|
171
191
|
key = input_callback(api_ask).strip()
|
172
192
|
elif result == LOGIN_CHOICE_NOTTY:
|
173
193
|
# TODO: Needs refactor as this needs to be handled by caller
|
@@ -276,9 +296,9 @@ def write_netrc(host: str, entity: str, key: str):
|
|
276
296
|
|
277
297
|
|
278
298
|
def write_key(
|
279
|
-
settings:
|
280
|
-
key:
|
281
|
-
api:
|
299
|
+
settings: Settings,
|
300
|
+
key: str | None,
|
301
|
+
api: InternalApi | None = None,
|
282
302
|
) -> None:
|
283
303
|
if not key:
|
284
304
|
raise ValueError("No API key specified.")
|
@@ -298,7 +318,7 @@ def write_key(
|
|
298
318
|
write_netrc(settings.base_url, "user", key)
|
299
319
|
|
300
320
|
|
301
|
-
def api_key(settings:
|
321
|
+
def api_key(settings: Settings | None = None) -> str | None:
|
302
322
|
if settings is None:
|
303
323
|
settings = wandb.setup().settings
|
304
324
|
if settings.api_key:
|