shardproxy 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shardproxy-1.0.0/PKG-INFO +59 -0
- shardproxy-1.0.0/README.rst +50 -0
- shardproxy-1.0.0/pyproject.toml +53 -0
- shardproxy-1.0.0/src/shardproxy/__init__.py +24 -0
- shardproxy-1.0.0/src/shardproxy/ops.py +234 -0
- shardproxy-1.0.0/src/shardproxy/proxy.py +136 -0
- shardproxy-1.0.0/src/shardproxy/py.typed +0 -0
- shardproxy-1.0.0/src/shardproxy/types.py +68 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: shardproxy
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Sharding with SQLAlchemy
|
|
5
|
+
Classifier: License :: OSI Approved :: ISC License (ISCL)
|
|
6
|
+
Requires-Dist: sqlalchemy[asyncio]>=2.0
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Description-Content-Type: text/x-rst
|
|
9
|
+
|
|
10
|
+
shardproxy
|
|
11
|
+
==========
|
|
12
|
+
|
|
13
|
+
The module consists of two layers:
|
|
14
|
+
|
|
15
|
+
- ShardProxy class that just manages connections.
|
|
16
|
+
- ops module that implements helper classes for parallelism and result processing.
|
|
17
|
+
|
|
18
|
+
ShardProxy
|
|
19
|
+
----------
|
|
20
|
+
|
|
21
|
+
Initialize:
|
|
22
|
+
|
|
23
|
+
.. code:: python
|
|
24
|
+
|
|
25
|
+
SHARD_URLS = [
|
|
26
|
+
"postgresql+asyncpg://server/shard0",
|
|
27
|
+
"postgresql+asyncpg://server/shard1",
|
|
28
|
+
]
|
|
29
|
+
proxy = ShardProxy(SHARD_URLS)
|
|
30
|
+
|
|
31
|
+
Process request:
|
|
32
|
+
|
|
33
|
+
.. code:: python
|
|
34
|
+
|
|
35
|
+
async def fetch_user(user_id: UUID):
|
|
36
|
+
async with proxy.connect(user_id) as conn:
|
|
37
|
+
async with conn.begin():
|
|
38
|
+
stmt = select(User).where(User.id == user_id)
|
|
39
|
+
res = await conn.execute(stmt)
|
|
40
|
+
return res.all()
|
|
41
|
+
|
|
42
|
+
Operations
|
|
43
|
+
----------
|
|
44
|
+
|
|
45
|
+
Select bunch of IDs:
|
|
46
|
+
|
|
47
|
+
.. code:: python
|
|
48
|
+
|
|
49
|
+
UserRow = Row[tuple[User]]
|
|
50
|
+
|
|
51
|
+
class SelectManyUsers(RunOnKeysOperation[UserRow]):
|
|
52
|
+
async def process(
|
|
53
|
+
self, conn: AsyncConnection, keys: list[UUID]
|
|
54
|
+
) -> ResultRows[UserRow]:
|
|
55
|
+
stmt = select(User).where(User.id.in_(keys))
|
|
56
|
+
res = await conn.execute(stmt)
|
|
57
|
+
return res.all()
|
|
58
|
+
|
|
59
|
+
users = await SelectManyUsers().run(proxy)
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
shardproxy
|
|
2
|
+
==========
|
|
3
|
+
|
|
4
|
+
The module consists of two layers:
|
|
5
|
+
|
|
6
|
+
- ShardProxy class that just manages connections.
|
|
7
|
+
- ops module that implements helper classes for parallelism and result processing.
|
|
8
|
+
|
|
9
|
+
ShardProxy
|
|
10
|
+
----------
|
|
11
|
+
|
|
12
|
+
Initialize:
|
|
13
|
+
|
|
14
|
+
.. code:: python
|
|
15
|
+
|
|
16
|
+
SHARD_URLS = [
|
|
17
|
+
"postgresql+asyncpg://server/shard0",
|
|
18
|
+
"postgresql+asyncpg://server/shard1",
|
|
19
|
+
]
|
|
20
|
+
proxy = ShardProxy(SHARD_URLS)
|
|
21
|
+
|
|
22
|
+
Process request:
|
|
23
|
+
|
|
24
|
+
.. code:: python
|
|
25
|
+
|
|
26
|
+
async def fetch_user(user_id: UUID):
|
|
27
|
+
async with proxy.connect(user_id) as conn:
|
|
28
|
+
async with conn.begin():
|
|
29
|
+
stmt = select(User).where(User.id == user_id)
|
|
30
|
+
res = await conn.execute(stmt)
|
|
31
|
+
return res.all()
|
|
32
|
+
|
|
33
|
+
Operations
|
|
34
|
+
----------
|
|
35
|
+
|
|
36
|
+
Select bunch of IDs:
|
|
37
|
+
|
|
38
|
+
.. code:: python
|
|
39
|
+
|
|
40
|
+
UserRow = Row[tuple[User]]
|
|
41
|
+
|
|
42
|
+
class SelectManyUsers(RunOnKeysOperation[UserRow]):
|
|
43
|
+
async def process(
|
|
44
|
+
self, conn: AsyncConnection, keys: list[UUID]
|
|
45
|
+
) -> ResultRows[UserRow]:
|
|
46
|
+
stmt = select(User).where(User.id.in_(keys))
|
|
47
|
+
res = await conn.execute(stmt)
|
|
48
|
+
return res.all()
|
|
49
|
+
|
|
50
|
+
users = await SelectManyUsers().run(proxy)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "shardproxy"
|
|
3
|
+
version = "1.0.0"
|
|
4
|
+
description = "Sharding with SQLAlchemy"
|
|
5
|
+
readme = "README.rst"
|
|
6
|
+
classifiers = ["License :: OSI Approved :: ISC License (ISCL)"]
|
|
7
|
+
requires-python = ">=3.12"
|
|
8
|
+
dependencies = ["sqlalchemy[asyncio]>=2.0"]
|
|
9
|
+
|
|
10
|
+
[dependency-groups]
|
|
11
|
+
dev = [
|
|
12
|
+
{include-group = "docs"},
|
|
13
|
+
{include-group = "lint"},
|
|
14
|
+
{include-group = "test"},
|
|
15
|
+
]
|
|
16
|
+
docs = ["sphinx"]
|
|
17
|
+
lint = ["ruff", "isort", "mypy"]
|
|
18
|
+
test = [
|
|
19
|
+
"asyncpg",
|
|
20
|
+
"psycopg[binary,pool]",
|
|
21
|
+
"pytest",
|
|
22
|
+
"pytest-asyncio",
|
|
23
|
+
"pytest-cov",
|
|
24
|
+
"coverage",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[build-system]
|
|
28
|
+
requires = ["uv_build>=0.11.1,<0.12.0"]
|
|
29
|
+
build-backend = "uv_build"
|
|
30
|
+
|
|
31
|
+
[tool.mypy]
|
|
32
|
+
python_version = "3.12"
|
|
33
|
+
strict = true
|
|
34
|
+
disallow_any_decorated = true
|
|
35
|
+
disallow_any_unimported = true
|
|
36
|
+
#disallow_any_explicit = true
|
|
37
|
+
#disallow_any_expr = true
|
|
38
|
+
|
|
39
|
+
[tool.isort]
|
|
40
|
+
known_first_party = ["shardproxy"]
|
|
41
|
+
known_local_folder = ["helpers"]
|
|
42
|
+
multi_line_output = 5
|
|
43
|
+
balanced_wrapping = true
|
|
44
|
+
include_trailing_comma = true
|
|
45
|
+
atomic = true
|
|
46
|
+
quiet = true
|
|
47
|
+
|
|
48
|
+
[[tool.uv.index]]
|
|
49
|
+
name = "testpypi"
|
|
50
|
+
url = "https://test.pypi.org/simple/"
|
|
51
|
+
publish-url = "https://test.pypi.org/legacy/"
|
|
52
|
+
explicit = true
|
|
53
|
+
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Sharding with SQLAlchemy."""
|
|
2
|
+
|
|
3
|
+
from .ops import (
|
|
4
|
+
RunOnAllOperation,
|
|
5
|
+
RunOnAnyOperation,
|
|
6
|
+
RunOnArgsOperation,
|
|
7
|
+
RunOnKeysOperation,
|
|
8
|
+
RunOnOneOperation,
|
|
9
|
+
)
|
|
10
|
+
from .proxy import ShardProxy
|
|
11
|
+
from .types import AsyncConnection, AsyncEngine, DBTask, ResultRows
|
|
12
|
+
|
|
13
|
+
__all__ = (
|
|
14
|
+
"ShardProxy",
|
|
15
|
+
"AsyncEngine",
|
|
16
|
+
"AsyncConnection",
|
|
17
|
+
"ResultRows",
|
|
18
|
+
"DBTask",
|
|
19
|
+
"RunOnAllOperation",
|
|
20
|
+
"RunOnKeysOperation",
|
|
21
|
+
"RunOnArgsOperation",
|
|
22
|
+
"RunOnAnyOperation",
|
|
23
|
+
"RunOnOneOperation",
|
|
24
|
+
)
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Helper classes for multi-shard operations."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
9
|
+
|
|
10
|
+
from .proxy import ShardProxy
|
|
11
|
+
from .types import DBTask, ResultRows
|
|
12
|
+
|
|
13
|
+
__all__ = (
|
|
14
|
+
"RunOnKeysOperation",
|
|
15
|
+
"RunOnArgsOperation",
|
|
16
|
+
"RunOnAllOperation",
|
|
17
|
+
"RunOnOneOperation",
|
|
18
|
+
"RunOnAnyOperation",
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RunOnKeysOperation[TRow](ABC):
|
|
23
|
+
"""Runs operation over given keys.
|
|
24
|
+
|
|
25
|
+
Usage::
|
|
26
|
+
|
|
27
|
+
class MyOp(RunOnKeysOperation[MyRow]):
|
|
28
|
+
async def process(
|
|
29
|
+
self, conn: AsyncConnection, keys: list[UUID]
|
|
30
|
+
) -> ResultRows[MyRow]:
|
|
31
|
+
return ...
|
|
32
|
+
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
#: Whether connection is in transaction context when given to .process().
|
|
36
|
+
USE_TRANSACTION = True
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
async def process(
|
|
40
|
+
self, conn: AsyncConnection, keys: list[UUID]
|
|
41
|
+
) -> ResultRows[TRow]:
|
|
42
|
+
"""Business logic to execute on each shard."""
|
|
43
|
+
raise NotImplementedError
|
|
44
|
+
|
|
45
|
+
async def run(
|
|
46
|
+
self, proxy: ShardProxy, shard_keys: Sequence[UUID]
|
|
47
|
+
) -> ResultRows[TRow]:
|
|
48
|
+
"""Split keys for each shard and run .process() on them."""
|
|
49
|
+
|
|
50
|
+
tasks: list[DBTask[ResultRows[TRow]]] = []
|
|
51
|
+
for engine, keys in proxy.spread_keys(shard_keys):
|
|
52
|
+
task = asyncio.create_task(self._handle_process(engine, keys))
|
|
53
|
+
tasks.append(task)
|
|
54
|
+
|
|
55
|
+
return await self._collect(tasks)
|
|
56
|
+
|
|
57
|
+
async def _handle_process(
|
|
58
|
+
self, engine: AsyncEngine, keys: list[UUID]
|
|
59
|
+
) -> ResultRows[TRow]:
|
|
60
|
+
async with engine.connect() as conn:
|
|
61
|
+
if self.USE_TRANSACTION:
|
|
62
|
+
async with conn.begin():
|
|
63
|
+
return await self.process(conn, keys)
|
|
64
|
+
else:
|
|
65
|
+
return await self.process(conn, keys)
|
|
66
|
+
|
|
67
|
+
async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
|
|
68
|
+
shard_rows = await asyncio.gather(*tasks)
|
|
69
|
+
res: list[TRow] = []
|
|
70
|
+
for rows in shard_rows:
|
|
71
|
+
res.extend(rows)
|
|
72
|
+
return res
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RunOnArgsOperation[TRow, TArg](ABC):
|
|
76
|
+
"""Runs operation over given arguments.
|
|
77
|
+
|
|
78
|
+
Usage::
|
|
79
|
+
|
|
80
|
+
class MyOp(RunOnArgsOperation[MyRow, MyArg]):
|
|
81
|
+
async def process(
|
|
82
|
+
self, conn: AsyncConnection, args: dict[UUID, MyArg]
|
|
83
|
+
) -> ResultRows[MyRow]:
|
|
84
|
+
return ...
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
#: Whether connection is in transaction context when given to .process().
|
|
89
|
+
USE_TRANSACTION = True
|
|
90
|
+
|
|
91
|
+
@abstractmethod
|
|
92
|
+
async def process(
|
|
93
|
+
self, conn: AsyncConnection, args: dict[UUID, TArg]
|
|
94
|
+
) -> ResultRows[TRow]:
|
|
95
|
+
"""Business logic to execute on each shard."""
|
|
96
|
+
raise NotImplementedError
|
|
97
|
+
|
|
98
|
+
async def run(self, proxy: ShardProxy, args: dict[UUID, TArg]) -> ResultRows[TRow]:
|
|
99
|
+
"""Split args for each shard and run .process() with them."""
|
|
100
|
+
|
|
101
|
+
tasks: list[DBTask[ResultRows[TRow]]] = []
|
|
102
|
+
for engine, shard_args in proxy.spread_args(args):
|
|
103
|
+
task = asyncio.create_task(self._handle_process(engine, shard_args))
|
|
104
|
+
tasks.append(task)
|
|
105
|
+
|
|
106
|
+
return await self._collect(tasks)
|
|
107
|
+
|
|
108
|
+
async def _handle_process(
|
|
109
|
+
self, engine: AsyncEngine, args: dict[UUID, TArg]
|
|
110
|
+
) -> ResultRows[TRow]:
|
|
111
|
+
async with engine.connect() as conn:
|
|
112
|
+
if self.USE_TRANSACTION:
|
|
113
|
+
async with conn.begin():
|
|
114
|
+
return await self.process(conn, args)
|
|
115
|
+
else:
|
|
116
|
+
return await self.process(conn, args)
|
|
117
|
+
|
|
118
|
+
async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
|
|
119
|
+
shard_rows = await asyncio.gather(*tasks)
|
|
120
|
+
res: list[TRow] = []
|
|
121
|
+
for rows in shard_rows:
|
|
122
|
+
res.extend(rows)
|
|
123
|
+
return res
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class RunOnAllOperation[TRow](ABC):
|
|
127
|
+
"""Runs operation over all shards.
|
|
128
|
+
|
|
129
|
+
Usage::
|
|
130
|
+
|
|
131
|
+
class MyOp(RunOnAllOperation[MyRow]):
|
|
132
|
+
async def process(
|
|
133
|
+
self, conn: AsyncConnection
|
|
134
|
+
) -> ResultRows[MyRow]:
|
|
135
|
+
return ...
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
#: Whether connection is in transaction context when given to .process().
|
|
140
|
+
USE_TRANSACTION = True
|
|
141
|
+
|
|
142
|
+
@abstractmethod
|
|
143
|
+
async def process(self, conn: AsyncConnection) -> ResultRows[TRow]:
|
|
144
|
+
"""Business logic to execute on each shard."""
|
|
145
|
+
raise NotImplementedError
|
|
146
|
+
|
|
147
|
+
async def run(self, proxy: ShardProxy) -> ResultRows[TRow]:
|
|
148
|
+
"""Loop over all shards, execute .process() on each."""
|
|
149
|
+
|
|
150
|
+
tasks: list[DBTask[ResultRows[TRow]]] = []
|
|
151
|
+
for engine in proxy.shards:
|
|
152
|
+
task = asyncio.create_task(self._handle_process(engine))
|
|
153
|
+
tasks.append(task)
|
|
154
|
+
|
|
155
|
+
return await self._collect(tasks)
|
|
156
|
+
|
|
157
|
+
async def _handle_process(self, engine: AsyncEngine) -> ResultRows[TRow]:
|
|
158
|
+
async with engine.connect() as conn:
|
|
159
|
+
if self.USE_TRANSACTION:
|
|
160
|
+
async with conn.begin():
|
|
161
|
+
return await self.process(conn)
|
|
162
|
+
else:
|
|
163
|
+
return await self.process(conn)
|
|
164
|
+
|
|
165
|
+
async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
|
|
166
|
+
shard_rows = await asyncio.gather(*tasks)
|
|
167
|
+
res: list[TRow] = []
|
|
168
|
+
for rows in shard_rows:
|
|
169
|
+
res.extend(rows)
|
|
170
|
+
return res
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class RunOnOneOperation[TRow](ABC):
|
|
174
|
+
"""Runs operation in single shard.
|
|
175
|
+
|
|
176
|
+
Usage::
|
|
177
|
+
|
|
178
|
+
class MyOp(RunOnOneOperation[MyRow]):
|
|
179
|
+
async def process(
|
|
180
|
+
self, conn: AsyncConnection, shard_key: UUID
|
|
181
|
+
) -> ResultRows[MyRow]:
|
|
182
|
+
return ...
|
|
183
|
+
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
#: Whether connection is in transaction context when given to .process().
|
|
187
|
+
USE_TRANSACTION = True
|
|
188
|
+
|
|
189
|
+
@abstractmethod
|
|
190
|
+
async def process(self, conn: AsyncConnection, shard_key: UUID) -> ResultRows[TRow]:
|
|
191
|
+
"""Business logic to execute on shard."""
|
|
192
|
+
raise NotImplementedError
|
|
193
|
+
|
|
194
|
+
async def run(self, proxy: ShardProxy, shard_key: UUID) -> ResultRows[TRow]:
|
|
195
|
+
"""Pick shard and run .process() on it."""
|
|
196
|
+
|
|
197
|
+
async with proxy.connect(shard_key) as conn:
|
|
198
|
+
if self.USE_TRANSACTION:
|
|
199
|
+
async with conn.begin():
|
|
200
|
+
return await self.process(conn, shard_key)
|
|
201
|
+
else:
|
|
202
|
+
return await self.process(conn, shard_key)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class RunOnAnyOperation[TRow](ABC):
|
|
206
|
+
"""Runs operation in single random shard.
|
|
207
|
+
|
|
208
|
+
Usage::
|
|
209
|
+
|
|
210
|
+
class MyOp(RunOnAnyOperation[MyRow]):
|
|
211
|
+
async def process(
|
|
212
|
+
self, conn: AsyncConnection
|
|
213
|
+
) -> ResultRows[MyRow]:
|
|
214
|
+
return ...
|
|
215
|
+
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
#: Whether connection is in transaction context when given to .process().
|
|
219
|
+
USE_TRANSACTION = True
|
|
220
|
+
|
|
221
|
+
@abstractmethod
|
|
222
|
+
async def process(self, conn: AsyncConnection) -> ResultRows[TRow]:
|
|
223
|
+
"""Business logic to execute on shard."""
|
|
224
|
+
raise NotImplementedError
|
|
225
|
+
|
|
226
|
+
async def run(self, proxy: ShardProxy) -> ResultRows[TRow]:
|
|
227
|
+
"""Pick random shard and run .process() on it."""
|
|
228
|
+
|
|
229
|
+
async with proxy.connect_any() as conn:
|
|
230
|
+
if self.USE_TRANSACTION:
|
|
231
|
+
async with conn.begin():
|
|
232
|
+
return await self.process(conn)
|
|
233
|
+
else:
|
|
234
|
+
return await self.process(conn)
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Connection manager for sharding."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping, Sequence
|
|
4
|
+
from random import choice
|
|
5
|
+
from typing import Unpack
|
|
6
|
+
from uuid import UUID
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import URL
|
|
9
|
+
from sqlalchemy.ext.asyncio import (
|
|
10
|
+
AsyncConnection,
|
|
11
|
+
AsyncEngine,
|
|
12
|
+
create_async_engine,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from .types import AsyncEngineParams
|
|
16
|
+
|
|
17
|
+
__all__ = ("ShardProxy",)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ShardProxy:
|
|
21
|
+
"""Manages AsyncEngine mapping for shards.
|
|
22
|
+
|
|
23
|
+
Takes list of shard URLs and generic create_async_engine()/create_engine()
|
|
24
|
+
keyword arguments.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
_shards: list[AsyncEngine]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self, shard_urls: Sequence[str | URL], **kwargs: Unpack[AsyncEngineParams]
|
|
31
|
+
) -> None:
|
|
32
|
+
self._shards = []
|
|
33
|
+
for url in shard_urls:
|
|
34
|
+
engine = create_async_engine(url, **kwargs)
|
|
35
|
+
self._shards.append(engine)
|
|
36
|
+
|
|
37
|
+
def _get_shard(self, shard_key: UUID) -> int:
|
|
38
|
+
return shard_key.node % len(self._shards)
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def shards(self) -> Sequence[AsyncEngine]:
|
|
42
|
+
"""Read-only list of engines per shard.
|
|
43
|
+
|
|
44
|
+
Usage::
|
|
45
|
+
|
|
46
|
+
# run-on-all: serial mode
|
|
47
|
+
for engine in proxy.shards:
|
|
48
|
+
async with engine.connect() as conn:
|
|
49
|
+
async with conn.begin():
|
|
50
|
+
await conn.execute()
|
|
51
|
+
|
|
52
|
+
# run-on-all: parallel mode
|
|
53
|
+
async with asyncio.TaskGroup() as tg:
|
|
54
|
+
for engine in proxy.shards:
|
|
55
|
+
tg.create_task(process(engine, items))
|
|
56
|
+
"""
|
|
57
|
+
return self._shards
|
|
58
|
+
|
|
59
|
+
def connect(self, shard_key: UUID) -> AsyncConnection:
|
|
60
|
+
"""Returns connection for specific shard as context manager.
|
|
61
|
+
|
|
62
|
+
Usage::
|
|
63
|
+
|
|
64
|
+
async with proxy.connect(key) as conn:
|
|
65
|
+
async with conn.begin():
|
|
66
|
+
...
|
|
67
|
+
"""
|
|
68
|
+
idx = self._get_shard(shard_key)
|
|
69
|
+
return self._shards[idx].connect()
|
|
70
|
+
|
|
71
|
+
def connect_any[T](self) -> AsyncConnection:
|
|
72
|
+
"""Returns connection to random shard as context manager.
|
|
73
|
+
|
|
74
|
+
Usage::
|
|
75
|
+
|
|
76
|
+
async with proxy.connect_any() as conn:
|
|
77
|
+
async with conn.begin():
|
|
78
|
+
...
|
|
79
|
+
"""
|
|
80
|
+
engine = choice(self._shards)
|
|
81
|
+
return engine.connect()
|
|
82
|
+
|
|
83
|
+
def spread_args[A](
|
|
84
|
+
self, args: Mapping[UUID, A]
|
|
85
|
+
) -> Sequence[tuple[AsyncEngine, dict[UUID, A]]]:
|
|
86
|
+
"""Spread keys with arguments over shards.
|
|
87
|
+
|
|
88
|
+
Usage::
|
|
89
|
+
|
|
90
|
+
# serial mode
|
|
91
|
+
for engine, shard_items in proxy.split_args(all_items):
|
|
92
|
+
async with engine.connect() as conn:
|
|
93
|
+
async with conn.begin():
|
|
94
|
+
await conn.execute()
|
|
95
|
+
|
|
96
|
+
# parallel mode
|
|
97
|
+
async with asyncio.TaskGroup() as tg:
|
|
98
|
+
for engine, items in proxy.split_args(all_items):
|
|
99
|
+
tg.create_task(process(engine, items))
|
|
100
|
+
"""
|
|
101
|
+
split: dict[int, dict[UUID, A]] = {}
|
|
102
|
+
for key, arg in args.items():
|
|
103
|
+
nr = self._get_shard(key)
|
|
104
|
+
target = split.get(nr)
|
|
105
|
+
if target is None:
|
|
106
|
+
split[nr] = {key: arg}
|
|
107
|
+
else:
|
|
108
|
+
target[key] = arg
|
|
109
|
+
return [(self._shards[nr], target) for nr, target in split.items()]
|
|
110
|
+
|
|
111
|
+
def spread_keys(
|
|
112
|
+
self, keys: Sequence[UUID]
|
|
113
|
+
) -> Sequence[tuple[AsyncEngine, list[UUID]]]:
|
|
114
|
+
"""Spread keys over shards.
|
|
115
|
+
|
|
116
|
+
Usage::
|
|
117
|
+
|
|
118
|
+
# serial mode
|
|
119
|
+
for engine, shard_keys in proxy.split_keys(all_keys):
|
|
120
|
+
async with engine.connect() as conn:
|
|
121
|
+
async with conn.begin():
|
|
122
|
+
await conn.execute()
|
|
123
|
+
|
|
124
|
+
# parallel mode
|
|
125
|
+
async with asyncio.TaskGroup() as tg:
|
|
126
|
+
for engine, shard_keys in proxy.split_keys(all_keys):
|
|
127
|
+
tg.create_task(process(engine, shard_keys))
|
|
128
|
+
"""
|
|
129
|
+
split: dict[int, set[UUID]] = {}
|
|
130
|
+
for key in keys:
|
|
131
|
+
nr = self._get_shard(key)
|
|
132
|
+
target = split.get(nr)
|
|
133
|
+
if target is None:
|
|
134
|
+
target = split[nr] = set()
|
|
135
|
+
target.add(key)
|
|
136
|
+
return [(self._shards[nr], list(target)) for nr, target in split.items()]
|
|
File without changes
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Common types"""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any, Callable, Literal, Type, TypedDict
|
|
6
|
+
|
|
7
|
+
from sqlalchemy.engine.interfaces import (
|
|
8
|
+
CoreExecuteOptionsParameter,
|
|
9
|
+
IsolationLevel,
|
|
10
|
+
)
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
|
|
12
|
+
from sqlalchemy.pool import Pool
|
|
13
|
+
|
|
14
|
+
__all__ = (
|
|
15
|
+
"AsyncConnection",
|
|
16
|
+
"AsyncEngine",
|
|
17
|
+
"AsyncEngineParams",
|
|
18
|
+
"ResultRows",
|
|
19
|
+
"DBTask",
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
type ResultRows[TRow] = Sequence[TRow]
|
|
24
|
+
|
|
25
|
+
type DBTask[TResult] = asyncio.Task[TResult]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class AsyncEngineParams(TypedDict, total=False):
|
|
29
|
+
"""Helper struct to describe common parameters for create_async_engine.
|
|
30
|
+
|
|
31
|
+
It may be incomplete - arguments from here are
|
|
32
|
+
passed to AsyncEngine, Dialect and Pool.
|
|
33
|
+
|
|
34
|
+
Some values are deliberately hidden as they will not work with ShardProxy.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
# async_creator - does not work with sharding
|
|
38
|
+
# connect_args - does not work with sharding
|
|
39
|
+
echo: bool | Literal["debug"] | None
|
|
40
|
+
echo_pool: bool | Literal["debug"] | None
|
|
41
|
+
enable_from_linting: bool
|
|
42
|
+
execution_options: CoreExecuteOptionsParameter
|
|
43
|
+
hide_parameters: bool
|
|
44
|
+
insertmanyvalues_page_size: int
|
|
45
|
+
isolation_level: IsolationLevel
|
|
46
|
+
json_deserializer: Callable[..., Any]
|
|
47
|
+
json_serializer: Callable[..., Any]
|
|
48
|
+
label_length: int | None
|
|
49
|
+
logging_name: str
|
|
50
|
+
max_identifier_length: int
|
|
51
|
+
max_overflow: int
|
|
52
|
+
# module - does not work with sharding
|
|
53
|
+
paramstyle: Literal[
|
|
54
|
+
"qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"
|
|
55
|
+
]
|
|
56
|
+
# pool - does not work with sharding
|
|
57
|
+
poolclass: Type[Pool]
|
|
58
|
+
pool_logging_name: str
|
|
59
|
+
pool_pre_ping: bool
|
|
60
|
+
pool_size: int
|
|
61
|
+
pool_recycle: int
|
|
62
|
+
pool_reset_on_return: Literal["rollback", "commit"] | None
|
|
63
|
+
pool_timeout: int
|
|
64
|
+
pool_use_iifo: bool
|
|
65
|
+
plugins: list[str]
|
|
66
|
+
query_cache_size: int
|
|
67
|
+
skip_autocommit_rollback: bool
|
|
68
|
+
use_insertmanyvalues: bool
|