sql-athame 0.4.0a12__py3-none-any.whl → 0.4.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_athame/dataclasses.py CHANGED
@@ -33,6 +33,35 @@ Pool: TypeAlias = Any
33
33
 
34
34
  @dataclass
35
35
  class ColumnInfo:
36
+ """Column metadata for dataclass fields.
37
+
38
+ This class specifies SQL column properties that can be applied to dataclass fields
39
+ to control how they are mapped to database columns.
40
+
41
+ Attributes:
42
+ type: SQL type name for query parameters (e.g., 'TEXT', 'INTEGER')
43
+ create_type: SQL type for CREATE TABLE statements (defaults to type if not specified)
44
+ nullable: Whether the column allows NULL values (inferred from Optional types if not specified)
45
+ constraints: Additional SQL constraints (e.g., 'UNIQUE', 'CHECK (value > 0)')
46
+ serialize: Function to transform Python values before database storage
47
+ deserialize: Function to transform database values back to Python objects
48
+ insert_only: Whether this field should only be set on INSERT, not UPDATE in upsert operations
49
+
50
+ Example:
51
+ >>> from dataclasses import dataclass
52
+ >>> from typing import Annotated
53
+ >>> from sql_athame import ModelBase, ColumnInfo
54
+ >>> import json
55
+ >>>
56
+ >>> @dataclass
57
+ ... class Product(ModelBase, table_name="products", primary_key="id"):
58
+ ... id: int
59
+ ... name: str
60
+ ... price: Annotated[float, ColumnInfo(constraints="CHECK (price > 0)")]
61
+ ... tags: Annotated[list, ColumnInfo(type="JSONB", serialize=json.dumps, deserialize=json.loads)]
62
+ ... created_at: Annotated[datetime, ColumnInfo(insert_only=True)]
63
+ """
64
+
36
65
  type: Optional[str] = None
37
66
  create_type: Optional[str] = None
38
67
  nullable: Optional[bool] = None
@@ -42,6 +71,7 @@ class ColumnInfo:
42
71
 
43
72
  serialize: Optional[Callable[[Any], Any]] = None
44
73
  deserialize: Optional[Callable[[Any], Any]] = None
74
+ insert_only: Optional[bool] = None
45
75
 
46
76
  def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
47
77
  if constraints is not None:
@@ -51,6 +81,15 @@ class ColumnInfo:
51
81
 
52
82
  @staticmethod
53
83
  def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
84
+ """Merge two ColumnInfo instances, with b taking precedence over a.
85
+
86
+ Args:
87
+ a: Base ColumnInfo
88
+ b: ColumnInfo to overlay on top of a
89
+
90
+ Returns:
91
+ New ColumnInfo with b's non-None values overriding a's values
92
+ """
54
93
  return ColumnInfo(
55
94
  type=b.type if b.type is not None else a.type,
56
95
  create_type=b.create_type if b.create_type is not None else a.create_type,
@@ -58,11 +97,29 @@ class ColumnInfo:
58
97
  _constraints=(*a._constraints, *b._constraints),
59
98
  serialize=b.serialize if b.serialize is not None else a.serialize,
60
99
  deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
100
+ insert_only=b.insert_only if b.insert_only is not None else a.insert_only,
61
101
  )
62
102
 
63
103
 
64
104
  @dataclass
65
105
  class ConcreteColumnInfo:
106
+ """Resolved column information for a specific dataclass field.
107
+
108
+ This is the final, computed column metadata after resolving type hints,
109
+ merging ColumnInfo instances, and applying defaults.
110
+
111
+ Attributes:
112
+ field: The dataclass Field object
113
+ type_hint: The resolved Python type hint
114
+ type: SQL type for query parameters
115
+ create_type: SQL type for CREATE TABLE statements
116
+ nullable: Whether the column allows NULL values
117
+ constraints: Tuple of SQL constraint strings
118
+ serialize: Optional serialization function
119
+ deserialize: Optional deserialization function
120
+ insert_only: Whether this field should only be set on INSERT, not UPDATE
121
+ """
122
+
66
123
  field: Field
67
124
  type_hint: type
68
125
  type: str
@@ -71,11 +128,25 @@ class ConcreteColumnInfo:
71
128
  constraints: tuple[str, ...]
72
129
  serialize: Optional[Callable[[Any], Any]] = None
73
130
  deserialize: Optional[Callable[[Any], Any]] = None
131
+ insert_only: bool = False
74
132
 
75
133
  @staticmethod
76
134
  def from_column_info(
77
135
  field: Field, type_hint: Any, *args: ColumnInfo
78
136
  ) -> "ConcreteColumnInfo":
137
+ """Create ConcreteColumnInfo from a field and its ColumnInfo metadata.
138
+
139
+ Args:
140
+ field: The dataclass Field
141
+ type_hint: The resolved type hint for the field
142
+ *args: ColumnInfo instances to merge (later ones take precedence)
143
+
144
+ Returns:
145
+ ConcreteColumnInfo with all metadata resolved
146
+
147
+ Raises:
148
+ ValueError: If no SQL type can be determined for the field
149
+ """
79
150
  info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
80
151
  if info.create_type is None and info.type is not None:
81
152
  info.create_type = info.type
@@ -91,9 +162,15 @@ class ConcreteColumnInfo:
91
162
  constraints=info._constraints,
92
163
  serialize=info.serialize,
93
164
  deserialize=info.deserialize,
165
+ insert_only=bool(info.insert_only),
94
166
  )
95
167
 
96
168
  def create_table_string(self) -> str:
169
+ """Generate the SQL column definition for CREATE TABLE statements.
170
+
171
+ Returns:
172
+ SQL string like "TEXT NOT NULL CHECK (length > 0)"
173
+ """
97
174
  parts = (
98
175
  self.create_type,
99
176
  *(() if self.nullable else ("NOT NULL",)),
@@ -102,6 +179,14 @@ class ConcreteColumnInfo:
102
179
  return " ".join(parts)
103
180
 
104
181
  def maybe_serialize(self, value: Any) -> Any:
182
+ """Apply serialization function if configured, otherwise return value unchanged.
183
+
184
+ Args:
185
+ value: The Python value to potentially serialize
186
+
187
+ Returns:
188
+ Serialized value if serialize function is configured, otherwise original value
189
+ """
105
190
  if self.serialize:
106
191
  return self.serialize(value)
107
192
  return value
@@ -179,6 +264,15 @@ class ModelBase:
179
264
 
180
265
  @classmethod
181
266
  def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
267
+ """Cache computation results by key.
268
+
269
+ Args:
270
+ key: Cache key tuple
271
+ thunk: Function to compute the value if not cached
272
+
273
+ Returns:
274
+ Cached or computed value
275
+ """
182
276
  try:
183
277
  return cls._cache[key]
184
278
  except KeyError:
@@ -187,6 +281,18 @@ class ModelBase:
187
281
 
188
282
  @classmethod
189
283
  def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
284
+ """Generate ConcreteColumnInfo for a dataclass field.
285
+
286
+ Analyzes the field's type hint and metadata to determine SQL column properties.
287
+ Looks for ColumnInfo in the field's metadata and merges it with type-based defaults.
288
+
289
+ Args:
290
+ field: The dataclass Field object
291
+ type_hint: The resolved type hint for the field
292
+
293
+ Returns:
294
+ ConcreteColumnInfo with all column metadata resolved
295
+ """
190
296
  base_type = type_hint
191
297
  metadata = []
192
298
  if get_origin(type_hint) is Annotated:
@@ -202,6 +308,14 @@ class ModelBase:
202
308
 
203
309
  @classmethod
204
310
  def column_info(cls) -> dict[str, ConcreteColumnInfo]:
311
+ """Get column information for all fields in this model.
312
+
313
+ Returns a cached mapping of field names to their resolved column information.
314
+ This is computed once per class and cached for performance.
315
+
316
+ Returns:
317
+ Dictionary mapping field names to ConcreteColumnInfo objects
318
+ """
205
319
  try:
206
320
  return cls._column_info
207
321
  except AttributeError:
@@ -214,20 +328,64 @@ class ModelBase:
214
328
 
215
329
  @classmethod
216
330
  def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
331
+ """Generate SQL fragment for the table name.
332
+
333
+ Args:
334
+ prefix: Optional schema or alias prefix
335
+
336
+ Returns:
337
+ Fragment containing the properly quoted table identifier
338
+
339
+ Example:
340
+ >>> list(User.table_name_sql())
341
+ ['"users"']
342
+ >>> list(User.table_name_sql(prefix="public"))
343
+ ['"public"."users"']
344
+ """
217
345
  return sql.identifier(cls.table_name, prefix=prefix)
218
346
 
219
347
  @classmethod
220
348
  def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
349
+ """Generate SQL fragments for primary key column names.
350
+
351
+ Args:
352
+ prefix: Optional table alias prefix
353
+
354
+ Returns:
355
+ List of Fragment objects for each primary key column
356
+ """
221
357
  return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
222
358
 
223
359
  @classmethod
224
360
  def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
361
+ """Get list of field names for this model.
362
+
363
+ Args:
364
+ exclude: Field names to exclude from the result
365
+
366
+ Returns:
367
+ List of field names as strings
368
+ """
225
369
  return [
226
370
  ci.field.name
227
371
  for ci in cls.column_info().values()
228
372
  if ci.field.name not in exclude
229
373
  ]
230
374
 
375
+ @classmethod
376
+ def insert_only_field_names(cls) -> set[str]:
377
+ """Get set of field names marked as insert_only in ColumnInfo.
378
+
379
+ Returns:
380
+ Set of field names that should only be set on INSERT, not UPDATE
381
+ """
382
+ return cls._cached(
383
+ ("insert_only_field_names",),
384
+ lambda: {
385
+ ci.field.name for ci in cls.column_info().values() if ci.insert_only
386
+ },
387
+ )
388
+
231
389
  @classmethod
232
390
  def field_names_sql(
233
391
  cls,
@@ -236,6 +394,24 @@ class ModelBase:
236
394
  exclude: FieldNamesSet = (),
237
395
  as_prepended: Optional[str] = None,
238
396
  ) -> list[Fragment]:
397
+ """Generate SQL fragments for field names.
398
+
399
+ Args:
400
+ prefix: Optional table alias prefix for column names
401
+ exclude: Field names to exclude from the result
402
+ as_prepended: If provided, generate "column AS prepended_column" aliases
403
+
404
+ Returns:
405
+ List of Fragment objects for each field
406
+
407
+ Example:
408
+ >>> list(sql.list(User.field_names_sql()))
409
+ ['"id", "name", "email"']
410
+ >>> list(sql.list(User.field_names_sql(prefix="u")))
411
+ ['"u"."id", "u"."name", "u"."email"']
412
+ >>> list(sql.list(User.field_names_sql(as_prepended="user_")))
413
+ ['"id" AS "user_id", "name" AS "user_name", "email" AS "user_email"']
414
+ """
239
415
  if as_prepended:
240
416
  return [
241
417
  sql(
@@ -250,12 +426,33 @@ class ModelBase:
250
426
  ]
251
427
 
252
428
  def primary_key(self) -> tuple:
429
+ """Get the primary key value(s) for this instance.
430
+
431
+ Returns:
432
+ Tuple containing the primary key field values
433
+
434
+ Example:
435
+ >>> user = User(id=UUID(...), name="Alice")
436
+ >>> user.primary_key()
437
+ (UUID('...'),)
438
+ """
253
439
  return tuple(getattr(self, pk) for pk in self.primary_key_names)
254
440
 
255
441
  @classmethod
256
442
  def _get_field_values_fn(
257
443
  cls: type[T], exclude: FieldNamesSet = ()
258
444
  ) -> Callable[[T], list[Any]]:
445
+ """Generate optimized function to extract field values from instances.
446
+
447
+ This method generates and compiles a function that efficiently extracts
448
+ field values from model instances, applying serialization where needed.
449
+
450
+ Args:
451
+ exclude: Field names to exclude from value extraction
452
+
453
+ Returns:
454
+ Compiled function that takes an instance and returns field values
455
+ """
259
456
  env: dict[str, Any] = {}
260
457
  func = ["def get_field_values(self): return ["]
261
458
  for ci in cls.column_info().values():
@@ -270,6 +467,17 @@ class ModelBase:
270
467
  return env["get_field_values"]
271
468
 
272
469
  def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
470
+ """Get field values for this instance, with serialization applied.
471
+
472
+ Args:
473
+ exclude: Field names to exclude from the result
474
+
475
+ Returns:
476
+ List of field values in the same order as field_names()
477
+
478
+ Note:
479
+ This method applies any configured serialize functions to the values.
480
+ """
273
481
  get_field_values = self._cached(
274
482
  ("get_field_values", tuple(sorted(exclude))),
275
483
  lambda: self._get_field_values_fn(exclude),
@@ -279,6 +487,15 @@ class ModelBase:
279
487
  def field_values_sql(
280
488
  self, *, exclude: FieldNamesSet = (), default_none: bool = False
281
489
  ) -> list[Fragment]:
490
+ """Generate SQL fragments for field values.
491
+
492
+ Args:
493
+ exclude: Field names to exclude
494
+ default_none: If True, None values become DEFAULT literals instead of NULL
495
+
496
+ Returns:
497
+ List of Fragment objects containing value placeholders or DEFAULT
498
+ """
282
499
  if default_none:
283
500
  return [
284
501
  sql.literal("DEFAULT") if value is None else sql.value(value)
@@ -289,6 +506,15 @@ class ModelBase:
289
506
 
290
507
  @classmethod
291
508
  def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
509
+ """Generate optimized function to create instances from mappings.
510
+
511
+ This method generates and compiles a function that efficiently creates
512
+ model instances from dictionary-like mappings, applying deserialization
513
+ where needed.
514
+
515
+ Returns:
516
+ Compiled function that takes a mapping and returns a model instance
517
+ """
292
518
  env: dict[str, Any] = {"cls": cls}
293
519
  func = ["def from_mapping(mapping):"]
294
520
  if not any(ci.deserialize for ci in cls.column_info().values()):
@@ -308,6 +534,21 @@ class ModelBase:
308
534
 
309
535
  @classmethod
310
536
  def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
537
+ """Create a model instance from a dictionary-like mapping.
538
+
539
+ This method applies any configured deserialize functions to the values
540
+ before creating the instance.
541
+
542
+ Args:
543
+ mapping: Dictionary-like object with field names as keys
544
+
545
+ Returns:
546
+ New instance of this model class
547
+
548
+ Example:
549
+ >>> row = {"id": UUID(...), "name": "Alice", "email": None}
550
+ >>> user = User.from_mapping(row)
551
+ """
311
552
  # KLUDGE nasty but... efficient?
312
553
  from_mapping_fn = cls._get_from_mapping_fn()
313
554
  cls.from_mapping = from_mapping_fn # type: ignore
@@ -317,6 +558,22 @@ class ModelBase:
317
558
  def from_prepended_mapping(
318
559
  cls: type[T], mapping: Mapping[str, Any], prepend: str
319
560
  ) -> T:
561
+ """Create a model instance from a mapping with prefixed keys.
562
+
563
+ Useful for creating instances from JOIN query results where columns
564
+ are prefixed to avoid name conflicts.
565
+
566
+ Args:
567
+ mapping: Dictionary with prefixed keys
568
+ prepend: Prefix to strip from keys
569
+
570
+ Returns:
571
+ New instance of this model class
572
+
573
+ Example:
574
+ >>> row = {"user_id": UUID(...), "user_name": "Alice", "user_email": None}
575
+ >>> user = User.from_prepended_mapping(row, "user_")
576
+ """
320
577
  filtered_dict: dict[str, Any] = {}
321
578
  for k, v in mapping.items():
322
579
  if k.startswith(prepend):
@@ -325,12 +582,29 @@ class ModelBase:
325
582
 
326
583
  @classmethod
327
584
  def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
585
+ """Ensure the input is a model instance, converting from mapping if needed.
586
+
587
+ Args:
588
+ row: Either a model instance or a mapping to convert
589
+
590
+ Returns:
591
+ Model instance
592
+ """
328
593
  if isinstance(row, cls):
329
594
  return row
330
595
  return cls.from_mapping(row) # type: ignore
331
596
 
332
597
  @classmethod
333
598
  def create_table_sql(cls) -> Fragment:
599
+ """Generate CREATE TABLE SQL for this model.
600
+
601
+ Returns:
602
+ Fragment containing CREATE TABLE IF NOT EXISTS statement
603
+
604
+ Example:
605
+ >>> list(User.create_table_sql())
606
+ ['CREATE TABLE IF NOT EXISTS "users" ("id" UUID NOT NULL, "name" TEXT NOT NULL, "email" TEXT, PRIMARY KEY ("id"))']
607
+ """
334
608
  entries = [
335
609
  sql(
336
610
  "{} {}",
@@ -354,6 +628,20 @@ class ModelBase:
354
628
  order_by: Union[FieldNames, str] = (),
355
629
  for_update: bool = False,
356
630
  ) -> Fragment:
631
+ """Generate SELECT SQL for this model.
632
+
633
+ Args:
634
+ where: WHERE conditions as Fragment or iterable of Fragments
635
+ order_by: ORDER BY field names
636
+ for_update: Whether to add FOR UPDATE clause
637
+
638
+ Returns:
639
+ Fragment containing SELECT statement
640
+
641
+ Example:
642
+ >>> list(User.select_sql(where=sql("name = {}", "Alice")))
643
+ ['SELECT "id", "name", "email" FROM "users" WHERE name = $1', 'Alice']
644
+ """
357
645
  if isinstance(order_by, str):
358
646
  order_by = (order_by,)
359
647
  if not isinstance(where, Fragment):
@@ -383,6 +671,16 @@ class ModelBase:
383
671
  query: Fragment,
384
672
  prefetch: int = 1000,
385
673
  ) -> AsyncGenerator[T, None]:
674
+ """Create an async generator from a query result.
675
+
676
+ Args:
677
+ connection: Database connection
678
+ query: SQL query Fragment
679
+ prefetch: Number of rows to prefetch
680
+
681
+ Yields:
682
+ Model instances from the query results
683
+ """
386
684
  async for row in connection.cursor(*query, prefetch=prefetch):
387
685
  yield cls.from_mapping(row)
388
686
 
@@ -395,6 +693,22 @@ class ModelBase:
395
693
  where: Where = (),
396
694
  prefetch: int = 1000,
397
695
  ) -> AsyncGenerator[T, None]:
696
+ """Create an async generator for SELECT results.
697
+
698
+ Args:
699
+ connection: Database connection
700
+ order_by: ORDER BY field names
701
+ for_update: Whether to add FOR UPDATE clause
702
+ where: WHERE conditions
703
+ prefetch: Number of rows to prefetch
704
+
705
+ Yields:
706
+ Model instances from the SELECT results
707
+
708
+ Example:
709
+ >>> async for user in User.select_cursor(conn, where=sql("active = {}", True)):
710
+ ... print(user.name)
711
+ """
398
712
  return cls.cursor_from(
399
713
  connection,
400
714
  cls.select_sql(order_by=order_by, for_update=for_update, where=where),
@@ -407,6 +721,15 @@ class ModelBase:
407
721
  connection_or_pool: Union[Connection, Pool],
408
722
  query: Fragment,
409
723
  ) -> list[T]:
724
+ """Execute a query and return model instances.
725
+
726
+ Args:
727
+ connection_or_pool: Database connection or pool
728
+ query: SQL query Fragment
729
+
730
+ Returns:
731
+ List of model instances from the query results
732
+ """
410
733
  return [cls.from_mapping(row) for row in await connection_or_pool.fetch(*query)]
411
734
 
412
735
  @classmethod
@@ -417,6 +740,20 @@ class ModelBase:
417
740
  for_update: bool = False,
418
741
  where: Where = (),
419
742
  ) -> list[T]:
743
+ """Execute a SELECT query and return model instances.
744
+
745
+ Args:
746
+ connection_or_pool: Database connection or pool
747
+ order_by: ORDER BY field names
748
+ for_update: Whether to add FOR UPDATE clause
749
+ where: WHERE conditions
750
+
751
+ Returns:
752
+ List of model instances from the SELECT results
753
+
754
+ Example:
755
+ >>> users = await User.select(pool, where=sql("active = {}", True))
756
+ """
420
757
  return await cls.fetch_from(
421
758
  connection_or_pool,
422
759
  cls.select_sql(order_by=order_by, for_update=for_update, where=where),
@@ -424,6 +761,18 @@ class ModelBase:
424
761
 
425
762
  @classmethod
426
763
  def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
764
+ """Generate INSERT SQL for creating a new record with RETURNING clause.
765
+
766
+ Args:
767
+ **kwargs: Field values for the new record
768
+
769
+ Returns:
770
+ Fragment containing INSERT ... RETURNING statement
771
+
772
+ Example:
773
+ >>> list(User.create_sql(name="Alice", email="alice@example.com"))
774
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2) RETURNING "id", "name", "email"', 'Alice', 'alice@example.com']
775
+ """
427
776
  column_info = cls.column_info()
428
777
  return sql(
429
778
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
@@ -439,10 +788,35 @@ class ModelBase:
439
788
  async def create(
440
789
  cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
441
790
  ) -> T:
791
+ """Create a new record in the database.
792
+
793
+ Args:
794
+ connection_or_pool: Database connection or pool
795
+ **kwargs: Field values for the new record
796
+
797
+ Returns:
798
+ Model instance representing the created record
799
+
800
+ Example:
801
+ >>> user = await User.create(pool, name="Alice", email="alice@example.com")
802
+ """
442
803
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
443
804
  return cls.from_mapping(row)
444
805
 
445
806
  def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
807
+ """Generate INSERT SQL for this instance.
808
+
809
+ Args:
810
+ exclude: Field names to exclude from the INSERT
811
+
812
+ Returns:
813
+ Fragment containing INSERT statement
814
+
815
+ Example:
816
+ >>> user = User(name="Alice", email="alice@example.com")
817
+ >>> list(user.insert_sql())
818
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2)', 'Alice', 'alice@example.com']
819
+ """
446
820
  cached = self._cached(
447
821
  ("insert_sql", tuple(sorted(exclude))),
448
822
  lambda: sql(
@@ -458,38 +832,136 @@ class ModelBase:
458
832
  async def insert(
459
833
  self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
460
834
  ) -> str:
835
+ """Insert this instance into the database.
836
+
837
+ Args:
838
+ connection_or_pool: Database connection or pool
839
+ exclude: Field names to exclude from the INSERT
840
+
841
+ Returns:
842
+ Result string from the database operation
843
+ """
461
844
  return await connection_or_pool.execute(*self.insert_sql(exclude))
462
845
 
463
846
  @classmethod
464
- def upsert_sql(cls, insert_sql: Fragment, exclude: FieldNamesSet = ()) -> Fragment:
465
- cached = cls._cached(
466
- ("upsert_sql", tuple(sorted(exclude))),
467
- lambda: sql(
468
- " ON CONFLICT ({pks}) DO UPDATE SET {assignments}",
847
+ def upsert_sql(
848
+ cls,
849
+ insert_sql: Fragment,
850
+ insert_only: FieldNamesSet = (),
851
+ force_update: FieldNamesSet = (),
852
+ ) -> Fragment:
853
+ """Generate UPSERT (INSERT ... ON CONFLICT DO UPDATE) SQL.
854
+
855
+ Args:
856
+ insert_sql: Base INSERT statement Fragment
857
+ insert_only: Field names to exclude from the UPDATE clause
858
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
859
+
860
+ Returns:
861
+ Fragment containing INSERT ... ON CONFLICT DO UPDATE statement
862
+
863
+ Example:
864
+ >>> insert = user.insert_sql()
865
+ >>> list(User.upsert_sql(insert))
866
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "name"=EXCLUDED."name", "email"=EXCLUDED."email"', 'Alice', 'alice@example.com']
867
+
868
+ Note:
869
+ Fields marked with ColumnInfo(insert_only=True) are automatically
870
+ excluded from the UPDATE clause, unless overridden by force_update.
871
+ """
872
+ # Combine insert_only parameter with auto-detected insert_only fields, but remove force_update fields
873
+ auto_insert_only = cls.insert_only_field_names() - set(force_update)
874
+ manual_insert_only = set(insert_only) - set(
875
+ force_update
876
+ ) # Remove force_update from manual insert_only too
877
+ all_insert_only = manual_insert_only | auto_insert_only
878
+
879
+ def generate_upsert_fragment():
880
+ updatable_fields = cls.field_names(
881
+ exclude=(*cls.primary_key_names, *all_insert_only)
882
+ )
883
+ return sql(
884
+ " ON CONFLICT ({pks}) DO {action}",
469
885
  insert_sql=insert_sql,
470
886
  pks=sql.list(cls.primary_key_names_sql()),
471
- assignments=sql.list(
472
- sql("{field}=EXCLUDED.{field}", field=x)
473
- for x in cls.field_names_sql(
474
- exclude=(*cls.primary_key_names, *exclude)
887
+ action=(
888
+ sql(
889
+ "UPDATE SET {assignments}",
890
+ assignments=sql.list(
891
+ sql("{field}=EXCLUDED.{field}", field=sql.identifier(field))
892
+ for field in updatable_fields
893
+ ),
475
894
  )
895
+ if updatable_fields
896
+ else sql.literal("NOTHING")
476
897
  ),
477
- ).flatten(),
898
+ ).flatten()
899
+
900
+ cached = cls._cached(
901
+ ("upsert_sql", tuple(sorted(all_insert_only))),
902
+ generate_upsert_fragment,
478
903
  )
479
904
  return Fragment([insert_sql, cached])
480
905
 
481
906
  async def upsert(
482
- self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
907
+ self,
908
+ connection_or_pool: Union[Connection, Pool],
909
+ exclude: FieldNamesSet = (),
910
+ insert_only: FieldNamesSet = (),
911
+ force_update: FieldNamesSet = (),
483
912
  ) -> bool:
913
+ """Insert or update this instance in the database.
914
+
915
+ Args:
916
+ connection_or_pool: Database connection or pool
917
+ exclude: Field names to exclude from INSERT and UPDATE
918
+ insert_only: Field names that should only be set on INSERT, not UPDATE
919
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
920
+
921
+ Returns:
922
+ True if the record was updated, False if it was inserted
923
+
924
+ Example:
925
+ >>> user = User(id=1, name="Alice", created_at=datetime.now())
926
+ >>> # Only set created_at on INSERT, not UPDATE
927
+ >>> was_updated = await user.upsert(pool, insert_only={'created_at'})
928
+ >>> # Force update created_at even if it's marked insert_only in ColumnInfo
929
+ >>> was_updated = await user.upsert(pool, force_update={'created_at'})
930
+
931
+ Note:
932
+ Fields marked with ColumnInfo(insert_only=True) are automatically
933
+ treated as insert-only and combined with the insert_only parameter,
934
+ unless overridden by force_update.
935
+ """
936
+ # upsert_sql automatically handles insert_only fields from ColumnInfo
937
+ # We only need to combine manual insert_only with exclude for the UPDATE clause
938
+ update_exclude = set(exclude) | set(insert_only)
484
939
  query = sql(
485
940
  "{} RETURNING xmax",
486
- self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
941
+ self.upsert_sql(
942
+ self.insert_sql(exclude=exclude),
943
+ insert_only=update_exclude,
944
+ force_update=force_update,
945
+ ),
487
946
  )
488
947
  result = await connection_or_pool.fetchrow(*query)
489
948
  return result["xmax"] != 0
490
949
 
491
950
  @classmethod
492
951
  def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
952
+ """Generate DELETE SQL for multiple records.
953
+
954
+ Args:
955
+ rows: Model instances to delete
956
+
957
+ Returns:
958
+ Fragment containing DELETE statement with UNNEST-based WHERE clause
959
+
960
+ Example:
961
+ >>> users = [user1, user2, user3]
962
+ >>> list(User.delete_multiple_sql(users))
963
+ ['DELETE FROM "users" WHERE ("id") IN (SELECT * FROM UNNEST($1::UUID[]))', (uuid1, uuid2, uuid3)]
964
+ """
493
965
  cached = cls._cached(
494
966
  ("delete_multiple_sql",),
495
967
  lambda: sql(
@@ -510,10 +982,34 @@ class ModelBase:
510
982
  async def delete_multiple(
511
983
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
512
984
  ) -> str:
985
+ """Delete multiple records from the database.
986
+
987
+ Args:
988
+ connection_or_pool: Database connection or pool
989
+ rows: Model instances to delete
990
+
991
+ Returns:
992
+ Result string from the database operation
993
+ """
513
994
  return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
514
995
 
515
996
  @classmethod
516
997
  def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
998
+ """Generate bulk INSERT SQL using UNNEST.
999
+
1000
+ This is the most efficient method for bulk inserts in PostgreSQL.
1001
+
1002
+ Args:
1003
+ rows: Model instances to insert
1004
+
1005
+ Returns:
1006
+ Fragment containing INSERT ... SELECT FROM UNNEST statement
1007
+
1008
+ Example:
1009
+ >>> users = [User(name="Alice"), User(name="Bob")]
1010
+ >>> list(User.insert_multiple_sql(users))
1011
+ ['INSERT INTO "users" ("name", "email") SELECT * FROM UNNEST($1::TEXT[], $2::TEXT[])', ('Alice', 'Bob'), (None, None)]
1012
+ """
517
1013
  cached = cls._cached(
518
1014
  ("insert_multiple_sql",),
519
1015
  lambda: sql(
@@ -532,6 +1028,18 @@ class ModelBase:
532
1028
 
533
1029
  @classmethod
534
1030
  def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
1031
+ """Generate bulk INSERT SQL using VALUES syntax.
1032
+
1033
+ This method is required when your model contains array columns, because
1034
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1035
+ Use this instead of the UNNEST method when you have array-typed fields.
1036
+
1037
+ Args:
1038
+ rows: Model instances to insert
1039
+
1040
+ Returns:
1041
+ Fragment containing INSERT ... VALUES statement
1042
+ """
535
1043
  return sql(
536
1044
  "INSERT INTO {table} ({fields}) VALUES {values}",
537
1045
  table=cls.table_name_sql(),
@@ -546,6 +1054,15 @@ class ModelBase:
546
1054
  def insert_multiple_executemany_chunk_sql(
547
1055
  cls: type[T], chunk_size: int
548
1056
  ) -> Fragment:
1057
+ """Generate INSERT SQL template for executemany with specific chunk size.
1058
+
1059
+ Args:
1060
+ chunk_size: Number of records per batch
1061
+
1062
+ Returns:
1063
+ Fragment containing INSERT statement with numbered placeholders
1064
+ """
1065
+
549
1066
  def generate() -> Fragment:
550
1067
  columns = len(cls.column_info())
551
1068
  values = ", ".join(
@@ -568,6 +1085,14 @@ class ModelBase:
568
1085
  async def insert_multiple_executemany(
569
1086
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
570
1087
  ) -> None:
1088
+ """Insert multiple records using asyncpg's executemany.
1089
+
1090
+ This is the most compatible but slowest bulk insert method.
1091
+
1092
+ Args:
1093
+ connection_or_pool: Database connection or pool
1094
+ rows: Model instances to insert
1095
+ """
571
1096
  args = [r.field_values() for r in rows]
572
1097
  query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
573
1098
  if args:
@@ -577,12 +1102,36 @@ class ModelBase:
577
1102
  async def insert_multiple_unnest(
578
1103
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
579
1104
  ) -> str:
1105
+ """Insert multiple records using PostgreSQL UNNEST.
1106
+
1107
+ This is the most efficient bulk insert method for PostgreSQL.
1108
+
1109
+ Args:
1110
+ connection_or_pool: Database connection or pool
1111
+ rows: Model instances to insert
1112
+
1113
+ Returns:
1114
+ Result string from the database operation
1115
+ """
580
1116
  return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
581
1117
 
582
1118
  @classmethod
583
1119
  async def insert_multiple_array_safe(
584
1120
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
585
1121
  ) -> str:
1122
+ """Insert multiple records using VALUES syntax with chunking.
1123
+
1124
+ This method is required when your model contains array columns, because
1125
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1126
+ Data is processed in chunks to manage memory usage.
1127
+
1128
+ Args:
1129
+ connection_or_pool: Database connection or pool
1130
+ rows: Model instances to insert
1131
+
1132
+ Returns:
1133
+ Result string from the last chunk operation
1134
+ """
586
1135
  last = ""
587
1136
  for chunk in chunked(rows, 100):
588
1137
  last = await connection_or_pool.execute(
@@ -594,6 +1143,21 @@ class ModelBase:
594
1143
  async def insert_multiple(
595
1144
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
596
1145
  ) -> str:
1146
+ """Insert multiple records using the configured insert_multiple_mode.
1147
+
1148
+ Args:
1149
+ connection_or_pool: Database connection or pool
1150
+ rows: Model instances to insert
1151
+
1152
+ Returns:
1153
+ Result string from the database operation
1154
+
1155
+ Note:
1156
+ The actual method used depends on the insert_multiple_mode setting:
1157
+ - 'unnest': Most efficient, uses UNNEST (default)
1158
+ - 'array_safe': Uses VALUES syntax; required when model has array columns
1159
+ - 'executemany': Uses asyncpg's executemany, slowest but most compatible
1160
+ """
597
1161
  if cls.insert_multiple_mode == "executemany":
598
1162
  await cls.insert_multiple_executemany(connection_or_pool, rows)
599
1163
  return "INSERT"
@@ -608,10 +1172,21 @@ class ModelBase:
608
1172
  connection_or_pool: Union[Connection, Pool],
609
1173
  rows: Iterable[T],
610
1174
  insert_only: FieldNamesSet = (),
1175
+ force_update: FieldNamesSet = (),
611
1176
  ) -> None:
1177
+ """Bulk upsert using asyncpg's executemany.
1178
+
1179
+ Args:
1180
+ connection_or_pool: Database connection or pool
1181
+ rows: Model instances to upsert
1182
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1183
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1184
+ """
612
1185
  args = [r.field_values() for r in rows]
613
1186
  query = cls.upsert_sql(
614
- cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
1187
+ cls.insert_multiple_executemany_chunk_sql(1),
1188
+ insert_only=insert_only,
1189
+ force_update=force_update,
615
1190
  ).query()[0]
616
1191
  if args:
617
1192
  await connection_or_pool.executemany(query, args)
@@ -622,9 +1197,25 @@ class ModelBase:
622
1197
  connection_or_pool: Union[Connection, Pool],
623
1198
  rows: Iterable[T],
624
1199
  insert_only: FieldNamesSet = (),
1200
+ force_update: FieldNamesSet = (),
625
1201
  ) -> str:
1202
+ """Bulk upsert using PostgreSQL UNNEST.
1203
+
1204
+ Args:
1205
+ connection_or_pool: Database connection or pool
1206
+ rows: Model instances to upsert
1207
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1208
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1209
+
1210
+ Returns:
1211
+ Result string from the database operation
1212
+ """
626
1213
  return await connection_or_pool.execute(
627
- *cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
1214
+ *cls.upsert_sql(
1215
+ cls.insert_multiple_sql(rows),
1216
+ insert_only=insert_only,
1217
+ force_update=force_update,
1218
+ )
628
1219
  )
629
1220
 
630
1221
  @classmethod
@@ -633,12 +1224,29 @@ class ModelBase:
633
1224
  connection_or_pool: Union[Connection, Pool],
634
1225
  rows: Iterable[T],
635
1226
  insert_only: FieldNamesSet = (),
1227
+ force_update: FieldNamesSet = (),
636
1228
  ) -> str:
1229
+ """Bulk upsert using VALUES syntax with chunking.
1230
+
1231
+ This method is required when your model contains array columns, because
1232
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1233
+
1234
+ Args:
1235
+ connection_or_pool: Database connection or pool
1236
+ rows: Model instances to upsert
1237
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1238
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1239
+
1240
+ Returns:
1241
+ Result string from the last chunk operation
1242
+ """
637
1243
  last = ""
638
1244
  for chunk in chunked(rows, 100):
639
1245
  last = await connection_or_pool.execute(
640
1246
  *cls.upsert_sql(
641
- cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
1247
+ cls.insert_multiple_array_safe_sql(chunk),
1248
+ insert_only=insert_only,
1249
+ force_update=force_update,
642
1250
  )
643
1251
  )
644
1252
  return last
@@ -649,25 +1257,66 @@ class ModelBase:
649
1257
  connection_or_pool: Union[Connection, Pool],
650
1258
  rows: Iterable[T],
651
1259
  insert_only: FieldNamesSet = (),
1260
+ force_update: FieldNamesSet = (),
652
1261
  ) -> str:
1262
+ """Bulk upsert (INSERT ... ON CONFLICT DO UPDATE) multiple records.
1263
+
1264
+ Args:
1265
+ connection_or_pool: Database connection or pool
1266
+ rows: Model instances to upsert
1267
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1268
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1269
+
1270
+ Returns:
1271
+ Result string from the database operation
1272
+
1273
+ Example:
1274
+ >>> await User.upsert_multiple(pool, users, insert_only={'created_at'})
1275
+ >>> await User.upsert_multiple(pool, users, force_update={'created_at'})
1276
+
1277
+ Note:
1278
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1279
+ treated as insert-only and combined with the insert_only parameter,
1280
+ unless overridden by force_update.
1281
+ """
1282
+ # upsert_sql automatically handles insert_only fields from ColumnInfo
1283
+ # Pass manual insert_only parameter through to the specific implementations
1284
+
653
1285
  if cls.insert_multiple_mode == "executemany":
654
1286
  await cls.upsert_multiple_executemany(
655
- connection_or_pool, rows, insert_only=insert_only
1287
+ connection_or_pool,
1288
+ rows,
1289
+ insert_only=insert_only,
1290
+ force_update=force_update,
656
1291
  )
657
1292
  return "INSERT"
658
1293
  elif cls.insert_multiple_mode == "array_safe":
659
1294
  return await cls.upsert_multiple_array_safe(
660
- connection_or_pool, rows, insert_only=insert_only
1295
+ connection_or_pool,
1296
+ rows,
1297
+ insert_only=insert_only,
1298
+ force_update=force_update,
661
1299
  )
662
1300
  else:
663
1301
  return await cls.upsert_multiple_unnest(
664
- connection_or_pool, rows, insert_only=insert_only
1302
+ connection_or_pool,
1303
+ rows,
1304
+ insert_only=insert_only,
1305
+ force_update=force_update,
665
1306
  )
666
1307
 
667
1308
  @classmethod
668
1309
  def _get_equal_ignoring_fn(
669
1310
  cls: type[T], ignore: FieldNamesSet = ()
670
1311
  ) -> Callable[[T, T], bool]:
1312
+ """Generate optimized function to compare instances ignoring certain fields.
1313
+
1314
+ Args:
1315
+ ignore: Field names to ignore during comparison
1316
+
1317
+ Returns:
1318
+ Compiled function that compares two instances, returning True if equal
1319
+ """
671
1320
  env: dict[str, Any] = {}
672
1321
  func = ["def equal_ignoring(a, b):"]
673
1322
  for ci in cls.column_info().values():
@@ -686,8 +1335,38 @@ class ModelBase:
686
1335
  where: Where,
687
1336
  ignore: FieldNamesSet = (),
688
1337
  insert_only: FieldNamesSet = (),
1338
+ force_update: FieldNamesSet = (),
689
1339
  ) -> "ReplaceMultiplePlan[T]":
690
- ignore = sorted(set(ignore) | set(insert_only))
1340
+ """Plan a replace operation by comparing new data with existing records.
1341
+
1342
+ This method analyzes the differences between the provided rows and existing
1343
+ database records, determining which records need to be created, updated, or deleted.
1344
+
1345
+ Args:
1346
+ connection: Database connection (must support FOR UPDATE)
1347
+ rows: New data as model instances or mappings
1348
+ where: WHERE clause to limit which existing records to consider
1349
+ ignore: Field names to ignore when comparing records
1350
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1351
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1352
+
1353
+ Returns:
1354
+ ReplaceMultiplePlan containing the planned operations
1355
+
1356
+ Example:
1357
+ >>> plan = await User.plan_replace_multiple(
1358
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1359
+ ... )
1360
+ >>> print(f"Will create {len(plan.created)}, update {len(plan.updated)}, delete {len(plan.deleted)}")
1361
+
1362
+ Note:
1363
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1364
+ treated as insert-only and combined with the insert_only parameter,
1365
+ unless overridden by force_update.
1366
+ """
1367
+ # For comparison purposes, combine auto-detected insert_only fields with manual ones
1368
+ all_insert_only = cls.insert_only_field_names() | set(insert_only)
1369
+ ignore = sorted(set(ignore) | all_insert_only)
691
1370
  equal_ignoring = cls._cached(
692
1371
  ("equal_ignoring", tuple(ignore)),
693
1372
  lambda: cls._get_equal_ignoring_fn(ignore),
@@ -710,7 +1389,11 @@ class ModelBase:
710
1389
 
711
1390
  created = list(pending.values())
712
1391
 
713
- return ReplaceMultiplePlan(cls, insert_only, created, updated, deleted)
1392
+ # Pass only manual insert_only and force_update to the plan
1393
+ # since upsert_multiple handles auto-detected ones
1394
+ return ReplaceMultiplePlan(
1395
+ cls, insert_only, force_update, created, updated, deleted
1396
+ )
714
1397
 
715
1398
  @classmethod
716
1399
  async def replace_multiple(
@@ -721,9 +1404,42 @@ class ModelBase:
721
1404
  where: Where,
722
1405
  ignore: FieldNamesSet = (),
723
1406
  insert_only: FieldNamesSet = (),
1407
+ force_update: FieldNamesSet = (),
724
1408
  ) -> tuple[list[T], list[T], list[T]]:
1409
+ """Replace records in the database with the provided data.
1410
+
1411
+ This is a complete replace operation: records matching the WHERE clause
1412
+ that aren't in the new data will be deleted, new records will be inserted,
1413
+ and changed records will be updated.
1414
+
1415
+ Args:
1416
+ connection: Database connection (must support FOR UPDATE)
1417
+ rows: New data as model instances or mappings
1418
+ where: WHERE clause to limit which existing records to consider for replacement
1419
+ ignore: Field names to ignore when comparing records
1420
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1421
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1422
+
1423
+ Returns:
1424
+ Tuple of (created_records, updated_records, deleted_records)
1425
+
1426
+ Example:
1427
+ >>> created, updated, deleted = await User.replace_multiple(
1428
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1429
+ ... )
1430
+
1431
+ Note:
1432
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1433
+ treated as insert-only and combined with the insert_only parameter,
1434
+ unless overridden by force_update.
1435
+ """
725
1436
  plan = await cls.plan_replace_multiple(
726
- connection, rows, where=where, ignore=ignore, insert_only=insert_only
1437
+ connection,
1438
+ rows,
1439
+ where=where,
1440
+ ignore=ignore,
1441
+ insert_only=insert_only,
1442
+ force_update=force_update,
727
1443
  )
728
1444
  await plan.execute(connection)
729
1445
  return plan.cud
@@ -732,6 +1448,14 @@ class ModelBase:
732
1448
  def _get_differences_ignoring_fn(
733
1449
  cls: type[T], ignore: FieldNamesSet = ()
734
1450
  ) -> Callable[[T, T], list[str]]:
1451
+ """Generate optimized function to find field differences between instances.
1452
+
1453
+ Args:
1454
+ ignore: Field names to ignore during comparison
1455
+
1456
+ Returns:
1457
+ Compiled function that returns list of field names that differ
1458
+ """
735
1459
  env: dict[str, Any] = {}
736
1460
  func = [
737
1461
  "def differences_ignoring(a, b):",
@@ -755,8 +1479,40 @@ class ModelBase:
755
1479
  where: Where,
756
1480
  ignore: FieldNamesSet = (),
757
1481
  insert_only: FieldNamesSet = (),
1482
+ force_update: FieldNamesSet = (),
758
1483
  ) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
759
- ignore = sorted(set(ignore) | set(insert_only))
1484
+ """Replace records and report the specific field differences for updates.
1485
+
1486
+ Like replace_multiple, but provides detailed information about which
1487
+ fields changed for each updated record.
1488
+
1489
+ Args:
1490
+ connection: Database connection (must support FOR UPDATE)
1491
+ rows: New data as model instances or mappings
1492
+ where: WHERE clause to limit which existing records to consider
1493
+ ignore: Field names to ignore when comparing records
1494
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1495
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1496
+
1497
+ Returns:
1498
+ Tuple of (created_records, update_triples, deleted_records)
1499
+ where update_triples contains (old_record, new_record, changed_field_names)
1500
+
1501
+ Example:
1502
+ >>> created, updates, deleted = await User.replace_multiple_reporting_differences(
1503
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1504
+ ... )
1505
+ >>> for old, new, fields in updates:
1506
+ ... print(f"Updated {old.name}: changed {', '.join(fields)}")
1507
+
1508
+ Note:
1509
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1510
+ treated as insert-only and combined with the insert_only parameter,
1511
+ unless overridden by force_update.
1512
+ """
1513
+ # For comparison purposes, combine auto-detected insert_only fields with manual ones
1514
+ all_insert_only = cls.insert_only_field_names() | set(insert_only)
1515
+ ignore = sorted(set(ignore) | all_insert_only)
760
1516
  differences_ignoring = cls._cached(
761
1517
  ("differences_ignoring", tuple(ignore)),
762
1518
  lambda: cls._get_differences_ignoring_fn(ignore),
@@ -786,6 +1542,7 @@ class ModelBase:
786
1542
  connection,
787
1543
  (*created, *(t[1] for t in updated_triples)),
788
1544
  insert_only=insert_only,
1545
+ force_update=force_update,
789
1546
  )
790
1547
  if deleted:
791
1548
  await cls.delete_multiple(connection, deleted)
@@ -797,30 +1554,67 @@ class ModelBase:
797
1554
  class ReplaceMultiplePlan(Generic[T]):
798
1555
  model_class: type[T]
799
1556
  insert_only: FieldNamesSet
1557
+ force_update: FieldNamesSet
800
1558
  created: list[T]
801
1559
  updated: list[T]
802
1560
  deleted: list[T]
803
1561
 
804
1562
  @property
805
1563
  def cud(self) -> tuple[list[T], list[T], list[T]]:
1564
+ """Get the create, update, delete lists as a tuple.
1565
+
1566
+ Returns:
1567
+ Tuple of (created, updated, deleted) record lists
1568
+ """
806
1569
  return (self.created, self.updated, self.deleted)
807
1570
 
808
1571
  async def execute_upserts(self, connection: Connection) -> None:
1572
+ """Execute the upsert operations (creates and updates).
1573
+
1574
+ Args:
1575
+ connection: Database connection
1576
+ """
809
1577
  if self.created or self.updated:
810
1578
  await self.model_class.upsert_multiple(
811
- connection, (*self.created, *self.updated), insert_only=self.insert_only
1579
+ connection,
1580
+ (*self.created, *self.updated),
1581
+ insert_only=self.insert_only,
1582
+ force_update=self.force_update,
812
1583
  )
813
1584
 
814
1585
  async def execute_deletes(self, connection: Connection) -> None:
1586
+ """Execute the delete operations.
1587
+
1588
+ Args:
1589
+ connection: Database connection
1590
+ """
815
1591
  if self.deleted:
816
1592
  await self.model_class.delete_multiple(connection, self.deleted)
817
1593
 
818
1594
  async def execute(self, connection: Connection) -> None:
1595
+ """Execute all planned operations (upserts then deletes).
1596
+
1597
+ Args:
1598
+ connection: Database connection
1599
+ """
819
1600
  await self.execute_upserts(connection)
820
1601
  await self.execute_deletes(connection)
821
1602
 
822
1603
 
823
1604
  def chunked(lst, n):
1605
+ """Split an iterable into chunks of size n.
1606
+
1607
+ Args:
1608
+ lst: Iterable to chunk
1609
+ n: Chunk size
1610
+
1611
+ Yields:
1612
+ Lists of up to n items from the input
1613
+
1614
+ Example:
1615
+ >>> list(chunked([1, 2, 3, 4, 5], 2))
1616
+ [[1, 2], [3, 4], [5]]
1617
+ """
824
1618
  if type(lst) is not list:
825
1619
  lst = list(lst)
826
1620
  for i in range(0, len(lst), n):