wandb 0.20.2rc20250616__py3-none-win_amd64.whl → 0.21.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +24 -23
- wandb/apis/internal.py +3 -0
- wandb/apis/paginator.py +17 -4
- wandb/apis/public/api.py +83 -2
- wandb/apis/public/artifacts.py +10 -8
- wandb/apis/public/files.py +5 -5
- wandb/apis/public/projects.py +44 -3
- wandb/apis/public/reports.py +64 -8
- wandb/apis/public/runs.py +16 -23
- wandb/automations/__init__.py +10 -10
- wandb/automations/_filters/run_metrics.py +0 -2
- wandb/automations/_utils.py +0 -2
- wandb/automations/actions.py +0 -2
- wandb/automations/automations.py +0 -2
- wandb/automations/events.py +0 -2
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/integration/catboost/catboost.py +6 -2
- wandb/integration/kfp/kfp_patch.py +3 -1
- wandb/integration/sb3/sb3.py +3 -3
- wandb/integration/ultralytics/callback.py +6 -2
- wandb/plot/__init__.py +2 -0
- wandb/plot/bar.py +30 -29
- wandb/plot/confusion_matrix.py +75 -71
- wandb/plot/histogram.py +26 -25
- wandb/plot/line.py +33 -32
- wandb/plot/line_series.py +100 -103
- wandb/plot/pr_curve.py +33 -32
- wandb/plot/roc_curve.py +38 -38
- wandb/plot/scatter.py +27 -27
- wandb/proto/v3/wandb_internal_pb2.py +366 -385
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_internal_pb2.py +352 -356
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_internal_pb2.py +352 -356
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v6/wandb_internal_pb2.py +352 -356
- wandb/proto/v6/wandb_settings_pb2.py +2 -2
- wandb/sdk/artifacts/_generated/__init__.py +12 -1
- wandb/sdk/artifacts/_generated/input_types.py +20 -2
- wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
- wandb/sdk/artifacts/_generated/operations.py +9 -0
- wandb/sdk/artifacts/_validators.py +40 -2
- wandb/sdk/artifacts/artifact.py +163 -21
- wandb/sdk/data_types/base_types/media.py +9 -7
- wandb/sdk/data_types/base_types/wb_value.py +6 -6
- wandb/sdk/data_types/saved_model.py +3 -3
- wandb/sdk/data_types/table.py +41 -41
- wandb/sdk/data_types/trace_tree.py +12 -12
- wandb/sdk/interface/interface.py +8 -19
- wandb/sdk/interface/interface_shared.py +7 -16
- wandb/sdk/internal/datastore.py +18 -18
- wandb/sdk/internal/handler.py +3 -5
- wandb/sdk/internal/internal_api.py +54 -0
- wandb/sdk/internal/sender.py +23 -3
- wandb/sdk/internal/sender_config.py +9 -0
- wandb/sdk/launch/_project_spec.py +3 -3
- wandb/sdk/launch/agent/agent.py +3 -3
- wandb/sdk/launch/agent/job_status_tracker.py +3 -1
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/console_capture.py +66 -19
- wandb/sdk/wandb_init.py +1 -2
- wandb/sdk/wandb_require.py +0 -1
- wandb/sdk/wandb_run.py +23 -113
- wandb/sdk/wandb_settings.py +234 -72
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -1
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.0.dist-info}/RECORD +71 -71
- wandb/sdk/wandb_metadata.py +0 -623
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
- {wandb-0.20.2rc20250616.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
wandb/sdk/data_types/table.py
CHANGED
@@ -203,8 +203,8 @@ class Table(Media):
|
|
203
203
|
This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
|
204
204
|
directly in Tables, alongside other traditional scalar values.
|
205
205
|
|
206
|
-
This class is the primary class used to generate the
|
207
|
-
|
206
|
+
This class is the primary class used to generate the W&B Tables
|
207
|
+
https://docs.wandb.ai/guides/models/tables/.
|
208
208
|
"""
|
209
209
|
|
210
210
|
MAX_ROWS = 10000
|
@@ -292,9 +292,9 @@ class Table(Media):
|
|
292
292
|
self._init_from_list([], columns, optional, dtype)
|
293
293
|
|
294
294
|
def _validate_log_mode(self, log_mode):
|
295
|
-
assert (
|
296
|
-
log_mode
|
297
|
-
)
|
295
|
+
assert log_mode in _SUPPORTED_LOGGING_MODES, (
|
296
|
+
f"Invalid log_mode: {log_mode}. Must be one of {_SUPPORTED_LOGGING_MODES}"
|
297
|
+
)
|
298
298
|
|
299
299
|
@staticmethod
|
300
300
|
def _assert_valid_columns(columns):
|
@@ -314,9 +314,9 @@ class Table(Media):
|
|
314
314
|
self.add_data(*row)
|
315
315
|
|
316
316
|
def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
|
317
|
-
assert util.is_numpy_array(
|
318
|
-
ndarray
|
319
|
-
)
|
317
|
+
assert util.is_numpy_array(ndarray), (
|
318
|
+
"ndarray argument expects a `numpy.ndarray` object"
|
319
|
+
)
|
320
320
|
self.data = []
|
321
321
|
self._assert_valid_columns(columns)
|
322
322
|
self.columns = columns
|
@@ -325,9 +325,9 @@ class Table(Media):
|
|
325
325
|
self.add_data(*row)
|
326
326
|
|
327
327
|
def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
|
328
|
-
assert util.is_pandas_data_frame(
|
329
|
-
dataframe
|
330
|
-
)
|
328
|
+
assert util.is_pandas_data_frame(dataframe), (
|
329
|
+
"dataframe argument expects a `pandas.core.frame.DataFrame` object"
|
330
|
+
)
|
331
331
|
self.data = []
|
332
332
|
columns = list(dataframe.columns)
|
333
333
|
self._assert_valid_columns(columns)
|
@@ -440,17 +440,17 @@ class Table(Media):
|
|
440
440
|
is_fk = isinstance(wbtype, _ForeignKeyType)
|
441
441
|
is_fi = isinstance(wbtype, _ForeignIndexType)
|
442
442
|
if is_pk or is_fk or is_fi:
|
443
|
-
assert (
|
444
|
-
|
445
|
-
)
|
443
|
+
assert not optional, (
|
444
|
+
"Primary keys, foreign keys, and foreign indexes cannot be optional."
|
445
|
+
)
|
446
446
|
|
447
447
|
if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
|
448
448
|
raise AssertionError("Cannot set a foreign table reference to same table.")
|
449
449
|
|
450
450
|
if is_pk:
|
451
|
-
assert (
|
452
|
-
self._pk_col is
|
453
|
-
)
|
451
|
+
assert self._pk_col is None, (
|
452
|
+
f"Cannot have multiple primary keys - {self._pk_col} is already set as the primary key."
|
453
|
+
)
|
454
454
|
|
455
455
|
# Update the column type
|
456
456
|
self._column_types.params["type_map"][col_name] = wbtype
|
@@ -464,21 +464,21 @@ class Table(Media):
|
|
464
464
|
|
465
465
|
def _eq_debug(self, other, should_assert=False):
|
466
466
|
eq = isinstance(other, Table)
|
467
|
-
assert (
|
468
|
-
|
469
|
-
)
|
467
|
+
assert not should_assert or eq, (
|
468
|
+
f"Found type {other.__class__}, expected {Table}"
|
469
|
+
)
|
470
470
|
eq = eq and len(self.data) == len(other.data)
|
471
|
-
assert (
|
472
|
-
|
473
|
-
)
|
471
|
+
assert not should_assert or eq, (
|
472
|
+
f"Found {len(other.data)} rows, expected {len(self.data)}"
|
473
|
+
)
|
474
474
|
eq = eq and self.columns == other.columns
|
475
|
-
assert (
|
476
|
-
|
477
|
-
)
|
475
|
+
assert not should_assert or eq, (
|
476
|
+
f"Found columns {other.columns}, expected {self.columns}"
|
477
|
+
)
|
478
478
|
eq = eq and self._column_types == other._column_types
|
479
|
-
assert (
|
480
|
-
|
481
|
-
)
|
479
|
+
assert not should_assert or eq, (
|
480
|
+
f"Found column type {other._column_types}, expected column type {self._column_types}"
|
481
|
+
)
|
482
482
|
if eq:
|
483
483
|
for row_ndx in range(len(self.data)):
|
484
484
|
for col_ndx in range(len(self.data[row_ndx])):
|
@@ -487,9 +487,9 @@ class Table(Media):
|
|
487
487
|
if util.is_numpy_array(_eq):
|
488
488
|
_eq = ((_eq * -1) + 1).sum() == 0
|
489
489
|
eq = eq and _eq
|
490
|
-
assert (
|
491
|
-
|
492
|
-
)
|
490
|
+
assert not should_assert or eq, (
|
491
|
+
f"Unequal data at row_ndx {row_ndx} col_ndx {col_ndx}: found {other.data[row_ndx][col_ndx]}, expected {self.data[row_ndx][col_ndx]}"
|
492
|
+
)
|
493
493
|
if not eq:
|
494
494
|
return eq
|
495
495
|
return eq
|
@@ -932,9 +932,9 @@ class Table(Media):
|
|
932
932
|
assert isinstance(data, list) or is_np
|
933
933
|
assert isinstance(optional, bool)
|
934
934
|
is_first_col = len(self.columns) == 0
|
935
|
-
assert is_first_col or len(data) == len(
|
936
|
-
self.data
|
937
|
-
)
|
935
|
+
assert is_first_col or len(data) == len(self.data), (
|
936
|
+
f"Expected length {len(self.data)}, found {len(data)}"
|
937
|
+
)
|
938
938
|
|
939
939
|
# Add the new data
|
940
940
|
for ndx in range(max(len(data), len(self.data))):
|
@@ -1257,13 +1257,13 @@ class JoinedTable(Media):
|
|
1257
1257
|
|
1258
1258
|
def _eq_debug(self, other, should_assert=False):
|
1259
1259
|
eq = isinstance(other, JoinedTable)
|
1260
|
-
assert (
|
1261
|
-
|
1262
|
-
)
|
1260
|
+
assert not should_assert or eq, (
|
1261
|
+
f"Found type {other.__class__}, expected {JoinedTable}"
|
1262
|
+
)
|
1263
1263
|
eq = eq and self._join_key == other._join_key
|
1264
|
-
assert (
|
1265
|
-
|
1266
|
-
)
|
1264
|
+
assert not should_assert or eq, (
|
1265
|
+
f"Found {other._join_key} join key, expected {self._join_key}"
|
1266
|
+
)
|
1267
1267
|
eq = eq and self._table1._eq_debug(other._table1, should_assert)
|
1268
1268
|
eq = eq and self._table2._eq_debug(other._table2, should_assert)
|
1269
1269
|
return eq
|
@@ -261,14 +261,14 @@ class Trace:
|
|
261
261
|
A Span object.
|
262
262
|
"""
|
263
263
|
if kind is not None:
|
264
|
-
assert (
|
265
|
-
kind
|
266
|
-
)
|
264
|
+
assert kind.upper() in SpanKind.__members__, (
|
265
|
+
"Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
|
266
|
+
)
|
267
267
|
kind = SpanKind(kind.upper())
|
268
268
|
if status_code is not None:
|
269
|
-
assert (
|
270
|
-
|
271
|
-
)
|
269
|
+
assert status_code.upper() in StatusCode.__members__, (
|
270
|
+
"Invalid status code, can be one of 'SUCCESS' or 'ERROR'"
|
271
|
+
)
|
272
272
|
status_code = StatusCode(status_code.upper())
|
273
273
|
if inputs is not None:
|
274
274
|
assert isinstance(inputs, dict), "Inputs must be a dictionary"
|
@@ -419,9 +419,9 @@ class Trace:
|
|
419
419
|
Args:
|
420
420
|
value: The kind of the trace to be set.
|
421
421
|
"""
|
422
|
-
assert (
|
423
|
-
|
424
|
-
)
|
422
|
+
assert value.upper() in SpanKind.__members__, (
|
423
|
+
"Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
|
424
|
+
)
|
425
425
|
self._span.span_kind = SpanKind(value.upper())
|
426
426
|
|
427
427
|
def log(self, name: str) -> None:
|
@@ -433,8 +433,8 @@ class Trace:
|
|
433
433
|
trace_tree = WBTraceTree(self._span, self._model_dict)
|
434
434
|
# NOTE: Does not work for reinit="create_new" runs.
|
435
435
|
# This method should be deprecated and users should call run.log().
|
436
|
-
assert (
|
437
|
-
wandb.
|
438
|
-
)
|
436
|
+
assert wandb.run is not None, (
|
437
|
+
"You must call wandb.init() before logging a trace"
|
438
|
+
)
|
439
439
|
assert len(name.strip()) > 0, "You must provide a valid name to log the trace"
|
440
440
|
wandb.run.log({name: trace_tree})
|
wandb/sdk/interface/interface.py
CHANGED
@@ -150,8 +150,7 @@ class InterfaceBase:
|
|
150
150
|
if run._settings.run_notes is not None:
|
151
151
|
proto_run.notes = run._settings.run_notes
|
152
152
|
if run._settings.run_tags is not None:
|
153
|
-
|
154
|
-
proto_run.tags.append(tag)
|
153
|
+
proto_run.tags.extend(run._settings.run_tags)
|
155
154
|
if run._start_time is not None:
|
156
155
|
proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
|
157
156
|
if run._starting_step is not None:
|
@@ -217,13 +216,6 @@ class InterfaceBase:
|
|
217
216
|
def _publish_config(self, cfg: pb.ConfigRecord) -> None:
|
218
217
|
raise NotImplementedError
|
219
218
|
|
220
|
-
def publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
221
|
-
self._publish_metadata(metadata)
|
222
|
-
|
223
|
-
@abstractmethod
|
224
|
-
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
225
|
-
raise NotImplementedError
|
226
|
-
|
227
219
|
@abstractmethod
|
228
220
|
def _publish_metric(self, metric: pb.MetricRecord) -> None:
|
229
221
|
raise NotImplementedError
|
@@ -671,6 +663,13 @@ class InterfaceBase:
|
|
671
663
|
def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
|
672
664
|
raise NotImplementedError
|
673
665
|
|
666
|
+
def publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
667
|
+
self._publish_environment(environment)
|
668
|
+
|
669
|
+
@abstractmethod
|
670
|
+
def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
671
|
+
raise NotImplementedError
|
672
|
+
|
674
673
|
def publish_partial_history(
|
675
674
|
self,
|
676
675
|
run: "Run",
|
@@ -1000,16 +999,6 @@ class InterfaceBase:
|
|
1000
999
|
) -> MailboxHandle[pb.Result]:
|
1001
1000
|
raise NotImplementedError
|
1002
1001
|
|
1003
|
-
def deliver_get_system_metadata(self) -> MailboxHandle[pb.Result]:
|
1004
|
-
get_system_metadata = pb.GetSystemMetadataRequest()
|
1005
|
-
return self._deliver_get_system_metadata(get_system_metadata)
|
1006
|
-
|
1007
|
-
@abstractmethod
|
1008
|
-
def _deliver_get_system_metadata(
|
1009
|
-
self, get_system_metadata: pb.GetSystemMetadataRequest
|
1010
|
-
) -> MailboxHandle[pb.Result]:
|
1011
|
-
raise NotImplementedError
|
1012
|
-
|
1013
1002
|
def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle[pb.Result]:
|
1014
1003
|
exit_data = self._make_exit(exit_code)
|
1015
1004
|
return self._deliver_exit(exit_data)
|
@@ -59,6 +59,10 @@ class InterfaceShared(InterfaceBase):
|
|
59
59
|
rec = self._make_record(telemetry=telem)
|
60
60
|
self._publish(rec)
|
61
61
|
|
62
|
+
def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
|
63
|
+
rec = self._make_record(environment=environment)
|
64
|
+
self._publish(rec)
|
65
|
+
|
62
66
|
def _publish_job_input(
|
63
67
|
self, job_input: pb.JobInputRequest
|
64
68
|
) -> MailboxHandle[pb.Result]:
|
@@ -106,11 +110,9 @@ class InterfaceShared(InterfaceBase):
|
|
106
110
|
summary_record: Optional[pb.SummaryRecordRequest] = None,
|
107
111
|
telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
|
108
112
|
get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
|
109
|
-
get_system_metadata: Optional[pb.GetSystemMetadataRequest] = None,
|
110
113
|
python_packages: Optional[pb.PythonPackagesRequest] = None,
|
111
114
|
job_input: Optional[pb.JobInputRequest] = None,
|
112
115
|
run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
|
113
|
-
metadata: Optional[pb.MetadataRequest] = None,
|
114
116
|
) -> pb.Record:
|
115
117
|
request = pb.Request()
|
116
118
|
if get_summary:
|
@@ -169,8 +171,6 @@ class InterfaceShared(InterfaceBase):
|
|
169
171
|
request.telemetry_record.CopyFrom(telemetry_record)
|
170
172
|
elif get_system_metrics:
|
171
173
|
request.get_system_metrics.CopyFrom(get_system_metrics)
|
172
|
-
elif get_system_metadata:
|
173
|
-
request.get_system_metadata.CopyFrom(get_system_metadata)
|
174
174
|
elif sync_finish:
|
175
175
|
request.sync_finish.CopyFrom(sync_finish)
|
176
176
|
elif python_packages:
|
@@ -179,8 +179,6 @@ class InterfaceShared(InterfaceBase):
|
|
179
179
|
request.job_input.CopyFrom(job_input)
|
180
180
|
elif run_finish_without_exit:
|
181
181
|
request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
|
182
|
-
elif metadata:
|
183
|
-
request.metadata.CopyFrom(metadata)
|
184
182
|
else:
|
185
183
|
raise Exception("Invalid request")
|
186
184
|
record = self._make_record(request=request)
|
@@ -212,6 +210,7 @@ class InterfaceShared(InterfaceBase):
|
|
212
210
|
use_artifact: Optional[pb.UseArtifactRecord] = None,
|
213
211
|
output: Optional[pb.OutputRecord] = None,
|
214
212
|
output_raw: Optional[pb.OutputRawRecord] = None,
|
213
|
+
environment: Optional[pb.EnvironmentRecord] = None,
|
215
214
|
) -> pb.Record:
|
216
215
|
record = pb.Record()
|
217
216
|
if run:
|
@@ -254,6 +253,8 @@ class InterfaceShared(InterfaceBase):
|
|
254
253
|
record.output.CopyFrom(output)
|
255
254
|
elif output_raw:
|
256
255
|
record.output_raw.CopyFrom(output_raw)
|
256
|
+
elif environment:
|
257
|
+
record.environment.CopyFrom(environment)
|
257
258
|
else:
|
258
259
|
raise Exception("Invalid record")
|
259
260
|
return record
|
@@ -304,10 +305,6 @@ class InterfaceShared(InterfaceBase):
|
|
304
305
|
rec = self._make_record(summary=summary)
|
305
306
|
self._publish(rec)
|
306
307
|
|
307
|
-
def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
|
308
|
-
rec = self._make_request(metadata=metadata)
|
309
|
-
self._publish(rec)
|
310
|
-
|
311
308
|
def _publish_metric(self, metric: pb.MetricRecord) -> None:
|
312
309
|
rec = self._make_record(metric=metric)
|
313
310
|
self._publish(rec)
|
@@ -422,12 +419,6 @@ class InterfaceShared(InterfaceBase):
|
|
422
419
|
record = self._make_request(get_system_metrics=get_system_metrics)
|
423
420
|
return self._deliver_record(record)
|
424
421
|
|
425
|
-
def _deliver_get_system_metadata(
|
426
|
-
self, get_system_metadata: pb.GetSystemMetadataRequest
|
427
|
-
) -> MailboxHandle[pb.Result]:
|
428
|
-
record = self._make_request(get_system_metadata=get_system_metadata)
|
429
|
-
return self._deliver_record(record)
|
430
|
-
|
431
422
|
def _deliver_exit(
|
432
423
|
self,
|
433
424
|
exit_data: pb.RunExitRecord,
|
wandb/sdk/internal/datastore.py
CHANGED
@@ -124,18 +124,18 @@ class DataStore:
|
|
124
124
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
125
125
|
if len(header) == 0:
|
126
126
|
return None
|
127
|
-
assert (
|
128
|
-
len(header)
|
129
|
-
)
|
127
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
128
|
+
f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
129
|
+
)
|
130
130
|
fields = struct.unpack("<IHB", header)
|
131
131
|
checksum, dlength, dtype = fields
|
132
132
|
# check len, better fit in the block
|
133
133
|
self._index += LEVELDBLOG_HEADER_LEN
|
134
134
|
data = self._fp.read(dlength)
|
135
135
|
checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
|
136
|
-
assert (
|
137
|
-
checksum
|
138
|
-
)
|
136
|
+
assert checksum == checksum_computed, (
|
137
|
+
"record checksum is invalid, data may be corrupt"
|
138
|
+
)
|
139
139
|
self._index += dlength
|
140
140
|
return dtype, data
|
141
141
|
|
@@ -158,9 +158,9 @@ class DataStore:
|
|
158
158
|
if dtype == LEVELDBLOG_FULL:
|
159
159
|
return data
|
160
160
|
|
161
|
-
assert (
|
162
|
-
|
163
|
-
)
|
161
|
+
assert dtype == LEVELDBLOG_FIRST, (
|
162
|
+
f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
|
163
|
+
)
|
164
164
|
while True:
|
165
165
|
offset = self._index % LEVELDBLOG_BLOCK_LEN
|
166
166
|
record = self.scan_record()
|
@@ -170,9 +170,9 @@ class DataStore:
|
|
170
170
|
if dtype == LEVELDBLOG_LAST:
|
171
171
|
data += new_data
|
172
172
|
break
|
173
|
-
assert (
|
174
|
-
|
175
|
-
)
|
173
|
+
assert dtype == LEVELDBLOG_MIDDLE, (
|
174
|
+
f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
|
175
|
+
)
|
176
176
|
data += new_data
|
177
177
|
return data
|
178
178
|
|
@@ -183,17 +183,17 @@ class DataStore:
|
|
183
183
|
LEVELDBLOG_HEADER_MAGIC,
|
184
184
|
LEVELDBLOG_HEADER_VERSION,
|
185
185
|
)
|
186
|
-
assert (
|
187
|
-
len(data)
|
188
|
-
)
|
186
|
+
assert len(data) == LEVELDBLOG_HEADER_LEN, (
|
187
|
+
f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
|
188
|
+
)
|
189
189
|
self._fp.write(data)
|
190
190
|
self._index += len(data)
|
191
191
|
|
192
192
|
def _read_header(self):
|
193
193
|
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
|
194
|
-
assert (
|
195
|
-
len(header)
|
196
|
-
)
|
194
|
+
assert len(header) == LEVELDBLOG_HEADER_LEN, (
|
195
|
+
f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
|
196
|
+
)
|
197
197
|
ident, magic, version = struct.unpack("<4sHB", header)
|
198
198
|
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
|
199
199
|
raise Exception("Invalid header")
|
wandb/sdk/internal/handler.py
CHANGED
@@ -37,7 +37,6 @@ from wandb.proto.wandb_internal_pb2 import (
|
|
37
37
|
|
38
38
|
from ..interface.interface_queue import InterfaceQueue
|
39
39
|
from ..lib import handler_util, proto_util
|
40
|
-
from ..wandb_metadata import Metadata
|
41
40
|
from . import context, sample, tb_watcher
|
42
41
|
from .settings_static import SettingsStatic
|
43
42
|
|
@@ -115,7 +114,6 @@ class HandleManager:
|
|
115
114
|
self._context_keeper = context_keeper
|
116
115
|
|
117
116
|
self._tb_watcher = None
|
118
|
-
self._metadata: Optional[Metadata] = None
|
119
117
|
self._step = 0
|
120
118
|
|
121
119
|
self._track_time = None
|
@@ -173,9 +171,6 @@ class HandleManager:
|
|
173
171
|
def handle_request_cancel(self, record: Record) -> None:
|
174
172
|
self._dispatch_record(record)
|
175
173
|
|
176
|
-
def handle_request_metadata(self, record: Record) -> None:
|
177
|
-
logger.warning("Metadata updates are ignored when using the legacy service.")
|
178
|
-
|
179
174
|
def handle_request_defer(self, record: Record) -> None:
|
180
175
|
defer = record.request.defer
|
181
176
|
state = defer.state
|
@@ -649,6 +644,9 @@ class HandleManager:
|
|
649
644
|
def handle_footer(self, record: Record) -> None:
|
650
645
|
self._dispatch_record(record)
|
651
646
|
|
647
|
+
def handle_metadata(self, record: Record) -> None:
|
648
|
+
self._dispatch_record(record)
|
649
|
+
|
652
650
|
def handle_request_attach(self, record: Record) -> None:
|
653
651
|
result = proto_util._result_from_record(record)
|
654
652
|
attach_id = record.request.attach.attach_id
|
@@ -362,6 +362,7 @@ class Api:
|
|
362
362
|
self.server_create_run_queue_supports_priority: Optional[bool] = None
|
363
363
|
self.server_supports_template_variables: Optional[bool] = None
|
364
364
|
self.server_push_to_run_queue_supports_priority: Optional[bool] = None
|
365
|
+
|
365
366
|
self._server_features_cache: Optional[Dict[str, bool]] = None
|
366
367
|
|
367
368
|
def gql(self, *args: Any, **kwargs: Any) -> Any:
|
@@ -4661,3 +4662,56 @@ class Api:
|
|
4661
4662
|
success: bool = response["stopRun"].get("success")
|
4662
4663
|
|
4663
4664
|
return success
|
4665
|
+
|
4666
|
+
@normalize_exceptions
|
4667
|
+
def create_custom_chart(
|
4668
|
+
self,
|
4669
|
+
entity: str,
|
4670
|
+
name: str,
|
4671
|
+
display_name: str,
|
4672
|
+
spec_type: str,
|
4673
|
+
access: str,
|
4674
|
+
spec: Union[str, Mapping[str, Any]],
|
4675
|
+
) -> Optional[Dict[str, Any]]:
|
4676
|
+
if not isinstance(spec, str):
|
4677
|
+
spec = json.dumps(spec)
|
4678
|
+
|
4679
|
+
mutation = gql(
|
4680
|
+
"""
|
4681
|
+
mutation CreateCustomChart(
|
4682
|
+
$entity: String!
|
4683
|
+
$name: String!
|
4684
|
+
$displayName: String!
|
4685
|
+
$type: String!
|
4686
|
+
$access: String!
|
4687
|
+
$spec: JSONString!
|
4688
|
+
) {
|
4689
|
+
createCustomChart(
|
4690
|
+
input: {
|
4691
|
+
entity: $entity
|
4692
|
+
name: $name
|
4693
|
+
displayName: $displayName
|
4694
|
+
type: $type
|
4695
|
+
access: $access
|
4696
|
+
spec: $spec
|
4697
|
+
}
|
4698
|
+
) {
|
4699
|
+
chart { id }
|
4700
|
+
}
|
4701
|
+
}
|
4702
|
+
"""
|
4703
|
+
)
|
4704
|
+
|
4705
|
+
variable_values = {
|
4706
|
+
"entity": entity,
|
4707
|
+
"name": name,
|
4708
|
+
"displayName": display_name,
|
4709
|
+
"type": spec_type,
|
4710
|
+
"access": access,
|
4711
|
+
"spec": spec,
|
4712
|
+
}
|
4713
|
+
|
4714
|
+
result: Optional[Dict[str, Any]] = self.gql(mutation, variable_values)[
|
4715
|
+
"createCustomChart"
|
4716
|
+
]
|
4717
|
+
return result
|
wandb/sdk/internal/sender.py
CHANGED
@@ -63,6 +63,7 @@ if TYPE_CHECKING:
|
|
63
63
|
ArtifactManifest,
|
64
64
|
ArtifactManifestEntry,
|
65
65
|
ArtifactRecord,
|
66
|
+
EnvironmentRecord,
|
66
67
|
HttpResponse,
|
67
68
|
LocalInfo,
|
68
69
|
Record,
|
@@ -212,6 +213,7 @@ class SendManager:
|
|
212
213
|
_context_keeper: context.ContextKeeper
|
213
214
|
|
214
215
|
_telemetry_obj: telemetry.TelemetryRecord
|
216
|
+
_environment_obj: "EnvironmentRecord"
|
215
217
|
_fs: Optional["file_stream.FileStreamApi"]
|
216
218
|
_run: Optional["RunRecord"]
|
217
219
|
_entity: Optional[str]
|
@@ -268,6 +270,7 @@ class SendManager:
|
|
268
270
|
|
269
271
|
self._start_time: int = 0
|
270
272
|
self._telemetry_obj = telemetry.TelemetryRecord()
|
273
|
+
self._environment_obj = wandb_internal_pb2.EnvironmentRecord()
|
271
274
|
self._config_metric_pbdict_list: List[Dict[int, Any]] = []
|
272
275
|
self._metadata_summary: Dict[str, Any] = defaultdict()
|
273
276
|
self._cached_summary: Dict[str, Any] = dict()
|
@@ -790,12 +793,12 @@ class SendManager:
|
|
790
793
|
|
791
794
|
def _config_backend_dict(self) -> sender_config.BackendConfigDict:
|
792
795
|
config = self._consolidated_config or sender_config.ConfigState()
|
793
|
-
|
794
796
|
return config.to_backend_dict(
|
795
797
|
telemetry_record=self._telemetry_obj,
|
796
798
|
framework=self._telemetry_get_framework(),
|
797
799
|
start_time_millis=self._start_time,
|
798
800
|
metric_pbdicts=self._config_metric_pbdict_list,
|
801
|
+
environment_record=self._environment_obj,
|
799
802
|
)
|
800
803
|
|
801
804
|
def _config_save(
|
@@ -1379,11 +1382,11 @@ class SendManager:
|
|
1379
1382
|
next_idx = len(self._config_metric_pbdict_list)
|
1380
1383
|
self._config_metric_pbdict_list.append(md)
|
1381
1384
|
self._config_metric_index_dict[metric.name] = next_idx
|
1382
|
-
self.
|
1385
|
+
self._debounce_config()
|
1383
1386
|
|
1384
1387
|
def _update_telemetry_record(self, telemetry: telemetry.TelemetryRecord) -> None:
|
1385
1388
|
self._telemetry_obj.MergeFrom(telemetry)
|
1386
|
-
self.
|
1389
|
+
self._debounce_config()
|
1387
1390
|
|
1388
1391
|
def send_telemetry(self, record: "Record") -> None:
|
1389
1392
|
self._update_telemetry_record(record.telemetry)
|
@@ -1417,6 +1420,23 @@ class SendManager:
|
|
1417
1420
|
# tbrecord watching threads are handled by handler.py
|
1418
1421
|
pass
|
1419
1422
|
|
1423
|
+
def _update_environment_record(self, environment: "EnvironmentRecord") -> None:
|
1424
|
+
self._environment_obj.MergeFrom(environment)
|
1425
|
+
self._debounce_config()
|
1426
|
+
|
1427
|
+
def send_environment(self, record: "Record") -> None:
|
1428
|
+
"""Inject environment info into config and upload as a JSON file."""
|
1429
|
+
self._update_environment_record(record.environment)
|
1430
|
+
|
1431
|
+
environment_json = json.dumps(proto_util.message_to_dict(self._environment_obj))
|
1432
|
+
|
1433
|
+
with open(
|
1434
|
+
os.path.join(self._settings.files_dir, filenames.METADATA_FNAME), "w"
|
1435
|
+
) as f:
|
1436
|
+
f.write(environment_json)
|
1437
|
+
|
1438
|
+
self._save_file(interface.GlobStr(filenames.METADATA_FNAME), policy="now")
|
1439
|
+
|
1420
1440
|
def send_request_link_artifact(self, record: "Record") -> None:
|
1421
1441
|
if not (record.control.req_resp or record.control.mailbox_slot):
|
1422
1442
|
raise ValueError(
|
@@ -79,6 +79,7 @@ class ConfigState:
|
|
79
79
|
framework: Optional[str],
|
80
80
|
start_time_millis: int,
|
81
81
|
metric_pbdicts: Sequence[Dict[int, Any]],
|
82
|
+
environment_record: wandb_internal_pb2.EnvironmentRecord,
|
82
83
|
) -> BackendConfigDict:
|
83
84
|
"""Returns a dictionary representation expected by the backend.
|
84
85
|
|
@@ -125,6 +126,14 @@ class ConfigState:
|
|
125
126
|
if metric_pbdicts:
|
126
127
|
wandb_internal["m"] = metric_pbdicts
|
127
128
|
|
129
|
+
###################################################
|
130
|
+
# Environment
|
131
|
+
###################################################
|
132
|
+
writer_id = environment_record.writer_id
|
133
|
+
if writer_id:
|
134
|
+
environment_dict = proto_util.message_to_dict(environment_record)
|
135
|
+
wandb_internal["e"] = {writer_id: environment_dict}
|
136
|
+
|
128
137
|
return BackendConfigDict(
|
129
138
|
{
|
130
139
|
key: {
|
@@ -370,9 +370,9 @@ class LaunchProject:
|
|
370
370
|
|
371
371
|
def set_job_entry_point(self, command: List[str]) -> "EntryPoint":
|
372
372
|
"""Set job entrypoint for the project."""
|
373
|
-
assert (
|
374
|
-
|
375
|
-
)
|
373
|
+
assert self._entry_point is None, (
|
374
|
+
"Cannot set entry point twice. Use LaunchProject.override_entrypoint"
|
375
|
+
)
|
376
376
|
new_entrypoint = EntryPoint(name=command[-1], command=command)
|
377
377
|
self._entry_point = new_entrypoint
|
378
378
|
return new_entrypoint
|
wandb/sdk/launch/agent/agent.py
CHANGED
@@ -76,9 +76,9 @@ class JobSpecAndQueue:
|
|
76
76
|
def _convert_access(access: str) -> str:
|
77
77
|
"""Convert access string to a value accepted by wandb."""
|
78
78
|
access = access.upper()
|
79
|
-
assert (
|
80
|
-
access
|
81
|
-
)
|
79
|
+
assert access == "PROJECT" or access == "USER", (
|
80
|
+
"Queue access must be either project or user"
|
81
|
+
)
|
82
82
|
return access
|
83
83
|
|
84
84
|
|
@@ -44,7 +44,9 @@ class JobAndRunStatusTracker:
|
|
44
44
|
self.run_id is not None
|
45
45
|
and self.project is not None
|
46
46
|
and self.entity is not None
|
47
|
-
),
|
47
|
+
), (
|
48
|
+
"Job tracker does not contain run info. Update with run info before checking if run stopped"
|
49
|
+
)
|
48
50
|
check_stop = event_loop_thread_exec(api.api.check_stop_requested)
|
49
51
|
try:
|
50
52
|
return bool(await check_stop(self.project, self.entity, self.run_id))
|
wandb/sdk/launch/utils.py
CHANGED
@@ -380,9 +380,9 @@ def diff_pip_requirements(req_1: List[str], req_2: List[str]) -> Dict[str, str]:
|
|
380
380
|
else:
|
381
381
|
raise ValueError(f"Unable to parse pip requirements file line: {line}")
|
382
382
|
if _name is not None:
|
383
|
-
assert re.match(
|
384
|
-
|
385
|
-
)
|
383
|
+
assert re.match(_VALID_PIP_PACKAGE_REGEX, _name), (
|
384
|
+
f"Invalid pip package name {_name}"
|
385
|
+
)
|
386
386
|
d[_name] = _version
|
387
387
|
return d
|
388
388
|
|