toms-fast 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toms_fast-0.2.1.dist-info/METADATA +467 -0
- toms_fast-0.2.1.dist-info/RECORD +60 -0
- toms_fast-0.2.1.dist-info/WHEEL +4 -0
- toms_fast-0.2.1.dist-info/entry_points.txt +2 -0
- tomskit/__init__.py +0 -0
- tomskit/celery/README.md +693 -0
- tomskit/celery/__init__.py +4 -0
- tomskit/celery/celery.py +306 -0
- tomskit/celery/config.py +377 -0
- tomskit/cli/__init__.py +207 -0
- tomskit/cli/__main__.py +8 -0
- tomskit/cli/scaffold.py +123 -0
- tomskit/cli/templates/__init__.py +42 -0
- tomskit/cli/templates/base.py +348 -0
- tomskit/cli/templates/celery.py +101 -0
- tomskit/cli/templates/extensions.py +213 -0
- tomskit/cli/templates/fastapi.py +400 -0
- tomskit/cli/templates/migrations.py +281 -0
- tomskit/cli/templates_config.py +122 -0
- tomskit/logger/README.md +466 -0
- tomskit/logger/__init__.py +4 -0
- tomskit/logger/config.py +106 -0
- tomskit/logger/logger.py +290 -0
- tomskit/py.typed +0 -0
- tomskit/redis/README.md +462 -0
- tomskit/redis/__init__.py +6 -0
- tomskit/redis/config.py +85 -0
- tomskit/redis/redis_pool.py +87 -0
- tomskit/redis/redis_sync.py +66 -0
- tomskit/server/__init__.py +47 -0
- tomskit/server/config.py +117 -0
- tomskit/server/exceptions.py +412 -0
- tomskit/server/middleware.py +371 -0
- tomskit/server/parser.py +312 -0
- tomskit/server/resource.py +464 -0
- tomskit/server/server.py +276 -0
- tomskit/server/type.py +263 -0
- tomskit/sqlalchemy/README.md +590 -0
- tomskit/sqlalchemy/__init__.py +20 -0
- tomskit/sqlalchemy/config.py +125 -0
- tomskit/sqlalchemy/database.py +125 -0
- tomskit/sqlalchemy/pagination.py +359 -0
- tomskit/sqlalchemy/property.py +19 -0
- tomskit/sqlalchemy/sqlalchemy.py +131 -0
- tomskit/sqlalchemy/types.py +32 -0
- tomskit/task/README.md +67 -0
- tomskit/task/__init__.py +4 -0
- tomskit/task/task_manager.py +124 -0
- tomskit/tools/README.md +63 -0
- tomskit/tools/__init__.py +18 -0
- tomskit/tools/config.py +70 -0
- tomskit/tools/warnings.py +37 -0
- tomskit/tools/woker.py +81 -0
- tomskit/utils/README.md +666 -0
- tomskit/utils/README_SERIALIZER.md +644 -0
- tomskit/utils/__init__.py +35 -0
- tomskit/utils/fields.py +434 -0
- tomskit/utils/marshal_utils.py +137 -0
- tomskit/utils/response_utils.py +13 -0
- tomskit/utils/serializers.py +447 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from sqlalchemy.orm.relationships import _RelationshipDeclared
|
|
4
|
+
from typing import Any, Optional
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import CHAR as sa_CHAR
|
|
7
|
+
from sqlalchemy import JSON as sa_JSON
|
|
8
|
+
from sqlalchemy import BigInteger as sa_BigInteger
|
|
9
|
+
from sqlalchemy import Boolean as sa_Boolean
|
|
10
|
+
from sqlalchemy import Column as sa_Column
|
|
11
|
+
from sqlalchemy import DateTime as sa_DateTime
|
|
12
|
+
from sqlalchemy import Float as sa_Float
|
|
13
|
+
from sqlalchemy import ForeignKey as sa_ForeignKey
|
|
14
|
+
from sqlalchemy import Index as sa_Index
|
|
15
|
+
from sqlalchemy import Integer as sa_Integer
|
|
16
|
+
from sqlalchemy import LargeBinary as sa_LargeBinary
|
|
17
|
+
from sqlalchemy import MetaData as sa_MetaData
|
|
18
|
+
from sqlalchemy import Numeric as sa_Numeric
|
|
19
|
+
from sqlalchemy import PickleType as sa_PickleType
|
|
20
|
+
from sqlalchemy import PrimaryKeyConstraint as sa_PrimaryKeyConstraint
|
|
21
|
+
from sqlalchemy import Sequence as sa_Sequence
|
|
22
|
+
from sqlalchemy import String as sa_String
|
|
23
|
+
from sqlalchemy import Text as sa_Text
|
|
24
|
+
from sqlalchemy import UniqueConstraint as sa_UniqueConstraint
|
|
25
|
+
from sqlalchemy import and_ as sa_and_
|
|
26
|
+
from sqlalchemy import delete as sa_delete
|
|
27
|
+
from sqlalchemy import func as sa_func
|
|
28
|
+
from sqlalchemy import insert as sa_insert
|
|
29
|
+
from sqlalchemy import select as sa_select
|
|
30
|
+
from sqlalchemy import text as sa_text
|
|
31
|
+
from sqlalchemy import update as sa_update
|
|
32
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
33
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
34
|
+
from sqlalchemy.orm import relationship as sa_relationship
|
|
35
|
+
from sqlalchemy.ext.asyncio import AsyncAttrs
|
|
36
|
+
from sqlalchemy.sql import Select
|
|
37
|
+
|
|
38
|
+
from tomskit.sqlalchemy.pagination import Pagination, SelectPagination
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"SQLAlchemy",
|
|
42
|
+
"Pagination",
|
|
43
|
+
"SelectPagination"
|
|
44
|
+
]
|
|
45
|
+
# Define a naming convention for indexes and constraints in MySQL
|
|
46
|
+
DB_INDEXES_NAMING_CONVENTION = {
|
|
47
|
+
"ix": "%(column_0_label)s_idx",
|
|
48
|
+
"uq": "%(table_name)s_%(column_0_name)s_key",
|
|
49
|
+
"ck": "%(table_name)s_%(constraint_name)s_check",
|
|
50
|
+
"fk": "%(table_name)s_%(column_0_name)s_fkey",
|
|
51
|
+
"pk": "%(table_name)s_pkey",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
# Define metadata with the naming convention
|
|
55
|
+
metadata = sa_MetaData(naming_convention=DB_INDEXES_NAMING_CONVENTION)
|
|
56
|
+
|
|
57
|
+
# Base model class for all models
|
|
58
|
+
# Base = declarative_base(metadata=metadata)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
# Define the SQLAlchemy class
|
|
62
|
+
class SQLAlchemy(ABC):
|
|
63
|
+
# Define all the common SQLAlchemy constructs
|
|
64
|
+
class Model(AsyncAttrs, DeclarativeBase):
|
|
65
|
+
metadata = metadata
|
|
66
|
+
|
|
67
|
+
Column = sa_Column
|
|
68
|
+
CHAR = sa_CHAR
|
|
69
|
+
BigInteger = sa_BigInteger
|
|
70
|
+
Boolean = sa_Boolean
|
|
71
|
+
DateTime = sa_DateTime
|
|
72
|
+
Float = sa_Float
|
|
73
|
+
Integer = sa_Integer
|
|
74
|
+
JSON = sa_JSON
|
|
75
|
+
LargeBinary = sa_LargeBinary
|
|
76
|
+
Numeric = sa_Numeric
|
|
77
|
+
PickleType = sa_PickleType
|
|
78
|
+
Sequence = sa_Sequence
|
|
79
|
+
String = sa_String
|
|
80
|
+
Text = sa_Text
|
|
81
|
+
text = staticmethod(sa_text)
|
|
82
|
+
ForeignKey = sa_ForeignKey
|
|
83
|
+
Index = sa_Index
|
|
84
|
+
uuid = sa_CHAR(36) # Define a UUID column
|
|
85
|
+
PrimaryKeyConstraint = sa_PrimaryKeyConstraint
|
|
86
|
+
UniqueConstraint = sa_UniqueConstraint
|
|
87
|
+
select = staticmethod(sa_select)
|
|
88
|
+
delete = staticmethod(sa_delete)
|
|
89
|
+
update = staticmethod(sa_update)
|
|
90
|
+
insert = staticmethod(sa_insert)
|
|
91
|
+
func = sa_func
|
|
92
|
+
relationship = staticmethod[..., _RelationshipDeclared[Any]](sa_relationship)
|
|
93
|
+
and_ = staticmethod(sa_and_)
|
|
94
|
+
def __init__(self) -> None:
|
|
95
|
+
self._engine: Optional[AsyncEngine] = None
|
|
96
|
+
self._SessionLocal: Optional[async_sessionmaker] = None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
async def paginate(self,
|
|
101
|
+
select: Select[Any],
|
|
102
|
+
*,
|
|
103
|
+
page: int | None = None,
|
|
104
|
+
per_page: int | None = None,
|
|
105
|
+
max_per_page: int | None = None,
|
|
106
|
+
error_out: bool = True,
|
|
107
|
+
count: bool = True
|
|
108
|
+
) -> Pagination:
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
@property
|
|
112
|
+
def session(self) -> AsyncSession:
|
|
113
|
+
raise NotImplementedError
|
|
114
|
+
|
|
115
|
+
@abstractmethod
|
|
116
|
+
def create_session(self) -> AsyncSession:
|
|
117
|
+
raise NotImplementedError
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
async def close_session(self):
|
|
121
|
+
raise NotImplementedError
|
|
122
|
+
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def initialize_session_pool(self, db_url: str):
|
|
125
|
+
raise NotImplementedError
|
|
126
|
+
|
|
127
|
+
@abstractmethod
|
|
128
|
+
async def close_session_pool(self):
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import CHAR, TypeDecorator
|
|
4
|
+
from sqlalchemy.dialects.postgresql import UUID
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class StringUUID(TypeDecorator):
|
|
8
|
+
impl = CHAR
|
|
9
|
+
cache_ok = True
|
|
10
|
+
|
|
11
|
+
def process_bind_param(self, value, dialect):
|
|
12
|
+
if value is None:
|
|
13
|
+
return value
|
|
14
|
+
elif dialect.name == "postgresql":
|
|
15
|
+
return str(value)
|
|
16
|
+
else:
|
|
17
|
+
return value.hex
|
|
18
|
+
|
|
19
|
+
def load_dialect_impl(self, dialect):
|
|
20
|
+
if dialect.name == "postgresql":
|
|
21
|
+
return dialect.type_descriptor(UUID())
|
|
22
|
+
else:
|
|
23
|
+
return dialect.type_descriptor(CHAR(36))
|
|
24
|
+
|
|
25
|
+
def process_result_value(self, value, dialect):
|
|
26
|
+
if value is None:
|
|
27
|
+
return value
|
|
28
|
+
return str(value)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def uuid_generate_v4():
|
|
32
|
+
return uuid.uuid4().hex
|
tomskit/task/README.md
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# AsyncTaskManager 使用指南
|
|
2
|
+
|
|
3
|
+
`AsyncTaskManager` 是一个基于 `asyncio.TaskGroup` 的简单异步任务管理器,支持批量和单次任务执行,并提供可选的数据库会话支持和丰富的日志功能。
|
|
4
|
+
|
|
5
|
+
## 功能概述
|
|
6
|
+
|
|
7
|
+
- **异步任务管理**:支持添加和并发执行多个异步任务。
|
|
8
|
+
- **数据库会话支持**:可选的数据库会话管理,适用于需要数据库交互的任务。
|
|
9
|
+
- **调试日志**:提供详细的任务执行日志,便于调试和性能分析。
|
|
10
|
+
|
|
11
|
+
## 类定义
|
|
12
|
+
|
|
13
|
+
### `AsyncTaskManager`
|
|
14
|
+
|
|
15
|
+
#### 初始化
|
|
16
|
+
|
|
17
|
+
```python
|
|
18
|
+
def __init__(self, task_name: str = __name__, db: bool = False, debug: bool = False)
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
- `task_name`:任务管理器的名称,用于日志记录。
|
|
22
|
+
- `db`:是否启用数据库会话。
|
|
23
|
+
- `debug`:是否启用调试日志。
|
|
24
|
+
|
|
25
|
+
#### 方法
|
|
26
|
+
|
|
27
|
+
- `add_task(target: TaskTarget, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> None`
|
|
28
|
+
- 添加一个异步任务。
|
|
29
|
+
- `target`:异步任务函数。
|
|
30
|
+
- `args`:传递给任务函数的位置参数。
|
|
31
|
+
- `kwargs`:传递给任务函数的关键字参数。
|
|
32
|
+
|
|
33
|
+
- `add_tasks(targets: list[tuple[TaskTarget, tuple, Optional[dict[str, Any]]]]) -> None`
|
|
34
|
+
- 批量添加多个异步任务。
|
|
35
|
+
|
|
36
|
+
- `async def run_all() -> None`
|
|
37
|
+
- 并发执行所有添加的任务,捕获并收集异常。
|
|
38
|
+
|
|
39
|
+
- `async def run_task(target: TaskTarget, args: tuple = (), kwargs: Optional[dict[str, Any]] = None) -> Any`
|
|
40
|
+
- 快捷单任务调用,执行单个任务并返回结果。
|
|
41
|
+
|
|
42
|
+
## 使用示例
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
import asyncio
|
|
46
|
+
|
|
47
|
+
async def sample_task(x):
|
|
48
|
+
await asyncio.sleep(1)
|
|
49
|
+
return x * 2
|
|
50
|
+
|
|
51
|
+
async def main():
|
|
52
|
+
manager = AsyncTaskManager(debug=True)
|
|
53
|
+
manager.add_task(sample_task, args=(5,))
|
|
54
|
+
await manager.run_all()
|
|
55
|
+
print(manager.results) # 输出: [10]
|
|
56
|
+
|
|
57
|
+
asyncio.run(main())
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## 日志
|
|
61
|
+
|
|
62
|
+
- 启用 `debug` 模式后,任务的开始和结束时间将被记录。
|
|
63
|
+
- 异常将被记录并可供后续分析。
|
|
64
|
+
|
|
65
|
+
## 数据库支持
|
|
66
|
+
|
|
67
|
+
- 如果 `db` 参数为 `True`,任务执行时将自动创建和关闭数据库会话。
|
tomskit/task/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any, Awaitable, Callable, Optional
|
|
6
|
+
from tomskit.sqlalchemy.database import db as db_instance
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# 异步任务目标类型
|
|
11
|
+
TaskTarget = Callable[..., Awaitable[Any]]
|
|
12
|
+
|
|
13
|
+
class AsyncTaskManager:
|
|
14
|
+
"""
|
|
15
|
+
基于 asyncio.TaskGroup 的简单异步任务管理器,支持批量/单次任务执行,
|
|
16
|
+
可选 db 会话,丰富日志
|
|
17
|
+
"""
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
task_name: str = __name__,
|
|
21
|
+
db: bool = False,
|
|
22
|
+
debug: bool = False,
|
|
23
|
+
):
|
|
24
|
+
self.task_name = task_name
|
|
25
|
+
self.db = db
|
|
26
|
+
self.debug = debug
|
|
27
|
+
# 存放任务生成器
|
|
28
|
+
self.tasks: list[Callable[[], Awaitable[Any]]] = []
|
|
29
|
+
# 存放结果和异常
|
|
30
|
+
self.results: list[Any] = []
|
|
31
|
+
self.exceptions: list[Exception] = []
|
|
32
|
+
|
|
33
|
+
def add_task(
|
|
34
|
+
self,
|
|
35
|
+
target: TaskTarget,
|
|
36
|
+
args: tuple = (),
|
|
37
|
+
kwargs: Optional[dict[str, Any]] = None,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""
|
|
40
|
+
添加一个异步任务,参数同 target 函数签名
|
|
41
|
+
"""
|
|
42
|
+
if kwargs is None:
|
|
43
|
+
kwargs = {}
|
|
44
|
+
func_name = getattr(target, '__name__', repr(target))
|
|
45
|
+
task_id = uuid.uuid4().hex[:8]
|
|
46
|
+
|
|
47
|
+
async def wrapper() -> Any:
|
|
48
|
+
session = None
|
|
49
|
+
try:
|
|
50
|
+
# 建会话
|
|
51
|
+
if self.db:
|
|
52
|
+
session = db_instance.create_session()
|
|
53
|
+
logger.debug(f"[{self.task_name}] db session connected, task {task_id}")
|
|
54
|
+
|
|
55
|
+
# 调试日志: 开始
|
|
56
|
+
if self.debug:
|
|
57
|
+
logger.info(f"[{self.task_name}] Start {func_name}({task_id})")
|
|
58
|
+
start = time.perf_counter()
|
|
59
|
+
|
|
60
|
+
# 执行目标协程
|
|
61
|
+
result = await target(*args, **kwargs)
|
|
62
|
+
self.results.append(result)
|
|
63
|
+
|
|
64
|
+
# 调试日志: 完成
|
|
65
|
+
if self.debug:
|
|
66
|
+
elapsed = time.perf_counter() - start
|
|
67
|
+
logger.info(f"[{self.task_name}] Finish {func_name}({task_id}) in {elapsed:.3f}s")
|
|
68
|
+
|
|
69
|
+
except Exception:
|
|
70
|
+
# 只日志,不立即收集,留给外层 TaskGroup 收集
|
|
71
|
+
logger.exception(f"[{self.task_name}] Error in {func_name}({task_id})")
|
|
72
|
+
raise
|
|
73
|
+
|
|
74
|
+
finally:
|
|
75
|
+
# 关闭会话
|
|
76
|
+
if self.db and session is not None:
|
|
77
|
+
await db_instance.close_session(session)
|
|
78
|
+
logger.debug(f"[{self.task_name}] db session closed, task {task_id}")
|
|
79
|
+
|
|
80
|
+
self.tasks.append(wrapper)
|
|
81
|
+
|
|
82
|
+
def add_tasks(
|
|
83
|
+
self,
|
|
84
|
+
targets: list[tuple[TaskTarget, tuple, Optional[dict[str, Any]]]]
|
|
85
|
+
) -> None:
|
|
86
|
+
"""
|
|
87
|
+
批量添加任务: [(func, args, kwargs), ...]
|
|
88
|
+
"""
|
|
89
|
+
for target, args, kwargs in targets:
|
|
90
|
+
self.add_task(target, args=args, kwargs=kwargs)
|
|
91
|
+
|
|
92
|
+
async def run_all(self) -> None:
|
|
93
|
+
"""
|
|
94
|
+
并发执行所有添加的任务,使用 asyncio.TaskGroup,
|
|
95
|
+
捕获异常并收集,不向外抛出。
|
|
96
|
+
"""
|
|
97
|
+
self.results.clear()
|
|
98
|
+
self.exceptions.clear()
|
|
99
|
+
try:
|
|
100
|
+
async with asyncio.TaskGroup() as tg:
|
|
101
|
+
for wrapper in self.tasks:
|
|
102
|
+
tg.create_task(wrapper()) # type: ignore
|
|
103
|
+
except* Exception as eg:
|
|
104
|
+
for e in eg.exceptions:
|
|
105
|
+
if not isinstance(e, asyncio.CancelledError):
|
|
106
|
+
self.exceptions.append(e)
|
|
107
|
+
|
|
108
|
+
async def run_task(
|
|
109
|
+
self,
|
|
110
|
+
target: TaskTarget,
|
|
111
|
+
args: tuple = (),
|
|
112
|
+
kwargs: Optional[dict[str, Any]] = None,
|
|
113
|
+
) -> Any:
|
|
114
|
+
"""
|
|
115
|
+
快捷单任务调用
|
|
116
|
+
"""
|
|
117
|
+
self.tasks.clear()
|
|
118
|
+
self.results.clear()
|
|
119
|
+
self.exceptions.clear()
|
|
120
|
+
self.add_task(target, args=args, kwargs=kwargs)
|
|
121
|
+
await self.run_all()
|
|
122
|
+
if self.exceptions:
|
|
123
|
+
raise self.exceptions[0]
|
|
124
|
+
return self.results[0]
|
tomskit/tools/README.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# FastAI Toolkit - Worker 模块
|
|
2
|
+
|
|
3
|
+
该模块提供了与 Redis 交互的功能,用于管理和监控 `uvicorn.workers.UvicornWorker` 的进程信息。以下是模块中可用的函数及其用途:
|
|
4
|
+
|
|
5
|
+
## 函数
|
|
6
|
+
|
|
7
|
+
### `worker_register_to_redis(redis: Redis, hostname: str, pid: int)`
|
|
8
|
+
|
|
9
|
+
- **描述**:
|
|
10
|
+
- 在 `gunicorn` 启动时,将 `uvicorn.workers.UvicornWorker` 的进程信息注册到 Redis。
|
|
11
|
+
|
|
12
|
+
- **参数**:
|
|
13
|
+
- `redis`: Redis 客户端实例。
|
|
14
|
+
- `hostname`: 主机名。
|
|
15
|
+
- `pid`: 进程 ID。
|
|
16
|
+
|
|
17
|
+
- **功能**:
|
|
18
|
+
- 将进程信息存储在 Redis 中,以便后续管理和监控。
|
|
19
|
+
|
|
20
|
+
### `worker_delete_from_redis(redis: Redis, hostname: str, pid: int)`
|
|
21
|
+
|
|
22
|
+
- **描述**:
|
|
23
|
+
- 在 `gunicorn` 关闭时,从 Redis 中删除 `uvicorn.workers.UvicornWorker` 的进程信息。
|
|
24
|
+
|
|
25
|
+
- **参数**:
|
|
26
|
+
- `redis`: Redis 客户端实例。
|
|
27
|
+
- `hostname`: 主机名。
|
|
28
|
+
- `pid`: 进程 ID。
|
|
29
|
+
|
|
30
|
+
- **功能**:
|
|
31
|
+
- 从 Redis 中移除进程信息,释放资源。
|
|
32
|
+
|
|
33
|
+
### `async worker_update_to_redis(hostname: str, pid: int, update_info: dict)`
|
|
34
|
+
|
|
35
|
+
- **描述**:
|
|
36
|
+
- 在 `uvicorn.workers.UvicornWorker` 中,更新进程信息到 Redis。
|
|
37
|
+
|
|
38
|
+
- **参数**:
|
|
39
|
+
- `hostname`: 主机名。
|
|
40
|
+
- `pid`: 进程 ID。
|
|
41
|
+
- `update_info`: 包含更新信息的字典,例如请求计数、异常计数等。
|
|
42
|
+
|
|
43
|
+
- **功能**:
|
|
44
|
+
- 更新 Redis 中的进程信息,保持数据的实时性。
|
|
45
|
+
|
|
46
|
+
### `async get_all_worker_info_from_redis(hostname: str) -> dict`
|
|
47
|
+
|
|
48
|
+
- **描述**:
|
|
49
|
+
- 从 Redis 中获取所有的进程信息,并返回一个字典。
|
|
50
|
+
|
|
51
|
+
- **参数**:
|
|
52
|
+
- `hostname`: 主机名。
|
|
53
|
+
|
|
54
|
+
- **返回**:
|
|
55
|
+
- 包含所有进程信息的字典。
|
|
56
|
+
|
|
57
|
+
- **功能**:
|
|
58
|
+
- 提供对所有注册进程的全面监控。
|
|
59
|
+
|
|
60
|
+
## 注意事项
|
|
61
|
+
|
|
62
|
+
- 确保在使用这些函数前,已正确初始化 Redis 客户端。
|
|
63
|
+
- 这些函数主要用于管理和监控 `uvicorn` 工作进程,适用于需要实时更新和获取进程状态的场景。
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
|
|
2
|
+
from tomskit.tools.config import GunicornSettings
|
|
3
|
+
from tomskit.tools.warnings import enable_unawaited_warning
|
|
4
|
+
from tomskit.tools.woker import (
|
|
5
|
+
worker_register_to_redis,
|
|
6
|
+
worker_delete_from_redis,
|
|
7
|
+
worker_update_to_redis,
|
|
8
|
+
get_all_worker_info_from_redis )
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"worker_register_to_redis",
|
|
13
|
+
"worker_delete_from_redis",
|
|
14
|
+
"worker_update_to_redis",
|
|
15
|
+
"get_all_worker_info_from_redis",
|
|
16
|
+
"GunicornSettings",
|
|
17
|
+
"enable_unawaited_warning"
|
|
18
|
+
]
|
tomskit/tools/config.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
from pydantic import Field, field_validator
|
|
4
|
+
from pydantic_settings import BaseSettings
|
|
5
|
+
|
|
6
|
+
class GunicornSettings(BaseSettings):
|
|
7
|
+
"""
|
|
8
|
+
Configuration settings for Gunicorn
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
bind: str = Field(
|
|
12
|
+
default="0.0.0.0:5001",
|
|
13
|
+
description="Bind address and port to listen on"
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
pidfile: Optional[str] = Field(
|
|
17
|
+
default=None,
|
|
18
|
+
description="File to write the PID of the main process"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
proc_name: Optional[str] = Field(
|
|
22
|
+
default=None,
|
|
23
|
+
description="Name of the process"
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
workers: int = Field(
|
|
27
|
+
default=0,
|
|
28
|
+
description="Number of workers to run (0 = auto by CPU cores)"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
cpu_affinity: List[int] = Field(
|
|
32
|
+
default_factory=list,
|
|
33
|
+
description="List of CPU core IDs to bind workers to"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
daemon: bool = Field(
|
|
37
|
+
default=False,
|
|
38
|
+
description="Run as a daemon"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
@field_validator("workers", mode="before")
|
|
42
|
+
@classmethod
|
|
43
|
+
def default_workers(cls, v):
|
|
44
|
+
if not v or int(v) <= 0:
|
|
45
|
+
return os.cpu_count() or 1
|
|
46
|
+
return int(v)
|
|
47
|
+
|
|
48
|
+
@field_validator("cpu_affinity", mode="before")
|
|
49
|
+
@classmethod
|
|
50
|
+
def parse_cpu_affinity(cls, v):
|
|
51
|
+
if v is None or (isinstance(v, str) and not v.strip()):
|
|
52
|
+
return []
|
|
53
|
+
if isinstance(v, str):
|
|
54
|
+
try:
|
|
55
|
+
return [int(x.strip()) for x in v.split(",") if x.strip()]
|
|
56
|
+
except Exception as e:
|
|
57
|
+
raise ValueError(f"Invalid CPU_AFFINITY string: {v}") from e
|
|
58
|
+
if isinstance(v, list):
|
|
59
|
+
return v
|
|
60
|
+
return []
|
|
61
|
+
|
|
62
|
+
@field_validator("pidfile", mode="after")
|
|
63
|
+
@classmethod
|
|
64
|
+
def ensure_pidfile_dir_exists(cls, v: str):
|
|
65
|
+
if not v:
|
|
66
|
+
return None
|
|
67
|
+
dir_path = os.path.dirname(v)
|
|
68
|
+
if dir_path and not os.path.exists(dir_path):
|
|
69
|
+
os.makedirs(dir_path, exist_ok=True)
|
|
70
|
+
return v
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
import warnings
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
logger = logging.getLogger(__name__)
|
|
8
|
+
|
|
9
|
+
def enable_unawaited_warning():
|
|
10
|
+
"""
|
|
11
|
+
捕获所有未被 await 的协程产生的 RuntimeWarning,
|
|
12
|
+
并用 logger.critical 记录:包括文件名和行号。
|
|
13
|
+
"""
|
|
14
|
+
# 原始的 showwarning,用于非 RuntimeWarning 恢复默认行为
|
|
15
|
+
_orig_showwarning = warnings.showwarning
|
|
16
|
+
|
|
17
|
+
def _custom_warning_handler(
|
|
18
|
+
message,
|
|
19
|
+
category,
|
|
20
|
+
filename,
|
|
21
|
+
lineno,
|
|
22
|
+
file=None,
|
|
23
|
+
line=None
|
|
24
|
+
):
|
|
25
|
+
if issubclass(category, RuntimeWarning):
|
|
26
|
+
# 只处理 RuntimeWarning
|
|
27
|
+
logger.critical(f"(File: {filename}, Line: {lineno}) {message}")
|
|
28
|
+
else:
|
|
29
|
+
# 其他警告恢复原来的行为
|
|
30
|
+
_orig_showwarning(message, category, filename, lineno, file, line)
|
|
31
|
+
|
|
32
|
+
# 打开协程 origin tracking(需 Python 3.11+)
|
|
33
|
+
sys.set_coroutine_origin_tracking_depth(5)
|
|
34
|
+
# 始终发出 RuntimeWarning
|
|
35
|
+
warnings.simplefilter("always", RuntimeWarning)
|
|
36
|
+
# 覆写全局 showwarning
|
|
37
|
+
warnings.showwarning = _custom_warning_handler
|
tomskit/tools/woker.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from redis import Redis
|
|
4
|
+
from tomskit.redis import redis_client
|
|
5
|
+
|
|
6
|
+
def worker_register_to_redis(redis: Redis, hostname: str, pid: int):
|
|
7
|
+
"""
|
|
8
|
+
gunicorn 启动时,把 uvicorn.workers.UvicornWorker 注册进程到 redis
|
|
9
|
+
"""
|
|
10
|
+
process_info: dict = {
|
|
11
|
+
"pid": str(pid),
|
|
12
|
+
"start_at": int(datetime.now().timestamp()),
|
|
13
|
+
"uptime": 0,
|
|
14
|
+
"request_count": 0,
|
|
15
|
+
"exception_count": 0,
|
|
16
|
+
"server_error_count": 0,
|
|
17
|
+
"status": "starting",
|
|
18
|
+
"last_update_by" : "gunicorn",
|
|
19
|
+
"last_update_at": int(datetime.now().timestamp()),
|
|
20
|
+
}
|
|
21
|
+
redis.sadd(f"{hostname}:workers", str(pid))
|
|
22
|
+
redis.hset(f"{hostname}:worker:{str(pid)}", mapping=process_info)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def worker_delete_from_redis(redis: Redis, hostname: str, pid: int):
|
|
26
|
+
"""
|
|
27
|
+
gunicorn 关闭时,把 uvicorn.workers.UvicornWorker 从 redis 中删除
|
|
28
|
+
"""
|
|
29
|
+
redis.srem(f"{hostname}:workers", str(pid))
|
|
30
|
+
redis.delete(f"{hostname}:worker:{str(pid)}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
async def worker_update_to_redis(hostname: str, pid: int, update_info: dict):
|
|
34
|
+
"""
|
|
35
|
+
在 uvicorn.workers.UvicornWorker 中,更新进程信息到 redis
|
|
36
|
+
update_info: dict = {
|
|
37
|
+
"request_count": int,
|
|
38
|
+
"exception_count": int,
|
|
39
|
+
"server_error_count": int,
|
|
40
|
+
}
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
redis = redis_client._client
|
|
44
|
+
if redis is None:
|
|
45
|
+
raise RuntimeError("Redis client is not initialized. Call initialize first.")
|
|
46
|
+
|
|
47
|
+
worker_key = f"{hostname}:worker:{str(pid)}"
|
|
48
|
+
|
|
49
|
+
# 获取当前进程信息
|
|
50
|
+
current_info = await redis.hgetall(worker_key) # type: ignore
|
|
51
|
+
|
|
52
|
+
# 更新进程信息
|
|
53
|
+
if current_info:
|
|
54
|
+
current_info["uptime"] = int(datetime.now().timestamp()) - int(current_info["start_at"])
|
|
55
|
+
current_info["last_update_at"] = int(datetime.now().timestamp())
|
|
56
|
+
current_info["status"] = "running"
|
|
57
|
+
current_info["last_update_by"] = "uvicorn.tomskit"
|
|
58
|
+
current_info.update(update_info)
|
|
59
|
+
# 更新到 redis
|
|
60
|
+
await redis.hset(worker_key, mapping=current_info) # type: ignore
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def get_all_worker_info_from_redis(hostname: str) -> dict:
|
|
64
|
+
"""
|
|
65
|
+
从 redis 中获取所有的进程信息,返回 dict
|
|
66
|
+
"""
|
|
67
|
+
redis = redis_client._client
|
|
68
|
+
if redis is None:
|
|
69
|
+
raise RuntimeError("Redis client is not initialized. Call initialize first.")
|
|
70
|
+
|
|
71
|
+
# 获取所有的 worker pids
|
|
72
|
+
worker_pids = await redis.smembers(f"{hostname}:workers") # type: ignore
|
|
73
|
+
|
|
74
|
+
all_process_info = {}
|
|
75
|
+
for pid in worker_pids:
|
|
76
|
+
worker_key = f"{hostname}:worker:{pid}"
|
|
77
|
+
process_info = await redis.hgetall(worker_key) # type: ignore
|
|
78
|
+
if process_info:
|
|
79
|
+
all_process_info[pid] = {k: v for k, v in process_info.items()}
|
|
80
|
+
|
|
81
|
+
return all_process_info
|