dycw-utilities 0.157.2__py3-none-any.whl → 0.158.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dycw-utilities
3
- Version: 0.157.2
3
+ Version: 0.158.1
4
4
  Author-email: Derek Wan <d.wan@icloud.com>
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,4 +1,4 @@
1
- utilities/__init__.py,sha256=4r65bgykd0Tf_s-OhMo8cbKy-69wDBJKBbpSfBl7DX0,60
1
+ utilities/__init__.py,sha256=SivYUEYUR51QHFgXpC5YqYC_WiHCwRe1sU5wvBGCfCY,60
2
2
  utilities/altair.py,sha256=92E2lCdyHY4Zb-vCw6rEJIsWdKipuu-Tu2ab1ufUfAk,9079
3
3
  utilities/asyncio.py,sha256=PUedzQ5deqlSECQ33sam9cRzI9TnygHz3FdOqWJWPTM,15288
4
4
  utilities/atomicwrites.py,sha256=tPo6r-Rypd9u99u66B9z86YBPpnLrlHtwox_8Z7T34Y,5790
@@ -10,7 +10,7 @@ utilities/contextlib.py,sha256=m2D5bwvtCZLJcJ3IwVqyErYODuwJ1gLrT2UfATAQl-w,7435
10
10
  utilities/contextvars.py,sha256=J8OhC7jqozAGYOCe2KUWysbPXNGe5JYz3HfaY_mIs08,883
11
11
  utilities/cryptography.py,sha256=5PFrzsNUGHay91dFgYnDKwYprXxahrBqztmUqViRzBk,956
12
12
  utilities/cvxpy.py,sha256=Rv1-fD-XYerosCavRF8Pohop2DBkU3AlFaGTfD8AEAA,13776
13
- utilities/dataclasses.py,sha256=MXrvIPSZHlpV4msRdVVDRZZo7MC3gX5C9jDUSoNOdpE,32478
13
+ utilities/dataclasses.py,sha256=wGfQtopYZCwvUxNxLOlVNFlhK8Q0_aQbKJ0rWfT-fEc,32482
14
14
  utilities/enum.py,sha256=5l6pwZD1cjSlVW4ss-zBPspWvrbrYrdtJWcg6f5_J5w,5781
15
15
  utilities/errors.py,sha256=mFlDGSM0LI1jZ1pbqwLAH3ttLZ2JVIxyZLojw8tGVZU,1479
16
16
  utilities/eventkit.py,sha256=ddoleSwW9zdc2tjX5Ge0pMKtYwV_JMxhHYOxnWX2AGM,12609
@@ -22,7 +22,7 @@ utilities/getpass.py,sha256=DfN5UgMAtFCqS3dSfFHUfqIMZX2shXvwphOz_6J6f6A,103
22
22
  utilities/gzip.py,sha256=fkGP3KdsBfXlstodT4wtlp-PwNyUsogpbDCVVVGdsm4,781
23
23
  utilities/hashlib.py,sha256=SVTgtguur0P4elppvzOBbLEjVM3Pea0eWB61yg2ilxo,309
24
24
  utilities/http.py,sha256=TsavEfHlRtlLaeV21Z6KZh0qbPw-kvD1zsQdZ7Kep5Q,977
25
- utilities/hypothesis.py,sha256=qqV0O2ynV73tT-bAPXxlfWdR1X0iZgrjn7sK-eJOqBM,43812
25
+ utilities/hypothesis.py,sha256=OG8mxN6Y3fSEjRg4NiIjsO_JUHJBzh4g8fvpmKRoRU8,44370
26
26
  utilities/importlib.py,sha256=mV1xT_O_zt_GnZZ36tl3xOmMaN_3jErDWY54fX39F6Y,429
27
27
  utilities/inflect.py,sha256=v7YkOWSu8NAmVghPcf4F3YBZQoJCS47_DLf9jbfWIs0,581
28
28
  utilities/ipython.py,sha256=V2oMYHvEKvlNBzxDXdLvKi48oUq2SclRg5xasjaXStw,763
@@ -64,8 +64,8 @@ utilities/sentinel.py,sha256=A_p5jX2K0Yc5XBfoYHyBLqHsEWzE1ByOdDuzzA2pZnE,1434
64
64
  utilities/shelve.py,sha256=4OzjQI6kGuUbJciqf535rwnao-_IBv66gsT6tRGiUt0,759
65
65
  utilities/slack_sdk.py,sha256=76-DYtcGiUhEvl-voMamc5OjfF7Y7nCq54Bys1arqzw,2233
66
66
  utilities/socket.py,sha256=K77vfREvzoVTrpYKo6MZakol0EYu2q1sWJnnZqL0So0,118
67
- utilities/sqlalchemy.py,sha256=S3r_sc04wxpP9gCsSMsPZ8GlC_edATlwZd_WLa1Bl14,40619
68
- utilities/sqlalchemy_polars.py,sha256=uaqUrFQQhHjhQWa1t03aNtvqeO8MsUZ9zchCQNF1Pq8,14454
67
+ utilities/sqlalchemy.py,sha256=qBB6N2wVjplvI4xFgbzzNmsW5zLvnXszmxyV9AsrmbA,36367
68
+ utilities/sqlalchemy_polars.py,sha256=JCGhB37raSR7fqeWV5dTsciRTMVzIdVT9YSqKT0piT0,13370
69
69
  utilities/statsmodels.py,sha256=koyiBHvpMcSiBfh99wFUfSggLNx7cuAw3rwyfAhoKpQ,3410
70
70
  utilities/string.py,sha256=shmBK87zZwzGyixuNuXCiUbqzfeZ9xlrFwz6JTaRvDk,582
71
71
  utilities/tempfile.py,sha256=HxB2BF28CcecDJLQ3Bx2Ej-Pb6RJc6W9ngSpB9CnP4k,2018
@@ -87,8 +87,8 @@ utilities/zoneinfo.py,sha256=FBMcUQ4662Aq8SsuCL1OAhDQiyANmVjtb-C30DRrWoE,1966
87
87
  utilities/pytest_plugins/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
88
88
  utilities/pytest_plugins/pytest_randomly.py,sha256=B1qYVlExGOxTywq2r1SMi5o7btHLk2PNdY_b1p98dkE,409
89
89
  utilities/pytest_plugins/pytest_regressions.py,sha256=9v8kAXDM2ycIXJBimoiF4EgrwbUvxTycFWJiGR_GHhM,1466
90
- dycw_utilities-0.157.2.dist-info/METADATA,sha256=H3bdmvntWyzOrtEtLSw_OxEMIe1vcs21otdOrEst-tA,1643
91
- dycw_utilities-0.157.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
- dycw_utilities-0.157.2.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
93
- dycw_utilities-0.157.2.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
94
- dycw_utilities-0.157.2.dist-info/RECORD,,
90
+ dycw_utilities-0.158.1.dist-info/METADATA,sha256=G1L-4vHRNDR2xpzQMCKJ3uLTRd1nCtaEtaboy9jMzio,1643
91
+ dycw_utilities-0.158.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
+ dycw_utilities-0.158.1.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
93
+ dycw_utilities-0.158.1.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
94
+ dycw_utilities-0.158.1.dist-info/RECORD,,
utilities/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
- __version__ = "0.157.2"
3
+ __version__ = "0.158.1"
utilities/dataclasses.py CHANGED
@@ -34,7 +34,7 @@ from utilities.text import (
34
34
  _SplitKeyValuePairsSplitError,
35
35
  split_key_value_pairs,
36
36
  )
37
- from utilities.types import SupportsLT
37
+ from utilities.types import MaybeType, SupportsLT
38
38
  from utilities.typing import get_type_hints
39
39
 
40
40
  if TYPE_CHECKING:
@@ -830,7 +830,7 @@ def yield_fields(
830
830
  warn_name_errors: bool = False,
831
831
  ) -> Iterator[_YieldFieldsClass[Any]]: ...
832
832
  def yield_fields(
833
- obj: Dataclass | type[Dataclass],
833
+ obj: MaybeType[Dataclass],
834
834
  /,
835
835
  *,
836
836
  globalns: StrMapping | None = None,
utilities/hypothesis.py CHANGED
@@ -897,6 +897,27 @@ def py_datetimes(
897
897
  ##
898
898
 
899
899
 
900
+ def quadruples[T](
901
+ strategy: SearchStrategy[T],
902
+ /,
903
+ *,
904
+ unique: MaybeSearchStrategy[bool] = False,
905
+ sorted: MaybeSearchStrategy[bool] = False, # noqa: A002
906
+ ) -> SearchStrategy[tuple[T, T, T, T]]:
907
+ """Strategy for generating quadruples of elements."""
908
+ return lists_fixed_length(strategy, 4, unique=unique, sorted=sorted).map(
909
+ _quadruples_map
910
+ )
911
+
912
+
913
+ def _quadruples_map[T](elements: list[T], /) -> tuple[T, T, T, T]:
914
+ first, second, third, fourth = elements
915
+ return first, second, third, fourth
916
+
917
+
918
+ ##
919
+
920
+
900
921
  @composite
901
922
  def random_states(
902
923
  draw: DrawFn, /, *, seed: MaybeSearchStrategy[int | None] = None
@@ -1555,6 +1576,7 @@ __all__ = [
1555
1576
  "paths",
1556
1577
  "plain_date_times",
1557
1578
  "py_datetimes",
1579
+ "quadruples",
1558
1580
  "random_states",
1559
1581
  "sentinels",
1560
1582
  "sets_fixed_length",
utilities/sqlalchemy.py CHANGED
@@ -12,8 +12,8 @@ from collections.abc import (
12
12
  )
13
13
  from collections.abc import Set as AbstractSet
14
14
  from contextlib import asynccontextmanager
15
- from dataclasses import dataclass, field
16
- from functools import partial, reduce
15
+ from dataclasses import dataclass
16
+ from functools import reduce
17
17
  from itertools import chain
18
18
  from math import floor
19
19
  from operator import ge, le
@@ -37,7 +37,6 @@ from sqlalchemy import (
37
37
  Connection,
38
38
  Engine,
39
39
  Insert,
40
- PrimaryKeyConstraint,
41
40
  Selectable,
42
41
  Table,
43
42
  and_,
@@ -49,20 +48,13 @@ from sqlalchemy import (
49
48
  from sqlalchemy.dialects.mssql import dialect as mssql_dialect
50
49
  from sqlalchemy.dialects.mysql import dialect as mysql_dialect
51
50
  from sqlalchemy.dialects.oracle import dialect as oracle_dialect
52
- from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
53
51
  from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
54
52
  from sqlalchemy.dialects.postgresql import insert as postgresql_insert
55
53
  from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
56
54
  from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
57
- from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
58
55
  from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
59
56
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
60
- from sqlalchemy.exc import (
61
- ArgumentError,
62
- DatabaseError,
63
- OperationalError,
64
- ProgrammingError,
65
- )
57
+ from sqlalchemy.exc import ArgumentError, DatabaseError
66
58
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
67
59
  from sqlalchemy.orm import (
68
60
  DeclarativeBase,
@@ -129,7 +121,7 @@ def check_connect(engine: Engine, /) -> bool:
129
121
  try:
130
122
  with engine.connect() as conn:
131
123
  return bool(conn.execute(_SELECT).scalar_one())
132
- except (gaierror, ConnectionRefusedError, OperationalError, ProgrammingError):
124
+ except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
133
125
  return False
134
126
 
135
127
 
@@ -144,13 +136,7 @@ async def check_connect_async(
144
136
  try:
145
137
  async with timeout_td(timeout, error=error), engine.connect() as conn:
146
138
  return bool((await conn.execute(_SELECT)).scalar_one())
147
- except (
148
- gaierror,
149
- ConnectionRefusedError,
150
- OperationalError,
151
- ProgrammingError,
152
- TimeoutError,
153
- ):
139
+ except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
154
140
  return False
155
141
 
156
142
 
@@ -162,7 +148,7 @@ async def check_engine(
162
148
  /,
163
149
  *,
164
150
  timeout: Delta | None = None,
165
- error: type[Exception] = TimeoutError,
151
+ error: MaybeType[Exception] = TimeoutError,
166
152
  num_tables: int | tuple[int, float] | None = None,
167
153
  ) -> None:
168
154
  """Check that an engine can connect.
@@ -336,9 +322,8 @@ async def ensure_database_created(super_: URL, database: str, /) -> None:
336
322
  async with engine.begin() as conn:
337
323
  try:
338
324
  _ = await conn.execute(text(f"CREATE DATABASE {database}"))
339
- except (OperationalError, ProgrammingError) as error:
340
- if not search('database ".*" already exists', ensure_str(one(error.args))):
341
- raise
325
+ except DatabaseError as error:
326
+ _ensure_tables_maybe_reraise(error, 'database ".*" already exists')
342
327
 
343
328
 
344
329
  async def ensure_database_dropped(super_: URL, database: str, /) -> None:
@@ -356,7 +341,7 @@ async def ensure_tables_created(
356
341
  /,
357
342
  *tables_or_orms: TableOrORMInstOrClass,
358
343
  timeout: Delta | None = None,
359
- error: type[Exception] = TimeoutError,
344
+ error: MaybeType[Exception] = TimeoutError,
360
345
  ) -> None:
361
346
  """Ensure a table/set of tables is/are created."""
362
347
  tables = set(map(get_table, tables_or_orms))
@@ -385,7 +370,7 @@ async def ensure_tables_dropped(
385
370
  engine: AsyncEngine,
386
371
  *tables_or_orms: TableOrORMInstOrClass,
387
372
  timeout: Delta | None = None,
388
- error: type[Exception] = TimeoutError,
373
+ error: MaybeType[Exception] = TimeoutError,
389
374
  ) -> None:
390
375
  """Ensure a table/set of tables is/are dropped."""
391
376
  tables = set(map(get_table, tables_or_orms))
@@ -606,18 +591,22 @@ type _InsertItem = (
606
591
  | Sequence[_PairOfTupleOrStrMappingAndTable]
607
592
  | Sequence[DeclarativeBase]
608
593
  )
594
+ type _NormalizedItem = tuple[Table, StrMapping]
595
+ type _InsertPair = tuple[Table, Sequence[StrMapping]]
596
+ type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
609
597
 
610
598
 
611
599
  async def insert_items(
612
600
  engine: AsyncEngine,
613
601
  *items: _InsertItem,
614
602
  snake: bool = False,
603
+ is_upsert: bool = False,
615
604
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
616
605
  assume_tables_exist: bool = False,
617
606
  timeout_create: Delta | None = None,
618
- error_create: type[Exception] = TimeoutError,
607
+ error_create: MaybeType[Exception] = TimeoutError,
619
608
  timeout_insert: Delta | None = None,
620
- error_insert: type[Exception] = TimeoutError,
609
+ error_insert: MaybeType[Exception] = TimeoutError,
621
610
  ) -> None:
622
611
  """Insert a set of items into a database.
623
612
 
@@ -641,37 +630,181 @@ async def insert_items(
641
630
  Obj(k1=v21, k2=v22, ...),
642
631
  ...]
643
632
  """
644
-
645
- def build_insert(
646
- table: Table, values: Iterable[StrMapping], /
647
- ) -> tuple[Insert, Any]:
648
- match _get_dialect(engine):
649
- case "oracle": # pragma: no cover
650
- return insert(table), values
651
- case _:
652
- return insert(table).values(list(values)), None
653
-
654
- try:
655
- prepared = _prepare_insert_or_upsert_items(
656
- partial(_normalize_insert_item, snake=snake),
657
- engine,
658
- build_insert,
659
- *items,
660
- chunk_size_frac=chunk_size_frac,
661
- )
662
- except _PrepareInsertOrUpsertItemsError as error:
663
- raise InsertItemsError(item=error.item) from None
633
+ normalized = chain.from_iterable(
634
+ _insert_items_yield_normalized(i, snake=snake) for i in items
635
+ )
636
+ triples = _insert_items_yield_triples(
637
+ engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
638
+ )
664
639
  if not assume_tables_exist:
640
+ triples = list(triples)
641
+ tables = {table for table, _, _ in triples}
665
642
  await ensure_tables_created(
666
- engine, *prepared.tables, timeout=timeout_create, error=error_create
643
+ engine, *tables, timeout=timeout_create, error=error_create
667
644
  )
668
- for ins, parameters in prepared.yield_pairs():
645
+ for _, ins, parameters in triples:
669
646
  async with yield_connection(
670
647
  engine, timeout=timeout_insert, error=error_insert
671
648
  ) as conn:
672
649
  _ = await conn.execute(ins, parameters=parameters)
673
650
 
674
651
 
652
+ def _insert_items_yield_normalized(
653
+ item: _InsertItem, /, *, snake: bool = False
654
+ ) -> Iterator[_NormalizedItem]:
655
+ if _is_pair_of_str_mapping_and_table(item):
656
+ mapping, table_or_orm = item
657
+ adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
658
+ yield (get_table(table_or_orm), adjusted)
659
+ return
660
+ if _is_pair_of_tuple_and_table(item):
661
+ tuple_, table_or_orm = item
662
+ mapping = _tuple_to_mapping(tuple_, table_or_orm)
663
+ yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
664
+ return
665
+ if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
666
+ items, table_or_orm = item
667
+ pairs = [(i, table_or_orm) for i in items]
668
+ for p in pairs:
669
+ yield from _insert_items_yield_normalized(p, snake=snake)
670
+ return
671
+ if isinstance(item, DeclarativeBase):
672
+ mapping = _orm_inst_to_dict(item)
673
+ yield from _insert_items_yield_normalized((mapping, item), snake=snake)
674
+ return
675
+ try:
676
+ _ = iter(item)
677
+ except TypeError:
678
+ raise InsertItemsError(item=item) from None
679
+ if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
680
+ pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
681
+ for p in pairs:
682
+ yield from _insert_items_yield_normalized(p, snake=snake)
683
+ return
684
+ if all(map(is_orm, item)):
685
+ classes = cast("Sequence[DeclarativeBase]", item)
686
+ for c in classes:
687
+ yield from _insert_items_yield_normalized(c, snake=snake)
688
+ return
689
+ raise InsertItemsError(item=item)
690
+
691
+
692
+ def _insert_items_yield_triples(
693
+ engine: AsyncEngine,
694
+ items: Iterable[_NormalizedItem],
695
+ /,
696
+ *,
697
+ is_upsert: bool = False,
698
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
699
+ ) -> Iterable[_InsertTriple]:
700
+ pairs = _insert_items_yield_chunked_pairs(
701
+ engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
702
+ )
703
+ for table, mappings in pairs:
704
+ match is_upsert, _get_dialect(engine):
705
+ case False, "oracle": # pragma: no cover
706
+ ins = insert(table)
707
+ parameters = mappings
708
+ case False, _:
709
+ ins = insert(table).values(mappings)
710
+ parameters = None
711
+ case True, _:
712
+ ins = _insert_items_build_insert_with_on_conflict_do_update(
713
+ engine, table, mappings
714
+ )
715
+ parameters = None
716
+ case never:
717
+ assert_never(never)
718
+ yield table, ins, parameters
719
+
720
+
721
+ def _insert_items_yield_chunked_pairs(
722
+ engine: AsyncEngine,
723
+ items: Iterable[_NormalizedItem],
724
+ /,
725
+ *,
726
+ is_upsert: bool = False,
727
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
728
+ ) -> Iterable[_InsertPair]:
729
+ for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
730
+ chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
731
+ for mappings_i in chunked(mappings, chunk_size):
732
+ yield table, list(mappings_i)
733
+
734
+
735
+ def _insert_items_yield_raw_pairs(
736
+ items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
737
+ ) -> Iterable[_InsertPair]:
738
+ by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
739
+ for table, mapping in items:
740
+ by_table[table].append(mapping)
741
+ for table, mappings in by_table.items():
742
+ yield from _insert_items_yield_raw_pairs_one(
743
+ table, mappings, is_upsert=is_upsert
744
+ )
745
+
746
+
747
+ def _insert_items_yield_raw_pairs_one(
748
+ table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
749
+ ) -> Iterable[_InsertPair]:
750
+ merged = _insert_items_yield_merged_mappings(table, mappings)
751
+ match is_upsert:
752
+ case True:
753
+ by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
754
+ for mapping in merged:
755
+ non_null = {k: v for k, v in mapping.items() if v is not None}
756
+ by_keys[frozenset(non_null)].append(non_null)
757
+ for mappings_i in by_keys.values():
758
+ yield table, mappings_i
759
+ case False:
760
+ yield table, list(merged)
761
+ case never:
762
+ assert_never(never)
763
+
764
+
765
+ def _insert_items_yield_merged_mappings(
766
+ table: Table, mappings: Iterable[StrMapping], /
767
+ ) -> Iterable[StrMapping]:
768
+ columns = list(yield_primary_key_columns(table))
769
+ col_names = [c.name for c in columns]
770
+ cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
771
+ cols_non_auto = set(col_names) - cols_auto
772
+ by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
773
+ for mapping in mappings:
774
+ check_subset(cols_non_auto, mapping)
775
+ has_all_auto = set(cols_auto).issubset(mapping)
776
+ if has_all_auto:
777
+ pkey = tuple(mapping[k] for k in col_names)
778
+ rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
779
+ by_key[pkey].append(rest)
780
+ else:
781
+ yield mapping
782
+ for k, v in by_key.items():
783
+ head = dict(zip(col_names, k, strict=True))
784
+ yield merge_str_mappings(head, *v)
785
+
786
+
787
+ def _insert_items_build_insert_with_on_conflict_do_update(
788
+ engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
789
+ ) -> Insert:
790
+ primary_key = cast("Any", table.primary_key)
791
+ mappings = list(mappings)
792
+ columns = merge_sets(*mappings)
793
+ match _get_dialect(engine):
794
+ case "postgresql": # skipif-ci-and-not-linux
795
+ ins = postgresql_insert(table).values(mappings)
796
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
797
+ return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
798
+ case "sqlite":
799
+ ins = sqlite_insert(table).values(mappings)
800
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
801
+ return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
802
+ case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
803
+ raise NotImplementedError(dialect)
804
+ case never:
805
+ assert_never(never)
806
+
807
+
675
808
  @dataclass(kw_only=True, slots=True)
676
809
  class InsertItemsError(Exception):
677
810
  item: _InsertItem
@@ -716,9 +849,9 @@ async def migrate_data(
716
849
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
717
850
  assume_tables_exist: bool = False,
718
851
  timeout_create: Delta | None = None,
719
- error_create: type[Exception] = TimeoutError,
852
+ error_create: MaybeType[Exception] = TimeoutError,
720
853
  timeout_insert: Delta | None = None,
721
- error_insert: type[Exception] = TimeoutError,
854
+ error_insert: MaybeType[Exception] = TimeoutError,
722
855
  ) -> None:
723
856
  """Migrate the contents of a table from one database to another."""
724
857
  table_from = get_table(table_or_orm_from)
@@ -742,80 +875,6 @@ async def migrate_data(
742
875
  ##
743
876
 
744
877
 
745
- def _normalize_insert_item(
746
- item: _InsertItem, /, *, snake: bool = False
747
- ) -> list[_NormalizedItem]:
748
- """Normalize an insertion item."""
749
- if _is_pair_of_str_mapping_and_table(item):
750
- mapping, table_or_orm = item
751
- adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
752
- normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
753
- return [normalized]
754
- if _is_pair_of_tuple_and_table(item):
755
- tuple_, table_or_orm = item
756
- mapping = _tuple_to_mapping(tuple_, table_or_orm)
757
- return _normalize_insert_item((mapping, table_or_orm), snake=snake)
758
- if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
759
- items, table_or_orm = item
760
- pairs = [(i, table_or_orm) for i in items]
761
- normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
762
- return list(chain.from_iterable(normalized))
763
- if isinstance(item, DeclarativeBase):
764
- mapping = _orm_inst_to_dict(item)
765
- return _normalize_insert_item((mapping, item), snake=snake)
766
- try:
767
- _ = iter(item)
768
- except TypeError:
769
- raise _NormalizeInsertItemError(item=item) from None
770
- if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
771
- seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
772
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
773
- return list(chain.from_iterable(normalized))
774
- if all(map(is_orm, item)):
775
- seq = cast("Sequence[DeclarativeBase]", item)
776
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
777
- return list(chain.from_iterable(normalized))
778
- raise _NormalizeInsertItemError(item=item)
779
-
780
-
781
- @dataclass(kw_only=True, slots=True)
782
- class _NormalizeInsertItemError(Exception):
783
- item: _InsertItem
784
-
785
- @override
786
- def __str__(self) -> str:
787
- return f"Item must be valid; got {self.item}"
788
-
789
-
790
- @dataclass(kw_only=True, slots=True)
791
- class _NormalizedItem:
792
- mapping: StrMapping
793
- table: Table
794
-
795
-
796
- def _normalize_upsert_item(
797
- item: _InsertItem,
798
- /,
799
- *,
800
- snake: bool = False,
801
- selected_or_all: Literal["selected", "all"] = "selected",
802
- ) -> Iterator[_NormalizedItem]:
803
- """Normalize an upsert item."""
804
- normalized = _normalize_insert_item(item, snake=snake)
805
- match selected_or_all:
806
- case "selected":
807
- for norm in normalized:
808
- values = {k: v for k, v in norm.mapping.items() if v is not None}
809
- yield _NormalizedItem(mapping=values, table=norm.table)
810
- case "all":
811
- yield from normalized
812
- case never:
813
- assert_never(never)
814
-
815
-
816
- ##
817
-
818
-
819
878
  def selectable_to_string(
820
879
  selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
821
880
  ) -> str:
@@ -840,134 +899,6 @@ class TablenameMixin:
840
899
  ##
841
900
 
842
901
 
843
- type _SelectedOrAll = Literal["selected", "all"]
844
-
845
-
846
- async def upsert_items(
847
- engine: AsyncEngine,
848
- /,
849
- *items: _InsertItem,
850
- snake: bool = False,
851
- selected_or_all: _SelectedOrAll = "selected",
852
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
853
- assume_tables_exist: bool = False,
854
- timeout_create: Delta | None = None,
855
- error_create: type[Exception] = TimeoutError,
856
- timeout_insert: Delta | None = None,
857
- error_insert: type[Exception] = TimeoutError,
858
- ) -> None:
859
- """Upsert a set of items into a database.
860
-
861
- These can be one of the following:
862
- - pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
863
- - pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
864
- {k1=v21, k2=v22, ...},
865
- ...], table/class
866
- - list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
867
- ({k1=v21, k2=v22, ...}, table_cls2),
868
- ...]
869
- - mapped class: Obj(k1=v1, k2=v2, ...)
870
- - list of mapped classes: [Obj(k1=v11, k2=v12, ...),
871
- Obj(k1=v21, k2=v22, ...),
872
- ...]
873
- """
874
-
875
- def build_insert(
876
- table: Table, values: Iterable[StrMapping], /
877
- ) -> tuple[Insert, None]:
878
- ups = _upsert_items_build(
879
- engine, table, values, selected_or_all=selected_or_all
880
- )
881
- return ups, None
882
-
883
- try:
884
- prepared = _prepare_insert_or_upsert_items(
885
- partial(
886
- _normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
887
- ),
888
- engine,
889
- build_insert,
890
- *items,
891
- chunk_size_frac=chunk_size_frac,
892
- )
893
- except _PrepareInsertOrUpsertItemsError as error:
894
- raise UpsertItemsError(item=error.item) from None
895
- if not assume_tables_exist:
896
- await ensure_tables_created(
897
- engine, *prepared.tables, timeout=timeout_create, error=error_create
898
- )
899
- for ups, _ in prepared.yield_pairs():
900
- async with yield_connection(
901
- engine, timeout=timeout_insert, error=error_insert
902
- ) as conn:
903
- _ = await conn.execute(ups)
904
-
905
-
906
- def _upsert_items_build(
907
- engine: AsyncEngine,
908
- table: Table,
909
- values: Iterable[StrMapping],
910
- /,
911
- *,
912
- selected_or_all: Literal["selected", "all"] = "selected",
913
- ) -> Insert:
914
- values = list(values)
915
- keys = merge_sets(*values)
916
- dict_nones = dict.fromkeys(keys)
917
- values = [{**dict_nones, **v} for v in values]
918
- match _get_dialect(engine):
919
- case "postgresql": # skipif-ci-and-not-linux
920
- insert = postgresql_insert
921
- case "sqlite":
922
- insert = sqlite_insert
923
- case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
924
- raise NotImplementedError(dialect)
925
- case never:
926
- assert_never(never)
927
- ins = insert(table).values(values)
928
- primary_key = cast("Any", table.primary_key)
929
- return _upsert_items_apply_on_conflict_do_update(
930
- values, ins, primary_key, selected_or_all=selected_or_all
931
- )
932
-
933
-
934
- def _upsert_items_apply_on_conflict_do_update(
935
- values: Iterable[StrMapping],
936
- insert: postgresql_Insert | sqlite_Insert,
937
- primary_key: PrimaryKeyConstraint,
938
- /,
939
- *,
940
- selected_or_all: Literal["selected", "all"] = "selected",
941
- ) -> Insert:
942
- match selected_or_all:
943
- case "selected":
944
- columns = merge_sets(*values)
945
- case "all":
946
- columns = {c.name for c in insert.excluded}
947
- case never:
948
- assert_never(never)
949
- set_ = {c: getattr(insert.excluded, c) for c in columns}
950
- match insert:
951
- case postgresql_Insert(): # skipif-ci
952
- return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
953
- case sqlite_Insert():
954
- return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
955
- case never:
956
- assert_never(never)
957
-
958
-
959
- @dataclass(kw_only=True, slots=True)
960
- class UpsertItemsError(Exception):
961
- item: _InsertItem
962
-
963
- @override
964
- def __str__(self) -> str:
965
- return f"Item must be valid; got {self.item}"
966
-
967
-
968
- ##
969
-
970
-
971
902
  @asynccontextmanager
972
903
  async def yield_connection(
973
904
  engine: AsyncEngine,
@@ -1211,84 +1142,6 @@ def _orm_inst_to_dict_predicate(
1211
1142
  ##
1212
1143
 
1213
1144
 
1214
- @dataclass(kw_only=True, slots=True)
1215
- class _PrepareInsertOrUpsertItems:
1216
- mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
1217
- yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
1218
-
1219
- @property
1220
- def tables(self) -> Sequence[Table]:
1221
- return list(self.mapping)
1222
-
1223
-
1224
- def _prepare_insert_or_upsert_items(
1225
- normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
1226
- engine: AsyncEngine,
1227
- build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
1228
- /,
1229
- *items: Any,
1230
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
1231
- ) -> _PrepareInsertOrUpsertItems:
1232
- """Prepare a set of insert/upsert items."""
1233
- mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
1234
- lengths: set[int] = set()
1235
- try:
1236
- for item in items:
1237
- for normed in normalize_item(item):
1238
- mapping[normed.table].append(normed.mapping)
1239
- lengths.add(len(normed.mapping))
1240
- except _NormalizeInsertItemError as error:
1241
- raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
1242
- merged: dict[Table, list[StrMapping]] = {
1243
- table: _prepare_insert_or_upsert_items_merge_items(table, values)
1244
- for table, values in mapping.items()
1245
- }
1246
-
1247
- def yield_pairs() -> Iterator[tuple[Insert, None]]:
1248
- for table, values in merged.items():
1249
- chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
1250
- for chunk in chunked(values, chunk_size):
1251
- yield build_insert(table, chunk)
1252
-
1253
- return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
1254
-
1255
-
1256
- @dataclass(kw_only=True, slots=True)
1257
- class _PrepareInsertOrUpsertItemsError(Exception):
1258
- item: Any
1259
-
1260
- @override
1261
- def __str__(self) -> str:
1262
- return f"Item must be valid; got {self.item}"
1263
-
1264
-
1265
- def _prepare_insert_or_upsert_items_merge_items(
1266
- table: Table, items: Iterable[StrMapping], /
1267
- ) -> list[StrMapping]:
1268
- columns = list(yield_primary_key_columns(table))
1269
- col_names = [c.name for c in columns]
1270
- cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
1271
- cols_non_auto = set(col_names) - cols_auto
1272
- mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
1273
- unchanged: list[StrMapping] = []
1274
- for item in items:
1275
- check_subset(cols_non_auto, item)
1276
- has_all_auto = set(cols_auto).issubset(item)
1277
- if has_all_auto:
1278
- pkey = tuple(item[k] for k in col_names)
1279
- rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
1280
- mapping[pkey].append(rest)
1281
- else:
1282
- unchanged.append(item)
1283
- merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
1284
- return [
1285
- dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
1286
- ] + unchanged
1287
-
1288
-
1289
- ##
1290
-
1291
-
1292
1145
  def _tuple_to_mapping(
1293
1146
  values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
1294
1147
  ) -> dict[str, Any]:
@@ -1307,7 +1160,6 @@ __all__ = [
1307
1160
  "GetTableError",
1308
1161
  "InsertItemsError",
1309
1162
  "TablenameMixin",
1310
- "UpsertItemsError",
1311
1163
  "check_connect",
1312
1164
  "check_connect_async",
1313
1165
  "check_engine",
@@ -1333,7 +1185,6 @@ __all__ = [
1333
1185
  "is_table_or_orm",
1334
1186
  "migrate_data",
1335
1187
  "selectable_to_string",
1336
- "upsert_items",
1337
1188
  "yield_connection",
1338
1189
  "yield_primary_key_columns",
1339
1190
  ]
@@ -4,7 +4,7 @@ import datetime as dt
4
4
  import decimal
5
5
  from contextlib import suppress
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
7
+ from typing import TYPE_CHECKING, Any, cast, overload, override
8
8
  from uuid import UUID
9
9
 
10
10
  import polars as pl
@@ -44,7 +44,6 @@ from utilities.sqlalchemy import (
44
44
  get_chunk_size,
45
45
  get_columns,
46
46
  insert_items,
47
- upsert_items,
48
47
  )
49
48
  from utilities.text import snake_case
50
49
  from utilities.typing import is_subclass_gen
@@ -75,9 +74,9 @@ async def insert_dataframe(
75
74
  /,
76
75
  *,
77
76
  snake: bool = False,
77
+ is_upsert: bool = False,
78
78
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
79
79
  assume_tables_exist: bool = False,
80
- upsert: Literal["selected", "all"] | None = None,
81
80
  timeout_create: TimeDelta | None = None,
82
81
  error_create: type[Exception] = TimeoutError,
83
82
  timeout_insert: TimeDelta | None = None,
@@ -87,43 +86,25 @@ async def insert_dataframe(
87
86
  mapping = _insert_dataframe_map_df_schema_to_table(
88
87
  df.schema, table_or_orm, snake=snake
89
88
  )
90
- items = df.select(*mapping).rename(mapping).to_dicts()
89
+ items = df.select(*mapping).rename(mapping).rows(named=True)
91
90
  if len(items) == 0:
92
- if not df.is_empty():
93
- raise InsertDataFrameError(df=df)
94
91
  if not assume_tables_exist:
95
92
  await ensure_tables_created(
96
93
  engine, table_or_orm, timeout=timeout_create, error=error_create
97
94
  )
98
95
  return
99
- match upsert:
100
- case None:
101
- await insert_items(
102
- engine,
103
- (items, table_or_orm),
104
- snake=snake,
105
- chunk_size_frac=chunk_size_frac,
106
- assume_tables_exist=assume_tables_exist,
107
- timeout_create=timeout_create,
108
- error_create=error_create,
109
- timeout_insert=timeout_insert,
110
- error_insert=error_insert,
111
- )
112
- case "selected" | "all" as selected_or_all: # skipif-ci-and-not-linux
113
- await upsert_items(
114
- engine,
115
- (items, table_or_orm),
116
- snake=snake,
117
- chunk_size_frac=chunk_size_frac,
118
- selected_or_all=selected_or_all,
119
- assume_tables_exist=assume_tables_exist,
120
- timeout_create=timeout_create,
121
- error_create=error_create,
122
- timeout_insert=timeout_insert,
123
- error_insert=error_insert,
124
- )
125
- case never:
126
- assert_never(never)
96
+ await insert_items(
97
+ engine,
98
+ (items, table_or_orm),
99
+ snake=snake,
100
+ is_upsert=is_upsert,
101
+ chunk_size_frac=chunk_size_frac,
102
+ assume_tables_exist=assume_tables_exist,
103
+ timeout_create=timeout_create,
104
+ error_create=error_create,
105
+ timeout_insert=timeout_insert,
106
+ error_insert=error_insert,
107
+ )
127
108
 
128
109
 
129
110
  def _insert_dataframe_map_df_schema_to_table(
@@ -207,15 +188,6 @@ def _insert_dataframe_check_df_and_db_types(
207
188
  )
208
189
 
209
190
 
210
- @dataclass(kw_only=True, slots=True)
211
- class InsertDataFrameError(Exception):
212
- df: DataFrame
213
-
214
- @override
215
- def __str__(self) -> str:
216
- return f"Non-empty DataFrame must resolve to at least 1 item\n\n{self.df}"
217
-
218
-
219
191
  @overload
220
192
  async def select_to_dataframe(
221
193
  sel: Select[Any],
@@ -443,4 +415,4 @@ def _select_to_dataframe_yield_selects_with_in_clauses(
443
415
  return (sel.where(in_col.in_(values)) for values in chunked(in_values, chunk_size))
444
416
 
445
417
 
446
- __all__ = ["InsertDataFrameError", "insert_dataframe", "select_to_dataframe"]
418
+ __all__ = ["insert_dataframe", "select_to_dataframe"]