dycw-utilities 0.157.1__py3-none-any.whl → 0.158.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dycw_utilities-0.157.1.dist-info → dycw_utilities-0.158.0.dist-info}/METADATA +1 -1
- {dycw_utilities-0.157.1.dist-info → dycw_utilities-0.158.0.dist-info}/RECORD +11 -11
- utilities/__init__.py +1 -1
- utilities/fastapi.py +8 -3
- utilities/hypothesis.py +22 -0
- utilities/slack_sdk.py +8 -3
- utilities/sqlalchemy.py +186 -331
- utilities/sqlalchemy_polars.py +20 -44
- {dycw_utilities-0.157.1.dist-info → dycw_utilities-0.158.0.dist-info}/WHEEL +0 -0
- {dycw_utilities-0.157.1.dist-info → dycw_utilities-0.158.0.dist-info}/entry_points.txt +0 -0
- {dycw_utilities-0.157.1.dist-info → dycw_utilities-0.158.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
utilities/__init__.py,sha256=
|
1
|
+
utilities/__init__.py,sha256=dllUsKScS7RBnWFHME-sLeOoDAWms1u25frX9QMEP-4,60
|
2
2
|
utilities/altair.py,sha256=92E2lCdyHY4Zb-vCw6rEJIsWdKipuu-Tu2ab1ufUfAk,9079
|
3
3
|
utilities/asyncio.py,sha256=PUedzQ5deqlSECQ33sam9cRzI9TnygHz3FdOqWJWPTM,15288
|
4
4
|
utilities/atomicwrites.py,sha256=tPo6r-Rypd9u99u66B9z86YBPpnLrlHtwox_8Z7T34Y,5790
|
@@ -14,7 +14,7 @@ utilities/dataclasses.py,sha256=MXrvIPSZHlpV4msRdVVDRZZo7MC3gX5C9jDUSoNOdpE,3247
|
|
14
14
|
utilities/enum.py,sha256=5l6pwZD1cjSlVW4ss-zBPspWvrbrYrdtJWcg6f5_J5w,5781
|
15
15
|
utilities/errors.py,sha256=mFlDGSM0LI1jZ1pbqwLAH3ttLZ2JVIxyZLojw8tGVZU,1479
|
16
16
|
utilities/eventkit.py,sha256=ddoleSwW9zdc2tjX5Ge0pMKtYwV_JMxhHYOxnWX2AGM,12609
|
17
|
-
utilities/fastapi.py,sha256=
|
17
|
+
utilities/fastapi.py,sha256=TqyKvBjiMS594sXPjrz-KRTLMb3l3D3rZ1zAYV7GfOk,1454
|
18
18
|
utilities/fpdf2.py,sha256=HgM8JSvoioDXrjC0UR3HVLjnMnnb_mML7nL2EmkTwGI,1854
|
19
19
|
utilities/functions.py,sha256=RNVAoLeT_sl-gXaBv2VI_U_EB-d-nSVosYR4gTeeojE,28261
|
20
20
|
utilities/functools.py,sha256=I00ru2gQPakZw2SHVeKIKXfTv741655s6HI0lUoE0D4,1552
|
@@ -22,7 +22,7 @@ utilities/getpass.py,sha256=DfN5UgMAtFCqS3dSfFHUfqIMZX2shXvwphOz_6J6f6A,103
|
|
22
22
|
utilities/gzip.py,sha256=fkGP3KdsBfXlstodT4wtlp-PwNyUsogpbDCVVVGdsm4,781
|
23
23
|
utilities/hashlib.py,sha256=SVTgtguur0P4elppvzOBbLEjVM3Pea0eWB61yg2ilxo,309
|
24
24
|
utilities/http.py,sha256=TsavEfHlRtlLaeV21Z6KZh0qbPw-kvD1zsQdZ7Kep5Q,977
|
25
|
-
utilities/hypothesis.py,sha256=
|
25
|
+
utilities/hypothesis.py,sha256=OG8mxN6Y3fSEjRg4NiIjsO_JUHJBzh4g8fvpmKRoRU8,44370
|
26
26
|
utilities/importlib.py,sha256=mV1xT_O_zt_GnZZ36tl3xOmMaN_3jErDWY54fX39F6Y,429
|
27
27
|
utilities/inflect.py,sha256=v7YkOWSu8NAmVghPcf4F3YBZQoJCS47_DLf9jbfWIs0,581
|
28
28
|
utilities/ipython.py,sha256=V2oMYHvEKvlNBzxDXdLvKi48oUq2SclRg5xasjaXStw,763
|
@@ -62,10 +62,10 @@ utilities/reprlib.py,sha256=ssYTcBW-TeRh3fhCJv57sopTZHF5FrPyyUg9yp5XBlo,3953
|
|
62
62
|
utilities/scipy.py,sha256=wZJM7fEgBAkLSYYvSmsg5ac-QuwAI0BGqHVetw1_Hb0,947
|
63
63
|
utilities/sentinel.py,sha256=A_p5jX2K0Yc5XBfoYHyBLqHsEWzE1ByOdDuzzA2pZnE,1434
|
64
64
|
utilities/shelve.py,sha256=4OzjQI6kGuUbJciqf535rwnao-_IBv66gsT6tRGiUt0,759
|
65
|
-
utilities/slack_sdk.py,sha256=
|
65
|
+
utilities/slack_sdk.py,sha256=76-DYtcGiUhEvl-voMamc5OjfF7Y7nCq54Bys1arqzw,2233
|
66
66
|
utilities/socket.py,sha256=K77vfREvzoVTrpYKo6MZakol0EYu2q1sWJnnZqL0So0,118
|
67
|
-
utilities/sqlalchemy.py,sha256=
|
68
|
-
utilities/sqlalchemy_polars.py,sha256=
|
67
|
+
utilities/sqlalchemy.py,sha256=4rLjecgDe60MnuWxUqZVIQy3dKepNaemoGyrnuCdNZg,36347
|
68
|
+
utilities/sqlalchemy_polars.py,sha256=JCGhB37raSR7fqeWV5dTsciRTMVzIdVT9YSqKT0piT0,13370
|
69
69
|
utilities/statsmodels.py,sha256=koyiBHvpMcSiBfh99wFUfSggLNx7cuAw3rwyfAhoKpQ,3410
|
70
70
|
utilities/string.py,sha256=shmBK87zZwzGyixuNuXCiUbqzfeZ9xlrFwz6JTaRvDk,582
|
71
71
|
utilities/tempfile.py,sha256=HxB2BF28CcecDJLQ3Bx2Ej-Pb6RJc6W9ngSpB9CnP4k,2018
|
@@ -87,8 +87,8 @@ utilities/zoneinfo.py,sha256=FBMcUQ4662Aq8SsuCL1OAhDQiyANmVjtb-C30DRrWoE,1966
|
|
87
87
|
utilities/pytest_plugins/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
88
88
|
utilities/pytest_plugins/pytest_randomly.py,sha256=B1qYVlExGOxTywq2r1SMi5o7btHLk2PNdY_b1p98dkE,409
|
89
89
|
utilities/pytest_plugins/pytest_regressions.py,sha256=9v8kAXDM2ycIXJBimoiF4EgrwbUvxTycFWJiGR_GHhM,1466
|
90
|
-
dycw_utilities-0.
|
91
|
-
dycw_utilities-0.
|
92
|
-
dycw_utilities-0.
|
93
|
-
dycw_utilities-0.
|
94
|
-
dycw_utilities-0.
|
90
|
+
dycw_utilities-0.158.0.dist-info/METADATA,sha256=-lYD9v1af9rVk_ibo4Z2cLp-i4JgpoR7qiHy1svZeDU,1643
|
91
|
+
dycw_utilities-0.158.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
92
|
+
dycw_utilities-0.158.0.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
|
93
|
+
dycw_utilities-0.158.0.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
|
94
|
+
dycw_utilities-0.158.0.dist-info/RECORD,,
|
utilities/__init__.py
CHANGED
utilities/fastapi.py
CHANGED
@@ -13,7 +13,7 @@ from utilities.whenever import get_now_local
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from collections.abc import AsyncIterator
|
15
15
|
|
16
|
-
from utilities.types import Delta
|
16
|
+
from utilities.types import Delta, MaybeType
|
17
17
|
|
18
18
|
|
19
19
|
_TASKS: list[Task[None]] = []
|
@@ -35,14 +35,19 @@ class _PingerReceiverApp(FastAPI):
|
|
35
35
|
|
36
36
|
@enhanced_async_context_manager
|
37
37
|
async def yield_ping_receiver(
|
38
|
-
port: int,
|
38
|
+
port: int,
|
39
|
+
/,
|
40
|
+
*,
|
41
|
+
host: str = "localhost",
|
42
|
+
timeout: Delta | None = None,
|
43
|
+
error: MaybeType[BaseException] = TimeoutError,
|
39
44
|
) -> AsyncIterator[None]:
|
40
45
|
"""Yield the ping receiver."""
|
41
46
|
app = _PingerReceiverApp() # skipif-ci
|
42
47
|
server = Server(Config(app, host=host, port=port)) # skipif-ci
|
43
48
|
_TASKS.append(create_task(server.serve())) # skipif-ci
|
44
49
|
try: # skipif-ci
|
45
|
-
async with timeout_td(timeout):
|
50
|
+
async with timeout_td(timeout, error=error):
|
46
51
|
yield
|
47
52
|
finally: # skipif-ci
|
48
53
|
await server.shutdown()
|
utilities/hypothesis.py
CHANGED
@@ -897,6 +897,27 @@ def py_datetimes(
|
|
897
897
|
##
|
898
898
|
|
899
899
|
|
900
|
+
def quadruples[T](
|
901
|
+
strategy: SearchStrategy[T],
|
902
|
+
/,
|
903
|
+
*,
|
904
|
+
unique: MaybeSearchStrategy[bool] = False,
|
905
|
+
sorted: MaybeSearchStrategy[bool] = False, # noqa: A002
|
906
|
+
) -> SearchStrategy[tuple[T, T, T, T]]:
|
907
|
+
"""Strategy for generating quadruples of elements."""
|
908
|
+
return lists_fixed_length(strategy, 4, unique=unique, sorted=sorted).map(
|
909
|
+
_quadruples_map
|
910
|
+
)
|
911
|
+
|
912
|
+
|
913
|
+
def _quadruples_map[T](elements: list[T], /) -> tuple[T, T, T, T]:
|
914
|
+
first, second, third, fourth = elements
|
915
|
+
return first, second, third, fourth
|
916
|
+
|
917
|
+
|
918
|
+
##
|
919
|
+
|
920
|
+
|
900
921
|
@composite
|
901
922
|
def random_states(
|
902
923
|
draw: DrawFn, /, *, seed: MaybeSearchStrategy[int | None] = None
|
@@ -1555,6 +1576,7 @@ __all__ = [
|
|
1555
1576
|
"paths",
|
1556
1577
|
"plain_date_times",
|
1557
1578
|
"py_datetimes",
|
1579
|
+
"quadruples",
|
1558
1580
|
"random_states",
|
1559
1581
|
"sentinels",
|
1560
1582
|
"sets_fixed_length",
|
utilities/slack_sdk.py
CHANGED
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|
15
15
|
from slack_sdk.webhook import WebhookResponse
|
16
16
|
from whenever import TimeDelta
|
17
17
|
|
18
|
-
from utilities.types import Delta
|
18
|
+
from utilities.types import Delta, MaybeType
|
19
19
|
|
20
20
|
|
21
21
|
_TIMEOUT: Delta = MINUTE
|
@@ -39,11 +39,16 @@ def _get_client(url: str, /, *, timeout: Delta = _TIMEOUT) -> WebhookClient:
|
|
39
39
|
|
40
40
|
|
41
41
|
async def send_to_slack_async(
|
42
|
-
url: str,
|
42
|
+
url: str,
|
43
|
+
text: str,
|
44
|
+
/,
|
45
|
+
*,
|
46
|
+
timeout: TimeDelta = _TIMEOUT,
|
47
|
+
error: MaybeType[BaseException] = TimeoutError,
|
43
48
|
) -> None:
|
44
49
|
"""Send a message via Slack."""
|
45
50
|
client = _get_async_client(url, timeout=timeout)
|
46
|
-
async with timeout_td(timeout):
|
51
|
+
async with timeout_td(timeout, error=error):
|
47
52
|
response = await client.send(text=text)
|
48
53
|
if response.status_code != HTTPStatus.OK: # pragma: no cover
|
49
54
|
raise SendToSlackError(text=text, response=response)
|
utilities/sqlalchemy.py
CHANGED
@@ -12,8 +12,8 @@ from collections.abc import (
|
|
12
12
|
)
|
13
13
|
from collections.abc import Set as AbstractSet
|
14
14
|
from contextlib import asynccontextmanager
|
15
|
-
from dataclasses import dataclass
|
16
|
-
from functools import
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from functools import reduce
|
17
17
|
from itertools import chain
|
18
18
|
from math import floor
|
19
19
|
from operator import ge, le
|
@@ -37,7 +37,6 @@ from sqlalchemy import (
|
|
37
37
|
Connection,
|
38
38
|
Engine,
|
39
39
|
Insert,
|
40
|
-
PrimaryKeyConstraint,
|
41
40
|
Selectable,
|
42
41
|
Table,
|
43
42
|
and_,
|
@@ -49,20 +48,13 @@ from sqlalchemy import (
|
|
49
48
|
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
|
50
49
|
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
|
51
50
|
from sqlalchemy.dialects.oracle import dialect as oracle_dialect
|
52
|
-
from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
|
53
51
|
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
|
54
52
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
55
53
|
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
|
56
54
|
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
|
57
|
-
from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
|
58
55
|
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
|
59
56
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
60
|
-
from sqlalchemy.exc import
|
61
|
-
ArgumentError,
|
62
|
-
DatabaseError,
|
63
|
-
OperationalError,
|
64
|
-
ProgrammingError,
|
65
|
-
)
|
57
|
+
from sqlalchemy.exc import ArgumentError, DatabaseError
|
66
58
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
67
59
|
from sqlalchemy.orm import (
|
68
60
|
DeclarativeBase,
|
@@ -129,24 +121,22 @@ def check_connect(engine: Engine, /) -> bool:
|
|
129
121
|
try:
|
130
122
|
with engine.connect() as conn:
|
131
123
|
return bool(conn.execute(_SELECT).scalar_one())
|
132
|
-
except (gaierror, ConnectionRefusedError,
|
124
|
+
except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
|
133
125
|
return False
|
134
126
|
|
135
127
|
|
136
128
|
async def check_connect_async(
|
137
|
-
engine: AsyncEngine,
|
129
|
+
engine: AsyncEngine,
|
130
|
+
/,
|
131
|
+
*,
|
132
|
+
timeout: Delta | None = None,
|
133
|
+
error: MaybeType[BaseException] = TimeoutError,
|
138
134
|
) -> bool:
|
139
135
|
"""Check if an engine can connect."""
|
140
136
|
try:
|
141
|
-
async with timeout_td(timeout), engine.connect() as conn:
|
137
|
+
async with timeout_td(timeout, error=error), engine.connect() as conn:
|
142
138
|
return bool((await conn.execute(_SELECT)).scalar_one())
|
143
|
-
except (
|
144
|
-
gaierror,
|
145
|
-
ConnectionRefusedError,
|
146
|
-
OperationalError,
|
147
|
-
ProgrammingError,
|
148
|
-
TimeoutError,
|
149
|
-
):
|
139
|
+
except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
|
150
140
|
return False
|
151
141
|
|
152
142
|
|
@@ -158,7 +148,7 @@ async def check_engine(
|
|
158
148
|
/,
|
159
149
|
*,
|
160
150
|
timeout: Delta | None = None,
|
161
|
-
error:
|
151
|
+
error: MaybeType[Exception] = TimeoutError,
|
162
152
|
num_tables: int | tuple[int, float] | None = None,
|
163
153
|
) -> None:
|
164
154
|
"""Check that an engine can connect.
|
@@ -332,9 +322,8 @@ async def ensure_database_created(super_: URL, database: str, /) -> None:
|
|
332
322
|
async with engine.begin() as conn:
|
333
323
|
try:
|
334
324
|
_ = await conn.execute(text(f"CREATE DATABASE {database}"))
|
335
|
-
except
|
336
|
-
|
337
|
-
raise
|
325
|
+
except DatabaseError as error:
|
326
|
+
_ensure_tables_maybe_reraise(error, 'database ".*" already exists')
|
338
327
|
|
339
328
|
|
340
329
|
async def ensure_database_dropped(super_: URL, database: str, /) -> None:
|
@@ -352,7 +341,7 @@ async def ensure_tables_created(
|
|
352
341
|
/,
|
353
342
|
*tables_or_orms: TableOrORMInstOrClass,
|
354
343
|
timeout: Delta | None = None,
|
355
|
-
error:
|
344
|
+
error: MaybeType[Exception] = TimeoutError,
|
356
345
|
) -> None:
|
357
346
|
"""Ensure a table/set of tables is/are created."""
|
358
347
|
tables = set(map(get_table, tables_or_orms))
|
@@ -381,7 +370,7 @@ async def ensure_tables_dropped(
|
|
381
370
|
engine: AsyncEngine,
|
382
371
|
*tables_or_orms: TableOrORMInstOrClass,
|
383
372
|
timeout: Delta | None = None,
|
384
|
-
error:
|
373
|
+
error: MaybeType[Exception] = TimeoutError,
|
385
374
|
) -> None:
|
386
375
|
"""Ensure a table/set of tables is/are dropped."""
|
387
376
|
tables = set(map(get_table, tables_or_orms))
|
@@ -602,12 +591,16 @@ type _InsertItem = (
|
|
602
591
|
| Sequence[_PairOfTupleOrStrMappingAndTable]
|
603
592
|
| Sequence[DeclarativeBase]
|
604
593
|
)
|
594
|
+
type _NormalizedItem = tuple[Table, StrMapping]
|
595
|
+
type _InsertPair = tuple[Table, Sequence[StrMapping]]
|
596
|
+
type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
|
605
597
|
|
606
598
|
|
607
599
|
async def insert_items(
|
608
600
|
engine: AsyncEngine,
|
609
601
|
*items: _InsertItem,
|
610
602
|
snake: bool = False,
|
603
|
+
is_upsert: bool = False,
|
611
604
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
612
605
|
assume_tables_exist: bool = False,
|
613
606
|
timeout_create: Delta | None = None,
|
@@ -637,37 +630,181 @@ async def insert_items(
|
|
637
630
|
Obj(k1=v21, k2=v22, ...),
|
638
631
|
...]
|
639
632
|
"""
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
return insert(table), values
|
647
|
-
case _:
|
648
|
-
return insert(table).values(list(values)), None
|
649
|
-
|
650
|
-
try:
|
651
|
-
prepared = _prepare_insert_or_upsert_items(
|
652
|
-
partial(_normalize_insert_item, snake=snake),
|
653
|
-
engine,
|
654
|
-
build_insert,
|
655
|
-
*items,
|
656
|
-
chunk_size_frac=chunk_size_frac,
|
657
|
-
)
|
658
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
659
|
-
raise InsertItemsError(item=error.item) from None
|
633
|
+
normalized = chain.from_iterable(
|
634
|
+
_insert_items_yield_normalized(i, snake=snake) for i in items
|
635
|
+
)
|
636
|
+
triples = _insert_items_yield_triples(
|
637
|
+
engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
638
|
+
)
|
660
639
|
if not assume_tables_exist:
|
640
|
+
triples = list(triples)
|
641
|
+
tables = {table for table, _, _ in triples}
|
661
642
|
await ensure_tables_created(
|
662
|
-
engine, *
|
643
|
+
engine, *tables, timeout=timeout_create, error=error_create
|
663
644
|
)
|
664
|
-
for ins, parameters in
|
645
|
+
for _, ins, parameters in triples:
|
665
646
|
async with yield_connection(
|
666
647
|
engine, timeout=timeout_insert, error=error_insert
|
667
648
|
) as conn:
|
668
649
|
_ = await conn.execute(ins, parameters=parameters)
|
669
650
|
|
670
651
|
|
652
|
+
def _insert_items_yield_normalized(
|
653
|
+
item: _InsertItem, /, *, snake: bool = False
|
654
|
+
) -> Iterator[_NormalizedItem]:
|
655
|
+
if _is_pair_of_str_mapping_and_table(item):
|
656
|
+
mapping, table_or_orm = item
|
657
|
+
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
658
|
+
yield (get_table(table_or_orm), adjusted)
|
659
|
+
return
|
660
|
+
if _is_pair_of_tuple_and_table(item):
|
661
|
+
tuple_, table_or_orm = item
|
662
|
+
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
663
|
+
yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
|
664
|
+
return
|
665
|
+
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
666
|
+
items, table_or_orm = item
|
667
|
+
pairs = [(i, table_or_orm) for i in items]
|
668
|
+
for p in pairs:
|
669
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
670
|
+
return
|
671
|
+
if isinstance(item, DeclarativeBase):
|
672
|
+
mapping = _orm_inst_to_dict(item)
|
673
|
+
yield from _insert_items_yield_normalized((mapping, item), snake=snake)
|
674
|
+
return
|
675
|
+
try:
|
676
|
+
_ = iter(item)
|
677
|
+
except TypeError:
|
678
|
+
raise InsertItemsError(item=item) from None
|
679
|
+
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
680
|
+
pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
681
|
+
for p in pairs:
|
682
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
683
|
+
return
|
684
|
+
if all(map(is_orm, item)):
|
685
|
+
classes = cast("Sequence[DeclarativeBase]", item)
|
686
|
+
for c in classes:
|
687
|
+
yield from _insert_items_yield_normalized(c, snake=snake)
|
688
|
+
return
|
689
|
+
raise InsertItemsError(item=item)
|
690
|
+
|
691
|
+
|
692
|
+
def _insert_items_yield_triples(
|
693
|
+
engine: AsyncEngine,
|
694
|
+
items: Iterable[_NormalizedItem],
|
695
|
+
/,
|
696
|
+
*,
|
697
|
+
is_upsert: bool = False,
|
698
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
699
|
+
) -> Iterable[_InsertTriple]:
|
700
|
+
pairs = _insert_items_yield_chunked_pairs(
|
701
|
+
engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
702
|
+
)
|
703
|
+
for table, mappings in pairs:
|
704
|
+
match is_upsert, _get_dialect(engine):
|
705
|
+
case False, "oracle": # pragma: no cover
|
706
|
+
ins = insert(table)
|
707
|
+
parameters = mappings
|
708
|
+
case False, _:
|
709
|
+
ins = insert(table).values(mappings)
|
710
|
+
parameters = None
|
711
|
+
case True, _:
|
712
|
+
ins = _insert_items_build_insert_with_on_conflict_do_update(
|
713
|
+
engine, table, mappings
|
714
|
+
)
|
715
|
+
parameters = None
|
716
|
+
case never:
|
717
|
+
assert_never(never)
|
718
|
+
yield table, ins, parameters
|
719
|
+
|
720
|
+
|
721
|
+
def _insert_items_yield_chunked_pairs(
|
722
|
+
engine: AsyncEngine,
|
723
|
+
items: Iterable[_NormalizedItem],
|
724
|
+
/,
|
725
|
+
*,
|
726
|
+
is_upsert: bool = False,
|
727
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
728
|
+
) -> Iterable[_InsertPair]:
|
729
|
+
for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
|
730
|
+
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
731
|
+
for mappings_i in chunked(mappings, chunk_size):
|
732
|
+
yield table, list(mappings_i)
|
733
|
+
|
734
|
+
|
735
|
+
def _insert_items_yield_raw_pairs(
|
736
|
+
items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
|
737
|
+
) -> Iterable[_InsertPair]:
|
738
|
+
by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
739
|
+
for table, mapping in items:
|
740
|
+
by_table[table].append(mapping)
|
741
|
+
for table, mappings in by_table.items():
|
742
|
+
yield from _insert_items_yield_raw_pairs_one(
|
743
|
+
table, mappings, is_upsert=is_upsert
|
744
|
+
)
|
745
|
+
|
746
|
+
|
747
|
+
def _insert_items_yield_raw_pairs_one(
|
748
|
+
table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
|
749
|
+
) -> Iterable[_InsertPair]:
|
750
|
+
merged = _insert_items_yield_merged_mappings(table, mappings)
|
751
|
+
match is_upsert:
|
752
|
+
case True:
|
753
|
+
by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
|
754
|
+
for mapping in merged:
|
755
|
+
non_null = {k: v for k, v in mapping.items() if v is not None}
|
756
|
+
by_keys[frozenset(non_null)].append(non_null)
|
757
|
+
for mappings_i in by_keys.values():
|
758
|
+
yield table, mappings_i
|
759
|
+
case False:
|
760
|
+
yield table, list(merged)
|
761
|
+
case never:
|
762
|
+
assert_never(never)
|
763
|
+
|
764
|
+
|
765
|
+
def _insert_items_yield_merged_mappings(
|
766
|
+
table: Table, mappings: Iterable[StrMapping], /
|
767
|
+
) -> Iterable[StrMapping]:
|
768
|
+
columns = list(yield_primary_key_columns(table))
|
769
|
+
col_names = [c.name for c in columns]
|
770
|
+
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
771
|
+
cols_non_auto = set(col_names) - cols_auto
|
772
|
+
by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
773
|
+
for mapping in mappings:
|
774
|
+
check_subset(cols_non_auto, mapping)
|
775
|
+
has_all_auto = set(cols_auto).issubset(mapping)
|
776
|
+
if has_all_auto:
|
777
|
+
pkey = tuple(mapping[k] for k in col_names)
|
778
|
+
rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
|
779
|
+
by_key[pkey].append(rest)
|
780
|
+
else:
|
781
|
+
yield mapping
|
782
|
+
for k, v in by_key.items():
|
783
|
+
head = dict(zip(col_names, k, strict=True))
|
784
|
+
yield merge_str_mappings(head, *v)
|
785
|
+
|
786
|
+
|
787
|
+
def _insert_items_build_insert_with_on_conflict_do_update(
|
788
|
+
engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
|
789
|
+
) -> Insert:
|
790
|
+
primary_key = cast("Any", table.primary_key)
|
791
|
+
mappings = list(mappings)
|
792
|
+
columns = merge_sets(*mappings)
|
793
|
+
match _get_dialect(engine):
|
794
|
+
case "postgresql": # skipif-ci-and-not-linux
|
795
|
+
ins = postgresql_insert(table).values(mappings)
|
796
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
797
|
+
return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
|
798
|
+
case "sqlite":
|
799
|
+
ins = sqlite_insert(table).values(mappings)
|
800
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
801
|
+
return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
802
|
+
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
803
|
+
raise NotImplementedError(dialect)
|
804
|
+
case never:
|
805
|
+
assert_never(never)
|
806
|
+
|
807
|
+
|
671
808
|
@dataclass(kw_only=True, slots=True)
|
672
809
|
class InsertItemsError(Exception):
|
673
810
|
item: _InsertItem
|
@@ -738,80 +875,6 @@ async def migrate_data(
|
|
738
875
|
##
|
739
876
|
|
740
877
|
|
741
|
-
def _normalize_insert_item(
|
742
|
-
item: _InsertItem, /, *, snake: bool = False
|
743
|
-
) -> list[_NormalizedItem]:
|
744
|
-
"""Normalize an insertion item."""
|
745
|
-
if _is_pair_of_str_mapping_and_table(item):
|
746
|
-
mapping, table_or_orm = item
|
747
|
-
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
748
|
-
normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
|
749
|
-
return [normalized]
|
750
|
-
if _is_pair_of_tuple_and_table(item):
|
751
|
-
tuple_, table_or_orm = item
|
752
|
-
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
753
|
-
return _normalize_insert_item((mapping, table_or_orm), snake=snake)
|
754
|
-
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
755
|
-
items, table_or_orm = item
|
756
|
-
pairs = [(i, table_or_orm) for i in items]
|
757
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
|
758
|
-
return list(chain.from_iterable(normalized))
|
759
|
-
if isinstance(item, DeclarativeBase):
|
760
|
-
mapping = _orm_inst_to_dict(item)
|
761
|
-
return _normalize_insert_item((mapping, item), snake=snake)
|
762
|
-
try:
|
763
|
-
_ = iter(item)
|
764
|
-
except TypeError:
|
765
|
-
raise _NormalizeInsertItemError(item=item) from None
|
766
|
-
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
767
|
-
seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
768
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
769
|
-
return list(chain.from_iterable(normalized))
|
770
|
-
if all(map(is_orm, item)):
|
771
|
-
seq = cast("Sequence[DeclarativeBase]", item)
|
772
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
773
|
-
return list(chain.from_iterable(normalized))
|
774
|
-
raise _NormalizeInsertItemError(item=item)
|
775
|
-
|
776
|
-
|
777
|
-
@dataclass(kw_only=True, slots=True)
|
778
|
-
class _NormalizeInsertItemError(Exception):
|
779
|
-
item: _InsertItem
|
780
|
-
|
781
|
-
@override
|
782
|
-
def __str__(self) -> str:
|
783
|
-
return f"Item must be valid; got {self.item}"
|
784
|
-
|
785
|
-
|
786
|
-
@dataclass(kw_only=True, slots=True)
|
787
|
-
class _NormalizedItem:
|
788
|
-
mapping: StrMapping
|
789
|
-
table: Table
|
790
|
-
|
791
|
-
|
792
|
-
def _normalize_upsert_item(
|
793
|
-
item: _InsertItem,
|
794
|
-
/,
|
795
|
-
*,
|
796
|
-
snake: bool = False,
|
797
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
798
|
-
) -> Iterator[_NormalizedItem]:
|
799
|
-
"""Normalize an upsert item."""
|
800
|
-
normalized = _normalize_insert_item(item, snake=snake)
|
801
|
-
match selected_or_all:
|
802
|
-
case "selected":
|
803
|
-
for norm in normalized:
|
804
|
-
values = {k: v for k, v in norm.mapping.items() if v is not None}
|
805
|
-
yield _NormalizedItem(mapping=values, table=norm.table)
|
806
|
-
case "all":
|
807
|
-
yield from normalized
|
808
|
-
case never:
|
809
|
-
assert_never(never)
|
810
|
-
|
811
|
-
|
812
|
-
##
|
813
|
-
|
814
|
-
|
815
878
|
def selectable_to_string(
|
816
879
|
selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
|
817
880
|
) -> str:
|
@@ -836,134 +899,6 @@ class TablenameMixin:
|
|
836
899
|
##
|
837
900
|
|
838
901
|
|
839
|
-
type _SelectedOrAll = Literal["selected", "all"]
|
840
|
-
|
841
|
-
|
842
|
-
async def upsert_items(
|
843
|
-
engine: AsyncEngine,
|
844
|
-
/,
|
845
|
-
*items: _InsertItem,
|
846
|
-
snake: bool = False,
|
847
|
-
selected_or_all: _SelectedOrAll = "selected",
|
848
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
849
|
-
assume_tables_exist: bool = False,
|
850
|
-
timeout_create: Delta | None = None,
|
851
|
-
error_create: type[Exception] = TimeoutError,
|
852
|
-
timeout_insert: Delta | None = None,
|
853
|
-
error_insert: type[Exception] = TimeoutError,
|
854
|
-
) -> None:
|
855
|
-
"""Upsert a set of items into a database.
|
856
|
-
|
857
|
-
These can be one of the following:
|
858
|
-
- pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
|
859
|
-
- pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
|
860
|
-
{k1=v21, k2=v22, ...},
|
861
|
-
...], table/class
|
862
|
-
- list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
|
863
|
-
({k1=v21, k2=v22, ...}, table_cls2),
|
864
|
-
...]
|
865
|
-
- mapped class: Obj(k1=v1, k2=v2, ...)
|
866
|
-
- list of mapped classes: [Obj(k1=v11, k2=v12, ...),
|
867
|
-
Obj(k1=v21, k2=v22, ...),
|
868
|
-
...]
|
869
|
-
"""
|
870
|
-
|
871
|
-
def build_insert(
|
872
|
-
table: Table, values: Iterable[StrMapping], /
|
873
|
-
) -> tuple[Insert, None]:
|
874
|
-
ups = _upsert_items_build(
|
875
|
-
engine, table, values, selected_or_all=selected_or_all
|
876
|
-
)
|
877
|
-
return ups, None
|
878
|
-
|
879
|
-
try:
|
880
|
-
prepared = _prepare_insert_or_upsert_items(
|
881
|
-
partial(
|
882
|
-
_normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
|
883
|
-
),
|
884
|
-
engine,
|
885
|
-
build_insert,
|
886
|
-
*items,
|
887
|
-
chunk_size_frac=chunk_size_frac,
|
888
|
-
)
|
889
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
890
|
-
raise UpsertItemsError(item=error.item) from None
|
891
|
-
if not assume_tables_exist:
|
892
|
-
await ensure_tables_created(
|
893
|
-
engine, *prepared.tables, timeout=timeout_create, error=error_create
|
894
|
-
)
|
895
|
-
for ups, _ in prepared.yield_pairs():
|
896
|
-
async with yield_connection(
|
897
|
-
engine, timeout=timeout_insert, error=error_insert
|
898
|
-
) as conn:
|
899
|
-
_ = await conn.execute(ups)
|
900
|
-
|
901
|
-
|
902
|
-
def _upsert_items_build(
|
903
|
-
engine: AsyncEngine,
|
904
|
-
table: Table,
|
905
|
-
values: Iterable[StrMapping],
|
906
|
-
/,
|
907
|
-
*,
|
908
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
909
|
-
) -> Insert:
|
910
|
-
values = list(values)
|
911
|
-
keys = merge_sets(*values)
|
912
|
-
dict_nones = dict.fromkeys(keys)
|
913
|
-
values = [{**dict_nones, **v} for v in values]
|
914
|
-
match _get_dialect(engine):
|
915
|
-
case "postgresql": # skipif-ci-and-not-linux
|
916
|
-
insert = postgresql_insert
|
917
|
-
case "sqlite":
|
918
|
-
insert = sqlite_insert
|
919
|
-
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
920
|
-
raise NotImplementedError(dialect)
|
921
|
-
case never:
|
922
|
-
assert_never(never)
|
923
|
-
ins = insert(table).values(values)
|
924
|
-
primary_key = cast("Any", table.primary_key)
|
925
|
-
return _upsert_items_apply_on_conflict_do_update(
|
926
|
-
values, ins, primary_key, selected_or_all=selected_or_all
|
927
|
-
)
|
928
|
-
|
929
|
-
|
930
|
-
def _upsert_items_apply_on_conflict_do_update(
|
931
|
-
values: Iterable[StrMapping],
|
932
|
-
insert: postgresql_Insert | sqlite_Insert,
|
933
|
-
primary_key: PrimaryKeyConstraint,
|
934
|
-
/,
|
935
|
-
*,
|
936
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
937
|
-
) -> Insert:
|
938
|
-
match selected_or_all:
|
939
|
-
case "selected":
|
940
|
-
columns = merge_sets(*values)
|
941
|
-
case "all":
|
942
|
-
columns = {c.name for c in insert.excluded}
|
943
|
-
case never:
|
944
|
-
assert_never(never)
|
945
|
-
set_ = {c: getattr(insert.excluded, c) for c in columns}
|
946
|
-
match insert:
|
947
|
-
case postgresql_Insert(): # skipif-ci
|
948
|
-
return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
|
949
|
-
case sqlite_Insert():
|
950
|
-
return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
951
|
-
case never:
|
952
|
-
assert_never(never)
|
953
|
-
|
954
|
-
|
955
|
-
@dataclass(kw_only=True, slots=True)
|
956
|
-
class UpsertItemsError(Exception):
|
957
|
-
item: _InsertItem
|
958
|
-
|
959
|
-
@override
|
960
|
-
def __str__(self) -> str:
|
961
|
-
return f"Item must be valid; got {self.item}"
|
962
|
-
|
963
|
-
|
964
|
-
##
|
965
|
-
|
966
|
-
|
967
902
|
@asynccontextmanager
|
968
903
|
async def yield_connection(
|
969
904
|
engine: AsyncEngine,
|
@@ -1207,84 +1142,6 @@ def _orm_inst_to_dict_predicate(
|
|
1207
1142
|
##
|
1208
1143
|
|
1209
1144
|
|
1210
|
-
@dataclass(kw_only=True, slots=True)
|
1211
|
-
class _PrepareInsertOrUpsertItems:
|
1212
|
-
mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
|
1213
|
-
yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
|
1214
|
-
|
1215
|
-
@property
|
1216
|
-
def tables(self) -> Sequence[Table]:
|
1217
|
-
return list(self.mapping)
|
1218
|
-
|
1219
|
-
|
1220
|
-
def _prepare_insert_or_upsert_items(
|
1221
|
-
normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
|
1222
|
-
engine: AsyncEngine,
|
1223
|
-
build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
|
1224
|
-
/,
|
1225
|
-
*items: Any,
|
1226
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
1227
|
-
) -> _PrepareInsertOrUpsertItems:
|
1228
|
-
"""Prepare a set of insert/upsert items."""
|
1229
|
-
mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
1230
|
-
lengths: set[int] = set()
|
1231
|
-
try:
|
1232
|
-
for item in items:
|
1233
|
-
for normed in normalize_item(item):
|
1234
|
-
mapping[normed.table].append(normed.mapping)
|
1235
|
-
lengths.add(len(normed.mapping))
|
1236
|
-
except _NormalizeInsertItemError as error:
|
1237
|
-
raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
|
1238
|
-
merged: dict[Table, list[StrMapping]] = {
|
1239
|
-
table: _prepare_insert_or_upsert_items_merge_items(table, values)
|
1240
|
-
for table, values in mapping.items()
|
1241
|
-
}
|
1242
|
-
|
1243
|
-
def yield_pairs() -> Iterator[tuple[Insert, None]]:
|
1244
|
-
for table, values in merged.items():
|
1245
|
-
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
1246
|
-
for chunk in chunked(values, chunk_size):
|
1247
|
-
yield build_insert(table, chunk)
|
1248
|
-
|
1249
|
-
return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
|
1250
|
-
|
1251
|
-
|
1252
|
-
@dataclass(kw_only=True, slots=True)
|
1253
|
-
class _PrepareInsertOrUpsertItemsError(Exception):
|
1254
|
-
item: Any
|
1255
|
-
|
1256
|
-
@override
|
1257
|
-
def __str__(self) -> str:
|
1258
|
-
return f"Item must be valid; got {self.item}"
|
1259
|
-
|
1260
|
-
|
1261
|
-
def _prepare_insert_or_upsert_items_merge_items(
|
1262
|
-
table: Table, items: Iterable[StrMapping], /
|
1263
|
-
) -> list[StrMapping]:
|
1264
|
-
columns = list(yield_primary_key_columns(table))
|
1265
|
-
col_names = [c.name for c in columns]
|
1266
|
-
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
1267
|
-
cols_non_auto = set(col_names) - cols_auto
|
1268
|
-
mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
1269
|
-
unchanged: list[StrMapping] = []
|
1270
|
-
for item in items:
|
1271
|
-
check_subset(cols_non_auto, item)
|
1272
|
-
has_all_auto = set(cols_auto).issubset(item)
|
1273
|
-
if has_all_auto:
|
1274
|
-
pkey = tuple(item[k] for k in col_names)
|
1275
|
-
rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
|
1276
|
-
mapping[pkey].append(rest)
|
1277
|
-
else:
|
1278
|
-
unchanged.append(item)
|
1279
|
-
merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
|
1280
|
-
return [
|
1281
|
-
dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
|
1282
|
-
] + unchanged
|
1283
|
-
|
1284
|
-
|
1285
|
-
##
|
1286
|
-
|
1287
|
-
|
1288
1145
|
def _tuple_to_mapping(
|
1289
1146
|
values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
|
1290
1147
|
) -> dict[str, Any]:
|
@@ -1303,7 +1160,6 @@ __all__ = [
|
|
1303
1160
|
"GetTableError",
|
1304
1161
|
"InsertItemsError",
|
1305
1162
|
"TablenameMixin",
|
1306
|
-
"UpsertItemsError",
|
1307
1163
|
"check_connect",
|
1308
1164
|
"check_connect_async",
|
1309
1165
|
"check_engine",
|
@@ -1329,7 +1185,6 @@ __all__ = [
|
|
1329
1185
|
"is_table_or_orm",
|
1330
1186
|
"migrate_data",
|
1331
1187
|
"selectable_to_string",
|
1332
|
-
"upsert_items",
|
1333
1188
|
"yield_connection",
|
1334
1189
|
"yield_primary_key_columns",
|
1335
1190
|
]
|
utilities/sqlalchemy_polars.py
CHANGED
@@ -4,7 +4,7 @@ import datetime as dt
|
|
4
4
|
import decimal
|
5
5
|
from contextlib import suppress
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, cast, overload, override
|
8
8
|
from uuid import UUID
|
9
9
|
|
10
10
|
import polars as pl
|
@@ -44,7 +44,6 @@ from utilities.sqlalchemy import (
|
|
44
44
|
get_chunk_size,
|
45
45
|
get_columns,
|
46
46
|
insert_items,
|
47
|
-
upsert_items,
|
48
47
|
)
|
49
48
|
from utilities.text import snake_case
|
50
49
|
from utilities.typing import is_subclass_gen
|
@@ -75,9 +74,9 @@ async def insert_dataframe(
|
|
75
74
|
/,
|
76
75
|
*,
|
77
76
|
snake: bool = False,
|
77
|
+
is_upsert: bool = False,
|
78
78
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
79
79
|
assume_tables_exist: bool = False,
|
80
|
-
upsert: Literal["selected", "all"] | None = None,
|
81
80
|
timeout_create: TimeDelta | None = None,
|
82
81
|
error_create: type[Exception] = TimeoutError,
|
83
82
|
timeout_insert: TimeDelta | None = None,
|
@@ -87,43 +86,25 @@ async def insert_dataframe(
|
|
87
86
|
mapping = _insert_dataframe_map_df_schema_to_table(
|
88
87
|
df.schema, table_or_orm, snake=snake
|
89
88
|
)
|
90
|
-
items = df.select(*mapping).rename(mapping).
|
89
|
+
items = df.select(*mapping).rename(mapping).rows(named=True)
|
91
90
|
if len(items) == 0:
|
92
|
-
if not df.is_empty():
|
93
|
-
raise InsertDataFrameError(df=df)
|
94
91
|
if not assume_tables_exist:
|
95
92
|
await ensure_tables_created(
|
96
93
|
engine, table_or_orm, timeout=timeout_create, error=error_create
|
97
94
|
)
|
98
95
|
return
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
)
|
112
|
-
case "selected" | "all" as selected_or_all: # skipif-ci-and-not-linux
|
113
|
-
await upsert_items(
|
114
|
-
engine,
|
115
|
-
(items, table_or_orm),
|
116
|
-
snake=snake,
|
117
|
-
chunk_size_frac=chunk_size_frac,
|
118
|
-
selected_or_all=selected_or_all,
|
119
|
-
assume_tables_exist=assume_tables_exist,
|
120
|
-
timeout_create=timeout_create,
|
121
|
-
error_create=error_create,
|
122
|
-
timeout_insert=timeout_insert,
|
123
|
-
error_insert=error_insert,
|
124
|
-
)
|
125
|
-
case never:
|
126
|
-
assert_never(never)
|
96
|
+
await insert_items(
|
97
|
+
engine,
|
98
|
+
(items, table_or_orm),
|
99
|
+
snake=snake,
|
100
|
+
is_upsert=is_upsert,
|
101
|
+
chunk_size_frac=chunk_size_frac,
|
102
|
+
assume_tables_exist=assume_tables_exist,
|
103
|
+
timeout_create=timeout_create,
|
104
|
+
error_create=error_create,
|
105
|
+
timeout_insert=timeout_insert,
|
106
|
+
error_insert=error_insert,
|
107
|
+
)
|
127
108
|
|
128
109
|
|
129
110
|
def _insert_dataframe_map_df_schema_to_table(
|
@@ -207,15 +188,6 @@ def _insert_dataframe_check_df_and_db_types(
|
|
207
188
|
)
|
208
189
|
|
209
190
|
|
210
|
-
@dataclass(kw_only=True, slots=True)
|
211
|
-
class InsertDataFrameError(Exception):
|
212
|
-
df: DataFrame
|
213
|
-
|
214
|
-
@override
|
215
|
-
def __str__(self) -> str:
|
216
|
-
return f"Non-empty DataFrame must resolve to at least 1 item\n\n{self.df}"
|
217
|
-
|
218
|
-
|
219
191
|
@overload
|
220
192
|
async def select_to_dataframe(
|
221
193
|
sel: Select[Any],
|
@@ -229,6 +201,7 @@ async def select_to_dataframe(
|
|
229
201
|
in_clauses_chunk_size: int | None = None,
|
230
202
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
231
203
|
timeout: Delta | None = None,
|
204
|
+
error: MaybeType[BaseException] = TimeoutError,
|
232
205
|
**kwargs: Any,
|
233
206
|
) -> DataFrame: ...
|
234
207
|
@overload
|
@@ -244,6 +217,7 @@ async def select_to_dataframe(
|
|
244
217
|
in_clauses_chunk_size: int | None = None,
|
245
218
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
246
219
|
timeout: Delta | None = None,
|
220
|
+
error: MaybeType[BaseException] = TimeoutError,
|
247
221
|
**kwargs: Any,
|
248
222
|
) -> Iterable[DataFrame]: ...
|
249
223
|
@overload
|
@@ -259,6 +233,7 @@ async def select_to_dataframe(
|
|
259
233
|
in_clauses_chunk_size: int | None = None,
|
260
234
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
261
235
|
timeout: Delta | None = None,
|
236
|
+
error: MaybeType[BaseException] = TimeoutError,
|
262
237
|
**kwargs: Any,
|
263
238
|
) -> AsyncIterable[DataFrame]: ...
|
264
239
|
@overload
|
@@ -274,6 +249,7 @@ async def select_to_dataframe(
|
|
274
249
|
in_clauses_chunk_size: int | None = None,
|
275
250
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
276
251
|
timeout: Delta | None = None,
|
252
|
+
error: MaybeType[BaseException] = TimeoutError,
|
277
253
|
**kwargs: Any,
|
278
254
|
) -> DataFrame | Iterable[DataFrame] | AsyncIterable[DataFrame]: ...
|
279
255
|
async def select_to_dataframe(
|
@@ -439,4 +415,4 @@ def _select_to_dataframe_yield_selects_with_in_clauses(
|
|
439
415
|
return (sel.where(in_col.in_(values)) for values in chunked(in_values, chunk_size))
|
440
416
|
|
441
417
|
|
442
|
-
__all__ = ["
|
418
|
+
__all__ = ["insert_dataframe", "select_to_dataframe"]
|
File without changes
|
File without changes
|
File without changes
|