maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
maque/cli/mllm_run.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import asyncio
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from rich import print
|
|
7
|
+
from maque.utils.helper_parser import split_image_paths
|
|
8
|
+
from maque.utils.helper_metrics import calc_binary_metrics
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional, List, Dict, Any
|
|
11
|
+
from flexllm.mllm_client import MllmClient
|
|
12
|
+
|
|
13
|
+
qianfan_apikey = os.environ.get("qianfan_apikey", "sk-123")
|
|
14
|
+
if qianfan_apikey == "sk-123":
|
|
15
|
+
# 警告
|
|
16
|
+
print(
|
|
17
|
+
"[yellow]未在环境变量中设置`qianfan_apikey`,现使用默认sk-123作为API key[/yellow]"
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ConfigManager:
|
|
22
|
+
"""配置管理类,负责配置的保存、加载和管理"""
|
|
23
|
+
|
|
24
|
+
CONFIG_FILE = "batch_processor_config.json"
|
|
25
|
+
|
|
26
|
+
# 完整的可修改配置项列表
|
|
27
|
+
MODIFIABLE_ITEMS = {
|
|
28
|
+
1: {"name": "文件选择", "key": "file_path", "display": "文件"},
|
|
29
|
+
2: {"name": "文本列", "key": "text_col", "display": "文本列"},
|
|
30
|
+
3: {"name": "图像列", "key": "image_col", "display": "图像列"},
|
|
31
|
+
4: {"name": "数据筛选", "key": "filter_config", "display": "筛选条件"},
|
|
32
|
+
5: {"name": "模型选择", "key": "model_info", "display": "模型"},
|
|
33
|
+
6: {"name": "提示模板", "key": "custom_prompt", "display": "Prompt"},
|
|
34
|
+
7: {"name": "行数范围", "key": "rows_range", "display": "处理行数"},
|
|
35
|
+
8: {"name": "预处理图像", "key": "preprocess_msg", "display": "预处理"},
|
|
36
|
+
9: {"name": "并发数量", "key": "concurrency_limit", "display": "并发数"},
|
|
37
|
+
10: {"name": "QPS限制", "key": "max_qps", "display": "QPS"},
|
|
38
|
+
11: {"name": "结果解析", "key": "parse_config", "display": "解析设置"},
|
|
39
|
+
12: {"name": "分类模式", "key": "use_cls", "display": "分类模式"},
|
|
40
|
+
13: {"name": "系统提示", "key": "system_prompt", "display": "系统提示"},
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
def __init__(self):
|
|
44
|
+
self.config_path = self.CONFIG_FILE
|
|
45
|
+
|
|
46
|
+
def load_last_config(self) -> Optional[Dict[str, Any]]:
|
|
47
|
+
"""加载上次的配置"""
|
|
48
|
+
try:
|
|
49
|
+
if os.path.exists(self.config_path):
|
|
50
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
51
|
+
data = json.load(f)
|
|
52
|
+
return data.get("last_config")
|
|
53
|
+
except Exception as e:
|
|
54
|
+
print(f"[red]加载配置失败: {e}[/red]")
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
def save_config(self, config: Dict[str, Any]):
|
|
58
|
+
"""保存配置"""
|
|
59
|
+
try:
|
|
60
|
+
# 加载现有配置
|
|
61
|
+
existing_data = {}
|
|
62
|
+
if os.path.exists(self.config_path):
|
|
63
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
64
|
+
existing_data = json.load(f)
|
|
65
|
+
|
|
66
|
+
# 更新配置
|
|
67
|
+
config_to_save = config.copy()
|
|
68
|
+
config_to_save["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
69
|
+
|
|
70
|
+
existing_data["last_config"] = config_to_save
|
|
71
|
+
|
|
72
|
+
# 更新最近文件列表
|
|
73
|
+
if "file_path" in config:
|
|
74
|
+
recent_files = existing_data.get("recent_files", [])
|
|
75
|
+
if config["file_path"] in recent_files:
|
|
76
|
+
recent_files.remove(config["file_path"])
|
|
77
|
+
recent_files.insert(0, config["file_path"])
|
|
78
|
+
existing_data["recent_files"] = recent_files[:10] # 保留最近10个
|
|
79
|
+
|
|
80
|
+
# 保存到文件
|
|
81
|
+
with open(self.config_path, "w", encoding="utf-8") as f:
|
|
82
|
+
json.dump(existing_data, f, ensure_ascii=False, indent=2)
|
|
83
|
+
|
|
84
|
+
except Exception as e:
|
|
85
|
+
print(f"[red]保存配置失败: {e}[/red]")
|
|
86
|
+
|
|
87
|
+
def get_recent_files(self) -> List[str]:
|
|
88
|
+
"""获取最近使用的文件列表"""
|
|
89
|
+
try:
|
|
90
|
+
if os.path.exists(self.config_path):
|
|
91
|
+
with open(self.config_path, "r", encoding="utf-8") as f:
|
|
92
|
+
data = json.load(f)
|
|
93
|
+
return data.get("recent_files", [])
|
|
94
|
+
except Exception:
|
|
95
|
+
pass
|
|
96
|
+
return []
|
|
97
|
+
|
|
98
|
+
def display_config_preview(self, config: Dict[str, Any]) -> str:
|
|
99
|
+
"""显示配置预览"""
|
|
100
|
+
if not config:
|
|
101
|
+
return "无历史配置"
|
|
102
|
+
|
|
103
|
+
preview = []
|
|
104
|
+
preview.append("上次配置预览:")
|
|
105
|
+
preview.append("-" * 40)
|
|
106
|
+
|
|
107
|
+
# 文件信息
|
|
108
|
+
if "file_path" in config:
|
|
109
|
+
preview.append(f"文件: {config['file_path']}")
|
|
110
|
+
|
|
111
|
+
# 列信息
|
|
112
|
+
text_col = config.get("text_col", "无")
|
|
113
|
+
image_col = config.get("image_col", "无")
|
|
114
|
+
preview.append(f"文本列: {text_col} | 图像列: {image_col}")
|
|
115
|
+
|
|
116
|
+
# 筛选信息
|
|
117
|
+
filter_config = config.get("filter_config")
|
|
118
|
+
if filter_config:
|
|
119
|
+
filter_col, filter_values = filter_config
|
|
120
|
+
preview.append(
|
|
121
|
+
f"筛选: {filter_col} = {filter_values[:3]}{'...' if len(filter_values) > 3 else ''}"
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
preview.append("筛选: 无")
|
|
125
|
+
|
|
126
|
+
# 模型信息
|
|
127
|
+
model_info = config.get("model_info", {})
|
|
128
|
+
model_name = model_info.get("name", model_info.get("model", "未知"))
|
|
129
|
+
preview.append(f"模型: {model_name}")
|
|
130
|
+
|
|
131
|
+
# Prompt信息
|
|
132
|
+
custom_prompt = config.get("custom_prompt")
|
|
133
|
+
preview.append(f"Prompt: {'自定义' if custom_prompt else '默认'}")
|
|
134
|
+
|
|
135
|
+
# 行数信息
|
|
136
|
+
rows_range = config.get("rows_range", (0, 0))
|
|
137
|
+
start_row, end_row = rows_range
|
|
138
|
+
if start_row == 0 and end_row == 0:
|
|
139
|
+
preview.append("行数范围: 未知")
|
|
140
|
+
elif start_row == 0:
|
|
141
|
+
preview.append(f"行数范围: 前{end_row}行")
|
|
142
|
+
else:
|
|
143
|
+
preview.append(f"行数范围: 第{start_row + 1}-{end_row}行")
|
|
144
|
+
|
|
145
|
+
# 处理参数
|
|
146
|
+
preprocess = "是" if config.get("preprocess_msg", False) else "否"
|
|
147
|
+
concurrency = config.get("concurrency_limit", 100)
|
|
148
|
+
qps = config.get("max_qps", 25)
|
|
149
|
+
use_cls = "是" if config.get("use_cls", False) else "否"
|
|
150
|
+
system_prompt = "是" if config.get("system_prompt") else "否"
|
|
151
|
+
preview.append(f"预处理: {preprocess} | 并发: {concurrency} | QPS: {qps}")
|
|
152
|
+
preview.append(f"分类模式: {use_cls} | 系统提示: {system_prompt}")
|
|
153
|
+
|
|
154
|
+
# 时间信息
|
|
155
|
+
if "timestamp" in config:
|
|
156
|
+
preview.append(f"最后使用: {config['timestamp']}")
|
|
157
|
+
|
|
158
|
+
return "\n".join(preview)
|
|
159
|
+
|
|
160
|
+
def get_config_differences(self, config: Dict[str, Any]) -> List[str]:
|
|
161
|
+
"""获取与默认配置的差异(这里简化处理)"""
|
|
162
|
+
# 这个方法可以用来检测配置项的变化,目前简化返回所有可修改项
|
|
163
|
+
return list(range(1, len(self.MODIFIABLE_ITEMS) + 1))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class BatchProcessor:
|
|
167
|
+
def __init__(
|
|
168
|
+
self,
|
|
169
|
+
model: str = "vkk8o2py_wenxiaoyan_mllm_bj",
|
|
170
|
+
base_url: str = None,
|
|
171
|
+
concurrency_limit: int = 100,
|
|
172
|
+
max_qps: int = 25,
|
|
173
|
+
):
|
|
174
|
+
self.model = model
|
|
175
|
+
self.base_url = base_url or "https://qianfan.baidubce.com/v2"
|
|
176
|
+
self.client = MllmClient(
|
|
177
|
+
base_url=self.base_url,
|
|
178
|
+
api_key=qianfan_apikey,
|
|
179
|
+
model=model,
|
|
180
|
+
concurrency_limit=concurrency_limit,
|
|
181
|
+
max_qps=max_qps,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
def set_model(self, model: str, base_url: str = None):
|
|
185
|
+
"""切换模型和base_url"""
|
|
186
|
+
self.model = model
|
|
187
|
+
if base_url:
|
|
188
|
+
self.base_url = base_url
|
|
189
|
+
# 重新创建client
|
|
190
|
+
self.client = MllmClient(
|
|
191
|
+
base_url=self.base_url,
|
|
192
|
+
api_key=os.environ.get("qianfan_apikey", "sk-123"),
|
|
193
|
+
model=model,
|
|
194
|
+
concurrency_limit=self.client.concurrency_limit,
|
|
195
|
+
max_qps=self.client.max_qps,
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
self.client.model = model
|
|
199
|
+
|
|
200
|
+
def set_custom_prompt(self, prompt_template: str):
|
|
201
|
+
"""设置自定义提示模板"""
|
|
202
|
+
self.custom_prompt = prompt_template
|
|
203
|
+
|
|
204
|
+
def create_messages_for_row(
|
|
205
|
+
self,
|
|
206
|
+
row_data: Dict[str, Any],
|
|
207
|
+
text_col: Optional[str],
|
|
208
|
+
image_col: Optional[str] = None,
|
|
209
|
+
custom_prompt: Optional[str] = None,
|
|
210
|
+
use_cls: bool = False,
|
|
211
|
+
system_prompt: Optional[str] = None,
|
|
212
|
+
) -> List[Dict[str, Any]]:
|
|
213
|
+
"""为单行数据创建messages格式"""
|
|
214
|
+
# 处理文本内容,如果没有文本列则使用"无"
|
|
215
|
+
if text_col is None:
|
|
216
|
+
text_content = "无"
|
|
217
|
+
else:
|
|
218
|
+
text_content = str(row_data.get(text_col, ""))
|
|
219
|
+
|
|
220
|
+
# 使用自定义提示模板或默认模板
|
|
221
|
+
if custom_prompt:
|
|
222
|
+
text_prompt = custom_prompt.format(text_content=text_content)
|
|
223
|
+
elif hasattr(self, "custom_prompt"):
|
|
224
|
+
text_prompt = self.custom_prompt.format(text_content=text_content)
|
|
225
|
+
else:
|
|
226
|
+
# 根据模型选择不同的提示模板
|
|
227
|
+
if "xiaoyan" in self.model:
|
|
228
|
+
if use_cls:
|
|
229
|
+
text_prompt = f"用户query文本: {text_content}\n\n 请判断以上图文内容的分类标签是?"
|
|
230
|
+
else:
|
|
231
|
+
text_prompt = (
|
|
232
|
+
f"用户query文本: {text_content}\n\n 请审核以上图文内容。"
|
|
233
|
+
)
|
|
234
|
+
else:
|
|
235
|
+
if use_cls:
|
|
236
|
+
text_prompt = (
|
|
237
|
+
f"文本: {text_content}\n\n 请判断以上图文内容的风险类型是什么"
|
|
238
|
+
)
|
|
239
|
+
else:
|
|
240
|
+
text_prompt = f"文本: {text_content}\n\n 请审核以上图文内容。"
|
|
241
|
+
|
|
242
|
+
content = [{"type": "text", "text": text_prompt}]
|
|
243
|
+
|
|
244
|
+
if image_col is not None:
|
|
245
|
+
assert image_col in row_data, (
|
|
246
|
+
f"{image_col} not found in row_data: {row_data}"
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# 处理图像列
|
|
250
|
+
if image_col and row_data.get(image_col):
|
|
251
|
+
path_str = str(row_data[image_col])
|
|
252
|
+
if path_str and not pd.isna(path_str):
|
|
253
|
+
path_list = split_image_paths(path_str)
|
|
254
|
+
for path in path_list:
|
|
255
|
+
if path.strip():
|
|
256
|
+
content.append(
|
|
257
|
+
{"type": "image_url", "image_url": {"url": path}}
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
messages = []
|
|
261
|
+
|
|
262
|
+
# 添加系统提示(如果提供)
|
|
263
|
+
if system_prompt:
|
|
264
|
+
messages.append({"role": "system", "content": system_prompt})
|
|
265
|
+
|
|
266
|
+
messages.append({"role": "user", "content": content})
|
|
267
|
+
|
|
268
|
+
return messages
|
|
269
|
+
|
|
270
|
+
async def process_table(
|
|
271
|
+
self,
|
|
272
|
+
table_path: str,
|
|
273
|
+
text_col: Optional[str],
|
|
274
|
+
image_col: Optional[str] = None,
|
|
275
|
+
sheet_name: Optional[str] = None,
|
|
276
|
+
preprocess_msg: bool = False,
|
|
277
|
+
custom_prompt: Optional[str] = None,
|
|
278
|
+
) -> pd.DataFrame:
|
|
279
|
+
"""批量处理表格数据"""
|
|
280
|
+
df = self._load_dataframe(table_path, sheet_name)
|
|
281
|
+
messages_list = self._create_messages_list(
|
|
282
|
+
df, text_col, image_col, custom_prompt
|
|
283
|
+
)
|
|
284
|
+
results = await self._call_llm_batch(messages_list, preprocess_msg)
|
|
285
|
+
df = self._add_results_to_dataframe(df, results)
|
|
286
|
+
output_path = self._save_results(df, table_path)
|
|
287
|
+
|
|
288
|
+
print(f"结果已保存到: {output_path}")
|
|
289
|
+
return df
|
|
290
|
+
|
|
291
|
+
def _load_dataframe(
|
|
292
|
+
self, table_path: str, sheet_name: Optional[str] = None
|
|
293
|
+
) -> pd.DataFrame:
|
|
294
|
+
"""加载数据表格"""
|
|
295
|
+
if sheet_name is not None:
|
|
296
|
+
return pd.read_excel(table_path, sheet_name=sheet_name)
|
|
297
|
+
else:
|
|
298
|
+
return (
|
|
299
|
+
pd.read_excel(table_path)
|
|
300
|
+
if table_path.endswith(".xlsx")
|
|
301
|
+
else pd.read_csv(table_path)
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
def _create_messages_list(
|
|
305
|
+
self,
|
|
306
|
+
df: pd.DataFrame,
|
|
307
|
+
text_col: Optional[str],
|
|
308
|
+
image_col: Optional[str],
|
|
309
|
+
custom_prompt: Optional[str] = None,
|
|
310
|
+
use_cls: bool = False,
|
|
311
|
+
system_prompt: Optional[str] = None,
|
|
312
|
+
) -> List[List[Dict[str, Any]]]:
|
|
313
|
+
"""为每行创建messages"""
|
|
314
|
+
messages_list = []
|
|
315
|
+
for _, row in df.iterrows():
|
|
316
|
+
messages = self.create_messages_for_row(
|
|
317
|
+
row.to_dict(),
|
|
318
|
+
text_col,
|
|
319
|
+
image_col,
|
|
320
|
+
custom_prompt,
|
|
321
|
+
use_cls,
|
|
322
|
+
system_prompt,
|
|
323
|
+
)
|
|
324
|
+
messages_list.append(messages)
|
|
325
|
+
return messages_list
|
|
326
|
+
|
|
327
|
+
async def _call_llm_batch(
|
|
328
|
+
self, messages_list: List[List[Dict[str, Any]]], preprocess_msg: bool
|
|
329
|
+
) -> List[str]:
|
|
330
|
+
"""批量调用LLM"""
|
|
331
|
+
return await self.client.call_llm(
|
|
332
|
+
messages_list=messages_list,
|
|
333
|
+
preprocess_msg=preprocess_msg,
|
|
334
|
+
safety={"input_level": "none", "input_image_level": "none"},
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
def _add_results_to_dataframe(
|
|
338
|
+
self, df: pd.DataFrame, results: List[str]
|
|
339
|
+
) -> pd.DataFrame:
|
|
340
|
+
"""将结果添加到DataFrame"""
|
|
341
|
+
if "response" in df.columns:
|
|
342
|
+
df.rename(columns={"response": "response_original"}, inplace=True)
|
|
343
|
+
df["response"] = results
|
|
344
|
+
return df
|
|
345
|
+
|
|
346
|
+
def _save_results(self, df: pd.DataFrame, table_path: str) -> str:
|
|
347
|
+
"""保存结果"""
|
|
348
|
+
output_path = Path(table_path).stem + "_result.xlsx"
|
|
349
|
+
df.to_excel(output_path, index=False, engine="openpyxl")
|
|
350
|
+
return output_path
|
|
351
|
+
|
|
352
|
+
def calculate_metrics(
|
|
353
|
+
self,
|
|
354
|
+
df: pd.DataFrame,
|
|
355
|
+
response_col: str = "response",
|
|
356
|
+
label_col: Optional[str] = None,
|
|
357
|
+
parse_response_to_pred: bool = True,
|
|
358
|
+
pred_parsed_tag: str = "一级标签",
|
|
359
|
+
record_root_dir: str = "record",
|
|
360
|
+
):
|
|
361
|
+
"""计算评估指标"""
|
|
362
|
+
calc_binary_metrics(
|
|
363
|
+
df,
|
|
364
|
+
response_col=response_col,
|
|
365
|
+
label_col=label_col,
|
|
366
|
+
parse_response_to_pred=parse_response_to_pred,
|
|
367
|
+
pred_parsed_tag=pred_parsed_tag,
|
|
368
|
+
record_root_dir=record_root_dir,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
async def example_usage():
|
|
373
|
+
"""使用示例"""
|
|
374
|
+
processor = BatchProcessor()
|
|
375
|
+
|
|
376
|
+
# 配置处理参数
|
|
377
|
+
table_path = "多模态输入流-剩余未标注数据.xlsx"
|
|
378
|
+
text_col = "feed_content"
|
|
379
|
+
image_col = "image_src"
|
|
380
|
+
sheet_name = None
|
|
381
|
+
preprocess_msg = False
|
|
382
|
+
|
|
383
|
+
# 处理表格
|
|
384
|
+
result_df = await processor.process_table(
|
|
385
|
+
table_path=table_path,
|
|
386
|
+
text_col=text_col,
|
|
387
|
+
image_col=image_col,
|
|
388
|
+
sheet_name=sheet_name,
|
|
389
|
+
preprocess_msg=preprocess_msg,
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
print(f"处理完成,共处理 {len(result_df)} 行数据")
|
|
393
|
+
|
|
394
|
+
# 计算指标
|
|
395
|
+
processor.calculate_metrics(result_df)
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class InteractiveRunner:
|
|
399
|
+
def __init__(self):
|
|
400
|
+
self.processor = None
|
|
401
|
+
self.models_config = self.load_models_config()
|
|
402
|
+
self.config_manager = ConfigManager()
|
|
403
|
+
self.current_config = {} # 存储当前会话的配置
|
|
404
|
+
|
|
405
|
+
def load_models_config(self) -> dict:
|
|
406
|
+
"""加载模型配置,优先加载当前目录的配置"""
|
|
407
|
+
# 当前目录的配置文件
|
|
408
|
+
current_config_path = "models_config.json"
|
|
409
|
+
# 默认配置文件路径
|
|
410
|
+
default_config_path = Path(__file__).parent / "models_config.json"
|
|
411
|
+
|
|
412
|
+
config_path = (
|
|
413
|
+
current_config_path
|
|
414
|
+
if os.path.exists(current_config_path)
|
|
415
|
+
else default_config_path
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
try:
|
|
419
|
+
with open(config_path, "r", encoding="utf-8") as f:
|
|
420
|
+
return json.load(f)
|
|
421
|
+
except FileNotFoundError:
|
|
422
|
+
print(f"警告: 找不到配置文件 {config_path},使用内置默认配置")
|
|
423
|
+
return {
|
|
424
|
+
"models": [
|
|
425
|
+
{
|
|
426
|
+
"name": "文小言 (wenxiaoyan)",
|
|
427
|
+
"model": "vkk8o2py_wenxiaoyan_mllm_bj",
|
|
428
|
+
"base_url": "https://qianfan.baidubce.com/v2",
|
|
429
|
+
"description": "百度文小言多模态大模型",
|
|
430
|
+
"parse_response_to_pred": True,
|
|
431
|
+
"pred_parsed_tag": "一级标签",
|
|
432
|
+
}
|
|
433
|
+
]
|
|
434
|
+
}
|
|
435
|
+
except json.JSONDecodeError as e:
|
|
436
|
+
print(f"错误: 配置文件格式错误 {e},使用内置默认配置")
|
|
437
|
+
return self.load_models_config() # 递归调用使用默认配置
|
|
438
|
+
|
|
439
|
+
def scan_files(self) -> List[str]:
|
|
440
|
+
"""扫描当前目录下的表格文件"""
|
|
441
|
+
files = []
|
|
442
|
+
for file in os.listdir("."):
|
|
443
|
+
if file.endswith((".xlsx", ".csv")):
|
|
444
|
+
files.append(file)
|
|
445
|
+
return files
|
|
446
|
+
|
|
447
|
+
def select_file(self) -> str:
|
|
448
|
+
"""选择文件"""
|
|
449
|
+
files = self.scan_files()
|
|
450
|
+
if not files:
|
|
451
|
+
print("当前目录下没有找到.xlsx或.csv文件")
|
|
452
|
+
return None
|
|
453
|
+
|
|
454
|
+
print("\n请选择要处理的文件:")
|
|
455
|
+
for i, file in enumerate(files, 1):
|
|
456
|
+
print(f"{i}. {file}")
|
|
457
|
+
|
|
458
|
+
while True:
|
|
459
|
+
try:
|
|
460
|
+
choice = int(input("请输入文件编号: "))
|
|
461
|
+
if 1 <= choice <= len(files):
|
|
462
|
+
return files[choice - 1]
|
|
463
|
+
else:
|
|
464
|
+
print(f"请输入1-{len(files)}之间的数字")
|
|
465
|
+
except ValueError:
|
|
466
|
+
print("请输入有效数字")
|
|
467
|
+
|
|
468
|
+
def select_columns(self, df: pd.DataFrame) -> tuple:
|
|
469
|
+
"""选择文本列和图像列"""
|
|
470
|
+
columns = list(df.columns)
|
|
471
|
+
print("\n表格列信息:")
|
|
472
|
+
for i, col in enumerate(columns, 1):
|
|
473
|
+
sample_data = str(df[col].iloc[0])[:50] if len(df) > 0 else "无数据"
|
|
474
|
+
print(f"{i}. {col} (示例: {sample_data})")
|
|
475
|
+
|
|
476
|
+
# 选择文本列 (支持无文本列)
|
|
477
|
+
print("\n请选择文本列 (输入0表示无文本列):")
|
|
478
|
+
for i, col in enumerate(columns, 1):
|
|
479
|
+
print(f"{i}. {col}")
|
|
480
|
+
print("0. 无文本列")
|
|
481
|
+
|
|
482
|
+
while True:
|
|
483
|
+
try:
|
|
484
|
+
choice = int(input("请选择文本列编号: "))
|
|
485
|
+
if choice == 0:
|
|
486
|
+
text_col = None
|
|
487
|
+
break
|
|
488
|
+
elif 1 <= choice <= len(columns):
|
|
489
|
+
text_col = columns[choice - 1]
|
|
490
|
+
break
|
|
491
|
+
else:
|
|
492
|
+
print(f"请输入0-{len(columns)}之间的数字")
|
|
493
|
+
except ValueError:
|
|
494
|
+
print("请输入有效数字")
|
|
495
|
+
|
|
496
|
+
# 选择图像列
|
|
497
|
+
print("\n请选择图像列 (输入0表示无图像列):")
|
|
498
|
+
for i, col in enumerate(columns, 1):
|
|
499
|
+
print(f"{i}. {col}")
|
|
500
|
+
print("0. 无图像列")
|
|
501
|
+
|
|
502
|
+
while True:
|
|
503
|
+
try:
|
|
504
|
+
choice = int(input("请选择图像列编号: "))
|
|
505
|
+
if choice == 0:
|
|
506
|
+
image_col = None
|
|
507
|
+
break
|
|
508
|
+
elif 1 <= choice <= len(columns):
|
|
509
|
+
image_col = columns[choice - 1]
|
|
510
|
+
break
|
|
511
|
+
else:
|
|
512
|
+
print(f"请输入0-{len(columns)}之间的数字")
|
|
513
|
+
except ValueError:
|
|
514
|
+
print("请输入有效数字")
|
|
515
|
+
|
|
516
|
+
return text_col, image_col
|
|
517
|
+
|
|
518
|
+
def select_filter_column(self, df: pd.DataFrame) -> Optional[tuple]:
|
|
519
|
+
"""选择筛选列及筛选值"""
|
|
520
|
+
columns = list(df.columns)
|
|
521
|
+
print("\n数据筛选 (可选):")
|
|
522
|
+
print("0. 不筛选数据")
|
|
523
|
+
for i, col in enumerate(columns, 1):
|
|
524
|
+
print(f"{i}. 筛选列 '{col}'")
|
|
525
|
+
|
|
526
|
+
while True:
|
|
527
|
+
try:
|
|
528
|
+
choice = int(input("请选择是否筛选数据: "))
|
|
529
|
+
if choice == 0:
|
|
530
|
+
return None
|
|
531
|
+
elif 1 <= choice <= len(columns):
|
|
532
|
+
filter_col = columns[choice - 1]
|
|
533
|
+
break
|
|
534
|
+
else:
|
|
535
|
+
print(f"请输入0-{len(columns)}之间的数字")
|
|
536
|
+
except ValueError:
|
|
537
|
+
print("请输入有效数字")
|
|
538
|
+
|
|
539
|
+
# 统计该列的值分布
|
|
540
|
+
value_counts = df[filter_col].value_counts().head(50) # 最多显示50个
|
|
541
|
+
print(f"\n列 '{filter_col}' 的值分布:")
|
|
542
|
+
print("-" * 40)
|
|
543
|
+
for i, (value, count) in enumerate(value_counts.items(), 1):
|
|
544
|
+
print(f"{i:2d}. {value} (出现{count}次)")
|
|
545
|
+
|
|
546
|
+
# 选择筛选值
|
|
547
|
+
print("\n筛选方式:")
|
|
548
|
+
print("1. 正选 - 保留选中的值")
|
|
549
|
+
print("2. 反选 - 排除选中的值,保留其他值")
|
|
550
|
+
|
|
551
|
+
# 选择筛选方式
|
|
552
|
+
while True:
|
|
553
|
+
try:
|
|
554
|
+
filter_mode = int(input("请选择筛选方式 (1/2): "))
|
|
555
|
+
if filter_mode in [1, 2]:
|
|
556
|
+
break
|
|
557
|
+
else:
|
|
558
|
+
print("请输入1或2")
|
|
559
|
+
except ValueError:
|
|
560
|
+
print("请输入有效数字")
|
|
561
|
+
|
|
562
|
+
mode_text = "保留" if filter_mode == 1 else "排除"
|
|
563
|
+
print(f"\n请选择要{mode_text}的值 (可多选,用逗号分隔编号,如: 1,3,5):")
|
|
564
|
+
|
|
565
|
+
while True:
|
|
566
|
+
try:
|
|
567
|
+
choices_input = input("筛选值编号: ").strip()
|
|
568
|
+
if not choices_input:
|
|
569
|
+
print("请输入至少一个编号")
|
|
570
|
+
continue
|
|
571
|
+
|
|
572
|
+
choices = [int(x.strip()) for x in choices_input.split(",")]
|
|
573
|
+
selected_indices = []
|
|
574
|
+
|
|
575
|
+
for choice in choices:
|
|
576
|
+
if 1 <= choice <= len(value_counts):
|
|
577
|
+
selected_indices.append(choice - 1)
|
|
578
|
+
else:
|
|
579
|
+
print(f"编号 {choice} 超出范围,请重新输入")
|
|
580
|
+
break
|
|
581
|
+
else:
|
|
582
|
+
# 所有选择都有效
|
|
583
|
+
if filter_mode == 1:
|
|
584
|
+
# 正选:保留选中的值
|
|
585
|
+
selected_values = [
|
|
586
|
+
value_counts.index[i] for i in selected_indices
|
|
587
|
+
]
|
|
588
|
+
print(f"已选择保留值: {selected_values}")
|
|
589
|
+
else:
|
|
590
|
+
# 反选:保留未选中的值
|
|
591
|
+
all_indices = set(range(len(value_counts)))
|
|
592
|
+
remaining_indices = all_indices - set(selected_indices)
|
|
593
|
+
selected_values = [
|
|
594
|
+
value_counts.index[i] for i in remaining_indices
|
|
595
|
+
]
|
|
596
|
+
excluded_values = [
|
|
597
|
+
value_counts.index[i] for i in selected_indices
|
|
598
|
+
]
|
|
599
|
+
print(f"已排除值: {excluded_values}")
|
|
600
|
+
print(f"将保留值: {selected_values}")
|
|
601
|
+
|
|
602
|
+
return filter_col, selected_values
|
|
603
|
+
|
|
604
|
+
except ValueError:
|
|
605
|
+
print("请输入有效的数字编号")
|
|
606
|
+
|
|
607
|
+
def apply_filter(
|
|
608
|
+
self, df: pd.DataFrame, filter_config: Optional[tuple]
|
|
609
|
+
) -> pd.DataFrame:
|
|
610
|
+
"""应用数据筛选"""
|
|
611
|
+
if filter_config is None:
|
|
612
|
+
return df
|
|
613
|
+
|
|
614
|
+
filter_col, selected_values = filter_config
|
|
615
|
+
filtered_df = df[df[filter_col].isin(selected_values)].copy()
|
|
616
|
+
print(f"筛选后数据: {len(filtered_df)} 行 (原始: {len(df)} 行)")
|
|
617
|
+
return filtered_df
|
|
618
|
+
|
|
619
|
+
def select_model(self) -> tuple:
|
|
620
|
+
"""选择模型和base_url"""
|
|
621
|
+
models = self.models_config.get("models", [])
|
|
622
|
+
|
|
623
|
+
print("\n请选择模型:")
|
|
624
|
+
for i, model_info in enumerate(models, 1):
|
|
625
|
+
print(f"{i}. {model_info['name']} - {model_info['description']}")
|
|
626
|
+
|
|
627
|
+
while True:
|
|
628
|
+
try:
|
|
629
|
+
choice = int(input("请选择模型编号: "))
|
|
630
|
+
if 1 <= choice <= len(models):
|
|
631
|
+
selected = models[choice - 1]
|
|
632
|
+
|
|
633
|
+
if selected["model"] == "custom":
|
|
634
|
+
# 自定义模型
|
|
635
|
+
base_url = input("请输入base_url: ")
|
|
636
|
+
model = input("请输入model名称: ")
|
|
637
|
+
# 自定义模型返回基本信息,解析配置后续处理
|
|
638
|
+
return {
|
|
639
|
+
"model": model,
|
|
640
|
+
"base_url": base_url,
|
|
641
|
+
"name": "自定义模型",
|
|
642
|
+
"description": "用户自定义",
|
|
643
|
+
"parse_response_to_pred": False,
|
|
644
|
+
"pred_parsed_tag": None,
|
|
645
|
+
}
|
|
646
|
+
else:
|
|
647
|
+
return selected
|
|
648
|
+
else:
|
|
649
|
+
print(f"请输入1-{len(models)}之间的数字")
|
|
650
|
+
except ValueError:
|
|
651
|
+
print("请输入有效数字")
|
|
652
|
+
|
|
653
|
+
def select_prompt(self) -> Optional[str]:
|
|
654
|
+
"""选择提示模板"""
|
|
655
|
+
print("\n请选择提示模板:")
|
|
656
|
+
print("1. 使用默认提示模板")
|
|
657
|
+
print("2. 自定义提示模板")
|
|
658
|
+
|
|
659
|
+
while True:
|
|
660
|
+
try:
|
|
661
|
+
choice = int(input("请选择: "))
|
|
662
|
+
if choice == 1:
|
|
663
|
+
return None
|
|
664
|
+
elif choice == 2:
|
|
665
|
+
print("\n请输入自定义提示模板 (使用{text_content}作为文本占位符):")
|
|
666
|
+
return input("提示模板: ")
|
|
667
|
+
else:
|
|
668
|
+
print("请输入1或2")
|
|
669
|
+
except ValueError:
|
|
670
|
+
print("请输入有效数字")
|
|
671
|
+
|
|
672
|
+
def select_rows(self, df: pd.DataFrame) -> tuple:
|
|
673
|
+
"""选择要处理的行数范围"""
|
|
674
|
+
total_rows = len(df)
|
|
675
|
+
print(f"\n行数选择 (总共 {total_rows} 行):")
|
|
676
|
+
print("1. 处理所有行")
|
|
677
|
+
print("2. 指定行数 (从第1行开始)")
|
|
678
|
+
print("3. 指定行数范围 (从第n行到第m行)")
|
|
679
|
+
|
|
680
|
+
while True:
|
|
681
|
+
try:
|
|
682
|
+
choice = int(input("请选择: "))
|
|
683
|
+
if choice == 1:
|
|
684
|
+
return 0, total_rows
|
|
685
|
+
elif choice == 2:
|
|
686
|
+
while True:
|
|
687
|
+
try:
|
|
688
|
+
count = int(input(f"请输入要处理的行数 (1-{total_rows}): "))
|
|
689
|
+
if 1 <= count <= total_rows:
|
|
690
|
+
return 0, count
|
|
691
|
+
print(f"行数必须在1-{total_rows}之间")
|
|
692
|
+
except ValueError:
|
|
693
|
+
print("请输入有效数字")
|
|
694
|
+
elif choice == 3:
|
|
695
|
+
while True:
|
|
696
|
+
try:
|
|
697
|
+
start = (
|
|
698
|
+
int(input(f"起始行号 (1-{total_rows}): ")) - 1
|
|
699
|
+
) # 转为0基索引
|
|
700
|
+
end = int(
|
|
701
|
+
input(f"结束行号 ({start + 2}-{total_rows}): ")
|
|
702
|
+
) # 显示基于1的索引
|
|
703
|
+
if 0 <= start < end <= total_rows:
|
|
704
|
+
return start, end
|
|
705
|
+
print("请确保起始行号小于结束行号,且在有效范围内")
|
|
706
|
+
except ValueError:
|
|
707
|
+
print("请输入有效数字")
|
|
708
|
+
else:
|
|
709
|
+
print("请输入1、2或3")
|
|
710
|
+
except ValueError:
|
|
711
|
+
print("请输入有效数字")
|
|
712
|
+
|
|
713
|
+
def select_config(self) -> dict:
|
|
714
|
+
"""选择处理配置"""
|
|
715
|
+
print("\n配置选项:")
|
|
716
|
+
|
|
717
|
+
# 预处理选择
|
|
718
|
+
while True:
|
|
719
|
+
preprocess = input("是否预处理图像? (y/n): ").lower()
|
|
720
|
+
if preprocess in ["y", "yes", "n", "no"]:
|
|
721
|
+
preprocess_msg = preprocess in ["y", "yes"]
|
|
722
|
+
break
|
|
723
|
+
print("请输入y或n")
|
|
724
|
+
|
|
725
|
+
# 并发数设置
|
|
726
|
+
while True:
|
|
727
|
+
try:
|
|
728
|
+
concurrency = int(input("并发数量 (默认100): ") or "100")
|
|
729
|
+
if concurrency > 0:
|
|
730
|
+
break
|
|
731
|
+
print("并发数必须大于0")
|
|
732
|
+
except ValueError:
|
|
733
|
+
print("请输入有效数字")
|
|
734
|
+
|
|
735
|
+
# QPS设置
|
|
736
|
+
while True:
|
|
737
|
+
try:
|
|
738
|
+
qps = int(input("最大QPS (默认25): ") or "25")
|
|
739
|
+
if qps > 0:
|
|
740
|
+
break
|
|
741
|
+
print("QPS必须大于0")
|
|
742
|
+
except ValueError:
|
|
743
|
+
print("请输入有效数字")
|
|
744
|
+
|
|
745
|
+
# 分类模式选择
|
|
746
|
+
while True:
|
|
747
|
+
use_cls_input = input("是否使用分类模式? (y/n, 默认n): ").lower() or "n"
|
|
748
|
+
if use_cls_input in ["y", "yes", "n", "no"]:
|
|
749
|
+
use_cls = use_cls_input in ["y", "yes"]
|
|
750
|
+
break
|
|
751
|
+
print("请输入y或n")
|
|
752
|
+
|
|
753
|
+
# 系统提示选择
|
|
754
|
+
system_prompt = None
|
|
755
|
+
while True:
|
|
756
|
+
system_input = input("是否添加系统提示? (y/n, 默认n): ").lower() or "n"
|
|
757
|
+
if system_input in ["y", "yes"]:
|
|
758
|
+
system_prompt = input("请输入系统提示内容: ").strip()
|
|
759
|
+
if not system_prompt:
|
|
760
|
+
print("系统提示不能为空")
|
|
761
|
+
continue
|
|
762
|
+
break
|
|
763
|
+
elif system_input in ["n", "no"]:
|
|
764
|
+
break
|
|
765
|
+
else:
|
|
766
|
+
print("请输入y或n")
|
|
767
|
+
|
|
768
|
+
return {
|
|
769
|
+
"preprocess_msg": preprocess_msg,
|
|
770
|
+
"concurrency_limit": concurrency,
|
|
771
|
+
"max_qps": qps,
|
|
772
|
+
"use_cls": use_cls,
|
|
773
|
+
"system_prompt": system_prompt,
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
def select_metrics_config(
|
|
777
|
+
self, df: pd.DataFrame, model_config: dict
|
|
778
|
+
) -> Optional[dict]:
|
|
779
|
+
"""选择指标计算配置"""
|
|
780
|
+
while True:
|
|
781
|
+
analyze = input("\n是否需要分析结果? (y/n): ").lower()
|
|
782
|
+
if analyze in ["y", "yes"]:
|
|
783
|
+
# 选择标签列
|
|
784
|
+
columns = list(df.columns)
|
|
785
|
+
print("\n选择标签列 (输入0表示无标签列):")
|
|
786
|
+
for i, col in enumerate(columns, 1):
|
|
787
|
+
print(f"{i}. {col}")
|
|
788
|
+
print("0. 无标签列")
|
|
789
|
+
|
|
790
|
+
while True:
|
|
791
|
+
try:
|
|
792
|
+
choice = int(input("请选择标签列编号: "))
|
|
793
|
+
if choice == 0:
|
|
794
|
+
label_col = None
|
|
795
|
+
break
|
|
796
|
+
elif 1 <= choice <= len(columns):
|
|
797
|
+
label_col = columns[choice - 1]
|
|
798
|
+
break
|
|
799
|
+
else:
|
|
800
|
+
print(f"请输入0-{len(columns)}之间的数字")
|
|
801
|
+
except ValueError:
|
|
802
|
+
print("请输入有效数字")
|
|
803
|
+
|
|
804
|
+
# 获取模型的默认解析配置
|
|
805
|
+
default_parse = model_config.get("parse_response_to_pred", False)
|
|
806
|
+
default_tag = model_config.get("pred_parsed_tag", None)
|
|
807
|
+
|
|
808
|
+
# 选择是否解析响应
|
|
809
|
+
print(f"\n解析响应设置 (模型默认: {default_parse}):")
|
|
810
|
+
print("1. 使用模型默认配置")
|
|
811
|
+
print("2. 自定义配置")
|
|
812
|
+
|
|
813
|
+
while True:
|
|
814
|
+
try:
|
|
815
|
+
parse_choice = int(input("请选择: "))
|
|
816
|
+
if parse_choice == 1:
|
|
817
|
+
parse_response_to_pred = default_parse
|
|
818
|
+
pred_parsed_tag = default_tag
|
|
819
|
+
break
|
|
820
|
+
elif parse_choice == 2:
|
|
821
|
+
# 自定义解析配置
|
|
822
|
+
while True:
|
|
823
|
+
parse_input = input(
|
|
824
|
+
"是否解析响应为预测结果? (y/n): "
|
|
825
|
+
).lower()
|
|
826
|
+
if parse_input in ["y", "yes"]:
|
|
827
|
+
parse_response_to_pred = True
|
|
828
|
+
pred_parsed_tag = input(
|
|
829
|
+
"请输入解析标签 (pred_parsed_tag): "
|
|
830
|
+
)
|
|
831
|
+
break
|
|
832
|
+
elif parse_input in ["n", "no"]:
|
|
833
|
+
parse_response_to_pred = False
|
|
834
|
+
pred_parsed_tag = None
|
|
835
|
+
break
|
|
836
|
+
else:
|
|
837
|
+
print("请输入y或n")
|
|
838
|
+
break
|
|
839
|
+
else:
|
|
840
|
+
print("请输入1或2")
|
|
841
|
+
except ValueError:
|
|
842
|
+
print("请输入有效数字")
|
|
843
|
+
|
|
844
|
+
return {
|
|
845
|
+
"label_col": label_col,
|
|
846
|
+
"parse_response_to_pred": parse_response_to_pred,
|
|
847
|
+
"pred_parsed_tag": pred_parsed_tag,
|
|
848
|
+
}
|
|
849
|
+
elif analyze in ["n", "no"]:
|
|
850
|
+
return None
|
|
851
|
+
print("请输入y或n")
|
|
852
|
+
|
|
853
|
+
def select_config_mode(self) -> str:
|
|
854
|
+
"""选择配置模式"""
|
|
855
|
+
last_config = self.config_manager.load_last_config()
|
|
856
|
+
|
|
857
|
+
print("=== 批量处理工具 ===")
|
|
858
|
+
|
|
859
|
+
if last_config:
|
|
860
|
+
print("\n检测到历史配置!")
|
|
861
|
+
print("1. 快速开始 (使用上次配置)")
|
|
862
|
+
print("2. 修改配置 (基于上次配置修改)")
|
|
863
|
+
print("3. 从头配置")
|
|
864
|
+
|
|
865
|
+
while True:
|
|
866
|
+
try:
|
|
867
|
+
choice = int(input("\n请选择模式 (1/2/3): "))
|
|
868
|
+
if choice == 1:
|
|
869
|
+
self.current_config = last_config.copy()
|
|
870
|
+
return "quick_start"
|
|
871
|
+
elif choice == 2:
|
|
872
|
+
self.current_config = last_config.copy()
|
|
873
|
+
return "modify_config"
|
|
874
|
+
elif choice == 3:
|
|
875
|
+
return "fresh_start"
|
|
876
|
+
else:
|
|
877
|
+
print("请输入1、2或3")
|
|
878
|
+
except ValueError:
|
|
879
|
+
print("请输入有效数字")
|
|
880
|
+
else:
|
|
881
|
+
print("\n首次使用,开始配置...")
|
|
882
|
+
return "fresh_start"
|
|
883
|
+
|
|
884
|
+
def select_modifications(self) -> List[int]:
|
|
885
|
+
"""选择需要修改的配置项"""
|
|
886
|
+
print("\n" + self.config_manager.display_config_preview(self.current_config))
|
|
887
|
+
|
|
888
|
+
print(f"\n请选择要修改的项目 (可多选,用逗号分隔编号):")
|
|
889
|
+
for item_id, item_info in self.config_manager.MODIFIABLE_ITEMS.items():
|
|
890
|
+
print(f"{item_id:2d}. {item_info['display']}")
|
|
891
|
+
print(" 0. 不修改,直接使用上述配置")
|
|
892
|
+
|
|
893
|
+
while True:
|
|
894
|
+
try:
|
|
895
|
+
choices_input = input("\n修改项编号: ").strip()
|
|
896
|
+
if choices_input == "0" or not choices_input:
|
|
897
|
+
return []
|
|
898
|
+
|
|
899
|
+
choices = [int(x.strip()) for x in choices_input.split(",")]
|
|
900
|
+
valid_choices = []
|
|
901
|
+
|
|
902
|
+
for choice in choices:
|
|
903
|
+
if 1 <= choice <= len(self.config_manager.MODIFIABLE_ITEMS):
|
|
904
|
+
valid_choices.append(choice)
|
|
905
|
+
else:
|
|
906
|
+
print(f"编号 {choice} 超出范围,请重新输入")
|
|
907
|
+
break
|
|
908
|
+
else:
|
|
909
|
+
return valid_choices
|
|
910
|
+
|
|
911
|
+
except ValueError:
|
|
912
|
+
print("请输入有效的数字编号")
|
|
913
|
+
|
|
914
|
+
async def run(self):
|
|
915
|
+
"""运行交互式处理流程"""
|
|
916
|
+
# 1. 选择配置模式
|
|
917
|
+
config_mode = self.select_config_mode()
|
|
918
|
+
|
|
919
|
+
if config_mode == "quick_start":
|
|
920
|
+
# 快速开始模式:直接使用历史配置
|
|
921
|
+
await self.run_with_config()
|
|
922
|
+
return
|
|
923
|
+
elif config_mode == "modify_config":
|
|
924
|
+
# 修改配置模式:选择性修改
|
|
925
|
+
modifications = self.select_modifications()
|
|
926
|
+
if not modifications:
|
|
927
|
+
await self.run_with_config()
|
|
928
|
+
return
|
|
929
|
+
await self.run_with_selective_config(modifications)
|
|
930
|
+
return
|
|
931
|
+
else:
|
|
932
|
+
# 从头配置模式:完整流程
|
|
933
|
+
await self.run_full_config()
|
|
934
|
+
|
|
935
|
+
async def run_full_config(self):
|
|
936
|
+
"""完整配置流程"""
|
|
937
|
+
print("\n=== 完整配置模式 ===")
|
|
938
|
+
|
|
939
|
+
# 1. 选择文件
|
|
940
|
+
file_path = self.select_file()
|
|
941
|
+
if not file_path:
|
|
942
|
+
return
|
|
943
|
+
self.current_config["file_path"] = file_path
|
|
944
|
+
|
|
945
|
+
# 继续完整流程配置
|
|
946
|
+
await self._complete_config_flow(file_path)
|
|
947
|
+
|
|
948
|
+
async def _complete_config_flow(self, file_path: str):
|
|
949
|
+
"""完成配置流程并执行处理"""
|
|
950
|
+
# 2. 加载并选择列
|
|
951
|
+
df_preview = (
|
|
952
|
+
pd.read_excel(file_path)
|
|
953
|
+
if file_path.endswith(".xlsx")
|
|
954
|
+
else pd.read_csv(file_path)
|
|
955
|
+
)
|
|
956
|
+
text_col, image_col = self.select_columns(df_preview)
|
|
957
|
+
|
|
958
|
+
# 检查是否至少有一列被选择
|
|
959
|
+
if text_col is None and image_col is None:
|
|
960
|
+
print("错误:必须至少选择一个文本列或图像列")
|
|
961
|
+
return
|
|
962
|
+
|
|
963
|
+
# 更新配置
|
|
964
|
+
self.current_config.update({"text_col": text_col, "image_col": image_col})
|
|
965
|
+
|
|
966
|
+
# 3. 数据筛选
|
|
967
|
+
filter_config = self.select_filter_column(df_preview)
|
|
968
|
+
df_filtered = self.apply_filter(df_preview, filter_config)
|
|
969
|
+
self.current_config["filter_config"] = filter_config
|
|
970
|
+
|
|
971
|
+
# 4. 选择模型
|
|
972
|
+
model_info = self.select_model()
|
|
973
|
+
self.current_config["model_info"] = model_info
|
|
974
|
+
|
|
975
|
+
# 5. 选择提示模板
|
|
976
|
+
custom_prompt = self.select_prompt()
|
|
977
|
+
self.current_config["custom_prompt"] = custom_prompt
|
|
978
|
+
|
|
979
|
+
# 6. 选择行数范围
|
|
980
|
+
start_row, end_row = self.select_rows(df_filtered)
|
|
981
|
+
df_to_process = df_filtered.iloc[start_row:end_row].copy()
|
|
982
|
+
self.current_config["rows_range"] = (start_row, end_row)
|
|
983
|
+
|
|
984
|
+
# 7. 配置参数
|
|
985
|
+
config = self.select_config()
|
|
986
|
+
self.current_config.update(
|
|
987
|
+
{
|
|
988
|
+
"preprocess_msg": config["preprocess_msg"],
|
|
989
|
+
"concurrency_limit": config["concurrency_limit"],
|
|
990
|
+
"max_qps": config["max_qps"],
|
|
991
|
+
"use_cls": config["use_cls"],
|
|
992
|
+
"system_prompt": config["system_prompt"],
|
|
993
|
+
}
|
|
994
|
+
)
|
|
995
|
+
|
|
996
|
+
# 保存配置
|
|
997
|
+
self.config_manager.save_config(self.current_config)
|
|
998
|
+
|
|
999
|
+
# 执行处理
|
|
1000
|
+
model_name = model_info["model"]
|
|
1001
|
+
base_url = model_info["base_url"]
|
|
1002
|
+
await self._execute_processing(
|
|
1003
|
+
df_to_process,
|
|
1004
|
+
model_name,
|
|
1005
|
+
base_url,
|
|
1006
|
+
model_info,
|
|
1007
|
+
custom_prompt,
|
|
1008
|
+
config,
|
|
1009
|
+
start_row,
|
|
1010
|
+
end_row,
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
async def run_with_config(self):
|
|
1014
|
+
"""使用现有配置运行"""
|
|
1015
|
+
print("\n=== 快速开始模式 ===")
|
|
1016
|
+
print("使用历史配置直接处理...")
|
|
1017
|
+
|
|
1018
|
+
# 验证配置文件是否存在
|
|
1019
|
+
file_path = self.current_config.get("file_path")
|
|
1020
|
+
if not file_path or not os.path.exists(file_path):
|
|
1021
|
+
print(f"错误:配置中的文件 {file_path} 不存在")
|
|
1022
|
+
# 重新选择文件
|
|
1023
|
+
file_path = self.select_file()
|
|
1024
|
+
if not file_path:
|
|
1025
|
+
return
|
|
1026
|
+
self.current_config["file_path"] = file_path
|
|
1027
|
+
|
|
1028
|
+
# 从配置中恢复数据并执行
|
|
1029
|
+
await self._execute_from_config()
|
|
1030
|
+
|
|
1031
|
+
async def run_with_selective_config(self, modifications: List[int]):
|
|
1032
|
+
"""选择性修改配置并运行"""
|
|
1033
|
+
print("\n=== 修改配置模式 ===")
|
|
1034
|
+
print(
|
|
1035
|
+
f"需要修改的配置项: {[self.config_manager.MODIFIABLE_ITEMS[i]['display'] for i in modifications]}"
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
file_path = self.current_config.get("file_path")
|
|
1039
|
+
|
|
1040
|
+
# 按顺序处理需要修改的配置项
|
|
1041
|
+
for item_id in modifications:
|
|
1042
|
+
item_info = self.config_manager.MODIFIABLE_ITEMS[item_id]
|
|
1043
|
+
print(f"\n正在修改: {item_info['display']}")
|
|
1044
|
+
|
|
1045
|
+
if item_id == 1: # 文件选择
|
|
1046
|
+
file_path = self.select_file()
|
|
1047
|
+
if not file_path:
|
|
1048
|
+
return
|
|
1049
|
+
self.current_config["file_path"] = file_path
|
|
1050
|
+
|
|
1051
|
+
elif item_id in [2, 3]: # 文本列或图像列
|
|
1052
|
+
if not file_path or not os.path.exists(file_path):
|
|
1053
|
+
print("错误:需要先选择有效的文件")
|
|
1054
|
+
continue
|
|
1055
|
+
df_preview = (
|
|
1056
|
+
pd.read_excel(file_path)
|
|
1057
|
+
if file_path.endswith(".xlsx")
|
|
1058
|
+
else pd.read_csv(file_path)
|
|
1059
|
+
)
|
|
1060
|
+
text_col, image_col = self.select_columns(df_preview)
|
|
1061
|
+
self.current_config.update(
|
|
1062
|
+
{"text_col": text_col, "image_col": image_col}
|
|
1063
|
+
)
|
|
1064
|
+
|
|
1065
|
+
elif item_id == 4: # 数据筛选
|
|
1066
|
+
if not file_path:
|
|
1067
|
+
print("错误:需要先选择文件")
|
|
1068
|
+
continue
|
|
1069
|
+
df_preview = (
|
|
1070
|
+
pd.read_excel(file_path)
|
|
1071
|
+
if file_path.endswith(".xlsx")
|
|
1072
|
+
else pd.read_csv(file_path)
|
|
1073
|
+
)
|
|
1074
|
+
filter_config = self.select_filter_column(df_preview)
|
|
1075
|
+
self.current_config["filter_config"] = filter_config
|
|
1076
|
+
|
|
1077
|
+
elif item_id == 5: # 模型选择
|
|
1078
|
+
model_info = self.select_model()
|
|
1079
|
+
self.current_config["model_info"] = model_info
|
|
1080
|
+
|
|
1081
|
+
elif item_id == 6: # 提示模板
|
|
1082
|
+
custom_prompt = self.select_prompt()
|
|
1083
|
+
self.current_config["custom_prompt"] = custom_prompt
|
|
1084
|
+
|
|
1085
|
+
elif item_id == 7: # 行数范围
|
|
1086
|
+
if not file_path:
|
|
1087
|
+
print("错误:需要先选择文件")
|
|
1088
|
+
continue
|
|
1089
|
+
df_preview = (
|
|
1090
|
+
pd.read_excel(file_path)
|
|
1091
|
+
if file_path.endswith(".xlsx")
|
|
1092
|
+
else pd.read_csv(file_path)
|
|
1093
|
+
)
|
|
1094
|
+
# 应用筛选
|
|
1095
|
+
filter_config = self.current_config.get("filter_config")
|
|
1096
|
+
df_filtered = self.apply_filter(df_preview, filter_config)
|
|
1097
|
+
start_row, end_row = self.select_rows(df_filtered)
|
|
1098
|
+
self.current_config["rows_range"] = (start_row, end_row)
|
|
1099
|
+
|
|
1100
|
+
elif item_id in [8, 9, 10, 12, 13]: # 预处理、并发、QPS、分类模式、系统提示
|
|
1101
|
+
config = self.select_config()
|
|
1102
|
+
self.current_config.update(
|
|
1103
|
+
{
|
|
1104
|
+
"preprocess_msg": config["preprocess_msg"],
|
|
1105
|
+
"concurrency_limit": config["concurrency_limit"],
|
|
1106
|
+
"max_qps": config["max_qps"],
|
|
1107
|
+
"use_cls": config["use_cls"],
|
|
1108
|
+
"system_prompt": config["system_prompt"],
|
|
1109
|
+
}
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
# 保存更新后的配置
|
|
1113
|
+
self.config_manager.save_config(self.current_config)
|
|
1114
|
+
|
|
1115
|
+
# 执行处理
|
|
1116
|
+
await self._execute_from_config()
|
|
1117
|
+
|
|
1118
|
+
async def _execute_from_config(self):
|
|
1119
|
+
"""从配置中执行处理"""
|
|
1120
|
+
file_path = self.current_config["file_path"]
|
|
1121
|
+
|
|
1122
|
+
# 加载数据
|
|
1123
|
+
df_preview = (
|
|
1124
|
+
pd.read_excel(file_path)
|
|
1125
|
+
if file_path.endswith(".xlsx")
|
|
1126
|
+
else pd.read_csv(file_path)
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
# 应用筛选
|
|
1130
|
+
filter_config = self.current_config.get("filter_config")
|
|
1131
|
+
df_filtered = self.apply_filter(df_preview, filter_config)
|
|
1132
|
+
|
|
1133
|
+
# 获取行数范围
|
|
1134
|
+
rows_range = self.current_config.get("rows_range", (0, len(df_filtered)))
|
|
1135
|
+
start_row, end_row = rows_range
|
|
1136
|
+
df_to_process = df_filtered.iloc[start_row:end_row].copy()
|
|
1137
|
+
|
|
1138
|
+
# 获取模型信息
|
|
1139
|
+
model_info = self.current_config["model_info"]
|
|
1140
|
+
model_name = model_info["model"]
|
|
1141
|
+
base_url = model_info["base_url"]
|
|
1142
|
+
|
|
1143
|
+
# 获取其他配置
|
|
1144
|
+
custom_prompt = self.current_config.get("custom_prompt")
|
|
1145
|
+
config = {
|
|
1146
|
+
"preprocess_msg": self.current_config.get("preprocess_msg", False),
|
|
1147
|
+
"concurrency_limit": self.current_config.get("concurrency_limit", 100),
|
|
1148
|
+
"max_qps": self.current_config.get("max_qps", 25),
|
|
1149
|
+
"use_cls": self.current_config.get("use_cls", False),
|
|
1150
|
+
"system_prompt": self.current_config.get("system_prompt"),
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
# 执行处理
|
|
1154
|
+
await self._execute_processing(
|
|
1155
|
+
df_to_process,
|
|
1156
|
+
model_name,
|
|
1157
|
+
base_url,
|
|
1158
|
+
model_info,
|
|
1159
|
+
custom_prompt,
|
|
1160
|
+
config,
|
|
1161
|
+
start_row,
|
|
1162
|
+
end_row,
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
async def _execute_processing(
|
|
1166
|
+
self,
|
|
1167
|
+
df_to_process,
|
|
1168
|
+
model_name,
|
|
1169
|
+
base_url,
|
|
1170
|
+
model_info,
|
|
1171
|
+
custom_prompt,
|
|
1172
|
+
config,
|
|
1173
|
+
start_row,
|
|
1174
|
+
end_row,
|
|
1175
|
+
):
|
|
1176
|
+
"""执行实际的处理逻辑"""
|
|
1177
|
+
text_col = self.current_config.get("text_col")
|
|
1178
|
+
image_col = self.current_config.get("image_col")
|
|
1179
|
+
file_path = self.current_config["file_path"]
|
|
1180
|
+
|
|
1181
|
+
# 8. 创建处理器
|
|
1182
|
+
self.processor = BatchProcessor(
|
|
1183
|
+
model=model_name,
|
|
1184
|
+
base_url=base_url,
|
|
1185
|
+
concurrency_limit=config["concurrency_limit"],
|
|
1186
|
+
max_qps=config["max_qps"],
|
|
1187
|
+
)
|
|
1188
|
+
|
|
1189
|
+
# 显示请求预览
|
|
1190
|
+
self._show_request_preview(
|
|
1191
|
+
df_to_process,
|
|
1192
|
+
text_col,
|
|
1193
|
+
image_col,
|
|
1194
|
+
custom_prompt,
|
|
1195
|
+
config,
|
|
1196
|
+
start_row,
|
|
1197
|
+
end_row,
|
|
1198
|
+
model_name,
|
|
1199
|
+
base_url,
|
|
1200
|
+
)
|
|
1201
|
+
|
|
1202
|
+
print("\n开始调用模型...")
|
|
1203
|
+
print(f"模型信息: {model_name}")
|
|
1204
|
+
print(f"API地址: {base_url}")
|
|
1205
|
+
print("正在处理...")
|
|
1206
|
+
|
|
1207
|
+
result_df = await self._process_selected_rows(
|
|
1208
|
+
file_path,
|
|
1209
|
+
df_to_process,
|
|
1210
|
+
text_col,
|
|
1211
|
+
image_col,
|
|
1212
|
+
config["preprocess_msg"],
|
|
1213
|
+
custom_prompt,
|
|
1214
|
+
start_row,
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
print(f"\n处理完成,共处理 {len(result_df)} 行数据")
|
|
1218
|
+
|
|
1219
|
+
# 9. 选择是否分析结果
|
|
1220
|
+
metrics_config = self.select_metrics_config(result_df, model_info)
|
|
1221
|
+
if metrics_config:
|
|
1222
|
+
print("\n开始分析结果...")
|
|
1223
|
+
self.processor.calculate_metrics(
|
|
1224
|
+
result_df,
|
|
1225
|
+
label_col=metrics_config["label_col"],
|
|
1226
|
+
parse_response_to_pred=metrics_config["parse_response_to_pred"],
|
|
1227
|
+
pred_parsed_tag=metrics_config["pred_parsed_tag"],
|
|
1228
|
+
)
|
|
1229
|
+
print("分析完成!")
|
|
1230
|
+
|
|
1231
|
+
async def _process_selected_rows(
|
|
1232
|
+
self,
|
|
1233
|
+
file_path: str,
|
|
1234
|
+
df_to_process: pd.DataFrame,
|
|
1235
|
+
text_col: Optional[str],
|
|
1236
|
+
image_col: Optional[str],
|
|
1237
|
+
preprocess_msg: bool,
|
|
1238
|
+
custom_prompt: Optional[str],
|
|
1239
|
+
start_row: int,
|
|
1240
|
+
) -> pd.DataFrame:
|
|
1241
|
+
"""处理选定的行"""
|
|
1242
|
+
# 从全局配置获取新参数
|
|
1243
|
+
use_cls = self.current_config.get("use_cls", False)
|
|
1244
|
+
system_prompt = self.current_config.get("system_prompt")
|
|
1245
|
+
|
|
1246
|
+
# 生成消息列表
|
|
1247
|
+
messages_list = self.processor._create_messages_list(
|
|
1248
|
+
df_to_process, text_col, image_col, custom_prompt, use_cls, system_prompt
|
|
1249
|
+
)
|
|
1250
|
+
|
|
1251
|
+
# 批量调用API
|
|
1252
|
+
results = await self.processor._call_llm_batch(messages_list, preprocess_msg)
|
|
1253
|
+
|
|
1254
|
+
# 将结果添加到DataFrame
|
|
1255
|
+
df_result = df_to_process.copy()
|
|
1256
|
+
if "response" in df_result.columns:
|
|
1257
|
+
df_result.rename(columns={"response": "response_original"}, inplace=True)
|
|
1258
|
+
df_result["response"] = results
|
|
1259
|
+
|
|
1260
|
+
# 保存结果 - 使用带行号范围的文件名
|
|
1261
|
+
file_stem = Path(file_path).stem
|
|
1262
|
+
if start_row == 0 and len(df_to_process) < len(
|
|
1263
|
+
pd.read_excel(file_path)
|
|
1264
|
+
if file_path.endswith(".xlsx")
|
|
1265
|
+
else pd.read_csv(file_path)
|
|
1266
|
+
):
|
|
1267
|
+
# 从第一行开始但不是全部
|
|
1268
|
+
output_path = f"{file_stem}_result_rows1-{len(df_to_process)}.xlsx"
|
|
1269
|
+
elif start_row > 0:
|
|
1270
|
+
# 指定范围
|
|
1271
|
+
output_path = f"{file_stem}_result_rows{start_row + 1}-{start_row + len(df_to_process)}.xlsx"
|
|
1272
|
+
else:
|
|
1273
|
+
# 全部行
|
|
1274
|
+
output_path = f"{file_stem}_result.xlsx"
|
|
1275
|
+
|
|
1276
|
+
df_result.to_excel(output_path, index=False, engine="openpyxl")
|
|
1277
|
+
print(f"结果已保存到: {output_path}")
|
|
1278
|
+
|
|
1279
|
+
return df_result
|
|
1280
|
+
|
|
1281
|
+
def _show_request_preview(
|
|
1282
|
+
self,
|
|
1283
|
+
df: pd.DataFrame,
|
|
1284
|
+
text_col: Optional[str],
|
|
1285
|
+
image_col: Optional[str],
|
|
1286
|
+
custom_prompt: Optional[str],
|
|
1287
|
+
config: dict,
|
|
1288
|
+
start_row: int,
|
|
1289
|
+
end_row: int,
|
|
1290
|
+
model_name: str,
|
|
1291
|
+
base_url: str,
|
|
1292
|
+
):
|
|
1293
|
+
"""显示请求预览信息"""
|
|
1294
|
+
print("\n" + "=" * 50)
|
|
1295
|
+
print("请求预览信息")
|
|
1296
|
+
print("=" * 50)
|
|
1297
|
+
|
|
1298
|
+
# 基本信息
|
|
1299
|
+
print(f"模型: {model_name}")
|
|
1300
|
+
print(f"Base URL: {base_url}")
|
|
1301
|
+
print(f"处理范围: 第{start_row + 1}行到第{end_row}行 (共{len(df)}行)")
|
|
1302
|
+
print(f"文本列: {text_col if text_col else '无'}")
|
|
1303
|
+
print(f"图像列: {image_col if image_col else '无'}")
|
|
1304
|
+
print(f"预处理图像: {'是' if config['preprocess_msg'] else '否'}")
|
|
1305
|
+
print(f"并发数: {config['concurrency_limit']}")
|
|
1306
|
+
print(f"QPS限制: {config['max_qps']}")
|
|
1307
|
+
print(f"分类模式: {'是' if config.get('use_cls', False) else '否'}")
|
|
1308
|
+
if config.get("system_prompt"):
|
|
1309
|
+
print(
|
|
1310
|
+
f"系统提示: {config['system_prompt'][:50]}{'...' if len(config['system_prompt']) > 50 else ''}"
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
# 创建临时处理器生成示例消息
|
|
1314
|
+
temp_processor = BatchProcessor(model=model_name, base_url=base_url)
|
|
1315
|
+
if custom_prompt:
|
|
1316
|
+
temp_processor.set_custom_prompt(custom_prompt)
|
|
1317
|
+
|
|
1318
|
+
# 获取第一行数据作为示例
|
|
1319
|
+
if len(df) > 0:
|
|
1320
|
+
first_row = df.iloc[0].to_dict()
|
|
1321
|
+
sample_messages = temp_processor.create_messages_for_row(
|
|
1322
|
+
row_data=first_row,
|
|
1323
|
+
text_col=text_col,
|
|
1324
|
+
image_col=image_col,
|
|
1325
|
+
custom_prompt=custom_prompt,
|
|
1326
|
+
use_cls=config.get("use_cls", False),
|
|
1327
|
+
system_prompt=config.get("system_prompt"),
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
print(f"\n第{start_row + 1}行请求示例 (原始messages格式):")
|
|
1331
|
+
print("-" * 50)
|
|
1332
|
+
import json
|
|
1333
|
+
|
|
1334
|
+
print(json.dumps(sample_messages, indent=2, ensure_ascii=False))
|
|
1335
|
+
print("-" * 50)
|
|
1336
|
+
|
|
1337
|
+
# 确认继续
|
|
1338
|
+
while True:
|
|
1339
|
+
confirm = input(f"\n确认开始处理 {len(df)} 行数据? (y/n): ").lower()
|
|
1340
|
+
if confirm in ["y", "yes"]:
|
|
1341
|
+
break
|
|
1342
|
+
elif confirm in ["n", "no"]:
|
|
1343
|
+
print("已取消处理")
|
|
1344
|
+
exit(0)
|
|
1345
|
+
else:
|
|
1346
|
+
print("请输入y或n")
|
|
1347
|
+
|
|
1348
|
+
|
|
1349
|
+
def parse_existing_results(file_path: str):
|
|
1350
|
+
"""解析已存在的结果文件"""
|
|
1351
|
+
print("=== 结果文件解析工具 ===")
|
|
1352
|
+
print(f"文件路径: {file_path}")
|
|
1353
|
+
|
|
1354
|
+
# 加载数据
|
|
1355
|
+
try:
|
|
1356
|
+
df = pd.read_excel(file_path)
|
|
1357
|
+
print(f"成功加载数据,共 {len(df)} 行")
|
|
1358
|
+
except Exception as e:
|
|
1359
|
+
print(f"错误:无法加载文件 {file_path}: {e}")
|
|
1360
|
+
return
|
|
1361
|
+
|
|
1362
|
+
# 创建临时的InteractiveRunner来使用配置选择功能
|
|
1363
|
+
runner = InteractiveRunner()
|
|
1364
|
+
|
|
1365
|
+
# 选择解析配置 - 使用通用的默认配置
|
|
1366
|
+
default_config = {"parse_response_to_pred": False, "pred_parsed_tag": None}
|
|
1367
|
+
|
|
1368
|
+
metrics_config = runner.select_metrics_config(df, default_config)
|
|
1369
|
+
if metrics_config is None:
|
|
1370
|
+
print("已取消解析")
|
|
1371
|
+
return
|
|
1372
|
+
|
|
1373
|
+
# 执行解析
|
|
1374
|
+
processor = BatchProcessor()
|
|
1375
|
+
print("\n开始解析结果...")
|
|
1376
|
+
processor.calculate_metrics(
|
|
1377
|
+
df,
|
|
1378
|
+
label_col=metrics_config["label_col"],
|
|
1379
|
+
parse_response_to_pred=metrics_config["parse_response_to_pred"],
|
|
1380
|
+
pred_parsed_tag=metrics_config["pred_parsed_tag"],
|
|
1381
|
+
)
|
|
1382
|
+
print("解析完成!")
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def choose_interface():
|
|
1386
|
+
"""选择界面模式"""
|
|
1387
|
+
print("🤖 MLLM Judge - 多模态内容审核系统")
|
|
1388
|
+
print("=" * 50)
|
|
1389
|
+
print("选择界面模式:")
|
|
1390
|
+
print("1. 📱 图形界面 (GUI) - 推荐")
|
|
1391
|
+
print("2. 💻 命令行界面 (CLI)")
|
|
1392
|
+
print("3. ❌ 退出")
|
|
1393
|
+
|
|
1394
|
+
while True:
|
|
1395
|
+
try:
|
|
1396
|
+
choice = input("\n请选择 (1-3): ").strip()
|
|
1397
|
+
if choice == "1":
|
|
1398
|
+
return "gui"
|
|
1399
|
+
elif choice == "2":
|
|
1400
|
+
return "cli"
|
|
1401
|
+
elif choice == "3":
|
|
1402
|
+
return "exit"
|
|
1403
|
+
else:
|
|
1404
|
+
print("请输入有效选择 (1-3)")
|
|
1405
|
+
except KeyboardInterrupt:
|
|
1406
|
+
print("\n👋 已退出")
|
|
1407
|
+
return "exit"
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
async def run_cli():
|
|
1411
|
+
"""运行CLI模式"""
|
|
1412
|
+
runner = InteractiveRunner()
|
|
1413
|
+
await runner.run()
|
|
1414
|
+
|
|
1415
|
+
|
|
1416
|
+
def run_gui():
|
|
1417
|
+
"""运行GUI模式"""
|
|
1418
|
+
try:
|
|
1419
|
+
from run_gui import MLLMJudgeApp
|
|
1420
|
+
|
|
1421
|
+
app = MLLMJudgeApp()
|
|
1422
|
+
app.run()
|
|
1423
|
+
except ImportError:
|
|
1424
|
+
print("❌ GUI模式需要安装Textual库:")
|
|
1425
|
+
print(" pip install textual textual-dev")
|
|
1426
|
+
print("\n🔄 正在启动CLI模式...")
|
|
1427
|
+
return False
|
|
1428
|
+
except Exception as e:
|
|
1429
|
+
print(f"❌ GUI启动失败: {e}")
|
|
1430
|
+
print("\n🔄 正在启动CLI模式...")
|
|
1431
|
+
return False
|
|
1432
|
+
return True
|
|
1433
|
+
|
|
1434
|
+
|
|
1435
|
+
async def main():
|
|
1436
|
+
"""主函数"""
|
|
1437
|
+
interface_mode = choose_interface()
|
|
1438
|
+
|
|
1439
|
+
if interface_mode == "gui":
|
|
1440
|
+
success = run_gui()
|
|
1441
|
+
if not success:
|
|
1442
|
+
print("\n" + "=" * 50)
|
|
1443
|
+
await run_cli()
|
|
1444
|
+
elif interface_mode == "cli":
|
|
1445
|
+
await run_cli()
|
|
1446
|
+
else:
|
|
1447
|
+
print("👋 再见!")
|
|
1448
|
+
|
|
1449
|
+
|
|
1450
|
+
if __name__ == "__main__":
|
|
1451
|
+
asyncio.run(main())
|