flowrep 0.post0.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flowrep/__init__.py +3 -0
- flowrep/_version.py +716 -0
- flowrep/workflow.py +1256 -0
- flowrep-0.post0.dev1.dist-info/METADATA +59 -0
- flowrep-0.post0.dev1.dist-info/RECORD +8 -0
- flowrep-0.post0.dev1.dist-info/WHEEL +5 -0
- flowrep-0.post0.dev1.dist-info/licenses/LICENSE +29 -0
- flowrep-0.post0.dev1.dist-info/top_level.txt +1 -0
flowrep/workflow.py
ADDED
|
@@ -0,0 +1,1256 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import builtins
|
|
3
|
+
import copy
|
|
4
|
+
import dataclasses
|
|
5
|
+
import inspect
|
|
6
|
+
import textwrap
|
|
7
|
+
from collections import deque
|
|
8
|
+
from collections.abc import Callable, Iterable
|
|
9
|
+
from functools import cached_property, update_wrapper
|
|
10
|
+
from typing import Any, Generic, TypeVar, cast, get_args, get_origin
|
|
11
|
+
|
|
12
|
+
import networkx as nx
|
|
13
|
+
from networkx.algorithms.dag import topological_sort
|
|
14
|
+
from semantikon.converter import (
|
|
15
|
+
get_annotated_type_hints,
|
|
16
|
+
get_return_expressions,
|
|
17
|
+
get_return_labels,
|
|
18
|
+
meta_to_dict,
|
|
19
|
+
parse_input_args,
|
|
20
|
+
parse_output_args,
|
|
21
|
+
)
|
|
22
|
+
from semantikon.datastructure import (
|
|
23
|
+
MISSING,
|
|
24
|
+
CoreMetadata,
|
|
25
|
+
Edges,
|
|
26
|
+
Function,
|
|
27
|
+
Input,
|
|
28
|
+
Inputs,
|
|
29
|
+
Missing,
|
|
30
|
+
Nodes,
|
|
31
|
+
Output,
|
|
32
|
+
Outputs,
|
|
33
|
+
PortType,
|
|
34
|
+
TypeMetadata,
|
|
35
|
+
Workflow,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
F = TypeVar("F", bound=Callable[..., object])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FunctionWithWorkflow(Generic[F]):
|
|
42
|
+
def __init__(self, func: F, workflow: dict[str, object], run) -> None:
|
|
43
|
+
self.func = func
|
|
44
|
+
self._semantikon_workflow: dict[str, object] = workflow
|
|
45
|
+
self.run = run
|
|
46
|
+
update_wrapper(self, func) # Copies __name__, __doc__, etc.
|
|
47
|
+
|
|
48
|
+
def __call__(self, *args, **kwargs):
|
|
49
|
+
return self.func(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
def __getattr__(self, item):
|
|
52
|
+
return getattr(self.func, item)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def separate_types(
|
|
56
|
+
data: dict[str, Any], class_dict: dict[str, type] | None = None
|
|
57
|
+
) -> tuple[dict[str, Any], dict[str, type]]:
|
|
58
|
+
"""
|
|
59
|
+
Separate types from the data dictionary and store them in a class dictionary.
|
|
60
|
+
The types inside the data dictionary will be replaced by their name (which
|
|
61
|
+
would for example make it easier to hash it).
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
data (dict[str, Any]): The data dictionary containing nodes and types.
|
|
65
|
+
class_dict (dict[str, type], optional): A dictionary to store types. It
|
|
66
|
+
is mainly used due to the recursivity of this function. Defaults to
|
|
67
|
+
None.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
tuple: A tuple containing the modified data dictionary and the
|
|
71
|
+
class dictionary.
|
|
72
|
+
"""
|
|
73
|
+
data = copy.deepcopy(data)
|
|
74
|
+
if class_dict is None:
|
|
75
|
+
class_dict = {}
|
|
76
|
+
if "nodes" in data:
|
|
77
|
+
for key, node in data["nodes"].items():
|
|
78
|
+
child_node, child_class_dict = separate_types(node, class_dict)
|
|
79
|
+
class_dict.update(child_class_dict)
|
|
80
|
+
data["nodes"][key] = child_node
|
|
81
|
+
for io_ in ["inputs", "outputs"]:
|
|
82
|
+
for key, content in data[io_].items():
|
|
83
|
+
if "dtype" in content and isinstance(content["dtype"], type):
|
|
84
|
+
class_dict[content["dtype"].__name__] = content["dtype"]
|
|
85
|
+
data[io_][key]["dtype"] = content["dtype"].__name__
|
|
86
|
+
return data, class_dict
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def separate_functions(
|
|
90
|
+
data: dict[str, Any], function_dict: dict[str, Callable] | None = None
|
|
91
|
+
) -> tuple[dict[str, Any], dict[str, Callable]]:
|
|
92
|
+
"""
|
|
93
|
+
Separate functions from the data dictionary and store them in a function
|
|
94
|
+
dictionary. The functions inside the data dictionary will be replaced by
|
|
95
|
+
their name (which would for example make it easier to hash it)
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
data (dict[str, Any]): The data dictionary containing nodes and
|
|
99
|
+
functions.
|
|
100
|
+
function_dict (dict[str, Callable], optional): A dictionary to store
|
|
101
|
+
functions. It is mainly used due to the recursivity of this
|
|
102
|
+
function. Defaults to None.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
tuple: A tuple containing the modified data dictionary and the
|
|
106
|
+
function dictionary.
|
|
107
|
+
"""
|
|
108
|
+
data = copy.deepcopy(data)
|
|
109
|
+
if function_dict is None:
|
|
110
|
+
function_dict = {}
|
|
111
|
+
if "nodes" in data:
|
|
112
|
+
for key, node in data["nodes"].items():
|
|
113
|
+
child_node, child_function_dict = separate_functions(node, function_dict)
|
|
114
|
+
function_dict.update(child_function_dict)
|
|
115
|
+
data["nodes"][key] = child_node
|
|
116
|
+
elif "function" in data and not isinstance(data["function"], str):
|
|
117
|
+
fnc_object = data["function"]
|
|
118
|
+
as_string = fnc_object.__module__ + "." + fnc_object.__qualname__
|
|
119
|
+
function_dict[as_string] = fnc_object
|
|
120
|
+
data["function"] = as_string
|
|
121
|
+
if "test" in data and not isinstance(data["test"]["function"], str):
|
|
122
|
+
fnc_object = data["test"]["function"]
|
|
123
|
+
as_string = fnc_object.__module__ + fnc_object.__qualname__
|
|
124
|
+
function_dict[as_string] = fnc_object
|
|
125
|
+
data["test"]["function"] = as_string
|
|
126
|
+
return data, function_dict
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
class FunctionDictFlowAnalyzer:
|
|
130
|
+
def __init__(self, ast_dict, scope):
|
|
131
|
+
self.graph = nx.DiGraph()
|
|
132
|
+
self.scope = scope # mapping from function names to objects
|
|
133
|
+
self.function_defs = {}
|
|
134
|
+
self.ast_dict = ast_dict
|
|
135
|
+
self._call_counter = {}
|
|
136
|
+
self._control_flow_list = []
|
|
137
|
+
self._parallel_var = {}
|
|
138
|
+
|
|
139
|
+
def analyze(self) -> tuple[nx.DiGraph, dict[str, Any]]:
|
|
140
|
+
for arg in self.ast_dict.get("args", {}).get("args", []):
|
|
141
|
+
if arg["_type"] == "arg":
|
|
142
|
+
self._add_output_edge("input", arg["arg"])
|
|
143
|
+
return_was_called = False
|
|
144
|
+
for node in self.ast_dict.get("body", []):
|
|
145
|
+
assert not return_was_called
|
|
146
|
+
self._visit_node(node)
|
|
147
|
+
if node["_type"] == "Return":
|
|
148
|
+
return_was_called = True
|
|
149
|
+
return self.graph, self.function_defs
|
|
150
|
+
|
|
151
|
+
def _visit_node(self, node, control_flow: str | None = None):
|
|
152
|
+
if node["_type"] == "Assign":
|
|
153
|
+
self._handle_assign(node, control_flow=control_flow)
|
|
154
|
+
elif node["_type"] == "Expr":
|
|
155
|
+
self._handle_expr(node, control_flow=control_flow)
|
|
156
|
+
elif node["_type"] == "While":
|
|
157
|
+
self._handle_while(node, control_flow=control_flow)
|
|
158
|
+
elif node["_type"] == "For":
|
|
159
|
+
self._handle_for(node, control_flow=control_flow)
|
|
160
|
+
elif node["_type"] == "If":
|
|
161
|
+
self._handle_if(node, control_flow=control_flow)
|
|
162
|
+
elif node["_type"] == "Return":
|
|
163
|
+
self._handle_return(node, control_flow=control_flow)
|
|
164
|
+
else:
|
|
165
|
+
raise NotImplementedError(f"Node type {node['_type']} not implemented")
|
|
166
|
+
|
|
167
|
+
def _handle_return(self, node, control_flow: str | None = None):
|
|
168
|
+
if not node["value"]:
|
|
169
|
+
return
|
|
170
|
+
if node["value"]["_type"] == "Tuple":
|
|
171
|
+
for idx, elt in enumerate(node["value"]["elts"]):
|
|
172
|
+
if elt["_type"] != "Name":
|
|
173
|
+
raise NotImplementedError("Only variable returns supported")
|
|
174
|
+
self._add_input_edge(elt, "output", input_index=idx)
|
|
175
|
+
elif node["value"]["_type"] == "Name":
|
|
176
|
+
self._add_input_edge(node["value"], "output")
|
|
177
|
+
|
|
178
|
+
def _handle_if(self, node, control_flow: str | None = None):
|
|
179
|
+
assert node["test"]["_type"] == "Call"
|
|
180
|
+
control_flow = self._convert_control_flow(control_flow, tag="If")
|
|
181
|
+
self._parse_function_call(node["test"], control_flow=f"{control_flow}-test")
|
|
182
|
+
for n in node["body"]:
|
|
183
|
+
self._visit_node(n, control_flow=f"{control_flow}-body")
|
|
184
|
+
for n in node.get("orelse", []):
|
|
185
|
+
cf_else = "/".join(
|
|
186
|
+
control_flow.split("/")[:-1]
|
|
187
|
+
+ [control_flow.split("/")[-1].replace("If", "Else") + "-body"]
|
|
188
|
+
)
|
|
189
|
+
self._visit_node(n, control_flow=cf_else)
|
|
190
|
+
self._reconnect_parallel(cf_else, f"{control_flow}-body")
|
|
191
|
+
self._register_parallel_variables(cf_else, f"{control_flow}-body")
|
|
192
|
+
|
|
193
|
+
def _reconnect_parallel(self, control_flow: str, ref_control_flow: str):
|
|
194
|
+
all_edges = list(self.graph.edges.data())
|
|
195
|
+
for edge in all_edges:
|
|
196
|
+
if (
|
|
197
|
+
"control_flow" not in edge[2]
|
|
198
|
+
or edge[2]["control_flow"] != control_flow
|
|
199
|
+
or edge[2]["type"] == "output"
|
|
200
|
+
):
|
|
201
|
+
continue
|
|
202
|
+
var, ind = "_".join(edge[0].split("_")[:-1]), int(edge[0].split("_")[-1])
|
|
203
|
+
while True:
|
|
204
|
+
if any(
|
|
205
|
+
[
|
|
206
|
+
e[2].get("control_flow") == ref_control_flow
|
|
207
|
+
for e in self.graph.in_edges(f"{var}_{ind}", data=True)
|
|
208
|
+
]
|
|
209
|
+
):
|
|
210
|
+
ind -= 1
|
|
211
|
+
break
|
|
212
|
+
if f"{var}_{ind}" != edge[0]:
|
|
213
|
+
self.graph.add_edge(f"{var}_{ind}", edge[1], **edge[2])
|
|
214
|
+
self.graph.remove_edge(edge[0], edge[1])
|
|
215
|
+
|
|
216
|
+
def _register_parallel_variables(self, control_flow: str, ref_control_flow: str):
|
|
217
|
+
data: dict[str, dict] = {control_flow: {}, ref_control_flow: {}}
|
|
218
|
+
for edge in self.graph.edges.data():
|
|
219
|
+
if (
|
|
220
|
+
edge[2].get("control_flow", "") in [control_flow, ref_control_flow]
|
|
221
|
+
and edge[2]["type"] == "output"
|
|
222
|
+
):
|
|
223
|
+
data[edge[2]["control_flow"]][edge[1].rsplit("_", 1)[0]] = edge[1]
|
|
224
|
+
for key in set(data[control_flow].keys()).intersection(
|
|
225
|
+
data[ref_control_flow].keys()
|
|
226
|
+
):
|
|
227
|
+
values = sorted([data[control_flow][key], data[ref_control_flow][key]])
|
|
228
|
+
self._parallel_var[values[-1]] = [values[0]]
|
|
229
|
+
if values[0] in self._parallel_var:
|
|
230
|
+
self._parallel_var[values[-1]].extend(self._parallel_var.pop(values[0]))
|
|
231
|
+
|
|
232
|
+
def _handle_while(self, node, control_flow: str | None = None):
|
|
233
|
+
assert node["test"]["_type"] == "Call"
|
|
234
|
+
control_flow = self._convert_control_flow(control_flow, tag="While")
|
|
235
|
+
self._parse_function_call(node["test"], control_flow=f"{control_flow}-test")
|
|
236
|
+
for n in node["body"]:
|
|
237
|
+
self._visit_node(n, control_flow=f"{control_flow}-body")
|
|
238
|
+
|
|
239
|
+
def _handle_for(self, node, control_flow: str | None = None):
|
|
240
|
+
assert node["iter"]["_type"] == "Call"
|
|
241
|
+
control_flow = self._convert_control_flow(control_flow, tag="For")
|
|
242
|
+
|
|
243
|
+
unique_func_name = self._parse_function_call(
|
|
244
|
+
node["iter"], control_flow=f"{control_flow}-iter"
|
|
245
|
+
)
|
|
246
|
+
self._parse_outputs(
|
|
247
|
+
[node["target"]], unique_func_name, control_flow=control_flow
|
|
248
|
+
)
|
|
249
|
+
for n in node["body"]:
|
|
250
|
+
self._visit_node(n, control_flow=f"{control_flow}-body")
|
|
251
|
+
|
|
252
|
+
def _handle_expr(self, node, control_flow: str | None = None) -> str:
|
|
253
|
+
value = node["value"]
|
|
254
|
+
return self._parse_function_call(value, control_flow=control_flow)
|
|
255
|
+
|
|
256
|
+
def _parse_function_call(self, value, control_flow: str | None = None) -> str:
|
|
257
|
+
if value["_type"] != "Call":
|
|
258
|
+
raise NotImplementedError(
|
|
259
|
+
f"Only function calls allowed on RHS: {value['_type']}"
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
func_node = value["func"]
|
|
263
|
+
if func_node["_type"] != "Name":
|
|
264
|
+
raise NotImplementedError("Only simple functions allowed")
|
|
265
|
+
|
|
266
|
+
func_name = func_node["id"]
|
|
267
|
+
unique_func_name = self._get_unique_func_name(func_name)
|
|
268
|
+
|
|
269
|
+
if func_name not in self.scope:
|
|
270
|
+
raise ValueError(f"Function {func_name} not found in scope")
|
|
271
|
+
|
|
272
|
+
self.function_defs[unique_func_name] = {"function": self.scope[func_name]}
|
|
273
|
+
if control_flow is not None:
|
|
274
|
+
self.function_defs[unique_func_name]["control_flow"] = control_flow
|
|
275
|
+
|
|
276
|
+
# Parse inputs (positional + keyword)
|
|
277
|
+
for i, arg in enumerate(value.get("args", [])):
|
|
278
|
+
self._add_input_edge(
|
|
279
|
+
arg, unique_func_name, input_index=i, control_flow=control_flow
|
|
280
|
+
)
|
|
281
|
+
for kw in value.get("keywords", []):
|
|
282
|
+
self._add_input_edge(
|
|
283
|
+
kw["value"],
|
|
284
|
+
unique_func_name,
|
|
285
|
+
input_name=kw["arg"],
|
|
286
|
+
control_flow=control_flow,
|
|
287
|
+
)
|
|
288
|
+
return unique_func_name
|
|
289
|
+
|
|
290
|
+
def _handle_assign(self, node, control_flow: str | None = None):
|
|
291
|
+
unique_func_name = self._handle_expr(node, control_flow=control_flow)
|
|
292
|
+
self._parse_outputs(
|
|
293
|
+
node["targets"], unique_func_name, control_flow=control_flow
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
def _parse_outputs(
|
|
297
|
+
self, targets, unique_func_name, control_flow: str | None = None
|
|
298
|
+
):
|
|
299
|
+
if len(targets) == 1 and targets[0]["_type"] == "Tuple":
|
|
300
|
+
for idx, elt in enumerate(targets[0]["elts"]):
|
|
301
|
+
self._add_output_edge(
|
|
302
|
+
unique_func_name,
|
|
303
|
+
elt["id"],
|
|
304
|
+
output_index=idx,
|
|
305
|
+
control_flow=control_flow,
|
|
306
|
+
)
|
|
307
|
+
else:
|
|
308
|
+
for target in targets:
|
|
309
|
+
self._add_output_edge(
|
|
310
|
+
unique_func_name, target["id"], control_flow=control_flow
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
def _get_max_index(self, variable: str) -> int:
|
|
314
|
+
index = 0
|
|
315
|
+
while True:
|
|
316
|
+
if len(self.graph.in_edges(f"{variable}_{index}")) > 0:
|
|
317
|
+
index += 1
|
|
318
|
+
continue
|
|
319
|
+
break
|
|
320
|
+
return index
|
|
321
|
+
|
|
322
|
+
def _get_var_index(self, variable: str, output: bool = False) -> int:
|
|
323
|
+
index = self._get_max_index(variable)
|
|
324
|
+
if index == 0 and not output:
|
|
325
|
+
raise KeyError(
|
|
326
|
+
f"Variable {variable} not found in graph. "
|
|
327
|
+
"This usually means that the variable was never defined."
|
|
328
|
+
)
|
|
329
|
+
if output:
|
|
330
|
+
return index
|
|
331
|
+
else:
|
|
332
|
+
return index - 1
|
|
333
|
+
|
|
334
|
+
def _add_output_edge(
|
|
335
|
+
self, source: str, target: str, control_flow: str | None = None, **kwargs
|
|
336
|
+
):
|
|
337
|
+
"""
|
|
338
|
+
Add an output edge from the source to the target variable.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
source (str): The source node (function name).
|
|
342
|
+
target (str): The target variable name.
|
|
343
|
+
control_flow (str | None): The control flow tag, if any.
|
|
344
|
+
**kwargs: Additional keyword arguments to pass to the edge.
|
|
345
|
+
|
|
346
|
+
In the case of the following line:
|
|
347
|
+
|
|
348
|
+
>>> y = f(x)
|
|
349
|
+
|
|
350
|
+
This function will add an edge from the function `f` to the variable `y`.
|
|
351
|
+
"""
|
|
352
|
+
versioned = f"{target}_{self._get_var_index(target, output=True)}"
|
|
353
|
+
if control_flow is not None:
|
|
354
|
+
kwargs["control_flow"] = control_flow
|
|
355
|
+
self.graph.add_edge(source, versioned, type="output", **kwargs)
|
|
356
|
+
|
|
357
|
+
def _add_input_edge(
|
|
358
|
+
self, source: dict, target: str, control_flow: str | None = None, **kwargs
|
|
359
|
+
):
|
|
360
|
+
"""
|
|
361
|
+
Add an input edge from the source variable to the target node.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
source (dict): The source variable node.
|
|
365
|
+
target (str): The target node (function name).
|
|
366
|
+
control_flow (str | None): The control flow tag, if any.
|
|
367
|
+
**kwargs: Additional keyword arguments to pass to the edge.
|
|
368
|
+
|
|
369
|
+
In the case of the following line:
|
|
370
|
+
|
|
371
|
+
>>> y = f(x)
|
|
372
|
+
|
|
373
|
+
This function will add an edge from the variable `x` to the function `f`.
|
|
374
|
+
"""
|
|
375
|
+
if source["_type"] != "Name":
|
|
376
|
+
raise NotImplementedError(f"Only variable inputs supported, got: {source}")
|
|
377
|
+
var_name = source["id"]
|
|
378
|
+
if control_flow is not None:
|
|
379
|
+
kwargs["control_flow"] = control_flow
|
|
380
|
+
versioned = f"{var_name}_{self._get_var_index(var_name)}"
|
|
381
|
+
self.graph.add_edge(versioned, target, type="input", **kwargs)
|
|
382
|
+
if versioned in self._parallel_var:
|
|
383
|
+
for key in self._parallel_var.pop(versioned):
|
|
384
|
+
self.graph.add_edge(key, target, type="input", **kwargs)
|
|
385
|
+
|
|
386
|
+
def _get_unique_func_name(self, base_name):
|
|
387
|
+
i = self._call_counter.get(base_name, 0)
|
|
388
|
+
self._call_counter[base_name] = i + 1
|
|
389
|
+
return f"{base_name}_{i}"
|
|
390
|
+
|
|
391
|
+
def _convert_control_flow(self, control_flow: str | None, tag: str) -> str:
|
|
392
|
+
control_flow = "" if control_flow is None else f"{control_flow.split('-')[0]}/"
|
|
393
|
+
counter = 0
|
|
394
|
+
while True:
|
|
395
|
+
if f"{control_flow}{tag}_{counter}" not in self._control_flow_list:
|
|
396
|
+
self._control_flow_list.append(f"{control_flow}{tag}_{counter}")
|
|
397
|
+
break
|
|
398
|
+
counter += 1
|
|
399
|
+
return f"{control_flow}{tag}_{counter}"
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
def _get_variables_from_subgraph(graph: nx.DiGraph, io_: str) -> set[str]:
|
|
403
|
+
"""
|
|
404
|
+
Get variables from a subgraph based on the type of I/O and control flow.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
graph (nx.DiGraph): The directed graph representing the function.
|
|
408
|
+
io_ (str): The type of I/O to filter by ("input" or "output").
|
|
409
|
+
control_flow (list, str): A list of control flow types to filter by.
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
set[str]: A set of variable names that match the specified I/O type and
|
|
413
|
+
control flow.
|
|
414
|
+
"""
|
|
415
|
+
assert io_ in ["input", "output"], "io_ must be 'input' or 'output'"
|
|
416
|
+
if io_ == "input":
|
|
417
|
+
edge_ind = 0
|
|
418
|
+
elif io_ == "output":
|
|
419
|
+
edge_ind = 1
|
|
420
|
+
return set(
|
|
421
|
+
[edge[edge_ind] for edge in graph.edges.data() if edge[2]["type"] == io_]
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def _get_parent_graph(graph: nx.DiGraph, control_flow: str) -> nx.DiGraph:
|
|
426
|
+
"""
|
|
427
|
+
Get parent body of the indented body
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
graph (nx.DiGraph): Full graph to look for the parent graph from
|
|
431
|
+
control_flow (str): Control flow whose parent graph is to look for
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
(nx.DiGraph): Parent graph
|
|
435
|
+
"""
|
|
436
|
+
return nx.DiGraph(
|
|
437
|
+
[
|
|
438
|
+
edge
|
|
439
|
+
for edge in graph.edges.data()
|
|
440
|
+
if not _get_control_flow(edge[2]).startswith(control_flow)
|
|
441
|
+
]
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def _detect_io_variables_from_control_flow(
|
|
446
|
+
graph: nx.DiGraph, subgraph: nx.DiGraph
|
|
447
|
+
) -> dict[str, list]:
|
|
448
|
+
"""
|
|
449
|
+
Detect input and output variables from a graph based on control flow.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
graph (nx.DiGraph): The directed graph representing the function.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
dict[str, set]: A dictionary with keys "input" and "output", each
|
|
456
|
+
containing a set
|
|
457
|
+
|
|
458
|
+
Take a look at the unit tests for examples of how to use this function.
|
|
459
|
+
"""
|
|
460
|
+
sg_body = nx.DiGraph(
|
|
461
|
+
[
|
|
462
|
+
edge
|
|
463
|
+
for edge in subgraph.edges.data()
|
|
464
|
+
if edge[0] != "input" and edge[1] != "output"
|
|
465
|
+
]
|
|
466
|
+
)
|
|
467
|
+
cf = sorted(
|
|
468
|
+
[
|
|
469
|
+
_get_control_flow(edge[2])
|
|
470
|
+
for edge in sg_body.edges.data()
|
|
471
|
+
if "control_flow" in edge[2]
|
|
472
|
+
]
|
|
473
|
+
)
|
|
474
|
+
if len(cf) == 0:
|
|
475
|
+
return {"inputs": [], "outputs": []}
|
|
476
|
+
assert all([cf[ii + 1].startswith(cf[ii]) for ii in range(len(cf) - 1)])
|
|
477
|
+
parent_graph = _get_parent_graph(graph, cf[0])
|
|
478
|
+
var_inp_1 = _get_variables_from_subgraph(graph=sg_body, io_="input")
|
|
479
|
+
var_inp_2 = _get_variables_from_subgraph(graph=parent_graph, io_="output")
|
|
480
|
+
var_out_1 = _get_variables_from_subgraph(graph=parent_graph, io_="input")
|
|
481
|
+
var_out_2 = _get_variables_from_subgraph(graph=sg_body, io_="output")
|
|
482
|
+
return {
|
|
483
|
+
"inputs": list(var_inp_1.intersection(var_inp_2)),
|
|
484
|
+
"outputs": list(var_out_1.intersection(var_out_2)),
|
|
485
|
+
}
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def _extract_control_flows(graph: nx.DiGraph) -> list[str]:
|
|
489
|
+
return list(set([_get_control_flow(edge[2]) for edge in graph.edges.data()]))
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _split_graphs_into_subgraphs(graph: nx.DiGraph) -> dict[str, nx.DiGraph]:
|
|
493
|
+
return {
|
|
494
|
+
control_flow: nx.DiGraph(
|
|
495
|
+
[
|
|
496
|
+
edge
|
|
497
|
+
for edge in graph.edges.data()
|
|
498
|
+
if _get_control_flow(edge[2]) == control_flow
|
|
499
|
+
]
|
|
500
|
+
)
|
|
501
|
+
for control_flow in _extract_control_flows(graph)
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def _get_subgraphs(graph: nx.DiGraph, cf_graph: nx.DiGraph) -> dict[str, nx.DiGraph]:
|
|
506
|
+
"""
|
|
507
|
+
Separate a flat graph into subgraphs nested by control flows
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
graph (nx.DiGraph): Flat workflow graph
|
|
511
|
+
cf_graph (nx.DiGraph): Control flow graph (cf. _get_control_flow_graph)
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
dict[str, nx.DiGraph]: Subgraphs
|
|
515
|
+
"""
|
|
516
|
+
subgraphs = _split_graphs_into_subgraphs(graph)
|
|
517
|
+
for key in list(topological_sort(cf_graph))[::-1]:
|
|
518
|
+
subgraph = subgraphs[key]
|
|
519
|
+
node_name = "injected_" + key.replace("/", "_")
|
|
520
|
+
io_ = _detect_io_variables_from_control_flow(graph, subgraph)
|
|
521
|
+
for parent_graph_name in cf_graph.predecessors(key):
|
|
522
|
+
parent_graph = subgraphs[parent_graph_name]
|
|
523
|
+
for inp in io_["inputs"]:
|
|
524
|
+
parent_graph.add_edge(
|
|
525
|
+
inp, node_name, type="input", input_name=_remove_index(inp)
|
|
526
|
+
)
|
|
527
|
+
for out in io_["outputs"]:
|
|
528
|
+
parent_graph.add_edge(
|
|
529
|
+
node_name, out, type="output", output_name=_remove_index(out)
|
|
530
|
+
)
|
|
531
|
+
for inp in io_["inputs"]:
|
|
532
|
+
subgraph.add_edge("input", inp, type="output")
|
|
533
|
+
for out in io_["outputs"]:
|
|
534
|
+
subgraph.add_edge(out, "output", type="input")
|
|
535
|
+
return subgraphs
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
def _extract_functions_from_graph(graph: nx.DiGraph) -> set:
|
|
539
|
+
function_names = []
|
|
540
|
+
for edge in graph.edges.data():
|
|
541
|
+
if edge[2]["type"] == "output" and edge[0] != "input":
|
|
542
|
+
function_names.append(edge[0])
|
|
543
|
+
elif edge[2]["type"] == "input" and edge[1] != "output":
|
|
544
|
+
function_names.append(edge[1])
|
|
545
|
+
return set(function_names)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def _get_control_flow_graph(control_flows: list[str]) -> nx.DiGraph:
|
|
549
|
+
"""
|
|
550
|
+
Create a graph based on the control flows. The indentation level
|
|
551
|
+
corresponds to the graph level. The higher level body is the parent node
|
|
552
|
+
of the lower body.
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
Args:
|
|
556
|
+
control_flows (list[str]): All control flows present in a workflow
|
|
557
|
+
|
|
558
|
+
Returns:
|
|
559
|
+
nx.DiGraph: Control flow graph
|
|
560
|
+
"""
|
|
561
|
+
cf_list = []
|
|
562
|
+
for cf in control_flows:
|
|
563
|
+
if cf == "":
|
|
564
|
+
continue
|
|
565
|
+
if "/" in cf:
|
|
566
|
+
cf_list.append(["/".join(cf.split("/")[:-1]), cf])
|
|
567
|
+
else:
|
|
568
|
+
cf_list.append(["", cf])
|
|
569
|
+
graph = nx.DiGraph(cf_list)
|
|
570
|
+
if len(graph) == 0:
|
|
571
|
+
graph.add_node("")
|
|
572
|
+
return graph
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
def _function_to_ast_dict(node):
|
|
576
|
+
if isinstance(node, ast.AST):
|
|
577
|
+
result = {"_type": type(node).__name__}
|
|
578
|
+
for field, value in ast.iter_fields(node):
|
|
579
|
+
result[field] = _function_to_ast_dict(value)
|
|
580
|
+
return result
|
|
581
|
+
elif isinstance(node, list):
|
|
582
|
+
return [_function_to_ast_dict(item) for item in node]
|
|
583
|
+
else:
|
|
584
|
+
return node
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
def get_ast_dict(func: Callable) -> dict:
|
|
588
|
+
"""Get the AST dictionary representation of a function."""
|
|
589
|
+
source_code = textwrap.dedent(inspect.getsource(func))
|
|
590
|
+
tree = ast.parse(source_code)
|
|
591
|
+
return _function_to_ast_dict(tree)
|
|
592
|
+
|
|
593
|
+
|
|
594
|
+
def analyze_function(func: Callable) -> tuple[nx.DiGraph, dict[str, Any]]:
|
|
595
|
+
"""Extracts the variable flow graph from a function"""
|
|
596
|
+
ast_dict = get_ast_dict(func)
|
|
597
|
+
scope = inspect.getmodule(func).__dict__ | vars(builtins)
|
|
598
|
+
analyzer = FunctionDictFlowAnalyzer(ast_dict["body"][0], scope)
|
|
599
|
+
return analyzer.analyze()
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _get_node_outputs(func: Callable, counts: int | None = None) -> dict[str, dict]:
|
|
603
|
+
output_hints = parse_output_args(
|
|
604
|
+
func, separate_tuple=(counts is None or counts > 1)
|
|
605
|
+
)
|
|
606
|
+
output_vars = get_return_expressions(func)
|
|
607
|
+
if output_vars is None or len(output_vars) == 0:
|
|
608
|
+
return {}
|
|
609
|
+
if (counts is not None and counts == 1) or isinstance(output_vars, str):
|
|
610
|
+
if isinstance(output_vars, str):
|
|
611
|
+
return {output_vars: cast(dict, output_hints)}
|
|
612
|
+
else:
|
|
613
|
+
return {"output": cast(dict, output_hints)}
|
|
614
|
+
assert isinstance(output_vars, tuple), output_vars
|
|
615
|
+
assert counts is None or len(output_vars) == counts, output_vars
|
|
616
|
+
if output_hints == {}:
|
|
617
|
+
return {key: {} for key in output_vars}
|
|
618
|
+
else:
|
|
619
|
+
assert counts is None or len(output_hints) == counts
|
|
620
|
+
return {key: hint for key, hint in zip(output_vars, output_hints, strict=False)}
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
def _get_output_counts(graph: nx.DiGraph) -> dict[str, int]:
|
|
624
|
+
"""
|
|
625
|
+
Get the number of outputs for each node in the graph.
|
|
626
|
+
|
|
627
|
+
Args:
|
|
628
|
+
graph (nx.DiGraph): The directed graph representing the function.
|
|
629
|
+
|
|
630
|
+
Returns:
|
|
631
|
+
dict: A dictionary mapping node names to the number of outputs.
|
|
632
|
+
"""
|
|
633
|
+
f_dict: dict[str, int] = {}
|
|
634
|
+
for edge in graph.edges.data():
|
|
635
|
+
if edge[2]["type"] != "output":
|
|
636
|
+
continue
|
|
637
|
+
f_dict[edge[0]] = f_dict.get(edge[0], 0) + 1
|
|
638
|
+
if "input" in f_dict:
|
|
639
|
+
del f_dict["input"]
|
|
640
|
+
return f_dict
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def _get_nodes(
|
|
644
|
+
data: dict[str, dict],
|
|
645
|
+
output_counts: dict[str, int],
|
|
646
|
+
control_flow: None | str = None,
|
|
647
|
+
) -> dict[str, dict]:
|
|
648
|
+
result = {}
|
|
649
|
+
for label, function in data.items():
|
|
650
|
+
func = function["function"]
|
|
651
|
+
if hasattr(func, "_semantikon_workflow"):
|
|
652
|
+
if output_counts[label] != len(func._semantikon_workflow["outputs"]):
|
|
653
|
+
raise ValueError(
|
|
654
|
+
f"{label} has {len(func._semantikon_workflow['outputs'])} outputs, "
|
|
655
|
+
f"but {output_counts[label]} expected"
|
|
656
|
+
)
|
|
657
|
+
data_dict = func._semantikon_workflow.copy()
|
|
658
|
+
result[label] = data_dict
|
|
659
|
+
result[label]["label"] = label
|
|
660
|
+
if hasattr(func, "_semantikon_metadata"):
|
|
661
|
+
result[label].update(func._semantikon_metadata)
|
|
662
|
+
else:
|
|
663
|
+
result[label] = get_node_dict(
|
|
664
|
+
function=func,
|
|
665
|
+
inputs=parse_input_args(func),
|
|
666
|
+
outputs=_get_node_outputs(func, output_counts.get(label, 1)),
|
|
667
|
+
)
|
|
668
|
+
return result
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def _remove_index(s: str) -> str:
|
|
672
|
+
return "_".join(s.split("_")[:-1])
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _get_control_flow(data: dict[str, Any]) -> str:
|
|
676
|
+
"""
|
|
677
|
+
Get the control flow name
|
|
678
|
+
|
|
679
|
+
Args:
|
|
680
|
+
data (dict[str, Any]): metadata of the edge (which is stored in the
|
|
681
|
+
third element of each edge of nx.Digraph)
|
|
682
|
+
|
|
683
|
+
Returns:
|
|
684
|
+
(str): Control flow name (e.g. While_0, For_3 etc.)
|
|
685
|
+
"""
|
|
686
|
+
return data.get("control_flow", "").split("-")[0]
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def _get_sorted_edges(graph: nx.DiGraph) -> list:
|
|
690
|
+
"""
|
|
691
|
+
Sort the edges of the graph based on the topological order of the nodes.
|
|
692
|
+
|
|
693
|
+
Args:
|
|
694
|
+
graph (nx.DiGraph): The directed graph representing the function.
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
list: A sorted list of edges in the graph.
|
|
698
|
+
|
|
699
|
+
Example:
|
|
700
|
+
|
|
701
|
+
>>> graph.add_edges_from([('A', 'B'), ('B', 'D'), ('A', 'C'), ('C', 'D')])
|
|
702
|
+
>>> sorted_edges = _get_sorted_edges(graph)
|
|
703
|
+
>>> print(sorted_edges)
|
|
704
|
+
|
|
705
|
+
Output:
|
|
706
|
+
|
|
707
|
+
>>> [('A', 'B', {}), ('A', 'C', {}), ('B', 'D', {}), ('C', 'D', {})]
|
|
708
|
+
"""
|
|
709
|
+
topo_order = list(topological_sort(graph))
|
|
710
|
+
node_order = {node: i for i, node in enumerate(topo_order)}
|
|
711
|
+
return sorted(graph.edges.data(), key=lambda edge: node_order[edge[0]])
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
def _remove_and_reconnect_nodes(
|
|
715
|
+
G: nx.DiGraph, nodes_to_remove: list[str]
|
|
716
|
+
) -> nx.DiGraph:
|
|
717
|
+
for node in set(nodes_to_remove):
|
|
718
|
+
preds = list(G.predecessors(node))
|
|
719
|
+
succs = list(G.successors(node))
|
|
720
|
+
for u in preds:
|
|
721
|
+
for v in succs:
|
|
722
|
+
G.add_edge(u, v)
|
|
723
|
+
G.remove_node(node)
|
|
724
|
+
return G
|
|
725
|
+
|
|
726
|
+
|
|
727
|
+
def _get_edges(graph: nx.DiGraph, nodes: dict[str, dict]) -> list[tuple[str, str]]:
|
|
728
|
+
io_dict = {
|
|
729
|
+
key: {
|
|
730
|
+
"input": list(data["inputs"].keys()),
|
|
731
|
+
"output": list(data["outputs"].keys()),
|
|
732
|
+
}
|
|
733
|
+
for key, data in nodes.items()
|
|
734
|
+
}
|
|
735
|
+
edges = []
|
|
736
|
+
nodes_to_remove = []
|
|
737
|
+
for edge in graph.edges.data():
|
|
738
|
+
if edge[0] == "input":
|
|
739
|
+
edges.append([edge[0] + "s." + _remove_index(edge[1]), edge[1]])
|
|
740
|
+
nodes_to_remove.append(edge[1])
|
|
741
|
+
elif edge[1] == "output":
|
|
742
|
+
edges.append([edge[0], edge[1] + "s." + _remove_index(edge[0])])
|
|
743
|
+
nodes_to_remove.append(edge[0])
|
|
744
|
+
elif edge[2]["type"] == "input":
|
|
745
|
+
if "input_name" in edge[2]:
|
|
746
|
+
tag = edge[2]["input_name"]
|
|
747
|
+
elif "input_index" in edge[2]:
|
|
748
|
+
tag = io_dict[edge[1]]["input"][edge[2]["input_index"]]
|
|
749
|
+
else:
|
|
750
|
+
raise ValueError
|
|
751
|
+
edges.append([edge[0], edge[1] + ".inputs." + tag])
|
|
752
|
+
nodes_to_remove.append(edge[0])
|
|
753
|
+
elif edge[2]["type"] == "output":
|
|
754
|
+
if "output_index" in edge[2]:
|
|
755
|
+
tag = io_dict[edge[0]]["output"][edge[2]["output_index"]]
|
|
756
|
+
elif "output_name" in edge[2]:
|
|
757
|
+
tag = edge[2]["output_name"]
|
|
758
|
+
else:
|
|
759
|
+
tag = io_dict[edge[0]]["output"][0]
|
|
760
|
+
edges.append([edge[0] + ".outputs." + tag, edge[1]])
|
|
761
|
+
nodes_to_remove.append(edge[1])
|
|
762
|
+
new_graph = _remove_and_reconnect_nodes(nx.DiGraph(edges), nodes_to_remove)
|
|
763
|
+
return list(new_graph.edges)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def get_node_dict(
|
|
767
|
+
function: Callable,
|
|
768
|
+
inputs: dict[str, dict] | None = None,
|
|
769
|
+
outputs: dict[str, dict] | None = None,
|
|
770
|
+
) -> dict:
|
|
771
|
+
"""
|
|
772
|
+
Get a dictionary representation of the function node.
|
|
773
|
+
|
|
774
|
+
Args:
|
|
775
|
+
func (Callable): The function to be analyzed.
|
|
776
|
+
data_format (str): The format of the output. Options are "semantikon" and
|
|
777
|
+
"ape".
|
|
778
|
+
|
|
779
|
+
Returns:
|
|
780
|
+
(dict) A dictionary representation of the function node.
|
|
781
|
+
"""
|
|
782
|
+
if inputs is None:
|
|
783
|
+
inputs = parse_input_args(function)
|
|
784
|
+
if outputs is None:
|
|
785
|
+
outputs = _get_node_outputs(function)
|
|
786
|
+
data = {
|
|
787
|
+
"inputs": inputs,
|
|
788
|
+
"outputs": outputs,
|
|
789
|
+
"function": function,
|
|
790
|
+
"type": "Function",
|
|
791
|
+
}
|
|
792
|
+
if hasattr(function, "_semantikon_metadata"):
|
|
793
|
+
data.update(function._semantikon_metadata)
|
|
794
|
+
return data
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
def _to_workflow_dict_entry(
|
|
798
|
+
inputs: dict[str, dict],
|
|
799
|
+
outputs: dict[str, dict],
|
|
800
|
+
nodes: dict[str, dict],
|
|
801
|
+
edges: list[tuple[str, str]],
|
|
802
|
+
label: str,
|
|
803
|
+
**kwargs,
|
|
804
|
+
) -> dict[str, object]:
|
|
805
|
+
assert all("inputs" in v for v in nodes.values())
|
|
806
|
+
assert all("outputs" in v for v in nodes.values())
|
|
807
|
+
assert all(
|
|
808
|
+
"function" in v or ("nodes" in v and "edges" in v) for v in nodes.values()
|
|
809
|
+
)
|
|
810
|
+
return {
|
|
811
|
+
"inputs": inputs,
|
|
812
|
+
"outputs": outputs,
|
|
813
|
+
"nodes": nodes,
|
|
814
|
+
"edges": edges,
|
|
815
|
+
"label": label,
|
|
816
|
+
"type": "Workflow",
|
|
817
|
+
} | kwargs
|
|
818
|
+
|
|
819
|
+
|
|
820
|
+
def _get_test_dict(f_dict: dict[str, dict]) -> dict[str, str]:
|
|
821
|
+
"""
|
|
822
|
+
dict to translate test and iter nodes into "test" and "iter"
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
f_dict (dict[str, dict]): Function dictionary
|
|
826
|
+
|
|
827
|
+
Returns:
|
|
828
|
+
dict[str, str]: Translation of node name to "test" or "iter"
|
|
829
|
+
"""
|
|
830
|
+
return {
|
|
831
|
+
key: tag
|
|
832
|
+
for key, value in f_dict.items()
|
|
833
|
+
for tag in ["test", "iter"]
|
|
834
|
+
if value.get("control_flow", "").endswith(tag)
|
|
835
|
+
}
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def _nest_nodes(
|
|
839
|
+
graph: nx.DiGraph, nodes: dict[str, dict], f_dict: dict[str, dict]
|
|
840
|
+
) -> tuple[dict[str, dict], list[tuple[str, str]]]:
|
|
841
|
+
"""
|
|
842
|
+
Nest workflow nodes
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
graph (nx.DiGraph): The directed graph representing the function.
|
|
846
|
+
nodes (dict[str, dict]): The dictionary of nodes.
|
|
847
|
+
f_dict (dict[str, dict]): The dictionary of functions.
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
dict: A dictionary containing the nested nodes, edges, and label.
|
|
851
|
+
"""
|
|
852
|
+
test_dict = _get_test_dict(f_dict=f_dict)
|
|
853
|
+
cf_graph = _get_control_flow_graph(_extract_control_flows(graph))
|
|
854
|
+
subgraphs = _get_subgraphs(graph, cf_graph)
|
|
855
|
+
injected_nodes: dict[str, Any] = {}
|
|
856
|
+
for cf_key in list(topological_sort(cf_graph))[::-1]:
|
|
857
|
+
subgraph = nx.relabel_nodes(subgraphs[cf_key], test_dict)
|
|
858
|
+
new_key = "injected_" + cf_key.replace("/", "_") if len(cf_key) > 0 else cf_key
|
|
859
|
+
current_nodes = {}
|
|
860
|
+
for key in _extract_functions_from_graph(subgraphs[cf_key]):
|
|
861
|
+
if key in test_dict:
|
|
862
|
+
current_nodes[test_dict[key]] = nodes[key]
|
|
863
|
+
elif key in nodes:
|
|
864
|
+
current_nodes[key] = nodes[key]
|
|
865
|
+
else:
|
|
866
|
+
current_nodes[key] = injected_nodes.pop(key)
|
|
867
|
+
io_ = _detect_io_variables_from_control_flow(graph, subgraph)
|
|
868
|
+
injected_nodes[new_key] = {
|
|
869
|
+
"nodes": current_nodes,
|
|
870
|
+
"edges": _get_edges(subgraph, current_nodes),
|
|
871
|
+
"label": new_key,
|
|
872
|
+
"inputs": {_remove_index(key): {} for key in io_["inputs"]},
|
|
873
|
+
"outputs": {_remove_index(key): {} for key in io_["outputs"]},
|
|
874
|
+
}
|
|
875
|
+
for tag in ["test", "iter"]:
|
|
876
|
+
if tag in injected_nodes[new_key]["nodes"]:
|
|
877
|
+
injected_nodes[new_key][tag] = injected_nodes[new_key]["nodes"].pop(tag)
|
|
878
|
+
return injected_nodes[""]["nodes"], injected_nodes[""]["edges"]
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
def get_workflow_dict(func: Callable) -> dict[str, object]:
|
|
882
|
+
"""
|
|
883
|
+
Get a dictionary representation of the workflow for a given function.
|
|
884
|
+
|
|
885
|
+
Args:
|
|
886
|
+
func (Callable): The function to be analyzed.
|
|
887
|
+
|
|
888
|
+
Returns:
|
|
889
|
+
dict: A dictionary representation of the workflow, including inputs,
|
|
890
|
+
outputs, nodes, edges, and label.
|
|
891
|
+
"""
|
|
892
|
+
graph, f_dict = analyze_function(func)
|
|
893
|
+
nodes = _get_nodes(f_dict, _get_output_counts(graph))
|
|
894
|
+
nested_nodes, edges = _nest_nodes(graph, nodes, f_dict)
|
|
895
|
+
return _to_workflow_dict_entry(
|
|
896
|
+
inputs=parse_input_args(func),
|
|
897
|
+
outputs=_get_node_outputs(func),
|
|
898
|
+
nodes=nested_nodes,
|
|
899
|
+
edges=edges,
|
|
900
|
+
label=func.__name__,
|
|
901
|
+
)
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def _get_missing_edges(edge_list: list[tuple[str, str]]) -> list[tuple[str, str]]:
|
|
905
|
+
"""
|
|
906
|
+
Insert processes into the data edges. Take the following workflow:
|
|
907
|
+
|
|
908
|
+
>>> y = f(x=x)
|
|
909
|
+
>>> z = g(y=y)
|
|
910
|
+
|
|
911
|
+
The data flow is
|
|
912
|
+
|
|
913
|
+
- f.inputs.x -> f.outputs.y
|
|
914
|
+
- f.outputs.y -> g.inputs.y
|
|
915
|
+
- g.inputs.y -> g.outputs.z
|
|
916
|
+
|
|
917
|
+
`_get_missing_edges` adds the processes:
|
|
918
|
+
|
|
919
|
+
- f.inputs.x -> f
|
|
920
|
+
- f -> f.outputs.y
|
|
921
|
+
- f.outputs.y -> g.inputs.y
|
|
922
|
+
- g.inputs.y -> g
|
|
923
|
+
- g -> g.outputs.z
|
|
924
|
+
"""
|
|
925
|
+
extra_edges = []
|
|
926
|
+
for edge in edge_list:
|
|
927
|
+
for tag in edge:
|
|
928
|
+
if len(tag.split(".")) < 3:
|
|
929
|
+
continue
|
|
930
|
+
if tag.split(".")[1] == "inputs":
|
|
931
|
+
new_edge = (tag, tag.split(".")[0])
|
|
932
|
+
elif tag.split(".")[1] == "outputs":
|
|
933
|
+
new_edge = (tag.split(".")[0], tag)
|
|
934
|
+
if new_edge not in extra_edges:
|
|
935
|
+
extra_edges.append(new_edge)
|
|
936
|
+
return extra_edges
|
|
937
|
+
|
|
938
|
+
|
|
939
|
+
class _Workflow:
|
|
940
|
+
def __init__(self, workflow_dict: dict[str, Any]):
|
|
941
|
+
self._workflow = workflow_dict
|
|
942
|
+
|
|
943
|
+
@cached_property
|
|
944
|
+
def _all_edges(self) -> list[tuple[str, str]]:
|
|
945
|
+
edges = cast(dict[str, list], self._workflow)["edges"]
|
|
946
|
+
return edges + _get_missing_edges(edges)
|
|
947
|
+
|
|
948
|
+
@cached_property
|
|
949
|
+
def _graph(self) -> nx.DiGraph:
|
|
950
|
+
return nx.DiGraph(self._all_edges)
|
|
951
|
+
|
|
952
|
+
@cached_property
|
|
953
|
+
def _execution_list(self) -> list[list[str]]:
|
|
954
|
+
return find_parallel_execution_levels(self._graph)
|
|
955
|
+
|
|
956
|
+
def _sanitize_input(self, *args, **kwargs) -> dict[str, Any]:
|
|
957
|
+
keys = list(self._workflow["inputs"].keys())
|
|
958
|
+
for ii, arg in enumerate(args):
|
|
959
|
+
if keys[ii] in kwargs:
|
|
960
|
+
raise TypeError(
|
|
961
|
+
f"{self._workflow['label']}() got multiple values for"
|
|
962
|
+
" argument '{keys[ii]}'"
|
|
963
|
+
)
|
|
964
|
+
kwargs[keys[ii]] = arg
|
|
965
|
+
return kwargs
|
|
966
|
+
|
|
967
|
+
def _set_inputs(self, *args, **kwargs):
|
|
968
|
+
kwargs = self._sanitize_input(*args, **kwargs)
|
|
969
|
+
for key, value in kwargs.items():
|
|
970
|
+
if key not in self._workflow["inputs"]:
|
|
971
|
+
raise TypeError(f"Unexpected keyword argument '{key}'")
|
|
972
|
+
self._workflow["inputs"][key]["value"] = value
|
|
973
|
+
|
|
974
|
+
def _get_value_from_data(self, node: dict[str, Any]) -> Any:
|
|
975
|
+
if "value" not in node:
|
|
976
|
+
node["value"] = node["default"]
|
|
977
|
+
return node["value"]
|
|
978
|
+
|
|
979
|
+
def _get_value_from_global(self, path: str) -> Any:
|
|
980
|
+
io, var = path.split(".")
|
|
981
|
+
return self._get_value_from_data(self._workflow[io][var])
|
|
982
|
+
|
|
983
|
+
def _get_value_from_node(self, path: str) -> Any:
|
|
984
|
+
node, io, var = path.split(".")
|
|
985
|
+
return self._get_value_from_data(self._workflow["nodes"][node][io][var])
|
|
986
|
+
|
|
987
|
+
def _set_value_from_global(self, path, value):
|
|
988
|
+
io, var = path.split(".")
|
|
989
|
+
self._workflow[io][var]["value"] = value
|
|
990
|
+
|
|
991
|
+
def _set_value_from_node(self, path, value):
|
|
992
|
+
node, io, var = path.split(".")
|
|
993
|
+
try:
|
|
994
|
+
self._workflow["nodes"][node][io][var]["value"] = value
|
|
995
|
+
except KeyError:
|
|
996
|
+
raise KeyError(f"{path} not found in {node}") from None
|
|
997
|
+
|
|
998
|
+
def _execute_node(self, function: str) -> Any:
|
|
999
|
+
node = self._workflow["nodes"][function]
|
|
1000
|
+
input_data = {}
|
|
1001
|
+
try:
|
|
1002
|
+
for key, content in node["inputs"].items():
|
|
1003
|
+
if "value" not in content:
|
|
1004
|
+
content["value"] = content["default"]
|
|
1005
|
+
input_data[key] = content["value"]
|
|
1006
|
+
except KeyError:
|
|
1007
|
+
raise KeyError(f"value not defined for {function}") from None
|
|
1008
|
+
if "function" not in node:
|
|
1009
|
+
workflow = _Workflow(node)
|
|
1010
|
+
outputs = [
|
|
1011
|
+
d["value"] for d in workflow.run(**input_data)["outputs"].values()
|
|
1012
|
+
]
|
|
1013
|
+
if len(outputs) == 1:
|
|
1014
|
+
outputs = outputs[0]
|
|
1015
|
+
else:
|
|
1016
|
+
outputs = node["function"](**input_data)
|
|
1017
|
+
return outputs
|
|
1018
|
+
|
|
1019
|
+
def _set_value(self, tag, value):
|
|
1020
|
+
if len(tag.split(".")) == 2 and tag.split(".")[0] in ("inputs", "outputs"):
|
|
1021
|
+
self._set_value_from_global(tag, value)
|
|
1022
|
+
elif len(tag.split(".")) == 3 and tag.split(".")[1] in ("inputs", "outputs"):
|
|
1023
|
+
self._set_value_from_node(tag, value)
|
|
1024
|
+
elif "." in tag:
|
|
1025
|
+
raise ValueError(f"{tag} not recognized")
|
|
1026
|
+
|
|
1027
|
+
def _get_value(self, tag: str):
|
|
1028
|
+
if len(tag.split(".")) == 2 and tag.split(".")[0] in ("inputs", "outputs"):
|
|
1029
|
+
return self._get_value_from_global(tag)
|
|
1030
|
+
elif len(tag.split(".")) == 3 and tag.split(".")[1] in ("inputs", "outputs"):
|
|
1031
|
+
return self._get_value_from_node(tag)
|
|
1032
|
+
elif "." not in tag:
|
|
1033
|
+
return self._execute_node(tag)
|
|
1034
|
+
else:
|
|
1035
|
+
raise ValueError(f"{tag} not recognized")
|
|
1036
|
+
|
|
1037
|
+
def run(self, *args, **kwargs) -> dict[str, Any]:
|
|
1038
|
+
self._set_inputs(*args, **kwargs)
|
|
1039
|
+
for current_list in self._execution_list:
|
|
1040
|
+
for item in current_list:
|
|
1041
|
+
values = self._get_value(item)
|
|
1042
|
+
nodes = self._graph.edges(item)
|
|
1043
|
+
if "." not in item and len(nodes) > 1:
|
|
1044
|
+
for value, node in zip(values, nodes, strict=False):
|
|
1045
|
+
self._set_value(node[1], value)
|
|
1046
|
+
else:
|
|
1047
|
+
for node in nodes:
|
|
1048
|
+
self._set_value(node[1], values)
|
|
1049
|
+
return self._workflow
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
def find_parallel_execution_levels(G: nx.DiGraph) -> list[list[str]]:
|
|
1053
|
+
"""
|
|
1054
|
+
Find levels of parallel execution in a directed acyclic graph (DAG).
|
|
1055
|
+
|
|
1056
|
+
Args:
|
|
1057
|
+
G (nx.DiGraph): The directed graph representing the function.
|
|
1058
|
+
|
|
1059
|
+
Returns:
|
|
1060
|
+
list[list[str]]: A list of lists, where each inner list contains nodes
|
|
1061
|
+
that can be executed in parallel.
|
|
1062
|
+
|
|
1063
|
+
Comment:
|
|
1064
|
+
This function only gives you a list of nodes that can be executed in
|
|
1065
|
+
parallel, but does not tell you which processes can be executed in
|
|
1066
|
+
case there is a process that takes longer at a higher level.
|
|
1067
|
+
"""
|
|
1068
|
+
in_degree = dict(cast(Iterable[tuple[Any, int]], G.in_degree()))
|
|
1069
|
+
queue = deque([node for node in G.nodes if in_degree[node] == 0])
|
|
1070
|
+
levels = []
|
|
1071
|
+
|
|
1072
|
+
while queue:
|
|
1073
|
+
current_level = list(queue)
|
|
1074
|
+
if "input" not in current_level and "output" not in current_level:
|
|
1075
|
+
levels.append(current_level)
|
|
1076
|
+
|
|
1077
|
+
next_queue: deque = deque()
|
|
1078
|
+
for node in current_level:
|
|
1079
|
+
for neighbor in G.successors(node):
|
|
1080
|
+
in_degree[neighbor] -= 1
|
|
1081
|
+
if in_degree[neighbor] == 0:
|
|
1082
|
+
next_queue.append(neighbor)
|
|
1083
|
+
|
|
1084
|
+
queue = next_queue
|
|
1085
|
+
|
|
1086
|
+
return levels
|
|
1087
|
+
|
|
1088
|
+
|
|
1089
|
+
def workflow(func: Callable) -> FunctionWithWorkflow:
|
|
1090
|
+
"""
|
|
1091
|
+
Decorator to convert a function into a workflow with metadata.
|
|
1092
|
+
|
|
1093
|
+
Args:
|
|
1094
|
+
func (Callable): The function to be converted into a workflow.
|
|
1095
|
+
|
|
1096
|
+
Returns:
|
|
1097
|
+
FunctionWithWorkflow: A callable object that includes the original function
|
|
1098
|
+
|
|
1099
|
+
Example:
|
|
1100
|
+
|
|
1101
|
+
>>> def operation(x: float, y: float) -> tuple[float, float]:
|
|
1102
|
+
>>> return x + y, x - y
|
|
1103
|
+
>>>
|
|
1104
|
+
>>>
|
|
1105
|
+
>>> def add(x: float = 2.0, y: float = 1) -> float:
|
|
1106
|
+
>>> return x + y
|
|
1107
|
+
>>>
|
|
1108
|
+
>>>
|
|
1109
|
+
>>> def multiply(x: float, y: float = 5) -> float:
|
|
1110
|
+
>>> return x * y
|
|
1111
|
+
>>>
|
|
1112
|
+
>>>
|
|
1113
|
+
>>> @workflow
|
|
1114
|
+
>>> def example_macro(a=10, b=20):
|
|
1115
|
+
>>> c, d = operation(a, b)
|
|
1116
|
+
>>> e = add(c, y=d)
|
|
1117
|
+
>>> f = multiply(e)
|
|
1118
|
+
>>> return f
|
|
1119
|
+
>>>
|
|
1120
|
+
>>>
|
|
1121
|
+
>>> @workflow
|
|
1122
|
+
>>> def example_workflow(a=10, b=20):
|
|
1123
|
+
>>> y = example_macro(a, b)
|
|
1124
|
+
>>> z = add(y, b)
|
|
1125
|
+
>>> return z
|
|
1126
|
+
|
|
1127
|
+
This example defines a workflow `example_macro`, that includes `operation`,
|
|
1128
|
+
`add`, and `multiply`, which is nested inside another workflow
|
|
1129
|
+
`example_workflow`. Both workflows can be executed using their `run` method,
|
|
1130
|
+
which returns the dictionary representation of the workflow with all the
|
|
1131
|
+
intermediate steps and outputs.
|
|
1132
|
+
"""
|
|
1133
|
+
workflow_dict = get_workflow_dict(func)
|
|
1134
|
+
w = _Workflow(workflow_dict)
|
|
1135
|
+
func_with_metadata = FunctionWithWorkflow(func, workflow_dict, w.run)
|
|
1136
|
+
return func_with_metadata
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
def get_ports(
|
|
1140
|
+
func: Callable, separate_return_tuple: bool = True, strict: bool = False
|
|
1141
|
+
) -> tuple[Inputs, Outputs]:
|
|
1142
|
+
type_hints = get_annotated_type_hints(func)
|
|
1143
|
+
return_hint = type_hints.pop("return", inspect.Parameter.empty)
|
|
1144
|
+
return_labels = get_return_labels(
|
|
1145
|
+
func, separate_tuple=separate_return_tuple, strict=strict
|
|
1146
|
+
)
|
|
1147
|
+
if get_origin(return_hint) is tuple and separate_return_tuple:
|
|
1148
|
+
output_annotations = {
|
|
1149
|
+
label: meta_to_dict(ann, flatten_metadata=False)
|
|
1150
|
+
for label, ann in zip(return_labels, get_args(return_hint), strict=False)
|
|
1151
|
+
}
|
|
1152
|
+
else:
|
|
1153
|
+
output_annotations = {
|
|
1154
|
+
return_labels[0]: meta_to_dict(return_hint, flatten_metadata=False)
|
|
1155
|
+
}
|
|
1156
|
+
input_annotations = {
|
|
1157
|
+
key: meta_to_dict(
|
|
1158
|
+
type_hints.get(key, value.annotation), value.default, flatten_metadata=False
|
|
1159
|
+
)
|
|
1160
|
+
for key, value in inspect.signature(func).parameters.items()
|
|
1161
|
+
}
|
|
1162
|
+
return (
|
|
1163
|
+
Inputs(**{k: Input(label=k, **v) for k, v in input_annotations.items()}),
|
|
1164
|
+
Outputs(**{k: Output(label=k, **v) for k, v in output_annotations.items()}),
|
|
1165
|
+
)
|
|
1166
|
+
|
|
1167
|
+
|
|
1168
|
+
def get_node(func: Callable, label: str | None = None) -> Function | Workflow:
|
|
1169
|
+
metadata_dict = (
|
|
1170
|
+
func._semantikon_metadata if hasattr(func, "_semantikon_metadata") else MISSING
|
|
1171
|
+
)
|
|
1172
|
+
metadata = (
|
|
1173
|
+
metadata_dict
|
|
1174
|
+
if isinstance(metadata_dict, Missing)
|
|
1175
|
+
else CoreMetadata.from_dict(metadata_dict)
|
|
1176
|
+
)
|
|
1177
|
+
|
|
1178
|
+
if hasattr(func, "_semantikon_workflow"):
|
|
1179
|
+
return parse_workflow(func._semantikon_workflow, metadata)
|
|
1180
|
+
else:
|
|
1181
|
+
return parse_function(func, metadata, label=label)
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def parse_function(
|
|
1185
|
+
func: Callable, metadata: CoreMetadata | Missing, label: str | None = None
|
|
1186
|
+
) -> Function:
|
|
1187
|
+
inputs, outputs = get_ports(func)
|
|
1188
|
+
return Function(
|
|
1189
|
+
label=func.__name__ if label is None else label,
|
|
1190
|
+
inputs=inputs,
|
|
1191
|
+
outputs=outputs,
|
|
1192
|
+
function=func,
|
|
1193
|
+
metadata=metadata,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def _port_from_dictionary(
|
|
1198
|
+
io_dictionary: dict[str, object], label: str, port_class: type[PortType]
|
|
1199
|
+
) -> PortType:
|
|
1200
|
+
"""
|
|
1201
|
+
Take a traditional _semantikon_workflow dictionary's input or output subdictionary
|
|
1202
|
+
and nest the metadata (if any) as a dataclass.
|
|
1203
|
+
"""
|
|
1204
|
+
metadata_kwargs = {}
|
|
1205
|
+
for field in dataclasses.fields(TypeMetadata):
|
|
1206
|
+
if field.name in io_dictionary:
|
|
1207
|
+
metadata_kwargs[field.name] = io_dictionary.pop(field.name)
|
|
1208
|
+
if len(metadata_kwargs) > 0:
|
|
1209
|
+
io_dictionary["metadata"] = TypeMetadata.from_dict(metadata_kwargs)
|
|
1210
|
+
io_dictionary["label"] = label
|
|
1211
|
+
return port_class.from_dict(io_dictionary)
|
|
1212
|
+
|
|
1213
|
+
|
|
1214
|
+
def _input_from_dictionary(io_dictionary: dict[str, object], label: str) -> Input:
|
|
1215
|
+
return _port_from_dictionary(io_dictionary, label, Input)
|
|
1216
|
+
|
|
1217
|
+
|
|
1218
|
+
def _output_from_dictionary(io_dictionary: dict[str, object], label: str) -> Output:
|
|
1219
|
+
return _port_from_dictionary(io_dictionary, label, Output)
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
def parse_workflow(
|
|
1223
|
+
semantikon_workflow: dict[str, Any], metadata: CoreMetadata | Missing = MISSING
|
|
1224
|
+
) -> Workflow:
|
|
1225
|
+
label = semantikon_workflow["label"]
|
|
1226
|
+
inputs = Inputs(
|
|
1227
|
+
**{
|
|
1228
|
+
k: _input_from_dictionary(v, label=k)
|
|
1229
|
+
for k, v in semantikon_workflow["inputs"].items()
|
|
1230
|
+
}
|
|
1231
|
+
)
|
|
1232
|
+
outputs = Outputs(
|
|
1233
|
+
**{
|
|
1234
|
+
k: _output_from_dictionary(v, label=k)
|
|
1235
|
+
for k, v in semantikon_workflow["outputs"].items()
|
|
1236
|
+
}
|
|
1237
|
+
)
|
|
1238
|
+
nodes = Nodes(
|
|
1239
|
+
**{
|
|
1240
|
+
k: (
|
|
1241
|
+
get_node(v["function"], label=k)
|
|
1242
|
+
if "function" in v
|
|
1243
|
+
else parse_workflow(v)
|
|
1244
|
+
)
|
|
1245
|
+
for k, v in semantikon_workflow["nodes"].items()
|
|
1246
|
+
}
|
|
1247
|
+
)
|
|
1248
|
+
edges = Edges(**{v: k for k, v in semantikon_workflow["edges"]})
|
|
1249
|
+
return Workflow(
|
|
1250
|
+
label=label,
|
|
1251
|
+
inputs=inputs,
|
|
1252
|
+
outputs=outputs,
|
|
1253
|
+
nodes=nodes,
|
|
1254
|
+
edges=edges,
|
|
1255
|
+
metadata=metadata,
|
|
1256
|
+
)
|