entari-plugin-database 0.2.4__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: entari-plugin-database
3
- Version: 0.2.4
3
+ Version: 0.3.0
4
4
  Summary: Entari plugin for SQLAlchemy ORM
5
5
  Author-Email: RF-Tar-Railt <rf_tar_railt@qq.com>
6
6
  License: MIT
@@ -11,7 +11,7 @@ Requires-Dist: graia-amnesia>=0.11.4
11
11
  Requires-Dist: tarina<0.8.0,>=0.7.1
12
12
  Requires-Dist: arclet-letoderea>=0.19.5
13
13
  Requires-Dist: alembic>=1.16.5
14
- Requires-Dist: arclet-entari<0.18.0,>=0.17.0
14
+ Requires-Dist: arclet-entari<0.19.0,>=0.18.0rc1
15
15
  Description-Content-Type: text/markdown
16
16
 
17
17
  # entari-plugin-database
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "entari-plugin-database"
3
- version = "0.2.4"
3
+ version = "0.3.0"
4
4
  description = "Entari plugin for SQLAlchemy ORM"
5
5
  authors = [
6
6
  { name = "RF-Tar-Railt", email = "rf_tar_railt@qq.com" },
@@ -12,7 +12,7 @@ dependencies = [
12
12
  "tarina<0.8.0,>=0.7.1",
13
13
  "arclet-letoderea>=0.19.5",
14
14
  "alembic>=1.16.5",
15
- "arclet-entari<0.18.0,>=0.17.0",
15
+ "arclet-entari<0.19.0,>=0.18.0rc1",
16
16
  ]
17
17
  requires-python = ">=3.10"
18
18
  readme = "README.md"
@@ -29,11 +29,6 @@ build-backend = "pdm.backend"
29
29
  [tool.pdm]
30
30
  distribution = true
31
31
 
32
- [tool.pdm.dev-dependencies]
33
- dev = [
34
- "arclet-entari[full]>=0.17.0",
35
- ]
36
-
37
32
  [tool.ruff]
38
33
  line-length = 120
39
34
  target-version = "py310"
@@ -61,3 +56,13 @@ ignore = [
61
56
  "PYI055",
62
57
  "UP038",
63
58
  ]
59
+
60
+ [tool.pyright]
61
+ pythonVersion = "3.10"
62
+ pythonPlatform = "All"
63
+ typeCheckingMode = "basic"
64
+
65
+ [dependency-groups]
66
+ dev = [
67
+ "arclet-entari[full]>=0.17.0",
68
+ ]
@@ -1,8 +1,6 @@
1
1
  from sqlalchemy.ext.asyncio import create_async_engine
2
- from arclet.letoderea.provider import global_providers
3
- from arclet.letoderea.scope import global_propagators
4
- from arclet.letoderea.core import add_task
5
- from arclet.entari import plugin
2
+ from arclet.letoderea.utils import add_task
3
+ from arclet.entari import Plugin, plugin
6
4
  from arclet.entari.config import config_model_validate
7
5
  from arclet.entari.event.config import ConfigReload
8
6
  from graia.amnesia.builtins.sqla import SqlalchemyService
@@ -14,7 +12,6 @@ from sqlalchemy.ext import asyncio as sa_async
14
12
  from sqlalchemy.orm import Mapped as Mapped, instrumentation
15
13
  from sqlalchemy.orm import mapped_column as mapped_column
16
14
 
17
- from .param import db_supplier, sess_provider, orm_factory
18
15
  from .param import SQLDepends as SQLDepends
19
16
  from .utils import logger
20
17
  from .migration import run_migration, register_custom_migration
@@ -25,7 +22,7 @@ plugin.declare_static()
25
22
  plugin.metadata(
26
23
  "Database 服务",
27
24
  [{"name": "RF-Tar-Railt", "email": "rf_tar_railt@qq.com"}],
28
- "0.2.4",
25
+ "0.3.0",
29
26
  description="基于 SQLAlchemy 的数据库服务插件",
30
27
  urls={
31
28
  "homepage": "https://github.com/ArcletProject/entari-plugin-database",
@@ -33,11 +30,6 @@ plugin.metadata(
33
30
  config=Config,
34
31
  readme="README.md",
35
32
  )
36
- plugin.collect_disposes(
37
- lambda: global_propagators.remove(db_supplier),
38
- lambda: global_providers.remove(sess_provider),
39
- lambda: global_providers.remove(orm_factory),
40
- )
41
33
 
42
34
  _config = plugin.get_config(Config)
43
35
 
@@ -133,12 +125,17 @@ def migration_callback(cls: type[Base], kwargs: dict):
133
125
  task.add_done_callback(_PENDING_TASKS.discard)
134
126
 
135
127
 
136
- register_callback(_setup_tablename)
137
- register_callback(_clean_exist)
138
- register_callback(migration_callback, after=True)
139
- plugin.collect_disposes(lambda: remove_callback(_clean_exist))
140
- plugin.collect_disposes(lambda: remove_callback(_setup_tablename))
141
- plugin.collect_disposes(lambda: remove_callback(migration_callback))
128
+ def _register_callbacks():
129
+ register_callback(_setup_tablename)
130
+ register_callback(_clean_exist)
131
+ register_callback(migration_callback, after=True)
132
+ yield lambda: remove_callback(_clean_exist)
133
+ yield lambda: remove_callback(_setup_tablename)
134
+ yield lambda: remove_callback(migration_callback)
135
+
136
+
137
+ plg = Plugin.current()
138
+ plg.effect(_register_callbacks, "database model callbacks")
142
139
 
143
140
 
144
141
  BaseOrm = Base
@@ -20,22 +20,44 @@ from arclet.letoderea import Propagator, Contexts, STACK, Provider, ProviderFact
20
20
  from arclet.letoderea.ref import Deref, generate
21
21
  from arclet.letoderea.provider import global_providers
22
22
  from arclet.letoderea.scope import global_propagators
23
+ from arclet.entari.plugin.model import Plugin
23
24
  from sqlalchemy.ext import asyncio as sa_async
24
25
 
25
26
 
27
+ class SessionProvider(Provider[sa_async.AsyncSession]):
28
+ priority = 10
29
+
30
+ async def __call__(self, context: Contexts):
31
+ if "$db_session" in context:
32
+ return context["$db_session"]
33
+ try:
34
+ db = it(Launart).get_component(SqlalchemyService)
35
+ stack = context[STACK]
36
+ sess = await stack.enter_async_context(db.get_session())
37
+ context["$db_session"] = sess
38
+ return sess
39
+ except ValueError:
40
+ return
41
+
42
+
26
43
  class DatabasePropagator(Propagator):
27
44
  def validate(self, subscriber: Subscriber):
28
45
  params = subscriber.params
29
46
  if any((p.depend and isinstance(p.depend, SQLDepend)) for p in params):
30
47
  return True
48
+ if any(isinstance(p.annotation, sa_async.AsyncSession) for p in params):
49
+ return True
31
50
  if any(
32
- isinstance(prod, (SessionProvider, ORMProviderFactory._ModelProvider))
51
+ isinstance(prod, ORMProviderFactory._ModelProvider)
33
52
  for p in params
34
53
  for prod in p.providers
35
54
  ): # noqa: E501,UP038
36
55
  return True
37
56
  return False
38
57
 
58
+ def providers(self):
59
+ return [SessionProvider()]
60
+
39
61
  async def supply(self, ctx: Contexts, serv: SqlalchemyService | None = None):
40
62
  if serv is None:
41
63
  return
@@ -48,22 +70,6 @@ class DatabasePropagator(Propagator):
48
70
  yield self.supply, True, 20
49
71
 
50
72
 
51
- class SessionProvider(Provider[sa_async.AsyncSession]):
52
- priority = 10
53
-
54
- async def __call__(self, context: Contexts):
55
- if "$db_session" in context:
56
- return context["$db_session"]
57
- try:
58
- db = it(Launart).get_component(SqlalchemyService)
59
- stack = context[STACK]
60
- sess = await stack.enter_async_context(db.get_session())
61
- context["$db_session"] = sess
62
- return sess
63
- except ValueError:
64
- return
65
-
66
-
67
73
  @dataclass(unsafe_hash=True)
68
74
  class Option:
69
75
  stream: bool = True
@@ -254,6 +260,14 @@ class ORMProviderFactory(ProviderFactory):
254
260
  return self._ModelProvider(statement, option)
255
261
 
256
262
 
257
- global_propagators.append(db_supplier := DatabasePropagator())
258
- global_providers.append(sess_provider := SessionProvider())
259
- global_providers.append(orm_factory := ORMProviderFactory())
263
+ plg = Plugin.current()
264
+
265
+
266
+ def _provides():
267
+ global_propagators.append(db_supplier := DatabasePropagator())
268
+ global_providers.append(orm_factory := ORMProviderFactory())
269
+ yield lambda: global_propagators.remove(db_supplier)
270
+ yield lambda: global_providers.remove(orm_factory)
271
+
272
+
273
+ plg.effect(_provides, "database providers")