runnable 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. runnable/__init__.py +34 -0
  2. runnable/catalog.py +141 -0
  3. runnable/cli.py +272 -0
  4. runnable/context.py +34 -0
  5. runnable/datastore.py +686 -0
  6. runnable/defaults.py +179 -0
  7. runnable/entrypoints.py +484 -0
  8. runnable/exceptions.py +94 -0
  9. runnable/executor.py +431 -0
  10. runnable/experiment_tracker.py +139 -0
  11. runnable/extensions/catalog/__init__.py +21 -0
  12. runnable/extensions/catalog/file_system/__init__.py +0 -0
  13. runnable/extensions/catalog/file_system/implementation.py +226 -0
  14. runnable/extensions/catalog/k8s_pvc/__init__.py +0 -0
  15. runnable/extensions/catalog/k8s_pvc/implementation.py +16 -0
  16. runnable/extensions/catalog/k8s_pvc/integration.py +59 -0
  17. runnable/extensions/executor/__init__.py +714 -0
  18. runnable/extensions/executor/argo/__init__.py +0 -0
  19. runnable/extensions/executor/argo/implementation.py +1182 -0
  20. runnable/extensions/executor/argo/specification.yaml +51 -0
  21. runnable/extensions/executor/k8s_job/__init__.py +0 -0
  22. runnable/extensions/executor/k8s_job/implementation_FF.py +259 -0
  23. runnable/extensions/executor/k8s_job/integration_FF.py +69 -0
  24. runnable/extensions/executor/local/__init__.py +0 -0
  25. runnable/extensions/executor/local/implementation.py +69 -0
  26. runnable/extensions/executor/local_container/__init__.py +0 -0
  27. runnable/extensions/executor/local_container/implementation.py +367 -0
  28. runnable/extensions/executor/mocked/__init__.py +0 -0
  29. runnable/extensions/executor/mocked/implementation.py +220 -0
  30. runnable/extensions/experiment_tracker/__init__.py +0 -0
  31. runnable/extensions/experiment_tracker/mlflow/__init__.py +0 -0
  32. runnable/extensions/experiment_tracker/mlflow/implementation.py +94 -0
  33. runnable/extensions/nodes.py +675 -0
  34. runnable/extensions/run_log_store/__init__.py +0 -0
  35. runnable/extensions/run_log_store/chunked_file_system/__init__.py +0 -0
  36. runnable/extensions/run_log_store/chunked_file_system/implementation.py +106 -0
  37. runnable/extensions/run_log_store/chunked_k8s_pvc/__init__.py +0 -0
  38. runnable/extensions/run_log_store/chunked_k8s_pvc/implementation.py +21 -0
  39. runnable/extensions/run_log_store/chunked_k8s_pvc/integration.py +61 -0
  40. runnable/extensions/run_log_store/db/implementation_FF.py +157 -0
  41. runnable/extensions/run_log_store/db/integration_FF.py +0 -0
  42. runnable/extensions/run_log_store/file_system/__init__.py +0 -0
  43. runnable/extensions/run_log_store/file_system/implementation.py +136 -0
  44. runnable/extensions/run_log_store/generic_chunked.py +541 -0
  45. runnable/extensions/run_log_store/k8s_pvc/__init__.py +0 -0
  46. runnable/extensions/run_log_store/k8s_pvc/implementation.py +21 -0
  47. runnable/extensions/run_log_store/k8s_pvc/integration.py +56 -0
  48. runnable/extensions/secrets/__init__.py +0 -0
  49. runnable/extensions/secrets/dotenv/__init__.py +0 -0
  50. runnable/extensions/secrets/dotenv/implementation.py +100 -0
  51. runnable/extensions/secrets/env_secrets/__init__.py +0 -0
  52. runnable/extensions/secrets/env_secrets/implementation.py +42 -0
  53. runnable/graph.py +464 -0
  54. runnable/integration.py +205 -0
  55. runnable/interaction.py +399 -0
  56. runnable/names.py +546 -0
  57. runnable/nodes.py +489 -0
  58. runnable/parameters.py +183 -0
  59. runnable/pickler.py +102 -0
  60. runnable/sdk.py +470 -0
  61. runnable/secrets.py +95 -0
  62. runnable/tasks.py +392 -0
  63. runnable/utils.py +630 -0
  64. runnable-0.2.0.dist-info/METADATA +437 -0
  65. runnable-0.2.0.dist-info/RECORD +69 -0
  66. runnable-0.2.0.dist-info/entry_points.txt +44 -0
  67. runnable-0.1.0.dist-info/METADATA +0 -16
  68. runnable-0.1.0.dist-info/RECORD +0 -6
  69. /runnable/{.gitkeep → extensions/__init__.py} +0 -0
  70. {runnable-0.1.0.dist-info → runnable-0.2.0.dist-info}/LICENSE +0 -0
  71. {runnable-0.1.0.dist-info → runnable-0.2.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,675 @@
1
+ import json
2
+ import logging
3
+ import multiprocessing
4
+ from collections import OrderedDict
5
+ from copy import deepcopy
6
+ from datetime import datetime
7
+ from typing import Any, Dict, cast
8
+
9
+ from pydantic import ConfigDict, Field, ValidationInfo, field_serializer, field_validator
10
+ from typing_extensions import Annotated
11
+
12
+ from runnable import defaults, utils
13
+ from runnable.datastore import StepAttempt
14
+ from runnable.defaults import TypeMapVariable
15
+ from runnable.graph import Graph, create_graph
16
+ from runnable.nodes import CompositeNode, ExecutableNode, TerminalNode
17
+ from runnable.tasks import BaseTaskType, create_task
18
+
19
+ logger = logging.getLogger(defaults.LOGGER_NAME)
20
+
21
+
22
+ class TaskNode(ExecutableNode):
23
+ """
24
+ A node of type Task.
25
+
26
+ This node does the actual function execution of the graph in all cases.
27
+ """
28
+
29
+ executable: BaseTaskType = Field(exclude=True)
30
+ node_type: str = Field(default="task", serialization_alias="type")
31
+
32
+ # It is technically not allowed as parse_from_config filters them.
33
+ # This is just to get the task level configuration to be present during serialization.
34
+ model_config = ConfigDict(extra="allow")
35
+
36
+ @classmethod
37
+ def parse_from_config(cls, config: Dict[str, Any]) -> "TaskNode":
38
+ # separate task config from node config
39
+ task_config = {k: v for k, v in config.items() if k not in TaskNode.model_fields.keys()}
40
+ node_config = {k: v for k, v in config.items() if k in TaskNode.model_fields.keys()}
41
+
42
+ task_config["node_name"] = config.get("name")
43
+
44
+ executable = create_task(task_config)
45
+ return cls(executable=executable, **node_config, **task_config)
46
+
47
+ def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
48
+ """
49
+ All that we do in magnus is to come to this point where we actually execute the command.
50
+
51
+ Args:
52
+ executor (_type_): The executor class
53
+ mock (bool, optional): If we should just mock and not execute. Defaults to False.
54
+ map_variable (dict, optional): If the node is part of internal branch. Defaults to None.
55
+
56
+ Returns:
57
+ StepAttempt: The attempt object
58
+ """
59
+ print("Executing task:", self._context.executor._context_node)
60
+ # Here is where the juice is
61
+ attempt_log = self._context.run_log_store.create_attempt_log()
62
+ try:
63
+ attempt_log.start_time = str(datetime.now())
64
+ attempt_log.status = defaults.SUCCESS
65
+ if not mock:
66
+ # Do not run if we are mocking the execution, could be useful for caching and dry runs
67
+ self.executable.execute_command(map_variable=map_variable)
68
+ except Exception as _e: # pylint: disable=W0703
69
+ logger.exception("Task failed")
70
+ attempt_log.status = defaults.FAIL
71
+ attempt_log.message = str(_e)
72
+ finally:
73
+ attempt_log.end_time = str(datetime.now())
74
+ attempt_log.duration = utils.get_duration_between_datetime_strings(
75
+ attempt_log.start_time, attempt_log.end_time
76
+ )
77
+ return attempt_log
78
+
79
+
80
+ class FailNode(TerminalNode):
81
+ """
82
+ A leaf node of the graph that represents a failure node
83
+ """
84
+
85
+ node_type: str = Field(default="fail", serialization_alias="type")
86
+
87
+ @classmethod
88
+ def parse_from_config(cls, config: Dict[str, Any]) -> "FailNode":
89
+ return cast("FailNode", super().parse_from_config(config))
90
+
91
+ def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
92
+ """
93
+ Execute the failure node.
94
+ Set the run or branch log status to failure.
95
+
96
+ Args:
97
+ executor (_type_): the executor class
98
+ mock (bool, optional): If we should just mock and not do the actual execution. Defaults to False.
99
+ map_variable (dict, optional): If the node belongs to internal branches. Defaults to None.
100
+
101
+ Returns:
102
+ StepAttempt: The step attempt object
103
+ """
104
+ attempt_log = self._context.run_log_store.create_attempt_log()
105
+ try:
106
+ attempt_log.start_time = str(datetime.now())
107
+ attempt_log.status = defaults.SUCCESS
108
+ #  could be a branch or run log
109
+ run_or_branch_log = self._context.run_log_store.get_branch_log(
110
+ self._get_branch_log_name(map_variable), self._context.run_id
111
+ )
112
+ run_or_branch_log.status = defaults.FAIL
113
+ self._context.run_log_store.add_branch_log(run_or_branch_log, self._context.run_id)
114
+ except BaseException: # pylint: disable=W0703
115
+ logger.exception("Fail node execution failed")
116
+ finally:
117
+ attempt_log.status = defaults.SUCCESS # This is a dummy node, so we ignore errors and mark SUCCESS
118
+ attempt_log.end_time = str(datetime.now())
119
+ attempt_log.duration = utils.get_duration_between_datetime_strings(
120
+ attempt_log.start_time, attempt_log.end_time
121
+ )
122
+ return attempt_log
123
+
124
+
125
+ class SuccessNode(TerminalNode):
126
+ """
127
+ A leaf node of the graph that represents a success node
128
+ """
129
+
130
+ node_type: str = Field(default="success", serialization_alias="type")
131
+
132
+ @classmethod
133
+ def parse_from_config(cls, config: Dict[str, Any]) -> "SuccessNode":
134
+ return cast("SuccessNode", super().parse_from_config(config))
135
+
136
+ def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
137
+ """
138
+ Execute the success node.
139
+ Set the run or branch log status to success.
140
+
141
+ Args:
142
+ executor (_type_): The executor class
143
+ mock (bool, optional): If we should just mock and not perform anything. Defaults to False.
144
+ map_variable (dict, optional): If the node belongs to an internal branch. Defaults to None.
145
+
146
+ Returns:
147
+ StepAttempt: The step attempt object
148
+ """
149
+ attempt_log = self._context.run_log_store.create_attempt_log()
150
+ try:
151
+ attempt_log.start_time = str(datetime.now())
152
+ attempt_log.status = defaults.SUCCESS
153
+ #  could be a branch or run log
154
+ run_or_branch_log = self._context.run_log_store.get_branch_log(
155
+ self._get_branch_log_name(map_variable), self._context.run_id
156
+ )
157
+ run_or_branch_log.status = defaults.SUCCESS
158
+ self._context.run_log_store.add_branch_log(run_or_branch_log, self._context.run_id)
159
+ except BaseException: # pylint: disable=W0703
160
+ logger.exception("Success node execution failed")
161
+ finally:
162
+ attempt_log.status = defaults.SUCCESS # This is a dummy node and we make sure we mark it as success
163
+ attempt_log.end_time = str(datetime.now())
164
+ attempt_log.duration = utils.get_duration_between_datetime_strings(
165
+ attempt_log.start_time, attempt_log.end_time
166
+ )
167
+ return attempt_log
168
+
169
+
170
+ class ParallelNode(CompositeNode):
171
+ """
172
+ A composite node containing many graph objects within itself.
173
+
174
+ The structure is generally:
175
+ ParallelNode:
176
+ Branch A:
177
+ Sub graph definition
178
+ Branch B:
179
+ Sub graph definition
180
+ . . .
181
+
182
+ """
183
+
184
+ node_type: str = Field(default="parallel", serialization_alias="type")
185
+ branches: Dict[str, Graph]
186
+ is_composite: bool = Field(default=True, exclude=True)
187
+
188
+ @field_serializer("branches")
189
+ def ser_branches(self, branches: Dict[str, Graph]) -> Dict[str, Graph]:
190
+ ret: Dict[str, Graph] = {}
191
+
192
+ for branch_name, branch in branches.items():
193
+ ret[branch_name.split(".")[-1]] = branch
194
+
195
+ return ret
196
+
197
+ @classmethod
198
+ def parse_from_config(cls, config: Dict[str, Any]) -> "ParallelNode":
199
+ internal_name = cast(str, config.get("internal_name"))
200
+
201
+ config_branches = config.pop("branches", {})
202
+ branches = {}
203
+ for branch_name, branch_config in config_branches.items():
204
+ sub_graph = create_graph(
205
+ deepcopy(branch_config),
206
+ internal_branch_name=internal_name + "." + branch_name,
207
+ )
208
+ branches[internal_name + "." + branch_name] = sub_graph
209
+
210
+ if not branches:
211
+ raise Exception("A parallel node should have branches")
212
+ return cls(branches=branches, **config)
213
+
214
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
215
+ if branch_name in self.branches:
216
+ return self.branches[branch_name]
217
+
218
+ raise Exception(f"Branch {branch_name} does not exist")
219
+
220
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
221
+ """
222
+ The general fan out method for a node of type Parallel.
223
+ This method assumes that the step log has already been created.
224
+
225
+ 3rd party orchestrators should create the step log and use this method to create the branch logs.
226
+
227
+ Args:
228
+ executor (BaseExecutor): The executor class as defined by the config
229
+ map_variable (dict, optional): If the node is part of a map node. Defaults to None.
230
+ """
231
+ # Prepare the branch logs
232
+ for internal_branch_name, _ in self.branches.items():
233
+ effective_branch_name = self._resolve_map_placeholders(internal_branch_name, map_variable=map_variable)
234
+
235
+ branch_log = self._context.run_log_store.create_branch_log(effective_branch_name)
236
+ branch_log.status = defaults.PROCESSING
237
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
238
+
239
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
240
+ """
241
+ This function does the actual execution of the sub-branches of the parallel node.
242
+
243
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
244
+
245
+ The modes that render the job specifications, do not need to interact with this node at all as they have their
246
+ own internal mechanisms of handing parallel states.
247
+ If they do not, you can find a way using as-is nodes as hack nodes.
248
+
249
+ The execution of a dag, could result in
250
+ * The dag being completely executed with a definite (fail, success) state in case of
251
+ local or local-container execution
252
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
253
+
254
+ Only fail state is considered failure during this phase of execution.
255
+
256
+ Args:
257
+ executor (Executor): The Executor as per the use config
258
+ **kwargs: Optional kwargs passed around
259
+ """
260
+ from runnable import entrypoints
261
+
262
+ self.fan_out(map_variable=map_variable, **kwargs)
263
+
264
+ jobs = []
265
+ # Given that we can have nesting and complex graphs, controlling the number of processes is hard.
266
+ # A better way is to actually submit the job to some process scheduler which does resource management
267
+ for internal_branch_name, branch in self.branches.items():
268
+ if self._context.executor._is_parallel_execution():
269
+ # Trigger parallel jobs
270
+ action = entrypoints.execute_single_brach
271
+ kwargs = {
272
+ "configuration_file": self._context.configuration_file,
273
+ "pipeline_file": self._context.pipeline_file,
274
+ "branch_name": internal_branch_name.replace(" ", defaults.COMMAND_FRIENDLY_CHARACTER),
275
+ "run_id": self._context.run_id,
276
+ "map_variable": json.dumps(map_variable),
277
+ "tag": self._context.tag,
278
+ }
279
+ process = multiprocessing.Process(target=action, kwargs=kwargs)
280
+ jobs.append(process)
281
+ process.start()
282
+
283
+ else:
284
+ # If parallel is not enabled, execute them sequentially
285
+ self._context.executor.execute_graph(branch, map_variable=map_variable, **kwargs)
286
+
287
+ for job in jobs:
288
+ job.join() # Find status of the branches
289
+
290
+ self.fan_in(map_variable=map_variable, **kwargs)
291
+
292
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
293
+ """
294
+ The general fan in method for a node of type Parallel.
295
+
296
+ 3rd party orchestrators should use this method to find the status of the composite step.
297
+
298
+ Args:
299
+ executor (BaseExecutor): The executor class as defined by the config
300
+ map_variable (dict, optional): If the node is part of a map. Defaults to None.
301
+ """
302
+ step_success_bool = True
303
+ for internal_branch_name, _ in self.branches.items():
304
+ effective_branch_name = self._resolve_map_placeholders(internal_branch_name, map_variable=map_variable)
305
+ branch_log = self._context.run_log_store.get_branch_log(effective_branch_name, self._context.run_id)
306
+ if branch_log.status != defaults.SUCCESS:
307
+ step_success_bool = False
308
+
309
+ # Collate all the results and update the status of the step
310
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
311
+ step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
312
+
313
+ if step_success_bool: #  If none failed
314
+ step_log.status = defaults.SUCCESS
315
+ else:
316
+ step_log.status = defaults.FAIL
317
+
318
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
319
+
320
+
321
+ class MapNode(CompositeNode):
322
+ """
323
+ A composite node that contains ONE graph object within itself that has to be executed with an iterable.
324
+
325
+ The structure is generally:
326
+ MapNode:
327
+ branch
328
+
329
+ The config is expected to have a variable 'iterate_on' and iterate_as which are looked for in the parameters.
330
+ for iter_variable in parameters['iterate_on']:
331
+ Execute the Branch by sending {'iterate_as': iter_variable}
332
+
333
+ The internal naming convention creates branches dynamically based on the iteration value
334
+ """
335
+
336
+ node_type: str = Field(default="map", serialization_alias="type")
337
+ iterate_on: str
338
+ iterate_as: str
339
+ branch: Graph
340
+ is_composite: bool = True
341
+
342
+ @classmethod
343
+ def parse_from_config(cls, config: Dict[str, Any]) -> "MapNode":
344
+ internal_name = cast(str, config.get("internal_name"))
345
+
346
+ config_branch = config.pop("branch", {})
347
+ if not config_branch:
348
+ raise Exception("A map node should have a branch")
349
+
350
+ branch = create_graph(
351
+ deepcopy(config_branch),
352
+ internal_branch_name=internal_name + "." + defaults.MAP_PLACEHOLDER,
353
+ )
354
+ return cls(branch=branch, **config)
355
+
356
+ def _get_branch_by_name(self, branch_name: str) -> Graph:
357
+ """
358
+ Retrieve a branch by name.
359
+
360
+ In the case of a Map Object, the branch naming is dynamic as it is parameterized on iterable.
361
+ This method takes no responsibility in checking the validity of the naming.
362
+
363
+ Returns a Graph Object
364
+
365
+ Args:
366
+ branch_name (str): The name of the branch to retrieve
367
+
368
+ Raises:
369
+ Exception: If the branch by that name does not exist
370
+ """
371
+ return self.branch
372
+
373
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
374
+ """
375
+ The general method to fan out for a node of type map.
376
+ This method assumes that the step log has already been created.
377
+
378
+ 3rd party orchestrators should call this method to create the individual branch logs.
379
+
380
+ Args:
381
+ executor (BaseExecutor): The executor class as defined by the config
382
+ map_variable (dict, optional): If the node is part of map. Defaults to None.
383
+ """
384
+ iterate_on = self._context.run_log_store.get_parameters(self._context.run_id)[self.iterate_on]
385
+
386
+ # Prepare the branch logs
387
+ for iter_variable in iterate_on:
388
+ effective_branch_name = self._resolve_map_placeholders(
389
+ self.internal_name + "." + str(iter_variable), map_variable=map_variable
390
+ )
391
+ branch_log = self._context.run_log_store.create_branch_log(effective_branch_name)
392
+ branch_log.status = defaults.PROCESSING
393
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
394
+
395
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
396
+ """
397
+ This function does the actual execution of the branch of the map node.
398
+
399
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
400
+
401
+ The modes that render the job specifications, do not need to interact with this node at all as
402
+ they have their own internal mechanisms of handing map states or dynamic parallel states.
403
+ If they do not, you can find a way using as-is nodes as hack nodes.
404
+
405
+ The actual logic is :
406
+ * We iterate over the iterable as mentioned in the config
407
+ * For every value in the iterable we call the executor.execute_graph(branch, iterate_as: iter_variable)
408
+
409
+ The execution of a dag, could result in
410
+ * The dag being completely executed with a definite (fail, success) state in case of local
411
+ or local-container execution
412
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
413
+
414
+ Only fail state is considered failure during this phase of execution.
415
+
416
+ Args:
417
+ executor (Executor): The Executor as per the use config
418
+ map_variable (dict): The map variables the graph belongs to
419
+ **kwargs: Optional kwargs passed around
420
+ """
421
+ from runnable import entrypoints
422
+
423
+ iterate_on = None
424
+ try:
425
+ iterate_on = self._context.run_log_store.get_parameters(self._context.run_id)[self.iterate_on]
426
+ except KeyError:
427
+ raise Exception(
428
+ f"Expected parameter {self.iterate_on} not present in Run Log parameters, was it ever set before?"
429
+ )
430
+
431
+ if not isinstance(iterate_on, list):
432
+ raise Exception("Only list is allowed as a valid iterator type")
433
+
434
+ self.fan_out(map_variable=map_variable, **kwargs)
435
+
436
+ jobs = []
437
+ # Given that we can have nesting and complex graphs, controlling the number of processess is hard.
438
+ # A better way is to actually submit the job to some process scheduler which does resource management
439
+ for iter_variable in iterate_on:
440
+ effective_map_variable = map_variable or OrderedDict()
441
+ effective_map_variable[self.iterate_as] = iter_variable
442
+
443
+ if self._context.executor._is_parallel_execution():
444
+ # Trigger parallel jobs
445
+ action = entrypoints.execute_single_brach
446
+ kwargs = {
447
+ "configuration_file": self._context.configuration_file,
448
+ "pipeline_file": self._context.pipeline_file,
449
+ "branch_name": self.branch.internal_branch_name.replace(" ", defaults.COMMAND_FRIENDLY_CHARACTER),
450
+ "run_id": self._context.run_id,
451
+ "map_variable": json.dumps(effective_map_variable),
452
+ "tag": self._context.tag,
453
+ }
454
+ process = multiprocessing.Process(target=action, kwargs=kwargs)
455
+ jobs.append(process)
456
+ process.start()
457
+
458
+ else:
459
+ # If parallel is not enabled, execute them sequentially
460
+ self._context.executor.execute_graph(self.branch, map_variable=effective_map_variable, **kwargs)
461
+
462
+ for job in jobs:
463
+ job.join()
464
+
465
+ self.fan_in(map_variable=map_variable, **kwargs)
466
+
467
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
468
+ """
469
+ The general method to fan in for a node of type map.
470
+
471
+ 3rd party orchestrators should call this method to find the status of the step log.
472
+
473
+ Args:
474
+ executor (BaseExecutor): The executor class as defined by the config
475
+ map_variable (dict, optional): If the node is part of map node. Defaults to None.
476
+ """
477
+ iterate_on = self._context.run_log_store.get_parameters(self._context.run_id)[self.iterate_on]
478
+ # # Find status of the branches
479
+ step_success_bool = True
480
+
481
+ for iter_variable in iterate_on:
482
+ effective_branch_name = self._resolve_map_placeholders(
483
+ self.internal_name + "." + str(iter_variable), map_variable=map_variable
484
+ )
485
+ branch_log = self._context.run_log_store.get_branch_log(effective_branch_name, self._context.run_id)
486
+ if branch_log.status != defaults.SUCCESS:
487
+ step_success_bool = False
488
+
489
+ # Collate all the results and update the status of the step
490
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
491
+ step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
492
+
493
+ if step_success_bool: #  If none failed and nothing is waiting
494
+ step_log.status = defaults.SUCCESS
495
+ else:
496
+ step_log.status = defaults.FAIL
497
+
498
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
499
+
500
+
501
+ class DagNode(CompositeNode):
502
+ """
503
+ A composite node that internally holds a dag.
504
+
505
+ The structure is generally:
506
+ DagNode:
507
+ dag_definition: A YAML file that holds the dag in 'dag' block
508
+
509
+ The config is expected to have a variable 'dag_definition'.
510
+ """
511
+
512
+ node_type: str = Field(default="dag", serialization_alias="type")
513
+ dag_definition: str
514
+ branch: Graph
515
+ is_composite: bool = True
516
+ internal_branch_name: Annotated[str, Field(validate_default=True)] = ""
517
+
518
+ @field_validator("internal_branch_name")
519
+ @classmethod
520
+ def validate_internal_branch_name(cls, internal_branch_name: str, info: ValidationInfo):
521
+ internal_name = info.data["internal_name"]
522
+ return internal_name + "." + defaults.DAG_BRANCH_NAME
523
+
524
+ @field_validator("dag_definition")
525
+ @classmethod
526
+ def validate_dag_definition(cls, value):
527
+ if not value.endswith(".yaml"): # TODO: Might have a problem with the SDK
528
+ raise ValueError("dag_definition must be a YAML file")
529
+ return value
530
+
531
+ @classmethod
532
+ def parse_from_config(cls, config: Dict[str, Any]) -> "DagNode":
533
+ internal_name = cast(str, config.get("internal_name"))
534
+
535
+ if "dag_definition" not in config:
536
+ raise Exception(f"No dag definition found in {config}")
537
+
538
+ dag_config = utils.load_yaml(config["dag_definition"])
539
+ if "dag" not in dag_config:
540
+ raise Exception("No DAG found in dag_definition, please provide it in dag block")
541
+
542
+ branch = create_graph(dag_config["dag"], internal_branch_name=internal_name + "." + defaults.DAG_BRANCH_NAME)
543
+
544
+ return cls(branch=branch, **config)
545
+
546
+ def _get_branch_by_name(self, branch_name: str):
547
+ """
548
+ Retrieve a branch by name.
549
+ The name is expected to follow a dot path convention.
550
+
551
+ Returns a Graph Object
552
+
553
+ Args:
554
+ branch_name (str): The name of the branch to retrieve
555
+
556
+ Raises:
557
+ Exception: If the branch_name is not 'dag'
558
+ """
559
+ if branch_name != self.internal_branch_name:
560
+ raise Exception(f"Node of type {self.node_type} only allows a branch of name {defaults.DAG_BRANCH_NAME}")
561
+
562
+ return self.branch
563
+
564
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
565
+ """
566
+ The general method to fan out for a node of type dag.
567
+ The method assumes that the step log has already been created.
568
+
569
+ Args:
570
+ executor (BaseExecutor): The executor class as defined by the config
571
+ map_variable (dict, optional): _description_. Defaults to None.
572
+ """
573
+ effective_branch_name = self._resolve_map_placeholders(self.internal_branch_name, map_variable=map_variable)
574
+
575
+ branch_log = self._context.run_log_store.create_branch_log(effective_branch_name)
576
+ branch_log.status = defaults.PROCESSING
577
+ self._context.run_log_store.add_branch_log(branch_log, self._context.run_id)
578
+
579
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
580
+ """
581
+ This function does the actual execution of the branch of the dag node.
582
+
583
+ From a design perspective, this function should not be called if the execution is 3rd party orchestrated.
584
+
585
+ The modes that render the job specifications, do not need to interact with this node at all
586
+ as they have their own internal mechanisms of handling sub dags.
587
+ If they do not, you can find a way using as-is nodes as hack nodes.
588
+
589
+ The actual logic is :
590
+ * We just execute the branch as with any other composite nodes
591
+ * The branch name is called 'dag'
592
+
593
+ The execution of a dag, could result in
594
+ * The dag being completely executed with a definite (fail, success) state in case of
595
+ local or local-container execution
596
+ * The dag being in a processing state with PROCESSING status in case of local-aws-batch
597
+
598
+ Only fail state is considered failure during this phase of execution.
599
+
600
+ Args:
601
+ executor (Executor): The Executor as per the use config
602
+ **kwargs: Optional kwargs passed around
603
+ """
604
+ self.fan_out(map_variable=map_variable, **kwargs)
605
+ self._context.executor.execute_graph(self.branch, map_variable=map_variable, **kwargs)
606
+ self.fan_in(map_variable=map_variable, **kwargs)
607
+
608
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
609
+ """
610
+ The general method to fan in for a node of type dag.
611
+
612
+ 3rd party orchestrators should call this method to find the status of the step log.
613
+
614
+ Args:
615
+ executor (BaseExecutor): The executor class as defined by the config
616
+ map_variable (dict, optional): If the node is part of type dag. Defaults to None.
617
+ """
618
+ step_success_bool = True
619
+ effective_branch_name = self._resolve_map_placeholders(self.internal_branch_name, map_variable=map_variable)
620
+ effective_internal_name = self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
621
+
622
+ branch_log = self._context.run_log_store.get_branch_log(effective_branch_name, self._context.run_id)
623
+ if branch_log.status != defaults.SUCCESS:
624
+ step_success_bool = False
625
+
626
+ step_log = self._context.run_log_store.get_step_log(effective_internal_name, self._context.run_id)
627
+ step_log.status = defaults.PROCESSING
628
+
629
+ if step_success_bool: #  If none failed and nothing is waiting
630
+ step_log.status = defaults.SUCCESS
631
+ else:
632
+ step_log.status = defaults.FAIL
633
+
634
+ self._context.run_log_store.add_step_log(step_log, self._context.run_id)
635
+
636
+
637
+ class StubNode(ExecutableNode):
638
+ """
639
+ Stub is a convenience design node.
640
+
641
+ It always returns success in the attempt log and does nothing.
642
+
643
+ This node is very similar to pass state in Step functions.
644
+
645
+ This node type could be handy when designing the pipeline and stubbing functions
646
+ """
647
+
648
+ node_type: str = Field(default="stub", serialization_alias="type")
649
+ model_config = ConfigDict(extra="allow")
650
+
651
+ @classmethod
652
+ def parse_from_config(cls, config: Dict[str, Any]) -> "StubNode":
653
+ return cls(**config)
654
+
655
+ def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
656
+ """
657
+ Do Nothing node.
658
+ We just send an success attempt log back to the caller
659
+
660
+ Args:
661
+ executor ([type]): [description]
662
+ mock (bool, optional): [description]. Defaults to False.
663
+ map_variable (str, optional): [description]. Defaults to ''.
664
+
665
+ Returns:
666
+ [type]: [description]
667
+ """
668
+ attempt_log = self._context.run_log_store.create_attempt_log()
669
+
670
+ attempt_log.start_time = str(datetime.now())
671
+ attempt_log.status = defaults.SUCCESS # This is a dummy node and always will be success
672
+
673
+ attempt_log.end_time = str(datetime.now())
674
+ attempt_log.duration = utils.get_duration_between_datetime_strings(attempt_log.start_time, attempt_log.end_time)
675
+ return attempt_log
File without changes