dclassql 0.3.1__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {dclassql-0.3.1 → dclassql-0.4.1}/PKG-INFO +5 -2
  2. {dclassql-0.3.1 → dclassql-0.4.1}/README.md +4 -1
  3. {dclassql-0.3.1 → dclassql-0.4.1}/pyproject.toml +1 -1
  4. dclassql-0.4.1/src/dclassql/__init__.py +21 -0
  5. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/cli.py +75 -36
  6. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/codegen.py +136 -35
  7. dclassql-0.4.1/src/dclassql/db_pool.py +89 -0
  8. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/model_inspector.py +98 -25
  9. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/__init__.py +14 -21
  10. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/base.py +20 -0
  11. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/push/sqlite.py +36 -1
  12. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/base.py +3 -0
  13. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/lazy.py +21 -5
  14. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/datasource.py +2 -5
  15. dclassql-0.4.1/src/dclassql/runtime/json_value.py +104 -0
  16. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/client_class.jinja +15 -22
  17. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/imports.jinja +3 -4
  18. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/model_section.jinja +18 -4
  19. dclassql-0.3.1/src/dclassql/__init__.py +0 -34
  20. dclassql-0.3.1/src/dclassql/asdict.pyi +0 -57
  21. dclassql-0.3.1/src/dclassql/client.py +0 -1397
  22. dclassql-0.3.1/src/dclassql/db_pool.py +0 -76
  23. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/.gitignore +0 -0
  24. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/asdict.py +0 -0
  25. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/__init__.py +0 -0
  26. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/metadata.py +0 -0
  27. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/protocols.py +0 -0
  28. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/sqlite.py +0 -0
  29. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/backends/where_compiler.py +0 -0
  30. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/sql_recorder.py +0 -0
  31. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/runtime/sqlite_adapters.py +0 -0
  32. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/table_spec.py +0 -0
  33. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/__init__.py +0 -0
  34. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/asdict_stub.pyi.jinja +0 -0
  35. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/client_module.py.jinja +0 -0
  36. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/exports.jinja +0 -0
  37. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/macros.jinja +0 -0
  38. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/templates/partials/scalar_filters.jinja +0 -0
  39. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/typing.py +0 -0
  40. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/unwarp.py +0 -0
  41. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/utils/__init__.py +0 -0
  42. {dclassql-0.3.1 → dclassql-0.4.1}/src/dclassql/utils/ensure.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dclassql
3
- Version: 0.3.1
3
+ Version: 0.4.1
4
4
  Summary: A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions.
5
5
  Keywords: orm,codegen,sqlite,dataclass,typed
6
6
  Author: myuanz
@@ -62,11 +62,14 @@ class User:
62
62
  写出如下代码时:
63
63
 
64
64
  ```python
65
- from dclassql import client
65
+ from dclassql import Client
66
+
67
+ client = Client()
66
68
 
67
69
  client.user.insert({
68
70
  "name": "Alice",
69
71
  "email": "test@example.com",
72
+ # 这里缺少 last_login
70
73
  })
71
74
  ```
72
75
 
@@ -41,11 +41,14 @@ class User:
41
41
  写出如下代码时:
42
42
 
43
43
  ```python
44
- from dclassql import client
44
+ from dclassql import Client
45
+
46
+ client = Client()
45
47
 
46
48
  client.user.insert({
47
49
  "name": "Alice",
48
50
  "email": "test@example.com",
51
+ # 这里缺少 last_login
49
52
  })
50
53
  ```
51
54
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dclassql"
3
- version = "0.3.1"
3
+ version = "0.4.1"
4
4
  description = "A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -0,0 +1,21 @@
1
+ from .asdict import asdict
2
+ from .db_pool import BaseDBPool, save_local
3
+ from .model_inspector import DataSourceConfig
4
+ from .push import db_push
5
+ from .runtime.backends.lazy import eager
6
+ from .runtime.sql_recorder import record_sql
7
+ from .unwarp import unwarp, unwarp_or, unwarp_or_raise
8
+
9
+
10
+ __all__ = [
11
+ 'db_push',
12
+ 'eager',
13
+ 'asdict',
14
+ 'unwarp',
15
+ 'unwarp_or',
16
+ 'unwarp_or_raise',
17
+ 'BaseDBPool',
18
+ 'save_local',
19
+ 'DataSourceConfig',
20
+ 'record_sql',
21
+ ]
@@ -18,6 +18,7 @@ from .runtime.datasource import open_sqlite_connection
18
18
  DEFAULT_MODEL_FILE = "model.py"
19
19
  GENERATED_CLIENT_FILENAME = "client.py"
20
20
 
21
+ GenerateTarget = Literal["model-dir", "package"]
21
22
  ConfirmRebuildMode = Literal["auto", "prompt"]
22
23
  ConfirmCallback = Callable[
23
24
  [ModelInfo, SchemaPlan, tuple[ExistingColumn, ...] | None, SchemaDiff],
@@ -51,11 +52,17 @@ def load_module(module_path: Path) -> ModuleType:
51
52
  raise ImportError(f"Unable to load module from '{module_path}'")
52
53
  module = importlib.util.module_from_spec(spec)
53
54
  sys.modules[module_name] = module
54
- sys.path.insert(0, str(module_path.parent))
55
+ original_sys_path = list(sys.path)
56
+ search_paths = [str(module_path.parent)]
57
+ cwd = str(Path.cwd())
58
+ if cwd not in search_paths:
59
+ search_paths.append(cwd)
60
+ for path in reversed(search_paths):
61
+ sys.path.insert(0, path)
55
62
  try:
56
63
  spec.loader.exec_module(module)
57
64
  finally:
58
- sys.path.pop(0)
65
+ sys.path[:] = original_sys_path
59
66
  return module
60
67
 
61
68
 
@@ -66,29 +73,61 @@ def _find_package_directory() -> Path:
66
73
  return Path(next(iter(spec.submodule_search_locations))).resolve()
67
74
 
68
75
 
69
- def resolve_generated_path() -> Path:
70
- return _find_package_directory() / GENERATED_CLIENT_FILENAME
76
+ def resolve_client_package_name(module_path: Path) -> str:
77
+ stem = re.sub(r"[^0-9a-zA-Z_]+", "_", module_path.stem) or "_"
78
+ return f"{stem}_client"
71
79
 
72
80
 
73
- def resolve_asdict_stub_path() -> Path:
74
- return _find_package_directory() / "asdict.pyi"
81
+ def resolve_client_class_name(module_path: Path) -> str:
82
+ package_name = resolve_client_package_name(module_path)
83
+ return "".join(part.capitalize() for part in package_name.split("_") if part)
84
+
85
+
86
+ def resolve_generated_package_dir(module_path: Path, target: GenerateTarget = "model-dir") -> Path:
87
+ package_name = resolve_client_package_name(module_path)
88
+ if target == "model-dir":
89
+ return module_path.resolve().parent / package_name
90
+ return _find_package_directory() / package_name
75
91
 
76
92
 
77
93
  def collect_models(module: ModuleType) -> list[type[Any]]:
78
94
  from dataclasses import is_dataclass
79
95
 
96
+ excluded_names = _collect_excluded_model_names(module)
80
97
  models: list[type[Any]] = []
81
98
  for value in vars(module).values():
82
- if isinstance(value, type) and is_dataclass(value) and value.__module__ == module.__name__:
99
+ if (
100
+ isinstance(value, type)
101
+ and is_dataclass(value)
102
+ and value.__module__ == module.__name__
103
+ and value.__name__ not in excluded_names
104
+ ):
83
105
  models.append(value)
84
106
  if not models:
85
107
  raise ValueError("No dataclass models were found in the provided module")
86
108
  return models
87
109
 
88
110
 
111
+ def _collect_excluded_model_names(module: ModuleType) -> set[str]:
112
+ raw = getattr(module, "__exclude__", ())
113
+ if raw is None:
114
+ return set()
115
+ if isinstance(raw, str):
116
+ return {raw}
117
+ names: set[str] = set()
118
+ for item in raw:
119
+ if isinstance(item, str):
120
+ names.add(item)
121
+ continue
122
+ if isinstance(item, type):
123
+ names.add(item.__name__)
124
+ continue
125
+ raise TypeError("__exclude__ entries must be dataclass classes or class names")
126
+ return names
127
+
128
+
89
129
  def _describe_schema_diff(info: ModelInfo, diff: SchemaDiff) -> str:
90
- datasource_key = getattr(info.datasource, "key", None)
91
- prefix = f"[{datasource_key}] " if datasource_key else ""
130
+ prefix = f"[{info.datasource.identity}] "
92
131
  parts: list[str] = [f"{prefix}模型 {info.model.__name__} 需要重建表"]
93
132
  if diff.added:
94
133
  added = ", ".join(f"+{column.name}:{column.type_sql}" for column in diff.added)
@@ -134,44 +173,38 @@ def push_database(
134
173
  confirm_mode: ConfirmRebuildMode | None = None,
135
174
  ) -> None:
136
175
  model_infos = inspect_models(models)
137
- connections: dict[str, Any] = {}
138
- opened: list[Any] = []
176
+ datasource_configs = {info.datasource for info in model_infos.values()}
177
+ if len(datasource_configs) != 1:
178
+ labels = ", ".join(sorted(config.identity for config in datasource_configs))
179
+ raise ValueError(f"push-db only supports one datasource, got: {labels}")
180
+ datasource = next(iter(datasource_configs))
181
+ if datasource.provider != "sqlite":
182
+ raise ValueError(f"Unsupported provider '{datasource.provider}'")
139
183
  confirm_callback = _build_confirm_callback(confirm_mode) if confirm_mode else None
184
+ connection = open_sqlite_connection(datasource.url)
140
185
  try:
141
- for info in model_infos.values():
142
- config = info.datasource
143
- key = config.key
144
- if key in connections:
145
- continue
146
- if config.provider != "sqlite":
147
- raise ValueError(f"Unsupported provider '{config.provider}'")
148
- connection = open_sqlite_connection(config.url)
149
- connections[key] = connection
150
- opened.append(connection)
151
186
  db_push(
152
187
  models,
153
- connections,
188
+ connection,
154
189
  sync_indexes=sync_indexes,
155
190
  confirm_rebuild=confirm_callback,
156
191
  )
157
192
  finally:
158
- for conn in opened:
159
- try:
160
- conn.close()
161
- except Exception:
162
- pass
193
+ connection.close()
163
194
 
164
195
 
165
- def command_generate(module_path: Path) -> None:
196
+ def command_generate(module_path: Path, *, target: GenerateTarget = "model-dir") -> None:
166
197
  module = load_module(module_path)
167
198
  models = collect_models(module)
168
- generated = generate_client(models)
169
- output_path = resolve_generated_path()
170
- output_path.parent.mkdir(parents=True, exist_ok=True)
171
- output_path.write_text(generated.code, encoding="utf-8")
172
- asdict_stub_path = resolve_asdict_stub_path()
173
- asdict_stub_path.write_text(generated.asdict_stub, encoding="utf-8")
174
- sys.stdout.write(f"Client written to {output_path}\n")
199
+ client_class_name = resolve_client_class_name(module_path)
200
+ generated = generate_client(models, client_class_name=client_class_name)
201
+ output_dir = resolve_generated_package_dir(module_path, target)
202
+ output_dir.mkdir(parents=True, exist_ok=True)
203
+ (output_dir / "__init__.py").write_text(generated.init_code, encoding="utf-8")
204
+ (output_dir / "__init__.pyi").write_text(generated.init_stub, encoding="utf-8")
205
+ (output_dir / GENERATED_CLIENT_FILENAME).write_text(generated.code, encoding="utf-8")
206
+ (output_dir / "asdict.pyi").write_text(generated.asdict_stub, encoding="utf-8")
207
+ sys.stdout.write(f"Client package written to {output_dir}\n")
175
208
 
176
209
 
177
210
  def command_push_db(
@@ -197,7 +230,13 @@ def build_parser() -> argparse.ArgumentParser:
197
230
  subparsers = parser.add_subparsers(dest="command", required=True)
198
231
 
199
232
  generate_parser = subparsers.add_parser("generate", help="Generate client code for given models")
200
- generate_parser.set_defaults(handler=lambda args: command_generate(args.module))
233
+ generate_parser.add_argument(
234
+ "--target",
235
+ choices=("model-dir", "package"),
236
+ default="model-dir",
237
+ help="生成 client 的位置: model-dir 写到模型文件同目录; package 写到 dclassql 包内",
238
+ )
239
+ generate_parser.set_defaults(handler=lambda args: command_generate(args.module, target=args.target))
201
240
 
202
241
  push_parser = subparsers.add_parser("push-db", help="Apply schema and indexes to configured databases")
203
242
  push_parser.add_argument(
@@ -17,7 +17,10 @@ from .model_inspector import ColumnInfo, ModelInfo, inspect_models, DataSourceCo
17
17
  class GeneratedModule:
18
18
  code: str
19
19
  asdict_stub: str
20
+ init_code: str
21
+ init_stub: str
20
22
  model_names: tuple[str, ...]
23
+ client_class_name: str
21
24
 
22
25
 
23
26
  @dataclass(slots=True)
@@ -48,14 +51,34 @@ class WhereFieldSpec:
48
51
  @dataclass(slots=True)
49
52
  class ColumnSpecRender:
50
53
  name: str
54
+ '''数据库列名'''
55
+
51
56
  name_repr: str
57
+ '''列名的 Python 字符串字面量形式, 用于生成代码里的 dict key.'''
58
+
52
59
  optional: bool
60
+ '''插入/更新时是否允许 None 或缺省, 来自 Optional/default/factory 判断.'''
61
+
53
62
  auto_increment: bool
63
+ '''是否自增,给主键用的'''
64
+
54
65
  has_default: bool
66
+ '''原 dataclass 字段是否有 default'''
67
+
55
68
  has_default_factory: bool
69
+ '''原 dataclass 字段是否有 default_factory'''
70
+
71
+ returned_field: bool
72
+ '''是否会进入返回的原 dataclass 对象. 隐式 id 为 False.'''
73
+
56
74
  mapping_value_expr: str
75
+ '''Mapping payload 转数据库值的生成表达式. 例如 `data['open_order_id']`.'''
76
+
57
77
  insert_value_expr: str
78
+ '''Insert dataclass 或原模型实例转数据库值的生成表达式. 隐式 id 用 getattr 默认 None.'''
79
+
58
80
  is_enum: bool
81
+ '''是否是 Enum 列'''
59
82
 
60
83
 
61
84
  @dataclass(slots=True)
@@ -125,6 +148,7 @@ class ModelRenderContext:
125
148
  indexes_literal: str
126
149
  unique_indexes_literal: str
127
150
  primary_value_types: tuple[str, ...]
151
+ primary_key_on_model: bool
128
152
  row_assignments: tuple[RowAssignmentRender, ...]
129
153
  default_factories: tuple[DefaultFactoryRender, ...]
130
154
  model_info: ModelInfo
@@ -132,9 +156,6 @@ class ModelRenderContext:
132
156
 
133
157
  @dataclass(slots=True)
134
158
  class ClientDataSourceContext:
135
- key: str
136
- key_repr: str
137
- provider_repr: str
138
159
  url_repr: str
139
160
  name_repr: str
140
161
 
@@ -147,6 +168,7 @@ class ClientModelBindingContext:
147
168
 
148
169
  @dataclass(slots=True)
149
170
  class ClientContext:
171
+ class_name: str
150
172
  datasource: ClientDataSourceContext
151
173
  model_bindings: tuple[ClientModelBindingContext, ...]
152
174
 
@@ -167,7 +189,7 @@ def _get_environment() -> Environment:
167
189
  return _ENVIRONMENT
168
190
 
169
191
 
170
- def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
192
+ def generate_client(models: Sequence[type[Any]], *, client_class_name: str = "GeneratedClient") -> GeneratedModule:
171
193
  model_infos = inspect_models(models)
172
194
  renderer = _TypeRenderer({info.model: name for name, info in model_infos.items()})
173
195
  filter_registry = _ScalarFilterRegistry(renderer)
@@ -192,8 +214,8 @@ def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
192
214
  for module, names in sorted(combined_imports.items())
193
215
  ]
194
216
 
195
- client_context = _build_client_context(model_infos)
196
- exports = _collect_exports(model_contexts)
217
+ client_context = _build_client_context(model_infos, client_class_name)
218
+ exports = _collect_exports(model_contexts, client_class_name)
197
219
  scalar_filters = filter_registry.render_definitions()
198
220
 
199
221
  template = _get_environment().get_template(_TEMPLATE_NAME)
@@ -207,7 +229,16 @@ def generate_client(models: Sequence[type[Any]]) -> GeneratedModule:
207
229
  if not code.endswith("\n"):
208
230
  code += "\n"
209
231
  asdict_stub = _render_asdict_stub(model_contexts)
210
- return GeneratedModule(code=code, asdict_stub=asdict_stub, model_names=tuple(sorted(model_infos.keys())))
232
+ init_code = _render_init_code(client_class_name)
233
+ init_stub = _render_init_stub(client_class_name)
234
+ return GeneratedModule(
235
+ code=code,
236
+ asdict_stub=asdict_stub,
237
+ init_code=init_code,
238
+ init_stub=init_stub,
239
+ model_names=tuple(sorted(model_infos.keys())),
240
+ client_class_name=client_class_name,
241
+ )
211
242
 
212
243
 
213
244
  def _build_model_context(
@@ -224,10 +255,13 @@ def _build_model_context(
224
255
  upsert_where_dicts: list[UpsertWhereRender] = []
225
256
  dict_field_map: dict[str, str] = {}
226
257
  enum_type_map: dict[str, type[Enum] | None] = {}
227
- column_lookup: dict[str, ColumnInfo] = {col.name: col for col in info.columns}
228
- for col in info.columns:
258
+ model_column_names = {col.name for col in info.columns}
259
+ db_columns = _build_db_columns(info)
260
+ column_lookup: dict[str, ColumnInfo] = {col.name: col for col in db_columns}
261
+ primary_key_on_model = all(column_name in model_column_names for column_name in info.primary_key)
262
+ for col in db_columns:
229
263
  annotation = _format_insert_annotation(col, renderer)
230
- default_fragment = _render_default_fragment(name, col)
264
+ default_fragment = _render_default_fragment(info.model, col)
231
265
  if default_fragment is not None:
232
266
  default_expr = default_fragment
233
267
  elif col.auto_increment:
@@ -254,8 +288,8 @@ def _build_model_context(
254
288
  if info.primary_key:
255
289
  pk_fields: list[TypedDictFieldSpec] = []
256
290
  for pk_col in info.primary_key:
257
- col_info = column_lookup.get(pk_col)
258
- annotation = renderer.render(col_info.python_type) if col_info else "object"
291
+ col_info = column_lookup[pk_col]
292
+ annotation = renderer.render(col_info.python_type)
259
293
  pk_fields.append(TypedDictFieldSpec(name=pk_col, annotation=annotation))
260
294
  upsert_where_dicts.append(UpsertWhereRender(name=f"{name}UpsertWherePK", fields=tuple(pk_fields)))
261
295
 
@@ -270,8 +304,11 @@ def _build_model_context(
270
304
  UpsertWhereRender(name=f"{name}UpsertWhereUnique{idx}", fields=tuple(unique_fields))
271
305
  )
272
306
 
307
+ if not upsert_where_dicts:
308
+ renderer.require_typing("Never")
309
+
273
310
  where_fields: list[WhereFieldSpec] = []
274
- for col in info.columns:
311
+ for col in db_columns:
275
312
  annotation = renderer.render(col.python_type)
276
313
  if "None" not in annotation:
277
314
  annotation = f"{annotation} | None"
@@ -316,11 +353,16 @@ def _build_model_context(
316
353
  auto_increment=column.auto_increment,
317
354
  has_default=column.has_default,
318
355
  has_default_factory=column.has_default_factory,
356
+ returned_field=column.name in model_column_names,
319
357
  mapping_value_expr=_format_mapping_value_expr(column, enum_type_map.get(column.name)),
320
- insert_value_expr=_format_insert_value_expr(column, enum_type_map.get(column.name)),
358
+ insert_value_expr=_format_insert_value_expr(
359
+ column,
360
+ enum_type_map.get(column.name),
361
+ returned_field=column.name in model_column_names,
362
+ ),
321
363
  is_enum=enum_type_map.get(column.name) is not None,
322
364
  )
323
- for column in info.columns
365
+ for column in db_columns
324
366
  ]
325
367
 
326
368
  foreign_keys = [
@@ -347,7 +389,7 @@ def _build_model_context(
347
389
 
348
390
  datasource_values = info.datasource
349
391
  datasource_expr = (
350
- f"DataSourceConfig(provider={datasource_values.provider!r}, url={repr(datasource_values.url)}, name={repr(datasource_values.name)})"
392
+ f"DataSourceConfig(url={repr(datasource_values.url)}, name={repr(datasource_values.name)})"
351
393
  )
352
394
 
353
395
  indexes_literal = _tuple_literal(tuple(tuple(idx) for idx in info.indexes)) if info.indexes else "()"
@@ -355,13 +397,10 @@ def _build_model_context(
355
397
  _tuple_literal(tuple(tuple(idx) for idx in info.unique_indexes)) if info.unique_indexes else "()"
356
398
  )
357
399
 
358
- row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map)
400
+ row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map, renderer)
359
401
  primary_value_types: list[str] = []
360
402
  for column_name in info.primary_key:
361
- column = column_lookup.get(column_name)
362
- if column is None:
363
- primary_value_types.append("object")
364
- continue
403
+ column = column_lookup[column_name]
365
404
  primary_value_types.append(renderer.render(column.python_type))
366
405
 
367
406
  relation_lookup = {relation.name: relation for relation in info.relations}
@@ -403,24 +442,39 @@ def _build_model_context(
403
442
  indexes_literal=indexes_literal,
404
443
  unique_indexes_literal=unique_indexes_literal,
405
444
  primary_value_types=tuple(primary_value_types),
445
+ primary_key_on_model=primary_key_on_model,
406
446
  row_assignments=tuple(row_assignments),
407
447
  default_factories=tuple(default_factories),
408
448
  model_info=info,
409
449
  )
410
450
 
411
451
 
412
- def _build_client_context(model_infos: Mapping[str, ModelInfo]) -> ClientContext:
452
+ def _build_db_columns(info: ModelInfo) -> tuple[ColumnInfo, ...]:
453
+ if info.primary_key == ("id",) and all(column.name != "id" for column in info.columns):
454
+ implicit_id = ColumnInfo(
455
+ name="id",
456
+ python_type=int,
457
+ optional=False,
458
+ auto_increment=True,
459
+ storage_kind="scalar",
460
+ has_default=False,
461
+ default_value=None,
462
+ has_default_factory=False,
463
+ default_factory=None,
464
+ )
465
+ return (implicit_id, *info.columns)
466
+ return tuple(info.columns)
467
+
468
+
469
+ def _build_client_context(model_infos: Mapping[str, ModelInfo], client_class_name: str) -> ClientContext:
413
470
  datasource_configs = {info.datasource for info in model_infos.values()}
414
471
  if len(datasource_configs) != 1:
415
472
  labels = ", ".join(
416
- f"{ds.key}({ds.provider}, {ds.url!r})" for ds in sorted(datasource_configs, key=lambda item: item.key)
473
+ f"{ds.identity}({ds.url!r})" for ds in sorted(datasource_configs, key=lambda item: item.identity)
417
474
  )
418
475
  raise ValueError(f"Generated Client can only use one datasource, got: {labels}")
419
476
  datasource = next(iter(datasource_configs))
420
477
  datasource_item = ClientDataSourceContext(
421
- key=datasource.key,
422
- key_repr=repr(datasource.key),
423
- provider_repr=repr(datasource.provider),
424
478
  url_repr=repr(datasource.url),
425
479
  name_repr=repr(datasource.name),
426
480
  )
@@ -434,13 +488,14 @@ def _build_client_context(model_infos: Mapping[str, ModelInfo]) -> ClientContext
434
488
  ]
435
489
 
436
490
  return ClientContext(
491
+ class_name=client_class_name,
437
492
  datasource=datasource_item,
438
493
  model_bindings=tuple(model_bindings),
439
494
  )
440
495
 
441
496
 
442
- def _collect_exports(model_contexts: Sequence[ModelRenderContext]) -> list[str]:
443
- exports: list[str] = ["DataSourceConfig", "ForeignKeySpec", "Client"]
497
+ def _collect_exports(model_contexts: Sequence[ModelRenderContext], client_class_name: str) -> list[str]:
498
+ exports: list[str] = ["DataSourceConfig", "ForeignKeySpec", client_class_name]
444
499
  for context in model_contexts:
445
500
  name = context.name
446
501
  exports.extend(
@@ -472,6 +527,24 @@ def _render_asdict_stub(model_contexts: Sequence[ModelRenderContext]) -> str:
472
527
  return code
473
528
 
474
529
 
530
+ def _render_init_code(client_class_name: str) -> str:
531
+ code = (
532
+ "from dclassql.asdict import asdict as asdict\n"
533
+ f"from .client import {client_class_name} as {client_class_name}\n\n"
534
+ f"__all__ = ['{client_class_name}', 'asdict']\n"
535
+ )
536
+ return code
537
+
538
+
539
+ def _render_init_stub(client_class_name: str) -> str:
540
+ code = (
541
+ "from .asdict import asdict as asdict\n"
542
+ f"from .client import {client_class_name} as {client_class_name}\n\n"
543
+ "__all__: list[str]\n"
544
+ )
545
+ return code
546
+
547
+
475
548
  def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo]) -> list[dict[str, Any]]:
476
549
  entries: list[dict[str, Any]] = []
477
550
  if not info.relations:
@@ -488,7 +561,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
488
561
  mapping: tuple[tuple[str, str], ...] | None = None
489
562
  if not relation.many:
490
563
  for fk in info.foreign_keys:
491
- if fk.remote_model is target_model:
564
+ if fk.remote_model is target_model and fk.relation_attribute == relation.name:
492
565
  mapping = tuple((local, remote) for local, remote in zip(fk.local_columns, fk.remote_columns))
493
566
  break
494
567
  if mapping is None:
@@ -524,6 +597,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
524
597
  def _build_row_assignment_context(
525
598
  info: ModelInfo,
526
599
  enum_type_map: Mapping[str, type[Enum] | None],
600
+ renderer: "_TypeRenderer",
527
601
  ) -> tuple[list[RowAssignmentRender], list[DefaultFactoryRender]]:
528
602
  dataclass_fields = fields(info.model)
529
603
  column_map = {column.name: column for column in info.columns}
@@ -539,6 +613,7 @@ def _build_row_assignment_context(
539
613
  column_map,
540
614
  enum_type_map,
541
615
  relation_defaults,
616
+ renderer,
542
617
  )
543
618
  assignments.append(RowAssignmentRender(field_name=field_obj.name, value_expr=assignment_expr))
544
619
  if default_factory is not None:
@@ -552,12 +627,13 @@ def _resolve_row_assignment(
552
627
  column_map: Mapping[str, ColumnInfo],
553
628
  enum_type_map: Mapping[str, type[Enum] | None],
554
629
  relation_defaults: Mapping[str, str],
630
+ renderer: "_TypeRenderer",
555
631
  ) -> tuple[str, DefaultFactoryRender | None]:
556
632
  name = field_obj.name
557
633
  column_info = column_map.get(name)
558
634
  if column_info is not None:
559
635
  enum_type = enum_type_map.get(name)
560
- return _column_value_expression(column_info, enum_type), None
636
+ return _column_value_expression(column_info, enum_type, renderer), None
561
637
  if field_obj.default is not MISSING:
562
638
  return f"{model_cls.__name__}.__dataclass_fields__[{name!r}].default", None
563
639
  if field_obj.default_factory is not MISSING:
@@ -570,8 +646,14 @@ def _resolve_row_assignment(
570
646
  return _infer_field_fallback(field_obj.type), None
571
647
 
572
648
 
573
- def _column_value_expression(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
649
+ def _column_value_expression(
650
+ column: ColumnInfo,
651
+ enum_type: type[Enum] | None,
652
+ renderer: "_TypeRenderer",
653
+ ) -> str:
574
654
  base_expr = f"row[{column.name!r}]"
655
+ if column.storage_kind == "json":
656
+ return f"deserialize_json_value({base_expr}, {renderer.render(column.python_type)})"
575
657
  if enum_type is None:
576
658
  return base_expr
577
659
  converter = enum_type.__name__
@@ -581,6 +663,8 @@ def _column_value_expression(column: ColumnInfo, enum_type: type[Enum] | None) -
581
663
 
582
664
 
583
665
  def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
666
+ if column.storage_kind == "json":
667
+ return f"serialize_json_value(data[{column.name!r}])"
584
668
  if enum_type is None:
585
669
  return f"data[{column.name!r}]"
586
670
  value_expr = f"data[{column.name!r}]"
@@ -589,7 +673,16 @@ def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None)
589
673
  return f"{value_expr}.value"
590
674
 
591
675
 
592
- def _format_insert_value_expr(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
676
+ def _format_insert_value_expr(
677
+ column: ColumnInfo,
678
+ enum_type: type[Enum] | None,
679
+ *,
680
+ returned_field: bool = True,
681
+ ) -> str:
682
+ if not returned_field:
683
+ return f"getattr(data, {column.name!r}, None)"
684
+ if column.storage_kind == "json":
685
+ return f"serialize_json_value(data.{column.name})"
593
686
  if enum_type is None:
594
687
  return f"data.{column.name}"
595
688
  value_expr = f"data.{column.name}"
@@ -641,12 +734,12 @@ def _format_insert_annotation(col: ColumnInfo, renderer: "_TypeRenderer") -> str
641
734
  return annotation
642
735
 
643
736
 
644
- def _render_default_fragment(model_name: str, col: ColumnInfo) -> str | None:
737
+ def _render_default_fragment(model_cls: type[Any], col: ColumnInfo) -> str | None:
645
738
  if col.has_default_factory and col.default_factory is not None:
646
- factory_expr = f"{model_name}.__dataclass_fields__['{col.name}'].default_factory"
739
+ factory_expr = f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default_factory"
647
740
  return f"field(default_factory={factory_expr})"
648
741
  if col.has_default:
649
- return repr(col.default_value)
742
+ return f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default"
650
743
  return None
651
744
 
652
745
 
@@ -875,6 +968,14 @@ class _TypeRenderer:
875
968
  self._typing_imports: set[str] = set()
876
969
 
877
970
  def render(self, tp: Any) -> str:
971
+ alias_value = getattr(tp, "__value__", None)
972
+ if alias_value is not None:
973
+ alias_name = getattr(tp, "__name__", None)
974
+ alias_module = getattr(tp, "__module__", None)
975
+ if isinstance(alias_name, str) and isinstance(alias_module, str):
976
+ self._module_imports[alias_module].add(alias_name)
977
+ return alias_name
978
+ return self.render(alias_value)
878
979
  if tp is Any:
879
980
  return "Any"
880
981
  if tp is type(None):