nerdd-link 0.2.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nerdd_link-0.2.11/LICENSE +21 -0
- nerdd_link-0.2.11/PKG-INFO +116 -0
- nerdd_link-0.2.11/README.md +39 -0
- nerdd_link-0.2.11/nerdd_link/__init__.py +6 -0
- nerdd_link-0.2.11/nerdd_link/actions/__init__.py +5 -0
- nerdd_link-0.2.11/nerdd_link/actions/action.py +41 -0
- nerdd_link-0.2.11/nerdd_link/actions/predict_checkpoints_action.py +76 -0
- nerdd_link-0.2.11/nerdd_link/actions/process_jobs_action.py +157 -0
- nerdd_link-0.2.11/nerdd_link/actions/register_module_action.py +37 -0
- nerdd_link-0.2.11/nerdd_link/actions/serialize_job_action.py +64 -0
- nerdd_link-0.2.11/nerdd_link/channels/__init__.py +3 -0
- nerdd_link-0.2.11/nerdd_link/channels/channel.py +147 -0
- nerdd_link-0.2.11/nerdd_link/channels/kafka_channel.py +109 -0
- nerdd_link-0.2.11/nerdd_link/channels/memory_channel.py +47 -0
- nerdd_link-0.2.11/nerdd_link/cli/__init__.py +4 -0
- nerdd_link-0.2.11/nerdd_link/cli/initialize_system.py +49 -0
- nerdd_link-0.2.11/nerdd_link/cli/run_job_server.py +115 -0
- nerdd_link-0.2.11/nerdd_link/cli/run_prediction_server.py +85 -0
- nerdd_link-0.2.11/nerdd_link/cli/run_serialization_server.py +73 -0
- nerdd_link-0.2.11/nerdd_link/converters/__init__.py +6 -0
- nerdd_link-0.2.11/nerdd_link/converters/image_converter.py +15 -0
- nerdd_link-0.2.11/nerdd_link/converters/mol_pickle_converter.py +20 -0
- nerdd_link-0.2.11/nerdd_link/converters/mol_to_image_converter.py +77 -0
- nerdd_link-0.2.11/nerdd_link/converters/pickle_converter.py +15 -0
- nerdd_link-0.2.11/nerdd_link/converters/problem_list_converter.py +15 -0
- nerdd_link-0.2.11/nerdd_link/converters/source_list_converter.py +15 -0
- nerdd_link-0.2.11/nerdd_link/delegates/__init__.py +5 -0
- nerdd_link-0.2.11/nerdd_link/delegates/pickle_writer.py +18 -0
- nerdd_link-0.2.11/nerdd_link/delegates/read_checkpoint_model.py +73 -0
- nerdd_link-0.2.11/nerdd_link/delegates/read_pickle_step.py +21 -0
- nerdd_link-0.2.11/nerdd_link/delegates/serialize_job_model.py +21 -0
- nerdd_link-0.2.11/nerdd_link/delegates/split_and_merge_step.py +51 -0
- nerdd_link-0.2.11/nerdd_link/delegates/topic_writer.py +19 -0
- nerdd_link-0.2.11/nerdd_link/files/__init__.py +1 -0
- nerdd_link-0.2.11/nerdd_link/files/file_system.py +89 -0
- nerdd_link-0.2.11/nerdd_link/input/__init__.py +1 -0
- nerdd_link-0.2.11/nerdd_link/input/structure_json_reader.py +37 -0
- nerdd_link-0.2.11/nerdd_link/py.typed +0 -0
- nerdd_link-0.2.11/nerdd_link/tests/__init__.py +3 -0
- nerdd_link-0.2.11/nerdd_link/tests/async_step.py +22 -0
- nerdd_link-0.2.11/nerdd_link/tests/channels.py +81 -0
- nerdd_link-0.2.11/nerdd_link/tests/files.py +9 -0
- nerdd_link-0.2.11/nerdd_link/types/__init__.py +70 -0
- nerdd_link-0.2.11/nerdd_link/utils/__init__.py +4 -0
- nerdd_link-0.2.11/nerdd_link/utils/async_to_sync.py +26 -0
- nerdd_link-0.2.11/nerdd_link/utils/batched.py +27 -0
- nerdd_link-0.2.11/nerdd_link/utils/observable_list.py +72 -0
- nerdd_link-0.2.11/nerdd_link/utils/safetee.py +39 -0
- nerdd_link-0.2.11/nerdd_link/version.py +11 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/PKG-INFO +116 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/SOURCES.txt +56 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/dependency_links.txt +1 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/entry_points.txt +5 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/requires.txt +34 -0
- nerdd_link-0.2.11/nerdd_link.egg-info/top_level.txt +1 -0
- nerdd_link-0.2.11/pyproject.toml +139 -0
- nerdd_link-0.2.11/setup.cfg +4 -0
- nerdd_link-0.2.11/tests/test_features.py +3 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Molecular Informatics Vienna
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: nerdd-link
|
|
3
|
+
Version: 0.2.11
|
|
4
|
+
Summary: Run a NERDD module as a service
|
|
5
|
+
Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
|
|
6
|
+
Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
|
|
7
|
+
License: MIT License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2023 Molecular Informatics Vienna
|
|
10
|
+
|
|
11
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
12
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
13
|
+
in the Software without restriction, including without limitation the rights
|
|
14
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
15
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
16
|
+
furnished to do so, subject to the following conditions:
|
|
17
|
+
|
|
18
|
+
The above copyright notice and this permission notice shall be included in all
|
|
19
|
+
copies or substantial portions of the Software.
|
|
20
|
+
|
|
21
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
22
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
23
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
24
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
25
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
|
+
SOFTWARE.
|
|
28
|
+
|
|
29
|
+
Project-URL: Repository, https://github.com/molinfo-vienna/nerdd-link
|
|
30
|
+
Keywords: science,research,development,nerdd
|
|
31
|
+
Classifier: Intended Audience :: Science/Research
|
|
32
|
+
Classifier: Intended Audience :: Developers
|
|
33
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
34
|
+
Classifier: Programming Language :: Python
|
|
35
|
+
Classifier: Topic :: Software Development
|
|
36
|
+
Classifier: Topic :: Scientific/Engineering
|
|
37
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
38
|
+
Classifier: Operating System :: POSIX
|
|
39
|
+
Classifier: Operating System :: Unix
|
|
40
|
+
Classifier: Operating System :: MacOS
|
|
41
|
+
Classifier: Programming Language :: Python :: 3
|
|
42
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
43
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
44
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
45
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
46
|
+
Description-Content-Type: text/markdown
|
|
47
|
+
License-File: LICENSE
|
|
48
|
+
Requires-Dist: nerdd-module>=0.3.6
|
|
49
|
+
Requires-Dist: pandas>=1.2.1
|
|
50
|
+
Requires-Dist: pyyaml~=6.0
|
|
51
|
+
Requires-Dist: filetype~=1.2.0
|
|
52
|
+
Requires-Dist: rich-click>=1.7.1
|
|
53
|
+
Requires-Dist: stringcase~=1.2.0
|
|
54
|
+
Requires-Dist: numpy
|
|
55
|
+
Requires-Dist: simplejson>=3
|
|
56
|
+
Requires-Dist: pydantic>=2
|
|
57
|
+
Requires-Dist: aiokafka>=0.12.0
|
|
58
|
+
Requires-Dist: importlib-metadata>=4.6; python_version < "3.10"
|
|
59
|
+
Provides-Extra: dev
|
|
60
|
+
Requires-Dist: mypy; extra == "dev"
|
|
61
|
+
Requires-Dist: ruff==0.8.0; extra == "dev"
|
|
62
|
+
Requires-Dist: pre-commit>=2; extra == "dev"
|
|
63
|
+
Provides-Extra: test
|
|
64
|
+
Requires-Dist: pytest; extra == "test"
|
|
65
|
+
Requires-Dist: pytest-sugar; extra == "test"
|
|
66
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
67
|
+
Requires-Dist: pytest-asyncio; extra == "test"
|
|
68
|
+
Requires-Dist: pytest-bdd==7.3.0; extra == "test"
|
|
69
|
+
Requires-Dist: pytest-mock; extra == "test"
|
|
70
|
+
Requires-Dist: pytest-watcher; extra == "test"
|
|
71
|
+
Requires-Dist: hypothesis; extra == "test"
|
|
72
|
+
Requires-Dist: hypothesis-rdkit; extra == "test"
|
|
73
|
+
Provides-Extra: docs
|
|
74
|
+
Requires-Dist: mkdocs; extra == "docs"
|
|
75
|
+
Requires-Dist: mkdocs-material; extra == "docs"
|
|
76
|
+
Requires-Dist: mkdocstrings; extra == "docs"
|
|
77
|
+
|
|
78
|
+
# NERDD-Link
|
|
79
|
+
|
|
80
|
+
Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
|
|
81
|
+
service that consumes input molecules and produces prediction tuples.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
## Installation
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
pip install -U nerdd-link
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
## Usage
|
|
91
|
+
|
|
92
|
+
When a class inherits from ```nerdd_module.AbstractModel``` (see
|
|
93
|
+
[NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
|
|
94
|
+
used to create a Kafka service.
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
# run a Kafka service for NerddModel on localhost:9092
|
|
98
|
+
run_nerdd_server package.path.to.NerddModel
|
|
99
|
+
|
|
100
|
+
# modify broker url, input topic and batch size
|
|
101
|
+
run_nerdd_server package.path.to.NerddModel \
|
|
102
|
+
--broker-url my-cluster-kafka-bootstrap.kafka:9092 \
|
|
103
|
+
--input-topic examples \
|
|
104
|
+
--batch-size 10
|
|
105
|
+
|
|
106
|
+
# more information via --help
|
|
107
|
+
run_nerdd_server --help
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
If the model class is called ```ExamplePredictionModel```, the server will read input
|
|
111
|
+
tuples from the input topic ```example-prediction-inputs``` in batches of size 100
|
|
112
|
+
and write results to the ```results``` topic. The batch size specifies the number
|
|
113
|
+
of input tuples that are given to the model at once.
|
|
114
|
+
|
|
115
|
+
## Communication
|
|
116
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# NERDD-Link
|
|
2
|
+
|
|
3
|
+
Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
|
|
4
|
+
service that consumes input molecules and produces prediction tuples.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install -U nerdd-link
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Usage
|
|
14
|
+
|
|
15
|
+
When a class inherits from ```nerdd_module.AbstractModel``` (see
|
|
16
|
+
[NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
|
|
17
|
+
used to create a Kafka service.
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
# run a Kafka service for NerddModel on localhost:9092
|
|
21
|
+
run_nerdd_server package.path.to.NerddModel
|
|
22
|
+
|
|
23
|
+
# modify broker url, input topic and batch size
|
|
24
|
+
run_nerdd_server package.path.to.NerddModel \
|
|
25
|
+
--broker-url my-cluster-kafka-bootstrap.kafka:9092 \
|
|
26
|
+
--input-topic examples \
|
|
27
|
+
--batch-size 10
|
|
28
|
+
|
|
29
|
+
# more information via --help
|
|
30
|
+
run_nerdd_server --help
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
If the model class is called ```ExamplePredictionModel```, the server will read input
|
|
34
|
+
tuples from the input topic ```example-prediction-inputs``` in batches of size 100
|
|
35
|
+
and write results to the ```results``` topic. The batch size specifies the number
|
|
36
|
+
of input tuples that are given to the model at once.
|
|
37
|
+
|
|
38
|
+
## Communication
|
|
39
|
+
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from asyncio import CancelledError
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from stringcase import spinalcase
|
|
7
|
+
|
|
8
|
+
from ..channels import Channel, Topic
|
|
9
|
+
from ..types import Message
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T", bound=Message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Action(ABC, Generic[T]):
|
|
17
|
+
def __init__(self, input_topic: Topic[T]):
|
|
18
|
+
self._input_topic = input_topic
|
|
19
|
+
|
|
20
|
+
async def run(self) -> None:
|
|
21
|
+
consumer_group = spinalcase(self._get_group_name())
|
|
22
|
+
async for message in self._input_topic.receive(consumer_group):
|
|
23
|
+
try:
|
|
24
|
+
await self._process_message(message)
|
|
25
|
+
except CancelledError:
|
|
26
|
+
# the consumer was cancelled, stop processing messages
|
|
27
|
+
break
|
|
28
|
+
except Exception:
|
|
29
|
+
# log the error and continue processing the next message
|
|
30
|
+
logger.error("Error processing message", exc_info=True)
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def _process_message(self, message: T) -> None:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def channel(self) -> Channel:
|
|
38
|
+
return self._input_topic.channel
|
|
39
|
+
|
|
40
|
+
def _get_group_name(self) -> str:
|
|
41
|
+
return self.__class__.__name__
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from multiprocessing import Queue
|
|
3
|
+
|
|
4
|
+
from nerdd_module import SimpleModel
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadCheckpointModel
|
|
8
|
+
from ..files import FileSystem
|
|
9
|
+
from ..types import CheckpointMessage, ResultCheckpointMessage, ResultMessage
|
|
10
|
+
from .action import Action
|
|
11
|
+
|
|
12
|
+
__all__ = ["PredictCheckpointsAction"]
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PredictCheckpointsAction(Action[CheckpointMessage]):
|
|
18
|
+
# Accept a batch of input molecules on the "<job-type>-checkpoints" topic
|
|
19
|
+
# (generated in the previous step) and process them. Results are written to
|
|
20
|
+
# the "results" topic.
|
|
21
|
+
|
|
22
|
+
def __init__(self, channel: Channel, model: SimpleModel, data_dir: str) -> None:
|
|
23
|
+
super().__init__(channel.checkpoints_topic(model))
|
|
24
|
+
self._model = model
|
|
25
|
+
self._data_dir = data_dir
|
|
26
|
+
|
|
27
|
+
async def _process_message(self, message: CheckpointMessage) -> None:
|
|
28
|
+
job_id = message.job_id
|
|
29
|
+
checkpoint_id = message.checkpoint_id
|
|
30
|
+
params = message.params
|
|
31
|
+
logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
|
|
32
|
+
|
|
33
|
+
# The Kafka consumers and producers run in the current asyncio event loop and (by
|
|
34
|
+
# observation) it seems that calling the produce method of a Kafka producer in a different
|
|
35
|
+
# event loop / thread / process doesn't seem to work (hangs indefinitely). Therefore, we
|
|
36
|
+
# create a queue in this event loop / thread and other tasks send messages to the queue
|
|
37
|
+
# instead of directly to the Kafka producer. This event loop will wait for new messages in
|
|
38
|
+
# this queue and forward them to the Kafka producer.
|
|
39
|
+
queue: Queue = Queue()
|
|
40
|
+
|
|
41
|
+
file_system = FileSystem(self._data_dir)
|
|
42
|
+
|
|
43
|
+
# create a wrapper model that
|
|
44
|
+
# * reads the checkpoint file instead of normal input
|
|
45
|
+
# * does preprocessing, prediction, and postprocessing like the encapsulated model
|
|
46
|
+
# * does not write to the specified results file, but to the checkpoints file instead
|
|
47
|
+
# * sends the results to the results topic
|
|
48
|
+
model = ReadCheckpointModel(
|
|
49
|
+
base_model=self._model,
|
|
50
|
+
job_id=job_id,
|
|
51
|
+
file_system=file_system,
|
|
52
|
+
checkpoint_id=checkpoint_id,
|
|
53
|
+
queue=queue,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# predict the checkpoint
|
|
57
|
+
# assign input=None, because the checkpoint file is already provided in ReadCheckpointModel
|
|
58
|
+
model.predict(
|
|
59
|
+
input=None,
|
|
60
|
+
**params,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Wait for the prediction to finish and the results to be sent.
|
|
64
|
+
while True:
|
|
65
|
+
record = queue.get()
|
|
66
|
+
if record is not None:
|
|
67
|
+
await self.channel.results_topic().send(ResultMessage(job_id=job_id, **record))
|
|
68
|
+
else:
|
|
69
|
+
await self.channel.result_checkpoints_topic().send(
|
|
70
|
+
ResultCheckpointMessage(job_id=job_id, checkpoint_id=checkpoint_id)
|
|
71
|
+
)
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
def _get_group_name(self) -> str:
|
|
75
|
+
model_id = self._model.get_config().id
|
|
76
|
+
return f"predict-checkpoints-{model_id}"
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pickle import dump
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from nerdd_module.input import DepthFirstExplorer
|
|
6
|
+
from nerdd_module.model import ReadInputStep
|
|
7
|
+
from rdkit.Chem import Mol
|
|
8
|
+
from rdkit.Chem.PropertyMol import PropertyMol
|
|
9
|
+
|
|
10
|
+
from ..channels import Channel
|
|
11
|
+
from ..files import FileSystem
|
|
12
|
+
from ..types import CheckpointMessage, JobMessage, LogMessage
|
|
13
|
+
from ..utils import batched
|
|
14
|
+
from .action import Action
|
|
15
|
+
|
|
16
|
+
__all__ = ["ProcessJobsAction"]
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProcessJobsAction(Action[JobMessage]):
|
|
22
|
+
# Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
|
|
23
|
+
# iterates through all molecules in the input (files), writes them as batches
|
|
24
|
+
# into checkpoint files and sends checkpoint messages (for each batch) to the
|
|
25
|
+
# "<job_type>-checkpoints" topic. Also, the number of molecules read is
|
|
26
|
+
# reported to the topic "job-sizes".
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
channel: Channel,
|
|
31
|
+
checkpoint_size: int,
|
|
32
|
+
max_num_molecules: int,
|
|
33
|
+
num_test_entries: int,
|
|
34
|
+
ratio_valid_entries: float,
|
|
35
|
+
maximum_depth: int,
|
|
36
|
+
max_num_lines_mol_block: int,
|
|
37
|
+
data_dir: str,
|
|
38
|
+
) -> None:
|
|
39
|
+
super().__init__(channel.jobs_topic())
|
|
40
|
+
# relevant for chunking
|
|
41
|
+
self._checkpoint_size = checkpoint_size
|
|
42
|
+
self._max_num_molecules = max_num_molecules
|
|
43
|
+
# parameters of DepthFirstExplorer
|
|
44
|
+
self._num_test_entries = num_test_entries
|
|
45
|
+
self._ratio_valid_entries = ratio_valid_entries
|
|
46
|
+
self._maximum_depth = maximum_depth
|
|
47
|
+
# used as kwargs in DepthFirstExplorer
|
|
48
|
+
self._max_num_lines_mol_block = max_num_lines_mol_block
|
|
49
|
+
self._file_system = FileSystem(data_dir)
|
|
50
|
+
|
|
51
|
+
async def _process_message(self, message: JobMessage) -> None:
|
|
52
|
+
job_id = message.id
|
|
53
|
+
job_type = message.job_type
|
|
54
|
+
logger.info(f"Received a new job {job_id} of type {job_type}")
|
|
55
|
+
|
|
56
|
+
# the input file to the job is stored in a designated sources directory
|
|
57
|
+
# (the file is allowed to reference other files, but setting the data_dir
|
|
58
|
+
# to the sources directory ensures that we never read files outside of the
|
|
59
|
+
# sources directory)
|
|
60
|
+
data_dir = self._file_system.get_sources_dir()
|
|
61
|
+
|
|
62
|
+
# create a reader (explorer) for the input file
|
|
63
|
+
explorer = DepthFirstExplorer(
|
|
64
|
+
num_test_entries=self._num_test_entries,
|
|
65
|
+
threshold=self._ratio_valid_entries,
|
|
66
|
+
maximum_depth=self._maximum_depth,
|
|
67
|
+
# extra args
|
|
68
|
+
max_num_lines_mol_block=self._max_num_lines_mol_block,
|
|
69
|
+
data_dir=data_dir,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
read_input_step = ReadInputStep(explorer, message.source_id)
|
|
73
|
+
|
|
74
|
+
# read the input file
|
|
75
|
+
entries = read_input_step()
|
|
76
|
+
|
|
77
|
+
# iterate through the entries
|
|
78
|
+
# create batches of size checkpoint_size
|
|
79
|
+
# limit the number of molecules to max_num_molecules
|
|
80
|
+
batches = batched(entries, self._checkpoint_size)
|
|
81
|
+
num_entries = 0
|
|
82
|
+
num_checkpoints = 0
|
|
83
|
+
for i, batch in enumerate(batches):
|
|
84
|
+
# max_num_molecules might be reached within the batch
|
|
85
|
+
num_store = min(len(batch), self._max_num_molecules - num_entries)
|
|
86
|
+
|
|
87
|
+
# store batch in data_dir
|
|
88
|
+
with self._file_system.get_checkpoint_file_handle(job_id, i, "wb") as f:
|
|
89
|
+
results = list(batch[:num_store])
|
|
90
|
+
|
|
91
|
+
# TODO: use a model for storing the batches
|
|
92
|
+
|
|
93
|
+
# check all items for mol values and use PropertyMol for those
|
|
94
|
+
# in order to keep molecular properties (thanks, RDKit! :/ )
|
|
95
|
+
def _check_value(value: Any) -> Any:
|
|
96
|
+
if isinstance(value, Mol):
|
|
97
|
+
return PropertyMol(value)
|
|
98
|
+
return value
|
|
99
|
+
|
|
100
|
+
def _check_item(item: dict) -> dict:
|
|
101
|
+
return {key: _check_value(value) for key, value in item.items()}
|
|
102
|
+
|
|
103
|
+
results = [_check_item(item) for item in results]
|
|
104
|
+
|
|
105
|
+
dump(results, f)
|
|
106
|
+
|
|
107
|
+
# send a tuple to topic cypstrate-checkpoints
|
|
108
|
+
await self.channel.checkpoints_topic(job_type).send(
|
|
109
|
+
CheckpointMessage(
|
|
110
|
+
job_id=job_id,
|
|
111
|
+
checkpoint_id=i,
|
|
112
|
+
params=message.params,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
num_entries += num_store
|
|
117
|
+
num_checkpoints += 1
|
|
118
|
+
|
|
119
|
+
if num_entries >= self._max_num_molecules:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
|
|
123
|
+
|
|
124
|
+
# send a warning message if there were more molecules in the job than allowed
|
|
125
|
+
too_many_molecules = num_store < len(batch)
|
|
126
|
+
try:
|
|
127
|
+
# try to get another entry
|
|
128
|
+
next(entries)
|
|
129
|
+
|
|
130
|
+
# if we get here, there was another entry and we need to send a warning
|
|
131
|
+
too_many_molecules = True
|
|
132
|
+
except StopIteration:
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
if too_many_molecules:
|
|
136
|
+
await self.channel.logs_topic().send(
|
|
137
|
+
LogMessage(
|
|
138
|
+
job_id=job_id,
|
|
139
|
+
message_type="warning",
|
|
140
|
+
message=(
|
|
141
|
+
f"The provided job contains more than "
|
|
142
|
+
f"{self._max_num_molecules} input structures. Only the "
|
|
143
|
+
f"first {self._max_num_molecules} will be processed."
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# at the end, send a tuple to topic job-sizes with the overall size
|
|
149
|
+
# of the job
|
|
150
|
+
await self.channel.logs_topic().send(
|
|
151
|
+
LogMessage(
|
|
152
|
+
job_id=job_id,
|
|
153
|
+
message_type="report_job_size",
|
|
154
|
+
num_entries=num_entries,
|
|
155
|
+
num_checkpoints=num_checkpoints,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from nerdd_module import Model
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..files import FileSystem
|
|
8
|
+
from ..types import ModuleMessage, SystemMessage
|
|
9
|
+
from .action import Action
|
|
10
|
+
|
|
11
|
+
__all__ = ["RegisterModuleAction"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RegisterModuleAction(Action[SystemMessage]):
|
|
17
|
+
def __init__(self, channel: Channel, model: Model, data_dir: str):
|
|
18
|
+
super().__init__(channel.system_topic())
|
|
19
|
+
# TODO: do this differently
|
|
20
|
+
assert hasattr(model, "get_config")
|
|
21
|
+
self._model = model
|
|
22
|
+
self._file_system = FileSystem(data_dir)
|
|
23
|
+
|
|
24
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
25
|
+
config = self._model.get_config()
|
|
26
|
+
logger.info(f"Registering module with id {config.id}")
|
|
27
|
+
|
|
28
|
+
# save module as json to file
|
|
29
|
+
module_file = self._file_system.get_module_file_path(config.id)
|
|
30
|
+
json.dump(config.model_dump(), open(module_file, "w"))
|
|
31
|
+
|
|
32
|
+
# send the initialization message
|
|
33
|
+
await self.channel.modules_topic().send(ModuleMessage(id=config.id))
|
|
34
|
+
|
|
35
|
+
def _get_group_name(self) -> str:
|
|
36
|
+
model_id = self._model.get_config().id
|
|
37
|
+
return f"register-module-{model_id}"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from nerdd_module import OutputStep
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadPickleStep, SerializeJobModel
|
|
8
|
+
from ..files import FileSystem
|
|
9
|
+
from ..types import SerializationRequestMessage, SerializationResultMessage
|
|
10
|
+
from .action import Action
|
|
11
|
+
|
|
12
|
+
__all__ = ["SerializeJobAction"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SerializeJobAction(Action[SerializationRequestMessage]):
|
|
19
|
+
def __init__(self, channel: Channel, data_dir: str) -> None:
|
|
20
|
+
super().__init__(channel.serialization_requests_topic())
|
|
21
|
+
self._file_system = FileSystem(data_dir)
|
|
22
|
+
|
|
23
|
+
async def _process_message(self, message: SerializationRequestMessage) -> None:
|
|
24
|
+
job_id = message.job_id
|
|
25
|
+
job_type = message.job_type
|
|
26
|
+
params = message.params
|
|
27
|
+
output_format = message.output_format
|
|
28
|
+
logger.info(f"Write output for job {job_id} in format {output_format}")
|
|
29
|
+
|
|
30
|
+
# remove specific parameter keys that could induce vulnerabilities
|
|
31
|
+
params.pop("output_file", None)
|
|
32
|
+
params.pop("output_format", None)
|
|
33
|
+
|
|
34
|
+
# obtain output file
|
|
35
|
+
output_file = self._file_system.get_output_file(job_id, output_format)
|
|
36
|
+
|
|
37
|
+
# get the configuration for the job_type
|
|
38
|
+
config_file = self._file_system.get_module_file_path(job_type)
|
|
39
|
+
config = json.load(open(config_file, "r"))
|
|
40
|
+
|
|
41
|
+
# create a fake model instance to get the postprocessing steps
|
|
42
|
+
model = SerializeJobModel(config)
|
|
43
|
+
|
|
44
|
+
steps = [
|
|
45
|
+
# read the result checkpoint files in the correct order
|
|
46
|
+
ReadPickleStep(self._file_system.iter_results_file_handles(job_id)),
|
|
47
|
+
# don't preprocess, don't do prediction, only post-process
|
|
48
|
+
*model._get_postprocessing_steps(output_format, output_file=output_file, **params),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
output_step = steps[-1]
|
|
52
|
+
assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
|
|
53
|
+
|
|
54
|
+
# build the pipeline from the list of steps
|
|
55
|
+
pipeline = None
|
|
56
|
+
for t in steps:
|
|
57
|
+
pipeline = t(pipeline)
|
|
58
|
+
|
|
59
|
+
# run the pipeline by calling the get_result method of the last step
|
|
60
|
+
output_step.get_result()
|
|
61
|
+
|
|
62
|
+
await self.channel.serialization_results_topic().send(
|
|
63
|
+
SerializationResultMessage(job_id=job_id, output_format=output_format)
|
|
64
|
+
)
|