vention-storage 0.1.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- storage/accessor.py +280 -0
- storage/auditor.py +69 -0
- storage/bootstrap.py +47 -0
- storage/database.py +107 -0
- storage/hooks.py +47 -0
- storage/io_helpers.py +171 -0
- storage/router_database.py +275 -0
- storage/router_model.py +247 -0
- storage/utils.py +20 -0
- vention_storage-0.5.2.dist-info/METADATA +318 -0
- vention_storage-0.5.2.dist-info/RECORD +13 -0
- {vention_storage-0.1.0.dist-info → vention_storage-0.5.2.dist-info}/WHEEL +1 -1
- vention_storage-0.1.0.dist-info/METADATA +0 -17
- vention_storage-0.1.0.dist-info/RECORD +0 -4
storage/accessor.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
Generic,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Sequence,
|
|
11
|
+
Type,
|
|
12
|
+
TypeVar,
|
|
13
|
+
cast,
|
|
14
|
+
Iterator,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from sqlmodel import SQLModel, Session, select
|
|
18
|
+
|
|
19
|
+
from storage.auditor import audit_operation
|
|
20
|
+
from storage import database
|
|
21
|
+
from storage.hooks import HookFn, HookRegistry, HookEvent
|
|
22
|
+
from storage.utils import ModelType, utcnow, Operation
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
WriteResult = TypeVar("WriteResult")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelAccessor(Generic[ModelType]):
|
|
29
|
+
"""
|
|
30
|
+
Accessor for a single SQLModel type with:
|
|
31
|
+
- strongly-typed lifecycle hooks (before/after insert/update/delete)
|
|
32
|
+
- atomic writes with auditing
|
|
33
|
+
- optional soft delete (if model defines `deleted_at`)
|
|
34
|
+
- batch helpers
|
|
35
|
+
- implicit session reuse inside hooks (no .bind() needed)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, model: Type[ModelType], component_name: str) -> None:
|
|
39
|
+
self.model = model
|
|
40
|
+
self.component = component_name
|
|
41
|
+
self._hooks: HookRegistry[ModelType] = HookRegistry()
|
|
42
|
+
self._has_soft_delete = hasattr(model, "deleted_at")
|
|
43
|
+
|
|
44
|
+
# ---------- Hook decorators ----------
|
|
45
|
+
def before_insert(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
46
|
+
return self._hooks.decorator("before_insert")
|
|
47
|
+
|
|
48
|
+
def after_insert(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
49
|
+
return self._hooks.decorator("after_insert")
|
|
50
|
+
|
|
51
|
+
def before_update(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
52
|
+
return self._hooks.decorator("before_update")
|
|
53
|
+
|
|
54
|
+
def after_update(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
55
|
+
return self._hooks.decorator("after_update")
|
|
56
|
+
|
|
57
|
+
def before_delete(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
58
|
+
return self._hooks.decorator("before_delete")
|
|
59
|
+
|
|
60
|
+
def after_delete(self) -> Callable[[HookFn[ModelType]], HookFn[ModelType]]:
|
|
61
|
+
return self._hooks.decorator("after_delete")
|
|
62
|
+
|
|
63
|
+
# ---------- Internal helpers ----------
|
|
64
|
+
def _emit(self, event: HookEvent, *, session: Session, instance: ModelType) -> None:
|
|
65
|
+
"""Make this session visible to any accessor calls done inside hooks."""
|
|
66
|
+
with database.use_session(session):
|
|
67
|
+
self._hooks.emit(event, session=session, instance=instance)
|
|
68
|
+
|
|
69
|
+
def _audit_create_operation(
|
|
70
|
+
self, *, session: Session, instance: ModelType, actor: str
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Audit a create operation."""
|
|
73
|
+
audit_operation(
|
|
74
|
+
session=session,
|
|
75
|
+
component=self.component,
|
|
76
|
+
operation="create",
|
|
77
|
+
record_id=int(getattr(instance, "id")),
|
|
78
|
+
actor=actor,
|
|
79
|
+
before=None,
|
|
80
|
+
after=instance.model_dump(),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def _run_write(self, fn: Callable[[Session], WriteResult]) -> WriteResult:
|
|
84
|
+
"""Run a write op using the current session if present, else open a transaction."""
|
|
85
|
+
existing = database.CURRENT_SESSION.get()
|
|
86
|
+
if existing is not None:
|
|
87
|
+
return fn(existing)
|
|
88
|
+
with database.transaction() as session:
|
|
89
|
+
return fn(session)
|
|
90
|
+
|
|
91
|
+
@contextmanager
|
|
92
|
+
def _read_session(self) -> Iterator[Session]:
|
|
93
|
+
"""Reuse current session if present; otherwise open a short-lived one."""
|
|
94
|
+
existing = database.CURRENT_SESSION.get()
|
|
95
|
+
if existing is not None:
|
|
96
|
+
yield existing
|
|
97
|
+
else:
|
|
98
|
+
with Session(database.get_engine(), expire_on_commit=False) as session:
|
|
99
|
+
yield session
|
|
100
|
+
|
|
101
|
+
# ---------- Reads ----------
|
|
102
|
+
def get(self, id: int, *, include_deleted: bool = False) -> Optional[ModelType]:
|
|
103
|
+
"""Get a single model by id."""
|
|
104
|
+
with self._read_session() as session:
|
|
105
|
+
obj = session.get(self.model, id)
|
|
106
|
+
if obj is None:
|
|
107
|
+
return None
|
|
108
|
+
if self._has_soft_delete and not include_deleted:
|
|
109
|
+
if getattr(obj, "deleted_at") is not None:
|
|
110
|
+
return None
|
|
111
|
+
return cast(ModelType, obj)
|
|
112
|
+
|
|
113
|
+
def all(self, *, include_deleted: bool = False) -> List[ModelType]:
|
|
114
|
+
"""Get all models."""
|
|
115
|
+
with self._read_session() as session:
|
|
116
|
+
statement = select(self.model)
|
|
117
|
+
if self._has_soft_delete and not include_deleted:
|
|
118
|
+
statement = statement.where(getattr(self.model, "deleted_at").is_(None))
|
|
119
|
+
return cast(List[ModelType], session.exec(statement).all())
|
|
120
|
+
|
|
121
|
+
# ---------- Writes ----------
|
|
122
|
+
def insert(self, obj: ModelType, *, actor: str = "internal") -> ModelType:
|
|
123
|
+
"""Insert a new model."""
|
|
124
|
+
|
|
125
|
+
def write_operation(session: Session) -> ModelType:
|
|
126
|
+
self._emit("before_insert", session=session, instance=obj)
|
|
127
|
+
session.add(obj)
|
|
128
|
+
session.flush()
|
|
129
|
+
session.refresh(obj)
|
|
130
|
+
self._audit_create_operation(session=session, instance=obj, actor=actor)
|
|
131
|
+
self._emit("after_insert", session=session, instance=obj)
|
|
132
|
+
return obj
|
|
133
|
+
|
|
134
|
+
return self._run_write(write_operation)
|
|
135
|
+
|
|
136
|
+
def save(self, obj: ModelType, *, actor: str = "internal") -> ModelType:
|
|
137
|
+
"""Save a model, creating it if it doesn't exist."""
|
|
138
|
+
|
|
139
|
+
def write_operation(session: Session) -> ModelType:
|
|
140
|
+
obj_id = cast(Optional[int], getattr(obj, "id", None))
|
|
141
|
+
if obj_id is None:
|
|
142
|
+
return self.insert(obj, actor=actor)
|
|
143
|
+
|
|
144
|
+
current = session.get(self.model, obj_id)
|
|
145
|
+
if current is None:
|
|
146
|
+
return self.insert(obj, actor=actor)
|
|
147
|
+
|
|
148
|
+
before = current.model_dump()
|
|
149
|
+
merged = session.merge(obj)
|
|
150
|
+
self._emit("before_update", session=session, instance=merged)
|
|
151
|
+
session.flush()
|
|
152
|
+
session.refresh(merged)
|
|
153
|
+
audit_operation(
|
|
154
|
+
session=session,
|
|
155
|
+
component=self.component,
|
|
156
|
+
operation="update",
|
|
157
|
+
record_id=int(getattr(merged, "id")),
|
|
158
|
+
actor=actor,
|
|
159
|
+
before=before,
|
|
160
|
+
after=merged.model_dump(),
|
|
161
|
+
)
|
|
162
|
+
self._emit("after_update", session=session, instance=merged)
|
|
163
|
+
return cast(ModelType, merged)
|
|
164
|
+
|
|
165
|
+
return self._run_write(write_operation)
|
|
166
|
+
|
|
167
|
+
def delete(self, id: int, *, actor: str = "internal") -> bool:
|
|
168
|
+
"""Delete a model."""
|
|
169
|
+
|
|
170
|
+
def write_operation(session: Session) -> bool:
|
|
171
|
+
obj = session.get(self.model, id)
|
|
172
|
+
if obj is None:
|
|
173
|
+
return False
|
|
174
|
+
self._emit("before_delete", session=session, instance=obj)
|
|
175
|
+
op_name, before_payload, after_payload = _soft_or_hard_delete(session, obj)
|
|
176
|
+
audit_operation(
|
|
177
|
+
session=session,
|
|
178
|
+
component=self.component,
|
|
179
|
+
operation=op_name,
|
|
180
|
+
record_id=id,
|
|
181
|
+
actor=actor,
|
|
182
|
+
before=before_payload,
|
|
183
|
+
after=after_payload,
|
|
184
|
+
)
|
|
185
|
+
self._emit("after_delete", session=session, instance=obj)
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
return self._run_write(write_operation)
|
|
189
|
+
|
|
190
|
+
def restore(self, id: int, *, actor: str = "internal") -> bool:
|
|
191
|
+
"""Restore a soft-deleted model."""
|
|
192
|
+
|
|
193
|
+
def write_operation(session: Session) -> bool:
|
|
194
|
+
obj = session.get(self.model, id)
|
|
195
|
+
if obj is None or not self._has_soft_delete:
|
|
196
|
+
return False
|
|
197
|
+
if getattr(obj, "deleted_at") is None:
|
|
198
|
+
return True
|
|
199
|
+
before = obj.model_dump()
|
|
200
|
+
setattr(obj, "deleted_at", None)
|
|
201
|
+
session.add(obj)
|
|
202
|
+
session.flush()
|
|
203
|
+
session.refresh(obj)
|
|
204
|
+
audit_operation(
|
|
205
|
+
session=session,
|
|
206
|
+
component=self.component,
|
|
207
|
+
operation="restore",
|
|
208
|
+
record_id=id,
|
|
209
|
+
actor=actor,
|
|
210
|
+
before=before,
|
|
211
|
+
after=obj.model_dump(),
|
|
212
|
+
)
|
|
213
|
+
return True
|
|
214
|
+
|
|
215
|
+
return self._run_write(write_operation)
|
|
216
|
+
|
|
217
|
+
# ---------- Batch helpers ----------
|
|
218
|
+
def insert_many(
|
|
219
|
+
self, objs: Sequence[ModelType], *, actor: str = "internal"
|
|
220
|
+
) -> List[ModelType]:
|
|
221
|
+
"""Insert multiple models."""
|
|
222
|
+
|
|
223
|
+
def write_operation(session: Session) -> List[ModelType]:
|
|
224
|
+
out: List[ModelType] = []
|
|
225
|
+
for obj in objs:
|
|
226
|
+
self._emit("before_insert", session=session, instance=obj)
|
|
227
|
+
session.add(obj)
|
|
228
|
+
session.flush()
|
|
229
|
+
for obj in objs:
|
|
230
|
+
session.refresh(obj)
|
|
231
|
+
self._audit_create_operation(session=session, instance=obj, actor=actor)
|
|
232
|
+
self._emit("after_insert", session=session, instance=obj)
|
|
233
|
+
out.append(obj)
|
|
234
|
+
return out
|
|
235
|
+
|
|
236
|
+
return self._run_write(write_operation)
|
|
237
|
+
|
|
238
|
+
def delete_many(self, ids: Sequence[int], *, actor: str = "internal") -> int:
|
|
239
|
+
"""Delete multiple models."""
|
|
240
|
+
|
|
241
|
+
def write_operation(session: Session) -> int:
|
|
242
|
+
count = 0
|
|
243
|
+
for id_ in ids:
|
|
244
|
+
obj = session.get(self.model, id_)
|
|
245
|
+
if obj is None:
|
|
246
|
+
continue
|
|
247
|
+
self._emit("before_delete", session=session, instance=obj)
|
|
248
|
+
op_name, before_payload, after_payload = _soft_or_hard_delete(
|
|
249
|
+
session, obj
|
|
250
|
+
)
|
|
251
|
+
audit_operation(
|
|
252
|
+
session=session,
|
|
253
|
+
component=self.component,
|
|
254
|
+
operation=op_name,
|
|
255
|
+
record_id=id_,
|
|
256
|
+
actor=actor,
|
|
257
|
+
before=before_payload,
|
|
258
|
+
after=after_payload,
|
|
259
|
+
)
|
|
260
|
+
self._emit("after_delete", session=session, instance=obj)
|
|
261
|
+
count += 1
|
|
262
|
+
return count
|
|
263
|
+
|
|
264
|
+
return self._run_write(write_operation)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _soft_or_hard_delete(
|
|
268
|
+
session: Session, instance: SQLModel
|
|
269
|
+
) -> tuple[Operation, dict[str, Any], dict[str, Any] | None]:
|
|
270
|
+
"""Soft delete if model defines `deleted_at`, else hard delete."""
|
|
271
|
+
before_payload = instance.model_dump()
|
|
272
|
+
if hasattr(instance, "deleted_at"):
|
|
273
|
+
setattr(instance, "deleted_at", utcnow())
|
|
274
|
+
session.add(instance)
|
|
275
|
+
session.flush()
|
|
276
|
+
after_payload = instance.model_dump()
|
|
277
|
+
return "soft_delete", before_payload, after_payload
|
|
278
|
+
else:
|
|
279
|
+
session.delete(instance)
|
|
280
|
+
return "delete", before_payload, None
|
storage/auditor.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime, date
|
|
4
|
+
from typing import Any, Dict, Optional, cast
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Column, String
|
|
7
|
+
from sqlalchemy.dialects.sqlite import JSON
|
|
8
|
+
from sqlmodel import Field, SQLModel, Session
|
|
9
|
+
from storage.utils import utcnow, Operation
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
__all__ = ["AuditLog", "audit_operation"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AuditLog(SQLModel, table=True): # type: ignore[misc, call-arg]
|
|
16
|
+
id: Optional[int] = Field(default=None, primary_key=True)
|
|
17
|
+
|
|
18
|
+
# Queryable identifiers
|
|
19
|
+
timestamp: datetime = Field(index=True)
|
|
20
|
+
component: str = Field(index=True)
|
|
21
|
+
record_id: int = Field(index=True)
|
|
22
|
+
|
|
23
|
+
# What happened and by whom
|
|
24
|
+
operation: Operation = Field(sa_column=Column(String, nullable=False))
|
|
25
|
+
actor: str
|
|
26
|
+
|
|
27
|
+
# Snapshot of state change
|
|
28
|
+
before: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
|
|
29
|
+
after: Optional[Dict[str, Any]] = Field(default=None, sa_column=Column(JSON))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _json_default(obj: Any) -> Any:
|
|
33
|
+
if isinstance(obj, (datetime, date)):
|
|
34
|
+
return obj.isoformat()
|
|
35
|
+
# add more cases as needed
|
|
36
|
+
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _jsonify(value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|
40
|
+
if value is None:
|
|
41
|
+
return None
|
|
42
|
+
# round-trip through json to coerce unsupported types (e.g., datetime) into strings
|
|
43
|
+
return cast(Dict[str, Any], json.loads(json.dumps(value, default=_json_default)))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def audit_operation(
|
|
47
|
+
*,
|
|
48
|
+
session: Session,
|
|
49
|
+
component: str,
|
|
50
|
+
operation: Operation,
|
|
51
|
+
record_id: int,
|
|
52
|
+
actor: str,
|
|
53
|
+
before: Optional[Dict[str, Any]] = None,
|
|
54
|
+
after: Optional[Dict[str, Any]] = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Record a single audit event. Call this INSIDE the same transaction as the data change.
|
|
58
|
+
"""
|
|
59
|
+
session.add(
|
|
60
|
+
AuditLog(
|
|
61
|
+
timestamp=utcnow(),
|
|
62
|
+
component=component,
|
|
63
|
+
record_id=record_id,
|
|
64
|
+
operation=operation,
|
|
65
|
+
actor=actor,
|
|
66
|
+
before=_jsonify(before),
|
|
67
|
+
after=_jsonify(after),
|
|
68
|
+
)
|
|
69
|
+
)
|
storage/bootstrap.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Iterable, Optional
|
|
4
|
+
|
|
5
|
+
from fastapi import FastAPI
|
|
6
|
+
from sqlmodel import SQLModel
|
|
7
|
+
|
|
8
|
+
from storage.database import get_engine, set_database_url
|
|
9
|
+
from storage.accessor import ModelAccessor
|
|
10
|
+
from storage.router_model import build_crud_router
|
|
11
|
+
from storage.router_database import build_db_router
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def bootstrap(
|
|
15
|
+
app: FastAPI,
|
|
16
|
+
*,
|
|
17
|
+
accessors: Iterable[ModelAccessor[Any]],
|
|
18
|
+
database_url: Optional[str] = None,
|
|
19
|
+
create_tables: bool = True,
|
|
20
|
+
max_records_per_model: Optional[int] = 5,
|
|
21
|
+
enable_db_router: bool = True,
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Bootstrap the storage system for a FastAPI app.
|
|
25
|
+
|
|
26
|
+
This helper wires up:
|
|
27
|
+
- Database engine initialization (optionally overriding the URL).
|
|
28
|
+
- Optional table creation via `SQLModel.metadata.create_all`.
|
|
29
|
+
- One CRUD router per registered `ModelAccessor`.
|
|
30
|
+
- The global /db router (health, audit, diagram, backup/restore) if enabled.
|
|
31
|
+
"""
|
|
32
|
+
if database_url is not None:
|
|
33
|
+
set_database_url(database_url)
|
|
34
|
+
|
|
35
|
+
engine = get_engine()
|
|
36
|
+
if create_tables:
|
|
37
|
+
SQLModel.metadata.create_all(engine)
|
|
38
|
+
|
|
39
|
+
# Per-model CRUD routers
|
|
40
|
+
for accessor in accessors:
|
|
41
|
+
app.include_router(
|
|
42
|
+
build_crud_router(accessor, max_records=max_records_per_model)
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Global DB router (health, audit, diagram, backup/restore)
|
|
46
|
+
if enable_db_router:
|
|
47
|
+
app.include_router(build_db_router())
|
storage/database.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
from typing import Any, Iterator, Optional
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import event
|
|
9
|
+
from sqlalchemy.engine import Engine
|
|
10
|
+
from sqlmodel import Session, create_engine
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"set_database_url",
|
|
14
|
+
"get_engine",
|
|
15
|
+
"transaction",
|
|
16
|
+
"use_session",
|
|
17
|
+
"CURRENT_SESSION",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
_DATABASE_URL = os.getenv("VENTION_STORAGE_DATABASE_URL", "sqlite:///./storage.db")
|
|
21
|
+
_ENGINE: Optional[Engine] = None
|
|
22
|
+
CURRENT_SESSION: ContextVar[Optional[Session]] = ContextVar(
|
|
23
|
+
"VENTION_STORAGE_CURRENT_SESSION", default=None
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_database_url(url: str) -> None:
|
|
28
|
+
"""
|
|
29
|
+
Configure the database URL before the engine is created.
|
|
30
|
+
Raise if the engine was already initialized.
|
|
31
|
+
"""
|
|
32
|
+
global _DATABASE_URL, _ENGINE
|
|
33
|
+
if _ENGINE is not None:
|
|
34
|
+
raise RuntimeError(
|
|
35
|
+
"Database engine already initialized. Call set_database_url() before first use."
|
|
36
|
+
)
|
|
37
|
+
_DATABASE_URL = url
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _attach_sqlite_pragmas(engine: Engine) -> None:
|
|
41
|
+
"""
|
|
42
|
+
In SQLite foreign key enforcement is disabled by default.
|
|
43
|
+
Without it, insertions of invalid foreign keys will succeed.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def _set_sqlite_pragmas(dbapi_connection: Any, _record: Any) -> None:
|
|
47
|
+
cursor = dbapi_connection.cursor()
|
|
48
|
+
try:
|
|
49
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
|
50
|
+
finally:
|
|
51
|
+
cursor.close()
|
|
52
|
+
|
|
53
|
+
event.listen(engine, "connect", _set_sqlite_pragmas)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_engine() -> Engine:
|
|
57
|
+
"""Return a singleton SQLAlchemy engine (created on first use)."""
|
|
58
|
+
global _ENGINE
|
|
59
|
+
if _ENGINE is None:
|
|
60
|
+
connect_args = (
|
|
61
|
+
{"check_same_thread": False} if _DATABASE_URL.startswith("sqlite") else {}
|
|
62
|
+
)
|
|
63
|
+
_ENGINE = create_engine(_DATABASE_URL, echo=False, connect_args=connect_args)
|
|
64
|
+
_attach_sqlite_pragmas(_ENGINE)
|
|
65
|
+
return _ENGINE
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@contextmanager
|
|
69
|
+
def transaction() -> Iterator[Session]:
|
|
70
|
+
"""
|
|
71
|
+
Yield a Session wrapped in a single atomic transaction.
|
|
72
|
+
Commits on success; rolls back on error.
|
|
73
|
+
Use this when you have multiple all-or-nothing operations.
|
|
74
|
+
|
|
75
|
+
Also sets CURRENT_SESSION for the duration, so accessors/hooks can reuse it.
|
|
76
|
+
"""
|
|
77
|
+
engine = get_engine()
|
|
78
|
+
with Session(engine, expire_on_commit=False) as session:
|
|
79
|
+
token = CURRENT_SESSION.set(session)
|
|
80
|
+
try:
|
|
81
|
+
with session.begin():
|
|
82
|
+
yield session
|
|
83
|
+
finally:
|
|
84
|
+
CURRENT_SESSION.reset(token)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@contextmanager
|
|
88
|
+
def use_session(session: Optional[Session] = None) -> Iterator[Session]:
|
|
89
|
+
"""
|
|
90
|
+
This helper lets accessors and hooks run DB work without duplicating
|
|
91
|
+
session-management code. It does not open a transaction by itself;
|
|
92
|
+
for atomic multi-write operations prefer `transaction()`.
|
|
93
|
+
"""
|
|
94
|
+
if session is not None:
|
|
95
|
+
token = CURRENT_SESSION.set(session)
|
|
96
|
+
try:
|
|
97
|
+
yield session
|
|
98
|
+
finally:
|
|
99
|
+
CURRENT_SESSION.reset(token)
|
|
100
|
+
return
|
|
101
|
+
|
|
102
|
+
with Session(get_engine(), expire_on_commit=False) as s:
|
|
103
|
+
token = CURRENT_SESSION.set(s)
|
|
104
|
+
try:
|
|
105
|
+
yield s
|
|
106
|
+
finally:
|
|
107
|
+
CURRENT_SESSION.reset(token)
|
storage/hooks.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Callable, DefaultDict, Generic, List, TypeVar
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
|
|
6
|
+
from sqlmodel import Session, SQLModel
|
|
7
|
+
from typing_extensions import Literal
|
|
8
|
+
|
|
9
|
+
ModelRecord = TypeVar("ModelRecord", bound=SQLModel)
|
|
10
|
+
|
|
11
|
+
HookEvent = Literal[
|
|
12
|
+
"before_insert",
|
|
13
|
+
"after_insert",
|
|
14
|
+
"before_update",
|
|
15
|
+
"after_update",
|
|
16
|
+
"before_delete",
|
|
17
|
+
"after_delete",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
HookFn = Callable[[Session, ModelRecord], None]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class HookRegistry(Generic[ModelRecord]):
|
|
24
|
+
"""Lightweight per-accessor registry for lifecycle hooks."""
|
|
25
|
+
|
|
26
|
+
def __init__(self) -> None:
|
|
27
|
+
self._hooks: DefaultDict[HookEvent, List[HookFn[ModelRecord]]] = defaultdict(
|
|
28
|
+
list
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def decorator(
|
|
32
|
+
self, event: HookEvent
|
|
33
|
+
) -> Callable[[HookFn[ModelRecord]], HookFn[ModelRecord]]:
|
|
34
|
+
"""Return a decorator that registers a function for `event`."""
|
|
35
|
+
|
|
36
|
+
def deco(fn: HookFn[ModelRecord]) -> HookFn[ModelRecord]:
|
|
37
|
+
self._hooks[event].append(fn)
|
|
38
|
+
return fn
|
|
39
|
+
|
|
40
|
+
return deco
|
|
41
|
+
|
|
42
|
+
def emit(
|
|
43
|
+
self, event: HookEvent, *, session: Session, instance: ModelRecord
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Invoke all hooks registered for `event`."""
|
|
46
|
+
for fn in self._hooks.get(event, []):
|
|
47
|
+
fn(session, instance)
|