wandb 0.20.1rc20250604__py3-none-any.whl → 0.21.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +3 -6
- wandb/__init__.pyi +24 -23
- wandb/analytics/sentry.py +2 -2
- wandb/apis/importers/internals/internal.py +0 -3
- wandb/apis/internal.py +3 -0
- wandb/apis/paginator.py +17 -4
- wandb/apis/public/api.py +85 -4
- wandb/apis/public/artifacts.py +10 -8
- wandb/apis/public/files.py +5 -5
- wandb/apis/public/projects.py +44 -3
- wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
- wandb/apis/public/registries/registries_search.py +2 -2
- wandb/apis/public/registries/registry.py +19 -18
- wandb/apis/public/reports.py +64 -8
- wandb/apis/public/runs.py +16 -23
- wandb/automations/__init__.py +10 -10
- wandb/automations/_filters/run_metrics.py +0 -2
- wandb/automations/_utils.py +0 -2
- wandb/automations/actions.py +0 -2
- wandb/automations/automations.py +0 -2
- wandb/automations/events.py +0 -2
- wandb/bin/gpu_stats +0 -0
- wandb/cli/beta.py +1 -7
- wandb/cli/cli.py +0 -30
- wandb/env.py +0 -6
- wandb/integration/catboost/catboost.py +6 -2
- wandb/integration/kfp/kfp_patch.py +3 -1
- wandb/integration/sb3/sb3.py +3 -3
- wandb/integration/ultralytics/callback.py +6 -2
- wandb/plot/__init__.py +2 -0
- wandb/plot/bar.py +30 -29
- wandb/plot/confusion_matrix.py +75 -71
- wandb/plot/histogram.py +26 -25
- wandb/plot/line.py +33 -32
- wandb/plot/line_series.py +100 -103
- wandb/plot/pr_curve.py +33 -32
- wandb/plot/roc_curve.py +38 -38
- wandb/plot/scatter.py +27 -27
- wandb/proto/v3/wandb_internal_pb2.py +366 -385
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +352 -356
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +352 -356
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v6/wandb_internal_pb2.py +352 -356
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_generated/__init__.py +12 -1
- wandb/sdk/artifacts/_generated/input_types.py +20 -2
- wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
- wandb/sdk/artifacts/_generated/operations.py +9 -0
- wandb/sdk/artifacts/_validators.py +40 -2
- wandb/sdk/artifacts/artifact.py +163 -21
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/base_types/media.py +9 -7
- wandb/sdk/data_types/base_types/wb_value.py +6 -6
- wandb/sdk/data_types/saved_model.py +3 -3
- wandb/sdk/data_types/table.py +41 -41
- wandb/sdk/data_types/trace_tree.py +12 -12
- wandb/sdk/interface/interface.py +8 -19
- wandb/sdk/interface/interface_shared.py +7 -16
- wandb/sdk/internal/datastore.py +18 -18
- wandb/sdk/internal/handler.py +4 -74
- wandb/sdk/internal/internal_api.py +54 -0
- wandb/sdk/internal/sender.py +23 -3
- wandb/sdk/internal/sender_config.py +9 -0
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/job_status_tracker.py +3 -1
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/console_capture.py +66 -19
- wandb/sdk/lib/printer.py +6 -7
- wandb/sdk/lib/progress.py +1 -3
- wandb/sdk/lib/service/ipc_support.py +13 -0
- wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
- wandb/sdk/lib/service/service_port_file.py +105 -0
- wandb/sdk/lib/service/service_process.py +111 -0
- wandb/sdk/lib/service/service_token.py +164 -0
- wandb/sdk/lib/sock_client.py +8 -12
- wandb/sdk/wandb_init.py +1 -5
- wandb/sdk/wandb_require.py +9 -21
- wandb/sdk/wandb_run.py +23 -137
- wandb/sdk/wandb_settings.py +233 -80
- wandb/sdk/wandb_setup.py +2 -13
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -3
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/RECORD +93 -119
- wandb/sdk/internal/flow_control.py +0 -263
- wandb/sdk/internal/internal.py +0 -401
- wandb/sdk/internal/internal_util.py +0 -97
- wandb/sdk/internal/system/__init__.py +0 -0
- wandb/sdk/internal/system/assets/__init__.py +0 -25
- wandb/sdk/internal/system/assets/aggregators.py +0 -31
- wandb/sdk/internal/system/assets/asset_registry.py +0 -20
- wandb/sdk/internal/system/assets/cpu.py +0 -163
- wandb/sdk/internal/system/assets/disk.py +0 -210
- wandb/sdk/internal/system/assets/gpu.py +0 -416
- wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
- wandb/sdk/internal/system/assets/interfaces.py +0 -205
- wandb/sdk/internal/system/assets/ipu.py +0 -177
- wandb/sdk/internal/system/assets/memory.py +0 -166
- wandb/sdk/internal/system/assets/network.py +0 -125
- wandb/sdk/internal/system/assets/open_metrics.py +0 -293
- wandb/sdk/internal/system/assets/tpu.py +0 -154
- wandb/sdk/internal/system/assets/trainium.py +0 -393
- wandb/sdk/internal/system/env_probe_helpers.py +0 -13
- wandb/sdk/internal/system/system_info.py +0 -248
- wandb/sdk/internal/system/system_monitor.py +0 -224
- wandb/sdk/internal/writer.py +0 -204
- wandb/sdk/lib/service_token.py +0 -93
- wandb/sdk/service/__init__.py +0 -0
- wandb/sdk/service/_startup_debug.py +0 -22
- wandb/sdk/service/port_file.py +0 -53
- wandb/sdk/service/server.py +0 -107
- wandb/sdk/service/server_sock.py +0 -286
- wandb/sdk/service/service.py +0 -252
- wandb/sdk/service/streams.py +0 -425
- wandb/sdk/wandb_metadata.py +0 -623
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
5
|
import os
|
6
|
+
import re
|
6
7
|
import time
|
7
8
|
from pathlib import PurePosixPath
|
8
9
|
from typing import TYPE_CHECKING, Sequence
|
@@ -52,10 +53,20 @@ class S3Handler(StorageHandler):
|
|
52
53
|
required="s3:// references requires the boto3 library, run pip install wandb[aws]",
|
53
54
|
lazy=False,
|
54
55
|
)
|
56
|
+
|
57
|
+
from botocore.client import Config # type: ignore
|
58
|
+
|
59
|
+
s3_endpoint = os.getenv("AWS_S3_ENDPOINT_URL")
|
60
|
+
config = (
|
61
|
+
Config(s3={"addressing_style": "virtual"})
|
62
|
+
if s3_endpoint and self._is_coreweave_endpoint(s3_endpoint)
|
63
|
+
else None
|
64
|
+
)
|
55
65
|
self._s3 = boto.session.Session().resource(
|
56
66
|
"s3",
|
57
|
-
endpoint_url=
|
67
|
+
endpoint_url=s3_endpoint,
|
58
68
|
region_name=os.getenv("AWS_REGION"),
|
69
|
+
config=config,
|
59
70
|
)
|
60
71
|
self._botocore = util.get_module("botocore")
|
61
72
|
return self._s3
|
@@ -296,3 +307,33 @@ class S3Handler(StorageHandler):
|
|
296
307
|
if hasattr(obj, "version_id") and obj.version_id and obj.version_id != "null":
|
297
308
|
extra["versionID"] = obj.version_id
|
298
309
|
return extra
|
310
|
+
|
311
|
+
_CW_LEGACY_NETLOC_REGEX: re.Pattern[str] = re.compile(
|
312
|
+
r"""
|
313
|
+
# accelerated endpoints like "accel-object.<region>.coreweave.com"
|
314
|
+
accel-object\.[a-z0-9-]+\.coreweave\.com
|
315
|
+
|
|
316
|
+
# URLs like "object.<region>.coreweave.com"
|
317
|
+
object\.[a-z0-9-]+\.coreweave\.com
|
318
|
+
""",
|
319
|
+
flags=re.VERBOSE,
|
320
|
+
)
|
321
|
+
|
322
|
+
def _is_coreweave_endpoint(self, endpoint_url: str) -> bool:
|
323
|
+
if not (url := endpoint_url.strip().rstrip("/")):
|
324
|
+
return False
|
325
|
+
|
326
|
+
# Only http://cwlota.com is supported using HTTP
|
327
|
+
if url == "http://cwlota.com":
|
328
|
+
return True
|
329
|
+
|
330
|
+
# Enforce HTTPS otherwise
|
331
|
+
https_url = url if url.startswith("https://") else f"https://{url}"
|
332
|
+
netloc = urlparse(https_url).netloc
|
333
|
+
return bool(
|
334
|
+
# Match for https://cwobject.com
|
335
|
+
(netloc == "cwobject.com")
|
336
|
+
or
|
337
|
+
# Check for legacy endpoints
|
338
|
+
self._CW_LEGACY_NETLOC_REGEX.fullmatch(netloc)
|
339
|
+
)
|
wandb/sdk/backend/backend.py
CHANGED
@@ -11,7 +11,7 @@ from wandb.sdk.interface.interface import InterfaceBase
|
|
11
11
|
from wandb.sdk.wandb_settings import Settings
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
|
-
from wandb.sdk.lib import service_connection
|
14
|
+
from wandb.sdk.lib.service import service_connection
|
15
15
|
|
16
16
|
logger = logging.getLogger("wandb")
|
17
17
|
|
@@ -127,9 +127,9 @@ class Media(WBValue):
|
|
127
127
|
self._path = path
|
128
128
|
self._is_tmp = is_tmp
|
129
129
|
self._extension = extension
|
130
|
-
assert extension is None or path.endswith(
|
131
|
-
extension
|
132
|
-
)
|
130
|
+
assert extension is None or path.endswith(extension), (
|
131
|
+
f'Media file extension "{extension}" must occur at the end of path "{path}".'
|
132
|
+
)
|
133
133
|
|
134
134
|
with open(self._path, "rb") as f:
|
135
135
|
self._sha256 = hashlib.sha256(f.read()).hexdigest()
|
@@ -247,11 +247,13 @@ class Media(WBValue):
|
|
247
247
|
json_obj["_latest_artifact_path"] = artifact_entry_latest_url
|
248
248
|
|
249
249
|
if artifact_entry_url is None or self.is_bound():
|
250
|
-
assert self.is_bound(),
|
250
|
+
assert self.is_bound(), (
|
251
|
+
f"Value of type {type(self).__name__} must be bound to a run with bind_to_run() before being serialized to JSON."
|
252
|
+
)
|
251
253
|
|
252
|
-
assert (
|
253
|
-
|
254
|
-
)
|
254
|
+
assert self._run is run, (
|
255
|
+
"We don't support referring to media files across runs."
|
256
|
+
)
|
255
257
|
|
256
258
|
# The following two assertions are guaranteed to pass
|
257
259
|
# by definition is_bound, but are needed for
|
@@ -218,17 +218,17 @@ class WBValue:
|
|
218
218
|
def _set_artifact_source(
|
219
219
|
self, artifact: "Artifact", name: Optional[str] = None
|
220
220
|
) -> None:
|
221
|
-
assert (
|
222
|
-
self._artifact_source
|
223
|
-
)
|
221
|
+
assert self._artifact_source is None, (
|
222
|
+
f"Cannot update artifact_source. Existing source: {self._artifact_source.artifact}/{self._artifact_source.name}"
|
223
|
+
)
|
224
224
|
self._artifact_source = _WBValueArtifactSource(artifact, name)
|
225
225
|
|
226
226
|
def _set_artifact_target(
|
227
227
|
self, artifact: "Artifact", name: Optional[str] = None
|
228
228
|
) -> None:
|
229
|
-
assert (
|
230
|
-
self._artifact_target
|
231
|
-
)
|
229
|
+
assert self._artifact_target is None, (
|
230
|
+
f"Cannot update artifact_target. Existing target: {self._artifact_target.artifact}/{self._artifact_target.name}"
|
231
|
+
)
|
232
232
|
self._artifact_target = _WBValueArtifactTarget(artifact, name)
|
233
233
|
|
234
234
|
def _get_artifact_entry_ref_url(self) -> Optional[str]:
|
@@ -257,9 +257,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
|
|
257
257
|
self._model_obj = None
|
258
258
|
|
259
259
|
def _set_obj(self, model_obj: Any) -> None:
|
260
|
-
assert model_obj is not None and self._validate_obj(
|
261
|
-
model_obj
|
262
|
-
)
|
260
|
+
assert model_obj is not None and self._validate_obj(model_obj), (
|
261
|
+
f"Invalid model object {model_obj}"
|
262
|
+
)
|
263
263
|
self._model_obj = model_obj
|
264
264
|
|
265
265
|
def _dump(self, target_path: str) -> None:
|
wandb/sdk/data_types/table.py
CHANGED
@@ -203,8 +203,8 @@ class Table(Media):
|
|
203
203
|
This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
|
204
204
|
directly in Tables, alongside other traditional scalar values.
|
205
205
|
|
206
|
-
This class is the primary class used to generate the
|
207
|
-
|
206
|
+
This class is the primary class used to generate the W&B Tables
|
207
|
+
https://docs.wandb.ai/guides/models/tables/.
|
208
208
|
"""
|
209
209
|
|
210
210
|
MAX_ROWS = 10000
|
@@ -292,9 +292,9 @@ class Table(Media):
|
|
292
292
|
self._init_from_list([], columns, optional, dtype)
|
293
293
|
|
294
294
|
def _validate_log_mode(self, log_mode):
|
295
|
-
assert (
|
296
|
-
log_mode
|
297
|
-
)
|
295
|
+
assert log_mode in _SUPPORTED_LOGGING_MODES, (
|
296
|
+
f"Invalid log_mode: {log_mode}. Must be one of {_SUPPORTED_LOGGING_MODES}"
|
297
|
+
)
|
298
298
|
|
299
299
|
@staticmethod
|
300
300
|
def _assert_valid_columns(columns):
|
@@ -314,9 +314,9 @@ class Table(Media):
|
|
314
314
|
self.add_data(*row)
|
315
315
|
|
316
316
|
def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
|
317
|
-
assert util.is_numpy_array(
|
318
|
-
ndarray
|
319
|
-
)
|
317
|
+
assert util.is_numpy_array(ndarray), (
|
318
|
+
"ndarray argument expects a `numpy.ndarray` object"
|
319
|
+
)
|
320
320
|
self.data = []
|
321
321
|
self._assert_valid_columns(columns)
|
322
322
|
self.columns = columns
|
@@ -325,9 +325,9 @@ class Table(Media):
|
|
325
325
|
self.add_data(*row)
|
326
326
|
|
327
327
|
def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
|
328
|
-
assert util.is_pandas_data_frame(
|
329
|
-
dataframe
|
330
|
-
)
|
328
|
+
assert util.is_pandas_data_frame(dataframe), (
|
329
|
+
"dataframe argument expects a `pandas.core.frame.DataFrame` object"
|
330
|
+
)
|
331
331
|
self.data = []
|
332
332
|
columns = list(dataframe.columns)
|
333
333
|
self._assert_valid_columns(columns)
|
@@ -440,17 +440,17 @@ class Table(Media):
|
|
440
440
|
is_fk = isinstance(wbtype, _ForeignKeyType)
|
441
441
|
is_fi = isinstance(wbtype, _ForeignIndexType)
|
442
442
|
if is_pk or is_fk or is_fi:
|
443
|
-
assert (
|
444
|
-
|
445
|
-
)
|
443
|
+
assert not optional, (
|
444
|
+
"Primary keys, foreign keys, and foreign indexes cannot be optional."
|
445
|
+
)
|
446
446
|
|
447
447
|
if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
|
448
448
|
raise AssertionError("Cannot set a foreign table reference to same table.")
|
449
449
|
|
450
450
|
if is_pk:
|
451
|
-
assert (
|
452
|
-
self._pk_col is
|
453
|
-
)
|
451
|
+
assert self._pk_col is None, (
|
452
|
+
f"Cannot have multiple primary keys - {self._pk_col} is already set as the primary key."
|
453
|
+
)
|
454
454
|
|
455
455
|
# Update the column type
|
456
456
|
self._column_types.params["type_map"][col_name] = wbtype
|
@@ -464,21 +464,21 @@ class Table(Media):
|
|
464
464
|
|
465
465
|
def _eq_debug(self, other, should_assert=False):
|
466
466
|
eq = isinstance(other, Table)
|
467
|
-
assert (
|
468
|
-
|
469
|
-
)
|
467
|
+
assert not should_assert or eq, (
|
468
|
+
f"Found type {other.__class__}, expected {Table}"
|
469
|
+
)
|
470
470
|
eq = eq and len(self.data) == len(other.data)
|
471
|
-
assert (
|
472
|
-
|
473
|
-
)
|
471
|
+
assert not should_assert or eq, (
|
472
|
+
f"Found {len(other.data)} rows, expected {len(self.data)}"
|
473
|
+
)
|
474
474
|
eq = eq and self.columns == other.columns
|
475
|
-
assert (
|
476
|
-
|
477
|
-
)
|
475
|
+
assert not should_assert or eq, (
|
476
|
+
f"Found columns {other.columns}, expected {self.columns}"
|
477
|
+
)
|
478
478
|
eq = eq and self._column_types == other._column_types
|
479
|
-
assert (
|
480
|
-
|
481
|
-
)
|
479
|
+
assert not should_assert or eq, (
|
480
|
+
f"Found column type {other._column_types}, expected column type {self._column_types}"
|
481
|
+
)
|
482
482
|
if eq:
|
483
483
|
for row_ndx in range(len(self.data)):
|
484
484
|
for col_ndx in range(len(self.data[row_ndx])):
|
@@ -487,9 +487,9 @@ class Table(Media):
|
|
487
487
|
if util.is_numpy_array(_eq):
|
488
488
|
_eq = ((_eq * -1) + 1).sum() == 0
|
489
489
|
eq = eq and _eq
|
490
|
-
assert (
|
491
|
-
|
492
|
-
)
|
490
|
+
assert not should_assert or eq, (
|
491
|
+
f"Unequal data at row_ndx {row_ndx} col_ndx {col_ndx}: found {other.data[row_ndx][col_ndx]}, expected {self.data[row_ndx][col_ndx]}"
|
492
|
+
)
|
493
493
|
if not eq:
|
494
494
|
return eq
|
495
495
|
return eq
|
@@ -932,9 +932,9 @@ class Table(Media):
|
|
932
932
|
assert isinstance(data, list) or is_np
|
933
933
|
assert isinstance(optional, bool)
|
934
934
|
is_first_col = len(self.columns) == 0
|
935
|
-
assert is_first_col or len(data) == len(
|
936
|
-
self.data
|
937
|
-
)
|
935
|
+
assert is_first_col or len(data) == len(self.data), (
|
936
|
+
f"Expected length {len(self.data)}, found {len(data)}"
|
937
|
+
)
|
938
938
|
|
939
939
|
# Add the new data
|
940
940
|
for ndx in range(max(len(data), len(self.data))):
|
@@ -1257,13 +1257,13 @@ class JoinedTable(Media):
|
|
1257
1257
|
|
1258
1258
|
def _eq_debug(self, other, should_assert=False):
|
1259
1259
|
eq = isinstance(other, JoinedTable)
|
1260
|
-
assert (
|
1261
|
-
|
1262
|
-
)
|
1260
|
+
assert not should_assert or eq, (
|
1261
|
+
f"Found type {other.__class__}, expected {JoinedTable}"
|
1262
|
+
)
|
1263
1263
|
eq = eq and self._join_key == other._join_key
|
1264
|
-
assert (
|
1265
|
-
|
1266
|
-
)
|
1264
|
+
assert not should_assert or eq, (
|
1265
|
+
f"Found {other._join_key} join key, expected {self._join_key}"
|
1266
|
+
)
|
1267
1267
|
eq = eq and self._table1._eq_debug(other._table1, should_assert)
|
1268
1268
|
eq = eq and self._table2._eq_debug(other._table2, should_assert)
|
1269
1269
|
return eq
|
@@ -261,14 +261,14 @@ class Trace:
|
|
261
261
|
A Span object.
|
262
262
|
"""
|
263
263
|
if kind is not None:
|
264
|
-
assert (
|
265
|
-
kind
|
266
|
-
)
|
264
|
+
assert kind.upper() in SpanKind.__members__, (
|
265
|
+
"Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
|
266
|
+
)
|
267
267
|
kind = SpanKind(kind.upper())
|
268
268
|
if status_code is not None:
|
269
|
-
assert (
|
270
|
-
|
271
|
-
)
|
269
|
+
assert status_code.upper() in StatusCode.__members__, (
|
270
|
+
"Invalid status code, can be one of 'SUCCESS' or 'ERROR'"
|
271
|
+
)
|
272
272
|
status_code = StatusCode(status_code.upper())
|
273
273
|
if inputs is not None:
|
274
274
|
assert isinstance(inputs, dict), "Inputs must be a dictionary"
|
@@ -419,9 +419,9 @@ class Trace:
|
|
419
419
|
Args:
|
420
420
|
value: The kind of the trace to be set.
|
421
421
|
"""
|
422
|
-
assert (
|
423
|
-
|
424
|
-
)
|
422
|
+
assert value.upper() in SpanKind.__members__, (
|
423
|
+
"Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
|
424
|
+
)
|
425
425
|
self._span.span_kind = SpanKind(value.upper())
|
426
426
|
|
427
427
|
def log(self, name: str) -> None:
|
@@ -433,8 +433,8 @@ class Trace:
|
|
433
433
|
trace_tree = WBTraceTree(self._span, self._model_dict)
|
434
434
|
# NOTE: Does not work for reinit="create_new" runs.
|
435
435
|
# This method should be deprecated and users should call run.log().
|
436
|
-
assert (
|
437
|
-
wandb.
|
438
|
-
)
|
436
|
+
assert wandb.run is not None, (
|
437
|
+
"You must call wandb.init() before logging a trace"
|
438
|
+
)
|
439
439
|
assert len(name.strip()) > 0, "You must provide a valid name to log the trace"
|
440
440
|
wandb.run.log({name: trace_tree})
|
wandb/sdk/interface/interface.py
CHANGED
@@ -150,8 +150,7 @@ class InterfaceBase:
|
|
150
150
|
if run._settings.run_notes is not None:
|
151
151
|
proto_run.notes = run._settings.run_notes
|
152
152
|
if run._settings.run_tags is not None:
|
153
|
-
|
154
|
-
proto_run.tags.append(tag)
|
153
|
+
proto_run.tags.extend(run._settings.run_tags)
|
155
154
|
if run._start_time is not None:
|
156
155
|
proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
|
157
156
|
if run._starting_step is not None:
|
@@ -217,13 +216,6 @@ class InterfaceBase:
|
|
217
216
|
def _publish_config(self, cfg: pb.ConfigRecord) -> None:
|
218
217
|
raise NotImplementedError
|
219
218
|
|
220
|
-
def publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
221
|
-
self._publish_metadata(metadata)
|
222
|
-
|
223
|
-
@abstractmethod
|
224
|
-
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
225
|
-
raise NotImplementedError
|
226
|
-
|
227
219
|
@abstractmethod
|
228
220
|
def _publish_metric(self, metric: pb.MetricRecord) -> None:
|
229
221
|
raise NotImplementedError
|
@@ -671,6 +663,13 @@ class InterfaceBase:
|
|
671
663
|
def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
|
672
664
|
raise NotImplementedError
|
673
665
|
|
666
|
+
def publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
667
|
+
self._publish_environment(environment)
|
668
|
+
|
669
|
+
@abstractmethod
|
670
|
+
def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
671
|
+
raise NotImplementedError
|
672
|
+
|
674
673
|
def publish_partial_history(
|
675
674
|
self,
|
676
675
|
run: "Run",
|
@@ -1000,16 +999,6 @@ class InterfaceBase:
|
|
1000
999
|
) -> MailboxHandle[pb.Result]:
|
1001
1000
|
raise NotImplementedError
|
1002
1001
|
|
1003
|
-
def deliver_get_system_metadata(self) -> MailboxHandle[pb.Result]:
|
1004
|
-
get_system_metadata = pb.GetSystemMetadataRequest()
|
1005
|
-
return self._deliver_get_system_metadata(get_system_metadata)
|
1006
|
-
|
1007
|
-
@abstractmethod
|
1008
|
-
def _deliver_get_system_metadata(
|
1009
|
-
self, get_system_metadata: pb.GetSystemMetadataRequest
|
1010
|
-
) -> MailboxHandle[pb.Result]:
|
1011
|
-
raise NotImplementedError
|
1012
|
-
|
1013
1002
|
def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle[pb.Result]:
|
1014
1003
|
exit_data = self._make_exit(exit_code)
|
1015
1004
|
return self._deliver_exit(exit_data)
|
@@ -59,6 +59,10 @@ class InterfaceShared(InterfaceBase):
|
|
59
59
|
rec = self._make_record(telemetry=telem)
|
60
60
|
self._publish(rec)
|
61
61
|
|
62
|
+
def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
63
|
+
rec = self._make_record(environment=environment)
|
64
|
+
self._publish(rec)
|
65
|
+
|
62
66
|
def _publish_job_input(
|
63
67
|
self, job_input: pb.JobInputRequest
|
64
68
|
) -> MailboxHandle[pb.Result]:
|
@@ -106,11 +110,9 @@ class InterfaceShared(InterfaceBase):
|
|
106
110
|
summary_record: Optional[pb.SummaryRecordRequest] = None,
|
107
111
|
telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
|
108
112
|
get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
|
109
|
-
get_system_metadata: Optional[pb.GetSystemMetadataRequest] = None,
|
110
113
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
111
114
|
job_input: Optional[pb.JobInputRequest] = None,
|
112
115
|
run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
|
113
|
-
metadata: Optional[pb.MetadataRequest] = None,
|
114
116
|
) -> pb.Record:
|
115
117
|
request = pb.Request()
|
116
118
|
if get_summary:
|
@@ -169,8 +171,6 @@ class InterfaceShared(InterfaceBase):
|
|
169
171
|
request.telemetry_record.CopyFrom(telemetry_record)
|
170
172
|
elif get_system_metrics:
|
171
173
|
request.get_system_metrics.CopyFrom(get_system_metrics)
|
172
|
-
elif get_system_metadata:
|
173
|
-
request.get_system_metadata.CopyFrom(get_system_metadata)
|
174
174
|
elif sync_finish:
|
175
175
|
request.sync_finish.CopyFrom(sync_finish)
|
176
176
|
elif python_packages:
|
@@ -179,8 +179,6 @@ class InterfaceShared(InterfaceBase):
|
|
179
179
|
request.job_input.CopyFrom(job_input)
|
180
180
|
elif run_finish_without_exit:
|
181
181
|
request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
|
182
|
-
elif metadata:
|
183
|
-
request.metadata.CopyFrom(metadata)
|
184
182
|
else:
|
185
183
|
raise Exception("Invalid request")
|
186
184
|
record = self._make_record(request=request)
|
@@ -212,6 +210,7 @@ class InterfaceShared(InterfaceBase):
|
|
212
210
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
213
211
|
output: Optional[pb.OutputRecord] = None,
|
214
212
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
213
|
+
environment: Optional[pb.EnvironmentRecord] = None,
|
215
214
|
) -> pb.Record:
|
216
215
|
record = pb.Record()
|
217
216
|
if run:
|
@@ -254,6 +253,8 @@ class InterfaceShared(InterfaceBase):
|
|
254
253
|
record.output.CopyFrom(output)
|
255
254
|
elif output_raw:
|
256
255
|
record.output_raw.CopyFrom(output_raw)
|
256
|
+
elif environment:
|
257
|
+
record.environment.CopyFrom(environment)
|
257
258
|
else:
|
258
259
|
raise Exception("Invalid record")
|
259
260
|
return record
|
@@ -304,10 +305,6 @@ class InterfaceShared(InterfaceBase):
|
|
304
305
|
rec = self._make_record(summary=summary)
|
305
306
|
self._publish(rec)
|
306
307
|
|
307
|
-
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
308
|
-
rec = self._make_request(metadata=metadata)
|
309
|
-
self._publish(rec)
|
310
|
-
|
311
308
|
def _publish_metric(self, metric: pb.MetricRecord) -> None:
|
312
309
|
rec = self._make_record(metric=metric)
|
313
310
|
self._publish(rec)
|
@@ -422,12 +419,6 @@ class InterfaceShared(InterfaceBase):
|
|
422
419
|
record = self._make_request(get_system_metrics=get_system_metrics)
|
423
420
|
return self._deliver_record(record)
|
424
421
|
|
425
|
-
def _deliver_get_system_metadata(
|
426
|
-
self, get_system_metadata: pb.GetSystemMetadataRequest
|
427
|
-
) -> MailboxHandle[pb.Result]:
|
428
|
-
record = self._make_request(get_system_metadata=get_system_metadata)
|
429
|
-
return self._deliver_record(record)
|
430
|
-
|
431
422
|
def _deliver_exit(
|
432
423
|
self,
|
433
424
|
exit_data: pb.RunExitRecord,
|
wandb/sdk/internal/datastore.py
CHANGED
@@ -124,18 +124,18 @@ class DataStore:
|
|
124
124
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
125
125
|
if len(header) == 0:
|
126
126
|
return None
|
127
|
-
assert (
|
128
|
-
len(header)
|
129
|
-
)
|
127
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
128
|
+
f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
129
|
+
)
|
130
130
|
fields = struct.unpack("<IHB", header)
|
131
131
|
checksum, dlength, dtype = fields
|
132
132
|
# check len, better fit in the block
|
133
133
|
self._index += LEVELDBLOG_HEADER_LEN
|
134
134
|
data = self._fp.read(dlength)
|
135
135
|
checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
|
136
|
-
assert (
|
137
|
-
checksum
|
138
|
-
)
|
136
|
+
assert checksum == checksum_computed, (
|
137
|
+
"record checksum is invalid, data may be corrupt"
|
138
|
+
)
|
139
139
|
self._index += dlength
|
140
140
|
return dtype, data
|
141
141
|
|
@@ -158,9 +158,9 @@ class DataStore:
|
|
158
158
|
if dtype == LEVELDBLOG_FULL:
|
159
159
|
return data
|
160
160
|
|
161
|
-
assert (
|
162
|
-
|
163
|
-
)
|
161
|
+
assert dtype == LEVELDBLOG_FIRST, (
|
162
|
+
f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
|
163
|
+
)
|
164
164
|
while True:
|
165
165
|
offset = self._index % LEVELDBLOG_BLOCK_LEN
|
166
166
|
record = self.scan_record()
|
@@ -170,9 +170,9 @@ class DataStore:
|
|
170
170
|
if dtype == LEVELDBLOG_LAST:
|
171
171
|
data += new_data
|
172
172
|
break
|
173
|
-
assert (
|
174
|
-
|
175
|
-
)
|
173
|
+
assert dtype == LEVELDBLOG_MIDDLE, (
|
174
|
+
f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
|
175
|
+
)
|
176
176
|
data += new_data
|
177
177
|
return data
|
178
178
|
|
@@ -183,17 +183,17 @@ class DataStore:
|
|
183
183
|
LEVELDBLOG_HEADER_MAGIC,
|
184
184
|
LEVELDBLOG_HEADER_VERSION,
|
185
185
|
)
|
186
|
-
assert (
|
187
|
-
len(data)
|
188
|
-
)
|
186
|
+
assert len(data) == LEVELDBLOG_HEADER_LEN, (
|
187
|
+
f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
|
188
|
+
)
|
189
189
|
self._fp.write(data)
|
190
190
|
self._index += len(data)
|
191
191
|
|
192
192
|
def _read_header(self):
|
193
193
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
194
|
-
assert (
|
195
|
-
len(header)
|
196
|
-
)
|
194
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
195
|
+
f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
196
|
+
)
|
197
197
|
ident, magic, version = struct.unpack("<4sHB", header)
|
198
198
|
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
|
199
199
|
raise Exception("Invalid header")
|