redis-stream-queue 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import random
5
+ import socket
6
+ from dataclasses import dataclass
7
+ from typing import ClassVar
8
+
9
+ from .client import StreamClient
10
+ from .message import ConsumerInfo, PendingEntry, StreamStats
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _default_worker_name(group: str) -> str:
16
+ hostname = socket.gethostname()
17
+ suffix = f"{random.randint(0, 9999):04d}"
18
+ # Truncate group name, not hostname/suffix, to keep uniqueness guarantees
19
+ tail = f"_{hostname}_{suffix}"
20
+ max_group = 64 - len(tail)
21
+ return f"{group[:max_group]}{tail}"
22
+
23
+
24
+ @dataclass
25
+ class ConsumerConfig:
26
+ """Configuration for a stream consumer."""
27
+ group: str
28
+ worker_name: str | None = None
29
+ dlq_stream: str | None = None
30
+ dlq_group: str | None = None
31
+ batch_size: int = 100
32
+ block_ms: int = 5_000
33
+ min_idle_claim_ms: int = 10_000
34
+ max_deliveries: int = 3
35
+ max_stream_size: int = 100_000
36
+ max_claim_passes: int | None = None # None = sweep full PEL per iteration
37
+
38
+ def resolved_worker_name(self) -> str:
39
+ return self.worker_name or _default_worker_name(self.group)
40
+
41
+
42
+ class ConsumerGroup:
43
+ """
44
+ Manages and inspects a Redis Stream consumer group.
45
+
46
+ Provides group creation (idempotent), stats, health checks,
47
+ and pending entry inspection.
48
+
49
+ Group existence is tracked in a class-level registry keyed by
50
+ (pool_key, stream, group), so multiple instances targeting the same
51
+ group share the cache. Call reset() to force re-creation (e.g. after
52
+ external FLUSHALL or XGROUP DESTROY).
53
+ """
54
+
55
+ _known: ClassVar[set[tuple]] = set()
56
+
57
+ def __init__(self, client: StreamClient, stream: str, group: str) -> None:
58
+ self._client = client
59
+ self._stream = stream
60
+ self._group = group
61
+
62
+ def _key(self) -> tuple:
63
+ return (self._client._pool_key(), self._stream, self._group)
64
+
65
+ async def ensure(self, dlq_stream: str | None = None, dlq_group: str | None = None) -> None:
66
+ """Create consumer group (and DLQ group) if not already known. No-op otherwise."""
67
+ key = self._key()
68
+ if key in ConsumerGroup._known:
69
+ return
70
+ await self._client.create_group(self._stream, self._group)
71
+ if dlq_stream and dlq_group:
72
+ await self._client.create_group(dlq_stream, dlq_group)
73
+ ConsumerGroup._known.add(key)
74
+
75
+ def reset(self) -> None:
76
+ """Remove this group from the known registry — next ensure() will re-create it."""
77
+ ConsumerGroup._known.discard(self._key())
78
+
79
+ async def stats(self, dlq_stream: str | None = None) -> StreamStats:
80
+ """Return stream length, lag, PEL size, and per-consumer info."""
81
+ stream_length = await self._client.stream_len(self._stream)
82
+ groups = await self._client.group_info(self._stream)
83
+ consumers_raw = await self._client.consumer_info(self._stream, self._group)
84
+
85
+ group_data = next(
86
+ (g for g in groups if _decode(g.get("name")) == self._group), None
87
+ )
88
+ pel_size = int(group_data.get("pending", 0)) if group_data else 0
89
+ lag = int(group_data.get("lag") or 0) if group_data else 0
90
+ last_delivered_id = _decode(group_data.get("last-delivered-id", b"0-0")) if group_data else "0-0"
91
+
92
+ consumers = [
93
+ ConsumerInfo(
94
+ name=_decode(c.get("name", b"")),
95
+ pending=int(c.get("pending", 0)),
96
+ idle_ms=int(c.get("idle", 0)),
97
+ )
98
+ for c in consumers_raw
99
+ ]
100
+
101
+ dlq_length: int | None = None
102
+ if dlq_stream:
103
+ try:
104
+ dlq_length = await self._client.stream_len(dlq_stream)
105
+ except Exception:
106
+ pass
107
+
108
+ return StreamStats(
109
+ stream=self._stream,
110
+ group=self._group,
111
+ stream_length=stream_length,
112
+ group_pel_size=pel_size,
113
+ lag=lag,
114
+ last_delivered_id=last_delivered_id,
115
+ consumers=consumers,
116
+ dlq_length=dlq_length,
117
+ )
118
+
119
+ async def health_check(
120
+ self,
121
+ max_lag: int = 10_000,
122
+ max_idle_ms: int = 60_000,
123
+ dlq_stream: str | None = None,
124
+ ) -> dict:
125
+ """Return health status dict with issue descriptions."""
126
+ s = await self.stats(dlq_stream=dlq_stream)
127
+ issues: list[str] = []
128
+ if s.lag > max_lag:
129
+ issues.append(f"lag={s.lag} exceeds max_lag={max_lag}")
130
+ for c in s.consumers:
131
+ if c.idle_ms > max_idle_ms:
132
+ issues.append(f"consumer '{c.name}' idle for {c.idle_ms}ms (max {max_idle_ms}ms)")
133
+ return {"healthy": not issues, "issues": issues, "stats": s}
134
+
135
+ async def pending_details(self, count: int = 100) -> list[PendingEntry]:
136
+ """Return details for all pending (unacknowledged) entries."""
137
+ return await self._client.pending_range(self._stream, self._group, count)
138
+
139
+
140
+ def _decode(value) -> str:
141
+ if isinstance(value, bytes):
142
+ return value.decode()
143
+ return str(value) if value is not None else ""
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from collections import deque
5
+ from dataclasses import dataclass, field
6
+
7
+ _TPS_WINDOW = 60.0 # seconds
8
+
9
+
10
+ class _TpsTracker:
11
+ """Sliding-window TPS counter. Stores (timestamp, batch_count) pairs."""
12
+
13
+ def __init__(self, window_secs: float = _TPS_WINDOW) -> None:
14
+ self._window = window_secs
15
+ self._batches: deque[tuple[float, int]] = deque()
16
+
17
+ def record(self, count: int) -> None:
18
+ if count > 0:
19
+ self._batches.append((time.monotonic(), count))
20
+
21
+ def tps(self) -> float:
22
+ now = time.monotonic()
23
+ cutoff = now - self._window
24
+ while self._batches and self._batches[0][0] < cutoff:
25
+ self._batches.popleft()
26
+ if not self._batches:
27
+ return 0.0
28
+ total = sum(c for _, c in self._batches)
29
+ span = now - self._batches[0][0]
30
+ # span < 1s → assume 1s to avoid artificially inflated rate on startup
31
+ return total / max(span, 1.0)
32
+
33
+
34
+ @dataclass
35
+ class StreamMessage:
36
+ id: str
37
+ data: dict
38
+ delivery_count: int = 0
39
+
40
+
41
+ @dataclass
42
+ class PendingEntry:
43
+ id: str
44
+ consumer: str
45
+ idle_ms: int
46
+ delivery_count: int
47
+
48
+
49
+ @dataclass
50
+ class ConsumerInfo:
51
+ name: str
52
+ pending: int
53
+ idle_ms: int
54
+
55
+
56
+ @dataclass
57
+ class StreamStats:
58
+ stream: str
59
+ group: str
60
+ stream_length: int
61
+ group_pel_size: int
62
+ lag: int
63
+ last_delivered_id: str
64
+ consumers: list[ConsumerInfo] = field(default_factory=list)
65
+ dlq_length: int | None = None
66
+
67
+
68
+ @dataclass
69
+ class ConsumerMetrics:
70
+ """Runtime throughput and error counters for a StreamConsumer."""
71
+ total_read: int # messages pulled from stream (incl. reclaimed)
72
+ total_acked: int # successfully processed + ACKed by handler
73
+ total_dlq: int # routed to DLQ (decode_error + max_deliveries)
74
+ total_errors: int # handler exceptions (message stays in PEL)
75
+ tps_in: float # reads/sec — XREADGROUP + autoclaim, sliding 60s window
76
+ tps_out: float # acked/sec — sliding 60s window
77
+ tps_total: float # tps_in + tps_out
78
+ avg_tps: float # total_acked / uptime_secs since first message
79
+ uptime_secs: float # seconds since first message was processed
80
+
81
+
82
+ @dataclass
83
+ class ProducerMetrics:
84
+ """Runtime throughput counters for a StreamProducer."""
85
+ total_pushed: int # messages pushed since instance creation
86
+ tps: float # pushed/sec — sliding 60s window
87
+ avg_tps: float # total_pushed / uptime_secs since first push
88
+ uptime_secs: float # seconds since first push
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import weakref
5
+ from typing import ClassVar
6
+
7
+ from .client import StreamClient
8
+ from .message import ProducerMetrics, _TpsTracker
9
+ from .serializers import JsonSerializer, Serializer
10
+
11
+
12
+ class StreamProducer:
13
+ """
14
+ Pushes messages onto a Redis Stream.
15
+
16
+ Safe to use from multiple pods simultaneously — Redis XADD is atomic.
17
+ Call metrics() for this instance or all_metrics() for all live instances in this process.
18
+ """
19
+
20
+ _instances: ClassVar[weakref.WeakSet[StreamProducer]] = weakref.WeakSet()
21
+
22
+ def __init__(
23
+ self,
24
+ client: StreamClient,
25
+ stream: str,
26
+ group: str | None = None,
27
+ max_len: int = 100_000,
28
+ serializer: Serializer | None = None,
29
+ ) -> None:
30
+ self._client = client
31
+ self._stream = stream
32
+ self._group = group
33
+ self._max_len = max_len
34
+ self._serializer = serializer or JsonSerializer()
35
+
36
+ self._total_pushed: int = 0
37
+ self._started_at: float | None = None
38
+ self._tps = _TpsTracker()
39
+ StreamProducer._instances.add(self)
40
+
41
+ async def push(self, data: dict) -> str:
42
+ """Serialize and publish one message. Returns the Redis stream entry ID."""
43
+ encoded = self._serializer.encode(data)
44
+ msg_id = await self._client.push(self._stream, encoded, max_len=self._max_len)
45
+ if self._started_at is None:
46
+ self._started_at = time.monotonic()
47
+ self._total_pushed += 1
48
+ self._tps.record(1)
49
+ return msg_id
50
+
51
+ async def push_many(self, data: list[dict]) -> list[str]:
52
+ """Serialize and publish multiple messages via pipeline (one round-trip). Returns list of entry IDs."""
53
+ if not data:
54
+ return []
55
+ encoded = [self._serializer.encode(d) for d in data]
56
+ ids = await self._client.push_many(self._stream, encoded, max_len=self._max_len)
57
+ if self._started_at is None:
58
+ self._started_at = time.monotonic()
59
+ self._total_pushed += len(data)
60
+ self._tps.record(len(data))
61
+ return ids
62
+
63
+ async def ensure_group(self, group: str | None = None) -> None:
64
+ """Idempotently create the consumer group for this stream."""
65
+ g = group or self._group
66
+ if not g:
67
+ raise ValueError("group name is required")
68
+ await self._client.create_group(self._stream, g)
69
+
70
+ @classmethod
71
+ def all_metrics(cls) -> list[ProducerMetrics]:
72
+ """Return metrics for all live StreamProducer instances in this process."""
73
+ return [inst.metrics() for inst in cls._instances]
74
+
75
+ def metrics(self) -> ProducerMetrics:
76
+ """Return current push counters and TPS (non-blocking, no Redis calls)."""
77
+ uptime = (time.monotonic() - self._started_at) if self._started_at else 0.0
78
+ avg_tps = self._total_pushed / uptime if uptime >= 1.0 else 0.0
79
+ return ProducerMetrics(
80
+ total_pushed=self._total_pushed,
81
+ tps=round(self._tps.tps(), 2),
82
+ avg_tps=round(avg_tps, 2),
83
+ uptime_secs=round(uptime, 1),
84
+ )
File without changes
@@ -0,0 +1,92 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Awaitable, Callable
5
+
6
+ from .client import StreamClient
7
+ from .message import StreamMessage
8
+ from .serializers import JsonSerializer, Serializer
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ DLQHandler = Callable[[StreamMessage, str], Awaitable[None]]
13
+
14
+
15
+ class RetryHandler:
16
+ """
17
+ Detects poison pills (messages exceeding max_deliveries) and routes them to DLQ.
18
+
19
+ Integrates with StreamConsumer to offload all retry/DLQ concerns from the main loop.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ client: StreamClient,
25
+ stream: str,
26
+ group: str,
27
+ dlq_handler: DLQHandler | None = None,
28
+ max_deliveries: int = 3,
29
+ batch_size: int = 100,
30
+ serializer: Serializer | None = None,
31
+ ) -> None:
32
+ self._client = client
33
+ self._stream = stream
34
+ self._group = group
35
+ self._dlq_handler = dlq_handler
36
+ self._max_deliveries = max_deliveries
37
+ self._batch_size = batch_size
38
+ self._serializer = serializer or JsonSerializer()
39
+
40
+ async def handle_poison_pills(self) -> int:
41
+ """
42
+ Find messages in PEL that exceeded max_deliveries, route to DLQ, then ACK.
43
+ Returns count of poison pills processed. Called once per consumer loop iteration.
44
+ """
45
+ pending = await self._client.pending_range(
46
+ self._stream, self._group, self._batch_size
47
+ )
48
+ poison_ids = [e.id for e in pending if e.delivery_count >= self._max_deliveries]
49
+ if not poison_ids:
50
+ return 0
51
+
52
+ if not self._dlq_handler:
53
+ logger.warning(
54
+ "Poison pills detected (%d msgs) but no dlq_handler set — "
55
+ "messages will be ACKed and lost: %s",
56
+ len(poison_ids), poison_ids[:5],
57
+ )
58
+
59
+ fetched = await self._client.fetch_by_ids(self._stream, poison_ids)
60
+ fetched_by_id = {m.id: m for m in fetched}
61
+
62
+ for msg_id in poison_ids:
63
+ raw_msg = fetched_by_id.get(msg_id)
64
+ if raw_msg is None:
65
+ logger.warning("Poison pill msg %s missing from stream (deleted?) — ACKing anyway", msg_id)
66
+ else:
67
+ decoded = self._try_decode(raw_msg)
68
+ await self.send_to_dlq(decoded, "max_deliveries")
69
+
70
+ # Batch ACK after all DLQ attempts — send_to_dlq swallows errors
71
+ await self._client.ack(self._stream, self._group, *poison_ids)
72
+ return len(poison_ids)
73
+
74
+ async def send_to_dlq(self, msg: StreamMessage, reason: str) -> None:
75
+ """Route a single message to the DLQ handler. Swallows handler errors."""
76
+ if self._dlq_handler:
77
+ try:
78
+ await self._dlq_handler(msg, reason)
79
+ except Exception as e:
80
+ logger.error("DLQ handler error for msg %s: %s", msg.id, e)
81
+
82
+ def _try_decode(self, msg: StreamMessage) -> StreamMessage:
83
+ raw = msg.data
84
+ if isinstance(raw, (bytes, str)):
85
+ try:
86
+ if isinstance(raw, str):
87
+ raw = raw.encode()
88
+ decoded_data = self._serializer.decode(raw)
89
+ return StreamMessage(id=msg.id, data=decoded_data, delivery_count=msg.delivery_count)
90
+ except Exception:
91
+ pass
92
+ return msg
@@ -0,0 +1,45 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import pickle
5
+ from typing import Protocol, runtime_checkable
6
+
7
+
8
+ @runtime_checkable
9
+ class Serializer(Protocol):
10
+ def encode(self, data: dict) -> bytes: ...
11
+ def decode(self, raw: bytes) -> dict: ...
12
+
13
+
14
+ class JsonSerializer:
15
+ def encode(self, data: dict) -> bytes:
16
+ return json.dumps(data, ensure_ascii=False).encode("utf-8")
17
+
18
+ def decode(self, raw: bytes) -> dict:
19
+ return json.loads(raw)
20
+
21
+
22
+ class MsgpackSerializer:
23
+ def __init__(self) -> None:
24
+ try:
25
+ import msgpack # noqa: F401
26
+ except ImportError:
27
+ raise ImportError(
28
+ "msgpack is required: pip install redis-stream-queue[msgpack]"
29
+ )
30
+
31
+ def encode(self, data: dict) -> bytes:
32
+ import msgpack
33
+ return msgpack.packb(data, use_bin_type=True)
34
+
35
+ def decode(self, raw: bytes) -> dict:
36
+ import msgpack
37
+ return msgpack.unpackb(raw, raw=False)
38
+
39
+
40
+ class PickleSerializer:
41
+ def encode(self, data: dict) -> bytes:
42
+ return pickle.dumps(data)
43
+
44
+ def decode(self, raw: bytes) -> dict:
45
+ return pickle.loads(raw) # noqa: S301