fastapi-cachex 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_cachex/__init__.py +20 -0
- fastapi_cachex/backends/__init__.py +4 -4
- fastapi_cachex/backends/memcached.py +21 -2
- fastapi_cachex/backends/memory.py +33 -5
- fastapi_cachex/backends/redis.py +29 -6
- fastapi_cachex/cache.py +59 -19
- fastapi_cachex/dependencies.py +2 -2
- fastapi_cachex/proxy.py +9 -2
- fastapi_cachex/routes.py +6 -5
- fastapi_cachex/session/__init__.py +21 -0
- fastapi_cachex/session/config.py +70 -0
- fastapi_cachex/session/dependencies.py +65 -0
- fastapi_cachex/session/exceptions.py +25 -0
- fastapi_cachex/session/manager.py +389 -0
- fastapi_cachex/session/middleware.py +149 -0
- fastapi_cachex/session/models.py +185 -0
- fastapi_cachex/session/security.py +111 -0
- fastapi_cachex/state/__init__.py +8 -0
- fastapi_cachex/state/exceptions.py +19 -0
- fastapi_cachex/state/manager.py +258 -0
- fastapi_cachex/state/models.py +31 -0
- fastapi_cachex/types.py +9 -0
- {fastapi_cachex-0.2.1.dist-info → fastapi_cachex-0.2.3.dist-info}/METADATA +23 -5
- fastapi_cachex-0.2.3.dist-info/RECORD +29 -0
- fastapi_cachex-0.2.1.dist-info/RECORD +0 -17
- {fastapi_cachex-0.2.1.dist-info → fastapi_cachex-0.2.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Session data models and user structures."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from datetime import timedelta
|
|
6
|
+
from datetime import timezone
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
from pydantic import Field
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
# Token format constant
|
|
17
|
+
TOKEN_PARTS_COUNT = 3
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SessionStatus(str, Enum):
|
|
21
|
+
"""Session status enumeration."""
|
|
22
|
+
|
|
23
|
+
ACTIVE = "active"
|
|
24
|
+
EXPIRED = "expired"
|
|
25
|
+
INVALIDATED = "invalidated"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SessionUser(BaseModel):
|
|
29
|
+
"""Base session user model.
|
|
30
|
+
|
|
31
|
+
This can be extended by application-specific user models.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
user_id: str
|
|
35
|
+
username: str | None = None
|
|
36
|
+
email: str | None = None
|
|
37
|
+
roles: list[str] = Field(default_factory=list)
|
|
38
|
+
permissions: list[str] = Field(default_factory=list)
|
|
39
|
+
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
40
|
+
|
|
41
|
+
model_config = {"extra": "allow"}
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Session(BaseModel):
|
|
45
|
+
"""Core session model containing all session data."""
|
|
46
|
+
|
|
47
|
+
session_id: str = Field(default_factory=lambda: str(uuid4()))
|
|
48
|
+
user: SessionUser | None = None
|
|
49
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
50
|
+
last_accessed: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
51
|
+
expires_at: datetime | None = None
|
|
52
|
+
status: SessionStatus = SessionStatus.ACTIVE
|
|
53
|
+
ip_address: str | None = None
|
|
54
|
+
user_agent: str | None = None
|
|
55
|
+
data: dict[str, Any] = Field(default_factory=dict)
|
|
56
|
+
flash_messages: list[dict[str, Any]] = Field(default_factory=list)
|
|
57
|
+
|
|
58
|
+
model_config = {"use_enum_values": True}
|
|
59
|
+
|
|
60
|
+
def is_valid(self) -> bool:
|
|
61
|
+
"""Check if session is valid (active and not expired)."""
|
|
62
|
+
if self.status != SessionStatus.ACTIVE:
|
|
63
|
+
return False
|
|
64
|
+
|
|
65
|
+
return not (self.expires_at and datetime.now(timezone.utc) > self.expires_at)
|
|
66
|
+
|
|
67
|
+
def is_expired(self) -> bool:
|
|
68
|
+
"""Check if session has expired."""
|
|
69
|
+
if self.expires_at is None:
|
|
70
|
+
return False
|
|
71
|
+
return datetime.now(timezone.utc) > self.expires_at
|
|
72
|
+
|
|
73
|
+
def update_last_accessed(self) -> None:
|
|
74
|
+
"""Update the last accessed timestamp."""
|
|
75
|
+
self.last_accessed = datetime.now(timezone.utc)
|
|
76
|
+
logger.debug("Session last_accessed updated; id=%s", self.session_id)
|
|
77
|
+
|
|
78
|
+
def renew(self, ttl: int) -> None:
|
|
79
|
+
"""Renew session expiry time.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
ttl: Time-to-live in seconds
|
|
83
|
+
"""
|
|
84
|
+
self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
|
85
|
+
self.update_last_accessed()
|
|
86
|
+
logger.debug("Session renewed; id=%s ttl=%s", self.session_id, ttl)
|
|
87
|
+
|
|
88
|
+
def invalidate(self) -> None:
|
|
89
|
+
"""Mark session as invalidated."""
|
|
90
|
+
self.status = SessionStatus.INVALIDATED
|
|
91
|
+
logger.debug("Session invalidated; id=%s", self.session_id)
|
|
92
|
+
|
|
93
|
+
def regenerate_id(self) -> str:
|
|
94
|
+
"""Regenerate session ID (for security after login).
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The new session ID
|
|
98
|
+
"""
|
|
99
|
+
old_id = self.session_id
|
|
100
|
+
self.session_id = str(uuid4())
|
|
101
|
+
logger.debug(
|
|
102
|
+
"Session ID regenerated; old_id=%s new_id=%s", old_id, self.session_id
|
|
103
|
+
)
|
|
104
|
+
return self.session_id
|
|
105
|
+
|
|
106
|
+
def add_flash_message(self, message: str, category: str = "info") -> None:
|
|
107
|
+
"""Add a flash message.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
message: The message content
|
|
111
|
+
category: Message category (info, success, warning, error)
|
|
112
|
+
"""
|
|
113
|
+
self.flash_messages.append(
|
|
114
|
+
{
|
|
115
|
+
"message": message,
|
|
116
|
+
"category": category,
|
|
117
|
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
118
|
+
}
|
|
119
|
+
)
|
|
120
|
+
logger.debug(
|
|
121
|
+
"Flash message added; id=%s category=%s", self.session_id, category
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def get_flash_messages(self, clear: bool = True) -> list[dict[str, Any]]:
|
|
125
|
+
"""Get and optionally clear flash messages.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
clear: Whether to clear messages after retrieving
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
List of flash messages
|
|
132
|
+
"""
|
|
133
|
+
messages = self.flash_messages.copy()
|
|
134
|
+
if clear:
|
|
135
|
+
self.flash_messages.clear()
|
|
136
|
+
logger.debug(
|
|
137
|
+
"Flash messages cleared; id=%s count=%s", self.session_id, len(messages)
|
|
138
|
+
)
|
|
139
|
+
return messages
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
class SessionToken(BaseModel):
|
|
143
|
+
"""Session token containing signed data."""
|
|
144
|
+
|
|
145
|
+
session_id: str
|
|
146
|
+
signature: str
|
|
147
|
+
issued_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
148
|
+
|
|
149
|
+
def to_string(self) -> str:
|
|
150
|
+
"""Convert token to string format.
|
|
151
|
+
|
|
152
|
+
Format: {session_id}.{signature}.{timestamp}
|
|
153
|
+
"""
|
|
154
|
+
timestamp = int(self.issued_at.timestamp())
|
|
155
|
+
|
|
156
|
+
logger.debug("SessionToken to_string called; id=%s", self.session_id)
|
|
157
|
+
return f"{self.session_id}.{self.signature}.{timestamp}"
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def from_string(cls, token_str: str) -> "SessionToken":
|
|
161
|
+
"""Parse token from string format.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
token_str: Token string in format {session_id}.{signature}.{timestamp}
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
SessionToken instance
|
|
168
|
+
|
|
169
|
+
Raises:
|
|
170
|
+
ValueError: If token format is invalid
|
|
171
|
+
"""
|
|
172
|
+
parts = token_str.split(".")
|
|
173
|
+
if len(parts) != TOKEN_PARTS_COUNT:
|
|
174
|
+
msg = "Invalid token format"
|
|
175
|
+
raise ValueError(msg)
|
|
176
|
+
|
|
177
|
+
session_id, signature, timestamp = parts
|
|
178
|
+
try:
|
|
179
|
+
issued_at = datetime.fromtimestamp(int(timestamp), tz=timezone.utc)
|
|
180
|
+
except (ValueError, OSError) as e:
|
|
181
|
+
msg = f"Invalid timestamp in token: {e}"
|
|
182
|
+
raise ValueError(msg) from e
|
|
183
|
+
|
|
184
|
+
logger.debug("SessionToken parsed from string; id=%s", session_id)
|
|
185
|
+
return cls(session_id=session_id, signature=signature, issued_at=issued_at)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Security utilities for session management."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import hmac
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from .models import Session
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SecurityManager:
|
|
13
|
+
"""Handles session security operations."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, secret_key: str) -> None:
|
|
16
|
+
"""Initialize security manager.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
secret_key: Secret key for signing tokens
|
|
20
|
+
"""
|
|
21
|
+
if len(secret_key) < 32: # noqa: PLR2004
|
|
22
|
+
msg = "Secret key must be at least 32 characters"
|
|
23
|
+
raise ValueError(msg)
|
|
24
|
+
self.secret_key = secret_key.encode("utf-8")
|
|
25
|
+
|
|
26
|
+
logger.debug(
|
|
27
|
+
"SecurityManager initialized with secret length=%s", len(secret_key)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
def sign_session_id(self, session_id: str) -> str:
|
|
31
|
+
"""Sign a session ID using HMAC-SHA256.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
session_id: The session ID to sign
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
The signature as a hex string
|
|
38
|
+
"""
|
|
39
|
+
return hmac.new(
|
|
40
|
+
self.secret_key,
|
|
41
|
+
session_id.encode("utf-8"),
|
|
42
|
+
hashlib.sha256,
|
|
43
|
+
).hexdigest()
|
|
44
|
+
|
|
45
|
+
def verify_signature(self, session_id: str, signature: str) -> bool:
|
|
46
|
+
"""Verify a session signature.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
session_id: The session ID
|
|
50
|
+
signature: The signature to verify
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
True if signature is valid, False otherwise
|
|
54
|
+
"""
|
|
55
|
+
expected_signature = self.sign_session_id(session_id)
|
|
56
|
+
# Use constant-time comparison to prevent timing attacks
|
|
57
|
+
valid = hmac.compare_digest(expected_signature, signature)
|
|
58
|
+
|
|
59
|
+
if not valid:
|
|
60
|
+
logger.debug("Signature verification failed; id=%s", session_id)
|
|
61
|
+
return valid
|
|
62
|
+
|
|
63
|
+
def check_ip_match(self, session: Session, current_ip: str | None) -> bool:
|
|
64
|
+
"""Check if session IP matches current request IP.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
session: The session to check
|
|
68
|
+
current_ip: Current request IP address
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
True if IPs match or session has no IP binding
|
|
72
|
+
"""
|
|
73
|
+
if session.ip_address is None:
|
|
74
|
+
return True
|
|
75
|
+
if current_ip is None:
|
|
76
|
+
return False
|
|
77
|
+
return session.ip_address == current_ip
|
|
78
|
+
|
|
79
|
+
def check_user_agent_match(
|
|
80
|
+
self,
|
|
81
|
+
session: Session,
|
|
82
|
+
current_user_agent: str | None,
|
|
83
|
+
) -> bool:
|
|
84
|
+
"""Check if session User-Agent matches current request.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
session: The session to check
|
|
88
|
+
current_user_agent: Current request User-Agent
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
True if User-Agents match or session has no UA binding
|
|
92
|
+
"""
|
|
93
|
+
if session.user_agent is None:
|
|
94
|
+
return True
|
|
95
|
+
if current_user_agent is None:
|
|
96
|
+
return False
|
|
97
|
+
return session.user_agent == current_user_agent
|
|
98
|
+
|
|
99
|
+
def hash_data(self, data: str) -> str:
|
|
100
|
+
"""Hash data using SHA-256.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
data: Data to hash
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
Hex digest of the hash
|
|
107
|
+
"""
|
|
108
|
+
digest = hashlib.sha256(data.encode("utf-8")).hexdigest()
|
|
109
|
+
|
|
110
|
+
logger.debug("Data hashed for session operations")
|
|
111
|
+
return digest
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""State management extension for FastAPI-CacheX."""
|
|
2
|
+
|
|
3
|
+
from .exceptions import InvalidStateError as InvalidStateError
|
|
4
|
+
from .exceptions import StateDataError as StateDataError
|
|
5
|
+
from .exceptions import StateError as StateError
|
|
6
|
+
from .exceptions import StateExpiredError as StateExpiredError
|
|
7
|
+
from .manager import StateManager as StateManager
|
|
8
|
+
from .models import StateData as StateData
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Custom exception classes for state management."""
|
|
2
|
+
|
|
3
|
+
from fastapi_cachex.exceptions import CacheXError
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class StateError(CacheXError):
|
|
7
|
+
"""Base exception for state-related errors."""
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class InvalidStateError(StateError):
|
|
11
|
+
"""Raised when a state is invalid or not found."""
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class StateExpiredError(StateError):
|
|
15
|
+
"""Raised when a state has expired."""
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class StateDataError(StateError):
|
|
19
|
+
"""Raised when state data parsing or format is invalid."""
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
"""State manager for OAuth and session state handling."""
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import logging
|
|
6
|
+
import secrets
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from datetime import timezone
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from fastapi_cachex.proxy import BackendProxy
|
|
13
|
+
from fastapi_cachex.types import ETagContent
|
|
14
|
+
|
|
15
|
+
from .exceptions import InvalidStateError
|
|
16
|
+
from .exceptions import StateDataError
|
|
17
|
+
from .exceptions import StateExpiredError
|
|
18
|
+
from .models import StateData
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Default TTL for OAuth state (10 minutes)
|
|
23
|
+
DEFAULT_STATE_TTL = 600
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class StateManager:
|
|
27
|
+
"""Manages OAuth state and session state lifecycle and storage."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self, key_prefix: str = "oauth_state:", default_ttl: int = DEFAULT_STATE_TTL
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize StateManager.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
key_prefix: Prefix for state keys in cache backend
|
|
36
|
+
default_ttl: Default time-to-live in seconds for state
|
|
37
|
+
"""
|
|
38
|
+
self.backend = BackendProxy.get_backend()
|
|
39
|
+
self.key_prefix = key_prefix
|
|
40
|
+
self.default_ttl = default_ttl
|
|
41
|
+
|
|
42
|
+
async def create_state(
|
|
43
|
+
self,
|
|
44
|
+
ttl: int | None = None,
|
|
45
|
+
metadata: dict[str, Any] | None = None,
|
|
46
|
+
) -> str:
|
|
47
|
+
"""Create a new random OAuth state and store it with metadata.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
ttl: Time-to-live in seconds (uses default_ttl if not provided)
|
|
51
|
+
metadata: Additional metadata to store with the state (e.g., callback_url, user_info)
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The generated state string
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
StateDataError: If backend storage fails
|
|
58
|
+
"""
|
|
59
|
+
# Generate a random state string (32 bytes = 256 bits of entropy)
|
|
60
|
+
state = secrets.token_urlsafe(32)
|
|
61
|
+
|
|
62
|
+
# Use provided TTL or default
|
|
63
|
+
effective_ttl = ttl if ttl is not None else self.default_ttl
|
|
64
|
+
|
|
65
|
+
# Create state data model
|
|
66
|
+
state_data = StateData(
|
|
67
|
+
state=state,
|
|
68
|
+
expires_at=datetime.now(timezone.utc) + timedelta(seconds=effective_ttl),
|
|
69
|
+
metadata=metadata or {},
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Serialize to JSON
|
|
73
|
+
json_content = json.dumps(state_data.model_dump(mode="json"))
|
|
74
|
+
|
|
75
|
+
# Create ETag from hash of state data
|
|
76
|
+
etag = hashlib.sha256(json_content.encode()).hexdigest()
|
|
77
|
+
|
|
78
|
+
# Store in backend with TTL using ETagContent
|
|
79
|
+
cache_key = f"{self.key_prefix}{state}"
|
|
80
|
+
etag_content = ETagContent(etag=etag, content=json_content)
|
|
81
|
+
await self.backend.set(cache_key, etag_content, ttl=effective_ttl)
|
|
82
|
+
|
|
83
|
+
logger.debug("OAuth state created; state=%s ttl=%s", state, effective_ttl)
|
|
84
|
+
return state
|
|
85
|
+
|
|
86
|
+
async def consume_state(self, state: str) -> StateData:
|
|
87
|
+
"""Consume and validate an OAuth state, removing it from storage.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
state: The state string to validate and consume
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
StateData object containing state data and metadata
|
|
94
|
+
|
|
95
|
+
Raises:
|
|
96
|
+
InvalidStateError: If state is invalid or not found
|
|
97
|
+
StateExpiredError: If state has expired
|
|
98
|
+
StateDataError: If state data format is invalid
|
|
99
|
+
"""
|
|
100
|
+
cache_key = f"{self.key_prefix}{state}"
|
|
101
|
+
|
|
102
|
+
# Retrieve state data from backend
|
|
103
|
+
cached_etag_content = await self.backend.get(cache_key)
|
|
104
|
+
if cached_etag_content is None:
|
|
105
|
+
logger.warning("OAuth state not found or expired; state=%s", state)
|
|
106
|
+
msg = "Invalid or expired state"
|
|
107
|
+
raise InvalidStateError(msg)
|
|
108
|
+
|
|
109
|
+
# Extract content from ETagContent
|
|
110
|
+
json_content = cached_etag_content.content
|
|
111
|
+
if not isinstance(json_content, str):
|
|
112
|
+
msg = "Unexpected state data format"
|
|
113
|
+
logger.error(
|
|
114
|
+
"Unexpected content type in state; state=%s type=%s",
|
|
115
|
+
state,
|
|
116
|
+
type(json_content),
|
|
117
|
+
)
|
|
118
|
+
raise StateDataError(msg)
|
|
119
|
+
|
|
120
|
+
# Parse the stored state data
|
|
121
|
+
try:
|
|
122
|
+
state_dict: dict[str, Any] = json.loads(json_content)
|
|
123
|
+
except json.JSONDecodeError as e:
|
|
124
|
+
msg = f"Failed to parse state data: {e}"
|
|
125
|
+
logger.exception("Failed to parse state data; state=%s", state)
|
|
126
|
+
raise StateDataError(msg) from e
|
|
127
|
+
|
|
128
|
+
# Validate and create StateData model
|
|
129
|
+
try:
|
|
130
|
+
state_data = StateData(**state_dict)
|
|
131
|
+
except ValueError as e:
|
|
132
|
+
msg = f"Invalid state data structure: {e}"
|
|
133
|
+
logger.exception("Failed to create StateData model; state=%s", state)
|
|
134
|
+
raise StateDataError(msg) from e
|
|
135
|
+
|
|
136
|
+
# Verify expiry
|
|
137
|
+
if datetime.now(timezone.utc) > state_data.expires_at:
|
|
138
|
+
logger.warning("OAuth state expired; state=%s", state)
|
|
139
|
+
msg = "State has expired"
|
|
140
|
+
raise StateExpiredError(msg)
|
|
141
|
+
|
|
142
|
+
# Delete the state from backend to prevent reuse
|
|
143
|
+
await self.backend.delete(cache_key)
|
|
144
|
+
logger.debug("OAuth state consumed and deleted; state=%s", state)
|
|
145
|
+
|
|
146
|
+
return state_data
|
|
147
|
+
|
|
148
|
+
async def validate_state(self, state: str) -> bool:
|
|
149
|
+
"""Validate if a state exists and is not expired (without consuming it).
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
state: The state string to validate
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
True if state is valid and not expired, False otherwise
|
|
156
|
+
"""
|
|
157
|
+
cache_key = f"{self.key_prefix}{state}"
|
|
158
|
+
|
|
159
|
+
# Try to retrieve state data from backend
|
|
160
|
+
cached_etag_content = await self.backend.get(cache_key)
|
|
161
|
+
if cached_etag_content is None:
|
|
162
|
+
logger.debug("State validation failed - not found; state=%s", state)
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
# Extract content from ETagContent
|
|
166
|
+
json_content = cached_etag_content.content
|
|
167
|
+
if not isinstance(json_content, str):
|
|
168
|
+
logger.error(
|
|
169
|
+
"Unexpected content type in state; state=%s type=%s",
|
|
170
|
+
state,
|
|
171
|
+
type(json_content),
|
|
172
|
+
)
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
try:
|
|
176
|
+
state_dict: dict[str, Any] = json.loads(json_content)
|
|
177
|
+
except json.JSONDecodeError:
|
|
178
|
+
logger.exception(
|
|
179
|
+
"Failed to parse state data during validation; state=%s",
|
|
180
|
+
state,
|
|
181
|
+
)
|
|
182
|
+
return False
|
|
183
|
+
|
|
184
|
+
# Validate and create StateData model
|
|
185
|
+
try:
|
|
186
|
+
state_data = StateData(**state_dict)
|
|
187
|
+
except ValueError:
|
|
188
|
+
logger.exception(
|
|
189
|
+
"Failed to create StateData model during validation; state=%s",
|
|
190
|
+
state,
|
|
191
|
+
)
|
|
192
|
+
return False
|
|
193
|
+
|
|
194
|
+
# Check expiry
|
|
195
|
+
if datetime.now(timezone.utc) > state_data.expires_at:
|
|
196
|
+
logger.debug("State validation failed - expired; state=%s", state)
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
logger.debug("State validation succeeded; state=%s", state)
|
|
200
|
+
return True
|
|
201
|
+
|
|
202
|
+
async def get_state_metadata(self, state: str) -> dict[str, Any] | None:
|
|
203
|
+
"""Retrieve metadata for a state without consuming it.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
state: The state string
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Metadata dictionary if state exists and is valid, None otherwise
|
|
210
|
+
"""
|
|
211
|
+
cache_key = f"{self.key_prefix}{state}"
|
|
212
|
+
|
|
213
|
+
cached_etag_content = await self.backend.get(cache_key)
|
|
214
|
+
if cached_etag_content is None:
|
|
215
|
+
return None
|
|
216
|
+
|
|
217
|
+
# Extract content from ETagContent
|
|
218
|
+
json_content = cached_etag_content.content
|
|
219
|
+
if not isinstance(json_content, str):
|
|
220
|
+
logger.error(
|
|
221
|
+
"Unexpected content type in state; state=%s type=%s",
|
|
222
|
+
state,
|
|
223
|
+
type(json_content),
|
|
224
|
+
)
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
state_dict: dict[str, Any] = json.loads(json_content)
|
|
229
|
+
except json.JSONDecodeError:
|
|
230
|
+
logger.exception("Failed to parse state data; state=%s", state)
|
|
231
|
+
return None
|
|
232
|
+
|
|
233
|
+
# Validate and create StateData model
|
|
234
|
+
try:
|
|
235
|
+
state_data = StateData(**state_dict)
|
|
236
|
+
except ValueError:
|
|
237
|
+
logger.exception("Failed to create StateData model; state=%s", state)
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
# Check expiry
|
|
241
|
+
if datetime.now(timezone.utc) > state_data.expires_at:
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
return state_data.metadata
|
|
245
|
+
|
|
246
|
+
async def delete_state(self, state: str) -> bool:
|
|
247
|
+
"""Manually delete a state from storage.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
state: The state string to delete
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
True if state was deleted, False if it didn't exist
|
|
254
|
+
"""
|
|
255
|
+
cache_key = f"{self.key_prefix}{state}"
|
|
256
|
+
await self.backend.delete(cache_key)
|
|
257
|
+
logger.debug("OAuth state deleted; state=%s", state)
|
|
258
|
+
return True
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Data models for state management."""
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from datetime import timezone
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel
|
|
8
|
+
from pydantic import ConfigDict
|
|
9
|
+
from pydantic import Field
|
|
10
|
+
from pydantic import field_serializer
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class StateData(BaseModel):
|
|
14
|
+
"""OAuth state data model."""
|
|
15
|
+
|
|
16
|
+
model_config = ConfigDict()
|
|
17
|
+
|
|
18
|
+
state: str = Field(..., description="The unique state identifier")
|
|
19
|
+
created_at: datetime = Field(
|
|
20
|
+
default_factory=lambda: datetime.now(timezone.utc),
|
|
21
|
+
description="When the state was created",
|
|
22
|
+
)
|
|
23
|
+
expires_at: datetime = Field(..., description="When the state expires")
|
|
24
|
+
metadata: dict[str, Any] = Field(
|
|
25
|
+
default_factory=dict, description="Additional metadata associated with state"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
@field_serializer("created_at", "expires_at", when_used="json")
|
|
29
|
+
def serialize_datetime(self, value: datetime) -> str:
|
|
30
|
+
"""Serialize datetime to ISO format string for JSON."""
|
|
31
|
+
return value.isoformat()
|
fastapi_cachex/types.py
CHANGED
|
@@ -1,8 +1,17 @@
|
|
|
1
1
|
"""Type definitions and type aliases for FastAPI-CacheX."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from typing import Any
|
|
5
6
|
|
|
7
|
+
from fastapi import Request
|
|
8
|
+
|
|
9
|
+
# Cache key separator - using ||| to avoid conflicts with port numbers in host (e.g., 127.0.0.1:8000)
|
|
10
|
+
CACHE_KEY_SEPARATOR = "|||"
|
|
11
|
+
|
|
12
|
+
# Type for custom cache key builder function
|
|
13
|
+
CacheKeyBuilder = Callable[[Request], str]
|
|
14
|
+
|
|
6
15
|
|
|
7
16
|
@dataclass
|
|
8
17
|
class ETagContent:
|