iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iceaxe might be problematic. Click here for more details.
- iceaxe/__init__.py +20 -0
- iceaxe/__tests__/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
- iceaxe/__tests__/benchmarks/test_select.py +114 -0
- iceaxe/__tests__/conf_models.py +133 -0
- iceaxe/__tests__/conftest.py +204 -0
- iceaxe/__tests__/docker_helpers.py +208 -0
- iceaxe/__tests__/helpers.py +268 -0
- iceaxe/__tests__/migrations/__init__.py +0 -0
- iceaxe/__tests__/migrations/conftest.py +36 -0
- iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
- iceaxe/__tests__/migrations/test_generator.py +140 -0
- iceaxe/__tests__/migrations/test_generics.py +91 -0
- iceaxe/__tests__/mountaineer/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
- iceaxe/__tests__/schemas/__init__.py +0 -0
- iceaxe/__tests__/schemas/test_actions.py +1265 -0
- iceaxe/__tests__/schemas/test_cli.py +25 -0
- iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
- iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
- iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
- iceaxe/__tests__/test_alias.py +83 -0
- iceaxe/__tests__/test_base.py +52 -0
- iceaxe/__tests__/test_comparison.py +383 -0
- iceaxe/__tests__/test_field.py +11 -0
- iceaxe/__tests__/test_helpers.py +9 -0
- iceaxe/__tests__/test_modifications.py +151 -0
- iceaxe/__tests__/test_queries.py +764 -0
- iceaxe/__tests__/test_queries_str.py +173 -0
- iceaxe/__tests__/test_session.py +1511 -0
- iceaxe/__tests__/test_text_search.py +287 -0
- iceaxe/alias_values.py +67 -0
- iceaxe/base.py +351 -0
- iceaxe/comparison.py +560 -0
- iceaxe/field.py +263 -0
- iceaxe/functions.py +1432 -0
- iceaxe/generics.py +140 -0
- iceaxe/io.py +107 -0
- iceaxe/logging.py +91 -0
- iceaxe/migrations/__init__.py +5 -0
- iceaxe/migrations/action_sorter.py +98 -0
- iceaxe/migrations/cli.py +228 -0
- iceaxe/migrations/client_io.py +62 -0
- iceaxe/migrations/generator.py +404 -0
- iceaxe/migrations/migration.py +86 -0
- iceaxe/migrations/migrator.py +101 -0
- iceaxe/modifications.py +176 -0
- iceaxe/mountaineer/__init__.py +10 -0
- iceaxe/mountaineer/cli.py +74 -0
- iceaxe/mountaineer/config.py +46 -0
- iceaxe/mountaineer/dependencies/__init__.py +6 -0
- iceaxe/mountaineer/dependencies/core.py +67 -0
- iceaxe/postgres.py +133 -0
- iceaxe/py.typed +0 -0
- iceaxe/queries.py +1459 -0
- iceaxe/queries_str.py +294 -0
- iceaxe/schemas/__init__.py +0 -0
- iceaxe/schemas/actions.py +864 -0
- iceaxe/schemas/cli.py +30 -0
- iceaxe/schemas/db_memory_serializer.py +711 -0
- iceaxe/schemas/db_serializer.py +347 -0
- iceaxe/schemas/db_stubs.py +529 -0
- iceaxe/session.py +860 -0
- iceaxe/session_optimized.c +12207 -0
- iceaxe/session_optimized.cpython-313-darwin.so +0 -0
- iceaxe/session_optimized.pyx +212 -0
- iceaxe/sql_types.py +149 -0
- iceaxe/typing.py +73 -0
- iceaxe-0.8.3.dist-info/METADATA +262 -0
- iceaxe-0.8.3.dist-info/RECORD +75 -0
- iceaxe-0.8.3.dist-info/WHEEL +6 -0
- iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
- iceaxe-0.8.3.dist-info/top_level.txt +1 -0
iceaxe/session.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from json import loads as json_loads
|
|
4
|
+
from math import ceil
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
Literal,
|
|
8
|
+
ParamSpec,
|
|
9
|
+
Sequence,
|
|
10
|
+
Type,
|
|
11
|
+
TypeVar,
|
|
12
|
+
cast,
|
|
13
|
+
overload,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
import asyncpg
|
|
17
|
+
from typing_extensions import TypeVarTuple
|
|
18
|
+
|
|
19
|
+
from iceaxe.base import DBFieldClassDefinition, TableBase
|
|
20
|
+
from iceaxe.logging import LOGGER
|
|
21
|
+
from iceaxe.modifications import ModificationTracker
|
|
22
|
+
from iceaxe.queries import (
|
|
23
|
+
QueryBuilder,
|
|
24
|
+
is_base_table,
|
|
25
|
+
is_column,
|
|
26
|
+
is_function_metadata,
|
|
27
|
+
)
|
|
28
|
+
from iceaxe.queries_str import QueryIdentifier
|
|
29
|
+
from iceaxe.session_optimized import optimize_exec_casting
|
|
30
|
+
|
|
31
|
+
P = ParamSpec("P")
|
|
32
|
+
T = TypeVar("T")
|
|
33
|
+
Ts = TypeVarTuple("Ts")
|
|
34
|
+
|
|
35
|
+
TableType = TypeVar("TableType", bound=TableBase)
|
|
36
|
+
|
|
37
|
+
# PostgreSQL has a limit of 32767 parameters per query (Short.MAX_VALUE)
|
|
38
|
+
PG_MAX_PARAMETERS = 32767
|
|
39
|
+
|
|
40
|
+
TYPE_CACHE = {}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class DBConnection:
|
|
44
|
+
"""
|
|
45
|
+
Core class for all ORM actions against a PostgreSQL database. Provides high-level methods
|
|
46
|
+
for executing queries and managing database transactions.
|
|
47
|
+
|
|
48
|
+
The DBConnection wraps an asyncpg Connection and provides ORM functionality for:
|
|
49
|
+
- Executing SELECT/INSERT/UPDATE/DELETE queries
|
|
50
|
+
- Managing transactions
|
|
51
|
+
- Inserting, updating, and deleting model instances
|
|
52
|
+
- Refreshing model instances from the database
|
|
53
|
+
|
|
54
|
+
```python {{sticky: True}}
|
|
55
|
+
# Create a connection
|
|
56
|
+
conn = DBConnection(
|
|
57
|
+
await asyncpg.connect(
|
|
58
|
+
host="localhost",
|
|
59
|
+
port=5432,
|
|
60
|
+
user="db_user",
|
|
61
|
+
password="yoursecretpassword",
|
|
62
|
+
database="your_db",
|
|
63
|
+
)
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Use with models
|
|
67
|
+
class User(TableBase):
|
|
68
|
+
id: int = Field(primary_key=True)
|
|
69
|
+
name: str
|
|
70
|
+
email: str
|
|
71
|
+
|
|
72
|
+
# Insert data
|
|
73
|
+
user = User(name="Alice", email="alice@example.com")
|
|
74
|
+
await conn.insert([user])
|
|
75
|
+
|
|
76
|
+
# Query data
|
|
77
|
+
users = await conn.exec(
|
|
78
|
+
select(User)
|
|
79
|
+
.where(User.name == "Alice")
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Update data
|
|
83
|
+
user.email = "newemail@example.com"
|
|
84
|
+
await conn.update([user])
|
|
85
|
+
```
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
def __init__(
|
|
89
|
+
self,
|
|
90
|
+
conn: asyncpg.Connection,
|
|
91
|
+
*,
|
|
92
|
+
uncommitted_verbosity: Literal["ERROR", "WARNING", "INFO"] | None = None,
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Initialize a new database connection wrapper.
|
|
96
|
+
|
|
97
|
+
:param conn: An asyncpg Connection instance to wrap
|
|
98
|
+
:param uncommitted_verbosity: The verbosity level if objects are modified but not committed when
|
|
99
|
+
the session is closed, defaults to nothing
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
self.conn = conn
|
|
103
|
+
self.obj_to_primary_key: dict[str, str | None] = {}
|
|
104
|
+
self.in_transaction = False
|
|
105
|
+
self.modification_tracker = ModificationTracker(uncommitted_verbosity)
|
|
106
|
+
|
|
107
|
+
async def initialize_types(self, timeout: float = 60.0) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Introspect and register PostgreSQL type codecs on this connection,
|
|
110
|
+
caching the result globally using the connection's DB URL as a key. These types
|
|
111
|
+
are unlikely to change in the lifetime of a Python process, so this is typically
|
|
112
|
+
safe to do automatically.
|
|
113
|
+
|
|
114
|
+
This method should be called once per connection so we can leverage our own cache. If
|
|
115
|
+
asyncpg is called directly on a new connection, it will result in its own duplicate
|
|
116
|
+
type introspection call.
|
|
117
|
+
|
|
118
|
+
"""
|
|
119
|
+
global TYPE_CACHE
|
|
120
|
+
|
|
121
|
+
if not self.conn._protocol:
|
|
122
|
+
LOGGER.warning(
|
|
123
|
+
"No protocol found for connection during type introspection, will fall back to asyncpg"
|
|
124
|
+
)
|
|
125
|
+
return
|
|
126
|
+
|
|
127
|
+
# Determine a unique key for this connection.
|
|
128
|
+
db_url = self.get_dsn()
|
|
129
|
+
|
|
130
|
+
# If we've already cached the type information for this DB URL, just register it.
|
|
131
|
+
if db_url in TYPE_CACHE:
|
|
132
|
+
self.conn._protocol.get_settings().register_data_types(TYPE_CACHE[db_url])
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
# Get the connection settings object (this is where type codecs are registered).
|
|
136
|
+
settings = self.conn._protocol.get_settings()
|
|
137
|
+
|
|
138
|
+
# Query PostgreSQL to get all type OIDs from non-system schemas.
|
|
139
|
+
rows = await self.conn.fetch(
|
|
140
|
+
"""
|
|
141
|
+
SELECT t.oid
|
|
142
|
+
FROM pg_type t
|
|
143
|
+
JOIN pg_namespace n ON t.typnamespace = n.oid
|
|
144
|
+
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
|
|
145
|
+
"""
|
|
146
|
+
)
|
|
147
|
+
# Build a set of type OIDs.
|
|
148
|
+
typeoids = {row["oid"] for row in rows}
|
|
149
|
+
|
|
150
|
+
# Introspect types – this call will recursively determine the PostgreSQL types needed.
|
|
151
|
+
types, intro_stmt = await self.conn._introspect_types(typeoids, timeout)
|
|
152
|
+
|
|
153
|
+
# Register the introspected types with the connection's settings.
|
|
154
|
+
settings.register_data_types(types)
|
|
155
|
+
|
|
156
|
+
# Cache the types globally so that future connections using the same DB URL
|
|
157
|
+
# can simply register the cached codecs.
|
|
158
|
+
TYPE_CACHE[db_url] = types
|
|
159
|
+
|
|
160
|
+
def get_dsn(self) -> str:
|
|
161
|
+
"""
|
|
162
|
+
Get the DSN (Data Source Name) string for this connection.
|
|
163
|
+
|
|
164
|
+
:return: DSN string in the format 'postgresql://user:password@host:port/dbname'
|
|
165
|
+
"""
|
|
166
|
+
params = self.conn._params
|
|
167
|
+
addr = self.conn._addr
|
|
168
|
+
|
|
169
|
+
# Build the DSN string with all available parameters
|
|
170
|
+
dsn_parts = ["postgresql://"]
|
|
171
|
+
|
|
172
|
+
# Add user/password if available
|
|
173
|
+
if params.user:
|
|
174
|
+
dsn_parts.append(params.user)
|
|
175
|
+
if params.password:
|
|
176
|
+
dsn_parts.append(f":{params.password}")
|
|
177
|
+
dsn_parts.append("@")
|
|
178
|
+
|
|
179
|
+
# Add host/port
|
|
180
|
+
dsn_parts.append(addr[0])
|
|
181
|
+
if addr[1]:
|
|
182
|
+
dsn_parts.append(f":{addr[1]}")
|
|
183
|
+
|
|
184
|
+
# Add database name
|
|
185
|
+
if params.database:
|
|
186
|
+
dsn_parts.append(f"/{params.database}")
|
|
187
|
+
|
|
188
|
+
return "".join(dsn_parts)
|
|
189
|
+
|
|
190
|
+
@asynccontextmanager
|
|
191
|
+
async def transaction(self, *, ensure: bool = False):
|
|
192
|
+
"""
|
|
193
|
+
Context manager for managing database transactions. Ensures that a series of database
|
|
194
|
+
operations are executed atomically.
|
|
195
|
+
|
|
196
|
+
:param ensure: If True and already in a transaction, the context manager will yield without creating a new transaction.
|
|
197
|
+
If False (default) and already in a transaction, raises a RuntimeError.
|
|
198
|
+
|
|
199
|
+
```python {{sticky: True}}
|
|
200
|
+
async with conn.transaction():
|
|
201
|
+
# All operations here are executed in a transaction
|
|
202
|
+
user = User(name="Alice", email="alice@example.com")
|
|
203
|
+
await conn.insert([user])
|
|
204
|
+
|
|
205
|
+
post = Post(title="Hello", user_id=user.id)
|
|
206
|
+
await conn.insert([post])
|
|
207
|
+
|
|
208
|
+
# If any operation fails, all changes are rolled back
|
|
209
|
+
```
|
|
210
|
+
"""
|
|
211
|
+
# If ensure is True and we're already in a transaction, just yield
|
|
212
|
+
if self.in_transaction:
|
|
213
|
+
if ensure:
|
|
214
|
+
yield
|
|
215
|
+
return
|
|
216
|
+
else:
|
|
217
|
+
raise RuntimeError(
|
|
218
|
+
"Cannot start a new transaction while already in a transaction. Use ensure=True if this is intentional."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Otherwise, start a new transaction
|
|
222
|
+
self.in_transaction = True
|
|
223
|
+
async with self.conn.transaction():
|
|
224
|
+
try:
|
|
225
|
+
yield
|
|
226
|
+
finally:
|
|
227
|
+
self.in_transaction = False
|
|
228
|
+
|
|
229
|
+
@overload
|
|
230
|
+
async def exec(self, query: QueryBuilder[T, Literal["SELECT"]]) -> list[T]: ...
|
|
231
|
+
|
|
232
|
+
@overload
|
|
233
|
+
async def exec(self, query: QueryBuilder[T, Literal["INSERT"]]) -> None: ...
|
|
234
|
+
|
|
235
|
+
@overload
|
|
236
|
+
async def exec(self, query: QueryBuilder[T, Literal["UPDATE"]]) -> None: ...
|
|
237
|
+
|
|
238
|
+
@overload
|
|
239
|
+
async def exec(self, query: QueryBuilder[T, Literal["DELETE"]]) -> None: ...
|
|
240
|
+
|
|
241
|
+
async def exec(
|
|
242
|
+
self,
|
|
243
|
+
query: QueryBuilder[T, Literal["SELECT"]]
|
|
244
|
+
| QueryBuilder[T, Literal["INSERT"]]
|
|
245
|
+
| QueryBuilder[T, Literal["UPDATE"]]
|
|
246
|
+
| QueryBuilder[T, Literal["DELETE"]],
|
|
247
|
+
) -> list[T] | None:
|
|
248
|
+
"""
|
|
249
|
+
Execute a query built with QueryBuilder and return the results.
|
|
250
|
+
|
|
251
|
+
```python {{sticky: True}}
|
|
252
|
+
# Select query
|
|
253
|
+
users = await conn.exec(
|
|
254
|
+
select(User)
|
|
255
|
+
.where(User.age >= 18)
|
|
256
|
+
.order_by(User.name)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Select with joins and aggregates
|
|
260
|
+
results = await conn.exec(
|
|
261
|
+
select((User.name, func.count(Order.id)))
|
|
262
|
+
.join(Order, Order.user_id == User.id)
|
|
263
|
+
.group_by(User.name)
|
|
264
|
+
.having(func.count(Order.id) > 5)
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# Delete query
|
|
268
|
+
await conn.exec(
|
|
269
|
+
delete(User)
|
|
270
|
+
.where(User.is_active == False)
|
|
271
|
+
)
|
|
272
|
+
```
|
|
273
|
+
|
|
274
|
+
:param query: A QueryBuilder instance representing the query to execute
|
|
275
|
+
:return: For SELECT queries, returns a list of results. For other queries, returns None
|
|
276
|
+
|
|
277
|
+
"""
|
|
278
|
+
sql_text, variables = query.build()
|
|
279
|
+
LOGGER.debug(f"Executing query: {sql_text} with variables: {variables}")
|
|
280
|
+
try:
|
|
281
|
+
values = await self.conn.fetch(sql_text, *variables)
|
|
282
|
+
except Exception as e:
|
|
283
|
+
LOGGER.error(
|
|
284
|
+
f"Error executing query: {sql_text} with variables: {variables}"
|
|
285
|
+
)
|
|
286
|
+
raise e
|
|
287
|
+
|
|
288
|
+
if query._query_type == "SELECT":
|
|
289
|
+
# Pre-cache the select types for better performance
|
|
290
|
+
select_types = [
|
|
291
|
+
(
|
|
292
|
+
is_base_table(select_raw),
|
|
293
|
+
is_column(select_raw),
|
|
294
|
+
is_function_metadata(select_raw),
|
|
295
|
+
)
|
|
296
|
+
for select_raw in query._select_raw
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
result_all = optimize_exec_casting(values, query._select_raw, select_types)
|
|
300
|
+
|
|
301
|
+
# Only loop through results if we have verbosity enabled, since this logic otherwise
|
|
302
|
+
# is wasted if no content will eventually be logged
|
|
303
|
+
if self.modification_tracker.verbosity:
|
|
304
|
+
for row in result_all:
|
|
305
|
+
elements = row if isinstance(row, tuple) else (row,)
|
|
306
|
+
for element in elements:
|
|
307
|
+
if isinstance(element, TableBase):
|
|
308
|
+
element.register_modified_callback(
|
|
309
|
+
self.modification_tracker.track_modification
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
return cast(list[T], result_all)
|
|
313
|
+
|
|
314
|
+
return None
|
|
315
|
+
|
|
316
|
+
async def insert(self, objects: Sequence[TableBase]):
|
|
317
|
+
"""
|
|
318
|
+
Insert one or more model instances into the database. If the model has an auto-incrementing
|
|
319
|
+
primary key, it will be populated on the instances after insertion.
|
|
320
|
+
|
|
321
|
+
```python {{sticky: True}}
|
|
322
|
+
# Insert a single object
|
|
323
|
+
user = User(name="Alice", email="alice@example.com")
|
|
324
|
+
await conn.insert([user])
|
|
325
|
+
print(user.id) # Auto-populated primary key
|
|
326
|
+
|
|
327
|
+
# Insert multiple objects
|
|
328
|
+
users = [
|
|
329
|
+
User(name="Bob", email="bob@example.com"),
|
|
330
|
+
User(name="Charlie", email="charlie@example.com")
|
|
331
|
+
]
|
|
332
|
+
await conn.insert(users)
|
|
333
|
+
```
|
|
334
|
+
|
|
335
|
+
:param objects: A sequence of TableBase instances to insert
|
|
336
|
+
|
|
337
|
+
"""
|
|
338
|
+
if not objects:
|
|
339
|
+
return
|
|
340
|
+
|
|
341
|
+
# Reuse a single transaction for all inserts
|
|
342
|
+
async with self.transaction(ensure=True):
|
|
343
|
+
for model, model_objects in self._aggregate_models_by_table(objects):
|
|
344
|
+
# For each table, build batched insert queries
|
|
345
|
+
table_name = QueryIdentifier(model.get_table_name())
|
|
346
|
+
fields = {
|
|
347
|
+
field: info
|
|
348
|
+
for field, info in model.model_fields.items()
|
|
349
|
+
if (not info.exclude and not info.autoincrement)
|
|
350
|
+
}
|
|
351
|
+
primary_key = self._get_primary_key(model)
|
|
352
|
+
field_names = list(fields.keys())
|
|
353
|
+
field_identifiers = ", ".join(f'"{f}"' for f in field_names)
|
|
354
|
+
|
|
355
|
+
# Build the base query
|
|
356
|
+
if primary_key:
|
|
357
|
+
query = f"""
|
|
358
|
+
INSERT INTO {table_name} ({field_identifiers})
|
|
359
|
+
VALUES ({", ".join(f"${i}" for i in range(1, len(field_names) + 1))})
|
|
360
|
+
RETURNING {primary_key}
|
|
361
|
+
"""
|
|
362
|
+
else:
|
|
363
|
+
query = f"""
|
|
364
|
+
INSERT INTO {table_name} ({field_identifiers})
|
|
365
|
+
VALUES ({", ".join(f"${i}" for i in range(1, len(field_names) + 1))})
|
|
366
|
+
"""
|
|
367
|
+
|
|
368
|
+
for batch_objects, values_list in self._batch_objects_and_values(
|
|
369
|
+
model_objects, field_names, fields
|
|
370
|
+
):
|
|
371
|
+
# Insert them in one go
|
|
372
|
+
if primary_key:
|
|
373
|
+
# For returning queries, we can use fetchmany to get the primary keys
|
|
374
|
+
rows = await self.conn.fetchmany(query, values_list)
|
|
375
|
+
for obj, row in zip(batch_objects, rows):
|
|
376
|
+
setattr(obj, primary_key, row[primary_key])
|
|
377
|
+
else:
|
|
378
|
+
# For non-returning queries, we can use executemany
|
|
379
|
+
await self.conn.executemany(query, values_list)
|
|
380
|
+
|
|
381
|
+
# Mark as unmodified
|
|
382
|
+
for obj in batch_objects:
|
|
383
|
+
obj.clear_modified_attributes()
|
|
384
|
+
|
|
385
|
+
# Register modification callbacks outside the main insert loop
|
|
386
|
+
if self.modification_tracker.verbosity:
|
|
387
|
+
for obj in objects:
|
|
388
|
+
obj.register_modified_callback(
|
|
389
|
+
self.modification_tracker.track_modification
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
# Clear modification status
|
|
393
|
+
self.modification_tracker.clear_status(objects)
|
|
394
|
+
|
|
395
|
+
@overload
|
|
396
|
+
async def upsert(
|
|
397
|
+
self,
|
|
398
|
+
objects: Sequence[TableBase],
|
|
399
|
+
*,
|
|
400
|
+
conflict_fields: tuple[Any, ...],
|
|
401
|
+
update_fields: tuple[Any, ...] | None = None,
|
|
402
|
+
returning_fields: tuple[T, *Ts] | None = None,
|
|
403
|
+
) -> list[tuple[T, *Ts]] | None: ...
|
|
404
|
+
|
|
405
|
+
@overload
|
|
406
|
+
async def upsert(
|
|
407
|
+
self,
|
|
408
|
+
objects: Sequence[TableBase],
|
|
409
|
+
*,
|
|
410
|
+
conflict_fields: tuple[Any, ...],
|
|
411
|
+
update_fields: tuple[Any, ...] | None = None,
|
|
412
|
+
returning_fields: None,
|
|
413
|
+
) -> None: ...
|
|
414
|
+
|
|
415
|
+
async def upsert(
|
|
416
|
+
self,
|
|
417
|
+
objects: Sequence[TableBase],
|
|
418
|
+
*,
|
|
419
|
+
conflict_fields: tuple[Any, ...],
|
|
420
|
+
update_fields: tuple[Any, ...] | None = None,
|
|
421
|
+
returning_fields: tuple[T, *Ts] | None = None,
|
|
422
|
+
) -> list[tuple[T, *Ts]] | None:
|
|
423
|
+
"""
|
|
424
|
+
Performs an upsert (INSERT ... ON CONFLICT DO UPDATE) operation for the given objects.
|
|
425
|
+
This is useful when you want to insert records but update them if they already exist.
|
|
426
|
+
|
|
427
|
+
```python {{sticky: True}}
|
|
428
|
+
# Simple upsert based on email
|
|
429
|
+
users = [
|
|
430
|
+
User(email="alice@example.com", name="Alice"),
|
|
431
|
+
User(email="bob@example.com", name="Bob")
|
|
432
|
+
]
|
|
433
|
+
await conn.upsert(
|
|
434
|
+
users,
|
|
435
|
+
conflict_fields=(User.email,),
|
|
436
|
+
update_fields=(User.name,)
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Upsert with returning values
|
|
440
|
+
results = await conn.upsert(
|
|
441
|
+
users,
|
|
442
|
+
conflict_fields=(User.email,),
|
|
443
|
+
update_fields=(User.name,),
|
|
444
|
+
returning_fields=(User.id, User.email)
|
|
445
|
+
)
|
|
446
|
+
for user_id, email in results:
|
|
447
|
+
print(f"Upserted user {email} with ID {user_id}")
|
|
448
|
+
```
|
|
449
|
+
|
|
450
|
+
:param objects: Sequence of TableBase objects to upsert
|
|
451
|
+
:param conflict_fields: Fields to check for conflicts (ON CONFLICT)
|
|
452
|
+
:param update_fields: Fields to update on conflict. If None, updates all non-excluded fields
|
|
453
|
+
:param returning_fields: Fields to return after the operation. If None, returns nothing
|
|
454
|
+
:return: List of tuples containing the returned fields if returning_fields is specified
|
|
455
|
+
|
|
456
|
+
"""
|
|
457
|
+
if not objects:
|
|
458
|
+
return None
|
|
459
|
+
|
|
460
|
+
# Evaluate column types
|
|
461
|
+
conflict_fields_cols: list[DBFieldClassDefinition] = []
|
|
462
|
+
update_fields_cols: list[DBFieldClassDefinition] = []
|
|
463
|
+
returning_fields_cols: list[DBFieldClassDefinition] = []
|
|
464
|
+
|
|
465
|
+
# Explicitly validate types of all columns
|
|
466
|
+
for field in conflict_fields:
|
|
467
|
+
if is_column(field):
|
|
468
|
+
conflict_fields_cols.append(field)
|
|
469
|
+
else:
|
|
470
|
+
raise ValueError(f"Field {field} is not a column")
|
|
471
|
+
for field in update_fields or []:
|
|
472
|
+
if is_column(field):
|
|
473
|
+
update_fields_cols.append(field)
|
|
474
|
+
else:
|
|
475
|
+
raise ValueError(f"Field {field} is not a column")
|
|
476
|
+
for field in returning_fields or []:
|
|
477
|
+
if is_column(field):
|
|
478
|
+
returning_fields_cols.append(field)
|
|
479
|
+
else:
|
|
480
|
+
raise ValueError(f"Field {field} is not a column")
|
|
481
|
+
|
|
482
|
+
results: list[tuple[T, *Ts]] = []
|
|
483
|
+
async with self.transaction(ensure=True):
|
|
484
|
+
for model, model_objects in self._aggregate_models_by_table(objects):
|
|
485
|
+
table_name = QueryIdentifier(model.get_table_name())
|
|
486
|
+
fields = {
|
|
487
|
+
field: info
|
|
488
|
+
for field, info in model.model_fields.items()
|
|
489
|
+
if (not info.exclude and not info.autoincrement)
|
|
490
|
+
}
|
|
491
|
+
|
|
492
|
+
field_string = ", ".join(f'"{field}"' for field in fields)
|
|
493
|
+
placeholders = ", ".join(f"${i}" for i in range(1, len(fields) + 1))
|
|
494
|
+
query = (
|
|
495
|
+
f"INSERT INTO {table_name} ({field_string}) VALUES ({placeholders})"
|
|
496
|
+
)
|
|
497
|
+
if conflict_fields_cols:
|
|
498
|
+
conflict_field_string = ", ".join(
|
|
499
|
+
f'"{field.key}"' for field in conflict_fields_cols
|
|
500
|
+
)
|
|
501
|
+
query += f" ON CONFLICT ({conflict_field_string})"
|
|
502
|
+
|
|
503
|
+
if update_fields_cols:
|
|
504
|
+
set_values = ", ".join(
|
|
505
|
+
f'"{field.key}" = EXCLUDED."{field.key}"'
|
|
506
|
+
for field in update_fields_cols
|
|
507
|
+
)
|
|
508
|
+
query += f" DO UPDATE SET {set_values}"
|
|
509
|
+
else:
|
|
510
|
+
query += " DO NOTHING"
|
|
511
|
+
|
|
512
|
+
if returning_fields_cols:
|
|
513
|
+
returning_string = ", ".join(
|
|
514
|
+
f'"{field.key}"' for field in returning_fields_cols
|
|
515
|
+
)
|
|
516
|
+
query += f" RETURNING {returning_string}"
|
|
517
|
+
|
|
518
|
+
# Execute in batches
|
|
519
|
+
for batch_objects, values_list in self._batch_objects_and_values(
|
|
520
|
+
model_objects, list(fields.keys()), fields
|
|
521
|
+
):
|
|
522
|
+
if returning_fields_cols:
|
|
523
|
+
# For returning queries, we need to use fetchmany to get all results
|
|
524
|
+
rows = await self.conn.fetchmany(query, values_list)
|
|
525
|
+
for row in rows:
|
|
526
|
+
if row:
|
|
527
|
+
# Process returned values, deserializing JSON if needed
|
|
528
|
+
processed_values = []
|
|
529
|
+
for field in returning_fields_cols:
|
|
530
|
+
value = row[field.key]
|
|
531
|
+
if (
|
|
532
|
+
value is not None
|
|
533
|
+
and field.root_model.model_fields[
|
|
534
|
+
field.key
|
|
535
|
+
].is_json
|
|
536
|
+
):
|
|
537
|
+
value = json_loads(value)
|
|
538
|
+
processed_values.append(value)
|
|
539
|
+
results.append(tuple(processed_values))
|
|
540
|
+
else:
|
|
541
|
+
# For non-returning queries, we can use executemany
|
|
542
|
+
await self.conn.executemany(query, values_list)
|
|
543
|
+
|
|
544
|
+
# Clear modified state for successfully upserted objects
|
|
545
|
+
for obj in batch_objects:
|
|
546
|
+
obj.clear_modified_attributes()
|
|
547
|
+
|
|
548
|
+
self.modification_tracker.clear_status(objects)
|
|
549
|
+
|
|
550
|
+
return results if returning_fields_cols else None
|
|
551
|
+
|
|
552
|
+
async def update(self, objects: Sequence[TableBase]):
|
|
553
|
+
"""
|
|
554
|
+
Update one or more model instances in the database. Only modified attributes will be updated.
|
|
555
|
+
Updates are batched together by grouping objects with the same modified fields, then using
|
|
556
|
+
executemany() for efficiency.
|
|
557
|
+
|
|
558
|
+
```python {{sticky: True}}
|
|
559
|
+
# Update a single object
|
|
560
|
+
user = await conn.exec(select(User).where(User.id == 1))
|
|
561
|
+
user.name = "New Name"
|
|
562
|
+
await conn.update([user])
|
|
563
|
+
|
|
564
|
+
# Update multiple objects
|
|
565
|
+
users = await conn.exec(select(User).where(User.age < 18))
|
|
566
|
+
for user in users:
|
|
567
|
+
user.is_minor = True
|
|
568
|
+
await conn.update(users)
|
|
569
|
+
```
|
|
570
|
+
|
|
571
|
+
:param objects: A sequence of TableBase instances to update
|
|
572
|
+
"""
|
|
573
|
+
if not objects:
|
|
574
|
+
return
|
|
575
|
+
|
|
576
|
+
async with self.transaction(ensure=True):
|
|
577
|
+
for model, model_objects in self._aggregate_models_by_table(objects):
|
|
578
|
+
table_name = QueryIdentifier(model.get_table_name())
|
|
579
|
+
primary_key = self._get_primary_key(model)
|
|
580
|
+
|
|
581
|
+
if not primary_key:
|
|
582
|
+
raise ValueError(
|
|
583
|
+
f"Model {model} has no primary key, required to UPDATE with ORM objects"
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
primary_key_name = QueryIdentifier(primary_key)
|
|
587
|
+
|
|
588
|
+
# Group objects by their modified fields to batch similar updates
|
|
589
|
+
updates_by_fields: defaultdict[frozenset[str], list[TableBase]] = (
|
|
590
|
+
defaultdict(list)
|
|
591
|
+
)
|
|
592
|
+
for obj in model_objects:
|
|
593
|
+
modified_attrs = frozenset(
|
|
594
|
+
k
|
|
595
|
+
for k, v in obj.get_modified_attributes().items()
|
|
596
|
+
if not obj.__class__.model_fields[k].exclude
|
|
597
|
+
)
|
|
598
|
+
if modified_attrs:
|
|
599
|
+
updates_by_fields[modified_attrs].append(obj)
|
|
600
|
+
|
|
601
|
+
# Process each group of objects with the same modified fields
|
|
602
|
+
for modified_fields, group_objects in updates_by_fields.items():
|
|
603
|
+
if not modified_fields:
|
|
604
|
+
continue
|
|
605
|
+
|
|
606
|
+
# Build the UPDATE query for this group
|
|
607
|
+
field_names = list(modified_fields)
|
|
608
|
+
fields = {field: model.model_fields[field] for field in field_names}
|
|
609
|
+
|
|
610
|
+
# Build the UPDATE query - note we need one extra parameter per row for the WHERE clause
|
|
611
|
+
set_clause = ", ".join(
|
|
612
|
+
f"{QueryIdentifier(key)} = ${i + 2}"
|
|
613
|
+
for i, key in enumerate(field_names)
|
|
614
|
+
)
|
|
615
|
+
query = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key_name} = $1"
|
|
616
|
+
|
|
617
|
+
for batch_objects, values_list in self._batch_objects_and_values(
|
|
618
|
+
group_objects,
|
|
619
|
+
field_names,
|
|
620
|
+
fields,
|
|
621
|
+
extra_params_per_row=1, # For the WHERE primary_key parameter
|
|
622
|
+
):
|
|
623
|
+
# Add primary key as first parameter for each row
|
|
624
|
+
for i, obj in enumerate(batch_objects):
|
|
625
|
+
values_list[i].insert(0, getattr(obj, primary_key))
|
|
626
|
+
|
|
627
|
+
# Execute the batch update
|
|
628
|
+
await self.conn.executemany(query, values_list)
|
|
629
|
+
|
|
630
|
+
# Clear modified state for successfully updated objects
|
|
631
|
+
for obj in batch_objects:
|
|
632
|
+
obj.clear_modified_attributes()
|
|
633
|
+
|
|
634
|
+
self.modification_tracker.clear_status(objects)
|
|
635
|
+
|
|
636
|
+
async def delete(self, objects: Sequence[TableBase]):
|
|
637
|
+
"""
|
|
638
|
+
Delete one or more model instances from the database.
|
|
639
|
+
|
|
640
|
+
```python {{sticky: True}}
|
|
641
|
+
# Delete a single object
|
|
642
|
+
user = await conn.exec(select(User).where(User.id == 1))
|
|
643
|
+
await conn.delete([user])
|
|
644
|
+
|
|
645
|
+
# Delete multiple objects
|
|
646
|
+
inactive_users = await conn.exec(
|
|
647
|
+
select(User).where(User.last_login < datetime.now() - timedelta(days=90))
|
|
648
|
+
)
|
|
649
|
+
await conn.delete(inactive_users)
|
|
650
|
+
```
|
|
651
|
+
|
|
652
|
+
:param objects: A sequence of TableBase instances to delete
|
|
653
|
+
|
|
654
|
+
"""
|
|
655
|
+
async with self.transaction(ensure=True):
|
|
656
|
+
for model, model_objects in self._aggregate_models_by_table(objects):
|
|
657
|
+
table_name = QueryIdentifier(model.get_table_name())
|
|
658
|
+
primary_key = self._get_primary_key(model)
|
|
659
|
+
|
|
660
|
+
if not primary_key:
|
|
661
|
+
raise ValueError(
|
|
662
|
+
f"Model {model} has no primary key, required to UPDATE with ORM objects"
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
primary_key_name = QueryIdentifier(primary_key)
|
|
666
|
+
|
|
667
|
+
for obj in model_objects:
|
|
668
|
+
query = f"DELETE FROM {table_name} WHERE {primary_key_name} = $1"
|
|
669
|
+
await self.conn.execute(query, getattr(obj, primary_key))
|
|
670
|
+
|
|
671
|
+
self.modification_tracker.clear_status(objects)
|
|
672
|
+
|
|
673
|
+
async def refresh(self, objects: Sequence[TableBase]):
|
|
674
|
+
"""
|
|
675
|
+
Refresh one or more model instances from the database, updating their attributes
|
|
676
|
+
with the current database values.
|
|
677
|
+
|
|
678
|
+
```python {{sticky: True}}
|
|
679
|
+
# Refresh a single object
|
|
680
|
+
user = await conn.exec(select(User).where(User.id == 1))
|
|
681
|
+
# ... some time passes, database might have changed
|
|
682
|
+
await conn.refresh([user]) # User now has current database values
|
|
683
|
+
|
|
684
|
+
# Refresh multiple objects
|
|
685
|
+
users = await conn.exec(select(User).where(User.department == "Sales"))
|
|
686
|
+
# ... after some time
|
|
687
|
+
await conn.refresh(users) # All users now have current values
|
|
688
|
+
```
|
|
689
|
+
|
|
690
|
+
:param objects: A sequence of TableBase instances to refresh
|
|
691
|
+
|
|
692
|
+
"""
|
|
693
|
+
for model, model_objects in self._aggregate_models_by_table(objects):
|
|
694
|
+
table_name = QueryIdentifier(model.get_table_name())
|
|
695
|
+
primary_key = self._get_primary_key(model)
|
|
696
|
+
fields = [
|
|
697
|
+
field for field, info in model.model_fields.items() if not info.exclude
|
|
698
|
+
]
|
|
699
|
+
|
|
700
|
+
if not primary_key:
|
|
701
|
+
raise ValueError(
|
|
702
|
+
f"Model {model} has no primary key, required to UPDATE with ORM objects"
|
|
703
|
+
)
|
|
704
|
+
|
|
705
|
+
primary_key_name = QueryIdentifier(primary_key)
|
|
706
|
+
object_ids = {getattr(obj, primary_key) for obj in model_objects}
|
|
707
|
+
|
|
708
|
+
query = f"SELECT * FROM {table_name} WHERE {primary_key_name} = ANY($1)"
|
|
709
|
+
results = {
|
|
710
|
+
result[primary_key]: result
|
|
711
|
+
for result in await self.conn.fetch(query, list(object_ids))
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
# Update the objects in-place
|
|
715
|
+
for obj in model_objects:
|
|
716
|
+
obj_id = getattr(obj, primary_key)
|
|
717
|
+
if obj_id in results:
|
|
718
|
+
# Update field-by-field
|
|
719
|
+
for field in fields:
|
|
720
|
+
setattr(obj, field, results[obj_id][field])
|
|
721
|
+
else:
|
|
722
|
+
LOGGER.error(
|
|
723
|
+
f"Object {obj} with primary key {obj_id} not found in database"
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# When an object is refreshed, it's fully overwritten with the new data so by
|
|
727
|
+
# definition it's no longer modified
|
|
728
|
+
for obj in objects:
|
|
729
|
+
obj.clear_modified_attributes()
|
|
730
|
+
|
|
731
|
+
self.modification_tracker.clear_status(objects)
|
|
732
|
+
|
|
733
|
+
async def get(
|
|
734
|
+
self, model: Type[TableType], primary_key_value: Any
|
|
735
|
+
) -> TableType | None:
|
|
736
|
+
"""
|
|
737
|
+
Retrieve a single model instance by its primary key value.
|
|
738
|
+
|
|
739
|
+
This method provides a convenient way to fetch a single record from the database using its primary key.
|
|
740
|
+
It automatically constructs and executes a SELECT query with a WHERE clause matching the primary key.
|
|
741
|
+
|
|
742
|
+
```python {{sticky: True}}
|
|
743
|
+
class User(TableBase):
|
|
744
|
+
id: int = Field(primary_key=True)
|
|
745
|
+
name: str
|
|
746
|
+
email: str
|
|
747
|
+
|
|
748
|
+
# Fetch a user by ID
|
|
749
|
+
user = await db_connection.get(User, 1)
|
|
750
|
+
if user:
|
|
751
|
+
print(f"Found user: {user.name}")
|
|
752
|
+
else:
|
|
753
|
+
print("User not found")
|
|
754
|
+
```
|
|
755
|
+
|
|
756
|
+
:param model: The model class to query (must be a subclass of TableBase)
|
|
757
|
+
:param primary_key_value: The value of the primary key to look up
|
|
758
|
+
:return: The model instance if found, None if no record matches the primary key
|
|
759
|
+
:raises ValueError: If the model has no primary key defined
|
|
760
|
+
|
|
761
|
+
"""
|
|
762
|
+
primary_key = self._get_primary_key(model)
|
|
763
|
+
if not primary_key:
|
|
764
|
+
raise ValueError(
|
|
765
|
+
f"Model {model} has no primary key, required to GET with ORM objects"
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
query_builder = QueryBuilder()
|
|
769
|
+
query = query_builder.select(model).where(
|
|
770
|
+
getattr(model, primary_key) == primary_key_value
|
|
771
|
+
)
|
|
772
|
+
results = await self.exec(query)
|
|
773
|
+
return results[0] if results else None
|
|
774
|
+
|
|
775
|
+
async def close(self):
|
|
776
|
+
"""
|
|
777
|
+
Close the database connection.
|
|
778
|
+
"""
|
|
779
|
+
await self.conn.close()
|
|
780
|
+
self.modification_tracker.log()
|
|
781
|
+
|
|
782
|
+
def _aggregate_models_by_table(self, objects: Sequence[TableBase]):
|
|
783
|
+
"""
|
|
784
|
+
Group model instances by their table class for batch operations.
|
|
785
|
+
|
|
786
|
+
:param objects: Sequence of TableBase instances to group
|
|
787
|
+
:return: Iterator of (model_class, list_of_instances) pairs
|
|
788
|
+
"""
|
|
789
|
+
objects_by_class: defaultdict[Type[TableBase], list[TableBase]] = defaultdict(
|
|
790
|
+
list
|
|
791
|
+
)
|
|
792
|
+
for obj in objects:
|
|
793
|
+
objects_by_class[obj.__class__].append(obj)
|
|
794
|
+
|
|
795
|
+
return objects_by_class.items()
|
|
796
|
+
|
|
797
|
+
def _get_primary_key(self, obj: Type[TableBase]) -> str | None:
|
|
798
|
+
"""
|
|
799
|
+
Get the primary key field name for a model class, with caching.
|
|
800
|
+
|
|
801
|
+
:param obj: The model class to get the primary key for
|
|
802
|
+
:return: The name of the primary key field, or None if no primary key exists
|
|
803
|
+
"""
|
|
804
|
+
table_name = obj.get_table_name()
|
|
805
|
+
if table_name not in self.obj_to_primary_key:
|
|
806
|
+
primary_key = [
|
|
807
|
+
field for field, info in obj.model_fields.items() if info.primary_key
|
|
808
|
+
]
|
|
809
|
+
self.obj_to_primary_key[table_name] = (
|
|
810
|
+
primary_key[0] if primary_key else None
|
|
811
|
+
)
|
|
812
|
+
return self.obj_to_primary_key[table_name]
|
|
813
|
+
|
|
814
|
+
def _batch_objects_and_values(
|
|
815
|
+
self,
|
|
816
|
+
objects: Sequence[TableBase],
|
|
817
|
+
field_names: list[str],
|
|
818
|
+
fields: dict[str, Any],
|
|
819
|
+
*,
|
|
820
|
+
extra_params_per_row: int = 0,
|
|
821
|
+
):
|
|
822
|
+
"""
|
|
823
|
+
Helper function to batch objects and their values for database operations.
|
|
824
|
+
Handles batching to stay under PostgreSQL's parameter limits.
|
|
825
|
+
|
|
826
|
+
:param objects: Sequence of objects to batch
|
|
827
|
+
:param field_names: List of field names to process
|
|
828
|
+
:param fields: Dictionary of field info
|
|
829
|
+
:param extra_params_per_row: Additional parameters per row beyond the field values
|
|
830
|
+
:return: Generator of (batch_objects, values_list) tuples
|
|
831
|
+
"""
|
|
832
|
+
# Calculate max batch size based on number of fields plus any extra parameters
|
|
833
|
+
# Each row uses (len(fields) + extra_params_per_row) parameters
|
|
834
|
+
params_per_row = len(field_names) + extra_params_per_row
|
|
835
|
+
max_batch_size = PG_MAX_PARAMETERS // params_per_row
|
|
836
|
+
# Cap at 5000 rows per batch to avoid excessive memory usage
|
|
837
|
+
max_batch_size = min(max_batch_size, 5000)
|
|
838
|
+
|
|
839
|
+
total = len(objects)
|
|
840
|
+
num_batches = ceil(total / max_batch_size)
|
|
841
|
+
|
|
842
|
+
for batch_idx in range(num_batches):
|
|
843
|
+
start_idx = batch_idx * max_batch_size
|
|
844
|
+
end_idx = (batch_idx + 1) * max_batch_size
|
|
845
|
+
batch_objects = objects[start_idx:end_idx]
|
|
846
|
+
|
|
847
|
+
if not batch_objects:
|
|
848
|
+
continue
|
|
849
|
+
|
|
850
|
+
# Convert objects to value lists
|
|
851
|
+
values_list = []
|
|
852
|
+
for obj in batch_objects:
|
|
853
|
+
obj_values = obj.model_dump()
|
|
854
|
+
row_values = []
|
|
855
|
+
for field in field_names:
|
|
856
|
+
info = fields[field]
|
|
857
|
+
row_values.append(info.to_db_value(obj_values[field]))
|
|
858
|
+
values_list.append(row_values)
|
|
859
|
+
|
|
860
|
+
yield batch_objects, values_list
|