belgie 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- belgie/__init__.py +97 -0
- belgie/alchemy.py +24 -0
- belgie/auth/__init__.py +65 -0
- belgie/auth/core/__init__.py +0 -0
- belgie/auth/core/auth.py +368 -0
- belgie/auth/core/client.py +204 -0
- belgie/auth/core/exceptions.py +26 -0
- belgie/auth/core/hooks.py +87 -0
- belgie/auth/core/settings.py +100 -0
- belgie/auth/providers/__init__.py +10 -0
- belgie/auth/providers/google.py +284 -0
- belgie/auth/providers/protocols.py +43 -0
- belgie/auth/py.typed +0 -0
- belgie/auth/session/__init__.py +3 -0
- belgie/auth/session/manager.py +168 -0
- belgie/auth/utils/__init__.py +9 -0
- belgie/auth/utils/crypto.py +10 -0
- belgie/auth/utils/scopes.py +49 -0
- belgie/mcp.py +12 -0
- belgie/oauth.py +12 -0
- belgie/proto.py +22 -0
- belgie-0.1.0a4.dist-info/METADATA +231 -0
- belgie-0.1.0a4.dist-info/RECORD +24 -0
- belgie-0.1.0a4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
from belgie_proto import (
|
|
5
|
+
AccountProtocol,
|
|
6
|
+
AdapterProtocol,
|
|
7
|
+
DBConnection,
|
|
8
|
+
OAuthStateProtocol,
|
|
9
|
+
SessionProtocol,
|
|
10
|
+
UserProtocol,
|
|
11
|
+
)
|
|
12
|
+
from fastapi import HTTPException, Request, status
|
|
13
|
+
from fastapi.security import SecurityScopes
|
|
14
|
+
|
|
15
|
+
from belgie.auth.core.hooks import HookContext, HookRunner, Hooks
|
|
16
|
+
from belgie.auth.session.manager import SessionManager
|
|
17
|
+
from belgie.auth.utils.scopes import validate_scopes
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
21
|
+
class AuthClient[
|
|
22
|
+
UserT: UserProtocol,
|
|
23
|
+
AccountT: AccountProtocol,
|
|
24
|
+
SessionT: SessionProtocol,
|
|
25
|
+
OAuthStateT: OAuthStateProtocol,
|
|
26
|
+
]:
|
|
27
|
+
"""Client for authentication operations with injected database session.
|
|
28
|
+
|
|
29
|
+
This class provides authentication methods with a captured database session,
|
|
30
|
+
allowing for convenient auth operations without explicitly passing db to each method.
|
|
31
|
+
|
|
32
|
+
Typically obtained via Auth.__call__() as a FastAPI dependency:
|
|
33
|
+
client: AuthClient = Depends(auth)
|
|
34
|
+
|
|
35
|
+
Type Parameters:
|
|
36
|
+
UserT: User model type implementing UserProtocol
|
|
37
|
+
AccountT: Account model type implementing AccountProtocol
|
|
38
|
+
SessionT: Session model type implementing SessionProtocol
|
|
39
|
+
OAuthStateT: OAuth state model type implementing OAuthStateProtocol
|
|
40
|
+
|
|
41
|
+
Attributes:
|
|
42
|
+
db: Captured database connection
|
|
43
|
+
adapter: Database adapter for persistence operations
|
|
44
|
+
session_manager: Session manager for session lifecycle operations
|
|
45
|
+
cookie_name: Name of the session cookie
|
|
46
|
+
|
|
47
|
+
Example:
|
|
48
|
+
>>> @app.delete("/account")
|
|
49
|
+
>>> async def delete_account(
|
|
50
|
+
... client: AuthClient = Depends(auth),
|
|
51
|
+
... request: Request,
|
|
52
|
+
... ):
|
|
53
|
+
... user = await client.get_user(SecurityScopes(), request)
|
|
54
|
+
... await client.delete_user(user)
|
|
55
|
+
... return {"message": "Account deleted"}
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
db: DBConnection
|
|
59
|
+
adapter: AdapterProtocol[UserT, AccountT, SessionT, OAuthStateT]
|
|
60
|
+
session_manager: SessionManager[UserT, AccountT, SessionT, OAuthStateT]
|
|
61
|
+
cookie_name: str
|
|
62
|
+
hook_runner: HookRunner = field(default_factory=lambda: HookRunner(Hooks()))
|
|
63
|
+
|
|
64
|
+
async def _get_session_from_cookie(self, request: Request) -> SessionT | None:
|
|
65
|
+
"""Extract and validate session from request cookies.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
request: FastAPI Request object containing cookies
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Valid session object or None if cookie missing/invalid/expired
|
|
72
|
+
"""
|
|
73
|
+
session_id_str = request.cookies.get(self.cookie_name)
|
|
74
|
+
if not session_id_str:
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
session_id = UUID(session_id_str)
|
|
79
|
+
except ValueError:
|
|
80
|
+
return None
|
|
81
|
+
|
|
82
|
+
return await self.session_manager.get_session(self.db, session_id)
|
|
83
|
+
|
|
84
|
+
async def get_user(self, security_scopes: SecurityScopes, request: Request) -> UserT:
|
|
85
|
+
"""Get the authenticated user from the request session.
|
|
86
|
+
|
|
87
|
+
Extracts the session from cookies, validates it, and returns the authenticated user.
|
|
88
|
+
Optionally validates user-level scopes if specified.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
security_scopes: FastAPI SecurityScopes for scope validation
|
|
92
|
+
request: FastAPI Request object containing cookies
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
Authenticated user object
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
HTTPException: 401 if not authenticated or session invalid
|
|
99
|
+
HTTPException: 403 if required scopes are not granted
|
|
100
|
+
|
|
101
|
+
Example:
|
|
102
|
+
>>> user = await client.get_user(SecurityScopes(scopes=["read"]), request)
|
|
103
|
+
>>> print(user.email)
|
|
104
|
+
"""
|
|
105
|
+
session = await self._get_session_from_cookie(request)
|
|
106
|
+
if not session:
|
|
107
|
+
raise HTTPException(
|
|
108
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
109
|
+
detail="not authenticated",
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
user = await self.adapter.get_user_by_id(self.db, session.user_id)
|
|
113
|
+
if not user:
|
|
114
|
+
raise HTTPException(
|
|
115
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
116
|
+
detail="user not found",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Validate user-level scopes if required
|
|
120
|
+
if security_scopes.scopes and not validate_scopes(user.scopes, security_scopes.scopes):
|
|
121
|
+
raise HTTPException(
|
|
122
|
+
status_code=status.HTTP_403_FORBIDDEN,
|
|
123
|
+
detail="Insufficient permissions",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
return user
|
|
127
|
+
|
|
128
|
+
async def get_session(self, request: Request) -> SessionT:
|
|
129
|
+
"""Get the current session from the request.
|
|
130
|
+
|
|
131
|
+
Extracts and validates the session from cookies.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
request: FastAPI Request object containing cookies
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Active session object
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
HTTPException: 401 if not authenticated or session invalid/expired
|
|
141
|
+
|
|
142
|
+
Example:
|
|
143
|
+
>>> session = await client.get_session(request)
|
|
144
|
+
>>> print(session.expires_at)
|
|
145
|
+
"""
|
|
146
|
+
session = await self._get_session_from_cookie(request)
|
|
147
|
+
if not session:
|
|
148
|
+
raise HTTPException(
|
|
149
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
150
|
+
detail="not authenticated",
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
return session
|
|
154
|
+
|
|
155
|
+
async def delete_user(self, user: UserT) -> bool:
|
|
156
|
+
"""Delete a user and all associated data."""
|
|
157
|
+
async with self.hook_runner.dispatch("on_delete", HookContext(user=user, db=self.db)):
|
|
158
|
+
return await self.adapter.delete_user(self.db, user.id)
|
|
159
|
+
|
|
160
|
+
async def get_user_from_session(self, session_id: UUID) -> UserT | None:
|
|
161
|
+
"""Retrieve user from a session ID.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
session_id: UUID of the session
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
User object if session is valid and user exists, None otherwise
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
>>> from uuid import UUID
|
|
171
|
+
>>> session_id = UUID("...")
|
|
172
|
+
>>> user = await client.get_user_from_session(session_id)
|
|
173
|
+
>>> if user:
|
|
174
|
+
... print(f"Found user: {user.email}")
|
|
175
|
+
"""
|
|
176
|
+
session = await self.session_manager.get_session(self.db, session_id)
|
|
177
|
+
if not session:
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
return await self.adapter.get_user_by_id(self.db, session.user_id)
|
|
181
|
+
|
|
182
|
+
async def sign_out(self, session_id: UUID) -> bool:
|
|
183
|
+
"""Sign out a user by deleting their session.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
session_id: UUID of the session to delete
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
True if session was deleted, False if session didn't exist
|
|
190
|
+
|
|
191
|
+
Example:
|
|
192
|
+
>>> session = await client.get_session(request)
|
|
193
|
+
>>> await client.sign_out(session.id)
|
|
194
|
+
"""
|
|
195
|
+
session = await self.session_manager.get_session(self.db, session_id)
|
|
196
|
+
if not session:
|
|
197
|
+
return False
|
|
198
|
+
|
|
199
|
+
user = await self.adapter.get_user_by_id(self.db, session.user_id)
|
|
200
|
+
if not user:
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
async with self.hook_runner.dispatch("on_signout", HookContext(user=user, db=self.db)):
|
|
204
|
+
return await self.session_manager.delete_session(self.db, session_id)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
class BelgieError(Exception):
|
|
2
|
+
pass
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class AuthenticationError(BelgieError):
|
|
6
|
+
pass
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AuthorizationError(BelgieError):
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SessionExpiredError(AuthenticationError):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class InvalidStateError(BelgieError):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OAuthError(BelgieError):
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ConfigurationError(BelgieError):
|
|
26
|
+
pass
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
|
|
5
|
+
from contextlib import AbstractAsyncContextManager, AbstractContextManager, AsyncExitStack, asynccontextmanager
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from belgie_proto import DBConnection
|
|
11
|
+
else: # pragma: no cover
|
|
12
|
+
DBConnection = object
|
|
13
|
+
|
|
14
|
+
from belgie_proto import UserProtocol
|
|
15
|
+
|
|
16
|
+
HookEvent = Literal["on_signup", "on_signin", "on_signout", "on_delete"]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
20
|
+
class HookContext[UserT: UserProtocol]:
|
|
21
|
+
user: UserT
|
|
22
|
+
db: DBConnection
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
type HookFunc = Callable[[HookContext], None | Awaitable[None]]
|
|
26
|
+
type HookCtxMgr = Callable[[HookContext], AbstractContextManager[None] | AbstractAsyncContextManager[None]]
|
|
27
|
+
type HookHandler = HookFunc | HookCtxMgr
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True, slots=True, kw_only=True)
|
|
31
|
+
class Hooks:
|
|
32
|
+
on_signup: HookHandler | Sequence[HookHandler] | None = None
|
|
33
|
+
on_signin: HookHandler | Sequence[HookHandler] | None = None
|
|
34
|
+
on_signout: HookHandler | Sequence[HookHandler] | None = None
|
|
35
|
+
on_delete: HookHandler | Sequence[HookHandler] | None = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class HookRunner:
|
|
39
|
+
def __init__(self, hooks: Hooks) -> None:
|
|
40
|
+
self._hooks = hooks
|
|
41
|
+
|
|
42
|
+
@asynccontextmanager
|
|
43
|
+
async def dispatch(self, event: HookEvent | str, context: HookContext) -> AsyncIterator[None]:
|
|
44
|
+
handlers = self._normalize(self._handlers_for(event))
|
|
45
|
+
|
|
46
|
+
if not handlers:
|
|
47
|
+
yield
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
async with AsyncExitStack() as stack:
|
|
51
|
+
for handler in handlers:
|
|
52
|
+
result = handler(context)
|
|
53
|
+
|
|
54
|
+
if hasattr(result, "__aenter__") and hasattr(result, "__aexit__"):
|
|
55
|
+
await stack.enter_async_context(result) # type: ignore[arg-type]
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
if hasattr(result, "__enter__") and hasattr(result, "__exit__"):
|
|
59
|
+
stack.enter_context(result) # type: ignore[arg-type]
|
|
60
|
+
continue
|
|
61
|
+
|
|
62
|
+
if inspect.isawaitable(result):
|
|
63
|
+
await result
|
|
64
|
+
|
|
65
|
+
yield
|
|
66
|
+
|
|
67
|
+
def _handlers_for(self, event: HookEvent | str) -> HookHandler | Sequence[HookHandler] | None:
|
|
68
|
+
match event:
|
|
69
|
+
case "on_signup":
|
|
70
|
+
return self._hooks.on_signup
|
|
71
|
+
case "on_signin":
|
|
72
|
+
return self._hooks.on_signin
|
|
73
|
+
case "on_signout":
|
|
74
|
+
return self._hooks.on_signout
|
|
75
|
+
case "on_delete":
|
|
76
|
+
return self._hooks.on_delete
|
|
77
|
+
case _:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
def _normalize(self, handlers: HookHandler | Sequence[HookHandler] | None) -> list[HookHandler]:
|
|
81
|
+
if handlers is None:
|
|
82
|
+
return []
|
|
83
|
+
|
|
84
|
+
if isinstance(handlers, Sequence) and not isinstance(handlers, (str, bytes)):
|
|
85
|
+
return list(cast("Sequence[HookHandler]", handlers))
|
|
86
|
+
|
|
87
|
+
return [cast("HookHandler", handlers)]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr, field_validator
|
|
5
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from belgie.auth.providers.protocols import OAuthProviderProtocol
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProviderSettings(BaseSettings):
|
|
12
|
+
"""Base settings class for OAuth providers.
|
|
13
|
+
|
|
14
|
+
All provider-specific settings should inherit from this class
|
|
15
|
+
to ensure consistent configuration structure.
|
|
16
|
+
|
|
17
|
+
Subclasses must implement __call__ to construct their provider instance.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
client_id: str
|
|
21
|
+
client_secret: SecretStr
|
|
22
|
+
redirect_uri: str
|
|
23
|
+
|
|
24
|
+
@field_validator("client_id", "redirect_uri")
|
|
25
|
+
@classmethod
|
|
26
|
+
def validate_non_empty(cls, v: str, info) -> str: # noqa: ANN001
|
|
27
|
+
"""Ensure required OAuth fields are non-empty."""
|
|
28
|
+
if not v or not v.strip():
|
|
29
|
+
msg = f"{info.field_name} must be a non-empty string"
|
|
30
|
+
raise ValueError(msg)
|
|
31
|
+
return v.strip()
|
|
32
|
+
|
|
33
|
+
@field_validator("client_secret")
|
|
34
|
+
@classmethod
|
|
35
|
+
def validate_client_secret(cls, v: SecretStr) -> SecretStr:
|
|
36
|
+
"""Ensure client_secret is non-empty and trim whitespace."""
|
|
37
|
+
secret_value = v.get_secret_value()
|
|
38
|
+
if not secret_value or not secret_value.strip():
|
|
39
|
+
msg = "client_secret must be a non-empty string"
|
|
40
|
+
raise ValueError(msg)
|
|
41
|
+
# Return a new SecretStr with trimmed value
|
|
42
|
+
return SecretStr(secret_value.strip())
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def __call__(self) -> "OAuthProviderProtocol":
|
|
46
|
+
"""Create and return the OAuth provider instance.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
OAuth provider configured with these settings
|
|
50
|
+
"""
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class SessionSettings(BaseSettings):
|
|
55
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_SESSION_")
|
|
56
|
+
|
|
57
|
+
max_age: int = Field(default=604800)
|
|
58
|
+
update_age: int = Field(default=86400)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class CookieSettings(BaseSettings):
|
|
62
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_COOKIE_")
|
|
63
|
+
|
|
64
|
+
name: str = Field(default="belgie_session")
|
|
65
|
+
secure: bool = Field(default=True)
|
|
66
|
+
http_only: bool = Field(default=True)
|
|
67
|
+
same_site: Literal["lax", "strict", "none"] = Field(default="lax")
|
|
68
|
+
domain: str | None = Field(default=None)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class URLSettings(BaseSettings):
|
|
72
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_URLS_")
|
|
73
|
+
|
|
74
|
+
signin_redirect: str = Field(default="/dashboard")
|
|
75
|
+
signout_redirect: str = Field(default="/")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class AuthSettings(BaseSettings):
|
|
79
|
+
model_config = SettingsConfigDict(
|
|
80
|
+
env_prefix="BELGIE_",
|
|
81
|
+
env_file=".env",
|
|
82
|
+
env_file_encoding="utf-8",
|
|
83
|
+
case_sensitive=False,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
secret: str
|
|
87
|
+
base_url: str
|
|
88
|
+
|
|
89
|
+
session: SessionSettings = Field(default_factory=SessionSettings)
|
|
90
|
+
cookie: CookieSettings = Field(default_factory=CookieSettings)
|
|
91
|
+
urls: URLSettings = Field(default_factory=URLSettings)
|
|
92
|
+
|
|
93
|
+
@field_validator("secret", "base_url")
|
|
94
|
+
@classmethod
|
|
95
|
+
def validate_non_empty(cls, value: str, info) -> str: # noqa: ANN001
|
|
96
|
+
"""Ensure required Auth settings are non-empty."""
|
|
97
|
+
if not value or not value.strip():
|
|
98
|
+
msg = f"{info.field_name} must be a non-empty string"
|
|
99
|
+
raise ValueError(msg)
|
|
100
|
+
return value.strip()
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from belgie.auth.providers.google import GoogleOAuthProvider, GoogleProviderSettings, GoogleUserInfo
|
|
2
|
+
from belgie.auth.providers.protocols import OAuthProviderProtocol, Providers
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"GoogleOAuthProvider",
|
|
6
|
+
"GoogleProviderSettings",
|
|
7
|
+
"GoogleUserInfo",
|
|
8
|
+
"OAuthProviderProtocol",
|
|
9
|
+
"Providers",
|
|
10
|
+
]
|