gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
简化版 Python 代码解释器工具。
|
|
3
|
+
|
|
4
|
+
参考 textMSA 中的 PythonREPL 实现,仅保留「执行代码」部分逻辑:
|
|
5
|
+
- 接收一段 Python 代码字符串
|
|
6
|
+
- 在受控环境中执行(支持表达式和多行脚本)
|
|
7
|
+
- 捕获 stdout / stderr
|
|
8
|
+
- 返回结构化的执行结果对象,便于上层判断是否成功
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import contextlib
|
|
14
|
+
import io
|
|
15
|
+
import time
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import Any, Optional
|
|
18
|
+
|
|
19
|
+
from logging_utils import get_logger
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class PythonREPLExecutionResult:
|
|
27
|
+
"""代码执行结果"""
|
|
28
|
+
|
|
29
|
+
stdout: str
|
|
30
|
+
stderr: str
|
|
31
|
+
execution_time: float
|
|
32
|
+
success: bool
|
|
33
|
+
error: Optional[Exception] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PythonREPL:
|
|
37
|
+
"""
|
|
38
|
+
轻量级 Python 代码执行器。
|
|
39
|
+
|
|
40
|
+
设计目标:
|
|
41
|
+
- 与 langchain_experimental.utilities.PythonREPL 的接口尽量兼容(提供 run(code))
|
|
42
|
+
- 保留跨调用的全局执行环境(可以在多次执行中复用变量)
|
|
43
|
+
- 捕获 stdout / stderr,供上层逻辑使用
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(self, max_code_length: int = 10_000) -> None:
|
|
47
|
+
self._max_code_length = max_code_length
|
|
48
|
+
# 共享全局环境,便于多次执行之间复用变量
|
|
49
|
+
self._exec_globals: dict[str, Any] = {}
|
|
50
|
+
logger.info(
|
|
51
|
+
"PythonREPL initialized",
|
|
52
|
+
extra={"max_code_length": max_code_length},
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def run(self, code: str) -> PythonREPLExecutionResult:
|
|
56
|
+
"""执行一段 Python 代码并返回执行结果。"""
|
|
57
|
+
if not code:
|
|
58
|
+
return PythonREPLExecutionResult(
|
|
59
|
+
stdout="",
|
|
60
|
+
stderr="代码为空",
|
|
61
|
+
execution_time=0.0,
|
|
62
|
+
success=False,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if len(code) > self._max_code_length:
|
|
66
|
+
return PythonREPLExecutionResult(
|
|
67
|
+
stdout="",
|
|
68
|
+
stderr=f"代码长度超过限制 ({len(code)} > {self._max_code_length})",
|
|
69
|
+
execution_time=0.0,
|
|
70
|
+
success=False,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
logger.info("Executing Python code", extra={"code_length": len(code)})
|
|
74
|
+
logger.debug("Code to execute", extra={"code_preview": code[:500]})
|
|
75
|
+
|
|
76
|
+
start_time = time.perf_counter()
|
|
77
|
+
|
|
78
|
+
stdout_buf = io.StringIO()
|
|
79
|
+
stderr_buf = io.StringIO()
|
|
80
|
+
|
|
81
|
+
# 判定使用 eval 还是 exec
|
|
82
|
+
try:
|
|
83
|
+
code_obj = compile(code, "<python-repl>", "eval")
|
|
84
|
+
use_eval = True
|
|
85
|
+
except SyntaxError:
|
|
86
|
+
code_obj = compile(code, "<python-repl>", "exec")
|
|
87
|
+
use_eval = False
|
|
88
|
+
|
|
89
|
+
try:
|
|
90
|
+
with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(
|
|
91
|
+
stderr_buf
|
|
92
|
+
):
|
|
93
|
+
if use_eval:
|
|
94
|
+
result = eval(code_obj, self._exec_globals) # noqa: S307
|
|
95
|
+
else:
|
|
96
|
+
exec(code_obj, self._exec_globals) # noqa: S102
|
|
97
|
+
result = None
|
|
98
|
+
|
|
99
|
+
execution_time = time.perf_counter() - start_time
|
|
100
|
+
|
|
101
|
+
stdout = stdout_buf.getvalue()
|
|
102
|
+
stderr = stderr_buf.getvalue()
|
|
103
|
+
success = True
|
|
104
|
+
|
|
105
|
+
# eval 模式下如果有返回值,将其追加到 stdout,便于查看
|
|
106
|
+
if use_eval and result is not None:
|
|
107
|
+
result_str = result if isinstance(result, str) else str(result)
|
|
108
|
+
if stdout and not stdout.endswith("\n"):
|
|
109
|
+
stdout += "\n"
|
|
110
|
+
stdout += result_str
|
|
111
|
+
|
|
112
|
+
logger.info(
|
|
113
|
+
"Code execution completed",
|
|
114
|
+
extra={
|
|
115
|
+
"execution_time": execution_time,
|
|
116
|
+
"stdout_length": len(stdout),
|
|
117
|
+
"stderr_length": len(stderr),
|
|
118
|
+
"success": success,
|
|
119
|
+
},
|
|
120
|
+
)
|
|
121
|
+
if stdout:
|
|
122
|
+
logger.info("Code execution stdout", extra={"stdout": stdout})
|
|
123
|
+
if stderr:
|
|
124
|
+
logger.warning("Code execution stderr", extra={"stderr": stderr})
|
|
125
|
+
|
|
126
|
+
return PythonREPLExecutionResult(
|
|
127
|
+
stdout=stdout,
|
|
128
|
+
stderr=stderr,
|
|
129
|
+
execution_time=execution_time,
|
|
130
|
+
success=success,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
except BaseException as exc: # noqa: BLE001
|
|
134
|
+
# 捕获 SystemExit / KeyboardInterrupt 等,避免上层进程被直接退出
|
|
135
|
+
execution_time = time.perf_counter() - start_time
|
|
136
|
+
stdout = stdout_buf.getvalue()
|
|
137
|
+
stderr = stderr_buf.getvalue()
|
|
138
|
+
|
|
139
|
+
# 将异常信息追加到 stderr,便于上层展示
|
|
140
|
+
if stderr:
|
|
141
|
+
stderr = f"{stderr}\n{exc}"
|
|
142
|
+
else:
|
|
143
|
+
stderr = str(exc)
|
|
144
|
+
|
|
145
|
+
logger.error(
|
|
146
|
+
"Code execution failed",
|
|
147
|
+
extra={
|
|
148
|
+
"execution_time": execution_time,
|
|
149
|
+
"error": stderr,
|
|
150
|
+
},
|
|
151
|
+
exc_info=True,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
return PythonREPLExecutionResult(
|
|
155
|
+
stdout=stdout,
|
|
156
|
+
stderr=stderr,
|
|
157
|
+
execution_time=execution_time,
|
|
158
|
+
success=False,
|
|
159
|
+
error=exc if isinstance(exc, Exception) else None,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
__all__ = ["PythonREPL", "PythonREPLExecutionResult"]
|
|
164
|
+
|
|
165
|
+
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Read Agent 状态定义(独立于 textMSA 项目)。
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from typing import Optional, TypedDict
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
from typing import NotRequired # type: ignore[attr-defined]
|
|
11
|
+
except Exception: # pragma: no cover
|
|
12
|
+
from typing_extensions import NotRequired
|
|
13
|
+
|
|
14
|
+
from logging_utils import get_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = get_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FileInfo(TypedDict, total=False):
|
|
21
|
+
"""
|
|
22
|
+
外部传入的文件信息。
|
|
23
|
+
|
|
24
|
+
注意:字段严格按照用户约定,不新增额外字段。
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
file_name: str
|
|
28
|
+
file_path: str
|
|
29
|
+
description: str
|
|
30
|
+
preview: str
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class PlanHistory(TypedDict, total=False):
|
|
34
|
+
"""计划历史记录"""
|
|
35
|
+
|
|
36
|
+
file_name: str # 文件名
|
|
37
|
+
file_path: str # 文件路径
|
|
38
|
+
plan_detail: str # 计划详情
|
|
39
|
+
result: Optional[str] # 执行结果
|
|
40
|
+
order_reasoning: NotRequired[str] # 顺序理由
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class ReadAgentState(TypedDict, total=False):
|
|
44
|
+
"""Read Agent 的状态(简化版)"""
|
|
45
|
+
|
|
46
|
+
# 用户查询
|
|
47
|
+
user_query: str
|
|
48
|
+
# 文件列表(外部传入的 file info)
|
|
49
|
+
files: list[FileInfo]
|
|
50
|
+
# 文件概览字符串(外部已经格式化好)
|
|
51
|
+
file_overview: str
|
|
52
|
+
# 语言
|
|
53
|
+
language: NotRequired[str]
|
|
54
|
+
# 历史计划
|
|
55
|
+
history_plans: list[PlanHistory]
|
|
56
|
+
# 当前计划索引
|
|
57
|
+
current_plan_index: int
|
|
58
|
+
# 最终答案
|
|
59
|
+
final_answer: NotRequired[Optional[str]]
|
|
60
|
+
# 下一步路由
|
|
61
|
+
next_route: NotRequired[str]
|
|
62
|
+
# 用户/项目 ID(可选)
|
|
63
|
+
user_id: NotRequired[str]
|
|
64
|
+
project_id: NotRequired[str]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def build_initial_state(
|
|
68
|
+
user_query: str,
|
|
69
|
+
files: list[FileInfo],
|
|
70
|
+
file_overview: str,
|
|
71
|
+
user_id: Optional[str] = None,
|
|
72
|
+
project_id: Optional[str] = None,
|
|
73
|
+
language: str = "zh",
|
|
74
|
+
) -> ReadAgentState:
|
|
75
|
+
"""构建初始状态"""
|
|
76
|
+
state: ReadAgentState = {
|
|
77
|
+
"user_query": user_query,
|
|
78
|
+
"files": files,
|
|
79
|
+
"file_overview": file_overview,
|
|
80
|
+
"language": language,
|
|
81
|
+
"history_plans": [],
|
|
82
|
+
"current_plan_index": 0,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
if user_id:
|
|
86
|
+
state["user_id"] = user_id
|
|
87
|
+
if project_id:
|
|
88
|
+
state["project_id"] = project_id
|
|
89
|
+
|
|
90
|
+
logger.info(
|
|
91
|
+
"Read Agent initial state ready",
|
|
92
|
+
extra={
|
|
93
|
+
"files_len": len(files),
|
|
94
|
+
"user_id": user_id,
|
|
95
|
+
"project_id": project_id,
|
|
96
|
+
"language": language,
|
|
97
|
+
},
|
|
98
|
+
)
|
|
99
|
+
return state
|
|
100
|
+
|
|
101
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Read Agent 工作流(独立版本)
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from langgraph.graph import END, StateGraph
|
|
6
|
+
|
|
7
|
+
from logging_utils import get_logger
|
|
8
|
+
from .nodes import plan_node, execute_plan_node, read_node, answer_node
|
|
9
|
+
from .state import ReadAgentState
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _route_after_execute(state: ReadAgentState) -> str:
|
|
16
|
+
"""路由函数:根据 next_route 决定下一步"""
|
|
17
|
+
next_route = state.get("next_route", "")
|
|
18
|
+
return next_route or "read"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build_read_agent_workflow() -> StateGraph:
|
|
22
|
+
"""构建 Read Agent 工作流"""
|
|
23
|
+
workflow = StateGraph(ReadAgentState)
|
|
24
|
+
|
|
25
|
+
workflow.add_node("plan", plan_node)
|
|
26
|
+
workflow.add_node("execute_plan", execute_plan_node)
|
|
27
|
+
workflow.add_node("read", read_node)
|
|
28
|
+
workflow.add_node("answer", answer_node)
|
|
29
|
+
|
|
30
|
+
workflow.set_entry_point("plan")
|
|
31
|
+
|
|
32
|
+
workflow.add_edge("plan", "execute_plan")
|
|
33
|
+
|
|
34
|
+
workflow.add_conditional_edges(
|
|
35
|
+
"execute_plan",
|
|
36
|
+
_route_after_execute,
|
|
37
|
+
{
|
|
38
|
+
"read": "read",
|
|
39
|
+
"answer": "answer",
|
|
40
|
+
},
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
workflow.add_edge("read", "execute_plan")
|
|
44
|
+
workflow.add_edge("answer", END)
|
|
45
|
+
|
|
46
|
+
return workflow
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def compile_read_agent_workflow():
|
|
50
|
+
"""编译 Read Agent 工作流"""
|
|
51
|
+
wf = build_read_agent_workflow()
|
|
52
|
+
return wf.compile()
|
|
53
|
+
|
|
54
|
+
|
gpbench/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
GPBench: A benchmarking toolkit for genomic prediction.
|
|
3
|
+
|
|
4
|
+
This package provides implementations of various genomic prediction methods
|
|
5
|
+
including classic linear statistical approaches and machine learning/deep learning methods.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "1.0.0"
|
|
9
|
+
__author__ = "GPBench Contributors"
|
|
10
|
+
__email__ = ""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
from ._selftest import run_import_test
|
|
14
|
+
def test(verbose=True):
|
|
15
|
+
"""
|
|
16
|
+
Run GPBench import self-test.
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
import gpbench
|
|
20
|
+
gpbench.test()
|
|
21
|
+
"""
|
|
22
|
+
return run_import_test(verbose=verbose)
|
|
23
|
+
|
|
24
|
+
__all__ = ["test"]
|
|
25
|
+
|
gpbench/_selftest.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
# gpbench/_selftest.py
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
# 你要测试的所有导入路径(按你提供的41条命令整理)
|
|
7
|
+
IMPORT_TESTS = [
|
|
8
|
+
|
|
9
|
+
# =========================
|
|
10
|
+
# method_reg methods
|
|
11
|
+
# =========================
|
|
12
|
+
("gpbench.method_reg.BayesA", "BayesA_reg"),
|
|
13
|
+
("gpbench.method_reg.BayesB", "BayesB_reg"),
|
|
14
|
+
("gpbench.method_reg.BayesC", "BayesC_reg"),
|
|
15
|
+
("gpbench.method_reg.CropARNet", "CropARNet_reg"),
|
|
16
|
+
("gpbench.method_reg.Cropformer", "Cropformer_reg"),
|
|
17
|
+
("gpbench.method_reg.DeepCCR", "DeepCCR_reg"),
|
|
18
|
+
("gpbench.method_reg.DeepGS", "DeepGS_reg"),
|
|
19
|
+
("gpbench.method_reg.DL_GWAS", "DL_GWAS_reg"),
|
|
20
|
+
("gpbench.method_reg.DNNGP", "DNNGP_reg"),
|
|
21
|
+
("gpbench.method_reg.EIR", "EIR_reg"),
|
|
22
|
+
("gpbench.method_reg.ElasticNet", "ElasticNet_reg"),
|
|
23
|
+
("gpbench.method_reg.G2PDeep", "G2PDeep_reg"),
|
|
24
|
+
("gpbench.method_reg.GBLUP", "GBLUP_reg"),
|
|
25
|
+
("gpbench.method_reg.GEFormer", "GEFormer_reg"),
|
|
26
|
+
("gpbench.method_reg.LightGBM", "LightGBM_reg"),
|
|
27
|
+
("gpbench.method_reg.MVP", "MVP_reg"),
|
|
28
|
+
("gpbench.method_reg.RF", "RF_reg"),
|
|
29
|
+
("gpbench.method_reg.rrBLUP", "rrBLUP_reg"),
|
|
30
|
+
("gpbench.method_reg.SoyDNGP", "SoyDNGP_reg"),
|
|
31
|
+
("gpbench.method_reg.SVC", "SVC_reg"),
|
|
32
|
+
("gpbench.method_reg.XGBoost", "XGBoost_reg"),
|
|
33
|
+
|
|
34
|
+
# =========================
|
|
35
|
+
# method_class methods
|
|
36
|
+
# =========================
|
|
37
|
+
("gpbench.method_class.BayesA", "BayesA_class"),
|
|
38
|
+
("gpbench.method_class.BayesB", "BayesB_class"),
|
|
39
|
+
("gpbench.method_class.BayesC", "BayesC_class"),
|
|
40
|
+
("gpbench.method_class.CropARNet", "CropARNet_class"),
|
|
41
|
+
("gpbench.method_class.Cropformer", "Cropformer_class"),
|
|
42
|
+
("gpbench.method_class.DeepCCR", "DeepCCR_class"),
|
|
43
|
+
("gpbench.method_class.DeepGS", "DeepGS_class"),
|
|
44
|
+
("gpbench.method_class.DL_GWAS", "DL_GWAS_class"),
|
|
45
|
+
("gpbench.method_class.DNNGP", "DNNGP_class"),
|
|
46
|
+
("gpbench.method_class.EIR", "EIR_class"),
|
|
47
|
+
("gpbench.method_class.ElasticNet", "ElasticNet_class"),
|
|
48
|
+
("gpbench.method_class.G2PDeep", "G2PDeep_class"),
|
|
49
|
+
("gpbench.method_class.GBLUP", "GBLUP_class"),
|
|
50
|
+
("gpbench.method_class.GEFormer", "GEFormer_class"),
|
|
51
|
+
("gpbench.method_class.LightGBM", "LightGBM_class"),
|
|
52
|
+
("gpbench.method_class.RF", "RF_class"),
|
|
53
|
+
("gpbench.method_class.rrBLUP", "rrBLUP_class"),
|
|
54
|
+
("gpbench.method_class.SoyDNGP", "SoyDNGP_class"),
|
|
55
|
+
("gpbench.method_class.SVC", "SVC_class"),
|
|
56
|
+
("gpbench.method_class.XGBoost", "XGBoost_class"),
|
|
57
|
+
]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def run_import_test(verbose=True):
|
|
61
|
+
"""
|
|
62
|
+
Run import test for all 41 methods.
|
|
63
|
+
Returns True if all passed, False otherwise.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
print("\n==============================")
|
|
67
|
+
print(" GPBench Import Self Test ")
|
|
68
|
+
print("==============================\n")
|
|
69
|
+
|
|
70
|
+
success = 0
|
|
71
|
+
failed = []
|
|
72
|
+
|
|
73
|
+
for module_name, obj_name in IMPORT_TESTS:
|
|
74
|
+
try:
|
|
75
|
+
module = importlib.import_module(module_name)
|
|
76
|
+
obj = getattr(module, obj_name)
|
|
77
|
+
|
|
78
|
+
success += 1
|
|
79
|
+
if verbose:
|
|
80
|
+
print(f"[OK] from {module_name} import {obj_name}")
|
|
81
|
+
|
|
82
|
+
except Exception as e:
|
|
83
|
+
failed.append((module_name, obj_name, str(e)))
|
|
84
|
+
print(f"[FAIL] from {module_name} import {obj_name}")
|
|
85
|
+
print(f" Error: {e}")
|
|
86
|
+
|
|
87
|
+
# Summary
|
|
88
|
+
print("\n==============================")
|
|
89
|
+
print(" Test Summary ")
|
|
90
|
+
print("==============================")
|
|
91
|
+
print(f"Total Methods Tested: {len(IMPORT_TESTS)}")
|
|
92
|
+
print(f"Passed: {success}")
|
|
93
|
+
print(f"Failed: {len(failed)}")
|
|
94
|
+
|
|
95
|
+
if failed:
|
|
96
|
+
print("\n❌ Failed Imports:")
|
|
97
|
+
for mod, obj, err in failed:
|
|
98
|
+
print(f" - {mod}.{obj}: {err}")
|
|
99
|
+
|
|
100
|
+
print("\nSelf-test FAILED.\n")
|
|
101
|
+
return False
|
|
102
|
+
|
|
103
|
+
print("\n✅ All imports passed successfully!\n")
|
|
104
|
+
return True
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import psutil
|
|
4
|
+
import swanlab
|
|
5
|
+
import argparse
|
|
6
|
+
import random
|
|
7
|
+
import torch
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import sys
|
|
11
|
+
from .bayesAfromR import BayesA
|
|
12
|
+
from sklearn.model_selection import StratifiedKFold
|
|
13
|
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
|
|
14
|
+
from sklearn.preprocessing import LabelEncoder
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def parse_args():
|
|
18
|
+
parser = argparse.ArgumentParser(description="Argument parser")
|
|
19
|
+
parser.add_argument('--methods', type=str, default='BayesA/', help='Model name')
|
|
20
|
+
parser.add_argument('--species', type=str, default='Human/', help='Species name')
|
|
21
|
+
parser.add_argument('--phe', type=str, default='', help='Phenotype name')
|
|
22
|
+
parser.add_argument('--task', type=str, default='classification', choices=['regression','classification'], help='Task: regression or classification')
|
|
23
|
+
parser.add_argument('--data_dir', type=str, default='../../data/', help='Path to data directory')
|
|
24
|
+
parser.add_argument('--result_dir', type=str, default='result/', help='Path to result directory')
|
|
25
|
+
return parser.parse_args()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def load_data(args):
|
|
29
|
+
xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
|
|
30
|
+
yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
|
|
31
|
+
names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
|
|
32
|
+
nsample = xData.shape[0]
|
|
33
|
+
nsnp = xData.shape[1]
|
|
34
|
+
print("Number of samples: ", nsample)
|
|
35
|
+
print("Number of SNPs: ", nsnp)
|
|
36
|
+
return xData, yData, nsample, nsnp, names
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def set_seed(seed=42):
|
|
40
|
+
random.seed(seed)
|
|
41
|
+
np.random.seed(seed)
|
|
42
|
+
torch.manual_seed(seed)
|
|
43
|
+
torch.cuda.manual_seed_all(seed)
|
|
44
|
+
torch.backends.cudnn.deterministic = True
|
|
45
|
+
torch.backends.cudnn.benchmark = False
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def run_nested_cv(args, data, label):
|
|
49
|
+
result_dir = os.path.join(args.result_dir, args.methods + args.species)
|
|
50
|
+
os.makedirs(result_dir, exist_ok=True)
|
|
51
|
+
print("Starting 10-fold cross-validation...")
|
|
52
|
+
|
|
53
|
+
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
|
|
54
|
+
le = LabelEncoder()
|
|
55
|
+
label_all = le.fit_transform(label)
|
|
56
|
+
|
|
57
|
+
np.save(os.path.join(result_dir, 'label_mapping.npy'), le.classes_)
|
|
58
|
+
|
|
59
|
+
all_acc, all_prec, all_rec, all_f1 = [], [], [], []
|
|
60
|
+
start_time = time.time()
|
|
61
|
+
process = psutil.Process(os.getpid())
|
|
62
|
+
|
|
63
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data, label_all)):
|
|
64
|
+
fold_start = time.time()
|
|
65
|
+
print(f"\n===== Fold {fold} =====")
|
|
66
|
+
X_train, X_test = data[train_index], data[test_index]
|
|
67
|
+
Y_train, Y_test = label_all[train_index], label_all[test_index]
|
|
68
|
+
|
|
69
|
+
if torch.cuda.is_available():
|
|
70
|
+
torch.cuda.reset_peak_memory_stats()
|
|
71
|
+
|
|
72
|
+
classes = np.unique(Y_train)
|
|
73
|
+
scores = np.zeros((len(classes), X_test.shape[0]))
|
|
74
|
+
for idx, cls in enumerate(classes):
|
|
75
|
+
y_train_bin = (Y_train == cls).astype(float)
|
|
76
|
+
model_k = BayesA(task="regression")
|
|
77
|
+
model_k.fit(X_train, y_train_bin)
|
|
78
|
+
scores[idx, :] = model_k.predict(X_test)
|
|
79
|
+
|
|
80
|
+
Y_pred = np.argmax(scores, axis=0)
|
|
81
|
+
|
|
82
|
+
acc = accuracy_score(Y_test, Y_pred)
|
|
83
|
+
prec, rec, f1, _ = precision_recall_fscore_support(Y_test, Y_pred, average='macro', zero_division=0)
|
|
84
|
+
cm = confusion_matrix(Y_test, Y_pred)
|
|
85
|
+
|
|
86
|
+
all_acc.append(acc)
|
|
87
|
+
all_prec.append(prec)
|
|
88
|
+
all_rec.append(rec)
|
|
89
|
+
all_f1.append(f1)
|
|
90
|
+
|
|
91
|
+
fold_time = time.time() - fold_start
|
|
92
|
+
fold_gpu_mem = torch.cuda.max_memory_allocated() / 1024**2 if torch.cuda.is_available() else 0
|
|
93
|
+
fold_cpu_mem = process.memory_info().rss / 1024**2
|
|
94
|
+
print(f'Fold {fold}: ACC={acc:.4f}, PREC={prec:.4f}, REC={rec:.4f}, F1={f1:.4f}, Time={fold_time:.2f}s, '
|
|
95
|
+
f'GPU={fold_gpu_mem:.2f}MB, CPU={fold_cpu_mem:.2f}MB')
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# ========== 保存预测结果 ==========
|
|
99
|
+
Y_test_orig = le.inverse_transform(Y_test)
|
|
100
|
+
Y_pred_orig = le.inverse_transform(Y_pred)
|
|
101
|
+
results_df = pd.DataFrame({'Y_test': Y_test_orig, 'Y_pred': Y_pred_orig})
|
|
102
|
+
results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
|
|
103
|
+
|
|
104
|
+
print("\n===== Cross-validation summary =====")
|
|
105
|
+
print(f"Average ACC: {np.mean(all_acc):.4f} ± {np.std(all_acc):.4f}")
|
|
106
|
+
print(f"Average PREC: {np.mean(all_prec):.4f} ± {np.std(all_prec):.4f}")
|
|
107
|
+
print(f"Average REC: {np.mean(all_rec):.4f} ± {np.std(all_rec):.4f}")
|
|
108
|
+
print(f"Average F1 : {np.mean(all_f1):.4f} ± {np.std(all_f1):.4f}")
|
|
109
|
+
print(f"Total time : {time.time() - start_time:.2f}s")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def BayesA_class():
|
|
113
|
+
set_seed(42)
|
|
114
|
+
torch.cuda.empty_cache()
|
|
115
|
+
args = parse_args()
|
|
116
|
+
all_species = ["Human/Sim/"]
|
|
117
|
+
for i in range(len(all_species)):
|
|
118
|
+
args.species = all_species[i]
|
|
119
|
+
X, Y, nsamples, nsnp, names = load_data(args)
|
|
120
|
+
args.phe = names
|
|
121
|
+
print("Starting run " + args.methods + args.species)
|
|
122
|
+
label = Y[:, 0]
|
|
123
|
+
|
|
124
|
+
if args.task == 'classification':
|
|
125
|
+
s = pd.Series(label)
|
|
126
|
+
fill_val = s.mode().iloc[0] if not s.dropna().empty else 0
|
|
127
|
+
label = np.nan_to_num(label, nan=fill_val)
|
|
128
|
+
|
|
129
|
+
start_time = time.time()
|
|
130
|
+
torch.cuda.reset_peak_memory_stats()
|
|
131
|
+
process = psutil.Process(os.getpid())
|
|
132
|
+
|
|
133
|
+
run_nested_cv(args, data=X, label=label)
|
|
134
|
+
|
|
135
|
+
elapsed_time = time.time() - start_time
|
|
136
|
+
print(f"Total running time: {elapsed_time:.2f} s")
|
|
137
|
+
print("Successfully finished!")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == "__main__":
|
|
141
|
+
BayesA_class()
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import rpy2
|
|
3
|
+
from rpy2.robjects import numpy2ri
|
|
4
|
+
rpy2.robjects.numpy2ri.activate()
|
|
5
|
+
import rpy2.robjects as robjects
|
|
6
|
+
from rpy2.robjects.packages import importr
|
|
7
|
+
from . import _param_free_base_model
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
|
|
10
|
+
class Bayes_R(_param_free_base_model.ParamFreeBaseModel):
|
|
11
|
+
"""
|
|
12
|
+
Implementation of a class for Bayesian alphabet.
|
|
13
|
+
|
|
14
|
+
*Attributes*
|
|
15
|
+
|
|
16
|
+
*Inherited attributes*
|
|
17
|
+
|
|
18
|
+
See :obj:`~easypheno.model._param_free_base_model.ParamFreeBaseModel` for more information on the attributes.
|
|
19
|
+
|
|
20
|
+
*Additional attributes*
|
|
21
|
+
|
|
22
|
+
- mu (*np.array*): intercept
|
|
23
|
+
- beta (*np.array*): effect size
|
|
24
|
+
- model_name (*str*): model to use (BayesA, BayesB or BayesC)
|
|
25
|
+
- n_iter (*int*): iterations for sampling
|
|
26
|
+
- burn_in (*int*): warmup/burnin for sampling
|
|
27
|
+
"""
|
|
28
|
+
standard_encoding = '012'
|
|
29
|
+
possible_encodings = ['101']
|
|
30
|
+
|
|
31
|
+
def __init__(self, task: str, model_name: str, encoding: str = None, n_iter: int =1000, burn_in: int = 200):
|
|
32
|
+
super().__init__(task=task, encoding=encoding)
|
|
33
|
+
self.model_name = model_name
|
|
34
|
+
self.n_iter = n_iter
|
|
35
|
+
self.burn_in = burn_in
|
|
36
|
+
self.n_jobs = 1
|
|
37
|
+
self.mu = None
|
|
38
|
+
self.beta = None
|
|
39
|
+
|
|
40
|
+
def _run_chain(self, chain_num: int, R_X, R_y):
|
|
41
|
+
"""
|
|
42
|
+
Helper function to run an individual MCMC chain.
|
|
43
|
+
"""
|
|
44
|
+
BGLR = importr('BGLR')
|
|
45
|
+
|
|
46
|
+
# Run BGLR for BayesB on a single chain
|
|
47
|
+
ETA = robjects.r['list'](robjects.r['list'](X=R_X, model=self.model_name))
|
|
48
|
+
fmBB = BGLR.BGLR(y=R_y, ETA=ETA, verbose=False, nIter=self.n_iter, burnIn=self.burn_in)
|
|
49
|
+
|
|
50
|
+
# Extract the results for this chain
|
|
51
|
+
beta_chain = np.asarray(fmBB.rx2('ETA').rx2(1).rx2('b'))
|
|
52
|
+
mu_chain = np.asarray(fmBB.rx2('mu')) # Extract mu (intercept) for this chain
|
|
53
|
+
return beta_chain, mu_chain
|
|
54
|
+
|
|
55
|
+
def fit(self, X: np.array, y: np.array) -> np.array:
|
|
56
|
+
"""
|
|
57
|
+
Implementation of fit function for Bayesian alphabet imported from R.
|
|
58
|
+
|
|
59
|
+
See :obj:`~easypheno.model._param_free_base_model.ParamFreeBaseModel` for more information.
|
|
60
|
+
"""
|
|
61
|
+
# import necessary R packages
|
|
62
|
+
base = importr('base')
|
|
63
|
+
BGLR = importr('BGLR')
|
|
64
|
+
|
|
65
|
+
# create R objects for X and y
|
|
66
|
+
R_X = robjects.r['matrix'](X, nrow=X.shape[0], ncol=X.shape[1])
|
|
67
|
+
R_y = robjects.FloatVector(y)
|
|
68
|
+
|
|
69
|
+
results = Parallel(n_jobs=self.n_jobs)(
|
|
70
|
+
delayed(self._run_chain)(chain_num, R_X, R_y) for chain_num in range(self.n_jobs)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Aggregate results from all chains
|
|
74
|
+
beta_chains = [result[0] for result in results]
|
|
75
|
+
mu_chains = [result[1] for result in results]
|
|
76
|
+
|
|
77
|
+
# Compute the mean of beta and mu over all chains
|
|
78
|
+
self.beta = np.mean(beta_chains, axis=0)
|
|
79
|
+
self.mu = np.mean(mu_chains, axis=0)
|
|
80
|
+
|
|
81
|
+
# run BGLR for BayesB
|
|
82
|
+
# ETA = base.list(base.list(X=R_X, model=self.model_name))
|
|
83
|
+
# fmBB = BGLR.BGLR(y=R_y, ETA=ETA, verbose=True, nIter=self.n_iter, burnIn=self.burn_in)
|
|
84
|
+
|
|
85
|
+
# # save results as numpy arrays
|
|
86
|
+
# self.beta = np.asarray(fmBB.rx2('ETA').rx2(1).rx2('b'))
|
|
87
|
+
# self.mu = fmBB.rx2('mu')
|
|
88
|
+
return self.predict(X_in=X)
|
|
89
|
+
|
|
90
|
+
def predict(self, X_in: np.array) -> np.array:
|
|
91
|
+
"""
|
|
92
|
+
Implementation of predict function for Bayesian alphabet model imported from R.
|
|
93
|
+
|
|
94
|
+
See :obj:`~easypheno.model._param_free_base_model.ParamFreeBaseModel` for more information.
|
|
95
|
+
"""
|
|
96
|
+
return self.mu + np.matmul(X_in, self.beta)
|