dclassql 0.4.0__tar.gz → 0.4.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {dclassql-0.4.0 → dclassql-0.4.1}/PKG-INFO +5 -2
  2. {dclassql-0.4.0 → dclassql-0.4.1}/README.md +4 -1
  3. {dclassql-0.4.0 → dclassql-0.4.1}/pyproject.toml +1 -1
  4. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/cli.py +44 -23
  5. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/codegen.py +97 -28
  6. dclassql-0.4.1/src/dclassql/db_pool.py +89 -0
  7. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/model_inspector.py +98 -25
  8. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/push/__init__.py +14 -21
  9. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/push/base.py +20 -0
  10. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/push/sqlite.py +36 -1
  11. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/base.py +3 -0
  12. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/lazy.py +21 -5
  13. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/datasource.py +2 -5
  14. dclassql-0.4.1/src/dclassql/runtime/json_value.py +104 -0
  15. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/client_class.jinja +13 -14
  16. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/imports.jinja +3 -2
  17. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/model_section.jinja +18 -4
  18. dclassql-0.4.0/src/dclassql/db_pool.py +0 -76
  19. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/.gitignore +0 -0
  20. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/__init__.py +0 -0
  21. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/asdict.py +0 -0
  22. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/__init__.py +0 -0
  23. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/metadata.py +0 -0
  24. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/protocols.py +0 -0
  25. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/sqlite.py +0 -0
  26. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/backends/where_compiler.py +0 -0
  27. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/sql_recorder.py +0 -0
  28. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/runtime/sqlite_adapters.py +0 -0
  29. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/table_spec.py +0 -0
  30. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/__init__.py +0 -0
  31. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/asdict_stub.pyi.jinja +0 -0
  32. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/client_module.py.jinja +0 -0
  33. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/exports.jinja +0 -0
  34. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/macros.jinja +0 -0
  35. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/templates/partials/scalar_filters.jinja +0 -0
  36. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/typing.py +0 -0
  37. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/unwarp.py +0 -0
  38. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/utils/__init__.py +0 -0
  39. {dclassql-0.4.0 → dclassql-0.4.1}/src/dclassql/utils/ensure.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: dclassql
3
- Version: 0.4.0
3
+ Version: 0.4.1
4
4
  Summary: A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions.
5
5
  Keywords: orm,codegen,sqlite,dataclass,typed
6
6
  Author: myuanz
@@ -62,11 +62,14 @@ class User:
62
62
  写出如下代码时:
63
63
 
64
64
  ```python
65
- from dclassql import client
65
+ from dclassql import Client
66
+
67
+ client = Client()
66
68
 
67
69
  client.user.insert({
68
70
  "name": "Alice",
69
71
  "email": "test@example.com",
72
+ # 这里缺少 last_login
70
73
  })
71
74
  ```
72
75
 
@@ -41,11 +41,14 @@ class User:
41
41
  写出如下代码时:
42
42
 
43
43
  ```python
44
- from dclassql import client
44
+ from dclassql import Client
45
+
46
+ client = Client()
45
47
 
46
48
  client.user.insert({
47
49
  "name": "Alice",
48
50
  "email": "test@example.com",
51
+ # 这里缺少 last_login
49
52
  })
50
53
  ```
51
54
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dclassql"
3
- version = "0.4.0"
3
+ version = "0.4.1"
4
4
  description = "A type-safe ORM generator for Python, creating fully type-hinted database clients from plain dataclass definitions."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -52,11 +52,17 @@ def load_module(module_path: Path) -> ModuleType:
52
52
  raise ImportError(f"Unable to load module from '{module_path}'")
53
53
  module = importlib.util.module_from_spec(spec)
54
54
  sys.modules[module_name] = module
55
- sys.path.insert(0, str(module_path.parent))
55
+ original_sys_path = list(sys.path)
56
+ search_paths = [str(module_path.parent)]
57
+ cwd = str(Path.cwd())
58
+ if cwd not in search_paths:
59
+ search_paths.append(cwd)
60
+ for path in reversed(search_paths):
61
+ sys.path.insert(0, path)
56
62
  try:
57
63
  spec.loader.exec_module(module)
58
64
  finally:
59
- sys.path.pop(0)
65
+ sys.path[:] = original_sys_path
60
66
  return module
61
67
 
62
68
 
@@ -87,18 +93,41 @@ def resolve_generated_package_dir(module_path: Path, target: GenerateTarget = "m
87
93
  def collect_models(module: ModuleType) -> list[type[Any]]:
88
94
  from dataclasses import is_dataclass
89
95
 
96
+ excluded_names = _collect_excluded_model_names(module)
90
97
  models: list[type[Any]] = []
91
98
  for value in vars(module).values():
92
- if isinstance(value, type) and is_dataclass(value) and value.__module__ == module.__name__:
99
+ if (
100
+ isinstance(value, type)
101
+ and is_dataclass(value)
102
+ and value.__module__ == module.__name__
103
+ and value.__name__ not in excluded_names
104
+ ):
93
105
  models.append(value)
94
106
  if not models:
95
107
  raise ValueError("No dataclass models were found in the provided module")
96
108
  return models
97
109
 
98
110
 
111
+ def _collect_excluded_model_names(module: ModuleType) -> set[str]:
112
+ raw = getattr(module, "__exclude__", ())
113
+ if raw is None:
114
+ return set()
115
+ if isinstance(raw, str):
116
+ return {raw}
117
+ names: set[str] = set()
118
+ for item in raw:
119
+ if isinstance(item, str):
120
+ names.add(item)
121
+ continue
122
+ if isinstance(item, type):
123
+ names.add(item.__name__)
124
+ continue
125
+ raise TypeError("__exclude__ entries must be dataclass classes or class names")
126
+ return names
127
+
128
+
99
129
  def _describe_schema_diff(info: ModelInfo, diff: SchemaDiff) -> str:
100
- datasource_key = getattr(info.datasource, "key", None)
101
- prefix = f"[{datasource_key}] " if datasource_key else ""
130
+ prefix = f"[{info.datasource.identity}] "
102
131
  parts: list[str] = [f"{prefix}模型 {info.model.__name__} 需要重建表"]
103
132
  if diff.added:
104
133
  added = ", ".join(f"+{column.name}:{column.type_sql}" for column in diff.added)
@@ -144,32 +173,24 @@ def push_database(
144
173
  confirm_mode: ConfirmRebuildMode | None = None,
145
174
  ) -> None:
146
175
  model_infos = inspect_models(models)
147
- connections: dict[str, Any] = {}
148
- opened: list[Any] = []
176
+ datasource_configs = {info.datasource for info in model_infos.values()}
177
+ if len(datasource_configs) != 1:
178
+ labels = ", ".join(sorted(config.identity for config in datasource_configs))
179
+ raise ValueError(f"push-db only supports one datasource, got: {labels}")
180
+ datasource = next(iter(datasource_configs))
181
+ if datasource.provider != "sqlite":
182
+ raise ValueError(f"Unsupported provider '{datasource.provider}'")
149
183
  confirm_callback = _build_confirm_callback(confirm_mode) if confirm_mode else None
184
+ connection = open_sqlite_connection(datasource.url)
150
185
  try:
151
- for info in model_infos.values():
152
- config = info.datasource
153
- key = config.key
154
- if key in connections:
155
- continue
156
- if config.provider != "sqlite":
157
- raise ValueError(f"Unsupported provider '{config.provider}'")
158
- connection = open_sqlite_connection(config.url)
159
- connections[key] = connection
160
- opened.append(connection)
161
186
  db_push(
162
187
  models,
163
- connections,
188
+ connection,
164
189
  sync_indexes=sync_indexes,
165
190
  confirm_rebuild=confirm_callback,
166
191
  )
167
192
  finally:
168
- for conn in opened:
169
- try:
170
- conn.close()
171
- except Exception:
172
- pass
193
+ connection.close()
173
194
 
174
195
 
175
196
  def command_generate(module_path: Path, *, target: GenerateTarget = "model-dir") -> None:
@@ -51,14 +51,34 @@ class WhereFieldSpec:
51
51
  @dataclass(slots=True)
52
52
  class ColumnSpecRender:
53
53
  name: str
54
+ '''数据库列名'''
55
+
54
56
  name_repr: str
57
+ '''列名的 Python 字符串字面量形式, 用于生成代码里的 dict key.'''
58
+
55
59
  optional: bool
60
+ '''插入/更新时是否允许 None 或缺省, 来自 Optional/default/factory 判断.'''
61
+
56
62
  auto_increment: bool
63
+ '''是否自增,给主键用的'''
64
+
57
65
  has_default: bool
66
+ '''原 dataclass 字段是否有 default'''
67
+
58
68
  has_default_factory: bool
69
+ '''原 dataclass 字段是否有 default_factory'''
70
+
71
+ returned_field: bool
72
+ '''是否会进入返回的原 dataclass 对象. 隐式 id 为 False.'''
73
+
59
74
  mapping_value_expr: str
75
+ '''Mapping payload 转数据库值的生成表达式. 例如 `data['open_order_id']`.'''
76
+
60
77
  insert_value_expr: str
78
+ '''Insert dataclass 或原模型实例转数据库值的生成表达式. 隐式 id 用 getattr 默认 None.'''
79
+
61
80
  is_enum: bool
81
+ '''是否是 Enum 列'''
62
82
 
63
83
 
64
84
  @dataclass(slots=True)
@@ -128,6 +148,7 @@ class ModelRenderContext:
128
148
  indexes_literal: str
129
149
  unique_indexes_literal: str
130
150
  primary_value_types: tuple[str, ...]
151
+ primary_key_on_model: bool
131
152
  row_assignments: tuple[RowAssignmentRender, ...]
132
153
  default_factories: tuple[DefaultFactoryRender, ...]
133
154
  model_info: ModelInfo
@@ -135,9 +156,6 @@ class ModelRenderContext:
135
156
 
136
157
  @dataclass(slots=True)
137
158
  class ClientDataSourceContext:
138
- key: str
139
- key_repr: str
140
- provider_repr: str
141
159
  url_repr: str
142
160
  name_repr: str
143
161
 
@@ -237,10 +255,13 @@ def _build_model_context(
237
255
  upsert_where_dicts: list[UpsertWhereRender] = []
238
256
  dict_field_map: dict[str, str] = {}
239
257
  enum_type_map: dict[str, type[Enum] | None] = {}
240
- column_lookup: dict[str, ColumnInfo] = {col.name: col for col in info.columns}
241
- for col in info.columns:
258
+ model_column_names = {col.name for col in info.columns}
259
+ db_columns = _build_db_columns(info)
260
+ column_lookup: dict[str, ColumnInfo] = {col.name: col for col in db_columns}
261
+ primary_key_on_model = all(column_name in model_column_names for column_name in info.primary_key)
262
+ for col in db_columns:
242
263
  annotation = _format_insert_annotation(col, renderer)
243
- default_fragment = _render_default_fragment(name, col)
264
+ default_fragment = _render_default_fragment(info.model, col)
244
265
  if default_fragment is not None:
245
266
  default_expr = default_fragment
246
267
  elif col.auto_increment:
@@ -267,8 +288,8 @@ def _build_model_context(
267
288
  if info.primary_key:
268
289
  pk_fields: list[TypedDictFieldSpec] = []
269
290
  for pk_col in info.primary_key:
270
- col_info = column_lookup.get(pk_col)
271
- annotation = renderer.render(col_info.python_type) if col_info else "object"
291
+ col_info = column_lookup[pk_col]
292
+ annotation = renderer.render(col_info.python_type)
272
293
  pk_fields.append(TypedDictFieldSpec(name=pk_col, annotation=annotation))
273
294
  upsert_where_dicts.append(UpsertWhereRender(name=f"{name}UpsertWherePK", fields=tuple(pk_fields)))
274
295
 
@@ -283,8 +304,11 @@ def _build_model_context(
283
304
  UpsertWhereRender(name=f"{name}UpsertWhereUnique{idx}", fields=tuple(unique_fields))
284
305
  )
285
306
 
307
+ if not upsert_where_dicts:
308
+ renderer.require_typing("Never")
309
+
286
310
  where_fields: list[WhereFieldSpec] = []
287
- for col in info.columns:
311
+ for col in db_columns:
288
312
  annotation = renderer.render(col.python_type)
289
313
  if "None" not in annotation:
290
314
  annotation = f"{annotation} | None"
@@ -329,11 +353,16 @@ def _build_model_context(
329
353
  auto_increment=column.auto_increment,
330
354
  has_default=column.has_default,
331
355
  has_default_factory=column.has_default_factory,
356
+ returned_field=column.name in model_column_names,
332
357
  mapping_value_expr=_format_mapping_value_expr(column, enum_type_map.get(column.name)),
333
- insert_value_expr=_format_insert_value_expr(column, enum_type_map.get(column.name)),
358
+ insert_value_expr=_format_insert_value_expr(
359
+ column,
360
+ enum_type_map.get(column.name),
361
+ returned_field=column.name in model_column_names,
362
+ ),
334
363
  is_enum=enum_type_map.get(column.name) is not None,
335
364
  )
336
- for column in info.columns
365
+ for column in db_columns
337
366
  ]
338
367
 
339
368
  foreign_keys = [
@@ -360,7 +389,7 @@ def _build_model_context(
360
389
 
361
390
  datasource_values = info.datasource
362
391
  datasource_expr = (
363
- f"DataSourceConfig(provider={datasource_values.provider!r}, url={repr(datasource_values.url)}, name={repr(datasource_values.name)})"
392
+ f"DataSourceConfig(url={repr(datasource_values.url)}, name={repr(datasource_values.name)})"
364
393
  )
365
394
 
366
395
  indexes_literal = _tuple_literal(tuple(tuple(idx) for idx in info.indexes)) if info.indexes else "()"
@@ -368,13 +397,10 @@ def _build_model_context(
368
397
  _tuple_literal(tuple(tuple(idx) for idx in info.unique_indexes)) if info.unique_indexes else "()"
369
398
  )
370
399
 
371
- row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map)
400
+ row_assignments, default_factories = _build_row_assignment_context(info, enum_type_map, renderer)
372
401
  primary_value_types: list[str] = []
373
402
  for column_name in info.primary_key:
374
- column = column_lookup.get(column_name)
375
- if column is None:
376
- primary_value_types.append("object")
377
- continue
403
+ column = column_lookup[column_name]
378
404
  primary_value_types.append(renderer.render(column.python_type))
379
405
 
380
406
  relation_lookup = {relation.name: relation for relation in info.relations}
@@ -416,24 +442,39 @@ def _build_model_context(
416
442
  indexes_literal=indexes_literal,
417
443
  unique_indexes_literal=unique_indexes_literal,
418
444
  primary_value_types=tuple(primary_value_types),
445
+ primary_key_on_model=primary_key_on_model,
419
446
  row_assignments=tuple(row_assignments),
420
447
  default_factories=tuple(default_factories),
421
448
  model_info=info,
422
449
  )
423
450
 
424
451
 
452
+ def _build_db_columns(info: ModelInfo) -> tuple[ColumnInfo, ...]:
453
+ if info.primary_key == ("id",) and all(column.name != "id" for column in info.columns):
454
+ implicit_id = ColumnInfo(
455
+ name="id",
456
+ python_type=int,
457
+ optional=False,
458
+ auto_increment=True,
459
+ storage_kind="scalar",
460
+ has_default=False,
461
+ default_value=None,
462
+ has_default_factory=False,
463
+ default_factory=None,
464
+ )
465
+ return (implicit_id, *info.columns)
466
+ return tuple(info.columns)
467
+
468
+
425
469
  def _build_client_context(model_infos: Mapping[str, ModelInfo], client_class_name: str) -> ClientContext:
426
470
  datasource_configs = {info.datasource for info in model_infos.values()}
427
471
  if len(datasource_configs) != 1:
428
472
  labels = ", ".join(
429
- f"{ds.key}({ds.provider}, {ds.url!r})" for ds in sorted(datasource_configs, key=lambda item: item.key)
473
+ f"{ds.identity}({ds.url!r})" for ds in sorted(datasource_configs, key=lambda item: item.identity)
430
474
  )
431
475
  raise ValueError(f"Generated Client can only use one datasource, got: {labels}")
432
476
  datasource = next(iter(datasource_configs))
433
477
  datasource_item = ClientDataSourceContext(
434
- key=datasource.key,
435
- key_repr=repr(datasource.key),
436
- provider_repr=repr(datasource.provider),
437
478
  url_repr=repr(datasource.url),
438
479
  name_repr=repr(datasource.name),
439
480
  )
@@ -520,7 +561,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
520
561
  mapping: tuple[tuple[str, str], ...] | None = None
521
562
  if not relation.many:
522
563
  for fk in info.foreign_keys:
523
- if fk.remote_model is target_model:
564
+ if fk.remote_model is target_model and fk.relation_attribute == relation.name:
524
565
  mapping = tuple((local, remote) for local, remote in zip(fk.local_columns, fk.remote_columns))
525
566
  break
526
567
  if mapping is None:
@@ -556,6 +597,7 @@ def _build_relation_entries(info: ModelInfo, model_infos: Mapping[str, ModelInfo
556
597
  def _build_row_assignment_context(
557
598
  info: ModelInfo,
558
599
  enum_type_map: Mapping[str, type[Enum] | None],
600
+ renderer: "_TypeRenderer",
559
601
  ) -> tuple[list[RowAssignmentRender], list[DefaultFactoryRender]]:
560
602
  dataclass_fields = fields(info.model)
561
603
  column_map = {column.name: column for column in info.columns}
@@ -571,6 +613,7 @@ def _build_row_assignment_context(
571
613
  column_map,
572
614
  enum_type_map,
573
615
  relation_defaults,
616
+ renderer,
574
617
  )
575
618
  assignments.append(RowAssignmentRender(field_name=field_obj.name, value_expr=assignment_expr))
576
619
  if default_factory is not None:
@@ -584,12 +627,13 @@ def _resolve_row_assignment(
584
627
  column_map: Mapping[str, ColumnInfo],
585
628
  enum_type_map: Mapping[str, type[Enum] | None],
586
629
  relation_defaults: Mapping[str, str],
630
+ renderer: "_TypeRenderer",
587
631
  ) -> tuple[str, DefaultFactoryRender | None]:
588
632
  name = field_obj.name
589
633
  column_info = column_map.get(name)
590
634
  if column_info is not None:
591
635
  enum_type = enum_type_map.get(name)
592
- return _column_value_expression(column_info, enum_type), None
636
+ return _column_value_expression(column_info, enum_type, renderer), None
593
637
  if field_obj.default is not MISSING:
594
638
  return f"{model_cls.__name__}.__dataclass_fields__[{name!r}].default", None
595
639
  if field_obj.default_factory is not MISSING:
@@ -602,8 +646,14 @@ def _resolve_row_assignment(
602
646
  return _infer_field_fallback(field_obj.type), None
603
647
 
604
648
 
605
- def _column_value_expression(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
649
+ def _column_value_expression(
650
+ column: ColumnInfo,
651
+ enum_type: type[Enum] | None,
652
+ renderer: "_TypeRenderer",
653
+ ) -> str:
606
654
  base_expr = f"row[{column.name!r}]"
655
+ if column.storage_kind == "json":
656
+ return f"deserialize_json_value({base_expr}, {renderer.render(column.python_type)})"
607
657
  if enum_type is None:
608
658
  return base_expr
609
659
  converter = enum_type.__name__
@@ -613,6 +663,8 @@ def _column_value_expression(column: ColumnInfo, enum_type: type[Enum] | None) -
613
663
 
614
664
 
615
665
  def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
666
+ if column.storage_kind == "json":
667
+ return f"serialize_json_value(data[{column.name!r}])"
616
668
  if enum_type is None:
617
669
  return f"data[{column.name!r}]"
618
670
  value_expr = f"data[{column.name!r}]"
@@ -621,7 +673,16 @@ def _format_mapping_value_expr(column: ColumnInfo, enum_type: type[Enum] | None)
621
673
  return f"{value_expr}.value"
622
674
 
623
675
 
624
- def _format_insert_value_expr(column: ColumnInfo, enum_type: type[Enum] | None) -> str:
676
+ def _format_insert_value_expr(
677
+ column: ColumnInfo,
678
+ enum_type: type[Enum] | None,
679
+ *,
680
+ returned_field: bool = True,
681
+ ) -> str:
682
+ if not returned_field:
683
+ return f"getattr(data, {column.name!r}, None)"
684
+ if column.storage_kind == "json":
685
+ return f"serialize_json_value(data.{column.name})"
625
686
  if enum_type is None:
626
687
  return f"data.{column.name}"
627
688
  value_expr = f"data.{column.name}"
@@ -673,12 +734,12 @@ def _format_insert_annotation(col: ColumnInfo, renderer: "_TypeRenderer") -> str
673
734
  return annotation
674
735
 
675
736
 
676
- def _render_default_fragment(model_name: str, col: ColumnInfo) -> str | None:
737
+ def _render_default_fragment(model_cls: type[Any], col: ColumnInfo) -> str | None:
677
738
  if col.has_default_factory and col.default_factory is not None:
678
- factory_expr = f"{model_name}.__dataclass_fields__['{col.name}'].default_factory"
739
+ factory_expr = f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default_factory"
679
740
  return f"field(default_factory={factory_expr})"
680
741
  if col.has_default:
681
- return repr(col.default_value)
742
+ return f"{model_cls.__name__}.__dataclass_fields__['{col.name}'].default"
682
743
  return None
683
744
 
684
745
 
@@ -907,6 +968,14 @@ class _TypeRenderer:
907
968
  self._typing_imports: set[str] = set()
908
969
 
909
970
  def render(self, tp: Any) -> str:
971
+ alias_value = getattr(tp, "__value__", None)
972
+ if alias_value is not None:
973
+ alias_name = getattr(tp, "__name__", None)
974
+ alias_module = getattr(tp, "__module__", None)
975
+ if isinstance(alias_name, str) and isinstance(alias_module, str):
976
+ self._module_imports[alias_module].add(alias_name)
977
+ return alias_name
978
+ return self.render(alias_value)
910
979
  if tp is Any:
911
980
  return "Any"
912
981
  if tp is type(None):
@@ -0,0 +1,89 @@
1
+ import functools
2
+ import sqlite3
3
+ import threading
4
+ from typing import Any, Callable, Concatenate, Protocol
5
+
6
+
7
+ class HasLocalClass(Protocol):
8
+ _local: threading.local
9
+
10
+
11
+ def save_local[C: HasLocalClass, **P, T](
12
+ func: Callable[Concatenate[C, P], T] | None = None,
13
+ *,
14
+ key: Callable[[Any, Callable[..., object]], object] | None = None,
15
+ ) -> Callable[[Callable[Concatenate[C, P], T]], Callable[Concatenate[C, P], T]] | Callable[Concatenate[C, P], T]:
16
+ def decorator(func: Callable[Concatenate[C, P], T]) -> Callable[Concatenate[C, P], T]:
17
+ @functools.wraps(func)
18
+ def wrapper(self: C, *args: P.args, **kwargs: P.kwargs) -> T:
19
+ cache = getattr(self._local, "_dclassql_cache", None)
20
+ if cache is None:
21
+ cache = {}
22
+ self._local._dclassql_cache = cache
23
+
24
+ cache_key = key(self, func) if key is not None else func.__name__
25
+ if cache_key in cache:
26
+ return cache[cache_key]
27
+
28
+ value = func(self, *args, **kwargs)
29
+ cache[cache_key] = value
30
+ return value
31
+
32
+ return wrapper
33
+
34
+ if func is not None:
35
+ return decorator(func)
36
+ return decorator
37
+
38
+
39
+ class BaseDBPool:
40
+ ''' Thread-level database pool base class. Methods decorated with `@save_local` are cached in `threading.local()`. Usage example:
41
+ ```python
42
+ class ExampleDBPool(BaseDBPool):
43
+ sqlite_db_path = 'data/news.db'
44
+ visitor_sqlite_db_path = 'data/visitors.db'
45
+
46
+ @save_local
47
+ def sqlite_conn(self) -> sqlite3.Connection:
48
+ conn = sqlite3.connect(self.sqlite_db_path, check_same_thread=False)
49
+ self._setup_sqlite_db(conn)
50
+ return conn
51
+
52
+ @save_local
53
+ def fastlite_conn(self):
54
+ from fastlite import database
55
+ fastlite_db = database(self.sqlite_db_path)
56
+ return fastlite_db
57
+
58
+ @save_local
59
+ def fastlite_conn_visitor(self):
60
+ from fastlite import database
61
+ fastlite_db_visitor = database(self.visitor_sqlite_db_path)
62
+ self._setup_sqlite_db(fastlite_db_visitor.conn)
63
+ return fastlite_db_visitor
64
+ ```
65
+ '''
66
+
67
+ _local = threading.local()
68
+
69
+ @classmethod
70
+ def close_all(cls, verbose: bool = False):
71
+ cache = getattr(cls._local, "_dclassql_cache", None)
72
+ if cache is None:
73
+ return
74
+ for key, obj in list(cache.items()):
75
+ label = repr(key)
76
+ if hasattr(obj, 'close') and callable(obj.close):
77
+ if verbose:
78
+ print(f'Closing {label}')
79
+ obj.close()
80
+ del cache[key]
81
+
82
+ @classmethod
83
+ def _setup_sqlite_db(cls, conn: sqlite3.Connection):
84
+ conn.execute('PRAGMA journal_mode = WAL;')
85
+ conn.execute('PRAGMA synchronous = NORMAL;')
86
+ conn.execute('pragma temp_store = memory;')
87
+ conn.execute('pragma page_size = 32768;')
88
+ conn.execute("PRAGMA busy_timeout = 3000;")
89
+ conn.execute('PRAGMA journal_size_limit=104857600;')