redis-stream-queue 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,27 @@
1
+ from .client import StreamClient
2
+ from .consumer import StreamConsumer
3
+ from .group import ConsumerConfig, ConsumerGroup
4
+ from .message import ConsumerInfo, ConsumerMetrics, PendingEntry, ProducerMetrics, StreamMessage, StreamStats
5
+ from .producer import StreamProducer
6
+ from .retry import DLQHandler, RetryHandler
7
+ from .serializers import JsonSerializer, MsgpackSerializer, PickleSerializer, Serializer
8
+
9
+ __all__ = [
10
+ "StreamClient",
11
+ "StreamConsumer",
12
+ "StreamProducer",
13
+ "ConsumerConfig",
14
+ "ConsumerGroup",
15
+ "RetryHandler",
16
+ "StreamMessage",
17
+ "PendingEntry",
18
+ "StreamStats",
19
+ "ConsumerInfo",
20
+ "ConsumerMetrics",
21
+ "ProducerMetrics",
22
+ "DLQHandler",
23
+ "Serializer",
24
+ "JsonSerializer",
25
+ "MsgpackSerializer",
26
+ "PickleSerializer",
27
+ ]
@@ -0,0 +1,306 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from typing import ClassVar
6
+
7
+ import redis.asyncio as aioredis
8
+
9
+ from .message import PendingEntry, StreamMessage
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ _BUSYGROUP = "BUSYGROUP"
14
+
15
+
16
+ class StreamClient:
17
+ """
18
+ Async Redis Stream client with singleton connection pool per (host, port, db).
19
+ Supports single-node, cluster, and URL connection modes.
20
+ """
21
+
22
+ _pools: ClassVar[dict[tuple, aioredis.Redis]] = {}
23
+ _lock: ClassVar[asyncio.Lock | None] = None
24
+
25
+ def __init__(
26
+ self,
27
+ host: str = "localhost",
28
+ port: int = 6379,
29
+ db: int = 0,
30
+ username: str | None = None,
31
+ password: str | None = None,
32
+ prefix: str = "",
33
+ max_connections: int = 1000,
34
+ pool_timeout: int = 5,
35
+ client_name: str | None = None,
36
+ ssl: bool = False,
37
+ ) -> None:
38
+ self._host = host
39
+ self._port = port
40
+ self._db = db
41
+ self._username = username
42
+ self._password = password
43
+ self._prefix = prefix
44
+ self._max_connections = max_connections
45
+ self._pool_timeout = pool_timeout
46
+ self._client_name = client_name
47
+ self._ssl = ssl
48
+ self._cluster = False
49
+ self._startup_nodes: list[dict] | None = None
50
+ self._url: str | None = None
51
+
52
+ @classmethod
53
+ def from_url(cls, url: str, **kwargs) -> "StreamClient":
54
+ instance = cls.__new__(cls)
55
+ instance._url = url
56
+ instance._cluster = url.startswith(("redis+cluster://", "rediss+cluster://"))
57
+ instance._startup_nodes = None
58
+ instance._prefix = kwargs.get("prefix", "")
59
+ instance._max_connections = kwargs.get("max_connections", 1000)
60
+ instance._pool_timeout = kwargs.get("pool_timeout", 5)
61
+ instance._client_name = kwargs.get("client_name", None)
62
+ instance._host = None
63
+ instance._port = None
64
+ instance._db = kwargs.get("db", 0)
65
+ instance._username = None
66
+ instance._password = None
67
+ instance._ssl = False
68
+ return instance
69
+
70
+ @classmethod
71
+ def from_cluster(
72
+ cls,
73
+ startup_nodes: list[dict],
74
+ password: str | None = None,
75
+ username: str | None = None,
76
+ prefix: str = "",
77
+ **kwargs,
78
+ ) -> "StreamClient":
79
+ instance = cls.__new__(cls)
80
+ instance._startup_nodes = startup_nodes
81
+ instance._cluster = True
82
+ instance._password = password
83
+ instance._username = username
84
+ instance._prefix = prefix
85
+ instance._max_connections = kwargs.get("max_connections", 1000)
86
+ instance._pool_timeout = kwargs.get("pool_timeout", 5)
87
+ instance._client_name = kwargs.get("client_name", None)
88
+ instance._host = None
89
+ instance._port = None
90
+ instance._db = 0
91
+ instance._ssl = kwargs.get("ssl", False)
92
+ instance._url = None
93
+ return instance
94
+
95
+ @classmethod
96
+ def _get_lock(cls) -> asyncio.Lock:
97
+ if cls._lock is None:
98
+ cls._lock = asyncio.Lock()
99
+ return cls._lock
100
+
101
+ def _pool_key(self) -> tuple:
102
+ if self._cluster and self._startup_nodes:
103
+ return ("cluster", tuple(sorted(str(n) for n in self._startup_nodes)))
104
+ return (self._url or f"{self._host}:{self._port}:{self._db}", self._prefix)
105
+
106
+ async def _get_client(self) -> aioredis.Redis:
107
+ key = self._pool_key()
108
+ async with self._get_lock():
109
+ if key not in self._pools:
110
+ self._pools[key] = await self._create_client()
111
+ return self._pools[key]
112
+
113
+ async def _create_client(self) -> aioredis.Redis:
114
+ if self._cluster and self._startup_nodes:
115
+ from redis.asyncio.cluster import ClusterNode, RedisCluster
116
+ nodes = [ClusterNode(n["host"], n["port"]) for n in self._startup_nodes]
117
+ return RedisCluster(startup_nodes=nodes, password=self._password, decode_responses=False)
118
+
119
+ if self._url:
120
+ pool = aioredis.BlockingConnectionPool.from_url(
121
+ self._url,
122
+ max_connections=self._max_connections,
123
+ timeout=self._pool_timeout,
124
+ decode_responses=False,
125
+ )
126
+ else:
127
+ pool_kwargs: dict = dict(
128
+ host=self._host,
129
+ port=self._port,
130
+ db=self._db,
131
+ username=self._username,
132
+ password=self._password,
133
+ client_name=self._client_name,
134
+ max_connections=self._max_connections,
135
+ timeout=self._pool_timeout,
136
+ decode_responses=False,
137
+ )
138
+ if self._ssl:
139
+ pool_kwargs["ssl"] = True
140
+ pool = aioredis.BlockingConnectionPool(**pool_kwargs)
141
+ return aioredis.Redis(connection_pool=pool)
142
+
143
+ def _prefixed(self, key: str) -> str:
144
+ return f"{self._prefix}_{key}" if self._prefix else key
145
+
146
+ # ── Stream operations ────────────────────────────────────────────────────
147
+
148
+ async def push(self, stream: str, data: bytes, max_len: int = 100_000) -> str:
149
+ client = await self._get_client()
150
+ msg_id = await client.xadd(self._prefixed(stream), {"data": data}, maxlen=max_len, approximate=True)
151
+ return msg_id.decode() if isinstance(msg_id, bytes) else msg_id
152
+
153
+ async def push_many(self, stream: str, items: list[bytes], max_len: int = 100_000) -> list[str]:
154
+ if not items:
155
+ return []
156
+ client = await self._get_client()
157
+ async with client.pipeline(transaction=False) as pipe:
158
+ for data in items:
159
+ pipe.xadd(self._prefixed(stream), {"data": data}, maxlen=max_len, approximate=True)
160
+ results = await pipe.execute()
161
+ return [r.decode() if isinstance(r, bytes) else r for r in results]
162
+
163
+ async def create_group(self, stream: str, group: str, start_id: str = "0") -> None:
164
+ client = await self._get_client()
165
+ try:
166
+ await client.xgroup_create(self._prefixed(stream), group, id=start_id, mkstream=True)
167
+ except Exception as e:
168
+ if _BUSYGROUP not in str(e):
169
+ raise
170
+
171
+ async def read(
172
+ self, stream: str, group: str, consumer: str, count: int, block_ms: int
173
+ ) -> list[StreamMessage]:
174
+ client = await self._get_client()
175
+ resp = await client.xreadgroup(
176
+ groupname=group,
177
+ consumername=consumer,
178
+ streams={self._prefixed(stream): ">"},
179
+ count=count,
180
+ block=block_ms,
181
+ )
182
+ if not resp:
183
+ return []
184
+ messages = []
185
+ for _stream_name, entries in resp:
186
+ for msg_id, fields in entries:
187
+ raw_id = msg_id.decode() if isinstance(msg_id, bytes) else msg_id
188
+ raw_data = fields.get(b"data", fields.get("data", b""))
189
+ messages.append(StreamMessage(id=raw_id, data=raw_data, delivery_count=1))
190
+ return messages
191
+
192
+ async def ack(self, stream: str, group: str, *ids: str) -> int:
193
+ if not ids:
194
+ return 0
195
+ client = await self._get_client()
196
+ return await client.xack(self._prefixed(stream), group, *ids)
197
+
198
+ async def autoclaim(
199
+ self,
200
+ stream: str,
201
+ group: str,
202
+ consumer: str,
203
+ min_idle_ms: int,
204
+ count: int,
205
+ max_passes: int | None = None,
206
+ ) -> list[StreamMessage]:
207
+ """
208
+ Reclaim idle PEL entries via XAUTOCLAIM.
209
+
210
+ Follows the cursor until the full PEL is swept (cursor returns "0-0").
211
+ max_passes caps the number of XAUTOCLAIM calls per invocation; None = unlimited.
212
+ Stall detection (cursor unchanged) guards against misbehaving Redis forks.
213
+ """
214
+ r = await self._get_client()
215
+ messages: list[StreamMessage] = []
216
+ cursor = "0-0"
217
+ passes = 0
218
+ while True:
219
+ if max_passes is not None and passes >= max_passes:
220
+ break
221
+ next_cursor, entries, _deleted = await r.xautoclaim(
222
+ name=self._prefixed(stream),
223
+ groupname=group,
224
+ consumername=consumer,
225
+ min_idle_time=min_idle_ms,
226
+ start_id=cursor,
227
+ count=count,
228
+ )
229
+ passes += 1
230
+ for msg_id, fields in entries:
231
+ raw_id = msg_id.decode() if isinstance(msg_id, bytes) else msg_id
232
+ raw_data = fields.get(b"data", fields.get("data", b""))
233
+ # XAUTOCLAIM doesn't return delivery_count in entries — 1 is minimum
234
+ # since reclaimed msgs were already delivered (real count lives in PEL)
235
+ messages.append(StreamMessage(id=raw_id, data=raw_data, delivery_count=1))
236
+ next_str = next_cursor.decode() if isinstance(next_cursor, bytes) else str(next_cursor)
237
+ if next_str == "0-0" or next_str == cursor:
238
+ # "0-0" → full PEL sweep complete (real Redis)
239
+ # next_str==cursor → no progress / stall (fakeredis bug guard)
240
+ break
241
+ cursor = next_str
242
+ return messages
243
+
244
+ async def pending_range(
245
+ self, stream: str, group: str, count: int, consumer: str | None = None
246
+ ) -> list[PendingEntry]:
247
+ client = await self._get_client()
248
+ resp = await client.xpending_range(
249
+ name=self._prefixed(stream), groupname=group, min="-", max="+",
250
+ count=count, consumername=consumer,
251
+ )
252
+ result = []
253
+ for entry in resp:
254
+ msg_id = entry["message_id"]
255
+ raw_id = msg_id.decode() if isinstance(msg_id, bytes) else msg_id
256
+ consumer_name = entry["consumer"]
257
+ if isinstance(consumer_name, bytes):
258
+ consumer_name = consumer_name.decode()
259
+ result.append(PendingEntry(
260
+ id=raw_id, consumer=consumer_name,
261
+ idle_ms=entry.get("time_since_delivered", 0),
262
+ delivery_count=entry.get("times_delivered", 0),
263
+ ))
264
+ return result
265
+
266
+ async def fetch_by_ids(self, stream: str, ids: list[str]) -> list[StreamMessage]:
267
+ if not ids:
268
+ return []
269
+ client = await self._get_client()
270
+ results = []
271
+ for msg_id in ids:
272
+ resp = await client.xrange(self._prefixed(stream), min=msg_id, max=msg_id, count=1)
273
+ for entry_id, fields in resp:
274
+ raw_id = entry_id.decode() if isinstance(entry_id, bytes) else entry_id
275
+ raw_data = fields.get(b"data", fields.get("data", b""))
276
+ results.append(StreamMessage(id=raw_id, data=raw_data))
277
+ return results
278
+
279
+ async def stream_len(self, stream: str) -> int:
280
+ client = await self._get_client()
281
+ return await client.xlen(self._prefixed(stream))
282
+
283
+ async def group_info(self, stream: str) -> list[dict]:
284
+ client = await self._get_client()
285
+ return await client.xinfo_groups(self._prefixed(stream))
286
+
287
+ async def consumer_info(self, stream: str, group: str) -> list[dict]:
288
+ client = await self._get_client()
289
+ return await client.xinfo_consumers(self._prefixed(stream), group)
290
+
291
+ async def close(self) -> None:
292
+ key = self._pool_key()
293
+ async with self._get_lock():
294
+ client = self._pools.pop(key, None)
295
+ if client:
296
+ await client.aclose()
297
+
298
+ @classmethod
299
+ async def close_all(cls) -> None:
300
+ async with cls._get_lock():
301
+ for client in cls._pools.values():
302
+ try:
303
+ await client.aclose()
304
+ except Exception:
305
+ pass
306
+ cls._pools.clear()
@@ -0,0 +1,218 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ import weakref
7
+ from typing import Awaitable, Callable, ClassVar
8
+
9
+ from .client import StreamClient
10
+ from .group import ConsumerConfig, ConsumerGroup
11
+ from .message import ConsumerMetrics, StreamMessage, _TpsTracker
12
+ from .retry import DLQHandler, RetryHandler
13
+ from .serializers import JsonSerializer, Serializer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ MessageHandler = Callable[[list[StreamMessage]], Awaitable[list[str]]]
18
+ ErrorHook = Callable[[Exception, list[StreamMessage]], Awaitable[None]]
19
+
20
+ _NOGROUP = "NOGROUP"
21
+
22
+
23
+ class StreamConsumer:
24
+ """
25
+ Async Redis Stream consumer using consumer groups.
26
+
27
+ Lifecycle per iteration (run_once):
28
+ 1. Ensure consumer group exists (idempotent — no-op after first call)
29
+ 2. XREADGROUP ">" — fetch new messages, pass to handler, ACK returned IDs
30
+ 3. XAUTOCLAIM — reclaim orphaned messages from crashed consumers, re-process
31
+ 4. Poison-pill sweep — messages exceeding max_deliveries go to DLQ via RetryHandler
32
+
33
+ handler(messages) must return the list of IDs it successfully processed.
34
+ Messages whose IDs are NOT returned stay in PEL for reclaim recovery.
35
+
36
+ Safe for multiple pods: each pod gets a unique worker_name; consumer groups
37
+ distribute messages and XAUTOCLAIM provides crash recovery across pods.
38
+
39
+ Call metrics() for this instance or all_metrics() for all live instances in this process.
40
+ """
41
+
42
+ _instances: ClassVar[weakref.WeakSet[StreamConsumer]] = weakref.WeakSet()
43
+
44
+ def __init__(
45
+ self,
46
+ client: StreamClient,
47
+ stream: str,
48
+ config: ConsumerConfig,
49
+ handler: MessageHandler,
50
+ dlq_handler: DLQHandler | None = None,
51
+ serializer: Serializer | None = None,
52
+ error_hook: ErrorHook | None = None,
53
+ ) -> None:
54
+ self._client = client
55
+ self._stream = stream
56
+ self._config = config
57
+ self._handler = handler
58
+ self._serializer = serializer or JsonSerializer()
59
+ self._error_hook = error_hook
60
+ self._worker_name = config.resolved_worker_name()
61
+
62
+ self._cg = ConsumerGroup(client, stream, config.group)
63
+ self._retry = RetryHandler(
64
+ client=client,
65
+ stream=stream,
66
+ group=config.group,
67
+ dlq_handler=dlq_handler,
68
+ max_deliveries=config.max_deliveries,
69
+ batch_size=config.batch_size,
70
+ serializer=self._serializer,
71
+ )
72
+ self._stop = False
73
+ StreamConsumer._instances.add(self)
74
+
75
+ # Throughput counters
76
+ self._total_read: int = 0
77
+ self._total_acked: int = 0
78
+ self._total_dlq: int = 0
79
+ self._total_errors: int = 0
80
+ self._started_at: float | None = None
81
+ self._tps_in = _TpsTracker()
82
+ self._tps_out = _TpsTracker()
83
+
84
+ @classmethod
85
+ def all_metrics(cls) -> list[ConsumerMetrics]:
86
+ """Return metrics for all live StreamConsumer instances in this process."""
87
+ return [inst.metrics() for inst in cls._instances]
88
+
89
+ def stop(self) -> None:
90
+ """Signal the consumer to exit after the current iteration."""
91
+ self._stop = True
92
+
93
+ def metrics(self) -> ConsumerMetrics:
94
+ """Return current throughput counters and TPS (non-blocking, no Redis calls)."""
95
+ uptime = (time.monotonic() - self._started_at) if self._started_at else 0.0
96
+ avg_tps = self._total_acked / uptime if uptime >= 1.0 else 0.0
97
+ tps_in = round(self._tps_in.tps(), 2)
98
+ tps_out = round(self._tps_out.tps(), 2)
99
+ return ConsumerMetrics(
100
+ total_read=self._total_read,
101
+ total_acked=self._total_acked,
102
+ total_dlq=self._total_dlq,
103
+ total_errors=self._total_errors,
104
+ tps_in=tps_in,
105
+ tps_out=tps_out,
106
+ tps_total=round(tps_in + tps_out, 2),
107
+ avg_tps=round(avg_tps, 2),
108
+ uptime_secs=round(uptime, 1),
109
+ )
110
+
111
+ async def run(self) -> None:
112
+ """Infinite loop. Catches all exceptions and restarts. Exits on stop() or CancelledError."""
113
+ while not self._stop:
114
+ try:
115
+ await self.run_once()
116
+ except asyncio.CancelledError:
117
+ raise
118
+ except Exception as e:
119
+ if _NOGROUP in str(e):
120
+ # Stream or group was externally deleted — reset registry so ensure() re-creates
121
+ self._cg.reset()
122
+ logger.warning("Consumer group lost (NOGROUP) — will re-create on next iteration")
123
+ else:
124
+ logger.exception("Consumer loop error: %s", e)
125
+ await asyncio.sleep(1)
126
+
127
+ async def run_once(self) -> None:
128
+ """Execute one full consumer iteration."""
129
+ cfg = self._config
130
+
131
+ await self._cg.ensure(dlq_stream=cfg.dlq_stream, dlq_group=cfg.dlq_group)
132
+
133
+ # 1. New messages
134
+ raw_messages = await self._client.read(
135
+ stream=self._stream,
136
+ group=cfg.group,
137
+ consumer=self._worker_name,
138
+ count=cfg.batch_size,
139
+ block_ms=cfg.block_ms,
140
+ )
141
+ if raw_messages:
142
+ if self._started_at is None:
143
+ self._started_at = time.monotonic()
144
+ self._total_read += len(raw_messages)
145
+ self._tps_in.record(len(raw_messages))
146
+ await self._process_batch(raw_messages)
147
+
148
+ # 2. Orphan recovery — reclaim from crashed consumers
149
+ claimed = await self._client.autoclaim(
150
+ stream=self._stream,
151
+ group=cfg.group,
152
+ consumer=self._worker_name,
153
+ min_idle_ms=cfg.min_idle_claim_ms,
154
+ count=cfg.batch_size,
155
+ max_passes=cfg.max_claim_passes,
156
+ )
157
+ if claimed:
158
+ if self._started_at is None:
159
+ self._started_at = time.monotonic()
160
+ self._total_read += len(claimed)
161
+ self._tps_in.record(len(claimed))
162
+ await self._process_batch(claimed)
163
+
164
+ # 3. Poison-pill sweep
165
+ dlq_count = await self._retry.handle_poison_pills()
166
+ self._total_dlq += dlq_count
167
+
168
+ # ── Internal ──────────────────────────────────────────────────────────────
169
+
170
+ async def _process_batch(self, messages: list[StreamMessage]) -> None:
171
+ cfg = self._config
172
+ decoded: list[StreamMessage] = []
173
+ bad_ids: list[str] = []
174
+
175
+ for msg in messages:
176
+ try:
177
+ decoded.append(self._decode(msg))
178
+ except Exception as e:
179
+ logger.warning("Decode error for msg %s: %s", msg.id, e)
180
+ bad_ids.append(msg.id)
181
+ await self._retry.send_to_dlq(msg, "decode_error")
182
+
183
+ if bad_ids:
184
+ self._total_dlq += len(bad_ids)
185
+ await self._client.ack(self._stream, cfg.group, *bad_ids)
186
+
187
+ if not decoded:
188
+ return
189
+
190
+ try:
191
+ acked_ids = await self._handler(decoded)
192
+ except asyncio.CancelledError:
193
+ raise
194
+ except Exception as e:
195
+ logger.exception("Handler error: %s", e)
196
+ self._total_errors += 1
197
+ if self._error_hook:
198
+ await self._error_hook(e, decoded)
199
+ return
200
+
201
+ if acked_ids is None:
202
+ logger.warning(
203
+ "Handler returned None instead of a list — no messages ACKed. "
204
+ "Return [] to explicitly ACK nothing, or [msg.id, ...] to ACK."
205
+ )
206
+ return
207
+
208
+ if acked_ids:
209
+ await self._client.ack(self._stream, cfg.group, *acked_ids)
210
+ self._total_acked += len(acked_ids)
211
+ self._tps_out.record(len(acked_ids))
212
+
213
+ def _decode(self, msg: StreamMessage) -> StreamMessage:
214
+ raw = msg.data
215
+ if isinstance(raw, str):
216
+ raw = raw.encode()
217
+ decoded_data = self._serializer.decode(raw)
218
+ return StreamMessage(id=msg.id, data=decoded_data, delivery_count=msg.delivery_count)
@@ -0,0 +1,14 @@
1
+ class RedisStreamError(Exception):
2
+ pass
3
+
4
+
5
+ class DLQError(RedisStreamError):
6
+ pass
7
+
8
+
9
+ class SerializationError(RedisStreamError):
10
+ pass
11
+
12
+
13
+ class ClusterSlotError(RedisStreamError):
14
+ pass