python-workflow-definition 0.0.1__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,138 @@
1
+ from importlib import import_module
2
+ import json
3
+ import traceback
4
+
5
+ from aiida import orm
6
+ from aiida_pythonjob.data.serializer import general_serializer
7
+ from aiida_workgraph import WorkGraph, task
8
+ from aiida_workgraph.socket import TaskSocketNamespace
9
+
10
+ from python_workflow_definition.shared import (
11
+ convert_nodes_list_to_dict,
12
+ NODES_LABEL,
13
+ EDGES_LABEL,
14
+ SOURCE_LABEL,
15
+ SOURCE_PORT_LABEL,
16
+ TARGET_LABEL,
17
+ TARGET_PORT_LABEL,
18
+ )
19
+
20
+
21
+ def load_workflow_json(file_name: str) -> WorkGraph:
22
+ with open(file_name) as f:
23
+ data = json.load(f)
24
+
25
+ wg = WorkGraph()
26
+ task_name_mapping = {}
27
+
28
+ for id, identifier in convert_nodes_list_to_dict(
29
+ nodes_list=data[NODES_LABEL]
30
+ ).items():
31
+ if isinstance(identifier, str) and "." in identifier:
32
+ p, m = identifier.rsplit(".", 1)
33
+ mod = import_module(p)
34
+ func = getattr(mod, m)
35
+ wg.add_task(func)
36
+ # Remove the default result output, because we will add the outputs later from the data in the link
37
+ del wg.tasks[-1].outputs["result"]
38
+ task_name_mapping[id] = wg.tasks[-1]
39
+ else:
40
+ # data task
41
+ data_node = general_serializer(identifier)
42
+ task_name_mapping[id] = data_node
43
+
44
+ # add links
45
+ for link in data[EDGES_LABEL]:
46
+ to_task = task_name_mapping[str(link[TARGET_LABEL])]
47
+ # if the input is not exit, it means we pass the data into to the kwargs
48
+ # in this case, we add the input socket
49
+ if link[TARGET_PORT_LABEL] not in to_task.inputs:
50
+ to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
51
+ else:
52
+ to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
53
+ from_task = task_name_mapping[str(link[SOURCE_LABEL])]
54
+ if isinstance(from_task, orm.Data):
55
+ to_socket.value = from_task
56
+ else:
57
+ try:
58
+ if link[SOURCE_PORT_LABEL] is None:
59
+ link[SOURCE_PORT_LABEL] = "result"
60
+ # because we are not define the outputs explicitly during the pythonjob creation
61
+ # we add it here, and assume the output exit
62
+ if link[SOURCE_PORT_LABEL] not in from_task.outputs:
63
+ # if str(link["sourcePort"]) not in from_task.outputs:
64
+ from_socket = from_task.add_output(
65
+ "workgraph.any",
66
+ name=link[SOURCE_PORT_LABEL],
67
+ # name=str(link["sourcePort"]),
68
+ metadata={"is_function_output": True},
69
+ )
70
+ else:
71
+ from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
72
+
73
+ wg.add_link(from_socket, to_socket)
74
+ except Exception as e:
75
+ traceback.print_exc()
76
+ print("Failed to link", link, "with error:", e)
77
+ return wg
78
+
79
+
80
+ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
81
+ data = {NODES_LABEL: [], EDGES_LABEL: []}
82
+ node_name_mapping = {}
83
+ data_node_name_mapping = {}
84
+ i = 0
85
+ for node in wg.tasks:
86
+ executor = node.get_executor()
87
+ node_name_mapping[node.name] = i
88
+
89
+ callable_name = executor["callable_name"]
90
+ callable_name = f"{executor['module_path']}.{callable_name}"
91
+ data[NODES_LABEL].append({"id": i, "function": callable_name})
92
+ i += 1
93
+
94
+ for link in wg.links:
95
+ link_data = link.to_dict()
96
+ # if the from socket is the default result, we set it to None
97
+ if link_data["from_socket"] == "result":
98
+ link_data["from_socket"] = None
99
+ link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
100
+ link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
101
+ link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
102
+ link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
103
+ data[EDGES_LABEL].append(link_data)
104
+
105
+ for node in wg.tasks:
106
+ for input in node.inputs:
107
+ # assume namespace is not used as input
108
+ if isinstance(input, TaskSocketNamespace):
109
+ continue
110
+ if isinstance(input.value, orm.Data):
111
+ if input.value.uuid not in data_node_name_mapping:
112
+ if isinstance(input.value, orm.List):
113
+ raw_value = input.value.get_list()
114
+ elif isinstance(input.value, orm.Dict):
115
+ raw_value = input.value.get_dict()
116
+ # unknow reason, there is a key "node_type" in the dict
117
+ raw_value.pop("node_type", None)
118
+ else:
119
+ raw_value = input.value.value
120
+ data[NODES_LABEL].append({"id": i, "value": raw_value})
121
+ input_node_name = i
122
+ data_node_name_mapping[input.value.uuid] = input_node_name
123
+ i += 1
124
+ else:
125
+ input_node_name = data_node_name_mapping[input.value.uuid]
126
+ data[EDGES_LABEL].append(
127
+ {
128
+ TARGET_LABEL: node_name_mapping[node.name],
129
+ TARGET_PORT_LABEL: input._name,
130
+ SOURCE_LABEL: input_node_name,
131
+ SOURCE_PORT_LABEL: None,
132
+ }
133
+ )
134
+ with open(file_name, "w") as f:
135
+ # json.dump({"nodes": data[], "edges": edges_new_lst}, f)
136
+ json.dump(data, f, indent=2)
137
+
138
+ return data
@@ -0,0 +1,74 @@
1
+ from concurrent.futures import Executor
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+ import json
5
+
6
+
7
+ from python_workflow_definition.shared import (
8
+ get_dict,
9
+ get_list,
10
+ get_kwargs,
11
+ get_source_handles,
12
+ convert_nodes_list_to_dict,
13
+ NODES_LABEL,
14
+ EDGES_LABEL,
15
+ SOURCE_LABEL,
16
+ SOURCE_PORT_LABEL,
17
+ )
18
+ from python_workflow_definition.purepython import resort_total_lst, group_edges
19
+
20
+
21
+ def get_item(obj, key):
22
+ return obj[key]
23
+
24
+
25
+ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Executor):
26
+ source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
27
+ if source in result_dict.keys():
28
+ result = result_dict[source]
29
+ elif source in nodes_new_dict.keys():
30
+ result = nodes_new_dict[source]
31
+ else:
32
+ raise KeyError()
33
+ if source_handle is None:
34
+ return result
35
+ else:
36
+ return exe.submit(fn=get_item, obj=result, key=source_handle)
37
+
38
+
39
+ def load_workflow_json(file_name: str, exe: Executor):
40
+ with open(file_name, "r") as f:
41
+ content = json.load(f)
42
+
43
+ edges_new_lst = content[EDGES_LABEL]
44
+ nodes_new_dict = {}
45
+
46
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
47
+ if isinstance(v, str) and "." in v:
48
+ p, m = v.rsplit(".", 1)
49
+ mod = import_module(p)
50
+ nodes_new_dict[int(k)] = getattr(mod, m)
51
+ else:
52
+ nodes_new_dict[int(k)] = v
53
+
54
+ total_lst = group_edges(edges_new_lst)
55
+ total_new_lst = resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
56
+
57
+ result_dict = {}
58
+ last_key = None
59
+ for lst in total_new_lst:
60
+ node = nodes_new_dict[lst[0]]
61
+ if isfunction(node):
62
+ kwargs = {
63
+ k: _get_value(
64
+ result_dict=result_dict,
65
+ nodes_new_dict=nodes_new_dict,
66
+ link_dict=v,
67
+ exe=exe,
68
+ )
69
+ for k, v in lst[1].items()
70
+ }
71
+ result_dict[lst[0]] = exe.submit(node, **kwargs)
72
+ last_key = lst[0]
73
+
74
+ return result_dict[last_key]
@@ -0,0 +1,333 @@
1
+ import json
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+
5
+ import numpy as np
6
+ from jobflow import job, Flow
7
+
8
+ from python_workflow_definition.shared import (
9
+ get_dict,
10
+ get_list,
11
+ get_kwargs,
12
+ get_source_handles,
13
+ convert_nodes_list_to_dict,
14
+ NODES_LABEL,
15
+ EDGES_LABEL,
16
+ SOURCE_LABEL,
17
+ SOURCE_PORT_LABEL,
18
+ TARGET_LABEL,
19
+ TARGET_PORT_LABEL,
20
+ )
21
+
22
+
23
+ def _get_function_dict(flow: Flow):
24
+ return {job.uuid: job.function for job in flow.jobs}
25
+
26
+
27
+ def _get_nodes_dict(function_dict: dict):
28
+ nodes_dict, nodes_mapping_dict = {}, {}
29
+ for i, [k, v] in enumerate(function_dict.items()):
30
+ nodes_dict[i] = v
31
+ nodes_mapping_dict[k] = i
32
+
33
+ return nodes_dict, nodes_mapping_dict
34
+
35
+
36
+ def _get_edge_from_dict(
37
+ target: str, key: str, value_dict: dict, nodes_mapping_dict: dict
38
+ ) -> dict:
39
+ if len(value_dict["attributes"]) == 1:
40
+ return {
41
+ TARGET_LABEL: target,
42
+ TARGET_PORT_LABEL: key,
43
+ SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
44
+ SOURCE_PORT_LABEL: value_dict["attributes"][0][1],
45
+ }
46
+ else:
47
+ return {
48
+ TARGET_LABEL: target,
49
+ TARGET_PORT_LABEL: key,
50
+ SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
51
+ SOURCE_PORT_LABEL: None,
52
+ }
53
+
54
+
55
+ def _get_edges_and_extend_nodes(
56
+ flow_dict: dict, nodes_mapping_dict: dict, nodes_dict: dict
57
+ ):
58
+ edges_lst = []
59
+ for job in flow_dict["jobs"]:
60
+ for k, v in job["function_kwargs"].items():
61
+ if (
62
+ isinstance(v, dict)
63
+ and "@module" in v
64
+ and "@class" in v
65
+ and "@version" in v
66
+ ):
67
+ edges_lst.append(
68
+ _get_edge_from_dict(
69
+ target=nodes_mapping_dict[job["uuid"]],
70
+ key=k,
71
+ value_dict=v,
72
+ nodes_mapping_dict=nodes_mapping_dict,
73
+ )
74
+ )
75
+ elif isinstance(v, dict) and any(
76
+ [
77
+ isinstance(el, dict)
78
+ and "@module" in el
79
+ and "@class" in el
80
+ and "@version" in el
81
+ for el in v.values()
82
+ ]
83
+ ):
84
+ node_dict_index = len(nodes_dict)
85
+ nodes_dict[node_dict_index] = get_dict
86
+ for kt, vt in v.items():
87
+ if (
88
+ isinstance(vt, dict)
89
+ and "@module" in vt
90
+ and "@class" in vt
91
+ and "@version" in vt
92
+ ):
93
+ edges_lst.append(
94
+ _get_edge_from_dict(
95
+ target=node_dict_index,
96
+ key=kt,
97
+ value_dict=vt,
98
+ nodes_mapping_dict=nodes_mapping_dict,
99
+ )
100
+ )
101
+ else:
102
+ if vt not in nodes_dict.values():
103
+ node_index = len(nodes_dict)
104
+ nodes_dict[node_index] = vt
105
+ else:
106
+ node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
107
+ str(vt)
108
+ ]
109
+ edges_lst.append(
110
+ {
111
+ TARGET_LABEL: node_dict_index,
112
+ TARGET_PORT_LABEL: kt,
113
+ SOURCE_LABEL: node_index,
114
+ SOURCE_PORT_LABEL: None,
115
+ }
116
+ )
117
+ edges_lst.append(
118
+ {
119
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
120
+ TARGET_PORT_LABEL: k,
121
+ SOURCE_LABEL: node_dict_index,
122
+ SOURCE_PORT_LABEL: None,
123
+ }
124
+ )
125
+ elif isinstance(v, list) and any(
126
+ [
127
+ isinstance(el, dict)
128
+ and "@module" in el
129
+ and "@class" in el
130
+ and "@version" in el
131
+ for el in v
132
+ ]
133
+ ):
134
+ node_list_index = len(nodes_dict)
135
+ nodes_dict[node_list_index] = get_list
136
+ for kt, vt in enumerate(v):
137
+ if (
138
+ isinstance(vt, dict)
139
+ and "@module" in vt
140
+ and "@class" in vt
141
+ and "@version" in vt
142
+ ):
143
+ edges_lst.append(
144
+ _get_edge_from_dict(
145
+ target=node_list_index,
146
+ key=str(kt),
147
+ value_dict=vt,
148
+ nodes_mapping_dict=nodes_mapping_dict,
149
+ )
150
+ )
151
+ else:
152
+ if vt not in nodes_dict.values():
153
+ node_index = len(nodes_dict)
154
+ nodes_dict[node_index] = vt
155
+ else:
156
+ node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
157
+ str(vt)
158
+ ]
159
+ edges_lst.append(
160
+ {
161
+ TARGET_LABEL: node_list_index,
162
+ TARGET_PORT_LABEL: kt,
163
+ SOURCE_LABEL: node_index,
164
+ SOURCE_PORT_LABEL: None,
165
+ }
166
+ )
167
+ edges_lst.append(
168
+ {
169
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
170
+ TARGET_PORT_LABEL: k,
171
+ SOURCE_LABEL: node_list_index,
172
+ SOURCE_PORT_LABEL: None,
173
+ }
174
+ )
175
+ else:
176
+ if v not in nodes_dict.values():
177
+ node_index = len(nodes_dict)
178
+ nodes_dict[node_index] = v
179
+ else:
180
+ node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
181
+ edges_lst.append(
182
+ {
183
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
184
+ TARGET_PORT_LABEL: k,
185
+ SOURCE_LABEL: node_index,
186
+ SOURCE_PORT_LABEL: None,
187
+ }
188
+ )
189
+ return edges_lst, nodes_dict
190
+
191
+
192
+ def _resort_total_lst(total_dict: dict, nodes_dict: dict) -> dict:
193
+ nodes_with_dep_lst = list(sorted(total_dict.keys()))
194
+ nodes_without_dep_lst = [
195
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
196
+ ]
197
+ ordered_lst = []
198
+ total_new_dict = {}
199
+ while len(total_new_dict) < len(total_dict):
200
+ for ind in sorted(total_dict.keys()):
201
+ connect = total_dict[ind]
202
+ if ind not in ordered_lst:
203
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
204
+ if all(
205
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
206
+ ):
207
+ ordered_lst.append(ind)
208
+ total_new_dict[ind] = connect
209
+ return total_new_dict
210
+
211
+
212
+ def _group_edges(edges_lst: list) -> dict:
213
+ total_dict = {}
214
+ for ed_major in edges_lst:
215
+ target_id = ed_major[TARGET_LABEL]
216
+ tmp_lst = []
217
+ if target_id not in total_dict.keys():
218
+ for ed in edges_lst:
219
+ if target_id == ed[TARGET_LABEL]:
220
+ tmp_lst.append(ed)
221
+ total_dict[target_id] = get_kwargs(lst=tmp_lst)
222
+ return total_dict
223
+
224
+
225
+ def _get_input_dict(nodes_dict: dict) -> dict:
226
+ return {k: v for k, v in nodes_dict.items() if not isfunction(v)}
227
+
228
+
229
+ def _get_workflow(
230
+ nodes_dict: dict, input_dict: dict, total_dict: dict, source_handles_dict: dict
231
+ ) -> list:
232
+ def get_attr_helper(obj, source_handle):
233
+ if source_handle is None:
234
+ return getattr(obj, "output")
235
+ else:
236
+ return getattr(getattr(obj, "output"), source_handle)
237
+
238
+ memory_dict = {}
239
+ for k in total_dict.keys():
240
+ v = nodes_dict[k]
241
+ if isfunction(v):
242
+ if k in source_handles_dict.keys():
243
+ fn = job(
244
+ method=v,
245
+ data=[el for el in source_handles_dict[k] if el is not None],
246
+ )
247
+ else:
248
+ fn = job(method=v)
249
+ kwargs = {
250
+ kw: (
251
+ input_dict[vw[SOURCE_LABEL]]
252
+ if vw[SOURCE_LABEL] in input_dict
253
+ else get_attr_helper(
254
+ obj=memory_dict[vw[SOURCE_LABEL]],
255
+ source_handle=vw[SOURCE_PORT_LABEL],
256
+ )
257
+ )
258
+ for kw, vw in total_dict[k].items()
259
+ }
260
+ memory_dict[k] = fn(**kwargs)
261
+ return list(memory_dict.values())
262
+
263
+
264
+ def _get_item_from_tuple(input_obj, index, index_lst):
265
+ if isinstance(input_obj, dict):
266
+ return input_obj[index]
267
+ else:
268
+ return list(input_obj)[index_lst.index(index)]
269
+
270
+
271
+ def load_workflow_json(file_name: str) -> Flow:
272
+ with open(file_name, "r") as f:
273
+ content = json.load(f)
274
+
275
+ edges_new_lst = []
276
+ for edge in content[EDGES_LABEL]:
277
+ if edge[SOURCE_PORT_LABEL] is None:
278
+ edges_new_lst.append(edge)
279
+ else:
280
+ edges_new_lst.append(
281
+ {
282
+ TARGET_LABEL: edge[TARGET_LABEL],
283
+ TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
284
+ SOURCE_LABEL: edge[SOURCE_LABEL],
285
+ SOURCE_PORT_LABEL: str(edge[SOURCE_PORT_LABEL]),
286
+ }
287
+ )
288
+
289
+ nodes_new_dict = {}
290
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
291
+ if isinstance(v, str) and "." in v:
292
+ p, m = v.rsplit(".", 1)
293
+ mod = import_module(p)
294
+ nodes_new_dict[int(k)] = getattr(mod, m)
295
+ else:
296
+ nodes_new_dict[int(k)] = v
297
+
298
+ source_handles_dict = get_source_handles(edges_lst=edges_new_lst)
299
+ total_dict = _group_edges(edges_lst=edges_new_lst)
300
+ input_dict = _get_input_dict(nodes_dict=nodes_new_dict)
301
+ new_total_dict = _resort_total_lst(total_dict=total_dict, nodes_dict=nodes_new_dict)
302
+ task_lst = _get_workflow(
303
+ nodes_dict=nodes_new_dict,
304
+ input_dict=input_dict,
305
+ total_dict=new_total_dict,
306
+ source_handles_dict=source_handles_dict,
307
+ )
308
+ return Flow(task_lst)
309
+
310
+
311
+ def write_workflow_json(flow: Flow, file_name: str = "workflow.json"):
312
+ flow_dict = flow.as_dict()
313
+ function_dict = _get_function_dict(flow=flow)
314
+ nodes_dict, nodes_mapping_dict = _get_nodes_dict(function_dict=function_dict)
315
+ edges_lst, nodes_dict = _get_edges_and_extend_nodes(
316
+ flow_dict=flow_dict,
317
+ nodes_mapping_dict=nodes_mapping_dict,
318
+ nodes_dict=nodes_dict,
319
+ )
320
+
321
+ nodes_store_lst = []
322
+ for k, v in nodes_dict.items():
323
+ if isfunction(v):
324
+ nodes_store_lst.append(
325
+ {"id": k, "function": v.__module__ + "." + v.__name__}
326
+ )
327
+ elif isinstance(v, np.ndarray):
328
+ nodes_store_lst.append({"id": k, "value": v.tolist()})
329
+ else:
330
+ nodes_store_lst.append({"id": k, "value": v})
331
+
332
+ with open(file_name, "w") as f:
333
+ json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst}, f)
@@ -0,0 +1,45 @@
1
+ import json
2
+
3
+ from IPython.display import SVG, display
4
+ import networkx as nx
5
+
6
+
7
+ from python_workflow_definition.purepython import group_edges
8
+ from python_workflow_definition.shared import (
9
+ get_kwargs,
10
+ convert_nodes_list_to_dict,
11
+ NODES_LABEL,
12
+ EDGES_LABEL,
13
+ SOURCE_LABEL,
14
+ SOURCE_PORT_LABEL,
15
+ )
16
+
17
+
18
+ def plot(file_name: str):
19
+ with open(file_name, "r") as f:
20
+ content = json.load(f)
21
+
22
+ graph = nx.DiGraph()
23
+ node_dict = convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL])
24
+ total_lst = group_edges(edges_lst=content[EDGES_LABEL])
25
+
26
+ for node_id, node_name in node_dict.items():
27
+ graph.add_node(node_id, name=str(node_name), label=str(node_name))
28
+
29
+ for edge_tuple in total_lst:
30
+ target_node, edge_dict = edge_tuple
31
+ edge_label_dict = {}
32
+ for k, v in edge_dict.items():
33
+ if v[SOURCE_LABEL] not in edge_label_dict:
34
+ edge_label_dict[v[SOURCE_LABEL]] = []
35
+ if v[SOURCE_PORT_LABEL] is None:
36
+ edge_label_dict[v[SOURCE_LABEL]].append(k)
37
+ else:
38
+ edge_label_dict[v[SOURCE_LABEL]].append(
39
+ k + "=result[" + v[SOURCE_PORT_LABEL] + "]"
40
+ )
41
+ for k, v in edge_label_dict.items():
42
+ graph.add_edge(str(k), str(target_node), label=", ".join(v))
43
+
44
+ svg = nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")
45
+ display(SVG(svg))
@@ -0,0 +1,99 @@
1
+ import json
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+
5
+
6
+ from python_workflow_definition.shared import (
7
+ get_dict,
8
+ get_list,
9
+ get_kwargs,
10
+ get_source_handles,
11
+ convert_nodes_list_to_dict,
12
+ NODES_LABEL,
13
+ EDGES_LABEL,
14
+ SOURCE_LABEL,
15
+ SOURCE_PORT_LABEL,
16
+ TARGET_LABEL,
17
+ TARGET_PORT_LABEL,
18
+ )
19
+
20
+
21
+ def resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
22
+ nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
23
+ nodes_without_dep_lst = [
24
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
25
+ ]
26
+ ordered_lst, total_new_lst = [], []
27
+ while len(total_new_lst) < len(total_lst):
28
+ for ind, connect in total_lst:
29
+ if ind not in ordered_lst:
30
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
31
+ if all(
32
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
33
+ ):
34
+ ordered_lst.append(ind)
35
+ total_new_lst.append([ind, connect])
36
+ return total_new_lst
37
+
38
+
39
+ def group_edges(edges_lst: list) -> list:
40
+ edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
41
+ total_lst, tmp_lst = [], []
42
+ target_id = edges_sorted_lst[0][TARGET_LABEL]
43
+ for ed in edges_sorted_lst:
44
+ if target_id == ed[TARGET_LABEL]:
45
+ tmp_lst.append(ed)
46
+ else:
47
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
48
+ target_id = ed[TARGET_LABEL]
49
+ tmp_lst = [ed]
50
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
51
+ return total_lst
52
+
53
+
54
+ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict):
55
+ source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
56
+ if source in result_dict.keys():
57
+ result = result_dict[source]
58
+ elif source in nodes_new_dict.keys():
59
+ result = nodes_new_dict[source]
60
+ else:
61
+ raise KeyError()
62
+ if source_handle is None:
63
+ return result
64
+ else:
65
+ return result[source_handle]
66
+
67
+
68
+ def load_workflow_json(file_name: str):
69
+ with open(file_name, "r") as f:
70
+ content = json.load(f)
71
+
72
+ edges_new_lst = content[EDGES_LABEL]
73
+ nodes_new_dict = {}
74
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
75
+ if isinstance(v, str) and "." in v:
76
+ p, m = v.rsplit(".", 1)
77
+ mod = import_module(p)
78
+ nodes_new_dict[int(k)] = getattr(mod, m)
79
+ else:
80
+ nodes_new_dict[int(k)] = v
81
+
82
+ total_lst = group_edges(edges_new_lst)
83
+ total_new_lst = resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
84
+
85
+ result_dict = {}
86
+ last_key = None
87
+ for lst in total_new_lst:
88
+ node = nodes_new_dict[lst[0]]
89
+ if isfunction(node):
90
+ kwargs = {
91
+ k: _get_value(
92
+ result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v
93
+ )
94
+ for k, v in lst[1].items()
95
+ }
96
+ result_dict[lst[0]] = node(**kwargs)
97
+ last_key = lst[0]
98
+
99
+ return result_dict[last_key]
@@ -0,0 +1,292 @@
1
+ from importlib import import_module
2
+ from inspect import isfunction
3
+ import json
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ from pyiron_base import job, Project
8
+ from pyiron_base.project.delayed import DelayedObject
9
+
10
+ from python_workflow_definition.shared import (
11
+ get_kwargs,
12
+ get_source_handles,
13
+ convert_nodes_list_to_dict,
14
+ NODES_LABEL,
15
+ EDGES_LABEL,
16
+ SOURCE_LABEL,
17
+ SOURCE_PORT_LABEL,
18
+ TARGET_LABEL,
19
+ TARGET_PORT_LABEL,
20
+ )
21
+
22
+
23
+ def _resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
24
+ nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
25
+ nodes_without_dep_lst = [
26
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
27
+ ]
28
+ ordered_lst, total_new_lst = [], []
29
+ while len(total_new_lst) < len(total_lst):
30
+ for ind, connect in total_lst:
31
+ if ind not in ordered_lst:
32
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
33
+ if all(
34
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
35
+ ):
36
+ ordered_lst.append(ind)
37
+ total_new_lst.append([ind, connect])
38
+ return total_new_lst
39
+
40
+
41
+ def _group_edges(edges_lst: list) -> list:
42
+ edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
43
+ total_lst, tmp_lst = [], []
44
+ target_id = edges_sorted_lst[0][TARGET_LABEL]
45
+ for ed in edges_sorted_lst:
46
+ if target_id == ed[TARGET_LABEL]:
47
+ tmp_lst.append(ed)
48
+ else:
49
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
50
+ target_id = ed[TARGET_LABEL]
51
+ tmp_lst = [ed]
52
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
53
+ return total_lst
54
+
55
+
56
+ def _get_source(
57
+ nodes_dict: dict, delayed_object_dict: dict, source: str, source_handle: str
58
+ ):
59
+ if source in delayed_object_dict.keys() and source_handle is not None:
60
+ return (
61
+ delayed_object_dict[source].__getattr__("output").__getattr__(source_handle)
62
+ )
63
+ elif source in delayed_object_dict.keys():
64
+ return delayed_object_dict[source]
65
+ else:
66
+ return nodes_dict[source]
67
+
68
+
69
+ def _get_delayed_object_dict(
70
+ total_lst: list, nodes_dict: dict, source_handle_dict: dict, pyiron_project: Project
71
+ ) -> dict:
72
+ delayed_object_dict = {}
73
+ for item in total_lst:
74
+ key, input_dict = item
75
+ kwargs = {
76
+ k: _get_source(
77
+ nodes_dict=nodes_dict,
78
+ delayed_object_dict=delayed_object_dict,
79
+ source=v[SOURCE_LABEL],
80
+ source_handle=v[SOURCE_PORT_LABEL],
81
+ )
82
+ for k, v in input_dict.items()
83
+ }
84
+ delayed_object_dict[key] = job(
85
+ funct=nodes_dict[key],
86
+ output_key_lst=source_handle_dict.get(key, []),
87
+ )(**kwargs, pyiron_project=pyiron_project)
88
+ return delayed_object_dict
89
+
90
+
91
+ def get_dict(**kwargs) -> dict:
92
+ return {k: v for k, v in kwargs["kwargs"].items()}
93
+
94
+
95
+ def get_list(**kwargs) -> list:
96
+ return list(kwargs["kwargs"].values())
97
+
98
+
99
+ def _remove_server_obj(nodes_dict: dict, edges_lst: list):
100
+ server_lst = [k for k in nodes_dict.keys() if k.startswith("server_obj_")]
101
+ for s in server_lst:
102
+ del nodes_dict[s]
103
+ edges_lst = [ep for ep in edges_lst if s not in ep]
104
+ return nodes_dict, edges_lst
105
+
106
+
107
+ def _get_nodes(connection_dict: dict, delayed_object_updated_dict: dict):
108
+ return {
109
+ connection_dict[k]: v._python_function if isinstance(v, DelayedObject) else v
110
+ for k, v in delayed_object_updated_dict.items()
111
+ }
112
+
113
+
114
+ def _get_unique_objects(nodes_dict: dict):
115
+ delayed_object_dict = {}
116
+ for k, v in nodes_dict.items():
117
+ if isinstance(v, DelayedObject):
118
+ delayed_object_dict[k] = v
119
+ elif isinstance(v, list) and any([isinstance(el, DelayedObject) for el in v]):
120
+ delayed_object_dict[k] = DelayedObject(function=get_list)
121
+ delayed_object_dict[k]._input = {i: el for i, el in enumerate(v)}
122
+ delayed_object_dict[k]._python_function = get_list
123
+ elif isinstance(v, dict) and any(
124
+ [isinstance(el, DelayedObject) for el in v.values()]
125
+ ):
126
+ delayed_object_dict[k] = DelayedObject(
127
+ function=get_dict,
128
+ **v,
129
+ )
130
+ delayed_object_dict[k]._python_function = get_dict
131
+ delayed_object_dict[k]._input = v
132
+ unique_lst = []
133
+ delayed_object_updated_dict, match_dict = {}, {}
134
+ for dobj in delayed_object_dict.keys():
135
+ match = False
136
+ for obj in unique_lst:
137
+ if (
138
+ delayed_object_updated_dict[obj]._python_function
139
+ == delayed_object_dict[dobj]._python_function
140
+ and delayed_object_dict[dobj]._input == delayed_object_dict[obj]._input
141
+ ):
142
+ delayed_object_updated_dict[obj] = delayed_object_dict[obj]
143
+ match_dict[dobj] = obj
144
+ match = True
145
+ break
146
+ if not match:
147
+ unique_lst.append(dobj)
148
+ delayed_object_updated_dict[dobj] = delayed_object_dict[dobj]
149
+ update_dict = {}
150
+ for k, v in nodes_dict.items():
151
+ if not (
152
+ isinstance(v, DelayedObject)
153
+ or (
154
+ isinstance(v, list) and any([isinstance(el, DelayedObject) for el in v])
155
+ )
156
+ or (
157
+ isinstance(v, dict)
158
+ and any([isinstance(el, DelayedObject) for el in v.values()])
159
+ )
160
+ ):
161
+ update_dict[k] = v
162
+ delayed_object_updated_dict.update(update_dict)
163
+ return delayed_object_updated_dict, match_dict
164
+
165
+
166
+ def _get_connection_dict(delayed_object_updated_dict: dict, match_dict: dict):
167
+ new_obj_dict = {}
168
+ connection_dict = {}
169
+ lookup_dict = {}
170
+ for i, [k, v] in enumerate(delayed_object_updated_dict.items()):
171
+ new_obj_dict[i] = v
172
+ connection_dict[k] = i
173
+ lookup_dict[i] = k
174
+
175
+ for k, v in match_dict.items():
176
+ if v in connection_dict.keys():
177
+ connection_dict[k] = connection_dict[v]
178
+
179
+ return connection_dict, lookup_dict
180
+
181
+
182
+ def _get_edges_dict(
183
+ edges_lst: list, nodes_dict: dict, connection_dict: dict, lookup_dict: dict
184
+ ):
185
+ edges_dict_lst = []
186
+ existing_connection_lst = []
187
+ for ep in edges_lst:
188
+ input_name, output_name = ep
189
+ target = connection_dict[input_name]
190
+ target_handle = "_".join(output_name.split("_")[:-1])
191
+ connection_name = lookup_dict[target] + "_" + target_handle
192
+ if connection_name not in existing_connection_lst:
193
+ output = nodes_dict[output_name]
194
+ if isinstance(output, DelayedObject):
195
+ if output._list_index is not None:
196
+ edges_dict_lst.append(
197
+ {
198
+ TARGET_LABEL: target,
199
+ TARGET_PORT_LABEL: target_handle,
200
+ SOURCE_LABEL: connection_dict[output_name],
201
+ SOURCE_PORT_LABEL: f"s_{output._list_index}", # check for list index
202
+ }
203
+ )
204
+ else:
205
+ edges_dict_lst.append(
206
+ {
207
+ TARGET_LABEL: target,
208
+ TARGET_PORT_LABEL: target_handle,
209
+ SOURCE_LABEL: connection_dict[output_name],
210
+ SOURCE_PORT_LABEL: output._output_key, # check for list index
211
+ }
212
+ )
213
+ else:
214
+ edges_dict_lst.append(
215
+ {
216
+ TARGET_LABEL: target,
217
+ TARGET_PORT_LABEL: target_handle,
218
+ SOURCE_LABEL: connection_dict[output_name],
219
+ SOURCE_PORT_LABEL: None,
220
+ }
221
+ )
222
+ existing_connection_lst.append(connection_name)
223
+ return edges_dict_lst
224
+
225
+
226
+ def load_workflow_json(file_name: str, project: Optional[Project] = None):
227
+ if project is None:
228
+ project = Project(".")
229
+
230
+ with open(file_name, "r") as f:
231
+ content = json.load(f)
232
+
233
+ edges_new_lst = content[EDGES_LABEL]
234
+ nodes_new_dict = {}
235
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
236
+ if isinstance(v, str) and "." in v:
237
+ p, m = v.rsplit(".", 1)
238
+ if p == "python_workflow_definition.shared":
239
+ p = "python_workflow_definition.pyiron_base"
240
+ mod = import_module(p)
241
+ nodes_new_dict[int(k)] = getattr(mod, m)
242
+ else:
243
+ nodes_new_dict[int(k)] = v
244
+
245
+ total_lst = _group_edges(edges_new_lst)
246
+ total_new_lst = _resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
247
+ source_handle_dict = get_source_handles(edges_new_lst)
248
+ delayed_object_dict = _get_delayed_object_dict(
249
+ total_lst=total_new_lst,
250
+ nodes_dict=nodes_new_dict,
251
+ source_handle_dict=source_handle_dict,
252
+ pyiron_project=project,
253
+ )
254
+ return list(delayed_object_dict.values())
255
+
256
+
257
+ def write_workflow_json(
258
+ delayed_object: DelayedObject, file_name: str = "workflow.json"
259
+ ):
260
+ nodes_dict, edges_lst = delayed_object.get_graph()
261
+ nodes_dict, edges_lst = _remove_server_obj(
262
+ nodes_dict=nodes_dict, edges_lst=edges_lst
263
+ )
264
+ delayed_object_updated_dict, match_dict = _get_unique_objects(nodes_dict=nodes_dict)
265
+ connection_dict, lookup_dict = _get_connection_dict(
266
+ delayed_object_updated_dict=delayed_object_updated_dict, match_dict=match_dict
267
+ )
268
+ nodes_new_dict = _get_nodes(
269
+ connection_dict=connection_dict,
270
+ delayed_object_updated_dict=delayed_object_updated_dict,
271
+ )
272
+ edges_new_lst = _get_edges_dict(
273
+ edges_lst=edges_lst,
274
+ nodes_dict=nodes_dict,
275
+ connection_dict=connection_dict,
276
+ lookup_dict=lookup_dict,
277
+ )
278
+
279
+ nodes_store_lst = []
280
+ for k, v in nodes_new_dict.items():
281
+ if isfunction(v):
282
+ mod = v.__module__
283
+ if mod == "python_workflow_definition.pyiron_base":
284
+ mod = "python_workflow_definition.shared"
285
+ nodes_store_lst.append({"id": k, "function": mod + "." + v.__name__})
286
+ elif isinstance(v, np.ndarray):
287
+ nodes_store_lst.append({"id": k, "value": v.tolist()})
288
+ else:
289
+ nodes_store_lst.append({"id": k, "value": v})
290
+
291
+ with open(file_name, "w") as f:
292
+ json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_new_lst}, f)
@@ -0,0 +1,45 @@
1
+ NODES_LABEL = "nodes"
2
+ EDGES_LABEL = "edges"
3
+ SOURCE_LABEL = "source"
4
+ SOURCE_PORT_LABEL = "sourcePort"
5
+ TARGET_LABEL = "target"
6
+ TARGET_PORT_LABEL = "targetPort"
7
+
8
+
9
+ def get_dict(**kwargs) -> dict:
10
+ # NOTE: In WG, this will automatically be wrapped in a dict with the `result` key
11
+ return {k: v for k, v in kwargs.items()}
12
+ # return {'dict': {k: v for k, v in kwargs.items()}}
13
+
14
+
15
+ def get_list(**kwargs) -> list:
16
+ return list(kwargs.values())
17
+
18
+
19
+ def get_kwargs(lst: list) -> dict:
20
+ return {
21
+ t[TARGET_PORT_LABEL]: {
22
+ SOURCE_LABEL: t[SOURCE_LABEL],
23
+ SOURCE_PORT_LABEL: t[SOURCE_PORT_LABEL],
24
+ }
25
+ for t in lst
26
+ }
27
+
28
+
29
+ def get_source_handles(edges_lst: list) -> dict:
30
+ source_handle_dict = {}
31
+ for ed in edges_lst:
32
+ if ed[SOURCE_LABEL] not in source_handle_dict.keys():
33
+ source_handle_dict[ed[SOURCE_LABEL]] = []
34
+ source_handle_dict[ed[SOURCE_LABEL]].append(ed[SOURCE_PORT_LABEL])
35
+ return {
36
+ k: list(range(len(v))) if len(v) > 1 and all([el is None for el in v]) else v
37
+ for k, v in source_handle_dict.items()
38
+ }
39
+
40
+
41
+ def convert_nodes_list_to_dict(nodes_list: list) -> dict:
42
+ return {
43
+ str(el["id"]): el["value"] if "value" in el else el["function"]
44
+ for el in sorted(nodes_list, key=lambda d: d["id"])
45
+ }
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.4
2
+ Name: python_workflow_definition
3
+ Version: 0.0.1
4
+ Summary: Python Workflow Definition - workflow interoperability for aiida, jobflow and pyiron
5
+ Author-email: Jan Janssen <janssen@mpie.de>, Janine George <janine.geogre@bam.de>, Julian Geiger <julian.geiger@psi.ch>, Xing Wang <xing.wang@psi.ch>, Marnik Bercx <marnik.bercx@psi.ch>, Christina Ertural <christina.ertural@bam.de>
6
+ License: BSD 3-Clause License
7
+
8
+ Copyright (c) 2025, Jan Janssen
9
+ All rights reserved.
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ * Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ * Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ * Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ Requires-Dist: aiida-workgraph<=0.5.2,>=0.5.1
36
+ Requires-Dist: jobflow<=0.1.19,>=0.1.18
37
+ Requires-Dist: numpy<2,>=1.21
38
+ Requires-Dist: pyiron-base<=0.11.11,>=0.11.10
39
+ Provides-Extra: plot
40
+ Requires-Dist: ipython<=9.0.2,>=7.33.0; extra == 'plot'
41
+ Requires-Dist: networkx<=3.4.2,>=2.8.8; extra == 'plot'
42
+ Requires-Dist: pygraphviz<=1.14,>=1.10; extra == 'plot'
@@ -0,0 +1,11 @@
1
+ python_workflow_definition/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ python_workflow_definition/aiida.py,sha256=WvDE_2Bhv2IhqYdv-PZ3eGxHRw8k35PsSV4vk5qkOhg,5550
3
+ python_workflow_definition/executorlib.py,sha256=x6Nw01s3WsH7MGnw8-jvkdD5Yy_V4O0LTsDO_OKpUHs,2145
4
+ python_workflow_definition/jobflow.py,sha256=hkmwCl-xTOmHPd-9peawJolFrpwd6nvWwWu9K3L3Hxc,11809
5
+ python_workflow_definition/plot.py,sha256=L_FOSLp1kyNSkj3_owJpxIFe2raKCF0KBqRXa69xacE,1423
6
+ python_workflow_definition/purepython.py,sha256=YgJQaBP60GjOCAAhISf-Alc2DVrAs6U71hSBbcnSxlk,3154
7
+ python_workflow_definition/pyiron_base.py,sha256=ehFMKaZE2U5hhLx9KGwLkAZv1LqKsw7fsB0qZ-UfV3c,10562
8
+ python_workflow_definition/shared.py,sha256=ZjcxB0BdXspsMIjiZakqWBGvrwzxleM_7t4XsQ9topg,1320
9
+ python_workflow_definition-0.0.1.dist-info/METADATA,sha256=XBAKbtfMI54OXAPsfhlxCKA8RBXLsI5e0-ubBfsC49M,2491
10
+ python_workflow_definition-0.0.1.dist-info/WHEEL,sha256=tkmg4JIqwd9H8mL30xA7crRmoStyCtGp0VWshokd1Jc,105
11
+ python_workflow_definition-0.0.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py2-none-any
5
+ Tag: py3-none-any