dycw-utilities 0.129.10__py3-none-any.whl → 0.175.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dycw_utilities-0.175.17.dist-info/METADATA +34 -0
- dycw_utilities-0.175.17.dist-info/RECORD +103 -0
- dycw_utilities-0.175.17.dist-info/WHEEL +4 -0
- dycw_utilities-0.175.17.dist-info/entry_points.txt +4 -0
- utilities/__init__.py +1 -1
- utilities/altair.py +14 -14
- utilities/asyncio.py +350 -819
- utilities/atomicwrites.py +18 -6
- utilities/atools.py +77 -22
- utilities/cachetools.py +24 -29
- utilities/click.py +393 -237
- utilities/concurrent.py +8 -11
- utilities/contextlib.py +216 -17
- utilities/contextvars.py +20 -1
- utilities/cryptography.py +3 -3
- utilities/dataclasses.py +83 -118
- utilities/docker.py +293 -0
- utilities/enum.py +26 -23
- utilities/errors.py +17 -3
- utilities/fastapi.py +29 -65
- utilities/fpdf2.py +3 -3
- utilities/functions.py +169 -416
- utilities/functools.py +18 -19
- utilities/git.py +9 -30
- utilities/grp.py +28 -0
- utilities/gzip.py +31 -0
- utilities/http.py +3 -2
- utilities/hypothesis.py +738 -589
- utilities/importlib.py +17 -1
- utilities/inflect.py +25 -0
- utilities/iterables.py +194 -262
- utilities/jinja2.py +148 -0
- utilities/json.py +70 -0
- utilities/libcst.py +38 -17
- utilities/lightweight_charts.py +5 -9
- utilities/logging.py +345 -543
- utilities/math.py +18 -13
- utilities/memory_profiler.py +11 -15
- utilities/more_itertools.py +200 -131
- utilities/operator.py +33 -29
- utilities/optuna.py +6 -6
- utilities/orjson.py +272 -137
- utilities/os.py +61 -4
- utilities/parse.py +59 -61
- utilities/pathlib.py +281 -40
- utilities/permissions.py +298 -0
- utilities/pickle.py +2 -2
- utilities/platform.py +24 -5
- utilities/polars.py +1214 -430
- utilities/polars_ols.py +1 -1
- utilities/postgres.py +408 -0
- utilities/pottery.py +113 -26
- utilities/pqdm.py +10 -11
- utilities/psutil.py +6 -57
- utilities/pwd.py +28 -0
- utilities/pydantic.py +4 -54
- utilities/pydantic_settings.py +240 -0
- utilities/pydantic_settings_sops.py +76 -0
- utilities/pyinstrument.py +8 -10
- utilities/pytest.py +227 -121
- utilities/pytest_plugins/__init__.py +1 -0
- utilities/pytest_plugins/pytest_randomly.py +23 -0
- utilities/pytest_plugins/pytest_regressions.py +56 -0
- utilities/pytest_regressions.py +26 -46
- utilities/random.py +13 -9
- utilities/re.py +58 -28
- utilities/redis.py +401 -550
- utilities/scipy.py +1 -1
- utilities/sentinel.py +10 -0
- utilities/shelve.py +4 -1
- utilities/shutil.py +25 -0
- utilities/slack_sdk.py +36 -106
- utilities/sqlalchemy.py +502 -473
- utilities/sqlalchemy_polars.py +38 -94
- utilities/string.py +2 -3
- utilities/subprocess.py +1572 -0
- utilities/tempfile.py +86 -4
- utilities/testbook.py +50 -0
- utilities/text.py +165 -42
- utilities/timer.py +37 -65
- utilities/traceback.py +158 -929
- utilities/types.py +146 -116
- utilities/typing.py +531 -71
- utilities/tzdata.py +1 -53
- utilities/tzlocal.py +6 -23
- utilities/uuid.py +43 -5
- utilities/version.py +27 -26
- utilities/whenever.py +1776 -386
- utilities/zoneinfo.py +84 -22
- dycw_utilities-0.129.10.dist-info/METADATA +0 -241
- dycw_utilities-0.129.10.dist-info/RECORD +0 -96
- dycw_utilities-0.129.10.dist-info/WHEEL +0 -4
- dycw_utilities-0.129.10.dist-info/licenses/LICENSE +0 -21
- utilities/datetime.py +0 -1409
- utilities/eventkit.py +0 -402
- utilities/loguru.py +0 -144
- utilities/luigi.py +0 -228
- utilities/period.py +0 -324
- utilities/pyrsistent.py +0 -89
- utilities/python_dotenv.py +0 -105
- utilities/streamlit.py +0 -105
- utilities/sys.py +0 -87
- utilities/tenacity.py +0 -145
utilities/sqlalchemy.py
CHANGED
|
@@ -12,21 +12,31 @@ from collections.abc import (
|
|
|
12
12
|
)
|
|
13
13
|
from collections.abc import Set as AbstractSet
|
|
14
14
|
from contextlib import asynccontextmanager
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
from functools import
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from functools import reduce
|
|
17
17
|
from itertools import chain
|
|
18
18
|
from math import floor
|
|
19
19
|
from operator import ge, le
|
|
20
20
|
from re import search
|
|
21
|
-
from
|
|
21
|
+
from socket import gaierror
|
|
22
|
+
from typing import (
|
|
23
|
+
TYPE_CHECKING,
|
|
24
|
+
Any,
|
|
25
|
+
Literal,
|
|
26
|
+
TypeGuard,
|
|
27
|
+
assert_never,
|
|
28
|
+
cast,
|
|
29
|
+
overload,
|
|
30
|
+
override,
|
|
31
|
+
)
|
|
22
32
|
|
|
33
|
+
import sqlalchemy
|
|
23
34
|
from sqlalchemy import (
|
|
24
35
|
URL,
|
|
25
36
|
Column,
|
|
26
37
|
Connection,
|
|
27
38
|
Engine,
|
|
28
39
|
Insert,
|
|
29
|
-
PrimaryKeyConstraint,
|
|
30
40
|
Selectable,
|
|
31
41
|
Table,
|
|
32
42
|
and_,
|
|
@@ -38,16 +48,14 @@ from sqlalchemy import (
|
|
|
38
48
|
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
|
|
39
49
|
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
|
|
40
50
|
from sqlalchemy.dialects.oracle import dialect as oracle_dialect
|
|
41
|
-
from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
|
|
42
51
|
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
|
|
43
52
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
|
44
53
|
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
|
|
45
|
-
from sqlalchemy.dialects.
|
|
54
|
+
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
|
|
46
55
|
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
|
|
47
56
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
48
57
|
from sqlalchemy.exc import ArgumentError, DatabaseError
|
|
49
|
-
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
50
|
-
from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
|
|
58
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
|
51
59
|
from sqlalchemy.orm import (
|
|
52
60
|
DeclarativeBase,
|
|
53
61
|
InstrumentedAttribute,
|
|
@@ -57,17 +65,8 @@ from sqlalchemy.orm import (
|
|
|
57
65
|
from sqlalchemy.orm.exc import UnmappedClassError
|
|
58
66
|
from sqlalchemy.pool import NullPool, Pool
|
|
59
67
|
|
|
60
|
-
from utilities.asyncio import
|
|
61
|
-
from utilities.
|
|
62
|
-
from utilities.datetime import SECOND
|
|
63
|
-
from utilities.functions import (
|
|
64
|
-
ensure_str,
|
|
65
|
-
get_class_name,
|
|
66
|
-
is_sequence_of_tuple_or_str_mapping,
|
|
67
|
-
is_string_mapping,
|
|
68
|
-
is_tuple,
|
|
69
|
-
is_tuple_or_str_mapping,
|
|
70
|
-
)
|
|
68
|
+
from utilities.asyncio import timeout_td
|
|
69
|
+
from utilities.functions import ensure_str, get_class_name, yield_object_attributes
|
|
71
70
|
from utilities.iterables import (
|
|
72
71
|
CheckLengthError,
|
|
73
72
|
CheckSubSetError,
|
|
@@ -80,16 +79,63 @@ from utilities.iterables import (
|
|
|
80
79
|
merge_str_mappings,
|
|
81
80
|
one,
|
|
82
81
|
)
|
|
82
|
+
from utilities.os import is_pytest
|
|
83
83
|
from utilities.reprlib import get_repr
|
|
84
|
-
from utilities.text import snake_case
|
|
85
|
-
from utilities.types import
|
|
84
|
+
from utilities.text import secret_str, snake_case
|
|
85
|
+
from utilities.types import (
|
|
86
|
+
Delta,
|
|
87
|
+
MaybeIterable,
|
|
88
|
+
MaybeType,
|
|
89
|
+
StrMapping,
|
|
90
|
+
TupleOrStrMapping,
|
|
91
|
+
)
|
|
92
|
+
from utilities.typing import (
|
|
93
|
+
is_sequence_of_tuple_or_str_mapping,
|
|
94
|
+
is_string_mapping,
|
|
95
|
+
is_tuple,
|
|
96
|
+
is_tuple_or_str_mapping,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
if TYPE_CHECKING:
|
|
100
|
+
from enum import Enum, StrEnum
|
|
101
|
+
|
|
86
102
|
|
|
87
|
-
|
|
88
|
-
type _EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
|
|
103
|
+
type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
|
|
89
104
|
type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
|
|
105
|
+
type DialectOrEngineOrConnectionOrAsync = Dialect | EngineOrConnectionOrAsync
|
|
90
106
|
type ORMInstOrClass = DeclarativeBase | type[DeclarativeBase]
|
|
91
107
|
type TableOrORMInstOrClass = Table | ORMInstOrClass
|
|
92
|
-
CHUNK_SIZE_FRAC = 0.
|
|
108
|
+
CHUNK_SIZE_FRAC = 0.8
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
##
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_SELECT = text("SELECT 1")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def check_connect(engine: Engine, /) -> bool:
|
|
118
|
+
"""Check if an engine can connect."""
|
|
119
|
+
try:
|
|
120
|
+
with engine.connect() as conn:
|
|
121
|
+
return bool(conn.execute(_SELECT).scalar_one())
|
|
122
|
+
except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def check_connect_async(
|
|
127
|
+
engine: AsyncEngine,
|
|
128
|
+
/,
|
|
129
|
+
*,
|
|
130
|
+
timeout: Delta | None = None,
|
|
131
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
132
|
+
) -> bool:
|
|
133
|
+
"""Check if an engine can connect."""
|
|
134
|
+
try:
|
|
135
|
+
async with timeout_td(timeout, error=error), engine.connect() as conn:
|
|
136
|
+
return bool((await conn.execute(_SELECT)).scalar_one())
|
|
137
|
+
except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
|
|
138
|
+
return False
|
|
93
139
|
|
|
94
140
|
|
|
95
141
|
##
|
|
@@ -99,8 +145,8 @@ async def check_engine(
|
|
|
99
145
|
engine: AsyncEngine,
|
|
100
146
|
/,
|
|
101
147
|
*,
|
|
102
|
-
timeout:
|
|
103
|
-
error:
|
|
148
|
+
timeout: Delta | None = None,
|
|
149
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
104
150
|
num_tables: int | tuple[int, float] | None = None,
|
|
105
151
|
) -> None:
|
|
106
152
|
"""Check that an engine can connect.
|
|
@@ -115,7 +161,7 @@ async def check_engine(
|
|
|
115
161
|
query = "select * from all_objects"
|
|
116
162
|
case "sqlite":
|
|
117
163
|
query = "select * from sqlite_master where type='table'"
|
|
118
|
-
case
|
|
164
|
+
case never:
|
|
119
165
|
assert_never(never)
|
|
120
166
|
statement = text(query)
|
|
121
167
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
@@ -183,7 +229,36 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
|
|
|
183
229
|
##
|
|
184
230
|
|
|
185
231
|
|
|
186
|
-
|
|
232
|
+
@overload
|
|
233
|
+
def create_engine(
|
|
234
|
+
drivername: str,
|
|
235
|
+
/,
|
|
236
|
+
*,
|
|
237
|
+
username: str | None = None,
|
|
238
|
+
password: str | None = None,
|
|
239
|
+
host: str | None = None,
|
|
240
|
+
port: int | None = None,
|
|
241
|
+
database: str | None = None,
|
|
242
|
+
query: StrMapping | None = None,
|
|
243
|
+
poolclass: type[Pool] | None = NullPool,
|
|
244
|
+
async_: Literal[True],
|
|
245
|
+
) -> AsyncEngine: ...
|
|
246
|
+
@overload
|
|
247
|
+
def create_engine(
|
|
248
|
+
drivername: str,
|
|
249
|
+
/,
|
|
250
|
+
*,
|
|
251
|
+
username: str | None = None,
|
|
252
|
+
password: str | None = None,
|
|
253
|
+
host: str | None = None,
|
|
254
|
+
port: int | None = None,
|
|
255
|
+
database: str | None = None,
|
|
256
|
+
query: StrMapping | None = None,
|
|
257
|
+
poolclass: type[Pool] | None = NullPool,
|
|
258
|
+
async_: Literal[False] = False,
|
|
259
|
+
) -> Engine: ...
|
|
260
|
+
@overload
|
|
261
|
+
def create_engine(
|
|
187
262
|
drivername: str,
|
|
188
263
|
/,
|
|
189
264
|
*,
|
|
@@ -194,7 +269,21 @@ def create_async_engine(
|
|
|
194
269
|
database: str | None = None,
|
|
195
270
|
query: StrMapping | None = None,
|
|
196
271
|
poolclass: type[Pool] | None = NullPool,
|
|
197
|
-
|
|
272
|
+
async_: bool = False,
|
|
273
|
+
) -> Engine | AsyncEngine: ...
|
|
274
|
+
def create_engine(
|
|
275
|
+
drivername: str,
|
|
276
|
+
/,
|
|
277
|
+
*,
|
|
278
|
+
username: str | None = None,
|
|
279
|
+
password: str | None = None,
|
|
280
|
+
host: str | None = None,
|
|
281
|
+
port: int | None = None,
|
|
282
|
+
database: str | None = None,
|
|
283
|
+
query: StrMapping | None = None,
|
|
284
|
+
poolclass: type[Pool] | None = NullPool,
|
|
285
|
+
async_: bool = False,
|
|
286
|
+
) -> Engine | AsyncEngine:
|
|
198
287
|
"""Create a SQLAlchemy engine."""
|
|
199
288
|
if query is None:
|
|
200
289
|
kwargs = {}
|
|
@@ -213,7 +302,47 @@ def create_async_engine(
|
|
|
213
302
|
database=database,
|
|
214
303
|
**kwargs,
|
|
215
304
|
)
|
|
216
|
-
|
|
305
|
+
match async_:
|
|
306
|
+
case False:
|
|
307
|
+
return sqlalchemy.create_engine(url, poolclass=poolclass)
|
|
308
|
+
case True:
|
|
309
|
+
return create_async_engine(url, poolclass=poolclass)
|
|
310
|
+
case never:
|
|
311
|
+
assert_never(never)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
##
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
async def ensure_database_created(super_: URL, database: str, /) -> None:
|
|
318
|
+
"""Ensure a database is created."""
|
|
319
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
320
|
+
async with engine.begin() as conn:
|
|
321
|
+
try:
|
|
322
|
+
_ = await conn.execute(text(f"CREATE DATABASE {database}"))
|
|
323
|
+
except DatabaseError as error:
|
|
324
|
+
_ensure_tables_maybe_reraise(error, 'database ".*" already exists')
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
async def ensure_database_dropped(super_: URL, database: str, /) -> None:
|
|
328
|
+
"""Ensure a database is dropped."""
|
|
329
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
330
|
+
async with engine.begin() as conn:
|
|
331
|
+
_ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
|
|
335
|
+
"""Ensure a databases' users are disconnected."""
|
|
336
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
337
|
+
match dialect := _get_dialect(engine):
|
|
338
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
339
|
+
query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
|
|
340
|
+
case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
|
|
341
|
+
raise NotImplementedError(dialect)
|
|
342
|
+
case never:
|
|
343
|
+
assert_never(never)
|
|
344
|
+
async with engine.begin() as conn:
|
|
345
|
+
_ = await conn.execute(text(query))
|
|
217
346
|
|
|
218
347
|
|
|
219
348
|
##
|
|
@@ -223,8 +352,8 @@ async def ensure_tables_created(
|
|
|
223
352
|
engine: AsyncEngine,
|
|
224
353
|
/,
|
|
225
354
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
226
|
-
timeout:
|
|
227
|
-
error:
|
|
355
|
+
timeout: Delta | None = None,
|
|
356
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
228
357
|
) -> None:
|
|
229
358
|
"""Ensure a table/set of tables is/are created."""
|
|
230
359
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -239,7 +368,7 @@ async def ensure_tables_created(
|
|
|
239
368
|
match = "ORA-00955: name is already used by an existing object"
|
|
240
369
|
case "sqlite":
|
|
241
370
|
match = "table .* already exists"
|
|
242
|
-
case
|
|
371
|
+
case never:
|
|
243
372
|
assert_never(never)
|
|
244
373
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
245
374
|
for table in tables:
|
|
@@ -252,8 +381,8 @@ async def ensure_tables_created(
|
|
|
252
381
|
async def ensure_tables_dropped(
|
|
253
382
|
engine: AsyncEngine,
|
|
254
383
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
255
|
-
timeout:
|
|
256
|
-
error:
|
|
384
|
+
timeout: Delta | None = None,
|
|
385
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
257
386
|
) -> None:
|
|
258
387
|
"""Ensure a table/set of tables is/are dropped."""
|
|
259
388
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -268,7 +397,7 @@ async def ensure_tables_dropped(
|
|
|
268
397
|
match = "ORA-00942: table or view does not exist"
|
|
269
398
|
case "sqlite":
|
|
270
399
|
match = "no such table"
|
|
271
|
-
case
|
|
400
|
+
case never:
|
|
272
401
|
assert_never(never)
|
|
273
402
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
274
403
|
for table in tables:
|
|
@@ -281,16 +410,120 @@ async def ensure_tables_dropped(
|
|
|
281
410
|
##
|
|
282
411
|
|
|
283
412
|
|
|
413
|
+
def enum_name(enum: type[Enum], /) -> str:
|
|
414
|
+
"""Get the name of an Enum."""
|
|
415
|
+
return f"{snake_case(get_class_name(enum))}_enum"
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
##
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def enum_values(enum: type[StrEnum], /) -> list[str]:
|
|
422
|
+
"""Get the values of a StrEnum."""
|
|
423
|
+
return [e.value for e in enum]
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
##
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
@dataclass(kw_only=True, slots=True)
|
|
430
|
+
class ExtractURLOutput:
|
|
431
|
+
username: str
|
|
432
|
+
password: secret_str
|
|
433
|
+
host: str
|
|
434
|
+
port: int
|
|
435
|
+
database: str
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def extract_url(url: URL, /) -> ExtractURLOutput:
|
|
439
|
+
"""Extract the database, host & port from a URL."""
|
|
440
|
+
if url.username is None:
|
|
441
|
+
raise _ExtractURLUsernameError(url=url)
|
|
442
|
+
if url.password is None:
|
|
443
|
+
raise _ExtractURLPasswordError(url=url)
|
|
444
|
+
if url.host is None:
|
|
445
|
+
raise _ExtractURLHostError(url=url)
|
|
446
|
+
if url.port is None:
|
|
447
|
+
raise _ExtractURLPortError(url=url)
|
|
448
|
+
if url.database is None:
|
|
449
|
+
raise _ExtractURLDatabaseError(url=url)
|
|
450
|
+
return ExtractURLOutput(
|
|
451
|
+
username=url.username,
|
|
452
|
+
password=secret_str(url.password),
|
|
453
|
+
host=url.host,
|
|
454
|
+
port=url.port,
|
|
455
|
+
database=url.database,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@dataclass(kw_only=True, slots=True)
|
|
460
|
+
class ExtractURLError(Exception):
|
|
461
|
+
url: URL
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass(kw_only=True, slots=True)
|
|
465
|
+
class _ExtractURLUsernameError(ExtractURLError):
|
|
466
|
+
@override
|
|
467
|
+
def __str__(self) -> str:
|
|
468
|
+
return f"Expected URL to contain a user name; got {self.url}"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@dataclass(kw_only=True, slots=True)
|
|
472
|
+
class _ExtractURLPasswordError(ExtractURLError):
|
|
473
|
+
@override
|
|
474
|
+
def __str__(self) -> str:
|
|
475
|
+
return f"Expected URL to contain a password; got {self.url}"
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@dataclass(kw_only=True, slots=True)
|
|
479
|
+
class _ExtractURLHostError(ExtractURLError):
|
|
480
|
+
@override
|
|
481
|
+
def __str__(self) -> str:
|
|
482
|
+
return f"Expected URL to contain a host; got {self.url}"
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dataclass(kw_only=True, slots=True)
|
|
486
|
+
class _ExtractURLPortError(ExtractURLError):
|
|
487
|
+
@override
|
|
488
|
+
def __str__(self) -> str:
|
|
489
|
+
return f"Expected URL to contain a port; got {self.url}"
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@dataclass(kw_only=True, slots=True)
|
|
493
|
+
class _ExtractURLDatabaseError(ExtractURLError):
|
|
494
|
+
@override
|
|
495
|
+
def __str__(self) -> str:
|
|
496
|
+
return f"Expected URL to contain a database; got {self.url}"
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
##
|
|
500
|
+
|
|
501
|
+
|
|
284
502
|
def get_chunk_size(
|
|
285
|
-
|
|
503
|
+
dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
|
|
504
|
+
table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
|
|
286
505
|
/,
|
|
287
506
|
*,
|
|
288
507
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
289
|
-
scaling: float = 1.0,
|
|
290
508
|
) -> int:
|
|
291
509
|
"""Get the maximum chunk size for an engine."""
|
|
292
|
-
max_params = _get_dialect_max_params(
|
|
293
|
-
|
|
510
|
+
max_params = _get_dialect_max_params(dialect_or_engine_or_conn)
|
|
511
|
+
match table_or_orm_or_num_cols:
|
|
512
|
+
case Table() | DeclarativeBase() | type() as table_or_orm:
|
|
513
|
+
return get_chunk_size(
|
|
514
|
+
dialect_or_engine_or_conn,
|
|
515
|
+
get_columns(table_or_orm),
|
|
516
|
+
chunk_size_frac=chunk_size_frac,
|
|
517
|
+
)
|
|
518
|
+
case Sized() as sized:
|
|
519
|
+
return get_chunk_size(
|
|
520
|
+
dialect_or_engine_or_conn, len(sized), chunk_size_frac=chunk_size_frac
|
|
521
|
+
)
|
|
522
|
+
case int() as num_cols:
|
|
523
|
+
size = floor(chunk_size_frac * max_params / num_cols)
|
|
524
|
+
return max(size, 1)
|
|
525
|
+
case never:
|
|
526
|
+
assert_never(never)
|
|
294
527
|
|
|
295
528
|
|
|
296
529
|
##
|
|
@@ -370,18 +603,22 @@ type _InsertItem = (
|
|
|
370
603
|
| Sequence[_PairOfTupleOrStrMappingAndTable]
|
|
371
604
|
| Sequence[DeclarativeBase]
|
|
372
605
|
)
|
|
606
|
+
type _NormalizedItem = tuple[Table, StrMapping]
|
|
607
|
+
type _InsertPair = tuple[Table, Sequence[StrMapping]]
|
|
608
|
+
type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
|
|
373
609
|
|
|
374
610
|
|
|
375
611
|
async def insert_items(
|
|
376
612
|
engine: AsyncEngine,
|
|
377
613
|
*items: _InsertItem,
|
|
378
614
|
snake: bool = False,
|
|
615
|
+
is_upsert: bool = False,
|
|
379
616
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
380
617
|
assume_tables_exist: bool = False,
|
|
381
|
-
timeout_create:
|
|
382
|
-
error_create:
|
|
383
|
-
timeout_insert:
|
|
384
|
-
error_insert:
|
|
618
|
+
timeout_create: Delta | None = None,
|
|
619
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
620
|
+
timeout_insert: Delta | None = None,
|
|
621
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
385
622
|
) -> None:
|
|
386
623
|
"""Insert a set of items into a database.
|
|
387
624
|
|
|
@@ -405,37 +642,181 @@ async def insert_items(
|
|
|
405
642
|
Obj(k1=v21, k2=v22, ...),
|
|
406
643
|
...]
|
|
407
644
|
"""
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
return insert(table), values
|
|
415
|
-
case _:
|
|
416
|
-
return insert(table).values(list(values)), None
|
|
417
|
-
|
|
418
|
-
try:
|
|
419
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
420
|
-
partial(_normalize_insert_item, snake=snake),
|
|
421
|
-
engine,
|
|
422
|
-
build_insert,
|
|
423
|
-
*items,
|
|
424
|
-
chunk_size_frac=chunk_size_frac,
|
|
425
|
-
)
|
|
426
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
427
|
-
raise InsertItemsError(item=error.item) from None
|
|
645
|
+
normalized = chain.from_iterable(
|
|
646
|
+
_insert_items_yield_normalized(i, snake=snake) for i in items
|
|
647
|
+
)
|
|
648
|
+
triples = _insert_items_yield_triples(
|
|
649
|
+
engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
650
|
+
)
|
|
428
651
|
if not assume_tables_exist:
|
|
652
|
+
triples = list(triples)
|
|
653
|
+
tables = {table for table, _, _ in triples}
|
|
429
654
|
await ensure_tables_created(
|
|
430
|
-
engine, *
|
|
655
|
+
engine, *tables, timeout=timeout_create, error=error_create
|
|
431
656
|
)
|
|
432
|
-
for ins, parameters in
|
|
657
|
+
for _, ins, parameters in triples:
|
|
433
658
|
async with yield_connection(
|
|
434
659
|
engine, timeout=timeout_insert, error=error_insert
|
|
435
660
|
) as conn:
|
|
436
661
|
_ = await conn.execute(ins, parameters=parameters)
|
|
437
662
|
|
|
438
663
|
|
|
664
|
+
def _insert_items_yield_normalized(
|
|
665
|
+
item: _InsertItem, /, *, snake: bool = False
|
|
666
|
+
) -> Iterator[_NormalizedItem]:
|
|
667
|
+
if _is_pair_of_str_mapping_and_table(item):
|
|
668
|
+
mapping, table_or_orm = item
|
|
669
|
+
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
670
|
+
yield (get_table(table_or_orm), adjusted)
|
|
671
|
+
return
|
|
672
|
+
if _is_pair_of_tuple_and_table(item):
|
|
673
|
+
tuple_, table_or_orm = item
|
|
674
|
+
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
675
|
+
yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
|
|
676
|
+
return
|
|
677
|
+
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
678
|
+
items, table_or_orm = item
|
|
679
|
+
pairs = [(i, table_or_orm) for i in items]
|
|
680
|
+
for p in pairs:
|
|
681
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
682
|
+
return
|
|
683
|
+
if isinstance(item, DeclarativeBase):
|
|
684
|
+
mapping = _orm_inst_to_dict(item)
|
|
685
|
+
yield from _insert_items_yield_normalized((mapping, item), snake=snake)
|
|
686
|
+
return
|
|
687
|
+
try:
|
|
688
|
+
_ = iter(item)
|
|
689
|
+
except TypeError:
|
|
690
|
+
raise InsertItemsError(item=item) from None
|
|
691
|
+
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
692
|
+
pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
693
|
+
for p in pairs:
|
|
694
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
695
|
+
return
|
|
696
|
+
if all(map(is_orm, item)):
|
|
697
|
+
classes = cast("Sequence[DeclarativeBase]", item)
|
|
698
|
+
for c in classes:
|
|
699
|
+
yield from _insert_items_yield_normalized(c, snake=snake)
|
|
700
|
+
return
|
|
701
|
+
raise InsertItemsError(item=item)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _insert_items_yield_triples(
|
|
705
|
+
engine: AsyncEngine,
|
|
706
|
+
items: Iterable[_NormalizedItem],
|
|
707
|
+
/,
|
|
708
|
+
*,
|
|
709
|
+
is_upsert: bool = False,
|
|
710
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
711
|
+
) -> Iterable[_InsertTriple]:
|
|
712
|
+
pairs = _insert_items_yield_chunked_pairs(
|
|
713
|
+
engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
714
|
+
)
|
|
715
|
+
for table, mappings in pairs:
|
|
716
|
+
match is_upsert, _get_dialect(engine):
|
|
717
|
+
case False, "oracle": # pragma: no cover
|
|
718
|
+
ins = insert(table)
|
|
719
|
+
parameters = mappings
|
|
720
|
+
case False, _:
|
|
721
|
+
ins = insert(table).values(mappings)
|
|
722
|
+
parameters = None
|
|
723
|
+
case True, _:
|
|
724
|
+
ins = _insert_items_build_insert_with_on_conflict_do_update(
|
|
725
|
+
engine, table, mappings
|
|
726
|
+
)
|
|
727
|
+
parameters = None
|
|
728
|
+
case never:
|
|
729
|
+
assert_never(never)
|
|
730
|
+
yield table, ins, parameters
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _insert_items_yield_chunked_pairs(
|
|
734
|
+
engine: AsyncEngine,
|
|
735
|
+
items: Iterable[_NormalizedItem],
|
|
736
|
+
/,
|
|
737
|
+
*,
|
|
738
|
+
is_upsert: bool = False,
|
|
739
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
740
|
+
) -> Iterable[_InsertPair]:
|
|
741
|
+
for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
|
|
742
|
+
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
|
743
|
+
for mappings_i in chunked(mappings, chunk_size):
|
|
744
|
+
yield table, list(mappings_i)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _insert_items_yield_raw_pairs(
|
|
748
|
+
items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
|
|
749
|
+
) -> Iterable[_InsertPair]:
|
|
750
|
+
by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
751
|
+
for table, mapping in items:
|
|
752
|
+
by_table[table].append(mapping)
|
|
753
|
+
for table, mappings in by_table.items():
|
|
754
|
+
yield from _insert_items_yield_raw_pairs_one(
|
|
755
|
+
table, mappings, is_upsert=is_upsert
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _insert_items_yield_raw_pairs_one(
|
|
760
|
+
table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
|
|
761
|
+
) -> Iterable[_InsertPair]:
|
|
762
|
+
merged = _insert_items_yield_merged_mappings(table, mappings)
|
|
763
|
+
match is_upsert:
|
|
764
|
+
case True:
|
|
765
|
+
by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
|
|
766
|
+
for mapping in merged:
|
|
767
|
+
non_null = {k: v for k, v in mapping.items() if v is not None}
|
|
768
|
+
by_keys[frozenset(non_null)].append(non_null)
|
|
769
|
+
for mappings_i in by_keys.values():
|
|
770
|
+
yield table, mappings_i
|
|
771
|
+
case False:
|
|
772
|
+
yield table, list(merged)
|
|
773
|
+
case never:
|
|
774
|
+
assert_never(never)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _insert_items_yield_merged_mappings(
|
|
778
|
+
table: Table, mappings: Iterable[StrMapping], /
|
|
779
|
+
) -> Iterable[StrMapping]:
|
|
780
|
+
columns = list(yield_primary_key_columns(table))
|
|
781
|
+
col_names = [c.name for c in columns]
|
|
782
|
+
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
|
783
|
+
cols_non_auto = set(col_names) - cols_auto
|
|
784
|
+
by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
|
785
|
+
for mapping in mappings:
|
|
786
|
+
check_subset(cols_non_auto, mapping)
|
|
787
|
+
has_all_auto = set(cols_auto).issubset(mapping)
|
|
788
|
+
if has_all_auto:
|
|
789
|
+
pkey = tuple(mapping[k] for k in col_names)
|
|
790
|
+
rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
|
|
791
|
+
by_key[pkey].append(rest)
|
|
792
|
+
else:
|
|
793
|
+
yield mapping
|
|
794
|
+
for k, v in by_key.items():
|
|
795
|
+
head = dict(zip(col_names, k, strict=True))
|
|
796
|
+
yield merge_str_mappings(head, *v)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _insert_items_build_insert_with_on_conflict_do_update(
|
|
800
|
+
engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
|
|
801
|
+
) -> Insert:
|
|
802
|
+
primary_key = cast("Any", table.primary_key)
|
|
803
|
+
mappings = list(mappings)
|
|
804
|
+
columns = merge_sets(*mappings)
|
|
805
|
+
match _get_dialect(engine):
|
|
806
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
807
|
+
ins = postgresql_insert(table).values(mappings)
|
|
808
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
809
|
+
return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
810
|
+
case "sqlite":
|
|
811
|
+
ins = sqlite_insert(table).values(mappings)
|
|
812
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
813
|
+
return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
814
|
+
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
815
|
+
raise NotImplementedError(dialect)
|
|
816
|
+
case never:
|
|
817
|
+
assert_never(never)
|
|
818
|
+
|
|
819
|
+
|
|
439
820
|
@dataclass(kw_only=True, slots=True)
|
|
440
821
|
class InsertItemsError(Exception):
|
|
441
822
|
item: _InsertItem
|
|
@@ -479,10 +860,10 @@ async def migrate_data(
|
|
|
479
860
|
table_or_orm_to: TableOrORMInstOrClass | None = None,
|
|
480
861
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
481
862
|
assume_tables_exist: bool = False,
|
|
482
|
-
timeout_create:
|
|
483
|
-
error_create:
|
|
484
|
-
timeout_insert:
|
|
485
|
-
error_insert:
|
|
863
|
+
timeout_create: Delta | None = None,
|
|
864
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
865
|
+
timeout_insert: Delta | None = None,
|
|
866
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
486
867
|
) -> None:
|
|
487
868
|
"""Migrate the contents of a table from one database to another."""
|
|
488
869
|
table_from = get_table(table_or_orm_from)
|
|
@@ -506,82 +887,8 @@ async def migrate_data(
|
|
|
506
887
|
##
|
|
507
888
|
|
|
508
889
|
|
|
509
|
-
def _normalize_insert_item(
|
|
510
|
-
item: _InsertItem, /, *, snake: bool = False
|
|
511
|
-
) -> list[_NormalizedItem]:
|
|
512
|
-
"""Normalize an insertion item."""
|
|
513
|
-
if _is_pair_of_str_mapping_and_table(item):
|
|
514
|
-
mapping, table_or_orm = item
|
|
515
|
-
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
516
|
-
normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
|
|
517
|
-
return [normalized]
|
|
518
|
-
if _is_pair_of_tuple_and_table(item):
|
|
519
|
-
tuple_, table_or_orm = item
|
|
520
|
-
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
521
|
-
return _normalize_insert_item((mapping, table_or_orm), snake=snake)
|
|
522
|
-
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
523
|
-
items, table_or_orm = item
|
|
524
|
-
pairs = [(i, table_or_orm) for i in items]
|
|
525
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
|
|
526
|
-
return list(chain.from_iterable(normalized))
|
|
527
|
-
if isinstance(item, DeclarativeBase):
|
|
528
|
-
mapping = _orm_inst_to_dict(item)
|
|
529
|
-
return _normalize_insert_item((mapping, item), snake=snake)
|
|
530
|
-
try:
|
|
531
|
-
_ = iter(item)
|
|
532
|
-
except TypeError:
|
|
533
|
-
raise _NormalizeInsertItemError(item=item) from None
|
|
534
|
-
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
535
|
-
seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
536
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
537
|
-
return list(chain.from_iterable(normalized))
|
|
538
|
-
if all(map(is_orm, item)):
|
|
539
|
-
seq = cast("Sequence[DeclarativeBase]", item)
|
|
540
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
541
|
-
return list(chain.from_iterable(normalized))
|
|
542
|
-
raise _NormalizeInsertItemError(item=item)
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
@dataclass(kw_only=True, slots=True)
|
|
546
|
-
class _NormalizeInsertItemError(Exception):
|
|
547
|
-
item: _InsertItem
|
|
548
|
-
|
|
549
|
-
@override
|
|
550
|
-
def __str__(self) -> str:
|
|
551
|
-
return f"Item must be valid; got {self.item}"
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
@dataclass(kw_only=True, slots=True)
|
|
555
|
-
class _NormalizedItem:
|
|
556
|
-
mapping: StrMapping
|
|
557
|
-
table: Table
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
def _normalize_upsert_item(
|
|
561
|
-
item: _InsertItem,
|
|
562
|
-
/,
|
|
563
|
-
*,
|
|
564
|
-
snake: bool = False,
|
|
565
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
566
|
-
) -> Iterator[_NormalizedItem]:
|
|
567
|
-
"""Normalize an upsert item."""
|
|
568
|
-
normalized = _normalize_insert_item(item, snake=snake)
|
|
569
|
-
match selected_or_all:
|
|
570
|
-
case "selected":
|
|
571
|
-
for norm in normalized:
|
|
572
|
-
values = {k: v for k, v in norm.mapping.items() if v is not None}
|
|
573
|
-
yield _NormalizedItem(mapping=values, table=norm.table)
|
|
574
|
-
case "all":
|
|
575
|
-
yield from normalized
|
|
576
|
-
case _ as never:
|
|
577
|
-
assert_never(never)
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
##
|
|
581
|
-
|
|
582
|
-
|
|
583
890
|
def selectable_to_string(
|
|
584
|
-
selectable: Selectable[Any], engine_or_conn:
|
|
891
|
+
selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
|
|
585
892
|
) -> str:
|
|
586
893
|
"""Convert a selectable into a string."""
|
|
587
894
|
com = selectable.compile(
|
|
@@ -604,237 +911,22 @@ class TablenameMixin:
|
|
|
604
911
|
##
|
|
605
912
|
|
|
606
913
|
|
|
607
|
-
@dataclass(kw_only=True)
|
|
608
|
-
class UpsertService(Looper[_InsertItem]):
|
|
609
|
-
"""Service to upsert items to a database."""
|
|
610
|
-
|
|
611
|
-
# base
|
|
612
|
-
freq: Duration = field(default=SECOND, repr=False)
|
|
613
|
-
backoff: Duration = field(default=SECOND, repr=False)
|
|
614
|
-
empty_upon_exit: bool = field(default=True, repr=False)
|
|
615
|
-
# self
|
|
616
|
-
engine: AsyncEngine
|
|
617
|
-
snake: bool = False
|
|
618
|
-
selected_or_all: _SelectedOrAll = "selected"
|
|
619
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC
|
|
620
|
-
assume_tables_exist: bool = False
|
|
621
|
-
timeout_create: Duration | None = None
|
|
622
|
-
error_create: type[Exception] = TimeoutError
|
|
623
|
-
timeout_insert: Duration | None = None
|
|
624
|
-
error_insert: type[Exception] = TimeoutError
|
|
625
|
-
|
|
626
|
-
@override
|
|
627
|
-
async def core(self) -> None:
|
|
628
|
-
await super().core()
|
|
629
|
-
await upsert_items(
|
|
630
|
-
self.engine,
|
|
631
|
-
*self.get_all_nowait(),
|
|
632
|
-
snake=self.snake,
|
|
633
|
-
selected_or_all=self.selected_or_all,
|
|
634
|
-
chunk_size_frac=self.chunk_size_frac,
|
|
635
|
-
assume_tables_exist=self.assume_tables_exist,
|
|
636
|
-
timeout_create=self.timeout_create,
|
|
637
|
-
error_create=self.error_create,
|
|
638
|
-
timeout_insert=self.timeout_insert,
|
|
639
|
-
error_insert=self.error_insert,
|
|
640
|
-
)
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
@dataclass(kw_only=True)
|
|
644
|
-
class UpsertServiceMixin:
|
|
645
|
-
"""Mix-in for the upsert service."""
|
|
646
|
-
|
|
647
|
-
# base - looper
|
|
648
|
-
upsert_service_freq: Duration = field(default=SECOND, repr=False)
|
|
649
|
-
upsert_service_backoff: Duration = field(default=SECOND, repr=False)
|
|
650
|
-
upsert_service_empty_upon_exit: bool = field(default=False, repr=False)
|
|
651
|
-
upsert_service_logger: str | None = field(default=None, repr=False)
|
|
652
|
-
upsert_service_timeout: Duration | None = field(default=None, repr=False)
|
|
653
|
-
upsert_service_debug: bool = field(default=False, repr=False)
|
|
654
|
-
# base - upsert service
|
|
655
|
-
upsert_service_database: AsyncEngine
|
|
656
|
-
upsert_service_snake: bool = False
|
|
657
|
-
upsert_service_selected_or_all: _SelectedOrAll = "selected"
|
|
658
|
-
upsert_service_chunk_size_frac: float = CHUNK_SIZE_FRAC
|
|
659
|
-
upsert_service_assume_tables_exist: bool = False
|
|
660
|
-
upsert_service_timeout_create: Duration | None = None
|
|
661
|
-
upsert_service_error_create: type[Exception] = TimeoutError
|
|
662
|
-
upsert_service_timeout_insert: Duration | None = None
|
|
663
|
-
upsert_service_error_insert: type[Exception] = TimeoutError
|
|
664
|
-
# self
|
|
665
|
-
_upsert_service: UpsertService = field(init=False, repr=False)
|
|
666
|
-
|
|
667
|
-
def __post_init__(self) -> None:
|
|
668
|
-
with suppress_super_object_attribute_error():
|
|
669
|
-
super().__post_init__() # pyright: ignore[reportAttributeAccessIssue]
|
|
670
|
-
self._upsert_service = UpsertService(
|
|
671
|
-
# looper
|
|
672
|
-
freq=self.upsert_service_freq,
|
|
673
|
-
backoff=self.upsert_service_backoff,
|
|
674
|
-
empty_upon_exit=self.upsert_service_empty_upon_exit,
|
|
675
|
-
logger=self.upsert_service_logger,
|
|
676
|
-
timeout=self.upsert_service_timeout,
|
|
677
|
-
_debug=self.upsert_service_debug,
|
|
678
|
-
# upsert service
|
|
679
|
-
engine=self.upsert_service_database,
|
|
680
|
-
snake=self.upsert_service_snake,
|
|
681
|
-
selected_or_all=self.upsert_service_selected_or_all,
|
|
682
|
-
chunk_size_frac=self.upsert_service_chunk_size_frac,
|
|
683
|
-
assume_tables_exist=self.upsert_service_assume_tables_exist,
|
|
684
|
-
timeout_create=self.upsert_service_timeout_create,
|
|
685
|
-
error_create=self.upsert_service_error_create,
|
|
686
|
-
timeout_insert=self.upsert_service_timeout_insert,
|
|
687
|
-
error_insert=self.upsert_service_error_insert,
|
|
688
|
-
)
|
|
689
|
-
|
|
690
|
-
def _yield_sub_loopers(self) -> Iterator[Looper[Any]]:
|
|
691
|
-
with suppress_super_object_attribute_error():
|
|
692
|
-
yield from super()._yield_sub_loopers() # pyright: ignore[reportAttributeAccessIssue]
|
|
693
|
-
yield self._upsert_service
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
##
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
type _SelectedOrAll = Literal["selected", "all"]
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
async def upsert_items(
|
|
703
|
-
engine: AsyncEngine,
|
|
704
|
-
/,
|
|
705
|
-
*items: _InsertItem,
|
|
706
|
-
snake: bool = False,
|
|
707
|
-
selected_or_all: _SelectedOrAll = "selected",
|
|
708
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
709
|
-
assume_tables_exist: bool = False,
|
|
710
|
-
timeout_create: Duration | None = None,
|
|
711
|
-
error_create: type[Exception] = TimeoutError,
|
|
712
|
-
timeout_insert: Duration | None = None,
|
|
713
|
-
error_insert: type[Exception] = TimeoutError,
|
|
714
|
-
) -> None:
|
|
715
|
-
"""Upsert a set of items into a database.
|
|
716
|
-
|
|
717
|
-
These can be one of the following:
|
|
718
|
-
- pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
|
|
719
|
-
- pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
|
|
720
|
-
{k1=v21, k2=v22, ...},
|
|
721
|
-
...], table/class
|
|
722
|
-
- list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
|
|
723
|
-
({k1=v21, k2=v22, ...}, table_cls2),
|
|
724
|
-
...]
|
|
725
|
-
- mapped class: Obj(k1=v1, k2=v2, ...)
|
|
726
|
-
- list of mapped classes: [Obj(k1=v11, k2=v12, ...),
|
|
727
|
-
Obj(k1=v21, k2=v22, ...),
|
|
728
|
-
...]
|
|
729
|
-
"""
|
|
730
|
-
|
|
731
|
-
def build_insert(
|
|
732
|
-
table: Table, values: Iterable[StrMapping], /
|
|
733
|
-
) -> tuple[Insert, None]:
|
|
734
|
-
ups = _upsert_items_build(
|
|
735
|
-
engine, table, values, selected_or_all=selected_or_all
|
|
736
|
-
)
|
|
737
|
-
return ups, None
|
|
738
|
-
|
|
739
|
-
try:
|
|
740
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
741
|
-
partial(
|
|
742
|
-
_normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
|
|
743
|
-
),
|
|
744
|
-
engine,
|
|
745
|
-
build_insert,
|
|
746
|
-
*items,
|
|
747
|
-
chunk_size_frac=chunk_size_frac,
|
|
748
|
-
)
|
|
749
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
750
|
-
raise UpsertItemsError(item=error.item) from None
|
|
751
|
-
if not assume_tables_exist:
|
|
752
|
-
await ensure_tables_created(
|
|
753
|
-
engine, *prepared.tables, timeout=timeout_create, error=error_create
|
|
754
|
-
)
|
|
755
|
-
for ups, _ in prepared.yield_pairs():
|
|
756
|
-
async with yield_connection(
|
|
757
|
-
engine, timeout=timeout_insert, error=error_insert
|
|
758
|
-
) as conn:
|
|
759
|
-
_ = await conn.execute(ups)
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
def _upsert_items_build(
|
|
763
|
-
engine: AsyncEngine,
|
|
764
|
-
table: Table,
|
|
765
|
-
values: Iterable[StrMapping],
|
|
766
|
-
/,
|
|
767
|
-
*,
|
|
768
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
769
|
-
) -> Insert:
|
|
770
|
-
values = list(values)
|
|
771
|
-
keys = merge_sets(*values)
|
|
772
|
-
dict_nones = dict.fromkeys(keys)
|
|
773
|
-
values = [{**dict_nones, **v} for v in values]
|
|
774
|
-
match _get_dialect(engine):
|
|
775
|
-
case "postgresql": # skipif-ci-and-not-linux
|
|
776
|
-
insert = postgresql_insert
|
|
777
|
-
case "sqlite":
|
|
778
|
-
insert = sqlite_insert
|
|
779
|
-
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
780
|
-
raise NotImplementedError(dialect)
|
|
781
|
-
case _ as never:
|
|
782
|
-
assert_never(never)
|
|
783
|
-
ins = insert(table).values(values)
|
|
784
|
-
primary_key = cast("Any", table.primary_key)
|
|
785
|
-
return _upsert_items_apply_on_conflict_do_update(
|
|
786
|
-
values, ins, primary_key, selected_or_all=selected_or_all
|
|
787
|
-
)
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
def _upsert_items_apply_on_conflict_do_update(
|
|
791
|
-
values: Iterable[StrMapping],
|
|
792
|
-
insert: postgresql_Insert | sqlite_Insert,
|
|
793
|
-
primary_key: PrimaryKeyConstraint,
|
|
794
|
-
/,
|
|
795
|
-
*,
|
|
796
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
797
|
-
) -> Insert:
|
|
798
|
-
match selected_or_all:
|
|
799
|
-
case "selected":
|
|
800
|
-
columns = merge_sets(*values)
|
|
801
|
-
case "all":
|
|
802
|
-
columns = {c.name for c in insert.excluded}
|
|
803
|
-
case _ as never:
|
|
804
|
-
assert_never(never)
|
|
805
|
-
set_ = {c: getattr(insert.excluded, c) for c in columns}
|
|
806
|
-
match insert:
|
|
807
|
-
case postgresql_Insert(): # skipif-ci
|
|
808
|
-
return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
809
|
-
case sqlite_Insert():
|
|
810
|
-
return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
811
|
-
case _ as never:
|
|
812
|
-
assert_never(never)
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
@dataclass(kw_only=True, slots=True)
|
|
816
|
-
class UpsertItemsError(Exception):
|
|
817
|
-
item: _InsertItem
|
|
818
|
-
|
|
819
|
-
@override
|
|
820
|
-
def __str__(self) -> str:
|
|
821
|
-
return f"Item must be valid; got {self.item}"
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
##
|
|
825
|
-
|
|
826
|
-
|
|
827
914
|
@asynccontextmanager
|
|
828
915
|
async def yield_connection(
|
|
829
916
|
engine: AsyncEngine,
|
|
830
917
|
/,
|
|
831
918
|
*,
|
|
832
|
-
timeout:
|
|
833
|
-
error:
|
|
919
|
+
timeout: Delta | None = None,
|
|
920
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
834
921
|
) -> AsyncIterator[AsyncConnection]:
|
|
835
922
|
"""Yield an async connection."""
|
|
836
|
-
|
|
837
|
-
|
|
923
|
+
try:
|
|
924
|
+
async with timeout_td(timeout, error=error), engine.begin() as conn:
|
|
925
|
+
yield conn
|
|
926
|
+
except GeneratorExit: # pragma: no cover
|
|
927
|
+
if not is_pytest():
|
|
928
|
+
raise
|
|
929
|
+
return
|
|
838
930
|
|
|
839
931
|
|
|
840
932
|
##
|
|
@@ -865,7 +957,7 @@ def _ensure_tables_maybe_reraise(error: DatabaseError, match: str, /) -> None:
|
|
|
865
957
|
##
|
|
866
958
|
|
|
867
959
|
|
|
868
|
-
def _get_dialect(engine_or_conn:
|
|
960
|
+
def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
|
|
869
961
|
"""Get the dialect of a database."""
|
|
870
962
|
dialect = engine_or_conn.dialect
|
|
871
963
|
if isinstance(dialect, mssql_dialect): # pragma: no cover
|
|
@@ -875,7 +967,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
|
|
|
875
967
|
if isinstance(dialect, oracle_dialect): # pragma: no cover
|
|
876
968
|
return "oracle"
|
|
877
969
|
if isinstance( # skipif-ci-and-not-linux
|
|
878
|
-
dialect, postgresql_dialect
|
|
970
|
+
dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
|
|
879
971
|
):
|
|
880
972
|
return "postgresql"
|
|
881
973
|
if isinstance(dialect, sqlite_dialect):
|
|
@@ -888,7 +980,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
|
|
|
888
980
|
|
|
889
981
|
|
|
890
982
|
def _get_dialect_max_params(
|
|
891
|
-
dialect_or_engine_or_conn:
|
|
983
|
+
dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync, /
|
|
892
984
|
) -> int:
|
|
893
985
|
"""Get the max number of parameters of a dialect."""
|
|
894
986
|
match dialect_or_engine_or_conn:
|
|
@@ -910,7 +1002,7 @@ def _get_dialect_max_params(
|
|
|
910
1002
|
):
|
|
911
1003
|
dialect = _get_dialect(engine_or_conn)
|
|
912
1004
|
return _get_dialect_max_params(dialect)
|
|
913
|
-
case
|
|
1005
|
+
case never:
|
|
914
1006
|
assert_never(never)
|
|
915
1007
|
|
|
916
1008
|
|
|
@@ -943,9 +1035,9 @@ def _is_pair_of_tuple_or_str_mapping_and_table(
|
|
|
943
1035
|
return _is_pair_with_predicate_and_table(obj, is_tuple_or_str_mapping)
|
|
944
1036
|
|
|
945
1037
|
|
|
946
|
-
def _is_pair_with_predicate_and_table(
|
|
947
|
-
obj: Any, predicate: Callable[[Any], TypeGuard[
|
|
948
|
-
) -> TypeGuard[tuple[
|
|
1038
|
+
def _is_pair_with_predicate_and_table[T](
|
|
1039
|
+
obj: Any, predicate: Callable[[Any], TypeGuard[T]], /
|
|
1040
|
+
) -> TypeGuard[tuple[T, TableOrORMInstOrClass]]:
|
|
949
1041
|
"""Check if an object is pair and a table."""
|
|
950
1042
|
return (
|
|
951
1043
|
isinstance(obj, tuple)
|
|
@@ -995,7 +1087,7 @@ def _map_mapping_to_table(
|
|
|
995
1087
|
@dataclass(kw_only=True, slots=True)
|
|
996
1088
|
class _MapMappingToTableError(Exception):
|
|
997
1089
|
mapping: StrMapping
|
|
998
|
-
columns:
|
|
1090
|
+
columns: list[str]
|
|
999
1091
|
|
|
1000
1092
|
|
|
1001
1093
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1032,102 +1124,31 @@ class _MapMappingToTableSnakeMapNonUniqueError(_MapMappingToTableError):
|
|
|
1032
1124
|
|
|
1033
1125
|
def _orm_inst_to_dict(obj: DeclarativeBase, /) -> StrMapping:
|
|
1034
1126
|
"""Map an ORM instance to a dictionary."""
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
):
|
|
1041
|
-
return attr
|
|
1042
|
-
return None
|
|
1043
|
-
|
|
1044
|
-
def yield_items() -> Iterator[tuple[str, Any]]:
|
|
1045
|
-
for key in get_column_names(cls):
|
|
1046
|
-
attr = one(attr for attr in dir(cls) if is_attr(attr, key) is not None)
|
|
1047
|
-
yield key, getattr(obj, attr)
|
|
1048
|
-
|
|
1049
|
-
return dict(yield_items())
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
##
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
@dataclass(kw_only=True, slots=True)
|
|
1056
|
-
class _PrepareInsertOrUpsertItems:
|
|
1057
|
-
mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
|
|
1058
|
-
yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
|
|
1059
|
-
|
|
1060
|
-
@property
|
|
1061
|
-
def tables(self) -> Sequence[Table]:
|
|
1062
|
-
return list(self.mapping)
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
def _prepare_insert_or_upsert_items(
|
|
1066
|
-
normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
|
|
1067
|
-
engine: AsyncEngine,
|
|
1068
|
-
build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
|
|
1069
|
-
/,
|
|
1070
|
-
*items: Any,
|
|
1071
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
1072
|
-
) -> _PrepareInsertOrUpsertItems:
|
|
1073
|
-
"""Prepare a set of insert/upsert items."""
|
|
1074
|
-
mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
1075
|
-
lengths: set[int] = set()
|
|
1076
|
-
try:
|
|
1077
|
-
for item in items:
|
|
1078
|
-
for normed in normalize_item(item):
|
|
1079
|
-
mapping[normed.table].append(normed.mapping)
|
|
1080
|
-
lengths.add(len(normed.mapping))
|
|
1081
|
-
except _NormalizeInsertItemError as error:
|
|
1082
|
-
raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
|
|
1083
|
-
merged = {
|
|
1084
|
-
table: _prepare_insert_or_upsert_items_merge_items(table, values)
|
|
1085
|
-
for table, values in mapping.items()
|
|
1127
|
+
attrs = {
|
|
1128
|
+
k for k, _ in yield_object_attributes(obj, static_type=InstrumentedAttribute)
|
|
1129
|
+
}
|
|
1130
|
+
return {
|
|
1131
|
+
name: _orm_inst_to_dict_one(obj, attrs, name) for name in get_column_names(obj)
|
|
1086
1132
|
}
|
|
1087
|
-
max_length = max(lengths, default=1)
|
|
1088
|
-
chunk_size = get_chunk_size(
|
|
1089
|
-
engine, chunk_size_frac=chunk_size_frac, scaling=max_length
|
|
1090
|
-
)
|
|
1091
|
-
|
|
1092
|
-
def yield_pairs() -> Iterator[tuple[Insert, None]]:
|
|
1093
|
-
for table, values in merged.items():
|
|
1094
|
-
for chunk in chunked(values, chunk_size):
|
|
1095
|
-
yield build_insert(table, chunk)
|
|
1096
|
-
|
|
1097
|
-
return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
|
|
1098
1133
|
|
|
1099
1134
|
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1135
|
+
def _orm_inst_to_dict_one(
|
|
1136
|
+
obj: DeclarativeBase, attrs: AbstractSet[str], name: str, /
|
|
1137
|
+
) -> Any:
|
|
1138
|
+
attr = one(
|
|
1139
|
+
attr for attr in attrs if _orm_inst_to_dict_predicate(type(obj), attr, name)
|
|
1140
|
+
)
|
|
1141
|
+
return getattr(obj, attr)
|
|
1107
1142
|
|
|
1108
1143
|
|
|
1109
|
-
def
|
|
1110
|
-
|
|
1111
|
-
) ->
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
unchanged: list[StrMapping] = []
|
|
1118
|
-
for item in items:
|
|
1119
|
-
check_subset(cols_non_auto, item)
|
|
1120
|
-
has_all_auto = set(cols_auto).issubset(item)
|
|
1121
|
-
if has_all_auto:
|
|
1122
|
-
pkey = tuple(item[k] for k in col_names)
|
|
1123
|
-
rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
|
|
1124
|
-
mapping[pkey].append(rest)
|
|
1125
|
-
else:
|
|
1126
|
-
unchanged.append(item)
|
|
1127
|
-
merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
|
|
1128
|
-
return [
|
|
1129
|
-
dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
|
|
1130
|
-
] + unchanged
|
|
1144
|
+
def _orm_inst_to_dict_predicate(
|
|
1145
|
+
cls: type[DeclarativeBase], attr: str, name: str, /
|
|
1146
|
+
) -> bool:
|
|
1147
|
+
cls_attr = getattr(cls, attr)
|
|
1148
|
+
try:
|
|
1149
|
+
return cls_attr.name == name
|
|
1150
|
+
except AttributeError:
|
|
1151
|
+
return False
|
|
1131
1152
|
|
|
1132
1153
|
|
|
1133
1154
|
##
|
|
@@ -1144,18 +1165,27 @@ def _tuple_to_mapping(
|
|
|
1144
1165
|
__all__ = [
|
|
1145
1166
|
"CHUNK_SIZE_FRAC",
|
|
1146
1167
|
"CheckEngineError",
|
|
1168
|
+
"DialectOrEngineOrConnectionOrAsync",
|
|
1169
|
+
"EngineOrConnectionOrAsync",
|
|
1170
|
+
"ExtractURLError",
|
|
1171
|
+
"ExtractURLOutput",
|
|
1147
1172
|
"GetTableError",
|
|
1148
1173
|
"InsertItemsError",
|
|
1149
1174
|
"TablenameMixin",
|
|
1150
|
-
"
|
|
1151
|
-
"
|
|
1152
|
-
"UpsertServiceMixin",
|
|
1175
|
+
"check_connect",
|
|
1176
|
+
"check_connect_async",
|
|
1153
1177
|
"check_engine",
|
|
1154
1178
|
"columnwise_max",
|
|
1155
1179
|
"columnwise_min",
|
|
1156
|
-
"
|
|
1180
|
+
"create_engine",
|
|
1181
|
+
"ensure_database_created",
|
|
1182
|
+
"ensure_database_dropped",
|
|
1183
|
+
"ensure_database_users_disconnected",
|
|
1157
1184
|
"ensure_tables_created",
|
|
1158
1185
|
"ensure_tables_dropped",
|
|
1186
|
+
"enum_name",
|
|
1187
|
+
"enum_values",
|
|
1188
|
+
"extract_url",
|
|
1159
1189
|
"get_chunk_size",
|
|
1160
1190
|
"get_column_names",
|
|
1161
1191
|
"get_columns",
|
|
@@ -1168,7 +1198,6 @@ __all__ = [
|
|
|
1168
1198
|
"is_table_or_orm",
|
|
1169
1199
|
"migrate_data",
|
|
1170
1200
|
"selectable_to_string",
|
|
1171
|
-
"upsert_items",
|
|
1172
1201
|
"yield_connection",
|
|
1173
1202
|
"yield_primary_key_columns",
|
|
1174
1203
|
]
|