belgie 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- belgie/__init__.py +97 -0
- belgie/alchemy.py +24 -0
- belgie/auth/__init__.py +65 -0
- belgie/auth/core/__init__.py +0 -0
- belgie/auth/core/auth.py +368 -0
- belgie/auth/core/client.py +204 -0
- belgie/auth/core/exceptions.py +26 -0
- belgie/auth/core/hooks.py +87 -0
- belgie/auth/core/settings.py +100 -0
- belgie/auth/providers/__init__.py +10 -0
- belgie/auth/providers/google.py +284 -0
- belgie/auth/providers/protocols.py +43 -0
- belgie/auth/py.typed +0 -0
- belgie/auth/session/__init__.py +3 -0
- belgie/auth/session/manager.py +168 -0
- belgie/auth/utils/__init__.py +9 -0
- belgie/auth/utils/crypto.py +10 -0
- belgie/auth/utils/scopes.py +49 -0
- belgie/mcp.py +12 -0
- belgie/oauth.py +12 -0
- belgie/proto.py +22 -0
- belgie-0.1.0a4.dist-info/METADATA +231 -0
- belgie-0.1.0a4.dist-info/RECORD +24 -0
- belgie-0.1.0a4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
from collections.abc import AsyncGenerator, Callable
|
|
2
|
+
from datetime import UTC, datetime, timedelta
|
|
3
|
+
from typing import Literal
|
|
4
|
+
from urllib.parse import urlencode, urlparse, urlunparse
|
|
5
|
+
|
|
6
|
+
import httpx
|
|
7
|
+
from belgie_proto import AdapterProtocol, DBConnection
|
|
8
|
+
from fastapi import APIRouter, Depends
|
|
9
|
+
from fastapi.responses import RedirectResponse
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
11
|
+
from pydantic_settings import SettingsConfigDict
|
|
12
|
+
|
|
13
|
+
from belgie.auth.core.exceptions import InvalidStateError, OAuthError
|
|
14
|
+
from belgie.auth.core.hooks import HookContext, HookRunner
|
|
15
|
+
from belgie.auth.core.settings import CookieSettings, ProviderSettings
|
|
16
|
+
from belgie.auth.utils.crypto import generate_state_token
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class GoogleProviderSettings(ProviderSettings):
|
|
20
|
+
"""Google OAuth provider settings loaded from environment.
|
|
21
|
+
|
|
22
|
+
Contains only Google-specific OAuth configuration.
|
|
23
|
+
Session and redirect settings are passed via get_router() parameters.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
model_config = SettingsConfigDict(
|
|
27
|
+
env_prefix="BELGIE_GOOGLE_",
|
|
28
|
+
env_file=".env",
|
|
29
|
+
extra="ignore",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
scopes: list[str] = Field(default=["openid", "email", "profile"])
|
|
33
|
+
access_type: str = Field(default="offline")
|
|
34
|
+
prompt: str = Field(default="consent")
|
|
35
|
+
|
|
36
|
+
def __call__(self) -> "GoogleOAuthProvider":
|
|
37
|
+
"""Create and return Google OAuth provider instance.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
GoogleOAuthProvider configured with these settings
|
|
41
|
+
"""
|
|
42
|
+
return GoogleOAuthProvider(settings=self)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GoogleUserInfo(BaseModel):
|
|
46
|
+
model_config = ConfigDict(strict=True, extra="ignore")
|
|
47
|
+
|
|
48
|
+
id: str
|
|
49
|
+
email: str
|
|
50
|
+
verified_email: bool
|
|
51
|
+
name: str | None = None
|
|
52
|
+
given_name: str | None = None
|
|
53
|
+
family_name: str | None = None
|
|
54
|
+
picture: str | None = None
|
|
55
|
+
locale: str | None = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class GoogleOAuthProvider:
|
|
59
|
+
"""Google OAuth provider - self-contained implementation."""
|
|
60
|
+
|
|
61
|
+
AUTHORIZATION_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
|
62
|
+
TOKEN_URL = "https://oauth2.googleapis.com/token" # noqa: S105
|
|
63
|
+
USER_INFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
|
|
64
|
+
|
|
65
|
+
def __init__(self, settings: GoogleProviderSettings) -> None:
|
|
66
|
+
self.settings = settings
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def provider_id(self) -> Literal["google"]:
|
|
70
|
+
return "google"
|
|
71
|
+
|
|
72
|
+
def generate_authorization_url(self, state: str) -> str:
|
|
73
|
+
"""Generate Google OAuth authorization URL."""
|
|
74
|
+
params = {
|
|
75
|
+
"client_id": self.settings.client_id,
|
|
76
|
+
"redirect_uri": self.settings.redirect_uri,
|
|
77
|
+
"response_type": "code",
|
|
78
|
+
"scope": " ".join(self.settings.scopes),
|
|
79
|
+
"state": state,
|
|
80
|
+
"access_type": self.settings.access_type,
|
|
81
|
+
"prompt": self.settings.prompt,
|
|
82
|
+
}
|
|
83
|
+
parsed = urlparse(self.AUTHORIZATION_URL)
|
|
84
|
+
return urlunparse(
|
|
85
|
+
(
|
|
86
|
+
parsed.scheme,
|
|
87
|
+
parsed.netloc,
|
|
88
|
+
parsed.path,
|
|
89
|
+
"",
|
|
90
|
+
urlencode(params),
|
|
91
|
+
"",
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
async def exchange_code_for_tokens(self, code: str) -> dict:
|
|
96
|
+
"""Exchange authorization code for access and refresh tokens."""
|
|
97
|
+
try:
|
|
98
|
+
async with httpx.AsyncClient() as client:
|
|
99
|
+
response = await client.post(
|
|
100
|
+
self.TOKEN_URL,
|
|
101
|
+
data={
|
|
102
|
+
"client_id": self.settings.client_id,
|
|
103
|
+
"client_secret": self.settings.client_secret.get_secret_value(),
|
|
104
|
+
"code": code,
|
|
105
|
+
"redirect_uri": self.settings.redirect_uri,
|
|
106
|
+
"grant_type": "authorization_code",
|
|
107
|
+
},
|
|
108
|
+
)
|
|
109
|
+
response.raise_for_status()
|
|
110
|
+
tokens = response.json()
|
|
111
|
+
|
|
112
|
+
if "access_token" not in tokens:
|
|
113
|
+
msg = "missing required field in token response: access_token"
|
|
114
|
+
raise OAuthError(msg)
|
|
115
|
+
|
|
116
|
+
# Calculate expires_at if expires_in is present
|
|
117
|
+
return {
|
|
118
|
+
"access_token": tokens["access_token"],
|
|
119
|
+
"token_type": tokens.get("token_type"),
|
|
120
|
+
"refresh_token": tokens.get("refresh_token"),
|
|
121
|
+
"scope": tokens.get("scope"),
|
|
122
|
+
"id_token": tokens.get("id_token"),
|
|
123
|
+
"expires_at": (
|
|
124
|
+
datetime.now(UTC) + timedelta(seconds=tokens["expires_in"]) if "expires_in" in tokens else None
|
|
125
|
+
),
|
|
126
|
+
}
|
|
127
|
+
except httpx.HTTPStatusError as e:
|
|
128
|
+
# Safely extract error code from response without exposing sensitive details
|
|
129
|
+
error_detail = ""
|
|
130
|
+
try:
|
|
131
|
+
error_data = e.response.json()
|
|
132
|
+
if isinstance(error_data, dict) and "error" in error_data:
|
|
133
|
+
error_detail = f" ({error_data['error']})"
|
|
134
|
+
except (ValueError, KeyError, TypeError):
|
|
135
|
+
# Ignore JSON parsing errors or missing fields
|
|
136
|
+
pass
|
|
137
|
+
msg = f"oauth token exchange failed: {e.response.status_code}{error_detail}"
|
|
138
|
+
raise OAuthError(msg) from e
|
|
139
|
+
except httpx.RequestError as e:
|
|
140
|
+
msg = "oauth token exchange request failed"
|
|
141
|
+
raise OAuthError(msg) from e
|
|
142
|
+
|
|
143
|
+
async def get_user_info(self, access_token: str) -> GoogleUserInfo:
|
|
144
|
+
"""Fetch user information from Google using access token."""
|
|
145
|
+
try:
|
|
146
|
+
async with httpx.AsyncClient() as client:
|
|
147
|
+
response = await client.get(
|
|
148
|
+
self.USER_INFO_URL,
|
|
149
|
+
headers={"Authorization": f"Bearer {access_token}"},
|
|
150
|
+
)
|
|
151
|
+
response.raise_for_status()
|
|
152
|
+
user_data = response.json()
|
|
153
|
+
return GoogleUserInfo(**user_data)
|
|
154
|
+
except httpx.HTTPStatusError as e:
|
|
155
|
+
# Safely extract error code from response without exposing sensitive details
|
|
156
|
+
error_detail = ""
|
|
157
|
+
try:
|
|
158
|
+
error_data = e.response.json()
|
|
159
|
+
if isinstance(error_data, dict) and "error" in error_data:
|
|
160
|
+
error_detail = f" ({error_data['error']})"
|
|
161
|
+
except (ValueError, KeyError, TypeError):
|
|
162
|
+
# Ignore JSON parsing errors or missing fields
|
|
163
|
+
pass
|
|
164
|
+
msg = f"failed to fetch user info: {e.response.status_code}{error_detail}"
|
|
165
|
+
raise OAuthError(msg) from e
|
|
166
|
+
except httpx.RequestError as e:
|
|
167
|
+
msg = "user info request failed"
|
|
168
|
+
raise OAuthError(msg) from e
|
|
169
|
+
|
|
170
|
+
def get_router( # noqa: PLR0913
|
|
171
|
+
self,
|
|
172
|
+
adapter: AdapterProtocol,
|
|
173
|
+
cookie_settings: CookieSettings,
|
|
174
|
+
session_max_age: int,
|
|
175
|
+
signin_redirect: str,
|
|
176
|
+
signout_redirect: str, # noqa: ARG002
|
|
177
|
+
hook_runner: HookRunner,
|
|
178
|
+
db_dependency: Callable[[], DBConnection | AsyncGenerator[DBConnection, None]],
|
|
179
|
+
) -> APIRouter:
|
|
180
|
+
"""Create router with Google OAuth endpoints."""
|
|
181
|
+
router = APIRouter(prefix=f"/{self.provider_id}", tags=["auth", "oauth"])
|
|
182
|
+
|
|
183
|
+
async def signin(db: DBConnection = Depends(db_dependency)) -> RedirectResponse: # noqa: B008
|
|
184
|
+
"""Initiate Google OAuth flow."""
|
|
185
|
+
# Generate and store state token with expiration
|
|
186
|
+
state = generate_state_token()
|
|
187
|
+
expires_at = datetime.now(UTC) + timedelta(minutes=10)
|
|
188
|
+
await adapter.create_oauth_state(
|
|
189
|
+
db,
|
|
190
|
+
state=state,
|
|
191
|
+
expires_at=expires_at.replace(tzinfo=None),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Generate authorization URL using helper method
|
|
195
|
+
auth_url = self.generate_authorization_url(state)
|
|
196
|
+
return RedirectResponse(url=auth_url, status_code=302)
|
|
197
|
+
|
|
198
|
+
async def callback(code: str, state: str, db: DBConnection = Depends(db_dependency)) -> RedirectResponse: # noqa: B008
|
|
199
|
+
"""Handle Google OAuth callback."""
|
|
200
|
+
# Validate and delete state token (use walrus operator)
|
|
201
|
+
if not await adapter.get_oauth_state(db, state):
|
|
202
|
+
msg = "Invalid OAuth state"
|
|
203
|
+
raise InvalidStateError(msg)
|
|
204
|
+
await adapter.delete_oauth_state(db, state)
|
|
205
|
+
|
|
206
|
+
# Exchange code for tokens using helper method
|
|
207
|
+
tokens = await self.exchange_code_for_tokens(code)
|
|
208
|
+
|
|
209
|
+
# Fetch user info using helper method
|
|
210
|
+
user_info = await self.get_user_info(tokens["access_token"])
|
|
211
|
+
|
|
212
|
+
created = False
|
|
213
|
+
# Get or create user (use walrus operator)
|
|
214
|
+
if not (user := await adapter.get_user_by_email(db, user_info.email)):
|
|
215
|
+
user = await adapter.create_user(
|
|
216
|
+
db,
|
|
217
|
+
email=user_info.email,
|
|
218
|
+
email_verified=user_info.verified_email,
|
|
219
|
+
name=user_info.name,
|
|
220
|
+
image=user_info.picture,
|
|
221
|
+
)
|
|
222
|
+
created = True
|
|
223
|
+
|
|
224
|
+
# Create or update OAuth account (use dict.get for optional tokens)
|
|
225
|
+
if await adapter.get_account_by_user_and_provider(
|
|
226
|
+
db,
|
|
227
|
+
user.id,
|
|
228
|
+
self.provider_id,
|
|
229
|
+
):
|
|
230
|
+
await adapter.update_account(
|
|
231
|
+
db,
|
|
232
|
+
user_id=user.id,
|
|
233
|
+
provider=self.provider_id,
|
|
234
|
+
access_token=tokens["access_token"],
|
|
235
|
+
refresh_token=tokens.get("refresh_token"),
|
|
236
|
+
expires_at=tokens.get("expires_at"),
|
|
237
|
+
scope=tokens.get("scope"),
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
await adapter.create_account(
|
|
241
|
+
db,
|
|
242
|
+
user_id=user.id,
|
|
243
|
+
provider=self.provider_id,
|
|
244
|
+
provider_account_id=user_info.id,
|
|
245
|
+
access_token=tokens["access_token"],
|
|
246
|
+
refresh_token=tokens.get("refresh_token"),
|
|
247
|
+
expires_at=tokens.get("expires_at"),
|
|
248
|
+
scope=tokens.get("scope"),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Hooks: signup (only on create)
|
|
252
|
+
if created:
|
|
253
|
+
async with hook_runner.dispatch("on_signup", HookContext(user=user, db=db)):
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
# Create session with proper expiration
|
|
257
|
+
expires_at = datetime.now(UTC) + timedelta(seconds=session_max_age)
|
|
258
|
+
session = await adapter.create_session(
|
|
259
|
+
db,
|
|
260
|
+
user_id=user.id,
|
|
261
|
+
expires_at=expires_at.replace(tzinfo=None),
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
async with hook_runner.dispatch("on_signin", HookContext(user=user, db=db)):
|
|
265
|
+
pass
|
|
266
|
+
|
|
267
|
+
# Set session cookie using centralized cookie settings
|
|
268
|
+
response = RedirectResponse(url=signin_redirect, status_code=302)
|
|
269
|
+
response.set_cookie(
|
|
270
|
+
key=cookie_settings.name,
|
|
271
|
+
value=str(session.id),
|
|
272
|
+
max_age=session_max_age,
|
|
273
|
+
httponly=cookie_settings.http_only,
|
|
274
|
+
secure=cookie_settings.secure,
|
|
275
|
+
samesite=cookie_settings.same_site,
|
|
276
|
+
domain=cookie_settings.domain,
|
|
277
|
+
)
|
|
278
|
+
return response
|
|
279
|
+
|
|
280
|
+
# Register routes
|
|
281
|
+
router.add_api_route("/signin", signin, methods=["GET"])
|
|
282
|
+
router.add_api_route("/callback", callback, methods=["GET"])
|
|
283
|
+
|
|
284
|
+
return router
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Callable # noqa: TC003
|
|
4
|
+
from typing import TYPE_CHECKING, NotRequired, Protocol, TypedDict, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from pydantic_settings import BaseSettings
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from collections.abc import AsyncGenerator
|
|
10
|
+
|
|
11
|
+
from belgie_proto import AdapterProtocol, DBConnection
|
|
12
|
+
from fastapi import APIRouter
|
|
13
|
+
|
|
14
|
+
from belgie.auth.core.hooks import HookRunner
|
|
15
|
+
from belgie.auth.core.settings import CookieSettings
|
|
16
|
+
from belgie.auth.providers.google import GoogleProviderSettings
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@runtime_checkable
|
|
20
|
+
class OAuthProviderProtocol[S: BaseSettings](Protocol):
|
|
21
|
+
"""Protocol that all OAuth providers must implement."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, settings: S) -> None: ...
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def provider_id(self) -> str: ...
|
|
27
|
+
|
|
28
|
+
def get_router( # noqa: PLR0913
|
|
29
|
+
self,
|
|
30
|
+
adapter: AdapterProtocol,
|
|
31
|
+
cookie_settings: CookieSettings,
|
|
32
|
+
session_max_age: int,
|
|
33
|
+
signin_redirect: str,
|
|
34
|
+
signout_redirect: str,
|
|
35
|
+
hook_runner: HookRunner,
|
|
36
|
+
db_dependency: Callable[[], DBConnection | AsyncGenerator[DBConnection, None]],
|
|
37
|
+
) -> APIRouter: ...
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Providers(TypedDict, total=False):
|
|
41
|
+
"""Type-safe provider registry for Auth initialization."""
|
|
42
|
+
|
|
43
|
+
google: NotRequired[GoogleProviderSettings]
|
belgie/auth/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
from datetime import UTC, datetime, timedelta
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from belgie_proto import (
|
|
5
|
+
AccountProtocol,
|
|
6
|
+
AdapterProtocol,
|
|
7
|
+
DBConnection,
|
|
8
|
+
OAuthStateProtocol,
|
|
9
|
+
SessionProtocol,
|
|
10
|
+
UserProtocol,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SessionManager[
|
|
15
|
+
UserT: UserProtocol,
|
|
16
|
+
AccountT: AccountProtocol,
|
|
17
|
+
SessionT: SessionProtocol,
|
|
18
|
+
OAuthStateT: OAuthStateProtocol,
|
|
19
|
+
]:
|
|
20
|
+
"""Manages user sessions with sliding window expiration.
|
|
21
|
+
|
|
22
|
+
The SessionManager handles session creation, retrieval, validation, and automatic
|
|
23
|
+
expiration refresh. It implements a sliding window mechanism where sessions are
|
|
24
|
+
automatically extended when accessed within the update_age threshold.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
adapter: Database adapter for session persistence
|
|
28
|
+
max_age: Maximum session lifetime in seconds
|
|
29
|
+
update_age: Minimum time before expiry to trigger session refresh (in seconds)
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
>>> manager = SessionManager(
|
|
33
|
+
... adapter=adapter,
|
|
34
|
+
... max_age=3600 * 24 * 7, # 7 days
|
|
35
|
+
... update_age=3600, # Refresh if < 1 hour until expiry
|
|
36
|
+
... )
|
|
37
|
+
>>> session = await manager.create_session(db, user_id=user.id)
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
adapter: AdapterProtocol[UserT, AccountT, SessionT, OAuthStateT],
|
|
43
|
+
max_age: int,
|
|
44
|
+
update_age: int,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Initialize the SessionManager.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
adapter: Database adapter for session persistence
|
|
50
|
+
max_age: Maximum session lifetime in seconds
|
|
51
|
+
update_age: Minimum time before expiry to trigger session refresh (in seconds)
|
|
52
|
+
"""
|
|
53
|
+
self.adapter = adapter
|
|
54
|
+
self.max_age = max_age
|
|
55
|
+
self.update_age = update_age
|
|
56
|
+
|
|
57
|
+
async def create_session(
|
|
58
|
+
self,
|
|
59
|
+
db: DBConnection,
|
|
60
|
+
user_id: UUID,
|
|
61
|
+
ip_address: str | None = None,
|
|
62
|
+
user_agent: str | None = None,
|
|
63
|
+
) -> SessionT:
|
|
64
|
+
"""Create a new session for a user.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
db: Database connection
|
|
68
|
+
user_id: UUID of the user
|
|
69
|
+
ip_address: Optional IP address of the client
|
|
70
|
+
user_agent: Optional User-Agent string of the client
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Newly created session object
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> session = await manager.create_session(
|
|
77
|
+
... db, user_id=user.id, ip_address="192.168.1.1", user_agent="Mozilla/5.0..."
|
|
78
|
+
... )
|
|
79
|
+
"""
|
|
80
|
+
expires_at = datetime.now(UTC) + timedelta(seconds=self.max_age)
|
|
81
|
+
return await self.adapter.create_session(
|
|
82
|
+
db,
|
|
83
|
+
user_id=user_id,
|
|
84
|
+
expires_at=expires_at,
|
|
85
|
+
ip_address=ip_address,
|
|
86
|
+
user_agent=user_agent,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
async def get_session(
|
|
90
|
+
self,
|
|
91
|
+
db: DBConnection,
|
|
92
|
+
session_id: UUID,
|
|
93
|
+
) -> SessionT | None:
|
|
94
|
+
"""Retrieve and validate a session with sliding window refresh.
|
|
95
|
+
|
|
96
|
+
Retrieves the session, checks if it's expired (deletes if expired), and
|
|
97
|
+
automatically extends the expiration if the session is within update_age
|
|
98
|
+
of expiring (sliding window mechanism).
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
db: Database connection
|
|
102
|
+
session_id: UUID of the session to retrieve
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Valid session object or None if not found/expired
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
>>> session = await manager.get_session(db, session_id)
|
|
109
|
+
>>> if session:
|
|
110
|
+
... print(f"Session expires at {session.expires_at}")
|
|
111
|
+
... else:
|
|
112
|
+
... print("Session not found or expired")
|
|
113
|
+
"""
|
|
114
|
+
session = await self.adapter.get_session(db, session_id)
|
|
115
|
+
|
|
116
|
+
if not session:
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
now = datetime.now(UTC)
|
|
120
|
+
|
|
121
|
+
if session.expires_at.replace(tzinfo=UTC) <= now:
|
|
122
|
+
await self.adapter.delete_session(db, session_id)
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
time_until_expiry = session.expires_at.replace(tzinfo=UTC) - now
|
|
126
|
+
if time_until_expiry.total_seconds() < self.update_age:
|
|
127
|
+
new_expires_at = now + timedelta(seconds=self.max_age)
|
|
128
|
+
session = await self.adapter.update_session(
|
|
129
|
+
db,
|
|
130
|
+
session_id,
|
|
131
|
+
expires_at=new_expires_at,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
return session
|
|
135
|
+
|
|
136
|
+
async def delete_session(self, db: DBConnection, session_id: UUID) -> bool:
|
|
137
|
+
"""Delete a session.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
db: Database connection
|
|
141
|
+
session_id: UUID of the session to delete
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
True if session was deleted, False if it didn't exist
|
|
145
|
+
|
|
146
|
+
Example:
|
|
147
|
+
>>> deleted = await manager.delete_session(db, session_id)
|
|
148
|
+
>>> if deleted:
|
|
149
|
+
... print("Session deleted successfully")
|
|
150
|
+
"""
|
|
151
|
+
return await self.adapter.delete_session(db, session_id)
|
|
152
|
+
|
|
153
|
+
async def cleanup_expired_sessions(self, db: DBConnection) -> int:
|
|
154
|
+
"""Delete all expired sessions from the database.
|
|
155
|
+
|
|
156
|
+
Useful for periodic cleanup tasks to remove stale session data.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
db: Database connection
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
Number of sessions deleted
|
|
163
|
+
|
|
164
|
+
Example:
|
|
165
|
+
>>> count = await manager.cleanup_expired_sessions(db)
|
|
166
|
+
>>> print(f"Deleted {count} expired sessions")
|
|
167
|
+
"""
|
|
168
|
+
return await self.adapter.delete_expired_sessions(db)
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from collections.abc import Sequence
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def parse_scopes(scopes_str: str) -> list[str]:
|
|
6
|
+
scopes_str = scopes_str.strip()
|
|
7
|
+
|
|
8
|
+
if not scopes_str:
|
|
9
|
+
return []
|
|
10
|
+
|
|
11
|
+
if scopes_str.startswith("["):
|
|
12
|
+
try:
|
|
13
|
+
parsed = json.loads(scopes_str)
|
|
14
|
+
if isinstance(parsed, list):
|
|
15
|
+
return [str(scope) for scope in parsed]
|
|
16
|
+
except json.JSONDecodeError:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
return [scope.strip() for scope in scopes_str.split(",") if scope.strip()]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def validate_scopes[S: str](
|
|
23
|
+
user_scopes: Sequence[S] | None,
|
|
24
|
+
required_scopes: Sequence[S],
|
|
25
|
+
) -> bool:
|
|
26
|
+
# Normalize to sets for comparison
|
|
27
|
+
# Generic over any str subclass (including StrEnum)
|
|
28
|
+
# Accepts any sequence type (list, tuple, set, etc.)
|
|
29
|
+
# None is treated as empty set (no scopes)
|
|
30
|
+
user_scopes_set = set(user_scopes) if user_scopes is not None else set()
|
|
31
|
+
required_scopes_set = set(required_scopes)
|
|
32
|
+
|
|
33
|
+
# Check if all required scopes are present in user scopes
|
|
34
|
+
return required_scopes_set.issubset(user_scopes_set)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def has_any_scope[S: str](
|
|
38
|
+
user_scopes: Sequence[S] | None,
|
|
39
|
+
required_scopes: Sequence[S],
|
|
40
|
+
) -> bool:
|
|
41
|
+
# Normalize to sets for comparison
|
|
42
|
+
# Generic over any str subclass (including StrEnum)
|
|
43
|
+
# Accepts any sequence type (list, tuple, set, etc.)
|
|
44
|
+
# None is treated as empty set (no scopes)
|
|
45
|
+
user_scopes_set = set(user_scopes) if user_scopes is not None else set()
|
|
46
|
+
required_scopes_set = set(required_scopes)
|
|
47
|
+
|
|
48
|
+
# Check if user has any of the required scopes
|
|
49
|
+
return bool(user_scopes_set & required_scopes_set)
|
belgie/mcp.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""MCP re-exports for belgie consumers."""
|
|
2
|
+
|
|
3
|
+
_MCP_IMPORT_ERROR = "belgie.mcp requires the 'mcp' extra. Install with: uv add belgie[mcp]"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from belgie_mcp import hello # type: ignore[import-not-found]
|
|
7
|
+
except ModuleNotFoundError as exc:
|
|
8
|
+
raise ImportError(_MCP_IMPORT_ERROR) from exc
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"hello",
|
|
12
|
+
]
|
belgie/oauth.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""OAuth re-exports for belgie consumers."""
|
|
2
|
+
|
|
3
|
+
_OAUTH_IMPORT_ERROR = "belgie.oauth requires the 'oauth' extra. Install with: uv add belgie[oauth]"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from belgie_oauth import hello # type: ignore[import-not-found]
|
|
7
|
+
except ModuleNotFoundError as exc:
|
|
8
|
+
raise ImportError(_OAUTH_IMPORT_ERROR) from exc
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"hello",
|
|
12
|
+
]
|
belgie/proto.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Protocol re-exports for belgie consumers."""
|
|
2
|
+
|
|
3
|
+
_PROTO_IMPORT_ERROR = "belgie.proto requires belgie-proto. Install with: uv add belgie-proto"
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from belgie_proto import ( # type: ignore[import-not-found]
|
|
7
|
+
AccountProtocol,
|
|
8
|
+
AdapterProtocol,
|
|
9
|
+
OAuthStateProtocol,
|
|
10
|
+
SessionProtocol,
|
|
11
|
+
UserProtocol,
|
|
12
|
+
)
|
|
13
|
+
except ModuleNotFoundError as exc:
|
|
14
|
+
raise ImportError(_PROTO_IMPORT_ERROR) from exc
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"AccountProtocol",
|
|
18
|
+
"AdapterProtocol",
|
|
19
|
+
"OAuthStateProtocol",
|
|
20
|
+
"SessionProtocol",
|
|
21
|
+
"UserProtocol",
|
|
22
|
+
]
|