nerdd-link 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nerdd_link-0.1.0/LICENSE +21 -0
- nerdd_link-0.1.0/PKG-INFO +116 -0
- nerdd_link-0.1.0/README.md +39 -0
- nerdd_link-0.1.0/nerdd_link/__init__.py +4 -0
- nerdd_link-0.1.0/nerdd_link/actions/__init__.py +5 -0
- nerdd_link-0.1.0/nerdd_link/actions/action.py +30 -0
- nerdd_link-0.1.0/nerdd_link/actions/predict_checkpoints_action.py +59 -0
- nerdd_link-0.1.0/nerdd_link/actions/process_jobs_action.py +138 -0
- nerdd_link-0.1.0/nerdd_link/actions/register_module_action.py +30 -0
- nerdd_link-0.1.0/nerdd_link/actions/write_output_action.py +21 -0
- nerdd_link-0.1.0/nerdd_link/channels/__init__.py +3 -0
- nerdd_link-0.1.0/nerdd_link/channels/channel.py +114 -0
- nerdd_link-0.1.0/nerdd_link/channels/kafka_channel.py +97 -0
- nerdd_link-0.1.0/nerdd_link/channels/memory_channel.py +33 -0
- nerdd_link-0.1.0/nerdd_link/cli/__init__.py +3 -0
- nerdd_link-0.1.0/nerdd_link/cli/initialize_system.py +45 -0
- nerdd_link-0.1.0/nerdd_link/cli/run_job_server.py +107 -0
- nerdd_link-0.1.0/nerdd_link/cli/run_prediction_server.py +81 -0
- nerdd_link-0.1.0/nerdd_link/delegates/__init__.py +3 -0
- nerdd_link-0.1.0/nerdd_link/delegates/pickle_writer.py +18 -0
- nerdd_link-0.1.0/nerdd_link/delegates/read_checkpoint_model.py +65 -0
- nerdd_link-0.1.0/nerdd_link/delegates/read_pickle_step.py +18 -0
- nerdd_link-0.1.0/nerdd_link/delegates/split_and_merge_step.py +51 -0
- nerdd_link-0.1.0/nerdd_link/delegates/topic_writer.py +27 -0
- nerdd_link-0.1.0/nerdd_link/input/__init__.py +1 -0
- nerdd_link-0.1.0/nerdd_link/input/structure_json_reader.py +41 -0
- nerdd_link-0.1.0/nerdd_link/py.typed +0 -0
- nerdd_link-0.1.0/nerdd_link/tests/__init__.py +3 -0
- nerdd_link-0.1.0/nerdd_link/tests/async_step.py +22 -0
- nerdd_link-0.1.0/nerdd_link/tests/channels.py +82 -0
- nerdd_link-0.1.0/nerdd_link/tests/files.py +9 -0
- nerdd_link-0.1.0/nerdd_link/types/__init__.py +54 -0
- nerdd_link-0.1.0/nerdd_link/utils/__init__.py +4 -0
- nerdd_link-0.1.0/nerdd_link/utils/async_to_sync.py +26 -0
- nerdd_link-0.1.0/nerdd_link/utils/batched.py +27 -0
- nerdd_link-0.1.0/nerdd_link/utils/observable_list.py +72 -0
- nerdd_link-0.1.0/nerdd_link/utils/safetee.py +39 -0
- nerdd_link-0.1.0/nerdd_link/version.py +11 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/PKG-INFO +116 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/SOURCES.txt +45 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/dependency_links.txt +1 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/entry_points.txt +4 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/requires.txt +34 -0
- nerdd_link-0.1.0/nerdd_link.egg-info/top_level.txt +1 -0
- nerdd_link-0.1.0/pyproject.toml +137 -0
- nerdd_link-0.1.0/setup.cfg +4 -0
- nerdd_link-0.1.0/tests/test_features.py +3 -0
nerdd_link-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Molecular Informatics Vienna
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: nerdd-link
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Run a NERDD module as a service
|
|
5
|
+
Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
|
|
6
|
+
Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
|
|
7
|
+
License: MIT License
|
|
8
|
+
|
|
9
|
+
Copyright (c) 2023 Molecular Informatics Vienna
|
|
10
|
+
|
|
11
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
12
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
13
|
+
in the Software without restriction, including without limitation the rights
|
|
14
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
15
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
16
|
+
furnished to do so, subject to the following conditions:
|
|
17
|
+
|
|
18
|
+
The above copyright notice and this permission notice shall be included in all
|
|
19
|
+
copies or substantial portions of the Software.
|
|
20
|
+
|
|
21
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
22
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
23
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
24
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
25
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
26
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
27
|
+
SOFTWARE.
|
|
28
|
+
|
|
29
|
+
Project-URL: Repository, https://github.com/molinfo-vienna/nerdd-link
|
|
30
|
+
Keywords: science,research,development,nerdd
|
|
31
|
+
Classifier: Intended Audience :: Science/Research
|
|
32
|
+
Classifier: Intended Audience :: Developers
|
|
33
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
34
|
+
Classifier: Programming Language :: Python
|
|
35
|
+
Classifier: Topic :: Software Development
|
|
36
|
+
Classifier: Topic :: Scientific/Engineering
|
|
37
|
+
Classifier: Operating System :: Microsoft :: Windows
|
|
38
|
+
Classifier: Operating System :: POSIX
|
|
39
|
+
Classifier: Operating System :: Unix
|
|
40
|
+
Classifier: Operating System :: MacOS
|
|
41
|
+
Classifier: Programming Language :: Python :: 3
|
|
42
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
43
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
44
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
45
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
46
|
+
Description-Content-Type: text/markdown
|
|
47
|
+
License-File: LICENSE
|
|
48
|
+
Requires-Dist: nerdd-module>=0.3.6
|
|
49
|
+
Requires-Dist: pandas>=1.2.1
|
|
50
|
+
Requires-Dist: pyyaml~=6.0
|
|
51
|
+
Requires-Dist: filetype~=1.2.0
|
|
52
|
+
Requires-Dist: rich-click>=1.7.1
|
|
53
|
+
Requires-Dist: stringcase~=1.2.0
|
|
54
|
+
Requires-Dist: numpy
|
|
55
|
+
Requires-Dist: simplejson>=3
|
|
56
|
+
Requires-Dist: pydantic>=2
|
|
57
|
+
Requires-Dist: aiokafka>=0.11.0
|
|
58
|
+
Requires-Dist: importlib-metadata>=4.6; python_version < "3.10"
|
|
59
|
+
Provides-Extra: dev
|
|
60
|
+
Requires-Dist: mypy; extra == "dev"
|
|
61
|
+
Requires-Dist: ruff==0.8.0; extra == "dev"
|
|
62
|
+
Requires-Dist: pre-commit>=2; extra == "dev"
|
|
63
|
+
Provides-Extra: test
|
|
64
|
+
Requires-Dist: pytest; extra == "test"
|
|
65
|
+
Requires-Dist: pytest-sugar; extra == "test"
|
|
66
|
+
Requires-Dist: pytest-cov; extra == "test"
|
|
67
|
+
Requires-Dist: pytest-asyncio; extra == "test"
|
|
68
|
+
Requires-Dist: pytest-bdd==7.3.0; extra == "test"
|
|
69
|
+
Requires-Dist: pytest-mock; extra == "test"
|
|
70
|
+
Requires-Dist: pytest-watcher; extra == "test"
|
|
71
|
+
Requires-Dist: hypothesis; extra == "test"
|
|
72
|
+
Requires-Dist: hypothesis-rdkit; extra == "test"
|
|
73
|
+
Provides-Extra: docs
|
|
74
|
+
Requires-Dist: mkdocs; extra == "docs"
|
|
75
|
+
Requires-Dist: mkdocs-material; extra == "docs"
|
|
76
|
+
Requires-Dist: mkdocstrings; extra == "docs"
|
|
77
|
+
|
|
78
|
+
# NERDD-Link
|
|
79
|
+
|
|
80
|
+
Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
|
|
81
|
+
service that consumes input molecules and produces prediction tuples.
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
## Installation
|
|
85
|
+
|
|
86
|
+
```bash
|
|
87
|
+
pip install -U nerdd-link
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
## Usage
|
|
91
|
+
|
|
92
|
+
When a class inherits from ```nerdd_module.AbstractModel``` (see
|
|
93
|
+
[NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
|
|
94
|
+
used to create a Kafka service.
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
# run a Kafka service for NerddModel on localhost:9092
|
|
98
|
+
run_nerdd_server package.path.to.NerddModel
|
|
99
|
+
|
|
100
|
+
# modify broker url, input topic and batch size
|
|
101
|
+
run_nerdd_server package.path.to.NerddModel \
|
|
102
|
+
--broker-url my-cluster-kafka-bootstrap.kafka:9092 \
|
|
103
|
+
--input-topic examples \
|
|
104
|
+
--batch-size 10
|
|
105
|
+
|
|
106
|
+
# more information via --help
|
|
107
|
+
run_nerdd_server --help
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
If the model class is called ```ExamplePredictionModel```, the server will read input
|
|
111
|
+
tuples from the input topic ```example-prediction-inputs``` in batches of size 100
|
|
112
|
+
and write results to the ```results``` topic. The batch size specifies the number
|
|
113
|
+
of input tuples that are given to the model at once.
|
|
114
|
+
|
|
115
|
+
## Communication
|
|
116
|
+
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# NERDD-Link
|
|
2
|
+
|
|
3
|
+
Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
|
|
4
|
+
service that consumes input molecules and produces prediction tuples.
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
```bash
|
|
10
|
+
pip install -U nerdd-link
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Usage
|
|
14
|
+
|
|
15
|
+
When a class inherits from ```nerdd_module.AbstractModel``` (see
|
|
16
|
+
[NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
|
|
17
|
+
used to create a Kafka service.
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
# run a Kafka service for NerddModel on localhost:9092
|
|
21
|
+
run_nerdd_server package.path.to.NerddModel
|
|
22
|
+
|
|
23
|
+
# modify broker url, input topic and batch size
|
|
24
|
+
run_nerdd_server package.path.to.NerddModel \
|
|
25
|
+
--broker-url my-cluster-kafka-bootstrap.kafka:9092 \
|
|
26
|
+
--input-topic examples \
|
|
27
|
+
--batch-size 10
|
|
28
|
+
|
|
29
|
+
# more information via --help
|
|
30
|
+
run_nerdd_server --help
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
If the model class is called ```ExamplePredictionModel```, the server will read input
|
|
34
|
+
tuples from the input topic ```example-prediction-inputs``` in batches of size 100
|
|
35
|
+
and write results to the ```results``` topic. The batch size specifies the number
|
|
36
|
+
of input tuples that are given to the model at once.
|
|
37
|
+
|
|
38
|
+
## Communication
|
|
39
|
+
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
from stringcase import spinalcase
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel, Topic
|
|
7
|
+
from ..types import Message
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", bound=Message)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Action(ABC, Generic[T]):
|
|
13
|
+
def __init__(self, input_topic: Topic[T]):
|
|
14
|
+
self._input_topic = input_topic
|
|
15
|
+
|
|
16
|
+
async def run(self) -> None:
|
|
17
|
+
consumer_group = spinalcase(self._get_group_name())
|
|
18
|
+
async for message in self._input_topic.receive(consumer_group):
|
|
19
|
+
await self._process_message(message)
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def _process_message(self, message: T) -> None:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def channel(self) -> Channel:
|
|
27
|
+
return self._input_topic.channel
|
|
28
|
+
|
|
29
|
+
def _get_group_name(self) -> str:
|
|
30
|
+
return self.__class__.__name__
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from nerdd_module import Model
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadCheckpointModel
|
|
8
|
+
from ..types import CheckpointMessage
|
|
9
|
+
from .action import Action
|
|
10
|
+
|
|
11
|
+
__all__ = ["PredictCheckpointsAction"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PredictCheckpointsAction(Action[CheckpointMessage]):
|
|
17
|
+
# Accept a batch of input molecules on the "<job-type>-checkpoints" topic
|
|
18
|
+
# (generated in the previous step) and process them. Results are written to
|
|
19
|
+
# the "results" topic.
|
|
20
|
+
|
|
21
|
+
def __init__(self, channel: Channel, model: Model, data_dir: str) -> None:
|
|
22
|
+
super().__init__(channel.checkpoints_topic(model))
|
|
23
|
+
self.model = model
|
|
24
|
+
self.data_dir = data_dir
|
|
25
|
+
|
|
26
|
+
async def _process_message(self, message: CheckpointMessage) -> None:
|
|
27
|
+
job_id = message.job_id
|
|
28
|
+
checkpoint_id = message.checkpoint_id
|
|
29
|
+
params = message.params
|
|
30
|
+
logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
|
|
31
|
+
|
|
32
|
+
# the input file to the job is stored in the file data_dir/job_id/input/
|
|
33
|
+
checkpoints_file = f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{checkpoint_id}.pickle"
|
|
34
|
+
checkpoint_results_file = (
|
|
35
|
+
f"{self.data_dir}/jobs/{job_id}/results/checkpoint_{checkpoint_id}.pickle"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# create the results directory
|
|
39
|
+
os.makedirs(f"{self.data_dir}/jobs/{job_id}/results", exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# create a model that reads the checkpoint file
|
|
42
|
+
model = ReadCheckpointModel(
|
|
43
|
+
base_model=self.model,
|
|
44
|
+
job_id=job_id,
|
|
45
|
+
checkpoint_id=checkpoint_id,
|
|
46
|
+
channel=self.channel,
|
|
47
|
+
checkpoints_file=checkpoints_file,
|
|
48
|
+
results_file=checkpoint_results_file,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# predict the checkpoint
|
|
52
|
+
model.predict(
|
|
53
|
+
input=None,
|
|
54
|
+
**params,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _get_group_name(self) -> str:
|
|
58
|
+
model_name = self.model.__class__.__name__
|
|
59
|
+
return model_name
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pickle import dump
|
|
4
|
+
|
|
5
|
+
from nerdd_module.input import DepthFirstExplorer
|
|
6
|
+
from nerdd_module.model import ReadInputStep
|
|
7
|
+
|
|
8
|
+
from ..channels import Channel
|
|
9
|
+
from ..types import CheckpointMessage, JobMessage, LogMessage
|
|
10
|
+
from ..utils import batched
|
|
11
|
+
from .action import Action
|
|
12
|
+
|
|
13
|
+
__all__ = ["ProcessJobsAction"]
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProcessJobsAction(Action[JobMessage]):
|
|
19
|
+
# Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
|
|
20
|
+
# iterates through all molecules in the input (files), writes them as batches
|
|
21
|
+
# into checkpoint files and sends checkpoint messages (for each batch) to the
|
|
22
|
+
# "<job_type>-checkpoints" topic. Also, the number of molecules read is
|
|
23
|
+
# reported to the topic "job-sizes".
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
channel: Channel,
|
|
28
|
+
checkpoint_size: int,
|
|
29
|
+
max_num_molecules: int,
|
|
30
|
+
num_test_entries: int,
|
|
31
|
+
ratio_valid_entries: float,
|
|
32
|
+
maximum_depth: int,
|
|
33
|
+
max_num_lines_mol_block: int,
|
|
34
|
+
data_dir: str,
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__(channel.jobs_topic())
|
|
37
|
+
# relevant for chunking
|
|
38
|
+
self.checkpoint_size = checkpoint_size
|
|
39
|
+
self.max_num_molecules = max_num_molecules
|
|
40
|
+
# parameters of DepthFirstExplorer
|
|
41
|
+
self.num_test_entries = num_test_entries
|
|
42
|
+
self.ratio_valid_entries = ratio_valid_entries
|
|
43
|
+
self.maximum_depth = maximum_depth
|
|
44
|
+
# used as kwargs in DepthFirstExplorer
|
|
45
|
+
self.max_num_lines_mol_block = max_num_lines_mol_block
|
|
46
|
+
self.data_dir = data_dir
|
|
47
|
+
|
|
48
|
+
async def _process_message(self, message: JobMessage) -> None:
|
|
49
|
+
job_id = message.id
|
|
50
|
+
job_type = message.job_type
|
|
51
|
+
logger.info(f"Received a new job {job_id} of type {job_type}")
|
|
52
|
+
|
|
53
|
+
# the input file to the job is stored in the directory data_dir/sources/
|
|
54
|
+
# (the file is allowed to reference other files, but setting the data_dir
|
|
55
|
+
# to the sources directory ensures that we never read files outside of the
|
|
56
|
+
# sources directory)
|
|
57
|
+
sources_dir = os.path.join(self.data_dir, "sources")
|
|
58
|
+
|
|
59
|
+
# create a reader (explorer) for the input file
|
|
60
|
+
explorer = DepthFirstExplorer(
|
|
61
|
+
num_test_entries=self.num_test_entries,
|
|
62
|
+
threshold=self.ratio_valid_entries,
|
|
63
|
+
maximum_depth=self.maximum_depth,
|
|
64
|
+
# extra args
|
|
65
|
+
max_num_lines_mol_block=self.max_num_lines_mol_block,
|
|
66
|
+
data_dir=sources_dir,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
read_input_step = ReadInputStep(explorer, message.source_id)
|
|
70
|
+
|
|
71
|
+
# create a directory for the job
|
|
72
|
+
os.makedirs(f"{self.data_dir}/jobs/{job_id}/input", exist_ok=True)
|
|
73
|
+
|
|
74
|
+
# read the input file
|
|
75
|
+
entries = read_input_step()
|
|
76
|
+
|
|
77
|
+
# iterate through the entries
|
|
78
|
+
# create batches of size checkpoint_size
|
|
79
|
+
# limit the number of molecules to max_num_molecules
|
|
80
|
+
batches = batched(entries, self.checkpoint_size)
|
|
81
|
+
num_entries = 0
|
|
82
|
+
for i, batch in enumerate(batches):
|
|
83
|
+
# max_num_molecules might be reached within the batch
|
|
84
|
+
num_store = min(len(batch), self.max_num_molecules - num_entries)
|
|
85
|
+
|
|
86
|
+
# store batch in data_dir
|
|
87
|
+
with open(f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{i}.pickle", "wb") as f:
|
|
88
|
+
dump(batch[:num_store], f)
|
|
89
|
+
|
|
90
|
+
# send a tuple to topic cypstrate-checkpoints
|
|
91
|
+
await self.channel.checkpoints_topic(job_type).send(
|
|
92
|
+
CheckpointMessage(
|
|
93
|
+
job_id=job_id,
|
|
94
|
+
checkpoint_id=i,
|
|
95
|
+
params=message.params,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
num_entries += num_store
|
|
100
|
+
|
|
101
|
+
if num_entries >= self.max_num_molecules:
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
|
|
105
|
+
|
|
106
|
+
# send a warning message if there were more molecules in the job than allowed
|
|
107
|
+
too_many_molecules = num_store < len(batch)
|
|
108
|
+
try:
|
|
109
|
+
# try to get another entry
|
|
110
|
+
next(entries)
|
|
111
|
+
|
|
112
|
+
# if we get here, there was another entry and we need to send a warning
|
|
113
|
+
too_many_molecules = True
|
|
114
|
+
except StopIteration:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
if too_many_molecules:
|
|
118
|
+
await self.channel.logs_topic().send(
|
|
119
|
+
LogMessage(
|
|
120
|
+
job_id=job_id,
|
|
121
|
+
message_type="warning",
|
|
122
|
+
message=(
|
|
123
|
+
f"The provided job contains more than "
|
|
124
|
+
f"{self.max_num_molecules} input structures. Only the "
|
|
125
|
+
f"first {self.max_num_molecules} will be processed."
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# at the end, send a tuple to topic job-sizes with the overall size
|
|
131
|
+
# of the job
|
|
132
|
+
await self.channel.logs_topic().send(
|
|
133
|
+
LogMessage(
|
|
134
|
+
job_id=job_id,
|
|
135
|
+
message_type="report_job_size",
|
|
136
|
+
size=num_entries,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from nerdd_module import Model
|
|
4
|
+
from stringcase import spinalcase
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..types import ModuleMessage, SystemMessage
|
|
8
|
+
from .action import Action
|
|
9
|
+
|
|
10
|
+
__all__ = ["RegisterModuleAction"]
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RegisterModuleAction(Action[SystemMessage]):
|
|
16
|
+
def __init__(self, channel: Channel, model: Model):
|
|
17
|
+
super().__init__(channel.system_topic())
|
|
18
|
+
# TODO: do this differently
|
|
19
|
+
assert hasattr(model, "get_config")
|
|
20
|
+
self._model = model
|
|
21
|
+
|
|
22
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
23
|
+
# send the initialization message
|
|
24
|
+
config = self._model.get_config()
|
|
25
|
+
logger.info(f"Send registration message for module {config.name}")
|
|
26
|
+
await self.channel.modules_topic().send(ModuleMessage(**config.model_dump()))
|
|
27
|
+
|
|
28
|
+
def _get_group_name(self) -> str:
|
|
29
|
+
model_name = spinalcase(self._model.__class__.__name__)
|
|
30
|
+
return model_name
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from nerdd_module import Model
|
|
2
|
+
from stringcase import spinalcase
|
|
3
|
+
|
|
4
|
+
from ..channels import Channel
|
|
5
|
+
from ..types import SystemMessage
|
|
6
|
+
from .action import Action
|
|
7
|
+
|
|
8
|
+
__all__ = ["WriteOutputAction"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WriteOutputAction(Action[SystemMessage]):
|
|
12
|
+
def __init__(self, channel: Channel, model: Model):
|
|
13
|
+
super().__init__(channel.system_topic())
|
|
14
|
+
self._model = model
|
|
15
|
+
|
|
16
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def _get_group_name(self) -> str:
|
|
20
|
+
model_name = spinalcase(self._model.__class__.__name__)
|
|
21
|
+
return model_name
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncIterable, Generic, TypeVar, Union, cast
|
|
5
|
+
|
|
6
|
+
from nerdd_module import Model
|
|
7
|
+
from stringcase import spinalcase # type: ignore
|
|
8
|
+
|
|
9
|
+
from ..types import (
|
|
10
|
+
CheckpointMessage,
|
|
11
|
+
JobMessage,
|
|
12
|
+
LogMessage,
|
|
13
|
+
Message,
|
|
14
|
+
ModuleMessage,
|
|
15
|
+
ResultCheckpointMessage,
|
|
16
|
+
ResultMessage,
|
|
17
|
+
SystemMessage,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = ["Channel", "Topic"]
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=Message)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_job_type(job_type_or_model: Union[str, Model]) -> str:
|
|
26
|
+
if isinstance(job_type_or_model, Model):
|
|
27
|
+
model = job_type_or_model
|
|
28
|
+
|
|
29
|
+
# create topic name from model name by
|
|
30
|
+
# * converting to spinal case, (e.g. "MyModel" -> "my-model")
|
|
31
|
+
# * converting to lowercase (just to be sure) and
|
|
32
|
+
# * removing all characters except dash and alphanumeric characters
|
|
33
|
+
topic_name = spinalcase(model.name)
|
|
34
|
+
topic_name = topic_name.lower()
|
|
35
|
+
topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
|
|
36
|
+
return topic_name
|
|
37
|
+
else:
|
|
38
|
+
return spinalcase(job_type_or_model)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Topic(Generic[T]):
|
|
42
|
+
def __init__(self, channel: Channel, name: str):
|
|
43
|
+
self._channel = channel
|
|
44
|
+
self._name = name
|
|
45
|
+
|
|
46
|
+
async def receive(self, consumer_group: str) -> AsyncIterable[T]:
|
|
47
|
+
async for msg in self.channel.iter_messages(self._name, consumer_group):
|
|
48
|
+
yield cast(T, msg)
|
|
49
|
+
|
|
50
|
+
async def send(self, message: T) -> None:
|
|
51
|
+
await self.channel.send(self._name, message)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def channel(self) -> Channel:
|
|
55
|
+
return self._channel
|
|
56
|
+
|
|
57
|
+
def __repr__(self) -> str:
|
|
58
|
+
return f"Topic({self._name})"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Channel(ABC):
|
|
62
|
+
#
|
|
63
|
+
# RECEIVE
|
|
64
|
+
#
|
|
65
|
+
async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
66
|
+
async for message in self._iter_messages(topic, consumer_group):
|
|
67
|
+
yield message
|
|
68
|
+
|
|
69
|
+
# Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
|
|
70
|
+
# here, because the method doesn't use "yield" and so the type checker will assume that the
|
|
71
|
+
# actual type is Coroutine[AsyncIterable[Message], None, None].
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
#
|
|
77
|
+
# SEND
|
|
78
|
+
#
|
|
79
|
+
async def send(self, topic: str, message: Message) -> None:
|
|
80
|
+
await self._send(topic, message)
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
#
|
|
87
|
+
# TOPICS
|
|
88
|
+
#
|
|
89
|
+
def modules_topic(self) -> Topic[ModuleMessage]:
|
|
90
|
+
return Topic[ModuleMessage](self, "modules")
|
|
91
|
+
|
|
92
|
+
def jobs_topic(self) -> Topic[JobMessage]:
|
|
93
|
+
return Topic[JobMessage](self, "jobs")
|
|
94
|
+
|
|
95
|
+
def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
|
|
96
|
+
job_type = get_job_type(job_type_or_model)
|
|
97
|
+
topic_name = f"{job_type}-checkpoints"
|
|
98
|
+
return Topic[CheckpointMessage](self, topic_name)
|
|
99
|
+
|
|
100
|
+
def results_topic(self) -> Topic[ResultMessage]:
|
|
101
|
+
return Topic[ResultMessage](self, "results")
|
|
102
|
+
|
|
103
|
+
def result_checkpoints_topic(
|
|
104
|
+
self, job_type_or_model: Union[str, Model]
|
|
105
|
+
) -> Topic[ResultCheckpointMessage]:
|
|
106
|
+
job_type = get_job_type(job_type_or_model)
|
|
107
|
+
topic_name = f"{job_type}-result-checkpoints"
|
|
108
|
+
return Topic[ResultCheckpointMessage](self, topic_name)
|
|
109
|
+
|
|
110
|
+
def logs_topic(self) -> Topic[LogMessage]:
|
|
111
|
+
return Topic[LogMessage](self, "logs")
|
|
112
|
+
|
|
113
|
+
def system_topic(self) -> Topic[SystemMessage]:
|
|
114
|
+
return Topic[SystemMessage](self, "system")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
|
|
7
|
+
|
|
8
|
+
from ..types import Message
|
|
9
|
+
from .channel import Channel
|
|
10
|
+
|
|
11
|
+
__all__ = ["KafkaChannel"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KafkaChannel(Channel):
|
|
17
|
+
def __init__(self, broker_url: str) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._broker_url = broker_url
|
|
20
|
+
self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
|
|
21
|
+
|
|
22
|
+
self._producer = AIOKafkaProducer(
|
|
23
|
+
bootstrap_servers=[self._broker_url],
|
|
24
|
+
)
|
|
25
|
+
# TODO: check value_serializer
|
|
26
|
+
# producer = AIOKafkaProducer(
|
|
27
|
+
# bootstrap_servers=KAFKA_BROKER_URL,
|
|
28
|
+
# value_serializer=lambda v: json.dumps(v).encode("utf-8"),
|
|
29
|
+
# )
|
|
30
|
+
asyncio.create_task(self._producer.start())
|
|
31
|
+
logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer.")
|
|
32
|
+
|
|
33
|
+
async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
34
|
+
if consumer_group is not None:
|
|
35
|
+
consumer_group = f"{consumer_group}-consumer-group"
|
|
36
|
+
|
|
37
|
+
key = (topic, consumer_group)
|
|
38
|
+
|
|
39
|
+
if key not in self._consumers:
|
|
40
|
+
# create consumer
|
|
41
|
+
consumer = AIOKafkaConsumer(
|
|
42
|
+
topic,
|
|
43
|
+
bootstrap_servers=[self._broker_url],
|
|
44
|
+
auto_offset_reset="earliest",
|
|
45
|
+
group_id=consumer_group,
|
|
46
|
+
enable_auto_commit=False,
|
|
47
|
+
)
|
|
48
|
+
await consumer.start()
|
|
49
|
+
self._consumers[key] = consumer
|
|
50
|
+
logger.info(
|
|
51
|
+
f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
|
|
52
|
+
f"topic {topic}."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
consumer = self._consumers[key]
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
async for message in consumer:
|
|
59
|
+
message_obj = json.loads(message.value)
|
|
60
|
+
yield Message(**message_obj)
|
|
61
|
+
await consumer.commit()
|
|
62
|
+
finally:
|
|
63
|
+
await consumer.stop()
|
|
64
|
+
|
|
65
|
+
# try:
|
|
66
|
+
# while True:
|
|
67
|
+
# # we use polling (instead of iterating through the consumer messages)
|
|
68
|
+
# # to be able to cancel the consumer
|
|
69
|
+
# messages = await self.kafka_consumer.getmany(timeout_ms=1000)
|
|
70
|
+
|
|
71
|
+
# if messages:
|
|
72
|
+
# for _, message_list in messages.items():
|
|
73
|
+
# for message in message_list:
|
|
74
|
+
# result = json.loads(message.value)
|
|
75
|
+
# logger.info(f"Received message on {message.topic}")
|
|
76
|
+
|
|
77
|
+
# try:
|
|
78
|
+
# for consumer in self.consumers:
|
|
79
|
+
# await consumer.consume(result)
|
|
80
|
+
|
|
81
|
+
# logger.info("Committing message")
|
|
82
|
+
# await self.kafka_consumer.commit()
|
|
83
|
+
# except Exception:
|
|
84
|
+
# logger.info("Rolling back message")
|
|
85
|
+
# logger.error(traceback.format_exc())
|
|
86
|
+
# except asyncio.CancelledError:
|
|
87
|
+
# logger.info("Stopping ConsumeKafkaTopicLifespan")
|
|
88
|
+
# await self.kafka_consumer.stop()
|
|
89
|
+
# except Exception as e:
|
|
90
|
+
# logger.error(e)
|
|
91
|
+
# logger.error(traceback.format_exc())
|
|
92
|
+
|
|
93
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
94
|
+
await self._producer.send_and_wait(
|
|
95
|
+
topic,
|
|
96
|
+
json.dumps(message.model_dump()).encode("utf-8"),
|
|
97
|
+
)
|