dclassql 0.3.1__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dclassql-0.3.1 → dclassql-0.4.1}/PKG-INFO +5 -2
- {dclassql-0.3.1 → dclassql-0.4.1}/README.md +4 -1
- {dclassql-0.3.1 → dclassql-0.4.1}/pyproject.toml +1 -1
- dclassql-0.4.1/src/dclassql/__init__.py +21 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/cli.py +75 -36
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/codegen.py +136 -35
- dclassql-0.4.1/src/dclassql/db_pool.py +89 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/model_inspector.py +98 -25
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/__init__.py +14 -21
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/base.py +20 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/sqlite.py +36 -1
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/base.py +3 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/lazy.py +21 -5
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/datasource.py +2 -5
- dclassql-0.4.1/src/dclassql/runtime/json_value.py +104 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/client_class.jinja +15 -22
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/imports.jinja +3 -4
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/model_section.jinja +18 -4
- dclassql-0.3.1/src/dclassql/__init__.py +0 -34
- dclassql-0.3.1/src/dclassql/asdict.pyi +0 -57
- dclassql-0.3.1/src/dclassql/client.py +0 -1397
- dclassql-0.3.1/src/dclassql/db_pool.py +0 -76
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/.gitignore +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/asdict.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/__init__.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/metadata.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/protocols.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/sqlite.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/where_compiler.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/sql_recorder.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/sqlite_adapters.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/table_spec.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/__init__.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/asdict_stub.pyi.jinja +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/client_module.py.jinja +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/exports.jinja +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/macros.jinja +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/scalar_filters.jinja +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/typing.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/unwarp.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/utils/__init__.py +0 -0
- {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/utils/ensure.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: dclassql
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions.
|
|
5
5
|
Keywords: orm,codegen,sqlite,dataclass,typed
|
|
6
6
|
Author: myuanz
|
|
@@ -62,11 +62,14 @@ class User:
|
|
|
62
62
|
写出如下代码时:
|
|
63
63
|
|
|
64
64
|
```python
|
|
65
|
-
from dclassql import
|
|
65
|
+
from dclassql import Client
|
|
66
|
+
|
|
67
|
+
client = Client()
|
|
66
68
|
|
|
67
69
|
client.user.insert({
|
|
68
70
|
"name": "Alice",
|
|
69
71
|
"email": "test@example.com",
|
|
72
|
+
# 这里缺少 last_login
|
|
70
73
|
})
|
|
71
74
|
```
|
|
72
75
|
|
|
@@ -41,11 +41,14 @@ class User:
|
|
|
41
41
|
写出如下代码时:
|
|
42
42
|
|
|
43
43
|
```python
|
|
44
|
-
from dclassql import
|
|
44
|
+
from dclassql import Client
|
|
45
|
+
|
|
46
|
+
client = Client()
|
|
45
47
|
|
|
46
48
|
client.user.insert({
|
|
47
49
|
"name": "Alice",
|
|
48
50
|
"email": "test@example.com",
|
|
51
|
+
# 这里缺少 last_login
|
|
49
52
|
})
|
|
50
53
|
```
|
|
51
54
|
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from .asdict import asdict
|
|
2
|
+
from .db_pool import BaseDBPool, save_local
|
|
3
|
+
from .model_inspector import DataSourceConfig
|
|
4
|
+
from .push import db_push
|
|
5
|
+
from .runtime.backends.lazy import eager
|
|
6
|
+
from .runtime.sql_recorder import record_sql
|
|
7
|
+
from .unwarp import unwarp, unwarp_or, unwarp_or_raise
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
'db_push',
|
|
12
|
+
'eager',
|
|
13
|
+
'asdict',
|
|
14
|
+
'unwarp',
|
|
15
|
+
'unwarp_or',
|
|
16
|
+
'unwarp_or_raise',
|
|
17
|
+
'BaseDBPool',
|
|
18
|
+
'save_local',
|
|
19
|
+
'DataSourceConfig',
|
|
20
|
+
'record_sql',
|
|
21
|
+
]
|
|
@@ -18,6 +18,7 @@ from .runtime.datasource import open_sqlite_connection
|
|
|
18
18
|
DEFAULT_MODEL_FILE = "model.py"
|
|
19
19
|
GENERATED_CLIENT_FILENAME = "client.py"
|
|
20
20
|
|
|
21
|
+
GenerateTarget = Literal["model-dir", "package"]
|
|
21
22
|
ConfirmRebuildMode = Literal["auto", "prompt"]
|
|
22
23
|
ConfirmCallback = Callable[
|
|
23
24
|
[ModelInfo, SchemaPlan, tuple[ExistingColumn, ...] | None, SchemaDiff],
|
|
@@ -51,11 +52,17 @@ def load_module(module_path: Path) -> ModuleType:
|
|
|
51
52
|
raise ImportError(f"Unable to load module from '{module_path}'")
|
|
52
53
|
module = importlib.util.module_from_spec(spec)
|
|
53
54
|
sys.modules[module_name] = module
|
|
54
|
-
sys.path
|
|
55
|
+
original_sys_path = list(sys.path)
|
|
56
|
+
search_paths = [str(module_path.parent)]
|
|
57
|
+
cwd = str(Path.cwd())
|
|
58
|
+
if cwd not in search_paths:
|
|
59
|
+
search_paths.append(cwd)
|
|
60
|
+
for path in reversed(search_paths):
|
|
61
|
+
sys.path.insert(0, path)
|
|
55
62
|
try:
|
|
56
63
|
spec.loader.exec_module(module)
|
|
57
64
|
finally:
|
|
58
|
-
sys.path
|
|
65
|
+
sys.path[:] = original_sys_path
|
|
59
66
|
return module
|
|
60
67
|
|
|
61
68
|
|
|
@@ -66,29 +73,61 @@ def _find_package_directory() -> Path:
|
|
|
66
73
|
return Path(next(iter(spec.submodule_search_locations))).resolve()
|
|
67
74
|
|
|
68
75
|
|
|
69
|
-
def
|
|
70
|
-
|
|
76
|
+
def resolve_client_package_name(module_path: Path) -> str:
|
|
77
|
+
stem = re.sub(r"[^0-9a-zA-Z_]+", "_", module_path.stem) or "_"
|
|
78
|
+
return f"{stem}_client"
|
|
71
79
|
|
|
72
80
|
|
|
73
|
-
def
|
|
74
|
-
|
|
81
|
+
def resolve_client_class_name(module_path: Path) -> str:
|
|
82
|
+
package_name = resolve_client_package_name(module_path)
|
|
83
|
+
return "".join(part.capitalize() for part in package_name.split("_") if part)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def resolve_generated_package_dir(module_path: Path, target: GenerateTarget = "model-dir") -> Path:
|
|
87
|
+
package_name = resolve_client_package_name(module_path)
|
|
88
|
+
if target == "model-dir":
|
|
89
|
+
return module_path.resolve().parent / package_name
|
|
90
|
+
return _find_package_directory() / package_name
|
|
75
91
|
|
|
76
92
|
|
|
77
93
|
def collect_models(module: ModuleType) -> list[type[Any]]:
|
|
78
94
|
from dataclasses import is_dataclass
|
|
79
95
|
|
|
96
|
+
excluded_names = _collect_excluded_model_names(module)
|
|
80
97
|
models: list[type[Any]] = []
|
|
81
98
|
for value in vars(module).values():
|
|
82
|
-
if
|
|
99
|
+
if (
|
|
100
|
+
isinstance(value, type)
|
|
101
|
+
and is_dataclass(value)
|
|
102
|
+
and value.__module__ == module.__name__
|
|
103
|
+
and value.__name__ not in excluded_names
|
|
104
|
+
):
|
|
83
105
|
models.append(value)
|
|
84
106
|
if not models:
|
|
85
107
|
raise ValueError("No dataclass models were found in the provided module")
|
|
86
108
|
return models
|
|
87
109
|
|
|
88
110
|
|
|
111
|
+
def _collect_excluded_model_names(module: ModuleType) -> set[str]:
|
|
112
|
+
raw = getattr(module, "__exclude__", ())
|
|
113
|
+
if raw is None:
|
|
114
|
+
return set()
|
|
115
|
+
if isinstance(raw, str):
|
|
116
|
+
return {raw}
|
|
117
|
+
names: set[str] = set()
|
|
118
|
+
for item in raw:
|
|
119
|
+
if isinstance(item, str):
|
|
120
|
+
names.add(item)
|
|
121
|
+
continue
|
|
122
|
+
if isinstance(item, type):
|
|
123
|
+
names.add(item.__name__)
|
|
124
|
+
continue
|
|
125
|
+
raise TypeError("__exclude__ entries must be dataclass classes or class names")
|
|
126
|
+
return names
|
|
127
|
+
|
|
128
|
+
|
|
89
129
|
def _describe_schema_diff(info: ModelInfo, diff: SchemaDiff) -> str:
|
|
90
|
-
|
|
91
|
-
prefix = f"[{datasource_key}] " if datasource_key else ""
|
|
130
|
+
prefix = f"[{info.datasource.identity}] "
|
|
92
131
|
parts: list[str] = [f"{prefix}模型 {info.model.__name__} 需要重建表"]
|
|
93
132
|
if diff.added:
|
|
94
133
|
added = ", ".join(f"+{column.name}:{column.type_sql}" for column in diff.added)
|
|
@@ -134,44 +173,38 @@ def push_database(
|
|
|
134
173
|
confirm_mode: ConfirmRebuildMode | None = None,
|
|
135
174
|
) -> None:
|
|
136
175
|
model_infos = inspect_models(models)
|
|
137
|
-
|
|
138
|
-
|
|
176
|
+
datasource_configs = {info.datasource for info in model_infos.values()}
|
|
177
|
+
if len(datasource_configs) != 1:
|
|
178
|
+
labels = ", ".join(sorted(config.identity for config in datasource_configs))
|
|
179
|
+
raise ValueError(f"push-db only supports one datasource, got: {labels}")
|
|
180
|
+
datasource = next(iter(datasource_configs))
|
|
181
|
+
if datasource.provider != "sqlite":
|
|
182
|
+
raise ValueError(f"Unsupported provider '{datasource.provider}'")
|
|
139
183
|
confirm_callback = _build_confirm_callback(confirm_mode) if confirm_mode else None
|
|
184
|
+
connection = open_sqlite_connection(datasource.url)
|
|
140
185
|
try:
|
|
141
|
-
for info in model_infos.values():
|
|
142
|
-
config = info.datasource
|
|
143
|
-
key = config.key
|
|
144
|
-
if key in connections:
|
|
145
|
-
continue
|
|
146
|
-
if config.provider != "sqlite":
|
|
147
|
-
raise ValueError(f"Unsupported provider '{config.provider}'")
|
|
148
|
-
connection = open_sqlite_connection(config.url)
|
|
149
|
-
connections[key] = connection
|
|
150
|
-
opened.append(connection)
|
|
151
186
|
db_push(
|
|
152
187
|
models,
|
|
153
|
-
|
|
188
|
+
connection,
|
|
154
189
|
sync_indexes=sync_indexes,
|
|
155
190
|
confirm_rebuild=confirm_callback,
|
|
156
191
|
)
|
|
157
192
|
finally:
|
|
158
|
-
|
|
159
|
-
try:
|
|
160
|
-
conn.close()
|
|
161
|
-
except Exception:
|
|
162
|
-
pass
|
|
193
|
+
connection.close()
|
|
163
194
|
|
|
164
195
|
|
|
165
|
-
def command_generate(module_path: Path) -> None:
|
|
196
|
+
def command_generate(module_path: Path, *, target: GenerateTarget = "model-dir") -> None:
|
|
166
197
|
module = load_module(module_path)
|
|
167
198
|
models = collect_models(module)
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
199
|
+
client_class_name = resolve_client_class_name(module_path)
|
|
200
|
+
generated = generate_client(models, client_class_name=client_class_name)
|
|
201
|
+
output_dir = resolve_generated_package_dir(module_path, target)
|
|
202
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
203
|
+
(output_dir / "__init__.py").write_text(generated.init_code, encoding="utf-8")
|
|
204
|
+
(output_dir / "__init__.pyi").write_text(generated.init_stub, encoding="utf-8")
|
|
205
|
+
(output_dir / GENERATED_CLIENT_FILENAME).write_text(generated.code, encoding="utf-8")
|
|
206
|
+
(output_dir / "asdict.pyi").write_text(generated.asdict_stub, encoding="utf-8")
|
|
207
|
+
sys.stdout.write(f"Client package written to {output_dir}\n")
|
|
175
208
|
|
|
176
209
|
|
|
177
210
|
def command_push_db(
|
|
@@ -197,7 +230,13 @@ def build_parser() -> argparse.ArgumentParser:
|
|
|
197
230
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
198
231
|
|
|
199
232
|
generate_parser = subparsers.add_parser("generate", help="Generate client code for given models")
|
|
200
|
-
generate_parser.
|
|
233
|
+
generate_parser.add_argument(
|
|
234
|
+
"--target",
|
|
235
|
+
choices=("model-dir", "package"),
|
|
236
|
+
default="model-dir",
|
|
237
|
+
help="生成 client 的位置: model-dir 写到模型文件同目录; package 写到 dclassql 包内",
|
|
238
|
+
)
|
|
239
|
+
generate_parser.set_defaults(handler=lambda args: command_generate(args.module, target=args.target))
|
|
201
240
|
|
|
202
241
|
push_parser = subparsers.add_parser("push-db", help="Apply schema and indexes to configured databases")
|
|
203
242
|
push_parser.add_argument(
|
|
@@ -17,7 +17,10 @@ from .model_inspector import ColumnInfo, ModelInfo, inspect_models, DataSourceCo
|
|
|
17
17
|
class GeneratedModule:
|
|
18
18
|
code: str
|
|
19
19
|
asdict_stub: str
|
|
20
|
+
init_code: str
|
|
21
|
+
init_stub: str
|
|
20
22
|
model_names: tuple[str, ...]
|
|
23
|
+
client_class_name: str
|
|
21
24
|
|
|
22
25
|
|
|
23
26
|
@dataclass(slots=True)
|
|
@@ -48,14 +51,34 @@ class WhereFieldSpec:
|
|
|
48
51
|
@dataclass(slots=True)
|
|
49
52
|
class ColumnSpecRender:
|
|
50
53
|
name: str
|
|
54
|
+
'''数据库列名'''
|
|
55
|
+
|
|
51
56
|
name_repr: str
|
|
57
|
+
'''列名的 Python 字符串字面量形式, 用于生成代码里的 dict key.'''
|
|
58
|
+
|
|
52
59
|
optional: bool
|
|
60
|
+
'''插入/更新时是否允许 None 或缺省, 来自 Optional/default/factory 判断.'''
|
|
61
|
+
|
|
53
62
|
auto_increment: bool
|
|
63
|
+
'''是否自增,给主键用的'''
|
|
64
|
+
|
|
54
65
|
has_default: bool
|
|
66
|
+
'''原 dataclass 字段是否有 default'''
|
|
67
|
+
|
|
55
68
|
has_default_factory: bool
|
|
69
|
+
'''原 dataclass 字段是否有 default_factory'''
|
|
70
|
+
|
|
71
|
+
returned_field: bool
|
|
72
|
+
'''是否会进入返回的原 dataclass 对象. 隐式 id 为 False.'''
|
|
73
|
+
|
|
56
74
|
mapping_value_expr: str
|
|
75
|
+
'''Mapping payload 转数据库值的生成表达式. 例如 `data['open_order_id']`.'''
|
|
76
|
+
|
|
57
77
|
insert_value_expr: str
|
|
78
|
+
'''Insert dataclass 或原模型实例转数据库值的生成表达式. 隐式 id 用 getattr 默认 None.'''
|
|
79
|
+
|
|
58
80
|
is_enum: bool
|
|
81
|
+
'''是否是 Enum 列'''
|
|
59
82
|
|
|
60
83
|
|
|
61
84
|
@dataclass(slots=True)
|
|
@@ -125,6 +148,7 @@ class ModelRenderContext:
|
|
|
125
148
|
indexes_literal: str
|
|
126
149
|
unique_indexes_literal: str
|
|
127
150
|
primary_value_types: tuple[str, ...]
|
|
151
|
+
primary_key_on_model: bool
|
|
128
152
|
row_assignments: tuple[RowAssignmentRender, ...]
|
|
129
153
|
default_factories: tuple[DefaultFactoryRender, ...]
|
|
130
154
|
model_info: ModelInfo
|
|
@@ -132,9 +156,6 @@ class ModelRenderContext:
|
|
|
132
156
|
|
|
133
157
|
@dataclass(slots=True)
|
|
134
158
|
class ClientDataSourceContext:
|
|
135
|
-
key: str
|
|
136
|
-
key_repr: str
|
|
137
|
-
provider_repr: str
|
|
138
159
|
url_repr: str
|
|
139
160
|
name_repr: str
|
|
140
161
|
|
|
@@ -147,6 +168,7 @@ class ClientModelBindingContext:
|
|
|
147
168
|
|
|
148
169
|
@dataclass(slots=True)
|
|
149
170
|
class ClientContext:
|
|
171
|
+
class_name: str
|
|
150
172
|
datasource: ClientDataSourceContext
|
|
151
173
|
model_bindings: tuple[ClientModelBindingContext, ...]
|
|
152
174
|
|
|
@@ -167,7 +189,7 @@ def _get_environment() -> Environment:
|
|
|
167
189
|
return _ENVIRONMENT
|
|
168
190
|
|
|
169
191
|
|
|
170
|
-
def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
|
|
192
|
+
def generate_client(models: Sequence[type[Any]], *, client_class_name: str = "GeneratedClient") -> GeneratedModule:
|
|
171
193
|
model_infos = inspect_models(models)
|
|
172
194
|
renderer = _TypeRenderer({info.model: name for name, info in model_infos.items()})
|
|
173
195
|
filter_registry = _ScalarFilterRegistry(renderer)
|
|
@@ -192,8 +214,8 @@ def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
|
|
|
192
214
|
for module, names in sorted(combined_imports.items())
|
|
193
215
|
]
|
|
194
216
|
|
|
195
|
-
client_context = _build_client_context(model_infos)
|
|
196
|
-
exports = _collect_exports(model_contexts)
|
|
217
|
+
client_context = _build_client_context(model_infos, client_class_name)
|
|
218
|
+
exports = _collect_exports(model_contexts, client_class_name)
|
|
197
219
|
scalar_filters = filter_registry.render_definitions()
|
|
198
220
|
|
|
199
221
|
template = _get_environment().get_template(_TEMPLATE_NAME)
|
|
@@ -207,7 +229,16 @@ def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
|
|
|
207
229
|
if not code.endswith("\n"):
|
|
208
230
|
code += "\n"
|
|
209
231
|
asdict_stub = _render_asdict_stub(model_contexts)
|
|
210
|
-
|
|
232
|
+
init_code = _render_init_code(client_class_name)
|
|
233
|
+
init_stub = _render_init_stub(client_class_name)
|
|
234
|
+
return GeneratedModule(
|
|
235
|
+
code=code,
|
|
236
|
+
asdict_stub=asdict_stub,
|
|
237
|
+
init_code=init_code,
|
|
238
|
+
init_stub=init_stub,
|
|
239
|
+
model_names=tuple(sorted(model_infos.keys())),
|
|
240
|
+
client_class_name=client_class_name,
|
|
241
|
+
)
|
|
211
242
|
|
|
212
243
|
|
|
213
244
|
def _build_model_context(
|
|
@@ -224,10 +255,13 @@ def _build_model_context(
|
|
|
224
255
|
upsert_where_dicts: list[UpsertWhereRender] = []
|
|
225
256
|
dict_field_map: dict[str, str] = {}
|
|
226
257
|
enum_type_map: dict[str, type[Enum] | None] = {}
|
|
227
|
-
|
|
228
|
-
|
|
258
|
+
model_column_names = {col.name for col in info.columns}
|
|
259
|
+
db_columns = _build_db_columns(info)
|
|
260
|
+
column_lookup: dict[str, ColumnInfo] = {col.name: col for col in db_columns}
|
|
261
|
+
primary_key_on_model = all(column_name in model_column_names for column_name in info.primary_key)
|
|
262
|
+
for col in db_columns:
|
|
229
263
|
annotation = _format_insert_annotation(col, renderer)
|
|
230
|
-
default_fragment = _render_default_fragment(
|
|
264
|
+
default_fragment = _render_default_fragment(info.model, col)
|
|
231
265
|
if default_fragment is not None:
|
|
232
266
|
default_expr = default_fragment
|
|
233
267
|
elif col.auto_increment:
|
|
@@ -254,8 +288,8 @@ def _build_model_context(
|
|
|
254
288
|
if info.primary_key:
|
|
255
289
|
pk_fields: list[TypedDictFieldSpec] = []
|
|
256
290
|
for pk_col in info.primary_key:
|
|
257
|
-
col_info = column_lookup
|
|
258
|
-
annotation = renderer.render(col_info.python_type)
|
|
291
|
+
col_info = column_lookup[pk_col]
|
|
292
|
+
annotation = renderer.render(col_info.python_type)
|
|
259
293
|
pk_fields.append(TypedDictFieldSpec(name=pk_col, annotation=annotation))
|
|
260
294
|
upsert_where_dicts.append(UpsertWhereRender(name=f"{name}UpsertWherePK", fields=tuple(pk_fields)))
|
|
261
295
|
|
|
@@ -270,8 +304,11 @@ def _build_model_context(
|
|
|
270
304
|
UpsertWhereRender(name=f"{name}UpsertWhereUnique{idx}", fields=tuple(unique_fields))
|
|
271
305
|
)
|
|
272
306
|
|
|
307
|
+
if not upsert_where_dicts:
|
|
308
|
+
renderer.require_typing("Never")
|
|
309
|
+
|
|
273
310
|
where_fields: list[WhereFieldSpec] = []
|
|
274
|
-
for col in
|
|
311
|
+
for col in db_columns:
|
|
275
312
|
annotation = renderer.render(col.python_type)
|
|
276
313
|
if "None" not in annotation:
|
|
277
314
|
annotation = f"{annotation} | None"
|
|
@@ -316,11 +353,16 @@ def _build_model_context(
|
|
|
316
353
|
auto_increment=column.auto_increment,
|
|
317
354
|
has_default=column.has_default,
|
|
318
355
|
has_default_factory=column.has_default_factory,
|
|
356
|
+
returned_field=column.name in model_column_names,
|
|
319
357
|
mapping_value_expr=_format_mapping_value_expr(column, enum_type_map.get(column.name)),
|
|
320
|
-
insert_value_expr=_format_insert_value_expr(
|
|
358
|
+
insert_value_expr=_format_insert_value_expr(
|
|
359
|
+
column,
|
|
360
|
+
enum_type_map.get(column.name),
|
|
361
|
+
returned_field=column.name in model_column_names,
|
|
362
|
+
),
|
|
321
363
|
is_enum=enum_type_map.get(column.name) is not None,
|
|
322
364
|
)
|
|
323
|
-
for column in
|
|
365
|
+
for column in db_columns
|
|
324
366
|
]
|
|
325
367
|
|
|
326
368
|
foreign_keys = [
|
|
@@ -347,7 +389,7 @@ def _build_model_context(
|
|
|
347
389
|
|
|
348
390
|
datasource_values = info.datasource
|
|
349
391
|
datasource_expr = (
|
|
350
|
-
f"DataSourceConfig(
|
|
392
|
+
f"DataSourceConfig(url={repr(datasource_values.url)}, name={repr(datasource_values.name)})"
|
|
351
393
|
)
|
|
352
394
|
|
|
353
395
|
indexes_literal = _tuple_literal(tuple(tuple(idx) for idx in info.indexes)) if info.indexes else "()"
|
|
@@ -355,13 +397,10 @@ def _build_model_context(
|
|
|
355
397
|
_tuple_literal(tuple(tuple(idx) for idx in info.unique_indexes)) if info.unique_indexes else "()"
|
|
356
398
|
)
|
|
357
399
|
|
|
358
|
-
row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map)
|
|
400
|
+
row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map, renderer)
|
|
359
401
|
primary_value_types: list[str] = []
|
|
360
402
|
for column_name in info.primary_key:
|
|
361
|
-
column = column_lookup
|
|
362
|
-
if column is None:
|
|
363
|
-
primary_value_types.append("object")
|
|
364
|
-
continue
|
|
403
|
+
column = column_lookup[column_name]
|
|
365
404
|
primary_value_types.append(renderer.render(column.python_type))
|
|
366
405
|
|
|
367
406
|
relation_lookup = {relation.name: relation for relation in info.relations}
|
|
@@ -403,24 +442,39 @@ def _build_model_context(
|
|
|
403
442
|
indexes_literal=indexes_literal,
|
|
404
443
|
unique_indexes_literal=unique_indexes_literal,
|
|
405
444
|
primary_value_types=tuple(primary_value_types),
|
|
445
|
+
primary_key_on_model=primary_key_on_model,
|
|
406
446
|
row_assignments=tuple(row_assignments),
|
|
407
447
|
default_factories=tuple(default_factories),
|
|
408
448
|
model_info=info,
|
|
409
449
|
)
|
|
410
450
|
|
|
411
451
|
|
|
412
|
-
def
|
|
452
|
+
def _build_db_columns(info: ModelInfo) -> tuple[ColumnInfo, ...]:
|
|
453
|
+
if info.primary_key == ("id",) and all(column.name != "id" for column in info.columns):
|
|
454
|
+
implicit_id = ColumnInfo(
|
|
455
|
+
name="id",
|
|
456
|
+
python_type=int,
|
|
457
|
+
optional=False,
|
|
458
|
+
auto_increment=True,
|
|
459
|
+
storage_kind="scalar",
|
|
460
|
+
has_default=False,
|
|
461
|
+
default_value=None,
|
|
462
|
+
has_default_factory=False,
|
|
463
|
+
default_factory=None,
|
|
464
|
+
)
|
|
465
|
+
return (implicit_id, *info.columns)
|
|
466
|
+
return tuple(info.columns)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
def _build_client_context(model_infos: Mapping[str, ModelInfo], client_class_name: str) -> ClientContext:
|
|
413
470
|
datasource_configs = {info.datasource for info in model_infos.values()}
|
|
414
471
|
if len(datasource_configs) != 1:
|
|
415
472
|
labels = ", ".join(
|
|
416
|
-
f"{ds.
|
|
473
|
+
f"{ds.identity}({ds.url!r})" for ds in sorted(datasource_configs, key=lambda item: item.identity)
|
|
417
474
|
)
|
|
418
475
|
raise ValueError(f"Generated Client can only use one datasource, got: {labels}")
|
|
419
476
|
datasource = next(iter(datasource_configs))
|
|
420
477
|
datasource_item = ClientDataSourceContext(
|
|
421
|
-
key=datasource.key,
|
|
422
|
-
key_repr=repr(datasource.key),
|
|
423
|
-
provider_repr=repr(datasource.provider),
|
|
424
478
|
url_repr=repr(datasource.url),
|
|
425
479
|
name_repr=repr(datasource.name),
|
|
426
480
|
)
|
|
@@ -434,13 +488,14 @@ def _build_client_context(model_infos: Mapping[str, ModelInfo]) -> ClientContext
|
|
|
434
488
|
]
|
|
435
489
|
|
|
436
490
|
return ClientContext(
|
|
491
|
+
class_name=client_class_name,
|
|
437
492
|
datasource=datasource_item,
|
|
438
493
|
model_bindings=tuple(model_bindings),
|
|
439
494
|
)
|
|
440
495
|
|
|
441
496
|
|
|
442
|
-
def _collect_exports(model_contexts: Sequence[ModelRenderContext]) -> list[str]:
|
|
443
|
-
exports: list[str] = ["DataSourceConfig", "ForeignKeySpec",
|
|
497
|
+
def _collect_exports(model_contexts: Sequence[ModelRenderContext], client_class_name: str) -> list[str]:
|
|
498
|
+
exports: list[str] = ["DataSourceConfig", "ForeignKeySpec", client_class_name]
|
|
444
499
|
for context in model_contexts:
|
|
445
500
|
name = context.name
|
|
446
501
|
exports.extend(
|
|
@@ -472,6 +527,24 @@ def _render_asdict_stub(model_contexts: Sequence[ModelRenderContext]) -> str:
|
|
|
472
527
|
return code
|
|
473
528
|
|
|
474
529
|
|
|
530
|
+
def _render_init_code(client_class_name: str) -> str:
|
|
531
|
+
code = (
|
|
532
|
+
"from dclassql.asdict import asdict as asdict\n"
|
|
533
|
+
f"from .client import {client_class_name} as {client_class_name}\n\n"
|
|
534
|
+
f"__all__ = ['{client_class_name}', 'asdict']\n"
|
|
535
|
+
)
|
|
536
|
+
return code
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
def _render_init_stub(client_class_name: str) -> str:
|
|
540
|
+
code = (
|
|
541
|
+
"from .asdict import asdict as asdict\n"
|
|
542
|
+
f"from .client import {client_class_name} as {client_class_name}\n\n"
|
|
543
|
+
"__all__: list[str]\n"
|
|
544
|
+
)
|
|
545
|
+
return code
|
|
546
|
+
|
|
547
|
+
|
|
475
548
|
def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo]) -> list[dict[str, Any]]:
|
|
476
549
|
entries: list[dict[str, Any]] = []
|
|
477
550
|
if not info.relations:
|
|
@@ -488,7 +561,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
|
|
|
488
561
|
mapping: tuple[tuple[str, str], ...] | None = None
|
|
489
562
|
if not relation.many:
|
|
490
563
|
for fk in info.foreign_keys:
|
|
491
|
-
if fk.remote_model is target_model:
|
|
564
|
+
if fk.remote_model is target_model and fk.relation_attribute == relation.name:
|
|
492
565
|
mapping = tuple((local, remote) for local, remote in zip(fk.local_columns, fk.remote_columns))
|
|
493
566
|
break
|
|
494
567
|
if mapping is None:
|
|
@@ -524,6 +597,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
|
|
|
524
597
|
def _build_row_assignment_context(
|
|
525
598
|
info: ModelInfo,
|
|
526
599
|
enum_type_map: Mapping[str, type[Enum] | None],
|
|
600
|
+
renderer: "_TypeRenderer",
|
|
527
601
|
) -> tuple[list[RowAssignmentRender], list[DefaultFactoryRender]]:
|
|
528
602
|
dataclass_fields = fields(info.model)
|
|
529
603
|
column_map = {column.name: column for column in info.columns}
|
|
@@ -539,6 +613,7 @@ def _build_row_assignment_context(
|
|
|
539
613
|
column_map,
|
|
540
614
|
enum_type_map,
|
|
541
615
|
relation_defaults,
|
|
616
|
+
renderer,
|
|
542
617
|
)
|
|
543
618
|
assignments.append(RowAssignmentRender(field_name=field_obj.name, value_expr=assignment_expr))
|
|
544
619
|
if default_factory is not None:
|
|
@@ -552,12 +627,13 @@ def _resolve_row_assignment(
|
|
|
552
627
|
column_map: Mapping[str, ColumnInfo],
|
|
553
628
|
enum_type_map: Mapping[str, type[Enum] | None],
|
|
554
629
|
relation_defaults: Mapping[str, str],
|
|
630
|
+
renderer: "_TypeRenderer",
|
|
555
631
|
) -> tuple[str, DefaultFactoryRender | None]:
|
|
556
632
|
name = field_obj.name
|
|
557
633
|
column_info = column_map.get(name)
|
|
558
634
|
if column_info is not None:
|
|
559
635
|
enum_type = enum_type_map.get(name)
|
|
560
|
-
return _column_value_expression(column_info, enum_type), None
|
|
636
|
+
return _column_value_expression(column_info, enum_type, renderer), None
|
|
561
637
|
if field_obj.default is not MISSING:
|
|
562
638
|
return f"{model_cls.__name__}.__dataclass_fields__[{name!r}].default", None
|
|
563
639
|
if field_obj.default_factory is not MISSING:
|
|
@@ -570,8 +646,14 @@ def _resolve_row_assignment(
|
|
|
570
646
|
return _infer_field_fallback(field_obj.type), None
|
|
571
647
|
|
|
572
648
|
|
|
573
|
-
def _column_value_expression(
|
|
649
|
+
def _column_value_expression(
|
|
650
|
+
column: ColumnInfo,
|
|
651
|
+
enum_type: type[Enum] | None,
|
|
652
|
+
renderer: "_TypeRenderer",
|
|
653
|
+
) -> str:
|
|
574
654
|
base_expr = f"row[{column.name!r}]"
|
|
655
|
+
if column.storage_kind == "json":
|
|
656
|
+
return f"deserialize_json_value({base_expr}, {renderer.render(column.python_type)})"
|
|
575
657
|
if enum_type is None:
|
|
576
658
|
return base_expr
|
|
577
659
|
converter = enum_type.__name__
|
|
@@ -581,6 +663,8 @@ def _column_value_expression(column: ColumnInfo, enum_type: type[Enum] | None) -
|
|
|
581
663
|
|
|
582
664
|
|
|
583
665
|
def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
|
|
666
|
+
if column.storage_kind == "json":
|
|
667
|
+
return f"serialize_json_value(data[{column.name!r}])"
|
|
584
668
|
if enum_type is None:
|
|
585
669
|
return f"data[{column.name!r}]"
|
|
586
670
|
value_expr = f"data[{column.name!r}]"
|
|
@@ -589,7 +673,16 @@ def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None)
|
|
|
589
673
|
return f"{value_expr}.value"
|
|
590
674
|
|
|
591
675
|
|
|
592
|
-
def _format_insert_value_expr(
|
|
676
|
+
def _format_insert_value_expr(
|
|
677
|
+
column: ColumnInfo,
|
|
678
|
+
enum_type: type[Enum] | None,
|
|
679
|
+
*,
|
|
680
|
+
returned_field: bool = True,
|
|
681
|
+
) -> str:
|
|
682
|
+
if not returned_field:
|
|
683
|
+
return f"getattr(data, {column.name!r}, None)"
|
|
684
|
+
if column.storage_kind == "json":
|
|
685
|
+
return f"serialize_json_value(data.{column.name})"
|
|
593
686
|
if enum_type is None:
|
|
594
687
|
return f"data.{column.name}"
|
|
595
688
|
value_expr = f"data.{column.name}"
|
|
@@ -641,12 +734,12 @@ def _format_insert_annotation(col: ColumnInfo, renderer: "_TypeRenderer") -> str
|
|
|
641
734
|
return annotation
|
|
642
735
|
|
|
643
736
|
|
|
644
|
-
def _render_default_fragment(
|
|
737
|
+
def _render_default_fragment(model_cls: type[Any], col: ColumnInfo) -> str | None:
|
|
645
738
|
if col.has_default_factory and col.default_factory is not None:
|
|
646
|
-
factory_expr = f"{
|
|
739
|
+
factory_expr = f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default_factory"
|
|
647
740
|
return f"field(default_factory={factory_expr})"
|
|
648
741
|
if col.has_default:
|
|
649
|
-
return
|
|
742
|
+
return f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default"
|
|
650
743
|
return None
|
|
651
744
|
|
|
652
745
|
|
|
@@ -875,6 +968,14 @@ class _TypeRenderer:
|
|
|
875
968
|
self._typing_imports: set[str] = set()
|
|
876
969
|
|
|
877
970
|
def render(self, tp: Any) -> str:
|
|
971
|
+
alias_value = getattr(tp, "__value__", None)
|
|
972
|
+
if alias_value is not None:
|
|
973
|
+
alias_name = getattr(tp, "__name__", None)
|
|
974
|
+
alias_module = getattr(tp, "__module__", None)
|
|
975
|
+
if isinstance(alias_name, str) and isinstance(alias_module, str):
|
|
976
|
+
self._module_imports[alias_module].add(alias_name)
|
|
977
|
+
return alias_name
|
|
978
|
+
return self.render(alias_value)
|
|
878
979
|
if tp is Any:
|
|
879
980
|
return "Any"
|
|
880
981
|
if tp is type(None):
|