nerdd-link 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. nerdd_link-0.1.0/LICENSE +21 -0
  2. nerdd_link-0.1.0/PKG-INFO +116 -0
  3. nerdd_link-0.1.0/README.md +39 -0
  4. nerdd_link-0.1.0/nerdd_link/__init__.py +4 -0
  5. nerdd_link-0.1.0/nerdd_link/actions/__init__.py +5 -0
  6. nerdd_link-0.1.0/nerdd_link/actions/action.py +30 -0
  7. nerdd_link-0.1.0/nerdd_link/actions/predict_checkpoints_action.py +59 -0
  8. nerdd_link-0.1.0/nerdd_link/actions/process_jobs_action.py +138 -0
  9. nerdd_link-0.1.0/nerdd_link/actions/register_module_action.py +30 -0
  10. nerdd_link-0.1.0/nerdd_link/actions/write_output_action.py +21 -0
  11. nerdd_link-0.1.0/nerdd_link/channels/__init__.py +3 -0
  12. nerdd_link-0.1.0/nerdd_link/channels/channel.py +114 -0
  13. nerdd_link-0.1.0/nerdd_link/channels/kafka_channel.py +97 -0
  14. nerdd_link-0.1.0/nerdd_link/channels/memory_channel.py +33 -0
  15. nerdd_link-0.1.0/nerdd_link/cli/__init__.py +3 -0
  16. nerdd_link-0.1.0/nerdd_link/cli/initialize_system.py +45 -0
  17. nerdd_link-0.1.0/nerdd_link/cli/run_job_server.py +107 -0
  18. nerdd_link-0.1.0/nerdd_link/cli/run_prediction_server.py +81 -0
  19. nerdd_link-0.1.0/nerdd_link/delegates/__init__.py +3 -0
  20. nerdd_link-0.1.0/nerdd_link/delegates/pickle_writer.py +18 -0
  21. nerdd_link-0.1.0/nerdd_link/delegates/read_checkpoint_model.py +65 -0
  22. nerdd_link-0.1.0/nerdd_link/delegates/read_pickle_step.py +18 -0
  23. nerdd_link-0.1.0/nerdd_link/delegates/split_and_merge_step.py +51 -0
  24. nerdd_link-0.1.0/nerdd_link/delegates/topic_writer.py +27 -0
  25. nerdd_link-0.1.0/nerdd_link/input/__init__.py +1 -0
  26. nerdd_link-0.1.0/nerdd_link/input/structure_json_reader.py +41 -0
  27. nerdd_link-0.1.0/nerdd_link/py.typed +0 -0
  28. nerdd_link-0.1.0/nerdd_link/tests/__init__.py +3 -0
  29. nerdd_link-0.1.0/nerdd_link/tests/async_step.py +22 -0
  30. nerdd_link-0.1.0/nerdd_link/tests/channels.py +82 -0
  31. nerdd_link-0.1.0/nerdd_link/tests/files.py +9 -0
  32. nerdd_link-0.1.0/nerdd_link/types/__init__.py +54 -0
  33. nerdd_link-0.1.0/nerdd_link/utils/__init__.py +4 -0
  34. nerdd_link-0.1.0/nerdd_link/utils/async_to_sync.py +26 -0
  35. nerdd_link-0.1.0/nerdd_link/utils/batched.py +27 -0
  36. nerdd_link-0.1.0/nerdd_link/utils/observable_list.py +72 -0
  37. nerdd_link-0.1.0/nerdd_link/utils/safetee.py +39 -0
  38. nerdd_link-0.1.0/nerdd_link/version.py +11 -0
  39. nerdd_link-0.1.0/nerdd_link.egg-info/PKG-INFO +116 -0
  40. nerdd_link-0.1.0/nerdd_link.egg-info/SOURCES.txt +45 -0
  41. nerdd_link-0.1.0/nerdd_link.egg-info/dependency_links.txt +1 -0
  42. nerdd_link-0.1.0/nerdd_link.egg-info/entry_points.txt +4 -0
  43. nerdd_link-0.1.0/nerdd_link.egg-info/requires.txt +34 -0
  44. nerdd_link-0.1.0/nerdd_link.egg-info/top_level.txt +1 -0
  45. nerdd_link-0.1.0/pyproject.toml +137 -0
  46. nerdd_link-0.1.0/setup.cfg +4 -0
  47. nerdd_link-0.1.0/tests/test_features.py +3 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Molecular Informatics Vienna
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,116 @@
1
+ Metadata-Version: 2.1
2
+ Name: nerdd-link
3
+ Version: 0.1.0
4
+ Summary: Run a NERDD module as a service
5
+ Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
6
+ Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
7
+ License: MIT License
8
+
9
+ Copyright (c) 2023 Molecular Informatics Vienna
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
28
+
29
+ Project-URL: Repository, https://github.com/molinfo-vienna/nerdd-link
30
+ Keywords: science,research,development,nerdd
31
+ Classifier: Intended Audience :: Science/Research
32
+ Classifier: Intended Audience :: Developers
33
+ Classifier: License :: OSI Approved :: BSD License
34
+ Classifier: Programming Language :: Python
35
+ Classifier: Topic :: Software Development
36
+ Classifier: Topic :: Scientific/Engineering
37
+ Classifier: Operating System :: Microsoft :: Windows
38
+ Classifier: Operating System :: POSIX
39
+ Classifier: Operating System :: Unix
40
+ Classifier: Operating System :: MacOS
41
+ Classifier: Programming Language :: Python :: 3
42
+ Classifier: Programming Language :: Python :: 3.9
43
+ Classifier: Programming Language :: Python :: 3.10
44
+ Classifier: Programming Language :: Python :: 3.11
45
+ Classifier: Programming Language :: Python :: 3.12
46
+ Description-Content-Type: text/markdown
47
+ License-File: LICENSE
48
+ Requires-Dist: nerdd-module>=0.3.6
49
+ Requires-Dist: pandas>=1.2.1
50
+ Requires-Dist: pyyaml~=6.0
51
+ Requires-Dist: filetype~=1.2.0
52
+ Requires-Dist: rich-click>=1.7.1
53
+ Requires-Dist: stringcase~=1.2.0
54
+ Requires-Dist: numpy
55
+ Requires-Dist: simplejson>=3
56
+ Requires-Dist: pydantic>=2
57
+ Requires-Dist: aiokafka>=0.11.0
58
+ Requires-Dist: importlib-metadata>=4.6; python_version < "3.10"
59
+ Provides-Extra: dev
60
+ Requires-Dist: mypy; extra == "dev"
61
+ Requires-Dist: ruff==0.8.0; extra == "dev"
62
+ Requires-Dist: pre-commit>=2; extra == "dev"
63
+ Provides-Extra: test
64
+ Requires-Dist: pytest; extra == "test"
65
+ Requires-Dist: pytest-sugar; extra == "test"
66
+ Requires-Dist: pytest-cov; extra == "test"
67
+ Requires-Dist: pytest-asyncio; extra == "test"
68
+ Requires-Dist: pytest-bdd==7.3.0; extra == "test"
69
+ Requires-Dist: pytest-mock; extra == "test"
70
+ Requires-Dist: pytest-watcher; extra == "test"
71
+ Requires-Dist: hypothesis; extra == "test"
72
+ Requires-Dist: hypothesis-rdkit; extra == "test"
73
+ Provides-Extra: docs
74
+ Requires-Dist: mkdocs; extra == "docs"
75
+ Requires-Dist: mkdocs-material; extra == "docs"
76
+ Requires-Dist: mkdocstrings; extra == "docs"
77
+
78
+ # NERDD-Link
79
+
80
+ Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
81
+ service that consumes input molecules and produces prediction tuples.
82
+
83
+
84
+ ## Installation
85
+
86
+ ```bash
87
+ pip install -U nerdd-link
88
+ ```
89
+
90
+ ## Usage
91
+
92
+ When a class inherits from ```nerdd_module.AbstractModel``` (see
93
+ [NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
94
+ used to create a Kafka service.
95
+
96
+ ```bash
97
+ # run a Kafka service for NerddModel on localhost:9092
98
+ run_nerdd_server package.path.to.NerddModel
99
+
100
+ # modify broker url, input topic and batch size
101
+ run_nerdd_server package.path.to.NerddModel \
102
+ --broker-url my-cluster-kafka-bootstrap.kafka:9092 \
103
+ --input-topic examples \
104
+ --batch-size 10
105
+
106
+ # more information via --help
107
+ run_nerdd_server --help
108
+ ```
109
+
110
+ If the model class is called ```ExamplePredictionModel```, the server will read input
111
+ tuples from the input topic ```example-prediction-inputs``` in batches of size 100
112
+ and write results to the ```results``` topic. The batch size specifies the number
113
+ of input tuples that are given to the model at once.
114
+
115
+ ## Communication
116
+
@@ -0,0 +1,39 @@
1
+ # NERDD-Link
2
+
3
+ Run a [NERDD module](https://github.com/molinfo-vienna/nerdd-module) as a
4
+ service that consumes input molecules and produces prediction tuples.
5
+
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ pip install -U nerdd-link
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ When a class inherits from ```nerdd_module.AbstractModel``` (see
16
+ [NERDD Module Github page](https://github.com/molinfo-vienna/nerdd-module)), it can be
17
+ used to create a Kafka service.
18
+
19
+ ```bash
20
+ # run a Kafka service for NerddModel on localhost:9092
21
+ run_nerdd_server package.path.to.NerddModel
22
+
23
+ # modify broker url, input topic and batch size
24
+ run_nerdd_server package.path.to.NerddModel \
25
+ --broker-url my-cluster-kafka-bootstrap.kafka:9092 \
26
+ --input-topic examples \
27
+ --batch-size 10
28
+
29
+ # more information via --help
30
+ run_nerdd_server --help
31
+ ```
32
+
33
+ If the model class is called ```ExamplePredictionModel```, the server will read input
34
+ tuples from the input topic ```example-prediction-inputs``` in batches of size 100
35
+ and write results to the ```results``` topic. The batch size specifies the number
36
+ of input tuples that are given to the model at once.
37
+
38
+ ## Communication
39
+
@@ -0,0 +1,4 @@
1
+ from .actions import *
2
+ from .channels import *
3
+ from .input import *
4
+ from .types import *
@@ -0,0 +1,5 @@
1
+ from .action import *
2
+ from .predict_checkpoints_action import *
3
+ from .process_jobs_action import *
4
+ from .register_module_action import *
5
+ from .write_output_action import *
@@ -0,0 +1,30 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ from stringcase import spinalcase
5
+
6
+ from ..channels import Channel, Topic
7
+ from ..types import Message
8
+
9
+ T = TypeVar("T", bound=Message)
10
+
11
+
12
+ class Action(ABC, Generic[T]):
13
+ def __init__(self, input_topic: Topic[T]):
14
+ self._input_topic = input_topic
15
+
16
+ async def run(self) -> None:
17
+ consumer_group = spinalcase(self._get_group_name())
18
+ async for message in self._input_topic.receive(consumer_group):
19
+ await self._process_message(message)
20
+
21
+ @abstractmethod
22
+ async def _process_message(self, message: T) -> None:
23
+ pass
24
+
25
+ @property
26
+ def channel(self) -> Channel:
27
+ return self._input_topic.channel
28
+
29
+ def _get_group_name(self) -> str:
30
+ return self.__class__.__name__
@@ -0,0 +1,59 @@
1
+ import logging
2
+ import os
3
+
4
+ from nerdd_module import Model
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadCheckpointModel
8
+ from ..types import CheckpointMessage
9
+ from .action import Action
10
+
11
+ __all__ = ["PredictCheckpointsAction"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PredictCheckpointsAction(Action[CheckpointMessage]):
17
+ # Accept a batch of input molecules on the "<job-type>-checkpoints" topic
18
+ # (generated in the previous step) and process them. Results are written to
19
+ # the "results" topic.
20
+
21
+ def __init__(self, channel: Channel, model: Model, data_dir: str) -> None:
22
+ super().__init__(channel.checkpoints_topic(model))
23
+ self.model = model
24
+ self.data_dir = data_dir
25
+
26
+ async def _process_message(self, message: CheckpointMessage) -> None:
27
+ job_id = message.job_id
28
+ checkpoint_id = message.checkpoint_id
29
+ params = message.params
30
+ logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
31
+
32
+ # the input file to the job is stored in the file data_dir/job_id/input/
33
+ checkpoints_file = f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{checkpoint_id}.pickle"
34
+ checkpoint_results_file = (
35
+ f"{self.data_dir}/jobs/{job_id}/results/checkpoint_{checkpoint_id}.pickle"
36
+ )
37
+
38
+ # create the results directory
39
+ os.makedirs(f"{self.data_dir}/jobs/{job_id}/results", exist_ok=True)
40
+
41
+ # create a model that reads the checkpoint file
42
+ model = ReadCheckpointModel(
43
+ base_model=self.model,
44
+ job_id=job_id,
45
+ checkpoint_id=checkpoint_id,
46
+ channel=self.channel,
47
+ checkpoints_file=checkpoints_file,
48
+ results_file=checkpoint_results_file,
49
+ )
50
+
51
+ # predict the checkpoint
52
+ model.predict(
53
+ input=None,
54
+ **params,
55
+ )
56
+
57
+ def _get_group_name(self) -> str:
58
+ model_name = self.model.__class__.__name__
59
+ return model_name
@@ -0,0 +1,138 @@
1
+ import logging
2
+ import os
3
+ from pickle import dump
4
+
5
+ from nerdd_module.input import DepthFirstExplorer
6
+ from nerdd_module.model import ReadInputStep
7
+
8
+ from ..channels import Channel
9
+ from ..types import CheckpointMessage, JobMessage, LogMessage
10
+ from ..utils import batched
11
+ from .action import Action
12
+
13
+ __all__ = ["ProcessJobsAction"]
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ProcessJobsAction(Action[JobMessage]):
19
+ # Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
20
+ # iterates through all molecules in the input (files), writes them as batches
21
+ # into checkpoint files and sends checkpoint messages (for each batch) to the
22
+ # "<job_type>-checkpoints" topic. Also, the number of molecules read is
23
+ # reported to the topic "job-sizes".
24
+
25
+ def __init__(
26
+ self,
27
+ channel: Channel,
28
+ checkpoint_size: int,
29
+ max_num_molecules: int,
30
+ num_test_entries: int,
31
+ ratio_valid_entries: float,
32
+ maximum_depth: int,
33
+ max_num_lines_mol_block: int,
34
+ data_dir: str,
35
+ ) -> None:
36
+ super().__init__(channel.jobs_topic())
37
+ # relevant for chunking
38
+ self.checkpoint_size = checkpoint_size
39
+ self.max_num_molecules = max_num_molecules
40
+ # parameters of DepthFirstExplorer
41
+ self.num_test_entries = num_test_entries
42
+ self.ratio_valid_entries = ratio_valid_entries
43
+ self.maximum_depth = maximum_depth
44
+ # used as kwargs in DepthFirstExplorer
45
+ self.max_num_lines_mol_block = max_num_lines_mol_block
46
+ self.data_dir = data_dir
47
+
48
+ async def _process_message(self, message: JobMessage) -> None:
49
+ job_id = message.id
50
+ job_type = message.job_type
51
+ logger.info(f"Received a new job {job_id} of type {job_type}")
52
+
53
+ # the input file to the job is stored in the directory data_dir/sources/
54
+ # (the file is allowed to reference other files, but setting the data_dir
55
+ # to the sources directory ensures that we never read files outside of the
56
+ # sources directory)
57
+ sources_dir = os.path.join(self.data_dir, "sources")
58
+
59
+ # create a reader (explorer) for the input file
60
+ explorer = DepthFirstExplorer(
61
+ num_test_entries=self.num_test_entries,
62
+ threshold=self.ratio_valid_entries,
63
+ maximum_depth=self.maximum_depth,
64
+ # extra args
65
+ max_num_lines_mol_block=self.max_num_lines_mol_block,
66
+ data_dir=sources_dir,
67
+ )
68
+
69
+ read_input_step = ReadInputStep(explorer, message.source_id)
70
+
71
+ # create a directory for the job
72
+ os.makedirs(f"{self.data_dir}/jobs/{job_id}/input", exist_ok=True)
73
+
74
+ # read the input file
75
+ entries = read_input_step()
76
+
77
+ # iterate through the entries
78
+ # create batches of size checkpoint_size
79
+ # limit the number of molecules to max_num_molecules
80
+ batches = batched(entries, self.checkpoint_size)
81
+ num_entries = 0
82
+ for i, batch in enumerate(batches):
83
+ # max_num_molecules might be reached within the batch
84
+ num_store = min(len(batch), self.max_num_molecules - num_entries)
85
+
86
+ # store batch in data_dir
87
+ with open(f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{i}.pickle", "wb") as f:
88
+ dump(batch[:num_store], f)
89
+
90
+ # send a tuple to topic cypstrate-checkpoints
91
+ await self.channel.checkpoints_topic(job_type).send(
92
+ CheckpointMessage(
93
+ job_id=job_id,
94
+ checkpoint_id=i,
95
+ params=message.params,
96
+ )
97
+ )
98
+
99
+ num_entries += num_store
100
+
101
+ if num_entries >= self.max_num_molecules:
102
+ break
103
+
104
+ logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
105
+
106
+ # send a warning message if there were more molecules in the job than allowed
107
+ too_many_molecules = num_store < len(batch)
108
+ try:
109
+ # try to get another entry
110
+ next(entries)
111
+
112
+ # if we get here, there was another entry and we need to send a warning
113
+ too_many_molecules = True
114
+ except StopIteration:
115
+ pass
116
+
117
+ if too_many_molecules:
118
+ await self.channel.logs_topic().send(
119
+ LogMessage(
120
+ job_id=job_id,
121
+ message_type="warning",
122
+ message=(
123
+ f"The provided job contains more than "
124
+ f"{self.max_num_molecules} input structures. Only the "
125
+ f"first {self.max_num_molecules} will be processed."
126
+ ),
127
+ )
128
+ )
129
+
130
+ # at the end, send a tuple to topic job-sizes with the overall size
131
+ # of the job
132
+ await self.channel.logs_topic().send(
133
+ LogMessage(
134
+ job_id=job_id,
135
+ message_type="report_job_size",
136
+ size=num_entries,
137
+ )
138
+ )
@@ -0,0 +1,30 @@
1
+ import logging
2
+
3
+ from nerdd_module import Model
4
+ from stringcase import spinalcase
5
+
6
+ from ..channels import Channel
7
+ from ..types import ModuleMessage, SystemMessage
8
+ from .action import Action
9
+
10
+ __all__ = ["RegisterModuleAction"]
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class RegisterModuleAction(Action[SystemMessage]):
16
+ def __init__(self, channel: Channel, model: Model):
17
+ super().__init__(channel.system_topic())
18
+ # TODO: do this differently
19
+ assert hasattr(model, "get_config")
20
+ self._model = model
21
+
22
+ async def _process_message(self, message: SystemMessage) -> None:
23
+ # send the initialization message
24
+ config = self._model.get_config()
25
+ logger.info(f"Send registration message for module {config.name}")
26
+ await self.channel.modules_topic().send(ModuleMessage(**config.model_dump()))
27
+
28
+ def _get_group_name(self) -> str:
29
+ model_name = spinalcase(self._model.__class__.__name__)
30
+ return model_name
@@ -0,0 +1,21 @@
1
+ from nerdd_module import Model
2
+ from stringcase import spinalcase
3
+
4
+ from ..channels import Channel
5
+ from ..types import SystemMessage
6
+ from .action import Action
7
+
8
+ __all__ = ["WriteOutputAction"]
9
+
10
+
11
+ class WriteOutputAction(Action[SystemMessage]):
12
+ def __init__(self, channel: Channel, model: Model):
13
+ super().__init__(channel.system_topic())
14
+ self._model = model
15
+
16
+ async def _process_message(self, message: SystemMessage) -> None:
17
+ pass
18
+
19
+ def _get_group_name(self) -> str:
20
+ model_name = spinalcase(self._model.__class__.__name__)
21
+ return model_name
@@ -0,0 +1,3 @@
1
+ from .channel import *
2
+ from .kafka_channel import *
3
+ from .memory_channel import *
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterable, Generic, TypeVar, Union, cast
5
+
6
+ from nerdd_module import Model
7
+ from stringcase import spinalcase # type: ignore
8
+
9
+ from ..types import (
10
+ CheckpointMessage,
11
+ JobMessage,
12
+ LogMessage,
13
+ Message,
14
+ ModuleMessage,
15
+ ResultCheckpointMessage,
16
+ ResultMessage,
17
+ SystemMessage,
18
+ )
19
+
20
+ __all__ = ["Channel", "Topic"]
21
+
22
+ T = TypeVar("T", bound=Message)
23
+
24
+
25
+ def get_job_type(job_type_or_model: Union[str, Model]) -> str:
26
+ if isinstance(job_type_or_model, Model):
27
+ model = job_type_or_model
28
+
29
+ # create topic name from model name by
30
+ # * converting to spinal case, (e.g. "MyModel" -> "my-model")
31
+ # * converting to lowercase (just to be sure) and
32
+ # * removing all characters except dash and alphanumeric characters
33
+ topic_name = spinalcase(model.name)
34
+ topic_name = topic_name.lower()
35
+ topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
36
+ return topic_name
37
+ else:
38
+ return spinalcase(job_type_or_model)
39
+
40
+
41
+ class Topic(Generic[T]):
42
+ def __init__(self, channel: Channel, name: str):
43
+ self._channel = channel
44
+ self._name = name
45
+
46
+ async def receive(self, consumer_group: str) -> AsyncIterable[T]:
47
+ async for msg in self.channel.iter_messages(self._name, consumer_group):
48
+ yield cast(T, msg)
49
+
50
+ async def send(self, message: T) -> None:
51
+ await self.channel.send(self._name, message)
52
+
53
+ @property
54
+ def channel(self) -> Channel:
55
+ return self._channel
56
+
57
+ def __repr__(self) -> str:
58
+ return f"Topic({self._name})"
59
+
60
+
61
+ class Channel(ABC):
62
+ #
63
+ # RECEIVE
64
+ #
65
+ async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
66
+ async for message in self._iter_messages(topic, consumer_group):
67
+ yield message
68
+
69
+ # Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
70
+ # here, because the method doesn't use "yield" and so the type checker will assume that the
71
+ # actual type is Coroutine[AsyncIterable[Message], None, None].
72
+ @abstractmethod
73
+ def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
74
+ pass
75
+
76
+ #
77
+ # SEND
78
+ #
79
+ async def send(self, topic: str, message: Message) -> None:
80
+ await self._send(topic, message)
81
+
82
+ @abstractmethod
83
+ async def _send(self, topic: str, message: Message) -> None:
84
+ pass
85
+
86
+ #
87
+ # TOPICS
88
+ #
89
+ def modules_topic(self) -> Topic[ModuleMessage]:
90
+ return Topic[ModuleMessage](self, "modules")
91
+
92
+ def jobs_topic(self) -> Topic[JobMessage]:
93
+ return Topic[JobMessage](self, "jobs")
94
+
95
+ def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
96
+ job_type = get_job_type(job_type_or_model)
97
+ topic_name = f"{job_type}-checkpoints"
98
+ return Topic[CheckpointMessage](self, topic_name)
99
+
100
+ def results_topic(self) -> Topic[ResultMessage]:
101
+ return Topic[ResultMessage](self, "results")
102
+
103
+ def result_checkpoints_topic(
104
+ self, job_type_or_model: Union[str, Model]
105
+ ) -> Topic[ResultCheckpointMessage]:
106
+ job_type = get_job_type(job_type_or_model)
107
+ topic_name = f"{job_type}-result-checkpoints"
108
+ return Topic[ResultCheckpointMessage](self, topic_name)
109
+
110
+ def logs_topic(self) -> Topic[LogMessage]:
111
+ return Topic[LogMessage](self, "logs")
112
+
113
+ def system_topic(self) -> Topic[SystemMessage]:
114
+ return Topic[SystemMessage](self, "system")
@@ -0,0 +1,97 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from typing import AsyncIterable, Dict, Tuple
5
+
6
+ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
7
+
8
+ from ..types import Message
9
+ from .channel import Channel
10
+
11
+ __all__ = ["KafkaChannel"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class KafkaChannel(Channel):
17
+ def __init__(self, broker_url: str) -> None:
18
+ super().__init__()
19
+ self._broker_url = broker_url
20
+ self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
21
+
22
+ self._producer = AIOKafkaProducer(
23
+ bootstrap_servers=[self._broker_url],
24
+ )
25
+ # TODO: check value_serializer
26
+ # producer = AIOKafkaProducer(
27
+ # bootstrap_servers=KAFKA_BROKER_URL,
28
+ # value_serializer=lambda v: json.dumps(v).encode("utf-8"),
29
+ # )
30
+ asyncio.create_task(self._producer.start())
31
+ logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer.")
32
+
33
+ async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
34
+ if consumer_group is not None:
35
+ consumer_group = f"{consumer_group}-consumer-group"
36
+
37
+ key = (topic, consumer_group)
38
+
39
+ if key not in self._consumers:
40
+ # create consumer
41
+ consumer = AIOKafkaConsumer(
42
+ topic,
43
+ bootstrap_servers=[self._broker_url],
44
+ auto_offset_reset="earliest",
45
+ group_id=consumer_group,
46
+ enable_auto_commit=False,
47
+ )
48
+ await consumer.start()
49
+ self._consumers[key] = consumer
50
+ logger.info(
51
+ f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
52
+ f"topic {topic}."
53
+ )
54
+
55
+ consumer = self._consumers[key]
56
+
57
+ try:
58
+ async for message in consumer:
59
+ message_obj = json.loads(message.value)
60
+ yield Message(**message_obj)
61
+ await consumer.commit()
62
+ finally:
63
+ await consumer.stop()
64
+
65
+ # try:
66
+ # while True:
67
+ # # we use polling (instead of iterating through the consumer messages)
68
+ # # to be able to cancel the consumer
69
+ # messages = await self.kafka_consumer.getmany(timeout_ms=1000)
70
+
71
+ # if messages:
72
+ # for _, message_list in messages.items():
73
+ # for message in message_list:
74
+ # result = json.loads(message.value)
75
+ # logger.info(f"Received message on {message.topic}")
76
+
77
+ # try:
78
+ # for consumer in self.consumers:
79
+ # await consumer.consume(result)
80
+
81
+ # logger.info("Committing message")
82
+ # await self.kafka_consumer.commit()
83
+ # except Exception:
84
+ # logger.info("Rolling back message")
85
+ # logger.error(traceback.format_exc())
86
+ # except asyncio.CancelledError:
87
+ # logger.info("Stopping ConsumeKafkaTopicLifespan")
88
+ # await self.kafka_consumer.stop()
89
+ # except Exception as e:
90
+ # logger.error(e)
91
+ # logger.error(traceback.format_exc())
92
+
93
+ async def _send(self, topic: str, message: Message) -> None:
94
+ await self._producer.send_and_wait(
95
+ topic,
96
+ json.dumps(message.model_dump()).encode("utf-8"),
97
+ )