maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
maque/cli/groups/mllm.py
ADDED
|
@@ -0,0 +1,1795 @@
|
|
|
1
|
+
"""MLLM (多模态大语言模型) 命令组"""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
|
|
6
|
+
# 强制启用颜色支持
|
|
7
|
+
os.environ["FORCE_COLOR"] = "1"
|
|
8
|
+
if not os.environ.get("TERM"):
|
|
9
|
+
os.environ["TERM"] = "xterm-256color"
|
|
10
|
+
|
|
11
|
+
from rich.console import Console
|
|
12
|
+
from rich import print
|
|
13
|
+
from rich.markdown import Markdown
|
|
14
|
+
|
|
15
|
+
console = Console(
|
|
16
|
+
force_terminal=True,
|
|
17
|
+
width=100,
|
|
18
|
+
color_system="windows",
|
|
19
|
+
legacy_windows=True,
|
|
20
|
+
safe_box=True
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def safe_print(*args, **kwargs):
|
|
25
|
+
"""安全的打印函数,确保在所有终端中正确显示颜色"""
|
|
26
|
+
try:
|
|
27
|
+
console.print(*args, **kwargs)
|
|
28
|
+
except Exception:
|
|
29
|
+
# 降级到普通print,处理编码问题
|
|
30
|
+
import re
|
|
31
|
+
import sys
|
|
32
|
+
import builtins
|
|
33
|
+
|
|
34
|
+
clean_args = []
|
|
35
|
+
for arg in args:
|
|
36
|
+
if isinstance(arg, str):
|
|
37
|
+
# 去除rich markup
|
|
38
|
+
clean_arg = re.sub(r"\[/?[^\]]*\]", "", str(arg))
|
|
39
|
+
# 处理emoji和特殊字符
|
|
40
|
+
try:
|
|
41
|
+
# 尝试编码为gbk (Windows默认编码)
|
|
42
|
+
clean_arg.encode('gbk')
|
|
43
|
+
clean_args.append(clean_arg)
|
|
44
|
+
except UnicodeEncodeError:
|
|
45
|
+
# 如果包含无法编码的字符,替换emoji为文本描述
|
|
46
|
+
clean_arg = re.sub(r'❌', '[错误]', clean_arg)
|
|
47
|
+
clean_arg = re.sub(r'✅', '[成功]', clean_arg)
|
|
48
|
+
clean_arg = re.sub(r'💡', '[提示]', clean_arg)
|
|
49
|
+
clean_arg = re.sub(r'🚀', '[启动]', clean_arg)
|
|
50
|
+
clean_arg = re.sub(r'📦', '[模型]', clean_arg)
|
|
51
|
+
clean_arg = re.sub(r'🌐', '[服务器]', clean_arg)
|
|
52
|
+
clean_arg = re.sub(r'👋', '[再见]', clean_arg)
|
|
53
|
+
clean_arg = re.sub(r'📝', '[记录]', clean_arg)
|
|
54
|
+
clean_arg = re.sub(r'⚠️', '[警告]', clean_arg)
|
|
55
|
+
clean_arg = re.sub(r'🔍', '[搜索]', clean_arg)
|
|
56
|
+
clean_arg = re.sub(r'🤖', '[机器人]', clean_arg)
|
|
57
|
+
clean_arg = re.sub(r'📡', '[网络]', clean_arg)
|
|
58
|
+
clean_arg = re.sub(r'🔌', '[连接]', clean_arg)
|
|
59
|
+
clean_arg = re.sub(r'📋', '[配置]', clean_arg)
|
|
60
|
+
clean_arg = re.sub(r'📁', '[文件]', clean_arg)
|
|
61
|
+
clean_arg = re.sub(r'🔧', '[设置]', clean_arg)
|
|
62
|
+
clean_arg = re.sub(r'🎯', '[目标]', clean_arg)
|
|
63
|
+
clean_arg = re.sub(r'📊', '[统计]', clean_arg)
|
|
64
|
+
clean_arg = re.sub(r'🧠', '[思考]', clean_arg)
|
|
65
|
+
clean_arg = re.sub(r'💭', '[推理]', clean_arg)
|
|
66
|
+
clean_arg = re.sub(r'🔗', '[逻辑]', clean_arg)
|
|
67
|
+
# 移除其他无法显示的emoji
|
|
68
|
+
clean_arg = re.sub(r'[\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF\U00002600-\U000027BF\U0001F900-\U0001F9FF]', '', clean_arg)
|
|
69
|
+
clean_args.append(clean_arg)
|
|
70
|
+
else:
|
|
71
|
+
clean_args.append(str(arg))
|
|
72
|
+
|
|
73
|
+
# 使用内置print
|
|
74
|
+
try:
|
|
75
|
+
builtins.print(*clean_args, **kwargs)
|
|
76
|
+
except UnicodeEncodeError:
|
|
77
|
+
# 最后的降级:使用错误替换
|
|
78
|
+
safe_args = [arg.encode('gbk', errors='replace').decode('gbk') if isinstance(arg, str) else arg for arg in clean_args]
|
|
79
|
+
builtins.print(*safe_args, **kwargs)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def safe_print_stream(text, **kwargs):
|
|
83
|
+
"""安全的流式打印函数,用于流式输出
|
|
84
|
+
|
|
85
|
+
默认使用原生 print 实现真正的流式输出,避免 Rich console 的格式化干扰。
|
|
86
|
+
"""
|
|
87
|
+
import builtins
|
|
88
|
+
|
|
89
|
+
flush = kwargs.pop('flush', True) # 流式输出默认 flush
|
|
90
|
+
end = kwargs.pop('end', '') # 流式输出默认不换行
|
|
91
|
+
|
|
92
|
+
try:
|
|
93
|
+
builtins.print(text, end=end, flush=flush, **kwargs)
|
|
94
|
+
except UnicodeEncodeError:
|
|
95
|
+
# 编码失败时,尝试使用 stdout buffer
|
|
96
|
+
if hasattr(sys.stdout, 'buffer'):
|
|
97
|
+
sys.stdout.buffer.write(text.encode('utf-8', errors='replace'))
|
|
98
|
+
if flush:
|
|
99
|
+
sys.stdout.buffer.flush()
|
|
100
|
+
else:
|
|
101
|
+
# 最后的降级方案:替换无法编码的字符
|
|
102
|
+
safe_text = text.encode('gbk', errors='replace').decode('gbk')
|
|
103
|
+
builtins.print(safe_text, end=end, flush=flush, **kwargs)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def safe_print_markdown(content, **kwargs):
|
|
107
|
+
"""安全的Markdown渲染函数"""
|
|
108
|
+
try:
|
|
109
|
+
# 使用Rich的Markdown渲染
|
|
110
|
+
markdown = Markdown(content)
|
|
111
|
+
console.print(markdown, **kwargs)
|
|
112
|
+
except Exception:
|
|
113
|
+
# 降级到普通打印
|
|
114
|
+
safe_print(content, **kwargs)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class StreamingMarkdownRenderer:
|
|
118
|
+
"""流式Markdown渲染器 - 实时解析并渲染Markdown"""
|
|
119
|
+
|
|
120
|
+
def __init__(self):
|
|
121
|
+
self.buffer = ""
|
|
122
|
+
self.last_rendered_length = 0
|
|
123
|
+
self.in_code_block = False
|
|
124
|
+
self.code_block_lang = ""
|
|
125
|
+
|
|
126
|
+
def add_token(self, token):
|
|
127
|
+
"""添加新token并尝试渲染"""
|
|
128
|
+
self.buffer += token
|
|
129
|
+
self._try_render_incremental()
|
|
130
|
+
|
|
131
|
+
def _try_render_incremental(self):
|
|
132
|
+
"""尝试增量渲染Markdown"""
|
|
133
|
+
# 检测代码块
|
|
134
|
+
if "```" in self.buffer[self.last_rendered_length:]:
|
|
135
|
+
code_block_matches = self.buffer.count("```")
|
|
136
|
+
self.in_code_block = (code_block_matches % 2) == 1
|
|
137
|
+
|
|
138
|
+
# 如果在代码块中,直接输出原始文本
|
|
139
|
+
if self.in_code_block:
|
|
140
|
+
new_content = self.buffer[self.last_rendered_length:]
|
|
141
|
+
if new_content:
|
|
142
|
+
safe_print_stream(new_content, end="", flush=True)
|
|
143
|
+
self.last_rendered_length = len(self.buffer)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
# 尝试找到可以安全渲染的边界(句子、段落等)
|
|
147
|
+
render_boundary = self._find_render_boundary()
|
|
148
|
+
if render_boundary > self.last_rendered_length:
|
|
149
|
+
content_to_render = self.buffer[self.last_rendered_length:render_boundary]
|
|
150
|
+
self._render_content(content_to_render)
|
|
151
|
+
self.last_rendered_length = render_boundary
|
|
152
|
+
|
|
153
|
+
def _find_render_boundary(self):
|
|
154
|
+
"""找到适合渲染的边界位置"""
|
|
155
|
+
content = self.buffer
|
|
156
|
+
|
|
157
|
+
# 寻找句子结束标记
|
|
158
|
+
for i in range(len(content) - 1, self.last_rendered_length - 1, -1):
|
|
159
|
+
char = content[i]
|
|
160
|
+
# 句子结束
|
|
161
|
+
if char in '.!?。!?':
|
|
162
|
+
# 确保后面有空格或换行,避免误判小数点等
|
|
163
|
+
if i + 1 < len(content) and content[i + 1] in ' \n\t':
|
|
164
|
+
return i + 1
|
|
165
|
+
# 段落结束
|
|
166
|
+
elif char == '\n' and (i + 1 >= len(content) or content[i + 1] == '\n'):
|
|
167
|
+
return i + 1
|
|
168
|
+
|
|
169
|
+
# 如果没有找到合适的边界,返回当前长度(不渲染)
|
|
170
|
+
return self.last_rendered_length
|
|
171
|
+
|
|
172
|
+
def _render_content(self, content):
|
|
173
|
+
"""渲染内容片段"""
|
|
174
|
+
if not content.strip():
|
|
175
|
+
safe_print_stream(content, end="", flush=True)
|
|
176
|
+
return
|
|
177
|
+
|
|
178
|
+
# 简单的行内Markdown渲染
|
|
179
|
+
try:
|
|
180
|
+
# 检查是否包含Markdown元素
|
|
181
|
+
if any(marker in content for marker in ['**', '*', '`', '#', '-', '1.']):
|
|
182
|
+
# 简单的实时渲染,只处理基本元素
|
|
183
|
+
rendered = self._simple_markdown_render(content)
|
|
184
|
+
safe_print_stream(rendered, end="", flush=True)
|
|
185
|
+
else:
|
|
186
|
+
# 纯文本直接输出
|
|
187
|
+
safe_print_stream(content, end="", flush=True)
|
|
188
|
+
except Exception:
|
|
189
|
+
# 出错时降级到原始文本
|
|
190
|
+
safe_print_stream(content, end="", flush=True)
|
|
191
|
+
|
|
192
|
+
def _simple_markdown_render(self, content):
|
|
193
|
+
"""简单的Markdown渲染 - 只处理基本格式"""
|
|
194
|
+
import re
|
|
195
|
+
|
|
196
|
+
# 粗体 **text**
|
|
197
|
+
content = re.sub(r'\*\*([^\*]+)\*\*', r'[bold]\1[/bold]', content)
|
|
198
|
+
# 斜体 *text*
|
|
199
|
+
content = re.sub(r'\*([^\*]+)\*', r'[italic]\1[/italic]', content)
|
|
200
|
+
# 行内代码 `code`
|
|
201
|
+
content = re.sub(r'`([^`]+)`', r'[code]\1[/code]', content)
|
|
202
|
+
|
|
203
|
+
return content
|
|
204
|
+
|
|
205
|
+
def finalize(self):
|
|
206
|
+
"""完成渲染,处理剩余内容"""
|
|
207
|
+
if self.last_rendered_length < len(self.buffer):
|
|
208
|
+
remaining = self.buffer[self.last_rendered_length:]
|
|
209
|
+
self._render_content(remaining)
|
|
210
|
+
|
|
211
|
+
safe_print_stream("", end="\n") # 换行
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def safe_print_stream_markdown(content, is_complete=False, **kwargs):
|
|
215
|
+
"""流式Markdown渲染函数,累积内容后渲染"""
|
|
216
|
+
if is_complete:
|
|
217
|
+
# 完整内容,进行Markdown渲染
|
|
218
|
+
try:
|
|
219
|
+
markdown = Markdown(content)
|
|
220
|
+
console.print(markdown, **kwargs)
|
|
221
|
+
except Exception:
|
|
222
|
+
safe_print_stream(content, **kwargs)
|
|
223
|
+
else:
|
|
224
|
+
# 流式输出,直接打印原始文本
|
|
225
|
+
safe_print_stream(content, **kwargs)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def get_user_input(prompt_text="You"):
|
|
229
|
+
"""获取用户输入,支持Rich格式的提示"""
|
|
230
|
+
try:
|
|
231
|
+
# 使用console.input来支持Rich格式
|
|
232
|
+
return console.input(f"[bold yellow]{prompt_text}:[/bold yellow] ")
|
|
233
|
+
except Exception:
|
|
234
|
+
# 降级到普通input
|
|
235
|
+
return input(f"{prompt_text}: ")
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
class AdvancedInput:
|
|
239
|
+
"""高级输入处理器,支持多行输入(Alt+Enter 换行)"""
|
|
240
|
+
|
|
241
|
+
def __init__(self):
|
|
242
|
+
self._use_prompt_toolkit = False
|
|
243
|
+
self._bindings = None
|
|
244
|
+
self._init_prompt_toolkit()
|
|
245
|
+
|
|
246
|
+
def _init_prompt_toolkit(self):
|
|
247
|
+
"""初始化 prompt_toolkit 的键绑定"""
|
|
248
|
+
try:
|
|
249
|
+
from prompt_toolkit.key_binding import KeyBindings
|
|
250
|
+
from prompt_toolkit.keys import Keys
|
|
251
|
+
|
|
252
|
+
# 创建快捷键绑定
|
|
253
|
+
self._bindings = KeyBindings()
|
|
254
|
+
|
|
255
|
+
@self._bindings.add(Keys.Enter)
|
|
256
|
+
def _(event):
|
|
257
|
+
"""Enter 提交输入"""
|
|
258
|
+
event.current_buffer.validate_and_handle()
|
|
259
|
+
|
|
260
|
+
# Alt+Enter (Escape + Enter) 换行 - 最可靠的方式
|
|
261
|
+
@self._bindings.add('escape', 'enter')
|
|
262
|
+
def _(event):
|
|
263
|
+
"""Alt+Enter 换行"""
|
|
264
|
+
event.current_buffer.insert_text('\n')
|
|
265
|
+
|
|
266
|
+
self._use_prompt_toolkit = True
|
|
267
|
+
except ImportError:
|
|
268
|
+
self._use_prompt_toolkit = False
|
|
269
|
+
|
|
270
|
+
def _sync_prompt(self, prompt_text: str) -> str:
|
|
271
|
+
"""同步调用 prompt_toolkit(在单独线程中运行)"""
|
|
272
|
+
from prompt_toolkit import prompt as pt_prompt
|
|
273
|
+
return pt_prompt(
|
|
274
|
+
f"{prompt_text}: ",
|
|
275
|
+
key_bindings=self._bindings,
|
|
276
|
+
multiline=False,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def get_input(self, prompt_text="You") -> str:
|
|
280
|
+
"""获取用户输入,支持多行(同步版本)"""
|
|
281
|
+
if self._use_prompt_toolkit:
|
|
282
|
+
try:
|
|
283
|
+
return self._sync_prompt(prompt_text)
|
|
284
|
+
except (KeyboardInterrupt, EOFError):
|
|
285
|
+
raise
|
|
286
|
+
except Exception:
|
|
287
|
+
# 出错时降级到基本输入
|
|
288
|
+
self._use_prompt_toolkit = False
|
|
289
|
+
|
|
290
|
+
# Fallback 到基本输入
|
|
291
|
+
return get_user_input(prompt_text)
|
|
292
|
+
|
|
293
|
+
async def get_input_async(self, prompt_text="You") -> str:
|
|
294
|
+
"""获取用户输入,支持多行(异步版本,在单独线程中运行)"""
|
|
295
|
+
if self._use_prompt_toolkit:
|
|
296
|
+
try:
|
|
297
|
+
import asyncio
|
|
298
|
+
# 在单独线程中运行 prompt_toolkit,避免与 asyncio 冲突
|
|
299
|
+
return await asyncio.to_thread(self._sync_prompt, prompt_text)
|
|
300
|
+
except (KeyboardInterrupt, EOFError):
|
|
301
|
+
raise
|
|
302
|
+
except Exception:
|
|
303
|
+
# 出错时降级到基本输入
|
|
304
|
+
self._use_prompt_toolkit = False
|
|
305
|
+
|
|
306
|
+
# Fallback 到基本输入
|
|
307
|
+
return get_user_input(prompt_text)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class ChatCommands:
|
|
311
|
+
"""聊天快捷命令处理器"""
|
|
312
|
+
|
|
313
|
+
COMMANDS = {
|
|
314
|
+
'/clear': '清空对话历史',
|
|
315
|
+
'/retry': '重新生成上一条回复',
|
|
316
|
+
'/save': '保存对话到文件 (用法: /save [文件名])',
|
|
317
|
+
'/model': '切换模型 (用法: /model [模型名])',
|
|
318
|
+
'/help': '显示帮助信息',
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def is_command(cls, text: str) -> bool:
|
|
323
|
+
"""检查是否是命令"""
|
|
324
|
+
return text.strip().startswith('/')
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def parse(cls, text: str) -> tuple:
|
|
328
|
+
"""解析命令,返回 (命令名, 参数列表)"""
|
|
329
|
+
parts = text.strip().split(maxsplit=1)
|
|
330
|
+
cmd = parts[0].lower()
|
|
331
|
+
args = parts[1] if len(parts) > 1 else ""
|
|
332
|
+
return cmd, args
|
|
333
|
+
|
|
334
|
+
@classmethod
|
|
335
|
+
def show_help(cls):
|
|
336
|
+
"""显示帮助信息"""
|
|
337
|
+
safe_print("\n[bold cyan]📋 可用命令:[/bold cyan]")
|
|
338
|
+
for cmd, desc in cls.COMMANDS.items():
|
|
339
|
+
safe_print(f" [green]{cmd:12}[/green] - {desc}")
|
|
340
|
+
safe_print("")
|
|
341
|
+
|
|
342
|
+
@classmethod
|
|
343
|
+
def handle_clear(cls, messages: list, system_prompt: str = None) -> list:
|
|
344
|
+
"""清空对话历史"""
|
|
345
|
+
new_messages = []
|
|
346
|
+
if system_prompt:
|
|
347
|
+
new_messages.append({"role": "system", "content": system_prompt})
|
|
348
|
+
safe_print("[dim]🗑️ 对话历史已清空[/dim]\n")
|
|
349
|
+
return new_messages
|
|
350
|
+
|
|
351
|
+
@classmethod
|
|
352
|
+
def handle_save(cls, messages: list, filename: str = None):
|
|
353
|
+
"""保存对话到文件"""
|
|
354
|
+
import json
|
|
355
|
+
from datetime import datetime
|
|
356
|
+
|
|
357
|
+
if not filename:
|
|
358
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
359
|
+
filename = f"chat_{timestamp}.json"
|
|
360
|
+
|
|
361
|
+
if not filename.endswith('.json'):
|
|
362
|
+
filename += '.json'
|
|
363
|
+
|
|
364
|
+
# 过滤掉系统消息,只保存用户和助手的对话
|
|
365
|
+
chat_history = [
|
|
366
|
+
msg for msg in messages
|
|
367
|
+
if msg.get('role') in ['user', 'assistant']
|
|
368
|
+
]
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
with open(filename, 'w', encoding='utf-8') as f:
|
|
372
|
+
json.dump({
|
|
373
|
+
'saved_at': datetime.now().isoformat(),
|
|
374
|
+
'messages': chat_history
|
|
375
|
+
}, f, ensure_ascii=False, indent=2)
|
|
376
|
+
safe_print(f"[green]💾 对话已保存到: {filename}[/green]\n")
|
|
377
|
+
except Exception as e:
|
|
378
|
+
safe_print(f"[red]❌ 保存失败: {e}[/red]\n")
|
|
379
|
+
|
|
380
|
+
@classmethod
|
|
381
|
+
def handle_retry(cls, messages: list) -> tuple:
|
|
382
|
+
"""准备重试:移除最后一条助手回复,返回是否需要重试"""
|
|
383
|
+
if len(messages) < 2:
|
|
384
|
+
safe_print("[yellow]⚠️ 没有可以重试的回复[/yellow]\n")
|
|
385
|
+
return messages, False
|
|
386
|
+
|
|
387
|
+
# 找到最后一条助手消息并移除
|
|
388
|
+
if messages[-1].get('role') == 'assistant':
|
|
389
|
+
messages.pop()
|
|
390
|
+
safe_print("[dim]🔄 正在重新生成...[/dim]")
|
|
391
|
+
return messages, True
|
|
392
|
+
else:
|
|
393
|
+
safe_print("[yellow]⚠️ 最后一条不是助手回复,无法重试[/yellow]\n")
|
|
394
|
+
return messages, False
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class MllmGroup:
|
|
398
|
+
"""MLLM命令组 - 统一管理多模态大语言模型相关功能"""
|
|
399
|
+
|
|
400
|
+
def __init__(self, cli_instance):
|
|
401
|
+
self.cli = cli_instance
|
|
402
|
+
|
|
403
|
+
def call_table(
|
|
404
|
+
self,
|
|
405
|
+
table_path: str,
|
|
406
|
+
model: str = None,
|
|
407
|
+
base_url: str = None,
|
|
408
|
+
api_key: str = None,
|
|
409
|
+
image_col: str = "image",
|
|
410
|
+
system_prompt: str = "你是一个专业的图像识别专家。",
|
|
411
|
+
text_prompt: str = "请描述这张图像。",
|
|
412
|
+
system_prompt_file: str = None,
|
|
413
|
+
text_prompt_file: str = None,
|
|
414
|
+
sheet_name: str = 0,
|
|
415
|
+
max_num=None,
|
|
416
|
+
output_file: str = "table_results.csv",
|
|
417
|
+
temperature: float = 0.1,
|
|
418
|
+
max_tokens: int = 2000,
|
|
419
|
+
concurrency_limit: int = 10,
|
|
420
|
+
max_qps: int = 50,
|
|
421
|
+
retry_times: int = 3,
|
|
422
|
+
skip_existing: bool = False,
|
|
423
|
+
**kwargs,
|
|
424
|
+
):
|
|
425
|
+
"""对表格中的图像列进行批量大模型识别和分析
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
table_path: 表格文件路径 (xlsx/csv)
|
|
429
|
+
model: 模型名称
|
|
430
|
+
base_url: API服务地址
|
|
431
|
+
api_key: API密钥
|
|
432
|
+
image_col: 图片列名
|
|
433
|
+
system_prompt: 系统提示词
|
|
434
|
+
text_prompt: 文本提示词
|
|
435
|
+
system_prompt_file: 系统提示词文件路径(优先于 system_prompt)
|
|
436
|
+
text_prompt_file: 文本提示词文件路径(优先于 text_prompt)
|
|
437
|
+
sheet_name: sheet名称
|
|
438
|
+
max_num: 最大处理数量
|
|
439
|
+
output_file: 输出文件路径
|
|
440
|
+
temperature: 温度参数
|
|
441
|
+
max_tokens: 最大token数
|
|
442
|
+
concurrency_limit: 并发限制
|
|
443
|
+
max_qps: 最大QPS
|
|
444
|
+
retry_times: 重试次数
|
|
445
|
+
skip_existing: 是否跳过已有结果的行(断点续传)
|
|
446
|
+
"""
|
|
447
|
+
import asyncio
|
|
448
|
+
import pandas as pd
|
|
449
|
+
import os
|
|
450
|
+
from flexllm.mllm_client import MllmClient
|
|
451
|
+
|
|
452
|
+
# 从配置文件获取默认值
|
|
453
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
454
|
+
model = model or mllm_config.get("model", "gemma3:latest")
|
|
455
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
456
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
457
|
+
|
|
458
|
+
# 从文件读取 prompt(如果指定)
|
|
459
|
+
if system_prompt_file and os.path.exists(system_prompt_file):
|
|
460
|
+
with open(system_prompt_file, 'r', encoding='utf-8') as f:
|
|
461
|
+
system_prompt = f.read().strip()
|
|
462
|
+
safe_print(f"[dim]📄 从文件加载 system_prompt: {system_prompt_file}[/dim]")
|
|
463
|
+
|
|
464
|
+
if text_prompt_file and os.path.exists(text_prompt_file):
|
|
465
|
+
with open(text_prompt_file, 'r', encoding='utf-8') as f:
|
|
466
|
+
text_prompt = f.read().strip()
|
|
467
|
+
safe_print(f"[dim]📄 从文件加载 text_prompt: {text_prompt_file}[/dim]")
|
|
468
|
+
|
|
469
|
+
async def run_call_table():
|
|
470
|
+
try:
|
|
471
|
+
safe_print(f"\n[bold green]📊 开始批量处理表格[/bold green]")
|
|
472
|
+
safe_print(f"[cyan]📁 文件: {table_path}[/cyan]")
|
|
473
|
+
safe_print(f"[dim]🔧 模型: {model} | 并发: {concurrency_limit} | QPS: {max_qps}[/dim]")
|
|
474
|
+
|
|
475
|
+
# 初始化客户端
|
|
476
|
+
client = MllmClient(
|
|
477
|
+
model=model,
|
|
478
|
+
base_url=base_url,
|
|
479
|
+
api_key=api_key,
|
|
480
|
+
concurrency_limit=concurrency_limit,
|
|
481
|
+
max_qps=max_qps,
|
|
482
|
+
retry_times=retry_times,
|
|
483
|
+
**kwargs,
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# 加载数据
|
|
487
|
+
if table_path.endswith(".xlsx"):
|
|
488
|
+
df = pd.read_excel(table_path, sheet_name=sheet_name)
|
|
489
|
+
else:
|
|
490
|
+
df = pd.read_csv(table_path)
|
|
491
|
+
|
|
492
|
+
total_rows = len(df)
|
|
493
|
+
if max_num:
|
|
494
|
+
df = df.head(max_num)
|
|
495
|
+
|
|
496
|
+
safe_print(f"[dim]📝 总行数: {total_rows}, 处理行数: {len(df)}[/dim]")
|
|
497
|
+
|
|
498
|
+
# 检查并创建结果列
|
|
499
|
+
result_col = "mllm_result"
|
|
500
|
+
if result_col not in df.columns:
|
|
501
|
+
df[result_col] = None
|
|
502
|
+
|
|
503
|
+
# 断点续传:过滤已有结果的行
|
|
504
|
+
if skip_existing and os.path.exists(output_file):
|
|
505
|
+
existing_df = pd.read_csv(output_file) if output_file.endswith('.csv') else pd.read_excel(output_file)
|
|
506
|
+
if result_col in existing_df.columns:
|
|
507
|
+
# 合并已有结果
|
|
508
|
+
df[result_col] = existing_df[result_col] if len(existing_df) == len(df) else df[result_col]
|
|
509
|
+
safe_print(f"[yellow]⏭️ 断点续传: 检测到已有结果文件[/yellow]")
|
|
510
|
+
|
|
511
|
+
# 找出需要处理的行
|
|
512
|
+
if skip_existing:
|
|
513
|
+
pending_mask = df[result_col].isna() | (df[result_col] == '') | (df[result_col] == 'None')
|
|
514
|
+
pending_indices = df[pending_mask].index.tolist()
|
|
515
|
+
else:
|
|
516
|
+
pending_indices = df.index.tolist()
|
|
517
|
+
|
|
518
|
+
if not pending_indices:
|
|
519
|
+
safe_print(f"[green]✅ 所有行已处理完成,无需重新处理[/green]")
|
|
520
|
+
return
|
|
521
|
+
|
|
522
|
+
safe_print(f"[cyan]🔄 待处理: {len(pending_indices)} 行[/cyan]")
|
|
523
|
+
|
|
524
|
+
# 构建待处理的 messages
|
|
525
|
+
messages_list = []
|
|
526
|
+
for idx in pending_indices:
|
|
527
|
+
row = df.loc[idx]
|
|
528
|
+
messages = []
|
|
529
|
+
if system_prompt:
|
|
530
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
531
|
+
messages.append({
|
|
532
|
+
"role": "user",
|
|
533
|
+
"content": [
|
|
534
|
+
{"type": "text", "text": text_prompt},
|
|
535
|
+
{"type": "image_url", "image_url": {"url": str(row[image_col])}},
|
|
536
|
+
],
|
|
537
|
+
})
|
|
538
|
+
messages_list.append(messages)
|
|
539
|
+
|
|
540
|
+
# 调用 MLLM
|
|
541
|
+
results = await client.call_llm(
|
|
542
|
+
messages_list,
|
|
543
|
+
temperature=temperature,
|
|
544
|
+
max_tokens=max_tokens,
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
# 填充结果
|
|
548
|
+
for i, idx in enumerate(pending_indices):
|
|
549
|
+
df.at[idx, result_col] = results[i] if i < len(results) else None
|
|
550
|
+
|
|
551
|
+
# 保存结果
|
|
552
|
+
if output_file.endswith('.csv'):
|
|
553
|
+
df.to_csv(output_file, index=False, encoding='utf-8-sig')
|
|
554
|
+
else:
|
|
555
|
+
df.to_excel(output_file, index=False)
|
|
556
|
+
|
|
557
|
+
safe_print(f"\n[bold green]✅ 处理完成!结果已保存到: {output_file}[/bold green]")
|
|
558
|
+
|
|
559
|
+
# 统计
|
|
560
|
+
success_count = df[result_col].notna().sum()
|
|
561
|
+
safe_print(f"[dim]📊 成功: {success_count}/{len(df)}[/dim]")
|
|
562
|
+
|
|
563
|
+
except Exception as e:
|
|
564
|
+
safe_print(f"[red]❌ 处理失败: {e}[/red]")
|
|
565
|
+
import traceback
|
|
566
|
+
traceback.print_exc()
|
|
567
|
+
|
|
568
|
+
return asyncio.run(run_call_table())
|
|
569
|
+
|
|
570
|
+
def call_images(
|
|
571
|
+
self,
|
|
572
|
+
folder_path: str,
|
|
573
|
+
model: str = None,
|
|
574
|
+
base_url: str = None,
|
|
575
|
+
api_key: str = None,
|
|
576
|
+
system_prompt: str = "你是一个专业的图像识别专家。",
|
|
577
|
+
text_prompt: str = "请描述这张图像。",
|
|
578
|
+
system_prompt_file: str = None,
|
|
579
|
+
text_prompt_file: str = None,
|
|
580
|
+
recursive: bool = True,
|
|
581
|
+
max_num: int = None,
|
|
582
|
+
extensions: str = None,
|
|
583
|
+
output_file: str = "results.csv",
|
|
584
|
+
temperature: float = 0.1,
|
|
585
|
+
max_tokens: int = 2000,
|
|
586
|
+
concurrency_limit: int = 10,
|
|
587
|
+
max_qps: int = 50,
|
|
588
|
+
retry_times: int = 3,
|
|
589
|
+
skip_existing: bool = False,
|
|
590
|
+
**kwargs,
|
|
591
|
+
):
|
|
592
|
+
"""对文件夹中的图像进行批量大模型识别和分析
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
folder_path: 文件夹路径
|
|
596
|
+
model: 模型名称
|
|
597
|
+
base_url: API服务地址
|
|
598
|
+
api_key: API密钥
|
|
599
|
+
system_prompt: 系统提示词
|
|
600
|
+
text_prompt: 文本提示词
|
|
601
|
+
system_prompt_file: 系统提示词文件路径(优先于 system_prompt)
|
|
602
|
+
text_prompt_file: 文本提示词文件路径(优先于 text_prompt)
|
|
603
|
+
recursive: 是否递归扫描子文件夹
|
|
604
|
+
max_num: 最大处理数量
|
|
605
|
+
extensions: 支持的文件扩展名(逗号分隔,如 "jpg,png,webp")
|
|
606
|
+
output_file: 输出文件路径
|
|
607
|
+
temperature: 温度参数
|
|
608
|
+
max_tokens: 最大token数
|
|
609
|
+
concurrency_limit: 并发限制
|
|
610
|
+
max_qps: 最大QPS
|
|
611
|
+
retry_times: 重试次数
|
|
612
|
+
skip_existing: 是否跳过已处理的图片(断点续传)
|
|
613
|
+
"""
|
|
614
|
+
import asyncio
|
|
615
|
+
import pandas as pd
|
|
616
|
+
import os
|
|
617
|
+
from pathlib import Path
|
|
618
|
+
from flexllm.mllm_client import MllmClient
|
|
619
|
+
|
|
620
|
+
# 从配置文件获取默认值
|
|
621
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
622
|
+
model = model or mllm_config.get("model", "gemma3:latest")
|
|
623
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
624
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
625
|
+
|
|
626
|
+
# 从文件读取 prompt(如果指定)
|
|
627
|
+
if system_prompt_file and os.path.exists(system_prompt_file):
|
|
628
|
+
with open(system_prompt_file, 'r', encoding='utf-8') as f:
|
|
629
|
+
system_prompt = f.read().strip()
|
|
630
|
+
safe_print(f"[dim]📄 从文件加载 system_prompt: {system_prompt_file}[/dim]")
|
|
631
|
+
|
|
632
|
+
if text_prompt_file and os.path.exists(text_prompt_file):
|
|
633
|
+
with open(text_prompt_file, 'r', encoding='utf-8') as f:
|
|
634
|
+
text_prompt = f.read().strip()
|
|
635
|
+
safe_print(f"[dim]📄 从文件加载 text_prompt: {text_prompt_file}[/dim]")
|
|
636
|
+
|
|
637
|
+
# 解析扩展名
|
|
638
|
+
ext_set = None
|
|
639
|
+
if extensions:
|
|
640
|
+
ext_set = {f".{ext.strip().lower().lstrip('.')}" for ext in extensions.split(',')}
|
|
641
|
+
|
|
642
|
+
async def run_call_images():
|
|
643
|
+
try:
|
|
644
|
+
safe_print(f"\n[bold green]📁 开始批量处理文件夹图片[/bold green]")
|
|
645
|
+
safe_print(f"[cyan]📂 路径: {folder_path}[/cyan]")
|
|
646
|
+
safe_print(f"[dim]🔧 模型: {model} | 并发: {concurrency_limit} | QPS: {max_qps}[/dim]")
|
|
647
|
+
|
|
648
|
+
# 初始化客户端
|
|
649
|
+
client = MllmClient(
|
|
650
|
+
model=model,
|
|
651
|
+
base_url=base_url,
|
|
652
|
+
api_key=api_key,
|
|
653
|
+
concurrency_limit=concurrency_limit,
|
|
654
|
+
max_qps=max_qps,
|
|
655
|
+
retry_times=retry_times,
|
|
656
|
+
**kwargs,
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
# 扫描图片文件
|
|
660
|
+
image_files = client.folder.scan_folder_images(
|
|
661
|
+
folder_path=folder_path,
|
|
662
|
+
recursive=recursive,
|
|
663
|
+
max_num=max_num,
|
|
664
|
+
extensions=ext_set,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
if not image_files:
|
|
668
|
+
safe_print(f"[yellow]⚠️ 未找到图片文件[/yellow]")
|
|
669
|
+
return
|
|
670
|
+
|
|
671
|
+
# 创建结果 DataFrame
|
|
672
|
+
df = pd.DataFrame({'image_path': image_files})
|
|
673
|
+
result_col = "mllm_result"
|
|
674
|
+
df[result_col] = None
|
|
675
|
+
|
|
676
|
+
# 断点续传:加载已有结果
|
|
677
|
+
processed_paths = set()
|
|
678
|
+
if skip_existing and os.path.exists(output_file):
|
|
679
|
+
try:
|
|
680
|
+
existing_df = pd.read_csv(output_file) if output_file.endswith('.csv') else pd.read_excel(output_file)
|
|
681
|
+
if 'image_path' in existing_df.columns and result_col in existing_df.columns:
|
|
682
|
+
# 创建路径到结果的映射
|
|
683
|
+
for _, row in existing_df.iterrows():
|
|
684
|
+
path = row['image_path']
|
|
685
|
+
result = row[result_col]
|
|
686
|
+
if pd.notna(result) and result != '' and result != 'None':
|
|
687
|
+
processed_paths.add(path)
|
|
688
|
+
# 更新 df 中对应行的结果
|
|
689
|
+
mask = df['image_path'] == path
|
|
690
|
+
if mask.any():
|
|
691
|
+
df.loc[mask, result_col] = result
|
|
692
|
+
safe_print(f"[yellow]⏭️ 断点续传: 已处理 {len(processed_paths)} 个文件[/yellow]")
|
|
693
|
+
except Exception as e:
|
|
694
|
+
safe_print(f"[yellow]⚠️ 读取已有结果失败: {e}[/yellow]")
|
|
695
|
+
|
|
696
|
+
# 找出需要处理的文件
|
|
697
|
+
pending_indices = []
|
|
698
|
+
for idx, row in df.iterrows():
|
|
699
|
+
if row['image_path'] not in processed_paths:
|
|
700
|
+
pending_indices.append(idx)
|
|
701
|
+
|
|
702
|
+
if not pending_indices:
|
|
703
|
+
safe_print(f"[green]✅ 所有图片已处理完成,无需重新处理[/green]")
|
|
704
|
+
return
|
|
705
|
+
|
|
706
|
+
safe_print(f"[cyan]🔄 待处理: {len(pending_indices)} 个图片[/cyan]")
|
|
707
|
+
|
|
708
|
+
# 构建 messages
|
|
709
|
+
messages_list = []
|
|
710
|
+
pending_files = []
|
|
711
|
+
for idx in pending_indices:
|
|
712
|
+
image_path = df.loc[idx, 'image_path']
|
|
713
|
+
pending_files.append(image_path)
|
|
714
|
+
messages = []
|
|
715
|
+
if system_prompt:
|
|
716
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
717
|
+
messages.append({
|
|
718
|
+
"role": "user",
|
|
719
|
+
"content": [
|
|
720
|
+
{"type": "text", "text": text_prompt},
|
|
721
|
+
{"type": "image_url", "image_url": {"url": f"file://{image_path}"}},
|
|
722
|
+
],
|
|
723
|
+
})
|
|
724
|
+
messages_list.append(messages)
|
|
725
|
+
|
|
726
|
+
# 调用 MLLM
|
|
727
|
+
results = await client.call_llm(
|
|
728
|
+
messages_list,
|
|
729
|
+
temperature=temperature,
|
|
730
|
+
max_tokens=max_tokens,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
# 填充结果
|
|
734
|
+
for i, idx in enumerate(pending_indices):
|
|
735
|
+
df.at[idx, result_col] = results[i] if i < len(results) else None
|
|
736
|
+
|
|
737
|
+
# 保存结果
|
|
738
|
+
if output_file.endswith('.csv'):
|
|
739
|
+
df.to_csv(output_file, index=False, encoding='utf-8-sig')
|
|
740
|
+
else:
|
|
741
|
+
df.to_excel(output_file, index=False)
|
|
742
|
+
|
|
743
|
+
safe_print(f"\n[bold green]✅ 处理完成!结果已保存到: {output_file}[/bold green]")
|
|
744
|
+
|
|
745
|
+
# 统计
|
|
746
|
+
success_count = df[result_col].notna().sum()
|
|
747
|
+
safe_print(f"[dim]📊 成功: {success_count}/{len(df)}[/dim]")
|
|
748
|
+
|
|
749
|
+
except Exception as e:
|
|
750
|
+
safe_print(f"[red]❌ 处理失败: {e}[/red]")
|
|
751
|
+
import traceback
|
|
752
|
+
traceback.print_exc()
|
|
753
|
+
|
|
754
|
+
return asyncio.run(run_call_images())
|
|
755
|
+
|
|
756
|
+
def chat(
|
|
757
|
+
self,
|
|
758
|
+
message: str = None,
|
|
759
|
+
image: str = None,
|
|
760
|
+
model: str = None,
|
|
761
|
+
base_url: str = None,
|
|
762
|
+
api_key: str = None,
|
|
763
|
+
system_prompt: str = None,
|
|
764
|
+
temperature: float = 0.1,
|
|
765
|
+
max_tokens: int = 2000,
|
|
766
|
+
stream: bool = True,
|
|
767
|
+
**kwargs,
|
|
768
|
+
):
|
|
769
|
+
"""交互式多模态对话"""
|
|
770
|
+
# 同步版本,简化处理
|
|
771
|
+
import asyncio
|
|
772
|
+
from flexllm.mllm_client import MllmClient
|
|
773
|
+
|
|
774
|
+
# 从配置文件获取默认值
|
|
775
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
776
|
+
|
|
777
|
+
if model is None:
|
|
778
|
+
model_name = mllm_config.get("model", "gemma3:latest")
|
|
779
|
+
else:
|
|
780
|
+
model_name = model
|
|
781
|
+
|
|
782
|
+
if base_url is None:
|
|
783
|
+
base_url_val = mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
784
|
+
else:
|
|
785
|
+
base_url_val = base_url
|
|
786
|
+
|
|
787
|
+
if api_key is None:
|
|
788
|
+
api_key_val = mllm_config.get("api_key", "EMPTY")
|
|
789
|
+
else:
|
|
790
|
+
api_key_val = api_key
|
|
791
|
+
|
|
792
|
+
if message:
|
|
793
|
+
# 单次对话模式
|
|
794
|
+
def run_single_chat():
|
|
795
|
+
async def _single_chat():
|
|
796
|
+
try:
|
|
797
|
+
# 初始化客户端
|
|
798
|
+
client = MllmClient(
|
|
799
|
+
model=model_name,
|
|
800
|
+
base_url=base_url_val,
|
|
801
|
+
api_key=api_key_val,
|
|
802
|
+
**kwargs,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
messages = [
|
|
806
|
+
{
|
|
807
|
+
"role": "user",
|
|
808
|
+
"content": message
|
|
809
|
+
if not image
|
|
810
|
+
else [
|
|
811
|
+
{"type": "text", "text": message},
|
|
812
|
+
{"type": "image_url", "image_url": {"url": image}},
|
|
813
|
+
],
|
|
814
|
+
}
|
|
815
|
+
]
|
|
816
|
+
|
|
817
|
+
if system_prompt:
|
|
818
|
+
messages.insert(
|
|
819
|
+
0, {"role": "system", "content": system_prompt}
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
if stream:
|
|
823
|
+
# 流式输出 - 使用优雅的Markdown渲染器
|
|
824
|
+
safe_print(f"[bold blue]Assistant:[/bold blue] ")
|
|
825
|
+
|
|
826
|
+
renderer = StreamingMarkdownRenderer()
|
|
827
|
+
try:
|
|
828
|
+
async for token in client.call_llm_stream(
|
|
829
|
+
messages=messages,
|
|
830
|
+
temperature=temperature,
|
|
831
|
+
max_tokens=max_tokens,
|
|
832
|
+
**kwargs,
|
|
833
|
+
):
|
|
834
|
+
renderer.add_token(token)
|
|
835
|
+
# 完成流式输出
|
|
836
|
+
renderer.finalize()
|
|
837
|
+
except KeyboardInterrupt:
|
|
838
|
+
safe_print_stream("\n")
|
|
839
|
+
safe_print("[dim]⏸️ 输出已中断[/dim]")
|
|
840
|
+
return renderer.buffer
|
|
841
|
+
else:
|
|
842
|
+
# 非流式输出,使用Markdown渲染
|
|
843
|
+
results = await client.call_llm(
|
|
844
|
+
messages_list=[messages], show_progress=False
|
|
845
|
+
)
|
|
846
|
+
response = (
|
|
847
|
+
results[0] if results and results[0] else "无响应"
|
|
848
|
+
)
|
|
849
|
+
safe_print(f"[bold blue]Assistant:[/bold blue]")
|
|
850
|
+
safe_print_markdown(response)
|
|
851
|
+
return response
|
|
852
|
+
except KeyboardInterrupt:
|
|
853
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
854
|
+
return None
|
|
855
|
+
except Exception as e:
|
|
856
|
+
safe_print(f"[red]❌ 执行错误: {e}[/red]")
|
|
857
|
+
safe_print("[yellow]💡 请检查模型配置和网络连接[/yellow]")
|
|
858
|
+
return None
|
|
859
|
+
|
|
860
|
+
try:
|
|
861
|
+
return asyncio.run(_single_chat())
|
|
862
|
+
except KeyboardInterrupt:
|
|
863
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
864
|
+
return None
|
|
865
|
+
|
|
866
|
+
return run_single_chat()
|
|
867
|
+
else:
|
|
868
|
+
# 多轮交互模式
|
|
869
|
+
def run_interactive_chat():
|
|
870
|
+
async def _interactive_chat():
|
|
871
|
+
try:
|
|
872
|
+
# 初始化客户端
|
|
873
|
+
client = MllmClient(
|
|
874
|
+
model=model_name,
|
|
875
|
+
base_url=base_url_val,
|
|
876
|
+
api_key=api_key_val,
|
|
877
|
+
**kwargs,
|
|
878
|
+
)
|
|
879
|
+
|
|
880
|
+
# 初始化对话历史
|
|
881
|
+
messages = []
|
|
882
|
+
if system_prompt:
|
|
883
|
+
messages.append(
|
|
884
|
+
{"role": "system", "content": system_prompt}
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
# 初始化高级输入处理器
|
|
888
|
+
advanced_input = AdvancedInput()
|
|
889
|
+
current_model = model_name # 用于支持 /model 切换
|
|
890
|
+
|
|
891
|
+
safe_print("\n[bold green]🚀 多轮对话模式启动[/bold green]")
|
|
892
|
+
safe_print(f"[cyan]📦 模型: [/cyan][bold]{current_model}[/bold]")
|
|
893
|
+
safe_print(f"[cyan]🌐 服务器: [/cyan][bold]{base_url_val}[/bold]")
|
|
894
|
+
safe_print(f"[dim]💡 输入 [bold]/help[/bold] 查看命令 | [bold]Ctrl+C[/bold] 退出 | [bold]Alt+Enter[/bold] 换行[/dim]")
|
|
895
|
+
safe_print(f"[dim]{'─' * 60}[/dim]\n")
|
|
896
|
+
|
|
897
|
+
while True:
|
|
898
|
+
try:
|
|
899
|
+
# 获取用户输入(支持多行,异步版本避免与 asyncio 冲突)
|
|
900
|
+
user_input = (await advanced_input.get_input_async("You")).strip()
|
|
901
|
+
|
|
902
|
+
# 检查退出命令
|
|
903
|
+
if user_input.lower() in ["quit", "exit", "q", "退出"]:
|
|
904
|
+
safe_print("[dim]👋 再见![/dim]")
|
|
905
|
+
break
|
|
906
|
+
|
|
907
|
+
if not user_input:
|
|
908
|
+
continue
|
|
909
|
+
|
|
910
|
+
# 处理快捷命令
|
|
911
|
+
if ChatCommands.is_command(user_input):
|
|
912
|
+
cmd, args = ChatCommands.parse(user_input)
|
|
913
|
+
|
|
914
|
+
if cmd == '/help':
|
|
915
|
+
ChatCommands.show_help()
|
|
916
|
+
continue
|
|
917
|
+
|
|
918
|
+
elif cmd == '/clear':
|
|
919
|
+
messages = ChatCommands.handle_clear(messages, system_prompt)
|
|
920
|
+
continue
|
|
921
|
+
|
|
922
|
+
elif cmd == '/save':
|
|
923
|
+
ChatCommands.handle_save(messages, args if args else None)
|
|
924
|
+
continue
|
|
925
|
+
|
|
926
|
+
elif cmd == '/model':
|
|
927
|
+
if args:
|
|
928
|
+
current_model = args.strip()
|
|
929
|
+
# 重新创建客户端
|
|
930
|
+
client = MllmClient(
|
|
931
|
+
model=current_model,
|
|
932
|
+
base_url=base_url_val,
|
|
933
|
+
api_key=api_key_val,
|
|
934
|
+
**kwargs,
|
|
935
|
+
)
|
|
936
|
+
safe_print(f"[green]✅ 模型已切换为: {current_model}[/green]\n")
|
|
937
|
+
else:
|
|
938
|
+
safe_print(f"[cyan]当前模型: {current_model}[/cyan]")
|
|
939
|
+
safe_print(f"[dim]用法: /model <模型名>[/dim]\n")
|
|
940
|
+
continue
|
|
941
|
+
|
|
942
|
+
elif cmd == '/retry':
|
|
943
|
+
messages, should_retry = ChatCommands.handle_retry(messages)
|
|
944
|
+
if not should_retry:
|
|
945
|
+
continue
|
|
946
|
+
# 继续执行下面的生成逻辑
|
|
947
|
+
else:
|
|
948
|
+
safe_print(f"[yellow]⚠️ 未知命令: {cmd}[/yellow]")
|
|
949
|
+
safe_print(f"[dim]输入 /help 查看可用命令[/dim]\n")
|
|
950
|
+
continue
|
|
951
|
+
|
|
952
|
+
# 检测是否包含图片路径或URL
|
|
953
|
+
import os
|
|
954
|
+
import re
|
|
955
|
+
|
|
956
|
+
# /retry 时不需要添加新消息,直接重新生成
|
|
957
|
+
is_retry = ChatCommands.is_command(user_input) and ChatCommands.parse(user_input)[0] == '/retry'
|
|
958
|
+
image_path = None
|
|
959
|
+
text_content = user_input
|
|
960
|
+
|
|
961
|
+
if not is_retry:
|
|
962
|
+
# 检查是否是URL
|
|
963
|
+
url_pattern = r'(https?://[^\s]+\.(?:jpg|jpeg|png|gif|bmp|webp)(?:\?[^\s]*)?)'
|
|
964
|
+
url_match = re.search(url_pattern, user_input, re.IGNORECASE)
|
|
965
|
+
|
|
966
|
+
if url_match:
|
|
967
|
+
image_path = url_match.group(1)
|
|
968
|
+
text_content = user_input.replace(image_path, "").strip()
|
|
969
|
+
if not text_content:
|
|
970
|
+
text_content = "请描述这张图片"
|
|
971
|
+
else:
|
|
972
|
+
# 检查是否包含本地文件路径
|
|
973
|
+
# 支持多种格式:绝对路径、相对路径、带引号的路径
|
|
974
|
+
path_patterns = [
|
|
975
|
+
r'"([^"]+\.(?:jpg|jpeg|png|gif|bmp|webp))"', # 双引号路径
|
|
976
|
+
r"'([^']+\.(?:jpg|jpeg|png|gif|bmp|webp))'", # 单引号路径
|
|
977
|
+
r'([^\s]+\.(?:jpg|jpeg|png|gif|bmp|webp))(?:\s|$)', # 无引号路径
|
|
978
|
+
]
|
|
979
|
+
|
|
980
|
+
for pattern in path_patterns:
|
|
981
|
+
match = re.search(pattern, user_input, re.IGNORECASE)
|
|
982
|
+
if match:
|
|
983
|
+
potential_path = match.group(1)
|
|
984
|
+
# 检查文件是否存在
|
|
985
|
+
if os.path.exists(potential_path):
|
|
986
|
+
image_path = os.path.abspath(potential_path)
|
|
987
|
+
text_content = user_input.replace(match.group(0), "").strip()
|
|
988
|
+
if not text_content:
|
|
989
|
+
text_content = "请描述这张图片"
|
|
990
|
+
break
|
|
991
|
+
# 尝试相对路径
|
|
992
|
+
elif os.path.exists(os.path.join(os.getcwd(), potential_path)):
|
|
993
|
+
image_path = os.path.abspath(os.path.join(os.getcwd(), potential_path))
|
|
994
|
+
text_content = user_input.replace(match.group(0), "").strip()
|
|
995
|
+
if not text_content:
|
|
996
|
+
text_content = "请描述这张图片"
|
|
997
|
+
break
|
|
998
|
+
|
|
999
|
+
# 构建消息内容
|
|
1000
|
+
if image_path:
|
|
1001
|
+
# 如果是本地文件,转换为file://格式
|
|
1002
|
+
if not image_path.startswith('http'):
|
|
1003
|
+
image_url = f"file://{image_path.replace(os.sep, '/')}"
|
|
1004
|
+
else:
|
|
1005
|
+
image_url = image_path
|
|
1006
|
+
|
|
1007
|
+
safe_print(f"[dim]📷 发送图片: {image_path}[/dim]")
|
|
1008
|
+
message_content = [
|
|
1009
|
+
{"type": "text", "text": text_content},
|
|
1010
|
+
{"type": "image_url", "image_url": {"url": image_url}}
|
|
1011
|
+
]
|
|
1012
|
+
else:
|
|
1013
|
+
message_content = user_input
|
|
1014
|
+
|
|
1015
|
+
# 添加用户消息到历史
|
|
1016
|
+
messages.append({"role": "user", "content": message_content})
|
|
1017
|
+
|
|
1018
|
+
if stream:
|
|
1019
|
+
# 流式输出 - 使用优雅的Markdown渲染器
|
|
1020
|
+
safe_print(f"[bold blue]Assistant:[/bold blue] ")
|
|
1021
|
+
|
|
1022
|
+
renderer = StreamingMarkdownRenderer()
|
|
1023
|
+
stream_interrupted = False
|
|
1024
|
+
try:
|
|
1025
|
+
async for token in client.call_llm_stream(
|
|
1026
|
+
messages=messages,
|
|
1027
|
+
temperature=temperature,
|
|
1028
|
+
max_tokens=max_tokens,
|
|
1029
|
+
**kwargs,
|
|
1030
|
+
):
|
|
1031
|
+
renderer.add_token(token)
|
|
1032
|
+
except KeyboardInterrupt:
|
|
1033
|
+
stream_interrupted = True
|
|
1034
|
+
safe_print_stream("\n")
|
|
1035
|
+
safe_print("[dim]⏸️ 输出已中断[/dim]")
|
|
1036
|
+
|
|
1037
|
+
# 完成流式输出
|
|
1038
|
+
if not stream_interrupted:
|
|
1039
|
+
renderer.finalize()
|
|
1040
|
+
full_response = renderer.buffer
|
|
1041
|
+
|
|
1042
|
+
# 添加助手响应到历史(即使中断也保存已获取的内容)
|
|
1043
|
+
if full_response:
|
|
1044
|
+
messages.append(
|
|
1045
|
+
{
|
|
1046
|
+
"role": "assistant",
|
|
1047
|
+
"content": full_response,
|
|
1048
|
+
}
|
|
1049
|
+
)
|
|
1050
|
+
else:
|
|
1051
|
+
# 非流式输出,使用Markdown渲染
|
|
1052
|
+
results = await client.call_llm(
|
|
1053
|
+
messages_list=[messages],
|
|
1054
|
+
show_progress=False,
|
|
1055
|
+
temperature=temperature,
|
|
1056
|
+
max_tokens=max_tokens,
|
|
1057
|
+
**kwargs,
|
|
1058
|
+
)
|
|
1059
|
+
response = (
|
|
1060
|
+
results[0]
|
|
1061
|
+
if results and results[0]
|
|
1062
|
+
else "无响应"
|
|
1063
|
+
)
|
|
1064
|
+
safe_print(f"[bold blue]Assistant:[/bold blue]")
|
|
1065
|
+
safe_print_markdown(response)
|
|
1066
|
+
|
|
1067
|
+
# 添加助手响应到历史
|
|
1068
|
+
if response and response != "无响应":
|
|
1069
|
+
messages.append(
|
|
1070
|
+
{"role": "assistant", "content": response}
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
except KeyboardInterrupt:
|
|
1074
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
1075
|
+
break
|
|
1076
|
+
except EOFError:
|
|
1077
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
1078
|
+
break
|
|
1079
|
+
except Exception as e:
|
|
1080
|
+
safe_print(f"[red]❌ 处理错误: {e}[/red]")
|
|
1081
|
+
safe_print("[yellow]💡 请重试或输入 'quit' 退出[/yellow]")
|
|
1082
|
+
continue
|
|
1083
|
+
|
|
1084
|
+
except Exception as e:
|
|
1085
|
+
safe_print(f"[red]❌ 初始化错误: {e}[/red]")
|
|
1086
|
+
safe_print("[yellow]💡 请检查MLLM客户端配置或服务器连接[/yellow]")
|
|
1087
|
+
return None
|
|
1088
|
+
|
|
1089
|
+
try:
|
|
1090
|
+
return asyncio.run(_interactive_chat())
|
|
1091
|
+
except KeyboardInterrupt:
|
|
1092
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
1093
|
+
return None
|
|
1094
|
+
|
|
1095
|
+
# 检查是否在交互环境中
|
|
1096
|
+
import sys
|
|
1097
|
+
|
|
1098
|
+
if not sys.stdin.isatty():
|
|
1099
|
+
safe_print("[red]❌ 错误: 交互模式需要在终端中运行[/red]")
|
|
1100
|
+
safe_print(
|
|
1101
|
+
"[yellow]💡 请在交互式终端中运行此命令,或提供具体的消息参数[/yellow]"
|
|
1102
|
+
)
|
|
1103
|
+
safe_print('[dim]📝 示例: [bold]maque mllm chat "你好"[/bold][/dim]')
|
|
1104
|
+
return
|
|
1105
|
+
|
|
1106
|
+
try:
|
|
1107
|
+
return run_interactive_chat()
|
|
1108
|
+
except KeyboardInterrupt:
|
|
1109
|
+
safe_print("\n[dim]👋 再见![/dim]")
|
|
1110
|
+
return None
|
|
1111
|
+
|
|
1112
|
+
def models(self, base_url: str = None, api_key: str = None):
|
|
1113
|
+
"""列出可用模型"""
|
|
1114
|
+
import requests
|
|
1115
|
+
|
|
1116
|
+
# 从配置获取默认值
|
|
1117
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
1118
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
1119
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
1120
|
+
|
|
1121
|
+
try:
|
|
1122
|
+
headers = {"Authorization": f"Bearer {api_key}"}
|
|
1123
|
+
response = requests.get(
|
|
1124
|
+
f"{base_url.rstrip('/')}/models", headers=headers, timeout=10
|
|
1125
|
+
)
|
|
1126
|
+
|
|
1127
|
+
if response.status_code == 200:
|
|
1128
|
+
models_data = response.json()
|
|
1129
|
+
|
|
1130
|
+
safe_print(f"\n[bold blue]🤖 可用模型列表[/bold blue]")
|
|
1131
|
+
safe_print(f"[dim]📡 服务器: {base_url}[/dim]")
|
|
1132
|
+
safe_print(f"[dim]{'─' * 50}[/dim]")
|
|
1133
|
+
|
|
1134
|
+
if isinstance(models_data, dict) and "data" in models_data:
|
|
1135
|
+
models = models_data["data"]
|
|
1136
|
+
elif isinstance(models_data, list):
|
|
1137
|
+
models = models_data
|
|
1138
|
+
else:
|
|
1139
|
+
models = []
|
|
1140
|
+
|
|
1141
|
+
if models:
|
|
1142
|
+
for i, model in enumerate(models, 1):
|
|
1143
|
+
if isinstance(model, dict):
|
|
1144
|
+
model_id = model.get("id", model.get("name", "unknown"))
|
|
1145
|
+
safe_print(f"[green]{i:2d}. [/green][cyan]{model_id}[/cyan]")
|
|
1146
|
+
else:
|
|
1147
|
+
safe_print(f"[green]{i:2d}. [/green][cyan]{model}[/cyan]")
|
|
1148
|
+
safe_print(f"\n[dim]✅ 共找到 {len(models)} 个可用模型[/dim]")
|
|
1149
|
+
else:
|
|
1150
|
+
safe_print("[yellow]⚠️ 未找到可用模型[/yellow]")
|
|
1151
|
+
safe_print("[dim]💡 请检查服务器配置或网络连接[/dim]")
|
|
1152
|
+
|
|
1153
|
+
else:
|
|
1154
|
+
safe_print(f"[red]❌ 获取模型列表失败: HTTP {response.status_code}[/red]")
|
|
1155
|
+
safe_print(f"[yellow]💡 请检查服务器状态或API权限[/yellow]")
|
|
1156
|
+
|
|
1157
|
+
except requests.exceptions.RequestException as e:
|
|
1158
|
+
safe_print(f"[red]🔌 连接失败: {e}[/red]")
|
|
1159
|
+
safe_print(f"[yellow]💡 请检查服务地址: [bold]{base_url}[/bold][/yellow]")
|
|
1160
|
+
safe_print(f"[dim]提示: 确保服务器正在运行并且地址正确[/dim]")
|
|
1161
|
+
except Exception as e:
|
|
1162
|
+
safe_print(f"[red]❌ 未知错误: {e}[/red]")
|
|
1163
|
+
|
|
1164
|
+
def test(
|
|
1165
|
+
self,
|
|
1166
|
+
model: str = None,
|
|
1167
|
+
base_url: str = None,
|
|
1168
|
+
api_key: str = None,
|
|
1169
|
+
message: str = "Hello, please respond with 'OK' if you can see this message.",
|
|
1170
|
+
timeout: int = 30,
|
|
1171
|
+
):
|
|
1172
|
+
"""测试MLLM服务连接和配置
|
|
1173
|
+
|
|
1174
|
+
Args:
|
|
1175
|
+
model: 模型名称(可选,不指定则只测试连接)
|
|
1176
|
+
base_url: API服务地址
|
|
1177
|
+
api_key: API密钥
|
|
1178
|
+
message: 测试消息
|
|
1179
|
+
timeout: 超时时间(秒)
|
|
1180
|
+
"""
|
|
1181
|
+
import requests
|
|
1182
|
+
import time
|
|
1183
|
+
|
|
1184
|
+
# 从配置获取默认值
|
|
1185
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
1186
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
1187
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
1188
|
+
model = model or mllm_config.get("model")
|
|
1189
|
+
|
|
1190
|
+
safe_print(f"\n[bold blue]🔍 MLLM 服务连接测试[/bold blue]")
|
|
1191
|
+
safe_print(f"[dim]{'─' * 50}[/dim]")
|
|
1192
|
+
|
|
1193
|
+
results = {
|
|
1194
|
+
"connection": False,
|
|
1195
|
+
"models_api": False,
|
|
1196
|
+
"chat_api": False,
|
|
1197
|
+
}
|
|
1198
|
+
|
|
1199
|
+
# 1. 测试基本连接
|
|
1200
|
+
safe_print(f"\n[cyan]1. 测试服务器连接...[/cyan]")
|
|
1201
|
+
safe_print(f" [dim]地址: {base_url}[/dim]")
|
|
1202
|
+
try:
|
|
1203
|
+
start_time = time.time()
|
|
1204
|
+
response = requests.get(
|
|
1205
|
+
f"{base_url.rstrip('/')}/models",
|
|
1206
|
+
headers={"Authorization": f"Bearer {api_key}"},
|
|
1207
|
+
timeout=timeout
|
|
1208
|
+
)
|
|
1209
|
+
elapsed = time.time() - start_time
|
|
1210
|
+
|
|
1211
|
+
if response.status_code == 200:
|
|
1212
|
+
safe_print(f" [green]✅ 连接成功[/green] [dim]({elapsed:.2f}s)[/dim]")
|
|
1213
|
+
results["connection"] = True
|
|
1214
|
+
results["models_api"] = True
|
|
1215
|
+
|
|
1216
|
+
# 解析模型列表
|
|
1217
|
+
models_data = response.json()
|
|
1218
|
+
if isinstance(models_data, dict) and "data" in models_data:
|
|
1219
|
+
models = models_data["data"]
|
|
1220
|
+
elif isinstance(models_data, list):
|
|
1221
|
+
models = models_data
|
|
1222
|
+
else:
|
|
1223
|
+
models = []
|
|
1224
|
+
|
|
1225
|
+
model_count = len(models)
|
|
1226
|
+
safe_print(f" [dim]可用模型数: {model_count}[/dim]")
|
|
1227
|
+
|
|
1228
|
+
elif response.status_code == 401:
|
|
1229
|
+
safe_print(f" [yellow]⚠️ 认证失败 (401)[/yellow]")
|
|
1230
|
+
safe_print(f" [dim]请检查 API Key 是否正确[/dim]")
|
|
1231
|
+
results["connection"] = True
|
|
1232
|
+
elif response.status_code == 404:
|
|
1233
|
+
safe_print(f" [yellow]⚠️ /models 端点不存在 (404)[/yellow]")
|
|
1234
|
+
safe_print(f" [dim]服务器可能不支持 OpenAI 兼容 API[/dim]")
|
|
1235
|
+
results["connection"] = True
|
|
1236
|
+
else:
|
|
1237
|
+
safe_print(f" [yellow]⚠️ HTTP {response.status_code}[/yellow]")
|
|
1238
|
+
results["connection"] = True
|
|
1239
|
+
|
|
1240
|
+
except requests.exceptions.ConnectionError:
|
|
1241
|
+
safe_print(f" [red]❌ 连接失败: 无法连接到服务器[/red]")
|
|
1242
|
+
safe_print(f" [dim]请检查服务器是否运行在 {base_url}[/dim]")
|
|
1243
|
+
except requests.exceptions.Timeout:
|
|
1244
|
+
safe_print(f" [red]❌ 连接超时 ({timeout}s)[/red]")
|
|
1245
|
+
except Exception as e:
|
|
1246
|
+
safe_print(f" [red]❌ 连接错误: {e}[/red]")
|
|
1247
|
+
|
|
1248
|
+
# 2. 测试 Chat API(如果指定了模型)
|
|
1249
|
+
if model and results["connection"]:
|
|
1250
|
+
safe_print(f"\n[cyan]2. 测试 Chat API...[/cyan]")
|
|
1251
|
+
safe_print(f" [dim]模型: {model}[/dim]")
|
|
1252
|
+
|
|
1253
|
+
try:
|
|
1254
|
+
start_time = time.time()
|
|
1255
|
+
response = requests.post(
|
|
1256
|
+
f"{base_url.rstrip('/')}/chat/completions",
|
|
1257
|
+
headers={
|
|
1258
|
+
"Authorization": f"Bearer {api_key}",
|
|
1259
|
+
"Content-Type": "application/json"
|
|
1260
|
+
},
|
|
1261
|
+
json={
|
|
1262
|
+
"model": model,
|
|
1263
|
+
"messages": [{"role": "user", "content": message}],
|
|
1264
|
+
"max_tokens": 50,
|
|
1265
|
+
"temperature": 0.1
|
|
1266
|
+
},
|
|
1267
|
+
timeout=timeout
|
|
1268
|
+
)
|
|
1269
|
+
elapsed = time.time() - start_time
|
|
1270
|
+
|
|
1271
|
+
if response.status_code == 200:
|
|
1272
|
+
data = response.json()
|
|
1273
|
+
content = ""
|
|
1274
|
+
if "choices" in data and data["choices"]:
|
|
1275
|
+
content = data["choices"][0].get("message", {}).get("content", "")
|
|
1276
|
+
|
|
1277
|
+
safe_print(f" [green]✅ Chat API 正常[/green] [dim]({elapsed:.2f}s)[/dim]")
|
|
1278
|
+
if content:
|
|
1279
|
+
# 截断过长的响应
|
|
1280
|
+
display_content = content[:100] + "..." if len(content) > 100 else content
|
|
1281
|
+
safe_print(f" [dim]响应: {display_content}[/dim]")
|
|
1282
|
+
results["chat_api"] = True
|
|
1283
|
+
|
|
1284
|
+
# 显示 token 使用情况
|
|
1285
|
+
usage = data.get("usage", {})
|
|
1286
|
+
if usage:
|
|
1287
|
+
safe_print(f" [dim]Token 使用: prompt={usage.get('prompt_tokens', 'N/A')}, "
|
|
1288
|
+
f"completion={usage.get('completion_tokens', 'N/A')}[/dim]")
|
|
1289
|
+
|
|
1290
|
+
elif response.status_code == 404:
|
|
1291
|
+
safe_print(f" [yellow]⚠️ 模型不存在或 API 端点不可用[/yellow]")
|
|
1292
|
+
safe_print(f" [dim]请检查模型名称: {model}[/dim]")
|
|
1293
|
+
elif response.status_code == 401:
|
|
1294
|
+
safe_print(f" [yellow]⚠️ 认证失败[/yellow]")
|
|
1295
|
+
else:
|
|
1296
|
+
safe_print(f" [yellow]⚠️ HTTP {response.status_code}[/yellow]")
|
|
1297
|
+
try:
|
|
1298
|
+
error_detail = response.json()
|
|
1299
|
+
safe_print(f" [dim]{error_detail}[/dim]")
|
|
1300
|
+
except:
|
|
1301
|
+
pass
|
|
1302
|
+
|
|
1303
|
+
except requests.exceptions.Timeout:
|
|
1304
|
+
safe_print(f" [yellow]⚠️ 请求超时 ({timeout}s)[/yellow]")
|
|
1305
|
+
safe_print(f" [dim]模型可能正在加载或服务器繁忙[/dim]")
|
|
1306
|
+
except Exception as e:
|
|
1307
|
+
safe_print(f" [red]❌ 请求失败: {e}[/red]")
|
|
1308
|
+
|
|
1309
|
+
# 3. 总结
|
|
1310
|
+
safe_print(f"\n[dim]{'─' * 50}[/dim]")
|
|
1311
|
+
safe_print(f"[bold]测试结果汇总:[/bold]")
|
|
1312
|
+
|
|
1313
|
+
status_icons = {True: "[green]✅[/green]", False: "[red]❌[/red]"}
|
|
1314
|
+
|
|
1315
|
+
safe_print(f" {status_icons[results['connection']]} 服务器连接")
|
|
1316
|
+
safe_print(f" {status_icons[results['models_api']]} Models API")
|
|
1317
|
+
if model:
|
|
1318
|
+
safe_print(f" {status_icons[results['chat_api']]} Chat API ({model})")
|
|
1319
|
+
|
|
1320
|
+
# 给出建议
|
|
1321
|
+
if all(results.values()) or (results["connection"] and results["models_api"] and not model):
|
|
1322
|
+
safe_print(f"\n[green]🎉 所有测试通过!MLLM 服务配置正确。[/green]")
|
|
1323
|
+
else:
|
|
1324
|
+
safe_print(f"\n[yellow]💡 建议:[/yellow]")
|
|
1325
|
+
if not results["connection"]:
|
|
1326
|
+
safe_print(f" - 检查服务器是否启动")
|
|
1327
|
+
safe_print(f" - 检查 base_url 配置是否正确")
|
|
1328
|
+
if results["connection"] and not results["models_api"]:
|
|
1329
|
+
safe_print(f" - 检查 API Key 是否正确")
|
|
1330
|
+
safe_print(f" - 确认服务器支持 OpenAI 兼容 API")
|
|
1331
|
+
if model and not results["chat_api"]:
|
|
1332
|
+
safe_print(f" - 检查模型名称是否正确")
|
|
1333
|
+
safe_print(f" - 使用 'mq mllm models' 查看可用模型")
|
|
1334
|
+
|
|
1335
|
+
return results
|
|
1336
|
+
|
|
1337
|
+
def chain_analysis(
|
|
1338
|
+
self,
|
|
1339
|
+
query: str,
|
|
1340
|
+
steps: int = 3,
|
|
1341
|
+
model: str = None,
|
|
1342
|
+
base_url: str = None,
|
|
1343
|
+
api_key: str = None,
|
|
1344
|
+
temperature: float = 0.1,
|
|
1345
|
+
max_tokens: int = 2000,
|
|
1346
|
+
show_details: bool = False,
|
|
1347
|
+
**kwargs,
|
|
1348
|
+
):
|
|
1349
|
+
"""使用Chain of Thought进行分析推理
|
|
1350
|
+
|
|
1351
|
+
Args:
|
|
1352
|
+
query: 要分析的问题或内容
|
|
1353
|
+
steps: 分析步骤数,默认3步
|
|
1354
|
+
model: 使用的模型
|
|
1355
|
+
base_url: API服务地址
|
|
1356
|
+
api_key: API密钥
|
|
1357
|
+
temperature: 温度参数
|
|
1358
|
+
max_tokens: 最大token数
|
|
1359
|
+
show_details: 是否显示每个步骤的详细信息
|
|
1360
|
+
"""
|
|
1361
|
+
import asyncio
|
|
1362
|
+
from flexllm.chain_of_thought_client import ChainOfThoughtClient, LinearStep, ExecutionConfig
|
|
1363
|
+
from flexllm.openaiclient import OpenAIClient
|
|
1364
|
+
|
|
1365
|
+
# 从配置获取默认值
|
|
1366
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
1367
|
+
model = model or mllm_config.get("model", "gemma3:latest")
|
|
1368
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
1369
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
1370
|
+
|
|
1371
|
+
async def run_chain_analysis():
|
|
1372
|
+
try:
|
|
1373
|
+
safe_print(f"[bold green]🔍 开始Chain of Thought分析推理[/bold green]")
|
|
1374
|
+
safe_print(f"[cyan]📝 问题: {query}[/cyan]")
|
|
1375
|
+
safe_print(f"[dim]🔧 模型: {model}, 步骤数: {steps}[/dim]\n")
|
|
1376
|
+
|
|
1377
|
+
# 初始化客户端
|
|
1378
|
+
openai_client = OpenAIClient(model=model, base_url=base_url, api_key=api_key)
|
|
1379
|
+
|
|
1380
|
+
# 配置执行参数
|
|
1381
|
+
config = ExecutionConfig(
|
|
1382
|
+
enable_monitoring=True,
|
|
1383
|
+
enable_progress=show_details,
|
|
1384
|
+
log_level="INFO" if show_details else "WARNING"
|
|
1385
|
+
)
|
|
1386
|
+
|
|
1387
|
+
chain_client = ChainOfThoughtClient(openai_client, config)
|
|
1388
|
+
|
|
1389
|
+
# 定义分析步骤
|
|
1390
|
+
def create_analysis_step(step_num: int, step_name: str, prompt_template: str):
|
|
1391
|
+
def prepare_messages(context):
|
|
1392
|
+
previous_analysis = ""
|
|
1393
|
+
if context.history:
|
|
1394
|
+
previous_analysis = "\n\n".join([
|
|
1395
|
+
f"步骤{i+1}: {step.response}"
|
|
1396
|
+
for i, step in enumerate(context.history)
|
|
1397
|
+
])
|
|
1398
|
+
|
|
1399
|
+
system_prompt = f"""你是一个专业的分析师,正在进行第{step_num}步分析。
|
|
1400
|
+
请根据问题和之前的分析结果,{step_name}。
|
|
1401
|
+
保持逻辑清晰,分析深入。"""
|
|
1402
|
+
|
|
1403
|
+
user_prompt = prompt_template.format(
|
|
1404
|
+
query=context.query,
|
|
1405
|
+
previous_analysis=previous_analysis
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
return [
|
|
1409
|
+
{"role": "system", "content": system_prompt},
|
|
1410
|
+
{"role": "user", "content": user_prompt}
|
|
1411
|
+
]
|
|
1412
|
+
|
|
1413
|
+
return LinearStep(
|
|
1414
|
+
name=f"analysis_step_{step_num}",
|
|
1415
|
+
prepare_messages_fn=prepare_messages,
|
|
1416
|
+
model_params={
|
|
1417
|
+
"temperature": temperature,
|
|
1418
|
+
"max_tokens": max_tokens,
|
|
1419
|
+
**kwargs
|
|
1420
|
+
}
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
# 创建分析链条
|
|
1424
|
+
analysis_steps = []
|
|
1425
|
+
|
|
1426
|
+
if steps >= 1:
|
|
1427
|
+
analysis_steps.append(create_analysis_step(
|
|
1428
|
+
1, "理解和分解问题",
|
|
1429
|
+
"请仔细分析这个问题:\n{query}\n\n请分解这个问题的关键要素,明确分析的方向和重点。"
|
|
1430
|
+
))
|
|
1431
|
+
|
|
1432
|
+
if steps >= 2:
|
|
1433
|
+
analysis_steps.append(create_analysis_step(
|
|
1434
|
+
2, "深入分析各个方面",
|
|
1435
|
+
"基于第一步的分析:\n{previous_analysis}\n\n请从多个角度深入分析问题,探讨可能的解决方案或答案。"
|
|
1436
|
+
))
|
|
1437
|
+
|
|
1438
|
+
if steps >= 3:
|
|
1439
|
+
analysis_steps.append(create_analysis_step(
|
|
1440
|
+
3, "综合结论和建议",
|
|
1441
|
+
"基于前面的分析:\n{previous_analysis}\n\n请总结分析结果,给出明确的结论和实用的建议。"
|
|
1442
|
+
))
|
|
1443
|
+
|
|
1444
|
+
# 如果步骤超过3步,添加更多细化分析
|
|
1445
|
+
for i in range(4, steps + 1):
|
|
1446
|
+
analysis_steps.append(create_analysis_step(
|
|
1447
|
+
i, f"进一步细化分析第{i-3}个方面",
|
|
1448
|
+
"继续深化分析:\n{previous_analysis}\n\n请进一步细化和补充分析,提供更详细的见解。"
|
|
1449
|
+
))
|
|
1450
|
+
|
|
1451
|
+
# 创建线性链条
|
|
1452
|
+
first_step = chain_client.create_linear_chain(analysis_steps, "analysis_chain")
|
|
1453
|
+
|
|
1454
|
+
# 执行链条
|
|
1455
|
+
context = chain_client.create_context({"query": query})
|
|
1456
|
+
result_context = await chain_client.execute_chain(
|
|
1457
|
+
first_step, context, show_step_details=show_details
|
|
1458
|
+
)
|
|
1459
|
+
|
|
1460
|
+
# 显示结果
|
|
1461
|
+
if result_context.history:
|
|
1462
|
+
safe_print(f"\n[bold blue]🎯 Chain of Thought 分析结果[/bold blue]")
|
|
1463
|
+
safe_print(f"[dim]{'=' * 60}[/dim]")
|
|
1464
|
+
|
|
1465
|
+
for i, step_result in enumerate(result_context.history):
|
|
1466
|
+
step_title = f"步骤 {i+1}"
|
|
1467
|
+
if i == 0:
|
|
1468
|
+
step_title += " - 问题理解"
|
|
1469
|
+
elif i == 1:
|
|
1470
|
+
step_title += " - 深入分析"
|
|
1471
|
+
elif i == 2:
|
|
1472
|
+
step_title += " - 综合结论"
|
|
1473
|
+
else:
|
|
1474
|
+
step_title += f" - 细化分析 {i-2}"
|
|
1475
|
+
|
|
1476
|
+
safe_print(f"\n[bold cyan]{step_title}[/bold cyan]")
|
|
1477
|
+
safe_print(f"[green]{step_result.response}[/green]")
|
|
1478
|
+
|
|
1479
|
+
# 执行摘要
|
|
1480
|
+
summary = result_context.get_execution_summary()
|
|
1481
|
+
safe_print(f"\n[dim]📊 执行统计: {summary['total_steps']} 个步骤, "
|
|
1482
|
+
f"耗时 {summary['total_execution_time']:.2f}秒, "
|
|
1483
|
+
f"成功率 {summary['success_rate']*100:.1f}%[/dim]")
|
|
1484
|
+
else:
|
|
1485
|
+
safe_print("[red]❌ 分析执行失败,没有生成结果[/red]")
|
|
1486
|
+
|
|
1487
|
+
except Exception as e:
|
|
1488
|
+
safe_print(f"[red]❌ Chain of Thought分析执行失败: {e}[/red]")
|
|
1489
|
+
safe_print("[yellow]💡 请检查模型配置和网络连接[/yellow]")
|
|
1490
|
+
|
|
1491
|
+
return asyncio.run(run_chain_analysis())
|
|
1492
|
+
|
|
1493
|
+
def chain_reasoning(
|
|
1494
|
+
self,
|
|
1495
|
+
query: str,
|
|
1496
|
+
model: str = None,
|
|
1497
|
+
base_url: str = None,
|
|
1498
|
+
api_key: str = None,
|
|
1499
|
+
temperature: float = 0.1,
|
|
1500
|
+
max_tokens: int = 2000,
|
|
1501
|
+
show_details: bool = False,
|
|
1502
|
+
**kwargs,
|
|
1503
|
+
):
|
|
1504
|
+
"""使用Chain of Thought进行逻辑推理
|
|
1505
|
+
|
|
1506
|
+
Args:
|
|
1507
|
+
query: 需要推理的问题或情境
|
|
1508
|
+
model: 使用的模型
|
|
1509
|
+
base_url: API服务地址
|
|
1510
|
+
api_key: API密钥
|
|
1511
|
+
temperature: 温度参数
|
|
1512
|
+
max_tokens: 最大token数
|
|
1513
|
+
show_details: 是否显示每个步骤的详细信息
|
|
1514
|
+
"""
|
|
1515
|
+
import asyncio
|
|
1516
|
+
from flexllm.chain_of_thought_client import ChainOfThoughtClient, LinearStep, ExecutionConfig
|
|
1517
|
+
from flexllm.openaiclient import OpenAIClient
|
|
1518
|
+
|
|
1519
|
+
# 从配置获取默认值
|
|
1520
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
1521
|
+
model = model or mllm_config.get("model", "gemma3:latest")
|
|
1522
|
+
base_url = base_url or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
1523
|
+
api_key = api_key or mllm_config.get("api_key", "EMPTY")
|
|
1524
|
+
|
|
1525
|
+
async def run_chain_reasoning():
|
|
1526
|
+
try:
|
|
1527
|
+
safe_print(f"[bold green]🧠 开始Chain of Thought逻辑推理[/bold green]")
|
|
1528
|
+
safe_print(f"[cyan]💭 推理问题: {query}[/cyan]")
|
|
1529
|
+
safe_print(f"[dim]🔧 模型: {model}[/dim]\n")
|
|
1530
|
+
|
|
1531
|
+
# 初始化客户端
|
|
1532
|
+
openai_client = OpenAIClient(model=model, base_url=base_url, api_key=api_key)
|
|
1533
|
+
|
|
1534
|
+
config = ExecutionConfig(
|
|
1535
|
+
enable_monitoring=True,
|
|
1536
|
+
enable_progress=show_details,
|
|
1537
|
+
log_level="INFO" if show_details else "WARNING"
|
|
1538
|
+
)
|
|
1539
|
+
|
|
1540
|
+
chain_client = ChainOfThoughtClient(openai_client, config)
|
|
1541
|
+
|
|
1542
|
+
# 定义推理步骤
|
|
1543
|
+
def create_reasoning_step(step_name: str, prompt_template: str):
|
|
1544
|
+
def prepare_messages(context):
|
|
1545
|
+
previous_reasoning = ""
|
|
1546
|
+
if context.history:
|
|
1547
|
+
previous_reasoning = "\n\n".join([
|
|
1548
|
+
f"[{step.step_name}]: {step.response}"
|
|
1549
|
+
for step in context.history
|
|
1550
|
+
])
|
|
1551
|
+
|
|
1552
|
+
return [
|
|
1553
|
+
{"role": "system", "content": "你是一个逻辑推理专家。请使用严谨的逻辑思维,一步一步地分析和推理。每一步都要有明确的逻辑依据。"},
|
|
1554
|
+
{"role": "user", "content": prompt_template.format(
|
|
1555
|
+
query=context.query,
|
|
1556
|
+
previous_reasoning=previous_reasoning
|
|
1557
|
+
)}
|
|
1558
|
+
]
|
|
1559
|
+
|
|
1560
|
+
return LinearStep(
|
|
1561
|
+
name=step_name,
|
|
1562
|
+
prepare_messages_fn=prepare_messages,
|
|
1563
|
+
model_params={
|
|
1564
|
+
"temperature": temperature,
|
|
1565
|
+
"max_tokens": max_tokens,
|
|
1566
|
+
**kwargs
|
|
1567
|
+
}
|
|
1568
|
+
)
|
|
1569
|
+
|
|
1570
|
+
# 创建推理链条
|
|
1571
|
+
reasoning_steps = [
|
|
1572
|
+
create_reasoning_step(
|
|
1573
|
+
"observation",
|
|
1574
|
+
"首先,让我观察和理解这个问题:\n{query}\n\n请仔细观察问题中的关键信息、已知条件和要求解答的内容。列出所有重要的事实和假设。"
|
|
1575
|
+
),
|
|
1576
|
+
create_reasoning_step(
|
|
1577
|
+
"hypothesis",
|
|
1578
|
+
"基于观察到的信息:\n{previous_reasoning}\n\n现在请提出可能的假设或解决方案。考虑多种可能性,并说明每种假设的依据。"
|
|
1579
|
+
),
|
|
1580
|
+
create_reasoning_step(
|
|
1581
|
+
"deduction",
|
|
1582
|
+
"基于前面的观察和假设:\n{previous_reasoning}\n\n现在进行逻辑推导。使用演绎推理,从已知条件推导出结论。确保每一步推理都有明确的逻辑关系。"
|
|
1583
|
+
),
|
|
1584
|
+
create_reasoning_step(
|
|
1585
|
+
"verification",
|
|
1586
|
+
"基于推理过程:\n{previous_reasoning}\n\n现在验证推理结果。检查逻辑是否一致,结论是否合理,是否遗漏了重要因素。如果发现问题,请指出并修正。"
|
|
1587
|
+
),
|
|
1588
|
+
create_reasoning_step(
|
|
1589
|
+
"conclusion",
|
|
1590
|
+
"综合整个推理过程:\n{previous_reasoning}\n\n请给出最终结论。总结推理的关键步骤,明确回答原始问题,并说明结论的可信度。"
|
|
1591
|
+
)
|
|
1592
|
+
]
|
|
1593
|
+
|
|
1594
|
+
# 创建和执行链条
|
|
1595
|
+
first_step = chain_client.create_linear_chain(reasoning_steps, "reasoning_chain")
|
|
1596
|
+
context = chain_client.create_context({"query": query})
|
|
1597
|
+
result_context = await chain_client.execute_chain(
|
|
1598
|
+
first_step, context, show_step_details=show_details
|
|
1599
|
+
)
|
|
1600
|
+
|
|
1601
|
+
# 显示推理结果
|
|
1602
|
+
if result_context.history:
|
|
1603
|
+
safe_print(f"\n[bold blue]🎯 Chain of Thought 推理结果[/bold blue]")
|
|
1604
|
+
safe_print(f"[dim]{'=' * 60}[/dim]")
|
|
1605
|
+
|
|
1606
|
+
step_names = {
|
|
1607
|
+
"observation": "🔍 观察分析",
|
|
1608
|
+
"hypothesis": "💡 假设提出",
|
|
1609
|
+
"deduction": "🔗 逻辑推导",
|
|
1610
|
+
"verification": "✅ 验证检查",
|
|
1611
|
+
"conclusion": "🎯 最终结论"
|
|
1612
|
+
}
|
|
1613
|
+
|
|
1614
|
+
for step_result in result_context.history:
|
|
1615
|
+
step_display = step_names.get(step_result.step_name, step_result.step_name)
|
|
1616
|
+
safe_print(f"\n[bold cyan]{step_display}[/bold cyan]")
|
|
1617
|
+
safe_print(f"[green]{step_result.response}[/green]")
|
|
1618
|
+
|
|
1619
|
+
# 执行摘要
|
|
1620
|
+
summary = result_context.get_execution_summary()
|
|
1621
|
+
safe_print(f"\n[dim]📊 推理统计: {summary['total_steps']} 个步骤, "
|
|
1622
|
+
f"耗时 {summary['total_execution_time']:.2f}秒, "
|
|
1623
|
+
f"成功率 {summary['success_rate']*100:.1f}%[/dim]")
|
|
1624
|
+
else:
|
|
1625
|
+
safe_print("[red]❌ 推理执行失败,没有生成结果[/red]")
|
|
1626
|
+
|
|
1627
|
+
except Exception as e:
|
|
1628
|
+
safe_print(f"[red]❌ Chain of Thought推理执行失败: {e}[/red]")
|
|
1629
|
+
safe_print("[yellow]💡 请检查模型配置和网络连接[/yellow]")
|
|
1630
|
+
|
|
1631
|
+
return asyncio.run(run_chain_reasoning())
|
|
1632
|
+
|
|
1633
|
+
def chain_run(
|
|
1634
|
+
self,
|
|
1635
|
+
config_file: str,
|
|
1636
|
+
input_data: str = None,
|
|
1637
|
+
model: str = None,
|
|
1638
|
+
base_url: str = None,
|
|
1639
|
+
api_key: str = None,
|
|
1640
|
+
show_details: bool = False,
|
|
1641
|
+
**kwargs,
|
|
1642
|
+
):
|
|
1643
|
+
"""运行自定义的Chain of Thought配置文件
|
|
1644
|
+
|
|
1645
|
+
Args:
|
|
1646
|
+
config_file: YAML格式的链条配置文件路径
|
|
1647
|
+
input_data: 输入数据,会作为query传入
|
|
1648
|
+
model: 使用的模型(覆盖配置文件中的设置)
|
|
1649
|
+
base_url: API服务地址
|
|
1650
|
+
api_key: API密钥
|
|
1651
|
+
show_details: 是否显示详细执行信息
|
|
1652
|
+
"""
|
|
1653
|
+
import asyncio
|
|
1654
|
+
import yaml
|
|
1655
|
+
import os
|
|
1656
|
+
from pathlib import Path
|
|
1657
|
+
from flexllm.chain_of_thought_client import ChainOfThoughtClient, LinearStep, ExecutionConfig
|
|
1658
|
+
from flexllm.openaiclient import OpenAIClient
|
|
1659
|
+
|
|
1660
|
+
async def run_chain_config():
|
|
1661
|
+
try:
|
|
1662
|
+
# 读取配置文件
|
|
1663
|
+
config_path = Path(config_file)
|
|
1664
|
+
if not config_path.exists():
|
|
1665
|
+
safe_print(f"[red]❌ 配置文件不存在: {config_file}[/red]")
|
|
1666
|
+
return
|
|
1667
|
+
|
|
1668
|
+
safe_print(f"[bold green]📋 运行Chain of Thought配置[/bold green]")
|
|
1669
|
+
safe_print(f"[cyan]📁 配置文件: {config_file}[/cyan]")
|
|
1670
|
+
|
|
1671
|
+
with open(config_path, 'r', encoding='utf-8') as f:
|
|
1672
|
+
config = yaml.safe_load(f)
|
|
1673
|
+
|
|
1674
|
+
# 从配置文件和命令行参数合并设置
|
|
1675
|
+
mllm_config = self.cli.maque_config.get("mllm", {})
|
|
1676
|
+
|
|
1677
|
+
# 模型配置优先级: 命令行 > 配置文件 > 全局配置
|
|
1678
|
+
final_model = model or config.get('model') or mllm_config.get("model", "gemma3:latest")
|
|
1679
|
+
final_base_url = base_url or config.get('base_url') or mllm_config.get("base_url", "http://localhost:11434/v1")
|
|
1680
|
+
final_api_key = api_key or config.get('api_key') or mllm_config.get("api_key", "EMPTY")
|
|
1681
|
+
|
|
1682
|
+
# 获取输入数据
|
|
1683
|
+
query = input_data or config.get('query', '')
|
|
1684
|
+
if not query:
|
|
1685
|
+
safe_print("[red]❌ 缺少输入数据,请通过 --input-data 参数或在配置文件中的 'query' 字段指定[/red]")
|
|
1686
|
+
return
|
|
1687
|
+
|
|
1688
|
+
safe_print(f"[cyan]📝 输入: {query}[/cyan]")
|
|
1689
|
+
safe_print(f"[dim]🔧 模型: {final_model}[/dim]\n")
|
|
1690
|
+
|
|
1691
|
+
# 初始化客户端
|
|
1692
|
+
openai_client = OpenAIClient(model=final_model, base_url=final_base_url, api_key=final_api_key)
|
|
1693
|
+
|
|
1694
|
+
# 执行配置
|
|
1695
|
+
exec_config = ExecutionConfig(
|
|
1696
|
+
enable_monitoring=config.get('enable_monitoring', True),
|
|
1697
|
+
enable_progress=show_details,
|
|
1698
|
+
log_level="INFO" if show_details else "WARNING",
|
|
1699
|
+
step_timeout=config.get('step_timeout'),
|
|
1700
|
+
chain_timeout=config.get('chain_timeout'),
|
|
1701
|
+
max_retries=config.get('max_retries', 0),
|
|
1702
|
+
retry_delay=config.get('retry_delay', 1.0)
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
chain_client = ChainOfThoughtClient(openai_client, exec_config)
|
|
1706
|
+
|
|
1707
|
+
# 构建步骤
|
|
1708
|
+
steps = config.get('steps', [])
|
|
1709
|
+
if not steps:
|
|
1710
|
+
safe_print("[red]❌ 配置文件中没有定义步骤[/red]")
|
|
1711
|
+
return
|
|
1712
|
+
|
|
1713
|
+
def create_config_step(step_config):
|
|
1714
|
+
step_name = step_config['name']
|
|
1715
|
+
system_prompt = step_config.get('system_prompt', '')
|
|
1716
|
+
user_prompt = step_config.get('user_prompt', '')
|
|
1717
|
+
|
|
1718
|
+
def prepare_messages(context):
|
|
1719
|
+
# 处理模板变量
|
|
1720
|
+
template_vars = {
|
|
1721
|
+
'query': context.query,
|
|
1722
|
+
'previous_responses': '\n\n'.join([f"[{s.step_name}]: {s.response}" for s in context.history])
|
|
1723
|
+
}
|
|
1724
|
+
|
|
1725
|
+
# 添加自定义变量
|
|
1726
|
+
custom_vars = context.get_custom_data('template_vars', {})
|
|
1727
|
+
template_vars.update(custom_vars)
|
|
1728
|
+
|
|
1729
|
+
messages = []
|
|
1730
|
+
if system_prompt:
|
|
1731
|
+
messages.append({
|
|
1732
|
+
"role": "system",
|
|
1733
|
+
"content": system_prompt.format(**template_vars)
|
|
1734
|
+
})
|
|
1735
|
+
|
|
1736
|
+
messages.append({
|
|
1737
|
+
"role": "user",
|
|
1738
|
+
"content": user_prompt.format(**template_vars)
|
|
1739
|
+
})
|
|
1740
|
+
|
|
1741
|
+
return messages
|
|
1742
|
+
|
|
1743
|
+
# 获取模型参数
|
|
1744
|
+
model_params = step_config.get('model_params', {})
|
|
1745
|
+
model_params.update(kwargs) # 命令行参数覆盖
|
|
1746
|
+
|
|
1747
|
+
return LinearStep(
|
|
1748
|
+
name=step_name,
|
|
1749
|
+
prepare_messages_fn=prepare_messages,
|
|
1750
|
+
model_params=model_params
|
|
1751
|
+
)
|
|
1752
|
+
|
|
1753
|
+
# 创建所有步骤
|
|
1754
|
+
chain_steps = [create_config_step(step_config) for step_config in steps]
|
|
1755
|
+
|
|
1756
|
+
# 创建和执行链条
|
|
1757
|
+
chain_name = config.get('name', 'custom_chain')
|
|
1758
|
+
first_step = chain_client.create_linear_chain(chain_steps, chain_name)
|
|
1759
|
+
|
|
1760
|
+
# 添加自定义模板变量到上下文
|
|
1761
|
+
context = chain_client.create_context({"query": query})
|
|
1762
|
+
if config.get('template_vars'):
|
|
1763
|
+
context.add_custom_data('template_vars', config['template_vars'])
|
|
1764
|
+
|
|
1765
|
+
result_context = await chain_client.execute_chain(
|
|
1766
|
+
first_step, context, show_step_details=show_details
|
|
1767
|
+
)
|
|
1768
|
+
|
|
1769
|
+
# 显示结果
|
|
1770
|
+
if result_context.history:
|
|
1771
|
+
safe_print(f"\n[bold blue]🎯 {config.get('name', 'Chain')} 执行结果[/bold blue]")
|
|
1772
|
+
safe_print(f"[dim]{'=' * 60}[/dim]")
|
|
1773
|
+
|
|
1774
|
+
for step_result in result_context.history:
|
|
1775
|
+
step_display = step_result.step_name.replace('_', ' ').title()
|
|
1776
|
+
safe_print(f"\n[bold cyan]📝 {step_display}[/bold cyan]")
|
|
1777
|
+
safe_print(f"[green]{step_result.response}[/green]")
|
|
1778
|
+
|
|
1779
|
+
# 执行摘要
|
|
1780
|
+
summary = result_context.get_execution_summary()
|
|
1781
|
+
safe_print(f"\n[dim]📊 执行统计: {summary['total_steps']} 个步骤, "
|
|
1782
|
+
f"耗时 {summary['total_execution_time']:.2f}秒, "
|
|
1783
|
+
f"成功率 {summary['success_rate']*100:.1f}%[/dim]")
|
|
1784
|
+
else:
|
|
1785
|
+
safe_print("[red]❌ 链条执行失败,没有生成结果[/red]")
|
|
1786
|
+
|
|
1787
|
+
except yaml.YAMLError as e:
|
|
1788
|
+
safe_print(f"[red]❌ YAML配置文件解析错误: {e}[/red]")
|
|
1789
|
+
except FileNotFoundError as e:
|
|
1790
|
+
safe_print(f"[red]❌ 配置文件未找到: {e}[/red]")
|
|
1791
|
+
except Exception as e:
|
|
1792
|
+
safe_print(f"[red]❌ Chain执行失败: {e}[/red]")
|
|
1793
|
+
safe_print("[yellow]💡 请检查配置文件格式和模型连接[/yellow]")
|
|
1794
|
+
|
|
1795
|
+
return asyncio.run(run_chain_config())
|