pinexq-procon 2.1.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,437 @@
1
+ import asyncio
2
+ import logging
3
+ from asyncio import CancelledError, Task
4
+ from typing import Awaitable, Callable, TypeVar, Union
5
+
6
+ import pydantic
7
+
8
+ from ..core.exceptions import (
9
+ ProConBadMessage,
10
+ ProConException,
11
+ ProConMessageRejected,
12
+ ProConShutdown,
13
+ ProConUnknownFunctionError,
14
+ )
15
+ from ..core.naming import escape_version_string, is_valid_version_string
16
+ from ..remote.messages import (
17
+ JobOfferMessageBody,
18
+ WorkerCommandMessageBody,
19
+ WorkerStatus,
20
+ )
21
+ from ..remote.rabbitmq import QueueSubscriber, RabbitMQClient
22
+ from ..step import Step
23
+ from .job import FunctionVersion, ProConJob
24
+ from .settings import JOB_OFFERS_TOPIC
25
+ from .tool import handle_task_result
26
+
27
+
28
+ log = logging.getLogger(__name__)
29
+
30
+ T = TypeVar('T') # Any type.
31
+ CallableOrAwaitable = Union[Callable[[T], bool], Callable[[T], Awaitable[bool]]]
32
+
33
+
34
+ class WorkerCommunicationRMQ:
35
+ """Worker-specific communication using RabbitMQ"""
36
+ rmq_client: RabbitMQClient
37
+ _job_offer_subs: list[QueueSubscriber]
38
+ _outgoing_msg_queue: asyncio.Queue
39
+ _worker_cmd_cb: CallableOrAwaitable[WorkerCommandMessageBody] | None
40
+
41
+ def __init__(
42
+ self,
43
+ rmq_parameters: dict,
44
+ sender_id: str,
45
+ worker_id: str,
46
+ context_id: str,
47
+ functions: list[FunctionVersion],
48
+ msg_queue: asyncio.Queue,
49
+ worker_cmd_cb: CallableOrAwaitable[WorkerCommandMessageBody] | None = None,
50
+ ):
51
+ self.sender_id = sender_id
52
+ self.worker_id = worker_id
53
+ self.context_id = context_id
54
+ self._functions = functions
55
+ self._outgoing_msg_queue = msg_queue
56
+ self._worker_cmd_cb = worker_cmd_cb
57
+ self._job_offer_subs = []
58
+
59
+ self.is_initialized = False
60
+
61
+ self.rmq_parameters = rmq_parameters
62
+ self.rmq_client = RabbitMQClient(**rmq_parameters)
63
+
64
+ async def start(self):
65
+ """Connect to the RabbitMQ server and start accepting jobs"""
66
+ await self.rmq_client.connect_and_run()
67
+ await self._connect_to_worker_queues(self._functions)
68
+
69
+ async def stop(self):
70
+ """Stop receiving jobs and disconnect from RMQ server"""
71
+ await self.disconnect_from_worker_queues()
72
+ await self.rmq_client.close()
73
+
74
+ @property
75
+ def is_ready_to_send(self) -> bool:
76
+ """Returns if connection to RMQ server is established and queues initialized."""
77
+ return self.rmq_client.is_connected and self.is_initialized
78
+
79
+ async def report_worker_state(self, status: WorkerStatus | str):
80
+ """Log the current state of the worker"""
81
+ log.info(f"Worker state changed to: '{status.value}'")
82
+
83
+ async def report_worker_error(self, title: str, exception: Exception | None = None):
84
+ """Log error message from an exception"""
85
+ if exception:
86
+ log.exception(f"Error in worker: {title}", exc_info=exception)
87
+ else:
88
+ log.error(f"Error in worker: {title}")
89
+
90
+ async def _create_job_offer_subscription(self, func_name: str, func_version: str) -> QueueSubscriber | None:
91
+ topic = JOB_OFFERS_TOPIC.format(ContextId=self.context_id, FunctionName=func_name, Version=func_version)
92
+ # log.debug("Subscribing to Job.Offers queue: '%s'", topic)
93
+ exists = await self.rmq_client.does_queue_exist(topic)
94
+ if exists:
95
+ return await self.rmq_client.subscriber(
96
+ queue_name=topic,
97
+ routing_key=topic,
98
+ callback=self.on_job_offer_message
99
+ )
100
+ else:
101
+ return None
102
+
103
+ async def _connect_to_worker_queues(self, functions: list[FunctionVersion]):
104
+ """
105
+ Create the worker command channel, subscribe to the status channel,
106
+ and subscribe to a `job.offer` channel for each function this worker publishes.
107
+
108
+ Args:
109
+ functions: List of tuples with ("the_function_name", "version")
110
+ """
111
+
112
+ self._job_offer_subs = []
113
+ for f in functions:
114
+ sub = await self._create_job_offer_subscription(func_name=f.name, func_version=f.version)
115
+ if sub:
116
+ self._job_offer_subs.append(sub)
117
+ log.info(f"Listening for JobOffers: {f.name}:{f.version}")
118
+ else:
119
+ log.warning(f"Could not connect to JobOffers of: {f.name}:{f.version}")
120
+
121
+ if not self._job_offer_subs:
122
+ log.critical(msg:="Failed to connect to any Job.Offer queues!")
123
+ raise ProConException(msg)
124
+
125
+ self.is_initialized = True
126
+
127
+ async def disconnect_from_worker_queues(self):
128
+ """Stop listening to job.offer queues and stop & delete the command queue for this worker."""
129
+ if not self.is_initialized:
130
+ return
131
+ await asyncio.gather(
132
+ *(sub.stop_consumer_loop() for sub in self._job_offer_subs)
133
+ )
134
+
135
+ self.is_initialized = False
136
+
137
+ async def start_consuming_job_offers(self):
138
+ """Start polling on all `job.offer` queues."""
139
+ try:
140
+ async with asyncio.TaskGroup() as group:
141
+ for i, sub in enumerate(self._job_offer_subs):
142
+ group.create_task(sub.run_consumer_loop(), name=f"job-offer-consumer-{i}")
143
+
144
+ except ExceptionGroup as eg:
145
+ for exc in eg.exceptions:
146
+ log.error(f"Error in Job.Offer consumer loop! {exc}")
147
+ # log.exception("Error in Job.Offer consumer loop!", exc_info=exc)
148
+ raise
149
+
150
+ async def stop_consuming_job_offers(self):
151
+ """Stop polling on all `job.offer` queues."""
152
+ await asyncio.gather(
153
+ *(sub.stop_consumer_loop() for sub in self._job_offer_subs)
154
+ )
155
+
156
+ # Protocol handlers ------------------------
157
+
158
+ async def on_job_offer_message(self, message: str | bytes):
159
+ """Process the raw message body of incoming job.offers"""
160
+ try:
161
+ job_offer_msg = JobOfferMessageBody.model_validate_json(message)
162
+ except pydantic.ValidationError as ex:
163
+ msg = "Deserialization of 'job.offer' message failed!"
164
+ await self.report_worker_error(msg, exception=ex)
165
+ log.exception("⚠ %s", msg, exc_info=ex)
166
+ raise ProConBadMessage(msg) from ex
167
+ else:
168
+ log.debug("'job.offer' message deserialized: %s", str(job_offer_msg))
169
+ try:
170
+ self._outgoing_msg_queue.put_nowait(job_offer_msg)
171
+ except asyncio.QueueFull:
172
+ # A full message queue indicates that the worker is busy. The worker
173
+ # should have stopped receiving further job.offer messages. If we
174
+ # receive a message nonetheless (e.g. due to concurrency or prefetch)
175
+ # reject the message back to the queue.
176
+ raise ProConMessageRejected("Worker busy -> Job.offer rejected!")
177
+
178
+ async def on_worker_cmd_message(self, message: str | bytes):
179
+ """Process the raw message body of incoming worker.commands"""
180
+ try:
181
+ job_cmd_msg = WorkerCommandMessageBody.model_validate_json(message)
182
+ except pydantic.ValidationError as ex:
183
+ job_cmd_msg = "Deserialization of 'worker.command' message failed!"
184
+ await self.report_worker_error(job_cmd_msg, exception=ex)
185
+ log.exception("⚠ %s", job_cmd_msg, exc_info=ex)
186
+ else:
187
+ log.debug("'worker.command' message deserialized: %s", str(job_cmd_msg))
188
+ raise NotImplementedError('Worker commands not yet supported!')
189
+ # if self._worker_cmd_cb is not None:
190
+ # result = self._worker_cmd_cb(job_cmd_msg)
191
+ # if isawaitable(result):
192
+ # await result
193
+
194
+
195
+ class ProConWorker:
196
+ """
197
+ Worker class providing a `Step` function as remote resource via RabbitMQ
198
+ and handling its execution.
199
+
200
+ This class is instantiated and supervised by the Foreman class.
201
+ """
202
+
203
+ _job_task: Task | None = None
204
+ _processing_task: Task | None = None
205
+ _worker_com: WorkerCommunicationRMQ
206
+ _busy_lock: asyncio.BoundedSemaphore
207
+ _message_queue: asyncio.Queue
208
+ _running: asyncio.Event
209
+
210
+ step: Step
211
+ connected: bool
212
+ available_functions: dict[str, str]
213
+ configured_functions: dict[str, str]
214
+
215
+ def __init__(
216
+ self,
217
+ step: Step,
218
+ function_names: list[str],
219
+ rmq_parameters: dict,
220
+ worker_id: str,
221
+ context_id: str,
222
+ idle_timeout: int = 0,
223
+ ):
224
+ """Initialize a worker that can offer multiple Step functions for remote execution.
225
+ Functions are executed one at a time according to job.offer messages received via
226
+ RabbitMQ (RMQ) queues.
227
+
228
+ Args:
229
+ step: Step class containing the functions.
230
+ function_names: A list of function names to be offered remotely.
231
+ rmq_parameters: RabbitMQ connection parameters.
232
+ worker_id: Unique identifier for this worker.
233
+ context_id: Context id this worker is running in.
234
+ idle_timeout: Maximum idle time in seconds
235
+ """
236
+ log.info("Initializing worker for function '%s'", function_names)
237
+ self.worker_id = worker_id
238
+ self.context_id = context_id
239
+ self.step = step
240
+ self.connected = False
241
+
242
+ # Functions as defined in the `Step` class (name and version)
243
+ self.available_functions = {
244
+ name: escape_version_string(signature.version)
245
+ for name, signature in self.step.step_signatures.items()
246
+ }
247
+
248
+ # Functions configured by the --function/-f commandline parameter
249
+ self.configured_functions = self.resolve_function_and_version(function_names)
250
+
251
+ # This event is set when the worker is actively receiving job.offers
252
+ self._running = asyncio.Event()
253
+
254
+ # Semaphore limiting the number of concurrently started Jobs.
255
+ # Currently, there is only one worker that can process a single job.
256
+ self._busy_lock = asyncio.BoundedSemaphore(value=1)
257
+
258
+ # Incoming job.offer messages are passed through this message queue.
259
+ # This limits the number of messages pre-fetched from the RMQ queue.
260
+ self._incoming_msg_queue = asyncio.Queue(maxsize=1)
261
+
262
+ self._worker_com = WorkerCommunicationRMQ(
263
+ rmq_parameters=rmq_parameters, # TODO: make this a dataclass
264
+ sender_id=self.sender_id,
265
+ worker_id=self.worker_id,
266
+ context_id=self.context_id,
267
+ functions=[
268
+ FunctionVersion(name=name, version=version)
269
+ for name, version in self.configured_functions.items()
270
+ ],
271
+ msg_queue=self._incoming_msg_queue
272
+ )
273
+
274
+ self.idle_timeout = idle_timeout
275
+
276
+ async def run(self) -> None:
277
+ """Connect to the RabbitMQ server and start accepting jobs"""
278
+ try:
279
+ async with asyncio.TaskGroup() as group:
280
+ log.info("Starting worker communication ...")
281
+ await self._worker_com.start()
282
+ await self._worker_com.report_worker_state(WorkerStatus.starting)
283
+ group.create_task(self._worker_com.start_consuming_job_offers())
284
+
285
+ log.info("▶ Start processing jobs")
286
+ await self._worker_com.report_worker_state(WorkerStatus.running)
287
+
288
+ group.create_task(self.process_jobs(), name="process_jobs")
289
+
290
+ except* ProConShutdown as exc:
291
+ log.info(f"Worker is shutting down! Reason: {str(exc)}")
292
+
293
+ except* CancelledError:
294
+ log.warning("Worker got cancelled!")
295
+
296
+ except* ProConException as exc:
297
+ log.exception(f"Unhandled exception in worker! -> {str(exc)}")
298
+
299
+ finally:
300
+ log.info("⏹ Stopped processing jobs")
301
+ await self._worker_com.report_worker_state(WorkerStatus.exiting)
302
+ await self._worker_com.stop()
303
+
304
+ @property
305
+ def sender_id(self) -> str:
306
+ """Unique sender ID of this worker"""
307
+ return f"worker:{self.worker_id}"
308
+
309
+ def resolve_function_and_version(self, function_version_strings: list[str]) -> dict[str, str]:
310
+ """Resolves a list with "function_name:version" specifiers.
311
+ Check if container exposes the given function names,
312
+ substitute '*' with all available names and split off the version string.
313
+ """
314
+ # If the requested functions contain a wildcard, just return all available functions
315
+ if "*" in function_version_strings:
316
+ return self.available_functions
317
+
318
+ function_version_dict: dict[str, str] = {}
319
+ for s in function_version_strings:
320
+ # Split function+version string at the first colon
321
+ name, _, version = s.partition(":")
322
+
323
+ if name not in self.available_functions:
324
+ raise ProConUnknownFunctionError(f"Unknown function: '{name}'!", func_name=name)
325
+
326
+ # If no version was provided as parameter, return the version annotated at the function
327
+ if not version:
328
+ version = self.available_functions[name]
329
+ else:
330
+ if not is_valid_version_string(version):
331
+ raise ValueError(f"Given version: '{version}' is not valid. Allowed: a-Z 0-9 and '_' '-' '.'")
332
+
333
+ # Ensure the version string contains no '.' so we don't mess with messaging
334
+ version = escape_version_string(version)
335
+
336
+ function_version_dict[name] = version
337
+
338
+ return function_version_dict
339
+
340
+ def check_function_from_offer(self, job_offer_msg: JobOfferMessageBody) -> FunctionVersion | None:
341
+ """Ensures that the function and version in the job.offer is available in this container."""
342
+
343
+ if (function_name := job_offer_msg.content.algorithm) not in self.configured_functions:
344
+ msg = (f"Received job offer for unknown function: '{function_name}'!"
345
+ f" Available are: {list(self.configured_functions.keys())}")
346
+ log.error(msg)
347
+ raise ProConUnknownFunctionError(msg)
348
+
349
+ requested_version = escape_version_string(job_offer_msg.content.algorithm_version)
350
+ available_version = self.configured_functions[function_name]
351
+ if requested_version != available_version:
352
+ msg = (
353
+ f"Function '{function_name}' is available in version '{available_version}', "
354
+ f"but the requested version is: '{requested_version}'"
355
+ )
356
+ log.error(msg)
357
+ raise ProConUnknownFunctionError(msg)
358
+
359
+ return FunctionVersion(name=function_name, version=requested_version)
360
+
361
+ async def run_job(self, job_offer_msg: JobOfferMessageBody) -> None:
362
+ """Run a specific job according to a job.offer message."""
363
+
364
+ function_and_version = self.check_function_from_offer(job_offer_msg)
365
+
366
+ await self._busy_lock.acquire()
367
+ try:
368
+ job = ProConJob(
369
+ step=self.step,
370
+ function=function_and_version,
371
+ context_id=self.context_id,
372
+ job_offer=job_offer_msg.content,
373
+ _rmq_client=self._worker_com.rmq_client,
374
+ )
375
+ await job.start()
376
+ await job.process()
377
+ except Exception:
378
+ raise
379
+ finally:
380
+ self._busy_lock.release()
381
+
382
+ async def process_jobs(self):
383
+ """Processing loop for incoming `job.offer` messages."""
384
+ log.debug("🔄 Entering processing loop.")
385
+ self._running.set()
386
+ while self._running.is_set():
387
+ try:
388
+ if self.idle_timeout > 0:
389
+ log.info(f"Idle timeout set to: {self.idle_timeout}s")
390
+ async with asyncio.timeout(delay=self.idle_timeout):
391
+ message = await self._incoming_msg_queue.get()
392
+ else:
393
+ message = await self._incoming_msg_queue.get()
394
+
395
+ except (TimeoutError, asyncio.TimeoutError):
396
+ log.info(f"⌛ Idle timeout! Worker was idle for more than {self.idle_timeout}s.")
397
+ raise ProConShutdown("Idle-timeout exceeded")
398
+ except asyncio.CancelledError:
399
+ raise ProConShutdown("Processing loop canceled")
400
+
401
+ try:
402
+ self._job_task = asyncio.create_task(self.run_job(message), name="job_main")
403
+ await self._job_task
404
+ except CancelledError:
405
+ pass # Wait for the computation to finish, when this task is cancelled
406
+ finally:
407
+ self._incoming_msg_queue.task_done()
408
+
409
+ self._running.clear()
410
+ log.debug("⏹ Exiting processing loop.")
411
+ raise ProConShutdown("Processing loop stopped")
412
+
413
+ async def stop(self, timeout: float | None = None) -> None:
414
+ """Stop the processing loop for incoming job offers.
415
+ If a job is still running, wait for the computation to finish."""
416
+ self._running.clear()
417
+ await self._worker_com.stop_consuming_job_offers()
418
+
419
+ # Wait for running jobs by waiting for the busy lock
420
+ if self._busy_lock.locked():
421
+ log.debug("Waiting for running Job to finish ...")
422
+ await asyncio.wait_for(self._busy_lock.acquire(), timeout=timeout)
423
+ self._busy_lock.release()
424
+
425
+ # Force the processing loop to exit, when waiting for the queue's .get()
426
+ if self._processing_task is not None:
427
+ self._processing_task.cancel()
428
+
429
+ def _on_idle_timeout(self) -> None:
430
+ """A wrapper to call self.stop() from synchronous code."""
431
+ log.info("Worker reached idle timeout. Shutting down ...")
432
+ self._stop_task = asyncio.create_task(self.stop(), name="stop-worker-task")
433
+ self._stop_task.add_done_callback(handle_task_result)
434
+
435
+ async def _on_worker_cmd(self, worker_cmd_msg: WorkerCommandMessageBody) -> None:
436
+ # TODO
437
+ raise NotImplementedError()
@@ -0,0 +1,3 @@
1
+ # ruff: noqa: F401
2
+ from .step import Step, ExecutionContext
3
+ from .versioning import version
@@ -0,0 +1,234 @@
1
+ import inspect
2
+ from importlib.metadata import version as get_package_version
3
+ from types import NoneType
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Type
5
+
6
+ import docstring_parser
7
+ import pydantic
8
+ from pydantic import Field
9
+
10
+ from ..core.exceptions import ProConSchemaError
11
+ from ..dataslots.annotation import RETURN_SLOT_NAME, SlotType, dataslot, get_dataslot_metadata
12
+ from ..dataslots.dataslots import isdataslot
13
+ from .schema import DataslotModel, DynamicParameters, DynamicReturns, FunctionModel
14
+ from .versioning import get_version_metadata
15
+
16
+
17
+ if TYPE_CHECKING:
18
+ from ..step import Step
19
+
20
+
21
+ class StepClassInfo:
22
+ """Adapter to extract signatures from all methods in a step class."""
23
+
24
+ def __init__(self, step: "Step"):
25
+ self.cls = step
26
+ self.funcs = StepClassInfo.funcs_from_cls(step)
27
+
28
+ @staticmethod
29
+ def funcs_from_cls(cls: object) -> List[Tuple[str, Callable]]:
30
+ """Return a list with _(name, func)_ for all public functions in _cls_"""
31
+ return [
32
+ (name, f)
33
+ for name, f in inspect.getmembers(cls)
34
+ if not name.startswith("_") and inspect.ismethod(f)
35
+ ]
36
+
37
+ def get_func_schemas(self) -> Dict[str, "FunctionSchema"]:
38
+ return {name: FunctionSchema(f) for name, f in self.funcs}
39
+
40
+
41
+ class FunctionSchema:
42
+ """Extract the signature of a given function from its annotations and docstrings"""
43
+
44
+ name: str
45
+ function: Callable[[Any], Any]
46
+ version: str
47
+
48
+ _docs: docstring_parser.Docstring
49
+ _returns_docs: str
50
+ _param_docs: dict[str, str]
51
+ _param_annotations: dict[str, inspect.Parameter]
52
+ _dataslot_annotations: dict[str, inspect.Parameter]
53
+
54
+ _decorator_slots: dict[str, dataslot]
55
+ _parameter_slots: dict[str, dataslot]
56
+ _signature: inspect.Signature
57
+
58
+ def __init__(self, function: Callable):
59
+ self.name = function.__name__
60
+ self.function = function
61
+ self.version = str(get_version_metadata(function))
62
+ self._signature = inspect.signature(function)
63
+ self._init_docstrings()
64
+ self._init_decorated_dataslots()
65
+ self._init_annotations()
66
+ self._init_annotated_dataslots()
67
+
68
+ def _init_docstrings(self):
69
+ self._docs = docstring_parser.parse(inspect.getdoc(self.function))
70
+ self._param_docs = {p.arg_name: p.description for p in self._docs.params}
71
+ self._returns_docs = self._docs.returns.description if self._docs.returns else ""
72
+
73
+ def _init_decorated_dataslots(self):
74
+ # Dataslots defined by function decorator
75
+ self._decorator_slots = get_dataslot_metadata(self.function)
76
+
77
+ # Validate for each dataslot that a parameter with that name exists
78
+ decorated_names = set(self._decorator_slots.keys()) - {RETURN_SLOT_NAME}
79
+ parameter_names = set(self._signature.parameters.keys())
80
+ if diff := decorated_names - parameter_names:
81
+ raise ProConSchemaError(
82
+ f"The dataslots {diff} of function '{self.name}' don't match any parameters name!"
83
+ )
84
+
85
+ # Update description for the return value with the docstring, if not set in decorator
86
+ if ((RETURN_SLOT_NAME in self._decorator_slots)
87
+ and not self._decorator_slots[RETURN_SLOT_NAME].description):
88
+ self._decorator_slots[RETURN_SLOT_NAME].description = self._returns_docs
89
+
90
+ def _init_annotations(self):
91
+ """Get the functions signature and sort between dataslot and "other" parameters."""
92
+ self._dataslot_annotations = {}
93
+ self._param_annotations = {}
94
+ sig = self._signature
95
+
96
+ # Check if all parameters and return values have type annotations
97
+ if (any((p.annotation is sig.empty for p in sig.parameters.values()))
98
+ or sig.return_annotation is sig.empty
99
+ ):
100
+ raise ProConSchemaError(
101
+ f"Can not generate schema for function '{self.name}'. Type annotation is missing!"
102
+ )
103
+
104
+ # Disallow wildcard `Any` or `object` type annotation
105
+ if (any((p.annotation in (Any, object) for p in sig.parameters.values()))
106
+ or sig.return_annotation in (Any, object)
107
+ ):
108
+ raise ProConSchemaError(
109
+ f"Can not generate schema for function '{self.name}'. "
110
+ f"Wildcard types like `Any` or `object` are not allowed for type annotations!"
111
+ )
112
+
113
+ for name, p in sig.parameters.items():
114
+ # Collect all parameters that are by type annotation or function decorator a Dataslot
115
+ if isdataslot(p.annotation) or (name in self._decorator_slots):
116
+ self._dataslot_annotations[name] = p
117
+ # ... everything else are regular parameters.
118
+ self._param_annotations[name] = p
119
+
120
+ def _init_annotated_dataslots(self):
121
+ """Create `dataslot` objects for all parameters with a `Dataslot` type annotation."""
122
+ self._parameter_slots = {}
123
+ for name, p in self._dataslot_annotations.items():
124
+ self._parameter_slots[name] = dataslot(
125
+ name=p.name, dtype=p.annotation, description=self._param_docs.get("name", "")
126
+ )
127
+
128
+ def _get_parameters_signature(self) -> Dict[str, tuple[Any, Field]]:
129
+ """Returns the parameters annotation combined with their docstring as a dict"""
130
+ params_schema = {}
131
+ for name, p in self._param_annotations.items():
132
+ params_schema[name] = (
133
+ p.annotation, # parameter type
134
+ Field(
135
+ title=name,
136
+ default=p.default if p.default is not p.empty else ...,
137
+ description=self._param_docs.get(name, None),
138
+ ),
139
+ )
140
+ return params_schema
141
+
142
+ def _get_return_signature(self) -> dict[str, tuple[Any, Field]]:
143
+ """Returns the return type combined with its docstring as a dict"""
144
+ return_type = self._signature.return_annotation
145
+ if self._signature.return_annotation in (self._signature.empty, None):
146
+ return_type = NoneType
147
+ returns_schema = {"value": (return_type, Field(..., description=self._returns_docs))}
148
+ return returns_schema
149
+
150
+ def get_parameters_model(self, exclude_dataslots: bool = False) -> Type[DynamicParameters]:
151
+ """Function parameters as pydantic model for schema generation"""
152
+ params_signature = self._get_parameters_signature()
153
+ if exclude_dataslots:
154
+ params_signature = {
155
+ n: p for n, p in params_signature.items() if n not in self.dataslots
156
+ }
157
+ return pydantic.create_model("Parameters", __base__=DynamicParameters, **params_signature)
158
+
159
+ def get_returns_model(self, exclude_dataslots: bool = False) -> Type[DynamicReturns]:
160
+ """Function's return type as pydantic model for schema generation"""
161
+ if exclude_dataslots and (RETURN_SLOT_NAME in self._decorator_slots):
162
+ returns_signature = {"value": (NoneType, Field(...))}
163
+ else:
164
+ returns_signature = self._get_return_signature()
165
+ return pydantic.create_model("Returns", __base__=DynamicReturns, **returns_signature)
166
+
167
+ @property
168
+ def dataslots(self) -> dict[str, dataslot]:
169
+ """Merge all dataslot information from decorators and type annotation"""
170
+ param_ds = self._parameter_slots
171
+ deco_ds = self._decorator_slots
172
+ # Look for a dataslots name in parameters and decorators,
173
+ # if not found create a default dataslot object and merge them.
174
+ dataslot_names = list(self._dataslot_annotations.keys())
175
+ if RETURN_SLOT_NAME in self._decorator_slots:
176
+ dataslot_names.append(RETURN_SLOT_NAME)
177
+ return {
178
+ name: (param_ds.get(name, dataslot(name))).update(deco_ds.get(name, dataslot(name)))
179
+ for name in dataslot_names
180
+ }
181
+
182
+ def get_dataslot_models(self) -> tuple[list[DataslotModel], list[DataslotModel]]:
183
+ """Create the signatures for the dataslots"""
184
+ input_dataslots = []
185
+ output_dataslots = []
186
+ for name, d in self.dataslots.items():
187
+ slot_model = DataslotModel(
188
+ name=d.alias or d.name,
189
+ title=d.title,
190
+ description=d.description,
191
+ mediatype=str(d.media_type), # user could use a strenum or similar
192
+ metadata={},
193
+ max_slots=d.max_slots,
194
+ min_slots=d.min_slots,
195
+
196
+ )
197
+
198
+ if d.slot_type is SlotType.INPUT:
199
+ input_dataslots.append(slot_model)
200
+ else:
201
+ output_dataslots.append(slot_model)
202
+ return input_dataslots, output_dataslots
203
+
204
+ @property
205
+ def signature(self) -> dict:
206
+ """Create the functions signature with a parameter list, return type and docstrings"""
207
+ in_ds, out_ds = self.get_dataslot_models()
208
+
209
+ return_schema = self.get_returns_model(exclude_dataslots=True).model_json_schema()
210
+
211
+ param_schema = None
212
+ p = self.get_parameters_model(exclude_dataslots=True)
213
+ if len(p.model_fields) > 0:
214
+ param_schema = p.model_json_schema()
215
+
216
+ fields = dict(
217
+ version=str(self.version),
218
+ function_name=self.name,
219
+ short_description=self._docs.short_description or "",
220
+ long_description=self._docs.long_description or "",
221
+ parameters=param_schema,
222
+ returns=return_schema,
223
+ input_dataslots=[m.model_dump() for m in in_ds],
224
+ output_dataslots=[m.model_dump() for m in out_ds],
225
+ procon_version=get_procon_version(),
226
+ )
227
+ return fields
228
+
229
+ def get_function_model(self) -> FunctionModel:
230
+ return FunctionModel(**self.signature)
231
+
232
+
233
+ def get_procon_version() -> str:
234
+ return get_package_version("pinexq-procon")