dclassql 0.1.6__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {dclassql-0.1.6 → dclassql-0.2.0}/PKG-INFO +1 -1
  2. {dclassql-0.1.6 → dclassql-0.2.0}/pyproject.toml +1 -1
  3. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/__init__.py +5 -2
  4. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/client.py +81 -21
  5. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/codegen.py +45 -0
  6. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/base.py +302 -27
  7. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/protocols.py +66 -2
  8. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/sqlite.py +66 -0
  9. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/where_compiler.py +5 -3
  10. dclassql-0.2.0/src/dclassql/runtime/sql_recorder.py +44 -0
  11. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/imports.jinja +1 -1
  12. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/model_section.jinja +69 -4
  13. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/typing.py +1 -0
  14. {dclassql-0.1.6 → dclassql-0.2.0}/README.md +0 -0
  15. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/.gitignore +0 -0
  16. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/asdict.py +0 -0
  17. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/asdict.pyi +0 -0
  18. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/cli.py +0 -0
  19. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/db_pool.py +0 -0
  20. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/generated_models/__init__.py +0 -0
  21. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/generated_models/test_models.py +0 -0
  22. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/model_inspector.py +0 -0
  23. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/push/__init__.py +0 -0
  24. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/push/base.py +0 -0
  25. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/push/sqlite.py +0 -0
  26. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/__init__.py +0 -0
  27. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/lazy.py +0 -0
  28. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/backends/metadata.py +0 -0
  29. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/datasource.py +0 -0
  30. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/runtime/sqlite_adapters.py +0 -0
  31. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/table_spec.py +0 -0
  32. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/__init__.py +0 -0
  33. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/asdict_stub.pyi.jinja +0 -0
  34. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/client_module.py.jinja +0 -0
  35. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/client_class.jinja +0 -0
  36. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/exports.jinja +0 -0
  37. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/macros.jinja +0 -0
  38. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/templates/partials/scalar_filters.jinja +0 -0
  39. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/unwarp.py +0 -0
  40. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/utils/__init__.py +0 -0
  41. {dclassql-0.1.6 → dclassql-0.2.0}/src/dclassql/utils/ensure.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dclassql
3
- Version: 0.1.6
3
+ Version: 0.2.0
4
4
  Summary: A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions.
5
5
  Keywords: orm,codegen,sqlite,dataclass,typed
6
6
  Author: myuanz
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dclassql"
3
- version = "0.1.6"
3
+ version = "0.2.0"
4
4
  description = "A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -1,10 +1,12 @@
1
- from .model_inspector import DataSourceConfig
1
+ from .asdict import asdict
2
2
  from .db_pool import BaseDBPool, save_local
3
+ from .model_inspector import DataSourceConfig
3
4
  from .push import db_push
4
5
  from .runtime.backends.lazy import eager
5
- from .asdict import asdict
6
+ from .runtime.sql_recorder import record_sql
6
7
  from .unwarp import unwarp, unwarp_or, unwarp_or_raise
7
8
 
9
+
8
10
  class _MissingClient:
9
11
  def __init__(self, *args: object, **kwargs: object) -> None:
10
12
  raise RuntimeError(
@@ -28,4 +30,5 @@ __all__ = [
28
30
  'BaseDBPool',
29
31
  'save_local',
30
32
  'DataSourceConfig',
33
+ 'record_sql',
31
34
  ]
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass, field
4
4
  from enum import Enum
5
5
  from types import MappingProxyType
6
- from typing import Any, Literal, Mapping, Sequence, NotRequired
6
+ from typing import Any, Literal, Mapping, Sequence, NotRequired, overload
7
7
  from typing_extensions import TypedDict
8
8
 
9
9
  from dclassql import DataSourceConfig
@@ -55,6 +55,7 @@ class StringFilter(TypedDict, total=False, closed=True):
55
55
 
56
56
  TAddressIncludeCol = Literal['user']
57
57
  TAddressSortableCol = Literal['id', 'location', 'user_id']
58
+ TAddressDistinctCol = Literal['id', 'location', 'user_id']
58
59
 
59
60
  @dataclass(slots=True, kw_only=True)
60
61
  class AddressInsert:
@@ -168,21 +169,32 @@ class AddressTable(TableProtocol):
168
169
  def insert_many(self, data: Sequence[AddressInsert | AddressInsertDict], *, batch_size: int | None = None) -> list[Address]:
169
170
  return self._backend.insert_many(self, data, batch_size=batch_size)
170
171
 
171
- def find_many(self, *, where: AddressWhereDict | None = None, include: AddressIncludeDict | None = None, order_by: AddressOrderByDict | None = None, take: int | None = None, skip: int | None = None) -> list[Address]:
172
+ def find_many(self, *, where: AddressWhereDict | None = None, include: AddressIncludeDict | None = None, order_by: AddressOrderByDict | None = None, distinct: TAddressDistinctCol | Sequence[TAddressDistinctCol] | None = None, take: int | None = None, skip: int | None = None) -> list[Address]:
172
173
  return self._backend.find_many(
173
174
  self,
174
- where=where, include=include, order_by=order_by,
175
+ where=where, include=include, order_by=order_by, distinct=distinct,
175
176
  take=take, skip=skip
176
177
  )
177
178
 
178
- def find_first(self, *, where: AddressWhereDict | None = None, include: AddressIncludeDict | None = None, order_by: AddressOrderByDict | None = None, skip: int | None = None) -> Address | None:
179
+ def find_first(self, *, where: AddressWhereDict | None = None, include: AddressIncludeDict | None = None, order_by: AddressOrderByDict | None = None, distinct: TAddressDistinctCol | Sequence[TAddressDistinctCol] | None = None, skip: int | None = None) -> Address | None:
179
180
  return self._backend.find_first(
180
181
  self,
181
- where=where, include=include, order_by=order_by,
182
+ where=where, include=include, order_by=order_by, distinct=distinct,
182
183
  skip=skip
183
184
  )
185
+
186
+ def delete(self, *, where: AddressWhereDict, include: AddressIncludeDict | None = None) -> Address | None:
187
+ return self._backend.delete(self, where=where, include=include)
188
+
189
+ @overload
190
+ def delete_many(self, *, where: AddressWhereDict | None = None, return_records: Literal[False] = False) -> int: ...
191
+ @overload
192
+ def delete_many(self, *, where: AddressWhereDict | None = None, return_records: Literal[True]) -> list[Address]: ...
193
+ def delete_many(self, *, where: AddressWhereDict | None = None, return_records: Literal[False, True] = False) -> int | list[Address]:
194
+ return self._backend.delete_many(self, where=where, return_records=return_records)
184
195
  TBirthDayIncludeCol = Literal['user']
185
196
  TBirthDaySortableCol = Literal['user_id', 'date']
197
+ TBirthDayDistinctCol = Literal['user_id', 'date']
186
198
 
187
199
  @dataclass(slots=True, kw_only=True)
188
200
  class BirthDayInsert:
@@ -286,21 +298,32 @@ class BirthDayTable(TableProtocol):
286
298
  def insert_many(self, data: Sequence[BirthDayInsert | BirthDayInsertDict], *, batch_size: int | None = None) -> list[BirthDay]:
287
299
  return self._backend.insert_many(self, data, batch_size=batch_size)
288
300
 
289
- def find_many(self, *, where: BirthDayWhereDict | None = None, include: BirthDayIncludeDict | None = None, order_by: BirthDayOrderByDict | None = None, take: int | None = None, skip: int | None = None) -> list[BirthDay]:
301
+ def find_many(self, *, where: BirthDayWhereDict | None = None, include: BirthDayIncludeDict | None = None, order_by: BirthDayOrderByDict | None = None, distinct: TBirthDayDistinctCol | Sequence[TBirthDayDistinctCol] | None = None, take: int | None = None, skip: int | None = None) -> list[BirthDay]:
290
302
  return self._backend.find_many(
291
303
  self,
292
- where=where, include=include, order_by=order_by,
304
+ where=where, include=include, order_by=order_by, distinct=distinct,
293
305
  take=take, skip=skip
294
306
  )
295
307
 
296
- def find_first(self, *, where: BirthDayWhereDict | None = None, include: BirthDayIncludeDict | None = None, order_by: BirthDayOrderByDict | None = None, skip: int | None = None) -> BirthDay | None:
308
+ def find_first(self, *, where: BirthDayWhereDict | None = None, include: BirthDayIncludeDict | None = None, order_by: BirthDayOrderByDict | None = None, distinct: TBirthDayDistinctCol | Sequence[TBirthDayDistinctCol] | None = None, skip: int | None = None) -> BirthDay | None:
297
309
  return self._backend.find_first(
298
310
  self,
299
- where=where, include=include, order_by=order_by,
311
+ where=where, include=include, order_by=order_by, distinct=distinct,
300
312
  skip=skip
301
313
  )
314
+
315
+ def delete(self, *, where: BirthDayWhereDict, include: BirthDayIncludeDict | None = None) -> BirthDay | None:
316
+ return self._backend.delete(self, where=where, include=include)
317
+
318
+ @overload
319
+ def delete_many(self, *, where: BirthDayWhereDict | None = None, return_records: Literal[False] = False) -> int: ...
320
+ @overload
321
+ def delete_many(self, *, where: BirthDayWhereDict | None = None, return_records: Literal[True]) -> list[BirthDay]: ...
322
+ def delete_many(self, *, where: BirthDayWhereDict | None = None, return_records: Literal[False, True] = False) -> int | list[BirthDay]:
323
+ return self._backend.delete_many(self, where=where, return_records=return_records)
302
324
  TBookIncludeCol = Literal['users']
303
325
  TBookSortableCol = Literal['id', 'name']
326
+ TBookDistinctCol = Literal['id', 'name']
304
327
 
305
328
  @dataclass(slots=True, kw_only=True)
306
329
  class BookInsert:
@@ -398,21 +421,32 @@ class BookTable(TableProtocol):
398
421
  def insert_many(self, data: Sequence[BookInsert | BookInsertDict], *, batch_size: int | None = None) -> list[Book]:
399
422
  return self._backend.insert_many(self, data, batch_size=batch_size)
400
423
 
401
- def find_many(self, *, where: BookWhereDict | None = None, include: BookIncludeDict | None = None, order_by: BookOrderByDict | None = None, take: int | None = None, skip: int | None = None) -> list[Book]:
424
+ def find_many(self, *, where: BookWhereDict | None = None, include: BookIncludeDict | None = None, order_by: BookOrderByDict | None = None, distinct: TBookDistinctCol | Sequence[TBookDistinctCol] | None = None, take: int | None = None, skip: int | None = None) -> list[Book]:
402
425
  return self._backend.find_many(
403
426
  self,
404
- where=where, include=include, order_by=order_by,
427
+ where=where, include=include, order_by=order_by, distinct=distinct,
405
428
  take=take, skip=skip
406
429
  )
407
430
 
408
- def find_first(self, *, where: BookWhereDict | None = None, include: BookIncludeDict | None = None, order_by: BookOrderByDict | None = None, skip: int | None = None) -> Book | None:
431
+ def find_first(self, *, where: BookWhereDict | None = None, include: BookIncludeDict | None = None, order_by: BookOrderByDict | None = None, distinct: TBookDistinctCol | Sequence[TBookDistinctCol] | None = None, skip: int | None = None) -> Book | None:
409
432
  return self._backend.find_first(
410
433
  self,
411
- where=where, include=include, order_by=order_by,
434
+ where=where, include=include, order_by=order_by, distinct=distinct,
412
435
  skip=skip
413
436
  )
437
+
438
+ def delete(self, *, where: BookWhereDict, include: BookIncludeDict | None = None) -> Book | None:
439
+ return self._backend.delete(self, where=where, include=include)
440
+
441
+ @overload
442
+ def delete_many(self, *, where: BookWhereDict | None = None, return_records: Literal[False] = False) -> int: ...
443
+ @overload
444
+ def delete_many(self, *, where: BookWhereDict | None = None, return_records: Literal[True]) -> list[Book]: ...
445
+ def delete_many(self, *, where: BookWhereDict | None = None, return_records: Literal[False, True] = False) -> int | list[Book]:
446
+ return self._backend.delete_many(self, where=where, return_records=return_records)
414
447
  TUserIncludeCol = Literal['addresses', 'birthday', 'books']
415
448
  TUserSortableCol = Literal['id', 'name', 'email', 'last_login']
449
+ TUserDistinctCol = Literal['id', 'name', 'email', 'last_login']
416
450
 
417
451
  @dataclass(slots=True, kw_only=True)
418
452
  class UserInsert:
@@ -551,21 +585,32 @@ class UserTable(TableProtocol):
551
585
  def insert_many(self, data: Sequence[UserInsert | UserInsertDict], *, batch_size: int | None = None) -> list[User]:
552
586
  return self._backend.insert_many(self, data, batch_size=batch_size)
553
587
 
554
- def find_many(self, *, where: UserWhereDict | None = None, include: UserIncludeDict | None = None, order_by: UserOrderByDict | None = None, take: int | None = None, skip: int | None = None) -> list[User]:
588
+ def find_many(self, *, where: UserWhereDict | None = None, include: UserIncludeDict | None = None, order_by: UserOrderByDict | None = None, distinct: TUserDistinctCol | Sequence[TUserDistinctCol] | None = None, take: int | None = None, skip: int | None = None) -> list[User]:
555
589
  return self._backend.find_many(
556
590
  self,
557
- where=where, include=include, order_by=order_by,
591
+ where=where, include=include, order_by=order_by, distinct=distinct,
558
592
  take=take, skip=skip
559
593
  )
560
594
 
561
- def find_first(self, *, where: UserWhereDict | None = None, include: UserIncludeDict | None = None, order_by: UserOrderByDict | None = None, skip: int | None = None) -> User | None:
595
+ def find_first(self, *, where: UserWhereDict | None = None, include: UserIncludeDict | None = None, order_by: UserOrderByDict | None = None, distinct: TUserDistinctCol | Sequence[TUserDistinctCol] | None = None, skip: int | None = None) -> User | None:
562
596
  return self._backend.find_first(
563
597
  self,
564
- where=where, include=include, order_by=order_by,
598
+ where=where, include=include, order_by=order_by, distinct=distinct,
565
599
  skip=skip
566
600
  )
601
+
602
+ def delete(self, *, where: UserWhereDict, include: UserIncludeDict | None = None) -> User | None:
603
+ return self._backend.delete(self, where=where, include=include)
604
+
605
+ @overload
606
+ def delete_many(self, *, where: UserWhereDict | None = None, return_records: Literal[False] = False) -> int: ...
607
+ @overload
608
+ def delete_many(self, *, where: UserWhereDict | None = None, return_records: Literal[True]) -> list[User]: ...
609
+ def delete_many(self, *, where: UserWhereDict | None = None, return_records: Literal[False, True] = False) -> int | list[User]:
610
+ return self._backend.delete_many(self, where=where, return_records=return_records)
567
611
  TUserBookIncludeCol = Literal['book', 'user']
568
612
  TUserBookSortableCol = Literal['user_id', 'book_id', 'created_at']
613
+ TUserBookDistinctCol = Literal['user_id', 'book_id', 'created_at']
569
614
 
570
615
  @dataclass(slots=True, kw_only=True)
571
616
  class UserBookInsert:
@@ -695,19 +740,29 @@ class UserBookTable(TableProtocol):
695
740
  def insert_many(self, data: Sequence[UserBookInsert | UserBookInsertDict], *, batch_size: int | None = None) -> list[UserBook]:
696
741
  return self._backend.insert_many(self, data, batch_size=batch_size)
697
742
 
698
- def find_many(self, *, where: UserBookWhereDict | None = None, include: UserBookIncludeDict | None = None, order_by: UserBookOrderByDict | None = None, take: int | None = None, skip: int | None = None) -> list[UserBook]:
743
+ def find_many(self, *, where: UserBookWhereDict | None = None, include: UserBookIncludeDict | None = None, order_by: UserBookOrderByDict | None = None, distinct: TUserBookDistinctCol | Sequence[TUserBookDistinctCol] | None = None, take: int | None = None, skip: int | None = None) -> list[UserBook]:
699
744
  return self._backend.find_many(
700
745
  self,
701
- where=where, include=include, order_by=order_by,
746
+ where=where, include=include, order_by=order_by, distinct=distinct,
702
747
  take=take, skip=skip
703
748
  )
704
749
 
705
- def find_first(self, *, where: UserBookWhereDict | None = None, include: UserBookIncludeDict | None = None, order_by: UserBookOrderByDict | None = None, skip: int | None = None) -> UserBook | None:
750
+ def find_first(self, *, where: UserBookWhereDict | None = None, include: UserBookIncludeDict | None = None, order_by: UserBookOrderByDict | None = None, distinct: TUserBookDistinctCol | Sequence[TUserBookDistinctCol] | None = None, skip: int | None = None) -> UserBook | None:
706
751
  return self._backend.find_first(
707
752
  self,
708
- where=where, include=include, order_by=order_by,
753
+ where=where, include=include, order_by=order_by, distinct=distinct,
709
754
  skip=skip
710
755
  )
756
+
757
+ def delete(self, *, where: UserBookWhereDict, include: UserBookIncludeDict | None = None) -> UserBook | None:
758
+ return self._backend.delete(self, where=where, include=include)
759
+
760
+ @overload
761
+ def delete_many(self, *, where: UserBookWhereDict | None = None, return_records: Literal[False] = False) -> int: ...
762
+ @overload
763
+ def delete_many(self, *, where: UserBookWhereDict | None = None, return_records: Literal[True]) -> list[UserBook]: ...
764
+ def delete_many(self, *, where: UserBookWhereDict | None = None, return_records: Literal[False, True] = False) -> int | list[UserBook]:
765
+ return self._backend.delete_many(self, where=where, return_records=return_records)
711
766
  class Client(BaseDBPool):
712
767
  _echo_sql: bool = False
713
768
  datasources = {
@@ -749,6 +804,7 @@ __all__ = (
749
804
  "Client",
750
805
  "TAddressIncludeCol",
751
806
  "TAddressSortableCol",
807
+ "TAddressDistinctCol",
752
808
  "AddressIncludeDict",
753
809
  "AddressOrderByDict",
754
810
  "AddressDict",
@@ -759,6 +815,7 @@ __all__ = (
759
815
  "AddressUserRelationFilter",
760
816
  "TBirthDayIncludeCol",
761
817
  "TBirthDaySortableCol",
818
+ "TBirthDayDistinctCol",
762
819
  "BirthDayIncludeDict",
763
820
  "BirthDayOrderByDict",
764
821
  "BirthDayDict",
@@ -769,6 +826,7 @@ __all__ = (
769
826
  "BirthDayUserRelationFilter",
770
827
  "TBookIncludeCol",
771
828
  "TBookSortableCol",
829
+ "TBookDistinctCol",
772
830
  "BookIncludeDict",
773
831
  "BookOrderByDict",
774
832
  "BookDict",
@@ -779,6 +837,7 @@ __all__ = (
779
837
  "BookUsersRelationFilter",
780
838
  "TUserIncludeCol",
781
839
  "TUserSortableCol",
840
+ "TUserDistinctCol",
782
841
  "UserIncludeDict",
783
842
  "UserOrderByDict",
784
843
  "UserDict",
@@ -791,6 +850,7 @@ __all__ = (
791
850
  "UserBooksRelationFilter",
792
851
  "TUserBookIncludeCol",
793
852
  "TUserBookSortableCol",
853
+ "TUserBookDistinctCol",
794
854
  "UserBookIncludeDict",
795
855
  "UserBookOrderByDict",
796
856
  "UserBookDict",
@@ -100,6 +100,12 @@ class ScalarFilterRender:
100
100
  fields: tuple[TypedDictFieldSpec, ...]
101
101
 
102
102
 
103
+ @dataclass(slots=True)
104
+ class UpsertWhereRender:
105
+ name: str
106
+ fields: tuple[TypedDictFieldSpec, ...]
107
+
108
+
103
109
  @dataclass(slots=True)
104
110
  class ModelRenderContext:
105
111
  name: str
@@ -107,6 +113,8 @@ class ModelRenderContext:
107
113
  table_name_literal: str
108
114
  insert_fields: tuple[InsertFieldSpec, ...]
109
115
  typed_dict_fields: tuple[TypedDictFieldSpec, ...]
116
+ update_fields: tuple[TypedDictFieldSpec, ...]
117
+ upsert_where_dicts: tuple["UpsertWhereRender", ...]
110
118
  dict_fields: tuple[TypedDictFieldSpec, ...]
111
119
  where_fields: tuple[WhereFieldSpec, ...]
112
120
  relation_filters: tuple[RelationFilterRender, ...]
@@ -116,6 +124,7 @@ class ModelRenderContext:
116
124
  primary_key_literal: str
117
125
  indexes_literal: str
118
126
  unique_indexes_literal: str
127
+ primary_value_types: tuple[str, ...]
119
128
  row_assignments: tuple[RowAssignmentRender, ...]
120
129
  default_factories: tuple[DefaultFactoryRender, ...]
121
130
  model_info: ModelInfo
@@ -222,8 +231,11 @@ def _build_model_context(
222
231
 
223
232
  insert_fields: list[InsertFieldSpec] = []
224
233
  typed_dict_fields: list[TypedDictFieldSpec] = []
234
+ update_fields: list[TypedDictFieldSpec] = []
235
+ upsert_where_dicts: list[UpsertWhereRender] = []
225
236
  dict_field_map: dict[str, str] = {}
226
237
  enum_type_map: dict[str, type[Enum] | None] = {}
238
+ column_lookup: dict[str, ColumnInfo] = {col.name: col for col in info.columns}
227
239
  for col in info.columns:
228
240
  annotation = _format_insert_annotation(col, renderer)
229
241
  default_fragment = _render_default_fragment(name, col)
@@ -248,6 +260,27 @@ def _build_model_context(
248
260
  enum_type = _resolve_enum_class(col.python_type)
249
261
  enum_type_map[col.name] = enum_type
250
262
 
263
+ update_fields.append(TypedDictFieldSpec(name=col.name, annotation=renderer.render(col.python_type)))
264
+
265
+ if info.primary_key:
266
+ pk_fields: list[TypedDictFieldSpec] = []
267
+ for pk_col in info.primary_key:
268
+ col_info = column_lookup.get(pk_col)
269
+ annotation = renderer.render(col_info.python_type) if col_info else "object"
270
+ pk_fields.append(TypedDictFieldSpec(name=pk_col, annotation=annotation))
271
+ upsert_where_dicts.append(UpsertWhereRender(name=f"{name}UpsertWherePK", fields=tuple(pk_fields)))
272
+
273
+ if info.unique_indexes:
274
+ for idx, unique_cols in enumerate(info.unique_indexes, start=1):
275
+ unique_fields: list[TypedDictFieldSpec] = []
276
+ for col_name in unique_cols:
277
+ col_info = column_lookup.get(col_name)
278
+ annotation = renderer.render(col_info.python_type) if col_info else "object"
279
+ unique_fields.append(TypedDictFieldSpec(name=col_name, annotation=annotation))
280
+ upsert_where_dicts.append(
281
+ UpsertWhereRender(name=f"{name}UpsertWhereUnique{idx}", fields=tuple(unique_fields))
282
+ )
283
+
251
284
  where_fields: list[WhereFieldSpec] = []
252
285
  for col in info.columns:
253
286
  annotation = renderer.render(col.python_type)
@@ -334,6 +367,13 @@ def _build_model_context(
334
367
  )
335
368
 
336
369
  row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map)
370
+ primary_value_types: list[str] = []
371
+ for column_name in info.primary_key:
372
+ column = column_lookup.get(column_name)
373
+ if column is None:
374
+ primary_value_types.append("object")
375
+ continue
376
+ primary_value_types.append(renderer.render(column.python_type))
337
377
 
338
378
  relation_lookup = {relation.name: relation for relation in info.relations}
339
379
  dataclass_fields = fields(info.model)
@@ -362,6 +402,8 @@ def _build_model_context(
362
402
  table_name_literal=repr(name),
363
403
  insert_fields=tuple(insert_fields),
364
404
  typed_dict_fields=tuple(typed_dict_fields),
405
+ update_fields=tuple(update_fields),
406
+ upsert_where_dicts=tuple(upsert_where_dicts),
365
407
  dict_fields=tuple(dict_fields),
366
408
  where_fields=tuple(where_fields),
367
409
  relation_filters=tuple(relation_filters),
@@ -371,6 +413,7 @@ def _build_model_context(
371
413
  primary_key_literal=_tuple_literal(info.primary_key),
372
414
  indexes_literal=indexes_literal,
373
415
  unique_indexes_literal=unique_indexes_literal,
416
+ primary_value_types=tuple(primary_value_types),
374
417
  row_assignments=tuple(row_assignments),
375
418
  default_factories=tuple(default_factories),
376
419
  model_info=info,
@@ -436,6 +479,8 @@ def _collect_exports(model_contexts: Sequence[ModelRenderContext]) -> list[str]:
436
479
  f"{name}Dict",
437
480
  f"{name}Insert",
438
481
  f"{name}InsertDict",
482
+ f"{name}UpdateDict",
483
+ f"{name}UpsertWhereDict",
439
484
  f"{name}WhereDict",
440
485
  f"{name}Table",
441
486
  ]
@@ -2,16 +2,17 @@ from __future__ import annotations
2
2
 
3
3
  import sys
4
4
  from abc import ABC, abstractmethod
5
- from typing import Any, Mapping, Sequence, cast
5
+ from typing import Any, Literal, Mapping, Sequence, cast, overload
6
6
  from weakref import ReferenceType, ref
7
7
 
8
8
  from pypika import Query, Table
9
9
  from pypika.enums import Order
10
- from pypika.terms import Criterion, Parameter
11
10
  from pypika.queries import QueryBuilder
11
+ from pypika.terms import Criterion, Parameter
12
12
  from pypika.utils import format_quotes
13
13
 
14
- from dclassql.typing import IncludeT, InsertT, ModelT, OrderByT, WhereT
14
+ from dclassql.runtime.sql_recorder import push_sql
15
+ from dclassql.typing import IncludeT, InsertT, ModelT, OrderByT, UpsertWhereT, WhereT
15
16
 
16
17
  from .lazy import ensure_lazy_state, finalize_lazy_state, reset_lazy_backref
17
18
  from .protocols import BackendProtocol, RelationSpec, TableProtocol
@@ -67,6 +68,118 @@ class BackendBase(BackendProtocol, ABC):
67
68
  _ = batch_size # 基础实现不做批量优化
68
69
  return [self.insert(table, item) for item in data]
69
70
 
71
+ def update(
72
+ self,
73
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
74
+ *,
75
+ data: Mapping[str, object],
76
+ where: WhereT,
77
+ include: Mapping[str, bool] | None = None,
78
+ ) -> ModelT:
79
+ payload = table.serialize_update(data)
80
+ if not payload:
81
+ raise ValueError("Update payload cannot be empty")
82
+
83
+ sql_table = self.table_cls(table.model.__name__)
84
+ update_query: QueryBuilder = self.query_cls.update(sql_table)
85
+ params: list[Any] = []
86
+ for column, value in payload.items():
87
+ update_query = update_query.set(sql_table.field(column), self.new_parameter())
88
+ params.append(value)
89
+
90
+ criterion, where_params = self._compile_where(table, sql_table, where)
91
+ if criterion is not None:
92
+ update_query = update_query.where(criterion)
93
+ params.extend(where_params)
94
+
95
+ sql = self._render_query(update_query)
96
+ returning_columns = [spec.name for spec in table.column_specs]
97
+ sql_with_returning = self._append_returning(sql, returning_columns)
98
+
99
+ rows = self.query_raw(sql_with_returning, params, auto_commit=True)
100
+ if len(rows) != 1:
101
+ raise RuntimeError(f"update() expected exactly 1 row, got {len(rows)}")
102
+
103
+ row = rows[0]
104
+ include_map = include or {}
105
+ instance = self._materialize_instance(table, row, include_map)
106
+ identity_key = self._identity_key(table, row)
107
+ if identity_key is not None:
108
+ self._identity_map.pop(identity_key, None)
109
+ self._invalidate_backrefs(table, instance)
110
+ return instance
111
+
112
+ def upsert(
113
+ self,
114
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
115
+ *,
116
+ where: UpsertWhereT,
117
+ update: Mapping[str, object],
118
+ insert: InsertT | Mapping[str, object],
119
+ include: Mapping[str, bool] | None = None,
120
+ ) -> ModelT:
121
+ where_payload = dict(where)
122
+ conflict_targets: list[tuple[str, ...]] = []
123
+ if table.primary_key:
124
+ conflict_targets.append(tuple(table.primary_key))
125
+ conflict_targets.extend(tuple(idx) for idx in getattr(table, "unique_indexes", ()))
126
+ if not conflict_targets:
127
+ raise ValueError("upsert requires primary key or unique index")
128
+
129
+ where_keys = set(where_payload.keys())
130
+ conflict_target: tuple[str, ...] | None = None
131
+ for target in conflict_targets:
132
+ if set(target) == where_keys:
133
+ conflict_target = target
134
+ break
135
+ if conflict_target is None:
136
+ raise ValueError("Upsert where must exactly match primary key or unique index")
137
+
138
+ insert_payload = table.serialize_insert(insert)
139
+ for column in conflict_target:
140
+ if column not in insert_payload:
141
+ insert_payload[column] = where_payload[column]
142
+ if not insert_payload:
143
+ raise ValueError("Upsert insert payload cannot be empty")
144
+
145
+ update_payload = table.serialize_update(update)
146
+ if not update_payload:
147
+ raise ValueError("Upsert update payload cannot be empty")
148
+
149
+ sql_table = self.table_cls(table.model.__name__)
150
+ insert_columns = list(insert_payload.keys())
151
+ params: list[Any] = [insert_payload[column] for column in insert_columns]
152
+
153
+ insert_query: QueryBuilder = (
154
+ self.query_cls.into(sql_table)
155
+ .columns(*insert_columns)
156
+ .insert(*(self.new_parameter() for _ in insert_columns))
157
+ )
158
+ sql_base = self._render_query(insert_query).rstrip().removesuffix(";")
159
+
160
+ conflict_target_sql = ", ".join(self.escape_identifier(col) for col in conflict_target)
161
+ update_assignments: list[str] = []
162
+ for column, value in update_payload.items():
163
+ update_assignments.append(f"{self.escape_identifier(column)} = {self.parameter_token}")
164
+ params.append(value)
165
+ update_clause = ", ".join(update_assignments)
166
+
167
+ sql = f"{sql_base} ON CONFLICT ({conflict_target_sql}) DO UPDATE SET {update_clause}"
168
+ returning_columns = [spec.name for spec in table.column_specs]
169
+ sql_with_returning = self._append_returning(sql, returning_columns)
170
+ rows = self.query_raw(sql_with_returning, params, auto_commit=True)
171
+ if len(rows) != 1:
172
+ raise RuntimeError(f"upsert() expected exactly 1 row, got {len(rows)}")
173
+
174
+ row = rows[0]
175
+ include_map = include or {}
176
+ instance = self._materialize_instance(table, row, include_map)
177
+ identity_key = self._identity_key(table, row)
178
+ if identity_key is not None:
179
+ self._identity_map.pop(identity_key, None)
180
+ self._invalidate_backrefs(table, instance)
181
+ return instance
182
+
70
183
  def find_many(
71
184
  self,
72
185
  table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
@@ -80,7 +193,33 @@ class BackendBase(BackendProtocol, ABC):
80
193
  ) -> list[ModelT]:
81
194
  sql_table = self.table_cls(table.model.__name__)
82
195
  distinct_columns = self._normalize_distinct(table, distinct)
83
- select_query = self.query_cls.from_(sql_table).select(
196
+ select_query, params = self._build_select_query(table, sql_table, where, order_by)
197
+
198
+ if skip is not None and not distinct_columns:
199
+ select_query = select_query.offset(skip)
200
+ if take is not None and not distinct_columns:
201
+ select_query = select_query.limit(take)
202
+
203
+ sql = self._render_query(select_query)
204
+ rows = self.query_raw(sql, params)
205
+ row_list = list(rows)
206
+ if distinct_columns:
207
+ row_list = self._deduplicate_rows(row_list, distinct_columns)
208
+ if skip:
209
+ row_list = row_list[skip:]
210
+ if take is not None:
211
+ row_list = row_list[:take]
212
+ include_map = include or {}
213
+ return [self._materialize_instance(table, row, include_map) for row in row_list]
214
+
215
+ def _build_select_query(
216
+ self,
217
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
218
+ sql_table: Table,
219
+ where: WhereT | None,
220
+ order_by: Mapping[str, str] | None,
221
+ ) -> tuple[QueryBuilder, list[Any]]:
222
+ select_query: QueryBuilder = self.query_cls.from_(sql_table).select(
84
223
  *[sql_table.field(spec.name) for spec in table.column_specs]
85
224
  )
86
225
  params: list[Any] = []
@@ -100,22 +239,7 @@ class BackendBase(BackendProtocol, ABC):
100
239
  raise ValueError("order_by direction must be 'asc' or 'desc'")
101
240
  select_query = select_query.orderby(sql_table.field(column), order=Order[direction_lower])
102
241
 
103
- if skip is not None and not distinct_columns:
104
- select_query = select_query.offset(skip)
105
- if take is not None and not distinct_columns:
106
- select_query = select_query.limit(take)
107
-
108
- sql = self._render_query(select_query)
109
- rows = self.query_raw(sql, params)
110
- row_list = list(rows)
111
- if distinct_columns:
112
- row_list = self._deduplicate_rows(row_list, distinct_columns)
113
- if skip:
114
- row_list = row_list[skip:]
115
- if take is not None:
116
- row_list = row_list[:take]
117
- include_map = include or {}
118
- return [self._materialize_instance(table, row, include_map) for row in row_list]
242
+ return select_query, params
119
243
 
120
244
  def find_first(
121
245
  self,
@@ -138,6 +262,157 @@ class BackendBase(BackendProtocol, ABC):
138
262
  )
139
263
  return results[0] if results else None
140
264
 
265
+ def delete(
266
+ self,
267
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
268
+ *,
269
+ where: WhereT,
270
+ include: Mapping[str, bool] | None = None,
271
+ ) -> ModelT | None:
272
+ sql_table = self.table_cls(table.model.__name__)
273
+ delete_query: QueryBuilder = self.query_cls.from_(sql_table).delete()
274
+ criterion, params = self._compile_where(table, sql_table, where)
275
+ if criterion is None:
276
+ raise ValueError("delete() requires a where clause")
277
+ delete_query = delete_query.where(criterion)
278
+ sql = self._render_query(delete_query)
279
+ returning_columns = [spec.name for spec in table.column_specs]
280
+ sql_with_returning = self._append_returning(sql, returning_columns)
281
+ rows = self.query_raw(sql_with_returning, params, auto_commit=True)
282
+ if not rows:
283
+ return None
284
+ if len(rows) != 1:
285
+ raise RuntimeError(f"delete() expected exactly 1 row, got {len(rows)}")
286
+
287
+ row = rows[0]
288
+ include_map = include or {}
289
+ instance = self._materialize_instance(table, row, include_map)
290
+ identity_key = self._identity_key(table, row)
291
+ if identity_key is not None:
292
+ self._identity_map.pop(identity_key, None)
293
+ self._invalidate_backrefs(table, instance)
294
+ return instance
295
+
296
+ @overload
297
+ def delete_many(
298
+ self,
299
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
300
+ *,
301
+ where: WhereT | None = None,
302
+ return_records: Literal[False] = False,
303
+ ) -> int: ...
304
+
305
+ @overload
306
+ def delete_many(
307
+ self,
308
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
309
+ *,
310
+ where: WhereT | None = None,
311
+ return_records: Literal[True],
312
+ ) -> list[ModelT]: ...
313
+
314
+ def delete_many(
315
+ self,
316
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
317
+ *,
318
+ where: WhereT | None = None,
319
+ return_records: Literal[False, True] = False,
320
+ ) -> int | list[ModelT]:
321
+ sql_table = self.table_cls(table.model.__name__)
322
+ delete_query: QueryBuilder = self.query_cls.from_(sql_table).delete()
323
+ params: list[Any] = []
324
+
325
+ if where:
326
+ criterion, where_params = self._compile_where(table, sql_table, where)
327
+ if criterion is not None:
328
+ delete_query = delete_query.where(criterion)
329
+ params.extend(where_params)
330
+
331
+ sql = self._render_query(delete_query)
332
+
333
+ if return_records:
334
+ returning_columns = [spec.name for spec in table.column_specs]
335
+ sql_with_returning = self._append_returning(sql, returning_columns)
336
+ rows = self.query_raw(sql_with_returning, params, auto_commit=True)
337
+ include_map: Mapping[str, bool] = {}
338
+ results: list[ModelT] = []
339
+ for row in rows:
340
+ identity_key = self._identity_key(table, row)
341
+ if identity_key is not None:
342
+ self._identity_map.pop(identity_key, None)
343
+ instance = self._materialize_instance(table, row, include_map)
344
+ results.append(instance)
345
+ return results
346
+
347
+ affected = self.execute_raw(sql, params, auto_commit=True)
348
+ self._purge_identity_map(table.model)
349
+ return affected
350
+
351
+ @overload
352
+ def update_many(
353
+ self,
354
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
355
+ *,
356
+ data: Mapping[str, object],
357
+ where: WhereT | None = None,
358
+ return_records: Literal[False] = False,
359
+ ) -> int: ...
360
+
361
+ @overload
362
+ def update_many(
363
+ self,
364
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
365
+ *,
366
+ data: Mapping[str, object],
367
+ where: WhereT | None = None,
368
+ return_records: Literal[True],
369
+ ) -> list[ModelT]: ...
370
+
371
+ def update_many(
372
+ self,
373
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
374
+ *,
375
+ data: Mapping[str, object],
376
+ where: WhereT | None = None,
377
+ return_records: Literal[False, True] = False,
378
+ ) -> int | list[ModelT]:
379
+ payload = table.serialize_update(data)
380
+ if not payload:
381
+ raise ValueError("Update payload cannot be empty")
382
+
383
+ sql_table = self.table_cls(table.model.__name__)
384
+ update_query: QueryBuilder = self.query_cls.update(sql_table)
385
+ params: list[Any] = []
386
+ for column, value in payload.items():
387
+ update_query = update_query.set(sql_table.field(column), self.new_parameter())
388
+ params.append(value)
389
+
390
+ if where:
391
+ criterion, where_params = self._compile_where(table, sql_table, where)
392
+ if criterion is not None:
393
+ update_query = update_query.where(criterion)
394
+ params.extend(where_params)
395
+
396
+ sql = self._render_query(update_query)
397
+
398
+ if return_records:
399
+ returning_columns = [spec.name for spec in table.column_specs]
400
+ sql_with_returning = self._append_returning(sql, returning_columns)
401
+ rows = self.query_raw(sql_with_returning, params, auto_commit=True)
402
+ include_map: Mapping[str, bool] = {}
403
+ results: list[ModelT] = []
404
+ for row in rows:
405
+ identity_key = self._identity_key(table, row)
406
+ if identity_key is not None:
407
+ self._identity_map.pop(identity_key, None)
408
+ instance = self._materialize_instance(table, row, include_map)
409
+ results.append(instance)
410
+ return results
411
+
412
+ affected = self.execute_raw(sql, params, auto_commit=True)
413
+ self._purge_identity_map(table.model)
414
+ return affected
415
+
141
416
  def _fetch_single(
142
417
  self,
143
418
  table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
@@ -273,6 +548,11 @@ class BackendBase(BackendProtocol, ABC):
273
548
  def _clear_identity_map(self) -> None:
274
549
  self._identity_map.clear()
275
550
 
551
+ def _purge_identity_map(self, model: type[Any]) -> None:
552
+ stale_keys = [key for key in self._identity_map if key[0] is model]
553
+ for key in stale_keys:
554
+ self._identity_map.pop(key, None)
555
+
276
556
  def _render_query(self, query: QueryBuilder) -> str:
277
557
  return query.get_sql(quote_char=self.quote_char) + ';'
278
558
 
@@ -302,13 +582,8 @@ class BackendBase(BackendProtocol, ABC):
302
582
  return criterion, compiler.params
303
583
 
304
584
  def _log_sql(self, sql: str, params: Sequence[object] | None) -> None:
305
- if not self._echo_sql:
306
- return
307
- if params is None:
308
- display_params = []
309
- else:
310
- display_params = list(params)
311
- print(f"[dclassql] SQL: {sql} | params={display_params}")
585
+ params_seq = list(params) if params is not None else []
586
+ push_sql(sql, params_seq, echo=self._echo_sql)
312
587
 
313
588
  def _normalize_distinct(
314
589
  self,
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import sqlite3
4
- from typing import Callable, Literal, Mapping, Protocol, Sequence, runtime_checkable
4
+ from typing import Callable, Literal, Mapping, Protocol, Sequence, overload, runtime_checkable
5
5
 
6
6
  from pypika import Query, Table
7
7
  from pypika.terms import Parameter
8
8
 
9
9
  from dclassql.model_inspector import DataSourceConfig
10
- from dclassql.typing import IncludeT, InsertT, ModelT, OrderByT, WhereT
10
+ from dclassql.typing import IncludeT, InsertT, ModelT, OrderByT, WhereT, UpsertWhereT
11
11
 
12
12
  from .metadata import ColumnSpec, ForeignKeySpec, RelationSpec
13
13
 
@@ -26,10 +26,13 @@ class TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT](Protocol):
26
26
 
27
27
  @classmethod
28
28
  def serialize_insert(cls, data: InsertT | Mapping[str, object]) -> dict[str, object]: ...
29
+ @classmethod
30
+ def serialize_update(cls, data: Mapping[str, object]) -> dict[str, object]: ...
29
31
 
30
32
  @classmethod
31
33
  def deserialize_row(cls, row: Mapping[str, object]) -> ModelT: ...
32
34
  primary_key: tuple[str, ...]
35
+ def primary_values(self, instance: ModelT) -> tuple[object, ...]: ...
33
36
  indexes: tuple[tuple[str, ...], ...]
34
37
  unique_indexes: tuple[tuple[str, ...], ...]
35
38
  foreign_keys: tuple[ForeignKeySpec, ...]
@@ -49,6 +52,23 @@ class BackendProtocol(Protocol):
49
52
  table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
50
53
  data: InsertT | Mapping[str, object],
51
54
  ) -> ModelT: ...
55
+ def update(
56
+ self,
57
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
58
+ *,
59
+ data: Mapping[str, object],
60
+ where: WhereT,
61
+ include: IncludeT | None = None,
62
+ ) -> ModelT: ...
63
+ def upsert(
64
+ self,
65
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
66
+ *,
67
+ where: UpsertWhereT,
68
+ update: Mapping[str, object],
69
+ insert: InsertT | Mapping[str, object],
70
+ include: IncludeT | None = None,
71
+ ) -> ModelT: ...
52
72
 
53
73
  def insert_many(
54
74
  self,
@@ -57,6 +77,24 @@ class BackendProtocol(Protocol):
57
77
  *,
58
78
  batch_size: int | None = None,
59
79
  ) -> list[ModelT]: ...
80
+ @overload
81
+ def update_many(
82
+ self,
83
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
84
+ *,
85
+ data: Mapping[str, object],
86
+ where: WhereT | None = None,
87
+ return_records: Literal[False] = False,
88
+ ) -> int: ...
89
+ @overload
90
+ def update_many(
91
+ self,
92
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
93
+ *,
94
+ data: Mapping[str, object],
95
+ where: WhereT | None = None,
96
+ return_records: Literal[True],
97
+ ) -> list[ModelT]: ...
60
98
 
61
99
  def find_many(
62
100
  self,
@@ -81,6 +119,32 @@ class BackendProtocol(Protocol):
81
119
  skip: int | None = None,
82
120
  ) -> ModelT | None: ...
83
121
 
122
+ def delete(
123
+ self,
124
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
125
+ *,
126
+ where: WhereT,
127
+ include: IncludeT | None = None,
128
+ ) -> ModelT | None: ...
129
+
130
+ @overload
131
+ def delete_many(
132
+ self,
133
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
134
+ *,
135
+ where: WhereT | None = None,
136
+ return_records: Literal[False] = False,
137
+ ) -> int: ...
138
+
139
+ @overload
140
+ def delete_many(
141
+ self,
142
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
143
+ *,
144
+ where: WhereT | None = None,
145
+ return_records: Literal[True],
146
+ ) -> list[ModelT]: ...
147
+
84
148
  def query_raw(self, sql: str, params: Sequence[object] | None = None, auto_commit: bool = False) -> Sequence[dict[str, object]]: ...
85
149
 
86
150
  def execute_raw(self, sql: str, params: Sequence[object] | None = None, auto_commit: bool = True) -> int: ...
@@ -4,7 +4,9 @@ import sqlite3
4
4
  import threading
5
5
  from typing import Any, Literal, Mapping, Sequence, overload
6
6
 
7
+ from pypika import analytics as an
7
8
  from pypika.dialects import SQLLiteQuery
9
+ from pypika.enums import Order
8
10
  from pypika.queries import QueryBuilder
9
11
 
10
12
  from dclassql.typing import IncludeT, InsertT, ModelT, OrderByT, WhereT
@@ -118,6 +120,70 @@ class SQLiteBackend(BackendBase):
118
120
  result = self._execute_sql(connection, sql, params, fetch=False, auto_commit=auto_commit)
119
121
  return result
120
122
 
123
+ def find_many(
124
+ self,
125
+ table: TableProtocol[ModelT, InsertT, WhereT, IncludeT, OrderByT],
126
+ *,
127
+ where: WhereT | None = None,
128
+ include: Mapping[str, bool] | None = None,
129
+ order_by: Mapping[str, str] | None = None,
130
+ distinct: Sequence[str] | str | None = None,
131
+ take: int | None = None,
132
+ skip: int | None = None,
133
+ ) -> list[ModelT]:
134
+ distinct_columns = self._normalize_distinct(table, distinct)
135
+ if not distinct_columns:
136
+ return super().find_many(
137
+ table,
138
+ where=where,
139
+ include=include,
140
+ order_by=order_by,
141
+ distinct=distinct,
142
+ take=take,
143
+ skip=skip,
144
+ )
145
+
146
+ sql_table = self.table_cls(table.model.__name__)
147
+ # 基础查询不带 order_by,方便后续在窗口与外层统一处理
148
+ base_query, params = self._build_select_query(table, sql_table, where, None)
149
+
150
+ partition_fields = [sql_table.field(col) for col in distinct_columns]
151
+ order_pairs: list[tuple[Any, Order]] = []
152
+ if order_by:
153
+ for col, direction in order_by.items():
154
+ direction_lower = direction.lower()
155
+ if direction_lower not in {"asc", "desc"}:
156
+ raise ValueError("order_by direction must be 'asc' or 'desc'")
157
+ order_pairs.append((sql_table.field(col), Order[direction_lower]))
158
+ else:
159
+ for col in distinct_columns:
160
+ order_pairs.append((sql_table.field(col), Order.asc))
161
+
162
+ row_number = an.RowNumber().over(*partition_fields)
163
+ for field, ord_ in order_pairs:
164
+ row_number = row_number.orderby(field, order=ord_)
165
+ row_number = row_number.as_("rn")
166
+ inner_query = base_query.select(row_number)
167
+ sub = inner_query.as_("__d")
168
+ outer_query: QueryBuilder = self.query_cls.from_(sub).select(sub.star).where(sub.rn == 1)
169
+
170
+ if order_by:
171
+ for col, direction in order_by.items():
172
+ direction_lower = direction.lower()
173
+ outer_query = outer_query.orderby(getattr(sub, col), order=Order[direction_lower])
174
+
175
+ if take is not None:
176
+ outer_query = outer_query.limit(take)
177
+ if skip is not None:
178
+ outer_query = outer_query.offset(skip)
179
+ elif skip is not None:
180
+ outer_query = outer_query.limit(-1).offset(skip)
181
+
182
+ sql = outer_query.get_sql(quote_char=self.quote_char) + ";"
183
+ rows = self.query_raw(sql, params)
184
+ include_map = include or {}
185
+ return [self._materialize_instance(table, row, include_map) for row in rows]
186
+
121
187
 
122
188
  @overload
123
189
  def _execute_sql(
@@ -167,9 +167,11 @@ class WhereCompiler:
167
167
  if operator == "NOT":
168
168
  if isinstance(operand, ABCMapping):
169
169
  compiled = self._compile_filter(field, operand)
170
- else:
171
- compiled = self._compile_direct(field, operand)
172
- return compiled.negate() if compiled is not None else None
170
+ return compiled.negate() if compiled is not None else None
171
+ if operand is None:
172
+ return field.isnull().negate()
173
+ bound = self._bind_value(operand)
174
+ return (field != bound) | field.isnull()
173
175
  raise ValueError(f"Unsupported filter operator '{operator}'")
174
176
 
175
177
  def _apply_like(self, field: Field, pattern: str) -> Criterion:
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ import contextvars
4
+ from contextlib import contextmanager
5
+ from typing import Any, Iterator, Sequence
6
+
7
+ _current_recorder: contextvars.ContextVar[tuple[list[tuple[str, tuple[Any, ...]]], bool] | None] = (
8
+ contextvars.ContextVar("_current_recorder", default=None)
9
+ )
10
+
11
+
12
+ def record_sql(echo: bool = False):
13
+ """
14
+ 记录当前上下文执行的 SQL.
15
+ example:
16
+ ```python
17
+ with record_sql() as sqls:
18
+ client.user.find_first(id=1)
19
+ print(sqls) # [('SELECT ...', (1,))]
20
+ ```
21
+ """
22
+
23
+ @contextmanager
24
+ def _manager() -> Iterator[list[tuple[str, tuple[Any, ...]]]]:
25
+ token = None
26
+ records: list[tuple[str, tuple[Any, ...]]] = []
27
+ try:
28
+ token = _current_recorder.set((records, echo))
29
+ yield records
30
+ finally:
31
+ if token is not None:
32
+ _current_recorder.reset(token)
33
+
34
+ return _manager()
35
+
36
+
37
+ def push_sql(sql: str, params: Sequence[Any], *, echo: bool) -> None:
38
+ recorder = _current_recorder.get()
39
+ rec_echo = False
40
+ if recorder is not None:
41
+ rec_list, rec_echo = recorder
42
+ rec_list.append((sql, tuple(params)))
43
+ if rec_echo or echo:
44
+ print(f"[dclassql] SQL: {sql} | params={list(params)}")
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass, field
4
4
  from enum import Enum
5
5
  from types import MappingProxyType
6
- from typing import Any, Literal, Mapping, Sequence, NotRequired
6
+ from typing import Any, Literal, Mapping, Sequence, NotRequired, overload
7
7
  from typing_extensions import TypedDict
8
8
 
9
9
  from dclassql import DataSourceConfig
@@ -11,17 +11,24 @@
11
11
  {%- set sortable_alias = 'T' ~ name ~ 'SortableCol' -%}
12
12
  {%- set include_relations = info.relations | sort(attribute='name', case_sensitive=True) -%}
13
13
  {%- set sortable_columns = info.columns -%}
14
+ {%- set update_dict_class = name ~ 'UpdateDict' -%}
15
+ {%- set upsert_where_alias = name ~ 'UpsertWhereDict' -%}
16
+ {%- set upsert_update_class = update_dict_class -%}
17
+ {%- set upsert_insert_class = insert_class -%}
18
+ {%- set upsert_insert_dict_class = insert_dict_class -%}
14
19
 
15
20
  {% set backend_signature = 'BackendProtocol' %}
16
21
 
17
22
  {% if include_relations %}
18
- {{ include_alias }} = Literal[{% for relation in include_relations %}'{{ relation.name }}'{% if not loop.last %}, {% endif %}{% endfor %}]
23
+ {% set include_names = include_relations | map(attribute='name') | list %}
24
+ {{ include_alias }} = Literal[{{ include_names | map('tojson') | join(', ') }}]
19
25
  {% else %}
20
26
  {{ include_alias }} = Literal[()]
21
27
  {% endif %}
22
28
  {% if sortable_columns %}
23
- {{ sortable_alias }} = Literal[{% for column in sortable_columns %}'{{ column.name }}'{% if not loop.last %}, {% endif %}{% endfor %}]
24
- {{ distinct_alias }} = Literal[{% for column in sortable_columns %}'{{ column.name }}'{% if not loop.last %}, {% endif %}{% endfor %}]
29
+ {% set col_names = sortable_columns | map(attribute='name') | list %}
30
+ {{ sortable_alias }} = Literal[{{ col_names | map('tojson') | join(', ') }}]
31
+ {{ distinct_alias }} = Literal[{{ col_names | map('tojson') | join(', ') }}]
25
32
  {% else %}
26
33
  {{ sortable_alias }} = Literal[()]
27
34
  {{ distinct_alias }} = Literal[()]
@@ -37,6 +44,20 @@ class {{ name }}Dict(TypedDict, closed=True):
37
44
  class {{ insert_dict_class }}(TypedDict, closed=True):
38
45
  {{ macros.typed_dict_fields(model.typed_dict_fields)|indent(4, True) }}
39
46
 
47
+ class {{ update_dict_class }}(TypedDict, total=False, closed=True):
48
+ {{ macros.typed_dict_fields(model.update_fields)|indent(4, True) }}
49
+
50
+ {% for upsert_dict in model.upsert_where_dicts %}
51
+ class {{ upsert_dict.name }}(TypedDict, closed=True):
52
+ {{ macros.typed_dict_fields(upsert_dict.fields)|indent(4, True) }}
53
+
54
+ {% endfor %}
55
+ {% if model.upsert_where_dicts|length == 1 %}
56
+ {{ upsert_where_alias }} = {{ model.upsert_where_dicts[0].name }}
57
+ {% else %}
58
+ {{ upsert_where_alias }} = {{ model.upsert_where_dicts | map(attribute='name') | join(' | ') }}
59
+ {% endif %}
60
+
40
61
  {% for relation_filter in model.relation_filters %}
41
62
  class {{ relation_filter.name }}(TypedDict, total=False, closed=True):
42
63
  {{ macros.typed_dict_fields(relation_filter.fields)|indent(4, True) }}
@@ -67,7 +88,9 @@ class {{ table_class }}(TableProtocol):
67
88
  datasource = {{ model.datasource_expr }}
68
89
  {{ macros.column_specs(model.column_specs)|indent(4, True) }}
69
90
  column_specs_by_name: Mapping[str, ColumnSpec] = MappingProxyType({spec.name: spec for spec in column_specs})
70
- primary_key: tuple[str, ...] = {{ model.primary_key_literal }}
91
+ {% set pk_types = ['str'] * (model.model_info.primary_key | length) %}
92
+ primary_key: tuple[{{ pk_types | join(', ') }}] = {{ model.primary_key_literal }}
93
+
71
94
  indexes: tuple[tuple[str, ...], ...] = {{ model.indexes_literal }}
72
95
  unique_indexes: tuple[tuple[str, ...], ...] = {{ model.unique_indexes_literal }}
73
96
  {{ macros.foreign_keys(model.foreign_keys)|indent(4, True) }}
@@ -75,6 +98,12 @@ class {{ table_class }}(TableProtocol):
75
98
  {% if model.default_factories %}
76
99
  {% for factory in model.default_factories %} {{ factory.var_name }} = {{ factory.expression }}
77
100
  {% endfor %}{% endif %}
101
+ def primary_values(self, instance: {{ name }}) -> tuple[{{ model.primary_value_types | join(', ') }}]:
102
+ return (
103
+ {% for pk in model.model_info.primary_key %}
104
+ instance.{{ pk }},
105
+ {% endfor %}
106
+ )
78
107
 
79
108
  @classmethod
80
109
  def serialize_insert(cls, data: {{ insert_class }} | Mapping[str, object]) -> dict[str, object]:
@@ -91,6 +120,12 @@ class {{ table_class }}(TableProtocol):
91
120
  {% endfor %} }
92
121
  raise TypeError("Unsupported insert payload type for {{ name }}")
93
122
 
123
+ @classmethod
124
+ def serialize_update(cls, data: Mapping[str, object]) -> dict[str, object]:
125
+ if not isinstance(data, Mapping):
126
+ raise TypeError("Update payload must be a mapping")
127
+ return cls.serialize_insert(data)
128
+
94
129
  @classmethod
95
130
  def deserialize_row(cls, row: Mapping[str, object]) -> {{ name }}:
96
131
  instance = {{ name }}.__new__({{ name }}) # type: ignore[call-arg]
@@ -109,6 +144,26 @@ class {{ table_class }}(TableProtocol):
109
144
  def insert_many(self, data: Sequence[{{ insert_class }} | {{ insert_dict_class }}], *, batch_size: int | None = None) -> list[{{ name }}]:
110
145
  return self._backend.insert_many(self, data, batch_size=batch_size)
111
146
 
147
+ def update(self, *, data: {{ update_dict_class }}, where: {{ where_dict_class }}, include: {{ include_dict_class }} | None = None) -> {{ name }}:
148
+ return self._backend.update(self, data=data, where=where, include=include)
149
+
150
+ @overload
151
+ def update_many(self, *, data: {{ update_dict_class }}, where: {{ where_dict_class }} | None = None, return_records: Literal[False] = False) -> int: ...
152
+ @overload
153
+ def update_many(self, *, data: {{ update_dict_class }}, where: {{ where_dict_class }} | None = None, return_records: Literal[True]) -> list[{{ name }}]: ...
154
+ def update_many(self, *, data: {{ update_dict_class }}, where: {{ where_dict_class }} | None = None, return_records: Literal[False, True] = False) -> int | list[{{ name }}]:
155
+ return self._backend.update_many(self, data=data, where=where, return_records=return_records)
156
+
157
+ def upsert(
158
+ self,
159
+ *,
160
+ where: {{ upsert_where_alias }},
161
+ update: {{ upsert_update_class }},
162
+ insert: {{ upsert_insert_class }} | {{ upsert_insert_dict_class }},
163
+ include: {{ include_dict_class }} | None = None,
164
+ ) -> {{ name }}:
165
+ return self._backend.upsert(self, where=where, update=update, insert=insert, include=include)
166
+
112
167
  def find_many(self, *, where: {{ where_dict_class }} | None = None, include: {{ include_dict_class }} | None = None, order_by: {{ order_by_dict_class }} | None = None, distinct: {{ distinct_alias }} | Sequence[{{ distinct_alias }}] | None = None, take: int | None = None, skip: int | None = None) -> list[{{ name }}]:
113
168
  return self._backend.find_many(
114
169
  self,
@@ -122,3 +177,13 @@ class {{ table_class }}(TableProtocol):
122
177
  where=where, include=include, order_by=order_by, distinct=distinct,
123
178
  skip=skip
124
179
  )
180
+
181
+ def delete(self, *, where: {{ where_dict_class }}, include: {{ include_dict_class }} | None = None) -> {{ name }} | None:
182
+ return self._backend.delete(self, where=where, include=include)
183
+
184
+ @overload
185
+ def delete_many(self, *, where: {{ where_dict_class }} | None = None, return_records: Literal[False] = False) -> int: ...
186
+ @overload
187
+ def delete_many(self, *, where: {{ where_dict_class }} | None = None, return_records: Literal[True]) -> list[{{ name }}]: ...
188
+ def delete_many(self, *, where: {{ where_dict_class }} | None = None, return_records: Literal[False, True] = False) -> int | list[{{ name }}]:
189
+ return self._backend.delete_many(self, where=where, return_records=return_records)
@@ -5,3 +5,4 @@ InsertT = TypeVar("InsertT", default=Mapping[str, object])
5
5
  WhereT = TypeVar("WhereT", bound=Mapping[str, object], default=Mapping[str, object])
6
6
  IncludeT = TypeVar("IncludeT", bound=Mapping[str, bool], default=Mapping[str, bool])
7
7
  OrderByT = TypeVar("OrderByT", bound=Mapping[str, Literal['asc','desc']], default=Mapping[str, Literal['asc','desc']])
8
+ UpsertWhereT = TypeVar("UpsertWhereT", bound=Mapping[str, object], default=Mapping[str, object])
File without changes
File without changes