belgie-alchemy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,323 @@
1
+ from datetime import UTC, datetime
2
+ from typing import Any
3
+ from uuid import UUID
4
+
5
+ from belgie_proto import (
6
+ AccountProtocol,
7
+ AdapterProtocol,
8
+ OAuthStateProtocol,
9
+ SessionProtocol,
10
+ UserProtocol,
11
+ )
12
+ from sqlalchemy import delete, select
13
+ from sqlalchemy.ext.asyncio import AsyncSession
14
+
15
+
16
+ class AlchemyAdapter[
17
+ UserT: UserProtocol,
18
+ AccountT: AccountProtocol,
19
+ SessionT: SessionProtocol,
20
+ OAuthStateT: OAuthStateProtocol,
21
+ ](AdapterProtocol[UserT, AccountT, SessionT, OAuthStateT]):
22
+ def __init__(
23
+ self,
24
+ *,
25
+ user: type[UserT],
26
+ account: type[AccountT],
27
+ session: type[SessionT],
28
+ oauth_state: type[OAuthStateT],
29
+ ) -> None:
30
+ self.user_model = user
31
+ self.account_model = account
32
+ self.session_model = session
33
+ self.oauth_state_model = oauth_state
34
+
35
+ async def create_user(
36
+ self,
37
+ session: AsyncSession,
38
+ email: str,
39
+ name: str | None = None,
40
+ image: str | None = None,
41
+ *,
42
+ email_verified: bool = False,
43
+ ) -> UserT:
44
+ user = self.user_model(
45
+ email=email,
46
+ email_verified=email_verified,
47
+ name=name,
48
+ image=image,
49
+ )
50
+ session.add(user)
51
+ try:
52
+ await session.commit()
53
+ await session.refresh(user)
54
+ except Exception:
55
+ await session.rollback()
56
+ raise
57
+ return user
58
+
59
+ async def get_user_by_id(self, session: AsyncSession, user_id: UUID) -> UserT | None:
60
+ stmt = select(self.user_model).where(self.user_model.id == user_id)
61
+ result = await session.execute(stmt)
62
+ return result.scalar_one_or_none()
63
+
64
+ async def get_user_by_email(self, session: AsyncSession, email: str) -> UserT | None:
65
+ stmt = select(self.user_model).where(self.user_model.email == email)
66
+ result = await session.execute(stmt)
67
+ return result.scalar_one_or_none()
68
+
69
+ async def update_user(
70
+ self,
71
+ session: AsyncSession,
72
+ user_id: UUID,
73
+ **updates: Any, # noqa: ANN401
74
+ ) -> UserT | None:
75
+ user = await self.get_user_by_id(session, user_id)
76
+ if not user:
77
+ return None
78
+
79
+ for key, value in updates.items():
80
+ if hasattr(user, key):
81
+ setattr(user, key, value)
82
+
83
+ user.updated_at = datetime.now(UTC)
84
+ try:
85
+ await session.commit()
86
+ await session.refresh(user)
87
+ except Exception:
88
+ await session.rollback()
89
+ raise
90
+ return user
91
+
92
+ async def create_account(
93
+ self,
94
+ session: AsyncSession,
95
+ user_id: UUID,
96
+ provider: str,
97
+ provider_account_id: str,
98
+ **tokens: Any, # noqa: ANN401
99
+ ) -> AccountT:
100
+ account = self.account_model(
101
+ user_id=user_id,
102
+ provider=provider,
103
+ provider_account_id=provider_account_id,
104
+ access_token=tokens.get("access_token"),
105
+ refresh_token=tokens.get("refresh_token"),
106
+ expires_at=tokens.get("expires_at"),
107
+ token_type=tokens.get("token_type"),
108
+ scope=tokens.get("scope"),
109
+ id_token=tokens.get("id_token"),
110
+ )
111
+ session.add(account)
112
+ try:
113
+ await session.commit()
114
+ await session.refresh(account)
115
+ except Exception:
116
+ await session.rollback()
117
+ raise
118
+ return account
119
+
120
+ async def get_account(
121
+ self,
122
+ session: AsyncSession,
123
+ provider: str,
124
+ provider_account_id: str,
125
+ ) -> AccountT | None:
126
+ stmt = select(self.account_model).where(
127
+ self.account_model.provider == provider,
128
+ self.account_model.provider_account_id == provider_account_id,
129
+ )
130
+ result = await session.execute(stmt)
131
+ return result.scalar_one_or_none()
132
+
133
+ async def get_account_by_user_and_provider(
134
+ self,
135
+ session: AsyncSession,
136
+ user_id: UUID,
137
+ provider: str,
138
+ ) -> AccountT | None:
139
+ stmt = select(self.account_model).where(
140
+ self.account_model.user_id == user_id,
141
+ self.account_model.provider == provider,
142
+ )
143
+ result = await session.execute(stmt)
144
+ return result.scalar_one_or_none()
145
+
146
+ async def update_account(
147
+ self,
148
+ session: AsyncSession,
149
+ user_id: UUID,
150
+ provider: str,
151
+ **tokens: Any, # noqa: ANN401
152
+ ) -> AccountT | None:
153
+ account = await self.get_account_by_user_and_provider(session, user_id, provider)
154
+ if not account:
155
+ return None
156
+
157
+ for key, value in tokens.items():
158
+ if hasattr(account, key) and value is not None:
159
+ setattr(account, key, value)
160
+
161
+ account.updated_at = datetime.now(UTC)
162
+ try:
163
+ await session.commit()
164
+ await session.refresh(account)
165
+ except Exception:
166
+ await session.rollback()
167
+ raise
168
+ return account
169
+
170
+ async def create_session(
171
+ self,
172
+ session: AsyncSession,
173
+ user_id: UUID,
174
+ expires_at: datetime,
175
+ ip_address: str | None = None,
176
+ user_agent: str | None = None,
177
+ ) -> SessionT:
178
+ session_obj = self.session_model(
179
+ user_id=user_id,
180
+ expires_at=expires_at,
181
+ ip_address=ip_address,
182
+ user_agent=user_agent,
183
+ )
184
+ session.add(session_obj)
185
+ try:
186
+ await session.commit()
187
+ await session.refresh(session_obj)
188
+ except Exception:
189
+ await session.rollback()
190
+ raise
191
+ return session_obj
192
+
193
+ async def get_session(
194
+ self,
195
+ session: AsyncSession,
196
+ session_id: UUID,
197
+ ) -> SessionT | None:
198
+ stmt = select(self.session_model).where(self.session_model.id == session_id)
199
+ result = await session.execute(stmt)
200
+ return result.scalar_one_or_none()
201
+
202
+ async def update_session(
203
+ self,
204
+ session: AsyncSession,
205
+ session_id: UUID,
206
+ **updates: Any, # noqa: ANN401
207
+ ) -> SessionT | None:
208
+ session_obj = await self.get_session(session, session_id)
209
+ if not session_obj:
210
+ return None
211
+
212
+ for key, value in updates.items():
213
+ if hasattr(session_obj, key):
214
+ setattr(session_obj, key, value)
215
+
216
+ session_obj.updated_at = datetime.now(UTC)
217
+ try:
218
+ await session.commit()
219
+ await session.refresh(session_obj)
220
+ except Exception:
221
+ await session.rollback()
222
+ raise
223
+ return session_obj
224
+
225
+ async def delete_session(self, session: AsyncSession, session_id: UUID) -> bool:
226
+ stmt = delete(self.session_model).where(self.session_model.id == session_id)
227
+ result = await session.execute(stmt)
228
+ try:
229
+ await session.commit()
230
+ except Exception:
231
+ await session.rollback()
232
+ raise
233
+ return result.rowcount > 0 # type: ignore[attr-defined]
234
+
235
+ async def delete_expired_sessions(self, session: AsyncSession) -> int:
236
+ now_naive = datetime.now(UTC).replace(tzinfo=None)
237
+ stmt = delete(self.session_model).where(self.session_model.expires_at < now_naive)
238
+ result = await session.execute(stmt)
239
+ try:
240
+ await session.commit()
241
+ except Exception:
242
+ await session.rollback()
243
+ raise
244
+ return result.rowcount # type: ignore[attr-defined]
245
+
246
+ async def create_oauth_state(
247
+ self,
248
+ session: AsyncSession,
249
+ state: str,
250
+ expires_at: datetime,
251
+ code_verifier: str | None = None,
252
+ redirect_url: str | None = None,
253
+ ) -> OAuthStateT:
254
+ # Create the model instance - some models have user_id, some don't
255
+ try:
256
+ oauth_state = self.oauth_state_model(
257
+ state=state,
258
+ user_id=None,
259
+ code_verifier=code_verifier,
260
+ redirect_url=redirect_url,
261
+ expires_at=expires_at,
262
+ )
263
+ except TypeError:
264
+ # Model doesn't accept user_id (like auth package models)
265
+ oauth_state = self.oauth_state_model(
266
+ state=state,
267
+ code_verifier=code_verifier,
268
+ redirect_url=redirect_url,
269
+ expires_at=expires_at,
270
+ )
271
+ session.add(oauth_state)
272
+ try:
273
+ await session.commit()
274
+ await session.refresh(oauth_state)
275
+ except Exception:
276
+ await session.rollback()
277
+ raise
278
+ return oauth_state
279
+
280
+ async def get_oauth_state(
281
+ self,
282
+ session: AsyncSession,
283
+ state: str,
284
+ ) -> OAuthStateT | None:
285
+ stmt = select(self.oauth_state_model).where(self.oauth_state_model.state == state)
286
+ result = await session.execute(stmt)
287
+ return result.scalar_one_or_none()
288
+
289
+ async def delete_oauth_state(self, session: AsyncSession, state: str) -> bool:
290
+ stmt = delete(self.oauth_state_model).where(self.oauth_state_model.state == state)
291
+ result = await session.execute(stmt)
292
+ try:
293
+ await session.commit()
294
+ except Exception:
295
+ await session.rollback()
296
+ raise
297
+ return result.rowcount > 0 # type: ignore[attr-defined]
298
+
299
+ async def delete_user(self, session: AsyncSession, user_id: UUID) -> bool:
300
+ """Delete a user and all associated data.
301
+
302
+ Deletes the user record. Related data (sessions, accounts) are automatically
303
+ deleted by the database via CASCADE constraints on the foreign keys.
304
+
305
+ Note: OAuth states are not user-specific and are not deleted.
306
+ They will expire based on their expires_at timestamp.
307
+
308
+ Args:
309
+ session: Database session
310
+ user_id: UUID of the user to delete
311
+
312
+ Returns:
313
+ True if user was deleted, False if user didn't exist
314
+ """
315
+ stmt = delete(self.user_model).where(self.user_model.id == user_id)
316
+ result = await session.execute(stmt)
317
+ try:
318
+ await session.commit()
319
+ except Exception:
320
+ await session.rollback()
321
+ raise
322
+
323
+ return result.rowcount > 0 # type: ignore[attr-defined]
belgie_alchemy/base.py ADDED
@@ -0,0 +1,25 @@
1
+ from datetime import datetime
2
+ from typing import Any, ClassVar, Final
3
+
4
+ from sqlalchemy import MetaData
5
+ from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
6
+
7
+ from belgie_alchemy.types import DateTimeUTC
8
+
9
+ NAMING_CONVENTION: Final[dict[str, str]] = {
10
+ "ix": "ix_%(column_0_label)s",
11
+ "uq": "uq_%(table_name)s_%(column_0_name)s",
12
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
13
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
14
+ "pk": "pk_%(table_name)s",
15
+ }
16
+
17
+ TYPE_ANNOTATION_MAP: Final[dict[type | Any, object]] = {
18
+ datetime: DateTimeUTC,
19
+ }
20
+
21
+
22
+ class Base(MappedAsDataclass, DeclarativeBase):
23
+ metadata = MetaData(naming_convention=NAMING_CONVENTION)
24
+ type_annotation_map = TYPE_ANNOTATION_MAP
25
+ __sa_dataclass_kwargs__: ClassVar[dict[str, bool]] = {"kw_only": True, "repr": True, "eq": True}
@@ -0,0 +1,83 @@
1
+ from datetime import datetime
2
+ from uuid import UUID, uuid4
3
+
4
+ from sqlalchemy import func
5
+ from sqlalchemy.orm import Mapped, MappedAsDataclass, declarative_mixin, mapped_column
6
+
7
+ from belgie_alchemy.types import DateTimeUTC
8
+
9
+
10
+ @declarative_mixin
11
+ class PrimaryKeyMixin(MappedAsDataclass):
12
+ """Mixin that adds a UUID primary key column.
13
+
14
+ Inherits from MappedAsDataclass to support standalone usage without Base.
15
+ When used with Base (which also inherits MappedAsDataclass), the duplicate
16
+ inheritance is safely handled by Python's MRO (Method Resolution Order).
17
+
18
+ The id field is excluded from __init__ (init=False) and is automatically
19
+ generated both client-side (default_factory=uuid4) and server-side
20
+ (server_default=gen_random_uuid()) for maximum compatibility.
21
+
22
+ Usage:
23
+ class MyModel(Base, PrimaryKeyMixin, TimestampMixin):
24
+ __tablename__ = "my_table"
25
+ name: Mapped[str]
26
+
27
+ The UUID is:
28
+ - Generated client-side by default (uuid4)
29
+ - Has server-side fallback (gen_random_uuid() for PostgreSQL)
30
+ - Indexed and unique for efficient lookups
31
+ """
32
+
33
+ id: Mapped[UUID] = mapped_column(
34
+ primary_key=True,
35
+ default_factory=uuid4,
36
+ server_default=func.gen_random_uuid(),
37
+ index=True,
38
+ unique=True,
39
+ init=False,
40
+ )
41
+
42
+
43
+ @declarative_mixin
44
+ class TimestampMixin(MappedAsDataclass):
45
+ """Mixin that adds automatic timestamp tracking columns.
46
+
47
+ Inherits from MappedAsDataclass to support standalone usage without Base.
48
+ When used with Base (which also inherits MappedAsDataclass), the duplicate
49
+ inheritance is safely handled by Python's MRO (Method Resolution Order).
50
+
51
+ All timestamp fields are excluded from __init__ (init=False) and are
52
+ automatically managed by the database using UTC-aware datetimes.
53
+
54
+ Usage:
55
+ class MyModel(Base, PrimaryKeyMixin, TimestampMixin):
56
+ __tablename__ = "my_table"
57
+ name: Mapped[str]
58
+
59
+ Fields:
60
+ created_at: Set automatically on insert (UTC-aware)
61
+ updated_at: Set automatically on insert and update (UTC-aware)
62
+ deleted_at: NULL by default, set via mark_deleted() for soft deletion
63
+
64
+ Soft Deletion:
65
+ Use mark_deleted() to mark an entity as deleted without removing it
66
+ from the database. Remember to commit the session after calling.
67
+ """
68
+
69
+ created_at: Mapped[datetime] = mapped_column(DateTimeUTC, default=func.now(), init=False)
70
+ updated_at: Mapped[datetime] = mapped_column(DateTimeUTC, default=func.now(), onupdate=func.now(), init=False)
71
+ deleted_at: Mapped[datetime | None] = mapped_column(DateTimeUTC, nullable=True, default=None, init=False)
72
+
73
+ def mark_deleted(self) -> None:
74
+ """Mark this entity as deleted by setting deleted_at timestamp.
75
+
76
+ Note: This only sets the field, it does not persist to the database.
77
+ You must commit the session to save the change.
78
+
79
+ Example:
80
+ user.mark_deleted()
81
+ await session.commit() # Persist the soft delete
82
+ """
83
+ self.deleted_at = func.now()
File without changes
@@ -0,0 +1,146 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from functools import cached_property
5
+ from typing import TYPE_CHECKING, Annotated, Literal, cast
6
+
7
+ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveInt, SecretStr
8
+ from pydantic_settings import BaseSettings, SettingsConfigDict
9
+ from sqlalchemy import event
10
+ from sqlalchemy.engine import URL
11
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
12
+
13
+ if TYPE_CHECKING:
14
+ import sqlite3
15
+ from collections.abc import AsyncGenerator
16
+
17
+
18
+ class PostgresSettings(BaseSettings):
19
+ """PostgreSQL database settings.
20
+
21
+ Environment variables use the prefix: BELGIE_POSTGRES_
22
+ Example: BELGIE_POSTGRES_HOST=localhost
23
+ """
24
+
25
+ model_config = SettingsConfigDict(env_prefix="BELGIE_POSTGRES_", extra="ignore")
26
+
27
+ type: Literal["postgres"] = "postgres"
28
+ host: str
29
+ port: PositiveInt = 5432
30
+ database: str
31
+ username: str
32
+ password: SecretStr
33
+ pool_size: PositiveInt = 5
34
+ max_overflow: NonNegativeInt = 10
35
+ pool_timeout: NonNegativeFloat = 30.0
36
+ pool_recycle: PositiveInt = 3600
37
+ pool_pre_ping: bool = True
38
+ echo: bool = False
39
+
40
+
41
+ class SqliteSettings(BaseSettings):
42
+ """SQLite database settings.
43
+
44
+ Environment variables use the prefix: BELGIE_SQLITE_
45
+ Example: BELGIE_SQLITE_DATABASE=:memory:
46
+ """
47
+
48
+ model_config = SettingsConfigDict(env_prefix="BELGIE_SQLITE_", extra="ignore")
49
+
50
+ type: Literal["sqlite"] = "sqlite"
51
+ database: str
52
+ enable_foreign_keys: bool = True
53
+ echo: bool = False
54
+
55
+
56
+ class DatabaseSettings(BaseSettings):
57
+ """Database settings with support for PostgreSQL and SQLite.
58
+
59
+ Environment variables:
60
+ - BELGIE_DATABASE_TYPE: "postgres" or "sqlite" (default: "sqlite")
61
+ - For PostgreSQL: BELGIE_POSTGRES_HOST, BELGIE_POSTGRES_PORT, etc.
62
+ - For SQLite: BELGIE_SQLITE_DATABASE, BELGIE_SQLITE_ENABLE_FOREIGN_KEYS, etc.
63
+
64
+ Example usage:
65
+ # From environment variables
66
+ db = DatabaseSettings.from_env()
67
+
68
+ # Direct instantiation
69
+ db = DatabaseSettings(dialect={"type": "sqlite", "database": ":memory:"})
70
+ """
71
+
72
+ model_config = SettingsConfigDict(
73
+ env_prefix="BELGIE_DATABASE_",
74
+ extra="ignore",
75
+ )
76
+
77
+ dialect: Annotated[PostgresSettings | SqliteSettings, Field(discriminator="type")]
78
+
79
+ @classmethod
80
+ def from_env(cls) -> DatabaseSettings:
81
+ """Load database settings from environment variables.
82
+
83
+ Reads BELGIE_DATABASE_TYPE to determine which dialect to use,
84
+ then loads the appropriate settings from BELGIE_POSTGRES_* or BELGIE_SQLITE_* vars.
85
+
86
+ Returns:
87
+ DatabaseSettings instance configured from environment variables.
88
+
89
+ Example:
90
+ # Set environment variables
91
+ os.environ["BELGIE_DATABASE_TYPE"] = "postgres"
92
+ os.environ["BELGIE_POSTGRES_HOST"] = "localhost"
93
+ os.environ["BELGIE_POSTGRES_DATABASE"] = "mydb"
94
+ # ... other postgres settings
95
+
96
+ db = DatabaseSettings.from_env()
97
+ """
98
+ db_type = os.getenv("BELGIE_DATABASE_TYPE", "sqlite")
99
+
100
+ if db_type == "postgres":
101
+ return cls(dialect=PostgresSettings()) # type: ignore[call-arg]
102
+ return cls(dialect=SqliteSettings()) # type: ignore[call-arg]
103
+
104
+ @cached_property
105
+ def engine(self) -> AsyncEngine:
106
+ if self.dialect.type == "postgres":
107
+ dialect = cast("PostgresSettings", self.dialect)
108
+ url = URL.create(
109
+ "postgresql+asyncpg",
110
+ username=dialect.username,
111
+ password=dialect.password.get_secret_value(),
112
+ host=dialect.host,
113
+ port=dialect.port,
114
+ database=dialect.database,
115
+ )
116
+ return create_async_engine(
117
+ url,
118
+ echo=dialect.echo,
119
+ pool_size=dialect.pool_size,
120
+ max_overflow=dialect.max_overflow,
121
+ pool_timeout=dialect.pool_timeout,
122
+ pool_recycle=dialect.pool_recycle,
123
+ pool_pre_ping=dialect.pool_pre_ping,
124
+ )
125
+
126
+ dialect = cast("SqliteSettings", self.dialect)
127
+ url = URL.create("sqlite+aiosqlite", database=dialect.database)
128
+ engine = create_async_engine(url, echo=dialect.echo)
129
+
130
+ if dialect.enable_foreign_keys:
131
+
132
+ @event.listens_for(engine.sync_engine, "connect")
133
+ def _enable_foreign_keys(dbapi_conn: sqlite3.Connection, _conn_record: object) -> None:
134
+ cursor = dbapi_conn.cursor()
135
+ cursor.execute("PRAGMA foreign_keys=ON")
136
+ cursor.close()
137
+
138
+ return engine
139
+
140
+ @cached_property
141
+ def session_maker(self) -> async_sessionmaker[AsyncSession]:
142
+ return async_sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
143
+
144
+ async def dependency(self) -> AsyncGenerator[AsyncSession, None]:
145
+ async with self.session_maker() as session:
146
+ yield session
@@ -0,0 +1,32 @@
1
+ from datetime import UTC, datetime
2
+ from typing import Any
3
+
4
+ from sqlalchemy import DateTime
5
+ from sqlalchemy.types import TypeDecorator
6
+
7
+
8
+ class DateTimeUTC(TypeDecorator[datetime]):
9
+ impl = DateTime(timezone=True)
10
+ cache_ok = True
11
+
12
+ def process_bind_param(self, value: datetime | None, _dialect: Any) -> datetime | None: # type: ignore[override] # noqa: ANN401
13
+ if value is None:
14
+ return None
15
+ if not isinstance(value, datetime):
16
+ type_name = type(value).__name__
17
+ msg = (
18
+ f"DateTimeUTC requires datetime object, got {type_name}. "
19
+ f"If using a date, convert to datetime first: "
20
+ f"datetime.combine(your_date, time())"
21
+ )
22
+ raise TypeError(msg)
23
+ if value.tzinfo is None:
24
+ value = value.replace(tzinfo=UTC)
25
+ return value.astimezone(UTC)
26
+
27
+ def process_result_value(self, value: Any, _dialect: Any) -> datetime | None: # type: ignore[override] # noqa: ANN401
28
+ if value is None:
29
+ return None
30
+ if value.tzinfo is None:
31
+ value = value.replace(tzinfo=UTC)
32
+ return value.astimezone(UTC)