entari-plugin-database 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 ARCLET
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,16 @@
1
+ Metadata-Version: 2.1
2
+ Name: entari-plugin-database
3
+ Version: 0.1.0
4
+ Summary: Entari plugin for SQLAlchemy ORM
5
+ Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
6
+ License: MIT
7
+ Requires-Python: >=3.9
8
+ Requires-Dist: sqlalchemy>=2.0.42
9
+ Requires-Dist: aiosqlite>=0.21.0
10
+ Requires-Dist: graia-amnesia>=0.10.1
11
+ Requires-Dist: tarina<0.8.0,>=0.7.1
12
+ Requires-Dist: arclet-entari>=0.15.0
13
+ Description-Content-Type: text/markdown
14
+
15
+ # entari-plugin-database
16
+ Entari plugin for SQLAlchemy ORM
@@ -0,0 +1,2 @@
1
+ # entari-plugin-database
2
+ Entari plugin for SQLAlchemy ORM
@@ -0,0 +1,33 @@
1
+ [project]
2
+ name = "entari-plugin-database"
3
+ version = "0.1.0"
4
+ description = "Entari plugin for SQLAlchemy ORM"
5
+ authors = [
6
+ { name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
7
+ ]
8
+ dependencies = [
9
+ "sqlalchemy>=2.0.42",
10
+ "aiosqlite>=0.21.0",
11
+ "graia-amnesia>=0.10.1",
12
+ "tarina<0.8.0,>=0.7.1",
13
+ "arclet-entari>=0.15.0",
14
+ ]
15
+ requires-python = ">=3.9"
16
+ readme = "README.md"
17
+
18
+ [project.license]
19
+ text = "MIT"
20
+
21
+ [build-system]
22
+ requires = [
23
+ "pdm-backend",
24
+ ]
25
+ build-backend = "pdm.backend"
26
+
27
+ [tool.pdm]
28
+ distribution = true
29
+
30
+ [tool.pdm.dev-dependencies]
31
+ dev = [
32
+ "arclet-entari[full]>=0.15.0",
33
+ ]
@@ -0,0 +1,154 @@
1
+ from dataclasses import field
2
+ from typing import Optional, Any, Literal
3
+ from typing import Union
4
+
5
+ from sqlalchemy.ext.asyncio import create_async_engine
6
+ from arclet.letoderea.scope import global_providers, global_propagators
7
+ from arclet.entari import BasicConfModel, plugin, logger
8
+ from arclet.entari.config import config_model_validate
9
+ from arclet.entari.event.config import ConfigReload
10
+ from graia.amnesia.builtins.sqla import SqlalchemyService
11
+ from graia.amnesia.builtins.sqla.model import register_callback, remove_callback
12
+ from graia.amnesia.builtins.sqla.model import Base as Base
13
+ from graia.amnesia.builtins.sqla.types import EngineOptions
14
+ from sqlalchemy.engine.url import URL
15
+ from sqlalchemy.ext import asyncio as sa_async
16
+ from sqlalchemy.orm import Mapped as Mapped
17
+ from sqlalchemy.orm import mapped_column as mapped_column
18
+
19
+ from .param import db_supplier, sess_provider, orm_factory
20
+ from .param import SQLDepends as SQLDepends
21
+
22
+
23
+ class UrlInfo(BasicConfModel):
24
+ type: str = "sqlite"
25
+ """数据库类型,默认为 sqlite"""
26
+ name: str = "data.db"
27
+ """数据库名称/文件路径"""
28
+ driver: str = "aiosqlite"
29
+ """数据库驱动,默认为 aiosqlite;其他类型的数据库驱动参考 SQLAlchemy 文档"""
30
+ host: Optional[str] = None
31
+ """数据库主机地址。如果是 SQLite 数据库,此项可不填。"""
32
+ port: Optional[int] = None
33
+ """数据库端口号。如果是 SQLite 数据库,此项可不填。"""
34
+ username: Optional[str] = None
35
+ """数据库用户名。如果是 SQLite 数据库,此项可不填。"""
36
+ password: Optional[str] = None
37
+ """数据库密码。如果是 SQLite 数据库,此项可不填。"""
38
+ query: dict[str, Union[list[str], str]] = field(default_factory=dict)
39
+ """数据库连接参数,默认为空字典。可以传入如 `{"timeout": "30"}` 的参数。"""
40
+
41
+ @property
42
+ def url(self) -> URL:
43
+ if self.type == "sqlite":
44
+ return URL.create(f"{self.type}+{self.driver}", database=self.name, query=self.query)
45
+ return URL.create(
46
+ f"{self.type}+{self.driver}", self.username, self.password, self.host, self.port, self.name, self.query
47
+ )
48
+
49
+
50
+ class Config(UrlInfo):
51
+ options: EngineOptions = field(default_factory=lambda: {"echo": None, "pool_pre_ping": True})
52
+ """数据库连接选项,默认为 `{"echo": None, "pool_pre_ping": True}`"""
53
+ session_options: Union[dict[str, Any], None] = field(default=None)
54
+ """数据库会话选项,默认为 None。可以传入如 `{"expire_on_commit": False}` 的字典。"""
55
+ binds: dict[str, UrlInfo] = field(default_factory=dict)
56
+ """数据库绑定配置,默认为 None。可以传入如 `{"bind1": UrlInfo(...), "bind2": UrlInfo(...)}` 的字典。"""
57
+ create_table_at: Literal['preparing', 'prepared', 'blocking'] = "preparing"
58
+ """在指定阶段创建数据库表,默认为 'preparing'。可选值为 'preparing', 'prepared', 'blocking'。"""
59
+
60
+
61
+ plugin.declare_static()
62
+ plugin.metadata(
63
+ "Database 服务",
64
+ [{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
65
+ "0.1.0",
66
+ description="基于 SQLAlchemy 的数据库服务插件",
67
+ urls={
68
+ "homepage": "https://github.com/ArcletProject/entari-plugin-database",
69
+ },
70
+ config=Config,
71
+ )
72
+ plugin.collect_disposes(
73
+ lambda: global_propagators.remove(db_supplier),
74
+ lambda: global_providers.remove(sess_provider),
75
+ lambda: global_providers.remove(orm_factory),
76
+ )
77
+
78
+ log = logger.log.wrapper("[Database]")
79
+ _config = plugin.get_config(Config)
80
+
81
+ try:
82
+ plugin.add_service(
83
+ service := SqlalchemyService(
84
+ _config.url,
85
+ _config.options,
86
+ _config.session_options,
87
+ {key: value.url for key, value in _config.binds.items()},
88
+ _config.create_table_at
89
+ )
90
+ )
91
+ except Exception as e:
92
+ raise RuntimeError("Failed to initialize SqlalchemyService. Please check your database configuration.") from e
93
+
94
+
95
+ @plugin.listen(ConfigReload)
96
+ async def reload_config(event: ConfigReload, serv: SqlalchemyService):
97
+ if event.scope != "plugin":
98
+ return None
99
+ if event.key not in ("database", "entari_plugin_database"):
100
+ return None
101
+ new_conf = config_model_validate(Config, event.value)
102
+ for engine in serv.engines.values():
103
+ await engine.dispose(close=True)
104
+ engine_options = {"echo": "debug", "pool_pre_ping": True}
105
+ serv.engines = {"": create_async_engine(new_conf.url, **(new_conf.options or engine_options))}
106
+ for key, bind in (new_conf.binds or {}).items():
107
+ serv.engines[key] = create_async_engine(bind.url, **(new_conf.options or engine_options))
108
+ serv.create_table_at = new_conf.create_table_at
109
+ serv.session_options = new_conf.session_options or {"expire_on_commit": False}
110
+
111
+ binds = await serv.initialize()
112
+ log.success("Database initialized!")
113
+ for key, models in binds.items():
114
+ async with serv.engines[key].begin() as conn:
115
+ await conn.run_sync(
116
+ serv.base_class.metadata.create_all, tables=[m.__table__ for m in models], checkfirst=True
117
+ )
118
+ log.success("Database tables created!")
119
+ return True
120
+
121
+
122
+ def _setup_tablename(cls: type[Base], kwargs: dict):
123
+ if "tablename" in kwargs:
124
+ return
125
+ for attr in ("__tablename__", "__table__"):
126
+ if getattr(cls, attr, None):
127
+ return
128
+
129
+ cls.__tablename__ = cls.__name__.lower()
130
+
131
+ if plg := plugin.get_plugin(3):
132
+ cls.__tablename__ = f"{plg.id.replace('-', '_')}_{cls.__tablename__}"
133
+
134
+
135
+ register_callback(_setup_tablename)
136
+ plugin.collect_disposes(lambda: remove_callback(_setup_tablename))
137
+
138
+
139
+ BaseOrm = Base
140
+ AsyncSession = sa_async.AsyncSession
141
+ get_session = service.get_session
142
+
143
+
144
+ __all__ = [
145
+ "AsyncSession",
146
+ "Base",
147
+ "BaseOrm",
148
+ "Mapped",
149
+ "mapped_column",
150
+ "service",
151
+ "SQLDepends",
152
+ "get_session",
153
+ "SqlalchemyService"
154
+ ]
@@ -0,0 +1,244 @@
1
+ from dataclasses import dataclass, field
2
+ from inspect import Signature, Parameter
3
+ from operator import methodcaller
4
+
5
+ from collections.abc import Iterator, Sequence, AsyncIterator
6
+ from typing import Optional
7
+
8
+ from sqlalchemy import Row, Result, ScalarResult, select
9
+ from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
10
+ from sqlalchemy.sql.selectable import ExecutableReturnsRows
11
+ from tarina import generic_issubclass
12
+ from tarina.generic import origin_is_union, isclass
13
+ from typing_extensions import Any, cast, get_args, get_origin
14
+
15
+ from creart import it
16
+ from launart import Launart
17
+ from graia.amnesia.builtins.sqla import SqlalchemyService, Base
18
+
19
+ from arclet.letoderea import Propagator, Contexts, STACK, Provider, ProviderFactory, Param, Depend, Subscriber
20
+ from arclet.letoderea.ref import Deref, generate
21
+ from arclet.letoderea.scope import global_propagators, global_providers
22
+ from sqlalchemy.ext import asyncio as sa_async
23
+
24
+
25
+ class DatabasePropagator(Propagator):
26
+ async def supply(self, ctx: Contexts, serv: Optional[SqlalchemyService] = None):
27
+ if serv is None:
28
+ return
29
+ session = serv.get_session()
30
+ stack = ctx[STACK]
31
+ session = await stack.enter_async_context(session)
32
+ return {"$db_session": session}
33
+
34
+ def compose(self):
35
+ yield self.supply, True, 20
36
+
37
+
38
+ class SessionProvider(Provider[sa_async.AsyncSession]):
39
+ priority = 10
40
+
41
+ async def __call__(self, context: Contexts):
42
+ if "$db_session" in context:
43
+ return context["$db_session"]
44
+ try:
45
+ db = it(Launart).get_component(SqlalchemyService)
46
+ stack = context[STACK]
47
+ sess = await stack.enter_async_context(db.get_session())
48
+ context["$db_session"] = sess
49
+ return sess
50
+ except ValueError:
51
+ return
52
+
53
+
54
+ @dataclass(unsafe_hash=True)
55
+ class Option:
56
+ stream: bool = True
57
+ scalars: bool = False
58
+ calls: tuple[methodcaller, ...] = field(default_factory=tuple)
59
+ result: methodcaller | None = None
60
+
61
+
62
+ PATTERNS = {
63
+ AsyncIterator[Sequence[Row[tuple[Any, ...]]]]: Option(
64
+ True,
65
+ False,
66
+ (methodcaller("partitions"),),
67
+ ),
68
+ AsyncIterator[Sequence[tuple[Any, ...]]]: Option(
69
+ True,
70
+ False,
71
+ (methodcaller("partitions"),),
72
+ ),
73
+ AsyncIterator[Sequence[Any]]: Option(
74
+ True,
75
+ True,
76
+ (methodcaller("partitions"),),
77
+ ),
78
+ Iterator[Sequence[Row[tuple[Any, ...]]]]: Option(
79
+ False,
80
+ False,
81
+ (methodcaller("partitions"),),
82
+ ),
83
+ Iterator[Sequence[tuple[Any, ...]]]: Option(
84
+ False,
85
+ False,
86
+ (methodcaller("partitions"),),
87
+ ),
88
+ Iterator[Sequence[Any]]: Option(
89
+ False,
90
+ True,
91
+ (methodcaller("partitions"),),
92
+ ),
93
+ AsyncResult[tuple[Any, ...]]: Option(
94
+ True,
95
+ False,
96
+ ),
97
+ AsyncScalarResult[Any]: Option(
98
+ True,
99
+ True,
100
+ ),
101
+ Result[tuple[Any, ...]]: Option(
102
+ False,
103
+ False,
104
+ ),
105
+ ScalarResult[Any]: Option(
106
+ False,
107
+ True,
108
+ ),
109
+ AsyncIterator[Row[tuple[Any, ...]]]: Option(
110
+ True,
111
+ False,
112
+ ),
113
+ Iterator[Row[tuple[Any, ...]]]: Option(
114
+ False,
115
+ False,
116
+ ),
117
+ Sequence[Row[tuple[Any, ...]]]: Option(
118
+ True,
119
+ False,
120
+ (),
121
+ methodcaller("all"),
122
+ ),
123
+ Sequence[tuple[Any, ...]]: Option(
124
+ True,
125
+ False,
126
+ (),
127
+ methodcaller("all"),
128
+ ),
129
+ Sequence[Any]: Option(
130
+ True,
131
+ True,
132
+ (),
133
+ methodcaller("all"),
134
+ ),
135
+ tuple[Any, ...]: Option(
136
+ True,
137
+ False,
138
+ (),
139
+ methodcaller("one_or_none"),
140
+ ),
141
+ Any: Option(
142
+ True,
143
+ True,
144
+ (),
145
+ methodcaller("one_or_none"),
146
+ ),
147
+ }
148
+
149
+
150
+ class SQLDepend(Depend):
151
+ def __init__(self, statement: ExecutableReturnsRows, option: Option = Option(), cache: bool = False):
152
+ super().__init__(lambda : None, cache)
153
+ self.statement = statement
154
+ self.option = option
155
+
156
+ async def target(session: sa_async.AsyncSession, **params):
157
+ if self.option.stream:
158
+ result = await session.stream(self.statement, params)
159
+ else:
160
+ result = await session.execute(self.statement, params)
161
+ if self.option.scalars:
162
+ result = result.scalars()
163
+ for call in self.option.calls:
164
+ result = call(result)
165
+ if call := self.option.result:
166
+ result = call(result)
167
+ if self.option.stream:
168
+ result = await result
169
+ return result
170
+
171
+ parameters = [Parameter("session", Parameter.KEYWORD_ONLY, annotation=sa_async.AsyncSession)]
172
+ for name, depends in self.statement.compile().params.items():
173
+ if isinstance(depends, Depend):
174
+ parameters.append(Parameter(name, Parameter.KEYWORD_ONLY, default=depends))
175
+ elif isinstance(depends, Deref):
176
+ parameters.append(Parameter(name, Parameter.KEYWORD_ONLY, default=Depend(generate(depends))))
177
+ target.__signature__ = Signature(parameters) # type: ignore
178
+ self.target = target
179
+
180
+
181
+ def SQLDepends(statement: ExecutableReturnsRows, option: Option = Option(), cache: bool = False) -> Any:
182
+ return SQLDepend(statement, option, cache)
183
+
184
+
185
+ class ORMProviderFactory(ProviderFactory):
186
+ priority = 10
187
+
188
+ class _ModelProvider(Provider[Any]):
189
+ def __init__(self, statement: ExecutableReturnsRows, option: Option):
190
+ super().__init__()
191
+ self.statement = statement
192
+ self.option = option
193
+
194
+ async def __call__(self, context: Contexts):
195
+ if "$db_session" not in context:
196
+ return
197
+ sess: sa_async.AsyncSession = context["$db_session"]
198
+ if self.option.stream:
199
+ result = await sess.stream(self.statement)
200
+ else:
201
+ result = await sess.execute(self.statement)
202
+ if self.option.scalars:
203
+ result = result.scalars()
204
+ for call in self.option.calls:
205
+ result = call(result)
206
+ if call := self.option.result:
207
+ result = call(result)
208
+ if self.option.stream:
209
+ result = await result
210
+ return result
211
+
212
+ def validate(self, param: Param):
213
+ if isinstance(param.default, SQLDepend):
214
+ return
215
+ for pattern, option in PATTERNS.items():
216
+ if models := cast("list[Any]", generic_issubclass(pattern, param.annotation, list_=True)):
217
+ break
218
+ else:
219
+ models, option = [], Option()
220
+
221
+ for index, model in enumerate(models):
222
+ if origin_is_union(get_origin(model)):
223
+ models[index] = next(
224
+ (
225
+ arg
226
+ for arg in get_args(model)
227
+ if isclass(arg) and issubclass(arg, Base)
228
+ ),
229
+ None,
230
+ )
231
+
232
+ if not (isclass(models[index]) and issubclass(models[index], Base)):
233
+ models = []
234
+ break
235
+ if not models:
236
+ return
237
+
238
+ statement = select(*models)
239
+ return self._ModelProvider(statement, option)
240
+
241
+
242
+ global_propagators.append(db_supplier := DatabasePropagator())
243
+ global_providers.append(sess_provider := SessionProvider())
244
+ global_providers.append(orm_factory := ORMProviderFactory())
File without changes