dbdm 0.1.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- dbdm-0.1.0/LICENSE +21 -0
- dbdm-0.1.0/PKG-INFO +191 -0
- dbdm-0.1.0/README.md +168 -0
- dbdm-0.1.0/dbdm/__init__.py +7 -0
- dbdm-0.1.0/dbdm/common.py +736 -0
- dbdm-0.1.0/dbdm/cwtch.py +66 -0
- dbdm-0.1.0/dbdm/pydantic.py +76 -0
- dbdm-0.1.0/pyproject.toml +51 -0
dbdm-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2023 Roma Koshel
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
dbdm-0.1.0/PKG-INFO
ADDED
@@ -0,0 +1,191 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: dbdm
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary:
|
5
|
+
License: MIT
|
6
|
+
Author: Roman Koshel
|
7
|
+
Author-email: roma.koshel@gmail.com
|
8
|
+
Requires-Python: >=3.11,<4.0
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
11
|
+
Classifier: Programming Language :: Python
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
16
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
17
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
18
|
+
Classifier: Topic :: Utilities
|
19
|
+
Requires-Dist: asyncpg
|
20
|
+
Requires-Dist: sqlalchemy[asyncio] (>=2,<3)
|
21
|
+
Description-Content-Type: text/markdown
|
22
|
+
|
23
|
+
# DBDM [wip]
|
24
|
+
|
25
|
+
|
26
|
+
## Examples
|
27
|
+
|
28
|
+
```python
|
29
|
+
from typing import ClassVar, Type
|
30
|
+
|
31
|
+
import sqlalchemy as sa
|
32
|
+
|
33
|
+
from cwtch import dataclass, resolve_types, view
|
34
|
+
from cwtch.types import UNSET, Unset
|
35
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
36
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
37
|
+
|
38
|
+
from dbdm import DM, NotFoundError, bind_engine
|
39
|
+
|
40
|
+
|
41
|
+
class BaseDB(DeclarativeBase):
|
42
|
+
pass
|
43
|
+
|
44
|
+
|
45
|
+
class ParentDB(BaseDB):
|
46
|
+
__tablename__ = "parents"
|
47
|
+
|
48
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
49
|
+
name: Mapped[str]
|
50
|
+
data: Mapped[str]
|
51
|
+
children = relationship("ChildDB", uselist=True, viewonly=True)
|
52
|
+
|
53
|
+
|
54
|
+
class ChildDB(BaseDB):
|
55
|
+
__tablename__ = "children"
|
56
|
+
|
57
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
58
|
+
name: Mapped[str]
|
59
|
+
parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
|
60
|
+
parent = relationship("ParentDB", uselist=False, viewonly=True)
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass(handle_circular_refs=True)
|
64
|
+
class Parent:
|
65
|
+
id: int
|
66
|
+
name: str
|
67
|
+
data: str
|
68
|
+
children: Unset[list["Child"]] = UNSET
|
69
|
+
|
70
|
+
# Parent views
|
71
|
+
Create: ClassVar[Type["ParentCreate"]]
|
72
|
+
Save: ClassVar[Type["ParentSave"]]
|
73
|
+
Update: ClassVar[Type["ParentUpdate"]]
|
74
|
+
|
75
|
+
|
76
|
+
@view(Parent, "Create", exclude=["id", "children"])
|
77
|
+
class ParentCreate:
|
78
|
+
pass
|
79
|
+
|
80
|
+
|
81
|
+
@view(Parent, "Save", exclude=["children"])
|
82
|
+
class ParentSave:
|
83
|
+
pass
|
84
|
+
|
85
|
+
|
86
|
+
@view(Parent, "Update", exclude=["children"])
|
87
|
+
class ParentUpdate:
|
88
|
+
name: Unset[str] = UNSET
|
89
|
+
data: Unset[str] = UNSET
|
90
|
+
|
91
|
+
|
92
|
+
@dataclass(handle_circular_refs=True)
|
93
|
+
class Child:
|
94
|
+
id: int
|
95
|
+
name: str
|
96
|
+
parent_id: int
|
97
|
+
parent: Unset[Parent] = UNSET
|
98
|
+
|
99
|
+
# Child views
|
100
|
+
Create: ClassVar[Type["ChildCreate"]]
|
101
|
+
Save: ClassVar[Type["ChildSave"]]
|
102
|
+
Update: ClassVar[Type["ChildUpdate"]]
|
103
|
+
|
104
|
+
|
105
|
+
@view(Child, "Create", exclude=["id", "parent"])
|
106
|
+
class ChildCreate:
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
@view(Child, "Save", exclude=["parent"])
|
111
|
+
class ChildSave:
|
112
|
+
pass
|
113
|
+
|
114
|
+
|
115
|
+
@view(Child, "Update", exclude=["parent"])
|
116
|
+
class ChildUpdate:
|
117
|
+
name: Unset[str] = UNSET
|
118
|
+
parent_id: Unset[int] = UNSET
|
119
|
+
|
120
|
+
|
121
|
+
resolve_types(Parent, globals(), locals())
|
122
|
+
|
123
|
+
|
124
|
+
class ParentDM(DM):
|
125
|
+
model_db = ParentDB
|
126
|
+
model = Parent
|
127
|
+
model_create = Parent.Create
|
128
|
+
model_save = Parent.Save
|
129
|
+
model_update = Parent.Update
|
130
|
+
key = "id"
|
131
|
+
index_elements = ["id"]
|
132
|
+
joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
|
133
|
+
|
134
|
+
|
135
|
+
class ChildDM(DM):
|
136
|
+
model_db = ChildDB
|
137
|
+
model = Child
|
138
|
+
model_create = Child.Create
|
139
|
+
model_save = Child.Save
|
140
|
+
model_update = Child.Update
|
141
|
+
key = "id"
|
142
|
+
index_elements = ["id"]
|
143
|
+
joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
|
144
|
+
|
145
|
+
|
146
|
+
@pytest_asyncio.fixture
|
147
|
+
async def create_all(engine):
|
148
|
+
async with engine.begin() as conn:
|
149
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
150
|
+
|
151
|
+
|
152
|
+
async def example(engine):
|
153
|
+
engine = create_async_engine(...)
|
154
|
+
|
155
|
+
async with engine.begin() as conn:
|
156
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
157
|
+
|
158
|
+
bind_engine(engine)
|
159
|
+
|
160
|
+
parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
|
161
|
+
|
162
|
+
# parents: list[Parent]
|
163
|
+
total, parents = await ParentDM.get_many()
|
164
|
+
|
165
|
+
# parents: list[Parent]
|
166
|
+
total, parents = await ParentDM.get_many(page_size=1)
|
167
|
+
|
168
|
+
# parents: list[Parent]
|
169
|
+
total, parents = await ParentDM.get_many(page_size=1, page=2)
|
170
|
+
|
171
|
+
# parent: Parent
|
172
|
+
parent = await ParentDM.get(1)
|
173
|
+
|
174
|
+
# parent: Parent
|
175
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
176
|
+
|
177
|
+
# parent: Parent
|
178
|
+
parent = await ParentDM.save(parent.Save())
|
179
|
+
|
180
|
+
# parent: Parent
|
181
|
+
parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
|
182
|
+
|
183
|
+
await ParentDM.delete(1)
|
184
|
+
|
185
|
+
# child : Child
|
186
|
+
child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
|
187
|
+
|
188
|
+
# parent: Parent
|
189
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
190
|
+
```
|
191
|
+
|
dbdm-0.1.0/README.md
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# DBDM [wip]
|
2
|
+
|
3
|
+
|
4
|
+
## Examples
|
5
|
+
|
6
|
+
```python
|
7
|
+
from typing import ClassVar, Type
|
8
|
+
|
9
|
+
import sqlalchemy as sa
|
10
|
+
|
11
|
+
from cwtch import dataclass, resolve_types, view
|
12
|
+
from cwtch.types import UNSET, Unset
|
13
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
14
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
15
|
+
|
16
|
+
from dbdm import DM, NotFoundError, bind_engine
|
17
|
+
|
18
|
+
|
19
|
+
class BaseDB(DeclarativeBase):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class ParentDB(BaseDB):
|
24
|
+
__tablename__ = "parents"
|
25
|
+
|
26
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
27
|
+
name: Mapped[str]
|
28
|
+
data: Mapped[str]
|
29
|
+
children = relationship("ChildDB", uselist=True, viewonly=True)
|
30
|
+
|
31
|
+
|
32
|
+
class ChildDB(BaseDB):
|
33
|
+
__tablename__ = "children"
|
34
|
+
|
35
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
36
|
+
name: Mapped[str]
|
37
|
+
parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
|
38
|
+
parent = relationship("ParentDB", uselist=False, viewonly=True)
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass(handle_circular_refs=True)
|
42
|
+
class Parent:
|
43
|
+
id: int
|
44
|
+
name: str
|
45
|
+
data: str
|
46
|
+
children: Unset[list["Child"]] = UNSET
|
47
|
+
|
48
|
+
# Parent views
|
49
|
+
Create: ClassVar[Type["ParentCreate"]]
|
50
|
+
Save: ClassVar[Type["ParentSave"]]
|
51
|
+
Update: ClassVar[Type["ParentUpdate"]]
|
52
|
+
|
53
|
+
|
54
|
+
@view(Parent, "Create", exclude=["id", "children"])
|
55
|
+
class ParentCreate:
|
56
|
+
pass
|
57
|
+
|
58
|
+
|
59
|
+
@view(Parent, "Save", exclude=["children"])
|
60
|
+
class ParentSave:
|
61
|
+
pass
|
62
|
+
|
63
|
+
|
64
|
+
@view(Parent, "Update", exclude=["children"])
|
65
|
+
class ParentUpdate:
|
66
|
+
name: Unset[str] = UNSET
|
67
|
+
data: Unset[str] = UNSET
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass(handle_circular_refs=True)
|
71
|
+
class Child:
|
72
|
+
id: int
|
73
|
+
name: str
|
74
|
+
parent_id: int
|
75
|
+
parent: Unset[Parent] = UNSET
|
76
|
+
|
77
|
+
# Child views
|
78
|
+
Create: ClassVar[Type["ChildCreate"]]
|
79
|
+
Save: ClassVar[Type["ChildSave"]]
|
80
|
+
Update: ClassVar[Type["ChildUpdate"]]
|
81
|
+
|
82
|
+
|
83
|
+
@view(Child, "Create", exclude=["id", "parent"])
|
84
|
+
class ChildCreate:
|
85
|
+
pass
|
86
|
+
|
87
|
+
|
88
|
+
@view(Child, "Save", exclude=["parent"])
|
89
|
+
class ChildSave:
|
90
|
+
pass
|
91
|
+
|
92
|
+
|
93
|
+
@view(Child, "Update", exclude=["parent"])
|
94
|
+
class ChildUpdate:
|
95
|
+
name: Unset[str] = UNSET
|
96
|
+
parent_id: Unset[int] = UNSET
|
97
|
+
|
98
|
+
|
99
|
+
resolve_types(Parent, globals(), locals())
|
100
|
+
|
101
|
+
|
102
|
+
class ParentDM(DM):
|
103
|
+
model_db = ParentDB
|
104
|
+
model = Parent
|
105
|
+
model_create = Parent.Create
|
106
|
+
model_save = Parent.Save
|
107
|
+
model_update = Parent.Update
|
108
|
+
key = "id"
|
109
|
+
index_elements = ["id"]
|
110
|
+
joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
|
111
|
+
|
112
|
+
|
113
|
+
class ChildDM(DM):
|
114
|
+
model_db = ChildDB
|
115
|
+
model = Child
|
116
|
+
model_create = Child.Create
|
117
|
+
model_save = Child.Save
|
118
|
+
model_update = Child.Update
|
119
|
+
key = "id"
|
120
|
+
index_elements = ["id"]
|
121
|
+
joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
|
122
|
+
|
123
|
+
|
124
|
+
@pytest_asyncio.fixture
|
125
|
+
async def create_all(engine):
|
126
|
+
async with engine.begin() as conn:
|
127
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
128
|
+
|
129
|
+
|
130
|
+
async def example(engine):
|
131
|
+
engine = create_async_engine(...)
|
132
|
+
|
133
|
+
async with engine.begin() as conn:
|
134
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
135
|
+
|
136
|
+
bind_engine(engine)
|
137
|
+
|
138
|
+
parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
|
139
|
+
|
140
|
+
# parents: list[Parent]
|
141
|
+
total, parents = await ParentDM.get_many()
|
142
|
+
|
143
|
+
# parents: list[Parent]
|
144
|
+
total, parents = await ParentDM.get_many(page_size=1)
|
145
|
+
|
146
|
+
# parents: list[Parent]
|
147
|
+
total, parents = await ParentDM.get_many(page_size=1, page=2)
|
148
|
+
|
149
|
+
# parent: Parent
|
150
|
+
parent = await ParentDM.get(1)
|
151
|
+
|
152
|
+
# parent: Parent
|
153
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
154
|
+
|
155
|
+
# parent: Parent
|
156
|
+
parent = await ParentDM.save(parent.Save())
|
157
|
+
|
158
|
+
# parent: Parent
|
159
|
+
parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
|
160
|
+
|
161
|
+
await ParentDM.delete(1)
|
162
|
+
|
163
|
+
# child : Child
|
164
|
+
child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
|
165
|
+
|
166
|
+
# parent: Parent
|
167
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
168
|
+
```
|
@@ -0,0 +1,736 @@
|
|
1
|
+
import re
|
2
|
+
import typing
|
3
|
+
|
4
|
+
from asyncio import gather
|
5
|
+
from contextlib import asynccontextmanager
|
6
|
+
from contextvars import ContextVar
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Any, AsyncIterator, Callable, Literal, Never, Optional, Protocol, Type, TypeVar
|
9
|
+
|
10
|
+
import sqlalchemy as sa
|
11
|
+
|
12
|
+
from sqlalchemy import delete, func, literal, select, union_all, update
|
13
|
+
from sqlalchemy.dialects.postgresql import insert
|
14
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker
|
15
|
+
from sqlalchemy.orm import DeclarativeBase, aliased
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
"DMError",
|
20
|
+
"AlreadyExistsError",
|
21
|
+
"BadParamsError",
|
22
|
+
"NotFoundError",
|
23
|
+
"OrderBy",
|
24
|
+
"bind_engine",
|
25
|
+
"transaction",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
_conn: dict[Optional[str], ContextVar] = {}
|
30
|
+
_engine: dict[Optional[str], sa.ext.asyncio.AsyncEngine] = {}
|
31
|
+
|
32
|
+
|
33
|
+
class DMError(Exception):
|
34
|
+
pass
|
35
|
+
|
36
|
+
|
37
|
+
class BadParamsError(DMError):
|
38
|
+
def __init__(self, message: str, param: Optional[str] = None):
|
39
|
+
self.message = message
|
40
|
+
self.param = param
|
41
|
+
|
42
|
+
def __str__(self):
|
43
|
+
return self.message
|
44
|
+
|
45
|
+
def __repr__(self):
|
46
|
+
return self.__str__()
|
47
|
+
|
48
|
+
|
49
|
+
class NotFoundError(DMError):
|
50
|
+
def __init__(self, key, value):
|
51
|
+
self.key = key
|
52
|
+
self.value = value
|
53
|
+
|
54
|
+
def __str__(self):
|
55
|
+
return f"item with ({self.key})=({self.value}) not found"
|
56
|
+
|
57
|
+
|
58
|
+
class AlreadyExistsError(DMError):
|
59
|
+
def __init__(self, key, value):
|
60
|
+
self.key = key
|
61
|
+
self.value = value
|
62
|
+
|
63
|
+
def __str__(self):
|
64
|
+
return f"key ({self.key})=({self.value}) already exists"
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class OrderBy:
|
69
|
+
by: Any
|
70
|
+
order: Literal["asc", "desc"] = "asc"
|
71
|
+
|
72
|
+
|
73
|
+
def bind_engine(engine: sa.ext.asyncio.AsyncEngine, name: Optional[str] = None):
|
74
|
+
if engine.dialect.name != "postgresql":
|
75
|
+
raise DMError("only 'postgresql' dialect is supported")
|
76
|
+
_engine[name] = engine
|
77
|
+
_conn[name] = ContextVar("conn", default=None)
|
78
|
+
|
79
|
+
|
80
|
+
@asynccontextmanager
|
81
|
+
async def transaction(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
|
82
|
+
if (conn := _conn[engine_name].get()) is None:
|
83
|
+
async with _engine[engine_name].connect() as conn:
|
84
|
+
async with conn.begin():
|
85
|
+
_conn[engine_name].set(conn)
|
86
|
+
try:
|
87
|
+
yield conn
|
88
|
+
finally:
|
89
|
+
_conn[engine_name].set(None)
|
90
|
+
else:
|
91
|
+
yield conn
|
92
|
+
|
93
|
+
|
94
|
+
@asynccontextmanager
|
95
|
+
async def get_conn(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
|
96
|
+
if (conn := _conn[engine_name].get()) is None:
|
97
|
+
async with _engine[engine_name].connect() as conn:
|
98
|
+
async with conn.begin():
|
99
|
+
yield conn
|
100
|
+
else:
|
101
|
+
yield conn
|
102
|
+
|
103
|
+
|
104
|
+
@asynccontextmanager
|
105
|
+
async def get_sess(engine_name: Optional[str] = None):
|
106
|
+
async with get_conn(engine_name=engine_name) as conn:
|
107
|
+
async_session = async_sessionmaker(conn, expire_on_commit=False)
|
108
|
+
async with async_session() as sess:
|
109
|
+
async with sess.begin():
|
110
|
+
yield sess
|
111
|
+
|
112
|
+
|
113
|
+
def _raise_exc(e: Exception) -> Never:
|
114
|
+
if isinstance(e, sa.exc.IntegrityError):
|
115
|
+
detail_match = re.match(r".*\nDETAIL:\s*(?P<text>.*)$", e.orig.args[0]) # type: ignore
|
116
|
+
if detail_match:
|
117
|
+
text = detail_match.groupdict()["text"].strip()
|
118
|
+
m = re.match(r"Key \((?P<key>.*)\)=\((?P<key_value>.*)\) already exists.", text)
|
119
|
+
if m:
|
120
|
+
key = m.groupdict()["key"]
|
121
|
+
key_value = m.groupdict()["key_value"].strip('\\"')
|
122
|
+
raise AlreadyExistsError(key, key_value)
|
123
|
+
raise e
|
124
|
+
|
125
|
+
|
126
|
+
ModelDB = TypeVar("ModelDB", bound=DeclarativeBase)
|
127
|
+
|
128
|
+
|
129
|
+
class _BaseProtocol(Protocol):
|
130
|
+
engine_name: Optional[str]
|
131
|
+
|
132
|
+
model_db: Type[DeclarativeBase]
|
133
|
+
model: Type
|
134
|
+
|
135
|
+
@classmethod
|
136
|
+
async def _execute(cls, query) -> sa.ResultProxy: ...
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
def _make_from_db_data(
|
140
|
+
cls,
|
141
|
+
db_item,
|
142
|
+
row: Optional[tuple] = None,
|
143
|
+
joinedload: Optional[dict] = None,
|
144
|
+
) -> dict: ...
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def _from_db(
|
148
|
+
cls,
|
149
|
+
item: sa.Row | ModelDB,
|
150
|
+
data: Optional[dict] = None,
|
151
|
+
exclude: Optional[list] = None,
|
152
|
+
suffix: Optional[str] = None,
|
153
|
+
model_out=None,
|
154
|
+
): ...
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def _get_key(cls, key=None): ...
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
async def get(
|
161
|
+
cls,
|
162
|
+
key_value,
|
163
|
+
key=None,
|
164
|
+
raise_not_found: Optional[bool] = None,
|
165
|
+
joinedload: Optional[dict] = None,
|
166
|
+
model_out: Optional[Type] = None,
|
167
|
+
**kwds,
|
168
|
+
) -> Optional[type]: ...
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
async def get_many(
|
172
|
+
cls,
|
173
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
174
|
+
page: Optional[int] = None,
|
175
|
+
page_size: Optional[int] = None,
|
176
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
177
|
+
joinedload: Optional[dict] = None,
|
178
|
+
model_out: Optional[Type] = None,
|
179
|
+
**kwds,
|
180
|
+
) -> tuple[int, list]: ...
|
181
|
+
|
182
|
+
|
183
|
+
class _Meta(typing._ProtocolMeta):
|
184
|
+
def __new__(cls, name, bases, ns, skip_checks: bool = False):
|
185
|
+
if not skip_checks:
|
186
|
+
|
187
|
+
def get_all_bases(bases: tuple) -> set:
|
188
|
+
result = set()
|
189
|
+
for base in bases:
|
190
|
+
result.add(base)
|
191
|
+
result |= get_all_bases(getattr(base, "__bases__", ()))
|
192
|
+
return result
|
193
|
+
|
194
|
+
all_bases = get_all_bases(bases)
|
195
|
+
|
196
|
+
if _GetDM not in all_bases:
|
197
|
+
raise TypeError("Any DM class should be subclassed from GetDM class")
|
198
|
+
|
199
|
+
if (model_db := ns.get("model_db")) is None:
|
200
|
+
raise ValueError("'model_db' field is required")
|
201
|
+
|
202
|
+
if (model := ns.get("model")) is None:
|
203
|
+
raise ValueError("'model' field is required")
|
204
|
+
|
205
|
+
if (
|
206
|
+
getattr(model, "__dataclass_fields__", None) is None
|
207
|
+
and getattr(model, "__pydantic_fields__", None) is None
|
208
|
+
):
|
209
|
+
raise ValueError("'model' is not a valid model")
|
210
|
+
|
211
|
+
if _CreateDM in all_bases or _SaveDM in all_bases:
|
212
|
+
if (model_create := ns.get("model_create")) is None:
|
213
|
+
raise ValueError("'model_create' field is required")
|
214
|
+
if (
|
215
|
+
getattr(model_create, "__dataclass_fields__", None) is None
|
216
|
+
and getattr(model_create, "__pydantic_fields__", None) is None
|
217
|
+
):
|
218
|
+
raise ValueError("'model_create' is not a valid model")
|
219
|
+
ns["_fields_create"] = set(model_db.__table__.columns.keys()) & set(
|
220
|
+
getattr(model_create, "__dataclass_fields__", None)
|
221
|
+
or getattr(model_create, "__pydantic_fields__", None)
|
222
|
+
or set()
|
223
|
+
)
|
224
|
+
if not ns["_fields_create"]:
|
225
|
+
raise DMError("model_create does not contain valid fields")
|
226
|
+
|
227
|
+
if _SaveDM in all_bases:
|
228
|
+
if (model_save := ns.get("model_save")) is None:
|
229
|
+
raise ValueError("'model_save' field is required")
|
230
|
+
if (
|
231
|
+
getattr(model_save, "__dataclass_fields__", None) is None
|
232
|
+
and getattr(model_save, "__pydantic_fields__", None) is None
|
233
|
+
):
|
234
|
+
raise ValueError("'model_save' is not a valid model")
|
235
|
+
ns["_fields_save"] = set(model_db.__table__.columns.keys()) & set(
|
236
|
+
getattr(model_save, "__dataclass_fields__", None)
|
237
|
+
or getattr(model_save, "__pydantic_fields__", None)
|
238
|
+
or set()
|
239
|
+
)
|
240
|
+
if not ns["_fields_save"]:
|
241
|
+
raise DMError("model_save does not contain valid fields")
|
242
|
+
|
243
|
+
if _UpdateDM in all_bases:
|
244
|
+
if (model_update := ns.get("model_update")) is None:
|
245
|
+
raise ValueError("'model_update' field is required")
|
246
|
+
if (
|
247
|
+
getattr(model_update, "__dataclass_fields__", None) is None
|
248
|
+
and getattr(model_update, "__pydantic_fields__", None) is None
|
249
|
+
):
|
250
|
+
raise ValueError("'model_update' is not a valid model")
|
251
|
+
ns["_fields_update"] = set(model_db.__table__.columns.keys()) & set(
|
252
|
+
getattr(model_update, "__dataclass_fields__", None)
|
253
|
+
or getattr(model_update, "__pydantic_fields__", None)
|
254
|
+
or set()
|
255
|
+
)
|
256
|
+
if not ns["_fields_update"]:
|
257
|
+
raise DMError("model_update does not contain valid fields")
|
258
|
+
|
259
|
+
return super().__new__(cls, name, bases, ns)
|
260
|
+
|
261
|
+
|
262
|
+
class _GetDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
|
263
|
+
engine_name: Optional[str] = None
|
264
|
+
|
265
|
+
model_db: Type[DeclarativeBase] = typing.cast(Type[DeclarativeBase], None)
|
266
|
+
model: Type = typing.cast(Type, None)
|
267
|
+
|
268
|
+
key = None
|
269
|
+
order_by: Optional[list[OrderBy | str | Callable] | OrderBy | str | Callable] = None
|
270
|
+
|
271
|
+
joinedload: dict = {}
|
272
|
+
|
273
|
+
@classmethod
|
274
|
+
async def _execute(cls, query) -> sa.ResultProxy:
|
275
|
+
async with get_conn(engine_name=cls.engine_name) as conn:
|
276
|
+
try:
|
277
|
+
return await conn.execute(query)
|
278
|
+
except Exception as e:
|
279
|
+
_raise_exc(e)
|
280
|
+
|
281
|
+
@classmethod
|
282
|
+
def _make_from_db_data(
|
283
|
+
cls,
|
284
|
+
db_item,
|
285
|
+
row: Optional[tuple] = None,
|
286
|
+
joinedload: Optional[dict] = None,
|
287
|
+
) -> dict:
|
288
|
+
return {}
|
289
|
+
|
290
|
+
@classmethod
|
291
|
+
def _from_db(
|
292
|
+
cls,
|
293
|
+
item: sa.Row | ModelDB,
|
294
|
+
data: Optional[dict] = None,
|
295
|
+
exclude: Optional[list] = None,
|
296
|
+
suffix: Optional[str] = None,
|
297
|
+
model_out=None,
|
298
|
+
): ...
|
299
|
+
|
300
|
+
@classmethod
|
301
|
+
def _get_key(cls, key=None):
|
302
|
+
if key is None:
|
303
|
+
key = cls.key
|
304
|
+
if key is None:
|
305
|
+
raise ValueError("key is None")
|
306
|
+
if isinstance(key, str):
|
307
|
+
key = getattr(cls.model_db, key)
|
308
|
+
return key
|
309
|
+
|
310
|
+
@classmethod
|
311
|
+
def _get_order_by(cls, c, order_by: Optional[list[OrderBy | str] | OrderBy | str] = None) -> list:
|
312
|
+
_order_by = order_by
|
313
|
+
if _order_by is None:
|
314
|
+
if callable(cls.order_by):
|
315
|
+
_order_by = cls.order_by(c)
|
316
|
+
else:
|
317
|
+
_order_by = cls.order_by or cls._get_key()
|
318
|
+
if _order_by is not None:
|
319
|
+
if not isinstance(_order_by, list):
|
320
|
+
_order_by = [_order_by]
|
321
|
+
else:
|
322
|
+
_order_by = list(_order_by)
|
323
|
+
for i, x in enumerate(_order_by):
|
324
|
+
if isinstance(x, OrderBy):
|
325
|
+
x = OrderBy(by=x.by, order=x.order)
|
326
|
+
else:
|
327
|
+
x = OrderBy(by=x)
|
328
|
+
if isinstance(x.by, str):
|
329
|
+
if getattr(c, x.by, None) is None:
|
330
|
+
raise BadParamsError(f"invalid order_by '{x.by}'", param=f"{x.by}")
|
331
|
+
x.by = getattr(c, x.by)
|
332
|
+
_order_by[i] = getattr(sa, x.order)(x.by)
|
333
|
+
return typing.cast(list, _order_by)
|
334
|
+
|
335
|
+
@classmethod
|
336
|
+
def _make_get_query(cls, key, key_value, **kwds) -> sa.sql.Select:
|
337
|
+
return select(cls.model_db).where(key == key_value)
|
338
|
+
|
339
|
+
@classmethod
|
340
|
+
async def get(
|
341
|
+
cls,
|
342
|
+
key_value,
|
343
|
+
key=None,
|
344
|
+
raise_not_found: Optional[bool] = None,
|
345
|
+
joinedload: Optional[dict] = None,
|
346
|
+
model_out: Optional[Type] = None,
|
347
|
+
**kwds,
|
348
|
+
) -> Optional[type]:
|
349
|
+
key = cls._get_key(key=key)
|
350
|
+
query = cls._make_get_query(key, key_value, **kwds)
|
351
|
+
model_out = model_out or cls.model
|
352
|
+
|
353
|
+
if joinedload is None or not any(joinedload.values()):
|
354
|
+
item = (await cls._execute(query)).one_or_none()
|
355
|
+
if item:
|
356
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
357
|
+
if raise_not_found:
|
358
|
+
raise NotFoundError(key=key, value=key_value)
|
359
|
+
return
|
360
|
+
|
361
|
+
exclude = []
|
362
|
+
for k, v in cls.joinedload.items():
|
363
|
+
if joinedload.get(k) is True:
|
364
|
+
query = query.options(v(cls.model_db))
|
365
|
+
else:
|
366
|
+
exclude.append(k)
|
367
|
+
|
368
|
+
async with get_sess(engine_name=cls.engine_name) as sess:
|
369
|
+
item = (await sess.execute(query)).unique().scalars().one_or_none()
|
370
|
+
if item:
|
371
|
+
return cls._from_db(
|
372
|
+
item,
|
373
|
+
exclude=exclude,
|
374
|
+
data=cls._make_from_db_data(item, joinedload=joinedload),
|
375
|
+
model_out=model_out,
|
376
|
+
)
|
377
|
+
if raise_not_found:
|
378
|
+
raise NotFoundError(key=key, value=key_value)
|
379
|
+
|
380
|
+
@classmethod
|
381
|
+
def _make_get_many_query(
|
382
|
+
cls,
|
383
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
384
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
385
|
+
**kwds,
|
386
|
+
) -> sa.sql.Select:
|
387
|
+
query = select(func.count(literal("*")).over().label("rows_total"), cls.model_db)
|
388
|
+
if flt is not None:
|
389
|
+
query = query.where(flt)
|
390
|
+
query = query.order_by(*cls._get_order_by(cls.model_db, order_by))
|
391
|
+
return query
|
392
|
+
|
393
|
+
@classmethod
|
394
|
+
async def get_many(
|
395
|
+
cls,
|
396
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
397
|
+
page: Optional[int] = None,
|
398
|
+
page_size: Optional[int] = None,
|
399
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
400
|
+
joinedload: Optional[dict] = None,
|
401
|
+
model_out: Optional[Type] = None,
|
402
|
+
**kwds,
|
403
|
+
) -> tuple[int, list]:
|
404
|
+
model_db = cls.model_db
|
405
|
+
model_out = model_out or cls.model
|
406
|
+
|
407
|
+
query = cls._make_get_many_query(flt=flt, order_by=order_by, **kwds)
|
408
|
+
|
409
|
+
cte = query.cte("cte")
|
410
|
+
|
411
|
+
query = select(literal(1).label("i"), cte)
|
412
|
+
|
413
|
+
if page_size:
|
414
|
+
page = page or 1
|
415
|
+
query = query.limit(page_size).offset((page - 1) * page_size)
|
416
|
+
|
417
|
+
query = union_all(select(literal(0).label("i"), cte).limit(1), query)
|
418
|
+
|
419
|
+
from_db = cls._from_db
|
420
|
+
make_from_orm_data = cls._make_from_db_data
|
421
|
+
|
422
|
+
if joinedload is None or not any(joinedload.values()):
|
423
|
+
rows = (await cls._execute(query)).all()
|
424
|
+
return rows[0].rows_total if rows else 0, [
|
425
|
+
from_db(row, data=make_from_orm_data(row), model_out=model_out) for row in rows[1:]
|
426
|
+
]
|
427
|
+
|
428
|
+
main_cte = query.cte("main_cte")
|
429
|
+
|
430
|
+
model_db_alias = aliased(model_db, main_cte)
|
431
|
+
query = select(main_cte, model_db_alias)
|
432
|
+
|
433
|
+
exclude = []
|
434
|
+
for k, v in cls.joinedload.items():
|
435
|
+
if joinedload.get(k) is True:
|
436
|
+
query = query.options(v(model_db_alias))
|
437
|
+
else:
|
438
|
+
exclude.append(k)
|
439
|
+
|
440
|
+
query = query.order_by(sa.asc(main_cte.c.i), *cls._get_order_by(main_cte.c, order_by))
|
441
|
+
|
442
|
+
def _hash(row):
|
443
|
+
return hash((row[0], row[1], row[-1]))
|
444
|
+
|
445
|
+
async with get_sess(engine_name=cls.engine_name) as sess:
|
446
|
+
rows = (await sess.execute(query)).unique(_hash).all()
|
447
|
+
return rows[0].rows_total if rows else 0, [
|
448
|
+
from_db(
|
449
|
+
row[-1],
|
450
|
+
exclude=exclude,
|
451
|
+
data=make_from_orm_data(row[-1], row=typing.cast(tuple, row), joinedload=joinedload),
|
452
|
+
model_out=model_out,
|
453
|
+
)
|
454
|
+
for row in rows[1:]
|
455
|
+
]
|
456
|
+
|
457
|
+
|
458
|
+
class _CreateDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
|
459
|
+
model_create: Type = typing.cast(Type, None)
|
460
|
+
|
461
|
+
index_elements: Optional[list] = None
|
462
|
+
|
463
|
+
_fields_create: set[str] = set()
|
464
|
+
|
465
|
+
@classmethod
|
466
|
+
def _get_values_for_create_query(cls, model) -> dict: ...
|
467
|
+
|
468
|
+
@classmethod
|
469
|
+
def _make_create_query(cls, model, returning: Optional[bool] = None) -> sa.sql.Insert:
|
470
|
+
model_db = cls.model_db
|
471
|
+
query = insert(model_db).values(cls._get_values_for_create_query(model))
|
472
|
+
if returning:
|
473
|
+
query = query.returning(model_db)
|
474
|
+
return query
|
475
|
+
|
476
|
+
@classmethod
|
477
|
+
async def create(
|
478
|
+
cls,
|
479
|
+
model,
|
480
|
+
model_out: Optional[Type] = None,
|
481
|
+
returning: bool = True,
|
482
|
+
):
|
483
|
+
result = await cls._execute(cls._make_create_query(model, returning=returning))
|
484
|
+
if returning:
|
485
|
+
item = result.one()
|
486
|
+
model_out = model_out or cls.model
|
487
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
488
|
+
|
489
|
+
@classmethod
|
490
|
+
def _make_create_many_query(cls, models: list, returning: bool | None = None) -> sa.sql.Insert:
|
491
|
+
model_db = cls.model_db
|
492
|
+
query = insert(cls.model_db).values([cls._get_values_for_create_query(model) for model in models])
|
493
|
+
if returning:
|
494
|
+
query = query.returning(model_db)
|
495
|
+
return query
|
496
|
+
|
497
|
+
@classmethod
|
498
|
+
async def create_many(
|
499
|
+
cls,
|
500
|
+
models: list,
|
501
|
+
model_out: Optional[Type] = None,
|
502
|
+
returning: bool = True,
|
503
|
+
) -> list | None:
|
504
|
+
if not models:
|
505
|
+
return []
|
506
|
+
result = await cls._execute(cls._make_create_many_query(models, returning=returning))
|
507
|
+
if returning:
|
508
|
+
from_db = cls._from_db
|
509
|
+
make_from_db_data = cls._make_from_db_data
|
510
|
+
model_out = model_out or cls.model
|
511
|
+
return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
|
512
|
+
|
513
|
+
@classmethod
|
514
|
+
def _make_get_or_create_query(
|
515
|
+
cls,
|
516
|
+
model,
|
517
|
+
update_element: str,
|
518
|
+
index_elements: Optional[list] = None,
|
519
|
+
) -> sa.sql.Insert:
|
520
|
+
model_db = cls.model_db
|
521
|
+
query = insert(model_db).values(cls._get_values_for_create_query(model))
|
522
|
+
index_elements = index_elements or cls.index_elements
|
523
|
+
if index_elements is None:
|
524
|
+
raise ValueError("index_elements is None")
|
525
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
526
|
+
return query.on_conflict_do_update(
|
527
|
+
index_elements=index_elements,
|
528
|
+
set_={update_element: getattr(query.excluded, update_element)},
|
529
|
+
).returning(model_db)
|
530
|
+
|
531
|
+
@classmethod
|
532
|
+
async def get_or_create(
|
533
|
+
cls,
|
534
|
+
model,
|
535
|
+
update_element: str,
|
536
|
+
index_elements: Optional[list] = None,
|
537
|
+
model_out: Optional[Type] = None,
|
538
|
+
):
|
539
|
+
model_out = model_out or cls.model
|
540
|
+
item = (
|
541
|
+
await cls._execute(
|
542
|
+
cls._make_get_or_create_query(
|
543
|
+
model,
|
544
|
+
update_element,
|
545
|
+
index_elements=index_elements,
|
546
|
+
)
|
547
|
+
)
|
548
|
+
).one()
|
549
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
550
|
+
|
551
|
+
|
552
|
+
class _SaveDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
553
|
+
model_create: Type = typing.cast(Type, None)
|
554
|
+
model_save: Type = typing.cast(Type, None)
|
555
|
+
|
556
|
+
index_elements: Optional[list] = None
|
557
|
+
|
558
|
+
_fields_create: set[str] = set()
|
559
|
+
_fields_save: set[str] = set()
|
560
|
+
|
561
|
+
@classmethod
|
562
|
+
def _get_values_for_save_query(cls, model) -> dict: ...
|
563
|
+
|
564
|
+
@classmethod
|
565
|
+
def _get_on_conflict_do_update_set_for_save_query(cls, excluded, model) -> dict:
|
566
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
567
|
+
return {k: getattr(excluded, k) for k in cls._fields_create}
|
568
|
+
return {k: getattr(excluded, k) for k in cls._fields_save}
|
569
|
+
|
570
|
+
@classmethod
|
571
|
+
def _make_save_query(
|
572
|
+
cls,
|
573
|
+
model,
|
574
|
+
index_elements: Optional[list] = None,
|
575
|
+
returning: Optional[bool] = None,
|
576
|
+
) -> sa.sql.Insert:
|
577
|
+
model_db = cls.model_db
|
578
|
+
query = insert(model_db).values(cls._get_values_for_save_query(model))
|
579
|
+
index_elements = index_elements or cls.index_elements
|
580
|
+
if index_elements is None:
|
581
|
+
raise ValueError("index_elements is None")
|
582
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
583
|
+
query = query.on_conflict_do_update(
|
584
|
+
index_elements=index_elements,
|
585
|
+
set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, model),
|
586
|
+
)
|
587
|
+
if returning:
|
588
|
+
query = query.returning(model_db)
|
589
|
+
return query
|
590
|
+
|
591
|
+
@classmethod
|
592
|
+
async def save(
|
593
|
+
cls,
|
594
|
+
model,
|
595
|
+
index_elements: Optional[list] = None,
|
596
|
+
model_out: Optional[Type] = None,
|
597
|
+
returning: bool = True,
|
598
|
+
):
|
599
|
+
result = await cls._execute(
|
600
|
+
cls._make_save_query(
|
601
|
+
model,
|
602
|
+
index_elements=index_elements,
|
603
|
+
returning=returning,
|
604
|
+
)
|
605
|
+
)
|
606
|
+
if returning:
|
607
|
+
item = result.one()
|
608
|
+
model_out = model_out or cls.model
|
609
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
610
|
+
|
611
|
+
@classmethod
|
612
|
+
def _make_save_many_query(
|
613
|
+
cls,
|
614
|
+
models: list,
|
615
|
+
index_elements: Optional[list] = None,
|
616
|
+
returning: Optional[bool] = None,
|
617
|
+
) -> sa.sql.Insert:
|
618
|
+
model_db = cls.model_db
|
619
|
+
query = insert(model_db).values([cls._get_values_for_save_query(model) for model in models])
|
620
|
+
index_elements = index_elements or cls.index_elements
|
621
|
+
if index_elements is None:
|
622
|
+
raise ValueError("index_elements is None")
|
623
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
624
|
+
query = query.on_conflict_do_update(
|
625
|
+
index_elements=index_elements,
|
626
|
+
set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, models[0]),
|
627
|
+
)
|
628
|
+
if returning:
|
629
|
+
query = query.returning(model_db)
|
630
|
+
return query
|
631
|
+
|
632
|
+
@classmethod
|
633
|
+
async def save_many(
|
634
|
+
cls,
|
635
|
+
models: list,
|
636
|
+
index_elements: Optional[list] = None,
|
637
|
+
model_out: Optional[Type] = None,
|
638
|
+
returning: bool = True,
|
639
|
+
) -> list | None:
|
640
|
+
if not models:
|
641
|
+
return []
|
642
|
+
result = await cls._execute(
|
643
|
+
cls._make_save_many_query(
|
644
|
+
models,
|
645
|
+
index_elements=index_elements,
|
646
|
+
returning=returning,
|
647
|
+
)
|
648
|
+
)
|
649
|
+
if returning:
|
650
|
+
from_db = cls._from_db
|
651
|
+
make_from_db_data = cls._make_from_db_data
|
652
|
+
model_out = model_out or cls.model
|
653
|
+
return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
|
654
|
+
|
655
|
+
|
656
|
+
class _UpdateDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
657
|
+
model_update: Type = typing.cast(Type, None)
|
658
|
+
_fields_update: set[str] = set()
|
659
|
+
|
660
|
+
@classmethod
|
661
|
+
def _get_values_for_update_query(cls, model) -> dict: ...
|
662
|
+
|
663
|
+
@classmethod
|
664
|
+
def _make_update_query(cls, model, key: str, returning: Optional[bool] = None) -> sa.sql.Update:
|
665
|
+
model_db = cls.model_db
|
666
|
+
query = (
|
667
|
+
update(model_db)
|
668
|
+
.values(cls._get_values_for_update_query(model))
|
669
|
+
.where(getattr(model_db, key) == getattr(model, key))
|
670
|
+
.returning(getattr(model_db, key))
|
671
|
+
)
|
672
|
+
if returning:
|
673
|
+
query = query.returning(model_db)
|
674
|
+
return query
|
675
|
+
|
676
|
+
@classmethod
|
677
|
+
async def update(
|
678
|
+
cls,
|
679
|
+
model,
|
680
|
+
key: str,
|
681
|
+
raise_not_found: Optional[bool] = None,
|
682
|
+
model_out: Optional[Type] = None,
|
683
|
+
returning: bool = True,
|
684
|
+
):
|
685
|
+
result = await cls._execute(cls._make_update_query(model, key, returning=returning or raise_not_found))
|
686
|
+
if raise_not_found and result.rowcount == 0:
|
687
|
+
raise NotFoundError(key=key, value=getattr(model, key))
|
688
|
+
if returning:
|
689
|
+
item = result.one_or_none()
|
690
|
+
if item:
|
691
|
+
model_out = model_out or cls.model
|
692
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
693
|
+
|
694
|
+
@classmethod
|
695
|
+
async def update_many(
|
696
|
+
cls,
|
697
|
+
models: list,
|
698
|
+
key: str,
|
699
|
+
model_out: Optional[Type] = None,
|
700
|
+
returning: bool = True,
|
701
|
+
) -> list | None:
|
702
|
+
async with transaction():
|
703
|
+
results = await gather(
|
704
|
+
*[cls.update(model, key, model_out=model_out, returning=returning) for model in models]
|
705
|
+
)
|
706
|
+
if returning:
|
707
|
+
return typing.cast(list | None, results)
|
708
|
+
|
709
|
+
|
710
|
+
class _DeleteDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
711
|
+
@classmethod
|
712
|
+
def _make_delete_query(cls, key, key_value, returning: Optional[bool] = None):
|
713
|
+
query = delete(cls.model_db).where(key == key_value)
|
714
|
+
if returning:
|
715
|
+
query = query.returning(cls.model_db)
|
716
|
+
return query
|
717
|
+
|
718
|
+
@classmethod
|
719
|
+
async def delete(
|
720
|
+
cls,
|
721
|
+
key_value,
|
722
|
+
key=None,
|
723
|
+
raise_not_found: Optional[bool] = None,
|
724
|
+
model_out: Optional[Type] = None,
|
725
|
+
returning: bool = True,
|
726
|
+
):
|
727
|
+
key = cls._get_key(key=key)
|
728
|
+
result = await cls._execute(cls._make_delete_query(key, key_value, returning=returning))
|
729
|
+
if raise_not_found:
|
730
|
+
if raise_not_found and result.rowcount == 0:
|
731
|
+
raise NotFoundError(key=key, value=key_value)
|
732
|
+
if returning:
|
733
|
+
item = result.one_or_none()
|
734
|
+
if item:
|
735
|
+
model_out = model_out or cls.model
|
736
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
dbdm-0.1.0/dbdm/cwtch.py
ADDED
@@ -0,0 +1,66 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
from typing import Optional, Type
|
4
|
+
|
5
|
+
import sqlalchemy as sa
|
6
|
+
|
7
|
+
from cwtch import asdict, from_attributes
|
8
|
+
|
9
|
+
from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
|
10
|
+
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"OrderBy",
|
14
|
+
"GetDM",
|
15
|
+
"CreateDM",
|
16
|
+
"SaveDM",
|
17
|
+
"UpdateDM",
|
18
|
+
"DeleteDM",
|
19
|
+
"DM",
|
20
|
+
]
|
21
|
+
|
22
|
+
|
23
|
+
class GetDM(_GetDM, skip_checks=True):
|
24
|
+
@classmethod
|
25
|
+
def _from_db(
|
26
|
+
cls,
|
27
|
+
item: sa.Row | ModelDB,
|
28
|
+
data: Optional[dict] = None,
|
29
|
+
exclude: Optional[list] = None,
|
30
|
+
suffix: Optional[str] = None,
|
31
|
+
model_out=None,
|
32
|
+
):
|
33
|
+
return from_attributes(
|
34
|
+
model_out or cls.model,
|
35
|
+
item,
|
36
|
+
data=data,
|
37
|
+
exclude=exclude,
|
38
|
+
suffix=suffix,
|
39
|
+
reset_circular_refs=True,
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class CreateDM(_CreateDM, skip_checks=True):
|
44
|
+
@classmethod
|
45
|
+
def _get_values_for_create_query(cls, model) -> dict:
|
46
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_create))
|
47
|
+
|
48
|
+
|
49
|
+
class SaveDM(_SaveDM, skip_checks=True):
|
50
|
+
@classmethod
|
51
|
+
def _get_values_for_save_query(cls, model) -> dict:
|
52
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
53
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_create))
|
54
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_save))
|
55
|
+
|
56
|
+
|
57
|
+
class UpdateDM(_UpdateDM, skip_checks=True):
|
58
|
+
@classmethod
|
59
|
+
def _get_values_for_update_query(cls, model) -> dict:
|
60
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
|
61
|
+
|
62
|
+
|
63
|
+
class DeleteDM(_DeleteDM, skip_checks=True): ...
|
64
|
+
|
65
|
+
|
66
|
+
class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
from typing import Optional, Type
|
4
|
+
|
5
|
+
import pydantic # noqa: F401 # type: ignore
|
6
|
+
import sqlalchemy as sa
|
7
|
+
|
8
|
+
from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
|
9
|
+
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"OrderBy",
|
13
|
+
"GetDM",
|
14
|
+
"CreateDM",
|
15
|
+
"SaveDM",
|
16
|
+
"UpdateDM",
|
17
|
+
"DeleteDM",
|
18
|
+
"DM",
|
19
|
+
]
|
20
|
+
|
21
|
+
|
22
|
+
class _object:
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
26
|
+
class GetDM(_GetDM, skip_checks=True):
|
27
|
+
@classmethod
|
28
|
+
def _from_db(
|
29
|
+
cls,
|
30
|
+
item: sa.Row | ModelDB,
|
31
|
+
data: Optional[dict] = None,
|
32
|
+
exclude: Optional[list] = None,
|
33
|
+
suffix: Optional[str] = None,
|
34
|
+
model_out=None,
|
35
|
+
):
|
36
|
+
if isinstance(item, sa.Row):
|
37
|
+
data = {**item._asdict(), **(data or {})}
|
38
|
+
else:
|
39
|
+
data = {**item.__dict__, **(data or {})}
|
40
|
+
|
41
|
+
obj = _object()
|
42
|
+
|
43
|
+
if suffix:
|
44
|
+
obj.__dict__ = {
|
45
|
+
k.rstrip(suffix): v for k, v in data.items() if (not exclude or k not in exclude) and k.endswith(suffix)
|
46
|
+
}
|
47
|
+
else:
|
48
|
+
obj.__dict__ = {k: v for k, v in data.items() if not exclude or k not in exclude}
|
49
|
+
|
50
|
+
return (model_out or cls.model).model_validate(obj, from_attributes=True)
|
51
|
+
|
52
|
+
|
53
|
+
class CreateDM(_CreateDM, skip_checks=True):
|
54
|
+
@classmethod
|
55
|
+
def _get_values_for_create_query(cls, model) -> dict:
|
56
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_create))
|
57
|
+
|
58
|
+
|
59
|
+
class SaveDM(_SaveDM, skip_checks=True):
|
60
|
+
@classmethod
|
61
|
+
def _get_values_for_save_query(cls, model) -> dict:
|
62
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
63
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_create))
|
64
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_save))
|
65
|
+
|
66
|
+
|
67
|
+
class UpdateDM(_UpdateDM, skip_checks=True):
|
68
|
+
@classmethod
|
69
|
+
def _get_values_for_update_query(cls, model) -> dict:
|
70
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
|
71
|
+
|
72
|
+
|
73
|
+
class DeleteDM(_DeleteDM, skip_checks=True): ...
|
74
|
+
|
75
|
+
|
76
|
+
class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
|
@@ -0,0 +1,51 @@
|
|
1
|
+
[tool.poetry]
|
2
|
+
name = "dbdm"
|
3
|
+
version = "0.1.0"
|
4
|
+
description = ""
|
5
|
+
authors = ["Roman Koshel <roma.koshel@gmail.com>"]
|
6
|
+
license = "MIT"
|
7
|
+
readme = "README.md"
|
8
|
+
keywords = []
|
9
|
+
packages = [{ include = "dbdm" }]
|
10
|
+
classifiers = [
|
11
|
+
"Development Status :: 3 - Alpha",
|
12
|
+
"Programming Language :: Python",
|
13
|
+
"Programming Language :: Python :: 3 :: Only",
|
14
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
15
|
+
"Topic :: Utilities",
|
16
|
+
"License :: OSI Approved :: MIT License",
|
17
|
+
]
|
18
|
+
|
19
|
+
[tool.poetry.dependencies]
|
20
|
+
python = ">=3.11, <4.0"
|
21
|
+
asyncpg = "*"
|
22
|
+
sqlalchemy = { version="^2", extras=["asyncio"] }
|
23
|
+
|
24
|
+
[tool.poetry.group.test.dependencies]
|
25
|
+
coverage = "*"
|
26
|
+
cwtch = ">=0.10.0"
|
27
|
+
docker = "*"
|
28
|
+
pydantic = "^2"
|
29
|
+
pylint = "*"
|
30
|
+
pytest = "*"
|
31
|
+
pytest_asyncio = "*"
|
32
|
+
pytest_timeout = "*"
|
33
|
+
|
34
|
+
[tool.poetry.group.docs.dependencies]
|
35
|
+
mkdocs = "*"
|
36
|
+
mkdocs-material = "*"
|
37
|
+
mkdocstrings = { extras = ["python"], version = "*" }
|
38
|
+
|
39
|
+
[build-system]
|
40
|
+
requires = ["poetry-core"]
|
41
|
+
build-backend = "poetry.core.masonry.api"
|
42
|
+
|
43
|
+
[tool.black]
|
44
|
+
line-length = 120
|
45
|
+
|
46
|
+
[tool.isort]
|
47
|
+
indent = 4
|
48
|
+
lines_after_imports = 2
|
49
|
+
lines_between_types = 1
|
50
|
+
src_paths = ["."]
|
51
|
+
profile = "black"
|