taskiq-redis 0.5.6__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022-2024 Pavel Kirilin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: taskiq-redis
3
- Version: 0.5.6
3
+ Version: 1.0.0
4
4
  Summary: Redis integration for taskiq
5
5
  Home-page: https://github.com/taskiq-python/taskiq-redis
6
6
  Keywords: taskiq,tasks,distributed,async,redis,result_backend
@@ -16,7 +16,7 @@ Classifier: Programming Language :: Python :: 3.12
16
16
  Classifier: Programming Language :: Python :: 3 :: Only
17
17
  Classifier: Programming Language :: Python :: 3.8
18
18
  Requires-Dist: redis (>=5,<6)
19
- Requires-Dist: taskiq (>=0.10.3,<1)
19
+ Requires-Dist: taskiq (>=0.11.1,<1)
20
20
  Project-URL: Repository, https://github.com/taskiq-python/taskiq-redis
21
21
  Description-Content-Type: text/markdown
22
22
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "taskiq-redis"
3
- version = "0.5.6"
3
+ version = "1.0.0"
4
4
  description = "Redis integration for taskiq"
5
5
  authors = ["taskiq-team <taskiq@norely.com>"]
6
6
  readme = "README.md"
@@ -26,7 +26,7 @@ keywords = [
26
26
 
27
27
  [tool.poetry.dependencies]
28
28
  python = "^3.8.1"
29
- taskiq = ">=0.10.3,<1"
29
+ taskiq = ">=0.11.1,<1"
30
30
  redis = "^5"
31
31
 
32
32
  [tool.poetry.group.dev.dependencies]
@@ -40,7 +40,7 @@ fakeredis = "^2"
40
40
  pre-commit = "^2.20.0"
41
41
  pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
42
42
  ruff = "^0.1.0"
43
- types-redis = "^4.6.0.7"
43
+ types-redis = "^4.6.0.20240425"
44
44
 
45
45
  [tool.mypy]
46
46
  strict = true
@@ -2,20 +2,30 @@
2
2
  from taskiq_redis.redis_backend import (
3
3
  RedisAsyncClusterResultBackend,
4
4
  RedisAsyncResultBackend,
5
+ RedisAsyncSentinelResultBackend,
5
6
  )
6
7
  from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
7
8
  from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
9
+ from taskiq_redis.redis_sentinel_broker import (
10
+ ListQueueSentinelBroker,
11
+ PubSubSentinelBroker,
12
+ )
8
13
  from taskiq_redis.schedule_source import (
9
14
  RedisClusterScheduleSource,
10
15
  RedisScheduleSource,
16
+ RedisSentinelScheduleSource,
11
17
  )
12
18
 
13
19
  __all__ = [
14
20
  "RedisAsyncClusterResultBackend",
15
21
  "RedisAsyncResultBackend",
22
+ "RedisAsyncSentinelResultBackend",
16
23
  "ListQueueBroker",
17
24
  "PubSubBroker",
18
25
  "ListQueueClusterBroker",
26
+ "ListQueueSentinelBroker",
27
+ "PubSubSentinelBroker",
19
28
  "RedisScheduleSource",
20
29
  "RedisClusterScheduleSource",
30
+ "RedisSentinelScheduleSource",
21
31
  ]
@@ -1,10 +1,25 @@
1
- import pickle
2
- from typing import Any, Dict, Optional, TypeVar, Union
1
+ import sys
2
+ from contextlib import asynccontextmanager
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ AsyncIterator,
7
+ Dict,
8
+ List,
9
+ Optional,
10
+ Tuple,
11
+ TypeVar,
12
+ Union,
13
+ )
3
14
 
4
- from redis.asyncio import BlockingConnectionPool, Redis
15
+ from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
5
16
  from redis.asyncio.cluster import RedisCluster
17
+ from redis.asyncio.connection import Connection
6
18
  from taskiq import AsyncResultBackend
7
19
  from taskiq.abc.result_backend import TaskiqResult
20
+ from taskiq.abc.serializer import TaskiqSerializer
21
+ from taskiq.compat import model_dump, model_validate
22
+ from taskiq.serializers import PickleSerializer
8
23
 
9
24
  from taskiq_redis.exceptions import (
10
25
  DuplicateExpireTimeSelectedError,
@@ -12,6 +27,18 @@ from taskiq_redis.exceptions import (
12
27
  ResultIsMissingError,
13
28
  )
14
29
 
30
+ if sys.version_info >= (3, 10):
31
+ from typing import TypeAlias
32
+ else:
33
+ from typing_extensions import TypeAlias
34
+
35
+ if TYPE_CHECKING:
36
+ _Redis: TypeAlias = Redis[bytes]
37
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
38
+ else:
39
+ _Redis: TypeAlias = Redis
40
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool
41
+
15
42
  _ReturnType = TypeVar("_ReturnType")
16
43
 
17
44
 
@@ -25,6 +52,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
25
52
  result_ex_time: Optional[int] = None,
26
53
  result_px_time: Optional[int] = None,
27
54
  max_connection_pool_size: Optional[int] = None,
55
+ serializer: Optional[TaskiqSerializer] = None,
28
56
  **connection_kwargs: Any,
29
57
  ) -> None:
30
58
  """
@@ -42,11 +70,12 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
42
70
  :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
43
71
  and result_px_time are equal zero.
44
72
  """
45
- self.redis_pool = BlockingConnectionPool.from_url(
73
+ self.redis_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
46
74
  url=redis_url,
47
75
  max_connections=max_connection_pool_size,
48
76
  **connection_kwargs,
49
77
  )
78
+ self.serializer = serializer or PickleSerializer()
50
79
  self.keep_results = keep_results
51
80
  self.result_ex_time = result_ex_time
52
81
  self.result_px_time = result_px_time
@@ -86,9 +115,9 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
86
115
  :param task_id: ID of the task.
87
116
  :param result: TaskiqResult instance.
88
117
  """
89
- redis_set_params: Dict[str, Union[str, bytes, int]] = {
118
+ redis_set_params: Dict[str, Union[str, int, bytes]] = {
90
119
  "name": task_id,
91
- "value": pickle.dumps(result),
120
+ "value": self.serializer.dumpb(model_dump(result)),
92
121
  }
93
122
  if self.result_ex_time:
94
123
  redis_set_params["ex"] = self.result_ex_time
@@ -135,8 +164,9 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
135
164
  if result_value is None:
136
165
  raise ResultIsMissingError
137
166
 
138
- taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
139
- result_value,
167
+ taskiq_result = model_validate(
168
+ TaskiqResult[_ReturnType],
169
+ self.serializer.loadb(result_value),
140
170
  )
141
171
 
142
172
  if not with_logs:
@@ -154,6 +184,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
154
184
  keep_results: bool = True,
155
185
  result_ex_time: Optional[int] = None,
156
186
  result_px_time: Optional[int] = None,
187
+ serializer: Optional[TaskiqSerializer] = None,
157
188
  **connection_kwargs: Any,
158
189
  ) -> None:
159
190
  """
@@ -174,6 +205,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
174
205
  redis_url,
175
206
  **connection_kwargs,
176
207
  )
208
+ self.serializer = serializer or PickleSerializer()
177
209
  self.keep_results = keep_results
178
210
  self.result_ex_time = result_ex_time
179
211
  self.result_px_time = result_px_time
@@ -215,7 +247,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
215
247
  """
216
248
  redis_set_params: Dict[str, Union[str, bytes, int]] = {
217
249
  "name": task_id,
218
- "value": pickle.dumps(result),
250
+ "value": self.serializer.dumpb(model_dump(result)),
219
251
  }
220
252
  if self.result_ex_time:
221
253
  redis_set_params["ex"] = self.result_ex_time
@@ -259,11 +291,155 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
259
291
  if result_value is None:
260
292
  raise ResultIsMissingError
261
293
 
262
- taskiq_result: TaskiqResult[_ReturnType] = pickle.loads( # noqa: S301
263
- result_value,
294
+ taskiq_result: TaskiqResult[_ReturnType] = model_validate(
295
+ TaskiqResult[_ReturnType],
296
+ self.serializer.loadb(result_value),
297
+ )
298
+
299
+ if not with_logs:
300
+ taskiq_result.log = None
301
+
302
+ return taskiq_result
303
+
304
+
305
+ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
306
+ """Async result based on redis sentinel."""
307
+
308
+ def __init__(
309
+ self,
310
+ sentinels: List[Tuple[str, int]],
311
+ master_name: str,
312
+ keep_results: bool = True,
313
+ result_ex_time: Optional[int] = None,
314
+ result_px_time: Optional[int] = None,
315
+ min_other_sentinels: int = 0,
316
+ sentinel_kwargs: Optional[Any] = None,
317
+ serializer: Optional[TaskiqSerializer] = None,
318
+ **connection_kwargs: Any,
319
+ ) -> None:
320
+ """
321
+ Constructs a new result backend.
322
+
323
+ :param sentinels: list of sentinel host and ports pairs.
324
+ :param master_name: sentinel master name.
325
+ :param keep_results: flag to not remove results from Redis after reading.
326
+ :param result_ex_time: expire time in seconds for result.
327
+ :param result_px_time: expire time in milliseconds for result.
328
+ :param max_connection_pool_size: maximum number of connections in pool.
329
+ :param connection_kwargs: additional arguments for redis BlockingConnectionPool.
330
+
331
+ :raises DuplicateExpireTimeSelectedError: if result_ex_time
332
+ and result_px_time are selected.
333
+ :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
334
+ and result_px_time are equal zero.
335
+ """
336
+ self.sentinel = Sentinel(
337
+ sentinels=sentinels,
338
+ min_other_sentinels=min_other_sentinels,
339
+ sentinel_kwargs=sentinel_kwargs,
340
+ **connection_kwargs,
341
+ )
342
+ self.master_name = master_name
343
+ self.serializer = serializer or PickleSerializer()
344
+ self.keep_results = keep_results
345
+ self.result_ex_time = result_ex_time
346
+ self.result_px_time = result_px_time
347
+
348
+ unavailable_conditions = any(
349
+ (
350
+ self.result_ex_time is not None and self.result_ex_time <= 0,
351
+ self.result_px_time is not None and self.result_px_time <= 0,
352
+ ),
353
+ )
354
+ if unavailable_conditions:
355
+ raise ExpireTimeMustBeMoreThanZeroError(
356
+ "You must select one expire time param and it must be more than zero.",
357
+ )
358
+
359
+ if self.result_ex_time and self.result_px_time:
360
+ raise DuplicateExpireTimeSelectedError(
361
+ "Choose either result_ex_time or result_px_time.",
362
+ )
363
+
364
+ @asynccontextmanager
365
+ async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
366
+ async with self.sentinel.master_for(self.master_name) as redis_conn:
367
+ yield redis_conn
368
+
369
+ async def set_result(
370
+ self,
371
+ task_id: str,
372
+ result: TaskiqResult[_ReturnType],
373
+ ) -> None:
374
+ """
375
+ Sets task result in redis.
376
+
377
+ Dumps TaskiqResult instance into the bytes and writes
378
+ it to redis.
379
+
380
+ :param task_id: ID of the task.
381
+ :param result: TaskiqResult instance.
382
+ """
383
+ redis_set_params: Dict[str, Union[str, bytes, int]] = {
384
+ "name": task_id,
385
+ "value": self.serializer.dumpb(model_dump(result)),
386
+ }
387
+ if self.result_ex_time:
388
+ redis_set_params["ex"] = self.result_ex_time
389
+ elif self.result_px_time:
390
+ redis_set_params["px"] = self.result_px_time
391
+
392
+ async with self._acquire_master_conn() as redis:
393
+ await redis.set(**redis_set_params) # type: ignore
394
+
395
+ async def is_result_ready(self, task_id: str) -> bool:
396
+ """
397
+ Returns whether the result is ready.
398
+
399
+ :param task_id: ID of the task.
400
+
401
+ :returns: True if the result is ready else False.
402
+ """
403
+ async with self._acquire_master_conn() as redis:
404
+ return bool(await redis.exists(task_id))
405
+
406
+ async def get_result(
407
+ self,
408
+ task_id: str,
409
+ with_logs: bool = False,
410
+ ) -> TaskiqResult[_ReturnType]:
411
+ """
412
+ Gets result from the task.
413
+
414
+ :param task_id: task's id.
415
+ :param with_logs: if True it will download task's logs.
416
+ :raises ResultIsMissingError: if there is no result when trying to get it.
417
+ :return: task's return value.
418
+ """
419
+ async with self._acquire_master_conn() as redis:
420
+ if self.keep_results:
421
+ result_value = await redis.get(
422
+ name=task_id,
423
+ )
424
+ else:
425
+ result_value = await redis.getdel(
426
+ name=task_id,
427
+ )
428
+
429
+ if result_value is None:
430
+ raise ResultIsMissingError
431
+
432
+ taskiq_result = model_validate(
433
+ TaskiqResult[_ReturnType],
434
+ self.serializer.loadb(result_value),
264
435
  )
265
436
 
266
437
  if not with_logs:
267
438
  taskiq_result.log = None
268
439
 
269
440
  return taskiq_result
441
+
442
+ async def shutdown(self) -> None:
443
+ """Shutdown sentinel connections."""
444
+ for sentinel in self.sentinel.sentinels:
445
+ await sentinel.aclose() # type: ignore[attr-defined]
@@ -1,7 +1,8 @@
1
+ import sys
1
2
  from logging import getLogger
2
- from typing import Any, AsyncGenerator, Callable, Optional, TypeVar
3
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Optional, TypeVar
3
4
 
4
- from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis
5
+ from redis.asyncio import BlockingConnectionPool, Connection, Redis
5
6
  from taskiq.abc.broker import AsyncBroker
6
7
  from taskiq.abc.result_backend import AsyncResultBackend
7
8
  from taskiq.message import BrokerMessage
@@ -10,6 +11,16 @@ _T = TypeVar("_T")
10
11
 
11
12
  logger = getLogger("taskiq.redis_broker")
12
13
 
14
+ if sys.version_info >= (3, 10):
15
+ from typing import TypeAlias
16
+ else:
17
+ from typing_extensions import TypeAlias
18
+
19
+ if TYPE_CHECKING:
20
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
21
+ else:
22
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool
23
+
13
24
 
14
25
  class BaseRedisBroker(AsyncBroker):
15
26
  """Base broker that works with Redis."""
@@ -40,7 +51,7 @@ class BaseRedisBroker(AsyncBroker):
40
51
  task_id_generator=task_id_generator,
41
52
  )
42
53
 
43
- self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
54
+ self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
44
55
  url=url,
45
56
  max_connections=max_connection_pool_size,
46
57
  **connection_kwargs,
@@ -0,0 +1,132 @@
1
+ import sys
2
+ from contextlib import asynccontextmanager
3
+ from logging import getLogger
4
+ from typing import (
5
+ TYPE_CHECKING,
6
+ Any,
7
+ AsyncGenerator,
8
+ AsyncIterator,
9
+ Callable,
10
+ List,
11
+ Optional,
12
+ Tuple,
13
+ TypeVar,
14
+ )
15
+
16
+ from redis.asyncio import Redis, Sentinel
17
+ from taskiq import AsyncResultBackend, BrokerMessage
18
+ from taskiq.abc.broker import AsyncBroker
19
+
20
+ if sys.version_info >= (3, 10):
21
+ from typing import TypeAlias
22
+ else:
23
+ from typing_extensions import TypeAlias
24
+
25
+ if TYPE_CHECKING:
26
+ _Redis: TypeAlias = Redis[bytes]
27
+ else:
28
+ _Redis: TypeAlias = Redis
29
+
30
+ _T = TypeVar("_T")
31
+
32
+ logger = getLogger("taskiq.redis_sentinel_broker")
33
+
34
+
35
+ class BaseSentinelBroker(AsyncBroker):
36
+ """Base broker that works with Sentinel."""
37
+
38
+ def __init__(
39
+ self,
40
+ sentinels: List[Tuple[str, int]],
41
+ master_name: str,
42
+ result_backend: Optional[AsyncResultBackend[_T]] = None,
43
+ task_id_generator: Optional[Callable[[], str]] = None,
44
+ queue_name: str = "taskiq",
45
+ min_other_sentinels: int = 0,
46
+ sentinel_kwargs: Optional[Any] = None,
47
+ **connection_kwargs: Any,
48
+ ) -> None:
49
+ super().__init__(
50
+ result_backend=result_backend,
51
+ task_id_generator=task_id_generator,
52
+ )
53
+
54
+ self.sentinel = Sentinel(
55
+ sentinels=sentinels,
56
+ min_other_sentinels=min_other_sentinels,
57
+ sentinel_kwargs=sentinel_kwargs,
58
+ **connection_kwargs,
59
+ )
60
+ self.master_name = master_name
61
+ self.queue_name = queue_name
62
+
63
+ @asynccontextmanager
64
+ async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
65
+ async with self.sentinel.master_for(self.master_name) as redis_conn:
66
+ yield redis_conn
67
+
68
+
69
+ class PubSubSentinelBroker(BaseSentinelBroker):
70
+ """Broker that works with Sentinel and broadcasts tasks to all workers."""
71
+
72
+ async def kick(self, message: BrokerMessage) -> None:
73
+ """
74
+ Publish message over PUBSUB channel.
75
+
76
+ :param message: message to send.
77
+ """
78
+ queue_name = message.labels.get("queue_name") or self.queue_name
79
+ async with self._acquire_master_conn() as redis_conn:
80
+ await redis_conn.publish(queue_name, message.message)
81
+
82
+ async def listen(self) -> AsyncGenerator[bytes, None]:
83
+ """
84
+ Listen redis queue for new messages.
85
+
86
+ This function listens to the pubsub channel
87
+ and yields all messages with proper types.
88
+
89
+ :yields: broker messages.
90
+ """
91
+ async with self._acquire_master_conn() as redis_conn:
92
+ redis_pubsub_channel = redis_conn.pubsub()
93
+ await redis_pubsub_channel.subscribe(self.queue_name)
94
+ async for message in redis_pubsub_channel.listen():
95
+ if not message:
96
+ continue
97
+ if message["type"] != "message":
98
+ logger.debug("Received non-message from redis: %s", message)
99
+ continue
100
+ yield message["data"]
101
+
102
+
103
+ class ListQueueSentinelBroker(BaseSentinelBroker):
104
+ """Broker that works with Sentinel and distributes tasks between workers."""
105
+
106
+ async def kick(self, message: BrokerMessage) -> None:
107
+ """
108
+ Put a message in a list.
109
+
110
+ This method appends a message to the list of all messages.
111
+
112
+ :param message: message to append.
113
+ """
114
+ queue_name = message.labels.get("queue_name") or self.queue_name
115
+ async with self._acquire_master_conn() as redis_conn:
116
+ await redis_conn.lpush(queue_name, message.message)
117
+
118
+ async def listen(self) -> AsyncGenerator[bytes, None]:
119
+ """
120
+ Listen redis queue for new messages.
121
+
122
+ This function listens to the queue
123
+ and yields new messages if they have BrokerMessage type.
124
+
125
+ :yields: broker messages.
126
+ """
127
+ redis_brpop_data_position = 1
128
+ async with self._acquire_master_conn() as redis_conn:
129
+ while True:
130
+ yield (await redis_conn.brpop(self.queue_name))[
131
+ redis_brpop_data_position
132
+ ]
@@ -1,12 +1,31 @@
1
- from typing import Any, List, Optional
1
+ import sys
2
+ from contextlib import asynccontextmanager
3
+ from typing import TYPE_CHECKING, Any, AsyncIterator, List, Optional, Tuple
2
4
 
3
- from redis.asyncio import BlockingConnectionPool, ConnectionPool, Redis, RedisCluster
5
+ from redis.asyncio import (
6
+ BlockingConnectionPool,
7
+ Connection,
8
+ Redis,
9
+ RedisCluster,
10
+ Sentinel,
11
+ )
4
12
  from taskiq import ScheduleSource
5
13
  from taskiq.abc.serializer import TaskiqSerializer
6
14
  from taskiq.compat import model_dump, model_validate
7
15
  from taskiq.scheduler.scheduled_task import ScheduledTask
16
+ from taskiq.serializers import PickleSerializer
8
17
 
9
- from taskiq_redis.serializer import PickleSerializer
18
+ if sys.version_info >= (3, 10):
19
+ from typing import TypeAlias
20
+ else:
21
+ from typing_extensions import TypeAlias
22
+
23
+ if TYPE_CHECKING:
24
+ _Redis: TypeAlias = Redis[bytes]
25
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
26
+ else:
27
+ _Redis: TypeAlias = Redis
28
+ _BlockingConnectionPool: TypeAlias = BlockingConnectionPool
10
29
 
11
30
 
12
31
  class RedisScheduleSource(ScheduleSource):
@@ -35,7 +54,7 @@ class RedisScheduleSource(ScheduleSource):
35
54
  **connection_kwargs: Any,
36
55
  ) -> None:
37
56
  self.prefix = prefix
38
- self.connection_pool: ConnectionPool = BlockingConnectionPool.from_url(
57
+ self.connection_pool: _BlockingConnectionPool = BlockingConnectionPool.from_url(
39
58
  url=url,
40
59
  max_connections=max_connection_pool_size,
41
60
  **connection_kwargs,
@@ -117,7 +136,6 @@ class RedisClusterScheduleSource(ScheduleSource):
117
136
  self,
118
137
  url: str,
119
138
  prefix: str = "schedule",
120
- buffer_size: int = 50,
121
139
  serializer: Optional[TaskiqSerializer] = None,
122
140
  **connection_kwargs: Any,
123
141
  ) -> None:
@@ -126,7 +144,6 @@ class RedisClusterScheduleSource(ScheduleSource):
126
144
  url,
127
145
  **connection_kwargs,
128
146
  )
129
- self.buffer_size = buffer_size
130
147
  if serializer is None:
131
148
  serializer = PickleSerializer()
132
149
  self.serializer = serializer
@@ -156,14 +173,107 @@ class RedisClusterScheduleSource(ScheduleSource):
156
173
  :return: list of schedules.
157
174
  """
158
175
  schedules = []
159
- buffer = []
160
176
  async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined]
161
- buffer.append(key)
162
- if len(buffer) >= self.buffer_size:
163
- schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
164
- buffer = []
165
- if buffer:
166
- schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined]
177
+ raw_schedule = await self.redis.get(key) # type: ignore[attr-defined]
178
+ parsed_schedule = model_validate(
179
+ ScheduledTask,
180
+ self.serializer.loadb(raw_schedule),
181
+ )
182
+ schedules.append(parsed_schedule)
183
+ return schedules
184
+
185
+ async def post_send(self, task: ScheduledTask) -> None:
186
+ """Delete a task after it's completed."""
187
+ if task.time is not None:
188
+ await self.delete_schedule(task.schedule_id)
189
+
190
+ async def shutdown(self) -> None:
191
+ """Shut down the schedule source."""
192
+ await self.redis.aclose() # type: ignore[attr-defined]
193
+
194
+
195
+ class RedisSentinelScheduleSource(ScheduleSource):
196
+ """
197
+ Source of schedules for redis cluster.
198
+
199
+ This class allows you to store schedules in redis.
200
+ Also it supports dynamic schedules.
201
+
202
+ :param sentinels: list of sentinel host and ports pairs.
203
+ :param master_name: sentinel master name.
204
+ :param prefix: prefix for redis schedule keys.
205
+ :param buffer_size: buffer size for redis scan.
206
+ This is how many keys will be fetched at once.
207
+ :param max_connection_pool_size: maximum number of connections in pool.
208
+ :param serializer: serializer for data.
209
+ :param connection_kwargs: additional arguments for RedisCluster.
210
+ """
211
+
212
+ def __init__(
213
+ self,
214
+ sentinels: List[Tuple[str, int]],
215
+ master_name: str,
216
+ prefix: str = "schedule",
217
+ buffer_size: int = 50,
218
+ serializer: Optional[TaskiqSerializer] = None,
219
+ min_other_sentinels: int = 0,
220
+ sentinel_kwargs: Optional[Any] = None,
221
+ **connection_kwargs: Any,
222
+ ) -> None:
223
+ self.prefix = prefix
224
+ self.sentinel = Sentinel(
225
+ sentinels=sentinels,
226
+ min_other_sentinels=min_other_sentinels,
227
+ sentinel_kwargs=sentinel_kwargs,
228
+ **connection_kwargs,
229
+ )
230
+ self.master_name = master_name
231
+ self.buffer_size = buffer_size
232
+ if serializer is None:
233
+ serializer = PickleSerializer()
234
+ self.serializer = serializer
235
+
236
+ @asynccontextmanager
237
+ async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
238
+ async with self.sentinel.master_for(self.master_name) as redis_conn:
239
+ yield redis_conn
240
+
241
+ async def delete_schedule(self, schedule_id: str) -> None:
242
+ """Remove schedule by id."""
243
+ async with self._acquire_master_conn() as redis:
244
+ await redis.delete(f"{self.prefix}:{schedule_id}")
245
+
246
+ async def add_schedule(self, schedule: ScheduledTask) -> None:
247
+ """
248
+ Add schedule to redis.
249
+
250
+ :param schedule: schedule to add.
251
+ :param schedule_id: schedule id.
252
+ """
253
+ async with self._acquire_master_conn() as redis:
254
+ await redis.set(
255
+ f"{self.prefix}:{schedule.schedule_id}",
256
+ self.serializer.dumpb(model_dump(schedule)),
257
+ )
258
+
259
+ async def get_schedules(self) -> List[ScheduledTask]:
260
+ """
261
+ Get all schedules from redis.
262
+
263
+ This method is used by scheduler to get all schedules.
264
+
265
+ :return: list of schedules.
266
+ """
267
+ schedules = []
268
+ async with self._acquire_master_conn() as redis:
269
+ buffer = []
270
+ async for key in redis.scan_iter(f"{self.prefix}:*"):
271
+ buffer.append(key)
272
+ if len(buffer) >= self.buffer_size:
273
+ schedules.extend(await redis.mget(buffer))
274
+ buffer = []
275
+ if buffer:
276
+ schedules.extend(await redis.mget(buffer))
167
277
  return [
168
278
  model_validate(ScheduledTask, self.serializer.loadb(schedule))
169
279
  for schedule in schedules
@@ -174,3 +284,8 @@ class RedisClusterScheduleSource(ScheduleSource):
174
284
  """Delete a task after it's completed."""
175
285
  if task.time is not None:
176
286
  await self.delete_schedule(task.schedule_id)
287
+
288
+ async def shutdown(self) -> None:
289
+ """Shut down the schedule source."""
290
+ for sentinel in self.sentinel.sentinels:
291
+ await sentinel.aclose() # type: ignore[attr-defined]
@@ -1,16 +0,0 @@
1
- import pickle
2
- from typing import Any
3
-
4
- from taskiq.abc.serializer import TaskiqSerializer
5
-
6
-
7
- class PickleSerializer(TaskiqSerializer):
8
- """Serializer that uses pickle."""
9
-
10
- def dumpb(self, value: Any) -> bytes:
11
- """Dumps value to bytes."""
12
- return pickle.dumps(value)
13
-
14
- def loadb(self, value: bytes) -> Any:
15
- """Loads value from bytes."""
16
- return pickle.loads(value) # noqa: S301
File without changes