dycw-utilities 0.157.1__py3-none-any.whl → 0.158.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dycw-utilities
3
- Version: 0.157.1
3
+ Version: 0.158.0
4
4
  Author-email: Derek Wan <d.wan@icloud.com>
5
5
  License-File: LICENSE
6
6
  Requires-Python: >=3.12
@@ -1,4 +1,4 @@
1
- utilities/__init__.py,sha256=aCEk399BQ2RGpGQ17jdGxjzeW08BgR-YI0pg_VvNlKM,60
1
+ utilities/__init__.py,sha256=dllUsKScS7RBnWFHME-sLeOoDAWms1u25frX9QMEP-4,60
2
2
  utilities/altair.py,sha256=92E2lCdyHY4Zb-vCw6rEJIsWdKipuu-Tu2ab1ufUfAk,9079
3
3
  utilities/asyncio.py,sha256=PUedzQ5deqlSECQ33sam9cRzI9TnygHz3FdOqWJWPTM,15288
4
4
  utilities/atomicwrites.py,sha256=tPo6r-Rypd9u99u66B9z86YBPpnLrlHtwox_8Z7T34Y,5790
@@ -14,7 +14,7 @@ utilities/dataclasses.py,sha256=MXrvIPSZHlpV4msRdVVDRZZo7MC3gX5C9jDUSoNOdpE,3247
14
14
  utilities/enum.py,sha256=5l6pwZD1cjSlVW4ss-zBPspWvrbrYrdtJWcg6f5_J5w,5781
15
15
  utilities/errors.py,sha256=mFlDGSM0LI1jZ1pbqwLAH3ttLZ2JVIxyZLojw8tGVZU,1479
16
16
  utilities/eventkit.py,sha256=ddoleSwW9zdc2tjX5Ge0pMKtYwV_JMxhHYOxnWX2AGM,12609
17
- utilities/fastapi.py,sha256=3wpd63Tw9paSyy7STpAD7GGe8fLkLaRC6TPCwIGm1BU,1361
17
+ utilities/fastapi.py,sha256=TqyKvBjiMS594sXPjrz-KRTLMb3l3D3rZ1zAYV7GfOk,1454
18
18
  utilities/fpdf2.py,sha256=HgM8JSvoioDXrjC0UR3HVLjnMnnb_mML7nL2EmkTwGI,1854
19
19
  utilities/functions.py,sha256=RNVAoLeT_sl-gXaBv2VI_U_EB-d-nSVosYR4gTeeojE,28261
20
20
  utilities/functools.py,sha256=I00ru2gQPakZw2SHVeKIKXfTv741655s6HI0lUoE0D4,1552
@@ -22,7 +22,7 @@ utilities/getpass.py,sha256=DfN5UgMAtFCqS3dSfFHUfqIMZX2shXvwphOz_6J6f6A,103
22
22
  utilities/gzip.py,sha256=fkGP3KdsBfXlstodT4wtlp-PwNyUsogpbDCVVVGdsm4,781
23
23
  utilities/hashlib.py,sha256=SVTgtguur0P4elppvzOBbLEjVM3Pea0eWB61yg2ilxo,309
24
24
  utilities/http.py,sha256=TsavEfHlRtlLaeV21Z6KZh0qbPw-kvD1zsQdZ7Kep5Q,977
25
- utilities/hypothesis.py,sha256=qqV0O2ynV73tT-bAPXxlfWdR1X0iZgrjn7sK-eJOqBM,43812
25
+ utilities/hypothesis.py,sha256=OG8mxN6Y3fSEjRg4NiIjsO_JUHJBzh4g8fvpmKRoRU8,44370
26
26
  utilities/importlib.py,sha256=mV1xT_O_zt_GnZZ36tl3xOmMaN_3jErDWY54fX39F6Y,429
27
27
  utilities/inflect.py,sha256=v7YkOWSu8NAmVghPcf4F3YBZQoJCS47_DLf9jbfWIs0,581
28
28
  utilities/ipython.py,sha256=V2oMYHvEKvlNBzxDXdLvKi48oUq2SclRg5xasjaXStw,763
@@ -62,10 +62,10 @@ utilities/reprlib.py,sha256=ssYTcBW-TeRh3fhCJv57sopTZHF5FrPyyUg9yp5XBlo,3953
62
62
  utilities/scipy.py,sha256=wZJM7fEgBAkLSYYvSmsg5ac-QuwAI0BGqHVetw1_Hb0,947
63
63
  utilities/sentinel.py,sha256=A_p5jX2K0Yc5XBfoYHyBLqHsEWzE1ByOdDuzzA2pZnE,1434
64
64
  utilities/shelve.py,sha256=4OzjQI6kGuUbJciqf535rwnao-_IBv66gsT6tRGiUt0,759
65
- utilities/slack_sdk.py,sha256=ppFBvKgfg5IRWiIoKPtpTyzBtBF4XmwEvU3I5wLJikM,2140
65
+ utilities/slack_sdk.py,sha256=76-DYtcGiUhEvl-voMamc5OjfF7Y7nCq54Bys1arqzw,2233
66
66
  utilities/socket.py,sha256=K77vfREvzoVTrpYKo6MZakol0EYu2q1sWJnnZqL0So0,118
67
- utilities/sqlalchemy.py,sha256=wtbIp6XDjKwfrvl-wfoY4FQXo_a9vSoHq5K_dYeBBeY,40541
68
- utilities/sqlalchemy_polars.py,sha256=5Q9HReETYg0qB6E6WQhFh4QAZlKE-IWlogj2BVif_-w,14246
67
+ utilities/sqlalchemy.py,sha256=4rLjecgDe60MnuWxUqZVIQy3dKepNaemoGyrnuCdNZg,36347
68
+ utilities/sqlalchemy_polars.py,sha256=JCGhB37raSR7fqeWV5dTsciRTMVzIdVT9YSqKT0piT0,13370
69
69
  utilities/statsmodels.py,sha256=koyiBHvpMcSiBfh99wFUfSggLNx7cuAw3rwyfAhoKpQ,3410
70
70
  utilities/string.py,sha256=shmBK87zZwzGyixuNuXCiUbqzfeZ9xlrFwz6JTaRvDk,582
71
71
  utilities/tempfile.py,sha256=HxB2BF28CcecDJLQ3Bx2Ej-Pb6RJc6W9ngSpB9CnP4k,2018
@@ -87,8 +87,8 @@ utilities/zoneinfo.py,sha256=FBMcUQ4662Aq8SsuCL1OAhDQiyANmVjtb-C30DRrWoE,1966
87
87
  utilities/pytest_plugins/__init__.py,sha256=U4S_2y3zgLZVfMenHRaJFBW8yqh2mUBuI291LGQVOJ8,35
88
88
  utilities/pytest_plugins/pytest_randomly.py,sha256=B1qYVlExGOxTywq2r1SMi5o7btHLk2PNdY_b1p98dkE,409
89
89
  utilities/pytest_plugins/pytest_regressions.py,sha256=9v8kAXDM2ycIXJBimoiF4EgrwbUvxTycFWJiGR_GHhM,1466
90
- dycw_utilities-0.157.1.dist-info/METADATA,sha256=OvC7ucuMA0zBo4CHiaNRLyMYPjJgGjD7meIO_vUXdG4,1643
91
- dycw_utilities-0.157.1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
- dycw_utilities-0.157.1.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
93
- dycw_utilities-0.157.1.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
94
- dycw_utilities-0.157.1.dist-info/RECORD,,
90
+ dycw_utilities-0.158.0.dist-info/METADATA,sha256=-lYD9v1af9rVk_ibo4Z2cLp-i4JgpoR7qiHy1svZeDU,1643
91
+ dycw_utilities-0.158.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
92
+ dycw_utilities-0.158.0.dist-info/entry_points.txt,sha256=BOD_SoDxwsfJYOLxhrSXhHP_T7iw-HXI9f2WVkzYxvQ,135
93
+ dycw_utilities-0.158.0.dist-info/licenses/LICENSE,sha256=gppZp16M6nSVpBbUBrNL6JuYfvKwZiKgV7XoKKsHzqo,1066
94
+ dycw_utilities-0.158.0.dist-info/RECORD,,
utilities/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
  from __future__ import annotations
2
2
 
3
- __version__ = "0.157.1"
3
+ __version__ = "0.158.0"
utilities/fastapi.py CHANGED
@@ -13,7 +13,7 @@ from utilities.whenever import get_now_local
13
13
  if TYPE_CHECKING:
14
14
  from collections.abc import AsyncIterator
15
15
 
16
- from utilities.types import Delta
16
+ from utilities.types import Delta, MaybeType
17
17
 
18
18
 
19
19
  _TASKS: list[Task[None]] = []
@@ -35,14 +35,19 @@ class _PingerReceiverApp(FastAPI):
35
35
 
36
36
  @enhanced_async_context_manager
37
37
  async def yield_ping_receiver(
38
- port: int, /, *, host: str = "localhost", timeout: Delta | None = None
38
+ port: int,
39
+ /,
40
+ *,
41
+ host: str = "localhost",
42
+ timeout: Delta | None = None,
43
+ error: MaybeType[BaseException] = TimeoutError,
39
44
  ) -> AsyncIterator[None]:
40
45
  """Yield the ping receiver."""
41
46
  app = _PingerReceiverApp() # skipif-ci
42
47
  server = Server(Config(app, host=host, port=port)) # skipif-ci
43
48
  _TASKS.append(create_task(server.serve())) # skipif-ci
44
49
  try: # skipif-ci
45
- async with timeout_td(timeout):
50
+ async with timeout_td(timeout, error=error):
46
51
  yield
47
52
  finally: # skipif-ci
48
53
  await server.shutdown()
utilities/hypothesis.py CHANGED
@@ -897,6 +897,27 @@ def py_datetimes(
897
897
  ##
898
898
 
899
899
 
900
+ def quadruples[T](
901
+ strategy: SearchStrategy[T],
902
+ /,
903
+ *,
904
+ unique: MaybeSearchStrategy[bool] = False,
905
+ sorted: MaybeSearchStrategy[bool] = False, # noqa: A002
906
+ ) -> SearchStrategy[tuple[T, T, T, T]]:
907
+ """Strategy for generating quadruples of elements."""
908
+ return lists_fixed_length(strategy, 4, unique=unique, sorted=sorted).map(
909
+ _quadruples_map
910
+ )
911
+
912
+
913
+ def _quadruples_map[T](elements: list[T], /) -> tuple[T, T, T, T]:
914
+ first, second, third, fourth = elements
915
+ return first, second, third, fourth
916
+
917
+
918
+ ##
919
+
920
+
900
921
  @composite
901
922
  def random_states(
902
923
  draw: DrawFn, /, *, seed: MaybeSearchStrategy[int | None] = None
@@ -1555,6 +1576,7 @@ __all__ = [
1555
1576
  "paths",
1556
1577
  "plain_date_times",
1557
1578
  "py_datetimes",
1579
+ "quadruples",
1558
1580
  "random_states",
1559
1581
  "sentinels",
1560
1582
  "sets_fixed_length",
utilities/slack_sdk.py CHANGED
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
  from slack_sdk.webhook import WebhookResponse
16
16
  from whenever import TimeDelta
17
17
 
18
- from utilities.types import Delta
18
+ from utilities.types import Delta, MaybeType
19
19
 
20
20
 
21
21
  _TIMEOUT: Delta = MINUTE
@@ -39,11 +39,16 @@ def _get_client(url: str, /, *, timeout: Delta = _TIMEOUT) -> WebhookClient:
39
39
 
40
40
 
41
41
  async def send_to_slack_async(
42
- url: str, text: str, /, *, timeout: TimeDelta = _TIMEOUT
42
+ url: str,
43
+ text: str,
44
+ /,
45
+ *,
46
+ timeout: TimeDelta = _TIMEOUT,
47
+ error: MaybeType[BaseException] = TimeoutError,
43
48
  ) -> None:
44
49
  """Send a message via Slack."""
45
50
  client = _get_async_client(url, timeout=timeout)
46
- async with timeout_td(timeout):
51
+ async with timeout_td(timeout, error=error):
47
52
  response = await client.send(text=text)
48
53
  if response.status_code != HTTPStatus.OK: # pragma: no cover
49
54
  raise SendToSlackError(text=text, response=response)
utilities/sqlalchemy.py CHANGED
@@ -12,8 +12,8 @@ from collections.abc import (
12
12
  )
13
13
  from collections.abc import Set as AbstractSet
14
14
  from contextlib import asynccontextmanager
15
- from dataclasses import dataclass, field
16
- from functools import partial, reduce
15
+ from dataclasses import dataclass
16
+ from functools import reduce
17
17
  from itertools import chain
18
18
  from math import floor
19
19
  from operator import ge, le
@@ -37,7 +37,6 @@ from sqlalchemy import (
37
37
  Connection,
38
38
  Engine,
39
39
  Insert,
40
- PrimaryKeyConstraint,
41
40
  Selectable,
42
41
  Table,
43
42
  and_,
@@ -49,20 +48,13 @@ from sqlalchemy import (
49
48
  from sqlalchemy.dialects.mssql import dialect as mssql_dialect
50
49
  from sqlalchemy.dialects.mysql import dialect as mysql_dialect
51
50
  from sqlalchemy.dialects.oracle import dialect as oracle_dialect
52
- from sqlalchemy.dialects.postgresql import Insert as postgresql_Insert
53
51
  from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
54
52
  from sqlalchemy.dialects.postgresql import insert as postgresql_insert
55
53
  from sqlalchemy.dialects.postgresql.asyncpg import PGDialect_asyncpg
56
54
  from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
57
- from sqlalchemy.dialects.sqlite import Insert as sqlite_Insert
58
55
  from sqlalchemy.dialects.sqlite import dialect as sqlite_dialect
59
56
  from sqlalchemy.dialects.sqlite import insert as sqlite_insert
60
- from sqlalchemy.exc import (
61
- ArgumentError,
62
- DatabaseError,
63
- OperationalError,
64
- ProgrammingError,
65
- )
57
+ from sqlalchemy.exc import ArgumentError, DatabaseError
66
58
  from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
67
59
  from sqlalchemy.orm import (
68
60
  DeclarativeBase,
@@ -129,24 +121,22 @@ def check_connect(engine: Engine, /) -> bool:
129
121
  try:
130
122
  with engine.connect() as conn:
131
123
  return bool(conn.execute(_SELECT).scalar_one())
132
- except (gaierror, ConnectionRefusedError, OperationalError, ProgrammingError):
124
+ except (gaierror, ConnectionRefusedError, DatabaseError): # pragma: no cover
133
125
  return False
134
126
 
135
127
 
136
128
  async def check_connect_async(
137
- engine: AsyncEngine, /, *, timeout: Delta | None = None
129
+ engine: AsyncEngine,
130
+ /,
131
+ *,
132
+ timeout: Delta | None = None,
133
+ error: MaybeType[BaseException] = TimeoutError,
138
134
  ) -> bool:
139
135
  """Check if an engine can connect."""
140
136
  try:
141
- async with timeout_td(timeout), engine.connect() as conn:
137
+ async with timeout_td(timeout, error=error), engine.connect() as conn:
142
138
  return bool((await conn.execute(_SELECT)).scalar_one())
143
- except (
144
- gaierror,
145
- ConnectionRefusedError,
146
- OperationalError,
147
- ProgrammingError,
148
- TimeoutError,
149
- ):
139
+ except (gaierror, ConnectionRefusedError, DatabaseError, TimeoutError):
150
140
  return False
151
141
 
152
142
 
@@ -158,7 +148,7 @@ async def check_engine(
158
148
  /,
159
149
  *,
160
150
  timeout: Delta | None = None,
161
- error: type[Exception] = TimeoutError,
151
+ error: MaybeType[Exception] = TimeoutError,
162
152
  num_tables: int | tuple[int, float] | None = None,
163
153
  ) -> None:
164
154
  """Check that an engine can connect.
@@ -332,9 +322,8 @@ async def ensure_database_created(super_: URL, database: str, /) -> None:
332
322
  async with engine.begin() as conn:
333
323
  try:
334
324
  _ = await conn.execute(text(f"CREATE DATABASE {database}"))
335
- except (OperationalError, ProgrammingError) as error:
336
- if not search('database ".*" already exists', ensure_str(one(error.args))):
337
- raise
325
+ except DatabaseError as error:
326
+ _ensure_tables_maybe_reraise(error, 'database ".*" already exists')
338
327
 
339
328
 
340
329
  async def ensure_database_dropped(super_: URL, database: str, /) -> None:
@@ -352,7 +341,7 @@ async def ensure_tables_created(
352
341
  /,
353
342
  *tables_or_orms: TableOrORMInstOrClass,
354
343
  timeout: Delta | None = None,
355
- error: type[Exception] = TimeoutError,
344
+ error: MaybeType[Exception] = TimeoutError,
356
345
  ) -> None:
357
346
  """Ensure a table/set of tables is/are created."""
358
347
  tables = set(map(get_table, tables_or_orms))
@@ -381,7 +370,7 @@ async def ensure_tables_dropped(
381
370
  engine: AsyncEngine,
382
371
  *tables_or_orms: TableOrORMInstOrClass,
383
372
  timeout: Delta | None = None,
384
- error: type[Exception] = TimeoutError,
373
+ error: MaybeType[Exception] = TimeoutError,
385
374
  ) -> None:
386
375
  """Ensure a table/set of tables is/are dropped."""
387
376
  tables = set(map(get_table, tables_or_orms))
@@ -602,12 +591,16 @@ type _InsertItem = (
602
591
  | Sequence[_PairOfTupleOrStrMappingAndTable]
603
592
  | Sequence[DeclarativeBase]
604
593
  )
594
+ type _NormalizedItem = tuple[Table, StrMapping]
595
+ type _InsertPair = tuple[Table, Sequence[StrMapping]]
596
+ type _InsertTriple = tuple[Table, Insert, Sequence[StrMapping] | None]
605
597
 
606
598
 
607
599
  async def insert_items(
608
600
  engine: AsyncEngine,
609
601
  *items: _InsertItem,
610
602
  snake: bool = False,
603
+ is_upsert: bool = False,
611
604
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
612
605
  assume_tables_exist: bool = False,
613
606
  timeout_create: Delta | None = None,
@@ -637,37 +630,181 @@ async def insert_items(
637
630
  Obj(k1=v21, k2=v22, ...),
638
631
  ...]
639
632
  """
640
-
641
- def build_insert(
642
- table: Table, values: Iterable[StrMapping], /
643
- ) -> tuple[Insert, Any]:
644
- match _get_dialect(engine):
645
- case "oracle": # pragma: no cover
646
- return insert(table), values
647
- case _:
648
- return insert(table).values(list(values)), None
649
-
650
- try:
651
- prepared = _prepare_insert_or_upsert_items(
652
- partial(_normalize_insert_item, snake=snake),
653
- engine,
654
- build_insert,
655
- *items,
656
- chunk_size_frac=chunk_size_frac,
657
- )
658
- except _PrepareInsertOrUpsertItemsError as error:
659
- raise InsertItemsError(item=error.item) from None
633
+ normalized = chain.from_iterable(
634
+ _insert_items_yield_normalized(i, snake=snake) for i in items
635
+ )
636
+ triples = _insert_items_yield_triples(
637
+ engine, normalized, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
638
+ )
660
639
  if not assume_tables_exist:
640
+ triples = list(triples)
641
+ tables = {table for table, _, _ in triples}
661
642
  await ensure_tables_created(
662
- engine, *prepared.tables, timeout=timeout_create, error=error_create
643
+ engine, *tables, timeout=timeout_create, error=error_create
663
644
  )
664
- for ins, parameters in prepared.yield_pairs():
645
+ for _, ins, parameters in triples:
665
646
  async with yield_connection(
666
647
  engine, timeout=timeout_insert, error=error_insert
667
648
  ) as conn:
668
649
  _ = await conn.execute(ins, parameters=parameters)
669
650
 
670
651
 
652
+ def _insert_items_yield_normalized(
653
+ item: _InsertItem, /, *, snake: bool = False
654
+ ) -> Iterator[_NormalizedItem]:
655
+ if _is_pair_of_str_mapping_and_table(item):
656
+ mapping, table_or_orm = item
657
+ adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
658
+ yield (get_table(table_or_orm), adjusted)
659
+ return
660
+ if _is_pair_of_tuple_and_table(item):
661
+ tuple_, table_or_orm = item
662
+ mapping = _tuple_to_mapping(tuple_, table_or_orm)
663
+ yield from _insert_items_yield_normalized((mapping, table_or_orm), snake=snake)
664
+ return
665
+ if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
666
+ items, table_or_orm = item
667
+ pairs = [(i, table_or_orm) for i in items]
668
+ for p in pairs:
669
+ yield from _insert_items_yield_normalized(p, snake=snake)
670
+ return
671
+ if isinstance(item, DeclarativeBase):
672
+ mapping = _orm_inst_to_dict(item)
673
+ yield from _insert_items_yield_normalized((mapping, item), snake=snake)
674
+ return
675
+ try:
676
+ _ = iter(item)
677
+ except TypeError:
678
+ raise InsertItemsError(item=item) from None
679
+ if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
680
+ pairs = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
681
+ for p in pairs:
682
+ yield from _insert_items_yield_normalized(p, snake=snake)
683
+ return
684
+ if all(map(is_orm, item)):
685
+ classes = cast("Sequence[DeclarativeBase]", item)
686
+ for c in classes:
687
+ yield from _insert_items_yield_normalized(c, snake=snake)
688
+ return
689
+ raise InsertItemsError(item=item)
690
+
691
+
692
+ def _insert_items_yield_triples(
693
+ engine: AsyncEngine,
694
+ items: Iterable[_NormalizedItem],
695
+ /,
696
+ *,
697
+ is_upsert: bool = False,
698
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
699
+ ) -> Iterable[_InsertTriple]:
700
+ pairs = _insert_items_yield_chunked_pairs(
701
+ engine, items, is_upsert=is_upsert, chunk_size_frac=chunk_size_frac
702
+ )
703
+ for table, mappings in pairs:
704
+ match is_upsert, _get_dialect(engine):
705
+ case False, "oracle": # pragma: no cover
706
+ ins = insert(table)
707
+ parameters = mappings
708
+ case False, _:
709
+ ins = insert(table).values(mappings)
710
+ parameters = None
711
+ case True, _:
712
+ ins = _insert_items_build_insert_with_on_conflict_do_update(
713
+ engine, table, mappings
714
+ )
715
+ parameters = None
716
+ case never:
717
+ assert_never(never)
718
+ yield table, ins, parameters
719
+
720
+
721
+ def _insert_items_yield_chunked_pairs(
722
+ engine: AsyncEngine,
723
+ items: Iterable[_NormalizedItem],
724
+ /,
725
+ *,
726
+ is_upsert: bool = False,
727
+ chunk_size_frac: float = CHUNK_SIZE_FRAC,
728
+ ) -> Iterable[_InsertPair]:
729
+ for table, mappings in _insert_items_yield_raw_pairs(items, is_upsert=is_upsert):
730
+ chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
731
+ for mappings_i in chunked(mappings, chunk_size):
732
+ yield table, list(mappings_i)
733
+
734
+
735
+ def _insert_items_yield_raw_pairs(
736
+ items: Iterable[_NormalizedItem], /, *, is_upsert: bool = False
737
+ ) -> Iterable[_InsertPair]:
738
+ by_table: defaultdict[Table, list[StrMapping]] = defaultdict(list)
739
+ for table, mapping in items:
740
+ by_table[table].append(mapping)
741
+ for table, mappings in by_table.items():
742
+ yield from _insert_items_yield_raw_pairs_one(
743
+ table, mappings, is_upsert=is_upsert
744
+ )
745
+
746
+
747
+ def _insert_items_yield_raw_pairs_one(
748
+ table: Table, mappings: Iterable[StrMapping], /, *, is_upsert: bool = False
749
+ ) -> Iterable[_InsertPair]:
750
+ merged = _insert_items_yield_merged_mappings(table, mappings)
751
+ match is_upsert:
752
+ case True:
753
+ by_keys: defaultdict[frozenset[str], list[StrMapping]] = defaultdict(list)
754
+ for mapping in merged:
755
+ non_null = {k: v for k, v in mapping.items() if v is not None}
756
+ by_keys[frozenset(non_null)].append(non_null)
757
+ for mappings_i in by_keys.values():
758
+ yield table, mappings_i
759
+ case False:
760
+ yield table, list(merged)
761
+ case never:
762
+ assert_never(never)
763
+
764
+
765
+ def _insert_items_yield_merged_mappings(
766
+ table: Table, mappings: Iterable[StrMapping], /
767
+ ) -> Iterable[StrMapping]:
768
+ columns = list(yield_primary_key_columns(table))
769
+ col_names = [c.name for c in columns]
770
+ cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
771
+ cols_non_auto = set(col_names) - cols_auto
772
+ by_key: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
773
+ for mapping in mappings:
774
+ check_subset(cols_non_auto, mapping)
775
+ has_all_auto = set(cols_auto).issubset(mapping)
776
+ if has_all_auto:
777
+ pkey = tuple(mapping[k] for k in col_names)
778
+ rest: StrMapping = {k: v for k, v in mapping.items() if k not in col_names}
779
+ by_key[pkey].append(rest)
780
+ else:
781
+ yield mapping
782
+ for k, v in by_key.items():
783
+ head = dict(zip(col_names, k, strict=True))
784
+ yield merge_str_mappings(head, *v)
785
+
786
+
787
+ def _insert_items_build_insert_with_on_conflict_do_update(
788
+ engine: AsyncEngine, table: Table, mappings: Iterable[StrMapping], /
789
+ ) -> Insert:
790
+ primary_key = cast("Any", table.primary_key)
791
+ mappings = list(mappings)
792
+ columns = merge_sets(*mappings)
793
+ match _get_dialect(engine):
794
+ case "postgresql": # skipif-ci-and-not-linux
795
+ ins = postgresql_insert(table).values(mappings)
796
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
797
+ return ins.on_conflict_do_update(constraint=primary_key, set_=set_)
798
+ case "sqlite":
799
+ ins = sqlite_insert(table).values(mappings)
800
+ set_ = {c: getattr(ins.excluded, c) for c in columns}
801
+ return ins.on_conflict_do_update(index_elements=primary_key, set_=set_)
802
+ case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
803
+ raise NotImplementedError(dialect)
804
+ case never:
805
+ assert_never(never)
806
+
807
+
671
808
  @dataclass(kw_only=True, slots=True)
672
809
  class InsertItemsError(Exception):
673
810
  item: _InsertItem
@@ -738,80 +875,6 @@ async def migrate_data(
738
875
  ##
739
876
 
740
877
 
741
- def _normalize_insert_item(
742
- item: _InsertItem, /, *, snake: bool = False
743
- ) -> list[_NormalizedItem]:
744
- """Normalize an insertion item."""
745
- if _is_pair_of_str_mapping_and_table(item):
746
- mapping, table_or_orm = item
747
- adjusted = _map_mapping_to_table(mapping, table_or_orm, snake=snake)
748
- normalized = _NormalizedItem(mapping=adjusted, table=get_table(table_or_orm))
749
- return [normalized]
750
- if _is_pair_of_tuple_and_table(item):
751
- tuple_, table_or_orm = item
752
- mapping = _tuple_to_mapping(tuple_, table_or_orm)
753
- return _normalize_insert_item((mapping, table_or_orm), snake=snake)
754
- if _is_pair_of_sequence_of_tuple_or_string_mapping_and_table(item):
755
- items, table_or_orm = item
756
- pairs = [(i, table_or_orm) for i in items]
757
- normalized = (_normalize_insert_item(p, snake=snake) for p in pairs)
758
- return list(chain.from_iterable(normalized))
759
- if isinstance(item, DeclarativeBase):
760
- mapping = _orm_inst_to_dict(item)
761
- return _normalize_insert_item((mapping, item), snake=snake)
762
- try:
763
- _ = iter(item)
764
- except TypeError:
765
- raise _NormalizeInsertItemError(item=item) from None
766
- if all(map(_is_pair_of_tuple_or_str_mapping_and_table, item)):
767
- seq = cast("Sequence[_PairOfTupleOrStrMappingAndTable]", item)
768
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
769
- return list(chain.from_iterable(normalized))
770
- if all(map(is_orm, item)):
771
- seq = cast("Sequence[DeclarativeBase]", item)
772
- normalized = (_normalize_insert_item(p, snake=snake) for p in seq)
773
- return list(chain.from_iterable(normalized))
774
- raise _NormalizeInsertItemError(item=item)
775
-
776
-
777
- @dataclass(kw_only=True, slots=True)
778
- class _NormalizeInsertItemError(Exception):
779
- item: _InsertItem
780
-
781
- @override
782
- def __str__(self) -> str:
783
- return f"Item must be valid; got {self.item}"
784
-
785
-
786
- @dataclass(kw_only=True, slots=True)
787
- class _NormalizedItem:
788
- mapping: StrMapping
789
- table: Table
790
-
791
-
792
- def _normalize_upsert_item(
793
- item: _InsertItem,
794
- /,
795
- *,
796
- snake: bool = False,
797
- selected_or_all: Literal["selected", "all"] = "selected",
798
- ) -> Iterator[_NormalizedItem]:
799
- """Normalize an upsert item."""
800
- normalized = _normalize_insert_item(item, snake=snake)
801
- match selected_or_all:
802
- case "selected":
803
- for norm in normalized:
804
- values = {k: v for k, v in norm.mapping.items() if v is not None}
805
- yield _NormalizedItem(mapping=values, table=norm.table)
806
- case "all":
807
- yield from normalized
808
- case never:
809
- assert_never(never)
810
-
811
-
812
- ##
813
-
814
-
815
878
  def selectable_to_string(
816
879
  selectable: Selectable[Any], engine_or_conn: EngineOrConnectionOrAsync, /
817
880
  ) -> str:
@@ -836,134 +899,6 @@ class TablenameMixin:
836
899
  ##
837
900
 
838
901
 
839
- type _SelectedOrAll = Literal["selected", "all"]
840
-
841
-
842
- async def upsert_items(
843
- engine: AsyncEngine,
844
- /,
845
- *items: _InsertItem,
846
- snake: bool = False,
847
- selected_or_all: _SelectedOrAll = "selected",
848
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
849
- assume_tables_exist: bool = False,
850
- timeout_create: Delta | None = None,
851
- error_create: type[Exception] = TimeoutError,
852
- timeout_insert: Delta | None = None,
853
- error_insert: type[Exception] = TimeoutError,
854
- ) -> None:
855
- """Upsert a set of items into a database.
856
-
857
- These can be one of the following:
858
- - pair of dict & table/class: {k1=v1, k2=v2, ...), table_cls
859
- - pair of list of dicts & table/class: [{k1=v11, k2=v12, ...},
860
- {k1=v21, k2=v22, ...},
861
- ...], table/class
862
- - list of pairs of dict & table/class: [({k1=v11, k2=v12, ...}, table_cls1),
863
- ({k1=v21, k2=v22, ...}, table_cls2),
864
- ...]
865
- - mapped class: Obj(k1=v1, k2=v2, ...)
866
- - list of mapped classes: [Obj(k1=v11, k2=v12, ...),
867
- Obj(k1=v21, k2=v22, ...),
868
- ...]
869
- """
870
-
871
- def build_insert(
872
- table: Table, values: Iterable[StrMapping], /
873
- ) -> tuple[Insert, None]:
874
- ups = _upsert_items_build(
875
- engine, table, values, selected_or_all=selected_or_all
876
- )
877
- return ups, None
878
-
879
- try:
880
- prepared = _prepare_insert_or_upsert_items(
881
- partial(
882
- _normalize_upsert_item, snake=snake, selected_or_all=selected_or_all
883
- ),
884
- engine,
885
- build_insert,
886
- *items,
887
- chunk_size_frac=chunk_size_frac,
888
- )
889
- except _PrepareInsertOrUpsertItemsError as error:
890
- raise UpsertItemsError(item=error.item) from None
891
- if not assume_tables_exist:
892
- await ensure_tables_created(
893
- engine, *prepared.tables, timeout=timeout_create, error=error_create
894
- )
895
- for ups, _ in prepared.yield_pairs():
896
- async with yield_connection(
897
- engine, timeout=timeout_insert, error=error_insert
898
- ) as conn:
899
- _ = await conn.execute(ups)
900
-
901
-
902
- def _upsert_items_build(
903
- engine: AsyncEngine,
904
- table: Table,
905
- values: Iterable[StrMapping],
906
- /,
907
- *,
908
- selected_or_all: Literal["selected", "all"] = "selected",
909
- ) -> Insert:
910
- values = list(values)
911
- keys = merge_sets(*values)
912
- dict_nones = dict.fromkeys(keys)
913
- values = [{**dict_nones, **v} for v in values]
914
- match _get_dialect(engine):
915
- case "postgresql": # skipif-ci-and-not-linux
916
- insert = postgresql_insert
917
- case "sqlite":
918
- insert = sqlite_insert
919
- case "mssql" | "mysql" | "oracle" as dialect: # pragma: no cover
920
- raise NotImplementedError(dialect)
921
- case never:
922
- assert_never(never)
923
- ins = insert(table).values(values)
924
- primary_key = cast("Any", table.primary_key)
925
- return _upsert_items_apply_on_conflict_do_update(
926
- values, ins, primary_key, selected_or_all=selected_or_all
927
- )
928
-
929
-
930
- def _upsert_items_apply_on_conflict_do_update(
931
- values: Iterable[StrMapping],
932
- insert: postgresql_Insert | sqlite_Insert,
933
- primary_key: PrimaryKeyConstraint,
934
- /,
935
- *,
936
- selected_or_all: Literal["selected", "all"] = "selected",
937
- ) -> Insert:
938
- match selected_or_all:
939
- case "selected":
940
- columns = merge_sets(*values)
941
- case "all":
942
- columns = {c.name for c in insert.excluded}
943
- case never:
944
- assert_never(never)
945
- set_ = {c: getattr(insert.excluded, c) for c in columns}
946
- match insert:
947
- case postgresql_Insert(): # skipif-ci
948
- return insert.on_conflict_do_update(constraint=primary_key, set_=set_)
949
- case sqlite_Insert():
950
- return insert.on_conflict_do_update(index_elements=primary_key, set_=set_)
951
- case never:
952
- assert_never(never)
953
-
954
-
955
- @dataclass(kw_only=True, slots=True)
956
- class UpsertItemsError(Exception):
957
- item: _InsertItem
958
-
959
- @override
960
- def __str__(self) -> str:
961
- return f"Item must be valid; got {self.item}"
962
-
963
-
964
- ##
965
-
966
-
967
902
  @asynccontextmanager
968
903
  async def yield_connection(
969
904
  engine: AsyncEngine,
@@ -1207,84 +1142,6 @@ def _orm_inst_to_dict_predicate(
1207
1142
  ##
1208
1143
 
1209
1144
 
1210
- @dataclass(kw_only=True, slots=True)
1211
- class _PrepareInsertOrUpsertItems:
1212
- mapping: dict[Table, list[StrMapping]] = field(default_factory=dict)
1213
- yield_pairs: Callable[[], Iterator[tuple[Insert, Any]]]
1214
-
1215
- @property
1216
- def tables(self) -> Sequence[Table]:
1217
- return list(self.mapping)
1218
-
1219
-
1220
- def _prepare_insert_or_upsert_items(
1221
- normalize_item: Callable[[_InsertItem], Iterable[_NormalizedItem]],
1222
- engine: AsyncEngine,
1223
- build_insert: Callable[[Table, Iterable[StrMapping]], tuple[Insert, Any]],
1224
- /,
1225
- *items: Any,
1226
- chunk_size_frac: float = CHUNK_SIZE_FRAC,
1227
- ) -> _PrepareInsertOrUpsertItems:
1228
- """Prepare a set of insert/upsert items."""
1229
- mapping: defaultdict[Table, list[StrMapping]] = defaultdict(list)
1230
- lengths: set[int] = set()
1231
- try:
1232
- for item in items:
1233
- for normed in normalize_item(item):
1234
- mapping[normed.table].append(normed.mapping)
1235
- lengths.add(len(normed.mapping))
1236
- except _NormalizeInsertItemError as error:
1237
- raise _PrepareInsertOrUpsertItemsError(item=error.item) from None
1238
- merged: dict[Table, list[StrMapping]] = {
1239
- table: _prepare_insert_or_upsert_items_merge_items(table, values)
1240
- for table, values in mapping.items()
1241
- }
1242
-
1243
- def yield_pairs() -> Iterator[tuple[Insert, None]]:
1244
- for table, values in merged.items():
1245
- chunk_size = get_chunk_size(engine, table, chunk_size_frac=chunk_size_frac)
1246
- for chunk in chunked(values, chunk_size):
1247
- yield build_insert(table, chunk)
1248
-
1249
- return _PrepareInsertOrUpsertItems(mapping=mapping, yield_pairs=yield_pairs)
1250
-
1251
-
1252
- @dataclass(kw_only=True, slots=True)
1253
- class _PrepareInsertOrUpsertItemsError(Exception):
1254
- item: Any
1255
-
1256
- @override
1257
- def __str__(self) -> str:
1258
- return f"Item must be valid; got {self.item}"
1259
-
1260
-
1261
- def _prepare_insert_or_upsert_items_merge_items(
1262
- table: Table, items: Iterable[StrMapping], /
1263
- ) -> list[StrMapping]:
1264
- columns = list(yield_primary_key_columns(table))
1265
- col_names = [c.name for c in columns]
1266
- cols_auto = {c.name for c in columns if c.autoincrement in {True, "auto"}}
1267
- cols_non_auto = set(col_names) - cols_auto
1268
- mapping: defaultdict[tuple[Hashable, ...], list[StrMapping]] = defaultdict(list)
1269
- unchanged: list[StrMapping] = []
1270
- for item in items:
1271
- check_subset(cols_non_auto, item)
1272
- has_all_auto = set(cols_auto).issubset(item)
1273
- if has_all_auto:
1274
- pkey = tuple(item[k] for k in col_names)
1275
- rest: StrMapping = {k: v for k, v in item.items() if k not in col_names}
1276
- mapping[pkey].append(rest)
1277
- else:
1278
- unchanged.append(item)
1279
- merged = {k: merge_str_mappings(*v) for k, v in mapping.items()}
1280
- return [
1281
- dict(zip(col_names, k, strict=True)) | dict(v) for k, v in merged.items()
1282
- ] + unchanged
1283
-
1284
-
1285
- ##
1286
-
1287
-
1288
1145
  def _tuple_to_mapping(
1289
1146
  values: tuple[Any, ...], table_or_orm: TableOrORMInstOrClass, /
1290
1147
  ) -> dict[str, Any]:
@@ -1303,7 +1160,6 @@ __all__ = [
1303
1160
  "GetTableError",
1304
1161
  "InsertItemsError",
1305
1162
  "TablenameMixin",
1306
- "UpsertItemsError",
1307
1163
  "check_connect",
1308
1164
  "check_connect_async",
1309
1165
  "check_engine",
@@ -1329,7 +1185,6 @@ __all__ = [
1329
1185
  "is_table_or_orm",
1330
1186
  "migrate_data",
1331
1187
  "selectable_to_string",
1332
- "upsert_items",
1333
1188
  "yield_connection",
1334
1189
  "yield_primary_key_columns",
1335
1190
  ]
@@ -4,7 +4,7 @@ import datetime as dt
4
4
  import decimal
5
5
  from contextlib import suppress
6
6
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, Any, Literal, assert_never, cast, overload, override
7
+ from typing import TYPE_CHECKING, Any, cast, overload, override
8
8
  from uuid import UUID
9
9
 
10
10
  import polars as pl
@@ -44,7 +44,6 @@ from utilities.sqlalchemy import (
44
44
  get_chunk_size,
45
45
  get_columns,
46
46
  insert_items,
47
- upsert_items,
48
47
  )
49
48
  from utilities.text import snake_case
50
49
  from utilities.typing import is_subclass_gen
@@ -75,9 +74,9 @@ async def insert_dataframe(
75
74
  /,
76
75
  *,
77
76
  snake: bool = False,
77
+ is_upsert: bool = False,
78
78
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
79
79
  assume_tables_exist: bool = False,
80
- upsert: Literal["selected", "all"] | None = None,
81
80
  timeout_create: TimeDelta | None = None,
82
81
  error_create: type[Exception] = TimeoutError,
83
82
  timeout_insert: TimeDelta | None = None,
@@ -87,43 +86,25 @@ async def insert_dataframe(
87
86
  mapping = _insert_dataframe_map_df_schema_to_table(
88
87
  df.schema, table_or_orm, snake=snake
89
88
  )
90
- items = df.select(*mapping).rename(mapping).to_dicts()
89
+ items = df.select(*mapping).rename(mapping).rows(named=True)
91
90
  if len(items) == 0:
92
- if not df.is_empty():
93
- raise InsertDataFrameError(df=df)
94
91
  if not assume_tables_exist:
95
92
  await ensure_tables_created(
96
93
  engine, table_or_orm, timeout=timeout_create, error=error_create
97
94
  )
98
95
  return
99
- match upsert:
100
- case None:
101
- await insert_items(
102
- engine,
103
- (items, table_or_orm),
104
- snake=snake,
105
- chunk_size_frac=chunk_size_frac,
106
- assume_tables_exist=assume_tables_exist,
107
- timeout_create=timeout_create,
108
- error_create=error_create,
109
- timeout_insert=timeout_insert,
110
- error_insert=error_insert,
111
- )
112
- case "selected" | "all" as selected_or_all: # skipif-ci-and-not-linux
113
- await upsert_items(
114
- engine,
115
- (items, table_or_orm),
116
- snake=snake,
117
- chunk_size_frac=chunk_size_frac,
118
- selected_or_all=selected_or_all,
119
- assume_tables_exist=assume_tables_exist,
120
- timeout_create=timeout_create,
121
- error_create=error_create,
122
- timeout_insert=timeout_insert,
123
- error_insert=error_insert,
124
- )
125
- case never:
126
- assert_never(never)
96
+ await insert_items(
97
+ engine,
98
+ (items, table_or_orm),
99
+ snake=snake,
100
+ is_upsert=is_upsert,
101
+ chunk_size_frac=chunk_size_frac,
102
+ assume_tables_exist=assume_tables_exist,
103
+ timeout_create=timeout_create,
104
+ error_create=error_create,
105
+ timeout_insert=timeout_insert,
106
+ error_insert=error_insert,
107
+ )
127
108
 
128
109
 
129
110
  def _insert_dataframe_map_df_schema_to_table(
@@ -207,15 +188,6 @@ def _insert_dataframe_check_df_and_db_types(
207
188
  )
208
189
 
209
190
 
210
- @dataclass(kw_only=True, slots=True)
211
- class InsertDataFrameError(Exception):
212
- df: DataFrame
213
-
214
- @override
215
- def __str__(self) -> str:
216
- return f"Non-empty DataFrame must resolve to at least 1 item\n\n{self.df}"
217
-
218
-
219
191
  @overload
220
192
  async def select_to_dataframe(
221
193
  sel: Select[Any],
@@ -229,6 +201,7 @@ async def select_to_dataframe(
229
201
  in_clauses_chunk_size: int | None = None,
230
202
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
231
203
  timeout: Delta | None = None,
204
+ error: MaybeType[BaseException] = TimeoutError,
232
205
  **kwargs: Any,
233
206
  ) -> DataFrame: ...
234
207
  @overload
@@ -244,6 +217,7 @@ async def select_to_dataframe(
244
217
  in_clauses_chunk_size: int | None = None,
245
218
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
246
219
  timeout: Delta | None = None,
220
+ error: MaybeType[BaseException] = TimeoutError,
247
221
  **kwargs: Any,
248
222
  ) -> Iterable[DataFrame]: ...
249
223
  @overload
@@ -259,6 +233,7 @@ async def select_to_dataframe(
259
233
  in_clauses_chunk_size: int | None = None,
260
234
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
261
235
  timeout: Delta | None = None,
236
+ error: MaybeType[BaseException] = TimeoutError,
262
237
  **kwargs: Any,
263
238
  ) -> AsyncIterable[DataFrame]: ...
264
239
  @overload
@@ -274,6 +249,7 @@ async def select_to_dataframe(
274
249
  in_clauses_chunk_size: int | None = None,
275
250
  chunk_size_frac: float = CHUNK_SIZE_FRAC,
276
251
  timeout: Delta | None = None,
252
+ error: MaybeType[BaseException] = TimeoutError,
277
253
  **kwargs: Any,
278
254
  ) -> DataFrame | Iterable[DataFrame] | AsyncIterable[DataFrame]: ...
279
255
  async def select_to_dataframe(
@@ -439,4 +415,4 @@ def _select_to_dataframe_yield_selects_with_in_clauses(
439
415
  return (sel.where(in_col.in_(values)) for values in chunked(in_values, chunk_size))
440
416
 
441
417
 
442
- __all__ = ["InsertDataFrameError", "insert_dataframe", "select_to_dataframe"]
418
+ __all__ = ["insert_dataframe", "select_to_dataframe"]