iceaxe 0.7.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1264 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1525 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +398 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +605 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +350 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +250 -0
  38. iceaxe/functions.py +906 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1455 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +705 -0
  63. iceaxe/schemas/db_serializer.py +346 -0
  64. iceaxe/schemas/db_stubs.py +525 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12035 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +148 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.7.1.dist-info/METADATA +261 -0
  72. iceaxe-0.7.1.dist-info/RECORD +75 -0
  73. iceaxe-0.7.1.dist-info/WHEEL +6 -0
  74. iceaxe-0.7.1.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.7.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,287 @@
1
+ from typing import Optional, TypeVar, cast
2
+
3
+ import pytest
4
+
5
+ from iceaxe import Field, TableBase, alias, func, select
6
+ from iceaxe.field import DBFieldInfo
7
+ from iceaxe.postgres import LexemePriority, PostgresFullText
8
+ from iceaxe.session import DBConnection
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class Article(TableBase):
14
+ """Test model for full-text search."""
15
+
16
+ id: int = Field(primary_key=True)
17
+ title: str = Field(postgres_config=PostgresFullText(language="english", weight="A"))
18
+ content: str = Field(
19
+ postgres_config=PostgresFullText(language="english", weight="B")
20
+ )
21
+ summary: Optional[str] = Field(
22
+ default=None, postgres_config=PostgresFullText(language="english", weight="C")
23
+ )
24
+
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_basic_text_search(indexed_db_connection: DBConnection):
28
+ """Test basic text search functionality using query builder."""
29
+ # Create test data
30
+ articles = [
31
+ Article(
32
+ id=1, title="Python Programming", content="Learn Python programming basics"
33
+ ),
34
+ Article(
35
+ id=2, title="Database Design", content="Python and database design patterns"
36
+ ),
37
+ Article(id=3, title="Web Development", content="Building web apps with Python"),
38
+ ]
39
+
40
+ await indexed_db_connection.insert(articles)
41
+
42
+ # Search in title only
43
+ title_vector = func.to_tsvector("english", Article.title)
44
+ query = func.to_tsquery("english", "python")
45
+
46
+ results = await indexed_db_connection.exec(
47
+ select(Article).where(title_vector.matches(query))
48
+ )
49
+ assert len(results) == 1
50
+ assert results[0].id == 1
51
+
52
+ # Search in content only
53
+ content_vector = func.to_tsvector("english", Article.content)
54
+ results = await indexed_db_connection.exec(
55
+ select(Article).where(content_vector.matches(query))
56
+ )
57
+ assert len(results) == 3 # All articles mention Python in content
58
+
59
+
60
+ @pytest.mark.asyncio
61
+ async def test_complex_text_search(indexed_db_connection: DBConnection):
62
+ """Test complex text search queries with boolean operators."""
63
+ articles = [
64
+ Article(id=1, title="Python Programming", content="Learn programming basics"),
65
+ Article(id=2, title="Python Advanced", content="Advanced programming concepts"),
66
+ Article(
67
+ id=3, title="JavaScript Basics", content="Learn programming with JavaScript"
68
+ ),
69
+ ]
70
+
71
+ await indexed_db_connection.insert(articles)
72
+
73
+ # Test AND operator
74
+ vector = func.to_tsvector("english", Article.title)
75
+ query = func.to_tsquery("english", "python & programming")
76
+ results = await indexed_db_connection.exec(
77
+ select(Article).where(vector.matches(query))
78
+ )
79
+ assert len(results) == 1
80
+ assert results[0].id == 1
81
+
82
+ # Test OR operator
83
+ query = func.to_tsquery("english", "python | javascript")
84
+ results = await indexed_db_connection.exec(
85
+ select(Article).where(vector.matches(query))
86
+ )
87
+ assert len(results) == 3
88
+ assert {r.id for r in results} == {1, 2, 3}
89
+
90
+ # Test NOT operator
91
+ query = func.to_tsquery("english", "programming & !python")
92
+ results = await indexed_db_connection.exec(
93
+ select(Article).where(vector.matches(query))
94
+ )
95
+ assert len(results) == 0 # No articles have "programming" without "python" in title
96
+
97
+
98
+ @pytest.mark.asyncio
99
+ async def test_combined_field_search(indexed_db_connection: DBConnection):
100
+ """Test searching across multiple fields."""
101
+ articles = [
102
+ Article(
103
+ id=1,
104
+ title="Python Guide",
105
+ content="Learn programming basics",
106
+ summary="A beginner's guide to Python",
107
+ ),
108
+ Article(
109
+ id=2,
110
+ title="Programming Tips",
111
+ content="Python best practices",
112
+ summary="Advanced Python concepts",
113
+ ),
114
+ ]
115
+
116
+ await indexed_db_connection.insert(articles)
117
+
118
+ # Search across all fields using list syntax
119
+ vector = func.to_tsvector(
120
+ "english", [Article.title, Article.content, Article.summary]
121
+ )
122
+ query = func.to_tsquery("english", "python & guide")
123
+
124
+ results = await indexed_db_connection.exec(
125
+ select(Article).where(vector.matches(query))
126
+ )
127
+ assert len(results) == 1
128
+ assert results[0].id == 1 # Only first article has both "python" and "guide"
129
+
130
+ # Test the original concatenation syntax still works
131
+ vector_concat = (
132
+ func.to_tsvector("english", Article.title)
133
+ .concat(func.to_tsvector("english", Article.content))
134
+ .concat(func.to_tsvector("english", Article.summary))
135
+ )
136
+ query = func.to_tsquery("english", "python & guide")
137
+
138
+ results_concat = await indexed_db_connection.exec(
139
+ select(Article).where(vector_concat.matches(query))
140
+ )
141
+ assert len(results_concat) == 1
142
+ assert results_concat[0].id == 1 # Results should be the same with both approaches
143
+
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_weighted_text_search(indexed_db_connection: DBConnection):
147
+ """Test text search with weighted columns."""
148
+ articles = [
149
+ Article(
150
+ id=1,
151
+ title="Python Guide", # Weight A
152
+ content="Basic Python", # Weight B
153
+ summary="Python tutorial", # Weight C
154
+ ),
155
+ Article(
156
+ id=2,
157
+ title="Programming",
158
+ content="Python Guide",
159
+ summary="Guide to programming",
160
+ ),
161
+ ]
162
+
163
+ await indexed_db_connection.insert(articles)
164
+
165
+ # Search with weights
166
+ vector = (
167
+ func.setweight(func.to_tsvector("english", Article.title), "A")
168
+ .concat(func.setweight(func.to_tsvector("english", Article.content), "B"))
169
+ .concat(func.setweight(func.to_tsvector("english", Article.summary), "C"))
170
+ )
171
+ query = func.to_tsquery("english", "python & guide")
172
+
173
+ results = await indexed_db_connection.exec(
174
+ select((Article, alias("ts_rank", func.ts_rank(vector, query))))
175
+ .where(vector.matches(query))
176
+ .order_by("ts_rank", direction="DESC"),
177
+ )
178
+ assert len(results) == 2
179
+ # First article should rank higher because "Python Guide" is in title (weight A)
180
+ assert results[0][0].id == 1
181
+ assert results[1][0].id == 2
182
+ assert results[0][1] > results[1][1] # Check that rank is higher
183
+
184
+
185
+ @pytest.mark.asyncio
186
+ async def test_weight_priority_variants(indexed_db_connection: DBConnection):
187
+ """Test text search using both string literals and LexemePriority enum for weights."""
188
+ articles = [
189
+ Article(
190
+ id=1,
191
+ title="Python Guide", # Weight A (string literal)
192
+ content="Basic Python", # Weight B (enum)
193
+ summary="Python tutorial", # Weight C (enum)
194
+ ),
195
+ Article(
196
+ id=2,
197
+ title="Programming",
198
+ content="Python Guide",
199
+ summary="Guide to programming",
200
+ ),
201
+ ]
202
+
203
+ # Create a variant of Article using the enum
204
+ class ArticleWithEnum(TableBase):
205
+ id: int = Field(primary_key=True)
206
+ title: str = Field(
207
+ postgres_config=PostgresFullText(
208
+ language="english", weight=LexemePriority.HIGHEST
209
+ )
210
+ )
211
+ content: str = Field(
212
+ postgres_config=PostgresFullText(
213
+ language="english", weight=LexemePriority.HIGH
214
+ )
215
+ )
216
+ summary: Optional[str] = Field(
217
+ default=None,
218
+ postgres_config=PostgresFullText(
219
+ language="english", weight=LexemePriority.LOW
220
+ ),
221
+ )
222
+
223
+ # Verify both models can be created and weights are equivalent
224
+ assert (
225
+ cast(
226
+ PostgresFullText,
227
+ cast(DBFieldInfo, Article.model_fields["title"]).postgres_config,
228
+ ).weight
229
+ == cast(
230
+ PostgresFullText,
231
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["title"]).postgres_config,
232
+ ).weight
233
+ == "A"
234
+ )
235
+ assert (
236
+ cast(
237
+ PostgresFullText,
238
+ cast(DBFieldInfo, Article.model_fields["content"]).postgres_config,
239
+ ).weight
240
+ == cast(
241
+ PostgresFullText,
242
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["content"]).postgres_config,
243
+ ).weight
244
+ == "B"
245
+ )
246
+ assert (
247
+ cast(
248
+ PostgresFullText,
249
+ cast(DBFieldInfo, Article.model_fields["summary"]).postgres_config,
250
+ ).weight
251
+ == cast(
252
+ PostgresFullText,
253
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["summary"]).postgres_config,
254
+ ).weight
255
+ == "C"
256
+ )
257
+
258
+ await indexed_db_connection.insert(articles)
259
+
260
+ vector = (
261
+ func.setweight(
262
+ func.to_tsvector("english", Article.title), LexemePriority.HIGHEST
263
+ )
264
+ .concat(
265
+ func.setweight(
266
+ func.to_tsvector("english", Article.content), LexemePriority.HIGH
267
+ )
268
+ )
269
+ .concat(
270
+ func.setweight(
271
+ func.to_tsvector("english", Article.summary), LexemePriority.LOW
272
+ )
273
+ )
274
+ )
275
+ query = func.to_tsquery("english", "python & guide")
276
+
277
+ results = await indexed_db_connection.exec(
278
+ select((Article, alias("ts_rank", func.ts_rank(vector, query))))
279
+ .where(vector.matches(query))
280
+ .order_by("ts_rank", direction="DESC"),
281
+ )
282
+
283
+ assert len(results) == 2
284
+ # First article should rank higher because "Python Guide" is in title (HIGHEST weight)
285
+ assert results[0][0].id == 1
286
+ assert results[1][0].id == 2
287
+ assert results[0][1] > results[1][1]
iceaxe/alias_values.py ADDED
@@ -0,0 +1,67 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar, cast
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class Alias(Generic[T]):
9
+ name: str
10
+ value: T
11
+
12
+ def __str__(self):
13
+ return self.name
14
+
15
+
16
+ def alias(name: str, type: T) -> T:
17
+ """
18
+ Creates an alias for a field in raw SQL queries, allowing for type-safe mapping of raw SQL results.
19
+ This is particularly useful in two main scenarios:
20
+
21
+ 1. When using raw SQL queries with aliased columns:
22
+ ```python
23
+ # Map a COUNT(*) result to an integer
24
+ query = select(alias("user_count", int)).text(
25
+ "SELECT COUNT(*) AS user_count FROM users"
26
+ )
27
+
28
+ # Map multiple aliased columns with different types
29
+ query = select((
30
+ alias("full_name", str),
31
+ alias("order_count", int),
32
+ alias("total_spent", float)
33
+ )).text(
34
+ '''
35
+ SELECT
36
+ concat(first_name, ' ', last_name) AS full_name,
37
+ COUNT(orders.id) AS order_count,
38
+ SUM(orders.amount) AS total_spent
39
+ FROM users
40
+ LEFT JOIN orders ON users.id = orders.user_id
41
+ GROUP BY users.id
42
+ '''
43
+ )
44
+ ```
45
+
46
+ 2. When combining ORM models with function results:
47
+ ```python
48
+ # Select a model alongside a function result
49
+ query = select((
50
+ User,
51
+ alias("name_length", func.length(User.name)),
52
+ alias("upper_name", func.upper(User.name))
53
+ ))
54
+
55
+ # Use with aggregation functions
56
+ query = select((
57
+ User,
58
+ alias("total_orders", func.count(Order.id))
59
+ )).join(Order, User.id == Order.user_id).group_by(User.id)
60
+ ```
61
+
62
+ :param name: The name of the alias as it appears in the SQL query's AS clause
63
+ :param type: Either a Python type to cast the result to (e.g., int, str, float) or
64
+ a function metadata object (e.g., from func.length())
65
+ :return: A type-safe alias that can be used in select() statements
66
+ """
67
+ return cast(T, Alias(name, type))
iceaxe/base.py ADDED
@@ -0,0 +1,350 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Callable,
5
+ ClassVar,
6
+ Self,
7
+ Type,
8
+ dataclass_transform,
9
+ )
10
+
11
+ from pydantic import BaseModel, Field as PydanticField
12
+ from pydantic.main import _model_construction
13
+ from pydantic_core import PydanticUndefined
14
+
15
+ from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field
16
+
17
+
18
+ @dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,))
19
+ class DBModelMetaclass(_model_construction.ModelMetaclass):
20
+ """
21
+ Metaclass for database model classes that provides automatic field tracking and SQL query generation.
22
+ Extends Pydantic's model metaclass to add database-specific functionality.
23
+
24
+ This metaclass provides:
25
+ - Automatic field to SQL column mapping
26
+ - Dynamic field access that returns query-compatible field definitions
27
+ - Registry tracking of all database model classes
28
+ - Support for generic model instantiation
29
+
30
+ ```python {{sticky: True}}
31
+ class User(TableBase): # Uses DBModelMetaclass
32
+ id: int = Field(primary_key=True)
33
+ name: str
34
+ email: str | None
35
+
36
+ # Fields can be accessed for queries
37
+ User.id # Returns DBFieldClassDefinition
38
+ User.name # Returns DBFieldClassDefinition
39
+
40
+ # Metaclass handles model registration
41
+ registered_models = DBModelMetaclass.get_registry()
42
+ ```
43
+ """
44
+
45
+ _registry: list[Type["TableBase"]] = []
46
+ _cached_args: dict[Type["TableBase"], dict[str, Any]] = {}
47
+ is_constructing: bool = False
48
+
49
+ def __new__(mcs, name, bases, namespace, **kwargs):
50
+ """
51
+ Create a new database model class with proper field tracking.
52
+ Handles registration of the model and processes any table-specific arguments.
53
+ """
54
+ raw_kwargs = {**kwargs}
55
+
56
+ mcs.is_constructing = True
57
+ autodetect = mcs._extract_kwarg(kwargs, "autodetect", True)
58
+ cls = super().__new__(mcs, name, bases, namespace, **kwargs)
59
+ mcs.is_constructing = False
60
+
61
+ # Allow future calls to subclasses / generic instantiations to reference the same
62
+ # kwargs as the base class
63
+ mcs._cached_args[cls] = raw_kwargs
64
+
65
+ # If we have already set the class's fields, we should wrap them
66
+ if hasattr(cls, "__pydantic_fields__"):
67
+ cls.__pydantic_fields__ = {
68
+ field: info
69
+ if isinstance(info, DBFieldInfo)
70
+ else DBFieldInfo.extend_field(
71
+ info,
72
+ primary_key=False,
73
+ postgres_config=None,
74
+ foreign_key=None,
75
+ unique=False,
76
+ index=False,
77
+ check_expression=None,
78
+ is_json=False,
79
+ )
80
+ for field, info in cls.model_fields.items()
81
+ }
82
+
83
+ # Avoid registering HandlerBase itself
84
+ if cls.__name__ not in {"TableBase", "BaseModel"} and autodetect:
85
+ DBModelMetaclass._registry.append(cls)
86
+
87
+ return cls
88
+
89
+ def __getattr__(self, key: str) -> Any:
90
+ """
91
+ Provides dynamic access to model fields as query-compatible definitions.
92
+ When accessing an undefined attribute, checks if it's a model field and returns
93
+ a DBFieldClassDefinition if it is.
94
+
95
+ :param key: The attribute name to access
96
+ :return: Field definition or raises AttributeError
97
+ :raises AttributeError: If the attribute doesn't exist and isn't a model field
98
+ """
99
+ if self.is_constructing:
100
+ return super().__getattr__(key) # type: ignore
101
+
102
+ try:
103
+ return super().__getattr__(key) # type: ignore
104
+ except AttributeError:
105
+ # Determine if this field is defined within the spec
106
+ # If so, return it
107
+ if key in self.model_fields:
108
+ return DBFieldClassDefinition(
109
+ root_model=self, # type: ignore
110
+ key=key,
111
+ field_definition=self.model_fields[key],
112
+ )
113
+ raise
114
+
115
+ @classmethod
116
+ def get_registry(cls) -> list[Type["TableBase"]]:
117
+ """
118
+ Get the set of all registered database model classes.
119
+
120
+ :return: Set of registered TableBase classes
121
+ """
122
+ return cls._registry
123
+
124
+ @classmethod
125
+ def _extract_kwarg(
126
+ cls, kwargs: dict[str, Any], key: str, default: Any = None
127
+ ) -> Any:
128
+ """
129
+ Extract a keyword argument from either standard kwargs or pydantic generic metadata.
130
+ Handles both normal instantiation and pydantic's generic model instantiation.
131
+
132
+ :param kwargs: Dictionary of keyword arguments
133
+ :param key: Key to extract
134
+ :param default: Default value if key not found
135
+ :return: Extracted value or default
136
+ """
137
+ if key in kwargs:
138
+ return kwargs.pop(key)
139
+
140
+ if "__pydantic_generic_metadata__" in kwargs:
141
+ origin_model = kwargs["__pydantic_generic_metadata__"]["origin"]
142
+ if origin_model in cls._cached_args:
143
+ return cls._cached_args[origin_model].get(key, default)
144
+
145
+ return default
146
+
147
+ @property
148
+ def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
149
+ """
150
+ Get the dictionary of model fields and their definitions.
151
+ Overrides the ClassVar typehint from TableBase for proper typing.
152
+
153
+ :return: Dictionary of field names to field definitions
154
+ """
155
+ return getattr(self, "__pydantic_fields__", {}) # type: ignore
156
+
157
+
158
+ class UniqueConstraint(BaseModel):
159
+ """
160
+ Represents a UNIQUE constraint in a database table.
161
+ Ensures that the specified combination of columns contains unique values across all rows.
162
+
163
+ ```python {{sticky: True}}
164
+ class User(TableBase):
165
+ email: str
166
+ tenant_id: int
167
+
168
+ table_args = [
169
+ UniqueConstraint(columns=["email", "tenant_id"])
170
+ ]
171
+ ```
172
+ """
173
+
174
+ columns: list[str]
175
+ """
176
+ List of column names that should have unique values
177
+ """
178
+
179
+
180
+ class IndexConstraint(BaseModel):
181
+ """
182
+ Represents an INDEX on one or more columns in a database table.
183
+ Improves query performance for the specified columns.
184
+
185
+ ```python {{sticky: True}}
186
+ class User(TableBase):
187
+ email: str
188
+ last_login: datetime
189
+
190
+ table_args = [
191
+ IndexConstraint(columns=["last_login"])
192
+ ]
193
+ ```
194
+ """
195
+
196
+ columns: list[str]
197
+ """
198
+ List of column names to create an index on
199
+ """
200
+
201
+
202
+ INTERNAL_TABLE_FIELDS = ["modified_attrs", "modified_attrs_callbacks"]
203
+
204
+
205
+ class TableBase(BaseModel, metaclass=DBModelMetaclass):
206
+ """
207
+ Base class for all database table models.
208
+ Provides the foundation for defining database tables using Python classes with
209
+ type hints and field definitions.
210
+
211
+ Features:
212
+ - Automatic table name generation from class name
213
+ - Support for custom table names
214
+ - Tracking of modified fields for efficient updates
215
+ - Support for unique constraints and indexes
216
+ - Integration with Pydantic for validation
217
+
218
+ ```python {{sticky: True}}
219
+ class User(TableBase):
220
+ # Custom table name (optional)
221
+ table_name = "users"
222
+
223
+ # Fields with types and constraints
224
+ id: int = Field(primary_key=True)
225
+ email: str = Field(unique=True)
226
+ name: str
227
+ is_active: bool = Field(default=True)
228
+
229
+ # Table-level constraints
230
+ table_args = [
231
+ UniqueConstraint(columns=["email"]),
232
+ IndexConstraint(columns=["name"])
233
+ ]
234
+
235
+ # Usage in queries
236
+ query = select(User).where(User.is_active == True)
237
+ users = await conn.execute(query)
238
+ ```
239
+ """
240
+
241
+ if TYPE_CHECKING:
242
+ model_fields: ClassVar[dict[str, DBFieldInfo]] # type: ignore
243
+
244
+ table_name: ClassVar[str] = PydanticUndefined # type: ignore
245
+ """
246
+ Optional custom name for the table
247
+ """
248
+
249
+ table_args: ClassVar[list[UniqueConstraint | IndexConstraint]] = PydanticUndefined # type: ignore
250
+ """
251
+ Table constraints and indexes
252
+ """
253
+
254
+ # Private methods
255
+ modified_attrs: dict[str, Any] = Field(default_factory=dict, exclude=True)
256
+ """
257
+ Dictionary of modified field values since instantiation or the last clear_modified_attributes() call.
258
+ Used to construct differential update queries.
259
+ """
260
+
261
+ modified_attrs_callbacks: list[Callable[[Self], None]] = Field(
262
+ default_factory=list, exclude=True
263
+ )
264
+ """
265
+ List of callbacks to be called when the model is modified.
266
+ """
267
+
268
+ def __setattr__(self, name: str, value: Any) -> None:
269
+ """
270
+ Track modified attributes when fields are updated.
271
+ This allows for efficient database updates by only updating changed fields.
272
+
273
+ :param name: Attribute name
274
+ :param value: New value
275
+ """
276
+ if name in self.__class__.model_fields:
277
+ self.modified_attrs[name] = value
278
+ for callback in self.modified_attrs_callbacks:
279
+ callback(self)
280
+ super().__setattr__(name, value)
281
+
282
+ def get_modified_attributes(self) -> dict[str, Any]:
283
+ """
284
+ Get the dictionary of attributes that have been modified since instantiation
285
+ or the last clear_modified_attributes() call.
286
+
287
+ :return: Dictionary of modified attribute names and their values
288
+ """
289
+ return self.modified_attrs
290
+
291
+ def clear_modified_attributes(self) -> None:
292
+ """
293
+ Clear the tracking of modified attributes.
294
+ Typically called after successfully saving changes to the database.
295
+ """
296
+ self.modified_attrs.clear()
297
+
298
+ @classmethod
299
+ def get_table_name(cls) -> str:
300
+ """
301
+ Get the table name for this model.
302
+ Uses the custom table_name if set, otherwise converts the class name to lowercase.
303
+
304
+ :return: Table name to use in SQL queries
305
+ """
306
+ if cls.table_name == PydanticUndefined:
307
+ return cls.__name__.lower()
308
+ return cls.table_name
309
+
310
+ @classmethod
311
+ def get_client_fields(cls) -> dict[str, DBFieldInfo]:
312
+ """
313
+ Get all fields that should be exposed to clients.
314
+ Excludes internal fields used for model functionality.
315
+
316
+ :return: Dictionary of field names to field definitions
317
+ """
318
+ return {
319
+ field: info
320
+ for field, info in cls.model_fields.items()
321
+ if field not in INTERNAL_TABLE_FIELDS
322
+ }
323
+
324
+ def register_modified_callback(self, callback: Callable[[Self], None]) -> None:
325
+ """
326
+ Register a callback to be called when the model is modified.
327
+ """
328
+ self.modified_attrs_callbacks.append(callback)
329
+
330
+ def __eq__(self, other: Any) -> bool:
331
+ """
332
+ Compare two model instances, ignoring modified_attrs_callbacks.
333
+ This ensures that two models with the same data but different callbacks are considered equal.
334
+ """
335
+ if not isinstance(other, self.__class__):
336
+ return False
337
+
338
+ # Get all fields except modified_attrs_callbacks
339
+ fields = {
340
+ field: value
341
+ for field, value in self.__dict__.items()
342
+ if field not in INTERNAL_TABLE_FIELDS
343
+ }
344
+ other_fields = {
345
+ field: value
346
+ for field, value in other.__dict__.items()
347
+ if field not in INTERNAL_TABLE_FIELDS
348
+ }
349
+
350
+ return fields == other_fields