lsst-pipe-base 30.0.1rc1__py3-none-any.whl → 30.2025.5100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. lsst/pipe/base/_instrument.py +20 -31
  2. lsst/pipe/base/_quantumContext.py +3 -3
  3. lsst/pipe/base/_status.py +10 -43
  4. lsst/pipe/base/_task_metadata.py +2 -2
  5. lsst/pipe/base/all_dimensions_quantum_graph_builder.py +3 -8
  6. lsst/pipe/base/automatic_connection_constants.py +1 -20
  7. lsst/pipe/base/cli/cmd/__init__.py +2 -18
  8. lsst/pipe/base/cli/cmd/commands.py +4 -149
  9. lsst/pipe/base/connectionTypes.py +160 -72
  10. lsst/pipe/base/connections.py +9 -6
  11. lsst/pipe/base/execution_reports.py +5 -0
  12. lsst/pipe/base/graph/graph.py +10 -11
  13. lsst/pipe/base/graph/quantumNode.py +4 -4
  14. lsst/pipe/base/graph_walker.py +10 -8
  15. lsst/pipe/base/log_capture.py +80 -40
  16. lsst/pipe/base/mp_graph_executor.py +15 -51
  17. lsst/pipe/base/pipeline.py +6 -5
  18. lsst/pipe/base/pipelineIR.py +8 -2
  19. lsst/pipe/base/pipelineTask.py +7 -5
  20. lsst/pipe/base/pipeline_graph/_dataset_types.py +2 -2
  21. lsst/pipe/base/pipeline_graph/_edges.py +22 -32
  22. lsst/pipe/base/pipeline_graph/_mapping_views.py +7 -4
  23. lsst/pipe/base/pipeline_graph/_pipeline_graph.py +7 -14
  24. lsst/pipe/base/pipeline_graph/expressions.py +2 -2
  25. lsst/pipe/base/pipeline_graph/io.py +10 -7
  26. lsst/pipe/base/pipeline_graph/visualization/_dot.py +12 -13
  27. lsst/pipe/base/pipeline_graph/visualization/_layout.py +18 -16
  28. lsst/pipe/base/pipeline_graph/visualization/_merge.py +7 -4
  29. lsst/pipe/base/pipeline_graph/visualization/_printer.py +10 -10
  30. lsst/pipe/base/pipeline_graph/visualization/_status_annotator.py +0 -7
  31. lsst/pipe/base/prerequisite_helpers.py +1 -2
  32. lsst/pipe/base/quantum_graph/_common.py +20 -19
  33. lsst/pipe/base/quantum_graph/_multiblock.py +31 -37
  34. lsst/pipe/base/quantum_graph/_predicted.py +13 -111
  35. lsst/pipe/base/quantum_graph/_provenance.py +45 -1136
  36. lsst/pipe/base/quantum_graph/aggregator/__init__.py +1 -0
  37. lsst/pipe/base/quantum_graph/aggregator/_communicators.py +289 -204
  38. lsst/pipe/base/quantum_graph/aggregator/_config.py +9 -87
  39. lsst/pipe/base/quantum_graph/aggregator/_ingester.py +12 -13
  40. lsst/pipe/base/quantum_graph/aggregator/_scanner.py +235 -49
  41. lsst/pipe/base/quantum_graph/aggregator/_structs.py +116 -6
  42. lsst/pipe/base/quantum_graph/aggregator/_supervisor.py +39 -29
  43. lsst/pipe/base/quantum_graph/aggregator/_writer.py +351 -34
  44. lsst/pipe/base/quantum_graph/visualization.py +1 -5
  45. lsst/pipe/base/quantum_graph_builder.py +8 -21
  46. lsst/pipe/base/quantum_graph_executor.py +13 -116
  47. lsst/pipe/base/quantum_graph_skeleton.py +29 -31
  48. lsst/pipe/base/quantum_provenance_graph.py +12 -29
  49. lsst/pipe/base/separable_pipeline_executor.py +3 -19
  50. lsst/pipe/base/single_quantum_executor.py +42 -67
  51. lsst/pipe/base/struct.py +0 -4
  52. lsst/pipe/base/testUtils.py +3 -3
  53. lsst/pipe/base/tests/mocks/_storage_class.py +1 -2
  54. lsst/pipe/base/version.py +1 -1
  55. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/METADATA +3 -3
  56. lsst_pipe_base-30.2025.5100.dist-info/RECORD +125 -0
  57. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/WHEEL +1 -1
  58. lsst/pipe/base/log_on_close.py +0 -76
  59. lsst/pipe/base/quantum_graph/aggregator/_workers.py +0 -303
  60. lsst/pipe/base/quantum_graph/formatter.py +0 -171
  61. lsst/pipe/base/quantum_graph/ingest_graph.py +0 -413
  62. lsst_pipe_base-30.0.1rc1.dist-info/RECORD +0 -129
  63. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/entry_points.txt +0 -0
  64. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/licenses/COPYRIGHT +0 -0
  65. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/licenses/LICENSE +0 -0
  66. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/licenses/bsd_license.txt +0 -0
  67. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/licenses/gpl-v3.0.txt +0 -0
  68. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/top_level.txt +0 -0
  69. {lsst_pipe_base-30.0.1rc1.dist-info → lsst_pipe_base-30.2025.5100.dist-info}/zip-safe +0 -0
@@ -43,7 +43,7 @@ __all__ = (
43
43
 
44
44
  import dataclasses
45
45
  import functools
46
- from typing import TYPE_CHECKING, Any, Literal
46
+ from typing import TYPE_CHECKING, Any, Literal, TypeAlias
47
47
 
48
48
  from lsst.daf.butler.queries.expressions.parser.ply import lex, yacc
49
49
 
@@ -268,4 +268,4 @@ def parse(expression: str) -> Node:
268
268
  return _ParserYacc().parse(expression)
269
269
 
270
270
 
271
- type Node = IdentifierNode | DirectionNode | NotNode | UnionNode | IntersectionNode
271
+ Node: TypeAlias = IdentifierNode | DirectionNode | NotNode | UnionNode | IntersectionNode
@@ -33,10 +33,11 @@ __all__ = (
33
33
  "SerializedTaskInitNode",
34
34
  "SerializedTaskNode",
35
35
  "SerializedTaskSubset",
36
+ "expect_not_none",
36
37
  )
37
38
 
38
39
  from collections.abc import Mapping
39
- from typing import Any
40
+ from typing import Any, TypeVar
40
41
 
41
42
  import networkx
42
43
  import pydantic
@@ -52,12 +53,14 @@ from ._pipeline_graph import PipelineGraph
52
53
  from ._task_subsets import StepDefinitions, TaskSubset
53
54
  from ._tasks import TaskImportMode, TaskInitNode, TaskNode
54
55
 
56
+ _U = TypeVar("_U")
57
+
55
58
  _IO_VERSION_INFO = (0, 0, 1)
56
59
  """Version tuple embedded in saved PipelineGraphs.
57
60
  """
58
61
 
59
62
 
60
- def _expect_not_none[U](value: U | None, msg: str) -> U:
63
+ def expect_not_none(value: _U | None, msg: str) -> _U:
61
64
  """Check that a value is not `None` and return it.
62
65
 
63
66
  Parameters
@@ -415,7 +418,7 @@ class SerializedTaskNode(pydantic.BaseModel):
415
418
  init = self.init.deserialize(
416
419
  init_key,
417
420
  task_class_name=self.task_class,
418
- config_str=_expect_not_none(
421
+ config_str=expect_not_none(
419
422
  self.config_str, f"No serialized config file for task with label {key.name!r}."
420
423
  ),
421
424
  dataset_type_keys=dataset_type_keys,
@@ -544,16 +547,16 @@ class SerializedDatasetTypeNode(pydantic.BaseModel):
544
547
  if self.dimensions is not None:
545
548
  dataset_type = DatasetType(
546
549
  key.name,
547
- _expect_not_none(
550
+ expect_not_none(
548
551
  self.dimensions,
549
552
  f"Serialized dataset type {key.name!r} has no dimensions.",
550
553
  ),
551
- storageClass=_expect_not_none(
554
+ storageClass=expect_not_none(
552
555
  self.storage_class,
553
556
  f"Serialized dataset type {key.name!r} has no storage class.",
554
557
  ),
555
558
  isCalibration=self.is_calibration,
556
- universe=_expect_not_none(
559
+ universe=expect_not_none(
557
560
  universe,
558
561
  f"Serialized dataset type {key.name!r} has dimensions, "
559
562
  "but no dimension universe was stored.",
@@ -744,7 +747,7 @@ class SerializedPipelineGraph(pydantic.BaseModel):
744
747
  if self.dimensions is not None:
745
748
  universe = DimensionUniverse(
746
749
  config=DimensionConfig(
747
- _expect_not_none(
750
+ expect_not_none(
748
751
  self.dimensions,
749
752
  "Serialized pipeline graph has not been resolved; "
750
753
  "load it is a MutablePipelineGraph instead.",
@@ -66,7 +66,7 @@ def show_dot(
66
66
  ----------
67
67
  pipeline_graph : `PipelineGraph`
68
68
  Pipeline graph to show.
69
- stream : `io.TextIO`, optional
69
+ stream : `TextIO`, optional
70
70
  Stream to write the DOT representation to.
71
71
  label_edge_connections : `bool`, optional
72
72
  If `True`, label edges with their connection names.
@@ -167,22 +167,21 @@ def _render_dataset_type_node(
167
167
 
168
168
  Parameters
169
169
  ----------
170
- node_key : `NodeKey`
171
- The key for the node.
172
- node_data : `~collections.abc.Mapping` [`str`, `typing.Any`]
173
- The data associated with the node.
174
- options : `NodeAttributeOptions`
175
- Options for rendering the node.
176
- stream : `io.TextIO`
177
- The stream to write the node to.
178
- overflow_ref : `int`, optional
170
+ node_key : NodeKey
171
+ The key for the node
172
+ node_data : Mapping[str, Any]
173
+ The data associated with the node
174
+ options : NodeAttributeOptions
175
+ Options for rendering the node
176
+ stream : TextIO
177
+ The stream to write the node to
179
178
 
180
179
  Returns
181
180
  -------
182
181
  overflow_ref : int
183
- The reference number for the next overflow node.
182
+ The reference number for the next overflow node
184
183
  overflow_ids : str | None
185
- The ID of the overflow node, if any.
184
+ The ID of the overflow node, if any
186
185
  """
187
186
  labels, label_extras, common_prefix = _format_label(str(node_key), _LABEL_MAX_LINES_SOFT)
188
187
  if len(labels) + len(label_extras) <= _LABEL_MAX_LINES_HARD:
@@ -272,7 +271,7 @@ def _render_edge(from_node_id: str, to_node_id: str, stream: TextIO, **kwargs: A
272
271
  The unique ID of the node the edge is going to
273
272
  stream : TextIO
274
273
  The stream to write the edge to
275
- **kwargs : Any
274
+ kwargs : Any
276
275
  Additional keyword arguments to pass to the edge
277
276
  """
278
277
  if kwargs:
@@ -30,7 +30,7 @@ __all__ = ("ColumnSelector", "Layout", "LayoutRow")
30
30
 
31
31
  import dataclasses
32
32
  from collections.abc import Iterable, Iterator, Mapping, Set
33
- from typing import TextIO
33
+ from typing import Generic, TextIO, TypeVar
34
34
 
35
35
  import networkx
36
36
  import networkx.algorithms.components
@@ -38,8 +38,10 @@ import networkx.algorithms.dag
38
38
  import networkx.algorithms.shortest_paths
39
39
  import networkx.algorithms.traversal
40
40
 
41
+ _K = TypeVar("_K")
41
42
 
42
- class Layout[K]:
43
+
44
+ class Layout(Generic[_K]):
43
45
  """A class that positions nodes and edges in text-art graph visualizations.
44
46
 
45
47
  Parameters
@@ -71,9 +73,9 @@ class Layout[K]:
71
73
  # to be close to that text when possible (or maybe it's historical, and
72
74
  # it's just a lot of work to re-invert the algorithm now that it's
73
75
  # written).
74
- self._active_columns: dict[int, set[K]] = {}
76
+ self._active_columns: dict[int, set[_K]] = {}
75
77
  # Mapping from node key to its column.
76
- self._locations: dict[K, int] = {}
78
+ self._locations: dict[_K, int] = {}
77
79
  # Minimum and maximum column (may go negative; will be shifted as
78
80
  # needed before actual display).
79
81
  self._x_min = 0
@@ -114,7 +116,7 @@ class Layout[K]:
114
116
  for component_xgraph, component_order in component_xgraphs_and_orders:
115
117
  self._add_connected_graph(component_xgraph, component_order)
116
118
 
117
- def _add_single_node(self, node: K) -> None:
119
+ def _add_single_node(self, node: _K) -> None:
118
120
  """Add a single node to the layout."""
119
121
  assert node not in self._locations
120
122
  if not self._locations:
@@ -182,7 +184,7 @@ class Layout[K]:
182
184
  return x + 1
183
185
 
184
186
  def _add_connected_graph(
185
- self, xgraph: networkx.DiGraph | networkx.MultiDiGraph, order: list[K] | None = None
187
+ self, xgraph: networkx.DiGraph | networkx.MultiDiGraph, order: list[_K] | None = None
186
188
  ) -> None:
187
189
  """Add a subgraph whose nodes are connected.
188
190
 
@@ -200,7 +202,7 @@ class Layout[K]:
200
202
  # "backbone" of our layout; we'll step through this path and add
201
203
  # recurse via calls to `_add_graph` on the nodes that we think should
202
204
  # go between the backbone nodes.
203
- backbone: list[K] = networkx.algorithms.dag.dag_longest_path(xgraph, topo_order=order)
205
+ backbone: list[_K] = networkx.algorithms.dag.dag_longest_path(xgraph, topo_order=order)
204
206
  # Add the first backbone node and any ancestors according to the full
205
207
  # graph (it can't have ancestors in this _subgraph_ because they'd have
206
208
  # been part of the longest path themselves, but the subgraph doesn't
@@ -235,7 +237,7 @@ class Layout[K]:
235
237
  remaining.remove_nodes_from(self._locations.keys())
236
238
  self._add_graph(remaining)
237
239
 
238
- def _add_blockers_of(self, node: K) -> None:
240
+ def _add_blockers_of(self, node: _K) -> None:
239
241
  """Add all nodes that are ancestors of the given node according to the
240
242
  full graph.
241
243
  """
@@ -249,7 +251,7 @@ class Layout[K]:
249
251
  return (self._x_max - self._x_min) // 2
250
252
 
251
253
  @property
252
- def nodes(self) -> Iterable[K]:
254
+ def nodes(self) -> Iterable[_K]:
253
255
  """The graph nodes in the order they appear in the layout."""
254
256
  return self._locations.keys()
255
257
 
@@ -275,7 +277,7 @@ class Layout[K]:
275
277
  return (self._x_max - x) // 2
276
278
 
277
279
  def __iter__(self) -> Iterator[LayoutRow]:
278
- active_edges: dict[K, set[K]] = {}
280
+ active_edges: dict[_K, set[_K]] = {}
279
281
  for node, node_x in self._locations.items():
280
282
  row = LayoutRow(node, self._external_location(node_x))
281
283
  for origin, destinations in active_edges.items():
@@ -293,20 +295,20 @@ class Layout[K]:
293
295
 
294
296
 
295
297
  @dataclasses.dataclass
296
- class LayoutRow[K]:
298
+ class LayoutRow(Generic[_K]):
297
299
  """Information about a single text-art row in a graph."""
298
300
 
299
- node: K
301
+ node: _K
300
302
  """Key for the node in the exported NetworkX graph."""
301
303
 
302
304
  x: int
303
305
  """Column of the node's symbol and its outgoing edges."""
304
306
 
305
- connecting: list[tuple[int, K]] = dataclasses.field(default_factory=list)
307
+ connecting: list[tuple[int, _K]] = dataclasses.field(default_factory=list)
306
308
  """The columns and node keys of edges that terminate at this row.
307
309
  """
308
310
 
309
- continuing: list[tuple[int, K, frozenset[K]]] = dataclasses.field(default_factory=list)
311
+ continuing: list[tuple[int, _K, frozenset[_K]]] = dataclasses.field(default_factory=list)
310
312
  """The columns and node keys of edges that continue through this row.
311
313
  """
312
314
 
@@ -335,11 +337,11 @@ class ColumnSelector:
335
337
  out in that case because it's applied to all candidate columns.
336
338
  """
337
339
 
338
- def __call__[K](
340
+ def __call__(
339
341
  self,
340
342
  connecting_x: list[int],
341
343
  node_x: int,
342
- active_columns: Mapping[int, Set[K]],
344
+ active_columns: Mapping[int, Set[_K]],
343
345
  x_min: int,
344
346
  x_max: int,
345
347
  ) -> int:
@@ -38,7 +38,7 @@ import hashlib
38
38
  from collections import defaultdict
39
39
  from collections.abc import Iterable
40
40
  from functools import cached_property
41
- from typing import Any
41
+ from typing import Any, TypeVar
42
42
 
43
43
  import networkx
44
44
  import networkx.algorithms.dag
@@ -49,6 +49,9 @@ from lsst.daf.butler import DimensionGroup
49
49
  from .._nodes import NodeKey, NodeType
50
50
  from ._options import NodeAttributeOptions
51
51
 
52
+ _P = TypeVar("_P")
53
+ _C = TypeVar("_C")
54
+
52
55
 
53
56
  class MergedNodeKey(frozenset[NodeKey]):
54
57
  """A key for NetworkX graph nodes that represent multiple similar tasks
@@ -222,11 +225,11 @@ class _MergeKey:
222
225
  """
223
226
 
224
227
  @classmethod
225
- def from_node_state[P, C](
228
+ def from_node_state(
226
229
  cls,
227
230
  state: dict[str, Any],
228
- parents: Iterable[P],
229
- children: Iterable[C],
231
+ parents: Iterable[_P],
232
+ children: Iterable[_C],
230
233
  options: NodeAttributeOptions,
231
234
  ) -> _MergeKey:
232
235
  """Construct from a NetworkX node attribute state dictionary.
@@ -30,9 +30,9 @@ __all__ = ("Printer", "make_colorama_printer", "make_default_printer", "make_sim
30
30
 
31
31
  import sys
32
32
  from collections.abc import Callable, Sequence
33
- from typing import TextIO
33
+ from typing import Generic, TextIO
34
34
 
35
- from ._layout import Layout, LayoutRow
35
+ from ._layout import _K, Layout, LayoutRow
36
36
 
37
37
  _CHAR_DECOMPOSITION = {
38
38
  # This mapping provides the "logic" for how to decompose the relevant
@@ -170,7 +170,7 @@ class PrintRow:
170
170
  return "".join(self._cells)
171
171
 
172
172
 
173
- def _default_get_text[K](node: K, x: int, style: tuple[str, str]) -> str:
173
+ def _default_get_text(node: _K, x: int, style: tuple[str, str]) -> str:
174
174
  """Return the default text to associate with a node.
175
175
 
176
176
  This function is the default value for the ``get_text`` argument to
@@ -179,7 +179,7 @@ def _default_get_text[K](node: K, x: int, style: tuple[str, str]) -> str:
179
179
  return str(node)
180
180
 
181
181
 
182
- def _default_get_symbol[K](node: K, x: int) -> str:
182
+ def _default_get_symbol(node: _K, x: int) -> str:
183
183
  """Return the default symbol for a node.
184
184
 
185
185
  This function is the default value for the ``get_symbol`` argument to
@@ -188,7 +188,7 @@ def _default_get_symbol[K](node: K, x: int) -> str:
188
188
  return "⬤"
189
189
 
190
190
 
191
- def _default_get_style[K](node: K, x: int) -> tuple[str, str]:
191
+ def _default_get_style(node: _K, x: int) -> tuple[str, str]:
192
192
  """Get the default styling suffix/prefix strings.
193
193
 
194
194
  This function is the default value for the ``get_style`` argument to
@@ -197,7 +197,7 @@ def _default_get_style[K](node: K, x: int) -> tuple[str, str]:
197
197
  return "", ""
198
198
 
199
199
 
200
- class Printer[K]:
200
+ class Printer(Generic[_K]):
201
201
  """High-level tool for drawing a text-based DAG visualization.
202
202
 
203
203
  Parameters
@@ -231,9 +231,9 @@ class Printer[K]:
231
231
  *,
232
232
  pad: str = " ",
233
233
  make_blank_row: Callable[[int, str], PrintRow] = PrintRow,
234
- get_text: Callable[[K, int, tuple[str, str]], str] = _default_get_text,
235
- get_symbol: Callable[[K, int], str] = _default_get_symbol,
236
- get_style: Callable[[K, int], tuple[str, str]] = _default_get_style,
234
+ get_text: Callable[[_K, int, tuple[str, str]], str] = _default_get_text,
235
+ get_symbol: Callable[[_K, int], str] = _default_get_symbol,
236
+ get_style: Callable[[_K, int], tuple[str, str]] = _default_get_style,
237
237
  ):
238
238
  self.width = layout_width * 2 + 1
239
239
  self.pad = pad
@@ -245,7 +245,7 @@ class Printer[K]:
245
245
  def print_row(
246
246
  self,
247
247
  stream: TextIO,
248
- layout_row: LayoutRow[K],
248
+ layout_row: LayoutRow[_K],
249
249
  ) -> None:
250
250
  """Print a single row of the DAG visualization to a file-like object.
251
251
 
@@ -200,13 +200,6 @@ class QuantumGraphExecutionStatusAnnotator:
200
200
  """Annotates a networkx graph with task and dataset status information from
201
201
  a quantum graph execution summary, implementing the StatusAnnotator
202
202
  protocol to update the graph with status data.
203
-
204
- Parameters
205
- ----------
206
- *args : `typing.Any`
207
- Arbitrary arguments.
208
- **kwargs : `typing.Any`
209
- Arbitrary keyword arguments.
210
203
  """
211
204
 
212
205
  def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -252,8 +252,7 @@ class PrerequisiteFinder:
252
252
  Sequence of collections to search, in order.
253
253
  data_id : `lsst.daf.butler.DataCoordinate`
254
254
  Data ID for the quantum.
255
- skypix_bounds : `~collections.abc.Mapping` \
256
- [ `str`, `lsst.sphgeom.RangeSet` ]
255
+ skypix_bounds : `Mapping` [ `str`, `lsst.sphgeom.RangeSet` ]
257
256
  The spatial bounds of this quantum in various skypix dimensions.
258
257
  Keys are skypix dimension names (a superset of those in
259
258
  `dataset_skypix`) and values are sets of integer pixel ID ranges.
@@ -50,7 +50,9 @@ from typing import (
50
50
  TYPE_CHECKING,
51
51
  Any,
52
52
  Self,
53
+ TypeAlias,
53
54
  TypedDict,
55
+ TypeVar,
54
56
  )
55
57
 
56
58
  import networkx
@@ -79,16 +81,18 @@ if TYPE_CHECKING:
79
81
  # These aliases make it a lot easier how the various pydantic models are
80
82
  # structured, but they're too verbose to be worth exporting to code outside the
81
83
  # quantum_graph subpackage.
82
- type TaskLabel = str
83
- type DatasetTypeName = str
84
- type ConnectionName = str
85
- type DatasetIndex = int
86
- type QuantumIndex = int
87
- type DatastoreName = str
88
- type DimensionElementName = str
89
- type DataCoordinateValues = list[DataIdValue]
84
+ TaskLabel: TypeAlias = str
85
+ DatasetTypeName: TypeAlias = str
86
+ ConnectionName: TypeAlias = str
87
+ DatasetIndex: TypeAlias = int
88
+ QuantumIndex: TypeAlias = int
89
+ DatastoreName: TypeAlias = str
90
+ DimensionElementName: TypeAlias = str
91
+ DataCoordinateValues: TypeAlias = list[DataIdValue]
90
92
 
91
93
 
94
+ _T = TypeVar("_T", bound=pydantic.BaseModel)
95
+
92
96
  FORMAT_VERSION: int = 1
93
97
  """
94
98
  File format version number for new files.
@@ -444,17 +448,14 @@ class BaseQuantumGraphWriter:
444
448
  uri: ResourcePathExpression,
445
449
  header: HeaderModel,
446
450
  pipeline_graph: PipelineGraph,
451
+ indices: dict[uuid.UUID, int],
447
452
  *,
448
453
  address_filename: str,
454
+ compressor: Compressor,
449
455
  cdict_data: bytes | None = None,
450
- zstd_level: int = 10,
451
456
  ) -> Iterator[Self]:
452
- uri = ResourcePath(uri, forceDirectory=False)
453
- address_writer = AddressWriter()
454
- if uri.isLocal:
455
- os.makedirs(uri.dirname().ospath, exist_ok=True)
456
- cdict = zstandard.ZstdCompressionDict(cdict_data) if cdict_data is not None else None
457
- compressor = zstandard.ZstdCompressor(level=zstd_level, dict_data=cdict)
457
+ uri = ResourcePath(uri)
458
+ address_writer = AddressWriter(indices)
458
459
  with uri.open(mode="wb") as stream:
459
460
  with zipfile.ZipFile(stream, mode="w", compression=zipfile.ZIP_STORED) as zf:
460
461
  self = cls(zf, compressor, address_writer, header.int_size)
@@ -593,9 +594,9 @@ class BaseQuantumGraphReader:
593
594
  )
594
595
 
595
596
  @staticmethod
596
- def _read_single_block_static[T: pydantic.BaseModel](
597
- name: str, model_type: type[T], zf: zipfile.ZipFile, decompressor: Decompressor
598
- ) -> T:
597
+ def _read_single_block_static(
598
+ name: str, model_type: type[_T], zf: zipfile.ZipFile, decompressor: Decompressor
599
+ ) -> _T:
599
600
  """Read a single compressed JSON block from a 'file' in a zip archive.
600
601
 
601
602
  Parameters
@@ -618,7 +619,7 @@ class BaseQuantumGraphReader:
618
619
  json_data = decompressor.decompress(compressed_data)
619
620
  return model_type.model_validate_json(json_data)
620
621
 
621
- def _read_single_block[T: pydantic.BaseModel](self, name: str, model_type: type[T]) -> T:
622
+ def _read_single_block(self, name: str, model_type: type[_T]) -> _T:
622
623
  """Read a single compressed JSON block from a 'file' in a zip archive.
623
624
 
624
625
  Parameters
@@ -43,22 +43,25 @@ import dataclasses
43
43
  import logging
44
44
  import tempfile
45
45
  import uuid
46
- import zipfile
47
- from collections.abc import Iterator, Set
46
+ from collections.abc import Iterator
48
47
  from contextlib import contextmanager
49
48
  from io import BufferedReader, BytesIO
50
49
  from operator import attrgetter
51
- from typing import IO, Protocol, TypeVar
50
+ from typing import IO, TYPE_CHECKING, Protocol, TypeAlias, TypeVar
52
51
 
53
52
  import pydantic
54
53
 
54
+ if TYPE_CHECKING:
55
+ import zipfile
56
+
57
+
55
58
  _LOG = logging.getLogger(__name__)
56
59
 
57
60
 
58
61
  _T = TypeVar("_T", bound=pydantic.BaseModel)
59
62
 
60
63
 
61
- type UUID_int = int
64
+ UUID_int: TypeAlias = int
62
65
 
63
66
  MAX_UUID_INT: UUID_int = 2**128
64
67
 
@@ -74,7 +77,7 @@ individual quanta (especially for execution).
74
77
 
75
78
 
76
79
  class Compressor(Protocol):
77
- """A protocol for objects with a ``compress`` method that takes and returns
80
+ """A protocol for objects with a `compress` method that takes and returns
78
81
  `bytes`.
79
82
  """
80
83
 
@@ -202,14 +205,21 @@ class AddressRow:
202
205
  class AddressWriter:
203
206
  """A helper object for writing address files for multi-block files."""
204
207
 
208
+ indices: dict[uuid.UUID, int] = dataclasses.field(default_factory=dict)
209
+ """Mapping from UUID to internal integer ID.
210
+
211
+ The internal integer ID must always correspond to the index into the
212
+ sorted list of all UUIDs, but this `dict` need not be sorted itself.
213
+ """
214
+
205
215
  addresses: list[dict[uuid.UUID, Address]] = dataclasses.field(default_factory=list)
206
216
  """Addresses to store with each UUID.
207
217
 
208
- Every key in one of these dictionaries must have an entry in ``indices``.
218
+ Every key in one of these dictionaries must have an entry in `indices`.
209
219
  The converse is not true.
210
220
  """
211
221
 
212
- def write(self, stream: IO[bytes], int_size: int, all_ids: Set[uuid.UUID] | None = None) -> None:
222
+ def write(self, stream: IO[bytes], int_size: int) -> None:
213
223
  """Write all addresses to a file-like object.
214
224
 
215
225
  Parameters
@@ -218,18 +228,19 @@ class AddressWriter:
218
228
  Binary file-like object.
219
229
  int_size : `int`
220
230
  Number of bytes to use for all integers.
221
- all_ids : `~collections.abc.Set` [`uuid.UUID`], optional
222
- Set of the union of all UUIDs in any dictionary from a call to
223
- `get_all_ids`.
224
231
  """
225
- if all_ids is None:
226
- all_ids = self.get_all_ids()
232
+ for n, address_map in enumerate(self.addresses):
233
+ if not self.indices.keys() >= address_map.keys():
234
+ raise AssertionError(
235
+ f"Logic bug in quantum graph I/O: address map {n} of {len(self.addresses)} has IDs "
236
+ f"{address_map.keys() - self.indices.keys()} not in the index map."
237
+ )
227
238
  stream.write(int_size.to_bytes(1))
228
- stream.write(len(all_ids).to_bytes(int_size))
239
+ stream.write(len(self.indices).to_bytes(int_size))
229
240
  stream.write(len(self.addresses).to_bytes(int_size))
230
241
  empty_address = Address()
231
- for n, key in enumerate(sorted(all_ids, key=attrgetter("int"))):
232
- row = AddressRow(key, n, [m.get(key, empty_address) for m in self.addresses])
242
+ for key in sorted(self.indices.keys(), key=attrgetter("int")):
243
+ row = AddressRow(key, self.indices[key], [m.get(key, empty_address) for m in self.addresses])
233
244
  _LOG.debug("Wrote address %s.", row)
234
245
  row.write(stream, int_size)
235
246
 
@@ -245,25 +256,8 @@ class AddressWriter:
245
256
  int_size : `int`
246
257
  Number of bytes to use for all integers.
247
258
  """
248
- all_ids = self.get_all_ids()
249
- zip_info = zipfile.ZipInfo(f"{name}.addr")
250
- row_size = AddressReader.compute_row_size(int_size, len(self.addresses))
251
- zip_info.file_size = AddressReader.compute_header_size(int_size) + len(all_ids) * row_size
252
- with zf.open(zip_info, mode="w") as stream:
253
- self.write(stream, int_size=int_size, all_ids=all_ids)
254
-
255
- def get_all_ids(self) -> Set[uuid.UUID]:
256
- """Return all IDs used by any address dictionary.
257
-
258
- Returns
259
- -------
260
- all_ids : `~collections.abc.Set` [`uuid.UUID`]
261
- Set of all IDs.
262
- """
263
- all_ids: set[uuid.UUID] = set()
264
- for address_map in self.addresses:
265
- all_ids.update(address_map.keys())
266
- return all_ids
259
+ with zf.open(f"{name}.addr", mode="w") as stream:
260
+ self.write(stream, int_size=int_size)
267
261
 
268
262
 
269
263
  @dataclasses.dataclass
@@ -662,7 +656,7 @@ class MultiblockWriter:
662
656
  model : `pydantic.BaseModel`
663
657
  Model to convert to JSON and compress.
664
658
  compressor : `Compressor`
665
- Object with a ``compress`` method that takes and returns `bytes`.
659
+ Object with a `compress` method that takes and returns `bytes`.
666
660
 
667
661
  Returns
668
662
  -------
@@ -759,7 +753,7 @@ class MultiblockReader:
759
753
  model_type : `type` [ `pydantic.BaseModel` ]
760
754
  Pydantic model to validate JSON with.
761
755
  decompressor : `Decompressor`
762
- Object with a ``decompress`` method that takes and returns `bytes`.
756
+ Object with a `decompress` method that takes and returns `bytes`.
763
757
  int_size : `int`
764
758
  Number of bytes to use for all integers.
765
759
  page_size : `int`
@@ -809,7 +803,7 @@ class MultiblockReader:
809
803
  model_type : `type` [ `pydantic.BaseModel` ]
810
804
  Pydantic model to validate JSON with.
811
805
  decompressor : `Decompressor`
812
- Object with a ``decompress`` method that takes and returns `bytes`.
806
+ Object with a `decompress` method that takes and returns `bytes`.
813
807
 
814
808
  Returns
815
809
  -------