qena-shared-lib 0.1.12__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,8 +1,7 @@
1
- from typing import Annotated
1
+ from typing import cast
2
2
 
3
3
  from pydantic import ValidationError
4
4
 
5
- from ..dependencies.miscellaneous import DependsOn
6
5
  from ..exceptions import (
7
6
  HTTPServiceError,
8
7
  RabbitMQServiceException,
@@ -10,138 +9,174 @@ from ..exceptions import (
10
9
  Severity,
11
10
  )
12
11
  from ..logging import LoggerProvider
13
- from ..logstash._base import BaseLogstashSender
12
+ from ..remotelogging._base import BaseRemoteLogSender
14
13
  from ._listener import ListenerContext
15
14
 
16
15
  __all__ = [
17
- "handle_general_mq_exception",
18
- "handle_rabbit_mq_service_exception",
19
- "handle_validation_error",
16
+ "AbstractRabbitMqExceptionHandler",
17
+ "GeneralMqExceptionHandler",
18
+ "RabbitMqServiceExceptionHandler",
19
+ "ValidationErrorHandler",
20
20
  ]
21
21
 
22
22
  RABBITMQ_EXCEPTION_HANDLER_LOGGER_NAME = "rabbitmq.exception_handler"
23
23
 
24
24
 
25
- def handle_rabbit_mq_service_exception(
26
- context: ListenerContext,
27
- exception: ServiceException,
28
- logstash: Annotated[BaseLogstashSender, DependsOn(BaseLogstashSender)],
29
- logger_provider: Annotated[LoggerProvider, DependsOn(LoggerProvider)],
30
- ) -> None:
31
- logger = logger_provider.get_logger(RABBITMQ_EXCEPTION_HANDLER_LOGGER_NAME)
32
- tags = [
33
- "RabbitMQ",
34
- context.queue,
35
- context.listener_name or "__default__",
36
- exception.__class__.__name__,
37
- ]
38
- extra = {
39
- "serviceType": "RabbitMQ",
40
- "queue": context.queue,
41
- "listenerName": context.listener_name,
42
- "exception": exception.__class__.__name__,
43
- }
44
-
45
- match exception:
46
- case HTTPServiceError() as http_service_error:
47
- if http_service_error.status_code is not None:
48
- str_status_code = str(http_service_error.status_code)
49
- extra["statusCode"] = str_status_code
50
-
51
- tags.append(str_status_code)
52
-
53
- if http_service_error.response_code is not None:
54
- str_response_code = str(http_service_error.response_code)
55
- extra["responseCode"] = str_response_code
56
-
57
- tags.append(str_response_code)
58
- case RabbitMQServiceException() as rabbitmq_service_exception:
59
- str_error_code = str(rabbitmq_service_exception.code)
60
- extra["code"] = str_error_code
61
-
62
- tags.append(str_error_code)
63
-
64
- if exception.tags:
65
- tags.extend(exception.tags)
66
-
67
- if exception.extra:
68
- extra.update(exception.extra)
69
-
70
- exc_info = (
71
- (type(exception), exception, exception.__traceback__)
72
- if exception.extract_exc_info
73
- else None
74
- )
75
-
76
- match exception.severity:
77
- case Severity.HIGH:
78
- logstash_logger_method = logstash.error
79
- logger_method = logger.error
80
- case Severity.MEDIUM:
81
- logstash_logger_method = logstash.warning
82
- logger_method = logger.warning
83
- case _:
84
- logstash_logger_method = logstash.info
85
- logger_method = logger.info
86
-
87
- if exception.logstash_logging:
88
- logstash_logger_method(
89
- message=exception.message,
90
- tags=tags,
91
- extra=extra,
92
- exception=exception if exception.extract_exc_info else None,
93
- )
94
- else:
95
- logger_method(
96
- "\nRabbitMQ `%s` -> `%s`\n%s",
97
- context.queue,
98
- context.listener_name,
99
- exception.message,
100
- exc_info=exc_info,
101
- )
25
+ class AbstractRabbitMqExceptionHandler:
26
+ @property
27
+ def exception(self) -> type[Exception]:
28
+ raise NotImplementedError()
102
29
 
103
30
 
104
- def handle_validation_error(
105
- context: ListenerContext,
106
- exception: ValidationError,
107
- logstash: Annotated[BaseLogstashSender, DependsOn(BaseLogstashSender)],
108
- ) -> None:
109
- logstash.error(
110
- message=f"invalid rabbitmq request data at queue `{context.queue}` and listener `{context.listener_name}`",
111
- tags=[
112
- "RabbitMQ",
113
- context.queue,
114
- context.listener_name or "__default__",
115
- "ValidationError",
116
- ],
117
- extra={
118
- "serviceType": "RabbitMQ",
119
- "queue": context.queue,
120
- "listenerName": context.listener_name,
121
- "exception": "ValidationError",
122
- },
123
- exception=exception,
124
- )
125
-
126
-
127
- def handle_general_mq_exception(
128
- context: ListenerContext,
129
- exception: Exception,
130
- logstash: Annotated[BaseLogstashSender, DependsOn(BaseLogstashSender)],
131
- ) -> None:
132
- logstash.error(
133
- message=f"something went wrong while consuming message on queue `{context.queue}` and listener `{context.listener_name}`",
134
- tags=[
31
+ class RabbitMqServiceExceptionHandler(AbstractRabbitMqExceptionHandler):
32
+ @property
33
+ def exception(self) -> type[Exception]:
34
+ return cast(type[Exception], ServiceException)
35
+
36
+ def __init__(
37
+ self,
38
+ remote_logger: BaseRemoteLogSender,
39
+ logger_provider: LoggerProvider,
40
+ ):
41
+ self._logger = logger_provider.get_logger(
42
+ RABBITMQ_EXCEPTION_HANDLER_LOGGER_NAME
43
+ )
44
+ self._remote_logger = remote_logger
45
+
46
+ def __call__(
47
+ self,
48
+ context: ListenerContext,
49
+ exception: ServiceException,
50
+ ) -> None:
51
+ tags = [
135
52
  "RabbitMQ",
136
53
  context.queue,
137
54
  context.listener_name or "__default__",
138
55
  exception.__class__.__name__,
139
- ],
140
- extra={
56
+ ]
57
+ extra = {
141
58
  "serviceType": "RabbitMQ",
142
59
  "queue": context.queue,
143
60
  "listenerName": context.listener_name,
144
61
  "exception": exception.__class__.__name__,
145
- },
146
- exception=exception,
147
- )
62
+ }
63
+
64
+ match exception:
65
+ case HTTPServiceError() as http_service_error:
66
+ if http_service_error.status_code is not None:
67
+ str_status_code = str(http_service_error.status_code)
68
+ extra["statusCode"] = str_status_code
69
+
70
+ tags.append(str_status_code)
71
+
72
+ if http_service_error.response_code is not None:
73
+ str_response_code = str(http_service_error.response_code)
74
+ extra["responseCode"] = str_response_code
75
+
76
+ tags.append(str_response_code)
77
+ case RabbitMQServiceException() as rabbitmq_service_exception:
78
+ str_error_code = str(rabbitmq_service_exception.code)
79
+ extra["code"] = str_error_code
80
+
81
+ tags.append(str_error_code)
82
+
83
+ if exception.tags:
84
+ tags.extend(exception.tags)
85
+
86
+ if exception.extra:
87
+ extra.update(exception.extra)
88
+
89
+ exc_info = (
90
+ (type(exception), exception, exception.__traceback__)
91
+ if exception.extract_exc_info
92
+ else None
93
+ )
94
+
95
+ match exception.severity:
96
+ case Severity.HIGH:
97
+ remote_logger_method = self._remote_logger.error
98
+ logger_method = self._logger.error
99
+ case Severity.MEDIUM:
100
+ remote_logger_method = self._remote_logger.warning
101
+ logger_method = self._logger.warning
102
+ case _:
103
+ remote_logger_method = self._remote_logger.info
104
+ logger_method = self._logger.info
105
+
106
+ if exception.remote_logging:
107
+ remote_logger_method(
108
+ message=exception.message,
109
+ tags=tags,
110
+ extra=extra,
111
+ exception=exception if exception.extract_exc_info else None,
112
+ )
113
+ else:
114
+ logger_method(
115
+ "\nRabbitMQ `%s` -> `%s`\n%s",
116
+ context.queue,
117
+ context.listener_name,
118
+ exception.message,
119
+ exc_info=exc_info,
120
+ )
121
+
122
+
123
+ class ValidationErrorHandler(AbstractRabbitMqExceptionHandler):
124
+ @property
125
+ def exception(self) -> type[Exception]:
126
+ return cast(type[Exception], ValidationError)
127
+
128
+ def __init__(self, remote_logger: BaseRemoteLogSender):
129
+ self._remote_logger = remote_logger
130
+
131
+ def __call__(
132
+ self,
133
+ context: ListenerContext,
134
+ exception: ValidationError,
135
+ ) -> None:
136
+ self._remote_logger.error(
137
+ message=f"invalid rabbitmq request data at queue `{context.queue}` and listener `{context.listener_name}`",
138
+ tags=[
139
+ "RabbitMQ",
140
+ context.queue,
141
+ context.listener_name or "__default__",
142
+ "ValidationError",
143
+ ],
144
+ extra={
145
+ "serviceType": "RabbitMQ",
146
+ "queue": context.queue,
147
+ "listenerName": context.listener_name,
148
+ "exception": "ValidationError",
149
+ },
150
+ exception=exception,
151
+ )
152
+
153
+
154
+ class GeneralMqExceptionHandler(AbstractRabbitMqExceptionHandler):
155
+ @property
156
+ def exception(self) -> type[Exception]:
157
+ return Exception
158
+
159
+ def __init__(self, remote_logger: BaseRemoteLogSender):
160
+ self._remote_logger = remote_logger
161
+
162
+ def __call__(
163
+ self,
164
+ context: ListenerContext,
165
+ exception: Exception,
166
+ ) -> None:
167
+ self._remote_logger.error(
168
+ message=f"something went wrong while consuming message on queue `{context.queue}` and listener `{context.listener_name}`",
169
+ tags=[
170
+ "RabbitMQ",
171
+ context.queue,
172
+ context.listener_name or "__default__",
173
+ exception.__class__.__name__,
174
+ ],
175
+ extra={
176
+ "serviceType": "RabbitMQ",
177
+ "queue": context.queue,
178
+ "listenerName": context.listener_name,
179
+ "exception": exception.__class__.__name__,
180
+ },
181
+ exception=exception,
182
+ )
@@ -1,12 +1,20 @@
1
1
  from abc import ABC, abstractmethod
2
- from asyncio import AbstractEventLoop, Future, Lock, Task, iscoroutinefunction
2
+ from asyncio import (
3
+ AbstractEventLoop,
4
+ Future,
5
+ Lock,
6
+ Task,
7
+ gather,
8
+ iscoroutinefunction,
9
+ )
3
10
  from dataclasses import dataclass
11
+ from decimal import Decimal
4
12
  from functools import partial
5
13
  from inspect import Parameter, signature
6
14
  from random import uniform
7
15
  from time import time
8
16
  from types import MappingProxyType
9
- from typing import Any, Callable, Collection, TypeVar
17
+ from typing import Any, Callable, Collection, TypeVar, cast
10
18
 
11
19
  from pika import BasicProperties
12
20
  from pika.adapters.asyncio_connection import AsyncioConnection
@@ -21,7 +29,7 @@ from pydantic_core import from_json, to_json
21
29
  from ..dependencies.miscellaneous import validate_annotation
22
30
  from ..exceptions import RabbitMQServiceException
23
31
  from ..logging import LoggerProvider
24
- from ..logstash import BaseLogstashSender
32
+ from ..remotelogging import BaseRemoteLogSender
25
33
  from ..utils import AsyncEventLoopMixin, TypeAdapterCache
26
34
  from ._channel import BaseChannel
27
35
  from ._pool import ChannelPool
@@ -59,7 +67,7 @@ class FlowControl:
59
67
  self._channel = channel
60
68
  self._loop = loop
61
69
  self._lock = Lock()
62
- self._flow_control_future: Future | None = None
70
+ self._flow_control_future: Future[None] | None = None
63
71
 
64
72
  async def request(self, prefetch_count: int) -> None:
65
73
  async with self._lock:
@@ -126,7 +134,7 @@ class ListenerContext:
126
134
  body: bytes
127
135
  flow_control: FlowControl
128
136
  rpc_reply: RpcReply | None = None
129
- context_dispose_callback: Callable | None = None
137
+ context_dispose_callback: Callable[..., None] | None = None
130
138
 
131
139
  def dispose(self) -> None:
132
140
  if self.context_dispose_callback is not None:
@@ -202,7 +210,7 @@ class RetryPolicy:
202
210
 
203
211
  @dataclass
204
212
  class ListenerMethodContainer:
205
- listener_method: Callable
213
+ listener_method: Callable[..., Any]
206
214
  parameters: MappingProxyType[str, Parameter]
207
215
  dependencies: dict[str, type]
208
216
  retry_policy: RetryPolicy | None = None
@@ -213,7 +221,7 @@ class ListenerChannelAdapter(BaseChannel):
213
221
  self,
214
222
  connection: AsyncioConnection,
215
223
  on_channel_open_callback: Callable[[Channel], None],
216
- on_cancel_callback: Callable,
224
+ on_cancel_callback: Callable[..., None],
217
225
  ) -> None:
218
226
  super().__init__(
219
227
  connection=connection,
@@ -273,6 +281,9 @@ class Listener(AsyncEventLoopMixin):
273
281
  self._purge_on_startup = purge_on_startup
274
282
  self._retry_policy = retry_policy
275
283
  self._listeners: dict[str, ListenerMethodContainer] = {}
284
+ self._listeners_tasks_and_futures: list[Task[Any] | Future[Any]] = []
285
+ self._consumer_tag: str | None = None
286
+ self._cancelled = False
276
287
  self._logger = LoggerProvider.default().get_logger("rabbitmq.listener")
277
288
 
278
289
  @property
@@ -290,7 +301,7 @@ class Listener(AsyncEventLoopMixin):
290
301
  def add_listener_method(
291
302
  self,
292
303
  listener_name: str | None,
293
- listener_method: Callable,
304
+ listener_method: Callable[..., Any],
294
305
  retry_policy: RetryPolicy | None = None,
295
306
  ) -> None:
296
307
  self._register_listener_method(
@@ -303,7 +314,7 @@ class Listener(AsyncEventLoopMixin):
303
314
  def _register_listener_method(
304
315
  self,
305
316
  listener_name: str | None,
306
- listener_method: Callable,
317
+ listener_method: Callable[..., Any],
307
318
  parameters: MappingProxyType[str, Parameter],
308
319
  retry_policy: RetryPolicy | None = None,
309
320
  ) -> None:
@@ -343,7 +354,7 @@ class Listener(AsyncEventLoopMixin):
343
354
  channel_pool: ChannelPool,
344
355
  on_exception_callback: Callable[[ListenerContext, BaseException], bool],
345
356
  container: Container,
346
- logstash: BaseLogstashSender,
357
+ remote_logger: BaseRemoteLogSender,
347
358
  global_retry_policy: RetryPolicy | None = None,
348
359
  ) -> None:
349
360
  self._connection = connection
@@ -351,7 +362,7 @@ class Listener(AsyncEventLoopMixin):
351
362
  self._listener_future = self.loop.create_future()
352
363
  self._on_exception_callback = on_exception_callback
353
364
  self._container = container
354
- self._logstash = logstash
365
+ self._remote_logger = remote_logger
355
366
  self._global_retry_policy = global_retry_policy
356
367
  self._listener_channel = ListenerChannelAdapter(
357
368
  connection=connection,
@@ -362,6 +373,16 @@ class Listener(AsyncEventLoopMixin):
362
373
 
363
374
  await self._listener_future
364
375
 
376
+ async def cancel(self) -> None:
377
+ self._cancelled = True
378
+
379
+ if self._consumer_tag is not None:
380
+ self._channel.basic_cancel(self._consumer_tag)
381
+
382
+ _ = await gather(
383
+ *self._listeners_tasks_and_futures, return_exceptions=True
384
+ )
385
+
365
386
  def _on_channel_opened(self, channel: Channel) -> None:
366
387
  self._channel = channel
367
388
  self._flow_control = FlowControl(channel=self._channel, loop=self.loop)
@@ -369,6 +390,9 @@ class Listener(AsyncEventLoopMixin):
369
390
  self._declare_queue()
370
391
 
371
392
  def _on_cancelled(self) -> None:
393
+ if self._cancelled:
394
+ return
395
+
372
396
  self._declare_queue()
373
397
 
374
398
  def _declare_queue(self) -> None:
@@ -406,7 +430,7 @@ class Listener(AsyncEventLoopMixin):
406
430
 
407
431
  def _register_listener(self) -> None:
408
432
  try:
409
- _ = self._channel.basic_consume(
433
+ self._consumer_tag = self._channel.basic_consume(
410
434
  queue=self._queue,
411
435
  auto_ack=True,
412
436
  on_message_callback=self._on_message,
@@ -443,6 +467,9 @@ class Listener(AsyncEventLoopMixin):
443
467
  else:
444
468
  listener_name = "__default__"
445
469
 
470
+ if not isinstance(listener_name, str):
471
+ listener_name = str(listener_name)
472
+
446
473
  self._logger.debug(
447
474
  "message recieved from `%s` queue for listener `%s`",
448
475
  self._queue,
@@ -452,7 +479,7 @@ class Listener(AsyncEventLoopMixin):
452
479
  listener_method_container = self._listeners.get(listener_name)
453
480
 
454
481
  if listener_method_container is None:
455
- self._logstash.error(
482
+ self._remote_logger.error(
456
483
  message=f"no listener registered with the name `{listener_name}` on queue `{self._queue}`",
457
484
  tags=[
458
485
  "RabbitMQ",
@@ -485,7 +512,7 @@ class Listener(AsyncEventLoopMixin):
485
512
  )
486
513
 
487
514
  def _on_submitted_listener_error(
488
- self, listener_message_meta: ListenerMessageMeta, future: Future
515
+ self, listener_message_meta: ListenerMessageMeta, future: Future[None]
489
516
  ) -> None:
490
517
  if future.cancelled():
491
518
  return
@@ -515,31 +542,36 @@ class Listener(AsyncEventLoopMixin):
515
542
 
516
543
  return
517
544
 
518
- listener_done_callback = partial(
519
- self._on_listener_done_executing, listener_message_meta
520
- )
545
+ listener_task_or_future: Task[Any] | Future[Any] | None = None
521
546
 
522
547
  if iscoroutinefunction(
523
548
  listener_message_meta.listener_method_container.listener_method
524
549
  ):
525
- self.loop.create_task(
550
+ listener_task_or_future = self.loop.create_task(
526
551
  listener_message_meta.listener_method_container.listener_method(
527
552
  *listener_method_args, **listener_method_kwargs
528
553
  )
529
- ).add_done_callback(listener_done_callback)
554
+ )
530
555
  else:
531
- self.loop.run_in_executor(
556
+ listener_task_or_future = self.loop.run_in_executor(
532
557
  executor=None,
533
558
  func=partial(
534
559
  listener_message_meta.listener_method_container.listener_method,
535
560
  *listener_method_args,
536
561
  **listener_method_kwargs,
537
562
  ),
538
- ).add_done_callback(listener_done_callback)
563
+ )
564
+
565
+ assert listener_task_or_future is not None
566
+
567
+ self._listeners_tasks_and_futures.append(listener_task_or_future)
568
+ listener_task_or_future.add_done_callback(
569
+ partial(self._on_listener_done_executing, listener_message_meta)
570
+ )
539
571
 
540
572
  def _parse_args(
541
573
  self, listener_message_meta: ListenerMessageMeta
542
- ) -> tuple[list, dict]:
574
+ ) -> tuple[list[Any], dict[str, Any]]:
543
575
  try:
544
576
  message = from_json(listener_message_meta.body)
545
577
  except:
@@ -739,8 +771,14 @@ class Listener(AsyncEventLoopMixin):
739
771
  def _on_listener_done_executing(
740
772
  self,
741
773
  listener_message_meta: ListenerMessageMeta,
742
- task_or_future: Task | Future,
774
+ task_or_future: Task[Any] | Future[Any],
743
775
  ) -> None:
776
+ if (
777
+ not self._cancelled
778
+ and task_or_future in self._listeners_tasks_and_futures
779
+ ):
780
+ self._listeners_tasks_and_futures.remove(task_or_future)
781
+
744
782
  if task_or_future.cancelled():
745
783
  return
746
784
 
@@ -773,11 +811,19 @@ class Listener(AsyncEventLoopMixin):
773
811
 
774
812
  if listener_message_meta.properties.headers is not None:
775
813
  try:
776
- times_rejected = int(
814
+ _times_rejected = (
777
815
  listener_message_meta.properties.headers.get(
778
816
  "times_rejected"
779
817
  )
780
818
  )
819
+
820
+ if isinstance(_times_rejected, int):
821
+ times_rejected = _times_rejected
822
+ elif isinstance(
823
+ _times_rejected,
824
+ (str, bytes, bytearray, Decimal),
825
+ ):
826
+ times_rejected = int(_times_rejected)
781
827
  except:
782
828
  pass
783
829
 
@@ -911,7 +957,7 @@ class Listener(AsyncEventLoopMixin):
911
957
  exception,
912
958
  )
913
959
  except:
914
- self._logstash.exception(
960
+ self._remote_logger.exception(
915
961
  message=f"error occured while invoking rabbitmq exception handler callback in listener `{listener_message_meta.listener_name}` and queue `{self._queue}`",
916
962
  tags=[
917
963
  "RabbitMQ",
@@ -930,7 +976,7 @@ class Listener(AsyncEventLoopMixin):
930
976
  return
931
977
 
932
978
  if not exception_callback_succeeded:
933
- self._logstash.exception(
979
+ self._remote_logger.exception(
934
980
  message=(
935
981
  message
936
982
  or f"error occured while handling event in listener `{listener_message_meta.listener_name}` and queue `{self._queue}`"
@@ -1013,7 +1059,9 @@ class Listener(AsyncEventLoopMixin):
1013
1059
  if listener_message_meta.properties.headers is None:
1014
1060
  listener_message_meta.properties.headers = headers
1015
1061
  else:
1016
- listener_message_meta.properties.headers.update(headers)
1062
+ cast(
1063
+ dict[str, Any], listener_message_meta.properties.headers
1064
+ ).update(headers)
1017
1065
 
1018
1066
  try:
1019
1067
  with task.result() as channel:
@@ -1079,7 +1127,9 @@ class Listener(AsyncEventLoopMixin):
1079
1127
  )
1080
1128
  )
1081
1129
 
1082
- def _reponse_from_exception(self, exception: BaseException) -> dict:
1130
+ def _reponse_from_exception(
1131
+ self, exception: BaseException
1132
+ ) -> dict[str, Any]:
1083
1133
  match exception:
1084
1134
  case RabbitMQServiceException() as rabbitmq_exception:
1085
1135
  code = rabbitmq_exception.code
@@ -1123,6 +1173,8 @@ class Listener(AsyncEventLoopMixin):
1123
1173
  return
1124
1174
 
1125
1175
  try:
1176
+ assert listener_message_meta.properties.reply_to is not None
1177
+
1126
1178
  with task.result() as channel:
1127
1179
  channel.basic_publish(
1128
1180
  exchange=DEFAULT_EXCHANGE,
@@ -1168,8 +1220,8 @@ class Consumer(Listener):
1168
1220
 
1169
1221
  def consume(
1170
1222
  self, target: str | None = None, retry_policy: RetryPolicy | None = None
1171
- ) -> Callable[[Callable], Callable]:
1172
- def wrapper(consumer_method: Callable) -> Callable:
1223
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1224
+ def wrapper(consumer_method: Callable[..., Any]) -> Callable[..., Any]:
1173
1225
  if not callable(consumer_method):
1174
1226
  raise TypeError(
1175
1227
  f"consumer method argument not a callable, got {type(consumer_method)}"
@@ -1201,8 +1253,8 @@ def consumer(
1201
1253
 
1202
1254
  def consume(
1203
1255
  target: str | None = None, retry_policy: RetryPolicy | None = None
1204
- ) -> Callable[[Callable], Callable]:
1205
- def wrapper(consumer_method: Callable) -> Callable:
1256
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1257
+ def wrapper(consumer_method: Callable[..., Any]) -> Callable[..., Any]:
1206
1258
  if not callable(consumer_method):
1207
1259
  raise TypeError(
1208
1260
  f"consumer method argument not a callable, got {type(consumer_method)}"
@@ -1230,8 +1282,8 @@ class RpcWorker(Listener):
1230
1282
 
1231
1283
  def execute(
1232
1284
  self, procedure: str | None = None
1233
- ) -> Callable[[Callable], Callable]:
1234
- def wrapper(worker_method: Callable) -> Callable:
1285
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1286
+ def wrapper(worker_method: Callable[..., Any]) -> Callable[..., Any]:
1235
1287
  if not callable(worker_method):
1236
1288
  raise TypeError(
1237
1289
  f"worker method argument not a callable, got {type(worker_method)}"
@@ -1252,8 +1304,10 @@ def rpc_worker(queue: str, prefetch_count: int = 250) -> RpcWorker:
1252
1304
  return RpcWorker(queue=queue, prefetch_count=prefetch_count)
1253
1305
 
1254
1306
 
1255
- def execute(procedure: str | None = None) -> Callable[[Callable], Callable]:
1256
- def wrapper(worker_method: Callable) -> Callable:
1307
+ def execute(
1308
+ procedure: str | None = None,
1309
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
1310
+ def wrapper(worker_method: Callable[..., Any]) -> Callable[..., Any]:
1257
1311
  if not callable(worker_method):
1258
1312
  raise TypeError(
1259
1313
  f"worker method argument not a callable, got {type(worker_method)}"
@@ -1283,7 +1337,7 @@ class ListenerBase:
1283
1337
  f"{self.__class__.__name__} not a listener, possibly no annotated with either `Consumer` or `RpcWorker`"
1284
1338
  )
1285
1339
 
1286
- return listener
1340
+ return cast(Listener, listener)
1287
1341
 
1288
1342
  def register_listener_methods(self) -> Listener:
1289
1343
  listener = self.get_inner_listener()
@@ -1330,7 +1384,9 @@ class ListenerBase:
1330
1384
  attribute: Any,
1331
1385
  listener_method_attribute: str,
1332
1386
  previous_listener_method_attribute: str | None,
1333
- ) -> tuple[str | None, Callable | None, ListenerMethodMeta | None]:
1387
+ ) -> tuple[
1388
+ str | None, Callable[..., Any] | None, ListenerMethodMeta | None
1389
+ ]:
1334
1390
  listener_method_meta = getattr(
1335
1391
  attribute, listener_method_attribute, None
1336
1392
  )
@@ -27,19 +27,19 @@ __all__ = ["RpcClient"]
27
27
 
28
28
  class ExitHandler:
29
29
  _exiting = False
30
- _rpc_futures: list[Future] = []
31
- _original_exit_handler: Callable
30
+ _rpc_futures: list[Future[Any]] = []
31
+ _original_exit_handler: Callable[..., None]
32
32
 
33
33
  @classmethod
34
34
  def is_exising(cls) -> bool:
35
35
  return cls._exiting
36
36
 
37
37
  @classmethod
38
- def add_rpc_future(cls, rpc_future: Future) -> None:
38
+ def add_rpc_future(cls, rpc_future: Future[Any]) -> None:
39
39
  cls._rpc_futures.append(rpc_future)
40
40
 
41
41
  @classmethod
42
- def remove_rpc_future(cls, rpc_future: Future) -> None:
42
+ def remove_rpc_future(cls, rpc_future: Future[Any]) -> None:
43
43
  try:
44
44
  cls._rpc_futures.remove(rpc_future)
45
45
  except: