aitoolman 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aitoolman/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from .app import LLMApplication
2
+ from .module import LLMModule, DefaultLLMModule
3
+ from .client import LLMClient, LLMLocalClient
4
+ from .channel import (TextFragmentOutput, Channel, TextChannel,
5
+ BaseXmlTagFilter, XmlTagToChannelFilter, ChannelEvent, collect_text_channels)
6
+ from .model import MediaContent, Message, ToolCall, LLMResponse, LLMRequest, FinishReason, LLMModuleResult
7
+ from .provider import LLMProviderManager
8
+ from .util import load_config, load_config_str
9
+ from . import postprocess
aitoolman/__main__.py ADDED
@@ -0,0 +1,5 @@
1
+ import aitoolman.cli
2
+
3
+
4
+ if __name__ == "__main__":
5
+ aitoolman.cli.main()
aitoolman/app.py ADDED
@@ -0,0 +1,202 @@
1
+ import functools
2
+ import logging
3
+ from typing import Any, Dict, Optional, Callable
4
+
5
+ import jinja2
6
+
7
+ from . import util
8
+ from . import postprocess
9
+ from . import client as _client
10
+ from . import channel as _channel
11
+ from .module import LLMModule, ModuleConfig, DefaultLLMModule
12
+
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class ConfigTemplateLoader(jinja2.BaseLoader):
18
+ """自定义Jinja2模板加载器,支持全局和模块模板交叉引用"""
19
+
20
+ def __init__(self, config_dict: Dict[str, Any]):
21
+ self.global_templates = config_dict.get('template', {})
22
+ self.modules = config_dict.get('module', {})
23
+
24
+ def get_source(self, environment, template):
25
+ """
26
+ 获取模板源码
27
+ template 可以是:
28
+ - 'header' -> 全局模板
29
+ - 'module/task_planner/user' -> 模块task_planner的user模板
30
+ """
31
+ if template in self.global_templates:
32
+ source = self.global_templates[template]
33
+ return source, template, lambda: True
34
+ elif template.startswith('module/'):
35
+ # 模块模板: module/{module_name}/{template_name}
36
+ parts = template.split('/')
37
+ module_name = parts[1]
38
+ template_name = parts[2]
39
+ module_config = self.modules.get(module_name, {})
40
+ if template_name in module_config.get('template', {}):
41
+ source = module_config['template'][template_name]
42
+ return source, template, lambda: True
43
+ raise jinja2.TemplateNotFound(template)
44
+
45
+ def list_templates(self):
46
+ """列出所有可用模板(调试用)"""
47
+ templates = list(self.global_templates.keys())
48
+ for module_name, module_config in self.modules.items():
49
+ for template_name in module_config.get('template', {}).keys():
50
+ templates.append(f"module/{module_name}/{template_name}")
51
+ return templates
52
+
53
+
54
+ class LLMApplication:
55
+ """LLM应用上下文"""
56
+
57
+ def __init__(
58
+ self,
59
+ client: _client.LLMClient,
60
+ config_dict: Optional[Dict[str, Any]] = None,
61
+ processors: Optional[Dict[str, Callable[[str], Any]]] = None,
62
+ channels: Optional[Dict[str, _channel.TextChannel]] = None,
63
+ context_id: Optional[str] = None
64
+ ):
65
+ self.client: _client.LLMClient = client
66
+ self.context_id: str = context_id or util.get_id()
67
+ self.vars: Dict[str, Any] = {}
68
+ self.channels: Dict[str, _channel.TextChannel] = {}
69
+ self.processors: Dict[str, Callable[[str], Any]] = postprocess.DEFAULT_PROCESSORS.copy()
70
+ self.modules: Dict[str, LLMModule] = {}
71
+
72
+ # 加载全局工具定义
73
+ self.global_tools: Dict[str, Any] = {}
74
+
75
+ # 配置初始化
76
+ self.config = (config_dict or {}).copy()
77
+ self.config.setdefault('module', {})
78
+ self.config.setdefault('template', {})
79
+ self.config.setdefault('tools', {})
80
+
81
+ # 加载全局工具
82
+ self.global_tools = self.config.get('tools', {})
83
+
84
+ # 初始化Jinja2环境,使用自定义loader
85
+ self.jinja_env: jinja2.Environment = jinja2.Environment(
86
+ loader=ConfigTemplateLoader(config_dict),
87
+ autoescape=False,
88
+ trim_blocks=True,
89
+ lstrip_blocks=True
90
+ )
91
+
92
+ if processors:
93
+ self.processors.update(processors)
94
+
95
+ if channels:
96
+ self.channels.update(channels)
97
+ if 'stdin' not in self.channels:
98
+ self.channels['stdin'] = _channel.TextChannel()
99
+ if 'stdout' not in self.channels:
100
+ self.channels['stdout'] = _channel.TextChannel(read_fragments=True)
101
+ if 'reasoning' not in self.channels:
102
+ self.channels['reasoning'] = _channel.TextChannel(read_fragments=True)
103
+
104
+ def init_all_modules(self):
105
+ """从配置加载所有模块"""
106
+ if 'module' not in self.config:
107
+ return
108
+
109
+ for module_name, module_config in self.config['module'].items():
110
+ # 创建模块配置对象
111
+ self.init_module_from_config(module_name, module_config)
112
+
113
+ def init_module_from_config(self, module_name, module_config):
114
+ """从配置初始化模块"""
115
+ # 合并模块默认配置
116
+ config = self.config.get('module_default', {}).copy()
117
+ config.update(module_config)
118
+
119
+ # 处理工具配置(支持全局工具引用)
120
+ tools_config = config.get('tools', {})
121
+ resolved_tools = {}
122
+
123
+ for tool_name, tool_config in tools_config.items():
124
+ # 如果工具配置为空dict,表示引用全局工具
125
+ if isinstance(tool_config, dict) and not tool_config:
126
+ if tool_name in self.global_tools:
127
+ resolved_tools[tool_name] = self.global_tools[tool_name]
128
+ else:
129
+ raise ValueError(f"Module '{module_name}' referenced undefined global tool '{tool_name}'.")
130
+ else:
131
+ # 使用模块自定义配置(覆盖全局配置)
132
+ resolved_tools[tool_name] = tool_config
133
+
134
+ # 解析通道配置
135
+ channel_name = config.get('output_channel')
136
+ output_channel = self.channels[channel_name] if channel_name else None
137
+ channel_name = config.get('reasoning_channel')
138
+ reasoning_channel = self.channels[channel_name] if channel_name else None
139
+
140
+ # 创建模块配置对象
141
+ module_config_obj = ModuleConfig(
142
+ name=module_name,
143
+ model=config.get('model', ''),
144
+ templates=config.get('template', {}),
145
+ tools=resolved_tools,
146
+ stream=config.get('stream', False),
147
+ output_channel=output_channel,
148
+ reasoning_channel=reasoning_channel,
149
+ post_processor=config.get('post_processor'),
150
+ save_context=config.get('save_context', False),
151
+ options=config.get('options', {})
152
+ )
153
+
154
+ module = DefaultLLMModule(self, module_config_obj)
155
+ self.modules[module_name] = module
156
+ return module
157
+
158
+ def __getattr__(self, name: str) -> LLMModule:
159
+ """通过属性访问模块"""
160
+ if name in self.modules:
161
+ return self.modules[name]
162
+ if name in self.config['module']:
163
+ return self.init_module_from_config(name, self.config['module'][name])
164
+ raise AttributeError(f"No LLM module named '{name}'")
165
+
166
+ def add_processor(self, name: str, processor: Callable):
167
+ """添加后处理器"""
168
+ self.processors[name] = processor
169
+
170
+ def get_processor(self, name: str) -> Optional[Callable]:
171
+ """获取后处理器"""
172
+ return self.processors.get(name)
173
+
174
+ def render_template(self, template_name: str, **kwargs) -> str:
175
+ """渲染命名模板"""
176
+ all_vars = {**self.vars, **kwargs}
177
+ return self.jinja_env.get_template(template_name).render(**all_vars)
178
+
179
+ def add_channel(self, name: str, channel: _channel.TextChannel):
180
+ """添加自定义通道"""
181
+ self.channels[name] = channel
182
+
183
+ async def audit_event(self, event_type: str, **kwargs):
184
+ """触发用户自定义审计事件"""
185
+ await self.client.audit_event(self.context_id, event_type, **kwargs)
186
+
187
+ @classmethod
188
+ def factory(
189
+ cls,
190
+ client: _client.LLMClient,
191
+ config_dict: Optional[Dict[str, Any]] = None,
192
+ processors: Optional[Dict[str, Callable[[str], Any]]] = None,
193
+ channels: Optional[Dict[str, _channel.TextChannel]] = None,
194
+ ) -> Callable[..., 'LLMApplication']:
195
+ """创建应用工厂函数"""
196
+ return functools.partial(
197
+ cls,
198
+ client=client,
199
+ config_dict=config_dict,
200
+ processors=processors,
201
+ channels=channels
202
+ )
aitoolman/channel.py ADDED
@@ -0,0 +1,403 @@
1
+ import re
2
+ import time
3
+ import typing
4
+ import asyncio
5
+ import logging
6
+ from abc import ABC, abstractmethod
7
+ from typing import Any, Optional, Dict, Set
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class TextFragmentInput(typing.Protocol):
13
+ async def read_message(self) -> Optional[str]:
14
+ ...
15
+
16
+ async def read_fragment(self) -> Optional[str]:
17
+ ...
18
+
19
+
20
+ class TextFragmentOutput(typing.Protocol):
21
+ async def write_message(self, message: Optional[str]):
22
+ ...
23
+
24
+ async def write_fragment(self, text: str, end: bool = False):
25
+ ...
26
+
27
+
28
+ class Channel:
29
+ """基础通道类,支持完整消息的异步读写"""
30
+
31
+ def __init__(self):
32
+ self._message_queue = asyncio.Queue()
33
+ self.closed = False
34
+
35
+ async def read_message(self) -> Any:
36
+ """读取一条完整消息"""
37
+ return await self._message_queue.get()
38
+
39
+ async def write_message(self, message: Any):
40
+ """写入一条完整消息"""
41
+ if self.closed:
42
+ raise IOError("Channel is closed")
43
+ await self._message_queue.put(message)
44
+
45
+ def close(self):
46
+ """关闭通道"""
47
+ self.closed = True
48
+
49
+
50
+ class TextChannel(Channel):
51
+ """文本通道,支持消息片段读写"""
52
+
53
+ def __init__(self, read_fragments=False):
54
+ super().__init__()
55
+ # 当前消息的片段缓冲区
56
+ self._current_fragment_buffer = []
57
+ # 片段读取相关的状态
58
+ self._fragment_queue = asyncio.Queue()
59
+ self._read_fragments = read_fragments
60
+
61
+ async def read_message(self) -> Optional[str]:
62
+ """读取完整消息"""
63
+ if self._read_fragments:
64
+ raise RuntimeError("Cannot read message while reading fragments")
65
+ elif self.closed:
66
+ raise IOError("Channel is closed")
67
+ return await self._message_queue.get()
68
+
69
+ async def write_message(self, message: Optional[str]):
70
+ if self.closed:
71
+ raise IOError("Channel is closed")
72
+
73
+ await self._message_queue.put(message)
74
+ if self._read_fragments:
75
+ await self._fragment_queue.put(message)
76
+ await self._fragment_queue.put(None) # 结束标记
77
+
78
+ async def read_fragment(self) -> Optional[str]:
79
+ """读取消息片段,返回 None 表示结束"""
80
+ if self.closed or not self._read_fragments:
81
+ return None
82
+
83
+ return await self._fragment_queue.get()
84
+
85
+ async def write_fragment(self, text: str, end: bool = False):
86
+ """写入消息片段"""
87
+ if self.closed:
88
+ raise IOError("Channel is closed")
89
+
90
+ self._current_fragment_buffer.append(text)
91
+
92
+ if self._read_fragments:
93
+ await self._fragment_queue.put(text)
94
+
95
+ if end:
96
+ complete_message = ''.join(self._current_fragment_buffer)
97
+ await self._message_queue.put(complete_message)
98
+ if self._read_fragments:
99
+ await self._fragment_queue.put(None)
100
+ self._current_fragment_buffer.clear()
101
+
102
+
103
+ class BaseXmlTagFilter(ABC):
104
+ def __init__(self, tags: Set[str]):
105
+ self.tags = tags
106
+ self.current_tag: Optional[str] = None # 当前激活的指定标签
107
+ self.current_content: list[str] = [] # 当前标签的内容缓冲区
108
+ self.pending_text: str = "" # 跨片段的不完整标签缓冲区
109
+
110
+ # 匹配所有XML标签的正则(支持命名空间和特殊字符)
111
+ self.tag_pattern = re.compile(r'<(/?)([a-zA-Z_][\w.:-]*)>')
112
+ # 闭合标签模板(动态生成当前标签的闭合匹配)
113
+ self.closing_tag_template = r'</%s>'
114
+
115
+ @abstractmethod
116
+ async def on_message_tag(self, tag: Optional[str], message: str, end: bool):
117
+ """处理完整消息的标签内容回调"""
118
+ pass
119
+
120
+ @abstractmethod
121
+ async def on_fragment_tag(self, tag: Optional[str], text: str, end: bool):
122
+ """处理消息片段的标签内容回调"""
123
+ pass
124
+
125
+ async def _on_tag(self, tag: Optional[str], message: str, end: bool, is_fragment: bool):
126
+ if is_fragment:
127
+ await self.on_fragment_tag(tag, message, end)
128
+ else:
129
+ await self.on_message_tag(tag, message, end)
130
+
131
+ async def write_message(self, message: Optional[str]) -> None:
132
+ """处理完整XML消息"""
133
+ self._reset_state()
134
+ try:
135
+ remaining = await self._parse_content(message, is_fragment=False, end=True)
136
+ if remaining:
137
+ await self._on_tag(None, remaining, end=True, is_fragment=False)
138
+ finally:
139
+ self._reset_state()
140
+
141
+ async def write_fragment(self, text: str, end: bool = False) -> None:
142
+ """处理XML消息片段"""
143
+ if not text and not end:
144
+ return
145
+
146
+ full_text = self.pending_text + text
147
+ remaining = await self._parse_content(full_text, is_fragment=True, end=end)
148
+ self.pending_text = remaining
149
+
150
+ if end:
151
+ await self._finalize_fragment()
152
+
153
+ def _reset_state(self) -> None:
154
+ """重置所有解析状态"""
155
+ self.current_tag = None
156
+ self.current_content = []
157
+ self.pending_text = ""
158
+
159
+ async def _parse_content(self, text: str, is_fragment: bool, end: bool) -> str:
160
+ """
161
+ 核心解析逻辑:递归处理文本内容
162
+ 返回值:未解析的剩余文本(用于跨片段处理)
163
+ """
164
+ pos = 0
165
+ len_text = len(text)
166
+
167
+ # 状态1:当前处于指定标签内部(仅搜索当前标签的闭合)
168
+ if self.current_tag is not None:
169
+ closing_tag = self.closing_tag_template % self.current_tag
170
+ closing_pos = text.find(closing_tag, pos)
171
+
172
+ if closing_pos != -1:
173
+ # 找到闭合标签:处理内容并重置状态
174
+ self.current_content.append(text[pos:closing_pos])
175
+ await self._emit_content(is_fragment)
176
+
177
+ # 继续解析闭合标签后的内容(递归)
178
+ pos = closing_pos + len(closing_tag)
179
+ return await self._parse_content(text[pos:], is_fragment, end)
180
+ else:
181
+ # 未找到闭合标签:保存所有内容
182
+ self.current_content.append(text[pos:])
183
+ return ""
184
+
185
+ # 状态2:处于顶层(解析所有标签)
186
+ while pos < len_text:
187
+ match = self.tag_pattern.search(text, pos)
188
+ if not match:
189
+ return await self._handle_top_level_remaining(text[pos:], is_fragment, end)
190
+
191
+ # 处理标签前的普通文本
192
+ before_tag = text[pos:match.start()]
193
+ if before_tag:
194
+ await self._on_tag(None, before_tag, False, is_fragment)
195
+
196
+ # 解析标签信息
197
+ is_closing = match.group(1) == '/'
198
+ tag_name = match.group(2)
199
+ tag_text = match.group(0)
200
+
201
+ if tag_name in self.tags:
202
+ if not is_closing:
203
+ # 处理指定标签的打开:进入标签内部状态
204
+ self.current_tag = tag_name
205
+ self.current_content = []
206
+ pos = match.end()
207
+ return await self._parse_content(text[pos:], is_fragment, end)
208
+ else:
209
+ # 孤立的闭合标签:作为普通文本处理
210
+ await self._on_tag(None, tag_text, False, is_fragment)
211
+ pos = match.end()
212
+ else:
213
+ # 非指定标签:作为普通文本处理
214
+ await self._on_tag(None, tag_text, False, is_fragment)
215
+ pos = match.end()
216
+
217
+ return ""
218
+
219
+ async def _handle_top_level_remaining(self, remaining: str, is_fragment: bool, end: bool) -> str:
220
+ """处理顶层未解析的剩余文本(处理不完整标签)"""
221
+ if not remaining:
222
+ return ""
223
+
224
+ # 查找最后一个<的位置(判断是否有不完整标签)
225
+ last_less_than = remaining.rfind('<')
226
+ if last_less_than == -1:
227
+ # 无标签结构:全部作为普通文本
228
+ await self._on_tag(None, remaining, end, is_fragment)
229
+ return ""
230
+ else:
231
+ # 分割完整文本与不完整标签
232
+ complete_part = remaining[:last_less_than]
233
+ if complete_part:
234
+ await self._on_tag(None, complete_part, False, is_fragment)
235
+ return remaining[last_less_than:] # 返回不完整部分
236
+
237
+ async def _emit_content(self, is_fragment: bool) -> None:
238
+ """发射当前标签的内容(非空时)"""
239
+ content = ''.join(self.current_content).strip()
240
+ if content:
241
+ await self._on_tag(self.current_tag, content, True, is_fragment)
242
+ self.current_tag = None
243
+ self.current_content = []
244
+
245
+ async def _finalize_fragment(self) -> None:
246
+ """处理最后一个片段的未完成状态"""
247
+ if self.current_tag is not None:
248
+ # 处理未闭合的标签内容
249
+ if self.pending_text:
250
+ self.current_content.append(self.pending_text)
251
+ await self._emit_content(True)
252
+ self.pending_text = ""
253
+ elif self.pending_text:
254
+ # 处理未完成的普通文本
255
+ await self.on_fragment_tag(None, self.pending_text, True)
256
+ self.pending_text = ""
257
+
258
+
259
+ class XmlTagToChannelFilter(BaseXmlTagFilter):
260
+ def __init__(self, default_channel: 'TextChannel', channel_map: Dict[str, 'TextChannel']):
261
+ tags = set(channel_map.keys())
262
+ super().__init__(tags)
263
+ self.default_channel = default_channel
264
+ self.channel_map = channel_map
265
+
266
+ async def on_message_tag(self, tag: Optional[str], message: str, end: bool):
267
+ """将完整消息分发到对应通道"""
268
+ if tag and tag in self.channel_map:
269
+ await self.channel_map[tag].write_message(message)
270
+ else:
271
+ await self.default_channel.write_message(message)
272
+
273
+ async def on_fragment_tag(self, tag: Optional[str], text: str, end: bool):
274
+ """将消息片段分发到对应通道"""
275
+ if tag and tag in self.channel_map:
276
+ await self.channel_map[tag].write_fragment(text, end)
277
+ else:
278
+ await self.default_channel.write_fragment(text, end)
279
+
280
+
281
+ class ChannelEvent(typing.NamedTuple):
282
+ """通道事件结构"""
283
+ channel: str # 通道名称,如 'reasoning', 'response'
284
+ message: Any # 消息内容
285
+ is_fragment: bool # 是否为片段
286
+ is_end: bool # 是否为结束标记
287
+
288
+
289
+ async def collect_text_channels(
290
+ channels: Dict[str, TextChannel],
291
+ read_fragments: bool = True,
292
+ timeout: Optional[float] = None
293
+ ) -> typing.AsyncGenerator[ChannelEvent, None]:
294
+ """
295
+ 通用通道收集器,同时监听多个TextChannel,生成通道事件。
296
+ Args:
297
+ channels: 通道字典,键为通道名称,值为TextChannel对象
298
+ read_fragments: 是否以片段模式读取(True=片段,False=完整消息)
299
+ timeout: 总体超时时间(秒),超过则抛出TimeoutError
300
+ Yields:
301
+ ChannelEvent: 通道事件,包含通道名、消息内容、是否为片段、是否为结束标记
302
+ """
303
+ # 初始化:为每个通道创建第一个读取任务(future)
304
+ _priority_timeout = 1.0
305
+ _priority_real_timeout = (
306
+ min(timeout, _priority_timeout) if timeout else _priority_timeout)
307
+ last_output_channel = None
308
+ pending_futures: Dict[asyncio.Future, str] = {} # future -> 通道名称
309
+ for channel_name, channel in channels.items():
310
+ # 根据读取模式选择对应的读取方法
311
+ coro = channel.read_fragment() if read_fragments else channel.read_message()
312
+ fut = asyncio.create_task(coro)
313
+ pending_futures[fut] = channel_name
314
+
315
+ try:
316
+ while pending_futures:
317
+ # 保存当前未完成的future映射(避免wait后pending_futures被覆盖)
318
+ current_futures = pending_futures.copy()
319
+ done_futures = {}
320
+ pending_set = set()
321
+ if last_output_channel in current_futures.values():
322
+ wait_futures = [
323
+ fut for fut, channel_name in current_futures.items()
324
+ if channel_name == last_output_channel
325
+ ]
326
+ start_time = time.monotonic()
327
+
328
+ _done, _pending = await asyncio.wait(
329
+ wait_futures,
330
+ timeout=_priority_real_timeout,
331
+ return_when=asyncio.FIRST_COMPLETED # 有一个完成就返回
332
+ )
333
+ if timeout and time.monotonic() - start_time > timeout:
334
+ raise TimeoutError(
335
+ f"collect_text_channels timed out after {timeout} seconds")
336
+ for fut in _done:
337
+ done_futures[fut] = current_futures[fut]
338
+ for fut in current_futures.keys():
339
+ if fut in done_futures:
340
+ continue
341
+ if fut.done():
342
+ done_futures[fut] = current_futures[fut]
343
+ else:
344
+ pending_set.add(fut)
345
+ else:
346
+ # 等待任意future完成,或超时(返回已完成和未完成的future分组)
347
+ _done, _pending = await asyncio.wait(
348
+ current_futures.keys(),
349
+ timeout=timeout,
350
+ return_when=asyncio.FIRST_COMPLETED # 有一个完成就返回
351
+ )
352
+ # 处理超时:无任何future完成时触发
353
+ if not _done:
354
+ raise TimeoutError(
355
+ f"collect_text_channels timed out after {timeout} seconds")
356
+ for fut in _done:
357
+ done_futures[fut] = current_futures[fut]
358
+ pending_set = _pending
359
+
360
+ # 更新pending_futures为未完成的任务(后续继续等待)
361
+ pending_futures = {fut: current_futures[fut] for fut in pending_set}
362
+
363
+ # 处理已完成的future
364
+ for fut, channel_name in done_futures.items():
365
+ channel = channels[channel_name]
366
+
367
+ try:
368
+ result = fut.result() # 获取读取结果(可能抛出异常)
369
+ except Exception:
370
+ logger.exception(f"Failed to read from channel '{channel_name}'")
371
+ # 异常视为通道结束,生成结束事件
372
+ yield ChannelEvent(
373
+ channel=channel_name,
374
+ message=None,
375
+ is_fragment=read_fragments,
376
+ is_end=True
377
+ )
378
+ continue # 跳过后续处理(该通道不再读取)
379
+
380
+ # 判断通道是否结束(收到None即为结束)
381
+ is_end = result is None
382
+
383
+ # 生成通道事件(严格对应读取结果)
384
+ yield ChannelEvent(
385
+ channel=channel_name,
386
+ message=result, # 消息内容(None表示结束)
387
+ is_fragment=read_fragments, # 是否为片段模式
388
+ is_end=is_end # 是否是通道结束标记
389
+ )
390
+ last_output_channel = channel_name
391
+
392
+ # 未结束的通道:继续添加下一次读取任务
393
+ if not is_end:
394
+ next_coro = channel.read_fragment() if read_fragments else channel.read_message()
395
+ next_fut = asyncio.create_task(next_coro)
396
+ pending_futures[next_fut] = channel_name
397
+
398
+ finally:
399
+ # 清理资源:取消所有未完成的读取任务
400
+ for fut in pending_futures:
401
+ fut.cancel()
402
+ # 等待所有任务取消完成(避免资源泄漏)
403
+ await asyncio.gather(*pending_futures.keys(), return_exceptions=True)