iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1265 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +764 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +351 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +263 -0
  38. iceaxe/functions.py +1432 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1459 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +711 -0
  63. iceaxe/schemas/db_serializer.py +347 -0
  64. iceaxe/schemas/db_stubs.py +529 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12207 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +149 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.8.3.dist-info/METADATA +262 -0
  72. iceaxe-0.8.3.dist-info/RECORD +75 -0
  73. iceaxe-0.8.3.dist-info/WHEEL +6 -0
  74. iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.8.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,287 @@
1
+ from typing import Optional, TypeVar, cast
2
+
3
+ import pytest
4
+
5
+ from iceaxe import Field, TableBase, alias, func, select
6
+ from iceaxe.field import DBFieldInfo
7
+ from iceaxe.postgres import LexemePriority, PostgresFullText
8
+ from iceaxe.session import DBConnection
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class Article(TableBase):
14
+ """Test model for full-text search."""
15
+
16
+ id: int = Field(primary_key=True)
17
+ title: str = Field(postgres_config=PostgresFullText(language="english", weight="A"))
18
+ content: str = Field(
19
+ postgres_config=PostgresFullText(language="english", weight="B")
20
+ )
21
+ summary: Optional[str] = Field(
22
+ default=None, postgres_config=PostgresFullText(language="english", weight="C")
23
+ )
24
+
25
+
26
+ @pytest.mark.asyncio
27
+ async def test_basic_text_search(indexed_db_connection: DBConnection):
28
+ """Test basic text search functionality using query builder."""
29
+ # Create test data
30
+ articles = [
31
+ Article(
32
+ id=1, title="Python Programming", content="Learn Python programming basics"
33
+ ),
34
+ Article(
35
+ id=2, title="Database Design", content="Python and database design patterns"
36
+ ),
37
+ Article(id=3, title="Web Development", content="Building web apps with Python"),
38
+ ]
39
+
40
+ await indexed_db_connection.insert(articles)
41
+
42
+ # Search in title only
43
+ title_vector = func.to_tsvector("english", Article.title)
44
+ query = func.to_tsquery("english", "python")
45
+
46
+ results = await indexed_db_connection.exec(
47
+ select(Article).where(title_vector.matches(query))
48
+ )
49
+ assert len(results) == 1
50
+ assert results[0].id == 1
51
+
52
+ # Search in content only
53
+ content_vector = func.to_tsvector("english", Article.content)
54
+ results = await indexed_db_connection.exec(
55
+ select(Article).where(content_vector.matches(query))
56
+ )
57
+ assert len(results) == 3 # All articles mention Python in content
58
+
59
+
60
+ @pytest.mark.asyncio
61
+ async def test_complex_text_search(indexed_db_connection: DBConnection):
62
+ """Test complex text search queries with boolean operators."""
63
+ articles = [
64
+ Article(id=1, title="Python Programming", content="Learn programming basics"),
65
+ Article(id=2, title="Python Advanced", content="Advanced programming concepts"),
66
+ Article(
67
+ id=3, title="JavaScript Basics", content="Learn programming with JavaScript"
68
+ ),
69
+ ]
70
+
71
+ await indexed_db_connection.insert(articles)
72
+
73
+ # Test AND operator
74
+ vector = func.to_tsvector("english", Article.title)
75
+ query = func.to_tsquery("english", "python & programming")
76
+ results = await indexed_db_connection.exec(
77
+ select(Article).where(vector.matches(query))
78
+ )
79
+ assert len(results) == 1
80
+ assert results[0].id == 1
81
+
82
+ # Test OR operator
83
+ query = func.to_tsquery("english", "python | javascript")
84
+ results = await indexed_db_connection.exec(
85
+ select(Article).where(vector.matches(query))
86
+ )
87
+ assert len(results) == 3
88
+ assert {r.id for r in results} == {1, 2, 3}
89
+
90
+ # Test NOT operator
91
+ query = func.to_tsquery("english", "programming & !python")
92
+ results = await indexed_db_connection.exec(
93
+ select(Article).where(vector.matches(query))
94
+ )
95
+ assert len(results) == 0 # No articles have "programming" without "python" in title
96
+
97
+
98
+ @pytest.mark.asyncio
99
+ async def test_combined_field_search(indexed_db_connection: DBConnection):
100
+ """Test searching across multiple fields."""
101
+ articles = [
102
+ Article(
103
+ id=1,
104
+ title="Python Guide",
105
+ content="Learn programming basics",
106
+ summary="A beginner's guide to Python",
107
+ ),
108
+ Article(
109
+ id=2,
110
+ title="Programming Tips",
111
+ content="Python best practices",
112
+ summary="Advanced Python concepts",
113
+ ),
114
+ ]
115
+
116
+ await indexed_db_connection.insert(articles)
117
+
118
+ # Search across all fields using list syntax
119
+ vector = func.to_tsvector(
120
+ "english", [Article.title, Article.content, Article.summary]
121
+ )
122
+ query = func.to_tsquery("english", "python & guide")
123
+
124
+ results = await indexed_db_connection.exec(
125
+ select(Article).where(vector.matches(query))
126
+ )
127
+ assert len(results) == 1
128
+ assert results[0].id == 1 # Only first article has both "python" and "guide"
129
+
130
+ # Test the original concatenation syntax still works
131
+ vector_concat = (
132
+ func.to_tsvector("english", Article.title)
133
+ .concat(func.to_tsvector("english", Article.content))
134
+ .concat(func.to_tsvector("english", Article.summary))
135
+ )
136
+ query = func.to_tsquery("english", "python & guide")
137
+
138
+ results_concat = await indexed_db_connection.exec(
139
+ select(Article).where(vector_concat.matches(query))
140
+ )
141
+ assert len(results_concat) == 1
142
+ assert results_concat[0].id == 1 # Results should be the same with both approaches
143
+
144
+
145
+ @pytest.mark.asyncio
146
+ async def test_weighted_text_search(indexed_db_connection: DBConnection):
147
+ """Test text search with weighted columns."""
148
+ articles = [
149
+ Article(
150
+ id=1,
151
+ title="Python Guide", # Weight A
152
+ content="Basic Python", # Weight B
153
+ summary="Python tutorial", # Weight C
154
+ ),
155
+ Article(
156
+ id=2,
157
+ title="Programming",
158
+ content="Python Guide",
159
+ summary="Guide to programming",
160
+ ),
161
+ ]
162
+
163
+ await indexed_db_connection.insert(articles)
164
+
165
+ # Search with weights
166
+ vector = (
167
+ func.setweight(func.to_tsvector("english", Article.title), "A")
168
+ .concat(func.setweight(func.to_tsvector("english", Article.content), "B"))
169
+ .concat(func.setweight(func.to_tsvector("english", Article.summary), "C"))
170
+ )
171
+ query = func.to_tsquery("english", "python & guide")
172
+
173
+ results = await indexed_db_connection.exec(
174
+ select((Article, alias("ts_rank", func.ts_rank(vector, query))))
175
+ .where(vector.matches(query))
176
+ .order_by("ts_rank", direction="DESC"),
177
+ )
178
+ assert len(results) == 2
179
+ # First article should rank higher because "Python Guide" is in title (weight A)
180
+ assert results[0][0].id == 1
181
+ assert results[1][0].id == 2
182
+ assert results[0][1] > results[1][1] # Check that rank is higher
183
+
184
+
185
+ @pytest.mark.asyncio
186
+ async def test_weight_priority_variants(indexed_db_connection: DBConnection):
187
+ """Test text search using both string literals and LexemePriority enum for weights."""
188
+ articles = [
189
+ Article(
190
+ id=1,
191
+ title="Python Guide", # Weight A (string literal)
192
+ content="Basic Python", # Weight B (enum)
193
+ summary="Python tutorial", # Weight C (enum)
194
+ ),
195
+ Article(
196
+ id=2,
197
+ title="Programming",
198
+ content="Python Guide",
199
+ summary="Guide to programming",
200
+ ),
201
+ ]
202
+
203
+ # Create a variant of Article using the enum
204
+ class ArticleWithEnum(TableBase):
205
+ id: int = Field(primary_key=True)
206
+ title: str = Field(
207
+ postgres_config=PostgresFullText(
208
+ language="english", weight=LexemePriority.HIGHEST
209
+ )
210
+ )
211
+ content: str = Field(
212
+ postgres_config=PostgresFullText(
213
+ language="english", weight=LexemePriority.HIGH
214
+ )
215
+ )
216
+ summary: Optional[str] = Field(
217
+ default=None,
218
+ postgres_config=PostgresFullText(
219
+ language="english", weight=LexemePriority.LOW
220
+ ),
221
+ )
222
+
223
+ # Verify both models can be created and weights are equivalent
224
+ assert (
225
+ cast(
226
+ PostgresFullText,
227
+ cast(DBFieldInfo, Article.model_fields["title"]).postgres_config,
228
+ ).weight
229
+ == cast(
230
+ PostgresFullText,
231
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["title"]).postgres_config,
232
+ ).weight
233
+ == "A"
234
+ )
235
+ assert (
236
+ cast(
237
+ PostgresFullText,
238
+ cast(DBFieldInfo, Article.model_fields["content"]).postgres_config,
239
+ ).weight
240
+ == cast(
241
+ PostgresFullText,
242
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["content"]).postgres_config,
243
+ ).weight
244
+ == "B"
245
+ )
246
+ assert (
247
+ cast(
248
+ PostgresFullText,
249
+ cast(DBFieldInfo, Article.model_fields["summary"]).postgres_config,
250
+ ).weight
251
+ == cast(
252
+ PostgresFullText,
253
+ cast(DBFieldInfo, ArticleWithEnum.model_fields["summary"]).postgres_config,
254
+ ).weight
255
+ == "C"
256
+ )
257
+
258
+ await indexed_db_connection.insert(articles)
259
+
260
+ vector = (
261
+ func.setweight(
262
+ func.to_tsvector("english", Article.title), LexemePriority.HIGHEST
263
+ )
264
+ .concat(
265
+ func.setweight(
266
+ func.to_tsvector("english", Article.content), LexemePriority.HIGH
267
+ )
268
+ )
269
+ .concat(
270
+ func.setweight(
271
+ func.to_tsvector("english", Article.summary), LexemePriority.LOW
272
+ )
273
+ )
274
+ )
275
+ query = func.to_tsquery("english", "python & guide")
276
+
277
+ results = await indexed_db_connection.exec(
278
+ select((Article, alias("ts_rank", func.ts_rank(vector, query))))
279
+ .where(vector.matches(query))
280
+ .order_by("ts_rank", direction="DESC"),
281
+ )
282
+
283
+ assert len(results) == 2
284
+ # First article should rank higher because "Python Guide" is in title (HIGHEST weight)
285
+ assert results[0][0].id == 1
286
+ assert results[1][0].id == 2
287
+ assert results[0][1] > results[1][1]
iceaxe/alias_values.py ADDED
@@ -0,0 +1,67 @@
1
+ from dataclasses import dataclass
2
+ from typing import Generic, TypeVar, cast
3
+
4
+ T = TypeVar("T")
5
+
6
+
7
+ @dataclass(frozen=True, slots=True)
8
+ class Alias(Generic[T]):
9
+ name: str
10
+ value: T
11
+
12
+ def __str__(self):
13
+ return self.name
14
+
15
+
16
+ def alias(name: str, type: T) -> T:
17
+ """
18
+ Creates an alias for a field in raw SQL queries, allowing for type-safe mapping of raw SQL results.
19
+ This is particularly useful in two main scenarios:
20
+
21
+ 1. When using raw SQL queries with aliased columns:
22
+ ```python
23
+ # Map a COUNT(*) result to an integer
24
+ query = select(alias("user_count", int)).text(
25
+ "SELECT COUNT(*) AS user_count FROM users"
26
+ )
27
+
28
+ # Map multiple aliased columns with different types
29
+ query = select((
30
+ alias("full_name", str),
31
+ alias("order_count", int),
32
+ alias("total_spent", float)
33
+ )).text(
34
+ '''
35
+ SELECT
36
+ concat(first_name, ' ', last_name) AS full_name,
37
+ COUNT(orders.id) AS order_count,
38
+ SUM(orders.amount) AS total_spent
39
+ FROM users
40
+ LEFT JOIN orders ON users.id = orders.user_id
41
+ GROUP BY users.id
42
+ '''
43
+ )
44
+ ```
45
+
46
+ 2. When combining ORM models with function results:
47
+ ```python
48
+ # Select a model alongside a function result
49
+ query = select((
50
+ User,
51
+ alias("name_length", func.length(User.name)),
52
+ alias("upper_name", func.upper(User.name))
53
+ ))
54
+
55
+ # Use with aggregation functions
56
+ query = select((
57
+ User,
58
+ alias("total_orders", func.count(Order.id))
59
+ )).join(Order, User.id == Order.user_id).group_by(User.id)
60
+ ```
61
+
62
+ :param name: The name of the alias as it appears in the SQL query's AS clause
63
+ :param type: Either a Python type to cast the result to (e.g., int, str, float) or
64
+ a function metadata object (e.g., from func.length())
65
+ :return: A type-safe alias that can be used in select() statements
66
+ """
67
+ return cast(T, Alias(name, type))
iceaxe/base.py ADDED
@@ -0,0 +1,351 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Callable,
5
+ ClassVar,
6
+ Self,
7
+ Type,
8
+ dataclass_transform,
9
+ )
10
+
11
+ from pydantic import BaseModel, Field as PydanticField
12
+ from pydantic.main import _model_construction
13
+ from pydantic_core import PydanticUndefined
14
+
15
+ from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field
16
+
17
+
18
+ @dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,))
19
+ class DBModelMetaclass(_model_construction.ModelMetaclass):
20
+ """
21
+ Metaclass for database model classes that provides automatic field tracking and SQL query generation.
22
+ Extends Pydantic's model metaclass to add database-specific functionality.
23
+
24
+ This metaclass provides:
25
+ - Automatic field to SQL column mapping
26
+ - Dynamic field access that returns query-compatible field definitions
27
+ - Registry tracking of all database model classes
28
+ - Support for generic model instantiation
29
+
30
+ ```python {{sticky: True}}
31
+ class User(TableBase): # Uses DBModelMetaclass
32
+ id: int = Field(primary_key=True)
33
+ name: str
34
+ email: str | None
35
+
36
+ # Fields can be accessed for queries
37
+ User.id # Returns DBFieldClassDefinition
38
+ User.name # Returns DBFieldClassDefinition
39
+
40
+ # Metaclass handles model registration
41
+ registered_models = DBModelMetaclass.get_registry()
42
+ ```
43
+ """
44
+
45
+ _registry: list[Type["TableBase"]] = []
46
+ _cached_args: dict[Type["TableBase"], dict[str, Any]] = {}
47
+ is_constructing: bool = False
48
+
49
+ def __new__(mcs, name, bases, namespace, **kwargs):
50
+ """
51
+ Create a new database model class with proper field tracking.
52
+ Handles registration of the model and processes any table-specific arguments.
53
+ """
54
+ raw_kwargs = {**kwargs}
55
+
56
+ mcs.is_constructing = True
57
+ autodetect = mcs._extract_kwarg(kwargs, "autodetect", True)
58
+ cls = super().__new__(mcs, name, bases, namespace, **kwargs)
59
+ mcs.is_constructing = False
60
+
61
+ # Allow future calls to subclasses / generic instantiations to reference the same
62
+ # kwargs as the base class
63
+ mcs._cached_args[cls] = raw_kwargs
64
+
65
+ # If we have already set the class's fields, we should wrap them
66
+ if hasattr(cls, "__pydantic_fields__"):
67
+ cls.__pydantic_fields__ = {
68
+ field: info
69
+ if isinstance(info, DBFieldInfo)
70
+ else DBFieldInfo.extend_field(
71
+ info,
72
+ primary_key=False,
73
+ postgres_config=None,
74
+ foreign_key=None,
75
+ unique=False,
76
+ index=False,
77
+ check_expression=None,
78
+ is_json=False,
79
+ explicit_type=None,
80
+ )
81
+ for field, info in cls.model_fields.items()
82
+ }
83
+
84
+ # Avoid registering HandlerBase itself
85
+ if cls.__name__ not in {"TableBase", "BaseModel"} and autodetect:
86
+ DBModelMetaclass._registry.append(cls)
87
+
88
+ return cls
89
+
90
+ def __getattr__(self, key: str) -> Any:
91
+ """
92
+ Provides dynamic access to model fields as query-compatible definitions.
93
+ When accessing an undefined attribute, checks if it's a model field and returns
94
+ a DBFieldClassDefinition if it is.
95
+
96
+ :param key: The attribute name to access
97
+ :return: Field definition or raises AttributeError
98
+ :raises AttributeError: If the attribute doesn't exist and isn't a model field
99
+ """
100
+ if self.is_constructing:
101
+ return super().__getattr__(key) # type: ignore
102
+
103
+ try:
104
+ return super().__getattr__(key) # type: ignore
105
+ except AttributeError:
106
+ # Determine if this field is defined within the spec
107
+ # If so, return it
108
+ if key in self.model_fields:
109
+ return DBFieldClassDefinition(
110
+ root_model=self, # type: ignore
111
+ key=key,
112
+ field_definition=self.model_fields[key],
113
+ )
114
+ raise
115
+
116
+ @classmethod
117
+ def get_registry(cls) -> list[Type["TableBase"]]:
118
+ """
119
+ Get the set of all registered database model classes.
120
+
121
+ :return: Set of registered TableBase classes
122
+ """
123
+ return cls._registry
124
+
125
+ @classmethod
126
+ def _extract_kwarg(
127
+ cls, kwargs: dict[str, Any], key: str, default: Any = None
128
+ ) -> Any:
129
+ """
130
+ Extract a keyword argument from either standard kwargs or pydantic generic metadata.
131
+ Handles both normal instantiation and pydantic's generic model instantiation.
132
+
133
+ :param kwargs: Dictionary of keyword arguments
134
+ :param key: Key to extract
135
+ :param default: Default value if key not found
136
+ :return: Extracted value or default
137
+ """
138
+ if key in kwargs:
139
+ return kwargs.pop(key)
140
+
141
+ if "__pydantic_generic_metadata__" in kwargs:
142
+ origin_model = kwargs["__pydantic_generic_metadata__"]["origin"]
143
+ if origin_model in cls._cached_args:
144
+ return cls._cached_args[origin_model].get(key, default)
145
+
146
+ return default
147
+
148
+ @property
149
+ def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
150
+ """
151
+ Get the dictionary of model fields and their definitions.
152
+ Overrides the ClassVar typehint from TableBase for proper typing.
153
+
154
+ :return: Dictionary of field names to field definitions
155
+ """
156
+ return getattr(self, "__pydantic_fields__", {}) # type: ignore
157
+
158
+
159
+ class UniqueConstraint(BaseModel):
160
+ """
161
+ Represents a UNIQUE constraint in a database table.
162
+ Ensures that the specified combination of columns contains unique values across all rows.
163
+
164
+ ```python {{sticky: True}}
165
+ class User(TableBase):
166
+ email: str
167
+ tenant_id: int
168
+
169
+ table_args = [
170
+ UniqueConstraint(columns=["email", "tenant_id"])
171
+ ]
172
+ ```
173
+ """
174
+
175
+ columns: list[str]
176
+ """
177
+ List of column names that should have unique values
178
+ """
179
+
180
+
181
+ class IndexConstraint(BaseModel):
182
+ """
183
+ Represents an INDEX on one or more columns in a database table.
184
+ Improves query performance for the specified columns.
185
+
186
+ ```python {{sticky: True}}
187
+ class User(TableBase):
188
+ email: str
189
+ last_login: datetime
190
+
191
+ table_args = [
192
+ IndexConstraint(columns=["last_login"])
193
+ ]
194
+ ```
195
+ """
196
+
197
+ columns: list[str]
198
+ """
199
+ List of column names to create an index on
200
+ """
201
+
202
+
203
+ INTERNAL_TABLE_FIELDS = ["modified_attrs", "modified_attrs_callbacks"]
204
+
205
+
206
+ class TableBase(BaseModel, metaclass=DBModelMetaclass):
207
+ """
208
+ Base class for all database table models.
209
+ Provides the foundation for defining database tables using Python classes with
210
+ type hints and field definitions.
211
+
212
+ Features:
213
+ - Automatic table name generation from class name
214
+ - Support for custom table names
215
+ - Tracking of modified fields for efficient updates
216
+ - Support for unique constraints and indexes
217
+ - Integration with Pydantic for validation
218
+
219
+ ```python {{sticky: True}}
220
+ class User(TableBase):
221
+ # Custom table name (optional)
222
+ table_name = "users"
223
+
224
+ # Fields with types and constraints
225
+ id: int = Field(primary_key=True)
226
+ email: str = Field(unique=True)
227
+ name: str
228
+ is_active: bool = Field(default=True)
229
+
230
+ # Table-level constraints
231
+ table_args = [
232
+ UniqueConstraint(columns=["email"]),
233
+ IndexConstraint(columns=["name"])
234
+ ]
235
+
236
+ # Usage in queries
237
+ query = select(User).where(User.is_active == True)
238
+ users = await conn.execute(query)
239
+ ```
240
+ """
241
+
242
+ if TYPE_CHECKING:
243
+ model_fields: ClassVar[dict[str, DBFieldInfo]] # type: ignore
244
+
245
+ table_name: ClassVar[str] = PydanticUndefined # type: ignore
246
+ """
247
+ Optional custom name for the table
248
+ """
249
+
250
+ table_args: ClassVar[list[UniqueConstraint | IndexConstraint]] = PydanticUndefined # type: ignore
251
+ """
252
+ Table constraints and indexes
253
+ """
254
+
255
+ # Private methods
256
+ modified_attrs: dict[str, Any] = Field(default_factory=dict, exclude=True)
257
+ """
258
+ Dictionary of modified field values since instantiation or the last clear_modified_attributes() call.
259
+ Used to construct differential update queries.
260
+ """
261
+
262
+ modified_attrs_callbacks: list[Callable[[Self], None]] = Field(
263
+ default_factory=list, exclude=True
264
+ )
265
+ """
266
+ List of callbacks to be called when the model is modified.
267
+ """
268
+
269
+ def __setattr__(self, name: str, value: Any) -> None:
270
+ """
271
+ Track modified attributes when fields are updated.
272
+ This allows for efficient database updates by only updating changed fields.
273
+
274
+ :param name: Attribute name
275
+ :param value: New value
276
+ """
277
+ if name in self.__class__.model_fields:
278
+ self.modified_attrs[name] = value
279
+ for callback in self.modified_attrs_callbacks:
280
+ callback(self)
281
+ super().__setattr__(name, value)
282
+
283
+ def get_modified_attributes(self) -> dict[str, Any]:
284
+ """
285
+ Get the dictionary of attributes that have been modified since instantiation
286
+ or the last clear_modified_attributes() call.
287
+
288
+ :return: Dictionary of modified attribute names and their values
289
+ """
290
+ return self.modified_attrs
291
+
292
+ def clear_modified_attributes(self) -> None:
293
+ """
294
+ Clear the tracking of modified attributes.
295
+ Typically called after successfully saving changes to the database.
296
+ """
297
+ self.modified_attrs.clear()
298
+
299
+ @classmethod
300
+ def get_table_name(cls) -> str:
301
+ """
302
+ Get the table name for this model.
303
+ Uses the custom table_name if set, otherwise converts the class name to lowercase.
304
+
305
+ :return: Table name to use in SQL queries
306
+ """
307
+ if cls.table_name == PydanticUndefined:
308
+ return cls.__name__.lower()
309
+ return cls.table_name
310
+
311
+ @classmethod
312
+ def get_client_fields(cls) -> dict[str, DBFieldInfo]:
313
+ """
314
+ Get all fields that should be exposed to clients.
315
+ Excludes internal fields used for model functionality.
316
+
317
+ :return: Dictionary of field names to field definitions
318
+ """
319
+ return {
320
+ field: info
321
+ for field, info in cls.model_fields.items()
322
+ if field not in INTERNAL_TABLE_FIELDS
323
+ }
324
+
325
+ def register_modified_callback(self, callback: Callable[[Self], None]) -> None:
326
+ """
327
+ Register a callback to be called when the model is modified.
328
+ """
329
+ self.modified_attrs_callbacks.append(callback)
330
+
331
+ def __eq__(self, other: Any) -> bool:
332
+ """
333
+ Compare two model instances, ignoring modified_attrs_callbacks.
334
+ This ensures that two models with the same data but different callbacks are considered equal.
335
+ """
336
+ if not isinstance(other, self.__class__):
337
+ return False
338
+
339
+ # Get all fields except modified_attrs_callbacks
340
+ fields = {
341
+ field: value
342
+ for field, value in self.__dict__.items()
343
+ if field not in INTERNAL_TABLE_FIELDS
344
+ }
345
+ other_fields = {
346
+ field: value
347
+ for field, value in other.__dict__.items()
348
+ if field not in INTERNAL_TABLE_FIELDS
349
+ }
350
+
351
+ return fields == other_fields