mdb-engine 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mdb_engine/README.md +144 -0
- mdb_engine/__init__.py +37 -0
- mdb_engine/auth/README.md +631 -0
- mdb_engine/auth/__init__.py +128 -0
- mdb_engine/auth/casbin_factory.py +199 -0
- mdb_engine/auth/casbin_models.py +46 -0
- mdb_engine/auth/config_defaults.py +71 -0
- mdb_engine/auth/config_helpers.py +213 -0
- mdb_engine/auth/cookie_utils.py +158 -0
- mdb_engine/auth/decorators.py +350 -0
- mdb_engine/auth/dependencies.py +747 -0
- mdb_engine/auth/helpers.py +64 -0
- mdb_engine/auth/integration.py +578 -0
- mdb_engine/auth/jwt.py +225 -0
- mdb_engine/auth/middleware.py +241 -0
- mdb_engine/auth/oso_factory.py +323 -0
- mdb_engine/auth/provider.py +570 -0
- mdb_engine/auth/restrictions.py +271 -0
- mdb_engine/auth/session_manager.py +477 -0
- mdb_engine/auth/token_lifecycle.py +213 -0
- mdb_engine/auth/token_store.py +289 -0
- mdb_engine/auth/users.py +1516 -0
- mdb_engine/auth/utils.py +614 -0
- mdb_engine/cli/__init__.py +13 -0
- mdb_engine/cli/commands/__init__.py +7 -0
- mdb_engine/cli/commands/generate.py +105 -0
- mdb_engine/cli/commands/migrate.py +83 -0
- mdb_engine/cli/commands/show.py +70 -0
- mdb_engine/cli/commands/validate.py +63 -0
- mdb_engine/cli/main.py +41 -0
- mdb_engine/cli/utils.py +92 -0
- mdb_engine/config.py +217 -0
- mdb_engine/constants.py +160 -0
- mdb_engine/core/README.md +542 -0
- mdb_engine/core/__init__.py +42 -0
- mdb_engine/core/app_registration.py +392 -0
- mdb_engine/core/connection.py +243 -0
- mdb_engine/core/engine.py +749 -0
- mdb_engine/core/index_management.py +162 -0
- mdb_engine/core/manifest.py +2793 -0
- mdb_engine/core/seeding.py +179 -0
- mdb_engine/core/service_initialization.py +355 -0
- mdb_engine/core/types.py +413 -0
- mdb_engine/database/README.md +522 -0
- mdb_engine/database/__init__.py +31 -0
- mdb_engine/database/abstraction.py +635 -0
- mdb_engine/database/connection.py +387 -0
- mdb_engine/database/scoped_wrapper.py +1721 -0
- mdb_engine/embeddings/README.md +184 -0
- mdb_engine/embeddings/__init__.py +62 -0
- mdb_engine/embeddings/dependencies.py +193 -0
- mdb_engine/embeddings/service.py +759 -0
- mdb_engine/exceptions.py +167 -0
- mdb_engine/indexes/README.md +651 -0
- mdb_engine/indexes/__init__.py +21 -0
- mdb_engine/indexes/helpers.py +145 -0
- mdb_engine/indexes/manager.py +895 -0
- mdb_engine/memory/README.md +451 -0
- mdb_engine/memory/__init__.py +30 -0
- mdb_engine/memory/service.py +1285 -0
- mdb_engine/observability/README.md +515 -0
- mdb_engine/observability/__init__.py +42 -0
- mdb_engine/observability/health.py +296 -0
- mdb_engine/observability/logging.py +161 -0
- mdb_engine/observability/metrics.py +297 -0
- mdb_engine/routing/README.md +462 -0
- mdb_engine/routing/__init__.py +73 -0
- mdb_engine/routing/websockets.py +813 -0
- mdb_engine/utils/__init__.py +7 -0
- mdb_engine-0.1.6.dist-info/METADATA +213 -0
- mdb_engine-0.1.6.dist-info/RECORD +75 -0
- mdb_engine-0.1.6.dist-info/WHEEL +5 -0
- mdb_engine-0.1.6.dist-info/entry_points.txt +2 -0
- mdb_engine-0.1.6.dist-info/licenses/LICENSE +661 -0
- mdb_engine-0.1.6.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Session Management
|
|
3
|
+
|
|
4
|
+
Provides session tracking and management for user authentication sessions.
|
|
5
|
+
|
|
6
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime, timedelta
|
|
11
|
+
from typing import Any, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from bson.objectid import ObjectId
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
from pymongo.errors import (ConnectionFailure, OperationFailure,
|
|
17
|
+
ServerSelectionTimeoutError)
|
|
18
|
+
except ImportError:
|
|
19
|
+
ConnectionFailure = Exception
|
|
20
|
+
OperationFailure = Exception
|
|
21
|
+
ServerSelectionTimeoutError = Exception
|
|
22
|
+
|
|
23
|
+
from ..config import MAX_SESSIONS_PER_USER as CONFIG_MAX_SESSIONS
|
|
24
|
+
from ..config import SESSION_INACTIVITY_TIMEOUT as CONFIG_INACTIVITY_TIMEOUT
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SessionManager:
|
|
30
|
+
"""
|
|
31
|
+
Manages user sessions with device tracking and activity monitoring.
|
|
32
|
+
|
|
33
|
+
Tracks active sessions per user, supports multiple concurrent sessions,
|
|
34
|
+
and provides session management capabilities.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, db, collection_name: str = "user_sessions"):
|
|
38
|
+
"""
|
|
39
|
+
Initialize session manager.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
db: MongoDB database instance (Motor AsyncIOMotorDatabase)
|
|
43
|
+
collection_name: Name of the sessions collection (default: "user_sessions")
|
|
44
|
+
"""
|
|
45
|
+
self.db = db
|
|
46
|
+
self.collection = db[collection_name]
|
|
47
|
+
self._indexes_created = False
|
|
48
|
+
self.max_sessions = CONFIG_MAX_SESSIONS
|
|
49
|
+
self.inactivity_timeout = CONFIG_INACTIVITY_TIMEOUT
|
|
50
|
+
self.fingerprinting_enabled = False
|
|
51
|
+
self.fingerprinting_strict = False
|
|
52
|
+
|
|
53
|
+
async def ensure_indexes(self):
|
|
54
|
+
"""
|
|
55
|
+
Ensure required indexes exist for the sessions collection.
|
|
56
|
+
|
|
57
|
+
Creates:
|
|
58
|
+
- Index on 'user_id' + 'last_seen' for session cleanup queries
|
|
59
|
+
- Index on 'device_id' for device lookup
|
|
60
|
+
- Index on 'refresh_jti' for token-to-session mapping
|
|
61
|
+
"""
|
|
62
|
+
if self._indexes_created:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
# Index for user session queries
|
|
67
|
+
await self.collection.create_index(
|
|
68
|
+
[("user_id", 1), ("last_seen", -1)], name="user_id_last_seen_idx"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Index for device lookup
|
|
72
|
+
await self.collection.create_index("device_id", name="device_id_idx")
|
|
73
|
+
|
|
74
|
+
# Index for refresh token lookup
|
|
75
|
+
await self.collection.create_index(
|
|
76
|
+
"refresh_jti", unique=True, name="refresh_jti_unique_idx"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self._indexes_created = True
|
|
80
|
+
logger.info("Session manager indexes created successfully")
|
|
81
|
+
except (
|
|
82
|
+
OperationFailure,
|
|
83
|
+
ConnectionFailure,
|
|
84
|
+
ServerSelectionTimeoutError,
|
|
85
|
+
AttributeError,
|
|
86
|
+
TypeError,
|
|
87
|
+
) as e:
|
|
88
|
+
logger.error(f"Error creating session manager indexes: {e}", exc_info=True)
|
|
89
|
+
# Don't raise - indexes might already exist
|
|
90
|
+
|
|
91
|
+
async def create_session(
|
|
92
|
+
self,
|
|
93
|
+
user_id: str,
|
|
94
|
+
device_id: str,
|
|
95
|
+
refresh_jti: str,
|
|
96
|
+
device_info: Optional[Dict[str, Any]] = None,
|
|
97
|
+
ip_address: Optional[str] = None,
|
|
98
|
+
session_fingerprint: Optional[str] = None,
|
|
99
|
+
) -> Optional[Dict[str, Any]]:
|
|
100
|
+
"""
|
|
101
|
+
Create a new user session.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
user_id: User identifier (email or user_id)
|
|
105
|
+
device_id: Unique device identifier
|
|
106
|
+
refresh_jti: Refresh token JWT ID
|
|
107
|
+
device_info: Optional device metadata (user_agent, browser, os, etc.)
|
|
108
|
+
ip_address: Optional IP address
|
|
109
|
+
session_fingerprint: Optional session fingerprint hash for security
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Created session document or None if creation failed
|
|
113
|
+
"""
|
|
114
|
+
try:
|
|
115
|
+
await self.ensure_indexes()
|
|
116
|
+
|
|
117
|
+
# Check session limit
|
|
118
|
+
active_sessions = await self.get_user_sessions(user_id, active_only=True)
|
|
119
|
+
if len(active_sessions) >= self.max_sessions:
|
|
120
|
+
# Remove oldest inactive session
|
|
121
|
+
await self.cleanup_inactive_sessions(user_id)
|
|
122
|
+
# Check again
|
|
123
|
+
active_sessions = await self.get_user_sessions(
|
|
124
|
+
user_id, active_only=True
|
|
125
|
+
)
|
|
126
|
+
if len(active_sessions) >= self.max_sessions:
|
|
127
|
+
# Force remove oldest session
|
|
128
|
+
if active_sessions:
|
|
129
|
+
oldest = active_sessions[-1] # Last in sorted list (oldest)
|
|
130
|
+
await self.revoke_session(oldest["_id"])
|
|
131
|
+
|
|
132
|
+
now = datetime.utcnow()
|
|
133
|
+
session_doc = {
|
|
134
|
+
"user_id": user_id,
|
|
135
|
+
"device_id": device_id,
|
|
136
|
+
"refresh_jti": refresh_jti,
|
|
137
|
+
"created_at": now,
|
|
138
|
+
"last_seen": now,
|
|
139
|
+
"ip_address": ip_address,
|
|
140
|
+
"active": True,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
if self.fingerprinting_enabled and session_fingerprint:
|
|
144
|
+
session_doc["session_fingerprint"] = session_fingerprint
|
|
145
|
+
|
|
146
|
+
if device_info:
|
|
147
|
+
session_doc.update(
|
|
148
|
+
{
|
|
149
|
+
"user_agent": device_info.get("user_agent"),
|
|
150
|
+
"browser": device_info.get("browser"),
|
|
151
|
+
"os": device_info.get("os"),
|
|
152
|
+
"device_type": device_info.get("device_type"),
|
|
153
|
+
"location": device_info.get("location"),
|
|
154
|
+
}
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
result = await self.collection.insert_one(session_doc)
|
|
158
|
+
session_doc["_id"] = result.inserted_id
|
|
159
|
+
|
|
160
|
+
logger.debug(
|
|
161
|
+
f"Created session {result.inserted_id} for user {user_id} on device {device_id}"
|
|
162
|
+
)
|
|
163
|
+
return session_doc
|
|
164
|
+
except (
|
|
165
|
+
OperationFailure,
|
|
166
|
+
ConnectionFailure,
|
|
167
|
+
ServerSelectionTimeoutError,
|
|
168
|
+
ValueError,
|
|
169
|
+
TypeError,
|
|
170
|
+
) as e:
|
|
171
|
+
logger.error(
|
|
172
|
+
f"Error creating session for user {user_id}: {e}", exc_info=True
|
|
173
|
+
)
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
async def update_session_activity(
|
|
177
|
+
self, refresh_jti: str, ip_address: Optional[str] = None
|
|
178
|
+
) -> bool:
|
|
179
|
+
"""
|
|
180
|
+
Update session last_seen timestamp (activity tracking).
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
refresh_jti: Refresh token JWT ID
|
|
184
|
+
ip_address: Optional IP address update
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
True if session was updated, False otherwise
|
|
188
|
+
"""
|
|
189
|
+
try:
|
|
190
|
+
update_data = {
|
|
191
|
+
"last_seen": datetime.utcnow(),
|
|
192
|
+
}
|
|
193
|
+
if ip_address:
|
|
194
|
+
update_data["ip_address"] = ip_address
|
|
195
|
+
|
|
196
|
+
result = await self.collection.update_one(
|
|
197
|
+
{"refresh_jti": refresh_jti, "active": True}, {"$set": update_data}
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
return result.modified_count > 0
|
|
201
|
+
except (
|
|
202
|
+
OperationFailure,
|
|
203
|
+
ConnectionFailure,
|
|
204
|
+
ServerSelectionTimeoutError,
|
|
205
|
+
ValueError,
|
|
206
|
+
TypeError,
|
|
207
|
+
) as e:
|
|
208
|
+
logger.error(
|
|
209
|
+
f"Error updating session activity for {refresh_jti}: {e}", exc_info=True
|
|
210
|
+
)
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
async def get_session_by_refresh_token(
|
|
214
|
+
self, refresh_jti: str
|
|
215
|
+
) -> Optional[Dict[str, Any]]:
|
|
216
|
+
"""
|
|
217
|
+
Get session by refresh token JWT ID.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
refresh_jti: Refresh token JWT ID
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Session document or None if not found
|
|
224
|
+
"""
|
|
225
|
+
try:
|
|
226
|
+
session = await self.collection.find_one(
|
|
227
|
+
{"refresh_jti": refresh_jti, "active": True}
|
|
228
|
+
)
|
|
229
|
+
return session
|
|
230
|
+
except (
|
|
231
|
+
OperationFailure,
|
|
232
|
+
ConnectionFailure,
|
|
233
|
+
ServerSelectionTimeoutError,
|
|
234
|
+
ValueError,
|
|
235
|
+
TypeError,
|
|
236
|
+
) as e:
|
|
237
|
+
logger.error(
|
|
238
|
+
f"Error getting session for refresh token {refresh_jti}: {e}",
|
|
239
|
+
exc_info=True,
|
|
240
|
+
)
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
async def validate_session_fingerprint(
|
|
244
|
+
self, refresh_jti: str, current_fingerprint: str, strict: Optional[bool] = None
|
|
245
|
+
) -> bool:
|
|
246
|
+
"""
|
|
247
|
+
Validate session fingerprint matches stored fingerprint.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
refresh_jti: Refresh token JWT ID
|
|
251
|
+
current_fingerprint: Current session fingerprint to validate
|
|
252
|
+
strict: If True, reject if fingerprint doesn't match. If False,
|
|
253
|
+
allow if no fingerprint stored.
|
|
254
|
+
If None, uses self.fingerprinting_strict
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
True if fingerprint is valid, False otherwise
|
|
258
|
+
"""
|
|
259
|
+
if not self.fingerprinting_enabled:
|
|
260
|
+
return True
|
|
261
|
+
|
|
262
|
+
try:
|
|
263
|
+
session = await self.get_session_by_refresh_token(refresh_jti)
|
|
264
|
+
if not session:
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
stored_fingerprint = session.get("session_fingerprint")
|
|
268
|
+
|
|
269
|
+
if not stored_fingerprint:
|
|
270
|
+
strict_mode = (
|
|
271
|
+
strict if strict is not None else self.fingerprinting_strict
|
|
272
|
+
)
|
|
273
|
+
return not strict_mode
|
|
274
|
+
|
|
275
|
+
return stored_fingerprint == current_fingerprint
|
|
276
|
+
except (
|
|
277
|
+
OperationFailure,
|
|
278
|
+
ConnectionFailure,
|
|
279
|
+
ServerSelectionTimeoutError,
|
|
280
|
+
ValueError,
|
|
281
|
+
TypeError,
|
|
282
|
+
AttributeError,
|
|
283
|
+
) as e:
|
|
284
|
+
logger.error(f"Error validating session fingerprint: {e}", exc_info=True)
|
|
285
|
+
return False
|
|
286
|
+
|
|
287
|
+
def configure_fingerprinting(self, enabled: bool = True, strict: bool = False):
|
|
288
|
+
"""
|
|
289
|
+
Configure session fingerprinting settings.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
enabled: Enable/disable session fingerprinting
|
|
293
|
+
strict: If True, reject requests when fingerprint doesn't match
|
|
294
|
+
"""
|
|
295
|
+
self.fingerprinting_enabled = enabled
|
|
296
|
+
self.fingerprinting_strict = strict
|
|
297
|
+
|
|
298
|
+
async def get_user_sessions(
|
|
299
|
+
self, user_id: str, active_only: bool = True
|
|
300
|
+
) -> List[Dict[str, Any]]:
|
|
301
|
+
"""
|
|
302
|
+
Get all sessions for a user.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
user_id: User identifier
|
|
306
|
+
active_only: If True, only return active sessions
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
List of session documents, sorted by last_seen (newest first)
|
|
310
|
+
"""
|
|
311
|
+
try:
|
|
312
|
+
query = {"user_id": user_id}
|
|
313
|
+
if active_only:
|
|
314
|
+
query["active"] = True
|
|
315
|
+
|
|
316
|
+
sessions = (
|
|
317
|
+
await self.collection.find(query).sort("last_seen", -1).to_list(None)
|
|
318
|
+
)
|
|
319
|
+
return sessions
|
|
320
|
+
except (
|
|
321
|
+
OperationFailure,
|
|
322
|
+
ConnectionFailure,
|
|
323
|
+
ServerSelectionTimeoutError,
|
|
324
|
+
ValueError,
|
|
325
|
+
TypeError,
|
|
326
|
+
) as e:
|
|
327
|
+
logger.error(
|
|
328
|
+
f"Error getting sessions for user {user_id}: {e}", exc_info=True
|
|
329
|
+
)
|
|
330
|
+
return []
|
|
331
|
+
|
|
332
|
+
async def revoke_session(self, session_id: Any) -> bool:
|
|
333
|
+
"""
|
|
334
|
+
Revoke a specific session.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
session_id: Session _id (ObjectId or string)
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
True if session was revoked, False otherwise
|
|
341
|
+
"""
|
|
342
|
+
try:
|
|
343
|
+
# Convert to ObjectId if string
|
|
344
|
+
if isinstance(session_id, str):
|
|
345
|
+
try:
|
|
346
|
+
session_id = ObjectId(session_id)
|
|
347
|
+
except (TypeError, ValueError):
|
|
348
|
+
# Type 2: Recoverable - invalid format, return False
|
|
349
|
+
logger.warning(f"Invalid session_id format: {session_id}")
|
|
350
|
+
return False
|
|
351
|
+
|
|
352
|
+
result = await self.collection.update_one(
|
|
353
|
+
{"_id": session_id},
|
|
354
|
+
{"$set": {"active": False, "revoked_at": datetime.utcnow()}},
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
if result.modified_count > 0:
|
|
358
|
+
logger.debug(f"Session {session_id} revoked")
|
|
359
|
+
return result.modified_count > 0
|
|
360
|
+
except (
|
|
361
|
+
OperationFailure,
|
|
362
|
+
ConnectionFailure,
|
|
363
|
+
ServerSelectionTimeoutError,
|
|
364
|
+
ValueError,
|
|
365
|
+
TypeError,
|
|
366
|
+
) as e:
|
|
367
|
+
logger.error(f"Error revoking session {session_id}: {e}", exc_info=True)
|
|
368
|
+
return False
|
|
369
|
+
|
|
370
|
+
async def revoke_user_sessions(
|
|
371
|
+
self, user_id: str, exclude_device_id: Optional[str] = None
|
|
372
|
+
) -> int:
|
|
373
|
+
"""
|
|
374
|
+
Revoke all sessions for a user.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
user_id: User identifier
|
|
378
|
+
exclude_device_id: Optional device_id to exclude from revocation
|
|
379
|
+
|
|
380
|
+
Returns:
|
|
381
|
+
Number of sessions revoked
|
|
382
|
+
"""
|
|
383
|
+
try:
|
|
384
|
+
query = {"user_id": user_id, "active": True}
|
|
385
|
+
if exclude_device_id:
|
|
386
|
+
query["device_id"] = {"$ne": exclude_device_id}
|
|
387
|
+
|
|
388
|
+
result = await self.collection.update_many(
|
|
389
|
+
query, {"$set": {"active": False, "revoked_at": datetime.utcnow()}}
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
revoked_count = result.modified_count
|
|
393
|
+
if revoked_count > 0:
|
|
394
|
+
logger.info(f"Revoked {revoked_count} sessions for user {user_id}")
|
|
395
|
+
return revoked_count
|
|
396
|
+
except (
|
|
397
|
+
OperationFailure,
|
|
398
|
+
ConnectionFailure,
|
|
399
|
+
ServerSelectionTimeoutError,
|
|
400
|
+
ValueError,
|
|
401
|
+
TypeError,
|
|
402
|
+
) as e:
|
|
403
|
+
logger.error(
|
|
404
|
+
f"Error revoking sessions for user {user_id}: {e}", exc_info=True
|
|
405
|
+
)
|
|
406
|
+
return 0
|
|
407
|
+
|
|
408
|
+
async def cleanup_inactive_sessions(self, user_id: Optional[str] = None) -> int:
|
|
409
|
+
"""
|
|
410
|
+
Clean up inactive sessions (beyond inactivity timeout).
|
|
411
|
+
|
|
412
|
+
Args:
|
|
413
|
+
user_id: Optional user_id to limit cleanup to specific user
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
Number of sessions cleaned up
|
|
417
|
+
"""
|
|
418
|
+
try:
|
|
419
|
+
cutoff_time = datetime.utcnow() - timedelta(seconds=self.inactivity_timeout)
|
|
420
|
+
|
|
421
|
+
query = {"active": True, "last_seen": {"$lt": cutoff_time}}
|
|
422
|
+
if user_id:
|
|
423
|
+
query["user_id"] = user_id
|
|
424
|
+
|
|
425
|
+
result = await self.collection.update_many(
|
|
426
|
+
query,
|
|
427
|
+
{
|
|
428
|
+
"$set": {
|
|
429
|
+
"active": False,
|
|
430
|
+
"revoked_at": datetime.utcnow(),
|
|
431
|
+
"reason": "inactivity",
|
|
432
|
+
}
|
|
433
|
+
},
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
cleaned_count = result.modified_count
|
|
437
|
+
if cleaned_count > 0:
|
|
438
|
+
logger.debug(f"Cleaned up {cleaned_count} inactive sessions")
|
|
439
|
+
return cleaned_count
|
|
440
|
+
except (
|
|
441
|
+
OperationFailure,
|
|
442
|
+
ConnectionFailure,
|
|
443
|
+
ServerSelectionTimeoutError,
|
|
444
|
+
ValueError,
|
|
445
|
+
TypeError,
|
|
446
|
+
) as e:
|
|
447
|
+
logger.error(f"Error cleaning up inactive sessions: {e}", exc_info=True)
|
|
448
|
+
return 0
|
|
449
|
+
|
|
450
|
+
async def revoke_session_by_refresh_token(self, refresh_jti: str) -> bool:
|
|
451
|
+
"""
|
|
452
|
+
Revoke session by refresh token JWT ID.
|
|
453
|
+
|
|
454
|
+
Args:
|
|
455
|
+
refresh_jti: Refresh token JWT ID
|
|
456
|
+
|
|
457
|
+
Returns:
|
|
458
|
+
True if session was revoked, False otherwise
|
|
459
|
+
"""
|
|
460
|
+
try:
|
|
461
|
+
result = await self.collection.update_one(
|
|
462
|
+
{"refresh_jti": refresh_jti, "active": True},
|
|
463
|
+
{"$set": {"active": False, "revoked_at": datetime.utcnow()}},
|
|
464
|
+
)
|
|
465
|
+
return result.modified_count > 0
|
|
466
|
+
except (
|
|
467
|
+
OperationFailure,
|
|
468
|
+
ConnectionFailure,
|
|
469
|
+
ServerSelectionTimeoutError,
|
|
470
|
+
ValueError,
|
|
471
|
+
TypeError,
|
|
472
|
+
) as e:
|
|
473
|
+
logger.error(
|
|
474
|
+
f"Error revoking session by refresh token {refresh_jti}: {e}",
|
|
475
|
+
exc_info=True,
|
|
476
|
+
)
|
|
477
|
+
return False
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Token Lifecycle Management
|
|
3
|
+
|
|
4
|
+
Provides utilities for managing token lifecycle, rotation, and expiration handling.
|
|
5
|
+
|
|
6
|
+
This module is part of MDB_ENGINE - MongoDB Engine.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import logging
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import Any, Dict, Optional
|
|
12
|
+
|
|
13
|
+
from ..config import ACCESS_TOKEN_TTL as CONFIG_ACCESS_TTL
|
|
14
|
+
from .jwt import extract_token_metadata
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_token_expiry_time(token: str, secret_key: str) -> Optional[datetime]:
|
|
20
|
+
"""
|
|
21
|
+
Get the expiration time of a token.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
token: JWT token string
|
|
25
|
+
secret_key: Secret key for decoding
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Expiration datetime or None if token is invalid
|
|
29
|
+
"""
|
|
30
|
+
try:
|
|
31
|
+
metadata = extract_token_metadata(token, secret_key)
|
|
32
|
+
if metadata and metadata.get("exp"):
|
|
33
|
+
exp_timestamp = metadata["exp"]
|
|
34
|
+
if isinstance(exp_timestamp, (int, float)):
|
|
35
|
+
return datetime.utcfromtimestamp(exp_timestamp)
|
|
36
|
+
return None
|
|
37
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
38
|
+
logger.debug(f"Error getting token expiry time: {e}")
|
|
39
|
+
return None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def is_token_expiring_soon(
|
|
43
|
+
token: str, secret_key: str, threshold_seconds: Optional[int] = None
|
|
44
|
+
) -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Check if a token is expiring soon.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
token: JWT token string
|
|
50
|
+
secret_key: Secret key for decoding
|
|
51
|
+
threshold_seconds: Seconds before expiry to consider "soon" (default: 10% of TTL)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
True if token is expiring soon, False otherwise
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
if threshold_seconds is None:
|
|
58
|
+
threshold_seconds = int(CONFIG_ACCESS_TTL * 0.1) # 10% of TTL
|
|
59
|
+
|
|
60
|
+
expiry_time = get_token_expiry_time(token, secret_key)
|
|
61
|
+
if expiry_time is None:
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
time_until_expiry = (expiry_time - datetime.utcnow()).total_seconds()
|
|
65
|
+
return time_until_expiry <= threshold_seconds
|
|
66
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
67
|
+
logger.debug(f"Error checking if token expiring soon: {e}")
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def should_refresh_token(
|
|
72
|
+
token: str, secret_key: str, refresh_threshold: Optional[int] = None
|
|
73
|
+
) -> bool:
|
|
74
|
+
"""
|
|
75
|
+
Determine if a token should be refreshed.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
token: JWT token string
|
|
79
|
+
secret_key: Secret key for decoding
|
|
80
|
+
refresh_threshold: Seconds before expiry to trigger refresh (default: 20% of TTL)
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
True if token should be refreshed, False otherwise
|
|
84
|
+
"""
|
|
85
|
+
try:
|
|
86
|
+
if refresh_threshold is None:
|
|
87
|
+
refresh_threshold = int(CONFIG_ACCESS_TTL * 0.2) # 20% of TTL
|
|
88
|
+
|
|
89
|
+
expiry_time = get_token_expiry_time(token, secret_key)
|
|
90
|
+
if expiry_time is None:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
time_until_expiry = (expiry_time - datetime.utcnow()).total_seconds()
|
|
94
|
+
return time_until_expiry <= refresh_threshold
|
|
95
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
96
|
+
logger.debug(f"Error determining if token should refresh: {e}")
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_token_age(token: str, secret_key: str) -> Optional[float]:
|
|
101
|
+
"""
|
|
102
|
+
Get the age of a token in seconds.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
token: JWT token string
|
|
106
|
+
secret_key: Secret key for decoding
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Token age in seconds or None if invalid
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
metadata = extract_token_metadata(token, secret_key)
|
|
113
|
+
if metadata and metadata.get("iat"):
|
|
114
|
+
iat_timestamp = metadata["iat"]
|
|
115
|
+
if isinstance(iat_timestamp, (int, float)):
|
|
116
|
+
issued_at = datetime.utcfromtimestamp(iat_timestamp)
|
|
117
|
+
age = (datetime.utcnow() - issued_at).total_seconds()
|
|
118
|
+
return age
|
|
119
|
+
return None
|
|
120
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
121
|
+
logger.debug(f"Error getting token age: {e}")
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def get_time_until_expiry(token: str, secret_key: str) -> Optional[float]:
|
|
126
|
+
"""
|
|
127
|
+
Get time until token expiry in seconds.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
token: JWT token string
|
|
131
|
+
secret_key: Secret key for decoding
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Seconds until expiry (negative if expired) or None if invalid
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
expiry_time = get_token_expiry_time(token, secret_key)
|
|
138
|
+
if expiry_time is None:
|
|
139
|
+
return None
|
|
140
|
+
|
|
141
|
+
time_until = (expiry_time - datetime.utcnow()).total_seconds()
|
|
142
|
+
return time_until
|
|
143
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
144
|
+
logger.debug(f"Error getting time until expiry: {e}")
|
|
145
|
+
return None
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def validate_token_version(
|
|
149
|
+
token: str, secret_key: str, required_version: Optional[str] = None
|
|
150
|
+
) -> bool:
|
|
151
|
+
"""
|
|
152
|
+
Validate token version compatibility.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
token: JWT token string
|
|
156
|
+
secret_key: Secret key for decoding
|
|
157
|
+
required_version: Optional required version (defaults to current version)
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
True if token version is valid, False otherwise
|
|
161
|
+
"""
|
|
162
|
+
try:
|
|
163
|
+
from ..constants import CURRENT_TOKEN_VERSION
|
|
164
|
+
|
|
165
|
+
metadata = extract_token_metadata(token, secret_key)
|
|
166
|
+
if not metadata:
|
|
167
|
+
return False
|
|
168
|
+
|
|
169
|
+
token_version = metadata.get("version")
|
|
170
|
+
if required_version is None:
|
|
171
|
+
required_version = CURRENT_TOKEN_VERSION
|
|
172
|
+
|
|
173
|
+
# For now, exact match required (can be extended for version ranges)
|
|
174
|
+
return token_version == required_version
|
|
175
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
176
|
+
logger.debug(f"Error validating token version: {e}")
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def get_token_info(token: str, secret_key: str) -> Optional[Dict[str, Any]]:
|
|
181
|
+
"""
|
|
182
|
+
Get comprehensive token information.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
token: JWT token string
|
|
186
|
+
secret_key: Secret key for decoding
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Dictionary with token information or None if invalid
|
|
190
|
+
"""
|
|
191
|
+
try:
|
|
192
|
+
metadata = extract_token_metadata(token, secret_key)
|
|
193
|
+
if not metadata:
|
|
194
|
+
return None
|
|
195
|
+
|
|
196
|
+
expiry_time = get_token_expiry_time(token, secret_key)
|
|
197
|
+
age = get_token_age(token, secret_key)
|
|
198
|
+
time_until_expiry = get_time_until_expiry(token, secret_key)
|
|
199
|
+
|
|
200
|
+
info = {
|
|
201
|
+
**metadata,
|
|
202
|
+
"expiry_time": expiry_time.isoformat() if expiry_time else None,
|
|
203
|
+
"age_seconds": age,
|
|
204
|
+
"time_until_expiry_seconds": time_until_expiry,
|
|
205
|
+
"is_expired": time_until_expiry is not None and time_until_expiry < 0,
|
|
206
|
+
"is_expiring_soon": is_token_expiring_soon(token, secret_key),
|
|
207
|
+
"should_refresh": should_refresh_token(token, secret_key),
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
return info
|
|
211
|
+
except (ValueError, TypeError, AttributeError, KeyError) as e:
|
|
212
|
+
logger.debug(f"Error getting token info: {e}")
|
|
213
|
+
return None
|