belgie 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,284 @@
1
+ from collections.abc import AsyncGenerator, Callable
2
+ from datetime import UTC, datetime, timedelta
3
+ from typing import Literal
4
+ from urllib.parse import urlencode, urlparse, urlunparse
5
+
6
+ import httpx
7
+ from belgie_proto import AdapterProtocol, DBConnection
8
+ from fastapi import APIRouter, Depends
9
+ from fastapi.responses import RedirectResponse
10
+ from pydantic import BaseModel, ConfigDict, Field
11
+ from pydantic_settings import SettingsConfigDict
12
+
13
+ from belgie.auth.core.exceptions import InvalidStateError, OAuthError
14
+ from belgie.auth.core.hooks import HookContext, HookRunner
15
+ from belgie.auth.core.settings import CookieSettings, ProviderSettings
16
+ from belgie.auth.utils.crypto import generate_state_token
17
+
18
+
19
+ class GoogleProviderSettings(ProviderSettings):
20
+ """Google OAuth provider settings loaded from environment.
21
+
22
+ Contains only Google-specific OAuth configuration.
23
+ Session and redirect settings are passed via get_router() parameters.
24
+ """
25
+
26
+ model_config = SettingsConfigDict(
27
+ env_prefix="BELGIE_GOOGLE_",
28
+ env_file=".env",
29
+ extra="ignore",
30
+ )
31
+
32
+ scopes: list[str] = Field(default=["openid", "email", "profile"])
33
+ access_type: str = Field(default="offline")
34
+ prompt: str = Field(default="consent")
35
+
36
+ def __call__(self) -> "GoogleOAuthProvider":
37
+ """Create and return Google OAuth provider instance.
38
+
39
+ Returns:
40
+ GoogleOAuthProvider configured with these settings
41
+ """
42
+ return GoogleOAuthProvider(settings=self)
43
+
44
+
45
+ class GoogleUserInfo(BaseModel):
46
+ model_config = ConfigDict(strict=True, extra="ignore")
47
+
48
+ id: str
49
+ email: str
50
+ verified_email: bool
51
+ name: str | None = None
52
+ given_name: str | None = None
53
+ family_name: str | None = None
54
+ picture: str | None = None
55
+ locale: str | None = None
56
+
57
+
58
+ class GoogleOAuthProvider:
59
+ """Google OAuth provider - self-contained implementation."""
60
+
61
+ AUTHORIZATION_URL = "https://accounts.google.com/o/oauth2/v2/auth"
62
+ TOKEN_URL = "https://oauth2.googleapis.com/token" # noqa: S105
63
+ USER_INFO_URL = "https://www.googleapis.com/oauth2/v2/userinfo"
64
+
65
+ def __init__(self, settings: GoogleProviderSettings) -> None:
66
+ self.settings = settings
67
+
68
+ @property
69
+ def provider_id(self) -> Literal["google"]:
70
+ return "google"
71
+
72
+ def generate_authorization_url(self, state: str) -> str:
73
+ """Generate Google OAuth authorization URL."""
74
+ params = {
75
+ "client_id": self.settings.client_id,
76
+ "redirect_uri": self.settings.redirect_uri,
77
+ "response_type": "code",
78
+ "scope": " ".join(self.settings.scopes),
79
+ "state": state,
80
+ "access_type": self.settings.access_type,
81
+ "prompt": self.settings.prompt,
82
+ }
83
+ parsed = urlparse(self.AUTHORIZATION_URL)
84
+ return urlunparse(
85
+ (
86
+ parsed.scheme,
87
+ parsed.netloc,
88
+ parsed.path,
89
+ "",
90
+ urlencode(params),
91
+ "",
92
+ ),
93
+ )
94
+
95
+ async def exchange_code_for_tokens(self, code: str) -> dict:
96
+ """Exchange authorization code for access and refresh tokens."""
97
+ try:
98
+ async with httpx.AsyncClient() as client:
99
+ response = await client.post(
100
+ self.TOKEN_URL,
101
+ data={
102
+ "client_id": self.settings.client_id,
103
+ "client_secret": self.settings.client_secret.get_secret_value(),
104
+ "code": code,
105
+ "redirect_uri": self.settings.redirect_uri,
106
+ "grant_type": "authorization_code",
107
+ },
108
+ )
109
+ response.raise_for_status()
110
+ tokens = response.json()
111
+
112
+ if "access_token" not in tokens:
113
+ msg = "missing required field in token response: access_token"
114
+ raise OAuthError(msg)
115
+
116
+ # Calculate expires_at if expires_in is present
117
+ return {
118
+ "access_token": tokens["access_token"],
119
+ "token_type": tokens.get("token_type"),
120
+ "refresh_token": tokens.get("refresh_token"),
121
+ "scope": tokens.get("scope"),
122
+ "id_token": tokens.get("id_token"),
123
+ "expires_at": (
124
+ datetime.now(UTC) + timedelta(seconds=tokens["expires_in"]) if "expires_in" in tokens else None
125
+ ),
126
+ }
127
+ except httpx.HTTPStatusError as e:
128
+ # Safely extract error code from response without exposing sensitive details
129
+ error_detail = ""
130
+ try:
131
+ error_data = e.response.json()
132
+ if isinstance(error_data, dict) and "error" in error_data:
133
+ error_detail = f" ({error_data['error']})"
134
+ except (ValueError, KeyError, TypeError):
135
+ # Ignore JSON parsing errors or missing fields
136
+ pass
137
+ msg = f"oauth token exchange failed: {e.response.status_code}{error_detail}"
138
+ raise OAuthError(msg) from e
139
+ except httpx.RequestError as e:
140
+ msg = "oauth token exchange request failed"
141
+ raise OAuthError(msg) from e
142
+
143
+ async def get_user_info(self, access_token: str) -> GoogleUserInfo:
144
+ """Fetch user information from Google using access token."""
145
+ try:
146
+ async with httpx.AsyncClient() as client:
147
+ response = await client.get(
148
+ self.USER_INFO_URL,
149
+ headers={"Authorization": f"Bearer {access_token}"},
150
+ )
151
+ response.raise_for_status()
152
+ user_data = response.json()
153
+ return GoogleUserInfo(**user_data)
154
+ except httpx.HTTPStatusError as e:
155
+ # Safely extract error code from response without exposing sensitive details
156
+ error_detail = ""
157
+ try:
158
+ error_data = e.response.json()
159
+ if isinstance(error_data, dict) and "error" in error_data:
160
+ error_detail = f" ({error_data['error']})"
161
+ except (ValueError, KeyError, TypeError):
162
+ # Ignore JSON parsing errors or missing fields
163
+ pass
164
+ msg = f"failed to fetch user info: {e.response.status_code}{error_detail}"
165
+ raise OAuthError(msg) from e
166
+ except httpx.RequestError as e:
167
+ msg = "user info request failed"
168
+ raise OAuthError(msg) from e
169
+
170
+ def get_router( # noqa: PLR0913
171
+ self,
172
+ adapter: AdapterProtocol,
173
+ cookie_settings: CookieSettings,
174
+ session_max_age: int,
175
+ signin_redirect: str,
176
+ signout_redirect: str, # noqa: ARG002
177
+ hook_runner: HookRunner,
178
+ db_dependency: Callable[[], DBConnection | AsyncGenerator[DBConnection, None]],
179
+ ) -> APIRouter:
180
+ """Create router with Google OAuth endpoints."""
181
+ router = APIRouter(prefix=f"/{self.provider_id}", tags=["auth", "oauth"])
182
+
183
+ async def signin(db: DBConnection = Depends(db_dependency)) -> RedirectResponse: # noqa: B008
184
+ """Initiate Google OAuth flow."""
185
+ # Generate and store state token with expiration
186
+ state = generate_state_token()
187
+ expires_at = datetime.now(UTC) + timedelta(minutes=10)
188
+ await adapter.create_oauth_state(
189
+ db,
190
+ state=state,
191
+ expires_at=expires_at.replace(tzinfo=None),
192
+ )
193
+
194
+ # Generate authorization URL using helper method
195
+ auth_url = self.generate_authorization_url(state)
196
+ return RedirectResponse(url=auth_url, status_code=302)
197
+
198
+ async def callback(code: str, state: str, db: DBConnection = Depends(db_dependency)) -> RedirectResponse: # noqa: B008
199
+ """Handle Google OAuth callback."""
200
+ # Validate and delete state token (use walrus operator)
201
+ if not await adapter.get_oauth_state(db, state):
202
+ msg = "Invalid OAuth state"
203
+ raise InvalidStateError(msg)
204
+ await adapter.delete_oauth_state(db, state)
205
+
206
+ # Exchange code for tokens using helper method
207
+ tokens = await self.exchange_code_for_tokens(code)
208
+
209
+ # Fetch user info using helper method
210
+ user_info = await self.get_user_info(tokens["access_token"])
211
+
212
+ created = False
213
+ # Get or create user (use walrus operator)
214
+ if not (user := await adapter.get_user_by_email(db, user_info.email)):
215
+ user = await adapter.create_user(
216
+ db,
217
+ email=user_info.email,
218
+ email_verified=user_info.verified_email,
219
+ name=user_info.name,
220
+ image=user_info.picture,
221
+ )
222
+ created = True
223
+
224
+ # Create or update OAuth account (use dict.get for optional tokens)
225
+ if await adapter.get_account_by_user_and_provider(
226
+ db,
227
+ user.id,
228
+ self.provider_id,
229
+ ):
230
+ await adapter.update_account(
231
+ db,
232
+ user_id=user.id,
233
+ provider=self.provider_id,
234
+ access_token=tokens["access_token"],
235
+ refresh_token=tokens.get("refresh_token"),
236
+ expires_at=tokens.get("expires_at"),
237
+ scope=tokens.get("scope"),
238
+ )
239
+ else:
240
+ await adapter.create_account(
241
+ db,
242
+ user_id=user.id,
243
+ provider=self.provider_id,
244
+ provider_account_id=user_info.id,
245
+ access_token=tokens["access_token"],
246
+ refresh_token=tokens.get("refresh_token"),
247
+ expires_at=tokens.get("expires_at"),
248
+ scope=tokens.get("scope"),
249
+ )
250
+
251
+ # Hooks: signup (only on create)
252
+ if created:
253
+ async with hook_runner.dispatch("on_signup", HookContext(user=user, db=db)):
254
+ pass
255
+
256
+ # Create session with proper expiration
257
+ expires_at = datetime.now(UTC) + timedelta(seconds=session_max_age)
258
+ session = await adapter.create_session(
259
+ db,
260
+ user_id=user.id,
261
+ expires_at=expires_at.replace(tzinfo=None),
262
+ )
263
+
264
+ async with hook_runner.dispatch("on_signin", HookContext(user=user, db=db)):
265
+ pass
266
+
267
+ # Set session cookie using centralized cookie settings
268
+ response = RedirectResponse(url=signin_redirect, status_code=302)
269
+ response.set_cookie(
270
+ key=cookie_settings.name,
271
+ value=str(session.id),
272
+ max_age=session_max_age,
273
+ httponly=cookie_settings.http_only,
274
+ secure=cookie_settings.secure,
275
+ samesite=cookie_settings.same_site,
276
+ domain=cookie_settings.domain,
277
+ )
278
+ return response
279
+
280
+ # Register routes
281
+ router.add_api_route("/signin", signin, methods=["GET"])
282
+ router.add_api_route("/callback", callback, methods=["GET"])
283
+
284
+ return router
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable # noqa: TC003
4
+ from typing import TYPE_CHECKING, NotRequired, Protocol, TypedDict, runtime_checkable
5
+
6
+ from pydantic_settings import BaseSettings
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import AsyncGenerator
10
+
11
+ from belgie_proto import AdapterProtocol, DBConnection
12
+ from fastapi import APIRouter
13
+
14
+ from belgie.auth.core.hooks import HookRunner
15
+ from belgie.auth.core.settings import CookieSettings
16
+ from belgie.auth.providers.google import GoogleProviderSettings
17
+
18
+
19
+ @runtime_checkable
20
+ class OAuthProviderProtocol[S: BaseSettings](Protocol):
21
+ """Protocol that all OAuth providers must implement."""
22
+
23
+ def __init__(self, settings: S) -> None: ...
24
+
25
+ @property
26
+ def provider_id(self) -> str: ...
27
+
28
+ def get_router( # noqa: PLR0913
29
+ self,
30
+ adapter: AdapterProtocol,
31
+ cookie_settings: CookieSettings,
32
+ session_max_age: int,
33
+ signin_redirect: str,
34
+ signout_redirect: str,
35
+ hook_runner: HookRunner,
36
+ db_dependency: Callable[[], DBConnection | AsyncGenerator[DBConnection, None]],
37
+ ) -> APIRouter: ...
38
+
39
+
40
+ class Providers(TypedDict, total=False):
41
+ """Type-safe provider registry for Auth initialization."""
42
+
43
+ google: NotRequired[GoogleProviderSettings]
belgie/auth/py.typed ADDED
File without changes
@@ -0,0 +1,3 @@
1
+ from belgie.auth.session.manager import SessionManager
2
+
3
+ __all__ = ["SessionManager"]
@@ -0,0 +1,168 @@
1
+ from datetime import UTC, datetime, timedelta
2
+ from uuid import UUID
3
+
4
+ from belgie_proto import (
5
+ AccountProtocol,
6
+ AdapterProtocol,
7
+ DBConnection,
8
+ OAuthStateProtocol,
9
+ SessionProtocol,
10
+ UserProtocol,
11
+ )
12
+
13
+
14
+ class SessionManager[
15
+ UserT: UserProtocol,
16
+ AccountT: AccountProtocol,
17
+ SessionT: SessionProtocol,
18
+ OAuthStateT: OAuthStateProtocol,
19
+ ]:
20
+ """Manages user sessions with sliding window expiration.
21
+
22
+ The SessionManager handles session creation, retrieval, validation, and automatic
23
+ expiration refresh. It implements a sliding window mechanism where sessions are
24
+ automatically extended when accessed within the update_age threshold.
25
+
26
+ Attributes:
27
+ adapter: Database adapter for session persistence
28
+ max_age: Maximum session lifetime in seconds
29
+ update_age: Minimum time before expiry to trigger session refresh (in seconds)
30
+
31
+ Example:
32
+ >>> manager = SessionManager(
33
+ ... adapter=adapter,
34
+ ... max_age=3600 * 24 * 7, # 7 days
35
+ ... update_age=3600, # Refresh if < 1 hour until expiry
36
+ ... )
37
+ >>> session = await manager.create_session(db, user_id=user.id)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ adapter: AdapterProtocol[UserT, AccountT, SessionT, OAuthStateT],
43
+ max_age: int,
44
+ update_age: int,
45
+ ) -> None:
46
+ """Initialize the SessionManager.
47
+
48
+ Args:
49
+ adapter: Database adapter for session persistence
50
+ max_age: Maximum session lifetime in seconds
51
+ update_age: Minimum time before expiry to trigger session refresh (in seconds)
52
+ """
53
+ self.adapter = adapter
54
+ self.max_age = max_age
55
+ self.update_age = update_age
56
+
57
+ async def create_session(
58
+ self,
59
+ db: DBConnection,
60
+ user_id: UUID,
61
+ ip_address: str | None = None,
62
+ user_agent: str | None = None,
63
+ ) -> SessionT:
64
+ """Create a new session for a user.
65
+
66
+ Args:
67
+ db: Database connection
68
+ user_id: UUID of the user
69
+ ip_address: Optional IP address of the client
70
+ user_agent: Optional User-Agent string of the client
71
+
72
+ Returns:
73
+ Newly created session object
74
+
75
+ Example:
76
+ >>> session = await manager.create_session(
77
+ ... db, user_id=user.id, ip_address="192.168.1.1", user_agent="Mozilla/5.0..."
78
+ ... )
79
+ """
80
+ expires_at = datetime.now(UTC) + timedelta(seconds=self.max_age)
81
+ return await self.adapter.create_session(
82
+ db,
83
+ user_id=user_id,
84
+ expires_at=expires_at,
85
+ ip_address=ip_address,
86
+ user_agent=user_agent,
87
+ )
88
+
89
+ async def get_session(
90
+ self,
91
+ db: DBConnection,
92
+ session_id: UUID,
93
+ ) -> SessionT | None:
94
+ """Retrieve and validate a session with sliding window refresh.
95
+
96
+ Retrieves the session, checks if it's expired (deletes if expired), and
97
+ automatically extends the expiration if the session is within update_age
98
+ of expiring (sliding window mechanism).
99
+
100
+ Args:
101
+ db: Database connection
102
+ session_id: UUID of the session to retrieve
103
+
104
+ Returns:
105
+ Valid session object or None if not found/expired
106
+
107
+ Example:
108
+ >>> session = await manager.get_session(db, session_id)
109
+ >>> if session:
110
+ ... print(f"Session expires at {session.expires_at}")
111
+ ... else:
112
+ ... print("Session not found or expired")
113
+ """
114
+ session = await self.adapter.get_session(db, session_id)
115
+
116
+ if not session:
117
+ return None
118
+
119
+ now = datetime.now(UTC)
120
+
121
+ if session.expires_at.replace(tzinfo=UTC) <= now:
122
+ await self.adapter.delete_session(db, session_id)
123
+ return None
124
+
125
+ time_until_expiry = session.expires_at.replace(tzinfo=UTC) - now
126
+ if time_until_expiry.total_seconds() < self.update_age:
127
+ new_expires_at = now + timedelta(seconds=self.max_age)
128
+ session = await self.adapter.update_session(
129
+ db,
130
+ session_id,
131
+ expires_at=new_expires_at,
132
+ )
133
+
134
+ return session
135
+
136
+ async def delete_session(self, db: DBConnection, session_id: UUID) -> bool:
137
+ """Delete a session.
138
+
139
+ Args:
140
+ db: Database connection
141
+ session_id: UUID of the session to delete
142
+
143
+ Returns:
144
+ True if session was deleted, False if it didn't exist
145
+
146
+ Example:
147
+ >>> deleted = await manager.delete_session(db, session_id)
148
+ >>> if deleted:
149
+ ... print("Session deleted successfully")
150
+ """
151
+ return await self.adapter.delete_session(db, session_id)
152
+
153
+ async def cleanup_expired_sessions(self, db: DBConnection) -> int:
154
+ """Delete all expired sessions from the database.
155
+
156
+ Useful for periodic cleanup tasks to remove stale session data.
157
+
158
+ Args:
159
+ db: Database connection
160
+
161
+ Returns:
162
+ Number of sessions deleted
163
+
164
+ Example:
165
+ >>> count = await manager.cleanup_expired_sessions(db)
166
+ >>> print(f"Deleted {count} expired sessions")
167
+ """
168
+ return await self.adapter.delete_expired_sessions(db)
@@ -0,0 +1,9 @@
1
+ from belgie.auth.utils.crypto import generate_session_id, generate_state_token
2
+ from belgie.auth.utils.scopes import parse_scopes, validate_scopes
3
+
4
+ __all__ = [
5
+ "generate_session_id",
6
+ "generate_state_token",
7
+ "parse_scopes",
8
+ "validate_scopes",
9
+ ]
@@ -0,0 +1,10 @@
1
+ import secrets
2
+ from uuid import UUID, uuid4
3
+
4
+
5
+ def generate_state_token() -> str:
6
+ return secrets.token_urlsafe(32)
7
+
8
+
9
+ def generate_session_id() -> UUID:
10
+ return uuid4()
@@ -0,0 +1,49 @@
1
+ import json
2
+ from collections.abc import Sequence
3
+
4
+
5
+ def parse_scopes(scopes_str: str) -> list[str]:
6
+ scopes_str = scopes_str.strip()
7
+
8
+ if not scopes_str:
9
+ return []
10
+
11
+ if scopes_str.startswith("["):
12
+ try:
13
+ parsed = json.loads(scopes_str)
14
+ if isinstance(parsed, list):
15
+ return [str(scope) for scope in parsed]
16
+ except json.JSONDecodeError:
17
+ pass
18
+
19
+ return [scope.strip() for scope in scopes_str.split(",") if scope.strip()]
20
+
21
+
22
+ def validate_scopes[S: str](
23
+ user_scopes: Sequence[S] | None,
24
+ required_scopes: Sequence[S],
25
+ ) -> bool:
26
+ # Normalize to sets for comparison
27
+ # Generic over any str subclass (including StrEnum)
28
+ # Accepts any sequence type (list, tuple, set, etc.)
29
+ # None is treated as empty set (no scopes)
30
+ user_scopes_set = set(user_scopes) if user_scopes is not None else set()
31
+ required_scopes_set = set(required_scopes)
32
+
33
+ # Check if all required scopes are present in user scopes
34
+ return required_scopes_set.issubset(user_scopes_set)
35
+
36
+
37
+ def has_any_scope[S: str](
38
+ user_scopes: Sequence[S] | None,
39
+ required_scopes: Sequence[S],
40
+ ) -> bool:
41
+ # Normalize to sets for comparison
42
+ # Generic over any str subclass (including StrEnum)
43
+ # Accepts any sequence type (list, tuple, set, etc.)
44
+ # None is treated as empty set (no scopes)
45
+ user_scopes_set = set(user_scopes) if user_scopes is not None else set()
46
+ required_scopes_set = set(required_scopes)
47
+
48
+ # Check if user has any of the required scopes
49
+ return bool(user_scopes_set & required_scopes_set)
belgie/mcp.py ADDED
@@ -0,0 +1,12 @@
1
+ """MCP re-exports for belgie consumers."""
2
+
3
+ _MCP_IMPORT_ERROR = "belgie.mcp requires the 'mcp' extra. Install with: uv add belgie[mcp]"
4
+
5
+ try:
6
+ from belgie_mcp import hello # type: ignore[import-not-found]
7
+ except ModuleNotFoundError as exc:
8
+ raise ImportError(_MCP_IMPORT_ERROR) from exc
9
+
10
+ __all__ = [
11
+ "hello",
12
+ ]
belgie/oauth.py ADDED
@@ -0,0 +1,12 @@
1
+ """OAuth re-exports for belgie consumers."""
2
+
3
+ _OAUTH_IMPORT_ERROR = "belgie.oauth requires the 'oauth' extra. Install with: uv add belgie[oauth]"
4
+
5
+ try:
6
+ from belgie_oauth import hello # type: ignore[import-not-found]
7
+ except ModuleNotFoundError as exc:
8
+ raise ImportError(_OAUTH_IMPORT_ERROR) from exc
9
+
10
+ __all__ = [
11
+ "hello",
12
+ ]
belgie/proto.py ADDED
@@ -0,0 +1,22 @@
1
+ """Protocol re-exports for belgie consumers."""
2
+
3
+ _PROTO_IMPORT_ERROR = "belgie.proto requires belgie-proto. Install with: uv add belgie-proto"
4
+
5
+ try:
6
+ from belgie_proto import ( # type: ignore[import-not-found]
7
+ AccountProtocol,
8
+ AdapterProtocol,
9
+ OAuthStateProtocol,
10
+ SessionProtocol,
11
+ UserProtocol,
12
+ )
13
+ except ModuleNotFoundError as exc:
14
+ raise ImportError(_PROTO_IMPORT_ERROR) from exc
15
+
16
+ __all__ = [
17
+ "AccountProtocol",
18
+ "AdapterProtocol",
19
+ "OAuthStateProtocol",
20
+ "SessionProtocol",
21
+ "UserProtocol",
22
+ ]