dycw-utilities 0.157.2__py3-none-any.whl → 0.158.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dycw_utilities-0.157.2.dist-info → dycw_utilities-0.158.0.dist-info}/METADATA +1 -1
- {dycw_utilities-0.157.2.dist-info → dycw_utilities-0.158.0.dist-info}/RECORD +9 -9
- utilities/__init__.py +1 -1
- utilities/hypothesis.py +22 -0
- utilities/sqlalchemy.py +180 -329
- utilities/sqlalchemy_polars.py +16 -44
- {dycw_utilities-0.157.2.dist-info → dycw_utilities-0.158.0.dist-info}/WHEEL +0 -0
- {dycw_utilities-0.157.2.dist-info → dycw_utilities-0.158.0.dist-info}/entry_points.txt +0 -0
- {dycw_utilities-0.157.2.dist-info → dycw_utilities-0.158.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
utilities/__init__.py,sha256=
|
1
|
+
utilities/__init__.py,sha256=dllUsKScS7RBnWFHME-sLeOoDAWms1u25frX9QMEP-4,60
|
2
2
|
utilities/altair.py,sha256=92E2lCdyHY4Zb-vCw6rEJIsWdKipuu-Tu2ab1ufUfAk,9079
|
3
3
|
utilities/asyncio.py,sha256=PUedzQ5deqlSECQ33sam9cRzI9TnygHz3FdOqWJWPTM,15288
|
4
4
|
utilities/atomicwrites.py,sha256=tPo6r-Rypd9u99u66B9z86YBPpnLrlHtwox_8Z7T34Y,5790
|
@@ -22,7 +22,7 @@ utilities/getpass.py,sha256=DfN5UgMAtFCqS3dSfFHUfqIMZX2shXvwphOz_6J6f6A,103
|
|
22
22
|
utilities/gzip.py,sha256=fkGP3KdsBfXlstodT4wtlp-PwNyUsogpbDCVVVGdsm4,781
|
23
23
|
utilities/hashlib.py,sha256=SVTgtguur0P4elppvzOBbLEjVM3Pea0eWB61yg2ilxo,309
|
24
24
|
utilities/http.py,sha256=TsavEfHlRtlLaeV21Z6KZh0qbPw-kvD1zsQdZ7Kep5Q,977
|
25
|
-
utilities/hypothesis.py,sha256=
|
25
|
+
utilities/hypothesis.py,sha256=OG8mxN6Y3fSEjRg4NiIjsO_JUHJBzh4g8fvpmKRoRU8,44370
|
26
26
|
utilities/importlib.py,sha256=mV1xT_O_zt_GnZZ36tl3xOmMaN_3jErDWY54fX39F6Y,429
|
27
27
|
utilities/inflect.py,sha256=v7YkOWSu8NAmVghPcf4F3YBZQoJCS47_DLf9jbfWIs0,581
|
28
28
|
utilities/ipython.py,sha256=V2oMYHvEKvlNBzxDXdLvKi48oUq2SclRg5xasjaXStw,763
|
@@ -64,8 +64,8 @@ utilities/sentinel.py,sha256=A_p5jX2K0Yc5XBfoYHyBLqHsEWzE1ByOdDuzzA2pZnE,1434
|
|
64
64
|
utilities/shelve.py,sha256=4OzjQI6kGuUbJciqf535rwnao-_IBv66gsT6tRGiUt0,759
|
65
65
|
utilities/slack_sdk.py,sha256=76-DYtcGiUhEvl-voMamc5OjfF7Y7nCq54Bys1arqzw,2233
|
66
66
|
utilities/socket.py,sha256=K77vfREvzoVTrpYKo6MZakol0EYu2q1sWJnnZqL0So0,118
|
67
|
-
utilities/sqlalchemy.py,sha256=
|
68
|
-
utilities/sqlalchemy_polars.py,sha256=
|
67
|
+
utilities/sqlalchemy.py,sha256=4rLjecgDe60MnuWxUqZVIQy3dKepNaemoGyrnuCdNZg,36347
|
68
|
+
utilities/sqlalchemy_polars.py,sha256=JCGhB37raSR7fqeWV5dTsciRTMVzIdVT9YSqKT0piT0,13370
|
69
69
|
utilities/statsmodels.py,sha256=koyiBHvpMcSiBfh99wFUfSggLNx7cuAw3rwyfAhoKpQ,3410
|
70
70
|
utilities/string.py,sha256=shmBK87zZwzGyixuNuXCiUbqzfeZ9xlrFwz6JTaRvDk,582
|
71
71
|
utilities/tempfile.py,sha256=HxB2BF28CcecDJLQ3Bx2Ej-Pb6RJc6W9ngSpB9CnP4k,2018
|
@@ -87,8 +87,8 @@ utilities/zoneinfo.py,sha256=FBMcUQ4662Aq8SsuCL1OAhDQiyANmVjtb-C30DRrWoE,1966
|
|
87
87
|
utilities/pytest_plugins/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
|
88
88
|
utilities/pytest_plugins/pytest_randomly.py,sha256=B1qYVlExGOxTywq2r1SMi5o7btHLk2PNdY_b1p98dkE,409
|
89
89
|
utilities/pytest_plugins/pytest_regressions.py,sha256=9v8kAXDM2ycIXJBimoiF4EgrwbUvxTycFWJiGR_GHhM,1466
|
90
|
-
dycw_utilities-0.
|
91
|
-
dycw_utilities-0.
|
92
|
-
dycw_utilities-0.
|
93
|
-
dycw_utilities-0.
|
94
|
-
dycw_utilities-0.
|
90
|
+
dycw_utilities-0.158.0.dist-info/METADATA,sha256=-lYD9v1af9rVk_ibo4Z2cLp-i4JgpoR7qiHy1svZeDU,1643
|
91
|
+
dycw_utilities-0.158.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
92
|
+
dycw_utilities-0.158.0.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
|
93
|
+
dycw_utilities-0.158.0.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
|
94
|
+
dycw_utilities-0.158.0.dist-info/RECORD,,
|
utilities/__init__.py
CHANGED
utilities/hypothesis.py
CHANGED
@@ -897,6 +897,27 @@ def py_datetimes(
|
|
897
897
|
##
|
898
898
|
|
899
899
|
|
900
|
+
def quadruples[T](
|
901
|
+
strategy: SearchStrategy[T],
|
902
|
+
/,
|
903
|
+
*,
|
904
|
+
unique: MaybeSearchStrategy[bool] = False,
|
905
|
+
sorted: MaybeSearchStrategy[bool] = False, # noqa: A002
|
906
|
+
) -> SearchStrategy[tuple[T, T, T, T]]:
|
907
|
+
"""Strategy for generating quadruples of elements."""
|
908
|
+
return lists_fixed_length(strategy, 4, unique=unique, sorted=sorted).map(
|
909
|
+
_quadruples_map
|
910
|
+
)
|
911
|
+
|
912
|
+
|
913
|
+
def _quadruples_map[T](elements: list[T], /) -> tuple[T, T, T, T]:
|
914
|
+
first, second, third, fourth = elements
|
915
|
+
return first, second, third, fourth
|
916
|
+
|
917
|
+
|
918
|
+
##
|
919
|
+
|
920
|
+
|
900
921
|
@composite
|
901
922
|
def random_states(
|
902
923
|
draw: DrawFn, /, *, seed: MaybeSearchStrategy[int | None] = None
|
@@ -1555,6 +1576,7 @@ __all__ = [
|
|
1555
1576
|
"paths",
|
1556
1577
|
"plain_date_times",
|
1557
1578
|
"py_datetimes",
|
1579
|
+
"quadruples",
|
1558
1580
|
"random_states",
|
1559
1581
|
"sentinels",
|
1560
1582
|
"sets_fixed_length",
|
utilities/sqlalchemy.py
CHANGED
@@ -12,8 +12,8 @@ from collections.abc import (
|
|
12
12
|
)
|
13
13
|
from collections.abc import Set as AbstractSet
|
14
14
|
from contextlib import asynccontextmanager
|
15
|
-
from dataclasses import dataclass
|
16
|
-
from functools import
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from functools import reduce
|
17
17
|
from itertools import chain
|
18
18
|
from math import floor
|
19
19
|
from operator import ge, le
|
@@ -37,7 +37,6 @@ from sqlalchemy import (
|
|
37
37
|
Connection,
|
38
38
|
Engine,
|
39
39
|
Insert,
|
40
|
-
PrimaryKeyConstraint,
|
41
40
|
Selectable,
|
42
41
|
Table,
|
43
42
|
and_,
|
@@ -49,20 +48,13 @@ from sqlalchemy import (
|
|
49
48
|
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
|
50
49
|
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
|
51
50
|
from sqlalchemy.dialects.oracle import dialect as oracle_dialect
|
52
|
-
from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
|
53
51
|
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
|
54
52
|
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
55
53
|
from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
|
56
54
|
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
|
57
|
-
from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
|
58
55
|
from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
|
59
56
|
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
60
|
-
from sqlalchemy.exc import
|
61
|
-
ArgumentError,
|
62
|
-
DatabaseError,
|
63
|
-
OperationalError,
|
64
|
-
ProgrammingError,
|
65
|
-
)
|
57
|
+
from sqlalchemy.exc import ArgumentError, DatabaseError
|
66
58
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
67
59
|
from sqlalchemy.orm import (
|
68
60
|
DeclarativeBase,
|
@@ -129,7 +121,7 @@ def check_connect(engine: Engine, /) -> bool:
|
|
129
121
|
try:
|
130
122
|
with engine.connect() as conn:
|
131
123
|
return bool(conn.execute(_SELECT).scalar_one())
|
132
|
-
except (gaierror, ConnectionRefusedError,
|
124
|
+
except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
|
133
125
|
return False
|
134
126
|
|
135
127
|
|
@@ -144,13 +136,7 @@ async def check_connect_async(
|
|
144
136
|
try:
|
145
137
|
async with timeout_td(timeout, error=error), engine.connect() as conn:
|
146
138
|
return bool((await conn.execute(_SELECT)).scalar_one())
|
147
|
-
except (
|
148
|
-
gaierror,
|
149
|
-
ConnectionRefusedError,
|
150
|
-
OperationalError,
|
151
|
-
ProgrammingError,
|
152
|
-
TimeoutError,
|
153
|
-
):
|
139
|
+
except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
|
154
140
|
return False
|
155
141
|
|
156
142
|
|
@@ -162,7 +148,7 @@ async def check_engine(
|
|
162
148
|
/,
|
163
149
|
*,
|
164
150
|
timeout: Delta | None = None,
|
165
|
-
error:
|
151
|
+
error: MaybeType[Exception] = TimeoutError,
|
166
152
|
num_tables: int | tuple[int, float] | None = None,
|
167
153
|
) -> None:
|
168
154
|
"""Check that an engine can connect.
|
@@ -336,9 +322,8 @@ async def ensure_database_created(super_: URL, database: str, /) -> None:
|
|
336
322
|
async with engine.begin() as conn:
|
337
323
|
try:
|
338
324
|
_ = await conn.execute(text(f"CREATE DATABASE {database}"))
|
339
|
-
except
|
340
|
-
|
341
|
-
raise
|
325
|
+
except DatabaseError as error:
|
326
|
+
_ensure_tables_maybe_reraise(error, 'database ".*" already exists')
|
342
327
|
|
343
328
|
|
344
329
|
async def ensure_database_dropped(super_: URL, database: str, /) -> None:
|
@@ -356,7 +341,7 @@ async def ensure_tables_created(
|
|
356
341
|
/,
|
357
342
|
*tables_or_orms: TableOrORMInstOrClass,
|
358
343
|
timeout: Delta | None = None,
|
359
|
-
error:
|
344
|
+
error: MaybeType[Exception] = TimeoutError,
|
360
345
|
) -> None:
|
361
346
|
"""Ensure a table/set of tables is/are created."""
|
362
347
|
tables = set(map(get_table, tables_or_orms))
|
@@ -385,7 +370,7 @@ async def ensure_tables_dropped(
|
|
385
370
|
engine: AsyncEngine,
|
386
371
|
*tables_or_orms: TableOrORMInstOrClass,
|
387
372
|
timeout: Delta | None = None,
|
388
|
-
error:
|
373
|
+
error: MaybeType[Exception] = TimeoutError,
|
389
374
|
) -> None:
|
390
375
|
"""Ensure a table/set of tables is/are dropped."""
|
391
376
|
tables = set(map(get_table, tables_or_orms))
|
@@ -606,12 +591,16 @@ type _InsertItem = (
|
|
606
591
|
| Sequence[_PairOfTupleOrStrMappingAndTable]
|
607
592
|
| Sequence[DeclarativeBase]
|
608
593
|
)
|
594
|
+
type _NormalizedItem = tuple[Table, StrMapping]
|
595
|
+
type _InsertPair = tuple[Table, Sequence[StrMapping]]
|
596
|
+
type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
|
609
597
|
|
610
598
|
|
611
599
|
async def insert_items(
|
612
600
|
engine: AsyncEngine,
|
613
601
|
*items: _InsertItem,
|
614
602
|
snake: bool = False,
|
603
|
+
is_upsert: bool = False,
|
615
604
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
616
605
|
assume_tables_exist: bool = False,
|
617
606
|
timeout_create: Delta | None = None,
|
@@ -641,37 +630,181 @@ async def insert_items(
|
|
641
630
|
Obj(k1=v21, k2=v22, ...),
|
642
631
|
...]
|
643
632
|
"""
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
return insert(table), values
|
651
|
-
case _:
|
652
|
-
return insert(table).values(list(values)), None
|
653
|
-
|
654
|
-
try:
|
655
|
-
prepared = _prepare_insert_or_upsert_items(
|
656
|
-
partial(_normalize_insert_item, snake=snake),
|
657
|
-
engine,
|
658
|
-
build_insert,
|
659
|
-
*items,
|
660
|
-
chunk_size_frac=chunk_size_frac,
|
661
|
-
)
|
662
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
663
|
-
raise InsertItemsError(item=error.item) from None
|
633
|
+
normalized = chain.from_iterable(
|
634
|
+
_insert_items_yield_normalized(i, snake=snake) for i in items
|
635
|
+
)
|
636
|
+
triples = _insert_items_yield_triples(
|
637
|
+
engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
638
|
+
)
|
664
639
|
if not assume_tables_exist:
|
640
|
+
triples = list(triples)
|
641
|
+
tables = {table for table, _, _ in triples}
|
665
642
|
await ensure_tables_created(
|
666
|
-
engine, *
|
643
|
+
engine, *tables, timeout=timeout_create, error=error_create
|
667
644
|
)
|
668
|
-
for ins, parameters in
|
645
|
+
for _, ins, parameters in triples:
|
669
646
|
async with yield_connection(
|
670
647
|
engine, timeout=timeout_insert, error=error_insert
|
671
648
|
) as conn:
|
672
649
|
_ = await conn.execute(ins, parameters=parameters)
|
673
650
|
|
674
651
|
|
652
|
+
def _insert_items_yield_normalized(
|
653
|
+
item: _InsertItem, /, *, snake: bool = False
|
654
|
+
) -> Iterator[_NormalizedItem]:
|
655
|
+
if _is_pair_of_str_mapping_and_table(item):
|
656
|
+
mapping, table_or_orm = item
|
657
|
+
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
658
|
+
yield (get_table(table_or_orm), adjusted)
|
659
|
+
return
|
660
|
+
if _is_pair_of_tuple_and_table(item):
|
661
|
+
tuple_, table_or_orm = item
|
662
|
+
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
663
|
+
yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
|
664
|
+
return
|
665
|
+
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
666
|
+
items, table_or_orm = item
|
667
|
+
pairs = [(i, table_or_orm) for i in items]
|
668
|
+
for p in pairs:
|
669
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
670
|
+
return
|
671
|
+
if isinstance(item, DeclarativeBase):
|
672
|
+
mapping = _orm_inst_to_dict(item)
|
673
|
+
yield from _insert_items_yield_normalized((mapping, item), snake=snake)
|
674
|
+
return
|
675
|
+
try:
|
676
|
+
_ = iter(item)
|
677
|
+
except TypeError:
|
678
|
+
raise InsertItemsError(item=item) from None
|
679
|
+
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
680
|
+
pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
681
|
+
for p in pairs:
|
682
|
+
yield from _insert_items_yield_normalized(p, snake=snake)
|
683
|
+
return
|
684
|
+
if all(map(is_orm, item)):
|
685
|
+
classes = cast("Sequence[DeclarativeBase]", item)
|
686
|
+
for c in classes:
|
687
|
+
yield from _insert_items_yield_normalized(c, snake=snake)
|
688
|
+
return
|
689
|
+
raise InsertItemsError(item=item)
|
690
|
+
|
691
|
+
|
692
|
+
def _insert_items_yield_triples(
|
693
|
+
engine: AsyncEngine,
|
694
|
+
items: Iterable[_NormalizedItem],
|
695
|
+
/,
|
696
|
+
*,
|
697
|
+
is_upsert: bool = False,
|
698
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
699
|
+
) -> Iterable[_InsertTriple]:
|
700
|
+
pairs = _insert_items_yield_chunked_pairs(
|
701
|
+
engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
|
702
|
+
)
|
703
|
+
for table, mappings in pairs:
|
704
|
+
match is_upsert, _get_dialect(engine):
|
705
|
+
case False, "oracle": # pragma: no cover
|
706
|
+
ins = insert(table)
|
707
|
+
parameters = mappings
|
708
|
+
case False, _:
|
709
|
+
ins = insert(table).values(mappings)
|
710
|
+
parameters = None
|
711
|
+
case True, _:
|
712
|
+
ins = _insert_items_build_insert_with_on_conflict_do_update(
|
713
|
+
engine, table, mappings
|
714
|
+
)
|
715
|
+
parameters = None
|
716
|
+
case never:
|
717
|
+
assert_never(never)
|
718
|
+
yield table, ins, parameters
|
719
|
+
|
720
|
+
|
721
|
+
def _insert_items_yield_chunked_pairs(
|
722
|
+
engine: AsyncEngine,
|
723
|
+
items: Iterable[_NormalizedItem],
|
724
|
+
/,
|
725
|
+
*,
|
726
|
+
is_upsert: bool = False,
|
727
|
+
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
728
|
+
) -> Iterable[_InsertPair]:
|
729
|
+
for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
|
730
|
+
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
731
|
+
for mappings_i in chunked(mappings, chunk_size):
|
732
|
+
yield table, list(mappings_i)
|
733
|
+
|
734
|
+
|
735
|
+
def _insert_items_yield_raw_pairs(
|
736
|
+
items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
|
737
|
+
) -> Iterable[_InsertPair]:
|
738
|
+
by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
739
|
+
for table, mapping in items:
|
740
|
+
by_table[table].append(mapping)
|
741
|
+
for table, mappings in by_table.items():
|
742
|
+
yield from _insert_items_yield_raw_pairs_one(
|
743
|
+
table, mappings, is_upsert=is_upsert
|
744
|
+
)
|
745
|
+
|
746
|
+
|
747
|
+
def _insert_items_yield_raw_pairs_one(
|
748
|
+
table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
|
749
|
+
) -> Iterable[_InsertPair]:
|
750
|
+
merged = _insert_items_yield_merged_mappings(table, mappings)
|
751
|
+
match is_upsert:
|
752
|
+
case True:
|
753
|
+
by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
|
754
|
+
for mapping in merged:
|
755
|
+
non_null = {k: v for k, v in mapping.items() if v is not None}
|
756
|
+
by_keys[frozenset(non_null)].append(non_null)
|
757
|
+
for mappings_i in by_keys.values():
|
758
|
+
yield table, mappings_i
|
759
|
+
case False:
|
760
|
+
yield table, list(merged)
|
761
|
+
case never:
|
762
|
+
assert_never(never)
|
763
|
+
|
764
|
+
|
765
|
+
def _insert_items_yield_merged_mappings(
|
766
|
+
table: Table, mappings: Iterable[StrMapping], /
|
767
|
+
) -> Iterable[StrMapping]:
|
768
|
+
columns = list(yield_primary_key_columns(table))
|
769
|
+
col_names = [c.name for c in columns]
|
770
|
+
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
771
|
+
cols_non_auto = set(col_names) - cols_auto
|
772
|
+
by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
773
|
+
for mapping in mappings:
|
774
|
+
check_subset(cols_non_auto, mapping)
|
775
|
+
has_all_auto = set(cols_auto).issubset(mapping)
|
776
|
+
if has_all_auto:
|
777
|
+
pkey = tuple(mapping[k] for k in col_names)
|
778
|
+
rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
|
779
|
+
by_key[pkey].append(rest)
|
780
|
+
else:
|
781
|
+
yield mapping
|
782
|
+
for k, v in by_key.items():
|
783
|
+
head = dict(zip(col_names, k, strict=True))
|
784
|
+
yield merge_str_mappings(head, *v)
|
785
|
+
|
786
|
+
|
787
|
+
def _insert_items_build_insert_with_on_conflict_do_update(
|
788
|
+
engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
|
789
|
+
) -> Insert:
|
790
|
+
primary_key = cast("Any", table.primary_key)
|
791
|
+
mappings = list(mappings)
|
792
|
+
columns = merge_sets(*mappings)
|
793
|
+
match _get_dialect(engine):
|
794
|
+
case "postgresql": # skipif-ci-and-not-linux
|
795
|
+
ins = postgresql_insert(table).values(mappings)
|
796
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
797
|
+
return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
|
798
|
+
case "sqlite":
|
799
|
+
ins = sqlite_insert(table).values(mappings)
|
800
|
+
set_ = {c: getattr(ins.excluded, c) for c in columns}
|
801
|
+
return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
802
|
+
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
803
|
+
raise NotImplementedError(dialect)
|
804
|
+
case never:
|
805
|
+
assert_never(never)
|
806
|
+
|
807
|
+
|
675
808
|
@dataclass(kw_only=True, slots=True)
|
676
809
|
class InsertItemsError(Exception):
|
677
810
|
item: _InsertItem
|
@@ -742,80 +875,6 @@ async def migrate_data(
|
|
742
875
|
##
|
743
876
|
|
744
877
|
|
745
|
-
def _normalize_insert_item(
|
746
|
-
item: _InsertItem, /, *, snake: bool = False
|
747
|
-
) -> list[_NormalizedItem]:
|
748
|
-
"""Normalize an insertion item."""
|
749
|
-
if _is_pair_of_str_mapping_and_table(item):
|
750
|
-
mapping, table_or_orm = item
|
751
|
-
adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
|
752
|
-
normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
|
753
|
-
return [normalized]
|
754
|
-
if _is_pair_of_tuple_and_table(item):
|
755
|
-
tuple_, table_or_orm = item
|
756
|
-
mapping = _tuple_to_mapping(tuple_, table_or_orm)
|
757
|
-
return _normalize_insert_item((mapping, table_or_orm), snake=snake)
|
758
|
-
if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
|
759
|
-
items, table_or_orm = item
|
760
|
-
pairs = [(i, table_or_orm) for i in items]
|
761
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
|
762
|
-
return list(chain.from_iterable(normalized))
|
763
|
-
if isinstance(item, DeclarativeBase):
|
764
|
-
mapping = _orm_inst_to_dict(item)
|
765
|
-
return _normalize_insert_item((mapping, item), snake=snake)
|
766
|
-
try:
|
767
|
-
_ = iter(item)
|
768
|
-
except TypeError:
|
769
|
-
raise _NormalizeInsertItemError(item=item) from None
|
770
|
-
if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
|
771
|
-
seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
|
772
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
773
|
-
return list(chain.from_iterable(normalized))
|
774
|
-
if all(map(is_orm, item)):
|
775
|
-
seq = cast("Sequence[DeclarativeBase]", item)
|
776
|
-
normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
|
777
|
-
return list(chain.from_iterable(normalized))
|
778
|
-
raise _NormalizeInsertItemError(item=item)
|
779
|
-
|
780
|
-
|
781
|
-
@dataclass(kw_only=True, slots=True)
|
782
|
-
class _NormalizeInsertItemError(Exception):
|
783
|
-
item: _InsertItem
|
784
|
-
|
785
|
-
@override
|
786
|
-
def __str__(self) -> str:
|
787
|
-
return f"Item must be valid; got {self.item}"
|
788
|
-
|
789
|
-
|
790
|
-
@dataclass(kw_only=True, slots=True)
|
791
|
-
class _NormalizedItem:
|
792
|
-
mapping: StrMapping
|
793
|
-
table: Table
|
794
|
-
|
795
|
-
|
796
|
-
def _normalize_upsert_item(
|
797
|
-
item: _InsertItem,
|
798
|
-
/,
|
799
|
-
*,
|
800
|
-
snake: bool = False,
|
801
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
802
|
-
) -> Iterator[_NormalizedItem]:
|
803
|
-
"""Normalize an upsert item."""
|
804
|
-
normalized = _normalize_insert_item(item, snake=snake)
|
805
|
-
match selected_or_all:
|
806
|
-
case "selected":
|
807
|
-
for norm in normalized:
|
808
|
-
values = {k: v for k, v in norm.mapping.items() if v is not None}
|
809
|
-
yield _NormalizedItem(mapping=values, table=norm.table)
|
810
|
-
case "all":
|
811
|
-
yield from normalized
|
812
|
-
case never:
|
813
|
-
assert_never(never)
|
814
|
-
|
815
|
-
|
816
|
-
##
|
817
|
-
|
818
|
-
|
819
878
|
def selectable_to_string(
|
820
879
|
selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
|
821
880
|
) -> str:
|
@@ -840,134 +899,6 @@ class TablenameMixin:
|
|
840
899
|
##
|
841
900
|
|
842
901
|
|
843
|
-
type _SelectedOrAll = Literal["selected", "all"]
|
844
|
-
|
845
|
-
|
846
|
-
async def upsert_items(
|
847
|
-
engine: AsyncEngine,
|
848
|
-
/,
|
849
|
-
*items: _InsertItem,
|
850
|
-
snake: bool = False,
|
851
|
-
selected_or_all: _SelectedOrAll = "selected",
|
852
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
853
|
-
assume_tables_exist: bool = False,
|
854
|
-
timeout_create: Delta | None = None,
|
855
|
-
error_create: type[Exception] = TimeoutError,
|
856
|
-
timeout_insert: Delta | None = None,
|
857
|
-
error_insert: type[Exception] = TimeoutError,
|
858
|
-
) -> None:
|
859
|
-
"""Upsert a set of items into a database.
|
860
|
-
|
861
|
-
These can be one of the following:
|
862
|
-
- pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
|
863
|
-
- pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
|
864
|
-
{k1=v21, k2=v22, ...},
|
865
|
-
...], table/class
|
866
|
-
- list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
|
867
|
-
({k1=v21, k2=v22, ...}, table_cls2),
|
868
|
-
...]
|
869
|
-
- mapped class: Obj(k1=v1, k2=v2, ...)
|
870
|
-
- list of mapped classes: [Obj(k1=v11, k2=v12, ...),
|
871
|
-
Obj(k1=v21, k2=v22, ...),
|
872
|
-
...]
|
873
|
-
"""
|
874
|
-
|
875
|
-
def build_insert(
|
876
|
-
table: Table, values: Iterable[StrMapping], /
|
877
|
-
) -> tuple[Insert, None]:
|
878
|
-
ups = _upsert_items_build(
|
879
|
-
engine, table, values, selected_or_all=selected_or_all
|
880
|
-
)
|
881
|
-
return ups, None
|
882
|
-
|
883
|
-
try:
|
884
|
-
prepared = _prepare_insert_or_upsert_items(
|
885
|
-
partial(
|
886
|
-
_normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
|
887
|
-
),
|
888
|
-
engine,
|
889
|
-
build_insert,
|
890
|
-
*items,
|
891
|
-
chunk_size_frac=chunk_size_frac,
|
892
|
-
)
|
893
|
-
except _PrepareInsertOrUpsertItemsError as error:
|
894
|
-
raise UpsertItemsError(item=error.item) from None
|
895
|
-
if not assume_tables_exist:
|
896
|
-
await ensure_tables_created(
|
897
|
-
engine, *prepared.tables, timeout=timeout_create, error=error_create
|
898
|
-
)
|
899
|
-
for ups, _ in prepared.yield_pairs():
|
900
|
-
async with yield_connection(
|
901
|
-
engine, timeout=timeout_insert, error=error_insert
|
902
|
-
) as conn:
|
903
|
-
_ = await conn.execute(ups)
|
904
|
-
|
905
|
-
|
906
|
-
def _upsert_items_build(
|
907
|
-
engine: AsyncEngine,
|
908
|
-
table: Table,
|
909
|
-
values: Iterable[StrMapping],
|
910
|
-
/,
|
911
|
-
*,
|
912
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
913
|
-
) -> Insert:
|
914
|
-
values = list(values)
|
915
|
-
keys = merge_sets(*values)
|
916
|
-
dict_nones = dict.fromkeys(keys)
|
917
|
-
values = [{**dict_nones, **v} for v in values]
|
918
|
-
match _get_dialect(engine):
|
919
|
-
case "postgresql": # skipif-ci-and-not-linux
|
920
|
-
insert = postgresql_insert
|
921
|
-
case "sqlite":
|
922
|
-
insert = sqlite_insert
|
923
|
-
case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
|
924
|
-
raise NotImplementedError(dialect)
|
925
|
-
case never:
|
926
|
-
assert_never(never)
|
927
|
-
ins = insert(table).values(values)
|
928
|
-
primary_key = cast("Any", table.primary_key)
|
929
|
-
return _upsert_items_apply_on_conflict_do_update(
|
930
|
-
values, ins, primary_key, selected_or_all=selected_or_all
|
931
|
-
)
|
932
|
-
|
933
|
-
|
934
|
-
def _upsert_items_apply_on_conflict_do_update(
|
935
|
-
values: Iterable[StrMapping],
|
936
|
-
insert: postgresql_Insert | sqlite_Insert,
|
937
|
-
primary_key: PrimaryKeyConstraint,
|
938
|
-
/,
|
939
|
-
*,
|
940
|
-
selected_or_all: Literal["selected", "all"] = "selected",
|
941
|
-
) -> Insert:
|
942
|
-
match selected_or_all:
|
943
|
-
case "selected":
|
944
|
-
columns = merge_sets(*values)
|
945
|
-
case "all":
|
946
|
-
columns = {c.name for c in insert.excluded}
|
947
|
-
case never:
|
948
|
-
assert_never(never)
|
949
|
-
set_ = {c: getattr(insert.excluded, c) for c in columns}
|
950
|
-
match insert:
|
951
|
-
case postgresql_Insert(): # skipif-ci
|
952
|
-
return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
|
953
|
-
case sqlite_Insert():
|
954
|
-
return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
|
955
|
-
case never:
|
956
|
-
assert_never(never)
|
957
|
-
|
958
|
-
|
959
|
-
@dataclass(kw_only=True, slots=True)
|
960
|
-
class UpsertItemsError(Exception):
|
961
|
-
item: _InsertItem
|
962
|
-
|
963
|
-
@override
|
964
|
-
def __str__(self) -> str:
|
965
|
-
return f"Item must be valid; got {self.item}"
|
966
|
-
|
967
|
-
|
968
|
-
##
|
969
|
-
|
970
|
-
|
971
902
|
@asynccontextmanager
|
972
903
|
async def yield_connection(
|
973
904
|
engine: AsyncEngine,
|
@@ -1211,84 +1142,6 @@ def _orm_inst_to_dict_predicate(
|
|
1211
1142
|
##
|
1212
1143
|
|
1213
1144
|
|
1214
|
-
@dataclass(kw_only=True, slots=True)
|
1215
|
-
class _PrepareInsertOrUpsertItems:
|
1216
|
-
mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
|
1217
|
-
yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
|
1218
|
-
|
1219
|
-
@property
|
1220
|
-
def tables(self) -> Sequence[Table]:
|
1221
|
-
return list(self.mapping)
|
1222
|
-
|
1223
|
-
|
1224
|
-
def _prepare_insert_or_upsert_items(
|
1225
|
-
normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
|
1226
|
-
engine: AsyncEngine,
|
1227
|
-
build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
|
1228
|
-
/,
|
1229
|
-
*items: Any,
|
1230
|
-
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
1231
|
-
) -> _PrepareInsertOrUpsertItems:
|
1232
|
-
"""Prepare a set of insert/upsert items."""
|
1233
|
-
mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
|
1234
|
-
lengths: set[int] = set()
|
1235
|
-
try:
|
1236
|
-
for item in items:
|
1237
|
-
for normed in normalize_item(item):
|
1238
|
-
mapping[normed.table].append(normed.mapping)
|
1239
|
-
lengths.add(len(normed.mapping))
|
1240
|
-
except _NormalizeInsertItemError as error:
|
1241
|
-
raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
|
1242
|
-
merged: dict[Table, list[StrMapping]] = {
|
1243
|
-
table: _prepare_insert_or_upsert_items_merge_items(table, values)
|
1244
|
-
for table, values in mapping.items()
|
1245
|
-
}
|
1246
|
-
|
1247
|
-
def yield_pairs() -> Iterator[tuple[Insert, None]]:
|
1248
|
-
for table, values in merged.items():
|
1249
|
-
chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
|
1250
|
-
for chunk in chunked(values, chunk_size):
|
1251
|
-
yield build_insert(table, chunk)
|
1252
|
-
|
1253
|
-
return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
|
1254
|
-
|
1255
|
-
|
1256
|
-
@dataclass(kw_only=True, slots=True)
|
1257
|
-
class _PrepareInsertOrUpsertItemsError(Exception):
|
1258
|
-
item: Any
|
1259
|
-
|
1260
|
-
@override
|
1261
|
-
def __str__(self) -> str:
|
1262
|
-
return f"Item must be valid; got {self.item}"
|
1263
|
-
|
1264
|
-
|
1265
|
-
def _prepare_insert_or_upsert_items_merge_items(
|
1266
|
-
table: Table, items: Iterable[StrMapping], /
|
1267
|
-
) -> list[StrMapping]:
|
1268
|
-
columns = list(yield_primary_key_columns(table))
|
1269
|
-
col_names = [c.name for c in columns]
|
1270
|
-
cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
|
1271
|
-
cols_non_auto = set(col_names) - cols_auto
|
1272
|
-
mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
|
1273
|
-
unchanged: list[StrMapping] = []
|
1274
|
-
for item in items:
|
1275
|
-
check_subset(cols_non_auto, item)
|
1276
|
-
has_all_auto = set(cols_auto).issubset(item)
|
1277
|
-
if has_all_auto:
|
1278
|
-
pkey = tuple(item[k] for k in col_names)
|
1279
|
-
rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
|
1280
|
-
mapping[pkey].append(rest)
|
1281
|
-
else:
|
1282
|
-
unchanged.append(item)
|
1283
|
-
merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
|
1284
|
-
return [
|
1285
|
-
dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
|
1286
|
-
] + unchanged
|
1287
|
-
|
1288
|
-
|
1289
|
-
##
|
1290
|
-
|
1291
|
-
|
1292
1145
|
def _tuple_to_mapping(
|
1293
1146
|
values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
|
1294
1147
|
) -> dict[str, Any]:
|
@@ -1307,7 +1160,6 @@ __all__ = [
|
|
1307
1160
|
"GetTableError",
|
1308
1161
|
"InsertItemsError",
|
1309
1162
|
"TablenameMixin",
|
1310
|
-
"UpsertItemsError",
|
1311
1163
|
"check_connect",
|
1312
1164
|
"check_connect_async",
|
1313
1165
|
"check_engine",
|
@@ -1333,7 +1185,6 @@ __all__ = [
|
|
1333
1185
|
"is_table_or_orm",
|
1334
1186
|
"migrate_data",
|
1335
1187
|
"selectable_to_string",
|
1336
|
-
"upsert_items",
|
1337
1188
|
"yield_connection",
|
1338
1189
|
"yield_primary_key_columns",
|
1339
1190
|
]
|
utilities/sqlalchemy_polars.py
CHANGED
@@ -4,7 +4,7 @@ import datetime as dt
|
|
4
4
|
import decimal
|
5
5
|
from contextlib import suppress
|
6
6
|
from dataclasses import dataclass
|
7
|
-
from typing import TYPE_CHECKING, Any,
|
7
|
+
from typing import TYPE_CHECKING, Any, cast, overload, override
|
8
8
|
from uuid import UUID
|
9
9
|
|
10
10
|
import polars as pl
|
@@ -44,7 +44,6 @@ from utilities.sqlalchemy import (
|
|
44
44
|
get_chunk_size,
|
45
45
|
get_columns,
|
46
46
|
insert_items,
|
47
|
-
upsert_items,
|
48
47
|
)
|
49
48
|
from utilities.text import snake_case
|
50
49
|
from utilities.typing import is_subclass_gen
|
@@ -75,9 +74,9 @@ async def insert_dataframe(
|
|
75
74
|
/,
|
76
75
|
*,
|
77
76
|
snake: bool = False,
|
77
|
+
is_upsert: bool = False,
|
78
78
|
chunk_size_frac: float = CHUNK_SIZE_FRAC,
|
79
79
|
assume_tables_exist: bool = False,
|
80
|
-
upsert: Literal["selected", "all"] | None = None,
|
81
80
|
timeout_create: TimeDelta | None = None,
|
82
81
|
error_create: type[Exception] = TimeoutError,
|
83
82
|
timeout_insert: TimeDelta | None = None,
|
@@ -87,43 +86,25 @@ async def insert_dataframe(
|
|
87
86
|
mapping = _insert_dataframe_map_df_schema_to_table(
|
88
87
|
df.schema, table_or_orm, snake=snake
|
89
88
|
)
|
90
|
-
items = df.select(*mapping).rename(mapping).
|
89
|
+
items = df.select(*mapping).rename(mapping).rows(named=True)
|
91
90
|
if len(items) == 0:
|
92
|
-
if not df.is_empty():
|
93
|
-
raise InsertDataFrameError(df=df)
|
94
91
|
if not assume_tables_exist:
|
95
92
|
await ensure_tables_created(
|
96
93
|
engine, table_or_orm, timeout=timeout_create, error=error_create
|
97
94
|
)
|
98
95
|
return
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
)
|
112
|
-
case "selected" | "all" as selected_or_all: # skipif-ci-and-not-linux
|
113
|
-
await upsert_items(
|
114
|
-
engine,
|
115
|
-
(items, table_or_orm),
|
116
|
-
snake=snake,
|
117
|
-
chunk_size_frac=chunk_size_frac,
|
118
|
-
selected_or_all=selected_or_all,
|
119
|
-
assume_tables_exist=assume_tables_exist,
|
120
|
-
timeout_create=timeout_create,
|
121
|
-
error_create=error_create,
|
122
|
-
timeout_insert=timeout_insert,
|
123
|
-
error_insert=error_insert,
|
124
|
-
)
|
125
|
-
case never:
|
126
|
-
assert_never(never)
|
96
|
+
await insert_items(
|
97
|
+
engine,
|
98
|
+
(items, table_or_orm),
|
99
|
+
snake=snake,
|
100
|
+
is_upsert=is_upsert,
|
101
|
+
chunk_size_frac=chunk_size_frac,
|
102
|
+
assume_tables_exist=assume_tables_exist,
|
103
|
+
timeout_create=timeout_create,
|
104
|
+
error_create=error_create,
|
105
|
+
timeout_insert=timeout_insert,
|
106
|
+
error_insert=error_insert,
|
107
|
+
)
|
127
108
|
|
128
109
|
|
129
110
|
def _insert_dataframe_map_df_schema_to_table(
|
@@ -207,15 +188,6 @@ def _insert_dataframe_check_df_and_db_types(
|
|
207
188
|
)
|
208
189
|
|
209
190
|
|
210
|
-
@dataclass(kw_only=True, slots=True)
|
211
|
-
class InsertDataFrameError(Exception):
|
212
|
-
df: DataFrame
|
213
|
-
|
214
|
-
@override
|
215
|
-
def __str__(self) -> str:
|
216
|
-
return f"Non-empty DataFrame must resolve to at least 1 item\n\n{self.df}"
|
217
|
-
|
218
|
-
|
219
191
|
@overload
|
220
192
|
async def select_to_dataframe(
|
221
193
|
sel: Select[Any],
|
@@ -443,4 +415,4 @@ def _select_to_dataframe_yield_selects_with_in_clauses(
|
|
443
415
|
return (sel.where(in_col.in_(values)) for values in chunked(in_values, chunk_size))
|
444
416
|
|
445
417
|
|
446
|
-
__all__ = ["
|
418
|
+
__all__ = ["insert_dataframe", "select_to_dataframe"]
|
File without changes
|
File without changes
|
File without changes
|