dycw-utilities 0.129.10__py3-none-any.whl → 0.175.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (103) hide show
  1. dycw_utilities-0.175.17.dist-info/METADATA +34 -0
  2. dycw_utilities-0.175.17.dist-info/RECORD +103 -0
  3. dycw_utilities-0.175.17.dist-info/WHEEL +4 -0
  4. dycw_utilities-0.175.17.dist-info/entry_points.txt +4 -0
  5. utilities/__init__.py +1 -1
  6. utilities/altair.py +14 -14
  7. utilities/asyncio.py +350 -819
  8. utilities/atomicwrites.py +18 -6
  9. utilities/atools.py +77 -22
  10. utilities/cachetools.py +24 -29
  11. utilities/click.py +393 -237
  12. utilities/concurrent.py +8 -11
  13. utilities/contextlib.py +216 -17
  14. utilities/contextvars.py +20 -1
  15. utilities/cryptography.py +3 -3
  16. utilities/dataclasses.py +83 -118
  17. utilities/docker.py +293 -0
  18. utilities/enum.py +26 -23
  19. utilities/errors.py +17 -3
  20. utilities/fastapi.py +29 -65
  21. utilities/fpdf2.py +3 -3
  22. utilities/functions.py +169 -416
  23. utilities/functools.py +18 -19
  24. utilities/git.py +9 -30
  25. utilities/grp.py +28 -0
  26. utilities/gzip.py +31 -0
  27. utilities/http.py +3 -2
  28. utilities/hypothesis.py +738 -589
  29. utilities/importlib.py +17 -1
  30. utilities/inflect.py +25 -0
  31. utilities/iterables.py +194 -262
  32. utilities/jinja2.py +148 -0
  33. utilities/json.py +70 -0
  34. utilities/libcst.py +38 -17
  35. utilities/lightweight_charts.py +5 -9
  36. utilities/logging.py +345 -543
  37. utilities/math.py +18 -13
  38. utilities/memory_profiler.py +11 -15
  39. utilities/more_itertools.py +200 -131
  40. utilities/operator.py +33 -29
  41. utilities/optuna.py +6 -6
  42. utilities/orjson.py +272 -137
  43. utilities/os.py +61 -4
  44. utilities/parse.py +59 -61
  45. utilities/pathlib.py +281 -40
  46. utilities/permissions.py +298 -0
  47. utilities/pickle.py +2 -2
  48. utilities/platform.py +24 -5
  49. utilities/polars.py +1214 -430
  50. utilities/polars_ols.py +1 -1
  51. utilities/postgres.py +408 -0
  52. utilities/pottery.py +113 -26
  53. utilities/pqdm.py +10 -11
  54. utilities/psutil.py +6 -57
  55. utilities/pwd.py +28 -0
  56. utilities/pydantic.py +4 -54
  57. utilities/pydantic_settings.py +240 -0
  58. utilities/pydantic_settings_sops.py +76 -0
  59. utilities/pyinstrument.py +8 -10
  60. utilities/pytest.py +227 -121
  61. utilities/pytest_plugins/__init__.py +1 -0
  62. utilities/pytest_plugins/pytest_randomly.py +23 -0
  63. utilities/pytest_plugins/pytest_regressions.py +56 -0
  64. utilities/pytest_regressions.py +26 -46
  65. utilities/random.py +13 -9
  66. utilities/re.py +58 -28
  67. utilities/redis.py +401 -550
  68. utilities/scipy.py +1 -1
  69. utilities/sentinel.py +10 -0
  70. utilities/shelve.py +4 -1
  71. utilities/shutil.py +25 -0
  72. utilities/slack_sdk.py +36 -106
  73. utilities/sqlalchemy.py +502 -473
  74. utilities/sqlalchemy_polars.py +38 -94
  75. utilities/string.py +2 -3
  76. utilities/subprocess.py +1572 -0
  77. utilities/tempfile.py +86 -4
  78. utilities/testbook.py +50 -0
  79. utilities/text.py +165 -42
  80. utilities/timer.py +37 -65
  81. utilities/traceback.py +158 -929
  82. utilities/types.py +146 -116
  83. utilities/typing.py +531 -71
  84. utilities/tzdata.py +1 -53
  85. utilities/tzlocal.py +6 -23
  86. utilities/uuid.py +43 -5
  87. utilities/version.py +27 -26
  88. utilities/whenever.py +1776 -386
  89. utilities/zoneinfo.py +84 -22
  90. dycw_utilities-0.129.10.dist-info/METADATA +0 -241
  91. dycw_utilities-0.129.10.dist-info/RECORD +0 -96
  92. dycw_utilities-0.129.10.dist-info/WHEEL +0 -4
  93. dycw_utilities-0.129.10.dist-info/licenses/LICENSE +0 -21
  94. utilities/datetime.py +0 -1409
  95. utilities/eventkit.py +0 -402
  96. utilities/loguru.py +0 -144
  97. utilities/luigi.py +0 -228
  98. utilities/period.py +0 -324
  99. utilities/pyrsistent.py +0 -89
  100. utilities/python_dotenv.py +0 -105
  101. utilities/streamlit.py +0 -105
  102. utilities/sys.py +0 -87
  103. utilities/tenacity.py +0 -145
utilities/sqlalchemy.py CHANGED
@@ -12,21 +12,31 @@ from collections.abc import (
12
12
  )
13
13
  from collections.abc import Set as AbstractSet
14
14
  from contextlib import asynccontextmanager
15
- from dataclasses import dataclass, field
16
- from functools import partial, reduce
15
+ from dataclasses import dataclass
16
+ from functools import reduce
17
17
  from itertools import chain
18
18
  from math import floor
19
19
  from operator import ge, le
20
20
  from re import search
21
- from typing import Any, Literal, TypeGuard, TypeVar, assert_never, cast, override
21
+ from socket import gaierror
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Literal,
26
+ TypeGuard,
27
+ assert_never,
28
+ cast,
29
+ overload,
30
+ override,
31
+ )
22
32
 
33
+ import sqlalchemy
23
34
  from sqlalchemy import (
24
35
  URL,
25
36
  Column,
26
37
  Connection,
27
38
  Engine,
28
39
  Insert,
29
- PrimaryKeyConstraint,
30
40
  Selectable,
31
41
  Table,
32
42
  and_,
@@ -38,16 +48,14 @@ from sqlalchemy import (
38
48
  from sqlalchemy.dialects.mssql import dialect as mssql_dialect
39
49
  from sqlalchemy.dialects.mysql import dialect as mysql_dialect
40
50
  from sqlalchemy.dialects.oracle import dialect as oracle_dialect
41
- from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
42
51
  from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
43
52
  from sqlalchemy.dialects.postgresql import insert as postgresql_insert
44
53
  from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
45
- from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
54
+ from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
46
55
  from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
47
56
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
48
57
  from sqlalchemy.exc import ArgumentError, DatabaseError
49
- from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
50
- from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
58
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
51
59
  from sqlalchemy.orm import (
52
60
  DeclarativeBase,
53
61
  InstrumentedAttribute,
@@ -57,17 +65,8 @@ from sqlalchemy.orm import (
57
65
  from sqlalchemy.orm.exc import UnmappedClassError
58
66
  from sqlalchemy.pool import NullPool, Pool
59
67
 
60
- from utilities.asyncio import Looper, timeout_dur
61
- from utilities.contextlib import suppress_super_object_attribute_error
62
- from utilities.datetime import SECOND
63
- from utilities.functions import (
64
- ensure_str,
65
- get_class_name,
66
- is_sequence_of_tuple_or_str_mapping,
67
- is_string_mapping,
68
- is_tuple,
69
- is_tuple_or_str_mapping,
70
- )
68
+ from utilities.asyncio import timeout_td
69
+ from utilities.functions import ensure_str, get_class_name, yield_object_attributes
71
70
  from utilities.iterables import (
72
71
  CheckLengthError,
73
72
  CheckSubSetError,
@@ -80,16 +79,63 @@ from utilities.iterables import (
80
79
  merge_str_mappings,
81
80
  one,
82
81
  )
82
+ from utilities.os import is_pytest
83
83
  from utilities.reprlib import get_repr
84
- from utilities.text import snake_case
85
- from utilities.types import Duration, MaybeIterable, StrMapping, TupleOrStrMapping
84
+ from utilities.text import secret_str, snake_case
85
+ from utilities.types import (
86
+ Delta,
87
+ MaybeIterable,
88
+ MaybeType,
89
+ StrMapping,
90
+ TupleOrStrMapping,
91
+ )
92
+ from utilities.typing import (
93
+ is_sequence_of_tuple_or_str_mapping,
94
+ is_string_mapping,
95
+ is_tuple,
96
+ is_tuple_or_str_mapping,
97
+ )
98
+
99
+ if TYPE_CHECKING:
100
+ from enum import Enum, StrEnum
101
+
86
102
 
87
- _T = TypeVar("_T")
88
- type _EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
103
+ type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
89
104
  type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
105
+ type DialectOrEngineOrConnectionOrAsync = Dialect | EngineOrConnectionOrAsync
90
106
  type ORMInstOrClass = DeclarativeBase | type[DeclarativeBase]
91
107
  type TableOrORMInstOrClass = Table | ORMInstOrClass
92
- CHUNK_SIZE_FRAC = 0.95
108
+ CHUNK_SIZE_FRAC = 0.8
109
+
110
+
111
+ ##
112
+
113
+
114
+ _SELECT = text("SELECT 1")
115
+
116
+
117
+ def check_connect(engine: Engine, /) -> bool:
118
+ """Check if an engine can connect."""
119
+ try:
120
+ with engine.connect() as conn:
121
+ return bool(conn.execute(_SELECT).scalar_one())
122
+ except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
123
+ return False
124
+
125
+
126
+ async def check_connect_async(
127
+ engine: AsyncEngine,
128
+ /,
129
+ *,
130
+ timeout: Delta | None = None,
131
+ error: MaybeType[BaseException] = TimeoutError,
132
+ ) -> bool:
133
+ """Check if an engine can connect."""
134
+ try:
135
+ async with timeout_td(timeout, error=error), engine.connect() as conn:
136
+ return bool((await conn.execute(_SELECT)).scalar_one())
137
+ except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
138
+ return False
93
139
 
94
140
 
95
141
  ##
@@ -99,8 +145,8 @@ async def check_engine(
99
145
  engine: AsyncEngine,
100
146
  /,
101
147
  *,
102
- timeout: Duration | None = None,
103
- error: type[Exception] = TimeoutError,
148
+ timeout: Delta | None = None,
149
+ error: MaybeType[BaseException] = TimeoutError,
104
150
  num_tables: int | tuple[int, float] | None = None,
105
151
  ) -> None:
106
152
  """Check that an engine can connect.
@@ -115,7 +161,7 @@ async def check_engine(
115
161
  query = "select * from all_objects"
116
162
  case "sqlite":
117
163
  query = "select * from sqlite_master where type='table'"
118
- case _ as never:
164
+ case never:
119
165
  assert_never(never)
120
166
  statement = text(query)
121
167
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
@@ -183,7 +229,36 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
183
229
  ##
184
230
 
185
231
 
186
- def create_async_engine(
232
+ @overload
233
+ def create_engine(
234
+ drivername: str,
235
+ /,
236
+ *,
237
+ username: str | None = None,
238
+ password: str | None = None,
239
+ host: str | None = None,
240
+ port: int | None = None,
241
+ database: str | None = None,
242
+ query: StrMapping | None = None,
243
+ poolclass: type[Pool] | None = NullPool,
244
+ async_: Literal[True],
245
+ ) -> AsyncEngine: ...
246
+ @overload
247
+ def create_engine(
248
+ drivername: str,
249
+ /,
250
+ *,
251
+ username: str | None = None,
252
+ password: str | None = None,
253
+ host: str | None = None,
254
+ port: int | None = None,
255
+ database: str | None = None,
256
+ query: StrMapping | None = None,
257
+ poolclass: type[Pool] | None = NullPool,
258
+ async_: Literal[False] = False,
259
+ ) -> Engine: ...
260
+ @overload
261
+ def create_engine(
187
262
  drivername: str,
188
263
  /,
189
264
  *,
@@ -194,7 +269,21 @@ def create_async_engine(
194
269
  database: str | None = None,
195
270
  query: StrMapping | None = None,
196
271
  poolclass: type[Pool] | None = NullPool,
197
- ) -> AsyncEngine:
272
+ async_: bool = False,
273
+ ) -> Engine | AsyncEngine: ...
274
+ def create_engine(
275
+ drivername: str,
276
+ /,
277
+ *,
278
+ username: str | None = None,
279
+ password: str | None = None,
280
+ host: str | None = None,
281
+ port: int | None = None,
282
+ database: str | None = None,
283
+ query: StrMapping | None = None,
284
+ poolclass: type[Pool] | None = NullPool,
285
+ async_: bool = False,
286
+ ) -> Engine | AsyncEngine:
198
287
  """Create a SQLAlchemy engine."""
199
288
  if query is None:
200
289
  kwargs = {}
@@ -213,7 +302,47 @@ def create_async_engine(
213
302
  database=database,
214
303
  **kwargs,
215
304
  )
216
- return _create_async_engine(url, poolclass=poolclass)
305
+ match async_:
306
+ case False:
307
+ return sqlalchemy.create_engine(url, poolclass=poolclass)
308
+ case True:
309
+ return create_async_engine(url, poolclass=poolclass)
310
+ case never:
311
+ assert_never(never)
312
+
313
+
314
+ ##
315
+
316
+
317
+ async def ensure_database_created(super_: URL, database: str, /) -> None:
318
+ """Ensure a database is created."""
319
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
320
+ async with engine.begin() as conn:
321
+ try:
322
+ _ = await conn.execute(text(f"CREATE DATABASE {database}"))
323
+ except DatabaseError as error:
324
+ _ensure_tables_maybe_reraise(error, 'database ".*" already exists')
325
+
326
+
327
+ async def ensure_database_dropped(super_: URL, database: str, /) -> None:
328
+ """Ensure a database is dropped."""
329
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
330
+ async with engine.begin() as conn:
331
+ _ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
332
+
333
+
334
+ async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
335
+ """Ensure a databases' users are disconnected."""
336
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
337
+ match dialect := _get_dialect(engine):
338
+ case "postgresql": # skipif-ci-and-not-linux
339
+ query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
340
+ case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
341
+ raise NotImplementedError(dialect)
342
+ case never:
343
+ assert_never(never)
344
+ async with engine.begin() as conn:
345
+ _ = await conn.execute(text(query))
217
346
 
218
347
 
219
348
  ##
@@ -223,8 +352,8 @@ async def ensure_tables_created(
223
352
  engine: AsyncEngine,
224
353
  /,
225
354
  *tables_or_orms: TableOrORMInstOrClass,
226
- timeout: Duration | None = None,
227
- error: type[Exception] = TimeoutError,
355
+ timeout: Delta | None = None,
356
+ error: MaybeType[BaseException] = TimeoutError,
228
357
  ) -> None:
229
358
  """Ensure a table/set of tables is/are created."""
230
359
  tables = set(map(get_table, tables_or_orms))
@@ -239,7 +368,7 @@ async def ensure_tables_created(
239
368
  match = "ORA-00955: name is already used by an existing object"
240
369
  case "sqlite":
241
370
  match = "table .* already exists"
242
- case _ as never:
371
+ case never:
243
372
  assert_never(never)
244
373
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
245
374
  for table in tables:
@@ -252,8 +381,8 @@ async def ensure_tables_created(
252
381
  async def ensure_tables_dropped(
253
382
  engine: AsyncEngine,
254
383
  *tables_or_orms: TableOrORMInstOrClass,
255
- timeout: Duration | None = None,
256
- error: type[Exception] = TimeoutError,
384
+ timeout: Delta | None = None,
385
+ error: MaybeType[BaseException] = TimeoutError,
257
386
  ) -> None:
258
387
  """Ensure a table/set of tables is/are dropped."""
259
388
  tables = set(map(get_table, tables_or_orms))
@@ -268,7 +397,7 @@ async def ensure_tables_dropped(
268
397
  match = "ORA-00942: table or view does not exist"
269
398
  case "sqlite":
270
399
  match = "no such table"
271
- case _ as never:
400
+ case never:
272
401
  assert_never(never)
273
402
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
274
403
  for table in tables:
@@ -281,16 +410,120 @@ async def ensure_tables_dropped(
281
410
  ##
282
411
 
283
412
 
413
+ def enum_name(enum: type[Enum], /) -> str:
414
+ """Get the name of an Enum."""
415
+ return f"{snake_case(get_class_name(enum))}_enum"
416
+
417
+
418
+ ##
419
+
420
+
421
+ def enum_values(enum: type[StrEnum], /) -> list[str]:
422
+ """Get the values of a StrEnum."""
423
+ return [e.value for e in enum]
424
+
425
+
426
+ ##
427
+
428
+
429
+ @dataclass(kw_only=True, slots=True)
430
+ class ExtractURLOutput:
431
+ username: str
432
+ password: secret_str
433
+ host: str
434
+ port: int
435
+ database: str
436
+
437
+
438
+ def extract_url(url: URL, /) -> ExtractURLOutput:
439
+ """Extract the database, host & port from a URL."""
440
+ if url.username is None:
441
+ raise _ExtractURLUsernameError(url=url)
442
+ if url.password is None:
443
+ raise _ExtractURLPasswordError(url=url)
444
+ if url.host is None:
445
+ raise _ExtractURLHostError(url=url)
446
+ if url.port is None:
447
+ raise _ExtractURLPortError(url=url)
448
+ if url.database is None:
449
+ raise _ExtractURLDatabaseError(url=url)
450
+ return ExtractURLOutput(
451
+ username=url.username,
452
+ password=secret_str(url.password),
453
+ host=url.host,
454
+ port=url.port,
455
+ database=url.database,
456
+ )
457
+
458
+
459
+ @dataclass(kw_only=True, slots=True)
460
+ class ExtractURLError(Exception):
461
+ url: URL
462
+
463
+
464
+ @dataclass(kw_only=True, slots=True)
465
+ class _ExtractURLUsernameError(ExtractURLError):
466
+ @override
467
+ def __str__(self) -> str:
468
+ return f"Expected URL to contain a user name; got {self.url}"
469
+
470
+
471
+ @dataclass(kw_only=True, slots=True)
472
+ class _ExtractURLPasswordError(ExtractURLError):
473
+ @override
474
+ def __str__(self) -> str:
475
+ return f"Expected URL to contain a password; got {self.url}"
476
+
477
+
478
+ @dataclass(kw_only=True, slots=True)
479
+ class _ExtractURLHostError(ExtractURLError):
480
+ @override
481
+ def __str__(self) -> str:
482
+ return f"Expected URL to contain a host; got {self.url}"
483
+
484
+
485
+ @dataclass(kw_only=True, slots=True)
486
+ class _ExtractURLPortError(ExtractURLError):
487
+ @override
488
+ def __str__(self) -> str:
489
+ return f"Expected URL to contain a port; got {self.url}"
490
+
491
+
492
+ @dataclass(kw_only=True, slots=True)
493
+ class _ExtractURLDatabaseError(ExtractURLError):
494
+ @override
495
+ def __str__(self) -> str:
496
+ return f"Expected URL to contain a database; got {self.url}"
497
+
498
+
499
+ ##
500
+
501
+
284
502
  def get_chunk_size(
285
- engine_or_conn: _EngineOrConnectionOrAsync,
503
+ dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
504
+ table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
286
505
  /,
287
506
  *,
288
507
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
289
- scaling: float = 1.0,
290
508
  ) -> int:
291
509
  """Get the maximum chunk size for an engine."""
292
- max_params = _get_dialect_max_params(engine_or_conn)
293
- return max(floor(chunk_size_frac * max_params / scaling), 1)
510
+ max_params = _get_dialect_max_params(dialect_or_engine_or_conn)
511
+ match table_or_orm_or_num_cols:
512
+ case Table() | DeclarativeBase() | type() as table_or_orm:
513
+ return get_chunk_size(
514
+ dialect_or_engine_or_conn,
515
+ get_columns(table_or_orm),
516
+ chunk_size_frac=chunk_size_frac,
517
+ )
518
+ case Sized() as sized:
519
+ return get_chunk_size(
520
+ dialect_or_engine_or_conn, len(sized), chunk_size_frac=chunk_size_frac
521
+ )
522
+ case int() as num_cols:
523
+ size = floor(chunk_size_frac * max_params / num_cols)
524
+ return max(size, 1)
525
+ case never:
526
+ assert_never(never)
294
527
 
295
528
 
296
529
  ##
@@ -370,18 +603,22 @@ type _InsertItem = (
370
603
  | Sequence[_PairOfTupleOrStrMappingAndTable]
371
604
  | Sequence[DeclarativeBase]
372
605
  )
606
+ type _NormalizedItem = tuple[Table, StrMapping]
607
+ type _InsertPair = tuple[Table, Sequence[StrMapping]]
608
+ type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
373
609
 
374
610
 
375
611
  async def insert_items(
376
612
  engine: AsyncEngine,
377
613
  *items: _InsertItem,
378
614
  snake: bool = False,
615
+ is_upsert: bool = False,
379
616
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
380
617
  assume_tables_exist: bool = False,
381
- timeout_create: Duration | None = None,
382
- error_create: type[Exception] = TimeoutError,
383
- timeout_insert: Duration | None = None,
384
- error_insert: type[Exception] = TimeoutError,
618
+ timeout_create: Delta | None = None,
619
+ error_create: MaybeType[BaseException] = TimeoutError,
620
+ timeout_insert: Delta | None = None,
621
+ error_insert: MaybeType[BaseException] = TimeoutError,
385
622
  ) -> None:
386
623
  """Insert a set of items into a database.
387
624
 
@@ -405,37 +642,181 @@ async def insert_items(
405
642
  Obj(k1=v21, k2=v22, ...),
406
643
  ...]
407
644
  """
408
-
409
- def build_insert(
410
- table: Table, values: Iterable[StrMapping], /
411
- ) -> tuple[Insert, Any]:
412
- match _get_dialect(engine):
413
- case "oracle": # pragma: no cover
414
- return insert(table), values
415
- case _:
416
- return insert(table).values(list(values)), None
417
-
418
- try:
419
- prepared = _prepare_insert_or_upsert_items(
420
- partial(_normalize_insert_item, snake=snake),
421
- engine,
422
- build_insert,
423
- *items,
424
- chunk_size_frac=chunk_size_frac,
425
- )
426
- except _PrepareInsertOrUpsertItemsError as error:
427
- raise InsertItemsError(item=error.item) from None
645
+ normalized = chain.from_iterable(
646
+ _insert_items_yield_normalized(i, snake=snake) for i in items
647
+ )
648
+ triples = _insert_items_yield_triples(
649
+ engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
650
+ )
428
651
  if not assume_tables_exist:
652
+ triples = list(triples)
653
+ tables = {table for table, _, _ in triples}
429
654
  await ensure_tables_created(
430
- engine, *prepared.tables, timeout=timeout_create, error=error_create
655
+ engine, *tables, timeout=timeout_create, error=error_create
431
656
  )
432
- for ins, parameters in prepared.yield_pairs():
657
+ for _, ins, parameters in triples:
433
658
  async with yield_connection(
434
659
  engine, timeout=timeout_insert, error=error_insert
435
660
  ) as conn:
436
661
  _ = await conn.execute(ins, parameters=parameters)
437
662
 
438
663
 
664
+ def _insert_items_yield_normalized(
665
+ item: _InsertItem, /, *, snake: bool = False
666
+ ) -> Iterator[_NormalizedItem]:
667
+ if _is_pair_of_str_mapping_and_table(item):
668
+ mapping, table_or_orm = item
669
+ adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
670
+ yield (get_table(table_or_orm), adjusted)
671
+ return
672
+ if _is_pair_of_tuple_and_table(item):
673
+ tuple_, table_or_orm = item
674
+ mapping = _tuple_to_mapping(tuple_, table_or_orm)
675
+ yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
676
+ return
677
+ if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
678
+ items, table_or_orm = item
679
+ pairs = [(i, table_or_orm) for i in items]
680
+ for p in pairs:
681
+ yield from _insert_items_yield_normalized(p, snake=snake)
682
+ return
683
+ if isinstance(item, DeclarativeBase):
684
+ mapping = _orm_inst_to_dict(item)
685
+ yield from _insert_items_yield_normalized((mapping, item), snake=snake)
686
+ return
687
+ try:
688
+ _ = iter(item)
689
+ except TypeError:
690
+ raise InsertItemsError(item=item) from None
691
+ if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
692
+ pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
693
+ for p in pairs:
694
+ yield from _insert_items_yield_normalized(p, snake=snake)
695
+ return
696
+ if all(map(is_orm, item)):
697
+ classes = cast("Sequence[DeclarativeBase]", item)
698
+ for c in classes:
699
+ yield from _insert_items_yield_normalized(c, snake=snake)
700
+ return
701
+ raise InsertItemsError(item=item)
702
+
703
+
704
+ def _insert_items_yield_triples(
705
+ engine: AsyncEngine,
706
+ items: Iterable[_NormalizedItem],
707
+ /,
708
+ *,
709
+ is_upsert: bool = False,
710
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
711
+ ) -> Iterable[_InsertTriple]:
712
+ pairs = _insert_items_yield_chunked_pairs(
713
+ engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
714
+ )
715
+ for table, mappings in pairs:
716
+ match is_upsert, _get_dialect(engine):
717
+ case False, "oracle": # pragma: no cover
718
+ ins = insert(table)
719
+ parameters = mappings
720
+ case False, _:
721
+ ins = insert(table).values(mappings)
722
+ parameters = None
723
+ case True, _:
724
+ ins = _insert_items_build_insert_with_on_conflict_do_update(
725
+ engine, table, mappings
726
+ )
727
+ parameters = None
728
+ case never:
729
+ assert_never(never)
730
+ yield table, ins, parameters
731
+
732
+
733
+ def _insert_items_yield_chunked_pairs(
734
+ engine: AsyncEngine,
735
+ items: Iterable[_NormalizedItem],
736
+ /,
737
+ *,
738
+ is_upsert: bool = False,
739
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
740
+ ) -> Iterable[_InsertPair]:
741
+ for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
742
+ chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
743
+ for mappings_i in chunked(mappings, chunk_size):
744
+ yield table, list(mappings_i)
745
+
746
+
747
+ def _insert_items_yield_raw_pairs(
748
+ items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
749
+ ) -> Iterable[_InsertPair]:
750
+ by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
751
+ for table, mapping in items:
752
+ by_table[table].append(mapping)
753
+ for table, mappings in by_table.items():
754
+ yield from _insert_items_yield_raw_pairs_one(
755
+ table, mappings, is_upsert=is_upsert
756
+ )
757
+
758
+
759
+ def _insert_items_yield_raw_pairs_one(
760
+ table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
761
+ ) -> Iterable[_InsertPair]:
762
+ merged = _insert_items_yield_merged_mappings(table, mappings)
763
+ match is_upsert:
764
+ case True:
765
+ by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
766
+ for mapping in merged:
767
+ non_null = {k: v for k, v in mapping.items() if v is not None}
768
+ by_keys[frozenset(non_null)].append(non_null)
769
+ for mappings_i in by_keys.values():
770
+ yield table, mappings_i
771
+ case False:
772
+ yield table, list(merged)
773
+ case never:
774
+ assert_never(never)
775
+
776
+
777
+ def _insert_items_yield_merged_mappings(
778
+ table: Table, mappings: Iterable[StrMapping], /
779
+ ) -> Iterable[StrMapping]:
780
+ columns = list(yield_primary_key_columns(table))
781
+ col_names = [c.name for c in columns]
782
+ cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
783
+ cols_non_auto = set(col_names) - cols_auto
784
+ by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
785
+ for mapping in mappings:
786
+ check_subset(cols_non_auto, mapping)
787
+ has_all_auto = set(cols_auto).issubset(mapping)
788
+ if has_all_auto:
789
+ pkey = tuple(mapping[k] for k in col_names)
790
+ rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
791
+ by_key[pkey].append(rest)
792
+ else:
793
+ yield mapping
794
+ for k, v in by_key.items():
795
+ head = dict(zip(col_names, k, strict=True))
796
+ yield merge_str_mappings(head, *v)
797
+
798
+
799
+ def _insert_items_build_insert_with_on_conflict_do_update(
800
+ engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
801
+ ) -> Insert:
802
+ primary_key = cast("Any", table.primary_key)
803
+ mappings = list(mappings)
804
+ columns = merge_sets(*mappings)
805
+ match _get_dialect(engine):
806
+ case "postgresql": # skipif-ci-and-not-linux
807
+ ins = postgresql_insert(table).values(mappings)
808
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
809
+ return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
810
+ case "sqlite":
811
+ ins = sqlite_insert(table).values(mappings)
812
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
813
+ return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
814
+ case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
815
+ raise NotImplementedError(dialect)
816
+ case never:
817
+ assert_never(never)
818
+
819
+
439
820
  @dataclass(kw_only=True, slots=True)
440
821
  class InsertItemsError(Exception):
441
822
  item: _InsertItem
@@ -479,10 +860,10 @@ async def migrate_data(
479
860
  table_or_orm_to: TableOrORMInstOrClass | None = None,
480
861
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
481
862
  assume_tables_exist: bool = False,
482
- timeout_create: Duration | None = None,
483
- error_create: type[Exception] = TimeoutError,
484
- timeout_insert: Duration | None = None,
485
- error_insert: type[Exception] = TimeoutError,
863
+ timeout_create: Delta | None = None,
864
+ error_create: MaybeType[BaseException] = TimeoutError,
865
+ timeout_insert: Delta | None = None,
866
+ error_insert: MaybeType[BaseException] = TimeoutError,
486
867
  ) -> None:
487
868
  """Migrate the contents of a table from one database to another."""
488
869
  table_from = get_table(table_or_orm_from)
@@ -506,82 +887,8 @@ async def migrate_data(
506
887
  ##
507
888
 
508
889
 
509
- def _normalize_insert_item(
510
- item: _InsertItem, /, *, snake: bool = False
511
- ) -> list[_NormalizedItem]:
512
- """Normalize an insertion item."""
513
- if _is_pair_of_str_mapping_and_table(item):
514
- mapping, table_or_orm = item
515
- adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
516
- normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
517
- return [normalized]
518
- if _is_pair_of_tuple_and_table(item):
519
- tuple_, table_or_orm = item
520
- mapping = _tuple_to_mapping(tuple_, table_or_orm)
521
- return _normalize_insert_item((mapping, table_or_orm), snake=snake)
522
- if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
523
- items, table_or_orm = item
524
- pairs = [(i, table_or_orm) for i in items]
525
- normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
526
- return list(chain.from_iterable(normalized))
527
- if isinstance(item, DeclarativeBase):
528
- mapping = _orm_inst_to_dict(item)
529
- return _normalize_insert_item((mapping, item), snake=snake)
530
- try:
531
- _ = iter(item)
532
- except TypeError:
533
- raise _NormalizeInsertItemError(item=item) from None
534
- if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
535
- seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
536
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
537
- return list(chain.from_iterable(normalized))
538
- if all(map(is_orm, item)):
539
- seq = cast("Sequence[DeclarativeBase]", item)
540
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
541
- return list(chain.from_iterable(normalized))
542
- raise _NormalizeInsertItemError(item=item)
543
-
544
-
545
- @dataclass(kw_only=True, slots=True)
546
- class _NormalizeInsertItemError(Exception):
547
- item: _InsertItem
548
-
549
- @override
550
- def __str__(self) -> str:
551
- return f"Item must be valid; got {self.item}"
552
-
553
-
554
- @dataclass(kw_only=True, slots=True)
555
- class _NormalizedItem:
556
- mapping: StrMapping
557
- table: Table
558
-
559
-
560
- def _normalize_upsert_item(
561
- item: _InsertItem,
562
- /,
563
- *,
564
- snake: bool = False,
565
- selected_or_all: Literal["selected", "all"] = "selected",
566
- ) -> Iterator[_NormalizedItem]:
567
- """Normalize an upsert item."""
568
- normalized = _normalize_insert_item(item, snake=snake)
569
- match selected_or_all:
570
- case "selected":
571
- for norm in normalized:
572
- values = {k: v for k, v in norm.mapping.items() if v is not None}
573
- yield _NormalizedItem(mapping=values, table=norm.table)
574
- case "all":
575
- yield from normalized
576
- case _ as never:
577
- assert_never(never)
578
-
579
-
580
- ##
581
-
582
-
583
890
  def selectable_to_string(
584
- selectable: Selectable[Any], engine_or_conn: _EngineOrConnectionOrAsync, /
891
+ selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
585
892
  ) -> str:
586
893
  """Convert a selectable into a string."""
587
894
  com = selectable.compile(
@@ -604,237 +911,22 @@ class TablenameMixin:
604
911
  ##
605
912
 
606
913
 
607
- @dataclass(kw_only=True)
608
- class UpsertService(Looper[_InsertItem]):
609
- """Service to upsert items to a database."""
610
-
611
- # base
612
- freq: Duration = field(default=SECOND, repr=False)
613
- backoff: Duration = field(default=SECOND, repr=False)
614
- empty_upon_exit: bool = field(default=True, repr=False)
615
- # self
616
- engine: AsyncEngine
617
- snake: bool = False
618
- selected_or_all: _SelectedOrAll = "selected"
619
- chunk_size_frac: float = CHUNK_SIZE_FRAC
620
- assume_tables_exist: bool = False
621
- timeout_create: Duration | None = None
622
- error_create: type[Exception] = TimeoutError
623
- timeout_insert: Duration | None = None
624
- error_insert: type[Exception] = TimeoutError
625
-
626
- @override
627
- async def core(self) -> None:
628
- await super().core()
629
- await upsert_items(
630
- self.engine,
631
- *self.get_all_nowait(),
632
- snake=self.snake,
633
- selected_or_all=self.selected_or_all,
634
- chunk_size_frac=self.chunk_size_frac,
635
- assume_tables_exist=self.assume_tables_exist,
636
- timeout_create=self.timeout_create,
637
- error_create=self.error_create,
638
- timeout_insert=self.timeout_insert,
639
- error_insert=self.error_insert,
640
- )
641
-
642
-
643
- @dataclass(kw_only=True)
644
- class UpsertServiceMixin:
645
- """Mix-in for the upsert service."""
646
-
647
- # base - looper
648
- upsert_service_freq: Duration = field(default=SECOND, repr=False)
649
- upsert_service_backoff: Duration = field(default=SECOND, repr=False)
650
- upsert_service_empty_upon_exit: bool = field(default=False, repr=False)
651
- upsert_service_logger: str | None = field(default=None, repr=False)
652
- upsert_service_timeout: Duration | None = field(default=None, repr=False)
653
- upsert_service_debug: bool = field(default=False, repr=False)
654
- # base - upsert service
655
- upsert_service_database: AsyncEngine
656
- upsert_service_snake: bool = False
657
- upsert_service_selected_or_all: _SelectedOrAll = "selected"
658
- upsert_service_chunk_size_frac: float = CHUNK_SIZE_FRAC
659
- upsert_service_assume_tables_exist: bool = False
660
- upsert_service_timeout_create: Duration | None = None
661
- upsert_service_error_create: type[Exception] = TimeoutError
662
- upsert_service_timeout_insert: Duration | None = None
663
- upsert_service_error_insert: type[Exception] = TimeoutError
664
- # self
665
- _upsert_service: UpsertService = field(init=False, repr=False)
666
-
667
- def __post_init__(self) -> None:
668
- with suppress_super_object_attribute_error():
669
- super().__post_init__() # pyright: ignore[reportAttributeAccessIssue]
670
- self._upsert_service = UpsertService(
671
- # looper
672
- freq=self.upsert_service_freq,
673
- backoff=self.upsert_service_backoff,
674
- empty_upon_exit=self.upsert_service_empty_upon_exit,
675
- logger=self.upsert_service_logger,
676
- timeout=self.upsert_service_timeout,
677
- _debug=self.upsert_service_debug,
678
- # upsert service
679
- engine=self.upsert_service_database,
680
- snake=self.upsert_service_snake,
681
- selected_or_all=self.upsert_service_selected_or_all,
682
- chunk_size_frac=self.upsert_service_chunk_size_frac,
683
- assume_tables_exist=self.upsert_service_assume_tables_exist,
684
- timeout_create=self.upsert_service_timeout_create,
685
- error_create=self.upsert_service_error_create,
686
- timeout_insert=self.upsert_service_timeout_insert,
687
- error_insert=self.upsert_service_error_insert,
688
- )
689
-
690
- def _yield_sub_loopers(self) -> Iterator[Looper[Any]]:
691
- with suppress_super_object_attribute_error():
692
- yield from super()._yield_sub_loopers() # pyright: ignore[reportAttributeAccessIssue]
693
- yield self._upsert_service
694
-
695
-
696
- ##
697
-
698
-
699
- type _SelectedOrAll = Literal["selected", "all"]
700
-
701
-
702
- async def upsert_items(
703
- engine: AsyncEngine,
704
- /,
705
- *items: _InsertItem,
706
- snake: bool = False,
707
- selected_or_all: _SelectedOrAll = "selected",
708
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
709
- assume_tables_exist: bool = False,
710
- timeout_create: Duration | None = None,
711
- error_create: type[Exception] = TimeoutError,
712
- timeout_insert: Duration | None = None,
713
- error_insert: type[Exception] = TimeoutError,
714
- ) -> None:
715
- """Upsert a set of items into a database.
716
-
717
- These can be one of the following:
718
- - pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
719
- - pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
720
- {k1=v21, k2=v22, ...},
721
- ...], table/class
722
- - list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
723
- ({k1=v21, k2=v22, ...}, table_cls2),
724
- ...]
725
- - mapped class: Obj(k1=v1, k2=v2, ...)
726
- - list of mapped classes: [Obj(k1=v11, k2=v12, ...),
727
- Obj(k1=v21, k2=v22, ...),
728
- ...]
729
- """
730
-
731
- def build_insert(
732
- table: Table, values: Iterable[StrMapping], /
733
- ) -> tuple[Insert, None]:
734
- ups = _upsert_items_build(
735
- engine, table, values, selected_or_all=selected_or_all
736
- )
737
- return ups, None
738
-
739
- try:
740
- prepared = _prepare_insert_or_upsert_items(
741
- partial(
742
- _normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
743
- ),
744
- engine,
745
- build_insert,
746
- *items,
747
- chunk_size_frac=chunk_size_frac,
748
- )
749
- except _PrepareInsertOrUpsertItemsError as error:
750
- raise UpsertItemsError(item=error.item) from None
751
- if not assume_tables_exist:
752
- await ensure_tables_created(
753
- engine, *prepared.tables, timeout=timeout_create, error=error_create
754
- )
755
- for ups, _ in prepared.yield_pairs():
756
- async with yield_connection(
757
- engine, timeout=timeout_insert, error=error_insert
758
- ) as conn:
759
- _ = await conn.execute(ups)
760
-
761
-
762
- def _upsert_items_build(
763
- engine: AsyncEngine,
764
- table: Table,
765
- values: Iterable[StrMapping],
766
- /,
767
- *,
768
- selected_or_all: Literal["selected", "all"] = "selected",
769
- ) -> Insert:
770
- values = list(values)
771
- keys = merge_sets(*values)
772
- dict_nones = dict.fromkeys(keys)
773
- values = [{**dict_nones, **v} for v in values]
774
- match _get_dialect(engine):
775
- case "postgresql": # skipif-ci-and-not-linux
776
- insert = postgresql_insert
777
- case "sqlite":
778
- insert = sqlite_insert
779
- case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
780
- raise NotImplementedError(dialect)
781
- case _ as never:
782
- assert_never(never)
783
- ins = insert(table).values(values)
784
- primary_key = cast("Any", table.primary_key)
785
- return _upsert_items_apply_on_conflict_do_update(
786
- values, ins, primary_key, selected_or_all=selected_or_all
787
- )
788
-
789
-
790
- def _upsert_items_apply_on_conflict_do_update(
791
- values: Iterable[StrMapping],
792
- insert: postgresql_Insert | sqlite_Insert,
793
- primary_key: PrimaryKeyConstraint,
794
- /,
795
- *,
796
- selected_or_all: Literal["selected", "all"] = "selected",
797
- ) -> Insert:
798
- match selected_or_all:
799
- case "selected":
800
- columns = merge_sets(*values)
801
- case "all":
802
- columns = {c.name for c in insert.excluded}
803
- case _ as never:
804
- assert_never(never)
805
- set_ = {c: getattr(insert.excluded, c) for c in columns}
806
- match insert:
807
- case postgresql_Insert(): # skipif-ci
808
- return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
809
- case sqlite_Insert():
810
- return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
811
- case _ as never:
812
- assert_never(never)
813
-
814
-
815
- @dataclass(kw_only=True, slots=True)
816
- class UpsertItemsError(Exception):
817
- item: _InsertItem
818
-
819
- @override
820
- def __str__(self) -> str:
821
- return f"Item must be valid; got {self.item}"
822
-
823
-
824
- ##
825
-
826
-
827
914
  @asynccontextmanager
828
915
  async def yield_connection(
829
916
  engine: AsyncEngine,
830
917
  /,
831
918
  *,
832
- timeout: Duration | None = None,
833
- error: type[Exception] = TimeoutError,
919
+ timeout: Delta | None = None,
920
+ error: MaybeType[BaseException] = TimeoutError,
834
921
  ) -> AsyncIterator[AsyncConnection]:
835
922
  """Yield an async connection."""
836
- async with timeout_dur(duration=timeout, error=error), engine.begin() as conn:
837
- yield conn
923
+ try:
924
+ async with timeout_td(timeout, error=error), engine.begin() as conn:
925
+ yield conn
926
+ except GeneratorExit: # pragma: no cover
927
+ if not is_pytest():
928
+ raise
929
+ return
838
930
 
839
931
 
840
932
  ##
@@ -865,7 +957,7 @@ def _ensure_tables_maybe_reraise(error: DatabaseError, match: str, /) -> None:
865
957
  ##
866
958
 
867
959
 
868
- def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
960
+ def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
869
961
  """Get the dialect of a database."""
870
962
  dialect = engine_or_conn.dialect
871
963
  if isinstance(dialect, mssql_dialect): # pragma: no cover
@@ -875,7 +967,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
875
967
  if isinstance(dialect, oracle_dialect): # pragma: no cover
876
968
  return "oracle"
877
969
  if isinstance( # skipif-ci-and-not-linux
878
- dialect, postgresql_dialect | PGDialect_asyncpg
970
+ dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
879
971
  ):
880
972
  return "postgresql"
881
973
  if isinstance(dialect, sqlite_dialect):
@@ -888,7 +980,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
888
980
 
889
981
 
890
982
  def _get_dialect_max_params(
891
- dialect_or_engine_or_conn: Dialect | _EngineOrConnectionOrAsync, /
983
+ dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync, /
892
984
  ) -> int:
893
985
  """Get the max number of parameters of a dialect."""
894
986
  match dialect_or_engine_or_conn:
@@ -910,7 +1002,7 @@ def _get_dialect_max_params(
910
1002
  ):
911
1003
  dialect = _get_dialect(engine_or_conn)
912
1004
  return _get_dialect_max_params(dialect)
913
- case _ as never:
1005
+ case never:
914
1006
  assert_never(never)
915
1007
 
916
1008
 
@@ -943,9 +1035,9 @@ def _is_pair_of_tuple_or_str_mapping_and_table(
943
1035
  return _is_pair_with_predicate_and_table(obj, is_tuple_or_str_mapping)
944
1036
 
945
1037
 
946
- def _is_pair_with_predicate_and_table(
947
- obj: Any, predicate: Callable[[Any], TypeGuard[_T]], /
948
- ) -> TypeGuard[tuple[_T, TableOrORMInstOrClass]]:
1038
+ def _is_pair_with_predicate_and_table[T](
1039
+ obj: Any, predicate: Callable[[Any], TypeGuard[T]], /
1040
+ ) -> TypeGuard[tuple[T, TableOrORMInstOrClass]]:
949
1041
  """Check if an object is pair and a table."""
950
1042
  return (
951
1043
  isinstance(obj, tuple)
@@ -995,7 +1087,7 @@ def _map_mapping_to_table(
995
1087
  @dataclass(kw_only=True, slots=True)
996
1088
  class _MapMappingToTableError(Exception):
997
1089
  mapping: StrMapping
998
- columns: Sequence[str]
1090
+ columns: list[str]
999
1091
 
1000
1092
 
1001
1093
  @dataclass(kw_only=True, slots=True)
@@ -1032,102 +1124,31 @@ class _MapMappingToTableSnakeMapNonUniqueError(_MapMappingToTableError):
1032
1124
 
1033
1125
  def _orm_inst_to_dict(obj: DeclarativeBase, /) -> StrMapping:
1034
1126
  """Map an ORM instance to a dictionary."""
1035
- cls = type(obj)
1036
-
1037
- def is_attr(attr: str, key: str, /) -> str | None:
1038
- if isinstance(value := getattr(cls, attr), InstrumentedAttribute) and (
1039
- value.name == key
1040
- ):
1041
- return attr
1042
- return None
1043
-
1044
- def yield_items() -> Iterator[tuple[str, Any]]:
1045
- for key in get_column_names(cls):
1046
- attr = one(attr for attr in dir(cls) if is_attr(attr, key) is not None)
1047
- yield key, getattr(obj, attr)
1048
-
1049
- return dict(yield_items())
1050
-
1051
-
1052
- ##
1053
-
1054
-
1055
- @dataclass(kw_only=True, slots=True)
1056
- class _PrepareInsertOrUpsertItems:
1057
- mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
1058
- yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
1059
-
1060
- @property
1061
- def tables(self) -> Sequence[Table]:
1062
- return list(self.mapping)
1063
-
1064
-
1065
- def _prepare_insert_or_upsert_items(
1066
- normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
1067
- engine: AsyncEngine,
1068
- build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
1069
- /,
1070
- *items: Any,
1071
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
1072
- ) -> _PrepareInsertOrUpsertItems:
1073
- """Prepare a set of insert/upsert items."""
1074
- mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
1075
- lengths: set[int] = set()
1076
- try:
1077
- for item in items:
1078
- for normed in normalize_item(item):
1079
- mapping[normed.table].append(normed.mapping)
1080
- lengths.add(len(normed.mapping))
1081
- except _NormalizeInsertItemError as error:
1082
- raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
1083
- merged = {
1084
- table: _prepare_insert_or_upsert_items_merge_items(table, values)
1085
- for table, values in mapping.items()
1127
+ attrs = {
1128
+ k for k, _ in yield_object_attributes(obj, static_type=InstrumentedAttribute)
1129
+ }
1130
+ return {
1131
+ name: _orm_inst_to_dict_one(obj, attrs, name) for name in get_column_names(obj)
1086
1132
  }
1087
- max_length = max(lengths, default=1)
1088
- chunk_size = get_chunk_size(
1089
- engine, chunk_size_frac=chunk_size_frac, scaling=max_length
1090
- )
1091
-
1092
- def yield_pairs() -> Iterator[tuple[Insert, None]]:
1093
- for table, values in merged.items():
1094
- for chunk in chunked(values, chunk_size):
1095
- yield build_insert(table, chunk)
1096
-
1097
- return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
1098
1133
 
1099
1134
 
1100
- @dataclass(kw_only=True, slots=True)
1101
- class _PrepareInsertOrUpsertItemsError(Exception):
1102
- item: Any
1103
-
1104
- @override
1105
- def __str__(self) -> str:
1106
- return f"Item must be valid; got {self.item}"
1135
+ def _orm_inst_to_dict_one(
1136
+ obj: DeclarativeBase, attrs: AbstractSet[str], name: str, /
1137
+ ) -> Any:
1138
+ attr = one(
1139
+ attr for attr in attrs if _orm_inst_to_dict_predicate(type(obj), attr, name)
1140
+ )
1141
+ return getattr(obj, attr)
1107
1142
 
1108
1143
 
1109
- def _prepare_insert_or_upsert_items_merge_items(
1110
- table: Table, items: Iterable[StrMapping], /
1111
- ) -> list[StrMapping]:
1112
- columns = list(yield_primary_key_columns(table))
1113
- col_names = [c.name for c in columns]
1114
- cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
1115
- cols_non_auto = set(col_names) - cols_auto
1116
- mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
1117
- unchanged: list[StrMapping] = []
1118
- for item in items:
1119
- check_subset(cols_non_auto, item)
1120
- has_all_auto = set(cols_auto).issubset(item)
1121
- if has_all_auto:
1122
- pkey = tuple(item[k] for k in col_names)
1123
- rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
1124
- mapping[pkey].append(rest)
1125
- else:
1126
- unchanged.append(item)
1127
- merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
1128
- return [
1129
- dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
1130
- ] + unchanged
1144
+ def _orm_inst_to_dict_predicate(
1145
+ cls: type[DeclarativeBase], attr: str, name: str, /
1146
+ ) -> bool:
1147
+ cls_attr = getattr(cls, attr)
1148
+ try:
1149
+ return cls_attr.name == name
1150
+ except AttributeError:
1151
+ return False
1131
1152
 
1132
1153
 
1133
1154
  ##
@@ -1144,18 +1165,27 @@ def _tuple_to_mapping(
1144
1165
  __all__ = [
1145
1166
  "CHUNK_SIZE_FRAC",
1146
1167
  "CheckEngineError",
1168
+ "DialectOrEngineOrConnectionOrAsync",
1169
+ "EngineOrConnectionOrAsync",
1170
+ "ExtractURLError",
1171
+ "ExtractURLOutput",
1147
1172
  "GetTableError",
1148
1173
  "InsertItemsError",
1149
1174
  "TablenameMixin",
1150
- "UpsertItemsError",
1151
- "UpsertService",
1152
- "UpsertServiceMixin",
1175
+ "check_connect",
1176
+ "check_connect_async",
1153
1177
  "check_engine",
1154
1178
  "columnwise_max",
1155
1179
  "columnwise_min",
1156
- "create_async_engine",
1180
+ "create_engine",
1181
+ "ensure_database_created",
1182
+ "ensure_database_dropped",
1183
+ "ensure_database_users_disconnected",
1157
1184
  "ensure_tables_created",
1158
1185
  "ensure_tables_dropped",
1186
+ "enum_name",
1187
+ "enum_values",
1188
+ "extract_url",
1159
1189
  "get_chunk_size",
1160
1190
  "get_column_names",
1161
1191
  "get_columns",
@@ -1168,7 +1198,6 @@ __all__ = [
1168
1198
  "is_table_or_orm",
1169
1199
  "migrate_data",
1170
1200
  "selectable_to_string",
1171
- "upsert_items",
1172
1201
  "yield_connection",
1173
1202
  "yield_primary_key_columns",
1174
1203
  ]