dbdm 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dbdm-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Roma Koshel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
dbdm-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,191 @@
1
+ Metadata-Version: 2.1
2
+ Name: dbdm
3
+ Version: 0.1.0
4
+ Summary:
5
+ License: MIT
6
+ Author: Roman Koshel
7
+ Author-email: roma.koshel@gmail.com
8
+ Requires-Python: >=3.11,<4.0
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3 :: Only
17
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
18
+ Classifier: Topic :: Utilities
19
+ Requires-Dist: asyncpg
20
+ Requires-Dist: sqlalchemy[asyncio] (>=2,<3)
21
+ Description-Content-Type: text/markdown
22
+
23
+ # DBDM [wip]
24
+
25
+
26
+ ## Examples
27
+
28
+ ```python
29
+ from typing import ClassVar, Type
30
+
31
+ import sqlalchemy as sa
32
+
33
+ from cwtch import dataclass, resolve_types, view
34
+ from cwtch.types import UNSET, Unset
35
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
36
+ from sqlalchemy.ext.asyncio import create_async_engine
37
+
38
+ from dbdm import DM, NotFoundError, bind_engine
39
+
40
+
41
+ class BaseDB(DeclarativeBase):
42
+ pass
43
+
44
+
45
+ class ParentDB(BaseDB):
46
+ __tablename__ = "parents"
47
+
48
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
49
+ name: Mapped[str]
50
+ data: Mapped[str]
51
+ children = relationship("ChildDB", uselist=True, viewonly=True)
52
+
53
+
54
+ class ChildDB(BaseDB):
55
+ __tablename__ = "children"
56
+
57
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
58
+ name: Mapped[str]
59
+ parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
60
+ parent = relationship("ParentDB", uselist=False, viewonly=True)
61
+
62
+
63
+ @dataclass(handle_circular_refs=True)
64
+ class Parent:
65
+ id: int
66
+ name: str
67
+ data: str
68
+ children: Unset[list["Child"]] = UNSET
69
+
70
+ # Parent views
71
+ Create: ClassVar[Type["ParentCreate"]]
72
+ Save: ClassVar[Type["ParentSave"]]
73
+ Update: ClassVar[Type["ParentUpdate"]]
74
+
75
+
76
+ @view(Parent, "Create", exclude=["id", "children"])
77
+ class ParentCreate:
78
+ pass
79
+
80
+
81
+ @view(Parent, "Save", exclude=["children"])
82
+ class ParentSave:
83
+ pass
84
+
85
+
86
+ @view(Parent, "Update", exclude=["children"])
87
+ class ParentUpdate:
88
+ name: Unset[str] = UNSET
89
+ data: Unset[str] = UNSET
90
+
91
+
92
+ @dataclass(handle_circular_refs=True)
93
+ class Child:
94
+ id: int
95
+ name: str
96
+ parent_id: int
97
+ parent: Unset[Parent] = UNSET
98
+
99
+ # Child views
100
+ Create: ClassVar[Type["ChildCreate"]]
101
+ Save: ClassVar[Type["ChildSave"]]
102
+ Update: ClassVar[Type["ChildUpdate"]]
103
+
104
+
105
+ @view(Child, "Create", exclude=["id", "parent"])
106
+ class ChildCreate:
107
+ pass
108
+
109
+
110
+ @view(Child, "Save", exclude=["parent"])
111
+ class ChildSave:
112
+ pass
113
+
114
+
115
+ @view(Child, "Update", exclude=["parent"])
116
+ class ChildUpdate:
117
+ name: Unset[str] = UNSET
118
+ parent_id: Unset[int] = UNSET
119
+
120
+
121
+ resolve_types(Parent, globals(), locals())
122
+
123
+
124
+ class ParentDM(DM):
125
+ model_db = ParentDB
126
+ model = Parent
127
+ model_create = Parent.Create
128
+ model_save = Parent.Save
129
+ model_update = Parent.Update
130
+ key = "id"
131
+ index_elements = ["id"]
132
+ joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
133
+
134
+
135
+ class ChildDM(DM):
136
+ model_db = ChildDB
137
+ model = Child
138
+ model_create = Child.Create
139
+ model_save = Child.Save
140
+ model_update = Child.Update
141
+ key = "id"
142
+ index_elements = ["id"]
143
+ joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
144
+
145
+
146
+ @pytest_asyncio.fixture
147
+ async def create_all(engine):
148
+ async with engine.begin() as conn:
149
+ await conn.run_sync(BaseDB.metadata.create_all)
150
+
151
+
152
+ async def example(engine):
153
+ engine = create_async_engine(...)
154
+
155
+ async with engine.begin() as conn:
156
+ await conn.run_sync(BaseDB.metadata.create_all)
157
+
158
+ bind_engine(engine)
159
+
160
+ parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
161
+
162
+ # parents: list[Parent]
163
+ total, parents = await ParentDM.get_many()
164
+
165
+ # parents: list[Parent]
166
+ total, parents = await ParentDM.get_many(page_size=1)
167
+
168
+ # parents: list[Parent]
169
+ total, parents = await ParentDM.get_many(page_size=1, page=2)
170
+
171
+ # parent: Parent
172
+ parent = await ParentDM.get(1)
173
+
174
+ # parent: Parent
175
+ parent = await ParentDM.get(1, joinedload={"children": True})
176
+
177
+ # parent: Parent
178
+ parent = await ParentDM.save(parent.Save())
179
+
180
+ # parent: Parent
181
+ parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
182
+
183
+ await ParentDM.delete(1)
184
+
185
+ # child : Child
186
+ child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
187
+
188
+ # parent: Parent
189
+ parent = await ParentDM.get(1, joinedload={"children": True})
190
+ ```
191
+
dbdm-0.1.0/README.md ADDED
@@ -0,0 +1,168 @@
1
+ # DBDM [wip]
2
+
3
+
4
+ ## Examples
5
+
6
+ ```python
7
+ from typing import ClassVar, Type
8
+
9
+ import sqlalchemy as sa
10
+
11
+ from cwtch import dataclass, resolve_types, view
12
+ from cwtch.types import UNSET, Unset
13
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
14
+ from sqlalchemy.ext.asyncio import create_async_engine
15
+
16
+ from dbdm import DM, NotFoundError, bind_engine
17
+
18
+
19
+ class BaseDB(DeclarativeBase):
20
+ pass
21
+
22
+
23
+ class ParentDB(BaseDB):
24
+ __tablename__ = "parents"
25
+
26
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
27
+ name: Mapped[str]
28
+ data: Mapped[str]
29
+ children = relationship("ChildDB", uselist=True, viewonly=True)
30
+
31
+
32
+ class ChildDB(BaseDB):
33
+ __tablename__ = "children"
34
+
35
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
36
+ name: Mapped[str]
37
+ parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
38
+ parent = relationship("ParentDB", uselist=False, viewonly=True)
39
+
40
+
41
+ @dataclass(handle_circular_refs=True)
42
+ class Parent:
43
+ id: int
44
+ name: str
45
+ data: str
46
+ children: Unset[list["Child"]] = UNSET
47
+
48
+ # Parent views
49
+ Create: ClassVar[Type["ParentCreate"]]
50
+ Save: ClassVar[Type["ParentSave"]]
51
+ Update: ClassVar[Type["ParentUpdate"]]
52
+
53
+
54
+ @view(Parent, "Create", exclude=["id", "children"])
55
+ class ParentCreate:
56
+ pass
57
+
58
+
59
+ @view(Parent, "Save", exclude=["children"])
60
+ class ParentSave:
61
+ pass
62
+
63
+
64
+ @view(Parent, "Update", exclude=["children"])
65
+ class ParentUpdate:
66
+ name: Unset[str] = UNSET
67
+ data: Unset[str] = UNSET
68
+
69
+
70
+ @dataclass(handle_circular_refs=True)
71
+ class Child:
72
+ id: int
73
+ name: str
74
+ parent_id: int
75
+ parent: Unset[Parent] = UNSET
76
+
77
+ # Child views
78
+ Create: ClassVar[Type["ChildCreate"]]
79
+ Save: ClassVar[Type["ChildSave"]]
80
+ Update: ClassVar[Type["ChildUpdate"]]
81
+
82
+
83
+ @view(Child, "Create", exclude=["id", "parent"])
84
+ class ChildCreate:
85
+ pass
86
+
87
+
88
+ @view(Child, "Save", exclude=["parent"])
89
+ class ChildSave:
90
+ pass
91
+
92
+
93
+ @view(Child, "Update", exclude=["parent"])
94
+ class ChildUpdate:
95
+ name: Unset[str] = UNSET
96
+ parent_id: Unset[int] = UNSET
97
+
98
+
99
+ resolve_types(Parent, globals(), locals())
100
+
101
+
102
+ class ParentDM(DM):
103
+ model_db = ParentDB
104
+ model = Parent
105
+ model_create = Parent.Create
106
+ model_save = Parent.Save
107
+ model_update = Parent.Update
108
+ key = "id"
109
+ index_elements = ["id"]
110
+ joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
111
+
112
+
113
+ class ChildDM(DM):
114
+ model_db = ChildDB
115
+ model = Child
116
+ model_create = Child.Create
117
+ model_save = Child.Save
118
+ model_update = Child.Update
119
+ key = "id"
120
+ index_elements = ["id"]
121
+ joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
122
+
123
+
124
+ @pytest_asyncio.fixture
125
+ async def create_all(engine):
126
+ async with engine.begin() as conn:
127
+ await conn.run_sync(BaseDB.metadata.create_all)
128
+
129
+
130
+ async def example(engine):
131
+ engine = create_async_engine(...)
132
+
133
+ async with engine.begin() as conn:
134
+ await conn.run_sync(BaseDB.metadata.create_all)
135
+
136
+ bind_engine(engine)
137
+
138
+ parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
139
+
140
+ # parents: list[Parent]
141
+ total, parents = await ParentDM.get_many()
142
+
143
+ # parents: list[Parent]
144
+ total, parents = await ParentDM.get_many(page_size=1)
145
+
146
+ # parents: list[Parent]
147
+ total, parents = await ParentDM.get_many(page_size=1, page=2)
148
+
149
+ # parent: Parent
150
+ parent = await ParentDM.get(1)
151
+
152
+ # parent: Parent
153
+ parent = await ParentDM.get(1, joinedload={"children": True})
154
+
155
+ # parent: Parent
156
+ parent = await ParentDM.save(parent.Save())
157
+
158
+ # parent: Parent
159
+ parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
160
+
161
+ await ParentDM.delete(1)
162
+
163
+ # child : Child
164
+ child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
165
+
166
+ # parent: Parent
167
+ parent = await ParentDM.get(1, joinedload={"children": True})
168
+ ```
@@ -0,0 +1,7 @@
1
+ import importlib.metadata
2
+
3
+
4
+ __version__ = importlib.metadata.version("dbdm")
5
+
6
+
7
+ from .common import * # noqa: F403
@@ -0,0 +1,736 @@
1
+ import re
2
+ import typing
3
+
4
+ from asyncio import gather
5
+ from contextlib import asynccontextmanager
6
+ from contextvars import ContextVar
7
+ from dataclasses import dataclass
8
+ from typing import Any, AsyncIterator, Callable, Literal, Never, Optional, Protocol, Type, TypeVar
9
+
10
+ import sqlalchemy as sa
11
+
12
+ from sqlalchemy import delete, func, literal, select, union_all, update
13
+ from sqlalchemy.dialects.postgresql import insert
14
+ from sqlalchemy.ext.asyncio import async_sessionmaker
15
+ from sqlalchemy.orm import DeclarativeBase, aliased
16
+
17
+
18
+ __all__ = [
19
+ "DMError",
20
+ "AlreadyExistsError",
21
+ "BadParamsError",
22
+ "NotFoundError",
23
+ "OrderBy",
24
+ "bind_engine",
25
+ "transaction",
26
+ ]
27
+
28
+
29
+ _conn: dict[Optional[str], ContextVar] = {}
30
+ _engine: dict[Optional[str], sa.ext.asyncio.AsyncEngine] = {}
31
+
32
+
33
+ class DMError(Exception):
34
+ pass
35
+
36
+
37
+ class BadParamsError(DMError):
38
+ def __init__(self, message: str, param: Optional[str] = None):
39
+ self.message = message
40
+ self.param = param
41
+
42
+ def __str__(self):
43
+ return self.message
44
+
45
+ def __repr__(self):
46
+ return self.__str__()
47
+
48
+
49
+ class NotFoundError(DMError):
50
+ def __init__(self, key, value):
51
+ self.key = key
52
+ self.value = value
53
+
54
+ def __str__(self):
55
+ return f"item with ({self.key})=({self.value}) not found"
56
+
57
+
58
+ class AlreadyExistsError(DMError):
59
+ def __init__(self, key, value):
60
+ self.key = key
61
+ self.value = value
62
+
63
+ def __str__(self):
64
+ return f"key ({self.key})=({self.value}) already exists"
65
+
66
+
67
+ @dataclass
68
+ class OrderBy:
69
+ by: Any
70
+ order: Literal["asc", "desc"] = "asc"
71
+
72
+
73
+ def bind_engine(engine: sa.ext.asyncio.AsyncEngine, name: Optional[str] = None):
74
+ if engine.dialect.name != "postgresql":
75
+ raise DMError("only 'postgresql' dialect is supported")
76
+ _engine[name] = engine
77
+ _conn[name] = ContextVar("conn", default=None)
78
+
79
+
80
+ @asynccontextmanager
81
+ async def transaction(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
82
+ if (conn := _conn[engine_name].get()) is None:
83
+ async with _engine[engine_name].connect() as conn:
84
+ async with conn.begin():
85
+ _conn[engine_name].set(conn)
86
+ try:
87
+ yield conn
88
+ finally:
89
+ _conn[engine_name].set(None)
90
+ else:
91
+ yield conn
92
+
93
+
94
+ @asynccontextmanager
95
+ async def get_conn(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
96
+ if (conn := _conn[engine_name].get()) is None:
97
+ async with _engine[engine_name].connect() as conn:
98
+ async with conn.begin():
99
+ yield conn
100
+ else:
101
+ yield conn
102
+
103
+
104
+ @asynccontextmanager
105
+ async def get_sess(engine_name: Optional[str] = None):
106
+ async with get_conn(engine_name=engine_name) as conn:
107
+ async_session = async_sessionmaker(conn, expire_on_commit=False)
108
+ async with async_session() as sess:
109
+ async with sess.begin():
110
+ yield sess
111
+
112
+
113
+ def _raise_exc(e: Exception) -> Never:
114
+ if isinstance(e, sa.exc.IntegrityError):
115
+ detail_match = re.match(r".*\nDETAIL:\s*(?P<text>.*)$", e.orig.args[0]) # type: ignore
116
+ if detail_match:
117
+ text = detail_match.groupdict()["text"].strip()
118
+ m = re.match(r"Key \((?P<key>.*)\)=\((?P<key_value>.*)\) already exists.", text)
119
+ if m:
120
+ key = m.groupdict()["key"]
121
+ key_value = m.groupdict()["key_value"].strip('\\"')
122
+ raise AlreadyExistsError(key, key_value)
123
+ raise e
124
+
125
+
126
+ ModelDB = TypeVar("ModelDB", bound=DeclarativeBase)
127
+
128
+
129
+ class _BaseProtocol(Protocol):
130
+ engine_name: Optional[str]
131
+
132
+ model_db: Type[DeclarativeBase]
133
+ model: Type
134
+
135
+ @classmethod
136
+ async def _execute(cls, query) -> sa.ResultProxy: ...
137
+
138
+ @classmethod
139
+ def _make_from_db_data(
140
+ cls,
141
+ db_item,
142
+ row: Optional[tuple] = None,
143
+ joinedload: Optional[dict] = None,
144
+ ) -> dict: ...
145
+
146
+ @classmethod
147
+ def _from_db(
148
+ cls,
149
+ item: sa.Row | ModelDB,
150
+ data: Optional[dict] = None,
151
+ exclude: Optional[list] = None,
152
+ suffix: Optional[str] = None,
153
+ model_out=None,
154
+ ): ...
155
+
156
+ @classmethod
157
+ def _get_key(cls, key=None): ...
158
+
159
+ @classmethod
160
+ async def get(
161
+ cls,
162
+ key_value,
163
+ key=None,
164
+ raise_not_found: Optional[bool] = None,
165
+ joinedload: Optional[dict] = None,
166
+ model_out: Optional[Type] = None,
167
+ **kwds,
168
+ ) -> Optional[type]: ...
169
+
170
+ @classmethod
171
+ async def get_many(
172
+ cls,
173
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
174
+ page: Optional[int] = None,
175
+ page_size: Optional[int] = None,
176
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
177
+ joinedload: Optional[dict] = None,
178
+ model_out: Optional[Type] = None,
179
+ **kwds,
180
+ ) -> tuple[int, list]: ...
181
+
182
+
183
+ class _Meta(typing._ProtocolMeta):
184
+ def __new__(cls, name, bases, ns, skip_checks: bool = False):
185
+ if not skip_checks:
186
+
187
+ def get_all_bases(bases: tuple) -> set:
188
+ result = set()
189
+ for base in bases:
190
+ result.add(base)
191
+ result |= get_all_bases(getattr(base, "__bases__", ()))
192
+ return result
193
+
194
+ all_bases = get_all_bases(bases)
195
+
196
+ if _GetDM not in all_bases:
197
+ raise TypeError("Any DM class should be subclassed from GetDM class")
198
+
199
+ if (model_db := ns.get("model_db")) is None:
200
+ raise ValueError("'model_db' field is required")
201
+
202
+ if (model := ns.get("model")) is None:
203
+ raise ValueError("'model' field is required")
204
+
205
+ if (
206
+ getattr(model, "__dataclass_fields__", None) is None
207
+ and getattr(model, "__pydantic_fields__", None) is None
208
+ ):
209
+ raise ValueError("'model' is not a valid model")
210
+
211
+ if _CreateDM in all_bases or _SaveDM in all_bases:
212
+ if (model_create := ns.get("model_create")) is None:
213
+ raise ValueError("'model_create' field is required")
214
+ if (
215
+ getattr(model_create, "__dataclass_fields__", None) is None
216
+ and getattr(model_create, "__pydantic_fields__", None) is None
217
+ ):
218
+ raise ValueError("'model_create' is not a valid model")
219
+ ns["_fields_create"] = set(model_db.__table__.columns.keys()) & set(
220
+ getattr(model_create, "__dataclass_fields__", None)
221
+ or getattr(model_create, "__pydantic_fields__", None)
222
+ or set()
223
+ )
224
+ if not ns["_fields_create"]:
225
+ raise DMError("model_create does not contain valid fields")
226
+
227
+ if _SaveDM in all_bases:
228
+ if (model_save := ns.get("model_save")) is None:
229
+ raise ValueError("'model_save' field is required")
230
+ if (
231
+ getattr(model_save, "__dataclass_fields__", None) is None
232
+ and getattr(model_save, "__pydantic_fields__", None) is None
233
+ ):
234
+ raise ValueError("'model_save' is not a valid model")
235
+ ns["_fields_save"] = set(model_db.__table__.columns.keys()) & set(
236
+ getattr(model_save, "__dataclass_fields__", None)
237
+ or getattr(model_save, "__pydantic_fields__", None)
238
+ or set()
239
+ )
240
+ if not ns["_fields_save"]:
241
+ raise DMError("model_save does not contain valid fields")
242
+
243
+ if _UpdateDM in all_bases:
244
+ if (model_update := ns.get("model_update")) is None:
245
+ raise ValueError("'model_update' field is required")
246
+ if (
247
+ getattr(model_update, "__dataclass_fields__", None) is None
248
+ and getattr(model_update, "__pydantic_fields__", None) is None
249
+ ):
250
+ raise ValueError("'model_update' is not a valid model")
251
+ ns["_fields_update"] = set(model_db.__table__.columns.keys()) & set(
252
+ getattr(model_update, "__dataclass_fields__", None)
253
+ or getattr(model_update, "__pydantic_fields__", None)
254
+ or set()
255
+ )
256
+ if not ns["_fields_update"]:
257
+ raise DMError("model_update does not contain valid fields")
258
+
259
+ return super().__new__(cls, name, bases, ns)
260
+
261
+
262
+ class _GetDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
263
+ engine_name: Optional[str] = None
264
+
265
+ model_db: Type[DeclarativeBase] = typing.cast(Type[DeclarativeBase], None)
266
+ model: Type = typing.cast(Type, None)
267
+
268
+ key = None
269
+ order_by: Optional[list[OrderBy | str | Callable] | OrderBy | str | Callable] = None
270
+
271
+ joinedload: dict = {}
272
+
273
+ @classmethod
274
+ async def _execute(cls, query) -> sa.ResultProxy:
275
+ async with get_conn(engine_name=cls.engine_name) as conn:
276
+ try:
277
+ return await conn.execute(query)
278
+ except Exception as e:
279
+ _raise_exc(e)
280
+
281
+ @classmethod
282
+ def _make_from_db_data(
283
+ cls,
284
+ db_item,
285
+ row: Optional[tuple] = None,
286
+ joinedload: Optional[dict] = None,
287
+ ) -> dict:
288
+ return {}
289
+
290
+ @classmethod
291
+ def _from_db(
292
+ cls,
293
+ item: sa.Row | ModelDB,
294
+ data: Optional[dict] = None,
295
+ exclude: Optional[list] = None,
296
+ suffix: Optional[str] = None,
297
+ model_out=None,
298
+ ): ...
299
+
300
+ @classmethod
301
+ def _get_key(cls, key=None):
302
+ if key is None:
303
+ key = cls.key
304
+ if key is None:
305
+ raise ValueError("key is None")
306
+ if isinstance(key, str):
307
+ key = getattr(cls.model_db, key)
308
+ return key
309
+
310
+ @classmethod
311
+ def _get_order_by(cls, c, order_by: Optional[list[OrderBy | str] | OrderBy | str] = None) -> list:
312
+ _order_by = order_by
313
+ if _order_by is None:
314
+ if callable(cls.order_by):
315
+ _order_by = cls.order_by(c)
316
+ else:
317
+ _order_by = cls.order_by or cls._get_key()
318
+ if _order_by is not None:
319
+ if not isinstance(_order_by, list):
320
+ _order_by = [_order_by]
321
+ else:
322
+ _order_by = list(_order_by)
323
+ for i, x in enumerate(_order_by):
324
+ if isinstance(x, OrderBy):
325
+ x = OrderBy(by=x.by, order=x.order)
326
+ else:
327
+ x = OrderBy(by=x)
328
+ if isinstance(x.by, str):
329
+ if getattr(c, x.by, None) is None:
330
+ raise BadParamsError(f"invalid order_by '{x.by}'", param=f"{x.by}")
331
+ x.by = getattr(c, x.by)
332
+ _order_by[i] = getattr(sa, x.order)(x.by)
333
+ return typing.cast(list, _order_by)
334
+
335
+ @classmethod
336
+ def _make_get_query(cls, key, key_value, **kwds) -> sa.sql.Select:
337
+ return select(cls.model_db).where(key == key_value)
338
+
339
+ @classmethod
340
+ async def get(
341
+ cls,
342
+ key_value,
343
+ key=None,
344
+ raise_not_found: Optional[bool] = None,
345
+ joinedload: Optional[dict] = None,
346
+ model_out: Optional[Type] = None,
347
+ **kwds,
348
+ ) -> Optional[type]:
349
+ key = cls._get_key(key=key)
350
+ query = cls._make_get_query(key, key_value, **kwds)
351
+ model_out = model_out or cls.model
352
+
353
+ if joinedload is None or not any(joinedload.values()):
354
+ item = (await cls._execute(query)).one_or_none()
355
+ if item:
356
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
357
+ if raise_not_found:
358
+ raise NotFoundError(key=key, value=key_value)
359
+ return
360
+
361
+ exclude = []
362
+ for k, v in cls.joinedload.items():
363
+ if joinedload.get(k) is True:
364
+ query = query.options(v(cls.model_db))
365
+ else:
366
+ exclude.append(k)
367
+
368
+ async with get_sess(engine_name=cls.engine_name) as sess:
369
+ item = (await sess.execute(query)).unique().scalars().one_or_none()
370
+ if item:
371
+ return cls._from_db(
372
+ item,
373
+ exclude=exclude,
374
+ data=cls._make_from_db_data(item, joinedload=joinedload),
375
+ model_out=model_out,
376
+ )
377
+ if raise_not_found:
378
+ raise NotFoundError(key=key, value=key_value)
379
+
380
+ @classmethod
381
+ def _make_get_many_query(
382
+ cls,
383
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
384
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
385
+ **kwds,
386
+ ) -> sa.sql.Select:
387
+ query = select(func.count(literal("*")).over().label("rows_total"), cls.model_db)
388
+ if flt is not None:
389
+ query = query.where(flt)
390
+ query = query.order_by(*cls._get_order_by(cls.model_db, order_by))
391
+ return query
392
+
393
+ @classmethod
394
+ async def get_many(
395
+ cls,
396
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
397
+ page: Optional[int] = None,
398
+ page_size: Optional[int] = None,
399
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
400
+ joinedload: Optional[dict] = None,
401
+ model_out: Optional[Type] = None,
402
+ **kwds,
403
+ ) -> tuple[int, list]:
404
+ model_db = cls.model_db
405
+ model_out = model_out or cls.model
406
+
407
+ query = cls._make_get_many_query(flt=flt, order_by=order_by, **kwds)
408
+
409
+ cte = query.cte("cte")
410
+
411
+ query = select(literal(1).label("i"), cte)
412
+
413
+ if page_size:
414
+ page = page or 1
415
+ query = query.limit(page_size).offset((page - 1) * page_size)
416
+
417
+ query = union_all(select(literal(0).label("i"), cte).limit(1), query)
418
+
419
+ from_db = cls._from_db
420
+ make_from_orm_data = cls._make_from_db_data
421
+
422
+ if joinedload is None or not any(joinedload.values()):
423
+ rows = (await cls._execute(query)).all()
424
+ return rows[0].rows_total if rows else 0, [
425
+ from_db(row, data=make_from_orm_data(row), model_out=model_out) for row in rows[1:]
426
+ ]
427
+
428
+ main_cte = query.cte("main_cte")
429
+
430
+ model_db_alias = aliased(model_db, main_cte)
431
+ query = select(main_cte, model_db_alias)
432
+
433
+ exclude = []
434
+ for k, v in cls.joinedload.items():
435
+ if joinedload.get(k) is True:
436
+ query = query.options(v(model_db_alias))
437
+ else:
438
+ exclude.append(k)
439
+
440
+ query = query.order_by(sa.asc(main_cte.c.i), *cls._get_order_by(main_cte.c, order_by))
441
+
442
+ def _hash(row):
443
+ return hash((row[0], row[1], row[-1]))
444
+
445
+ async with get_sess(engine_name=cls.engine_name) as sess:
446
+ rows = (await sess.execute(query)).unique(_hash).all()
447
+ return rows[0].rows_total if rows else 0, [
448
+ from_db(
449
+ row[-1],
450
+ exclude=exclude,
451
+ data=make_from_orm_data(row[-1], row=typing.cast(tuple, row), joinedload=joinedload),
452
+ model_out=model_out,
453
+ )
454
+ for row in rows[1:]
455
+ ]
456
+
457
+
458
+ class _CreateDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
459
+ model_create: Type = typing.cast(Type, None)
460
+
461
+ index_elements: Optional[list] = None
462
+
463
+ _fields_create: set[str] = set()
464
+
465
+ @classmethod
466
+ def _get_values_for_create_query(cls, model) -> dict: ...
467
+
468
+ @classmethod
469
+ def _make_create_query(cls, model, returning: Optional[bool] = None) -> sa.sql.Insert:
470
+ model_db = cls.model_db
471
+ query = insert(model_db).values(cls._get_values_for_create_query(model))
472
+ if returning:
473
+ query = query.returning(model_db)
474
+ return query
475
+
476
+ @classmethod
477
+ async def create(
478
+ cls,
479
+ model,
480
+ model_out: Optional[Type] = None,
481
+ returning: bool = True,
482
+ ):
483
+ result = await cls._execute(cls._make_create_query(model, returning=returning))
484
+ if returning:
485
+ item = result.one()
486
+ model_out = model_out or cls.model
487
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
488
+
489
+ @classmethod
490
+ def _make_create_many_query(cls, models: list, returning: bool | None = None) -> sa.sql.Insert:
491
+ model_db = cls.model_db
492
+ query = insert(cls.model_db).values([cls._get_values_for_create_query(model) for model in models])
493
+ if returning:
494
+ query = query.returning(model_db)
495
+ return query
496
+
497
+ @classmethod
498
+ async def create_many(
499
+ cls,
500
+ models: list,
501
+ model_out: Optional[Type] = None,
502
+ returning: bool = True,
503
+ ) -> list | None:
504
+ if not models:
505
+ return []
506
+ result = await cls._execute(cls._make_create_many_query(models, returning=returning))
507
+ if returning:
508
+ from_db = cls._from_db
509
+ make_from_db_data = cls._make_from_db_data
510
+ model_out = model_out or cls.model
511
+ return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
512
+
513
+ @classmethod
514
+ def _make_get_or_create_query(
515
+ cls,
516
+ model,
517
+ update_element: str,
518
+ index_elements: Optional[list] = None,
519
+ ) -> sa.sql.Insert:
520
+ model_db = cls.model_db
521
+ query = insert(model_db).values(cls._get_values_for_create_query(model))
522
+ index_elements = index_elements or cls.index_elements
523
+ if index_elements is None:
524
+ raise ValueError("index_elements is None")
525
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
526
+ return query.on_conflict_do_update(
527
+ index_elements=index_elements,
528
+ set_={update_element: getattr(query.excluded, update_element)},
529
+ ).returning(model_db)
530
+
531
+ @classmethod
532
+ async def get_or_create(
533
+ cls,
534
+ model,
535
+ update_element: str,
536
+ index_elements: Optional[list] = None,
537
+ model_out: Optional[Type] = None,
538
+ ):
539
+ model_out = model_out or cls.model
540
+ item = (
541
+ await cls._execute(
542
+ cls._make_get_or_create_query(
543
+ model,
544
+ update_element,
545
+ index_elements=index_elements,
546
+ )
547
+ )
548
+ ).one()
549
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
550
+
551
+
552
+ class _SaveDM(_CreateDM, metaclass=_Meta, skip_checks=True):
553
+ model_create: Type = typing.cast(Type, None)
554
+ model_save: Type = typing.cast(Type, None)
555
+
556
+ index_elements: Optional[list] = None
557
+
558
+ _fields_create: set[str] = set()
559
+ _fields_save: set[str] = set()
560
+
561
+ @classmethod
562
+ def _get_values_for_save_query(cls, model) -> dict: ...
563
+
564
+ @classmethod
565
+ def _get_on_conflict_do_update_set_for_save_query(cls, excluded, model) -> dict:
566
+ if isinstance(model, typing.cast(Type, cls.model_create)):
567
+ return {k: getattr(excluded, k) for k in cls._fields_create}
568
+ return {k: getattr(excluded, k) for k in cls._fields_save}
569
+
570
+ @classmethod
571
+ def _make_save_query(
572
+ cls,
573
+ model,
574
+ index_elements: Optional[list] = None,
575
+ returning: Optional[bool] = None,
576
+ ) -> sa.sql.Insert:
577
+ model_db = cls.model_db
578
+ query = insert(model_db).values(cls._get_values_for_save_query(model))
579
+ index_elements = index_elements or cls.index_elements
580
+ if index_elements is None:
581
+ raise ValueError("index_elements is None")
582
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
583
+ query = query.on_conflict_do_update(
584
+ index_elements=index_elements,
585
+ set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, model),
586
+ )
587
+ if returning:
588
+ query = query.returning(model_db)
589
+ return query
590
+
591
+ @classmethod
592
+ async def save(
593
+ cls,
594
+ model,
595
+ index_elements: Optional[list] = None,
596
+ model_out: Optional[Type] = None,
597
+ returning: bool = True,
598
+ ):
599
+ result = await cls._execute(
600
+ cls._make_save_query(
601
+ model,
602
+ index_elements=index_elements,
603
+ returning=returning,
604
+ )
605
+ )
606
+ if returning:
607
+ item = result.one()
608
+ model_out = model_out or cls.model
609
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
610
+
611
+ @classmethod
612
+ def _make_save_many_query(
613
+ cls,
614
+ models: list,
615
+ index_elements: Optional[list] = None,
616
+ returning: Optional[bool] = None,
617
+ ) -> sa.sql.Insert:
618
+ model_db = cls.model_db
619
+ query = insert(model_db).values([cls._get_values_for_save_query(model) for model in models])
620
+ index_elements = index_elements or cls.index_elements
621
+ if index_elements is None:
622
+ raise ValueError("index_elements is None")
623
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
624
+ query = query.on_conflict_do_update(
625
+ index_elements=index_elements,
626
+ set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, models[0]),
627
+ )
628
+ if returning:
629
+ query = query.returning(model_db)
630
+ return query
631
+
632
+ @classmethod
633
+ async def save_many(
634
+ cls,
635
+ models: list,
636
+ index_elements: Optional[list] = None,
637
+ model_out: Optional[Type] = None,
638
+ returning: bool = True,
639
+ ) -> list | None:
640
+ if not models:
641
+ return []
642
+ result = await cls._execute(
643
+ cls._make_save_many_query(
644
+ models,
645
+ index_elements=index_elements,
646
+ returning=returning,
647
+ )
648
+ )
649
+ if returning:
650
+ from_db = cls._from_db
651
+ make_from_db_data = cls._make_from_db_data
652
+ model_out = model_out or cls.model
653
+ return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
654
+
655
+
656
+ class _UpdateDM(_CreateDM, metaclass=_Meta, skip_checks=True):
657
+ model_update: Type = typing.cast(Type, None)
658
+ _fields_update: set[str] = set()
659
+
660
+ @classmethod
661
+ def _get_values_for_update_query(cls, model) -> dict: ...
662
+
663
+ @classmethod
664
+ def _make_update_query(cls, model, key: str, returning: Optional[bool] = None) -> sa.sql.Update:
665
+ model_db = cls.model_db
666
+ query = (
667
+ update(model_db)
668
+ .values(cls._get_values_for_update_query(model))
669
+ .where(getattr(model_db, key) == getattr(model, key))
670
+ .returning(getattr(model_db, key))
671
+ )
672
+ if returning:
673
+ query = query.returning(model_db)
674
+ return query
675
+
676
+ @classmethod
677
+ async def update(
678
+ cls,
679
+ model,
680
+ key: str,
681
+ raise_not_found: Optional[bool] = None,
682
+ model_out: Optional[Type] = None,
683
+ returning: bool = True,
684
+ ):
685
+ result = await cls._execute(cls._make_update_query(model, key, returning=returning or raise_not_found))
686
+ if raise_not_found and result.rowcount == 0:
687
+ raise NotFoundError(key=key, value=getattr(model, key))
688
+ if returning:
689
+ item = result.one_or_none()
690
+ if item:
691
+ model_out = model_out or cls.model
692
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
693
+
694
+ @classmethod
695
+ async def update_many(
696
+ cls,
697
+ models: list,
698
+ key: str,
699
+ model_out: Optional[Type] = None,
700
+ returning: bool = True,
701
+ ) -> list | None:
702
+ async with transaction():
703
+ results = await gather(
704
+ *[cls.update(model, key, model_out=model_out, returning=returning) for model in models]
705
+ )
706
+ if returning:
707
+ return typing.cast(list | None, results)
708
+
709
+
710
+ class _DeleteDM(_CreateDM, metaclass=_Meta, skip_checks=True):
711
+ @classmethod
712
+ def _make_delete_query(cls, key, key_value, returning: Optional[bool] = None):
713
+ query = delete(cls.model_db).where(key == key_value)
714
+ if returning:
715
+ query = query.returning(cls.model_db)
716
+ return query
717
+
718
+ @classmethod
719
+ async def delete(
720
+ cls,
721
+ key_value,
722
+ key=None,
723
+ raise_not_found: Optional[bool] = None,
724
+ model_out: Optional[Type] = None,
725
+ returning: bool = True,
726
+ ):
727
+ key = cls._get_key(key=key)
728
+ result = await cls._execute(cls._make_delete_query(key, key_value, returning=returning))
729
+ if raise_not_found:
730
+ if raise_not_found and result.rowcount == 0:
731
+ raise NotFoundError(key=key, value=key_value)
732
+ if returning:
733
+ item = result.one_or_none()
734
+ if item:
735
+ model_out = model_out or cls.model
736
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
@@ -0,0 +1,66 @@
1
+ import typing
2
+
3
+ from typing import Optional, Type
4
+
5
+ import sqlalchemy as sa
6
+
7
+ from cwtch import asdict, from_attributes
8
+
9
+ from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
10
+
11
+
12
+ __all__ = [
13
+ "OrderBy",
14
+ "GetDM",
15
+ "CreateDM",
16
+ "SaveDM",
17
+ "UpdateDM",
18
+ "DeleteDM",
19
+ "DM",
20
+ ]
21
+
22
+
23
+ class GetDM(_GetDM, skip_checks=True):
24
+ @classmethod
25
+ def _from_db(
26
+ cls,
27
+ item: sa.Row | ModelDB,
28
+ data: Optional[dict] = None,
29
+ exclude: Optional[list] = None,
30
+ suffix: Optional[str] = None,
31
+ model_out=None,
32
+ ):
33
+ return from_attributes(
34
+ model_out or cls.model,
35
+ item,
36
+ data=data,
37
+ exclude=exclude,
38
+ suffix=suffix,
39
+ reset_circular_refs=True,
40
+ )
41
+
42
+
43
+ class CreateDM(_CreateDM, skip_checks=True):
44
+ @classmethod
45
+ def _get_values_for_create_query(cls, model) -> dict:
46
+ return asdict(model, include=typing.cast(list[str], cls._fields_create))
47
+
48
+
49
+ class SaveDM(_SaveDM, skip_checks=True):
50
+ @classmethod
51
+ def _get_values_for_save_query(cls, model) -> dict:
52
+ if isinstance(model, typing.cast(Type, cls.model_create)):
53
+ return asdict(model, include=typing.cast(list[str], cls._fields_create))
54
+ return asdict(model, include=typing.cast(list[str], cls._fields_save))
55
+
56
+
57
+ class UpdateDM(_UpdateDM, skip_checks=True):
58
+ @classmethod
59
+ def _get_values_for_update_query(cls, model) -> dict:
60
+ return asdict(model, include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
61
+
62
+
63
+ class DeleteDM(_DeleteDM, skip_checks=True): ...
64
+
65
+
66
+ class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
@@ -0,0 +1,76 @@
1
+ import typing
2
+
3
+ from typing import Optional, Type
4
+
5
+ import pydantic # noqa: F401 # type: ignore
6
+ import sqlalchemy as sa
7
+
8
+ from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
9
+
10
+
11
+ __all__ = [
12
+ "OrderBy",
13
+ "GetDM",
14
+ "CreateDM",
15
+ "SaveDM",
16
+ "UpdateDM",
17
+ "DeleteDM",
18
+ "DM",
19
+ ]
20
+
21
+
22
+ class _object:
23
+ pass
24
+
25
+
26
+ class GetDM(_GetDM, skip_checks=True):
27
+ @classmethod
28
+ def _from_db(
29
+ cls,
30
+ item: sa.Row | ModelDB,
31
+ data: Optional[dict] = None,
32
+ exclude: Optional[list] = None,
33
+ suffix: Optional[str] = None,
34
+ model_out=None,
35
+ ):
36
+ if isinstance(item, sa.Row):
37
+ data = {**item._asdict(), **(data or {})}
38
+ else:
39
+ data = {**item.__dict__, **(data or {})}
40
+
41
+ obj = _object()
42
+
43
+ if suffix:
44
+ obj.__dict__ = {
45
+ k.rstrip(suffix): v for k, v in data.items() if (not exclude or k not in exclude) and k.endswith(suffix)
46
+ }
47
+ else:
48
+ obj.__dict__ = {k: v for k, v in data.items() if not exclude or k not in exclude}
49
+
50
+ return (model_out or cls.model).model_validate(obj, from_attributes=True)
51
+
52
+
53
+ class CreateDM(_CreateDM, skip_checks=True):
54
+ @classmethod
55
+ def _get_values_for_create_query(cls, model) -> dict:
56
+ return model.model_dump(include=typing.cast(list[str], cls._fields_create))
57
+
58
+
59
+ class SaveDM(_SaveDM, skip_checks=True):
60
+ @classmethod
61
+ def _get_values_for_save_query(cls, model) -> dict:
62
+ if isinstance(model, typing.cast(Type, cls.model_create)):
63
+ return model.model_dump(include=typing.cast(list[str], cls._fields_create))
64
+ return model.model_dump(include=typing.cast(list[str], cls._fields_save))
65
+
66
+
67
+ class UpdateDM(_UpdateDM, skip_checks=True):
68
+ @classmethod
69
+ def _get_values_for_update_query(cls, model) -> dict:
70
+ return model.model_dump(include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
71
+
72
+
73
+ class DeleteDM(_DeleteDM, skip_checks=True): ...
74
+
75
+
76
+ class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
@@ -0,0 +1,51 @@
1
+ [tool.poetry]
2
+ name = "dbdm"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Roman Koshel <roma.koshel@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ keywords = []
9
+ packages = [{ include = "dbdm" }]
10
+ classifiers = [
11
+ "Development Status :: 3 - Alpha",
12
+ "Programming Language :: Python",
13
+ "Programming Language :: Python :: 3 :: Only",
14
+ "Topic :: Software Development :: Libraries :: Python Modules",
15
+ "Topic :: Utilities",
16
+ "License :: OSI Approved :: MIT License",
17
+ ]
18
+
19
+ [tool.poetry.dependencies]
20
+ python = ">=3.11, <4.0"
21
+ asyncpg = "*"
22
+ sqlalchemy = { version="^2", extras=["asyncio"] }
23
+
24
+ [tool.poetry.group.test.dependencies]
25
+ coverage = "*"
26
+ cwtch = ">=0.10.0"
27
+ docker = "*"
28
+ pydantic = "^2"
29
+ pylint = "*"
30
+ pytest = "*"
31
+ pytest_asyncio = "*"
32
+ pytest_timeout = "*"
33
+
34
+ [tool.poetry.group.docs.dependencies]
35
+ mkdocs = "*"
36
+ mkdocs-material = "*"
37
+ mkdocstrings = { extras = ["python"], version = "*" }
38
+
39
+ [build-system]
40
+ requires = ["poetry-core"]
41
+ build-backend = "poetry.core.masonry.api"
42
+
43
+ [tool.black]
44
+ line-length = 120
45
+
46
+ [tool.isort]
47
+ indent = 4
48
+ lines_after_imports = 2
49
+ lines_between_types = 1
50
+ src_paths = ["."]
51
+ profile = "black"