dbdm 0.1.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
dbdm-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Roma Koshel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
dbdm-0.1.0/PKG-INFO ADDED
@@ -0,0 +1,191 @@
1
+ Metadata-Version: 2.1
2
+ Name: dbdm
3
+ Version: 0.1.0
4
+ Summary:
5
+ License: MIT
6
+ Author: Roman Koshel
7
+ Author-email: roma.koshel@gmail.com
8
+ Requires-Python: >=3.11,<4.0
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: License :: OSI Approved :: MIT License
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Programming Language :: Python :: 3.13
16
+ Classifier: Programming Language :: Python :: 3 :: Only
17
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
18
+ Classifier: Topic :: Utilities
19
+ Requires-Dist: asyncpg
20
+ Requires-Dist: sqlalchemy[asyncio] (>=2,<3)
21
+ Description-Content-Type: text/markdown
22
+
23
+ # DBDM [wip]
24
+
25
+
26
+ ## Examples
27
+
28
+ ```python
29
+ from typing import ClassVar, Type
30
+
31
+ import sqlalchemy as sa
32
+
33
+ from cwtch import dataclass, resolve_types, view
34
+ from cwtch.types import UNSET, Unset
35
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
36
+ from sqlalchemy.ext.asyncio import create_async_engine
37
+
38
+ from dbdm import DM, NotFoundError, bind_engine
39
+
40
+
41
+ class BaseDB(DeclarativeBase):
42
+ pass
43
+
44
+
45
+ class ParentDB(BaseDB):
46
+ __tablename__ = "parents"
47
+
48
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
49
+ name: Mapped[str]
50
+ data: Mapped[str]
51
+ children = relationship("ChildDB", uselist=True, viewonly=True)
52
+
53
+
54
+ class ChildDB(BaseDB):
55
+ __tablename__ = "children"
56
+
57
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
58
+ name: Mapped[str]
59
+ parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
60
+ parent = relationship("ParentDB", uselist=False, viewonly=True)
61
+
62
+
63
+ @dataclass(handle_circular_refs=True)
64
+ class Parent:
65
+ id: int
66
+ name: str
67
+ data: str
68
+ children: Unset[list["Child"]] = UNSET
69
+
70
+ # Parent views
71
+ Create: ClassVar[Type["ParentCreate"]]
72
+ Save: ClassVar[Type["ParentSave"]]
73
+ Update: ClassVar[Type["ParentUpdate"]]
74
+
75
+
76
+ @view(Parent, "Create", exclude=["id", "children"])
77
+ class ParentCreate:
78
+ pass
79
+
80
+
81
+ @view(Parent, "Save", exclude=["children"])
82
+ class ParentSave:
83
+ pass
84
+
85
+
86
+ @view(Parent, "Update", exclude=["children"])
87
+ class ParentUpdate:
88
+ name: Unset[str] = UNSET
89
+ data: Unset[str] = UNSET
90
+
91
+
92
+ @dataclass(handle_circular_refs=True)
93
+ class Child:
94
+ id: int
95
+ name: str
96
+ parent_id: int
97
+ parent: Unset[Parent] = UNSET
98
+
99
+ # Child views
100
+ Create: ClassVar[Type["ChildCreate"]]
101
+ Save: ClassVar[Type["ChildSave"]]
102
+ Update: ClassVar[Type["ChildUpdate"]]
103
+
104
+
105
+ @view(Child, "Create", exclude=["id", "parent"])
106
+ class ChildCreate:
107
+ pass
108
+
109
+
110
+ @view(Child, "Save", exclude=["parent"])
111
+ class ChildSave:
112
+ pass
113
+
114
+
115
+ @view(Child, "Update", exclude=["parent"])
116
+ class ChildUpdate:
117
+ name: Unset[str] = UNSET
118
+ parent_id: Unset[int] = UNSET
119
+
120
+
121
+ resolve_types(Parent, globals(), locals())
122
+
123
+
124
+ class ParentDM(DM):
125
+ model_db = ParentDB
126
+ model = Parent
127
+ model_create = Parent.Create
128
+ model_save = Parent.Save
129
+ model_update = Parent.Update
130
+ key = "id"
131
+ index_elements = ["id"]
132
+ joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
133
+
134
+
135
+ class ChildDM(DM):
136
+ model_db = ChildDB
137
+ model = Child
138
+ model_create = Child.Create
139
+ model_save = Child.Save
140
+ model_update = Child.Update
141
+ key = "id"
142
+ index_elements = ["id"]
143
+ joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
144
+
145
+
146
+ @pytest_asyncio.fixture
147
+ async def create_all(engine):
148
+ async with engine.begin() as conn:
149
+ await conn.run_sync(BaseDB.metadata.create_all)
150
+
151
+
152
+ async def example(engine):
153
+ engine = create_async_engine(...)
154
+
155
+ async with engine.begin() as conn:
156
+ await conn.run_sync(BaseDB.metadata.create_all)
157
+
158
+ bind_engine(engine)
159
+
160
+ parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
161
+
162
+ # parents: list[Parent]
163
+ total, parents = await ParentDM.get_many()
164
+
165
+ # parents: list[Parent]
166
+ total, parents = await ParentDM.get_many(page_size=1)
167
+
168
+ # parents: list[Parent]
169
+ total, parents = await ParentDM.get_many(page_size=1, page=2)
170
+
171
+ # parent: Parent
172
+ parent = await ParentDM.get(1)
173
+
174
+ # parent: Parent
175
+ parent = await ParentDM.get(1, joinedload={"children": True})
176
+
177
+ # parent: Parent
178
+ parent = await ParentDM.save(parent.Save())
179
+
180
+ # parent: Parent
181
+ parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
182
+
183
+ await ParentDM.delete(1)
184
+
185
+ # child : Child
186
+ child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
187
+
188
+ # parent: Parent
189
+ parent = await ParentDM.get(1, joinedload={"children": True})
190
+ ```
191
+
dbdm-0.1.0/README.md ADDED
@@ -0,0 +1,168 @@
1
+ # DBDM [wip]
2
+
3
+
4
+ ## Examples
5
+
6
+ ```python
7
+ from typing import ClassVar, Type
8
+
9
+ import sqlalchemy as sa
10
+
11
+ from cwtch import dataclass, resolve_types, view
12
+ from cwtch.types import UNSET, Unset
13
+ from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
14
+ from sqlalchemy.ext.asyncio import create_async_engine
15
+
16
+ from dbdm import DM, NotFoundError, bind_engine
17
+
18
+
19
+ class BaseDB(DeclarativeBase):
20
+ pass
21
+
22
+
23
+ class ParentDB(BaseDB):
24
+ __tablename__ = "parents"
25
+
26
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
27
+ name: Mapped[str]
28
+ data: Mapped[str]
29
+ children = relationship("ChildDB", uselist=True, viewonly=True)
30
+
31
+
32
+ class ChildDB(BaseDB):
33
+ __tablename__ = "children"
34
+
35
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
36
+ name: Mapped[str]
37
+ parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
38
+ parent = relationship("ParentDB", uselist=False, viewonly=True)
39
+
40
+
41
+ @dataclass(handle_circular_refs=True)
42
+ class Parent:
43
+ id: int
44
+ name: str
45
+ data: str
46
+ children: Unset[list["Child"]] = UNSET
47
+
48
+ # Parent views
49
+ Create: ClassVar[Type["ParentCreate"]]
50
+ Save: ClassVar[Type["ParentSave"]]
51
+ Update: ClassVar[Type["ParentUpdate"]]
52
+
53
+
54
+ @view(Parent, "Create", exclude=["id", "children"])
55
+ class ParentCreate:
56
+ pass
57
+
58
+
59
+ @view(Parent, "Save", exclude=["children"])
60
+ class ParentSave:
61
+ pass
62
+
63
+
64
+ @view(Parent, "Update", exclude=["children"])
65
+ class ParentUpdate:
66
+ name: Unset[str] = UNSET
67
+ data: Unset[str] = UNSET
68
+
69
+
70
+ @dataclass(handle_circular_refs=True)
71
+ class Child:
72
+ id: int
73
+ name: str
74
+ parent_id: int
75
+ parent: Unset[Parent] = UNSET
76
+
77
+ # Child views
78
+ Create: ClassVar[Type["ChildCreate"]]
79
+ Save: ClassVar[Type["ChildSave"]]
80
+ Update: ClassVar[Type["ChildUpdate"]]
81
+
82
+
83
+ @view(Child, "Create", exclude=["id", "parent"])
84
+ class ChildCreate:
85
+ pass
86
+
87
+
88
+ @view(Child, "Save", exclude=["parent"])
89
+ class ChildSave:
90
+ pass
91
+
92
+
93
+ @view(Child, "Update", exclude=["parent"])
94
+ class ChildUpdate:
95
+ name: Unset[str] = UNSET
96
+ parent_id: Unset[int] = UNSET
97
+
98
+
99
+ resolve_types(Parent, globals(), locals())
100
+
101
+
102
+ class ParentDM(DM):
103
+ model_db = ParentDB
104
+ model = Parent
105
+ model_create = Parent.Create
106
+ model_save = Parent.Save
107
+ model_update = Parent.Update
108
+ key = "id"
109
+ index_elements = ["id"]
110
+ joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
111
+
112
+
113
+ class ChildDM(DM):
114
+ model_db = ChildDB
115
+ model = Child
116
+ model_create = Child.Create
117
+ model_save = Child.Save
118
+ model_update = Child.Update
119
+ key = "id"
120
+ index_elements = ["id"]
121
+ joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
122
+
123
+
124
+ @pytest_asyncio.fixture
125
+ async def create_all(engine):
126
+ async with engine.begin() as conn:
127
+ await conn.run_sync(BaseDB.metadata.create_all)
128
+
129
+
130
+ async def example(engine):
131
+ engine = create_async_engine(...)
132
+
133
+ async with engine.begin() as conn:
134
+ await conn.run_sync(BaseDB.metadata.create_all)
135
+
136
+ bind_engine(engine)
137
+
138
+ parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
139
+
140
+ # parents: list[Parent]
141
+ total, parents = await ParentDM.get_many()
142
+
143
+ # parents: list[Parent]
144
+ total, parents = await ParentDM.get_many(page_size=1)
145
+
146
+ # parents: list[Parent]
147
+ total, parents = await ParentDM.get_many(page_size=1, page=2)
148
+
149
+ # parent: Parent
150
+ parent = await ParentDM.get(1)
151
+
152
+ # parent: Parent
153
+ parent = await ParentDM.get(1, joinedload={"children": True})
154
+
155
+ # parent: Parent
156
+ parent = await ParentDM.save(parent.Save())
157
+
158
+ # parent: Parent
159
+ parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
160
+
161
+ await ParentDM.delete(1)
162
+
163
+ # child : Child
164
+ child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
165
+
166
+ # parent: Parent
167
+ parent = await ParentDM.get(1, joinedload={"children": True})
168
+ ```
@@ -0,0 +1,7 @@
1
+ import importlib.metadata
2
+
3
+
4
+ __version__ = importlib.metadata.version("dbdm")
5
+
6
+
7
+ from .common import * # noqa: F403
@@ -0,0 +1,736 @@
1
+ import re
2
+ import typing
3
+
4
+ from asyncio import gather
5
+ from contextlib import asynccontextmanager
6
+ from contextvars import ContextVar
7
+ from dataclasses import dataclass
8
+ from typing import Any, AsyncIterator, Callable, Literal, Never, Optional, Protocol, Type, TypeVar
9
+
10
+ import sqlalchemy as sa
11
+
12
+ from sqlalchemy import delete, func, literal, select, union_all, update
13
+ from sqlalchemy.dialects.postgresql import insert
14
+ from sqlalchemy.ext.asyncio import async_sessionmaker
15
+ from sqlalchemy.orm import DeclarativeBase, aliased
16
+
17
+
18
+ __all__ = [
19
+ "DMError",
20
+ "AlreadyExistsError",
21
+ "BadParamsError",
22
+ "NotFoundError",
23
+ "OrderBy",
24
+ "bind_engine",
25
+ "transaction",
26
+ ]
27
+
28
+
29
+ _conn: dict[Optional[str], ContextVar] = {}
30
+ _engine: dict[Optional[str], sa.ext.asyncio.AsyncEngine] = {}
31
+
32
+
33
+ class DMError(Exception):
34
+ pass
35
+
36
+
37
+ class BadParamsError(DMError):
38
+ def __init__(self, message: str, param: Optional[str] = None):
39
+ self.message = message
40
+ self.param = param
41
+
42
+ def __str__(self):
43
+ return self.message
44
+
45
+ def __repr__(self):
46
+ return self.__str__()
47
+
48
+
49
+ class NotFoundError(DMError):
50
+ def __init__(self, key, value):
51
+ self.key = key
52
+ self.value = value
53
+
54
+ def __str__(self):
55
+ return f"item with ({self.key})=({self.value}) not found"
56
+
57
+
58
+ class AlreadyExistsError(DMError):
59
+ def __init__(self, key, value):
60
+ self.key = key
61
+ self.value = value
62
+
63
+ def __str__(self):
64
+ return f"key ({self.key})=({self.value}) already exists"
65
+
66
+
67
+ @dataclass
68
+ class OrderBy:
69
+ by: Any
70
+ order: Literal["asc", "desc"] = "asc"
71
+
72
+
73
+ def bind_engine(engine: sa.ext.asyncio.AsyncEngine, name: Optional[str] = None):
74
+ if engine.dialect.name != "postgresql":
75
+ raise DMError("only 'postgresql' dialect is supported")
76
+ _engine[name] = engine
77
+ _conn[name] = ContextVar("conn", default=None)
78
+
79
+
80
+ @asynccontextmanager
81
+ async def transaction(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
82
+ if (conn := _conn[engine_name].get()) is None:
83
+ async with _engine[engine_name].connect() as conn:
84
+ async with conn.begin():
85
+ _conn[engine_name].set(conn)
86
+ try:
87
+ yield conn
88
+ finally:
89
+ _conn[engine_name].set(None)
90
+ else:
91
+ yield conn
92
+
93
+
94
+ @asynccontextmanager
95
+ async def get_conn(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
96
+ if (conn := _conn[engine_name].get()) is None:
97
+ async with _engine[engine_name].connect() as conn:
98
+ async with conn.begin():
99
+ yield conn
100
+ else:
101
+ yield conn
102
+
103
+
104
+ @asynccontextmanager
105
+ async def get_sess(engine_name: Optional[str] = None):
106
+ async with get_conn(engine_name=engine_name) as conn:
107
+ async_session = async_sessionmaker(conn, expire_on_commit=False)
108
+ async with async_session() as sess:
109
+ async with sess.begin():
110
+ yield sess
111
+
112
+
113
+ def _raise_exc(e: Exception) -> Never:
114
+ if isinstance(e, sa.exc.IntegrityError):
115
+ detail_match = re.match(r".*\nDETAIL:\s*(?P<text>.*)$", e.orig.args[0]) # type: ignore
116
+ if detail_match:
117
+ text = detail_match.groupdict()["text"].strip()
118
+ m = re.match(r"Key \((?P<key>.*)\)=\((?P<key_value>.*)\) already exists.", text)
119
+ if m:
120
+ key = m.groupdict()["key"]
121
+ key_value = m.groupdict()["key_value"].strip('\\"')
122
+ raise AlreadyExistsError(key, key_value)
123
+ raise e
124
+
125
+
126
+ ModelDB = TypeVar("ModelDB", bound=DeclarativeBase)
127
+
128
+
129
+ class _BaseProtocol(Protocol):
130
+ engine_name: Optional[str]
131
+
132
+ model_db: Type[DeclarativeBase]
133
+ model: Type
134
+
135
+ @classmethod
136
+ async def _execute(cls, query) -> sa.ResultProxy: ...
137
+
138
+ @classmethod
139
+ def _make_from_db_data(
140
+ cls,
141
+ db_item,
142
+ row: Optional[tuple] = None,
143
+ joinedload: Optional[dict] = None,
144
+ ) -> dict: ...
145
+
146
+ @classmethod
147
+ def _from_db(
148
+ cls,
149
+ item: sa.Row | ModelDB,
150
+ data: Optional[dict] = None,
151
+ exclude: Optional[list] = None,
152
+ suffix: Optional[str] = None,
153
+ model_out=None,
154
+ ): ...
155
+
156
+ @classmethod
157
+ def _get_key(cls, key=None): ...
158
+
159
+ @classmethod
160
+ async def get(
161
+ cls,
162
+ key_value,
163
+ key=None,
164
+ raise_not_found: Optional[bool] = None,
165
+ joinedload: Optional[dict] = None,
166
+ model_out: Optional[Type] = None,
167
+ **kwds,
168
+ ) -> Optional[type]: ...
169
+
170
+ @classmethod
171
+ async def get_many(
172
+ cls,
173
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
174
+ page: Optional[int] = None,
175
+ page_size: Optional[int] = None,
176
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
177
+ joinedload: Optional[dict] = None,
178
+ model_out: Optional[Type] = None,
179
+ **kwds,
180
+ ) -> tuple[int, list]: ...
181
+
182
+
183
+ class _Meta(typing._ProtocolMeta):
184
+ def __new__(cls, name, bases, ns, skip_checks: bool = False):
185
+ if not skip_checks:
186
+
187
+ def get_all_bases(bases: tuple) -> set:
188
+ result = set()
189
+ for base in bases:
190
+ result.add(base)
191
+ result |= get_all_bases(getattr(base, "__bases__", ()))
192
+ return result
193
+
194
+ all_bases = get_all_bases(bases)
195
+
196
+ if _GetDM not in all_bases:
197
+ raise TypeError("Any DM class should be subclassed from GetDM class")
198
+
199
+ if (model_db := ns.get("model_db")) is None:
200
+ raise ValueError("'model_db' field is required")
201
+
202
+ if (model := ns.get("model")) is None:
203
+ raise ValueError("'model' field is required")
204
+
205
+ if (
206
+ getattr(model, "__dataclass_fields__", None) is None
207
+ and getattr(model, "__pydantic_fields__", None) is None
208
+ ):
209
+ raise ValueError("'model' is not a valid model")
210
+
211
+ if _CreateDM in all_bases or _SaveDM in all_bases:
212
+ if (model_create := ns.get("model_create")) is None:
213
+ raise ValueError("'model_create' field is required")
214
+ if (
215
+ getattr(model_create, "__dataclass_fields__", None) is None
216
+ and getattr(model_create, "__pydantic_fields__", None) is None
217
+ ):
218
+ raise ValueError("'model_create' is not a valid model")
219
+ ns["_fields_create"] = set(model_db.__table__.columns.keys()) & set(
220
+ getattr(model_create, "__dataclass_fields__", None)
221
+ or getattr(model_create, "__pydantic_fields__", None)
222
+ or set()
223
+ )
224
+ if not ns["_fields_create"]:
225
+ raise DMError("model_create does not contain valid fields")
226
+
227
+ if _SaveDM in all_bases:
228
+ if (model_save := ns.get("model_save")) is None:
229
+ raise ValueError("'model_save' field is required")
230
+ if (
231
+ getattr(model_save, "__dataclass_fields__", None) is None
232
+ and getattr(model_save, "__pydantic_fields__", None) is None
233
+ ):
234
+ raise ValueError("'model_save' is not a valid model")
235
+ ns["_fields_save"] = set(model_db.__table__.columns.keys()) & set(
236
+ getattr(model_save, "__dataclass_fields__", None)
237
+ or getattr(model_save, "__pydantic_fields__", None)
238
+ or set()
239
+ )
240
+ if not ns["_fields_save"]:
241
+ raise DMError("model_save does not contain valid fields")
242
+
243
+ if _UpdateDM in all_bases:
244
+ if (model_update := ns.get("model_update")) is None:
245
+ raise ValueError("'model_update' field is required")
246
+ if (
247
+ getattr(model_update, "__dataclass_fields__", None) is None
248
+ and getattr(model_update, "__pydantic_fields__", None) is None
249
+ ):
250
+ raise ValueError("'model_update' is not a valid model")
251
+ ns["_fields_update"] = set(model_db.__table__.columns.keys()) & set(
252
+ getattr(model_update, "__dataclass_fields__", None)
253
+ or getattr(model_update, "__pydantic_fields__", None)
254
+ or set()
255
+ )
256
+ if not ns["_fields_update"]:
257
+ raise DMError("model_update does not contain valid fields")
258
+
259
+ return super().__new__(cls, name, bases, ns)
260
+
261
+
262
+ class _GetDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
263
+ engine_name: Optional[str] = None
264
+
265
+ model_db: Type[DeclarativeBase] = typing.cast(Type[DeclarativeBase], None)
266
+ model: Type = typing.cast(Type, None)
267
+
268
+ key = None
269
+ order_by: Optional[list[OrderBy | str | Callable] | OrderBy | str | Callable] = None
270
+
271
+ joinedload: dict = {}
272
+
273
+ @classmethod
274
+ async def _execute(cls, query) -> sa.ResultProxy:
275
+ async with get_conn(engine_name=cls.engine_name) as conn:
276
+ try:
277
+ return await conn.execute(query)
278
+ except Exception as e:
279
+ _raise_exc(e)
280
+
281
+ @classmethod
282
+ def _make_from_db_data(
283
+ cls,
284
+ db_item,
285
+ row: Optional[tuple] = None,
286
+ joinedload: Optional[dict] = None,
287
+ ) -> dict:
288
+ return {}
289
+
290
+ @classmethod
291
+ def _from_db(
292
+ cls,
293
+ item: sa.Row | ModelDB,
294
+ data: Optional[dict] = None,
295
+ exclude: Optional[list] = None,
296
+ suffix: Optional[str] = None,
297
+ model_out=None,
298
+ ): ...
299
+
300
+ @classmethod
301
+ def _get_key(cls, key=None):
302
+ if key is None:
303
+ key = cls.key
304
+ if key is None:
305
+ raise ValueError("key is None")
306
+ if isinstance(key, str):
307
+ key = getattr(cls.model_db, key)
308
+ return key
309
+
310
+ @classmethod
311
+ def _get_order_by(cls, c, order_by: Optional[list[OrderBy | str] | OrderBy | str] = None) -> list:
312
+ _order_by = order_by
313
+ if _order_by is None:
314
+ if callable(cls.order_by):
315
+ _order_by = cls.order_by(c)
316
+ else:
317
+ _order_by = cls.order_by or cls._get_key()
318
+ if _order_by is not None:
319
+ if not isinstance(_order_by, list):
320
+ _order_by = [_order_by]
321
+ else:
322
+ _order_by = list(_order_by)
323
+ for i, x in enumerate(_order_by):
324
+ if isinstance(x, OrderBy):
325
+ x = OrderBy(by=x.by, order=x.order)
326
+ else:
327
+ x = OrderBy(by=x)
328
+ if isinstance(x.by, str):
329
+ if getattr(c, x.by, None) is None:
330
+ raise BadParamsError(f"invalid order_by '{x.by}'", param=f"{x.by}")
331
+ x.by = getattr(c, x.by)
332
+ _order_by[i] = getattr(sa, x.order)(x.by)
333
+ return typing.cast(list, _order_by)
334
+
335
+ @classmethod
336
+ def _make_get_query(cls, key, key_value, **kwds) -> sa.sql.Select:
337
+ return select(cls.model_db).where(key == key_value)
338
+
339
+ @classmethod
340
+ async def get(
341
+ cls,
342
+ key_value,
343
+ key=None,
344
+ raise_not_found: Optional[bool] = None,
345
+ joinedload: Optional[dict] = None,
346
+ model_out: Optional[Type] = None,
347
+ **kwds,
348
+ ) -> Optional[type]:
349
+ key = cls._get_key(key=key)
350
+ query = cls._make_get_query(key, key_value, **kwds)
351
+ model_out = model_out or cls.model
352
+
353
+ if joinedload is None or not any(joinedload.values()):
354
+ item = (await cls._execute(query)).one_or_none()
355
+ if item:
356
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
357
+ if raise_not_found:
358
+ raise NotFoundError(key=key, value=key_value)
359
+ return
360
+
361
+ exclude = []
362
+ for k, v in cls.joinedload.items():
363
+ if joinedload.get(k) is True:
364
+ query = query.options(v(cls.model_db))
365
+ else:
366
+ exclude.append(k)
367
+
368
+ async with get_sess(engine_name=cls.engine_name) as sess:
369
+ item = (await sess.execute(query)).unique().scalars().one_or_none()
370
+ if item:
371
+ return cls._from_db(
372
+ item,
373
+ exclude=exclude,
374
+ data=cls._make_from_db_data(item, joinedload=joinedload),
375
+ model_out=model_out,
376
+ )
377
+ if raise_not_found:
378
+ raise NotFoundError(key=key, value=key_value)
379
+
380
+ @classmethod
381
+ def _make_get_many_query(
382
+ cls,
383
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
384
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
385
+ **kwds,
386
+ ) -> sa.sql.Select:
387
+ query = select(func.count(literal("*")).over().label("rows_total"), cls.model_db)
388
+ if flt is not None:
389
+ query = query.where(flt)
390
+ query = query.order_by(*cls._get_order_by(cls.model_db, order_by))
391
+ return query
392
+
393
+ @classmethod
394
+ async def get_many(
395
+ cls,
396
+ flt: Optional[sa.sql.elements.BinaryExpression] = None,
397
+ page: Optional[int] = None,
398
+ page_size: Optional[int] = None,
399
+ order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
400
+ joinedload: Optional[dict] = None,
401
+ model_out: Optional[Type] = None,
402
+ **kwds,
403
+ ) -> tuple[int, list]:
404
+ model_db = cls.model_db
405
+ model_out = model_out or cls.model
406
+
407
+ query = cls._make_get_many_query(flt=flt, order_by=order_by, **kwds)
408
+
409
+ cte = query.cte("cte")
410
+
411
+ query = select(literal(1).label("i"), cte)
412
+
413
+ if page_size:
414
+ page = page or 1
415
+ query = query.limit(page_size).offset((page - 1) * page_size)
416
+
417
+ query = union_all(select(literal(0).label("i"), cte).limit(1), query)
418
+
419
+ from_db = cls._from_db
420
+ make_from_orm_data = cls._make_from_db_data
421
+
422
+ if joinedload is None or not any(joinedload.values()):
423
+ rows = (await cls._execute(query)).all()
424
+ return rows[0].rows_total if rows else 0, [
425
+ from_db(row, data=make_from_orm_data(row), model_out=model_out) for row in rows[1:]
426
+ ]
427
+
428
+ main_cte = query.cte("main_cte")
429
+
430
+ model_db_alias = aliased(model_db, main_cte)
431
+ query = select(main_cte, model_db_alias)
432
+
433
+ exclude = []
434
+ for k, v in cls.joinedload.items():
435
+ if joinedload.get(k) is True:
436
+ query = query.options(v(model_db_alias))
437
+ else:
438
+ exclude.append(k)
439
+
440
+ query = query.order_by(sa.asc(main_cte.c.i), *cls._get_order_by(main_cte.c, order_by))
441
+
442
+ def _hash(row):
443
+ return hash((row[0], row[1], row[-1]))
444
+
445
+ async with get_sess(engine_name=cls.engine_name) as sess:
446
+ rows = (await sess.execute(query)).unique(_hash).all()
447
+ return rows[0].rows_total if rows else 0, [
448
+ from_db(
449
+ row[-1],
450
+ exclude=exclude,
451
+ data=make_from_orm_data(row[-1], row=typing.cast(tuple, row), joinedload=joinedload),
452
+ model_out=model_out,
453
+ )
454
+ for row in rows[1:]
455
+ ]
456
+
457
+
458
+ class _CreateDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
459
+ model_create: Type = typing.cast(Type, None)
460
+
461
+ index_elements: Optional[list] = None
462
+
463
+ _fields_create: set[str] = set()
464
+
465
+ @classmethod
466
+ def _get_values_for_create_query(cls, model) -> dict: ...
467
+
468
+ @classmethod
469
+ def _make_create_query(cls, model, returning: Optional[bool] = None) -> sa.sql.Insert:
470
+ model_db = cls.model_db
471
+ query = insert(model_db).values(cls._get_values_for_create_query(model))
472
+ if returning:
473
+ query = query.returning(model_db)
474
+ return query
475
+
476
+ @classmethod
477
+ async def create(
478
+ cls,
479
+ model,
480
+ model_out: Optional[Type] = None,
481
+ returning: bool = True,
482
+ ):
483
+ result = await cls._execute(cls._make_create_query(model, returning=returning))
484
+ if returning:
485
+ item = result.one()
486
+ model_out = model_out or cls.model
487
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
488
+
489
+ @classmethod
490
+ def _make_create_many_query(cls, models: list, returning: bool | None = None) -> sa.sql.Insert:
491
+ model_db = cls.model_db
492
+ query = insert(cls.model_db).values([cls._get_values_for_create_query(model) for model in models])
493
+ if returning:
494
+ query = query.returning(model_db)
495
+ return query
496
+
497
+ @classmethod
498
+ async def create_many(
499
+ cls,
500
+ models: list,
501
+ model_out: Optional[Type] = None,
502
+ returning: bool = True,
503
+ ) -> list | None:
504
+ if not models:
505
+ return []
506
+ result = await cls._execute(cls._make_create_many_query(models, returning=returning))
507
+ if returning:
508
+ from_db = cls._from_db
509
+ make_from_db_data = cls._make_from_db_data
510
+ model_out = model_out or cls.model
511
+ return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
512
+
513
+ @classmethod
514
+ def _make_get_or_create_query(
515
+ cls,
516
+ model,
517
+ update_element: str,
518
+ index_elements: Optional[list] = None,
519
+ ) -> sa.sql.Insert:
520
+ model_db = cls.model_db
521
+ query = insert(model_db).values(cls._get_values_for_create_query(model))
522
+ index_elements = index_elements or cls.index_elements
523
+ if index_elements is None:
524
+ raise ValueError("index_elements is None")
525
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
526
+ return query.on_conflict_do_update(
527
+ index_elements=index_elements,
528
+ set_={update_element: getattr(query.excluded, update_element)},
529
+ ).returning(model_db)
530
+
531
+ @classmethod
532
+ async def get_or_create(
533
+ cls,
534
+ model,
535
+ update_element: str,
536
+ index_elements: Optional[list] = None,
537
+ model_out: Optional[Type] = None,
538
+ ):
539
+ model_out = model_out or cls.model
540
+ item = (
541
+ await cls._execute(
542
+ cls._make_get_or_create_query(
543
+ model,
544
+ update_element,
545
+ index_elements=index_elements,
546
+ )
547
+ )
548
+ ).one()
549
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
550
+
551
+
552
+ class _SaveDM(_CreateDM, metaclass=_Meta, skip_checks=True):
553
+ model_create: Type = typing.cast(Type, None)
554
+ model_save: Type = typing.cast(Type, None)
555
+
556
+ index_elements: Optional[list] = None
557
+
558
+ _fields_create: set[str] = set()
559
+ _fields_save: set[str] = set()
560
+
561
+ @classmethod
562
+ def _get_values_for_save_query(cls, model) -> dict: ...
563
+
564
+ @classmethod
565
+ def _get_on_conflict_do_update_set_for_save_query(cls, excluded, model) -> dict:
566
+ if isinstance(model, typing.cast(Type, cls.model_create)):
567
+ return {k: getattr(excluded, k) for k in cls._fields_create}
568
+ return {k: getattr(excluded, k) for k in cls._fields_save}
569
+
570
+ @classmethod
571
+ def _make_save_query(
572
+ cls,
573
+ model,
574
+ index_elements: Optional[list] = None,
575
+ returning: Optional[bool] = None,
576
+ ) -> sa.sql.Insert:
577
+ model_db = cls.model_db
578
+ query = insert(model_db).values(cls._get_values_for_save_query(model))
579
+ index_elements = index_elements or cls.index_elements
580
+ if index_elements is None:
581
+ raise ValueError("index_elements is None")
582
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
583
+ query = query.on_conflict_do_update(
584
+ index_elements=index_elements,
585
+ set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, model),
586
+ )
587
+ if returning:
588
+ query = query.returning(model_db)
589
+ return query
590
+
591
+ @classmethod
592
+ async def save(
593
+ cls,
594
+ model,
595
+ index_elements: Optional[list] = None,
596
+ model_out: Optional[Type] = None,
597
+ returning: bool = True,
598
+ ):
599
+ result = await cls._execute(
600
+ cls._make_save_query(
601
+ model,
602
+ index_elements=index_elements,
603
+ returning=returning,
604
+ )
605
+ )
606
+ if returning:
607
+ item = result.one()
608
+ model_out = model_out or cls.model
609
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
610
+
611
+ @classmethod
612
+ def _make_save_many_query(
613
+ cls,
614
+ models: list,
615
+ index_elements: Optional[list] = None,
616
+ returning: Optional[bool] = None,
617
+ ) -> sa.sql.Insert:
618
+ model_db = cls.model_db
619
+ query = insert(model_db).values([cls._get_values_for_save_query(model) for model in models])
620
+ index_elements = index_elements or cls.index_elements
621
+ if index_elements is None:
622
+ raise ValueError("index_elements is None")
623
+ index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
624
+ query = query.on_conflict_do_update(
625
+ index_elements=index_elements,
626
+ set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, models[0]),
627
+ )
628
+ if returning:
629
+ query = query.returning(model_db)
630
+ return query
631
+
632
+ @classmethod
633
+ async def save_many(
634
+ cls,
635
+ models: list,
636
+ index_elements: Optional[list] = None,
637
+ model_out: Optional[Type] = None,
638
+ returning: bool = True,
639
+ ) -> list | None:
640
+ if not models:
641
+ return []
642
+ result = await cls._execute(
643
+ cls._make_save_many_query(
644
+ models,
645
+ index_elements=index_elements,
646
+ returning=returning,
647
+ )
648
+ )
649
+ if returning:
650
+ from_db = cls._from_db
651
+ make_from_db_data = cls._make_from_db_data
652
+ model_out = model_out or cls.model
653
+ return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
654
+
655
+
656
+ class _UpdateDM(_CreateDM, metaclass=_Meta, skip_checks=True):
657
+ model_update: Type = typing.cast(Type, None)
658
+ _fields_update: set[str] = set()
659
+
660
+ @classmethod
661
+ def _get_values_for_update_query(cls, model) -> dict: ...
662
+
663
+ @classmethod
664
+ def _make_update_query(cls, model, key: str, returning: Optional[bool] = None) -> sa.sql.Update:
665
+ model_db = cls.model_db
666
+ query = (
667
+ update(model_db)
668
+ .values(cls._get_values_for_update_query(model))
669
+ .where(getattr(model_db, key) == getattr(model, key))
670
+ .returning(getattr(model_db, key))
671
+ )
672
+ if returning:
673
+ query = query.returning(model_db)
674
+ return query
675
+
676
+ @classmethod
677
+ async def update(
678
+ cls,
679
+ model,
680
+ key: str,
681
+ raise_not_found: Optional[bool] = None,
682
+ model_out: Optional[Type] = None,
683
+ returning: bool = True,
684
+ ):
685
+ result = await cls._execute(cls._make_update_query(model, key, returning=returning or raise_not_found))
686
+ if raise_not_found and result.rowcount == 0:
687
+ raise NotFoundError(key=key, value=getattr(model, key))
688
+ if returning:
689
+ item = result.one_or_none()
690
+ if item:
691
+ model_out = model_out or cls.model
692
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
693
+
694
+ @classmethod
695
+ async def update_many(
696
+ cls,
697
+ models: list,
698
+ key: str,
699
+ model_out: Optional[Type] = None,
700
+ returning: bool = True,
701
+ ) -> list | None:
702
+ async with transaction():
703
+ results = await gather(
704
+ *[cls.update(model, key, model_out=model_out, returning=returning) for model in models]
705
+ )
706
+ if returning:
707
+ return typing.cast(list | None, results)
708
+
709
+
710
+ class _DeleteDM(_CreateDM, metaclass=_Meta, skip_checks=True):
711
+ @classmethod
712
+ def _make_delete_query(cls, key, key_value, returning: Optional[bool] = None):
713
+ query = delete(cls.model_db).where(key == key_value)
714
+ if returning:
715
+ query = query.returning(cls.model_db)
716
+ return query
717
+
718
+ @classmethod
719
+ async def delete(
720
+ cls,
721
+ key_value,
722
+ key=None,
723
+ raise_not_found: Optional[bool] = None,
724
+ model_out: Optional[Type] = None,
725
+ returning: bool = True,
726
+ ):
727
+ key = cls._get_key(key=key)
728
+ result = await cls._execute(cls._make_delete_query(key, key_value, returning=returning))
729
+ if raise_not_found:
730
+ if raise_not_found and result.rowcount == 0:
731
+ raise NotFoundError(key=key, value=key_value)
732
+ if returning:
733
+ item = result.one_or_none()
734
+ if item:
735
+ model_out = model_out or cls.model
736
+ return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
@@ -0,0 +1,66 @@
1
+ import typing
2
+
3
+ from typing import Optional, Type
4
+
5
+ import sqlalchemy as sa
6
+
7
+ from cwtch import asdict, from_attributes
8
+
9
+ from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
10
+
11
+
12
+ __all__ = [
13
+ "OrderBy",
14
+ "GetDM",
15
+ "CreateDM",
16
+ "SaveDM",
17
+ "UpdateDM",
18
+ "DeleteDM",
19
+ "DM",
20
+ ]
21
+
22
+
23
+ class GetDM(_GetDM, skip_checks=True):
24
+ @classmethod
25
+ def _from_db(
26
+ cls,
27
+ item: sa.Row | ModelDB,
28
+ data: Optional[dict] = None,
29
+ exclude: Optional[list] = None,
30
+ suffix: Optional[str] = None,
31
+ model_out=None,
32
+ ):
33
+ return from_attributes(
34
+ model_out or cls.model,
35
+ item,
36
+ data=data,
37
+ exclude=exclude,
38
+ suffix=suffix,
39
+ reset_circular_refs=True,
40
+ )
41
+
42
+
43
+ class CreateDM(_CreateDM, skip_checks=True):
44
+ @classmethod
45
+ def _get_values_for_create_query(cls, model) -> dict:
46
+ return asdict(model, include=typing.cast(list[str], cls._fields_create))
47
+
48
+
49
+ class SaveDM(_SaveDM, skip_checks=True):
50
+ @classmethod
51
+ def _get_values_for_save_query(cls, model) -> dict:
52
+ if isinstance(model, typing.cast(Type, cls.model_create)):
53
+ return asdict(model, include=typing.cast(list[str], cls._fields_create))
54
+ return asdict(model, include=typing.cast(list[str], cls._fields_save))
55
+
56
+
57
+ class UpdateDM(_UpdateDM, skip_checks=True):
58
+ @classmethod
59
+ def _get_values_for_update_query(cls, model) -> dict:
60
+ return asdict(model, include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
61
+
62
+
63
+ class DeleteDM(_DeleteDM, skip_checks=True): ...
64
+
65
+
66
+ class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
@@ -0,0 +1,76 @@
1
+ import typing
2
+
3
+ from typing import Optional, Type
4
+
5
+ import pydantic # noqa: F401 # type: ignore
6
+ import sqlalchemy as sa
7
+
8
+ from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
9
+
10
+
11
+ __all__ = [
12
+ "OrderBy",
13
+ "GetDM",
14
+ "CreateDM",
15
+ "SaveDM",
16
+ "UpdateDM",
17
+ "DeleteDM",
18
+ "DM",
19
+ ]
20
+
21
+
22
+ class _object:
23
+ pass
24
+
25
+
26
+ class GetDM(_GetDM, skip_checks=True):
27
+ @classmethod
28
+ def _from_db(
29
+ cls,
30
+ item: sa.Row | ModelDB,
31
+ data: Optional[dict] = None,
32
+ exclude: Optional[list] = None,
33
+ suffix: Optional[str] = None,
34
+ model_out=None,
35
+ ):
36
+ if isinstance(item, sa.Row):
37
+ data = {**item._asdict(), **(data or {})}
38
+ else:
39
+ data = {**item.__dict__, **(data or {})}
40
+
41
+ obj = _object()
42
+
43
+ if suffix:
44
+ obj.__dict__ = {
45
+ k.rstrip(suffix): v for k, v in data.items() if (not exclude or k not in exclude) and k.endswith(suffix)
46
+ }
47
+ else:
48
+ obj.__dict__ = {k: v for k, v in data.items() if not exclude or k not in exclude}
49
+
50
+ return (model_out or cls.model).model_validate(obj, from_attributes=True)
51
+
52
+
53
+ class CreateDM(_CreateDM, skip_checks=True):
54
+ @classmethod
55
+ def _get_values_for_create_query(cls, model) -> dict:
56
+ return model.model_dump(include=typing.cast(list[str], cls._fields_create))
57
+
58
+
59
+ class SaveDM(_SaveDM, skip_checks=True):
60
+ @classmethod
61
+ def _get_values_for_save_query(cls, model) -> dict:
62
+ if isinstance(model, typing.cast(Type, cls.model_create)):
63
+ return model.model_dump(include=typing.cast(list[str], cls._fields_create))
64
+ return model.model_dump(include=typing.cast(list[str], cls._fields_save))
65
+
66
+
67
+ class UpdateDM(_UpdateDM, skip_checks=True):
68
+ @classmethod
69
+ def _get_values_for_update_query(cls, model) -> dict:
70
+ return model.model_dump(include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
71
+
72
+
73
+ class DeleteDM(_DeleteDM, skip_checks=True): ...
74
+
75
+
76
+ class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
@@ -0,0 +1,51 @@
1
+ [tool.poetry]
2
+ name = "dbdm"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Roman Koshel <roma.koshel@gmail.com>"]
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ keywords = []
9
+ packages = [{ include = "dbdm" }]
10
+ classifiers = [
11
+ "Development Status :: 3 - Alpha",
12
+ "Programming Language :: Python",
13
+ "Programming Language :: Python :: 3 :: Only",
14
+ "Topic :: Software Development :: Libraries :: Python Modules",
15
+ "Topic :: Utilities",
16
+ "License :: OSI Approved :: MIT License",
17
+ ]
18
+
19
+ [tool.poetry.dependencies]
20
+ python = ">=3.11, <4.0"
21
+ asyncpg = "*"
22
+ sqlalchemy = { version="^2", extras=["asyncio"] }
23
+
24
+ [tool.poetry.group.test.dependencies]
25
+ coverage = "*"
26
+ cwtch = ">=0.10.0"
27
+ docker = "*"
28
+ pydantic = "^2"
29
+ pylint = "*"
30
+ pytest = "*"
31
+ pytest_asyncio = "*"
32
+ pytest_timeout = "*"
33
+
34
+ [tool.poetry.group.docs.dependencies]
35
+ mkdocs = "*"
36
+ mkdocs-material = "*"
37
+ mkdocstrings = { extras = ["python"], version = "*" }
38
+
39
+ [build-system]
40
+ requires = ["poetry-core"]
41
+ build-backend = "poetry.core.masonry.api"
42
+
43
+ [tool.black]
44
+ line-length = 120
45
+
46
+ [tool.isort]
47
+ indent = 4
48
+ lines_after_imports = 2
49
+ lines_between_types = 1
50
+ src_paths = ["."]
51
+ profile = "black"