lockwatch 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
lockwatch/__init__.py ADDED
@@ -0,0 +1,47 @@
1
+ """
2
+ LockWatch — API security middleware for FastAPI and Flask.
3
+
4
+ Public API surface:
5
+
6
+ from lockwatch import LockWatchMiddleware, RateLimiter, JWTRotationManager
7
+ from lockwatch import AnomalyDetector, AuditLogger
8
+
9
+ Quickstart (FastAPI):
10
+
11
+ from fastapi import FastAPI
12
+ from lockwatch import LockWatchMiddleware
13
+
14
+ app = FastAPI()
15
+ app.add_middleware(
16
+ LockWatchMiddleware,
17
+ redis_url="redis://localhost:6379",
18
+ rate_limit_requests=100,
19
+ rate_limit_window_seconds=60,
20
+ )
21
+ """
22
+
23
+ from lockwatch.rate_limiter import RateLimiter, RateLimitConfig, RateLimitState
24
+ from lockwatch.jwt_rotation import JWTRotationManager, TokenPair, JWTConfig
25
+ from lockwatch.anomaly import AnomalyDetector, AnomalyConfig
26
+ from lockwatch.audit import AuditLogger, AuditLogEntry, AuditQuery
27
+ from lockwatch.middleware import LockWatchMiddleware, LockWatchFlaskMiddleware, flask_jwt_required
28
+
29
+ __version__ = "0.1.0"
30
+ __all__ = [
31
+ # Middleware (primary entrypoint)
32
+ "LockWatchMiddleware",
33
+ "LockWatchFlaskMiddleware",
34
+ "flask_jwt_required",
35
+ # Individual components (for custom wiring)
36
+ "RateLimiter",
37
+ "RateLimitConfig",
38
+ "RateLimitState",
39
+ "JWTRotationManager",
40
+ "TokenPair",
41
+ "JWTConfig",
42
+ "AnomalyDetector",
43
+ "AnomalyConfig",
44
+ "AuditLogger",
45
+ "AuditLogEntry",
46
+ "AuditQuery",
47
+ ]
lockwatch/anomaly.py ADDED
@@ -0,0 +1,161 @@
1
+ """
2
+ IP burst anomaly detector.
3
+
4
+ Uses a separate Redis sorted set per IP (distinct from rate limiter keys) to
5
+ track request timestamps in a short burst window. When count exceeds threshold,
6
+ fires async webhook alerts and records a dedup key so we don't spam on every
7
+ subsequent request from the same IP.
8
+
9
+ Alert dedup: after firing, set Redis key lockwatch:anomaly:alerted:{ip} with
10
+ TTL=60s. Skip alert if key exists. This means at most one alert per IP per minute.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import asyncio
16
+ import logging
17
+ import time
18
+ from dataclasses import dataclass, field
19
+ from datetime import datetime, timezone
20
+ from typing import Any, Optional
21
+
22
+ import httpx
23
+ from redis.asyncio import Redis
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class AnomalyConfig:
30
+ """
31
+ Configuration for burst anomaly detection.
32
+
33
+ burst_threshold: number of requests in burst_window_seconds that triggers an alert.
34
+ webhook_urls: list of HTTPS endpoints to POST alert payloads to.
35
+ alert_cooldown_seconds: minimum seconds between alerts for the same IP.
36
+ """
37
+
38
+ burst_threshold: int = 50
39
+ burst_window_seconds: int = 10
40
+ webhook_urls: list[str] = field(default_factory=list)
41
+ alert_cooldown_seconds: int = 60
42
+ key_prefix: str = "lockwatch:anomaly"
43
+
44
+
45
+ @dataclass
46
+ class AnomalyEvent:
47
+ """Payload sent to configured webhook URLs when a burst is detected."""
48
+
49
+ ip: str
50
+ timestamp_utc: str
51
+ request_count_in_window: int
52
+ window_seconds: int
53
+ endpoint: str
54
+ method: str
55
+ user_agent: Optional[str] = None
56
+
57
+
58
+ class AnomalyDetector:
59
+ """
60
+ Per-IP burst detector with async webhook alerting.
61
+
62
+ Usage:
63
+ detector = AnomalyDetector(redis_client, AnomalyConfig(burst_threshold=50, burst_window_seconds=10))
64
+ is_anomaly = await detector.check(ip="1.2.3.4", endpoint="/api/login", method="POST")
65
+ """
66
+
67
+ def __init__(self, redis: Redis, config: AnomalyConfig) -> None:
68
+ self._redis = redis
69
+ self._config = config
70
+ # Shared httpx client — reuse connections across alert dispatches
71
+ self._http: Optional[httpx.AsyncClient] = None
72
+
73
+ async def check(
74
+ self,
75
+ ip: str,
76
+ endpoint: str,
77
+ method: str,
78
+ user_agent: Optional[str] = None,
79
+ ) -> bool:
80
+ """
81
+ Record a request from ip and return True if a burst anomaly is detected.
82
+
83
+ Does NOT block the request — that decision is left to the middleware.
84
+ Alert dispatch is fire-and-forget (asyncio.create_task) so it never
85
+ adds latency to the response path.
86
+ """
87
+ now_ms = int(time.time() * 1000)
88
+ window_start_ms = now_ms - (self._config.burst_window_seconds * 1000)
89
+ burst_key = self._ip_redis_key(ip)
90
+
91
+ import uuid
92
+ pipe = self._redis.pipeline()
93
+ # UUID suffix ensures uniqueness within the same millisecond — same reason as rate_limiter
94
+ member = f"{now_ms}:{uuid.uuid4().hex}"
95
+ pipe.zadd(burst_key, {member: now_ms})
96
+ pipe.zremrangebyscore(burst_key, 0, window_start_ms)
97
+ pipe.zcard(burst_key)
98
+ pipe.expire(burst_key, self._config.burst_window_seconds)
99
+ results = await pipe.execute()
100
+
101
+ count: int = results[2]
102
+ is_burst = count >= self._config.burst_threshold
103
+
104
+ if is_burst and self._config.webhook_urls:
105
+ # Only fire if we haven't alerted for this IP within the cooldown window
106
+ alerted_key = self._alerted_redis_key(ip)
107
+ already_alerted = await self._redis.exists(alerted_key)
108
+ if not already_alerted:
109
+ await self._redis.set(alerted_key, "1", ex=self._config.alert_cooldown_seconds)
110
+ event = AnomalyEvent(
111
+ ip=ip,
112
+ timestamp_utc=datetime.now(timezone.utc).isoformat(),
113
+ request_count_in_window=count,
114
+ window_seconds=self._config.burst_window_seconds,
115
+ endpoint=endpoint,
116
+ method=method,
117
+ user_agent=user_agent,
118
+ )
119
+ # Fire-and-forget — never let webhook delay block a response
120
+ asyncio.create_task(self._dispatch_alert(event))
121
+
122
+ return is_burst
123
+
124
+ async def _dispatch_alert(self, event: AnomalyEvent) -> None:
125
+ """
126
+ POST alert payload to all configured webhook URLs concurrently.
127
+
128
+ Failures are logged but swallowed — a broken webhook should never affect
129
+ the application. Each webhook gets a 5s timeout.
130
+ """
131
+ if self._http is None:
132
+ self._http = httpx.AsyncClient(timeout=5.0)
133
+
134
+ payload = {
135
+ "ip": event.ip,
136
+ "timestamp": event.timestamp_utc,
137
+ "request_count_in_window": event.request_count_in_window,
138
+ "window_seconds": event.window_seconds,
139
+ "endpoint": event.endpoint,
140
+ "method": event.method,
141
+ "user_agent": event.user_agent,
142
+ }
143
+
144
+ async def post_to(url: str) -> None:
145
+ try:
146
+ await self._http.post(url, json=payload) # type: ignore[union-attr]
147
+ except Exception as exc:
148
+ logger.warning("LockWatch anomaly webhook failed (%s): %s", url, exc)
149
+
150
+ await asyncio.gather(*[post_to(url) for url in self._config.webhook_urls])
151
+
152
+ async def close(self) -> None:
153
+ """Clean up the shared HTTP client. Call on application shutdown."""
154
+ if self._http:
155
+ await self._http.aclose()
156
+
157
+ def _ip_redis_key(self, ip: str) -> str:
158
+ return f"{self._config.key_prefix}:burst:{ip}"
159
+
160
+ def _alerted_redis_key(self, ip: str) -> str:
161
+ return f"{self._config.key_prefix}:alerted:{ip}"
lockwatch/audit.py ADDED
@@ -0,0 +1,179 @@
1
+ """
2
+ Postgres audit log for every request processed by LockWatch middleware.
3
+
4
+ Design choices:
5
+ - Writes are async background tasks (fire-and-forget) so audit logging is
6
+ never on the critical response path. A failed write logs a warning but
7
+ does not affect the response.
8
+ - SQLAlchemy 2.0 async session with asyncpg driver — fastest available Python
9
+ Postgres driver; async avoids blocking the event loop.
10
+ - Indices on (timestamp, user_id) and (timestamp, ip) support the two most
11
+ common audit query patterns: "what did user X do?" and "what came from IP Y?".
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import logging
17
+ from dataclasses import dataclass
18
+ from datetime import datetime, timezone
19
+ from typing import Optional
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # SQLAlchemy imports are conditional — audit is an optional dependency
24
+ try:
25
+ from sqlalchemy import Index, String, Boolean, Integer, Float, DateTime, select
26
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
27
+ from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine, async_sessionmaker
28
+
29
+ class _Base(DeclarativeBase):
30
+ pass
31
+
32
+ class AuditLogEntry(_Base):
33
+ """
34
+ SQLAlchemy model for a single audited request.
35
+
36
+ Every request that passes through LockWatchMiddleware gets one row.
37
+ Columns are intentionally wide — storage is cheap, retroactive debugging
38
+ of security incidents is not.
39
+ """
40
+
41
+ __tablename__ = "lockwatch_audit_log"
42
+
43
+ id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
44
+ # When the request was received (UTC, stored as naive datetime per SQLAlchemy convention)
45
+ timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False)
46
+ # Authenticated user identifier if available; NULL for unauthenticated requests
47
+ user_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
48
+ ip: Mapped[str] = mapped_column(String(45), nullable=False) # IPv6 max = 39 chars + CIDR
49
+ endpoint: Mapped[str] = mapped_column(String(2048), nullable=False)
50
+ method: Mapped[str] = mapped_column(String(10), nullable=False)
51
+ status_code: Mapped[int] = mapped_column(Integer, nullable=False)
52
+ # True if this request was counted against a rate limit (even if not blocked)
53
+ rate_limit_hit: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
54
+ # True if anomaly detector fired for this request
55
+ anomaly_detected: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
56
+ # Response latency in milliseconds (measured from middleware entry to exit)
57
+ latency_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
58
+
59
+ __table_args__ = (
60
+ # Fast lookups: "all requests from user X in time range"
61
+ Index("ix_audit_timestamp_user", "timestamp", "user_id"),
62
+ # Fast lookups: "all requests from IP Y in time range"
63
+ Index("ix_audit_timestamp_ip", "timestamp", "ip"),
64
+ )
65
+
66
+ _SQLALCHEMY_AVAILABLE = True
67
+
68
+ except ImportError:
69
+ _SQLALCHEMY_AVAILABLE = False
70
+
71
+ # sqlalchemy not installed — provide a stub so the rest of the library
72
+ # can import audit.py without crashing even if [audit] extras are absent
73
+ class AuditLogEntry: # type: ignore[no-redef]
74
+ """Stub: install lockwatch[audit] to enable Postgres audit logging."""
75
+ pass
76
+
77
+
78
+ @dataclass
79
+ class AuditQuery:
80
+ """Filters for querying the audit log."""
81
+
82
+ user_id: Optional[str] = None
83
+ ip: Optional[str] = None
84
+ endpoint_prefix: Optional[str] = None
85
+ from_time: Optional[datetime] = None
86
+ to_time: Optional[datetime] = None
87
+ rate_limit_hits_only: bool = False
88
+ anomalies_only: bool = False
89
+ limit: int = 100
90
+ offset: int = 0
91
+
92
+
93
+ class AuditLogger:
94
+ """
95
+ Async audit logger backed by Postgres via SQLAlchemy.
96
+
97
+ Usage:
98
+ logger = AuditLogger(database_url="postgresql+asyncpg://user:pass@host/db")
99
+ await logger.init() # creates table if not exists
100
+ await logger.log(entry)
101
+ """
102
+
103
+ def __init__(self, database_url: str) -> None:
104
+ if not _SQLALCHEMY_AVAILABLE:
105
+ raise ImportError(
106
+ "AuditLogger requires SQLAlchemy and asyncpg. "
107
+ "Install with: pip install lockwatch[audit]"
108
+ )
109
+ self._database_url = database_url
110
+ self._engine: Optional[AsyncEngine] = None
111
+ self._session_factory: Optional[async_sessionmaker[AsyncSession]] = None
112
+
113
+ async def init(self) -> None:
114
+ """
115
+ Create the audit log table if it doesn't exist.
116
+
117
+ Call once at application startup. Idempotent — safe to call on every startup.
118
+ """
119
+ self._engine = create_async_engine(self._database_url, echo=False)
120
+ self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
121
+ async with self._engine.begin() as conn:
122
+ await conn.run_sync(_Base.metadata.create_all)
123
+
124
+ async def log(self, entry: AuditLogEntry) -> None:
125
+ """
126
+ Persist an audit log entry asynchronously.
127
+
128
+ Designed to be called via asyncio.create_task() so it never blocks the
129
+ response. Failures are logged at WARNING level — never re-raised.
130
+ """
131
+ if self._session_factory is None:
132
+ logger.warning("AuditLogger.log() called before init() — skipping")
133
+ return
134
+ try:
135
+ async with self._session_factory() as session:
136
+ session.add(entry)
137
+ await session.commit()
138
+ except Exception as exc:
139
+ logger.warning("LockWatch audit log write failed: %s", exc)
140
+
141
+ async def query(self, filters: AuditQuery) -> list[AuditLogEntry]:
142
+ """
143
+ Query audit log entries matching filters.
144
+
145
+ Returns at most filters.limit rows, ordered by timestamp DESC.
146
+ """
147
+ if self._session_factory is None:
148
+ raise RuntimeError("AuditLogger not initialized — call await logger.init() first")
149
+
150
+ stmt = select(AuditLogEntry)
151
+
152
+ if filters.user_id is not None:
153
+ stmt = stmt.where(AuditLogEntry.user_id == filters.user_id)
154
+ if filters.ip is not None:
155
+ stmt = stmt.where(AuditLogEntry.ip == filters.ip)
156
+ if filters.endpoint_prefix is not None:
157
+ stmt = stmt.where(AuditLogEntry.endpoint.like(f"{filters.endpoint_prefix}%"))
158
+ if filters.from_time is not None:
159
+ stmt = stmt.where(AuditLogEntry.timestamp >= filters.from_time)
160
+ if filters.to_time is not None:
161
+ stmt = stmt.where(AuditLogEntry.timestamp <= filters.to_time)
162
+ if filters.rate_limit_hits_only:
163
+ stmt = stmt.where(AuditLogEntry.rate_limit_hit == True) # noqa: E712
164
+ if filters.anomalies_only:
165
+ stmt = stmt.where(AuditLogEntry.anomaly_detected == True) # noqa: E712
166
+
167
+ stmt = stmt.order_by(AuditLogEntry.timestamp.desc())
168
+ stmt = stmt.limit(filters.limit).offset(filters.offset)
169
+
170
+ async with self._session_factory() as session:
171
+ result = await session.execute(stmt)
172
+ return list(result.scalars().all())
173
+
174
+ async def close(self) -> None:
175
+ """Dispose the SQLAlchemy engine. Call on application shutdown."""
176
+ if self._engine:
177
+ await self._engine.dispose()
178
+ self._engine = None
179
+ self._session_factory = None
@@ -0,0 +1,272 @@
1
+ """
2
+ JWT rotation manager: access token (15min RS256) + refresh token (7-day).
3
+
4
+ Security model:
5
+ - Refresh tokens are stored as SHA-256(token) in Redis — raw tokens never
6
+ persisted, so a Redis dump doesn't leak usable tokens.
7
+ - Rotation is atomic: blacklist old JTI before issuing new pair. If the new
8
+ issue fails (e.g. Redis down), the old token stays valid — fail open is
9
+ intentional for availability, but configurable.
10
+ - Access token blacklist: stored in Redis with TTL = remaining token lifetime.
11
+ After expiry the key auto-deletes; no unbounded blacklist growth.
12
+ - JTIs are UUID4 — unguessable, no sequential enumeration possible.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import hashlib
18
+ import uuid
19
+ from dataclasses import dataclass
20
+ from datetime import datetime, timedelta, timezone
21
+ from typing import Any, Optional
22
+
23
+ from jose import jwt, JWTError
24
+ from redis.asyncio import Redis
25
+
26
+
27
+ @dataclass
28
+ class JWTConfig:
29
+ """
30
+ Configuration for the JWT rotation manager.
31
+
32
+ Either provide rsa_private_key_pem (PEM string) or a path to a PEM file
33
+ via rsa_private_key_path — the manager loads the file on init.
34
+ """
35
+
36
+ rsa_private_key_pem: Optional[str] = None
37
+ rsa_private_key_path: Optional[str] = None
38
+ issuer: str = "lockwatch"
39
+ audience: str = "lockwatch-api"
40
+ access_token_ttl_seconds: int = 900 # 15 minutes
41
+ refresh_token_ttl_seconds: int = 604800 # 7 days
42
+ key_prefix: str = "lockwatch:jwt"
43
+
44
+
45
+ @dataclass
46
+ class TokenPair:
47
+ """A newly-issued access + refresh token pair."""
48
+
49
+ access_token: str
50
+ refresh_token: str
51
+ access_expires_at: datetime
52
+ refresh_expires_at: datetime
53
+ token_type: str = "Bearer"
54
+
55
+
56
+ class JWTRotationManager:
57
+ """
58
+ Issues RS256 JWTs with automatic rotation and Redis-backed blacklisting.
59
+
60
+ Usage:
61
+ manager = JWTRotationManager(redis_client, JWTConfig(rsa_private_key_pem=pem))
62
+ pair = await manager.issue_token_pair(subject="user:42", claims={"roles": ["admin"]})
63
+ # ... on refresh endpoint ...
64
+ new_pair = await manager.rotate(incoming_refresh_token)
65
+ """
66
+
67
+ def __init__(self, redis: Redis, config: JWTConfig) -> None:
68
+ self._redis = redis
69
+ self._config = config
70
+ self._private_key: Optional[str] = None # loaded lazily
71
+
72
+ def _load_private_key(self) -> str:
73
+ if self._private_key:
74
+ return self._private_key
75
+ if self._config.rsa_private_key_pem:
76
+ self._private_key = self._config.rsa_private_key_pem
77
+ elif self._config.rsa_private_key_path:
78
+ with open(self._config.rsa_private_key_path) as f:
79
+ self._private_key = f.read()
80
+ else:
81
+ raise ValueError("JWTConfig requires either rsa_private_key_pem or rsa_private_key_path")
82
+ return self._private_key
83
+
84
+ def _extract_public_key(self) -> str:
85
+ """Derive RSA public key from the private key PEM for verification."""
86
+ from cryptography.hazmat.primitives.serialization import (
87
+ load_pem_private_key,
88
+ Encoding,
89
+ PublicFormat,
90
+ )
91
+ private_key_pem = self._load_private_key().encode()
92
+ private_key = load_pem_private_key(private_key_pem, password=None)
93
+ return private_key.public_key().public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode()
94
+
95
+ async def issue_token_pair(
96
+ self, subject: str, claims: Optional[dict[str, Any]] = None
97
+ ) -> TokenPair:
98
+ """
99
+ Issue a new access + refresh token pair for a subject.
100
+
101
+ Generates a UUID4 JTI for each token so they're individually revocable.
102
+ Stores SHA-256(refresh_token) in Redis with refresh_token_ttl_seconds TTL.
103
+ """
104
+ private_key = self._load_private_key()
105
+ now = datetime.now(timezone.utc)
106
+
107
+ access_jti = str(uuid.uuid4())
108
+ access_expires_at = now + timedelta(seconds=self._config.access_token_ttl_seconds)
109
+ access_payload: dict[str, Any] = {
110
+ "sub": subject,
111
+ "iss": self._config.issuer,
112
+ "aud": self._config.audience,
113
+ "iat": int(now.timestamp()),
114
+ "exp": int(access_expires_at.timestamp()),
115
+ "jti": access_jti,
116
+ "type": "access",
117
+ }
118
+ if claims:
119
+ access_payload.update(claims)
120
+
121
+ access_token = jwt.encode(access_payload, private_key, algorithm="RS256")
122
+
123
+ refresh_jti = str(uuid.uuid4())
124
+ refresh_expires_at = now + timedelta(seconds=self._config.refresh_token_ttl_seconds)
125
+ refresh_payload: dict[str, Any] = {
126
+ "sub": subject,
127
+ "iss": self._config.issuer,
128
+ "aud": self._config.audience,
129
+ "iat": int(now.timestamp()),
130
+ "exp": int(refresh_expires_at.timestamp()),
131
+ "jti": refresh_jti,
132
+ "type": "refresh",
133
+ }
134
+
135
+ refresh_token = jwt.encode(refresh_payload, private_key, algorithm="RS256")
136
+
137
+ # Store hash so Redis never holds a usable token — hash lookup on rotate()
138
+ token_hash = self._hash_token(refresh_token)
139
+ redis_key = self._refresh_redis_key(token_hash)
140
+ await self._redis.set(redis_key, refresh_jti, ex=self._config.refresh_token_ttl_seconds)
141
+
142
+ return TokenPair(
143
+ access_token=access_token,
144
+ refresh_token=refresh_token,
145
+ access_expires_at=access_expires_at,
146
+ refresh_expires_at=refresh_expires_at,
147
+ )
148
+
149
+ async def rotate(self, refresh_token: str) -> TokenPair:
150
+ """
151
+ Validate a refresh token, atomically delete its hash, and issue a new pair.
152
+
153
+ Uses WATCH/MULTI/EXEC to eliminate the race condition where two concurrent
154
+ callers could both pass the existence check before either deletes the key.
155
+ If another client deletes the key between WATCH and EXEC, WatchError is
156
+ raised and we retry — on retry the GET returns None → raises ValueError.
157
+ """
158
+ public_key = self._extract_public_key()
159
+
160
+ try:
161
+ claims = jwt.decode(
162
+ refresh_token,
163
+ public_key,
164
+ algorithms=["RS256"],
165
+ audience=self._config.audience,
166
+ issuer=self._config.issuer,
167
+ )
168
+ except JWTError as e:
169
+ raise ValueError(f"Invalid refresh token: {e}") from e
170
+
171
+ if claims.get("type") != "refresh":
172
+ raise ValueError("Token is not a refresh token")
173
+
174
+ jti = claims.get("jti")
175
+ subject = claims.get("sub")
176
+ if not jti or not subject:
177
+ raise ValueError("Refresh token missing required claims")
178
+
179
+ token_hash = self._hash_token(refresh_token)
180
+ refresh_key = self._refresh_redis_key(token_hash)
181
+ blacklist_key = self._blacklist_redis_key(jti)
182
+
183
+ from redis.exceptions import WatchError
184
+
185
+ async with self._redis.pipeline() as pipe:
186
+ while True:
187
+ try:
188
+ await pipe.watch(refresh_key)
189
+ # pipe is now in immediate-execution mode (not buffered)
190
+ stored_jti = await pipe.get(refresh_key)
191
+ if not stored_jti:
192
+ await pipe.unwatch()
193
+ raise ValueError("Refresh token has already been used or revoked")
194
+
195
+ is_blacklisted = await pipe.exists(blacklist_key)
196
+ if is_blacklisted:
197
+ await pipe.unwatch()
198
+ raise ValueError("Refresh token JTI is blacklisted")
199
+
200
+ pipe.multi() # switch to buffered mode
201
+ pipe.delete(refresh_key)
202
+ await pipe.execute() # atomic; raises WatchError if key was changed
203
+ break
204
+ except WatchError:
205
+ continue # concurrent rotation: retry sees empty key → ValueError
206
+
207
+ return await self.issue_token_pair(subject=subject)
208
+
209
+ async def blacklist_access_token(self, token: str) -> None:
210
+ """
211
+ Revoke an access token before its natural expiry.
212
+
213
+ Computes remaining TTL from token claims so the blacklist key auto-expires
214
+ exactly when the token would have expired naturally — no orphaned keys.
215
+ """
216
+ public_key = self._extract_public_key()
217
+
218
+ try:
219
+ claims = jwt.decode(
220
+ token,
221
+ public_key,
222
+ algorithms=["RS256"],
223
+ audience=self._config.audience,
224
+ issuer=self._config.issuer,
225
+ )
226
+ except JWTError as e:
227
+ raise ValueError(f"Cannot blacklist invalid token: {e}") from e
228
+
229
+ jti = claims.get("jti")
230
+ exp = claims.get("exp")
231
+ if not jti or not exp:
232
+ raise ValueError("Token missing jti or exp claim")
233
+
234
+ remaining_ttl = int(exp - datetime.now(timezone.utc).timestamp())
235
+ if remaining_ttl > 0:
236
+ await self._redis.set(self._blacklist_redis_key(jti), "1", ex=remaining_ttl)
237
+
238
+ async def verify_access_token(self, token: str) -> dict[str, Any]:
239
+ """
240
+ Verify signature, expiry, issuer, audience, and blacklist status.
241
+
242
+ Returns decoded claims dict on success. Raises jose.JWTError on failure.
243
+ Blacklist check is an O(1) Redis GET — fast enough for every request.
244
+ """
245
+ public_key = self._extract_public_key()
246
+
247
+ claims = jwt.decode(
248
+ token,
249
+ public_key,
250
+ algorithms=["RS256"],
251
+ audience=self._config.audience,
252
+ issuer=self._config.issuer,
253
+ )
254
+
255
+ if claims.get("type") != "access":
256
+ raise JWTError("Token is not an access token")
257
+
258
+ jti = claims.get("jti")
259
+ if jti and await self._redis.exists(self._blacklist_redis_key(jti)):
260
+ raise JWTError("Token has been revoked")
261
+
262
+ return claims
263
+
264
+ def _hash_token(self, token: str) -> str:
265
+ # SHA-256 hex digest — deterministic, non-reversible
266
+ return hashlib.sha256(token.encode()).hexdigest()
267
+
268
+ def _refresh_redis_key(self, token_hash: str) -> str:
269
+ return f"{self._config.key_prefix}:refresh:{token_hash}"
270
+
271
+ def _blacklist_redis_key(self, jti: str) -> str:
272
+ return f"{self._config.key_prefix}:blacklist:{jti}"