dbdm 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dbdm-0.1.0/LICENSE +21 -0
- dbdm-0.1.0/PKG-INFO +191 -0
- dbdm-0.1.0/README.md +168 -0
- dbdm-0.1.0/dbdm/__init__.py +7 -0
- dbdm-0.1.0/dbdm/common.py +736 -0
- dbdm-0.1.0/dbdm/cwtch.py +66 -0
- dbdm-0.1.0/dbdm/pydantic.py +76 -0
- dbdm-0.1.0/pyproject.toml +51 -0
dbdm-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2023 Roma Koshel
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
dbdm-0.1.0/PKG-INFO
ADDED
@@ -0,0 +1,191 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: dbdm
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary:
|
5
|
+
License: MIT
|
6
|
+
Author: Roman Koshel
|
7
|
+
Author-email: roma.koshel@gmail.com
|
8
|
+
Requires-Python: >=3.11,<4.0
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
10
|
+
Classifier: License :: OSI Approved :: MIT License
|
11
|
+
Classifier: Programming Language :: Python
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
16
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
17
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
18
|
+
Classifier: Topic :: Utilities
|
19
|
+
Requires-Dist: asyncpg
|
20
|
+
Requires-Dist: sqlalchemy[asyncio] (>=2,<3)
|
21
|
+
Description-Content-Type: text/markdown
|
22
|
+
|
23
|
+
# DBDM [wip]
|
24
|
+
|
25
|
+
|
26
|
+
## Examples
|
27
|
+
|
28
|
+
```python
|
29
|
+
from typing import ClassVar, Type
|
30
|
+
|
31
|
+
import sqlalchemy as sa
|
32
|
+
|
33
|
+
from cwtch import dataclass, resolve_types, view
|
34
|
+
from cwtch.types import UNSET, Unset
|
35
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
36
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
37
|
+
|
38
|
+
from dbdm import DM, NotFoundError, bind_engine
|
39
|
+
|
40
|
+
|
41
|
+
class BaseDB(DeclarativeBase):
|
42
|
+
pass
|
43
|
+
|
44
|
+
|
45
|
+
class ParentDB(BaseDB):
|
46
|
+
__tablename__ = "parents"
|
47
|
+
|
48
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
49
|
+
name: Mapped[str]
|
50
|
+
data: Mapped[str]
|
51
|
+
children = relationship("ChildDB", uselist=True, viewonly=True)
|
52
|
+
|
53
|
+
|
54
|
+
class ChildDB(BaseDB):
|
55
|
+
__tablename__ = "children"
|
56
|
+
|
57
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
58
|
+
name: Mapped[str]
|
59
|
+
parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
|
60
|
+
parent = relationship("ParentDB", uselist=False, viewonly=True)
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass(handle_circular_refs=True)
|
64
|
+
class Parent:
|
65
|
+
id: int
|
66
|
+
name: str
|
67
|
+
data: str
|
68
|
+
children: Unset[list["Child"]] = UNSET
|
69
|
+
|
70
|
+
# Parent views
|
71
|
+
Create: ClassVar[Type["ParentCreate"]]
|
72
|
+
Save: ClassVar[Type["ParentSave"]]
|
73
|
+
Update: ClassVar[Type["ParentUpdate"]]
|
74
|
+
|
75
|
+
|
76
|
+
@view(Parent, "Create", exclude=["id", "children"])
|
77
|
+
class ParentCreate:
|
78
|
+
pass
|
79
|
+
|
80
|
+
|
81
|
+
@view(Parent, "Save", exclude=["children"])
|
82
|
+
class ParentSave:
|
83
|
+
pass
|
84
|
+
|
85
|
+
|
86
|
+
@view(Parent, "Update", exclude=["children"])
|
87
|
+
class ParentUpdate:
|
88
|
+
name: Unset[str] = UNSET
|
89
|
+
data: Unset[str] = UNSET
|
90
|
+
|
91
|
+
|
92
|
+
@dataclass(handle_circular_refs=True)
|
93
|
+
class Child:
|
94
|
+
id: int
|
95
|
+
name: str
|
96
|
+
parent_id: int
|
97
|
+
parent: Unset[Parent] = UNSET
|
98
|
+
|
99
|
+
# Child views
|
100
|
+
Create: ClassVar[Type["ChildCreate"]]
|
101
|
+
Save: ClassVar[Type["ChildSave"]]
|
102
|
+
Update: ClassVar[Type["ChildUpdate"]]
|
103
|
+
|
104
|
+
|
105
|
+
@view(Child, "Create", exclude=["id", "parent"])
|
106
|
+
class ChildCreate:
|
107
|
+
pass
|
108
|
+
|
109
|
+
|
110
|
+
@view(Child, "Save", exclude=["parent"])
|
111
|
+
class ChildSave:
|
112
|
+
pass
|
113
|
+
|
114
|
+
|
115
|
+
@view(Child, "Update", exclude=["parent"])
|
116
|
+
class ChildUpdate:
|
117
|
+
name: Unset[str] = UNSET
|
118
|
+
parent_id: Unset[int] = UNSET
|
119
|
+
|
120
|
+
|
121
|
+
resolve_types(Parent, globals(), locals())
|
122
|
+
|
123
|
+
|
124
|
+
class ParentDM(DM):
|
125
|
+
model_db = ParentDB
|
126
|
+
model = Parent
|
127
|
+
model_create = Parent.Create
|
128
|
+
model_save = Parent.Save
|
129
|
+
model_update = Parent.Update
|
130
|
+
key = "id"
|
131
|
+
index_elements = ["id"]
|
132
|
+
joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
|
133
|
+
|
134
|
+
|
135
|
+
class ChildDM(DM):
|
136
|
+
model_db = ChildDB
|
137
|
+
model = Child
|
138
|
+
model_create = Child.Create
|
139
|
+
model_save = Child.Save
|
140
|
+
model_update = Child.Update
|
141
|
+
key = "id"
|
142
|
+
index_elements = ["id"]
|
143
|
+
joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
|
144
|
+
|
145
|
+
|
146
|
+
@pytest_asyncio.fixture
|
147
|
+
async def create_all(engine):
|
148
|
+
async with engine.begin() as conn:
|
149
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
150
|
+
|
151
|
+
|
152
|
+
async def example(engine):
|
153
|
+
engine = create_async_engine(...)
|
154
|
+
|
155
|
+
async with engine.begin() as conn:
|
156
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
157
|
+
|
158
|
+
bind_engine(engine)
|
159
|
+
|
160
|
+
parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
|
161
|
+
|
162
|
+
# parents: list[Parent]
|
163
|
+
total, parents = await ParentDM.get_many()
|
164
|
+
|
165
|
+
# parents: list[Parent]
|
166
|
+
total, parents = await ParentDM.get_many(page_size=1)
|
167
|
+
|
168
|
+
# parents: list[Parent]
|
169
|
+
total, parents = await ParentDM.get_many(page_size=1, page=2)
|
170
|
+
|
171
|
+
# parent: Parent
|
172
|
+
parent = await ParentDM.get(1)
|
173
|
+
|
174
|
+
# parent: Parent
|
175
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
176
|
+
|
177
|
+
# parent: Parent
|
178
|
+
parent = await ParentDM.save(parent.Save())
|
179
|
+
|
180
|
+
# parent: Parent
|
181
|
+
parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
|
182
|
+
|
183
|
+
await ParentDM.delete(1)
|
184
|
+
|
185
|
+
# child : Child
|
186
|
+
child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
|
187
|
+
|
188
|
+
# parent: Parent
|
189
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
190
|
+
```
|
191
|
+
|
dbdm-0.1.0/README.md
ADDED
@@ -0,0 +1,168 @@
|
|
1
|
+
# DBDM [wip]
|
2
|
+
|
3
|
+
|
4
|
+
## Examples
|
5
|
+
|
6
|
+
```python
|
7
|
+
from typing import ClassVar, Type
|
8
|
+
|
9
|
+
import sqlalchemy as sa
|
10
|
+
|
11
|
+
from cwtch import dataclass, resolve_types, view
|
12
|
+
from cwtch.types import UNSET, Unset
|
13
|
+
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
14
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
15
|
+
|
16
|
+
from dbdm import DM, NotFoundError, bind_engine
|
17
|
+
|
18
|
+
|
19
|
+
class BaseDB(DeclarativeBase):
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
class ParentDB(BaseDB):
|
24
|
+
__tablename__ = "parents"
|
25
|
+
|
26
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
27
|
+
name: Mapped[str]
|
28
|
+
data: Mapped[str]
|
29
|
+
children = relationship("ChildDB", uselist=True, viewonly=True)
|
30
|
+
|
31
|
+
|
32
|
+
class ChildDB(BaseDB):
|
33
|
+
__tablename__ = "children"
|
34
|
+
|
35
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
36
|
+
name: Mapped[str]
|
37
|
+
parent_id: Mapped[int] = mapped_column(sa.ForeignKey("parents.id"))
|
38
|
+
parent = relationship("ParentDB", uselist=False, viewonly=True)
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass(handle_circular_refs=True)
|
42
|
+
class Parent:
|
43
|
+
id: int
|
44
|
+
name: str
|
45
|
+
data: str
|
46
|
+
children: Unset[list["Child"]] = UNSET
|
47
|
+
|
48
|
+
# Parent views
|
49
|
+
Create: ClassVar[Type["ParentCreate"]]
|
50
|
+
Save: ClassVar[Type["ParentSave"]]
|
51
|
+
Update: ClassVar[Type["ParentUpdate"]]
|
52
|
+
|
53
|
+
|
54
|
+
@view(Parent, "Create", exclude=["id", "children"])
|
55
|
+
class ParentCreate:
|
56
|
+
pass
|
57
|
+
|
58
|
+
|
59
|
+
@view(Parent, "Save", exclude=["children"])
|
60
|
+
class ParentSave:
|
61
|
+
pass
|
62
|
+
|
63
|
+
|
64
|
+
@view(Parent, "Update", exclude=["children"])
|
65
|
+
class ParentUpdate:
|
66
|
+
name: Unset[str] = UNSET
|
67
|
+
data: Unset[str] = UNSET
|
68
|
+
|
69
|
+
|
70
|
+
@dataclass(handle_circular_refs=True)
|
71
|
+
class Child:
|
72
|
+
id: int
|
73
|
+
name: str
|
74
|
+
parent_id: int
|
75
|
+
parent: Unset[Parent] = UNSET
|
76
|
+
|
77
|
+
# Child views
|
78
|
+
Create: ClassVar[Type["ChildCreate"]]
|
79
|
+
Save: ClassVar[Type["ChildSave"]]
|
80
|
+
Update: ClassVar[Type["ChildUpdate"]]
|
81
|
+
|
82
|
+
|
83
|
+
@view(Child, "Create", exclude=["id", "parent"])
|
84
|
+
class ChildCreate:
|
85
|
+
pass
|
86
|
+
|
87
|
+
|
88
|
+
@view(Child, "Save", exclude=["parent"])
|
89
|
+
class ChildSave:
|
90
|
+
pass
|
91
|
+
|
92
|
+
|
93
|
+
@view(Child, "Update", exclude=["parent"])
|
94
|
+
class ChildUpdate:
|
95
|
+
name: Unset[str] = UNSET
|
96
|
+
parent_id: Unset[int] = UNSET
|
97
|
+
|
98
|
+
|
99
|
+
resolve_types(Parent, globals(), locals())
|
100
|
+
|
101
|
+
|
102
|
+
class ParentDM(DM):
|
103
|
+
model_db = ParentDB
|
104
|
+
model = Parent
|
105
|
+
model_create = Parent.Create
|
106
|
+
model_save = Parent.Save
|
107
|
+
model_update = Parent.Update
|
108
|
+
key = "id"
|
109
|
+
index_elements = ["id"]
|
110
|
+
joinedload = {"children": lambda m: sa.orm.joinedload(m.children)}
|
111
|
+
|
112
|
+
|
113
|
+
class ChildDM(DM):
|
114
|
+
model_db = ChildDB
|
115
|
+
model = Child
|
116
|
+
model_create = Child.Create
|
117
|
+
model_save = Child.Save
|
118
|
+
model_update = Child.Update
|
119
|
+
key = "id"
|
120
|
+
index_elements = ["id"]
|
121
|
+
joinedload = {"parent": lambda m: sa.orm.joinedload(m.parent)}
|
122
|
+
|
123
|
+
|
124
|
+
@pytest_asyncio.fixture
|
125
|
+
async def create_all(engine):
|
126
|
+
async with engine.begin() as conn:
|
127
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
128
|
+
|
129
|
+
|
130
|
+
async def example(engine):
|
131
|
+
engine = create_async_engine(...)
|
132
|
+
|
133
|
+
async with engine.begin() as conn:
|
134
|
+
await conn.run_sync(BaseDB.metadata.create_all)
|
135
|
+
|
136
|
+
bind_engine(engine)
|
137
|
+
|
138
|
+
parent = await ParentDM.create(Parent.Create(name=f"Parent_{i}", data="data"))
|
139
|
+
|
140
|
+
# parents: list[Parent]
|
141
|
+
total, parents = await ParentDM.get_many()
|
142
|
+
|
143
|
+
# parents: list[Parent]
|
144
|
+
total, parents = await ParentDM.get_many(page_size=1)
|
145
|
+
|
146
|
+
# parents: list[Parent]
|
147
|
+
total, parents = await ParentDM.get_many(page_size=1, page=2)
|
148
|
+
|
149
|
+
# parent: Parent
|
150
|
+
parent = await ParentDM.get(1)
|
151
|
+
|
152
|
+
# parent: Parent
|
153
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
154
|
+
|
155
|
+
# parent: Parent
|
156
|
+
parent = await ParentDM.save(parent.Save())
|
157
|
+
|
158
|
+
# parent: Parent
|
159
|
+
parent = await ParentDM.update(Parent.Update(id=1, data="new data"), key="id")
|
160
|
+
|
161
|
+
await ParentDM.delete(1)
|
162
|
+
|
163
|
+
# child : Child
|
164
|
+
child = await ChildDM.create(Child.Create(name=f"Child_{i}", parent_id=i))
|
165
|
+
|
166
|
+
# parent: Parent
|
167
|
+
parent = await ParentDM.get(1, joinedload={"children": True})
|
168
|
+
```
|
@@ -0,0 +1,736 @@
|
|
1
|
+
import re
|
2
|
+
import typing
|
3
|
+
|
4
|
+
from asyncio import gather
|
5
|
+
from contextlib import asynccontextmanager
|
6
|
+
from contextvars import ContextVar
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from typing import Any, AsyncIterator, Callable, Literal, Never, Optional, Protocol, Type, TypeVar
|
9
|
+
|
10
|
+
import sqlalchemy as sa
|
11
|
+
|
12
|
+
from sqlalchemy import delete, func, literal, select, union_all, update
|
13
|
+
from sqlalchemy.dialects.postgresql import insert
|
14
|
+
from sqlalchemy.ext.asyncio import async_sessionmaker
|
15
|
+
from sqlalchemy.orm import DeclarativeBase, aliased
|
16
|
+
|
17
|
+
|
18
|
+
__all__ = [
|
19
|
+
"DMError",
|
20
|
+
"AlreadyExistsError",
|
21
|
+
"BadParamsError",
|
22
|
+
"NotFoundError",
|
23
|
+
"OrderBy",
|
24
|
+
"bind_engine",
|
25
|
+
"transaction",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
_conn: dict[Optional[str], ContextVar] = {}
|
30
|
+
_engine: dict[Optional[str], sa.ext.asyncio.AsyncEngine] = {}
|
31
|
+
|
32
|
+
|
33
|
+
class DMError(Exception):
|
34
|
+
pass
|
35
|
+
|
36
|
+
|
37
|
+
class BadParamsError(DMError):
|
38
|
+
def __init__(self, message: str, param: Optional[str] = None):
|
39
|
+
self.message = message
|
40
|
+
self.param = param
|
41
|
+
|
42
|
+
def __str__(self):
|
43
|
+
return self.message
|
44
|
+
|
45
|
+
def __repr__(self):
|
46
|
+
return self.__str__()
|
47
|
+
|
48
|
+
|
49
|
+
class NotFoundError(DMError):
|
50
|
+
def __init__(self, key, value):
|
51
|
+
self.key = key
|
52
|
+
self.value = value
|
53
|
+
|
54
|
+
def __str__(self):
|
55
|
+
return f"item with ({self.key})=({self.value}) not found"
|
56
|
+
|
57
|
+
|
58
|
+
class AlreadyExistsError(DMError):
|
59
|
+
def __init__(self, key, value):
|
60
|
+
self.key = key
|
61
|
+
self.value = value
|
62
|
+
|
63
|
+
def __str__(self):
|
64
|
+
return f"key ({self.key})=({self.value}) already exists"
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass
|
68
|
+
class OrderBy:
|
69
|
+
by: Any
|
70
|
+
order: Literal["asc", "desc"] = "asc"
|
71
|
+
|
72
|
+
|
73
|
+
def bind_engine(engine: sa.ext.asyncio.AsyncEngine, name: Optional[str] = None):
|
74
|
+
if engine.dialect.name != "postgresql":
|
75
|
+
raise DMError("only 'postgresql' dialect is supported")
|
76
|
+
_engine[name] = engine
|
77
|
+
_conn[name] = ContextVar("conn", default=None)
|
78
|
+
|
79
|
+
|
80
|
+
@asynccontextmanager
|
81
|
+
async def transaction(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
|
82
|
+
if (conn := _conn[engine_name].get()) is None:
|
83
|
+
async with _engine[engine_name].connect() as conn:
|
84
|
+
async with conn.begin():
|
85
|
+
_conn[engine_name].set(conn)
|
86
|
+
try:
|
87
|
+
yield conn
|
88
|
+
finally:
|
89
|
+
_conn[engine_name].set(None)
|
90
|
+
else:
|
91
|
+
yield conn
|
92
|
+
|
93
|
+
|
94
|
+
@asynccontextmanager
|
95
|
+
async def get_conn(engine_name: Optional[str] = None) -> AsyncIterator[sa.ext.asyncio.AsyncConnection]:
|
96
|
+
if (conn := _conn[engine_name].get()) is None:
|
97
|
+
async with _engine[engine_name].connect() as conn:
|
98
|
+
async with conn.begin():
|
99
|
+
yield conn
|
100
|
+
else:
|
101
|
+
yield conn
|
102
|
+
|
103
|
+
|
104
|
+
@asynccontextmanager
|
105
|
+
async def get_sess(engine_name: Optional[str] = None):
|
106
|
+
async with get_conn(engine_name=engine_name) as conn:
|
107
|
+
async_session = async_sessionmaker(conn, expire_on_commit=False)
|
108
|
+
async with async_session() as sess:
|
109
|
+
async with sess.begin():
|
110
|
+
yield sess
|
111
|
+
|
112
|
+
|
113
|
+
def _raise_exc(e: Exception) -> Never:
|
114
|
+
if isinstance(e, sa.exc.IntegrityError):
|
115
|
+
detail_match = re.match(r".*\nDETAIL:\s*(?P<text>.*)$", e.orig.args[0]) # type: ignore
|
116
|
+
if detail_match:
|
117
|
+
text = detail_match.groupdict()["text"].strip()
|
118
|
+
m = re.match(r"Key \((?P<key>.*)\)=\((?P<key_value>.*)\) already exists.", text)
|
119
|
+
if m:
|
120
|
+
key = m.groupdict()["key"]
|
121
|
+
key_value = m.groupdict()["key_value"].strip('\\"')
|
122
|
+
raise AlreadyExistsError(key, key_value)
|
123
|
+
raise e
|
124
|
+
|
125
|
+
|
126
|
+
ModelDB = TypeVar("ModelDB", bound=DeclarativeBase)
|
127
|
+
|
128
|
+
|
129
|
+
class _BaseProtocol(Protocol):
|
130
|
+
engine_name: Optional[str]
|
131
|
+
|
132
|
+
model_db: Type[DeclarativeBase]
|
133
|
+
model: Type
|
134
|
+
|
135
|
+
@classmethod
|
136
|
+
async def _execute(cls, query) -> sa.ResultProxy: ...
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
def _make_from_db_data(
|
140
|
+
cls,
|
141
|
+
db_item,
|
142
|
+
row: Optional[tuple] = None,
|
143
|
+
joinedload: Optional[dict] = None,
|
144
|
+
) -> dict: ...
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def _from_db(
|
148
|
+
cls,
|
149
|
+
item: sa.Row | ModelDB,
|
150
|
+
data: Optional[dict] = None,
|
151
|
+
exclude: Optional[list] = None,
|
152
|
+
suffix: Optional[str] = None,
|
153
|
+
model_out=None,
|
154
|
+
): ...
|
155
|
+
|
156
|
+
@classmethod
|
157
|
+
def _get_key(cls, key=None): ...
|
158
|
+
|
159
|
+
@classmethod
|
160
|
+
async def get(
|
161
|
+
cls,
|
162
|
+
key_value,
|
163
|
+
key=None,
|
164
|
+
raise_not_found: Optional[bool] = None,
|
165
|
+
joinedload: Optional[dict] = None,
|
166
|
+
model_out: Optional[Type] = None,
|
167
|
+
**kwds,
|
168
|
+
) -> Optional[type]: ...
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
async def get_many(
|
172
|
+
cls,
|
173
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
174
|
+
page: Optional[int] = None,
|
175
|
+
page_size: Optional[int] = None,
|
176
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
177
|
+
joinedload: Optional[dict] = None,
|
178
|
+
model_out: Optional[Type] = None,
|
179
|
+
**kwds,
|
180
|
+
) -> tuple[int, list]: ...
|
181
|
+
|
182
|
+
|
183
|
+
class _Meta(typing._ProtocolMeta):
|
184
|
+
def __new__(cls, name, bases, ns, skip_checks: bool = False):
|
185
|
+
if not skip_checks:
|
186
|
+
|
187
|
+
def get_all_bases(bases: tuple) -> set:
|
188
|
+
result = set()
|
189
|
+
for base in bases:
|
190
|
+
result.add(base)
|
191
|
+
result |= get_all_bases(getattr(base, "__bases__", ()))
|
192
|
+
return result
|
193
|
+
|
194
|
+
all_bases = get_all_bases(bases)
|
195
|
+
|
196
|
+
if _GetDM not in all_bases:
|
197
|
+
raise TypeError("Any DM class should be subclassed from GetDM class")
|
198
|
+
|
199
|
+
if (model_db := ns.get("model_db")) is None:
|
200
|
+
raise ValueError("'model_db' field is required")
|
201
|
+
|
202
|
+
if (model := ns.get("model")) is None:
|
203
|
+
raise ValueError("'model' field is required")
|
204
|
+
|
205
|
+
if (
|
206
|
+
getattr(model, "__dataclass_fields__", None) is None
|
207
|
+
and getattr(model, "__pydantic_fields__", None) is None
|
208
|
+
):
|
209
|
+
raise ValueError("'model' is not a valid model")
|
210
|
+
|
211
|
+
if _CreateDM in all_bases or _SaveDM in all_bases:
|
212
|
+
if (model_create := ns.get("model_create")) is None:
|
213
|
+
raise ValueError("'model_create' field is required")
|
214
|
+
if (
|
215
|
+
getattr(model_create, "__dataclass_fields__", None) is None
|
216
|
+
and getattr(model_create, "__pydantic_fields__", None) is None
|
217
|
+
):
|
218
|
+
raise ValueError("'model_create' is not a valid model")
|
219
|
+
ns["_fields_create"] = set(model_db.__table__.columns.keys()) & set(
|
220
|
+
getattr(model_create, "__dataclass_fields__", None)
|
221
|
+
or getattr(model_create, "__pydantic_fields__", None)
|
222
|
+
or set()
|
223
|
+
)
|
224
|
+
if not ns["_fields_create"]:
|
225
|
+
raise DMError("model_create does not contain valid fields")
|
226
|
+
|
227
|
+
if _SaveDM in all_bases:
|
228
|
+
if (model_save := ns.get("model_save")) is None:
|
229
|
+
raise ValueError("'model_save' field is required")
|
230
|
+
if (
|
231
|
+
getattr(model_save, "__dataclass_fields__", None) is None
|
232
|
+
and getattr(model_save, "__pydantic_fields__", None) is None
|
233
|
+
):
|
234
|
+
raise ValueError("'model_save' is not a valid model")
|
235
|
+
ns["_fields_save"] = set(model_db.__table__.columns.keys()) & set(
|
236
|
+
getattr(model_save, "__dataclass_fields__", None)
|
237
|
+
or getattr(model_save, "__pydantic_fields__", None)
|
238
|
+
or set()
|
239
|
+
)
|
240
|
+
if not ns["_fields_save"]:
|
241
|
+
raise DMError("model_save does not contain valid fields")
|
242
|
+
|
243
|
+
if _UpdateDM in all_bases:
|
244
|
+
if (model_update := ns.get("model_update")) is None:
|
245
|
+
raise ValueError("'model_update' field is required")
|
246
|
+
if (
|
247
|
+
getattr(model_update, "__dataclass_fields__", None) is None
|
248
|
+
and getattr(model_update, "__pydantic_fields__", None) is None
|
249
|
+
):
|
250
|
+
raise ValueError("'model_update' is not a valid model")
|
251
|
+
ns["_fields_update"] = set(model_db.__table__.columns.keys()) & set(
|
252
|
+
getattr(model_update, "__dataclass_fields__", None)
|
253
|
+
or getattr(model_update, "__pydantic_fields__", None)
|
254
|
+
or set()
|
255
|
+
)
|
256
|
+
if not ns["_fields_update"]:
|
257
|
+
raise DMError("model_update does not contain valid fields")
|
258
|
+
|
259
|
+
return super().__new__(cls, name, bases, ns)
|
260
|
+
|
261
|
+
|
262
|
+
class _GetDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
|
263
|
+
engine_name: Optional[str] = None
|
264
|
+
|
265
|
+
model_db: Type[DeclarativeBase] = typing.cast(Type[DeclarativeBase], None)
|
266
|
+
model: Type = typing.cast(Type, None)
|
267
|
+
|
268
|
+
key = None
|
269
|
+
order_by: Optional[list[OrderBy | str | Callable] | OrderBy | str | Callable] = None
|
270
|
+
|
271
|
+
joinedload: dict = {}
|
272
|
+
|
273
|
+
@classmethod
|
274
|
+
async def _execute(cls, query) -> sa.ResultProxy:
|
275
|
+
async with get_conn(engine_name=cls.engine_name) as conn:
|
276
|
+
try:
|
277
|
+
return await conn.execute(query)
|
278
|
+
except Exception as e:
|
279
|
+
_raise_exc(e)
|
280
|
+
|
281
|
+
@classmethod
|
282
|
+
def _make_from_db_data(
|
283
|
+
cls,
|
284
|
+
db_item,
|
285
|
+
row: Optional[tuple] = None,
|
286
|
+
joinedload: Optional[dict] = None,
|
287
|
+
) -> dict:
|
288
|
+
return {}
|
289
|
+
|
290
|
+
@classmethod
|
291
|
+
def _from_db(
|
292
|
+
cls,
|
293
|
+
item: sa.Row | ModelDB,
|
294
|
+
data: Optional[dict] = None,
|
295
|
+
exclude: Optional[list] = None,
|
296
|
+
suffix: Optional[str] = None,
|
297
|
+
model_out=None,
|
298
|
+
): ...
|
299
|
+
|
300
|
+
@classmethod
|
301
|
+
def _get_key(cls, key=None):
|
302
|
+
if key is None:
|
303
|
+
key = cls.key
|
304
|
+
if key is None:
|
305
|
+
raise ValueError("key is None")
|
306
|
+
if isinstance(key, str):
|
307
|
+
key = getattr(cls.model_db, key)
|
308
|
+
return key
|
309
|
+
|
310
|
+
@classmethod
|
311
|
+
def _get_order_by(cls, c, order_by: Optional[list[OrderBy | str] | OrderBy | str] = None) -> list:
|
312
|
+
_order_by = order_by
|
313
|
+
if _order_by is None:
|
314
|
+
if callable(cls.order_by):
|
315
|
+
_order_by = cls.order_by(c)
|
316
|
+
else:
|
317
|
+
_order_by = cls.order_by or cls._get_key()
|
318
|
+
if _order_by is not None:
|
319
|
+
if not isinstance(_order_by, list):
|
320
|
+
_order_by = [_order_by]
|
321
|
+
else:
|
322
|
+
_order_by = list(_order_by)
|
323
|
+
for i, x in enumerate(_order_by):
|
324
|
+
if isinstance(x, OrderBy):
|
325
|
+
x = OrderBy(by=x.by, order=x.order)
|
326
|
+
else:
|
327
|
+
x = OrderBy(by=x)
|
328
|
+
if isinstance(x.by, str):
|
329
|
+
if getattr(c, x.by, None) is None:
|
330
|
+
raise BadParamsError(f"invalid order_by '{x.by}'", param=f"{x.by}")
|
331
|
+
x.by = getattr(c, x.by)
|
332
|
+
_order_by[i] = getattr(sa, x.order)(x.by)
|
333
|
+
return typing.cast(list, _order_by)
|
334
|
+
|
335
|
+
@classmethod
|
336
|
+
def _make_get_query(cls, key, key_value, **kwds) -> sa.sql.Select:
|
337
|
+
return select(cls.model_db).where(key == key_value)
|
338
|
+
|
339
|
+
@classmethod
|
340
|
+
async def get(
|
341
|
+
cls,
|
342
|
+
key_value,
|
343
|
+
key=None,
|
344
|
+
raise_not_found: Optional[bool] = None,
|
345
|
+
joinedload: Optional[dict] = None,
|
346
|
+
model_out: Optional[Type] = None,
|
347
|
+
**kwds,
|
348
|
+
) -> Optional[type]:
|
349
|
+
key = cls._get_key(key=key)
|
350
|
+
query = cls._make_get_query(key, key_value, **kwds)
|
351
|
+
model_out = model_out or cls.model
|
352
|
+
|
353
|
+
if joinedload is None or not any(joinedload.values()):
|
354
|
+
item = (await cls._execute(query)).one_or_none()
|
355
|
+
if item:
|
356
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
357
|
+
if raise_not_found:
|
358
|
+
raise NotFoundError(key=key, value=key_value)
|
359
|
+
return
|
360
|
+
|
361
|
+
exclude = []
|
362
|
+
for k, v in cls.joinedload.items():
|
363
|
+
if joinedload.get(k) is True:
|
364
|
+
query = query.options(v(cls.model_db))
|
365
|
+
else:
|
366
|
+
exclude.append(k)
|
367
|
+
|
368
|
+
async with get_sess(engine_name=cls.engine_name) as sess:
|
369
|
+
item = (await sess.execute(query)).unique().scalars().one_or_none()
|
370
|
+
if item:
|
371
|
+
return cls._from_db(
|
372
|
+
item,
|
373
|
+
exclude=exclude,
|
374
|
+
data=cls._make_from_db_data(item, joinedload=joinedload),
|
375
|
+
model_out=model_out,
|
376
|
+
)
|
377
|
+
if raise_not_found:
|
378
|
+
raise NotFoundError(key=key, value=key_value)
|
379
|
+
|
380
|
+
@classmethod
|
381
|
+
def _make_get_many_query(
|
382
|
+
cls,
|
383
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
384
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
385
|
+
**kwds,
|
386
|
+
) -> sa.sql.Select:
|
387
|
+
query = select(func.count(literal("*")).over().label("rows_total"), cls.model_db)
|
388
|
+
if flt is not None:
|
389
|
+
query = query.where(flt)
|
390
|
+
query = query.order_by(*cls._get_order_by(cls.model_db, order_by))
|
391
|
+
return query
|
392
|
+
|
393
|
+
@classmethod
|
394
|
+
async def get_many(
|
395
|
+
cls,
|
396
|
+
flt: Optional[sa.sql.elements.BinaryExpression] = None,
|
397
|
+
page: Optional[int] = None,
|
398
|
+
page_size: Optional[int] = None,
|
399
|
+
order_by: Optional[list[OrderBy | str] | OrderBy | str] = None,
|
400
|
+
joinedload: Optional[dict] = None,
|
401
|
+
model_out: Optional[Type] = None,
|
402
|
+
**kwds,
|
403
|
+
) -> tuple[int, list]:
|
404
|
+
model_db = cls.model_db
|
405
|
+
model_out = model_out or cls.model
|
406
|
+
|
407
|
+
query = cls._make_get_many_query(flt=flt, order_by=order_by, **kwds)
|
408
|
+
|
409
|
+
cte = query.cte("cte")
|
410
|
+
|
411
|
+
query = select(literal(1).label("i"), cte)
|
412
|
+
|
413
|
+
if page_size:
|
414
|
+
page = page or 1
|
415
|
+
query = query.limit(page_size).offset((page - 1) * page_size)
|
416
|
+
|
417
|
+
query = union_all(select(literal(0).label("i"), cte).limit(1), query)
|
418
|
+
|
419
|
+
from_db = cls._from_db
|
420
|
+
make_from_orm_data = cls._make_from_db_data
|
421
|
+
|
422
|
+
if joinedload is None or not any(joinedload.values()):
|
423
|
+
rows = (await cls._execute(query)).all()
|
424
|
+
return rows[0].rows_total if rows else 0, [
|
425
|
+
from_db(row, data=make_from_orm_data(row), model_out=model_out) for row in rows[1:]
|
426
|
+
]
|
427
|
+
|
428
|
+
main_cte = query.cte("main_cte")
|
429
|
+
|
430
|
+
model_db_alias = aliased(model_db, main_cte)
|
431
|
+
query = select(main_cte, model_db_alias)
|
432
|
+
|
433
|
+
exclude = []
|
434
|
+
for k, v in cls.joinedload.items():
|
435
|
+
if joinedload.get(k) is True:
|
436
|
+
query = query.options(v(model_db_alias))
|
437
|
+
else:
|
438
|
+
exclude.append(k)
|
439
|
+
|
440
|
+
query = query.order_by(sa.asc(main_cte.c.i), *cls._get_order_by(main_cte.c, order_by))
|
441
|
+
|
442
|
+
def _hash(row):
|
443
|
+
return hash((row[0], row[1], row[-1]))
|
444
|
+
|
445
|
+
async with get_sess(engine_name=cls.engine_name) as sess:
|
446
|
+
rows = (await sess.execute(query)).unique(_hash).all()
|
447
|
+
return rows[0].rows_total if rows else 0, [
|
448
|
+
from_db(
|
449
|
+
row[-1],
|
450
|
+
exclude=exclude,
|
451
|
+
data=make_from_orm_data(row[-1], row=typing.cast(tuple, row), joinedload=joinedload),
|
452
|
+
model_out=model_out,
|
453
|
+
)
|
454
|
+
for row in rows[1:]
|
455
|
+
]
|
456
|
+
|
457
|
+
|
458
|
+
class _CreateDM(_BaseProtocol, metaclass=_Meta, skip_checks=True):
|
459
|
+
model_create: Type = typing.cast(Type, None)
|
460
|
+
|
461
|
+
index_elements: Optional[list] = None
|
462
|
+
|
463
|
+
_fields_create: set[str] = set()
|
464
|
+
|
465
|
+
@classmethod
|
466
|
+
def _get_values_for_create_query(cls, model) -> dict: ...
|
467
|
+
|
468
|
+
@classmethod
|
469
|
+
def _make_create_query(cls, model, returning: Optional[bool] = None) -> sa.sql.Insert:
|
470
|
+
model_db = cls.model_db
|
471
|
+
query = insert(model_db).values(cls._get_values_for_create_query(model))
|
472
|
+
if returning:
|
473
|
+
query = query.returning(model_db)
|
474
|
+
return query
|
475
|
+
|
476
|
+
@classmethod
|
477
|
+
async def create(
|
478
|
+
cls,
|
479
|
+
model,
|
480
|
+
model_out: Optional[Type] = None,
|
481
|
+
returning: bool = True,
|
482
|
+
):
|
483
|
+
result = await cls._execute(cls._make_create_query(model, returning=returning))
|
484
|
+
if returning:
|
485
|
+
item = result.one()
|
486
|
+
model_out = model_out or cls.model
|
487
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
488
|
+
|
489
|
+
@classmethod
|
490
|
+
def _make_create_many_query(cls, models: list, returning: bool | None = None) -> sa.sql.Insert:
|
491
|
+
model_db = cls.model_db
|
492
|
+
query = insert(cls.model_db).values([cls._get_values_for_create_query(model) for model in models])
|
493
|
+
if returning:
|
494
|
+
query = query.returning(model_db)
|
495
|
+
return query
|
496
|
+
|
497
|
+
@classmethod
|
498
|
+
async def create_many(
|
499
|
+
cls,
|
500
|
+
models: list,
|
501
|
+
model_out: Optional[Type] = None,
|
502
|
+
returning: bool = True,
|
503
|
+
) -> list | None:
|
504
|
+
if not models:
|
505
|
+
return []
|
506
|
+
result = await cls._execute(cls._make_create_many_query(models, returning=returning))
|
507
|
+
if returning:
|
508
|
+
from_db = cls._from_db
|
509
|
+
make_from_db_data = cls._make_from_db_data
|
510
|
+
model_out = model_out or cls.model
|
511
|
+
return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
|
512
|
+
|
513
|
+
@classmethod
|
514
|
+
def _make_get_or_create_query(
|
515
|
+
cls,
|
516
|
+
model,
|
517
|
+
update_element: str,
|
518
|
+
index_elements: Optional[list] = None,
|
519
|
+
) -> sa.sql.Insert:
|
520
|
+
model_db = cls.model_db
|
521
|
+
query = insert(model_db).values(cls._get_values_for_create_query(model))
|
522
|
+
index_elements = index_elements or cls.index_elements
|
523
|
+
if index_elements is None:
|
524
|
+
raise ValueError("index_elements is None")
|
525
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
526
|
+
return query.on_conflict_do_update(
|
527
|
+
index_elements=index_elements,
|
528
|
+
set_={update_element: getattr(query.excluded, update_element)},
|
529
|
+
).returning(model_db)
|
530
|
+
|
531
|
+
@classmethod
|
532
|
+
async def get_or_create(
|
533
|
+
cls,
|
534
|
+
model,
|
535
|
+
update_element: str,
|
536
|
+
index_elements: Optional[list] = None,
|
537
|
+
model_out: Optional[Type] = None,
|
538
|
+
):
|
539
|
+
model_out = model_out or cls.model
|
540
|
+
item = (
|
541
|
+
await cls._execute(
|
542
|
+
cls._make_get_or_create_query(
|
543
|
+
model,
|
544
|
+
update_element,
|
545
|
+
index_elements=index_elements,
|
546
|
+
)
|
547
|
+
)
|
548
|
+
).one()
|
549
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
550
|
+
|
551
|
+
|
552
|
+
class _SaveDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
553
|
+
model_create: Type = typing.cast(Type, None)
|
554
|
+
model_save: Type = typing.cast(Type, None)
|
555
|
+
|
556
|
+
index_elements: Optional[list] = None
|
557
|
+
|
558
|
+
_fields_create: set[str] = set()
|
559
|
+
_fields_save: set[str] = set()
|
560
|
+
|
561
|
+
@classmethod
|
562
|
+
def _get_values_for_save_query(cls, model) -> dict: ...
|
563
|
+
|
564
|
+
@classmethod
|
565
|
+
def _get_on_conflict_do_update_set_for_save_query(cls, excluded, model) -> dict:
|
566
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
567
|
+
return {k: getattr(excluded, k) for k in cls._fields_create}
|
568
|
+
return {k: getattr(excluded, k) for k in cls._fields_save}
|
569
|
+
|
570
|
+
@classmethod
|
571
|
+
def _make_save_query(
|
572
|
+
cls,
|
573
|
+
model,
|
574
|
+
index_elements: Optional[list] = None,
|
575
|
+
returning: Optional[bool] = None,
|
576
|
+
) -> sa.sql.Insert:
|
577
|
+
model_db = cls.model_db
|
578
|
+
query = insert(model_db).values(cls._get_values_for_save_query(model))
|
579
|
+
index_elements = index_elements or cls.index_elements
|
580
|
+
if index_elements is None:
|
581
|
+
raise ValueError("index_elements is None")
|
582
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
583
|
+
query = query.on_conflict_do_update(
|
584
|
+
index_elements=index_elements,
|
585
|
+
set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, model),
|
586
|
+
)
|
587
|
+
if returning:
|
588
|
+
query = query.returning(model_db)
|
589
|
+
return query
|
590
|
+
|
591
|
+
@classmethod
|
592
|
+
async def save(
|
593
|
+
cls,
|
594
|
+
model,
|
595
|
+
index_elements: Optional[list] = None,
|
596
|
+
model_out: Optional[Type] = None,
|
597
|
+
returning: bool = True,
|
598
|
+
):
|
599
|
+
result = await cls._execute(
|
600
|
+
cls._make_save_query(
|
601
|
+
model,
|
602
|
+
index_elements=index_elements,
|
603
|
+
returning=returning,
|
604
|
+
)
|
605
|
+
)
|
606
|
+
if returning:
|
607
|
+
item = result.one()
|
608
|
+
model_out = model_out or cls.model
|
609
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
610
|
+
|
611
|
+
@classmethod
|
612
|
+
def _make_save_many_query(
|
613
|
+
cls,
|
614
|
+
models: list,
|
615
|
+
index_elements: Optional[list] = None,
|
616
|
+
returning: Optional[bool] = None,
|
617
|
+
) -> sa.sql.Insert:
|
618
|
+
model_db = cls.model_db
|
619
|
+
query = insert(model_db).values([cls._get_values_for_save_query(model) for model in models])
|
620
|
+
index_elements = index_elements or cls.index_elements
|
621
|
+
if index_elements is None:
|
622
|
+
raise ValueError("index_elements is None")
|
623
|
+
index_elements = list(map(lambda e: isinstance(e, str) and getattr(model_db, e) or e, index_elements))
|
624
|
+
query = query.on_conflict_do_update(
|
625
|
+
index_elements=index_elements,
|
626
|
+
set_=cls._get_on_conflict_do_update_set_for_save_query(query.excluded, models[0]),
|
627
|
+
)
|
628
|
+
if returning:
|
629
|
+
query = query.returning(model_db)
|
630
|
+
return query
|
631
|
+
|
632
|
+
@classmethod
|
633
|
+
async def save_many(
|
634
|
+
cls,
|
635
|
+
models: list,
|
636
|
+
index_elements: Optional[list] = None,
|
637
|
+
model_out: Optional[Type] = None,
|
638
|
+
returning: bool = True,
|
639
|
+
) -> list | None:
|
640
|
+
if not models:
|
641
|
+
return []
|
642
|
+
result = await cls._execute(
|
643
|
+
cls._make_save_many_query(
|
644
|
+
models,
|
645
|
+
index_elements=index_elements,
|
646
|
+
returning=returning,
|
647
|
+
)
|
648
|
+
)
|
649
|
+
if returning:
|
650
|
+
from_db = cls._from_db
|
651
|
+
make_from_db_data = cls._make_from_db_data
|
652
|
+
model_out = model_out or cls.model
|
653
|
+
return [from_db(item, data=make_from_db_data(item), model_out=model_out) for item in result.all()]
|
654
|
+
|
655
|
+
|
656
|
+
class _UpdateDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
657
|
+
model_update: Type = typing.cast(Type, None)
|
658
|
+
_fields_update: set[str] = set()
|
659
|
+
|
660
|
+
@classmethod
|
661
|
+
def _get_values_for_update_query(cls, model) -> dict: ...
|
662
|
+
|
663
|
+
@classmethod
|
664
|
+
def _make_update_query(cls, model, key: str, returning: Optional[bool] = None) -> sa.sql.Update:
|
665
|
+
model_db = cls.model_db
|
666
|
+
query = (
|
667
|
+
update(model_db)
|
668
|
+
.values(cls._get_values_for_update_query(model))
|
669
|
+
.where(getattr(model_db, key) == getattr(model, key))
|
670
|
+
.returning(getattr(model_db, key))
|
671
|
+
)
|
672
|
+
if returning:
|
673
|
+
query = query.returning(model_db)
|
674
|
+
return query
|
675
|
+
|
676
|
+
@classmethod
|
677
|
+
async def update(
|
678
|
+
cls,
|
679
|
+
model,
|
680
|
+
key: str,
|
681
|
+
raise_not_found: Optional[bool] = None,
|
682
|
+
model_out: Optional[Type] = None,
|
683
|
+
returning: bool = True,
|
684
|
+
):
|
685
|
+
result = await cls._execute(cls._make_update_query(model, key, returning=returning or raise_not_found))
|
686
|
+
if raise_not_found and result.rowcount == 0:
|
687
|
+
raise NotFoundError(key=key, value=getattr(model, key))
|
688
|
+
if returning:
|
689
|
+
item = result.one_or_none()
|
690
|
+
if item:
|
691
|
+
model_out = model_out or cls.model
|
692
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
693
|
+
|
694
|
+
@classmethod
|
695
|
+
async def update_many(
|
696
|
+
cls,
|
697
|
+
models: list,
|
698
|
+
key: str,
|
699
|
+
model_out: Optional[Type] = None,
|
700
|
+
returning: bool = True,
|
701
|
+
) -> list | None:
|
702
|
+
async with transaction():
|
703
|
+
results = await gather(
|
704
|
+
*[cls.update(model, key, model_out=model_out, returning=returning) for model in models]
|
705
|
+
)
|
706
|
+
if returning:
|
707
|
+
return typing.cast(list | None, results)
|
708
|
+
|
709
|
+
|
710
|
+
class _DeleteDM(_CreateDM, metaclass=_Meta, skip_checks=True):
|
711
|
+
@classmethod
|
712
|
+
def _make_delete_query(cls, key, key_value, returning: Optional[bool] = None):
|
713
|
+
query = delete(cls.model_db).where(key == key_value)
|
714
|
+
if returning:
|
715
|
+
query = query.returning(cls.model_db)
|
716
|
+
return query
|
717
|
+
|
718
|
+
@classmethod
|
719
|
+
async def delete(
|
720
|
+
cls,
|
721
|
+
key_value,
|
722
|
+
key=None,
|
723
|
+
raise_not_found: Optional[bool] = None,
|
724
|
+
model_out: Optional[Type] = None,
|
725
|
+
returning: bool = True,
|
726
|
+
):
|
727
|
+
key = cls._get_key(key=key)
|
728
|
+
result = await cls._execute(cls._make_delete_query(key, key_value, returning=returning))
|
729
|
+
if raise_not_found:
|
730
|
+
if raise_not_found and result.rowcount == 0:
|
731
|
+
raise NotFoundError(key=key, value=key_value)
|
732
|
+
if returning:
|
733
|
+
item = result.one_or_none()
|
734
|
+
if item:
|
735
|
+
model_out = model_out or cls.model
|
736
|
+
return cls._from_db(item, data=cls._make_from_db_data(item), model_out=model_out)
|
dbdm-0.1.0/dbdm/cwtch.py
ADDED
@@ -0,0 +1,66 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
from typing import Optional, Type
|
4
|
+
|
5
|
+
import sqlalchemy as sa
|
6
|
+
|
7
|
+
from cwtch import asdict, from_attributes
|
8
|
+
|
9
|
+
from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
|
10
|
+
|
11
|
+
|
12
|
+
__all__ = [
|
13
|
+
"OrderBy",
|
14
|
+
"GetDM",
|
15
|
+
"CreateDM",
|
16
|
+
"SaveDM",
|
17
|
+
"UpdateDM",
|
18
|
+
"DeleteDM",
|
19
|
+
"DM",
|
20
|
+
]
|
21
|
+
|
22
|
+
|
23
|
+
class GetDM(_GetDM, skip_checks=True):
|
24
|
+
@classmethod
|
25
|
+
def _from_db(
|
26
|
+
cls,
|
27
|
+
item: sa.Row | ModelDB,
|
28
|
+
data: Optional[dict] = None,
|
29
|
+
exclude: Optional[list] = None,
|
30
|
+
suffix: Optional[str] = None,
|
31
|
+
model_out=None,
|
32
|
+
):
|
33
|
+
return from_attributes(
|
34
|
+
model_out or cls.model,
|
35
|
+
item,
|
36
|
+
data=data,
|
37
|
+
exclude=exclude,
|
38
|
+
suffix=suffix,
|
39
|
+
reset_circular_refs=True,
|
40
|
+
)
|
41
|
+
|
42
|
+
|
43
|
+
class CreateDM(_CreateDM, skip_checks=True):
|
44
|
+
@classmethod
|
45
|
+
def _get_values_for_create_query(cls, model) -> dict:
|
46
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_create))
|
47
|
+
|
48
|
+
|
49
|
+
class SaveDM(_SaveDM, skip_checks=True):
|
50
|
+
@classmethod
|
51
|
+
def _get_values_for_save_query(cls, model) -> dict:
|
52
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
53
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_create))
|
54
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_save))
|
55
|
+
|
56
|
+
|
57
|
+
class UpdateDM(_UpdateDM, skip_checks=True):
|
58
|
+
@classmethod
|
59
|
+
def _get_values_for_update_query(cls, model) -> dict:
|
60
|
+
return asdict(model, include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
|
61
|
+
|
62
|
+
|
63
|
+
class DeleteDM(_DeleteDM, skip_checks=True): ...
|
64
|
+
|
65
|
+
|
66
|
+
class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
|
@@ -0,0 +1,76 @@
|
|
1
|
+
import typing
|
2
|
+
|
3
|
+
from typing import Optional, Type
|
4
|
+
|
5
|
+
import pydantic # noqa: F401 # type: ignore
|
6
|
+
import sqlalchemy as sa
|
7
|
+
|
8
|
+
from .common import ModelDB, OrderBy, _CreateDM, _DeleteDM, _GetDM, _SaveDM, _UpdateDM
|
9
|
+
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"OrderBy",
|
13
|
+
"GetDM",
|
14
|
+
"CreateDM",
|
15
|
+
"SaveDM",
|
16
|
+
"UpdateDM",
|
17
|
+
"DeleteDM",
|
18
|
+
"DM",
|
19
|
+
]
|
20
|
+
|
21
|
+
|
22
|
+
class _object:
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
26
|
+
class GetDM(_GetDM, skip_checks=True):
|
27
|
+
@classmethod
|
28
|
+
def _from_db(
|
29
|
+
cls,
|
30
|
+
item: sa.Row | ModelDB,
|
31
|
+
data: Optional[dict] = None,
|
32
|
+
exclude: Optional[list] = None,
|
33
|
+
suffix: Optional[str] = None,
|
34
|
+
model_out=None,
|
35
|
+
):
|
36
|
+
if isinstance(item, sa.Row):
|
37
|
+
data = {**item._asdict(), **(data or {})}
|
38
|
+
else:
|
39
|
+
data = {**item.__dict__, **(data or {})}
|
40
|
+
|
41
|
+
obj = _object()
|
42
|
+
|
43
|
+
if suffix:
|
44
|
+
obj.__dict__ = {
|
45
|
+
k.rstrip(suffix): v for k, v in data.items() if (not exclude or k not in exclude) and k.endswith(suffix)
|
46
|
+
}
|
47
|
+
else:
|
48
|
+
obj.__dict__ = {k: v for k, v in data.items() if not exclude or k not in exclude}
|
49
|
+
|
50
|
+
return (model_out or cls.model).model_validate(obj, from_attributes=True)
|
51
|
+
|
52
|
+
|
53
|
+
class CreateDM(_CreateDM, skip_checks=True):
|
54
|
+
@classmethod
|
55
|
+
def _get_values_for_create_query(cls, model) -> dict:
|
56
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_create))
|
57
|
+
|
58
|
+
|
59
|
+
class SaveDM(_SaveDM, skip_checks=True):
|
60
|
+
@classmethod
|
61
|
+
def _get_values_for_save_query(cls, model) -> dict:
|
62
|
+
if isinstance(model, typing.cast(Type, cls.model_create)):
|
63
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_create))
|
64
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_save))
|
65
|
+
|
66
|
+
|
67
|
+
class UpdateDM(_UpdateDM, skip_checks=True):
|
68
|
+
@classmethod
|
69
|
+
def _get_values_for_update_query(cls, model) -> dict:
|
70
|
+
return model.model_dump(include=typing.cast(list[str], cls._fields_update), exclude_unset=True)
|
71
|
+
|
72
|
+
|
73
|
+
class DeleteDM(_DeleteDM, skip_checks=True): ...
|
74
|
+
|
75
|
+
|
76
|
+
class DM(GetDM, CreateDM, SaveDM, UpdateDM, DeleteDM, skip_checks=True): ...
|
@@ -0,0 +1,51 @@
|
|
1
|
+
[tool.poetry]
|
2
|
+
name = "dbdm"
|
3
|
+
version = "0.1.0"
|
4
|
+
description = ""
|
5
|
+
authors = ["Roman Koshel <roma.koshel@gmail.com>"]
|
6
|
+
license = "MIT"
|
7
|
+
readme = "README.md"
|
8
|
+
keywords = []
|
9
|
+
packages = [{ include = "dbdm" }]
|
10
|
+
classifiers = [
|
11
|
+
"Development Status :: 3 - Alpha",
|
12
|
+
"Programming Language :: Python",
|
13
|
+
"Programming Language :: Python :: 3 :: Only",
|
14
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
15
|
+
"Topic :: Utilities",
|
16
|
+
"License :: OSI Approved :: MIT License",
|
17
|
+
]
|
18
|
+
|
19
|
+
[tool.poetry.dependencies]
|
20
|
+
python = ">=3.11, <4.0"
|
21
|
+
asyncpg = "*"
|
22
|
+
sqlalchemy = { version="^2", extras=["asyncio"] }
|
23
|
+
|
24
|
+
[tool.poetry.group.test.dependencies]
|
25
|
+
coverage = "*"
|
26
|
+
cwtch = ">=0.10.0"
|
27
|
+
docker = "*"
|
28
|
+
pydantic = "^2"
|
29
|
+
pylint = "*"
|
30
|
+
pytest = "*"
|
31
|
+
pytest_asyncio = "*"
|
32
|
+
pytest_timeout = "*"
|
33
|
+
|
34
|
+
[tool.poetry.group.docs.dependencies]
|
35
|
+
mkdocs = "*"
|
36
|
+
mkdocs-material = "*"
|
37
|
+
mkdocstrings = { extras = ["python"], version = "*" }
|
38
|
+
|
39
|
+
[build-system]
|
40
|
+
requires = ["poetry-core"]
|
41
|
+
build-backend = "poetry.core.masonry.api"
|
42
|
+
|
43
|
+
[tool.black]
|
44
|
+
line-length = 120
|
45
|
+
|
46
|
+
[tool.isort]
|
47
|
+
indent = 4
|
48
|
+
lines_after_imports = 2
|
49
|
+
lines_between_types = 1
|
50
|
+
src_paths = ["."]
|
51
|
+
profile = "black"
|