gpbench 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gp_agent_tool/compute_dataset_feature.py +67 -0
- gp_agent_tool/config.py +65 -0
- gp_agent_tool/experience/create_masked_dataset_summary.py +97 -0
- gp_agent_tool/experience/dataset_summary_info.py +13 -0
- gp_agent_tool/experience/experience_info.py +12 -0
- gp_agent_tool/experience/get_matched_experience.py +111 -0
- gp_agent_tool/llm_client.py +119 -0
- gp_agent_tool/logging_utils.py +24 -0
- gp_agent_tool/main.py +347 -0
- gp_agent_tool/read_agent/__init__.py +46 -0
- gp_agent_tool/read_agent/nodes.py +674 -0
- gp_agent_tool/read_agent/prompts.py +547 -0
- gp_agent_tool/read_agent/python_repl_tool.py +165 -0
- gp_agent_tool/read_agent/state.py +101 -0
- gp_agent_tool/read_agent/workflow.py +54 -0
- gpbench/__init__.py +25 -0
- gpbench/_selftest.py +104 -0
- gpbench/method_class/BayesA/BayesA_class.py +141 -0
- gpbench/method_class/BayesA/__init__.py +5 -0
- gpbench/method_class/BayesA/_bayesfromR.py +96 -0
- gpbench/method_class/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesA/bayesAfromR.py +16 -0
- gpbench/method_class/BayesB/BayesB_class.py +140 -0
- gpbench/method_class/BayesB/__init__.py +5 -0
- gpbench/method_class/BayesB/_bayesfromR.py +96 -0
- gpbench/method_class/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesB/bayesBfromR.py +16 -0
- gpbench/method_class/BayesC/BayesC_class.py +141 -0
- gpbench/method_class/BayesC/__init__.py +4 -0
- gpbench/method_class/BayesC/_bayesfromR.py +96 -0
- gpbench/method_class/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_class/BayesC/bayesCfromR.py +16 -0
- gpbench/method_class/CropARNet/CropARNet_class.py +186 -0
- gpbench/method_class/CropARNet/CropARNet_he_class.py +154 -0
- gpbench/method_class/CropARNet/__init__.py +5 -0
- gpbench/method_class/CropARNet/base_CropARNet_class.py +178 -0
- gpbench/method_class/Cropformer/Cropformer_class.py +308 -0
- gpbench/method_class/Cropformer/__init__.py +5 -0
- gpbench/method_class/Cropformer/cropformer_he_class.py +221 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_class.py +250 -0
- gpbench/method_class/DL_GWAS/DL_GWAS_he_class.py +169 -0
- gpbench/method_class/DL_GWAS/__init__.py +5 -0
- gpbench/method_class/DNNGP/DNNGP_class.py +163 -0
- gpbench/method_class/DNNGP/DNNGP_he_class.py +138 -0
- gpbench/method_class/DNNGP/__init__.py +5 -0
- gpbench/method_class/DNNGP/base_dnngp_class.py +116 -0
- gpbench/method_class/DeepCCR/DeepCCR_class.py +172 -0
- gpbench/method_class/DeepCCR/DeepCCR_he_class.py +161 -0
- gpbench/method_class/DeepCCR/__init__.py +5 -0
- gpbench/method_class/DeepCCR/base_DeepCCR_class.py +209 -0
- gpbench/method_class/DeepGS/DeepGS_class.py +184 -0
- gpbench/method_class/DeepGS/DeepGS_he_class.py +150 -0
- gpbench/method_class/DeepGS/__init__.py +5 -0
- gpbench/method_class/DeepGS/base_deepgs_class.py +153 -0
- gpbench/method_class/EIR/EIR_class.py +276 -0
- gpbench/method_class/EIR/EIR_he_class.py +184 -0
- gpbench/method_class/EIR/__init__.py +5 -0
- gpbench/method_class/EIR/utils/__init__.py +0 -0
- gpbench/method_class/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_class/EIR/utils/common.py +65 -0
- gpbench/method_class/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_class/EIR/utils/logging.py +59 -0
- gpbench/method_class/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_class/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_class/EIR/utils/transformer_models.py +546 -0
- gpbench/method_class/ElasticNet/ElasticNet_class.py +133 -0
- gpbench/method_class/ElasticNet/ElasticNet_he_class.py +91 -0
- gpbench/method_class/ElasticNet/__init__.py +5 -0
- gpbench/method_class/G2PDeep/G2PDeep_he_class.py +217 -0
- gpbench/method_class/G2PDeep/G2Pdeep_class.py +205 -0
- gpbench/method_class/G2PDeep/__init__.py +5 -0
- gpbench/method_class/G2PDeep/base_G2PDeep_class.py +209 -0
- gpbench/method_class/GBLUP/GBLUP_class.py +183 -0
- gpbench/method_class/GBLUP/__init__.py +5 -0
- gpbench/method_class/GEFormer/GEFormer_class.py +169 -0
- gpbench/method_class/GEFormer/GEFormer_he_class.py +137 -0
- gpbench/method_class/GEFormer/__init__.py +5 -0
- gpbench/method_class/GEFormer/gMLP_class.py +357 -0
- gpbench/method_class/LightGBM/LightGBM_class.py +224 -0
- gpbench/method_class/LightGBM/LightGBM_he_class.py +121 -0
- gpbench/method_class/LightGBM/__init__.py +5 -0
- gpbench/method_class/RF/RF_GPU_class.py +165 -0
- gpbench/method_class/RF/RF_GPU_he_class.py +124 -0
- gpbench/method_class/RF/__init__.py +5 -0
- gpbench/method_class/SVC/SVC_GPU.py +181 -0
- gpbench/method_class/SVC/SVC_GPU_he.py +106 -0
- gpbench/method_class/SVC/__init__.py +5 -0
- gpbench/method_class/SoyDNGP/AlexNet_206_class.py +179 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_class.py +189 -0
- gpbench/method_class/SoyDNGP/SoyDNGP_he_class.py +112 -0
- gpbench/method_class/SoyDNGP/__init__.py +5 -0
- gpbench/method_class/XGBoost/XGboost_GPU_class.py +198 -0
- gpbench/method_class/XGBoost/XGboost_GPU_he_class.py +178 -0
- gpbench/method_class/XGBoost/__init__.py +5 -0
- gpbench/method_class/__init__.py +52 -0
- gpbench/method_class/rrBLUP/__init__.py +5 -0
- gpbench/method_class/rrBLUP/rrBLUP_class.py +140 -0
- gpbench/method_reg/BayesA/BayesA.py +116 -0
- gpbench/method_reg/BayesA/__init__.py +5 -0
- gpbench/method_reg/BayesA/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesA/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesA/bayesAfromR.py +16 -0
- gpbench/method_reg/BayesB/BayesB.py +117 -0
- gpbench/method_reg/BayesB/__init__.py +5 -0
- gpbench/method_reg/BayesB/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesB/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesB/bayesBfromR.py +16 -0
- gpbench/method_reg/BayesC/BayesC.py +115 -0
- gpbench/method_reg/BayesC/__init__.py +5 -0
- gpbench/method_reg/BayesC/_bayesfromR.py +96 -0
- gpbench/method_reg/BayesC/_param_free_base_model.py +84 -0
- gpbench/method_reg/BayesC/bayesCfromR.py +16 -0
- gpbench/method_reg/CropARNet/CropARNet.py +159 -0
- gpbench/method_reg/CropARNet/CropARNet_Hyperparameters.py +109 -0
- gpbench/method_reg/CropARNet/__init__.py +5 -0
- gpbench/method_reg/CropARNet/base_CropARNet.py +137 -0
- gpbench/method_reg/Cropformer/Cropformer.py +313 -0
- gpbench/method_reg/Cropformer/Cropformer_Hyperparameters.py +250 -0
- gpbench/method_reg/Cropformer/__init__.py +5 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS.py +186 -0
- gpbench/method_reg/DL_GWAS/DL_GWAS_Hyperparameters.py +125 -0
- gpbench/method_reg/DL_GWAS/__init__.py +5 -0
- gpbench/method_reg/DNNGP/DNNGP.py +157 -0
- gpbench/method_reg/DNNGP/DNNGP_Hyperparameters.py +118 -0
- gpbench/method_reg/DNNGP/__init__.py +5 -0
- gpbench/method_reg/DNNGP/base_dnngp.py +101 -0
- gpbench/method_reg/DeepCCR/DeepCCR.py +149 -0
- gpbench/method_reg/DeepCCR/DeepCCR_Hyperparameters.py +110 -0
- gpbench/method_reg/DeepCCR/__init__.py +5 -0
- gpbench/method_reg/DeepCCR/base_DeepCCR.py +171 -0
- gpbench/method_reg/DeepGS/DeepGS.py +165 -0
- gpbench/method_reg/DeepGS/DeepGS_Hyperparameters.py +114 -0
- gpbench/method_reg/DeepGS/__init__.py +5 -0
- gpbench/method_reg/DeepGS/base_deepgs.py +98 -0
- gpbench/method_reg/EIR/EIR.py +258 -0
- gpbench/method_reg/EIR/EIR_Hyperparameters.py +178 -0
- gpbench/method_reg/EIR/__init__.py +5 -0
- gpbench/method_reg/EIR/utils/__init__.py +0 -0
- gpbench/method_reg/EIR/utils/array_output_modules.py +97 -0
- gpbench/method_reg/EIR/utils/common.py +65 -0
- gpbench/method_reg/EIR/utils/lcl_layers.py +235 -0
- gpbench/method_reg/EIR/utils/logging.py +59 -0
- gpbench/method_reg/EIR/utils/mlp_layers.py +92 -0
- gpbench/method_reg/EIR/utils/models_locally_connected.py +642 -0
- gpbench/method_reg/EIR/utils/transformer_models.py +546 -0
- gpbench/method_reg/ElasticNet/ElasticNet.py +123 -0
- gpbench/method_reg/ElasticNet/ElasticNet_he.py +83 -0
- gpbench/method_reg/ElasticNet/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/G2PDeep_Hyperparameters.py +107 -0
- gpbench/method_reg/G2PDeep/G2Pdeep.py +166 -0
- gpbench/method_reg/G2PDeep/__init__.py +5 -0
- gpbench/method_reg/G2PDeep/base_G2PDeep.py +209 -0
- gpbench/method_reg/GBLUP/GBLUP_R.py +182 -0
- gpbench/method_reg/GBLUP/__init__.py +5 -0
- gpbench/method_reg/GEFormer/GEFormer.py +164 -0
- gpbench/method_reg/GEFormer/GEFormer_Hyperparameters.py +106 -0
- gpbench/method_reg/GEFormer/__init__.py +5 -0
- gpbench/method_reg/GEFormer/gMLP.py +341 -0
- gpbench/method_reg/LightGBM/LightGBM.py +237 -0
- gpbench/method_reg/LightGBM/LightGBM_Hyperparameters.py +77 -0
- gpbench/method_reg/LightGBM/__init__.py +5 -0
- gpbench/method_reg/MVP/MVP.py +182 -0
- gpbench/method_reg/MVP/MVP_Hyperparameters.py +126 -0
- gpbench/method_reg/MVP/__init__.py +5 -0
- gpbench/method_reg/MVP/base_MVP.py +113 -0
- gpbench/method_reg/RF/RF_GPU.py +174 -0
- gpbench/method_reg/RF/RF_Hyperparameters.py +163 -0
- gpbench/method_reg/RF/__init__.py +5 -0
- gpbench/method_reg/SVC/SVC_GPU.py +194 -0
- gpbench/method_reg/SVC/SVC_Hyperparameters.py +107 -0
- gpbench/method_reg/SVC/__init__.py +5 -0
- gpbench/method_reg/SoyDNGP/AlexNet_206.py +185 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP.py +179 -0
- gpbench/method_reg/SoyDNGP/SoyDNGP_Hyperparameters.py +105 -0
- gpbench/method_reg/SoyDNGP/__init__.py +5 -0
- gpbench/method_reg/XGBoost/XGboost_GPU.py +188 -0
- gpbench/method_reg/XGBoost/XGboost_Hyperparameters.py +167 -0
- gpbench/method_reg/XGBoost/__init__.py +5 -0
- gpbench/method_reg/__init__.py +55 -0
- gpbench/method_reg/rrBLUP/__init__.py +5 -0
- gpbench/method_reg/rrBLUP/rrBLUP.py +123 -0
- gpbench-1.0.0.dist-info/METADATA +379 -0
- gpbench-1.0.0.dist-info/RECORD +188 -0
- gpbench-1.0.0.dist-info/WHEEL +5 -0
- gpbench-1.0.0.dist-info/entry_points.txt +2 -0
- gpbench-1.0.0.dist-info/top_level.txt +3 -0
- tests/test_import.py +80 -0
- tests/test_method.py +232 -0
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import math
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from functools import partial
|
|
5
|
+
from typing import Callable, Dict, Literal, Optional, Sequence, Tuple, Type, Union
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
|
10
|
+
from torch.nn.functional import pad
|
|
11
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
12
|
+
|
|
13
|
+
from .lcl_layers import _find_lcl_padding_needed
|
|
14
|
+
from .logging import get_logger
|
|
15
|
+
|
|
16
|
+
logger = get_logger(name=__name__, tqdm_compatible=True)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SequenceModelConfig:
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
:param model_init_config:
|
|
24
|
+
Configuration / arguments used to initialise model.
|
|
25
|
+
|
|
26
|
+
:param model_type:
|
|
27
|
+
Which type of image model to use.
|
|
28
|
+
|
|
29
|
+
:param embedding_dim:
|
|
30
|
+
Which dimension to use for the embeddings. If ``None``, will automatically set
|
|
31
|
+
this value based on the number of tokens and attention heads.
|
|
32
|
+
|
|
33
|
+
:param position:
|
|
34
|
+
Whether to encode the token position or use learnable position embeddings.
|
|
35
|
+
|
|
36
|
+
:param position_dropout:
|
|
37
|
+
Dropout for the positional encoding / embedding.
|
|
38
|
+
|
|
39
|
+
:param window_size:
|
|
40
|
+
If set to more than 0, will apply a sliding window of feature
|
|
41
|
+
extraction over the input, meaning the model (e.g. transformer) will only
|
|
42
|
+
see a part of the input at a time. Can be Useful to avoid the O(n²)
|
|
43
|
+
complexity of transformers, as it becomes O(window_size² * n_windows) instead.
|
|
44
|
+
|
|
45
|
+
:param pool:
|
|
46
|
+
Whether and how to pool (max / avg) the final feature maps before being
|
|
47
|
+
passed to the final fusion module / predictor. Meaning we pool over the
|
|
48
|
+
sequence (i.e. time) dimension, so the resulting dimensions is embedding_dim
|
|
49
|
+
instead of sequence_length * embedding_dim. If using windowed / conv
|
|
50
|
+
transformers, this becomes embedding_dim * number_of_chunks.
|
|
51
|
+
|
|
52
|
+
:param pretrained_model:
|
|
53
|
+
Specify whether the model type is assumed to be pretrained and from the
|
|
54
|
+
Pytorch Image Models repository.
|
|
55
|
+
|
|
56
|
+
:param freeze_pretrained_model:
|
|
57
|
+
Whether to freeze the pretrained model weights.
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
model_init_config: Union["BasicTransformerFeatureExtractorModelConfig", Dict]
|
|
62
|
+
|
|
63
|
+
#model_type: Literal["sequence-default"] | str = "sequence-default"
|
|
64
|
+
model_type: Union[Literal["sequence-default"], str] = "sequence-default"
|
|
65
|
+
|
|
66
|
+
embedding_dim: int = 64
|
|
67
|
+
|
|
68
|
+
position: Literal["encode", "embed"] = "encode"
|
|
69
|
+
position_dropout: float = 0.10
|
|
70
|
+
window_size: int = 0
|
|
71
|
+
pool: Union[Literal["avg"], Literal["max"], None] = None
|
|
72
|
+
|
|
73
|
+
pretrained_model: bool = False
|
|
74
|
+
freeze_pretrained_model: bool = False
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class TransformerWrapperModel(nn.Module):
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
feature_extractor: Union["TransformerFeatureExtractor", nn.Module],
|
|
81
|
+
model_config: SequenceModelConfig,
|
|
82
|
+
embedding_dim: int,
|
|
83
|
+
num_tokens: int,
|
|
84
|
+
max_length: int,
|
|
85
|
+
external_feature_extractor: bool,
|
|
86
|
+
device: str,
|
|
87
|
+
embeddings: Optional[nn.Embedding] = None,
|
|
88
|
+
pre_computed_num_out_features: Optional[int] = None,
|
|
89
|
+
) -> None:
|
|
90
|
+
super().__init__()
|
|
91
|
+
self.model_config = model_config
|
|
92
|
+
self.embedding_dim = embedding_dim
|
|
93
|
+
self.num_tokens = num_tokens
|
|
94
|
+
self.max_length = max_length
|
|
95
|
+
self.external_feature_extractor = external_feature_extractor
|
|
96
|
+
self.pre_computed_num_out_features = pre_computed_num_out_features
|
|
97
|
+
|
|
98
|
+
pos_repr_class = get_positional_representation_class(
|
|
99
|
+
position_model_config=self.model_config.position
|
|
100
|
+
)
|
|
101
|
+
self.pos_representation = pos_repr_class(
|
|
102
|
+
embedding_dim=self.embedding_dim,
|
|
103
|
+
dropout=self.model_config.position_dropout,
|
|
104
|
+
max_length=self.max_length,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if embeddings:
|
|
108
|
+
self.embedding = embeddings
|
|
109
|
+
else:
|
|
110
|
+
self.embedding = nn.Embedding(
|
|
111
|
+
num_embeddings=self.num_tokens, embedding_dim=self.embedding_dim
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.feature_extractor = feature_extractor
|
|
115
|
+
|
|
116
|
+
if not embeddings:
|
|
117
|
+
self.init_embedding_weights()
|
|
118
|
+
|
|
119
|
+
(
|
|
120
|
+
self.dynamic_extras,
|
|
121
|
+
self.extract_features,
|
|
122
|
+
) = _get_transformer_wrapper_feature_extractor(
|
|
123
|
+
feature_extractor=self.feature_extractor,
|
|
124
|
+
window_size=self.model_config.window_size,
|
|
125
|
+
max_length=self.max_length,
|
|
126
|
+
embedding_dim=self.embedding_dim,
|
|
127
|
+
pool=self.model_config.pool,
|
|
128
|
+
device=device,
|
|
129
|
+
external_feature_extractor=self.external_feature_extractor,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
@property
|
|
133
|
+
def num_out_features(self) -> int:
|
|
134
|
+
if self.pre_computed_num_out_features:
|
|
135
|
+
return self.pre_computed_num_out_features
|
|
136
|
+
|
|
137
|
+
padding = self.dynamic_extras.get("padding", 0)
|
|
138
|
+
length_with_padding = self.max_length + padding
|
|
139
|
+
|
|
140
|
+
if self.model_config.pool in ("avg", "max"):
|
|
141
|
+
num_chunks = 1
|
|
142
|
+
if self.model_config.window_size:
|
|
143
|
+
num_chunks = length_with_padding // self.model_config.window_size
|
|
144
|
+
|
|
145
|
+
return self.embedding_dim * num_chunks
|
|
146
|
+
|
|
147
|
+
return length_with_padding * self.embedding_dim
|
|
148
|
+
|
|
149
|
+
def script_submodules_for_tracing(self):
|
|
150
|
+
self.embedding = torch.jit.script(self.embedding)
|
|
151
|
+
|
|
152
|
+
def init_embedding_weights(self) -> None:
|
|
153
|
+
init_range = 0.1
|
|
154
|
+
self.embedding.weight.data.uniform_(-init_range, init_range)
|
|
155
|
+
|
|
156
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
157
|
+
out = input * math.sqrt(self.embedding_dim)
|
|
158
|
+
out = self.pos_representation(out)
|
|
159
|
+
out = self.extract_features(out)
|
|
160
|
+
|
|
161
|
+
return out
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def get_embedding_dim_for_sequence_model(
|
|
165
|
+
embedding_dim: Union[None, int], num_tokens: int, num_heads: int
|
|
166
|
+
) -> int:
|
|
167
|
+
if embedding_dim is None:
|
|
168
|
+
auto_emb_dim = math.ceil((int(num_tokens**0.25) / num_heads)) * num_heads
|
|
169
|
+
logger.info(
|
|
170
|
+
"Setting up automatic embedding dimension of %d based on %d "
|
|
171
|
+
"tokens and %d attention heads.",
|
|
172
|
+
auto_emb_dim,
|
|
173
|
+
num_tokens,
|
|
174
|
+
num_heads,
|
|
175
|
+
)
|
|
176
|
+
embedding_dim = auto_emb_dim
|
|
177
|
+
|
|
178
|
+
return embedding_dim
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _get_transformer_wrapper_feature_extractor(
|
|
182
|
+
feature_extractor: Union["TransformerFeatureExtractor", nn.Module],
|
|
183
|
+
external_feature_extractor: bool,
|
|
184
|
+
window_size: int,
|
|
185
|
+
embedding_dim: int,
|
|
186
|
+
max_length: int,
|
|
187
|
+
device: str,
|
|
188
|
+
pool: Union[Literal["avg"], Literal["max"], None] = None,
|
|
189
|
+
) -> Tuple[Dict[str, int], Callable[[torch.Tensor], torch.Tensor]]:
|
|
190
|
+
dynamic_extras = {"padding": 0}
|
|
191
|
+
|
|
192
|
+
feature_extractor_forward = _get_feature_extractor_forward(
|
|
193
|
+
is_hf_model=external_feature_extractor,
|
|
194
|
+
feature_extractor=feature_extractor,
|
|
195
|
+
input_length=window_size if window_size else max_length,
|
|
196
|
+
embedding_size=embedding_dim,
|
|
197
|
+
device=device,
|
|
198
|
+
pool=pool,
|
|
199
|
+
)
|
|
200
|
+
if not window_size:
|
|
201
|
+
extractor = partial(
|
|
202
|
+
feature_extractor_forward, feature_extractor=feature_extractor
|
|
203
|
+
)
|
|
204
|
+
else:
|
|
205
|
+
num_chunks = int(math.ceil(max_length / window_size))
|
|
206
|
+
logger.debug(
|
|
207
|
+
"Setting num chunks to %d as window size of %d and maximum sequence length "
|
|
208
|
+
"of %d were passed in.",
|
|
209
|
+
num_chunks,
|
|
210
|
+
window_size,
|
|
211
|
+
max_length,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
padding = _find_lcl_padding_needed(
|
|
215
|
+
input_size=max_length,
|
|
216
|
+
kernel_size=window_size,
|
|
217
|
+
num_chunks=num_chunks,
|
|
218
|
+
)
|
|
219
|
+
dynamic_extras["padding"] = padding
|
|
220
|
+
|
|
221
|
+
extractor = partial(
|
|
222
|
+
_conv_transformer_forward,
|
|
223
|
+
feature_extractor=feature_extractor,
|
|
224
|
+
feature_extractor_forward_callable=feature_extractor_forward,
|
|
225
|
+
max_length=max_length,
|
|
226
|
+
window_size=window_size,
|
|
227
|
+
padding=padding,
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
return dynamic_extras, extractor
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _get_feature_extractor_forward(
|
|
234
|
+
is_hf_model: bool,
|
|
235
|
+
feature_extractor: Union[nn.Module, PreTrainedModel],
|
|
236
|
+
input_length: int,
|
|
237
|
+
embedding_size: int,
|
|
238
|
+
device: str,
|
|
239
|
+
pool: Union[Literal["avg"], Literal["max"], None] = None,
|
|
240
|
+
) -> Callable[
|
|
241
|
+
[torch.Tensor, Union["TransformerFeatureExtractor", nn.Module]], torch.Tensor
|
|
242
|
+
]:
|
|
243
|
+
if is_hf_model:
|
|
244
|
+
return get_hf_transformer_forward(
|
|
245
|
+
feature_extractor_=feature_extractor,
|
|
246
|
+
input_length=input_length,
|
|
247
|
+
embedding_dim=embedding_size,
|
|
248
|
+
device=device,
|
|
249
|
+
pool=pool,
|
|
250
|
+
)
|
|
251
|
+
return _get_simple_transformer_forward(pool=pool)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def _get_simple_transformer_forward(
|
|
255
|
+
pool: Union[Literal["avg"], Literal["max"], None] = None,
|
|
256
|
+
):
|
|
257
|
+
pooling_func = _get_sequence_pooling_func(pool=pool)
|
|
258
|
+
|
|
259
|
+
def _simple_transformer_forward(
|
|
260
|
+
input: torch.Tensor,
|
|
261
|
+
feature_extractor: "TransformerFeatureExtractor",
|
|
262
|
+
) -> torch.Tensor:
|
|
263
|
+
tensor_out = feature_extractor(input)
|
|
264
|
+
tensor_pooled = pooling_func(input=tensor_out)
|
|
265
|
+
final_out = tensor_pooled.flatten(1)
|
|
266
|
+
|
|
267
|
+
return final_out
|
|
268
|
+
|
|
269
|
+
return _simple_transformer_forward
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def get_hf_transformer_forward(
|
|
273
|
+
feature_extractor_: PreTrainedModel,
|
|
274
|
+
input_length: int,
|
|
275
|
+
embedding_dim: int,
|
|
276
|
+
device: str,
|
|
277
|
+
pool: Union[Literal["avg"], Literal["max"], None] = None,
|
|
278
|
+
):
|
|
279
|
+
forward_argnames = inspect.getfullargspec(feature_extractor_.forward)[0]
|
|
280
|
+
|
|
281
|
+
bound_kwargs = _build_transformer_forward_kwargs(
|
|
282
|
+
forward_argnames=forward_argnames,
|
|
283
|
+
config=feature_extractor_.config,
|
|
284
|
+
input_length=input_length,
|
|
285
|
+
embedding_dim=embedding_dim,
|
|
286
|
+
device=device,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
pooling_func = _get_sequence_pooling_func(pool=pool)
|
|
290
|
+
|
|
291
|
+
def _hf_transformer_forward(
|
|
292
|
+
input: torch.Tensor,
|
|
293
|
+
feature_extractor: nn.Module,
|
|
294
|
+
key: str = "last_hidden_state",
|
|
295
|
+
) -> torch.Tensor:
|
|
296
|
+
hf_transformer_out = feature_extractor(inputs_embeds=input, **bound_kwargs)
|
|
297
|
+
tensor_out = getattr(hf_transformer_out, key)
|
|
298
|
+
tensor_pooled = pooling_func(input=tensor_out)
|
|
299
|
+
final_out = tensor_pooled.flatten(1)
|
|
300
|
+
return final_out
|
|
301
|
+
|
|
302
|
+
return _hf_transformer_forward
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def _get_sequence_pooling_func(
|
|
306
|
+
pool: Union[Literal["avg"], Literal["max"], None]
|
|
307
|
+
) -> Callable:
|
|
308
|
+
def _identity(input: torch.Tensor) -> torch.Tensor:
|
|
309
|
+
return input
|
|
310
|
+
|
|
311
|
+
def _max(input: torch.Tensor) -> torch.Tensor:
|
|
312
|
+
return input.max(dim=1)[0]
|
|
313
|
+
|
|
314
|
+
def _avg(input: torch.Tensor) -> torch.Tensor:
|
|
315
|
+
return input.mean(dim=1)
|
|
316
|
+
|
|
317
|
+
if pool is None:
|
|
318
|
+
return _identity
|
|
319
|
+
|
|
320
|
+
elif pool == "max":
|
|
321
|
+
return _max
|
|
322
|
+
|
|
323
|
+
elif pool == "avg":
|
|
324
|
+
return _avg
|
|
325
|
+
|
|
326
|
+
raise ValueError()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _build_transformer_forward_kwargs(
|
|
330
|
+
forward_argnames: Sequence[str],
|
|
331
|
+
config: PretrainedConfig,
|
|
332
|
+
input_length: int,
|
|
333
|
+
embedding_dim: int,
|
|
334
|
+
device: str,
|
|
335
|
+
) -> Dict:
|
|
336
|
+
"""
|
|
337
|
+
TODO: Deprecate.
|
|
338
|
+
"""
|
|
339
|
+
kwargs = {}
|
|
340
|
+
|
|
341
|
+
if "decoder_inputs_embeds" in forward_argnames:
|
|
342
|
+
kwargs["decoder_inputs_embeds"] = torch.randn(1, input_length, embedding_dim)
|
|
343
|
+
|
|
344
|
+
return kwargs
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def _conv_transformer_forward(
|
|
348
|
+
input: torch.Tensor,
|
|
349
|
+
feature_extractor: "TransformerFeatureExtractor",
|
|
350
|
+
feature_extractor_forward_callable: Callable[
|
|
351
|
+
[torch.Tensor, Callable], torch.Tensor
|
|
352
|
+
],
|
|
353
|
+
max_length: int,
|
|
354
|
+
window_size: int,
|
|
355
|
+
padding: int,
|
|
356
|
+
) -> torch.Tensor:
|
|
357
|
+
out = pad(input=input, pad=[0, 0, padding, 0])
|
|
358
|
+
total_length = max_length + padding
|
|
359
|
+
|
|
360
|
+
aggregated_out = None
|
|
361
|
+
for lower_index in range(0, total_length, window_size):
|
|
362
|
+
upper_index = lower_index + window_size
|
|
363
|
+
|
|
364
|
+
cur_input = out[:, lower_index:upper_index, :]
|
|
365
|
+
cur_out = feature_extractor_forward_callable(cur_input, feature_extractor)
|
|
366
|
+
|
|
367
|
+
if aggregated_out is None:
|
|
368
|
+
aggregated_out = cur_out
|
|
369
|
+
else:
|
|
370
|
+
aggregated_out = torch.cat((aggregated_out, cur_out), dim=1)
|
|
371
|
+
|
|
372
|
+
assert aggregated_out is not None
|
|
373
|
+
return aggregated_out
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
@dataclass
|
|
377
|
+
class BasicTransformerFeatureExtractorModelConfig:
|
|
378
|
+
"""
|
|
379
|
+
:param num_heads:
|
|
380
|
+
The number of heads in the multi-head attention models
|
|
381
|
+
|
|
382
|
+
:param num_layers:
|
|
383
|
+
The number of encoder blocks in the transformer model.
|
|
384
|
+
|
|
385
|
+
:param dim_feedforward:
|
|
386
|
+
The dimension of the feedforward network model
|
|
387
|
+
|
|
388
|
+
:param dropout:
|
|
389
|
+
Dropout value to use in the encoder layers.
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
num_heads: int = 8
|
|
393
|
+
num_layers: int = 2
|
|
394
|
+
dim_feedforward: Union[int, Literal["auto"]] = "auto"
|
|
395
|
+
dropout: float = 0.10
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class TransformerFeatureExtractor(nn.Module):
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
model_config: BasicTransformerFeatureExtractorModelConfig,
|
|
402
|
+
embedding_dim: int,
|
|
403
|
+
num_tokens: int,
|
|
404
|
+
max_length: int,
|
|
405
|
+
) -> None:
|
|
406
|
+
super().__init__()
|
|
407
|
+
|
|
408
|
+
self.model_config = model_config
|
|
409
|
+
self.embedding_dim = embedding_dim
|
|
410
|
+
self.num_tokens = num_tokens
|
|
411
|
+
self.max_length = max_length
|
|
412
|
+
|
|
413
|
+
dim_feed_forward = parse_dim_feedforward(
|
|
414
|
+
dim_feedforward=self.model_config.dim_feedforward,
|
|
415
|
+
embedding_dim=self.embedding_dim,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
encoder_layer_base = TransformerEncoderLayer(
|
|
419
|
+
d_model=self.embedding_dim,
|
|
420
|
+
nhead=self.model_config.num_heads,
|
|
421
|
+
dim_feedforward=dim_feed_forward,
|
|
422
|
+
dropout=self.model_config.dropout,
|
|
423
|
+
activation="gelu",
|
|
424
|
+
batch_first=True,
|
|
425
|
+
norm_first=True,
|
|
426
|
+
)
|
|
427
|
+
self.transformer_encoder = TransformerEncoder(
|
|
428
|
+
encoder_layer=encoder_layer_base, num_layers=self.model_config.num_layers
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
@property
|
|
432
|
+
def num_out_features(self) -> int:
|
|
433
|
+
return self.max_length * self.embedding_dim
|
|
434
|
+
|
|
435
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
436
|
+
out = self.transformer_encoder(input)
|
|
437
|
+
return out
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def parse_dim_feedforward(
|
|
441
|
+
dim_feedforward: Union[int, Literal["auto"]], embedding_dim: int
|
|
442
|
+
) -> int:
|
|
443
|
+
if dim_feedforward == "auto":
|
|
444
|
+
dim_feedforward = embedding_dim * 4
|
|
445
|
+
logger.info(
|
|
446
|
+
"Setting dim_feedfoward to %d based on embedding_dim=%d and 'auto' "
|
|
447
|
+
"option.",
|
|
448
|
+
dim_feedforward,
|
|
449
|
+
embedding_dim,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
return dim_feedforward
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
class SequenceOutputTransformerFeatureExtractor(TransformerFeatureExtractor):
|
|
456
|
+
def __init__(
|
|
457
|
+
self,
|
|
458
|
+
model_config: BasicTransformerFeatureExtractorModelConfig,
|
|
459
|
+
embedding_dim: int,
|
|
460
|
+
num_tokens: int,
|
|
461
|
+
max_length: int,
|
|
462
|
+
):
|
|
463
|
+
super().__init__(
|
|
464
|
+
model_config=model_config,
|
|
465
|
+
embedding_dim=embedding_dim,
|
|
466
|
+
num_tokens=num_tokens,
|
|
467
|
+
max_length=max_length,
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
mask = torch.triu(
|
|
471
|
+
torch.ones(self.max_length, self.max_length) * float("-inf"), diagonal=1
|
|
472
|
+
)
|
|
473
|
+
self.register_buffer("mask", mask)
|
|
474
|
+
|
|
475
|
+
def forward(self, input: Tensor) -> Tensor:
|
|
476
|
+
out = self.transformer_encoder(input, mask=self.mask)
|
|
477
|
+
return out
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def get_positional_representation_class(
|
|
481
|
+
position_model_config: Literal["encode", "embed"]
|
|
482
|
+
) -> Union[Type["PositionalEncoding"], Type["PositionalEmbedding"]]:
|
|
483
|
+
if position_model_config == "encode":
|
|
484
|
+
return PositionalEncoding
|
|
485
|
+
elif position_model_config == "embed":
|
|
486
|
+
return PositionalEmbedding
|
|
487
|
+
raise ValueError(
|
|
488
|
+
"Unknown value for positional representation. "
|
|
489
|
+
"Expected 'encode' or 'embed' but got '%s'.",
|
|
490
|
+
position_model_config,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
class PositionalEncoding(nn.Module):
|
|
495
|
+
def __init__(
|
|
496
|
+
self,
|
|
497
|
+
embedding_dim: int,
|
|
498
|
+
max_length: int,
|
|
499
|
+
dropout: float = 0.1,
|
|
500
|
+
) -> None:
|
|
501
|
+
super().__init__()
|
|
502
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
503
|
+
self.max_length = max_length
|
|
504
|
+
|
|
505
|
+
position = torch.arange(max_length).unsqueeze(1)
|
|
506
|
+
div_term = torch.exp(
|
|
507
|
+
torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim)
|
|
508
|
+
)
|
|
509
|
+
pe: torch.Tensor = torch.zeros(1, max_length, embedding_dim)
|
|
510
|
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
|
511
|
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
|
512
|
+
self.register_buffer("pe", pe)
|
|
513
|
+
|
|
514
|
+
self.pe: Tensor
|
|
515
|
+
|
|
516
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
517
|
+
x = x + self.pe[:, : self.max_length, :]
|
|
518
|
+
return self.dropout(x)
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
class PositionalEmbedding(nn.Module):
|
|
522
|
+
def __init__(
|
|
523
|
+
self,
|
|
524
|
+
embedding_dim: int,
|
|
525
|
+
max_length: int,
|
|
526
|
+
dropout: float = 0.1,
|
|
527
|
+
zero_init: bool = False,
|
|
528
|
+
) -> None:
|
|
529
|
+
super().__init__()
|
|
530
|
+
self.dropout = nn.Dropout(p=dropout)
|
|
531
|
+
self.max_length = max_length
|
|
532
|
+
|
|
533
|
+
if zero_init:
|
|
534
|
+
self.embedding = torch.nn.Parameter(
|
|
535
|
+
data=torch.zeros(1, max_length, embedding_dim),
|
|
536
|
+
requires_grad=True,
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
self.embedding = torch.nn.Parameter(
|
|
540
|
+
data=torch.randn(1, max_length, embedding_dim),
|
|
541
|
+
requires_grad=True,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
545
|
+
x = x + self.embedding[:, : self.max_length, :]
|
|
546
|
+
return self.dropout(x)
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import psutil
|
|
4
|
+
import argparse
|
|
5
|
+
import random
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import numpy as np
|
|
8
|
+
from sklearn.model_selection import KFold
|
|
9
|
+
from sklearn.linear_model import ElasticNet
|
|
10
|
+
from scipy.stats import pearsonr
|
|
11
|
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
|
12
|
+
from . import ElasticNet_he
|
|
13
|
+
|
|
14
|
+
def parse_args():
|
|
15
|
+
parser = argparse.ArgumentParser(description="Argument parser")
|
|
16
|
+
parser.add_argument('--methods', type=str, default='ElasticNet/', help='Method name')
|
|
17
|
+
parser.add_argument('--species', type=str, default='', help='Dataset name')
|
|
18
|
+
parser.add_argument('--phe', type=str, default='', help='Phenotype name')
|
|
19
|
+
parser.add_argument('--data_dir', type=str, default='../../data/')
|
|
20
|
+
parser.add_argument('--result_dir', type=str, default='result/')
|
|
21
|
+
|
|
22
|
+
parser.add_argument('--alpha', type=float, default=0.5, help='Regularization strength')
|
|
23
|
+
parser.add_argument('--l1_ratio', type=float, default=0.5, help='L1 ratio (0=Ridge, 1=Lasso)')
|
|
24
|
+
args = parser.parse_args()
|
|
25
|
+
return args
|
|
26
|
+
|
|
27
|
+
def load_data(args):
|
|
28
|
+
xData = np.load(os.path.join(args.data_dir, args.species, 'genotype.npz'))["arr_0"]
|
|
29
|
+
yData = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_0"]
|
|
30
|
+
names = np.load(os.path.join(args.data_dir, args.species, 'phenotype.npz'))["arr_1"]
|
|
31
|
+
|
|
32
|
+
nsample = xData.shape[0]
|
|
33
|
+
nsnp = xData.shape[1]
|
|
34
|
+
print("Number of samples: ", nsample)
|
|
35
|
+
print("Number of SNPs: ", nsnp)
|
|
36
|
+
return xData, yData, nsample, nsnp, names
|
|
37
|
+
|
|
38
|
+
def set_seed(seed=42):
|
|
39
|
+
random.seed(seed)
|
|
40
|
+
np.random.seed(seed)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def run_nested_cv(args, data, label):
|
|
44
|
+
result_dir = os.path.join(args.result_dir, args.methods + args.species + args.phe)
|
|
45
|
+
os.makedirs(result_dir, exist_ok=True)
|
|
46
|
+
print("Starting 10-fold cross-validation with ElasticNet (sklearn)...")
|
|
47
|
+
|
|
48
|
+
kf = KFold(n_splits=10, shuffle=True, random_state=42)
|
|
49
|
+
|
|
50
|
+
all_mse, all_mae, all_r2, all_pcc = [], [], [], []
|
|
51
|
+
time_star = time.time()
|
|
52
|
+
|
|
53
|
+
for fold, (train_index, test_index) in enumerate(kf.split(data)):
|
|
54
|
+
print(f"Running fold {fold}...")
|
|
55
|
+
process = psutil.Process(os.getpid())
|
|
56
|
+
fold_start_time = time.time()
|
|
57
|
+
|
|
58
|
+
x_train = data[train_index]
|
|
59
|
+
x_test = data[test_index]
|
|
60
|
+
y_train = label[train_index]
|
|
61
|
+
y_test = label[test_index]
|
|
62
|
+
|
|
63
|
+
model = ElasticNet(alpha=args.alpha, l1_ratio=args.l1_ratio, max_iter=1000, random_state=42)
|
|
64
|
+
model.fit(x_train, y_train)
|
|
65
|
+
y_test_preds = model.predict(x_test)
|
|
66
|
+
|
|
67
|
+
pcc, _ = pearsonr(y_test, y_test_preds)
|
|
68
|
+
mse = mean_squared_error(y_test, y_test_preds)
|
|
69
|
+
r2 = r2_score(y_test, y_test_preds)
|
|
70
|
+
mae = mean_absolute_error(y_test, y_test_preds)
|
|
71
|
+
|
|
72
|
+
all_mse.append(mse)
|
|
73
|
+
all_r2.append(r2)
|
|
74
|
+
all_mae.append(mae)
|
|
75
|
+
all_pcc.append(pcc)
|
|
76
|
+
|
|
77
|
+
fold_time = time.time() - fold_start_time
|
|
78
|
+
fold_cpu_mem = process.memory_info().rss / 1024**2
|
|
79
|
+
|
|
80
|
+
print(f'Fold {fold}: Corr={pcc:.4f}, MAE={mae:.4f}, MSE={mse:.4f}, R2={r2:.4f}, '
|
|
81
|
+
f'Time={fold_time:.2f}s, CPU={fold_cpu_mem:.2f}MB')
|
|
82
|
+
|
|
83
|
+
results_df = pd.DataFrame({'Y_test': y_test, 'Y_pred': y_test_preds})
|
|
84
|
+
results_df.to_csv(os.path.join(result_dir, f"fold{fold}.csv"), index=False)
|
|
85
|
+
|
|
86
|
+
del model, y_test_preds, x_train, x_test, y_train, y_test
|
|
87
|
+
|
|
88
|
+
print("\n===== Cross-validation summary =====")
|
|
89
|
+
print(f"Using sklearn ElasticNet")
|
|
90
|
+
print(f"Average PCC: {np.mean(all_pcc):.4f} ± {np.std(all_pcc):.4f}")
|
|
91
|
+
print(f"Average MAE: {np.mean(all_mae):.4f} ± {np.std(all_mae):.4f}")
|
|
92
|
+
print(f"Average MSE: {np.mean(all_mse):.4f} ± {np.std(all_mse):.4f}")
|
|
93
|
+
print(f"Average R2 : {np.mean(all_r2):.4f} ± {np.std(all_r2):.4f}")
|
|
94
|
+
print(f"Total Time: {time.time() - time_star:.2f}s")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def ElasticNet_reg():
|
|
98
|
+
set_seed(42)
|
|
99
|
+
args = parse_args()
|
|
100
|
+
all_species =['Cotton/']
|
|
101
|
+
|
|
102
|
+
for i in range(len(all_species)):
|
|
103
|
+
args.species = all_species[i]
|
|
104
|
+
X, Y, nsamples, nsnp, names = load_data(args)
|
|
105
|
+
for j in range(len(names)):
|
|
106
|
+
args.phe = names[j]
|
|
107
|
+
print("starting run " + args.methods + args.species + args.phe)
|
|
108
|
+
label = Y[:, j]
|
|
109
|
+
label = np.nan_to_num(label, nan=np.nanmean(label))
|
|
110
|
+
|
|
111
|
+
best_params = ElasticNet_he.Hyperparameter(X, label)
|
|
112
|
+
args.alpha = best_params['alpha']
|
|
113
|
+
args.l1_ratio = best_params['l1_ratio']
|
|
114
|
+
|
|
115
|
+
start_time = time.time()
|
|
116
|
+
run_nested_cv(args, data=X, label=label)
|
|
117
|
+
elapsed_time = time.time() - start_time
|
|
118
|
+
print(f"running time: {elapsed_time:.2f} s")
|
|
119
|
+
print("successfully")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if __name__ == "__main__":
|
|
123
|
+
ElasticNet_reg()
|