sycommon-python-lib 0.1.55__py3-none-any.whl → 0.1.55a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@ from kafka import KafkaProducer
15
15
  from loguru import logger
16
16
  import loguru
17
17
  from sycommon.config.Config import Config, SingletonMeta
18
- from sycommon.middleware.context import current_trace_id, current_headers
18
+ from sycommon.middleware.context import current_trace_id
19
19
  from sycommon.tools.snowflake import Snowflake
20
20
 
21
21
  # 配置Loguru的颜色方案
@@ -114,7 +114,7 @@ class KafkaLogger(metaclass=SingletonMeta):
114
114
  trace_id = None
115
115
 
116
116
  if not trace_id:
117
- trace_id = SYLogger.get_trace_id() or Snowflake.id
117
+ trace_id = SYLogger.get_trace_id() or Snowflake.next_id()
118
118
 
119
119
  # 获取线程/协程信息
120
120
  thread_info = SYLogger._get_execution_context()
@@ -173,7 +173,7 @@ class KafkaLogger(metaclass=SingletonMeta):
173
173
  "className": "",
174
174
  "sqlCost": 0,
175
175
  "size": len(str(message)),
176
- "uid": int(Snowflake.id) # 独立新的id
176
+ "uid": int(Snowflake.next_id()) # 独立新的id
177
177
  }
178
178
 
179
179
  # 智能队列管理
@@ -212,7 +212,7 @@ class KafkaLogger(metaclass=SingletonMeta):
212
212
  return
213
213
 
214
214
  # 获取当前的trace_id
215
- trace_id = SYLogger.get_trace_id() or Snowflake.id
215
+ trace_id = SYLogger.get_trace_id() or Snowflake.next_id()
216
216
 
217
217
  # 构建错误日志
218
218
  error_log = {
@@ -441,18 +441,6 @@ class SYLogger:
441
441
  """重置当前的 trace_id"""
442
442
  current_trace_id.reset(token)
443
443
 
444
- @staticmethod
445
- def get_headers():
446
- return current_headers.get()
447
-
448
- @staticmethod
449
- def set_headers(headers: list[tuple[str, str]]):
450
- return current_headers.set(headers)
451
-
452
- @staticmethod
453
- def reset_headers(token):
454
- current_headers.reset(token)
455
-
456
444
  @staticmethod
457
445
  def _get_execution_context() -> str:
458
446
  """获取当前执行上下文的线程或协程信息,返回格式化字符串"""
@@ -471,7 +459,7 @@ class SYLogger:
471
459
 
472
460
  @staticmethod
473
461
  def _log(msg: any, level: str = "INFO"):
474
- trace_id = SYLogger.get_trace_id() or Snowflake.id
462
+ trace_id = SYLogger.get_trace_id() or Snowflake.next_id()
475
463
 
476
464
  if isinstance(msg, dict) or isinstance(msg, list):
477
465
  msg_str = json.dumps(msg, ensure_ascii=False)
@@ -485,7 +473,7 @@ class SYLogger:
485
473
  request_log = {}
486
474
  if level == "ERROR":
487
475
  request_log = {
488
- "trace_id": str(trace_id) if trace_id else Snowflake.id,
476
+ "trace_id": str(trace_id) if trace_id else Snowflake.next_id(),
489
477
  "message": msg_str,
490
478
  "traceback": traceback.format_exc(),
491
479
  "level": level,
@@ -493,7 +481,7 @@ class SYLogger:
493
481
  }
494
482
  else:
495
483
  request_log = {
496
- "trace_id": str(trace_id) if trace_id else Snowflake.id,
484
+ "trace_id": str(trace_id) if trace_id else Snowflake.next_id(),
497
485
  "message": msg_str,
498
486
  "level": level,
499
487
  "threadName": thread_info
@@ -533,7 +521,7 @@ class SYLogger:
533
521
  @staticmethod
534
522
  def exception(msg: any, *args, **kwargs):
535
523
  """记录异常信息,包括完整堆栈"""
536
- trace_id = SYLogger.get_trace_id() or Snowflake.id
524
+ trace_id = SYLogger.get_trace_id() or Snowflake.next_id()
537
525
 
538
526
  if isinstance(msg, dict) or isinstance(msg, list):
539
527
  msg_str = json.dumps(msg, ensure_ascii=False)
@@ -545,7 +533,7 @@ class SYLogger:
545
533
 
546
534
  # 构建包含异常堆栈的日志
547
535
  request_log = {
548
- "trace_id": str(trace_id) if trace_id else Snowflake.id,
536
+ "trace_id": str(trace_id) if trace_id else Snowflake.next_id(),
549
537
  "message": msg_str,
550
538
  "level": "ERROR",
551
539
  "threadName": thread_info
@@ -1,5 +1,3 @@
1
1
  import contextvars
2
2
 
3
3
  current_trace_id = contextvars.ContextVar("trace_id", default=None)
4
-
5
- current_headers = contextvars.ContextVar("headers", default=None)
@@ -3,53 +3,33 @@ import re
3
3
  from typing import Dict, Any
4
4
  from fastapi import Request, Response
5
5
  from sycommon.logging.kafka_log import SYLogger
6
- from sycommon.tools.merge_headers import merge_headers
7
6
  from sycommon.tools.snowflake import Snowflake
8
7
 
9
8
 
10
9
  def setup_trace_id_handler(app):
11
10
  @app.middleware("http")
12
11
  async def trace_id_and_log_middleware(request: Request, call_next):
13
- # ========== 1. 请求阶段:确保获取/生成 x-traceId-header ==========
14
- # 优先从请求头读取(兼容任意大小写)
15
- trace_id = request.headers.get(
16
- "x-traceId-header") or request.headers.get("x-traceid-header")
17
- # 无则生成雪花ID
12
+ # 生成或获取 traceId
13
+ trace_id = request.headers.get("x-traceId-header")
18
14
  if not trace_id:
19
- trace_id = Snowflake.id
15
+ trace_id = Snowflake.next_id()
20
16
 
21
- # 设置 trace_id 到日志上下文
17
+ # 设置 trace_id 上下文
22
18
  token = SYLogger.set_trace_id(trace_id)
23
- header_token = SYLogger.set_headers(request.headers.raw)
24
19
 
25
20
  # 获取请求参数
26
21
  query_params = dict(request.query_params)
27
22
  request_body: Dict[str, Any] = {}
28
23
  files_info: Dict[str, str] = {}
29
24
 
30
- json_content_types = [
31
- "application/json",
32
- "text/plain;charset=utf-8",
33
- "text/plain"
34
- ]
25
+ # 检测请求内容类型
35
26
  content_type = request.headers.get("content-type", "").lower()
36
- is_json_content = any(ct in content_type for ct in json_content_types)
37
27
 
38
- if is_json_content and request.method in ["POST", "PUT", "PATCH"]:
28
+ if "application/json" in content_type and request.method in ["POST", "PUT", "PATCH"]:
39
29
  try:
40
- # 兼容纯文本格式的 JSON(先读文本再解析)
41
- if "text/plain" in content_type:
42
- raw_text = await request.text(encoding="utf-8")
43
- request_body = json.loads(raw_text)
44
- else:
45
- # application/json 直接解析
46
- request_body = await request.json()
30
+ request_body = await request.json()
47
31
  except Exception as e:
48
- try:
49
- request_body = await request.json()
50
- except Exception as e:
51
- # 精准捕获 JSON 解析错误(而非泛 Exception)
52
- request_body = {"error": f"JSON parse failed: {str(e)}"}
32
+ request_body = {"error": f"Failed to parse JSON: {str(e)}"}
53
33
 
54
34
  elif "multipart/form-data" in content_type and request.method in ["POST", "PUT"]:
55
35
  try:
@@ -82,9 +62,8 @@ def setup_trace_id_handler(app):
82
62
  request_body = {
83
63
  "error": f"Failed to process form data: {str(e)}"}
84
64
 
85
- # 构建请求日志(包含 traceId)
65
+ # 构建请求日志信息
86
66
  request_message = {
87
- "traceId": trace_id, # 请求日志中加入 traceId
88
67
  "method": request.method,
89
68
  "url": str(request.url),
90
69
  "query_params": query_params,
@@ -98,159 +77,68 @@ def setup_trace_id_handler(app):
98
77
  # 处理请求
99
78
  response = await call_next(request)
100
79
 
101
- # 获取响应Content-Type(统一小写)
102
- content_type = response.headers.get("content-type", "").lower()
80
+ content_type = response.headers.get("Content-Type", "")
103
81
 
104
- # ========== 2. SSE 响应:仅设置 x-traceId-header,不修改其他头 ==========
82
+ # 处理 SSE 响应
105
83
  if "text/event-stream" in content_type:
106
- try:
107
- # 强制写入 x-traceId-header 到响应头
108
- response.headers["x-traceId-header"] = trace_id
109
- # 确保前端能读取(仅补充暴露头,不覆盖原有值)
110
- expose_headers = response.headers.get(
111
- "access-control-expose-headers", "")
112
- if expose_headers:
113
- if "x-traceId-header" not in expose_headers.lower():
114
- response.headers[
115
- "access-control-expose-headers"] = f"{expose_headers}, x-traceId-header"
116
- else:
117
- response.headers["access-control-expose-headers"] = "x-traceId-header"
118
- # SSE 必须移除 Content-Length(仅这一个额外操作)
119
- headers_lower = {
120
- k.lower(): k for k in response.headers.keys()}
121
- if "content-length" in headers_lower:
122
- del response.headers[headers_lower["content-length"]]
123
- except AttributeError:
124
- # 流式响应头只读:初始化时仅加入 traceId 和必要暴露头
125
- new_headers = dict(response.headers) if hasattr(
126
- response.headers, 'items') else {}
127
- new_headers["x-traceId-header"] = trace_id # 强制加入
128
- # 保留原有暴露头,补充 traceId
129
- if "access-control-expose-headers" in new_headers:
130
- if "x-traceId-header" not in new_headers["access-control-expose-headers"].lower():
131
- new_headers["access-control-expose-headers"] += ", x-traceId-header"
132
- else:
133
- new_headers["access-control-expose-headers"] = "x-traceId-header"
134
- # 移除 Content-Length
135
- new_headers.pop("content-length", None)
136
- response.init_headers(new_headers)
84
+ # 流式响应不能有Content-Length,移除它
85
+ if "Content-Length" in response.headers:
86
+ del response.headers["Content-Length"]
87
+ response.headers["x-traceId-header"] = trace_id
137
88
  return response
138
89
 
139
- # ========== 3. 非 SSE 响应:强制写入 x-traceId-header,保留 CORS ==========
140
- # 备份 CORS 头(防止丢失)
141
- cors_headers = {}
142
- cors_header_keys = [
143
- "access-control-allow-origin",
144
- "access-control-allow-methods",
145
- "access-control-allow-headers",
146
- "access-control-expose-headers",
147
- "access-control-allow-credentials",
148
- "access-control-max-age"
149
- ]
150
- for key in cors_header_keys:
151
- for k in response.headers.keys():
152
- if k.lower() == key:
153
- cors_headers[key] = response.headers[k]
154
- break
155
-
156
- # 合并 headers(非 SSE 场景)
157
- merged_headers = merge_headers(
158
- source_headers=request.headers,
159
- target_headers=response.headers,
160
- keep_keys=None,
161
- delete_keys={'content-length', 'accept', 'content-type'}
162
- )
163
-
164
- # 强制加入 x-traceId-header(优先级最高)
165
- merged_headers["x-traceId-header"] = trace_id
166
- # 恢复 CORS 头 + 补充 traceId 到暴露头
167
- merged_headers.update(cors_headers)
168
- expose_headers = merged_headers.get(
169
- "access-control-expose-headers", "")
170
- if expose_headers:
171
- if "x-traceId-header" not in expose_headers.lower():
172
- merged_headers["access-control-expose-headers"] = f"{expose_headers}, x-traceId-header"
173
- else:
174
- merged_headers["access-control-expose-headers"] = "x-traceId-header"
175
-
176
- # 更新响应头
177
- if hasattr(response.headers, 'clear'):
178
- response.headers.clear()
179
- for k, v in merged_headers.items():
180
- response.headers[k] = v
181
- elif hasattr(response, "init_headers"):
182
- response.init_headers(merged_headers)
183
- else:
184
- for k, v in merged_headers.items():
185
- try:
186
- response.headers[k] = v
187
- except (AttributeError, KeyError):
188
- pass
189
-
190
- # 处理普通响应体(JSON 加入 traceId)
90
+ # 处理普通响应
191
91
  response_body = b""
192
92
  try:
93
+ # 收集所有响应块
193
94
  async for chunk in response.body_iterator:
194
95
  response_body += chunk
195
96
 
196
- # 获取 Content-Disposition(统一小写)
197
97
  content_disposition = response.headers.get(
198
- "content-disposition", "").lower()
98
+ "Content-Disposition", "")
199
99
 
200
- # JSON 响应体加入 traceId
100
+ # 判断是否能添加 trace_id
201
101
  if "application/json" in content_type and not content_disposition.startswith("attachment"):
202
102
  try:
203
103
  data = json.loads(response_body)
204
- new_body = response_body
205
- if data:
206
- data["traceId"] = trace_id # 响应体也加入
207
- new_body = json.dumps(
208
- data, ensure_ascii=False).encode()
104
+ data["traceId"] = trace_id
105
+ new_body = json.dumps(
106
+ data, ensure_ascii=False).encode()
209
107
 
210
- # 重建响应,确保 header 包含 x-traceId-header
108
+ # 创建新响应,确保Content-Length正确
211
109
  response = Response(
212
110
  content=new_body,
213
111
  status_code=response.status_code,
214
112
  headers=dict(response.headers),
215
113
  media_type=response.media_type
216
114
  )
217
- response.headers["content-length"] = str(len(new_body))
218
- response.headers["x-traceId-header"] = trace_id # 再次兜底
219
- # 恢复 CORS 头
220
- for k, v in cors_headers.items():
221
- response.headers[k] = v
115
+ # 显式设置正确的Content-Length
116
+ response.headers["Content-Length"] = str(len(new_body))
222
117
  except json.JSONDecodeError:
223
- # JSON 响应:仅更新长度,强制加入 traceId
118
+ # 如果不是JSON,恢复原始响应体并更新长度
224
119
  response = Response(
225
120
  content=response_body,
226
121
  status_code=response.status_code,
227
122
  headers=dict(response.headers),
228
123
  media_type=response.media_type
229
124
  )
230
- response.headers["content-length"] = str(
125
+ response.headers["Content-Length"] = str(
231
126
  len(response_body))
232
- response.headers["x-traceId-header"] = trace_id # 强制加入
233
- for k, v in cors_headers.items():
234
- response.headers[k] = v
235
127
  else:
236
- # 非 JSON 响应:强制加入 traceId
128
+ # 非JSON响应,恢复原始响应体
237
129
  response = Response(
238
130
  content=response_body,
239
131
  status_code=response.status_code,
240
132
  headers=dict(response.headers),
241
133
  media_type=response.media_type
242
134
  )
243
- response.headers["content-length"] = str(
135
+ response.headers["Content-Length"] = str(
244
136
  len(response_body))
245
- response.headers["x-traceId-header"] = trace_id # 强制加入
246
- for k, v in cors_headers.items():
247
- response.headers[k] = v
248
137
  except StopAsyncIteration:
249
138
  pass
250
139
 
251
- # 构建响应日志(包含 traceId)
140
+ # 构建响应日志信息
252
141
  response_message = {
253
- "traceId": trace_id, # 响应日志加入 traceId
254
142
  "status_code": response.status_code,
255
143
  "response_body": response_body.decode('utf-8', errors='ignore'),
256
144
  }
@@ -258,21 +146,11 @@ def setup_trace_id_handler(app):
258
146
  response_message, ensure_ascii=False)
259
147
  SYLogger.info(response_message_str)
260
148
 
261
- # ========== 最终兜底:确保响应头必有 x-traceId-header ==========
262
- try:
263
- response.headers["x-traceId-header"] = trace_id
264
- except AttributeError:
265
- new_headers = dict(response.headers) if hasattr(
266
- response.headers, 'items') else {}
267
- new_headers["x-traceId-header"] = trace_id
268
- if hasattr(response, "init_headers"):
269
- response.init_headers(new_headers)
149
+ response.headers["x-traceId-header"] = trace_id
270
150
 
271
151
  return response
272
152
  except Exception as e:
273
- # 异常日志也加入 traceId
274
153
  error_message = {
275
- "traceId": trace_id,
276
154
  "error": str(e),
277
155
  "query_params": query_params,
278
156
  "request_body": request_body,
@@ -282,8 +160,7 @@ def setup_trace_id_handler(app):
282
160
  SYLogger.error(error_message_str)
283
161
  raise
284
162
  finally:
285
- # 清理上下文变量
163
+ # 清理上下文变量,防止泄漏
286
164
  SYLogger.reset_trace_id(token)
287
- SYLogger.reset_headers(header_token)
288
165
 
289
166
  return app
sycommon/services.py CHANGED
@@ -7,7 +7,6 @@ from fastapi import FastAPI, applications
7
7
  from pydantic import BaseModel
8
8
  from typing import Any, Callable, Dict, List, Tuple, Union, Optional, AsyncGenerator
9
9
  from sycommon.config.Config import SingletonMeta
10
- from sycommon.logging.logger_levels import setup_logger_levels
11
10
  from sycommon.models.mqlistener_config import RabbitMQListenerConfig
12
11
  from sycommon.models.mqsend_config import RabbitMQSendConfig
13
12
  from sycommon.rabbitmq.rabbitmq_service import RabbitMQService
@@ -24,9 +23,6 @@ class Services(metaclass=SingletonMeta):
24
23
  _user_lifespan: Optional[Callable] = None
25
24
  _shutdown_lock: asyncio.Lock = asyncio.Lock()
26
25
 
27
- # 用于存储待执行的异步数据库初始化任务
28
- _pending_async_db_setup: List[Tuple[Callable, str]] = []
29
-
30
26
  def __init__(self, config: dict, app: FastAPI):
31
27
  if not Services._config:
32
28
  Services._config = config
@@ -52,25 +48,25 @@ class Services(metaclass=SingletonMeta):
52
48
  nacos_service: Optional[Callable[[dict], None]] = None,
53
49
  logging_service: Optional[Callable[[dict], None]] = None,
54
50
  database_service: Optional[Union[
55
- Tuple[Callable, str],
56
- List[Tuple[Callable, str]]
51
+ Tuple[Callable[[dict, str], None], str],
52
+ List[Tuple[Callable[[dict, str], None], str]]
57
53
  ]] = None,
58
54
  rabbitmq_listeners: Optional[List[RabbitMQListenerConfig]] = None,
59
55
  rabbitmq_senders: Optional[List[RabbitMQSendConfig]] = None
60
56
  ) -> FastAPI:
61
57
  load_dotenv()
62
- setup_logger_levels()
58
+ # 保存应用实例和配置
63
59
  cls._app = app
64
60
  cls._config = config
65
61
  cls._user_lifespan = app.router.lifespan_context
66
-
62
+ # 设置文档
67
63
  applications.get_swagger_ui_html = custom_swagger_ui_html
68
64
  applications.get_redoc_html = custom_redoc_html
69
-
65
+ # 设置app.state host, port
70
66
  if not cls._config:
71
67
  config = yaml.safe_load(open('app.yaml', 'r', encoding='utf-8'))
72
68
  cls._config = config
73
-
69
+ # 使用config
74
70
  app.state.config = {
75
71
  "host": cls._config.get('Host', '0.0.0.0'),
76
72
  "port": cls._config.get('Port', 8080),
@@ -78,6 +74,7 @@ class Services(metaclass=SingletonMeta):
78
74
  "h11_max_incomplete_event_size": cls._config.get('H11MaxIncompleteEventSize', 1024 * 1024 * 10)
79
75
  }
80
76
 
77
+ # 立即配置非异步服务(在应用启动前)
81
78
  if middleware:
82
79
  middleware(app, config)
83
80
 
@@ -87,29 +84,8 @@ class Services(metaclass=SingletonMeta):
87
84
  if logging_service:
88
85
  logging_service(config)
89
86
 
90
- # ========== 处理数据库服务 ==========
91
- # 清空之前的待执行列表(防止热重载时重复)
92
- cls._pending_async_db_setup = []
93
-
94
87
  if database_service:
95
- # 解析配置并区分同步/异步
96
- items = [database_service] if isinstance(
97
- database_service, tuple) else database_service
98
- for item in items:
99
- db_setup_func, db_name = item
100
- if asyncio.iscoroutinefunction(db_setup_func):
101
- # 如果是异步函数,加入待执行列表
102
- logging.info(f"检测到异步数据库服务: {db_name},将在应用启动时初始化")
103
- cls._pending_async_db_setup.append(item)
104
- else:
105
- # 如果是同步函数,立即执行
106
- logging.info(f"执行同步数据库服务: {db_name}")
107
- try:
108
- db_setup_func(config, db_name)
109
- except Exception as e:
110
- logging.error(
111
- f"同步数据库服务 {db_name} 初始化失败: {e}", exc_info=True)
112
- raise
88
+ cls._setup_database_static(database_service, config)
113
89
 
114
90
  # 创建组合生命周期管理器
115
91
  @asynccontextmanager
@@ -117,25 +93,14 @@ class Services(metaclass=SingletonMeta):
117
93
  # 1. 执行Services自身的初始化
118
94
  instance = cls(config, app)
119
95
 
120
- # ========== 执行挂起的异步数据库初始化 ==========
121
- if cls._pending_async_db_setup:
122
- logging.info("开始执行异步数据库初始化...")
123
- for db_setup_func, db_name in cls._pending_async_db_setup:
124
- try:
125
- await db_setup_func(config, db_name)
126
- logging.info(f"异步数据库服务 {db_name} 初始化成功")
127
- except Exception as e:
128
- logging.error(
129
- f"异步数据库服务 {db_name} 初始化失败: {e}", exc_info=True)
130
- raise
131
-
132
- # ========== 初始化 MQ ==========
96
+ # 明确判断是否有有效的监听器/发送器配置
133
97
  has_valid_listeners = bool(
134
98
  rabbitmq_listeners and len(rabbitmq_listeners) > 0)
135
99
  has_valid_senders = bool(
136
100
  rabbitmq_senders and len(rabbitmq_senders) > 0)
137
101
 
138
102
  try:
103
+ # 只有存在监听器或发送器时才初始化RabbitMQService
139
104
  if has_valid_listeners or has_valid_senders:
140
105
  await instance._setup_mq_async(
141
106
  rabbitmq_listeners=rabbitmq_listeners if has_valid_listeners else None,
@@ -154,18 +119,28 @@ class Services(metaclass=SingletonMeta):
154
119
  # 2. 执行用户定义的生命周期
155
120
  if cls._user_lifespan:
156
121
  async with cls._user_lifespan(app):
157
- yield
122
+ yield # 应用运行阶段
158
123
  else:
159
- yield
124
+ yield # 没有用户生命周期时直接 yield
160
125
 
161
126
  # 3. 执行Services的关闭逻辑
162
127
  await cls.shutdown()
163
128
  logging.info("Services已关闭")
164
129
 
130
+ # 设置组合生命周期
165
131
  app.router.lifespan_context = combined_lifespan
132
+
166
133
  return app
167
134
 
168
- # 移除了 _setup_database_static,因为逻辑已内联到 plugins 中
135
+ @staticmethod
136
+ def _setup_database_static(database_service, config):
137
+ """静态方法:设置数据库服务"""
138
+ if isinstance(database_service, tuple):
139
+ db_setup, db_name = database_service
140
+ db_setup(config, db_name)
141
+ elif isinstance(database_service, list):
142
+ for db_setup, db_name in database_service:
143
+ db_setup(config, db_name)
169
144
 
170
145
  async def _setup_mq_async(
171
146
  self,
@@ -174,13 +149,16 @@ class Services(metaclass=SingletonMeta):
174
149
  has_listeners: bool = False,
175
150
  has_senders: bool = False,
176
151
  ):
177
- """异步设置MQ相关服务"""
152
+ """异步设置MQ相关服务(适配单通道RabbitMQService)"""
153
+ # ========== 只有需要使用MQ时才初始化 ==========
178
154
  if not (has_listeners or has_senders):
179
155
  logging.info("无RabbitMQ监听器/发送器配置,跳过RabbitMQService初始化")
180
156
  return
181
157
 
158
+ # 仅当有监听器或发送器时,才执行RabbitMQService初始化
182
159
  RabbitMQService.init(self._config, has_listeners, has_senders)
183
160
 
161
+ # 优化:等待连接池“存在且初始化完成”(避免提前执行后续逻辑)
184
162
  start_time = asyncio.get_event_loop().time()
185
163
  while not (RabbitMQService._connection_pool and RabbitMQService._connection_pool._initialized) and not RabbitMQService._is_shutdown:
186
164
  if asyncio.get_event_loop().time() - start_time > 30:
@@ -188,7 +166,10 @@ class Services(metaclass=SingletonMeta):
188
166
  logging.info("等待RabbitMQ连接池初始化...")
189
167
  await asyncio.sleep(0.5)
190
168
 
169
+ # ========== 保留原有严格的发送器/监听器初始化判断 ==========
170
+ # 只有配置了发送器才执行发送器初始化
191
171
  if has_senders and rabbitmq_senders:
172
+ # 判断是否有监听器,如果有遍历监听器列表,队列名一样将prefetch_count属性设置到发送器对象中
192
173
  if has_listeners and rabbitmq_listeners:
193
174
  for sender in rabbitmq_senders:
194
175
  for listener in rabbitmq_listeners:
@@ -196,25 +177,31 @@ class Services(metaclass=SingletonMeta):
196
177
  sender.prefetch_count = listener.prefetch_count
197
178
  await self._setup_senders_async(rabbitmq_senders, has_listeners)
198
179
 
180
+ # 只有配置了监听器才执行监听器初始化
199
181
  if has_listeners and rabbitmq_listeners:
200
182
  await self._setup_listeners_async(rabbitmq_listeners, has_senders)
201
183
 
184
+ # 验证初始化结果
202
185
  if has_listeners:
186
+ # 异步获取客户端数量(适配新的RabbitMQService)
203
187
  listener_count = len(RabbitMQService._consumer_tasks)
204
188
  logging.info(f"监听器初始化完成,共启动 {listener_count} 个消费者")
205
189
  if listener_count == 0:
206
190
  logging.warning("未成功初始化任何监听器,请检查配置或MQ服务状态")
207
191
 
208
192
  async def _setup_senders_async(self, rabbitmq_senders, has_listeners: bool):
209
- """设置发送器"""
193
+ """设置发送器(适配新的RabbitMQService异步方法)"""
210
194
  Services._registered_senders = [
211
195
  sender.queue_name for sender in rabbitmq_senders]
196
+
197
+ # 将是否有监听器的信息传递给RabbitMQService(异步调用)
212
198
  await RabbitMQService.setup_senders(rabbitmq_senders, has_listeners)
199
+ # 更新已注册的发送器(从RabbitMQService获取实际注册的名称)
213
200
  Services._registered_senders = RabbitMQService._sender_client_names
214
201
  logging.info(f"已注册的RabbitMQ发送器: {Services._registered_senders}")
215
202
 
216
203
  async def _setup_listeners_async(self, rabbitmq_listeners, has_senders: bool):
217
- """设置监听器"""
204
+ """设置监听器(适配新的RabbitMQService异步方法)"""
218
205
  await RabbitMQService.setup_listeners(rabbitmq_listeners, has_senders)
219
206
 
220
207
  @classmethod
@@ -225,7 +212,7 @@ class Services(metaclass=SingletonMeta):
225
212
  max_retries: int = 3,
226
213
  retry_delay: float = 1.0, **kwargs
227
214
  ) -> None:
228
- """发送消息"""
215
+ """发送消息,添加重试机制(适配单通道RabbitMQService)"""
229
216
  if not cls._initialized or not cls._loop:
230
217
  logging.error("Services not properly initialized!")
231
218
  raise ValueError("服务未正确初始化")
@@ -236,15 +223,18 @@ class Services(metaclass=SingletonMeta):
236
223
 
237
224
  for attempt in range(max_retries):
238
225
  try:
226
+ # 验证发送器是否注册
239
227
  if queue_name not in cls._registered_senders:
240
228
  cls._registered_senders = RabbitMQService._sender_client_names
241
229
  if queue_name not in cls._registered_senders:
242
230
  raise ValueError(f"发送器 {queue_name} 未注册")
243
231
 
232
+ # 获取发送器(适配新的异步get_sender方法)
244
233
  sender = await RabbitMQService.get_sender(queue_name)
245
234
  if not sender:
246
235
  raise ValueError(f"发送器 '{queue_name}' 不存在或连接无效")
247
236
 
237
+ # 发送消息(调用RabbitMQService的异步send_message)
248
238
  await RabbitMQService.send_message(data, queue_name, **kwargs)
249
239
  logging.info(f"消息发送成功(尝试 {attempt+1}/{max_retries})")
250
240
  return
@@ -254,18 +244,25 @@ class Services(metaclass=SingletonMeta):
254
244
  logging.error(
255
245
  f"消息发送失败(已尝试 {max_retries} 次): {str(e)}", exc_info=True)
256
246
  raise
247
+
257
248
  logging.warning(
258
- f"消息发送失败(尝试 {attempt+1}/{max_retries}): {str(e)},{retry_delay}秒后重试...")
249
+ f"消息发送失败(尝试 {attempt+1}/{max_retries}): {str(e)},"
250
+ f"{retry_delay}秒后重试..."
251
+ )
259
252
  await asyncio.sleep(retry_delay)
260
253
 
261
254
  @classmethod
262
255
  async def shutdown(cls):
263
- """关闭所有服务"""
256
+ """关闭所有服务(适配单通道RabbitMQService关闭逻辑)"""
264
257
  async with cls._shutdown_lock:
265
258
  if RabbitMQService._is_shutdown:
266
259
  logging.info("RabbitMQService已关闭,无需重复操作")
267
260
  return
261
+
262
+ # 关闭RabbitMQ服务(异步调用,内部会关闭所有客户端+消费任务)
268
263
  await RabbitMQService.shutdown()
264
+
265
+ # 清理全局状态
269
266
  cls._initialized = False
270
267
  cls._registered_senders.clear()
271
268
  logging.info("所有服务已关闭")