fastapi-cachex 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_cachex/__init__.py +20 -0
- fastapi_cachex/backends/__init__.py +4 -4
- fastapi_cachex/backends/memcached.py +21 -2
- fastapi_cachex/backends/memory.py +33 -5
- fastapi_cachex/backends/redis.py +29 -6
- fastapi_cachex/cache.py +59 -19
- fastapi_cachex/dependencies.py +2 -2
- fastapi_cachex/proxy.py +9 -2
- fastapi_cachex/routes.py +6 -5
- fastapi_cachex/session/__init__.py +21 -0
- fastapi_cachex/session/config.py +70 -0
- fastapi_cachex/session/dependencies.py +65 -0
- fastapi_cachex/session/exceptions.py +25 -0
- fastapi_cachex/session/manager.py +389 -0
- fastapi_cachex/session/middleware.py +149 -0
- fastapi_cachex/session/models.py +185 -0
- fastapi_cachex/session/security.py +111 -0
- fastapi_cachex/state/__init__.py +8 -0
- fastapi_cachex/state/exceptions.py +19 -0
- fastapi_cachex/state/manager.py +258 -0
- fastapi_cachex/state/models.py +31 -0
- fastapi_cachex/types.py +9 -0
- {fastapi_cachex-0.2.1.dist-info → fastapi_cachex-0.2.3.dist-info}/METADATA +23 -5
- fastapi_cachex-0.2.3.dist-info/RECORD +29 -0
- fastapi_cachex-0.2.1.dist-info/RECORD +0 -17
- {fastapi_cachex-0.2.1.dist-info → fastapi_cachex-0.2.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Session configuration settings."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
from pydantic import SecretStr
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SessionConfig(BaseModel):
|
|
11
|
+
"""Session configuration settings."""
|
|
12
|
+
|
|
13
|
+
# Session lifetime
|
|
14
|
+
session_ttl: int = Field(
|
|
15
|
+
default=3600,
|
|
16
|
+
description="Session time-to-live in seconds (default: 1 hour)",
|
|
17
|
+
)
|
|
18
|
+
absolute_timeout: int | None = Field(
|
|
19
|
+
default=None,
|
|
20
|
+
description="Absolute session timeout in seconds (None = no absolute timeout)",
|
|
21
|
+
)
|
|
22
|
+
sliding_expiration: bool = Field(
|
|
23
|
+
default=True,
|
|
24
|
+
description="Whether to refresh session expiry on each access",
|
|
25
|
+
)
|
|
26
|
+
sliding_threshold: float = Field(
|
|
27
|
+
default=0.5,
|
|
28
|
+
ge=0.0,
|
|
29
|
+
le=1.0,
|
|
30
|
+
description="Fraction of TTL that must pass before sliding refresh (0.5 = refresh after half TTL)",
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Token settings
|
|
34
|
+
header_name: str = Field(
|
|
35
|
+
default="X-Session-Token",
|
|
36
|
+
description="Custom header name for session token",
|
|
37
|
+
)
|
|
38
|
+
use_bearer_token: bool = Field(
|
|
39
|
+
default=True,
|
|
40
|
+
description="Whether to accept Authorization Bearer tokens",
|
|
41
|
+
)
|
|
42
|
+
token_source_priority: list[Literal["header", "bearer"]] = Field(
|
|
43
|
+
default=["header", "bearer"],
|
|
44
|
+
description="Priority order for token sources",
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# Security settings
|
|
48
|
+
secret_key: SecretStr = Field(
|
|
49
|
+
...,
|
|
50
|
+
min_length=32,
|
|
51
|
+
description="Secret key for signing session tokens (min 32 characters)",
|
|
52
|
+
)
|
|
53
|
+
ip_binding: bool = Field(
|
|
54
|
+
default=False,
|
|
55
|
+
description="Whether to bind session to client IP address",
|
|
56
|
+
)
|
|
57
|
+
user_agent_binding: bool = Field(
|
|
58
|
+
default=False,
|
|
59
|
+
description="Whether to bind session to User-Agent",
|
|
60
|
+
)
|
|
61
|
+
regenerate_on_login: bool = Field(
|
|
62
|
+
default=True,
|
|
63
|
+
description="Whether to regenerate session ID on login",
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Backend settings
|
|
67
|
+
backend_key_prefix: str = Field(
|
|
68
|
+
default="session:",
|
|
69
|
+
description="Prefix for session keys in backend storage",
|
|
70
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""FastAPI dependency injection utilities for session management."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
from fastapi import Depends
|
|
6
|
+
from fastapi import HTTPException
|
|
7
|
+
from fastapi import Request
|
|
8
|
+
from fastapi import status
|
|
9
|
+
|
|
10
|
+
from .models import Session
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_optional_session(request: Request) -> Session | None:
|
|
14
|
+
"""Get session from request state (optional).
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
request: FastAPI request object
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Session object or None if not authenticated
|
|
21
|
+
"""
|
|
22
|
+
return getattr(request.state, "__fastapi_cachex_session", None)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_session(request: Request) -> Session:
|
|
26
|
+
"""Get session from request state (required).
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
request: FastAPI request object
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Session object
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
HTTPException: 401 if session not found
|
|
36
|
+
"""
|
|
37
|
+
session: Session | None = getattr(request.state, "__fastapi_cachex_session", None)
|
|
38
|
+
if session is None:
|
|
39
|
+
raise HTTPException(
|
|
40
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
41
|
+
detail="Authentication required",
|
|
42
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
43
|
+
)
|
|
44
|
+
return session
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def require_session(request: Request) -> Session:
|
|
48
|
+
"""Require authenticated session (alias for get_session).
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
request: FastAPI request object
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Session object
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
HTTPException: 401 if session not found
|
|
58
|
+
"""
|
|
59
|
+
return get_session(request)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# Type annotations for dependency injection
|
|
63
|
+
OptionalSession = Annotated[Session | None, Depends(get_optional_session)]
|
|
64
|
+
RequiredSession = Annotated[Session, Depends(get_session)]
|
|
65
|
+
SessionDep = Annotated[Session, Depends(require_session)]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""Session-related exceptions."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class SessionError(Exception):
|
|
5
|
+
"""Base exception for session errors."""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SessionNotFoundError(SessionError):
|
|
9
|
+
"""Raised when a session is not found."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SessionExpiredError(SessionError):
|
|
13
|
+
"""Raised when a session has expired."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SessionInvalidError(SessionError):
|
|
17
|
+
"""Raised when a session is invalid."""
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SessionSecurityError(SessionError):
|
|
21
|
+
"""Raised when a session fails security checks."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SessionTokenError(SessionError):
|
|
25
|
+
"""Raised when there's an issue with session token."""
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""Session manager for CRUD operations."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from datetime import timezone
|
|
7
|
+
|
|
8
|
+
from fastapi_cachex.backends.base import BaseCacheBackend
|
|
9
|
+
from fastapi_cachex.types import ETagContent
|
|
10
|
+
|
|
11
|
+
from .config import SessionConfig
|
|
12
|
+
from .exceptions import SessionExpiredError
|
|
13
|
+
from .exceptions import SessionInvalidError
|
|
14
|
+
from .exceptions import SessionNotFoundError
|
|
15
|
+
from .exceptions import SessionSecurityError
|
|
16
|
+
from .exceptions import SessionTokenError
|
|
17
|
+
from .models import Session
|
|
18
|
+
from .models import SessionStatus
|
|
19
|
+
from .models import SessionToken
|
|
20
|
+
from .models import SessionUser
|
|
21
|
+
from .security import SecurityManager
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SessionManager:
|
|
27
|
+
"""Manages session lifecycle and storage."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, backend: BaseCacheBackend, config: SessionConfig) -> None:
|
|
30
|
+
"""Initialize session manager.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
backend: Cache backend for session storage
|
|
34
|
+
config: Session configuration
|
|
35
|
+
"""
|
|
36
|
+
self.backend = backend
|
|
37
|
+
self.config = config
|
|
38
|
+
secret_value = config.secret_key.get_secret_value()
|
|
39
|
+
self.security = SecurityManager(secret_value)
|
|
40
|
+
logger.debug(
|
|
41
|
+
"SessionManager initialized with backend prefix=%s",
|
|
42
|
+
config.backend_key_prefix,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def _get_backend_key(self, session_id: str) -> str:
|
|
46
|
+
"""Get backend storage key for a session.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
session_id: The session ID
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Backend storage key
|
|
53
|
+
"""
|
|
54
|
+
return f"{self.config.backend_key_prefix}{session_id}"
|
|
55
|
+
|
|
56
|
+
async def create_session(
|
|
57
|
+
self,
|
|
58
|
+
user: SessionUser | None = None,
|
|
59
|
+
ip_address: str | None = None,
|
|
60
|
+
user_agent: str | None = None,
|
|
61
|
+
**extra_data: dict[str, object],
|
|
62
|
+
) -> tuple[Session, str]:
|
|
63
|
+
"""Create a new session.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
user: Optional user data
|
|
67
|
+
ip_address: Client IP address (if IP binding enabled)
|
|
68
|
+
user_agent: Client User-Agent (if UA binding enabled)
|
|
69
|
+
**extra_data: Additional session data
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Tuple of (Session, token_string)
|
|
73
|
+
"""
|
|
74
|
+
# Create session
|
|
75
|
+
session = Session(
|
|
76
|
+
user=user,
|
|
77
|
+
data=extra_data,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# Set expiry
|
|
81
|
+
if self.config.session_ttl:
|
|
82
|
+
session.expires_at = datetime.now(timezone.utc) + timedelta(
|
|
83
|
+
seconds=self.config.session_ttl,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Bind IP and User-Agent if configured
|
|
87
|
+
if self.config.ip_binding and ip_address:
|
|
88
|
+
session.ip_address = ip_address
|
|
89
|
+
if self.config.user_agent_binding and user_agent:
|
|
90
|
+
session.user_agent = user_agent
|
|
91
|
+
|
|
92
|
+
# Store in backend
|
|
93
|
+
await self._save_session(session)
|
|
94
|
+
|
|
95
|
+
# Generate signed token
|
|
96
|
+
token = self._create_token(session.session_id)
|
|
97
|
+
logger.debug(
|
|
98
|
+
"Session created; id=%s ttl=%s ip=%s ua=%s",
|
|
99
|
+
session.session_id,
|
|
100
|
+
self.config.session_ttl,
|
|
101
|
+
session.ip_address,
|
|
102
|
+
session.user_agent,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return session, token.to_string()
|
|
106
|
+
|
|
107
|
+
async def get_session(
|
|
108
|
+
self,
|
|
109
|
+
token_string: str,
|
|
110
|
+
ip_address: str | None = None,
|
|
111
|
+
user_agent: str | None = None,
|
|
112
|
+
) -> Session:
|
|
113
|
+
"""Retrieve and validate a session.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
token_string: Session token string
|
|
117
|
+
ip_address: Current request IP address
|
|
118
|
+
user_agent: Current request User-Agent
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Session object
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
SessionTokenError: If token is invalid
|
|
125
|
+
SessionNotFoundError: If session not found
|
|
126
|
+
SessionExpiredError: If session has expired
|
|
127
|
+
SessionInvalidError: If session is not active
|
|
128
|
+
SessionSecurityError: If security checks fail
|
|
129
|
+
"""
|
|
130
|
+
# Parse and verify token
|
|
131
|
+
try:
|
|
132
|
+
token = SessionToken.from_string(token_string)
|
|
133
|
+
except ValueError as e:
|
|
134
|
+
logger.debug("Session token parse error: %s", e)
|
|
135
|
+
raise SessionTokenError(str(e)) from e
|
|
136
|
+
|
|
137
|
+
if not self.security.verify_signature(token.session_id, token.signature):
|
|
138
|
+
msg = "Invalid session signature"
|
|
139
|
+
logger.debug(
|
|
140
|
+
"Session signature verification failed; id=%s", token.session_id
|
|
141
|
+
)
|
|
142
|
+
raise SessionSecurityError(msg)
|
|
143
|
+
|
|
144
|
+
# Load session from backend
|
|
145
|
+
session = await self._load_session(token.session_id)
|
|
146
|
+
if not session:
|
|
147
|
+
msg = f"Session {token.session_id} not found"
|
|
148
|
+
logger.debug("Session not found; id=%s", token.session_id)
|
|
149
|
+
raise SessionNotFoundError(msg)
|
|
150
|
+
|
|
151
|
+
# Validate session
|
|
152
|
+
if session.status != SessionStatus.ACTIVE:
|
|
153
|
+
msg = f"Session is {session.status}"
|
|
154
|
+
logger.debug(
|
|
155
|
+
"Session not active; id=%s status=%s",
|
|
156
|
+
session.session_id,
|
|
157
|
+
session.status,
|
|
158
|
+
)
|
|
159
|
+
raise SessionInvalidError(msg)
|
|
160
|
+
|
|
161
|
+
if session.is_expired():
|
|
162
|
+
session.status = SessionStatus.EXPIRED
|
|
163
|
+
await self._save_session(session)
|
|
164
|
+
msg = "Session has expired"
|
|
165
|
+
logger.debug("Session expired; id=%s", session.session_id)
|
|
166
|
+
raise SessionExpiredError(msg)
|
|
167
|
+
|
|
168
|
+
# Security checks
|
|
169
|
+
if self.config.ip_binding and not self.security.check_ip_match(
|
|
170
|
+
session,
|
|
171
|
+
ip_address,
|
|
172
|
+
):
|
|
173
|
+
msg = "IP address mismatch"
|
|
174
|
+
logger.debug(
|
|
175
|
+
"IP mismatch; id=%s expected=%s got=%s",
|
|
176
|
+
session.session_id,
|
|
177
|
+
session.ip_address,
|
|
178
|
+
ip_address,
|
|
179
|
+
)
|
|
180
|
+
raise SessionSecurityError(msg)
|
|
181
|
+
|
|
182
|
+
if self.config.user_agent_binding and not self.security.check_user_agent_match(
|
|
183
|
+
session,
|
|
184
|
+
user_agent,
|
|
185
|
+
):
|
|
186
|
+
msg = "User-Agent mismatch"
|
|
187
|
+
logger.debug(
|
|
188
|
+
"UA mismatch; id=%s expected=%s got=%s",
|
|
189
|
+
session.session_id,
|
|
190
|
+
session.user_agent,
|
|
191
|
+
user_agent,
|
|
192
|
+
)
|
|
193
|
+
raise SessionSecurityError(msg)
|
|
194
|
+
|
|
195
|
+
# Update last accessed and handle sliding expiration
|
|
196
|
+
session.update_last_accessed()
|
|
197
|
+
|
|
198
|
+
if self.config.sliding_expiration and session.expires_at:
|
|
199
|
+
time_remaining = (
|
|
200
|
+
session.expires_at - datetime.now(timezone.utc)
|
|
201
|
+
).total_seconds()
|
|
202
|
+
threshold = self.config.session_ttl * self.config.sliding_threshold
|
|
203
|
+
|
|
204
|
+
if time_remaining < threshold:
|
|
205
|
+
session.renew(self.config.session_ttl)
|
|
206
|
+
logger.debug(
|
|
207
|
+
"Session renewed (sliding expiration); id=%s ttl=%s",
|
|
208
|
+
session.session_id,
|
|
209
|
+
self.config.session_ttl,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
await self._save_session(session)
|
|
213
|
+
|
|
214
|
+
return session
|
|
215
|
+
|
|
216
|
+
async def update_session(self, session: Session) -> None:
|
|
217
|
+
"""Update an existing session.
|
|
218
|
+
|
|
219
|
+
Args:
|
|
220
|
+
session: Session to update
|
|
221
|
+
"""
|
|
222
|
+
session.update_last_accessed()
|
|
223
|
+
await self._save_session(session)
|
|
224
|
+
logger.debug("Session updated; id=%s", session.session_id)
|
|
225
|
+
|
|
226
|
+
async def delete_session(self, session_id: str) -> None:
|
|
227
|
+
"""Delete a session.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
session_id: Session ID to delete
|
|
231
|
+
"""
|
|
232
|
+
key = self._get_backend_key(session_id)
|
|
233
|
+
await self.backend.delete(key)
|
|
234
|
+
logger.debug("Session deleted; id=%s", session_id)
|
|
235
|
+
|
|
236
|
+
async def invalidate_session(self, session: Session) -> None:
|
|
237
|
+
"""Invalidate a session.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
session: Session to invalidate
|
|
241
|
+
"""
|
|
242
|
+
session.invalidate()
|
|
243
|
+
await self._save_session(session)
|
|
244
|
+
logger.debug("Session invalidated; id=%s", session.session_id)
|
|
245
|
+
|
|
246
|
+
async def regenerate_session_id(
|
|
247
|
+
self,
|
|
248
|
+
session: Session,
|
|
249
|
+
) -> tuple[Session, str]:
|
|
250
|
+
"""Regenerate session ID (after login for security).
|
|
251
|
+
|
|
252
|
+
Args:
|
|
253
|
+
session: Session to regenerate
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tuple of (updated session, new token string)
|
|
257
|
+
"""
|
|
258
|
+
# Delete old session
|
|
259
|
+
await self.delete_session(session.session_id)
|
|
260
|
+
|
|
261
|
+
# Generate new ID
|
|
262
|
+
old_id = session.session_id
|
|
263
|
+
session.regenerate_id()
|
|
264
|
+
|
|
265
|
+
# Save with new ID
|
|
266
|
+
await self._save_session(session)
|
|
267
|
+
|
|
268
|
+
# Create new token
|
|
269
|
+
token = self._create_token(session.session_id)
|
|
270
|
+
logger.debug(
|
|
271
|
+
"Session ID regenerated; old_id=%s new_id=%s", old_id, session.session_id
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return session, token.to_string()
|
|
275
|
+
|
|
276
|
+
async def delete_user_sessions(self, user_id: str) -> int:
|
|
277
|
+
"""Delete all sessions for a user.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
user_id: User ID
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
Number of sessions deleted
|
|
284
|
+
"""
|
|
285
|
+
# This requires scanning all session keys
|
|
286
|
+
count = 0
|
|
287
|
+
|
|
288
|
+
try:
|
|
289
|
+
all_keys = await self.backend.get_all_keys()
|
|
290
|
+
for key in all_keys:
|
|
291
|
+
if key.startswith(self.config.backend_key_prefix):
|
|
292
|
+
session = await self._load_session_by_key(key)
|
|
293
|
+
if session and session.user and session.user.user_id == user_id:
|
|
294
|
+
await self.backend.delete(key)
|
|
295
|
+
count += 1
|
|
296
|
+
except NotImplementedError: # pragma: no cover
|
|
297
|
+
# Backend doesn't support get_all_keys, can't delete by user
|
|
298
|
+
pass
|
|
299
|
+
|
|
300
|
+
logger.debug("User sessions deleted; user_id=%s count=%s", user_id, count)
|
|
301
|
+
return count
|
|
302
|
+
|
|
303
|
+
async def clear_expired_sessions(self) -> int:
|
|
304
|
+
"""Clear all expired sessions.
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Number of sessions cleared
|
|
308
|
+
"""
|
|
309
|
+
count = 0
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
all_keys = await self.backend.get_all_keys()
|
|
313
|
+
for key in all_keys:
|
|
314
|
+
if key.startswith(self.config.backend_key_prefix):
|
|
315
|
+
session = await self._load_session_by_key(key)
|
|
316
|
+
if session and session.is_expired():
|
|
317
|
+
await self.backend.delete(key)
|
|
318
|
+
count += 1
|
|
319
|
+
except NotImplementedError: # pragma: no cover
|
|
320
|
+
# Backend doesn't support get_all_keys
|
|
321
|
+
pass
|
|
322
|
+
|
|
323
|
+
logger.debug("Expired sessions cleared; count=%s", count)
|
|
324
|
+
return count
|
|
325
|
+
|
|
326
|
+
def _create_token(self, session_id: str) -> SessionToken:
|
|
327
|
+
"""Create a signed session token.
|
|
328
|
+
|
|
329
|
+
Args:
|
|
330
|
+
session_id: Session ID to sign
|
|
331
|
+
|
|
332
|
+
Returns:
|
|
333
|
+
SessionToken object
|
|
334
|
+
"""
|
|
335
|
+
signature = self.security.sign_session_id(session_id)
|
|
336
|
+
return SessionToken(session_id=session_id, signature=signature)
|
|
337
|
+
|
|
338
|
+
async def _save_session(self, session: Session) -> None:
|
|
339
|
+
"""Save session to backend.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
session: Session to save
|
|
343
|
+
"""
|
|
344
|
+
key = self._get_backend_key(session.session_id)
|
|
345
|
+
value = session.model_dump_json().encode("utf-8")
|
|
346
|
+
|
|
347
|
+
# Calculate TTL
|
|
348
|
+
ttl = None
|
|
349
|
+
if session.expires_at:
|
|
350
|
+
ttl = int((session.expires_at - datetime.now(timezone.utc)).total_seconds())
|
|
351
|
+
ttl = max(ttl, 1) # Ensure at least 1 second
|
|
352
|
+
|
|
353
|
+
# Store as bytes in cache backend (wrapped in ETagContent for compatibility)
|
|
354
|
+
etag = self.security.hash_data(value.decode("utf-8"))
|
|
355
|
+
await self.backend.set(key, ETagContent(etag=etag, content=value), ttl=ttl)
|
|
356
|
+
logger.debug("Session saved; id=%s ttl=%s", session.session_id, ttl)
|
|
357
|
+
|
|
358
|
+
async def _load_session(self, session_id: str) -> Session | None:
|
|
359
|
+
"""Load session from backend.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
session_id: Session ID to load
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Session object or None if not found
|
|
366
|
+
"""
|
|
367
|
+
key = self._get_backend_key(session_id)
|
|
368
|
+
return await self._load_session_by_key(key)
|
|
369
|
+
|
|
370
|
+
async def _load_session_by_key(self, key: str) -> Session | None:
|
|
371
|
+
"""Load session from backend by key.
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
key: Backend key
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Session object or None if not found
|
|
378
|
+
"""
|
|
379
|
+
cached = await self.backend.get(key)
|
|
380
|
+
if not cached:
|
|
381
|
+
logger.debug("Session load MISS; key=%s", key)
|
|
382
|
+
return None
|
|
383
|
+
|
|
384
|
+
try:
|
|
385
|
+
return Session.model_validate_json(cached.content)
|
|
386
|
+
except (ValueError, TypeError): # pragma: no cover
|
|
387
|
+
# Invalid session data
|
|
388
|
+
logger.debug("Session load DESERIALIZE ERROR; key=%s", key)
|
|
389
|
+
return None
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
"""Session middleware for FastAPI."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
|
|
6
|
+
from fastapi import Request
|
|
7
|
+
from fastapi import Response
|
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
|
+
from starlette.middleware.base import RequestResponseEndpoint
|
|
10
|
+
from starlette.types import ASGIApp
|
|
11
|
+
|
|
12
|
+
from .config import SessionConfig
|
|
13
|
+
from .exceptions import SessionError
|
|
14
|
+
from .manager import SessionManager
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .models import Session
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SessionMiddleware(BaseHTTPMiddleware):
|
|
23
|
+
"""Middleware to handle session loading and cookie management."""
|
|
24
|
+
|
|
25
|
+
def __init__(
|
|
26
|
+
self,
|
|
27
|
+
app: ASGIApp,
|
|
28
|
+
session_manager: SessionManager,
|
|
29
|
+
config: SessionConfig | None = None,
|
|
30
|
+
) -> None:
|
|
31
|
+
"""Initialize session middleware.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
app: ASGI application
|
|
35
|
+
session_manager: Session manager instance
|
|
36
|
+
config: Session configuration
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(app)
|
|
39
|
+
self.session_manager = session_manager
|
|
40
|
+
|
|
41
|
+
if config is None:
|
|
42
|
+
config = self.session_manager.config
|
|
43
|
+
|
|
44
|
+
self.config = config
|
|
45
|
+
|
|
46
|
+
logger.debug(
|
|
47
|
+
"SessionMiddleware initialized; header=%s bearer=%s",
|
|
48
|
+
config.header_name,
|
|
49
|
+
config.use_bearer_token,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
async def dispatch(
|
|
53
|
+
self,
|
|
54
|
+
request: Request,
|
|
55
|
+
call_next: RequestResponseEndpoint,
|
|
56
|
+
) -> Response:
|
|
57
|
+
"""Process request and handle session.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
request: Incoming request
|
|
61
|
+
call_next: Next handler in chain
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Response
|
|
65
|
+
"""
|
|
66
|
+
# Extract session token from request
|
|
67
|
+
token = self._extract_token(request)
|
|
68
|
+
|
|
69
|
+
# Try to load session
|
|
70
|
+
session: Session | None = None
|
|
71
|
+
if token:
|
|
72
|
+
try:
|
|
73
|
+
ip_address = self._get_client_ip(request)
|
|
74
|
+
user_agent = request.headers.get("user-agent")
|
|
75
|
+
session = await self.session_manager.get_session(
|
|
76
|
+
token,
|
|
77
|
+
ip_address=ip_address,
|
|
78
|
+
user_agent=user_agent,
|
|
79
|
+
)
|
|
80
|
+
logger.debug("Session loaded in middleware; id=%s", session.session_id)
|
|
81
|
+
except SessionError:
|
|
82
|
+
# Session invalid/expired, continue without session
|
|
83
|
+
session = None
|
|
84
|
+
logger.debug("Session failed to load; token invalid/expired")
|
|
85
|
+
|
|
86
|
+
# Store session in request state
|
|
87
|
+
setattr(request.state, "__fastapi_cachex_session", session)
|
|
88
|
+
|
|
89
|
+
# Process request
|
|
90
|
+
response: Response = await call_next(request)
|
|
91
|
+
|
|
92
|
+
return response
|
|
93
|
+
|
|
94
|
+
def _extract_token(self, request: Request) -> str | None:
|
|
95
|
+
"""Extract session token from request.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
request: Incoming request
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
Session token or None
|
|
102
|
+
"""
|
|
103
|
+
for source in self.config.token_source_priority:
|
|
104
|
+
if source == "header":
|
|
105
|
+
token = request.headers.get(self.config.header_name)
|
|
106
|
+
if token:
|
|
107
|
+
logger.debug("Token extracted from header")
|
|
108
|
+
return token
|
|
109
|
+
|
|
110
|
+
elif source == "bearer":
|
|
111
|
+
if self.config.use_bearer_token:
|
|
112
|
+
auth_header = request.headers.get("authorization")
|
|
113
|
+
if auth_header and auth_header.startswith("Bearer "):
|
|
114
|
+
bearer_prefix_len = 7
|
|
115
|
+
token_value = auth_header[bearer_prefix_len:]
|
|
116
|
+
logger.debug("Token extracted from bearer auth")
|
|
117
|
+
return token_value
|
|
118
|
+
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
def _get_client_ip(self, request: Request) -> str | None:
|
|
122
|
+
"""Get client IP address from request.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
request: Incoming request
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Client IP address or None
|
|
129
|
+
"""
|
|
130
|
+
# Check X-Forwarded-For header (for proxied requests)
|
|
131
|
+
forwarded_for = request.headers.get("x-forwarded-for")
|
|
132
|
+
if forwarded_for:
|
|
133
|
+
# Get first IP from comma-separated list
|
|
134
|
+
ip = forwarded_for.split(",")[0].strip()
|
|
135
|
+
logger.debug("Client IP from X-Forwarded-For: %s", ip)
|
|
136
|
+
return ip
|
|
137
|
+
|
|
138
|
+
# Check X-Real-IP header
|
|
139
|
+
real_ip = request.headers.get("x-real-ip")
|
|
140
|
+
if real_ip:
|
|
141
|
+
logger.debug("Client IP from X-Real-IP: %s", real_ip)
|
|
142
|
+
return real_ip
|
|
143
|
+
|
|
144
|
+
# Fallback to direct client IP
|
|
145
|
+
if request.client:
|
|
146
|
+
logger.debug("Client IP from connection: %s", request.client.host)
|
|
147
|
+
return request.client.host
|
|
148
|
+
|
|
149
|
+
return None
|