toms-fast 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. toms_fast-0.2.1.dist-info/METADATA +467 -0
  2. toms_fast-0.2.1.dist-info/RECORD +60 -0
  3. toms_fast-0.2.1.dist-info/WHEEL +4 -0
  4. toms_fast-0.2.1.dist-info/entry_points.txt +2 -0
  5. tomskit/__init__.py +0 -0
  6. tomskit/celery/README.md +693 -0
  7. tomskit/celery/__init__.py +4 -0
  8. tomskit/celery/celery.py +306 -0
  9. tomskit/celery/config.py +377 -0
  10. tomskit/cli/__init__.py +207 -0
  11. tomskit/cli/__main__.py +8 -0
  12. tomskit/cli/scaffold.py +123 -0
  13. tomskit/cli/templates/__init__.py +42 -0
  14. tomskit/cli/templates/base.py +348 -0
  15. tomskit/cli/templates/celery.py +101 -0
  16. tomskit/cli/templates/extensions.py +213 -0
  17. tomskit/cli/templates/fastapi.py +400 -0
  18. tomskit/cli/templates/migrations.py +281 -0
  19. tomskit/cli/templates_config.py +122 -0
  20. tomskit/logger/README.md +466 -0
  21. tomskit/logger/__init__.py +4 -0
  22. tomskit/logger/config.py +106 -0
  23. tomskit/logger/logger.py +290 -0
  24. tomskit/py.typed +0 -0
  25. tomskit/redis/README.md +462 -0
  26. tomskit/redis/__init__.py +6 -0
  27. tomskit/redis/config.py +85 -0
  28. tomskit/redis/redis_pool.py +87 -0
  29. tomskit/redis/redis_sync.py +66 -0
  30. tomskit/server/__init__.py +47 -0
  31. tomskit/server/config.py +117 -0
  32. tomskit/server/exceptions.py +412 -0
  33. tomskit/server/middleware.py +371 -0
  34. tomskit/server/parser.py +312 -0
  35. tomskit/server/resource.py +464 -0
  36. tomskit/server/server.py +276 -0
  37. tomskit/server/type.py +263 -0
  38. tomskit/sqlalchemy/README.md +590 -0
  39. tomskit/sqlalchemy/__init__.py +20 -0
  40. tomskit/sqlalchemy/config.py +125 -0
  41. tomskit/sqlalchemy/database.py +125 -0
  42. tomskit/sqlalchemy/pagination.py +359 -0
  43. tomskit/sqlalchemy/property.py +19 -0
  44. tomskit/sqlalchemy/sqlalchemy.py +131 -0
  45. tomskit/sqlalchemy/types.py +32 -0
  46. tomskit/task/README.md +67 -0
  47. tomskit/task/__init__.py +4 -0
  48. tomskit/task/task_manager.py +124 -0
  49. tomskit/tools/README.md +63 -0
  50. tomskit/tools/__init__.py +18 -0
  51. tomskit/tools/config.py +70 -0
  52. tomskit/tools/warnings.py +37 -0
  53. tomskit/tools/woker.py +81 -0
  54. tomskit/utils/README.md +666 -0
  55. tomskit/utils/README_SERIALIZER.md +644 -0
  56. tomskit/utils/__init__.py +35 -0
  57. tomskit/utils/fields.py +434 -0
  58. tomskit/utils/marshal_utils.py +137 -0
  59. tomskit/utils/response_utils.py +13 -0
  60. tomskit/utils/serializers.py +447 -0
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+ from abc import ABC, abstractmethod
3
+ from sqlalchemy.orm.relationships import _RelationshipDeclared
4
+ from typing import Any, Optional
5
+
6
+ from sqlalchemy import CHAR as sa_CHAR
7
+ from sqlalchemy import JSON as sa_JSON
8
+ from sqlalchemy import BigInteger as sa_BigInteger
9
+ from sqlalchemy import Boolean as sa_Boolean
10
+ from sqlalchemy import Column as sa_Column
11
+ from sqlalchemy import DateTime as sa_DateTime
12
+ from sqlalchemy import Float as sa_Float
13
+ from sqlalchemy import ForeignKey as sa_ForeignKey
14
+ from sqlalchemy import Index as sa_Index
15
+ from sqlalchemy import Integer as sa_Integer
16
+ from sqlalchemy import LargeBinary as sa_LargeBinary
17
+ from sqlalchemy import MetaData as sa_MetaData
18
+ from sqlalchemy import Numeric as sa_Numeric
19
+ from sqlalchemy import PickleType as sa_PickleType
20
+ from sqlalchemy import PrimaryKeyConstraint as sa_PrimaryKeyConstraint
21
+ from sqlalchemy import Sequence as sa_Sequence
22
+ from sqlalchemy import String as sa_String
23
+ from sqlalchemy import Text as sa_Text
24
+ from sqlalchemy import UniqueConstraint as sa_UniqueConstraint
25
+ from sqlalchemy import and_ as sa_and_
26
+ from sqlalchemy import delete as sa_delete
27
+ from sqlalchemy import func as sa_func
28
+ from sqlalchemy import insert as sa_insert
29
+ from sqlalchemy import select as sa_select
30
+ from sqlalchemy import text as sa_text
31
+ from sqlalchemy import update as sa_update
32
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
33
+ from sqlalchemy.orm import DeclarativeBase
34
+ from sqlalchemy.orm import relationship as sa_relationship
35
+ from sqlalchemy.ext.asyncio import AsyncAttrs
36
+ from sqlalchemy.sql import Select
37
+
38
+ from tomskit.sqlalchemy.pagination import Pagination, SelectPagination
39
+
40
+ __all__ = [
41
+ "SQLAlchemy",
42
+ "Pagination",
43
+ "SelectPagination"
44
+ ]
45
+ # Define a naming convention for indexes and constraints in MySQL
46
+ DB_INDEXES_NAMING_CONVENTION = {
47
+ "ix": "%(column_0_label)s_idx",
48
+ "uq": "%(table_name)s_%(column_0_name)s_key",
49
+ "ck": "%(table_name)s_%(constraint_name)s_check",
50
+ "fk": "%(table_name)s_%(column_0_name)s_fkey",
51
+ "pk": "%(table_name)s_pkey",
52
+ }
53
+
54
+ # Define metadata with the naming convention
55
+ metadata = sa_MetaData(naming_convention=DB_INDEXES_NAMING_CONVENTION)
56
+
57
+ # Base model class for all models
58
+ # Base = declarative_base(metadata=metadata)
59
+
60
+
61
+ # Define the SQLAlchemy class
62
+ class SQLAlchemy(ABC):
63
+ # Define all the common SQLAlchemy constructs
64
+ class Model(AsyncAttrs, DeclarativeBase):
65
+ metadata = metadata
66
+
67
+ Column = sa_Column
68
+ CHAR = sa_CHAR
69
+ BigInteger = sa_BigInteger
70
+ Boolean = sa_Boolean
71
+ DateTime = sa_DateTime
72
+ Float = sa_Float
73
+ Integer = sa_Integer
74
+ JSON = sa_JSON
75
+ LargeBinary = sa_LargeBinary
76
+ Numeric = sa_Numeric
77
+ PickleType = sa_PickleType
78
+ Sequence = sa_Sequence
79
+ String = sa_String
80
+ Text = sa_Text
81
+ text = staticmethod(sa_text)
82
+ ForeignKey = sa_ForeignKey
83
+ Index = sa_Index
84
+ uuid = sa_CHAR(36) # Define a UUID column
85
+ PrimaryKeyConstraint = sa_PrimaryKeyConstraint
86
+ UniqueConstraint = sa_UniqueConstraint
87
+ select = staticmethod(sa_select)
88
+ delete = staticmethod(sa_delete)
89
+ update = staticmethod(sa_update)
90
+ insert = staticmethod(sa_insert)
91
+ func = sa_func
92
+ relationship = staticmethod[..., _RelationshipDeclared[Any]](sa_relationship)
93
+ and_ = staticmethod(sa_and_)
94
+ def __init__(self) -> None:
95
+ self._engine: Optional[AsyncEngine] = None
96
+ self._SessionLocal: Optional[async_sessionmaker] = None
97
+
98
+
99
+ @abstractmethod
100
+ async def paginate(self,
101
+ select: Select[Any],
102
+ *,
103
+ page: int | None = None,
104
+ per_page: int | None = None,
105
+ max_per_page: int | None = None,
106
+ error_out: bool = True,
107
+ count: bool = True
108
+ ) -> Pagination:
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def session(self) -> AsyncSession:
113
+ raise NotImplementedError
114
+
115
+ @abstractmethod
116
+ def create_session(self) -> AsyncSession:
117
+ raise NotImplementedError
118
+
119
+ @abstractmethod
120
+ async def close_session(self):
121
+ raise NotImplementedError
122
+
123
+ @abstractmethod
124
+ def initialize_session_pool(self, db_url: str):
125
+ raise NotImplementedError
126
+
127
+ @abstractmethod
128
+ async def close_session_pool(self):
129
+ raise NotImplementedError
130
+
131
+
@@ -0,0 +1,32 @@
1
+ import uuid
2
+
3
+ from sqlalchemy import CHAR, TypeDecorator
4
+ from sqlalchemy.dialects.postgresql import UUID
5
+
6
+
7
+ class StringUUID(TypeDecorator):
8
+ impl = CHAR
9
+ cache_ok = True
10
+
11
+ def process_bind_param(self, value, dialect):
12
+ if value is None:
13
+ return value
14
+ elif dialect.name == "postgresql":
15
+ return str(value)
16
+ else:
17
+ return value.hex
18
+
19
+ def load_dialect_impl(self, dialect):
20
+ if dialect.name == "postgresql":
21
+ return dialect.type_descriptor(UUID())
22
+ else:
23
+ return dialect.type_descriptor(CHAR(36))
24
+
25
+ def process_result_value(self, value, dialect):
26
+ if value is None:
27
+ return value
28
+ return str(value)
29
+
30
+
31
+ def uuid_generate_v4():
32
+ return uuid.uuid4().hex
tomskit/task/README.md ADDED
@@ -0,0 +1,67 @@
1
+ # AsyncTaskManager 使用指南
2
+
3
+ `AsyncTaskManager` 是一个基于 `asyncio.TaskGroup` 的简单异步任务管理器,支持批量和单次任务执行,并提供可选的数据库会话支持和丰富的日志功能。
4
+
5
+ ## 功能概述
6
+
7
+ - **异步任务管理**:支持添加和并发执行多个异步任务。
8
+ - **数据库会话支持**:可选的数据库会话管理,适用于需要数据库交互的任务。
9
+ - **调试日志**:提供详细的任务执行日志,便于调试和性能分析。
10
+
11
+ ## 类定义
12
+
13
+ ### `AsyncTaskManager`
14
+
15
+ #### 初始化
16
+
17
+ ```python
18
+ def __init__(self, task_name: str = __name__, db: bool = False, debug: bool = False)
19
+ ```
20
+
21
+ - `task_name`:任务管理器的名称,用于日志记录。
22
+ - `db`:是否启用数据库会话。
23
+ - `debug`:是否启用调试日志。
24
+
25
+ #### 方法
26
+
27
+ - `add_task(target: TaskTarget, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> None`
28
+ - 添加一个异步任务。
29
+ - `target`:异步任务函数。
30
+ - `args`:传递给任务函数的位置参数。
31
+ - `kwargs`:传递给任务函数的关键字参数。
32
+
33
+ - `add_tasks(targets: list[tuple[TaskTarget, tuple, Optional[dict[str, Any]]]]) -> None`
34
+ - 批量添加多个异步任务。
35
+
36
+ - `async def run_all() -> None`
37
+ - 并发执行所有添加的任务,捕获并收集异常。
38
+
39
+ - `async def run_task(target: TaskTarget, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> Any`
40
+ - 快捷单任务调用,执行单个任务并返回结果。
41
+
42
+ ## 使用示例
43
+
44
+ ```python
45
+ import asyncio
46
+
47
+ async def sample_task(x):
48
+ await asyncio.sleep(1)
49
+ return x * 2
50
+
51
+ async def main():
52
+ manager = AsyncTaskManager(debug=True)
53
+ manager.add_task(sample_task, args=(5,))
54
+ await manager.run_all()
55
+ print(manager.results) # 输出: [10]
56
+
57
+ asyncio.run(main())
58
+ ```
59
+
60
+ ## 日志
61
+
62
+ - 启用 `debug` 模式后,任务的开始和结束时间将被记录。
63
+ - 异常将被记录并可供后续分析。
64
+
65
+ ## 数据库支持
66
+
67
+ - 如果 `db` 参数为 `True`,任务执行时将自动创建和关闭数据库会话。
@@ -0,0 +1,4 @@
1
+ from .task_manager import AsyncTaskManager
2
+
3
+
4
+ __all__ = ["AsyncTaskManager"]
@@ -0,0 +1,124 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ import uuid
5
+ from typing import Any, Awaitable, Callable, Optional
6
+ from tomskit.sqlalchemy.database import db as db_instance
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # 异步任务目标类型
11
+ TaskTarget = Callable[..., Awaitable[Any]]
12
+
13
+ class AsyncTaskManager:
14
+ """
15
+ 基于 asyncio.TaskGroup 的简单异步任务管理器,支持批量/单次任务执行,
16
+ 可选 db 会话,丰富日志
17
+ """
18
+ def __init__(
19
+ self,
20
+ task_name: str = __name__,
21
+ db: bool = False,
22
+ debug: bool = False,
23
+ ):
24
+ self.task_name = task_name
25
+ self.db = db
26
+ self.debug = debug
27
+ # 存放任务生成器
28
+ self.tasks: list[Callable[[], Awaitable[Any]]] = []
29
+ # 存放结果和异常
30
+ self.results: list[Any] = []
31
+ self.exceptions: list[Exception] = []
32
+
33
+ def add_task(
34
+ self,
35
+ target: TaskTarget,
36
+ args: tuple = (),
37
+ kwargs: Optional[dict[str, Any]] = None,
38
+ ) -> None:
39
+ """
40
+ 添加一个异步任务,参数同 target 函数签名
41
+ """
42
+ if kwargs is None:
43
+ kwargs = {}
44
+ func_name = getattr(target, '__name__', repr(target))
45
+ task_id = uuid.uuid4().hex[:8]
46
+
47
+ async def wrapper() -> Any:
48
+ session = None
49
+ try:
50
+ # 建会话
51
+ if self.db:
52
+ session = db_instance.create_session()
53
+ logger.debug(f"[{self.task_name}] db session connected, task {task_id}")
54
+
55
+ # 调试日志: 开始
56
+ if self.debug:
57
+ logger.info(f"[{self.task_name}] Start {func_name}({task_id})")
58
+ start = time.perf_counter()
59
+
60
+ # 执行目标协程
61
+ result = await target(*args, **kwargs)
62
+ self.results.append(result)
63
+
64
+ # 调试日志: 完成
65
+ if self.debug:
66
+ elapsed = time.perf_counter() - start
67
+ logger.info(f"[{self.task_name}] Finish {func_name}({task_id}) in {elapsed:.3f}s")
68
+
69
+ except Exception:
70
+ # 只日志,不立即收集,留给外层 TaskGroup 收集
71
+ logger.exception(f"[{self.task_name}] Error in {func_name}({task_id})")
72
+ raise
73
+
74
+ finally:
75
+ # 关闭会话
76
+ if self.db and session is not None:
77
+ await db_instance.close_session(session)
78
+ logger.debug(f"[{self.task_name}] db session closed, task {task_id}")
79
+
80
+ self.tasks.append(wrapper)
81
+
82
+ def add_tasks(
83
+ self,
84
+ targets: list[tuple[TaskTarget, tuple, Optional[dict[str, Any]]]]
85
+ ) -> None:
86
+ """
87
+ 批量添加任务: [(func, args, kwargs), ...]
88
+ """
89
+ for target, args, kwargs in targets:
90
+ self.add_task(target, args=args, kwargs=kwargs)
91
+
92
+ async def run_all(self) -> None:
93
+ """
94
+ 并发执行所有添加的任务,使用 asyncio.TaskGroup,
95
+ 捕获异常并收集,不向外抛出。
96
+ """
97
+ self.results.clear()
98
+ self.exceptions.clear()
99
+ try:
100
+ async with asyncio.TaskGroup() as tg:
101
+ for wrapper in self.tasks:
102
+ tg.create_task(wrapper()) # type: ignore
103
+ except* Exception as eg:
104
+ for e in eg.exceptions:
105
+ if not isinstance(e, asyncio.CancelledError):
106
+ self.exceptions.append(e)
107
+
108
+ async def run_task(
109
+ self,
110
+ target: TaskTarget,
111
+ args: tuple = (),
112
+ kwargs: Optional[dict[str, Any]] = None,
113
+ ) -> Any:
114
+ """
115
+ 快捷单任务调用
116
+ """
117
+ self.tasks.clear()
118
+ self.results.clear()
119
+ self.exceptions.clear()
120
+ self.add_task(target, args=args, kwargs=kwargs)
121
+ await self.run_all()
122
+ if self.exceptions:
123
+ raise self.exceptions[0]
124
+ return self.results[0]
@@ -0,0 +1,63 @@
1
+ # FastAI Toolkit - Worker 模块
2
+
3
+ 该模块提供了与 Redis 交互的功能,用于管理和监控 `uvicorn.workers.UvicornWorker` 的进程信息。以下是模块中可用的函数及其用途:
4
+
5
+ ## 函数
6
+
7
+ ### `worker_register_to_redis(redis: Redis, hostname: str, pid: int)`
8
+
9
+ - **描述**:
10
+ - 在 `gunicorn` 启动时,将 `uvicorn.workers.UvicornWorker` 的进程信息注册到 Redis。
11
+
12
+ - **参数**:
13
+ - `redis`: Redis 客户端实例。
14
+ - `hostname`: 主机名。
15
+ - `pid`: 进程 ID。
16
+
17
+ - **功能**:
18
+ - 将进程信息存储在 Redis 中,以便后续管理和监控。
19
+
20
+ ### `worker_delete_from_redis(redis: Redis, hostname: str, pid: int)`
21
+
22
+ - **描述**:
23
+ - 在 `gunicorn` 关闭时,从 Redis 中删除 `uvicorn.workers.UvicornWorker` 的进程信息。
24
+
25
+ - **参数**:
26
+ - `redis`: Redis 客户端实例。
27
+ - `hostname`: 主机名。
28
+ - `pid`: 进程 ID。
29
+
30
+ - **功能**:
31
+ - 从 Redis 中移除进程信息,释放资源。
32
+
33
+ ### `async worker_update_to_redis(hostname: str, pid: int, update_info: dict)`
34
+
35
+ - **描述**:
36
+ - 在 `uvicorn.workers.UvicornWorker` 中,更新进程信息到 Redis。
37
+
38
+ - **参数**:
39
+ - `hostname`: 主机名。
40
+ - `pid`: 进程 ID。
41
+ - `update_info`: 包含更新信息的字典,例如请求计数、异常计数等。
42
+
43
+ - **功能**:
44
+ - 更新 Redis 中的进程信息,保持数据的实时性。
45
+
46
+ ### `async get_all_worker_info_from_redis(hostname: str) -> dict`
47
+
48
+ - **描述**:
49
+ - 从 Redis 中获取所有的进程信息,并返回一个字典。
50
+
51
+ - **参数**:
52
+ - `hostname`: 主机名。
53
+
54
+ - **返回**:
55
+ - 包含所有进程信息的字典。
56
+
57
+ - **功能**:
58
+ - 提供对所有注册进程的全面监控。
59
+
60
+ ## 注意事项
61
+
62
+ - 确保在使用这些函数前,已正确初始化 Redis 客户端。
63
+ - 这些函数主要用于管理和监控 `uvicorn` 工作进程,适用于需要实时更新和获取进程状态的场景。
@@ -0,0 +1,18 @@
1
+
2
+ from tomskit.tools.config import GunicornSettings
3
+ from tomskit.tools.warnings import enable_unawaited_warning
4
+ from tomskit.tools.woker import (
5
+ worker_register_to_redis,
6
+ worker_delete_from_redis,
7
+ worker_update_to_redis,
8
+ get_all_worker_info_from_redis )
9
+
10
+
11
+ __all__ = [
12
+ "worker_register_to_redis",
13
+ "worker_delete_from_redis",
14
+ "worker_update_to_redis",
15
+ "get_all_worker_info_from_redis",
16
+ "GunicornSettings",
17
+ "enable_unawaited_warning"
18
+ ]
@@ -0,0 +1,70 @@
1
+ import os
2
+ from typing import List, Optional
3
+ from pydantic import Field, field_validator
4
+ from pydantic_settings import BaseSettings
5
+
6
+ class GunicornSettings(BaseSettings):
7
+ """
8
+ Configuration settings for Gunicorn
9
+ """
10
+
11
+ bind: str = Field(
12
+ default="0.0.0.0:5001",
13
+ description="Bind address and port to listen on"
14
+ )
15
+
16
+ pidfile: Optional[str] = Field(
17
+ default=None,
18
+ description="File to write the PID of the main process"
19
+ )
20
+
21
+ proc_name: Optional[str] = Field(
22
+ default=None,
23
+ description="Name of the process"
24
+ )
25
+
26
+ workers: int = Field(
27
+ default=0,
28
+ description="Number of workers to run (0 = auto by CPU cores)"
29
+ )
30
+
31
+ cpu_affinity: List[int] = Field(
32
+ default_factory=list,
33
+ description="List of CPU core IDs to bind workers to"
34
+ )
35
+
36
+ daemon: bool = Field(
37
+ default=False,
38
+ description="Run as a daemon"
39
+ )
40
+
41
+ @field_validator("workers", mode="before")
42
+ @classmethod
43
+ def default_workers(cls, v):
44
+ if not v or int(v) <= 0:
45
+ return os.cpu_count() or 1
46
+ return int(v)
47
+
48
+ @field_validator("cpu_affinity", mode="before")
49
+ @classmethod
50
+ def parse_cpu_affinity(cls, v):
51
+ if v is None or (isinstance(v, str) and not v.strip()):
52
+ return []
53
+ if isinstance(v, str):
54
+ try:
55
+ return [int(x.strip()) for x in v.split(",") if x.strip()]
56
+ except Exception as e:
57
+ raise ValueError(f"Invalid CPU_AFFINITY string: {v}") from e
58
+ if isinstance(v, list):
59
+ return v
60
+ return []
61
+
62
+ @field_validator("pidfile", mode="after")
63
+ @classmethod
64
+ def ensure_pidfile_dir_exists(cls, v: str):
65
+ if not v:
66
+ return None
67
+ dir_path = os.path.dirname(v)
68
+ if dir_path and not os.path.exists(dir_path):
69
+ os.makedirs(dir_path, exist_ok=True)
70
+ return v
@@ -0,0 +1,37 @@
1
+
2
+
3
+ import sys
4
+ import warnings
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ def enable_unawaited_warning():
10
+ """
11
+ 捕获所有未被 await 的协程产生的 RuntimeWarning,
12
+ 并用 logger.critical 记录:包括文件名和行号。
13
+ """
14
+ # 原始的 showwarning,用于非 RuntimeWarning 恢复默认行为
15
+ _orig_showwarning = warnings.showwarning
16
+
17
+ def _custom_warning_handler(
18
+ message,
19
+ category,
20
+ filename,
21
+ lineno,
22
+ file=None,
23
+ line=None
24
+ ):
25
+ if issubclass(category, RuntimeWarning):
26
+ # 只处理 RuntimeWarning
27
+ logger.critical(f"(File: {filename}, Line: {lineno}) {message}")
28
+ else:
29
+ # 其他警告恢复原来的行为
30
+ _orig_showwarning(message, category, filename, lineno, file, line)
31
+
32
+ # 打开协程 origin tracking(需 Python 3.11+)
33
+ sys.set_coroutine_origin_tracking_depth(5)
34
+ # 始终发出 RuntimeWarning
35
+ warnings.simplefilter("always", RuntimeWarning)
36
+ # 覆写全局 showwarning
37
+ warnings.showwarning = _custom_warning_handler
tomskit/tools/woker.py ADDED
@@ -0,0 +1,81 @@
1
+
2
+ from datetime import datetime
3
+ from redis import Redis
4
+ from tomskit.redis import redis_client
5
+
6
+ def worker_register_to_redis(redis: Redis, hostname: str, pid: int):
7
+ """
8
+ gunicorn 启动时,把 uvicorn.workers.UvicornWorker 注册进程到 redis
9
+ """
10
+ process_info: dict = {
11
+ "pid": str(pid),
12
+ "start_at": int(datetime.now().timestamp()),
13
+ "uptime": 0,
14
+ "request_count": 0,
15
+ "exception_count": 0,
16
+ "server_error_count": 0,
17
+ "status": "starting",
18
+ "last_update_by" : "gunicorn",
19
+ "last_update_at": int(datetime.now().timestamp()),
20
+ }
21
+ redis.sadd(f"{hostname}:workers", str(pid))
22
+ redis.hset(f"{hostname}:worker:{str(pid)}", mapping=process_info)
23
+
24
+
25
+ def worker_delete_from_redis(redis: Redis, hostname: str, pid: int):
26
+ """
27
+ gunicorn 关闭时,把 uvicorn.workers.UvicornWorker 从 redis 中删除
28
+ """
29
+ redis.srem(f"{hostname}:workers", str(pid))
30
+ redis.delete(f"{hostname}:worker:{str(pid)}")
31
+
32
+
33
+ async def worker_update_to_redis(hostname: str, pid: int, update_info: dict):
34
+ """
35
+ 在 uvicorn.workers.UvicornWorker 中,更新进程信息到 redis
36
+ update_info: dict = {
37
+ "request_count": int,
38
+ "exception_count": int,
39
+ "server_error_count": int,
40
+ }
41
+ """
42
+
43
+ redis = redis_client._client
44
+ if redis is None:
45
+ raise RuntimeError("Redis client is not initialized. Call initialize first.")
46
+
47
+ worker_key = f"{hostname}:worker:{str(pid)}"
48
+
49
+ # 获取当前进程信息
50
+ current_info = await redis.hgetall(worker_key) # type: ignore
51
+
52
+ # 更新进程信息
53
+ if current_info:
54
+ current_info["uptime"] = int(datetime.now().timestamp()) - int(current_info["start_at"])
55
+ current_info["last_update_at"] = int(datetime.now().timestamp())
56
+ current_info["status"] = "running"
57
+ current_info["last_update_by"] = "uvicorn.tomskit"
58
+ current_info.update(update_info)
59
+ # 更新到 redis
60
+ await redis.hset(worker_key, mapping=current_info) # type: ignore
61
+
62
+
63
+ async def get_all_worker_info_from_redis(hostname: str) -> dict:
64
+ """
65
+ 从 redis 中获取所有的进程信息,返回 dict
66
+ """
67
+ redis = redis_client._client
68
+ if redis is None:
69
+ raise RuntimeError("Redis client is not initialized. Call initialize first.")
70
+
71
+ # 获取所有的 worker pids
72
+ worker_pids = await redis.smembers(f"{hostname}:workers") # type: ignore
73
+
74
+ all_process_info = {}
75
+ for pid in worker_pids:
76
+ worker_key = f"{hostname}:worker:{pid}"
77
+ process_info = await redis.hgetall(worker_key) # type: ignore
78
+ if process_info:
79
+ all_process_info[pid] = {k: v for k, v in process_info.items()}
80
+
81
+ return all_process_info