hypern 0.3.11__cp310-cp310-musllinux_1_2_armv7l.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hypern/__init__.py +24 -0
- hypern/application.py +495 -0
- hypern/args_parser.py +73 -0
- hypern/auth/__init__.py +0 -0
- hypern/auth/authorization.py +2 -0
- hypern/background.py +4 -0
- hypern/caching/__init__.py +6 -0
- hypern/caching/backend.py +31 -0
- hypern/caching/redis_backend.py +201 -0
- hypern/caching/strategies.py +208 -0
- hypern/cli/__init__.py +0 -0
- hypern/cli/commands.py +0 -0
- hypern/config.py +246 -0
- hypern/database/__init__.py +0 -0
- hypern/database/sqlalchemy/__init__.py +4 -0
- hypern/database/sqlalchemy/config.py +66 -0
- hypern/database/sqlalchemy/repository.py +290 -0
- hypern/database/sqlx/__init__.py +36 -0
- hypern/database/sqlx/field.py +246 -0
- hypern/database/sqlx/migrate.py +263 -0
- hypern/database/sqlx/model.py +117 -0
- hypern/database/sqlx/query.py +904 -0
- hypern/datastructures.py +40 -0
- hypern/enum.py +13 -0
- hypern/exceptions/__init__.py +34 -0
- hypern/exceptions/base.py +62 -0
- hypern/exceptions/common.py +12 -0
- hypern/exceptions/errors.py +15 -0
- hypern/exceptions/formatters.py +56 -0
- hypern/exceptions/http.py +76 -0
- hypern/gateway/__init__.py +6 -0
- hypern/gateway/aggregator.py +32 -0
- hypern/gateway/gateway.py +41 -0
- hypern/gateway/proxy.py +60 -0
- hypern/gateway/service.py +52 -0
- hypern/hypern.cpython-310-arm-linux-gnueabihf.so +0 -0
- hypern/hypern.pyi +333 -0
- hypern/i18n/__init__.py +0 -0
- hypern/logging/__init__.py +3 -0
- hypern/logging/logger.py +82 -0
- hypern/middleware/__init__.py +17 -0
- hypern/middleware/base.py +13 -0
- hypern/middleware/cache.py +177 -0
- hypern/middleware/compress.py +78 -0
- hypern/middleware/cors.py +41 -0
- hypern/middleware/i18n.py +1 -0
- hypern/middleware/limit.py +177 -0
- hypern/middleware/security.py +184 -0
- hypern/openapi/__init__.py +5 -0
- hypern/openapi/schemas.py +51 -0
- hypern/openapi/swagger.py +3 -0
- hypern/processpool.py +139 -0
- hypern/py.typed +0 -0
- hypern/reload.py +46 -0
- hypern/response/__init__.py +3 -0
- hypern/response/response.py +142 -0
- hypern/routing/__init__.py +5 -0
- hypern/routing/dispatcher.py +70 -0
- hypern/routing/endpoint.py +30 -0
- hypern/routing/parser.py +98 -0
- hypern/routing/queue.py +175 -0
- hypern/routing/route.py +280 -0
- hypern/scheduler.py +5 -0
- hypern/worker.py +274 -0
- hypern/ws/__init__.py +4 -0
- hypern/ws/channel.py +80 -0
- hypern/ws/heartbeat.py +74 -0
- hypern/ws/room.py +76 -0
- hypern/ws/route.py +26 -0
- hypern-0.3.11.dist-info/METADATA +134 -0
- hypern-0.3.11.dist-info/RECORD +74 -0
- hypern-0.3.11.dist-info/WHEEL +4 -0
- hypern-0.3.11.dist-info/licenses/LICENSE +24 -0
- hypern.libs/libgcc_s-5b5488a6.so.1 +0 -0
@@ -0,0 +1,66 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
import asyncio
|
3
|
+
import traceback
|
4
|
+
from contextlib import asynccontextmanager
|
5
|
+
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_scoped_session
|
7
|
+
from sqlalchemy.orm import Session, sessionmaker
|
8
|
+
from sqlalchemy.sql.expression import Delete, Insert, Update
|
9
|
+
|
10
|
+
|
11
|
+
class SqlalchemyConfig:
|
12
|
+
def __init__(self, default_engine: AsyncEngine | None = None, reader_engine: AsyncEngine | None = None, writer_engine: AsyncEngine | None = None):
|
13
|
+
"""
|
14
|
+
Initialize the SQL configuration.
|
15
|
+
You can provide a default engine, a reader engine, and a writer engine.
|
16
|
+
If only one engine is provided (default_engine), it will be used for both reading and writing.
|
17
|
+
If both reader and writer engines are provided, they will be used for reading and writing respectively.
|
18
|
+
Note: The reader and writer engines must be different.
|
19
|
+
"""
|
20
|
+
|
21
|
+
assert default_engine or reader_engine or writer_engine, "At least one engine must be provided."
|
22
|
+
assert not (reader_engine and writer_engine and id(reader_engine) == id(writer_engine)), "Reader and writer engines must be different."
|
23
|
+
|
24
|
+
engines = {
|
25
|
+
"writer": writer_engine or default_engine,
|
26
|
+
"reader": reader_engine or default_engine,
|
27
|
+
}
|
28
|
+
|
29
|
+
class RoutingSession(Session):
|
30
|
+
def get_bind(this, mapper=None, clause=None, **kwargs):
|
31
|
+
if this._flushing or isinstance(clause, (Update, Delete, Insert)):
|
32
|
+
return engines["writer"].sync_engine
|
33
|
+
return engines["reader"].sync_engine
|
34
|
+
|
35
|
+
async_session_factory = sessionmaker(
|
36
|
+
class_=AsyncSession,
|
37
|
+
sync_session_class=RoutingSession,
|
38
|
+
expire_on_commit=False,
|
39
|
+
)
|
40
|
+
|
41
|
+
session_scope: AsyncSession | async_scoped_session = async_scoped_session(
|
42
|
+
session_factory=async_session_factory,
|
43
|
+
scopefunc=asyncio.current_task,
|
44
|
+
)
|
45
|
+
|
46
|
+
@asynccontextmanager
|
47
|
+
async def get_session():
|
48
|
+
"""
|
49
|
+
Get the database session.
|
50
|
+
This can be used for dependency injection.
|
51
|
+
|
52
|
+
:return: The database session.
|
53
|
+
"""
|
54
|
+
try:
|
55
|
+
yield session_scope
|
56
|
+
except Exception:
|
57
|
+
traceback.print_exc()
|
58
|
+
await session_scope.rollback()
|
59
|
+
finally:
|
60
|
+
await session_scope.remove()
|
61
|
+
await session_scope.close()
|
62
|
+
|
63
|
+
self.get_session = get_session
|
64
|
+
|
65
|
+
def init_app(self, app):
|
66
|
+
app.inject("get_session", self.get_session)
|
@@ -0,0 +1,290 @@
|
|
1
|
+
# -*- coding: utf-8 -*-
|
2
|
+
from functools import reduce
|
3
|
+
from typing import Any, Dict, Generic, Optional, Type, TypeVar
|
4
|
+
|
5
|
+
from sqlalchemy import Select, and_, asc, between, desc, select
|
6
|
+
from sqlalchemy.ext.asyncio import (
|
7
|
+
AsyncSession,
|
8
|
+
)
|
9
|
+
from sqlalchemy.orm import declarative_base
|
10
|
+
from sqlalchemy.sql import func
|
11
|
+
|
12
|
+
Base = declarative_base()
|
13
|
+
|
14
|
+
|
15
|
+
class Model(Base): # type: ignore
|
16
|
+
__abstract__ = True
|
17
|
+
__table_args__ = {"extend_existing": True}
|
18
|
+
|
19
|
+
@property
|
20
|
+
def as_dict(self) -> Dict[str, Any]:
|
21
|
+
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
|
22
|
+
|
23
|
+
|
24
|
+
ModelType = TypeVar("ModelType", bound=Base) # type: ignore
|
25
|
+
|
26
|
+
|
27
|
+
class PostgresRepository(Generic[ModelType]):
|
28
|
+
"""Base class for data repositories."""
|
29
|
+
|
30
|
+
def __init__(self, model: Type[ModelType], db_session: AsyncSession):
|
31
|
+
self.session = db_session # type: ignore
|
32
|
+
self.model_class: Type[ModelType] = model
|
33
|
+
|
34
|
+
async def create(self, attributes: Optional[dict[str, Any]] = None) -> ModelType:
|
35
|
+
"""
|
36
|
+
Creates the model instance.
|
37
|
+
|
38
|
+
:param attributes: The attributes to create the model with.
|
39
|
+
:return: The created model instance.
|
40
|
+
"""
|
41
|
+
if attributes is None:
|
42
|
+
attributes = {}
|
43
|
+
model = self.model_class(**attributes) # type: ignore
|
44
|
+
self.session.add(model)
|
45
|
+
return model
|
46
|
+
|
47
|
+
async def get_all(
|
48
|
+
self,
|
49
|
+
skip: int = 0,
|
50
|
+
limit: int = 100,
|
51
|
+
join_: set[str] | None = None,
|
52
|
+
where: Optional[dict] = None,
|
53
|
+
order_by: tuple[str, str] | None = None,
|
54
|
+
) -> list[ModelType]:
|
55
|
+
"""
|
56
|
+
Returns a list of model instances.
|
57
|
+
|
58
|
+
:param skip: The number of records to skip.
|
59
|
+
:param limit: The number of record to return.
|
60
|
+
:param join_: The joins to make.
|
61
|
+
:param where: The conditions for the WHERE clause.
|
62
|
+
:return: A list of model instances.
|
63
|
+
"""
|
64
|
+
query = self._query(join_)
|
65
|
+
query = query.offset(skip).limit(limit)
|
66
|
+
|
67
|
+
if where is not None:
|
68
|
+
conditions = []
|
69
|
+
for k, v in where.items():
|
70
|
+
if isinstance(v, dict) and "$gt" in v and "$lt" in v:
|
71
|
+
conditions.append(between(getattr(self.model_class, k), v["$gt"], v["$lt"]))
|
72
|
+
else:
|
73
|
+
conditions.append(getattr(self.model_class, k) == v)
|
74
|
+
query = query.where(and_(*conditions))
|
75
|
+
|
76
|
+
if order_by is not None:
|
77
|
+
column, direction = order_by
|
78
|
+
if direction.lower() == "desc":
|
79
|
+
query = query.order_by(desc(getattr(self.model_class, column)))
|
80
|
+
else:
|
81
|
+
query = query.order_by(asc(getattr(self.model_class, column)))
|
82
|
+
|
83
|
+
if join_ is not None:
|
84
|
+
return await self.all_unique(query)
|
85
|
+
return await self._all(query)
|
86
|
+
|
87
|
+
async def get_by(
|
88
|
+
self,
|
89
|
+
field: str,
|
90
|
+
value: Any,
|
91
|
+
join_: set[str] | None = None,
|
92
|
+
unique: bool = False,
|
93
|
+
) -> ModelType | list[ModelType] | None:
|
94
|
+
"""
|
95
|
+
Returns the model instance matching the field and value.
|
96
|
+
|
97
|
+
:param field: The field to match.
|
98
|
+
:param value: The value to match.
|
99
|
+
:param join_: The joins to make.
|
100
|
+
:return: The model instance.
|
101
|
+
"""
|
102
|
+
query = self._query(join_)
|
103
|
+
query = await self._get_by(query, field, value)
|
104
|
+
|
105
|
+
if join_ is not None:
|
106
|
+
return await self.all_unique(query)
|
107
|
+
if unique:
|
108
|
+
return await self._one(query)
|
109
|
+
|
110
|
+
return await self._all(query)
|
111
|
+
|
112
|
+
async def delete(self, model: ModelType) -> None:
|
113
|
+
"""
|
114
|
+
Deletes the model.
|
115
|
+
|
116
|
+
:param model: The model to delete.
|
117
|
+
:return: None
|
118
|
+
"""
|
119
|
+
await self.session.delete(model)
|
120
|
+
|
121
|
+
async def update(self, model: ModelType, attributes: dict[str, Any]) -> ModelType:
|
122
|
+
"""
|
123
|
+
Updates the model.
|
124
|
+
|
125
|
+
:param model: The model to update.
|
126
|
+
:param attributes: The attributes to update the model with.
|
127
|
+
:return: The updated model instance.
|
128
|
+
"""
|
129
|
+
for key, value in attributes.items():
|
130
|
+
if hasattr(model, key):
|
131
|
+
setattr(model, key, value)
|
132
|
+
|
133
|
+
self.session.add(model)
|
134
|
+
await self.session.commit()
|
135
|
+
|
136
|
+
return model
|
137
|
+
|
138
|
+
def _query(
|
139
|
+
self,
|
140
|
+
join_: set[str] | None = None,
|
141
|
+
order_: dict | None = None,
|
142
|
+
) -> Select:
|
143
|
+
"""
|
144
|
+
Returns a callable that can be used to query the model.
|
145
|
+
|
146
|
+
:param join_: The joins to make.
|
147
|
+
:param order_: The order of the results. (e.g desc, asc)
|
148
|
+
:return: A callable that can be used to query the model.
|
149
|
+
"""
|
150
|
+
query = select(self.model_class)
|
151
|
+
query = self._maybe_join(query, join_)
|
152
|
+
query = self._maybe_ordered(query, order_)
|
153
|
+
|
154
|
+
return query
|
155
|
+
|
156
|
+
async def _all(self, query: Select) -> list[ModelType]:
|
157
|
+
"""
|
158
|
+
Returns all results from the query.
|
159
|
+
|
160
|
+
:param query: The query to execute.
|
161
|
+
:return: A list of model instances.
|
162
|
+
"""
|
163
|
+
query = await self.session.scalars(query)
|
164
|
+
return query.all()
|
165
|
+
|
166
|
+
async def all_unique(self, query: Select) -> list[ModelType]:
|
167
|
+
result = await self.session.execute(query)
|
168
|
+
return result.unique().scalars().all()
|
169
|
+
|
170
|
+
async def _first(self, query: Select) -> ModelType | None:
|
171
|
+
"""
|
172
|
+
Returns the first result from the query.
|
173
|
+
|
174
|
+
:param query: The query to execute.
|
175
|
+
:return: The first model instance.
|
176
|
+
"""
|
177
|
+
query = await self.session.scalars(query)
|
178
|
+
return query.first()
|
179
|
+
|
180
|
+
async def _one_or_none(self, query: Select) -> ModelType | None:
|
181
|
+
"""Returns the first result from the query or None."""
|
182
|
+
query = await self.session.scalars(query)
|
183
|
+
return query.one_or_none()
|
184
|
+
|
185
|
+
async def _one(self, query: Select) -> ModelType:
|
186
|
+
"""
|
187
|
+
Returns the first result from the query or raises NoResultFound.
|
188
|
+
|
189
|
+
:param query: The query to execute.
|
190
|
+
:return: The first model instance.
|
191
|
+
"""
|
192
|
+
query = await self.session.scalars(query)
|
193
|
+
return query.one()
|
194
|
+
|
195
|
+
async def _count(self, query: Select) -> int:
|
196
|
+
"""
|
197
|
+
Returns the count of the records.
|
198
|
+
|
199
|
+
:param query: The query to execute.
|
200
|
+
"""
|
201
|
+
query = query.subquery()
|
202
|
+
query = await self.session.scalars(select(func.count()).select_from(query))
|
203
|
+
return query.one()
|
204
|
+
|
205
|
+
async def _sort_by(
|
206
|
+
self,
|
207
|
+
query: Select,
|
208
|
+
sort_by: str,
|
209
|
+
order: str | None = "asc",
|
210
|
+
model: Type[ModelType] | None = None,
|
211
|
+
case_insensitive: bool = False,
|
212
|
+
) -> Select:
|
213
|
+
"""
|
214
|
+
Returns the query sorted by the given column.
|
215
|
+
|
216
|
+
:param query: The query to sort.
|
217
|
+
:param sort_by: The column to sort by.
|
218
|
+
:param order: The order to sort by.
|
219
|
+
:param model: The model to sort.
|
220
|
+
:param case_insensitive: Whether to sort case insensitively.
|
221
|
+
:return: The sorted query.
|
222
|
+
"""
|
223
|
+
model = model or self.model_class
|
224
|
+
|
225
|
+
order_column = None
|
226
|
+
|
227
|
+
if case_insensitive:
|
228
|
+
order_column = func.lower(getattr(model, sort_by))
|
229
|
+
else:
|
230
|
+
order_column = getattr(model, sort_by)
|
231
|
+
|
232
|
+
if order == "desc":
|
233
|
+
return query.order_by(order_column.desc())
|
234
|
+
|
235
|
+
return query.order_by(order_column.asc())
|
236
|
+
|
237
|
+
async def _get_by(self, query: Select, field: str, value: Any) -> Select:
|
238
|
+
"""
|
239
|
+
Returns the query filtered by the given column.
|
240
|
+
|
241
|
+
:param query: The query to filter.
|
242
|
+
:param field: The column to filter by.
|
243
|
+
:param value: The value to filter by.
|
244
|
+
:return: The filtered query.
|
245
|
+
"""
|
246
|
+
return query.where(getattr(self.model_class, field) == value)
|
247
|
+
|
248
|
+
def _maybe_join(self, query: Select, join_: set[str] | None = None) -> Select:
|
249
|
+
"""
|
250
|
+
Returns the query with the given joins.
|
251
|
+
|
252
|
+
:param query: The query to join.
|
253
|
+
:param join_: The joins to make.
|
254
|
+
:return: The query with the given joins.
|
255
|
+
"""
|
256
|
+
if not join_:
|
257
|
+
return query
|
258
|
+
|
259
|
+
if not isinstance(join_, set):
|
260
|
+
raise TypeError("join_ must be a set")
|
261
|
+
|
262
|
+
return reduce(self._add_join_to_query, join_, query) # type: ignore
|
263
|
+
|
264
|
+
def _maybe_ordered(self, query: Select, order_: dict | None = None) -> Select:
|
265
|
+
"""
|
266
|
+
Returns the query ordered by the given column.
|
267
|
+
|
268
|
+
:param query: The query to order.
|
269
|
+
:param order_: The order to make.
|
270
|
+
:return: The query ordered by the given column.
|
271
|
+
"""
|
272
|
+
if order_:
|
273
|
+
if order_["asc"]:
|
274
|
+
for order in order_["asc"]:
|
275
|
+
query = query.order_by(getattr(self.model_class, order).asc())
|
276
|
+
else:
|
277
|
+
for order in order_["desc"]:
|
278
|
+
query = query.order_by(getattr(self.model_class, order).desc())
|
279
|
+
|
280
|
+
return query
|
281
|
+
|
282
|
+
def _add_join_to_query(self, query: Select, join_: str) -> Select:
|
283
|
+
"""
|
284
|
+
Returns the query with the given join.
|
285
|
+
|
286
|
+
:param query: The query to join.
|
287
|
+
:param join_: The join to make.
|
288
|
+
:return: The query with the given join.
|
289
|
+
"""
|
290
|
+
return getattr(self, "_join_" + join_)(query)
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# from .context import SqlConfig, DatabaseType
|
2
|
+
from .field import (
|
3
|
+
CharField,
|
4
|
+
IntegerField,
|
5
|
+
TextField,
|
6
|
+
FloatField,
|
7
|
+
BooleanField,
|
8
|
+
ForeignKeyField,
|
9
|
+
DateTimeField,
|
10
|
+
Field,
|
11
|
+
JSONField,
|
12
|
+
ArrayField,
|
13
|
+
DecimalField,
|
14
|
+
DateField,
|
15
|
+
)
|
16
|
+
from .model import Model
|
17
|
+
from .query import F, Q, QuerySet
|
18
|
+
|
19
|
+
__all__ = [
|
20
|
+
"CharField",
|
21
|
+
"IntegerField",
|
22
|
+
"TextField",
|
23
|
+
"FloatField",
|
24
|
+
"BooleanField",
|
25
|
+
"ForeignKeyField",
|
26
|
+
"DateTimeField",
|
27
|
+
"Field",
|
28
|
+
"JSONField",
|
29
|
+
"ArrayField",
|
30
|
+
"DecimalField",
|
31
|
+
"DateField",
|
32
|
+
"Model",
|
33
|
+
"Q",
|
34
|
+
"F",
|
35
|
+
"QuerySet",
|
36
|
+
]
|
@@ -0,0 +1,246 @@
|
|
1
|
+
import json
|
2
|
+
from datetime import date, datetime
|
3
|
+
from decimal import Decimal, InvalidOperation
|
4
|
+
from typing import Any, Optional, Union
|
5
|
+
|
6
|
+
from hypern.exceptions import DBFieldValidationError
|
7
|
+
|
8
|
+
|
9
|
+
class Field:
|
10
|
+
"""Base field class for ORM-like field definitions."""
|
11
|
+
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
field_type: str,
|
15
|
+
primary_key: bool = False,
|
16
|
+
null: bool = True,
|
17
|
+
default: Any = None,
|
18
|
+
unique: bool = False,
|
19
|
+
index: bool = False,
|
20
|
+
validators: Optional[list] = None,
|
21
|
+
auto_increment: bool = False,
|
22
|
+
):
|
23
|
+
self.field_type = field_type
|
24
|
+
self.primary_key = primary_key
|
25
|
+
self.null = null
|
26
|
+
self.default = default
|
27
|
+
self.unique = unique
|
28
|
+
self.index = index
|
29
|
+
self.validators = validators or []
|
30
|
+
self.name = None
|
31
|
+
self.model = None
|
32
|
+
self.auto_increment = auto_increment
|
33
|
+
|
34
|
+
def validate(self, value: Any) -> None:
|
35
|
+
if value is None:
|
36
|
+
if not self.null:
|
37
|
+
raise DBFieldValidationError(f"Field {self.name} cannot be null")
|
38
|
+
return
|
39
|
+
|
40
|
+
for validator in self.validators:
|
41
|
+
try:
|
42
|
+
validator(value)
|
43
|
+
except Exception as e:
|
44
|
+
raise DBFieldValidationError(f"Validation failed for {self.name}: {str(e)}")
|
45
|
+
|
46
|
+
def sql_type(self) -> str:
|
47
|
+
"""Return SQL type definition for the field."""
|
48
|
+
type_mapping = {
|
49
|
+
"int": "INTEGER",
|
50
|
+
"str": "VARCHAR(255)",
|
51
|
+
"float": "FLOAT",
|
52
|
+
"bool": "BOOLEAN",
|
53
|
+
"datetime": "TIMESTAMP",
|
54
|
+
"date": "DATE",
|
55
|
+
"text": "TEXT",
|
56
|
+
"json": "JSONB",
|
57
|
+
"array": "ARRAY",
|
58
|
+
"decimal": "DECIMAL",
|
59
|
+
}
|
60
|
+
return type_mapping.get(self.field_type, "VARCHAR(255)")
|
61
|
+
|
62
|
+
|
63
|
+
class CharField(Field):
|
64
|
+
def __init__(self, max_length: int = 255, **kwargs):
|
65
|
+
super().__init__(field_type="str", **kwargs)
|
66
|
+
self.max_length = max_length
|
67
|
+
|
68
|
+
def validate(self, value: Any) -> None:
|
69
|
+
super().validate(value)
|
70
|
+
if value is not None:
|
71
|
+
if not isinstance(value, str):
|
72
|
+
raise DBFieldValidationError(f"Field {self.name} must be a string")
|
73
|
+
if len(value) > self.max_length:
|
74
|
+
raise DBFieldValidationError(f"Field {self.name} cannot exceed {self.max_length} characters")
|
75
|
+
|
76
|
+
def sql_type(self) -> str:
|
77
|
+
return f"VARCHAR({self.max_length})"
|
78
|
+
|
79
|
+
|
80
|
+
class TextField(Field):
|
81
|
+
def __init__(self, **kwargs):
|
82
|
+
super().__init__(field_type="text", **kwargs)
|
83
|
+
|
84
|
+
def validate(self, value: Any) -> None:
|
85
|
+
super().validate(value)
|
86
|
+
if value is not None and not isinstance(value, str):
|
87
|
+
raise DBFieldValidationError(f"Field {self.name} must be a string")
|
88
|
+
|
89
|
+
|
90
|
+
class IntegerField(Field):
|
91
|
+
def __init__(self, **kwargs):
|
92
|
+
super().__init__(field_type="int", **kwargs)
|
93
|
+
|
94
|
+
def validate(self, value: Any) -> None:
|
95
|
+
super().validate(value)
|
96
|
+
if value is not None:
|
97
|
+
try:
|
98
|
+
int(value)
|
99
|
+
except (TypeError, ValueError):
|
100
|
+
raise DBFieldValidationError(f"Field {self.name} must be an integer")
|
101
|
+
|
102
|
+
|
103
|
+
class FloatField(Field):
|
104
|
+
def __init__(self, **kwargs):
|
105
|
+
super().__init__(field_type="float", **kwargs)
|
106
|
+
|
107
|
+
def validate(self, value: Any) -> None:
|
108
|
+
super().validate(value)
|
109
|
+
if value is not None:
|
110
|
+
try:
|
111
|
+
float(value)
|
112
|
+
except (TypeError, ValueError):
|
113
|
+
raise DBFieldValidationError(f"Field {self.name} must be a float")
|
114
|
+
|
115
|
+
|
116
|
+
class BooleanField(Field):
|
117
|
+
def __init__(self, **kwargs):
|
118
|
+
super().__init__(field_type="bool", **kwargs)
|
119
|
+
|
120
|
+
def validate(self, value: Any) -> None:
|
121
|
+
super().validate(value)
|
122
|
+
if value is not None and not isinstance(value, bool):
|
123
|
+
raise DBFieldValidationError(f"Field {self.name} must be a boolean")
|
124
|
+
|
125
|
+
|
126
|
+
class DateTimeField(Field):
|
127
|
+
def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs):
|
128
|
+
super().__init__(field_type="datetime", **kwargs)
|
129
|
+
self.auto_now = auto_now
|
130
|
+
self.auto_now_add = auto_now_add
|
131
|
+
|
132
|
+
def validate(self, value: Any) -> None:
|
133
|
+
super().validate(value)
|
134
|
+
if value is not None and not isinstance(value, datetime):
|
135
|
+
raise DBFieldValidationError(f"Field {self.name} must be a datetime object")
|
136
|
+
|
137
|
+
|
138
|
+
class DateField(Field):
|
139
|
+
def __init__(self, auto_now: bool = False, auto_now_add: bool = False, **kwargs):
|
140
|
+
super().__init__(field_type="date", **kwargs)
|
141
|
+
self.auto_now = auto_now
|
142
|
+
self.auto_now_add = auto_now_add
|
143
|
+
|
144
|
+
def validate(self, value: Any) -> None:
|
145
|
+
super().validate(value)
|
146
|
+
if value is not None and not isinstance(value, date):
|
147
|
+
raise DBFieldValidationError(f"Field {self.name} must be a date object")
|
148
|
+
|
149
|
+
|
150
|
+
class JSONField(Field):
|
151
|
+
def __init__(self, **kwargs):
|
152
|
+
super().__init__(field_type="json", **kwargs)
|
153
|
+
|
154
|
+
def validate(self, value: Any) -> None:
|
155
|
+
super().validate(value)
|
156
|
+
if value is not None:
|
157
|
+
try:
|
158
|
+
json.dumps(value)
|
159
|
+
except (TypeError, ValueError):
|
160
|
+
raise DBFieldValidationError(f"Field {self.name} must be JSON serializable")
|
161
|
+
|
162
|
+
|
163
|
+
class ArrayField(Field):
|
164
|
+
def __init__(self, base_field: Field, **kwargs):
|
165
|
+
super().__init__(field_type="array", **kwargs)
|
166
|
+
self.base_field = base_field
|
167
|
+
|
168
|
+
def validate(self, value: Any) -> None:
|
169
|
+
super().validate(value)
|
170
|
+
if value is not None:
|
171
|
+
if not isinstance(value, (list, tuple)):
|
172
|
+
raise DBFieldValidationError(f"Field {self.name} must be a list or tuple")
|
173
|
+
for item in value:
|
174
|
+
self.base_field.validate(item)
|
175
|
+
|
176
|
+
def sql_type(self) -> str:
|
177
|
+
return f"{self.base_field.sql_type()}[]"
|
178
|
+
|
179
|
+
|
180
|
+
class DecimalField(Field):
|
181
|
+
def __init__(self, max_digits: int = 10, decimal_places: int = 2, **kwargs):
|
182
|
+
super().__init__(field_type="decimal", **kwargs)
|
183
|
+
self.max_digits = max_digits
|
184
|
+
self.decimal_places = decimal_places
|
185
|
+
|
186
|
+
def validate(self, value: Any) -> None:
|
187
|
+
super().validate(value)
|
188
|
+
if value is not None:
|
189
|
+
try:
|
190
|
+
decimal_value = Decimal(str(value))
|
191
|
+
decimal_tuple = decimal_value.as_tuple()
|
192
|
+
if len(decimal_tuple.digits) - (-decimal_tuple.exponent) > self.max_digits:
|
193
|
+
raise DBFieldValidationError(f"Field {self.name} exceeds maximum digits {self.max_digits}")
|
194
|
+
if -decimal_tuple.exponent > self.decimal_places:
|
195
|
+
raise DBFieldValidationError(f"Field {self.name} exceeds maximum decimal places {self.decimal_places}")
|
196
|
+
except InvalidOperation:
|
197
|
+
raise DBFieldValidationError(f"Field {self.name} must be a valid decimal number")
|
198
|
+
|
199
|
+
def sql_type(self) -> str:
|
200
|
+
return f"DECIMAL({self.max_digits},{self.decimal_places})"
|
201
|
+
|
202
|
+
|
203
|
+
class ForeignKeyField(Field):
|
204
|
+
"""Field for foreign key relationships."""
|
205
|
+
|
206
|
+
def __init__(
|
207
|
+
self,
|
208
|
+
to_model: Union[str, Any],
|
209
|
+
related_field: str = "id",
|
210
|
+
on_delete: str = "CASCADE",
|
211
|
+
on_update: str = "CASCADE",
|
212
|
+
related_name: Optional[str] = None,
|
213
|
+
**kwargs,
|
214
|
+
):
|
215
|
+
if isinstance(to_model, str):
|
216
|
+
field_type = "int"
|
217
|
+
else:
|
218
|
+
related_field_obj = getattr(to_model, related_field, None)
|
219
|
+
if related_field_obj is None:
|
220
|
+
raise ValueError(f"Field {related_field} not found in model {to_model.__name__}")
|
221
|
+
field_type = related_field_obj.field_type
|
222
|
+
|
223
|
+
super().__init__(field_type=field_type, **kwargs)
|
224
|
+
self.to_model = to_model
|
225
|
+
self.related_field = related_field
|
226
|
+
self.on_delete = on_delete.upper()
|
227
|
+
self.on_update = on_update.upper()
|
228
|
+
self.related_name = related_name
|
229
|
+
|
230
|
+
valid_actions = {"CASCADE", "SET NULL", "RESTRICT", "NO ACTION"}
|
231
|
+
if self.on_delete not in valid_actions:
|
232
|
+
raise ValueError(f"Invalid on_delete action. Must be one of: {valid_actions}")
|
233
|
+
if self.on_update not in valid_actions:
|
234
|
+
raise ValueError(f"Invalid on_update action. Must be one of: {valid_actions}")
|
235
|
+
|
236
|
+
if (self.on_delete == "SET NULL" or self.on_update == "SET NULL") and not kwargs.get("null", True):
|
237
|
+
raise ValueError("Field must be nullable to use SET NULL referential action")
|
238
|
+
|
239
|
+
def validate(self, value: Any) -> None:
|
240
|
+
super().validate(value)
|
241
|
+
if value is not None and not isinstance(self.to_model, str):
|
242
|
+
related_field_obj = getattr(self.to_model, self.related_field)
|
243
|
+
try:
|
244
|
+
related_field_obj.validate(value)
|
245
|
+
except DBFieldValidationError as e:
|
246
|
+
raise DBFieldValidationError(f"Foreign key {self.name} validation failed: {str(e)}")
|