easylink 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
easylink/_version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.12"
1
+ __version__ = "0.1.14"
easylink/configuration.py CHANGED
@@ -274,10 +274,10 @@ class Config(LayeredConfigTree):
274
274
 
275
275
 
276
276
  def load_params_from_specification(
277
- pipeline_specification: str,
278
- input_data: str,
279
- computing_environment: str | None,
280
- results_dir: str,
277
+ pipeline_specification: str | Path,
278
+ input_data: str | Path,
279
+ computing_environment: str | Path | None,
280
+ results_dir: str | Path,
281
281
  ) -> dict[str, Any]:
282
282
  """Gathers together all specification data.
283
283
 
@@ -325,7 +325,7 @@ def _load_input_data_paths(
325
325
 
326
326
 
327
327
  def _load_computing_environment(
328
- computing_environment_specification_path: str | None,
328
+ computing_environment_specification_path: str | Path | None,
329
329
  ) -> dict[Any, Any]:
330
330
  """Loads the computing environment specification file and returns the contents as a dict."""
331
331
  if not computing_environment_specification_path:
@@ -18,6 +18,12 @@ from typing import TYPE_CHECKING, Any
18
18
 
19
19
  import networkx as nx
20
20
 
21
+ from easylink.implementation import (
22
+ NullAggregatorImplementation,
23
+ NullImplementation,
24
+ NullSplitterImplementation,
25
+ )
26
+
21
27
  if TYPE_CHECKING:
22
28
  from easylink.implementation import Implementation
23
29
  from easylink.step import Step
@@ -42,48 +48,38 @@ class InputSlot:
42
48
  env_var: str | None
43
49
  """The environment variable that is used to pass a list of data filepaths to
44
50
  an ``Implementation``."""
45
- validator: Callable[[str], None] = field(compare=False)
51
+ validator: Callable[[str], None] | None = field(compare=False)
46
52
  """A function that validates the input data being passed into the pipeline via
47
53
  this ``InputSlot``. If the data is invalid, the function should raise an exception
48
54
  with a descriptive error message which will then be reported to the user.
49
55
  **Note that the function *must* be defined in the** :mod:`easylink.utilities.validation_utils`
50
56
  **module!**"""
51
- splitter: Callable[[list[str], str, Any], None] | None = field(
52
- default=None, compare=False
53
- )
54
- """A function that splits the incoming data to this ``InputSlot`` into smaller
55
- pieces. The primary purpose of this functionality is to run sections of the
56
- pipeline in an embarrassingly parallel manner. **Note that the function *must*
57
- be defined in the **:mod:`easylink.utilities.splitter_utils`** module!**"""
58
57
 
59
58
  def __eq__(self, other: Any) -> bool | NotImplementedType:
60
- """Checks if two ``InputSlots`` are equal.
61
-
62
- Two ``InputSlots`` are considered equal if their names, ``env_vars``, and
63
- names of their ``validators`` and ``splitters`` are all the same.
64
- """
59
+ """Checks if two ``InputSlots`` are equal."""
65
60
  if not isinstance(other, InputSlot):
66
61
  return NotImplemented
67
- splitter_name = self.splitter.__name__ if self.splitter else None
68
- other_splitter_name = other.splitter.__name__ if other.splitter else None
62
+ validator_name = self.validator.__name__ if self.validator else None
63
+ other_validator_name = other.validator.__name__ if other.validator else None
69
64
  return (
70
65
  self.name == other.name
71
66
  and self.env_var == other.env_var
72
- and self.validator.__name__ == other.validator.__name__
73
- and splitter_name == other_splitter_name
67
+ and validator_name == other_validator_name
74
68
  )
75
69
 
76
70
  def __hash__(self) -> int:
77
- """Hashes an ``InputSlot``.
78
-
79
- The hash is based on the name of the ``InputSlot``, its ``env_var``, and
80
- the names of its ``validator`` and ``splitter``.
81
- """
82
- splitter_name = self.splitter.__name__ if self.splitter else None
83
- return hash((self.name, self.env_var, self.validator.__name__, splitter_name))
71
+ """Hashes an ``InputSlot``."""
72
+ validator_name = self.validator.__name__ if self.validator else None
73
+ return hash(
74
+ (
75
+ self.name,
76
+ self.env_var,
77
+ validator_name,
78
+ )
79
+ )
84
80
 
85
81
 
86
- @dataclass()
82
+ @dataclass(frozen=True)
87
83
  class OutputSlot:
88
84
  """A single output slot from a specific node.
89
85
 
@@ -104,31 +100,6 @@ class OutputSlot:
104
100
 
105
101
  name: str
106
102
  """The name of the ``OutputSlot``."""
107
- aggregator: Callable[[list[str], str], None] = field(default=None, compare=False)
108
- """A function that aggregates all of the generated data to be passed out via this
109
- ``OutputSlot``. The primary purpose of this functionality is to run sections
110
- of the pipeline in an embarrassingly parallel manner. **Note that the function
111
- *must* be defined in the **:py:mod:`easylink.utilities.aggregator_utils`** module!**"""
112
-
113
- def __eq__(self, other: Any) -> bool | NotImplementedType:
114
- """Checks if two ``OutputSlots`` are equal.
115
-
116
- Two ``OutputSlots`` are considered equal if their names and the names of their
117
- ``aggregators`` are the same.
118
- """
119
- if not isinstance(other, OutputSlot):
120
- return NotImplemented
121
- aggregator_name = self.aggregator.__name__ if self.aggregator else None
122
- other_aggregator_name = other.aggregator.__name__ if other.aggregator else None
123
- return self.name == other.name and aggregator_name == other_aggregator_name
124
-
125
- def __hash__(self) -> int:
126
- """Hashes an ``OutputSlot``.
127
-
128
- The hash is based on the name of the ``OutputSlot`` and the name of its ``aggregator``.
129
- """
130
- aggregator_name = self.aggregator.__name__ if self.aggregator else None
131
- return hash((self.name, aggregator_name))
132
103
 
133
104
 
134
105
  @dataclass(frozen=True)
@@ -263,7 +234,33 @@ class ImplementationGraph(nx.MultiDiGraph):
263
234
  def implementation_nodes(self) -> list[str]:
264
235
  """The topologically sorted list of ``Implementation`` names."""
265
236
  ordered_nodes = list(nx.topological_sort(self))
266
- return [node for node in ordered_nodes if node != "input_data" and node != "results"]
237
+ # Remove nodes that do not actually have implementations
238
+ null_implementations = [
239
+ node
240
+ for node in ordered_nodes
241
+ if isinstance(self.nodes[node]["implementation"], NullImplementation)
242
+ ]
243
+ return [node for node in ordered_nodes if node not in null_implementations]
244
+
245
+ @property
246
+ def splitter_nodes(self) -> list[str]:
247
+ """The topologically sorted list of splitter nodes (which have no implementations)."""
248
+ ordered_nodes = list(nx.topological_sort(self))
249
+ return [
250
+ node
251
+ for node in ordered_nodes
252
+ if isinstance(self.nodes[node]["implementation"], NullSplitterImplementation)
253
+ ]
254
+
255
+ @property
256
+ def aggregator_nodes(self) -> list[str]:
257
+ """The topologically sorted list of aggregator nodes (which have no implementations)."""
258
+ ordered_nodes = list(nx.topological_sort(self))
259
+ return [
260
+ node
261
+ for node in ordered_nodes
262
+ if isinstance(self.nodes[node]["implementation"], NullAggregatorImplementation)
263
+ ]
267
264
 
268
265
  @property
269
266
  def implementations(self) -> list[Implementation]:
@@ -9,15 +9,20 @@ information about what container to run for a given step and other related detai
9
9
 
10
10
  """
11
11
 
12
+ from __future__ import annotations
13
+
12
14
  from collections.abc import Iterable
13
15
  from pathlib import Path
16
+ from typing import TYPE_CHECKING
14
17
 
15
18
  from layered_config_tree import LayeredConfigTree
16
19
 
17
- from easylink.graph_components import InputSlot, OutputSlot
18
20
  from easylink.utilities import paths
19
21
  from easylink.utilities.data_utils import load_yaml
20
22
 
23
+ if TYPE_CHECKING:
24
+ from easylink.graph_components import InputSlot, OutputSlot
25
+
21
26
 
22
27
  class Implementation:
23
28
  """A representation of an actual container that will be executed for a :class:`~easylink.step.Step`.
@@ -43,8 +48,8 @@ class Implementation:
43
48
  self,
44
49
  schema_steps: list[str],
45
50
  implementation_config: LayeredConfigTree,
46
- input_slots: Iterable["InputSlot"] = (),
47
- output_slots: Iterable["OutputSlot"] = (),
51
+ input_slots: Iterable[InputSlot] = (),
52
+ output_slots: Iterable[OutputSlot] = (),
48
53
  is_embarrassingly_parallel: bool = False,
49
54
  ):
50
55
  self.name = implementation_config.name
@@ -137,9 +142,8 @@ class Implementation:
137
142
  class NullImplementation:
138
143
  """A partial :class:`Implementation` interface when no container is needed to run.
139
144
 
140
- The primary use case for this class is when adding an
141
- :class:`~easylink.step.IOStep` - which does not have a corresponding
142
- ``Implementation`` - to an :class:`~easylink.graph_components.ImplementationGraph`
145
+ The primary use case for this class is to be able to add a :class:`~easylink.step.Step`
146
+ that does *not* have a corresponding ``Implementation`` to an :class:`~easylink.graph_components.ImplementationGraph`
143
147
  since adding any new node requires an object with :class:`~easylink.graph_components.InputSlot`
144
148
  and :class:`~easylink.graph_components.OutputSlot` names.
145
149
 
@@ -151,13 +155,14 @@ class NullImplementation:
151
155
  All required ``InputSlots``.
152
156
  output_slots
153
157
  All required ``OutputSlots``.
158
+
154
159
  """
155
160
 
156
161
  def __init__(
157
162
  self,
158
163
  name: str,
159
- input_slots: Iterable["InputSlot"] = (),
160
- output_slots: Iterable["OutputSlot"] = (),
164
+ input_slots: Iterable[InputSlot] = (),
165
+ output_slots: Iterable[OutputSlot] = (),
161
166
  ):
162
167
  self.name = name
163
168
  """The name of this ``NullImplementation``."""
@@ -172,6 +177,61 @@ class NullImplementation:
172
177
  is a constituent. This is definitionally None."""
173
178
 
174
179
 
180
+ class NullSplitterImplementation(NullImplementation):
181
+ """A type of :class:`NullImplementation` specifically for :class:`SplitterSteps<easylink.step.SplitterStep>`.
182
+
183
+ See ``NullImplementation`` for inherited attributes.
184
+
185
+ Parameters
186
+ ----------
187
+ splitter_func_name
188
+ The name of the splitter function to use.
189
+
190
+ """
191
+
192
+ def __init__(
193
+ self,
194
+ name: str,
195
+ input_slots: Iterable[InputSlot],
196
+ output_slots: Iterable[OutputSlot],
197
+ splitter_func_name: str,
198
+ ):
199
+ super().__init__(name, input_slots, output_slots)
200
+ self.splitter_func_name = splitter_func_name
201
+ """The name of the splitter function to use."""
202
+
203
+
204
+ class NullAggregatorImplementation(NullImplementation):
205
+ """A type of :class:`NullImplementation` specifically for :class:`AggregatorSteps<easylink.step.AggregatorStep>`.
206
+
207
+ See ``NullImplementation`` for inherited attributes.
208
+
209
+ Parameters
210
+ ----------
211
+ aggregator_func_name
212
+ The name of the aggregation function to use.
213
+ splitter_node_name
214
+ The name of the :class:`~easylink.step.SplitterStep` and its corresponding
215
+ :class:`NullSplitterImplementation` that did the splitting.
216
+
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ name: str,
222
+ input_slots: Iterable[InputSlot],
223
+ output_slots: Iterable[OutputSlot],
224
+ aggregator_func_name: str,
225
+ splitter_node_name: str,
226
+ ):
227
+ super().__init__(name, input_slots, output_slots)
228
+ self.aggregator_func_name = aggregator_func_name
229
+ """The name of the aggregation function to use."""
230
+ self.splitter_node_name = splitter_node_name
231
+ """The name of the :class:`~easylink.step.SplitterStep` and its corresponding
232
+ :class:`NullSplitterImplementation` that did the splitting."""
233
+
234
+
175
235
  class PartialImplementation:
176
236
  """One part of a combined implementation that spans multiple :class:`Steps<easylink.step.Step>`.
177
237
 
@@ -205,8 +265,8 @@ class PartialImplementation:
205
265
  self,
206
266
  combined_name: str,
207
267
  schema_step: str,
208
- input_slots: Iterable["InputSlot"] = (),
209
- output_slots: Iterable["OutputSlot"] = (),
268
+ input_slots: Iterable[InputSlot] = (),
269
+ output_slots: Iterable[OutputSlot] = (),
210
270
  ):
211
271
  self.combined_name = combined_name
212
272
  """The name of the combined implementation of which this ``PartialImplementation``
easylink/pipeline.py CHANGED
@@ -93,8 +93,18 @@ class Pipeline:
93
93
  self._write_spark_config()
94
94
  self._write_target_rules()
95
95
  self._write_spark_module()
96
- for node in self.pipeline_graph.implementation_nodes:
97
- self._write_implementation_rules(node)
96
+ for node_name in self.pipeline_graph.implementation_nodes:
97
+ self._write_implementation_rules(node_name)
98
+ checkpoint_filepaths = self._get_checkpoint_filepaths()
99
+ for node_name in self.pipeline_graph.splitter_nodes:
100
+ self._write_checkpoint_rule(node_name, checkpoint_filepaths[node_name])
101
+ for node_name in self.pipeline_graph.aggregator_nodes:
102
+ self._write_aggregation_rule(
103
+ node_name,
104
+ checkpoint_filepaths[
105
+ self.pipeline_graph.nodes[node_name]["implementation"].splitter_node_name
106
+ ],
107
+ )
98
108
  return self.snakefile_path
99
109
 
100
110
  ##################
@@ -130,6 +140,42 @@ class Pipeline:
130
140
  errors[IMPLEMENTATION_ERRORS_KEY][implementation.name] = implementation_errors
131
141
  return errors
132
142
 
143
+ @staticmethod
144
+ def _get_input_slots_to_split(
145
+ input_slot_dict: dict[str, dict[str, str | list[str]]]
146
+ ) -> list[str]:
147
+ """Gets any input slots that have a splitter attribute."""
148
+ return [
149
+ slot_name
150
+ for slot_name, slot_attrs in input_slot_dict.items()
151
+ if slot_attrs.get("splitter", None)
152
+ ]
153
+
154
+ def _get_checkpoint_filepaths(self) -> dict[str, str]:
155
+ """Gets a checkpoint filepath for each splitter node."""
156
+ checkpoint_filepaths = {}
157
+ for node_name in self.pipeline_graph.splitter_nodes:
158
+ _input_files, output_files = self.pipeline_graph.get_io_filepaths(node_name)
159
+ if len(set(output_files)) > 1:
160
+ raise ValueError(
161
+ "The list of output files from a CheckpointRule should always be "
162
+ "length 1; wildcards handle the fact that there are actually "
163
+ "multiple files."
164
+ )
165
+ # The snakemake checkpoint rule requires the output parent directory
166
+ # to the chunked sub-directories (which are created by the splitter).
167
+ # e.g. if the chunks are eventually going to be written to
168
+ # 'intermediate/split_step_1_python_pandas/{chunk}/result.parquet',
169
+ # we need the output directory 'intermediate/split_step_1_python_pandas'
170
+ checkpoint_filepaths[node_name] = str(
171
+ Path(output_files[0]).parent.parent / "checkpoint.txt"
172
+ )
173
+ return checkpoint_filepaths
174
+
175
+ #################################
176
+ # Snakefile Rule Writer Methods #
177
+ #################################
178
+
133
179
  def _write_imports(self) -> None:
134
180
  if not self.any_embarrassingly_parallel:
135
181
  imports = "from easylink.utilities import validation_utils\n"
@@ -157,7 +203,7 @@ wildcard_constraints:
157
203
  def _write_target_rules(self) -> None:
158
204
  """Writes the rule for the final output and its validation.
159
205
 
160
- The input files to the the target rule (i.e. the result node) are the final
206
+ The input files to the target rule (i.e. the result node) are the final
161
207
  output themselves.
162
208
  """
163
209
  final_output, _ = self.pipeline_graph.get_io_filepaths("results")
@@ -246,29 +292,17 @@ use rule start_spark_worker from spark_cluster with:
246
292
  The name of the ``Implementation`` to write the rule(s) for.
247
293
  """
248
294
 
249
- input_slots, output_slots = self.pipeline_graph.get_io_slot_attributes(node_name)
250
- validation_files, validation_rules = self._get_validations(node_name, input_slots)
295
+ is_embarrassingly_parallel = self.pipeline_graph.get_whether_embarrassingly_parallel(
296
+ node_name
297
+ )
298
+ input_slots, _output_slots = self.pipeline_graph.get_io_slot_attributes(node_name)
299
+ validation_files, validation_rules = self._get_validations(
300
+ node_name, input_slots, is_embarrassingly_parallel
301
+ )
251
302
  for validation_rule in validation_rules:
252
303
  validation_rule.write_to_snakefile(self.snakefile_path)
253
304
 
254
305
  _input_files, output_files = self.pipeline_graph.get_io_filepaths(node_name)
255
- is_embarrassingly_parallel = self.pipeline_graph.get_whether_embarrassingly_parallel(
256
- node_name
257
- )
258
- if is_embarrassingly_parallel:
259
- CheckpointRule(
260
- name=node_name,
261
- input_slots=input_slots,
262
- validations=validation_files,
263
- output=output_files,
264
- ).write_to_snakefile(self.snakefile_path)
265
- for name, attrs in output_slots.items():
266
- AggregationRule(
267
- name=node_name,
268
- input_slots=input_slots,
269
- output_slot_name=name,
270
- output_slot=attrs,
271
- ).write_to_snakefile(self.snakefile_path)
272
306
 
273
307
  implementation = self.pipeline_graph.nodes[node_name]["implementation"]
274
308
  diagnostics_dir = Path("diagnostics") / node_name
@@ -294,9 +328,74 @@ use rule start_spark_worker from spark_cluster with:
294
328
  is_embarrassingly_parallel=is_embarrassingly_parallel,
295
329
  ).write_to_snakefile(self.snakefile_path)
296
330
 
331
+ def _write_checkpoint_rule(self, node_name: str, checkpoint_filepath: str) -> None:
332
+ """Writes the snakemake checkpoint rule.
333
+
334
+ This builds the ``CheckpointRule`` which splits the data into (unprocessed)
335
+ chunks and saves them in the output directory using wildcards.
336
+ """
337
+ splitter_func_name = self.pipeline_graph.nodes[node_name][
338
+ "implementation"
339
+ ].splitter_func_name
340
+ input_files, output_files = self.pipeline_graph.get_io_filepaths(node_name)
341
+ if len(output_files) > 1:
342
+ raise ValueError(
343
+ "The list of output files from a CheckpointRule should always be "
344
+ "length 1; wildcards handle the fact that there are actually "
345
+ "multiple files."
346
+ )
347
+ # The snakemake checkpoint rule requires the output parent directory
348
+ # to the chunked sub-directories (which are created by the splitter).
349
+ # e.g. if the chunks are eventually going to be written to
350
+ # 'intermediate/split_step_1_python_pandas/{chunk}/result.parquet',
351
+ # we need the output directory 'intermediate/split_step_1_python_pandas'
352
+ output_dir = str(Path(output_files[0]).parent.parent)
353
+ CheckpointRule(
354
+ name=node_name,
355
+ input_files=input_files,
356
+ splitter_func_name=splitter_func_name,
357
+ output_dir=output_dir,
358
+ checkpoint_filepath=checkpoint_filepath,
359
+ ).write_to_snakefile(self.snakefile_path)
360
+
361
+ def _write_aggregation_rule(self, node_name: str, checkpoint_filepath: str) -> None:
362
+ """Writes the snakemake aggregation rule.
363
+
364
+ This builds the ``AggregationRule`` which aggregates the processed data
365
+ from the chunks originally created by the ``SplitterRule``.
366
+ """
367
+ _input_slots, output_slots = self.pipeline_graph.get_io_slot_attributes(node_name)
368
+ input_files, output_files = self.pipeline_graph.get_io_filepaths(node_name)
369
+ if len(output_slots) > 1:
370
+ raise NotImplementedError(
371
+ "FIXME [MIC-5883] Multiple output slots/files of EmbarrassinglyParallelSteps not yet supported"
372
+ )
373
+ if len(output_files) > 1:
374
+ raise ValueError(
375
+ "There should always only be a single output file from an AggregationRule."
376
+ )
377
+ implementation = self.pipeline_graph.nodes[node_name]["implementation"]
378
+ output_slot_name = list(output_slots.keys())[0]
379
+ output_slot_attrs = list(output_slots.values())[0]
380
+ if len(output_slot_attrs["filepaths"]) > 1:
381
+ raise NotImplementedError(
382
+ "FIXME [MIC-5883] Multiple output slots/files of EmbarrassinglyParallelSteps not yet supported"
383
+ )
384
+ checkpoint_rule_name = f"checkpoints.{implementation.splitter_node_name}"
385
+ AggregationRule(
386
+ name=f"{node_name}_{output_slot_name}",
387
+ input_files=input_files,
388
+ aggregated_output_file=output_files[0],
389
+ aggregator_func_name=implementation.aggregator_func_name,
390
+ checkpoint_filepath=checkpoint_filepath,
391
+ checkpoint_rule_name=checkpoint_rule_name,
392
+ ).write_to_snakefile(self.snakefile_path)
393
+
297
394
  @staticmethod
298
395
  def _get_validations(
299
- node_name, input_slots
396
+ node_name: str,
397
+ input_slots: dict[str, dict[str, str | list[str]]],
398
+ is_embarrassingly_parallel: bool,
300
399
  ) -> tuple[list[str], list[InputValidationRule]]:
301
400
  """Gets the validation rule and its output filepath for each slot for a given node.
302
401
 
@@ -315,7 +414,11 @@ use rule start_spark_worker from spark_cluster with:
315
414
  validation_rules = []
316
415
 
317
416
  for input_slot_name, input_slot_attrs in input_slots.items():
318
- validation_file = f"input_validations/{node_name}/{input_slot_name}_validator"
417
+ # embarrassingly parallel implementations rely on snakemake wildcards
418
+ # TODO: [MIC-5787] - need to support multiple wildcards at once
419
+ validation_file = f"input_validations/{node_name}/{input_slot_name}_validator" + (
420
+ "-{chunk}" if is_embarrassingly_parallel else ""
421
+ )
319
422
  validation_files.append(validation_file)
320
423
  validation_rules.append(
321
424
  InputValidationRule(
@@ -82,7 +82,7 @@ class PipelineGraph(ImplementationGraph):
82
82
  ]
83
83
  )
84
84
 
85
- def get_whether_embarrassingly_parallel(self, node: str) -> bool:
85
+ def get_whether_embarrassingly_parallel(self, node: str) -> dict[str, bool]:
86
86
  """Determines whether a node is to be run in an embarrassingly parallel way.
87
87
 
88
88
  Parameters
@@ -119,11 +119,13 @@ class PipelineGraph(ImplementationGraph):
119
119
  )
120
120
  )
121
121
  output_files = list(
122
- itertools.chain.from_iterable(
123
- [
124
- edge_attrs["filepaths"]
125
- for _, _, edge_attrs in self.out_edges(node, data=True)
126
- ]
122
+ set(
123
+ itertools.chain.from_iterable(
124
+ [
125
+ edge_attrs["filepaths"]
126
+ for _, _, edge_attrs in self.out_edges(node, data=True)
127
+ ]
128
+ )
127
129
  )
128
130
  )
129
131
  return input_files, output_files
@@ -480,10 +482,31 @@ class PipelineGraph(ImplementationGraph):
480
482
  str(
481
483
  Path("intermediate")
482
484
  / node
485
+ # embarrassingly parallel implementations rely on snakemake wildcards
486
+ # TODO: [MIC-5787] - need to support multiple wildcards at once
487
+ / ("{chunk}" if implementation.is_embarrassingly_parallel else "")
483
488
  / imp_outputs[edge_attrs["output_slot"].name]
484
489
  ),
485
490
  )
486
491
 
492
+ # Update splitters and aggregators with their filepaths
493
+ for node in self.splitter_nodes:
494
+ implementation = self.nodes[node]["implementation"]
495
+ for src, sink, edge_attrs in self.out_edges(node, data=True):
496
+ for edge_idx in self[node][sink]:
497
+ # splitter nodes rely on snakemake wildcards
498
+ # TODO: [MIC-5787] - need to support multiple wildcards at once
499
+ self[src][sink][edge_idx]["filepaths"] = (
500
+ str(Path("intermediate") / node / "{chunk}" / "result.parquet"),
501
+ )
502
+ for node in self.aggregator_nodes:
503
+ implementation = self.nodes[node]["implementation"]
504
+ for src, sink, edge_attrs in self.out_edges(node, data=True):
505
+ for edge_idx in self[node][sink]:
506
+ self[src][sink][edge_idx]["filepaths"] = (
507
+ str(Path("intermediate") / node / "result.parquet"),
508
+ )
509
+
487
510
  @staticmethod
488
511
  def _deduplicate_input_slots(
489
512
  input_slots: list[InputSlot], filepaths_by_slot: list[str]
@@ -509,23 +532,21 @@ class PipelineGraph(ImplementationGraph):
509
532
  """
510
533
  condensed_slot_dict = {}
511
534
  for input_slot, filepaths in zip(input_slots, filepaths_by_slot):
512
- slot_name, env_var, validator, splitter = (
535
+ slot_name, env_var, validator = (
513
536
  input_slot.name,
514
537
  input_slot.env_var,
515
538
  input_slot.validator,
516
- input_slot.splitter,
517
539
  )
540
+ attrs = {
541
+ "env_var": env_var,
542
+ "validator": validator,
543
+ }
518
544
  if slot_name in condensed_slot_dict:
519
- if env_var != condensed_slot_dict[slot_name]["env_var"]:
520
- raise ValueError(
521
- f"Duplicate input slots named '{slot_name}' have different env vars."
522
- )
523
- condensed_slot_validator = condensed_slot_dict[slot_name]["validator"]
524
- if validator != condensed_slot_validator:
525
- raise ValueError(
526
- f"Duplicate input slots named '{slot_name}' have different validators: "
527
- f"'{validator.__name__}' and '{condensed_slot_validator.__name__}'."
528
- )
545
+ for key, value in attrs.items():
546
+ if value != condensed_slot_dict[slot_name][key]:
547
+ raise ValueError(
548
+ f"Duplicate input slots named '{slot_name}' have different {key} values."
549
+ )
529
550
  # Add the new filepaths to the existing slot
530
551
  condensed_slot_dict[slot_name]["filepaths"].extend(filepaths)
531
552
  else:
@@ -533,7 +554,6 @@ class PipelineGraph(ImplementationGraph):
533
554
  "env_var": env_var,
534
555
  "validator": validator,
535
556
  "filepaths": filepaths,
536
- "splitter": splitter,
537
557
  }
538
558
  return condensed_slot_dict
539
559
 
@@ -556,16 +576,16 @@ class PipelineGraph(ImplementationGraph):
556
576
  """
557
577
  condensed_slot_dict = {}
558
578
  for output_slot, filepaths in zip(output_slots, filepaths_by_slot):
559
- slot_name, aggregator = (
560
- output_slot.name,
561
- output_slot.aggregator,
562
- )
579
+ slot_name = output_slot.name
563
580
  if slot_name in condensed_slot_dict:
564
- # Add the new filepaths to the existing slot
565
- condensed_slot_dict[slot_name]["filepaths"].extend(filepaths)
581
+ # Add any new/unique filepaths to the existing slot
582
+ condensed_slot_dict[slot_name]["filepaths"].extend(
583
+ item
584
+ for item in filepaths
585
+ if item not in condensed_slot_dict[slot_name]["filepaths"]
586
+ )
566
587
  else:
567
588
  condensed_slot_dict[slot_name] = {
568
589
  "filepaths": filepaths,
569
- "aggregator": aggregator,
570
590
  }
571
591
  return condensed_slot_dict
@@ -16,11 +16,15 @@ ALLOWED_SCHEMA_PARAMS = {
16
16
  }
17
17
 
18
18
  TESTING_SCHEMA_PARAMS = {
19
- "integration": testing.SINGLE_STEP_SCHEMA_PARAMS,
20
- "combine_bad_topology": testing.BAD_COMBINED_TOPOLOGY_SCHEMA_PARAMS,
21
- "combine_bad_implementation_names": testing.BAD_COMBINED_TOPOLOGY_SCHEMA_PARAMS,
22
- "nested_templated_steps": testing.NESTED_TEMPLATED_STEPS_SCHEMA_PARAMS,
23
- "combine_with_iteration": testing.COMBINE_WITH_ITERATION_SCHEMA_PARAMS,
24
- "combine_with_iteration_cycle": testing.COMBINE_WITH_ITERATION_SCHEMA_PARAMS,
25
- "combine_with_extra_node": testing.TRIPLE_STEP_SCHEMA_PARAMS,
19
+ "integration": testing.SCHEMA_PARAMS_ONE_STEP,
20
+ "combine_bad_topology": testing.SCHEMA_PARAMS_BAD_COMBINED_TOPOLOGY,
21
+ "combine_bad_implementation_names": testing.SCHEMA_PARAMS_BAD_COMBINED_TOPOLOGY,
22
+ "nested_templated_steps": testing.SCHEMA_PARAMS_NESTED_TEMPLATED_STEPS,
23
+ "combine_with_iteration": testing.SCHEMA_PARAMS_COMBINE_WITH_ITERATION,
24
+ "combine_with_iteration_cycle": testing.SCHEMA_PARAMS_COMBINE_WITH_ITERATION,
25
+ "combine_with_extra_node": testing.SCHEMA_PARAMS_THREE_STEPS,
26
+ "looping_ep_step": testing.SCHEMA_PARAMS_LOOPING_EP_STEP,
27
+ "ep_parallel_step": testing.SCHEMA_PARAMS_EP_PARALLEL_STEP,
28
+ "ep_loop_step": testing.SCHEMA_PARAMS_EP_LOOP_STEP,
29
+ "ep_hierarchical_step": testing.SCHEMA_PARAMS_EP_HIERARCHICAL_STEP,
26
30
  }