fastapi-toolsets 3.0.3__tar.gz → 3.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/PKG-INFO +1 -1
  2. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/pyproject.toml +1 -1
  3. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/__init__.py +1 -1
  4. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/db.py +145 -3
  5. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/fixtures/__init__.py +7 -1
  6. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/fixtures/utils.py +93 -29
  7. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/pytest/plugin.py +58 -3
  8. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/LICENSE +0 -0
  9. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/README.md +0 -0
  10. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/_imports.py +0 -0
  11. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/__init__.py +0 -0
  12. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/app.py +0 -0
  13. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/commands/__init__.py +0 -0
  14. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/commands/fixtures.py +0 -0
  15. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/config.py +0 -0
  16. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/pyproject.py +0 -0
  17. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/cli/utils.py +0 -0
  18. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/crud/__init__.py +0 -0
  19. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/crud/factory.py +0 -0
  20. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/crud/search.py +0 -0
  21. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/dependencies.py +0 -0
  22. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/exceptions/__init__.py +0 -0
  23. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/exceptions/exceptions.py +0 -0
  24. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/exceptions/handler.py +0 -0
  25. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/fixtures/enum.py +0 -0
  26. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/fixtures/registry.py +0 -0
  27. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/logger.py +0 -0
  28. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/metrics/__init__.py +0 -0
  29. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/metrics/handler.py +0 -0
  30. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/metrics/registry.py +0 -0
  31. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/models/__init__.py +0 -0
  32. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/models/columns.py +0 -0
  33. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/models/watched.py +0 -0
  34. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/py.typed +0 -0
  35. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/pytest/__init__.py +0 -0
  36. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/pytest/utils.py +0 -0
  37. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/schemas.py +0 -0
  38. {fastapi_toolsets-3.0.3 → fastapi_toolsets-3.1.1}/src/fastapi_toolsets/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-toolsets
3
- Version: 3.0.3
3
+ Version: 3.1.1
4
4
  Summary: Production-ready utilities for FastAPI applications
5
5
  Keywords: fastapi,sqlalchemy,postgresql
6
6
  Author: d3vyce
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "fastapi-toolsets"
3
- version = "3.0.3"
3
+ version = "3.1.1"
4
4
  description = "Production-ready utilities for FastAPI applications"
5
5
  readme = "README.md"
6
6
  license = "MIT"
@@ -21,4 +21,4 @@ Example usage:
21
21
  return Response(data={"user": user.username}, message="Success")
22
22
  """
23
23
 
24
- __version__ = "3.0.3"
24
+ __version__ = "3.1.1"
@@ -4,11 +4,13 @@ import asyncio
4
4
  from collections.abc import AsyncGenerator, Callable
5
5
  from contextlib import AbstractAsyncContextManager, asynccontextmanager
6
6
  from enum import Enum
7
- from typing import Any, TypeVar
7
+ from typing import Any, TypeVar, cast
8
8
 
9
- from sqlalchemy import text
9
+ from sqlalchemy import Table, delete, text, tuple_
10
+ from sqlalchemy.dialects.postgresql import insert as pg_insert
10
11
  from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
11
- from sqlalchemy.orm import DeclarativeBase
12
+ from sqlalchemy.orm import DeclarativeBase, QueryableAttribute
13
+ from sqlalchemy.orm.relationships import RelationshipProperty
12
14
 
13
15
  from .exceptions import NotFoundError
14
16
 
@@ -20,6 +22,9 @@ __all__ = [
20
22
  "create_db_dependency",
21
23
  "get_transaction",
22
24
  "lock_tables",
25
+ "m2m_add",
26
+ "m2m_remove",
27
+ "m2m_set",
23
28
  "wait_for_row_change",
24
29
  ]
25
30
 
@@ -339,3 +344,140 @@ async def wait_for_row_change(
339
344
  current = {col: getattr(instance, col) for col in watch_cols}
340
345
  if current != initial:
341
346
  return instance
347
+
348
+
349
+ def _m2m_prop(rel_attr: QueryableAttribute) -> RelationshipProperty: # type: ignore[type-arg]
350
+ """Return the validated M2M RelationshipProperty for *rel_attr*.
351
+
352
+ Raises TypeError if *rel_attr* is not a Many-to-Many relationship.
353
+ """
354
+ prop = rel_attr.property
355
+ if not isinstance(prop, RelationshipProperty) or prop.secondary is None:
356
+ raise TypeError(
357
+ f"m2m helpers require a Many-to-Many relationship attribute, "
358
+ f"got {rel_attr!r}. Use a relationship with a secondary table."
359
+ )
360
+ return prop
361
+
362
+
363
+ async def m2m_add(
364
+ session: AsyncSession,
365
+ instance: DeclarativeBase,
366
+ rel_attr: QueryableAttribute,
367
+ *related: DeclarativeBase,
368
+ ignore_conflicts: bool = False,
369
+ ) -> None:
370
+ """Insert rows into a Many-to-Many association table without loading the ORM collection.
371
+
372
+ Args:
373
+ session: DB async session.
374
+ instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
375
+ rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
376
+ *related: One or more related instances to associate with ``instance``.
377
+ ignore_conflicts: When ``True``, silently skip rows that already exist
378
+ in the association table (``ON CONFLICT DO NOTHING``).
379
+
380
+ Raises:
381
+ TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
382
+ """
383
+ prop = _m2m_prop(rel_attr)
384
+ if not related:
385
+ return
386
+
387
+ secondary = cast(Table, prop.secondary)
388
+ assert secondary is not None # guaranteed by _m2m_prop
389
+ sync_pairs = prop.secondary_synchronize_pairs
390
+ assert sync_pairs is not None # set whenever secondary is set
391
+
392
+ # synchronize_pairs: [(parent_col, assoc_col), ...]
393
+ # secondary_synchronize_pairs: [(related_col, assoc_col), ...]
394
+ rows: list[dict[str, Any]] = []
395
+ for rel_instance in related:
396
+ row: dict[str, Any] = {}
397
+ for parent_col, assoc_col in prop.synchronize_pairs:
398
+ row[assoc_col.name] = getattr(instance, cast(str, parent_col.key))
399
+ for related_col, assoc_col in sync_pairs:
400
+ row[assoc_col.name] = getattr(rel_instance, cast(str, related_col.key))
401
+ rows.append(row)
402
+
403
+ stmt = pg_insert(secondary).values(rows)
404
+ if ignore_conflicts:
405
+ stmt = stmt.on_conflict_do_nothing()
406
+ await session.execute(stmt)
407
+
408
+
409
+ async def m2m_remove(
410
+ session: AsyncSession,
411
+ instance: DeclarativeBase,
412
+ rel_attr: QueryableAttribute,
413
+ *related: DeclarativeBase,
414
+ ) -> None:
415
+ """Remove rows from a Many-to-Many association table without loading the ORM collection.
416
+
417
+ Args:
418
+ session: DB async session.
419
+ instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
420
+ rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
421
+ *related: One or more related instances to disassociate from ``instance``.
422
+
423
+ Raises:
424
+ TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
425
+ """
426
+ prop = _m2m_prop(rel_attr)
427
+ if not related:
428
+ return
429
+
430
+ secondary = cast(Table, prop.secondary)
431
+ assert secondary is not None # guaranteed by _m2m_prop
432
+ related_pairs = prop.secondary_synchronize_pairs
433
+ assert related_pairs is not None # set whenever secondary is set
434
+
435
+ parent_where = [
436
+ assoc_col == getattr(instance, cast(str, parent_col.key))
437
+ for parent_col, assoc_col in prop.synchronize_pairs
438
+ ]
439
+
440
+ if len(related_pairs) == 1:
441
+ related_col, assoc_col = related_pairs[0]
442
+ related_values = [getattr(r, cast(str, related_col.key)) for r in related]
443
+ related_where = assoc_col.in_(related_values)
444
+ else:
445
+ assoc_cols = [ac for _, ac in related_pairs]
446
+ rel_cols = [rc for rc, _ in related_pairs]
447
+ related_values_t = [
448
+ tuple(getattr(r, cast(str, rc.key)) for rc in rel_cols) for r in related
449
+ ]
450
+ related_where = tuple_(*assoc_cols).in_(related_values_t)
451
+
452
+ await session.execute(delete(secondary).where(*parent_where, related_where))
453
+
454
+
455
+ async def m2m_set(
456
+ session: AsyncSession,
457
+ instance: DeclarativeBase,
458
+ rel_attr: QueryableAttribute,
459
+ *related: DeclarativeBase,
460
+ ) -> None:
461
+ """Replace the entire Many-to-Many association set atomically.
462
+
463
+ Args:
464
+ session: DB async session.
465
+ instance: The "owner" side model instance (e.g. the ``A`` in ``A.b_list``).
466
+ rel_attr: The M2M relationship attribute on the model class (e.g. ``A.b_list``).
467
+ *related: The new complete set of related instances.
468
+
469
+ Raises:
470
+ TypeError: If ``rel_attr`` is not a Many-to-Many relationship.
471
+ """
472
+ prop = _m2m_prop(rel_attr)
473
+ secondary = cast(Table, prop.secondary)
474
+ assert secondary is not None # guaranteed by _m2m_prop
475
+
476
+ parent_where = [
477
+ assoc_col == getattr(instance, cast(str, parent_col.key))
478
+ for parent_col, assoc_col in prop.synchronize_pairs
479
+ ]
480
+ await session.execute(delete(secondary).where(*parent_where))
481
+
482
+ if related:
483
+ await m2m_add(session, instance, rel_attr, *related)
@@ -2,12 +2,18 @@
2
2
 
3
3
  from .enum import LoadStrategy
4
4
  from .registry import Context, FixtureRegistry
5
- from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
5
+ from .utils import (
6
+ get_field_by_attr,
7
+ get_obj_by_attr,
8
+ load_fixtures,
9
+ load_fixtures_by_context,
10
+ )
6
11
 
7
12
  __all__ = [
8
13
  "Context",
9
14
  "FixtureRegistry",
10
15
  "LoadStrategy",
16
+ "get_field_by_attr",
11
17
  "get_obj_by_attr",
12
18
  "load_fixtures",
13
19
  "load_fixtures_by_context",
@@ -40,6 +40,32 @@ def _instance_to_dict(instance: DeclarativeBase) -> dict[str, Any]:
40
40
  return result
41
41
 
42
42
 
43
+ def _get_table_chain(model_cls: type[DeclarativeBase]) -> list[type[DeclarativeBase]]:
44
+ """Return [root, ..., model_cls] for joined-table inheritance, or [model_cls]."""
45
+ chain: list[type[DeclarativeBase]] = []
46
+ current = sa_inspect(model_cls)
47
+ while current is not None:
48
+ chain.append(current.class_)
49
+ current = current.inherits
50
+ chain.reverse()
51
+ seen: set[int] = set()
52
+ result: list[type[DeclarativeBase]] = []
53
+ for cls in chain:
54
+ tid = id(cls.__table__)
55
+ if tid not in seen: # pragma: no branch
56
+ seen.add(tid)
57
+ result.append(cls)
58
+ return result
59
+
60
+
61
+ def _instance_to_dict_for_cls(
62
+ instance: DeclarativeBase, cls: type[DeclarativeBase]
63
+ ) -> dict[str, Any]:
64
+ """Like _instance_to_dict but limited to columns belonging to cls's own table."""
65
+ own_cols = {col.key for col in cls.__table__.columns}
66
+ return {k: v for k, v in _instance_to_dict(instance).items() if k in own_cols}
67
+
68
+
43
69
  def _group_by_type(
44
70
  instances: list[DeclarativeBase],
45
71
  ) -> list[tuple[type[DeclarativeBase], list[DeclarativeBase]]]:
@@ -73,9 +99,11 @@ async def _batch_insert(
73
99
  instances: list[DeclarativeBase],
74
100
  ) -> None:
75
101
  """INSERT all instances — raises on conflict (no duplicate handling)."""
76
- dicts = [_instance_to_dict(i) for i in instances]
77
- for group_dicts, _ in _group_by_column_set(dicts, instances):
78
- await session.execute(pg_insert(model_cls).values(group_dicts))
102
+ for cls in _get_table_chain(model_cls):
103
+ dicts = [_instance_to_dict_for_cls(i, cls) for i in instances]
104
+ for group_dicts, _ in _group_by_column_set(dicts, instances):
105
+ if group_dicts and group_dicts[0]: # pragma: no branch
106
+ await session.execute(pg_insert(cls).values(group_dicts))
79
107
 
80
108
 
81
109
  async def _batch_merge(
@@ -84,31 +112,30 @@ async def _batch_merge(
84
112
  instances: list[DeclarativeBase],
85
113
  ) -> None:
86
114
  """UPSERT: insert new rows, update existing ones with the provided values."""
87
- mapper = model_cls.__mapper__
88
- pk_names = [col.name for col in mapper.primary_key]
89
- pk_names_set = set(pk_names)
90
- non_pk_cols = [
91
- prop.key
92
- for prop in mapper.column_attrs
93
- if not any(col.name in pk_names_set for col in prop.columns)
94
- ]
95
-
96
- dicts = [_instance_to_dict(i) for i in instances]
97
- for group_dicts, _ in _group_by_column_set(dicts, instances):
98
- stmt = pg_insert(model_cls).values(group_dicts)
99
-
100
- inserted_keys = set(group_dicts[0])
101
- update_cols = [col for col in non_pk_cols if col in inserted_keys]
102
-
103
- if update_cols:
104
- stmt = stmt.on_conflict_do_update(
105
- index_elements=pk_names,
106
- set_={col: stmt.excluded[col] for col in update_cols},
107
- )
108
- else:
109
- stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
115
+ for cls in _get_table_chain(model_cls):
116
+ pk_names = [col.name for col in cls.__table__.primary_key]
117
+ pk_names_set = set(pk_names)
118
+ own_col_keys = {col.key for col in cls.__table__.columns}
119
+ non_pk_cols = [k for k in own_col_keys if k not in pk_names_set]
120
+
121
+ dicts = [_instance_to_dict_for_cls(i, cls) for i in instances]
122
+ for group_dicts, _ in _group_by_column_set(dicts, instances):
123
+ if not group_dicts or not group_dicts[0]: # pragma: no cover
124
+ continue
125
+ stmt = pg_insert(cls).values(group_dicts)
110
126
 
111
- await session.execute(stmt)
127
+ inserted_keys = set(group_dicts[0])
128
+ update_cols = [col for col in non_pk_cols if col in inserted_keys]
129
+
130
+ if update_cols:
131
+ stmt = stmt.on_conflict_do_update(
132
+ index_elements=pk_names,
133
+ set_={col: stmt.excluded[col] for col in update_cols},
134
+ )
135
+ else:
136
+ stmt = stmt.on_conflict_do_nothing(index_elements=pk_names)
137
+
138
+ await session.execute(stmt)
112
139
 
113
140
 
114
141
  async def _batch_skip_existing(
@@ -117,6 +144,16 @@ async def _batch_skip_existing(
117
144
  instances: list[DeclarativeBase],
118
145
  ) -> list[DeclarativeBase]:
119
146
  """INSERT only rows that do not already exist; return the inserted ones."""
147
+ if len(_get_table_chain(model_cls)) > 1:
148
+ loaded: list[DeclarativeBase] = []
149
+ for inst in instances:
150
+ pk = _get_primary_key(inst)
151
+ if pk is None or not await session.get(model_cls, pk):
152
+ session.add(inst)
153
+ loaded.append(inst)
154
+ await session.flush()
155
+ return loaded
156
+
120
157
  mapper = model_cls.__mapper__
121
158
  pk_names = [col.name for col in mapper.primary_key]
122
159
 
@@ -129,7 +166,7 @@ async def _batch_skip_existing(
129
166
  else:
130
167
  with_pk_pairs.append((inst, pk))
131
168
 
132
- loaded: list[DeclarativeBase] = list(no_pk)
169
+ loaded = list(no_pk)
133
170
  if no_pk:
134
171
  no_pk_dicts = [_instance_to_dict(i) for i in no_pk]
135
172
  for group_dicts, _ in _group_by_column_set(no_pk_dicts, no_pk):
@@ -179,7 +216,7 @@ async def _load_ordered(
179
216
  if contexts is not None and not variants:
180
217
  variants = registry.get_variants(name)
181
218
 
182
- if not variants:
219
+ if not variants: # pragma: no cover
183
220
  results[name] = []
184
221
  continue
185
222
 
@@ -204,6 +241,8 @@ async def _load_ordered(
204
241
  case LoadStrategy.SKIP_EXISTING:
205
242
  inserted = await _batch_skip_existing(session, model_cls, group)
206
243
  loaded.extend(inserted)
244
+ case _: # pragma: no cover
245
+ pass
207
246
 
208
247
  results[name] = loaded
209
248
  logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
@@ -250,6 +289,31 @@ def get_obj_by_attr(
250
289
  ) from None
251
290
 
252
291
 
292
+ def get_field_by_attr(
293
+ fixtures: Callable[[], Sequence[ModelType]],
294
+ attr_name: str,
295
+ value: Any,
296
+ *,
297
+ field: str = "id",
298
+ ) -> Any:
299
+ """Get a single field value from a fixture object matched by an attribute.
300
+
301
+ Args:
302
+ fixtures: A fixture function registered via ``@registry.register``
303
+ that returns a sequence of SQLAlchemy model instances.
304
+ attr_name: Name of the attribute to match against.
305
+ value: Value to match.
306
+ field: Attribute name to return from the matched object (default: ``"id"``).
307
+
308
+ Returns:
309
+ The value of ``field`` on the first matching model instance.
310
+
311
+ Raises:
312
+ StopIteration: If no matching object is found in the fixture group.
313
+ """
314
+ return getattr(get_obj_by_attr(fixtures, attr_name, value), field)
315
+
316
+
253
317
  async def load_fixtures(
254
318
  session: AsyncSession,
255
319
  registry: FixtureRegistry,
@@ -1,11 +1,13 @@
1
1
  """Pytest plugin for using FixtureRegistry fixtures in tests."""
2
2
 
3
3
  from collections.abc import Callable, Sequence
4
- from typing import Any
4
+ from typing import Any, cast
5
5
 
6
6
  import pytest
7
+ from sqlalchemy import select
7
8
  from sqlalchemy.ext.asyncio import AsyncSession
8
- from sqlalchemy.orm import DeclarativeBase
9
+ from sqlalchemy.orm import DeclarativeBase, selectinload
10
+ from sqlalchemy.orm.interfaces import ExecutableOption, ORMOption
9
11
 
10
12
  from ..db import get_transaction
11
13
  from ..fixtures import FixtureRegistry, LoadStrategy
@@ -112,7 +114,7 @@ def _create_fixture_function(
112
114
  elif strategy == LoadStrategy.MERGE:
113
115
  merged = await session.merge(instance)
114
116
  loaded.append(merged)
115
- elif strategy == LoadStrategy.SKIP_EXISTING:
117
+ elif strategy == LoadStrategy.SKIP_EXISTING: # pragma: no branch
116
118
  pk = _get_primary_key(instance)
117
119
  if pk is not None:
118
120
  existing = await session.get(type(instance), pk)
@@ -125,6 +127,11 @@ def _create_fixture_function(
125
127
  session.add(instance)
126
128
  loaded.append(instance)
127
129
 
130
+ if loaded: # pragma: no branch
131
+ load_options = _relationship_load_options(type(loaded[0]))
132
+ if load_options:
133
+ return await _reload_with_relationships(session, loaded, load_options)
134
+
128
135
  return loaded
129
136
 
130
137
  # Update function signature to include dependencies
@@ -141,6 +148,54 @@ def _create_fixture_function(
141
148
  return created_func
142
149
 
143
150
 
151
+ def _relationship_load_options(model: type[DeclarativeBase]) -> list[ExecutableOption]:
152
+ """Build selectinload options for all direct relationships on a model."""
153
+ return [
154
+ selectinload(getattr(model, rel.key)) for rel in model.__mapper__.relationships
155
+ ]
156
+
157
+
158
+ async def _reload_with_relationships(
159
+ session: AsyncSession,
160
+ instances: list[DeclarativeBase],
161
+ load_options: list[ExecutableOption],
162
+ ) -> list[DeclarativeBase]:
163
+ """Reload instances in a single bulk query with relationship eager-loading.
164
+
165
+ Uses one SELECT … WHERE pk IN (…) so selectinload can batch all relationship
166
+ queries — 1 + N_relationships round-trips regardless of how many instances
167
+ there are, instead of one session.get() per instance.
168
+
169
+ Preserves the original insertion order.
170
+ """
171
+ model = type(instances[0])
172
+ mapper = model.__mapper__
173
+ pk_cols = mapper.primary_key
174
+
175
+ if len(pk_cols) == 1:
176
+ pk_attr = getattr(model, pk_cols[0].key)
177
+ pks = [getattr(inst, pk_cols[0].key) for inst in instances]
178
+ result = await session.execute(
179
+ select(model).where(pk_attr.in_(pks)).options(*load_options)
180
+ )
181
+ by_pk = {getattr(row, pk_cols[0].key): row for row in result.unique().scalars()}
182
+ return [by_pk[pk] for pk in pks]
183
+
184
+ # Composite PK: fall back to per-instance reload
185
+ reloaded: list[DeclarativeBase] = []
186
+ for instance in instances:
187
+ pk = _get_primary_key(instance)
188
+ refreshed = await session.get(
189
+ model,
190
+ pk,
191
+ options=cast(list[ORMOption], load_options),
192
+ populate_existing=True,
193
+ )
194
+ if refreshed is not None: # pragma: no branch
195
+ reloaded.append(refreshed)
196
+ return reloaded
197
+
198
+
144
199
  def _get_primary_key(instance: DeclarativeBase) -> Any | None:
145
200
  """Get the primary key value of a model instance."""
146
201
  mapper = instance.__class__.__mapper__