dstklib 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dstk/workflow_tools.py ADDED
@@ -0,0 +1,257 @@
1
+ from functools import wraps
2
+ import warnings
3
+ import inspect
4
+ from copy import deepcopy
5
+
6
+ from .lib_types.dstk_types import Function, MethodSpec
7
+ from inspect import Signature, BoundArguments
8
+ from typing import Any, cast, Callable, Type, TypeAlias, TypeGuard
9
+
10
+ class WorkflowManager:
11
+ """
12
+ Manages the execution of processing methods in workflow mode.
13
+
14
+ Tracks workflow state, controls stage transitions, and stores intermediate
15
+ results for chained method execution with enforced sequencing and unit context.
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ """
20
+ Initializes WorkflowManager with given attributes.
21
+ """
22
+
23
+ self._flow: bool
24
+ self._current_stage: str
25
+ self._processing_unit: str
26
+
27
+ self._stages: list[str]
28
+
29
+ self._start: Any
30
+ self._end: Any
31
+
32
+ self._called_methods: list[str] = []
33
+
34
+ def _set_workflow(self, input_arg: Any) -> None:
35
+ """
36
+ Initializes workflow mode based on the presence of input arguments.
37
+
38
+ Sets workflow state and starting point if all required inputs are provided.
39
+ Issues a warning and disables workflow mode if inputs are partially missing.
40
+
41
+ :param input_args: A dictionary of input argument names and their values. If all values are non-None, workflow mode is activated.
42
+ :param input_source: The initial data source to store when starting the workflow.
43
+ """
44
+
45
+ if input_arg is not None:
46
+ self._start = input_arg
47
+ self._current_stage = "start"
48
+ self._flow = True
49
+ else:
50
+ self._flow = False
51
+
52
+ @property
53
+ def result(self) -> Any:
54
+ """
55
+ Returns the current output of the processing workflow.
56
+
57
+ Use this property to retrieve the final result after a chain of workflow method calls. It safely copies the internal state (if possible) to prevent side effects.
58
+
59
+ :return: The result of the most recent workflow stage.
60
+ """
61
+
62
+ result: Any = getattr(self, f"_{self._current_stage}")
63
+ try:
64
+ copy: Any = deepcopy(result)
65
+ return copy
66
+ except:
67
+ return result
68
+
69
+
70
+ class WorkflowBuilder:
71
+ """
72
+ Automates the execution of a sequence of methods on a WorkflowManager subclass.
73
+
74
+ :param work_class: A subclass of WorkflowManager representing the workflow to execute.
75
+ :param method_representation: A dictionary mapping method names to their keyword arguments.
76
+ :param result: If True, returns the result of the workflow. Else, returns the instance of the working class. Defaults to True.
77
+ """
78
+
79
+ def __init__(self, work_class: Type[WorkflowManager], method_representation: MethodSpec, result: bool = True):
80
+ """
81
+ Initializes WorkflowBuilder with given attributes.
82
+ """
83
+
84
+ self.work_class: type = work_class
85
+ self.methods: MethodSpec = method_representation
86
+ self.result: bool = result
87
+
88
+ def __call__(self, *args, **kwargs) -> Any:
89
+ workflow: Type[WorkflowManager] = self.work_class(*args, **kwargs)
90
+
91
+ for key, value in self.methods.items():
92
+ method: Callable = getattr(workflow, key)
93
+ method(**value)
94
+
95
+ return workflow.result if self.result else workflow
96
+
97
+
98
+ def workflow(input_arg: str, input_process: str, output_process: str, input_attrs: dict[str, Any] | None = None, next_stage: str | None = None, set_unit: str | None = None,) -> Callable[[Function], Function]:
99
+ """
100
+ Enables workflow execution for a method by automatically injecting inputs,
101
+ storing outputs, and transitioning stages when in workflow mode.
102
+
103
+ :param input_arg: Name of the keyword argument to inject into the method.
104
+ :param input_process: Attribute name to retrieve the input data from if not provided.
105
+ :param output_process: Attribute name to store the method's output in workflow mode.
106
+ :param input_attrs: Optional mapping of argument names to extract from input data. Supports str for attribute access or nested dicts for deep lookup. Defaults to None.
107
+ :param next_stage: Optional name of the next workflow stage to transition to after method execution. Defaults to None.
108
+ :param set_unit: Optional name of the processing unit to set for the next workflow step. Defaults to None.
109
+
110
+ :return: A decorator that wraps the method with workflow logic.
111
+ :raises ValueError: If the user is not in workflow mode and he did not passed thte input arg. Also, if the value of input_attrs is different from None, str or dict.
112
+ """
113
+
114
+ def decorator(method: Function) -> Function:
115
+ @wraps(method)
116
+ def wrapper(self, *args, **kwargs) -> Any:
117
+ method_name: str = method.__name__
118
+
119
+ if input_arg not in kwargs:
120
+ if not self._flow:
121
+ raise ValueError(f"{input_arg} must be provided if not using workflow mode")
122
+
123
+ input_data: Any = getattr(self, input_process)
124
+
125
+ if input_attrs:
126
+ for key, value in input_attrs.items():
127
+ if value is None:
128
+ kwargs[key] = input_data
129
+ elif isinstance(value, str):
130
+ kwargs[key] = getattr(input_data, value)
131
+ elif isinstance(value, dict):
132
+ for attr_name, subattr in value.items():
133
+ mapping = getattr(input_data, attr_name)[subattr]
134
+ kwargs[key] = mapping
135
+ else:
136
+ raise ValueError(f"Type {type(value)} of value bot recognized")
137
+ else:
138
+ kwargs[input_arg] = input_data
139
+
140
+ result: Any = method(self, *args, **kwargs)
141
+
142
+ if self._flow:
143
+ setattr(self, output_process, result)
144
+ if next_stage:
145
+ self._current_stage = next_stage
146
+ if set_unit:
147
+ self._processing_unit = set_unit
148
+ if next_stage == "end":
149
+ warnings.warn(UserWarning(f"After calling method {method_name} you must necessarily call result to continue with analysis. Further chaining will result in error."))
150
+
151
+ self._called_methods.append(method_name)
152
+ return self
153
+
154
+ return result
155
+
156
+ return cast(Function, wrapper)
157
+
158
+ return decorator
159
+
160
+ def requires(stages: list[str], unit: str | None = None, multiple_calls: bool = False) -> Callable[[Function], Function]:
161
+ """
162
+ Ensures a method is only callable in workflow mode at allowed stages and units,
163
+ and prevents repeated calls to the same method.
164
+
165
+ :param stages: A list of stages where the method is allowed to be used.
166
+ :param unit: The required processing unit; if specified, the method can only run when the current unit matches. Defaults to None.
167
+
168
+ :return: A decorator that enforces workflow constraints on the wrapped method.
169
+ """
170
+
171
+ def decorator(method: Function) -> Function:
172
+ @wraps(method)
173
+ def wrapper(self, *args, **kwargs) -> Any:
174
+ method_name: str = method.__name__
175
+
176
+ if self._flow:
177
+ if method_name in self._called_methods and not multiple_calls:
178
+ raise RuntimeError(f"Method {method_name} already called. You can only call each method exactly once in workflow mode.")
179
+ if not hasattr(self, "_current_stage"):
180
+ raise RuntimeError("Current phase is not initialized.")
181
+ if self._current_stage not in stages:
182
+ raise RuntimeError(
183
+ f"Method '{method_name}' requires stages '{', '.join(stages)}' but current phase is '{self._current_stage}'"
184
+ )
185
+ if unit and unit != self._processing_unit:
186
+ raise RuntimeError(
187
+ f"Method '{method_name}' can only be used when the you are processing by {unit}, but you are currently processing by {self._processing_unit}."
188
+ )
189
+
190
+ return method(self, *args, **kwargs)
191
+
192
+ return cast(Function, wrapper)
193
+
194
+ return decorator
195
+
196
+ def accepts_generic(*, type_checker: Callable, input_arg: str, accepts: bool, intercept: bool, interceptor: Callable, input_type: TypeAlias, custom_error_message: str = "") -> Callable[[Function], Function]:
197
+ """
198
+ A generic decorator factory that conditionally intercepts and transforms a method's input based on runtime type checks.
199
+
200
+ This utility enables creating flexible, reusable decorators (like `accepts_sentences` or `accepts_tags`) by delegating the interception logic to a custom `interceptor` and input validation to a `type_checker`.
201
+
202
+ If the input value (specified by `input_arg`) passes the `type_checker`, and both `accepts` and `intercept` are True, the `interceptor` is called instead of the original method. Otherwise, the original method is called directly. If the input is of the expected type but `accepts` is False, a `ValueError` is raised.
203
+
204
+ :param type_checker: Function that checks whether the input value is of the expected structure/type.
205
+ :param input_arg: The name of the argument to inspect in the decorated method.
206
+ :param accepts: Whether the method is allowed to handle this type of input.
207
+ :param intercept: Whether the input should be transformed before calling the method.
208
+ :param interceptor: Function that handles the interception logic, replacing the original method call when triggered.
209
+ :param input_type: A human-readable type description used in error messages and type casts. This is for documentation/static typing purposes only.
210
+ :param custom_error_message: Optional custom message to append to error if input is rejected.
211
+
212
+ :returns: Callable: A decorator that wraps the target method with conditional input handling logic.
213
+
214
+ :raises: ValueError: If the input matches the expected type but `accepts` is False.
215
+ """
216
+
217
+ def decorator(method: Function) -> Function:
218
+ @wraps(method)
219
+ def wrapper(self, *args, **kwargs) -> Any:
220
+ signature: Signature = inspect.signature(method)
221
+ bound_args: BoundArguments = signature.bind(self, *args, **kwargs)
222
+ bound_args.apply_defaults()
223
+
224
+ input_value: Any = bound_args.arguments.get(input_arg, None)
225
+
226
+ if type_checker(input_value) and (accepts and intercept):
227
+ filtered_kwargs: dict[str, Any] = kwargs.copy()
228
+ filtered_kwargs.pop(input_arg, None)
229
+
230
+ return interceptor(self, input_value=input_value, method=method, *args, **filtered_kwargs)
231
+ elif not type_checker(input_value) or (accepts and not intercept):
232
+ return cast(input_type, method(self, *args, **kwargs))
233
+ else:
234
+ raise ValueError(f"Method {method.__name__} does not accept {input_type} as input. {custom_error_message}")
235
+
236
+ return cast(Function, wrapper)
237
+ return decorator
238
+
239
+ def is_method_spec(spec: Any) -> TypeGuard[MethodSpec]:
240
+ """
241
+ If spec is of type MethodSpec, returns True. Else, returns False.
242
+
243
+ :param spec: A dict to check its type.
244
+ """
245
+
246
+ return (
247
+ isinstance(spec, dict) and
248
+ all(
249
+ isinstance(method_name, str) and
250
+ isinstance(params, dict) and
251
+ all(
252
+ isinstance(key, str)
253
+ for key in params.keys()
254
+ )
255
+ for method_name, params in spec.items()
256
+ )
257
+ )