python-katlas 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- katlas/__init__.py +1 -0
- katlas/_modidx.py +110 -0
- katlas/core.py +769 -0
- katlas/dl.py +355 -0
- katlas/feature.py +290 -0
- katlas/imports.py +7 -0
- katlas/plot.py +663 -0
- katlas/train.py +231 -0
- python_katlas-0.0.1.dist-info/LICENSE +201 -0
- python_katlas-0.0.1.dist-info/METADATA +402 -0
- python_katlas-0.0.1.dist-info/RECORD +14 -0
- python_katlas-0.0.1.dist-info/WHEEL +5 -0
- python_katlas-0.0.1.dist-info/entry_points.txt +2 -0
- python_katlas-0.0.1.dist-info/top_level.txt +1 -0
katlas/dl.py
ADDED
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_DL.ipynb.
|
|
2
|
+
|
|
3
|
+
# %% auto 0
|
|
4
|
+
__all__ = ['def_device', 'seed_everything', 'GeneralDataset', 'get_sampler', 'MLP_1', 'CNN1D_1', 'init_weights', 'lin_wn',
|
|
5
|
+
'conv_wn', 'CNN1D_2', 'train_dl', 'train_dl_cv', 'predict_dl']
|
|
6
|
+
|
|
7
|
+
# %% ../nbs/04_DL.ipynb 4
|
|
8
|
+
from fastbook import *
|
|
9
|
+
import fastcore.all as fc,torch.nn.init as init
|
|
10
|
+
from fastai.callback.training import GradientClip
|
|
11
|
+
from torch.utils.data import WeightedRandomSampler
|
|
12
|
+
|
|
13
|
+
# katlas
|
|
14
|
+
from .core import Data
|
|
15
|
+
from .feature import *
|
|
16
|
+
from .train import *
|
|
17
|
+
|
|
18
|
+
# sklearn
|
|
19
|
+
from sklearn.model_selection import *
|
|
20
|
+
from sklearn.metrics import mean_squared_error
|
|
21
|
+
from scipy.stats import spearmanr,pearsonr
|
|
22
|
+
|
|
23
|
+
# %% ../nbs/04_DL.ipynb 6
|
|
24
|
+
def seed_everything(seed=123):
|
|
25
|
+
random.seed(seed)
|
|
26
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
27
|
+
np.random.seed(seed)
|
|
28
|
+
torch.manual_seed(seed)
|
|
29
|
+
torch.cuda.manual_seed(seed)
|
|
30
|
+
torch.backends.cudnn.deterministic = True
|
|
31
|
+
torch.backends.cudnn.benchmark = False
|
|
32
|
+
|
|
33
|
+
# %% ../nbs/04_DL.ipynb 8
|
|
34
|
+
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
35
|
+
|
|
36
|
+
# %% ../nbs/04_DL.ipynb 13
|
|
37
|
+
class GeneralDataset:
|
|
38
|
+
def __init__(self,
|
|
39
|
+
df, # a dataframe of values
|
|
40
|
+
feat_col, # feature columns
|
|
41
|
+
target_col=None # Will return test set for prediction if target col is None
|
|
42
|
+
):
|
|
43
|
+
"A general dataset that can be applied to any dataframe"
|
|
44
|
+
|
|
45
|
+
self.test = False if target_col is not None else True
|
|
46
|
+
|
|
47
|
+
self.X = df[feat_col].values
|
|
48
|
+
self.y = df[target_col].values if not self.test else None
|
|
49
|
+
|
|
50
|
+
self.len = df.shape[0]
|
|
51
|
+
|
|
52
|
+
def __len__(self):
|
|
53
|
+
return self.len
|
|
54
|
+
|
|
55
|
+
def __getitem__(self, index):
|
|
56
|
+
X = torch.Tensor(self.X[index])
|
|
57
|
+
if self.test:
|
|
58
|
+
return X
|
|
59
|
+
else:
|
|
60
|
+
y = torch.Tensor(self.y[index])
|
|
61
|
+
return X, y
|
|
62
|
+
|
|
63
|
+
# %% ../nbs/04_DL.ipynb 17
|
|
64
|
+
def get_sampler(info,col):
|
|
65
|
+
|
|
66
|
+
"For imbalanced data, get higher weights for less-represented samples"
|
|
67
|
+
|
|
68
|
+
# get value counts
|
|
69
|
+
group_counts = info[col].value_counts()
|
|
70
|
+
|
|
71
|
+
# to reduce the difference through log
|
|
72
|
+
# group_counts = group_counts.apply(lambda x: np.log(x+1.01))
|
|
73
|
+
|
|
74
|
+
weights = 1. / group_counts[info[col]]
|
|
75
|
+
|
|
76
|
+
sample_weights = torch.from_numpy(weights.to_numpy())
|
|
77
|
+
sample_weights = torch.clamp_min(sample_weights,0.01)
|
|
78
|
+
|
|
79
|
+
sampler = WeightedRandomSampler(sample_weights, len(sample_weights),replacement=True)
|
|
80
|
+
|
|
81
|
+
return sampler
|
|
82
|
+
|
|
83
|
+
# %% ../nbs/04_DL.ipynb 23
|
|
84
|
+
def MLP_1(num_features,
|
|
85
|
+
num_targets,
|
|
86
|
+
hidden_units = [512, 218],
|
|
87
|
+
dp = 0.2):
|
|
88
|
+
|
|
89
|
+
# Start with the first layer from num_features to the first hidden layer
|
|
90
|
+
layers = [
|
|
91
|
+
nn.Linear(num_features, hidden_units[0]),
|
|
92
|
+
nn.BatchNorm1d(hidden_units[0]),
|
|
93
|
+
nn.Dropout(dp),
|
|
94
|
+
nn.PReLU()
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
# Loop over hidden units to create intermediate layers
|
|
98
|
+
for i in range(len(hidden_units) - 1):
|
|
99
|
+
layers.extend([
|
|
100
|
+
nn.Linear(hidden_units[i], hidden_units[i+1]),
|
|
101
|
+
nn.BatchNorm1d(hidden_units[i+1]),
|
|
102
|
+
nn.Dropout(dp),
|
|
103
|
+
nn.PReLU()
|
|
104
|
+
])
|
|
105
|
+
|
|
106
|
+
# Add the output layer
|
|
107
|
+
layers.append(nn.Linear(hidden_units[-1], num_targets))
|
|
108
|
+
|
|
109
|
+
model = nn.Sequential(*layers)
|
|
110
|
+
|
|
111
|
+
return model
|
|
112
|
+
|
|
113
|
+
# %% ../nbs/04_DL.ipynb 29
|
|
114
|
+
class CNN1D_1(Module):
|
|
115
|
+
|
|
116
|
+
def __init__(self,
|
|
117
|
+
num_features, # this does not matter, just for format
|
|
118
|
+
num_targets):
|
|
119
|
+
|
|
120
|
+
self.conv1 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3, dilation=1, padding=1, stride=1)
|
|
121
|
+
self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
|
|
122
|
+
self.conv2 = nn.Conv1d(in_channels=3, out_channels=8, kernel_size=3, dilation=1, padding=1, stride=1)
|
|
123
|
+
self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
|
|
124
|
+
self.flatten = Flatten()
|
|
125
|
+
self.fc1 = nn.Linear(in_features = int(8 * num_features/4), out_features=128)
|
|
126
|
+
self.fc2 = nn.Linear(in_features=128, out_features=num_targets)
|
|
127
|
+
|
|
128
|
+
def forward(self, x):
|
|
129
|
+
x = x.unsqueeze(1) # need shape (bs, 1, num_features) for CNN
|
|
130
|
+
x = self.pool1(nn.functional.relu(self.conv1(x)))
|
|
131
|
+
x = self.pool2(nn.functional.relu(self.conv2(x)))
|
|
132
|
+
# x = torch.flatten(x, 1)
|
|
133
|
+
x = self.flatten(x)
|
|
134
|
+
x = nn.functional.relu(self.fc1(x))
|
|
135
|
+
x = self.fc2(x)
|
|
136
|
+
return x
|
|
137
|
+
|
|
138
|
+
# %% ../nbs/04_DL.ipynb 33
|
|
139
|
+
def init_weights(m, leaky=0.):
|
|
140
|
+
"Initiate any Conv layer with Kaiming norm."
|
|
141
|
+
if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): init.kaiming_normal_(m.weight, a=leaky)
|
|
142
|
+
|
|
143
|
+
# %% ../nbs/04_DL.ipynb 34
|
|
144
|
+
def lin_wn(ni,nf,dp=0.1,act=nn.SiLU):
|
|
145
|
+
"Weight norm of linear."
|
|
146
|
+
layers = nn.Sequential(
|
|
147
|
+
nn.BatchNorm1d(ni),
|
|
148
|
+
nn.Dropout(dp),
|
|
149
|
+
nn.utils.weight_norm(nn.Linear(ni, nf)) )
|
|
150
|
+
if act: layers.append(act())
|
|
151
|
+
return layers
|
|
152
|
+
|
|
153
|
+
# %% ../nbs/04_DL.ipynb 35
|
|
154
|
+
def conv_wn(ni, nf, ks=3, stride=1, padding=1, dp=0.1,act=nn.ReLU):
|
|
155
|
+
"Weight norm of conv."
|
|
156
|
+
layers = nn.Sequential(
|
|
157
|
+
nn.BatchNorm1d(ni),
|
|
158
|
+
nn.Dropout(dp),
|
|
159
|
+
nn.utils.weight_norm(nn.Conv1d(ni, nf, ks, stride, padding)) )
|
|
160
|
+
if act: layers.append(act())
|
|
161
|
+
return layers
|
|
162
|
+
|
|
163
|
+
# %% ../nbs/04_DL.ipynb 36
|
|
164
|
+
class CNN1D_2(nn.Module):
|
|
165
|
+
|
|
166
|
+
def __init__(self, ni, nf, amp_scale = 16):
|
|
167
|
+
super().__init__()
|
|
168
|
+
|
|
169
|
+
cha_1,cha_2,cha_3 = 256,512,512
|
|
170
|
+
hidden_size = cha_1*amp_scale
|
|
171
|
+
|
|
172
|
+
cha_po_1 = hidden_size//(cha_1*2)
|
|
173
|
+
cha_po_2 = (hidden_size//(cha_1*4)) * cha_3
|
|
174
|
+
|
|
175
|
+
self.lin = lin_wn(ni,hidden_size)
|
|
176
|
+
|
|
177
|
+
# bs, 256, 16
|
|
178
|
+
self.view = View(-1,cha_1,amp_scale)
|
|
179
|
+
|
|
180
|
+
self.conv1 = nn.Sequential(
|
|
181
|
+
conv_wn(cha_1, cha_2, ks=5, stride=1, padding=2, dp=0.1),
|
|
182
|
+
nn.AdaptiveAvgPool1d(output_size = cha_po_1),
|
|
183
|
+
conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.1))
|
|
184
|
+
|
|
185
|
+
self.conv2 = nn.Sequential(
|
|
186
|
+
conv_wn(cha_2, cha_2, ks=3, stride=1, padding=1, dp=0.3),
|
|
187
|
+
conv_wn(cha_2, cha_3, ks=5, stride=1, padding=2, dp=0.2))
|
|
188
|
+
|
|
189
|
+
self.head = nn.Sequential(
|
|
190
|
+
nn.MaxPool1d(kernel_size=4, stride=2, padding=1),
|
|
191
|
+
nn.Flatten(),
|
|
192
|
+
lin_wn(cha_po_2,nf,act=None) )
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def forward(self, x):
|
|
196
|
+
# amplify features to 4096
|
|
197
|
+
x = self.lin(x)
|
|
198
|
+
|
|
199
|
+
# reshape to bs,256,16 for conv1d
|
|
200
|
+
x = self.view(x)
|
|
201
|
+
|
|
202
|
+
x = self.conv1(x)
|
|
203
|
+
|
|
204
|
+
x_s = x # for skip connection (multiply)
|
|
205
|
+
x = self.conv2(x)
|
|
206
|
+
x = x * x_s
|
|
207
|
+
|
|
208
|
+
# Final block
|
|
209
|
+
x = self.head(x)
|
|
210
|
+
|
|
211
|
+
return x
|
|
212
|
+
|
|
213
|
+
# %% ../nbs/04_DL.ipynb 40
|
|
214
|
+
def train_dl(df,
|
|
215
|
+
feat_col,
|
|
216
|
+
target_col,
|
|
217
|
+
split, # tuple of numpy array for split index
|
|
218
|
+
model_func, # function to get pytorch model
|
|
219
|
+
n_epoch = 4, # number of epochs
|
|
220
|
+
bs = 32, # batch size
|
|
221
|
+
lr = 1e-2, # will be useless if lr_find is True
|
|
222
|
+
loss = mse, # loss function
|
|
223
|
+
save = None, # models/{save}.pth
|
|
224
|
+
sampler = None,
|
|
225
|
+
lr_find=False, # if true, will use lr from lr_find
|
|
226
|
+
):
|
|
227
|
+
"A DL trainer."
|
|
228
|
+
|
|
229
|
+
train = df.loc[split[0]]
|
|
230
|
+
valid = df.loc[split[1]]
|
|
231
|
+
|
|
232
|
+
train_ds = GeneralDataset(train, feat_col, target_col)
|
|
233
|
+
valid_ds = GeneralDataset(valid, feat_col, target_col)
|
|
234
|
+
|
|
235
|
+
n_workers = fc.defaults.cpus
|
|
236
|
+
|
|
237
|
+
if sampler is not None:
|
|
238
|
+
|
|
239
|
+
train_dl = DataLoader(train_ds, batch_size=bs, sampler=sampler,num_workers=n_workers)
|
|
240
|
+
valid_dl = DataLoader(valid_ds, batch_size=bs, sampler=sampler,num_workers=n_workers)
|
|
241
|
+
|
|
242
|
+
dls = DataLoaders(train_dl, valid_dl)
|
|
243
|
+
|
|
244
|
+
else:
|
|
245
|
+
|
|
246
|
+
dls = DataLoaders.from_dsets(train_ds, valid_ds, bs=bs, num_workers=n_workers)
|
|
247
|
+
|
|
248
|
+
model = model_func()
|
|
249
|
+
|
|
250
|
+
learn = Learner(dls.to(def_device), model.to(def_device), loss,
|
|
251
|
+
metrics= [PearsonCorrCoef(),SpearmanCorrCoef()],
|
|
252
|
+
cbs = [GradientClip(1.0)] # prevent overfitting
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if lr_find:
|
|
256
|
+
# get learning rate
|
|
257
|
+
lr = learn.lr_find()
|
|
258
|
+
plt.show()
|
|
259
|
+
plt.close()
|
|
260
|
+
print(lr)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
print('lr in training is', lr)
|
|
264
|
+
learn.fit_one_cycle(n_epoch,lr) #cbs = [SaveModelCallback(fname = 'best')] # save best model
|
|
265
|
+
|
|
266
|
+
if save is not None:
|
|
267
|
+
learn.save(save)
|
|
268
|
+
|
|
269
|
+
pred,target = learn.get_preds()
|
|
270
|
+
|
|
271
|
+
pred = pd.DataFrame(pred.detach().cpu().numpy(),index=valid.index,columns=target_col)
|
|
272
|
+
target = pd.DataFrame(target.detach().cpu().numpy(),index=valid.index,columns=target_col)
|
|
273
|
+
|
|
274
|
+
return target, pred
|
|
275
|
+
|
|
276
|
+
# %% ../nbs/04_DL.ipynb 45
|
|
277
|
+
@fc.delegates(train_dl)
|
|
278
|
+
def train_dl_cv(df,
|
|
279
|
+
feat_col,
|
|
280
|
+
target_col,
|
|
281
|
+
splits, # list of tuples
|
|
282
|
+
model_func, # functions like lambda x: return MLP_1(num_feat, num_target)
|
|
283
|
+
save:str=None,
|
|
284
|
+
**kwargs
|
|
285
|
+
):
|
|
286
|
+
|
|
287
|
+
OOF = []
|
|
288
|
+
metrics = []
|
|
289
|
+
|
|
290
|
+
for fold,split in enumerate(splits):
|
|
291
|
+
|
|
292
|
+
print(f'------fold{fold}------')
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
fname=None
|
|
296
|
+
# save best model for each fold
|
|
297
|
+
if save is not None:
|
|
298
|
+
fname = f'{save}_fold{fold}'
|
|
299
|
+
|
|
300
|
+
# train model
|
|
301
|
+
target, pred = train_dl(df,feat_col,target_col, split, model_func ,save=fname,**kwargs)
|
|
302
|
+
|
|
303
|
+
#------------get scores--------------
|
|
304
|
+
# get score metrics
|
|
305
|
+
mse, pearson_avg, _ = score_each(target,pred)
|
|
306
|
+
|
|
307
|
+
# store metrics in a dictionary for the current fold
|
|
308
|
+
fold_metrics = {
|
|
309
|
+
'fold': fold,
|
|
310
|
+
'mse': mse,
|
|
311
|
+
'pearson_avg': pearson_avg
|
|
312
|
+
}
|
|
313
|
+
metrics.append(fold_metrics)
|
|
314
|
+
|
|
315
|
+
OOF.append(pred)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
# Concatenate OOF from each fold to a new dataframe
|
|
319
|
+
oof = pd.concat(OOF).sort_index()
|
|
320
|
+
|
|
321
|
+
# Get metrics into a dataframe
|
|
322
|
+
metrics = pd.DataFrame(metrics)
|
|
323
|
+
|
|
324
|
+
return oof, metrics
|
|
325
|
+
|
|
326
|
+
# %% ../nbs/04_DL.ipynb 53
|
|
327
|
+
def predict_dl(df,
|
|
328
|
+
feat_col,
|
|
329
|
+
target_col,
|
|
330
|
+
model, # model architecture
|
|
331
|
+
model_pth, # only name, not with .pth
|
|
332
|
+
):
|
|
333
|
+
|
|
334
|
+
"Predict dataframe given a deep learning model"
|
|
335
|
+
|
|
336
|
+
test_dset = GeneralDataset(df,feat_col)
|
|
337
|
+
test_dl = DataLoader(test_dset,bs=512)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
learn = Learner(None, model.to(def_device), loss_func=1)
|
|
341
|
+
learn.load(model_pth)
|
|
342
|
+
|
|
343
|
+
learn.model.eval()
|
|
344
|
+
|
|
345
|
+
preds = []
|
|
346
|
+
for data in test_dl:
|
|
347
|
+
inputs = data.to(def_device)
|
|
348
|
+
outputs = learn.model(inputs) #learn.model(x).sigmoid().detach().cpu().numpy()
|
|
349
|
+
|
|
350
|
+
preds.append(outputs.detach().cpu().numpy())
|
|
351
|
+
|
|
352
|
+
preds = np.concatenate(preds)
|
|
353
|
+
preds = pd.DataFrame(preds,index=df.index,columns=target_col)
|
|
354
|
+
|
|
355
|
+
return preds
|
katlas/feature.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_feature.ipynb.
|
|
2
|
+
|
|
3
|
+
# %% auto 0
|
|
4
|
+
__all__ = ['get_rdkit', 'get_morgan', 'get_esm', 'get_t5', 'get_t5_bfd', 'reduce_feature', 'remove_hi_corr', 'preprocess']
|
|
5
|
+
|
|
6
|
+
# %% ../nbs/01_feature.ipynb 4
|
|
7
|
+
from fastbook import *
|
|
8
|
+
import torch,re,joblib,gc,esm
|
|
9
|
+
from tqdm.notebook import tqdm; tqdm.pandas()
|
|
10
|
+
from .core import Data
|
|
11
|
+
|
|
12
|
+
# Rdkit
|
|
13
|
+
from rdkit import Chem
|
|
14
|
+
from rdkit.ML.Descriptors import MoleculeDescriptors
|
|
15
|
+
from rdkit.Chem import Draw,Descriptors,AllChem
|
|
16
|
+
|
|
17
|
+
# Models
|
|
18
|
+
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
|
|
19
|
+
from fairscale.nn.wrap import enable_wrap, wrap
|
|
20
|
+
from transformers import T5Tokenizer, T5EncoderModel, T5Model
|
|
21
|
+
|
|
22
|
+
# Dimension Reduction
|
|
23
|
+
from sklearn import set_config
|
|
24
|
+
from sklearn.decomposition import PCA
|
|
25
|
+
from sklearn.manifold import TSNE
|
|
26
|
+
from sklearn.preprocessing import StandardScaler
|
|
27
|
+
from umap.umap_ import UMAP
|
|
28
|
+
|
|
29
|
+
set_config(transform_output="pandas")
|
|
30
|
+
|
|
31
|
+
# %% ../nbs/01_feature.ipynb 7
|
|
32
|
+
def get_rdkit(df: pd.DataFrame, # a dataframe that contains smiles
|
|
33
|
+
col:str = "SMILES", # colname of smile
|
|
34
|
+
normalize: bool = True, # normalize features using StandardScaler()
|
|
35
|
+
):
|
|
36
|
+
"Extract chemical features from smiles via rdkit.Chem.Descriptors; if normalize, apply StandardScaler"
|
|
37
|
+
|
|
38
|
+
mols = [Chem.MolFromSmiles(smi) for smi in df[col]]
|
|
39
|
+
desc_names = [desc_name[0] for desc_name in Descriptors.descList]
|
|
40
|
+
desc_calc = MoleculeDescriptors.MolecularDescriptorCalculator(desc_names)
|
|
41
|
+
desc_values = [desc_calc.CalcDescriptors(mol) for mol in mols]
|
|
42
|
+
feature_df = pd.DataFrame(np.stack(desc_values), index=df.index,columns=desc_names)
|
|
43
|
+
|
|
44
|
+
if normalize:
|
|
45
|
+
feature_df = StandardScaler().fit_transform(feature_df)
|
|
46
|
+
|
|
47
|
+
# feature_df = feature_df.reset_index()
|
|
48
|
+
return feature_df
|
|
49
|
+
|
|
50
|
+
# %% ../nbs/01_feature.ipynb 11
|
|
51
|
+
def get_morgan(df: pd.DataFrame, # a dataframe that contains smiles
|
|
52
|
+
col: str = "SMILES", # colname of smile
|
|
53
|
+
radius=3
|
|
54
|
+
):
|
|
55
|
+
"Get 2048 morgan fingerprint (binary feature) from smiles in a dataframe"
|
|
56
|
+
mols = [Chem.MolFromSmiles(smi) for smi in df[col]]
|
|
57
|
+
morgan_fps = [AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=2048) for mol in mols]
|
|
58
|
+
fp_df = pd.DataFrame(np.array(morgan_fps), index=df.index)
|
|
59
|
+
fp_df.columns = "morgan_" + fp_df.columns.astype(str)
|
|
60
|
+
return fp_df
|
|
61
|
+
|
|
62
|
+
# %% ../nbs/01_feature.ipynb 15
|
|
63
|
+
def get_esm(df:pd.DataFrame, # a dataframe that contains amino acid sequence
|
|
64
|
+
col: str = 'sequence', # colname of amino acid sequence
|
|
65
|
+
model_name: str = "esm2_t33_650M_UR50D", # Name of the ESM model to use for the embeddings.
|
|
66
|
+
):
|
|
67
|
+
|
|
68
|
+
"Extract esmfold2 embeddings from protein sequence in a dataframe"
|
|
69
|
+
|
|
70
|
+
# Initialize distributed world with world_size 1
|
|
71
|
+
if not torch.distributed.is_initialized():
|
|
72
|
+
url = "tcp://localhost:23456"
|
|
73
|
+
torch.distributed.init_process_group(backend="nccl", init_method=url, world_size=1, rank=0)
|
|
74
|
+
|
|
75
|
+
#get number of repr layers
|
|
76
|
+
match = re.search(r'_t(\d+)_', model_name)
|
|
77
|
+
number = int(match.group(1))
|
|
78
|
+
print(f"repr_layers number for model {model_name} is {number}.")
|
|
79
|
+
print("You can also choose other esm2 models:",
|
|
80
|
+
"\nesm2_t48_15B_UR50D\nesm2_t36_3B_UR50D\nesm2_t33_650M_UR50D\nesm2_t30_150M_UR50D\nesm2_t12_35M_UR50D\nesm2_t6_8M_UR50D\n")
|
|
81
|
+
|
|
82
|
+
# Download model data from the hub
|
|
83
|
+
model_data, regression_data = esm.pretrained._download_model_and_regression_data(model_name)
|
|
84
|
+
|
|
85
|
+
# Initialize the model with FSDP wrapper
|
|
86
|
+
fsdp_params = dict(
|
|
87
|
+
mixed_precision=True,
|
|
88
|
+
flatten_parameters=True,
|
|
89
|
+
state_dict_device=torch.device("cpu"), # reduce GPU mem usage
|
|
90
|
+
cpu_offload=True, # enable cpu offloading
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
|
|
94
|
+
model, vocab = esm.pretrained.load_model_and_alphabet_core(
|
|
95
|
+
model_name, model_data, regression_data
|
|
96
|
+
)
|
|
97
|
+
batch_converter = vocab.get_batch_converter()
|
|
98
|
+
model.eval()
|
|
99
|
+
|
|
100
|
+
# Wrap each layer in FSDP separately
|
|
101
|
+
for name, child in model.named_children():
|
|
102
|
+
if name == "layers":
|
|
103
|
+
for layer_name, layer in child.named_children():
|
|
104
|
+
wrapped_layer = wrap(layer)
|
|
105
|
+
setattr(child, layer_name, wrapped_layer)
|
|
106
|
+
model = wrap(model)
|
|
107
|
+
|
|
108
|
+
# Define the feature extraction function
|
|
109
|
+
def esm_embeddings(r, colname=col):
|
|
110
|
+
data = [('protein', r[colname])]
|
|
111
|
+
labels, strs, tokens = batch_converter(data)
|
|
112
|
+
with torch.no_grad():
|
|
113
|
+
results = model(tokens.cuda(), repr_layers=[number], return_contacts=False)
|
|
114
|
+
rpr = results["representations"][number].squeeze()
|
|
115
|
+
rpr = rpr[1 : len(r[colname]) + 1].mean(0).detach().cpu().numpy()
|
|
116
|
+
|
|
117
|
+
del results, labels, strs, tokens, data #especially need to delete those on cuda: tokens, results
|
|
118
|
+
gc.collect()
|
|
119
|
+
|
|
120
|
+
return rpr
|
|
121
|
+
|
|
122
|
+
# Apply the feature extraction function to each row in the DataFrame
|
|
123
|
+
series = df.progress_apply(esm_embeddings, axis=1)
|
|
124
|
+
df_feature = pd.DataFrame(series.tolist(), index=df.index)
|
|
125
|
+
df_feature.columns = 'esm_' + df_feature.columns.astype(str)
|
|
126
|
+
|
|
127
|
+
return df_feature
|
|
128
|
+
|
|
129
|
+
# %% ../nbs/01_feature.ipynb 19
|
|
130
|
+
def get_t5(df: pd.DataFrame,
|
|
131
|
+
col: str = 'sequence'
|
|
132
|
+
):
|
|
133
|
+
"Extract ProtT5-XL-uniref50 embeddings from protein sequence in a dataframe"
|
|
134
|
+
|
|
135
|
+
# Reference: https://github.com/agemagician/ProtTrans/tree/master/Embedding/PyTorch/Advanced
|
|
136
|
+
# Load the tokenizer
|
|
137
|
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)
|
|
138
|
+
|
|
139
|
+
# Load the model
|
|
140
|
+
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to('cuda')
|
|
141
|
+
|
|
142
|
+
# Set the model precision based on the device
|
|
143
|
+
model.half()
|
|
144
|
+
|
|
145
|
+
def T5_embeddings(sequence):
|
|
146
|
+
seq_len = len(sequence)
|
|
147
|
+
# Prepare the protein sequences as a list
|
|
148
|
+
sequence = [" ".join(list(re.sub(r"[UZOB]", "X", sequence)))]
|
|
149
|
+
|
|
150
|
+
# Tokenize sequences and pad up to the longest sequence in the batch
|
|
151
|
+
ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
|
|
152
|
+
input_ids = torch.tensor(ids['input_ids']).to('cuda')
|
|
153
|
+
attention_mask = torch.tensor(ids['attention_mask']).to('cuda')
|
|
154
|
+
|
|
155
|
+
# Generate embeddings
|
|
156
|
+
with torch.no_grad():
|
|
157
|
+
embedding_rpr = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
158
|
+
|
|
159
|
+
emb_mean = embedding_rpr.last_hidden_state[0][:seq_len].detach().cpu().numpy().mean(axis=0)
|
|
160
|
+
|
|
161
|
+
return emb_mean
|
|
162
|
+
|
|
163
|
+
series = df[col].progress_apply(T5_embeddings)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
T5_feature = pd.DataFrame(series.tolist(),index=df.index)
|
|
167
|
+
T5_feature.columns = 'T5_' + T5_feature.columns.astype(str)
|
|
168
|
+
|
|
169
|
+
return T5_feature
|
|
170
|
+
|
|
171
|
+
# %% ../nbs/01_feature.ipynb 22
|
|
172
|
+
def get_t5_bfd(df:pd.DataFrame,
|
|
173
|
+
col: str = 'sequence'
|
|
174
|
+
):
|
|
175
|
+
|
|
176
|
+
"Extract ProtT5-XL-BFD embeddings from protein sequence in a dataframe"
|
|
177
|
+
# Reference: https://github.com/agemagician/ProtTrans/tree/master/Embedding/PyTorch/Advanced
|
|
178
|
+
# Load the tokenizer
|
|
179
|
+
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_bfd', do_lower_case=False)
|
|
180
|
+
|
|
181
|
+
model = T5Model.from_pretrained("Rostlab/prot_t5_xl_bfd").to('cuda')
|
|
182
|
+
|
|
183
|
+
model.eval()
|
|
184
|
+
|
|
185
|
+
def T5_embeddings_bfd(sequence, device = 'cuda'):
|
|
186
|
+
|
|
187
|
+
seq_len = len(sequence)
|
|
188
|
+
|
|
189
|
+
# Prepare the protein sequences as a list
|
|
190
|
+
sequence = [" ".join(list(re.sub(r"[UZOB]", "X", sequence)))]
|
|
191
|
+
|
|
192
|
+
# Tokenize sequences and pad up to the longest sequence in the batch
|
|
193
|
+
ids = tokenizer.batch_encode_plus(sequence, add_special_tokens=True, padding="longest")
|
|
194
|
+
input_ids = torch.tensor(ids['input_ids']).to(device)
|
|
195
|
+
attention_mask = torch.tensor(ids['attention_mask']).to(device)
|
|
196
|
+
|
|
197
|
+
# Generate embeddings
|
|
198
|
+
with torch.no_grad():
|
|
199
|
+
embedding_rpr = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids = input_ids)
|
|
200
|
+
|
|
201
|
+
emb_mean = embedding_rpr.last_hidden_state[0][:seq_len].detach().cpu().numpy().mean(axis=0)
|
|
202
|
+
|
|
203
|
+
return emb_mean
|
|
204
|
+
|
|
205
|
+
series = df[col].progress_apply(T5_embeddings_bfd)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
T5_feature = pd.DataFrame(series.tolist(),index=df.index)
|
|
209
|
+
T5_feature.columns = 'T5bfd_' + T5_feature.columns.astype(str)
|
|
210
|
+
|
|
211
|
+
return T5_feature
|
|
212
|
+
|
|
213
|
+
# %% ../nbs/01_feature.ipynb 26
|
|
214
|
+
def reduce_feature(df: pd.DataFrame,
|
|
215
|
+
method: str='pca', # dimensionality reduction method, accept both capital and lower case
|
|
216
|
+
complexity: int=20, # None for PCA; perfplexity for TSNE, recommend: 30; n_neigbors for UMAP, recommend: 15
|
|
217
|
+
n: int=2, # n_components
|
|
218
|
+
load: str=None, # load a previous model, e.g. model.pkl
|
|
219
|
+
save: str=None, # pkl file to be saved, e.g. pca_model.pkl
|
|
220
|
+
seed: int=123, # seed for random_state
|
|
221
|
+
**kwargs, # arguments from PCA, TSNE, or UMAP depends on which method to use
|
|
222
|
+
):
|
|
223
|
+
|
|
224
|
+
"Reduce the dimensionality given a dataframe of values"
|
|
225
|
+
|
|
226
|
+
method = method.lower()
|
|
227
|
+
assert method in ['pca','tsne','umap'], "Please choose a method among PCA, TSNE, and UMAP"
|
|
228
|
+
|
|
229
|
+
if load is not None:
|
|
230
|
+
reducer = joblib.load(load)
|
|
231
|
+
else:
|
|
232
|
+
if method == 'pca':
|
|
233
|
+
reducer = PCA(n_components=n, random_state=seed,**kwargs)
|
|
234
|
+
elif method == 'tsne':
|
|
235
|
+
reducer = TSNE(n_components=n,
|
|
236
|
+
random_state=seed,
|
|
237
|
+
perplexity = complexity, # default from official is 30
|
|
238
|
+
**kwargs)
|
|
239
|
+
elif method == 'umap':
|
|
240
|
+
reducer = UMAP(n_components=n,
|
|
241
|
+
random_state=seed,
|
|
242
|
+
n_neighbors=complexity, # default from official is 15, try 15-200
|
|
243
|
+
**kwargs)
|
|
244
|
+
else:
|
|
245
|
+
raise ValueError('Invalid method specified')
|
|
246
|
+
|
|
247
|
+
proj = reducer.fit_transform(df)
|
|
248
|
+
embedding_df = pd.DataFrame(proj).set_index(df.index)
|
|
249
|
+
embedding_df.columns = [f"{method.upper()}{i}" for i in range(1, n + 1)]
|
|
250
|
+
|
|
251
|
+
if save is not None:
|
|
252
|
+
path = Path(save)
|
|
253
|
+
path.parent.mkdir(exist_ok=True)
|
|
254
|
+
|
|
255
|
+
joblib.dump(reducer, save)
|
|
256
|
+
|
|
257
|
+
return embedding_df
|
|
258
|
+
|
|
259
|
+
# %% ../nbs/01_feature.ipynb 29
|
|
260
|
+
def remove_hi_corr(df: pd.DataFrame,
|
|
261
|
+
thr: float=0.98 # threshold
|
|
262
|
+
):
|
|
263
|
+
"Remove highly correlated features in a dataframe given a pearson threshold"
|
|
264
|
+
|
|
265
|
+
# Create correlation matrix
|
|
266
|
+
corr_matrix = df.corr().abs()
|
|
267
|
+
|
|
268
|
+
# Select upper triangle of correlation matrix
|
|
269
|
+
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
|
270
|
+
|
|
271
|
+
# Find index of feature columns with correlation greater than threshold
|
|
272
|
+
to_drop = [column for column in upper.columns if any(upper[column] > thr)]
|
|
273
|
+
|
|
274
|
+
# Drop features
|
|
275
|
+
df = df.drop(to_drop, axis=1)
|
|
276
|
+
|
|
277
|
+
return df
|
|
278
|
+
|
|
279
|
+
# %% ../nbs/01_feature.ipynb 33
|
|
280
|
+
def preprocess(df: pd.DataFrame,
|
|
281
|
+
thr: float=0.98):
|
|
282
|
+
|
|
283
|
+
"Remove features with no variance, and highly correlated features based on threshold"
|
|
284
|
+
|
|
285
|
+
df_original = df.copy()
|
|
286
|
+
df = df.loc[:,df.std() != 0]
|
|
287
|
+
df = remove_hi_corr(df, thr)
|
|
288
|
+
dropping_col = set(df_original.columns) - set(df.columns)
|
|
289
|
+
print(f'removing columns: {dropping_col}')
|
|
290
|
+
return df
|