fastapi-cachex 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,185 @@
1
+ """Session data models and user structures."""
2
+
3
+ import logging
4
+ from datetime import datetime
5
+ from datetime import timedelta
6
+ from datetime import timezone
7
+ from enum import Enum
8
+ from typing import Any
9
+ from uuid import uuid4
10
+
11
+ from pydantic import BaseModel
12
+ from pydantic import Field
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Token format constant
17
+ TOKEN_PARTS_COUNT = 3
18
+
19
+
20
+ class SessionStatus(str, Enum):
21
+ """Session status enumeration."""
22
+
23
+ ACTIVE = "active"
24
+ EXPIRED = "expired"
25
+ INVALIDATED = "invalidated"
26
+
27
+
28
+ class SessionUser(BaseModel):
29
+ """Base session user model.
30
+
31
+ This can be extended by application-specific user models.
32
+ """
33
+
34
+ user_id: str
35
+ username: str | None = None
36
+ email: str | None = None
37
+ roles: list[str] = Field(default_factory=list)
38
+ permissions: list[str] = Field(default_factory=list)
39
+ metadata: dict[str, Any] = Field(default_factory=dict)
40
+
41
+ model_config = {"extra": "allow"}
42
+
43
+
44
+ class Session(BaseModel):
45
+ """Core session model containing all session data."""
46
+
47
+ session_id: str = Field(default_factory=lambda: str(uuid4()))
48
+ user: SessionUser | None = None
49
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
50
+ last_accessed: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
51
+ expires_at: datetime | None = None
52
+ status: SessionStatus = SessionStatus.ACTIVE
53
+ ip_address: str | None = None
54
+ user_agent: str | None = None
55
+ data: dict[str, Any] = Field(default_factory=dict)
56
+ flash_messages: list[dict[str, Any]] = Field(default_factory=list)
57
+
58
+ model_config = {"use_enum_values": True}
59
+
60
+ def is_valid(self) -> bool:
61
+ """Check if session is valid (active and not expired)."""
62
+ if self.status != SessionStatus.ACTIVE:
63
+ return False
64
+
65
+ return not (self.expires_at and datetime.now(timezone.utc) > self.expires_at)
66
+
67
+ def is_expired(self) -> bool:
68
+ """Check if session has expired."""
69
+ if self.expires_at is None:
70
+ return False
71
+ return datetime.now(timezone.utc) > self.expires_at
72
+
73
+ def update_last_accessed(self) -> None:
74
+ """Update the last accessed timestamp."""
75
+ self.last_accessed = datetime.now(timezone.utc)
76
+ logger.debug("Session last_accessed updated; id=%s", self.session_id)
77
+
78
+ def renew(self, ttl: int) -> None:
79
+ """Renew session expiry time.
80
+
81
+ Args:
82
+ ttl: Time-to-live in seconds
83
+ """
84
+ self.expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
85
+ self.update_last_accessed()
86
+ logger.debug("Session renewed; id=%s ttl=%s", self.session_id, ttl)
87
+
88
+ def invalidate(self) -> None:
89
+ """Mark session as invalidated."""
90
+ self.status = SessionStatus.INVALIDATED
91
+ logger.debug("Session invalidated; id=%s", self.session_id)
92
+
93
+ def regenerate_id(self) -> str:
94
+ """Regenerate session ID (for security after login).
95
+
96
+ Returns:
97
+ The new session ID
98
+ """
99
+ old_id = self.session_id
100
+ self.session_id = str(uuid4())
101
+ logger.debug(
102
+ "Session ID regenerated; old_id=%s new_id=%s", old_id, self.session_id
103
+ )
104
+ return self.session_id
105
+
106
+ def add_flash_message(self, message: str, category: str = "info") -> None:
107
+ """Add a flash message.
108
+
109
+ Args:
110
+ message: The message content
111
+ category: Message category (info, success, warning, error)
112
+ """
113
+ self.flash_messages.append(
114
+ {
115
+ "message": message,
116
+ "category": category,
117
+ "timestamp": datetime.now(timezone.utc).isoformat(),
118
+ }
119
+ )
120
+ logger.debug(
121
+ "Flash message added; id=%s category=%s", self.session_id, category
122
+ )
123
+
124
+ def get_flash_messages(self, clear: bool = True) -> list[dict[str, Any]]:
125
+ """Get and optionally clear flash messages.
126
+
127
+ Args:
128
+ clear: Whether to clear messages after retrieving
129
+
130
+ Returns:
131
+ List of flash messages
132
+ """
133
+ messages = self.flash_messages.copy()
134
+ if clear:
135
+ self.flash_messages.clear()
136
+ logger.debug(
137
+ "Flash messages cleared; id=%s count=%s", self.session_id, len(messages)
138
+ )
139
+ return messages
140
+
141
+
142
+ class SessionToken(BaseModel):
143
+ """Session token containing signed data."""
144
+
145
+ session_id: str
146
+ signature: str
147
+ issued_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
148
+
149
+ def to_string(self) -> str:
150
+ """Convert token to string format.
151
+
152
+ Format: {session_id}.{signature}.{timestamp}
153
+ """
154
+ timestamp = int(self.issued_at.timestamp())
155
+
156
+ logger.debug("SessionToken to_string called; id=%s", self.session_id)
157
+ return f"{self.session_id}.{self.signature}.{timestamp}"
158
+
159
+ @classmethod
160
+ def from_string(cls, token_str: str) -> "SessionToken":
161
+ """Parse token from string format.
162
+
163
+ Args:
164
+ token_str: Token string in format {session_id}.{signature}.{timestamp}
165
+
166
+ Returns:
167
+ SessionToken instance
168
+
169
+ Raises:
170
+ ValueError: If token format is invalid
171
+ """
172
+ parts = token_str.split(".")
173
+ if len(parts) != TOKEN_PARTS_COUNT:
174
+ msg = "Invalid token format"
175
+ raise ValueError(msg)
176
+
177
+ session_id, signature, timestamp = parts
178
+ try:
179
+ issued_at = datetime.fromtimestamp(int(timestamp), tz=timezone.utc)
180
+ except (ValueError, OSError) as e:
181
+ msg = f"Invalid timestamp in token: {e}"
182
+ raise ValueError(msg) from e
183
+
184
+ logger.debug("SessionToken parsed from string; id=%s", session_id)
185
+ return cls(session_id=session_id, signature=signature, issued_at=issued_at)
@@ -0,0 +1,111 @@
1
+ """Security utilities for session management."""
2
+
3
+ import hashlib
4
+ import hmac
5
+ import logging
6
+
7
+ from .models import Session
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class SecurityManager:
13
+ """Handles session security operations."""
14
+
15
+ def __init__(self, secret_key: str) -> None:
16
+ """Initialize security manager.
17
+
18
+ Args:
19
+ secret_key: Secret key for signing tokens
20
+ """
21
+ if len(secret_key) < 32: # noqa: PLR2004
22
+ msg = "Secret key must be at least 32 characters"
23
+ raise ValueError(msg)
24
+ self.secret_key = secret_key.encode("utf-8")
25
+
26
+ logger.debug(
27
+ "SecurityManager initialized with secret length=%s", len(secret_key)
28
+ )
29
+
30
+ def sign_session_id(self, session_id: str) -> str:
31
+ """Sign a session ID using HMAC-SHA256.
32
+
33
+ Args:
34
+ session_id: The session ID to sign
35
+
36
+ Returns:
37
+ The signature as a hex string
38
+ """
39
+ return hmac.new(
40
+ self.secret_key,
41
+ session_id.encode("utf-8"),
42
+ hashlib.sha256,
43
+ ).hexdigest()
44
+
45
+ def verify_signature(self, session_id: str, signature: str) -> bool:
46
+ """Verify a session signature.
47
+
48
+ Args:
49
+ session_id: The session ID
50
+ signature: The signature to verify
51
+
52
+ Returns:
53
+ True if signature is valid, False otherwise
54
+ """
55
+ expected_signature = self.sign_session_id(session_id)
56
+ # Use constant-time comparison to prevent timing attacks
57
+ valid = hmac.compare_digest(expected_signature, signature)
58
+
59
+ if not valid:
60
+ logger.debug("Signature verification failed; id=%s", session_id)
61
+ return valid
62
+
63
+ def check_ip_match(self, session: Session, current_ip: str | None) -> bool:
64
+ """Check if session IP matches current request IP.
65
+
66
+ Args:
67
+ session: The session to check
68
+ current_ip: Current request IP address
69
+
70
+ Returns:
71
+ True if IPs match or session has no IP binding
72
+ """
73
+ if session.ip_address is None:
74
+ return True
75
+ if current_ip is None:
76
+ return False
77
+ return session.ip_address == current_ip
78
+
79
+ def check_user_agent_match(
80
+ self,
81
+ session: Session,
82
+ current_user_agent: str | None,
83
+ ) -> bool:
84
+ """Check if session User-Agent matches current request.
85
+
86
+ Args:
87
+ session: The session to check
88
+ current_user_agent: Current request User-Agent
89
+
90
+ Returns:
91
+ True if User-Agents match or session has no UA binding
92
+ """
93
+ if session.user_agent is None:
94
+ return True
95
+ if current_user_agent is None:
96
+ return False
97
+ return session.user_agent == current_user_agent
98
+
99
+ def hash_data(self, data: str) -> str:
100
+ """Hash data using SHA-256.
101
+
102
+ Args:
103
+ data: Data to hash
104
+
105
+ Returns:
106
+ Hex digest of the hash
107
+ """
108
+ digest = hashlib.sha256(data.encode("utf-8")).hexdigest()
109
+
110
+ logger.debug("Data hashed for session operations")
111
+ return digest
@@ -0,0 +1,8 @@
1
+ """State management extension for FastAPI-CacheX."""
2
+
3
+ from .exceptions import InvalidStateError as InvalidStateError
4
+ from .exceptions import StateDataError as StateDataError
5
+ from .exceptions import StateError as StateError
6
+ from .exceptions import StateExpiredError as StateExpiredError
7
+ from .manager import StateManager as StateManager
8
+ from .models import StateData as StateData
@@ -0,0 +1,19 @@
1
+ """Custom exception classes for state management."""
2
+
3
+ from fastapi_cachex.exceptions import CacheXError
4
+
5
+
6
+ class StateError(CacheXError):
7
+ """Base exception for state-related errors."""
8
+
9
+
10
+ class InvalidStateError(StateError):
11
+ """Raised when a state is invalid or not found."""
12
+
13
+
14
+ class StateExpiredError(StateError):
15
+ """Raised when a state has expired."""
16
+
17
+
18
+ class StateDataError(StateError):
19
+ """Raised when state data parsing or format is invalid."""
@@ -0,0 +1,258 @@
1
+ """State manager for OAuth and session state handling."""
2
+
3
+ import hashlib
4
+ import json
5
+ import logging
6
+ import secrets
7
+ from datetime import datetime
8
+ from datetime import timedelta
9
+ from datetime import timezone
10
+ from typing import Any
11
+
12
+ from fastapi_cachex.proxy import BackendProxy
13
+ from fastapi_cachex.types import ETagContent
14
+
15
+ from .exceptions import InvalidStateError
16
+ from .exceptions import StateDataError
17
+ from .exceptions import StateExpiredError
18
+ from .models import StateData
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Default TTL for OAuth state (10 minutes)
23
+ DEFAULT_STATE_TTL = 600
24
+
25
+
26
+ class StateManager:
27
+ """Manages OAuth state and session state lifecycle and storage."""
28
+
29
+ def __init__(
30
+ self, key_prefix: str = "oauth_state:", default_ttl: int = DEFAULT_STATE_TTL
31
+ ) -> None:
32
+ """Initialize StateManager.
33
+
34
+ Args:
35
+ key_prefix: Prefix for state keys in cache backend
36
+ default_ttl: Default time-to-live in seconds for state
37
+ """
38
+ self.backend = BackendProxy.get_backend()
39
+ self.key_prefix = key_prefix
40
+ self.default_ttl = default_ttl
41
+
42
+ async def create_state(
43
+ self,
44
+ ttl: int | None = None,
45
+ metadata: dict[str, Any] | None = None,
46
+ ) -> str:
47
+ """Create a new random OAuth state and store it with metadata.
48
+
49
+ Args:
50
+ ttl: Time-to-live in seconds (uses default_ttl if not provided)
51
+ metadata: Additional metadata to store with the state (e.g., callback_url, user_info)
52
+
53
+ Returns:
54
+ The generated state string
55
+
56
+ Raises:
57
+ StateDataError: If backend storage fails
58
+ """
59
+ # Generate a random state string (32 bytes = 256 bits of entropy)
60
+ state = secrets.token_urlsafe(32)
61
+
62
+ # Use provided TTL or default
63
+ effective_ttl = ttl if ttl is not None else self.default_ttl
64
+
65
+ # Create state data model
66
+ state_data = StateData(
67
+ state=state,
68
+ expires_at=datetime.now(timezone.utc) + timedelta(seconds=effective_ttl),
69
+ metadata=metadata or {},
70
+ )
71
+
72
+ # Serialize to JSON
73
+ json_content = json.dumps(state_data.model_dump(mode="json"))
74
+
75
+ # Create ETag from hash of state data
76
+ etag = hashlib.sha256(json_content.encode()).hexdigest()
77
+
78
+ # Store in backend with TTL using ETagContent
79
+ cache_key = f"{self.key_prefix}{state}"
80
+ etag_content = ETagContent(etag=etag, content=json_content)
81
+ await self.backend.set(cache_key, etag_content, ttl=effective_ttl)
82
+
83
+ logger.debug("OAuth state created; state=%s ttl=%s", state, effective_ttl)
84
+ return state
85
+
86
+ async def consume_state(self, state: str) -> StateData:
87
+ """Consume and validate an OAuth state, removing it from storage.
88
+
89
+ Args:
90
+ state: The state string to validate and consume
91
+
92
+ Returns:
93
+ StateData object containing state data and metadata
94
+
95
+ Raises:
96
+ InvalidStateError: If state is invalid or not found
97
+ StateExpiredError: If state has expired
98
+ StateDataError: If state data format is invalid
99
+ """
100
+ cache_key = f"{self.key_prefix}{state}"
101
+
102
+ # Retrieve state data from backend
103
+ cached_etag_content = await self.backend.get(cache_key)
104
+ if cached_etag_content is None:
105
+ logger.warning("OAuth state not found or expired; state=%s", state)
106
+ msg = "Invalid or expired state"
107
+ raise InvalidStateError(msg)
108
+
109
+ # Extract content from ETagContent
110
+ json_content = cached_etag_content.content
111
+ if not isinstance(json_content, str):
112
+ msg = "Unexpected state data format"
113
+ logger.error(
114
+ "Unexpected content type in state; state=%s type=%s",
115
+ state,
116
+ type(json_content),
117
+ )
118
+ raise StateDataError(msg)
119
+
120
+ # Parse the stored state data
121
+ try:
122
+ state_dict: dict[str, Any] = json.loads(json_content)
123
+ except json.JSONDecodeError as e:
124
+ msg = f"Failed to parse state data: {e}"
125
+ logger.exception("Failed to parse state data; state=%s", state)
126
+ raise StateDataError(msg) from e
127
+
128
+ # Validate and create StateData model
129
+ try:
130
+ state_data = StateData(**state_dict)
131
+ except ValueError as e:
132
+ msg = f"Invalid state data structure: {e}"
133
+ logger.exception("Failed to create StateData model; state=%s", state)
134
+ raise StateDataError(msg) from e
135
+
136
+ # Verify expiry
137
+ if datetime.now(timezone.utc) > state_data.expires_at:
138
+ logger.warning("OAuth state expired; state=%s", state)
139
+ msg = "State has expired"
140
+ raise StateExpiredError(msg)
141
+
142
+ # Delete the state from backend to prevent reuse
143
+ await self.backend.delete(cache_key)
144
+ logger.debug("OAuth state consumed and deleted; state=%s", state)
145
+
146
+ return state_data
147
+
148
+ async def validate_state(self, state: str) -> bool:
149
+ """Validate if a state exists and is not expired (without consuming it).
150
+
151
+ Args:
152
+ state: The state string to validate
153
+
154
+ Returns:
155
+ True if state is valid and not expired, False otherwise
156
+ """
157
+ cache_key = f"{self.key_prefix}{state}"
158
+
159
+ # Try to retrieve state data from backend
160
+ cached_etag_content = await self.backend.get(cache_key)
161
+ if cached_etag_content is None:
162
+ logger.debug("State validation failed - not found; state=%s", state)
163
+ return False
164
+
165
+ # Extract content from ETagContent
166
+ json_content = cached_etag_content.content
167
+ if not isinstance(json_content, str):
168
+ logger.error(
169
+ "Unexpected content type in state; state=%s type=%s",
170
+ state,
171
+ type(json_content),
172
+ )
173
+ return False
174
+
175
+ try:
176
+ state_dict: dict[str, Any] = json.loads(json_content)
177
+ except json.JSONDecodeError:
178
+ logger.exception(
179
+ "Failed to parse state data during validation; state=%s",
180
+ state,
181
+ )
182
+ return False
183
+
184
+ # Validate and create StateData model
185
+ try:
186
+ state_data = StateData(**state_dict)
187
+ except ValueError:
188
+ logger.exception(
189
+ "Failed to create StateData model during validation; state=%s",
190
+ state,
191
+ )
192
+ return False
193
+
194
+ # Check expiry
195
+ if datetime.now(timezone.utc) > state_data.expires_at:
196
+ logger.debug("State validation failed - expired; state=%s", state)
197
+ return False
198
+
199
+ logger.debug("State validation succeeded; state=%s", state)
200
+ return True
201
+
202
+ async def get_state_metadata(self, state: str) -> dict[str, Any] | None:
203
+ """Retrieve metadata for a state without consuming it.
204
+
205
+ Args:
206
+ state: The state string
207
+
208
+ Returns:
209
+ Metadata dictionary if state exists and is valid, None otherwise
210
+ """
211
+ cache_key = f"{self.key_prefix}{state}"
212
+
213
+ cached_etag_content = await self.backend.get(cache_key)
214
+ if cached_etag_content is None:
215
+ return None
216
+
217
+ # Extract content from ETagContent
218
+ json_content = cached_etag_content.content
219
+ if not isinstance(json_content, str):
220
+ logger.error(
221
+ "Unexpected content type in state; state=%s type=%s",
222
+ state,
223
+ type(json_content),
224
+ )
225
+ return None
226
+
227
+ try:
228
+ state_dict: dict[str, Any] = json.loads(json_content)
229
+ except json.JSONDecodeError:
230
+ logger.exception("Failed to parse state data; state=%s", state)
231
+ return None
232
+
233
+ # Validate and create StateData model
234
+ try:
235
+ state_data = StateData(**state_dict)
236
+ except ValueError:
237
+ logger.exception("Failed to create StateData model; state=%s", state)
238
+ return None
239
+
240
+ # Check expiry
241
+ if datetime.now(timezone.utc) > state_data.expires_at:
242
+ return None
243
+
244
+ return state_data.metadata
245
+
246
+ async def delete_state(self, state: str) -> bool:
247
+ """Manually delete a state from storage.
248
+
249
+ Args:
250
+ state: The state string to delete
251
+
252
+ Returns:
253
+ True if state was deleted, False if it didn't exist
254
+ """
255
+ cache_key = f"{self.key_prefix}{state}"
256
+ await self.backend.delete(cache_key)
257
+ logger.debug("OAuth state deleted; state=%s", state)
258
+ return True
@@ -0,0 +1,31 @@
1
+ """Data models for state management."""
2
+
3
+ from datetime import datetime
4
+ from datetime import timezone
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel
8
+ from pydantic import ConfigDict
9
+ from pydantic import Field
10
+ from pydantic import field_serializer
11
+
12
+
13
+ class StateData(BaseModel):
14
+ """OAuth state data model."""
15
+
16
+ model_config = ConfigDict()
17
+
18
+ state: str = Field(..., description="The unique state identifier")
19
+ created_at: datetime = Field(
20
+ default_factory=lambda: datetime.now(timezone.utc),
21
+ description="When the state was created",
22
+ )
23
+ expires_at: datetime = Field(..., description="When the state expires")
24
+ metadata: dict[str, Any] = Field(
25
+ default_factory=dict, description="Additional metadata associated with state"
26
+ )
27
+
28
+ @field_serializer("created_at", "expires_at", when_used="json")
29
+ def serialize_datetime(self, value: datetime) -> str:
30
+ """Serialize datetime to ISO format string for JSON."""
31
+ return value.isoformat()
fastapi_cachex/types.py CHANGED
@@ -1,8 +1,17 @@
1
1
  """Type definitions and type aliases for FastAPI-CacheX."""
2
2
 
3
+ from collections.abc import Callable
3
4
  from dataclasses import dataclass
4
5
  from typing import Any
5
6
 
7
+ from fastapi import Request
8
+
9
+ # Cache key separator - using ||| to avoid conflicts with port numbers in host (e.g., 127.0.0.1:8000)
10
+ CACHE_KEY_SEPARATOR = "|||"
11
+
12
+ # Type for custom cache key builder function
13
+ CacheKeyBuilder = Callable[[Request], str]
14
+
6
15
 
7
16
  @dataclass
8
17
  class ETagContent: