entari-plugin-database 0.2.4__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: entari-plugin-database
3
- Version: 0.2.4
3
+ Version: 0.3.1
4
4
  Summary: Entari plugin for SQLAlchemy ORM
5
5
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
6
6
  License: MIT
@@ -11,7 +11,7 @@ Requires-Dist: graia-amnesia>=0.11.4
11
11
  Requires-Dist: tarina<0.8.0,>=0.7.1
12
12
  Requires-Dist: arclet-letoderea>=0.19.5
13
13
  Requires-Dist: alembic>=1.16.5
14
- Requires-Dist: arclet-entari<0.18.0,>=0.17.0
14
+ Requires-Dist: arclet-entari<0.19.0,>=0.18.0rc1
15
15
  Description-Content-Type: text/markdown
16
16
 
17
17
  # entari-plugin-database
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "entari-plugin-database"
3
- version = "0.2.4"
3
+ version = "0.3.1"
4
4
  description = "Entari plugin for SQLAlchemy ORM"
5
5
  authors = [
6
6
  { name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
@@ -12,7 +12,7 @@ dependencies = [
12
12
  "tarina<0.8.0,>=0.7.1",
13
13
  "arclet-letoderea>=0.19.5",
14
14
  "alembic>=1.16.5",
15
- "arclet-entari<0.18.0,>=0.17.0",
15
+ "arclet-entari<0.19.0,>=0.18.0rc1",
16
16
  ]
17
17
  requires-python = ">=3.10"
18
18
  readme = "README.md"
@@ -29,11 +29,6 @@ build-backend = "pdm.backend"
29
29
  [tool.pdm]
30
30
  distribution = true
31
31
 
32
- [tool.pdm.dev-dependencies]
33
- dev = [
34
- "arclet-entari[full]>=0.17.0",
35
- ]
36
-
37
32
  [tool.ruff]
38
33
  line-length = 120
39
34
  target-version = "py310"
@@ -61,3 +56,13 @@ ignore = [
61
56
  "PYI055",
62
57
  "UP038",
63
58
  ]
59
+
60
+ [tool.pyright]
61
+ pythonVersion = "3.10"
62
+ pythonPlatform = "All"
63
+ typeCheckingMode = "basic"
64
+
65
+ [dependency-groups]
66
+ dev = [
67
+ "arclet-entari[full]>=0.17.0",
68
+ ]
@@ -1,8 +1,6 @@
1
1
  from sqlalchemy.ext.asyncio import create_async_engine
2
- from arclet.letoderea.provider import global_providers
3
- from arclet.letoderea.scope import global_propagators
4
- from arclet.letoderea.core import add_task
5
- from arclet.entari import plugin
2
+ from arclet.letoderea.utils import add_task
3
+ from arclet.entari import Plugin, plugin
6
4
  from arclet.entari.config import config_model_validate
7
5
  from arclet.entari.event.config import ConfigReload
8
6
  from graia.amnesia.builtins.sqla import SqlalchemyService
@@ -14,7 +12,6 @@ from sqlalchemy.ext import asyncio as sa_async
14
12
  from sqlalchemy.orm import Mapped as Mapped, instrumentation
15
13
  from sqlalchemy.orm import mapped_column as mapped_column
16
14
 
17
- from .param import db_supplier, sess_provider, orm_factory
18
15
  from .param import SQLDepends as SQLDepends
19
16
  from .utils import logger
20
17
  from .migration import run_migration, register_custom_migration
@@ -25,7 +22,7 @@ plugin.declare_static()
25
22
  plugin.metadata(
26
23
  "Database 服务",
27
24
  [{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
28
- "0.2.4",
25
+ "0.3.1",
29
26
  description="基于 SQLAlchemy 的数据库服务插件",
30
27
  urls={
31
28
  "homepage": "https://github.com/ArcletProject/entari-plugin-database",
@@ -33,11 +30,6 @@ plugin.metadata(
33
30
  config=Config,
34
31
  readme="README.md",
35
32
  )
36
- plugin.collect_disposes(
37
- lambda: global_propagators.remove(db_supplier),
38
- lambda: global_providers.remove(sess_provider),
39
- lambda: global_providers.remove(orm_factory),
40
- )
41
33
 
42
34
  _config = plugin.get_config(Config)
43
35
 
@@ -133,12 +125,17 @@ def migration_callback(cls: type[Base], kwargs: dict):
133
125
  task.add_done_callback(_PENDING_TASKS.discard)
134
126
 
135
127
 
136
- register_callback(_setup_tablename)
137
- register_callback(_clean_exist)
138
- register_callback(migration_callback, after=True)
139
- plugin.collect_disposes(lambda: remove_callback(_clean_exist))
140
- plugin.collect_disposes(lambda: remove_callback(_setup_tablename))
141
- plugin.collect_disposes(lambda: remove_callback(migration_callback))
128
+ def _register_callbacks():
129
+ register_callback(_setup_tablename)
130
+ register_callback(_clean_exist)
131
+ register_callback(migration_callback, after=True)
132
+ yield lambda: remove_callback(_clean_exist)
133
+ yield lambda: remove_callback(_setup_tablename)
134
+ yield lambda: remove_callback(migration_callback)
135
+
136
+
137
+ plg = Plugin.current()
138
+ plg.effect(_register_callbacks, "database model callbacks")
142
139
 
143
140
 
144
141
  BaseOrm = Base
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import re
5
6
  import hashlib
6
7
  import json
7
8
  from collections.abc import Callable, Iterable
@@ -31,6 +32,9 @@ from sqlalchemy import (
31
32
  MetaData,
32
33
  PrimaryKeyConstraint,
33
34
  UniqueConstraint,
35
+ Column,
36
+ DefaultClause,
37
+ TextClause,
34
38
  )
35
39
  from sqlalchemy.schema import Table
36
40
 
@@ -202,6 +206,44 @@ def _serialize_constraint(const: Any) -> dict[str, Any]:
202
206
  return const_info
203
207
 
204
208
 
209
+ def _serialize_column(column: Column[Any]) -> str:
210
+ kwargs = {}
211
+ if column.key != column.name:
212
+ kwargs["key"] = repr(column.key)
213
+ if column.primary_key:
214
+ kwargs["primary_key"] = repr(column.primary_key)
215
+ if not column.nullable:
216
+ kwargs["nullable"] = repr(column.nullable)
217
+ if column.onupdate:
218
+ kwargs["onupdate"] = repr(column.onupdate)
219
+ if column.default:
220
+ kwargs["default"] = repr(column.default)
221
+ if column.server_default:
222
+ if isinstance(column.server_default, DefaultClause) and isinstance(column.server_default.arg, TextClause):
223
+ kwargs["server_default"] = f"text({column.server_default.arg.text!r})"
224
+ else:
225
+ kwargs["server_default"] = repr(column.server_default)
226
+ if column.comment:
227
+ kwargs["comment"] = repr(column.comment)
228
+ ans = (
229
+ "Column("
230
+ + ", ".join(
231
+ [repr(column.name)]
232
+ + [repr(column.type)]
233
+ + [repr(x) for x in column.foreign_keys if x is not None]
234
+ + [repr(_serialize_constraint(x)) for x in column.constraints]
235
+ + [f"table={repr(column.table.description) if column.table is not None else 'None'}"]
236
+ + [f"{k}={v}" for k, v in kwargs.items()]
237
+ )
238
+ + ")"
239
+ )
240
+ return re.sub(
241
+ r"\s*at\s*0x[0-9a-fA-F]+",
242
+ "",
243
+ ans,
244
+ )
245
+
246
+
205
247
  def _get_table_structure(table: Table) -> dict[str, Any]:
206
248
  """将 SQLAlchemy Table 对象序列化为字典,用于后续哈希计算。"""
207
249
  # 按名称排序约束和索引以确保稳定性
@@ -214,8 +256,8 @@ def _get_table_structure(table: Table) -> dict[str, Any]:
214
256
  return {
215
257
  "name": table.name,
216
258
  "metadata": repr(table.metadata),
217
- "columns": [repr(col) for col in sorted(table.columns, key=lambda c: c.name)],
218
- "schema": f"schema={table.schema!r}",
259
+ "columns": [_serialize_column(col) for col in sorted(table.columns, key=lambda c: c.name)],
260
+ "schema": f"{table.schema!r}",
219
261
  "constraints": [_serialize_constraint(c) for c in sorted_constraints],
220
262
  "indexes": [repr(i) for i in sorted_indexes],
221
263
  }
@@ -269,6 +311,29 @@ def _include_tables_factory(target_tables: set[str]) -> Callable[[Any, str, str,
269
311
  return include
270
312
 
271
313
 
314
+ def _unwrap_default_clause(value: Any) -> Any:
315
+ """将 SQLAlchemy DefaultClause 还原为 Alembic/SQLAlchemy 期望的原始默认值。"""
316
+ if isinstance(value, DefaultClause):
317
+ return value.arg
318
+ return value
319
+
320
+
321
+ def _normalize_alembic_ops(ops_list: Iterable[Any]) -> list[Any]:
322
+ """在执行前归一化 Alembic 操作对象,修复 server_default 类型不兼容问题。"""
323
+ normalized: list[Any] = []
324
+
325
+ for op in ops_list:
326
+ if isinstance(op, alembic_ops.ModifyTableOps):
327
+ op.ops = _normalize_alembic_ops(op.ops)
328
+ elif isinstance(op, AlterColumnOp):
329
+ op.modify_server_default = _unwrap_default_clause(op.modify_server_default)
330
+ op.existing_server_default = _unwrap_default_clause(op.existing_server_default)
331
+
332
+ normalized.append(op)
333
+
334
+ return normalized
335
+
336
+
272
337
  def _execute_script(
273
338
  sync_conn: Connection,
274
339
  table: str,
@@ -722,6 +787,7 @@ async def _execute_auto_migration(
722
787
  def _apply_ops_direct(op_runner: Operations, ops_list: Iterable[Any]) -> bool:
723
788
  """直接应用迁移操作(非 SQLite)。"""
724
789
  applied = False
790
+ ops_list = _normalize_alembic_ops(ops_list)
725
791
 
726
792
  def apply(ops: Iterable[Any]) -> None:
727
793
  nonlocal applied
@@ -742,6 +808,7 @@ def _apply_ops_sqlite_batch(
742
808
  target_tables: set[str],
743
809
  ) -> bool:
744
810
  """使用 batch 模式应用迁移操作(SQLite)。"""
811
+ ops_list = _normalize_alembic_ops(ops_list)
745
812
 
746
813
  def iter_ops(ops: Iterable[Any]):
747
814
  for op in ops:
@@ -20,22 +20,40 @@ from arclet.letoderea import Propagator, Contexts, STACK, Provider, ProviderFact
20
20
  from arclet.letoderea.ref import Deref, generate
21
21
  from arclet.letoderea.provider import global_providers
22
22
  from arclet.letoderea.scope import global_propagators
23
+ from arclet.entari.plugin.model import Plugin
23
24
  from sqlalchemy.ext import asyncio as sa_async
24
25
 
25
26
 
27
+ class SessionProvider(Provider[sa_async.AsyncSession]):
28
+ priority = 10
29
+
30
+ async def __call__(self, context: Contexts):
31
+ if "$db_session" in context:
32
+ return context["$db_session"]
33
+ try:
34
+ db = it(Launart).get_component(SqlalchemyService)
35
+ stack = context[STACK]
36
+ sess = await stack.enter_async_context(db.get_session())
37
+ context["$db_session"] = sess
38
+ return sess
39
+ except ValueError:
40
+ return
41
+
42
+
26
43
  class DatabasePropagator(Propagator):
27
44
  def validate(self, subscriber: Subscriber):
28
45
  params = subscriber.params
29
46
  if any((p.depend and isinstance(p.depend, SQLDepend)) for p in params):
30
47
  return True
31
- if any(
32
- isinstance(prod, (SessionProvider, ORMProviderFactory._ModelProvider))
33
- for p in params
34
- for prod in p.providers
35
- ): # noqa: E501,UP038
48
+ if any(isinstance(p.annotation, sa_async.AsyncSession) for p in params):
49
+ return True
50
+ if any(isinstance(prod, ORMProviderFactory._ModelProvider) for p in params for prod in p.providers): # noqa: E501,UP038
36
51
  return True
37
52
  return False
38
53
 
54
+ def providers(self):
55
+ return [SessionProvider()]
56
+
39
57
  async def supply(self, ctx: Contexts, serv: SqlalchemyService | None = None):
40
58
  if serv is None:
41
59
  return
@@ -48,22 +66,6 @@ class DatabasePropagator(Propagator):
48
66
  yield self.supply, True, 20
49
67
 
50
68
 
51
- class SessionProvider(Provider[sa_async.AsyncSession]):
52
- priority = 10
53
-
54
- async def __call__(self, context: Contexts):
55
- if "$db_session" in context:
56
- return context["$db_session"]
57
- try:
58
- db = it(Launart).get_component(SqlalchemyService)
59
- stack = context[STACK]
60
- sess = await stack.enter_async_context(db.get_session())
61
- context["$db_session"] = sess
62
- return sess
63
- except ValueError:
64
- return
65
-
66
-
67
69
  @dataclass(unsafe_hash=True)
68
70
  class Option:
69
71
  stream: bool = True
@@ -254,6 +256,14 @@ class ORMProviderFactory(ProviderFactory):
254
256
  return self._ModelProvider(statement, option)
255
257
 
256
258
 
257
- global_propagators.append(db_supplier := DatabasePropagator())
258
- global_providers.append(sess_provider := SessionProvider())
259
- global_providers.append(orm_factory := ORMProviderFactory())
259
+ plg = Plugin.current()
260
+
261
+
262
+ def _provides():
263
+ global_propagators.append(db_supplier := DatabasePropagator())
264
+ global_providers.append(orm_factory := ORMProviderFactory())
265
+ yield lambda: global_propagators.remove(db_supplier)
266
+ yield lambda: global_providers.remove(orm_factory)
267
+
268
+
269
+ plg.effect(_provides, "database providers")