pinexq-procon 2.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pinexq/procon/__init__.py +0 -0
- pinexq/procon/core/__init__.py +0 -0
- pinexq/procon/core/cli.py +442 -0
- pinexq/procon/core/exceptions.py +64 -0
- pinexq/procon/core/helpers.py +61 -0
- pinexq/procon/core/logconfig.py +48 -0
- pinexq/procon/core/naming.py +36 -0
- pinexq/procon/core/types.py +15 -0
- pinexq/procon/dataslots/__init__.py +19 -0
- pinexq/procon/dataslots/abstractionlayer.py +215 -0
- pinexq/procon/dataslots/annotation.py +389 -0
- pinexq/procon/dataslots/dataslots.py +369 -0
- pinexq/procon/dataslots/datatypes.py +50 -0
- pinexq/procon/dataslots/default_reader_writer.py +26 -0
- pinexq/procon/dataslots/filebackend.py +126 -0
- pinexq/procon/dataslots/metadata.py +137 -0
- pinexq/procon/jobmanagement/__init__.py +9 -0
- pinexq/procon/jobmanagement/api_helpers.py +287 -0
- pinexq/procon/remote/__init__.py +0 -0
- pinexq/procon/remote/messages.py +250 -0
- pinexq/procon/remote/rabbitmq.py +420 -0
- pinexq/procon/runtime/__init__.py +3 -0
- pinexq/procon/runtime/foreman.py +128 -0
- pinexq/procon/runtime/job.py +384 -0
- pinexq/procon/runtime/settings.py +12 -0
- pinexq/procon/runtime/tool.py +16 -0
- pinexq/procon/runtime/worker.py +437 -0
- pinexq/procon/step/__init__.py +3 -0
- pinexq/procon/step/introspection.py +234 -0
- pinexq/procon/step/schema.py +99 -0
- pinexq/procon/step/step.py +119 -0
- pinexq/procon/step/versioning.py +84 -0
- pinexq_procon-2.1.0.dev3.dist-info/METADATA +83 -0
- pinexq_procon-2.1.0.dev3.dist-info/RECORD +35 -0
- pinexq_procon-2.1.0.dev3.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from asyncio import CancelledError, Task
|
|
4
|
+
from typing import Awaitable, Callable, TypeVar, Union
|
|
5
|
+
|
|
6
|
+
import pydantic
|
|
7
|
+
|
|
8
|
+
from ..core.exceptions import (
|
|
9
|
+
ProConBadMessage,
|
|
10
|
+
ProConException,
|
|
11
|
+
ProConMessageRejected,
|
|
12
|
+
ProConShutdown,
|
|
13
|
+
ProConUnknownFunctionError,
|
|
14
|
+
)
|
|
15
|
+
from ..core.naming import escape_version_string, is_valid_version_string
|
|
16
|
+
from ..remote.messages import (
|
|
17
|
+
JobOfferMessageBody,
|
|
18
|
+
WorkerCommandMessageBody,
|
|
19
|
+
WorkerStatus,
|
|
20
|
+
)
|
|
21
|
+
from ..remote.rabbitmq import QueueSubscriber, RabbitMQClient
|
|
22
|
+
from ..step import Step
|
|
23
|
+
from .job import FunctionVersion, ProConJob
|
|
24
|
+
from .settings import JOB_OFFERS_TOPIC
|
|
25
|
+
from .tool import handle_task_result
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
log = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
T = TypeVar('T') # Any type.
|
|
31
|
+
CallableOrAwaitable = Union[Callable[[T], bool], Callable[[T], Awaitable[bool]]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class WorkerCommunicationRMQ:
|
|
35
|
+
"""Worker-specific communication using RabbitMQ"""
|
|
36
|
+
rmq_client: RabbitMQClient
|
|
37
|
+
_job_offer_subs: list[QueueSubscriber]
|
|
38
|
+
_outgoing_msg_queue: asyncio.Queue
|
|
39
|
+
_worker_cmd_cb: CallableOrAwaitable[WorkerCommandMessageBody] | None
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
rmq_parameters: dict,
|
|
44
|
+
sender_id: str,
|
|
45
|
+
worker_id: str,
|
|
46
|
+
context_id: str,
|
|
47
|
+
functions: list[FunctionVersion],
|
|
48
|
+
msg_queue: asyncio.Queue,
|
|
49
|
+
worker_cmd_cb: CallableOrAwaitable[WorkerCommandMessageBody] | None = None,
|
|
50
|
+
):
|
|
51
|
+
self.sender_id = sender_id
|
|
52
|
+
self.worker_id = worker_id
|
|
53
|
+
self.context_id = context_id
|
|
54
|
+
self._functions = functions
|
|
55
|
+
self._outgoing_msg_queue = msg_queue
|
|
56
|
+
self._worker_cmd_cb = worker_cmd_cb
|
|
57
|
+
self._job_offer_subs = []
|
|
58
|
+
|
|
59
|
+
self.is_initialized = False
|
|
60
|
+
|
|
61
|
+
self.rmq_parameters = rmq_parameters
|
|
62
|
+
self.rmq_client = RabbitMQClient(**rmq_parameters)
|
|
63
|
+
|
|
64
|
+
async def start(self):
|
|
65
|
+
"""Connect to the RabbitMQ server and start accepting jobs"""
|
|
66
|
+
await self.rmq_client.connect_and_run()
|
|
67
|
+
await self._connect_to_worker_queues(self._functions)
|
|
68
|
+
|
|
69
|
+
async def stop(self):
|
|
70
|
+
"""Stop receiving jobs and disconnect from RMQ server"""
|
|
71
|
+
await self.disconnect_from_worker_queues()
|
|
72
|
+
await self.rmq_client.close()
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def is_ready_to_send(self) -> bool:
|
|
76
|
+
"""Returns if connection to RMQ server is established and queues initialized."""
|
|
77
|
+
return self.rmq_client.is_connected and self.is_initialized
|
|
78
|
+
|
|
79
|
+
async def report_worker_state(self, status: WorkerStatus | str):
|
|
80
|
+
"""Log the current state of the worker"""
|
|
81
|
+
log.info(f"Worker state changed to: '{status.value}'")
|
|
82
|
+
|
|
83
|
+
async def report_worker_error(self, title: str, exception: Exception | None = None):
|
|
84
|
+
"""Log error message from an exception"""
|
|
85
|
+
if exception:
|
|
86
|
+
log.exception(f"Error in worker: {title}", exc_info=exception)
|
|
87
|
+
else:
|
|
88
|
+
log.error(f"Error in worker: {title}")
|
|
89
|
+
|
|
90
|
+
async def _create_job_offer_subscription(self, func_name: str, func_version: str) -> QueueSubscriber | None:
|
|
91
|
+
topic = JOB_OFFERS_TOPIC.format(ContextId=self.context_id, FunctionName=func_name, Version=func_version)
|
|
92
|
+
# log.debug("Subscribing to Job.Offers queue: '%s'", topic)
|
|
93
|
+
exists = await self.rmq_client.does_queue_exist(topic)
|
|
94
|
+
if exists:
|
|
95
|
+
return await self.rmq_client.subscriber(
|
|
96
|
+
queue_name=topic,
|
|
97
|
+
routing_key=topic,
|
|
98
|
+
callback=self.on_job_offer_message
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
async def _connect_to_worker_queues(self, functions: list[FunctionVersion]):
|
|
104
|
+
"""
|
|
105
|
+
Create the worker command channel, subscribe to the status channel,
|
|
106
|
+
and subscribe to a `job.offer` channel for each function this worker publishes.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
functions: List of tuples with ("the_function_name", "version")
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
self._job_offer_subs = []
|
|
113
|
+
for f in functions:
|
|
114
|
+
sub = await self._create_job_offer_subscription(func_name=f.name, func_version=f.version)
|
|
115
|
+
if sub:
|
|
116
|
+
self._job_offer_subs.append(sub)
|
|
117
|
+
log.info(f"Listening for JobOffers: {f.name}:{f.version}")
|
|
118
|
+
else:
|
|
119
|
+
log.warning(f"Could not connect to JobOffers of: {f.name}:{f.version}")
|
|
120
|
+
|
|
121
|
+
if not self._job_offer_subs:
|
|
122
|
+
log.critical(msg:="Failed to connect to any Job.Offer queues!")
|
|
123
|
+
raise ProConException(msg)
|
|
124
|
+
|
|
125
|
+
self.is_initialized = True
|
|
126
|
+
|
|
127
|
+
async def disconnect_from_worker_queues(self):
|
|
128
|
+
"""Stop listening to job.offer queues and stop & delete the command queue for this worker."""
|
|
129
|
+
if not self.is_initialized:
|
|
130
|
+
return
|
|
131
|
+
await asyncio.gather(
|
|
132
|
+
*(sub.stop_consumer_loop() for sub in self._job_offer_subs)
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
self.is_initialized = False
|
|
136
|
+
|
|
137
|
+
async def start_consuming_job_offers(self):
|
|
138
|
+
"""Start polling on all `job.offer` queues."""
|
|
139
|
+
try:
|
|
140
|
+
async with asyncio.TaskGroup() as group:
|
|
141
|
+
for i, sub in enumerate(self._job_offer_subs):
|
|
142
|
+
group.create_task(sub.run_consumer_loop(), name=f"job-offer-consumer-{i}")
|
|
143
|
+
|
|
144
|
+
except ExceptionGroup as eg:
|
|
145
|
+
for exc in eg.exceptions:
|
|
146
|
+
log.error(f"Error in Job.Offer consumer loop! {exc}")
|
|
147
|
+
# log.exception("Error in Job.Offer consumer loop!", exc_info=exc)
|
|
148
|
+
raise
|
|
149
|
+
|
|
150
|
+
async def stop_consuming_job_offers(self):
|
|
151
|
+
"""Stop polling on all `job.offer` queues."""
|
|
152
|
+
await asyncio.gather(
|
|
153
|
+
*(sub.stop_consumer_loop() for sub in self._job_offer_subs)
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Protocol handlers ------------------------
|
|
157
|
+
|
|
158
|
+
async def on_job_offer_message(self, message: str | bytes):
|
|
159
|
+
"""Process the raw message body of incoming job.offers"""
|
|
160
|
+
try:
|
|
161
|
+
job_offer_msg = JobOfferMessageBody.model_validate_json(message)
|
|
162
|
+
except pydantic.ValidationError as ex:
|
|
163
|
+
msg = "Deserialization of 'job.offer' message failed!"
|
|
164
|
+
await self.report_worker_error(msg, exception=ex)
|
|
165
|
+
log.exception("⚠ %s", msg, exc_info=ex)
|
|
166
|
+
raise ProConBadMessage(msg) from ex
|
|
167
|
+
else:
|
|
168
|
+
log.debug("'job.offer' message deserialized: %s", str(job_offer_msg))
|
|
169
|
+
try:
|
|
170
|
+
self._outgoing_msg_queue.put_nowait(job_offer_msg)
|
|
171
|
+
except asyncio.QueueFull:
|
|
172
|
+
# A full message queue indicates that the worker is busy. The worker
|
|
173
|
+
# should have stopped receiving further job.offer messages. If we
|
|
174
|
+
# receive a message nonetheless (e.g. due to concurrency or prefetch)
|
|
175
|
+
# reject the message back to the queue.
|
|
176
|
+
raise ProConMessageRejected("Worker busy -> Job.offer rejected!")
|
|
177
|
+
|
|
178
|
+
async def on_worker_cmd_message(self, message: str | bytes):
|
|
179
|
+
"""Process the raw message body of incoming worker.commands"""
|
|
180
|
+
try:
|
|
181
|
+
job_cmd_msg = WorkerCommandMessageBody.model_validate_json(message)
|
|
182
|
+
except pydantic.ValidationError as ex:
|
|
183
|
+
job_cmd_msg = "Deserialization of 'worker.command' message failed!"
|
|
184
|
+
await self.report_worker_error(job_cmd_msg, exception=ex)
|
|
185
|
+
log.exception("⚠ %s", job_cmd_msg, exc_info=ex)
|
|
186
|
+
else:
|
|
187
|
+
log.debug("'worker.command' message deserialized: %s", str(job_cmd_msg))
|
|
188
|
+
raise NotImplementedError('Worker commands not yet supported!')
|
|
189
|
+
# if self._worker_cmd_cb is not None:
|
|
190
|
+
# result = self._worker_cmd_cb(job_cmd_msg)
|
|
191
|
+
# if isawaitable(result):
|
|
192
|
+
# await result
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class ProConWorker:
|
|
196
|
+
"""
|
|
197
|
+
Worker class providing a `Step` function as remote resource via RabbitMQ
|
|
198
|
+
and handling its execution.
|
|
199
|
+
|
|
200
|
+
This class is instantiated and supervised by the Foreman class.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
_job_task: Task | None = None
|
|
204
|
+
_processing_task: Task | None = None
|
|
205
|
+
_worker_com: WorkerCommunicationRMQ
|
|
206
|
+
_busy_lock: asyncio.BoundedSemaphore
|
|
207
|
+
_message_queue: asyncio.Queue
|
|
208
|
+
_running: asyncio.Event
|
|
209
|
+
|
|
210
|
+
step: Step
|
|
211
|
+
connected: bool
|
|
212
|
+
available_functions: dict[str, str]
|
|
213
|
+
configured_functions: dict[str, str]
|
|
214
|
+
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
step: Step,
|
|
218
|
+
function_names: list[str],
|
|
219
|
+
rmq_parameters: dict,
|
|
220
|
+
worker_id: str,
|
|
221
|
+
context_id: str,
|
|
222
|
+
idle_timeout: int = 0,
|
|
223
|
+
):
|
|
224
|
+
"""Initialize a worker that can offer multiple Step functions for remote execution.
|
|
225
|
+
Functions are executed one at a time according to job.offer messages received via
|
|
226
|
+
RabbitMQ (RMQ) queues.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
step: Step class containing the functions.
|
|
230
|
+
function_names: A list of function names to be offered remotely.
|
|
231
|
+
rmq_parameters: RabbitMQ connection parameters.
|
|
232
|
+
worker_id: Unique identifier for this worker.
|
|
233
|
+
context_id: Context id this worker is running in.
|
|
234
|
+
idle_timeout: Maximum idle time in seconds
|
|
235
|
+
"""
|
|
236
|
+
log.info("Initializing worker for function '%s'", function_names)
|
|
237
|
+
self.worker_id = worker_id
|
|
238
|
+
self.context_id = context_id
|
|
239
|
+
self.step = step
|
|
240
|
+
self.connected = False
|
|
241
|
+
|
|
242
|
+
# Functions as defined in the `Step` class (name and version)
|
|
243
|
+
self.available_functions = {
|
|
244
|
+
name: escape_version_string(signature.version)
|
|
245
|
+
for name, signature in self.step.step_signatures.items()
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
# Functions configured by the --function/-f commandline parameter
|
|
249
|
+
self.configured_functions = self.resolve_function_and_version(function_names)
|
|
250
|
+
|
|
251
|
+
# This event is set when the worker is actively receiving job.offers
|
|
252
|
+
self._running = asyncio.Event()
|
|
253
|
+
|
|
254
|
+
# Semaphore limiting the number of concurrently started Jobs.
|
|
255
|
+
# Currently, there is only one worker that can process a single job.
|
|
256
|
+
self._busy_lock = asyncio.BoundedSemaphore(value=1)
|
|
257
|
+
|
|
258
|
+
# Incoming job.offer messages are passed through this message queue.
|
|
259
|
+
# This limits the number of messages pre-fetched from the RMQ queue.
|
|
260
|
+
self._incoming_msg_queue = asyncio.Queue(maxsize=1)
|
|
261
|
+
|
|
262
|
+
self._worker_com = WorkerCommunicationRMQ(
|
|
263
|
+
rmq_parameters=rmq_parameters, # TODO: make this a dataclass
|
|
264
|
+
sender_id=self.sender_id,
|
|
265
|
+
worker_id=self.worker_id,
|
|
266
|
+
context_id=self.context_id,
|
|
267
|
+
functions=[
|
|
268
|
+
FunctionVersion(name=name, version=version)
|
|
269
|
+
for name, version in self.configured_functions.items()
|
|
270
|
+
],
|
|
271
|
+
msg_queue=self._incoming_msg_queue
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self.idle_timeout = idle_timeout
|
|
275
|
+
|
|
276
|
+
async def run(self) -> None:
|
|
277
|
+
"""Connect to the RabbitMQ server and start accepting jobs"""
|
|
278
|
+
try:
|
|
279
|
+
async with asyncio.TaskGroup() as group:
|
|
280
|
+
log.info("Starting worker communication ...")
|
|
281
|
+
await self._worker_com.start()
|
|
282
|
+
await self._worker_com.report_worker_state(WorkerStatus.starting)
|
|
283
|
+
group.create_task(self._worker_com.start_consuming_job_offers())
|
|
284
|
+
|
|
285
|
+
log.info("▶ Start processing jobs")
|
|
286
|
+
await self._worker_com.report_worker_state(WorkerStatus.running)
|
|
287
|
+
|
|
288
|
+
group.create_task(self.process_jobs(), name="process_jobs")
|
|
289
|
+
|
|
290
|
+
except* ProConShutdown as exc:
|
|
291
|
+
log.info(f"Worker is shutting down! Reason: {str(exc)}")
|
|
292
|
+
|
|
293
|
+
except* CancelledError:
|
|
294
|
+
log.warning("Worker got cancelled!")
|
|
295
|
+
|
|
296
|
+
except* ProConException as exc:
|
|
297
|
+
log.exception(f"Unhandled exception in worker! -> {str(exc)}")
|
|
298
|
+
|
|
299
|
+
finally:
|
|
300
|
+
log.info("⏹ Stopped processing jobs")
|
|
301
|
+
await self._worker_com.report_worker_state(WorkerStatus.exiting)
|
|
302
|
+
await self._worker_com.stop()
|
|
303
|
+
|
|
304
|
+
@property
|
|
305
|
+
def sender_id(self) -> str:
|
|
306
|
+
"""Unique sender ID of this worker"""
|
|
307
|
+
return f"worker:{self.worker_id}"
|
|
308
|
+
|
|
309
|
+
def resolve_function_and_version(self, function_version_strings: list[str]) -> dict[str, str]:
|
|
310
|
+
"""Resolves a list with "function_name:version" specifiers.
|
|
311
|
+
Check if container exposes the given function names,
|
|
312
|
+
substitute '*' with all available names and split off the version string.
|
|
313
|
+
"""
|
|
314
|
+
# If the requested functions contain a wildcard, just return all available functions
|
|
315
|
+
if "*" in function_version_strings:
|
|
316
|
+
return self.available_functions
|
|
317
|
+
|
|
318
|
+
function_version_dict: dict[str, str] = {}
|
|
319
|
+
for s in function_version_strings:
|
|
320
|
+
# Split function+version string at the first colon
|
|
321
|
+
name, _, version = s.partition(":")
|
|
322
|
+
|
|
323
|
+
if name not in self.available_functions:
|
|
324
|
+
raise ProConUnknownFunctionError(f"Unknown function: '{name}'!", func_name=name)
|
|
325
|
+
|
|
326
|
+
# If no version was provided as parameter, return the version annotated at the function
|
|
327
|
+
if not version:
|
|
328
|
+
version = self.available_functions[name]
|
|
329
|
+
else:
|
|
330
|
+
if not is_valid_version_string(version):
|
|
331
|
+
raise ValueError(f"Given version: '{version}' is not valid. Allowed: a-Z 0-9 and '_' '-' '.'")
|
|
332
|
+
|
|
333
|
+
# Ensure the version string contains no '.' so we don't mess with messaging
|
|
334
|
+
version = escape_version_string(version)
|
|
335
|
+
|
|
336
|
+
function_version_dict[name] = version
|
|
337
|
+
|
|
338
|
+
return function_version_dict
|
|
339
|
+
|
|
340
|
+
def check_function_from_offer(self, job_offer_msg: JobOfferMessageBody) -> FunctionVersion | None:
|
|
341
|
+
"""Ensures that the function and version in the job.offer is available in this container."""
|
|
342
|
+
|
|
343
|
+
if (function_name := job_offer_msg.content.algorithm) not in self.configured_functions:
|
|
344
|
+
msg = (f"Received job offer for unknown function: '{function_name}'!"
|
|
345
|
+
f" Available are: {list(self.configured_functions.keys())}")
|
|
346
|
+
log.error(msg)
|
|
347
|
+
raise ProConUnknownFunctionError(msg)
|
|
348
|
+
|
|
349
|
+
requested_version = escape_version_string(job_offer_msg.content.algorithm_version)
|
|
350
|
+
available_version = self.configured_functions[function_name]
|
|
351
|
+
if requested_version != available_version:
|
|
352
|
+
msg = (
|
|
353
|
+
f"Function '{function_name}' is available in version '{available_version}', "
|
|
354
|
+
f"but the requested version is: '{requested_version}'"
|
|
355
|
+
)
|
|
356
|
+
log.error(msg)
|
|
357
|
+
raise ProConUnknownFunctionError(msg)
|
|
358
|
+
|
|
359
|
+
return FunctionVersion(name=function_name, version=requested_version)
|
|
360
|
+
|
|
361
|
+
async def run_job(self, job_offer_msg: JobOfferMessageBody) -> None:
|
|
362
|
+
"""Run a specific job according to a job.offer message."""
|
|
363
|
+
|
|
364
|
+
function_and_version = self.check_function_from_offer(job_offer_msg)
|
|
365
|
+
|
|
366
|
+
await self._busy_lock.acquire()
|
|
367
|
+
try:
|
|
368
|
+
job = ProConJob(
|
|
369
|
+
step=self.step,
|
|
370
|
+
function=function_and_version,
|
|
371
|
+
context_id=self.context_id,
|
|
372
|
+
job_offer=job_offer_msg.content,
|
|
373
|
+
_rmq_client=self._worker_com.rmq_client,
|
|
374
|
+
)
|
|
375
|
+
await job.start()
|
|
376
|
+
await job.process()
|
|
377
|
+
except Exception:
|
|
378
|
+
raise
|
|
379
|
+
finally:
|
|
380
|
+
self._busy_lock.release()
|
|
381
|
+
|
|
382
|
+
async def process_jobs(self):
|
|
383
|
+
"""Processing loop for incoming `job.offer` messages."""
|
|
384
|
+
log.debug("🔄 Entering processing loop.")
|
|
385
|
+
self._running.set()
|
|
386
|
+
while self._running.is_set():
|
|
387
|
+
try:
|
|
388
|
+
if self.idle_timeout > 0:
|
|
389
|
+
log.info(f"Idle timeout set to: {self.idle_timeout}s")
|
|
390
|
+
async with asyncio.timeout(delay=self.idle_timeout):
|
|
391
|
+
message = await self._incoming_msg_queue.get()
|
|
392
|
+
else:
|
|
393
|
+
message = await self._incoming_msg_queue.get()
|
|
394
|
+
|
|
395
|
+
except (TimeoutError, asyncio.TimeoutError):
|
|
396
|
+
log.info(f"⌛ Idle timeout! Worker was idle for more than {self.idle_timeout}s.")
|
|
397
|
+
raise ProConShutdown("Idle-timeout exceeded")
|
|
398
|
+
except asyncio.CancelledError:
|
|
399
|
+
raise ProConShutdown("Processing loop canceled")
|
|
400
|
+
|
|
401
|
+
try:
|
|
402
|
+
self._job_task = asyncio.create_task(self.run_job(message), name="job_main")
|
|
403
|
+
await self._job_task
|
|
404
|
+
except CancelledError:
|
|
405
|
+
pass # Wait for the computation to finish, when this task is cancelled
|
|
406
|
+
finally:
|
|
407
|
+
self._incoming_msg_queue.task_done()
|
|
408
|
+
|
|
409
|
+
self._running.clear()
|
|
410
|
+
log.debug("⏹ Exiting processing loop.")
|
|
411
|
+
raise ProConShutdown("Processing loop stopped")
|
|
412
|
+
|
|
413
|
+
async def stop(self, timeout: float | None = None) -> None:
|
|
414
|
+
"""Stop the processing loop for incoming job offers.
|
|
415
|
+
If a job is still running, wait for the computation to finish."""
|
|
416
|
+
self._running.clear()
|
|
417
|
+
await self._worker_com.stop_consuming_job_offers()
|
|
418
|
+
|
|
419
|
+
# Wait for running jobs by waiting for the busy lock
|
|
420
|
+
if self._busy_lock.locked():
|
|
421
|
+
log.debug("Waiting for running Job to finish ...")
|
|
422
|
+
await asyncio.wait_for(self._busy_lock.acquire(), timeout=timeout)
|
|
423
|
+
self._busy_lock.release()
|
|
424
|
+
|
|
425
|
+
# Force the processing loop to exit, when waiting for the queue's .get()
|
|
426
|
+
if self._processing_task is not None:
|
|
427
|
+
self._processing_task.cancel()
|
|
428
|
+
|
|
429
|
+
def _on_idle_timeout(self) -> None:
|
|
430
|
+
"""A wrapper to call self.stop() from synchronous code."""
|
|
431
|
+
log.info("Worker reached idle timeout. Shutting down ...")
|
|
432
|
+
self._stop_task = asyncio.create_task(self.stop(), name="stop-worker-task")
|
|
433
|
+
self._stop_task.add_done_callback(handle_task_result)
|
|
434
|
+
|
|
435
|
+
async def _on_worker_cmd(self, worker_cmd_msg: WorkerCommandMessageBody) -> None:
|
|
436
|
+
# TODO
|
|
437
|
+
raise NotImplementedError()
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from importlib.metadata import version as get_package_version
|
|
3
|
+
from types import NoneType
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
|
|
5
|
+
|
|
6
|
+
import docstring_parser
|
|
7
|
+
import pydantic
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
|
|
10
|
+
from ..core.exceptions import ProConSchemaError
|
|
11
|
+
from ..dataslots.annotation import RETURN_SLOT_NAME, SlotType, dataslot, get_dataslot_metadata
|
|
12
|
+
from ..dataslots.dataslots import isdataslot
|
|
13
|
+
from .schema import DataslotModel, DynamicParameters, DynamicReturns, FunctionModel
|
|
14
|
+
from .versioning import get_version_metadata
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from ..step import Step
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class StepClassInfo:
|
|
22
|
+
"""Adapter to extract signatures from all methods in a step class."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, step: "Step"):
|
|
25
|
+
self.cls = step
|
|
26
|
+
self.funcs = StepClassInfo.funcs_from_cls(step)
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def funcs_from_cls(cls: object) -> List[Tuple[str, Callable]]:
|
|
30
|
+
"""Return a list with _(name, func)_ for all public functions in _cls_"""
|
|
31
|
+
return [
|
|
32
|
+
(name, f)
|
|
33
|
+
for name, f in inspect.getmembers(cls)
|
|
34
|
+
if not name.startswith("_") and inspect.ismethod(f)
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
def get_func_schemas(self) -> Dict[str, "FunctionSchema"]:
|
|
38
|
+
return {name: FunctionSchema(f) for name, f in self.funcs}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FunctionSchema:
|
|
42
|
+
"""Extract the signature of a given function from its annotations and docstrings"""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
function: Callable[[Any], Any]
|
|
46
|
+
version: str
|
|
47
|
+
|
|
48
|
+
_docs: docstring_parser.Docstring
|
|
49
|
+
_returns_docs: str
|
|
50
|
+
_param_docs: dict[str, str]
|
|
51
|
+
_param_annotations: dict[str, inspect.Parameter]
|
|
52
|
+
_dataslot_annotations: dict[str, inspect.Parameter]
|
|
53
|
+
|
|
54
|
+
_decorator_slots: dict[str, dataslot]
|
|
55
|
+
_parameter_slots: dict[str, dataslot]
|
|
56
|
+
_signature: inspect.Signature
|
|
57
|
+
|
|
58
|
+
def __init__(self, function: Callable):
|
|
59
|
+
self.name = function.__name__
|
|
60
|
+
self.function = function
|
|
61
|
+
self.version = str(get_version_metadata(function))
|
|
62
|
+
self._signature = inspect.signature(function)
|
|
63
|
+
self._init_docstrings()
|
|
64
|
+
self._init_decorated_dataslots()
|
|
65
|
+
self._init_annotations()
|
|
66
|
+
self._init_annotated_dataslots()
|
|
67
|
+
|
|
68
|
+
def _init_docstrings(self):
|
|
69
|
+
self._docs = docstring_parser.parse(inspect.getdoc(self.function))
|
|
70
|
+
self._param_docs = {p.arg_name: p.description for p in self._docs.params}
|
|
71
|
+
self._returns_docs = self._docs.returns.description if self._docs.returns else ""
|
|
72
|
+
|
|
73
|
+
def _init_decorated_dataslots(self):
|
|
74
|
+
# Dataslots defined by function decorator
|
|
75
|
+
self._decorator_slots = get_dataslot_metadata(self.function)
|
|
76
|
+
|
|
77
|
+
# Validate for each dataslot that a parameter with that name exists
|
|
78
|
+
decorated_names = set(self._decorator_slots.keys()) - {RETURN_SLOT_NAME}
|
|
79
|
+
parameter_names = set(self._signature.parameters.keys())
|
|
80
|
+
if diff := decorated_names - parameter_names:
|
|
81
|
+
raise ProConSchemaError(
|
|
82
|
+
f"The dataslots {diff} of function '{self.name}' don't match any parameters name!"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Update description for the return value with the docstring, if not set in decorator
|
|
86
|
+
if ((RETURN_SLOT_NAME in self._decorator_slots)
|
|
87
|
+
and not self._decorator_slots[RETURN_SLOT_NAME].description):
|
|
88
|
+
self._decorator_slots[RETURN_SLOT_NAME].description = self._returns_docs
|
|
89
|
+
|
|
90
|
+
def _init_annotations(self):
|
|
91
|
+
"""Get the functions signature and sort between dataslot and "other" parameters."""
|
|
92
|
+
self._dataslot_annotations = {}
|
|
93
|
+
self._param_annotations = {}
|
|
94
|
+
sig = self._signature
|
|
95
|
+
|
|
96
|
+
# Check if all parameters and return values have type annotations
|
|
97
|
+
if (any((p.annotation is sig.empty for p in sig.parameters.values()))
|
|
98
|
+
or sig.return_annotation is sig.empty
|
|
99
|
+
):
|
|
100
|
+
raise ProConSchemaError(
|
|
101
|
+
f"Can not generate schema for function '{self.name}'. Type annotation is missing!"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Disallow wildcard `Any` or `object` type annotation
|
|
105
|
+
if (any((p.annotation in (Any, object) for p in sig.parameters.values()))
|
|
106
|
+
or sig.return_annotation in (Any, object)
|
|
107
|
+
):
|
|
108
|
+
raise ProConSchemaError(
|
|
109
|
+
f"Can not generate schema for function '{self.name}'. "
|
|
110
|
+
f"Wildcard types like `Any` or `object` are not allowed for type annotations!"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
for name, p in sig.parameters.items():
|
|
114
|
+
# Collect all parameters that are by type annotation or function decorator a Dataslot
|
|
115
|
+
if isdataslot(p.annotation) or (name in self._decorator_slots):
|
|
116
|
+
self._dataslot_annotations[name] = p
|
|
117
|
+
# ... everything else are regular parameters.
|
|
118
|
+
self._param_annotations[name] = p
|
|
119
|
+
|
|
120
|
+
def _init_annotated_dataslots(self):
|
|
121
|
+
"""Create `dataslot` objects for all parameters with a `Dataslot` type annotation."""
|
|
122
|
+
self._parameter_slots = {}
|
|
123
|
+
for name, p in self._dataslot_annotations.items():
|
|
124
|
+
self._parameter_slots[name] = dataslot(
|
|
125
|
+
name=p.name, dtype=p.annotation, description=self._param_docs.get("name", "")
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def _get_parameters_signature(self) -> Dict[str, tuple[Any, Field]]:
|
|
129
|
+
"""Returns the parameters annotation combined with their docstring as a dict"""
|
|
130
|
+
params_schema = {}
|
|
131
|
+
for name, p in self._param_annotations.items():
|
|
132
|
+
params_schema[name] = (
|
|
133
|
+
p.annotation, # parameter type
|
|
134
|
+
Field(
|
|
135
|
+
title=name,
|
|
136
|
+
default=p.default if p.default is not p.empty else ...,
|
|
137
|
+
description=self._param_docs.get(name, None),
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
return params_schema
|
|
141
|
+
|
|
142
|
+
def _get_return_signature(self) -> dict[str, tuple[Any, Field]]:
|
|
143
|
+
"""Returns the return type combined with its docstring as a dict"""
|
|
144
|
+
return_type = self._signature.return_annotation
|
|
145
|
+
if self._signature.return_annotation in (self._signature.empty, None):
|
|
146
|
+
return_type = NoneType
|
|
147
|
+
returns_schema = {"value": (return_type, Field(..., description=self._returns_docs))}
|
|
148
|
+
return returns_schema
|
|
149
|
+
|
|
150
|
+
def get_parameters_model(self, exclude_dataslots: bool = False) -> Type[DynamicParameters]:
|
|
151
|
+
"""Function parameters as pydantic model for schema generation"""
|
|
152
|
+
params_signature = self._get_parameters_signature()
|
|
153
|
+
if exclude_dataslots:
|
|
154
|
+
params_signature = {
|
|
155
|
+
n: p for n, p in params_signature.items() if n not in self.dataslots
|
|
156
|
+
}
|
|
157
|
+
return pydantic.create_model("Parameters", __base__=DynamicParameters, **params_signature)
|
|
158
|
+
|
|
159
|
+
def get_returns_model(self, exclude_dataslots: bool = False) -> Type[DynamicReturns]:
|
|
160
|
+
"""Function's return type as pydantic model for schema generation"""
|
|
161
|
+
if exclude_dataslots and (RETURN_SLOT_NAME in self._decorator_slots):
|
|
162
|
+
returns_signature = {"value": (NoneType, Field(...))}
|
|
163
|
+
else:
|
|
164
|
+
returns_signature = self._get_return_signature()
|
|
165
|
+
return pydantic.create_model("Returns", __base__=DynamicReturns, **returns_signature)
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def dataslots(self) -> dict[str, dataslot]:
|
|
169
|
+
"""Merge all dataslot information from decorators and type annotation"""
|
|
170
|
+
param_ds = self._parameter_slots
|
|
171
|
+
deco_ds = self._decorator_slots
|
|
172
|
+
# Look for a dataslots name in parameters and decorators,
|
|
173
|
+
# if not found create a default dataslot object and merge them.
|
|
174
|
+
dataslot_names = list(self._dataslot_annotations.keys())
|
|
175
|
+
if RETURN_SLOT_NAME in self._decorator_slots:
|
|
176
|
+
dataslot_names.append(RETURN_SLOT_NAME)
|
|
177
|
+
return {
|
|
178
|
+
name: (param_ds.get(name, dataslot(name))).update(deco_ds.get(name, dataslot(name)))
|
|
179
|
+
for name in dataslot_names
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
def get_dataslot_models(self) -> tuple[list[DataslotModel], list[DataslotModel]]:
|
|
183
|
+
"""Create the signatures for the dataslots"""
|
|
184
|
+
input_dataslots = []
|
|
185
|
+
output_dataslots = []
|
|
186
|
+
for name, d in self.dataslots.items():
|
|
187
|
+
slot_model = DataslotModel(
|
|
188
|
+
name=d.alias or d.name,
|
|
189
|
+
title=d.title,
|
|
190
|
+
description=d.description,
|
|
191
|
+
mediatype=str(d.media_type), # user could use a strenum or similar
|
|
192
|
+
metadata={},
|
|
193
|
+
max_slots=d.max_slots,
|
|
194
|
+
min_slots=d.min_slots,
|
|
195
|
+
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
if d.slot_type is SlotType.INPUT:
|
|
199
|
+
input_dataslots.append(slot_model)
|
|
200
|
+
else:
|
|
201
|
+
output_dataslots.append(slot_model)
|
|
202
|
+
return input_dataslots, output_dataslots
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def signature(self) -> dict:
|
|
206
|
+
"""Create the functions signature with a parameter list, return type and docstrings"""
|
|
207
|
+
in_ds, out_ds = self.get_dataslot_models()
|
|
208
|
+
|
|
209
|
+
return_schema = self.get_returns_model(exclude_dataslots=True).model_json_schema()
|
|
210
|
+
|
|
211
|
+
param_schema = None
|
|
212
|
+
p = self.get_parameters_model(exclude_dataslots=True)
|
|
213
|
+
if len(p.model_fields) > 0:
|
|
214
|
+
param_schema = p.model_json_schema()
|
|
215
|
+
|
|
216
|
+
fields = dict(
|
|
217
|
+
version=str(self.version),
|
|
218
|
+
function_name=self.name,
|
|
219
|
+
short_description=self._docs.short_description or "",
|
|
220
|
+
long_description=self._docs.long_description or "",
|
|
221
|
+
parameters=param_schema,
|
|
222
|
+
returns=return_schema,
|
|
223
|
+
input_dataslots=[m.model_dump() for m in in_ds],
|
|
224
|
+
output_dataslots=[m.model_dump() for m in out_ds],
|
|
225
|
+
procon_version=get_procon_version(),
|
|
226
|
+
)
|
|
227
|
+
return fields
|
|
228
|
+
|
|
229
|
+
def get_function_model(self) -> FunctionModel:
|
|
230
|
+
return FunctionModel(**self.signature)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def get_procon_version() -> str:
|
|
234
|
+
return get_package_version("pinexq-procon")
|