runnable 0.1.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. runnable/__init__.py +34 -0
  2. runnable/catalog.py +141 -0
  3. runnable/cli.py +272 -0
  4. runnable/context.py +34 -0
  5. runnable/datastore.py +687 -0
  6. runnable/defaults.py +182 -0
  7. runnable/entrypoints.py +448 -0
  8. runnable/exceptions.py +94 -0
  9. runnable/executor.py +421 -0
  10. runnable/experiment_tracker.py +139 -0
  11. runnable/extensions/catalog/__init__.py +21 -0
  12. runnable/extensions/catalog/file_system/__init__.py +0 -0
  13. runnable/extensions/catalog/file_system/implementation.py +227 -0
  14. runnable/extensions/catalog/k8s_pvc/__init__.py +0 -0
  15. runnable/extensions/catalog/k8s_pvc/implementation.py +16 -0
  16. runnable/extensions/catalog/k8s_pvc/integration.py +59 -0
  17. runnable/extensions/executor/__init__.py +725 -0
  18. runnable/extensions/executor/argo/__init__.py +0 -0
  19. runnable/extensions/executor/argo/implementation.py +1183 -0
  20. runnable/extensions/executor/argo/specification.yaml +51 -0
  21. runnable/extensions/executor/k8s_job/__init__.py +0 -0
  22. runnable/extensions/executor/k8s_job/implementation_FF.py +259 -0
  23. runnable/extensions/executor/k8s_job/integration_FF.py +69 -0
  24. runnable/extensions/executor/local/__init__.py +0 -0
  25. runnable/extensions/executor/local/implementation.py +70 -0
  26. runnable/extensions/executor/local_container/__init__.py +0 -0
  27. runnable/extensions/executor/local_container/implementation.py +361 -0
  28. runnable/extensions/executor/mocked/__init__.py +0 -0
  29. runnable/extensions/executor/mocked/implementation.py +189 -0
  30. runnable/extensions/experiment_tracker/__init__.py +0 -0
  31. runnable/extensions/experiment_tracker/mlflow/__init__.py +0 -0
  32. runnable/extensions/experiment_tracker/mlflow/implementation.py +94 -0
  33. runnable/extensions/nodes.py +655 -0
  34. runnable/extensions/run_log_store/__init__.py +0 -0
  35. runnable/extensions/run_log_store/chunked_file_system/__init__.py +0 -0
  36. runnable/extensions/run_log_store/chunked_file_system/implementation.py +106 -0
  37. runnable/extensions/run_log_store/chunked_k8s_pvc/__init__.py +0 -0
  38. runnable/extensions/run_log_store/chunked_k8s_pvc/implementation.py +21 -0
  39. runnable/extensions/run_log_store/chunked_k8s_pvc/integration.py +61 -0
  40. runnable/extensions/run_log_store/db/implementation_FF.py +157 -0
  41. runnable/extensions/run_log_store/db/integration_FF.py +0 -0
  42. runnable/extensions/run_log_store/file_system/__init__.py +0 -0
  43. runnable/extensions/run_log_store/file_system/implementation.py +136 -0
  44. runnable/extensions/run_log_store/generic_chunked.py +541 -0
  45. runnable/extensions/run_log_store/k8s_pvc/__init__.py +0 -0
  46. runnable/extensions/run_log_store/k8s_pvc/implementation.py +21 -0
  47. runnable/extensions/run_log_store/k8s_pvc/integration.py +56 -0
  48. runnable/extensions/secrets/__init__.py +0 -0
  49. runnable/extensions/secrets/dotenv/__init__.py +0 -0
  50. runnable/extensions/secrets/dotenv/implementation.py +100 -0
  51. runnable/extensions/secrets/env_secrets/__init__.py +0 -0
  52. runnable/extensions/secrets/env_secrets/implementation.py +42 -0
  53. runnable/graph.py +464 -0
  54. runnable/integration.py +205 -0
  55. runnable/interaction.py +404 -0
  56. runnable/names.py +546 -0
  57. runnable/nodes.py +501 -0
  58. runnable/parameters.py +183 -0
  59. runnable/pickler.py +102 -0
  60. runnable/sdk.py +472 -0
  61. runnable/secrets.py +95 -0
  62. runnable/tasks.py +395 -0
  63. runnable/utils.py +630 -0
  64. runnable-0.3.0.dist-info/METADATA +437 -0
  65. runnable-0.3.0.dist-info/RECORD +69 -0
  66. {runnable-0.1.0.dist-info → runnable-0.3.0.dist-info}/WHEEL +1 -1
  67. runnable-0.3.0.dist-info/entry_points.txt +44 -0
  68. runnable-0.1.0.dist-info/METADATA +0 -16
  69. runnable-0.1.0.dist-info/RECORD +0 -6
  70. /runnable/{.gitkeep → extensions/__init__.py} +0 -0
  71. {runnable-0.1.0.dist-info → runnable-0.3.0.dist-info}/LICENSE +0 -0
runnable/nodes.py ADDED
@@ -0,0 +1,501 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
6
+
7
+ import runnable.context as context
8
+ from runnable import defaults, exceptions
9
+ from runnable.datastore import StepAttempt
10
+ from runnable.defaults import TypeMapVariable
11
+
12
+ logger = logging.getLogger(defaults.LOGGER_NAME)
13
+
14
+ # --8<-- [start:docs]
15
+
16
+
17
+ class BaseNode(ABC, BaseModel):
18
+ """
19
+ Base class with common functionality provided for a Node of a graph.
20
+
21
+ A node of a graph could be a
22
+ * single execution node as task, success, fail.
23
+ * Could be graph in itself as parallel, dag and map.
24
+ * could be a convenience function like as-is.
25
+
26
+ The name is relative to the DAG.
27
+ The internal name of the node, is absolute name in dot path convention.
28
+ This has one to one mapping to the name in the run log
29
+ The internal name of a node, should always be odd when split against dot.
30
+
31
+ The internal branch name, only applies for branched nodes, is the branch it belongs to.
32
+ The internal branch name should always be even when split against dot.
33
+ """
34
+
35
+ node_type: str = Field(serialization_alias="type")
36
+ name: str
37
+ internal_name: str = Field(exclude=True)
38
+ internal_branch_name: str = Field(default="", exclude=True)
39
+ is_composite: bool = Field(default=False, exclude=True)
40
+
41
+ @property
42
+ def _context(self):
43
+ return context.run_context
44
+
45
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=False)
46
+
47
+ @field_validator("name")
48
+ @classmethod
49
+ def validate_name(cls, name: str):
50
+ if "." in name or "%" in name:
51
+ raise ValueError("Node names cannot have . or '%' in them")
52
+ return name
53
+
54
+ def _command_friendly_name(self, replace_with=defaults.COMMAND_FRIENDLY_CHARACTER) -> str:
55
+ """
56
+ Replace spaces with special character for spaces.
57
+ Spaces in the naming of the node is convenient for the user but causes issues when used programmatically.
58
+
59
+ Returns:
60
+ str: The command friendly name of the node
61
+ """
62
+ return self.internal_name.replace(" ", replace_with)
63
+
64
+ @classmethod
65
+ def _get_internal_name_from_command_name(cls, command_name: str) -> str:
66
+ """
67
+ Replace runnable specific character (%) with whitespace.
68
+ The opposite of _command_friendly_name.
69
+
70
+ Args:
71
+ command_name (str): The command friendly node name
72
+
73
+ Returns:
74
+ str: The internal name of the step
75
+ """
76
+ return command_name.replace(defaults.COMMAND_FRIENDLY_CHARACTER, " ")
77
+
78
+ @classmethod
79
+ def _resolve_map_placeholders(cls, name: str, map_variable: TypeMapVariable = None) -> str:
80
+ """
81
+ If there is no map step used, then we just return the name as we find it.
82
+
83
+ If there is a map variable being used, replace every occurrence of the map variable placeholder with
84
+ the value sequentially.
85
+
86
+ For example:
87
+ 1). dag:
88
+ start_at: step1
89
+ steps:
90
+ step1:
91
+ type: map
92
+ iterate_on: y
93
+ iterate_as: y_i
94
+ branch:
95
+ start_at: map_step1
96
+ steps:
97
+ map_step1: # internal_name step1.placeholder.map_step1
98
+ type: task
99
+ command: a.map_func
100
+ command_type: python
101
+ next: map_success
102
+ map_success:
103
+ type: success
104
+ map_failure:
105
+ type: fail
106
+
107
+ and if y is ['a', 'b', 'c'].
108
+
109
+ This method would be called 3 times with map_variable = {'y_i': 'a'}, map_variable = {'y_i': 'b'} and
110
+ map_variable = {'y_i': 'c'} corresponding to the three branches.
111
+
112
+ For nested map branches, we would get the map_variables ordered hierarchically.
113
+
114
+ Args:
115
+ name (str): The name to resolve
116
+ map_variable (dict): The dictionary of map variables
117
+
118
+ Returns:
119
+ [str]: The resolved name
120
+ """
121
+ if not map_variable:
122
+ return name
123
+
124
+ for _, value in map_variable.items():
125
+ name = name.replace(defaults.MAP_PLACEHOLDER, str(value), 1)
126
+
127
+ return name
128
+
129
+ def _get_step_log_name(self, map_variable: TypeMapVariable = None) -> str:
130
+ """
131
+ For every step in the dag, there is a corresponding step log name.
132
+ This method returns the step log name in dot path convention.
133
+
134
+ All node types except a map state has a "static" defined step_log names and are equivalent to internal_name.
135
+ For nodes belonging to map state, the internal name has a placeholder that is replaced at runtime.
136
+
137
+ Args:
138
+ map_variable (dict): If the node is of type map, the names are based on the current iteration state of the
139
+ parameter.
140
+
141
+ Returns:
142
+ str: The dot path name of the step log name
143
+ """
144
+ return self._resolve_map_placeholders(self.internal_name, map_variable=map_variable)
145
+
146
+ def _get_branch_log_name(self, map_variable: TypeMapVariable = None) -> str:
147
+ """
148
+ For nodes that are internally branches, this method returns the branch log name.
149
+ The branch log name is in dot path convention.
150
+
151
+ For nodes that are not map, the internal branch name is equivalent to the branch name.
152
+ For map nodes, the internal branch name has a placeholder that is replaced at runtime.
153
+
154
+ Args:
155
+ map_variable (dict): If the node is of type map, the names are based on the current iteration state of the
156
+ parameter.
157
+
158
+ Returns:
159
+ str: The dot path name of the branch log
160
+ """
161
+ return self._resolve_map_placeholders(self.internal_branch_name, map_variable=map_variable)
162
+
163
+ def __str__(self) -> str: # pragma: no cover
164
+ """
165
+ String representation of the node.
166
+
167
+ Returns:
168
+ str: The string representation of the node.
169
+ """
170
+ return f"Node of type {self.node_type} and name {self.internal_name}"
171
+
172
+ @abstractmethod
173
+ def _get_on_failure_node(self) -> str:
174
+ """
175
+ If the node defines a on_failure node in the config, return this or None.
176
+
177
+ The naming is relative to the dag, the caller is supposed to resolve it to the correct graph
178
+
179
+ Returns:
180
+ str: The on_failure node defined by the dag or ''
181
+ This is a base implementation which the BaseNode does not satisfy
182
+ """
183
+ ...
184
+
185
+ @abstractmethod
186
+ def _get_next_node(self) -> str:
187
+ """
188
+ Return the next node as defined by the config.
189
+
190
+ Returns:
191
+ str: The node name, relative to the dag, as defined by the config
192
+ """
193
+ ...
194
+
195
+ @abstractmethod
196
+ def _is_terminal_node(self) -> bool:
197
+ """
198
+ Returns whether a node has a next node
199
+
200
+ Returns:
201
+ bool: True or False of whether there is next node.
202
+ """
203
+ ...
204
+
205
+ @abstractmethod
206
+ def _get_catalog_settings(self) -> Dict[str, Any]:
207
+ """
208
+ If the node defines a catalog settings, return it or None
209
+
210
+ Returns:
211
+ dict: catalog settings defined as per the node or None
212
+ """
213
+ ...
214
+
215
+ @abstractmethod
216
+ def _get_branch_by_name(self, branch_name: str):
217
+ """
218
+ Retrieve a branch by name.
219
+
220
+ The name is expected to follow a dot path convention.
221
+
222
+ Args:
223
+ branch_name (str): [description]
224
+
225
+ Raises:
226
+ Exception: [description]
227
+ """
228
+ ...
229
+
230
+ def _get_neighbors(self) -> List[str]:
231
+ """
232
+ Gets the connecting neighbor nodes, either the "next" node or "on_failure" node.
233
+
234
+ Returns:
235
+ list: List of connected neighbors for a given node. Empty if terminal node.
236
+ """
237
+ neighbors = []
238
+ try:
239
+ next_node = self._get_next_node()
240
+ neighbors += [next_node]
241
+ except exceptions.TerminalNodeError:
242
+ pass
243
+
244
+ try:
245
+ fail_node = self._get_on_failure_node()
246
+ if fail_node:
247
+ neighbors += [fail_node]
248
+ except exceptions.TerminalNodeError:
249
+ pass
250
+
251
+ return neighbors
252
+
253
+ @abstractmethod
254
+ def _get_executor_config(self, executor_type: str) -> str:
255
+ """
256
+ Return the executor config of the node, if defined, or empty dict
257
+
258
+ Args:
259
+ executor_type (str): The executor type that the config refers to.
260
+
261
+ Returns:
262
+ dict: The executor config, if defined or an empty dict
263
+ """
264
+ ...
265
+
266
+ @abstractmethod
267
+ def _get_max_attempts(self) -> int:
268
+ """
269
+ The number of max attempts as defined by the config or 1.
270
+
271
+ Returns:
272
+ int: The number of maximum retries as defined by the config or 1.
273
+ """
274
+ ...
275
+
276
+ @abstractmethod
277
+ def execute(
278
+ self,
279
+ mock=False,
280
+ params: Optional[Dict[str, Any]] = None,
281
+ map_variable: TypeMapVariable = None,
282
+ **kwargs,
283
+ ) -> StepAttempt:
284
+ """
285
+ The actual function that does the execution of the command in the config.
286
+
287
+ Should only be implemented for task, success, fail and as-is and never for
288
+ composite nodes.
289
+
290
+ Args:
291
+ executor (runnable.executor.BaseExecutor): The executor class
292
+ mock (bool, optional): Don't run, just pretend. Defaults to False.
293
+ map_variable (str, optional): The value of the map iteration variable, if part of a map node.
294
+ Defaults to ''.
295
+
296
+ Raises:
297
+ NotImplementedError: Base class, hence not implemented.
298
+ """
299
+ ...
300
+
301
+ @abstractmethod
302
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
303
+ """
304
+ This function would be called to set up the execution of the individual
305
+ branches of a composite node.
306
+
307
+ Function should only be implemented for composite nodes like dag, map, parallel.
308
+
309
+ Args:
310
+ executor (runnable.executor.BaseExecutor): The executor.
311
+
312
+ Raises:
313
+ NotImplementedError: Base class, hence not implemented.
314
+ """
315
+ ...
316
+
317
+ @abstractmethod
318
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
319
+ """
320
+ This function would be called to set up the execution of the individual
321
+ branches of a composite node.
322
+
323
+ Function should only be implemented for composite nodes like dag, map, parallel.
324
+
325
+ Args:
326
+ executor (runnable.executor.BaseExecutor): The executor.
327
+ map_variable (str, optional): The value of the map iteration variable, if part of a map node.
328
+
329
+ Raises:
330
+ Exception: If the node is not a composite node.
331
+ """
332
+ ...
333
+
334
+ @abstractmethod
335
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
336
+ """
337
+ This function would be called to tear down the execution of the individual
338
+ branches of a composite node.
339
+
340
+ Function should only be implemented for composite nodes like dag, map, parallel.
341
+
342
+ Args:
343
+ executor (runnable.executor.BaseExecutor): The executor.
344
+ map_variable (str, optional): The value of the map iteration variable, if part of a map node.
345
+
346
+ Raises:
347
+ Exception: If the node is not a composite node.
348
+ """
349
+ ...
350
+
351
+ @classmethod
352
+ @abstractmethod
353
+ def parse_from_config(cls, config: Dict[str, Any]) -> "BaseNode":
354
+ """
355
+ Parse the config from the user and create the corresponding node.
356
+
357
+ Args:
358
+ config (Dict[str, Any]): The config of the node from the yaml or from the sdk.
359
+
360
+ Returns:
361
+ BaseNode: The corresponding node.
362
+ """
363
+ ...
364
+
365
+
366
+ # --8<-- [end:docs]
367
+ class TraversalNode(BaseNode):
368
+ next_node: str = Field(serialization_alias="next")
369
+ on_failure: str = Field(default="")
370
+ overrides: Dict[str, str] = Field(default_factory=dict)
371
+
372
+ def _get_on_failure_node(self) -> str:
373
+ """
374
+ If the node defines a on_failure node in the config, return this or None.
375
+
376
+ The naming is relative to the dag, the caller is supposed to resolve it to the correct graph
377
+
378
+ Returns:
379
+ str: The on_failure node defined by the dag or ''
380
+ This is a base implementation which the BaseNode does not satisfy
381
+ """
382
+ return self.on_failure
383
+
384
+ def _get_next_node(self) -> str:
385
+ """
386
+ Return the next node as defined by the config.
387
+
388
+ Returns:
389
+ str: The node name, relative to the dag, as defined by the config
390
+ """
391
+
392
+ return self.next_node
393
+
394
+ def _is_terminal_node(self) -> bool:
395
+ """
396
+ Returns whether a node has a next node
397
+
398
+ Returns:
399
+ bool: True or False of whether there is next node.
400
+ """
401
+ return False
402
+
403
+ def _get_executor_config(self, executor_type) -> str:
404
+ return self.overrides.get(executor_type) or ""
405
+
406
+
407
+ class CatalogStructure(BaseModel):
408
+ model_config = ConfigDict(extra="forbid") # Need to forbid
409
+
410
+ get: List[str] = Field(default_factory=list)
411
+ put: List[str] = Field(default_factory=list)
412
+
413
+
414
+ class ExecutableNode(TraversalNode):
415
+ catalog: Optional[CatalogStructure] = Field(default=None)
416
+ max_attempts: int = Field(default=1, ge=1)
417
+
418
+ def _get_catalog_settings(self) -> Dict[str, Any]:
419
+ """
420
+ If the node defines a catalog settings, return it or None
421
+
422
+ Returns:
423
+ dict: catalog settings defined as per the node or None
424
+ """
425
+ if self.catalog:
426
+ return self.catalog.model_dump()
427
+ return {}
428
+
429
+ def _get_max_attempts(self) -> int:
430
+ return self.max_attempts
431
+
432
+ def _get_branch_by_name(self, branch_name: str):
433
+ raise Exception("This is an executable node and does not have branches")
434
+
435
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
436
+ raise Exception("This is an executable node and does not have a graph")
437
+
438
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
439
+ raise Exception("This is an executable node and does not have a fan in")
440
+
441
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
442
+ raise Exception("This is an executable node and does not have a fan out")
443
+
444
+
445
+ class CompositeNode(TraversalNode):
446
+ def _get_catalog_settings(self) -> Dict[str, Any]:
447
+ """
448
+ If the node defines a catalog settings, return it or None
449
+
450
+ Returns:
451
+ dict: catalog settings defined as per the node or None
452
+ """
453
+ raise Exception("This is a composite node and does not have a catalog settings")
454
+
455
+ def _get_max_attempts(self) -> int:
456
+ raise Exception("This is a composite node and does not have a max_attempts")
457
+
458
+ def execute(
459
+ self,
460
+ mock=False,
461
+ params: Optional[Dict[str, Any]] = None,
462
+ map_variable: TypeMapVariable = None,
463
+ **kwargs,
464
+ ) -> StepAttempt:
465
+ raise Exception("This is a composite node and does not have an execute function")
466
+
467
+
468
+ class TerminalNode(BaseNode):
469
+ def _get_on_failure_node(self) -> str:
470
+ raise exceptions.TerminalNodeError()
471
+
472
+ def _get_next_node(self) -> str:
473
+ raise exceptions.TerminalNodeError()
474
+
475
+ def _is_terminal_node(self) -> bool:
476
+ return True
477
+
478
+ def _get_catalog_settings(self) -> Dict[str, Any]:
479
+ raise exceptions.TerminalNodeError()
480
+
481
+ def _get_branch_by_name(self, branch_name: str):
482
+ raise exceptions.TerminalNodeError()
483
+
484
+ def _get_executor_config(self, executor_type) -> str:
485
+ raise exceptions.TerminalNodeError()
486
+
487
+ def _get_max_attempts(self) -> int:
488
+ return 1
489
+
490
+ def execute_as_graph(self, map_variable: TypeMapVariable = None, **kwargs):
491
+ raise exceptions.TerminalNodeError()
492
+
493
+ def fan_in(self, map_variable: TypeMapVariable = None, **kwargs):
494
+ raise exceptions.TerminalNodeError()
495
+
496
+ def fan_out(self, map_variable: TypeMapVariable = None, **kwargs):
497
+ raise exceptions.TerminalNodeError()
498
+
499
+ @classmethod
500
+ def parse_from_config(cls, config: Dict[str, Any]) -> "TerminalNode":
501
+ return cls(**config)
runnable/parameters.py ADDED
@@ -0,0 +1,183 @@
1
+ import inspect
2
+ import json
3
+ import logging
4
+ import os
5
+ from typing import Any, Dict, Optional, Type, Union
6
+
7
+ from pydantic import BaseModel, ConfigDict
8
+ from typing_extensions import Callable
9
+
10
+ from runnable import defaults
11
+ from runnable.defaults import TypeMapVariable
12
+ from runnable.utils import remove_prefix
13
+
14
+ logger = logging.getLogger(defaults.LOGGER_NAME)
15
+
16
+
17
+ def get_user_set_parameters(remove: bool = False) -> Dict[str, Any]:
18
+ """
19
+ Scans the environment variables for any user returned parameters that have a prefix runnable_PRM_.
20
+
21
+ This function does not deal with any type conversion of the parameters.
22
+ It just deserializes the parameters and returns them as a dictionary.
23
+
24
+ Args:
25
+ remove (bool, optional): Flag to remove the parameter if needed. Defaults to False.
26
+
27
+ Returns:
28
+ dict: The dictionary of found user returned parameters
29
+ """
30
+ parameters = {}
31
+ for env_var, value in os.environ.items():
32
+ if env_var.startswith(defaults.PARAMETER_PREFIX):
33
+ key = remove_prefix(env_var, defaults.PARAMETER_PREFIX)
34
+ try:
35
+ parameters[key.lower()] = json.loads(value)
36
+ except json.decoder.JSONDecodeError:
37
+ logger.error(f"Parameter {key} could not be JSON decoded, adding the literal value")
38
+ parameters[key.lower()] = value
39
+
40
+ if remove:
41
+ del os.environ[env_var]
42
+ return parameters
43
+
44
+
45
+ def set_user_defined_params_as_environment_variables(params: Dict[str, Any]):
46
+ """
47
+ Sets the user set parameters as environment variables.
48
+
49
+ At this point in time, the params are already in Dict or some kind of literal
50
+
51
+ Args:
52
+ parameters (Dict[str, Any]): The parameters to set as environment variables
53
+ update (bool, optional): Flag to update the environment variables. Defaults to True.
54
+
55
+ """
56
+ for key, value in params.items():
57
+ logger.info(f"Storing parameter {key} with value: {value}")
58
+ environ_key = defaults.PARAMETER_PREFIX + key
59
+
60
+ os.environ[environ_key] = serialize_parameter_as_str(value)
61
+
62
+
63
+ def cast_parameters_as_type(value: Any, newT: Optional[Type] = None) -> Union[Any, BaseModel, Dict[str, Any]]:
64
+ """
65
+ Casts the environment variable to the given type.
66
+
67
+ Note: Only pydantic models special, everything else just goes through.
68
+
69
+ Args:
70
+ value (Any): The value to cast
71
+ newT (T): The type to cast to
72
+
73
+ Returns:
74
+ T: The casted value
75
+
76
+ Examples:
77
+ >>> class MyBaseModel(BaseModel):
78
+ ... a: int
79
+ ... b: str
80
+ >>>
81
+ >>> class MyDict(dict):
82
+ ... pass
83
+ >>>
84
+ >>> cast_parameters_as_type({"a": 1, "b": "2"}, MyBaseModel)
85
+ MyBaseModel(a=1, b="2")
86
+ >>> cast_parameters_as_type({"a": 1, "b": "2"}, MyDict)
87
+ MyDict({'a': 1, 'b': '2'})
88
+ >>> cast_parameters_as_type(MyBaseModel(a=1, b="2"), MyBaseModel)
89
+ MyBaseModel(a=1, b="2")
90
+ >>> cast_parameters_as_type(MyDict({"a": 1, "b": "2"}), MyBaseModel)
91
+ MyBaseModel(a=1, b="2")
92
+ >>> cast_parameters_as_type({"a": 1, "b": "2"}, MyDict[str, int])
93
+ MyDict({'a': 1, 'b': '2'})
94
+ >>> cast_parameters_as_type({"a": 1, "b": "2"}, Dict[str, int])
95
+ MyDict({'a': 1, 'b': '2'})
96
+ >>> with pytest.warns(UserWarning):
97
+ ... cast_parameters_as_type(1, MyBaseModel)
98
+ MyBaseModel(a=1, b=None)
99
+ >>> with pytest.raises(TypeError):
100
+ ... cast_parameters_as_type(1, MyDict)
101
+ """
102
+ if not newT:
103
+ return value
104
+
105
+ if issubclass(newT, BaseModel):
106
+ return newT(**value)
107
+
108
+ if issubclass(newT, Dict):
109
+ return dict(value)
110
+
111
+ if type(value) != newT:
112
+ logger.warning(f"Casting {value} of {type(value)} to {newT} seems wrong!!")
113
+
114
+ return newT(value)
115
+
116
+
117
+ def serialize_parameter_as_str(value: Any) -> str:
118
+ if isinstance(value, BaseModel):
119
+ return json.dumps(value.model_dump())
120
+
121
+ return json.dumps(value)
122
+
123
+
124
+ def filter_arguments_for_func(
125
+ func: Callable[..., Any], params: Dict[str, Any], map_variable: TypeMapVariable = None
126
+ ) -> Dict[str, Any]:
127
+ """
128
+ Inspects the function to be called as part of the pipeline to find the arguments of the function.
129
+ Matches the function arguments to the parameters available either by command line or by up stream steps.
130
+
131
+
132
+ Args:
133
+ func (Callable): The function to inspect
134
+ parameters (dict): The parameters available for the run
135
+
136
+ Returns:
137
+ dict: The parameters matching the function signature
138
+ """
139
+ function_args = inspect.signature(func).parameters
140
+
141
+ # Update parameters with the map variables
142
+ params.update(map_variable or {})
143
+
144
+ unassigned_params = set(params.keys())
145
+ bound_args = {}
146
+ for name, value in function_args.items():
147
+ if name not in params:
148
+ # No parameter of this name was provided
149
+ if value.default == inspect.Parameter.empty:
150
+ # No default value is given in the function signature. error as parameter is required.
151
+ raise ValueError(f"Parameter {name} is required for {func.__name__} but not provided")
152
+ # default value is given in the function signature, nothing further to do.
153
+ continue
154
+
155
+ if issubclass(value.annotation, BaseModel):
156
+ # We try to cast it as a pydantic model.
157
+ named_param = params[name]
158
+
159
+ if not isinstance(named_param, dict):
160
+ # A case where the parameter is a one attribute model
161
+ named_param = {name: named_param}
162
+
163
+ bound_model = bind_args_for_pydantic_model(named_param, value.annotation)
164
+ bound_args[name] = bound_model
165
+ unassigned_params = unassigned_params.difference(bound_model.model_fields.keys())
166
+ else:
167
+ # simple python data type.
168
+ bound_args[name] = cast_parameters_as_type(params[name], value.annotation) # type: ignore
169
+
170
+ unassigned_params.remove(name)
171
+
172
+ params = {key: params[key] for key in unassigned_params} # remove keys from params if they are assigned
173
+
174
+ return bound_args
175
+
176
+
177
+ def bind_args_for_pydantic_model(params: Dict[str, Any], model: Type[BaseModel]) -> BaseModel:
178
+ class EasyModel(model): # type: ignore
179
+ model_config = ConfigDict(extra="ignore")
180
+
181
+ swallow_all = EasyModel(**params)
182
+ bound_model = model(**swallow_all.model_dump())
183
+ return bound_model