lush-sqlalchemyx 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/PKG-INFO +1 -1
  2. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/pyproject.toml +6 -4
  3. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/__init__.py +1 -33
  4. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/_async.py +32 -130
  5. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/_common.py +5 -2
  6. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/_pagination.py +2 -2
  7. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/_repository.py +36 -24
  8. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/dal/_sync.py +32 -130
  9. lush_sqlalchemyx-0.3.0/src/lush_sqlalchemyx/base/dal/_async_v2.py +0 -184
  10. lush_sqlalchemyx-0.3.0/src/lush_sqlalchemyx/base/dal/_params.py +0 -32
  11. lush_sqlalchemyx-0.3.0/src/lush_sqlalchemyx/base/dal/_sync_v2.py +0 -183
  12. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/README.md +0 -0
  13. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/__init__.py +0 -0
  14. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/_compat.py +0 -0
  15. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/base/__init__.py +0 -0
  16. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/integrations/__init__.py +0 -0
  17. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/integrations/fastapi/__init__.py +0 -0
  18. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/integrations/fastapi/depends/__init__.py +0 -0
  19. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/integrations/flask/__init__.py +0 -0
  20. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/integrations/flask/ext.py +0 -0
  21. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/__init__.py +0 -0
  22. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/mysql/__init__.py +0 -0
  23. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/mysql/manager.py +0 -0
  24. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/mysql/mapper.py +0 -0
  25. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/mysql/sync_manager.py +0 -0
  26. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/mgrs/mysql/sync_mapper.py +0 -0
  27. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/py.typed +0 -0
  28. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/same_impl_just_warn_wrapper.py +0 -0
  29. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/shortcuts/__init__.py +0 -0
  30. {lush_sqlalchemyx-0.3.0 → lush_sqlalchemyx-0.3.2}/src/lush_sqlalchemyx/shortcuts/meta.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: lush-sqlalchemyx
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: SQLAlchemy helpers (DAL) and async MySQL managers, with some web frameworks integrations
5
5
  Author: straydragon
6
6
  Author-email: straydragon <straydragonl@foxmail.com>
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "lush-sqlalchemyx"
3
- version = "0.3.0"
3
+ version = "0.3.2"
4
4
  description = "SQLAlchemy helpers (DAL) and async MySQL managers, with some web frameworks integrations"
5
5
  readme = "README.md"
6
6
  requires-python = ">=3.10"
@@ -149,9 +149,6 @@ ignore = [
149
149
  ]
150
150
 
151
151
  [tool.ruff.lint.per-file-ignores]
152
- "**/dal/_*_v2.py" = [
153
- "ARG003", # extra 参数是 ABC 合规所必需但在部分方法中未使用
154
- ]
155
152
  "tests/**/*.py" = [
156
153
  # === 测试代码质量 ===
157
154
  "B011", # assert False应改为raise AssertionError
@@ -304,3 +301,8 @@ reportUnknownArgumentType = false
304
301
 
305
302
  [[tool.basedpyright.executionEnvironments]]
306
303
  root = "src"
304
+
305
+ [[tool.basedpyright.executionEnvironments]]
306
+ root = "examples"
307
+ # examples 作为下游类型检查 SSOT, 仅继承全局设置, 不做额外放宽.
308
+ # 如果此处报错, 说明库的类型变更破坏了合法下游用法.
@@ -6,21 +6,14 @@
6
6
  # --- shared (sync/async agnostic) ---
7
7
  # --- lush-dal-protocol ABCs (ORM 无关的抽象层) ---
8
8
  from lush_dal_protocol import (
9
- AbstractAsyncAdvancedWriteDAL,
10
9
  AbstractAsyncBaseDAL,
11
- AbstractAsyncBatchFieldDAL,
12
- AbstractAsyncLockDAL,
13
- AbstractAsyncRawSQLDAL,
14
10
  AbstractAsyncReadDAL,
15
11
  AbstractAsyncWriteDAL,
16
- AbstractSyncAdvancedWriteDAL,
17
12
  AbstractSyncBaseDAL,
18
- AbstractSyncBatchFieldDAL,
19
- AbstractSyncLockDAL,
20
- AbstractSyncRawSQLDAL,
21
13
  AbstractSyncReadDAL,
22
14
  AbstractSyncWriteDAL,
23
15
  )
16
+ from lush_dal_protocol.params.pagination import CursorPagination, CursorResult, OffsetPagination, PageResult
24
17
 
25
18
  # --- async (requires sqlalchemy[asyncio]) ---
26
19
  from ._async import (
@@ -40,9 +33,6 @@ from ._async import (
40
33
  async_temp_set_lock_wait_timeout,
41
34
  async_with_retry,
42
35
  )
43
-
44
- # --- V2 (ABC-compliant, options-based) ---
45
- from ._async_v2 import AsyncBaseDALV2, AsyncReadDALV2, AsyncWriteDALV2
46
36
  from ._common import (
47
37
  DEFAULT_RETRY_CONFIG,
48
38
  OPTIMISTIC_LOCK_ERROR_MSG_TRAIT,
@@ -68,10 +58,6 @@ from ._common import (
68
58
  from ._common import __prevent_readonly_write as __prevent_readonly_write # pyright: ignore[reportPrivateUsage]
69
59
  from ._common import __receive_before_flush as __receive_before_flush # pyright: ignore[reportPrivateUsage]
70
60
  from ._pagination import (
71
- CursorPagination,
72
- CursorResult,
73
- OffsetPagination,
74
- PageResult,
75
61
  build_cursor_stmt,
76
62
  build_offset_stmt,
77
63
  decode_cursor,
@@ -79,7 +65,6 @@ from ._pagination import (
79
65
  make_cursor_result,
80
66
  make_page_result,
81
67
  )
82
- from ._params import SQLAExtra
83
68
  from ._repository import AsyncSQLAlchemyRepository, SyncSQLAlchemyRepository
84
69
 
85
70
  # --- sync ---
@@ -100,7 +85,6 @@ from ._sync import (
100
85
  sync_temp_set_lock_wait_timeout,
101
86
  sync_with_retry,
102
87
  )
103
- from ._sync_v2 import SyncBaseDALV2, SyncReadDALV2, SyncWriteDALV2
104
88
 
105
89
  __all__ = (
106
90
  # common
@@ -165,30 +149,14 @@ __all__ = (
165
149
  "encode_cursor",
166
150
  "make_cursor_result",
167
151
  "make_page_result",
168
- # V2
169
- "AsyncBaseDALV2",
170
- "AsyncReadDALV2",
171
- "AsyncWriteDALV2",
172
- "SQLAExtra",
173
- "SyncBaseDALV2",
174
- "SyncReadDALV2",
175
- "SyncWriteDALV2",
176
152
  # repository
177
153
  "AsyncSQLAlchemyRepository",
178
154
  "SyncSQLAlchemyRepository",
179
155
  # lush-dal-protocol ABCs
180
- "AbstractAsyncAdvancedWriteDAL",
181
156
  "AbstractAsyncBaseDAL",
182
- "AbstractAsyncBatchFieldDAL",
183
- "AbstractAsyncLockDAL",
184
- "AbstractAsyncRawSQLDAL",
185
157
  "AbstractAsyncReadDAL",
186
158
  "AbstractAsyncWriteDAL",
187
- "AbstractSyncAdvancedWriteDAL",
188
159
  "AbstractSyncBaseDAL",
189
- "AbstractSyncBatchFieldDAL",
190
- "AbstractSyncLockDAL",
191
- "AbstractSyncRawSQLDAL",
192
160
  "AbstractSyncReadDAL",
193
161
  "AbstractSyncWriteDAL",
194
162
  )
@@ -15,6 +15,7 @@ from contextlib import asynccontextmanager, suppress
15
15
  from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, ParamSpec, TypeVar, cast
16
16
 
17
17
  import sqlalchemy as sa
18
+ from lush_dal_protocol.abc import AbstractAsyncReadDAL, AbstractAsyncWriteDAL
18
19
  from pydantic import BaseModel
19
20
  from sqlalchemy import ColumnExpressionArgument
20
21
  from sqlalchemy.ext.asyncio import AsyncAttrs, AsyncSession
@@ -121,7 +122,8 @@ async def async_temp_set_lock_wait_timeout(
121
122
  # Async Table bases
122
123
  # ---------------------------------------------------------------------------
123
124
 
124
- AsyncSQLATableT = TypeVar("AsyncSQLATableT", bound="AsyncSqlATableBase")
125
+ # 隐含约束: AsyncSQLATableT 应为 AsyncSqlATableBase (AsyncAttrs + DeclarativeBase) 子类.
126
+ AsyncSQLATableT = TypeVar("AsyncSQLATableT")
125
127
 
126
128
 
127
129
  class AsyncSqlATableBase(AsyncAttrs, DeclarativeBase):
@@ -315,7 +317,11 @@ class AsyncRawReadDAL:
315
317
  last_id = getattr(batch[-1], id_attr.key)
316
318
 
317
319
 
318
- class AsyncReadDAL(AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT]):
320
+ class AsyncReadDAL(
321
+ AsyncRawReadDAL,
322
+ AbstractAsyncReadDAL[AsyncSession, AsyncSQLATableT, DTOModelT, int],
323
+ Generic[AsyncSQLATableT, DTOModelT],
324
+ ):
319
325
  """抽象只读数据访问层基类."""
320
326
 
321
327
  _Table: ClassVar[type[AsyncSQLATableT]] # pyright: ignore[reportGeneralTypeIssues]
@@ -425,15 +431,15 @@ class AsyncReadDAL(AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT]):
425
431
  return entity is not None
426
432
 
427
433
  @classmethod
428
- async def _get_by_id_for_update_core(
434
+ async def get_by_id_for_update(
429
435
  cls,
430
436
  session: AsyncSession,
431
437
  entity_id: int,
432
438
  *,
433
- timeout: int | None = None,
439
+ lock_wait_timeout: int | None = None,
434
440
  ) -> AsyncSQLATableT | None:
435
441
  try:
436
- async with async_temp_set_lock_wait_timeout(session, timeout):
442
+ async with async_temp_set_lock_wait_timeout(session, lock_wait_timeout):
437
443
  stmt = (
438
444
  sa.select(cls._Table)
439
445
  .where(cls._Table.id == entity_id) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue, reportUnknownArgumentType]
@@ -448,29 +454,19 @@ class AsyncReadDAL(AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT]):
448
454
  raise
449
455
 
450
456
  @classmethod
451
- async def get_by_id_for_update(
452
- cls,
453
- session: AsyncSession,
454
- entity_id: int,
455
- *,
456
- lock_wait_timeout: int | None = None,
457
- ) -> AsyncSQLATableT | None:
458
- return await cls._get_by_id_for_update_core(session, entity_id, timeout=lock_wait_timeout)
459
-
460
- @classmethod
461
- async def _batch_get_for_update_core(
457
+ async def batch_get_for_update(
462
458
  cls,
463
459
  session: AsyncSession,
464
460
  entity_ids: Iterable[int],
465
461
  *,
466
- timeout: int | None = None,
462
+ lock_wait_timeout: int | None = None,
467
463
  ) -> list[AsyncSQLATableT]:
468
464
  filtered_ids = filtered_in_sql_values(entity_ids, int)
469
465
  if not filtered_ids:
470
466
  return []
471
467
 
472
468
  try:
473
- async with async_temp_set_lock_wait_timeout(session, timeout):
469
+ async with async_temp_set_lock_wait_timeout(session, lock_wait_timeout):
474
470
  stmt = (
475
471
  sa.select(cls._Table)
476
472
  .where(cls._Table.id.in_(filtered_ids)) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue, reportUnknownArgumentType]
@@ -485,25 +481,15 @@ class AsyncReadDAL(AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT]):
485
481
  raise
486
482
 
487
483
  @classmethod
488
- async def batch_get_for_update(
489
- cls,
490
- session: AsyncSession,
491
- entity_ids: Iterable[int],
492
- *,
493
- lock_wait_timeout: int | None = None,
494
- ) -> list[AsyncSQLATableT]:
495
- return await cls._batch_get_for_update_core(session, entity_ids, timeout=lock_wait_timeout)
496
-
497
- @classmethod
498
- async def _get_one_for_update_core(
484
+ async def get_one_for_update(
499
485
  cls,
500
486
  session: AsyncSession,
501
487
  *,
502
488
  where_clauses: list[ColumnExpressionArgument[bool]],
503
- timeout: int | None = None,
489
+ lock_wait_timeout: int | None = None,
504
490
  ) -> AsyncSQLATableT | None:
505
491
  try:
506
- async with async_temp_set_lock_wait_timeout(session, timeout):
492
+ async with async_temp_set_lock_wait_timeout(session, lock_wait_timeout):
507
493
  stmt = sa.select(cls._Table).with_for_update()
508
494
 
509
495
  for clause in where_clauses:
@@ -517,16 +503,6 @@ class AsyncReadDAL(AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT]):
517
503
  raise DBRetryableError(f"{PESSIMISTIC_LOCK_ERROR_MSG_TRAIT}-条件锁等待超时: {error_msg}") from e
518
504
  raise
519
505
 
520
- @classmethod
521
- async def get_one_for_update(
522
- cls,
523
- session: AsyncSession,
524
- *,
525
- where_clauses: list[ColumnExpressionArgument[bool]],
526
- lock_wait_timeout: int | None = None,
527
- ) -> AsyncSQLATableT | None:
528
- return await cls._get_one_for_update_core(session, where_clauses=where_clauses, timeout=lock_wait_timeout)
529
-
530
506
  @classmethod
531
507
  async def iter_record_dtos(
532
508
  cls,
@@ -563,7 +539,12 @@ class AsyncRawDAL:
563
539
  return await session.execute(stmt, params)
564
540
 
565
541
 
566
- class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOModelT, CUModelT]):
542
+ class AsyncWriteDAL(
543
+ AsyncRawDAL,
544
+ AsyncRawReadDAL,
545
+ AbstractAsyncWriteDAL[AsyncSession, AsyncSQLATableT, DTOModelT, CUModelT, int],
546
+ Generic[AsyncSQLATableT, DTOModelT, CUModelT],
547
+ ):
567
548
  """写入数据访问层基类."""
568
549
 
569
550
  _Table: ClassVar[type[AsyncSQLATableT]] # pyright: ignore[reportGeneralTypeIssues]
@@ -643,7 +624,7 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
643
624
  _ensure_strict_fields(provided_keys=provided_keys, allowed_names=allowed_names, strict=strict)
644
625
 
645
626
  @classmethod
646
- async def _update_full_by_id_core(
627
+ async def update_full_by_id(
647
628
  cls,
648
629
  session: AsyncSession,
649
630
  entity_id: int,
@@ -677,19 +658,7 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
677
658
  return entity
678
659
 
679
660
  @classmethod
680
- async def update_full_by_id(
681
- cls,
682
- session: AsyncSession,
683
- entity_id: int,
684
- cu: CUModelT,
685
- *,
686
- need_refresh: bool = False,
687
- strict_missing: bool = True,
688
- ) -> AsyncSQLATableT | None:
689
- return await cls._update_full_by_id_core(session, entity_id, cu, need_refresh=need_refresh, strict_missing=strict_missing)
690
-
691
- @classmethod
692
- async def _update_partial_by_id_core(
661
+ async def update_partial_by_id(
693
662
  cls,
694
663
  session: AsyncSession,
695
664
  entity_id: int,
@@ -756,30 +725,6 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
756
725
  await session.refresh(entity)
757
726
  return entity
758
727
 
759
- @classmethod
760
- async def update_partial_by_id(
761
- cls,
762
- session: AsyncSession,
763
- entity_id: int,
764
- cu: CUModelT,
765
- *,
766
- need_refresh: bool = False,
767
- fields: set[InstrumentedAttribute[Any]] | set[sa.Column[Any]] | None = None,
768
- none_policy: Literal["ignore", "allow", "forbid"] = "ignore",
769
- none_policy_overrides: dict[InstrumentedAttribute[Any] | sa.Column[Any], Literal["ignore", "allow", "forbid"]] | None = None,
770
- strict: bool = False,
771
- ) -> AsyncSQLATableT | None:
772
- return await cls._update_partial_by_id_core(
773
- session,
774
- entity_id,
775
- cu,
776
- need_refresh=need_refresh,
777
- fields=fields,
778
- none_policy=none_policy,
779
- none_policy_overrides=none_policy_overrides,
780
- strict=strict,
781
- )
782
-
783
728
  @classmethod
784
729
  async def delete_by_id(
785
730
  cls,
@@ -815,11 +760,11 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
815
760
  yield entity
816
761
 
817
762
  @classmethod
818
- async def _batch_update_by_conditions_core(
763
+ async def batch_update_by_conditions(
819
764
  cls,
820
765
  session: AsyncSession,
821
766
  *,
822
- conditions: list[ColumnExpressionArgument[bool]],
767
+ whereclause: list[ColumnExpressionArgument[bool]],
823
768
  update_data: dict[InstrumentedAttribute[Any], Any] | dict[sa.Column[Any], Any],
824
769
  updater_id: int | None = None,
825
770
  ) -> int:
@@ -844,24 +789,13 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
844
789
  if hasattr(cls._Table, "update_operator_id") and updater_id is not None:
845
790
  final_update_data["update_operator_id"] = updater_id
846
791
 
847
- stmt = sa.update(cls._Table).where(*conditions).values(**final_update_data)
792
+ stmt = sa.update(cls._Table).where(*whereclause).values(**final_update_data)
848
793
 
849
794
  result = await session.execute(stmt)
850
795
  await session.flush()
851
796
 
852
797
  return result.rowcount # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType, reportUnknownVariableType]
853
798
 
854
- @classmethod
855
- async def batch_update_by_conditions(
856
- cls,
857
- session: AsyncSession,
858
- *,
859
- whereclause: list[ColumnExpressionArgument[bool]],
860
- update_data: dict[InstrumentedAttribute[Any], Any] | dict[sa.Column[Any], Any],
861
- updater_id: int | None = None,
862
- ) -> int:
863
- return await cls._batch_update_by_conditions_core(session, conditions=whereclause, update_data=update_data, updater_id=updater_id)
864
-
865
799
  @classmethod
866
800
  async def batch_update_by_ids(
867
801
  cls,
@@ -875,15 +809,15 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
875
809
  if not filtered_ids:
876
810
  return 0
877
811
  _id_column = cls._Table.id # pyright: ignore[reportAttributeAccessIssue,reportUnknownVariableType, reportUnknownMemberType]
878
- return await cls._batch_update_by_conditions_core(
812
+ return await cls.batch_update_by_conditions(
879
813
  session,
880
- conditions=[_id_column.in_(filtered_ids)], # pyright: ignore[reportUnknownMemberType]
814
+ whereclause=[_id_column.in_(filtered_ids)], # pyright: ignore[reportUnknownMemberType]
881
815
  update_data=update_data,
882
816
  updater_id=updater_id,
883
817
  )
884
818
 
885
819
  @classmethod
886
- async def _update_only_set_with_optimistic_lock_core(
820
+ async def update_only_set_with_optimistic_lock(
887
821
  cls,
888
822
  session: AsyncSession,
889
823
  entity_id: int,
@@ -934,45 +868,13 @@ class AsyncWriteDAL(AsyncRawDAL, AsyncRawReadDAL, Generic[AsyncSQLATableT, DTOMo
934
868
 
935
869
  raise DBRetryableError(f"{OPTIMISTIC_LOCK_ERROR_MSG_TRAIT}-版本号不匹配({entity_id=}, {expected_version=})")
936
870
 
937
- @classmethod
938
- async def update_only_set_with_optimistic_lock(
939
- cls,
940
- session: AsyncSession,
941
- entity_id: int,
942
- cu: CUModelT,
943
- *,
944
- expected_version: int,
945
- need_refresh: bool = False,
946
- version_field: str = "version",
947
- ) -> AsyncSQLATableT | None:
948
- return await cls._update_only_set_with_optimistic_lock_core(
949
- session,
950
- entity_id,
951
- cu,
952
- expected_version=expected_version,
953
- need_refresh=need_refresh,
954
- version_field=version_field,
955
- )
956
-
957
871
 
958
872
  class AsyncXDALOp(AsyncRawReadDAL, AsyncRawDAL):
959
873
  """扩展数据访问操作类."""
960
874
 
961
875
 
962
876
  class AsyncBaseDAL(AsyncReadDAL[AsyncSQLATableT, DTOModelT], AsyncWriteDAL[AsyncSQLATableT, DTOModelT, CUModelT]):
963
- """基础数据访问层.
964
-
965
- .. deprecated:: 0.3.0
966
- V1 DAL 将在 1.0 移除, 请迁移至 ``AsyncBaseDALV2``.
967
- """
968
-
969
- def __init_subclass__(cls, **kwargs: Any) -> None:
970
- super().__init_subclass__(**kwargs)
971
- warnings.warn(
972
- f"{cls.__name__} 继承了 V1 AsyncBaseDAL, 建议迁移至 AsyncBaseDALV2",
973
- DeprecationWarning,
974
- stacklevel=2,
975
- )
877
+ """基础数据访问层."""
976
878
 
977
879
 
978
880
  class ReadOnlyAsyncBaseDAL(AsyncReadDAL[AsyncSQLATableT, ReadOnlyDTOModelT]):
@@ -21,7 +21,7 @@ from lush_dal_protocol.dto import DTOModelT as DTOModelT # noqa: PLC0414
21
21
  from pydantic import BaseModel, ConfigDict, Field
22
22
  from sqlalchemy import event as sa_event
23
23
  from sqlalchemy.exc import OperationalError as SQLAlchemyOperationalError
24
- from sqlalchemy.orm import DeclarativeBase, Mapped, ORMExecuteState, mapped_column, with_loader_criteria
24
+ from sqlalchemy.orm import Mapped, ORMExecuteState, mapped_column, with_loader_criteria
25
25
  from sqlalchemy.orm import Session as SyncSession
26
26
 
27
27
  READONLY_SESSION_FLAG: Final[str] = "__lush_sqlalchemyx__readonly_session__"
@@ -139,7 +139,10 @@ DEFAULT_RETRY_CONFIG = RetryConfig(max_attempts=3, initial_delay=0.1, max_delay=
139
139
  # Pydantic CU / DTO models
140
140
  # ---------------------------------------------------------------------------
141
141
 
142
- SQLATableT = TypeVar("SQLATableT", bound=DeclarativeBase)
142
+ # 隐含约束: SQLATableT 应为 SQLAlchemy DeclarativeBase 子类 ( Flask-SQLAlchemy db.Model).
143
+ # 不设 bound 是因为 Flask-SQLAlchemy 的 db.Model 运行时继承 DeclarativeBase,
144
+ # 但静态类型系统看不到该链路, bound 会误拦合法下游.
145
+ SQLATableT = TypeVar("SQLATableT")
143
146
 
144
147
 
145
148
  class BaseCU(_ProtocolBaseCU[SQLATableT]):
@@ -42,7 +42,7 @@ def build_offset_stmt(
42
42
  if order_by is not None:
43
43
  stmt = stmt.order_by(order_by)
44
44
  else:
45
- stmt = stmt.order_by(table.id) # pyright: ignore[reportAttributeAccessIssue]
45
+ stmt = stmt.order_by(table.id)
46
46
  return stmt.offset(p.skip).limit(p.limit)
47
47
 
48
48
 
@@ -55,7 +55,7 @@ def build_cursor_stmt(
55
55
  使用 id > cursor_value 的 keyset 分页方式.
56
56
  """
57
57
  p = pagination or CursorPagination()
58
- id_col = table.id # pyright: ignore[reportAttributeAccessIssue]
58
+ id_col = table.id
59
59
  stmt = sa.select(table).order_by(id_col)
60
60
 
61
61
  if p.cursor is not None:
@@ -1,15 +1,21 @@
1
1
  """SQLAlchemy 具体 Repository 实现.
2
2
 
3
- 提供高层声明式 CRUD 接口, 内部委托给 DAL V2.
3
+ 提供高层声明式 CRUD 接口, 内部委托给 DAL.
4
4
  """
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from collections.abc import Callable, Iterable
8
+ from collections.abc import Callable, Generator, Iterable
9
9
  from contextlib import contextmanager
10
- from typing import Any, ClassVar, Generic, TypeVar
10
+ from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeVar, cast
11
11
 
12
12
  import sqlalchemy as sa
13
+
14
+ if TYPE_CHECKING:
15
+ from contextlib import AbstractAsyncContextManager
16
+
17
+ from sqlalchemy.ext.asyncio import AsyncSession
18
+ from sqlalchemy.orm import Session
13
19
  from lush_dal_protocol.params.pagination import CursorPagination, CursorResult, OffsetPagination, PageResult
14
20
  from lush_dal_protocol.repository import AbstractAsyncRepository, AbstractSyncRepository
15
21
 
@@ -31,15 +37,21 @@ class SyncSQLAlchemyRepository(
31
37
  _session_factory: 返回 Session 的工厂函数
32
38
  """
33
39
 
34
- _Table: ClassVar[type]
35
- _DTO: ClassVar[type]
36
- _session_factory: ClassVar[Callable[[], Any]]
40
+ _Table: ClassVar[type[TableT]] # pyright: ignore[reportGeneralTypeIssues] — ClassVar + TypeVar 是 pyright 已知限制
41
+ _DTO: ClassVar[type[DTOModelT]] # pyright: ignore[reportGeneralTypeIssues]
42
+ _session_factory: ClassVar[Callable[[], Session]]
37
43
 
38
44
  @classmethod
39
- @contextmanager
40
- def _get_session(cls) -> Any:
41
- session = cls._session_factory() # pyright: ignore[reportAttributeAccessIssue]
45
+ def _make_session(cls) -> Session:
46
+ """创建并配置一个新的 Session."""
47
+ session = cls._session_factory() # pyright: ignore[reportAttributeAccessIssue] — pyright generic classmethod limitation
42
48
  session.expire_on_commit = False
49
+ return session
50
+
51
+ @classmethod
52
+ @contextmanager
53
+ def _get_session(cls) -> Generator[Session, None, None]:
54
+ session = cls._make_session()
43
55
  try:
44
56
  yield session
45
57
  session.commit()
@@ -150,10 +162,10 @@ class SyncSQLAlchemyRepository(
150
162
  if not pk_list:
151
163
  return 0
152
164
  with cls._get_session() as session:
153
- id_col = cls._Table.id # pyright: ignore[reportAttributeAccessIssue]
165
+ id_col: sa.Column[int] = cast("sa.Column[int]", cls._Table.id) # pyright: ignore[reportAttributeAccessIssue] — SA 列描述符对 pyright 不可见 # SA 列描述符对 pyright 不可见
154
166
  stmt = sa.update(cls._Table).where(id_col.in_(pk_list)).values(**data)
155
- result = session.execute(stmt)
156
- return result.rowcount # pyright: ignore[reportReturnType]
167
+ result = cast("sa.CursorResult[Any]", session.execute(stmt)) # SA stubs 返回 Result, 但运行时为 CursorResult
168
+ return result.rowcount
157
169
 
158
170
  @classmethod
159
171
  def bulk_delete(cls, pks: Iterable[int]) -> int:
@@ -162,10 +174,10 @@ class SyncSQLAlchemyRepository(
162
174
  if not pk_list:
163
175
  return 0
164
176
  with cls._get_session() as session:
165
- id_col = cls._Table.id # pyright: ignore[reportAttributeAccessIssue]
177
+ id_col: sa.Column[int] = cast("sa.Column[int]", cls._Table.id) # pyright: ignore[reportAttributeAccessIssue] — SA 列描述符对 pyright 不可见
166
178
  stmt = sa.delete(cls._Table).where(id_col.in_(pk_list))
167
- result = session.execute(stmt)
168
- return result.rowcount # pyright: ignore[reportReturnType]
179
+ result = cast("sa.CursorResult[Any]", session.execute(stmt))
180
+ return result.rowcount
169
181
 
170
182
 
171
183
  class AsyncSQLAlchemyRepository(
@@ -180,9 +192,9 @@ class AsyncSQLAlchemyRepository(
180
192
  _session_factory: 返回 async context manager 的工厂
181
193
  """
182
194
 
183
- _Table: ClassVar[type]
184
- _DTO: ClassVar[type]
185
- _session_factory: ClassVar[Callable[..., Any]]
195
+ _Table: ClassVar[type[TableT]] # pyright: ignore[reportGeneralTypeIssues]
196
+ _DTO: ClassVar[type[DTOModelT]] # pyright: ignore[reportGeneralTypeIssues]
197
+ _session_factory: ClassVar[Callable[..., AbstractAsyncContextManager[AsyncSession]]]
186
198
 
187
199
  @classmethod
188
200
  async def get(cls, pk: int) -> TableT | None:
@@ -291,11 +303,11 @@ class AsyncSQLAlchemyRepository(
291
303
  if not pk_list:
292
304
  return 0
293
305
  async with cls._session_factory() as session:
294
- id_col = cls._Table.id # pyright: ignore[reportAttributeAccessIssue]
306
+ id_col: sa.Column[int] = cast("sa.Column[int]", cls._Table.id) # pyright: ignore[reportAttributeAccessIssue] — SA 列描述符对 pyright 不可见
295
307
  stmt = sa.update(cls._Table).where(id_col.in_(pk_list)).values(**data)
296
- result = await session.execute(stmt)
308
+ result = cast("sa.CursorResult[Any]", await session.execute(stmt))
297
309
  await session.commit()
298
- return result.rowcount # pyright: ignore[reportReturnType]
310
+ return result.rowcount
299
311
 
300
312
  @classmethod
301
313
  async def bulk_delete(cls, pks: Iterable[int]) -> int:
@@ -304,8 +316,8 @@ class AsyncSQLAlchemyRepository(
304
316
  if not pk_list:
305
317
  return 0
306
318
  async with cls._session_factory() as session:
307
- id_col = cls._Table.id # pyright: ignore[reportAttributeAccessIssue]
319
+ id_col: sa.Column[int] = cast("sa.Column[int]", cls._Table.id) # pyright: ignore[reportAttributeAccessIssue] — SA 列描述符对 pyright 不可见
308
320
  stmt = sa.delete(cls._Table).where(id_col.in_(pk_list))
309
- result = await session.execute(stmt)
321
+ result = cast("sa.CursorResult[Any]", await session.execute(stmt))
310
322
  await session.commit()
311
- return result.rowcount # pyright: ignore[reportReturnType]
323
+ return result.rowcount