sql-athame 0.4.0a11__py3-none-any.whl → 0.4.0a13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sql_athame/dataclasses.py CHANGED
@@ -33,6 +33,35 @@ Pool: TypeAlias = Any
33
33
 
34
34
  @dataclass
35
35
  class ColumnInfo:
36
+ """Column metadata for dataclass fields.
37
+
38
+ This class specifies SQL column properties that can be applied to dataclass fields
39
+ to control how they are mapped to database columns.
40
+
41
+ Attributes:
42
+ type: SQL type name for query parameters (e.g., 'TEXT', 'INTEGER')
43
+ create_type: SQL type for CREATE TABLE statements (defaults to type if not specified)
44
+ nullable: Whether the column allows NULL values (inferred from Optional types if not specified)
45
+ constraints: Additional SQL constraints (e.g., 'UNIQUE', 'CHECK (value > 0)')
46
+ serialize: Function to transform Python values before database storage
47
+ deserialize: Function to transform database values back to Python objects
48
+ insert_only: Whether this field should only be set on INSERT, not UPDATE in upsert operations
49
+
50
+ Example:
51
+ >>> from dataclasses import dataclass
52
+ >>> from typing import Annotated
53
+ >>> from sql_athame import ModelBase, ColumnInfo
54
+ >>> import json
55
+ >>>
56
+ >>> @dataclass
57
+ ... class Product(ModelBase, table_name="products", primary_key="id"):
58
+ ... id: int
59
+ ... name: str
60
+ ... price: Annotated[float, ColumnInfo(constraints="CHECK (price > 0)")]
61
+ ... tags: Annotated[list, ColumnInfo(type="JSONB", serialize=json.dumps, deserialize=json.loads)]
62
+ ... created_at: Annotated[datetime, ColumnInfo(insert_only=True)]
63
+ """
64
+
36
65
  type: Optional[str] = None
37
66
  create_type: Optional[str] = None
38
67
  nullable: Optional[bool] = None
@@ -42,6 +71,7 @@ class ColumnInfo:
42
71
 
43
72
  serialize: Optional[Callable[[Any], Any]] = None
44
73
  deserialize: Optional[Callable[[Any], Any]] = None
74
+ insert_only: Optional[bool] = None
45
75
 
46
76
  def __post_init__(self, constraints: Union[str, Iterable[str], None]) -> None:
47
77
  if constraints is not None:
@@ -51,6 +81,15 @@ class ColumnInfo:
51
81
 
52
82
  @staticmethod
53
83
  def merge(a: "ColumnInfo", b: "ColumnInfo") -> "ColumnInfo":
84
+ """Merge two ColumnInfo instances, with b taking precedence over a.
85
+
86
+ Args:
87
+ a: Base ColumnInfo
88
+ b: ColumnInfo to overlay on top of a
89
+
90
+ Returns:
91
+ New ColumnInfo with b's non-None values overriding a's values
92
+ """
54
93
  return ColumnInfo(
55
94
  type=b.type if b.type is not None else a.type,
56
95
  create_type=b.create_type if b.create_type is not None else a.create_type,
@@ -58,11 +97,29 @@ class ColumnInfo:
58
97
  _constraints=(*a._constraints, *b._constraints),
59
98
  serialize=b.serialize if b.serialize is not None else a.serialize,
60
99
  deserialize=b.deserialize if b.deserialize is not None else a.deserialize,
100
+ insert_only=b.insert_only if b.insert_only is not None else a.insert_only,
61
101
  )
62
102
 
63
103
 
64
104
  @dataclass
65
105
  class ConcreteColumnInfo:
106
+ """Resolved column information for a specific dataclass field.
107
+
108
+ This is the final, computed column metadata after resolving type hints,
109
+ merging ColumnInfo instances, and applying defaults.
110
+
111
+ Attributes:
112
+ field: The dataclass Field object
113
+ type_hint: The resolved Python type hint
114
+ type: SQL type for query parameters
115
+ create_type: SQL type for CREATE TABLE statements
116
+ nullable: Whether the column allows NULL values
117
+ constraints: Tuple of SQL constraint strings
118
+ serialize: Optional serialization function
119
+ deserialize: Optional deserialization function
120
+ insert_only: Whether this field should only be set on INSERT, not UPDATE
121
+ """
122
+
66
123
  field: Field
67
124
  type_hint: type
68
125
  type: str
@@ -71,11 +128,25 @@ class ConcreteColumnInfo:
71
128
  constraints: tuple[str, ...]
72
129
  serialize: Optional[Callable[[Any], Any]] = None
73
130
  deserialize: Optional[Callable[[Any], Any]] = None
131
+ insert_only: bool = False
74
132
 
75
133
  @staticmethod
76
134
  def from_column_info(
77
135
  field: Field, type_hint: Any, *args: ColumnInfo
78
136
  ) -> "ConcreteColumnInfo":
137
+ """Create ConcreteColumnInfo from a field and its ColumnInfo metadata.
138
+
139
+ Args:
140
+ field: The dataclass Field
141
+ type_hint: The resolved type hint for the field
142
+ *args: ColumnInfo instances to merge (later ones take precedence)
143
+
144
+ Returns:
145
+ ConcreteColumnInfo with all metadata resolved
146
+
147
+ Raises:
148
+ ValueError: If no SQL type can be determined for the field
149
+ """
79
150
  info = functools.reduce(ColumnInfo.merge, args, ColumnInfo())
80
151
  if info.create_type is None and info.type is not None:
81
152
  info.create_type = info.type
@@ -91,9 +162,15 @@ class ConcreteColumnInfo:
91
162
  constraints=info._constraints,
92
163
  serialize=info.serialize,
93
164
  deserialize=info.deserialize,
165
+ insert_only=bool(info.insert_only),
94
166
  )
95
167
 
96
168
  def create_table_string(self) -> str:
169
+ """Generate the SQL column definition for CREATE TABLE statements.
170
+
171
+ Returns:
172
+ SQL string like "TEXT NOT NULL CHECK (length > 0)"
173
+ """
97
174
  parts = (
98
175
  self.create_type,
99
176
  *(() if self.nullable else ("NOT NULL",)),
@@ -102,6 +179,14 @@ class ConcreteColumnInfo:
102
179
  return " ".join(parts)
103
180
 
104
181
  def maybe_serialize(self, value: Any) -> Any:
182
+ """Apply serialization function if configured, otherwise return value unchanged.
183
+
184
+ Args:
185
+ value: The Python value to potentially serialize
186
+
187
+ Returns:
188
+ Serialized value if serialize function is configured, otherwise original value
189
+ """
105
190
  if self.serialize:
106
191
  return self.serialize(value)
107
192
  return value
@@ -179,6 +264,15 @@ class ModelBase:
179
264
 
180
265
  @classmethod
181
266
  def _cached(cls, key: tuple, thunk: Callable[[], U]) -> U:
267
+ """Cache computation results by key.
268
+
269
+ Args:
270
+ key: Cache key tuple
271
+ thunk: Function to compute the value if not cached
272
+
273
+ Returns:
274
+ Cached or computed value
275
+ """
182
276
  try:
183
277
  return cls._cache[key]
184
278
  except KeyError:
@@ -187,6 +281,18 @@ class ModelBase:
187
281
 
188
282
  @classmethod
189
283
  def column_info_for_field(cls, field: Field, type_hint: type) -> ConcreteColumnInfo:
284
+ """Generate ConcreteColumnInfo for a dataclass field.
285
+
286
+ Analyzes the field's type hint and metadata to determine SQL column properties.
287
+ Looks for ColumnInfo in the field's metadata and merges it with type-based defaults.
288
+
289
+ Args:
290
+ field: The dataclass Field object
291
+ type_hint: The resolved type hint for the field
292
+
293
+ Returns:
294
+ ConcreteColumnInfo with all column metadata resolved
295
+ """
190
296
  base_type = type_hint
191
297
  metadata = []
192
298
  if get_origin(type_hint) is Annotated:
@@ -202,6 +308,14 @@ class ModelBase:
202
308
 
203
309
  @classmethod
204
310
  def column_info(cls) -> dict[str, ConcreteColumnInfo]:
311
+ """Get column information for all fields in this model.
312
+
313
+ Returns a cached mapping of field names to their resolved column information.
314
+ This is computed once per class and cached for performance.
315
+
316
+ Returns:
317
+ Dictionary mapping field names to ConcreteColumnInfo objects
318
+ """
205
319
  try:
206
320
  return cls._column_info
207
321
  except AttributeError:
@@ -214,35 +328,131 @@ class ModelBase:
214
328
 
215
329
  @classmethod
216
330
  def table_name_sql(cls, *, prefix: Optional[str] = None) -> Fragment:
331
+ """Generate SQL fragment for the table name.
332
+
333
+ Args:
334
+ prefix: Optional schema or alias prefix
335
+
336
+ Returns:
337
+ Fragment containing the properly quoted table identifier
338
+
339
+ Example:
340
+ >>> list(User.table_name_sql())
341
+ ['"users"']
342
+ >>> list(User.table_name_sql(prefix="public"))
343
+ ['"public"."users"']
344
+ """
217
345
  return sql.identifier(cls.table_name, prefix=prefix)
218
346
 
219
347
  @classmethod
220
348
  def primary_key_names_sql(cls, *, prefix: Optional[str] = None) -> list[Fragment]:
349
+ """Generate SQL fragments for primary key column names.
350
+
351
+ Args:
352
+ prefix: Optional table alias prefix
353
+
354
+ Returns:
355
+ List of Fragment objects for each primary key column
356
+ """
221
357
  return [sql.identifier(pk, prefix=prefix) for pk in cls.primary_key_names]
222
358
 
223
359
  @classmethod
224
360
  def field_names(cls, *, exclude: FieldNamesSet = ()) -> list[str]:
361
+ """Get list of field names for this model.
362
+
363
+ Args:
364
+ exclude: Field names to exclude from the result
365
+
366
+ Returns:
367
+ List of field names as strings
368
+ """
225
369
  return [
226
370
  ci.field.name
227
371
  for ci in cls.column_info().values()
228
372
  if ci.field.name not in exclude
229
373
  ]
230
374
 
375
+ @classmethod
376
+ def insert_only_field_names(cls) -> set[str]:
377
+ """Get set of field names marked as insert_only in ColumnInfo.
378
+
379
+ Returns:
380
+ Set of field names that should only be set on INSERT, not UPDATE
381
+ """
382
+ return cls._cached(
383
+ ("insert_only_field_names",),
384
+ lambda: {
385
+ ci.field.name for ci in cls.column_info().values() if ci.insert_only
386
+ },
387
+ )
388
+
231
389
  @classmethod
232
390
  def field_names_sql(
233
- cls, *, prefix: Optional[str] = None, exclude: FieldNamesSet = ()
391
+ cls,
392
+ *,
393
+ prefix: Optional[str] = None,
394
+ exclude: FieldNamesSet = (),
395
+ as_prepended: Optional[str] = None,
234
396
  ) -> list[Fragment]:
397
+ """Generate SQL fragments for field names.
398
+
399
+ Args:
400
+ prefix: Optional table alias prefix for column names
401
+ exclude: Field names to exclude from the result
402
+ as_prepended: If provided, generate "column AS prepended_column" aliases
403
+
404
+ Returns:
405
+ List of Fragment objects for each field
406
+
407
+ Example:
408
+ >>> list(sql.list(User.field_names_sql()))
409
+ ['"id", "name", "email"']
410
+ >>> list(sql.list(User.field_names_sql(prefix="u")))
411
+ ['"u"."id", "u"."name", "u"."email"']
412
+ >>> list(sql.list(User.field_names_sql(as_prepended="user_")))
413
+ ['"id" AS "user_id", "name" AS "user_name", "email" AS "user_email"']
414
+ """
415
+ if as_prepended:
416
+ return [
417
+ sql(
418
+ "{} AS {}",
419
+ sql.identifier(f, prefix=prefix),
420
+ sql.identifier(f"{as_prepended}{f}"),
421
+ )
422
+ for f in cls.field_names(exclude=exclude)
423
+ ]
235
424
  return [
236
425
  sql.identifier(f, prefix=prefix) for f in cls.field_names(exclude=exclude)
237
426
  ]
238
427
 
239
428
  def primary_key(self) -> tuple:
429
+ """Get the primary key value(s) for this instance.
430
+
431
+ Returns:
432
+ Tuple containing the primary key field values
433
+
434
+ Example:
435
+ >>> user = User(id=UUID(...), name="Alice")
436
+ >>> user.primary_key()
437
+ (UUID('...'),)
438
+ """
240
439
  return tuple(getattr(self, pk) for pk in self.primary_key_names)
241
440
 
242
441
  @classmethod
243
442
  def _get_field_values_fn(
244
443
  cls: type[T], exclude: FieldNamesSet = ()
245
444
  ) -> Callable[[T], list[Any]]:
445
+ """Generate optimized function to extract field values from instances.
446
+
447
+ This method generates and compiles a function that efficiently extracts
448
+ field values from model instances, applying serialization where needed.
449
+
450
+ Args:
451
+ exclude: Field names to exclude from value extraction
452
+
453
+ Returns:
454
+ Compiled function that takes an instance and returns field values
455
+ """
246
456
  env: dict[str, Any] = {}
247
457
  func = ["def get_field_values(self): return ["]
248
458
  for ci in cls.column_info().values():
@@ -257,6 +467,17 @@ class ModelBase:
257
467
  return env["get_field_values"]
258
468
 
259
469
  def field_values(self, *, exclude: FieldNamesSet = ()) -> list[Any]:
470
+ """Get field values for this instance, with serialization applied.
471
+
472
+ Args:
473
+ exclude: Field names to exclude from the result
474
+
475
+ Returns:
476
+ List of field values in the same order as field_names()
477
+
478
+ Note:
479
+ This method applies any configured serialize functions to the values.
480
+ """
260
481
  get_field_values = self._cached(
261
482
  ("get_field_values", tuple(sorted(exclude))),
262
483
  lambda: self._get_field_values_fn(exclude),
@@ -266,6 +487,15 @@ class ModelBase:
266
487
  def field_values_sql(
267
488
  self, *, exclude: FieldNamesSet = (), default_none: bool = False
268
489
  ) -> list[Fragment]:
490
+ """Generate SQL fragments for field values.
491
+
492
+ Args:
493
+ exclude: Field names to exclude
494
+ default_none: If True, None values become DEFAULT literals instead of NULL
495
+
496
+ Returns:
497
+ List of Fragment objects containing value placeholders or DEFAULT
498
+ """
269
499
  if default_none:
270
500
  return [
271
501
  sql.literal("DEFAULT") if value is None else sql.value(value)
@@ -276,6 +506,15 @@ class ModelBase:
276
506
 
277
507
  @classmethod
278
508
  def _get_from_mapping_fn(cls: type[T]) -> Callable[[Mapping[str, Any]], T]:
509
+ """Generate optimized function to create instances from mappings.
510
+
511
+ This method generates and compiles a function that efficiently creates
512
+ model instances from dictionary-like mappings, applying deserialization
513
+ where needed.
514
+
515
+ Returns:
516
+ Compiled function that takes a mapping and returns a model instance
517
+ """
279
518
  env: dict[str, Any] = {"cls": cls}
280
519
  func = ["def from_mapping(mapping):"]
281
520
  if not any(ci.deserialize for ci in cls.column_info().values()):
@@ -295,19 +534,77 @@ class ModelBase:
295
534
 
296
535
  @classmethod
297
536
  def from_mapping(cls: type[T], mapping: Mapping[str, Any], /) -> T:
537
+ """Create a model instance from a dictionary-like mapping.
538
+
539
+ This method applies any configured deserialize functions to the values
540
+ before creating the instance.
541
+
542
+ Args:
543
+ mapping: Dictionary-like object with field names as keys
544
+
545
+ Returns:
546
+ New instance of this model class
547
+
548
+ Example:
549
+ >>> row = {"id": UUID(...), "name": "Alice", "email": None}
550
+ >>> user = User.from_mapping(row)
551
+ """
298
552
  # KLUDGE nasty but... efficient?
299
553
  from_mapping_fn = cls._get_from_mapping_fn()
300
554
  cls.from_mapping = from_mapping_fn # type: ignore
301
555
  return from_mapping_fn(mapping)
302
556
 
557
+ @classmethod
558
+ def from_prepended_mapping(
559
+ cls: type[T], mapping: Mapping[str, Any], prepend: str
560
+ ) -> T:
561
+ """Create a model instance from a mapping with prefixed keys.
562
+
563
+ Useful for creating instances from JOIN query results where columns
564
+ are prefixed to avoid name conflicts.
565
+
566
+ Args:
567
+ mapping: Dictionary with prefixed keys
568
+ prepend: Prefix to strip from keys
569
+
570
+ Returns:
571
+ New instance of this model class
572
+
573
+ Example:
574
+ >>> row = {"user_id": UUID(...), "user_name": "Alice", "user_email": None}
575
+ >>> user = User.from_prepended_mapping(row, "user_")
576
+ """
577
+ filtered_dict: dict[str, Any] = {}
578
+ for k, v in mapping.items():
579
+ if k.startswith(prepend):
580
+ filtered_dict[k[len(prepend) :]] = v
581
+ return cls.from_mapping(filtered_dict)
582
+
303
583
  @classmethod
304
584
  def ensure_model(cls: type[T], row: Union[T, Mapping[str, Any]]) -> T:
585
+ """Ensure the input is a model instance, converting from mapping if needed.
586
+
587
+ Args:
588
+ row: Either a model instance or a mapping to convert
589
+
590
+ Returns:
591
+ Model instance
592
+ """
305
593
  if isinstance(row, cls):
306
594
  return row
307
595
  return cls.from_mapping(row) # type: ignore
308
596
 
309
597
  @classmethod
310
598
  def create_table_sql(cls) -> Fragment:
599
+ """Generate CREATE TABLE SQL for this model.
600
+
601
+ Returns:
602
+ Fragment containing CREATE TABLE IF NOT EXISTS statement
603
+
604
+ Example:
605
+ >>> list(User.create_table_sql())
606
+ ['CREATE TABLE IF NOT EXISTS "users" ("id" UUID NOT NULL, "name" TEXT NOT NULL, "email" TEXT, PRIMARY KEY ("id"))']
607
+ """
311
608
  entries = [
312
609
  sql(
313
610
  "{} {}",
@@ -331,6 +628,20 @@ class ModelBase:
331
628
  order_by: Union[FieldNames, str] = (),
332
629
  for_update: bool = False,
333
630
  ) -> Fragment:
631
+ """Generate SELECT SQL for this model.
632
+
633
+ Args:
634
+ where: WHERE conditions as Fragment or iterable of Fragments
635
+ order_by: ORDER BY field names
636
+ for_update: Whether to add FOR UPDATE clause
637
+
638
+ Returns:
639
+ Fragment containing SELECT statement
640
+
641
+ Example:
642
+ >>> list(User.select_sql(where=sql("name = {}", "Alice")))
643
+ ['SELECT "id", "name", "email" FROM "users" WHERE name = $1', 'Alice']
644
+ """
334
645
  if isinstance(order_by, str):
335
646
  order_by = (order_by,)
336
647
  if not isinstance(where, Fragment):
@@ -360,6 +671,16 @@ class ModelBase:
360
671
  query: Fragment,
361
672
  prefetch: int = 1000,
362
673
  ) -> AsyncGenerator[T, None]:
674
+ """Create an async generator from a query result.
675
+
676
+ Args:
677
+ connection: Database connection
678
+ query: SQL query Fragment
679
+ prefetch: Number of rows to prefetch
680
+
681
+ Yields:
682
+ Model instances from the query results
683
+ """
363
684
  async for row in connection.cursor(*query, prefetch=prefetch):
364
685
  yield cls.from_mapping(row)
365
686
 
@@ -372,6 +693,22 @@ class ModelBase:
372
693
  where: Where = (),
373
694
  prefetch: int = 1000,
374
695
  ) -> AsyncGenerator[T, None]:
696
+ """Create an async generator for SELECT results.
697
+
698
+ Args:
699
+ connection: Database connection
700
+ order_by: ORDER BY field names
701
+ for_update: Whether to add FOR UPDATE clause
702
+ where: WHERE conditions
703
+ prefetch: Number of rows to prefetch
704
+
705
+ Yields:
706
+ Model instances from the SELECT results
707
+
708
+ Example:
709
+ >>> async for user in User.select_cursor(conn, where=sql("active = {}", True)):
710
+ ... print(user.name)
711
+ """
375
712
  return cls.cursor_from(
376
713
  connection,
377
714
  cls.select_sql(order_by=order_by, for_update=for_update, where=where),
@@ -384,6 +721,15 @@ class ModelBase:
384
721
  connection_or_pool: Union[Connection, Pool],
385
722
  query: Fragment,
386
723
  ) -> list[T]:
724
+ """Execute a query and return model instances.
725
+
726
+ Args:
727
+ connection_or_pool: Database connection or pool
728
+ query: SQL query Fragment
729
+
730
+ Returns:
731
+ List of model instances from the query results
732
+ """
387
733
  return [cls.from_mapping(row) for row in await connection_or_pool.fetch(*query)]
388
734
 
389
735
  @classmethod
@@ -394,6 +740,20 @@ class ModelBase:
394
740
  for_update: bool = False,
395
741
  where: Where = (),
396
742
  ) -> list[T]:
743
+ """Execute a SELECT query and return model instances.
744
+
745
+ Args:
746
+ connection_or_pool: Database connection or pool
747
+ order_by: ORDER BY field names
748
+ for_update: Whether to add FOR UPDATE clause
749
+ where: WHERE conditions
750
+
751
+ Returns:
752
+ List of model instances from the SELECT results
753
+
754
+ Example:
755
+ >>> users = await User.select(pool, where=sql("active = {}", True))
756
+ """
397
757
  return await cls.fetch_from(
398
758
  connection_or_pool,
399
759
  cls.select_sql(order_by=order_by, for_update=for_update, where=where),
@@ -401,6 +761,18 @@ class ModelBase:
401
761
 
402
762
  @classmethod
403
763
  def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
764
+ """Generate INSERT SQL for creating a new record with RETURNING clause.
765
+
766
+ Args:
767
+ **kwargs: Field values for the new record
768
+
769
+ Returns:
770
+ Fragment containing INSERT ... RETURNING statement
771
+
772
+ Example:
773
+ >>> list(User.create_sql(name="Alice", email="alice@example.com"))
774
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2) RETURNING "id", "name", "email"', 'Alice', 'alice@example.com']
775
+ """
404
776
  column_info = cls.column_info()
405
777
  return sql(
406
778
  "INSERT INTO {table} ({fields}) VALUES ({values}) RETURNING {out_fields}",
@@ -416,10 +788,35 @@ class ModelBase:
416
788
  async def create(
417
789
  cls: type[T], connection_or_pool: Union[Connection, Pool], **kwargs: Any
418
790
  ) -> T:
791
+ """Create a new record in the database.
792
+
793
+ Args:
794
+ connection_or_pool: Database connection or pool
795
+ **kwargs: Field values for the new record
796
+
797
+ Returns:
798
+ Model instance representing the created record
799
+
800
+ Example:
801
+ >>> user = await User.create(pool, name="Alice", email="alice@example.com")
802
+ """
419
803
  row = await connection_or_pool.fetchrow(*cls.create_sql(**kwargs))
420
804
  return cls.from_mapping(row)
421
805
 
422
806
  def insert_sql(self, exclude: FieldNamesSet = ()) -> Fragment:
807
+ """Generate INSERT SQL for this instance.
808
+
809
+ Args:
810
+ exclude: Field names to exclude from the INSERT
811
+
812
+ Returns:
813
+ Fragment containing INSERT statement
814
+
815
+ Example:
816
+ >>> user = User(name="Alice", email="alice@example.com")
817
+ >>> list(user.insert_sql())
818
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2)', 'Alice', 'alice@example.com']
819
+ """
423
820
  cached = self._cached(
424
821
  ("insert_sql", tuple(sorted(exclude))),
425
822
  lambda: sql(
@@ -435,38 +832,136 @@ class ModelBase:
435
832
  async def insert(
436
833
  self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
437
834
  ) -> str:
835
+ """Insert this instance into the database.
836
+
837
+ Args:
838
+ connection_or_pool: Database connection or pool
839
+ exclude: Field names to exclude from the INSERT
840
+
841
+ Returns:
842
+ Result string from the database operation
843
+ """
438
844
  return await connection_or_pool.execute(*self.insert_sql(exclude))
439
845
 
440
846
  @classmethod
441
- def upsert_sql(cls, insert_sql: Fragment, exclude: FieldNamesSet = ()) -> Fragment:
442
- cached = cls._cached(
443
- ("upsert_sql", tuple(sorted(exclude))),
444
- lambda: sql(
445
- " ON CONFLICT ({pks}) DO UPDATE SET {assignments}",
847
+ def upsert_sql(
848
+ cls,
849
+ insert_sql: Fragment,
850
+ insert_only: FieldNamesSet = (),
851
+ force_update: FieldNamesSet = (),
852
+ ) -> Fragment:
853
+ """Generate UPSERT (INSERT ... ON CONFLICT DO UPDATE) SQL.
854
+
855
+ Args:
856
+ insert_sql: Base INSERT statement Fragment
857
+ insert_only: Field names to exclude from the UPDATE clause
858
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
859
+
860
+ Returns:
861
+ Fragment containing INSERT ... ON CONFLICT DO UPDATE statement
862
+
863
+ Example:
864
+ >>> insert = user.insert_sql()
865
+ >>> list(User.upsert_sql(insert))
866
+ ['INSERT INTO "users" ("name", "email") VALUES ($1, $2) ON CONFLICT ("id") DO UPDATE SET "name"=EXCLUDED."name", "email"=EXCLUDED."email"', 'Alice', 'alice@example.com']
867
+
868
+ Note:
869
+ Fields marked with ColumnInfo(insert_only=True) are automatically
870
+ excluded from the UPDATE clause, unless overridden by force_update.
871
+ """
872
+ # Combine insert_only parameter with auto-detected insert_only fields, but remove force_update fields
873
+ auto_insert_only = cls.insert_only_field_names() - set(force_update)
874
+ manual_insert_only = set(insert_only) - set(
875
+ force_update
876
+ ) # Remove force_update from manual insert_only too
877
+ all_insert_only = manual_insert_only | auto_insert_only
878
+
879
+ def generate_upsert_fragment():
880
+ updatable_fields = cls.field_names(
881
+ exclude=(*cls.primary_key_names, *all_insert_only)
882
+ )
883
+ return sql(
884
+ " ON CONFLICT ({pks}) DO {action}",
446
885
  insert_sql=insert_sql,
447
886
  pks=sql.list(cls.primary_key_names_sql()),
448
- assignments=sql.list(
449
- sql("{field}=EXCLUDED.{field}", field=x)
450
- for x in cls.field_names_sql(
451
- exclude=(*cls.primary_key_names, *exclude)
887
+ action=(
888
+ sql(
889
+ "UPDATE SET {assignments}",
890
+ assignments=sql.list(
891
+ sql("{field}=EXCLUDED.{field}", field=sql.identifier(field))
892
+ for field in updatable_fields
893
+ ),
452
894
  )
895
+ if updatable_fields
896
+ else sql.literal("NOTHING")
453
897
  ),
454
- ).flatten(),
898
+ ).flatten()
899
+
900
+ cached = cls._cached(
901
+ ("upsert_sql", tuple(sorted(all_insert_only))),
902
+ generate_upsert_fragment,
455
903
  )
456
904
  return Fragment([insert_sql, cached])
457
905
 
458
906
  async def upsert(
459
- self, connection_or_pool: Union[Connection, Pool], exclude: FieldNamesSet = ()
907
+ self,
908
+ connection_or_pool: Union[Connection, Pool],
909
+ exclude: FieldNamesSet = (),
910
+ insert_only: FieldNamesSet = (),
911
+ force_update: FieldNamesSet = (),
460
912
  ) -> bool:
913
+ """Insert or update this instance in the database.
914
+
915
+ Args:
916
+ connection_or_pool: Database connection or pool
917
+ exclude: Field names to exclude from INSERT and UPDATE
918
+ insert_only: Field names that should only be set on INSERT, not UPDATE
919
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
920
+
921
+ Returns:
922
+ True if the record was updated, False if it was inserted
923
+
924
+ Example:
925
+ >>> user = User(id=1, name="Alice", created_at=datetime.now())
926
+ >>> # Only set created_at on INSERT, not UPDATE
927
+ >>> was_updated = await user.upsert(pool, insert_only={'created_at'})
928
+ >>> # Force update created_at even if it's marked insert_only in ColumnInfo
929
+ >>> was_updated = await user.upsert(pool, force_update={'created_at'})
930
+
931
+ Note:
932
+ Fields marked with ColumnInfo(insert_only=True) are automatically
933
+ treated as insert-only and combined with the insert_only parameter,
934
+ unless overridden by force_update.
935
+ """
936
+ # upsert_sql automatically handles insert_only fields from ColumnInfo
937
+ # We only need to combine manual insert_only with exclude for the UPDATE clause
938
+ update_exclude = set(exclude) | set(insert_only)
461
939
  query = sql(
462
940
  "{} RETURNING xmax",
463
- self.upsert_sql(self.insert_sql(exclude=exclude), exclude=exclude),
941
+ self.upsert_sql(
942
+ self.insert_sql(exclude=exclude),
943
+ insert_only=update_exclude,
944
+ force_update=force_update,
945
+ ),
464
946
  )
465
947
  result = await connection_or_pool.fetchrow(*query)
466
948
  return result["xmax"] != 0
467
949
 
468
950
  @classmethod
469
951
  def delete_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
952
+ """Generate DELETE SQL for multiple records.
953
+
954
+ Args:
955
+ rows: Model instances to delete
956
+
957
+ Returns:
958
+ Fragment containing DELETE statement with UNNEST-based WHERE clause
959
+
960
+ Example:
961
+ >>> users = [user1, user2, user3]
962
+ >>> list(User.delete_multiple_sql(users))
963
+ ['DELETE FROM "users" WHERE ("id") IN (SELECT * FROM UNNEST($1::UUID[]))', (uuid1, uuid2, uuid3)]
964
+ """
470
965
  cached = cls._cached(
471
966
  ("delete_multiple_sql",),
472
967
  lambda: sql(
@@ -487,10 +982,34 @@ class ModelBase:
487
982
  async def delete_multiple(
488
983
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
489
984
  ) -> str:
985
+ """Delete multiple records from the database.
986
+
987
+ Args:
988
+ connection_or_pool: Database connection or pool
989
+ rows: Model instances to delete
990
+
991
+ Returns:
992
+ Result string from the database operation
993
+ """
490
994
  return await connection_or_pool.execute(*cls.delete_multiple_sql(rows))
491
995
 
492
996
  @classmethod
493
997
  def insert_multiple_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
998
+ """Generate bulk INSERT SQL using UNNEST.
999
+
1000
+ This is the most efficient method for bulk inserts in PostgreSQL.
1001
+
1002
+ Args:
1003
+ rows: Model instances to insert
1004
+
1005
+ Returns:
1006
+ Fragment containing INSERT ... SELECT FROM UNNEST statement
1007
+
1008
+ Example:
1009
+ >>> users = [User(name="Alice"), User(name="Bob")]
1010
+ >>> list(User.insert_multiple_sql(users))
1011
+ ['INSERT INTO "users" ("name", "email") SELECT * FROM UNNEST($1::TEXT[], $2::TEXT[])', ('Alice', 'Bob'), (None, None)]
1012
+ """
494
1013
  cached = cls._cached(
495
1014
  ("insert_multiple_sql",),
496
1015
  lambda: sql(
@@ -509,6 +1028,18 @@ class ModelBase:
509
1028
 
510
1029
  @classmethod
511
1030
  def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
1031
+ """Generate bulk INSERT SQL using VALUES syntax.
1032
+
1033
+ This method is required when your model contains array columns, because
1034
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1035
+ Use this instead of the UNNEST method when you have array-typed fields.
1036
+
1037
+ Args:
1038
+ rows: Model instances to insert
1039
+
1040
+ Returns:
1041
+ Fragment containing INSERT ... VALUES statement
1042
+ """
512
1043
  return sql(
513
1044
  "INSERT INTO {table} ({fields}) VALUES {values}",
514
1045
  table=cls.table_name_sql(),
@@ -523,6 +1054,15 @@ class ModelBase:
523
1054
  def insert_multiple_executemany_chunk_sql(
524
1055
  cls: type[T], chunk_size: int
525
1056
  ) -> Fragment:
1057
+ """Generate INSERT SQL template for executemany with specific chunk size.
1058
+
1059
+ Args:
1060
+ chunk_size: Number of records per batch
1061
+
1062
+ Returns:
1063
+ Fragment containing INSERT statement with numbered placeholders
1064
+ """
1065
+
526
1066
  def generate() -> Fragment:
527
1067
  columns = len(cls.column_info())
528
1068
  values = ", ".join(
@@ -545,6 +1085,14 @@ class ModelBase:
545
1085
  async def insert_multiple_executemany(
546
1086
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
547
1087
  ) -> None:
1088
+ """Insert multiple records using asyncpg's executemany.
1089
+
1090
+ This is the most compatible but slowest bulk insert method.
1091
+
1092
+ Args:
1093
+ connection_or_pool: Database connection or pool
1094
+ rows: Model instances to insert
1095
+ """
548
1096
  args = [r.field_values() for r in rows]
549
1097
  query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
550
1098
  if args:
@@ -554,12 +1102,36 @@ class ModelBase:
554
1102
  async def insert_multiple_unnest(
555
1103
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
556
1104
  ) -> str:
1105
+ """Insert multiple records using PostgreSQL UNNEST.
1106
+
1107
+ This is the most efficient bulk insert method for PostgreSQL.
1108
+
1109
+ Args:
1110
+ connection_or_pool: Database connection or pool
1111
+ rows: Model instances to insert
1112
+
1113
+ Returns:
1114
+ Result string from the database operation
1115
+ """
557
1116
  return await connection_or_pool.execute(*cls.insert_multiple_sql(rows))
558
1117
 
559
1118
  @classmethod
560
1119
  async def insert_multiple_array_safe(
561
1120
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
562
1121
  ) -> str:
1122
+ """Insert multiple records using VALUES syntax with chunking.
1123
+
1124
+ This method is required when your model contains array columns, because
1125
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1126
+ Data is processed in chunks to manage memory usage.
1127
+
1128
+ Args:
1129
+ connection_or_pool: Database connection or pool
1130
+ rows: Model instances to insert
1131
+
1132
+ Returns:
1133
+ Result string from the last chunk operation
1134
+ """
563
1135
  last = ""
564
1136
  for chunk in chunked(rows, 100):
565
1137
  last = await connection_or_pool.execute(
@@ -571,6 +1143,21 @@ class ModelBase:
571
1143
  async def insert_multiple(
572
1144
  cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
573
1145
  ) -> str:
1146
+ """Insert multiple records using the configured insert_multiple_mode.
1147
+
1148
+ Args:
1149
+ connection_or_pool: Database connection or pool
1150
+ rows: Model instances to insert
1151
+
1152
+ Returns:
1153
+ Result string from the database operation
1154
+
1155
+ Note:
1156
+ The actual method used depends on the insert_multiple_mode setting:
1157
+ - 'unnest': Most efficient, uses UNNEST (default)
1158
+ - 'array_safe': Uses VALUES syntax; required when model has array columns
1159
+ - 'executemany': Uses asyncpg's executemany, slowest but most compatible
1160
+ """
574
1161
  if cls.insert_multiple_mode == "executemany":
575
1162
  await cls.insert_multiple_executemany(connection_or_pool, rows)
576
1163
  return "INSERT"
@@ -585,10 +1172,21 @@ class ModelBase:
585
1172
  connection_or_pool: Union[Connection, Pool],
586
1173
  rows: Iterable[T],
587
1174
  insert_only: FieldNamesSet = (),
1175
+ force_update: FieldNamesSet = (),
588
1176
  ) -> None:
1177
+ """Bulk upsert using asyncpg's executemany.
1178
+
1179
+ Args:
1180
+ connection_or_pool: Database connection or pool
1181
+ rows: Model instances to upsert
1182
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1183
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1184
+ """
589
1185
  args = [r.field_values() for r in rows]
590
1186
  query = cls.upsert_sql(
591
- cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
1187
+ cls.insert_multiple_executemany_chunk_sql(1),
1188
+ insert_only=insert_only,
1189
+ force_update=force_update,
592
1190
  ).query()[0]
593
1191
  if args:
594
1192
  await connection_or_pool.executemany(query, args)
@@ -599,9 +1197,25 @@ class ModelBase:
599
1197
  connection_or_pool: Union[Connection, Pool],
600
1198
  rows: Iterable[T],
601
1199
  insert_only: FieldNamesSet = (),
1200
+ force_update: FieldNamesSet = (),
602
1201
  ) -> str:
1202
+ """Bulk upsert using PostgreSQL UNNEST.
1203
+
1204
+ Args:
1205
+ connection_or_pool: Database connection or pool
1206
+ rows: Model instances to upsert
1207
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1208
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1209
+
1210
+ Returns:
1211
+ Result string from the database operation
1212
+ """
603
1213
  return await connection_or_pool.execute(
604
- *cls.upsert_sql(cls.insert_multiple_sql(rows), exclude=insert_only)
1214
+ *cls.upsert_sql(
1215
+ cls.insert_multiple_sql(rows),
1216
+ insert_only=insert_only,
1217
+ force_update=force_update,
1218
+ )
605
1219
  )
606
1220
 
607
1221
  @classmethod
@@ -610,12 +1224,29 @@ class ModelBase:
610
1224
  connection_or_pool: Union[Connection, Pool],
611
1225
  rows: Iterable[T],
612
1226
  insert_only: FieldNamesSet = (),
1227
+ force_update: FieldNamesSet = (),
613
1228
  ) -> str:
1229
+ """Bulk upsert using VALUES syntax with chunking.
1230
+
1231
+ This method is required when your model contains array columns, because
1232
+ PostgreSQL doesn't support arrays-of-arrays (which UNNEST would require).
1233
+
1234
+ Args:
1235
+ connection_or_pool: Database connection or pool
1236
+ rows: Model instances to upsert
1237
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1238
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1239
+
1240
+ Returns:
1241
+ Result string from the last chunk operation
1242
+ """
614
1243
  last = ""
615
1244
  for chunk in chunked(rows, 100):
616
1245
  last = await connection_or_pool.execute(
617
1246
  *cls.upsert_sql(
618
- cls.insert_multiple_array_safe_sql(chunk), exclude=insert_only
1247
+ cls.insert_multiple_array_safe_sql(chunk),
1248
+ insert_only=insert_only,
1249
+ force_update=force_update,
619
1250
  )
620
1251
  )
621
1252
  return last
@@ -626,25 +1257,66 @@ class ModelBase:
626
1257
  connection_or_pool: Union[Connection, Pool],
627
1258
  rows: Iterable[T],
628
1259
  insert_only: FieldNamesSet = (),
1260
+ force_update: FieldNamesSet = (),
629
1261
  ) -> str:
1262
+ """Bulk upsert (INSERT ... ON CONFLICT DO UPDATE) multiple records.
1263
+
1264
+ Args:
1265
+ connection_or_pool: Database connection or pool
1266
+ rows: Model instances to upsert
1267
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1268
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1269
+
1270
+ Returns:
1271
+ Result string from the database operation
1272
+
1273
+ Example:
1274
+ >>> await User.upsert_multiple(pool, users, insert_only={'created_at'})
1275
+ >>> await User.upsert_multiple(pool, users, force_update={'created_at'})
1276
+
1277
+ Note:
1278
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1279
+ treated as insert-only and combined with the insert_only parameter,
1280
+ unless overridden by force_update.
1281
+ """
1282
+ # upsert_sql automatically handles insert_only fields from ColumnInfo
1283
+ # Pass manual insert_only parameter through to the specific implementations
1284
+
630
1285
  if cls.insert_multiple_mode == "executemany":
631
1286
  await cls.upsert_multiple_executemany(
632
- connection_or_pool, rows, insert_only=insert_only
1287
+ connection_or_pool,
1288
+ rows,
1289
+ insert_only=insert_only,
1290
+ force_update=force_update,
633
1291
  )
634
1292
  return "INSERT"
635
1293
  elif cls.insert_multiple_mode == "array_safe":
636
1294
  return await cls.upsert_multiple_array_safe(
637
- connection_or_pool, rows, insert_only=insert_only
1295
+ connection_or_pool,
1296
+ rows,
1297
+ insert_only=insert_only,
1298
+ force_update=force_update,
638
1299
  )
639
1300
  else:
640
1301
  return await cls.upsert_multiple_unnest(
641
- connection_or_pool, rows, insert_only=insert_only
1302
+ connection_or_pool,
1303
+ rows,
1304
+ insert_only=insert_only,
1305
+ force_update=force_update,
642
1306
  )
643
1307
 
644
1308
  @classmethod
645
1309
  def _get_equal_ignoring_fn(
646
1310
  cls: type[T], ignore: FieldNamesSet = ()
647
1311
  ) -> Callable[[T, T], bool]:
1312
+ """Generate optimized function to compare instances ignoring certain fields.
1313
+
1314
+ Args:
1315
+ ignore: Field names to ignore during comparison
1316
+
1317
+ Returns:
1318
+ Compiled function that compares two instances, returning True if equal
1319
+ """
648
1320
  env: dict[str, Any] = {}
649
1321
  func = ["def equal_ignoring(a, b):"]
650
1322
  for ci in cls.column_info().values():
@@ -663,8 +1335,38 @@ class ModelBase:
663
1335
  where: Where,
664
1336
  ignore: FieldNamesSet = (),
665
1337
  insert_only: FieldNamesSet = (),
1338
+ force_update: FieldNamesSet = (),
666
1339
  ) -> "ReplaceMultiplePlan[T]":
667
- ignore = sorted(set(ignore) | set(insert_only))
1340
+ """Plan a replace operation by comparing new data with existing records.
1341
+
1342
+ This method analyzes the differences between the provided rows and existing
1343
+ database records, determining which records need to be created, updated, or deleted.
1344
+
1345
+ Args:
1346
+ connection: Database connection (must support FOR UPDATE)
1347
+ rows: New data as model instances or mappings
1348
+ where: WHERE clause to limit which existing records to consider
1349
+ ignore: Field names to ignore when comparing records
1350
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1351
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1352
+
1353
+ Returns:
1354
+ ReplaceMultiplePlan containing the planned operations
1355
+
1356
+ Example:
1357
+ >>> plan = await User.plan_replace_multiple(
1358
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1359
+ ... )
1360
+ >>> print(f"Will create {len(plan.created)}, update {len(plan.updated)}, delete {len(plan.deleted)}")
1361
+
1362
+ Note:
1363
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1364
+ treated as insert-only and combined with the insert_only parameter,
1365
+ unless overridden by force_update.
1366
+ """
1367
+ # For comparison purposes, combine auto-detected insert_only fields with manual ones
1368
+ all_insert_only = cls.insert_only_field_names() | set(insert_only)
1369
+ ignore = sorted(set(ignore) | all_insert_only)
668
1370
  equal_ignoring = cls._cached(
669
1371
  ("equal_ignoring", tuple(ignore)),
670
1372
  lambda: cls._get_equal_ignoring_fn(ignore),
@@ -687,7 +1389,11 @@ class ModelBase:
687
1389
 
688
1390
  created = list(pending.values())
689
1391
 
690
- return ReplaceMultiplePlan(cls, insert_only, created, updated, deleted)
1392
+ # Pass only manual insert_only and force_update to the plan
1393
+ # since upsert_multiple handles auto-detected ones
1394
+ return ReplaceMultiplePlan(
1395
+ cls, insert_only, force_update, created, updated, deleted
1396
+ )
691
1397
 
692
1398
  @classmethod
693
1399
  async def replace_multiple(
@@ -698,9 +1404,42 @@ class ModelBase:
698
1404
  where: Where,
699
1405
  ignore: FieldNamesSet = (),
700
1406
  insert_only: FieldNamesSet = (),
1407
+ force_update: FieldNamesSet = (),
701
1408
  ) -> tuple[list[T], list[T], list[T]]:
1409
+ """Replace records in the database with the provided data.
1410
+
1411
+ This is a complete replace operation: records matching the WHERE clause
1412
+ that aren't in the new data will be deleted, new records will be inserted,
1413
+ and changed records will be updated.
1414
+
1415
+ Args:
1416
+ connection: Database connection (must support FOR UPDATE)
1417
+ rows: New data as model instances or mappings
1418
+ where: WHERE clause to limit which existing records to consider for replacement
1419
+ ignore: Field names to ignore when comparing records
1420
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1421
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1422
+
1423
+ Returns:
1424
+ Tuple of (created_records, updated_records, deleted_records)
1425
+
1426
+ Example:
1427
+ >>> created, updated, deleted = await User.replace_multiple(
1428
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1429
+ ... )
1430
+
1431
+ Note:
1432
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1433
+ treated as insert-only and combined with the insert_only parameter,
1434
+ unless overridden by force_update.
1435
+ """
702
1436
  plan = await cls.plan_replace_multiple(
703
- connection, rows, where=where, ignore=ignore, insert_only=insert_only
1437
+ connection,
1438
+ rows,
1439
+ where=where,
1440
+ ignore=ignore,
1441
+ insert_only=insert_only,
1442
+ force_update=force_update,
704
1443
  )
705
1444
  await plan.execute(connection)
706
1445
  return plan.cud
@@ -709,6 +1448,14 @@ class ModelBase:
709
1448
  def _get_differences_ignoring_fn(
710
1449
  cls: type[T], ignore: FieldNamesSet = ()
711
1450
  ) -> Callable[[T, T], list[str]]:
1451
+ """Generate optimized function to find field differences between instances.
1452
+
1453
+ Args:
1454
+ ignore: Field names to ignore during comparison
1455
+
1456
+ Returns:
1457
+ Compiled function that returns list of field names that differ
1458
+ """
712
1459
  env: dict[str, Any] = {}
713
1460
  func = [
714
1461
  "def differences_ignoring(a, b):",
@@ -732,8 +1479,40 @@ class ModelBase:
732
1479
  where: Where,
733
1480
  ignore: FieldNamesSet = (),
734
1481
  insert_only: FieldNamesSet = (),
1482
+ force_update: FieldNamesSet = (),
735
1483
  ) -> tuple[list[T], list[tuple[T, T, list[str]]], list[T]]:
736
- ignore = sorted(set(ignore) | set(insert_only))
1484
+ """Replace records and report the specific field differences for updates.
1485
+
1486
+ Like replace_multiple, but provides detailed information about which
1487
+ fields changed for each updated record.
1488
+
1489
+ Args:
1490
+ connection: Database connection (must support FOR UPDATE)
1491
+ rows: New data as model instances or mappings
1492
+ where: WHERE clause to limit which existing records to consider
1493
+ ignore: Field names to ignore when comparing records
1494
+ insert_only: Field names that should only be set on INSERT, not UPDATE
1495
+ force_update: Field names to force include in UPDATE clause, overriding insert_only settings
1496
+
1497
+ Returns:
1498
+ Tuple of (created_records, update_triples, deleted_records)
1499
+ where update_triples contains (old_record, new_record, changed_field_names)
1500
+
1501
+ Example:
1502
+ >>> created, updates, deleted = await User.replace_multiple_reporting_differences(
1503
+ ... conn, new_users, where=sql("department_id = {}", dept_id)
1504
+ ... )
1505
+ >>> for old, new, fields in updates:
1506
+ ... print(f"Updated {old.name}: changed {', '.join(fields)}")
1507
+
1508
+ Note:
1509
+ Fields marked with ColumnInfo(insert_only=True) are automatically
1510
+ treated as insert-only and combined with the insert_only parameter,
1511
+ unless overridden by force_update.
1512
+ """
1513
+ # For comparison purposes, combine auto-detected insert_only fields with manual ones
1514
+ all_insert_only = cls.insert_only_field_names() | set(insert_only)
1515
+ ignore = sorted(set(ignore) | all_insert_only)
737
1516
  differences_ignoring = cls._cached(
738
1517
  ("differences_ignoring", tuple(ignore)),
739
1518
  lambda: cls._get_differences_ignoring_fn(ignore),
@@ -763,6 +1542,7 @@ class ModelBase:
763
1542
  connection,
764
1543
  (*created, *(t[1] for t in updated_triples)),
765
1544
  insert_only=insert_only,
1545
+ force_update=force_update,
766
1546
  )
767
1547
  if deleted:
768
1548
  await cls.delete_multiple(connection, deleted)
@@ -774,30 +1554,67 @@ class ModelBase:
774
1554
  class ReplaceMultiplePlan(Generic[T]):
775
1555
  model_class: type[T]
776
1556
  insert_only: FieldNamesSet
1557
+ force_update: FieldNamesSet
777
1558
  created: list[T]
778
1559
  updated: list[T]
779
1560
  deleted: list[T]
780
1561
 
781
1562
  @property
782
1563
  def cud(self) -> tuple[list[T], list[T], list[T]]:
1564
+ """Get the create, update, delete lists as a tuple.
1565
+
1566
+ Returns:
1567
+ Tuple of (created, updated, deleted) record lists
1568
+ """
783
1569
  return (self.created, self.updated, self.deleted)
784
1570
 
785
1571
  async def execute_upserts(self, connection: Connection) -> None:
1572
+ """Execute the upsert operations (creates and updates).
1573
+
1574
+ Args:
1575
+ connection: Database connection
1576
+ """
786
1577
  if self.created or self.updated:
787
1578
  await self.model_class.upsert_multiple(
788
- connection, (*self.created, *self.updated), insert_only=self.insert_only
1579
+ connection,
1580
+ (*self.created, *self.updated),
1581
+ insert_only=self.insert_only,
1582
+ force_update=self.force_update,
789
1583
  )
790
1584
 
791
1585
  async def execute_deletes(self, connection: Connection) -> None:
1586
+ """Execute the delete operations.
1587
+
1588
+ Args:
1589
+ connection: Database connection
1590
+ """
792
1591
  if self.deleted:
793
1592
  await self.model_class.delete_multiple(connection, self.deleted)
794
1593
 
795
1594
  async def execute(self, connection: Connection) -> None:
1595
+ """Execute all planned operations (upserts then deletes).
1596
+
1597
+ Args:
1598
+ connection: Database connection
1599
+ """
796
1600
  await self.execute_upserts(connection)
797
1601
  await self.execute_deletes(connection)
798
1602
 
799
1603
 
800
1604
  def chunked(lst, n):
1605
+ """Split an iterable into chunks of size n.
1606
+
1607
+ Args:
1608
+ lst: Iterable to chunk
1609
+ n: Chunk size
1610
+
1611
+ Yields:
1612
+ Lists of up to n items from the input
1613
+
1614
+ Example:
1615
+ >>> list(chunked([1, 2, 3, 4, 5], 2))
1616
+ [[1, 2], [3, 4], [5]]
1617
+ """
801
1618
  if type(lst) is not list:
802
1619
  lst = list(lst)
803
1620
  for i in range(0, len(lst), n):