ns-orm 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ns_orm/__init__.py +96 -0
- ns_orm/cli.py +174 -0
- ns_orm/database.py +292 -0
- ns_orm/dialects.py +290 -0
- ns_orm/exceptions.py +26 -0
- ns_orm/expressions.py +108 -0
- ns_orm/fields.py +313 -0
- ns_orm/manager.py +72 -0
- ns_orm/migrations/__init__.py +3 -0
- ns_orm/migrations/autodetector.py +159 -0
- ns_orm/migrations/executor.py +150 -0
- ns_orm/migrations/loader.py +53 -0
- ns_orm/migrations/migration.py +14 -0
- ns_orm/migrations/operations.py +93 -0
- ns_orm/migrations/state.py +42 -0
- ns_orm/migrations/writer.py +79 -0
- ns_orm/model.py +151 -0
- ns_orm/query.py +659 -0
- ns_orm/schema.py +131 -0
- ns_orm/typing.py +39 -0
- ns_orm/utils.py +58 -0
- ns_orm-0.0.0.dist-info/METADATA +289 -0
- ns_orm-0.0.0.dist-info/RECORD +27 -0
- ns_orm-0.0.0.dist-info/WHEEL +5 -0
- ns_orm-0.0.0.dist-info/entry_points.txt +2 -0
- ns_orm-0.0.0.dist-info/licenses/LICENSE +201 -0
- ns_orm-0.0.0.dist-info/top_level.txt +1 -0
ns_orm/__init__.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from ns_orm.database import (
|
|
2
|
+
AsyncDatabase,
|
|
3
|
+
Database,
|
|
4
|
+
get_connection,
|
|
5
|
+
register_connection,
|
|
6
|
+
set_connection_provider,
|
|
7
|
+
)
|
|
8
|
+
from ns_orm.expressions import Q
|
|
9
|
+
from ns_orm.fields import (
|
|
10
|
+
JSON,
|
|
11
|
+
UUID,
|
|
12
|
+
Array,
|
|
13
|
+
BigInt,
|
|
14
|
+
Binary,
|
|
15
|
+
Boolean,
|
|
16
|
+
Char,
|
|
17
|
+
Date,
|
|
18
|
+
Date32,
|
|
19
|
+
DateTime,
|
|
20
|
+
DateTime64,
|
|
21
|
+
Decimal,
|
|
22
|
+
Double,
|
|
23
|
+
Enum8,
|
|
24
|
+
Enum16,
|
|
25
|
+
FixedString,
|
|
26
|
+
Float,
|
|
27
|
+
ForeignKey,
|
|
28
|
+
Int,
|
|
29
|
+
IPv4,
|
|
30
|
+
IPv6,
|
|
31
|
+
LowCardinality,
|
|
32
|
+
ManyToMany,
|
|
33
|
+
Map,
|
|
34
|
+
Nullable,
|
|
35
|
+
OneToOne,
|
|
36
|
+
Real,
|
|
37
|
+
SmallInt,
|
|
38
|
+
String,
|
|
39
|
+
Text,
|
|
40
|
+
TinyInt,
|
|
41
|
+
TupleType,
|
|
42
|
+
UBigInt,
|
|
43
|
+
UInt,
|
|
44
|
+
USmallInt,
|
|
45
|
+
UTinyInt,
|
|
46
|
+
)
|
|
47
|
+
from ns_orm.model import Model
|
|
48
|
+
from ns_orm.schema import acreate_all, create_all
|
|
49
|
+
|
|
50
|
+
__all__ = [
|
|
51
|
+
"Array",
|
|
52
|
+
"BigInt",
|
|
53
|
+
"Binary",
|
|
54
|
+
"Boolean",
|
|
55
|
+
"Char",
|
|
56
|
+
"AsyncDatabase",
|
|
57
|
+
"Database",
|
|
58
|
+
"Date",
|
|
59
|
+
"Date32",
|
|
60
|
+
"DateTime",
|
|
61
|
+
"DateTime64",
|
|
62
|
+
"Decimal",
|
|
63
|
+
"Double",
|
|
64
|
+
"Enum16",
|
|
65
|
+
"Enum8",
|
|
66
|
+
"FixedString",
|
|
67
|
+
"Float",
|
|
68
|
+
"ForeignKey",
|
|
69
|
+
"IPv4",
|
|
70
|
+
"IPv6",
|
|
71
|
+
"Int",
|
|
72
|
+
"JSON",
|
|
73
|
+
"LowCardinality",
|
|
74
|
+
"Map",
|
|
75
|
+
"ManyToMany",
|
|
76
|
+
"Model",
|
|
77
|
+
"Nullable",
|
|
78
|
+
"OneToOne",
|
|
79
|
+
"Q",
|
|
80
|
+
"Real",
|
|
81
|
+
"SmallInt",
|
|
82
|
+
"String",
|
|
83
|
+
"Text",
|
|
84
|
+
"TinyInt",
|
|
85
|
+
"TupleType",
|
|
86
|
+
"UBigInt",
|
|
87
|
+
"USmallInt",
|
|
88
|
+
"UTinyInt",
|
|
89
|
+
"UInt",
|
|
90
|
+
"UUID",
|
|
91
|
+
"acreate_all",
|
|
92
|
+
"create_all",
|
|
93
|
+
"get_connection",
|
|
94
|
+
"register_connection",
|
|
95
|
+
"set_connection_provider",
|
|
96
|
+
]
|
ns_orm/cli.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import importlib
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, List, Optional
|
|
8
|
+
|
|
9
|
+
from ns_orm.dialects import Dialect, dialect_from_url
|
|
10
|
+
from ns_orm.migrations.autodetector import autodetect_migration
|
|
11
|
+
from ns_orm.migrations.executor import MigrationExecutor
|
|
12
|
+
from ns_orm.migrations.loader import MigrationLoader
|
|
13
|
+
from ns_orm.migrations.writer import MigrationWriter
|
|
14
|
+
from ns_orm.model import Model
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _import_string(path: str) -> Any:
|
|
18
|
+
module_name, attr = path.rsplit(":", 1)
|
|
19
|
+
module = importlib.import_module(module_name)
|
|
20
|
+
return getattr(module, attr)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _load_models(modules: list[str]) -> list[type[Model]]:
|
|
24
|
+
out: list[type[Model]] = []
|
|
25
|
+
for m in modules:
|
|
26
|
+
importlib.import_module(m)
|
|
27
|
+
for cls in Model._registry.values():
|
|
28
|
+
if not isinstance(cls, type) or not issubclass(cls, Model):
|
|
29
|
+
continue
|
|
30
|
+
mod = getattr(cls, "__module__", "")
|
|
31
|
+
if any(mod == x or mod.startswith(x + ".") for x in modules):
|
|
32
|
+
out.append(cls)
|
|
33
|
+
unique = list(dict.fromkeys(out))
|
|
34
|
+
return unique
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _dialect_from_arg(value: str) -> Dialect:
|
|
38
|
+
value = value.strip()
|
|
39
|
+
if "://" in value:
|
|
40
|
+
scheme = value.split("://", 1)[0]
|
|
41
|
+
return dialect_from_url(scheme)
|
|
42
|
+
return dialect_from_url(value)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _parse_args(argv: list[str]) -> argparse.Namespace:
|
|
46
|
+
parser = argparse.ArgumentParser(prog="ns-orm")
|
|
47
|
+
sub = parser.add_subparsers(dest="command", required=True)
|
|
48
|
+
|
|
49
|
+
mk = sub.add_parser("makemigrations")
|
|
50
|
+
mk.add_argument(
|
|
51
|
+
"--models",
|
|
52
|
+
required=True,
|
|
53
|
+
help="模型模块路径,逗号分隔,例如: app.models,core.models",
|
|
54
|
+
)
|
|
55
|
+
mk.add_argument(
|
|
56
|
+
"--dialect",
|
|
57
|
+
required=True,
|
|
58
|
+
help="数据库方言或 URL scheme,例如: postgres 或 postgresql://",
|
|
59
|
+
)
|
|
60
|
+
mk.add_argument("--name", default="auto", help="迁移名称后缀,例如: add_user_email")
|
|
61
|
+
mk.add_argument("--migrations-dir", default="migrations", help="迁移文件输出目录")
|
|
62
|
+
|
|
63
|
+
mg = sub.add_parser("migrate")
|
|
64
|
+
mg.add_argument(
|
|
65
|
+
"--models",
|
|
66
|
+
required=True,
|
|
67
|
+
help="模型模块路径,逗号分隔,例如: app.models,core.models",
|
|
68
|
+
)
|
|
69
|
+
mg.add_argument(
|
|
70
|
+
"--dialect",
|
|
71
|
+
required=True,
|
|
72
|
+
help="数据库方言或 URL scheme,例如: postgres 或 postgresql://",
|
|
73
|
+
)
|
|
74
|
+
mg.add_argument(
|
|
75
|
+
"--db-factory",
|
|
76
|
+
required=True,
|
|
77
|
+
help="返回 ns_orm.Database/AsyncDatabase 的工厂函数路径,格式: pkg.module:func",
|
|
78
|
+
)
|
|
79
|
+
mg.add_argument("--migrations-dir", default="migrations", help="迁移文件目录")
|
|
80
|
+
mg.add_argument("--plan", action="store_true", help="仅输出计划,不执行")
|
|
81
|
+
mg.add_argument(
|
|
82
|
+
"--async", dest="is_async", action="store_true", help="使用异步执行器"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
return parser.parse_args(argv)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def main(argv: Optional[List[str]] = None) -> None:
|
|
89
|
+
ns = _parse_args(sys.argv[1:] if argv is None else argv)
|
|
90
|
+
models = [x.strip() for x in ns.models.split(",") if x.strip()]
|
|
91
|
+
if not models:
|
|
92
|
+
raise SystemExit("--models 不能为空")
|
|
93
|
+
|
|
94
|
+
dialect = _dialect_from_arg(ns.dialect)
|
|
95
|
+
migrations_dir = Path(ns.migrations_dir).resolve()
|
|
96
|
+
|
|
97
|
+
if ns.command == "makemigrations":
|
|
98
|
+
_cmd_makemigrations(
|
|
99
|
+
models=models,
|
|
100
|
+
dialect=dialect,
|
|
101
|
+
migrations_dir=migrations_dir,
|
|
102
|
+
name=ns.name,
|
|
103
|
+
)
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
if ns.command == "migrate":
|
|
107
|
+
_cmd_migrate(
|
|
108
|
+
models=models,
|
|
109
|
+
dialect=dialect,
|
|
110
|
+
migrations_dir=migrations_dir,
|
|
111
|
+
db_factory=ns.db_factory,
|
|
112
|
+
plan=bool(ns.plan),
|
|
113
|
+
is_async=bool(ns.is_async),
|
|
114
|
+
)
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
raise SystemExit(f"未知命令: {ns.command}")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _cmd_makemigrations(
|
|
121
|
+
*,
|
|
122
|
+
models: list[str],
|
|
123
|
+
dialect: Dialect,
|
|
124
|
+
migrations_dir: Path,
|
|
125
|
+
name: str,
|
|
126
|
+
) -> None:
|
|
127
|
+
loaded_models = _load_models(models)
|
|
128
|
+
loader = MigrationLoader(migrations_dir=migrations_dir)
|
|
129
|
+
prev_state = loader.latest_state()
|
|
130
|
+
|
|
131
|
+
migration = autodetect_migration(
|
|
132
|
+
models=loaded_models,
|
|
133
|
+
dialect=dialect,
|
|
134
|
+
previous_state=prev_state,
|
|
135
|
+
name_suffix=name,
|
|
136
|
+
)
|
|
137
|
+
if migration is None:
|
|
138
|
+
print("No changes detected")
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
writer = MigrationWriter(migrations_dir=migrations_dir)
|
|
142
|
+
path = writer.write(migration)
|
|
143
|
+
print(str(path))
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _cmd_migrate(
|
|
147
|
+
*,
|
|
148
|
+
models: list[str],
|
|
149
|
+
dialect: Dialect,
|
|
150
|
+
migrations_dir: Path,
|
|
151
|
+
db_factory: str,
|
|
152
|
+
plan: bool,
|
|
153
|
+
is_async: bool,
|
|
154
|
+
) -> None:
|
|
155
|
+
_load_models(models)
|
|
156
|
+
loader = MigrationLoader(migrations_dir=migrations_dir)
|
|
157
|
+
factory = _import_string(db_factory)
|
|
158
|
+
db = factory(dialect)
|
|
159
|
+
executor = MigrationExecutor(
|
|
160
|
+
db=db, dialect=dialect, loader=loader, is_async=is_async
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if plan:
|
|
164
|
+
for m in executor.plan():
|
|
165
|
+
print(f"{m.name}")
|
|
166
|
+
return
|
|
167
|
+
|
|
168
|
+
if is_async:
|
|
169
|
+
import asyncio
|
|
170
|
+
|
|
171
|
+
asyncio.run(executor.amigrate())
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
executor.migrate()
|
ns_orm/database.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import contextlib
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
AsyncGenerator,
|
|
7
|
+
Callable,
|
|
8
|
+
Dict,
|
|
9
|
+
Generator,
|
|
10
|
+
Optional,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from ns_orm.dialects import Dialect, PreparedSQL, dialect_from_url
|
|
15
|
+
from ns_orm.exceptions import ConfigurationError, IntegrityError
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _is_awaitable(v: Any) -> bool:
|
|
19
|
+
return hasattr(v, "__await__")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _require_callable(obj: Any, name: str) -> Callable[..., Any]:
|
|
23
|
+
fn = getattr(obj, name, None)
|
|
24
|
+
if fn is None or not callable(fn):
|
|
25
|
+
raise ConfigurationError(f"Database executor missing method: {name}")
|
|
26
|
+
return fn
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LesscodeDatabaseExecutorAdapter:
|
|
30
|
+
def __init__(self, connect: Any):
|
|
31
|
+
self._connect = connect
|
|
32
|
+
|
|
33
|
+
def _call(self, names: list[str], *args: Any, **kwargs: Any) -> Any:
|
|
34
|
+
for n in names:
|
|
35
|
+
fn = getattr(self._connect, n, None)
|
|
36
|
+
if callable(fn):
|
|
37
|
+
return fn(*args, **kwargs)
|
|
38
|
+
raise ConfigurationError(
|
|
39
|
+
"lesscode-database connect object missing method: {}".format(
|
|
40
|
+
", ".join(names)
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def execute(self, sql: str, params: Any = None) -> Any:
|
|
45
|
+
if params is None:
|
|
46
|
+
return self._call(["execute", "exec", "run"], sql)
|
|
47
|
+
return self._call(["execute", "exec", "run"], sql, params)
|
|
48
|
+
|
|
49
|
+
def fetch_all(self, sql: str, params: Any = None) -> Any:
|
|
50
|
+
if params is None:
|
|
51
|
+
return self._call(
|
|
52
|
+
["fetch_all", "fetchall", "query_all", "query", "select"], sql
|
|
53
|
+
)
|
|
54
|
+
return self._call(
|
|
55
|
+
["fetch_all", "fetchall", "query_all", "query", "select"], sql, params
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def fetch_one(self, sql: str, params: Any = None) -> Any:
|
|
59
|
+
fn = getattr(self._connect, "fetch_one", None)
|
|
60
|
+
if callable(fn):
|
|
61
|
+
return fn(sql, params) if params is not None else fn(sql)
|
|
62
|
+
fn = getattr(self._connect, "fetchone", None)
|
|
63
|
+
if callable(fn):
|
|
64
|
+
return fn(sql, params) if params is not None else fn(sql)
|
|
65
|
+
rows = self.fetch_all(sql, params)
|
|
66
|
+
if _is_awaitable(rows):
|
|
67
|
+
|
|
68
|
+
async def _coro():
|
|
69
|
+
data = await rows
|
|
70
|
+
return data[0] if data else None
|
|
71
|
+
|
|
72
|
+
return _coro()
|
|
73
|
+
return rows[0] if rows else None
|
|
74
|
+
|
|
75
|
+
def transaction(self) -> Any:
|
|
76
|
+
txn = getattr(self._connect, "transaction", None)
|
|
77
|
+
if callable(txn):
|
|
78
|
+
return txn()
|
|
79
|
+
return contextlib.nullcontext()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class Database:
|
|
83
|
+
"""
|
|
84
|
+
同步数据库执行器封装。
|
|
85
|
+
|
|
86
|
+
连接与连接池不由 ns-orm 负责,必须由外部库(例如 lesscode-database)
|
|
87
|
+
创建并传入 executor。
|
|
88
|
+
|
|
89
|
+
executor 最低需提供以下方法(duck-typing):
|
|
90
|
+
- execute(sql, params=None) -> int | Any
|
|
91
|
+
- fetch_all(sql, params=None) -> list[dict]
|
|
92
|
+
- fetch_one(sql, params=None) -> dict | None
|
|
93
|
+
|
|
94
|
+
可选:
|
|
95
|
+
- transaction() -> context manager
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(self, *, executor: Any, dialect: Dialect):
|
|
99
|
+
self.executor = executor
|
|
100
|
+
self.dialect = dialect
|
|
101
|
+
self._txn_depth: int = 0
|
|
102
|
+
|
|
103
|
+
@contextlib.contextmanager
|
|
104
|
+
def transaction(self) -> Generator[None, None, None]:
|
|
105
|
+
self._txn_depth += 1
|
|
106
|
+
try:
|
|
107
|
+
txn = getattr(self.executor, "transaction", None)
|
|
108
|
+
if callable(txn):
|
|
109
|
+
with txn():
|
|
110
|
+
yield
|
|
111
|
+
return
|
|
112
|
+
yield
|
|
113
|
+
finally:
|
|
114
|
+
self._txn_depth = max(0, self._txn_depth - 1)
|
|
115
|
+
|
|
116
|
+
def _prepare(self, sql: str, params: Optional[Dict[str, Any]]) -> PreparedSQL:
|
|
117
|
+
params = params or {}
|
|
118
|
+
return self.dialect.prepare(sql, params)
|
|
119
|
+
|
|
120
|
+
def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> int:
|
|
121
|
+
prepared = self._prepare(sql, params)
|
|
122
|
+
try:
|
|
123
|
+
fn = _require_callable(self.executor, "execute")
|
|
124
|
+
result = fn(prepared.sql, prepared.params)
|
|
125
|
+
if isinstance(result, int):
|
|
126
|
+
return result
|
|
127
|
+
return 0
|
|
128
|
+
except Exception as e:
|
|
129
|
+
raise IntegrityError(str(e)) from e
|
|
130
|
+
|
|
131
|
+
def executemany(self, sql: str, params_list: list[dict[str, Any]]) -> int:
|
|
132
|
+
if not params_list:
|
|
133
|
+
return 0
|
|
134
|
+
try:
|
|
135
|
+
fn = getattr(self.executor, "executemany", None)
|
|
136
|
+
if callable(fn):
|
|
137
|
+
prepared_list = [self._prepare(sql, p) for p in params_list]
|
|
138
|
+
if len({p.sql for p in prepared_list}) != 1:
|
|
139
|
+
raise ConfigurationError(
|
|
140
|
+
"executemany requires stable SQL after parameter preparation"
|
|
141
|
+
)
|
|
142
|
+
result = fn(prepared_list[0].sql, [p.params for p in prepared_list])
|
|
143
|
+
return int(result) if isinstance(result, int) else 0
|
|
144
|
+
|
|
145
|
+
total = 0
|
|
146
|
+
for p in params_list:
|
|
147
|
+
total += self.execute(sql, p)
|
|
148
|
+
return total
|
|
149
|
+
except Exception as e:
|
|
150
|
+
raise IntegrityError(str(e)) from e
|
|
151
|
+
|
|
152
|
+
def fetch_all(
|
|
153
|
+
self, sql: str, params: Optional[Dict[str, Any]] = None
|
|
154
|
+
) -> list[dict[str, Any]]:
|
|
155
|
+
prepared = self._prepare(sql, params)
|
|
156
|
+
fn = _require_callable(self.executor, "fetch_all")
|
|
157
|
+
rows = fn(prepared.sql, prepared.params)
|
|
158
|
+
if rows is None:
|
|
159
|
+
return []
|
|
160
|
+
return [dict(r) for r in rows]
|
|
161
|
+
|
|
162
|
+
def fetch_one(
|
|
163
|
+
self, sql: str, params: Optional[Dict[str, Any]] = None
|
|
164
|
+
) -> Optional[dict[str, Any]]:
|
|
165
|
+
prepared = self._prepare(sql, params)
|
|
166
|
+
fn = _require_callable(self.executor, "fetch_one")
|
|
167
|
+
row = fn(prepared.sql, prepared.params)
|
|
168
|
+
return None if row is None else dict(row)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class AsyncDatabase:
|
|
172
|
+
"""
|
|
173
|
+
异步数据库执行器封装。
|
|
174
|
+
|
|
175
|
+
连接与连接池不由 ns-orm 负责,必须由外部库(例如 lesscode-database)
|
|
176
|
+
创建并传入 executor。
|
|
177
|
+
|
|
178
|
+
executor 最低需提供以下协程方法(duck-typing):
|
|
179
|
+
- execute(sql, params=None) -> int | Any
|
|
180
|
+
- fetch_all(sql, params=None) -> list[dict]
|
|
181
|
+
- fetch_one(sql, params=None) -> dict | None
|
|
182
|
+
|
|
183
|
+
可选:
|
|
184
|
+
- transaction() -> async context manager
|
|
185
|
+
"""
|
|
186
|
+
|
|
187
|
+
def __init__(self, *, executor: Any, dialect: Dialect):
|
|
188
|
+
self.executor = executor
|
|
189
|
+
self.dialect = dialect
|
|
190
|
+
self._txn_depth: int = 0
|
|
191
|
+
|
|
192
|
+
@contextlib.asynccontextmanager
|
|
193
|
+
async def transaction(self) -> AsyncGenerator[None, None]:
|
|
194
|
+
self._txn_depth += 1
|
|
195
|
+
try:
|
|
196
|
+
txn = getattr(self.executor, "transaction", None)
|
|
197
|
+
if callable(txn):
|
|
198
|
+
async with txn():
|
|
199
|
+
yield
|
|
200
|
+
return
|
|
201
|
+
yield
|
|
202
|
+
finally:
|
|
203
|
+
self._txn_depth = max(0, self._txn_depth - 1)
|
|
204
|
+
|
|
205
|
+
def _prepare(self, sql: str, params: Optional[Dict[str, Any]]) -> PreparedSQL:
|
|
206
|
+
params = params or {}
|
|
207
|
+
return self.dialect.prepare(sql, params)
|
|
208
|
+
|
|
209
|
+
async def execute(self, sql: str, params: Optional[Dict[str, Any]] = None) -> int:
|
|
210
|
+
prepared = self._prepare(sql, params)
|
|
211
|
+
try:
|
|
212
|
+
fn = _require_callable(self.executor, "execute")
|
|
213
|
+
result = fn(prepared.sql, prepared.params)
|
|
214
|
+
if _is_awaitable(result):
|
|
215
|
+
result = await result
|
|
216
|
+
return int(result) if isinstance(result, int) else 0
|
|
217
|
+
except Exception as e:
|
|
218
|
+
raise IntegrityError(str(e)) from e
|
|
219
|
+
|
|
220
|
+
async def fetch_all(
|
|
221
|
+
self, sql: str, params: Optional[Dict[str, Any]] = None
|
|
222
|
+
) -> list[dict[str, Any]]:
|
|
223
|
+
prepared = self._prepare(sql, params)
|
|
224
|
+
fn = _require_callable(self.executor, "fetch_all")
|
|
225
|
+
rows = fn(prepared.sql, prepared.params)
|
|
226
|
+
if _is_awaitable(rows):
|
|
227
|
+
rows = await rows
|
|
228
|
+
if rows is None:
|
|
229
|
+
return []
|
|
230
|
+
return [dict(r) for r in rows]
|
|
231
|
+
|
|
232
|
+
async def fetch_one(
|
|
233
|
+
self, sql: str, params: Optional[Dict[str, Any]] = None
|
|
234
|
+
) -> Optional[dict[str, Any]]:
|
|
235
|
+
prepared = self._prepare(sql, params)
|
|
236
|
+
fn = _require_callable(self.executor, "fetch_one")
|
|
237
|
+
row = fn(prepared.sql, prepared.params)
|
|
238
|
+
if _is_awaitable(row):
|
|
239
|
+
row = await row
|
|
240
|
+
return None if row is None else dict(row)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
_connections: dict[str, Any] = {}
|
|
244
|
+
_connection_provider: Callable[[str], Any] | None = None
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def register_connection(name: str, db: Union[Database, AsyncDatabase]) -> None:
|
|
248
|
+
_connections[name] = db
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def set_connection_provider(provider: Callable[[str], Any]) -> None:
|
|
252
|
+
global _connection_provider
|
|
253
|
+
_connection_provider = provider
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def _default_lesscode_database_provider(connect_name: str) -> Any:
|
|
257
|
+
try:
|
|
258
|
+
from lesscode_database.db_options import db_options # type: ignore
|
|
259
|
+
except Exception as e:
|
|
260
|
+
raise ConfigurationError("lesscode-database is required") from e
|
|
261
|
+
|
|
262
|
+
if not connect_name:
|
|
263
|
+
raise ConfigurationError("connect_name is empty")
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
connect, connect_info = getattr(db_options, connect_name)
|
|
267
|
+
except Exception as e:
|
|
268
|
+
raise ConfigurationError(f"Unknown connect_name: {connect_name}") from e
|
|
269
|
+
|
|
270
|
+
dialect = dialect_from_url(getattr(connect_info, "dialect", ""))
|
|
271
|
+
executor = LesscodeDatabaseExecutorAdapter(connect)
|
|
272
|
+
|
|
273
|
+
if bool(getattr(connect_info, "async_enable", False)):
|
|
274
|
+
return AsyncDatabase(executor=executor, dialect=dialect)
|
|
275
|
+
return Database(executor=executor, dialect=dialect)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _ensure_default_provider() -> None:
|
|
279
|
+
global _connection_provider
|
|
280
|
+
if _connection_provider is None:
|
|
281
|
+
_connection_provider = _default_lesscode_database_provider
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def get_connection(name: str) -> Any:
|
|
285
|
+
if name in _connections:
|
|
286
|
+
return _connections[name]
|
|
287
|
+
_ensure_default_provider()
|
|
288
|
+
if _connection_provider is not None:
|
|
289
|
+
db = _connection_provider(name)
|
|
290
|
+
_connections[name] = db
|
|
291
|
+
return db
|
|
292
|
+
raise ConfigurationError(f"Connection not found: {name}")
|