nerdd-link 0.2.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. nerdd_link-0.2.11/LICENSE +21 -0
  2. nerdd_link-0.2.11/PKG-INFO +116 -0
  3. nerdd_link-0.2.11/README.md +39 -0
  4. nerdd_link-0.2.11/nerdd_link/__init__.py +6 -0
  5. nerdd_link-0.2.11/nerdd_link/actions/__init__.py +5 -0
  6. nerdd_link-0.2.11/nerdd_link/actions/action.py +41 -0
  7. nerdd_link-0.2.11/nerdd_link/actions/predict_checkpoints_action.py +76 -0
  8. nerdd_link-0.2.11/nerdd_link/actions/process_jobs_action.py +157 -0
  9. nerdd_link-0.2.11/nerdd_link/actions/register_module_action.py +37 -0
  10. nerdd_link-0.2.11/nerdd_link/actions/serialize_job_action.py +64 -0
  11. nerdd_link-0.2.11/nerdd_link/channels/__init__.py +3 -0
  12. nerdd_link-0.2.11/nerdd_link/channels/channel.py +147 -0
  13. nerdd_link-0.2.11/nerdd_link/channels/kafka_channel.py +109 -0
  14. nerdd_link-0.2.11/nerdd_link/channels/memory_channel.py +47 -0
  15. nerdd_link-0.2.11/nerdd_link/cli/__init__.py +4 -0
  16. nerdd_link-0.2.11/nerdd_link/cli/initialize_system.py +49 -0
  17. nerdd_link-0.2.11/nerdd_link/cli/run_job_server.py +115 -0
  18. nerdd_link-0.2.11/nerdd_link/cli/run_prediction_server.py +85 -0
  19. nerdd_link-0.2.11/nerdd_link/cli/run_serialization_server.py +73 -0
  20. nerdd_link-0.2.11/nerdd_link/converters/__init__.py +6 -0
  21. nerdd_link-0.2.11/nerdd_link/converters/image_converter.py +15 -0
  22. nerdd_link-0.2.11/nerdd_link/converters/mol_pickle_converter.py +20 -0
  23. nerdd_link-0.2.11/nerdd_link/converters/mol_to_image_converter.py +77 -0
  24. nerdd_link-0.2.11/nerdd_link/converters/pickle_converter.py +15 -0
  25. nerdd_link-0.2.11/nerdd_link/converters/problem_list_converter.py +15 -0
  26. nerdd_link-0.2.11/nerdd_link/converters/source_list_converter.py +15 -0
  27. nerdd_link-0.2.11/nerdd_link/delegates/__init__.py +5 -0
  28. nerdd_link-0.2.11/nerdd_link/delegates/pickle_writer.py +18 -0
  29. nerdd_link-0.2.11/nerdd_link/delegates/read_checkpoint_model.py +73 -0
  30. nerdd_link-0.2.11/nerdd_link/delegates/read_pickle_step.py +21 -0
  31. nerdd_link-0.2.11/nerdd_link/delegates/serialize_job_model.py +21 -0
  32. nerdd_link-0.2.11/nerdd_link/delegates/split_and_merge_step.py +51 -0
  33. nerdd_link-0.2.11/nerdd_link/delegates/topic_writer.py +19 -0
  34. nerdd_link-0.2.11/nerdd_link/files/__init__.py +1 -0
  35. nerdd_link-0.2.11/nerdd_link/files/file_system.py +89 -0
  36. nerdd_link-0.2.11/nerdd_link/input/__init__.py +1 -0
  37. nerdd_link-0.2.11/nerdd_link/input/structure_json_reader.py +37 -0
  38. nerdd_link-0.2.11/nerdd_link/py.typed +0 -0
  39. nerdd_link-0.2.11/nerdd_link/tests/__init__.py +3 -0
  40. nerdd_link-0.2.11/nerdd_link/tests/async_step.py +22 -0
  41. nerdd_link-0.2.11/nerdd_link/tests/channels.py +81 -0
  42. nerdd_link-0.2.11/nerdd_link/tests/files.py +9 -0
  43. nerdd_link-0.2.11/nerdd_link/types/__init__.py +70 -0
  44. nerdd_link-0.2.11/nerdd_link/utils/__init__.py +4 -0
  45. nerdd_link-0.2.11/nerdd_link/utils/async_to_sync.py +26 -0
  46. nerdd_link-0.2.11/nerdd_link/utils/batched.py +27 -0
  47. nerdd_link-0.2.11/nerdd_link/utils/observable_list.py +72 -0
  48. nerdd_link-0.2.11/nerdd_link/utils/safetee.py +39 -0
  49. nerdd_link-0.2.11/nerdd_link/version.py +11 -0
  50. nerdd_link-0.2.11/nerdd_link.egg-info/PKG-INFO +116 -0
  51. nerdd_link-0.2.11/nerdd_link.egg-info/SOURCES.txt +56 -0
  52. nerdd_link-0.2.11/nerdd_link.egg-info/dependency_links.txt +1 -0
  53. nerdd_link-0.2.11/nerdd_link.egg-info/entry_points.txt +5 -0
  54. nerdd_link-0.2.11/nerdd_link.egg-info/requires.txt +34 -0
  55. nerdd_link-0.2.11/nerdd_link.egg-info/top_level.txt +1 -0
  56. nerdd_link-0.2.11/pyproject.toml +139 -0
  57. nerdd_link-0.2.11/setup.cfg +4 -0
  58. nerdd_link-0.2.11/tests/test_features.py +3 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Molecular Informatics Vienna
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.2
2
+ Name: nerdd-link
3
+ Version: 0.2.11
4
+ Summary: Run a NERDD module as a service
5
+ Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
6
+ Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
7
+ License: MIT License
8
+
9
+ Copyright (c) 2023 Molecular Informatics Vienna
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+
29
+ Project-URL: Repository, https://github.com/molinfo-vienna/nerdd-link
30
+ Keywords: science,research,development,nerdd
31
+ Classifier: Intended Audience :: Science/Research
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: BSD License
34
+ Classifier: Programming Language :: Python
35
+ Classifier: Topic :: Software Development
36
+ Classifier: Topic :: Scientific/Engineering
37
+ Classifier: Operating System :: Microsoft :: Windows
38
+ Classifier: Operating System :: POSIX
39
+ Classifier: Operating System :: Unix
40
+ Classifier: Operating System :: MacOS
41
+ Classifier: Programming Language :: Python :: 3
42
+ Classifier: Programming Language :: Python :: 3.9
43
+ Classifier: Programming Language :: Python :: 3.10
44
+ Classifier: Programming Language :: Python :: 3.11
45
+ Classifier: Programming Language :: Python :: 3.12
46
+ Description-Content-Type: text/markdown
47
+ License-File: LICENSE
48
+ Requires-Dist: nerdd-module>=0.3.6
49
+ Requires-Dist: pandas>=1.2.1
50
+ Requires-Dist: pyyaml~=6.0
51
+ Requires-Dist: filetype~=1.2.0
52
+ Requires-Dist: rich-click>=1.7.1
53
+ Requires-Dist: stringcase~=1.2.0
54
+ Requires-Dist: numpy
55
+ Requires-Dist: simplejson>=3
56
+ Requires-Dist: pydantic>=2
57
+ Requires-Dist: aiokafka>=0.12.0
58
+ Requires-Dist: importlib-metadata>=4.6; python_version < "3.10"
59
+ Provides-Extra: dev
60
+ Requires-Dist: mypy; extra == "dev"
61
+ Requires-Dist: ruff==0.8.0; extra == "dev"
62
+ Requires-Dist: pre-commit>=2; extra == "dev"
63
+ Provides-Extra: test
64
+ Requires-Dist: pytest; extra == "test"
65
+ Requires-Dist: pytest-sugar; extra == "test"
66
+ Requires-Dist: pytest-cov; extra == "test"
67
+ Requires-Dist: pytest-asyncio; extra == "test"
68
+ Requires-Dist: pytest-bdd==7.3.0; extra == "test"
69
+ Requires-Dist: pytest-mock; extra == "test"
70
+ Requires-Dist: pytest-watcher; extra == "test"
71
+ Requires-Dist: hypothesis; extra == "test"
72
+ Requires-Dist: hypothesis-rdkit; extra == "test"
73
+ Provides-Extra: docs
74
+ Requires-Dist: mkdocs; extra == "docs"
75
+ Requires-Dist: mkdocs-material; extra == "docs"
76
+ Requires-Dist: mkdocstrings; extra == "docs"
77
+
78
+ # NERDD-Link
79
+
80
+ Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
81
+ service that consumes input molecules and produces prediction tuples.
82
+
83
+
84
+ ## Installation
85
+
86
+ ```bash
87
+ pip install -U nerdd-link
88
+ ```
89
+
90
+ ## Usage
91
+
92
+ When a class inherits from ```nerdd_module.AbstractModel``` (see
93
+ [NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
94
+ used to create a Kafka service.
95
+
96
+ ```bash
97
+ # run a Kafka service for NerddModel on localhost:9092
98
+ run_nerdd_server package.path.to.NerddModel
99
+
100
+ # modify broker url, input topic and batch size
101
+ run_nerdd_server package.path.to.NerddModel \
102
+ --broker-url my-cluster-kafka-bootstrap.kafka:9092 \
103
+ --input-topic examples \
104
+ --batch-size 10
105
+
106
+ # more information via --help
107
+ run_nerdd_server --help
108
+ ```
109
+
110
+ If the model class is called ```ExamplePredictionModel```, the server will read input
111
+ tuples from the input topic ```example-prediction-inputs``` in batches of size 100
112
+ and write results to the ```results``` topic. The batch size specifies the number
113
+ of input tuples that are given to the model at once.
114
+
115
+ ## Communication
116
+
@@ -0,0 +1,39 @@
1
+ # NERDD-Link
2
+
3
+ Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
4
+ service that consumes input molecules and produces prediction tuples.
5
+
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ pip install -U nerdd-link
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ When a class inherits from ```nerdd_module.AbstractModel``` (see
16
+ [NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
17
+ used to create a Kafka service.
18
+
19
+ ```bash
20
+ # run a Kafka service for NerddModel on localhost:9092
21
+ run_nerdd_server package.path.to.NerddModel
22
+
23
+ # modify broker url, input topic and batch size
24
+ run_nerdd_server package.path.to.NerddModel \
25
+ --broker-url my-cluster-kafka-bootstrap.kafka:9092 \
26
+ --input-topic examples \
27
+ --batch-size 10
28
+
29
+ # more information via --help
30
+ run_nerdd_server --help
31
+ ```
32
+
33
+ If the model class is called ```ExamplePredictionModel```, the server will read input
34
+ tuples from the input topic ```example-prediction-inputs``` in batches of size 100
35
+ and write results to the ```results``` topic. The batch size specifies the number
36
+ of input tuples that are given to the model at once.
37
+
38
+ ## Communication
39
+
@@ -0,0 +1,6 @@
1
+ from .actions import *
2
+ from .channels import *
3
+ from .converters import *
4
+ from .files import *
5
+ from .input import *
6
+ from .types import *
@@ -0,0 +1,5 @@
1
+ from .action import *
2
+ from .predict_checkpoints_action import *
3
+ from .process_jobs_action import *
4
+ from .register_module_action import *
5
+ from .serialize_job_action import *
@@ -0,0 +1,41 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from asyncio import CancelledError
4
+ from typing import Generic, TypeVar
5
+
6
+ from stringcase import spinalcase
7
+
8
+ from ..channels import Channel, Topic
9
+ from ..types import Message
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ T = TypeVar("T", bound=Message)
14
+
15
+
16
+ class Action(ABC, Generic[T]):
17
+ def __init__(self, input_topic: Topic[T]):
18
+ self._input_topic = input_topic
19
+
20
+ async def run(self) -> None:
21
+ consumer_group = spinalcase(self._get_group_name())
22
+ async for message in self._input_topic.receive(consumer_group):
23
+ try:
24
+ await self._process_message(message)
25
+ except CancelledError:
26
+ # the consumer was cancelled, stop processing messages
27
+ break
28
+ except Exception:
29
+ # log the error and continue processing the next message
30
+ logger.error("Error processing message", exc_info=True)
31
+
32
+ @abstractmethod
33
+ async def _process_message(self, message: T) -> None:
34
+ pass
35
+
36
+ @property
37
+ def channel(self) -> Channel:
38
+ return self._input_topic.channel
39
+
40
+ def _get_group_name(self) -> str:
41
+ return self.__class__.__name__
@@ -0,0 +1,76 @@
1
+ import logging
2
+ from multiprocessing import Queue
3
+
4
+ from nerdd_module import SimpleModel
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadCheckpointModel
8
+ from ..files import FileSystem
9
+ from ..types import CheckpointMessage, ResultCheckpointMessage, ResultMessage
10
+ from .action import Action
11
+
12
+ __all__ = ["PredictCheckpointsAction"]
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class PredictCheckpointsAction(Action[CheckpointMessage]):
18
+ # Accept a batch of input molecules on the "<job-type>-checkpoints" topic
19
+ # (generated in the previous step) and process them. Results are written to
20
+ # the "results" topic.
21
+
22
+ def __init__(self, channel: Channel, model: SimpleModel, data_dir: str) -> None:
23
+ super().__init__(channel.checkpoints_topic(model))
24
+ self._model = model
25
+ self._data_dir = data_dir
26
+
27
+ async def _process_message(self, message: CheckpointMessage) -> None:
28
+ job_id = message.job_id
29
+ checkpoint_id = message.checkpoint_id
30
+ params = message.params
31
+ logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
32
+
33
+ # The Kafka consumers and producers run in the current asyncio event loop and (by
34
+ # observation) it seems that calling the produce method of a Kafka producer in a different
35
+ # event loop / thread / process doesn't seem to work (hangs indefinitely). Therefore, we
36
+ # create a queue in this event loop / thread and other tasks send messages to the queue
37
+ # instead of directly to the Kafka producer. This event loop will wait for new messages in
38
+ # this queue and forward them to the Kafka producer.
39
+ queue: Queue = Queue()
40
+
41
+ file_system = FileSystem(self._data_dir)
42
+
43
+ # create a wrapper model that
44
+ # * reads the checkpoint file instead of normal input
45
+ # * does preprocessing, prediction, and postprocessing like the encapsulated model
46
+ # * does not write to the specified results file, but to the checkpoints file instead
47
+ # * sends the results to the results topic
48
+ model = ReadCheckpointModel(
49
+ base_model=self._model,
50
+ job_id=job_id,
51
+ file_system=file_system,
52
+ checkpoint_id=checkpoint_id,
53
+ queue=queue,
54
+ )
55
+
56
+ # predict the checkpoint
57
+ # assign input=None, because the checkpoint file is already provided in ReadCheckpointModel
58
+ model.predict(
59
+ input=None,
60
+ **params,
61
+ )
62
+
63
+ # Wait for the prediction to finish and the results to be sent.
64
+ while True:
65
+ record = queue.get()
66
+ if record is not None:
67
+ await self.channel.results_topic().send(ResultMessage(job_id=job_id, **record))
68
+ else:
69
+ await self.channel.result_checkpoints_topic().send(
70
+ ResultCheckpointMessage(job_id=job_id, checkpoint_id=checkpoint_id)
71
+ )
72
+ break
73
+
74
+ def _get_group_name(self) -> str:
75
+ model_id = self._model.get_config().id
76
+ return f"predict-checkpoints-{model_id}"
@@ -0,0 +1,157 @@
1
+ import logging
2
+ from pickle import dump
3
+ from typing import Any
4
+
5
+ from nerdd_module.input import DepthFirstExplorer
6
+ from nerdd_module.model import ReadInputStep
7
+ from rdkit.Chem import Mol
8
+ from rdkit.Chem.PropertyMol import PropertyMol
9
+
10
+ from ..channels import Channel
11
+ from ..files import FileSystem
12
+ from ..types import CheckpointMessage, JobMessage, LogMessage
13
+ from ..utils import batched
14
+ from .action import Action
15
+
16
+ __all__ = ["ProcessJobsAction"]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProcessJobsAction(Action[JobMessage]):
22
+ # Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
23
+ # iterates through all molecules in the input (files), writes them as batches
24
+ # into checkpoint files and sends checkpoint messages (for each batch) to the
25
+ # "<job_type>-checkpoints" topic. Also, the number of molecules read is
26
+ # reported to the topic "job-sizes".
27
+
28
+ def __init__(
29
+ self,
30
+ channel: Channel,
31
+ checkpoint_size: int,
32
+ max_num_molecules: int,
33
+ num_test_entries: int,
34
+ ratio_valid_entries: float,
35
+ maximum_depth: int,
36
+ max_num_lines_mol_block: int,
37
+ data_dir: str,
38
+ ) -> None:
39
+ super().__init__(channel.jobs_topic())
40
+ # relevant for chunking
41
+ self._checkpoint_size = checkpoint_size
42
+ self._max_num_molecules = max_num_molecules
43
+ # parameters of DepthFirstExplorer
44
+ self._num_test_entries = num_test_entries
45
+ self._ratio_valid_entries = ratio_valid_entries
46
+ self._maximum_depth = maximum_depth
47
+ # used as kwargs in DepthFirstExplorer
48
+ self._max_num_lines_mol_block = max_num_lines_mol_block
49
+ self._file_system = FileSystem(data_dir)
50
+
51
+ async def _process_message(self, message: JobMessage) -> None:
52
+ job_id = message.id
53
+ job_type = message.job_type
54
+ logger.info(f"Received a new job {job_id} of type {job_type}")
55
+
56
+ # the input file to the job is stored in a designated sources directory
57
+ # (the file is allowed to reference other files, but setting the data_dir
58
+ # to the sources directory ensures that we never read files outside of the
59
+ # sources directory)
60
+ data_dir = self._file_system.get_sources_dir()
61
+
62
+ # create a reader (explorer) for the input file
63
+ explorer = DepthFirstExplorer(
64
+ num_test_entries=self._num_test_entries,
65
+ threshold=self._ratio_valid_entries,
66
+ maximum_depth=self._maximum_depth,
67
+ # extra args
68
+ max_num_lines_mol_block=self._max_num_lines_mol_block,
69
+ data_dir=data_dir,
70
+ )
71
+
72
+ read_input_step = ReadInputStep(explorer, message.source_id)
73
+
74
+ # read the input file
75
+ entries = read_input_step()
76
+
77
+ # iterate through the entries
78
+ # create batches of size checkpoint_size
79
+ # limit the number of molecules to max_num_molecules
80
+ batches = batched(entries, self._checkpoint_size)
81
+ num_entries = 0
82
+ num_checkpoints = 0
83
+ for i, batch in enumerate(batches):
84
+ # max_num_molecules might be reached within the batch
85
+ num_store = min(len(batch), self._max_num_molecules - num_entries)
86
+
87
+ # store batch in data_dir
88
+ with self._file_system.get_checkpoint_file_handle(job_id, i, "wb") as f:
89
+ results = list(batch[:num_store])
90
+
91
+ # TODO: use a model for storing the batches
92
+
93
+ # check all items for mol values and use PropertyMol for those
94
+ # in order to keep molecular properties (thanks, RDKit! :/ )
95
+ def _check_value(value: Any) -> Any:
96
+ if isinstance(value, Mol):
97
+ return PropertyMol(value)
98
+ return value
99
+
100
+ def _check_item(item: dict) -> dict:
101
+ return {key: _check_value(value) for key, value in item.items()}
102
+
103
+ results = [_check_item(item) for item in results]
104
+
105
+ dump(results, f)
106
+
107
+ # send a tuple to topic cypstrate-checkpoints
108
+ await self.channel.checkpoints_topic(job_type).send(
109
+ CheckpointMessage(
110
+ job_id=job_id,
111
+ checkpoint_id=i,
112
+ params=message.params,
113
+ )
114
+ )
115
+
116
+ num_entries += num_store
117
+ num_checkpoints += 1
118
+
119
+ if num_entries >= self._max_num_molecules:
120
+ break
121
+
122
+ logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
123
+
124
+ # send a warning message if there were more molecules in the job than allowed
125
+ too_many_molecules = num_store < len(batch)
126
+ try:
127
+ # try to get another entry
128
+ next(entries)
129
+
130
+ # if we get here, there was another entry and we need to send a warning
131
+ too_many_molecules = True
132
+ except StopIteration:
133
+ pass
134
+
135
+ if too_many_molecules:
136
+ await self.channel.logs_topic().send(
137
+ LogMessage(
138
+ job_id=job_id,
139
+ message_type="warning",
140
+ message=(
141
+ f"The provided job contains more than "
142
+ f"{self._max_num_molecules} input structures. Only the "
143
+ f"first {self._max_num_molecules} will be processed."
144
+ ),
145
+ )
146
+ )
147
+
148
+ # at the end, send a tuple to topic job-sizes with the overall size
149
+ # of the job
150
+ await self.channel.logs_topic().send(
151
+ LogMessage(
152
+ job_id=job_id,
153
+ message_type="report_job_size",
154
+ num_entries=num_entries,
155
+ num_checkpoints=num_checkpoints,
156
+ )
157
+ )
@@ -0,0 +1,37 @@
1
+ import json
2
+ import logging
3
+
4
+ from nerdd_module import Model
5
+
6
+ from ..channels import Channel
7
+ from ..files import FileSystem
8
+ from ..types import ModuleMessage, SystemMessage
9
+ from .action import Action
10
+
11
+ __all__ = ["RegisterModuleAction"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RegisterModuleAction(Action[SystemMessage]):
17
+ def __init__(self, channel: Channel, model: Model, data_dir: str):
18
+ super().__init__(channel.system_topic())
19
+ # TODO: do this differently
20
+ assert hasattr(model, "get_config")
21
+ self._model = model
22
+ self._file_system = FileSystem(data_dir)
23
+
24
+ async def _process_message(self, message: SystemMessage) -> None:
25
+ config = self._model.get_config()
26
+ logger.info(f"Registering module with id {config.id}")
27
+
28
+ # save module as json to file
29
+ module_file = self._file_system.get_module_file_path(config.id)
30
+ json.dump(config.model_dump(), open(module_file, "w"))
31
+
32
+ # send the initialization message
33
+ await self.channel.modules_topic().send(ModuleMessage(id=config.id))
34
+
35
+ def _get_group_name(self) -> str:
36
+ model_id = self._model.get_config().id
37
+ return f"register-module-{model_id}"
@@ -0,0 +1,64 @@
1
+ import json
2
+ import logging
3
+
4
+ from nerdd_module import OutputStep
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadPickleStep, SerializeJobModel
8
+ from ..files import FileSystem
9
+ from ..types import SerializationRequestMessage, SerializationResultMessage
10
+ from .action import Action
11
+
12
+ __all__ = ["SerializeJobAction"]
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SerializeJobAction(Action[SerializationRequestMessage]):
19
+ def __init__(self, channel: Channel, data_dir: str) -> None:
20
+ super().__init__(channel.serialization_requests_topic())
21
+ self._file_system = FileSystem(data_dir)
22
+
23
+ async def _process_message(self, message: SerializationRequestMessage) -> None:
24
+ job_id = message.job_id
25
+ job_type = message.job_type
26
+ params = message.params
27
+ output_format = message.output_format
28
+ logger.info(f"Write output for job {job_id} in format {output_format}")
29
+
30
+ # remove specific parameter keys that could induce vulnerabilities
31
+ params.pop("output_file", None)
32
+ params.pop("output_format", None)
33
+
34
+ # obtain output file
35
+ output_file = self._file_system.get_output_file(job_id, output_format)
36
+
37
+ # get the configuration for the job_type
38
+ config_file = self._file_system.get_module_file_path(job_type)
39
+ config = json.load(open(config_file, "r"))
40
+
41
+ # create a fake model instance to get the postprocessing steps
42
+ model = SerializeJobModel(config)
43
+
44
+ steps = [
45
+ # read the result checkpoint files in the correct order
46
+ ReadPickleStep(self._file_system.iter_results_file_handles(job_id)),
47
+ # don't preprocess, don't do prediction, only post-process
48
+ *model._get_postprocessing_steps(output_format, output_file=output_file, **params),
49
+ ]
50
+
51
+ output_step = steps[-1]
52
+ assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
53
+
54
+ # build the pipeline from the list of steps
55
+ pipeline = None
56
+ for t in steps:
57
+ pipeline = t(pipeline)
58
+
59
+ # run the pipeline by calling the get_result method of the last step
60
+ output_step.get_result()
61
+
62
+ await self.channel.serialization_results_topic().send(
63
+ SerializationResultMessage(job_id=job_id, output_format=output_format)
64
+ )
@@ -0,0 +1,3 @@
1
+ from .channel import *
2
+ from .kafka_channel import *
3
+ from .memory_channel import *