dycw-utilities 0.148.5__py3-none-any.whl → 0.175.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dycw-utilities might be problematic. Click here for more details.
- dycw_utilities-0.175.31.dist-info/METADATA +34 -0
- dycw_utilities-0.175.31.dist-info/RECORD +103 -0
- dycw_utilities-0.175.31.dist-info/WHEEL +4 -0
- {dycw_utilities-0.148.5.dist-info → dycw_utilities-0.175.31.dist-info}/entry_points.txt +1 -0
- utilities/__init__.py +1 -1
- utilities/altair.py +10 -7
- utilities/asyncio.py +113 -64
- utilities/atomicwrites.py +1 -1
- utilities/atools.py +64 -4
- utilities/cachetools.py +9 -6
- utilities/click.py +144 -49
- utilities/concurrent.py +1 -1
- utilities/contextlib.py +4 -2
- utilities/contextvars.py +20 -1
- utilities/cryptography.py +3 -3
- utilities/dataclasses.py +15 -28
- utilities/docker.py +381 -0
- utilities/enum.py +2 -2
- utilities/errors.py +1 -1
- utilities/fastapi.py +8 -3
- utilities/fpdf2.py +2 -2
- utilities/functions.py +20 -297
- utilities/git.py +19 -0
- utilities/grp.py +28 -0
- utilities/hypothesis.py +361 -79
- utilities/importlib.py +17 -1
- utilities/inflect.py +1 -1
- utilities/iterables.py +12 -58
- utilities/jinja2.py +148 -0
- utilities/json.py +1 -1
- utilities/libcst.py +7 -7
- utilities/logging.py +74 -85
- utilities/math.py +8 -4
- utilities/more_itertools.py +4 -6
- utilities/operator.py +1 -1
- utilities/orjson.py +86 -34
- utilities/os.py +49 -2
- utilities/parse.py +2 -2
- utilities/pathlib.py +66 -34
- utilities/permissions.py +298 -0
- utilities/platform.py +4 -4
- utilities/polars.py +934 -420
- utilities/polars_ols.py +1 -1
- utilities/postgres.py +296 -174
- utilities/pottery.py +8 -73
- utilities/pqdm.py +3 -3
- utilities/pwd.py +28 -0
- utilities/pydantic.py +11 -0
- utilities/pydantic_settings.py +240 -0
- utilities/pydantic_settings_sops.py +76 -0
- utilities/pyinstrument.py +5 -5
- utilities/pytest.py +155 -46
- utilities/pytest_plugins/pytest_randomly.py +1 -1
- utilities/pytest_plugins/pytest_regressions.py +7 -3
- utilities/pytest_regressions.py +27 -8
- utilities/random.py +11 -6
- utilities/re.py +1 -1
- utilities/redis.py +101 -64
- utilities/sentinel.py +10 -0
- utilities/shelve.py +4 -1
- utilities/shutil.py +25 -0
- utilities/slack_sdk.py +8 -3
- utilities/sqlalchemy.py +422 -352
- utilities/sqlalchemy_polars.py +28 -52
- utilities/string.py +1 -1
- utilities/subprocess.py +1947 -0
- utilities/tempfile.py +95 -4
- utilities/testbook.py +50 -0
- utilities/text.py +165 -42
- utilities/timer.py +2 -2
- utilities/traceback.py +46 -36
- utilities/types.py +62 -23
- utilities/typing.py +479 -19
- utilities/uuid.py +42 -5
- utilities/version.py +27 -26
- utilities/whenever.py +661 -151
- utilities/zoneinfo.py +80 -22
- dycw_utilities-0.148.5.dist-info/METADATA +0 -41
- dycw_utilities-0.148.5.dist-info/RECORD +0 -95
- dycw_utilities-0.148.5.dist-info/WHEEL +0 -4
- dycw_utilities-0.148.5.dist-info/licenses/LICENSE +0 -21
- utilities/eventkit.py +0 -388
- utilities/period.py +0 -237
- utilities/typed_settings.py +0 -144
utilities/sqlalchemy.py
CHANGED
|
@@ -12,21 +12,31 @@ from collections.abc import (
|
|
|
12
12
|
)
|
|
13
13
|
from collections.abc import Set as AbstractSet
|
|
14
14
|
from contextlib import asynccontextmanager
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
from functools import
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from functools import reduce
|
|
17
17
|
from itertools import chain
|
|
18
18
|
from math import floor
|
|
19
19
|
from operator import ge, le
|
|
20
20
|
from re import search
|
|
21
|
-
from
|
|
21
|
+
from socket import gaierror
|
|
22
|
+
from typing import (
|
|
23
|
+
TYPE_CHECKING,
|
|
24
|
+
Any,
|
|
25
|
+
Literal,
|
|
26
|
+
TypeGuard,
|
|
27
|
+
assert_never,
|
|
28
|
+
cast,
|
|
29
|
+
overload,
|
|
30
|
+
override,
|
|
31
|
+
)
|
|
22
32
|
|
|
33
|
+
import sqlalchemy
|
|
23
34
|
from sqlalchemy import (
|
|
24
35
|
URL,
|
|
25
36
|
Column,
|
|
26
37
|
Connection,
|
|
27
38
|
Engine,
|
|
28
39
|
Insert,
|
|
29
|
-
PrimaryKeyConstraint,
|
|
30
40
|
Selectable,
|
|
31
41
|
Table,
|
|
32
42
|
and_,
|
|
@@ -38,16 +48,14 @@ from sqlalchemy import (
|
|
|
38
48
|
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
|
|
39
49
|
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
|
|
40
50
|
from sqlalchemy.dialects.oracle import dialect as oracle_dialect
|
|
41
|
-
from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
|
|
42
51
|
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
|
|
43
52
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
|
44
53
|
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
|
|
45
|
-
from sqlalchemy.dialects.
|
|
54
|
+
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
|
|
46
55
|
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
|
|
47
56
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
48
57
|
from sqlalchemy.exc import ArgumentError, DatabaseError
|
|
49
|
-
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
50
|
-
from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
|
|
58
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
|
51
59
|
from sqlalchemy.orm import (
|
|
52
60
|
DeclarativeBase,
|
|
53
61
|
InstrumentedAttribute,
|
|
@@ -58,15 +66,7 @@ from sqlalchemy.orm.exc import UnmappedClassError
|
|
|
58
66
|
from sqlalchemy.pool import NullPool, Pool
|
|
59
67
|
|
|
60
68
|
from utilities.asyncio import timeout_td
|
|
61
|
-
from utilities.functions import
|
|
62
|
-
ensure_str,
|
|
63
|
-
get_class_name,
|
|
64
|
-
is_sequence_of_tuple_or_str_mapping,
|
|
65
|
-
is_string_mapping,
|
|
66
|
-
is_tuple,
|
|
67
|
-
is_tuple_or_str_mapping,
|
|
68
|
-
yield_object_attributes,
|
|
69
|
-
)
|
|
69
|
+
from utilities.functions import ensure_str, get_class_name, yield_object_attributes
|
|
70
70
|
from utilities.iterables import (
|
|
71
71
|
CheckLengthError,
|
|
72
72
|
CheckSubSetError,
|
|
@@ -79,14 +79,26 @@ from utilities.iterables import (
|
|
|
79
79
|
merge_str_mappings,
|
|
80
80
|
one,
|
|
81
81
|
)
|
|
82
|
+
from utilities.os import is_pytest
|
|
82
83
|
from utilities.reprlib import get_repr
|
|
83
|
-
from utilities.text import snake_case
|
|
84
|
-
from utilities.types import
|
|
84
|
+
from utilities.text import secret_str, snake_case
|
|
85
|
+
from utilities.types import (
|
|
86
|
+
Delta,
|
|
87
|
+
MaybeIterable,
|
|
88
|
+
MaybeType,
|
|
89
|
+
StrMapping,
|
|
90
|
+
TupleOrStrMapping,
|
|
91
|
+
)
|
|
92
|
+
from utilities.typing import (
|
|
93
|
+
is_sequence_of_tuple_or_str_mapping,
|
|
94
|
+
is_string_mapping,
|
|
95
|
+
is_tuple,
|
|
96
|
+
is_tuple_or_str_mapping,
|
|
97
|
+
)
|
|
85
98
|
|
|
86
99
|
if TYPE_CHECKING:
|
|
87
100
|
from enum import Enum, StrEnum
|
|
88
101
|
|
|
89
|
-
from whenever import TimeDelta
|
|
90
102
|
|
|
91
103
|
type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
|
|
92
104
|
type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
|
|
@@ -99,12 +111,42 @@ CHUNK_SIZE_FRAC = 0.8
|
|
|
99
111
|
##
|
|
100
112
|
|
|
101
113
|
|
|
114
|
+
_SELECT = text("SELECT 1")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def check_connect(engine: Engine, /) -> bool:
|
|
118
|
+
"""Check if an engine can connect."""
|
|
119
|
+
try:
|
|
120
|
+
with engine.connect() as conn:
|
|
121
|
+
return bool(conn.execute(_SELECT).scalar_one())
|
|
122
|
+
except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def check_connect_async(
|
|
127
|
+
engine: AsyncEngine,
|
|
128
|
+
/,
|
|
129
|
+
*,
|
|
130
|
+
timeout: Delta | None = None,
|
|
131
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
132
|
+
) -> bool:
|
|
133
|
+
"""Check if an engine can connect."""
|
|
134
|
+
try:
|
|
135
|
+
async with timeout_td(timeout, error=error), engine.connect() as conn:
|
|
136
|
+
return bool((await conn.execute(_SELECT)).scalar_one())
|
|
137
|
+
except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
##
|
|
142
|
+
|
|
143
|
+
|
|
102
144
|
async def check_engine(
|
|
103
145
|
engine: AsyncEngine,
|
|
104
146
|
/,
|
|
105
147
|
*,
|
|
106
|
-
timeout:
|
|
107
|
-
error:
|
|
148
|
+
timeout: Delta | None = None,
|
|
149
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
108
150
|
num_tables: int | tuple[int, float] | None = None,
|
|
109
151
|
) -> None:
|
|
110
152
|
"""Check that an engine can connect.
|
|
@@ -119,7 +161,7 @@ async def check_engine(
|
|
|
119
161
|
query = "select * from all_objects"
|
|
120
162
|
case "sqlite":
|
|
121
163
|
query = "select * from sqlite_master where type='table'"
|
|
122
|
-
case
|
|
164
|
+
case never:
|
|
123
165
|
assert_never(never)
|
|
124
166
|
statement = text(query)
|
|
125
167
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
@@ -187,7 +229,49 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
|
|
|
187
229
|
##
|
|
188
230
|
|
|
189
231
|
|
|
190
|
-
|
|
232
|
+
@overload
|
|
233
|
+
def create_engine(
|
|
234
|
+
drivername: str,
|
|
235
|
+
/,
|
|
236
|
+
*,
|
|
237
|
+
username: str | None = None,
|
|
238
|
+
password: str | None = None,
|
|
239
|
+
host: str | None = None,
|
|
240
|
+
port: int | None = None,
|
|
241
|
+
database: str | None = None,
|
|
242
|
+
query: StrMapping | None = None,
|
|
243
|
+
poolclass: type[Pool] | None = NullPool,
|
|
244
|
+
async_: Literal[True],
|
|
245
|
+
) -> AsyncEngine: ...
|
|
246
|
+
@overload
|
|
247
|
+
def create_engine(
|
|
248
|
+
drivername: str,
|
|
249
|
+
/,
|
|
250
|
+
*,
|
|
251
|
+
username: str | None = None,
|
|
252
|
+
password: str | None = None,
|
|
253
|
+
host: str | None = None,
|
|
254
|
+
port: int | None = None,
|
|
255
|
+
database: str | None = None,
|
|
256
|
+
query: StrMapping | None = None,
|
|
257
|
+
poolclass: type[Pool] | None = NullPool,
|
|
258
|
+
async_: Literal[False] = False,
|
|
259
|
+
) -> Engine: ...
|
|
260
|
+
@overload
|
|
261
|
+
def create_engine(
|
|
262
|
+
drivername: str,
|
|
263
|
+
/,
|
|
264
|
+
*,
|
|
265
|
+
username: str | None = None,
|
|
266
|
+
password: str | None = None,
|
|
267
|
+
host: str | None = None,
|
|
268
|
+
port: int | None = None,
|
|
269
|
+
database: str | None = None,
|
|
270
|
+
query: StrMapping | None = None,
|
|
271
|
+
poolclass: type[Pool] | None = NullPool,
|
|
272
|
+
async_: bool = False,
|
|
273
|
+
) -> Engine | AsyncEngine: ...
|
|
274
|
+
def create_engine(
|
|
191
275
|
drivername: str,
|
|
192
276
|
/,
|
|
193
277
|
*,
|
|
@@ -198,7 +282,8 @@ def create_async_engine(
|
|
|
198
282
|
database: str | None = None,
|
|
199
283
|
query: StrMapping | None = None,
|
|
200
284
|
poolclass: type[Pool] | None = NullPool,
|
|
201
|
-
|
|
285
|
+
async_: bool = False,
|
|
286
|
+
) -> Engine | AsyncEngine:
|
|
202
287
|
"""Create a SQLAlchemy engine."""
|
|
203
288
|
if query is None:
|
|
204
289
|
kwargs = {}
|
|
@@ -217,7 +302,47 @@ def create_async_engine(
|
|
|
217
302
|
database=database,
|
|
218
303
|
**kwargs,
|
|
219
304
|
)
|
|
220
|
-
|
|
305
|
+
match async_:
|
|
306
|
+
case False:
|
|
307
|
+
return sqlalchemy.create_engine(url, poolclass=poolclass)
|
|
308
|
+
case True:
|
|
309
|
+
return create_async_engine(url, poolclass=poolclass)
|
|
310
|
+
case never:
|
|
311
|
+
assert_never(never)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
##
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
async def ensure_database_created(super_: URL, database: str, /) -> None:
|
|
318
|
+
"""Ensure a database is created."""
|
|
319
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
320
|
+
async with engine.begin() as conn:
|
|
321
|
+
try:
|
|
322
|
+
_ = await conn.execute(text(f"CREATE DATABASE {database}"))
|
|
323
|
+
except DatabaseError as error:
|
|
324
|
+
_ensure_tables_maybe_reraise(error, 'database ".*" already exists')
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
async def ensure_database_dropped(super_: URL, database: str, /) -> None:
|
|
328
|
+
"""Ensure a database is dropped."""
|
|
329
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
330
|
+
async with engine.begin() as conn:
|
|
331
|
+
_ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
|
|
335
|
+
"""Ensure a databases' users are disconnected."""
|
|
336
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
337
|
+
match dialect := _get_dialect(engine):
|
|
338
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
339
|
+
query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
|
|
340
|
+
case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
|
|
341
|
+
raise NotImplementedError(dialect)
|
|
342
|
+
case never:
|
|
343
|
+
assert_never(never)
|
|
344
|
+
async with engine.begin() as conn:
|
|
345
|
+
_ = await conn.execute(text(query))
|
|
221
346
|
|
|
222
347
|
|
|
223
348
|
##
|
|
@@ -227,8 +352,8 @@ async def ensure_tables_created(
|
|
|
227
352
|
engine: AsyncEngine,
|
|
228
353
|
/,
|
|
229
354
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
230
|
-
timeout:
|
|
231
|
-
error:
|
|
355
|
+
timeout: Delta | None = None,
|
|
356
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
232
357
|
) -> None:
|
|
233
358
|
"""Ensure a table/set of tables is/are created."""
|
|
234
359
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -243,7 +368,7 @@ async def ensure_tables_created(
|
|
|
243
368
|
match = "ORA-00955: name is already used by an existing object"
|
|
244
369
|
case "sqlite":
|
|
245
370
|
match = "table .* already exists"
|
|
246
|
-
case
|
|
371
|
+
case never:
|
|
247
372
|
assert_never(never)
|
|
248
373
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
249
374
|
for table in tables:
|
|
@@ -256,8 +381,8 @@ async def ensure_tables_created(
|
|
|
256
381
|
async def ensure_tables_dropped(
|
|
257
382
|
engine: AsyncEngine,
|
|
258
383
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
259
|
-
timeout:
|
|
260
|
-
error:
|
|
384
|
+
timeout: Delta | None = None,
|
|
385
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
261
386
|
) -> None:
|
|
262
387
|
"""Ensure a table/set of tables is/are dropped."""
|
|
263
388
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -272,7 +397,7 @@ async def ensure_tables_dropped(
|
|
|
272
397
|
match = "ORA-00942: table or view does not exist"
|
|
273
398
|
case "sqlite":
|
|
274
399
|
match = "no such table"
|
|
275
|
-
case
|
|
400
|
+
case never:
|
|
276
401
|
assert_never(never)
|
|
277
402
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
278
403
|
for table in tables:
|
|
@@ -301,6 +426,79 @@ def enum_values(enum: type[StrEnum], /) -> list[str]:
|
|
|
301
426
|
##
|
|
302
427
|
|
|
303
428
|
|
|
429
|
+
@dataclass(kw_only=True, slots=True)
|
|
430
|
+
class ExtractURLOutput:
|
|
431
|
+
username: str
|
|
432
|
+
password: secret_str
|
|
433
|
+
host: str
|
|
434
|
+
port: int
|
|
435
|
+
database: str
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def extract_url(url: URL, /) -> ExtractURLOutput:
|
|
439
|
+
"""Extract the database, host & port from a URL."""
|
|
440
|
+
if url.username is None:
|
|
441
|
+
raise _ExtractURLUsernameError(url=url)
|
|
442
|
+
if url.password is None:
|
|
443
|
+
raise _ExtractURLPasswordError(url=url)
|
|
444
|
+
if url.host is None:
|
|
445
|
+
raise _ExtractURLHostError(url=url)
|
|
446
|
+
if url.port is None:
|
|
447
|
+
raise _ExtractURLPortError(url=url)
|
|
448
|
+
if url.database is None:
|
|
449
|
+
raise _ExtractURLDatabaseError(url=url)
|
|
450
|
+
return ExtractURLOutput(
|
|
451
|
+
username=url.username,
|
|
452
|
+
password=secret_str(url.password),
|
|
453
|
+
host=url.host,
|
|
454
|
+
port=url.port,
|
|
455
|
+
database=url.database,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@dataclass(kw_only=True, slots=True)
|
|
460
|
+
class ExtractURLError(Exception):
|
|
461
|
+
url: URL
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass(kw_only=True, slots=True)
|
|
465
|
+
class _ExtractURLUsernameError(ExtractURLError):
|
|
466
|
+
@override
|
|
467
|
+
def __str__(self) -> str:
|
|
468
|
+
return f"Expected URL to contain a user name; got {self.url}"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@dataclass(kw_only=True, slots=True)
|
|
472
|
+
class _ExtractURLPasswordError(ExtractURLError):
|
|
473
|
+
@override
|
|
474
|
+
def __str__(self) -> str:
|
|
475
|
+
return f"Expected URL to contain a password; got {self.url}"
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@dataclass(kw_only=True, slots=True)
|
|
479
|
+
class _ExtractURLHostError(ExtractURLError):
|
|
480
|
+
@override
|
|
481
|
+
def __str__(self) -> str:
|
|
482
|
+
return f"Expected URL to contain a host; got {self.url}"
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dataclass(kw_only=True, slots=True)
|
|
486
|
+
class _ExtractURLPortError(ExtractURLError):
|
|
487
|
+
@override
|
|
488
|
+
def __str__(self) -> str:
|
|
489
|
+
return f"Expected URL to contain a port; got {self.url}"
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@dataclass(kw_only=True, slots=True)
|
|
493
|
+
class _ExtractURLDatabaseError(ExtractURLError):
|
|
494
|
+
@override
|
|
495
|
+
def __str__(self) -> str:
|
|
496
|
+
return f"Expected URL to contain a database; got {self.url}"
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
##
|
|
500
|
+
|
|
501
|
+
|
|
304
502
|
def get_chunk_size(
|
|
305
503
|
dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
|
|
306
504
|
table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
|
|
@@ -324,7 +522,7 @@ def get_chunk_size(
|
|
|
324
522
|
case int() as num_cols:
|
|
325
523
|
size = floor(chunk_size_frac * max_params / num_cols)
|
|
326
524
|
return max(size, 1)
|
|
327
|
-
case
|
|
525
|
+
case never:
|
|
328
526
|
assert_never(never)
|
|
329
527
|
|
|
330
528
|
|
|
@@ -405,18 +603,22 @@ type _InsertItem = (
|
|
|
405
603
|
| Sequence[_PairOfTupleOrStrMappingAndTable]
|
|
406
604
|
| Sequence[DeclarativeBase]
|
|
407
605
|
)
|
|
606
|
+
type _NormalizedItem = tuple[Table, StrMapping]
|
|
607
|
+
type _InsertPair = tuple[Table, Sequence[StrMapping]]
|
|
608
|
+
type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
|
|
408
609
|
|
|
409
610
|
|
|
410
611
|
async def insert_items(
|
|
411
612
|
engine: AsyncEngine,
|
|
412
613
|
*items: _InsertItem,
|
|
413
614
|
snake: bool = False,
|
|
615
|
+
is_upsert: bool = False,
|
|
414
616
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
415
617
|
assume_tables_exist: bool = False,
|
|
416
|
-
timeout_create:
|
|
417
|
-
error_create:
|
|
418
|
-
timeout_insert:
|
|
419
|
-
error_insert:
|
|
618
|
+
timeout_create: Delta | None = None,
|
|
619
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
620
|
+
timeout_insert: Delta | None = None,
|
|
621
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
420
622
|
) -> None:
|
|
421
623
|
"""Insert a set of items into a database.
|
|
422
624
|
|
|
@@ -440,37 +642,181 @@ async def insert_items(
|
|
|
440
642
|
Obj(k1=v21, k2=v22, ...),
|
|
441
643
|
...]
|
|
442
644
|
"""
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
return insert(table), values
|
|
450
|
-
case _:
|
|
451
|
-
return insert(table).values(list(values)), None
|
|
452
|
-
|
|
453
|
-
try:
|
|
454
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
455
|
-
partial(_normalize_insert_item, snake=snake),
|
|
456
|
-
engine,
|
|
457
|
-
build_insert,
|
|
458
|
-
*items,
|
|
459
|
-
chunk_size_frac=chunk_size_frac,
|
|
460
|
-
)
|
|
461
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
462
|
-
raise InsertItemsError(item=error.item) from None
|
|
645
|
+
normalized = chain.from_iterable(
|
|
646
|
+
_insert_items_yield_normalized(i, snake=snake) for i in items
|
|
647
|
+
)
|
|
648
|
+
triples = _insert_items_yield_triples(
|
|
649
|
+
engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
650
|
+
)
|
|
463
651
|
if not assume_tables_exist:
|
|
652
|
+
triples = list(triples)
|
|
653
|
+
tables = {table for table, _, _ in triples}
|
|
464
654
|
await ensure_tables_created(
|
|
465
|
-
engine, *
|
|
655
|
+
engine, *tables, timeout=timeout_create, error=error_create
|
|
466
656
|
)
|
|
467
|
-
for ins, parameters in
|
|
657
|
+
for _, ins, parameters in triples:
|
|
468
658
|
async with yield_connection(
|
|
469
659
|
engine, timeout=timeout_insert, error=error_insert
|
|
470
660
|
) as conn:
|
|
471
661
|
_ = await conn.execute(ins, parameters=parameters)
|
|
472
662
|
|
|
473
663
|
|
|
664
|
+
def _insert_items_yield_normalized(
|
|
665
|
+
item: _InsertItem, /, *, snake: bool = False
|
|
666
|
+
) -> Iterator[_NormalizedItem]:
|
|
667
|
+
if _is_pair_of_str_mapping_and_table(item):
|
|
668
|
+
mapping, table_or_orm = item
|
|
669
|
+
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
670
|
+
yield (get_table(table_or_orm), adjusted)
|
|
671
|
+
return
|
|
672
|
+
if _is_pair_of_tuple_and_table(item):
|
|
673
|
+
tuple_, table_or_orm = item
|
|
674
|
+
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
675
|
+
yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
|
|
676
|
+
return
|
|
677
|
+
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
678
|
+
items, table_or_orm = item
|
|
679
|
+
pairs = [(i, table_or_orm) for i in items]
|
|
680
|
+
for p in pairs:
|
|
681
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
682
|
+
return
|
|
683
|
+
if isinstance(item, DeclarativeBase):
|
|
684
|
+
mapping = _orm_inst_to_dict(item)
|
|
685
|
+
yield from _insert_items_yield_normalized((mapping, item), snake=snake)
|
|
686
|
+
return
|
|
687
|
+
try:
|
|
688
|
+
_ = iter(item)
|
|
689
|
+
except TypeError:
|
|
690
|
+
raise InsertItemsError(item=item) from None
|
|
691
|
+
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
692
|
+
pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
693
|
+
for p in pairs:
|
|
694
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
695
|
+
return
|
|
696
|
+
if all(map(is_orm, item)):
|
|
697
|
+
classes = cast("Sequence[DeclarativeBase]", item)
|
|
698
|
+
for c in classes:
|
|
699
|
+
yield from _insert_items_yield_normalized(c, snake=snake)
|
|
700
|
+
return
|
|
701
|
+
raise InsertItemsError(item=item)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _insert_items_yield_triples(
|
|
705
|
+
engine: AsyncEngine,
|
|
706
|
+
items: Iterable[_NormalizedItem],
|
|
707
|
+
/,
|
|
708
|
+
*,
|
|
709
|
+
is_upsert: bool = False,
|
|
710
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
711
|
+
) -> Iterable[_InsertTriple]:
|
|
712
|
+
pairs = _insert_items_yield_chunked_pairs(
|
|
713
|
+
engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
714
|
+
)
|
|
715
|
+
for table, mappings in pairs:
|
|
716
|
+
match is_upsert, _get_dialect(engine):
|
|
717
|
+
case False, "oracle": # pragma: no cover
|
|
718
|
+
ins = insert(table)
|
|
719
|
+
parameters = mappings
|
|
720
|
+
case False, _:
|
|
721
|
+
ins = insert(table).values(mappings)
|
|
722
|
+
parameters = None
|
|
723
|
+
case True, _:
|
|
724
|
+
ins = _insert_items_build_insert_with_on_conflict_do_update(
|
|
725
|
+
engine, table, mappings
|
|
726
|
+
)
|
|
727
|
+
parameters = None
|
|
728
|
+
case never:
|
|
729
|
+
assert_never(never)
|
|
730
|
+
yield table, ins, parameters
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _insert_items_yield_chunked_pairs(
|
|
734
|
+
engine: AsyncEngine,
|
|
735
|
+
items: Iterable[_NormalizedItem],
|
|
736
|
+
/,
|
|
737
|
+
*,
|
|
738
|
+
is_upsert: bool = False,
|
|
739
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
740
|
+
) -> Iterable[_InsertPair]:
|
|
741
|
+
for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
|
|
742
|
+
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
|
743
|
+
for mappings_i in chunked(mappings, chunk_size):
|
|
744
|
+
yield table, list(mappings_i)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _insert_items_yield_raw_pairs(
|
|
748
|
+
items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
|
|
749
|
+
) -> Iterable[_InsertPair]:
|
|
750
|
+
by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
751
|
+
for table, mapping in items:
|
|
752
|
+
by_table[table].append(mapping)
|
|
753
|
+
for table, mappings in by_table.items():
|
|
754
|
+
yield from _insert_items_yield_raw_pairs_one(
|
|
755
|
+
table, mappings, is_upsert=is_upsert
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _insert_items_yield_raw_pairs_one(
|
|
760
|
+
table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
|
|
761
|
+
) -> Iterable[_InsertPair]:
|
|
762
|
+
merged = _insert_items_yield_merged_mappings(table, mappings)
|
|
763
|
+
match is_upsert:
|
|
764
|
+
case True:
|
|
765
|
+
by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
|
|
766
|
+
for mapping in merged:
|
|
767
|
+
non_null = {k: v for k, v in mapping.items() if v is not None}
|
|
768
|
+
by_keys[frozenset(non_null)].append(non_null)
|
|
769
|
+
for mappings_i in by_keys.values():
|
|
770
|
+
yield table, mappings_i
|
|
771
|
+
case False:
|
|
772
|
+
yield table, list(merged)
|
|
773
|
+
case never:
|
|
774
|
+
assert_never(never)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _insert_items_yield_merged_mappings(
|
|
778
|
+
table: Table, mappings: Iterable[StrMapping], /
|
|
779
|
+
) -> Iterable[StrMapping]:
|
|
780
|
+
columns = list(yield_primary_key_columns(table))
|
|
781
|
+
col_names = [c.name for c in columns]
|
|
782
|
+
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
|
783
|
+
cols_non_auto = set(col_names) - cols_auto
|
|
784
|
+
by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
|
785
|
+
for mapping in mappings:
|
|
786
|
+
check_subset(cols_non_auto, mapping)
|
|
787
|
+
has_all_auto = set(cols_auto).issubset(mapping)
|
|
788
|
+
if has_all_auto:
|
|
789
|
+
pkey = tuple(mapping[k] for k in col_names)
|
|
790
|
+
rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
|
|
791
|
+
by_key[pkey].append(rest)
|
|
792
|
+
else:
|
|
793
|
+
yield mapping
|
|
794
|
+
for k, v in by_key.items():
|
|
795
|
+
head = dict(zip(col_names, k, strict=True))
|
|
796
|
+
yield merge_str_mappings(head, *v)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _insert_items_build_insert_with_on_conflict_do_update(
|
|
800
|
+
engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
|
|
801
|
+
) -> Insert:
|
|
802
|
+
primary_key = cast("Any", table.primary_key)
|
|
803
|
+
mappings = list(mappings)
|
|
804
|
+
columns = merge_sets(*mappings)
|
|
805
|
+
match _get_dialect(engine):
|
|
806
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
807
|
+
ins = postgresql_insert(table).values(mappings)
|
|
808
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
809
|
+
return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
810
|
+
case "sqlite":
|
|
811
|
+
ins = sqlite_insert(table).values(mappings)
|
|
812
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
813
|
+
return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
814
|
+
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
815
|
+
raise NotImplementedError(dialect)
|
|
816
|
+
case never:
|
|
817
|
+
assert_never(never)
|
|
818
|
+
|
|
819
|
+
|
|
474
820
|
@dataclass(kw_only=True, slots=True)
|
|
475
821
|
class InsertItemsError(Exception):
|
|
476
822
|
item: _InsertItem
|
|
@@ -514,10 +860,10 @@ async def migrate_data(
|
|
|
514
860
|
table_or_orm_to: TableOrORMInstOrClass | None = None,
|
|
515
861
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
516
862
|
assume_tables_exist: bool = False,
|
|
517
|
-
timeout_create:
|
|
518
|
-
error_create:
|
|
519
|
-
timeout_insert:
|
|
520
|
-
error_insert:
|
|
863
|
+
timeout_create: Delta | None = None,
|
|
864
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
865
|
+
timeout_insert: Delta | None = None,
|
|
866
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
521
867
|
) -> None:
|
|
522
868
|
"""Migrate the contents of a table from one database to another."""
|
|
523
869
|
table_from = get_table(table_or_orm_from)
|
|
@@ -541,80 +887,6 @@ async def migrate_data(
|
|
|
541
887
|
##
|
|
542
888
|
|
|
543
889
|
|
|
544
|
-
def _normalize_insert_item(
|
|
545
|
-
item: _InsertItem, /, *, snake: bool = False
|
|
546
|
-
) -> list[_NormalizedItem]:
|
|
547
|
-
"""Normalize an insertion item."""
|
|
548
|
-
if _is_pair_of_str_mapping_and_table(item):
|
|
549
|
-
mapping, table_or_orm = item
|
|
550
|
-
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
551
|
-
normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
|
|
552
|
-
return [normalized]
|
|
553
|
-
if _is_pair_of_tuple_and_table(item):
|
|
554
|
-
tuple_, table_or_orm = item
|
|
555
|
-
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
556
|
-
return _normalize_insert_item((mapping, table_or_orm), snake=snake)
|
|
557
|
-
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
558
|
-
items, table_or_orm = item
|
|
559
|
-
pairs = [(i, table_or_orm) for i in items]
|
|
560
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
|
|
561
|
-
return list(chain.from_iterable(normalized))
|
|
562
|
-
if isinstance(item, DeclarativeBase):
|
|
563
|
-
mapping = _orm_inst_to_dict(item)
|
|
564
|
-
return _normalize_insert_item((mapping, item), snake=snake)
|
|
565
|
-
try:
|
|
566
|
-
_ = iter(item)
|
|
567
|
-
except TypeError:
|
|
568
|
-
raise _NormalizeInsertItemError(item=item) from None
|
|
569
|
-
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
570
|
-
seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
571
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
572
|
-
return list(chain.from_iterable(normalized))
|
|
573
|
-
if all(map(is_orm, item)):
|
|
574
|
-
seq = cast("Sequence[DeclarativeBase]", item)
|
|
575
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
576
|
-
return list(chain.from_iterable(normalized))
|
|
577
|
-
raise _NormalizeInsertItemError(item=item)
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
@dataclass(kw_only=True, slots=True)
|
|
581
|
-
class _NormalizeInsertItemError(Exception):
|
|
582
|
-
item: _InsertItem
|
|
583
|
-
|
|
584
|
-
@override
|
|
585
|
-
def __str__(self) -> str:
|
|
586
|
-
return f"Item must be valid; got {self.item}"
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
@dataclass(kw_only=True, slots=True)
|
|
590
|
-
class _NormalizedItem:
|
|
591
|
-
mapping: StrMapping
|
|
592
|
-
table: Table
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
def _normalize_upsert_item(
|
|
596
|
-
item: _InsertItem,
|
|
597
|
-
/,
|
|
598
|
-
*,
|
|
599
|
-
snake: bool = False,
|
|
600
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
601
|
-
) -> Iterator[_NormalizedItem]:
|
|
602
|
-
"""Normalize an upsert item."""
|
|
603
|
-
normalized = _normalize_insert_item(item, snake=snake)
|
|
604
|
-
match selected_or_all:
|
|
605
|
-
case "selected":
|
|
606
|
-
for norm in normalized:
|
|
607
|
-
values = {k: v for k, v in norm.mapping.items() if v is not None}
|
|
608
|
-
yield _NormalizedItem(mapping=values, table=norm.table)
|
|
609
|
-
case "all":
|
|
610
|
-
yield from normalized
|
|
611
|
-
case _ as never:
|
|
612
|
-
assert_never(never)
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
##
|
|
616
|
-
|
|
617
|
-
|
|
618
890
|
def selectable_to_string(
|
|
619
891
|
selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
|
|
620
892
|
) -> str:
|
|
@@ -639,140 +911,12 @@ class TablenameMixin:
|
|
|
639
911
|
##
|
|
640
912
|
|
|
641
913
|
|
|
642
|
-
type _SelectedOrAll = Literal["selected", "all"]
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
async def upsert_items(
|
|
646
|
-
engine: AsyncEngine,
|
|
647
|
-
/,
|
|
648
|
-
*items: _InsertItem,
|
|
649
|
-
snake: bool = False,
|
|
650
|
-
selected_or_all: _SelectedOrAll = "selected",
|
|
651
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
652
|
-
assume_tables_exist: bool = False,
|
|
653
|
-
timeout_create: TimeDelta | None = None,
|
|
654
|
-
error_create: type[Exception] = TimeoutError,
|
|
655
|
-
timeout_insert: TimeDelta | None = None,
|
|
656
|
-
error_insert: type[Exception] = TimeoutError,
|
|
657
|
-
) -> None:
|
|
658
|
-
"""Upsert a set of items into a database.
|
|
659
|
-
|
|
660
|
-
These can be one of the following:
|
|
661
|
-
- pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
|
|
662
|
-
- pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
|
|
663
|
-
{k1=v21, k2=v22, ...},
|
|
664
|
-
...], table/class
|
|
665
|
-
- list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
|
|
666
|
-
({k1=v21, k2=v22, ...}, table_cls2),
|
|
667
|
-
...]
|
|
668
|
-
- mapped class: Obj(k1=v1, k2=v2, ...)
|
|
669
|
-
- list of mapped classes: [Obj(k1=v11, k2=v12, ...),
|
|
670
|
-
Obj(k1=v21, k2=v22, ...),
|
|
671
|
-
...]
|
|
672
|
-
"""
|
|
673
|
-
|
|
674
|
-
def build_insert(
|
|
675
|
-
table: Table, values: Iterable[StrMapping], /
|
|
676
|
-
) -> tuple[Insert, None]:
|
|
677
|
-
ups = _upsert_items_build(
|
|
678
|
-
engine, table, values, selected_or_all=selected_or_all
|
|
679
|
-
)
|
|
680
|
-
return ups, None
|
|
681
|
-
|
|
682
|
-
try:
|
|
683
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
684
|
-
partial(
|
|
685
|
-
_normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
|
|
686
|
-
),
|
|
687
|
-
engine,
|
|
688
|
-
build_insert,
|
|
689
|
-
*items,
|
|
690
|
-
chunk_size_frac=chunk_size_frac,
|
|
691
|
-
)
|
|
692
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
693
|
-
raise UpsertItemsError(item=error.item) from None
|
|
694
|
-
if not assume_tables_exist:
|
|
695
|
-
await ensure_tables_created(
|
|
696
|
-
engine, *prepared.tables, timeout=timeout_create, error=error_create
|
|
697
|
-
)
|
|
698
|
-
for ups, _ in prepared.yield_pairs():
|
|
699
|
-
async with yield_connection(
|
|
700
|
-
engine, timeout=timeout_insert, error=error_insert
|
|
701
|
-
) as conn:
|
|
702
|
-
_ = await conn.execute(ups)
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
def _upsert_items_build(
|
|
706
|
-
engine: AsyncEngine,
|
|
707
|
-
table: Table,
|
|
708
|
-
values: Iterable[StrMapping],
|
|
709
|
-
/,
|
|
710
|
-
*,
|
|
711
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
712
|
-
) -> Insert:
|
|
713
|
-
values = list(values)
|
|
714
|
-
keys = merge_sets(*values)
|
|
715
|
-
dict_nones = dict.fromkeys(keys)
|
|
716
|
-
values = [{**dict_nones, **v} for v in values]
|
|
717
|
-
match _get_dialect(engine):
|
|
718
|
-
case "postgresql": # skipif-ci-and-not-linux
|
|
719
|
-
insert = postgresql_insert
|
|
720
|
-
case "sqlite":
|
|
721
|
-
insert = sqlite_insert
|
|
722
|
-
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
723
|
-
raise NotImplementedError(dialect)
|
|
724
|
-
case _ as never:
|
|
725
|
-
assert_never(never)
|
|
726
|
-
ins = insert(table).values(values)
|
|
727
|
-
primary_key = cast("Any", table.primary_key)
|
|
728
|
-
return _upsert_items_apply_on_conflict_do_update(
|
|
729
|
-
values, ins, primary_key, selected_or_all=selected_or_all
|
|
730
|
-
)
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
def _upsert_items_apply_on_conflict_do_update(
|
|
734
|
-
values: Iterable[StrMapping],
|
|
735
|
-
insert: postgresql_Insert | sqlite_Insert,
|
|
736
|
-
primary_key: PrimaryKeyConstraint,
|
|
737
|
-
/,
|
|
738
|
-
*,
|
|
739
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
740
|
-
) -> Insert:
|
|
741
|
-
match selected_or_all:
|
|
742
|
-
case "selected":
|
|
743
|
-
columns = merge_sets(*values)
|
|
744
|
-
case "all":
|
|
745
|
-
columns = {c.name for c in insert.excluded}
|
|
746
|
-
case _ as never:
|
|
747
|
-
assert_never(never)
|
|
748
|
-
set_ = {c: getattr(insert.excluded, c) for c in columns}
|
|
749
|
-
match insert:
|
|
750
|
-
case postgresql_Insert(): # skipif-ci
|
|
751
|
-
return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
752
|
-
case sqlite_Insert():
|
|
753
|
-
return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
754
|
-
case _ as never:
|
|
755
|
-
assert_never(never)
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
@dataclass(kw_only=True, slots=True)
|
|
759
|
-
class UpsertItemsError(Exception):
|
|
760
|
-
item: _InsertItem
|
|
761
|
-
|
|
762
|
-
@override
|
|
763
|
-
def __str__(self) -> str:
|
|
764
|
-
return f"Item must be valid; got {self.item}"
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
##
|
|
768
|
-
|
|
769
|
-
|
|
770
914
|
@asynccontextmanager
|
|
771
915
|
async def yield_connection(
|
|
772
916
|
engine: AsyncEngine,
|
|
773
917
|
/,
|
|
774
918
|
*,
|
|
775
|
-
timeout:
|
|
919
|
+
timeout: Delta | None = None,
|
|
776
920
|
error: MaybeType[BaseException] = TimeoutError,
|
|
777
921
|
) -> AsyncIterator[AsyncConnection]:
|
|
778
922
|
"""Yield an async connection."""
|
|
@@ -780,8 +924,6 @@ async def yield_connection(
|
|
|
780
924
|
async with timeout_td(timeout, error=error), engine.begin() as conn:
|
|
781
925
|
yield conn
|
|
782
926
|
except GeneratorExit: # pragma: no cover
|
|
783
|
-
from utilities.pytest import is_pytest
|
|
784
|
-
|
|
785
927
|
if not is_pytest():
|
|
786
928
|
raise
|
|
787
929
|
return
|
|
@@ -825,7 +967,7 @@ def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
|
|
|
825
967
|
if isinstance(dialect, oracle_dialect): # pragma: no cover
|
|
826
968
|
return "oracle"
|
|
827
969
|
if isinstance( # skipif-ci-and-not-linux
|
|
828
|
-
dialect, postgresql_dialect
|
|
970
|
+
dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
|
|
829
971
|
):
|
|
830
972
|
return "postgresql"
|
|
831
973
|
if isinstance(dialect, sqlite_dialect):
|
|
@@ -860,7 +1002,7 @@ def _get_dialect_max_params(
|
|
|
860
1002
|
):
|
|
861
1003
|
dialect = _get_dialect(engine_or_conn)
|
|
862
1004
|
return _get_dialect_max_params(dialect)
|
|
863
|
-
case
|
|
1005
|
+
case never:
|
|
864
1006
|
assert_never(never)
|
|
865
1007
|
|
|
866
1008
|
|
|
@@ -945,7 +1087,7 @@ def _map_mapping_to_table(
|
|
|
945
1087
|
@dataclass(kw_only=True, slots=True)
|
|
946
1088
|
class _MapMappingToTableError(Exception):
|
|
947
1089
|
mapping: StrMapping
|
|
948
|
-
columns:
|
|
1090
|
+
columns: list[str]
|
|
949
1091
|
|
|
950
1092
|
|
|
951
1093
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1012,84 +1154,6 @@ def _orm_inst_to_dict_predicate(
|
|
|
1012
1154
|
##
|
|
1013
1155
|
|
|
1014
1156
|
|
|
1015
|
-
@dataclass(kw_only=True, slots=True)
|
|
1016
|
-
class _PrepareInsertOrUpsertItems:
|
|
1017
|
-
mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
|
|
1018
|
-
yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
|
|
1019
|
-
|
|
1020
|
-
@property
|
|
1021
|
-
def tables(self) -> Sequence[Table]:
|
|
1022
|
-
return list(self.mapping)
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
def _prepare_insert_or_upsert_items(
|
|
1026
|
-
normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
|
|
1027
|
-
engine: AsyncEngine,
|
|
1028
|
-
build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
|
|
1029
|
-
/,
|
|
1030
|
-
*items: Any,
|
|
1031
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
1032
|
-
) -> _PrepareInsertOrUpsertItems:
|
|
1033
|
-
"""Prepare a set of insert/upsert items."""
|
|
1034
|
-
mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
1035
|
-
lengths: set[int] = set()
|
|
1036
|
-
try:
|
|
1037
|
-
for item in items:
|
|
1038
|
-
for normed in normalize_item(item):
|
|
1039
|
-
mapping[normed.table].append(normed.mapping)
|
|
1040
|
-
lengths.add(len(normed.mapping))
|
|
1041
|
-
except _NormalizeInsertItemError as error:
|
|
1042
|
-
raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
|
|
1043
|
-
merged: dict[Table, list[StrMapping]] = {
|
|
1044
|
-
table: _prepare_insert_or_upsert_items_merge_items(table, values)
|
|
1045
|
-
for table, values in mapping.items()
|
|
1046
|
-
}
|
|
1047
|
-
|
|
1048
|
-
def yield_pairs() -> Iterator[tuple[Insert, None]]:
|
|
1049
|
-
for table, values in merged.items():
|
|
1050
|
-
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
|
1051
|
-
for chunk in chunked(values, chunk_size):
|
|
1052
|
-
yield build_insert(table, chunk)
|
|
1053
|
-
|
|
1054
|
-
return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
@dataclass(kw_only=True, slots=True)
|
|
1058
|
-
class _PrepareInsertOrUpsertItemsError(Exception):
|
|
1059
|
-
item: Any
|
|
1060
|
-
|
|
1061
|
-
@override
|
|
1062
|
-
def __str__(self) -> str:
|
|
1063
|
-
return f"Item must be valid; got {self.item}"
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
def _prepare_insert_or_upsert_items_merge_items(
|
|
1067
|
-
table: Table, items: Iterable[StrMapping], /
|
|
1068
|
-
) -> list[StrMapping]:
|
|
1069
|
-
columns = list(yield_primary_key_columns(table))
|
|
1070
|
-
col_names = [c.name for c in columns]
|
|
1071
|
-
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
|
1072
|
-
cols_non_auto = set(col_names) - cols_auto
|
|
1073
|
-
mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
|
1074
|
-
unchanged: list[StrMapping] = []
|
|
1075
|
-
for item in items:
|
|
1076
|
-
check_subset(cols_non_auto, item)
|
|
1077
|
-
has_all_auto = set(cols_auto).issubset(item)
|
|
1078
|
-
if has_all_auto:
|
|
1079
|
-
pkey = tuple(item[k] for k in col_names)
|
|
1080
|
-
rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
|
|
1081
|
-
mapping[pkey].append(rest)
|
|
1082
|
-
else:
|
|
1083
|
-
unchanged.append(item)
|
|
1084
|
-
merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
|
|
1085
|
-
return [
|
|
1086
|
-
dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
|
|
1087
|
-
] + unchanged
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
##
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
1157
|
def _tuple_to_mapping(
|
|
1094
1158
|
values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
|
|
1095
1159
|
) -> dict[str, Any]:
|
|
@@ -1103,18 +1167,25 @@ __all__ = [
|
|
|
1103
1167
|
"CheckEngineError",
|
|
1104
1168
|
"DialectOrEngineOrConnectionOrAsync",
|
|
1105
1169
|
"EngineOrConnectionOrAsync",
|
|
1170
|
+
"ExtractURLError",
|
|
1171
|
+
"ExtractURLOutput",
|
|
1106
1172
|
"GetTableError",
|
|
1107
1173
|
"InsertItemsError",
|
|
1108
1174
|
"TablenameMixin",
|
|
1109
|
-
"
|
|
1175
|
+
"check_connect",
|
|
1176
|
+
"check_connect_async",
|
|
1110
1177
|
"check_engine",
|
|
1111
1178
|
"columnwise_max",
|
|
1112
1179
|
"columnwise_min",
|
|
1113
|
-
"
|
|
1180
|
+
"create_engine",
|
|
1181
|
+
"ensure_database_created",
|
|
1182
|
+
"ensure_database_dropped",
|
|
1183
|
+
"ensure_database_users_disconnected",
|
|
1114
1184
|
"ensure_tables_created",
|
|
1115
1185
|
"ensure_tables_dropped",
|
|
1116
1186
|
"enum_name",
|
|
1117
1187
|
"enum_values",
|
|
1188
|
+
"extract_url",
|
|
1118
1189
|
"get_chunk_size",
|
|
1119
1190
|
"get_column_names",
|
|
1120
1191
|
"get_columns",
|
|
@@ -1127,7 +1198,6 @@ __all__ = [
|
|
|
1127
1198
|
"is_table_or_orm",
|
|
1128
1199
|
"migrate_data",
|
|
1129
1200
|
"selectable_to_string",
|
|
1130
|
-
"upsert_items",
|
|
1131
1201
|
"yield_connection",
|
|
1132
1202
|
"yield_primary_key_columns",
|
|
1133
1203
|
]
|