belgie-alchemy 0.1.0a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- belgie_alchemy/__init__.py +33 -0
- belgie_alchemy/__tests__/__init__.py +0 -0
- belgie_alchemy/__tests__/adapter/__init__.py +0 -0
- belgie_alchemy/__tests__/adapter/test_adapter.py +493 -0
- belgie_alchemy/__tests__/auth_models/__init__.py +0 -0
- belgie_alchemy/__tests__/auth_models/test_auth_models.py +91 -0
- belgie_alchemy/__tests__/base/__init__.py +0 -0
- belgie_alchemy/__tests__/base/test_base.py +90 -0
- belgie_alchemy/__tests__/conftest.py +39 -0
- belgie_alchemy/__tests__/fixtures/__init__.py +12 -0
- belgie_alchemy/__tests__/fixtures/database.py +38 -0
- belgie_alchemy/__tests__/fixtures/models.py +119 -0
- belgie_alchemy/__tests__/mixins/__init__.py +0 -0
- belgie_alchemy/__tests__/mixins/test_mixins.py +80 -0
- belgie_alchemy/__tests__/settings/__init__.py +0 -0
- belgie_alchemy/__tests__/settings/test_settings.py +342 -0
- belgie_alchemy/__tests__/settings/test_settings_integration.py +416 -0
- belgie_alchemy/__tests__/types/__init__.py +0 -0
- belgie_alchemy/__tests__/types/test_types.py +155 -0
- belgie_alchemy/adapter.py +323 -0
- belgie_alchemy/base.py +25 -0
- belgie_alchemy/mixins.py +83 -0
- belgie_alchemy/py.typed +0 -0
- belgie_alchemy/settings.py +146 -0
- belgie_alchemy/types.py +32 -0
- belgie_alchemy-0.1.0a4.dist-info/METADATA +266 -0
- belgie_alchemy-0.1.0a4.dist-info/RECORD +28 -0
- belgie_alchemy-0.1.0a4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
from datetime import UTC, datetime
|
|
2
|
+
from typing import Any
|
|
3
|
+
from uuid import UUID
|
|
4
|
+
|
|
5
|
+
from belgie_proto import (
|
|
6
|
+
AccountProtocol,
|
|
7
|
+
AdapterProtocol,
|
|
8
|
+
OAuthStateProtocol,
|
|
9
|
+
SessionProtocol,
|
|
10
|
+
UserProtocol,
|
|
11
|
+
)
|
|
12
|
+
from sqlalchemy import delete, select
|
|
13
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AlchemyAdapter[
|
|
17
|
+
UserT: UserProtocol,
|
|
18
|
+
AccountT: AccountProtocol,
|
|
19
|
+
SessionT: SessionProtocol,
|
|
20
|
+
OAuthStateT: OAuthStateProtocol,
|
|
21
|
+
](AdapterProtocol[UserT, AccountT, SessionT, OAuthStateT]):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
*,
|
|
25
|
+
user: type[UserT],
|
|
26
|
+
account: type[AccountT],
|
|
27
|
+
session: type[SessionT],
|
|
28
|
+
oauth_state: type[OAuthStateT],
|
|
29
|
+
) -> None:
|
|
30
|
+
self.user_model = user
|
|
31
|
+
self.account_model = account
|
|
32
|
+
self.session_model = session
|
|
33
|
+
self.oauth_state_model = oauth_state
|
|
34
|
+
|
|
35
|
+
async def create_user(
|
|
36
|
+
self,
|
|
37
|
+
session: AsyncSession,
|
|
38
|
+
email: str,
|
|
39
|
+
name: str | None = None,
|
|
40
|
+
image: str | None = None,
|
|
41
|
+
*,
|
|
42
|
+
email_verified: bool = False,
|
|
43
|
+
) -> UserT:
|
|
44
|
+
user = self.user_model(
|
|
45
|
+
email=email,
|
|
46
|
+
email_verified=email_verified,
|
|
47
|
+
name=name,
|
|
48
|
+
image=image,
|
|
49
|
+
)
|
|
50
|
+
session.add(user)
|
|
51
|
+
try:
|
|
52
|
+
await session.commit()
|
|
53
|
+
await session.refresh(user)
|
|
54
|
+
except Exception:
|
|
55
|
+
await session.rollback()
|
|
56
|
+
raise
|
|
57
|
+
return user
|
|
58
|
+
|
|
59
|
+
async def get_user_by_id(self, session: AsyncSession, user_id: UUID) -> UserT | None:
|
|
60
|
+
stmt = select(self.user_model).where(self.user_model.id == user_id)
|
|
61
|
+
result = await session.execute(stmt)
|
|
62
|
+
return result.scalar_one_or_none()
|
|
63
|
+
|
|
64
|
+
async def get_user_by_email(self, session: AsyncSession, email: str) -> UserT | None:
|
|
65
|
+
stmt = select(self.user_model).where(self.user_model.email == email)
|
|
66
|
+
result = await session.execute(stmt)
|
|
67
|
+
return result.scalar_one_or_none()
|
|
68
|
+
|
|
69
|
+
async def update_user(
|
|
70
|
+
self,
|
|
71
|
+
session: AsyncSession,
|
|
72
|
+
user_id: UUID,
|
|
73
|
+
**updates: Any, # noqa: ANN401
|
|
74
|
+
) -> UserT | None:
|
|
75
|
+
user = await self.get_user_by_id(session, user_id)
|
|
76
|
+
if not user:
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
for key, value in updates.items():
|
|
80
|
+
if hasattr(user, key):
|
|
81
|
+
setattr(user, key, value)
|
|
82
|
+
|
|
83
|
+
user.updated_at = datetime.now(UTC)
|
|
84
|
+
try:
|
|
85
|
+
await session.commit()
|
|
86
|
+
await session.refresh(user)
|
|
87
|
+
except Exception:
|
|
88
|
+
await session.rollback()
|
|
89
|
+
raise
|
|
90
|
+
return user
|
|
91
|
+
|
|
92
|
+
async def create_account(
|
|
93
|
+
self,
|
|
94
|
+
session: AsyncSession,
|
|
95
|
+
user_id: UUID,
|
|
96
|
+
provider: str,
|
|
97
|
+
provider_account_id: str,
|
|
98
|
+
**tokens: Any, # noqa: ANN401
|
|
99
|
+
) -> AccountT:
|
|
100
|
+
account = self.account_model(
|
|
101
|
+
user_id=user_id,
|
|
102
|
+
provider=provider,
|
|
103
|
+
provider_account_id=provider_account_id,
|
|
104
|
+
access_token=tokens.get("access_token"),
|
|
105
|
+
refresh_token=tokens.get("refresh_token"),
|
|
106
|
+
expires_at=tokens.get("expires_at"),
|
|
107
|
+
token_type=tokens.get("token_type"),
|
|
108
|
+
scope=tokens.get("scope"),
|
|
109
|
+
id_token=tokens.get("id_token"),
|
|
110
|
+
)
|
|
111
|
+
session.add(account)
|
|
112
|
+
try:
|
|
113
|
+
await session.commit()
|
|
114
|
+
await session.refresh(account)
|
|
115
|
+
except Exception:
|
|
116
|
+
await session.rollback()
|
|
117
|
+
raise
|
|
118
|
+
return account
|
|
119
|
+
|
|
120
|
+
async def get_account(
|
|
121
|
+
self,
|
|
122
|
+
session: AsyncSession,
|
|
123
|
+
provider: str,
|
|
124
|
+
provider_account_id: str,
|
|
125
|
+
) -> AccountT | None:
|
|
126
|
+
stmt = select(self.account_model).where(
|
|
127
|
+
self.account_model.provider == provider,
|
|
128
|
+
self.account_model.provider_account_id == provider_account_id,
|
|
129
|
+
)
|
|
130
|
+
result = await session.execute(stmt)
|
|
131
|
+
return result.scalar_one_or_none()
|
|
132
|
+
|
|
133
|
+
async def get_account_by_user_and_provider(
|
|
134
|
+
self,
|
|
135
|
+
session: AsyncSession,
|
|
136
|
+
user_id: UUID,
|
|
137
|
+
provider: str,
|
|
138
|
+
) -> AccountT | None:
|
|
139
|
+
stmt = select(self.account_model).where(
|
|
140
|
+
self.account_model.user_id == user_id,
|
|
141
|
+
self.account_model.provider == provider,
|
|
142
|
+
)
|
|
143
|
+
result = await session.execute(stmt)
|
|
144
|
+
return result.scalar_one_or_none()
|
|
145
|
+
|
|
146
|
+
async def update_account(
|
|
147
|
+
self,
|
|
148
|
+
session: AsyncSession,
|
|
149
|
+
user_id: UUID,
|
|
150
|
+
provider: str,
|
|
151
|
+
**tokens: Any, # noqa: ANN401
|
|
152
|
+
) -> AccountT | None:
|
|
153
|
+
account = await self.get_account_by_user_and_provider(session, user_id, provider)
|
|
154
|
+
if not account:
|
|
155
|
+
return None
|
|
156
|
+
|
|
157
|
+
for key, value in tokens.items():
|
|
158
|
+
if hasattr(account, key) and value is not None:
|
|
159
|
+
setattr(account, key, value)
|
|
160
|
+
|
|
161
|
+
account.updated_at = datetime.now(UTC)
|
|
162
|
+
try:
|
|
163
|
+
await session.commit()
|
|
164
|
+
await session.refresh(account)
|
|
165
|
+
except Exception:
|
|
166
|
+
await session.rollback()
|
|
167
|
+
raise
|
|
168
|
+
return account
|
|
169
|
+
|
|
170
|
+
async def create_session(
|
|
171
|
+
self,
|
|
172
|
+
session: AsyncSession,
|
|
173
|
+
user_id: UUID,
|
|
174
|
+
expires_at: datetime,
|
|
175
|
+
ip_address: str | None = None,
|
|
176
|
+
user_agent: str | None = None,
|
|
177
|
+
) -> SessionT:
|
|
178
|
+
session_obj = self.session_model(
|
|
179
|
+
user_id=user_id,
|
|
180
|
+
expires_at=expires_at,
|
|
181
|
+
ip_address=ip_address,
|
|
182
|
+
user_agent=user_agent,
|
|
183
|
+
)
|
|
184
|
+
session.add(session_obj)
|
|
185
|
+
try:
|
|
186
|
+
await session.commit()
|
|
187
|
+
await session.refresh(session_obj)
|
|
188
|
+
except Exception:
|
|
189
|
+
await session.rollback()
|
|
190
|
+
raise
|
|
191
|
+
return session_obj
|
|
192
|
+
|
|
193
|
+
async def get_session(
|
|
194
|
+
self,
|
|
195
|
+
session: AsyncSession,
|
|
196
|
+
session_id: UUID,
|
|
197
|
+
) -> SessionT | None:
|
|
198
|
+
stmt = select(self.session_model).where(self.session_model.id == session_id)
|
|
199
|
+
result = await session.execute(stmt)
|
|
200
|
+
return result.scalar_one_or_none()
|
|
201
|
+
|
|
202
|
+
async def update_session(
|
|
203
|
+
self,
|
|
204
|
+
session: AsyncSession,
|
|
205
|
+
session_id: UUID,
|
|
206
|
+
**updates: Any, # noqa: ANN401
|
|
207
|
+
) -> SessionT | None:
|
|
208
|
+
session_obj = await self.get_session(session, session_id)
|
|
209
|
+
if not session_obj:
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
for key, value in updates.items():
|
|
213
|
+
if hasattr(session_obj, key):
|
|
214
|
+
setattr(session_obj, key, value)
|
|
215
|
+
|
|
216
|
+
session_obj.updated_at = datetime.now(UTC)
|
|
217
|
+
try:
|
|
218
|
+
await session.commit()
|
|
219
|
+
await session.refresh(session_obj)
|
|
220
|
+
except Exception:
|
|
221
|
+
await session.rollback()
|
|
222
|
+
raise
|
|
223
|
+
return session_obj
|
|
224
|
+
|
|
225
|
+
async def delete_session(self, session: AsyncSession, session_id: UUID) -> bool:
|
|
226
|
+
stmt = delete(self.session_model).where(self.session_model.id == session_id)
|
|
227
|
+
result = await session.execute(stmt)
|
|
228
|
+
try:
|
|
229
|
+
await session.commit()
|
|
230
|
+
except Exception:
|
|
231
|
+
await session.rollback()
|
|
232
|
+
raise
|
|
233
|
+
return result.rowcount > 0 # type: ignore[attr-defined]
|
|
234
|
+
|
|
235
|
+
async def delete_expired_sessions(self, session: AsyncSession) -> int:
|
|
236
|
+
now_naive = datetime.now(UTC).replace(tzinfo=None)
|
|
237
|
+
stmt = delete(self.session_model).where(self.session_model.expires_at < now_naive)
|
|
238
|
+
result = await session.execute(stmt)
|
|
239
|
+
try:
|
|
240
|
+
await session.commit()
|
|
241
|
+
except Exception:
|
|
242
|
+
await session.rollback()
|
|
243
|
+
raise
|
|
244
|
+
return result.rowcount # type: ignore[attr-defined]
|
|
245
|
+
|
|
246
|
+
async def create_oauth_state(
|
|
247
|
+
self,
|
|
248
|
+
session: AsyncSession,
|
|
249
|
+
state: str,
|
|
250
|
+
expires_at: datetime,
|
|
251
|
+
code_verifier: str | None = None,
|
|
252
|
+
redirect_url: str | None = None,
|
|
253
|
+
) -> OAuthStateT:
|
|
254
|
+
# Create the model instance - some models have user_id, some don't
|
|
255
|
+
try:
|
|
256
|
+
oauth_state = self.oauth_state_model(
|
|
257
|
+
state=state,
|
|
258
|
+
user_id=None,
|
|
259
|
+
code_verifier=code_verifier,
|
|
260
|
+
redirect_url=redirect_url,
|
|
261
|
+
expires_at=expires_at,
|
|
262
|
+
)
|
|
263
|
+
except TypeError:
|
|
264
|
+
# Model doesn't accept user_id (like auth package models)
|
|
265
|
+
oauth_state = self.oauth_state_model(
|
|
266
|
+
state=state,
|
|
267
|
+
code_verifier=code_verifier,
|
|
268
|
+
redirect_url=redirect_url,
|
|
269
|
+
expires_at=expires_at,
|
|
270
|
+
)
|
|
271
|
+
session.add(oauth_state)
|
|
272
|
+
try:
|
|
273
|
+
await session.commit()
|
|
274
|
+
await session.refresh(oauth_state)
|
|
275
|
+
except Exception:
|
|
276
|
+
await session.rollback()
|
|
277
|
+
raise
|
|
278
|
+
return oauth_state
|
|
279
|
+
|
|
280
|
+
async def get_oauth_state(
|
|
281
|
+
self,
|
|
282
|
+
session: AsyncSession,
|
|
283
|
+
state: str,
|
|
284
|
+
) -> OAuthStateT | None:
|
|
285
|
+
stmt = select(self.oauth_state_model).where(self.oauth_state_model.state == state)
|
|
286
|
+
result = await session.execute(stmt)
|
|
287
|
+
return result.scalar_one_or_none()
|
|
288
|
+
|
|
289
|
+
async def delete_oauth_state(self, session: AsyncSession, state: str) -> bool:
|
|
290
|
+
stmt = delete(self.oauth_state_model).where(self.oauth_state_model.state == state)
|
|
291
|
+
result = await session.execute(stmt)
|
|
292
|
+
try:
|
|
293
|
+
await session.commit()
|
|
294
|
+
except Exception:
|
|
295
|
+
await session.rollback()
|
|
296
|
+
raise
|
|
297
|
+
return result.rowcount > 0 # type: ignore[attr-defined]
|
|
298
|
+
|
|
299
|
+
async def delete_user(self, session: AsyncSession, user_id: UUID) -> bool:
|
|
300
|
+
"""Delete a user and all associated data.
|
|
301
|
+
|
|
302
|
+
Deletes the user record. Related data (sessions, accounts) are automatically
|
|
303
|
+
deleted by the database via CASCADE constraints on the foreign keys.
|
|
304
|
+
|
|
305
|
+
Note: OAuth states are not user-specific and are not deleted.
|
|
306
|
+
They will expire based on their expires_at timestamp.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
session: Database session
|
|
310
|
+
user_id: UUID of the user to delete
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
True if user was deleted, False if user didn't exist
|
|
314
|
+
"""
|
|
315
|
+
stmt = delete(self.user_model).where(self.user_model.id == user_id)
|
|
316
|
+
result = await session.execute(stmt)
|
|
317
|
+
try:
|
|
318
|
+
await session.commit()
|
|
319
|
+
except Exception:
|
|
320
|
+
await session.rollback()
|
|
321
|
+
raise
|
|
322
|
+
|
|
323
|
+
return result.rowcount > 0 # type: ignore[attr-defined]
|
belgie_alchemy/base.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Any, ClassVar, Final
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import MetaData
|
|
5
|
+
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass
|
|
6
|
+
|
|
7
|
+
from belgie_alchemy.types import DateTimeUTC
|
|
8
|
+
|
|
9
|
+
NAMING_CONVENTION: Final[dict[str, str]] = {
|
|
10
|
+
"ix": "ix_%(column_0_label)s",
|
|
11
|
+
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
|
12
|
+
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
|
13
|
+
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
14
|
+
"pk": "pk_%(table_name)s",
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
TYPE_ANNOTATION_MAP: Final[dict[type | Any, object]] = {
|
|
18
|
+
datetime: DateTimeUTC,
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class Base(MappedAsDataclass, DeclarativeBase):
|
|
23
|
+
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
|
24
|
+
type_annotation_map = TYPE_ANNOTATION_MAP
|
|
25
|
+
__sa_dataclass_kwargs__: ClassVar[dict[str, bool]] = {"kw_only": True, "repr": True, "eq": True}
|
belgie_alchemy/mixins.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from uuid import UUID, uuid4
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import func
|
|
5
|
+
from sqlalchemy.orm import Mapped, MappedAsDataclass, declarative_mixin, mapped_column
|
|
6
|
+
|
|
7
|
+
from belgie_alchemy.types import DateTimeUTC
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@declarative_mixin
|
|
11
|
+
class PrimaryKeyMixin(MappedAsDataclass):
|
|
12
|
+
"""Mixin that adds a UUID primary key column.
|
|
13
|
+
|
|
14
|
+
Inherits from MappedAsDataclass to support standalone usage without Base.
|
|
15
|
+
When used with Base (which also inherits MappedAsDataclass), the duplicate
|
|
16
|
+
inheritance is safely handled by Python's MRO (Method Resolution Order).
|
|
17
|
+
|
|
18
|
+
The id field is excluded from __init__ (init=False) and is automatically
|
|
19
|
+
generated both client-side (default_factory=uuid4) and server-side
|
|
20
|
+
(server_default=gen_random_uuid()) for maximum compatibility.
|
|
21
|
+
|
|
22
|
+
Usage:
|
|
23
|
+
class MyModel(Base, PrimaryKeyMixin, TimestampMixin):
|
|
24
|
+
__tablename__ = "my_table"
|
|
25
|
+
name: Mapped[str]
|
|
26
|
+
|
|
27
|
+
The UUID is:
|
|
28
|
+
- Generated client-side by default (uuid4)
|
|
29
|
+
- Has server-side fallback (gen_random_uuid() for PostgreSQL)
|
|
30
|
+
- Indexed and unique for efficient lookups
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
id: Mapped[UUID] = mapped_column(
|
|
34
|
+
primary_key=True,
|
|
35
|
+
default_factory=uuid4,
|
|
36
|
+
server_default=func.gen_random_uuid(),
|
|
37
|
+
index=True,
|
|
38
|
+
unique=True,
|
|
39
|
+
init=False,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@declarative_mixin
|
|
44
|
+
class TimestampMixin(MappedAsDataclass):
|
|
45
|
+
"""Mixin that adds automatic timestamp tracking columns.
|
|
46
|
+
|
|
47
|
+
Inherits from MappedAsDataclass to support standalone usage without Base.
|
|
48
|
+
When used with Base (which also inherits MappedAsDataclass), the duplicate
|
|
49
|
+
inheritance is safely handled by Python's MRO (Method Resolution Order).
|
|
50
|
+
|
|
51
|
+
All timestamp fields are excluded from __init__ (init=False) and are
|
|
52
|
+
automatically managed by the database using UTC-aware datetimes.
|
|
53
|
+
|
|
54
|
+
Usage:
|
|
55
|
+
class MyModel(Base, PrimaryKeyMixin, TimestampMixin):
|
|
56
|
+
__tablename__ = "my_table"
|
|
57
|
+
name: Mapped[str]
|
|
58
|
+
|
|
59
|
+
Fields:
|
|
60
|
+
created_at: Set automatically on insert (UTC-aware)
|
|
61
|
+
updated_at: Set automatically on insert and update (UTC-aware)
|
|
62
|
+
deleted_at: NULL by default, set via mark_deleted() for soft deletion
|
|
63
|
+
|
|
64
|
+
Soft Deletion:
|
|
65
|
+
Use mark_deleted() to mark an entity as deleted without removing it
|
|
66
|
+
from the database. Remember to commit the session after calling.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
created_at: Mapped[datetime] = mapped_column(DateTimeUTC, default=func.now(), init=False)
|
|
70
|
+
updated_at: Mapped[datetime] = mapped_column(DateTimeUTC, default=func.now(), onupdate=func.now(), init=False)
|
|
71
|
+
deleted_at: Mapped[datetime | None] = mapped_column(DateTimeUTC, nullable=True, default=None, init=False)
|
|
72
|
+
|
|
73
|
+
def mark_deleted(self) -> None:
|
|
74
|
+
"""Mark this entity as deleted by setting deleted_at timestamp.
|
|
75
|
+
|
|
76
|
+
Note: This only sets the field, it does not persist to the database.
|
|
77
|
+
You must commit the session to save the change.
|
|
78
|
+
|
|
79
|
+
Example:
|
|
80
|
+
user.mark_deleted()
|
|
81
|
+
await session.commit() # Persist the soft delete
|
|
82
|
+
"""
|
|
83
|
+
self.deleted_at = func.now()
|
belgie_alchemy/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import TYPE_CHECKING, Annotated, Literal, cast
|
|
6
|
+
|
|
7
|
+
from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveInt, SecretStr
|
|
8
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
9
|
+
from sqlalchemy import event
|
|
10
|
+
from sqlalchemy.engine import URL
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import sqlite3
|
|
15
|
+
from collections.abc import AsyncGenerator
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PostgresSettings(BaseSettings):
|
|
19
|
+
"""PostgreSQL database settings.
|
|
20
|
+
|
|
21
|
+
Environment variables use the prefix: BELGIE_POSTGRES_
|
|
22
|
+
Example: BELGIE_POSTGRES_HOST=localhost
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_POSTGRES_", extra="ignore")
|
|
26
|
+
|
|
27
|
+
type: Literal["postgres"] = "postgres"
|
|
28
|
+
host: str
|
|
29
|
+
port: PositiveInt = 5432
|
|
30
|
+
database: str
|
|
31
|
+
username: str
|
|
32
|
+
password: SecretStr
|
|
33
|
+
pool_size: PositiveInt = 5
|
|
34
|
+
max_overflow: NonNegativeInt = 10
|
|
35
|
+
pool_timeout: NonNegativeFloat = 30.0
|
|
36
|
+
pool_recycle: PositiveInt = 3600
|
|
37
|
+
pool_pre_ping: bool = True
|
|
38
|
+
echo: bool = False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SqliteSettings(BaseSettings):
|
|
42
|
+
"""SQLite database settings.
|
|
43
|
+
|
|
44
|
+
Environment variables use the prefix: BELGIE_SQLITE_
|
|
45
|
+
Example: BELGIE_SQLITE_DATABASE=:memory:
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
model_config = SettingsConfigDict(env_prefix="BELGIE_SQLITE_", extra="ignore")
|
|
49
|
+
|
|
50
|
+
type: Literal["sqlite"] = "sqlite"
|
|
51
|
+
database: str
|
|
52
|
+
enable_foreign_keys: bool = True
|
|
53
|
+
echo: bool = False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DatabaseSettings(BaseSettings):
|
|
57
|
+
"""Database settings with support for PostgreSQL and SQLite.
|
|
58
|
+
|
|
59
|
+
Environment variables:
|
|
60
|
+
- BELGIE_DATABASE_TYPE: "postgres" or "sqlite" (default: "sqlite")
|
|
61
|
+
- For PostgreSQL: BELGIE_POSTGRES_HOST, BELGIE_POSTGRES_PORT, etc.
|
|
62
|
+
- For SQLite: BELGIE_SQLITE_DATABASE, BELGIE_SQLITE_ENABLE_FOREIGN_KEYS, etc.
|
|
63
|
+
|
|
64
|
+
Example usage:
|
|
65
|
+
# From environment variables
|
|
66
|
+
db = DatabaseSettings.from_env()
|
|
67
|
+
|
|
68
|
+
# Direct instantiation
|
|
69
|
+
db = DatabaseSettings(dialect={"type": "sqlite", "database": ":memory:"})
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
model_config = SettingsConfigDict(
|
|
73
|
+
env_prefix="BELGIE_DATABASE_",
|
|
74
|
+
extra="ignore",
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
dialect: Annotated[PostgresSettings | SqliteSettings, Field(discriminator="type")]
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def from_env(cls) -> DatabaseSettings:
|
|
81
|
+
"""Load database settings from environment variables.
|
|
82
|
+
|
|
83
|
+
Reads BELGIE_DATABASE_TYPE to determine which dialect to use,
|
|
84
|
+
then loads the appropriate settings from BELGIE_POSTGRES_* or BELGIE_SQLITE_* vars.
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
DatabaseSettings instance configured from environment variables.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
# Set environment variables
|
|
91
|
+
os.environ["BELGIE_DATABASE_TYPE"] = "postgres"
|
|
92
|
+
os.environ["BELGIE_POSTGRES_HOST"] = "localhost"
|
|
93
|
+
os.environ["BELGIE_POSTGRES_DATABASE"] = "mydb"
|
|
94
|
+
# ... other postgres settings
|
|
95
|
+
|
|
96
|
+
db = DatabaseSettings.from_env()
|
|
97
|
+
"""
|
|
98
|
+
db_type = os.getenv("BELGIE_DATABASE_TYPE", "sqlite")
|
|
99
|
+
|
|
100
|
+
if db_type == "postgres":
|
|
101
|
+
return cls(dialect=PostgresSettings()) # type: ignore[call-arg]
|
|
102
|
+
return cls(dialect=SqliteSettings()) # type: ignore[call-arg]
|
|
103
|
+
|
|
104
|
+
@cached_property
|
|
105
|
+
def engine(self) -> AsyncEngine:
|
|
106
|
+
if self.dialect.type == "postgres":
|
|
107
|
+
dialect = cast("PostgresSettings", self.dialect)
|
|
108
|
+
url = URL.create(
|
|
109
|
+
"postgresql+asyncpg",
|
|
110
|
+
username=dialect.username,
|
|
111
|
+
password=dialect.password.get_secret_value(),
|
|
112
|
+
host=dialect.host,
|
|
113
|
+
port=dialect.port,
|
|
114
|
+
database=dialect.database,
|
|
115
|
+
)
|
|
116
|
+
return create_async_engine(
|
|
117
|
+
url,
|
|
118
|
+
echo=dialect.echo,
|
|
119
|
+
pool_size=dialect.pool_size,
|
|
120
|
+
max_overflow=dialect.max_overflow,
|
|
121
|
+
pool_timeout=dialect.pool_timeout,
|
|
122
|
+
pool_recycle=dialect.pool_recycle,
|
|
123
|
+
pool_pre_ping=dialect.pool_pre_ping,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
dialect = cast("SqliteSettings", self.dialect)
|
|
127
|
+
url = URL.create("sqlite+aiosqlite", database=dialect.database)
|
|
128
|
+
engine = create_async_engine(url, echo=dialect.echo)
|
|
129
|
+
|
|
130
|
+
if dialect.enable_foreign_keys:
|
|
131
|
+
|
|
132
|
+
@event.listens_for(engine.sync_engine, "connect")
|
|
133
|
+
def _enable_foreign_keys(dbapi_conn: sqlite3.Connection, _conn_record: object) -> None:
|
|
134
|
+
cursor = dbapi_conn.cursor()
|
|
135
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
|
136
|
+
cursor.close()
|
|
137
|
+
|
|
138
|
+
return engine
|
|
139
|
+
|
|
140
|
+
@cached_property
|
|
141
|
+
def session_maker(self) -> async_sessionmaker[AsyncSession]:
|
|
142
|
+
return async_sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False)
|
|
143
|
+
|
|
144
|
+
async def dependency(self) -> AsyncGenerator[AsyncSession, None]:
|
|
145
|
+
async with self.session_maker() as session:
|
|
146
|
+
yield session
|
belgie_alchemy/types.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from datetime import UTC, datetime
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import DateTime
|
|
5
|
+
from sqlalchemy.types import TypeDecorator
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DateTimeUTC(TypeDecorator[datetime]):
|
|
9
|
+
impl = DateTime(timezone=True)
|
|
10
|
+
cache_ok = True
|
|
11
|
+
|
|
12
|
+
def process_bind_param(self, value: datetime | None, _dialect: Any) -> datetime | None: # type: ignore[override] # noqa: ANN401
|
|
13
|
+
if value is None:
|
|
14
|
+
return None
|
|
15
|
+
if not isinstance(value, datetime):
|
|
16
|
+
type_name = type(value).__name__
|
|
17
|
+
msg = (
|
|
18
|
+
f"DateTimeUTC requires datetime object, got {type_name}. "
|
|
19
|
+
f"If using a date, convert to datetime first: "
|
|
20
|
+
f"datetime.combine(your_date, time())"
|
|
21
|
+
)
|
|
22
|
+
raise TypeError(msg)
|
|
23
|
+
if value.tzinfo is None:
|
|
24
|
+
value = value.replace(tzinfo=UTC)
|
|
25
|
+
return value.astimezone(UTC)
|
|
26
|
+
|
|
27
|
+
def process_result_value(self, value: Any, _dialect: Any) -> datetime | None: # type: ignore[override] # noqa: ANN401
|
|
28
|
+
if value is None:
|
|
29
|
+
return None
|
|
30
|
+
if value.tzinfo is None:
|
|
31
|
+
value = value.replace(tzinfo=UTC)
|
|
32
|
+
return value.astimezone(UTC)
|