ns-orm 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ns_orm/__init__.py +96 -0
- ns_orm/cli.py +174 -0
- ns_orm/database.py +292 -0
- ns_orm/dialects.py +290 -0
- ns_orm/exceptions.py +26 -0
- ns_orm/expressions.py +108 -0
- ns_orm/fields.py +313 -0
- ns_orm/manager.py +72 -0
- ns_orm/migrations/__init__.py +3 -0
- ns_orm/migrations/autodetector.py +159 -0
- ns_orm/migrations/executor.py +150 -0
- ns_orm/migrations/loader.py +53 -0
- ns_orm/migrations/migration.py +14 -0
- ns_orm/migrations/operations.py +93 -0
- ns_orm/migrations/state.py +42 -0
- ns_orm/migrations/writer.py +79 -0
- ns_orm/model.py +151 -0
- ns_orm/query.py +659 -0
- ns_orm/schema.py +131 -0
- ns_orm/typing.py +39 -0
- ns_orm/utils.py +58 -0
- ns_orm-0.0.0.dist-info/METADATA +289 -0
- ns_orm-0.0.0.dist-info/RECORD +27 -0
- ns_orm-0.0.0.dist-info/WHEEL +5 -0
- ns_orm-0.0.0.dist-info/entry_points.txt +2 -0
- ns_orm-0.0.0.dist-info/licenses/LICENSE +201 -0
- ns_orm-0.0.0.dist-info/top_level.txt +1 -0
ns_orm/fields.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Literal
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class FieldDef:
|
|
11
|
+
primary_key: bool = False
|
|
12
|
+
nullable: Optional[bool] = None
|
|
13
|
+
unique: bool = False
|
|
14
|
+
index: bool = False
|
|
15
|
+
default: Any = None
|
|
16
|
+
server_default: Any = None
|
|
17
|
+
autoincrement: Optional[bool] = None
|
|
18
|
+
|
|
19
|
+
def ddl(self) -> dict[str, Any]:
|
|
20
|
+
return {
|
|
21
|
+
"primary_key": self.primary_key,
|
|
22
|
+
"nullable": self.nullable,
|
|
23
|
+
"unique": self.unique,
|
|
24
|
+
"index": self.index,
|
|
25
|
+
"default": self.default,
|
|
26
|
+
"server_default": self.server_default,
|
|
27
|
+
"autoincrement": self.autoincrement,
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass(frozen=True)
|
|
32
|
+
class ColumnField(FieldDef):
|
|
33
|
+
sa_type: Any = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class TinyInt(ColumnField):
|
|
38
|
+
sa_type: Any = "TINYINT"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class SmallInt(ColumnField):
|
|
43
|
+
sa_type: Any = "SMALLINT"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass(frozen=True)
|
|
47
|
+
class Int(ColumnField):
|
|
48
|
+
sa_type: Any = "INTEGER"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class BigInt(ColumnField):
|
|
53
|
+
sa_type: Any = "BIGINT"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True)
|
|
57
|
+
class UTinyInt(ColumnField):
|
|
58
|
+
sa_type: Any = "UTINYINT"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass(frozen=True)
|
|
62
|
+
class USmallInt(ColumnField):
|
|
63
|
+
sa_type: Any = "USMALLINT"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(frozen=True)
|
|
67
|
+
class UInt(ColumnField):
|
|
68
|
+
sa_type: Any = "UINTEGER"
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(frozen=True)
|
|
72
|
+
class UBigInt(ColumnField):
|
|
73
|
+
sa_type: Any = "UBIGINT"
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True)
|
|
77
|
+
class Float(ColumnField):
|
|
78
|
+
sa_type: Any = "FLOAT"
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass(frozen=True)
|
|
82
|
+
class Real(ColumnField):
|
|
83
|
+
sa_type: Any = "REAL"
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass(frozen=True)
|
|
87
|
+
class Double(ColumnField):
|
|
88
|
+
sa_type: Any = "DOUBLE"
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass(frozen=True)
|
|
92
|
+
class Decimal(ColumnField):
|
|
93
|
+
precision: int = 18
|
|
94
|
+
scale: int = 6
|
|
95
|
+
sa_type: Any = "DECIMAL"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(frozen=True)
|
|
99
|
+
class Boolean(ColumnField):
|
|
100
|
+
sa_type: Any = "BOOLEAN"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@dataclass(frozen=True)
|
|
104
|
+
class Char(ColumnField):
|
|
105
|
+
length: int = 1
|
|
106
|
+
sa_type: Any = "CHAR"
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
@dataclass(frozen=True)
|
|
110
|
+
class String(ColumnField):
|
|
111
|
+
max_length: int = 255
|
|
112
|
+
sa_type: Any = "VARCHAR"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass(frozen=True)
|
|
116
|
+
class Text(ColumnField):
|
|
117
|
+
sa_type: Any = "TEXT"
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass(frozen=True)
|
|
121
|
+
class Binary(ColumnField):
|
|
122
|
+
length: Optional[int] = None
|
|
123
|
+
sa_type: Any = "BINARY"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass(frozen=True)
|
|
127
|
+
class DateTime(ColumnField):
|
|
128
|
+
timezone: bool = False
|
|
129
|
+
sa_type: Any = "DATETIME"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@dataclass(frozen=True)
|
|
133
|
+
class Date(ColumnField):
|
|
134
|
+
sa_type: Any = "DATE"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@dataclass(frozen=True)
|
|
138
|
+
class JSON(ColumnField):
|
|
139
|
+
sa_type: Any = "JSON"
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@dataclass(frozen=True)
|
|
143
|
+
class UUID(ColumnField):
|
|
144
|
+
sa_type: Any = "UUID"
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass(frozen=True)
|
|
148
|
+
class IPv4(ColumnField):
|
|
149
|
+
sa_type: Any = "IPv4"
|
|
150
|
+
|
|
151
|
+
def render_type(self, dialect: Any) -> str:
|
|
152
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
153
|
+
return "IPv4"
|
|
154
|
+
return "VARCHAR(15)"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@dataclass(frozen=True)
|
|
158
|
+
class IPv6(ColumnField):
|
|
159
|
+
sa_type: Any = "IPv6"
|
|
160
|
+
|
|
161
|
+
def render_type(self, dialect: Any) -> str:
|
|
162
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
163
|
+
return "IPv6"
|
|
164
|
+
return "VARCHAR(39)"
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass(frozen=True)
|
|
168
|
+
class Date32(ColumnField):
|
|
169
|
+
sa_type: Any = "Date32"
|
|
170
|
+
|
|
171
|
+
def render_type(self, dialect: Any) -> str:
|
|
172
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
173
|
+
return "Date32"
|
|
174
|
+
return "DATE"
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@dataclass(frozen=True)
|
|
178
|
+
class DateTime64(ColumnField):
|
|
179
|
+
precision: int = 3
|
|
180
|
+
timezone: Optional[str] = None
|
|
181
|
+
sa_type: Any = "DateTime64"
|
|
182
|
+
|
|
183
|
+
def render_type(self, dialect: Any) -> str:
|
|
184
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
185
|
+
if self.timezone:
|
|
186
|
+
return f"DateTime64({int(self.precision)}, '{self.timezone}')"
|
|
187
|
+
return f"DateTime64({int(self.precision)})"
|
|
188
|
+
return "TIMESTAMP"
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
@dataclass(frozen=True)
|
|
192
|
+
class FixedString(ColumnField):
|
|
193
|
+
length: int = 16
|
|
194
|
+
sa_type: Any = "FixedString"
|
|
195
|
+
|
|
196
|
+
def render_type(self, dialect: Any) -> str:
|
|
197
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
198
|
+
return f"FixedString({int(self.length)})"
|
|
199
|
+
return f"VARCHAR({int(self.length)})"
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@dataclass(frozen=True)
|
|
203
|
+
class Enum8(ColumnField):
|
|
204
|
+
values: Dict[str, int] = field(default_factory=dict)
|
|
205
|
+
sa_type: Any = "Enum8"
|
|
206
|
+
|
|
207
|
+
def render_type(self, dialect: Any) -> str:
|
|
208
|
+
if getattr(dialect, "name", "") != "clickhouse":
|
|
209
|
+
return "VARCHAR(255)"
|
|
210
|
+
items = ", ".join(f"'{k}' = {int(v)}" for k, v in self.values.items())
|
|
211
|
+
return f"Enum8({items})"
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
@dataclass(frozen=True)
|
|
215
|
+
class Enum16(ColumnField):
|
|
216
|
+
values: Dict[str, int] = field(default_factory=dict)
|
|
217
|
+
sa_type: Any = "Enum16"
|
|
218
|
+
|
|
219
|
+
def render_type(self, dialect: Any) -> str:
|
|
220
|
+
if getattr(dialect, "name", "") != "clickhouse":
|
|
221
|
+
return "VARCHAR(255)"
|
|
222
|
+
items = ", ".join(f"'{k}' = {int(v)}" for k, v in self.values.items())
|
|
223
|
+
return f"Enum16({items})"
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@dataclass(frozen=True)
|
|
227
|
+
class Array(ColumnField):
|
|
228
|
+
item: Any = None
|
|
229
|
+
sa_type: Any = "Array"
|
|
230
|
+
|
|
231
|
+
def render_type(self, dialect: Any) -> str:
|
|
232
|
+
if self.item is None:
|
|
233
|
+
return (
|
|
234
|
+
"Array(String)"
|
|
235
|
+
if getattr(dialect, "name", "") == "clickhouse"
|
|
236
|
+
else "TEXT"
|
|
237
|
+
)
|
|
238
|
+
inner = dialect.type_sql(self.item)
|
|
239
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
240
|
+
return f"Array({inner})"
|
|
241
|
+
return "TEXT"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@dataclass(frozen=True)
|
|
245
|
+
class Map(ColumnField):
|
|
246
|
+
key: Any = None
|
|
247
|
+
value: Any = None
|
|
248
|
+
sa_type: Any = "Map"
|
|
249
|
+
|
|
250
|
+
def render_type(self, dialect: Any) -> str:
|
|
251
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
252
|
+
k = "String" if self.key is None else dialect.type_sql(self.key)
|
|
253
|
+
v = "String" if self.value is None else dialect.type_sql(self.value)
|
|
254
|
+
return f"Map({k}, {v})"
|
|
255
|
+
return "JSON"
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@dataclass(frozen=True)
|
|
259
|
+
class TupleType(ColumnField):
|
|
260
|
+
items: List[Any] = field(default_factory=list)
|
|
261
|
+
sa_type: Any = "Tuple"
|
|
262
|
+
|
|
263
|
+
def render_type(self, dialect: Any) -> str:
|
|
264
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
265
|
+
inner = ", ".join(dialect.type_sql(i) for i in self.items) or "String"
|
|
266
|
+
return f"Tuple({inner})"
|
|
267
|
+
return "TEXT"
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@dataclass(frozen=True)
|
|
271
|
+
class LowCardinality(ColumnField):
|
|
272
|
+
inner: Any = None
|
|
273
|
+
sa_type: Any = "LowCardinality"
|
|
274
|
+
|
|
275
|
+
def render_type(self, dialect: Any) -> str:
|
|
276
|
+
inner = "String" if self.inner is None else dialect.type_sql(self.inner)
|
|
277
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
278
|
+
return f"LowCardinality({inner})"
|
|
279
|
+
return inner
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass(frozen=True)
|
|
283
|
+
class Nullable(ColumnField):
|
|
284
|
+
inner: Any = None
|
|
285
|
+
sa_type: Any = "Nullable"
|
|
286
|
+
|
|
287
|
+
def render_type(self, dialect: Any) -> str:
|
|
288
|
+
inner = "String" if self.inner is None else dialect.type_sql(self.inner)
|
|
289
|
+
if getattr(dialect, "name", "") == "clickhouse":
|
|
290
|
+
return f"Nullable({inner})"
|
|
291
|
+
return inner
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@dataclass(frozen=True)
|
|
295
|
+
class ForeignKey(FieldDef):
|
|
296
|
+
to: Union[str, Type[Any]]
|
|
297
|
+
to_field: str = "id"
|
|
298
|
+
on_delete: Literal["RESTRICT", "CASCADE", "SET NULL"] = "RESTRICT"
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
@dataclass(frozen=True)
|
|
302
|
+
class OneToOne(ForeignKey):
|
|
303
|
+
unique: bool = True
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@dataclass(frozen=True)
|
|
307
|
+
class ManyToMany:
|
|
308
|
+
to: Union[str, Type[Any]]
|
|
309
|
+
through: Optional[str] = None
|
|
310
|
+
from_field: Optional[str] = None
|
|
311
|
+
to_field: Optional[str] = None
|
|
312
|
+
related_name: Optional[str] = None
|
|
313
|
+
extra: dict[str, Any] = field(default_factory=dict)
|
ns_orm/manager.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Generic, TypeVar
|
|
5
|
+
|
|
6
|
+
from ns_orm.database import AsyncDatabase, Database, get_connection
|
|
7
|
+
from ns_orm.query import AsyncQuerySet, QuerySet
|
|
8
|
+
|
|
9
|
+
TModel = TypeVar("TModel")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class Manager(Generic[TModel]):
|
|
14
|
+
model: type[TModel]
|
|
15
|
+
db: Any = None
|
|
16
|
+
|
|
17
|
+
def using(self, db: Any) -> Manager[TModel]:
|
|
18
|
+
if isinstance(db, str):
|
|
19
|
+
return Manager(model=self.model, db=get_connection(db))
|
|
20
|
+
return Manager(model=self.model, db=db)
|
|
21
|
+
|
|
22
|
+
def _get_db(self) -> Any:
|
|
23
|
+
if self.db is not None:
|
|
24
|
+
return self.db
|
|
25
|
+
connect_name = getattr(getattr(self.model, "_meta", None), "connect_name", None)
|
|
26
|
+
if not connect_name:
|
|
27
|
+
raise ValueError(
|
|
28
|
+
"Database is not set and Model.Meta.connect_name is missing"
|
|
29
|
+
)
|
|
30
|
+
return get_connection(connect_name)
|
|
31
|
+
|
|
32
|
+
def qs(self) -> QuerySet[TModel]:
|
|
33
|
+
db = self._get_db()
|
|
34
|
+
if not isinstance(db, Database):
|
|
35
|
+
raise ValueError("Database instance required for sync operations")
|
|
36
|
+
return QuerySet(self.model, db=db)
|
|
37
|
+
|
|
38
|
+
def aqs(self) -> AsyncQuerySet[TModel]:
|
|
39
|
+
db = self._get_db()
|
|
40
|
+
if not isinstance(db, AsyncDatabase):
|
|
41
|
+
raise ValueError("AsyncDatabase instance required for async operations")
|
|
42
|
+
return AsyncQuerySet(self.model, db=db)
|
|
43
|
+
|
|
44
|
+
def all(self) -> list[TModel]:
|
|
45
|
+
return self.qs().all()
|
|
46
|
+
|
|
47
|
+
async def aall(self) -> list[TModel]:
|
|
48
|
+
return await self.aqs().all()
|
|
49
|
+
|
|
50
|
+
def filter(self, **lookups: Any) -> Any:
|
|
51
|
+
db = self._get_db()
|
|
52
|
+
if isinstance(db, AsyncDatabase):
|
|
53
|
+
return self.aqs().filter(**lookups)
|
|
54
|
+
return self.qs().filter(**lookups)
|
|
55
|
+
|
|
56
|
+
def exclude(self, **lookups: Any) -> Any:
|
|
57
|
+
db = self._get_db()
|
|
58
|
+
if isinstance(db, AsyncDatabase):
|
|
59
|
+
return self.aqs().exclude(**lookups)
|
|
60
|
+
return self.qs().exclude(**lookups)
|
|
61
|
+
|
|
62
|
+
def get(self, **lookups: Any) -> TModel:
|
|
63
|
+
return self.qs().get(**lookups)
|
|
64
|
+
|
|
65
|
+
async def aget(self, **lookups: Any) -> TModel:
|
|
66
|
+
return await self.aqs().get(**lookups)
|
|
67
|
+
|
|
68
|
+
def create(self, **data: Any) -> TModel:
|
|
69
|
+
return self.qs().create(**data)
|
|
70
|
+
|
|
71
|
+
async def acreate(self, **data: Any) -> TModel:
|
|
72
|
+
return await self.aqs().create(**data)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
from ns_orm.dialects import Dialect
|
|
8
|
+
from ns_orm.migrations.migration import Migration
|
|
9
|
+
from ns_orm.migrations.operations import (
|
|
10
|
+
AddColumn,
|
|
11
|
+
CreateTable,
|
|
12
|
+
DropColumn,
|
|
13
|
+
DropTable,
|
|
14
|
+
Operation,
|
|
15
|
+
)
|
|
16
|
+
from ns_orm.migrations.state import project_state
|
|
17
|
+
from ns_orm.model import Model
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _slug(value: str) -> str:
|
|
21
|
+
value = value.strip().lower()
|
|
22
|
+
value = re.sub(r"[^a-z0-9_]+", "_", value)
|
|
23
|
+
value = re.sub(r"_+", "_", value).strip("_")
|
|
24
|
+
return value or "auto"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def autodetect_migration(
|
|
28
|
+
*,
|
|
29
|
+
models: list[type[Any]],
|
|
30
|
+
dialect: Dialect,
|
|
31
|
+
previous_state: Optional[dict[str, Any]],
|
|
32
|
+
name_suffix: str,
|
|
33
|
+
) -> Optional[Migration]:
|
|
34
|
+
prev = previous_state or {}
|
|
35
|
+
curr = project_state(models)
|
|
36
|
+
|
|
37
|
+
ops: list[Operation] = []
|
|
38
|
+
for model_name, mstate in curr.items():
|
|
39
|
+
if model_name not in prev:
|
|
40
|
+
ops.extend(_ops_create_model(mstate))
|
|
41
|
+
continue
|
|
42
|
+
ops.extend(_ops_alter_model(prev[model_name], mstate))
|
|
43
|
+
|
|
44
|
+
for model_name, mstate in prev.items():
|
|
45
|
+
if model_name in curr:
|
|
46
|
+
continue
|
|
47
|
+
ops.append(DropTable(table_name=mstate["table_name"]))
|
|
48
|
+
|
|
49
|
+
if not ops:
|
|
50
|
+
return None
|
|
51
|
+
|
|
52
|
+
ts = datetime.utcnow().strftime("%Y%m%d%H%M%S")
|
|
53
|
+
name = f"{ts}_{_slug(name_suffix)}"
|
|
54
|
+
return Migration(name=name, operations=ops, state=curr, dependencies=[])
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _ops_create_model(mstate: dict[str, Any]) -> list[Operation]:
|
|
58
|
+
columns: list[tuple[str, Any]] = []
|
|
59
|
+
fks: list[tuple[str, str, str, str]] = []
|
|
60
|
+
|
|
61
|
+
for col, fdef in mstate["fields"].items():
|
|
62
|
+
cls_name = fdef["__class__"]
|
|
63
|
+
if cls_name in {"ForeignKey", "OneToOne"}:
|
|
64
|
+
to = fdef["to"]
|
|
65
|
+
to_table = _resolve_table_name(to)
|
|
66
|
+
fks.append(
|
|
67
|
+
(
|
|
68
|
+
col,
|
|
69
|
+
to_table,
|
|
70
|
+
fdef.get("to_field", "id"),
|
|
71
|
+
fdef.get("on_delete", "RESTRICT"),
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
columns.append(
|
|
75
|
+
(
|
|
76
|
+
col,
|
|
77
|
+
{
|
|
78
|
+
"__class__": "Int",
|
|
79
|
+
"nullable": fdef.get("nullable"),
|
|
80
|
+
"unique": fdef.get("unique", False),
|
|
81
|
+
"index": fdef.get("index", False),
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
columns.append((col, fdef))
|
|
87
|
+
|
|
88
|
+
ops: list[Operation] = [
|
|
89
|
+
CreateTable(
|
|
90
|
+
table_name=mstate["table_name"],
|
|
91
|
+
columns=columns,
|
|
92
|
+
pk_name=mstate["pk_name"],
|
|
93
|
+
fks=fks,
|
|
94
|
+
unique_together=[],
|
|
95
|
+
)
|
|
96
|
+
]
|
|
97
|
+
|
|
98
|
+
for _rel_name, rel in mstate.get("m2m", {}).items():
|
|
99
|
+
ops.extend(_ops_create_m2m(mstate, rel))
|
|
100
|
+
return ops
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _ops_alter_model(prev: dict[str, Any], curr: dict[str, Any]) -> list[Operation]:
|
|
104
|
+
ops: list[Operation] = []
|
|
105
|
+
prev_fields: dict[str, Any] = prev.get("fields", {})
|
|
106
|
+
curr_fields: dict[str, Any] = curr.get("fields", {})
|
|
107
|
+
|
|
108
|
+
table = curr["table_name"]
|
|
109
|
+
for col, fdef in curr_fields.items():
|
|
110
|
+
if col in prev_fields:
|
|
111
|
+
continue
|
|
112
|
+
ops.append(AddColumn(table_name=table, column_name=col, column_def=fdef))
|
|
113
|
+
|
|
114
|
+
for col in prev_fields:
|
|
115
|
+
if col in curr_fields:
|
|
116
|
+
continue
|
|
117
|
+
ops.append(DropColumn(table_name=table, column_name=col))
|
|
118
|
+
|
|
119
|
+
return ops
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _ops_create_m2m(mstate: dict[str, Any], rel: dict[str, Any]) -> list[Operation]:
|
|
123
|
+
to = rel["to"]
|
|
124
|
+
to_table = _resolve_table_name(to)
|
|
125
|
+
from_pk = mstate["pk_name"]
|
|
126
|
+
to_pk = "id"
|
|
127
|
+
through = rel.get("through") or f"{mstate['table_name']}_{to_table}".lower()
|
|
128
|
+
from_col = rel.get("from_field") or f"{mstate['table_name']}_{from_pk}"
|
|
129
|
+
to_col = rel.get("to_field") or f"{to_table.lower()}_{to_pk}"
|
|
130
|
+
|
|
131
|
+
columns = [
|
|
132
|
+
(from_col, {"__class__": "Int", "nullable": False}),
|
|
133
|
+
(to_col, {"__class__": "Int", "nullable": False}),
|
|
134
|
+
]
|
|
135
|
+
fks = [
|
|
136
|
+
(from_col, mstate["table_name"], from_pk, "CASCADE"),
|
|
137
|
+
(to_col, to_table, to_pk, "CASCADE"),
|
|
138
|
+
]
|
|
139
|
+
return [
|
|
140
|
+
CreateTable(
|
|
141
|
+
table_name=through,
|
|
142
|
+
columns=columns,
|
|
143
|
+
pk_name=from_col,
|
|
144
|
+
fks=fks,
|
|
145
|
+
unique_together=[(from_col, to_col)],
|
|
146
|
+
)
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _resolve_table_name(ref: Any) -> str:
|
|
151
|
+
if isinstance(ref, str):
|
|
152
|
+
try:
|
|
153
|
+
m = Model._resolve_model(ref)
|
|
154
|
+
return m.table_name()
|
|
155
|
+
except Exception:
|
|
156
|
+
return ref.lower()
|
|
157
|
+
if hasattr(ref, "table_name") and callable(ref.table_name):
|
|
158
|
+
return ref.table_name()
|
|
159
|
+
return str(ref).lower()
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from ns_orm.database import AsyncDatabase
|
|
8
|
+
from ns_orm.dialects import Dialect
|
|
9
|
+
from ns_orm.migrations.loader import MigrationLoader
|
|
10
|
+
from ns_orm.migrations.migration import Migration
|
|
11
|
+
|
|
12
|
+
_MIGRATIONS_TABLE = "ns_orm_migrations"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class MigrationExecutor:
|
|
17
|
+
db: Any
|
|
18
|
+
dialect: Dialect
|
|
19
|
+
loader: MigrationLoader
|
|
20
|
+
is_async: bool = False
|
|
21
|
+
|
|
22
|
+
def plan(self) -> list[Migration]:
|
|
23
|
+
migrations = self.loader.load_all()
|
|
24
|
+
if not self.is_async:
|
|
25
|
+
applied = set(self._applied_sync())
|
|
26
|
+
else:
|
|
27
|
+
import asyncio
|
|
28
|
+
|
|
29
|
+
applied = set(asyncio.run(self._applied_async()))
|
|
30
|
+
return [m for m in migrations if m.name not in applied]
|
|
31
|
+
|
|
32
|
+
def migrate(self) -> None:
|
|
33
|
+
if isinstance(self.db, AsyncDatabase):
|
|
34
|
+
raise RuntimeError("Use amigrate() for AsyncDatabase")
|
|
35
|
+
self._ensure_table_sync()
|
|
36
|
+
applied = set(self._applied_sync())
|
|
37
|
+
for mig in self.loader.load_all():
|
|
38
|
+
if mig.name in applied:
|
|
39
|
+
continue
|
|
40
|
+
self._apply_migration_sync(mig)
|
|
41
|
+
self._record_sync(mig.name)
|
|
42
|
+
|
|
43
|
+
async def amigrate(self) -> None:
|
|
44
|
+
if not isinstance(self.db, AsyncDatabase):
|
|
45
|
+
raise RuntimeError("Use migrate() for sync Database")
|
|
46
|
+
await self._ensure_table_async()
|
|
47
|
+
applied = set(await self._applied_async())
|
|
48
|
+
for mig in self.loader.load_all():
|
|
49
|
+
if mig.name in applied:
|
|
50
|
+
continue
|
|
51
|
+
await self._apply_migration_async(mig)
|
|
52
|
+
await self._record_async(mig.name)
|
|
53
|
+
|
|
54
|
+
def _ensure_table_sync(self) -> None:
|
|
55
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
56
|
+
col_name = self.dialect.quote_ident("name")
|
|
57
|
+
col_applied_at = self.dialect.quote_ident("applied_at")
|
|
58
|
+
ddl = (
|
|
59
|
+
"CREATE TABLE IF NOT EXISTS {table} ("
|
|
60
|
+
"{col_name} VARCHAR(255) PRIMARY KEY, "
|
|
61
|
+
"{col_applied_at} VARCHAR(32) NOT NULL"
|
|
62
|
+
")"
|
|
63
|
+
)
|
|
64
|
+
ddl = ddl.format(
|
|
65
|
+
table=table,
|
|
66
|
+
col_name=col_name,
|
|
67
|
+
col_applied_at=col_applied_at,
|
|
68
|
+
)
|
|
69
|
+
self.db.execute(ddl)
|
|
70
|
+
|
|
71
|
+
async def _ensure_table_async(self) -> None:
|
|
72
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
73
|
+
col_name = self.dialect.quote_ident("name")
|
|
74
|
+
col_applied_at = self.dialect.quote_ident("applied_at")
|
|
75
|
+
ddl = (
|
|
76
|
+
"CREATE TABLE IF NOT EXISTS {table} ("
|
|
77
|
+
"{col_name} VARCHAR(255) PRIMARY KEY, "
|
|
78
|
+
"{col_applied_at} VARCHAR(32) NOT NULL"
|
|
79
|
+
")"
|
|
80
|
+
)
|
|
81
|
+
ddl = ddl.format(
|
|
82
|
+
table=table,
|
|
83
|
+
col_name=col_name,
|
|
84
|
+
col_applied_at=col_applied_at,
|
|
85
|
+
)
|
|
86
|
+
await self.db.execute(ddl)
|
|
87
|
+
|
|
88
|
+
def _applied_sync(self) -> list[str]:
|
|
89
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
90
|
+
col_name = self.dialect.quote_ident("name")
|
|
91
|
+
sql = "SELECT {col_name} AS name FROM {table}".format(
|
|
92
|
+
col_name=col_name, table=table
|
|
93
|
+
)
|
|
94
|
+
try:
|
|
95
|
+
rows = self.db.fetch_all(sql)
|
|
96
|
+
except Exception:
|
|
97
|
+
return []
|
|
98
|
+
return [r["name"] for r in rows]
|
|
99
|
+
|
|
100
|
+
async def _applied_async(self) -> list[str]:
|
|
101
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
102
|
+
col_name = self.dialect.quote_ident("name")
|
|
103
|
+
sql = "SELECT {col_name} AS name FROM {table}".format(
|
|
104
|
+
col_name=col_name, table=table
|
|
105
|
+
)
|
|
106
|
+
try:
|
|
107
|
+
rows = await self.db.fetch_all(sql)
|
|
108
|
+
except Exception:
|
|
109
|
+
return []
|
|
110
|
+
return [r["name"] for r in rows]
|
|
111
|
+
|
|
112
|
+
def _apply_migration_sync(self, mig: Migration) -> None:
|
|
113
|
+
with self.db.transaction():
|
|
114
|
+
for op in mig.operations:
|
|
115
|
+
for sql in op.sql(self.dialect):
|
|
116
|
+
self.db.execute(sql)
|
|
117
|
+
|
|
118
|
+
async def _apply_migration_async(self, mig: Migration) -> None:
|
|
119
|
+
async with self.db.transaction():
|
|
120
|
+
for op in mig.operations:
|
|
121
|
+
for sql in op.sql(self.dialect):
|
|
122
|
+
await self.db.execute(sql)
|
|
123
|
+
|
|
124
|
+
def _record_sync(self, name: str) -> None:
|
|
125
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
126
|
+
col_name = self.dialect.quote_ident("name")
|
|
127
|
+
col_applied_at = self.dialect.quote_ident("applied_at")
|
|
128
|
+
sql = "INSERT INTO {table} ({col_name}, {col_applied_at}) VALUES (:p1, :p2)"
|
|
129
|
+
sql = sql.format(
|
|
130
|
+
table=table,
|
|
131
|
+
col_name=col_name,
|
|
132
|
+
col_applied_at=col_applied_at,
|
|
133
|
+
)
|
|
134
|
+
self.db.execute(
|
|
135
|
+
sql, {"p1": name, "p2": datetime.utcnow().strftime("%Y%m%d%H%M%S")}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
async def _record_async(self, name: str) -> None:
|
|
139
|
+
table = self.dialect.quote_ident(_MIGRATIONS_TABLE)
|
|
140
|
+
col_name = self.dialect.quote_ident("name")
|
|
141
|
+
col_applied_at = self.dialect.quote_ident("applied_at")
|
|
142
|
+
sql = "INSERT INTO {table} ({col_name}, {col_applied_at}) VALUES (:p1, :p2)"
|
|
143
|
+
sql = sql.format(
|
|
144
|
+
table=table,
|
|
145
|
+
col_name=col_name,
|
|
146
|
+
col_applied_at=col_applied_at,
|
|
147
|
+
)
|
|
148
|
+
await self.db.execute(
|
|
149
|
+
sql, {"p1": name, "p2": datetime.utcnow().strftime("%Y%m%d%H%M%S")}
|
|
150
|
+
)
|