sql-athame 0.4.0a11__py3-none-any.whl → 0.4.0a13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sql_athame/base.py +479 -0
- sql_athame/dataclasses.py +841 -24
- {sql_athame-0.4.0a11.dist-info → sql_athame-0.4.0a13.dist-info}/METADATA +8 -5
- sql_athame-0.4.0a13.dist-info/RECORD +11 -0
- {sql_athame-0.4.0a11.dist-info → sql_athame-0.4.0a13.dist-info}/licenses/LICENSE +1 -1
- sql_athame-0.4.0a11.dist-info/RECORD +0 -11
- {sql_athame-0.4.0a11.dist-info → sql_athame-0.4.0a13.dist-info}/WHEEL +0 -0
sql_athame/dataclasses.py
CHANGED
@@ -33,6 +33,35 @@ Pool: TypeAlias = Any
|
|
33
33
|
|
34
34
|
@dataclass
|
35
35
|
class ColumnInfo:
|
36
|
+
"""Column metadata for dataclass fields.
|
37
|
+
|
38
|
+
This class specifies SQL column properties that can be applied to dataclass fields
|
39
|
+
to control how they are mapped to database columns.
|
40
|
+
|
41
|
+
Attributes:
|
42
|
+
type: SQL type name for query parameters (e.g., 'TEXT', 'INTEGER')
|
43
|
+
create_type: SQL type for CREATE TABLE statements (defaults to type if not specified)
|
44
|
+
nullable: Whether the column allows NULL values (inferred from Optional types if not specified)
|
45
|
+
constraints: Additional SQL constraints (e.g., 'UNIQUE', 'CHECK (value > 0)')
|
46
|
+
serialize: Function to transform Python values before database storage
|
47
|
+
deserialize: Function to transform database values back to Python objects
|
48
|
+
insert_only: Whether this field should only be set on INSERT, not UPDATE in upsert operations
|
49
|
+
|
50
|
+
Example:
|
51
|
+
>>> from dataclasses import dataclass
|
52
|
+
>>> from typing import Annotated
|
53
|
+
>>> from sql_athame import ModelBase, ColumnInfo
|
54
|
+
>>> import json
|
55
|
+
>>>
|
56
|
+
>>> @dataclass
|
57
|
+
... class Product(ModelBase, table_name="products", primary_key="id"):
|
58
|
+
... id: int
|
59
|
+
... name: str
|
60
|
+
... price: Annotated[float, ColumnInfo(constraints="CHECK (price > 0)")]
|
61
|
+
... tags: Annotated[list, ColumnInfo(type="JSONB", serialize=json.dumps, deserialize=json.loads)]
|
62
|
+
... created_at: Annotated[datetime, ColumnInfo(insert_only=True)]
|
63
|
+
"""
|
64
|
+
|
36
65
|
type: Optional[str] = None
|
37
66
|
create_type: Optional[str] = None
|
38
67
|
nullable: Optional[bool] = None
|
@@ -42,6 +71,7 @@ class ColumnInfo:
|
|
42
71
|
|
43
72
|
serialize: Optional[Callable[[Any], Any]] = None
|
44
73
|
deserialize: Optional[Callable[[Any], Any]] = None
|
74
|
+
insert_only: Optional[bool] = None
|
45
75
|
|
46
76
|
def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
|
47
77
|
if constraints is not None:
|
@@ -51,6 +81,15 @@ class ColumnInfo:
|
|
51
81
|
|
52
82
|
@staticmethod
|
53
83
|
def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
|
84
|
+
"""Merge two ColumnInfo instances, with b taking precedence over a.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
a: Base ColumnInfo
|
88
|
+
b: ColumnInfo to overlay on top of a
|
89
|
+
|
90
|
+
Returns:
|
91
|
+
New ColumnInfo with b's non-None values overriding a's values
|
92
|
+
"""
|
54
93
|
return ColumnInfo(
|
55
94
|
type=b.type if b.type is not None else a.type,
|
56
95
|
create_type=b.create_type if b.create_type is not None else a.create_type,
|
@@ -58,11 +97,29 @@ class ColumnInfo:
|
|
58
97
|
_constraints=(*a._constraints, *b._constraints),
|
59
98
|
serialize=b.serialize if b.serialize is not None else a.serialize,
|
60
99
|
deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
|
100
|
+
insert_only=b.insert_only if b.insert_only is not None else a.insert_only,
|
61
101
|
)
|
62
102
|
|
63
103
|
|
64
104
|
@dataclass
|
65
105
|
class ConcreteColumnInfo:
|
106
|
+
"""Resolved column information for a specific dataclass field.
|
107
|
+
|
108
|
+
This is the final, computed column metadata after resolving type hints,
|
109
|
+
merging ColumnInfo instances, and applying defaults.
|
110
|
+
|
111
|
+
Attributes:
|
112
|
+
field: The dataclass Field object
|
113
|
+
type_hint: The resolved Python type hint
|
114
|
+
type: SQL type for query parameters
|
115
|
+
create_type: SQL type for CREATE TABLE statements
|
116
|
+
nullable: Whether the column allows NULL values
|
117
|
+
constraints: Tuple of SQL constraint strings
|
118
|
+
serialize: Optional serialization function
|
119
|
+
deserialize: Optional deserialization function
|
120
|
+
insert_only: Whether this field should only be set on INSERT, not UPDATE
|
121
|
+
"""
|
122
|
+
|
66
123
|
field: Field
|
67
124
|
type_hint: type
|
68
125
|
type: str
|
@@ -71,11 +128,25 @@ class ConcreteColumnInfo:
|
|
71
128
|
constraints: tuple[str, ...]
|
72
129
|
serialize: Optional[Callable[[Any], Any]] = None
|
73
130
|
deserialize: Optional[Callable[[Any], Any]] = None
|
131
|
+
insert_only: bool = False
|
74
132
|
|
75
133
|
@staticmethod
|
76
134
|
def from_column_info(
|
77
135
|
field: Field, type_hint: Any, *args: ColumnInfo
|
78
136
|
) -> "ConcreteColumnInfo":
|
137
|
+
"""Create ConcreteColumnInfo from a field and its ColumnInfo metadata.
|
138
|
+
|
139
|
+
Args:
|
140
|
+
field: The dataclass Field
|
141
|
+
type_hint: The resolved type hint for the field
|
142
|
+
*args: ColumnInfo instances to merge (later ones take precedence)
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
ConcreteColumnInfo with all metadata resolved
|
146
|
+
|
147
|
+
Raises:
|
148
|
+
ValueError: If no SQL type can be determined for the field
|
149
|
+
"""
|
79
150
|
info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
|
80
151
|
if info.create_type is None and info.type is not None:
|
81
152
|
info.create_type = info.type
|
@@ -91,9 +162,15 @@ class ConcreteColumnInfo:
|
|
91
162
|
constraints=info._constraints,
|
92
163
|
serialize=info.serialize,
|
93
164
|
deserialize=info.deserialize,
|
165
|
+
insert_only=bool(info.insert_only),
|
94
166
|
)
|
95
167
|
|
96
168
|
def create_table_string(self) -> str:
|
169
|
+
"""Generate the SQL column definition for CREATE TABLE statements.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
SQL string like "TEXT NOT NULL CHECK (length > 0)"
|
173
|
+
"""
|
97
174
|
parts = (
|
98
175
|
self.create_type,
|
99
176
|
*(() if self.nullable else ("NOT NULL",)),
|
@@ -102,6 +179,14 @@ class ConcreteColumnInfo:
|
|
102
179
|
return " ".join(parts)
|
103
180
|
|
104
181
|
def maybe_serialize(self, value: Any) -> Any:
|
182
|
+
"""Apply serialization function if configured, otherwise return value unchanged.
|
183
|
+
|
184
|
+
Args:
|
185
|
+
value: The Python value to potentially serialize
|
186
|
+
|
187
|
+
Returns:
|
188
|
+
Serialized value if serialize function is configured, otherwise original value
|
189
|
+
"""
|
105
190
|
if self.serialize:
|
106
191
|
return self.serialize(value)
|
107
192
|
return value
|
@@ -179,6 +264,15 @@ class ModelBase:
|
|
179
264
|
|
180
265
|
@classmethod
|
181
266
|
def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
|
267
|
+
"""Cache computation results by key.
|
268
|
+
|
269
|
+
Args:
|
270
|
+
key: Cache key tuple
|
271
|
+
thunk: Function to compute the value if not cached
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
Cached or computed value
|
275
|
+
"""
|
182
276
|
try:
|
183
277
|
return cls._cache[key]
|
184
278
|
except KeyError:
|
@@ -187,6 +281,18 @@ class ModelBase:
|
|
187
281
|
|
188
282
|
@classmethod
|
189
283
|
def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
|
284
|
+
"""Generate ConcreteColumnInfo for a dataclass field.
|
285
|
+
|
286
|
+
Analyzes the field's type hint and metadata to determine SQL column properties.
|
287
|
+
Looks for ColumnInfo in the field's metadata and merges it with type-based defaults.
|
288
|
+
|
289
|
+
Args:
|
290
|
+
field: The dataclass Field object
|
291
|
+
type_hint: The resolved type hint for the field
|
292
|
+
|
293
|
+
Returns:
|
294
|
+
ConcreteColumnInfo with all column metadata resolved
|
295
|
+
"""
|
190
296
|
base_type = type_hint
|
191
297
|
metadata = []
|
192
298
|
if get_origin(type_hint) is Annotated:
|
@@ -202,6 +308,14 @@ class ModelBase:
|
|
202
308
|
|
203
309
|
@classmethod
|
204
310
|
def column_info(cls) -> dict[str, ConcreteColumnInfo]:
|
311
|
+
"""Get column information for all fields in this model.
|
312
|
+
|
313
|
+
Returns a cached mapping of field names to their resolved column information.
|
314
|
+
This is computed once per class and cached for performance.
|
315
|
+
|
316
|
+
Returns:
|
317
|
+
Dictionary mapping field names to ConcreteColumnInfo objects
|
318
|
+
"""
|
205
319
|
try:
|
206
320
|
return cls._column_info
|
207
321
|
except AttributeError:
|
@@ -214,35 +328,131 @@ class ModelBase:
|
|
214
328
|
|
215
329
|
@classmethod
|
216
330
|
def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
|
331
|
+
"""Generate SQL fragment for the table name.
|
332
|
+
|
333
|
+
Args:
|
334
|
+
prefix: Optional schema or alias prefix
|
335
|
+
|
336
|
+
Returns:
|
337
|
+
Fragment containing the properly quoted table identifier
|
338
|
+
|
339
|
+
Example:
|
340
|
+
>>> list(User.table_name_sql())
|
341
|
+
['"users"']
|
342
|
+
>>> list(User.table_name_sql(prefix="public"))
|
343
|
+
['"public"."users"']
|
344
|
+
"""
|
217
345
|
return sql.identifier(cls.table_name, prefix=prefix)
|
218
346
|
|
219
347
|
@classmethod
|
220
348
|
def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
|
349
|
+
"""Generate SQL fragments for primary key column names.
|
350
|
+
|
351
|
+
Args:
|
352
|
+
prefix: Optional table alias prefix
|
353
|
+
|
354
|
+
Returns:
|
355
|
+
List of Fragment objects for each primary key column
|
356
|
+
"""
|
221
357
|
return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
|
222
358
|
|
223
359
|
@classmethod
|
224
360
|
def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
|
361
|
+
"""Get list of field names for this model.
|
362
|
+
|
363
|
+
Args:
|
364
|
+
exclude: Field names to exclude from the result
|
365
|
+
|
366
|
+
Returns:
|
367
|
+
List of field names as strings
|
368
|
+
"""
|
225
369
|
return [
|
226
370
|
ci.field.name
|
227
371
|
for ci in cls.column_info().values()
|
228
372
|
if ci.field.name not in exclude
|
229
373
|
]
|
230
374
|
|
375
|
+
@classmethod
|
376
|
+
def insert_only_field_names(cls) -> set[str]:
|
377
|
+
"""Get set of field names marked as insert_only in ColumnInfo.
|
378
|
+
|
379
|
+
Returns:
|
380
|
+
Set of field names that should only be set on INSERT, not UPDATE
|
381
|
+
"""
|
382
|
+
return cls._cached(
|
383
|
+
("insert_only_field_names",),
|
384
|
+
lambda: {
|
385
|
+
ci.field.name for ci in cls.column_info().values() if ci.insert_only
|
386
|
+
},
|
387
|
+
)
|
388
|
+
|
231
389
|
@classmethod
|
232
390
|
def field_names_sql(
|
233
|
-
cls,
|
391
|
+
cls,
|
392
|
+
*,
|
393
|
+
prefix: Optional[str] = None,
|
394
|
+
exclude: FieldNamesSet = (),
|
395
|
+
as_prepended: Optional[str] = None,
|
234
396
|
) -> list[Fragment]:
|
397
|
+
"""Generate SQL fragments for field names.
|
398
|
+
|
399
|
+
Args:
|
400
|
+
prefix: Optional table alias prefix for column names
|
401
|
+
exclude: Field names to exclude from the result
|
402
|
+
as_prepended: If provided, generate "column AS prepended_column" aliases
|
403
|
+
|
404
|
+
Returns:
|
405
|
+
List of Fragment objects for each field
|
406
|
+
|
407
|
+
Example:
|
408
|
+
>>> list(sql.list(User.field_names_sql()))
|
409
|
+
['"id", "name", "email"']
|
410
|
+
>>> list(sql.list(User.field_names_sql(prefix="u")))
|
411
|
+
['"u"."id", "u"."name", "u"."email"']
|
412
|
+
>>> list(sql.list(User.field_names_sql(as_prepended="user_")))
|
413
|
+
['"id" AS "user_id", "name" AS "user_name", "email" AS "user_email"']
|
414
|
+
"""
|
415
|
+
if as_prepended:
|
416
|
+
return [
|
417
|
+
sql(
|
418
|
+
"{} AS {}",
|
419
|
+
sql.identifier(f, prefix=prefix),
|
420
|
+
sql.identifier(f"{as_prepended}{f}"),
|
421
|
+
)
|
422
|
+
for f in cls.field_names(exclude=exclude)
|
423
|
+
]
|
235
424
|
return [
|
236
425
|
sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
|
237
426
|
]
|
238
427
|
|
239
428
|
def primary_key(self) -> tuple:
|
429
|
+
"""Get the primary key value(s) for this instance.
|
430
|
+
|
431
|
+
Returns:
|
432
|
+
Tuple containing the primary key field values
|
433
|
+
|
434
|
+
Example:
|
435
|
+
>>> user = User(id=UUID(...), name="Alice")
|
436
|
+
>>> user.primary_key()
|
437
|
+
(UUID('...'),)
|
438
|
+
"""
|
240
439
|
return tuple(getattr(self, pk) for pk in self.primary_key_names)
|
241
440
|
|
242
441
|
@classmethod
|
243
442
|
def _get_field_values_fn(
|
244
443
|
cls: type[T], exclude: FieldNamesSet = ()
|
245
444
|
) -> Callable[[T], list[Any]]:
|
445
|
+
"""Generate optimized function to extract field values from instances.
|
446
|
+
|
447
|
+
This method generates and compiles a function that efficiently extracts
|
448
|
+
field values from model instances, applying serialization where needed.
|
449
|
+
|
450
|
+
Args:
|
451
|
+
exclude: Field names to exclude from value extraction
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
Compiled function that takes an instance and returns field values
|
455
|
+
"""
|
246
456
|
env: dict[str, Any] = {}
|
247
457
|
func = ["def get_field_values(self): return ["]
|
248
458
|
for ci in cls.column_info().values():
|
@@ -257,6 +467,17 @@ class ModelBase:
|
|
257
467
|
return env["get_field_values"]
|
258
468
|
|
259
469
|
def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
|
470
|
+
"""Get field values for this instance, with serialization applied.
|
471
|
+
|
472
|
+
Args:
|
473
|
+
exclude: Field names to exclude from the result
|
474
|
+
|
475
|
+
Returns:
|
476
|
+
List of field values in the same order as field_names()
|
477
|
+
|
478
|
+
Note:
|
479
|
+
This method applies any configured serialize functions to the values.
|
480
|
+
"""
|
260
481
|
get_field_values = self._cached(
|
261
482
|
("get_field_values", tuple(sorted(exclude))),
|
262
483
|
lambda: self._get_field_values_fn(exclude),
|
@@ -266,6 +487,15 @@ class ModelBase:
|
|
266
487
|
def field_values_sql(
|
267
488
|
self, *, exclude: FieldNamesSet = (), default_none: bool = False
|
268
489
|
) -> list[Fragment]:
|
490
|
+
"""Generate SQL fragments for field values.
|
491
|
+
|
492
|
+
Args:
|
493
|
+
exclude: Field names to exclude
|
494
|
+
default_none: If True, None values become DEFAULT literals instead of NULL
|
495
|
+
|
496
|
+
Returns:
|
497
|
+
List of Fragment objects containing value placeholders or DEFAULT
|
498
|
+
"""
|
269
499
|
if default_none:
|
270
500
|
return [
|
271
501
|
sql.literal("DEFAULT") if value is None else sql.value(value)
|
@@ -276,6 +506,15 @@ class ModelBase:
|
|
276
506
|
|
277
507
|
@classmethod
|
278
508
|
def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
|
509
|
+
"""Generate optimized function to create instances from mappings.
|
510
|
+
|
511
|
+
This method generates and compiles a function that efficiently creates
|
512
|
+
model instances from dictionary-like mappings, applying deserialization
|
513
|
+
where needed.
|
514
|
+
|
515
|
+
Returns:
|
516
|
+
Compiled function that takes a mapping and returns a model instance
|
517
|
+
"""
|
279
518
|
env: dict[str, Any] = {"cls": cls}
|
280
519
|
func = ["def from_mapping(mapping):"]
|
281
520
|
if not any(ci.deserialize for ci in cls.column_info().values()):
|
@@ -295,19 +534,77 @@ class ModelBase:
|
|
295
534
|
|
296
535
|
@classmethod
|
297
536
|
def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
|
537
|
+
"""Create a model instance from a dictionary-like mapping.
|
538
|
+
|
539
|
+
This method applies any configured deserialize functions to the values
|
540
|
+
before creating the instance.
|
541
|
+
|
542
|
+
Args:
|
543
|
+
mapping: Dictionary-like object with field names as keys
|
544
|
+
|
545
|
+
Returns:
|
546
|
+
New instance of this model class
|
547
|
+
|
548
|
+
Example:
|
549
|
+
>>> row = {"id": UUID(...), "name": "Alice", "email": None}
|
550
|
+
>>> user = User.from_mapping(row)
|
551
|
+
"""
|
298
552
|
# KLUDGE nasty but... efficient?
|
299
553
|
from_mapping_fn = cls._get_from_mapping_fn()
|
300
554
|
cls.from_mapping = from_mapping_fn # type: ignore
|
301
555
|
return from_mapping_fn(mapping)
|
302
556
|
|
557
|
+
@classmethod
|
558
|
+
def from_prepended_mapping(
|
559
|
+
cls: type[T], mapping: Mapping[str, Any], prepend: str
|
560
|
+
) -> T:
|
561
|
+
"""Create a model instance from a mapping with prefixed keys.
|
562
|
+
|
563
|
+
Useful for creating instances from JOIN query results where columns
|
564
|
+
are prefixed to avoid name conflicts.
|
565
|
+
|
566
|
+
Args:
|
567
|
+
mapping: Dictionary with prefixed keys
|
568
|
+
prepend: Prefix to strip from keys
|
569
|
+
|
570
|
+
Returns:
|
571
|
+
New instance of this model class
|
572
|
+
|
573
|
+
Example:
|
574
|
+
>>> row = {"user_id": UUID(...), "user_name": "Alice", "user_email": None}
|
575
|
+
>>> user = User.from_prepended_mapping(row, "user_")
|
576
|
+
"""
|
577
|
+
filtered_dict: dict[str, Any] = {}
|
578
|
+
for k, v in mapping.items():
|
579
|
+
if k.startswith(prepend):
|
580
|
+
filtered_dict[k[len(prepend) :]] = v
|
581
|
+
return cls.from_mapping(filtered_dict)
|
582
|
+
|
303
583
|
@classmethod
|
304
584
|
def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
|
585
|
+
"""Ensure the input is a model instance, converting from mapping if needed.
|
586
|
+
|
587
|
+
Args:
|
588
|
+
row: Either a model instance or a mapping to convert
|
589
|
+
|
590
|
+
Returns:
|
591
|
+
Model instance
|
592
|
+
"""
|
305
593
|
if isinstance(row, cls):
|
306
594
|
return row
|
307
595
|
return cls.from_mapping(row) # type: ignore
|
308
596
|
|
309
597
|
@classmethod
|
310
598
|
def create_table_sql(cls) -> Fragment:
|
599
|
+
"""Generate CREATE TABLE SQL for this model.
|
600
|
+
|
601
|
+
Returns:
|
602
|
+
Fragment containing CREATE TABLE IF NOT EXISTS statement
|
603
|
+
|
604
|
+
Example:
|
605
|
+
>>> list(User.create_table_sql())
|
606
|
+
['CREATE TABLE IF NOT EXISTS "users" ("id" UUID NOT NULL, "name" TEXT NOT NULL, "email" TEXT, PRIMARY KEY ("id"))']
|
607
|
+
"""
|
311
608
|
entries = [
|
312
609
|
sql(
|
313
610
|
"{} {}",
|
@@ -331,6 +628,20 @@ class ModelBase:
|
|
331
628
|
order_by: Union[FieldNames, str] = (),
|
332
629
|
for_update: bool = False,
|
333
630
|
) -> Fragment:
|
631
|
+
"""Generate SELECT SQL for this model.
|
632
|
+
|
633
|
+
Args:
|
634
|
+
where: WHERE conditions as Fragment or iterable of Fragments
|
635
|
+
order_by: ORDER BY field names
|
636
|
+
for_update: Whether to add FOR UPDATE clause
|
637
|
+
|
638
|
+
Returns:
|
639
|
+
Fragment containing SELECT statement
|
640
|
+
|
641
|
+
Example:
|
642
|
+
>>> list(User.select_sql(where=sql("name = {}", "Alice")))
|
643
|
+
['SELECT "id", "name", "email" FROM "users" WHERE name = $1', 'Alice']
|
644
|
+
"""
|
334
645
|
if isinstance(order_by, str):
|
335
646
|
order_by = (order_by,)
|
336
647
|
if not isinstance(where, Fragment):
|
@@ -360,6 +671,16 @@ class ModelBase:
|
|
360
671
|
query: Fragment,
|
361
672
|
prefetch: int = 1000,
|
362
673
|
) -> AsyncGenerator[T, None]:
|
674
|
+
"""Create an async generator from a query result.
|
675
|
+
|
676
|
+
Args:
|
677
|
+
connection: Database connection
|
678
|
+
query: SQL query Fragment
|
679
|
+
prefetch: Number of rows to prefetch
|
680
|
+
|
681
|
+
Yields:
|
682
|
+
Model instances from the query results
|
683
|
+
"""
|
363
684
|
async for row in connection.cursor(*query, prefetch=prefetch):
|
364
685
|
yield cls.from_mapping(row)
|
365
686
|
|
@@ -372,6 +693,22 @@ class ModelBase:
|
|
372
693
|
where: Where = (),
|
373
694
|
prefetch: int = 1000,
|
374
695
|
) -> AsyncGenerator[T, None]:
|
696
|
+
"""Create an async generator for SELECT results.
|
697
|
+
|
698
|
+
Args:
|
699
|
+
connection: Database connection
|
700
|
+
order_by: ORDER BY field names
|
701
|
+
for_update: Whether to add FOR UPDATE clause
|
702
|
+
where: WHERE conditions
|
703
|
+
prefetch: Number of rows to prefetch
|
704
|
+
|
705
|
+
Yields:
|
706
|
+
Model instances from the SELECT results
|
707
|
+
|
708
|
+
Example:
|
709
|
+
>>> async for user in User.select_cursor(conn, where=sql("active = {}", True)):
|
710
|
+
... print(user.name)
|
711
|
+
"""
|
375
712
|
return cls.cursor_from(
|
376
713
|
connection,
|
377
714
|
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
|
@@ -384,6 +721,15 @@ class ModelBase:
|
|
384
721
|
connection_or_pool: Union[Connection, Pool],
|
385
722
|
query: Fragment,
|
386
723
|
) -> list[T]:
|
724
|
+
"""Execute a query and return model instances.
|
725
|
+
|
726
|
+
Args:
|
727
|
+
connection_or_pool: Database connection or pool
|
728
|
+
query: SQL query Fragment
|
729
|
+
|
730
|
+
Returns:
|
731
|
+
List of model instances from the query results
|
732
|
+
"""
|
387
733
|
return [cls.from_mapping(row) for row in await connection_or_pool.fetch(*query)]
|
388
734
|
|
389
735
|
@classmethod
|
@@ -394,6 +740,20 @@ class ModelBase:
|
|
394
740
|
for_update: bool = False,
|
395
741
|
where: Where = (),
|
396
742
|
) -> list[T]:
|
743
|
+
"""Execute a SELECT query and return model instances.
|
744
|
+
|
745
|
+
Args:
|
746
|
+
connection_or_pool: Database connection or pool
|
747
|
+
order_by: ORDER BY field names
|
748
|
+
for_update: Whether to add FOR UPDATE clause
|
749
|
+
where: WHERE conditions
|
750
|
+
|
751
|
+
Returns:
|
752
|
+
List of model instances from the SELECT results
|
753
|
+
|
754
|
+
Example:
|
755
|
+
>>> users = await User.select(pool, where=sql("active = {}", True))
|
756
|
+
"""
|
397
757
|
return await cls.fetch_from(
|
398
758
|
connection_or_pool,
|
399
759
|
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
|
@@ -401,6 +761,18 @@ class ModelBase:
|
|
401
761
|
|
402
762
|
@classmethod
|
403
763
|
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
|
764
|
+
"""Generate INSERT SQL for creating a new record with RETURNING clause.
|
765
|
+
|
766
|
+
Args:
|
767
|
+
**kwargs: Field values for the new record
|
768
|
+
|
769
|
+
Returns:
|
770
|
+
Fragment containing INSERT ... RETURNING statement
|
771
|
+
|
772
|
+
Example:
|
773
|
+
>>> list(User.create_sql(name="Alice", email="alice@example.com"))
|
774
|
+
['INSERT INTO "users" ("name", "email") VALUES ($1, $2) RETURNING "id", "name", "email"', 'Alice', 'alice@example.com']
|
775
|
+
"""
|
404
776
|
column_info = cls.column_info()
|
405
777
|
return sql(
|
406
778
|
"INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
|
@@ -416,10 +788,35 @@ class ModelBase:
|
|
416
788
|
async def create(
|
417
789
|
cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
|
418
790
|
) -> T:
|
791
|
+
"""Create a new record in the database.
|
792
|
+
|
793
|
+
Args:
|
794
|
+
connection_or_pool: Database connection or pool
|
795
|
+
**kwargs: Field values for the new record
|
796
|
+
|
797
|
+
Returns:
|
798
|
+
Model instance representing the created record
|
799
|
+
|
800
|
+
Example:
|
801
|
+
>>> user = await User.create(pool, name="Alice", email="alice@example.com")
|
802
|
+
"""
|
419
803
|
row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
|
420
804
|
return cls.from_mapping(row)
|
421
805
|
|
422
806
|
def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
|
807
|
+
"""Generate INSERT SQL for this instance.
|
808
|
+
|
809
|
+
Args:
|
810
|
+
exclude: Field names to exclude from the INSERT
|
811
|
+
|
812
|
+
Returns:
|
813
|
+
Fragment containing INSERT statement
|
814
|
+
|
815
|
+
Example:
|
816
|
+
>>> user = User(name="Alice", email="alice@example.com")
|
817
|
+
>>> list(user.insert_sql())
|
818
|
+
['INSERT INTO "users" ("name", "email") VALUES ($1, $2)', 'Alice', 'alice@example.com']
|
819
|
+
"""
|
423
820
|
cached = self._cached(
|
424
821
|
("insert_sql", tuple(sorted(exclude))),
|
425
822
|
lambda: sql(
|
@@ -435,38 +832,136 @@ class ModelBase:
|
|
435
832
|
async def insert(
|
436
833
|
self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
|
437
834
|
) -> str:
|
835
|
+
"""Insert this instance into the database.
|
836
|
+
|
837
|
+
Args:
|
838
|
+
connection_or_pool: Database connection or pool
|
839
|
+
exclude: Field names to exclude from the INSERT
|
840
|
+
|
841
|
+
Returns:
|
842
|
+
Result string from the database operation
|
843
|
+
"""
|
438
844
|
return await connection_or_pool.execute(*self.insert_sql(exclude))
|
439
845
|
|
440
846
|
@classmethod
|
441
|
-
def upsert_sql(
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
847
|
+
def upsert_sql(
|
848
|
+
cls,
|
849
|
+
insert_sql: Fragment,
|
850
|
+
insert_only: FieldNamesSet = (),
|
851
|
+
force_update: FieldNamesSet = (),
|
852
|
+
) -> Fragment:
|
853
|
+
"""Generate UPSERT (INSERT ... ON CONFLICT DO UPDATE) SQL.
|
854
|
+
|
855
|
+
Args:
|
856
|
+
insert_sql: Base INSERT statement Fragment
|
857
|
+
insert_only: Field names to exclude from the UPDATE clause
|
858
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
859
|
+
|
860
|
+
Returns:
|
861
|
+
Fragment containing INSERT ... ON CONFLICT DO UPDATE statement
|
862
|
+
|
863
|
+
Example:
|
864
|
+
>>> insert = user.insert_sql()
|
865
|
+
>>> list(User.upsert_sql(insert))
|
866
|
+
['INSERT INTO "users" ("name", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "name"=EXCLUDED."name", "email"=EXCLUDED."email"', 'Alice', 'alice@example.com']
|
867
|
+
|
868
|
+
Note:
|
869
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
870
|
+
excluded from the UPDATE clause, unless overridden by force_update.
|
871
|
+
"""
|
872
|
+
# Combine insert_only parameter with auto-detected insert_only fields, but remove force_update fields
|
873
|
+
auto_insert_only = cls.insert_only_field_names() - set(force_update)
|
874
|
+
manual_insert_only = set(insert_only) - set(
|
875
|
+
force_update
|
876
|
+
) # Remove force_update from manual insert_only too
|
877
|
+
all_insert_only = manual_insert_only | auto_insert_only
|
878
|
+
|
879
|
+
def generate_upsert_fragment():
|
880
|
+
updatable_fields = cls.field_names(
|
881
|
+
exclude=(*cls.primary_key_names, *all_insert_only)
|
882
|
+
)
|
883
|
+
return sql(
|
884
|
+
" ON CONFLICT ({pks}) DO {action}",
|
446
885
|
insert_sql=insert_sql,
|
447
886
|
pks=sql.list(cls.primary_key_names_sql()),
|
448
|
-
|
449
|
-
sql(
|
450
|
-
|
451
|
-
|
887
|
+
action=(
|
888
|
+
sql(
|
889
|
+
"UPDATE SET {assignments}",
|
890
|
+
assignments=sql.list(
|
891
|
+
sql("{field}=EXCLUDED.{field}", field=sql.identifier(field))
|
892
|
+
for field in updatable_fields
|
893
|
+
),
|
452
894
|
)
|
895
|
+
if updatable_fields
|
896
|
+
else sql.literal("NOTHING")
|
453
897
|
),
|
454
|
-
).flatten()
|
898
|
+
).flatten()
|
899
|
+
|
900
|
+
cached = cls._cached(
|
901
|
+
("upsert_sql", tuple(sorted(all_insert_only))),
|
902
|
+
generate_upsert_fragment,
|
455
903
|
)
|
456
904
|
return Fragment([insert_sql, cached])
|
457
905
|
|
458
906
|
async def upsert(
|
459
|
-
self,
|
907
|
+
self,
|
908
|
+
connection_or_pool: Union[Connection, Pool],
|
909
|
+
exclude: FieldNamesSet = (),
|
910
|
+
insert_only: FieldNamesSet = (),
|
911
|
+
force_update: FieldNamesSet = (),
|
460
912
|
) -> bool:
|
913
|
+
"""Insert or update this instance in the database.
|
914
|
+
|
915
|
+
Args:
|
916
|
+
connection_or_pool: Database connection or pool
|
917
|
+
exclude: Field names to exclude from INSERT and UPDATE
|
918
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
919
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
920
|
+
|
921
|
+
Returns:
|
922
|
+
True if the record was updated, False if it was inserted
|
923
|
+
|
924
|
+
Example:
|
925
|
+
>>> user = User(id=1, name="Alice", created_at=datetime.now())
|
926
|
+
>>> # Only set created_at on INSERT, not UPDATE
|
927
|
+
>>> was_updated = await user.upsert(pool, insert_only={'created_at'})
|
928
|
+
>>> # Force update created_at even if it's marked insert_only in ColumnInfo
|
929
|
+
>>> was_updated = await user.upsert(pool, force_update={'created_at'})
|
930
|
+
|
931
|
+
Note:
|
932
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
933
|
+
treated as insert-only and combined with the insert_only parameter,
|
934
|
+
unless overridden by force_update.
|
935
|
+
"""
|
936
|
+
# upsert_sql automatically handles insert_only fields from ColumnInfo
|
937
|
+
# We only need to combine manual insert_only with exclude for the UPDATE clause
|
938
|
+
update_exclude = set(exclude) | set(insert_only)
|
461
939
|
query = sql(
|
462
940
|
"{} RETURNING xmax",
|
463
|
-
self.upsert_sql(
|
941
|
+
self.upsert_sql(
|
942
|
+
self.insert_sql(exclude=exclude),
|
943
|
+
insert_only=update_exclude,
|
944
|
+
force_update=force_update,
|
945
|
+
),
|
464
946
|
)
|
465
947
|
result = await connection_or_pool.fetchrow(*query)
|
466
948
|
return result["xmax"] != 0
|
467
949
|
|
468
950
|
@classmethod
|
469
951
|
def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
952
|
+
"""Generate DELETE SQL for multiple records.
|
953
|
+
|
954
|
+
Args:
|
955
|
+
rows: Model instances to delete
|
956
|
+
|
957
|
+
Returns:
|
958
|
+
Fragment containing DELETE statement with UNNEST-based WHERE clause
|
959
|
+
|
960
|
+
Example:
|
961
|
+
>>> users = [user1, user2, user3]
|
962
|
+
>>> list(User.delete_multiple_sql(users))
|
963
|
+
['DELETE FROM "users" WHERE ("id") IN (SELECT * FROM UNNEST($1::UUID[]))', (uuid1, uuid2, uuid3)]
|
964
|
+
"""
|
470
965
|
cached = cls._cached(
|
471
966
|
("delete_multiple_sql",),
|
472
967
|
lambda: sql(
|
@@ -487,10 +982,34 @@ class ModelBase:
|
|
487
982
|
async def delete_multiple(
|
488
983
|
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
489
984
|
) -> str:
|
985
|
+
"""Delete multiple records from the database.
|
986
|
+
|
987
|
+
Args:
|
988
|
+
connection_or_pool: Database connection or pool
|
989
|
+
rows: Model instances to delete
|
990
|
+
|
991
|
+
Returns:
|
992
|
+
Result string from the database operation
|
993
|
+
"""
|
490
994
|
return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
|
491
995
|
|
492
996
|
@classmethod
|
493
997
|
def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
998
|
+
"""Generate bulk INSERT SQL using UNNEST.
|
999
|
+
|
1000
|
+
This is the most efficient method for bulk inserts in PostgreSQL.
|
1001
|
+
|
1002
|
+
Args:
|
1003
|
+
rows: Model instances to insert
|
1004
|
+
|
1005
|
+
Returns:
|
1006
|
+
Fragment containing INSERT ... SELECT FROM UNNEST statement
|
1007
|
+
|
1008
|
+
Example:
|
1009
|
+
>>> users = [User(name="Alice"), User(name="Bob")]
|
1010
|
+
>>> list(User.insert_multiple_sql(users))
|
1011
|
+
['INSERT INTO "users" ("name", "email") SELECT * FROM UNNEST($1::TEXT[], $2::TEXT[])', ('Alice', 'Bob'), (None, None)]
|
1012
|
+
"""
|
494
1013
|
cached = cls._cached(
|
495
1014
|
("insert_multiple_sql",),
|
496
1015
|
lambda: sql(
|
@@ -509,6 +1028,18 @@ class ModelBase:
|
|
509
1028
|
|
510
1029
|
@classmethod
|
511
1030
|
def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
|
1031
|
+
"""Generate bulk INSERT SQL using VALUES syntax.
|
1032
|
+
|
1033
|
+
This method is required when your model contains array columns, because
|
1034
|
+
PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
|
1035
|
+
Use this instead of the UNNEST method when you have array-typed fields.
|
1036
|
+
|
1037
|
+
Args:
|
1038
|
+
rows: Model instances to insert
|
1039
|
+
|
1040
|
+
Returns:
|
1041
|
+
Fragment containing INSERT ... VALUES statement
|
1042
|
+
"""
|
512
1043
|
return sql(
|
513
1044
|
"INSERT INTO {table} ({fields}) VALUES {values}",
|
514
1045
|
table=cls.table_name_sql(),
|
@@ -523,6 +1054,15 @@ class ModelBase:
|
|
523
1054
|
def insert_multiple_executemany_chunk_sql(
|
524
1055
|
cls: type[T], chunk_size: int
|
525
1056
|
) -> Fragment:
|
1057
|
+
"""Generate INSERT SQL template for executemany with specific chunk size.
|
1058
|
+
|
1059
|
+
Args:
|
1060
|
+
chunk_size: Number of records per batch
|
1061
|
+
|
1062
|
+
Returns:
|
1063
|
+
Fragment containing INSERT statement with numbered placeholders
|
1064
|
+
"""
|
1065
|
+
|
526
1066
|
def generate() -> Fragment:
|
527
1067
|
columns = len(cls.column_info())
|
528
1068
|
values = ", ".join(
|
@@ -545,6 +1085,14 @@ class ModelBase:
|
|
545
1085
|
async def insert_multiple_executemany(
|
546
1086
|
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
547
1087
|
) -> None:
|
1088
|
+
"""Insert multiple records using asyncpg's executemany.
|
1089
|
+
|
1090
|
+
This is the most compatible but slowest bulk insert method.
|
1091
|
+
|
1092
|
+
Args:
|
1093
|
+
connection_or_pool: Database connection or pool
|
1094
|
+
rows: Model instances to insert
|
1095
|
+
"""
|
548
1096
|
args = [r.field_values() for r in rows]
|
549
1097
|
query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
|
550
1098
|
if args:
|
@@ -554,12 +1102,36 @@ class ModelBase:
|
|
554
1102
|
async def insert_multiple_unnest(
|
555
1103
|
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
556
1104
|
) -> str:
|
1105
|
+
"""Insert multiple records using PostgreSQL UNNEST.
|
1106
|
+
|
1107
|
+
This is the most efficient bulk insert method for PostgreSQL.
|
1108
|
+
|
1109
|
+
Args:
|
1110
|
+
connection_or_pool: Database connection or pool
|
1111
|
+
rows: Model instances to insert
|
1112
|
+
|
1113
|
+
Returns:
|
1114
|
+
Result string from the database operation
|
1115
|
+
"""
|
557
1116
|
return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
|
558
1117
|
|
559
1118
|
@classmethod
|
560
1119
|
async def insert_multiple_array_safe(
|
561
1120
|
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
562
1121
|
) -> str:
|
1122
|
+
"""Insert multiple records using VALUES syntax with chunking.
|
1123
|
+
|
1124
|
+
This method is required when your model contains array columns, because
|
1125
|
+
PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
|
1126
|
+
Data is processed in chunks to manage memory usage.
|
1127
|
+
|
1128
|
+
Args:
|
1129
|
+
connection_or_pool: Database connection or pool
|
1130
|
+
rows: Model instances to insert
|
1131
|
+
|
1132
|
+
Returns:
|
1133
|
+
Result string from the last chunk operation
|
1134
|
+
"""
|
563
1135
|
last = ""
|
564
1136
|
for chunk in chunked(rows, 100):
|
565
1137
|
last = await connection_or_pool.execute(
|
@@ -571,6 +1143,21 @@ class ModelBase:
|
|
571
1143
|
async def insert_multiple(
|
572
1144
|
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
|
573
1145
|
) -> str:
|
1146
|
+
"""Insert multiple records using the configured insert_multiple_mode.
|
1147
|
+
|
1148
|
+
Args:
|
1149
|
+
connection_or_pool: Database connection or pool
|
1150
|
+
rows: Model instances to insert
|
1151
|
+
|
1152
|
+
Returns:
|
1153
|
+
Result string from the database operation
|
1154
|
+
|
1155
|
+
Note:
|
1156
|
+
The actual method used depends on the insert_multiple_mode setting:
|
1157
|
+
- 'unnest': Most efficient, uses UNNEST (default)
|
1158
|
+
- 'array_safe': Uses VALUES syntax; required when model has array columns
|
1159
|
+
- 'executemany': Uses asyncpg's executemany, slowest but most compatible
|
1160
|
+
"""
|
574
1161
|
if cls.insert_multiple_mode == "executemany":
|
575
1162
|
await cls.insert_multiple_executemany(connection_or_pool, rows)
|
576
1163
|
return "INSERT"
|
@@ -585,10 +1172,21 @@ class ModelBase:
|
|
585
1172
|
connection_or_pool: Union[Connection, Pool],
|
586
1173
|
rows: Iterable[T],
|
587
1174
|
insert_only: FieldNamesSet = (),
|
1175
|
+
force_update: FieldNamesSet = (),
|
588
1176
|
) -> None:
|
1177
|
+
"""Bulk upsert using asyncpg's executemany.
|
1178
|
+
|
1179
|
+
Args:
|
1180
|
+
connection_or_pool: Database connection or pool
|
1181
|
+
rows: Model instances to upsert
|
1182
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1183
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1184
|
+
"""
|
589
1185
|
args = [r.field_values() for r in rows]
|
590
1186
|
query = cls.upsert_sql(
|
591
|
-
cls.insert_multiple_executemany_chunk_sql(1),
|
1187
|
+
cls.insert_multiple_executemany_chunk_sql(1),
|
1188
|
+
insert_only=insert_only,
|
1189
|
+
force_update=force_update,
|
592
1190
|
).query()[0]
|
593
1191
|
if args:
|
594
1192
|
await connection_or_pool.executemany(query, args)
|
@@ -599,9 +1197,25 @@ class ModelBase:
|
|
599
1197
|
connection_or_pool: Union[Connection, Pool],
|
600
1198
|
rows: Iterable[T],
|
601
1199
|
insert_only: FieldNamesSet = (),
|
1200
|
+
force_update: FieldNamesSet = (),
|
602
1201
|
) -> str:
|
1202
|
+
"""Bulk upsert using PostgreSQL UNNEST.
|
1203
|
+
|
1204
|
+
Args:
|
1205
|
+
connection_or_pool: Database connection or pool
|
1206
|
+
rows: Model instances to upsert
|
1207
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1208
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1209
|
+
|
1210
|
+
Returns:
|
1211
|
+
Result string from the database operation
|
1212
|
+
"""
|
603
1213
|
return await connection_or_pool.execute(
|
604
|
-
*cls.upsert_sql(
|
1214
|
+
*cls.upsert_sql(
|
1215
|
+
cls.insert_multiple_sql(rows),
|
1216
|
+
insert_only=insert_only,
|
1217
|
+
force_update=force_update,
|
1218
|
+
)
|
605
1219
|
)
|
606
1220
|
|
607
1221
|
@classmethod
|
@@ -610,12 +1224,29 @@ class ModelBase:
|
|
610
1224
|
connection_or_pool: Union[Connection, Pool],
|
611
1225
|
rows: Iterable[T],
|
612
1226
|
insert_only: FieldNamesSet = (),
|
1227
|
+
force_update: FieldNamesSet = (),
|
613
1228
|
) -> str:
|
1229
|
+
"""Bulk upsert using VALUES syntax with chunking.
|
1230
|
+
|
1231
|
+
This method is required when your model contains array columns, because
|
1232
|
+
PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
|
1233
|
+
|
1234
|
+
Args:
|
1235
|
+
connection_or_pool: Database connection or pool
|
1236
|
+
rows: Model instances to upsert
|
1237
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1238
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1239
|
+
|
1240
|
+
Returns:
|
1241
|
+
Result string from the last chunk operation
|
1242
|
+
"""
|
614
1243
|
last = ""
|
615
1244
|
for chunk in chunked(rows, 100):
|
616
1245
|
last = await connection_or_pool.execute(
|
617
1246
|
*cls.upsert_sql(
|
618
|
-
cls.insert_multiple_array_safe_sql(chunk),
|
1247
|
+
cls.insert_multiple_array_safe_sql(chunk),
|
1248
|
+
insert_only=insert_only,
|
1249
|
+
force_update=force_update,
|
619
1250
|
)
|
620
1251
|
)
|
621
1252
|
return last
|
@@ -626,25 +1257,66 @@ class ModelBase:
|
|
626
1257
|
connection_or_pool: Union[Connection, Pool],
|
627
1258
|
rows: Iterable[T],
|
628
1259
|
insert_only: FieldNamesSet = (),
|
1260
|
+
force_update: FieldNamesSet = (),
|
629
1261
|
) -> str:
|
1262
|
+
"""Bulk upsert (INSERT ... ON CONFLICT DO UPDATE) multiple records.
|
1263
|
+
|
1264
|
+
Args:
|
1265
|
+
connection_or_pool: Database connection or pool
|
1266
|
+
rows: Model instances to upsert
|
1267
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1268
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1269
|
+
|
1270
|
+
Returns:
|
1271
|
+
Result string from the database operation
|
1272
|
+
|
1273
|
+
Example:
|
1274
|
+
>>> await User.upsert_multiple(pool, users, insert_only={'created_at'})
|
1275
|
+
>>> await User.upsert_multiple(pool, users, force_update={'created_at'})
|
1276
|
+
|
1277
|
+
Note:
|
1278
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
1279
|
+
treated as insert-only and combined with the insert_only parameter,
|
1280
|
+
unless overridden by force_update.
|
1281
|
+
"""
|
1282
|
+
# upsert_sql automatically handles insert_only fields from ColumnInfo
|
1283
|
+
# Pass manual insert_only parameter through to the specific implementations
|
1284
|
+
|
630
1285
|
if cls.insert_multiple_mode == "executemany":
|
631
1286
|
await cls.upsert_multiple_executemany(
|
632
|
-
connection_or_pool,
|
1287
|
+
connection_or_pool,
|
1288
|
+
rows,
|
1289
|
+
insert_only=insert_only,
|
1290
|
+
force_update=force_update,
|
633
1291
|
)
|
634
1292
|
return "INSERT"
|
635
1293
|
elif cls.insert_multiple_mode == "array_safe":
|
636
1294
|
return await cls.upsert_multiple_array_safe(
|
637
|
-
connection_or_pool,
|
1295
|
+
connection_or_pool,
|
1296
|
+
rows,
|
1297
|
+
insert_only=insert_only,
|
1298
|
+
force_update=force_update,
|
638
1299
|
)
|
639
1300
|
else:
|
640
1301
|
return await cls.upsert_multiple_unnest(
|
641
|
-
connection_or_pool,
|
1302
|
+
connection_or_pool,
|
1303
|
+
rows,
|
1304
|
+
insert_only=insert_only,
|
1305
|
+
force_update=force_update,
|
642
1306
|
)
|
643
1307
|
|
644
1308
|
@classmethod
|
645
1309
|
def _get_equal_ignoring_fn(
|
646
1310
|
cls: type[T], ignore: FieldNamesSet = ()
|
647
1311
|
) -> Callable[[T, T], bool]:
|
1312
|
+
"""Generate optimized function to compare instances ignoring certain fields.
|
1313
|
+
|
1314
|
+
Args:
|
1315
|
+
ignore: Field names to ignore during comparison
|
1316
|
+
|
1317
|
+
Returns:
|
1318
|
+
Compiled function that compares two instances, returning True if equal
|
1319
|
+
"""
|
648
1320
|
env: dict[str, Any] = {}
|
649
1321
|
func = ["def equal_ignoring(a, b):"]
|
650
1322
|
for ci in cls.column_info().values():
|
@@ -663,8 +1335,38 @@ class ModelBase:
|
|
663
1335
|
where: Where,
|
664
1336
|
ignore: FieldNamesSet = (),
|
665
1337
|
insert_only: FieldNamesSet = (),
|
1338
|
+
force_update: FieldNamesSet = (),
|
666
1339
|
) -> "ReplaceMultiplePlan[T]":
|
667
|
-
|
1340
|
+
"""Plan a replace operation by comparing new data with existing records.
|
1341
|
+
|
1342
|
+
This method analyzes the differences between the provided rows and existing
|
1343
|
+
database records, determining which records need to be created, updated, or deleted.
|
1344
|
+
|
1345
|
+
Args:
|
1346
|
+
connection: Database connection (must support FOR UPDATE)
|
1347
|
+
rows: New data as model instances or mappings
|
1348
|
+
where: WHERE clause to limit which existing records to consider
|
1349
|
+
ignore: Field names to ignore when comparing records
|
1350
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1351
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1352
|
+
|
1353
|
+
Returns:
|
1354
|
+
ReplaceMultiplePlan containing the planned operations
|
1355
|
+
|
1356
|
+
Example:
|
1357
|
+
>>> plan = await User.plan_replace_multiple(
|
1358
|
+
... conn, new_users, where=sql("department_id = {}", dept_id)
|
1359
|
+
... )
|
1360
|
+
>>> print(f"Will create {len(plan.created)}, update {len(plan.updated)}, delete {len(plan.deleted)}")
|
1361
|
+
|
1362
|
+
Note:
|
1363
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
1364
|
+
treated as insert-only and combined with the insert_only parameter,
|
1365
|
+
unless overridden by force_update.
|
1366
|
+
"""
|
1367
|
+
# For comparison purposes, combine auto-detected insert_only fields with manual ones
|
1368
|
+
all_insert_only = cls.insert_only_field_names() | set(insert_only)
|
1369
|
+
ignore = sorted(set(ignore) | all_insert_only)
|
668
1370
|
equal_ignoring = cls._cached(
|
669
1371
|
("equal_ignoring", tuple(ignore)),
|
670
1372
|
lambda: cls._get_equal_ignoring_fn(ignore),
|
@@ -687,7 +1389,11 @@ class ModelBase:
|
|
687
1389
|
|
688
1390
|
created = list(pending.values())
|
689
1391
|
|
690
|
-
|
1392
|
+
# Pass only manual insert_only and force_update to the plan
|
1393
|
+
# since upsert_multiple handles auto-detected ones
|
1394
|
+
return ReplaceMultiplePlan(
|
1395
|
+
cls, insert_only, force_update, created, updated, deleted
|
1396
|
+
)
|
691
1397
|
|
692
1398
|
@classmethod
|
693
1399
|
async def replace_multiple(
|
@@ -698,9 +1404,42 @@ class ModelBase:
|
|
698
1404
|
where: Where,
|
699
1405
|
ignore: FieldNamesSet = (),
|
700
1406
|
insert_only: FieldNamesSet = (),
|
1407
|
+
force_update: FieldNamesSet = (),
|
701
1408
|
) -> tuple[list[T], list[T], list[T]]:
|
1409
|
+
"""Replace records in the database with the provided data.
|
1410
|
+
|
1411
|
+
This is a complete replace operation: records matching the WHERE clause
|
1412
|
+
that aren't in the new data will be deleted, new records will be inserted,
|
1413
|
+
and changed records will be updated.
|
1414
|
+
|
1415
|
+
Args:
|
1416
|
+
connection: Database connection (must support FOR UPDATE)
|
1417
|
+
rows: New data as model instances or mappings
|
1418
|
+
where: WHERE clause to limit which existing records to consider for replacement
|
1419
|
+
ignore: Field names to ignore when comparing records
|
1420
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1421
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1422
|
+
|
1423
|
+
Returns:
|
1424
|
+
Tuple of (created_records, updated_records, deleted_records)
|
1425
|
+
|
1426
|
+
Example:
|
1427
|
+
>>> created, updated, deleted = await User.replace_multiple(
|
1428
|
+
... conn, new_users, where=sql("department_id = {}", dept_id)
|
1429
|
+
... )
|
1430
|
+
|
1431
|
+
Note:
|
1432
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
1433
|
+
treated as insert-only and combined with the insert_only parameter,
|
1434
|
+
unless overridden by force_update.
|
1435
|
+
"""
|
702
1436
|
plan = await cls.plan_replace_multiple(
|
703
|
-
connection,
|
1437
|
+
connection,
|
1438
|
+
rows,
|
1439
|
+
where=where,
|
1440
|
+
ignore=ignore,
|
1441
|
+
insert_only=insert_only,
|
1442
|
+
force_update=force_update,
|
704
1443
|
)
|
705
1444
|
await plan.execute(connection)
|
706
1445
|
return plan.cud
|
@@ -709,6 +1448,14 @@ class ModelBase:
|
|
709
1448
|
def _get_differences_ignoring_fn(
|
710
1449
|
cls: type[T], ignore: FieldNamesSet = ()
|
711
1450
|
) -> Callable[[T, T], list[str]]:
|
1451
|
+
"""Generate optimized function to find field differences between instances.
|
1452
|
+
|
1453
|
+
Args:
|
1454
|
+
ignore: Field names to ignore during comparison
|
1455
|
+
|
1456
|
+
Returns:
|
1457
|
+
Compiled function that returns list of field names that differ
|
1458
|
+
"""
|
712
1459
|
env: dict[str, Any] = {}
|
713
1460
|
func = [
|
714
1461
|
"def differences_ignoring(a, b):",
|
@@ -732,8 +1479,40 @@ class ModelBase:
|
|
732
1479
|
where: Where,
|
733
1480
|
ignore: FieldNamesSet = (),
|
734
1481
|
insert_only: FieldNamesSet = (),
|
1482
|
+
force_update: FieldNamesSet = (),
|
735
1483
|
) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
|
736
|
-
|
1484
|
+
"""Replace records and report the specific field differences for updates.
|
1485
|
+
|
1486
|
+
Like replace_multiple, but provides detailed information about which
|
1487
|
+
fields changed for each updated record.
|
1488
|
+
|
1489
|
+
Args:
|
1490
|
+
connection: Database connection (must support FOR UPDATE)
|
1491
|
+
rows: New data as model instances or mappings
|
1492
|
+
where: WHERE clause to limit which existing records to consider
|
1493
|
+
ignore: Field names to ignore when comparing records
|
1494
|
+
insert_only: Field names that should only be set on INSERT, not UPDATE
|
1495
|
+
force_update: Field names to force include in UPDATE clause, overriding insert_only settings
|
1496
|
+
|
1497
|
+
Returns:
|
1498
|
+
Tuple of (created_records, update_triples, deleted_records)
|
1499
|
+
where update_triples contains (old_record, new_record, changed_field_names)
|
1500
|
+
|
1501
|
+
Example:
|
1502
|
+
>>> created, updates, deleted = await User.replace_multiple_reporting_differences(
|
1503
|
+
... conn, new_users, where=sql("department_id = {}", dept_id)
|
1504
|
+
... )
|
1505
|
+
>>> for old, new, fields in updates:
|
1506
|
+
... print(f"Updated {old.name}: changed {', '.join(fields)}")
|
1507
|
+
|
1508
|
+
Note:
|
1509
|
+
Fields marked with ColumnInfo(insert_only=True) are automatically
|
1510
|
+
treated as insert-only and combined with the insert_only parameter,
|
1511
|
+
unless overridden by force_update.
|
1512
|
+
"""
|
1513
|
+
# For comparison purposes, combine auto-detected insert_only fields with manual ones
|
1514
|
+
all_insert_only = cls.insert_only_field_names() | set(insert_only)
|
1515
|
+
ignore = sorted(set(ignore) | all_insert_only)
|
737
1516
|
differences_ignoring = cls._cached(
|
738
1517
|
("differences_ignoring", tuple(ignore)),
|
739
1518
|
lambda: cls._get_differences_ignoring_fn(ignore),
|
@@ -763,6 +1542,7 @@ class ModelBase:
|
|
763
1542
|
connection,
|
764
1543
|
(*created, *(t[1] for t in updated_triples)),
|
765
1544
|
insert_only=insert_only,
|
1545
|
+
force_update=force_update,
|
766
1546
|
)
|
767
1547
|
if deleted:
|
768
1548
|
await cls.delete_multiple(connection, deleted)
|
@@ -774,30 +1554,67 @@ class ModelBase:
|
|
774
1554
|
class ReplaceMultiplePlan(Generic[T]):
|
775
1555
|
model_class: type[T]
|
776
1556
|
insert_only: FieldNamesSet
|
1557
|
+
force_update: FieldNamesSet
|
777
1558
|
created: list[T]
|
778
1559
|
updated: list[T]
|
779
1560
|
deleted: list[T]
|
780
1561
|
|
781
1562
|
@property
|
782
1563
|
def cud(self) -> tuple[list[T], list[T], list[T]]:
|
1564
|
+
"""Get the create, update, delete lists as a tuple.
|
1565
|
+
|
1566
|
+
Returns:
|
1567
|
+
Tuple of (created, updated, deleted) record lists
|
1568
|
+
"""
|
783
1569
|
return (self.created, self.updated, self.deleted)
|
784
1570
|
|
785
1571
|
async def execute_upserts(self, connection: Connection) -> None:
|
1572
|
+
"""Execute the upsert operations (creates and updates).
|
1573
|
+
|
1574
|
+
Args:
|
1575
|
+
connection: Database connection
|
1576
|
+
"""
|
786
1577
|
if self.created or self.updated:
|
787
1578
|
await self.model_class.upsert_multiple(
|
788
|
-
connection,
|
1579
|
+
connection,
|
1580
|
+
(*self.created, *self.updated),
|
1581
|
+
insert_only=self.insert_only,
|
1582
|
+
force_update=self.force_update,
|
789
1583
|
)
|
790
1584
|
|
791
1585
|
async def execute_deletes(self, connection: Connection) -> None:
|
1586
|
+
"""Execute the delete operations.
|
1587
|
+
|
1588
|
+
Args:
|
1589
|
+
connection: Database connection
|
1590
|
+
"""
|
792
1591
|
if self.deleted:
|
793
1592
|
await self.model_class.delete_multiple(connection, self.deleted)
|
794
1593
|
|
795
1594
|
async def execute(self, connection: Connection) -> None:
|
1595
|
+
"""Execute all planned operations (upserts then deletes).
|
1596
|
+
|
1597
|
+
Args:
|
1598
|
+
connection: Database connection
|
1599
|
+
"""
|
796
1600
|
await self.execute_upserts(connection)
|
797
1601
|
await self.execute_deletes(connection)
|
798
1602
|
|
799
1603
|
|
800
1604
|
def chunked(lst, n):
|
1605
|
+
"""Split an iterable into chunks of size n.
|
1606
|
+
|
1607
|
+
Args:
|
1608
|
+
lst: Iterable to chunk
|
1609
|
+
n: Chunk size
|
1610
|
+
|
1611
|
+
Yields:
|
1612
|
+
Lists of up to n items from the input
|
1613
|
+
|
1614
|
+
Example:
|
1615
|
+
>>> list(chunked([1, 2, 3, 4, 5], 2))
|
1616
|
+
[[1, 2], [3, 4], [5]]
|
1617
|
+
"""
|
801
1618
|
if type(lst) is not list:
|
802
1619
|
lst = list(lst)
|
803
1620
|
for i in range(0, len(lst), n):
|