nerdd-link 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nerdd_link/__init__.py +6 -0
- nerdd_link/actions/__init__.py +5 -0
- nerdd_link/actions/action.py +41 -0
- nerdd_link/actions/predict_checkpoints_action.py +76 -0
- nerdd_link/actions/process_jobs_action.py +157 -0
- nerdd_link/actions/register_module_action.py +37 -0
- nerdd_link/actions/serialize_job_action.py +64 -0
- nerdd_link/channels/__init__.py +3 -0
- nerdd_link/channels/channel.py +147 -0
- nerdd_link/channels/kafka_channel.py +109 -0
- nerdd_link/channels/memory_channel.py +47 -0
- nerdd_link/cli/__init__.py +4 -0
- nerdd_link/cli/initialize_system.py +49 -0
- nerdd_link/cli/run_job_server.py +115 -0
- nerdd_link/cli/run_prediction_server.py +85 -0
- nerdd_link/cli/run_serialization_server.py +73 -0
- nerdd_link/converters/__init__.py +6 -0
- nerdd_link/converters/image_converter.py +15 -0
- nerdd_link/converters/mol_pickle_converter.py +20 -0
- nerdd_link/converters/mol_to_image_converter.py +77 -0
- nerdd_link/converters/pickle_converter.py +15 -0
- nerdd_link/converters/problem_list_converter.py +15 -0
- nerdd_link/converters/source_list_converter.py +15 -0
- nerdd_link/delegates/__init__.py +5 -0
- nerdd_link/delegates/pickle_writer.py +18 -0
- nerdd_link/delegates/read_checkpoint_model.py +73 -0
- nerdd_link/delegates/read_pickle_step.py +21 -0
- nerdd_link/delegates/serialize_job_model.py +21 -0
- nerdd_link/delegates/split_and_merge_step.py +51 -0
- nerdd_link/delegates/topic_writer.py +19 -0
- nerdd_link/files/__init__.py +1 -0
- nerdd_link/files/file_system.py +89 -0
- nerdd_link/input/__init__.py +1 -0
- nerdd_link/input/structure_json_reader.py +37 -0
- nerdd_link/py.typed +0 -0
- nerdd_link/tests/__init__.py +3 -0
- nerdd_link/tests/async_step.py +22 -0
- nerdd_link/tests/channels.py +81 -0
- nerdd_link/tests/files.py +9 -0
- nerdd_link/types/__init__.py +70 -0
- nerdd_link/utils/__init__.py +4 -0
- nerdd_link/utils/async_to_sync.py +26 -0
- nerdd_link/utils/batched.py +27 -0
- nerdd_link/utils/observable_list.py +72 -0
- nerdd_link/utils/safetee.py +39 -0
- nerdd_link/version.py +11 -0
- nerdd_link-0.2.11.dist-info/LICENSE +21 -0
- nerdd_link-0.2.11.dist-info/METADATA +116 -0
- nerdd_link-0.2.11.dist-info/RECORD +52 -0
- nerdd_link-0.2.11.dist-info/WHEEL +5 -0
- nerdd_link-0.2.11.dist-info/entry_points.txt +5 -0
- nerdd_link-0.2.11.dist-info/top_level.txt +1 -0
nerdd_link/__init__.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from asyncio import CancelledError
|
|
4
|
+
from typing import Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from stringcase import spinalcase
|
|
7
|
+
|
|
8
|
+
from ..channels import Channel, Topic
|
|
9
|
+
from ..types import Message
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
T = TypeVar("T", bound=Message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Action(ABC, Generic[T]):
|
|
17
|
+
def __init__(self, input_topic: Topic[T]):
|
|
18
|
+
self._input_topic = input_topic
|
|
19
|
+
|
|
20
|
+
async def run(self) -> None:
|
|
21
|
+
consumer_group = spinalcase(self._get_group_name())
|
|
22
|
+
async for message in self._input_topic.receive(consumer_group):
|
|
23
|
+
try:
|
|
24
|
+
await self._process_message(message)
|
|
25
|
+
except CancelledError:
|
|
26
|
+
# the consumer was cancelled, stop processing messages
|
|
27
|
+
break
|
|
28
|
+
except Exception:
|
|
29
|
+
# log the error and continue processing the next message
|
|
30
|
+
logger.error("Error processing message", exc_info=True)
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
async def _process_message(self, message: T) -> None:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def channel(self) -> Channel:
|
|
38
|
+
return self._input_topic.channel
|
|
39
|
+
|
|
40
|
+
def _get_group_name(self) -> str:
|
|
41
|
+
return self.__class__.__name__
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from multiprocessing import Queue
|
|
3
|
+
|
|
4
|
+
from nerdd_module import SimpleModel
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadCheckpointModel
|
|
8
|
+
from ..files import FileSystem
|
|
9
|
+
from ..types import CheckpointMessage, ResultCheckpointMessage, ResultMessage
|
|
10
|
+
from .action import Action
|
|
11
|
+
|
|
12
|
+
__all__ = ["PredictCheckpointsAction"]
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PredictCheckpointsAction(Action[CheckpointMessage]):
|
|
18
|
+
# Accept a batch of input molecules on the "<job-type>-checkpoints" topic
|
|
19
|
+
# (generated in the previous step) and process them. Results are written to
|
|
20
|
+
# the "results" topic.
|
|
21
|
+
|
|
22
|
+
def __init__(self, channel: Channel, model: SimpleModel, data_dir: str) -> None:
|
|
23
|
+
super().__init__(channel.checkpoints_topic(model))
|
|
24
|
+
self._model = model
|
|
25
|
+
self._data_dir = data_dir
|
|
26
|
+
|
|
27
|
+
async def _process_message(self, message: CheckpointMessage) -> None:
|
|
28
|
+
job_id = message.job_id
|
|
29
|
+
checkpoint_id = message.checkpoint_id
|
|
30
|
+
params = message.params
|
|
31
|
+
logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
|
|
32
|
+
|
|
33
|
+
# The Kafka consumers and producers run in the current asyncio event loop and (by
|
|
34
|
+
# observation) it seems that calling the produce method of a Kafka producer in a different
|
|
35
|
+
# event loop / thread / process doesn't seem to work (hangs indefinitely). Therefore, we
|
|
36
|
+
# create a queue in this event loop / thread and other tasks send messages to the queue
|
|
37
|
+
# instead of directly to the Kafka producer. This event loop will wait for new messages in
|
|
38
|
+
# this queue and forward them to the Kafka producer.
|
|
39
|
+
queue: Queue = Queue()
|
|
40
|
+
|
|
41
|
+
file_system = FileSystem(self._data_dir)
|
|
42
|
+
|
|
43
|
+
# create a wrapper model that
|
|
44
|
+
# * reads the checkpoint file instead of normal input
|
|
45
|
+
# * does preprocessing, prediction, and postprocessing like the encapsulated model
|
|
46
|
+
# * does not write to the specified results file, but to the checkpoints file instead
|
|
47
|
+
# * sends the results to the results topic
|
|
48
|
+
model = ReadCheckpointModel(
|
|
49
|
+
base_model=self._model,
|
|
50
|
+
job_id=job_id,
|
|
51
|
+
file_system=file_system,
|
|
52
|
+
checkpoint_id=checkpoint_id,
|
|
53
|
+
queue=queue,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# predict the checkpoint
|
|
57
|
+
# assign input=None, because the checkpoint file is already provided in ReadCheckpointModel
|
|
58
|
+
model.predict(
|
|
59
|
+
input=None,
|
|
60
|
+
**params,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Wait for the prediction to finish and the results to be sent.
|
|
64
|
+
while True:
|
|
65
|
+
record = queue.get()
|
|
66
|
+
if record is not None:
|
|
67
|
+
await self.channel.results_topic().send(ResultMessage(job_id=job_id, **record))
|
|
68
|
+
else:
|
|
69
|
+
await self.channel.result_checkpoints_topic().send(
|
|
70
|
+
ResultCheckpointMessage(job_id=job_id, checkpoint_id=checkpoint_id)
|
|
71
|
+
)
|
|
72
|
+
break
|
|
73
|
+
|
|
74
|
+
def _get_group_name(self) -> str:
|
|
75
|
+
model_id = self._model.get_config().id
|
|
76
|
+
return f"predict-checkpoints-{model_id}"
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pickle import dump
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from nerdd_module.input import DepthFirstExplorer
|
|
6
|
+
from nerdd_module.model import ReadInputStep
|
|
7
|
+
from rdkit.Chem import Mol
|
|
8
|
+
from rdkit.Chem.PropertyMol import PropertyMol
|
|
9
|
+
|
|
10
|
+
from ..channels import Channel
|
|
11
|
+
from ..files import FileSystem
|
|
12
|
+
from ..types import CheckpointMessage, JobMessage, LogMessage
|
|
13
|
+
from ..utils import batched
|
|
14
|
+
from .action import Action
|
|
15
|
+
|
|
16
|
+
__all__ = ["ProcessJobsAction"]
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ProcessJobsAction(Action[JobMessage]):
|
|
22
|
+
# Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
|
|
23
|
+
# iterates through all molecules in the input (files), writes them as batches
|
|
24
|
+
# into checkpoint files and sends checkpoint messages (for each batch) to the
|
|
25
|
+
# "<job_type>-checkpoints" topic. Also, the number of molecules read is
|
|
26
|
+
# reported to the topic "job-sizes".
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
channel: Channel,
|
|
31
|
+
checkpoint_size: int,
|
|
32
|
+
max_num_molecules: int,
|
|
33
|
+
num_test_entries: int,
|
|
34
|
+
ratio_valid_entries: float,
|
|
35
|
+
maximum_depth: int,
|
|
36
|
+
max_num_lines_mol_block: int,
|
|
37
|
+
data_dir: str,
|
|
38
|
+
) -> None:
|
|
39
|
+
super().__init__(channel.jobs_topic())
|
|
40
|
+
# relevant for chunking
|
|
41
|
+
self._checkpoint_size = checkpoint_size
|
|
42
|
+
self._max_num_molecules = max_num_molecules
|
|
43
|
+
# parameters of DepthFirstExplorer
|
|
44
|
+
self._num_test_entries = num_test_entries
|
|
45
|
+
self._ratio_valid_entries = ratio_valid_entries
|
|
46
|
+
self._maximum_depth = maximum_depth
|
|
47
|
+
# used as kwargs in DepthFirstExplorer
|
|
48
|
+
self._max_num_lines_mol_block = max_num_lines_mol_block
|
|
49
|
+
self._file_system = FileSystem(data_dir)
|
|
50
|
+
|
|
51
|
+
async def _process_message(self, message: JobMessage) -> None:
|
|
52
|
+
job_id = message.id
|
|
53
|
+
job_type = message.job_type
|
|
54
|
+
logger.info(f"Received a new job {job_id} of type {job_type}")
|
|
55
|
+
|
|
56
|
+
# the input file to the job is stored in a designated sources directory
|
|
57
|
+
# (the file is allowed to reference other files, but setting the data_dir
|
|
58
|
+
# to the sources directory ensures that we never read files outside of the
|
|
59
|
+
# sources directory)
|
|
60
|
+
data_dir = self._file_system.get_sources_dir()
|
|
61
|
+
|
|
62
|
+
# create a reader (explorer) for the input file
|
|
63
|
+
explorer = DepthFirstExplorer(
|
|
64
|
+
num_test_entries=self._num_test_entries,
|
|
65
|
+
threshold=self._ratio_valid_entries,
|
|
66
|
+
maximum_depth=self._maximum_depth,
|
|
67
|
+
# extra args
|
|
68
|
+
max_num_lines_mol_block=self._max_num_lines_mol_block,
|
|
69
|
+
data_dir=data_dir,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
read_input_step = ReadInputStep(explorer, message.source_id)
|
|
73
|
+
|
|
74
|
+
# read the input file
|
|
75
|
+
entries = read_input_step()
|
|
76
|
+
|
|
77
|
+
# iterate through the entries
|
|
78
|
+
# create batches of size checkpoint_size
|
|
79
|
+
# limit the number of molecules to max_num_molecules
|
|
80
|
+
batches = batched(entries, self._checkpoint_size)
|
|
81
|
+
num_entries = 0
|
|
82
|
+
num_checkpoints = 0
|
|
83
|
+
for i, batch in enumerate(batches):
|
|
84
|
+
# max_num_molecules might be reached within the batch
|
|
85
|
+
num_store = min(len(batch), self._max_num_molecules - num_entries)
|
|
86
|
+
|
|
87
|
+
# store batch in data_dir
|
|
88
|
+
with self._file_system.get_checkpoint_file_handle(job_id, i, "wb") as f:
|
|
89
|
+
results = list(batch[:num_store])
|
|
90
|
+
|
|
91
|
+
# TODO: use a model for storing the batches
|
|
92
|
+
|
|
93
|
+
# check all items for mol values and use PropertyMol for those
|
|
94
|
+
# in order to keep molecular properties (thanks, RDKit! :/ )
|
|
95
|
+
def _check_value(value: Any) -> Any:
|
|
96
|
+
if isinstance(value, Mol):
|
|
97
|
+
return PropertyMol(value)
|
|
98
|
+
return value
|
|
99
|
+
|
|
100
|
+
def _check_item(item: dict) -> dict:
|
|
101
|
+
return {key: _check_value(value) for key, value in item.items()}
|
|
102
|
+
|
|
103
|
+
results = [_check_item(item) for item in results]
|
|
104
|
+
|
|
105
|
+
dump(results, f)
|
|
106
|
+
|
|
107
|
+
# send a tuple to topic cypstrate-checkpoints
|
|
108
|
+
await self.channel.checkpoints_topic(job_type).send(
|
|
109
|
+
CheckpointMessage(
|
|
110
|
+
job_id=job_id,
|
|
111
|
+
checkpoint_id=i,
|
|
112
|
+
params=message.params,
|
|
113
|
+
)
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
num_entries += num_store
|
|
117
|
+
num_checkpoints += 1
|
|
118
|
+
|
|
119
|
+
if num_entries >= self._max_num_molecules:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
|
|
123
|
+
|
|
124
|
+
# send a warning message if there were more molecules in the job than allowed
|
|
125
|
+
too_many_molecules = num_store < len(batch)
|
|
126
|
+
try:
|
|
127
|
+
# try to get another entry
|
|
128
|
+
next(entries)
|
|
129
|
+
|
|
130
|
+
# if we get here, there was another entry and we need to send a warning
|
|
131
|
+
too_many_molecules = True
|
|
132
|
+
except StopIteration:
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
if too_many_molecules:
|
|
136
|
+
await self.channel.logs_topic().send(
|
|
137
|
+
LogMessage(
|
|
138
|
+
job_id=job_id,
|
|
139
|
+
message_type="warning",
|
|
140
|
+
message=(
|
|
141
|
+
f"The provided job contains more than "
|
|
142
|
+
f"{self._max_num_molecules} input structures. Only the "
|
|
143
|
+
f"first {self._max_num_molecules} will be processed."
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# at the end, send a tuple to topic job-sizes with the overall size
|
|
149
|
+
# of the job
|
|
150
|
+
await self.channel.logs_topic().send(
|
|
151
|
+
LogMessage(
|
|
152
|
+
job_id=job_id,
|
|
153
|
+
message_type="report_job_size",
|
|
154
|
+
num_entries=num_entries,
|
|
155
|
+
num_checkpoints=num_checkpoints,
|
|
156
|
+
)
|
|
157
|
+
)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from nerdd_module import Model
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..files import FileSystem
|
|
8
|
+
from ..types import ModuleMessage, SystemMessage
|
|
9
|
+
from .action import Action
|
|
10
|
+
|
|
11
|
+
__all__ = ["RegisterModuleAction"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RegisterModuleAction(Action[SystemMessage]):
|
|
17
|
+
def __init__(self, channel: Channel, model: Model, data_dir: str):
|
|
18
|
+
super().__init__(channel.system_topic())
|
|
19
|
+
# TODO: do this differently
|
|
20
|
+
assert hasattr(model, "get_config")
|
|
21
|
+
self._model = model
|
|
22
|
+
self._file_system = FileSystem(data_dir)
|
|
23
|
+
|
|
24
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
25
|
+
config = self._model.get_config()
|
|
26
|
+
logger.info(f"Registering module with id {config.id}")
|
|
27
|
+
|
|
28
|
+
# save module as json to file
|
|
29
|
+
module_file = self._file_system.get_module_file_path(config.id)
|
|
30
|
+
json.dump(config.model_dump(), open(module_file, "w"))
|
|
31
|
+
|
|
32
|
+
# send the initialization message
|
|
33
|
+
await self.channel.modules_topic().send(ModuleMessage(id=config.id))
|
|
34
|
+
|
|
35
|
+
def _get_group_name(self) -> str:
|
|
36
|
+
model_id = self._model.get_config().id
|
|
37
|
+
return f"register-module-{model_id}"
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
from nerdd_module import OutputStep
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadPickleStep, SerializeJobModel
|
|
8
|
+
from ..files import FileSystem
|
|
9
|
+
from ..types import SerializationRequestMessage, SerializationResultMessage
|
|
10
|
+
from .action import Action
|
|
11
|
+
|
|
12
|
+
__all__ = ["SerializeJobAction"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SerializeJobAction(Action[SerializationRequestMessage]):
|
|
19
|
+
def __init__(self, channel: Channel, data_dir: str) -> None:
|
|
20
|
+
super().__init__(channel.serialization_requests_topic())
|
|
21
|
+
self._file_system = FileSystem(data_dir)
|
|
22
|
+
|
|
23
|
+
async def _process_message(self, message: SerializationRequestMessage) -> None:
|
|
24
|
+
job_id = message.job_id
|
|
25
|
+
job_type = message.job_type
|
|
26
|
+
params = message.params
|
|
27
|
+
output_format = message.output_format
|
|
28
|
+
logger.info(f"Write output for job {job_id} in format {output_format}")
|
|
29
|
+
|
|
30
|
+
# remove specific parameter keys that could induce vulnerabilities
|
|
31
|
+
params.pop("output_file", None)
|
|
32
|
+
params.pop("output_format", None)
|
|
33
|
+
|
|
34
|
+
# obtain output file
|
|
35
|
+
output_file = self._file_system.get_output_file(job_id, output_format)
|
|
36
|
+
|
|
37
|
+
# get the configuration for the job_type
|
|
38
|
+
config_file = self._file_system.get_module_file_path(job_type)
|
|
39
|
+
config = json.load(open(config_file, "r"))
|
|
40
|
+
|
|
41
|
+
# create a fake model instance to get the postprocessing steps
|
|
42
|
+
model = SerializeJobModel(config)
|
|
43
|
+
|
|
44
|
+
steps = [
|
|
45
|
+
# read the result checkpoint files in the correct order
|
|
46
|
+
ReadPickleStep(self._file_system.iter_results_file_handles(job_id)),
|
|
47
|
+
# don't preprocess, don't do prediction, only post-process
|
|
48
|
+
*model._get_postprocessing_steps(output_format, output_file=output_file, **params),
|
|
49
|
+
]
|
|
50
|
+
|
|
51
|
+
output_step = steps[-1]
|
|
52
|
+
assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
|
|
53
|
+
|
|
54
|
+
# build the pipeline from the list of steps
|
|
55
|
+
pipeline = None
|
|
56
|
+
for t in steps:
|
|
57
|
+
pipeline = t(pipeline)
|
|
58
|
+
|
|
59
|
+
# run the pipeline by calling the get_result method of the last step
|
|
60
|
+
output_step.get_result()
|
|
61
|
+
|
|
62
|
+
await self.channel.serialization_results_topic().send(
|
|
63
|
+
SerializationResultMessage(job_id=job_id, output_format=output_format)
|
|
64
|
+
)
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncIterable, Generic, TypeVar, Union, cast
|
|
5
|
+
|
|
6
|
+
from nerdd_module import Model
|
|
7
|
+
from stringcase import spinalcase
|
|
8
|
+
|
|
9
|
+
from ..types import (
|
|
10
|
+
CheckpointMessage,
|
|
11
|
+
JobMessage,
|
|
12
|
+
LogMessage,
|
|
13
|
+
Message,
|
|
14
|
+
ModuleMessage,
|
|
15
|
+
ResultCheckpointMessage,
|
|
16
|
+
ResultMessage,
|
|
17
|
+
SerializationRequestMessage,
|
|
18
|
+
SerializationResultMessage,
|
|
19
|
+
SystemMessage,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
__all__ = ["Channel", "Topic"]
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T", bound=Message)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_job_type(job_type_or_model: Union[str, Model]) -> str:
|
|
28
|
+
if isinstance(job_type_or_model, Model):
|
|
29
|
+
model = job_type_or_model
|
|
30
|
+
|
|
31
|
+
# create topic name from model name by
|
|
32
|
+
# * converting to spinal case, (e.g. "MyModel" -> "my-model")
|
|
33
|
+
# * converting to lowercase (just to be sure) and
|
|
34
|
+
# * removing all characters except dash and alphanumeric characters
|
|
35
|
+
# TODO: move to Module Id
|
|
36
|
+
topic_name = spinalcase(model.name)
|
|
37
|
+
topic_name = topic_name.lower()
|
|
38
|
+
topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
|
|
39
|
+
return topic_name
|
|
40
|
+
else:
|
|
41
|
+
return spinalcase(job_type_or_model)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Topic(Generic[T]):
|
|
45
|
+
def __init__(self, channel: Channel, name: str):
|
|
46
|
+
self._channel = channel
|
|
47
|
+
self._name = name
|
|
48
|
+
|
|
49
|
+
async def receive(self, consumer_group: str) -> AsyncIterable[T]:
|
|
50
|
+
async for msg in self.channel.iter_messages(self._name, consumer_group):
|
|
51
|
+
yield cast(T, msg)
|
|
52
|
+
|
|
53
|
+
async def send(self, message: T) -> None:
|
|
54
|
+
await self.channel.send(self._name, message)
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def channel(self) -> Channel:
|
|
58
|
+
return self._channel
|
|
59
|
+
|
|
60
|
+
def __repr__(self) -> str:
|
|
61
|
+
return f"Topic({self._name})"
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Channel(ABC):
|
|
65
|
+
def __init__(self) -> None:
|
|
66
|
+
self._is_running = False
|
|
67
|
+
|
|
68
|
+
async def start(self) -> None:
|
|
69
|
+
self._is_running = True
|
|
70
|
+
await self._start()
|
|
71
|
+
|
|
72
|
+
async def _start(self) -> None: # noqa: B027
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
async def stop(self) -> None:
|
|
76
|
+
await self._stop()
|
|
77
|
+
self._is_running = False
|
|
78
|
+
|
|
79
|
+
async def _stop(self) -> None: # noqa: B027
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
async def __aenter__(self) -> Channel:
|
|
83
|
+
await self.start()
|
|
84
|
+
return self
|
|
85
|
+
|
|
86
|
+
async def __aexit__(self, exc_type: type, exc_value: Exception, traceback: object) -> None:
|
|
87
|
+
await self.stop()
|
|
88
|
+
|
|
89
|
+
#
|
|
90
|
+
# RECEIVE
|
|
91
|
+
#
|
|
92
|
+
async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
93
|
+
if not self._is_running:
|
|
94
|
+
raise RuntimeError("Channel is not running. Call start() first.")
|
|
95
|
+
async for message in self._iter_messages(topic, consumer_group):
|
|
96
|
+
yield message
|
|
97
|
+
|
|
98
|
+
# Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
|
|
99
|
+
# here, because the method doesn't use "yield" and so the type checker will assume that the
|
|
100
|
+
# actual type is Coroutine[AsyncIterable[Message], None, None].
|
|
101
|
+
@abstractmethod
|
|
102
|
+
def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
103
|
+
pass
|
|
104
|
+
|
|
105
|
+
#
|
|
106
|
+
# SEND
|
|
107
|
+
#
|
|
108
|
+
async def send(self, topic: str, message: Message) -> None:
|
|
109
|
+
if not self._is_running:
|
|
110
|
+
raise RuntimeError("Channel is not running. Call start() first.")
|
|
111
|
+
await self._send(topic, message)
|
|
112
|
+
|
|
113
|
+
@abstractmethod
|
|
114
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
#
|
|
118
|
+
# TOPICS
|
|
119
|
+
#
|
|
120
|
+
def modules_topic(self) -> Topic[ModuleMessage]:
|
|
121
|
+
return Topic[ModuleMessage](self, "modules")
|
|
122
|
+
|
|
123
|
+
def jobs_topic(self) -> Topic[JobMessage]:
|
|
124
|
+
return Topic[JobMessage](self, "jobs")
|
|
125
|
+
|
|
126
|
+
def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
|
|
127
|
+
job_type = get_job_type(job_type_or_model)
|
|
128
|
+
topic_name = f"{job_type}-checkpoints"
|
|
129
|
+
return Topic[CheckpointMessage](self, topic_name)
|
|
130
|
+
|
|
131
|
+
def results_topic(self) -> Topic[ResultMessage]:
|
|
132
|
+
return Topic[ResultMessage](self, "results")
|
|
133
|
+
|
|
134
|
+
def result_checkpoints_topic(self) -> Topic[ResultCheckpointMessage]:
|
|
135
|
+
return Topic[ResultCheckpointMessage](self, "result-checkpoints")
|
|
136
|
+
|
|
137
|
+
def serialization_requests_topic(self) -> Topic[SerializationRequestMessage]:
|
|
138
|
+
return Topic[SerializationRequestMessage](self, "serialization-requests")
|
|
139
|
+
|
|
140
|
+
def serialization_results_topic(self) -> Topic[SerializationResultMessage]:
|
|
141
|
+
return Topic[SerializationResultMessage](self, "serialization-results")
|
|
142
|
+
|
|
143
|
+
def logs_topic(self) -> Topic[LogMessage]:
|
|
144
|
+
return Topic[LogMessage](self, "logs")
|
|
145
|
+
|
|
146
|
+
def system_topic(self) -> Topic[SystemMessage]:
|
|
147
|
+
return Topic[SystemMessage](self, "system")
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import AsyncIterable, Dict, Tuple
|
|
4
|
+
|
|
5
|
+
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
|
|
6
|
+
from aiokafka.errors import CommitFailedError
|
|
7
|
+
|
|
8
|
+
from ..types import Message
|
|
9
|
+
from .channel import Channel
|
|
10
|
+
|
|
11
|
+
__all__ = ["KafkaChannel"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KafkaChannel(Channel):
|
|
17
|
+
def __init__(self, broker_url: str) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._broker_url = broker_url
|
|
20
|
+
self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
|
|
21
|
+
|
|
22
|
+
async def _start(self) -> None:
|
|
23
|
+
self._producer = AIOKafkaProducer(
|
|
24
|
+
bootstrap_servers=[self._broker_url],
|
|
25
|
+
value_serializer=lambda v: json.dumps(v.model_dump()).encode("utf-8"),
|
|
26
|
+
)
|
|
27
|
+
logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer...")
|
|
28
|
+
await self._producer.start()
|
|
29
|
+
|
|
30
|
+
for consumer in self._consumers.values():
|
|
31
|
+
await consumer.start()
|
|
32
|
+
|
|
33
|
+
async def _stop(self) -> None:
|
|
34
|
+
await self._producer.stop()
|
|
35
|
+
for consumer in self._consumers.values():
|
|
36
|
+
await consumer.stop()
|
|
37
|
+
|
|
38
|
+
async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
39
|
+
key = (topic, consumer_group)
|
|
40
|
+
|
|
41
|
+
if key not in self._consumers:
|
|
42
|
+
# create consumer
|
|
43
|
+
consumer = AIOKafkaConsumer(
|
|
44
|
+
topic,
|
|
45
|
+
bootstrap_servers=[self._broker_url],
|
|
46
|
+
auto_offset_reset="earliest",
|
|
47
|
+
group_id=consumer_group,
|
|
48
|
+
enable_auto_commit=False,
|
|
49
|
+
# consume only one message at a time
|
|
50
|
+
max_poll_records=1,
|
|
51
|
+
# one message should be consumed within 10 minutes
|
|
52
|
+
max_poll_interval_ms=10 * 60 * 1000,
|
|
53
|
+
session_timeout_ms=10 * 60 * 1000,
|
|
54
|
+
# send heartbeat every minute
|
|
55
|
+
heartbeat_interval_ms=1 * 60 * 1000,
|
|
56
|
+
)
|
|
57
|
+
await consumer.start()
|
|
58
|
+
self._consumers[key] = consumer
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
|
|
61
|
+
f"topic {topic}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
consumer = self._consumers[key]
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
async for message in consumer:
|
|
68
|
+
message_obj = json.loads(message.value)
|
|
69
|
+
yield Message(**message_obj)
|
|
70
|
+
try:
|
|
71
|
+
await consumer.commit()
|
|
72
|
+
except CommitFailedError as e:
|
|
73
|
+
logger.error(f"Commit failed: {e}... trying again.")
|
|
74
|
+
finally:
|
|
75
|
+
await consumer.stop()
|
|
76
|
+
|
|
77
|
+
# try:
|
|
78
|
+
# while True:
|
|
79
|
+
# # we use polling (instead of iterating through the consumer messages)
|
|
80
|
+
# # to be able to cancel the consumer
|
|
81
|
+
# messages = await self.kafka_consumer.getmany(timeout_ms=1000)
|
|
82
|
+
|
|
83
|
+
# if messages:
|
|
84
|
+
# for _, message_list in messages.items():
|
|
85
|
+
# for message in message_list:
|
|
86
|
+
# result = json.loads(message.value)
|
|
87
|
+
# logger.info(f"Received message on {message.topic}")
|
|
88
|
+
|
|
89
|
+
# try:
|
|
90
|
+
# for consumer in self.consumers:
|
|
91
|
+
# await consumer.consume(result)
|
|
92
|
+
|
|
93
|
+
# logger.info("Committing message")
|
|
94
|
+
# await self.kafka_consumer.commit()
|
|
95
|
+
# except Exception:
|
|
96
|
+
# logger.info("Rolling back message")
|
|
97
|
+
# logger.error(traceback.format_exc())
|
|
98
|
+
# except asyncio.CancelledError:
|
|
99
|
+
# logger.info("Stopping ConsumeKafkaTopicLifespan")
|
|
100
|
+
# await self.kafka_consumer.stop()
|
|
101
|
+
# except Exception as e:
|
|
102
|
+
# logger.error(e)
|
|
103
|
+
# logger.error(traceback.format_exc())
|
|
104
|
+
|
|
105
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
106
|
+
await self._producer.send_and_wait(
|
|
107
|
+
topic,
|
|
108
|
+
message,
|
|
109
|
+
)
|