flowrep 0.post0.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flowrep/workflow.py ADDED
@@ -0,0 +1,1256 @@
1
+ import ast
2
+ import builtins
3
+ import copy
4
+ import dataclasses
5
+ import inspect
6
+ import textwrap
7
+ from collections import deque
8
+ from collections.abc import Callable, Iterable
9
+ from functools import cached_property, update_wrapper
10
+ from typing import Any, Generic, TypeVar, cast, get_args, get_origin
11
+
12
+ import networkx as nx
13
+ from networkx.algorithms.dag import topological_sort
14
+ from semantikon.converter import (
15
+ get_annotated_type_hints,
16
+ get_return_expressions,
17
+ get_return_labels,
18
+ meta_to_dict,
19
+ parse_input_args,
20
+ parse_output_args,
21
+ )
22
+ from semantikon.datastructure import (
23
+ MISSING,
24
+ CoreMetadata,
25
+ Edges,
26
+ Function,
27
+ Input,
28
+ Inputs,
29
+ Missing,
30
+ Nodes,
31
+ Output,
32
+ Outputs,
33
+ PortType,
34
+ TypeMetadata,
35
+ Workflow,
36
+ )
37
+
38
+ F = TypeVar("F", bound=Callable[..., object])
39
+
40
+
41
+ class FunctionWithWorkflow(Generic[F]):
42
+ def __init__(self, func: F, workflow: dict[str, object], run) -> None:
43
+ self.func = func
44
+ self._semantikon_workflow: dict[str, object] = workflow
45
+ self.run = run
46
+ update_wrapper(self, func) # Copies __name__, __doc__, etc.
47
+
48
+ def __call__(self, *args, **kwargs):
49
+ return self.func(*args, **kwargs)
50
+
51
+ def __getattr__(self, item):
52
+ return getattr(self.func, item)
53
+
54
+
55
+ def separate_types(
56
+ data: dict[str, Any], class_dict: dict[str, type] | None = None
57
+ ) -> tuple[dict[str, Any], dict[str, type]]:
58
+ """
59
+ Separate types from the data dictionary and store them in a class dictionary.
60
+ The types inside the data dictionary will be replaced by their name (which
61
+ would for example make it easier to hash it).
62
+
63
+ Args:
64
+ data (dict[str, Any]): The data dictionary containing nodes and types.
65
+ class_dict (dict[str, type], optional): A dictionary to store types. It
66
+ is mainly used due to the recursivity of this function. Defaults to
67
+ None.
68
+
69
+ Returns:
70
+ tuple: A tuple containing the modified data dictionary and the
71
+ class dictionary.
72
+ """
73
+ data = copy.deepcopy(data)
74
+ if class_dict is None:
75
+ class_dict = {}
76
+ if "nodes" in data:
77
+ for key, node in data["nodes"].items():
78
+ child_node, child_class_dict = separate_types(node, class_dict)
79
+ class_dict.update(child_class_dict)
80
+ data["nodes"][key] = child_node
81
+ for io_ in ["inputs", "outputs"]:
82
+ for key, content in data[io_].items():
83
+ if "dtype" in content and isinstance(content["dtype"], type):
84
+ class_dict[content["dtype"].__name__] = content["dtype"]
85
+ data[io_][key]["dtype"] = content["dtype"].__name__
86
+ return data, class_dict
87
+
88
+
89
+ def separate_functions(
90
+ data: dict[str, Any], function_dict: dict[str, Callable] | None = None
91
+ ) -> tuple[dict[str, Any], dict[str, Callable]]:
92
+ """
93
+ Separate functions from the data dictionary and store them in a function
94
+ dictionary. The functions inside the data dictionary will be replaced by
95
+ their name (which would for example make it easier to hash it)
96
+
97
+ Args:
98
+ data (dict[str, Any]): The data dictionary containing nodes and
99
+ functions.
100
+ function_dict (dict[str, Callable], optional): A dictionary to store
101
+ functions. It is mainly used due to the recursivity of this
102
+ function. Defaults to None.
103
+
104
+ Returns:
105
+ tuple: A tuple containing the modified data dictionary and the
106
+ function dictionary.
107
+ """
108
+ data = copy.deepcopy(data)
109
+ if function_dict is None:
110
+ function_dict = {}
111
+ if "nodes" in data:
112
+ for key, node in data["nodes"].items():
113
+ child_node, child_function_dict = separate_functions(node, function_dict)
114
+ function_dict.update(child_function_dict)
115
+ data["nodes"][key] = child_node
116
+ elif "function" in data and not isinstance(data["function"], str):
117
+ fnc_object = data["function"]
118
+ as_string = fnc_object.__module__ + "." + fnc_object.__qualname__
119
+ function_dict[as_string] = fnc_object
120
+ data["function"] = as_string
121
+ if "test" in data and not isinstance(data["test"]["function"], str):
122
+ fnc_object = data["test"]["function"]
123
+ as_string = fnc_object.__module__ + fnc_object.__qualname__
124
+ function_dict[as_string] = fnc_object
125
+ data["test"]["function"] = as_string
126
+ return data, function_dict
127
+
128
+
129
+ class FunctionDictFlowAnalyzer:
130
+ def __init__(self, ast_dict, scope):
131
+ self.graph = nx.DiGraph()
132
+ self.scope = scope # mapping from function names to objects
133
+ self.function_defs = {}
134
+ self.ast_dict = ast_dict
135
+ self._call_counter = {}
136
+ self._control_flow_list = []
137
+ self._parallel_var = {}
138
+
139
+ def analyze(self) -> tuple[nx.DiGraph, dict[str, Any]]:
140
+ for arg in self.ast_dict.get("args", {}).get("args", []):
141
+ if arg["_type"] == "arg":
142
+ self._add_output_edge("input", arg["arg"])
143
+ return_was_called = False
144
+ for node in self.ast_dict.get("body", []):
145
+ assert not return_was_called
146
+ self._visit_node(node)
147
+ if node["_type"] == "Return":
148
+ return_was_called = True
149
+ return self.graph, self.function_defs
150
+
151
+ def _visit_node(self, node, control_flow: str | None = None):
152
+ if node["_type"] == "Assign":
153
+ self._handle_assign(node, control_flow=control_flow)
154
+ elif node["_type"] == "Expr":
155
+ self._handle_expr(node, control_flow=control_flow)
156
+ elif node["_type"] == "While":
157
+ self._handle_while(node, control_flow=control_flow)
158
+ elif node["_type"] == "For":
159
+ self._handle_for(node, control_flow=control_flow)
160
+ elif node["_type"] == "If":
161
+ self._handle_if(node, control_flow=control_flow)
162
+ elif node["_type"] == "Return":
163
+ self._handle_return(node, control_flow=control_flow)
164
+ else:
165
+ raise NotImplementedError(f"Node type {node['_type']} not implemented")
166
+
167
+ def _handle_return(self, node, control_flow: str | None = None):
168
+ if not node["value"]:
169
+ return
170
+ if node["value"]["_type"] == "Tuple":
171
+ for idx, elt in enumerate(node["value"]["elts"]):
172
+ if elt["_type"] != "Name":
173
+ raise NotImplementedError("Only variable returns supported")
174
+ self._add_input_edge(elt, "output", input_index=idx)
175
+ elif node["value"]["_type"] == "Name":
176
+ self._add_input_edge(node["value"], "output")
177
+
178
+ def _handle_if(self, node, control_flow: str | None = None):
179
+ assert node["test"]["_type"] == "Call"
180
+ control_flow = self._convert_control_flow(control_flow, tag="If")
181
+ self._parse_function_call(node["test"], control_flow=f"{control_flow}-test")
182
+ for n in node["body"]:
183
+ self._visit_node(n, control_flow=f"{control_flow}-body")
184
+ for n in node.get("orelse", []):
185
+ cf_else = "/".join(
186
+ control_flow.split("/")[:-1]
187
+ + [control_flow.split("/")[-1].replace("If", "Else") + "-body"]
188
+ )
189
+ self._visit_node(n, control_flow=cf_else)
190
+ self._reconnect_parallel(cf_else, f"{control_flow}-body")
191
+ self._register_parallel_variables(cf_else, f"{control_flow}-body")
192
+
193
+ def _reconnect_parallel(self, control_flow: str, ref_control_flow: str):
194
+ all_edges = list(self.graph.edges.data())
195
+ for edge in all_edges:
196
+ if (
197
+ "control_flow" not in edge[2]
198
+ or edge[2]["control_flow"] != control_flow
199
+ or edge[2]["type"] == "output"
200
+ ):
201
+ continue
202
+ var, ind = "_".join(edge[0].split("_")[:-1]), int(edge[0].split("_")[-1])
203
+ while True:
204
+ if any(
205
+ [
206
+ e[2].get("control_flow") == ref_control_flow
207
+ for e in self.graph.in_edges(f"{var}_{ind}", data=True)
208
+ ]
209
+ ):
210
+ ind -= 1
211
+ break
212
+ if f"{var}_{ind}" != edge[0]:
213
+ self.graph.add_edge(f"{var}_{ind}", edge[1], **edge[2])
214
+ self.graph.remove_edge(edge[0], edge[1])
215
+
216
+ def _register_parallel_variables(self, control_flow: str, ref_control_flow: str):
217
+ data: dict[str, dict] = {control_flow: {}, ref_control_flow: {}}
218
+ for edge in self.graph.edges.data():
219
+ if (
220
+ edge[2].get("control_flow", "") in [control_flow, ref_control_flow]
221
+ and edge[2]["type"] == "output"
222
+ ):
223
+ data[edge[2]["control_flow"]][edge[1].rsplit("_", 1)[0]] = edge[1]
224
+ for key in set(data[control_flow].keys()).intersection(
225
+ data[ref_control_flow].keys()
226
+ ):
227
+ values = sorted([data[control_flow][key], data[ref_control_flow][key]])
228
+ self._parallel_var[values[-1]] = [values[0]]
229
+ if values[0] in self._parallel_var:
230
+ self._parallel_var[values[-1]].extend(self._parallel_var.pop(values[0]))
231
+
232
+ def _handle_while(self, node, control_flow: str | None = None):
233
+ assert node["test"]["_type"] == "Call"
234
+ control_flow = self._convert_control_flow(control_flow, tag="While")
235
+ self._parse_function_call(node["test"], control_flow=f"{control_flow}-test")
236
+ for n in node["body"]:
237
+ self._visit_node(n, control_flow=f"{control_flow}-body")
238
+
239
+ def _handle_for(self, node, control_flow: str | None = None):
240
+ assert node["iter"]["_type"] == "Call"
241
+ control_flow = self._convert_control_flow(control_flow, tag="For")
242
+
243
+ unique_func_name = self._parse_function_call(
244
+ node["iter"], control_flow=f"{control_flow}-iter"
245
+ )
246
+ self._parse_outputs(
247
+ [node["target"]], unique_func_name, control_flow=control_flow
248
+ )
249
+ for n in node["body"]:
250
+ self._visit_node(n, control_flow=f"{control_flow}-body")
251
+
252
+ def _handle_expr(self, node, control_flow: str | None = None) -> str:
253
+ value = node["value"]
254
+ return self._parse_function_call(value, control_flow=control_flow)
255
+
256
+ def _parse_function_call(self, value, control_flow: str | None = None) -> str:
257
+ if value["_type"] != "Call":
258
+ raise NotImplementedError(
259
+ f"Only function calls allowed on RHS: {value['_type']}"
260
+ )
261
+
262
+ func_node = value["func"]
263
+ if func_node["_type"] != "Name":
264
+ raise NotImplementedError("Only simple functions allowed")
265
+
266
+ func_name = func_node["id"]
267
+ unique_func_name = self._get_unique_func_name(func_name)
268
+
269
+ if func_name not in self.scope:
270
+ raise ValueError(f"Function {func_name} not found in scope")
271
+
272
+ self.function_defs[unique_func_name] = {"function": self.scope[func_name]}
273
+ if control_flow is not None:
274
+ self.function_defs[unique_func_name]["control_flow"] = control_flow
275
+
276
+ # Parse inputs (positional + keyword)
277
+ for i, arg in enumerate(value.get("args", [])):
278
+ self._add_input_edge(
279
+ arg, unique_func_name, input_index=i, control_flow=control_flow
280
+ )
281
+ for kw in value.get("keywords", []):
282
+ self._add_input_edge(
283
+ kw["value"],
284
+ unique_func_name,
285
+ input_name=kw["arg"],
286
+ control_flow=control_flow,
287
+ )
288
+ return unique_func_name
289
+
290
+ def _handle_assign(self, node, control_flow: str | None = None):
291
+ unique_func_name = self._handle_expr(node, control_flow=control_flow)
292
+ self._parse_outputs(
293
+ node["targets"], unique_func_name, control_flow=control_flow
294
+ )
295
+
296
+ def _parse_outputs(
297
+ self, targets, unique_func_name, control_flow: str | None = None
298
+ ):
299
+ if len(targets) == 1 and targets[0]["_type"] == "Tuple":
300
+ for idx, elt in enumerate(targets[0]["elts"]):
301
+ self._add_output_edge(
302
+ unique_func_name,
303
+ elt["id"],
304
+ output_index=idx,
305
+ control_flow=control_flow,
306
+ )
307
+ else:
308
+ for target in targets:
309
+ self._add_output_edge(
310
+ unique_func_name, target["id"], control_flow=control_flow
311
+ )
312
+
313
+ def _get_max_index(self, variable: str) -> int:
314
+ index = 0
315
+ while True:
316
+ if len(self.graph.in_edges(f"{variable}_{index}")) > 0:
317
+ index += 1
318
+ continue
319
+ break
320
+ return index
321
+
322
+ def _get_var_index(self, variable: str, output: bool = False) -> int:
323
+ index = self._get_max_index(variable)
324
+ if index == 0 and not output:
325
+ raise KeyError(
326
+ f"Variable {variable} not found in graph. "
327
+ "This usually means that the variable was never defined."
328
+ )
329
+ if output:
330
+ return index
331
+ else:
332
+ return index - 1
333
+
334
+ def _add_output_edge(
335
+ self, source: str, target: str, control_flow: str | None = None, **kwargs
336
+ ):
337
+ """
338
+ Add an output edge from the source to the target variable.
339
+
340
+ Args:
341
+ source (str): The source node (function name).
342
+ target (str): The target variable name.
343
+ control_flow (str | None): The control flow tag, if any.
344
+ **kwargs: Additional keyword arguments to pass to the edge.
345
+
346
+ In the case of the following line:
347
+
348
+ >>> y = f(x)
349
+
350
+ This function will add an edge from the function `f` to the variable `y`.
351
+ """
352
+ versioned = f"{target}_{self._get_var_index(target, output=True)}"
353
+ if control_flow is not None:
354
+ kwargs["control_flow"] = control_flow
355
+ self.graph.add_edge(source, versioned, type="output", **kwargs)
356
+
357
+ def _add_input_edge(
358
+ self, source: dict, target: str, control_flow: str | None = None, **kwargs
359
+ ):
360
+ """
361
+ Add an input edge from the source variable to the target node.
362
+
363
+ Args:
364
+ source (dict): The source variable node.
365
+ target (str): The target node (function name).
366
+ control_flow (str | None): The control flow tag, if any.
367
+ **kwargs: Additional keyword arguments to pass to the edge.
368
+
369
+ In the case of the following line:
370
+
371
+ >>> y = f(x)
372
+
373
+ This function will add an edge from the variable `x` to the function `f`.
374
+ """
375
+ if source["_type"] != "Name":
376
+ raise NotImplementedError(f"Only variable inputs supported, got: {source}")
377
+ var_name = source["id"]
378
+ if control_flow is not None:
379
+ kwargs["control_flow"] = control_flow
380
+ versioned = f"{var_name}_{self._get_var_index(var_name)}"
381
+ self.graph.add_edge(versioned, target, type="input", **kwargs)
382
+ if versioned in self._parallel_var:
383
+ for key in self._parallel_var.pop(versioned):
384
+ self.graph.add_edge(key, target, type="input", **kwargs)
385
+
386
+ def _get_unique_func_name(self, base_name):
387
+ i = self._call_counter.get(base_name, 0)
388
+ self._call_counter[base_name] = i + 1
389
+ return f"{base_name}_{i}"
390
+
391
+ def _convert_control_flow(self, control_flow: str | None, tag: str) -> str:
392
+ control_flow = "" if control_flow is None else f"{control_flow.split('-')[0]}/"
393
+ counter = 0
394
+ while True:
395
+ if f"{control_flow}{tag}_{counter}" not in self._control_flow_list:
396
+ self._control_flow_list.append(f"{control_flow}{tag}_{counter}")
397
+ break
398
+ counter += 1
399
+ return f"{control_flow}{tag}_{counter}"
400
+
401
+
402
+ def _get_variables_from_subgraph(graph: nx.DiGraph, io_: str) -> set[str]:
403
+ """
404
+ Get variables from a subgraph based on the type of I/O and control flow.
405
+
406
+ Args:
407
+ graph (nx.DiGraph): The directed graph representing the function.
408
+ io_ (str): The type of I/O to filter by ("input" or "output").
409
+ control_flow (list, str): A list of control flow types to filter by.
410
+
411
+ Returns:
412
+ set[str]: A set of variable names that match the specified I/O type and
413
+ control flow.
414
+ """
415
+ assert io_ in ["input", "output"], "io_ must be 'input' or 'output'"
416
+ if io_ == "input":
417
+ edge_ind = 0
418
+ elif io_ == "output":
419
+ edge_ind = 1
420
+ return set(
421
+ [edge[edge_ind] for edge in graph.edges.data() if edge[2]["type"] == io_]
422
+ )
423
+
424
+
425
+ def _get_parent_graph(graph: nx.DiGraph, control_flow: str) -> nx.DiGraph:
426
+ """
427
+ Get parent body of the indented body
428
+
429
+ Args:
430
+ graph (nx.DiGraph): Full graph to look for the parent graph from
431
+ control_flow (str): Control flow whose parent graph is to look for
432
+
433
+ Returns:
434
+ (nx.DiGraph): Parent graph
435
+ """
436
+ return nx.DiGraph(
437
+ [
438
+ edge
439
+ for edge in graph.edges.data()
440
+ if not _get_control_flow(edge[2]).startswith(control_flow)
441
+ ]
442
+ )
443
+
444
+
445
+ def _detect_io_variables_from_control_flow(
446
+ graph: nx.DiGraph, subgraph: nx.DiGraph
447
+ ) -> dict[str, list]:
448
+ """
449
+ Detect input and output variables from a graph based on control flow.
450
+
451
+ Args:
452
+ graph (nx.DiGraph): The directed graph representing the function.
453
+
454
+ Returns:
455
+ dict[str, set]: A dictionary with keys "input" and "output", each
456
+ containing a set
457
+
458
+ Take a look at the unit tests for examples of how to use this function.
459
+ """
460
+ sg_body = nx.DiGraph(
461
+ [
462
+ edge
463
+ for edge in subgraph.edges.data()
464
+ if edge[0] != "input" and edge[1] != "output"
465
+ ]
466
+ )
467
+ cf = sorted(
468
+ [
469
+ _get_control_flow(edge[2])
470
+ for edge in sg_body.edges.data()
471
+ if "control_flow" in edge[2]
472
+ ]
473
+ )
474
+ if len(cf) == 0:
475
+ return {"inputs": [], "outputs": []}
476
+ assert all([cf[ii + 1].startswith(cf[ii]) for ii in range(len(cf) - 1)])
477
+ parent_graph = _get_parent_graph(graph, cf[0])
478
+ var_inp_1 = _get_variables_from_subgraph(graph=sg_body, io_="input")
479
+ var_inp_2 = _get_variables_from_subgraph(graph=parent_graph, io_="output")
480
+ var_out_1 = _get_variables_from_subgraph(graph=parent_graph, io_="input")
481
+ var_out_2 = _get_variables_from_subgraph(graph=sg_body, io_="output")
482
+ return {
483
+ "inputs": list(var_inp_1.intersection(var_inp_2)),
484
+ "outputs": list(var_out_1.intersection(var_out_2)),
485
+ }
486
+
487
+
488
+ def _extract_control_flows(graph: nx.DiGraph) -> list[str]:
489
+ return list(set([_get_control_flow(edge[2]) for edge in graph.edges.data()]))
490
+
491
+
492
+ def _split_graphs_into_subgraphs(graph: nx.DiGraph) -> dict[str, nx.DiGraph]:
493
+ return {
494
+ control_flow: nx.DiGraph(
495
+ [
496
+ edge
497
+ for edge in graph.edges.data()
498
+ if _get_control_flow(edge[2]) == control_flow
499
+ ]
500
+ )
501
+ for control_flow in _extract_control_flows(graph)
502
+ }
503
+
504
+
505
+ def _get_subgraphs(graph: nx.DiGraph, cf_graph: nx.DiGraph) -> dict[str, nx.DiGraph]:
506
+ """
507
+ Separate a flat graph into subgraphs nested by control flows
508
+
509
+ Args:
510
+ graph (nx.DiGraph): Flat workflow graph
511
+ cf_graph (nx.DiGraph): Control flow graph (cf. _get_control_flow_graph)
512
+
513
+ Returns:
514
+ dict[str, nx.DiGraph]: Subgraphs
515
+ """
516
+ subgraphs = _split_graphs_into_subgraphs(graph)
517
+ for key in list(topological_sort(cf_graph))[::-1]:
518
+ subgraph = subgraphs[key]
519
+ node_name = "injected_" + key.replace("/", "_")
520
+ io_ = _detect_io_variables_from_control_flow(graph, subgraph)
521
+ for parent_graph_name in cf_graph.predecessors(key):
522
+ parent_graph = subgraphs[parent_graph_name]
523
+ for inp in io_["inputs"]:
524
+ parent_graph.add_edge(
525
+ inp, node_name, type="input", input_name=_remove_index(inp)
526
+ )
527
+ for out in io_["outputs"]:
528
+ parent_graph.add_edge(
529
+ node_name, out, type="output", output_name=_remove_index(out)
530
+ )
531
+ for inp in io_["inputs"]:
532
+ subgraph.add_edge("input", inp, type="output")
533
+ for out in io_["outputs"]:
534
+ subgraph.add_edge(out, "output", type="input")
535
+ return subgraphs
536
+
537
+
538
+ def _extract_functions_from_graph(graph: nx.DiGraph) -> set:
539
+ function_names = []
540
+ for edge in graph.edges.data():
541
+ if edge[2]["type"] == "output" and edge[0] != "input":
542
+ function_names.append(edge[0])
543
+ elif edge[2]["type"] == "input" and edge[1] != "output":
544
+ function_names.append(edge[1])
545
+ return set(function_names)
546
+
547
+
548
+ def _get_control_flow_graph(control_flows: list[str]) -> nx.DiGraph:
549
+ """
550
+ Create a graph based on the control flows. The indentation level
551
+ corresponds to the graph level. The higher level body is the parent node
552
+ of the lower body.
553
+
554
+
555
+ Args:
556
+ control_flows (list[str]): All control flows present in a workflow
557
+
558
+ Returns:
559
+ nx.DiGraph: Control flow graph
560
+ """
561
+ cf_list = []
562
+ for cf in control_flows:
563
+ if cf == "":
564
+ continue
565
+ if "/" in cf:
566
+ cf_list.append(["/".join(cf.split("/")[:-1]), cf])
567
+ else:
568
+ cf_list.append(["", cf])
569
+ graph = nx.DiGraph(cf_list)
570
+ if len(graph) == 0:
571
+ graph.add_node("")
572
+ return graph
573
+
574
+
575
+ def _function_to_ast_dict(node):
576
+ if isinstance(node, ast.AST):
577
+ result = {"_type": type(node).__name__}
578
+ for field, value in ast.iter_fields(node):
579
+ result[field] = _function_to_ast_dict(value)
580
+ return result
581
+ elif isinstance(node, list):
582
+ return [_function_to_ast_dict(item) for item in node]
583
+ else:
584
+ return node
585
+
586
+
587
+ def get_ast_dict(func: Callable) -> dict:
588
+ """Get the AST dictionary representation of a function."""
589
+ source_code = textwrap.dedent(inspect.getsource(func))
590
+ tree = ast.parse(source_code)
591
+ return _function_to_ast_dict(tree)
592
+
593
+
594
+ def analyze_function(func: Callable) -> tuple[nx.DiGraph, dict[str, Any]]:
595
+ """Extracts the variable flow graph from a function"""
596
+ ast_dict = get_ast_dict(func)
597
+ scope = inspect.getmodule(func).__dict__ | vars(builtins)
598
+ analyzer = FunctionDictFlowAnalyzer(ast_dict["body"][0], scope)
599
+ return analyzer.analyze()
600
+
601
+
602
+ def _get_node_outputs(func: Callable, counts: int | None = None) -> dict[str, dict]:
603
+ output_hints = parse_output_args(
604
+ func, separate_tuple=(counts is None or counts > 1)
605
+ )
606
+ output_vars = get_return_expressions(func)
607
+ if output_vars is None or len(output_vars) == 0:
608
+ return {}
609
+ if (counts is not None and counts == 1) or isinstance(output_vars, str):
610
+ if isinstance(output_vars, str):
611
+ return {output_vars: cast(dict, output_hints)}
612
+ else:
613
+ return {"output": cast(dict, output_hints)}
614
+ assert isinstance(output_vars, tuple), output_vars
615
+ assert counts is None or len(output_vars) == counts, output_vars
616
+ if output_hints == {}:
617
+ return {key: {} for key in output_vars}
618
+ else:
619
+ assert counts is None or len(output_hints) == counts
620
+ return {key: hint for key, hint in zip(output_vars, output_hints, strict=False)}
621
+
622
+
623
+ def _get_output_counts(graph: nx.DiGraph) -> dict[str, int]:
624
+ """
625
+ Get the number of outputs for each node in the graph.
626
+
627
+ Args:
628
+ graph (nx.DiGraph): The directed graph representing the function.
629
+
630
+ Returns:
631
+ dict: A dictionary mapping node names to the number of outputs.
632
+ """
633
+ f_dict: dict[str, int] = {}
634
+ for edge in graph.edges.data():
635
+ if edge[2]["type"] != "output":
636
+ continue
637
+ f_dict[edge[0]] = f_dict.get(edge[0], 0) + 1
638
+ if "input" in f_dict:
639
+ del f_dict["input"]
640
+ return f_dict
641
+
642
+
643
+ def _get_nodes(
644
+ data: dict[str, dict],
645
+ output_counts: dict[str, int],
646
+ control_flow: None | str = None,
647
+ ) -> dict[str, dict]:
648
+ result = {}
649
+ for label, function in data.items():
650
+ func = function["function"]
651
+ if hasattr(func, "_semantikon_workflow"):
652
+ if output_counts[label] != len(func._semantikon_workflow["outputs"]):
653
+ raise ValueError(
654
+ f"{label} has {len(func._semantikon_workflow['outputs'])} outputs, "
655
+ f"but {output_counts[label]} expected"
656
+ )
657
+ data_dict = func._semantikon_workflow.copy()
658
+ result[label] = data_dict
659
+ result[label]["label"] = label
660
+ if hasattr(func, "_semantikon_metadata"):
661
+ result[label].update(func._semantikon_metadata)
662
+ else:
663
+ result[label] = get_node_dict(
664
+ function=func,
665
+ inputs=parse_input_args(func),
666
+ outputs=_get_node_outputs(func, output_counts.get(label, 1)),
667
+ )
668
+ return result
669
+
670
+
671
+ def _remove_index(s: str) -> str:
672
+ return "_".join(s.split("_")[:-1])
673
+
674
+
675
+ def _get_control_flow(data: dict[str, Any]) -> str:
676
+ """
677
+ Get the control flow name
678
+
679
+ Args:
680
+ data (dict[str, Any]): metadata of the edge (which is stored in the
681
+ third element of each edge of nx.Digraph)
682
+
683
+ Returns:
684
+ (str): Control flow name (e.g. While_0, For_3 etc.)
685
+ """
686
+ return data.get("control_flow", "").split("-")[0]
687
+
688
+
689
+ def _get_sorted_edges(graph: nx.DiGraph) -> list:
690
+ """
691
+ Sort the edges of the graph based on the topological order of the nodes.
692
+
693
+ Args:
694
+ graph (nx.DiGraph): The directed graph representing the function.
695
+
696
+ Returns:
697
+ list: A sorted list of edges in the graph.
698
+
699
+ Example:
700
+
701
+ >>> graph.add_edges_from([('A', 'B'), ('B', 'D'), ('A', 'C'), ('C', 'D')])
702
+ >>> sorted_edges = _get_sorted_edges(graph)
703
+ >>> print(sorted_edges)
704
+
705
+ Output:
706
+
707
+ >>> [('A', 'B', {}), ('A', 'C', {}), ('B', 'D', {}), ('C', 'D', {})]
708
+ """
709
+ topo_order = list(topological_sort(graph))
710
+ node_order = {node: i for i, node in enumerate(topo_order)}
711
+ return sorted(graph.edges.data(), key=lambda edge: node_order[edge[0]])
712
+
713
+
714
+ def _remove_and_reconnect_nodes(
715
+ G: nx.DiGraph, nodes_to_remove: list[str]
716
+ ) -> nx.DiGraph:
717
+ for node in set(nodes_to_remove):
718
+ preds = list(G.predecessors(node))
719
+ succs = list(G.successors(node))
720
+ for u in preds:
721
+ for v in succs:
722
+ G.add_edge(u, v)
723
+ G.remove_node(node)
724
+ return G
725
+
726
+
727
+ def _get_edges(graph: nx.DiGraph, nodes: dict[str, dict]) -> list[tuple[str, str]]:
728
+ io_dict = {
729
+ key: {
730
+ "input": list(data["inputs"].keys()),
731
+ "output": list(data["outputs"].keys()),
732
+ }
733
+ for key, data in nodes.items()
734
+ }
735
+ edges = []
736
+ nodes_to_remove = []
737
+ for edge in graph.edges.data():
738
+ if edge[0] == "input":
739
+ edges.append([edge[0] + "s." + _remove_index(edge[1]), edge[1]])
740
+ nodes_to_remove.append(edge[1])
741
+ elif edge[1] == "output":
742
+ edges.append([edge[0], edge[1] + "s." + _remove_index(edge[0])])
743
+ nodes_to_remove.append(edge[0])
744
+ elif edge[2]["type"] == "input":
745
+ if "input_name" in edge[2]:
746
+ tag = edge[2]["input_name"]
747
+ elif "input_index" in edge[2]:
748
+ tag = io_dict[edge[1]]["input"][edge[2]["input_index"]]
749
+ else:
750
+ raise ValueError
751
+ edges.append([edge[0], edge[1] + ".inputs." + tag])
752
+ nodes_to_remove.append(edge[0])
753
+ elif edge[2]["type"] == "output":
754
+ if "output_index" in edge[2]:
755
+ tag = io_dict[edge[0]]["output"][edge[2]["output_index"]]
756
+ elif "output_name" in edge[2]:
757
+ tag = edge[2]["output_name"]
758
+ else:
759
+ tag = io_dict[edge[0]]["output"][0]
760
+ edges.append([edge[0] + ".outputs." + tag, edge[1]])
761
+ nodes_to_remove.append(edge[1])
762
+ new_graph = _remove_and_reconnect_nodes(nx.DiGraph(edges), nodes_to_remove)
763
+ return list(new_graph.edges)
764
+
765
+
766
+ def get_node_dict(
767
+ function: Callable,
768
+ inputs: dict[str, dict] | None = None,
769
+ outputs: dict[str, dict] | None = None,
770
+ ) -> dict:
771
+ """
772
+ Get a dictionary representation of the function node.
773
+
774
+ Args:
775
+ func (Callable): The function to be analyzed.
776
+ data_format (str): The format of the output. Options are "semantikon" and
777
+ "ape".
778
+
779
+ Returns:
780
+ (dict) A dictionary representation of the function node.
781
+ """
782
+ if inputs is None:
783
+ inputs = parse_input_args(function)
784
+ if outputs is None:
785
+ outputs = _get_node_outputs(function)
786
+ data = {
787
+ "inputs": inputs,
788
+ "outputs": outputs,
789
+ "function": function,
790
+ "type": "Function",
791
+ }
792
+ if hasattr(function, "_semantikon_metadata"):
793
+ data.update(function._semantikon_metadata)
794
+ return data
795
+
796
+
797
+ def _to_workflow_dict_entry(
798
+ inputs: dict[str, dict],
799
+ outputs: dict[str, dict],
800
+ nodes: dict[str, dict],
801
+ edges: list[tuple[str, str]],
802
+ label: str,
803
+ **kwargs,
804
+ ) -> dict[str, object]:
805
+ assert all("inputs" in v for v in nodes.values())
806
+ assert all("outputs" in v for v in nodes.values())
807
+ assert all(
808
+ "function" in v or ("nodes" in v and "edges" in v) for v in nodes.values()
809
+ )
810
+ return {
811
+ "inputs": inputs,
812
+ "outputs": outputs,
813
+ "nodes": nodes,
814
+ "edges": edges,
815
+ "label": label,
816
+ "type": "Workflow",
817
+ } | kwargs
818
+
819
+
820
+ def _get_test_dict(f_dict: dict[str, dict]) -> dict[str, str]:
821
+ """
822
+ dict to translate test and iter nodes into "test" and "iter"
823
+
824
+ Args:
825
+ f_dict (dict[str, dict]): Function dictionary
826
+
827
+ Returns:
828
+ dict[str, str]: Translation of node name to "test" or "iter"
829
+ """
830
+ return {
831
+ key: tag
832
+ for key, value in f_dict.items()
833
+ for tag in ["test", "iter"]
834
+ if value.get("control_flow", "").endswith(tag)
835
+ }
836
+
837
+
838
+ def _nest_nodes(
839
+ graph: nx.DiGraph, nodes: dict[str, dict], f_dict: dict[str, dict]
840
+ ) -> tuple[dict[str, dict], list[tuple[str, str]]]:
841
+ """
842
+ Nest workflow nodes
843
+
844
+ Args:
845
+ graph (nx.DiGraph): The directed graph representing the function.
846
+ nodes (dict[str, dict]): The dictionary of nodes.
847
+ f_dict (dict[str, dict]): The dictionary of functions.
848
+
849
+ Returns:
850
+ dict: A dictionary containing the nested nodes, edges, and label.
851
+ """
852
+ test_dict = _get_test_dict(f_dict=f_dict)
853
+ cf_graph = _get_control_flow_graph(_extract_control_flows(graph))
854
+ subgraphs = _get_subgraphs(graph, cf_graph)
855
+ injected_nodes: dict[str, Any] = {}
856
+ for cf_key in list(topological_sort(cf_graph))[::-1]:
857
+ subgraph = nx.relabel_nodes(subgraphs[cf_key], test_dict)
858
+ new_key = "injected_" + cf_key.replace("/", "_") if len(cf_key) > 0 else cf_key
859
+ current_nodes = {}
860
+ for key in _extract_functions_from_graph(subgraphs[cf_key]):
861
+ if key in test_dict:
862
+ current_nodes[test_dict[key]] = nodes[key]
863
+ elif key in nodes:
864
+ current_nodes[key] = nodes[key]
865
+ else:
866
+ current_nodes[key] = injected_nodes.pop(key)
867
+ io_ = _detect_io_variables_from_control_flow(graph, subgraph)
868
+ injected_nodes[new_key] = {
869
+ "nodes": current_nodes,
870
+ "edges": _get_edges(subgraph, current_nodes),
871
+ "label": new_key,
872
+ "inputs": {_remove_index(key): {} for key in io_["inputs"]},
873
+ "outputs": {_remove_index(key): {} for key in io_["outputs"]},
874
+ }
875
+ for tag in ["test", "iter"]:
876
+ if tag in injected_nodes[new_key]["nodes"]:
877
+ injected_nodes[new_key][tag] = injected_nodes[new_key]["nodes"].pop(tag)
878
+ return injected_nodes[""]["nodes"], injected_nodes[""]["edges"]
879
+
880
+
881
+ def get_workflow_dict(func: Callable) -> dict[str, object]:
882
+ """
883
+ Get a dictionary representation of the workflow for a given function.
884
+
885
+ Args:
886
+ func (Callable): The function to be analyzed.
887
+
888
+ Returns:
889
+ dict: A dictionary representation of the workflow, including inputs,
890
+ outputs, nodes, edges, and label.
891
+ """
892
+ graph, f_dict = analyze_function(func)
893
+ nodes = _get_nodes(f_dict, _get_output_counts(graph))
894
+ nested_nodes, edges = _nest_nodes(graph, nodes, f_dict)
895
+ return _to_workflow_dict_entry(
896
+ inputs=parse_input_args(func),
897
+ outputs=_get_node_outputs(func),
898
+ nodes=nested_nodes,
899
+ edges=edges,
900
+ label=func.__name__,
901
+ )
902
+
903
+
904
+ def _get_missing_edges(edge_list: list[tuple[str, str]]) -> list[tuple[str, str]]:
905
+ """
906
+ Insert processes into the data edges. Take the following workflow:
907
+
908
+ >>> y = f(x=x)
909
+ >>> z = g(y=y)
910
+
911
+ The data flow is
912
+
913
+ - f.inputs.x -> f.outputs.y
914
+ - f.outputs.y -> g.inputs.y
915
+ - g.inputs.y -> g.outputs.z
916
+
917
+ `_get_missing_edges` adds the processes:
918
+
919
+ - f.inputs.x -> f
920
+ - f -> f.outputs.y
921
+ - f.outputs.y -> g.inputs.y
922
+ - g.inputs.y -> g
923
+ - g -> g.outputs.z
924
+ """
925
+ extra_edges = []
926
+ for edge in edge_list:
927
+ for tag in edge:
928
+ if len(tag.split(".")) < 3:
929
+ continue
930
+ if tag.split(".")[1] == "inputs":
931
+ new_edge = (tag, tag.split(".")[0])
932
+ elif tag.split(".")[1] == "outputs":
933
+ new_edge = (tag.split(".")[0], tag)
934
+ if new_edge not in extra_edges:
935
+ extra_edges.append(new_edge)
936
+ return extra_edges
937
+
938
+
939
+ class _Workflow:
940
+ def __init__(self, workflow_dict: dict[str, Any]):
941
+ self._workflow = workflow_dict
942
+
943
+ @cached_property
944
+ def _all_edges(self) -> list[tuple[str, str]]:
945
+ edges = cast(dict[str, list], self._workflow)["edges"]
946
+ return edges + _get_missing_edges(edges)
947
+
948
+ @cached_property
949
+ def _graph(self) -> nx.DiGraph:
950
+ return nx.DiGraph(self._all_edges)
951
+
952
+ @cached_property
953
+ def _execution_list(self) -> list[list[str]]:
954
+ return find_parallel_execution_levels(self._graph)
955
+
956
+ def _sanitize_input(self, *args, **kwargs) -> dict[str, Any]:
957
+ keys = list(self._workflow["inputs"].keys())
958
+ for ii, arg in enumerate(args):
959
+ if keys[ii] in kwargs:
960
+ raise TypeError(
961
+ f"{self._workflow['label']}() got multiple values for"
962
+ " argument '{keys[ii]}'"
963
+ )
964
+ kwargs[keys[ii]] = arg
965
+ return kwargs
966
+
967
+ def _set_inputs(self, *args, **kwargs):
968
+ kwargs = self._sanitize_input(*args, **kwargs)
969
+ for key, value in kwargs.items():
970
+ if key not in self._workflow["inputs"]:
971
+ raise TypeError(f"Unexpected keyword argument '{key}'")
972
+ self._workflow["inputs"][key]["value"] = value
973
+
974
+ def _get_value_from_data(self, node: dict[str, Any]) -> Any:
975
+ if "value" not in node:
976
+ node["value"] = node["default"]
977
+ return node["value"]
978
+
979
+ def _get_value_from_global(self, path: str) -> Any:
980
+ io, var = path.split(".")
981
+ return self._get_value_from_data(self._workflow[io][var])
982
+
983
+ def _get_value_from_node(self, path: str) -> Any:
984
+ node, io, var = path.split(".")
985
+ return self._get_value_from_data(self._workflow["nodes"][node][io][var])
986
+
987
+ def _set_value_from_global(self, path, value):
988
+ io, var = path.split(".")
989
+ self._workflow[io][var]["value"] = value
990
+
991
+ def _set_value_from_node(self, path, value):
992
+ node, io, var = path.split(".")
993
+ try:
994
+ self._workflow["nodes"][node][io][var]["value"] = value
995
+ except KeyError:
996
+ raise KeyError(f"{path} not found in {node}") from None
997
+
998
+ def _execute_node(self, function: str) -> Any:
999
+ node = self._workflow["nodes"][function]
1000
+ input_data = {}
1001
+ try:
1002
+ for key, content in node["inputs"].items():
1003
+ if "value" not in content:
1004
+ content["value"] = content["default"]
1005
+ input_data[key] = content["value"]
1006
+ except KeyError:
1007
+ raise KeyError(f"value not defined for {function}") from None
1008
+ if "function" not in node:
1009
+ workflow = _Workflow(node)
1010
+ outputs = [
1011
+ d["value"] for d in workflow.run(**input_data)["outputs"].values()
1012
+ ]
1013
+ if len(outputs) == 1:
1014
+ outputs = outputs[0]
1015
+ else:
1016
+ outputs = node["function"](**input_data)
1017
+ return outputs
1018
+
1019
+ def _set_value(self, tag, value):
1020
+ if len(tag.split(".")) == 2 and tag.split(".")[0] in ("inputs", "outputs"):
1021
+ self._set_value_from_global(tag, value)
1022
+ elif len(tag.split(".")) == 3 and tag.split(".")[1] in ("inputs", "outputs"):
1023
+ self._set_value_from_node(tag, value)
1024
+ elif "." in tag:
1025
+ raise ValueError(f"{tag} not recognized")
1026
+
1027
+ def _get_value(self, tag: str):
1028
+ if len(tag.split(".")) == 2 and tag.split(".")[0] in ("inputs", "outputs"):
1029
+ return self._get_value_from_global(tag)
1030
+ elif len(tag.split(".")) == 3 and tag.split(".")[1] in ("inputs", "outputs"):
1031
+ return self._get_value_from_node(tag)
1032
+ elif "." not in tag:
1033
+ return self._execute_node(tag)
1034
+ else:
1035
+ raise ValueError(f"{tag} not recognized")
1036
+
1037
+ def run(self, *args, **kwargs) -> dict[str, Any]:
1038
+ self._set_inputs(*args, **kwargs)
1039
+ for current_list in self._execution_list:
1040
+ for item in current_list:
1041
+ values = self._get_value(item)
1042
+ nodes = self._graph.edges(item)
1043
+ if "." not in item and len(nodes) > 1:
1044
+ for value, node in zip(values, nodes, strict=False):
1045
+ self._set_value(node[1], value)
1046
+ else:
1047
+ for node in nodes:
1048
+ self._set_value(node[1], values)
1049
+ return self._workflow
1050
+
1051
+
1052
+ def find_parallel_execution_levels(G: nx.DiGraph) -> list[list[str]]:
1053
+ """
1054
+ Find levels of parallel execution in a directed acyclic graph (DAG).
1055
+
1056
+ Args:
1057
+ G (nx.DiGraph): The directed graph representing the function.
1058
+
1059
+ Returns:
1060
+ list[list[str]]: A list of lists, where each inner list contains nodes
1061
+ that can be executed in parallel.
1062
+
1063
+ Comment:
1064
+ This function only gives you a list of nodes that can be executed in
1065
+ parallel, but does not tell you which processes can be executed in
1066
+ case there is a process that takes longer at a higher level.
1067
+ """
1068
+ in_degree = dict(cast(Iterable[tuple[Any, int]], G.in_degree()))
1069
+ queue = deque([node for node in G.nodes if in_degree[node] == 0])
1070
+ levels = []
1071
+
1072
+ while queue:
1073
+ current_level = list(queue)
1074
+ if "input" not in current_level and "output" not in current_level:
1075
+ levels.append(current_level)
1076
+
1077
+ next_queue: deque = deque()
1078
+ for node in current_level:
1079
+ for neighbor in G.successors(node):
1080
+ in_degree[neighbor] -= 1
1081
+ if in_degree[neighbor] == 0:
1082
+ next_queue.append(neighbor)
1083
+
1084
+ queue = next_queue
1085
+
1086
+ return levels
1087
+
1088
+
1089
+ def workflow(func: Callable) -> FunctionWithWorkflow:
1090
+ """
1091
+ Decorator to convert a function into a workflow with metadata.
1092
+
1093
+ Args:
1094
+ func (Callable): The function to be converted into a workflow.
1095
+
1096
+ Returns:
1097
+ FunctionWithWorkflow: A callable object that includes the original function
1098
+
1099
+ Example:
1100
+
1101
+ >>> def operation(x: float, y: float) -> tuple[float, float]:
1102
+ >>> return x + y, x - y
1103
+ >>>
1104
+ >>>
1105
+ >>> def add(x: float = 2.0, y: float = 1) -> float:
1106
+ >>> return x + y
1107
+ >>>
1108
+ >>>
1109
+ >>> def multiply(x: float, y: float = 5) -> float:
1110
+ >>> return x * y
1111
+ >>>
1112
+ >>>
1113
+ >>> @workflow
1114
+ >>> def example_macro(a=10, b=20):
1115
+ >>> c, d = operation(a, b)
1116
+ >>> e = add(c, y=d)
1117
+ >>> f = multiply(e)
1118
+ >>> return f
1119
+ >>>
1120
+ >>>
1121
+ >>> @workflow
1122
+ >>> def example_workflow(a=10, b=20):
1123
+ >>> y = example_macro(a, b)
1124
+ >>> z = add(y, b)
1125
+ >>> return z
1126
+
1127
+ This example defines a workflow `example_macro`, that includes `operation`,
1128
+ `add`, and `multiply`, which is nested inside another workflow
1129
+ `example_workflow`. Both workflows can be executed using their `run` method,
1130
+ which returns the dictionary representation of the workflow with all the
1131
+ intermediate steps and outputs.
1132
+ """
1133
+ workflow_dict = get_workflow_dict(func)
1134
+ w = _Workflow(workflow_dict)
1135
+ func_with_metadata = FunctionWithWorkflow(func, workflow_dict, w.run)
1136
+ return func_with_metadata
1137
+
1138
+
1139
+ def get_ports(
1140
+ func: Callable, separate_return_tuple: bool = True, strict: bool = False
1141
+ ) -> tuple[Inputs, Outputs]:
1142
+ type_hints = get_annotated_type_hints(func)
1143
+ return_hint = type_hints.pop("return", inspect.Parameter.empty)
1144
+ return_labels = get_return_labels(
1145
+ func, separate_tuple=separate_return_tuple, strict=strict
1146
+ )
1147
+ if get_origin(return_hint) is tuple and separate_return_tuple:
1148
+ output_annotations = {
1149
+ label: meta_to_dict(ann, flatten_metadata=False)
1150
+ for label, ann in zip(return_labels, get_args(return_hint), strict=False)
1151
+ }
1152
+ else:
1153
+ output_annotations = {
1154
+ return_labels[0]: meta_to_dict(return_hint, flatten_metadata=False)
1155
+ }
1156
+ input_annotations = {
1157
+ key: meta_to_dict(
1158
+ type_hints.get(key, value.annotation), value.default, flatten_metadata=False
1159
+ )
1160
+ for key, value in inspect.signature(func).parameters.items()
1161
+ }
1162
+ return (
1163
+ Inputs(**{k: Input(label=k, **v) for k, v in input_annotations.items()}),
1164
+ Outputs(**{k: Output(label=k, **v) for k, v in output_annotations.items()}),
1165
+ )
1166
+
1167
+
1168
+ def get_node(func: Callable, label: str | None = None) -> Function | Workflow:
1169
+ metadata_dict = (
1170
+ func._semantikon_metadata if hasattr(func, "_semantikon_metadata") else MISSING
1171
+ )
1172
+ metadata = (
1173
+ metadata_dict
1174
+ if isinstance(metadata_dict, Missing)
1175
+ else CoreMetadata.from_dict(metadata_dict)
1176
+ )
1177
+
1178
+ if hasattr(func, "_semantikon_workflow"):
1179
+ return parse_workflow(func._semantikon_workflow, metadata)
1180
+ else:
1181
+ return parse_function(func, metadata, label=label)
1182
+
1183
+
1184
+ def parse_function(
1185
+ func: Callable, metadata: CoreMetadata | Missing, label: str | None = None
1186
+ ) -> Function:
1187
+ inputs, outputs = get_ports(func)
1188
+ return Function(
1189
+ label=func.__name__ if label is None else label,
1190
+ inputs=inputs,
1191
+ outputs=outputs,
1192
+ function=func,
1193
+ metadata=metadata,
1194
+ )
1195
+
1196
+
1197
+ def _port_from_dictionary(
1198
+ io_dictionary: dict[str, object], label: str, port_class: type[PortType]
1199
+ ) -> PortType:
1200
+ """
1201
+ Take a traditional _semantikon_workflow dictionary's input or output subdictionary
1202
+ and nest the metadata (if any) as a dataclass.
1203
+ """
1204
+ metadata_kwargs = {}
1205
+ for field in dataclasses.fields(TypeMetadata):
1206
+ if field.name in io_dictionary:
1207
+ metadata_kwargs[field.name] = io_dictionary.pop(field.name)
1208
+ if len(metadata_kwargs) > 0:
1209
+ io_dictionary["metadata"] = TypeMetadata.from_dict(metadata_kwargs)
1210
+ io_dictionary["label"] = label
1211
+ return port_class.from_dict(io_dictionary)
1212
+
1213
+
1214
+ def _input_from_dictionary(io_dictionary: dict[str, object], label: str) -> Input:
1215
+ return _port_from_dictionary(io_dictionary, label, Input)
1216
+
1217
+
1218
+ def _output_from_dictionary(io_dictionary: dict[str, object], label: str) -> Output:
1219
+ return _port_from_dictionary(io_dictionary, label, Output)
1220
+
1221
+
1222
+ def parse_workflow(
1223
+ semantikon_workflow: dict[str, Any], metadata: CoreMetadata | Missing = MISSING
1224
+ ) -> Workflow:
1225
+ label = semantikon_workflow["label"]
1226
+ inputs = Inputs(
1227
+ **{
1228
+ k: _input_from_dictionary(v, label=k)
1229
+ for k, v in semantikon_workflow["inputs"].items()
1230
+ }
1231
+ )
1232
+ outputs = Outputs(
1233
+ **{
1234
+ k: _output_from_dictionary(v, label=k)
1235
+ for k, v in semantikon_workflow["outputs"].items()
1236
+ }
1237
+ )
1238
+ nodes = Nodes(
1239
+ **{
1240
+ k: (
1241
+ get_node(v["function"], label=k)
1242
+ if "function" in v
1243
+ else parse_workflow(v)
1244
+ )
1245
+ for k, v in semantikon_workflow["nodes"].items()
1246
+ }
1247
+ )
1248
+ edges = Edges(**{v: k for k, v in semantikon_workflow["edges"]})
1249
+ return Workflow(
1250
+ label=label,
1251
+ inputs=inputs,
1252
+ outputs=outputs,
1253
+ nodes=nodes,
1254
+ edges=edges,
1255
+ metadata=metadata,
1256
+ )