iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1265 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +764 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +351 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +263 -0
  38. iceaxe/functions.py +1432 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1459 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +711 -0
  63. iceaxe/schemas/db_serializer.py +347 -0
  64. iceaxe/schemas/db_stubs.py +529 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12207 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +149 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.8.3.dist-info/METADATA +262 -0
  72. iceaxe-0.8.3.dist-info/RECORD +75 -0
  73. iceaxe-0.8.3.dist-info/WHEEL +6 -0
  74. iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.8.3.dist-info/top_level.txt +1 -0
iceaxe/session.py ADDED
@@ -0,0 +1,860 @@
1
+ from collections import defaultdict
2
+ from contextlib import asynccontextmanager
3
+ from json import loads as json_loads
4
+ from math import ceil
5
+ from typing import (
6
+ Any,
7
+ Literal,
8
+ ParamSpec,
9
+ Sequence,
10
+ Type,
11
+ TypeVar,
12
+ cast,
13
+ overload,
14
+ )
15
+
16
+ import asyncpg
17
+ from typing_extensions import TypeVarTuple
18
+
19
+ from iceaxe.base import DBFieldClassDefinition, TableBase
20
+ from iceaxe.logging import LOGGER
21
+ from iceaxe.modifications import ModificationTracker
22
+ from iceaxe.queries import (
23
+ QueryBuilder,
24
+ is_base_table,
25
+ is_column,
26
+ is_function_metadata,
27
+ )
28
+ from iceaxe.queries_str import QueryIdentifier
29
+ from iceaxe.session_optimized import optimize_exec_casting
30
+
31
+ P = ParamSpec("P")
32
+ T = TypeVar("T")
33
+ Ts = TypeVarTuple("Ts")
34
+
35
+ TableType = TypeVar("TableType", bound=TableBase)
36
+
37
+ # PostgreSQL has a limit of 32767 parameters per query (Short.MAX_VALUE)
38
+ PG_MAX_PARAMETERS = 32767
39
+
40
+ TYPE_CACHE = {}
41
+
42
+
43
+ class DBConnection:
44
+ """
45
+ Core class for all ORM actions against a PostgreSQL database. Provides high-level methods
46
+ for executing queries and managing database transactions.
47
+
48
+ The DBConnection wraps an asyncpg Connection and provides ORM functionality for:
49
+ - Executing SELECT/INSERT/UPDATE/DELETE queries
50
+ - Managing transactions
51
+ - Inserting, updating, and deleting model instances
52
+ - Refreshing model instances from the database
53
+
54
+ ```python {{sticky: True}}
55
+ # Create a connection
56
+ conn = DBConnection(
57
+ await asyncpg.connect(
58
+ host="localhost",
59
+ port=5432,
60
+ user="db_user",
61
+ password="yoursecretpassword",
62
+ database="your_db",
63
+ )
64
+ )
65
+
66
+ # Use with models
67
+ class User(TableBase):
68
+ id: int = Field(primary_key=True)
69
+ name: str
70
+ email: str
71
+
72
+ # Insert data
73
+ user = User(name="Alice", email="alice@example.com")
74
+ await conn.insert([user])
75
+
76
+ # Query data
77
+ users = await conn.exec(
78
+ select(User)
79
+ .where(User.name == "Alice")
80
+ )
81
+
82
+ # Update data
83
+ user.email = "newemail@example.com"
84
+ await conn.update([user])
85
+ ```
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ conn: asyncpg.Connection,
91
+ *,
92
+ uncommitted_verbosity: Literal["ERROR", "WARNING", "INFO"] | None = None,
93
+ ):
94
+ """
95
+ Initialize a new database connection wrapper.
96
+
97
+ :param conn: An asyncpg Connection instance to wrap
98
+ :param uncommitted_verbosity: The verbosity level if objects are modified but not committed when
99
+ the session is closed, defaults to nothing
100
+
101
+ """
102
+ self.conn = conn
103
+ self.obj_to_primary_key: dict[str, str | None] = {}
104
+ self.in_transaction = False
105
+ self.modification_tracker = ModificationTracker(uncommitted_verbosity)
106
+
107
+ async def initialize_types(self, timeout: float = 60.0) -> None:
108
+ """
109
+ Introspect and register PostgreSQL type codecs on this connection,
110
+ caching the result globally using the connection's DB URL as a key. These types
111
+ are unlikely to change in the lifetime of a Python process, so this is typically
112
+ safe to do automatically.
113
+
114
+ This method should be called once per connection so we can leverage our own cache. If
115
+ asyncpg is called directly on a new connection, it will result in its own duplicate
116
+ type introspection call.
117
+
118
+ """
119
+ global TYPE_CACHE
120
+
121
+ if not self.conn._protocol:
122
+ LOGGER.warning(
123
+ "No protocol found for connection during type introspection, will fall back to asyncpg"
124
+ )
125
+ return
126
+
127
+ # Determine a unique key for this connection.
128
+ db_url = self.get_dsn()
129
+
130
+ # If we've already cached the type information for this DB URL, just register it.
131
+ if db_url in TYPE_CACHE:
132
+ self.conn._protocol.get_settings().register_data_types(TYPE_CACHE[db_url])
133
+ return
134
+
135
+ # Get the connection settings object (this is where type codecs are registered).
136
+ settings = self.conn._protocol.get_settings()
137
+
138
+ # Query PostgreSQL to get all type OIDs from non-system schemas.
139
+ rows = await self.conn.fetch(
140
+ """
141
+ SELECT t.oid
142
+ FROM pg_type t
143
+ JOIN pg_namespace n ON t.typnamespace = n.oid
144
+ WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
145
+ """
146
+ )
147
+ # Build a set of type OIDs.
148
+ typeoids = {row["oid"] for row in rows}
149
+
150
+ # Introspect types – this call will recursively determine the PostgreSQL types needed.
151
+ types, intro_stmt = await self.conn._introspect_types(typeoids, timeout)
152
+
153
+ # Register the introspected types with the connection's settings.
154
+ settings.register_data_types(types)
155
+
156
+ # Cache the types globally so that future connections using the same DB URL
157
+ # can simply register the cached codecs.
158
+ TYPE_CACHE[db_url] = types
159
+
160
+ def get_dsn(self) -> str:
161
+ """
162
+ Get the DSN (Data Source Name) string for this connection.
163
+
164
+ :return: DSN string in the format 'postgresql://user:password@host:port/dbname'
165
+ """
166
+ params = self.conn._params
167
+ addr = self.conn._addr
168
+
169
+ # Build the DSN string with all available parameters
170
+ dsn_parts = ["postgresql://"]
171
+
172
+ # Add user/password if available
173
+ if params.user:
174
+ dsn_parts.append(params.user)
175
+ if params.password:
176
+ dsn_parts.append(f":{params.password}")
177
+ dsn_parts.append("@")
178
+
179
+ # Add host/port
180
+ dsn_parts.append(addr[0])
181
+ if addr[1]:
182
+ dsn_parts.append(f":{addr[1]}")
183
+
184
+ # Add database name
185
+ if params.database:
186
+ dsn_parts.append(f"/{params.database}")
187
+
188
+ return "".join(dsn_parts)
189
+
190
+ @asynccontextmanager
191
+ async def transaction(self, *, ensure: bool = False):
192
+ """
193
+ Context manager for managing database transactions. Ensures that a series of database
194
+ operations are executed atomically.
195
+
196
+ :param ensure: If True and already in a transaction, the context manager will yield without creating a new transaction.
197
+ If False (default) and already in a transaction, raises a RuntimeError.
198
+
199
+ ```python {{sticky: True}}
200
+ async with conn.transaction():
201
+ # All operations here are executed in a transaction
202
+ user = User(name="Alice", email="alice@example.com")
203
+ await conn.insert([user])
204
+
205
+ post = Post(title="Hello", user_id=user.id)
206
+ await conn.insert([post])
207
+
208
+ # If any operation fails, all changes are rolled back
209
+ ```
210
+ """
211
+ # If ensure is True and we're already in a transaction, just yield
212
+ if self.in_transaction:
213
+ if ensure:
214
+ yield
215
+ return
216
+ else:
217
+ raise RuntimeError(
218
+ "Cannot start a new transaction while already in a transaction. Use ensure=True if this is intentional."
219
+ )
220
+
221
+ # Otherwise, start a new transaction
222
+ self.in_transaction = True
223
+ async with self.conn.transaction():
224
+ try:
225
+ yield
226
+ finally:
227
+ self.in_transaction = False
228
+
229
+ @overload
230
+ async def exec(self, query: QueryBuilder[T, Literal["SELECT"]]) -> list[T]: ...
231
+
232
+ @overload
233
+ async def exec(self, query: QueryBuilder[T, Literal["INSERT"]]) -> None: ...
234
+
235
+ @overload
236
+ async def exec(self, query: QueryBuilder[T, Literal["UPDATE"]]) -> None: ...
237
+
238
+ @overload
239
+ async def exec(self, query: QueryBuilder[T, Literal["DELETE"]]) -> None: ...
240
+
241
+ async def exec(
242
+ self,
243
+ query: QueryBuilder[T, Literal["SELECT"]]
244
+ | QueryBuilder[T, Literal["INSERT"]]
245
+ | QueryBuilder[T, Literal["UPDATE"]]
246
+ | QueryBuilder[T, Literal["DELETE"]],
247
+ ) -> list[T] | None:
248
+ """
249
+ Execute a query built with QueryBuilder and return the results.
250
+
251
+ ```python {{sticky: True}}
252
+ # Select query
253
+ users = await conn.exec(
254
+ select(User)
255
+ .where(User.age >= 18)
256
+ .order_by(User.name)
257
+ )
258
+
259
+ # Select with joins and aggregates
260
+ results = await conn.exec(
261
+ select((User.name, func.count(Order.id)))
262
+ .join(Order, Order.user_id == User.id)
263
+ .group_by(User.name)
264
+ .having(func.count(Order.id) > 5)
265
+ )
266
+
267
+ # Delete query
268
+ await conn.exec(
269
+ delete(User)
270
+ .where(User.is_active == False)
271
+ )
272
+ ```
273
+
274
+ :param query: A QueryBuilder instance representing the query to execute
275
+ :return: For SELECT queries, returns a list of results. For other queries, returns None
276
+
277
+ """
278
+ sql_text, variables = query.build()
279
+ LOGGER.debug(f"Executing query: {sql_text} with variables: {variables}")
280
+ try:
281
+ values = await self.conn.fetch(sql_text, *variables)
282
+ except Exception as e:
283
+ LOGGER.error(
284
+ f"Error executing query: {sql_text} with variables: {variables}"
285
+ )
286
+ raise e
287
+
288
+ if query._query_type == "SELECT":
289
+ # Pre-cache the select types for better performance
290
+ select_types = [
291
+ (
292
+ is_base_table(select_raw),
293
+ is_column(select_raw),
294
+ is_function_metadata(select_raw),
295
+ )
296
+ for select_raw in query._select_raw
297
+ ]
298
+
299
+ result_all = optimize_exec_casting(values, query._select_raw, select_types)
300
+
301
+ # Only loop through results if we have verbosity enabled, since this logic otherwise
302
+ # is wasted if no content will eventually be logged
303
+ if self.modification_tracker.verbosity:
304
+ for row in result_all:
305
+ elements = row if isinstance(row, tuple) else (row,)
306
+ for element in elements:
307
+ if isinstance(element, TableBase):
308
+ element.register_modified_callback(
309
+ self.modification_tracker.track_modification
310
+ )
311
+
312
+ return cast(list[T], result_all)
313
+
314
+ return None
315
+
316
+ async def insert(self, objects: Sequence[TableBase]):
317
+ """
318
+ Insert one or more model instances into the database. If the model has an auto-incrementing
319
+ primary key, it will be populated on the instances after insertion.
320
+
321
+ ```python {{sticky: True}}
322
+ # Insert a single object
323
+ user = User(name="Alice", email="alice@example.com")
324
+ await conn.insert([user])
325
+ print(user.id) # Auto-populated primary key
326
+
327
+ # Insert multiple objects
328
+ users = [
329
+ User(name="Bob", email="bob@example.com"),
330
+ User(name="Charlie", email="charlie@example.com")
331
+ ]
332
+ await conn.insert(users)
333
+ ```
334
+
335
+ :param objects: A sequence of TableBase instances to insert
336
+
337
+ """
338
+ if not objects:
339
+ return
340
+
341
+ # Reuse a single transaction for all inserts
342
+ async with self.transaction(ensure=True):
343
+ for model, model_objects in self._aggregate_models_by_table(objects):
344
+ # For each table, build batched insert queries
345
+ table_name = QueryIdentifier(model.get_table_name())
346
+ fields = {
347
+ field: info
348
+ for field, info in model.model_fields.items()
349
+ if (not info.exclude and not info.autoincrement)
350
+ }
351
+ primary_key = self._get_primary_key(model)
352
+ field_names = list(fields.keys())
353
+ field_identifiers = ", ".join(f'"{f}"' for f in field_names)
354
+
355
+ # Build the base query
356
+ if primary_key:
357
+ query = f"""
358
+ INSERT INTO {table_name} ({field_identifiers})
359
+ VALUES ({", ".join(f"${i}" for i in range(1, len(field_names) + 1))})
360
+ RETURNING {primary_key}
361
+ """
362
+ else:
363
+ query = f"""
364
+ INSERT INTO {table_name} ({field_identifiers})
365
+ VALUES ({", ".join(f"${i}" for i in range(1, len(field_names) + 1))})
366
+ """
367
+
368
+ for batch_objects, values_list in self._batch_objects_and_values(
369
+ model_objects, field_names, fields
370
+ ):
371
+ # Insert them in one go
372
+ if primary_key:
373
+ # For returning queries, we can use fetchmany to get the primary keys
374
+ rows = await self.conn.fetchmany(query, values_list)
375
+ for obj, row in zip(batch_objects, rows):
376
+ setattr(obj, primary_key, row[primary_key])
377
+ else:
378
+ # For non-returning queries, we can use executemany
379
+ await self.conn.executemany(query, values_list)
380
+
381
+ # Mark as unmodified
382
+ for obj in batch_objects:
383
+ obj.clear_modified_attributes()
384
+
385
+ # Register modification callbacks outside the main insert loop
386
+ if self.modification_tracker.verbosity:
387
+ for obj in objects:
388
+ obj.register_modified_callback(
389
+ self.modification_tracker.track_modification
390
+ )
391
+
392
+ # Clear modification status
393
+ self.modification_tracker.clear_status(objects)
394
+
395
+ @overload
396
+ async def upsert(
397
+ self,
398
+ objects: Sequence[TableBase],
399
+ *,
400
+ conflict_fields: tuple[Any, ...],
401
+ update_fields: tuple[Any, ...] | None = None,
402
+ returning_fields: tuple[T, *Ts] | None = None,
403
+ ) -> list[tuple[T, *Ts]] | None: ...
404
+
405
+ @overload
406
+ async def upsert(
407
+ self,
408
+ objects: Sequence[TableBase],
409
+ *,
410
+ conflict_fields: tuple[Any, ...],
411
+ update_fields: tuple[Any, ...] | None = None,
412
+ returning_fields: None,
413
+ ) -> None: ...
414
+
415
+ async def upsert(
416
+ self,
417
+ objects: Sequence[TableBase],
418
+ *,
419
+ conflict_fields: tuple[Any, ...],
420
+ update_fields: tuple[Any, ...] | None = None,
421
+ returning_fields: tuple[T, *Ts] | None = None,
422
+ ) -> list[tuple[T, *Ts]] | None:
423
+ """
424
+ Performs an upsert (INSERT ... ON CONFLICT DO UPDATE) operation for the given objects.
425
+ This is useful when you want to insert records but update them if they already exist.
426
+
427
+ ```python {{sticky: True}}
428
+ # Simple upsert based on email
429
+ users = [
430
+ User(email="alice@example.com", name="Alice"),
431
+ User(email="bob@example.com", name="Bob")
432
+ ]
433
+ await conn.upsert(
434
+ users,
435
+ conflict_fields=(User.email,),
436
+ update_fields=(User.name,)
437
+ )
438
+
439
+ # Upsert with returning values
440
+ results = await conn.upsert(
441
+ users,
442
+ conflict_fields=(User.email,),
443
+ update_fields=(User.name,),
444
+ returning_fields=(User.id, User.email)
445
+ )
446
+ for user_id, email in results:
447
+ print(f"Upserted user {email} with ID {user_id}")
448
+ ```
449
+
450
+ :param objects: Sequence of TableBase objects to upsert
451
+ :param conflict_fields: Fields to check for conflicts (ON CONFLICT)
452
+ :param update_fields: Fields to update on conflict. If None, updates all non-excluded fields
453
+ :param returning_fields: Fields to return after the operation. If None, returns nothing
454
+ :return: List of tuples containing the returned fields if returning_fields is specified
455
+
456
+ """
457
+ if not objects:
458
+ return None
459
+
460
+ # Evaluate column types
461
+ conflict_fields_cols: list[DBFieldClassDefinition] = []
462
+ update_fields_cols: list[DBFieldClassDefinition] = []
463
+ returning_fields_cols: list[DBFieldClassDefinition] = []
464
+
465
+ # Explicitly validate types of all columns
466
+ for field in conflict_fields:
467
+ if is_column(field):
468
+ conflict_fields_cols.append(field)
469
+ else:
470
+ raise ValueError(f"Field {field} is not a column")
471
+ for field in update_fields or []:
472
+ if is_column(field):
473
+ update_fields_cols.append(field)
474
+ else:
475
+ raise ValueError(f"Field {field} is not a column")
476
+ for field in returning_fields or []:
477
+ if is_column(field):
478
+ returning_fields_cols.append(field)
479
+ else:
480
+ raise ValueError(f"Field {field} is not a column")
481
+
482
+ results: list[tuple[T, *Ts]] = []
483
+ async with self.transaction(ensure=True):
484
+ for model, model_objects in self._aggregate_models_by_table(objects):
485
+ table_name = QueryIdentifier(model.get_table_name())
486
+ fields = {
487
+ field: info
488
+ for field, info in model.model_fields.items()
489
+ if (not info.exclude and not info.autoincrement)
490
+ }
491
+
492
+ field_string = ", ".join(f'"{field}"' for field in fields)
493
+ placeholders = ", ".join(f"${i}" for i in range(1, len(fields) + 1))
494
+ query = (
495
+ f"INSERT INTO {table_name} ({field_string}) VALUES ({placeholders})"
496
+ )
497
+ if conflict_fields_cols:
498
+ conflict_field_string = ", ".join(
499
+ f'"{field.key}"' for field in conflict_fields_cols
500
+ )
501
+ query += f" ON CONFLICT ({conflict_field_string})"
502
+
503
+ if update_fields_cols:
504
+ set_values = ", ".join(
505
+ f'"{field.key}" = EXCLUDED."{field.key}"'
506
+ for field in update_fields_cols
507
+ )
508
+ query += f" DO UPDATE SET {set_values}"
509
+ else:
510
+ query += " DO NOTHING"
511
+
512
+ if returning_fields_cols:
513
+ returning_string = ", ".join(
514
+ f'"{field.key}"' for field in returning_fields_cols
515
+ )
516
+ query += f" RETURNING {returning_string}"
517
+
518
+ # Execute in batches
519
+ for batch_objects, values_list in self._batch_objects_and_values(
520
+ model_objects, list(fields.keys()), fields
521
+ ):
522
+ if returning_fields_cols:
523
+ # For returning queries, we need to use fetchmany to get all results
524
+ rows = await self.conn.fetchmany(query, values_list)
525
+ for row in rows:
526
+ if row:
527
+ # Process returned values, deserializing JSON if needed
528
+ processed_values = []
529
+ for field in returning_fields_cols:
530
+ value = row[field.key]
531
+ if (
532
+ value is not None
533
+ and field.root_model.model_fields[
534
+ field.key
535
+ ].is_json
536
+ ):
537
+ value = json_loads(value)
538
+ processed_values.append(value)
539
+ results.append(tuple(processed_values))
540
+ else:
541
+ # For non-returning queries, we can use executemany
542
+ await self.conn.executemany(query, values_list)
543
+
544
+ # Clear modified state for successfully upserted objects
545
+ for obj in batch_objects:
546
+ obj.clear_modified_attributes()
547
+
548
+ self.modification_tracker.clear_status(objects)
549
+
550
+ return results if returning_fields_cols else None
551
+
552
+ async def update(self, objects: Sequence[TableBase]):
553
+ """
554
+ Update one or more model instances in the database. Only modified attributes will be updated.
555
+ Updates are batched together by grouping objects with the same modified fields, then using
556
+ executemany() for efficiency.
557
+
558
+ ```python {{sticky: True}}
559
+ # Update a single object
560
+ user = await conn.exec(select(User).where(User.id == 1))
561
+ user.name = "New Name"
562
+ await conn.update([user])
563
+
564
+ # Update multiple objects
565
+ users = await conn.exec(select(User).where(User.age < 18))
566
+ for user in users:
567
+ user.is_minor = True
568
+ await conn.update(users)
569
+ ```
570
+
571
+ :param objects: A sequence of TableBase instances to update
572
+ """
573
+ if not objects:
574
+ return
575
+
576
+ async with self.transaction(ensure=True):
577
+ for model, model_objects in self._aggregate_models_by_table(objects):
578
+ table_name = QueryIdentifier(model.get_table_name())
579
+ primary_key = self._get_primary_key(model)
580
+
581
+ if not primary_key:
582
+ raise ValueError(
583
+ f"Model {model} has no primary key, required to UPDATE with ORM objects"
584
+ )
585
+
586
+ primary_key_name = QueryIdentifier(primary_key)
587
+
588
+ # Group objects by their modified fields to batch similar updates
589
+ updates_by_fields: defaultdict[frozenset[str], list[TableBase]] = (
590
+ defaultdict(list)
591
+ )
592
+ for obj in model_objects:
593
+ modified_attrs = frozenset(
594
+ k
595
+ for k, v in obj.get_modified_attributes().items()
596
+ if not obj.__class__.model_fields[k].exclude
597
+ )
598
+ if modified_attrs:
599
+ updates_by_fields[modified_attrs].append(obj)
600
+
601
+ # Process each group of objects with the same modified fields
602
+ for modified_fields, group_objects in updates_by_fields.items():
603
+ if not modified_fields:
604
+ continue
605
+
606
+ # Build the UPDATE query for this group
607
+ field_names = list(modified_fields)
608
+ fields = {field: model.model_fields[field] for field in field_names}
609
+
610
+ # Build the UPDATE query - note we need one extra parameter per row for the WHERE clause
611
+ set_clause = ", ".join(
612
+ f"{QueryIdentifier(key)} = ${i + 2}"
613
+ for i, key in enumerate(field_names)
614
+ )
615
+ query = f"UPDATE {table_name} SET {set_clause} WHERE {primary_key_name} = $1"
616
+
617
+ for batch_objects, values_list in self._batch_objects_and_values(
618
+ group_objects,
619
+ field_names,
620
+ fields,
621
+ extra_params_per_row=1, # For the WHERE primary_key parameter
622
+ ):
623
+ # Add primary key as first parameter for each row
624
+ for i, obj in enumerate(batch_objects):
625
+ values_list[i].insert(0, getattr(obj, primary_key))
626
+
627
+ # Execute the batch update
628
+ await self.conn.executemany(query, values_list)
629
+
630
+ # Clear modified state for successfully updated objects
631
+ for obj in batch_objects:
632
+ obj.clear_modified_attributes()
633
+
634
+ self.modification_tracker.clear_status(objects)
635
+
636
+ async def delete(self, objects: Sequence[TableBase]):
637
+ """
638
+ Delete one or more model instances from the database.
639
+
640
+ ```python {{sticky: True}}
641
+ # Delete a single object
642
+ user = await conn.exec(select(User).where(User.id == 1))
643
+ await conn.delete([user])
644
+
645
+ # Delete multiple objects
646
+ inactive_users = await conn.exec(
647
+ select(User).where(User.last_login < datetime.now() - timedelta(days=90))
648
+ )
649
+ await conn.delete(inactive_users)
650
+ ```
651
+
652
+ :param objects: A sequence of TableBase instances to delete
653
+
654
+ """
655
+ async with self.transaction(ensure=True):
656
+ for model, model_objects in self._aggregate_models_by_table(objects):
657
+ table_name = QueryIdentifier(model.get_table_name())
658
+ primary_key = self._get_primary_key(model)
659
+
660
+ if not primary_key:
661
+ raise ValueError(
662
+ f"Model {model} has no primary key, required to UPDATE with ORM objects"
663
+ )
664
+
665
+ primary_key_name = QueryIdentifier(primary_key)
666
+
667
+ for obj in model_objects:
668
+ query = f"DELETE FROM {table_name} WHERE {primary_key_name} = $1"
669
+ await self.conn.execute(query, getattr(obj, primary_key))
670
+
671
+ self.modification_tracker.clear_status(objects)
672
+
673
+ async def refresh(self, objects: Sequence[TableBase]):
674
+ """
675
+ Refresh one or more model instances from the database, updating their attributes
676
+ with the current database values.
677
+
678
+ ```python {{sticky: True}}
679
+ # Refresh a single object
680
+ user = await conn.exec(select(User).where(User.id == 1))
681
+ # ... some time passes, database might have changed
682
+ await conn.refresh([user]) # User now has current database values
683
+
684
+ # Refresh multiple objects
685
+ users = await conn.exec(select(User).where(User.department == "Sales"))
686
+ # ... after some time
687
+ await conn.refresh(users) # All users now have current values
688
+ ```
689
+
690
+ :param objects: A sequence of TableBase instances to refresh
691
+
692
+ """
693
+ for model, model_objects in self._aggregate_models_by_table(objects):
694
+ table_name = QueryIdentifier(model.get_table_name())
695
+ primary_key = self._get_primary_key(model)
696
+ fields = [
697
+ field for field, info in model.model_fields.items() if not info.exclude
698
+ ]
699
+
700
+ if not primary_key:
701
+ raise ValueError(
702
+ f"Model {model} has no primary key, required to UPDATE with ORM objects"
703
+ )
704
+
705
+ primary_key_name = QueryIdentifier(primary_key)
706
+ object_ids = {getattr(obj, primary_key) for obj in model_objects}
707
+
708
+ query = f"SELECT * FROM {table_name} WHERE {primary_key_name} = ANY($1)"
709
+ results = {
710
+ result[primary_key]: result
711
+ for result in await self.conn.fetch(query, list(object_ids))
712
+ }
713
+
714
+ # Update the objects in-place
715
+ for obj in model_objects:
716
+ obj_id = getattr(obj, primary_key)
717
+ if obj_id in results:
718
+ # Update field-by-field
719
+ for field in fields:
720
+ setattr(obj, field, results[obj_id][field])
721
+ else:
722
+ LOGGER.error(
723
+ f"Object {obj} with primary key {obj_id} not found in database"
724
+ )
725
+
726
+ # When an object is refreshed, it's fully overwritten with the new data so by
727
+ # definition it's no longer modified
728
+ for obj in objects:
729
+ obj.clear_modified_attributes()
730
+
731
+ self.modification_tracker.clear_status(objects)
732
+
733
+ async def get(
734
+ self, model: Type[TableType], primary_key_value: Any
735
+ ) -> TableType | None:
736
+ """
737
+ Retrieve a single model instance by its primary key value.
738
+
739
+ This method provides a convenient way to fetch a single record from the database using its primary key.
740
+ It automatically constructs and executes a SELECT query with a WHERE clause matching the primary key.
741
+
742
+ ```python {{sticky: True}}
743
+ class User(TableBase):
744
+ id: int = Field(primary_key=True)
745
+ name: str
746
+ email: str
747
+
748
+ # Fetch a user by ID
749
+ user = await db_connection.get(User, 1)
750
+ if user:
751
+ print(f"Found user: {user.name}")
752
+ else:
753
+ print("User not found")
754
+ ```
755
+
756
+ :param model: The model class to query (must be a subclass of TableBase)
757
+ :param primary_key_value: The value of the primary key to look up
758
+ :return: The model instance if found, None if no record matches the primary key
759
+ :raises ValueError: If the model has no primary key defined
760
+
761
+ """
762
+ primary_key = self._get_primary_key(model)
763
+ if not primary_key:
764
+ raise ValueError(
765
+ f"Model {model} has no primary key, required to GET with ORM objects"
766
+ )
767
+
768
+ query_builder = QueryBuilder()
769
+ query = query_builder.select(model).where(
770
+ getattr(model, primary_key) == primary_key_value
771
+ )
772
+ results = await self.exec(query)
773
+ return results[0] if results else None
774
+
775
+ async def close(self):
776
+ """
777
+ Close the database connection.
778
+ """
779
+ await self.conn.close()
780
+ self.modification_tracker.log()
781
+
782
+ def _aggregate_models_by_table(self, objects: Sequence[TableBase]):
783
+ """
784
+ Group model instances by their table class for batch operations.
785
+
786
+ :param objects: Sequence of TableBase instances to group
787
+ :return: Iterator of (model_class, list_of_instances) pairs
788
+ """
789
+ objects_by_class: defaultdict[Type[TableBase], list[TableBase]] = defaultdict(
790
+ list
791
+ )
792
+ for obj in objects:
793
+ objects_by_class[obj.__class__].append(obj)
794
+
795
+ return objects_by_class.items()
796
+
797
+ def _get_primary_key(self, obj: Type[TableBase]) -> str | None:
798
+ """
799
+ Get the primary key field name for a model class, with caching.
800
+
801
+ :param obj: The model class to get the primary key for
802
+ :return: The name of the primary key field, or None if no primary key exists
803
+ """
804
+ table_name = obj.get_table_name()
805
+ if table_name not in self.obj_to_primary_key:
806
+ primary_key = [
807
+ field for field, info in obj.model_fields.items() if info.primary_key
808
+ ]
809
+ self.obj_to_primary_key[table_name] = (
810
+ primary_key[0] if primary_key else None
811
+ )
812
+ return self.obj_to_primary_key[table_name]
813
+
814
+ def _batch_objects_and_values(
815
+ self,
816
+ objects: Sequence[TableBase],
817
+ field_names: list[str],
818
+ fields: dict[str, Any],
819
+ *,
820
+ extra_params_per_row: int = 0,
821
+ ):
822
+ """
823
+ Helper function to batch objects and their values for database operations.
824
+ Handles batching to stay under PostgreSQL's parameter limits.
825
+
826
+ :param objects: Sequence of objects to batch
827
+ :param field_names: List of field names to process
828
+ :param fields: Dictionary of field info
829
+ :param extra_params_per_row: Additional parameters per row beyond the field values
830
+ :return: Generator of (batch_objects, values_list) tuples
831
+ """
832
+ # Calculate max batch size based on number of fields plus any extra parameters
833
+ # Each row uses (len(fields) + extra_params_per_row) parameters
834
+ params_per_row = len(field_names) + extra_params_per_row
835
+ max_batch_size = PG_MAX_PARAMETERS // params_per_row
836
+ # Cap at 5000 rows per batch to avoid excessive memory usage
837
+ max_batch_size = min(max_batch_size, 5000)
838
+
839
+ total = len(objects)
840
+ num_batches = ceil(total / max_batch_size)
841
+
842
+ for batch_idx in range(num_batches):
843
+ start_idx = batch_idx * max_batch_size
844
+ end_idx = (batch_idx + 1) * max_batch_size
845
+ batch_objects = objects[start_idx:end_idx]
846
+
847
+ if not batch_objects:
848
+ continue
849
+
850
+ # Convert objects to value lists
851
+ values_list = []
852
+ for obj in batch_objects:
853
+ obj_values = obj.model_dump()
854
+ row_values = []
855
+ for field in field_names:
856
+ info = fields[field]
857
+ row_values.append(info.to_db_value(obj_values[field]))
858
+ values_list.append(row_values)
859
+
860
+ yield batch_objects, values_list