lsst-pipe-base 29.2025.1000__py3-none-any.whl → 29.2025.1200__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (20) hide show
  1. lsst/pipe/base/_datasetQueryConstraints.py +1 -1
  2. lsst/pipe/base/all_dimensions_quantum_graph_builder.py +642 -357
  3. lsst/pipe/base/connections.py +179 -2
  4. lsst/pipe/base/pipeline_graph/visualization/_mermaid.py +157 -24
  5. lsst/pipe/base/prerequisite_helpers.py +1 -1
  6. lsst/pipe/base/quantum_graph_builder.py +91 -60
  7. lsst/pipe/base/quantum_graph_skeleton.py +20 -0
  8. lsst/pipe/base/quantum_provenance_graph.py +790 -421
  9. lsst/pipe/base/tests/mocks/_data_id_match.py +4 -0
  10. lsst/pipe/base/version.py +1 -1
  11. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/METADATA +5 -2
  12. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/RECORD +20 -20
  13. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/WHEEL +1 -1
  14. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/entry_points.txt +0 -0
  15. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info/licenses}/COPYRIGHT +0 -0
  16. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info/licenses}/LICENSE +0 -0
  17. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info/licenses}/bsd_license.txt +0 -0
  18. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info/licenses}/gpl-v3.0.txt +0 -0
  19. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/top_level.txt +0 -0
  20. {lsst_pipe_base-29.2025.1000.dist-info → lsst_pipe_base-29.2025.1200.dist-info}/zip-safe +0 -0
@@ -35,6 +35,7 @@ __all__ = [
35
35
  "InputQuantizedConnection",
36
36
  "OutputQuantizedConnection",
37
37
  "PipelineTaskConnections",
38
+ "QuantaAdjuster",
38
39
  "QuantizedConnection",
39
40
  "ScalarError",
40
41
  "ScalarError",
@@ -45,8 +46,8 @@ import dataclasses
45
46
  import itertools
46
47
  import string
47
48
  import warnings
48
- from collections import UserDict
49
- from collections.abc import Collection, Generator, Iterable, Mapping, Sequence, Set
49
+ from collections import UserDict, defaultdict
50
+ from collections.abc import Collection, Generator, Iterable, Iterator, Mapping, Sequence, Set
50
51
  from dataclasses import dataclass
51
52
  from types import MappingProxyType, SimpleNamespace
52
53
  from typing import TYPE_CHECKING, Any
@@ -58,6 +59,8 @@ from .connectionTypes import BaseConnection, BaseInput, Output, PrerequisiteInpu
58
59
 
59
60
  if TYPE_CHECKING:
60
61
  from .config import PipelineTaskConfig
62
+ from .pipeline_graph import PipelineGraph, TaskNode
63
+ from .quantum_graph_skeleton import QuantumGraphSkeleton
61
64
 
62
65
 
63
66
  class ScalarError(TypeError):
@@ -999,6 +1002,25 @@ class PipelineTaskConnections(metaclass=PipelineTaskConnectionsMetaclass):
999
1002
  """
1000
1003
  return ()
1001
1004
 
1005
+ def adjust_all_quanta(self, adjuster: QuantaAdjuster) -> None:
1006
+ """Customize the set of quanta predicted for this task during quantum
1007
+ graph generation.
1008
+
1009
+ Parameters
1010
+ ----------
1011
+ adjuster : `QuantaAdjuster`
1012
+ A helper object that implementations can use to modify the
1013
+ under-construction quantum graph.
1014
+
1015
+ Notes
1016
+ -----
1017
+ This hook is called before `adjustQuantum`, which is where built-in
1018
+ checks for `NoWorkFound` cases and missing prerequisites are handled.
1019
+ This means that the set of preliminary quanta seen by this method could
1020
+ include some that would normally be dropped later.
1021
+ """
1022
+ pass
1023
+
1002
1024
 
1003
1025
  def iterConnections(
1004
1026
  connections: PipelineTaskConnections, connectionType: str | Iterable[str]
@@ -1130,3 +1152,158 @@ class AdjustQuantumHelper:
1130
1152
  self.outputs_adjusted = True
1131
1153
  else:
1132
1154
  self.outputs_adjusted = False
1155
+
1156
+
1157
+ class QuantaAdjuster:
1158
+ """A helper class for the `PipelineTaskConnections.adjust_all_quanta` hook.
1159
+
1160
+ Parameters
1161
+ ----------
1162
+ task_label : `str`
1163
+ Label of the task whose quanta are being adjusted.
1164
+ pipeline_graph : `pipeline_graph.PipelineGraph`
1165
+ Pipeline graph the quantum graph is being built from.
1166
+ skeleton : `quantum_graph_skeleton.QuantumGraphSkeleton`
1167
+ Under-construction quantum graph that will be modified in place.
1168
+ """
1169
+
1170
+ def __init__(self, task_label: str, pipeline_graph: PipelineGraph, skeleton: QuantumGraphSkeleton):
1171
+ self._task_node = pipeline_graph.tasks[task_label]
1172
+ self._pipeline_graph = pipeline_graph
1173
+ self._skeleton = skeleton
1174
+ self._n_removed = 0
1175
+
1176
+ @property
1177
+ def task_label(self) -> str:
1178
+ """The label this task has been configured with."""
1179
+ return self._task_node.label
1180
+
1181
+ @property
1182
+ def task_node(self) -> TaskNode:
1183
+ """The node for this task in the pipeline graph."""
1184
+ return self._task_node
1185
+
1186
+ def iter_data_ids(self) -> Iterator[DataCoordinate]:
1187
+ """Iterate over the data IDs of all quanta for this task."
1188
+
1189
+ Returns
1190
+ -------
1191
+ data_ids : `~collections.abc.Iterator` [ \
1192
+ `~lsst.daf.butler.DataCoordinate` ]
1193
+ Data IDs. These are minimal data IDs without dimension records or
1194
+ implied values; use `expand_quantum_data_id` to get a full data ID
1195
+ when needed.
1196
+ """
1197
+ for key in self._skeleton.get_quanta(self._task_node.label):
1198
+ yield DataCoordinate.from_required_values(self._task_node.dimensions, key.data_id_values)
1199
+
1200
+ def remove_quantum(self, data_id: DataCoordinate) -> None:
1201
+ """Remove a quantum from the graph.
1202
+
1203
+ Parameters
1204
+ ----------
1205
+ data_id : `~lsst.daf.butler.DataCoordinate`
1206
+ Data ID of the quantum to remove. All outputs will be removed as
1207
+ well.
1208
+ """
1209
+ from .quantum_graph_skeleton import QuantumKey
1210
+
1211
+ self._skeleton.remove_quantum_node(
1212
+ QuantumKey(self._task_node.label, data_id.required_values), remove_outputs=True
1213
+ )
1214
+ self._n_removed += 1
1215
+
1216
+ def get_inputs(self, quantum_data_id: DataCoordinate) -> dict[str, list[DataCoordinate]]:
1217
+ """Return the data IDs of all regular inputs to a quantum.
1218
+
1219
+ Parameters
1220
+ ----------
1221
+ data_id : `~lsst.daf.butler.DataCoordinate`
1222
+ Data ID of the quantum to get the inputs of.
1223
+
1224
+ Returns
1225
+ -------
1226
+ inputs : `dict` [ `str`, `list` [ `~lsst.daf.butler.DataCoordinate` ] ]
1227
+ Data IDs of inputs, keyed by the connection name (the internal task
1228
+ name, not the dataset type name). This only contains regular
1229
+ inputs, not init-inputs or prerequisite inputs.
1230
+
1231
+ Notes
1232
+ -----
1233
+ If two connections have the same dataset type, the current
1234
+ implementation assumes the set of datasets is the same for the two
1235
+ connections. This limitation may be removed in the future.
1236
+ """
1237
+ from .quantum_graph_skeleton import DatasetKey, QuantumKey
1238
+
1239
+ by_dataset_type_name: defaultdict[str, list[DataCoordinate]] = defaultdict(list)
1240
+ quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1241
+ for dataset_key in self._skeleton.iter_inputs_of(quantum_key):
1242
+ if not isinstance(dataset_key, DatasetKey):
1243
+ continue
1244
+ dataset_type_node = self._pipeline_graph.dataset_types[dataset_key.parent_dataset_type_name]
1245
+ by_dataset_type_name[dataset_key.parent_dataset_type_name].append(
1246
+ DataCoordinate.from_required_values(dataset_type_node.dimensions, dataset_key.data_id_values)
1247
+ )
1248
+ return {
1249
+ edge.connection_name: by_dataset_type_name[edge.parent_dataset_type_name]
1250
+ for edge in self._task_node.iter_all_inputs()
1251
+ }
1252
+
1253
+ def add_input(
1254
+ self, quantum_data_id: DataCoordinate, connection_name: str, dataset_data_id: DataCoordinate
1255
+ ) -> None:
1256
+ """Add a new input to a quantum.
1257
+
1258
+ Parameters
1259
+ ----------
1260
+ quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1261
+ Data ID of the quantum to add an input to.
1262
+ connection_name : `str`
1263
+ Name of the connection (the task-internal name, not the butler
1264
+ dataset type name).
1265
+ dataset_data_id : `~lsst.daf.butler.DataCoordinate`
1266
+ Data ID of the input dataset. Must already exist in the graph
1267
+ as an input to a different quantum of this task, and must be a
1268
+ regular input, not a prerequisite input or init-input.
1269
+
1270
+ Notes
1271
+ -----
1272
+ If two connections have the same dataset type, the current
1273
+ implementation assumes the set of datasets is the same for the two
1274
+ connections. This limitation may be removed in the future.
1275
+ """
1276
+ from .quantum_graph_skeleton import DatasetKey, QuantumKey
1277
+
1278
+ quantum_key = QuantumKey(self._task_node.label, quantum_data_id.required_values)
1279
+ read_edge = self._task_node.inputs[connection_name]
1280
+ dataset_key = DatasetKey(read_edge.parent_dataset_type_name, dataset_data_id.required_values)
1281
+ if dataset_key not in self._skeleton:
1282
+ raise LookupError(
1283
+ f"Dataset {read_edge.parent_dataset_type_name}@{dataset_data_id} is not already in the graph."
1284
+ )
1285
+ self._skeleton.add_input_edge(quantum_key, dataset_key)
1286
+
1287
+ def expand_quantum_data_id(self, data_id: DataCoordinate) -> DataCoordinate:
1288
+ """Expand a quantum data ID to include implied values and dimension
1289
+ records.
1290
+
1291
+ Parameters
1292
+ ----------
1293
+ quantum_data_id : `~lsst.daf.butler.DataCoordinate`
1294
+ A data ID of a quantum already in the graph.
1295
+
1296
+ Returns
1297
+ -------
1298
+ expanded_data_id : `~lsst.daf.butler.DataCoordinate`
1299
+ The same data ID, with implied values included and dimension
1300
+ records attached.
1301
+ """
1302
+ from .quantum_graph_skeleton import QuantumKey
1303
+
1304
+ return self._skeleton.get_data_id(QuantumKey(self._task_node.label, data_id.required_values))
1305
+
1306
+ @property
1307
+ def n_removed(self) -> int:
1308
+ """The number of quanta that have been removed by this helper."""
1309
+ return self._n_removed
@@ -32,6 +32,7 @@ import html
32
32
  import os
33
33
  import sys
34
34
  from collections.abc import Mapping
35
+ from io import BufferedIOBase, BytesIO, StringIO, TextIOBase
35
36
  from typing import Any, TextIO
36
37
 
37
38
  from .._nodes import NodeType
@@ -40,6 +41,14 @@ from ._formatting import NodeKey, format_dimensions, format_task_class
40
41
  from ._options import NodeAttributeOptions
41
42
  from ._show import parse_display_args
42
43
 
44
+ try:
45
+ from mermaid import Mermaid # type: ignore
46
+ from mermaid.graph import Graph # type: ignore
47
+
48
+ MERMAID_AVAILABLE = True
49
+ except ImportError:
50
+ MERMAID_AVAILABLE = False
51
+
43
52
  # Configuration constants for label formatting and overflow handling.
44
53
  _LABEL_PX_SIZE = 18
45
54
  _LABEL_MAX_LINES_SOFT = 10
@@ -49,7 +58,11 @@ _OVERFLOW_MAX_LINES = 20
49
58
 
50
59
  def show_mermaid(
51
60
  pipeline_graph: PipelineGraph,
52
- stream: TextIO = sys.stdout,
61
+ stream: TextIO | BytesIO = sys.stdout,
62
+ output_format: str = "mmd",
63
+ width: int | None = None,
64
+ height: int | None = None,
65
+ scale: float | None = None,
53
66
  **kwargs: Any,
54
67
  ) -> None:
55
68
  """Write a Mermaid flowchart representation of the pipeline graph to a
@@ -65,9 +78,20 @@ def show_mermaid(
65
78
  ----------
66
79
  pipeline_graph : `PipelineGraph`
67
80
  The pipeline graph to visualize.
68
- stream : `TextIO`, optional
81
+ stream : `TextIO` or `BytesIO`, optional
69
82
  The output stream where Mermaid code is written. Defaults to
70
83
  `sys.stdout`.
84
+ output_format : str, optional
85
+ Defines the output format. 'mmd' (default) generates a Mermaid
86
+ definition text file, while 'svg' and 'png' produce rendered images as
87
+ binary streams.
88
+ width : int, optional
89
+ The width of the rendered image in pixels.
90
+ height : int, optional
91
+ The height of the rendered image in pixels.
92
+ scale : float, optional
93
+ The scale factor for the rendered image. Must be an float between 1
94
+ and 3, and one of height or width must be provided.
71
95
  **kwargs : Any
72
96
  Additional arguments passed to `parse_display_args` to control aspects
73
97
  such as displaying dimensions, storage classes, or full task class
@@ -85,27 +109,61 @@ def show_mermaid(
85
109
  - If a node's label is too long, overflow nodes are created to hold extra
86
110
  lines.
87
111
  """
112
+ # Generate Mermaid source code in-memory.
113
+ mermaid_source = _generate_mermaid_source(pipeline_graph, **kwargs)
114
+
115
+ if output_format == "mmd":
116
+ if isinstance(stream, TextIOBase):
117
+ # Write Mermaid source as a string.
118
+ stream.write(mermaid_source)
119
+ else:
120
+ raise TypeError(f"Expected a text stream, but got {type(stream)}.")
121
+ else:
122
+ if isinstance(stream, BufferedIOBase):
123
+ # Render Mermaid source as an image and write to binary stream.
124
+ _render_mermaid_image(
125
+ mermaid_source, stream, output_format, width=width, height=height, scale=scale
126
+ )
127
+ else:
128
+ raise ValueError(f"Expected a binary stream, but got {type(stream)}.")
129
+
130
+
131
+ def _generate_mermaid_source(pipeline_graph: PipelineGraph, **kwargs: Any) -> str:
132
+ """Generate the Mermaid source code from the pipeline graph.
133
+
134
+ Parameters
135
+ ----------
136
+ pipeline_graph : `PipelineGraph`
137
+ The pipeline graph to visualize.
138
+ **kwargs : Any
139
+ Additional arguments passed to `parse_display_args` for rendering.
140
+
141
+ Returns
142
+ -------
143
+ str
144
+ The Mermaid source code as a string.
145
+ """
146
+ # A buffer to collect Mermaid source code.
147
+ buffer = StringIO()
148
+
88
149
  # Parse display arguments to determine what to show.
89
150
  xgraph, options = parse_display_args(pipeline_graph, **kwargs)
90
151
 
91
152
  # Begin the Mermaid code block.
92
- print("flowchart TD", file=stream)
153
+ buffer.write("flowchart TD\n")
93
154
 
94
155
  # Define Mermaid classes for node styling.
95
- print(
156
+ buffer.write(
96
157
  f"classDef task fill:#B1F2EF,color:#000,stroke:#000,stroke-width:3px,"
97
- f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;",
98
- file=stream,
158
+ f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n"
99
159
  )
100
- print(
160
+ buffer.write(
101
161
  f"classDef dsType fill:#F5F5F5,color:#000,stroke:#00BABC,stroke-width:3px,"
102
- f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left,rx:8,ry:8;",
103
- file=stream,
162
+ f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left,rx:8,ry:8;\n"
104
163
  )
105
- print(
164
+ buffer.write(
106
165
  f"classDef taskInit fill:#F4DEFA,color:#000,stroke:#000,stroke-width:3px,"
107
- f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;",
108
- file=stream,
166
+ f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;\n"
109
167
  )
110
168
 
111
169
  # `overflow_ref` tracks the reference numbers for overflow nodes.
@@ -116,30 +174,27 @@ def show_mermaid(
116
174
  for node_key, node_data in xgraph.nodes.items():
117
175
  match node_key.node_type:
118
176
  case NodeType.TASK | NodeType.TASK_INIT:
119
- # Render a task or task-init node.
120
- _render_task_node(node_key, node_data, options, stream)
177
+ _render_task_node(node_key, node_data, options, buffer)
121
178
  case NodeType.DATASET_TYPE:
122
- # Render a dataset-type node with possible overflow handling.
123
179
  overflow_ref, node_overflow_ids = _render_dataset_type_node(
124
- node_key, node_data, options, stream, overflow_ref
180
+ node_key, node_data, options, buffer, overflow_ref
125
181
  )
126
- if node_overflow_ids:
127
- overflow_ids += node_overflow_ids
182
+ overflow_ids += node_overflow_ids if node_overflow_ids else []
128
183
  case _:
129
184
  raise AssertionError(f"Unexpected node type: {node_key.node_type}")
130
185
 
131
- # Collect edges for printing and track which ones are prerequisite
132
- # so we can apply dashed styling after printing them.
186
+ # Collect edges for adding to the Mermaid code and track which ones are
187
+ # prerequisite so we can apply dashed styling to them later.
133
188
  edges = []
134
189
  for _, (from_node, to_node, *_rest) in enumerate(xgraph.edges):
135
190
  is_prereq = xgraph.nodes[from_node].get("is_prerequisite", False)
136
191
  edges.append((from_node.node_id, to_node.node_id, is_prereq))
137
192
 
138
- # Print all edges
193
+ # Render all edges.
139
194
  for _, (f, t, p) in enumerate(edges):
140
- _render_edge(f, t, p, stream)
195
+ _render_edge(f, t, p, buffer)
141
196
 
142
- # After printing all edges, apply linkStyle to prerequisite edges to make
197
+ # After rendering all edges, apply linkStyle to prerequisite edges to make
143
198
  # them dashed:
144
199
 
145
200
  # First, gather indices of prerequisite edges.
@@ -147,7 +202,85 @@ def show_mermaid(
147
202
 
148
203
  # Then apply dashed styling to all prerequisite edges in one line.
149
204
  if prereq_indices:
150
- print(f"linkStyle {','.join(prereq_indices)} stroke-dasharray:5;", file=stream)
205
+ buffer.write(f"linkStyle {','.join(prereq_indices)} stroke-dasharray:5;\n")
206
+
207
+ # Return Mermaid source as string.
208
+ return buffer.getvalue()
209
+
210
+
211
+ def _render_mermaid_image(
212
+ mermaid_source: str,
213
+ binary_stream: BytesIO,
214
+ output_format: str,
215
+ width: int | None = None,
216
+ height: int | None = None,
217
+ scale: float | None = None,
218
+ ) -> None:
219
+ """Render a Mermaid diagram as an image and write the output to a binary
220
+ stream.
221
+
222
+ Parameters
223
+ ----------
224
+ mermaid_source : str
225
+ The Mermaid diagram source code.
226
+ binary_stream : `BytesIO`
227
+ The binary stream where the output content will be written.
228
+ output_format : str
229
+ The desired output format for the image. Supported image formats are
230
+ 'svg' and 'png'.
231
+ width : int, optional
232
+ The width of the rendered image in pixels.
233
+ height : int, optional
234
+ The height of the rendered image in pixels.
235
+ scale : float, optional
236
+ The scale factor for the rendered image. Must be a float between 1 and
237
+ 3, and one of height or width must be provided.
238
+
239
+ Raises
240
+ ------
241
+ ImportError
242
+ If `mermaid-py` is not installed.
243
+ ValueError
244
+ If the requested ``output_format`` is not supported.
245
+ RuntimeError
246
+ If the rendering process fails.
247
+ """
248
+ if output_format.lower() not in {"svg", "png"}:
249
+ raise ValueError(f"Unsupported format: {output_format}. Use 'svg' or 'png'.")
250
+
251
+ if not MERMAID_AVAILABLE:
252
+ raise ImportError("The `mermaid-py` package is required for rendering images but is not installed.")
253
+
254
+ # Generate Mermaid graph object.
255
+ graph = Graph(title="Mermaid Diagram", script=mermaid_source)
256
+ diagram = Mermaid(graph, width=width, height=height, scale=scale)
257
+
258
+ # Determine the response type based on the output format.
259
+ if output_format.lower() == "svg":
260
+ response_type = "svg_response"
261
+ else:
262
+ response_type = "img_response"
263
+
264
+ # Select the appropriate output format and write the content to the stream.
265
+ try:
266
+ content = getattr(diagram, response_type).content
267
+
268
+ # Check if the response is actually an image.
269
+ if content.startswith(b"<!DOCTYPE html>") or b"<title>" in content[:200]:
270
+ error_msg = content.decode(errors="ignore")[:1000]
271
+ if "524" in error_msg or "timeout" in error_msg.lower():
272
+ raise RuntimeError(
273
+ f"Mermaid rendering service (mermaid.ink) timed out while generating {response_type}. "
274
+ "This may be due to server overload. Try again later or use a local rendering option."
275
+ )
276
+ raise RuntimeError(
277
+ f"Unexpected error from Mermaid API while generating {response_type}. Response:\n{error_msg}"
278
+ )
279
+
280
+ # Write the content to the binary stream if it's a valid image.
281
+ binary_stream.write(content)
282
+ except AttributeError as exc:
283
+ raise RuntimeError(f"Failed to generate {response_type} content") from exc
151
284
 
152
285
 
153
286
  def _render_task_node(
@@ -340,7 +340,7 @@ class PrerequisiteFinder:
340
340
  where_terms: list[str] = []
341
341
  bind: dict[str, list[int]] = {}
342
342
  for name in self.dataset_skypix:
343
- where_terms.append(f"{name} IN ({name}_pixels)")
343
+ where_terms.append(f"{name} IN (:{name}_pixels)")
344
344
  pixels: list[int] = []
345
345
  for begin, end in skypix_bounds[name]:
346
346
  pixels.extend(range(begin, end))