taskiq-redis 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- taskiq_redis/__init__.py +17 -7
- taskiq_redis/exceptions.py +6 -0
- taskiq_redis/list_schedule_source.py +229 -0
- taskiq_redis/redis_backend.py +57 -45
- taskiq_redis/redis_broker.py +142 -5
- taskiq_redis/redis_cluster_broker.py +127 -6
- taskiq_redis/redis_sentinel_broker.py +131 -4
- taskiq_redis/schedule_source.py +9 -9
- taskiq_redis-1.0.4.dist-info/METADATA +215 -0
- taskiq_redis-1.0.4.dist-info/RECORD +13 -0
- {taskiq_redis-1.0.2.dist-info → taskiq_redis-1.0.4.dist-info}/WHEEL +1 -1
- taskiq_redis-1.0.2.dist-info/METADATA +0 -125
- taskiq_redis-1.0.2.dist-info/RECORD +0 -12
- {taskiq_redis-1.0.2.dist-info → taskiq_redis-1.0.4.dist-info}/LICENSE +0 -0
taskiq_redis/__init__.py
CHANGED
|
@@ -1,14 +1,20 @@
|
|
|
1
1
|
"""Package for redis integration."""
|
|
2
|
+
|
|
3
|
+
from taskiq_redis.list_schedule_source import ListRedisScheduleSource
|
|
2
4
|
from taskiq_redis.redis_backend import (
|
|
3
5
|
RedisAsyncClusterResultBackend,
|
|
4
6
|
RedisAsyncResultBackend,
|
|
5
7
|
RedisAsyncSentinelResultBackend,
|
|
6
8
|
)
|
|
7
|
-
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
|
|
8
|
-
from taskiq_redis.redis_cluster_broker import
|
|
9
|
+
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker, RedisStreamBroker
|
|
10
|
+
from taskiq_redis.redis_cluster_broker import (
|
|
11
|
+
ListQueueClusterBroker,
|
|
12
|
+
RedisStreamClusterBroker,
|
|
13
|
+
)
|
|
9
14
|
from taskiq_redis.redis_sentinel_broker import (
|
|
10
15
|
ListQueueSentinelBroker,
|
|
11
16
|
PubSubSentinelBroker,
|
|
17
|
+
RedisStreamSentinelBroker,
|
|
12
18
|
)
|
|
13
19
|
from taskiq_redis.schedule_source import (
|
|
14
20
|
RedisClusterScheduleSource,
|
|
@@ -17,15 +23,19 @@ from taskiq_redis.schedule_source import (
|
|
|
17
23
|
)
|
|
18
24
|
|
|
19
25
|
__all__ = [
|
|
20
|
-
"RedisAsyncClusterResultBackend",
|
|
21
|
-
"RedisAsyncResultBackend",
|
|
22
|
-
"RedisAsyncSentinelResultBackend",
|
|
23
26
|
"ListQueueBroker",
|
|
24
|
-
"PubSubBroker",
|
|
25
27
|
"ListQueueClusterBroker",
|
|
26
28
|
"ListQueueSentinelBroker",
|
|
29
|
+
"ListRedisScheduleSource",
|
|
30
|
+
"PubSubBroker",
|
|
27
31
|
"PubSubSentinelBroker",
|
|
28
|
-
"
|
|
32
|
+
"RedisAsyncClusterResultBackend",
|
|
33
|
+
"RedisAsyncResultBackend",
|
|
34
|
+
"RedisAsyncSentinelResultBackend",
|
|
29
35
|
"RedisClusterScheduleSource",
|
|
36
|
+
"RedisScheduleSource",
|
|
30
37
|
"RedisSentinelScheduleSource",
|
|
38
|
+
"RedisStreamBroker",
|
|
39
|
+
"RedisStreamClusterBroker",
|
|
40
|
+
"RedisStreamSentinelBroker",
|
|
31
41
|
]
|
taskiq_redis/exceptions.py
CHANGED
|
@@ -8,10 +8,16 @@ class TaskIQRedisError(TaskiqError):
|
|
|
8
8
|
class DuplicateExpireTimeSelectedError(ResultBackendError, TaskIQRedisError):
|
|
9
9
|
"""Error if two lifetimes are selected."""
|
|
10
10
|
|
|
11
|
+
__template__ = "Choose either result_ex_time or result_px_time."
|
|
12
|
+
|
|
11
13
|
|
|
12
14
|
class ExpireTimeMustBeMoreThanZeroError(ResultBackendError, TaskIQRedisError):
|
|
13
15
|
"""Error if two lifetimes are less or equal zero."""
|
|
14
16
|
|
|
17
|
+
__template__ = (
|
|
18
|
+
"You must select one expire time param and it must be more than zero."
|
|
19
|
+
)
|
|
20
|
+
|
|
15
21
|
|
|
16
22
|
class ResultIsMissingError(TaskIQRedisError, ResultGetError):
|
|
17
23
|
"""Error if there is no result when trying to get it."""
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
from logging import getLogger
|
|
3
|
+
from typing import Any, List, Optional
|
|
4
|
+
|
|
5
|
+
from redis.asyncio import BlockingConnectionPool, Redis
|
|
6
|
+
from taskiq import ScheduledTask, ScheduleSource
|
|
7
|
+
from taskiq.abc.serializer import TaskiqSerializer
|
|
8
|
+
from taskiq.compat import model_dump, model_validate
|
|
9
|
+
from taskiq.serializers import PickleSerializer
|
|
10
|
+
from typing_extensions import Self
|
|
11
|
+
|
|
12
|
+
logger = getLogger("taskiq.redis_schedule_source")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ListRedisScheduleSource(ScheduleSource):
|
|
16
|
+
"""Schecule source based on arrays."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
url: str,
|
|
21
|
+
prefix: str = "schedule",
|
|
22
|
+
max_connection_pool_size: Optional[int] = None,
|
|
23
|
+
serializer: Optional[TaskiqSerializer] = None,
|
|
24
|
+
bufffer_size: int = 50,
|
|
25
|
+
skip_past_schedules: bool = False,
|
|
26
|
+
**connection_kwargs: Any,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
self._prefix = prefix
|
|
30
|
+
self._buffer_size = bufffer_size
|
|
31
|
+
self._connection_pool = BlockingConnectionPool.from_url(
|
|
32
|
+
url=url,
|
|
33
|
+
max_connections=max_connection_pool_size,
|
|
34
|
+
**connection_kwargs,
|
|
35
|
+
)
|
|
36
|
+
if serializer is None:
|
|
37
|
+
serializer = PickleSerializer()
|
|
38
|
+
self._serializer = serializer
|
|
39
|
+
self._is_first_run = True
|
|
40
|
+
self._previous_schedule_source: Optional[ScheduleSource] = None
|
|
41
|
+
self._delete_schedules_after_migration: bool = True
|
|
42
|
+
self._skip_past_schedules = skip_past_schedules
|
|
43
|
+
|
|
44
|
+
async def startup(self) -> None:
|
|
45
|
+
"""
|
|
46
|
+
Startup the schedule source.
|
|
47
|
+
|
|
48
|
+
By default this function does nothing.
|
|
49
|
+
But if the previous schedule source is set,
|
|
50
|
+
it will try to migrate schedules from it.
|
|
51
|
+
"""
|
|
52
|
+
if self._previous_schedule_source is not None:
|
|
53
|
+
logger.info("Migrating schedules from previous source")
|
|
54
|
+
await self._previous_schedule_source.startup()
|
|
55
|
+
schedules = await self._previous_schedule_source.get_schedules()
|
|
56
|
+
logger.info(f"Found {len(schedules)}")
|
|
57
|
+
for schedule in schedules:
|
|
58
|
+
await self.add_schedule(schedule)
|
|
59
|
+
if self._delete_schedules_after_migration:
|
|
60
|
+
await self._previous_schedule_source.delete_schedule(
|
|
61
|
+
schedule.schedule_id,
|
|
62
|
+
)
|
|
63
|
+
await self._previous_schedule_source.shutdown()
|
|
64
|
+
logger.info("Migration complete")
|
|
65
|
+
|
|
66
|
+
def _get_time_key(self, time: datetime.datetime) -> str:
|
|
67
|
+
"""Get the key for a time-based schedule."""
|
|
68
|
+
if time.tzinfo is None:
|
|
69
|
+
time = time.replace(tzinfo=datetime.timezone.utc)
|
|
70
|
+
iso_time = time.astimezone(datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M")
|
|
71
|
+
return f"{self._prefix}:time:{iso_time}"
|
|
72
|
+
|
|
73
|
+
def _get_cron_key(self) -> str:
|
|
74
|
+
"""Get the key for a cron-based schedule."""
|
|
75
|
+
return f"{self._prefix}:cron"
|
|
76
|
+
|
|
77
|
+
def _get_data_key(self, schedule_id: str) -> str:
|
|
78
|
+
"""Get the key for a schedule data."""
|
|
79
|
+
return f"{self._prefix}:data:{schedule_id}"
|
|
80
|
+
|
|
81
|
+
def _parse_time_key(self, key: str) -> Optional[datetime.datetime]:
|
|
82
|
+
"""Get time value from the timed-key."""
|
|
83
|
+
try:
|
|
84
|
+
dt_str = key.split(":", 2)[2]
|
|
85
|
+
return datetime.datetime.strptime(dt_str, "%Y-%m-%dT%H:%M").replace(
|
|
86
|
+
tzinfo=datetime.timezone.utc,
|
|
87
|
+
)
|
|
88
|
+
except ValueError:
|
|
89
|
+
logger.debug("Failed to parse time key %s", key)
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
async def _get_previous_time_schedules(self) -> list[bytes]:
|
|
93
|
+
"""
|
|
94
|
+
Function that gets all timed schedules that are in the past.
|
|
95
|
+
|
|
96
|
+
Since this source doesn't retrieve all the schedules at once,
|
|
97
|
+
we need to get all the schedules that are in the past and haven't
|
|
98
|
+
been sent yet.
|
|
99
|
+
|
|
100
|
+
We do this by getting all the time keys and checking if the time
|
|
101
|
+
is less than the current time.
|
|
102
|
+
|
|
103
|
+
This function is called only during the first run to minimize
|
|
104
|
+
the number of requests to the Redis server.
|
|
105
|
+
"""
|
|
106
|
+
logger.info("Getting previous time schedules")
|
|
107
|
+
minute_before = datetime.datetime.now(
|
|
108
|
+
datetime.timezone.utc,
|
|
109
|
+
).replace(second=0, microsecond=0) - datetime.timedelta(
|
|
110
|
+
minutes=1,
|
|
111
|
+
)
|
|
112
|
+
schedules = []
|
|
113
|
+
async with Redis(connection_pool=self._connection_pool) as redis:
|
|
114
|
+
time_keys: list[str] = []
|
|
115
|
+
# We need to get all the time keys and check if the time is less than
|
|
116
|
+
# the current time.
|
|
117
|
+
async for key in redis.scan_iter(f"{self._prefix}:time:*"):
|
|
118
|
+
key_time = self._parse_time_key(key.decode())
|
|
119
|
+
if key_time and key_time <= minute_before:
|
|
120
|
+
time_keys.append(key.decode())
|
|
121
|
+
for key in time_keys:
|
|
122
|
+
schedules.extend(await redis.lrange(key, 0, -1)) # type: ignore
|
|
123
|
+
|
|
124
|
+
return schedules
|
|
125
|
+
|
|
126
|
+
async def delete_schedule(self, schedule_id: str) -> None:
|
|
127
|
+
"""Delete a schedule from the source."""
|
|
128
|
+
async with Redis(connection_pool=self._connection_pool) as redis:
|
|
129
|
+
schedule = await redis.getdel(self._get_data_key(schedule_id))
|
|
130
|
+
if schedule is not None:
|
|
131
|
+
logger.debug("Deleting schedule %s", schedule_id)
|
|
132
|
+
schedule = model_validate(
|
|
133
|
+
ScheduledTask,
|
|
134
|
+
self._serializer.loadb(schedule),
|
|
135
|
+
)
|
|
136
|
+
# We need to remove the schedule from the cron or time list.
|
|
137
|
+
if schedule.cron is not None:
|
|
138
|
+
await redis.lrem(self._get_cron_key(), 0, schedule_id) # type: ignore
|
|
139
|
+
elif schedule.time is not None:
|
|
140
|
+
time_key = self._get_time_key(schedule.time)
|
|
141
|
+
await redis.lrem(time_key, 0, schedule_id) # type: ignore
|
|
142
|
+
|
|
143
|
+
async def add_schedule(self, schedule: "ScheduledTask") -> None:
|
|
144
|
+
"""Add a schedule to the source."""
|
|
145
|
+
async with Redis(connection_pool=self._connection_pool) as redis:
|
|
146
|
+
# At first we set data key which contains the schedule data.
|
|
147
|
+
await redis.set(
|
|
148
|
+
f"{self._prefix}:data:{schedule.schedule_id}",
|
|
149
|
+
self._serializer.dumpb(model_dump(schedule)),
|
|
150
|
+
)
|
|
151
|
+
# Then we add the schedule to the cron or time list.
|
|
152
|
+
# This is an optimization, so we can get all the schedules
|
|
153
|
+
# for the current time much faster.
|
|
154
|
+
if schedule.cron is not None:
|
|
155
|
+
await redis.rpush(self._get_cron_key(), schedule.schedule_id) # type: ignore
|
|
156
|
+
elif schedule.time is not None:
|
|
157
|
+
await redis.rpush( # type: ignore
|
|
158
|
+
self._get_time_key(schedule.time),
|
|
159
|
+
schedule.schedule_id,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
async def post_send(self, task: ScheduledTask) -> None:
|
|
163
|
+
"""Delete a task after it's completed."""
|
|
164
|
+
if task.time is not None:
|
|
165
|
+
await self.delete_schedule(task.schedule_id)
|
|
166
|
+
|
|
167
|
+
async def get_schedules(self) -> List["ScheduledTask"]:
|
|
168
|
+
"""
|
|
169
|
+
Get all schedules.
|
|
170
|
+
|
|
171
|
+
This function gets all the schedules from the schedule source.
|
|
172
|
+
What it does is get all the cron schedules and time schedules
|
|
173
|
+
for the current time and return them.
|
|
174
|
+
|
|
175
|
+
If it's the first run, it also gets all the time schedules
|
|
176
|
+
that are in the past and haven't been sent yet.
|
|
177
|
+
"""
|
|
178
|
+
schedules = []
|
|
179
|
+
current_time = datetime.datetime.now(datetime.timezone.utc)
|
|
180
|
+
timed: list[bytes] = []
|
|
181
|
+
# Only during first run, we need to get previous time schedules
|
|
182
|
+
if self._is_first_run and not self._skip_past_schedules:
|
|
183
|
+
timed = await self._get_previous_time_schedules()
|
|
184
|
+
self._is_first_run = False
|
|
185
|
+
async with Redis(connection_pool=self._connection_pool) as redis:
|
|
186
|
+
buffer = []
|
|
187
|
+
crons = await redis.lrange(self._get_cron_key(), 0, -1) # type: ignore
|
|
188
|
+
logger.debug("Got cron scheduleds: %s", crons)
|
|
189
|
+
if crons:
|
|
190
|
+
buffer.extend(crons)
|
|
191
|
+
timed.extend(await redis.lrange(self._get_time_key(current_time), 0, -1)) # type: ignore
|
|
192
|
+
logger.debug("Got timed scheduleds: %s", crons)
|
|
193
|
+
if timed:
|
|
194
|
+
buffer.extend(timed)
|
|
195
|
+
while buffer:
|
|
196
|
+
schedules.extend(
|
|
197
|
+
await redis.mget(
|
|
198
|
+
(
|
|
199
|
+
self._get_data_key(x.decode())
|
|
200
|
+
for x in buffer[: self._buffer_size]
|
|
201
|
+
),
|
|
202
|
+
),
|
|
203
|
+
)
|
|
204
|
+
buffer = buffer[self._buffer_size :]
|
|
205
|
+
|
|
206
|
+
return [
|
|
207
|
+
model_validate(ScheduledTask, self._serializer.loadb(schedule))
|
|
208
|
+
for schedule in schedules
|
|
209
|
+
if schedule
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
def with_migrate_from(
|
|
213
|
+
self,
|
|
214
|
+
source: ScheduleSource,
|
|
215
|
+
delete_schedules: bool = True,
|
|
216
|
+
) -> Self:
|
|
217
|
+
"""
|
|
218
|
+
Enable migration from previous schedule source.
|
|
219
|
+
|
|
220
|
+
If this function is called during declaration,
|
|
221
|
+
the source will try to migrate schedules from the previous source.
|
|
222
|
+
|
|
223
|
+
:param source: previous schedule source
|
|
224
|
+
:param delete_schedules: delete schedules during migration process
|
|
225
|
+
from the previous source.
|
|
226
|
+
"""
|
|
227
|
+
self._previous_schedule_source = source
|
|
228
|
+
self._delete_schedules_after_migration = delete_schedules
|
|
229
|
+
return self
|
taskiq_redis/redis_backend.py
CHANGED
|
@@ -16,10 +16,10 @@ from redis.asyncio import BlockingConnectionPool, Redis, Sentinel
|
|
|
16
16
|
from redis.asyncio.cluster import RedisCluster
|
|
17
17
|
from redis.asyncio.connection import Connection
|
|
18
18
|
from taskiq import AsyncResultBackend
|
|
19
|
-
from taskiq.abc.result_backend import TaskiqResult
|
|
20
19
|
from taskiq.abc.serializer import TaskiqSerializer
|
|
21
20
|
from taskiq.compat import model_dump, model_validate
|
|
22
21
|
from taskiq.depends.progress_tracker import TaskProgress
|
|
22
|
+
from taskiq.result import TaskiqResult
|
|
23
23
|
from taskiq.serializers import PickleSerializer
|
|
24
24
|
|
|
25
25
|
from taskiq_redis.exceptions import (
|
|
@@ -34,8 +34,8 @@ else:
|
|
|
34
34
|
from typing_extensions import TypeAlias
|
|
35
35
|
|
|
36
36
|
if TYPE_CHECKING:
|
|
37
|
-
_Redis: TypeAlias = Redis[bytes]
|
|
38
|
-
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection]
|
|
37
|
+
_Redis: TypeAlias = Redis[bytes] # type: ignore
|
|
38
|
+
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool[Connection] # type: ignore
|
|
39
39
|
else:
|
|
40
40
|
_Redis: TypeAlias = Redis
|
|
41
41
|
_BlockingConnectionPool: TypeAlias = BlockingConnectionPool
|
|
@@ -56,6 +56,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
56
56
|
result_px_time: Optional[int] = None,
|
|
57
57
|
max_connection_pool_size: Optional[int] = None,
|
|
58
58
|
serializer: Optional[TaskiqSerializer] = None,
|
|
59
|
+
prefix_str: Optional[str] = None,
|
|
59
60
|
**connection_kwargs: Any,
|
|
60
61
|
) -> None:
|
|
61
62
|
"""
|
|
@@ -82,6 +83,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
82
83
|
self.keep_results = keep_results
|
|
83
84
|
self.result_ex_time = result_ex_time
|
|
84
85
|
self.result_px_time = result_px_time
|
|
86
|
+
self.prefix_str = prefix_str
|
|
85
87
|
|
|
86
88
|
unavailable_conditions = any(
|
|
87
89
|
(
|
|
@@ -90,14 +92,15 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
90
92
|
),
|
|
91
93
|
)
|
|
92
94
|
if unavailable_conditions:
|
|
93
|
-
raise ExpireTimeMustBeMoreThanZeroError
|
|
94
|
-
"You must select one expire time param and it must be more than zero.",
|
|
95
|
-
)
|
|
95
|
+
raise ExpireTimeMustBeMoreThanZeroError
|
|
96
96
|
|
|
97
97
|
if self.result_ex_time and self.result_px_time:
|
|
98
|
-
raise DuplicateExpireTimeSelectedError
|
|
99
|
-
|
|
100
|
-
|
|
98
|
+
raise DuplicateExpireTimeSelectedError
|
|
99
|
+
|
|
100
|
+
def _task_name(self, task_id: str) -> str:
|
|
101
|
+
if self.prefix_str is None:
|
|
102
|
+
return task_id
|
|
103
|
+
return f"{self.prefix_str}:{task_id}"
|
|
101
104
|
|
|
102
105
|
async def shutdown(self) -> None:
|
|
103
106
|
"""Closes redis connection."""
|
|
@@ -119,7 +122,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
119
122
|
:param result: TaskiqResult instance.
|
|
120
123
|
"""
|
|
121
124
|
redis_set_params: Dict[str, Union[str, int, bytes]] = {
|
|
122
|
-
"name": task_id,
|
|
125
|
+
"name": self._task_name(task_id),
|
|
123
126
|
"value": self.serializer.dumpb(model_dump(result)),
|
|
124
127
|
}
|
|
125
128
|
if self.result_ex_time:
|
|
@@ -139,7 +142,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
139
142
|
:returns: True if the result is ready else False.
|
|
140
143
|
"""
|
|
141
144
|
async with Redis(connection_pool=self.redis_pool) as redis:
|
|
142
|
-
return bool(await redis.exists(task_id))
|
|
145
|
+
return bool(await redis.exists(self._task_name(task_id)))
|
|
143
146
|
|
|
144
147
|
async def get_result(
|
|
145
148
|
self,
|
|
@@ -154,14 +157,15 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
154
157
|
:raises ResultIsMissingError: if there is no result when trying to get it.
|
|
155
158
|
:return: task's return value.
|
|
156
159
|
"""
|
|
160
|
+
task_name = self._task_name(task_id)
|
|
157
161
|
async with Redis(connection_pool=self.redis_pool) as redis:
|
|
158
162
|
if self.keep_results:
|
|
159
163
|
result_value = await redis.get(
|
|
160
|
-
name=
|
|
164
|
+
name=task_name,
|
|
161
165
|
)
|
|
162
166
|
else:
|
|
163
167
|
result_value = await redis.getdel(
|
|
164
|
-
name=
|
|
168
|
+
name=task_name,
|
|
165
169
|
)
|
|
166
170
|
|
|
167
171
|
if result_value is None:
|
|
@@ -192,7 +196,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
192
196
|
:param result: task's TaskProgress instance.
|
|
193
197
|
"""
|
|
194
198
|
redis_set_params: Dict[str, Union[str, int, bytes]] = {
|
|
195
|
-
"name": task_id + PROGRESS_KEY_SUFFIX,
|
|
199
|
+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
196
200
|
"value": self.serializer.dumpb(model_dump(progress)),
|
|
197
201
|
}
|
|
198
202
|
if self.result_ex_time:
|
|
@@ -215,7 +219,7 @@ class RedisAsyncResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
215
219
|
"""
|
|
216
220
|
async with Redis(connection_pool=self.redis_pool) as redis:
|
|
217
221
|
result_value = await redis.get(
|
|
218
|
-
name=task_id + PROGRESS_KEY_SUFFIX,
|
|
222
|
+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
219
223
|
)
|
|
220
224
|
|
|
221
225
|
if result_value is None:
|
|
@@ -237,6 +241,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
237
241
|
result_ex_time: Optional[int] = None,
|
|
238
242
|
result_px_time: Optional[int] = None,
|
|
239
243
|
serializer: Optional[TaskiqSerializer] = None,
|
|
244
|
+
prefix_str: Optional[str] = None,
|
|
240
245
|
**connection_kwargs: Any,
|
|
241
246
|
) -> None:
|
|
242
247
|
"""
|
|
@@ -253,7 +258,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
253
258
|
:raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
|
|
254
259
|
and result_px_time are equal zero.
|
|
255
260
|
"""
|
|
256
|
-
self.redis: RedisCluster
|
|
261
|
+
self.redis: "RedisCluster" = RedisCluster.from_url(
|
|
257
262
|
redis_url,
|
|
258
263
|
**connection_kwargs,
|
|
259
264
|
)
|
|
@@ -261,6 +266,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
261
266
|
self.keep_results = keep_results
|
|
262
267
|
self.result_ex_time = result_ex_time
|
|
263
268
|
self.result_px_time = result_px_time
|
|
269
|
+
self.prefix_str = prefix_str
|
|
264
270
|
|
|
265
271
|
unavailable_conditions = any(
|
|
266
272
|
(
|
|
@@ -269,18 +275,19 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
269
275
|
),
|
|
270
276
|
)
|
|
271
277
|
if unavailable_conditions:
|
|
272
|
-
raise ExpireTimeMustBeMoreThanZeroError
|
|
273
|
-
"You must select one expire time param and it must be more than zero.",
|
|
274
|
-
)
|
|
278
|
+
raise ExpireTimeMustBeMoreThanZeroError
|
|
275
279
|
|
|
276
280
|
if self.result_ex_time and self.result_px_time:
|
|
277
|
-
raise DuplicateExpireTimeSelectedError
|
|
278
|
-
|
|
279
|
-
|
|
281
|
+
raise DuplicateExpireTimeSelectedError
|
|
282
|
+
|
|
283
|
+
def _task_name(self, task_id: str) -> str:
|
|
284
|
+
if self.prefix_str is None:
|
|
285
|
+
return task_id
|
|
286
|
+
return f"{self.prefix_str}:{task_id}"
|
|
280
287
|
|
|
281
288
|
async def shutdown(self) -> None:
|
|
282
289
|
"""Closes redis connection."""
|
|
283
|
-
await self.redis.aclose()
|
|
290
|
+
await self.redis.aclose()
|
|
284
291
|
await super().shutdown()
|
|
285
292
|
|
|
286
293
|
async def set_result(
|
|
@@ -298,7 +305,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
298
305
|
:param result: TaskiqResult instance.
|
|
299
306
|
"""
|
|
300
307
|
redis_set_params: Dict[str, Union[str, bytes, int]] = {
|
|
301
|
-
"name": task_id,
|
|
308
|
+
"name": self._task_name(task_id),
|
|
302
309
|
"value": self.serializer.dumpb(model_dump(result)),
|
|
303
310
|
}
|
|
304
311
|
if self.result_ex_time:
|
|
@@ -316,7 +323,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
316
323
|
|
|
317
324
|
:returns: True if the result is ready else False.
|
|
318
325
|
"""
|
|
319
|
-
return bool(await self.redis.exists(task_id))
|
|
326
|
+
return bool(await self.redis.exists(self._task_name(task_id)))
|
|
320
327
|
|
|
321
328
|
async def get_result(
|
|
322
329
|
self,
|
|
@@ -331,13 +338,14 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
331
338
|
:raises ResultIsMissingError: if there is no result when trying to get it.
|
|
332
339
|
:return: task's return value.
|
|
333
340
|
"""
|
|
341
|
+
task_name = self._task_name(task_id)
|
|
334
342
|
if self.keep_results:
|
|
335
|
-
result_value = await self.redis.get(
|
|
336
|
-
name=
|
|
343
|
+
result_value = await self.redis.get(
|
|
344
|
+
name=task_name,
|
|
337
345
|
)
|
|
338
346
|
else:
|
|
339
|
-
result_value = await self.redis.getdel(
|
|
340
|
-
name=
|
|
347
|
+
result_value = await self.redis.getdel(
|
|
348
|
+
name=task_name,
|
|
341
349
|
)
|
|
342
350
|
|
|
343
351
|
if result_value is None:
|
|
@@ -368,7 +376,7 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
368
376
|
:param result: task's TaskProgress instance.
|
|
369
377
|
"""
|
|
370
378
|
redis_set_params: Dict[str, Union[str, int, bytes]] = {
|
|
371
|
-
"name": task_id + PROGRESS_KEY_SUFFIX,
|
|
379
|
+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
372
380
|
"value": self.serializer.dumpb(model_dump(progress)),
|
|
373
381
|
}
|
|
374
382
|
if self.result_ex_time:
|
|
@@ -388,8 +396,8 @@ class RedisAsyncClusterResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
388
396
|
:param task_id: task's id.
|
|
389
397
|
:return: task's TaskProgress instance.
|
|
390
398
|
"""
|
|
391
|
-
result_value = await self.redis.get(
|
|
392
|
-
name=task_id + PROGRESS_KEY_SUFFIX,
|
|
399
|
+
result_value = await self.redis.get(
|
|
400
|
+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
393
401
|
)
|
|
394
402
|
|
|
395
403
|
if result_value is None:
|
|
@@ -414,6 +422,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
414
422
|
min_other_sentinels: int = 0,
|
|
415
423
|
sentinel_kwargs: Optional[Any] = None,
|
|
416
424
|
serializer: Optional[TaskiqSerializer] = None,
|
|
425
|
+
prefix_str: Optional[str] = None,
|
|
417
426
|
**connection_kwargs: Any,
|
|
418
427
|
) -> None:
|
|
419
428
|
"""
|
|
@@ -443,6 +452,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
443
452
|
self.keep_results = keep_results
|
|
444
453
|
self.result_ex_time = result_ex_time
|
|
445
454
|
self.result_px_time = result_px_time
|
|
455
|
+
self.prefix_str = prefix_str
|
|
446
456
|
|
|
447
457
|
unavailable_conditions = any(
|
|
448
458
|
(
|
|
@@ -451,14 +461,15 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
451
461
|
),
|
|
452
462
|
)
|
|
453
463
|
if unavailable_conditions:
|
|
454
|
-
raise ExpireTimeMustBeMoreThanZeroError
|
|
455
|
-
"You must select one expire time param and it must be more than zero.",
|
|
456
|
-
)
|
|
464
|
+
raise ExpireTimeMustBeMoreThanZeroError
|
|
457
465
|
|
|
458
466
|
if self.result_ex_time and self.result_px_time:
|
|
459
|
-
raise DuplicateExpireTimeSelectedError
|
|
460
|
-
|
|
461
|
-
|
|
467
|
+
raise DuplicateExpireTimeSelectedError
|
|
468
|
+
|
|
469
|
+
def _task_name(self, task_id: str) -> str:
|
|
470
|
+
if self.prefix_str is None:
|
|
471
|
+
return task_id
|
|
472
|
+
return f"{self.prefix_str}:{task_id}"
|
|
462
473
|
|
|
463
474
|
@asynccontextmanager
|
|
464
475
|
async def _acquire_master_conn(self) -> AsyncIterator[_Redis]:
|
|
@@ -480,7 +491,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
480
491
|
:param result: TaskiqResult instance.
|
|
481
492
|
"""
|
|
482
493
|
redis_set_params: Dict[str, Union[str, bytes, int]] = {
|
|
483
|
-
"name": task_id,
|
|
494
|
+
"name": self._task_name(task_id),
|
|
484
495
|
"value": self.serializer.dumpb(model_dump(result)),
|
|
485
496
|
}
|
|
486
497
|
if self.result_ex_time:
|
|
@@ -500,7 +511,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
500
511
|
:returns: True if the result is ready else False.
|
|
501
512
|
"""
|
|
502
513
|
async with self._acquire_master_conn() as redis:
|
|
503
|
-
return bool(await redis.exists(task_id))
|
|
514
|
+
return bool(await redis.exists(self._task_name(task_id)))
|
|
504
515
|
|
|
505
516
|
async def get_result(
|
|
506
517
|
self,
|
|
@@ -515,14 +526,15 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
515
526
|
:raises ResultIsMissingError: if there is no result when trying to get it.
|
|
516
527
|
:return: task's return value.
|
|
517
528
|
"""
|
|
529
|
+
task_name = self._task_name(task_id)
|
|
518
530
|
async with self._acquire_master_conn() as redis:
|
|
519
531
|
if self.keep_results:
|
|
520
532
|
result_value = await redis.get(
|
|
521
|
-
name=
|
|
533
|
+
name=task_name,
|
|
522
534
|
)
|
|
523
535
|
else:
|
|
524
536
|
result_value = await redis.getdel(
|
|
525
|
-
name=
|
|
537
|
+
name=task_name,
|
|
526
538
|
)
|
|
527
539
|
|
|
528
540
|
if result_value is None:
|
|
@@ -553,7 +565,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
553
565
|
:param result: task's TaskProgress instance.
|
|
554
566
|
"""
|
|
555
567
|
redis_set_params: Dict[str, Union[str, int, bytes]] = {
|
|
556
|
-
"name": task_id + PROGRESS_KEY_SUFFIX,
|
|
568
|
+
"name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
557
569
|
"value": self.serializer.dumpb(model_dump(progress)),
|
|
558
570
|
}
|
|
559
571
|
if self.result_ex_time:
|
|
@@ -576,7 +588,7 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
576
588
|
"""
|
|
577
589
|
async with self._acquire_master_conn() as redis:
|
|
578
590
|
result_value = await redis.get(
|
|
579
|
-
name=task_id + PROGRESS_KEY_SUFFIX,
|
|
591
|
+
name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX,
|
|
580
592
|
)
|
|
581
593
|
|
|
582
594
|
if result_value is None:
|
|
@@ -590,4 +602,4 @@ class RedisAsyncSentinelResultBackend(AsyncResultBackend[_ReturnType]):
|
|
|
590
602
|
async def shutdown(self) -> None:
|
|
591
603
|
"""Shutdown sentinel connections."""
|
|
592
604
|
for sentinel in self.sentinel.sentinels:
|
|
593
|
-
await sentinel.aclose()
|
|
605
|
+
await sentinel.aclose()
|