lsst-pipe-base 29.2025.1400__py3-none-any.whl → 29.2025.1600__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. lsst/pipe/base/all_dimensions_quantum_graph_builder.py +17 -0
  2. lsst/pipe/base/graph/_loadHelpers.py +4 -0
  3. lsst/pipe/base/graph/graph.py +2 -2
  4. lsst/pipe/base/pipeline.py +1 -1
  5. lsst/pipe/base/pipelineIR.py +10 -1
  6. lsst/pipe/base/pipeline_graph/__main__.py +1 -0
  7. lsst/pipe/base/pipeline_graph/_exceptions.py +7 -0
  8. lsst/pipe/base/pipeline_graph/_pipeline_graph.py +360 -11
  9. lsst/pipe/base/pipeline_graph/expressions.py +271 -0
  10. lsst/pipe/base/pipeline_graph/visualization/__init__.py +1 -0
  11. lsst/pipe/base/pipeline_graph/visualization/_formatting.py +300 -5
  12. lsst/pipe/base/pipeline_graph/visualization/_mermaid.py +17 -25
  13. lsst/pipe/base/pipeline_graph/visualization/_options.py +11 -3
  14. lsst/pipe/base/pipeline_graph/visualization/_show.py +23 -3
  15. lsst/pipe/base/pipeline_graph/visualization/_status_annotator.py +250 -0
  16. lsst/pipe/base/quantum_provenance_graph.py +28 -0
  17. lsst/pipe/base/version.py +1 -1
  18. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/METADATA +2 -1
  19. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/RECORD +27 -25
  20. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/WHEEL +0 -0
  21. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/entry_points.txt +0 -0
  22. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/licenses/COPYRIGHT +0 -0
  23. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/licenses/LICENSE +0 -0
  24. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/licenses/bsd_license.txt +0 -0
  25. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/licenses/gpl-v3.0.txt +0 -0
  26. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/top_level.txt +0 -0
  27. {lsst_pipe_base-29.2025.1400.dist-info → lsst_pipe_base-29.2025.1600.dist-info}/zip-safe +0 -0
@@ -39,6 +39,8 @@ from collections import defaultdict
39
39
  from collections.abc import Iterable, Mapping
40
40
  from typing import TYPE_CHECKING, Any, TypeAlias, final
41
41
 
42
+ import astropy.table
43
+
42
44
  from lsst.daf.butler import (
43
45
  Butler,
44
46
  DataCoordinate,
@@ -85,6 +87,11 @@ class AllDimensionsQuantumGraphBuilder(QuantumGraphBuilder):
85
87
  (sometimes catastrophically bad) query plan.
86
88
  bind : `~collections.abc.Mapping`, optional
87
89
  Variable substitutions for the ``where`` expression.
90
+ data_id_tables : `~collections.abc.Iterable` [ `astropy.table.Table` ],\
91
+ optional
92
+ Tables of data IDs to join in as constraints. Missing dimensions that
93
+ are constrained by the ``where`` argument or pipeline data ID will be
94
+ filled in automatically.
88
95
  **kwargs
89
96
  Additional keyword arguments forwarded to `QuantumGraphBuilder`.
90
97
 
@@ -113,6 +120,7 @@ class AllDimensionsQuantumGraphBuilder(QuantumGraphBuilder):
113
120
  where: str = "",
114
121
  dataset_query_constraint: DatasetQueryConstraintVariant = DatasetQueryConstraintVariant.ALL,
115
122
  bind: Mapping[str, Any] | None = None,
123
+ data_id_tables: Iterable[astropy.table.Table] = (),
116
124
  **kwargs: Any,
117
125
  ):
118
126
  super().__init__(pipeline_graph, butler, **kwargs)
@@ -120,6 +128,7 @@ class AllDimensionsQuantumGraphBuilder(QuantumGraphBuilder):
120
128
  self.where = where
121
129
  self.dataset_query_constraint = dataset_query_constraint
122
130
  self.bind = bind
131
+ self.data_id_tables = list(data_id_tables)
123
132
 
124
133
  @timeMethod
125
134
  def process_subgraph(self, subgraph: PipelineGraph) -> QuantumGraphSkeleton:
@@ -194,6 +203,14 @@ class AllDimensionsQuantumGraphBuilder(QuantumGraphBuilder):
194
203
  f"{self.where!r}, bind={self.bind!r})"
195
204
  )
196
205
  query = query.where(tree.subgraph.data_id, self.where, bind=self.bind)
206
+ # It's important for tables to be joined in last, so data IDs from
207
+ # pipeline and where can be used to fill in missing columns.
208
+ for table in self.data_id_tables:
209
+ # If this is from ctrl_mpexec's pipetask, it'll have added
210
+ # a filename to the metadata for us.
211
+ table_name = table.meta.get("filename", "unknown")
212
+ query_cmd.append(f" query = query.join_data_coordinate_table(<{table_name}>)")
213
+ query = query.join_data_coordinate_table(table)
197
214
  self.log.verbose("Querying for data IDs via: %s", "\n".join(query_cmd))
198
215
  # Allow duplicates from common skypix overlaps to make some queries
199
216
  # run faster.
@@ -65,6 +65,7 @@ class LoadHelper(AbstractContextManager["LoadHelper"]):
65
65
  to upgrade them to the latest format before they can be used in
66
66
  production.
67
67
  """
68
+ fullRead: bool = False
68
69
 
69
70
  def __post_init__(self) -> None:
70
71
  self._resourceHandle: ResourceHandleProtocol | None = None
@@ -261,6 +262,9 @@ class LoadHelper(AbstractContextManager["LoadHelper"]):
261
262
  def __enter__(self) -> LoadHelper:
262
263
  if isinstance(self.uri, BinaryIO | BytesIO | BufferedRandom):
263
264
  self._resourceHandle = self.uri
265
+ elif self.fullRead:
266
+ local = self._exitStack.enter_context(self.uri.as_local())
267
+ self._resourceHandle = self._exitStack.enter_context(local.open("rb"))
264
268
  else:
265
269
  self._resourceHandle = self._exitStack.enter_context(self.uri.open("rb"))
266
270
  self._initialize()
@@ -963,7 +963,7 @@ class QuantumGraph:
963
963
  """
964
964
  uri = ResourcePath(uri)
965
965
  if uri.getExtension() in {".qgraph"}:
966
- with LoadHelper(uri, minimumVersion) as loader:
966
+ with LoadHelper(uri, minimumVersion, fullRead=(nodes is None)) as loader:
967
967
  qgraph = loader.load(universe, nodes, graphID)
968
968
  else:
969
969
  raise ValueError(f"Only know how to handle files saved as `.qgraph`, not {uri}")
@@ -1230,7 +1230,7 @@ class QuantumGraph:
1230
1230
  being loaded or if the supplied uri does not point at a valid
1231
1231
  `QuantumGraph` save file.
1232
1232
  """
1233
- with LoadHelper(file, minimumVersion) as loader:
1233
+ with LoadHelper(file, minimumVersion, fullRead=(nodes is None)) as loader:
1234
1234
  qgraph = loader.load(universe, nodes, graphID)
1235
1235
  if not isinstance(qgraph, QuantumGraph):
1236
1236
  raise TypeError(f"QuantumGraph file contains unexpected object type: {type(qgraph)}")
@@ -427,7 +427,7 @@ class Pipeline:
427
427
  if "," in label_subset:
428
428
  if ".." in label_subset:
429
429
  raise ValueError(
430
- "Can only specify a list of labels or a rangewhen loading a Pipline not both"
430
+ "Can only specify a list of labels or a range when loading a Pipeline, not both."
431
431
  )
432
432
  args = {"labels": set(label_subset.split(","))}
433
433
  # labels supplied as a range
@@ -980,10 +980,19 @@ class PipelineIR:
980
980
  if extraTaskLabels := (labeled_subset.subset - pipeline.tasks.keys()):
981
981
  match subsetCtrl:
982
982
  case PipelineSubsetCtrl.DROP:
983
- pipeline.labeled_subsets.pop(label)
983
+ del pipeline.labeled_subsets[label]
984
984
  case PipelineSubsetCtrl.EDIT:
985
985
  for extra in extraTaskLabels:
986
986
  labeled_subset.subset.discard(extra)
987
+ elif subsetCtrl is PipelineSubsetCtrl.DROP and not labeled_subset.subset:
988
+ # When mode is DROP, also drop any subsets that were already
989
+ # empty. This ensures we drop steps that were emptied-out by
990
+ # (earlier) imports with exclude in EDIT mode. Note that we
991
+ # don't want to drop those steps when they're first excluded
992
+ # down to nothing, because the pipeline might be about to add
993
+ # new tasks back into them, and then we'd want to preserve the
994
+ # step definitions.
995
+ del pipeline.labeled_subsets[label]
987
996
 
988
997
  # remove any steps that correspond to removed subsets
989
998
  new_steps = []
@@ -334,6 +334,7 @@ class DisplayArguments:
334
334
  dimensions=args.dimensions,
335
335
  task_classes=args.task_classes,
336
336
  storage_classes=args.storage_classes,
337
+ status=None,
337
338
  ),
338
339
  merge_input_trees=args.merge_input_trees,
339
340
  merge_output_trees=args.merge_output_trees,
@@ -31,6 +31,7 @@ __all__ = (
31
31
  "DuplicateOutputError",
32
32
  "EdgesChangedError",
33
33
  "IncompatibleDatasetTypeError",
34
+ "InvalidExpressionError",
34
35
  "InvalidStepsError",
35
36
  "PipelineDataCycleError",
36
37
  "PipelineGraphError",
@@ -102,5 +103,11 @@ class PipelineGraphExceptionSafetyError(PipelineGraphError):
102
103
  """
103
104
 
104
105
 
106
+ class InvalidExpressionError(PipelineGraphError):
107
+ """Exception raised when a pipeline subset expression could not be parsed
108
+ or applied.
109
+ """
110
+
111
+
105
112
  class InvalidStepsError(PipelineGraphError):
106
113
  """Exception raised when the step definitions are invalid."""
@@ -55,11 +55,13 @@ from lsst.utils.packages import Packages
55
55
 
56
56
  from .._dataset_handle import InMemoryDatasetHandle
57
57
  from ..automatic_connection_constants import PACKAGES_INIT_OUTPUT_NAME, PACKAGES_INIT_OUTPUT_STORAGE_CLASS
58
+ from . import expressions
58
59
  from ._dataset_types import DatasetTypeNode
59
60
  from ._edges import Edge, ReadEdge, WriteEdge
60
61
  from ._exceptions import (
61
62
  DuplicateOutputError,
62
63
  EdgesChangedError,
64
+ InvalidExpressionError,
63
65
  InvalidStepsError,
64
66
  PipelineDataCycleError,
65
67
  PipelineGraphError,
@@ -1149,16 +1151,7 @@ class PipelineGraph:
1149
1151
  See `TaskNode` and `TaskInitNode` for the descriptive node and
1150
1152
  attributes added.
1151
1153
  """
1152
- bipartite_xgraph = self._make_bipartite_xgraph_internal(init)
1153
- task_keys = [
1154
- key
1155
- for key, bipartite in bipartite_xgraph.nodes(data="bipartite")
1156
- if bipartite == NodeType.TASK.bipartite
1157
- ]
1158
- return self._transform_xgraph_state(
1159
- networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys),
1160
- skip_edges=True,
1161
- )
1154
+ return self._transform_xgraph_state(self._make_task_xgraph_internal(init), skip_edges=True)
1162
1155
 
1163
1156
  def make_dataset_type_xgraph(self, init: bool = False) -> networkx.DiGraph:
1164
1157
  """Return a networkx representation of just the dataset types in the
@@ -1197,6 +1190,62 @@ class PipelineGraph:
1197
1190
  skip_edges=True,
1198
1191
  )
1199
1192
 
1193
+ ###########################################################################
1194
+ #
1195
+ # Expression-based Selection Interface.
1196
+ #
1197
+ ###########################################################################
1198
+
1199
+ def select_tasks(self, expression: str) -> set[str]:
1200
+ """Return the tasks that match an expression.
1201
+
1202
+ Parameters
1203
+ ----------
1204
+ expression : `str`
1205
+ String expression to evaluate. See
1206
+ :ref:`pipeline-graph-subset-expressions`.
1207
+
1208
+ Returns
1209
+ -------
1210
+ task_labels : `set` [ `str` ]
1211
+ Set of matching task labels.
1212
+ """
1213
+ task_xgraph = self._make_task_xgraph_internal(init=False)
1214
+ expr_tree = expressions.parse(expression)
1215
+ matching_task_keys = self._select_expression(expr_tree, task_xgraph)
1216
+ return {key.name for key in matching_task_keys}
1217
+
1218
+ def select(self, expression: str) -> PipelineGraph:
1219
+ """Return a new pipeline graph with the tasks that match an expression.
1220
+
1221
+ Parameters
1222
+ ----------
1223
+ expression : `str`
1224
+ String expression to evaluate. See
1225
+ :ref:`pipeline-graph-subset-expressions`.
1226
+
1227
+ Returns
1228
+ -------
1229
+ new_graph : `PipelineGraph`
1230
+ New pipeline graph with just the matching tasks.
1231
+
1232
+ Notes
1233
+ -----
1234
+ All resolved dataset type nodes will be preserved.
1235
+
1236
+ If `has_been_sorted`, the new graph will be sorted as well.
1237
+
1238
+ Task subsets will not be included in the returned graph.
1239
+ """
1240
+ selected_tasks = self.select_tasks(expression)
1241
+ new_pipeline_graph = PipelineGraph(universe=self._universe, data_id=self._raw_data_id)
1242
+ new_pipeline_graph.add_task_nodes(
1243
+ [self.tasks[task_label] for task_label in selected_tasks], parent=self
1244
+ )
1245
+ if self.has_been_sorted:
1246
+ new_pipeline_graph.sort()
1247
+ return new_pipeline_graph
1248
+
1200
1249
  ###########################################################################
1201
1250
  #
1202
1251
  # Serialization Interface.
@@ -1575,6 +1624,8 @@ class PipelineGraph:
1575
1624
  element in the iterable.
1576
1625
 
1577
1626
  If `has_been_sorted`, all subgraphs will be sorted as well.
1627
+
1628
+ Task subsets will not be included in the returned graphs.
1578
1629
  """
1579
1630
  # Having an overall input in common isn't enough to make subgraphs
1580
1631
  # dependent on each other, so we want to look for connected component
@@ -1595,7 +1646,7 @@ class PipelineGraph:
1595
1646
  yield self
1596
1647
  return
1597
1648
  else:
1598
- component_subgraph = PipelineGraph(universe=self._universe)
1649
+ component_subgraph = PipelineGraph(universe=self._universe, data_id=self._raw_data_id)
1599
1650
  component_subgraph.add_task_nodes(
1600
1651
  [self._xgraph.nodes[key]["instance"] for key in component_task_keys], parent=self
1601
1652
  )
@@ -2053,6 +2104,26 @@ class PipelineGraph:
2053
2104
  """
2054
2105
  return self._xgraph.edge_subgraph([edge.key for edge in self.iter_edges(init)])
2055
2106
 
2107
+ def _make_task_xgraph_internal(self, init: bool) -> networkx.DiGraph:
2108
+ """Make a init-only or runtime-only internal task subgraph.
2109
+
2110
+ See `make_task_xgraph` for parameters and return values.
2111
+
2112
+ Notes
2113
+ -----
2114
+ This method returns a view of the `PipelineGraph` object's internal
2115
+ backing graph, and hence should only be called in methods that copy the
2116
+ result either explicitly or by running a copying algorithm before
2117
+ returning it to the user.
2118
+ """
2119
+ bipartite_xgraph = self._make_bipartite_xgraph_internal(init=init)
2120
+ task_keys = [
2121
+ key
2122
+ for key, bipartite in bipartite_xgraph.nodes(data="bipartite")
2123
+ if bipartite == NodeType.TASK.bipartite
2124
+ ]
2125
+ return networkx.algorithms.bipartite.projected_graph(networkx.DiGraph(bipartite_xgraph), task_keys)
2126
+
2056
2127
  def _transform_xgraph_state(self, xgraph: _G, skip_edges: bool) -> _G:
2057
2128
  """Transform networkx graph attributes in-place from the internal
2058
2129
  "instance" attributes to the documented exported attributes.
@@ -2342,6 +2413,284 @@ class PipelineGraph:
2342
2413
  f"{step_label!r}."
2343
2414
  )
2344
2415
 
2416
+ def _select_expression(self, expr_tree: expressions.Node, task_xgraph: networkx.DiGraph) -> set[NodeKey]:
2417
+ """Select tasks from a pipeline based on a string expression.
2418
+
2419
+ This is the primary implementation method for `select` and
2420
+ `select_tasks`.
2421
+
2422
+ Parameters
2423
+ ----------
2424
+ expr_tree : `expressions.Node`
2425
+ Expression [sub]tree to process (recursively).
2426
+ task_xgraph : `networkx.DiGraph`
2427
+ NetworkX graph of all tasks (runtime nodes only) in the pipeline.
2428
+
2429
+ Returns
2430
+ -------
2431
+ selected : `set` [ `NodeKey` ]
2432
+ Set of `NodeKey` objects for matching tasks (only; no dataset type
2433
+ or task-init nodes).
2434
+ """
2435
+ match expr_tree:
2436
+ case expressions.IdentifierNode(qualifier=qualifier, label=label):
2437
+ match self._select_identifier(qualifier, label):
2438
+ case NodeKey(node_type=NodeType.TASK) as task_key:
2439
+ return {task_key}
2440
+ case NodeKey(node_type=NodeType.DATASET_TYPE) as dataset_type_key:
2441
+ # Since a dataset type can have only one producer, this
2442
+ # yields 0- (for overall inputs) or 1-element sets.
2443
+ for producer_key, _ in self._xgraph.in_edges(dataset_type_key):
2444
+ if producer_key.node_type is NodeType.TASK_INIT:
2445
+ raise InvalidExpressionError(
2446
+ f"Init-output dataset type {label!r} cannot be used directly in an "
2447
+ "expression."
2448
+ )
2449
+ return {producer_key}
2450
+ return set()
2451
+ case TaskSubset() as task_subset:
2452
+ return {NodeKey(NodeType.TASK, label) for label in task_subset}
2453
+ case _: # pragma: no cover
2454
+ raise AssertionError("Identifier type inconsistent with grammar.")
2455
+ case expressions.DirectionNode(operator=operator, start=start):
2456
+ match self._select_identifier(start.qualifier, start.label):
2457
+ case NodeKey(node_type=NodeType.TASK) as task_key:
2458
+ if operator.startswith("<"):
2459
+ return self._select_task_ancestors(
2460
+ task_key, task_xgraph, inclusive=operator.endswith("=")
2461
+ )
2462
+ else:
2463
+ assert operator.startswith(">"), "Guaranteed by grammar."
2464
+ return self._select_task_descendants(
2465
+ task_key, task_xgraph, inclusive=operator.endswith("=")
2466
+ )
2467
+ case NodeKey(node_type=NodeType.DATASET_TYPE) as dataset_type_key:
2468
+ if operator.startswith("<"):
2469
+ return self._select_dataset_type_ancestors(
2470
+ dataset_type_key, task_xgraph, inclusive=operator.endswith("=")
2471
+ )
2472
+ else:
2473
+ assert operator.startswith(">"), "Guaranteed by grammar."
2474
+ return self._select_dataset_type_descendants(
2475
+ dataset_type_key, task_xgraph, inclusive=operator.endswith("=")
2476
+ )
2477
+ case TaskSubset():
2478
+ raise InvalidExpressionError(
2479
+ f"Task subset identifier {start!r} cannot be used as the start of an "
2480
+ "ancestor/descendant search."
2481
+ )
2482
+ case _: # pragma: no cover
2483
+ raise AssertionError("Unexpected parsed identifier result type.")
2484
+ case expressions.NotNode(operand=operand):
2485
+ operand_result = self._select_expression(operand, task_xgraph)
2486
+ return set(task_xgraph.nodes.keys() - operand_result)
2487
+ case expressions.UnionNode(lhs=lhs, rhs=rhs):
2488
+ lhs_result = self._select_expression(lhs, task_xgraph)
2489
+ rhs_result = self._select_expression(rhs, task_xgraph)
2490
+ return lhs_result.union(rhs_result)
2491
+ case expressions.IntersectionNode(lhs=lhs, rhs=rhs):
2492
+ lhs_result = self._select_expression(lhs, task_xgraph)
2493
+ rhs_result = self._select_expression(rhs, task_xgraph)
2494
+ return lhs_result.intersection(rhs_result)
2495
+ case _: # pragma: no cover
2496
+ raise AssertionError("Expression parse node inconsistent with grammar.")
2497
+
2498
+ def _select_task_ancestors(
2499
+ self, start: NodeKey, task_xgraph: networkx.DiGraph, inclusive: bool
2500
+ ) -> set[NodeKey]:
2501
+ """Return all task-node ancestors of the given task node, as defined by
2502
+ the `select` expression language.
2503
+
2504
+ Parameters
2505
+ ----------
2506
+ start : `NodeKey`
2507
+ A runtime task node key.
2508
+ task_xgraph : `networkx.DiGraph`
2509
+ NetworkX graph of all tasks (runtime nodes only) in the pipeline.
2510
+ inclusive : `bool`
2511
+ Whether to include the ``start`` node in the results.
2512
+
2513
+ Returns
2514
+ -------
2515
+ selected : `set` [ `NodeKey` ]
2516
+ Set of `NodeKey` objects for matching tasks (only; no dataset type
2517
+ or task-init nodes).
2518
+ """
2519
+ result = set(networkx.dag.ancestors(task_xgraph, start))
2520
+ if inclusive:
2521
+ result.add(start)
2522
+ return result
2523
+
2524
+ def _select_task_descendants(
2525
+ self, start: NodeKey, task_xgraph: networkx.DiGraph, inclusive: bool
2526
+ ) -> set[NodeKey]:
2527
+ """Return all task-node descendants of the given task node, as defined
2528
+ by the `select` expression language.
2529
+
2530
+ Parameters
2531
+ ----------
2532
+ start : `NodeKey`
2533
+ A runtime task node key.
2534
+ task_xgraph : `networkx.DiGraph`
2535
+ NetworkX graph of all tasks (runtime nodes only) in the pipeline.
2536
+ inclusive : `bool`
2537
+ Whether to include the ``start`` node in the results.
2538
+
2539
+ Returns
2540
+ -------
2541
+ selected : `set` [ `NodeKey` ]
2542
+ Set of `NodeKey` objects for matching tasks (only; no dataset type
2543
+ or task-init nodes).
2544
+ """
2545
+ result = set(networkx.dag.descendants(task_xgraph, start))
2546
+ if inclusive:
2547
+ result.add(start)
2548
+ return result
2549
+
2550
+ def _select_dataset_type_ancestors(
2551
+ self, start: NodeKey, task_xgraph: networkx.DiGraph, inclusive: bool
2552
+ ) -> set[NodeKey]:
2553
+ """Return all task-node ancestors of the given dataset type node, as
2554
+ defined by the `select` expression language.
2555
+
2556
+ Parameters
2557
+ ----------
2558
+ start : `NodeKey`
2559
+ A dataset type node key. May not be an init-output.
2560
+ task_xgraph : `networkx.DiGraph`
2561
+ NetworkX graph of all tasks (runtime nodes only) in the pipeline.
2562
+ inclusive : `bool`
2563
+ Whether to include the producer of the ``start`` node in the
2564
+ results.
2565
+
2566
+ Returns
2567
+ -------
2568
+ selected : `set` [ `NodeKey` ]
2569
+ Set of `NodeKey` objects for matching tasks (only; no dataset type
2570
+ or task-init nodes).
2571
+ """
2572
+ result: set[NodeKey] = set()
2573
+ for producer_key, _ in self._xgraph.in_edges(start):
2574
+ if producer_key.node_type is NodeType.TASK_INIT:
2575
+ raise InvalidExpressionError(
2576
+ f"Init-output dataset type {start.name!r} cannot be used as the "
2577
+ "starting point for an ancestor ('<' or '<=') search."
2578
+ )
2579
+ result.update(networkx.dag.ancestors(task_xgraph, producer_key))
2580
+ if inclusive:
2581
+ result.add(producer_key)
2582
+ return result
2583
+
2584
+ def _select_dataset_type_descendants(
2585
+ self, start: NodeKey, task_xgraph: networkx.DiGraph, inclusive: bool
2586
+ ) -> set[NodeKey]:
2587
+ """Return all task-node descendatns of the given dataset type node, as
2588
+ defined by the `select` expression language.
2589
+
2590
+ Parameters
2591
+ ----------
2592
+ start : `NodeKey`
2593
+ A dataset type node key. May not be an init-output if
2594
+ ``inclusive=True``.
2595
+ task_xgraph : `networkx.DiGraph`
2596
+ NetworkX graph of all tasks (runtime nodes only) in the pipeline.
2597
+ inclusive : `bool`
2598
+ Whether to include the producer of the ``start`` node in the
2599
+ results.
2600
+
2601
+ Returns
2602
+ -------
2603
+ selected : `set` [ `NodeKey` ]
2604
+ Set of `NodeKey` objects for matching tasks (only; no dataset type
2605
+ or task-init nodes).
2606
+ """
2607
+ result: set[NodeKey] = set()
2608
+ if inclusive:
2609
+ for producer_key, _ in self._xgraph.in_edges(start):
2610
+ if producer_key.node_type is NodeType.TASK_INIT:
2611
+ raise InvalidExpressionError(
2612
+ f"Init-output dataset type {start.name!r} cannot be used as the "
2613
+ "starting point for an includsive descendant ('>=') search."
2614
+ )
2615
+ result.add(producer_key)
2616
+ # We also include tasks that consume a dataset type as an init-input,
2617
+ # since that can affect their runtime behavior.
2618
+ consumer_keys: set[NodeKey] = {
2619
+ (
2620
+ consumer_key
2621
+ if consumer_key.node_type is NodeType.TASK
2622
+ else NodeKey(NodeType.TASK, consumer_key.name)
2623
+ )
2624
+ for _, consumer_key in self._xgraph.out_edges(start)
2625
+ }
2626
+ for consumer_key in consumer_keys:
2627
+ result.add(consumer_key)
2628
+ result.update(networkx.dag.descendants(task_xgraph, consumer_key))
2629
+ return result
2630
+
2631
+ def _select_identifier(
2632
+ self, qualifier: Literal["T", "D", "S"] | None, label: str
2633
+ ) -> NodeKey | TaskSubset:
2634
+ """Return the node key or task subset that corresponds to a `select`
2635
+ expression identifier.
2636
+
2637
+ Parameters
2638
+ ----------
2639
+ qualifier : `str` or `None`
2640
+ Task, dataset type, or task subset qualifier included in the
2641
+ identifier, if any.
2642
+ label : `str`
2643
+ Task label, dataset type name, or task subset label.
2644
+
2645
+ Returns
2646
+ -------
2647
+ key_or_subset : `NodeKey` or `TaskSubset`
2648
+ A `NodeKey` for a task or dataset type, or a `TaskSubset` for a
2649
+ task subset.
2650
+ """
2651
+ match qualifier:
2652
+ case None:
2653
+ task_key = NodeKey(NodeType.TASK, label)
2654
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, label)
2655
+ if task_key in self._xgraph.nodes:
2656
+ if dataset_type_key in self._xgraph.nodes:
2657
+ raise InvalidExpressionError(
2658
+ f"{label!r} is both a task label and a dataset type name; "
2659
+ "prefix with 'T:' or 'D:' (respectively) to specify which."
2660
+ )
2661
+ assert label not in self._task_subsets, "Should be prohibited at construction."
2662
+ return task_key
2663
+ elif dataset_type_key in self._xgraph.nodes:
2664
+ if label in self._task_subsets:
2665
+ raise InvalidExpressionError(
2666
+ f"{label!r} is both a subset label and a dataset type name; "
2667
+ "prefix with 'S:' or 'D:' (respectively) to specify which."
2668
+ )
2669
+ return dataset_type_key
2670
+ elif label in self._task_subsets:
2671
+ return self._task_subsets[label]
2672
+ else:
2673
+ raise InvalidExpressionError(
2674
+ f"{label!r} is not a task label, task subset label, or dataset type name."
2675
+ )
2676
+ case "T":
2677
+ task_key = NodeKey(NodeType.TASK, label)
2678
+ if task_key not in self._xgraph.nodes:
2679
+ raise InvalidExpressionError(f"Task with label {label!r} does not exist.")
2680
+ return task_key
2681
+ case "D":
2682
+ dataset_type_key = NodeKey(NodeType.DATASET_TYPE, label)
2683
+ if dataset_type_key not in self._xgraph.nodes:
2684
+ raise InvalidExpressionError(f"Dataset type with name {label!r} does not exist.")
2685
+ return dataset_type_key
2686
+ case "S":
2687
+ try:
2688
+ return self._task_subsets[label]
2689
+ except KeyError:
2690
+ raise InvalidExpressionError(f"Task subset with label {label!r} does not exist.")
2691
+ case _: # pragma: no cover
2692
+ raise AssertionError("Unexpected identifier qualifier in expression.")
2693
+
2345
2694
  _xgraph: networkx.MultiDiGraph
2346
2695
  _sorted_keys: Sequence[NodeKey] | None
2347
2696
  _task_subsets: dict[str, TaskSubset]