lockwatch 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lockwatch/__init__.py +47 -0
- lockwatch/anomaly.py +161 -0
- lockwatch/audit.py +179 -0
- lockwatch/jwt_rotation.py +272 -0
- lockwatch/middleware.py +429 -0
- lockwatch/rate_limiter.py +203 -0
- lockwatch-0.1.0.dist-info/METADATA +309 -0
- lockwatch-0.1.0.dist-info/RECORD +9 -0
- lockwatch-0.1.0.dist-info/WHEEL +4 -0
lockwatch/__init__.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LockWatch — API security middleware for FastAPI and Flask.
|
|
3
|
+
|
|
4
|
+
Public API surface:
|
|
5
|
+
|
|
6
|
+
from lockwatch import LockWatchMiddleware, RateLimiter, JWTRotationManager
|
|
7
|
+
from lockwatch import AnomalyDetector, AuditLogger
|
|
8
|
+
|
|
9
|
+
Quickstart (FastAPI):
|
|
10
|
+
|
|
11
|
+
from fastapi import FastAPI
|
|
12
|
+
from lockwatch import LockWatchMiddleware
|
|
13
|
+
|
|
14
|
+
app = FastAPI()
|
|
15
|
+
app.add_middleware(
|
|
16
|
+
LockWatchMiddleware,
|
|
17
|
+
redis_url="redis://localhost:6379",
|
|
18
|
+
rate_limit_requests=100,
|
|
19
|
+
rate_limit_window_seconds=60,
|
|
20
|
+
)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from lockwatch.rate_limiter import RateLimiter, RateLimitConfig, RateLimitState
|
|
24
|
+
from lockwatch.jwt_rotation import JWTRotationManager, TokenPair, JWTConfig
|
|
25
|
+
from lockwatch.anomaly import AnomalyDetector, AnomalyConfig
|
|
26
|
+
from lockwatch.audit import AuditLogger, AuditLogEntry, AuditQuery
|
|
27
|
+
from lockwatch.middleware import LockWatchMiddleware, LockWatchFlaskMiddleware, flask_jwt_required
|
|
28
|
+
|
|
29
|
+
__version__ = "0.1.0"
|
|
30
|
+
__all__ = [
|
|
31
|
+
# Middleware (primary entrypoint)
|
|
32
|
+
"LockWatchMiddleware",
|
|
33
|
+
"LockWatchFlaskMiddleware",
|
|
34
|
+
"flask_jwt_required",
|
|
35
|
+
# Individual components (for custom wiring)
|
|
36
|
+
"RateLimiter",
|
|
37
|
+
"RateLimitConfig",
|
|
38
|
+
"RateLimitState",
|
|
39
|
+
"JWTRotationManager",
|
|
40
|
+
"TokenPair",
|
|
41
|
+
"JWTConfig",
|
|
42
|
+
"AnomalyDetector",
|
|
43
|
+
"AnomalyConfig",
|
|
44
|
+
"AuditLogger",
|
|
45
|
+
"AuditLogEntry",
|
|
46
|
+
"AuditQuery",
|
|
47
|
+
]
|
lockwatch/anomaly.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IP burst anomaly detector.
|
|
3
|
+
|
|
4
|
+
Uses a separate Redis sorted set per IP (distinct from rate limiter keys) to
|
|
5
|
+
track request timestamps in a short burst window. When count exceeds threshold,
|
|
6
|
+
fires async webhook alerts and records a dedup key so we don't spam on every
|
|
7
|
+
subsequent request from the same IP.
|
|
8
|
+
|
|
9
|
+
Alert dedup: after firing, set Redis key lockwatch:anomaly:alerted:{ip} with
|
|
10
|
+
TTL=60s. Skip alert if key exists. This means at most one alert per IP per minute.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from datetime import datetime, timezone
|
|
20
|
+
from typing import Any, Optional
|
|
21
|
+
|
|
22
|
+
import httpx
|
|
23
|
+
from redis.asyncio import Redis
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class AnomalyConfig:
|
|
30
|
+
"""
|
|
31
|
+
Configuration for burst anomaly detection.
|
|
32
|
+
|
|
33
|
+
burst_threshold: number of requests in burst_window_seconds that triggers an alert.
|
|
34
|
+
webhook_urls: list of HTTPS endpoints to POST alert payloads to.
|
|
35
|
+
alert_cooldown_seconds: minimum seconds between alerts for the same IP.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
burst_threshold: int = 50
|
|
39
|
+
burst_window_seconds: int = 10
|
|
40
|
+
webhook_urls: list[str] = field(default_factory=list)
|
|
41
|
+
alert_cooldown_seconds: int = 60
|
|
42
|
+
key_prefix: str = "lockwatch:anomaly"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class AnomalyEvent:
|
|
47
|
+
"""Payload sent to configured webhook URLs when a burst is detected."""
|
|
48
|
+
|
|
49
|
+
ip: str
|
|
50
|
+
timestamp_utc: str
|
|
51
|
+
request_count_in_window: int
|
|
52
|
+
window_seconds: int
|
|
53
|
+
endpoint: str
|
|
54
|
+
method: str
|
|
55
|
+
user_agent: Optional[str] = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AnomalyDetector:
|
|
59
|
+
"""
|
|
60
|
+
Per-IP burst detector with async webhook alerting.
|
|
61
|
+
|
|
62
|
+
Usage:
|
|
63
|
+
detector = AnomalyDetector(redis_client, AnomalyConfig(burst_threshold=50, burst_window_seconds=10))
|
|
64
|
+
is_anomaly = await detector.check(ip="1.2.3.4", endpoint="/api/login", method="POST")
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, redis: Redis, config: AnomalyConfig) -> None:
|
|
68
|
+
self._redis = redis
|
|
69
|
+
self._config = config
|
|
70
|
+
# Shared httpx client — reuse connections across alert dispatches
|
|
71
|
+
self._http: Optional[httpx.AsyncClient] = None
|
|
72
|
+
|
|
73
|
+
async def check(
|
|
74
|
+
self,
|
|
75
|
+
ip: str,
|
|
76
|
+
endpoint: str,
|
|
77
|
+
method: str,
|
|
78
|
+
user_agent: Optional[str] = None,
|
|
79
|
+
) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Record a request from ip and return True if a burst anomaly is detected.
|
|
82
|
+
|
|
83
|
+
Does NOT block the request — that decision is left to the middleware.
|
|
84
|
+
Alert dispatch is fire-and-forget (asyncio.create_task) so it never
|
|
85
|
+
adds latency to the response path.
|
|
86
|
+
"""
|
|
87
|
+
now_ms = int(time.time() * 1000)
|
|
88
|
+
window_start_ms = now_ms - (self._config.burst_window_seconds * 1000)
|
|
89
|
+
burst_key = self._ip_redis_key(ip)
|
|
90
|
+
|
|
91
|
+
import uuid
|
|
92
|
+
pipe = self._redis.pipeline()
|
|
93
|
+
# UUID suffix ensures uniqueness within the same millisecond — same reason as rate_limiter
|
|
94
|
+
member = f"{now_ms}:{uuid.uuid4().hex}"
|
|
95
|
+
pipe.zadd(burst_key, {member: now_ms})
|
|
96
|
+
pipe.zremrangebyscore(burst_key, 0, window_start_ms)
|
|
97
|
+
pipe.zcard(burst_key)
|
|
98
|
+
pipe.expire(burst_key, self._config.burst_window_seconds)
|
|
99
|
+
results = await pipe.execute()
|
|
100
|
+
|
|
101
|
+
count: int = results[2]
|
|
102
|
+
is_burst = count >= self._config.burst_threshold
|
|
103
|
+
|
|
104
|
+
if is_burst and self._config.webhook_urls:
|
|
105
|
+
# Only fire if we haven't alerted for this IP within the cooldown window
|
|
106
|
+
alerted_key = self._alerted_redis_key(ip)
|
|
107
|
+
already_alerted = await self._redis.exists(alerted_key)
|
|
108
|
+
if not already_alerted:
|
|
109
|
+
await self._redis.set(alerted_key, "1", ex=self._config.alert_cooldown_seconds)
|
|
110
|
+
event = AnomalyEvent(
|
|
111
|
+
ip=ip,
|
|
112
|
+
timestamp_utc=datetime.now(timezone.utc).isoformat(),
|
|
113
|
+
request_count_in_window=count,
|
|
114
|
+
window_seconds=self._config.burst_window_seconds,
|
|
115
|
+
endpoint=endpoint,
|
|
116
|
+
method=method,
|
|
117
|
+
user_agent=user_agent,
|
|
118
|
+
)
|
|
119
|
+
# Fire-and-forget — never let webhook delay block a response
|
|
120
|
+
asyncio.create_task(self._dispatch_alert(event))
|
|
121
|
+
|
|
122
|
+
return is_burst
|
|
123
|
+
|
|
124
|
+
async def _dispatch_alert(self, event: AnomalyEvent) -> None:
|
|
125
|
+
"""
|
|
126
|
+
POST alert payload to all configured webhook URLs concurrently.
|
|
127
|
+
|
|
128
|
+
Failures are logged but swallowed — a broken webhook should never affect
|
|
129
|
+
the application. Each webhook gets a 5s timeout.
|
|
130
|
+
"""
|
|
131
|
+
if self._http is None:
|
|
132
|
+
self._http = httpx.AsyncClient(timeout=5.0)
|
|
133
|
+
|
|
134
|
+
payload = {
|
|
135
|
+
"ip": event.ip,
|
|
136
|
+
"timestamp": event.timestamp_utc,
|
|
137
|
+
"request_count_in_window": event.request_count_in_window,
|
|
138
|
+
"window_seconds": event.window_seconds,
|
|
139
|
+
"endpoint": event.endpoint,
|
|
140
|
+
"method": event.method,
|
|
141
|
+
"user_agent": event.user_agent,
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
async def post_to(url: str) -> None:
|
|
145
|
+
try:
|
|
146
|
+
await self._http.post(url, json=payload) # type: ignore[union-attr]
|
|
147
|
+
except Exception as exc:
|
|
148
|
+
logger.warning("LockWatch anomaly webhook failed (%s): %s", url, exc)
|
|
149
|
+
|
|
150
|
+
await asyncio.gather(*[post_to(url) for url in self._config.webhook_urls])
|
|
151
|
+
|
|
152
|
+
async def close(self) -> None:
|
|
153
|
+
"""Clean up the shared HTTP client. Call on application shutdown."""
|
|
154
|
+
if self._http:
|
|
155
|
+
await self._http.aclose()
|
|
156
|
+
|
|
157
|
+
def _ip_redis_key(self, ip: str) -> str:
|
|
158
|
+
return f"{self._config.key_prefix}:burst:{ip}"
|
|
159
|
+
|
|
160
|
+
def _alerted_redis_key(self, ip: str) -> str:
|
|
161
|
+
return f"{self._config.key_prefix}:alerted:{ip}"
|
lockwatch/audit.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Postgres audit log for every request processed by LockWatch middleware.
|
|
3
|
+
|
|
4
|
+
Design choices:
|
|
5
|
+
- Writes are async background tasks (fire-and-forget) so audit logging is
|
|
6
|
+
never on the critical response path. A failed write logs a warning but
|
|
7
|
+
does not affect the response.
|
|
8
|
+
- SQLAlchemy 2.0 async session with asyncpg driver — fastest available Python
|
|
9
|
+
Postgres driver; async avoids blocking the event loop.
|
|
10
|
+
- Indices on (timestamp, user_id) and (timestamp, ip) support the two most
|
|
11
|
+
common audit query patterns: "what did user X do?" and "what came from IP Y?".
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from dataclasses import dataclass
|
|
18
|
+
from datetime import datetime, timezone
|
|
19
|
+
from typing import Optional
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
# SQLAlchemy imports are conditional — audit is an optional dependency
|
|
24
|
+
try:
|
|
25
|
+
from sqlalchemy import Index, String, Boolean, Integer, Float, DateTime, select
|
|
26
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
27
|
+
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine, async_sessionmaker
|
|
28
|
+
|
|
29
|
+
class _Base(DeclarativeBase):
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
class AuditLogEntry(_Base):
|
|
33
|
+
"""
|
|
34
|
+
SQLAlchemy model for a single audited request.
|
|
35
|
+
|
|
36
|
+
Every request that passes through LockWatchMiddleware gets one row.
|
|
37
|
+
Columns are intentionally wide — storage is cheap, retroactive debugging
|
|
38
|
+
of security incidents is not.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
__tablename__ = "lockwatch_audit_log"
|
|
42
|
+
|
|
43
|
+
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
44
|
+
# When the request was received (UTC, stored as naive datetime per SQLAlchemy convention)
|
|
45
|
+
timestamp: Mapped[datetime] = mapped_column(DateTime, nullable=False)
|
|
46
|
+
# Authenticated user identifier if available; NULL for unauthenticated requests
|
|
47
|
+
user_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
|
48
|
+
ip: Mapped[str] = mapped_column(String(45), nullable=False) # IPv6 max = 39 chars + CIDR
|
|
49
|
+
endpoint: Mapped[str] = mapped_column(String(2048), nullable=False)
|
|
50
|
+
method: Mapped[str] = mapped_column(String(10), nullable=False)
|
|
51
|
+
status_code: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
52
|
+
# True if this request was counted against a rate limit (even if not blocked)
|
|
53
|
+
rate_limit_hit: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
54
|
+
# True if anomaly detector fired for this request
|
|
55
|
+
anomaly_detected: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
56
|
+
# Response latency in milliseconds (measured from middleware entry to exit)
|
|
57
|
+
latency_ms: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
|
58
|
+
|
|
59
|
+
__table_args__ = (
|
|
60
|
+
# Fast lookups: "all requests from user X in time range"
|
|
61
|
+
Index("ix_audit_timestamp_user", "timestamp", "user_id"),
|
|
62
|
+
# Fast lookups: "all requests from IP Y in time range"
|
|
63
|
+
Index("ix_audit_timestamp_ip", "timestamp", "ip"),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
_SQLALCHEMY_AVAILABLE = True
|
|
67
|
+
|
|
68
|
+
except ImportError:
|
|
69
|
+
_SQLALCHEMY_AVAILABLE = False
|
|
70
|
+
|
|
71
|
+
# sqlalchemy not installed — provide a stub so the rest of the library
|
|
72
|
+
# can import audit.py without crashing even if [audit] extras are absent
|
|
73
|
+
class AuditLogEntry: # type: ignore[no-redef]
|
|
74
|
+
"""Stub: install lockwatch[audit] to enable Postgres audit logging."""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class AuditQuery:
|
|
80
|
+
"""Filters for querying the audit log."""
|
|
81
|
+
|
|
82
|
+
user_id: Optional[str] = None
|
|
83
|
+
ip: Optional[str] = None
|
|
84
|
+
endpoint_prefix: Optional[str] = None
|
|
85
|
+
from_time: Optional[datetime] = None
|
|
86
|
+
to_time: Optional[datetime] = None
|
|
87
|
+
rate_limit_hits_only: bool = False
|
|
88
|
+
anomalies_only: bool = False
|
|
89
|
+
limit: int = 100
|
|
90
|
+
offset: int = 0
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class AuditLogger:
|
|
94
|
+
"""
|
|
95
|
+
Async audit logger backed by Postgres via SQLAlchemy.
|
|
96
|
+
|
|
97
|
+
Usage:
|
|
98
|
+
logger = AuditLogger(database_url="postgresql+asyncpg://user:pass@host/db")
|
|
99
|
+
await logger.init() # creates table if not exists
|
|
100
|
+
await logger.log(entry)
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, database_url: str) -> None:
|
|
104
|
+
if not _SQLALCHEMY_AVAILABLE:
|
|
105
|
+
raise ImportError(
|
|
106
|
+
"AuditLogger requires SQLAlchemy and asyncpg. "
|
|
107
|
+
"Install with: pip install lockwatch[audit]"
|
|
108
|
+
)
|
|
109
|
+
self._database_url = database_url
|
|
110
|
+
self._engine: Optional[AsyncEngine] = None
|
|
111
|
+
self._session_factory: Optional[async_sessionmaker[AsyncSession]] = None
|
|
112
|
+
|
|
113
|
+
async def init(self) -> None:
|
|
114
|
+
"""
|
|
115
|
+
Create the audit log table if it doesn't exist.
|
|
116
|
+
|
|
117
|
+
Call once at application startup. Idempotent — safe to call on every startup.
|
|
118
|
+
"""
|
|
119
|
+
self._engine = create_async_engine(self._database_url, echo=False)
|
|
120
|
+
self._session_factory = async_sessionmaker(self._engine, expire_on_commit=False)
|
|
121
|
+
async with self._engine.begin() as conn:
|
|
122
|
+
await conn.run_sync(_Base.metadata.create_all)
|
|
123
|
+
|
|
124
|
+
async def log(self, entry: AuditLogEntry) -> None:
|
|
125
|
+
"""
|
|
126
|
+
Persist an audit log entry asynchronously.
|
|
127
|
+
|
|
128
|
+
Designed to be called via asyncio.create_task() so it never blocks the
|
|
129
|
+
response. Failures are logged at WARNING level — never re-raised.
|
|
130
|
+
"""
|
|
131
|
+
if self._session_factory is None:
|
|
132
|
+
logger.warning("AuditLogger.log() called before init() — skipping")
|
|
133
|
+
return
|
|
134
|
+
try:
|
|
135
|
+
async with self._session_factory() as session:
|
|
136
|
+
session.add(entry)
|
|
137
|
+
await session.commit()
|
|
138
|
+
except Exception as exc:
|
|
139
|
+
logger.warning("LockWatch audit log write failed: %s", exc)
|
|
140
|
+
|
|
141
|
+
async def query(self, filters: AuditQuery) -> list[AuditLogEntry]:
|
|
142
|
+
"""
|
|
143
|
+
Query audit log entries matching filters.
|
|
144
|
+
|
|
145
|
+
Returns at most filters.limit rows, ordered by timestamp DESC.
|
|
146
|
+
"""
|
|
147
|
+
if self._session_factory is None:
|
|
148
|
+
raise RuntimeError("AuditLogger not initialized — call await logger.init() first")
|
|
149
|
+
|
|
150
|
+
stmt = select(AuditLogEntry)
|
|
151
|
+
|
|
152
|
+
if filters.user_id is not None:
|
|
153
|
+
stmt = stmt.where(AuditLogEntry.user_id == filters.user_id)
|
|
154
|
+
if filters.ip is not None:
|
|
155
|
+
stmt = stmt.where(AuditLogEntry.ip == filters.ip)
|
|
156
|
+
if filters.endpoint_prefix is not None:
|
|
157
|
+
stmt = stmt.where(AuditLogEntry.endpoint.like(f"{filters.endpoint_prefix}%"))
|
|
158
|
+
if filters.from_time is not None:
|
|
159
|
+
stmt = stmt.where(AuditLogEntry.timestamp >= filters.from_time)
|
|
160
|
+
if filters.to_time is not None:
|
|
161
|
+
stmt = stmt.where(AuditLogEntry.timestamp <= filters.to_time)
|
|
162
|
+
if filters.rate_limit_hits_only:
|
|
163
|
+
stmt = stmt.where(AuditLogEntry.rate_limit_hit == True) # noqa: E712
|
|
164
|
+
if filters.anomalies_only:
|
|
165
|
+
stmt = stmt.where(AuditLogEntry.anomaly_detected == True) # noqa: E712
|
|
166
|
+
|
|
167
|
+
stmt = stmt.order_by(AuditLogEntry.timestamp.desc())
|
|
168
|
+
stmt = stmt.limit(filters.limit).offset(filters.offset)
|
|
169
|
+
|
|
170
|
+
async with self._session_factory() as session:
|
|
171
|
+
result = await session.execute(stmt)
|
|
172
|
+
return list(result.scalars().all())
|
|
173
|
+
|
|
174
|
+
async def close(self) -> None:
|
|
175
|
+
"""Dispose the SQLAlchemy engine. Call on application shutdown."""
|
|
176
|
+
if self._engine:
|
|
177
|
+
await self._engine.dispose()
|
|
178
|
+
self._engine = None
|
|
179
|
+
self._session_factory = None
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
JWT rotation manager: access token (15min RS256) + refresh token (7-day).
|
|
3
|
+
|
|
4
|
+
Security model:
|
|
5
|
+
- Refresh tokens are stored as SHA-256(token) in Redis — raw tokens never
|
|
6
|
+
persisted, so a Redis dump doesn't leak usable tokens.
|
|
7
|
+
- Rotation is atomic: blacklist old JTI before issuing new pair. If the new
|
|
8
|
+
issue fails (e.g. Redis down), the old token stays valid — fail open is
|
|
9
|
+
intentional for availability, but configurable.
|
|
10
|
+
- Access token blacklist: stored in Redis with TTL = remaining token lifetime.
|
|
11
|
+
After expiry the key auto-deletes; no unbounded blacklist growth.
|
|
12
|
+
- JTIs are UUID4 — unguessable, no sequential enumeration possible.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import uuid
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
from datetime import datetime, timedelta, timezone
|
|
21
|
+
from typing import Any, Optional
|
|
22
|
+
|
|
23
|
+
from jose import jwt, JWTError
|
|
24
|
+
from redis.asyncio import Redis
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class JWTConfig:
|
|
29
|
+
"""
|
|
30
|
+
Configuration for the JWT rotation manager.
|
|
31
|
+
|
|
32
|
+
Either provide rsa_private_key_pem (PEM string) or a path to a PEM file
|
|
33
|
+
via rsa_private_key_path — the manager loads the file on init.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
rsa_private_key_pem: Optional[str] = None
|
|
37
|
+
rsa_private_key_path: Optional[str] = None
|
|
38
|
+
issuer: str = "lockwatch"
|
|
39
|
+
audience: str = "lockwatch-api"
|
|
40
|
+
access_token_ttl_seconds: int = 900 # 15 minutes
|
|
41
|
+
refresh_token_ttl_seconds: int = 604800 # 7 days
|
|
42
|
+
key_prefix: str = "lockwatch:jwt"
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class TokenPair:
|
|
47
|
+
"""A newly-issued access + refresh token pair."""
|
|
48
|
+
|
|
49
|
+
access_token: str
|
|
50
|
+
refresh_token: str
|
|
51
|
+
access_expires_at: datetime
|
|
52
|
+
refresh_expires_at: datetime
|
|
53
|
+
token_type: str = "Bearer"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class JWTRotationManager:
|
|
57
|
+
"""
|
|
58
|
+
Issues RS256 JWTs with automatic rotation and Redis-backed blacklisting.
|
|
59
|
+
|
|
60
|
+
Usage:
|
|
61
|
+
manager = JWTRotationManager(redis_client, JWTConfig(rsa_private_key_pem=pem))
|
|
62
|
+
pair = await manager.issue_token_pair(subject="user:42", claims={"roles": ["admin"]})
|
|
63
|
+
# ... on refresh endpoint ...
|
|
64
|
+
new_pair = await manager.rotate(incoming_refresh_token)
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, redis: Redis, config: JWTConfig) -> None:
|
|
68
|
+
self._redis = redis
|
|
69
|
+
self._config = config
|
|
70
|
+
self._private_key: Optional[str] = None # loaded lazily
|
|
71
|
+
|
|
72
|
+
def _load_private_key(self) -> str:
|
|
73
|
+
if self._private_key:
|
|
74
|
+
return self._private_key
|
|
75
|
+
if self._config.rsa_private_key_pem:
|
|
76
|
+
self._private_key = self._config.rsa_private_key_pem
|
|
77
|
+
elif self._config.rsa_private_key_path:
|
|
78
|
+
with open(self._config.rsa_private_key_path) as f:
|
|
79
|
+
self._private_key = f.read()
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError("JWTConfig requires either rsa_private_key_pem or rsa_private_key_path")
|
|
82
|
+
return self._private_key
|
|
83
|
+
|
|
84
|
+
def _extract_public_key(self) -> str:
|
|
85
|
+
"""Derive RSA public key from the private key PEM for verification."""
|
|
86
|
+
from cryptography.hazmat.primitives.serialization import (
|
|
87
|
+
load_pem_private_key,
|
|
88
|
+
Encoding,
|
|
89
|
+
PublicFormat,
|
|
90
|
+
)
|
|
91
|
+
private_key_pem = self._load_private_key().encode()
|
|
92
|
+
private_key = load_pem_private_key(private_key_pem, password=None)
|
|
93
|
+
return private_key.public_key().public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo).decode()
|
|
94
|
+
|
|
95
|
+
async def issue_token_pair(
|
|
96
|
+
self, subject: str, claims: Optional[dict[str, Any]] = None
|
|
97
|
+
) -> TokenPair:
|
|
98
|
+
"""
|
|
99
|
+
Issue a new access + refresh token pair for a subject.
|
|
100
|
+
|
|
101
|
+
Generates a UUID4 JTI for each token so they're individually revocable.
|
|
102
|
+
Stores SHA-256(refresh_token) in Redis with refresh_token_ttl_seconds TTL.
|
|
103
|
+
"""
|
|
104
|
+
private_key = self._load_private_key()
|
|
105
|
+
now = datetime.now(timezone.utc)
|
|
106
|
+
|
|
107
|
+
access_jti = str(uuid.uuid4())
|
|
108
|
+
access_expires_at = now + timedelta(seconds=self._config.access_token_ttl_seconds)
|
|
109
|
+
access_payload: dict[str, Any] = {
|
|
110
|
+
"sub": subject,
|
|
111
|
+
"iss": self._config.issuer,
|
|
112
|
+
"aud": self._config.audience,
|
|
113
|
+
"iat": int(now.timestamp()),
|
|
114
|
+
"exp": int(access_expires_at.timestamp()),
|
|
115
|
+
"jti": access_jti,
|
|
116
|
+
"type": "access",
|
|
117
|
+
}
|
|
118
|
+
if claims:
|
|
119
|
+
access_payload.update(claims)
|
|
120
|
+
|
|
121
|
+
access_token = jwt.encode(access_payload, private_key, algorithm="RS256")
|
|
122
|
+
|
|
123
|
+
refresh_jti = str(uuid.uuid4())
|
|
124
|
+
refresh_expires_at = now + timedelta(seconds=self._config.refresh_token_ttl_seconds)
|
|
125
|
+
refresh_payload: dict[str, Any] = {
|
|
126
|
+
"sub": subject,
|
|
127
|
+
"iss": self._config.issuer,
|
|
128
|
+
"aud": self._config.audience,
|
|
129
|
+
"iat": int(now.timestamp()),
|
|
130
|
+
"exp": int(refresh_expires_at.timestamp()),
|
|
131
|
+
"jti": refresh_jti,
|
|
132
|
+
"type": "refresh",
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
refresh_token = jwt.encode(refresh_payload, private_key, algorithm="RS256")
|
|
136
|
+
|
|
137
|
+
# Store hash so Redis never holds a usable token — hash lookup on rotate()
|
|
138
|
+
token_hash = self._hash_token(refresh_token)
|
|
139
|
+
redis_key = self._refresh_redis_key(token_hash)
|
|
140
|
+
await self._redis.set(redis_key, refresh_jti, ex=self._config.refresh_token_ttl_seconds)
|
|
141
|
+
|
|
142
|
+
return TokenPair(
|
|
143
|
+
access_token=access_token,
|
|
144
|
+
refresh_token=refresh_token,
|
|
145
|
+
access_expires_at=access_expires_at,
|
|
146
|
+
refresh_expires_at=refresh_expires_at,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
async def rotate(self, refresh_token: str) -> TokenPair:
|
|
150
|
+
"""
|
|
151
|
+
Validate a refresh token, atomically delete its hash, and issue a new pair.
|
|
152
|
+
|
|
153
|
+
Uses WATCH/MULTI/EXEC to eliminate the race condition where two concurrent
|
|
154
|
+
callers could both pass the existence check before either deletes the key.
|
|
155
|
+
If another client deletes the key between WATCH and EXEC, WatchError is
|
|
156
|
+
raised and we retry — on retry the GET returns None → raises ValueError.
|
|
157
|
+
"""
|
|
158
|
+
public_key = self._extract_public_key()
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
claims = jwt.decode(
|
|
162
|
+
refresh_token,
|
|
163
|
+
public_key,
|
|
164
|
+
algorithms=["RS256"],
|
|
165
|
+
audience=self._config.audience,
|
|
166
|
+
issuer=self._config.issuer,
|
|
167
|
+
)
|
|
168
|
+
except JWTError as e:
|
|
169
|
+
raise ValueError(f"Invalid refresh token: {e}") from e
|
|
170
|
+
|
|
171
|
+
if claims.get("type") != "refresh":
|
|
172
|
+
raise ValueError("Token is not a refresh token")
|
|
173
|
+
|
|
174
|
+
jti = claims.get("jti")
|
|
175
|
+
subject = claims.get("sub")
|
|
176
|
+
if not jti or not subject:
|
|
177
|
+
raise ValueError("Refresh token missing required claims")
|
|
178
|
+
|
|
179
|
+
token_hash = self._hash_token(refresh_token)
|
|
180
|
+
refresh_key = self._refresh_redis_key(token_hash)
|
|
181
|
+
blacklist_key = self._blacklist_redis_key(jti)
|
|
182
|
+
|
|
183
|
+
from redis.exceptions import WatchError
|
|
184
|
+
|
|
185
|
+
async with self._redis.pipeline() as pipe:
|
|
186
|
+
while True:
|
|
187
|
+
try:
|
|
188
|
+
await pipe.watch(refresh_key)
|
|
189
|
+
# pipe is now in immediate-execution mode (not buffered)
|
|
190
|
+
stored_jti = await pipe.get(refresh_key)
|
|
191
|
+
if not stored_jti:
|
|
192
|
+
await pipe.unwatch()
|
|
193
|
+
raise ValueError("Refresh token has already been used or revoked")
|
|
194
|
+
|
|
195
|
+
is_blacklisted = await pipe.exists(blacklist_key)
|
|
196
|
+
if is_blacklisted:
|
|
197
|
+
await pipe.unwatch()
|
|
198
|
+
raise ValueError("Refresh token JTI is blacklisted")
|
|
199
|
+
|
|
200
|
+
pipe.multi() # switch to buffered mode
|
|
201
|
+
pipe.delete(refresh_key)
|
|
202
|
+
await pipe.execute() # atomic; raises WatchError if key was changed
|
|
203
|
+
break
|
|
204
|
+
except WatchError:
|
|
205
|
+
continue # concurrent rotation: retry sees empty key → ValueError
|
|
206
|
+
|
|
207
|
+
return await self.issue_token_pair(subject=subject)
|
|
208
|
+
|
|
209
|
+
async def blacklist_access_token(self, token: str) -> None:
|
|
210
|
+
"""
|
|
211
|
+
Revoke an access token before its natural expiry.
|
|
212
|
+
|
|
213
|
+
Computes remaining TTL from token claims so the blacklist key auto-expires
|
|
214
|
+
exactly when the token would have expired naturally — no orphaned keys.
|
|
215
|
+
"""
|
|
216
|
+
public_key = self._extract_public_key()
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
claims = jwt.decode(
|
|
220
|
+
token,
|
|
221
|
+
public_key,
|
|
222
|
+
algorithms=["RS256"],
|
|
223
|
+
audience=self._config.audience,
|
|
224
|
+
issuer=self._config.issuer,
|
|
225
|
+
)
|
|
226
|
+
except JWTError as e:
|
|
227
|
+
raise ValueError(f"Cannot blacklist invalid token: {e}") from e
|
|
228
|
+
|
|
229
|
+
jti = claims.get("jti")
|
|
230
|
+
exp = claims.get("exp")
|
|
231
|
+
if not jti or not exp:
|
|
232
|
+
raise ValueError("Token missing jti or exp claim")
|
|
233
|
+
|
|
234
|
+
remaining_ttl = int(exp - datetime.now(timezone.utc).timestamp())
|
|
235
|
+
if remaining_ttl > 0:
|
|
236
|
+
await self._redis.set(self._blacklist_redis_key(jti), "1", ex=remaining_ttl)
|
|
237
|
+
|
|
238
|
+
async def verify_access_token(self, token: str) -> dict[str, Any]:
|
|
239
|
+
"""
|
|
240
|
+
Verify signature, expiry, issuer, audience, and blacklist status.
|
|
241
|
+
|
|
242
|
+
Returns decoded claims dict on success. Raises jose.JWTError on failure.
|
|
243
|
+
Blacklist check is an O(1) Redis GET — fast enough for every request.
|
|
244
|
+
"""
|
|
245
|
+
public_key = self._extract_public_key()
|
|
246
|
+
|
|
247
|
+
claims = jwt.decode(
|
|
248
|
+
token,
|
|
249
|
+
public_key,
|
|
250
|
+
algorithms=["RS256"],
|
|
251
|
+
audience=self._config.audience,
|
|
252
|
+
issuer=self._config.issuer,
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if claims.get("type") != "access":
|
|
256
|
+
raise JWTError("Token is not an access token")
|
|
257
|
+
|
|
258
|
+
jti = claims.get("jti")
|
|
259
|
+
if jti and await self._redis.exists(self._blacklist_redis_key(jti)):
|
|
260
|
+
raise JWTError("Token has been revoked")
|
|
261
|
+
|
|
262
|
+
return claims
|
|
263
|
+
|
|
264
|
+
def _hash_token(self, token: str) -> str:
|
|
265
|
+
# SHA-256 hex digest — deterministic, non-reversible
|
|
266
|
+
return hashlib.sha256(token.encode()).hexdigest()
|
|
267
|
+
|
|
268
|
+
def _refresh_redis_key(self, token_hash: str) -> str:
|
|
269
|
+
return f"{self._config.key_prefix}:refresh:{token_hash}"
|
|
270
|
+
|
|
271
|
+
def _blacklist_redis_key(self, jti: str) -> str:
|
|
272
|
+
return f"{self._config.key_prefix}:blacklist:{jti}"
|