entari-plugin-database 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- entari_plugin_database-0.1.0/LICENSE +21 -0
- entari_plugin_database-0.1.0/PKG-INFO +16 -0
- entari_plugin_database-0.1.0/README.md +2 -0
- entari_plugin_database-0.1.0/pyproject.toml +33 -0
- entari_plugin_database-0.1.0/src/entari_plugin_database/__init__.py +154 -0
- entari_plugin_database-0.1.0/src/entari_plugin_database/param.py +244 -0
- entari_plugin_database-0.1.0/tests/__init__.py +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 ARCLET
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: entari-plugin-database
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Entari plugin for SQLAlchemy ORM
|
|
5
|
+
Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Requires-Dist: sqlalchemy>=2.0.42
|
|
9
|
+
Requires-Dist: aiosqlite>=0.21.0
|
|
10
|
+
Requires-Dist: graia-amnesia>=0.10.1
|
|
11
|
+
Requires-Dist: tarina<0.8.0,>=0.7.1
|
|
12
|
+
Requires-Dist: arclet-entari>=0.15.0
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# entari-plugin-database
|
|
16
|
+
Entari plugin for SQLAlchemy ORM
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "entari-plugin-database"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Entari plugin for SQLAlchemy ORM"
|
|
5
|
+
authors = [
|
|
6
|
+
{ name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
|
|
7
|
+
]
|
|
8
|
+
dependencies = [
|
|
9
|
+
"sqlalchemy>=2.0.42",
|
|
10
|
+
"aiosqlite>=0.21.0",
|
|
11
|
+
"graia-amnesia>=0.10.1",
|
|
12
|
+
"tarina<0.8.0,>=0.7.1",
|
|
13
|
+
"arclet-entari>=0.15.0",
|
|
14
|
+
]
|
|
15
|
+
requires-python = ">=3.9"
|
|
16
|
+
readme = "README.md"
|
|
17
|
+
|
|
18
|
+
[project.license]
|
|
19
|
+
text = "MIT"
|
|
20
|
+
|
|
21
|
+
[build-system]
|
|
22
|
+
requires = [
|
|
23
|
+
"pdm-backend",
|
|
24
|
+
]
|
|
25
|
+
build-backend = "pdm.backend"
|
|
26
|
+
|
|
27
|
+
[tool.pdm]
|
|
28
|
+
distribution = true
|
|
29
|
+
|
|
30
|
+
[tool.pdm.dev-dependencies]
|
|
31
|
+
dev = [
|
|
32
|
+
"arclet-entari[full]>=0.15.0",
|
|
33
|
+
]
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
from dataclasses import field
|
|
2
|
+
from typing import Optional, Any, Literal
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from sqlalchemy.ext.asyncio import create_async_engine
|
|
6
|
+
from arclet.letoderea.scope import global_providers, global_propagators
|
|
7
|
+
from arclet.entari import BasicConfModel, plugin, logger
|
|
8
|
+
from arclet.entari.config import config_model_validate
|
|
9
|
+
from arclet.entari.event.config import ConfigReload
|
|
10
|
+
from graia.amnesia.builtins.sqla import SqlalchemyService
|
|
11
|
+
from graia.amnesia.builtins.sqla.model import register_callback, remove_callback
|
|
12
|
+
from graia.amnesia.builtins.sqla.model import Base as Base
|
|
13
|
+
from graia.amnesia.builtins.sqla.types import EngineOptions
|
|
14
|
+
from sqlalchemy.engine.url import URL
|
|
15
|
+
from sqlalchemy.ext import asyncio as sa_async
|
|
16
|
+
from sqlalchemy.orm import Mapped as Mapped
|
|
17
|
+
from sqlalchemy.orm import mapped_column as mapped_column
|
|
18
|
+
|
|
19
|
+
from .param import db_supplier, sess_provider, orm_factory
|
|
20
|
+
from .param import SQLDepends as SQLDepends
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class UrlInfo(BasicConfModel):
|
|
24
|
+
type: str = "sqlite"
|
|
25
|
+
"""数据库类型,默认为 sqlite"""
|
|
26
|
+
name: str = "data.db"
|
|
27
|
+
"""数据库名称/文件路径"""
|
|
28
|
+
driver: str = "aiosqlite"
|
|
29
|
+
"""数据库驱动,默认为 aiosqlite;其他类型的数据库驱动参考 SQLAlchemy 文档"""
|
|
30
|
+
host: Optional[str] = None
|
|
31
|
+
"""数据库主机地址。如果是 SQLite 数据库,此项可不填。"""
|
|
32
|
+
port: Optional[int] = None
|
|
33
|
+
"""数据库端口号。如果是 SQLite 数据库,此项可不填。"""
|
|
34
|
+
username: Optional[str] = None
|
|
35
|
+
"""数据库用户名。如果是 SQLite 数据库,此项可不填。"""
|
|
36
|
+
password: Optional[str] = None
|
|
37
|
+
"""数据库密码。如果是 SQLite 数据库,此项可不填。"""
|
|
38
|
+
query: dict[str, Union[list[str], str]] = field(default_factory=dict)
|
|
39
|
+
"""数据库连接参数,默认为空字典。可以传入如 `{"timeout": "30"}` 的参数。"""
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def url(self) -> URL:
|
|
43
|
+
if self.type == "sqlite":
|
|
44
|
+
return URL.create(f"{self.type}+{self.driver}", database=self.name, query=self.query)
|
|
45
|
+
return URL.create(
|
|
46
|
+
f"{self.type}+{self.driver}", self.username, self.password, self.host, self.port, self.name, self.query
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class Config(UrlInfo):
|
|
51
|
+
options: EngineOptions = field(default_factory=lambda: {"echo": None, "pool_pre_ping": True})
|
|
52
|
+
"""数据库连接选项,默认为 `{"echo": None, "pool_pre_ping": True}`"""
|
|
53
|
+
session_options: Union[dict[str, Any], None] = field(default=None)
|
|
54
|
+
"""数据库会话选项,默认为 None。可以传入如 `{"expire_on_commit": False}` 的字典。"""
|
|
55
|
+
binds: dict[str, UrlInfo] = field(default_factory=dict)
|
|
56
|
+
"""数据库绑定配置,默认为 None。可以传入如 `{"bind1": UrlInfo(...), "bind2": UrlInfo(...)}` 的字典。"""
|
|
57
|
+
create_table_at: Literal['preparing', 'prepared', 'blocking'] = "preparing"
|
|
58
|
+
"""在指定阶段创建数据库表,默认为 'preparing'。可选值为 'preparing', 'prepared', 'blocking'。"""
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
plugin.declare_static()
|
|
62
|
+
plugin.metadata(
|
|
63
|
+
"Database 服务",
|
|
64
|
+
[{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
|
|
65
|
+
"0.1.0",
|
|
66
|
+
description="基于 SQLAlchemy 的数据库服务插件",
|
|
67
|
+
urls={
|
|
68
|
+
"homepage": "https://github.com/ArcletProject/entari-plugin-database",
|
|
69
|
+
},
|
|
70
|
+
config=Config,
|
|
71
|
+
)
|
|
72
|
+
plugin.collect_disposes(
|
|
73
|
+
lambda: global_propagators.remove(db_supplier),
|
|
74
|
+
lambda: global_providers.remove(sess_provider),
|
|
75
|
+
lambda: global_providers.remove(orm_factory),
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
log = logger.log.wrapper("[Database]")
|
|
79
|
+
_config = plugin.get_config(Config)
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
plugin.add_service(
|
|
83
|
+
service := SqlalchemyService(
|
|
84
|
+
_config.url,
|
|
85
|
+
_config.options,
|
|
86
|
+
_config.session_options,
|
|
87
|
+
{key: value.url for key, value in _config.binds.items()},
|
|
88
|
+
_config.create_table_at
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
raise RuntimeError("Failed to initialize SqlalchemyService. Please check your database configuration.") from e
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@plugin.listen(ConfigReload)
|
|
96
|
+
async def reload_config(event: ConfigReload, serv: SqlalchemyService):
|
|
97
|
+
if event.scope != "plugin":
|
|
98
|
+
return None
|
|
99
|
+
if event.key not in ("database", "entari_plugin_database"):
|
|
100
|
+
return None
|
|
101
|
+
new_conf = config_model_validate(Config, event.value)
|
|
102
|
+
for engine in serv.engines.values():
|
|
103
|
+
await engine.dispose(close=True)
|
|
104
|
+
engine_options = {"echo": "debug", "pool_pre_ping": True}
|
|
105
|
+
serv.engines = {"": create_async_engine(new_conf.url, **(new_conf.options or engine_options))}
|
|
106
|
+
for key, bind in (new_conf.binds or {}).items():
|
|
107
|
+
serv.engines[key] = create_async_engine(bind.url, **(new_conf.options or engine_options))
|
|
108
|
+
serv.create_table_at = new_conf.create_table_at
|
|
109
|
+
serv.session_options = new_conf.session_options or {"expire_on_commit": False}
|
|
110
|
+
|
|
111
|
+
binds = await serv.initialize()
|
|
112
|
+
log.success("Database initialized!")
|
|
113
|
+
for key, models in binds.items():
|
|
114
|
+
async with serv.engines[key].begin() as conn:
|
|
115
|
+
await conn.run_sync(
|
|
116
|
+
serv.base_class.metadata.create_all, tables=[m.__table__ for m in models], checkfirst=True
|
|
117
|
+
)
|
|
118
|
+
log.success("Database tables created!")
|
|
119
|
+
return True
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _setup_tablename(cls: type[Base], kwargs: dict):
|
|
123
|
+
if "tablename" in kwargs:
|
|
124
|
+
return
|
|
125
|
+
for attr in ("__tablename__", "__table__"):
|
|
126
|
+
if getattr(cls, attr, None):
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
cls.__tablename__ = cls.__name__.lower()
|
|
130
|
+
|
|
131
|
+
if plg := plugin.get_plugin(3):
|
|
132
|
+
cls.__tablename__ = f"{plg.id.replace('-', '_')}_{cls.__tablename__}"
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
register_callback(_setup_tablename)
|
|
136
|
+
plugin.collect_disposes(lambda: remove_callback(_setup_tablename))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
BaseOrm = Base
|
|
140
|
+
AsyncSession = sa_async.AsyncSession
|
|
141
|
+
get_session = service.get_session
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
__all__ = [
|
|
145
|
+
"AsyncSession",
|
|
146
|
+
"Base",
|
|
147
|
+
"BaseOrm",
|
|
148
|
+
"Mapped",
|
|
149
|
+
"mapped_column",
|
|
150
|
+
"service",
|
|
151
|
+
"SQLDepends",
|
|
152
|
+
"get_session",
|
|
153
|
+
"SqlalchemyService"
|
|
154
|
+
]
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from inspect import Signature, Parameter
|
|
3
|
+
from operator import methodcaller
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterator, Sequence, AsyncIterator
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import Row, Result, ScalarResult, select
|
|
9
|
+
from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
|
|
10
|
+
from sqlalchemy.sql.selectable import ExecutableReturnsRows
|
|
11
|
+
from tarina import generic_issubclass
|
|
12
|
+
from tarina.generic import origin_is_union, isclass
|
|
13
|
+
from typing_extensions import Any, cast, get_args, get_origin
|
|
14
|
+
|
|
15
|
+
from creart import it
|
|
16
|
+
from launart import Launart
|
|
17
|
+
from graia.amnesia.builtins.sqla import SqlalchemyService, Base
|
|
18
|
+
|
|
19
|
+
from arclet.letoderea import Propagator, Contexts, STACK, Provider, ProviderFactory, Param, Depend, Subscriber
|
|
20
|
+
from arclet.letoderea.ref import Deref, generate
|
|
21
|
+
from arclet.letoderea.scope import global_propagators, global_providers
|
|
22
|
+
from sqlalchemy.ext import asyncio as sa_async
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DatabasePropagator(Propagator):
|
|
26
|
+
async def supply(self, ctx: Contexts, serv: Optional[SqlalchemyService] = None):
|
|
27
|
+
if serv is None:
|
|
28
|
+
return
|
|
29
|
+
session = serv.get_session()
|
|
30
|
+
stack = ctx[STACK]
|
|
31
|
+
session = await stack.enter_async_context(session)
|
|
32
|
+
return {"$db_session": session}
|
|
33
|
+
|
|
34
|
+
def compose(self):
|
|
35
|
+
yield self.supply, True, 20
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SessionProvider(Provider[sa_async.AsyncSession]):
|
|
39
|
+
priority = 10
|
|
40
|
+
|
|
41
|
+
async def __call__(self, context: Contexts):
|
|
42
|
+
if "$db_session" in context:
|
|
43
|
+
return context["$db_session"]
|
|
44
|
+
try:
|
|
45
|
+
db = it(Launart).get_component(SqlalchemyService)
|
|
46
|
+
stack = context[STACK]
|
|
47
|
+
sess = await stack.enter_async_context(db.get_session())
|
|
48
|
+
context["$db_session"] = sess
|
|
49
|
+
return sess
|
|
50
|
+
except ValueError:
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass(unsafe_hash=True)
|
|
55
|
+
class Option:
|
|
56
|
+
stream: bool = True
|
|
57
|
+
scalars: bool = False
|
|
58
|
+
calls: tuple[methodcaller, ...] = field(default_factory=tuple)
|
|
59
|
+
result: methodcaller | None = None
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
PATTERNS = {
|
|
63
|
+
AsyncIterator[Sequence[Row[tuple[Any, ...]]]]: Option(
|
|
64
|
+
True,
|
|
65
|
+
False,
|
|
66
|
+
(methodcaller("partitions"),),
|
|
67
|
+
),
|
|
68
|
+
AsyncIterator[Sequence[tuple[Any, ...]]]: Option(
|
|
69
|
+
True,
|
|
70
|
+
False,
|
|
71
|
+
(methodcaller("partitions"),),
|
|
72
|
+
),
|
|
73
|
+
AsyncIterator[Sequence[Any]]: Option(
|
|
74
|
+
True,
|
|
75
|
+
True,
|
|
76
|
+
(methodcaller("partitions"),),
|
|
77
|
+
),
|
|
78
|
+
Iterator[Sequence[Row[tuple[Any, ...]]]]: Option(
|
|
79
|
+
False,
|
|
80
|
+
False,
|
|
81
|
+
(methodcaller("partitions"),),
|
|
82
|
+
),
|
|
83
|
+
Iterator[Sequence[tuple[Any, ...]]]: Option(
|
|
84
|
+
False,
|
|
85
|
+
False,
|
|
86
|
+
(methodcaller("partitions"),),
|
|
87
|
+
),
|
|
88
|
+
Iterator[Sequence[Any]]: Option(
|
|
89
|
+
False,
|
|
90
|
+
True,
|
|
91
|
+
(methodcaller("partitions"),),
|
|
92
|
+
),
|
|
93
|
+
AsyncResult[tuple[Any, ...]]: Option(
|
|
94
|
+
True,
|
|
95
|
+
False,
|
|
96
|
+
),
|
|
97
|
+
AsyncScalarResult[Any]: Option(
|
|
98
|
+
True,
|
|
99
|
+
True,
|
|
100
|
+
),
|
|
101
|
+
Result[tuple[Any, ...]]: Option(
|
|
102
|
+
False,
|
|
103
|
+
False,
|
|
104
|
+
),
|
|
105
|
+
ScalarResult[Any]: Option(
|
|
106
|
+
False,
|
|
107
|
+
True,
|
|
108
|
+
),
|
|
109
|
+
AsyncIterator[Row[tuple[Any, ...]]]: Option(
|
|
110
|
+
True,
|
|
111
|
+
False,
|
|
112
|
+
),
|
|
113
|
+
Iterator[Row[tuple[Any, ...]]]: Option(
|
|
114
|
+
False,
|
|
115
|
+
False,
|
|
116
|
+
),
|
|
117
|
+
Sequence[Row[tuple[Any, ...]]]: Option(
|
|
118
|
+
True,
|
|
119
|
+
False,
|
|
120
|
+
(),
|
|
121
|
+
methodcaller("all"),
|
|
122
|
+
),
|
|
123
|
+
Sequence[tuple[Any, ...]]: Option(
|
|
124
|
+
True,
|
|
125
|
+
False,
|
|
126
|
+
(),
|
|
127
|
+
methodcaller("all"),
|
|
128
|
+
),
|
|
129
|
+
Sequence[Any]: Option(
|
|
130
|
+
True,
|
|
131
|
+
True,
|
|
132
|
+
(),
|
|
133
|
+
methodcaller("all"),
|
|
134
|
+
),
|
|
135
|
+
tuple[Any, ...]: Option(
|
|
136
|
+
True,
|
|
137
|
+
False,
|
|
138
|
+
(),
|
|
139
|
+
methodcaller("one_or_none"),
|
|
140
|
+
),
|
|
141
|
+
Any: Option(
|
|
142
|
+
True,
|
|
143
|
+
True,
|
|
144
|
+
(),
|
|
145
|
+
methodcaller("one_or_none"),
|
|
146
|
+
),
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class SQLDepend(Depend):
|
|
151
|
+
def __init__(self, statement: ExecutableReturnsRows, option: Option = Option(), cache: bool = False):
|
|
152
|
+
super().__init__(lambda : None, cache)
|
|
153
|
+
self.statement = statement
|
|
154
|
+
self.option = option
|
|
155
|
+
|
|
156
|
+
async def target(session: sa_async.AsyncSession, **params):
|
|
157
|
+
if self.option.stream:
|
|
158
|
+
result = await session.stream(self.statement, params)
|
|
159
|
+
else:
|
|
160
|
+
result = await session.execute(self.statement, params)
|
|
161
|
+
if self.option.scalars:
|
|
162
|
+
result = result.scalars()
|
|
163
|
+
for call in self.option.calls:
|
|
164
|
+
result = call(result)
|
|
165
|
+
if call := self.option.result:
|
|
166
|
+
result = call(result)
|
|
167
|
+
if self.option.stream:
|
|
168
|
+
result = await result
|
|
169
|
+
return result
|
|
170
|
+
|
|
171
|
+
parameters = [Parameter("session", Parameter.KEYWORD_ONLY, annotation=sa_async.AsyncSession)]
|
|
172
|
+
for name, depends in self.statement.compile().params.items():
|
|
173
|
+
if isinstance(depends, Depend):
|
|
174
|
+
parameters.append(Parameter(name, Parameter.KEYWORD_ONLY, default=depends))
|
|
175
|
+
elif isinstance(depends, Deref):
|
|
176
|
+
parameters.append(Parameter(name, Parameter.KEYWORD_ONLY, default=Depend(generate(depends))))
|
|
177
|
+
target.__signature__ = Signature(parameters) # type: ignore
|
|
178
|
+
self.target = target
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def SQLDepends(statement: ExecutableReturnsRows, option: Option = Option(), cache: bool = False) -> Any:
|
|
182
|
+
return SQLDepend(statement, option, cache)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ORMProviderFactory(ProviderFactory):
|
|
186
|
+
priority = 10
|
|
187
|
+
|
|
188
|
+
class _ModelProvider(Provider[Any]):
|
|
189
|
+
def __init__(self, statement: ExecutableReturnsRows, option: Option):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.statement = statement
|
|
192
|
+
self.option = option
|
|
193
|
+
|
|
194
|
+
async def __call__(self, context: Contexts):
|
|
195
|
+
if "$db_session" not in context:
|
|
196
|
+
return
|
|
197
|
+
sess: sa_async.AsyncSession = context["$db_session"]
|
|
198
|
+
if self.option.stream:
|
|
199
|
+
result = await sess.stream(self.statement)
|
|
200
|
+
else:
|
|
201
|
+
result = await sess.execute(self.statement)
|
|
202
|
+
if self.option.scalars:
|
|
203
|
+
result = result.scalars()
|
|
204
|
+
for call in self.option.calls:
|
|
205
|
+
result = call(result)
|
|
206
|
+
if call := self.option.result:
|
|
207
|
+
result = call(result)
|
|
208
|
+
if self.option.stream:
|
|
209
|
+
result = await result
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
def validate(self, param: Param):
|
|
213
|
+
if isinstance(param.default, SQLDepend):
|
|
214
|
+
return
|
|
215
|
+
for pattern, option in PATTERNS.items():
|
|
216
|
+
if models := cast("list[Any]", generic_issubclass(pattern, param.annotation, list_=True)):
|
|
217
|
+
break
|
|
218
|
+
else:
|
|
219
|
+
models, option = [], Option()
|
|
220
|
+
|
|
221
|
+
for index, model in enumerate(models):
|
|
222
|
+
if origin_is_union(get_origin(model)):
|
|
223
|
+
models[index] = next(
|
|
224
|
+
(
|
|
225
|
+
arg
|
|
226
|
+
for arg in get_args(model)
|
|
227
|
+
if isclass(arg) and issubclass(arg, Base)
|
|
228
|
+
),
|
|
229
|
+
None,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if not (isclass(models[index]) and issubclass(models[index], Base)):
|
|
233
|
+
models = []
|
|
234
|
+
break
|
|
235
|
+
if not models:
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
statement = select(*models)
|
|
239
|
+
return self._ModelProvider(statement, option)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
global_propagators.append(db_supplier := DatabasePropagator())
|
|
243
|
+
global_providers.append(sess_provider := SessionProvider())
|
|
244
|
+
global_providers.append(orm_factory := ORMProviderFactory())
|
|
File without changes
|