runnable 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- runnable/__init__.py +34 -0
- runnable/catalog.py +141 -0
- runnable/cli.py +272 -0
- runnable/context.py +34 -0
- runnable/datastore.py +686 -0
- runnable/defaults.py +179 -0
- runnable/entrypoints.py +484 -0
- runnable/exceptions.py +94 -0
- runnable/executor.py +431 -0
- runnable/experiment_tracker.py +139 -0
- runnable/extensions/catalog/__init__.py +21 -0
- runnable/extensions/catalog/file_system/__init__.py +0 -0
- runnable/extensions/catalog/file_system/implementation.py +226 -0
- runnable/extensions/catalog/k8s_pvc/__init__.py +0 -0
- runnable/extensions/catalog/k8s_pvc/implementation.py +16 -0
- runnable/extensions/catalog/k8s_pvc/integration.py +59 -0
- runnable/extensions/executor/__init__.py +714 -0
- runnable/extensions/executor/argo/__init__.py +0 -0
- runnable/extensions/executor/argo/implementation.py +1182 -0
- runnable/extensions/executor/argo/specification.yaml +51 -0
- runnable/extensions/executor/k8s_job/__init__.py +0 -0
- runnable/extensions/executor/k8s_job/implementation_FF.py +259 -0
- runnable/extensions/executor/k8s_job/integration_FF.py +69 -0
- runnable/extensions/executor/local/__init__.py +0 -0
- runnable/extensions/executor/local/implementation.py +69 -0
- runnable/extensions/executor/local_container/__init__.py +0 -0
- runnable/extensions/executor/local_container/implementation.py +367 -0
- runnable/extensions/executor/mocked/__init__.py +0 -0
- runnable/extensions/executor/mocked/implementation.py +220 -0
- runnable/extensions/experiment_tracker/__init__.py +0 -0
- runnable/extensions/experiment_tracker/mlflow/__init__.py +0 -0
- runnable/extensions/experiment_tracker/mlflow/implementation.py +94 -0
- runnable/extensions/nodes.py +675 -0
- runnable/extensions/run_log_store/__init__.py +0 -0
- runnable/extensions/run_log_store/chunked_file_system/__init__.py +0 -0
- runnable/extensions/run_log_store/chunked_file_system/implementation.py +106 -0
- runnable/extensions/run_log_store/chunked_k8s_pvc/__init__.py +0 -0
- runnable/extensions/run_log_store/chunked_k8s_pvc/implementation.py +21 -0
- runnable/extensions/run_log_store/chunked_k8s_pvc/integration.py +61 -0
- runnable/extensions/run_log_store/db/implementation_FF.py +157 -0
- runnable/extensions/run_log_store/db/integration_FF.py +0 -0
- runnable/extensions/run_log_store/file_system/__init__.py +0 -0
- runnable/extensions/run_log_store/file_system/implementation.py +136 -0
- runnable/extensions/run_log_store/generic_chunked.py +541 -0
- runnable/extensions/run_log_store/k8s_pvc/__init__.py +0 -0
- runnable/extensions/run_log_store/k8s_pvc/implementation.py +21 -0
- runnable/extensions/run_log_store/k8s_pvc/integration.py +56 -0
- runnable/extensions/secrets/__init__.py +0 -0
- runnable/extensions/secrets/dotenv/__init__.py +0 -0
- runnable/extensions/secrets/dotenv/implementation.py +100 -0
- runnable/extensions/secrets/env_secrets/__init__.py +0 -0
- runnable/extensions/secrets/env_secrets/implementation.py +42 -0
- runnable/graph.py +464 -0
- runnable/integration.py +205 -0
- runnable/interaction.py +399 -0
- runnable/names.py +546 -0
- runnable/nodes.py +489 -0
- runnable/parameters.py +183 -0
- runnable/pickler.py +102 -0
- runnable/sdk.py +470 -0
- runnable/secrets.py +95 -0
- runnable/tasks.py +392 -0
- runnable/utils.py +630 -0
- runnable-0.2.0.dist-info/METADATA +437 -0
- runnable-0.2.0.dist-info/RECORD +69 -0
- runnable-0.2.0.dist-info/entry_points.txt +44 -0
- runnable-0.1.0.dist-info/METADATA +0 -16
- runnable-0.1.0.dist-info/RECORD +0 -6
- /runnable/{.gitkeep → extensions/__init__.py} +0 -0
- {runnable-0.1.0.dist-info → runnable-0.2.0.dist-info}/LICENSE +0 -0
- {runnable-0.1.0.dist-info → runnable-0.2.0.dist-info}/WHEEL +0 -0
runnable/nodes.py
ADDED
@@ -0,0 +1,489 @@
|
|
1
|
+
import logging
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
6
|
+
|
7
|
+
import runnable.context as context
|
8
|
+
from runnable import defaults, exceptions
|
9
|
+
from runnable.datastore import StepAttempt
|
10
|
+
from runnable.defaults import TypeMapVariable
|
11
|
+
|
12
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
13
|
+
|
14
|
+
# --8<-- [start:docs]
|
15
|
+
|
16
|
+
|
17
|
+
class BaseNode(ABC, BaseModel):
|
18
|
+
"""
|
19
|
+
Base class with common functionality provided for a Node of a graph.
|
20
|
+
|
21
|
+
A node of a graph could be a
|
22
|
+
* single execution node as task, success, fail.
|
23
|
+
* Could be graph in itself as parallel, dag and map.
|
24
|
+
* could be a convenience function like as-is.
|
25
|
+
|
26
|
+
The name is relative to the DAG.
|
27
|
+
The internal name of the node, is absolute name in dot path convention.
|
28
|
+
This has one to one mapping to the name in the run log
|
29
|
+
The internal name of a node, should always be odd when split against dot.
|
30
|
+
|
31
|
+
The internal branch name, only applies for branched nodes, is the branch it belongs to.
|
32
|
+
The internal branch name should always be even when split against dot.
|
33
|
+
"""
|
34
|
+
|
35
|
+
node_type: str = Field(serialization_alias="type")
|
36
|
+
name: str
|
37
|
+
internal_name: str = Field(exclude=True)
|
38
|
+
internal_branch_name: str = Field(default="", exclude=True)
|
39
|
+
is_composite: bool = Field(default=False, exclude=True)
|
40
|
+
|
41
|
+
@property
|
42
|
+
def _context(self):
|
43
|
+
return context.run_context
|
44
|
+
|
45
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=False)
|
46
|
+
|
47
|
+
@field_validator("name")
|
48
|
+
@classmethod
|
49
|
+
def validate_name(cls, name: str):
|
50
|
+
if "." in name or "%" in name:
|
51
|
+
raise ValueError("Node names cannot have . or '%' in them")
|
52
|
+
return name
|
53
|
+
|
54
|
+
def _command_friendly_name(self, replace_with=defaults.COMMAND_FRIENDLY_CHARACTER) -> str:
|
55
|
+
"""
|
56
|
+
Replace spaces with special character for spaces.
|
57
|
+
Spaces in the naming of the node is convenient for the user but causes issues when used programmatically.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
str: The command friendly name of the node
|
61
|
+
"""
|
62
|
+
return self.internal_name.replace(" ", replace_with)
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def _get_internal_name_from_command_name(cls, command_name: str) -> str:
|
66
|
+
"""
|
67
|
+
Replace Magnus specific character (%) with whitespace.
|
68
|
+
The opposite of _command_friendly_name.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
command_name (str): The command friendly node name
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
str: The internal name of the step
|
75
|
+
"""
|
76
|
+
return command_name.replace(defaults.COMMAND_FRIENDLY_CHARACTER, " ")
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def _resolve_map_placeholders(cls, name: str, map_variable: TypeMapVariable = None) -> str:
|
80
|
+
"""
|
81
|
+
If there is no map step used, then we just return the name as we find it.
|
82
|
+
|
83
|
+
If there is a map variable being used, replace every occurrence of the map variable placeholder with
|
84
|
+
the value sequentially.
|
85
|
+
|
86
|
+
For example:
|
87
|
+
1). dag:
|
88
|
+
start_at: step1
|
89
|
+
steps:
|
90
|
+
step1:
|
91
|
+
type: map
|
92
|
+
iterate_on: y
|
93
|
+
iterate_as: y_i
|
94
|
+
branch:
|
95
|
+
start_at: map_step1
|
96
|
+
steps:
|
97
|
+
map_step1: # internal_name step1.placeholder.map_step1
|
98
|
+
type: task
|
99
|
+
command: a.map_func
|
100
|
+
command_type: python
|
101
|
+
next: map_success
|
102
|
+
map_success:
|
103
|
+
type: success
|
104
|
+
map_failure:
|
105
|
+
type: fail
|
106
|
+
|
107
|
+
and if y is ['a', 'b', 'c'].
|
108
|
+
|
109
|
+
This method would be called 3 times with map_variable = {'y_i': 'a'}, map_variable = {'y_i': 'b'} and
|
110
|
+
map_variable = {'y_i': 'c'} corresponding to the three branches.
|
111
|
+
|
112
|
+
For nested map branches, we would get the map_variables ordered hierarchically.
|
113
|
+
|
114
|
+
Args:
|
115
|
+
name (str): The name to resolve
|
116
|
+
map_variable (dict): The dictionary of map variables
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
[str]: The resolved name
|
120
|
+
"""
|
121
|
+
if not map_variable:
|
122
|
+
return name
|
123
|
+
|
124
|
+
for _, value in map_variable.items():
|
125
|
+
name = name.replace(defaults.MAP_PLACEHOLDER, str(value), 1)
|
126
|
+
|
127
|
+
return name
|
128
|
+
|
129
|
+
def _get_step_log_name(self, map_variable: TypeMapVariable = None) -> str:
|
130
|
+
"""
|
131
|
+
For every step in the dag, there is a corresponding step log name.
|
132
|
+
This method returns the step log name in dot path convention.
|
133
|
+
|
134
|
+
All node types except a map state has a "static" defined step_log names and are equivalent to internal_name.
|
135
|
+
For nodes belonging to map state, the internal name has a placeholder that is replaced at runtime.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
map_variable (dict): If the node is of type map, the names are based on the current iteration state of the
|
139
|
+
parameter.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
str: The dot path name of the step log name
|
143
|
+
"""
|
144
|
+
return self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
|
145
|
+
|
146
|
+
def _get_branch_log_name(self, map_variable: TypeMapVariable = None) -> str:
|
147
|
+
"""
|
148
|
+
For nodes that are internally branches, this method returns the branch log name.
|
149
|
+
The branch log name is in dot path convention.
|
150
|
+
|
151
|
+
For nodes that are not map, the internal branch name is equivalent to the branch name.
|
152
|
+
For map nodes, the internal branch name has a placeholder that is replaced at runtime.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
map_variable (dict): If the node is of type map, the names are based on the current iteration state of the
|
156
|
+
parameter.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
str: The dot path name of the branch log
|
160
|
+
"""
|
161
|
+
return self._resolve_map_placeholders(self.internal_branch_name, map_variable=map_variable)
|
162
|
+
|
163
|
+
def __str__(self) -> str: # pragma: no cover
|
164
|
+
"""
|
165
|
+
String representation of the node.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
str: The string representation of the node.
|
169
|
+
"""
|
170
|
+
return f"Node of type {self.node_type} and name {self.internal_name}"
|
171
|
+
|
172
|
+
@abstractmethod
|
173
|
+
def _get_on_failure_node(self) -> str:
|
174
|
+
"""
|
175
|
+
If the node defines a on_failure node in the config, return this or None.
|
176
|
+
|
177
|
+
The naming is relative to the dag, the caller is supposed to resolve it to the correct graph
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
str: The on_failure node defined by the dag or ''
|
181
|
+
This is a base implementation which the BaseNode does not satisfy
|
182
|
+
"""
|
183
|
+
...
|
184
|
+
|
185
|
+
@abstractmethod
|
186
|
+
def _get_next_node(self) -> str:
|
187
|
+
"""
|
188
|
+
Return the next node as defined by the config.
|
189
|
+
|
190
|
+
Returns:
|
191
|
+
str: The node name, relative to the dag, as defined by the config
|
192
|
+
"""
|
193
|
+
...
|
194
|
+
|
195
|
+
@abstractmethod
|
196
|
+
def _is_terminal_node(self) -> bool:
|
197
|
+
"""
|
198
|
+
Returns whether a node has a next node
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
bool: True or False of whether there is next node.
|
202
|
+
"""
|
203
|
+
...
|
204
|
+
|
205
|
+
@abstractmethod
|
206
|
+
def _get_catalog_settings(self) -> Dict[str, Any]:
|
207
|
+
"""
|
208
|
+
If the node defines a catalog settings, return it or None
|
209
|
+
|
210
|
+
Returns:
|
211
|
+
dict: catalog settings defined as per the node or None
|
212
|
+
"""
|
213
|
+
...
|
214
|
+
|
215
|
+
@abstractmethod
|
216
|
+
def _get_branch_by_name(self, branch_name: str):
|
217
|
+
"""
|
218
|
+
Retrieve a branch by name.
|
219
|
+
|
220
|
+
The name is expected to follow a dot path convention.
|
221
|
+
|
222
|
+
Args:
|
223
|
+
branch_name (str): [description]
|
224
|
+
|
225
|
+
Raises:
|
226
|
+
Exception: [description]
|
227
|
+
"""
|
228
|
+
...
|
229
|
+
|
230
|
+
def _get_neighbors(self) -> List[str]:
|
231
|
+
"""
|
232
|
+
Gets the connecting neighbor nodes, either the "next" node or "on_failure" node.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
list: List of connected neighbors for a given node. Empty if terminal node.
|
236
|
+
"""
|
237
|
+
neighbors = []
|
238
|
+
try:
|
239
|
+
next_node = self._get_next_node()
|
240
|
+
neighbors += [next_node]
|
241
|
+
except exceptions.TerminalNodeError:
|
242
|
+
pass
|
243
|
+
|
244
|
+
try:
|
245
|
+
fail_node = self._get_on_failure_node()
|
246
|
+
if fail_node:
|
247
|
+
neighbors += [fail_node]
|
248
|
+
except exceptions.TerminalNodeError:
|
249
|
+
pass
|
250
|
+
|
251
|
+
return neighbors
|
252
|
+
|
253
|
+
@abstractmethod
|
254
|
+
def _get_executor_config(self, executor_type: str) -> str:
|
255
|
+
"""
|
256
|
+
Return the executor config of the node, if defined, or empty dict
|
257
|
+
|
258
|
+
Args:
|
259
|
+
executor_type (str): The executor type that the config refers to.
|
260
|
+
|
261
|
+
Returns:
|
262
|
+
dict: The executor config, if defined or an empty dict
|
263
|
+
"""
|
264
|
+
...
|
265
|
+
|
266
|
+
@abstractmethod
|
267
|
+
def _get_max_attempts(self) -> int:
|
268
|
+
"""
|
269
|
+
The number of max attempts as defined by the config or 1.
|
270
|
+
|
271
|
+
Returns:
|
272
|
+
int: The number of maximum retries as defined by the config or 1.
|
273
|
+
"""
|
274
|
+
...
|
275
|
+
|
276
|
+
@abstractmethod
|
277
|
+
def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
|
278
|
+
"""
|
279
|
+
The actual function that does the execution of the command in the config.
|
280
|
+
|
281
|
+
Should only be implemented for task, success, fail and as-is and never for
|
282
|
+
composite nodes.
|
283
|
+
|
284
|
+
Args:
|
285
|
+
executor (magnus.executor.BaseExecutor): The executor class
|
286
|
+
mock (bool, optional): Don't run, just pretend. Defaults to False.
|
287
|
+
map_variable (str, optional): The value of the map iteration variable, if part of a map node.
|
288
|
+
Defaults to ''.
|
289
|
+
|
290
|
+
Raises:
|
291
|
+
NotImplementedError: Base class, hence not implemented.
|
292
|
+
"""
|
293
|
+
...
|
294
|
+
|
295
|
+
@abstractmethod
|
296
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
|
297
|
+
"""
|
298
|
+
This function would be called to set up the execution of the individual
|
299
|
+
branches of a composite node.
|
300
|
+
|
301
|
+
Function should only be implemented for composite nodes like dag, map, parallel.
|
302
|
+
|
303
|
+
Args:
|
304
|
+
executor (magnus.executor.BaseExecutor): The executor.
|
305
|
+
|
306
|
+
Raises:
|
307
|
+
NotImplementedError: Base class, hence not implemented.
|
308
|
+
"""
|
309
|
+
...
|
310
|
+
|
311
|
+
@abstractmethod
|
312
|
+
def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
|
313
|
+
"""
|
314
|
+
This function would be called to set up the execution of the individual
|
315
|
+
branches of a composite node.
|
316
|
+
|
317
|
+
Function should only be implemented for composite nodes like dag, map, parallel.
|
318
|
+
|
319
|
+
Args:
|
320
|
+
executor (magnus.executor.BaseExecutor): The executor.
|
321
|
+
map_variable (str, optional): The value of the map iteration variable, if part of a map node.
|
322
|
+
|
323
|
+
Raises:
|
324
|
+
Exception: If the node is not a composite node.
|
325
|
+
"""
|
326
|
+
...
|
327
|
+
|
328
|
+
@abstractmethod
|
329
|
+
def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
|
330
|
+
"""
|
331
|
+
This function would be called to tear down the execution of the individual
|
332
|
+
branches of a composite node.
|
333
|
+
|
334
|
+
Function should only be implemented for composite nodes like dag, map, parallel.
|
335
|
+
|
336
|
+
Args:
|
337
|
+
executor (magnus.executor.BaseExecutor): The executor.
|
338
|
+
map_variable (str, optional): The value of the map iteration variable, if part of a map node.
|
339
|
+
|
340
|
+
Raises:
|
341
|
+
Exception: If the node is not a composite node.
|
342
|
+
"""
|
343
|
+
...
|
344
|
+
|
345
|
+
@classmethod
|
346
|
+
@abstractmethod
|
347
|
+
def parse_from_config(cls, config: Dict[str, Any]) -> "BaseNode":
|
348
|
+
"""
|
349
|
+
Parse the config from the user and create the corresponding node.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
config (Dict[str, Any]): The config of the node from the yaml or from the sdk.
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
BaseNode: The corresponding node.
|
356
|
+
"""
|
357
|
+
...
|
358
|
+
|
359
|
+
|
360
|
+
# --8<-- [end:docs]
|
361
|
+
class TraversalNode(BaseNode):
|
362
|
+
next_node: str = Field(serialization_alias="next")
|
363
|
+
on_failure: str = Field(default="")
|
364
|
+
overrides: Dict[str, str] = Field(default_factory=dict)
|
365
|
+
|
366
|
+
def _get_on_failure_node(self) -> str:
|
367
|
+
"""
|
368
|
+
If the node defines a on_failure node in the config, return this or None.
|
369
|
+
|
370
|
+
The naming is relative to the dag, the caller is supposed to resolve it to the correct graph
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
str: The on_failure node defined by the dag or ''
|
374
|
+
This is a base implementation which the BaseNode does not satisfy
|
375
|
+
"""
|
376
|
+
return self.on_failure
|
377
|
+
|
378
|
+
def _get_next_node(self) -> str:
|
379
|
+
"""
|
380
|
+
Return the next node as defined by the config.
|
381
|
+
|
382
|
+
Returns:
|
383
|
+
str: The node name, relative to the dag, as defined by the config
|
384
|
+
"""
|
385
|
+
|
386
|
+
return self.next_node
|
387
|
+
|
388
|
+
def _is_terminal_node(self) -> bool:
|
389
|
+
"""
|
390
|
+
Returns whether a node has a next node
|
391
|
+
|
392
|
+
Returns:
|
393
|
+
bool: True or False of whether there is next node.
|
394
|
+
"""
|
395
|
+
return False
|
396
|
+
|
397
|
+
def _get_executor_config(self, executor_type) -> str:
|
398
|
+
return self.overrides.get(executor_type) or ""
|
399
|
+
|
400
|
+
|
401
|
+
class CatalogStructure(BaseModel):
|
402
|
+
model_config = ConfigDict(extra="forbid") # Need to forbid
|
403
|
+
|
404
|
+
get: List[str] = Field(default_factory=list)
|
405
|
+
put: List[str] = Field(default_factory=list)
|
406
|
+
|
407
|
+
|
408
|
+
class ExecutableNode(TraversalNode):
|
409
|
+
catalog: Optional[CatalogStructure] = Field(default=None)
|
410
|
+
max_attempts: int = Field(default=1, ge=1)
|
411
|
+
|
412
|
+
def _get_catalog_settings(self) -> Dict[str, Any]:
|
413
|
+
"""
|
414
|
+
If the node defines a catalog settings, return it or None
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
dict: catalog settings defined as per the node or None
|
418
|
+
"""
|
419
|
+
if self.catalog:
|
420
|
+
return self.catalog.model_dump()
|
421
|
+
return {}
|
422
|
+
|
423
|
+
def _get_max_attempts(self) -> int:
|
424
|
+
return self.max_attempts
|
425
|
+
|
426
|
+
def _get_branch_by_name(self, branch_name: str):
|
427
|
+
raise Exception("This is an executable node and does not have branches")
|
428
|
+
|
429
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
|
430
|
+
raise Exception("This is an executable node and does not have a graph")
|
431
|
+
|
432
|
+
def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
|
433
|
+
raise Exception("This is an executable node and does not have a fan in")
|
434
|
+
|
435
|
+
def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
|
436
|
+
raise Exception("This is an executable node and does not have a fan out")
|
437
|
+
|
438
|
+
|
439
|
+
class CompositeNode(TraversalNode):
|
440
|
+
def _get_catalog_settings(self) -> Dict[str, Any]:
|
441
|
+
"""
|
442
|
+
If the node defines a catalog settings, return it or None
|
443
|
+
|
444
|
+
Returns:
|
445
|
+
dict: catalog settings defined as per the node or None
|
446
|
+
"""
|
447
|
+
raise Exception("This is a composite node and does not have a catalog settings")
|
448
|
+
|
449
|
+
def _get_max_attempts(self) -> int:
|
450
|
+
raise Exception("This is a composite node and does not have a max_attempts")
|
451
|
+
|
452
|
+
def execute(self, mock=False, map_variable: TypeMapVariable = None, **kwargs) -> StepAttempt:
|
453
|
+
raise Exception("This is a composite node and does not have an execute function")
|
454
|
+
|
455
|
+
|
456
|
+
class TerminalNode(BaseNode):
|
457
|
+
def _get_on_failure_node(self) -> str:
|
458
|
+
raise exceptions.TerminalNodeError()
|
459
|
+
|
460
|
+
def _get_next_node(self) -> str:
|
461
|
+
raise exceptions.TerminalNodeError()
|
462
|
+
|
463
|
+
def _is_terminal_node(self) -> bool:
|
464
|
+
return True
|
465
|
+
|
466
|
+
def _get_catalog_settings(self) -> Dict[str, Any]:
|
467
|
+
raise exceptions.TerminalNodeError()
|
468
|
+
|
469
|
+
def _get_branch_by_name(self, branch_name: str):
|
470
|
+
raise exceptions.TerminalNodeError()
|
471
|
+
|
472
|
+
def _get_executor_config(self, executor_type) -> str:
|
473
|
+
raise exceptions.TerminalNodeError()
|
474
|
+
|
475
|
+
def _get_max_attempts(self) -> int:
|
476
|
+
return 1
|
477
|
+
|
478
|
+
def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
|
479
|
+
raise exceptions.TerminalNodeError()
|
480
|
+
|
481
|
+
def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
|
482
|
+
raise exceptions.TerminalNodeError()
|
483
|
+
|
484
|
+
def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
|
485
|
+
raise exceptions.TerminalNodeError()
|
486
|
+
|
487
|
+
@classmethod
|
488
|
+
def parse_from_config(cls, config: Dict[str, Any]) -> "TerminalNode":
|
489
|
+
return cls(**config)
|
runnable/parameters.py
ADDED
@@ -0,0 +1,183 @@
|
|
1
|
+
import inspect
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, Optional, Type, Union
|
6
|
+
|
7
|
+
from pydantic import BaseModel, ConfigDict
|
8
|
+
from typing_extensions import Callable
|
9
|
+
|
10
|
+
from runnable import defaults
|
11
|
+
from runnable.defaults import TypeMapVariable
|
12
|
+
from runnable.utils import remove_prefix
|
13
|
+
|
14
|
+
logger = logging.getLogger(defaults.LOGGER_NAME)
|
15
|
+
|
16
|
+
|
17
|
+
def get_user_set_parameters(remove: bool = False) -> Dict[str, Any]:
|
18
|
+
"""
|
19
|
+
Scans the environment variables for any user returned parameters that have a prefix MAGNUS_PRM_.
|
20
|
+
|
21
|
+
This function does not deal with any type conversion of the parameters.
|
22
|
+
It just deserializes the parameters and returns them as a dictionary.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
remove (bool, optional): Flag to remove the parameter if needed. Defaults to False.
|
26
|
+
|
27
|
+
Returns:
|
28
|
+
dict: The dictionary of found user returned parameters
|
29
|
+
"""
|
30
|
+
parameters = {}
|
31
|
+
for env_var, value in os.environ.items():
|
32
|
+
if env_var.startswith(defaults.PARAMETER_PREFIX):
|
33
|
+
key = remove_prefix(env_var, defaults.PARAMETER_PREFIX)
|
34
|
+
try:
|
35
|
+
parameters[key.lower()] = json.loads(value)
|
36
|
+
except json.decoder.JSONDecodeError:
|
37
|
+
logger.error(f"Parameter {key} could not be JSON decoded, adding the literal value")
|
38
|
+
parameters[key.lower()] = value
|
39
|
+
|
40
|
+
if remove:
|
41
|
+
del os.environ[env_var]
|
42
|
+
return parameters
|
43
|
+
|
44
|
+
|
45
|
+
def set_user_defined_params_as_environment_variables(params: Dict[str, Any]):
|
46
|
+
"""
|
47
|
+
Sets the user set parameters as environment variables.
|
48
|
+
|
49
|
+
At this point in time, the params are already in Dict or some kind of literal
|
50
|
+
|
51
|
+
Args:
|
52
|
+
parameters (Dict[str, Any]): The parameters to set as environment variables
|
53
|
+
update (bool, optional): Flag to update the environment variables. Defaults to True.
|
54
|
+
|
55
|
+
"""
|
56
|
+
for key, value in params.items():
|
57
|
+
logger.info(f"Storing parameter {key} with value: {value}")
|
58
|
+
environ_key = defaults.PARAMETER_PREFIX + key
|
59
|
+
|
60
|
+
os.environ[environ_key] = serialize_parameter_as_str(value)
|
61
|
+
|
62
|
+
|
63
|
+
def cast_parameters_as_type(value: Any, newT: Optional[Type] = None) -> Union[Any, BaseModel, Dict[str, Any]]:
|
64
|
+
"""
|
65
|
+
Casts the environment variable to the given type.
|
66
|
+
|
67
|
+
Note: Only pydantic models special, everything else just goes through.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
value (Any): The value to cast
|
71
|
+
newT (T): The type to cast to
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
T: The casted value
|
75
|
+
|
76
|
+
Examples:
|
77
|
+
>>> class MyBaseModel(BaseModel):
|
78
|
+
... a: int
|
79
|
+
... b: str
|
80
|
+
>>>
|
81
|
+
>>> class MyDict(dict):
|
82
|
+
... pass
|
83
|
+
>>>
|
84
|
+
>>> cast_parameters_as_type({"a": 1, "b": "2"}, MyBaseModel)
|
85
|
+
MyBaseModel(a=1, b="2")
|
86
|
+
>>> cast_parameters_as_type({"a": 1, "b": "2"}, MyDict)
|
87
|
+
MyDict({'a': 1, 'b': '2'})
|
88
|
+
>>> cast_parameters_as_type(MyBaseModel(a=1, b="2"), MyBaseModel)
|
89
|
+
MyBaseModel(a=1, b="2")
|
90
|
+
>>> cast_parameters_as_type(MyDict({"a": 1, "b": "2"}), MyBaseModel)
|
91
|
+
MyBaseModel(a=1, b="2")
|
92
|
+
>>> cast_parameters_as_type({"a": 1, "b": "2"}, MyDict[str, int])
|
93
|
+
MyDict({'a': 1, 'b': '2'})
|
94
|
+
>>> cast_parameters_as_type({"a": 1, "b": "2"}, Dict[str, int])
|
95
|
+
MyDict({'a': 1, 'b': '2'})
|
96
|
+
>>> with pytest.warns(UserWarning):
|
97
|
+
... cast_parameters_as_type(1, MyBaseModel)
|
98
|
+
MyBaseModel(a=1, b=None)
|
99
|
+
>>> with pytest.raises(TypeError):
|
100
|
+
... cast_parameters_as_type(1, MyDict)
|
101
|
+
"""
|
102
|
+
if not newT:
|
103
|
+
return value
|
104
|
+
|
105
|
+
if issubclass(newT, BaseModel):
|
106
|
+
return newT(**value)
|
107
|
+
|
108
|
+
if issubclass(newT, Dict):
|
109
|
+
return dict(value)
|
110
|
+
|
111
|
+
if type(value) != newT:
|
112
|
+
logger.warning(f"Casting {value} of {type(value)} to {newT} seems wrong!!")
|
113
|
+
|
114
|
+
return newT(value)
|
115
|
+
|
116
|
+
|
117
|
+
def serialize_parameter_as_str(value: Any) -> str:
|
118
|
+
if isinstance(value, BaseModel):
|
119
|
+
return json.dumps(value.model_dump())
|
120
|
+
|
121
|
+
return json.dumps(value)
|
122
|
+
|
123
|
+
|
124
|
+
def filter_arguments_for_func(
|
125
|
+
func: Callable[..., Any], params: Dict[str, Any], map_variable: TypeMapVariable = None
|
126
|
+
) -> Dict[str, Any]:
|
127
|
+
"""
|
128
|
+
Inspects the function to be called as part of the pipeline to find the arguments of the function.
|
129
|
+
Matches the function arguments to the parameters available either by command line or by up stream steps.
|
130
|
+
|
131
|
+
|
132
|
+
Args:
|
133
|
+
func (Callable): The function to inspect
|
134
|
+
parameters (dict): The parameters available for the run
|
135
|
+
|
136
|
+
Returns:
|
137
|
+
dict: The parameters matching the function signature
|
138
|
+
"""
|
139
|
+
function_args = inspect.signature(func).parameters
|
140
|
+
|
141
|
+
# Update parameters with the map variables
|
142
|
+
params.update(map_variable or {})
|
143
|
+
|
144
|
+
unassigned_params = set(params.keys())
|
145
|
+
bound_args = {}
|
146
|
+
for name, value in function_args.items():
|
147
|
+
if name not in params:
|
148
|
+
# No parameter of this name was provided
|
149
|
+
if value.default == inspect.Parameter.empty:
|
150
|
+
# No default value is given in the function signature. error as parameter is required.
|
151
|
+
raise ValueError(f"Parameter {name} is required for {func.__name__} but not provided")
|
152
|
+
# default value is given in the function signature, nothing further to do.
|
153
|
+
continue
|
154
|
+
|
155
|
+
if issubclass(value.annotation, BaseModel):
|
156
|
+
# We try to cast it as a pydantic model.
|
157
|
+
named_param = params[name]
|
158
|
+
|
159
|
+
if not isinstance(named_param, dict):
|
160
|
+
# A case where the parameter is a one attribute model
|
161
|
+
named_param = {name: named_param}
|
162
|
+
|
163
|
+
bound_model = bind_args_for_pydantic_model(named_param, value.annotation)
|
164
|
+
bound_args[name] = bound_model
|
165
|
+
unassigned_params = unassigned_params.difference(bound_model.model_fields.keys())
|
166
|
+
else:
|
167
|
+
# simple python data type.
|
168
|
+
bound_args[name] = cast_parameters_as_type(params[name], value.annotation) # type: ignore
|
169
|
+
|
170
|
+
unassigned_params.remove(name)
|
171
|
+
|
172
|
+
params = {key: params[key] for key in unassigned_params} # remove keys from params if they are assigned
|
173
|
+
|
174
|
+
return bound_args
|
175
|
+
|
176
|
+
|
177
|
+
def bind_args_for_pydantic_model(params: Dict[str, Any], model: Type[BaseModel]) -> BaseModel:
|
178
|
+
class EasyModel(model): # type: ignore
|
179
|
+
model_config = ConfigDict(extra="ignore")
|
180
|
+
|
181
|
+
swallow_all = EasyModel(**params)
|
182
|
+
bound_model = model(**swallow_all.model_dump())
|
183
|
+
return bound_model
|