dycw-utilities 0.135.0__py3-none-any.whl → 0.178.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dycw-utilities might be problematic. Click here for more details.
- dycw_utilities-0.178.1.dist-info/METADATA +34 -0
- dycw_utilities-0.178.1.dist-info/RECORD +105 -0
- dycw_utilities-0.178.1.dist-info/WHEEL +4 -0
- dycw_utilities-0.178.1.dist-info/entry_points.txt +4 -0
- utilities/__init__.py +1 -1
- utilities/altair.py +13 -10
- utilities/asyncio.py +312 -787
- utilities/atomicwrites.py +18 -6
- utilities/atools.py +64 -4
- utilities/cachetools.py +9 -6
- utilities/click.py +195 -77
- utilities/concurrent.py +1 -1
- utilities/contextlib.py +216 -17
- utilities/contextvars.py +20 -1
- utilities/cryptography.py +3 -3
- utilities/dataclasses.py +15 -28
- utilities/docker.py +387 -0
- utilities/enum.py +2 -2
- utilities/errors.py +17 -3
- utilities/fastapi.py +28 -59
- utilities/fpdf2.py +2 -2
- utilities/functions.py +24 -269
- utilities/git.py +9 -30
- utilities/grp.py +28 -0
- utilities/gzip.py +31 -0
- utilities/http.py +3 -2
- utilities/hypothesis.py +513 -159
- utilities/importlib.py +17 -1
- utilities/inflect.py +12 -4
- utilities/iterables.py +33 -58
- utilities/jinja2.py +148 -0
- utilities/json.py +70 -0
- utilities/libcst.py +38 -17
- utilities/lightweight_charts.py +4 -7
- utilities/logging.py +136 -93
- utilities/math.py +8 -4
- utilities/more_itertools.py +43 -45
- utilities/operator.py +27 -27
- utilities/orjson.py +189 -36
- utilities/os.py +61 -4
- utilities/packaging.py +115 -0
- utilities/parse.py +8 -5
- utilities/pathlib.py +269 -40
- utilities/permissions.py +298 -0
- utilities/platform.py +7 -6
- utilities/polars.py +1205 -413
- utilities/polars_ols.py +1 -1
- utilities/postgres.py +408 -0
- utilities/pottery.py +43 -19
- utilities/pqdm.py +3 -3
- utilities/psutil.py +5 -57
- utilities/pwd.py +28 -0
- utilities/pydantic.py +4 -52
- utilities/pydantic_settings.py +240 -0
- utilities/pydantic_settings_sops.py +76 -0
- utilities/pyinstrument.py +7 -7
- utilities/pytest.py +104 -143
- utilities/pytest_plugins/__init__.py +1 -0
- utilities/pytest_plugins/pytest_randomly.py +23 -0
- utilities/pytest_plugins/pytest_regressions.py +56 -0
- utilities/pytest_regressions.py +26 -46
- utilities/random.py +11 -6
- utilities/re.py +1 -1
- utilities/redis.py +220 -343
- utilities/sentinel.py +10 -0
- utilities/shelve.py +4 -1
- utilities/shutil.py +25 -0
- utilities/slack_sdk.py +35 -104
- utilities/sqlalchemy.py +496 -471
- utilities/sqlalchemy_polars.py +29 -54
- utilities/string.py +2 -3
- utilities/subprocess.py +1977 -0
- utilities/tempfile.py +112 -4
- utilities/testbook.py +50 -0
- utilities/text.py +174 -42
- utilities/throttle.py +158 -0
- utilities/timer.py +2 -2
- utilities/traceback.py +70 -35
- utilities/types.py +102 -30
- utilities/typing.py +479 -19
- utilities/uuid.py +42 -5
- utilities/version.py +27 -26
- utilities/whenever.py +1559 -361
- utilities/zoneinfo.py +80 -22
- dycw_utilities-0.135.0.dist-info/METADATA +0 -39
- dycw_utilities-0.135.0.dist-info/RECORD +0 -96
- dycw_utilities-0.135.0.dist-info/WHEEL +0 -4
- dycw_utilities-0.135.0.dist-info/licenses/LICENSE +0 -21
- utilities/aiolimiter.py +0 -25
- utilities/arq.py +0 -216
- utilities/eventkit.py +0 -388
- utilities/luigi.py +0 -183
- utilities/period.py +0 -152
- utilities/pudb.py +0 -62
- utilities/python_dotenv.py +0 -101
- utilities/streamlit.py +0 -105
- utilities/typed_settings.py +0 -123
utilities/sqlalchemy.py
CHANGED
|
@@ -12,21 +12,31 @@ from collections.abc import (
|
|
|
12
12
|
)
|
|
13
13
|
from collections.abc import Set as AbstractSet
|
|
14
14
|
from contextlib import asynccontextmanager
|
|
15
|
-
from dataclasses import dataclass
|
|
16
|
-
from functools import
|
|
15
|
+
from dataclasses import dataclass
|
|
16
|
+
from functools import reduce
|
|
17
17
|
from itertools import chain
|
|
18
18
|
from math import floor
|
|
19
19
|
from operator import ge, le
|
|
20
20
|
from re import search
|
|
21
|
-
from
|
|
21
|
+
from socket import gaierror
|
|
22
|
+
from typing import (
|
|
23
|
+
TYPE_CHECKING,
|
|
24
|
+
Any,
|
|
25
|
+
Literal,
|
|
26
|
+
TypeGuard,
|
|
27
|
+
assert_never,
|
|
28
|
+
cast,
|
|
29
|
+
overload,
|
|
30
|
+
override,
|
|
31
|
+
)
|
|
22
32
|
|
|
33
|
+
import sqlalchemy
|
|
23
34
|
from sqlalchemy import (
|
|
24
35
|
URL,
|
|
25
36
|
Column,
|
|
26
37
|
Connection,
|
|
27
38
|
Engine,
|
|
28
39
|
Insert,
|
|
29
|
-
PrimaryKeyConstraint,
|
|
30
40
|
Selectable,
|
|
31
41
|
Table,
|
|
32
42
|
and_,
|
|
@@ -38,16 +48,14 @@ from sqlalchemy import (
|
|
|
38
48
|
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
|
|
39
49
|
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
|
|
40
50
|
from sqlalchemy.dialects.oracle import dialect as oracle_dialect
|
|
41
|
-
from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
|
|
42
51
|
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
|
|
43
52
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
|
44
53
|
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
|
|
45
|
-
from sqlalchemy.dialects.
|
|
54
|
+
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
|
|
46
55
|
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
|
|
47
56
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|
48
57
|
from sqlalchemy.exc import ArgumentError, DatabaseError
|
|
49
|
-
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
50
|
-
from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine
|
|
58
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
|
51
59
|
from sqlalchemy.orm import (
|
|
52
60
|
DeclarativeBase,
|
|
53
61
|
InstrumentedAttribute,
|
|
@@ -57,16 +65,8 @@ from sqlalchemy.orm import (
|
|
|
57
65
|
from sqlalchemy.orm.exc import UnmappedClassError
|
|
58
66
|
from sqlalchemy.pool import NullPool, Pool
|
|
59
67
|
|
|
60
|
-
from utilities.asyncio import
|
|
61
|
-
from utilities.
|
|
62
|
-
from utilities.functions import (
|
|
63
|
-
ensure_str,
|
|
64
|
-
get_class_name,
|
|
65
|
-
is_sequence_of_tuple_or_str_mapping,
|
|
66
|
-
is_string_mapping,
|
|
67
|
-
is_tuple,
|
|
68
|
-
is_tuple_or_str_mapping,
|
|
69
|
-
)
|
|
68
|
+
from utilities.asyncio import timeout_td
|
|
69
|
+
from utilities.functions import ensure_str, get_class_name, yield_object_attributes
|
|
70
70
|
from utilities.iterables import (
|
|
71
71
|
CheckLengthError,
|
|
72
72
|
CheckSubSetError,
|
|
@@ -79,19 +79,63 @@ from utilities.iterables import (
|
|
|
79
79
|
merge_str_mappings,
|
|
80
80
|
one,
|
|
81
81
|
)
|
|
82
|
+
from utilities.os import is_pytest
|
|
82
83
|
from utilities.reprlib import get_repr
|
|
83
|
-
from utilities.text import snake_case
|
|
84
|
-
from utilities.types import
|
|
85
|
-
|
|
84
|
+
from utilities.text import secret_str, snake_case
|
|
85
|
+
from utilities.types import (
|
|
86
|
+
Delta,
|
|
87
|
+
MaybeIterable,
|
|
88
|
+
MaybeType,
|
|
89
|
+
StrMapping,
|
|
90
|
+
TupleOrStrMapping,
|
|
91
|
+
)
|
|
92
|
+
from utilities.typing import (
|
|
93
|
+
is_sequence_of_tuple_or_str_mapping,
|
|
94
|
+
is_string_mapping,
|
|
95
|
+
is_tuple,
|
|
96
|
+
is_tuple_or_str_mapping,
|
|
97
|
+
)
|
|
86
98
|
|
|
87
99
|
if TYPE_CHECKING:
|
|
88
|
-
from
|
|
100
|
+
from enum import Enum, StrEnum
|
|
101
|
+
|
|
89
102
|
|
|
90
|
-
type
|
|
103
|
+
type EngineOrConnectionOrAsync = Engine | Connection | AsyncEngine | AsyncConnection
|
|
91
104
|
type Dialect = Literal["mssql", "mysql", "oracle", "postgresql", "sqlite"]
|
|
105
|
+
type DialectOrEngineOrConnectionOrAsync = Dialect | EngineOrConnectionOrAsync
|
|
92
106
|
type ORMInstOrClass = DeclarativeBase | type[DeclarativeBase]
|
|
93
107
|
type TableOrORMInstOrClass = Table | ORMInstOrClass
|
|
94
|
-
CHUNK_SIZE_FRAC = 0.
|
|
108
|
+
CHUNK_SIZE_FRAC = 0.8
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
##
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_SELECT = text("SELECT 1")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def check_connect(engine: Engine, /) -> bool:
|
|
118
|
+
"""Check if an engine can connect."""
|
|
119
|
+
try:
|
|
120
|
+
with engine.connect() as conn:
|
|
121
|
+
return bool(conn.execute(_SELECT).scalar_one())
|
|
122
|
+
except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
|
|
123
|
+
return False
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def check_connect_async(
|
|
127
|
+
engine: AsyncEngine,
|
|
128
|
+
/,
|
|
129
|
+
*,
|
|
130
|
+
timeout: Delta | None = None,
|
|
131
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
132
|
+
) -> bool:
|
|
133
|
+
"""Check if an engine can connect."""
|
|
134
|
+
try:
|
|
135
|
+
async with timeout_td(timeout, error=error), engine.connect() as conn:
|
|
136
|
+
return bool((await conn.execute(_SELECT)).scalar_one())
|
|
137
|
+
except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
|
|
138
|
+
return False
|
|
95
139
|
|
|
96
140
|
|
|
97
141
|
##
|
|
@@ -101,8 +145,8 @@ async def check_engine(
|
|
|
101
145
|
engine: AsyncEngine,
|
|
102
146
|
/,
|
|
103
147
|
*,
|
|
104
|
-
timeout:
|
|
105
|
-
error:
|
|
148
|
+
timeout: Delta | None = None,
|
|
149
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
106
150
|
num_tables: int | tuple[int, float] | None = None,
|
|
107
151
|
) -> None:
|
|
108
152
|
"""Check that an engine can connect.
|
|
@@ -117,7 +161,7 @@ async def check_engine(
|
|
|
117
161
|
query = "select * from all_objects"
|
|
118
162
|
case "sqlite":
|
|
119
163
|
query = "select * from sqlite_master where type='table'"
|
|
120
|
-
case
|
|
164
|
+
case never:
|
|
121
165
|
assert_never(never)
|
|
122
166
|
statement = text(query)
|
|
123
167
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
@@ -185,7 +229,8 @@ def _columnwise_minmax(*columns: Any, op: Callable[[Any, Any], Any]) -> Any:
|
|
|
185
229
|
##
|
|
186
230
|
|
|
187
231
|
|
|
188
|
-
|
|
232
|
+
@overload
|
|
233
|
+
def create_engine(
|
|
189
234
|
drivername: str,
|
|
190
235
|
/,
|
|
191
236
|
*,
|
|
@@ -196,7 +241,49 @@ def create_async_engine(
|
|
|
196
241
|
database: str | None = None,
|
|
197
242
|
query: StrMapping | None = None,
|
|
198
243
|
poolclass: type[Pool] | None = NullPool,
|
|
199
|
-
|
|
244
|
+
async_: Literal[True],
|
|
245
|
+
) -> AsyncEngine: ...
|
|
246
|
+
@overload
|
|
247
|
+
def create_engine(
|
|
248
|
+
drivername: str,
|
|
249
|
+
/,
|
|
250
|
+
*,
|
|
251
|
+
username: str | None = None,
|
|
252
|
+
password: str | None = None,
|
|
253
|
+
host: str | None = None,
|
|
254
|
+
port: int | None = None,
|
|
255
|
+
database: str | None = None,
|
|
256
|
+
query: StrMapping | None = None,
|
|
257
|
+
poolclass: type[Pool] | None = NullPool,
|
|
258
|
+
async_: Literal[False] = False,
|
|
259
|
+
) -> Engine: ...
|
|
260
|
+
@overload
|
|
261
|
+
def create_engine(
|
|
262
|
+
drivername: str,
|
|
263
|
+
/,
|
|
264
|
+
*,
|
|
265
|
+
username: str | None = None,
|
|
266
|
+
password: str | None = None,
|
|
267
|
+
host: str | None = None,
|
|
268
|
+
port: int | None = None,
|
|
269
|
+
database: str | None = None,
|
|
270
|
+
query: StrMapping | None = None,
|
|
271
|
+
poolclass: type[Pool] | None = NullPool,
|
|
272
|
+
async_: bool = False,
|
|
273
|
+
) -> Engine | AsyncEngine: ...
|
|
274
|
+
def create_engine(
|
|
275
|
+
drivername: str,
|
|
276
|
+
/,
|
|
277
|
+
*,
|
|
278
|
+
username: str | None = None,
|
|
279
|
+
password: str | None = None,
|
|
280
|
+
host: str | None = None,
|
|
281
|
+
port: int | None = None,
|
|
282
|
+
database: str | None = None,
|
|
283
|
+
query: StrMapping | None = None,
|
|
284
|
+
poolclass: type[Pool] | None = NullPool,
|
|
285
|
+
async_: bool = False,
|
|
286
|
+
) -> Engine | AsyncEngine:
|
|
200
287
|
"""Create a SQLAlchemy engine."""
|
|
201
288
|
if query is None:
|
|
202
289
|
kwargs = {}
|
|
@@ -215,7 +302,47 @@ def create_async_engine(
|
|
|
215
302
|
database=database,
|
|
216
303
|
**kwargs,
|
|
217
304
|
)
|
|
218
|
-
|
|
305
|
+
match async_:
|
|
306
|
+
case False:
|
|
307
|
+
return sqlalchemy.create_engine(url, poolclass=poolclass)
|
|
308
|
+
case True:
|
|
309
|
+
return create_async_engine(url, poolclass=poolclass)
|
|
310
|
+
case never:
|
|
311
|
+
assert_never(never)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
##
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
async def ensure_database_created(super_: URL, database: str, /) -> None:
|
|
318
|
+
"""Ensure a database is created."""
|
|
319
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
320
|
+
async with engine.begin() as conn:
|
|
321
|
+
try:
|
|
322
|
+
_ = await conn.execute(text(f"CREATE DATABASE {database}"))
|
|
323
|
+
except DatabaseError as error:
|
|
324
|
+
_ensure_tables_maybe_reraise(error, 'database ".*" already exists')
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
async def ensure_database_dropped(super_: URL, database: str, /) -> None:
|
|
328
|
+
"""Ensure a database is dropped."""
|
|
329
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
330
|
+
async with engine.begin() as conn:
|
|
331
|
+
_ = await conn.execute(text(f"DROP DATABASE IF EXISTS {database}"))
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
async def ensure_database_users_disconnected(super_: URL, database: str, /) -> None:
|
|
335
|
+
"""Ensure a databases' users are disconnected."""
|
|
336
|
+
engine = create_async_engine(super_, isolation_level="AUTOCOMMIT")
|
|
337
|
+
match dialect := _get_dialect(engine):
|
|
338
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
339
|
+
query = f"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {database!r} AND pid <> pg_backend_pid()" # noqa: S608
|
|
340
|
+
case "mssql" | "mysql" | "oracle" | "sqlite": # pragma: no cover
|
|
341
|
+
raise NotImplementedError(dialect)
|
|
342
|
+
case never:
|
|
343
|
+
assert_never(never)
|
|
344
|
+
async with engine.begin() as conn:
|
|
345
|
+
_ = await conn.execute(text(query))
|
|
219
346
|
|
|
220
347
|
|
|
221
348
|
##
|
|
@@ -225,8 +352,8 @@ async def ensure_tables_created(
|
|
|
225
352
|
engine: AsyncEngine,
|
|
226
353
|
/,
|
|
227
354
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
228
|
-
timeout:
|
|
229
|
-
error:
|
|
355
|
+
timeout: Delta | None = None,
|
|
356
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
230
357
|
) -> None:
|
|
231
358
|
"""Ensure a table/set of tables is/are created."""
|
|
232
359
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -241,7 +368,7 @@ async def ensure_tables_created(
|
|
|
241
368
|
match = "ORA-00955: name is already used by an existing object"
|
|
242
369
|
case "sqlite":
|
|
243
370
|
match = "table .* already exists"
|
|
244
|
-
case
|
|
371
|
+
case never:
|
|
245
372
|
assert_never(never)
|
|
246
373
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
247
374
|
for table in tables:
|
|
@@ -254,8 +381,8 @@ async def ensure_tables_created(
|
|
|
254
381
|
async def ensure_tables_dropped(
|
|
255
382
|
engine: AsyncEngine,
|
|
256
383
|
*tables_or_orms: TableOrORMInstOrClass,
|
|
257
|
-
timeout:
|
|
258
|
-
error:
|
|
384
|
+
timeout: Delta | None = None,
|
|
385
|
+
error: MaybeType[BaseException] = TimeoutError,
|
|
259
386
|
) -> None:
|
|
260
387
|
"""Ensure a table/set of tables is/are dropped."""
|
|
261
388
|
tables = set(map(get_table, tables_or_orms))
|
|
@@ -270,7 +397,7 @@ async def ensure_tables_dropped(
|
|
|
270
397
|
match = "ORA-00942: table or view does not exist"
|
|
271
398
|
case "sqlite":
|
|
272
399
|
match = "no such table"
|
|
273
|
-
case
|
|
400
|
+
case never:
|
|
274
401
|
assert_never(never)
|
|
275
402
|
async with yield_connection(engine, timeout=timeout, error=error) as conn:
|
|
276
403
|
for table in tables:
|
|
@@ -283,18 +410,120 @@ async def ensure_tables_dropped(
|
|
|
283
410
|
##
|
|
284
411
|
|
|
285
412
|
|
|
413
|
+
def enum_name(enum: type[Enum], /) -> str:
|
|
414
|
+
"""Get the name of an Enum."""
|
|
415
|
+
return f"{snake_case(get_class_name(enum))}_enum"
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
##
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def enum_values(enum: type[StrEnum], /) -> list[str]:
|
|
422
|
+
"""Get the values of a StrEnum."""
|
|
423
|
+
return [e.value for e in enum]
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
##
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
@dataclass(kw_only=True, slots=True)
|
|
430
|
+
class ExtractURLOutput:
|
|
431
|
+
username: str
|
|
432
|
+
password: secret_str
|
|
433
|
+
host: str
|
|
434
|
+
port: int
|
|
435
|
+
database: str
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def extract_url(url: URL, /) -> ExtractURLOutput:
|
|
439
|
+
"""Extract the database, host & port from a URL."""
|
|
440
|
+
if url.username is None:
|
|
441
|
+
raise _ExtractURLUsernameError(url=url)
|
|
442
|
+
if url.password is None:
|
|
443
|
+
raise _ExtractURLPasswordError(url=url)
|
|
444
|
+
if url.host is None:
|
|
445
|
+
raise _ExtractURLHostError(url=url)
|
|
446
|
+
if url.port is None:
|
|
447
|
+
raise _ExtractURLPortError(url=url)
|
|
448
|
+
if url.database is None:
|
|
449
|
+
raise _ExtractURLDatabaseError(url=url)
|
|
450
|
+
return ExtractURLOutput(
|
|
451
|
+
username=url.username,
|
|
452
|
+
password=secret_str(url.password),
|
|
453
|
+
host=url.host,
|
|
454
|
+
port=url.port,
|
|
455
|
+
database=url.database,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@dataclass(kw_only=True, slots=True)
|
|
460
|
+
class ExtractURLError(Exception):
|
|
461
|
+
url: URL
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
@dataclass(kw_only=True, slots=True)
|
|
465
|
+
class _ExtractURLUsernameError(ExtractURLError):
|
|
466
|
+
@override
|
|
467
|
+
def __str__(self) -> str:
|
|
468
|
+
return f"Expected URL to contain a user name; got {self.url}"
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@dataclass(kw_only=True, slots=True)
|
|
472
|
+
class _ExtractURLPasswordError(ExtractURLError):
|
|
473
|
+
@override
|
|
474
|
+
def __str__(self) -> str:
|
|
475
|
+
return f"Expected URL to contain a password; got {self.url}"
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
@dataclass(kw_only=True, slots=True)
|
|
479
|
+
class _ExtractURLHostError(ExtractURLError):
|
|
480
|
+
@override
|
|
481
|
+
def __str__(self) -> str:
|
|
482
|
+
return f"Expected URL to contain a host; got {self.url}"
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
@dataclass(kw_only=True, slots=True)
|
|
486
|
+
class _ExtractURLPortError(ExtractURLError):
|
|
487
|
+
@override
|
|
488
|
+
def __str__(self) -> str:
|
|
489
|
+
return f"Expected URL to contain a port; got {self.url}"
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
@dataclass(kw_only=True, slots=True)
|
|
493
|
+
class _ExtractURLDatabaseError(ExtractURLError):
|
|
494
|
+
@override
|
|
495
|
+
def __str__(self) -> str:
|
|
496
|
+
return f"Expected URL to contain a database; got {self.url}"
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
##
|
|
500
|
+
|
|
501
|
+
|
|
286
502
|
def get_chunk_size(
|
|
287
|
-
|
|
503
|
+
dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync,
|
|
504
|
+
table_or_orm_or_num_cols: TableOrORMInstOrClass | Sized | int,
|
|
288
505
|
/,
|
|
289
506
|
*,
|
|
290
507
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
291
|
-
max_length: int = 1,
|
|
292
508
|
) -> int:
|
|
293
509
|
"""Get the maximum chunk size for an engine."""
|
|
294
|
-
max_params = _get_dialect_max_params(
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
510
|
+
max_params = _get_dialect_max_params(dialect_or_engine_or_conn)
|
|
511
|
+
match table_or_orm_or_num_cols:
|
|
512
|
+
case Table() | DeclarativeBase() | type() as table_or_orm:
|
|
513
|
+
return get_chunk_size(
|
|
514
|
+
dialect_or_engine_or_conn,
|
|
515
|
+
get_columns(table_or_orm),
|
|
516
|
+
chunk_size_frac=chunk_size_frac,
|
|
517
|
+
)
|
|
518
|
+
case Sized() as sized:
|
|
519
|
+
return get_chunk_size(
|
|
520
|
+
dialect_or_engine_or_conn, len(sized), chunk_size_frac=chunk_size_frac
|
|
521
|
+
)
|
|
522
|
+
case int() as num_cols:
|
|
523
|
+
size = floor(chunk_size_frac * max_params / num_cols)
|
|
524
|
+
return max(size, 1)
|
|
525
|
+
case never:
|
|
526
|
+
assert_never(never)
|
|
298
527
|
|
|
299
528
|
|
|
300
529
|
##
|
|
@@ -374,18 +603,22 @@ type _InsertItem = (
|
|
|
374
603
|
| Sequence[_PairOfTupleOrStrMappingAndTable]
|
|
375
604
|
| Sequence[DeclarativeBase]
|
|
376
605
|
)
|
|
606
|
+
type _NormalizedItem = tuple[Table, StrMapping]
|
|
607
|
+
type _InsertPair = tuple[Table, Sequence[StrMapping]]
|
|
608
|
+
type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
|
|
377
609
|
|
|
378
610
|
|
|
379
611
|
async def insert_items(
|
|
380
612
|
engine: AsyncEngine,
|
|
381
613
|
*items: _InsertItem,
|
|
382
614
|
snake: bool = False,
|
|
615
|
+
is_upsert: bool = False,
|
|
383
616
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
384
617
|
assume_tables_exist: bool = False,
|
|
385
|
-
timeout_create:
|
|
386
|
-
error_create:
|
|
387
|
-
timeout_insert:
|
|
388
|
-
error_insert:
|
|
618
|
+
timeout_create: Delta | None = None,
|
|
619
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
620
|
+
timeout_insert: Delta | None = None,
|
|
621
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
389
622
|
) -> None:
|
|
390
623
|
"""Insert a set of items into a database.
|
|
391
624
|
|
|
@@ -409,37 +642,181 @@ async def insert_items(
|
|
|
409
642
|
Obj(k1=v21, k2=v22, ...),
|
|
410
643
|
...]
|
|
411
644
|
"""
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
return insert(table), values
|
|
419
|
-
case _:
|
|
420
|
-
return insert(table).values(list(values)), None
|
|
421
|
-
|
|
422
|
-
try:
|
|
423
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
424
|
-
partial(_normalize_insert_item, snake=snake),
|
|
425
|
-
engine,
|
|
426
|
-
build_insert,
|
|
427
|
-
*items,
|
|
428
|
-
chunk_size_frac=chunk_size_frac,
|
|
429
|
-
)
|
|
430
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
431
|
-
raise InsertItemsError(item=error.item) from None
|
|
645
|
+
normalized = chain.from_iterable(
|
|
646
|
+
_insert_items_yield_normalized(i, snake=snake) for i in items
|
|
647
|
+
)
|
|
648
|
+
triples = _insert_items_yield_triples(
|
|
649
|
+
engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
650
|
+
)
|
|
432
651
|
if not assume_tables_exist:
|
|
652
|
+
triples = list(triples)
|
|
653
|
+
tables = {table for table, _, _ in triples}
|
|
433
654
|
await ensure_tables_created(
|
|
434
|
-
engine, *
|
|
655
|
+
engine, *tables, timeout=timeout_create, error=error_create
|
|
435
656
|
)
|
|
436
|
-
for ins, parameters in
|
|
657
|
+
for _, ins, parameters in triples:
|
|
437
658
|
async with yield_connection(
|
|
438
659
|
engine, timeout=timeout_insert, error=error_insert
|
|
439
660
|
) as conn:
|
|
440
661
|
_ = await conn.execute(ins, parameters=parameters)
|
|
441
662
|
|
|
442
663
|
|
|
664
|
+
def _insert_items_yield_normalized(
|
|
665
|
+
item: _InsertItem, /, *, snake: bool = False
|
|
666
|
+
) -> Iterator[_NormalizedItem]:
|
|
667
|
+
if _is_pair_of_str_mapping_and_table(item):
|
|
668
|
+
mapping, table_or_orm = item
|
|
669
|
+
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
670
|
+
yield (get_table(table_or_orm), adjusted)
|
|
671
|
+
return
|
|
672
|
+
if _is_pair_of_tuple_and_table(item):
|
|
673
|
+
tuple_, table_or_orm = item
|
|
674
|
+
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
675
|
+
yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
|
|
676
|
+
return
|
|
677
|
+
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
678
|
+
items, table_or_orm = item
|
|
679
|
+
pairs = [(i, table_or_orm) for i in items]
|
|
680
|
+
for p in pairs:
|
|
681
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
682
|
+
return
|
|
683
|
+
if isinstance(item, DeclarativeBase):
|
|
684
|
+
mapping = _orm_inst_to_dict(item)
|
|
685
|
+
yield from _insert_items_yield_normalized((mapping, item), snake=snake)
|
|
686
|
+
return
|
|
687
|
+
try:
|
|
688
|
+
_ = iter(item)
|
|
689
|
+
except TypeError:
|
|
690
|
+
raise InsertItemsError(item=item) from None
|
|
691
|
+
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
692
|
+
pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
693
|
+
for p in pairs:
|
|
694
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
|
695
|
+
return
|
|
696
|
+
if all(map(is_orm, item)):
|
|
697
|
+
classes = cast("Sequence[DeclarativeBase]", item)
|
|
698
|
+
for c in classes:
|
|
699
|
+
yield from _insert_items_yield_normalized(c, snake=snake)
|
|
700
|
+
return
|
|
701
|
+
raise InsertItemsError(item=item)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def _insert_items_yield_triples(
|
|
705
|
+
engine: AsyncEngine,
|
|
706
|
+
items: Iterable[_NormalizedItem],
|
|
707
|
+
/,
|
|
708
|
+
*,
|
|
709
|
+
is_upsert: bool = False,
|
|
710
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
711
|
+
) -> Iterable[_InsertTriple]:
|
|
712
|
+
pairs = _insert_items_yield_chunked_pairs(
|
|
713
|
+
engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
|
714
|
+
)
|
|
715
|
+
for table, mappings in pairs:
|
|
716
|
+
match is_upsert, _get_dialect(engine):
|
|
717
|
+
case False, "oracle": # pragma: no cover
|
|
718
|
+
ins = insert(table)
|
|
719
|
+
parameters = mappings
|
|
720
|
+
case False, _:
|
|
721
|
+
ins = insert(table).values(mappings)
|
|
722
|
+
parameters = None
|
|
723
|
+
case True, _:
|
|
724
|
+
ins = _insert_items_build_insert_with_on_conflict_do_update(
|
|
725
|
+
engine, table, mappings
|
|
726
|
+
)
|
|
727
|
+
parameters = None
|
|
728
|
+
case never:
|
|
729
|
+
assert_never(never)
|
|
730
|
+
yield table, ins, parameters
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def _insert_items_yield_chunked_pairs(
|
|
734
|
+
engine: AsyncEngine,
|
|
735
|
+
items: Iterable[_NormalizedItem],
|
|
736
|
+
/,
|
|
737
|
+
*,
|
|
738
|
+
is_upsert: bool = False,
|
|
739
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
740
|
+
) -> Iterable[_InsertPair]:
|
|
741
|
+
for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
|
|
742
|
+
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
|
743
|
+
for mappings_i in chunked(mappings, chunk_size):
|
|
744
|
+
yield table, list(mappings_i)
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def _insert_items_yield_raw_pairs(
|
|
748
|
+
items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
|
|
749
|
+
) -> Iterable[_InsertPair]:
|
|
750
|
+
by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
751
|
+
for table, mapping in items:
|
|
752
|
+
by_table[table].append(mapping)
|
|
753
|
+
for table, mappings in by_table.items():
|
|
754
|
+
yield from _insert_items_yield_raw_pairs_one(
|
|
755
|
+
table, mappings, is_upsert=is_upsert
|
|
756
|
+
)
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def _insert_items_yield_raw_pairs_one(
|
|
760
|
+
table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
|
|
761
|
+
) -> Iterable[_InsertPair]:
|
|
762
|
+
merged = _insert_items_yield_merged_mappings(table, mappings)
|
|
763
|
+
match is_upsert:
|
|
764
|
+
case True:
|
|
765
|
+
by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
|
|
766
|
+
for mapping in merged:
|
|
767
|
+
non_null = {k: v for k, v in mapping.items() if v is not None}
|
|
768
|
+
by_keys[frozenset(non_null)].append(non_null)
|
|
769
|
+
for mappings_i in by_keys.values():
|
|
770
|
+
yield table, mappings_i
|
|
771
|
+
case False:
|
|
772
|
+
yield table, list(merged)
|
|
773
|
+
case never:
|
|
774
|
+
assert_never(never)
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
def _insert_items_yield_merged_mappings(
|
|
778
|
+
table: Table, mappings: Iterable[StrMapping], /
|
|
779
|
+
) -> Iterable[StrMapping]:
|
|
780
|
+
columns = list(yield_primary_key_columns(table))
|
|
781
|
+
col_names = [c.name for c in columns]
|
|
782
|
+
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
|
783
|
+
cols_non_auto = set(col_names) - cols_auto
|
|
784
|
+
by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
|
785
|
+
for mapping in mappings:
|
|
786
|
+
check_subset(cols_non_auto, mapping)
|
|
787
|
+
has_all_auto = set(cols_auto).issubset(mapping)
|
|
788
|
+
if has_all_auto:
|
|
789
|
+
pkey = tuple(mapping[k] for k in col_names)
|
|
790
|
+
rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
|
|
791
|
+
by_key[pkey].append(rest)
|
|
792
|
+
else:
|
|
793
|
+
yield mapping
|
|
794
|
+
for k, v in by_key.items():
|
|
795
|
+
head = dict(zip(col_names, k, strict=True))
|
|
796
|
+
yield merge_str_mappings(head, *v)
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def _insert_items_build_insert_with_on_conflict_do_update(
|
|
800
|
+
engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
|
|
801
|
+
) -> Insert:
|
|
802
|
+
primary_key = cast("Any", table.primary_key)
|
|
803
|
+
mappings = list(mappings)
|
|
804
|
+
columns = merge_sets(*mappings)
|
|
805
|
+
match _get_dialect(engine):
|
|
806
|
+
case "postgresql": # skipif-ci-and-not-linux
|
|
807
|
+
ins = postgresql_insert(table).values(mappings)
|
|
808
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
809
|
+
return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
810
|
+
case "sqlite":
|
|
811
|
+
ins = sqlite_insert(table).values(mappings)
|
|
812
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
|
813
|
+
return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
814
|
+
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
815
|
+
raise NotImplementedError(dialect)
|
|
816
|
+
case never:
|
|
817
|
+
assert_never(never)
|
|
818
|
+
|
|
819
|
+
|
|
443
820
|
@dataclass(kw_only=True, slots=True)
|
|
444
821
|
class InsertItemsError(Exception):
|
|
445
822
|
item: _InsertItem
|
|
@@ -483,10 +860,10 @@ async def migrate_data(
|
|
|
483
860
|
table_or_orm_to: TableOrORMInstOrClass | None = None,
|
|
484
861
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
485
862
|
assume_tables_exist: bool = False,
|
|
486
|
-
timeout_create:
|
|
487
|
-
error_create:
|
|
488
|
-
timeout_insert:
|
|
489
|
-
error_insert:
|
|
863
|
+
timeout_create: Delta | None = None,
|
|
864
|
+
error_create: MaybeType[BaseException] = TimeoutError,
|
|
865
|
+
timeout_insert: Delta | None = None,
|
|
866
|
+
error_insert: MaybeType[BaseException] = TimeoutError,
|
|
490
867
|
) -> None:
|
|
491
868
|
"""Migrate the contents of a table from one database to another."""
|
|
492
869
|
table_from = get_table(table_or_orm_from)
|
|
@@ -510,82 +887,8 @@ async def migrate_data(
|
|
|
510
887
|
##
|
|
511
888
|
|
|
512
889
|
|
|
513
|
-
def _normalize_insert_item(
|
|
514
|
-
item: _InsertItem, /, *, snake: bool = False
|
|
515
|
-
) -> list[_NormalizedItem]:
|
|
516
|
-
"""Normalize an insertion item."""
|
|
517
|
-
if _is_pair_of_str_mapping_and_table(item):
|
|
518
|
-
mapping, table_or_orm = item
|
|
519
|
-
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
|
520
|
-
normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
|
|
521
|
-
return [normalized]
|
|
522
|
-
if _is_pair_of_tuple_and_table(item):
|
|
523
|
-
tuple_, table_or_orm = item
|
|
524
|
-
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
|
525
|
-
return _normalize_insert_item((mapping, table_or_orm), snake=snake)
|
|
526
|
-
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
|
527
|
-
items, table_or_orm = item
|
|
528
|
-
pairs = [(i, table_or_orm) for i in items]
|
|
529
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
|
|
530
|
-
return list(chain.from_iterable(normalized))
|
|
531
|
-
if isinstance(item, DeclarativeBase):
|
|
532
|
-
mapping = _orm_inst_to_dict(item)
|
|
533
|
-
return _normalize_insert_item((mapping, item), snake=snake)
|
|
534
|
-
try:
|
|
535
|
-
_ = iter(item)
|
|
536
|
-
except TypeError:
|
|
537
|
-
raise _NormalizeInsertItemError(item=item) from None
|
|
538
|
-
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
|
539
|
-
seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
|
540
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
541
|
-
return list(chain.from_iterable(normalized))
|
|
542
|
-
if all(map(is_orm, item)):
|
|
543
|
-
seq = cast("Sequence[DeclarativeBase]", item)
|
|
544
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
|
545
|
-
return list(chain.from_iterable(normalized))
|
|
546
|
-
raise _NormalizeInsertItemError(item=item)
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
@dataclass(kw_only=True, slots=True)
|
|
550
|
-
class _NormalizeInsertItemError(Exception):
|
|
551
|
-
item: _InsertItem
|
|
552
|
-
|
|
553
|
-
@override
|
|
554
|
-
def __str__(self) -> str:
|
|
555
|
-
return f"Item must be valid; got {self.item}"
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
@dataclass(kw_only=True, slots=True)
|
|
559
|
-
class _NormalizedItem:
|
|
560
|
-
mapping: StrMapping
|
|
561
|
-
table: Table
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
def _normalize_upsert_item(
|
|
565
|
-
item: _InsertItem,
|
|
566
|
-
/,
|
|
567
|
-
*,
|
|
568
|
-
snake: bool = False,
|
|
569
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
570
|
-
) -> Iterator[_NormalizedItem]:
|
|
571
|
-
"""Normalize an upsert item."""
|
|
572
|
-
normalized = _normalize_insert_item(item, snake=snake)
|
|
573
|
-
match selected_or_all:
|
|
574
|
-
case "selected":
|
|
575
|
-
for norm in normalized:
|
|
576
|
-
values = {k: v for k, v in norm.mapping.items() if v is not None}
|
|
577
|
-
yield _NormalizedItem(mapping=values, table=norm.table)
|
|
578
|
-
case "all":
|
|
579
|
-
yield from normalized
|
|
580
|
-
case _ as never:
|
|
581
|
-
assert_never(never)
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
##
|
|
585
|
-
|
|
586
|
-
|
|
587
890
|
def selectable_to_string(
|
|
588
|
-
selectable: Selectable[Any], engine_or_conn:
|
|
891
|
+
selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
|
|
589
892
|
) -> str:
|
|
590
893
|
"""Convert a selectable into a string."""
|
|
591
894
|
com = selectable.compile(
|
|
@@ -608,237 +911,22 @@ class TablenameMixin:
|
|
|
608
911
|
##
|
|
609
912
|
|
|
610
913
|
|
|
611
|
-
@dataclass(kw_only=True)
|
|
612
|
-
class UpsertService(Looper[_InsertItem]):
|
|
613
|
-
"""Service to upsert items to a database."""
|
|
614
|
-
|
|
615
|
-
# base
|
|
616
|
-
freq: TimeDelta = field(default=SECOND, repr=False)
|
|
617
|
-
backoff: TimeDelta = field(default=SECOND, repr=False)
|
|
618
|
-
empty_upon_exit: bool = field(default=True, repr=False)
|
|
619
|
-
# self
|
|
620
|
-
engine: AsyncEngine
|
|
621
|
-
snake: bool = False
|
|
622
|
-
selected_or_all: _SelectedOrAll = "selected"
|
|
623
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC
|
|
624
|
-
assume_tables_exist: bool = False
|
|
625
|
-
timeout_create: TimeDelta | None = None
|
|
626
|
-
error_create: type[Exception] = TimeoutError
|
|
627
|
-
timeout_insert: TimeDelta | None = None
|
|
628
|
-
error_insert: type[Exception] = TimeoutError
|
|
629
|
-
|
|
630
|
-
@override
|
|
631
|
-
async def core(self) -> None:
|
|
632
|
-
await super().core()
|
|
633
|
-
await upsert_items(
|
|
634
|
-
self.engine,
|
|
635
|
-
*self.get_all_nowait(),
|
|
636
|
-
snake=self.snake,
|
|
637
|
-
selected_or_all=self.selected_or_all,
|
|
638
|
-
chunk_size_frac=self.chunk_size_frac,
|
|
639
|
-
assume_tables_exist=self.assume_tables_exist,
|
|
640
|
-
timeout_create=self.timeout_create,
|
|
641
|
-
error_create=self.error_create,
|
|
642
|
-
timeout_insert=self.timeout_insert,
|
|
643
|
-
error_insert=self.error_insert,
|
|
644
|
-
)
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
@dataclass(kw_only=True)
|
|
648
|
-
class UpsertServiceMixin:
|
|
649
|
-
"""Mix-in for the upsert service."""
|
|
650
|
-
|
|
651
|
-
# base - looper
|
|
652
|
-
upsert_service_freq: TimeDelta = field(default=SECOND, repr=False)
|
|
653
|
-
upsert_service_backoff: TimeDelta = field(default=SECOND, repr=False)
|
|
654
|
-
upsert_service_empty_upon_exit: bool = field(default=False, repr=False)
|
|
655
|
-
upsert_service_logger: str | None = field(default=None, repr=False)
|
|
656
|
-
upsert_service_timeout: TimeDelta | None = field(default=None, repr=False)
|
|
657
|
-
upsert_service_debug: bool = field(default=False, repr=False)
|
|
658
|
-
# base - upsert service
|
|
659
|
-
upsert_service_database: AsyncEngine
|
|
660
|
-
upsert_service_snake: bool = False
|
|
661
|
-
upsert_service_selected_or_all: _SelectedOrAll = "selected"
|
|
662
|
-
upsert_service_chunk_size_frac: float = CHUNK_SIZE_FRAC
|
|
663
|
-
upsert_service_assume_tables_exist: bool = False
|
|
664
|
-
upsert_service_timeout_create: TimeDelta | None = None
|
|
665
|
-
upsert_service_error_create: type[Exception] = TimeoutError
|
|
666
|
-
upsert_service_timeout_insert: TimeDelta | None = None
|
|
667
|
-
upsert_service_error_insert: type[Exception] = TimeoutError
|
|
668
|
-
# self
|
|
669
|
-
_upsert_service: UpsertService = field(init=False, repr=False)
|
|
670
|
-
|
|
671
|
-
def __post_init__(self) -> None:
|
|
672
|
-
with suppress_super_object_attribute_error():
|
|
673
|
-
super().__post_init__() # pyright: ignore[reportAttributeAccessIssue]
|
|
674
|
-
self._upsert_service = UpsertService(
|
|
675
|
-
# looper
|
|
676
|
-
freq=self.upsert_service_freq,
|
|
677
|
-
backoff=self.upsert_service_backoff,
|
|
678
|
-
empty_upon_exit=self.upsert_service_empty_upon_exit,
|
|
679
|
-
logger=self.upsert_service_logger,
|
|
680
|
-
timeout=self.upsert_service_timeout,
|
|
681
|
-
_debug=self.upsert_service_debug,
|
|
682
|
-
# upsert service
|
|
683
|
-
engine=self.upsert_service_database,
|
|
684
|
-
snake=self.upsert_service_snake,
|
|
685
|
-
selected_or_all=self.upsert_service_selected_or_all,
|
|
686
|
-
chunk_size_frac=self.upsert_service_chunk_size_frac,
|
|
687
|
-
assume_tables_exist=self.upsert_service_assume_tables_exist,
|
|
688
|
-
timeout_create=self.upsert_service_timeout_create,
|
|
689
|
-
error_create=self.upsert_service_error_create,
|
|
690
|
-
timeout_insert=self.upsert_service_timeout_insert,
|
|
691
|
-
error_insert=self.upsert_service_error_insert,
|
|
692
|
-
)
|
|
693
|
-
|
|
694
|
-
def _yield_sub_loopers(self) -> Iterator[Looper[Any]]:
|
|
695
|
-
with suppress_super_object_attribute_error():
|
|
696
|
-
yield from super()._yield_sub_loopers() # pyright: ignore[reportAttributeAccessIssue]
|
|
697
|
-
yield self._upsert_service
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
##
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
type _SelectedOrAll = Literal["selected", "all"]
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
async def upsert_items(
|
|
707
|
-
engine: AsyncEngine,
|
|
708
|
-
/,
|
|
709
|
-
*items: _InsertItem,
|
|
710
|
-
snake: bool = False,
|
|
711
|
-
selected_or_all: _SelectedOrAll = "selected",
|
|
712
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
713
|
-
assume_tables_exist: bool = False,
|
|
714
|
-
timeout_create: TimeDelta | None = None,
|
|
715
|
-
error_create: type[Exception] = TimeoutError,
|
|
716
|
-
timeout_insert: TimeDelta | None = None,
|
|
717
|
-
error_insert: type[Exception] = TimeoutError,
|
|
718
|
-
) -> None:
|
|
719
|
-
"""Upsert a set of items into a database.
|
|
720
|
-
|
|
721
|
-
These can be one of the following:
|
|
722
|
-
- pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
|
|
723
|
-
- pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
|
|
724
|
-
{k1=v21, k2=v22, ...},
|
|
725
|
-
...], table/class
|
|
726
|
-
- list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
|
|
727
|
-
({k1=v21, k2=v22, ...}, table_cls2),
|
|
728
|
-
...]
|
|
729
|
-
- mapped class: Obj(k1=v1, k2=v2, ...)
|
|
730
|
-
- list of mapped classes: [Obj(k1=v11, k2=v12, ...),
|
|
731
|
-
Obj(k1=v21, k2=v22, ...),
|
|
732
|
-
...]
|
|
733
|
-
"""
|
|
734
|
-
|
|
735
|
-
def build_insert(
|
|
736
|
-
table: Table, values: Iterable[StrMapping], /
|
|
737
|
-
) -> tuple[Insert, None]:
|
|
738
|
-
ups = _upsert_items_build(
|
|
739
|
-
engine, table, values, selected_or_all=selected_or_all
|
|
740
|
-
)
|
|
741
|
-
return ups, None
|
|
742
|
-
|
|
743
|
-
try:
|
|
744
|
-
prepared = _prepare_insert_or_upsert_items(
|
|
745
|
-
partial(
|
|
746
|
-
_normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
|
|
747
|
-
),
|
|
748
|
-
engine,
|
|
749
|
-
build_insert,
|
|
750
|
-
*items,
|
|
751
|
-
chunk_size_frac=chunk_size_frac,
|
|
752
|
-
)
|
|
753
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
|
754
|
-
raise UpsertItemsError(item=error.item) from None
|
|
755
|
-
if not assume_tables_exist:
|
|
756
|
-
await ensure_tables_created(
|
|
757
|
-
engine, *prepared.tables, timeout=timeout_create, error=error_create
|
|
758
|
-
)
|
|
759
|
-
for ups, _ in prepared.yield_pairs():
|
|
760
|
-
async with yield_connection(
|
|
761
|
-
engine, timeout=timeout_insert, error=error_insert
|
|
762
|
-
) as conn:
|
|
763
|
-
_ = await conn.execute(ups)
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
def _upsert_items_build(
|
|
767
|
-
engine: AsyncEngine,
|
|
768
|
-
table: Table,
|
|
769
|
-
values: Iterable[StrMapping],
|
|
770
|
-
/,
|
|
771
|
-
*,
|
|
772
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
773
|
-
) -> Insert:
|
|
774
|
-
values = list(values)
|
|
775
|
-
keys = merge_sets(*values)
|
|
776
|
-
dict_nones = dict.fromkeys(keys)
|
|
777
|
-
values = [{**dict_nones, **v} for v in values]
|
|
778
|
-
match _get_dialect(engine):
|
|
779
|
-
case "postgresql": # skipif-ci-and-not-linux
|
|
780
|
-
insert = postgresql_insert
|
|
781
|
-
case "sqlite":
|
|
782
|
-
insert = sqlite_insert
|
|
783
|
-
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
|
784
|
-
raise NotImplementedError(dialect)
|
|
785
|
-
case _ as never:
|
|
786
|
-
assert_never(never)
|
|
787
|
-
ins = insert(table).values(values)
|
|
788
|
-
primary_key = cast("Any", table.primary_key)
|
|
789
|
-
return _upsert_items_apply_on_conflict_do_update(
|
|
790
|
-
values, ins, primary_key, selected_or_all=selected_or_all
|
|
791
|
-
)
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
def _upsert_items_apply_on_conflict_do_update(
|
|
795
|
-
values: Iterable[StrMapping],
|
|
796
|
-
insert: postgresql_Insert | sqlite_Insert,
|
|
797
|
-
primary_key: PrimaryKeyConstraint,
|
|
798
|
-
/,
|
|
799
|
-
*,
|
|
800
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
|
801
|
-
) -> Insert:
|
|
802
|
-
match selected_or_all:
|
|
803
|
-
case "selected":
|
|
804
|
-
columns = merge_sets(*values)
|
|
805
|
-
case "all":
|
|
806
|
-
columns = {c.name for c in insert.excluded}
|
|
807
|
-
case _ as never:
|
|
808
|
-
assert_never(never)
|
|
809
|
-
set_ = {c: getattr(insert.excluded, c) for c in columns}
|
|
810
|
-
match insert:
|
|
811
|
-
case postgresql_Insert(): # skipif-ci
|
|
812
|
-
return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
|
|
813
|
-
case sqlite_Insert():
|
|
814
|
-
return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
|
815
|
-
case _ as never:
|
|
816
|
-
assert_never(never)
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
@dataclass(kw_only=True, slots=True)
|
|
820
|
-
class UpsertItemsError(Exception):
|
|
821
|
-
item: _InsertItem
|
|
822
|
-
|
|
823
|
-
@override
|
|
824
|
-
def __str__(self) -> str:
|
|
825
|
-
return f"Item must be valid; got {self.item}"
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
##
|
|
829
|
-
|
|
830
|
-
|
|
831
914
|
@asynccontextmanager
|
|
832
915
|
async def yield_connection(
|
|
833
916
|
engine: AsyncEngine,
|
|
834
917
|
/,
|
|
835
918
|
*,
|
|
836
|
-
timeout:
|
|
919
|
+
timeout: Delta | None = None,
|
|
837
920
|
error: MaybeType[BaseException] = TimeoutError,
|
|
838
921
|
) -> AsyncIterator[AsyncConnection]:
|
|
839
922
|
"""Yield an async connection."""
|
|
840
|
-
|
|
841
|
-
|
|
923
|
+
try:
|
|
924
|
+
async with timeout_td(timeout, error=error), engine.begin() as conn:
|
|
925
|
+
yield conn
|
|
926
|
+
except GeneratorExit: # pragma: no cover
|
|
927
|
+
if not is_pytest():
|
|
928
|
+
raise
|
|
929
|
+
return
|
|
842
930
|
|
|
843
931
|
|
|
844
932
|
##
|
|
@@ -869,7 +957,7 @@ def _ensure_tables_maybe_reraise(error: DatabaseError, match: str, /) -> None:
|
|
|
869
957
|
##
|
|
870
958
|
|
|
871
959
|
|
|
872
|
-
def _get_dialect(engine_or_conn:
|
|
960
|
+
def _get_dialect(engine_or_conn: EngineOrConnectionOrAsync, /) -> Dialect:
|
|
873
961
|
"""Get the dialect of a database."""
|
|
874
962
|
dialect = engine_or_conn.dialect
|
|
875
963
|
if isinstance(dialect, mssql_dialect): # pragma: no cover
|
|
@@ -879,7 +967,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
|
|
|
879
967
|
if isinstance(dialect, oracle_dialect): # pragma: no cover
|
|
880
968
|
return "oracle"
|
|
881
969
|
if isinstance( # skipif-ci-and-not-linux
|
|
882
|
-
dialect, postgresql_dialect
|
|
970
|
+
dialect, (postgresql_dialect, PGDialect_asyncpg, PGDialect_psycopg)
|
|
883
971
|
):
|
|
884
972
|
return "postgresql"
|
|
885
973
|
if isinstance(dialect, sqlite_dialect):
|
|
@@ -892,7 +980,7 @@ def _get_dialect(engine_or_conn: _EngineOrConnectionOrAsync, /) -> Dialect:
|
|
|
892
980
|
|
|
893
981
|
|
|
894
982
|
def _get_dialect_max_params(
|
|
895
|
-
dialect_or_engine_or_conn:
|
|
983
|
+
dialect_or_engine_or_conn: DialectOrEngineOrConnectionOrAsync, /
|
|
896
984
|
) -> int:
|
|
897
985
|
"""Get the max number of parameters of a dialect."""
|
|
898
986
|
match dialect_or_engine_or_conn:
|
|
@@ -914,7 +1002,7 @@ def _get_dialect_max_params(
|
|
|
914
1002
|
):
|
|
915
1003
|
dialect = _get_dialect(engine_or_conn)
|
|
916
1004
|
return _get_dialect_max_params(dialect)
|
|
917
|
-
case
|
|
1005
|
+
case never:
|
|
918
1006
|
assert_never(never)
|
|
919
1007
|
|
|
920
1008
|
|
|
@@ -999,7 +1087,7 @@ def _map_mapping_to_table(
|
|
|
999
1087
|
@dataclass(kw_only=True, slots=True)
|
|
1000
1088
|
class _MapMappingToTableError(Exception):
|
|
1001
1089
|
mapping: StrMapping
|
|
1002
|
-
columns:
|
|
1090
|
+
columns: list[str]
|
|
1003
1091
|
|
|
1004
1092
|
|
|
1005
1093
|
@dataclass(kw_only=True, slots=True)
|
|
@@ -1036,102 +1124,31 @@ class _MapMappingToTableSnakeMapNonUniqueError(_MapMappingToTableError):
|
|
|
1036
1124
|
|
|
1037
1125
|
def _orm_inst_to_dict(obj: DeclarativeBase, /) -> StrMapping:
|
|
1038
1126
|
"""Map an ORM instance to a dictionary."""
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
):
|
|
1045
|
-
return attr
|
|
1046
|
-
return None
|
|
1047
|
-
|
|
1048
|
-
def yield_items() -> Iterator[tuple[str, Any]]:
|
|
1049
|
-
for key in get_column_names(cls):
|
|
1050
|
-
attr = one(attr for attr in dir(cls) if is_attr(attr, key) is not None)
|
|
1051
|
-
yield key, getattr(obj, attr)
|
|
1052
|
-
|
|
1053
|
-
return dict(yield_items())
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
##
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
@dataclass(kw_only=True, slots=True)
|
|
1060
|
-
class _PrepareInsertOrUpsertItems:
|
|
1061
|
-
mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
|
|
1062
|
-
yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
|
|
1063
|
-
|
|
1064
|
-
@property
|
|
1065
|
-
def tables(self) -> Sequence[Table]:
|
|
1066
|
-
return list(self.mapping)
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
def _prepare_insert_or_upsert_items(
|
|
1070
|
-
normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
|
|
1071
|
-
engine: AsyncEngine,
|
|
1072
|
-
build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
|
|
1073
|
-
/,
|
|
1074
|
-
*items: Any,
|
|
1075
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
|
1076
|
-
) -> _PrepareInsertOrUpsertItems:
|
|
1077
|
-
"""Prepare a set of insert/upsert items."""
|
|
1078
|
-
mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
|
1079
|
-
lengths: set[int] = set()
|
|
1080
|
-
try:
|
|
1081
|
-
for item in items:
|
|
1082
|
-
for normed in normalize_item(item):
|
|
1083
|
-
mapping[normed.table].append(normed.mapping)
|
|
1084
|
-
lengths.add(len(normed.mapping))
|
|
1085
|
-
except _NormalizeInsertItemError as error:
|
|
1086
|
-
raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
|
|
1087
|
-
merged = {
|
|
1088
|
-
table: _prepare_insert_or_upsert_items_merge_items(table, values)
|
|
1089
|
-
for table, values in mapping.items()
|
|
1127
|
+
attrs = {
|
|
1128
|
+
k for k, _ in yield_object_attributes(obj, static_type=InstrumentedAttribute)
|
|
1129
|
+
}
|
|
1130
|
+
return {
|
|
1131
|
+
name: _orm_inst_to_dict_one(obj, attrs, name) for name in get_column_names(obj)
|
|
1090
1132
|
}
|
|
1091
|
-
max_length = max(lengths, default=1)
|
|
1092
|
-
chunk_size = get_chunk_size(
|
|
1093
|
-
engine, chunk_size_frac=chunk_size_frac, max_length=max_length
|
|
1094
|
-
)
|
|
1095
|
-
|
|
1096
|
-
def yield_pairs() -> Iterator[tuple[Insert, None]]:
|
|
1097
|
-
for table, values in merged.items():
|
|
1098
|
-
for chunk in chunked(values, chunk_size):
|
|
1099
|
-
yield build_insert(table, chunk)
|
|
1100
|
-
|
|
1101
|
-
return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
|
|
1102
1133
|
|
|
1103
1134
|
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1107
|
-
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
|
|
1135
|
+
def _orm_inst_to_dict_one(
|
|
1136
|
+
obj: DeclarativeBase, attrs: AbstractSet[str], name: str, /
|
|
1137
|
+
) -> Any:
|
|
1138
|
+
attr = one(
|
|
1139
|
+
attr for attr in attrs if _orm_inst_to_dict_predicate(type(obj), attr, name)
|
|
1140
|
+
)
|
|
1141
|
+
return getattr(obj, attr)
|
|
1111
1142
|
|
|
1112
1143
|
|
|
1113
|
-
def
|
|
1114
|
-
|
|
1115
|
-
) ->
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
unchanged: list[StrMapping] = []
|
|
1122
|
-
for item in items:
|
|
1123
|
-
check_subset(cols_non_auto, item)
|
|
1124
|
-
has_all_auto = set(cols_auto).issubset(item)
|
|
1125
|
-
if has_all_auto:
|
|
1126
|
-
pkey = tuple(item[k] for k in col_names)
|
|
1127
|
-
rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
|
|
1128
|
-
mapping[pkey].append(rest)
|
|
1129
|
-
else:
|
|
1130
|
-
unchanged.append(item)
|
|
1131
|
-
merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
|
|
1132
|
-
return [
|
|
1133
|
-
dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
|
|
1134
|
-
] + unchanged
|
|
1144
|
+
def _orm_inst_to_dict_predicate(
|
|
1145
|
+
cls: type[DeclarativeBase], attr: str, name: str, /
|
|
1146
|
+
) -> bool:
|
|
1147
|
+
cls_attr = getattr(cls, attr)
|
|
1148
|
+
try:
|
|
1149
|
+
return cls_attr.name == name
|
|
1150
|
+
except AttributeError:
|
|
1151
|
+
return False
|
|
1135
1152
|
|
|
1136
1153
|
|
|
1137
1154
|
##
|
|
@@ -1148,18 +1165,27 @@ def _tuple_to_mapping(
|
|
|
1148
1165
|
__all__ = [
|
|
1149
1166
|
"CHUNK_SIZE_FRAC",
|
|
1150
1167
|
"CheckEngineError",
|
|
1168
|
+
"DialectOrEngineOrConnectionOrAsync",
|
|
1169
|
+
"EngineOrConnectionOrAsync",
|
|
1170
|
+
"ExtractURLError",
|
|
1171
|
+
"ExtractURLOutput",
|
|
1151
1172
|
"GetTableError",
|
|
1152
1173
|
"InsertItemsError",
|
|
1153
1174
|
"TablenameMixin",
|
|
1154
|
-
"
|
|
1155
|
-
"
|
|
1156
|
-
"UpsertServiceMixin",
|
|
1175
|
+
"check_connect",
|
|
1176
|
+
"check_connect_async",
|
|
1157
1177
|
"check_engine",
|
|
1158
1178
|
"columnwise_max",
|
|
1159
1179
|
"columnwise_min",
|
|
1160
|
-
"
|
|
1180
|
+
"create_engine",
|
|
1181
|
+
"ensure_database_created",
|
|
1182
|
+
"ensure_database_dropped",
|
|
1183
|
+
"ensure_database_users_disconnected",
|
|
1161
1184
|
"ensure_tables_created",
|
|
1162
1185
|
"ensure_tables_dropped",
|
|
1186
|
+
"enum_name",
|
|
1187
|
+
"enum_values",
|
|
1188
|
+
"extract_url",
|
|
1163
1189
|
"get_chunk_size",
|
|
1164
1190
|
"get_column_names",
|
|
1165
1191
|
"get_columns",
|
|
@@ -1172,7 +1198,6 @@ __all__ = [
|
|
|
1172
1198
|
"is_table_or_orm",
|
|
1173
1199
|
"migrate_data",
|
|
1174
1200
|
"selectable_to_string",
|
|
1175
|
-
"upsert_items",
|
|
1176
1201
|
"yield_connection",
|
|
1177
1202
|
"yield_primary_key_columns",
|
|
1178
1203
|
]
|