gpbench 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. gp_agent_tool/compute_dataset_feature.py +67 -0
  2. gp_agent_tool/config.py +65 -0
  3. gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
  4. gp_agent_tool/experience/dataset_summary_info.py +13 -0
  5. gp_agent_tool/experience/experience_info.py +12 -0
  6. gp_agent_tool/experience/get_matched_experience.py +111 -0
  7. gp_agent_tool/llm_client.py +119 -0
  8. gp_agent_tool/logging_utils.py +24 -0
  9. gp_agent_tool/main.py +347 -0
  10. gp_agent_tool/read_agent/__init__.py +46 -0
  11. gp_agent_tool/read_agent/nodes.py +674 -0
  12. gp_agent_tool/read_agent/prompts.py +547 -0
  13. gp_agent_tool/read_agent/python_repl_tool.py +165 -0
  14. gp_agent_tool/read_agent/state.py +101 -0
  15. gp_agent_tool/read_agent/workflow.py +54 -0
  16. gpbench/__init__.py +25 -0
  17. gpbench/_selftest.py +104 -0
  18. gpbench/method_class/BayesA/BayesA_class.py +141 -0
  19. gpbench/method_class/BayesA/__init__.py +5 -0
  20. gpbench/method_class/BayesA/_bayesfromR.py +96 -0
  21. gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
  22. gpbench/method_class/BayesA/bayesAfromR.py +16 -0
  23. gpbench/method_class/BayesB/BayesB_class.py +140 -0
  24. gpbench/method_class/BayesB/__init__.py +5 -0
  25. gpbench/method_class/BayesB/_bayesfromR.py +96 -0
  26. gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
  27. gpbench/method_class/BayesB/bayesBfromR.py +16 -0
  28. gpbench/method_class/BayesC/BayesC_class.py +141 -0
  29. gpbench/method_class/BayesC/__init__.py +4 -0
  30. gpbench/method_class/BayesC/_bayesfromR.py +96 -0
  31. gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
  32. gpbench/method_class/BayesC/bayesCfromR.py +16 -0
  33. gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
  34. gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
  35. gpbench/method_class/CropARNet/__init__.py +5 -0
  36. gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
  37. gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
  38. gpbench/method_class/Cropformer/__init__.py +5 -0
  39. gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
  40. gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
  41. gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
  42. gpbench/method_class/DL_GWAS/__init__.py +5 -0
  43. gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
  44. gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
  45. gpbench/method_class/DNNGP/__init__.py +5 -0
  46. gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
  47. gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
  48. gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
  49. gpbench/method_class/DeepCCR/__init__.py +5 -0
  50. gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
  51. gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
  52. gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
  53. gpbench/method_class/DeepGS/__init__.py +5 -0
  54. gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
  55. gpbench/method_class/EIR/EIR_class.py +276 -0
  56. gpbench/method_class/EIR/EIR_he_class.py +184 -0
  57. gpbench/method_class/EIR/__init__.py +5 -0
  58. gpbench/method_class/EIR/utils/__init__.py +0 -0
  59. gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
  60. gpbench/method_class/EIR/utils/common.py +65 -0
  61. gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
  62. gpbench/method_class/EIR/utils/logging.py +59 -0
  63. gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
  64. gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
  65. gpbench/method_class/EIR/utils/transformer_models.py +546 -0
  66. gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
  67. gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
  68. gpbench/method_class/ElasticNet/__init__.py +5 -0
  69. gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
  70. gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
  71. gpbench/method_class/G2PDeep/__init__.py +5 -0
  72. gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
  73. gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
  74. gpbench/method_class/GBLUP/__init__.py +5 -0
  75. gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
  76. gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
  77. gpbench/method_class/GEFormer/__init__.py +5 -0
  78. gpbench/method_class/GEFormer/gMLP_class.py +357 -0
  79. gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
  80. gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
  81. gpbench/method_class/LightGBM/__init__.py +5 -0
  82. gpbench/method_class/RF/RF_GPU_class.py +165 -0
  83. gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
  84. gpbench/method_class/RF/__init__.py +5 -0
  85. gpbench/method_class/SVC/SVC_GPU.py +181 -0
  86. gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
  87. gpbench/method_class/SVC/__init__.py +5 -0
  88. gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
  89. gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
  90. gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
  91. gpbench/method_class/SoyDNGP/__init__.py +5 -0
  92. gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
  93. gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
  94. gpbench/method_class/XGBoost/__init__.py +5 -0
  95. gpbench/method_class/__init__.py +52 -0
  96. gpbench/method_class/rrBLUP/__init__.py +5 -0
  97. gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
  98. gpbench/method_reg/BayesA/BayesA.py +116 -0
  99. gpbench/method_reg/BayesA/__init__.py +5 -0
  100. gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
  101. gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
  102. gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
  103. gpbench/method_reg/BayesB/BayesB.py +117 -0
  104. gpbench/method_reg/BayesB/__init__.py +5 -0
  105. gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
  106. gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
  107. gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
  108. gpbench/method_reg/BayesC/BayesC.py +115 -0
  109. gpbench/method_reg/BayesC/__init__.py +5 -0
  110. gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
  111. gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
  112. gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
  113. gpbench/method_reg/CropARNet/CropARNet.py +159 -0
  114. gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
  115. gpbench/method_reg/CropARNet/__init__.py +5 -0
  116. gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
  117. gpbench/method_reg/Cropformer/Cropformer.py +313 -0
  118. gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
  119. gpbench/method_reg/Cropformer/__init__.py +5 -0
  120. gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
  121. gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
  122. gpbench/method_reg/DL_GWAS/__init__.py +5 -0
  123. gpbench/method_reg/DNNGP/DNNGP.py +157 -0
  124. gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
  125. gpbench/method_reg/DNNGP/__init__.py +5 -0
  126. gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
  127. gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
  128. gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
  129. gpbench/method_reg/DeepCCR/__init__.py +5 -0
  130. gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
  131. gpbench/method_reg/DeepGS/DeepGS.py +165 -0
  132. gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
  133. gpbench/method_reg/DeepGS/__init__.py +5 -0
  134. gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
  135. gpbench/method_reg/EIR/EIR.py +258 -0
  136. gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
  137. gpbench/method_reg/EIR/__init__.py +5 -0
  138. gpbench/method_reg/EIR/utils/__init__.py +0 -0
  139. gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
  140. gpbench/method_reg/EIR/utils/common.py +65 -0
  141. gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
  142. gpbench/method_reg/EIR/utils/logging.py +59 -0
  143. gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
  144. gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
  145. gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
  146. gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
  147. gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
  148. gpbench/method_reg/ElasticNet/__init__.py +5 -0
  149. gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
  150. gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
  151. gpbench/method_reg/G2PDeep/__init__.py +5 -0
  152. gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
  153. gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
  154. gpbench/method_reg/GBLUP/__init__.py +5 -0
  155. gpbench/method_reg/GEFormer/GEFormer.py +164 -0
  156. gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
  157. gpbench/method_reg/GEFormer/__init__.py +5 -0
  158. gpbench/method_reg/GEFormer/gMLP.py +341 -0
  159. gpbench/method_reg/LightGBM/LightGBM.py +237 -0
  160. gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
  161. gpbench/method_reg/LightGBM/__init__.py +5 -0
  162. gpbench/method_reg/MVP/MVP.py +182 -0
  163. gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
  164. gpbench/method_reg/MVP/__init__.py +5 -0
  165. gpbench/method_reg/MVP/base_MVP.py +113 -0
  166. gpbench/method_reg/RF/RF_GPU.py +174 -0
  167. gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
  168. gpbench/method_reg/RF/__init__.py +5 -0
  169. gpbench/method_reg/SVC/SVC_GPU.py +194 -0
  170. gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
  171. gpbench/method_reg/SVC/__init__.py +5 -0
  172. gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
  173. gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
  174. gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
  175. gpbench/method_reg/SoyDNGP/__init__.py +5 -0
  176. gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
  177. gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
  178. gpbench/method_reg/XGBoost/__init__.py +5 -0
  179. gpbench/method_reg/__init__.py +55 -0
  180. gpbench/method_reg/rrBLUP/__init__.py +5 -0
  181. gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
  182. gpbench-1.0.0.dist-info/METADATA +379 -0
  183. gpbench-1.0.0.dist-info/RECORD +188 -0
  184. gpbench-1.0.0.dist-info/WHEEL +5 -0
  185. gpbench-1.0.0.dist-info/entry_points.txt +2 -0
  186. gpbench-1.0.0.dist-info/top_level.txt +3 -0
  187. tests/test_import.py +80 -0
  188. tests/test_method.py +232 -0
@@ -0,0 +1,674 @@
1
+ """
2
+ Read Agent 各个节点的实现(独立版本,不依赖 textMSA)。
3
+
4
+ 注意:
5
+ - 不在内部执行任何文件系统读取,所有文件信息由外部通过 FileInfo 传入。
6
+ - 对于文本文件,直接使用 preview 作为内容来源。
7
+ - 对于数据/图像文件,仍允许在生成的代码或多模态模型中使用 file_path,但本模块自身不做 I/O。
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Any, Optional
15
+
16
+ from llm_client import run_llm, run_multimodal_llm
17
+ from logging_utils import get_logger
18
+ from .python_repl_tool import PythonREPL
19
+ from .prompts import (
20
+ format_answer_prompt,
21
+ format_code_generation_prompt,
22
+ format_code_retry_prompt,
23
+ format_data_preview_analysis_prompt,
24
+ format_plan_prompt,
25
+ format_text_summary_prompt,
26
+ )
27
+ from .state import FileInfo, PlanHistory, ReadAgentState
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ def _normalize_language(language: Optional[str]) -> str:
34
+ if not language:
35
+ return "en"
36
+ lower = language.lower()
37
+ if lower.startswith("zh"):
38
+ return "zh"
39
+ if lower.startswith("en"):
40
+ return "en"
41
+ return "en"
42
+
43
+
44
+ def _get_localized_message(messages: dict[str, str], language: Optional[str]) -> str:
45
+ lang = _normalize_language(language)
46
+ return messages.get(lang, messages.get("en", ""))
47
+
48
+
49
+ def _is_image_file(path_str: str) -> bool:
50
+ path = Path(path_str)
51
+ return path.suffix.lower() in {
52
+ ".png",
53
+ ".jpg",
54
+ ".jpeg",
55
+ ".bmp",
56
+ ".gif",
57
+ ".tiff",
58
+ ".webp",
59
+ }
60
+
61
+
62
+ def _is_data_file(filename: str) -> bool:
63
+ data_extensions = {
64
+ ".csv",
65
+ ".h5ad",
66
+ ".json",
67
+ ".parquet",
68
+ ".xlsx",
69
+ ".xls",
70
+ ".h5",
71
+ ".hdf5",
72
+ ".feather",
73
+ ".pkl",
74
+ ".pickle",
75
+ }
76
+ return Path(filename).suffix.lower() in data_extensions
77
+
78
+
79
+ def _strip_if_main_block(code: str) -> str:
80
+ """
81
+ 修复代码中的 if __name__ == "__main__": 块
82
+
83
+ 在 REPL 环境中,__name__ 可能不是 "__main__",导致代码不执行。
84
+ 此函数将 if __name__ == "__main__": 块中的内容提取出来,直接执行。
85
+ """
86
+ import re
87
+
88
+ # 检查是否包含 if __name__ == "__main__": 模式
89
+ pattern = r'if\s+__name__\s*==\s*["\']__main__["\']\s*:'
90
+
91
+ if not re.search(pattern, code):
92
+ return code # 没有 main 块,直接返回
93
+
94
+ lines = code.split('\n')
95
+ result_lines = []
96
+ non_main_lines = []
97
+ main_content_lines = []
98
+ in_main_block = False
99
+ main_block_indent = None
100
+
101
+ for line in lines:
102
+ # 检查是否是 if __name__ == "__main__": 行
103
+ if re.match(r'\s*if\s+__name__\s*==\s*["\']__main__["\']\s*:', line):
104
+ in_main_block = True
105
+ main_block_indent = len(line) - len(line.lstrip())
106
+ continue
107
+
108
+ if in_main_block:
109
+ # 在 main 块中
110
+ if not line.strip(): # 空行
111
+ main_content_lines.append('')
112
+ continue
113
+
114
+ current_indent = len(line) - len(line.lstrip())
115
+
116
+ # 如果缩进小于等于 main_block_indent,说明 main 块结束了
117
+ if current_indent <= main_block_indent:
118
+ in_main_block = False
119
+ # 这一行不属于 main 块,添加到非 main 块
120
+ non_main_lines.append(line)
121
+ continue
122
+
123
+ # 提取 main 块中的内容,去除缩进(main_block_indent + 4)
124
+ indent_to_remove = main_block_indent + 4
125
+ if len(line) >= indent_to_remove:
126
+ main_content_lines.append(line[indent_to_remove:])
127
+ else:
128
+ # 如果缩进不够,可能是使用了 tab 或其他缩进方式,直接去除所有前导空白
129
+ main_content_lines.append(line.lstrip())
130
+ else:
131
+ # 不在 main 块中,保留原行
132
+ non_main_lines.append(line)
133
+
134
+ # 组合代码:非 main 块 + main 块内容
135
+ fixed_code = '\n'.join(non_main_lines)
136
+ if main_content_lines:
137
+ if fixed_code and not fixed_code.endswith('\n'):
138
+ fixed_code += '\n'
139
+ fixed_code += '\n'.join(main_content_lines)
140
+
141
+ if main_content_lines:
142
+ logger.info(
143
+ "Read Agent - 修复了 if __name__ == '__main__': 块",
144
+ extra={
145
+ "original_code_length": len(code),
146
+ "fixed_code_length": len(fixed_code),
147
+ "main_block_lines": len(main_content_lines),
148
+ },
149
+ )
150
+
151
+ return fixed_code
152
+
153
+
154
+ def _parse_json_response(response_content: str) -> dict[str, Any]:
155
+ """从 LLM 响应中解析 JSON,失败时返回空 dict。
156
+
157
+ 兼容以下几种常见格式:
158
+ - 直接输出 JSON:{"plans": [...], "reasoning": "..."}
159
+ - 使用 ```json / ``` 包裹的代码块:
160
+ ```json
161
+ { ... }
162
+ ```
163
+ """
164
+ cleaned = response_content.strip()
165
+
166
+ # 先尝试直接解析整段内容
167
+ try:
168
+ return json.loads(cleaned)
169
+ except Exception:
170
+ pass
171
+
172
+ # 再尝试处理被 ``` / ```json / ```python 等代码块包裹的情况
173
+ if cleaned.startswith("```"):
174
+ lines = cleaned.splitlines()
175
+ # 去掉第一行 ``` / ```json / ```python
176
+ if lines and lines[0].lstrip().startswith("```"):
177
+ lines = lines[1:]
178
+ # 去掉最后一行 ```
179
+ if lines and lines[-1].lstrip().startswith("```"):
180
+ lines = lines[:-1]
181
+ cleaned_block = "\n".join(lines).strip()
182
+
183
+ # 1)优先尝试直接把代码块内容当 JSON 解析
184
+ try:
185
+ return json.loads(cleaned_block)
186
+ except Exception:
187
+ pass
188
+
189
+ # 2)如果是 ```python 之类的代码块,可能包含变量赋值,尝试从中抽取第一个 {...} 结构
190
+ brace_start = cleaned_block.find("{")
191
+ brace_end = cleaned_block.rfind("}")
192
+ if brace_start != -1 and brace_end != -1 and brace_end > brace_start:
193
+ candidate = cleaned_block[brace_start : brace_end + 1].strip()
194
+ try:
195
+ return json.loads(candidate)
196
+ except Exception:
197
+ return {}
198
+
199
+ return {}
200
+
201
+ # 最后一次兜底:从任意文本中提取第一个 {...} 结构再尝试解析
202
+ brace_start = cleaned.find("{")
203
+ brace_end = cleaned.rfind("}")
204
+ if brace_start != -1 and brace_end != -1 and brace_end > brace_start:
205
+ candidate = cleaned[brace_start : brace_end + 1].strip()
206
+ try:
207
+ return json.loads(candidate)
208
+ except Exception:
209
+ pass
210
+
211
+ return {}
212
+
213
+
214
+ def plan_node(state: ReadAgentState):
215
+ """生成执行计划(顺序规划)"""
216
+ user_query = state["user_query"]
217
+ file_overview = state["file_overview"]
218
+ files = state["files"]
219
+ language = state.get("language", "en")
220
+
221
+ plan_prompt = format_plan_prompt(
222
+ user_query=user_query,
223
+ file_overview=file_overview,
224
+ language=language,
225
+ )
226
+ response_text = run_llm(
227
+ plan_prompt,
228
+ temperature=0.1,
229
+ max_tokens=8000,
230
+ node_name="plan_node",
231
+ )
232
+ parsed = _parse_json_response(response_text)
233
+ plans = parsed.get("plans", [])
234
+ reasoning = parsed.get("reasoning", "")
235
+
236
+ if not plans:
237
+ warning_msg = _get_localized_message(
238
+ {
239
+ "zh": "计划生成失败,返回空计划列表",
240
+ "en": "Plan generation failed, returning empty plan list",
241
+ },
242
+ language,
243
+ )
244
+ logger.warning(warning_msg)
245
+ plans = []
246
+
247
+ # 将计划与外部传入的文件信息对齐(按 file_name / file_path 匹配)
248
+ file_index = {(f.get("file_name"), f.get("file_path")): f for f in files}
249
+
250
+ history_plans: list[PlanHistory] = []
251
+ for plan in plans:
252
+ file_name = plan.get("file_name", "")
253
+ file_path = plan.get("file_path", "")
254
+ key = (file_name, file_path)
255
+ file_info: Optional[FileInfo] = file_index.get(key) # 目前主要用于日志和一致性校验
256
+ if not file_info:
257
+ logger.warning(
258
+ "Plan references file not found in input files",
259
+ extra={"file_name": file_name, "file_path": file_path},
260
+ )
261
+
262
+ plan_history: PlanHistory = {
263
+ "file_name": file_name,
264
+ "file_path": file_path,
265
+ "plan_detail": plan.get("plan_detail", ""),
266
+ "result": None,
267
+ }
268
+ order_reasoning = plan.get("order_reasoning", "")
269
+ if order_reasoning:
270
+ plan_history["order_reasoning"] = order_reasoning
271
+ history_plans.append(plan_history)
272
+
273
+ logger.info(
274
+ "Plan node completed",
275
+ extra={
276
+ "plan_count": len(history_plans),
277
+ "reasoning_preview": reasoning[:200]
278
+ + ("..." if len(reasoning) > 200 else ""),
279
+ },
280
+ )
281
+
282
+ return {
283
+ "current_plan_index": 0,
284
+ "history_plans": history_plans,
285
+ }
286
+
287
+
288
+ def execute_plan_node(state: ReadAgentState):
289
+ """路由节点:判断是否还有计划需要执行"""
290
+ current_plan_index = state.get("current_plan_index", 0)
291
+ history_plans = state.get("history_plans", [])
292
+
293
+ if current_plan_index >= len(history_plans):
294
+ next_route = "answer"
295
+ else:
296
+ next_route = "read"
297
+ return {
298
+ "next_route": next_route,
299
+ }
300
+
301
+
302
+ def read_node(state: ReadAgentState):
303
+ """执行单个计划项"""
304
+ files = state["files"]
305
+ history_plans = state.get("history_plans", [])
306
+ current_plan_index = state.get("current_plan_index", 0)
307
+ language = state.get("language", "en")
308
+
309
+ if current_plan_index >= len(history_plans):
310
+ warning_msg = _get_localized_message(
311
+ {
312
+ "zh": "current_plan_index 超出范围",
313
+ "en": "current_plan_index out of range",
314
+ },
315
+ language,
316
+ )
317
+ logger.warning(warning_msg)
318
+ return {}
319
+
320
+ # 收集之前已读取的结果
321
+ previous_results_list = []
322
+ for i in range(current_plan_index):
323
+ prev_plan = history_plans[i]
324
+ prev_result = prev_plan.get("result")
325
+ if prev_result:
326
+ previous_results_list.append(
327
+ {
328
+ "file_name": prev_plan.get("file_name", ""),
329
+ "file_path": prev_plan.get("file_path", ""),
330
+ "plan_detail": prev_plan.get("plan_detail", ""),
331
+ "result": prev_result,
332
+ }
333
+ )
334
+
335
+ if previous_results_list:
336
+ previous_results_str = json.dumps(
337
+ previous_results_list,
338
+ ensure_ascii=False,
339
+ indent=2,
340
+ )
341
+ else:
342
+ previous_results_str = _get_localized_message(
343
+ {
344
+ "zh": "尚未读取任何文件。",
345
+ "en": "No previous files have been read yet.",
346
+ },
347
+ language,
348
+ )
349
+
350
+ current_plan = history_plans[current_plan_index]
351
+ file_name = current_plan.get("file_name", "")
352
+ file_path = current_plan.get("file_path", "")
353
+ plan_detail = current_plan.get("plan_detail", "")
354
+
355
+ # 根据 file_name + file_path 找到对应的 FileInfo
356
+ file_info: Optional[FileInfo] = None
357
+ for f in files:
358
+ if f.get("file_name") == file_name and f.get("file_path") == file_path:
359
+ file_info = f
360
+ break
361
+
362
+ if not file_info:
363
+ warning_msg = _get_localized_message(
364
+ {
365
+ "zh": f"文件信息不存在: {file_name} ({file_path})",
366
+ "en": f"File info not found: {file_name} ({file_path})",
367
+ },
368
+ language,
369
+ )
370
+ logger.warning(warning_msg)
371
+ history_plans[current_plan_index]["result"] = warning_msg
372
+ return {
373
+ "history_plans": history_plans,
374
+ "current_plan_index": current_plan_index + 1,
375
+ }
376
+
377
+ result = ""
378
+
379
+ # 判断文件类型并处理(通过扩展名简单区分)
380
+ if _is_image_file(file_path):
381
+ # 图像文件:使用多模态模型分析
382
+ if language == "zh":
383
+ text_content = f"文件: {file_name}\n路径: {file_path}\n类型: 图像文件\n\n请分析图像内容并回答: {plan_detail}"
384
+ else:
385
+ text_content = f"File: {file_name}\nPath: {file_path}\nType: Image file\n\nPlease analyze the image content and answer: {plan_detail}"
386
+
387
+ content_payload = [{"image": f"file://{file_path}"}, {"text": text_content}]
388
+ try:
389
+ result = run_multimodal_llm(content_payload, node_name="read_node_image")
390
+ except Exception as exc: # noqa: BLE001
391
+ error_msg = _get_localized_message(
392
+ {
393
+ "zh": f"[错误] 图像分析失败: {exc}",
394
+ "en": f"[Error] Image analysis failed: {exc}",
395
+ },
396
+ language,
397
+ )
398
+ logger.error(error_msg, exc_info=True)
399
+ result = error_msg
400
+
401
+ elif _is_data_file(file_name):
402
+ # 数据文件:不在此处读取文件,只生成和执行分析代码
403
+ repl = PythonREPL()
404
+ code = ""
405
+ execution_result = None
406
+ execution_success = False
407
+ analysis_guidance = ""
408
+
409
+ try:
410
+ user_query = state.get("user_query", "")
411
+ preview_analysis_prompt = format_data_preview_analysis_prompt(
412
+ user_query=user_query,
413
+ file_info={
414
+ "file_name": file_name,
415
+ "file_path": file_path,
416
+ "preview": file_info.get("preview", ""),
417
+ "description": file_info.get("description", ""),
418
+ },
419
+ previous_results=previous_results_str,
420
+ language=language,
421
+ )
422
+ guidance_response = run_llm(
423
+ preview_analysis_prompt,
424
+ temperature=0.1,
425
+ max_tokens=2000,
426
+ use_codegen=False,
427
+ node_name="read_node_preview_analysis",
428
+ )
429
+ parsed = _parse_json_response(guidance_response)
430
+ analysis_guidance = parsed.get("guidance", "") or guidance_response.strip()
431
+ except Exception as exc: # noqa: BLE001
432
+ logger.warning("数据预览分析失败: %s", exc, exc_info=True)
433
+
434
+ for attempt in range(5):
435
+ try:
436
+ if attempt == 0:
437
+ prompt = format_code_generation_prompt(
438
+ instruction=plan_detail,
439
+ file_info={
440
+ "file_name": file_name,
441
+ "file_path": file_path,
442
+ "preview": file_info.get("preview", ""),
443
+ "description": file_info.get("description", ""),
444
+ },
445
+ previous_results=previous_results_str,
446
+ analysis_guidance=analysis_guidance,
447
+ language=language,
448
+ )
449
+ else:
450
+ if execution_result is None or not hasattr(
451
+ execution_result, "stderr"
452
+ ):
453
+ break
454
+ prompt = format_code_retry_prompt(
455
+ user_query=user_query,
456
+ instruction=plan_detail,
457
+ file_info={
458
+ "file_name": file_name,
459
+ "file_path": file_path,
460
+ "preview": file_info.get("preview", ""),
461
+ "description": file_info.get("description", ""),
462
+ },
463
+ previous_code=code,
464
+ error_message=getattr(execution_result, "stderr", "") or "",
465
+ previous_results=previous_results_str,
466
+ language=language,
467
+ )
468
+
469
+ response_text = run_llm(
470
+ prompt,
471
+ temperature=0.1,
472
+ max_tokens=5000,
473
+ use_codegen=True,
474
+ node_name=f"read_node_codegen_attempt_{attempt + 1}",
475
+ )
476
+
477
+ # 代码生成节点约定优先返回 JSON:{"code": "..."},
478
+ # 但在实际日志中,经常会直接返回 ```python 代码块。
479
+ # 这里先尝试按 JSON 解析;若失败,则将整个响应当作原始代码字符串处理。
480
+ parsed = _parse_json_response(response_text)
481
+ parsed_code = parsed.get("code", "")
482
+ if not parsed_code:
483
+ text = response_text.strip()
484
+ if text.startswith("```"):
485
+ lines = text.splitlines()
486
+ # 去掉起始 ``` / ```python 等
487
+ if lines and lines[0].lstrip().startswith("```"):
488
+ lines = lines[1:]
489
+ # 去掉结尾 ```
490
+ if lines and lines[-1].lstrip().startswith("```"):
491
+ lines = lines[:-1]
492
+ parsed_code = "\n".join(lines)
493
+ else:
494
+ parsed_code = text
495
+
496
+ code = (parsed_code or "").strip()
497
+ # 去除顶层 if __name__ == "__main__": 结构,避免在 REPL 中被触发
498
+ code = _strip_if_main_block(code)
499
+ if not code:
500
+ result = _get_localized_message(
501
+ {
502
+ "zh": "[错误] 代码生成失败",
503
+ "en": "[Error] Code generation failed",
504
+ },
505
+ language,
506
+ )
507
+ break
508
+
509
+ execution_result = repl.run(code)
510
+ stdout = getattr(execution_result, "stdout", "") or ""
511
+ stderr = getattr(execution_result, "stderr", "") or ""
512
+
513
+ if stderr.strip():
514
+ logger.warning(
515
+ "代码执行 stderr 非空(attempt=%s): %s",
516
+ attempt + 1,
517
+ stderr[:500],
518
+ )
519
+ if attempt == 2:
520
+ result = _get_localized_message(
521
+ {
522
+ "zh": f"[错误] 代码执行失败: {stderr}",
523
+ "en": f"[Error] Code execution failed: {stderr}",
524
+ },
525
+ language,
526
+ )
527
+ else:
528
+ result = stdout or _get_localized_message(
529
+ {
530
+ "zh": "[成功] 代码执行完成,但无输出",
531
+ "en": "[Success] Code execution completed, but no output",
532
+ },
533
+ language,
534
+ )
535
+ execution_success = True
536
+ break
537
+ except Exception as exc: # noqa: BLE001
538
+ logger.error("代码生成或执行异常: %s", exc, exc_info=True)
539
+ if attempt == 2:
540
+ result = _get_localized_message(
541
+ {
542
+ "zh": f"[错误] 代码生成或执行异常: {exc}",
543
+ "en": f"[Error] Code generation or execution exception: {exc}",
544
+ },
545
+ language,
546
+ )
547
+
548
+ if not execution_success and not result:
549
+ result = _get_localized_message(
550
+ {
551
+ "zh": "[错误] 代码执行失败,已重试3次",
552
+ "en": "[Error] Code execution failed, retried 3 times",
553
+ },
554
+ language,
555
+ )
556
+
557
+ else:
558
+ # 文本文件:直接使用 preview 作为内容(不读取磁盘)
559
+ preview = file_info.get("preview", "")
560
+ if not preview:
561
+ result = _get_localized_message(
562
+ {
563
+ "zh": "[错误] 该文本文件未提供预览内容,无法分析",
564
+ "en": "[Error] No preview content provided for this text file",
565
+ },
566
+ language,
567
+ )
568
+ else:
569
+ summary_prompt = format_text_summary_prompt(
570
+ instruction=plan_detail,
571
+ file_content=preview,
572
+ previous_results=previous_results_str,
573
+ language=language,
574
+ )
575
+ result = run_llm(
576
+ summary_prompt,
577
+ temperature=0.1,
578
+ max_tokens=4000,
579
+ node_name="read_node_text_summary",
580
+ )
581
+
582
+ history_plans[current_plan_index]["result"] = result
583
+
584
+ logger.info(
585
+ "Read node completed",
586
+ extra={
587
+ "file_name": file_name,
588
+ "file_path": file_path,
589
+ },
590
+ )
591
+
592
+ return {
593
+ "history_plans": history_plans,
594
+ "current_plan_index": current_plan_index + 1,
595
+ }
596
+
597
+
598
+ def answer_node(state: ReadAgentState):
599
+ """汇总所有计划结果,生成最终答案"""
600
+ user_query = state["user_query"]
601
+ history_plans = state.get("history_plans", [])
602
+ language = state.get("language", "en")
603
+
604
+ execution_results = []
605
+ for plan in history_plans:
606
+ result_str = plan.get("result", "")
607
+ is_error = bool(result_str and (result_str.startswith("[错误]") or result_str.startswith("[Error]")))
608
+ execution_results.append(
609
+ {
610
+ "file_name": plan.get("file_name", ""),
611
+ "file_path": plan.get("file_path", ""),
612
+ "plan_detail": plan.get("plan_detail", ""),
613
+ "result": result_str,
614
+ "success": not is_error,
615
+ }
616
+ )
617
+
618
+ answer_prompt = format_answer_prompt(
619
+ user_query=user_query,
620
+ execution_results=execution_results,
621
+ language=language,
622
+ )
623
+ response_text = run_llm(
624
+ answer_prompt,
625
+ temperature=0.1,
626
+ max_tokens=8000,
627
+ node_name="answer_node",
628
+ )
629
+ parsed = _parse_json_response(response_text)
630
+ final_answer = parsed.get("final_answer", "")
631
+
632
+ if not final_answer:
633
+ # Fallback:简单地汇总每个文件的状态
634
+ completion_msg = _get_localized_message(
635
+ {
636
+ "zh": "执行完成。\n\n",
637
+ "en": "Execution completed.\n\n",
638
+ },
639
+ language,
640
+ )
641
+ success_status = _get_localized_message(
642
+ {
643
+ "zh": "成功",
644
+ "en": "Success",
645
+ },
646
+ language,
647
+ )
648
+ failed_status = _get_localized_message(
649
+ {
650
+ "zh": "失败",
651
+ "en": "Failed",
652
+ },
653
+ language,
654
+ )
655
+ file_prefix = _get_localized_message(
656
+ {
657
+ "zh": "- 文件",
658
+ "en": "- File",
659
+ },
660
+ language,
661
+ )
662
+ final_answer = completion_msg
663
+ for plan in history_plans:
664
+ file_name = plan.get("file_name", "")
665
+ result = plan.get("result", "")
666
+ is_error = bool(result and (result.startswith("[错误]") or result.startswith("[Error]")))
667
+ status = success_status if result and not is_error else failed_status
668
+ final_answer += f"{file_prefix}: {file_name} - {status}\n"
669
+
670
+ return {
671
+ "final_answer": final_answer,
672
+ }
673
+
674
+