python-workflow-definition 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,202 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
175
+
176
+ # Files of this package
177
+ .vscode/
178
+ input_tmp.in
179
+ pyiron.log
180
+ pyiron_draw.png
181
+ python_workflow_definition/src/python_workflow_definition/__pycache__/
182
+ test/
183
+ mini/
184
+ evcurve.png
185
+ strain_0/
186
+ strain_1/
187
+ strain_2/
188
+ strain_3/
189
+ strain_4/
190
+ aiida_to_pyiron_base_qe.json
191
+ pyiron_base_to_aiida_qe.json
192
+ jobflow_to_pyiron_base_simple.json
193
+ aiida_to_pyiron_base_simple.json
194
+ jobflow_to_aiida_simple.json
195
+ pyiron_base_to_jobflow_simple.json
196
+ aiida_to_jobflow_simple.json
197
+ jobflow_to_pyiron_base_qe.json
198
+ jobflow_to_aiida_qe.json
199
+ aiida_to_jobflow_qe.json
200
+ pyiron_base_to_aiida_simple.json
201
+ pyiron_base_to_jobflow_qe.json
202
+
@@ -0,0 +1,42 @@
1
+ Metadata-Version: 2.4
2
+ Name: python_workflow_definition
3
+ Version: 0.0.1
4
+ Summary: Python Workflow Definition - workflow interoperability for aiida, jobflow and pyiron
5
+ Author-email: Jan Janssen <janssen@mpie.de>, Janine George <janine.geogre@bam.de>, Julian Geiger <julian.geiger@psi.ch>, Xing Wang <xing.wang@psi.ch>, Marnik Bercx <marnik.bercx@psi.ch>, Christina Ertural <christina.ertural@bam.de>
6
+ License: BSD 3-Clause License
7
+
8
+ Copyright (c) 2025, Jan Janssen
9
+ All rights reserved.
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ * Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ * Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ * Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ Requires-Dist: aiida-workgraph<=0.5.2,>=0.5.1
36
+ Requires-Dist: jobflow<=0.1.19,>=0.1.18
37
+ Requires-Dist: numpy<2,>=1.21
38
+ Requires-Dist: pyiron-base<=0.11.11,>=0.11.10
39
+ Provides-Extra: plot
40
+ Requires-Dist: ipython<=9.0.2,>=7.33.0; extra == 'plot'
41
+ Requires-Dist: networkx<=3.4.2,>=2.8.8; extra == 'plot'
42
+ Requires-Dist: pygraphviz<=1.14,>=1.10; extra == 'plot'
@@ -0,0 +1,30 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "python_workflow_definition"
7
+ version = "0.0.1"
8
+ description = "Python Workflow Definition - workflow interoperability for aiida, jobflow and pyiron"
9
+ authors = [
10
+ { name = "Jan Janssen", email = "janssen@mpie.de" },
11
+ { name = "Janine George", email = "janine.geogre@bam.de" },
12
+ { name = "Julian Geiger", email = "julian.geiger@psi.ch" },
13
+ { name = "Xing Wang", email = "xing.wang@psi.ch" },
14
+ { name = "Marnik Bercx", email = "marnik.bercx@psi.ch" },
15
+ { name = "Christina Ertural", email = "christina.ertural@bam.de" },
16
+ ]
17
+ license = { file = "../LICENSE" }
18
+ dependencies = [
19
+ "aiida-workgraph>=0.5.1,<=0.5.2",
20
+ "numpy>=1.21,<2",
21
+ "jobflow>=0.1.18,<=0.1.19",
22
+ "pyiron_base>=0.11.10,<=0.11.11",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ plot = [
27
+ "pygraphviz>=1.10,<=1.14",
28
+ "networkx>=2.8.8,<=3.4.2",
29
+ "ipython>=7.33.0,<=9.0.2",
30
+ ]
@@ -0,0 +1,138 @@
1
+ from importlib import import_module
2
+ import json
3
+ import traceback
4
+
5
+ from aiida import orm
6
+ from aiida_pythonjob.data.serializer import general_serializer
7
+ from aiida_workgraph import WorkGraph, task
8
+ from aiida_workgraph.socket import TaskSocketNamespace
9
+
10
+ from python_workflow_definition.shared import (
11
+ convert_nodes_list_to_dict,
12
+ NODES_LABEL,
13
+ EDGES_LABEL,
14
+ SOURCE_LABEL,
15
+ SOURCE_PORT_LABEL,
16
+ TARGET_LABEL,
17
+ TARGET_PORT_LABEL,
18
+ )
19
+
20
+
21
+ def load_workflow_json(file_name: str) -> WorkGraph:
22
+ with open(file_name) as f:
23
+ data = json.load(f)
24
+
25
+ wg = WorkGraph()
26
+ task_name_mapping = {}
27
+
28
+ for id, identifier in convert_nodes_list_to_dict(
29
+ nodes_list=data[NODES_LABEL]
30
+ ).items():
31
+ if isinstance(identifier, str) and "." in identifier:
32
+ p, m = identifier.rsplit(".", 1)
33
+ mod = import_module(p)
34
+ func = getattr(mod, m)
35
+ wg.add_task(func)
36
+ # Remove the default result output, because we will add the outputs later from the data in the link
37
+ del wg.tasks[-1].outputs["result"]
38
+ task_name_mapping[id] = wg.tasks[-1]
39
+ else:
40
+ # data task
41
+ data_node = general_serializer(identifier)
42
+ task_name_mapping[id] = data_node
43
+
44
+ # add links
45
+ for link in data[EDGES_LABEL]:
46
+ to_task = task_name_mapping[str(link[TARGET_LABEL])]
47
+ # if the input is not exit, it means we pass the data into to the kwargs
48
+ # in this case, we add the input socket
49
+ if link[TARGET_PORT_LABEL] not in to_task.inputs:
50
+ to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
51
+ else:
52
+ to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
53
+ from_task = task_name_mapping[str(link[SOURCE_LABEL])]
54
+ if isinstance(from_task, orm.Data):
55
+ to_socket.value = from_task
56
+ else:
57
+ try:
58
+ if link[SOURCE_PORT_LABEL] is None:
59
+ link[SOURCE_PORT_LABEL] = "result"
60
+ # because we are not define the outputs explicitly during the pythonjob creation
61
+ # we add it here, and assume the output exit
62
+ if link[SOURCE_PORT_LABEL] not in from_task.outputs:
63
+ # if str(link["sourcePort"]) not in from_task.outputs:
64
+ from_socket = from_task.add_output(
65
+ "workgraph.any",
66
+ name=link[SOURCE_PORT_LABEL],
67
+ # name=str(link["sourcePort"]),
68
+ metadata={"is_function_output": True},
69
+ )
70
+ else:
71
+ from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
72
+
73
+ wg.add_link(from_socket, to_socket)
74
+ except Exception as e:
75
+ traceback.print_exc()
76
+ print("Failed to link", link, "with error:", e)
77
+ return wg
78
+
79
+
80
+ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
81
+ data = {NODES_LABEL: [], EDGES_LABEL: []}
82
+ node_name_mapping = {}
83
+ data_node_name_mapping = {}
84
+ i = 0
85
+ for node in wg.tasks:
86
+ executor = node.get_executor()
87
+ node_name_mapping[node.name] = i
88
+
89
+ callable_name = executor["callable_name"]
90
+ callable_name = f"{executor['module_path']}.{callable_name}"
91
+ data[NODES_LABEL].append({"id": i, "function": callable_name})
92
+ i += 1
93
+
94
+ for link in wg.links:
95
+ link_data = link.to_dict()
96
+ # if the from socket is the default result, we set it to None
97
+ if link_data["from_socket"] == "result":
98
+ link_data["from_socket"] = None
99
+ link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
100
+ link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
101
+ link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
102
+ link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
103
+ data[EDGES_LABEL].append(link_data)
104
+
105
+ for node in wg.tasks:
106
+ for input in node.inputs:
107
+ # assume namespace is not used as input
108
+ if isinstance(input, TaskSocketNamespace):
109
+ continue
110
+ if isinstance(input.value, orm.Data):
111
+ if input.value.uuid not in data_node_name_mapping:
112
+ if isinstance(input.value, orm.List):
113
+ raw_value = input.value.get_list()
114
+ elif isinstance(input.value, orm.Dict):
115
+ raw_value = input.value.get_dict()
116
+ # unknow reason, there is a key "node_type" in the dict
117
+ raw_value.pop("node_type", None)
118
+ else:
119
+ raw_value = input.value.value
120
+ data[NODES_LABEL].append({"id": i, "value": raw_value})
121
+ input_node_name = i
122
+ data_node_name_mapping[input.value.uuid] = input_node_name
123
+ i += 1
124
+ else:
125
+ input_node_name = data_node_name_mapping[input.value.uuid]
126
+ data[EDGES_LABEL].append(
127
+ {
128
+ TARGET_LABEL: node_name_mapping[node.name],
129
+ TARGET_PORT_LABEL: input._name,
130
+ SOURCE_LABEL: input_node_name,
131
+ SOURCE_PORT_LABEL: None,
132
+ }
133
+ )
134
+ with open(file_name, "w") as f:
135
+ # json.dump({"nodes": data[], "edges": edges_new_lst}, f)
136
+ json.dump(data, f, indent=2)
137
+
138
+ return data
@@ -0,0 +1,74 @@
1
+ from concurrent.futures import Executor
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+ import json
5
+
6
+
7
+ from python_workflow_definition.shared import (
8
+ get_dict,
9
+ get_list,
10
+ get_kwargs,
11
+ get_source_handles,
12
+ convert_nodes_list_to_dict,
13
+ NODES_LABEL,
14
+ EDGES_LABEL,
15
+ SOURCE_LABEL,
16
+ SOURCE_PORT_LABEL,
17
+ )
18
+ from python_workflow_definition.purepython import resort_total_lst, group_edges
19
+
20
+
21
+ def get_item(obj, key):
22
+ return obj[key]
23
+
24
+
25
+ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict, exe: Executor):
26
+ source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
27
+ if source in result_dict.keys():
28
+ result = result_dict[source]
29
+ elif source in nodes_new_dict.keys():
30
+ result = nodes_new_dict[source]
31
+ else:
32
+ raise KeyError()
33
+ if source_handle is None:
34
+ return result
35
+ else:
36
+ return exe.submit(fn=get_item, obj=result, key=source_handle)
37
+
38
+
39
+ def load_workflow_json(file_name: str, exe: Executor):
40
+ with open(file_name, "r") as f:
41
+ content = json.load(f)
42
+
43
+ edges_new_lst = content[EDGES_LABEL]
44
+ nodes_new_dict = {}
45
+
46
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
47
+ if isinstance(v, str) and "." in v:
48
+ p, m = v.rsplit(".", 1)
49
+ mod = import_module(p)
50
+ nodes_new_dict[int(k)] = getattr(mod, m)
51
+ else:
52
+ nodes_new_dict[int(k)] = v
53
+
54
+ total_lst = group_edges(edges_new_lst)
55
+ total_new_lst = resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
56
+
57
+ result_dict = {}
58
+ last_key = None
59
+ for lst in total_new_lst:
60
+ node = nodes_new_dict[lst[0]]
61
+ if isfunction(node):
62
+ kwargs = {
63
+ k: _get_value(
64
+ result_dict=result_dict,
65
+ nodes_new_dict=nodes_new_dict,
66
+ link_dict=v,
67
+ exe=exe,
68
+ )
69
+ for k, v in lst[1].items()
70
+ }
71
+ result_dict[lst[0]] = exe.submit(node, **kwargs)
72
+ last_key = lst[0]
73
+
74
+ return result_dict[last_key]
@@ -0,0 +1,333 @@
1
+ import json
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+
5
+ import numpy as np
6
+ from jobflow import job, Flow
7
+
8
+ from python_workflow_definition.shared import (
9
+ get_dict,
10
+ get_list,
11
+ get_kwargs,
12
+ get_source_handles,
13
+ convert_nodes_list_to_dict,
14
+ NODES_LABEL,
15
+ EDGES_LABEL,
16
+ SOURCE_LABEL,
17
+ SOURCE_PORT_LABEL,
18
+ TARGET_LABEL,
19
+ TARGET_PORT_LABEL,
20
+ )
21
+
22
+
23
+ def _get_function_dict(flow: Flow):
24
+ return {job.uuid: job.function for job in flow.jobs}
25
+
26
+
27
+ def _get_nodes_dict(function_dict: dict):
28
+ nodes_dict, nodes_mapping_dict = {}, {}
29
+ for i, [k, v] in enumerate(function_dict.items()):
30
+ nodes_dict[i] = v
31
+ nodes_mapping_dict[k] = i
32
+
33
+ return nodes_dict, nodes_mapping_dict
34
+
35
+
36
+ def _get_edge_from_dict(
37
+ target: str, key: str, value_dict: dict, nodes_mapping_dict: dict
38
+ ) -> dict:
39
+ if len(value_dict["attributes"]) == 1:
40
+ return {
41
+ TARGET_LABEL: target,
42
+ TARGET_PORT_LABEL: key,
43
+ SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
44
+ SOURCE_PORT_LABEL: value_dict["attributes"][0][1],
45
+ }
46
+ else:
47
+ return {
48
+ TARGET_LABEL: target,
49
+ TARGET_PORT_LABEL: key,
50
+ SOURCE_LABEL: nodes_mapping_dict[value_dict["uuid"]],
51
+ SOURCE_PORT_LABEL: None,
52
+ }
53
+
54
+
55
+ def _get_edges_and_extend_nodes(
56
+ flow_dict: dict, nodes_mapping_dict: dict, nodes_dict: dict
57
+ ):
58
+ edges_lst = []
59
+ for job in flow_dict["jobs"]:
60
+ for k, v in job["function_kwargs"].items():
61
+ if (
62
+ isinstance(v, dict)
63
+ and "@module" in v
64
+ and "@class" in v
65
+ and "@version" in v
66
+ ):
67
+ edges_lst.append(
68
+ _get_edge_from_dict(
69
+ target=nodes_mapping_dict[job["uuid"]],
70
+ key=k,
71
+ value_dict=v,
72
+ nodes_mapping_dict=nodes_mapping_dict,
73
+ )
74
+ )
75
+ elif isinstance(v, dict) and any(
76
+ [
77
+ isinstance(el, dict)
78
+ and "@module" in el
79
+ and "@class" in el
80
+ and "@version" in el
81
+ for el in v.values()
82
+ ]
83
+ ):
84
+ node_dict_index = len(nodes_dict)
85
+ nodes_dict[node_dict_index] = get_dict
86
+ for kt, vt in v.items():
87
+ if (
88
+ isinstance(vt, dict)
89
+ and "@module" in vt
90
+ and "@class" in vt
91
+ and "@version" in vt
92
+ ):
93
+ edges_lst.append(
94
+ _get_edge_from_dict(
95
+ target=node_dict_index,
96
+ key=kt,
97
+ value_dict=vt,
98
+ nodes_mapping_dict=nodes_mapping_dict,
99
+ )
100
+ )
101
+ else:
102
+ if vt not in nodes_dict.values():
103
+ node_index = len(nodes_dict)
104
+ nodes_dict[node_index] = vt
105
+ else:
106
+ node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
107
+ str(vt)
108
+ ]
109
+ edges_lst.append(
110
+ {
111
+ TARGET_LABEL: node_dict_index,
112
+ TARGET_PORT_LABEL: kt,
113
+ SOURCE_LABEL: node_index,
114
+ SOURCE_PORT_LABEL: None,
115
+ }
116
+ )
117
+ edges_lst.append(
118
+ {
119
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
120
+ TARGET_PORT_LABEL: k,
121
+ SOURCE_LABEL: node_dict_index,
122
+ SOURCE_PORT_LABEL: None,
123
+ }
124
+ )
125
+ elif isinstance(v, list) and any(
126
+ [
127
+ isinstance(el, dict)
128
+ and "@module" in el
129
+ and "@class" in el
130
+ and "@version" in el
131
+ for el in v
132
+ ]
133
+ ):
134
+ node_list_index = len(nodes_dict)
135
+ nodes_dict[node_list_index] = get_list
136
+ for kt, vt in enumerate(v):
137
+ if (
138
+ isinstance(vt, dict)
139
+ and "@module" in vt
140
+ and "@class" in vt
141
+ and "@version" in vt
142
+ ):
143
+ edges_lst.append(
144
+ _get_edge_from_dict(
145
+ target=node_list_index,
146
+ key=str(kt),
147
+ value_dict=vt,
148
+ nodes_mapping_dict=nodes_mapping_dict,
149
+ )
150
+ )
151
+ else:
152
+ if vt not in nodes_dict.values():
153
+ node_index = len(nodes_dict)
154
+ nodes_dict[node_index] = vt
155
+ else:
156
+ node_index = {str(tv): tk for tk, tv in nodes_dict.items()}[
157
+ str(vt)
158
+ ]
159
+ edges_lst.append(
160
+ {
161
+ TARGET_LABEL: node_list_index,
162
+ TARGET_PORT_LABEL: kt,
163
+ SOURCE_LABEL: node_index,
164
+ SOURCE_PORT_LABEL: None,
165
+ }
166
+ )
167
+ edges_lst.append(
168
+ {
169
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
170
+ TARGET_PORT_LABEL: k,
171
+ SOURCE_LABEL: node_list_index,
172
+ SOURCE_PORT_LABEL: None,
173
+ }
174
+ )
175
+ else:
176
+ if v not in nodes_dict.values():
177
+ node_index = len(nodes_dict)
178
+ nodes_dict[node_index] = v
179
+ else:
180
+ node_index = {tv: tk for tk, tv in nodes_dict.items()}[v]
181
+ edges_lst.append(
182
+ {
183
+ TARGET_LABEL: nodes_mapping_dict[job["uuid"]],
184
+ TARGET_PORT_LABEL: k,
185
+ SOURCE_LABEL: node_index,
186
+ SOURCE_PORT_LABEL: None,
187
+ }
188
+ )
189
+ return edges_lst, nodes_dict
190
+
191
+
192
+ def _resort_total_lst(total_dict: dict, nodes_dict: dict) -> dict:
193
+ nodes_with_dep_lst = list(sorted(total_dict.keys()))
194
+ nodes_without_dep_lst = [
195
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
196
+ ]
197
+ ordered_lst = []
198
+ total_new_dict = {}
199
+ while len(total_new_dict) < len(total_dict):
200
+ for ind in sorted(total_dict.keys()):
201
+ connect = total_dict[ind]
202
+ if ind not in ordered_lst:
203
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
204
+ if all(
205
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
206
+ ):
207
+ ordered_lst.append(ind)
208
+ total_new_dict[ind] = connect
209
+ return total_new_dict
210
+
211
+
212
+ def _group_edges(edges_lst: list) -> dict:
213
+ total_dict = {}
214
+ for ed_major in edges_lst:
215
+ target_id = ed_major[TARGET_LABEL]
216
+ tmp_lst = []
217
+ if target_id not in total_dict.keys():
218
+ for ed in edges_lst:
219
+ if target_id == ed[TARGET_LABEL]:
220
+ tmp_lst.append(ed)
221
+ total_dict[target_id] = get_kwargs(lst=tmp_lst)
222
+ return total_dict
223
+
224
+
225
+ def _get_input_dict(nodes_dict: dict) -> dict:
226
+ return {k: v for k, v in nodes_dict.items() if not isfunction(v)}
227
+
228
+
229
+ def _get_workflow(
230
+ nodes_dict: dict, input_dict: dict, total_dict: dict, source_handles_dict: dict
231
+ ) -> list:
232
+ def get_attr_helper(obj, source_handle):
233
+ if source_handle is None:
234
+ return getattr(obj, "output")
235
+ else:
236
+ return getattr(getattr(obj, "output"), source_handle)
237
+
238
+ memory_dict = {}
239
+ for k in total_dict.keys():
240
+ v = nodes_dict[k]
241
+ if isfunction(v):
242
+ if k in source_handles_dict.keys():
243
+ fn = job(
244
+ method=v,
245
+ data=[el for el in source_handles_dict[k] if el is not None],
246
+ )
247
+ else:
248
+ fn = job(method=v)
249
+ kwargs = {
250
+ kw: (
251
+ input_dict[vw[SOURCE_LABEL]]
252
+ if vw[SOURCE_LABEL] in input_dict
253
+ else get_attr_helper(
254
+ obj=memory_dict[vw[SOURCE_LABEL]],
255
+ source_handle=vw[SOURCE_PORT_LABEL],
256
+ )
257
+ )
258
+ for kw, vw in total_dict[k].items()
259
+ }
260
+ memory_dict[k] = fn(**kwargs)
261
+ return list(memory_dict.values())
262
+
263
+
264
+ def _get_item_from_tuple(input_obj, index, index_lst):
265
+ if isinstance(input_obj, dict):
266
+ return input_obj[index]
267
+ else:
268
+ return list(input_obj)[index_lst.index(index)]
269
+
270
+
271
+ def load_workflow_json(file_name: str) -> Flow:
272
+ with open(file_name, "r") as f:
273
+ content = json.load(f)
274
+
275
+ edges_new_lst = []
276
+ for edge in content[EDGES_LABEL]:
277
+ if edge[SOURCE_PORT_LABEL] is None:
278
+ edges_new_lst.append(edge)
279
+ else:
280
+ edges_new_lst.append(
281
+ {
282
+ TARGET_LABEL: edge[TARGET_LABEL],
283
+ TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
284
+ SOURCE_LABEL: edge[SOURCE_LABEL],
285
+ SOURCE_PORT_LABEL: str(edge[SOURCE_PORT_LABEL]),
286
+ }
287
+ )
288
+
289
+ nodes_new_dict = {}
290
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
291
+ if isinstance(v, str) and "." in v:
292
+ p, m = v.rsplit(".", 1)
293
+ mod = import_module(p)
294
+ nodes_new_dict[int(k)] = getattr(mod, m)
295
+ else:
296
+ nodes_new_dict[int(k)] = v
297
+
298
+ source_handles_dict = get_source_handles(edges_lst=edges_new_lst)
299
+ total_dict = _group_edges(edges_lst=edges_new_lst)
300
+ input_dict = _get_input_dict(nodes_dict=nodes_new_dict)
301
+ new_total_dict = _resort_total_lst(total_dict=total_dict, nodes_dict=nodes_new_dict)
302
+ task_lst = _get_workflow(
303
+ nodes_dict=nodes_new_dict,
304
+ input_dict=input_dict,
305
+ total_dict=new_total_dict,
306
+ source_handles_dict=source_handles_dict,
307
+ )
308
+ return Flow(task_lst)
309
+
310
+
311
+ def write_workflow_json(flow: Flow, file_name: str = "workflow.json"):
312
+ flow_dict = flow.as_dict()
313
+ function_dict = _get_function_dict(flow=flow)
314
+ nodes_dict, nodes_mapping_dict = _get_nodes_dict(function_dict=function_dict)
315
+ edges_lst, nodes_dict = _get_edges_and_extend_nodes(
316
+ flow_dict=flow_dict,
317
+ nodes_mapping_dict=nodes_mapping_dict,
318
+ nodes_dict=nodes_dict,
319
+ )
320
+
321
+ nodes_store_lst = []
322
+ for k, v in nodes_dict.items():
323
+ if isfunction(v):
324
+ nodes_store_lst.append(
325
+ {"id": k, "function": v.__module__ + "." + v.__name__}
326
+ )
327
+ elif isinstance(v, np.ndarray):
328
+ nodes_store_lst.append({"id": k, "value": v.tolist()})
329
+ else:
330
+ nodes_store_lst.append({"id": k, "value": v})
331
+
332
+ with open(file_name, "w") as f:
333
+ json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_lst}, f)
@@ -0,0 +1,45 @@
1
+ import json
2
+
3
+ from IPython.display import SVG, display
4
+ import networkx as nx
5
+
6
+
7
+ from python_workflow_definition.purepython import group_edges
8
+ from python_workflow_definition.shared import (
9
+ get_kwargs,
10
+ convert_nodes_list_to_dict,
11
+ NODES_LABEL,
12
+ EDGES_LABEL,
13
+ SOURCE_LABEL,
14
+ SOURCE_PORT_LABEL,
15
+ )
16
+
17
+
18
+ def plot(file_name: str):
19
+ with open(file_name, "r") as f:
20
+ content = json.load(f)
21
+
22
+ graph = nx.DiGraph()
23
+ node_dict = convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL])
24
+ total_lst = group_edges(edges_lst=content[EDGES_LABEL])
25
+
26
+ for node_id, node_name in node_dict.items():
27
+ graph.add_node(node_id, name=str(node_name), label=str(node_name))
28
+
29
+ for edge_tuple in total_lst:
30
+ target_node, edge_dict = edge_tuple
31
+ edge_label_dict = {}
32
+ for k, v in edge_dict.items():
33
+ if v[SOURCE_LABEL] not in edge_label_dict:
34
+ edge_label_dict[v[SOURCE_LABEL]] = []
35
+ if v[SOURCE_PORT_LABEL] is None:
36
+ edge_label_dict[v[SOURCE_LABEL]].append(k)
37
+ else:
38
+ edge_label_dict[v[SOURCE_LABEL]].append(
39
+ k + "=result[" + v[SOURCE_PORT_LABEL] + "]"
40
+ )
41
+ for k, v in edge_label_dict.items():
42
+ graph.add_edge(str(k), str(target_node), label=", ".join(v))
43
+
44
+ svg = nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")
45
+ display(SVG(svg))
@@ -0,0 +1,99 @@
1
+ import json
2
+ from importlib import import_module
3
+ from inspect import isfunction
4
+
5
+
6
+ from python_workflow_definition.shared import (
7
+ get_dict,
8
+ get_list,
9
+ get_kwargs,
10
+ get_source_handles,
11
+ convert_nodes_list_to_dict,
12
+ NODES_LABEL,
13
+ EDGES_LABEL,
14
+ SOURCE_LABEL,
15
+ SOURCE_PORT_LABEL,
16
+ TARGET_LABEL,
17
+ TARGET_PORT_LABEL,
18
+ )
19
+
20
+
21
+ def resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
22
+ nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
23
+ nodes_without_dep_lst = [
24
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
25
+ ]
26
+ ordered_lst, total_new_lst = [], []
27
+ while len(total_new_lst) < len(total_lst):
28
+ for ind, connect in total_lst:
29
+ if ind not in ordered_lst:
30
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
31
+ if all(
32
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
33
+ ):
34
+ ordered_lst.append(ind)
35
+ total_new_lst.append([ind, connect])
36
+ return total_new_lst
37
+
38
+
39
+ def group_edges(edges_lst: list) -> list:
40
+ edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
41
+ total_lst, tmp_lst = [], []
42
+ target_id = edges_sorted_lst[0][TARGET_LABEL]
43
+ for ed in edges_sorted_lst:
44
+ if target_id == ed[TARGET_LABEL]:
45
+ tmp_lst.append(ed)
46
+ else:
47
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
48
+ target_id = ed[TARGET_LABEL]
49
+ tmp_lst = [ed]
50
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
51
+ return total_lst
52
+
53
+
54
+ def _get_value(result_dict: dict, nodes_new_dict: dict, link_dict: dict):
55
+ source, source_handle = link_dict[SOURCE_LABEL], link_dict[SOURCE_PORT_LABEL]
56
+ if source in result_dict.keys():
57
+ result = result_dict[source]
58
+ elif source in nodes_new_dict.keys():
59
+ result = nodes_new_dict[source]
60
+ else:
61
+ raise KeyError()
62
+ if source_handle is None:
63
+ return result
64
+ else:
65
+ return result[source_handle]
66
+
67
+
68
+ def load_workflow_json(file_name: str):
69
+ with open(file_name, "r") as f:
70
+ content = json.load(f)
71
+
72
+ edges_new_lst = content[EDGES_LABEL]
73
+ nodes_new_dict = {}
74
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
75
+ if isinstance(v, str) and "." in v:
76
+ p, m = v.rsplit(".", 1)
77
+ mod = import_module(p)
78
+ nodes_new_dict[int(k)] = getattr(mod, m)
79
+ else:
80
+ nodes_new_dict[int(k)] = v
81
+
82
+ total_lst = group_edges(edges_new_lst)
83
+ total_new_lst = resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
84
+
85
+ result_dict = {}
86
+ last_key = None
87
+ for lst in total_new_lst:
88
+ node = nodes_new_dict[lst[0]]
89
+ if isfunction(node):
90
+ kwargs = {
91
+ k: _get_value(
92
+ result_dict=result_dict, nodes_new_dict=nodes_new_dict, link_dict=v
93
+ )
94
+ for k, v in lst[1].items()
95
+ }
96
+ result_dict[lst[0]] = node(**kwargs)
97
+ last_key = lst[0]
98
+
99
+ return result_dict[last_key]
@@ -0,0 +1,292 @@
1
+ from importlib import import_module
2
+ from inspect import isfunction
3
+ import json
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ from pyiron_base import job, Project
8
+ from pyiron_base.project.delayed import DelayedObject
9
+
10
+ from python_workflow_definition.shared import (
11
+ get_kwargs,
12
+ get_source_handles,
13
+ convert_nodes_list_to_dict,
14
+ NODES_LABEL,
15
+ EDGES_LABEL,
16
+ SOURCE_LABEL,
17
+ SOURCE_PORT_LABEL,
18
+ TARGET_LABEL,
19
+ TARGET_PORT_LABEL,
20
+ )
21
+
22
+
23
+ def _resort_total_lst(total_lst: list, nodes_dict: dict) -> list:
24
+ nodes_with_dep_lst = list(sorted([v[0] for v in total_lst]))
25
+ nodes_without_dep_lst = [
26
+ k for k in nodes_dict.keys() if k not in nodes_with_dep_lst
27
+ ]
28
+ ordered_lst, total_new_lst = [], []
29
+ while len(total_new_lst) < len(total_lst):
30
+ for ind, connect in total_lst:
31
+ if ind not in ordered_lst:
32
+ source_lst = [sd[SOURCE_LABEL] for sd in connect.values()]
33
+ if all(
34
+ [s in ordered_lst or s in nodes_without_dep_lst for s in source_lst]
35
+ ):
36
+ ordered_lst.append(ind)
37
+ total_new_lst.append([ind, connect])
38
+ return total_new_lst
39
+
40
+
41
+ def _group_edges(edges_lst: list) -> list:
42
+ edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
43
+ total_lst, tmp_lst = [], []
44
+ target_id = edges_sorted_lst[0][TARGET_LABEL]
45
+ for ed in edges_sorted_lst:
46
+ if target_id == ed[TARGET_LABEL]:
47
+ tmp_lst.append(ed)
48
+ else:
49
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
50
+ target_id = ed[TARGET_LABEL]
51
+ tmp_lst = [ed]
52
+ total_lst.append((target_id, get_kwargs(lst=tmp_lst)))
53
+ return total_lst
54
+
55
+
56
+ def _get_source(
57
+ nodes_dict: dict, delayed_object_dict: dict, source: str, source_handle: str
58
+ ):
59
+ if source in delayed_object_dict.keys() and source_handle is not None:
60
+ return (
61
+ delayed_object_dict[source].__getattr__("output").__getattr__(source_handle)
62
+ )
63
+ elif source in delayed_object_dict.keys():
64
+ return delayed_object_dict[source]
65
+ else:
66
+ return nodes_dict[source]
67
+
68
+
69
+ def _get_delayed_object_dict(
70
+ total_lst: list, nodes_dict: dict, source_handle_dict: dict, pyiron_project: Project
71
+ ) -> dict:
72
+ delayed_object_dict = {}
73
+ for item in total_lst:
74
+ key, input_dict = item
75
+ kwargs = {
76
+ k: _get_source(
77
+ nodes_dict=nodes_dict,
78
+ delayed_object_dict=delayed_object_dict,
79
+ source=v[SOURCE_LABEL],
80
+ source_handle=v[SOURCE_PORT_LABEL],
81
+ )
82
+ for k, v in input_dict.items()
83
+ }
84
+ delayed_object_dict[key] = job(
85
+ funct=nodes_dict[key],
86
+ output_key_lst=source_handle_dict.get(key, []),
87
+ )(**kwargs, pyiron_project=pyiron_project)
88
+ return delayed_object_dict
89
+
90
+
91
+ def get_dict(**kwargs) -> dict:
92
+ return {k: v for k, v in kwargs["kwargs"].items()}
93
+
94
+
95
+ def get_list(**kwargs) -> list:
96
+ return list(kwargs["kwargs"].values())
97
+
98
+
99
+ def _remove_server_obj(nodes_dict: dict, edges_lst: list):
100
+ server_lst = [k for k in nodes_dict.keys() if k.startswith("server_obj_")]
101
+ for s in server_lst:
102
+ del nodes_dict[s]
103
+ edges_lst = [ep for ep in edges_lst if s not in ep]
104
+ return nodes_dict, edges_lst
105
+
106
+
107
+ def _get_nodes(connection_dict: dict, delayed_object_updated_dict: dict):
108
+ return {
109
+ connection_dict[k]: v._python_function if isinstance(v, DelayedObject) else v
110
+ for k, v in delayed_object_updated_dict.items()
111
+ }
112
+
113
+
114
+ def _get_unique_objects(nodes_dict: dict):
115
+ delayed_object_dict = {}
116
+ for k, v in nodes_dict.items():
117
+ if isinstance(v, DelayedObject):
118
+ delayed_object_dict[k] = v
119
+ elif isinstance(v, list) and any([isinstance(el, DelayedObject) for el in v]):
120
+ delayed_object_dict[k] = DelayedObject(function=get_list)
121
+ delayed_object_dict[k]._input = {i: el for i, el in enumerate(v)}
122
+ delayed_object_dict[k]._python_function = get_list
123
+ elif isinstance(v, dict) and any(
124
+ [isinstance(el, DelayedObject) for el in v.values()]
125
+ ):
126
+ delayed_object_dict[k] = DelayedObject(
127
+ function=get_dict,
128
+ **v,
129
+ )
130
+ delayed_object_dict[k]._python_function = get_dict
131
+ delayed_object_dict[k]._input = v
132
+ unique_lst = []
133
+ delayed_object_updated_dict, match_dict = {}, {}
134
+ for dobj in delayed_object_dict.keys():
135
+ match = False
136
+ for obj in unique_lst:
137
+ if (
138
+ delayed_object_updated_dict[obj]._python_function
139
+ == delayed_object_dict[dobj]._python_function
140
+ and delayed_object_dict[dobj]._input == delayed_object_dict[obj]._input
141
+ ):
142
+ delayed_object_updated_dict[obj] = delayed_object_dict[obj]
143
+ match_dict[dobj] = obj
144
+ match = True
145
+ break
146
+ if not match:
147
+ unique_lst.append(dobj)
148
+ delayed_object_updated_dict[dobj] = delayed_object_dict[dobj]
149
+ update_dict = {}
150
+ for k, v in nodes_dict.items():
151
+ if not (
152
+ isinstance(v, DelayedObject)
153
+ or (
154
+ isinstance(v, list) and any([isinstance(el, DelayedObject) for el in v])
155
+ )
156
+ or (
157
+ isinstance(v, dict)
158
+ and any([isinstance(el, DelayedObject) for el in v.values()])
159
+ )
160
+ ):
161
+ update_dict[k] = v
162
+ delayed_object_updated_dict.update(update_dict)
163
+ return delayed_object_updated_dict, match_dict
164
+
165
+
166
+ def _get_connection_dict(delayed_object_updated_dict: dict, match_dict: dict):
167
+ new_obj_dict = {}
168
+ connection_dict = {}
169
+ lookup_dict = {}
170
+ for i, [k, v] in enumerate(delayed_object_updated_dict.items()):
171
+ new_obj_dict[i] = v
172
+ connection_dict[k] = i
173
+ lookup_dict[i] = k
174
+
175
+ for k, v in match_dict.items():
176
+ if v in connection_dict.keys():
177
+ connection_dict[k] = connection_dict[v]
178
+
179
+ return connection_dict, lookup_dict
180
+
181
+
182
+ def _get_edges_dict(
183
+ edges_lst: list, nodes_dict: dict, connection_dict: dict, lookup_dict: dict
184
+ ):
185
+ edges_dict_lst = []
186
+ existing_connection_lst = []
187
+ for ep in edges_lst:
188
+ input_name, output_name = ep
189
+ target = connection_dict[input_name]
190
+ target_handle = "_".join(output_name.split("_")[:-1])
191
+ connection_name = lookup_dict[target] + "_" + target_handle
192
+ if connection_name not in existing_connection_lst:
193
+ output = nodes_dict[output_name]
194
+ if isinstance(output, DelayedObject):
195
+ if output._list_index is not None:
196
+ edges_dict_lst.append(
197
+ {
198
+ TARGET_LABEL: target,
199
+ TARGET_PORT_LABEL: target_handle,
200
+ SOURCE_LABEL: connection_dict[output_name],
201
+ SOURCE_PORT_LABEL: f"s_{output._list_index}", # check for list index
202
+ }
203
+ )
204
+ else:
205
+ edges_dict_lst.append(
206
+ {
207
+ TARGET_LABEL: target,
208
+ TARGET_PORT_LABEL: target_handle,
209
+ SOURCE_LABEL: connection_dict[output_name],
210
+ SOURCE_PORT_LABEL: output._output_key, # check for list index
211
+ }
212
+ )
213
+ else:
214
+ edges_dict_lst.append(
215
+ {
216
+ TARGET_LABEL: target,
217
+ TARGET_PORT_LABEL: target_handle,
218
+ SOURCE_LABEL: connection_dict[output_name],
219
+ SOURCE_PORT_LABEL: None,
220
+ }
221
+ )
222
+ existing_connection_lst.append(connection_name)
223
+ return edges_dict_lst
224
+
225
+
226
+ def load_workflow_json(file_name: str, project: Optional[Project] = None):
227
+ if project is None:
228
+ project = Project(".")
229
+
230
+ with open(file_name, "r") as f:
231
+ content = json.load(f)
232
+
233
+ edges_new_lst = content[EDGES_LABEL]
234
+ nodes_new_dict = {}
235
+ for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
236
+ if isinstance(v, str) and "." in v:
237
+ p, m = v.rsplit(".", 1)
238
+ if p == "python_workflow_definition.shared":
239
+ p = "python_workflow_definition.pyiron_base"
240
+ mod = import_module(p)
241
+ nodes_new_dict[int(k)] = getattr(mod, m)
242
+ else:
243
+ nodes_new_dict[int(k)] = v
244
+
245
+ total_lst = _group_edges(edges_new_lst)
246
+ total_new_lst = _resort_total_lst(total_lst=total_lst, nodes_dict=nodes_new_dict)
247
+ source_handle_dict = get_source_handles(edges_new_lst)
248
+ delayed_object_dict = _get_delayed_object_dict(
249
+ total_lst=total_new_lst,
250
+ nodes_dict=nodes_new_dict,
251
+ source_handle_dict=source_handle_dict,
252
+ pyiron_project=project,
253
+ )
254
+ return list(delayed_object_dict.values())
255
+
256
+
257
+ def write_workflow_json(
258
+ delayed_object: DelayedObject, file_name: str = "workflow.json"
259
+ ):
260
+ nodes_dict, edges_lst = delayed_object.get_graph()
261
+ nodes_dict, edges_lst = _remove_server_obj(
262
+ nodes_dict=nodes_dict, edges_lst=edges_lst
263
+ )
264
+ delayed_object_updated_dict, match_dict = _get_unique_objects(nodes_dict=nodes_dict)
265
+ connection_dict, lookup_dict = _get_connection_dict(
266
+ delayed_object_updated_dict=delayed_object_updated_dict, match_dict=match_dict
267
+ )
268
+ nodes_new_dict = _get_nodes(
269
+ connection_dict=connection_dict,
270
+ delayed_object_updated_dict=delayed_object_updated_dict,
271
+ )
272
+ edges_new_lst = _get_edges_dict(
273
+ edges_lst=edges_lst,
274
+ nodes_dict=nodes_dict,
275
+ connection_dict=connection_dict,
276
+ lookup_dict=lookup_dict,
277
+ )
278
+
279
+ nodes_store_lst = []
280
+ for k, v in nodes_new_dict.items():
281
+ if isfunction(v):
282
+ mod = v.__module__
283
+ if mod == "python_workflow_definition.pyiron_base":
284
+ mod = "python_workflow_definition.shared"
285
+ nodes_store_lst.append({"id": k, "function": mod + "." + v.__name__})
286
+ elif isinstance(v, np.ndarray):
287
+ nodes_store_lst.append({"id": k, "value": v.tolist()})
288
+ else:
289
+ nodes_store_lst.append({"id": k, "value": v})
290
+
291
+ with open(file_name, "w") as f:
292
+ json.dump({NODES_LABEL: nodes_store_lst, EDGES_LABEL: edges_new_lst}, f)
@@ -0,0 +1,45 @@
1
+ NODES_LABEL = "nodes"
2
+ EDGES_LABEL = "edges"
3
+ SOURCE_LABEL = "source"
4
+ SOURCE_PORT_LABEL = "sourcePort"
5
+ TARGET_LABEL = "target"
6
+ TARGET_PORT_LABEL = "targetPort"
7
+
8
+
9
+ def get_dict(**kwargs) -> dict:
10
+ # NOTE: In WG, this will automatically be wrapped in a dict with the `result` key
11
+ return {k: v for k, v in kwargs.items()}
12
+ # return {'dict': {k: v for k, v in kwargs.items()}}
13
+
14
+
15
+ def get_list(**kwargs) -> list:
16
+ return list(kwargs.values())
17
+
18
+
19
+ def get_kwargs(lst: list) -> dict:
20
+ return {
21
+ t[TARGET_PORT_LABEL]: {
22
+ SOURCE_LABEL: t[SOURCE_LABEL],
23
+ SOURCE_PORT_LABEL: t[SOURCE_PORT_LABEL],
24
+ }
25
+ for t in lst
26
+ }
27
+
28
+
29
+ def get_source_handles(edges_lst: list) -> dict:
30
+ source_handle_dict = {}
31
+ for ed in edges_lst:
32
+ if ed[SOURCE_LABEL] not in source_handle_dict.keys():
33
+ source_handle_dict[ed[SOURCE_LABEL]] = []
34
+ source_handle_dict[ed[SOURCE_LABEL]].append(ed[SOURCE_PORT_LABEL])
35
+ return {
36
+ k: list(range(len(v))) if len(v) > 1 and all([el is None for el in v]) else v
37
+ for k, v in source_handle_dict.items()
38
+ }
39
+
40
+
41
+ def convert_nodes_list_to_dict(nodes_list: list) -> dict:
42
+ return {
43
+ str(el["id"]): el["value"] if "value" in el else el["function"]
44
+ for el in sorted(nodes_list, key=lambda d: d["id"])
45
+ }