wandb 0.20.1rc20250604__py3-none-win32.whl → 0.21.0__py3-none-win32.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (125) hide show
  1. wandb/__init__.py +3 -6
  2. wandb/__init__.pyi +24 -23
  3. wandb/analytics/sentry.py +2 -2
  4. wandb/apis/importers/internals/internal.py +0 -3
  5. wandb/apis/internal.py +3 -0
  6. wandb/apis/paginator.py +17 -4
  7. wandb/apis/public/api.py +85 -4
  8. wandb/apis/public/artifacts.py +10 -8
  9. wandb/apis/public/files.py +5 -5
  10. wandb/apis/public/projects.py +44 -3
  11. wandb/apis/public/registries/{utils.py → _utils.py} +12 -12
  12. wandb/apis/public/registries/registries_search.py +2 -2
  13. wandb/apis/public/registries/registry.py +19 -18
  14. wandb/apis/public/reports.py +64 -8
  15. wandb/apis/public/runs.py +16 -23
  16. wandb/automations/__init__.py +10 -10
  17. wandb/automations/_filters/run_metrics.py +0 -2
  18. wandb/automations/_utils.py +0 -2
  19. wandb/automations/actions.py +0 -2
  20. wandb/automations/automations.py +0 -2
  21. wandb/automations/events.py +0 -2
  22. wandb/bin/gpu_stats.exe +0 -0
  23. wandb/bin/wandb-core +0 -0
  24. wandb/cli/beta.py +1 -7
  25. wandb/cli/cli.py +0 -30
  26. wandb/env.py +0 -6
  27. wandb/integration/catboost/catboost.py +6 -2
  28. wandb/integration/kfp/kfp_patch.py +3 -1
  29. wandb/integration/sb3/sb3.py +3 -3
  30. wandb/integration/ultralytics/callback.py +6 -2
  31. wandb/plot/__init__.py +2 -0
  32. wandb/plot/bar.py +30 -29
  33. wandb/plot/confusion_matrix.py +75 -71
  34. wandb/plot/histogram.py +26 -25
  35. wandb/plot/line.py +33 -32
  36. wandb/plot/line_series.py +100 -103
  37. wandb/plot/pr_curve.py +33 -32
  38. wandb/plot/roc_curve.py +38 -38
  39. wandb/plot/scatter.py +27 -27
  40. wandb/proto/v3/wandb_internal_pb2.py +366 -385
  41. wandb/proto/v3/wandb_settings_pb2.py +2 -2
  42. wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
  43. wandb/proto/v4/wandb_internal_pb2.py +352 -356
  44. wandb/proto/v4/wandb_settings_pb2.py +2 -2
  45. wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
  46. wandb/proto/v5/wandb_internal_pb2.py +352 -356
  47. wandb/proto/v5/wandb_settings_pb2.py +2 -2
  48. wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
  49. wandb/proto/v6/wandb_internal_pb2.py +352 -356
  50. wandb/proto/v6/wandb_settings_pb2.py +2 -2
  51. wandb/proto/v6/wandb_telemetry_pb2.py +10 -10
  52. wandb/sdk/artifacts/_generated/__init__.py +12 -1
  53. wandb/sdk/artifacts/_generated/input_types.py +20 -2
  54. wandb/sdk/artifacts/_generated/link_artifact.py +21 -0
  55. wandb/sdk/artifacts/_generated/operations.py +9 -0
  56. wandb/sdk/artifacts/_validators.py +40 -2
  57. wandb/sdk/artifacts/artifact.py +163 -21
  58. wandb/sdk/artifacts/storage_handlers/s3_handler.py +42 -1
  59. wandb/sdk/backend/backend.py +1 -1
  60. wandb/sdk/data_types/base_types/media.py +9 -7
  61. wandb/sdk/data_types/base_types/wb_value.py +6 -6
  62. wandb/sdk/data_types/saved_model.py +3 -3
  63. wandb/sdk/data_types/table.py +41 -41
  64. wandb/sdk/data_types/trace_tree.py +12 -12
  65. wandb/sdk/interface/interface.py +8 -19
  66. wandb/sdk/interface/interface_shared.py +7 -16
  67. wandb/sdk/internal/datastore.py +18 -18
  68. wandb/sdk/internal/handler.py +4 -74
  69. wandb/sdk/internal/internal_api.py +54 -0
  70. wandb/sdk/internal/sender.py +23 -3
  71. wandb/sdk/internal/sender_config.py +9 -0
  72. wandb/sdk/launch/_project_spec.py +3 -3
  73. wandb/sdk/launch/agent/agent.py +3 -3
  74. wandb/sdk/launch/agent/job_status_tracker.py +3 -1
  75. wandb/sdk/launch/utils.py +3 -3
  76. wandb/sdk/lib/console_capture.py +66 -19
  77. wandb/sdk/lib/printer.py +6 -7
  78. wandb/sdk/lib/progress.py +1 -3
  79. wandb/sdk/lib/service/ipc_support.py +13 -0
  80. wandb/sdk/lib/{service_connection.py → service/service_connection.py} +20 -56
  81. wandb/sdk/lib/service/service_port_file.py +105 -0
  82. wandb/sdk/lib/service/service_process.py +111 -0
  83. wandb/sdk/lib/service/service_token.py +164 -0
  84. wandb/sdk/lib/sock_client.py +8 -12
  85. wandb/sdk/wandb_init.py +1 -5
  86. wandb/sdk/wandb_require.py +9 -21
  87. wandb/sdk/wandb_run.py +23 -137
  88. wandb/sdk/wandb_settings.py +233 -80
  89. wandb/sdk/wandb_setup.py +2 -13
  90. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/METADATA +1 -3
  91. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/RECORD +94 -120
  92. wandb/sdk/internal/flow_control.py +0 -263
  93. wandb/sdk/internal/internal.py +0 -401
  94. wandb/sdk/internal/internal_util.py +0 -97
  95. wandb/sdk/internal/system/__init__.py +0 -0
  96. wandb/sdk/internal/system/assets/__init__.py +0 -25
  97. wandb/sdk/internal/system/assets/aggregators.py +0 -31
  98. wandb/sdk/internal/system/assets/asset_registry.py +0 -20
  99. wandb/sdk/internal/system/assets/cpu.py +0 -163
  100. wandb/sdk/internal/system/assets/disk.py +0 -210
  101. wandb/sdk/internal/system/assets/gpu.py +0 -416
  102. wandb/sdk/internal/system/assets/gpu_amd.py +0 -233
  103. wandb/sdk/internal/system/assets/interfaces.py +0 -205
  104. wandb/sdk/internal/system/assets/ipu.py +0 -177
  105. wandb/sdk/internal/system/assets/memory.py +0 -166
  106. wandb/sdk/internal/system/assets/network.py +0 -125
  107. wandb/sdk/internal/system/assets/open_metrics.py +0 -293
  108. wandb/sdk/internal/system/assets/tpu.py +0 -154
  109. wandb/sdk/internal/system/assets/trainium.py +0 -393
  110. wandb/sdk/internal/system/env_probe_helpers.py +0 -13
  111. wandb/sdk/internal/system/system_info.py +0 -248
  112. wandb/sdk/internal/system/system_monitor.py +0 -224
  113. wandb/sdk/internal/writer.py +0 -204
  114. wandb/sdk/lib/service_token.py +0 -93
  115. wandb/sdk/service/__init__.py +0 -0
  116. wandb/sdk/service/_startup_debug.py +0 -22
  117. wandb/sdk/service/port_file.py +0 -53
  118. wandb/sdk/service/server.py +0 -107
  119. wandb/sdk/service/server_sock.py +0 -286
  120. wandb/sdk/service/service.py +0 -252
  121. wandb/sdk/service/streams.py +0 -425
  122. wandb/sdk/wandb_metadata.py +0 -623
  123. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/WHEEL +0 -0
  124. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/entry_points.txt +0 -0
  125. {wandb-0.20.1rc20250604.dist-info → wandb-0.21.0.dist-info}/licenses/LICENSE +0 -0
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import os
6
+ import re
6
7
  import time
7
8
  from pathlib import PurePosixPath
8
9
  from typing import TYPE_CHECKING, Sequence
@@ -52,10 +53,20 @@ class S3Handler(StorageHandler):
52
53
  required="s3:// references requires the boto3 library, run pip install wandb[aws]",
53
54
  lazy=False,
54
55
  )
56
+
57
+ from botocore.client import Config # type: ignore
58
+
59
+ s3_endpoint = os.getenv("AWS_S3_ENDPOINT_URL")
60
+ config = (
61
+ Config(s3={"addressing_style": "virtual"})
62
+ if s3_endpoint and self._is_coreweave_endpoint(s3_endpoint)
63
+ else None
64
+ )
55
65
  self._s3 = boto.session.Session().resource(
56
66
  "s3",
57
- endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL"),
67
+ endpoint_url=s3_endpoint,
58
68
  region_name=os.getenv("AWS_REGION"),
69
+ config=config,
59
70
  )
60
71
  self._botocore = util.get_module("botocore")
61
72
  return self._s3
@@ -296,3 +307,33 @@ class S3Handler(StorageHandler):
296
307
  if hasattr(obj, "version_id") and obj.version_id and obj.version_id != "null":
297
308
  extra["versionID"] = obj.version_id
298
309
  return extra
310
+
311
+ _CW_LEGACY_NETLOC_REGEX: re.Pattern[str] = re.compile(
312
+ r"""
313
+ # accelerated endpoints like "accel-object.<region>.coreweave.com"
314
+ accel-object\.[a-z0-9-]+\.coreweave\.com
315
+ |
316
+ # URLs like "object.<region>.coreweave.com"
317
+ object\.[a-z0-9-]+\.coreweave\.com
318
+ """,
319
+ flags=re.VERBOSE,
320
+ )
321
+
322
+ def _is_coreweave_endpoint(self, endpoint_url: str) -> bool:
323
+ if not (url := endpoint_url.strip().rstrip("/")):
324
+ return False
325
+
326
+ # Only http://cwlota.com is supported using HTTP
327
+ if url == "http://cwlota.com":
328
+ return True
329
+
330
+ # Enforce HTTPS otherwise
331
+ https_url = url if url.startswith("https://") else f"https://{url}"
332
+ netloc = urlparse(https_url).netloc
333
+ return bool(
334
+ # Match for https://cwobject.com
335
+ (netloc == "cwobject.com")
336
+ or
337
+ # Check for legacy endpoints
338
+ self._CW_LEGACY_NETLOC_REGEX.fullmatch(netloc)
339
+ )
@@ -11,7 +11,7 @@ from wandb.sdk.interface.interface import InterfaceBase
11
11
  from wandb.sdk.wandb_settings import Settings
12
12
 
13
13
  if TYPE_CHECKING:
14
- from wandb.sdk.lib import service_connection
14
+ from wandb.sdk.lib.service import service_connection
15
15
 
16
16
  logger = logging.getLogger("wandb")
17
17
 
@@ -127,9 +127,9 @@ class Media(WBValue):
127
127
  self._path = path
128
128
  self._is_tmp = is_tmp
129
129
  self._extension = extension
130
- assert extension is None or path.endswith(
131
- extension
132
- ), f'Media file extension "{extension}" must occur at the end of path "{path}".'
130
+ assert extension is None or path.endswith(extension), (
131
+ f'Media file extension "{extension}" must occur at the end of path "{path}".'
132
+ )
133
133
 
134
134
  with open(self._path, "rb") as f:
135
135
  self._sha256 = hashlib.sha256(f.read()).hexdigest()
@@ -247,11 +247,13 @@ class Media(WBValue):
247
247
  json_obj["_latest_artifact_path"] = artifact_entry_latest_url
248
248
 
249
249
  if artifact_entry_url is None or self.is_bound():
250
- assert self.is_bound(), f"Value of type {type(self).__name__} must be bound to a run with bind_to_run() before being serialized to JSON."
250
+ assert self.is_bound(), (
251
+ f"Value of type {type(self).__name__} must be bound to a run with bind_to_run() before being serialized to JSON."
252
+ )
251
253
 
252
- assert (
253
- self._run is run
254
- ), "We don't support referring to media files across runs."
254
+ assert self._run is run, (
255
+ "We don't support referring to media files across runs."
256
+ )
255
257
 
256
258
  # The following two assertions are guaranteed to pass
257
259
  # by definition is_bound, but are needed for
@@ -218,17 +218,17 @@ class WBValue:
218
218
  def _set_artifact_source(
219
219
  self, artifact: "Artifact", name: Optional[str] = None
220
220
  ) -> None:
221
- assert (
222
- self._artifact_source is None
223
- ), f"Cannot update artifact_source. Existing source: {self._artifact_source.artifact}/{self._artifact_source.name}"
221
+ assert self._artifact_source is None, (
222
+ f"Cannot update artifact_source. Existing source: {self._artifact_source.artifact}/{self._artifact_source.name}"
223
+ )
224
224
  self._artifact_source = _WBValueArtifactSource(artifact, name)
225
225
 
226
226
  def _set_artifact_target(
227
227
  self, artifact: "Artifact", name: Optional[str] = None
228
228
  ) -> None:
229
- assert (
230
- self._artifact_target is None
231
- ), f"Cannot update artifact_target. Existing target: {self._artifact_target.artifact}/{self._artifact_target.name}"
229
+ assert self._artifact_target is None, (
230
+ f"Cannot update artifact_target. Existing target: {self._artifact_target.artifact}/{self._artifact_target.name}"
231
+ )
232
232
  self._artifact_target = _WBValueArtifactTarget(artifact, name)
233
233
 
234
234
  def _get_artifact_entry_ref_url(self) -> Optional[str]:
@@ -257,9 +257,9 @@ class _SavedModel(WBValue, Generic[SavedModelObjType]):
257
257
  self._model_obj = None
258
258
 
259
259
  def _set_obj(self, model_obj: Any) -> None:
260
- assert model_obj is not None and self._validate_obj(
261
- model_obj
262
- ), f"Invalid model object {model_obj}"
260
+ assert model_obj is not None and self._validate_obj(model_obj), (
261
+ f"Invalid model object {model_obj}"
262
+ )
263
263
  self._model_obj = model_obj
264
264
 
265
265
  def _dump(self, target_path: str) -> None:
@@ -203,8 +203,8 @@ class Table(Media):
203
203
  This means you can embed `Images`, `Video`, `Audio`, and other sorts of rich, annotated media
204
204
  directly in Tables, alongside other traditional scalar values.
205
205
 
206
- This class is the primary class used to generate the Table Visualizer
207
- in the UI: https://docs.wandb.ai/guides/data-vis/tables.
206
+ This class is the primary class used to generate the W&B Tables
207
+ https://docs.wandb.ai/guides/models/tables/.
208
208
  """
209
209
 
210
210
  MAX_ROWS = 10000
@@ -292,9 +292,9 @@ class Table(Media):
292
292
  self._init_from_list([], columns, optional, dtype)
293
293
 
294
294
  def _validate_log_mode(self, log_mode):
295
- assert (
296
- log_mode in _SUPPORTED_LOGGING_MODES
297
- ), f"Invalid log_mode: {log_mode}. Must be one of {_SUPPORTED_LOGGING_MODES}"
295
+ assert log_mode in _SUPPORTED_LOGGING_MODES, (
296
+ f"Invalid log_mode: {log_mode}. Must be one of {_SUPPORTED_LOGGING_MODES}"
297
+ )
298
298
 
299
299
  @staticmethod
300
300
  def _assert_valid_columns(columns):
@@ -314,9 +314,9 @@ class Table(Media):
314
314
  self.add_data(*row)
315
315
 
316
316
  def _init_from_ndarray(self, ndarray, columns, optional=True, dtype=None):
317
- assert util.is_numpy_array(
318
- ndarray
319
- ), "ndarray argument expects a `numpy.ndarray` object"
317
+ assert util.is_numpy_array(ndarray), (
318
+ "ndarray argument expects a `numpy.ndarray` object"
319
+ )
320
320
  self.data = []
321
321
  self._assert_valid_columns(columns)
322
322
  self.columns = columns
@@ -325,9 +325,9 @@ class Table(Media):
325
325
  self.add_data(*row)
326
326
 
327
327
  def _init_from_dataframe(self, dataframe, columns, optional=True, dtype=None):
328
- assert util.is_pandas_data_frame(
329
- dataframe
330
- ), "dataframe argument expects a `pandas.core.frame.DataFrame` object"
328
+ assert util.is_pandas_data_frame(dataframe), (
329
+ "dataframe argument expects a `pandas.core.frame.DataFrame` object"
330
+ )
331
331
  self.data = []
332
332
  columns = list(dataframe.columns)
333
333
  self._assert_valid_columns(columns)
@@ -440,17 +440,17 @@ class Table(Media):
440
440
  is_fk = isinstance(wbtype, _ForeignKeyType)
441
441
  is_fi = isinstance(wbtype, _ForeignIndexType)
442
442
  if is_pk or is_fk or is_fi:
443
- assert (
444
- not optional
445
- ), "Primary keys, foreign keys, and foreign indexes cannot be optional."
443
+ assert not optional, (
444
+ "Primary keys, foreign keys, and foreign indexes cannot be optional."
445
+ )
446
446
 
447
447
  if (is_fk or is_fk) and id(wbtype.params["table"]) == id(self):
448
448
  raise AssertionError("Cannot set a foreign table reference to same table.")
449
449
 
450
450
  if is_pk:
451
- assert (
452
- self._pk_col is None
453
- ), f"Cannot have multiple primary keys - {self._pk_col} is already set as the primary key."
451
+ assert self._pk_col is None, (
452
+ f"Cannot have multiple primary keys - {self._pk_col} is already set as the primary key."
453
+ )
454
454
 
455
455
  # Update the column type
456
456
  self._column_types.params["type_map"][col_name] = wbtype
@@ -464,21 +464,21 @@ class Table(Media):
464
464
 
465
465
  def _eq_debug(self, other, should_assert=False):
466
466
  eq = isinstance(other, Table)
467
- assert (
468
- not should_assert or eq
469
- ), f"Found type {other.__class__}, expected {Table}"
467
+ assert not should_assert or eq, (
468
+ f"Found type {other.__class__}, expected {Table}"
469
+ )
470
470
  eq = eq and len(self.data) == len(other.data)
471
- assert (
472
- not should_assert or eq
473
- ), f"Found {len(other.data)} rows, expected {len(self.data)}"
471
+ assert not should_assert or eq, (
472
+ f"Found {len(other.data)} rows, expected {len(self.data)}"
473
+ )
474
474
  eq = eq and self.columns == other.columns
475
- assert (
476
- not should_assert or eq
477
- ), f"Found columns {other.columns}, expected {self.columns}"
475
+ assert not should_assert or eq, (
476
+ f"Found columns {other.columns}, expected {self.columns}"
477
+ )
478
478
  eq = eq and self._column_types == other._column_types
479
- assert (
480
- not should_assert or eq
481
- ), f"Found column type {other._column_types}, expected column type {self._column_types}"
479
+ assert not should_assert or eq, (
480
+ f"Found column type {other._column_types}, expected column type {self._column_types}"
481
+ )
482
482
  if eq:
483
483
  for row_ndx in range(len(self.data)):
484
484
  for col_ndx in range(len(self.data[row_ndx])):
@@ -487,9 +487,9 @@ class Table(Media):
487
487
  if util.is_numpy_array(_eq):
488
488
  _eq = ((_eq * -1) + 1).sum() == 0
489
489
  eq = eq and _eq
490
- assert (
491
- not should_assert or eq
492
- ), f"Unequal data at row_ndx {row_ndx} col_ndx {col_ndx}: found {other.data[row_ndx][col_ndx]}, expected {self.data[row_ndx][col_ndx]}"
490
+ assert not should_assert or eq, (
491
+ f"Unequal data at row_ndx {row_ndx} col_ndx {col_ndx}: found {other.data[row_ndx][col_ndx]}, expected {self.data[row_ndx][col_ndx]}"
492
+ )
493
493
  if not eq:
494
494
  return eq
495
495
  return eq
@@ -932,9 +932,9 @@ class Table(Media):
932
932
  assert isinstance(data, list) or is_np
933
933
  assert isinstance(optional, bool)
934
934
  is_first_col = len(self.columns) == 0
935
- assert is_first_col or len(data) == len(
936
- self.data
937
- ), f"Expected length {len(self.data)}, found {len(data)}"
935
+ assert is_first_col or len(data) == len(self.data), (
936
+ f"Expected length {len(self.data)}, found {len(data)}"
937
+ )
938
938
 
939
939
  # Add the new data
940
940
  for ndx in range(max(len(data), len(self.data))):
@@ -1257,13 +1257,13 @@ class JoinedTable(Media):
1257
1257
 
1258
1258
  def _eq_debug(self, other, should_assert=False):
1259
1259
  eq = isinstance(other, JoinedTable)
1260
- assert (
1261
- not should_assert or eq
1262
- ), f"Found type {other.__class__}, expected {JoinedTable}"
1260
+ assert not should_assert or eq, (
1261
+ f"Found type {other.__class__}, expected {JoinedTable}"
1262
+ )
1263
1263
  eq = eq and self._join_key == other._join_key
1264
- assert (
1265
- not should_assert or eq
1266
- ), f"Found {other._join_key} join key, expected {self._join_key}"
1264
+ assert not should_assert or eq, (
1265
+ f"Found {other._join_key} join key, expected {self._join_key}"
1266
+ )
1267
1267
  eq = eq and self._table1._eq_debug(other._table1, should_assert)
1268
1268
  eq = eq and self._table2._eq_debug(other._table2, should_assert)
1269
1269
  return eq
@@ -261,14 +261,14 @@ class Trace:
261
261
  A Span object.
262
262
  """
263
263
  if kind is not None:
264
- assert (
265
- kind.upper() in SpanKind.__members__
266
- ), "Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
264
+ assert kind.upper() in SpanKind.__members__, (
265
+ "Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
266
+ )
267
267
  kind = SpanKind(kind.upper())
268
268
  if status_code is not None:
269
- assert (
270
- status_code.upper() in StatusCode.__members__
271
- ), "Invalid status code, can be one of 'SUCCESS' or 'ERROR'"
269
+ assert status_code.upper() in StatusCode.__members__, (
270
+ "Invalid status code, can be one of 'SUCCESS' or 'ERROR'"
271
+ )
272
272
  status_code = StatusCode(status_code.upper())
273
273
  if inputs is not None:
274
274
  assert isinstance(inputs, dict), "Inputs must be a dictionary"
@@ -419,9 +419,9 @@ class Trace:
419
419
  Args:
420
420
  value: The kind of the trace to be set.
421
421
  """
422
- assert (
423
- value.upper() in SpanKind.__members__
424
- ), "Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
422
+ assert value.upper() in SpanKind.__members__, (
423
+ "Invalid span kind, can be one of 'LLM', 'AGENT', 'CHAIN', 'TOOL'"
424
+ )
425
425
  self._span.span_kind = SpanKind(value.upper())
426
426
 
427
427
  def log(self, name: str) -> None:
@@ -433,8 +433,8 @@ class Trace:
433
433
  trace_tree = WBTraceTree(self._span, self._model_dict)
434
434
  # NOTE: Does not work for reinit="create_new" runs.
435
435
  # This method should be deprecated and users should call run.log().
436
- assert (
437
- wandb.run is not None
438
- ), "You must call wandb.init() before logging a trace"
436
+ assert wandb.run is not None, (
437
+ "You must call wandb.init() before logging a trace"
438
+ )
439
439
  assert len(name.strip()) > 0, "You must provide a valid name to log the trace"
440
440
  wandb.run.log({name: trace_tree})
@@ -150,8 +150,7 @@ class InterfaceBase:
150
150
  if run._settings.run_notes is not None:
151
151
  proto_run.notes = run._settings.run_notes
152
152
  if run._settings.run_tags is not None:
153
- for tag in run._settings.run_tags:
154
- proto_run.tags.append(tag)
153
+ proto_run.tags.extend(run._settings.run_tags)
155
154
  if run._start_time is not None:
156
155
  proto_run.start_time.FromMicroseconds(int(run._start_time * 1e6))
157
156
  if run._starting_step is not None:
@@ -217,13 +216,6 @@ class InterfaceBase:
217
216
  def _publish_config(self, cfg: pb.ConfigRecord) -> None:
218
217
  raise NotImplementedError
219
218
 
220
- def publish_metadata(self, metadata: pb.MetadataRequest) -> None:
221
- self._publish_metadata(metadata)
222
-
223
- @abstractmethod
224
- def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
225
- raise NotImplementedError
226
-
227
219
  @abstractmethod
228
220
  def _publish_metric(self, metric: pb.MetricRecord) -> None:
229
221
  raise NotImplementedError
@@ -671,6 +663,13 @@ class InterfaceBase:
671
663
  def _publish_telemetry(self, telem: tpb.TelemetryRecord) -> None:
672
664
  raise NotImplementedError
673
665
 
666
+ def publish_environment(self, environment: pb.EnvironmentRecord) -> None:
667
+ self._publish_environment(environment)
668
+
669
+ @abstractmethod
670
+ def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
671
+ raise NotImplementedError
672
+
674
673
  def publish_partial_history(
675
674
  self,
676
675
  run: "Run",
@@ -1000,16 +999,6 @@ class InterfaceBase:
1000
999
  ) -> MailboxHandle[pb.Result]:
1001
1000
  raise NotImplementedError
1002
1001
 
1003
- def deliver_get_system_metadata(self) -> MailboxHandle[pb.Result]:
1004
- get_system_metadata = pb.GetSystemMetadataRequest()
1005
- return self._deliver_get_system_metadata(get_system_metadata)
1006
-
1007
- @abstractmethod
1008
- def _deliver_get_system_metadata(
1009
- self, get_system_metadata: pb.GetSystemMetadataRequest
1010
- ) -> MailboxHandle[pb.Result]:
1011
- raise NotImplementedError
1012
-
1013
1002
  def deliver_exit(self, exit_code: Optional[int]) -> MailboxHandle[pb.Result]:
1014
1003
  exit_data = self._make_exit(exit_code)
1015
1004
  return self._deliver_exit(exit_data)
@@ -59,6 +59,10 @@ class InterfaceShared(InterfaceBase):
59
59
  rec = self._make_record(telemetry=telem)
60
60
  self._publish(rec)
61
61
 
62
+ def _publish_environment(self, environment: pb.EnvironmentRecord) -> None:
63
+ rec = self._make_record(environment=environment)
64
+ self._publish(rec)
65
+
62
66
  def _publish_job_input(
63
67
  self, job_input: pb.JobInputRequest
64
68
  ) -> MailboxHandle[pb.Result]:
@@ -106,11 +110,9 @@ class InterfaceShared(InterfaceBase):
106
110
  summary_record: Optional[pb.SummaryRecordRequest] = None,
107
111
  telemetry_record: Optional[pb.TelemetryRecordRequest] = None,
108
112
  get_system_metrics: Optional[pb.GetSystemMetricsRequest] = None,
109
- get_system_metadata: Optional[pb.GetSystemMetadataRequest] = None,
110
113
  python_packages: Optional[pb.PythonPackagesRequest] = None,
111
114
  job_input: Optional[pb.JobInputRequest] = None,
112
115
  run_finish_without_exit: Optional[pb.RunFinishWithoutExitRequest] = None,
113
- metadata: Optional[pb.MetadataRequest] = None,
114
116
  ) -> pb.Record:
115
117
  request = pb.Request()
116
118
  if get_summary:
@@ -169,8 +171,6 @@ class InterfaceShared(InterfaceBase):
169
171
  request.telemetry_record.CopyFrom(telemetry_record)
170
172
  elif get_system_metrics:
171
173
  request.get_system_metrics.CopyFrom(get_system_metrics)
172
- elif get_system_metadata:
173
- request.get_system_metadata.CopyFrom(get_system_metadata)
174
174
  elif sync_finish:
175
175
  request.sync_finish.CopyFrom(sync_finish)
176
176
  elif python_packages:
@@ -179,8 +179,6 @@ class InterfaceShared(InterfaceBase):
179
179
  request.job_input.CopyFrom(job_input)
180
180
  elif run_finish_without_exit:
181
181
  request.run_finish_without_exit.CopyFrom(run_finish_without_exit)
182
- elif metadata:
183
- request.metadata.CopyFrom(metadata)
184
182
  else:
185
183
  raise Exception("Invalid request")
186
184
  record = self._make_record(request=request)
@@ -212,6 +210,7 @@ class InterfaceShared(InterfaceBase):
212
210
  use_artifact: Optional[pb.UseArtifactRecord] = None,
213
211
  output: Optional[pb.OutputRecord] = None,
214
212
  output_raw: Optional[pb.OutputRawRecord] = None,
213
+ environment: Optional[pb.EnvironmentRecord] = None,
215
214
  ) -> pb.Record:
216
215
  record = pb.Record()
217
216
  if run:
@@ -254,6 +253,8 @@ class InterfaceShared(InterfaceBase):
254
253
  record.output.CopyFrom(output)
255
254
  elif output_raw:
256
255
  record.output_raw.CopyFrom(output_raw)
256
+ elif environment:
257
+ record.environment.CopyFrom(environment)
257
258
  else:
258
259
  raise Exception("Invalid record")
259
260
  return record
@@ -304,10 +305,6 @@ class InterfaceShared(InterfaceBase):
304
305
  rec = self._make_record(summary=summary)
305
306
  self._publish(rec)
306
307
 
307
- def _publish_metadata(self, metadata: pb.MetadataRequest) -> None:
308
- rec = self._make_request(metadata=metadata)
309
- self._publish(rec)
310
-
311
308
  def _publish_metric(self, metric: pb.MetricRecord) -> None:
312
309
  rec = self._make_record(metric=metric)
313
310
  self._publish(rec)
@@ -422,12 +419,6 @@ class InterfaceShared(InterfaceBase):
422
419
  record = self._make_request(get_system_metrics=get_system_metrics)
423
420
  return self._deliver_record(record)
424
421
 
425
- def _deliver_get_system_metadata(
426
- self, get_system_metadata: pb.GetSystemMetadataRequest
427
- ) -> MailboxHandle[pb.Result]:
428
- record = self._make_request(get_system_metadata=get_system_metadata)
429
- return self._deliver_record(record)
430
-
431
422
  def _deliver_exit(
432
423
  self,
433
424
  exit_data: pb.RunExitRecord,
@@ -124,18 +124,18 @@ class DataStore:
124
124
  header = self._fp.read(LEVELDBLOG_HEADER_LEN)
125
125
  if len(header) == 0:
126
126
  return None
127
- assert (
128
- len(header) == LEVELDBLOG_HEADER_LEN
129
- ), f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
127
+ assert len(header) == LEVELDBLOG_HEADER_LEN, (
128
+ f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
129
+ )
130
130
  fields = struct.unpack("<IHB", header)
131
131
  checksum, dlength, dtype = fields
132
132
  # check len, better fit in the block
133
133
  self._index += LEVELDBLOG_HEADER_LEN
134
134
  data = self._fp.read(dlength)
135
135
  checksum_computed = zlib.crc32(data, self._crc[dtype]) & 0xFFFFFFFF
136
- assert (
137
- checksum == checksum_computed
138
- ), "record checksum is invalid, data may be corrupt"
136
+ assert checksum == checksum_computed, (
137
+ "record checksum is invalid, data may be corrupt"
138
+ )
139
139
  self._index += dlength
140
140
  return dtype, data
141
141
 
@@ -158,9 +158,9 @@ class DataStore:
158
158
  if dtype == LEVELDBLOG_FULL:
159
159
  return data
160
160
 
161
- assert (
162
- dtype == LEVELDBLOG_FIRST
163
- ), f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
161
+ assert dtype == LEVELDBLOG_FIRST, (
162
+ f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
163
+ )
164
164
  while True:
165
165
  offset = self._index % LEVELDBLOG_BLOCK_LEN
166
166
  record = self.scan_record()
@@ -170,9 +170,9 @@ class DataStore:
170
170
  if dtype == LEVELDBLOG_LAST:
171
171
  data += new_data
172
172
  break
173
- assert (
174
- dtype == LEVELDBLOG_MIDDLE
175
- ), f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
173
+ assert dtype == LEVELDBLOG_MIDDLE, (
174
+ f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
175
+ )
176
176
  data += new_data
177
177
  return data
178
178
 
@@ -183,17 +183,17 @@ class DataStore:
183
183
  LEVELDBLOG_HEADER_MAGIC,
184
184
  LEVELDBLOG_HEADER_VERSION,
185
185
  )
186
- assert (
187
- len(data) == LEVELDBLOG_HEADER_LEN
188
- ), f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
186
+ assert len(data) == LEVELDBLOG_HEADER_LEN, (
187
+ f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
188
+ )
189
189
  self._fp.write(data)
190
190
  self._index += len(data)
191
191
 
192
192
  def _read_header(self):
193
193
  header = self._fp.read(LEVELDBLOG_HEADER_LEN)
194
- assert (
195
- len(header) == LEVELDBLOG_HEADER_LEN
196
- ), f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
194
+ assert len(header) == LEVELDBLOG_HEADER_LEN, (
195
+ f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
196
+ )
197
197
  ident, magic, version = struct.unpack("<4sHB", header)
198
198
  if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
199
199
  raise Exception("Invalid header")