shardproxy 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,59 @@
1
+ Metadata-Version: 2.3
2
+ Name: shardproxy
3
+ Version: 1.0.0
4
+ Summary: Sharding with SQLAlchemy
5
+ Classifier: License :: OSI Approved :: ISC License (ISCL)
6
+ Requires-Dist: sqlalchemy[asyncio]>=2.0
7
+ Requires-Python: >=3.12
8
+ Description-Content-Type: text/x-rst
9
+
10
+ shardproxy
11
+ ==========
12
+
13
+ The module consists of two layers:
14
+
15
+ - ShardProxy class that just manages connections.
16
+ - ops module that implements helper classes for parallelism and result processing.
17
+
18
+ ShardProxy
19
+ ----------
20
+
21
+ Initialize:
22
+
23
+ .. code:: python
24
+
25
+ SHARD_URLS = [
26
+ "postgresql+asyncpg://server/shard0",
27
+ "postgresql+asyncpg://server/shard1",
28
+ ]
29
+ proxy = ShardProxy(SHARD_URLS)
30
+
31
+ Process request:
32
+
33
+ .. code:: python
34
+
35
+ async def fetch_user(user_id: UUID):
36
+ async with proxy.connect(user_id) as conn:
37
+ async with conn.begin():
38
+ stmt = select(User).where(User.id == user_id)
39
+ res = await conn.execute(stmt)
40
+ return res.all()
41
+
42
+ Operations
43
+ ----------
44
+
45
+ Select bunch of IDs:
46
+
47
+ .. code:: python
48
+
49
+ UserRow = Row[tuple[User]]
50
+
51
+ class SelectManyUsers(RunOnKeysOperation[UserRow]):
52
+ async def process(
53
+ self, conn: AsyncConnection, keys: list[UUID]
54
+ ) -> ResultRows[UserRow]:
55
+ stmt = select(User).where(User.id.in_(keys))
56
+ res = await conn.execute(stmt)
57
+ return res.all()
58
+
59
+ users = await SelectManyUsers().run(proxy)
@@ -0,0 +1,50 @@
1
+ shardproxy
2
+ ==========
3
+
4
+ The module consists of two layers:
5
+
6
+ - ShardProxy class that just manages connections.
7
+ - ops module that implements helper classes for parallelism and result processing.
8
+
9
+ ShardProxy
10
+ ----------
11
+
12
+ Initialize:
13
+
14
+ .. code:: python
15
+
16
+ SHARD_URLS = [
17
+ "postgresql+asyncpg://server/shard0",
18
+ "postgresql+asyncpg://server/shard1",
19
+ ]
20
+ proxy = ShardProxy(SHARD_URLS)
21
+
22
+ Process request:
23
+
24
+ .. code:: python
25
+
26
+ async def fetch_user(user_id: UUID):
27
+ async with proxy.connect(user_id) as conn:
28
+ async with conn.begin():
29
+ stmt = select(User).where(User.id == user_id)
30
+ res = await conn.execute(stmt)
31
+ return res.all()
32
+
33
+ Operations
34
+ ----------
35
+
36
+ Select bunch of IDs:
37
+
38
+ .. code:: python
39
+
40
+ UserRow = Row[tuple[User]]
41
+
42
+ class SelectManyUsers(RunOnKeysOperation[UserRow]):
43
+ async def process(
44
+ self, conn: AsyncConnection, keys: list[UUID]
45
+ ) -> ResultRows[UserRow]:
46
+ stmt = select(User).where(User.id.in_(keys))
47
+ res = await conn.execute(stmt)
48
+ return res.all()
49
+
50
+ users = await SelectManyUsers().run(proxy)
@@ -0,0 +1,53 @@
1
+ [project]
2
+ name = "shardproxy"
3
+ version = "1.0.0"
4
+ description = "Sharding with SQLAlchemy"
5
+ readme = "README.rst"
6
+ classifiers = ["License :: OSI Approved :: ISC License (ISCL)"]
7
+ requires-python = ">=3.12"
8
+ dependencies = ["sqlalchemy[asyncio]>=2.0"]
9
+
10
+ [dependency-groups]
11
+ dev = [
12
+ {include-group = "docs"},
13
+ {include-group = "lint"},
14
+ {include-group = "test"},
15
+ ]
16
+ docs = ["sphinx"]
17
+ lint = ["ruff", "isort", "mypy"]
18
+ test = [
19
+ "asyncpg",
20
+ "psycopg[binary,pool]",
21
+ "pytest",
22
+ "pytest-asyncio",
23
+ "pytest-cov",
24
+ "coverage",
25
+ ]
26
+
27
+ [build-system]
28
+ requires = ["uv_build>=0.11.1,<0.12.0"]
29
+ build-backend = "uv_build"
30
+
31
+ [tool.mypy]
32
+ python_version = "3.12"
33
+ strict = true
34
+ disallow_any_decorated = true
35
+ disallow_any_unimported = true
36
+ #disallow_any_explicit = true
37
+ #disallow_any_expr = true
38
+
39
+ [tool.isort]
40
+ known_first_party = ["shardproxy"]
41
+ known_local_folder = ["helpers"]
42
+ multi_line_output = 5
43
+ balanced_wrapping = true
44
+ include_trailing_comma = true
45
+ atomic = true
46
+ quiet = true
47
+
48
+ [[tool.uv.index]]
49
+ name = "testpypi"
50
+ url = "https://test.pypi.org/simple/"
51
+ publish-url = "https://test.pypi.org/legacy/"
52
+ explicit = true
53
+
@@ -0,0 +1,24 @@
1
+ """Sharding with SQLAlchemy."""
2
+
3
+ from .ops import (
4
+ RunOnAllOperation,
5
+ RunOnAnyOperation,
6
+ RunOnArgsOperation,
7
+ RunOnKeysOperation,
8
+ RunOnOneOperation,
9
+ )
10
+ from .proxy import ShardProxy
11
+ from .types import AsyncConnection, AsyncEngine, DBTask, ResultRows
12
+
13
+ __all__ = (
14
+ "ShardProxy",
15
+ "AsyncEngine",
16
+ "AsyncConnection",
17
+ "ResultRows",
18
+ "DBTask",
19
+ "RunOnAllOperation",
20
+ "RunOnKeysOperation",
21
+ "RunOnArgsOperation",
22
+ "RunOnAnyOperation",
23
+ "RunOnOneOperation",
24
+ )
@@ -0,0 +1,234 @@
1
+ """Helper classes for multi-shard operations."""
2
+
3
+ import asyncio
4
+ from abc import ABC, abstractmethod
5
+ from collections.abc import Sequence
6
+ from uuid import UUID
7
+
8
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
9
+
10
+ from .proxy import ShardProxy
11
+ from .types import DBTask, ResultRows
12
+
13
+ __all__ = (
14
+ "RunOnKeysOperation",
15
+ "RunOnArgsOperation",
16
+ "RunOnAllOperation",
17
+ "RunOnOneOperation",
18
+ "RunOnAnyOperation",
19
+ )
20
+
21
+
22
+ class RunOnKeysOperation[TRow](ABC):
23
+ """Runs operation over given keys.
24
+
25
+ Usage::
26
+
27
+ class MyOp(RunOnKeysOperation[MyRow]):
28
+ async def process(
29
+ self, conn: AsyncConnection, keys: list[UUID]
30
+ ) -> ResultRows[MyRow]:
31
+ return ...
32
+
33
+ """
34
+
35
+ #: Whether connection is in transaction context when given to .process().
36
+ USE_TRANSACTION = True
37
+
38
+ @abstractmethod
39
+ async def process(
40
+ self, conn: AsyncConnection, keys: list[UUID]
41
+ ) -> ResultRows[TRow]:
42
+ """Business logic to execute on each shard."""
43
+ raise NotImplementedError
44
+
45
+ async def run(
46
+ self, proxy: ShardProxy, shard_keys: Sequence[UUID]
47
+ ) -> ResultRows[TRow]:
48
+ """Split keys for each shard and run .process() on them."""
49
+
50
+ tasks: list[DBTask[ResultRows[TRow]]] = []
51
+ for engine, keys in proxy.spread_keys(shard_keys):
52
+ task = asyncio.create_task(self._handle_process(engine, keys))
53
+ tasks.append(task)
54
+
55
+ return await self._collect(tasks)
56
+
57
+ async def _handle_process(
58
+ self, engine: AsyncEngine, keys: list[UUID]
59
+ ) -> ResultRows[TRow]:
60
+ async with engine.connect() as conn:
61
+ if self.USE_TRANSACTION:
62
+ async with conn.begin():
63
+ return await self.process(conn, keys)
64
+ else:
65
+ return await self.process(conn, keys)
66
+
67
+ async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
68
+ shard_rows = await asyncio.gather(*tasks)
69
+ res: list[TRow] = []
70
+ for rows in shard_rows:
71
+ res.extend(rows)
72
+ return res
73
+
74
+
75
+ class RunOnArgsOperation[TRow, TArg](ABC):
76
+ """Runs operation over given arguments.
77
+
78
+ Usage::
79
+
80
+ class MyOp(RunOnArgsOperation[MyRow, MyArg]):
81
+ async def process(
82
+ self, conn: AsyncConnection, args: dict[UUID, MyArg]
83
+ ) -> ResultRows[MyRow]:
84
+ return ...
85
+
86
+ """
87
+
88
+ #: Whether connection is in transaction context when given to .process().
89
+ USE_TRANSACTION = True
90
+
91
+ @abstractmethod
92
+ async def process(
93
+ self, conn: AsyncConnection, args: dict[UUID, TArg]
94
+ ) -> ResultRows[TRow]:
95
+ """Business logic to execute on each shard."""
96
+ raise NotImplementedError
97
+
98
+ async def run(self, proxy: ShardProxy, args: dict[UUID, TArg]) -> ResultRows[TRow]:
99
+ """Split args for each shard and run .process() with them."""
100
+
101
+ tasks: list[DBTask[ResultRows[TRow]]] = []
102
+ for engine, shard_args in proxy.spread_args(args):
103
+ task = asyncio.create_task(self._handle_process(engine, shard_args))
104
+ tasks.append(task)
105
+
106
+ return await self._collect(tasks)
107
+
108
+ async def _handle_process(
109
+ self, engine: AsyncEngine, args: dict[UUID, TArg]
110
+ ) -> ResultRows[TRow]:
111
+ async with engine.connect() as conn:
112
+ if self.USE_TRANSACTION:
113
+ async with conn.begin():
114
+ return await self.process(conn, args)
115
+ else:
116
+ return await self.process(conn, args)
117
+
118
+ async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
119
+ shard_rows = await asyncio.gather(*tasks)
120
+ res: list[TRow] = []
121
+ for rows in shard_rows:
122
+ res.extend(rows)
123
+ return res
124
+
125
+
126
+ class RunOnAllOperation[TRow](ABC):
127
+ """Runs operation over all shards.
128
+
129
+ Usage::
130
+
131
+ class MyOp(RunOnAllOperation[MyRow]):
132
+ async def process(
133
+ self, conn: AsyncConnection
134
+ ) -> ResultRows[MyRow]:
135
+ return ...
136
+
137
+ """
138
+
139
+ #: Whether connection is in transaction context when given to .process().
140
+ USE_TRANSACTION = True
141
+
142
+ @abstractmethod
143
+ async def process(self, conn: AsyncConnection) -> ResultRows[TRow]:
144
+ """Business logic to execute on each shard."""
145
+ raise NotImplementedError
146
+
147
+ async def run(self, proxy: ShardProxy) -> ResultRows[TRow]:
148
+ """Loop over all shards, execute .process() on each."""
149
+
150
+ tasks: list[DBTask[ResultRows[TRow]]] = []
151
+ for engine in proxy.shards:
152
+ task = asyncio.create_task(self._handle_process(engine))
153
+ tasks.append(task)
154
+
155
+ return await self._collect(tasks)
156
+
157
+ async def _handle_process(self, engine: AsyncEngine) -> ResultRows[TRow]:
158
+ async with engine.connect() as conn:
159
+ if self.USE_TRANSACTION:
160
+ async with conn.begin():
161
+ return await self.process(conn)
162
+ else:
163
+ return await self.process(conn)
164
+
165
+ async def _collect(self, tasks: list[DBTask[ResultRows[TRow]]]) -> ResultRows[TRow]:
166
+ shard_rows = await asyncio.gather(*tasks)
167
+ res: list[TRow] = []
168
+ for rows in shard_rows:
169
+ res.extend(rows)
170
+ return res
171
+
172
+
173
+ class RunOnOneOperation[TRow](ABC):
174
+ """Runs operation in single shard.
175
+
176
+ Usage::
177
+
178
+ class MyOp(RunOnOneOperation[MyRow]):
179
+ async def process(
180
+ self, conn: AsyncConnection, shard_key: UUID
181
+ ) -> ResultRows[MyRow]:
182
+ return ...
183
+
184
+ """
185
+
186
+ #: Whether connection is in transaction context when given to .process().
187
+ USE_TRANSACTION = True
188
+
189
+ @abstractmethod
190
+ async def process(self, conn: AsyncConnection, shard_key: UUID) -> ResultRows[TRow]:
191
+ """Business logic to execute on shard."""
192
+ raise NotImplementedError
193
+
194
+ async def run(self, proxy: ShardProxy, shard_key: UUID) -> ResultRows[TRow]:
195
+ """Pick shard and run .process() on it."""
196
+
197
+ async with proxy.connect(shard_key) as conn:
198
+ if self.USE_TRANSACTION:
199
+ async with conn.begin():
200
+ return await self.process(conn, shard_key)
201
+ else:
202
+ return await self.process(conn, shard_key)
203
+
204
+
205
+ class RunOnAnyOperation[TRow](ABC):
206
+ """Runs operation in single random shard.
207
+
208
+ Usage::
209
+
210
+ class MyOp(RunOnAnyOperation[MyRow]):
211
+ async def process(
212
+ self, conn: AsyncConnection
213
+ ) -> ResultRows[MyRow]:
214
+ return ...
215
+
216
+ """
217
+
218
+ #: Whether connection is in transaction context when given to .process().
219
+ USE_TRANSACTION = True
220
+
221
+ @abstractmethod
222
+ async def process(self, conn: AsyncConnection) -> ResultRows[TRow]:
223
+ """Business logic to execute on shard."""
224
+ raise NotImplementedError
225
+
226
+ async def run(self, proxy: ShardProxy) -> ResultRows[TRow]:
227
+ """Pick random shard and run .process() on it."""
228
+
229
+ async with proxy.connect_any() as conn:
230
+ if self.USE_TRANSACTION:
231
+ async with conn.begin():
232
+ return await self.process(conn)
233
+ else:
234
+ return await self.process(conn)
@@ -0,0 +1,136 @@
1
+ """Connection manager for sharding."""
2
+
3
+ from collections.abc import Mapping, Sequence
4
+ from random import choice
5
+ from typing import Unpack
6
+ from uuid import UUID
7
+
8
+ from sqlalchemy import URL
9
+ from sqlalchemy.ext.asyncio import (
10
+ AsyncConnection,
11
+ AsyncEngine,
12
+ create_async_engine,
13
+ )
14
+
15
+ from .types import AsyncEngineParams
16
+
17
+ __all__ = ("ShardProxy",)
18
+
19
+
20
+ class ShardProxy:
21
+ """Manages AsyncEngine mapping for shards.
22
+
23
+ Takes list of shard URLs and generic create_async_engine()/create_engine()
24
+ keyword arguments.
25
+ """
26
+
27
+ _shards: list[AsyncEngine]
28
+
29
+ def __init__(
30
+ self, shard_urls: Sequence[str | URL], **kwargs: Unpack[AsyncEngineParams]
31
+ ) -> None:
32
+ self._shards = []
33
+ for url in shard_urls:
34
+ engine = create_async_engine(url, **kwargs)
35
+ self._shards.append(engine)
36
+
37
+ def _get_shard(self, shard_key: UUID) -> int:
38
+ return shard_key.node % len(self._shards)
39
+
40
+ @property
41
+ def shards(self) -> Sequence[AsyncEngine]:
42
+ """Read-only list of engines per shard.
43
+
44
+ Usage::
45
+
46
+ # run-on-all: serial mode
47
+ for engine in proxy.shards:
48
+ async with engine.connect() as conn:
49
+ async with conn.begin():
50
+ await conn.execute()
51
+
52
+ # run-on-all: parallel mode
53
+ async with asyncio.TaskGroup() as tg:
54
+ for engine in proxy.shards:
55
+ tg.create_task(process(engine, items))
56
+ """
57
+ return self._shards
58
+
59
+ def connect(self, shard_key: UUID) -> AsyncConnection:
60
+ """Returns connection for specific shard as context manager.
61
+
62
+ Usage::
63
+
64
+ async with proxy.connect(key) as conn:
65
+ async with conn.begin():
66
+ ...
67
+ """
68
+ idx = self._get_shard(shard_key)
69
+ return self._shards[idx].connect()
70
+
71
+ def connect_any[T](self) -> AsyncConnection:
72
+ """Returns connection to random shard as context manager.
73
+
74
+ Usage::
75
+
76
+ async with proxy.connect_any() as conn:
77
+ async with conn.begin():
78
+ ...
79
+ """
80
+ engine = choice(self._shards)
81
+ return engine.connect()
82
+
83
+ def spread_args[A](
84
+ self, args: Mapping[UUID, A]
85
+ ) -> Sequence[tuple[AsyncEngine, dict[UUID, A]]]:
86
+ """Spread keys with arguments over shards.
87
+
88
+ Usage::
89
+
90
+ # serial mode
91
+ for engine, shard_items in proxy.split_args(all_items):
92
+ async with engine.connect() as conn:
93
+ async with conn.begin():
94
+ await conn.execute()
95
+
96
+ # parallel mode
97
+ async with asyncio.TaskGroup() as tg:
98
+ for engine, items in proxy.split_args(all_items):
99
+ tg.create_task(process(engine, items))
100
+ """
101
+ split: dict[int, dict[UUID, A]] = {}
102
+ for key, arg in args.items():
103
+ nr = self._get_shard(key)
104
+ target = split.get(nr)
105
+ if target is None:
106
+ split[nr] = {key: arg}
107
+ else:
108
+ target[key] = arg
109
+ return [(self._shards[nr], target) for nr, target in split.items()]
110
+
111
+ def spread_keys(
112
+ self, keys: Sequence[UUID]
113
+ ) -> Sequence[tuple[AsyncEngine, list[UUID]]]:
114
+ """Spread keys over shards.
115
+
116
+ Usage::
117
+
118
+ # serial mode
119
+ for engine, shard_keys in proxy.split_keys(all_keys):
120
+ async with engine.connect() as conn:
121
+ async with conn.begin():
122
+ await conn.execute()
123
+
124
+ # parallel mode
125
+ async with asyncio.TaskGroup() as tg:
126
+ for engine, shard_keys in proxy.split_keys(all_keys):
127
+ tg.create_task(process(engine, shard_keys))
128
+ """
129
+ split: dict[int, set[UUID]] = {}
130
+ for key in keys:
131
+ nr = self._get_shard(key)
132
+ target = split.get(nr)
133
+ if target is None:
134
+ target = split[nr] = set()
135
+ target.add(key)
136
+ return [(self._shards[nr], list(target)) for nr, target in split.items()]
File without changes
@@ -0,0 +1,68 @@
1
+ """Common types"""
2
+
3
+ import asyncio
4
+ from collections.abc import Sequence
5
+ from typing import Any, Callable, Literal, Type, TypedDict
6
+
7
+ from sqlalchemy.engine.interfaces import (
8
+ CoreExecuteOptionsParameter,
9
+ IsolationLevel,
10
+ )
11
+ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
12
+ from sqlalchemy.pool import Pool
13
+
14
+ __all__ = (
15
+ "AsyncConnection",
16
+ "AsyncEngine",
17
+ "AsyncEngineParams",
18
+ "ResultRows",
19
+ "DBTask",
20
+ )
21
+
22
+
23
+ type ResultRows[TRow] = Sequence[TRow]
24
+
25
+ type DBTask[TResult] = asyncio.Task[TResult]
26
+
27
+
28
+ class AsyncEngineParams(TypedDict, total=False):
29
+ """Helper struct to describe common parameters for create_async_engine.
30
+
31
+ It may be incomplete - arguments from here are
32
+ passed to AsyncEngine, Dialect and Pool.
33
+
34
+ Some values are deliberately hidden as they will not work with ShardProxy.
35
+ """
36
+
37
+ # async_creator - does not work with sharding
38
+ # connect_args - does not work with sharding
39
+ echo: bool | Literal["debug"] | None
40
+ echo_pool: bool | Literal["debug"] | None
41
+ enable_from_linting: bool
42
+ execution_options: CoreExecuteOptionsParameter
43
+ hide_parameters: bool
44
+ insertmanyvalues_page_size: int
45
+ isolation_level: IsolationLevel
46
+ json_deserializer: Callable[..., Any]
47
+ json_serializer: Callable[..., Any]
48
+ label_length: int | None
49
+ logging_name: str
50
+ max_identifier_length: int
51
+ max_overflow: int
52
+ # module - does not work with sharding
53
+ paramstyle: Literal[
54
+ "qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"
55
+ ]
56
+ # pool - does not work with sharding
57
+ poolclass: Type[Pool]
58
+ pool_logging_name: str
59
+ pool_pre_ping: bool
60
+ pool_size: int
61
+ pool_recycle: int
62
+ pool_reset_on_return: Literal["rollback", "commit"] | None
63
+ pool_timeout: int
64
+ pool_use_iifo: bool
65
+ plugins: list[str]
66
+ query_cache_size: int
67
+ skip_autocommit_rollback: bool
68
+ use_insertmanyvalues: bool