wandb 0.17.8rc1__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package_readme.md +47 -53
- wandb/__init__.py +8 -7
- wandb/__init__.pyi +134 -14
- wandb/analytics/sentry.py +1 -1
- wandb/bin/nvidia_gpu_stats +0 -0
- wandb/cli/cli.py +15 -10
- wandb/data_types.py +1 -0
- wandb/env.py +4 -3
- wandb/integration/keras/__init__.py +2 -5
- wandb/integration/keras/callbacks/metrics_logger.py +10 -4
- wandb/integration/keras/callbacks/model_checkpoint.py +0 -5
- wandb/integration/keras/keras.py +12 -1
- wandb/integration/openai/fine_tuning.py +5 -5
- wandb/integration/tensorboard/log.py +1 -4
- wandb/jupyter.py +18 -3
- wandb/proto/v3/wandb_internal_pb2.py +238 -228
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v4/wandb_internal_pb2.py +230 -228
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +12 -12
- wandb/proto/v5/wandb_internal_pb2.py +230 -228
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +12 -12
- wandb/proto/wandb_deprecated.py +4 -0
- wandb/sdk/__init__.py +1 -1
- wandb/sdk/artifacts/_validators.py +45 -0
- wandb/sdk/artifacts/artifact.py +109 -56
- wandb/sdk/artifacts/artifact_manifest_entry.py +10 -2
- wandb/sdk/artifacts/artifact_saver.py +6 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +31 -0
- wandb/sdk/interface/interface.py +7 -4
- wandb/sdk/internal/internal_api.py +16 -6
- wandb/sdk/internal/sender.py +1 -0
- wandb/sdk/internal/system/assets/gpu_amd.py +11 -2
- wandb/sdk/internal/system/assets/trainium.py +2 -1
- wandb/sdk/internal/tb_watcher.py +1 -1
- wandb/sdk/launch/inputs/internal.py +2 -2
- wandb/sdk/lib/_settings_toposort_generated.py +3 -3
- wandb/sdk/service/server_sock.py +1 -1
- wandb/sdk/service/service.py +6 -2
- wandb/sdk/wandb_init.py +12 -3
- wandb/sdk/wandb_login.py +1 -0
- wandb/sdk/wandb_manager.py +0 -3
- wandb/sdk/wandb_require.py +7 -2
- wandb/sdk/wandb_run.py +40 -14
- wandb/sdk/wandb_settings.py +32 -12
- wandb/sdk/wandb_setup.py +3 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.18.0.dist-info}/METADATA +52 -58
- {wandb-0.17.8rc1.dist-info → wandb-0.18.0.dist-info}/RECORD +54 -54
- wandb/testing/relay.py +0 -874
- /wandb/{viz.py → sdk/lib/viz.py} +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.18.0.dist-info}/WHEEL +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.18.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.17.8rc1.dist-info → wandb-0.18.0.dist-info}/licenses/LICENSE +0 -0
@@ -19,6 +19,12 @@ if TYPE_CHECKING:
|
|
19
19
|
from wandb.sdk.artifacts.artifact import Artifact
|
20
20
|
|
21
21
|
|
22
|
+
class _GCSIsADirectoryError(Exception):
|
23
|
+
"""Raised when we try to download a GCS folder."""
|
24
|
+
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
22
28
|
class GCSHandler(StorageHandler):
|
23
29
|
_client: Optional["gcs_module.client.Client"]
|
24
30
|
|
@@ -70,6 +76,11 @@ class GCSHandler(StorageHandler):
|
|
70
76
|
bucket, key, _ = self._parse_uri(manifest_entry.ref)
|
71
77
|
version = manifest_entry.extra.get("versionID")
|
72
78
|
|
79
|
+
if self._is_dir(manifest_entry):
|
80
|
+
raise _GCSIsADirectoryError(
|
81
|
+
f"Unable to download GCS folder {manifest_entry.ref!r}, skipping"
|
82
|
+
)
|
83
|
+
|
73
84
|
obj = None
|
74
85
|
# First attempt to get the generation specified, this will return None if versioning is not enabled
|
75
86
|
if version is not None:
|
@@ -135,6 +146,7 @@ class GCSHandler(StorageHandler):
|
|
135
146
|
entries = [
|
136
147
|
self._entry_from_obj(obj, path, name, prefix=key, multi=multi)
|
137
148
|
for obj in objects
|
149
|
+
if not obj.name.endswith("/")
|
138
150
|
]
|
139
151
|
if start_time is not None:
|
140
152
|
termlog("Done. %.1fs" % (time.time() - start_time), prefix=False)
|
@@ -193,3 +205,22 @@ class GCSHandler(StorageHandler):
|
|
193
205
|
size=obj.size,
|
194
206
|
extra={"versionID": obj.generation},
|
195
207
|
)
|
208
|
+
|
209
|
+
def _is_dir(
|
210
|
+
self,
|
211
|
+
manifest_entry: ArtifactManifestEntry,
|
212
|
+
) -> bool:
|
213
|
+
assert self._client is not None
|
214
|
+
assert manifest_entry.ref is not None
|
215
|
+
bucket, key, _ = self._parse_uri(manifest_entry.ref)
|
216
|
+
bucket_obj = self._client.bucket(bucket)
|
217
|
+
# A gcs bucket key should end with a forward slash on gcloud, but
|
218
|
+
# we save these refs without the forward slash in the manifest entry
|
219
|
+
# so we check the size and extension, make sure its not referring to
|
220
|
+
# an actual file with this reference, and that the ref with the slash
|
221
|
+
# exists on gcloud
|
222
|
+
return key.endswith("/") or (
|
223
|
+
not (manifest_entry.size or PurePosixPath(key).suffix)
|
224
|
+
and bucket_obj.get_blob(key) is None
|
225
|
+
and bucket_obj.get_blob(f"{key}/") is not None
|
226
|
+
)
|
wandb/sdk/interface/interface.py
CHANGED
@@ -519,6 +519,7 @@ class InterfaceBase:
|
|
519
519
|
run: "Run",
|
520
520
|
artifact: "Artifact",
|
521
521
|
aliases: Iterable[str],
|
522
|
+
tags: Optional[Iterable[str]] = None,
|
522
523
|
history_step: Optional[int] = None,
|
523
524
|
is_user_created: bool = False,
|
524
525
|
use_after_commit: bool = False,
|
@@ -532,8 +533,9 @@ class InterfaceBase:
|
|
532
533
|
proto_artifact.user_created = is_user_created
|
533
534
|
proto_artifact.use_after_commit = use_after_commit
|
534
535
|
proto_artifact.finalize = finalize
|
535
|
-
|
536
|
-
|
536
|
+
|
537
|
+
proto_artifact.aliases.extend(aliases or [])
|
538
|
+
proto_artifact.tags.extend(tags or [])
|
537
539
|
|
538
540
|
log_artifact = pb.LogArtifactRequest()
|
539
541
|
log_artifact.artifact.CopyFrom(proto_artifact)
|
@@ -577,6 +579,7 @@ class InterfaceBase:
|
|
577
579
|
run: "Run",
|
578
580
|
artifact: "Artifact",
|
579
581
|
aliases: Iterable[str],
|
582
|
+
tags: Optional[Iterable[str]] = None,
|
580
583
|
is_user_created: bool = False,
|
581
584
|
use_after_commit: bool = False,
|
582
585
|
finalize: bool = True,
|
@@ -589,8 +592,8 @@ class InterfaceBase:
|
|
589
592
|
proto_artifact.user_created = is_user_created
|
590
593
|
proto_artifact.use_after_commit = use_after_commit
|
591
594
|
proto_artifact.finalize = finalize
|
592
|
-
|
593
|
-
|
595
|
+
proto_artifact.aliases.extend(aliases or [])
|
596
|
+
proto_artifact.tags.extend(tags or [])
|
594
597
|
self._publish_artifact(proto_artifact)
|
595
598
|
|
596
599
|
@abstractmethod
|
@@ -3551,7 +3551,7 @@ class Api:
|
|
3551
3551
|
_id: Optional[str] = response["createArtifactType"]["artifactType"]["id"]
|
3552
3552
|
return _id
|
3553
3553
|
|
3554
|
-
def server_artifact_introspection(self) -> List:
|
3554
|
+
def server_artifact_introspection(self) -> List[str]:
|
3555
3555
|
query_string = """
|
3556
3556
|
query ProbeServerArtifact {
|
3557
3557
|
ArtifactInfoType: __type(name:"Artifact") {
|
@@ -3572,7 +3572,7 @@ class Api:
|
|
3572
3572
|
|
3573
3573
|
return self.server_artifact_fields_info
|
3574
3574
|
|
3575
|
-
def server_create_artifact_introspection(self) -> List:
|
3575
|
+
def server_create_artifact_introspection(self) -> List[str]:
|
3576
3576
|
query_string = """
|
3577
3577
|
query ProbeServerCreateArtifactInput {
|
3578
3578
|
CreateArtifactInputInfoType: __type(name:"CreateArtifactInput") {
|
@@ -3627,6 +3627,10 @@ class Api:
|
|
3627
3627
|
types += "$ttlDurationSeconds: Int64,"
|
3628
3628
|
values += "ttlDurationSeconds: $ttlDurationSeconds,"
|
3629
3629
|
|
3630
|
+
if "tags" in fields:
|
3631
|
+
types += "$tags: [TagInput!],"
|
3632
|
+
values += "tags: $tags,"
|
3633
|
+
|
3630
3634
|
query_template = """
|
3631
3635
|
mutation CreateArtifact(
|
3632
3636
|
$artifactTypeName: String!,
|
@@ -3686,18 +3690,25 @@ class Api:
|
|
3686
3690
|
metadata: Optional[Dict] = None,
|
3687
3691
|
ttl_duration_seconds: Optional[int] = None,
|
3688
3692
|
aliases: Optional[List[Dict[str, str]]] = None,
|
3693
|
+
tags: Optional[List[Dict[str, str]]] = None,
|
3689
3694
|
distributed_id: Optional[str] = None,
|
3690
3695
|
is_user_created: Optional[bool] = False,
|
3691
3696
|
history_step: Optional[int] = None,
|
3692
3697
|
) -> Tuple[Dict, Dict]:
|
3693
3698
|
fields = self.server_create_artifact_introspection()
|
3694
3699
|
artifact_fields = self.server_artifact_introspection()
|
3695
|
-
if "ttlIsInherited" not in artifact_fields and ttl_duration_seconds:
|
3700
|
+
if ("ttlIsInherited" not in artifact_fields) and ttl_duration_seconds:
|
3696
3701
|
wandb.termwarn(
|
3697
3702
|
"Server not compatible with setting Artifact TTLs, please upgrade the server to use Artifact TTL"
|
3698
3703
|
)
|
3699
3704
|
# ttlDurationSeconds is only usable if ttlIsInherited is also present
|
3700
3705
|
ttl_duration_seconds = None
|
3706
|
+
if ("tags" not in artifact_fields) and tags:
|
3707
|
+
wandb.termwarn(
|
3708
|
+
"Server not compatible with Artifact tags. "
|
3709
|
+
"To use Artifact tags, please upgrade the server to v0.85 or higher."
|
3710
|
+
)
|
3711
|
+
|
3701
3712
|
query_template = self._get_create_artifact_mutation(
|
3702
3713
|
fields, history_step, distributed_id
|
3703
3714
|
)
|
@@ -3706,8 +3717,6 @@ class Api:
|
|
3706
3717
|
project_name = project_name or self.settings("project")
|
3707
3718
|
if not is_user_created:
|
3708
3719
|
run_name = run_name or self.current_run_id
|
3709
|
-
if aliases is None:
|
3710
|
-
aliases = []
|
3711
3720
|
|
3712
3721
|
mutation = gql(query_template)
|
3713
3722
|
response = self.gql(
|
@@ -3722,7 +3731,8 @@ class Api:
|
|
3722
3731
|
"sequenceClientID": sequence_client_id,
|
3723
3732
|
"digest": digest,
|
3724
3733
|
"description": description,
|
3725
|
-
"aliases":
|
3734
|
+
"aliases": list(aliases or []),
|
3735
|
+
"tags": list(tags or []),
|
3726
3736
|
"metadata": json.dumps(util.make_safe_for_json(metadata))
|
3727
3737
|
if metadata
|
3728
3738
|
else None,
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1589,6 +1589,7 @@ class SendManager:
|
|
1589
1589
|
ttl_duration_seconds=artifact.ttl_duration_seconds or None,
|
1590
1590
|
description=artifact.description or None,
|
1591
1591
|
aliases=artifact.aliases,
|
1592
|
+
tags=artifact.tags,
|
1592
1593
|
use_after_commit=artifact.use_after_commit,
|
1593
1594
|
distributed_id=artifact.distributed_id,
|
1594
1595
|
finalize=artifact.finalize,
|
@@ -183,8 +183,17 @@ class GPUAMD:
|
|
183
183
|
|
184
184
|
can_read_rocm_smi = False
|
185
185
|
try:
|
186
|
-
|
187
|
-
|
186
|
+
# try to read stats from rocm-smi and parse them
|
187
|
+
raw_stats = get_rocm_smi_stats()
|
188
|
+
card_keys = [
|
189
|
+
key for key in sorted(raw_stats.keys()) if key.startswith("card")
|
190
|
+
]
|
191
|
+
|
192
|
+
for card_key in card_keys:
|
193
|
+
card_stats = raw_stats[card_key]
|
194
|
+
parse_stats(card_stats)
|
195
|
+
|
196
|
+
can_read_rocm_smi = True
|
188
197
|
except Exception:
|
189
198
|
pass
|
190
199
|
|
@@ -224,7 +224,8 @@ class NeuronCoreStats:
|
|
224
224
|
for k, v in usage_breakdown["neuroncore_memory_usage"].items()
|
225
225
|
}
|
226
226
|
|
227
|
-
#
|
227
|
+
# When the training script is executed with torchrun,
|
228
|
+
# we only want to keep the relevant LOCAL_RANK stats
|
228
229
|
local_rank = int(os.environ.get("LOCAL_RANK", -1337))
|
229
230
|
if local_rank >= 0:
|
230
231
|
neuroncore_utilization = {
|
wandb/sdk/internal/tb_watcher.py
CHANGED
@@ -141,9 +141,9 @@ def _replace_refs_and_allofs(schema: dict, defs: Optional[dict]) -> dict:
|
|
141
141
|
ret: Dict[str, Any] = {}
|
142
142
|
if "$ref" in schema and defs:
|
143
143
|
# Reference found, replace it with its definition
|
144
|
-
def_key = schema
|
144
|
+
def_key = schema.pop("$ref").split("#/$defs/")[1]
|
145
145
|
# Also run recursive replacement in case a ref contains more refs
|
146
|
-
|
146
|
+
ret = _replace_refs_and_allofs(defs.pop(def_key), defs)
|
147
147
|
for key, val in schema.items():
|
148
148
|
if isinstance(val, dict):
|
149
149
|
# Step into dicts recursively
|
@@ -62,7 +62,7 @@ _Setting = Literal[
|
|
62
62
|
"_proxies",
|
63
63
|
"_python",
|
64
64
|
"_runqueue_item_id",
|
65
|
-
"
|
65
|
+
"_require_legacy_service",
|
66
66
|
"_save_requirements",
|
67
67
|
"_service_transport",
|
68
68
|
"_service_wait",
|
@@ -70,6 +70,7 @@ _Setting = Literal[
|
|
70
70
|
"_start_datetime",
|
71
71
|
"_start_time",
|
72
72
|
"_stats_pid",
|
73
|
+
"_stats_sampling_interval",
|
73
74
|
"_stats_sample_rate_seconds",
|
74
75
|
"_stats_samples_to_average",
|
75
76
|
"_stats_join_assets",
|
@@ -173,8 +174,6 @@ _Setting = Literal[
|
|
173
174
|
"sync_dir",
|
174
175
|
"sync_file",
|
175
176
|
"sync_symlink_latest",
|
176
|
-
"system_sample",
|
177
|
-
"system_sample_seconds",
|
178
177
|
"table_raise_on_max_row_limit_exceeded",
|
179
178
|
"timespec",
|
180
179
|
"tmp_dir",
|
@@ -186,6 +185,7 @@ SETTINGS_TOPOLOGICALLY_SORTED: Final[Tuple[_Setting, ...]] = (
|
|
186
185
|
"_service_wait",
|
187
186
|
"_stats_sample_rate_seconds",
|
188
187
|
"_stats_samples_to_average",
|
188
|
+
"_stats_sampling_interval",
|
189
189
|
"anonymous",
|
190
190
|
"api_key",
|
191
191
|
"base_url",
|
wandb/sdk/service/server_sock.py
CHANGED
@@ -199,7 +199,6 @@ class SockAcceptThread(threading.Thread):
|
|
199
199
|
self._clients = ClientDict()
|
200
200
|
|
201
201
|
def run(self) -> None:
|
202
|
-
self._sock.listen(5)
|
203
202
|
read_threads = []
|
204
203
|
|
205
204
|
while not self._stopped.is_set():
|
@@ -254,6 +253,7 @@ class SocketServer:
|
|
254
253
|
|
255
254
|
def start(self) -> None:
|
256
255
|
self._bind()
|
256
|
+
self._sock.listen(5)
|
257
257
|
self._thread = SockAcceptThread(sock=self._sock, mux=self._mux)
|
258
258
|
self._thread.start()
|
259
259
|
# Note: Uncomment to figure out what thread is not exiting properly
|
wandb/sdk/service/service.py
CHANGED
@@ -15,7 +15,11 @@ import time
|
|
15
15
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
16
16
|
|
17
17
|
from wandb import _sentry, termlog
|
18
|
-
from wandb.env import
|
18
|
+
from wandb.env import (
|
19
|
+
core_debug,
|
20
|
+
core_error_reporting_enabled,
|
21
|
+
is_require_legacy_service,
|
22
|
+
)
|
19
23
|
from wandb.errors import Error, WandbCoreNotAvailableError
|
20
24
|
from wandb.sdk.lib.wburls import wburls
|
21
25
|
from wandb.util import get_core_path, get_module
|
@@ -164,7 +168,7 @@ class _Service:
|
|
164
168
|
|
165
169
|
service_args = []
|
166
170
|
|
167
|
-
if
|
171
|
+
if not is_require_legacy_service():
|
168
172
|
try:
|
169
173
|
core_path = get_core_path()
|
170
174
|
except WandbCoreNotAvailableError as e:
|
wandb/sdk/wandb_init.py
CHANGED
@@ -273,8 +273,9 @@ class _WandbInit:
|
|
273
273
|
|
274
274
|
tensorboard = kwargs.pop("tensorboard", None)
|
275
275
|
sync_tensorboard = kwargs.pop("sync_tensorboard", None)
|
276
|
-
if tensorboard or sync_tensorboard
|
277
|
-
wandb.tensorboard
|
276
|
+
if tensorboard or sync_tensorboard:
|
277
|
+
if len(wandb.patched["tensorboard"]) == 0:
|
278
|
+
wandb.tensorboard.patch() # type: ignore
|
278
279
|
with telemetry.context(obj=self._init_telemetry_obj) as tel:
|
279
280
|
tel.feature.tensorboard_sync = True
|
280
281
|
|
@@ -566,6 +567,7 @@ class _WandbInit:
|
|
566
567
|
"watch",
|
567
568
|
"unwatch",
|
568
569
|
"upsert_artifact",
|
570
|
+
"_finish",
|
569
571
|
):
|
570
572
|
setattr(drun, symbol, lambda *_, **__: None) # type: ignore
|
571
573
|
# attributes
|
@@ -736,7 +738,7 @@ class _WandbInit:
|
|
736
738
|
tel.feature.flow_control_disabled = True
|
737
739
|
if self.settings._flow_control_custom:
|
738
740
|
tel.feature.flow_control_custom = True
|
739
|
-
if self.settings.
|
741
|
+
if not self.settings._require_legacy_service:
|
740
742
|
tel.feature.core = True
|
741
743
|
if self.settings._shared:
|
742
744
|
wandb.termwarn(
|
@@ -1153,6 +1155,7 @@ def init(
|
|
1153
1155
|
mode if a user isn't logged in to W&B. (default: `False`)
|
1154
1156
|
sync_tensorboard: (bool, optional) Synchronize wandb logs from tensorboard or
|
1155
1157
|
tensorboardX and save the relevant events file. (default: `False`)
|
1158
|
+
tensorboard: (bool, optional) Alias for `sync_tensorboard`, deprecated.
|
1156
1159
|
monitor_gym: (bool, optional) Automatically log videos of environment when
|
1157
1160
|
using OpenAI Gym. (default: `False`)
|
1158
1161
|
See [our guide to this integration](https://docs.wandb.com/guides/integrations/openai-gym).
|
@@ -1166,6 +1169,12 @@ def init(
|
|
1166
1169
|
a moment in a previous run to fork a new run from. Creates a new run that picks up
|
1167
1170
|
logging history from the specified run at the specified moment. The target run must
|
1168
1171
|
be in the current project. Example: `fork_from="my-run-id?_step=1234"`.
|
1172
|
+
resume_from: (str, optional) A string with the format {run_id}?_step={step} describing
|
1173
|
+
a moment in a previous run to resume a run from. This allows users to truncate
|
1174
|
+
the history logged to a run at an intermediate step and resume logging from that step.
|
1175
|
+
It uses run forking under the hood. The target run must be in the
|
1176
|
+
current project. Example: `resume_from="my-run-id?_step=1234"`.
|
1177
|
+
settings: (dict, wandb.Settings, optional) Settings to use for this run. (default: None)
|
1169
1178
|
|
1170
1179
|
Examples:
|
1171
1180
|
### Set where the run is logged
|
wandb/sdk/wandb_login.py
CHANGED
@@ -64,6 +64,7 @@ def login(
|
|
64
64
|
"allow", only create an anonymous user if the user
|
65
65
|
isn't already logged in. If set to "never", never log a
|
66
66
|
user anonymously. Default set to "never".
|
67
|
+
key: (string, optional) The API key to use.
|
67
68
|
relogin: (bool, optional) If true, will re-prompt for API key.
|
68
69
|
host: (string, optional) The host to connect to.
|
69
70
|
force: (bool, optional) If true, will force a relogin.
|
wandb/sdk/wandb_manager.py
CHANGED
@@ -13,7 +13,6 @@ import wandb
|
|
13
13
|
from wandb import env, trigger
|
14
14
|
from wandb.errors import Error
|
15
15
|
from wandb.sdk.lib.exit_hooks import ExitHooks
|
16
|
-
from wandb.sdk.lib.import_hooks import unregister_all_post_import_hooks
|
17
16
|
|
18
17
|
if TYPE_CHECKING:
|
19
18
|
from wandb.proto import wandb_settings_pb2
|
@@ -163,8 +162,6 @@ class _Manager:
|
|
163
162
|
This sends a teardown record to the process. An exception is raised if
|
164
163
|
the process has already been shut down.
|
165
164
|
"""
|
166
|
-
unregister_all_post_import_hooks()
|
167
|
-
|
168
165
|
if self._atexit_lambda:
|
169
166
|
atexit.unregister(self._atexit_lambda)
|
170
167
|
self._atexit_lambda = None
|
wandb/sdk/wandb_require.py
CHANGED
@@ -13,7 +13,7 @@ import os
|
|
13
13
|
from typing import Optional, Sequence, Union
|
14
14
|
|
15
15
|
import wandb
|
16
|
-
from wandb.env import
|
16
|
+
from wandb.env import _REQUIRE_LEGACY_SERVICE
|
17
17
|
from wandb.errors import UnsupportedError
|
18
18
|
from wandb.sdk import wandb_run
|
19
19
|
from wandb.sdk.lib.wburls import wburls
|
@@ -41,7 +41,12 @@ class _Requires:
|
|
41
41
|
self._require_service()
|
42
42
|
|
43
43
|
def require_core(self) -> None:
|
44
|
-
|
44
|
+
wandb.termwarn(
|
45
|
+
"`wandb.require('core')` is redundant as it is now the default behavior."
|
46
|
+
)
|
47
|
+
|
48
|
+
def require_legacy_service(self) -> None:
|
49
|
+
os.environ[_REQUIRE_LEGACY_SERVICE] = "true"
|
45
50
|
|
46
51
|
def apply(self) -> None:
|
47
52
|
"""Call require_* method for supported features."""
|
wandb/sdk/wandb_run.py
CHANGED
@@ -57,6 +57,7 @@ from wandb.sdk.lib.import_hooks import (
|
|
57
57
|
unregister_post_import_hook,
|
58
58
|
)
|
59
59
|
from wandb.sdk.lib.paths import FilePathStr, LogicalPath, StrPath
|
60
|
+
from wandb.sdk.lib.viz import CustomChart, Visualize, custom_chart
|
60
61
|
from wandb.util import (
|
61
62
|
_is_artifact_object,
|
62
63
|
_is_artifact_string,
|
@@ -66,9 +67,9 @@ from wandb.util import (
|
|
66
67
|
add_import_hook,
|
67
68
|
parse_artifact_string,
|
68
69
|
)
|
69
|
-
from wandb.viz import CustomChart, Visualize, custom_chart
|
70
70
|
|
71
71
|
from . import wandb_config, wandb_metric, wandb_summary
|
72
|
+
from .artifacts._validators import validate_aliases, validate_tags
|
72
73
|
from .data_types._dtypes import TypeRegistry
|
73
74
|
from .interface.interface import FilesDict, GlobStr, InterfaceBase, PolicyName
|
74
75
|
from .interface.summary_record import SummaryRecord
|
@@ -1654,7 +1655,8 @@ class Run:
|
|
1654
1655
|
if os.getpid() != self._init_pid or self._is_attached:
|
1655
1656
|
wandb.termwarn(
|
1656
1657
|
"Note that setting step in multiprocessing can result in data loss. "
|
1657
|
-
"Please
|
1658
|
+
"Please use `run.define_metric(...)` to define a custom metric "
|
1659
|
+
"to log your step values.",
|
1658
1660
|
repeat=False,
|
1659
1661
|
)
|
1660
1662
|
# if step is passed in when tensorboard_sync is used we honor the step passed
|
@@ -1662,8 +1664,9 @@ class Run:
|
|
1662
1664
|
# this history later on in publish_history()
|
1663
1665
|
if len(wandb.patched["tensorboard"]) > 0:
|
1664
1666
|
wandb.termwarn(
|
1665
|
-
"Step cannot be set when using syncing
|
1666
|
-
"Please
|
1667
|
+
"Step cannot be set when using tensorboard syncing. "
|
1668
|
+
"Please use `run.define_metric(...)` to define a custom metric "
|
1669
|
+
"to log your step values.",
|
1667
1670
|
repeat=False,
|
1668
1671
|
)
|
1669
1672
|
if step > self._step:
|
@@ -2785,6 +2788,14 @@ class Run:
|
|
2785
2788
|
self,
|
2786
2789
|
)
|
2787
2790
|
|
2791
|
+
if (summary and "best" in summary) or goal is not None:
|
2792
|
+
deprecate.deprecate(
|
2793
|
+
deprecate.Deprecated.run__define_metric_best_goal,
|
2794
|
+
"define_metric(summary='best', goal=...) is deprecated and will be removed. "
|
2795
|
+
"Use define_metric(summary='min') or define_metric(summary='max') instead.",
|
2796
|
+
self,
|
2797
|
+
)
|
2798
|
+
|
2788
2799
|
return self._define_metric(
|
2789
2800
|
name,
|
2790
2801
|
step_metric,
|
@@ -3101,6 +3112,7 @@ class Run:
|
|
3101
3112
|
name: Optional[str] = None,
|
3102
3113
|
type: Optional[str] = None,
|
3103
3114
|
aliases: Optional[List[str]] = None,
|
3115
|
+
tags: Optional[List[str]] = None,
|
3104
3116
|
) -> Artifact:
|
3105
3117
|
"""Declare an artifact as an output of a run.
|
3106
3118
|
|
@@ -3121,12 +3133,17 @@ class Run:
|
|
3121
3133
|
type: (str) The type of artifact to log, examples include `dataset`, `model`
|
3122
3134
|
aliases: (list, optional) Aliases to apply to this artifact,
|
3123
3135
|
defaults to `["latest"]`
|
3136
|
+
tags: (list, optional) Tags to apply to this artifact, if any.
|
3124
3137
|
|
3125
3138
|
Returns:
|
3126
3139
|
An `Artifact` object.
|
3127
3140
|
"""
|
3128
3141
|
return self._log_artifact(
|
3129
|
-
artifact_or_path,
|
3142
|
+
artifact_or_path,
|
3143
|
+
name=name,
|
3144
|
+
type=type,
|
3145
|
+
aliases=aliases,
|
3146
|
+
tags=tags,
|
3130
3147
|
)
|
3131
3148
|
|
3132
3149
|
@_run_decorator._noop_on_finish()
|
@@ -3243,6 +3260,7 @@ class Run:
|
|
3243
3260
|
name: Optional[str] = None,
|
3244
3261
|
type: Optional[str] = None,
|
3245
3262
|
aliases: Optional[List[str]] = None,
|
3263
|
+
tags: Optional[List[str]] = None,
|
3246
3264
|
distributed_id: Optional[str] = None,
|
3247
3265
|
finalize: bool = True,
|
3248
3266
|
is_user_created: bool = False,
|
@@ -3253,13 +3271,17 @@ class Run:
|
|
3253
3271
|
wandb.termwarn(
|
3254
3272
|
"Artifacts logged anonymously cannot be claimed and expire after 7 days."
|
3255
3273
|
)
|
3274
|
+
|
3256
3275
|
if not finalize and distributed_id is None:
|
3257
3276
|
raise TypeError("Must provide distributed_id if artifact is not finalize")
|
3277
|
+
|
3258
3278
|
if aliases is not None:
|
3259
|
-
|
3260
|
-
|
3261
|
-
|
3262
|
-
|
3279
|
+
aliases = validate_aliases(aliases)
|
3280
|
+
|
3281
|
+
# Check if artifact tags are supported
|
3282
|
+
if tags is not None:
|
3283
|
+
tags = validate_tags(tags)
|
3284
|
+
|
3263
3285
|
artifact, aliases = self._prepare_artifact(
|
3264
3286
|
artifact_or_path, name, type, aliases
|
3265
3287
|
)
|
@@ -3271,6 +3293,7 @@ class Run:
|
|
3271
3293
|
self,
|
3272
3294
|
artifact,
|
3273
3295
|
aliases,
|
3296
|
+
tags,
|
3274
3297
|
self.step,
|
3275
3298
|
finalize=finalize,
|
3276
3299
|
is_user_created=is_user_created,
|
@@ -3282,6 +3305,7 @@ class Run:
|
|
3282
3305
|
self,
|
3283
3306
|
artifact,
|
3284
3307
|
aliases,
|
3308
|
+
tags,
|
3285
3309
|
finalize=finalize,
|
3286
3310
|
is_user_created=is_user_created,
|
3287
3311
|
use_after_commit=use_after_commit,
|
@@ -3291,6 +3315,7 @@ class Run:
|
|
3291
3315
|
self,
|
3292
3316
|
artifact,
|
3293
3317
|
aliases,
|
3318
|
+
tags,
|
3294
3319
|
finalize=finalize,
|
3295
3320
|
is_user_created=is_user_created,
|
3296
3321
|
use_after_commit=use_after_commit,
|
@@ -3364,6 +3389,7 @@ class Run:
|
|
3364
3389
|
"You must pass an instance of wandb.Artifact or a "
|
3365
3390
|
"valid file path to log_artifact"
|
3366
3391
|
)
|
3392
|
+
|
3367
3393
|
artifact.finalize()
|
3368
3394
|
return artifact, _resolve_aliases(aliases)
|
3369
3395
|
|
@@ -3502,7 +3528,7 @@ class Run:
|
|
3502
3528
|
registered_model_name: (str) - the name of the registered model that the model is to be linked to.
|
3503
3529
|
A registered model is a collection of model versions linked to the model registry, typically representing a
|
3504
3530
|
team's specific ML Task. The entity that this registered model belongs to will be derived from the run
|
3505
|
-
|
3531
|
+
name: (str, optional) - the name of the model artifact that files in 'path' will be logged to. This will
|
3506
3532
|
default to the basename of the path prepended with the current run id if not specified.
|
3507
3533
|
aliases: (List[str], optional) - alias(es) that will only be applied on this linked artifact
|
3508
3534
|
inside the registered model.
|
@@ -4234,13 +4260,13 @@ class Run:
|
|
4234
4260
|
printer: Union["PrinterTerm", "PrinterJupyter"],
|
4235
4261
|
) -> None:
|
4236
4262
|
"""Prints a message advertising the upcoming core release."""
|
4237
|
-
if quiet or settings.
|
4263
|
+
if quiet or not settings._require_legacy_service:
|
4238
4264
|
return
|
4239
4265
|
|
4240
4266
|
printer.display(
|
4241
|
-
"The
|
4242
|
-
|
4243
|
-
"
|
4267
|
+
"The legacy backend is deprecated. In future versions, `wandb-core` will become "
|
4268
|
+
"the sole backend service, and the `wandb.require('legacy-service')` flag will be removed. "
|
4269
|
+
"For more information, visit https://wandb.me/wandb-core",
|
4244
4270
|
level="warn",
|
4245
4271
|
)
|
4246
4272
|
|