fastapi-cachex 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,70 @@
1
+ """Session configuration settings."""
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel
6
+ from pydantic import Field
7
+ from pydantic import SecretStr
8
+
9
+
10
+ class SessionConfig(BaseModel):
11
+ """Session configuration settings."""
12
+
13
+ # Session lifetime
14
+ session_ttl: int = Field(
15
+ default=3600,
16
+ description="Session time-to-live in seconds (default: 1 hour)",
17
+ )
18
+ absolute_timeout: int | None = Field(
19
+ default=None,
20
+ description="Absolute session timeout in seconds (None = no absolute timeout)",
21
+ )
22
+ sliding_expiration: bool = Field(
23
+ default=True,
24
+ description="Whether to refresh session expiry on each access",
25
+ )
26
+ sliding_threshold: float = Field(
27
+ default=0.5,
28
+ ge=0.0,
29
+ le=1.0,
30
+ description="Fraction of TTL that must pass before sliding refresh (0.5 = refresh after half TTL)",
31
+ )
32
+
33
+ # Token settings
34
+ header_name: str = Field(
35
+ default="X-Session-Token",
36
+ description="Custom header name for session token",
37
+ )
38
+ use_bearer_token: bool = Field(
39
+ default=True,
40
+ description="Whether to accept Authorization Bearer tokens",
41
+ )
42
+ token_source_priority: list[Literal["header", "bearer"]] = Field(
43
+ default=["header", "bearer"],
44
+ description="Priority order for token sources",
45
+ )
46
+
47
+ # Security settings
48
+ secret_key: SecretStr = Field(
49
+ ...,
50
+ min_length=32,
51
+ description="Secret key for signing session tokens (min 32 characters)",
52
+ )
53
+ ip_binding: bool = Field(
54
+ default=False,
55
+ description="Whether to bind session to client IP address",
56
+ )
57
+ user_agent_binding: bool = Field(
58
+ default=False,
59
+ description="Whether to bind session to User-Agent",
60
+ )
61
+ regenerate_on_login: bool = Field(
62
+ default=True,
63
+ description="Whether to regenerate session ID on login",
64
+ )
65
+
66
+ # Backend settings
67
+ backend_key_prefix: str = Field(
68
+ default="session:",
69
+ description="Prefix for session keys in backend storage",
70
+ )
@@ -0,0 +1,65 @@
1
+ """FastAPI dependency injection utilities for session management."""
2
+
3
+ from typing import Annotated
4
+
5
+ from fastapi import Depends
6
+ from fastapi import HTTPException
7
+ from fastapi import Request
8
+ from fastapi import status
9
+
10
+ from .models import Session
11
+
12
+
13
+ def get_optional_session(request: Request) -> Session | None:
14
+ """Get session from request state (optional).
15
+
16
+ Args:
17
+ request: FastAPI request object
18
+
19
+ Returns:
20
+ Session object or None if not authenticated
21
+ """
22
+ return getattr(request.state, "__fastapi_cachex_session", None)
23
+
24
+
25
+ def get_session(request: Request) -> Session:
26
+ """Get session from request state (required).
27
+
28
+ Args:
29
+ request: FastAPI request object
30
+
31
+ Returns:
32
+ Session object
33
+
34
+ Raises:
35
+ HTTPException: 401 if session not found
36
+ """
37
+ session: Session | None = getattr(request.state, "__fastapi_cachex_session", None)
38
+ if session is None:
39
+ raise HTTPException(
40
+ status_code=status.HTTP_401_UNAUTHORIZED,
41
+ detail="Authentication required",
42
+ headers={"WWW-Authenticate": "Bearer"},
43
+ )
44
+ return session
45
+
46
+
47
+ def require_session(request: Request) -> Session:
48
+ """Require authenticated session (alias for get_session).
49
+
50
+ Args:
51
+ request: FastAPI request object
52
+
53
+ Returns:
54
+ Session object
55
+
56
+ Raises:
57
+ HTTPException: 401 if session not found
58
+ """
59
+ return get_session(request)
60
+
61
+
62
+ # Type annotations for dependency injection
63
+ OptionalSession = Annotated[Session | None, Depends(get_optional_session)]
64
+ RequiredSession = Annotated[Session, Depends(get_session)]
65
+ SessionDep = Annotated[Session, Depends(require_session)]
@@ -0,0 +1,25 @@
1
+ """Session-related exceptions."""
2
+
3
+
4
+ class SessionError(Exception):
5
+ """Base exception for session errors."""
6
+
7
+
8
+ class SessionNotFoundError(SessionError):
9
+ """Raised when a session is not found."""
10
+
11
+
12
+ class SessionExpiredError(SessionError):
13
+ """Raised when a session has expired."""
14
+
15
+
16
+ class SessionInvalidError(SessionError):
17
+ """Raised when a session is invalid."""
18
+
19
+
20
+ class SessionSecurityError(SessionError):
21
+ """Raised when a session fails security checks."""
22
+
23
+
24
+ class SessionTokenError(SessionError):
25
+ """Raised when there's an issue with session token."""
@@ -0,0 +1,389 @@
1
+ """Session manager for CRUD operations."""
2
+
3
+ import logging
4
+ from datetime import datetime
5
+ from datetime import timedelta
6
+ from datetime import timezone
7
+
8
+ from fastapi_cachex.backends.base import BaseCacheBackend
9
+ from fastapi_cachex.types import ETagContent
10
+
11
+ from .config import SessionConfig
12
+ from .exceptions import SessionExpiredError
13
+ from .exceptions import SessionInvalidError
14
+ from .exceptions import SessionNotFoundError
15
+ from .exceptions import SessionSecurityError
16
+ from .exceptions import SessionTokenError
17
+ from .models import Session
18
+ from .models import SessionStatus
19
+ from .models import SessionToken
20
+ from .models import SessionUser
21
+ from .security import SecurityManager
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class SessionManager:
27
+ """Manages session lifecycle and storage."""
28
+
29
+ def __init__(self, backend: BaseCacheBackend, config: SessionConfig) -> None:
30
+ """Initialize session manager.
31
+
32
+ Args:
33
+ backend: Cache backend for session storage
34
+ config: Session configuration
35
+ """
36
+ self.backend = backend
37
+ self.config = config
38
+ secret_value = config.secret_key.get_secret_value()
39
+ self.security = SecurityManager(secret_value)
40
+ logger.debug(
41
+ "SessionManager initialized with backend prefix=%s",
42
+ config.backend_key_prefix,
43
+ )
44
+
45
+ def _get_backend_key(self, session_id: str) -> str:
46
+ """Get backend storage key for a session.
47
+
48
+ Args:
49
+ session_id: The session ID
50
+
51
+ Returns:
52
+ Backend storage key
53
+ """
54
+ return f"{self.config.backend_key_prefix}{session_id}"
55
+
56
+ async def create_session(
57
+ self,
58
+ user: SessionUser | None = None,
59
+ ip_address: str | None = None,
60
+ user_agent: str | None = None,
61
+ **extra_data: dict[str, object],
62
+ ) -> tuple[Session, str]:
63
+ """Create a new session.
64
+
65
+ Args:
66
+ user: Optional user data
67
+ ip_address: Client IP address (if IP binding enabled)
68
+ user_agent: Client User-Agent (if UA binding enabled)
69
+ **extra_data: Additional session data
70
+
71
+ Returns:
72
+ Tuple of (Session, token_string)
73
+ """
74
+ # Create session
75
+ session = Session(
76
+ user=user,
77
+ data=extra_data,
78
+ )
79
+
80
+ # Set expiry
81
+ if self.config.session_ttl:
82
+ session.expires_at = datetime.now(timezone.utc) + timedelta(
83
+ seconds=self.config.session_ttl,
84
+ )
85
+
86
+ # Bind IP and User-Agent if configured
87
+ if self.config.ip_binding and ip_address:
88
+ session.ip_address = ip_address
89
+ if self.config.user_agent_binding and user_agent:
90
+ session.user_agent = user_agent
91
+
92
+ # Store in backend
93
+ await self._save_session(session)
94
+
95
+ # Generate signed token
96
+ token = self._create_token(session.session_id)
97
+ logger.debug(
98
+ "Session created; id=%s ttl=%s ip=%s ua=%s",
99
+ session.session_id,
100
+ self.config.session_ttl,
101
+ session.ip_address,
102
+ session.user_agent,
103
+ )
104
+
105
+ return session, token.to_string()
106
+
107
+ async def get_session(
108
+ self,
109
+ token_string: str,
110
+ ip_address: str | None = None,
111
+ user_agent: str | None = None,
112
+ ) -> Session:
113
+ """Retrieve and validate a session.
114
+
115
+ Args:
116
+ token_string: Session token string
117
+ ip_address: Current request IP address
118
+ user_agent: Current request User-Agent
119
+
120
+ Returns:
121
+ Session object
122
+
123
+ Raises:
124
+ SessionTokenError: If token is invalid
125
+ SessionNotFoundError: If session not found
126
+ SessionExpiredError: If session has expired
127
+ SessionInvalidError: If session is not active
128
+ SessionSecurityError: If security checks fail
129
+ """
130
+ # Parse and verify token
131
+ try:
132
+ token = SessionToken.from_string(token_string)
133
+ except ValueError as e:
134
+ logger.debug("Session token parse error: %s", e)
135
+ raise SessionTokenError(str(e)) from e
136
+
137
+ if not self.security.verify_signature(token.session_id, token.signature):
138
+ msg = "Invalid session signature"
139
+ logger.debug(
140
+ "Session signature verification failed; id=%s", token.session_id
141
+ )
142
+ raise SessionSecurityError(msg)
143
+
144
+ # Load session from backend
145
+ session = await self._load_session(token.session_id)
146
+ if not session:
147
+ msg = f"Session {token.session_id} not found"
148
+ logger.debug("Session not found; id=%s", token.session_id)
149
+ raise SessionNotFoundError(msg)
150
+
151
+ # Validate session
152
+ if session.status != SessionStatus.ACTIVE:
153
+ msg = f"Session is {session.status}"
154
+ logger.debug(
155
+ "Session not active; id=%s status=%s",
156
+ session.session_id,
157
+ session.status,
158
+ )
159
+ raise SessionInvalidError(msg)
160
+
161
+ if session.is_expired():
162
+ session.status = SessionStatus.EXPIRED
163
+ await self._save_session(session)
164
+ msg = "Session has expired"
165
+ logger.debug("Session expired; id=%s", session.session_id)
166
+ raise SessionExpiredError(msg)
167
+
168
+ # Security checks
169
+ if self.config.ip_binding and not self.security.check_ip_match(
170
+ session,
171
+ ip_address,
172
+ ):
173
+ msg = "IP address mismatch"
174
+ logger.debug(
175
+ "IP mismatch; id=%s expected=%s got=%s",
176
+ session.session_id,
177
+ session.ip_address,
178
+ ip_address,
179
+ )
180
+ raise SessionSecurityError(msg)
181
+
182
+ if self.config.user_agent_binding and not self.security.check_user_agent_match(
183
+ session,
184
+ user_agent,
185
+ ):
186
+ msg = "User-Agent mismatch"
187
+ logger.debug(
188
+ "UA mismatch; id=%s expected=%s got=%s",
189
+ session.session_id,
190
+ session.user_agent,
191
+ user_agent,
192
+ )
193
+ raise SessionSecurityError(msg)
194
+
195
+ # Update last accessed and handle sliding expiration
196
+ session.update_last_accessed()
197
+
198
+ if self.config.sliding_expiration and session.expires_at:
199
+ time_remaining = (
200
+ session.expires_at - datetime.now(timezone.utc)
201
+ ).total_seconds()
202
+ threshold = self.config.session_ttl * self.config.sliding_threshold
203
+
204
+ if time_remaining < threshold:
205
+ session.renew(self.config.session_ttl)
206
+ logger.debug(
207
+ "Session renewed (sliding expiration); id=%s ttl=%s",
208
+ session.session_id,
209
+ self.config.session_ttl,
210
+ )
211
+
212
+ await self._save_session(session)
213
+
214
+ return session
215
+
216
+ async def update_session(self, session: Session) -> None:
217
+ """Update an existing session.
218
+
219
+ Args:
220
+ session: Session to update
221
+ """
222
+ session.update_last_accessed()
223
+ await self._save_session(session)
224
+ logger.debug("Session updated; id=%s", session.session_id)
225
+
226
+ async def delete_session(self, session_id: str) -> None:
227
+ """Delete a session.
228
+
229
+ Args:
230
+ session_id: Session ID to delete
231
+ """
232
+ key = self._get_backend_key(session_id)
233
+ await self.backend.delete(key)
234
+ logger.debug("Session deleted; id=%s", session_id)
235
+
236
+ async def invalidate_session(self, session: Session) -> None:
237
+ """Invalidate a session.
238
+
239
+ Args:
240
+ session: Session to invalidate
241
+ """
242
+ session.invalidate()
243
+ await self._save_session(session)
244
+ logger.debug("Session invalidated; id=%s", session.session_id)
245
+
246
+ async def regenerate_session_id(
247
+ self,
248
+ session: Session,
249
+ ) -> tuple[Session, str]:
250
+ """Regenerate session ID (after login for security).
251
+
252
+ Args:
253
+ session: Session to regenerate
254
+
255
+ Returns:
256
+ Tuple of (updated session, new token string)
257
+ """
258
+ # Delete old session
259
+ await self.delete_session(session.session_id)
260
+
261
+ # Generate new ID
262
+ old_id = session.session_id
263
+ session.regenerate_id()
264
+
265
+ # Save with new ID
266
+ await self._save_session(session)
267
+
268
+ # Create new token
269
+ token = self._create_token(session.session_id)
270
+ logger.debug(
271
+ "Session ID regenerated; old_id=%s new_id=%s", old_id, session.session_id
272
+ )
273
+
274
+ return session, token.to_string()
275
+
276
+ async def delete_user_sessions(self, user_id: str) -> int:
277
+ """Delete all sessions for a user.
278
+
279
+ Args:
280
+ user_id: User ID
281
+
282
+ Returns:
283
+ Number of sessions deleted
284
+ """
285
+ # This requires scanning all session keys
286
+ count = 0
287
+
288
+ try:
289
+ all_keys = await self.backend.get_all_keys()
290
+ for key in all_keys:
291
+ if key.startswith(self.config.backend_key_prefix):
292
+ session = await self._load_session_by_key(key)
293
+ if session and session.user and session.user.user_id == user_id:
294
+ await self.backend.delete(key)
295
+ count += 1
296
+ except NotImplementedError: # pragma: no cover
297
+ # Backend doesn't support get_all_keys, can't delete by user
298
+ pass
299
+
300
+ logger.debug("User sessions deleted; user_id=%s count=%s", user_id, count)
301
+ return count
302
+
303
+ async def clear_expired_sessions(self) -> int:
304
+ """Clear all expired sessions.
305
+
306
+ Returns:
307
+ Number of sessions cleared
308
+ """
309
+ count = 0
310
+
311
+ try:
312
+ all_keys = await self.backend.get_all_keys()
313
+ for key in all_keys:
314
+ if key.startswith(self.config.backend_key_prefix):
315
+ session = await self._load_session_by_key(key)
316
+ if session and session.is_expired():
317
+ await self.backend.delete(key)
318
+ count += 1
319
+ except NotImplementedError: # pragma: no cover
320
+ # Backend doesn't support get_all_keys
321
+ pass
322
+
323
+ logger.debug("Expired sessions cleared; count=%s", count)
324
+ return count
325
+
326
+ def _create_token(self, session_id: str) -> SessionToken:
327
+ """Create a signed session token.
328
+
329
+ Args:
330
+ session_id: Session ID to sign
331
+
332
+ Returns:
333
+ SessionToken object
334
+ """
335
+ signature = self.security.sign_session_id(session_id)
336
+ return SessionToken(session_id=session_id, signature=signature)
337
+
338
+ async def _save_session(self, session: Session) -> None:
339
+ """Save session to backend.
340
+
341
+ Args:
342
+ session: Session to save
343
+ """
344
+ key = self._get_backend_key(session.session_id)
345
+ value = session.model_dump_json().encode("utf-8")
346
+
347
+ # Calculate TTL
348
+ ttl = None
349
+ if session.expires_at:
350
+ ttl = int((session.expires_at - datetime.now(timezone.utc)).total_seconds())
351
+ ttl = max(ttl, 1) # Ensure at least 1 second
352
+
353
+ # Store as bytes in cache backend (wrapped in ETagContent for compatibility)
354
+ etag = self.security.hash_data(value.decode("utf-8"))
355
+ await self.backend.set(key, ETagContent(etag=etag, content=value), ttl=ttl)
356
+ logger.debug("Session saved; id=%s ttl=%s", session.session_id, ttl)
357
+
358
+ async def _load_session(self, session_id: str) -> Session | None:
359
+ """Load session from backend.
360
+
361
+ Args:
362
+ session_id: Session ID to load
363
+
364
+ Returns:
365
+ Session object or None if not found
366
+ """
367
+ key = self._get_backend_key(session_id)
368
+ return await self._load_session_by_key(key)
369
+
370
+ async def _load_session_by_key(self, key: str) -> Session | None:
371
+ """Load session from backend by key.
372
+
373
+ Args:
374
+ key: Backend key
375
+
376
+ Returns:
377
+ Session object or None if not found
378
+ """
379
+ cached = await self.backend.get(key)
380
+ if not cached:
381
+ logger.debug("Session load MISS; key=%s", key)
382
+ return None
383
+
384
+ try:
385
+ return Session.model_validate_json(cached.content)
386
+ except (ValueError, TypeError): # pragma: no cover
387
+ # Invalid session data
388
+ logger.debug("Session load DESERIALIZE ERROR; key=%s", key)
389
+ return None
@@ -0,0 +1,149 @@
1
+ """Session middleware for FastAPI."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ from fastapi import Request
7
+ from fastapi import Response
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+ from starlette.middleware.base import RequestResponseEndpoint
10
+ from starlette.types import ASGIApp
11
+
12
+ from .config import SessionConfig
13
+ from .exceptions import SessionError
14
+ from .manager import SessionManager
15
+
16
+ if TYPE_CHECKING:
17
+ from .models import Session
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class SessionMiddleware(BaseHTTPMiddleware):
23
+ """Middleware to handle session loading and cookie management."""
24
+
25
+ def __init__(
26
+ self,
27
+ app: ASGIApp,
28
+ session_manager: SessionManager,
29
+ config: SessionConfig | None = None,
30
+ ) -> None:
31
+ """Initialize session middleware.
32
+
33
+ Args:
34
+ app: ASGI application
35
+ session_manager: Session manager instance
36
+ config: Session configuration
37
+ """
38
+ super().__init__(app)
39
+ self.session_manager = session_manager
40
+
41
+ if config is None:
42
+ config = self.session_manager.config
43
+
44
+ self.config = config
45
+
46
+ logger.debug(
47
+ "SessionMiddleware initialized; header=%s bearer=%s",
48
+ config.header_name,
49
+ config.use_bearer_token,
50
+ )
51
+
52
+ async def dispatch(
53
+ self,
54
+ request: Request,
55
+ call_next: RequestResponseEndpoint,
56
+ ) -> Response:
57
+ """Process request and handle session.
58
+
59
+ Args:
60
+ request: Incoming request
61
+ call_next: Next handler in chain
62
+
63
+ Returns:
64
+ Response
65
+ """
66
+ # Extract session token from request
67
+ token = self._extract_token(request)
68
+
69
+ # Try to load session
70
+ session: Session | None = None
71
+ if token:
72
+ try:
73
+ ip_address = self._get_client_ip(request)
74
+ user_agent = request.headers.get("user-agent")
75
+ session = await self.session_manager.get_session(
76
+ token,
77
+ ip_address=ip_address,
78
+ user_agent=user_agent,
79
+ )
80
+ logger.debug("Session loaded in middleware; id=%s", session.session_id)
81
+ except SessionError:
82
+ # Session invalid/expired, continue without session
83
+ session = None
84
+ logger.debug("Session failed to load; token invalid/expired")
85
+
86
+ # Store session in request state
87
+ setattr(request.state, "__fastapi_cachex_session", session)
88
+
89
+ # Process request
90
+ response: Response = await call_next(request)
91
+
92
+ return response
93
+
94
+ def _extract_token(self, request: Request) -> str | None:
95
+ """Extract session token from request.
96
+
97
+ Args:
98
+ request: Incoming request
99
+
100
+ Returns:
101
+ Session token or None
102
+ """
103
+ for source in self.config.token_source_priority:
104
+ if source == "header":
105
+ token = request.headers.get(self.config.header_name)
106
+ if token:
107
+ logger.debug("Token extracted from header")
108
+ return token
109
+
110
+ elif source == "bearer":
111
+ if self.config.use_bearer_token:
112
+ auth_header = request.headers.get("authorization")
113
+ if auth_header and auth_header.startswith("Bearer "):
114
+ bearer_prefix_len = 7
115
+ token_value = auth_header[bearer_prefix_len:]
116
+ logger.debug("Token extracted from bearer auth")
117
+ return token_value
118
+
119
+ return None
120
+
121
+ def _get_client_ip(self, request: Request) -> str | None:
122
+ """Get client IP address from request.
123
+
124
+ Args:
125
+ request: Incoming request
126
+
127
+ Returns:
128
+ Client IP address or None
129
+ """
130
+ # Check X-Forwarded-For header (for proxied requests)
131
+ forwarded_for = request.headers.get("x-forwarded-for")
132
+ if forwarded_for:
133
+ # Get first IP from comma-separated list
134
+ ip = forwarded_for.split(",")[0].strip()
135
+ logger.debug("Client IP from X-Forwarded-For: %s", ip)
136
+ return ip
137
+
138
+ # Check X-Real-IP header
139
+ real_ip = request.headers.get("x-real-ip")
140
+ if real_ip:
141
+ logger.debug("Client IP from X-Real-IP: %s", real_ip)
142
+ return real_ip
143
+
144
+ # Fallback to direct client IP
145
+ if request.client:
146
+ logger.debug("Client IP from connection: %s", request.client.host)
147
+ return request.client.host
148
+
149
+ return None