runnable 0.17.1__py3-none-any.whl → 0.18.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. extensions/README.md +0 -0
  2. extensions/__init__.py +0 -0
  3. extensions/catalog/README.md +0 -0
  4. extensions/catalog/file_system.py +253 -0
  5. extensions/catalog/pyproject.toml +14 -0
  6. extensions/job_executor/README.md +0 -0
  7. extensions/job_executor/__init__.py +160 -0
  8. extensions/job_executor/k8s.py +362 -0
  9. extensions/job_executor/k8s_job_spec.yaml +37 -0
  10. extensions/job_executor/local.py +61 -0
  11. extensions/job_executor/local_container.py +192 -0
  12. extensions/job_executor/pyproject.toml +16 -0
  13. extensions/nodes/README.md +0 -0
  14. extensions/nodes/nodes.py +954 -0
  15. extensions/nodes/pyproject.toml +15 -0
  16. extensions/pipeline_executor/README.md +0 -0
  17. extensions/pipeline_executor/__init__.py +644 -0
  18. extensions/pipeline_executor/argo.py +1307 -0
  19. extensions/pipeline_executor/argo_specification.yaml +51 -0
  20. extensions/pipeline_executor/local.py +62 -0
  21. extensions/pipeline_executor/local_container.py +363 -0
  22. extensions/pipeline_executor/mocked.py +161 -0
  23. extensions/pipeline_executor/pyproject.toml +16 -0
  24. extensions/pipeline_executor/retry.py +180 -0
  25. extensions/run_log_store/README.md +0 -0
  26. extensions/run_log_store/__init__.py +0 -0
  27. extensions/run_log_store/chunked_fs.py +113 -0
  28. extensions/run_log_store/db/implementation_FF.py +163 -0
  29. extensions/run_log_store/db/integration_FF.py +0 -0
  30. extensions/run_log_store/file_system.py +145 -0
  31. extensions/run_log_store/generic_chunked.py +599 -0
  32. extensions/run_log_store/pyproject.toml +15 -0
  33. extensions/secrets/README.md +0 -0
  34. extensions/secrets/dotenv.py +62 -0
  35. extensions/secrets/pyproject.toml +15 -0
  36. {runnable-0.17.1.dist-info → runnable-0.18.0.dist-info}/METADATA +1 -7
  37. runnable-0.18.0.dist-info/RECORD +58 -0
  38. runnable-0.17.1.dist-info/RECORD +0 -23
  39. {runnable-0.17.1.dist-info → runnable-0.18.0.dist-info}/WHEEL +0 -0
  40. {runnable-0.17.1.dist-info → runnable-0.18.0.dist-info}/entry_points.txt +0 -0
  41. {runnable-0.17.1.dist-info → runnable-0.18.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,954 @@
1
+ import importlib
2
+ import logging
3
+ import os
4
+ import sys
5
+ from collections import OrderedDict
6
+ from copy import deepcopy
7
+ from datetime import datetime
8
+ from typing import Annotated, Any, Callable, Dict, List, Optional, Tuple, Union, cast
9
+
10
+ from pydantic import (
11
+ ConfigDict,
12
+ Field,
13
+ ValidationInfo,
14
+ field_serializer,
15
+ field_validator,
16
+ )
17
+
18
+ from runnable import console, datastore, defaults, utils
19
+ from runnable.datastore import (
20
+ JsonParameter,
21
+ MetricParameter,
22
+ ObjectParameter,
23
+ Parameter,
24
+ StepLog,
25
+ )
26
+ from runnable.defaults import TypeMapVariable
27
+ from runnable.graph import Graph, create_graph
28
+ from runnable.nodes import CompositeNode, ExecutableNode, TerminalNode
29
+ from runnable.tasks import BaseTaskType, create_task
30
+
31
+ logger = logging.getLogger(defaults.LOGGER_NAME)
32
+
33
+
34
+ class TaskNode(ExecutableNode):
35
+ """
36
+ A node of type Task.
37
+
38
+ This node does the actual function execution of the graph in all cases.
39
+ """
40
+
41
+ executable: BaseTaskType = Field(exclude=True)
42
+ node_type: str = Field(default="task", serialization_alias="type")
43
+
44
+ # It is technically not allowed as parse_from_config filters them.
45
+ # This is just to get the task level configuration to be present during serialization.
46
+ model_config = ConfigDict(extra="allow")
47
+
48
+ @classmethod
49
+ def parse_from_config(cls, config: Dict[str, Any]) -> "TaskNode":
50
+ # separate task config from node config
51
+ task_config = {
52
+ k: v for k, v in config.items() if k not in TaskNode.model_fields.keys()
53
+ }
54
+ node_config = {
55
+ k: v for k, v in config.items() if k in TaskNode.model_fields.keys()
56
+ }
57
+
58
+ executable = create_task(task_config)
59
+ return cls(executable=executable, **node_config, **task_config)
60
+
61
+ def get_summary(self) -> Dict[str, Any]:
62
+ summary = {
63
+ "name": self.name,
64
+ "type": self.node_type,
65
+ "executable": self.executable.get_summary(),
66
+ "catalog": self._get_catalog_settings(),
67
+ }
68
+
69
+ return summary
70
+
71
+ def execute(
72
+ self,
73
+ mock=False,
74
+ map_variable: TypeMapVariable = None,
75
+ attempt_number: int = 1,
76
+ **kwargs,
77
+ ) -> StepLog:
78
+ """
79
+ All that we do in runnable is to come to this point where we actually execute the command.
80
+
81
+ Args:
82
+ executor (_type_): The executor class
83
+ mock (bool, optional): If we should just mock and not execute. Defaults to False.
84
+ map_variable (dict, optional): If the node is part of internal branch. Defaults to None.
85
+
86
+ Returns:
87
+ StepAttempt: The attempt object
88
+ """
89
+ step_log = self._context.run_log_store.get_step_log(
90
+ self._get_step_log_name(map_variable), self._context.run_id
91
+ )
92
+
93
+ if not mock:
94
+ # Do not run if we are mocking the execution, could be useful for caching and dry runs
95
+ attempt_log = self.executable.execute_command(map_variable=map_variable)
96
+ attempt_log.attempt_number = attempt_number
97
+ else:
98
+ attempt_log = datastore.StepAttempt(
99
+ status=defaults.SUCCESS,
100
+ start_time=str(datetime.now()),
101
+ end_time=str(datetime.now()),
102
+ attempt_number=attempt_number,
103
+ )
104
+
105
+ logger.info(f"attempt_log: {attempt_log}")
106
+ logger.info(f"Step {self.name} completed with status: {attempt_log.status}")
107
+
108
+ step_log.status = attempt_log.status
109
+ step_log.attempts.append(attempt_log)
110
+
111
+ return step_log
112
+
113
+
114
+ class FailNode(TerminalNode):
115
+ """
116
+ A leaf node of the graph that represents a failure node
117
+ """
118
+
119
+ node_type: str = Field(default="fail", serialization_alias="type")
120
+
121
+ @classmethod
122
+ def parse_from_config(cls, config: Dict[str, Any]) -> "FailNode":
123
+ return cast("FailNode", super().parse_from_config(config))
124
+
125
+ def get_summary(self) -> Dict[str, Any]:
126
+ summary = {
127
+ "name": self.name,
128
+ "type": self.node_type,
129
+ }
130
+
131
+ return summary
132
+
133
+ def execute(
134
+ self,
135
+ mock=False,
136
+ map_variable: TypeMapVariable = None,
137
+ attempt_number: int = 1,
138
+ **kwargs,
139
+ ) -> StepLog:
140
+ """
141
+ Execute the failure node.
142
+ Set the run or branch log status to failure.
143
+
144
+ Args:
145
+ executor (_type_): the executor class
146
+ mock (bool, optional): If we should just mock and not do the actual execution. Defaults to False.
147
+ map_variable (dict, optional): If the node belongs to internal branches. Defaults to None.
148
+
149
+ Returns:
150
+ StepAttempt: The step attempt object
151
+ """
152
+ step_log = self._context.run_log_store.get_step_log(
153
+ self._get_step_log_name(map_variable), self._context.run_id
154
+ )
155
+
156
+ attempt_log = datastore.StepAttempt(
157
+ status=defaults.SUCCESS,
158
+ start_time=str(datetime.now()),
159
+ end_time=str(datetime.now()),
160
+ attempt_number=attempt_number,
161
+ )
162
+
163
+ run_or_branch_log = self._context.run_log_store.get_branch_log(
164
+ self._get_branch_log_name(map_variable), self._context.run_id
165
+ )
166
+ run_or_branch_log.status = defaults.FAIL
167
+ self._context.run_log_store.add_branch_log(
168
+ run_or_branch_log, self._context.run_id
169
+ )
170
+
171
+ step_log.status = attempt_log.status
172
+
173
+ step_log.attempts.append(attempt_log)
174
+
175
+ return step_log
176
+
177
+
178
+ class SuccessNode(TerminalNode):
179
+ """
180
+ A leaf node of the graph that represents a success node
181
+ """
182
+
183
+ node_type: str = Field(default="success", serialization_alias="type")
184
+
185
+ @classmethod
186
+ def parse_from_config(cls, config: Dict[str, Any]) -> "SuccessNode":
187
+ return cast("SuccessNode", super().parse_from_config(config))
188
+
189
+ def get_summary(self) -> Dict[str, Any]:
190
+ summary = {
191
+ "name": self.name,
192
+ "type": self.node_type,
193
+ }
194
+
195
+ return summary
196
+
197
+ def execute(
198
+ self,
199
+ mock=False,
200
+ map_variable: TypeMapVariable = None,
201
+ attempt_number: int = 1,
202
+ **kwargs,
203
+ ) -> StepLog:
204
+ """
205
+ Execute the success node.
206
+ Set the run or branch log status to success.
207
+
208
+ Args:
209
+ executor (_type_): The executor class
210
+ mock (bool, optional): If we should just mock and not perform anything. Defaults to False.
211
+ map_variable (dict, optional): If the node belongs to an internal branch. Defaults to None.
212
+
213
+ Returns:
214
+ StepAttempt: The step attempt object
215
+ """
216
+ step_log = self._context.run_log_store.get_step_log(
217
+ self._get_step_log_name(map_variable), self._context.run_id
218
+ )
219
+
220
+ attempt_log = datastore.StepAttempt(
221
+ status=defaults.SUCCESS,
222
+ start_time=str(datetime.now()),
223
+ end_time=str(datetime.now()),
224
+ attempt_number=attempt_number,
225
+ )
226
+
227
+ run_or_branch_log = self._context.run_log_store.get_branch_log(
228
+ self._get_branch_log_name(map_variable), self._context.run_id
229
+ )
230
+ run_or_branch_log.status = defaults.SUCCESS
231
+ self._context.run_log_store.add_branch_log(
232
+ run_or_branch_log, self._context.run_id
233
+ )
234
+
235
+ step_log.status = attempt_log.status
236
+
237
+ step_log.attempts.append(attempt_log)
238
+
239
+ return step_log
240
+
241
+
242
+ class ParallelNode(CompositeNode):
243
+ """
244
+ A composite node containing many graph objects within itself.
245
+
246
+ The structure is generally:
247
+ ParallelNode:
248
+ Branch A:
249
+ Sub graph definition
250
+ Branch B:
251
+ Sub graph definition
252
+ . . .
253
+
254
+ """
255
+
256
+ node_type: str = Field(default="parallel", serialization_alias="type")
257
+ branches: Dict[str, Graph]
258
+ is_composite: bool = Field(default=True, exclude=True)
259
+
260
+ def get_summary(self) -> Dict[str, Any]:
261
+ summary = {
262
+ "name": self.name,
263
+ "type": self.node_type,
264
+ "branches": [branch.get_summary() for branch in self.branches.values()],
265
+ }
266
+
267
+ return summary
268
+
269
+ @field_serializer("branches")
270
+ def ser_branches(self, branches: Dict[str, Graph]) -> Dict[str, Graph]:
271
+ ret: Dict[str, Graph] = {}
272
+
273
+ for branch_name, branch in branches.items():
274
+ ret[branch_name.split(".")[-1]] = branch
275
+
276
+ return ret
277
+
278
+ @classmethod
279
+ def parse_from_config(cls, config: Dict[str, Any]) -> "ParallelNode":
280
+ internal_name = cast(str, config.get("internal_name"))
281
+
282
+ config_branches = config.pop("branches", {})
283
+ branches = {}
284
+ for branch_name, branch_config in config_branches.items():
285
+ sub_graph = create_graph(
286
+ deepcopy(branch_config),
287
+ internal_branch_name=internal_name + "." + branch_name,
288
+ )
289
+ branches[internal_name + "." + branch_name] = sub_graph
290
+
291
+ if not branches:
292
+ raise Exception("A parallel node should have branches")
293
+ return cls(branches=branches, **config)
294
+
295
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
296
+ if branch_name in self.branches:
297
+ return self.branches[branch_name]
298
+
299
+ raise Exception(f"Branch {branch_name} does not exist")
300
+
301
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
302
+ """
303
+ The general fan out method for a node of type Parallel.
304
+ This method assumes that the step log has already been created.
305
+
306
+ 3rd party orchestrators should create the step log and use this method to create the branch logs.
307
+
308
+ Args:
309
+ executor (BaseExecutor): The executor class as defined by the config
310
+ map_variable (dict, optional): If the node is part of a map node. Defaults to None.
311
+ """
312
+ # Prepare the branch logs
313
+ for internal_branch_name, _ in self.branches.items():
314
+ effective_branch_name = self._resolve_map_placeholders(
315
+ internal_branch_name, map_variable=map_variable
316
+ )
317
+
318
+ branch_log = self._context.run_log_store.create_branch_log(
319
+ effective_branch_name
320
+ )
321
+ branch_log.status = defaults.PROCESSING
322
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
323
+
324
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
325
+ """
326
+ This function does the actual execution of the sub-branches of the parallel node.
327
+
328
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
329
+
330
+ The modes that render the job specifications, do not need to interact with this node at all as they have their
331
+ own internal mechanisms of handing parallel states.
332
+ If they do not, you can find a way using as-is nodes as hack nodes.
333
+
334
+ The execution of a dag, could result in
335
+ * The dag being completely executed with a definite (fail, success) state in case of
336
+ local or local-container execution
337
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
338
+
339
+ Only fail state is considered failure during this phase of execution.
340
+
341
+ Args:
342
+ executor (Executor): The Executor as per the use config
343
+ **kwargs: Optional kwargs passed around
344
+ """
345
+ self.fan_out(map_variable=map_variable, **kwargs)
346
+
347
+ for _, branch in self.branches.items():
348
+ self._context.executor.execute_graph(
349
+ branch, map_variable=map_variable, **kwargs
350
+ )
351
+
352
+ self.fan_in(map_variable=map_variable, **kwargs)
353
+
354
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
355
+ """
356
+ The general fan in method for a node of type Parallel.
357
+
358
+ 3rd party orchestrators should use this method to find the status of the composite step.
359
+
360
+ Args:
361
+ executor (BaseExecutor): The executor class as defined by the config
362
+ map_variable (dict, optional): If the node is part of a map. Defaults to None.
363
+ """
364
+ effective_internal_name = self._resolve_map_placeholders(
365
+ self.internal_name, map_variable=map_variable
366
+ )
367
+ step_success_bool = True
368
+ for internal_branch_name, _ in self.branches.items():
369
+ effective_branch_name = self._resolve_map_placeholders(
370
+ internal_branch_name, map_variable=map_variable
371
+ )
372
+ branch_log = self._context.run_log_store.get_branch_log(
373
+ effective_branch_name, self._context.run_id
374
+ )
375
+
376
+ if branch_log.status != defaults.SUCCESS:
377
+ step_success_bool = False
378
+
379
+ # Collate all the results and update the status of the step
380
+
381
+ step_log = self._context.run_log_store.get_step_log(
382
+ effective_internal_name, self._context.run_id
383
+ )
384
+
385
+ if step_success_bool: #  If none failed
386
+ step_log.status = defaults.SUCCESS
387
+ else:
388
+ step_log.status = defaults.FAIL
389
+
390
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
391
+
392
+
393
+ class MapNode(CompositeNode):
394
+ """
395
+ A composite node that contains ONE graph object within itself that has to be executed with an iterable.
396
+
397
+ The structure is generally:
398
+ MapNode:
399
+ branch
400
+
401
+ The config is expected to have a variable 'iterate_on' and iterate_as which are looked for in the parameters.
402
+ for iter_variable in parameters['iterate_on']:
403
+ Execute the Branch by sending {'iterate_as': iter_variable}
404
+
405
+ The internal naming convention creates branches dynamically based on the iteration value
406
+ """
407
+
408
+ # TODO: Should it be one function or a dict of functions indexed by the return name
409
+
410
+ node_type: str = Field(default="map", serialization_alias="type")
411
+ iterate_on: str
412
+ iterate_as: str
413
+ reducer: Optional[str] = Field(default=None)
414
+ branch: Graph
415
+ is_composite: bool = True
416
+
417
+ def get_summary(self) -> Dict[str, Any]:
418
+ summary = {
419
+ "name": self.name,
420
+ "type": self.node_type,
421
+ "branch": self.branch.get_summary(),
422
+ "iterate_on": self.iterate_on,
423
+ "iterate_as": self.iterate_as,
424
+ "reducer": self.reducer,
425
+ }
426
+
427
+ return summary
428
+
429
+ def get_reducer_function(self):
430
+ if not self.reducer:
431
+ return lambda *x: list(x) # returns a list of the args
432
+
433
+ # try a lambda function
434
+ try:
435
+ f = eval(self.reducer)
436
+ if callable(f):
437
+ return f
438
+ except SyntaxError:
439
+ logger.info(f"{self.reducer} is not a lambda function")
440
+
441
+ # Load the reducer function from dotted path
442
+ mod, func = utils.get_module_and_attr_names(self.reducer)
443
+ sys.path.insert(0, os.getcwd()) # Need to add the current directory to path
444
+ imported_module = importlib.import_module(mod)
445
+ f = getattr(imported_module, func)
446
+
447
+ return f
448
+
449
+ @classmethod
450
+ def parse_from_config(cls, config: Dict[str, Any]) -> "MapNode":
451
+ internal_name = cast(str, config.get("internal_name"))
452
+
453
+ config_branch = config.pop("branch", {})
454
+ if not config_branch:
455
+ raise Exception("A map node should have a branch")
456
+
457
+ branch = create_graph(
458
+ deepcopy(config_branch),
459
+ internal_branch_name=internal_name + "." + defaults.MAP_PLACEHOLDER,
460
+ )
461
+ return cls(branch=branch, **config)
462
+
463
+ @property
464
+ def branch_returns(self):
465
+ branch_returns: List[
466
+ Tuple[str, Union[ObjectParameter, MetricParameter, JsonParameter]]
467
+ ] = []
468
+ for _, node in self.branch.nodes.items():
469
+ if isinstance(node, TaskNode):
470
+ for task_return in node.executable.returns:
471
+ if task_return.kind == "json":
472
+ branch_returns.append(
473
+ (
474
+ task_return.name,
475
+ JsonParameter(kind="json", value="", reduced=False),
476
+ )
477
+ )
478
+ elif task_return.kind == "object":
479
+ branch_returns.append(
480
+ (
481
+ task_return.name,
482
+ ObjectParameter(
483
+ kind="object",
484
+ value="Will be reduced",
485
+ reduced=False,
486
+ ),
487
+ )
488
+ )
489
+ elif task_return.kind == "metric":
490
+ branch_returns.append(
491
+ (
492
+ task_return.name,
493
+ MetricParameter(kind="metric", value="", reduced=False),
494
+ )
495
+ )
496
+ else:
497
+ raise Exception("kind should be either json or object")
498
+
499
+ return branch_returns
500
+
501
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
502
+ """
503
+ Retrieve a branch by name.
504
+
505
+ In the case of a Map Object, the branch naming is dynamic as it is parameterized on iterable.
506
+ This method takes no responsibility in checking the validity of the naming.
507
+
508
+ Returns a Graph Object
509
+
510
+ Args:
511
+ branch_name (str): The name of the branch to retrieve
512
+
513
+ Raises:
514
+ Exception: If the branch by that name does not exist
515
+ """
516
+ return self.branch
517
+
518
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
519
+ """
520
+ The general method to fan out for a node of type map.
521
+ This method assumes that the step log has already been created.
522
+
523
+ 3rd party orchestrators should call this method to create the individual branch logs.
524
+
525
+ Args:
526
+ executor (BaseExecutor): The executor class as defined by the config
527
+ map_variable (dict, optional): If the node is part of map. Defaults to None.
528
+ """
529
+ iterate_on = self._context.run_log_store.get_parameters(self._context.run_id)[
530
+ self.iterate_on
531
+ ].get_value()
532
+
533
+ # Prepare the branch logs
534
+ for iter_variable in iterate_on:
535
+ effective_branch_name = self._resolve_map_placeholders(
536
+ self.internal_name + "." + str(iter_variable), map_variable=map_variable
537
+ )
538
+ branch_log = self._context.run_log_store.create_branch_log(
539
+ effective_branch_name
540
+ )
541
+
542
+ console.print(
543
+ f"Branch log created for {effective_branch_name}: {branch_log}"
544
+ )
545
+ branch_log.status = defaults.PROCESSING
546
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
547
+
548
+ # Gather all the returns of the task nodes and create parameters in reduced=False state.
549
+ # TODO: Why are we preemptively creating the parameters?
550
+ raw_parameters = {}
551
+ if map_variable:
552
+ # If we are in a map state already, the param should have an index of the map variable.
553
+ for _, v in map_variable.items():
554
+ for branch_return in self.branch_returns:
555
+ param_name, param_type = branch_return
556
+ raw_parameters[f"{v}_{param_name}"] = param_type.copy()
557
+ else:
558
+ for branch_return in self.branch_returns:
559
+ param_name, param_type = branch_return
560
+ raw_parameters[f"{param_name}"] = param_type.copy()
561
+
562
+ self._context.run_log_store.set_parameters(
563
+ parameters=raw_parameters, run_id=self._context.run_id
564
+ )
565
+
566
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
567
+ """
568
+ This function does the actual execution of the branch of the map node.
569
+
570
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
571
+
572
+ The modes that render the job specifications, do not need to interact with this node at all as
573
+ they have their own internal mechanisms of handing map states or dynamic parallel states.
574
+ If they do not, you can find a way using as-is nodes as hack nodes.
575
+
576
+ The actual logic is :
577
+ * We iterate over the iterable as mentioned in the config
578
+ * For every value in the iterable we call the executor.execute_graph(branch, iterate_as: iter_variable)
579
+
580
+ The execution of a dag, could result in
581
+ * The dag being completely executed with a definite (fail, success) state in case of local
582
+ or local-container execution
583
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
584
+
585
+ Only fail state is considered failure during this phase of execution.
586
+
587
+ Args:
588
+ executor (Executor): The Executor as per the use config
589
+ map_variable (dict): The map variables the graph belongs to
590
+ **kwargs: Optional kwargs passed around
591
+ """
592
+
593
+ iterate_on = None
594
+ try:
595
+ iterate_on = self._context.run_log_store.get_parameters(
596
+ self._context.run_id
597
+ )[self.iterate_on].get_value()
598
+ except KeyError as e:
599
+ raise Exception(
600
+ (
601
+ f"Expected parameter {self.iterate_on}",
602
+ "not present in Run Log parameters",
603
+ "was it ever set before?",
604
+ )
605
+ ) from e
606
+
607
+ if not isinstance(iterate_on, list):
608
+ raise Exception("Only list is allowed as a valid iterator type")
609
+
610
+ self.fan_out(map_variable=map_variable, **kwargs)
611
+
612
+ for iter_variable in iterate_on:
613
+ effective_map_variable = map_variable or OrderedDict()
614
+ effective_map_variable[self.iterate_as] = iter_variable
615
+
616
+ self._context.executor.execute_graph(
617
+ self.branch, map_variable=effective_map_variable, **kwargs
618
+ )
619
+
620
+ self.fan_in(map_variable=map_variable, **kwargs)
621
+
622
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
623
+ """
624
+ The general method to fan in for a node of type map.
625
+
626
+ 3rd party orchestrators should call this method to find the status of the step log.
627
+
628
+ Args:
629
+ executor (BaseExecutor): The executor class as defined by the config
630
+ map_variable (dict, optional): If the node is part of map node. Defaults to None.
631
+ """
632
+ params = self._context.run_log_store.get_parameters(self._context.run_id)
633
+ iterate_on = params[self.iterate_on].get_value()
634
+ # # Find status of the branches
635
+ step_success_bool = True
636
+ effective_internal_name = self._resolve_map_placeholders(
637
+ self.internal_name, map_variable=map_variable
638
+ )
639
+
640
+ for iter_variable in iterate_on:
641
+ effective_branch_name = self._resolve_map_placeholders(
642
+ self.internal_name + "." + str(iter_variable), map_variable=map_variable
643
+ )
644
+ branch_log = self._context.run_log_store.get_branch_log(
645
+ effective_branch_name, self._context.run_id
646
+ )
647
+ # console.print(f"Branch log for {effective_branch_name}: {branch_log}")
648
+
649
+ if branch_log.status != defaults.SUCCESS:
650
+ step_success_bool = False
651
+
652
+ # Collate all the results and update the status of the step
653
+ step_log = self._context.run_log_store.get_step_log(
654
+ effective_internal_name, self._context.run_id
655
+ )
656
+
657
+ if step_success_bool: #  If none failed and nothing is waiting
658
+ step_log.status = defaults.SUCCESS
659
+ else:
660
+ step_log.status = defaults.FAIL
661
+
662
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
663
+
664
+ # Apply the reduce function and reduce the returns of the task nodes.
665
+ # The final value of the parameter is the result of the reduce function.
666
+ reducer_f = self.get_reducer_function()
667
+
668
+ def update_param(
669
+ params: Dict[str, Parameter], reducer_f: Callable, map_prefix: str = ""
670
+ ):
671
+ for branch_return in self.branch_returns:
672
+ param_name, _ = branch_return
673
+
674
+ to_reduce = []
675
+ for iter_variable in iterate_on:
676
+ try:
677
+ to_reduce.append(
678
+ params[f"{iter_variable}_{param_name}"].get_value()
679
+ )
680
+ except KeyError as e:
681
+ from extensions.pipeline_executor.mocked import MockedExecutor
682
+
683
+ if isinstance(self._context.executor, MockedExecutor):
684
+ pass
685
+ else:
686
+ raise Exception(
687
+ (
688
+ f"Expected parameter {iter_variable}_{param_name}",
689
+ "not present in Run Log parameters",
690
+ "was it ever set before?",
691
+ )
692
+ ) from e
693
+
694
+ param_name = f"{map_prefix}{param_name}"
695
+ if to_reduce:
696
+ params[param_name].value = reducer_f(*to_reduce)
697
+ else:
698
+ params[param_name].value = ""
699
+ params[param_name].reduced = True
700
+
701
+ if map_variable:
702
+ # If we are in a map state already, the param should have an index of the map variable.
703
+ for _, v in map_variable.items():
704
+ update_param(params, reducer_f, map_prefix=f"{v}_")
705
+ else:
706
+ update_param(params, reducer_f)
707
+
708
+ self._context.run_log_store.set_parameters(
709
+ parameters=params, run_id=self._context.run_id
710
+ )
711
+
712
+
713
+ class DagNode(CompositeNode):
714
+ """
715
+ A composite node that internally holds a dag.
716
+
717
+ The structure is generally:
718
+ DagNode:
719
+ dag_definition: A YAML file that holds the dag in 'dag' block
720
+
721
+ The config is expected to have a variable 'dag_definition'.
722
+ """
723
+
724
+ node_type: str = Field(default="dag", serialization_alias="type")
725
+ dag_definition: str
726
+ branch: Graph
727
+ is_composite: bool = True
728
+ internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
729
+
730
+ def get_summary(self) -> Dict[str, Any]:
731
+ summary = {
732
+ "name": self.name,
733
+ "type": self.node_type,
734
+ }
735
+ return summary
736
+
737
+ @field_validator("internal_branch_name")
738
+ @classmethod
739
+ def validate_internal_branch_name(
740
+ cls, internal_branch_name: str, info: ValidationInfo
741
+ ):
742
+ internal_name = info.data["internal_name"]
743
+ return internal_name + "." + defaults.DAG_BRANCH_NAME
744
+
745
+ @field_validator("dag_definition")
746
+ @classmethod
747
+ def validate_dag_definition(cls, value):
748
+ if not value.endswith(".yaml"): # TODO: Might have a problem with the SDK
749
+ raise ValueError("dag_definition must be a YAML file")
750
+ return value
751
+
752
+ @classmethod
753
+ def parse_from_config(cls, config: Dict[str, Any]) -> "DagNode":
754
+ internal_name = cast(str, config.get("internal_name"))
755
+
756
+ if "dag_definition" not in config:
757
+ raise Exception(f"No dag definition found in {config}")
758
+
759
+ dag_config = utils.load_yaml(config["dag_definition"])
760
+ if "dag" not in dag_config:
761
+ raise Exception(
762
+ "No DAG found in dag_definition, please provide it in dag block"
763
+ )
764
+
765
+ branch = create_graph(
766
+ dag_config["dag"],
767
+ internal_branch_name=internal_name + "." + defaults.DAG_BRANCH_NAME,
768
+ )
769
+
770
+ return cls(branch=branch, **config)
771
+
772
+ def _get_branch_by_name(self, branch_name: str):
773
+ """
774
+ Retrieve a branch by name.
775
+ The name is expected to follow a dot path convention.
776
+
777
+ Returns a Graph Object
778
+
779
+ Args:
780
+ branch_name (str): The name of the branch to retrieve
781
+
782
+ Raises:
783
+ Exception: If the branch_name is not 'dag'
784
+ """
785
+ if branch_name != self.internal_branch_name:
786
+ raise Exception(
787
+ f"Node of type {self.node_type} only allows a branch of name {defaults.DAG_BRANCH_NAME}"
788
+ )
789
+
790
+ return self.branch
791
+
792
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
793
+ """
794
+ The general method to fan out for a node of type dag.
795
+ The method assumes that the step log has already been created.
796
+
797
+ Args:
798
+ executor (BaseExecutor): The executor class as defined by the config
799
+ map_variable (dict, optional): _description_. Defaults to None.
800
+ """
801
+ effective_branch_name = self._resolve_map_placeholders(
802
+ self.internal_branch_name, map_variable=map_variable
803
+ )
804
+
805
+ branch_log = self._context.run_log_store.create_branch_log(
806
+ effective_branch_name
807
+ )
808
+ branch_log.status = defaults.PROCESSING
809
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
810
+
811
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
812
+ """
813
+ This function does the actual execution of the branch of the dag node.
814
+
815
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
816
+
817
+ The modes that render the job specifications, do not need to interact with this node at all
818
+ as they have their own internal mechanisms of handling sub dags.
819
+ If they do not, you can find a way using as-is nodes as hack nodes.
820
+
821
+ The actual logic is :
822
+ * We just execute the branch as with any other composite nodes
823
+ * The branch name is called 'dag'
824
+
825
+ The execution of a dag, could result in
826
+ * The dag being completely executed with a definite (fail, success) state in case of
827
+ local or local-container execution
828
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
829
+
830
+ Only fail state is considered failure during this phase of execution.
831
+
832
+ Args:
833
+ executor (Executor): The Executor as per the use config
834
+ **kwargs: Optional kwargs passed around
835
+ """
836
+ self.fan_out(map_variable=map_variable, **kwargs)
837
+ self._context.executor.execute_graph(
838
+ self.branch, map_variable=map_variable, **kwargs
839
+ )
840
+ self.fan_in(map_variable=map_variable, **kwargs)
841
+
842
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
843
+ """
844
+ The general method to fan in for a node of type dag.
845
+
846
+ 3rd party orchestrators should call this method to find the status of the step log.
847
+
848
+ Args:
849
+ executor (BaseExecutor): The executor class as defined by the config
850
+ map_variable (dict, optional): If the node is part of type dag. Defaults to None.
851
+ """
852
+ step_success_bool = True
853
+ effective_branch_name = self._resolve_map_placeholders(
854
+ self.internal_branch_name, map_variable=map_variable
855
+ )
856
+ effective_internal_name = self._resolve_map_placeholders(
857
+ self.internal_name, map_variable=map_variable
858
+ )
859
+
860
+ branch_log = self._context.run_log_store.get_branch_log(
861
+ effective_branch_name, self._context.run_id
862
+ )
863
+ if branch_log.status != defaults.SUCCESS:
864
+ step_success_bool = False
865
+
866
+ step_log = self._context.run_log_store.get_step_log(
867
+ effective_internal_name, self._context.run_id
868
+ )
869
+ step_log.status = defaults.PROCESSING
870
+
871
+ if step_success_bool: #  If none failed and nothing is waiting
872
+ step_log.status = defaults.SUCCESS
873
+ else:
874
+ step_log.status = defaults.FAIL
875
+
876
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
877
+
878
+
879
+ class StubNode(ExecutableNode):
880
+ """
881
+ Stub is a convenience design node.
882
+ It always returns success in the attempt log and does nothing.
883
+
884
+ This node is very similar to pass state in Step functions.
885
+
886
+ This node type could be handy when designing the pipeline and stubbing functions
887
+ --8<-- [start:stub_reference]
888
+ An stub execution node of the pipeline.
889
+ Please refer to define pipeline/tasks/stub for more information.
890
+
891
+ As part of the dag definition, a stub task is defined as follows:
892
+
893
+ dag:
894
+ steps:
895
+ stub_task: # The name of the node
896
+ type: stub
897
+ on_failure: The name of the step to traverse in case of failure
898
+ next: The next node to execute after this task, use "success" to terminate the pipeline successfully
899
+ or "fail" to terminate the pipeline with an error.
900
+
901
+ It can take arbritary number of parameters, which is handy to temporarily silence a task node.
902
+ --8<-- [end:stub_reference]
903
+ """
904
+
905
+ node_type: str = Field(default="stub", serialization_alias="type")
906
+ model_config = ConfigDict(extra="ignore")
907
+
908
+ def get_summary(self) -> Dict[str, Any]:
909
+ summary = {
910
+ "name": self.name,
911
+ "type": self.node_type,
912
+ }
913
+
914
+ return summary
915
+
916
+ @classmethod
917
+ def parse_from_config(cls, config: Dict[str, Any]) -> "StubNode":
918
+ return cls(**config)
919
+
920
+ def execute(
921
+ self,
922
+ mock=False,
923
+ map_variable: TypeMapVariable = None,
924
+ attempt_number: int = 1,
925
+ **kwargs,
926
+ ) -> StepLog:
927
+ """
928
+ Do Nothing node.
929
+ We just send an success attempt log back to the caller
930
+
931
+ Args:
932
+ executor ([type]): [description]
933
+ mock (bool, optional): [description]. Defaults to False.
934
+ map_variable (str, optional): [description]. Defaults to ''.
935
+
936
+ Returns:
937
+ [type]: [description]
938
+ """
939
+ step_log = self._context.run_log_store.get_step_log(
940
+ self._get_step_log_name(map_variable), self._context.run_id
941
+ )
942
+
943
+ attempt_log = datastore.StepAttempt(
944
+ status=defaults.SUCCESS,
945
+ start_time=str(datetime.now()),
946
+ end_time=str(datetime.now()),
947
+ attempt_number=attempt_number,
948
+ )
949
+
950
+ step_log.status = attempt_log.status
951
+
952
+ step_log.attempts.append(attempt_log)
953
+
954
+ return step_log