fs-pyutils 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fs_pyutils-1.1.0/PKG-INFO +12 -0
- fs_pyutils-1.1.0/README.md +3 -0
- fs_pyutils-1.1.0/pyproject.toml +29 -0
- fs_pyutils-1.1.0/src/fs_pyutils/__init__.py +0 -0
- fs_pyutils-1.1.0/src/fs_pyutils/audio.py +45 -0
- fs_pyutils-1.1.0/src/fs_pyutils/gunicorn_logger.py +73 -0
- fs_pyutils-1.1.0/src/fs_pyutils/import.py +56 -0
- fs_pyutils-1.1.0/src/fs_pyutils/log_builder.py +183 -0
- fs_pyutils-1.1.0/src/fs_pyutils/py.typed +0 -0
- fs_pyutils-1.1.0/src/fs_pyutils/systemd_notifier.py +104 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: fs-pyutils
|
|
3
|
+
Version: 1.1.0
|
|
4
|
+
Summary: Python utils for fseasy scope.
|
|
5
|
+
Author: fseasy
|
|
6
|
+
Author-email: fseasy <xuwei.fs@outlook.com>
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
|
|
10
|
+
## FS pyutils
|
|
11
|
+
|
|
12
|
+
Python Utils for fseasy scope.
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "fs-pyutils"
|
|
3
|
+
version = "1.1.0"
|
|
4
|
+
description = "Python utils for fseasy scope."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [{ name = "fseasy", email = "xuwei.fs@outlook.com" }]
|
|
7
|
+
requires-python = ">=3.12"
|
|
8
|
+
dependencies = []
|
|
9
|
+
|
|
10
|
+
[build-system]
|
|
11
|
+
requires = ["uv_build>=0.9.25,<0.10.0"]
|
|
12
|
+
build-backend = "uv_build"
|
|
13
|
+
|
|
14
|
+
[tool.hatch.build.targets.wheel]
|
|
15
|
+
packages = ["src/fs_pyutils"]
|
|
16
|
+
|
|
17
|
+
[tool.ruff]
|
|
18
|
+
line-length = 120
|
|
19
|
+
indent-width = 2
|
|
20
|
+
exclude = []
|
|
21
|
+
|
|
22
|
+
[tool.ruff.lint]
|
|
23
|
+
# 确保 "I" (isort) 规则在启用列表中
|
|
24
|
+
select = ["E", "F", "B", "UP", "I"]
|
|
25
|
+
|
|
26
|
+
[tool.ruff.lint.isort]
|
|
27
|
+
# 这里可以像 isort 一样自定义配置
|
|
28
|
+
force-single-line = false # 是否强制单行导入
|
|
29
|
+
combine-as-imports = true # 是否合并 import as
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
import subprocess
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def audio_to_mp3_bytes(
|
|
6
|
+
audio_bytes: bytes,
|
|
7
|
+
bitrate: str = "128k", #
|
|
8
|
+
sample_rate: int | None = None,
|
|
9
|
+
channels: int | None = None,
|
|
10
|
+
ffmpeg_bin_path: str | None = None,
|
|
11
|
+
) -> bytes:
|
|
12
|
+
"""Any audio bytes to mp3 bytes
|
|
13
|
+
Args:
|
|
14
|
+
- bitrate: like 64k 96k 128k 192k
|
|
15
|
+
- sample_rate: like 16000 22050 24000 44100
|
|
16
|
+
- channels: like 1, 2
|
|
17
|
+
- ffmpeg_bin_path: default = "ffmpeg"
|
|
18
|
+
"""
|
|
19
|
+
ffmpeg_bin_path = ffmpeg_bin_path or "ffmpeg"
|
|
20
|
+
if not shutil.which(ffmpeg_bin_path):
|
|
21
|
+
raise RuntimeError("No ffmpeg found in env")
|
|
22
|
+
|
|
23
|
+
cmd = [ffmpeg_bin_path, "-i", "pipe:0", "-codec:a", "libmp3lame", "-b:a", bitrate]
|
|
24
|
+
|
|
25
|
+
if sample_rate:
|
|
26
|
+
cmd += ["-ar", str(sample_rate)]
|
|
27
|
+
|
|
28
|
+
if channels:
|
|
29
|
+
cmd += ["-ac", str(channels)]
|
|
30
|
+
|
|
31
|
+
cmd += ["-f", "mp3", "pipe:1"]
|
|
32
|
+
|
|
33
|
+
proc = subprocess.Popen(
|
|
34
|
+
cmd,
|
|
35
|
+
stdin=subprocess.PIPE,
|
|
36
|
+
stdout=subprocess.PIPE,
|
|
37
|
+
stderr=subprocess.PIPE,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
out, err = proc.communicate(audio_bytes)
|
|
41
|
+
|
|
42
|
+
if proc.returncode != 0:
|
|
43
|
+
raise RuntimeError(err.decode())
|
|
44
|
+
|
|
45
|
+
return out
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
#! note1: this module must be used in python env where contains the gunicorn (it's reasonable)
|
|
6
|
+
from gunicorn.config import Config
|
|
7
|
+
from gunicorn.glogging import Logger as GLogger
|
|
8
|
+
|
|
9
|
+
from .log_builder import JsonSyslogFormatter, NginxAlignedSyslogHandler
|
|
10
|
+
|
|
11
|
+
#! note2: you must set `SYSLOG_ADDRESS` in ENV, in format `ip:port`, like: "127.0.0.1:5140"
|
|
12
|
+
g_syslog_addr_str = os.environ["SYSLOG_ADDRESS"]
|
|
13
|
+
#! it's better also set the hostname.
|
|
14
|
+
g_hostname = os.getenv("HOSTNAME", "gunicorn-default")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class GunicornSyslogLogger(GLogger):
|
|
18
|
+
def setup(self, cfg: Config) -> None:
|
|
19
|
+
super().setup(cfg)
|
|
20
|
+
|
|
21
|
+
ip, port = g_syslog_addr_str.split(":")
|
|
22
|
+
syslog_address = (ip, int(port))
|
|
23
|
+
|
|
24
|
+
# 继承 gunicorn 的日志级别 => cfg.loglevel 是 string, 必须得改成 int
|
|
25
|
+
level = self.LOG_LEVELS.get(cfg.loglevel.lower(), logging.INFO)
|
|
26
|
+
|
|
27
|
+
# --- 准备你的 Handler 和 Formatter ---
|
|
28
|
+
syslog_handler = NginxAlignedSyslogHandler(
|
|
29
|
+
address=syslog_address,
|
|
30
|
+
hostname=g_hostname,
|
|
31
|
+
facility=23, # Local7
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# 使用你现有的 JsonSyslogFormatter
|
|
35
|
+
json_fmt = JsonSyslogFormatter(g_hostname)
|
|
36
|
+
syslog_handler.setFormatter(json_fmt)
|
|
37
|
+
syslog_handler.setLevel(level)
|
|
38
|
+
|
|
39
|
+
# --- 将 Handler 绑定到 Gunicorn 的两个核心 Logger ---
|
|
40
|
+
# 1. 错误日志 (gunicorn.error) (we dont' remove the default error log handler, just append)
|
|
41
|
+
self.error_log.addHandler(syslog_handler)
|
|
42
|
+
self.error_log.setLevel(level)
|
|
43
|
+
|
|
44
|
+
# 2. 访问日志 (gunicorn.access)
|
|
45
|
+
# 移除默认的 access_log handler (防止重复打印到 stdout)
|
|
46
|
+
self.access_log.handlers = []
|
|
47
|
+
self.access_log.addHandler(syslog_handler)
|
|
48
|
+
self.access_log.setLevel(level)
|
|
49
|
+
self.access_log.propagate = False
|
|
50
|
+
|
|
51
|
+
def access(self, resp, req, environ, request_time) -> None: # type: ignore
|
|
52
|
+
"""
|
|
53
|
+
这个方法覆盖了 Gunicorn 默认逻辑。
|
|
54
|
+
它把本该塞进 args 的字典,转而塞进 extra。
|
|
55
|
+
这样就能被你现有的 _get_extra_kv 自动识别并转为 JSON 字段。
|
|
56
|
+
"""
|
|
57
|
+
atoms = self.atoms(resp, req, environ, request_time)
|
|
58
|
+
# 消息内容可以设为 "access" 或者根据 access_log_format 渲染
|
|
59
|
+
self.access_log.info("access", extra=atoms)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GunicornJsonFormatter(JsonSyslogFormatter):
|
|
63
|
+
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
|
64
|
+
log_data = super()._build_log_dict(record)
|
|
65
|
+
|
|
66
|
+
# mapping key for readability
|
|
67
|
+
mapping = {"h": "remote_ip", "m": "method", "U": "path", "s": "status", "M": "duration_ms", "a": "user_agent"}
|
|
68
|
+
|
|
69
|
+
for short_key, long_key in mapping.items():
|
|
70
|
+
if short_key in log_data:
|
|
71
|
+
log_data[long_key] = log_data.pop(short_key)
|
|
72
|
+
|
|
73
|
+
return log_data
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from importlib.util import module_from_spec, spec_from_file_location
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from types import ModuleType
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def import_module_from_path(
|
|
9
|
+
module_name: str,
|
|
10
|
+
file_path: Union[str, Path],
|
|
11
|
+
*,
|
|
12
|
+
register: bool = True,
|
|
13
|
+
) -> ModuleType:
|
|
14
|
+
"""
|
|
15
|
+
Dynamically import a Python module from a file path.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
module_name:
|
|
20
|
+
Name to assign to the module (used in sys.modules).
|
|
21
|
+
|
|
22
|
+
file_path:
|
|
23
|
+
Path to the .py file.
|
|
24
|
+
|
|
25
|
+
register:
|
|
26
|
+
Whether to register the module in sys.modules.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
ModuleType
|
|
31
|
+
The imported module.
|
|
32
|
+
|
|
33
|
+
Raises
|
|
34
|
+
------
|
|
35
|
+
FileNotFoundError
|
|
36
|
+
ImportError
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
path = Path(file_path)
|
|
40
|
+
|
|
41
|
+
if not path.exists():
|
|
42
|
+
raise FileNotFoundError(path)
|
|
43
|
+
|
|
44
|
+
spec = spec_from_file_location(module_name, path)
|
|
45
|
+
|
|
46
|
+
if spec is None or spec.loader is None:
|
|
47
|
+
raise ImportError(f"Cannot load module from {path}")
|
|
48
|
+
|
|
49
|
+
module = module_from_spec(spec)
|
|
50
|
+
|
|
51
|
+
if register:
|
|
52
|
+
sys.modules[module_name] = module
|
|
53
|
+
|
|
54
|
+
spec.loader.exec_module(module)
|
|
55
|
+
|
|
56
|
+
return module
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import socket
|
|
4
|
+
import sys
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from logging.handlers import SysLogHandler
|
|
7
|
+
|
|
8
|
+
# import traceback
|
|
9
|
+
from typing import Any
|
|
10
|
+
from urllib.parse import urlparse
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def build_logger(
|
|
14
|
+
name: str, level: int, syslog_address: tuple[str, int] | None = None, domain: str | None = None
|
|
15
|
+
) -> logging.Logger:
|
|
16
|
+
"""
|
|
17
|
+
Args:
|
|
18
|
+
syslog_address: used for syslog (you can setup a grafana alloy), will send a json log
|
|
19
|
+
domain: used in the json logger
|
|
20
|
+
"""
|
|
21
|
+
logger = logging.getLogger(name)
|
|
22
|
+
|
|
23
|
+
streamHandler = logging.StreamHandler(stream=sys.stderr)
|
|
24
|
+
fmt = SingleLineFormatter("%(asctime)s/%(name)s/%(levelname)s/%(filename)s:%(lineno)d> %(message)s")
|
|
25
|
+
streamHandler.setFormatter(fmt)
|
|
26
|
+
streamHandler.setLevel(level)
|
|
27
|
+
|
|
28
|
+
logger.addHandler(streamHandler)
|
|
29
|
+
logger.setLevel(level)
|
|
30
|
+
|
|
31
|
+
# syslog
|
|
32
|
+
if syslog_address:
|
|
33
|
+
try:
|
|
34
|
+
syslog_handler = NginxAlignedSyslogHandler(
|
|
35
|
+
address=syslog_address,
|
|
36
|
+
hostname=_domain2hostname(domain),
|
|
37
|
+
facility=SysLogHandler.LOG_LOCAL7,
|
|
38
|
+
)
|
|
39
|
+
json_fmt = JsonSyslogFormatter(_domain2hostname(domain))
|
|
40
|
+
syslog_handler.setFormatter(json_fmt)
|
|
41
|
+
syslog_handler.setLevel(level)
|
|
42
|
+
logger.addHandler(syslog_handler)
|
|
43
|
+
except Exception as e:
|
|
44
|
+
logger.warning(f"Failed to add syslog, err={e}")
|
|
45
|
+
return logger
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SingleLineFormatter(logging.Formatter):
|
|
49
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
50
|
+
fmt_line = super().format(record)
|
|
51
|
+
single_line = fmt_line.replace("\n", " ↵ ")
|
|
52
|
+
# naively append extra fields
|
|
53
|
+
extra_kv = _get_extra_kv(record)
|
|
54
|
+
if extra_kv:
|
|
55
|
+
extra_line = json.dumps(extra_kv, ensure_ascii=False, default=str)
|
|
56
|
+
single_line = f"{single_line} extra={extra_line}"
|
|
57
|
+
return single_line
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class JsonSyslogFormatter(logging.Formatter):
|
|
61
|
+
"""A formatter for syslog in json format"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, host: str):
|
|
64
|
+
"""init with a hostname"""
|
|
65
|
+
super().__init__(fmt="%(message)s")
|
|
66
|
+
self._host = host
|
|
67
|
+
|
|
68
|
+
def _build_log_dict(self, record: logging.LogRecord) -> dict[str, Any]:
|
|
69
|
+
iso_time = datetime.fromtimestamp(record.created).astimezone().isoformat(timespec="seconds")
|
|
70
|
+
|
|
71
|
+
log_data = {
|
|
72
|
+
"host": self._host,
|
|
73
|
+
"time": iso_time, # to align the loki receiver
|
|
74
|
+
"level": record.levelname,
|
|
75
|
+
"logger": record.name,
|
|
76
|
+
"file": f"{record.filename}:{record.lineno}",
|
|
77
|
+
"msg": record.getMessage(),
|
|
78
|
+
}
|
|
79
|
+
if record.exc_info:
|
|
80
|
+
log_data["traceback"] = self.formatException(record.exc_info)
|
|
81
|
+
|
|
82
|
+
# add info from extra={...}
|
|
83
|
+
log_data.update(_get_extra_kv(record))
|
|
84
|
+
return log_data
|
|
85
|
+
|
|
86
|
+
def format(self, record: logging.LogRecord) -> str:
|
|
87
|
+
log_data = self._build_log_dict(record)
|
|
88
|
+
return json.dumps(log_data, default=str)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class NginxAlignedSyslogHandler(logging.Handler):
|
|
92
|
+
"""
|
|
93
|
+
严格遵守 RFC 3164,生成与 Nginx 完全一致的 Syslog UDP 报文。
|
|
94
|
+
解决了 Python SysLogHandler 日期填充错误、格式错位的问题。
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, address: tuple[str, int], hostname: str, facility: int = 23):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.address = address
|
|
100
|
+
# 强制替换非法字符,保持与 Nginx (site_docgate) 类似的纯净 TAG
|
|
101
|
+
self.app_name = f"{hostname.replace('.', '_').replace('-', '_')}_fastapi"
|
|
102
|
+
self.hostname = hostname
|
|
103
|
+
self.facility = facility # Nginx 默认多用 Local7 (23)
|
|
104
|
+
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
105
|
+
|
|
106
|
+
def emit(self, record: logging.LogRecord) -> None:
|
|
107
|
+
try:
|
|
108
|
+
# 1. 提取 JSON message
|
|
109
|
+
msg = self.format(record).lstrip() # remove the potential leading space to keep the format
|
|
110
|
+
|
|
111
|
+
# 2. 计算 PRI (Facility * 8 + Severity)
|
|
112
|
+
severity_map = {
|
|
113
|
+
logging.DEBUG: 7,
|
|
114
|
+
logging.INFO: 6,
|
|
115
|
+
logging.WARNING: 4,
|
|
116
|
+
logging.ERROR: 3,
|
|
117
|
+
logging.CRITICAL: 2,
|
|
118
|
+
}
|
|
119
|
+
severity = severity_map.get(record.levelno, 6)
|
|
120
|
+
pri = (self.facility * 8) + severity
|
|
121
|
+
|
|
122
|
+
# 3. 严格生成 RFC 3164 时间戳: "Mmm dd hh:mm:ss"
|
|
123
|
+
# 注意:个位数日期必须用空格补齐,比如 "Feb 3" 不能是 "Feb 03"
|
|
124
|
+
now = datetime.now()
|
|
125
|
+
month = now.strftime("%b")
|
|
126
|
+
day = now.day
|
|
127
|
+
day_str = f"{day:>2}" # 右对齐,不足补空格。这一步完美解决解析失败问题!
|
|
128
|
+
time_str = now.strftime("%H:%M:%S")
|
|
129
|
+
timestamp = f"{month} {day_str} {time_str}"
|
|
130
|
+
|
|
131
|
+
# 4. 严格拼接,完全复刻 Nginx 格式: <PRI>TIMESTAMP HOSTNAME TAG: MSG
|
|
132
|
+
# 注意 self.app_name 后面紧跟冒号和空格 ": "
|
|
133
|
+
syslog_msg = f"<{pri}>{timestamp} {self.hostname} {self.app_name}: {msg}\n"
|
|
134
|
+
|
|
135
|
+
# 5. 发送 UDP 包
|
|
136
|
+
self.sock.sendto(syslog_msg.encode("utf-8"), self.address)
|
|
137
|
+
except Exception:
|
|
138
|
+
self.handleError(record)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def _get_extra_kv(record: logging.LogRecord) -> dict[str, Any]:
|
|
142
|
+
RESERVED_ATTRS = set(
|
|
143
|
+
(
|
|
144
|
+
"args",
|
|
145
|
+
"asctime",
|
|
146
|
+
"created",
|
|
147
|
+
"exc_info",
|
|
148
|
+
"exc_text",
|
|
149
|
+
"filename",
|
|
150
|
+
"funcName",
|
|
151
|
+
"levelname",
|
|
152
|
+
"levelno",
|
|
153
|
+
"lineno",
|
|
154
|
+
"message",
|
|
155
|
+
"module",
|
|
156
|
+
"msecs",
|
|
157
|
+
"msg",
|
|
158
|
+
"name",
|
|
159
|
+
"pathname",
|
|
160
|
+
"process",
|
|
161
|
+
"processName",
|
|
162
|
+
"relativeCreated",
|
|
163
|
+
"stack_info",
|
|
164
|
+
"thread",
|
|
165
|
+
"threadName",
|
|
166
|
+
"taskName",
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
extra_data: dict[str, Any] = {}
|
|
170
|
+
for key, value in record.__dict__.items():
|
|
171
|
+
if key not in RESERVED_ATTRS:
|
|
172
|
+
extra_data[key] = value
|
|
173
|
+
return extra_data
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _domain2hostname(domain: str | None) -> str:
|
|
177
|
+
_DEFAULT_HOST = "default_app"
|
|
178
|
+
if not domain:
|
|
179
|
+
return _DEFAULT_HOST
|
|
180
|
+
host = urlparse(domain if "://" in domain else "https://" + domain).hostname
|
|
181
|
+
if not host:
|
|
182
|
+
host = _DEFAULT_HOST
|
|
183
|
+
return host
|
|
File without changes
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""Systemd notifier, should be used with fastapi.
|
|
2
|
+
1. intercept the gunicorn default `READY` signal (by pop env var: NOTIFY_SOCKET)
|
|
3
|
+
2. export a lifespan manager with following features:
|
|
4
|
+
- notify READY/STOPPING
|
|
5
|
+
- db watchdog for systemd.
|
|
6
|
+
|
|
7
|
+
To enable systemd watchdog:
|
|
8
|
+
1. in systemd service. set:
|
|
9
|
+
Type=notify
|
|
10
|
+
|
|
11
|
+
WatchdogSec=45(or any other value)
|
|
12
|
+
Restart=always
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import asyncio
|
|
16
|
+
import logging
|
|
17
|
+
import os
|
|
18
|
+
import socket
|
|
19
|
+
from collections.abc import AsyncGenerator
|
|
20
|
+
from contextlib import asynccontextmanager
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
from fastapi import FastAPI
|
|
24
|
+
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
25
|
+
|
|
26
|
+
_CUSTOM_NOTIFY_SOCKET_VAR_NAME = "CUSTOM_NOTIFY_SOCKET"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def intercept_server_ready_signal() -> None:
|
|
30
|
+
"""pop NOTIFY_SOCKET from env and set it to our custom var"""
|
|
31
|
+
# 拦截 Systemd 的通知 Socket
|
|
32
|
+
# Gunicorn 内部逻辑是:如果检测到环境变量 NOTIFY_SOCKET,就会在启动后发送 READY=1。
|
|
33
|
+
# 我们在 Gunicorn 初始化前把它移走,存入自定义变量中。
|
|
34
|
+
_real_notify_socket = os.environ.pop("NOTIFY_SOCKET", None)
|
|
35
|
+
if _real_notify_socket:
|
|
36
|
+
# 存入一个 Gunicorn 不认识,但我们代码能找到的变量名
|
|
37
|
+
os.environ[_CUSTOM_NOTIFY_SOCKET_VAR_NAME] = _real_notify_socket
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@asynccontextmanager
|
|
41
|
+
async def systemd_notifier_lifespan(
|
|
42
|
+
app: FastAPI, async_db_engine: AsyncEngine, logger: logging.Logger | None = None
|
|
43
|
+
) -> AsyncGenerator[Any, None]:
|
|
44
|
+
"""Please call this after you do anything else! because it'll send `READY/STOPPING` signal"""
|
|
45
|
+
del app
|
|
46
|
+
# 启动看门狗后台任务
|
|
47
|
+
task = asyncio.create_task(_db_watchdog_task(async_db_engine, logger=logger))
|
|
48
|
+
# 发送启动成功
|
|
49
|
+
_send_sd_notify("READY=1", logger=logger)
|
|
50
|
+
try:
|
|
51
|
+
yield
|
|
52
|
+
finally:
|
|
53
|
+
_send_sd_notify("STOPPING=1", logger=logger)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
task.cancel()
|
|
57
|
+
except asyncio.CancelledError:
|
|
58
|
+
pass # 预期的取消异常
|
|
59
|
+
if logger:
|
|
60
|
+
logger.info("systemd: db watchdog stopped.")
|
|
61
|
+
_send_sd_notify("STATUS=db watchdog stopped", logger=logger)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def _db_watchdog_task(async_db_engine: AsyncEngine, logger: logging.Logger | None = None) -> None:
|
|
65
|
+
"""后台监控任务:定期执行 SQL 任务,确保 SQL 运行正常,并喂狗"""
|
|
66
|
+
interval_usec = os.getenv("WATCHDOG_USEC")
|
|
67
|
+
if not interval_usec:
|
|
68
|
+
if logger:
|
|
69
|
+
logger.warning("db watchdog isn't set in systemd. skip watchdog task")
|
|
70
|
+
return
|
|
71
|
+
watch_sleep_sec = round(int(interval_usec) / 1_000_000 / 3) # try to report 3 times in one watch duration
|
|
72
|
+
_send_sd_notify(f"STATUS=prepared to create systemd watchdog task, report interval={watch_sleep_sec}s", logger=logger)
|
|
73
|
+
while True:
|
|
74
|
+
try:
|
|
75
|
+
async with async_db_engine.begin() as conn:
|
|
76
|
+
await conn.exec_driver_sql("SELECT 1")
|
|
77
|
+
|
|
78
|
+
# 2. 如果成功执行,说明事件循环没死,DB 没卡死,通知 Systemd
|
|
79
|
+
_send_sd_notify("WATCHDOG=1", logger=logger)
|
|
80
|
+
_send_sd_notify("STATUS=db watchdog send", logger=logger) # test
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
# 如果连不上 DB 报错,故意不喂狗,让 Systemd 几秒后重启我们
|
|
84
|
+
if logger:
|
|
85
|
+
logger.exception(f"db watchdog health check failed: {e}")
|
|
86
|
+
_send_sd_notify(f"STATUS=db watchdog health check failed, e={e}", logger=logger)
|
|
87
|
+
|
|
88
|
+
await asyncio.sleep(watch_sleep_sec)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _send_sd_notify(data: str, logger: logging.Logger | None = None):
|
|
92
|
+
# 读取拦截后的 Socket 地址
|
|
93
|
+
addr = os.getenv(_CUSTOM_NOTIFY_SOCKET_VAR_NAME)
|
|
94
|
+
if not addr:
|
|
95
|
+
return
|
|
96
|
+
if addr.startswith("@"):
|
|
97
|
+
addr = "\0" + addr[1:]
|
|
98
|
+
try:
|
|
99
|
+
with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock:
|
|
100
|
+
sock.connect(addr)
|
|
101
|
+
sock.sendall(data.encode())
|
|
102
|
+
except Exception as e:
|
|
103
|
+
if logger:
|
|
104
|
+
logger.warning(f"Failed to send watchdog signal to systemd: {e}")
|