gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
gp_agent_tool/main.py
ADDED
|
@@ -0,0 +1,347 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Iterable, Optional, TypedDict
|
|
5
|
+
|
|
6
|
+
from experience.dataset_summary_info import dataset_summary_info
|
|
7
|
+
from experience.experience_info import experience_info
|
|
8
|
+
from experience.create_masked_dataset_summary import create_masked_dataset_summary
|
|
9
|
+
from compute_dataset_feature import process_one_phenotype
|
|
10
|
+
from experience.get_matched_experience import get_matched_experience
|
|
11
|
+
from llm_client import run_llm
|
|
12
|
+
from read_agent import run_read_agent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _detect_language(text: str) -> str:
|
|
16
|
+
"""
|
|
17
|
+
根据输入文本中英文字符占比判断语言。
|
|
18
|
+
如果英文字符占比 >= 50%,则返回 'en',否则返回 'zh'。
|
|
19
|
+
"""
|
|
20
|
+
if not text:
|
|
21
|
+
return "zh" # 默认为中文
|
|
22
|
+
|
|
23
|
+
total_chars = 0
|
|
24
|
+
english_chars = 0
|
|
25
|
+
|
|
26
|
+
for char in text:
|
|
27
|
+
if char.isalpha(): # 只统计字母字符
|
|
28
|
+
total_chars += 1
|
|
29
|
+
if char.isascii() and char.isalpha(): # 英文字母
|
|
30
|
+
english_chars += 1
|
|
31
|
+
|
|
32
|
+
if total_chars == 0:
|
|
33
|
+
return "zh" # 如果没有字母字符,默认为中文
|
|
34
|
+
|
|
35
|
+
english_ratio = english_chars / total_chars
|
|
36
|
+
return "en" if english_ratio >= 0.5 else "zh"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _build_similarity_prompt(query_dataset_summary: dict, language: str = "zh") -> str:
|
|
40
|
+
"""构造相似数据集查询 prompt。"""
|
|
41
|
+
if language == "en":
|
|
42
|
+
return (
|
|
43
|
+
"Based on the statistical information of the following dataset, "
|
|
44
|
+
"find the datasets with the most similar distribution to this dataset, "
|
|
45
|
+
"and provide detailed reasons.\n"
|
|
46
|
+
"Please clearly list the names of these similar datasets in your answer, "
|
|
47
|
+
"and each name must be in the format species/phenotype_name, "
|
|
48
|
+
"for example human/bmi, mouse/height, etc."
|
|
49
|
+
f"\nStatistical information: {query_dataset_summary}"
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
return (
|
|
53
|
+
"根据以下数据集的统计信息,找出与该数据集分布最相似的几个数据集,并给出详细原因。\n"
|
|
54
|
+
"请在回答中明确列出这些相似数据集的名称,且每个名称的格式必须为 species/phenotype_name,"
|
|
55
|
+
"例如 human/bmi、mouse/height 等。"
|
|
56
|
+
f"\n统计信息:{query_dataset_summary}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _build_method_prompt(user_query: str, language: str = "zh") -> str:
|
|
61
|
+
"""构造方法推荐 prompt。"""
|
|
62
|
+
if language == "en":
|
|
63
|
+
return (
|
|
64
|
+
"Based on user requirements and experimental performance of similar datasets, "
|
|
65
|
+
"recommend suitable algorithms and provide detailed reasoning.\n"
|
|
66
|
+
f"\nUser requirements: {user_query}"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
return (
|
|
70
|
+
"根据用户需求与相似数据集的实验表现,推荐适合的算法,并给出详细的推荐理由。\n"
|
|
71
|
+
f"\n用户需求:{user_query}"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _build_file_overview(info: dict) -> str:
|
|
76
|
+
"""将文件元信息格式化为 overview 字符串,供 read agent 使用。"""
|
|
77
|
+
name = info.get("file_name", "")
|
|
78
|
+
desc = info.get("description", "")
|
|
79
|
+
path = info.get("file_path", "")
|
|
80
|
+
preview = info.get("preview", "")
|
|
81
|
+
return f"文件名: {name}\n描述: {desc}\n路径: {path}\n预览:\n{preview}"
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _call_read_agent(prompt: str, info: dict, language: str = "zh") -> str:
|
|
85
|
+
"""
|
|
86
|
+
使用本地 read_agent 调用,返回 final_answer 文本。
|
|
87
|
+
只依赖单文件信息,避免外部传参。
|
|
88
|
+
"""
|
|
89
|
+
files = [
|
|
90
|
+
{
|
|
91
|
+
"file_name": info.get("file_name", ""),
|
|
92
|
+
"file_path": info.get("file_path", ""),
|
|
93
|
+
"description": info.get("description", ""),
|
|
94
|
+
"preview": info.get("preview", ""),
|
|
95
|
+
}
|
|
96
|
+
]
|
|
97
|
+
overview = _build_file_overview(info)
|
|
98
|
+
state = run_read_agent(
|
|
99
|
+
user_query=prompt,
|
|
100
|
+
files=files,
|
|
101
|
+
file_overview=overview,
|
|
102
|
+
language=language,
|
|
103
|
+
)
|
|
104
|
+
return state.get("final_answer", "") or ""
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalize_list_with_reason(
|
|
108
|
+
task_prompt: str,
|
|
109
|
+
raw_answer: str,
|
|
110
|
+
*,
|
|
111
|
+
language: str = "zh",
|
|
112
|
+
node_name: str = "normalize_list_with_reason",
|
|
113
|
+
) -> Optional[dict]:
|
|
114
|
+
"""
|
|
115
|
+
使用二次 LLM 调用,将自由文本规范化为
|
|
116
|
+
{"items": [...], "reason": "..."} 结构;返回 JSON 或 markdown JSON 代码块。
|
|
117
|
+
"""
|
|
118
|
+
if language == "en":
|
|
119
|
+
normalize_prompt = (
|
|
120
|
+
"You are an assistant responsible for result normalization. "
|
|
121
|
+
"Now there is an upstream LLM's answer that needs to be refined into structured JSON.\n\n"
|
|
122
|
+
"[Task Description]\n"
|
|
123
|
+
f"{task_prompt}\n\n"
|
|
124
|
+
"[Upstream Answer]\n"
|
|
125
|
+
f"{raw_answer}\n\n"
|
|
126
|
+
"Please summarize from the upstream answer:\n"
|
|
127
|
+
"1. A string list items, giving the recommended items in order (such as dataset names or method names);\n"
|
|
128
|
+
"2. A string reason, briefly explaining the overall rationale.\n\n"
|
|
129
|
+
"Return format requirement: Strictly output a JSON object, in the form:\n"
|
|
130
|
+
'{"items": ["item1", "item2"], "reason": "overall rationale"}\n'
|
|
131
|
+
"You can directly output JSON, or wrap JSON in a markdown code block ```json; do not output any additional text."
|
|
132
|
+
)
|
|
133
|
+
else:
|
|
134
|
+
normalize_prompt = (
|
|
135
|
+
"你是一个负责结果规范化的助手。现在有一个上游 LLM 的中文回答,需要你将其提炼为结构化 JSON。\n\n"
|
|
136
|
+
"【任务描述】\n"
|
|
137
|
+
f"{task_prompt}\n\n"
|
|
138
|
+
"【上游回答】\n"
|
|
139
|
+
f"{raw_answer}\n\n"
|
|
140
|
+
"请根据上游回答,总结出:\n"
|
|
141
|
+
"1. 一个字符串列表 items,依次给出推荐的项目(如数据集名称或方法名称);\n"
|
|
142
|
+
"2. 一个字符串 reason,简要说明总体理由。\n\n"
|
|
143
|
+
"返回格式要求:严格输出一个 JSON 对象,形如:\n"
|
|
144
|
+
'{"items": ["item1", "item2"], "reason": "总体理由"}\n'
|
|
145
|
+
"你可以直接输出 JSON,或用 markdown 代码块 ```json 包裹 JSON;不要输出任何额外文字。"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
norm_text = run_llm(
|
|
150
|
+
prompt=normalize_prompt,
|
|
151
|
+
temperature=0.1,
|
|
152
|
+
max_tokens=512,
|
|
153
|
+
use_codegen=True,
|
|
154
|
+
node_name=node_name,
|
|
155
|
+
)
|
|
156
|
+
except Exception:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
text = norm_text.strip()
|
|
160
|
+
if text.startswith("```"):
|
|
161
|
+
lines = text.splitlines()
|
|
162
|
+
if lines and lines[0].startswith("```"):
|
|
163
|
+
lines = lines[1:]
|
|
164
|
+
if lines and lines[-1].startswith("```"):
|
|
165
|
+
lines = lines[:-1]
|
|
166
|
+
text = "\n".join(lines).strip()
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
data = json.loads(text)
|
|
170
|
+
except Exception:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
if not isinstance(data, dict):
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
items = data.get("items", [])
|
|
177
|
+
reason = data.get("reason", "")
|
|
178
|
+
if not isinstance(items, list):
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
parsed_items = [str(x).strip() for x in items if str(x).strip()]
|
|
182
|
+
return {"items": parsed_items, "reason": str(reason).strip()}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class _ParsedList(TypedDict):
|
|
186
|
+
items: list[str]
|
|
187
|
+
reason: str
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _read_agent_list(
|
|
191
|
+
prompt: str,
|
|
192
|
+
info: dict,
|
|
193
|
+
fallback_parser: Optional[Callable[[str], Iterable[str]]] = None,
|
|
194
|
+
language: str = "zh",
|
|
195
|
+
) -> _ParsedList:
|
|
196
|
+
"""
|
|
197
|
+
调用 read_agent 并解析 JSON 列表,同时带上总体理由;
|
|
198
|
+
若解析失败则回退到简单拆分,reason 置空。
|
|
199
|
+
"""
|
|
200
|
+
answer = _call_read_agent(prompt, info, language)
|
|
201
|
+
|
|
202
|
+
# 先用二次 LLM 规范化为 JSON 结构
|
|
203
|
+
normalized = _normalize_list_with_reason(prompt, answer, language=language)
|
|
204
|
+
if normalized and normalized.get("items"):
|
|
205
|
+
return normalized # type: ignore[return-value]
|
|
206
|
+
|
|
207
|
+
# 若规范化失败或 items 为空,则退回到简单解析
|
|
208
|
+
parser = fallback_parser
|
|
209
|
+
return {"items": list(parser(answer)), "reason": ""}
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def get_recommend_method(
|
|
213
|
+
query_dataset_path: Optional[str],
|
|
214
|
+
user_query: str,
|
|
215
|
+
masked_dataset_name: Optional[str] = None,
|
|
216
|
+
) -> dict:
|
|
217
|
+
"""
|
|
218
|
+
根据用户查询和数据集路径,推荐适合的分析方法,并返回理由。
|
|
219
|
+
|
|
220
|
+
参数
|
|
221
|
+
----
|
|
222
|
+
query_dataset_path : str | None
|
|
223
|
+
目标数据集目录,包含 genetype.npz / phenotype.npz。
|
|
224
|
+
若为 None,则不基于具体数据集查找相似数据集,而是直接基于完整经验表推荐方法。
|
|
225
|
+
user_query : str
|
|
226
|
+
用户的分析需求描述。
|
|
227
|
+
masked_dataset_name : str, optional
|
|
228
|
+
若提供,则在参考库中过滤该 species_phenotype。
|
|
229
|
+
返回
|
|
230
|
+
----
|
|
231
|
+
dict
|
|
232
|
+
{
|
|
233
|
+
"similar_datasets": {"items": [...], "reason": "..."},
|
|
234
|
+
"methods": {"items": [...], "reason": "..."},
|
|
235
|
+
}
|
|
236
|
+
"""
|
|
237
|
+
# 在方法开始时检测语言
|
|
238
|
+
detected_language = _detect_language(user_query)
|
|
239
|
+
|
|
240
|
+
# 若未提供数据集路径,则直接基于完整经验表推荐方法
|
|
241
|
+
if not query_dataset_path:
|
|
242
|
+
matched_experience_info = experience_info.copy()
|
|
243
|
+
method_prompt = _build_method_prompt(user_query, detected_language)
|
|
244
|
+
method_result = _read_agent_list(method_prompt, matched_experience_info, language=detected_language)
|
|
245
|
+
reason_msg = (
|
|
246
|
+
"Dataset path not provided, recommending methods based on complete experience table only."
|
|
247
|
+
if detected_language == "en"
|
|
248
|
+
else "未提供数据集路径,仅基于完整经验表推荐方法。"
|
|
249
|
+
)
|
|
250
|
+
return {
|
|
251
|
+
"similar_datasets": {
|
|
252
|
+
"items": [],
|
|
253
|
+
"reason": reason_msg,
|
|
254
|
+
},
|
|
255
|
+
"methods": method_result,
|
|
256
|
+
}
|
|
257
|
+
|
|
258
|
+
# 1) 获取查询数据集的统计信息
|
|
259
|
+
query_dataset_summary = process_one_phenotype(query_dataset_path)
|
|
260
|
+
|
|
261
|
+
# 2) 处理参考数据集(可选屏蔽指定数据集)
|
|
262
|
+
if masked_dataset_name:
|
|
263
|
+
ref_dataset_summary_path, ref_dataset_summary_preview = create_masked_dataset_summary(
|
|
264
|
+
[masked_dataset_name]
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
ref_dataset_summary_path = dataset_summary_info["file_path"]
|
|
268
|
+
ref_dataset_summary_preview = dataset_summary_info["preview"]
|
|
269
|
+
|
|
270
|
+
ref_dataset_summary_info = dataset_summary_info.copy()
|
|
271
|
+
ref_dataset_summary_info["file_path"] = ref_dataset_summary_path
|
|
272
|
+
ref_dataset_summary_info["preview"] = ref_dataset_summary_preview
|
|
273
|
+
|
|
274
|
+
# 3) 调用 read agent 获取相似数据集名称
|
|
275
|
+
similarity_prompt = _build_similarity_prompt(query_dataset_summary, detected_language)
|
|
276
|
+
similar_result = _read_agent_list(similarity_prompt, ref_dataset_summary_info, language=detected_language)
|
|
277
|
+
similar_dataset_names = similar_result["items"]
|
|
278
|
+
if not similar_dataset_names:
|
|
279
|
+
return {
|
|
280
|
+
"similar_datasets": {"items": [], "reason": similar_result.get("reason", "")},
|
|
281
|
+
"methods": {"items": [], "reason": ""},
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# 4) 筛选匹配的经验表
|
|
285
|
+
matched_experience_path, matched_experience_preview = get_matched_experience(
|
|
286
|
+
similar_dataset_names, experience_info["file_path"]
|
|
287
|
+
)
|
|
288
|
+
matched_experience_info = experience_info.copy()
|
|
289
|
+
matched_experience_info["file_path"] = matched_experience_path
|
|
290
|
+
matched_experience_info["preview"] = matched_experience_preview
|
|
291
|
+
|
|
292
|
+
# 5) 调用 LLM 推荐方法
|
|
293
|
+
method_prompt = _build_method_prompt(user_query, detected_language)
|
|
294
|
+
method_result = _read_agent_list(method_prompt, matched_experience_info, language=detected_language)
|
|
295
|
+
|
|
296
|
+
return {"similar_datasets": similar_result, "methods": method_result}
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def _parse_args() -> argparse.Namespace:
|
|
300
|
+
"""CLI 参数解析。"""
|
|
301
|
+
parser = argparse.ArgumentParser(
|
|
302
|
+
description="根据数据集统计信息与用户需求推荐分析方法",
|
|
303
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
304
|
+
)
|
|
305
|
+
parser.add_argument(
|
|
306
|
+
"-d",
|
|
307
|
+
"--dataset",
|
|
308
|
+
help="可选:待分析数据集目录,包含 genetype.npz 和 phenotype.npz;不提供则仅基于经验表推荐方法",
|
|
309
|
+
)
|
|
310
|
+
parser.add_argument(
|
|
311
|
+
"-q",
|
|
312
|
+
"--user-query",
|
|
313
|
+
required=True,
|
|
314
|
+
help="用户对分析需求的描述",
|
|
315
|
+
)
|
|
316
|
+
parser.add_argument(
|
|
317
|
+
"-m",
|
|
318
|
+
"--mask",
|
|
319
|
+
dest="masked_dataset_name",
|
|
320
|
+
help="可选:需要在参考库中屏蔽的 species_phenotype",
|
|
321
|
+
)
|
|
322
|
+
parser.add_argument(
|
|
323
|
+
"-o",
|
|
324
|
+
"--output",
|
|
325
|
+
help="可选:将结果保存到 JSON 文件路径;不提供则直接打印",
|
|
326
|
+
)
|
|
327
|
+
return parser.parse_args()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def main() -> None:
|
|
331
|
+
args = _parse_args()
|
|
332
|
+
result = get_recommend_method(
|
|
333
|
+
query_dataset_path=args.dataset,
|
|
334
|
+
user_query=args.user_query,
|
|
335
|
+
masked_dataset_name=args.masked_dataset_name,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if args.output:
|
|
339
|
+
path = Path(args.output)
|
|
340
|
+
path.write_text(json.dumps(result, ensure_ascii=False, indent=2))
|
|
341
|
+
print(f"result saved to: {path}")
|
|
342
|
+
else:
|
|
343
|
+
print(json.dumps(result, ensure_ascii=False, indent=2))
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
if __name__ == "__main__":
|
|
347
|
+
main()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from .state import FileInfo, build_initial_state
|
|
4
|
+
from .workflow import compile_read_agent_workflow
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def run_read_agent(
|
|
8
|
+
user_query: str,
|
|
9
|
+
files: list[FileInfo],
|
|
10
|
+
file_overview: str,
|
|
11
|
+
language: str = "zh",
|
|
12
|
+
user_id: Optional[str] = None,
|
|
13
|
+
project_id: Optional[str] = None,
|
|
14
|
+
) -> dict:
|
|
15
|
+
"""
|
|
16
|
+
运行 Read Agent,返回最终状态。
|
|
17
|
+
|
|
18
|
+
参数:
|
|
19
|
+
user_query: 用户问题或任务描述。
|
|
20
|
+
files: 外部传入的文件信息列表,元素为 FileInfo:
|
|
21
|
+
{
|
|
22
|
+
"file_name": str,
|
|
23
|
+
"file_path": str,
|
|
24
|
+
"description": str,
|
|
25
|
+
"preview": str,
|
|
26
|
+
}
|
|
27
|
+
file_overview: 已经由外部构建好的文件概览字符串,用于规划 Prompt。
|
|
28
|
+
language: 语言代码,默认 "zh"。
|
|
29
|
+
user_id: 可选用户 ID。
|
|
30
|
+
project_id: 可选项目 ID。
|
|
31
|
+
返回:
|
|
32
|
+
最终状态字典,至少包含 "final_answer" 字段。
|
|
33
|
+
"""
|
|
34
|
+
initial_state = build_initial_state(
|
|
35
|
+
user_query=user_query,
|
|
36
|
+
files=files,
|
|
37
|
+
file_overview=file_overview,
|
|
38
|
+
user_id=user_id,
|
|
39
|
+
project_id=project_id,
|
|
40
|
+
language=language,
|
|
41
|
+
)
|
|
42
|
+
app = compile_read_agent_workflow()
|
|
43
|
+
final_state = app.invoke(initial_state)
|
|
44
|
+
return final_state
|
|
45
|
+
|
|
46
|
+
|