nerdd-link 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nerdd_link/__init__.py +4 -0
- nerdd_link/actions/__init__.py +5 -0
- nerdd_link/actions/action.py +30 -0
- nerdd_link/actions/predict_checkpoints_action.py +59 -0
- nerdd_link/actions/process_jobs_action.py +138 -0
- nerdd_link/actions/register_module_action.py +30 -0
- nerdd_link/actions/write_output_action.py +21 -0
- nerdd_link/channels/__init__.py +3 -0
- nerdd_link/channels/channel.py +114 -0
- nerdd_link/channels/kafka_channel.py +97 -0
- nerdd_link/channels/memory_channel.py +33 -0
- nerdd_link/cli/__init__.py +3 -0
- nerdd_link/cli/initialize_system.py +45 -0
- nerdd_link/cli/run_job_server.py +107 -0
- nerdd_link/cli/run_prediction_server.py +81 -0
- nerdd_link/delegates/__init__.py +3 -0
- nerdd_link/delegates/pickle_writer.py +18 -0
- nerdd_link/delegates/read_checkpoint_model.py +65 -0
- nerdd_link/delegates/read_pickle_step.py +18 -0
- nerdd_link/delegates/split_and_merge_step.py +51 -0
- nerdd_link/delegates/topic_writer.py +27 -0
- nerdd_link/input/__init__.py +1 -0
- nerdd_link/input/structure_json_reader.py +41 -0
- nerdd_link/py.typed +0 -0
- nerdd_link/tests/__init__.py +3 -0
- nerdd_link/tests/async_step.py +22 -0
- nerdd_link/tests/channels.py +82 -0
- nerdd_link/tests/files.py +9 -0
- nerdd_link/types/__init__.py +54 -0
- nerdd_link/utils/__init__.py +4 -0
- nerdd_link/utils/async_to_sync.py +26 -0
- nerdd_link/utils/batched.py +27 -0
- nerdd_link/utils/observable_list.py +72 -0
- nerdd_link/utils/safetee.py +39 -0
- nerdd_link/version.py +11 -0
- nerdd_link-0.1.0.dist-info/LICENSE +21 -0
- nerdd_link-0.1.0.dist-info/METADATA +116 -0
- nerdd_link-0.1.0.dist-info/RECORD +41 -0
- nerdd_link-0.1.0.dist-info/WHEEL +5 -0
- nerdd_link-0.1.0.dist-info/entry_points.txt +4 -0
- nerdd_link-0.1.0.dist-info/top_level.txt +1 -0
nerdd_link/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Generic, TypeVar
|
|
3
|
+
|
|
4
|
+
from stringcase import spinalcase
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel, Topic
|
|
7
|
+
from ..types import Message
|
|
8
|
+
|
|
9
|
+
T = TypeVar("T", bound=Message)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Action(ABC, Generic[T]):
|
|
13
|
+
def __init__(self, input_topic: Topic[T]):
|
|
14
|
+
self._input_topic = input_topic
|
|
15
|
+
|
|
16
|
+
async def run(self) -> None:
|
|
17
|
+
consumer_group = spinalcase(self._get_group_name())
|
|
18
|
+
async for message in self._input_topic.receive(consumer_group):
|
|
19
|
+
await self._process_message(message)
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
async def _process_message(self, message: T) -> None:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def channel(self) -> Channel:
|
|
27
|
+
return self._input_topic.channel
|
|
28
|
+
|
|
29
|
+
def _get_group_name(self) -> str:
|
|
30
|
+
return self.__class__.__name__
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from nerdd_module import Model
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..delegates import ReadCheckpointModel
|
|
8
|
+
from ..types import CheckpointMessage
|
|
9
|
+
from .action import Action
|
|
10
|
+
|
|
11
|
+
__all__ = ["PredictCheckpointsAction"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PredictCheckpointsAction(Action[CheckpointMessage]):
|
|
17
|
+
# Accept a batch of input molecules on the "<job-type>-checkpoints" topic
|
|
18
|
+
# (generated in the previous step) and process them. Results are written to
|
|
19
|
+
# the "results" topic.
|
|
20
|
+
|
|
21
|
+
def __init__(self, channel: Channel, model: Model, data_dir: str) -> None:
|
|
22
|
+
super().__init__(channel.checkpoints_topic(model))
|
|
23
|
+
self.model = model
|
|
24
|
+
self.data_dir = data_dir
|
|
25
|
+
|
|
26
|
+
async def _process_message(self, message: CheckpointMessage) -> None:
|
|
27
|
+
job_id = message.job_id
|
|
28
|
+
checkpoint_id = message.checkpoint_id
|
|
29
|
+
params = message.params
|
|
30
|
+
logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
|
|
31
|
+
|
|
32
|
+
# the input file to the job is stored in the file data_dir/job_id/input/
|
|
33
|
+
checkpoints_file = f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{checkpoint_id}.pickle"
|
|
34
|
+
checkpoint_results_file = (
|
|
35
|
+
f"{self.data_dir}/jobs/{job_id}/results/checkpoint_{checkpoint_id}.pickle"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# create the results directory
|
|
39
|
+
os.makedirs(f"{self.data_dir}/jobs/{job_id}/results", exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# create a model that reads the checkpoint file
|
|
42
|
+
model = ReadCheckpointModel(
|
|
43
|
+
base_model=self.model,
|
|
44
|
+
job_id=job_id,
|
|
45
|
+
checkpoint_id=checkpoint_id,
|
|
46
|
+
channel=self.channel,
|
|
47
|
+
checkpoints_file=checkpoints_file,
|
|
48
|
+
results_file=checkpoint_results_file,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# predict the checkpoint
|
|
52
|
+
model.predict(
|
|
53
|
+
input=None,
|
|
54
|
+
**params,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
def _get_group_name(self) -> str:
|
|
58
|
+
model_name = self.model.__class__.__name__
|
|
59
|
+
return model_name
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from pickle import dump
|
|
4
|
+
|
|
5
|
+
from nerdd_module.input import DepthFirstExplorer
|
|
6
|
+
from nerdd_module.model import ReadInputStep
|
|
7
|
+
|
|
8
|
+
from ..channels import Channel
|
|
9
|
+
from ..types import CheckpointMessage, JobMessage, LogMessage
|
|
10
|
+
from ..utils import batched
|
|
11
|
+
from .action import Action
|
|
12
|
+
|
|
13
|
+
__all__ = ["ProcessJobsAction"]
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ProcessJobsAction(Action[JobMessage]):
|
|
19
|
+
# Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
|
|
20
|
+
# iterates through all molecules in the input (files), writes them as batches
|
|
21
|
+
# into checkpoint files and sends checkpoint messages (for each batch) to the
|
|
22
|
+
# "<job_type>-checkpoints" topic. Also, the number of molecules read is
|
|
23
|
+
# reported to the topic "job-sizes".
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
channel: Channel,
|
|
28
|
+
checkpoint_size: int,
|
|
29
|
+
max_num_molecules: int,
|
|
30
|
+
num_test_entries: int,
|
|
31
|
+
ratio_valid_entries: float,
|
|
32
|
+
maximum_depth: int,
|
|
33
|
+
max_num_lines_mol_block: int,
|
|
34
|
+
data_dir: str,
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__(channel.jobs_topic())
|
|
37
|
+
# relevant for chunking
|
|
38
|
+
self.checkpoint_size = checkpoint_size
|
|
39
|
+
self.max_num_molecules = max_num_molecules
|
|
40
|
+
# parameters of DepthFirstExplorer
|
|
41
|
+
self.num_test_entries = num_test_entries
|
|
42
|
+
self.ratio_valid_entries = ratio_valid_entries
|
|
43
|
+
self.maximum_depth = maximum_depth
|
|
44
|
+
# used as kwargs in DepthFirstExplorer
|
|
45
|
+
self.max_num_lines_mol_block = max_num_lines_mol_block
|
|
46
|
+
self.data_dir = data_dir
|
|
47
|
+
|
|
48
|
+
async def _process_message(self, message: JobMessage) -> None:
|
|
49
|
+
job_id = message.id
|
|
50
|
+
job_type = message.job_type
|
|
51
|
+
logger.info(f"Received a new job {job_id} of type {job_type}")
|
|
52
|
+
|
|
53
|
+
# the input file to the job is stored in the directory data_dir/sources/
|
|
54
|
+
# (the file is allowed to reference other files, but setting the data_dir
|
|
55
|
+
# to the sources directory ensures that we never read files outside of the
|
|
56
|
+
# sources directory)
|
|
57
|
+
sources_dir = os.path.join(self.data_dir, "sources")
|
|
58
|
+
|
|
59
|
+
# create a reader (explorer) for the input file
|
|
60
|
+
explorer = DepthFirstExplorer(
|
|
61
|
+
num_test_entries=self.num_test_entries,
|
|
62
|
+
threshold=self.ratio_valid_entries,
|
|
63
|
+
maximum_depth=self.maximum_depth,
|
|
64
|
+
# extra args
|
|
65
|
+
max_num_lines_mol_block=self.max_num_lines_mol_block,
|
|
66
|
+
data_dir=sources_dir,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
read_input_step = ReadInputStep(explorer, message.source_id)
|
|
70
|
+
|
|
71
|
+
# create a directory for the job
|
|
72
|
+
os.makedirs(f"{self.data_dir}/jobs/{job_id}/input", exist_ok=True)
|
|
73
|
+
|
|
74
|
+
# read the input file
|
|
75
|
+
entries = read_input_step()
|
|
76
|
+
|
|
77
|
+
# iterate through the entries
|
|
78
|
+
# create batches of size checkpoint_size
|
|
79
|
+
# limit the number of molecules to max_num_molecules
|
|
80
|
+
batches = batched(entries, self.checkpoint_size)
|
|
81
|
+
num_entries = 0
|
|
82
|
+
for i, batch in enumerate(batches):
|
|
83
|
+
# max_num_molecules might be reached within the batch
|
|
84
|
+
num_store = min(len(batch), self.max_num_molecules - num_entries)
|
|
85
|
+
|
|
86
|
+
# store batch in data_dir
|
|
87
|
+
with open(f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{i}.pickle", "wb") as f:
|
|
88
|
+
dump(batch[:num_store], f)
|
|
89
|
+
|
|
90
|
+
# send a tuple to topic cypstrate-checkpoints
|
|
91
|
+
await self.channel.checkpoints_topic(job_type).send(
|
|
92
|
+
CheckpointMessage(
|
|
93
|
+
job_id=job_id,
|
|
94
|
+
checkpoint_id=i,
|
|
95
|
+
params=message.params,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
num_entries += num_store
|
|
100
|
+
|
|
101
|
+
if num_entries >= self.max_num_molecules:
|
|
102
|
+
break
|
|
103
|
+
|
|
104
|
+
logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
|
|
105
|
+
|
|
106
|
+
# send a warning message if there were more molecules in the job than allowed
|
|
107
|
+
too_many_molecules = num_store < len(batch)
|
|
108
|
+
try:
|
|
109
|
+
# try to get another entry
|
|
110
|
+
next(entries)
|
|
111
|
+
|
|
112
|
+
# if we get here, there was another entry and we need to send a warning
|
|
113
|
+
too_many_molecules = True
|
|
114
|
+
except StopIteration:
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
if too_many_molecules:
|
|
118
|
+
await self.channel.logs_topic().send(
|
|
119
|
+
LogMessage(
|
|
120
|
+
job_id=job_id,
|
|
121
|
+
message_type="warning",
|
|
122
|
+
message=(
|
|
123
|
+
f"The provided job contains more than "
|
|
124
|
+
f"{self.max_num_molecules} input structures. Only the "
|
|
125
|
+
f"first {self.max_num_molecules} will be processed."
|
|
126
|
+
),
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# at the end, send a tuple to topic job-sizes with the overall size
|
|
131
|
+
# of the job
|
|
132
|
+
await self.channel.logs_topic().send(
|
|
133
|
+
LogMessage(
|
|
134
|
+
job_id=job_id,
|
|
135
|
+
message_type="report_job_size",
|
|
136
|
+
size=num_entries,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from nerdd_module import Model
|
|
4
|
+
from stringcase import spinalcase
|
|
5
|
+
|
|
6
|
+
from ..channels import Channel
|
|
7
|
+
from ..types import ModuleMessage, SystemMessage
|
|
8
|
+
from .action import Action
|
|
9
|
+
|
|
10
|
+
__all__ = ["RegisterModuleAction"]
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class RegisterModuleAction(Action[SystemMessage]):
|
|
16
|
+
def __init__(self, channel: Channel, model: Model):
|
|
17
|
+
super().__init__(channel.system_topic())
|
|
18
|
+
# TODO: do this differently
|
|
19
|
+
assert hasattr(model, "get_config")
|
|
20
|
+
self._model = model
|
|
21
|
+
|
|
22
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
23
|
+
# send the initialization message
|
|
24
|
+
config = self._model.get_config()
|
|
25
|
+
logger.info(f"Send registration message for module {config.name}")
|
|
26
|
+
await self.channel.modules_topic().send(ModuleMessage(**config.model_dump()))
|
|
27
|
+
|
|
28
|
+
def _get_group_name(self) -> str:
|
|
29
|
+
model_name = spinalcase(self._model.__class__.__name__)
|
|
30
|
+
return model_name
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from nerdd_module import Model
|
|
2
|
+
from stringcase import spinalcase
|
|
3
|
+
|
|
4
|
+
from ..channels import Channel
|
|
5
|
+
from ..types import SystemMessage
|
|
6
|
+
from .action import Action
|
|
7
|
+
|
|
8
|
+
__all__ = ["WriteOutputAction"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class WriteOutputAction(Action[SystemMessage]):
|
|
12
|
+
def __init__(self, channel: Channel, model: Model):
|
|
13
|
+
super().__init__(channel.system_topic())
|
|
14
|
+
self._model = model
|
|
15
|
+
|
|
16
|
+
async def _process_message(self, message: SystemMessage) -> None:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
def _get_group_name(self) -> str:
|
|
20
|
+
model_name = spinalcase(self._model.__class__.__name__)
|
|
21
|
+
return model_name
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import AsyncIterable, Generic, TypeVar, Union, cast
|
|
5
|
+
|
|
6
|
+
from nerdd_module import Model
|
|
7
|
+
from stringcase import spinalcase # type: ignore
|
|
8
|
+
|
|
9
|
+
from ..types import (
|
|
10
|
+
CheckpointMessage,
|
|
11
|
+
JobMessage,
|
|
12
|
+
LogMessage,
|
|
13
|
+
Message,
|
|
14
|
+
ModuleMessage,
|
|
15
|
+
ResultCheckpointMessage,
|
|
16
|
+
ResultMessage,
|
|
17
|
+
SystemMessage,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
__all__ = ["Channel", "Topic"]
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=Message)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_job_type(job_type_or_model: Union[str, Model]) -> str:
|
|
26
|
+
if isinstance(job_type_or_model, Model):
|
|
27
|
+
model = job_type_or_model
|
|
28
|
+
|
|
29
|
+
# create topic name from model name by
|
|
30
|
+
# * converting to spinal case, (e.g. "MyModel" -> "my-model")
|
|
31
|
+
# * converting to lowercase (just to be sure) and
|
|
32
|
+
# * removing all characters except dash and alphanumeric characters
|
|
33
|
+
topic_name = spinalcase(model.name)
|
|
34
|
+
topic_name = topic_name.lower()
|
|
35
|
+
topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
|
|
36
|
+
return topic_name
|
|
37
|
+
else:
|
|
38
|
+
return spinalcase(job_type_or_model)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Topic(Generic[T]):
|
|
42
|
+
def __init__(self, channel: Channel, name: str):
|
|
43
|
+
self._channel = channel
|
|
44
|
+
self._name = name
|
|
45
|
+
|
|
46
|
+
async def receive(self, consumer_group: str) -> AsyncIterable[T]:
|
|
47
|
+
async for msg in self.channel.iter_messages(self._name, consumer_group):
|
|
48
|
+
yield cast(T, msg)
|
|
49
|
+
|
|
50
|
+
async def send(self, message: T) -> None:
|
|
51
|
+
await self.channel.send(self._name, message)
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def channel(self) -> Channel:
|
|
55
|
+
return self._channel
|
|
56
|
+
|
|
57
|
+
def __repr__(self) -> str:
|
|
58
|
+
return f"Topic({self._name})"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Channel(ABC):
|
|
62
|
+
#
|
|
63
|
+
# RECEIVE
|
|
64
|
+
#
|
|
65
|
+
async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
66
|
+
async for message in self._iter_messages(topic, consumer_group):
|
|
67
|
+
yield message
|
|
68
|
+
|
|
69
|
+
# Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
|
|
70
|
+
# here, because the method doesn't use "yield" and so the type checker will assume that the
|
|
71
|
+
# actual type is Coroutine[AsyncIterable[Message], None, None].
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
74
|
+
pass
|
|
75
|
+
|
|
76
|
+
#
|
|
77
|
+
# SEND
|
|
78
|
+
#
|
|
79
|
+
async def send(self, topic: str, message: Message) -> None:
|
|
80
|
+
await self._send(topic, message)
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
84
|
+
pass
|
|
85
|
+
|
|
86
|
+
#
|
|
87
|
+
# TOPICS
|
|
88
|
+
#
|
|
89
|
+
def modules_topic(self) -> Topic[ModuleMessage]:
|
|
90
|
+
return Topic[ModuleMessage](self, "modules")
|
|
91
|
+
|
|
92
|
+
def jobs_topic(self) -> Topic[JobMessage]:
|
|
93
|
+
return Topic[JobMessage](self, "jobs")
|
|
94
|
+
|
|
95
|
+
def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
|
|
96
|
+
job_type = get_job_type(job_type_or_model)
|
|
97
|
+
topic_name = f"{job_type}-checkpoints"
|
|
98
|
+
return Topic[CheckpointMessage](self, topic_name)
|
|
99
|
+
|
|
100
|
+
def results_topic(self) -> Topic[ResultMessage]:
|
|
101
|
+
return Topic[ResultMessage](self, "results")
|
|
102
|
+
|
|
103
|
+
def result_checkpoints_topic(
|
|
104
|
+
self, job_type_or_model: Union[str, Model]
|
|
105
|
+
) -> Topic[ResultCheckpointMessage]:
|
|
106
|
+
job_type = get_job_type(job_type_or_model)
|
|
107
|
+
topic_name = f"{job_type}-result-checkpoints"
|
|
108
|
+
return Topic[ResultCheckpointMessage](self, topic_name)
|
|
109
|
+
|
|
110
|
+
def logs_topic(self) -> Topic[LogMessage]:
|
|
111
|
+
return Topic[LogMessage](self, "logs")
|
|
112
|
+
|
|
113
|
+
def system_topic(self) -> Topic[SystemMessage]:
|
|
114
|
+
return Topic[SystemMessage](self, "system")
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
from typing import AsyncIterable, Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
|
|
7
|
+
|
|
8
|
+
from ..types import Message
|
|
9
|
+
from .channel import Channel
|
|
10
|
+
|
|
11
|
+
__all__ = ["KafkaChannel"]
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KafkaChannel(Channel):
|
|
17
|
+
def __init__(self, broker_url: str) -> None:
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._broker_url = broker_url
|
|
20
|
+
self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
|
|
21
|
+
|
|
22
|
+
self._producer = AIOKafkaProducer(
|
|
23
|
+
bootstrap_servers=[self._broker_url],
|
|
24
|
+
)
|
|
25
|
+
# TODO: check value_serializer
|
|
26
|
+
# producer = AIOKafkaProducer(
|
|
27
|
+
# bootstrap_servers=KAFKA_BROKER_URL,
|
|
28
|
+
# value_serializer=lambda v: json.dumps(v).encode("utf-8"),
|
|
29
|
+
# )
|
|
30
|
+
asyncio.create_task(self._producer.start())
|
|
31
|
+
logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer.")
|
|
32
|
+
|
|
33
|
+
async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
34
|
+
if consumer_group is not None:
|
|
35
|
+
consumer_group = f"{consumer_group}-consumer-group"
|
|
36
|
+
|
|
37
|
+
key = (topic, consumer_group)
|
|
38
|
+
|
|
39
|
+
if key not in self._consumers:
|
|
40
|
+
# create consumer
|
|
41
|
+
consumer = AIOKafkaConsumer(
|
|
42
|
+
topic,
|
|
43
|
+
bootstrap_servers=[self._broker_url],
|
|
44
|
+
auto_offset_reset="earliest",
|
|
45
|
+
group_id=consumer_group,
|
|
46
|
+
enable_auto_commit=False,
|
|
47
|
+
)
|
|
48
|
+
await consumer.start()
|
|
49
|
+
self._consumers[key] = consumer
|
|
50
|
+
logger.info(
|
|
51
|
+
f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
|
|
52
|
+
f"topic {topic}."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
consumer = self._consumers[key]
|
|
56
|
+
|
|
57
|
+
try:
|
|
58
|
+
async for message in consumer:
|
|
59
|
+
message_obj = json.loads(message.value)
|
|
60
|
+
yield Message(**message_obj)
|
|
61
|
+
await consumer.commit()
|
|
62
|
+
finally:
|
|
63
|
+
await consumer.stop()
|
|
64
|
+
|
|
65
|
+
# try:
|
|
66
|
+
# while True:
|
|
67
|
+
# # we use polling (instead of iterating through the consumer messages)
|
|
68
|
+
# # to be able to cancel the consumer
|
|
69
|
+
# messages = await self.kafka_consumer.getmany(timeout_ms=1000)
|
|
70
|
+
|
|
71
|
+
# if messages:
|
|
72
|
+
# for _, message_list in messages.items():
|
|
73
|
+
# for message in message_list:
|
|
74
|
+
# result = json.loads(message.value)
|
|
75
|
+
# logger.info(f"Received message on {message.topic}")
|
|
76
|
+
|
|
77
|
+
# try:
|
|
78
|
+
# for consumer in self.consumers:
|
|
79
|
+
# await consumer.consume(result)
|
|
80
|
+
|
|
81
|
+
# logger.info("Committing message")
|
|
82
|
+
# await self.kafka_consumer.commit()
|
|
83
|
+
# except Exception:
|
|
84
|
+
# logger.info("Rolling back message")
|
|
85
|
+
# logger.error(traceback.format_exc())
|
|
86
|
+
# except asyncio.CancelledError:
|
|
87
|
+
# logger.info("Stopping ConsumeKafkaTopicLifespan")
|
|
88
|
+
# await self.kafka_consumer.stop()
|
|
89
|
+
# except Exception as e:
|
|
90
|
+
# logger.error(e)
|
|
91
|
+
# logger.error(traceback.format_exc())
|
|
92
|
+
|
|
93
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
94
|
+
await self._producer.send_and_wait(
|
|
95
|
+
topic,
|
|
96
|
+
json.dumps(message.model_dump()).encode("utf-8"),
|
|
97
|
+
)
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import AsyncIterable, List, Tuple
|
|
3
|
+
|
|
4
|
+
from ..types import Message
|
|
5
|
+
from ..utils import ObservableList
|
|
6
|
+
from .channel import Channel
|
|
7
|
+
|
|
8
|
+
__all__ = ["MemoryChannel"]
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MemoryChannel(Channel):
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
super().__init__()
|
|
16
|
+
self._messages = ObservableList[Tuple[str, Message]]()
|
|
17
|
+
|
|
18
|
+
def get_produced_messages(self) -> List[Tuple[str, Message]]:
|
|
19
|
+
return self._messages.get_items()
|
|
20
|
+
|
|
21
|
+
async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
|
|
22
|
+
async for _, new in self._messages.changes():
|
|
23
|
+
assert new is not None
|
|
24
|
+
(t, message) = new
|
|
25
|
+
if topic == t:
|
|
26
|
+
yield message
|
|
27
|
+
|
|
28
|
+
async def _send(self, topic: str, message: Message) -> None:
|
|
29
|
+
logger.info(f"Send message to topic {topic}")
|
|
30
|
+
self._messages.append((topic, message))
|
|
31
|
+
|
|
32
|
+
async def stop(self) -> None:
|
|
33
|
+
await self._messages.stop()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
import rich_click as click
|
|
4
|
+
|
|
5
|
+
from ..channels import KafkaChannel
|
|
6
|
+
from ..types import SystemMessage
|
|
7
|
+
from ..utils import async_to_sync
|
|
8
|
+
|
|
9
|
+
__all__ = ["initialize_system"]
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@click.command(context_settings={"show_default": True})
|
|
15
|
+
@click.option(
|
|
16
|
+
"--channel",
|
|
17
|
+
type=click.Choice(["kafka"], case_sensitive=False),
|
|
18
|
+
default="kafka",
|
|
19
|
+
help="Channel to use for communication with the model.",
|
|
20
|
+
)
|
|
21
|
+
@click.option("--broker-url", default="localhost:9092", help="Kafka broker to connect to.")
|
|
22
|
+
@click.option(
|
|
23
|
+
"--log-level",
|
|
24
|
+
default="info",
|
|
25
|
+
type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False),
|
|
26
|
+
help="The logging level.",
|
|
27
|
+
)
|
|
28
|
+
@async_to_sync
|
|
29
|
+
async def initialize_system(
|
|
30
|
+
# communication options
|
|
31
|
+
channel: str,
|
|
32
|
+
broker_url: str,
|
|
33
|
+
# log level
|
|
34
|
+
log_level: str,
|
|
35
|
+
) -> None:
|
|
36
|
+
logging.basicConfig(level=log_level.upper())
|
|
37
|
+
|
|
38
|
+
channel_instance = None
|
|
39
|
+
if channel == "kafka":
|
|
40
|
+
channel_instance = KafkaChannel(broker_url)
|
|
41
|
+
else:
|
|
42
|
+
raise ValueError(f"Channel {channel} not supported.")
|
|
43
|
+
|
|
44
|
+
logging.info("Sending the system initialization message...")
|
|
45
|
+
await channel_instance.system_topic().send(SystemMessage())
|