wandb 0.19.8__py3-none-macosx_11_0_arm64.whl → 0.19.9__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +5 -1
- wandb/__init__.pyi +12 -8
- wandb/_pydantic/__init__.py +23 -0
- wandb/_pydantic/base.py +113 -0
- wandb/_pydantic/v1_compat.py +262 -0
- wandb/apis/paginator.py +82 -38
- wandb/apis/public/api.py +10 -64
- wandb/apis/public/artifacts.py +73 -17
- wandb/apis/public/files.py +2 -2
- wandb/apis/public/projects.py +3 -2
- wandb/apis/public/reports.py +2 -2
- wandb/apis/public/runs.py +19 -11
- wandb/bin/gpu_stats +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/integration/metaflow/metaflow.py +19 -17
- wandb/integration/sacred/__init__.py +1 -1
- wandb/jupyter.py +18 -15
- wandb/proto/v3/wandb_internal_pb2.py +7 -3
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_internal_pb2.py +3 -3
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_internal_pb2.py +3 -3
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +2 -0
- wandb/sdk/artifacts/_graphql_fragments.py +18 -20
- wandb/sdk/artifacts/_validators.py +1 -0
- wandb/sdk/artifacts/artifact.py +70 -36
- wandb/sdk/artifacts/artifact_saver.py +16 -2
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +23 -2
- wandb/sdk/data_types/audio.py +1 -3
- wandb/sdk/data_types/base_types/media.py +11 -4
- wandb/sdk/data_types/image.py +44 -25
- wandb/sdk/data_types/molecule.py +1 -5
- wandb/sdk/data_types/object_3d.py +2 -1
- wandb/sdk/data_types/saved_model.py +7 -9
- wandb/sdk/data_types/video.py +1 -4
- wandb/{apis/public → sdk/internal}/_generated/__init__.py +0 -6
- wandb/sdk/internal/_generated/base.py +226 -0
- wandb/{apis/public → sdk/internal}/_generated/server_features_query.py +3 -3
- wandb/{apis/public → sdk/internal}/_generated/typing_compat.py +1 -1
- wandb/sdk/internal/internal_api.py +138 -47
- wandb/sdk/internal/sender.py +2 -0
- wandb/sdk/internal/sender_config.py +8 -11
- wandb/sdk/internal/settings_static.py +24 -2
- wandb/sdk/lib/apikey.py +15 -16
- wandb/sdk/lib/run_moment.py +4 -6
- wandb/sdk/lib/wb_logging.py +161 -0
- wandb/sdk/wandb_config.py +44 -43
- wandb/sdk/wandb_init.py +141 -79
- wandb/sdk/wandb_metadata.py +107 -91
- wandb/sdk/wandb_run.py +152 -44
- wandb/sdk/wandb_settings.py +403 -201
- wandb/sdk/wandb_setup.py +3 -1
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/METADATA +3 -3
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/RECORD +64 -60
- wandb/apis/public/_generated/base.py +0 -128
- /wandb/{apis/public → sdk/internal}/_generated/enums.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/input_types.py +0 -0
- /wandb/{apis/public → sdk/internal}/_generated/operations.py +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/WHEEL +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.8.dist-info → wandb-0.19.9.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/api.py
CHANGED
@@ -25,7 +25,6 @@ import wandb
|
|
25
25
|
from wandb import env, util
|
26
26
|
from wandb.apis import public
|
27
27
|
from wandb.apis.normalize import normalize_exceptions
|
28
|
-
from wandb.apis.public._generated import SERVER_FEATURES_QUERY_GQL, ServerFeaturesQuery
|
29
28
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
30
29
|
from wandb.apis.public.registries import Registries
|
31
30
|
from wandb.apis.public.utils import (
|
@@ -289,7 +288,7 @@ class Api:
|
|
289
288
|
)
|
290
289
|
)
|
291
290
|
self._client = RetryingClient(self._base_client)
|
292
|
-
self._server_features_cache = None
|
291
|
+
self._server_features_cache: Optional[dict[str, bool]] = None
|
293
292
|
|
294
293
|
def create_project(self, name: str, entity: str) -> None:
|
295
294
|
"""Create a new project.
|
@@ -757,15 +756,14 @@ class Api:
|
|
757
756
|
return parts
|
758
757
|
|
759
758
|
def projects(
|
760
|
-
self, entity: Optional[str] = None, per_page:
|
759
|
+
self, entity: Optional[str] = None, per_page: int = 200
|
761
760
|
) -> "public.Projects":
|
762
761
|
"""Get projects for a given entity.
|
763
762
|
|
764
763
|
Args:
|
765
764
|
entity: (str) Name of the entity requested. If None, will fall back to the
|
766
765
|
default entity passed to `Api`. If no default entity, will raise a `ValueError`.
|
767
|
-
per_page: (int) Sets the page size for query pagination.
|
768
|
-
Usually there is no reason to change this.
|
766
|
+
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
|
769
767
|
|
770
768
|
Returns:
|
771
769
|
A `Projects` object which is an iterable collection of `Project` objects.
|
@@ -808,7 +806,7 @@ class Api:
|
|
808
806
|
return public.Project(self.client, entity, name, {})
|
809
807
|
|
810
808
|
def reports(
|
811
|
-
self, path: str = "", name: Optional[str] = None, per_page:
|
809
|
+
self, path: str = "", name: Optional[str] = None, per_page: int = 50
|
812
810
|
) -> "public.Reports":
|
813
811
|
"""Get reports for a given project path.
|
814
812
|
|
@@ -817,8 +815,7 @@ class Api:
|
|
817
815
|
Args:
|
818
816
|
path: (str) path to project the report resides in, should be in the form: "entity/project"
|
819
817
|
name: (str, optional) optional name of the report requested.
|
820
|
-
per_page: (int) Sets the page size for query pagination.
|
821
|
-
Usually there is no reason to change this.
|
818
|
+
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
|
822
819
|
|
823
820
|
Returns:
|
824
821
|
A `Reports` object which is an iterable collection of `BetaReport` objects.
|
@@ -1153,15 +1150,14 @@ class Api:
|
|
1153
1150
|
|
1154
1151
|
@normalize_exceptions
|
1155
1152
|
def artifact_collections(
|
1156
|
-
self, project_name: str, type_name: str, per_page:
|
1153
|
+
self, project_name: str, type_name: str, per_page: int = 50
|
1157
1154
|
) -> "public.ArtifactCollections":
|
1158
1155
|
"""Return a collection of matching artifact collections.
|
1159
1156
|
|
1160
1157
|
Args:
|
1161
1158
|
project_name: (str) The name of the project to filter on.
|
1162
1159
|
type_name: (str) The name of the artifact type to filter on.
|
1163
|
-
per_page: (int
|
1164
|
-
Usually there is no reason to change this.
|
1160
|
+
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
|
1165
1161
|
|
1166
1162
|
Returns:
|
1167
1163
|
An iterable `ArtifactCollections` object.
|
@@ -1226,7 +1222,7 @@ class Api:
|
|
1226
1222
|
self,
|
1227
1223
|
type_name: str,
|
1228
1224
|
name: str,
|
1229
|
-
per_page:
|
1225
|
+
per_page: int = 50,
|
1230
1226
|
tags: Optional[List[str]] = None,
|
1231
1227
|
) -> "public.Artifacts":
|
1232
1228
|
"""Return an `Artifacts` collection from the given parameters.
|
@@ -1234,8 +1230,7 @@ class Api:
|
|
1234
1230
|
Args:
|
1235
1231
|
type_name: (str) The type of artifacts to fetch.
|
1236
1232
|
name: (str) An artifact collection name. May be prefixed with entity/project.
|
1237
|
-
per_page: (int
|
1238
|
-
Usually there is no reason to change this.
|
1233
|
+
per_page: (int) Sets the page size for query pagination. Usually there is no reason to change this.
|
1239
1234
|
tags: (list[str], optional) Only return artifacts with all of these tags.
|
1240
1235
|
|
1241
1236
|
Returns:
|
@@ -1510,7 +1505,7 @@ class Api:
|
|
1510
1505
|
Returns:
|
1511
1506
|
A registry iterator.
|
1512
1507
|
"""
|
1513
|
-
if not
|
1508
|
+
if not InternalApi()._check_server_feature_with_fallback(
|
1514
1509
|
ServerFeature.ARTIFACT_REGISTRY_SEARCH
|
1515
1510
|
):
|
1516
1511
|
raise RuntimeError(
|
@@ -1522,52 +1517,3 @@ class Api:
|
|
1522
1517
|
self.settings, self.default_entity
|
1523
1518
|
)
|
1524
1519
|
return Registries(self.client, organization, filter)
|
1525
|
-
|
1526
|
-
def _check_server_feature(self, feature: ServerFeature) -> bool:
|
1527
|
-
"""Check if a server feature is enabled.
|
1528
|
-
|
1529
|
-
Args:
|
1530
|
-
feature (ServerFeature): The feature to check.
|
1531
|
-
|
1532
|
-
Returns:
|
1533
|
-
bool: True if the feature is enabled, False otherwise.
|
1534
|
-
|
1535
|
-
Raises:
|
1536
|
-
Exception: If server doesn't support feature queries or other errors occur
|
1537
|
-
"""
|
1538
|
-
if self._server_features_cache is None:
|
1539
|
-
response = self.client.execute(gql(SERVER_FEATURES_QUERY_GQL))
|
1540
|
-
self._server_features_cache = ServerFeaturesQuery.model_validate(response)
|
1541
|
-
|
1542
|
-
feature_name = ServerFeature.Name(feature)
|
1543
|
-
if (
|
1544
|
-
self._server_features_cache
|
1545
|
-
and self._server_features_cache.server_info
|
1546
|
-
and self._server_features_cache.server_info.features
|
1547
|
-
):
|
1548
|
-
for feature_info in self._server_features_cache.server_info.features:
|
1549
|
-
if feature_info and feature_info.name == feature_name:
|
1550
|
-
return feature_info.is_enabled
|
1551
|
-
|
1552
|
-
return False
|
1553
|
-
|
1554
|
-
def _check_server_feature_with_fallback(self, feature: ServerFeature) -> bool:
|
1555
|
-
"""Wrapper around check_server_feature that warns and returns False for older unsupported servers.
|
1556
|
-
|
1557
|
-
Good to use for features that have a fallback mechanism for older servers.
|
1558
|
-
|
1559
|
-
Args:
|
1560
|
-
feature (ServerFeature): The feature to check.
|
1561
|
-
|
1562
|
-
Returns:
|
1563
|
-
bool: True if the feature is enabled, False otherwise.
|
1564
|
-
|
1565
|
-
Exceptions:
|
1566
|
-
Exception: If an error other than the server not supporting feature queries occurs.
|
1567
|
-
"""
|
1568
|
-
try:
|
1569
|
-
return self._check_server_feature(feature)
|
1570
|
-
except Exception as e:
|
1571
|
-
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
1572
|
-
return False
|
1573
|
-
raise e
|
wandb/apis/public/artifacts.py
CHANGED
@@ -10,19 +10,21 @@ from wandb_gql import Client, gql
|
|
10
10
|
import wandb
|
11
11
|
from wandb.apis import public
|
12
12
|
from wandb.apis.normalize import normalize_exceptions
|
13
|
-
from wandb.apis.paginator import Paginator
|
13
|
+
from wandb.apis.paginator import Paginator, SizedPaginator
|
14
14
|
from wandb.errors.term import termlog
|
15
|
+
from wandb.proto.wandb_internal_pb2 import ServerFeature
|
15
16
|
from wandb.sdk.artifacts._graphql_fragments import (
|
16
17
|
ARTIFACT_FILES_FRAGMENT,
|
17
18
|
ARTIFACTS_TYPES_FRAGMENT,
|
18
19
|
)
|
20
|
+
from wandb.sdk.internal.internal_api import Api as InternalApi
|
19
21
|
from wandb.sdk.lib import deprecate
|
20
22
|
|
21
23
|
if TYPE_CHECKING:
|
22
24
|
from wandb.apis.public import RetryingClient, Run
|
23
25
|
|
24
26
|
|
25
|
-
class ArtifactTypes(Paginator):
|
27
|
+
class ArtifactTypes(Paginator["ArtifactType"]):
|
26
28
|
QUERY = gql(
|
27
29
|
"""
|
28
30
|
query ProjectArtifacts(
|
@@ -45,7 +47,7 @@ class ArtifactTypes(Paginator):
|
|
45
47
|
client: Client,
|
46
48
|
entity: str,
|
47
49
|
project: str,
|
48
|
-
per_page:
|
50
|
+
per_page: int = 50,
|
49
51
|
):
|
50
52
|
self.entity = entity
|
51
53
|
self.project = project
|
@@ -58,7 +60,7 @@ class ArtifactTypes(Paginator):
|
|
58
60
|
super().__init__(client, variable_values, per_page)
|
59
61
|
|
60
62
|
@property
|
61
|
-
def length(self):
|
63
|
+
def length(self) -> None:
|
62
64
|
# TODO
|
63
65
|
return None
|
64
66
|
|
@@ -167,14 +169,14 @@ class ArtifactType:
|
|
167
169
|
return f"<ArtifactType {self.type}>"
|
168
170
|
|
169
171
|
|
170
|
-
class ArtifactCollections(
|
172
|
+
class ArtifactCollections(SizedPaginator["ArtifactCollection"]):
|
171
173
|
def __init__(
|
172
174
|
self,
|
173
175
|
client: Client,
|
174
176
|
entity: str,
|
175
177
|
project: str,
|
176
178
|
type_name: str,
|
177
|
-
per_page:
|
179
|
+
per_page: int = 50,
|
178
180
|
):
|
179
181
|
self.entity = entity
|
180
182
|
self.project = project
|
@@ -710,7 +712,7 @@ class ArtifactCollection:
|
|
710
712
|
return f"<ArtifactCollection {self._name} ({self._type})>"
|
711
713
|
|
712
714
|
|
713
|
-
class Artifacts(
|
715
|
+
class Artifacts(SizedPaginator["wandb.Artifact"]):
|
714
716
|
"""An iterable collection of artifact versions associated with a project and optional filter.
|
715
717
|
|
716
718
|
This is generally used indirectly via the `Api`.artifact_versions method.
|
@@ -826,10 +828,8 @@ class Artifacts(Paginator):
|
|
826
828
|
]
|
827
829
|
|
828
830
|
|
829
|
-
class RunArtifacts(
|
830
|
-
def __init__(
|
831
|
-
self, client: Client, run: "Run", mode="logged", per_page: Optional[int] = 50
|
832
|
-
):
|
831
|
+
class RunArtifacts(SizedPaginator["wandb.Artifact"]):
|
832
|
+
def __init__(self, client: Client, run: "Run", mode="logged", per_page: int = 50):
|
833
833
|
from wandb.sdk.artifacts.artifact import _gql_artifact_fragment
|
834
834
|
|
835
835
|
output_query = gql(
|
@@ -944,9 +944,9 @@ class RunArtifacts(Paginator):
|
|
944
944
|
]
|
945
945
|
|
946
946
|
|
947
|
-
class ArtifactFiles(
|
948
|
-
|
949
|
-
"""
|
947
|
+
class ArtifactFiles(SizedPaginator["public.File"]):
|
948
|
+
ARTIFACT_VERSION_FILES_QUERY = gql(
|
949
|
+
f"""
|
950
950
|
query ArtifactFiles(
|
951
951
|
$entityName: String!,
|
952
952
|
$projectName: String!,
|
@@ -959,13 +959,40 @@ class ArtifactFiles(Paginator):
|
|
959
959
|
project(name: $projectName, entityName: $entityName) {{
|
960
960
|
artifactType(name: $artifactTypeName) {{
|
961
961
|
artifact(name: $artifactName) {{
|
962
|
-
|
962
|
+
files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
|
963
|
+
...FilesFragment
|
964
|
+
}}
|
963
965
|
}}
|
964
966
|
}}
|
965
967
|
}}
|
966
968
|
}}
|
967
|
-
{}
|
968
|
-
"""
|
969
|
+
{ARTIFACT_FILES_FRAGMENT}
|
970
|
+
"""
|
971
|
+
)
|
972
|
+
|
973
|
+
ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY = gql(
|
974
|
+
f"""
|
975
|
+
query ArtifactCollectionMembershipFiles(
|
976
|
+
$entityName: String!,
|
977
|
+
$projectName: String!,
|
978
|
+
$artifactName: String!,
|
979
|
+
$artifactVersionIndex: String!,
|
980
|
+
$fileNames: [String!],
|
981
|
+
$fileCursor: String,
|
982
|
+
$fileLimit: Int = 50
|
983
|
+
) {{
|
984
|
+
project(name: $projectName, entityName: $entityName) {{
|
985
|
+
artifactCollection(name: $artifactName) {{
|
986
|
+
artifactMembership (aliasName: $artifactVersionIndex) {{
|
987
|
+
files(names: $fileNames, after: $fileCursor, first: $fileLimit) {{
|
988
|
+
...FilesFragment
|
989
|
+
}}
|
990
|
+
}}
|
991
|
+
}}
|
992
|
+
}}
|
993
|
+
}}
|
994
|
+
{ARTIFACT_FILES_FRAGMENT}
|
995
|
+
"""
|
969
996
|
)
|
970
997
|
|
971
998
|
def __init__(
|
@@ -975,6 +1002,9 @@ class ArtifactFiles(Paginator):
|
|
975
1002
|
names: Optional[Sequence[str]] = None,
|
976
1003
|
per_page: int = 50,
|
977
1004
|
):
|
1005
|
+
self.query_via_membership = InternalApi()._check_server_feature_with_fallback(
|
1006
|
+
ServerFeature.ARTIFACT_COLLECTION_MEMBERSHIP_FILES
|
1007
|
+
)
|
978
1008
|
self.artifact = artifact
|
979
1009
|
variables = {
|
980
1010
|
"entityName": artifact.source_entity,
|
@@ -983,6 +1013,17 @@ class ArtifactFiles(Paginator):
|
|
983
1013
|
"artifactName": artifact.source_name,
|
984
1014
|
"fileNames": names,
|
985
1015
|
}
|
1016
|
+
if self.query_via_membership:
|
1017
|
+
self.QUERY = self.ARTIFACT_COLLECTION_MEMBERSHIP_FILES_QUERY
|
1018
|
+
variables = {
|
1019
|
+
"entityName": artifact.entity,
|
1020
|
+
"projectName": artifact.project,
|
1021
|
+
"artifactName": artifact.name.split(":")[0],
|
1022
|
+
"artifactVersionIndex": artifact.version,
|
1023
|
+
"fileNames": names,
|
1024
|
+
}
|
1025
|
+
else:
|
1026
|
+
self.QUERY = self.ARTIFACT_VERSION_FILES_QUERY
|
986
1027
|
# The server must advertise at least SDK 0.12.21
|
987
1028
|
# to get storagePath
|
988
1029
|
if not client.version_supported("0.12.21"):
|
@@ -1000,6 +1041,10 @@ class ArtifactFiles(Paginator):
|
|
1000
1041
|
@property
|
1001
1042
|
def more(self):
|
1002
1043
|
if self.last_response:
|
1044
|
+
if self.query_via_membership:
|
1045
|
+
return self.last_response["project"]["artifactCollection"][
|
1046
|
+
"artifactMembership"
|
1047
|
+
]["files"]["pageInfo"]["hasNextPage"]
|
1003
1048
|
return self.last_response["project"]["artifactType"]["artifact"]["files"][
|
1004
1049
|
"pageInfo"
|
1005
1050
|
]["hasNextPage"]
|
@@ -1009,6 +1054,10 @@ class ArtifactFiles(Paginator):
|
|
1009
1054
|
@property
|
1010
1055
|
def cursor(self):
|
1011
1056
|
if self.last_response:
|
1057
|
+
if self.query_via_membership:
|
1058
|
+
return self.last_response["project"]["artifactCollection"][
|
1059
|
+
"artifactMembership"
|
1060
|
+
]["files"]["edges"][-1]["cursor"]
|
1012
1061
|
return self.last_response["project"]["artifactType"]["artifact"]["files"][
|
1013
1062
|
"edges"
|
1014
1063
|
][-1]["cursor"]
|
@@ -1019,6 +1068,13 @@ class ArtifactFiles(Paginator):
|
|
1019
1068
|
self.variables.update({"fileLimit": self.per_page, "fileCursor": self.cursor})
|
1020
1069
|
|
1021
1070
|
def convert_objects(self):
|
1071
|
+
if self.query_via_membership:
|
1072
|
+
return [
|
1073
|
+
public.File(self.client, r["node"])
|
1074
|
+
for r in self.last_response["project"]["artifactCollection"][
|
1075
|
+
"artifactMembership"
|
1076
|
+
]["files"]["edges"]
|
1077
|
+
]
|
1022
1078
|
return [
|
1023
1079
|
public.File(self.client, r["node"])
|
1024
1080
|
for r in self.last_response["project"]["artifactType"]["artifact"]["files"][
|
wandb/apis/public/files.py
CHANGED
@@ -12,7 +12,7 @@ import wandb
|
|
12
12
|
from wandb import util
|
13
13
|
from wandb.apis.attrs import Attrs
|
14
14
|
from wandb.apis.normalize import normalize_exceptions
|
15
|
-
from wandb.apis.paginator import
|
15
|
+
from wandb.apis.paginator import SizedPaginator
|
16
16
|
from wandb.apis.public import utils
|
17
17
|
from wandb.apis.public.api import Api
|
18
18
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
@@ -41,7 +41,7 @@ FILE_FRAGMENT = """fragment RunFilesFragment on Run {
|
|
41
41
|
}"""
|
42
42
|
|
43
43
|
|
44
|
-
class Files(
|
44
|
+
class Files(SizedPaginator["File"]):
|
45
45
|
"""An iterable collection of `File` objects."""
|
46
46
|
|
47
47
|
QUERY = gql(
|
wandb/apis/public/projects.py
CHANGED
@@ -17,7 +17,7 @@ PROJECT_FRAGMENT = """fragment ProjectFragment on Project {
|
|
17
17
|
}"""
|
18
18
|
|
19
19
|
|
20
|
-
class Projects(Paginator):
|
20
|
+
class Projects(Paginator["Project"]):
|
21
21
|
"""An iterable collection of `Project` objects."""
|
22
22
|
|
23
23
|
QUERY = gql(
|
@@ -49,7 +49,8 @@ class Projects(Paginator):
|
|
49
49
|
super().__init__(client, variables, per_page)
|
50
50
|
|
51
51
|
@property
|
52
|
-
def length(self):
|
52
|
+
def length(self) -> None:
|
53
|
+
# For backwards compatibility, even though this isn't a SizedPaginator
|
53
54
|
return None
|
54
55
|
|
55
56
|
@property
|
wandb/apis/public/reports.py
CHANGED
@@ -9,11 +9,11 @@ from wandb_gql import gql
|
|
9
9
|
import wandb
|
10
10
|
from wandb.apis import public
|
11
11
|
from wandb.apis.attrs import Attrs
|
12
|
-
from wandb.apis.paginator import
|
12
|
+
from wandb.apis.paginator import SizedPaginator
|
13
13
|
from wandb.sdk.lib import ipython
|
14
14
|
|
15
15
|
|
16
|
-
class Reports(
|
16
|
+
class Reports(SizedPaginator["BetaReport"]):
|
17
17
|
"""Reports is an iterable collection of `BetaReport` objects."""
|
18
18
|
|
19
19
|
QUERY = gql(
|
wandb/apis/public/runs.py
CHANGED
@@ -24,7 +24,7 @@ from wandb.apis import public
|
|
24
24
|
from wandb.apis.attrs import Attrs
|
25
25
|
from wandb.apis.internal import Api as InternalApi
|
26
26
|
from wandb.apis.normalize import normalize_exceptions
|
27
|
-
from wandb.apis.paginator import
|
27
|
+
from wandb.apis.paginator import SizedPaginator
|
28
28
|
from wandb.apis.public.const import RETRY_TIMEDELTA
|
29
29
|
from wandb.sdk.lib import ipython, json_util, runid
|
30
30
|
from wandb.sdk.lib.paths import LogicalPath
|
@@ -61,7 +61,7 @@ RUN_FRAGMENT = """fragment RunFragment on Run {
|
|
61
61
|
}"""
|
62
62
|
|
63
63
|
|
64
|
-
class Runs(
|
64
|
+
class Runs(SizedPaginator["Run"]):
|
65
65
|
"""An iterable collection of runs associated with a project and optional filter.
|
66
66
|
|
67
67
|
This is generally used indirectly via the `Api`.runs method.
|
@@ -421,16 +421,15 @@ class Run(Attrs):
|
|
421
421
|
"""
|
422
422
|
query Run($project: String!, $entity: String!, $name: String!) {{
|
423
423
|
project(name: $project, entityName: $entity) {{
|
424
|
-
{}
|
425
424
|
run(name: $name) {{
|
425
|
+
{}
|
426
426
|
...RunFragment
|
427
427
|
}}
|
428
428
|
}}
|
429
429
|
}}
|
430
430
|
{}
|
431
431
|
""".format(
|
432
|
-
|
433
|
-
"internalId" if self._server_provides_internal_id_for_project() else "",
|
432
|
+
"projectId" if self._server_provides_internal_id_for_project() else "",
|
434
433
|
RUN_FRAGMENT,
|
435
434
|
)
|
436
435
|
)
|
@@ -444,7 +443,11 @@ class Run(Attrs):
|
|
444
443
|
raise ValueError("Could not find run {}".format(self))
|
445
444
|
self._attrs = response["project"]["run"]
|
446
445
|
self._state = self._attrs["state"]
|
447
|
-
|
446
|
+
|
447
|
+
self._project_internal_id = (
|
448
|
+
int(self._attrs["projectId"]) if "projectId" in self._attrs else None
|
449
|
+
)
|
450
|
+
|
448
451
|
if self._include_sweeps and self.sweep_name and not self.sweep:
|
449
452
|
# There may be a lot of runs. Don't bother pulling them all
|
450
453
|
# just for the sake of this one.
|
@@ -847,7 +850,12 @@ class Run(Attrs):
|
|
847
850
|
api.set_current_run_id(self.id)
|
848
851
|
|
849
852
|
if isinstance(artifact, wandb.Artifact) and not artifact.is_draft():
|
850
|
-
api.use_artifact(
|
853
|
+
api.use_artifact(
|
854
|
+
artifact.id,
|
855
|
+
use_as=use_as or artifact.name,
|
856
|
+
artifact_entity_name=artifact.entity,
|
857
|
+
artifact_project_name=artifact.project,
|
858
|
+
)
|
851
859
|
return artifact
|
852
860
|
elif isinstance(artifact, wandb.Artifact) and artifact.is_draft():
|
853
861
|
raise ValueError(
|
@@ -911,8 +919,8 @@ class Run(Attrs):
|
|
911
919
|
This check is done by utilizing GraphQL introspection in the available fields on the Project type.
|
912
920
|
"""
|
913
921
|
query_string = """
|
914
|
-
query
|
915
|
-
|
922
|
+
query ProbeRunInput {
|
923
|
+
RunType: __type(name:"Run") {
|
916
924
|
fields {
|
917
925
|
name
|
918
926
|
}
|
@@ -924,8 +932,8 @@ class Run(Attrs):
|
|
924
932
|
if self.server_provides_internal_id_field is None:
|
925
933
|
query = gql(query_string)
|
926
934
|
res = self.client.execute(query)
|
927
|
-
self.server_provides_internal_id_field = "
|
928
|
-
x["name"] for x in (res.get("
|
935
|
+
self.server_provides_internal_id_field = "projectId" in [
|
936
|
+
x["name"] for x in (res.get("RunType", {}).get("fields", [{}]))
|
929
937
|
]
|
930
938
|
|
931
939
|
return self.server_provides_internal_id_field
|
wandb/bin/gpu_stats
CHANGED
Binary file
|
wandb/bin/wandb-core
CHANGED
Binary file
|
@@ -13,6 +13,7 @@ import inspect
|
|
13
13
|
import pickle
|
14
14
|
from functools import wraps
|
15
15
|
from pathlib import Path
|
16
|
+
from typing import Union
|
16
17
|
|
17
18
|
import wandb
|
18
19
|
from wandb.sdk.lib import telemetry as wb_telemetry
|
@@ -25,17 +26,18 @@ except ImportError as e:
|
|
25
26
|
) from e
|
26
27
|
|
27
28
|
try:
|
28
|
-
from
|
29
|
+
from plum import dispatch
|
29
30
|
except ImportError as e:
|
30
31
|
raise Exception(
|
31
|
-
"Error: `
|
32
|
+
"Error: `plum-dispatch` not installed >> "
|
33
|
+
"This integration requires plum-dispatch! To fix, please `pip install -Uqq plum-dispatch`"
|
32
34
|
) from e
|
33
35
|
|
34
36
|
|
35
37
|
try:
|
36
38
|
import pandas as pd
|
37
39
|
|
38
|
-
@
|
40
|
+
@dispatch # noqa: F811
|
39
41
|
def _wandb_use(
|
40
42
|
name: str,
|
41
43
|
data: pd.DataFrame,
|
@@ -52,7 +54,7 @@ try:
|
|
52
54
|
run.use_artifact(f"{name}:latest")
|
53
55
|
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
54
56
|
|
55
|
-
@
|
57
|
+
@dispatch # noqa: F811
|
56
58
|
def wandb_track(
|
57
59
|
name: str,
|
58
60
|
data: pd.DataFrame,
|
@@ -81,7 +83,7 @@ try:
|
|
81
83
|
import torch
|
82
84
|
import torch.nn as nn
|
83
85
|
|
84
|
-
@
|
86
|
+
@dispatch # noqa: F811
|
85
87
|
def _wandb_use(
|
86
88
|
name: str,
|
87
89
|
data: nn.Module,
|
@@ -98,7 +100,7 @@ try:
|
|
98
100
|
run.use_artifact(f"{name}:latest")
|
99
101
|
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
100
102
|
|
101
|
-
@
|
103
|
+
@dispatch # noqa: F811
|
102
104
|
def wandb_track(
|
103
105
|
name: str,
|
104
106
|
data: nn.Module,
|
@@ -126,7 +128,7 @@ except ImportError:
|
|
126
128
|
try:
|
127
129
|
from sklearn.base import BaseEstimator
|
128
130
|
|
129
|
-
@
|
131
|
+
@dispatch # noqa: F811
|
130
132
|
def _wandb_use(
|
131
133
|
name: str,
|
132
134
|
data: BaseEstimator,
|
@@ -143,7 +145,7 @@ try:
|
|
143
145
|
run.use_artifact(f"{name}:latest")
|
144
146
|
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
145
147
|
|
146
|
-
@
|
148
|
+
@dispatch # noqa: F811
|
147
149
|
def wandb_track(
|
148
150
|
name: str,
|
149
151
|
data: BaseEstimator,
|
@@ -192,10 +194,10 @@ class ArtifactProxy:
|
|
192
194
|
return getattr(self.flow, key)
|
193
195
|
|
194
196
|
|
195
|
-
@
|
197
|
+
@dispatch # noqa: F811
|
196
198
|
def wandb_track(
|
197
199
|
name: str,
|
198
|
-
data:
|
200
|
+
data: Union[dict, list, set, str, int, float, bool],
|
199
201
|
run=None,
|
200
202
|
testing=False,
|
201
203
|
*args,
|
@@ -207,7 +209,7 @@ def wandb_track(
|
|
207
209
|
run.log({name: data})
|
208
210
|
|
209
211
|
|
210
|
-
@
|
212
|
+
@dispatch # noqa: F811
|
211
213
|
def wandb_track(
|
212
214
|
name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
|
213
215
|
):
|
@@ -225,7 +227,7 @@ def wandb_track(
|
|
225
227
|
|
226
228
|
|
227
229
|
# this is the base case
|
228
|
-
@
|
230
|
+
@dispatch # noqa: F811
|
229
231
|
def wandb_track(
|
230
232
|
name: str, data, others=False, run=None, testing=False, *args, **kwargs
|
231
233
|
):
|
@@ -240,7 +242,7 @@ def wandb_track(
|
|
240
242
|
wandb.termlog(f"Logging artifact: {name} ({type(data)})")
|
241
243
|
|
242
244
|
|
243
|
-
@
|
245
|
+
@dispatch # noqa: F811
|
244
246
|
def wandb_use(name: str, data, *args, **kwargs):
|
245
247
|
try:
|
246
248
|
return _wandb_use(name, data, *args, **kwargs)
|
@@ -252,14 +254,14 @@ def wandb_use(name: str, data, *args, **kwargs):
|
|
252
254
|
)
|
253
255
|
|
254
256
|
|
255
|
-
@
|
257
|
+
@dispatch # noqa: F811
|
256
258
|
def wandb_use(
|
257
|
-
name: str, data:
|
259
|
+
name: str, data: Union[dict, list, set, str, int, float, bool], *args, **kwargs
|
258
260
|
): # type: ignore
|
259
261
|
pass # do nothing for these types
|
260
262
|
|
261
263
|
|
262
|
-
@
|
264
|
+
@dispatch # noqa: F811
|
263
265
|
def _wandb_use(
|
264
266
|
name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
|
265
267
|
): # type: ignore
|
@@ -271,7 +273,7 @@ def _wandb_use(
|
|
271
273
|
wandb.termlog(f"Using artifact: {name} ({type(data)})")
|
272
274
|
|
273
275
|
|
274
|
-
@
|
276
|
+
@dispatch # noqa: F811
|
275
277
|
def _wandb_use(name: str, data, others=False, run=None, testing=False, *args, **kwargs): # type: ignore
|
276
278
|
if testing:
|
277
279
|
return "others" if others else None
|
@@ -23,7 +23,7 @@ class WandbObserver(RunObserver):
|
|
23
23
|
job_type — the type of job you are logging, e.g. eval, worker, ps (default: training)
|
24
24
|
save_code — save the main python or notebook file to wandb to enable diffing (default: editable from your settings page)
|
25
25
|
group — a string by which to group other runs; see Grouping
|
26
|
-
reinit —
|
26
|
+
reinit — Shorthand for the reinit setting that defines what to do when `wandb.init()` is called while a run is active. See the setting's documentation.
|
27
27
|
id — A unique ID for this run primarily used for Resuming. It must be globally unique, and if you delete a run you can't reuse the ID. Use the name field for a descriptive, useful name for the run. The ID cannot contain special characters.
|
28
28
|
resume — if set to True, the run auto resumes; can also be a unique string for manual resuming; see Resuming (default: False)
|
29
29
|
anonymous — can be "allow", "never", or "must". This enables or explicitly disables anonymous logging. (default: never)
|