entari-plugin-database 0.3.2__tar.gz → 0.3.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: entari-plugin-database
3
- Version: 0.3.2
3
+ Version: 0.3.3
4
4
  Summary: Entari plugin for SQLAlchemy ORM
5
5
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
6
6
  License: MIT
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "entari-plugin-database"
3
- version = "0.3.2"
3
+ version = "0.3.3"
4
4
  description = "Entari plugin for SQLAlchemy ORM"
5
5
  authors = [
6
6
  { name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
@@ -24,7 +24,7 @@ plugin.metadata(
24
24
  "Database 服务",
25
25
  PluginRole.LIBRARY,
26
26
  [{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
27
- "0.3.2",
27
+ "0.3.3",
28
28
  description="基于 SQLAlchemy 的数据库服务插件",
29
29
  urls={
30
30
  "homepage": "https://github.com/ArcletProject/entari-plugin-database",
@@ -7,7 +7,6 @@ from collections.abc import Iterator, Sequence, AsyncIterator
7
7
  from sqlalchemy import Row, Result, ScalarResult, select
8
8
  from sqlalchemy.ext.asyncio import AsyncResult, AsyncScalarResult
9
9
  from sqlalchemy.sql.selectable import ExecutableReturnsRows
10
- from tarina import generic_issubclass
11
10
  from tarina.generic import origin_is_union, isclass
12
11
  from typing_extensions import Any
13
12
  from typing import cast, get_args, get_origin
@@ -23,6 +22,8 @@ from arclet.letoderea.scope import global_propagators
23
22
  from arclet.entari.plugin.model import Plugin
24
23
  from sqlalchemy.ext import asyncio as sa_async
25
24
 
25
+ from .utils import generic_issubclass
26
+
26
27
 
27
28
  class SessionProvider(Provider[sa_async.AsyncSession]):
28
29
  priority = 10
@@ -204,7 +205,7 @@ def SQLDepends(statement: ExecutableReturnsRows, option: Option = Option(), cach
204
205
 
205
206
 
206
207
  class ORMProviderFactory(ProviderFactory):
207
- priority = 10
208
+ priority = 20
208
209
 
209
210
  class _ModelProvider(Provider[Any]):
210
211
  def __init__(self, statement: ExecutableReturnsRows, option: Option):
@@ -231,8 +232,10 @@ class ORMProviderFactory(ProviderFactory):
231
232
  return result
232
233
 
233
234
  def validate(self, param: Param):
235
+ if param.providers: # skip if already has providers
236
+ return
234
237
  for pattern, option in PATTERNS.items():
235
- if models := cast("list[Any]", generic_issubclass(pattern, param.annotation, list_=True)):
238
+ if models := cast("list[Any]", generic_issubclass(pattern, param.annotation)):
236
239
  break
237
240
  else:
238
241
  models, option = [], Option()
@@ -0,0 +1,85 @@
1
+ import types
2
+ from itertools import repeat
3
+ from collections.abc import Iterable
4
+ from typing import Annotated, Any, Union, Literal, get_args, get_origin
5
+ from typing import Literal as LiteralExt
6
+
7
+
8
+ from arclet.entari import logger as log_m
9
+
10
+ logger = log_m.log.wrapper("[Database]")
11
+
12
+
13
+ def origin_is_union(origin: type[Any] | None) -> bool:
14
+ return origin is Union or origin is types.UnionType
15
+
16
+
17
+ def origin_is_literal(origin: type[Any] | None) -> bool:
18
+ """判断是否是 Literal 类型"""
19
+ return origin is Literal or origin is LiteralExt
20
+
21
+
22
+ def generic_issubclass(scls: Any, cls: Any) -> bool | list[Any]:
23
+ if isinstance(cls, tuple):
24
+ return _map_generic_issubclass(repeat(scls), cls)
25
+
26
+ if scls is Any:
27
+ return [cls]
28
+
29
+ if cls is Any:
30
+ return True
31
+
32
+ try:
33
+ return issubclass(scls, cls)
34
+ except TypeError:
35
+ pass
36
+
37
+ scls_origin, scls_args = get_origin(scls) or scls, get_args(scls)
38
+ cls_origin, cls_args = get_origin(cls) or cls, get_args(cls)
39
+
40
+ if scls_origin is tuple and cls_origin is tuple:
41
+ if len(scls_args) == 2 and scls_args[1] is Ellipsis:
42
+ return generic_issubclass(scls_args[0], cls_args)
43
+
44
+ if len(cls_args) == 2 and cls_args[1] is Ellipsis:
45
+ return _map_generic_issubclass(scls_args, repeat(cls_args[0]), failfast=True)
46
+
47
+ if scls_origin is Annotated:
48
+ return generic_issubclass(scls_args[0], cls)
49
+ if cls_origin is Annotated:
50
+ return generic_issubclass(scls, cls_args[0])
51
+
52
+ if origin_is_union(scls_origin):
53
+ return _map_generic_issubclass(scls_args, repeat(cls), failfast=True)
54
+ if origin_is_union(cls_origin):
55
+ return generic_issubclass(scls, cls_args)
56
+
57
+ if origin_is_literal(scls_origin) and origin_is_literal(cls_origin):
58
+ return set(scls_args) <= set(cls_args)
59
+
60
+ try:
61
+ if not issubclass(scls_origin, cls_origin):
62
+ return False
63
+ except TypeError:
64
+ return False
65
+
66
+ if not cls_args:
67
+ return True
68
+
69
+ if len(scls_args) != len(cls_args):
70
+ return False
71
+
72
+ return _map_generic_issubclass(scls_args, cls_args, failfast=True)
73
+
74
+
75
+ def _map_generic_issubclass(scls: Iterable[Any], cls: Iterable[Any], *, failfast: bool = False) -> bool | list[Any]:
76
+ results = []
77
+ for scls_arg, cls_arg in zip(scls, cls):
78
+ if not (result := generic_issubclass(scls_arg, cls_arg)) and failfast:
79
+ return False
80
+ elif isinstance(result, list):
81
+ results.extend(result)
82
+ elif not isinstance(result, bool):
83
+ results.append(result)
84
+
85
+ return results or False
@@ -1,3 +0,0 @@
1
- from arclet.entari import logger as log_m
2
-
3
- logger = log_m.log.wrapper("[Database]")