wandb 0.19.12rc1__py3-none-win32.whl → 0.20.1__py3-none-win32.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -2
- wandb/__init__.pyi +3 -6
- wandb/_iterutils.py +26 -7
- wandb/_pydantic/__init__.py +2 -1
- wandb/_pydantic/utils.py +7 -0
- wandb/agents/pyagent.py +9 -15
- wandb/analytics/sentry.py +1 -2
- wandb/apis/attrs.py +3 -4
- wandb/apis/importers/internals/util.py +1 -1
- wandb/apis/importers/validation.py +2 -2
- wandb/apis/importers/wandb.py +30 -25
- wandb/apis/normalize.py +2 -2
- wandb/apis/public/__init__.py +1 -0
- wandb/apis/public/api.py +37 -33
- wandb/apis/public/artifacts.py +103 -72
- wandb/apis/public/jobs.py +3 -2
- wandb/apis/public/registries/registries_search.py +4 -2
- wandb/apis/public/registries/registry.py +1 -1
- wandb/apis/public/registries/utils.py +9 -9
- wandb/apis/public/runs.py +18 -6
- wandb/automations/_filters/expressions.py +1 -1
- wandb/automations/_filters/operators.py +1 -1
- wandb/automations/_filters/run_metrics.py +1 -1
- wandb/beta/workflows.py +6 -5
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +54 -73
- wandb/docker/__init__.py +21 -74
- wandb/docker/names.py +40 -0
- wandb/env.py +0 -1
- wandb/errors/util.py +1 -1
- wandb/filesync/step_checksum.py +1 -1
- wandb/filesync/step_upload.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +1 -2
- wandb/integration/gym/__init__.py +5 -6
- wandb/integration/keras/callbacks/model_checkpoint.py +2 -2
- wandb/integration/keras/keras.py +13 -19
- wandb/integration/kfp/kfp_patch.py +2 -3
- wandb/integration/langchain/wandb_tracer.py +1 -1
- wandb/integration/metaflow/metaflow.py +13 -13
- wandb/integration/openai/fine_tuning.py +3 -2
- wandb/integration/sagemaker/auth.py +2 -1
- wandb/integration/sklearn/utils.py +2 -1
- wandb/integration/tensorboard/__init__.py +1 -1
- wandb/integration/tensorboard/log.py +2 -5
- wandb/integration/tensorflow/__init__.py +2 -2
- wandb/jupyter.py +20 -17
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plot/utils.py +8 -7
- wandb/proto/v3/wandb_internal_pb2.py +355 -335
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +339 -335
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +339 -335
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v6/wandb_internal_pb2.py +339 -335
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +6 -8
- wandb/sdk/artifacts/_internal_artifact.py +43 -0
- wandb/sdk/artifacts/_validators.py +55 -35
- wandb/sdk/artifacts/artifact.py +117 -115
- wandb/sdk/artifacts/artifact_download_logger.py +2 -0
- wandb/sdk/artifacts/artifact_saver.py +1 -3
- wandb/sdk/artifacts/artifact_state.py +2 -0
- wandb/sdk/artifacts/artifact_ttl.py +2 -0
- wandb/sdk/artifacts/exceptions.py +14 -0
- wandb/sdk/artifacts/staging.py +2 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +2 -6
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -5
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/artifacts/storage_layout.py +2 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +3 -3
- wandb/sdk/backend/backend.py +11 -182
- wandb/sdk/data_types/_dtypes.py +2 -6
- wandb/sdk/data_types/audio.py +20 -3
- wandb/sdk/data_types/base_types/media.py +12 -7
- wandb/sdk/data_types/base_types/wb_value.py +8 -18
- wandb/sdk/data_types/bokeh.py +19 -2
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +17 -1
- wandb/sdk/data_types/helper_types/image_mask.py +7 -1
- wandb/sdk/data_types/html.py +4 -4
- wandb/sdk/data_types/image.py +178 -103
- wandb/sdk/data_types/molecule.py +6 -6
- wandb/sdk/data_types/object_3d.py +10 -5
- wandb/sdk/data_types/saved_model.py +11 -6
- wandb/sdk/data_types/table.py +313 -83
- wandb/sdk/data_types/table_decorators.py +108 -0
- wandb/sdk/data_types/utils.py +43 -7
- wandb/sdk/data_types/video.py +21 -3
- wandb/sdk/interface/interface.py +10 -0
- wandb/sdk/internal/datastore.py +2 -6
- wandb/sdk/internal/file_pusher.py +1 -5
- wandb/sdk/internal/file_stream.py +8 -17
- wandb/sdk/internal/handler.py +2 -2
- wandb/sdk/internal/incremental_table_util.py +53 -0
- wandb/sdk/internal/internal.py +3 -5
- wandb/sdk/internal/internal_api.py +66 -89
- wandb/sdk/internal/job_builder.py +2 -7
- wandb/sdk/internal/profiler.py +2 -2
- wandb/sdk/internal/progress.py +1 -3
- wandb/sdk/internal/run.py +1 -6
- wandb/sdk/internal/sender.py +24 -36
- wandb/sdk/internal/system/assets/aggregators.py +1 -7
- wandb/sdk/internal/system/assets/disk.py +3 -3
- wandb/sdk/internal/system/assets/gpu.py +4 -4
- wandb/sdk/internal/system/assets/gpu_amd.py +4 -4
- wandb/sdk/internal/system/assets/interfaces.py +6 -6
- wandb/sdk/internal/system/assets/tpu.py +1 -1
- wandb/sdk/internal/system/assets/trainium.py +6 -6
- wandb/sdk/internal/system/system_info.py +5 -7
- wandb/sdk/internal/system/system_monitor.py +4 -4
- wandb/sdk/internal/tb_watcher.py +5 -7
- wandb/sdk/launch/_launch.py +1 -1
- wandb/sdk/launch/_project_spec.py +19 -20
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/config.py +1 -1
- wandb/sdk/launch/agent/job_status_tracker.py +2 -2
- wandb/sdk/launch/builder/build.py +2 -3
- wandb/sdk/launch/builder/kaniko_builder.py +5 -4
- wandb/sdk/launch/environment/gcp_environment.py +1 -2
- wandb/sdk/launch/registry/azure_container_registry.py +2 -2
- wandb/sdk/launch/registry/elastic_container_registry.py +2 -2
- wandb/sdk/launch/registry/google_artifact_registry.py +3 -3
- wandb/sdk/launch/runner/abstract.py +5 -5
- wandb/sdk/launch/runner/kubernetes_monitor.py +2 -2
- wandb/sdk/launch/runner/kubernetes_runner.py +1 -1
- wandb/sdk/launch/runner/sagemaker_runner.py +2 -4
- wandb/sdk/launch/runner/vertex_runner.py +2 -7
- wandb/sdk/launch/sweeps/__init__.py +1 -1
- wandb/sdk/launch/sweeps/scheduler.py +2 -2
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -4
- wandb/sdk/lib/apikey.py +5 -8
- wandb/sdk/lib/config_util.py +3 -3
- wandb/sdk/lib/fsm.py +3 -18
- wandb/sdk/lib/gitlib.py +6 -5
- wandb/sdk/lib/ipython.py +2 -2
- wandb/sdk/lib/json_util.py +9 -14
- wandb/sdk/lib/printer.py +3 -8
- wandb/sdk/lib/redirect.py +1 -1
- wandb/sdk/lib/retry.py +3 -7
- wandb/sdk/lib/run_moment.py +2 -2
- wandb/sdk/lib/service_connection.py +3 -1
- wandb/sdk/lib/service_token.py +1 -2
- wandb/sdk/mailbox/mailbox_handle.py +3 -7
- wandb/sdk/mailbox/response_handle.py +2 -6
- wandb/sdk/service/streams.py +3 -7
- wandb/sdk/verify/verify.py +5 -6
- wandb/sdk/wandb_config.py +1 -1
- wandb/sdk/wandb_init.py +38 -106
- wandb/sdk/wandb_login.py +7 -6
- wandb/sdk/wandb_run.py +52 -240
- wandb/sdk/wandb_settings.py +71 -60
- wandb/sdk/wandb_setup.py +40 -14
- wandb/sdk/wandb_watch.py +5 -7
- wandb/sync/__init__.py +1 -1
- wandb/sync/sync.py +13 -13
- wandb/util.py +17 -35
- wandb/wandb_agent.py +8 -11
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/METADATA +5 -5
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/RECORD +170 -168
- wandb/docker/auth.py +0 -435
- wandb/docker/www_authenticate.py +0 -94
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/WHEEL +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.12rc1.dist-info → wandb-0.20.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,3 @@
|
|
1
|
-
import ast
|
2
1
|
import base64
|
3
2
|
import datetime
|
4
3
|
import functools
|
@@ -12,7 +11,6 @@ import sys
|
|
12
11
|
import threading
|
13
12
|
from copy import deepcopy
|
14
13
|
from pathlib import Path
|
15
|
-
from types import MappingProxyType
|
16
14
|
from typing import (
|
17
15
|
IO,
|
18
16
|
TYPE_CHECKING,
|
@@ -70,42 +68,42 @@ if TYPE_CHECKING:
|
|
70
68
|
class CreateArtifactFileSpecInput(TypedDict, total=False):
|
71
69
|
"""Corresponds to `type CreateArtifactFileSpecInput` in schema.graphql."""
|
72
70
|
|
73
|
-
artifactID: str
|
71
|
+
artifactID: str
|
74
72
|
name: str
|
75
73
|
md5: str
|
76
74
|
mimetype: Optional[str]
|
77
|
-
artifactManifestID: Optional[str]
|
78
|
-
uploadPartsInput: Optional[List[Dict[str, object]]]
|
75
|
+
artifactManifestID: Optional[str]
|
76
|
+
uploadPartsInput: Optional[List[Dict[str, object]]]
|
79
77
|
|
80
78
|
class CreateArtifactFilesResponseFile(TypedDict):
|
81
79
|
id: str
|
82
80
|
name: str
|
83
|
-
displayName: str
|
84
|
-
uploadUrl: Optional[str]
|
85
|
-
uploadHeaders: Sequence[str]
|
86
|
-
uploadMultipartUrls: "UploadPartsResponse"
|
87
|
-
storagePath: str
|
81
|
+
displayName: str
|
82
|
+
uploadUrl: Optional[str]
|
83
|
+
uploadHeaders: Sequence[str]
|
84
|
+
uploadMultipartUrls: "UploadPartsResponse"
|
85
|
+
storagePath: str
|
88
86
|
artifact: "CreateArtifactFilesResponseFileNode"
|
89
87
|
|
90
88
|
class CreateArtifactFilesResponseFileNode(TypedDict):
|
91
89
|
id: str
|
92
90
|
|
93
91
|
class UploadPartsResponse(TypedDict):
|
94
|
-
uploadUrlParts: List["UploadUrlParts"]
|
95
|
-
uploadID: str
|
92
|
+
uploadUrlParts: List["UploadUrlParts"]
|
93
|
+
uploadID: str
|
96
94
|
|
97
95
|
class UploadUrlParts(TypedDict):
|
98
|
-
partNumber: int
|
99
|
-
uploadUrl: str
|
96
|
+
partNumber: int
|
97
|
+
uploadUrl: str
|
100
98
|
|
101
99
|
class CompleteMultipartUploadArtifactInput(TypedDict):
|
102
100
|
"""Corresponds to `type CompleteMultipartUploadArtifactInput` in schema.graphql."""
|
103
101
|
|
104
|
-
completeMultipartAction: str
|
105
|
-
completedParts: Dict[int, str]
|
106
|
-
artifactID: str
|
107
|
-
storagePath: str
|
108
|
-
uploadID: str
|
102
|
+
completeMultipartAction: str
|
103
|
+
completedParts: Dict[int, str]
|
104
|
+
artifactID: str
|
105
|
+
storagePath: str
|
106
|
+
uploadID: str
|
109
107
|
md5: str
|
110
108
|
|
111
109
|
class CompleteMultipartUploadArtifactResponse(TypedDict):
|
@@ -238,7 +236,7 @@ class Api:
|
|
238
236
|
]
|
239
237
|
] = None,
|
240
238
|
load_settings: bool = True,
|
241
|
-
retry_timedelta: datetime.timedelta = datetime.timedelta( #
|
239
|
+
retry_timedelta: datetime.timedelta = datetime.timedelta( # okay because it's immutable
|
242
240
|
days=7
|
243
241
|
),
|
244
242
|
environ: MutableMapping = os.environ,
|
@@ -364,7 +362,7 @@ class Api:
|
|
364
362
|
self.server_create_run_queue_supports_priority: Optional[bool] = None
|
365
363
|
self.server_supports_template_variables: Optional[bool] = None
|
366
364
|
self.server_push_to_run_queue_supports_priority: Optional[bool] = None
|
367
|
-
self._server_features_cache: Optional[
|
365
|
+
self._server_features_cache: Optional[Dict[str, bool]] = None
|
368
366
|
|
369
367
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
370
368
|
ret = self._retry_gql(
|
@@ -399,8 +397,7 @@ class Api:
|
|
399
397
|
except requests.exceptions.HTTPError as err:
|
400
398
|
response = err.response
|
401
399
|
assert response is not None
|
402
|
-
logger.
|
403
|
-
logger.error(response.text)
|
400
|
+
logger.exception("Error executing GraphQL.")
|
404
401
|
for error in parse_backend_error_messages(response):
|
405
402
|
wandb.termerror(f"Error while calling W&B API: {error} ({response})")
|
406
403
|
raise
|
@@ -869,45 +866,43 @@ class Api:
|
|
869
866
|
_, _, mutations = self.server_info_introspection()
|
870
867
|
return "updateRunQueueItemWarning" in mutations
|
871
868
|
|
872
|
-
def _server_features(self) ->
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
|
877
|
-
|
878
|
-
|
879
|
-
|
880
|
-
|
881
|
-
|
882
|
-
|
883
|
-
|
884
|
-
raise
|
869
|
+
def _server_features(self) -> Dict[str, bool]:
|
870
|
+
# NOTE: Avoid caching via `@cached_property`, due to undocumented
|
871
|
+
# locking behavior before Python 3.12.
|
872
|
+
# See: https://github.com/python/cpython/issues/87634
|
873
|
+
query = gql(SERVER_FEATURES_QUERY_GQL)
|
874
|
+
try:
|
875
|
+
response = self.gql(query)
|
876
|
+
except Exception as e:
|
877
|
+
# Unfortunately we currently have to match on the text of the error message,
|
878
|
+
# as the `gql` client raises `Exception` rather than a more specific error.
|
879
|
+
if 'Cannot query field "features" on type "ServerInfo".' in str(e):
|
880
|
+
self._server_features_cache = {}
|
885
881
|
else:
|
886
|
-
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
882
|
+
raise
|
883
|
+
else:
|
884
|
+
info = ServerFeaturesQuery.model_validate(response).server_info
|
885
|
+
if info and (feats := info.features):
|
886
|
+
self._server_features_cache = {f.name: f.is_enabled for f in feats if f}
|
887
|
+
else:
|
888
|
+
self._server_features_cache = {}
|
889
|
+
return self._server_features_cache
|
893
890
|
|
894
|
-
|
891
|
+
def _server_supports(self, feature: Union[int, str]) -> bool:
|
892
|
+
"""Return whether the current server supports the given feature.
|
895
893
|
|
896
|
-
|
897
|
-
|
894
|
+
This also caches the underlying lookup of server feature flags,
|
895
|
+
and it maps {feature_name (str) -> is_enabled (bool)}.
|
898
896
|
|
899
897
|
Good to use for features that have a fallback mechanism for older servers.
|
900
|
-
|
901
|
-
Args:
|
902
|
-
feature_value (ServerFeature): The enum value of the feature to check.
|
903
|
-
|
904
|
-
Returns:
|
905
|
-
bool: True if the feature is enabled, False otherwise.
|
906
|
-
|
907
|
-
Exceptions:
|
908
|
-
Exception: If an error other than the server not supporting feature queries occurs.
|
909
898
|
"""
|
910
|
-
|
899
|
+
# If we're given the protobuf enum value, convert to a string name.
|
900
|
+
# NOTE: We deliberately use names (str) instead of enum values (int)
|
901
|
+
# as the keys here, since:
|
902
|
+
# - the server identifies features by their name, rather than (client-side) enum value
|
903
|
+
# - the defined list of client-side flags may be behind the server-side list of flags
|
904
|
+
key = ServerFeature.Name(feature) if isinstance(feature, int) else feature
|
905
|
+
return self._server_features().get(key) or False
|
911
906
|
|
912
907
|
@normalize_exceptions
|
913
908
|
def update_run_queue_item_warning(
|
@@ -2092,9 +2087,7 @@ class Api:
|
|
2092
2087
|
)
|
2093
2088
|
if default is None or default.get("queueID") is None:
|
2094
2089
|
raise CommError(
|
2095
|
-
"Unable to create default queue for {}/{}. No queues for agent to poll"
|
2096
|
-
entity, project
|
2097
|
-
)
|
2090
|
+
f"Unable to create default queue for {entity}/{project}. No queues for agent to poll"
|
2098
2091
|
)
|
2099
2092
|
project_queues = [{"id": default["queueID"], "name": "default"}]
|
2100
2093
|
polling_queue_ids = [
|
@@ -2571,15 +2564,11 @@ class Api:
|
|
2571
2564
|
res = self.gql(query, variable_values)
|
2572
2565
|
if res.get("project") is None:
|
2573
2566
|
raise CommError(
|
2574
|
-
"Error fetching run info for {}/{}/{}. Check that this project exists and you have access to this entity and project"
|
2575
|
-
entity, project, name
|
2576
|
-
)
|
2567
|
+
f"Error fetching run info for {entity}/{project}/{name}. Check that this project exists and you have access to this entity and project"
|
2577
2568
|
)
|
2578
2569
|
elif res["project"].get("run") is None:
|
2579
2570
|
raise CommError(
|
2580
|
-
"Error fetching run info for {}/{}/{}. Check that this run id exists"
|
2581
|
-
entity, project, name
|
2582
|
-
)
|
2571
|
+
f"Error fetching run info for {entity}/{project}/{name}. Check that this run id exists"
|
2583
2572
|
)
|
2584
2573
|
run_info: dict = res["project"]["run"]["runInfo"]
|
2585
2574
|
return run_info
|
@@ -2993,11 +2982,8 @@ class Api:
|
|
2993
2982
|
logger.debug("upload_file: %s complete", url)
|
2994
2983
|
response.raise_for_status()
|
2995
2984
|
except requests.exceptions.RequestException as e:
|
2996
|
-
logger.
|
2997
|
-
request_headers = e.request.headers if e.request is not None else ""
|
2998
|
-
logger.error(f"upload_file request headers: {request_headers!r}")
|
2985
|
+
logger.exception(f"upload_file exception for {url=}")
|
2999
2986
|
response_content = e.response.content if e.response is not None else ""
|
3000
|
-
logger.error(f"upload_file response body: {response_content!r}")
|
3001
2987
|
status_code = e.response.status_code if e.response is not None else 0
|
3002
2988
|
# S3 reports retryable request timeouts out-of-band
|
3003
2989
|
is_aws_retryable = status_code == 400 and "RequestTimeout" in str(
|
@@ -3059,11 +3045,8 @@ class Api:
|
|
3059
3045
|
logger.debug("upload_file: %s complete", url)
|
3060
3046
|
response.raise_for_status()
|
3061
3047
|
except requests.exceptions.RequestException as e:
|
3062
|
-
logger.
|
3063
|
-
request_headers = e.request.headers if e.request is not None else ""
|
3064
|
-
logger.error(f"upload_file request headers: {request_headers}")
|
3048
|
+
logger.exception(f"upload_file exception for {url=}")
|
3065
3049
|
response_content = e.response.content if e.response is not None else ""
|
3066
|
-
logger.error(f"upload_file response body: {response_content!r}")
|
3067
3050
|
status_code = e.response.status_code if e.response is not None else 0
|
3068
3051
|
# S3 reports retryable request timeouts out-of-band
|
3069
3052
|
is_aws_retryable = (
|
@@ -3190,10 +3173,8 @@ class Api:
|
|
3190
3173
|
},
|
3191
3174
|
timeout=60,
|
3192
3175
|
)
|
3193
|
-
except Exception
|
3194
|
-
|
3195
|
-
message = ast.literal_eval(e.args[0])["message"]
|
3196
|
-
logger.error("Error communicating with W&B: %s", message)
|
3176
|
+
except Exception:
|
3177
|
+
logger.exception("Error communicating with W&B.")
|
3197
3178
|
return []
|
3198
3179
|
else:
|
3199
3180
|
result: List[Dict[str, Any]] = json.loads(
|
@@ -3235,10 +3216,8 @@ class Api:
|
|
3235
3216
|
parameter["distribution"] = "uniform"
|
3236
3217
|
else:
|
3237
3218
|
raise ValueError(
|
3238
|
-
"Parameter {} is ambiguous, please specify bounds as both floats (for a float_"
|
3239
|
-
"uniform distribution) or ints (for an int_uniform distribution)."
|
3240
|
-
parameter_name
|
3241
|
-
)
|
3219
|
+
f"Parameter {parameter_name} is ambiguous, please specify bounds as both floats (for a float_"
|
3220
|
+
"uniform distribution) or ints (for an int_uniform distribution)."
|
3242
3221
|
)
|
3243
3222
|
return config
|
3244
3223
|
|
@@ -3387,8 +3366,8 @@ class Api:
|
|
3387
3366
|
variable_values=variables,
|
3388
3367
|
check_retry_fn=util.no_retry_4xx,
|
3389
3368
|
)
|
3390
|
-
except UsageError
|
3391
|
-
raise
|
3369
|
+
except UsageError:
|
3370
|
+
raise
|
3392
3371
|
except Exception as e:
|
3393
3372
|
# graphql schema exception is generic
|
3394
3373
|
err = e
|
@@ -3783,10 +3762,8 @@ class Api:
|
|
3783
3762
|
"usedAs": use_as,
|
3784
3763
|
}
|
3785
3764
|
|
3786
|
-
server_allows_entity_project_information = (
|
3787
|
-
|
3788
|
-
ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION # type: ignore
|
3789
|
-
)
|
3765
|
+
server_allows_entity_project_information = self._server_supports(
|
3766
|
+
ServerFeature.USE_ARTIFACT_WITH_ENTITY_AND_PROJECT_INFORMATION
|
3790
3767
|
)
|
3791
3768
|
if server_allows_entity_project_information:
|
3792
3769
|
query_vars.extend(
|
@@ -4565,9 +4542,9 @@ class Api:
|
|
4565
4542
|
s = self.sweep(sweep=sweep, entity=entity, project=project, specs="{}")
|
4566
4543
|
curr_state = s["state"].upper()
|
4567
4544
|
if state == "PAUSED" and curr_state not in ("PAUSED", "RUNNING"):
|
4568
|
-
raise Exception("Cannot pause {} sweep."
|
4545
|
+
raise Exception(f"Cannot pause {curr_state.lower()} sweep.")
|
4569
4546
|
elif state != "RUNNING" and curr_state not in ("RUNNING", "PAUSED", "PENDING"):
|
4570
|
-
raise Exception("Sweep already {
|
4547
|
+
raise Exception(f"Sweep already {curr_state.lower()}.")
|
4571
4548
|
sweep_id = s["id"]
|
4572
4549
|
mutation = gql(
|
4573
4550
|
"""
|
@@ -19,6 +19,7 @@ from typing import (
|
|
19
19
|
)
|
20
20
|
|
21
21
|
import wandb
|
22
|
+
from wandb.sdk.artifacts._internal_artifact import InternalArtifact
|
22
23
|
from wandb.sdk.artifacts.artifact import Artifact
|
23
24
|
from wandb.sdk.data_types._dtypes import TypeRegistry
|
24
25
|
from wandb.sdk.internal.internal_api import Api
|
@@ -128,12 +129,6 @@ def get_min_supported_for_source_dict(
|
|
128
129
|
return min_seen
|
129
130
|
|
130
131
|
|
131
|
-
class JobArtifact(Artifact):
|
132
|
-
def __init__(self, name: str, *args: Any, **kwargs: Any):
|
133
|
-
super().__init__(name, "placeholder", *args, **kwargs)
|
134
|
-
self._type = JOB_ARTIFACT_TYPE # Get around type restriction.
|
135
|
-
|
136
|
-
|
137
132
|
class JobBuilder:
|
138
133
|
_settings: SettingsStatic
|
139
134
|
_metadatafile_path: Optional[str]
|
@@ -552,7 +547,7 @@ class JobBuilder:
|
|
552
547
|
assert source_info is not None
|
553
548
|
assert name is not None
|
554
549
|
|
555
|
-
artifact =
|
550
|
+
artifact = InternalArtifact(name, JOB_ARTIFACT_TYPE)
|
556
551
|
|
557
552
|
_logger.info("adding wandb-job metadata file")
|
558
553
|
with artifact.new_file("wandb-job.json") as f:
|
wandb/sdk/internal/profiler.py
CHANGED
@@ -54,12 +54,12 @@ def torch_trace_handler():
|
|
54
54
|
prof.step()
|
55
55
|
```
|
56
56
|
"""
|
57
|
-
from
|
57
|
+
from packaging.version import parse
|
58
58
|
|
59
59
|
torch = wandb.util.get_module(PYTORCH_MODULE, required=True)
|
60
60
|
torch_profiler = wandb.util.get_module(PYTORCH_PROFILER_MODULE, required=True)
|
61
61
|
|
62
|
-
if
|
62
|
+
if parse(torch.__version__) < parse("1.9.0"):
|
63
63
|
raise Error(
|
64
64
|
f"torch version must be at least 1.9 in order to use the PyTorch Profiler API.\
|
65
65
|
\nVersion of torch currently installed: {torch.__version__}"
|
wandb/sdk/internal/progress.py
CHANGED
@@ -43,9 +43,7 @@ class Progress:
|
|
43
43
|
# files getting truncated while uploading seems like something
|
44
44
|
# that shouldn't really be happening anyway.
|
45
45
|
raise CommError(
|
46
|
-
"File {} size shrank from {} to {} while it was being uploaded."
|
47
|
-
self.file.name, self.len, self.bytes_read
|
48
|
-
)
|
46
|
+
f"File {self.file.name} size shrank from {self.len} to {self.bytes_read} while it was being uploaded."
|
49
47
|
)
|
50
48
|
# Growing files are also likely to be bad, but our code didn't break
|
51
49
|
# on those in the past, so it's riskier to make that an error now.
|
wandb/sdk/internal/run.py
CHANGED
@@ -5,12 +5,7 @@ Semi-stubbed run for internal process use.
|
|
5
5
|
|
6
6
|
"""
|
7
7
|
|
8
|
-
import
|
9
|
-
|
10
|
-
if sys.version_info >= (3, 12):
|
11
|
-
from typing import override
|
12
|
-
else:
|
13
|
-
from typing_extensions import override
|
8
|
+
from typing_extensions import override
|
14
9
|
|
15
10
|
from wandb.sdk import wandb_run
|
16
11
|
|
wandb/sdk/internal/sender.py
CHANGED
@@ -749,14 +749,12 @@ class SendManager:
|
|
749
749
|
self._resume_state.wandb_runtime = new_runtime
|
750
750
|
tags = resume_status.get("tags") or []
|
751
751
|
|
752
|
-
except (IndexError, ValueError)
|
753
|
-
logger.
|
752
|
+
except (IndexError, ValueError):
|
753
|
+
logger.exception("unable to load resume tails")
|
754
754
|
if self._settings.resume == "must":
|
755
755
|
error = wandb_internal_pb2.ErrorInfo()
|
756
756
|
error.code = wandb_internal_pb2.ErrorInfo.ErrorCode.USAGE
|
757
|
-
error.message = "resume='must' but could not resume ({}) "
|
758
|
-
run.run_id
|
759
|
-
)
|
757
|
+
error.message = f"resume='must' but could not resume ({run.run_id}) "
|
760
758
|
return error
|
761
759
|
|
762
760
|
# TODO: Do we need to restore config / summary?
|
@@ -772,7 +770,7 @@ class SendManager:
|
|
772
770
|
self._resume_state.summary = summary
|
773
771
|
self._resume_state.tags = tags
|
774
772
|
self._resume_state.resumed = True
|
775
|
-
logger.info("configured resuming with: {
|
773
|
+
logger.info(f"configured resuming with: {self._resume_state}")
|
776
774
|
return None
|
777
775
|
|
778
776
|
def _telemetry_get_framework(self) -> str:
|
@@ -816,9 +814,7 @@ class SendManager:
|
|
816
814
|
self._interface.publish_config(
|
817
815
|
key=("_wandb", "spell_url"), val=env.get("SPELL_RUN_URL")
|
818
816
|
)
|
819
|
-
url = "{}/{}/{}/runs/{}"
|
820
|
-
self._api.app_url, self._run.entity, self._run.project, self._run.run_id
|
821
|
-
)
|
817
|
+
url = f"{self._api.app_url}/{self._run.entity}/{self._run.project}/runs/{self._run.run_id}"
|
822
818
|
requests.put(
|
823
819
|
env.get("SPELL_API_URL", "https://api.spell.run") + "/wandb_url",
|
824
820
|
json={"access_token": env.get("WANDB_ACCESS_TOKEN"), "url": url},
|
@@ -829,23 +825,22 @@ class SendManager:
|
|
829
825
|
# TODO: do something if sync spell is not successful?
|
830
826
|
|
831
827
|
def _setup_fork(self, server_run: dict):
|
832
|
-
assert self._settings.fork_from
|
833
|
-
assert self._settings.fork_from.metric == "_step"
|
834
828
|
assert self._run
|
835
|
-
|
829
|
+
assert self._run.branch_point
|
830
|
+
first_step = int(self._run.branch_point.value) + 1
|
836
831
|
self._resume_state.step = first_step
|
837
832
|
self._resume_state.history = server_run.get("historyLineCount", 0)
|
838
833
|
self._run.forked = True
|
839
834
|
self._run.starting_step = first_step
|
840
835
|
|
841
836
|
def _load_rewind_state(self, run: "RunRecord"):
|
842
|
-
assert
|
837
|
+
assert run.branch_point
|
843
838
|
self._rewind_response = self._api.rewind_run(
|
844
839
|
run_name=run.run_id,
|
845
840
|
entity=run.entity or None,
|
846
841
|
project=run.project or None,
|
847
|
-
metric_name=
|
848
|
-
metric_value=
|
842
|
+
metric_name=run.branch_point.metric,
|
843
|
+
metric_value=run.branch_point.value,
|
849
844
|
program_path=self._settings.program or None,
|
850
845
|
)
|
851
846
|
self._resume_state.history = self._rewind_response.get("historyLineCount", 0)
|
@@ -854,12 +849,11 @@ class SendManager:
|
|
854
849
|
)
|
855
850
|
|
856
851
|
def _install_rewind_state(self):
|
857
|
-
assert self._settings.resume_from
|
858
|
-
assert self._settings.resume_from.metric == "_step"
|
859
852
|
assert self._run
|
853
|
+
assert self._run.branch_point
|
860
854
|
assert self._rewind_response
|
861
855
|
|
862
|
-
first_step = int(self.
|
856
|
+
first_step = int(self._run.branch_point.value) + 1
|
863
857
|
self._resume_state.step = first_step
|
864
858
|
|
865
859
|
# We set the fork flag here because rewind uses the forking
|
@@ -903,8 +897,8 @@ class SendManager:
|
|
903
897
|
config_value_dict = self._config_backend_dict()
|
904
898
|
self._config_save(config_value_dict)
|
905
899
|
|
906
|
-
|
907
|
-
|
900
|
+
do_rewind = run.branch_point.run == run.run_id
|
901
|
+
do_fork = not do_rewind and run.branch_point.run != ""
|
908
902
|
do_resume = bool(self._settings.resume)
|
909
903
|
|
910
904
|
num_resume_options_set = sum([do_fork, do_rewind, do_resume])
|
@@ -1188,7 +1182,7 @@ class SendManager:
|
|
1188
1182
|
try:
|
1189
1183
|
d[item.key] = json.loads(item.value_json)
|
1190
1184
|
except json.JSONDecodeError:
|
1191
|
-
logger.
|
1185
|
+
logger.exception("error decoding stats json: %s", item.value_json)
|
1192
1186
|
row: Dict[str, Any] = dict(system=d)
|
1193
1187
|
self._flatten(row)
|
1194
1188
|
row["_wandb"] = True
|
@@ -1500,17 +1494,15 @@ class SendManager:
|
|
1500
1494
|
try:
|
1501
1495
|
res = self._send_artifact(artifact)
|
1502
1496
|
logger.info(f"sent artifact {artifact.name} - {res}")
|
1503
|
-
except Exception
|
1504
|
-
logger.
|
1505
|
-
'send_artifact: failed for artifact "{}/{}"
|
1506
|
-
artifact.type, artifact.name, e
|
1507
|
-
)
|
1497
|
+
except Exception:
|
1498
|
+
logger.exception(
|
1499
|
+
f'send_artifact: failed for artifact "{artifact.type}/{artifact.name}"'
|
1508
1500
|
)
|
1509
1501
|
|
1510
1502
|
def _send_artifact(
|
1511
1503
|
self, artifact: "ArtifactRecord", history_step: Optional[int] = None
|
1512
1504
|
) -> Optional[Dict]:
|
1513
|
-
from
|
1505
|
+
from packaging.version import parse
|
1514
1506
|
|
1515
1507
|
assert self._pusher
|
1516
1508
|
saver = ArtifactSaver(
|
@@ -1523,9 +1515,7 @@ class SendManager:
|
|
1523
1515
|
|
1524
1516
|
if artifact.distributed_id:
|
1525
1517
|
max_cli_version = self._max_cli_version()
|
1526
|
-
if max_cli_version is None or
|
1527
|
-
max_cli_version
|
1528
|
-
) < parse_version("0.10.16"):
|
1518
|
+
if max_cli_version is None or parse(max_cli_version) < parse("0.10.16"):
|
1529
1519
|
logger.warning(
|
1530
1520
|
"This W&B Server doesn't support distributed artifacts, "
|
1531
1521
|
"have your administrator install wandb/local >= 0.9.37"
|
@@ -1561,13 +1551,11 @@ class SendManager:
|
|
1561
1551
|
return res
|
1562
1552
|
|
1563
1553
|
def send_alert(self, record: "Record") -> None:
|
1564
|
-
from
|
1554
|
+
from packaging.version import parse
|
1565
1555
|
|
1566
1556
|
alert = record.alert
|
1567
1557
|
max_cli_version = self._max_cli_version()
|
1568
|
-
if max_cli_version is None or
|
1569
|
-
"0.10.9"
|
1570
|
-
):
|
1558
|
+
if max_cli_version is None or parse(max_cli_version) < parse("0.10.9"):
|
1571
1559
|
logger.warning(
|
1572
1560
|
"This W&B server doesn't support alerts, "
|
1573
1561
|
"have your administrator install wandb/local >= 0.9.31"
|
@@ -1580,8 +1568,8 @@ class SendManager:
|
|
1580
1568
|
level=alert.level,
|
1581
1569
|
wait_duration=alert.wait_duration,
|
1582
1570
|
)
|
1583
|
-
except Exception
|
1584
|
-
logger.
|
1571
|
+
except Exception:
|
1572
|
+
logger.exception(f"send_alert: failed for alert {alert.title!r}")
|
1585
1573
|
|
1586
1574
|
def finish(self) -> None:
|
1587
1575
|
logger.info("shutting down sender")
|
@@ -33,7 +33,7 @@ class DiskUsagePercent:
|
|
33
33
|
try:
|
34
34
|
psutil.disk_usage(path)
|
35
35
|
self.paths.append(path)
|
36
|
-
except Exception as e:
|
36
|
+
except Exception as e:
|
37
37
|
termwarn(f"Could not access disk path {path}: {e}", repeat=False)
|
38
38
|
|
39
39
|
def sample(self) -> None:
|
@@ -74,7 +74,7 @@ class DiskUsage:
|
|
74
74
|
try:
|
75
75
|
psutil.disk_usage(path)
|
76
76
|
self.paths.append(path)
|
77
|
-
except Exception as e:
|
77
|
+
except Exception as e:
|
78
78
|
termwarn(f"Could not access disk path {path}: {e}", repeat=False)
|
79
79
|
|
80
80
|
def sample(self) -> None:
|
@@ -198,7 +198,7 @@ class Disk:
|
|
198
198
|
"total": total,
|
199
199
|
"used": used,
|
200
200
|
}
|
201
|
-
except Exception as e:
|
201
|
+
except Exception as e:
|
202
202
|
termwarn(f"Could not access disk path {disk_path}: {e}", repeat=False)
|
203
203
|
|
204
204
|
return {self.name: disk_metrics}
|
@@ -377,8 +377,8 @@ class GPU:
|
|
377
377
|
return True
|
378
378
|
except pynvml.NVMLError_LibraryNotFound: # type: ignore
|
379
379
|
return False
|
380
|
-
except Exception
|
381
|
-
logger.
|
380
|
+
except Exception:
|
381
|
+
logger.exception("Error initializing NVML.")
|
382
382
|
return False
|
383
383
|
|
384
384
|
def start(self) -> None:
|
@@ -410,7 +410,7 @@ class GPU:
|
|
410
410
|
|
411
411
|
except pynvml.NVMLError:
|
412
412
|
pass
|
413
|
-
except Exception
|
414
|
-
logger.
|
413
|
+
except Exception:
|
414
|
+
logger.exception("Error Probing GPU.")
|
415
415
|
|
416
416
|
return info
|
@@ -104,8 +104,8 @@ class GPUAMDStats:
|
|
104
104
|
if cards:
|
105
105
|
self.samples.append(cards)
|
106
106
|
|
107
|
-
except (OSError, ValueError, TypeError, subprocess.CalledProcessError)
|
108
|
-
logger.exception(
|
107
|
+
except (OSError, ValueError, TypeError, subprocess.CalledProcessError):
|
108
|
+
logger.exception("GPU stats error")
|
109
109
|
|
110
110
|
def clear(self) -> None:
|
111
111
|
self.samples.clear()
|
@@ -228,6 +228,6 @@ class GPUAMD:
|
|
228
228
|
for key in stats.keys()
|
229
229
|
if key.startswith("card")
|
230
230
|
]
|
231
|
-
except Exception
|
232
|
-
logger.exception(
|
231
|
+
except Exception:
|
232
|
+
logger.exception("GPUAMD probe error")
|
233
233
|
return info
|
@@ -136,8 +136,8 @@ class MetricsMonitor:
|
|
136
136
|
logger.info(f"Process {metric.name} has exited.")
|
137
137
|
self._shutdown_event.set()
|
138
138
|
break
|
139
|
-
except Exception
|
140
|
-
logger.
|
139
|
+
except Exception:
|
140
|
+
logger.exception("Failed to sample metric.")
|
141
141
|
self._shutdown_event.wait(self.sampling_interval)
|
142
142
|
if self._shutdown_event.is_set():
|
143
143
|
break
|
@@ -153,8 +153,8 @@ class MetricsMonitor:
|
|
153
153
|
# aggregated_metrics = wandb.util.merge_dicts(
|
154
154
|
# aggregated_metrics, metric.serialize()
|
155
155
|
# )
|
156
|
-
except Exception
|
157
|
-
logger.
|
156
|
+
except Exception:
|
157
|
+
logger.exception("Failed to serialize metric.")
|
158
158
|
return aggregated_metrics
|
159
159
|
|
160
160
|
def publish(self) -> None:
|
@@ -165,8 +165,8 @@ class MetricsMonitor:
|
|
165
165
|
self._interface.publish_stats(aggregated_metrics)
|
166
166
|
for metric in self.metrics:
|
167
167
|
metric.clear()
|
168
|
-
except Exception
|
169
|
-
logger.
|
168
|
+
except Exception:
|
169
|
+
logger.exception("Failed to publish metrics.")
|
170
170
|
|
171
171
|
def start(self) -> None:
|
172
172
|
if (self._process is not None) or self._shutdown_event.is_set():
|