iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of iceaxe might be problematic. Click here for more details.
- iceaxe/__init__.py +20 -0
- iceaxe/__tests__/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/__init__.py +0 -0
- iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
- iceaxe/__tests__/benchmarks/test_select.py +114 -0
- iceaxe/__tests__/conf_models.py +133 -0
- iceaxe/__tests__/conftest.py +204 -0
- iceaxe/__tests__/docker_helpers.py +208 -0
- iceaxe/__tests__/helpers.py +268 -0
- iceaxe/__tests__/migrations/__init__.py +0 -0
- iceaxe/__tests__/migrations/conftest.py +36 -0
- iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
- iceaxe/__tests__/migrations/test_generator.py +140 -0
- iceaxe/__tests__/migrations/test_generics.py +91 -0
- iceaxe/__tests__/mountaineer/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
- iceaxe/__tests__/schemas/__init__.py +0 -0
- iceaxe/__tests__/schemas/test_actions.py +1265 -0
- iceaxe/__tests__/schemas/test_cli.py +25 -0
- iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
- iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
- iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
- iceaxe/__tests__/test_alias.py +83 -0
- iceaxe/__tests__/test_base.py +52 -0
- iceaxe/__tests__/test_comparison.py +383 -0
- iceaxe/__tests__/test_field.py +11 -0
- iceaxe/__tests__/test_helpers.py +9 -0
- iceaxe/__tests__/test_modifications.py +151 -0
- iceaxe/__tests__/test_queries.py +764 -0
- iceaxe/__tests__/test_queries_str.py +173 -0
- iceaxe/__tests__/test_session.py +1511 -0
- iceaxe/__tests__/test_text_search.py +287 -0
- iceaxe/alias_values.py +67 -0
- iceaxe/base.py +351 -0
- iceaxe/comparison.py +560 -0
- iceaxe/field.py +263 -0
- iceaxe/functions.py +1432 -0
- iceaxe/generics.py +140 -0
- iceaxe/io.py +107 -0
- iceaxe/logging.py +91 -0
- iceaxe/migrations/__init__.py +5 -0
- iceaxe/migrations/action_sorter.py +98 -0
- iceaxe/migrations/cli.py +228 -0
- iceaxe/migrations/client_io.py +62 -0
- iceaxe/migrations/generator.py +404 -0
- iceaxe/migrations/migration.py +86 -0
- iceaxe/migrations/migrator.py +101 -0
- iceaxe/modifications.py +176 -0
- iceaxe/mountaineer/__init__.py +10 -0
- iceaxe/mountaineer/cli.py +74 -0
- iceaxe/mountaineer/config.py +46 -0
- iceaxe/mountaineer/dependencies/__init__.py +6 -0
- iceaxe/mountaineer/dependencies/core.py +67 -0
- iceaxe/postgres.py +133 -0
- iceaxe/py.typed +0 -0
- iceaxe/queries.py +1459 -0
- iceaxe/queries_str.py +294 -0
- iceaxe/schemas/__init__.py +0 -0
- iceaxe/schemas/actions.py +864 -0
- iceaxe/schemas/cli.py +30 -0
- iceaxe/schemas/db_memory_serializer.py +711 -0
- iceaxe/schemas/db_serializer.py +347 -0
- iceaxe/schemas/db_stubs.py +529 -0
- iceaxe/session.py +860 -0
- iceaxe/session_optimized.c +12207 -0
- iceaxe/session_optimized.cpython-313-darwin.so +0 -0
- iceaxe/session_optimized.pyx +212 -0
- iceaxe/sql_types.py +149 -0
- iceaxe/typing.py +73 -0
- iceaxe-0.8.3.dist-info/METADATA +262 -0
- iceaxe-0.8.3.dist-info/RECORD +75 -0
- iceaxe-0.8.3.dist-info/WHEEL +6 -0
- iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
- iceaxe-0.8.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
from typing import Optional, TypeVar, cast
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
from iceaxe import Field, TableBase, alias, func, select
|
|
6
|
+
from iceaxe.field import DBFieldInfo
|
|
7
|
+
from iceaxe.postgres import LexemePriority, PostgresFullText
|
|
8
|
+
from iceaxe.session import DBConnection
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Article(TableBase):
|
|
14
|
+
"""Test model for full-text search."""
|
|
15
|
+
|
|
16
|
+
id: int = Field(primary_key=True)
|
|
17
|
+
title: str = Field(postgres_config=PostgresFullText(language="english", weight="A"))
|
|
18
|
+
content: str = Field(
|
|
19
|
+
postgres_config=PostgresFullText(language="english", weight="B")
|
|
20
|
+
)
|
|
21
|
+
summary: Optional[str] = Field(
|
|
22
|
+
default=None, postgres_config=PostgresFullText(language="english", weight="C")
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@pytest.mark.asyncio
|
|
27
|
+
async def test_basic_text_search(indexed_db_connection: DBConnection):
|
|
28
|
+
"""Test basic text search functionality using query builder."""
|
|
29
|
+
# Create test data
|
|
30
|
+
articles = [
|
|
31
|
+
Article(
|
|
32
|
+
id=1, title="Python Programming", content="Learn Python programming basics"
|
|
33
|
+
),
|
|
34
|
+
Article(
|
|
35
|
+
id=2, title="Database Design", content="Python and database design patterns"
|
|
36
|
+
),
|
|
37
|
+
Article(id=3, title="Web Development", content="Building web apps with Python"),
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
await indexed_db_connection.insert(articles)
|
|
41
|
+
|
|
42
|
+
# Search in title only
|
|
43
|
+
title_vector = func.to_tsvector("english", Article.title)
|
|
44
|
+
query = func.to_tsquery("english", "python")
|
|
45
|
+
|
|
46
|
+
results = await indexed_db_connection.exec(
|
|
47
|
+
select(Article).where(title_vector.matches(query))
|
|
48
|
+
)
|
|
49
|
+
assert len(results) == 1
|
|
50
|
+
assert results[0].id == 1
|
|
51
|
+
|
|
52
|
+
# Search in content only
|
|
53
|
+
content_vector = func.to_tsvector("english", Article.content)
|
|
54
|
+
results = await indexed_db_connection.exec(
|
|
55
|
+
select(Article).where(content_vector.matches(query))
|
|
56
|
+
)
|
|
57
|
+
assert len(results) == 3 # All articles mention Python in content
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@pytest.mark.asyncio
|
|
61
|
+
async def test_complex_text_search(indexed_db_connection: DBConnection):
|
|
62
|
+
"""Test complex text search queries with boolean operators."""
|
|
63
|
+
articles = [
|
|
64
|
+
Article(id=1, title="Python Programming", content="Learn programming basics"),
|
|
65
|
+
Article(id=2, title="Python Advanced", content="Advanced programming concepts"),
|
|
66
|
+
Article(
|
|
67
|
+
id=3, title="JavaScript Basics", content="Learn programming with JavaScript"
|
|
68
|
+
),
|
|
69
|
+
]
|
|
70
|
+
|
|
71
|
+
await indexed_db_connection.insert(articles)
|
|
72
|
+
|
|
73
|
+
# Test AND operator
|
|
74
|
+
vector = func.to_tsvector("english", Article.title)
|
|
75
|
+
query = func.to_tsquery("english", "python & programming")
|
|
76
|
+
results = await indexed_db_connection.exec(
|
|
77
|
+
select(Article).where(vector.matches(query))
|
|
78
|
+
)
|
|
79
|
+
assert len(results) == 1
|
|
80
|
+
assert results[0].id == 1
|
|
81
|
+
|
|
82
|
+
# Test OR operator
|
|
83
|
+
query = func.to_tsquery("english", "python | javascript")
|
|
84
|
+
results = await indexed_db_connection.exec(
|
|
85
|
+
select(Article).where(vector.matches(query))
|
|
86
|
+
)
|
|
87
|
+
assert len(results) == 3
|
|
88
|
+
assert {r.id for r in results} == {1, 2, 3}
|
|
89
|
+
|
|
90
|
+
# Test NOT operator
|
|
91
|
+
query = func.to_tsquery("english", "programming & !python")
|
|
92
|
+
results = await indexed_db_connection.exec(
|
|
93
|
+
select(Article).where(vector.matches(query))
|
|
94
|
+
)
|
|
95
|
+
assert len(results) == 0 # No articles have "programming" without "python" in title
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@pytest.mark.asyncio
|
|
99
|
+
async def test_combined_field_search(indexed_db_connection: DBConnection):
|
|
100
|
+
"""Test searching across multiple fields."""
|
|
101
|
+
articles = [
|
|
102
|
+
Article(
|
|
103
|
+
id=1,
|
|
104
|
+
title="Python Guide",
|
|
105
|
+
content="Learn programming basics",
|
|
106
|
+
summary="A beginner's guide to Python",
|
|
107
|
+
),
|
|
108
|
+
Article(
|
|
109
|
+
id=2,
|
|
110
|
+
title="Programming Tips",
|
|
111
|
+
content="Python best practices",
|
|
112
|
+
summary="Advanced Python concepts",
|
|
113
|
+
),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
await indexed_db_connection.insert(articles)
|
|
117
|
+
|
|
118
|
+
# Search across all fields using list syntax
|
|
119
|
+
vector = func.to_tsvector(
|
|
120
|
+
"english", [Article.title, Article.content, Article.summary]
|
|
121
|
+
)
|
|
122
|
+
query = func.to_tsquery("english", "python & guide")
|
|
123
|
+
|
|
124
|
+
results = await indexed_db_connection.exec(
|
|
125
|
+
select(Article).where(vector.matches(query))
|
|
126
|
+
)
|
|
127
|
+
assert len(results) == 1
|
|
128
|
+
assert results[0].id == 1 # Only first article has both "python" and "guide"
|
|
129
|
+
|
|
130
|
+
# Test the original concatenation syntax still works
|
|
131
|
+
vector_concat = (
|
|
132
|
+
func.to_tsvector("english", Article.title)
|
|
133
|
+
.concat(func.to_tsvector("english", Article.content))
|
|
134
|
+
.concat(func.to_tsvector("english", Article.summary))
|
|
135
|
+
)
|
|
136
|
+
query = func.to_tsquery("english", "python & guide")
|
|
137
|
+
|
|
138
|
+
results_concat = await indexed_db_connection.exec(
|
|
139
|
+
select(Article).where(vector_concat.matches(query))
|
|
140
|
+
)
|
|
141
|
+
assert len(results_concat) == 1
|
|
142
|
+
assert results_concat[0].id == 1 # Results should be the same with both approaches
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@pytest.mark.asyncio
|
|
146
|
+
async def test_weighted_text_search(indexed_db_connection: DBConnection):
|
|
147
|
+
"""Test text search with weighted columns."""
|
|
148
|
+
articles = [
|
|
149
|
+
Article(
|
|
150
|
+
id=1,
|
|
151
|
+
title="Python Guide", # Weight A
|
|
152
|
+
content="Basic Python", # Weight B
|
|
153
|
+
summary="Python tutorial", # Weight C
|
|
154
|
+
),
|
|
155
|
+
Article(
|
|
156
|
+
id=2,
|
|
157
|
+
title="Programming",
|
|
158
|
+
content="Python Guide",
|
|
159
|
+
summary="Guide to programming",
|
|
160
|
+
),
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
await indexed_db_connection.insert(articles)
|
|
164
|
+
|
|
165
|
+
# Search with weights
|
|
166
|
+
vector = (
|
|
167
|
+
func.setweight(func.to_tsvector("english", Article.title), "A")
|
|
168
|
+
.concat(func.setweight(func.to_tsvector("english", Article.content), "B"))
|
|
169
|
+
.concat(func.setweight(func.to_tsvector("english", Article.summary), "C"))
|
|
170
|
+
)
|
|
171
|
+
query = func.to_tsquery("english", "python & guide")
|
|
172
|
+
|
|
173
|
+
results = await indexed_db_connection.exec(
|
|
174
|
+
select((Article, alias("ts_rank", func.ts_rank(vector, query))))
|
|
175
|
+
.where(vector.matches(query))
|
|
176
|
+
.order_by("ts_rank", direction="DESC"),
|
|
177
|
+
)
|
|
178
|
+
assert len(results) == 2
|
|
179
|
+
# First article should rank higher because "Python Guide" is in title (weight A)
|
|
180
|
+
assert results[0][0].id == 1
|
|
181
|
+
assert results[1][0].id == 2
|
|
182
|
+
assert results[0][1] > results[1][1] # Check that rank is higher
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@pytest.mark.asyncio
|
|
186
|
+
async def test_weight_priority_variants(indexed_db_connection: DBConnection):
|
|
187
|
+
"""Test text search using both string literals and LexemePriority enum for weights."""
|
|
188
|
+
articles = [
|
|
189
|
+
Article(
|
|
190
|
+
id=1,
|
|
191
|
+
title="Python Guide", # Weight A (string literal)
|
|
192
|
+
content="Basic Python", # Weight B (enum)
|
|
193
|
+
summary="Python tutorial", # Weight C (enum)
|
|
194
|
+
),
|
|
195
|
+
Article(
|
|
196
|
+
id=2,
|
|
197
|
+
title="Programming",
|
|
198
|
+
content="Python Guide",
|
|
199
|
+
summary="Guide to programming",
|
|
200
|
+
),
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
# Create a variant of Article using the enum
|
|
204
|
+
class ArticleWithEnum(TableBase):
|
|
205
|
+
id: int = Field(primary_key=True)
|
|
206
|
+
title: str = Field(
|
|
207
|
+
postgres_config=PostgresFullText(
|
|
208
|
+
language="english", weight=LexemePriority.HIGHEST
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
content: str = Field(
|
|
212
|
+
postgres_config=PostgresFullText(
|
|
213
|
+
language="english", weight=LexemePriority.HIGH
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
summary: Optional[str] = Field(
|
|
217
|
+
default=None,
|
|
218
|
+
postgres_config=PostgresFullText(
|
|
219
|
+
language="english", weight=LexemePriority.LOW
|
|
220
|
+
),
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Verify both models can be created and weights are equivalent
|
|
224
|
+
assert (
|
|
225
|
+
cast(
|
|
226
|
+
PostgresFullText,
|
|
227
|
+
cast(DBFieldInfo, Article.model_fields["title"]).postgres_config,
|
|
228
|
+
).weight
|
|
229
|
+
== cast(
|
|
230
|
+
PostgresFullText,
|
|
231
|
+
cast(DBFieldInfo, ArticleWithEnum.model_fields["title"]).postgres_config,
|
|
232
|
+
).weight
|
|
233
|
+
== "A"
|
|
234
|
+
)
|
|
235
|
+
assert (
|
|
236
|
+
cast(
|
|
237
|
+
PostgresFullText,
|
|
238
|
+
cast(DBFieldInfo, Article.model_fields["content"]).postgres_config,
|
|
239
|
+
).weight
|
|
240
|
+
== cast(
|
|
241
|
+
PostgresFullText,
|
|
242
|
+
cast(DBFieldInfo, ArticleWithEnum.model_fields["content"]).postgres_config,
|
|
243
|
+
).weight
|
|
244
|
+
== "B"
|
|
245
|
+
)
|
|
246
|
+
assert (
|
|
247
|
+
cast(
|
|
248
|
+
PostgresFullText,
|
|
249
|
+
cast(DBFieldInfo, Article.model_fields["summary"]).postgres_config,
|
|
250
|
+
).weight
|
|
251
|
+
== cast(
|
|
252
|
+
PostgresFullText,
|
|
253
|
+
cast(DBFieldInfo, ArticleWithEnum.model_fields["summary"]).postgres_config,
|
|
254
|
+
).weight
|
|
255
|
+
== "C"
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
await indexed_db_connection.insert(articles)
|
|
259
|
+
|
|
260
|
+
vector = (
|
|
261
|
+
func.setweight(
|
|
262
|
+
func.to_tsvector("english", Article.title), LexemePriority.HIGHEST
|
|
263
|
+
)
|
|
264
|
+
.concat(
|
|
265
|
+
func.setweight(
|
|
266
|
+
func.to_tsvector("english", Article.content), LexemePriority.HIGH
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
.concat(
|
|
270
|
+
func.setweight(
|
|
271
|
+
func.to_tsvector("english", Article.summary), LexemePriority.LOW
|
|
272
|
+
)
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
query = func.to_tsquery("english", "python & guide")
|
|
276
|
+
|
|
277
|
+
results = await indexed_db_connection.exec(
|
|
278
|
+
select((Article, alias("ts_rank", func.ts_rank(vector, query))))
|
|
279
|
+
.where(vector.matches(query))
|
|
280
|
+
.order_by("ts_rank", direction="DESC"),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
assert len(results) == 2
|
|
284
|
+
# First article should rank higher because "Python Guide" is in title (HIGHEST weight)
|
|
285
|
+
assert results[0][0].id == 1
|
|
286
|
+
assert results[1][0].id == 2
|
|
287
|
+
assert results[0][1] > results[1][1]
|
iceaxe/alias_values.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Generic, TypeVar, cast
|
|
3
|
+
|
|
4
|
+
T = TypeVar("T")
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True, slots=True)
|
|
8
|
+
class Alias(Generic[T]):
|
|
9
|
+
name: str
|
|
10
|
+
value: T
|
|
11
|
+
|
|
12
|
+
def __str__(self):
|
|
13
|
+
return self.name
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def alias(name: str, type: T) -> T:
|
|
17
|
+
"""
|
|
18
|
+
Creates an alias for a field in raw SQL queries, allowing for type-safe mapping of raw SQL results.
|
|
19
|
+
This is particularly useful in two main scenarios:
|
|
20
|
+
|
|
21
|
+
1. When using raw SQL queries with aliased columns:
|
|
22
|
+
```python
|
|
23
|
+
# Map a COUNT(*) result to an integer
|
|
24
|
+
query = select(alias("user_count", int)).text(
|
|
25
|
+
"SELECT COUNT(*) AS user_count FROM users"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Map multiple aliased columns with different types
|
|
29
|
+
query = select((
|
|
30
|
+
alias("full_name", str),
|
|
31
|
+
alias("order_count", int),
|
|
32
|
+
alias("total_spent", float)
|
|
33
|
+
)).text(
|
|
34
|
+
'''
|
|
35
|
+
SELECT
|
|
36
|
+
concat(first_name, ' ', last_name) AS full_name,
|
|
37
|
+
COUNT(orders.id) AS order_count,
|
|
38
|
+
SUM(orders.amount) AS total_spent
|
|
39
|
+
FROM users
|
|
40
|
+
LEFT JOIN orders ON users.id = orders.user_id
|
|
41
|
+
GROUP BY users.id
|
|
42
|
+
'''
|
|
43
|
+
)
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
2. When combining ORM models with function results:
|
|
47
|
+
```python
|
|
48
|
+
# Select a model alongside a function result
|
|
49
|
+
query = select((
|
|
50
|
+
User,
|
|
51
|
+
alias("name_length", func.length(User.name)),
|
|
52
|
+
alias("upper_name", func.upper(User.name))
|
|
53
|
+
))
|
|
54
|
+
|
|
55
|
+
# Use with aggregation functions
|
|
56
|
+
query = select((
|
|
57
|
+
User,
|
|
58
|
+
alias("total_orders", func.count(Order.id))
|
|
59
|
+
)).join(Order, User.id == Order.user_id).group_by(User.id)
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
:param name: The name of the alias as it appears in the SQL query's AS clause
|
|
63
|
+
:param type: Either a Python type to cast the result to (e.g., int, str, float) or
|
|
64
|
+
a function metadata object (e.g., from func.length())
|
|
65
|
+
:return: A type-safe alias that can be used in select() statements
|
|
66
|
+
"""
|
|
67
|
+
return cast(T, Alias(name, type))
|
iceaxe/base.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
TYPE_CHECKING,
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
ClassVar,
|
|
6
|
+
Self,
|
|
7
|
+
Type,
|
|
8
|
+
dataclass_transform,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
from pydantic import BaseModel, Field as PydanticField
|
|
12
|
+
from pydantic.main import _model_construction
|
|
13
|
+
from pydantic_core import PydanticUndefined
|
|
14
|
+
|
|
15
|
+
from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,))
|
|
19
|
+
class DBModelMetaclass(_model_construction.ModelMetaclass):
|
|
20
|
+
"""
|
|
21
|
+
Metaclass for database model classes that provides automatic field tracking and SQL query generation.
|
|
22
|
+
Extends Pydantic's model metaclass to add database-specific functionality.
|
|
23
|
+
|
|
24
|
+
This metaclass provides:
|
|
25
|
+
- Automatic field to SQL column mapping
|
|
26
|
+
- Dynamic field access that returns query-compatible field definitions
|
|
27
|
+
- Registry tracking of all database model classes
|
|
28
|
+
- Support for generic model instantiation
|
|
29
|
+
|
|
30
|
+
```python {{sticky: True}}
|
|
31
|
+
class User(TableBase): # Uses DBModelMetaclass
|
|
32
|
+
id: int = Field(primary_key=True)
|
|
33
|
+
name: str
|
|
34
|
+
email: str | None
|
|
35
|
+
|
|
36
|
+
# Fields can be accessed for queries
|
|
37
|
+
User.id # Returns DBFieldClassDefinition
|
|
38
|
+
User.name # Returns DBFieldClassDefinition
|
|
39
|
+
|
|
40
|
+
# Metaclass handles model registration
|
|
41
|
+
registered_models = DBModelMetaclass.get_registry()
|
|
42
|
+
```
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_registry: list[Type["TableBase"]] = []
|
|
46
|
+
_cached_args: dict[Type["TableBase"], dict[str, Any]] = {}
|
|
47
|
+
is_constructing: bool = False
|
|
48
|
+
|
|
49
|
+
def __new__(mcs, name, bases, namespace, **kwargs):
|
|
50
|
+
"""
|
|
51
|
+
Create a new database model class with proper field tracking.
|
|
52
|
+
Handles registration of the model and processes any table-specific arguments.
|
|
53
|
+
"""
|
|
54
|
+
raw_kwargs = {**kwargs}
|
|
55
|
+
|
|
56
|
+
mcs.is_constructing = True
|
|
57
|
+
autodetect = mcs._extract_kwarg(kwargs, "autodetect", True)
|
|
58
|
+
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
|
|
59
|
+
mcs.is_constructing = False
|
|
60
|
+
|
|
61
|
+
# Allow future calls to subclasses / generic instantiations to reference the same
|
|
62
|
+
# kwargs as the base class
|
|
63
|
+
mcs._cached_args[cls] = raw_kwargs
|
|
64
|
+
|
|
65
|
+
# If we have already set the class's fields, we should wrap them
|
|
66
|
+
if hasattr(cls, "__pydantic_fields__"):
|
|
67
|
+
cls.__pydantic_fields__ = {
|
|
68
|
+
field: info
|
|
69
|
+
if isinstance(info, DBFieldInfo)
|
|
70
|
+
else DBFieldInfo.extend_field(
|
|
71
|
+
info,
|
|
72
|
+
primary_key=False,
|
|
73
|
+
postgres_config=None,
|
|
74
|
+
foreign_key=None,
|
|
75
|
+
unique=False,
|
|
76
|
+
index=False,
|
|
77
|
+
check_expression=None,
|
|
78
|
+
is_json=False,
|
|
79
|
+
explicit_type=None,
|
|
80
|
+
)
|
|
81
|
+
for field, info in cls.model_fields.items()
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
# Avoid registering HandlerBase itself
|
|
85
|
+
if cls.__name__ not in {"TableBase", "BaseModel"} and autodetect:
|
|
86
|
+
DBModelMetaclass._registry.append(cls)
|
|
87
|
+
|
|
88
|
+
return cls
|
|
89
|
+
|
|
90
|
+
def __getattr__(self, key: str) -> Any:
|
|
91
|
+
"""
|
|
92
|
+
Provides dynamic access to model fields as query-compatible definitions.
|
|
93
|
+
When accessing an undefined attribute, checks if it's a model field and returns
|
|
94
|
+
a DBFieldClassDefinition if it is.
|
|
95
|
+
|
|
96
|
+
:param key: The attribute name to access
|
|
97
|
+
:return: Field definition or raises AttributeError
|
|
98
|
+
:raises AttributeError: If the attribute doesn't exist and isn't a model field
|
|
99
|
+
"""
|
|
100
|
+
if self.is_constructing:
|
|
101
|
+
return super().__getattr__(key) # type: ignore
|
|
102
|
+
|
|
103
|
+
try:
|
|
104
|
+
return super().__getattr__(key) # type: ignore
|
|
105
|
+
except AttributeError:
|
|
106
|
+
# Determine if this field is defined within the spec
|
|
107
|
+
# If so, return it
|
|
108
|
+
if key in self.model_fields:
|
|
109
|
+
return DBFieldClassDefinition(
|
|
110
|
+
root_model=self, # type: ignore
|
|
111
|
+
key=key,
|
|
112
|
+
field_definition=self.model_fields[key],
|
|
113
|
+
)
|
|
114
|
+
raise
|
|
115
|
+
|
|
116
|
+
@classmethod
|
|
117
|
+
def get_registry(cls) -> list[Type["TableBase"]]:
|
|
118
|
+
"""
|
|
119
|
+
Get the set of all registered database model classes.
|
|
120
|
+
|
|
121
|
+
:return: Set of registered TableBase classes
|
|
122
|
+
"""
|
|
123
|
+
return cls._registry
|
|
124
|
+
|
|
125
|
+
@classmethod
|
|
126
|
+
def _extract_kwarg(
|
|
127
|
+
cls, kwargs: dict[str, Any], key: str, default: Any = None
|
|
128
|
+
) -> Any:
|
|
129
|
+
"""
|
|
130
|
+
Extract a keyword argument from either standard kwargs or pydantic generic metadata.
|
|
131
|
+
Handles both normal instantiation and pydantic's generic model instantiation.
|
|
132
|
+
|
|
133
|
+
:param kwargs: Dictionary of keyword arguments
|
|
134
|
+
:param key: Key to extract
|
|
135
|
+
:param default: Default value if key not found
|
|
136
|
+
:return: Extracted value or default
|
|
137
|
+
"""
|
|
138
|
+
if key in kwargs:
|
|
139
|
+
return kwargs.pop(key)
|
|
140
|
+
|
|
141
|
+
if "__pydantic_generic_metadata__" in kwargs:
|
|
142
|
+
origin_model = kwargs["__pydantic_generic_metadata__"]["origin"]
|
|
143
|
+
if origin_model in cls._cached_args:
|
|
144
|
+
return cls._cached_args[origin_model].get(key, default)
|
|
145
|
+
|
|
146
|
+
return default
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
|
|
150
|
+
"""
|
|
151
|
+
Get the dictionary of model fields and their definitions.
|
|
152
|
+
Overrides the ClassVar typehint from TableBase for proper typing.
|
|
153
|
+
|
|
154
|
+
:return: Dictionary of field names to field definitions
|
|
155
|
+
"""
|
|
156
|
+
return getattr(self, "__pydantic_fields__", {}) # type: ignore
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class UniqueConstraint(BaseModel):
|
|
160
|
+
"""
|
|
161
|
+
Represents a UNIQUE constraint in a database table.
|
|
162
|
+
Ensures that the specified combination of columns contains unique values across all rows.
|
|
163
|
+
|
|
164
|
+
```python {{sticky: True}}
|
|
165
|
+
class User(TableBase):
|
|
166
|
+
email: str
|
|
167
|
+
tenant_id: int
|
|
168
|
+
|
|
169
|
+
table_args = [
|
|
170
|
+
UniqueConstraint(columns=["email", "tenant_id"])
|
|
171
|
+
]
|
|
172
|
+
```
|
|
173
|
+
"""
|
|
174
|
+
|
|
175
|
+
columns: list[str]
|
|
176
|
+
"""
|
|
177
|
+
List of column names that should have unique values
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class IndexConstraint(BaseModel):
|
|
182
|
+
"""
|
|
183
|
+
Represents an INDEX on one or more columns in a database table.
|
|
184
|
+
Improves query performance for the specified columns.
|
|
185
|
+
|
|
186
|
+
```python {{sticky: True}}
|
|
187
|
+
class User(TableBase):
|
|
188
|
+
email: str
|
|
189
|
+
last_login: datetime
|
|
190
|
+
|
|
191
|
+
table_args = [
|
|
192
|
+
IndexConstraint(columns=["last_login"])
|
|
193
|
+
]
|
|
194
|
+
```
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
columns: list[str]
|
|
198
|
+
"""
|
|
199
|
+
List of column names to create an index on
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
INTERNAL_TABLE_FIELDS = ["modified_attrs", "modified_attrs_callbacks"]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class TableBase(BaseModel, metaclass=DBModelMetaclass):
|
|
207
|
+
"""
|
|
208
|
+
Base class for all database table models.
|
|
209
|
+
Provides the foundation for defining database tables using Python classes with
|
|
210
|
+
type hints and field definitions.
|
|
211
|
+
|
|
212
|
+
Features:
|
|
213
|
+
- Automatic table name generation from class name
|
|
214
|
+
- Support for custom table names
|
|
215
|
+
- Tracking of modified fields for efficient updates
|
|
216
|
+
- Support for unique constraints and indexes
|
|
217
|
+
- Integration with Pydantic for validation
|
|
218
|
+
|
|
219
|
+
```python {{sticky: True}}
|
|
220
|
+
class User(TableBase):
|
|
221
|
+
# Custom table name (optional)
|
|
222
|
+
table_name = "users"
|
|
223
|
+
|
|
224
|
+
# Fields with types and constraints
|
|
225
|
+
id: int = Field(primary_key=True)
|
|
226
|
+
email: str = Field(unique=True)
|
|
227
|
+
name: str
|
|
228
|
+
is_active: bool = Field(default=True)
|
|
229
|
+
|
|
230
|
+
# Table-level constraints
|
|
231
|
+
table_args = [
|
|
232
|
+
UniqueConstraint(columns=["email"]),
|
|
233
|
+
IndexConstraint(columns=["name"])
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
# Usage in queries
|
|
237
|
+
query = select(User).where(User.is_active == True)
|
|
238
|
+
users = await conn.execute(query)
|
|
239
|
+
```
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
if TYPE_CHECKING:
|
|
243
|
+
model_fields: ClassVar[dict[str, DBFieldInfo]] # type: ignore
|
|
244
|
+
|
|
245
|
+
table_name: ClassVar[str] = PydanticUndefined # type: ignore
|
|
246
|
+
"""
|
|
247
|
+
Optional custom name for the table
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
table_args: ClassVar[list[UniqueConstraint | IndexConstraint]] = PydanticUndefined # type: ignore
|
|
251
|
+
"""
|
|
252
|
+
Table constraints and indexes
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
# Private methods
|
|
256
|
+
modified_attrs: dict[str, Any] = Field(default_factory=dict, exclude=True)
|
|
257
|
+
"""
|
|
258
|
+
Dictionary of modified field values since instantiation or the last clear_modified_attributes() call.
|
|
259
|
+
Used to construct differential update queries.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
modified_attrs_callbacks: list[Callable[[Self], None]] = Field(
|
|
263
|
+
default_factory=list, exclude=True
|
|
264
|
+
)
|
|
265
|
+
"""
|
|
266
|
+
List of callbacks to be called when the model is modified.
|
|
267
|
+
"""
|
|
268
|
+
|
|
269
|
+
def __setattr__(self, name: str, value: Any) -> None:
|
|
270
|
+
"""
|
|
271
|
+
Track modified attributes when fields are updated.
|
|
272
|
+
This allows for efficient database updates by only updating changed fields.
|
|
273
|
+
|
|
274
|
+
:param name: Attribute name
|
|
275
|
+
:param value: New value
|
|
276
|
+
"""
|
|
277
|
+
if name in self.__class__.model_fields:
|
|
278
|
+
self.modified_attrs[name] = value
|
|
279
|
+
for callback in self.modified_attrs_callbacks:
|
|
280
|
+
callback(self)
|
|
281
|
+
super().__setattr__(name, value)
|
|
282
|
+
|
|
283
|
+
def get_modified_attributes(self) -> dict[str, Any]:
|
|
284
|
+
"""
|
|
285
|
+
Get the dictionary of attributes that have been modified since instantiation
|
|
286
|
+
or the last clear_modified_attributes() call.
|
|
287
|
+
|
|
288
|
+
:return: Dictionary of modified attribute names and their values
|
|
289
|
+
"""
|
|
290
|
+
return self.modified_attrs
|
|
291
|
+
|
|
292
|
+
def clear_modified_attributes(self) -> None:
|
|
293
|
+
"""
|
|
294
|
+
Clear the tracking of modified attributes.
|
|
295
|
+
Typically called after successfully saving changes to the database.
|
|
296
|
+
"""
|
|
297
|
+
self.modified_attrs.clear()
|
|
298
|
+
|
|
299
|
+
@classmethod
|
|
300
|
+
def get_table_name(cls) -> str:
|
|
301
|
+
"""
|
|
302
|
+
Get the table name for this model.
|
|
303
|
+
Uses the custom table_name if set, otherwise converts the class name to lowercase.
|
|
304
|
+
|
|
305
|
+
:return: Table name to use in SQL queries
|
|
306
|
+
"""
|
|
307
|
+
if cls.table_name == PydanticUndefined:
|
|
308
|
+
return cls.__name__.lower()
|
|
309
|
+
return cls.table_name
|
|
310
|
+
|
|
311
|
+
@classmethod
|
|
312
|
+
def get_client_fields(cls) -> dict[str, DBFieldInfo]:
|
|
313
|
+
"""
|
|
314
|
+
Get all fields that should be exposed to clients.
|
|
315
|
+
Excludes internal fields used for model functionality.
|
|
316
|
+
|
|
317
|
+
:return: Dictionary of field names to field definitions
|
|
318
|
+
"""
|
|
319
|
+
return {
|
|
320
|
+
field: info
|
|
321
|
+
for field, info in cls.model_fields.items()
|
|
322
|
+
if field not in INTERNAL_TABLE_FIELDS
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
def register_modified_callback(self, callback: Callable[[Self], None]) -> None:
|
|
326
|
+
"""
|
|
327
|
+
Register a callback to be called when the model is modified.
|
|
328
|
+
"""
|
|
329
|
+
self.modified_attrs_callbacks.append(callback)
|
|
330
|
+
|
|
331
|
+
def __eq__(self, other: Any) -> bool:
|
|
332
|
+
"""
|
|
333
|
+
Compare two model instances, ignoring modified_attrs_callbacks.
|
|
334
|
+
This ensures that two models with the same data but different callbacks are considered equal.
|
|
335
|
+
"""
|
|
336
|
+
if not isinstance(other, self.__class__):
|
|
337
|
+
return False
|
|
338
|
+
|
|
339
|
+
# Get all fields except modified_attrs_callbacks
|
|
340
|
+
fields = {
|
|
341
|
+
field: value
|
|
342
|
+
for field, value in self.__dict__.items()
|
|
343
|
+
if field not in INTERNAL_TABLE_FIELDS
|
|
344
|
+
}
|
|
345
|
+
other_fields = {
|
|
346
|
+
field: value
|
|
347
|
+
for field, value in other.__dict__.items()
|
|
348
|
+
if field not in INTERNAL_TABLE_FIELDS
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
return fields == other_fields
|