nerdd-link 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. nerdd_link/__init__.py +4 -0
  2. nerdd_link/actions/__init__.py +5 -0
  3. nerdd_link/actions/action.py +30 -0
  4. nerdd_link/actions/predict_checkpoints_action.py +59 -0
  5. nerdd_link/actions/process_jobs_action.py +138 -0
  6. nerdd_link/actions/register_module_action.py +30 -0
  7. nerdd_link/actions/write_output_action.py +21 -0
  8. nerdd_link/channels/__init__.py +3 -0
  9. nerdd_link/channels/channel.py +114 -0
  10. nerdd_link/channels/kafka_channel.py +97 -0
  11. nerdd_link/channels/memory_channel.py +33 -0
  12. nerdd_link/cli/__init__.py +3 -0
  13. nerdd_link/cli/initialize_system.py +45 -0
  14. nerdd_link/cli/run_job_server.py +107 -0
  15. nerdd_link/cli/run_prediction_server.py +81 -0
  16. nerdd_link/delegates/__init__.py +3 -0
  17. nerdd_link/delegates/pickle_writer.py +18 -0
  18. nerdd_link/delegates/read_checkpoint_model.py +65 -0
  19. nerdd_link/delegates/read_pickle_step.py +18 -0
  20. nerdd_link/delegates/split_and_merge_step.py +51 -0
  21. nerdd_link/delegates/topic_writer.py +27 -0
  22. nerdd_link/input/__init__.py +1 -0
  23. nerdd_link/input/structure_json_reader.py +41 -0
  24. nerdd_link/py.typed +0 -0
  25. nerdd_link/tests/__init__.py +3 -0
  26. nerdd_link/tests/async_step.py +22 -0
  27. nerdd_link/tests/channels.py +82 -0
  28. nerdd_link/tests/files.py +9 -0
  29. nerdd_link/types/__init__.py +54 -0
  30. nerdd_link/utils/__init__.py +4 -0
  31. nerdd_link/utils/async_to_sync.py +26 -0
  32. nerdd_link/utils/batched.py +27 -0
  33. nerdd_link/utils/observable_list.py +72 -0
  34. nerdd_link/utils/safetee.py +39 -0
  35. nerdd_link/version.py +11 -0
  36. nerdd_link-0.1.0.dist-info/LICENSE +21 -0
  37. nerdd_link-0.1.0.dist-info/METADATA +116 -0
  38. nerdd_link-0.1.0.dist-info/RECORD +41 -0
  39. nerdd_link-0.1.0.dist-info/WHEEL +5 -0
  40. nerdd_link-0.1.0.dist-info/entry_points.txt +4 -0
  41. nerdd_link-0.1.0.dist-info/top_level.txt +1 -0
nerdd_link/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .actions import *
2
+ from .channels import *
3
+ from .input import *
4
+ from .types import *
@@ -0,0 +1,5 @@
1
+ from .action import *
2
+ from .predict_checkpoints_action import *
3
+ from .process_jobs_action import *
4
+ from .register_module_action import *
5
+ from .write_output_action import *
@@ -0,0 +1,30 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Generic, TypeVar
3
+
4
+ from stringcase import spinalcase
5
+
6
+ from ..channels import Channel, Topic
7
+ from ..types import Message
8
+
9
+ T = TypeVar("T", bound=Message)
10
+
11
+
12
+ class Action(ABC, Generic[T]):
13
+ def __init__(self, input_topic: Topic[T]):
14
+ self._input_topic = input_topic
15
+
16
+ async def run(self) -> None:
17
+ consumer_group = spinalcase(self._get_group_name())
18
+ async for message in self._input_topic.receive(consumer_group):
19
+ await self._process_message(message)
20
+
21
+ @abstractmethod
22
+ async def _process_message(self, message: T) -> None:
23
+ pass
24
+
25
+ @property
26
+ def channel(self) -> Channel:
27
+ return self._input_topic.channel
28
+
29
+ def _get_group_name(self) -> str:
30
+ return self.__class__.__name__
@@ -0,0 +1,59 @@
1
+ import logging
2
+ import os
3
+
4
+ from nerdd_module import Model
5
+
6
+ from ..channels import Channel
7
+ from ..delegates import ReadCheckpointModel
8
+ from ..types import CheckpointMessage
9
+ from .action import Action
10
+
11
+ __all__ = ["PredictCheckpointsAction"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class PredictCheckpointsAction(Action[CheckpointMessage]):
17
+ # Accept a batch of input molecules on the "<job-type>-checkpoints" topic
18
+ # (generated in the previous step) and process them. Results are written to
19
+ # the "results" topic.
20
+
21
+ def __init__(self, channel: Channel, model: Model, data_dir: str) -> None:
22
+ super().__init__(channel.checkpoints_topic(model))
23
+ self.model = model
24
+ self.data_dir = data_dir
25
+
26
+ async def _process_message(self, message: CheckpointMessage) -> None:
27
+ job_id = message.job_id
28
+ checkpoint_id = message.checkpoint_id
29
+ params = message.params
30
+ logger.info(f"Predict checkpoint {checkpoint_id} of job {job_id}")
31
+
32
+ # the input file to the job is stored in the file data_dir/job_id/input/
33
+ checkpoints_file = f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{checkpoint_id}.pickle"
34
+ checkpoint_results_file = (
35
+ f"{self.data_dir}/jobs/{job_id}/results/checkpoint_{checkpoint_id}.pickle"
36
+ )
37
+
38
+ # create the results directory
39
+ os.makedirs(f"{self.data_dir}/jobs/{job_id}/results", exist_ok=True)
40
+
41
+ # create a model that reads the checkpoint file
42
+ model = ReadCheckpointModel(
43
+ base_model=self.model,
44
+ job_id=job_id,
45
+ checkpoint_id=checkpoint_id,
46
+ channel=self.channel,
47
+ checkpoints_file=checkpoints_file,
48
+ results_file=checkpoint_results_file,
49
+ )
50
+
51
+ # predict the checkpoint
52
+ model.predict(
53
+ input=None,
54
+ **params,
55
+ )
56
+
57
+ def _get_group_name(self) -> str:
58
+ model_name = self.model.__class__.__name__
59
+ return model_name
@@ -0,0 +1,138 @@
1
+ import logging
2
+ import os
3
+ from pickle import dump
4
+
5
+ from nerdd_module.input import DepthFirstExplorer
6
+ from nerdd_module.model import ReadInputStep
7
+
8
+ from ..channels import Channel
9
+ from ..types import CheckpointMessage, JobMessage, LogMessage
10
+ from ..utils import batched
11
+ from .action import Action
12
+
13
+ __all__ = ["ProcessJobsAction"]
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ProcessJobsAction(Action[JobMessage]):
19
+ # Accept new jobs (on the "<job_type>-jobs" topic). For each job, the program
20
+ # iterates through all molecules in the input (files), writes them as batches
21
+ # into checkpoint files and sends checkpoint messages (for each batch) to the
22
+ # "<job_type>-checkpoints" topic. Also, the number of molecules read is
23
+ # reported to the topic "job-sizes".
24
+
25
+ def __init__(
26
+ self,
27
+ channel: Channel,
28
+ checkpoint_size: int,
29
+ max_num_molecules: int,
30
+ num_test_entries: int,
31
+ ratio_valid_entries: float,
32
+ maximum_depth: int,
33
+ max_num_lines_mol_block: int,
34
+ data_dir: str,
35
+ ) -> None:
36
+ super().__init__(channel.jobs_topic())
37
+ # relevant for chunking
38
+ self.checkpoint_size = checkpoint_size
39
+ self.max_num_molecules = max_num_molecules
40
+ # parameters of DepthFirstExplorer
41
+ self.num_test_entries = num_test_entries
42
+ self.ratio_valid_entries = ratio_valid_entries
43
+ self.maximum_depth = maximum_depth
44
+ # used as kwargs in DepthFirstExplorer
45
+ self.max_num_lines_mol_block = max_num_lines_mol_block
46
+ self.data_dir = data_dir
47
+
48
+ async def _process_message(self, message: JobMessage) -> None:
49
+ job_id = message.id
50
+ job_type = message.job_type
51
+ logger.info(f"Received a new job {job_id} of type {job_type}")
52
+
53
+ # the input file to the job is stored in the directory data_dir/sources/
54
+ # (the file is allowed to reference other files, but setting the data_dir
55
+ # to the sources directory ensures that we never read files outside of the
56
+ # sources directory)
57
+ sources_dir = os.path.join(self.data_dir, "sources")
58
+
59
+ # create a reader (explorer) for the input file
60
+ explorer = DepthFirstExplorer(
61
+ num_test_entries=self.num_test_entries,
62
+ threshold=self.ratio_valid_entries,
63
+ maximum_depth=self.maximum_depth,
64
+ # extra args
65
+ max_num_lines_mol_block=self.max_num_lines_mol_block,
66
+ data_dir=sources_dir,
67
+ )
68
+
69
+ read_input_step = ReadInputStep(explorer, message.source_id)
70
+
71
+ # create a directory for the job
72
+ os.makedirs(f"{self.data_dir}/jobs/{job_id}/input", exist_ok=True)
73
+
74
+ # read the input file
75
+ entries = read_input_step()
76
+
77
+ # iterate through the entries
78
+ # create batches of size checkpoint_size
79
+ # limit the number of molecules to max_num_molecules
80
+ batches = batched(entries, self.checkpoint_size)
81
+ num_entries = 0
82
+ for i, batch in enumerate(batches):
83
+ # max_num_molecules might be reached within the batch
84
+ num_store = min(len(batch), self.max_num_molecules - num_entries)
85
+
86
+ # store batch in data_dir
87
+ with open(f"{self.data_dir}/jobs/{job_id}/input/checkpoint_{i}.pickle", "wb") as f:
88
+ dump(batch[:num_store], f)
89
+
90
+ # send a tuple to topic cypstrate-checkpoints
91
+ await self.channel.checkpoints_topic(job_type).send(
92
+ CheckpointMessage(
93
+ job_id=job_id,
94
+ checkpoint_id=i,
95
+ params=message.params,
96
+ )
97
+ )
98
+
99
+ num_entries += num_store
100
+
101
+ if num_entries >= self.max_num_molecules:
102
+ break
103
+
104
+ logger.info(f"Wrote {i+1} checkpoints containing {num_entries} entries for job {job_id}")
105
+
106
+ # send a warning message if there were more molecules in the job than allowed
107
+ too_many_molecules = num_store < len(batch)
108
+ try:
109
+ # try to get another entry
110
+ next(entries)
111
+
112
+ # if we get here, there was another entry and we need to send a warning
113
+ too_many_molecules = True
114
+ except StopIteration:
115
+ pass
116
+
117
+ if too_many_molecules:
118
+ await self.channel.logs_topic().send(
119
+ LogMessage(
120
+ job_id=job_id,
121
+ message_type="warning",
122
+ message=(
123
+ f"The provided job contains more than "
124
+ f"{self.max_num_molecules} input structures. Only the "
125
+ f"first {self.max_num_molecules} will be processed."
126
+ ),
127
+ )
128
+ )
129
+
130
+ # at the end, send a tuple to topic job-sizes with the overall size
131
+ # of the job
132
+ await self.channel.logs_topic().send(
133
+ LogMessage(
134
+ job_id=job_id,
135
+ message_type="report_job_size",
136
+ size=num_entries,
137
+ )
138
+ )
@@ -0,0 +1,30 @@
1
+ import logging
2
+
3
+ from nerdd_module import Model
4
+ from stringcase import spinalcase
5
+
6
+ from ..channels import Channel
7
+ from ..types import ModuleMessage, SystemMessage
8
+ from .action import Action
9
+
10
+ __all__ = ["RegisterModuleAction"]
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class RegisterModuleAction(Action[SystemMessage]):
16
+ def __init__(self, channel: Channel, model: Model):
17
+ super().__init__(channel.system_topic())
18
+ # TODO: do this differently
19
+ assert hasattr(model, "get_config")
20
+ self._model = model
21
+
22
+ async def _process_message(self, message: SystemMessage) -> None:
23
+ # send the initialization message
24
+ config = self._model.get_config()
25
+ logger.info(f"Send registration message for module {config.name}")
26
+ await self.channel.modules_topic().send(ModuleMessage(**config.model_dump()))
27
+
28
+ def _get_group_name(self) -> str:
29
+ model_name = spinalcase(self._model.__class__.__name__)
30
+ return model_name
@@ -0,0 +1,21 @@
1
+ from nerdd_module import Model
2
+ from stringcase import spinalcase
3
+
4
+ from ..channels import Channel
5
+ from ..types import SystemMessage
6
+ from .action import Action
7
+
8
+ __all__ = ["WriteOutputAction"]
9
+
10
+
11
+ class WriteOutputAction(Action[SystemMessage]):
12
+ def __init__(self, channel: Channel, model: Model):
13
+ super().__init__(channel.system_topic())
14
+ self._model = model
15
+
16
+ async def _process_message(self, message: SystemMessage) -> None:
17
+ pass
18
+
19
+ def _get_group_name(self) -> str:
20
+ model_name = spinalcase(self._model.__class__.__name__)
21
+ return model_name
@@ -0,0 +1,3 @@
1
+ from .channel import *
2
+ from .kafka_channel import *
3
+ from .memory_channel import *
@@ -0,0 +1,114 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import AsyncIterable, Generic, TypeVar, Union, cast
5
+
6
+ from nerdd_module import Model
7
+ from stringcase import spinalcase # type: ignore
8
+
9
+ from ..types import (
10
+ CheckpointMessage,
11
+ JobMessage,
12
+ LogMessage,
13
+ Message,
14
+ ModuleMessage,
15
+ ResultCheckpointMessage,
16
+ ResultMessage,
17
+ SystemMessage,
18
+ )
19
+
20
+ __all__ = ["Channel", "Topic"]
21
+
22
+ T = TypeVar("T", bound=Message)
23
+
24
+
25
+ def get_job_type(job_type_or_model: Union[str, Model]) -> str:
26
+ if isinstance(job_type_or_model, Model):
27
+ model = job_type_or_model
28
+
29
+ # create topic name from model name by
30
+ # * converting to spinal case, (e.g. "MyModel" -> "my-model")
31
+ # * converting to lowercase (just to be sure) and
32
+ # * removing all characters except dash and alphanumeric characters
33
+ topic_name = spinalcase(model.name)
34
+ topic_name = topic_name.lower()
35
+ topic_name = "".join([c for c in topic_name if str.isalnum(c) or c == "-"])
36
+ return topic_name
37
+ else:
38
+ return spinalcase(job_type_or_model)
39
+
40
+
41
+ class Topic(Generic[T]):
42
+ def __init__(self, channel: Channel, name: str):
43
+ self._channel = channel
44
+ self._name = name
45
+
46
+ async def receive(self, consumer_group: str) -> AsyncIterable[T]:
47
+ async for msg in self.channel.iter_messages(self._name, consumer_group):
48
+ yield cast(T, msg)
49
+
50
+ async def send(self, message: T) -> None:
51
+ await self.channel.send(self._name, message)
52
+
53
+ @property
54
+ def channel(self) -> Channel:
55
+ return self._channel
56
+
57
+ def __repr__(self) -> str:
58
+ return f"Topic({self._name})"
59
+
60
+
61
+ class Channel(ABC):
62
+ #
63
+ # RECEIVE
64
+ #
65
+ async def iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
66
+ async for message in self._iter_messages(topic, consumer_group):
67
+ yield message
68
+
69
+ # Insane glitch: we need to use "def _iter_messages" instead of "async def _iter_messages"
70
+ # here, because the method doesn't use "yield" and so the type checker will assume that the
71
+ # actual type is Coroutine[AsyncIterable[Message], None, None].
72
+ @abstractmethod
73
+ def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
74
+ pass
75
+
76
+ #
77
+ # SEND
78
+ #
79
+ async def send(self, topic: str, message: Message) -> None:
80
+ await self._send(topic, message)
81
+
82
+ @abstractmethod
83
+ async def _send(self, topic: str, message: Message) -> None:
84
+ pass
85
+
86
+ #
87
+ # TOPICS
88
+ #
89
+ def modules_topic(self) -> Topic[ModuleMessage]:
90
+ return Topic[ModuleMessage](self, "modules")
91
+
92
+ def jobs_topic(self) -> Topic[JobMessage]:
93
+ return Topic[JobMessage](self, "jobs")
94
+
95
+ def checkpoints_topic(self, job_type_or_model: Union[str, Model]) -> Topic[CheckpointMessage]:
96
+ job_type = get_job_type(job_type_or_model)
97
+ topic_name = f"{job_type}-checkpoints"
98
+ return Topic[CheckpointMessage](self, topic_name)
99
+
100
+ def results_topic(self) -> Topic[ResultMessage]:
101
+ return Topic[ResultMessage](self, "results")
102
+
103
+ def result_checkpoints_topic(
104
+ self, job_type_or_model: Union[str, Model]
105
+ ) -> Topic[ResultCheckpointMessage]:
106
+ job_type = get_job_type(job_type_or_model)
107
+ topic_name = f"{job_type}-result-checkpoints"
108
+ return Topic[ResultCheckpointMessage](self, topic_name)
109
+
110
+ def logs_topic(self) -> Topic[LogMessage]:
111
+ return Topic[LogMessage](self, "logs")
112
+
113
+ def system_topic(self) -> Topic[SystemMessage]:
114
+ return Topic[SystemMessage](self, "system")
@@ -0,0 +1,97 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ from typing import AsyncIterable, Dict, Tuple
5
+
6
+ from aiokafka import AIOKafkaConsumer, AIOKafkaProducer
7
+
8
+ from ..types import Message
9
+ from .channel import Channel
10
+
11
+ __all__ = ["KafkaChannel"]
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class KafkaChannel(Channel):
17
+ def __init__(self, broker_url: str) -> None:
18
+ super().__init__()
19
+ self._broker_url = broker_url
20
+ self._consumers: Dict[Tuple[str, str], AIOKafkaConsumer] = {}
21
+
22
+ self._producer = AIOKafkaProducer(
23
+ bootstrap_servers=[self._broker_url],
24
+ )
25
+ # TODO: check value_serializer
26
+ # producer = AIOKafkaProducer(
27
+ # bootstrap_servers=KAFKA_BROKER_URL,
28
+ # value_serializer=lambda v: json.dumps(v).encode("utf-8"),
29
+ # )
30
+ asyncio.create_task(self._producer.start())
31
+ logger.info(f"Connecting to Kafka broker {self._broker_url} and starting a producer.")
32
+
33
+ async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
34
+ if consumer_group is not None:
35
+ consumer_group = f"{consumer_group}-consumer-group"
36
+
37
+ key = (topic, consumer_group)
38
+
39
+ if key not in self._consumers:
40
+ # create consumer
41
+ consumer = AIOKafkaConsumer(
42
+ topic,
43
+ bootstrap_servers=[self._broker_url],
44
+ auto_offset_reset="earliest",
45
+ group_id=consumer_group,
46
+ enable_auto_commit=False,
47
+ )
48
+ await consumer.start()
49
+ self._consumers[key] = consumer
50
+ logger.info(
51
+ f"Connecting to Kafka broker {self._broker_url} and starting a consumer on "
52
+ f"topic {topic}."
53
+ )
54
+
55
+ consumer = self._consumers[key]
56
+
57
+ try:
58
+ async for message in consumer:
59
+ message_obj = json.loads(message.value)
60
+ yield Message(**message_obj)
61
+ await consumer.commit()
62
+ finally:
63
+ await consumer.stop()
64
+
65
+ # try:
66
+ # while True:
67
+ # # we use polling (instead of iterating through the consumer messages)
68
+ # # to be able to cancel the consumer
69
+ # messages = await self.kafka_consumer.getmany(timeout_ms=1000)
70
+
71
+ # if messages:
72
+ # for _, message_list in messages.items():
73
+ # for message in message_list:
74
+ # result = json.loads(message.value)
75
+ # logger.info(f"Received message on {message.topic}")
76
+
77
+ # try:
78
+ # for consumer in self.consumers:
79
+ # await consumer.consume(result)
80
+
81
+ # logger.info("Committing message")
82
+ # await self.kafka_consumer.commit()
83
+ # except Exception:
84
+ # logger.info("Rolling back message")
85
+ # logger.error(traceback.format_exc())
86
+ # except asyncio.CancelledError:
87
+ # logger.info("Stopping ConsumeKafkaTopicLifespan")
88
+ # await self.kafka_consumer.stop()
89
+ # except Exception as e:
90
+ # logger.error(e)
91
+ # logger.error(traceback.format_exc())
92
+
93
+ async def _send(self, topic: str, message: Message) -> None:
94
+ await self._producer.send_and_wait(
95
+ topic,
96
+ json.dumps(message.model_dump()).encode("utf-8"),
97
+ )
@@ -0,0 +1,33 @@
1
+ import logging
2
+ from typing import AsyncIterable, List, Tuple
3
+
4
+ from ..types import Message
5
+ from ..utils import ObservableList
6
+ from .channel import Channel
7
+
8
+ __all__ = ["MemoryChannel"]
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class MemoryChannel(Channel):
14
+ def __init__(self) -> None:
15
+ super().__init__()
16
+ self._messages = ObservableList[Tuple[str, Message]]()
17
+
18
+ def get_produced_messages(self) -> List[Tuple[str, Message]]:
19
+ return self._messages.get_items()
20
+
21
+ async def _iter_messages(self, topic: str, consumer_group: str) -> AsyncIterable[Message]:
22
+ async for _, new in self._messages.changes():
23
+ assert new is not None
24
+ (t, message) = new
25
+ if topic == t:
26
+ yield message
27
+
28
+ async def _send(self, topic: str, message: Message) -> None:
29
+ logger.info(f"Send message to topic {topic}")
30
+ self._messages.append((topic, message))
31
+
32
+ async def stop(self) -> None:
33
+ await self._messages.stop()
@@ -0,0 +1,3 @@
1
+ from .initialize_system import *
2
+ from .run_job_server import *
3
+ from .run_prediction_server import *
@@ -0,0 +1,45 @@
1
+ import logging
2
+
3
+ import rich_click as click
4
+
5
+ from ..channels import KafkaChannel
6
+ from ..types import SystemMessage
7
+ from ..utils import async_to_sync
8
+
9
+ __all__ = ["initialize_system"]
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ @click.command(context_settings={"show_default": True})
15
+ @click.option(
16
+ "--channel",
17
+ type=click.Choice(["kafka"], case_sensitive=False),
18
+ default="kafka",
19
+ help="Channel to use for communication with the model.",
20
+ )
21
+ @click.option("--broker-url", default="localhost:9092", help="Kafka broker to connect to.")
22
+ @click.option(
23
+ "--log-level",
24
+ default="info",
25
+ type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False),
26
+ help="The logging level.",
27
+ )
28
+ @async_to_sync
29
+ async def initialize_system(
30
+ # communication options
31
+ channel: str,
32
+ broker_url: str,
33
+ # log level
34
+ log_level: str,
35
+ ) -> None:
36
+ logging.basicConfig(level=log_level.upper())
37
+
38
+ channel_instance = None
39
+ if channel == "kafka":
40
+ channel_instance = KafkaChannel(broker_url)
41
+ else:
42
+ raise ValueError(f"Channel {channel} not supported.")
43
+
44
+ logging.info("Sending the system initialization message...")
45
+ await channel_instance.system_topic().send(SystemMessage())