entari-plugin-database 0.3.0__tar.gz → 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: entari-plugin-database
3
- Version: 0.3.0
3
+ Version: 0.3.1
4
4
  Summary: Entari plugin for SQLAlchemy ORM
5
5
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "entari-plugin-database"
3
- version = "0.3.0"
3
+ version = "0.3.1"
4
4
  description = "Entari plugin for SQLAlchemy ORM"
5
5
  authors = [
6
6
  { name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
@@ -22,7 +22,7 @@ plugin.declare_static()
22
22
  plugin.metadata(
23
23
  "Database 服务",
24
24
  [{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
25
- "0.3.0",
25
+ "0.3.1",
26
26
  description="基于 SQLAlchemy 的数据库服务插件",
27
27
  urls={
28
28
  "homepage": "https://github.com/ArcletProject/entari-plugin-database",
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import re
5
6
  import hashlib
6
7
  import json
7
8
  from collections.abc import Callable, Iterable
@@ -31,6 +32,9 @@ from sqlalchemy import (
31
32
  MetaData,
32
33
  PrimaryKeyConstraint,
33
34
  UniqueConstraint,
35
+ Column,
36
+ DefaultClause,
37
+ TextClause,
34
38
  )
35
39
  from sqlalchemy.schema import Table
36
40
 
@@ -202,6 +206,44 @@ def _serialize_constraint(const: Any) -> dict[str, Any]:
202
206
  return const_info
203
207
 
204
208
 
209
+ def _serialize_column(column: Column[Any]) -> str:
210
+ kwargs = {}
211
+ if column.key != column.name:
212
+ kwargs["key"] = repr(column.key)
213
+ if column.primary_key:
214
+ kwargs["primary_key"] = repr(column.primary_key)
215
+ if not column.nullable:
216
+ kwargs["nullable"] = repr(column.nullable)
217
+ if column.onupdate:
218
+ kwargs["onupdate"] = repr(column.onupdate)
219
+ if column.default:
220
+ kwargs["default"] = repr(column.default)
221
+ if column.server_default:
222
+ if isinstance(column.server_default, DefaultClause) and isinstance(column.server_default.arg, TextClause):
223
+ kwargs["server_default"] = f"text({column.server_default.arg.text!r})"
224
+ else:
225
+ kwargs["server_default"] = repr(column.server_default)
226
+ if column.comment:
227
+ kwargs["comment"] = repr(column.comment)
228
+ ans = (
229
+ "Column("
230
+ + ", ".join(
231
+ [repr(column.name)]
232
+ + [repr(column.type)]
233
+ + [repr(x) for x in column.foreign_keys if x is not None]
234
+ + [repr(_serialize_constraint(x)) for x in column.constraints]
235
+ + [f"table={repr(column.table.description) if column.table is not None else 'None'}"]
236
+ + [f"{k}={v}" for k, v in kwargs.items()]
237
+ )
238
+ + ")"
239
+ )
240
+ return re.sub(
241
+ r"\s*at\s*0x[0-9a-fA-F]+",
242
+ "",
243
+ ans,
244
+ )
245
+
246
+
205
247
  def _get_table_structure(table: Table) -> dict[str, Any]:
206
248
  """将 SQLAlchemy Table 对象序列化为字典,用于后续哈希计算。"""
207
249
  # 按名称排序约束和索引以确保稳定性
@@ -214,8 +256,8 @@ def _get_table_structure(table: Table) -> dict[str, Any]:
214
256
  return {
215
257
  "name": table.name,
216
258
  "metadata": repr(table.metadata),
217
- "columns": [repr(col) for col in sorted(table.columns, key=lambda c: c.name)],
218
- "schema": f"schema={table.schema!r}",
259
+ "columns": [_serialize_column(col) for col in sorted(table.columns, key=lambda c: c.name)],
260
+ "schema": f"{table.schema!r}",
219
261
  "constraints": [_serialize_constraint(c) for c in sorted_constraints],
220
262
  "indexes": [repr(i) for i in sorted_indexes],
221
263
  }
@@ -269,6 +311,29 @@ def _include_tables_factory(target_tables: set[str]) -> Callable[[Any, str, str,
269
311
  return include
270
312
 
271
313
 
314
+ def _unwrap_default_clause(value: Any) -> Any:
315
+ """将 SQLAlchemy DefaultClause 还原为 Alembic/SQLAlchemy 期望的原始默认值。"""
316
+ if isinstance(value, DefaultClause):
317
+ return value.arg
318
+ return value
319
+
320
+
321
+ def _normalize_alembic_ops(ops_list: Iterable[Any]) -> list[Any]:
322
+ """在执行前归一化 Alembic 操作对象,修复 server_default 类型不兼容问题。"""
323
+ normalized: list[Any] = []
324
+
325
+ for op in ops_list:
326
+ if isinstance(op, alembic_ops.ModifyTableOps):
327
+ op.ops = _normalize_alembic_ops(op.ops)
328
+ elif isinstance(op, AlterColumnOp):
329
+ op.modify_server_default = _unwrap_default_clause(op.modify_server_default)
330
+ op.existing_server_default = _unwrap_default_clause(op.existing_server_default)
331
+
332
+ normalized.append(op)
333
+
334
+ return normalized
335
+
336
+
272
337
  def _execute_script(
273
338
  sync_conn: Connection,
274
339
  table: str,
@@ -722,6 +787,7 @@ async def _execute_auto_migration(
722
787
  def _apply_ops_direct(op_runner: Operations, ops_list: Iterable[Any]) -> bool:
723
788
  """直接应用迁移操作(非 SQLite)。"""
724
789
  applied = False
790
+ ops_list = _normalize_alembic_ops(ops_list)
725
791
 
726
792
  def apply(ops: Iterable[Any]) -> None:
727
793
  nonlocal applied
@@ -742,6 +808,7 @@ def _apply_ops_sqlite_batch(
742
808
  target_tables: set[str],
743
809
  ) -> bool:
744
810
  """使用 batch 模式应用迁移操作(SQLite)。"""
811
+ ops_list = _normalize_alembic_ops(ops_list)
745
812
 
746
813
  def iter_ops(ops: Iterable[Any]):
747
814
  for op in ops:
@@ -47,11 +47,7 @@ class DatabasePropagator(Propagator):
47
47
  return True
48
48
  if any(isinstance(p.annotation, sa_async.AsyncSession) for p in params):
49
49
  return True
50
- if any(
51
- isinstance(prod, ORMProviderFactory._ModelProvider)
52
- for p in params
53
- for prod in p.providers
54
- ): # noqa: E501,UP038
50
+ if any(isinstance(prod, ORMProviderFactory._ModelProvider) for p in params for prod in p.providers): # noqa: E501,UP038
55
51
  return True
56
52
  return False
57
53