nerdd-link 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. nerdd_link/__init__.py +6 -0
  2. nerdd_link/actions/__init__.py +5 -0
  3. nerdd_link/actions/action.py +41 -0
  4. nerdd_link/actions/predict_checkpoints_action.py +76 -0
  5. nerdd_link/actions/process_jobs_action.py +157 -0
  6. nerdd_link/actions/register_module_action.py +37 -0
  7. nerdd_link/actions/serialize_job_action.py +64 -0
  8. nerdd_link/channels/__init__.py +3 -0
  9. nerdd_link/channels/channel.py +147 -0
  10. nerdd_link/channels/kafka_channel.py +109 -0
  11. nerdd_link/channels/memory_channel.py +47 -0
  12. nerdd_link/cli/__init__.py +4 -0
  13. nerdd_link/cli/initialize_system.py +49 -0
  14. nerdd_link/cli/run_job_server.py +115 -0
  15. nerdd_link/cli/run_prediction_server.py +85 -0
  16. nerdd_link/cli/run_serialization_server.py +73 -0
  17. nerdd_link/converters/__init__.py +6 -0
  18. nerdd_link/converters/image_converter.py +15 -0
  19. nerdd_link/converters/mol_pickle_converter.py +20 -0
  20. nerdd_link/converters/mol_to_image_converter.py +77 -0
  21. nerdd_link/converters/pickle_converter.py +15 -0
  22. nerdd_link/converters/problem_list_converter.py +15 -0
  23. nerdd_link/converters/source_list_converter.py +15 -0
  24. nerdd_link/delegates/__init__.py +5 -0
  25. nerdd_link/delegates/pickle_writer.py +18 -0
  26. nerdd_link/delegates/read_checkpoint_model.py +73 -0
  27. nerdd_link/delegates/read_pickle_step.py +21 -0
  28. nerdd_link/delegates/serialize_job_model.py +21 -0
  29. nerdd_link/delegates/split_and_merge_step.py +51 -0
  30. nerdd_link/delegates/topic_writer.py +19 -0
  31. nerdd_link/files/__init__.py +1 -0
  32. nerdd_link/files/file_system.py +89 -0
  33. nerdd_link/input/__init__.py +1 -0
  34. nerdd_link/input/structure_json_reader.py +37 -0
  35. nerdd_link/py.typed +0 -0
  36. nerdd_link/tests/__init__.py +3 -0
  37. nerdd_link/tests/async_step.py +22 -0
  38. nerdd_link/tests/channels.py +81 -0
  39. nerdd_link/tests/files.py +9 -0
  40. nerdd_link/types/__init__.py +70 -0
  41. nerdd_link/utils/__init__.py +4 -0
  42. nerdd_link/utils/async_to_sync.py +26 -0
  43. nerdd_link/utils/batched.py +27 -0
  44. nerdd_link/utils/observable_list.py +72 -0
  45. nerdd_link/utils/safetee.py +39 -0
  46. nerdd_link/version.py +11 -0
  47. nerdd_link-0.2.11.dist-info/LICENSE +21 -0
  48. nerdd_link-0.2.11.dist-info/METADATA +116 -0
  49. nerdd_link-0.2.11.dist-info/RECORD +52 -0
  50. nerdd_link-0.2.11.dist-info/WHEEL +5 -0
  51. nerdd_link-0.2.11.dist-info/entry_points.txt +5 -0
  52. nerdd_link-0.2.11.dist-info/top_level.txt +1 -0
nerdd_link/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from .actions import *
2
+ from .channels import *
3
+ from .converters import *
4
+ from .files import *
5
+ from .input import *
6
+ from .types import *
@@ -0,0 +1,5 @@
1
+ from .action import *
2
+ from .predict_checkpoints_action import *
3
+ from .process_jobs_action import *
4
+ from .register_module_action import *
5
+ from .serialize_job_action import *
@@ -0,0 +1,41 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from asyncio import CancelledError
4
+ from typing import Generic, TypeVar
5
+
6
+ from stringcase import spinalcase
7
+
8
+ from ..channels import Channel, Topic
9
+ from ..types import Message
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ T = TypeVar("T", bound=Message)
14
+
15
+
16
+ class Action(ABC, Generic[T]):
17
+ def __init__(self, input_topic: Topic[T]):
18
+ self._input_topic = input_topic
19
+
20
+ async def run(self) -> None:
21
+ consumer_group = spinalcase(self._get_group_name())
22
+ async for message in self._input_topic.receive(consumer_group):
23
+ try:
24
+ await self._process_message(message)
25
+ except CancelledError:
26
+ # the consumer was cancelled, stop processing messages
27
+ break
28
+ except Exception:
29
+ # log the error and continue processing the next message
30
+ logger.error("Error processing message", exc_info=True)
31
+
32
+ @abstractmethod
33
+ async def _process_message(self, message: T) -> None:
34
+ pass
35
+
36
+ @property
37
+ def channel(self) -> Channel:
38
+ return self._input_topic.channel
39
+
40
+ def _get_group_name(self) -> str:
41
+ return self.__class__.__name__
@@ -0,0 +1,76 @@
1
+ import logging
2
+ from multiprocessing import Queue
3
+
4
+ from nerdd_module import SimpleModel
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadCheckpointModel
8
+ from ..files import FileSystem
9
+ from ..types import CheckpointMessage, ResultCheckpointMessage, ResultMessage
10
+ from .action import Action
11
+
12
+ __all__ = ["PredictCheckpointsAction"]
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class PredictCheckpointsAction(Action[CheckpointMessage]):
18
+ # Accept a batch of input molecules on the "<job-type>-checkpoints" topic
19
+ # (generated in the previous step) and process them. Results are written to
20
+ # the "results" topic.
21
+
22
+ def __init__(self, channel: Channel, model: SimpleModel, data_dir: str) -> None:
23
+ super().__init__(channel.checkpoints_topic(model))
24
+ self._model = model
25
+ self._data_dir = data_dir
26
+
27
+ async def _process_message(self, message: CheckpointMessage) -> None:
28
+ job_id = message.job_id
29
+ checkpoint_id = message.checkpoint_id
30
+ params = message.params
31
+ logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
32
+
33
+ # The Kafka consumers and producers run in the current asyncio event loop and (by
34
+ # observation) it seems that calling the produce method of a Kafka producer in a different
35
+ # event loop / thread / process doesn't seem to work (hangs indefinitely). Therefore, we
36
+ # create a queue in this event loop / thread and other tasks send messages to the queue
37
+ # instead of directly to the Kafka producer. This event loop will wait for new messages in
38
+ # this queue and forward them to the Kafka producer.
39
+ queue: Queue = Queue()
40
+
41
+ file_system = FileSystem(self._data_dir)
42
+
43
+ # create a wrapper model that
44
+ # * reads the checkpoint file instead of normal input
45
+ # * does preprocessing, prediction, and postprocessing like the encapsulated model
46
+ # * does not write to the specified results file, but to the checkpoints file instead
47
+ # * sends the results to the results topic
48
+ model = ReadCheckpointModel(
49
+ base_model=self._model,
50
+ job_id=job_id,
51
+ file_system=file_system,
52
+ checkpoint_id=checkpoint_id,
53
+ queue=queue,
54
+ )
55
+
56
+ # predict the checkpoint
57
+ # assign input=None, because the checkpoint file is already provided in ReadCheckpointModel
58
+ model.predict(
59
+ input=None,
60
+ **params,
61
+ )
62
+
63
+ # Wait for the prediction to finish and the results to be sent.
64
+ while True:
65
+ record = queue.get()
66
+ if record is not None:
67
+ await self.channel.results_topic().send(ResultMessage(job_id=job_id, **record))
68
+ else:
69
+ await self.channel.result_checkpoints_topic().send(
70
+ ResultCheckpointMessage(job_id=job_id, checkpoint_id=checkpoint_id)
71
+ )
72
+ break
73
+
74
+ def _get_group_name(self) -> str:
75
+ model_id = self._model.get_config().id
76
+ return f"predict-checkpoints-{model_id}"
@@ -0,0 +1,157 @@
1
+ import logging
2
+ from pickle import dump
3
+ from typing import Any
4
+
5
+ from nerdd_module.input import DepthFirstExplorer
6
+ from nerdd_module.model import ReadInputStep
7
+ from rdkit.Chem import Mol
8
+ from rdkit.Chem.PropertyMol import PropertyMol
9
+
10
+ from ..channels import Channel
11
+ from ..files import FileSystem
12
+ from ..types import CheckpointMessage, JobMessage, LogMessage
13
+ from ..utils import batched
14
+ from .action import Action
15
+
16
+ __all__ = ["ProcessJobsAction"]
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class ProcessJobsAction(Action[JobMessage]):
22
+ # Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
23
+ # iterates through all molecules in the input (files), writes them as batches
24
+ # into checkpoint files and sends checkpoint messages (for each batch) to the
25
+ # "<job_type>-checkpoints" topic. Also, the number of molecules read is
26
+ # reported to the topic "job-sizes".
27
+
28
+ def __init__(
29
+ self,
30
+ channel: Channel,
31
+ checkpoint_size: int,
32
+ max_num_molecules: int,
33
+ num_test_entries: int,
34
+ ratio_valid_entries: float,
35
+ maximum_depth: int,
36
+ max_num_lines_mol_block: int,
37
+ data_dir: str,
38
+ ) -> None:
39
+ super().__init__(channel.jobs_topic())
40
+ # relevant for chunking
41
+ self._checkpoint_size = checkpoint_size
42
+ self._max_num_molecules = max_num_molecules
43
+ # parameters of DepthFirstExplorer
44
+ self._num_test_entries = num_test_entries
45
+ self._ratio_valid_entries = ratio_valid_entries
46
+ self._maximum_depth = maximum_depth
47
+ # used as kwargs in DepthFirstExplorer
48
+ self._max_num_lines_mol_block = max_num_lines_mol_block
49
+ self._file_system = FileSystem(data_dir)
50
+
51
+ async def _process_message(self, message: JobMessage) -> None:
52
+ job_id = message.id
53
+ job_type = message.job_type
54
+ logger.info(f"Received a new job {job_id} of type {job_type}")
55
+
56
+ # the input file to the job is stored in a designated sources directory
57
+ # (the file is allowed to reference other files, but setting the data_dir
58
+ # to the sources directory ensures that we never read files outside of the
59
+ # sources directory)
60
+ data_dir = self._file_system.get_sources_dir()
61
+
62
+ # create a reader (explorer) for the input file
63
+ explorer = DepthFirstExplorer(
64
+ num_test_entries=self._num_test_entries,
65
+ threshold=self._ratio_valid_entries,
66
+ maximum_depth=self._maximum_depth,
67
+ # extra args
68
+ max_num_lines_mol_block=self._max_num_lines_mol_block,
69
+ data_dir=data_dir,
70
+ )
71
+
72
+ read_input_step = ReadInputStep(explorer, message.source_id)
73
+
74
+ # read the input file
75
+ entries = read_input_step()
76
+
77
+ # iterate through the entries
78
+ # create batches of size checkpoint_size
79
+ # limit the number of molecules to max_num_molecules
80
+ batches = batched(entries, self._checkpoint_size)
81
+ num_entries = 0
82
+ num_checkpoints = 0
83
+ for i, batch in enumerate(batches):
84
+ # max_num_molecules might be reached within the batch
85
+ num_store = min(len(batch), self._max_num_molecules - num_entries)
86
+
87
+ # store batch in data_dir
88
+ with self._file_system.get_checkpoint_file_handle(job_id, i, "wb") as f:
89
+ results = list(batch[:num_store])
90
+
91
+ # TODO: use a model for storing the batches
92
+
93
+ # check all items for mol values and use PropertyMol for those
94
+ # in order to keep molecular properties (thanks, RDKit! :/ )
95
+ def _check_value(value: Any) -> Any:
96
+ if isinstance(value, Mol):
97
+ return PropertyMol(value)
98
+ return value
99
+
100
+ def _check_item(item: dict) -> dict:
101
+ return {key: _check_value(value) for key, value in item.items()}
102
+
103
+ results = [_check_item(item) for item in results]
104
+
105
+ dump(results, f)
106
+
107
+ # send a tuple to topic cypstrate-checkpoints
108
+ await self.channel.checkpoints_topic(job_type).send(
109
+ CheckpointMessage(
110
+ job_id=job_id,
111
+ checkpoint_id=i,
112
+ params=message.params,
113
+ )
114
+ )
115
+
116
+ num_entries += num_store
117
+ num_checkpoints += 1
118
+
119
+ if num_entries >= self._max_num_molecules:
120
+ break
121
+
122
+ logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
123
+
124
+ # send a warning message if there were more molecules in the job than allowed
125
+ too_many_molecules = num_store < len(batch)
126
+ try:
127
+ # try to get another entry
128
+ next(entries)
129
+
130
+ # if we get here, there was another entry and we need to send a warning
131
+ too_many_molecules = True
132
+ except StopIteration:
133
+ pass
134
+
135
+ if too_many_molecules:
136
+ await self.channel.logs_topic().send(
137
+ LogMessage(
138
+ job_id=job_id,
139
+ message_type="warning",
140
+ message=(
141
+ f"The provided job contains more than "
142
+ f"{self._max_num_molecules} input structures. Only the "
143
+ f"first {self._max_num_molecules} will be processed."
144
+ ),
145
+ )
146
+ )
147
+
148
+ # at the end, send a tuple to topic job-sizes with the overall size
149
+ # of the job
150
+ await self.channel.logs_topic().send(
151
+ LogMessage(
152
+ job_id=job_id,
153
+ message_type="report_job_size",
154
+ num_entries=num_entries,
155
+ num_checkpoints=num_checkpoints,
156
+ )
157
+ )
@@ -0,0 +1,37 @@
1
+ import json
2
+ import logging
3
+
4
+ from nerdd_module import Model
5
+
6
+ from ..channels import Channel
7
+ from ..files import FileSystem
8
+ from ..types import ModuleMessage, SystemMessage
9
+ from .action import Action
10
+
11
+ __all__ = ["RegisterModuleAction"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RegisterModuleAction(Action[SystemMessage]):
17
+ def __init__(self, channel: Channel, model: Model, data_dir: str):
18
+ super().__init__(channel.system_topic())
19
+ # TODO: do this differently
20
+ assert hasattr(model, "get_config")
21
+ self._model = model
22
+ self._file_system = FileSystem(data_dir)
23
+
24
+ async def _process_message(self, message: SystemMessage) -> None:
25
+ config = self._model.get_config()
26
+ logger.info(f"Registering module with id {config.id}")
27
+
28
+ # save module as json to file
29
+ module_file = self._file_system.get_module_file_path(config.id)
30
+ json.dump(config.model_dump(), open(module_file, "w"))
31
+
32
+ # send the initialization message
33
+ await self.channel.modules_topic().send(ModuleMessage(id=config.id))
34
+
35
+ def _get_group_name(self) -> str:
36
+ model_id = self._model.get_config().id
37
+ return f"register-module-{model_id}"
@@ -0,0 +1,64 @@
1
+ import json
2
+ import logging
3
+
4
+ from nerdd_module import OutputStep
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadPickleStep, SerializeJobModel
8
+ from ..files import FileSystem
9
+ from ..types import SerializationRequestMessage, SerializationResultMessage
10
+ from .action import Action
11
+
12
+ __all__ = ["SerializeJobAction"]
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class SerializeJobAction(Action[SerializationRequestMessage]):
19
+ def __init__(self, channel: Channel, data_dir: str) -> None:
20
+ super().__init__(channel.serialization_requests_topic())
21
+ self._file_system = FileSystem(data_dir)
22
+
23
+ async def _process_message(self, message: SerializationRequestMessage) -> None:
24
+ job_id = message.job_id
25
+ job_type = message.job_type
26
+ params = message.params
27
+ output_format = message.output_format
28
+ logger.info(f"Write output for job {job_id} in format {output_format}")
29
+
30
+ # remove specific parameter keys that could induce vulnerabilities
31
+ params.pop("output_file", None)
32
+ params.pop("output_format", None)
33
+
34
+ # obtain output file
35
+ output_file = self._file_system.get_output_file(job_id, output_format)
36
+
37
+ # get the configuration for the job_type
38
+ config_file = self._file_system.get_module_file_path(job_type)
39
+ config = json.load(open(config_file, "r"))
40
+
41
+ # create a fake model instance to get the postprocessing steps
42
+ model = SerializeJobModel(config)
43
+
44
+ steps = [
45
+ # read the result checkpoint files in the correct order
46
+ ReadPickleStep(self._file_system.iter_results_file_handles(job_id)),
47
+ # don't preprocess, don't do prediction, only post-process
48
+ *model._get_postprocessing_steps(output_format, output_file=output_file, **params),
49
+ ]
50
+
51
+ output_step = steps[-1]
52
+ assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
53
+
54
+ # build the pipeline from the list of steps
55
+ pipeline = None
56
+ for t in steps:
57
+ pipeline = t(pipeline)
58
+
59
+ # run the pipeline by calling the get_result method of the last step
60
+ output_step.get_result()
61
+
62
+ await self.channel.serialization_results_topic().send(
63
+ SerializationResultMessage(job_id=job_id, output_format=output_format)
64
+ )
@@ -0,0 +1,3 @@
1
+ from .channel import *
2
+ from .kafka_channel import *
3
+ from .memory_channel import *
@@ -0,0 +1,147 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterable, Generic, TypeVar, Union, cast
5
+
6
+ from nerdd_module import Model
7
+ from stringcase import spinalcase
8
+
9
+ from ..types import (
10
+ CheckpointMessage,
11
+ JobMessage,
12
+ LogMessage,
13
+ Message,
14
+ ModuleMessage,
15
+ ResultCheckpointMessage,
16
+ ResultMessage,
17
+ SerializationRequestMessage,
18
+ SerializationResultMessage,
19
+ SystemMessage,
20
+ )
21
+
22
+ __all__ = ["Channel", "Topic"]
23
+
24
+ T = TypeVar("T", bound=Message)
25
+
26
+
27
+ def get_job_type(job_type_or_model: Union[str, Model]) -> str:
28
+ if isinstance(job_type_or_model, Model):
29
+ model = job_type_or_model
30
+
31
+ # create topic name from model name by
32
+ # * converting to spinal case, (e.g. "MyModel" -> "my-model")
33
+ # * converting to lowercase (just to be sure) and
34
+ # * removing all characters except dash and alphanumeric characters
35
+ # TODO: move to Module Id
36
+ topic_name = spinalcase(model.name)
37
+ topic_name = topic_name.lower()
38
+ topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
39
+ return topic_name
40
+ else:
41
+ return spinalcase(job_type_or_model)
42
+
43
+
44
+ class Topic(Generic[T]):
45
+ def __init__(self, channel: Channel, name: str):
46
+ self._channel = channel
47
+ self._name = name
48
+
49
+ async def receive(self, consumer_group: str) -> AsyncIterable[T]:
50
+ async for msg in self.channel.iter_messages(self._name, consumer_group):
51
+ yield cast(T, msg)
52
+
53
+ async def send(self, message: T) -> None:
54
+ await self.channel.send(self._name, message)
55
+
56
+ @property
57
+ def channel(self) -> Channel:
58
+ return self._channel
59
+
60
+ def __repr__(self) -> str:
61
+ return f"Topic({self._name})"
62
+
63
+
64
+ class Channel(ABC):
65
+ def __init__(self) -> None:
66
+ self._is_running = False
67
+
68
+ async def start(self) -> None:
69
+ self._is_running = True
70
+ await self._start()
71
+
72
+ async def _start(self) -> None: # noqa: B027
73
+ pass
74
+
75
+ async def stop(self) -> None:
76
+ await self._stop()
77
+ self._is_running = False
78
+
79
+ async def _stop(self) -> None: # noqa: B027
80
+ pass
81
+
82
+ async def __aenter__(self) -> Channel:
83
+ await self.start()
84
+ return self
85
+
86
+ async def __aexit__(self, exc_type: type, exc_value: Exception, traceback: object) -> None:
87
+ await self.stop()
88
+
89
+ #
90
+ # RECEIVE
91
+ #
92
+ async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
93
+ if not self._is_running:
94
+ raise RuntimeError("Channel is not running. Call start() first.")
95
+ async for message in self._iter_messages(topic, consumer_group):
96
+ yield message
97
+
98
+ # Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
99
+ # here, because the method doesn't use "yield" and so the type checker will assume that the
100
+ # actual type is Coroutine[AsyncIterable[Message], None, None].
101
+ @abstractmethod
102
+ def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
103
+ pass
104
+
105
+ #
106
+ # SEND
107
+ #
108
+ async def send(self, topic: str, message: Message) -> None:
109
+ if not self._is_running:
110
+ raise RuntimeError("Channel is not running. Call start() first.")
111
+ await self._send(topic, message)
112
+
113
+ @abstractmethod
114
+ async def _send(self, topic: str, message: Message) -> None:
115
+ pass
116
+
117
+ #
118
+ # TOPICS
119
+ #
120
+ def modules_topic(self) -> Topic[ModuleMessage]:
121
+ return Topic[ModuleMessage](self, "modules")
122
+
123
+ def jobs_topic(self) -> Topic[JobMessage]:
124
+ return Topic[JobMessage](self, "jobs")
125
+
126
+ def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
127
+ job_type = get_job_type(job_type_or_model)
128
+ topic_name = f"{job_type}-checkpoints"
129
+ return Topic[CheckpointMessage](self, topic_name)
130
+
131
+ def results_topic(self) -> Topic[ResultMessage]:
132
+ return Topic[ResultMessage](self, "results")
133
+
134
+ def result_checkpoints_topic(self) -> Topic[ResultCheckpointMessage]:
135
+ return Topic[ResultCheckpointMessage](self, "result-checkpoints")
136
+
137
+ def serialization_requests_topic(self) -> Topic[SerializationRequestMessage]:
138
+ return Topic[SerializationRequestMessage](self, "serialization-requests")
139
+
140
+ def serialization_results_topic(self) -> Topic[SerializationResultMessage]:
141
+ return Topic[SerializationResultMessage](self, "serialization-results")
142
+
143
+ def logs_topic(self) -> Topic[LogMessage]:
144
+ return Topic[LogMessage](self, "logs")
145
+
146
+ def system_topic(self) -> Topic[SystemMessage]:
147
+ return Topic[SystemMessage](self, "system")
@@ -0,0 +1,109 @@
1
+ import json
2
+ import logging
3
+ from typing import AsyncIterable, Dict, Tuple
4
+
5
+ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
6
+ from aiokafka.errors import CommitFailedError
7
+
8
+ from ..types import Message
9
+ from .channel import Channel
10
+
11
+ __all__ = ["KafkaChannel"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class KafkaChannel(Channel):
17
+ def __init__(self, broker_url: str) -> None:
18
+ super().__init__()
19
+ self._broker_url = broker_url
20
+ self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
21
+
22
+ async def _start(self) -> None:
23
+ self._producer = AIOKafkaProducer(
24
+ bootstrap_servers=[self._broker_url],
25
+ value_serializer=lambda v: json.dumps(v.model_dump()).encode("utf-8"),
26
+ )
27
+ logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer...")
28
+ await self._producer.start()
29
+
30
+ for consumer in self._consumers.values():
31
+ await consumer.start()
32
+
33
+ async def _stop(self) -> None:
34
+ await self._producer.stop()
35
+ for consumer in self._consumers.values():
36
+ await consumer.stop()
37
+
38
+ async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
39
+ key = (topic, consumer_group)
40
+
41
+ if key not in self._consumers:
42
+ # create consumer
43
+ consumer = AIOKafkaConsumer(
44
+ topic,
45
+ bootstrap_servers=[self._broker_url],
46
+ auto_offset_reset="earliest",
47
+ group_id=consumer_group,
48
+ enable_auto_commit=False,
49
+ # consume only one message at a time
50
+ max_poll_records=1,
51
+ # one message should be consumed within 10 minutes
52
+ max_poll_interval_ms=10 * 60 * 1000,
53
+ session_timeout_ms=10 * 60 * 1000,
54
+ # send heartbeat every minute
55
+ heartbeat_interval_ms=1 * 60 * 1000,
56
+ )
57
+ await consumer.start()
58
+ self._consumers[key] = consumer
59
+ logger.info(
60
+ f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
61
+ f"topic {topic}."
62
+ )
63
+
64
+ consumer = self._consumers[key]
65
+
66
+ try:
67
+ async for message in consumer:
68
+ message_obj = json.loads(message.value)
69
+ yield Message(**message_obj)
70
+ try:
71
+ await consumer.commit()
72
+ except CommitFailedError as e:
73
+ logger.error(f"Commit failed: {e}... trying again.")
74
+ finally:
75
+ await consumer.stop()
76
+
77
+ # try:
78
+ # while True:
79
+ # # we use polling (instead of iterating through the consumer messages)
80
+ # # to be able to cancel the consumer
81
+ # messages = await self.kafka_consumer.getmany(timeout_ms=1000)
82
+
83
+ # if messages:
84
+ # for _, message_list in messages.items():
85
+ # for message in message_list:
86
+ # result = json.loads(message.value)
87
+ # logger.info(f"Received message on {message.topic}")
88
+
89
+ # try:
90
+ # for consumer in self.consumers:
91
+ # await consumer.consume(result)
92
+
93
+ # logger.info("Committing message")
94
+ # await self.kafka_consumer.commit()
95
+ # except Exception:
96
+ # logger.info("Rolling back message")
97
+ # logger.error(traceback.format_exc())
98
+ # except asyncio.CancelledError:
99
+ # logger.info("Stopping ConsumeKafkaTopicLifespan")
100
+ # await self.kafka_consumer.stop()
101
+ # except Exception as e:
102
+ # logger.error(e)
103
+ # logger.error(traceback.format_exc())
104
+
105
+ async def _send(self, topic: str, message: Message) -> None:
106
+ await self._producer.send_and_wait(
107
+ topic,
108
+ message,
109
+ )