dycw-utilities 0.148.5__py3-none-any.whl → 0.175.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dycw-utilities might be problematic. Click here for more details.

Files changed (84) hide show
  1. dycw_utilities-0.175.31.dist-info/METADATA +34 -0
  2. dycw_utilities-0.175.31.dist-info/RECORD +103 -0
  3. dycw_utilities-0.175.31.dist-info/WHEEL +4 -0
  4. {dycw_utilities-0.148.5.dist-info → dycw_utilities-0.175.31.dist-info}/entry_points.txt +1 -0
  5. utilities/__init__.py +1 -1
  6. utilities/altair.py +10 -7
  7. utilities/asyncio.py +113 -64
  8. utilities/atomicwrites.py +1 -1
  9. utilities/atools.py +64 -4
  10. utilities/cachetools.py +9 -6
  11. utilities/click.py +144 -49
  12. utilities/concurrent.py +1 -1
  13. utilities/contextlib.py +4 -2
  14. utilities/contextvars.py +20 -1
  15. utilities/cryptography.py +3 -3
  16. utilities/dataclasses.py +15 -28
  17. utilities/docker.py +381 -0
  18. utilities/enum.py +2 -2
  19. utilities/errors.py +1 -1
  20. utilities/fastapi.py +8 -3
  21. utilities/fpdf2.py +2 -2
  22. utilities/functions.py +20 -297
  23. utilities/git.py +19 -0
  24. utilities/grp.py +28 -0
  25. utilities/hypothesis.py +361 -79
  26. utilities/importlib.py +17 -1
  27. utilities/inflect.py +1 -1
  28. utilities/iterables.py +12 -58
  29. utilities/jinja2.py +148 -0
  30. utilities/json.py +1 -1
  31. utilities/libcst.py +7 -7
  32. utilities/logging.py +74 -85
  33. utilities/math.py +8 -4
  34. utilities/more_itertools.py +4 -6
  35. utilities/operator.py +1 -1
  36. utilities/orjson.py +86 -34
  37. utilities/os.py +49 -2
  38. utilities/parse.py +2 -2
  39. utilities/pathlib.py +66 -34
  40. utilities/permissions.py +298 -0
  41. utilities/platform.py +4 -4
  42. utilities/polars.py +934 -420
  43. utilities/polars_ols.py +1 -1
  44. utilities/postgres.py +296 -174
  45. utilities/pottery.py +8 -73
  46. utilities/pqdm.py +3 -3
  47. utilities/pwd.py +28 -0
  48. utilities/pydantic.py +11 -0
  49. utilities/pydantic_settings.py +240 -0
  50. utilities/pydantic_settings_sops.py +76 -0
  51. utilities/pyinstrument.py +5 -5
  52. utilities/pytest.py +155 -46
  53. utilities/pytest_plugins/pytest_randomly.py +1 -1
  54. utilities/pytest_plugins/pytest_regressions.py +7 -3
  55. utilities/pytest_regressions.py +27 -8
  56. utilities/random.py +11 -6
  57. utilities/re.py +1 -1
  58. utilities/redis.py +101 -64
  59. utilities/sentinel.py +10 -0
  60. utilities/shelve.py +4 -1
  61. utilities/shutil.py +25 -0
  62. utilities/slack_sdk.py +8 -3
  63. utilities/sqlalchemy.py +422 -352
  64. utilities/sqlalchemy_polars.py +28 -52
  65. utilities/string.py +1 -1
  66. utilities/subprocess.py +1947 -0
  67. utilities/tempfile.py +95 -4
  68. utilities/testbook.py +50 -0
  69. utilities/text.py +165 -42
  70. utilities/timer.py +2 -2
  71. utilities/traceback.py +46 -36
  72. utilities/types.py +62 -23
  73. utilities/typing.py +479 -19
  74. utilities/uuid.py +42 -5
  75. utilities/version.py +27 -26
  76. utilities/whenever.py +661 -151
  77. utilities/zoneinfo.py +80 -22
  78. dycw_utilities-0.148.5.dist-info/METADATA +0 -41
  79. dycw_utilities-0.148.5.dist-info/RECORD +0 -95
  80. dycw_utilities-0.148.5.dist-info/WHEEL +0 -4
  81. dycw_utilities-0.148.5.dist-info/licenses/LICENSE +0 -21
  82. utilities/eventkit.py +0 -388
  83. utilities/period.py +0 -237
  84. utilities/typed_settings.py +0 -144
utilities/sqlalchemy.py CHANGED
@@ -12,21 +12,31 @@ from collections.abc import (
12
12
  )
13
13
  from collections.abc import Set as AbstractSet
14
14
  from contextlib import asynccontextmanager
15
- from dataclasses import dataclass, field
16
- from functools import partial, reduce
15
+ from dataclasses import dataclass
16
+ from functools import reduce
17
17
  from itertools import chain
18
18
  from math import floor
19
19
  from operator import ge, le
20
20
  from re import search
21
- from typing import TYPE_CHECKING, Any, Literal, TypeGuard, assert_never, cast, override
21
+ from socket import gaierror
22
+ from typing import (
23
+ TYPE_CHECKING,
24
+ Any,
25
+ Literal,
26
+ TypeGuard,
27
+ assert_never,
28
+ cast,
29
+ overload,
30
+ override,
31
+ )
22
32
 
33
+ import sqlalchemy
23
34
  from sqlalchemy import (
24
35
  URL,
25
36
  Column,
26
37
  Connection,
27
38
  Engine,
28
39
  Insert,
29
- PrimaryKeyConstraint,
30
40
  Selectable,
31
41
  Table,
32
42
  and_,
@@ -38,16 +48,14 @@ from sqlalchemy import (
38
48
  from sqlalchemy.dialects.mssql import dialect as mssql_dialect
39
49
  from sqlalchemy.dialects.mysql import dialect as mysql_dialect
40
50
  from sqlalchemy.dialects.oracle import dialect as oracle_dialect
41
- from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
42
51
  from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
43
52
  from sqlalchemy.dialects.postgresql import insert as postgresql_insert
44
53
  from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
45
- from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
54
+ from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
46
55
  from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
47
56
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
48
57
  from sqlalchemy.exc import ArgumentError, DatabaseError
49
- from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
50
- from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
58
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
51
59
  from sqlalchemy.orm import (
52
60
  DeclarativeBase,
53
61
  InstrumentedAttribute,
@@ -58,15 +66,7 @@ from sqlalchemy.orm.exc import UnmappedClassError
58
66
  from sqlalchemy.pool import NullPool, Pool
59
67
 
60
68
  from utilities.asyncio import timeout_td
61
- from utilities.functions import (
62
- ensure_str,
63
- get_class_name,
64
- is_sequence_of_tuple_or_str_mapping,
65
- is_string_mapping,
66
- is_tuple,
67
- is_tuple_or_str_mapping,
68
- yield_object_attributes,
69
- )
69
+ from utilities.functions import ensure_str, get_class_name, yield_object_attributes
70
70
  from utilities.iterables import (
71
71
  CheckLengthError,
72
72
  CheckSubSetError,
@@ -79,14 +79,26 @@ from utilities.iterables import (
79
79
  merge_str_mappings,
80
80
  one,
81
81
  )
82
+ from utilities.os import is_pytest
82
83
  from utilities.reprlib import get_repr
83
- from utilities.text import snake_case
84
- from utilities.types import MaybeIterable, MaybeType, StrMapping, TupleOrStrMapping
84
+ from utilities.text import secret_str, snake_case
85
+ from utilities.types import (
86
+ Delta,
87
+ MaybeIterable,
88
+ MaybeType,
89
+ StrMapping,
90
+ TupleOrStrMapping,
91
+ )
92
+ from utilities.typing import (
93
+ is_sequence_of_tuple_or_str_mapping,
94
+ is_string_mapping,
95
+ is_tuple,
96
+ is_tuple_or_str_mapping,
97
+ )
85
98
 
86
99
  if TYPE_CHECKING:
87
100
  from enum import Enum, StrEnum
88
101
 
89
- from whenever import TimeDelta
90
102
 
91
103
  type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
92
104
  type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
@@ -99,12 +111,42 @@ CHUNK_SIZE_FRAC = 0.8
99
111
  ##
100
112
 
101
113
 
114
+ _SELECT = text("SELECT 1")
115
+
116
+
117
+ def check_connect(engine: Engine, /) -> bool:
118
+ """Check if an engine can connect."""
119
+ try:
120
+ with engine.connect() as conn:
121
+ return bool(conn.execute(_SELECT).scalar_one())
122
+ except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
123
+ return False
124
+
125
+
126
+ async def check_connect_async(
127
+ engine: AsyncEngine,
128
+ /,
129
+ *,
130
+ timeout: Delta | None = None,
131
+ error: MaybeType[BaseException] = TimeoutError,
132
+ ) -> bool:
133
+ """Check if an engine can connect."""
134
+ try:
135
+ async with timeout_td(timeout, error=error), engine.connect() as conn:
136
+ return bool((await conn.execute(_SELECT)).scalar_one())
137
+ except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
138
+ return False
139
+
140
+
141
+ ##
142
+
143
+
102
144
  async def check_engine(
103
145
  engine: AsyncEngine,
104
146
  /,
105
147
  *,
106
- timeout: TimeDelta | None = None,
107
- error: type[Exception] = TimeoutError,
148
+ timeout: Delta | None = None,
149
+ error: MaybeType[BaseException] = TimeoutError,
108
150
  num_tables: int | tuple[int, float] | None = None,
109
151
  ) -> None:
110
152
  """Check that an engine can connect.
@@ -119,7 +161,7 @@ async def check_engine(
119
161
  query = "select * from all_objects"
120
162
  case "sqlite":
121
163
  query = "select * from sqlite_master where type='table'"
122
- case _ as never:
164
+ case never:
123
165
  assert_never(never)
124
166
  statement = text(query)
125
167
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
@@ -187,7 +229,49 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
187
229
  ##
188
230
 
189
231
 
190
- def create_async_engine(
232
+ @overload
233
+ def create_engine(
234
+ drivername: str,
235
+ /,
236
+ *,
237
+ username: str | None = None,
238
+ password: str | None = None,
239
+ host: str | None = None,
240
+ port: int | None = None,
241
+ database: str | None = None,
242
+ query: StrMapping | None = None,
243
+ poolclass: type[Pool] | None = NullPool,
244
+ async_: Literal[True],
245
+ ) -> AsyncEngine: ...
246
+ @overload
247
+ def create_engine(
248
+ drivername: str,
249
+ /,
250
+ *,
251
+ username: str | None = None,
252
+ password: str | None = None,
253
+ host: str | None = None,
254
+ port: int | None = None,
255
+ database: str | None = None,
256
+ query: StrMapping | None = None,
257
+ poolclass: type[Pool] | None = NullPool,
258
+ async_: Literal[False] = False,
259
+ ) -> Engine: ...
260
+ @overload
261
+ def create_engine(
262
+ drivername: str,
263
+ /,
264
+ *,
265
+ username: str | None = None,
266
+ password: str | None = None,
267
+ host: str | None = None,
268
+ port: int | None = None,
269
+ database: str | None = None,
270
+ query: StrMapping | None = None,
271
+ poolclass: type[Pool] | None = NullPool,
272
+ async_: bool = False,
273
+ ) -> Engine | AsyncEngine: ...
274
+ def create_engine(
191
275
  drivername: str,
192
276
  /,
193
277
  *,
@@ -198,7 +282,8 @@ def create_async_engine(
198
282
  database: str | None = None,
199
283
  query: StrMapping | None = None,
200
284
  poolclass: type[Pool] | None = NullPool,
201
- ) -> AsyncEngine:
285
+ async_: bool = False,
286
+ ) -> Engine | AsyncEngine:
202
287
  """Create a SQLAlchemy engine."""
203
288
  if query is None:
204
289
  kwargs = {}
@@ -217,7 +302,47 @@ def create_async_engine(
217
302
  database=database,
218
303
  **kwargs,
219
304
  )
220
- return _create_async_engine(url, poolclass=poolclass)
305
+ match async_:
306
+ case False:
307
+ return sqlalchemy.create_engine(url, poolclass=poolclass)
308
+ case True:
309
+ return create_async_engine(url, poolclass=poolclass)
310
+ case never:
311
+ assert_never(never)
312
+
313
+
314
+ ##
315
+
316
+
317
+ async def ensure_database_created(super_: URL, database: str, /) -> None:
318
+ """Ensure a database is created."""
319
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
320
+ async with engine.begin() as conn:
321
+ try:
322
+ _ = await conn.execute(text(f"CREATE DATABASE {database}"))
323
+ except DatabaseError as error:
324
+ _ensure_tables_maybe_reraise(error, 'database ".*" already exists')
325
+
326
+
327
+ async def ensure_database_dropped(super_: URL, database: str, /) -> None:
328
+ """Ensure a database is dropped."""
329
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
330
+ async with engine.begin() as conn:
331
+ _ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
332
+
333
+
334
+ async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
335
+ """Ensure a databases' users are disconnected."""
336
+ engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
337
+ match dialect := _get_dialect(engine):
338
+ case "postgresql": # skipif-ci-and-not-linux
339
+ query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
340
+ case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
341
+ raise NotImplementedError(dialect)
342
+ case never:
343
+ assert_never(never)
344
+ async with engine.begin() as conn:
345
+ _ = await conn.execute(text(query))
221
346
 
222
347
 
223
348
  ##
@@ -227,8 +352,8 @@ async def ensure_tables_created(
227
352
  engine: AsyncEngine,
228
353
  /,
229
354
  *tables_or_orms: TableOrORMInstOrClass,
230
- timeout: TimeDelta | None = None,
231
- error: type[Exception] = TimeoutError,
355
+ timeout: Delta | None = None,
356
+ error: MaybeType[BaseException] = TimeoutError,
232
357
  ) -> None:
233
358
  """Ensure a table/set of tables is/are created."""
234
359
  tables = set(map(get_table, tables_or_orms))
@@ -243,7 +368,7 @@ async def ensure_tables_created(
243
368
  match = "ORA-00955: name is already used by an existing object"
244
369
  case "sqlite":
245
370
  match = "table .* already exists"
246
- case _ as never:
371
+ case never:
247
372
  assert_never(never)
248
373
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
249
374
  for table in tables:
@@ -256,8 +381,8 @@ async def ensure_tables_created(
256
381
  async def ensure_tables_dropped(
257
382
  engine: AsyncEngine,
258
383
  *tables_or_orms: TableOrORMInstOrClass,
259
- timeout: TimeDelta | None = None,
260
- error: type[Exception] = TimeoutError,
384
+ timeout: Delta | None = None,
385
+ error: MaybeType[BaseException] = TimeoutError,
261
386
  ) -> None:
262
387
  """Ensure a table/set of tables is/are dropped."""
263
388
  tables = set(map(get_table, tables_or_orms))
@@ -272,7 +397,7 @@ async def ensure_tables_dropped(
272
397
  match = "ORA-00942: table or view does not exist"
273
398
  case "sqlite":
274
399
  match = "no such table"
275
- case _ as never:
400
+ case never:
276
401
  assert_never(never)
277
402
  async with yield_connection(engine, timeout=timeout, error=error) as conn:
278
403
  for table in tables:
@@ -301,6 +426,79 @@ def enum_values(enum: type[StrEnum], /) -> list[str]:
301
426
  ##
302
427
 
303
428
 
429
+ @dataclass(kw_only=True, slots=True)
430
+ class ExtractURLOutput:
431
+ username: str
432
+ password: secret_str
433
+ host: str
434
+ port: int
435
+ database: str
436
+
437
+
438
+ def extract_url(url: URL, /) -> ExtractURLOutput:
439
+ """Extract the database, host & port from a URL."""
440
+ if url.username is None:
441
+ raise _ExtractURLUsernameError(url=url)
442
+ if url.password is None:
443
+ raise _ExtractURLPasswordError(url=url)
444
+ if url.host is None:
445
+ raise _ExtractURLHostError(url=url)
446
+ if url.port is None:
447
+ raise _ExtractURLPortError(url=url)
448
+ if url.database is None:
449
+ raise _ExtractURLDatabaseError(url=url)
450
+ return ExtractURLOutput(
451
+ username=url.username,
452
+ password=secret_str(url.password),
453
+ host=url.host,
454
+ port=url.port,
455
+ database=url.database,
456
+ )
457
+
458
+
459
+ @dataclass(kw_only=True, slots=True)
460
+ class ExtractURLError(Exception):
461
+ url: URL
462
+
463
+
464
+ @dataclass(kw_only=True, slots=True)
465
+ class _ExtractURLUsernameError(ExtractURLError):
466
+ @override
467
+ def __str__(self) -> str:
468
+ return f"Expected URL to contain a user name; got {self.url}"
469
+
470
+
471
+ @dataclass(kw_only=True, slots=True)
472
+ class _ExtractURLPasswordError(ExtractURLError):
473
+ @override
474
+ def __str__(self) -> str:
475
+ return f"Expected URL to contain a password; got {self.url}"
476
+
477
+
478
+ @dataclass(kw_only=True, slots=True)
479
+ class _ExtractURLHostError(ExtractURLError):
480
+ @override
481
+ def __str__(self) -> str:
482
+ return f"Expected URL to contain a host; got {self.url}"
483
+
484
+
485
+ @dataclass(kw_only=True, slots=True)
486
+ class _ExtractURLPortError(ExtractURLError):
487
+ @override
488
+ def __str__(self) -> str:
489
+ return f"Expected URL to contain a port; got {self.url}"
490
+
491
+
492
+ @dataclass(kw_only=True, slots=True)
493
+ class _ExtractURLDatabaseError(ExtractURLError):
494
+ @override
495
+ def __str__(self) -> str:
496
+ return f"Expected URL to contain a database; got {self.url}"
497
+
498
+
499
+ ##
500
+
501
+
304
502
  def get_chunk_size(
305
503
  dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
306
504
  table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
@@ -324,7 +522,7 @@ def get_chunk_size(
324
522
  case int() as num_cols:
325
523
  size = floor(chunk_size_frac * max_params / num_cols)
326
524
  return max(size, 1)
327
- case _ as never:
525
+ case never:
328
526
  assert_never(never)
329
527
 
330
528
 
@@ -405,18 +603,22 @@ type _InsertItem = (
405
603
  | Sequence[_PairOfTupleOrStrMappingAndTable]
406
604
  | Sequence[DeclarativeBase]
407
605
  )
606
+ type _NormalizedItem = tuple[Table, StrMapping]
607
+ type _InsertPair = tuple[Table, Sequence[StrMapping]]
608
+ type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
408
609
 
409
610
 
410
611
  async def insert_items(
411
612
  engine: AsyncEngine,
412
613
  *items: _InsertItem,
413
614
  snake: bool = False,
615
+ is_upsert: bool = False,
414
616
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
415
617
  assume_tables_exist: bool = False,
416
- timeout_create: TimeDelta | None = None,
417
- error_create: type[Exception] = TimeoutError,
418
- timeout_insert: TimeDelta | None = None,
419
- error_insert: type[Exception] = TimeoutError,
618
+ timeout_create: Delta | None = None,
619
+ error_create: MaybeType[BaseException] = TimeoutError,
620
+ timeout_insert: Delta | None = None,
621
+ error_insert: MaybeType[BaseException] = TimeoutError,
420
622
  ) -> None:
421
623
  """Insert a set of items into a database.
422
624
 
@@ -440,37 +642,181 @@ async def insert_items(
440
642
  Obj(k1=v21, k2=v22, ...),
441
643
  ...]
442
644
  """
443
-
444
- def build_insert(
445
- table: Table, values: Iterable[StrMapping], /
446
- ) -> tuple[Insert, Any]:
447
- match _get_dialect(engine):
448
- case "oracle": # pragma: no cover
449
- return insert(table), values
450
- case _:
451
- return insert(table).values(list(values)), None
452
-
453
- try:
454
- prepared = _prepare_insert_or_upsert_items(
455
- partial(_normalize_insert_item, snake=snake),
456
- engine,
457
- build_insert,
458
- *items,
459
- chunk_size_frac=chunk_size_frac,
460
- )
461
- except _PrepareInsertOrUpsertItemsError as error:
462
- raise InsertItemsError(item=error.item) from None
645
+ normalized = chain.from_iterable(
646
+ _insert_items_yield_normalized(i, snake=snake) for i in items
647
+ )
648
+ triples = _insert_items_yield_triples(
649
+ engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
650
+ )
463
651
  if not assume_tables_exist:
652
+ triples = list(triples)
653
+ tables = {table for table, _, _ in triples}
464
654
  await ensure_tables_created(
465
- engine, *prepared.tables, timeout=timeout_create, error=error_create
655
+ engine, *tables, timeout=timeout_create, error=error_create
466
656
  )
467
- for ins, parameters in prepared.yield_pairs():
657
+ for _, ins, parameters in triples:
468
658
  async with yield_connection(
469
659
  engine, timeout=timeout_insert, error=error_insert
470
660
  ) as conn:
471
661
  _ = await conn.execute(ins, parameters=parameters)
472
662
 
473
663
 
664
+ def _insert_items_yield_normalized(
665
+ item: _InsertItem, /, *, snake: bool = False
666
+ ) -> Iterator[_NormalizedItem]:
667
+ if _is_pair_of_str_mapping_and_table(item):
668
+ mapping, table_or_orm = item
669
+ adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
670
+ yield (get_table(table_or_orm), adjusted)
671
+ return
672
+ if _is_pair_of_tuple_and_table(item):
673
+ tuple_, table_or_orm = item
674
+ mapping = _tuple_to_mapping(tuple_, table_or_orm)
675
+ yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
676
+ return
677
+ if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
678
+ items, table_or_orm = item
679
+ pairs = [(i, table_or_orm) for i in items]
680
+ for p in pairs:
681
+ yield from _insert_items_yield_normalized(p, snake=snake)
682
+ return
683
+ if isinstance(item, DeclarativeBase):
684
+ mapping = _orm_inst_to_dict(item)
685
+ yield from _insert_items_yield_normalized((mapping, item), snake=snake)
686
+ return
687
+ try:
688
+ _ = iter(item)
689
+ except TypeError:
690
+ raise InsertItemsError(item=item) from None
691
+ if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
692
+ pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
693
+ for p in pairs:
694
+ yield from _insert_items_yield_normalized(p, snake=snake)
695
+ return
696
+ if all(map(is_orm, item)):
697
+ classes = cast("Sequence[DeclarativeBase]", item)
698
+ for c in classes:
699
+ yield from _insert_items_yield_normalized(c, snake=snake)
700
+ return
701
+ raise InsertItemsError(item=item)
702
+
703
+
704
+ def _insert_items_yield_triples(
705
+ engine: AsyncEngine,
706
+ items: Iterable[_NormalizedItem],
707
+ /,
708
+ *,
709
+ is_upsert: bool = False,
710
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
711
+ ) -> Iterable[_InsertTriple]:
712
+ pairs = _insert_items_yield_chunked_pairs(
713
+ engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
714
+ )
715
+ for table, mappings in pairs:
716
+ match is_upsert, _get_dialect(engine):
717
+ case False, "oracle": # pragma: no cover
718
+ ins = insert(table)
719
+ parameters = mappings
720
+ case False, _:
721
+ ins = insert(table).values(mappings)
722
+ parameters = None
723
+ case True, _:
724
+ ins = _insert_items_build_insert_with_on_conflict_do_update(
725
+ engine, table, mappings
726
+ )
727
+ parameters = None
728
+ case never:
729
+ assert_never(never)
730
+ yield table, ins, parameters
731
+
732
+
733
+ def _insert_items_yield_chunked_pairs(
734
+ engine: AsyncEngine,
735
+ items: Iterable[_NormalizedItem],
736
+ /,
737
+ *,
738
+ is_upsert: bool = False,
739
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
740
+ ) -> Iterable[_InsertPair]:
741
+ for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
742
+ chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
743
+ for mappings_i in chunked(mappings, chunk_size):
744
+ yield table, list(mappings_i)
745
+
746
+
747
+ def _insert_items_yield_raw_pairs(
748
+ items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
749
+ ) -> Iterable[_InsertPair]:
750
+ by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
751
+ for table, mapping in items:
752
+ by_table[table].append(mapping)
753
+ for table, mappings in by_table.items():
754
+ yield from _insert_items_yield_raw_pairs_one(
755
+ table, mappings, is_upsert=is_upsert
756
+ )
757
+
758
+
759
+ def _insert_items_yield_raw_pairs_one(
760
+ table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
761
+ ) -> Iterable[_InsertPair]:
762
+ merged = _insert_items_yield_merged_mappings(table, mappings)
763
+ match is_upsert:
764
+ case True:
765
+ by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
766
+ for mapping in merged:
767
+ non_null = {k: v for k, v in mapping.items() if v is not None}
768
+ by_keys[frozenset(non_null)].append(non_null)
769
+ for mappings_i in by_keys.values():
770
+ yield table, mappings_i
771
+ case False:
772
+ yield table, list(merged)
773
+ case never:
774
+ assert_never(never)
775
+
776
+
777
+ def _insert_items_yield_merged_mappings(
778
+ table: Table, mappings: Iterable[StrMapping], /
779
+ ) -> Iterable[StrMapping]:
780
+ columns = list(yield_primary_key_columns(table))
781
+ col_names = [c.name for c in columns]
782
+ cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
783
+ cols_non_auto = set(col_names) - cols_auto
784
+ by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
785
+ for mapping in mappings:
786
+ check_subset(cols_non_auto, mapping)
787
+ has_all_auto = set(cols_auto).issubset(mapping)
788
+ if has_all_auto:
789
+ pkey = tuple(mapping[k] for k in col_names)
790
+ rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
791
+ by_key[pkey].append(rest)
792
+ else:
793
+ yield mapping
794
+ for k, v in by_key.items():
795
+ head = dict(zip(col_names, k, strict=True))
796
+ yield merge_str_mappings(head, *v)
797
+
798
+
799
+ def _insert_items_build_insert_with_on_conflict_do_update(
800
+ engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
801
+ ) -> Insert:
802
+ primary_key = cast("Any", table.primary_key)
803
+ mappings = list(mappings)
804
+ columns = merge_sets(*mappings)
805
+ match _get_dialect(engine):
806
+ case "postgresql": # skipif-ci-and-not-linux
807
+ ins = postgresql_insert(table).values(mappings)
808
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
809
+ return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
810
+ case "sqlite":
811
+ ins = sqlite_insert(table).values(mappings)
812
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
813
+ return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
814
+ case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
815
+ raise NotImplementedError(dialect)
816
+ case never:
817
+ assert_never(never)
818
+
819
+
474
820
  @dataclass(kw_only=True, slots=True)
475
821
  class InsertItemsError(Exception):
476
822
  item: _InsertItem
@@ -514,10 +860,10 @@ async def migrate_data(
514
860
  table_or_orm_to: TableOrORMInstOrClass | None = None,
515
861
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
516
862
  assume_tables_exist: bool = False,
517
- timeout_create: TimeDelta | None = None,
518
- error_create: type[Exception] = TimeoutError,
519
- timeout_insert: TimeDelta | None = None,
520
- error_insert: type[Exception] = TimeoutError,
863
+ timeout_create: Delta | None = None,
864
+ error_create: MaybeType[BaseException] = TimeoutError,
865
+ timeout_insert: Delta | None = None,
866
+ error_insert: MaybeType[BaseException] = TimeoutError,
521
867
  ) -> None:
522
868
  """Migrate the contents of a table from one database to another."""
523
869
  table_from = get_table(table_or_orm_from)
@@ -541,80 +887,6 @@ async def migrate_data(
541
887
  ##
542
888
 
543
889
 
544
- def _normalize_insert_item(
545
- item: _InsertItem, /, *, snake: bool = False
546
- ) -> list[_NormalizedItem]:
547
- """Normalize an insertion item."""
548
- if _is_pair_of_str_mapping_and_table(item):
549
- mapping, table_or_orm = item
550
- adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
551
- normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
552
- return [normalized]
553
- if _is_pair_of_tuple_and_table(item):
554
- tuple_, table_or_orm = item
555
- mapping = _tuple_to_mapping(tuple_, table_or_orm)
556
- return _normalize_insert_item((mapping, table_or_orm), snake=snake)
557
- if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
558
- items, table_or_orm = item
559
- pairs = [(i, table_or_orm) for i in items]
560
- normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
561
- return list(chain.from_iterable(normalized))
562
- if isinstance(item, DeclarativeBase):
563
- mapping = _orm_inst_to_dict(item)
564
- return _normalize_insert_item((mapping, item), snake=snake)
565
- try:
566
- _ = iter(item)
567
- except TypeError:
568
- raise _NormalizeInsertItemError(item=item) from None
569
- if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
570
- seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
571
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
572
- return list(chain.from_iterable(normalized))
573
- if all(map(is_orm, item)):
574
- seq = cast("Sequence[DeclarativeBase]", item)
575
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
576
- return list(chain.from_iterable(normalized))
577
- raise _NormalizeInsertItemError(item=item)
578
-
579
-
580
- @dataclass(kw_only=True, slots=True)
581
- class _NormalizeInsertItemError(Exception):
582
- item: _InsertItem
583
-
584
- @override
585
- def __str__(self) -> str:
586
- return f"Item must be valid; got {self.item}"
587
-
588
-
589
- @dataclass(kw_only=True, slots=True)
590
- class _NormalizedItem:
591
- mapping: StrMapping
592
- table: Table
593
-
594
-
595
- def _normalize_upsert_item(
596
- item: _InsertItem,
597
- /,
598
- *,
599
- snake: bool = False,
600
- selected_or_all: Literal["selected", "all"] = "selected",
601
- ) -> Iterator[_NormalizedItem]:
602
- """Normalize an upsert item."""
603
- normalized = _normalize_insert_item(item, snake=snake)
604
- match selected_or_all:
605
- case "selected":
606
- for norm in normalized:
607
- values = {k: v for k, v in norm.mapping.items() if v is not None}
608
- yield _NormalizedItem(mapping=values, table=norm.table)
609
- case "all":
610
- yield from normalized
611
- case _ as never:
612
- assert_never(never)
613
-
614
-
615
- ##
616
-
617
-
618
890
  def selectable_to_string(
619
891
  selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
620
892
  ) -> str:
@@ -639,140 +911,12 @@ class TablenameMixin:
639
911
  ##
640
912
 
641
913
 
642
- type _SelectedOrAll = Literal["selected", "all"]
643
-
644
-
645
- async def upsert_items(
646
- engine: AsyncEngine,
647
- /,
648
- *items: _InsertItem,
649
- snake: bool = False,
650
- selected_or_all: _SelectedOrAll = "selected",
651
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
652
- assume_tables_exist: bool = False,
653
- timeout_create: TimeDelta | None = None,
654
- error_create: type[Exception] = TimeoutError,
655
- timeout_insert: TimeDelta | None = None,
656
- error_insert: type[Exception] = TimeoutError,
657
- ) -> None:
658
- """Upsert a set of items into a database.
659
-
660
- These can be one of the following:
661
- - pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
662
- - pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
663
- {k1=v21, k2=v22, ...},
664
- ...], table/class
665
- - list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
666
- ({k1=v21, k2=v22, ...}, table_cls2),
667
- ...]
668
- - mapped class: Obj(k1=v1, k2=v2, ...)
669
- - list of mapped classes: [Obj(k1=v11, k2=v12, ...),
670
- Obj(k1=v21, k2=v22, ...),
671
- ...]
672
- """
673
-
674
- def build_insert(
675
- table: Table, values: Iterable[StrMapping], /
676
- ) -> tuple[Insert, None]:
677
- ups = _upsert_items_build(
678
- engine, table, values, selected_or_all=selected_or_all
679
- )
680
- return ups, None
681
-
682
- try:
683
- prepared = _prepare_insert_or_upsert_items(
684
- partial(
685
- _normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
686
- ),
687
- engine,
688
- build_insert,
689
- *items,
690
- chunk_size_frac=chunk_size_frac,
691
- )
692
- except _PrepareInsertOrUpsertItemsError as error:
693
- raise UpsertItemsError(item=error.item) from None
694
- if not assume_tables_exist:
695
- await ensure_tables_created(
696
- engine, *prepared.tables, timeout=timeout_create, error=error_create
697
- )
698
- for ups, _ in prepared.yield_pairs():
699
- async with yield_connection(
700
- engine, timeout=timeout_insert, error=error_insert
701
- ) as conn:
702
- _ = await conn.execute(ups)
703
-
704
-
705
- def _upsert_items_build(
706
- engine: AsyncEngine,
707
- table: Table,
708
- values: Iterable[StrMapping],
709
- /,
710
- *,
711
- selected_or_all: Literal["selected", "all"] = "selected",
712
- ) -> Insert:
713
- values = list(values)
714
- keys = merge_sets(*values)
715
- dict_nones = dict.fromkeys(keys)
716
- values = [{**dict_nones, **v} for v in values]
717
- match _get_dialect(engine):
718
- case "postgresql": # skipif-ci-and-not-linux
719
- insert = postgresql_insert
720
- case "sqlite":
721
- insert = sqlite_insert
722
- case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
723
- raise NotImplementedError(dialect)
724
- case _ as never:
725
- assert_never(never)
726
- ins = insert(table).values(values)
727
- primary_key = cast("Any", table.primary_key)
728
- return _upsert_items_apply_on_conflict_do_update(
729
- values, ins, primary_key, selected_or_all=selected_or_all
730
- )
731
-
732
-
733
- def _upsert_items_apply_on_conflict_do_update(
734
- values: Iterable[StrMapping],
735
- insert: postgresql_Insert | sqlite_Insert,
736
- primary_key: PrimaryKeyConstraint,
737
- /,
738
- *,
739
- selected_or_all: Literal["selected", "all"] = "selected",
740
- ) -> Insert:
741
- match selected_or_all:
742
- case "selected":
743
- columns = merge_sets(*values)
744
- case "all":
745
- columns = {c.name for c in insert.excluded}
746
- case _ as never:
747
- assert_never(never)
748
- set_ = {c: getattr(insert.excluded, c) for c in columns}
749
- match insert:
750
- case postgresql_Insert(): # skipif-ci
751
- return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
752
- case sqlite_Insert():
753
- return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
754
- case _ as never:
755
- assert_never(never)
756
-
757
-
758
- @dataclass(kw_only=True, slots=True)
759
- class UpsertItemsError(Exception):
760
- item: _InsertItem
761
-
762
- @override
763
- def __str__(self) -> str:
764
- return f"Item must be valid; got {self.item}"
765
-
766
-
767
- ##
768
-
769
-
770
914
  @asynccontextmanager
771
915
  async def yield_connection(
772
916
  engine: AsyncEngine,
773
917
  /,
774
918
  *,
775
- timeout: TimeDelta | None = None,
919
+ timeout: Delta | None = None,
776
920
  error: MaybeType[BaseException] = TimeoutError,
777
921
  ) -> AsyncIterator[AsyncConnection]:
778
922
  """Yield an async connection."""
@@ -780,8 +924,6 @@ async def yield_connection(
780
924
  async with timeout_td(timeout, error=error), engine.begin() as conn:
781
925
  yield conn
782
926
  except GeneratorExit: # pragma: no cover
783
- from utilities.pytest import is_pytest
784
-
785
927
  if not is_pytest():
786
928
  raise
787
929
  return
@@ -825,7 +967,7 @@ def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
825
967
  if isinstance(dialect, oracle_dialect): # pragma: no cover
826
968
  return "oracle"
827
969
  if isinstance( # skipif-ci-and-not-linux
828
- dialect, postgresql_dialect | PGDialect_asyncpg
970
+ dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
829
971
  ):
830
972
  return "postgresql"
831
973
  if isinstance(dialect, sqlite_dialect):
@@ -860,7 +1002,7 @@ def _get_dialect_max_params(
860
1002
  ):
861
1003
  dialect = _get_dialect(engine_or_conn)
862
1004
  return _get_dialect_max_params(dialect)
863
- case _ as never:
1005
+ case never:
864
1006
  assert_never(never)
865
1007
 
866
1008
 
@@ -945,7 +1087,7 @@ def _map_mapping_to_table(
945
1087
  @dataclass(kw_only=True, slots=True)
946
1088
  class _MapMappingToTableError(Exception):
947
1089
  mapping: StrMapping
948
- columns: Sequence[str]
1090
+ columns: list[str]
949
1091
 
950
1092
 
951
1093
  @dataclass(kw_only=True, slots=True)
@@ -1012,84 +1154,6 @@ def _orm_inst_to_dict_predicate(
1012
1154
  ##
1013
1155
 
1014
1156
 
1015
- @dataclass(kw_only=True, slots=True)
1016
- class _PrepareInsertOrUpsertItems:
1017
- mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
1018
- yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
1019
-
1020
- @property
1021
- def tables(self) -> Sequence[Table]:
1022
- return list(self.mapping)
1023
-
1024
-
1025
- def _prepare_insert_or_upsert_items(
1026
- normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
1027
- engine: AsyncEngine,
1028
- build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
1029
- /,
1030
- *items: Any,
1031
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
1032
- ) -> _PrepareInsertOrUpsertItems:
1033
- """Prepare a set of insert/upsert items."""
1034
- mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
1035
- lengths: set[int] = set()
1036
- try:
1037
- for item in items:
1038
- for normed in normalize_item(item):
1039
- mapping[normed.table].append(normed.mapping)
1040
- lengths.add(len(normed.mapping))
1041
- except _NormalizeInsertItemError as error:
1042
- raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
1043
- merged: dict[Table, list[StrMapping]] = {
1044
- table: _prepare_insert_or_upsert_items_merge_items(table, values)
1045
- for table, values in mapping.items()
1046
- }
1047
-
1048
- def yield_pairs() -> Iterator[tuple[Insert, None]]:
1049
- for table, values in merged.items():
1050
- chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
1051
- for chunk in chunked(values, chunk_size):
1052
- yield build_insert(table, chunk)
1053
-
1054
- return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
1055
-
1056
-
1057
- @dataclass(kw_only=True, slots=True)
1058
- class _PrepareInsertOrUpsertItemsError(Exception):
1059
- item: Any
1060
-
1061
- @override
1062
- def __str__(self) -> str:
1063
- return f"Item must be valid; got {self.item}"
1064
-
1065
-
1066
- def _prepare_insert_or_upsert_items_merge_items(
1067
- table: Table, items: Iterable[StrMapping], /
1068
- ) -> list[StrMapping]:
1069
- columns = list(yield_primary_key_columns(table))
1070
- col_names = [c.name for c in columns]
1071
- cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
1072
- cols_non_auto = set(col_names) - cols_auto
1073
- mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
1074
- unchanged: list[StrMapping] = []
1075
- for item in items:
1076
- check_subset(cols_non_auto, item)
1077
- has_all_auto = set(cols_auto).issubset(item)
1078
- if has_all_auto:
1079
- pkey = tuple(item[k] for k in col_names)
1080
- rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
1081
- mapping[pkey].append(rest)
1082
- else:
1083
- unchanged.append(item)
1084
- merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
1085
- return [
1086
- dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
1087
- ] + unchanged
1088
-
1089
-
1090
- ##
1091
-
1092
-
1093
1157
  def _tuple_to_mapping(
1094
1158
  values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
1095
1159
  ) -> dict[str, Any]:
@@ -1103,18 +1167,25 @@ __all__ = [
1103
1167
  "CheckEngineError",
1104
1168
  "DialectOrEngineOrConnectionOrAsync",
1105
1169
  "EngineOrConnectionOrAsync",
1170
+ "ExtractURLError",
1171
+ "ExtractURLOutput",
1106
1172
  "GetTableError",
1107
1173
  "InsertItemsError",
1108
1174
  "TablenameMixin",
1109
- "UpsertItemsError",
1175
+ "check_connect",
1176
+ "check_connect_async",
1110
1177
  "check_engine",
1111
1178
  "columnwise_max",
1112
1179
  "columnwise_min",
1113
- "create_async_engine",
1180
+ "create_engine",
1181
+ "ensure_database_created",
1182
+ "ensure_database_dropped",
1183
+ "ensure_database_users_disconnected",
1114
1184
  "ensure_tables_created",
1115
1185
  "ensure_tables_dropped",
1116
1186
  "enum_name",
1117
1187
  "enum_values",
1188
+ "extract_url",
1118
1189
  "get_chunk_size",
1119
1190
  "get_column_names",
1120
1191
  "get_columns",
@@ -1127,7 +1198,6 @@ __all__ = [
1127
1198
  "is_table_or_orm",
1128
1199
  "migrate_data",
1129
1200
  "selectable_to_string",
1130
- "upsert_items",
1131
1201
  "yield_connection",
1132
1202
  "yield_primary_key_columns",
1133
1203
  ]