dycw-utilities 0.135.0__py3-none-any.whl → 0.178.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dycw-utilities might be problematic. Click here for more details.

Files changed (97) hide show
  1. dycw_utilities-0.178.1.dist-info/METADATA +34 -0
  2. dycw_utilities-0.178.1.dist-info/RECORD +105 -0
  3. dycw_utilities-0.178.1.dist-info/WHEEL +4 -0
  4. dycw_utilities-0.178.1.dist-info/entry_points.txt +4 -0
  5. utilities/__init__.py +1 -1
  6. utilities/altair.py +13 -10
  7. utilities/asyncio.py +312 -787
  8. utilities/atomicwrites.py +18 -6
  9. utilities/atools.py +64 -4
  10. utilities/cachetools.py +9 -6
  11. utilities/click.py +195 -77
  12. utilities/concurrent.py +1 -1
  13. utilities/contextlib.py +216 -17
  14. utilities/contextvars.py +20 -1
  15. utilities/cryptography.py +3 -3
  16. utilities/dataclasses.py +15 -28
  17. utilities/docker.py +387 -0
  18. utilities/enum.py +2 -2
  19. utilities/errors.py +17 -3
  20. utilities/fastapi.py +28 -59
  21. utilities/fpdf2.py +2 -2
  22. utilities/functions.py +24 -269
  23. utilities/git.py +9 -30
  24. utilities/grp.py +28 -0
  25. utilities/gzip.py +31 -0
  26. utilities/http.py +3 -2
  27. utilities/hypothesis.py +513 -159
  28. utilities/importlib.py +17 -1
  29. utilities/inflect.py +12 -4
  30. utilities/iterables.py +33 -58
  31. utilities/jinja2.py +148 -0
  32. utilities/json.py +70 -0
  33. utilities/libcst.py +38 -17
  34. utilities/lightweight_charts.py +4 -7
  35. utilities/logging.py +136 -93
  36. utilities/math.py +8 -4
  37. utilities/more_itertools.py +43 -45
  38. utilities/operator.py +27 -27
  39. utilities/orjson.py +189 -36
  40. utilities/os.py +61 -4
  41. utilities/packaging.py +115 -0
  42. utilities/parse.py +8 -5
  43. utilities/pathlib.py +269 -40
  44. utilities/permissions.py +298 -0
  45. utilities/platform.py +7 -6
  46. utilities/polars.py +1205 -413
  47. utilities/polars_ols.py +1 -1
  48. utilities/postgres.py +408 -0
  49. utilities/pottery.py +43 -19
  50. utilities/pqdm.py +3 -3
  51. utilities/psutil.py +5 -57
  52. utilities/pwd.py +28 -0
  53. utilities/pydantic.py +4 -52
  54. utilities/pydantic_settings.py +240 -0
  55. utilities/pydantic_settings_sops.py +76 -0
  56. utilities/pyinstrument.py +7 -7
  57. utilities/pytest.py +104 -143
  58. utilities/pytest_plugins/__init__.py +1 -0
  59. utilities/pytest_plugins/pytest_randomly.py +23 -0
  60. utilities/pytest_plugins/pytest_regressions.py +56 -0
  61. utilities/pytest_regressions.py +26 -46
  62. utilities/random.py +11 -6
  63. utilities/re.py +1 -1
  64. utilities/redis.py +220 -343
  65. utilities/sentinel.py +10 -0
  66. utilities/shelve.py +4 -1
  67. utilities/shutil.py +25 -0
  68. utilities/slack_sdk.py +35 -104
  69. utilities/sqlalchemy.py +496 -471
  70. utilities/sqlalchemy_polars.py +29 -54
  71. utilities/string.py +2 -3
  72. utilities/subprocess.py +1977 -0
  73. utilities/tempfile.py +112 -4
  74. utilities/testbook.py +50 -0
  75. utilities/text.py +174 -42
  76. utilities/throttle.py +158 -0
  77. utilities/timer.py +2 -2
  78. utilities/traceback.py +70 -35
  79. utilities/types.py +102 -30
  80. utilities/typing.py +479 -19
  81. utilities/uuid.py +42 -5
  82. utilities/version.py +27 -26
  83. utilities/whenever.py +1559 -361
  84. utilities/zoneinfo.py +80 -22
  85. dycw_utilities-0.135.0.dist-info/METADATA +0 -39
  86. dycw_utilities-0.135.0.dist-info/RECORD +0 -96
  87. dycw_utilities-0.135.0.dist-info/WHEEL +0 -4
  88. dycw_utilities-0.135.0.dist-info/licenses/LICENSE +0 -21
  89. utilities/aiolimiter.py +0 -25
  90. utilities/arq.py +0 -216
  91. utilities/eventkit.py +0 -388
  92. utilities/luigi.py +0 -183
  93. utilities/period.py +0 -152
  94. utilities/pudb.py +0 -62
  95. utilities/python_dotenv.py +0 -101
  96. utilities/streamlit.py +0 -105
  97. utilities/typed_settings.py +0 -123
utilities/sqlalchemy.py CHANGED
@@ -12,21 +12,31 @@ from collections.abc import (
12
12
  )
13
13
  from collections.abc import Set as AbstractSet
14
14
  from contextlib import asynccontextmanager
15
- from dataclasses import dataclass, field
16
- from functools import partial, reduce
15
+ from dataclasses import dataclass
16
+ from functools import reduce
17
17
  from itertools import chain
18
18
  from math import floor
19
19
  from operator import ge, le
20
20
  from re import search
21
- from typing import TYPE_CHECKING, Any, Literal, TypeGuard, assert_never, cast, override
21
+ from socket import gaierror
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Literal,
26
+ TypeGuard,
27
+ assert_never,
28
+ cast,
29
+ overload,
30
+ override,
31
+ )
22
32
 
33
+ import sqlalchemy
23
34
  from sqlalchemy import (
24
35
  URL,
25
36
  Column,
26
37
  Connection,
27
38
  Engine,
28
39
  Insert,
29
- PrimaryKeyConstraint,
30
40
  Selectable,
31
41
  Table,
32
42
  and_,
@@ -38,16 +48,14 @@ from sqlalchemy import (
38
48
  from sqlalchemy.dialects.mssql import dialect as mssql_dialect
39
49
  from sqlalchemy.dialects.mysql import dialect as mysql_dialect
40
50
  from sqlalchemy.dialects.oracle import dialect as oracle_dialect
41
- from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
42
51
  from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
43
52
  from sqlalchemy.dialects.postgresql import insert as postgresql_insert
44
53
  from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
45
- from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
54
+ from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
46
55
  from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
47
56
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
48
57
  from sqlalchemy.exc import ArgumentError, DatabaseError
49
- from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
50
- from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
58
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
51
59
  from sqlalchemy.orm import (
52
60
  DeclarativeBase,
53
61
  InstrumentedAttribute,
@@ -57,16 +65,8 @@ from sqlalchemy.orm import (
57
65
  from sqlalchemy.orm.exc import UnmappedClassError
58
66
  from sqlalchemy.pool import NullPool, Pool
59
67
 
60
- from utilities.asyncio import Looper, timeout_td
61
- from utilities.contextlib import suppress_super_object_attribute_error
62
- from utilities.functions import (
63
- ensure_str,
64
- get_class_name,
65
- is_sequence_of_tuple_or_str_mapping,
66
- is_string_mapping,
67
- is_tuple,
68
- is_tuple_or_str_mapping,
69
- )
68
+ from utilities.asyncio import timeout_td
69
+ from utilities.functions import ensure_str, get_class_name, yield_object_attributes
70
70
  from utilities.iterables import (
71
71
  CheckLengthError,
72
72
  CheckSubSetError,
@@ -79,19 +79,63 @@ from utilities.iterables import (
79
79
  merge_str_mappings,
80
80
  one,
81
81
  )
82
+ from utilities.os import is_pytest
82
83
  from utilities.reprlib import get_repr
83
- from utilities.text import snake_case
84
- from utilities.types import MaybeIterable, MaybeType, StrMapping, TupleOrStrMapping
85
- from utilities.whenever import SECOND
84
+ from utilities.text import secret_str, snake_case
85
+ from utilities.types import (
86
+ Delta,
87
+ MaybeIterable,
88
+ MaybeType,
89
+ StrMapping,
90
+ TupleOrStrMapping,
91
+ )
92
+ from utilities.typing import (
93
+ is_sequence_of_tuple_or_str_mapping,
94
+ is_string_mapping,
95
+ is_tuple,
96
+ is_tuple_or_str_mapping,
97
+ )
86
98
 
87
99
  if TYPE_CHECKING:
88
- from whenever import TimeDelta
100
+ from enum import Enum, StrEnum
101
+
89
102
 
90
- type _EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
103
+ type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
91
104
  type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
105
+ type DialectOrEngineOrConnectionOrAsync = Dialect | EngineOrConnectionOrAsync
92
106
  type ORMInstOrClass = DeclarativeBase | type[DeclarativeBase]
93
107
  type TableOrORMInstOrClass = Table | ORMInstOrClass
94
- CHUNK_SIZE_FRAC = 0.95
108
+ CHUNK_SIZE_FRAC = 0.8
109
+
110
+
111
+ ##
112
+
113
+
114
+ _SELECT = text("SELECT 1")
115
+
116
+
117
+ def check_connect(engine: Engine, /) -> bool:
118
+ """Check if an engine can connect."""
119
+ try:
120
+ with engine.connect() as conn:
121
+ return bool(conn.execute(_SELECT).scalar_one())
122
+ except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
123
+ return False
124
+
125
+
126
+ async def check_connect_async(
127
+ engine: AsyncEngine,
128
+ /,
129
+ *,
130
+ timeout: Delta | None = None,
131
+ error: MaybeType[BaseException] = TimeoutError,
132
+ ) -> bool:
133
+ """Check if an engine can connect."""
134
+ try:
135
+ async with timeout_td(timeout, error=error), engine.connect() as conn:
136
+ return bool((await conn.execute(_SELECT)).scalar_one())
137
+ except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
138
+ return False
95
139
 
96
140
 
97
141
  ##
@@ -101,8 +145,8 @@ async def check_engine(
101
145
  engine: AsyncEngine,
102
146
  /,
103
147
  *,
104
- timeout: TimeDelta | None = None,
105
- error: type[Exception] = TimeoutError,
148
+ timeout: Delta | None = None,
149
+ error: MaybeType[BaseException] = TimeoutError,
106
150
  num_tables: int | tuple[int, float] | None = None,
107
151
  ) -> None:
108
152
  """Check that an engine can connect.
@@ -117,7 +161,7 @@ async def check_engine(
117
161
  query = "select * from all_objects"
118
162
  case "sqlite":
119
163
  query = "select * from sqlite_master where type='table'"
120
- case _ as never:
164
+ case never:
121
165
  assert_never(never)
122
166
  statement = text(query)
123
167
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
@@ -185,7 +229,8 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
185
229
  ##
186
230
 
187
231
 
188
- def create_async_engine(
232
+ @overload
233
+ def create_engine(
189
234
  drivername: str,
190
235
  /,
191
236
  *,
@@ -196,7 +241,49 @@ def create_async_engine(
196
241
  database: str | None = None,
197
242
  query: StrMapping | None = None,
198
243
  poolclass: type[Pool] | None = NullPool,
199
- ) -> AsyncEngine:
244
+ async_: Literal[True],
245
+ ) -> AsyncEngine: ...
246
+ @overload
247
+ def create_engine(
248
+ drivername: str,
249
+ /,
250
+ *,
251
+ username: str | None = None,
252
+ password: str | None = None,
253
+ host: str | None = None,
254
+ port: int | None = None,
255
+ database: str | None = None,
256
+ query: StrMapping | None = None,
257
+ poolclass: type[Pool] | None = NullPool,
258
+ async_: Literal[False] = False,
259
+ ) -> Engine: ...
260
+ @overload
261
+ def create_engine(
262
+ drivername: str,
263
+ /,
264
+ *,
265
+ username: str | None = None,
266
+ password: str | None = None,
267
+ host: str | None = None,
268
+ port: int | None = None,
269
+ database: str | None = None,
270
+ query: StrMapping | None = None,
271
+ poolclass: type[Pool] | None = NullPool,
272
+ async_: bool = False,
273
+ ) -> Engine | AsyncEngine: ...
274
+ def create_engine(
275
+ drivername: str,
276
+ /,
277
+ *,
278
+ username: str | None = None,
279
+ password: str | None = None,
280
+ host: str | None = None,
281
+ port: int | None = None,
282
+ database: str | None = None,
283
+ query: StrMapping | None = None,
284
+ poolclass: type[Pool] | None = NullPool,
285
+ async_: bool = False,
286
+ ) -> Engine | AsyncEngine:
200
287
  """Create a SQLAlchemy engine."""
201
288
  if query is None:
202
289
  kwargs = {}
@@ -215,7 +302,47 @@ def create_async_engine(
215
302
  database=database,
216
303
  **kwargs,
217
304
  )
218
- return _create_async_engine(url, poolclass=poolclass)
305
+ match async_:
306
+ case False:
307
+ return sqlalchemy.create_engine(url, poolclass=poolclass)
308
+ case True:
309
+ return create_async_engine(url, poolclass=poolclass)
310
+ case never:
311
+ assert_never(never)
312
+
313
+
314
+ ##
315
+
316
+
317
+ async def ensure_database_created(super_: URL, database: str, /) -> None:
318
+ """Ensure a database is created."""
319
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
320
+ async with engine.begin() as conn:
321
+ try:
322
+ _ = await conn.execute(text(f"CREATE DATABASE {database}"))
323
+ except DatabaseError as error:
324
+ _ensure_tables_maybe_reraise(error, 'database ".*" already exists')
325
+
326
+
327
+ async def ensure_database_dropped(super_: URL, database: str, /) -> None:
328
+ """Ensure a database is dropped."""
329
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
330
+ async with engine.begin() as conn:
331
+ _ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
332
+
333
+
334
+ async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
335
+ """Ensure a databases' users are disconnected."""
336
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
337
+ match dialect := _get_dialect(engine):
338
+ case "postgresql": # skipif-ci-and-not-linux
339
+ query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
340
+ case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
341
+ raise NotImplementedError(dialect)
342
+ case never:
343
+ assert_never(never)
344
+ async with engine.begin() as conn:
345
+ _ = await conn.execute(text(query))
219
346
 
220
347
 
221
348
  ##
@@ -225,8 +352,8 @@ async def ensure_tables_created(
225
352
  engine: AsyncEngine,
226
353
  /,
227
354
  *tables_or_orms: TableOrORMInstOrClass,
228
- timeout: TimeDelta | None = None,
229
- error: type[Exception] = TimeoutError,
355
+ timeout: Delta | None = None,
356
+ error: MaybeType[BaseException] = TimeoutError,
230
357
  ) -> None:
231
358
  """Ensure a table/set of tables is/are created."""
232
359
  tables = set(map(get_table, tables_or_orms))
@@ -241,7 +368,7 @@ async def ensure_tables_created(
241
368
  match = "ORA-00955: name is already used by an existing object"
242
369
  case "sqlite":
243
370
  match = "table .* already exists"
244
- case _ as never:
371
+ case never:
245
372
  assert_never(never)
246
373
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
247
374
  for table in tables:
@@ -254,8 +381,8 @@ async def ensure_tables_created(
254
381
  async def ensure_tables_dropped(
255
382
  engine: AsyncEngine,
256
383
  *tables_or_orms: TableOrORMInstOrClass,
257
- timeout: TimeDelta | None = None,
258
- error: type[Exception] = TimeoutError,
384
+ timeout: Delta | None = None,
385
+ error: MaybeType[BaseException] = TimeoutError,
259
386
  ) -> None:
260
387
  """Ensure a table/set of tables is/are dropped."""
261
388
  tables = set(map(get_table, tables_or_orms))
@@ -270,7 +397,7 @@ async def ensure_tables_dropped(
270
397
  match = "ORA-00942: table or view does not exist"
271
398
  case "sqlite":
272
399
  match = "no such table"
273
- case _ as never:
400
+ case never:
274
401
  assert_never(never)
275
402
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
276
403
  for table in tables:
@@ -283,18 +410,120 @@ async def ensure_tables_dropped(
283
410
  ##
284
411
 
285
412
 
413
+ def enum_name(enum: type[Enum], /) -> str:
414
+ """Get the name of an Enum."""
415
+ return f"{snake_case(get_class_name(enum))}_enum"
416
+
417
+
418
+ ##
419
+
420
+
421
+ def enum_values(enum: type[StrEnum], /) -> list[str]:
422
+ """Get the values of a StrEnum."""
423
+ return [e.value for e in enum]
424
+
425
+
426
+ ##
427
+
428
+
429
+ @dataclass(kw_only=True, slots=True)
430
+ class ExtractURLOutput:
431
+ username: str
432
+ password: secret_str
433
+ host: str
434
+ port: int
435
+ database: str
436
+
437
+
438
+ def extract_url(url: URL, /) -> ExtractURLOutput:
439
+ """Extract the database, host & port from a URL."""
440
+ if url.username is None:
441
+ raise _ExtractURLUsernameError(url=url)
442
+ if url.password is None:
443
+ raise _ExtractURLPasswordError(url=url)
444
+ if url.host is None:
445
+ raise _ExtractURLHostError(url=url)
446
+ if url.port is None:
447
+ raise _ExtractURLPortError(url=url)
448
+ if url.database is None:
449
+ raise _ExtractURLDatabaseError(url=url)
450
+ return ExtractURLOutput(
451
+ username=url.username,
452
+ password=secret_str(url.password),
453
+ host=url.host,
454
+ port=url.port,
455
+ database=url.database,
456
+ )
457
+
458
+
459
+ @dataclass(kw_only=True, slots=True)
460
+ class ExtractURLError(Exception):
461
+ url: URL
462
+
463
+
464
+ @dataclass(kw_only=True, slots=True)
465
+ class _ExtractURLUsernameError(ExtractURLError):
466
+ @override
467
+ def __str__(self) -> str:
468
+ return f"Expected URL to contain a user name; got {self.url}"
469
+
470
+
471
+ @dataclass(kw_only=True, slots=True)
472
+ class _ExtractURLPasswordError(ExtractURLError):
473
+ @override
474
+ def __str__(self) -> str:
475
+ return f"Expected URL to contain a password; got {self.url}"
476
+
477
+
478
+ @dataclass(kw_only=True, slots=True)
479
+ class _ExtractURLHostError(ExtractURLError):
480
+ @override
481
+ def __str__(self) -> str:
482
+ return f"Expected URL to contain a host; got {self.url}"
483
+
484
+
485
+ @dataclass(kw_only=True, slots=True)
486
+ class _ExtractURLPortError(ExtractURLError):
487
+ @override
488
+ def __str__(self) -> str:
489
+ return f"Expected URL to contain a port; got {self.url}"
490
+
491
+
492
+ @dataclass(kw_only=True, slots=True)
493
+ class _ExtractURLDatabaseError(ExtractURLError):
494
+ @override
495
+ def __str__(self) -> str:
496
+ return f"Expected URL to contain a database; got {self.url}"
497
+
498
+
499
+ ##
500
+
501
+
286
502
  def get_chunk_size(
287
- engine_or_conn: _EngineOrConnectionOrAsync,
503
+ dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
504
+ table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
288
505
  /,
289
506
  *,
290
507
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
291
- max_length: int = 1,
292
508
  ) -> int:
293
509
  """Get the maximum chunk size for an engine."""
294
- max_params = _get_dialect_max_params(engine_or_conn)
295
- scaling = max(max_length, 1)
296
- size = floor(chunk_size_frac * max_params / scaling)
297
- return max(size, 1)
510
+ max_params = _get_dialect_max_params(dialect_or_engine_or_conn)
511
+ match table_or_orm_or_num_cols:
512
+ case Table() | DeclarativeBase() | type() as table_or_orm:
513
+ return get_chunk_size(
514
+ dialect_or_engine_or_conn,
515
+ get_columns(table_or_orm),
516
+ chunk_size_frac=chunk_size_frac,
517
+ )
518
+ case Sized() as sized:
519
+ return get_chunk_size(
520
+ dialect_or_engine_or_conn, len(sized), chunk_size_frac=chunk_size_frac
521
+ )
522
+ case int() as num_cols:
523
+ size = floor(chunk_size_frac * max_params / num_cols)
524
+ return max(size, 1)
525
+ case never:
526
+ assert_never(never)
298
527
 
299
528
 
300
529
  ##
@@ -374,18 +603,22 @@ type _InsertItem = (
374
603
  | Sequence[_PairOfTupleOrStrMappingAndTable]
375
604
  | Sequence[DeclarativeBase]
376
605
  )
606
+ type _NormalizedItem = tuple[Table, StrMapping]
607
+ type _InsertPair = tuple[Table, Sequence[StrMapping]]
608
+ type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
377
609
 
378
610
 
379
611
  async def insert_items(
380
612
  engine: AsyncEngine,
381
613
  *items: _InsertItem,
382
614
  snake: bool = False,
615
+ is_upsert: bool = False,
383
616
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
384
617
  assume_tables_exist: bool = False,
385
- timeout_create: TimeDelta | None = None,
386
- error_create: type[Exception] = TimeoutError,
387
- timeout_insert: TimeDelta | None = None,
388
- error_insert: type[Exception] = TimeoutError,
618
+ timeout_create: Delta | None = None,
619
+ error_create: MaybeType[BaseException] = TimeoutError,
620
+ timeout_insert: Delta | None = None,
621
+ error_insert: MaybeType[BaseException] = TimeoutError,
389
622
  ) -> None:
390
623
  """Insert a set of items into a database.
391
624
 
@@ -409,37 +642,181 @@ async def insert_items(
409
642
  Obj(k1=v21, k2=v22, ...),
410
643
  ...]
411
644
  """
412
-
413
- def build_insert(
414
- table: Table, values: Iterable[StrMapping], /
415
- ) -> tuple[Insert, Any]:
416
- match _get_dialect(engine):
417
- case "oracle": # pragma: no cover
418
- return insert(table), values
419
- case _:
420
- return insert(table).values(list(values)), None
421
-
422
- try:
423
- prepared = _prepare_insert_or_upsert_items(
424
- partial(_normalize_insert_item, snake=snake),
425
- engine,
426
- build_insert,
427
- *items,
428
- chunk_size_frac=chunk_size_frac,
429
- )
430
- except _PrepareInsertOrUpsertItemsError as error:
431
- raise InsertItemsError(item=error.item) from None
645
+ normalized = chain.from_iterable(
646
+ _insert_items_yield_normalized(i, snake=snake) for i in items
647
+ )
648
+ triples = _insert_items_yield_triples(
649
+ engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
650
+ )
432
651
  if not assume_tables_exist:
652
+ triples = list(triples)
653
+ tables = {table for table, _, _ in triples}
433
654
  await ensure_tables_created(
434
- engine, *prepared.tables, timeout=timeout_create, error=error_create
655
+ engine, *tables, timeout=timeout_create, error=error_create
435
656
  )
436
- for ins, parameters in prepared.yield_pairs():
657
+ for _, ins, parameters in triples:
437
658
  async with yield_connection(
438
659
  engine, timeout=timeout_insert, error=error_insert
439
660
  ) as conn:
440
661
  _ = await conn.execute(ins, parameters=parameters)
441
662
 
442
663
 
664
+ def _insert_items_yield_normalized(
665
+ item: _InsertItem, /, *, snake: bool = False
666
+ ) -> Iterator[_NormalizedItem]:
667
+ if _is_pair_of_str_mapping_and_table(item):
668
+ mapping, table_or_orm = item
669
+ adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
670
+ yield (get_table(table_or_orm), adjusted)
671
+ return
672
+ if _is_pair_of_tuple_and_table(item):
673
+ tuple_, table_or_orm = item
674
+ mapping = _tuple_to_mapping(tuple_, table_or_orm)
675
+ yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
676
+ return
677
+ if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
678
+ items, table_or_orm = item
679
+ pairs = [(i, table_or_orm) for i in items]
680
+ for p in pairs:
681
+ yield from _insert_items_yield_normalized(p, snake=snake)
682
+ return
683
+ if isinstance(item, DeclarativeBase):
684
+ mapping = _orm_inst_to_dict(item)
685
+ yield from _insert_items_yield_normalized((mapping, item), snake=snake)
686
+ return
687
+ try:
688
+ _ = iter(item)
689
+ except TypeError:
690
+ raise InsertItemsError(item=item) from None
691
+ if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
692
+ pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
693
+ for p in pairs:
694
+ yield from _insert_items_yield_normalized(p, snake=snake)
695
+ return
696
+ if all(map(is_orm, item)):
697
+ classes = cast("Sequence[DeclarativeBase]", item)
698
+ for c in classes:
699
+ yield from _insert_items_yield_normalized(c, snake=snake)
700
+ return
701
+ raise InsertItemsError(item=item)
702
+
703
+
704
+ def _insert_items_yield_triples(
705
+ engine: AsyncEngine,
706
+ items: Iterable[_NormalizedItem],
707
+ /,
708
+ *,
709
+ is_upsert: bool = False,
710
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
711
+ ) -> Iterable[_InsertTriple]:
712
+ pairs = _insert_items_yield_chunked_pairs(
713
+ engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
714
+ )
715
+ for table, mappings in pairs:
716
+ match is_upsert, _get_dialect(engine):
717
+ case False, "oracle": # pragma: no cover
718
+ ins = insert(table)
719
+ parameters = mappings
720
+ case False, _:
721
+ ins = insert(table).values(mappings)
722
+ parameters = None
723
+ case True, _:
724
+ ins = _insert_items_build_insert_with_on_conflict_do_update(
725
+ engine, table, mappings
726
+ )
727
+ parameters = None
728
+ case never:
729
+ assert_never(never)
730
+ yield table, ins, parameters
731
+
732
+
733
+ def _insert_items_yield_chunked_pairs(
734
+ engine: AsyncEngine,
735
+ items: Iterable[_NormalizedItem],
736
+ /,
737
+ *,
738
+ is_upsert: bool = False,
739
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
740
+ ) -> Iterable[_InsertPair]:
741
+ for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
742
+ chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
743
+ for mappings_i in chunked(mappings, chunk_size):
744
+ yield table, list(mappings_i)
745
+
746
+
747
+ def _insert_items_yield_raw_pairs(
748
+ items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
749
+ ) -> Iterable[_InsertPair]:
750
+ by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
751
+ for table, mapping in items:
752
+ by_table[table].append(mapping)
753
+ for table, mappings in by_table.items():
754
+ yield from _insert_items_yield_raw_pairs_one(
755
+ table, mappings, is_upsert=is_upsert
756
+ )
757
+
758
+
759
+ def _insert_items_yield_raw_pairs_one(
760
+ table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
761
+ ) -> Iterable[_InsertPair]:
762
+ merged = _insert_items_yield_merged_mappings(table, mappings)
763
+ match is_upsert:
764
+ case True:
765
+ by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
766
+ for mapping in merged:
767
+ non_null = {k: v for k, v in mapping.items() if v is not None}
768
+ by_keys[frozenset(non_null)].append(non_null)
769
+ for mappings_i in by_keys.values():
770
+ yield table, mappings_i
771
+ case False:
772
+ yield table, list(merged)
773
+ case never:
774
+ assert_never(never)
775
+
776
+
777
+ def _insert_items_yield_merged_mappings(
778
+ table: Table, mappings: Iterable[StrMapping], /
779
+ ) -> Iterable[StrMapping]:
780
+ columns = list(yield_primary_key_columns(table))
781
+ col_names = [c.name for c in columns]
782
+ cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
783
+ cols_non_auto = set(col_names) - cols_auto
784
+ by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
785
+ for mapping in mappings:
786
+ check_subset(cols_non_auto, mapping)
787
+ has_all_auto = set(cols_auto).issubset(mapping)
788
+ if has_all_auto:
789
+ pkey = tuple(mapping[k] for k in col_names)
790
+ rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
791
+ by_key[pkey].append(rest)
792
+ else:
793
+ yield mapping
794
+ for k, v in by_key.items():
795
+ head = dict(zip(col_names, k, strict=True))
796
+ yield merge_str_mappings(head, *v)
797
+
798
+
799
+ def _insert_items_build_insert_with_on_conflict_do_update(
800
+ engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
801
+ ) -> Insert:
802
+ primary_key = cast("Any", table.primary_key)
803
+ mappings = list(mappings)
804
+ columns = merge_sets(*mappings)
805
+ match _get_dialect(engine):
806
+ case "postgresql": # skipif-ci-and-not-linux
807
+ ins = postgresql_insert(table).values(mappings)
808
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
809
+ return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
810
+ case "sqlite":
811
+ ins = sqlite_insert(table).values(mappings)
812
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
813
+ return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
814
+ case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
815
+ raise NotImplementedError(dialect)
816
+ case never:
817
+ assert_never(never)
818
+
819
+
443
820
  @dataclass(kw_only=True, slots=True)
444
821
  class InsertItemsError(Exception):
445
822
  item: _InsertItem
@@ -483,10 +860,10 @@ async def migrate_data(
483
860
  table_or_orm_to: TableOrORMInstOrClass | None = None,
484
861
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
485
862
  assume_tables_exist: bool = False,
486
- timeout_create: TimeDelta | None = None,
487
- error_create: type[Exception] = TimeoutError,
488
- timeout_insert: TimeDelta | None = None,
489
- error_insert: type[Exception] = TimeoutError,
863
+ timeout_create: Delta | None = None,
864
+ error_create: MaybeType[BaseException] = TimeoutError,
865
+ timeout_insert: Delta | None = None,
866
+ error_insert: MaybeType[BaseException] = TimeoutError,
490
867
  ) -> None:
491
868
  """Migrate the contents of a table from one database to another."""
492
869
  table_from = get_table(table_or_orm_from)
@@ -510,82 +887,8 @@ async def migrate_data(
510
887
  ##
511
888
 
512
889
 
513
- def _normalize_insert_item(
514
- item: _InsertItem, /, *, snake: bool = False
515
- ) -> list[_NormalizedItem]:
516
- """Normalize an insertion item."""
517
- if _is_pair_of_str_mapping_and_table(item):
518
- mapping, table_or_orm = item
519
- adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
520
- normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
521
- return [normalized]
522
- if _is_pair_of_tuple_and_table(item):
523
- tuple_, table_or_orm = item
524
- mapping = _tuple_to_mapping(tuple_, table_or_orm)
525
- return _normalize_insert_item((mapping, table_or_orm), snake=snake)
526
- if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
527
- items, table_or_orm = item
528
- pairs = [(i, table_or_orm) for i in items]
529
- normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
530
- return list(chain.from_iterable(normalized))
531
- if isinstance(item, DeclarativeBase):
532
- mapping = _orm_inst_to_dict(item)
533
- return _normalize_insert_item((mapping, item), snake=snake)
534
- try:
535
- _ = iter(item)
536
- except TypeError:
537
- raise _NormalizeInsertItemError(item=item) from None
538
- if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
539
- seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
540
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
541
- return list(chain.from_iterable(normalized))
542
- if all(map(is_orm, item)):
543
- seq = cast("Sequence[DeclarativeBase]", item)
544
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
545
- return list(chain.from_iterable(normalized))
546
- raise _NormalizeInsertItemError(item=item)
547
-
548
-
549
- @dataclass(kw_only=True, slots=True)
550
- class _NormalizeInsertItemError(Exception):
551
- item: _InsertItem
552
-
553
- @override
554
- def __str__(self) -> str:
555
- return f"Item must be valid; got {self.item}"
556
-
557
-
558
- @dataclass(kw_only=True, slots=True)
559
- class _NormalizedItem:
560
- mapping: StrMapping
561
- table: Table
562
-
563
-
564
- def _normalize_upsert_item(
565
- item: _InsertItem,
566
- /,
567
- *,
568
- snake: bool = False,
569
- selected_or_all: Literal["selected", "all"] = "selected",
570
- ) -> Iterator[_NormalizedItem]:
571
- """Normalize an upsert item."""
572
- normalized = _normalize_insert_item(item, snake=snake)
573
- match selected_or_all:
574
- case "selected":
575
- for norm in normalized:
576
- values = {k: v for k, v in norm.mapping.items() if v is not None}
577
- yield _NormalizedItem(mapping=values, table=norm.table)
578
- case "all":
579
- yield from normalized
580
- case _ as never:
581
- assert_never(never)
582
-
583
-
584
- ##
585
-
586
-
587
890
  def selectable_to_string(
588
- selectable: Selectable[Any], engine_or_conn: _EngineOrConnectionOrAsync, /
891
+ selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
589
892
  ) -> str:
590
893
  """Convert a selectable into a string."""
591
894
  com = selectable.compile(
@@ -608,237 +911,22 @@ class TablenameMixin:
608
911
  ##
609
912
 
610
913
 
611
- @dataclass(kw_only=True)
612
- class UpsertService(Looper[_InsertItem]):
613
- """Service to upsert items to a database."""
614
-
615
- # base
616
- freq: TimeDelta = field(default=SECOND, repr=False)
617
- backoff: TimeDelta = field(default=SECOND, repr=False)
618
- empty_upon_exit: bool = field(default=True, repr=False)
619
- # self
620
- engine: AsyncEngine
621
- snake: bool = False
622
- selected_or_all: _SelectedOrAll = "selected"
623
- chunk_size_frac: float = CHUNK_SIZE_FRAC
624
- assume_tables_exist: bool = False
625
- timeout_create: TimeDelta | None = None
626
- error_create: type[Exception] = TimeoutError
627
- timeout_insert: TimeDelta | None = None
628
- error_insert: type[Exception] = TimeoutError
629
-
630
- @override
631
- async def core(self) -> None:
632
- await super().core()
633
- await upsert_items(
634
- self.engine,
635
- *self.get_all_nowait(),
636
- snake=self.snake,
637
- selected_or_all=self.selected_or_all,
638
- chunk_size_frac=self.chunk_size_frac,
639
- assume_tables_exist=self.assume_tables_exist,
640
- timeout_create=self.timeout_create,
641
- error_create=self.error_create,
642
- timeout_insert=self.timeout_insert,
643
- error_insert=self.error_insert,
644
- )
645
-
646
-
647
- @dataclass(kw_only=True)
648
- class UpsertServiceMixin:
649
- """Mix-in for the upsert service."""
650
-
651
- # base - looper
652
- upsert_service_freq: TimeDelta = field(default=SECOND, repr=False)
653
- upsert_service_backoff: TimeDelta = field(default=SECOND, repr=False)
654
- upsert_service_empty_upon_exit: bool = field(default=False, repr=False)
655
- upsert_service_logger: str | None = field(default=None, repr=False)
656
- upsert_service_timeout: TimeDelta | None = field(default=None, repr=False)
657
- upsert_service_debug: bool = field(default=False, repr=False)
658
- # base - upsert service
659
- upsert_service_database: AsyncEngine
660
- upsert_service_snake: bool = False
661
- upsert_service_selected_or_all: _SelectedOrAll = "selected"
662
- upsert_service_chunk_size_frac: float = CHUNK_SIZE_FRAC
663
- upsert_service_assume_tables_exist: bool = False
664
- upsert_service_timeout_create: TimeDelta | None = None
665
- upsert_service_error_create: type[Exception] = TimeoutError
666
- upsert_service_timeout_insert: TimeDelta | None = None
667
- upsert_service_error_insert: type[Exception] = TimeoutError
668
- # self
669
- _upsert_service: UpsertService = field(init=False, repr=False)
670
-
671
- def __post_init__(self) -> None:
672
- with suppress_super_object_attribute_error():
673
- super().__post_init__() # pyright: ignore[reportAttributeAccessIssue]
674
- self._upsert_service = UpsertService(
675
- # looper
676
- freq=self.upsert_service_freq,
677
- backoff=self.upsert_service_backoff,
678
- empty_upon_exit=self.upsert_service_empty_upon_exit,
679
- logger=self.upsert_service_logger,
680
- timeout=self.upsert_service_timeout,
681
- _debug=self.upsert_service_debug,
682
- # upsert service
683
- engine=self.upsert_service_database,
684
- snake=self.upsert_service_snake,
685
- selected_or_all=self.upsert_service_selected_or_all,
686
- chunk_size_frac=self.upsert_service_chunk_size_frac,
687
- assume_tables_exist=self.upsert_service_assume_tables_exist,
688
- timeout_create=self.upsert_service_timeout_create,
689
- error_create=self.upsert_service_error_create,
690
- timeout_insert=self.upsert_service_timeout_insert,
691
- error_insert=self.upsert_service_error_insert,
692
- )
693
-
694
- def _yield_sub_loopers(self) -> Iterator[Looper[Any]]:
695
- with suppress_super_object_attribute_error():
696
- yield from super()._yield_sub_loopers() # pyright: ignore[reportAttributeAccessIssue]
697
- yield self._upsert_service
698
-
699
-
700
- ##
701
-
702
-
703
- type _SelectedOrAll = Literal["selected", "all"]
704
-
705
-
706
- async def upsert_items(
707
- engine: AsyncEngine,
708
- /,
709
- *items: _InsertItem,
710
- snake: bool = False,
711
- selected_or_all: _SelectedOrAll = "selected",
712
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
713
- assume_tables_exist: bool = False,
714
- timeout_create: TimeDelta | None = None,
715
- error_create: type[Exception] = TimeoutError,
716
- timeout_insert: TimeDelta | None = None,
717
- error_insert: type[Exception] = TimeoutError,
718
- ) -> None:
719
- """Upsert a set of items into a database.
720
-
721
- These can be one of the following:
722
- - pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
723
- - pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
724
- {k1=v21, k2=v22, ...},
725
- ...], table/class
726
- - list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
727
- ({k1=v21, k2=v22, ...}, table_cls2),
728
- ...]
729
- - mapped class: Obj(k1=v1, k2=v2, ...)
730
- - list of mapped classes: [Obj(k1=v11, k2=v12, ...),
731
- Obj(k1=v21, k2=v22, ...),
732
- ...]
733
- """
734
-
735
- def build_insert(
736
- table: Table, values: Iterable[StrMapping], /
737
- ) -> tuple[Insert, None]:
738
- ups = _upsert_items_build(
739
- engine, table, values, selected_or_all=selected_or_all
740
- )
741
- return ups, None
742
-
743
- try:
744
- prepared = _prepare_insert_or_upsert_items(
745
- partial(
746
- _normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
747
- ),
748
- engine,
749
- build_insert,
750
- *items,
751
- chunk_size_frac=chunk_size_frac,
752
- )
753
- except _PrepareInsertOrUpsertItemsError as error:
754
- raise UpsertItemsError(item=error.item) from None
755
- if not assume_tables_exist:
756
- await ensure_tables_created(
757
- engine, *prepared.tables, timeout=timeout_create, error=error_create
758
- )
759
- for ups, _ in prepared.yield_pairs():
760
- async with yield_connection(
761
- engine, timeout=timeout_insert, error=error_insert
762
- ) as conn:
763
- _ = await conn.execute(ups)
764
-
765
-
766
- def _upsert_items_build(
767
- engine: AsyncEngine,
768
- table: Table,
769
- values: Iterable[StrMapping],
770
- /,
771
- *,
772
- selected_or_all: Literal["selected", "all"] = "selected",
773
- ) -> Insert:
774
- values = list(values)
775
- keys = merge_sets(*values)
776
- dict_nones = dict.fromkeys(keys)
777
- values = [{**dict_nones, **v} for v in values]
778
- match _get_dialect(engine):
779
- case "postgresql": # skipif-ci-and-not-linux
780
- insert = postgresql_insert
781
- case "sqlite":
782
- insert = sqlite_insert
783
- case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
784
- raise NotImplementedError(dialect)
785
- case _ as never:
786
- assert_never(never)
787
- ins = insert(table).values(values)
788
- primary_key = cast("Any", table.primary_key)
789
- return _upsert_items_apply_on_conflict_do_update(
790
- values, ins, primary_key, selected_or_all=selected_or_all
791
- )
792
-
793
-
794
- def _upsert_items_apply_on_conflict_do_update(
795
- values: Iterable[StrMapping],
796
- insert: postgresql_Insert | sqlite_Insert,
797
- primary_key: PrimaryKeyConstraint,
798
- /,
799
- *,
800
- selected_or_all: Literal["selected", "all"] = "selected",
801
- ) -> Insert:
802
- match selected_or_all:
803
- case "selected":
804
- columns = merge_sets(*values)
805
- case "all":
806
- columns = {c.name for c in insert.excluded}
807
- case _ as never:
808
- assert_never(never)
809
- set_ = {c: getattr(insert.excluded, c) for c in columns}
810
- match insert:
811
- case postgresql_Insert(): # skipif-ci
812
- return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
813
- case sqlite_Insert():
814
- return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
815
- case _ as never:
816
- assert_never(never)
817
-
818
-
819
- @dataclass(kw_only=True, slots=True)
820
- class UpsertItemsError(Exception):
821
- item: _InsertItem
822
-
823
- @override
824
- def __str__(self) -> str:
825
- return f"Item must be valid; got {self.item}"
826
-
827
-
828
- ##
829
-
830
-
831
914
  @asynccontextmanager
832
915
  async def yield_connection(
833
916
  engine: AsyncEngine,
834
917
  /,
835
918
  *,
836
- timeout: TimeDelta | None = None,
919
+ timeout: Delta | None = None,
837
920
  error: MaybeType[BaseException] = TimeoutError,
838
921
  ) -> AsyncIterator[AsyncConnection]:
839
922
  """Yield an async connection."""
840
- async with timeout_td(timeout, error=error), engine.begin() as conn:
841
- yield conn
923
+ try:
924
+ async with timeout_td(timeout, error=error), engine.begin() as conn:
925
+ yield conn
926
+ except GeneratorExit: # pragma: no cover
927
+ if not is_pytest():
928
+ raise
929
+ return
842
930
 
843
931
 
844
932
  ##
@@ -869,7 +957,7 @@ def _ensure_tables_maybe_reraise(error: DatabaseError, match: str, /) -> None:
869
957
  ##
870
958
 
871
959
 
872
- def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
960
+ def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
873
961
  """Get the dialect of a database."""
874
962
  dialect = engine_or_conn.dialect
875
963
  if isinstance(dialect, mssql_dialect): # pragma: no cover
@@ -879,7 +967,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
879
967
  if isinstance(dialect, oracle_dialect): # pragma: no cover
880
968
  return "oracle"
881
969
  if isinstance( # skipif-ci-and-not-linux
882
- dialect, postgresql_dialect | PGDialect_asyncpg
970
+ dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
883
971
  ):
884
972
  return "postgresql"
885
973
  if isinstance(dialect, sqlite_dialect):
@@ -892,7 +980,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
892
980
 
893
981
 
894
982
  def _get_dialect_max_params(
895
- dialect_or_engine_or_conn: Dialect | _EngineOrConnectionOrAsync, /
983
+ dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync, /
896
984
  ) -> int:
897
985
  """Get the max number of parameters of a dialect."""
898
986
  match dialect_or_engine_or_conn:
@@ -914,7 +1002,7 @@ def _get_dialect_max_params(
914
1002
  ):
915
1003
  dialect = _get_dialect(engine_or_conn)
916
1004
  return _get_dialect_max_params(dialect)
917
- case _ as never:
1005
+ case never:
918
1006
  assert_never(never)
919
1007
 
920
1008
 
@@ -999,7 +1087,7 @@ def _map_mapping_to_table(
999
1087
  @dataclass(kw_only=True, slots=True)
1000
1088
  class _MapMappingToTableError(Exception):
1001
1089
  mapping: StrMapping
1002
- columns: Sequence[str]
1090
+ columns: list[str]
1003
1091
 
1004
1092
 
1005
1093
  @dataclass(kw_only=True, slots=True)
@@ -1036,102 +1124,31 @@ class _MapMappingToTableSnakeMapNonUniqueError(_MapMappingToTableError):
1036
1124
 
1037
1125
  def _orm_inst_to_dict(obj: DeclarativeBase, /) -> StrMapping:
1038
1126
  """Map an ORM instance to a dictionary."""
1039
- cls = type(obj)
1040
-
1041
- def is_attr(attr: str, key: str, /) -> str | None:
1042
- if isinstance(value := getattr(cls, attr), InstrumentedAttribute) and (
1043
- value.name == key
1044
- ):
1045
- return attr
1046
- return None
1047
-
1048
- def yield_items() -> Iterator[tuple[str, Any]]:
1049
- for key in get_column_names(cls):
1050
- attr = one(attr for attr in dir(cls) if is_attr(attr, key) is not None)
1051
- yield key, getattr(obj, attr)
1052
-
1053
- return dict(yield_items())
1054
-
1055
-
1056
- ##
1057
-
1058
-
1059
- @dataclass(kw_only=True, slots=True)
1060
- class _PrepareInsertOrUpsertItems:
1061
- mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
1062
- yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
1063
-
1064
- @property
1065
- def tables(self) -> Sequence[Table]:
1066
- return list(self.mapping)
1067
-
1068
-
1069
- def _prepare_insert_or_upsert_items(
1070
- normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
1071
- engine: AsyncEngine,
1072
- build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
1073
- /,
1074
- *items: Any,
1075
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
1076
- ) -> _PrepareInsertOrUpsertItems:
1077
- """Prepare a set of insert/upsert items."""
1078
- mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
1079
- lengths: set[int] = set()
1080
- try:
1081
- for item in items:
1082
- for normed in normalize_item(item):
1083
- mapping[normed.table].append(normed.mapping)
1084
- lengths.add(len(normed.mapping))
1085
- except _NormalizeInsertItemError as error:
1086
- raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
1087
- merged = {
1088
- table: _prepare_insert_or_upsert_items_merge_items(table, values)
1089
- for table, values in mapping.items()
1127
+ attrs = {
1128
+ k for k, _ in yield_object_attributes(obj, static_type=InstrumentedAttribute)
1129
+ }
1130
+ return {
1131
+ name: _orm_inst_to_dict_one(obj, attrs, name) for name in get_column_names(obj)
1090
1132
  }
1091
- max_length = max(lengths, default=1)
1092
- chunk_size = get_chunk_size(
1093
- engine, chunk_size_frac=chunk_size_frac, max_length=max_length
1094
- )
1095
-
1096
- def yield_pairs() -> Iterator[tuple[Insert, None]]:
1097
- for table, values in merged.items():
1098
- for chunk in chunked(values, chunk_size):
1099
- yield build_insert(table, chunk)
1100
-
1101
- return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
1102
1133
 
1103
1134
 
1104
- @dataclass(kw_only=True, slots=True)
1105
- class _PrepareInsertOrUpsertItemsError(Exception):
1106
- item: Any
1107
-
1108
- @override
1109
- def __str__(self) -> str:
1110
- return f"Item must be valid; got {self.item}"
1135
+ def _orm_inst_to_dict_one(
1136
+ obj: DeclarativeBase, attrs: AbstractSet[str], name: str, /
1137
+ ) -> Any:
1138
+ attr = one(
1139
+ attr for attr in attrs if _orm_inst_to_dict_predicate(type(obj), attr, name)
1140
+ )
1141
+ return getattr(obj, attr)
1111
1142
 
1112
1143
 
1113
- def _prepare_insert_or_upsert_items_merge_items(
1114
- table: Table, items: Iterable[StrMapping], /
1115
- ) -> list[StrMapping]:
1116
- columns = list(yield_primary_key_columns(table))
1117
- col_names = [c.name for c in columns]
1118
- cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
1119
- cols_non_auto = set(col_names) - cols_auto
1120
- mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
1121
- unchanged: list[StrMapping] = []
1122
- for item in items:
1123
- check_subset(cols_non_auto, item)
1124
- has_all_auto = set(cols_auto).issubset(item)
1125
- if has_all_auto:
1126
- pkey = tuple(item[k] for k in col_names)
1127
- rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
1128
- mapping[pkey].append(rest)
1129
- else:
1130
- unchanged.append(item)
1131
- merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
1132
- return [
1133
- dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
1134
- ] + unchanged
1144
+ def _orm_inst_to_dict_predicate(
1145
+ cls: type[DeclarativeBase], attr: str, name: str, /
1146
+ ) -> bool:
1147
+ cls_attr = getattr(cls, attr)
1148
+ try:
1149
+ return cls_attr.name == name
1150
+ except AttributeError:
1151
+ return False
1135
1152
 
1136
1153
 
1137
1154
  ##
@@ -1148,18 +1165,27 @@ def _tuple_to_mapping(
1148
1165
  __all__ = [
1149
1166
  "CHUNK_SIZE_FRAC",
1150
1167
  "CheckEngineError",
1168
+ "DialectOrEngineOrConnectionOrAsync",
1169
+ "EngineOrConnectionOrAsync",
1170
+ "ExtractURLError",
1171
+ "ExtractURLOutput",
1151
1172
  "GetTableError",
1152
1173
  "InsertItemsError",
1153
1174
  "TablenameMixin",
1154
- "UpsertItemsError",
1155
- "UpsertService",
1156
- "UpsertServiceMixin",
1175
+ "check_connect",
1176
+ "check_connect_async",
1157
1177
  "check_engine",
1158
1178
  "columnwise_max",
1159
1179
  "columnwise_min",
1160
- "create_async_engine",
1180
+ "create_engine",
1181
+ "ensure_database_created",
1182
+ "ensure_database_dropped",
1183
+ "ensure_database_users_disconnected",
1161
1184
  "ensure_tables_created",
1162
1185
  "ensure_tables_dropped",
1186
+ "enum_name",
1187
+ "enum_values",
1188
+ "extract_url",
1163
1189
  "get_chunk_size",
1164
1190
  "get_column_names",
1165
1191
  "get_columns",
@@ -1172,7 +1198,6 @@ __all__ = [
1172
1198
  "is_table_or_orm",
1173
1199
  "migrate_data",
1174
1200
  "selectable_to_string",
1175
- "upsert_items",
1176
1201
  "yield_connection",
1177
1202
  "yield_primary_key_columns",
1178
1203
  ]