EntDetect 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- EntDetect/Jwalk/GridTools.py +567 -0
- EntDetect/Jwalk/PDBTools.py +532 -0
- EntDetect/Jwalk/SASDTools.py +543 -0
- EntDetect/Jwalk/SurfaceTools.py +150 -0
- EntDetect/Jwalk/__init__.py +19 -0
- EntDetect/Jwalk/naccess.config.txt +255 -0
- EntDetect/__init__.py +10 -0
- EntDetect/_logging.py +71 -0
- EntDetect/change_resolution.py +2361 -0
- EntDetect/clustering.py +2626 -0
- EntDetect/compare_sim2exp.py +1927 -0
- EntDetect/entanglement_features.py +478 -0
- EntDetect/gaussian_entanglement.py +2067 -0
- EntDetect/order_params.py +1048 -0
- EntDetect/resources/__init__.py +11 -0
- EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
- EntDetect/resources/calc_K.pl +712 -0
- EntDetect/resources/calc_Q.pl +962 -0
- EntDetect/resources/pulchra +0 -0
- EntDetect/resources/shared_files/__init__.py +2 -0
- EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
- EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
- EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
- EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
- EntDetect/resources/stride +0 -0
- EntDetect/statistics.py +1344 -0
- EntDetect/utilities.py +201 -0
- entdetect-1.2.0.dist-info/METADATA +26 -0
- entdetect-1.2.0.dist-info/RECORD +45 -0
- entdetect-1.2.0.dist-info/WHEEL +5 -0
- entdetect-1.2.0.dist-info/entry_points.txt +11 -0
- entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
- entdetect-1.2.0.dist-info/top_level.txt +2 -0
- scripts/__init__.py +5 -0
- scripts/convert_cor_psf_to_pdb.py +103 -0
- scripts/run_Foldingpathway.py +162 -0
- scripts/run_MSM.py +152 -0
- scripts/run_OP_on_simulation_traj.py +194 -0
- scripts/run_change_resolution.py +63 -0
- scripts/run_compare_sim2exp.py +215 -0
- scripts/run_montecarlo.py +158 -0
- scripts/run_nativeNCLE.py +179 -0
- scripts/run_nonnative_entanglement_clustering.py +110 -0
- scripts/run_population_modeling.py +117 -0
- scripts/run_workflow4_nativeNCLE_batch.py +412 -0
|
@@ -0,0 +1,1927 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
import sys, getopt, math, os, time, traceback, glob, multiprocessing, copy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from scipy.stats import norm
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from statsmodels.stats.multitest import multipletests
|
|
7
|
+
import parmed as pmd
|
|
8
|
+
import mdtraj as mdt
|
|
9
|
+
import pathlib, re
|
|
10
|
+
import logging
|
|
11
|
+
from EntDetect._logging import setup_logger
|
|
12
|
+
|
|
13
|
+
class MassSpec:
|
|
14
|
+
"""
|
|
15
|
+
Compare ensembles of protein structures with LiPMS and XLMS experimental data.
|
|
16
|
+
Primary author: Yang Jiang
|
|
17
|
+
Secondary author: Ian Sitarik
|
|
18
|
+
"""
|
|
19
|
+
#############################################################################################################
|
|
20
|
+
def __init__(self, msm_data_file:str, meta_dist_file:str, LiPMS_exp_file:str, sasa_data_file:str, XLMS_exp_file:str, dist_data_file:str,
|
|
21
|
+
cluster_data_file:str, OPpath:str, AAdcd_dir:str, native_AA_pdb:str, native_state_idx:int, state_idx_list:list, prot_len:int, last_num_frames:int,
|
|
22
|
+
rm_traj_list:list=[], outdir:str='./', ID:str='', xp_dir:str=None, resid2residueidx_map:dict={},
|
|
23
|
+
start:int=0, end:int=999999999999, stride:int=1, verbose:bool=False, num_perm:int=10000, n_boot:int=10000, lag_frame:int=1, nproc:int=1, log_level:int=logging.INFO, logdir:str=None):
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
self.msm_data_file = msm_data_file
|
|
27
|
+
self.meta_dist_file = meta_dist_file
|
|
28
|
+
self.LiPMS_exp_file = LiPMS_exp_file
|
|
29
|
+
self.sasa_data_file = sasa_data_file
|
|
30
|
+
self.XLMS_exp_file = XLMS_exp_file
|
|
31
|
+
self.dist_data_file = dist_data_file
|
|
32
|
+
self.xp_dir = xp_dir
|
|
33
|
+
self.cluster_data_file = cluster_data_file
|
|
34
|
+
self.OPpath = OPpath
|
|
35
|
+
self.AAdcd_dir = AAdcd_dir
|
|
36
|
+
self.native_AA_pdb = native_AA_pdb
|
|
37
|
+
self.native_state_idx = native_state_idx
|
|
38
|
+
self.state_idx_list = state_idx_list
|
|
39
|
+
self.rm_traj_list = rm_traj_list
|
|
40
|
+
self.outdir = outdir
|
|
41
|
+
self.ID = ID
|
|
42
|
+
|
|
43
|
+
# Initialize logging before any log calls in constructor.
|
|
44
|
+
self.logger = setup_logger('MassSpec', outdir=logdir if logdir is not None else self.outdir, ID=self.ID, log_level=log_level)
|
|
45
|
+
|
|
46
|
+
self.resid2residueidx_map = resid2residueidx_map
|
|
47
|
+
if len(self.resid2residueidx_map) == 0:
|
|
48
|
+
self.resid2residueidx_map = {i + 1:i for i in range(prot_len)}
|
|
49
|
+
self.logger.info(f'No resid2residueidx_map provided, using identity mapping for protein length {prot_len} with an offset of -1')
|
|
50
|
+
self.start = start
|
|
51
|
+
self.end = end
|
|
52
|
+
self.stride = stride
|
|
53
|
+
self.verbose = verbose
|
|
54
|
+
|
|
55
|
+
self.res_buffer = 5
|
|
56
|
+
self.num_perm = num_perm
|
|
57
|
+
self.prot_len = prot_len
|
|
58
|
+
self.nproc = nproc
|
|
59
|
+
|
|
60
|
+
self.if_calc_M = 1
|
|
61
|
+
|
|
62
|
+
self.last_num_frames = last_num_frames # last X ns
|
|
63
|
+
self.lag_frame = lag_frame # down sample trajectories at each #lag_frame frame
|
|
64
|
+
self.n_boot = n_boot
|
|
65
|
+
#self.n_boot = 100
|
|
66
|
+
|
|
67
|
+
# make the outdir if it doesnt existgs
|
|
68
|
+
if not os.path.exists(self.outdir):
|
|
69
|
+
os.makedirs(self.outdir)
|
|
70
|
+
self.logger.debug(f'Creating directory: {self.outdir}')
|
|
71
|
+
##############################################################################
|
|
72
|
+
|
|
73
|
+
##############################################################################
|
|
74
|
+
def load_LiPMS_data(self, file_path):
|
|
75
|
+
df = pd.read_excel(file_path)
|
|
76
|
+
LiPMS_sig_data = {}
|
|
77
|
+
for index, row in df.iterrows():
|
|
78
|
+
peptide = row['Cut Site']
|
|
79
|
+
if '-' not in peptide:
|
|
80
|
+
resid = int(peptide.strip()[1:])-1
|
|
81
|
+
peptide_range = set(list(np.arange(np.max([0,resid-self.res_buffer]), np.min([self.prot_len,resid+self.res_buffer+1]))))
|
|
82
|
+
if peptide not in LiPMS_sig_data:
|
|
83
|
+
LiPMS_sig_data[peptide] = {}
|
|
84
|
+
LiPMS_sig_data[peptide]['peptide_range'] = []
|
|
85
|
+
LiPMS_sig_data[peptide]['qual_change'] = []
|
|
86
|
+
LiPMS_sig_data[peptide]['peptide_range'] = peptide_range
|
|
87
|
+
if row['Log2 FC'] < 0:
|
|
88
|
+
LiPMS_sig_data[peptide]['qual_change'].append(-1)
|
|
89
|
+
elif row['Log2 FC'] > 0:
|
|
90
|
+
LiPMS_sig_data[peptide]['qual_change'].append(1)
|
|
91
|
+
elif row['Log2 FC'] == 0:
|
|
92
|
+
LiPMS_sig_data[peptide]['qual_change'].append(0)
|
|
93
|
+
removed_key_list = []
|
|
94
|
+
for k, v in LiPMS_sig_data.items():
|
|
95
|
+
a = list(set(v['qual_change']))
|
|
96
|
+
if len(a) == 1:
|
|
97
|
+
v['qual_change'] = a[0]
|
|
98
|
+
else:
|
|
99
|
+
self.logger.info('Site %s has inconsistent changes in abundance: %s'%(k, str(v['qual_change'])))
|
|
100
|
+
removed_key_list.append(k)
|
|
101
|
+
for k in removed_key_list:
|
|
102
|
+
LiPMS_sig_data.pop(k)
|
|
103
|
+
LiPMS_sig_data = dict(sorted(LiPMS_sig_data.items(), key=lambda item: int(item[0].strip()[1:])))
|
|
104
|
+
return LiPMS_sig_data
|
|
105
|
+
##############################################################################
|
|
106
|
+
|
|
107
|
+
##############################################################################
|
|
108
|
+
def load_XLMS_data(self, file_path):
|
|
109
|
+
df = pd.read_excel(file_path)
|
|
110
|
+
XLMS_sig_data = {}
|
|
111
|
+
for index, row in df.iterrows():
|
|
112
|
+
XL_site_key = row['Pairs'].strip()
|
|
113
|
+
XL_site = XL_site_key.split('-')
|
|
114
|
+
if XL_site_key not in XLMS_sig_data.keys():
|
|
115
|
+
XLMS_sig_data[XL_site_key] = {}
|
|
116
|
+
XLMS_sig_data[XL_site_key]['qual_change'] = []
|
|
117
|
+
XLMS_sig_data[XL_site_key]['res_pair'] = []
|
|
118
|
+
XLMS_sig_data[XL_site_key]['res_pair'] = [int(XL_site[0][1:])-1, int(XL_site[1][1:])-1]
|
|
119
|
+
if row['log2(heavy/light)'] < 0:
|
|
120
|
+
XLMS_sig_data[XL_site_key]['qual_change'].append(-1)
|
|
121
|
+
elif row['log2(heavy/light)'] > 0:
|
|
122
|
+
XLMS_sig_data[XL_site_key]['qual_change'].append(1)
|
|
123
|
+
elif row['log2(heavy/light)'] == 0:
|
|
124
|
+
XLMS_sig_data[XL_site_key]['qual_change'].append(0)
|
|
125
|
+
removed_key_list = []
|
|
126
|
+
for k, v in XLMS_sig_data.items():
|
|
127
|
+
a = list(set(v['qual_change']))
|
|
128
|
+
if len(a) == 1:
|
|
129
|
+
v['qual_change'] = a[0]
|
|
130
|
+
else:
|
|
131
|
+
self.logger.info('Sites %s has inconsistent changes in abundance: %s'%(k, str(v['qual_change'])))
|
|
132
|
+
removed_key_list.append(k)
|
|
133
|
+
for k in removed_key_list:
|
|
134
|
+
XLMS_sig_data.pop(k)
|
|
135
|
+
XLMS_sig_data = dict(sorted(XLMS_sig_data.items(), key=lambda item: (int(item[0].split('-')[0][1:]), int(item[0].split('-')[1][1:]))))
|
|
136
|
+
return XLMS_sig_data
|
|
137
|
+
##############################################################################
|
|
138
|
+
|
|
139
|
+
##############################################################################
|
|
140
|
+
def score_XL(self, pair_AA, JWalk_dist):
|
|
141
|
+
XL_offset = 1.1
|
|
142
|
+
sc_length = {'K': 6.3,
|
|
143
|
+
'S': 2.5,
|
|
144
|
+
'T': 2.5,
|
|
145
|
+
'Y': 6.5,
|
|
146
|
+
'M': 1.5,}
|
|
147
|
+
KK_mu = 18.6
|
|
148
|
+
KK_sigma = 6.0
|
|
149
|
+
KK_threshold = 33
|
|
150
|
+
|
|
151
|
+
KK_mu += XL_offset
|
|
152
|
+
KK_sigma = (XL_offset + 3*KK_sigma) / 3
|
|
153
|
+
KK_threshold += XL_offset
|
|
154
|
+
|
|
155
|
+
mu = KK_mu + (sc_length[pair_AA[0]] + sc_length[pair_AA[1]]) - 2*sc_length['K']
|
|
156
|
+
sigma = (mu - (KK_mu - 3*KK_sigma)) / 3
|
|
157
|
+
threshold = KK_threshold + mu - KK_mu
|
|
158
|
+
|
|
159
|
+
N = norm(mu, sigma)
|
|
160
|
+
|
|
161
|
+
if JWalk_dist == -1:
|
|
162
|
+
score = 0
|
|
163
|
+
elif JWalk_dist <= threshold:
|
|
164
|
+
score = N.pdf(JWalk_dist)
|
|
165
|
+
else:
|
|
166
|
+
score = 0
|
|
167
|
+
return score
|
|
168
|
+
##############################################################################
|
|
169
|
+
|
|
170
|
+
##############################################################################
|
|
171
|
+
def perm_fun(self, perm_idx_list, combined_data, length_1):
|
|
172
|
+
d_1 = combined_data[perm_idx_list[:length_1]]
|
|
173
|
+
d_2 = combined_data[perm_idx_list[length_1:]]
|
|
174
|
+
return self.statistic_fun(d_1, d_2, 0)
|
|
175
|
+
##############################################################################
|
|
176
|
+
|
|
177
|
+
##############################################################################
|
|
178
|
+
def permutation_test(self, perm_stat_fun, data_1, data_2, num_perm, side='!='):
|
|
179
|
+
if side not in ['!=', '>', '<']:
|
|
180
|
+
self.logger.info('side parameter is wrong for function "permutation_test". It must be "!=", ">", or "<".')
|
|
181
|
+
sys.exit()
|
|
182
|
+
combined_data = np.array(list(data_1) + list(data_2))
|
|
183
|
+
perm_idx_list_0 = np.arange(len(combined_data))
|
|
184
|
+
t0 = self.perm_fun(perm_idx_list_0, combined_data, len(data_1))
|
|
185
|
+
|
|
186
|
+
pool = multiprocessing.Pool(self.nproc)
|
|
187
|
+
pool_list = []
|
|
188
|
+
start_time = time.time()
|
|
189
|
+
self.logger.debug('start permutation test')
|
|
190
|
+
for i in range(num_perm):
|
|
191
|
+
perm_idx_list = np.random.permutation(perm_idx_list_0)
|
|
192
|
+
pool_list.append(pool.apply_async(self.perm_fun, (perm_idx_list, combined_data, len(data_1),)))
|
|
193
|
+
pool.close()
|
|
194
|
+
pool.join()
|
|
195
|
+
t_dist = [p.get() for p in pool_list]
|
|
196
|
+
p = 0
|
|
197
|
+
for t in t_dist:
|
|
198
|
+
if side == '!=' and np.abs(t) >= np.abs(t0):
|
|
199
|
+
p += 1
|
|
200
|
+
elif side == '>' and t >= t0:
|
|
201
|
+
p += 1
|
|
202
|
+
elif side == '<' and t <= t0:
|
|
203
|
+
p += 1
|
|
204
|
+
p = (p+1)/(num_perm+1)
|
|
205
|
+
used_time = time.time() - start_time
|
|
206
|
+
self.logger.info('%.2fs'%used_time)
|
|
207
|
+
return p
|
|
208
|
+
##############################################################################
|
|
209
|
+
|
|
210
|
+
##############################################################################
|
|
211
|
+
def bootstrap(self, boot_fun, data, n_time):
|
|
212
|
+
def fun(boot_fun, sample_idx_list):
|
|
213
|
+
if len(data.shape) == 1:
|
|
214
|
+
new_data = data[sample_idx_list]
|
|
215
|
+
return boot_fun(new_data)
|
|
216
|
+
elif len(data.shape) == 2:
|
|
217
|
+
new_data = data[sample_idx_list, :]
|
|
218
|
+
result = np.array([boot_fun(new_data[:,j]) for j in range(data.shape[1])])
|
|
219
|
+
return result
|
|
220
|
+
idx_list = np.arange(len(data))
|
|
221
|
+
if len(data.shape) == 1:
|
|
222
|
+
boot_stat = np.zeros(n_time)
|
|
223
|
+
elif len(data.shape) == 2:
|
|
224
|
+
boot_stat = np.zeros((n_time, data.shape[1]))
|
|
225
|
+
else:
|
|
226
|
+
self.logger.info('bootstrap: Can only handle 1 or 2 dimentional data')
|
|
227
|
+
sys.exit()
|
|
228
|
+
|
|
229
|
+
boot_stat = []
|
|
230
|
+
for i in range(n_time):
|
|
231
|
+
sample_idx_list = np.random.choice(idx_list, len(idx_list))
|
|
232
|
+
bs = fun(boot_fun, sample_idx_list)
|
|
233
|
+
boot_stat.append(bs)
|
|
234
|
+
boot_stat = np.array(boot_stat)
|
|
235
|
+
|
|
236
|
+
# pool = multiprocessing.Pool(nproc)
|
|
237
|
+
# pool_list = []
|
|
238
|
+
# start_time = time.time()
|
|
239
|
+
# print('start bootstrapping')
|
|
240
|
+
# for i in range(n_time):
|
|
241
|
+
# sample_idx_list = np.random.choice(idx_list, len(idx_list))
|
|
242
|
+
# pool_list.append(pool.apply_async(fun, (boot_fun, sample_idx_list)))
|
|
243
|
+
# pool.close()
|
|
244
|
+
# pool.join()
|
|
245
|
+
# boot_stat = np.array([p.get() for p in pool_list])
|
|
246
|
+
# used_time = time.time() - start_time
|
|
247
|
+
# print('%.2fs'%used_time)
|
|
248
|
+
return boot_stat
|
|
249
|
+
##############################################################################
|
|
250
|
+
|
|
251
|
+
##############################################################################
|
|
252
|
+
def remove_traj_from_frame_list(self, rm_traj_list, frame_list, traj_axis):
|
|
253
|
+
if len(frame_list) == 0:
|
|
254
|
+
result_list = frame_list
|
|
255
|
+
else:
|
|
256
|
+
if traj_axis == 0: # traj ids in 1st row
|
|
257
|
+
sel_idx = []
|
|
258
|
+
for idx, i in enumerate(frame_list[0,:]):
|
|
259
|
+
if i not in rm_traj_list:
|
|
260
|
+
sel_idx.append(idx)
|
|
261
|
+
result_list = frame_list[:,sel_idx]
|
|
262
|
+
elif traj_axis == 1: # traj ids in 1st column
|
|
263
|
+
sel_idx = []
|
|
264
|
+
for idx, i in enumerate(frame_list[:,0]):
|
|
265
|
+
if i not in rm_traj_list:
|
|
266
|
+
sel_idx.append(idx)
|
|
267
|
+
result_list = frame_list[sel_idx,:]
|
|
268
|
+
return result_list
|
|
269
|
+
##############################################################################
|
|
270
|
+
|
|
271
|
+
##############################################################################
|
|
272
|
+
def statistic_fun(self, data_1, data_2, ref):
|
|
273
|
+
a = (np.mean(data_1) - np.mean(data_2)) - ref
|
|
274
|
+
b = (np.std(data_1)**2/len(data_1) + np.std(data_2)**2/len(data_2))**0.5
|
|
275
|
+
if a == 0 and b == 0:
|
|
276
|
+
stat = 1.0
|
|
277
|
+
else:
|
|
278
|
+
stat = a/b
|
|
279
|
+
return stat
|
|
280
|
+
##############################################################################
|
|
281
|
+
|
|
282
|
+
##############################################################################
|
|
283
|
+
def bootstrap_test(self, data_1, data_2, statistic_fun, n_time, side='!='):
|
|
284
|
+
if side not in ['!=', '>', '<']:
|
|
285
|
+
self.logger.info('side parameter is wrong for function "bootstrap_test". It must be "!=", ">", or "<".')
|
|
286
|
+
sys.exit()
|
|
287
|
+
idx_list_1 = np.arange(len(data_1))
|
|
288
|
+
idx_list_2 = np.arange(len(data_2))
|
|
289
|
+
boot_stat = []
|
|
290
|
+
ref_0 = 0
|
|
291
|
+
ref_1 = np.mean(data_1) - np.mean(data_2)
|
|
292
|
+
for i in range(n_time):
|
|
293
|
+
sample_idx_list_1 = np.random.choice(idx_list_1, len(idx_list_1))
|
|
294
|
+
sample_idx_list_2 = np.random.choice(idx_list_2, len(idx_list_2))
|
|
295
|
+
bs = statistic_fun(data_1[sample_idx_list_1], data_2[sample_idx_list_2], ref_1)
|
|
296
|
+
boot_stat.append(bs)
|
|
297
|
+
boot_stat = np.array(boot_stat)
|
|
298
|
+
|
|
299
|
+
boot_stat_0 = statistic_fun(data_1, data_2, ref_0)
|
|
300
|
+
|
|
301
|
+
if side == '!=':
|
|
302
|
+
p = (np.min([len(np.where(boot_stat >= boot_stat_0)[0]), len(np.where(boot_stat <= boot_stat_0)[0])])+1) / (n_time+1) * 2
|
|
303
|
+
elif side == '>':
|
|
304
|
+
p = (len(np.where(boot_stat >= boot_stat_0)[0])+1) / (n_time+1)
|
|
305
|
+
elif side == '<':
|
|
306
|
+
p = (len(np.where(boot_stat <= boot_stat_0)[0])+1) / (n_time+1)
|
|
307
|
+
|
|
308
|
+
if p > 1:
|
|
309
|
+
p = 1.0
|
|
310
|
+
|
|
311
|
+
return (p, boot_stat)
|
|
312
|
+
##############################################################################
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
##############################################################################
|
|
316
|
+
def LiP_XL_MS_ConsistencyTest(self,):
|
|
317
|
+
self.logger.info(f'Comparing simulation to experimental data...')
|
|
318
|
+
xlsx_outfile = os.path.join(self.outdir, f'LiPMS_XLMS_consist_pvalues_metastates_v11_down_sample_lag{self.lag_frame}.xlsx')
|
|
319
|
+
npz_outfile = os.path.join(self.outdir, 'LiPMS_XLMS_consist_data_v9.npz')
|
|
320
|
+
self.logger.debug(f'xlsx_outfile: {xlsx_outfile}')
|
|
321
|
+
self.logger.debug(f'npz_outfile: {npz_outfile}')
|
|
322
|
+
|
|
323
|
+
if os.path.exists(npz_outfile) and os.path.exists(xlsx_outfile):
|
|
324
|
+
self.logger.info(f'npz_outfile EXISTS: Loading...')
|
|
325
|
+
self.logger.info(f'xlsx_outfile EXISTS: Loading...')
|
|
326
|
+
M_data = np.load(npz_outfile, allow_pickle=True)
|
|
327
|
+
XLSX_df = pd.read_excel(xlsx_outfile)
|
|
328
|
+
#print(f'XLSX_df:\n{XLSX_df}')
|
|
329
|
+
|
|
330
|
+
else:
|
|
331
|
+
|
|
332
|
+
#################################################################
|
|
333
|
+
# Load MSM data
|
|
334
|
+
MSM_data = pd.read_csv(self.msm_data_file)
|
|
335
|
+
self.logger.info(f'MSM_data\n{MSM_data}')
|
|
336
|
+
meta_states = MSM_data['metastablestate'].unique()
|
|
337
|
+
meta_states = np.array(meta_states, dtype=int)
|
|
338
|
+
self.logger.debug(f'meta_states: {meta_states}')
|
|
339
|
+
num_meta_states = len(meta_states)
|
|
340
|
+
self.logger.debug(f'num_meta_states: {num_meta_states}')
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
meta_dtrajs_last = []
|
|
344
|
+
traj_idx_to_trajnum = {} # mapping traj_idx to traj number
|
|
345
|
+
for traj_idx, (traj, traj_df) in enumerate(MSM_data.groupby('traj')):
|
|
346
|
+
traj_len = len(traj_df)
|
|
347
|
+
self.logger.debug(f'traj: {traj}, traj_len: {traj_len}\n{traj_df.head()}')
|
|
348
|
+
|
|
349
|
+
last = traj_df.iloc[-self.last_num_frames:,:]
|
|
350
|
+
last = last.reset_index(drop=True)
|
|
351
|
+
last = last['metastablestate'].values
|
|
352
|
+
#print(f'last: {last}')
|
|
353
|
+
meta_dtrajs_last.append(last)
|
|
354
|
+
|
|
355
|
+
self.logger.debug(f'traj_idx: {traj_idx}, traj: {traj}, traj_len: {traj_len}, last_num_frames: {self.last_num_frames}, last: {last} {len(last)}')
|
|
356
|
+
traj_idx_to_trajnum[traj_idx] = traj
|
|
357
|
+
|
|
358
|
+
meta_dtrajs_last = np.array(meta_dtrajs_last)
|
|
359
|
+
self.logger.info(f'meta_dtrajs_last.shape: {meta_dtrajs_last.shape}')
|
|
360
|
+
self.logger.info(f'meta_dtrajs_last\n{meta_dtrajs_last} {meta_dtrajs_last.shape}')
|
|
361
|
+
self.logger.debug(np.unique(meta_dtrajs_last))
|
|
362
|
+
|
|
363
|
+
# Keep MSM trajectory indexing aligned with downstream OP arrays.
|
|
364
|
+
rm_traj_set = set(int(t) for t in self.rm_traj_list)
|
|
365
|
+
keep_traj_idx = [
|
|
366
|
+
idx for idx in range(meta_dtrajs_last.shape[0])
|
|
367
|
+
if int(traj_idx_to_trajnum[idx]) not in rm_traj_set
|
|
368
|
+
]
|
|
369
|
+
meta_dtrajs_last = meta_dtrajs_last[keep_traj_idx, :]
|
|
370
|
+
traj_idx_to_trajnum = {
|
|
371
|
+
new_idx: int(traj_idx_to_trajnum[old_idx])
|
|
372
|
+
for new_idx, old_idx in enumerate(keep_traj_idx)
|
|
373
|
+
}
|
|
374
|
+
self.logger.info(
|
|
375
|
+
f'meta_dtrajs_last.shape after mirror-image removal: {meta_dtrajs_last.shape}'
|
|
376
|
+
)
|
|
377
|
+
#################################################################
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
#################################################################
|
|
381
|
+
# Create frame_list for each state
|
|
382
|
+
# The result is a list where each element is a 2D array with the first column being the trajectory index and the second column being the frame index
|
|
383
|
+
sel_frame_idx = np.arange(0, self.last_num_frames, self.lag_frame)
|
|
384
|
+
self.logger.info(f'sel_frame_idx:\n{sel_frame_idx}')
|
|
385
|
+
|
|
386
|
+
frame_list = []
|
|
387
|
+
empty_states = []
|
|
388
|
+
for state_idx in self.state_idx_list:
|
|
389
|
+
self.logger.info(f'\nGetting frames for state {state_idx}...')
|
|
390
|
+
frame_list_0 = np.array(np.where(meta_dtrajs_last[:, sel_frame_idx] == state_idx)).T
|
|
391
|
+
frame_list_0[:,1] = sel_frame_idx[frame_list_0[:,1]]
|
|
392
|
+
utrajs = np.unique(frame_list_0[:,0])
|
|
393
|
+
self.logger.debug(frame_list_0)
|
|
394
|
+
self.logger.debug(f'utrajs: {utrajs}')
|
|
395
|
+
|
|
396
|
+
#frame_list_0 = self.remove_traj_from_frame_list(self.rm_traj_list, frame_list_0, 1)
|
|
397
|
+
if len(frame_list_0) == 0:
|
|
398
|
+
self.logger.info(f'No frames for state {state_idx} in the last {self.last_num_frames} frames. Exitting. The state maybe made entirely of mirror traj and in the self.rm_traj_list!')
|
|
399
|
+
empty_states.append(state_idx)
|
|
400
|
+
continue
|
|
401
|
+
|
|
402
|
+
frame_list.append(frame_list_0)
|
|
403
|
+
|
|
404
|
+
## adjust the frame_list and the state_idx_list depending on what states are populated
|
|
405
|
+
self.logger.debug(f'empty_states: {empty_states}')
|
|
406
|
+
self.state_idx_list = [idx for idx in self.state_idx_list if idx not in empty_states]
|
|
407
|
+
self.logger.info(f'Updated state_idx_list: {self.state_idx_list}')
|
|
408
|
+
for idx, arr in enumerate(frame_list):
|
|
409
|
+
state = self.state_idx_list[idx]
|
|
410
|
+
self.logger.info(f'state_list_idx={idx}, state={state}, shape={arr.shape}')
|
|
411
|
+
|
|
412
|
+
native_sel = np.array(np.where(meta_dtrajs_last[:, sel_frame_idx] == self.native_state_idx)).T
|
|
413
|
+
native_sel[:,1] = sel_frame_idx[native_sel[:,1]]
|
|
414
|
+
self.logger.info(f'native_sel:\n{native_sel} {native_sel.shape}')
|
|
415
|
+
|
|
416
|
+
#native_sel = self.remove_traj_from_frame_list(self.rm_traj_list, native_sel, 1)
|
|
417
|
+
#print(f'native_sel:\n{native_sel} {native_sel.shape}')
|
|
418
|
+
#################################################################
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
#################################################################
|
|
422
|
+
# Load SASA data
|
|
423
|
+
sasa_traj_list = np.load(self.sasa_data_file, allow_pickle=True)[:,-self.last_num_frames:,:]
|
|
424
|
+
self.logger.info(f'sasa_traj_list.shape: {sasa_traj_list.shape}')
|
|
425
|
+
|
|
426
|
+
# Apply the same trajectory filtering used for MSM indexing.
|
|
427
|
+
sasa_traj_list = sasa_traj_list[keep_traj_idx, :, :]
|
|
428
|
+
self.logger.info(f'sasa_traj_list.shape after mirror-image removal: {sasa_traj_list.shape}')
|
|
429
|
+
#################################################################
|
|
430
|
+
|
|
431
|
+
|
|
432
|
+
#################################################################
|
|
433
|
+
# XL-MS scores are computed either from a pre-built Jwalk.npy
|
|
434
|
+
# (legacy mode) or directly from per-trajectory XP files
|
|
435
|
+
# (memory-friendly streaming mode).
|
|
436
|
+
dist_traj_list = None
|
|
437
|
+
#################################################################
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
#################################################################
|
|
441
|
+
# Select the frames without nan residual SASA
|
|
442
|
+
# NAN SASA indicates a bad backmapped structure and I want to skip them
|
|
443
|
+
nan_frame_sel = np.where(np.isnan(sasa_traj_list))
|
|
444
|
+
nan_frame_sel = np.array([nan_frame_sel[0], nan_frame_sel[1]], dtype=int).T
|
|
445
|
+
self.logger.debug(nan_frame_sel)
|
|
446
|
+
frame_list = [np.array([sel for sel in frame if not sel.tolist() in nan_frame_sel.tolist()]) for frame in frame_list]
|
|
447
|
+
self.logger.info(f'frame_list:')
|
|
448
|
+
for idx, arr in enumerate(frame_list):
|
|
449
|
+
state = self.state_idx_list[idx]
|
|
450
|
+
self.logger.info(f'state_list_idx={idx}, state={state}, shape={arr.shape}')
|
|
451
|
+
native_sel = np.array([sel for sel in native_sel if not sel.tolist() in nan_frame_sel.tolist()])
|
|
452
|
+
self.logger.info(f'native_sel:\n{native_sel} {native_sel.shape}')
|
|
453
|
+
#################################################################
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
#################################################################
|
|
457
|
+
# Load LiPMS experimental data
|
|
458
|
+
LiPMS_sig_data = self.load_LiPMS_data(self.LiPMS_exp_file)
|
|
459
|
+
self.logger.info(f'Loaded LiPMS experimental data: {len(LiPMS_sig_data)} peptides')
|
|
460
|
+
|
|
461
|
+
# load XLMS experimental data
|
|
462
|
+
XLMS_sig_data = self.load_XLMS_data(self.XLMS_exp_file)
|
|
463
|
+
self.logger.info(f'Loaded XLMS experimental data: {len(XLMS_sig_data)} pairs')
|
|
464
|
+
#################################################################
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
#################################################################
|
|
468
|
+
# Calculate or load metric matrix
|
|
469
|
+
if self.if_calc_M == 1:
|
|
470
|
+
self.logger.debug('Calculating metric matrix...')
|
|
471
|
+
# calculate metric matrix
|
|
472
|
+
M_LiPMS = np.zeros((*meta_dtrajs_last.shape, len(LiPMS_sig_data)))
|
|
473
|
+
for idx, peptide in enumerate(LiPMS_sig_data.keys()):
|
|
474
|
+
sel = list(LiPMS_sig_data[peptide]['peptide_range'])
|
|
475
|
+
SA = np.sum(sasa_traj_list[:,:,sel], axis=-1)
|
|
476
|
+
M_LiPMS[:,:,idx] = SA
|
|
477
|
+
|
|
478
|
+
xlms_targets = []
|
|
479
|
+
for idx, key in enumerate(XLMS_sig_data.keys()):
|
|
480
|
+
pair_AA = [k[0] for k in key.split('-')]
|
|
481
|
+
key_0 = '-'.join([k[1:]+'|A' for k in key.split('-')])
|
|
482
|
+
xlms_targets.append((idx, key_0, pair_AA))
|
|
483
|
+
|
|
484
|
+
M_XLMS = np.zeros((*meta_dtrajs_last.shape, len(XLMS_sig_data)))
|
|
485
|
+
if self.dist_data_file is not None and os.path.exists(self.dist_data_file):
|
|
486
|
+
dist_traj_list = np.load(self.dist_data_file, allow_pickle=True)[:,-self.last_num_frames:]
|
|
487
|
+
self.logger.info(f'dist_traj_list.shape: {dist_traj_list.shape}')
|
|
488
|
+
|
|
489
|
+
# Apply the same trajectory filtering used for MSM indexing.
|
|
490
|
+
dist_traj_list = dist_traj_list[keep_traj_idx, :]
|
|
491
|
+
self.logger.info(f'dist_traj_list.shape after mirror-image removal: {dist_traj_list.shape}')
|
|
492
|
+
|
|
493
|
+
n_traj = min(meta_dtrajs_last.shape[0], dist_traj_list.shape[0])
|
|
494
|
+
n_frame = min(meta_dtrajs_last.shape[1], dist_traj_list.shape[1])
|
|
495
|
+
for i in range(n_traj):
|
|
496
|
+
for j in range(n_frame):
|
|
497
|
+
frame_data = dist_traj_list[i, j]
|
|
498
|
+
if frame_data is None:
|
|
499
|
+
continue
|
|
500
|
+
for idx, key_0, pair_AA in xlms_targets:
|
|
501
|
+
if key_0 not in frame_data:
|
|
502
|
+
continue
|
|
503
|
+
JWalk_dist = frame_data[key_0].get('Jwalk', -1)
|
|
504
|
+
M_XLMS[i, j, idx] = self.score_XL(pair_AA, JWalk_dist)
|
|
505
|
+
elif self.xp_dir is not None and os.path.isdir(self.xp_dir):
|
|
506
|
+
self.logger.info(f'Computing XL-MS scores in streaming mode from XP files in: {self.xp_dir}')
|
|
507
|
+
target_lookup = {key_0: (idx, pair_AA) for idx, key_0, pair_AA in xlms_targets}
|
|
508
|
+
|
|
509
|
+
for traj_idx in range(meta_dtrajs_last.shape[0]):
|
|
510
|
+
traj_num = int(traj_idx_to_trajnum[traj_idx])
|
|
511
|
+
if traj_num in rm_traj_set:
|
|
512
|
+
continue
|
|
513
|
+
|
|
514
|
+
fpath = os.path.join(self.xp_dir, f'{self.ID}_Traj{traj_num}.XP')
|
|
515
|
+
if not os.path.exists(fpath):
|
|
516
|
+
self.logger.warning(f'Missing XP file for streaming XL-MS scoring: {fpath}')
|
|
517
|
+
continue
|
|
518
|
+
|
|
519
|
+
df = pd.read_csv(
|
|
520
|
+
fpath,
|
|
521
|
+
sep='\t',
|
|
522
|
+
usecols=['Frame', 'Atom1', 'Atom2', 'SASD'],
|
|
523
|
+
dtype={'Frame': np.int32, 'SASD': np.float32, 'Atom1': 'string', 'Atom2': 'string'},
|
|
524
|
+
)
|
|
525
|
+
if df.empty:
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
frame_values = np.sort(df['Frame'].unique())
|
|
529
|
+
frame_to_idx = {int(f): idx for idx, f in enumerate(frame_values)}
|
|
530
|
+
|
|
531
|
+
atom1_parts = df['Atom1'].str.split('-', expand=True)
|
|
532
|
+
atom2_parts = df['Atom2'].str.split('-', expand=True)
|
|
533
|
+
if atom1_parts.shape[1] < 3 or atom2_parts.shape[1] < 3:
|
|
534
|
+
self.logger.warning(f'Malformed Atom fields in {fpath}; skipping trajectory {traj_num}')
|
|
535
|
+
continue
|
|
536
|
+
|
|
537
|
+
frame_key_df = pd.DataFrame({
|
|
538
|
+
'Frame': df['Frame'].values,
|
|
539
|
+
'SASD': df['SASD'].values,
|
|
540
|
+
'key': (atom1_parts[1].astype(str) + '|'+ atom1_parts[2].astype(str) + '-' +
|
|
541
|
+
atom2_parts[1].astype(str) + '|' + atom2_parts[2].astype(str)).values,
|
|
542
|
+
})
|
|
543
|
+
|
|
544
|
+
frame_key_df = frame_key_df[frame_key_df['key'].isin(target_lookup.keys())]
|
|
545
|
+
frame_key_df = frame_key_df.drop_duplicates(subset=['Frame', 'key'], keep='first')
|
|
546
|
+
if frame_key_df.empty:
|
|
547
|
+
continue
|
|
548
|
+
|
|
549
|
+
for key_0, key_df in frame_key_df.groupby('key', sort=False):
|
|
550
|
+
idx, pair_AA = target_lookup[key_0]
|
|
551
|
+
for frame_num, sasd in zip(key_df['Frame'].values, key_df['SASD'].values):
|
|
552
|
+
frame_idx = frame_to_idx.get(int(frame_num), None)
|
|
553
|
+
if frame_idx is None or frame_idx >= meta_dtrajs_last.shape[1]:
|
|
554
|
+
continue
|
|
555
|
+
M_XLMS[traj_idx, frame_idx, idx] = self.score_XL(pair_AA, float(sasd))
|
|
556
|
+
|
|
557
|
+
if (traj_idx + 1) % 50 == 0:
|
|
558
|
+
self.logger.info(f'Streamed XP scoring progress: {traj_idx + 1}/{meta_dtrajs_last.shape[0]} trajectories')
|
|
559
|
+
else:
|
|
560
|
+
raise ValueError(
|
|
561
|
+
'XL-MS scoring requires either dist_data_file (Jwalk.npy) or xp_dir (per-trajectory XP files).'
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
# Save data
|
|
565
|
+
np.savez(npz_outfile,
|
|
566
|
+
M_LiPMS = M_LiPMS,
|
|
567
|
+
M_XLMS = M_XLMS)
|
|
568
|
+
self.logger.info(f'Saved metric matrices to {npz_outfile}')
|
|
569
|
+
|
|
570
|
+
else:
|
|
571
|
+
self.logger.debug('Loading metric matrix...')
|
|
572
|
+
M_data = np.load(npz_outfile, allow_pickle=True)
|
|
573
|
+
M_LiPMS = M_data['M_LiPMS'][:,-self.last_num_frames:,:]
|
|
574
|
+
M_XLMS = M_data['M_XLMS'][:,-self.last_num_frames:,:]
|
|
575
|
+
#################################################################
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
#################################################################
|
|
579
|
+
# Consistency tests
|
|
580
|
+
M_str_list = ['LiPMS', 'XLMS']
|
|
581
|
+
exp_data_list = [LiPMS_sig_data, XLMS_sig_data]
|
|
582
|
+
df_list = []
|
|
583
|
+
df_all_state_list = []
|
|
584
|
+
p_list = []
|
|
585
|
+
p_all_state_list = []
|
|
586
|
+
for idx, M in enumerate([M_LiPMS, M_XLMS]):
|
|
587
|
+
self.logger.info(M_str_list[idx]+':')
|
|
588
|
+
exp_data = exp_data_list[idx]
|
|
589
|
+
|
|
590
|
+
df_list_0 = []
|
|
591
|
+
for idx_0 in range(len(exp_data)):
|
|
592
|
+
self.logger.info('%s:'%(list(exp_data.keys())[idx_0]))
|
|
593
|
+
M_0 = M[:,:,idx_0]
|
|
594
|
+
index_str = list(exp_data.keys())[idx_0]
|
|
595
|
+
header = ['Near-native state', 'Sample size', '<M>']
|
|
596
|
+
header += ['p (!=)', 'Adjusted p (!=)']
|
|
597
|
+
if exp_data[list(exp_data.keys())[idx_0]]['qual_change'] > 0:
|
|
598
|
+
header += ['p (>)', 'Adjusted p (>)']
|
|
599
|
+
test_side = '>'
|
|
600
|
+
else:
|
|
601
|
+
header += ['p (<)', 'Adjusted p (<)']
|
|
602
|
+
test_side = '<'
|
|
603
|
+
|
|
604
|
+
M_0_native = M_0[native_sel[:,0], native_sel[:,1]]
|
|
605
|
+
# bootstrapping to get 95%CI
|
|
606
|
+
boot_stat_native = self.bootstrap(np.mean, M_0_native, self.n_boot)
|
|
607
|
+
lb_native = np.percentile(boot_stat_native, 2.5)
|
|
608
|
+
ub_native = np.percentile(boot_stat_native, 97.5)
|
|
609
|
+
|
|
610
|
+
df_data = []
|
|
611
|
+
df_all_state_data = []
|
|
612
|
+
# for near-native states
|
|
613
|
+
self.logger.info(f'len(self.state_idx_list): {self.state_idx_list} {len(self.state_idx_list)}')
|
|
614
|
+
for idx_1, state_id in enumerate(self.state_idx_list):
|
|
615
|
+
self.logger.info(f'state_list_idx={idx_1}, state={state_id}')
|
|
616
|
+
near_native_sel = np.array(frame_list[idx_1], dtype=int)
|
|
617
|
+
self.logger.debug(near_native_sel)
|
|
618
|
+
M_0_near_native = M_0[near_native_sel[:,0], near_native_sel[:,1]]
|
|
619
|
+
# bootstrapping to get 95%CI
|
|
620
|
+
boot_stat_near_native = self.bootstrap(np.mean, M_0_near_native, self.n_boot)
|
|
621
|
+
lb_near_native = np.percentile(boot_stat_near_native, 2.5)
|
|
622
|
+
ub_near_native = np.percentile(boot_stat_near_native, 97.5)
|
|
623
|
+
self.logger.info('Near-native state %d vs. Native state %d:'%(state_id+1, num_meta_states))
|
|
624
|
+
self.logger.info(' Sample size: %d vs. %d'%(len(M_0_near_native), len(M_0_native)))
|
|
625
|
+
self.logger.info(' <M>: %.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native))
|
|
626
|
+
p_value_list_0 = []
|
|
627
|
+
for ts in ['!=', test_side]:
|
|
628
|
+
p = self.permutation_test(self.perm_fun, M_0_near_native, M_0_native, self.num_perm, side=ts)
|
|
629
|
+
# p, _ = bootstrap_test(M_0_near_native, M_0_native, statistic_fun, n_boot, side=ts)
|
|
630
|
+
self.logger.info(' p-value ("%s") = %.4f'%(ts, p))
|
|
631
|
+
p_value_list_0.append(p)
|
|
632
|
+
p_list.append(p_value_list_0)
|
|
633
|
+
df_data.append([state_id+1, '%d vs. %d'%(len(M_0_near_native), len(M_0_native)), '%.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native), p_value_list_0[0], 0, p_value_list_0[1], 0])
|
|
634
|
+
|
|
635
|
+
df = pd.DataFrame(df_data, columns=header, index=[index_str]*len(df_data))
|
|
636
|
+
df_list_0.append(df)
|
|
637
|
+
|
|
638
|
+
# for all states
|
|
639
|
+
frame_list_1 = np.array(np.where(meta_dtrajs_last < num_meta_states)).T
|
|
640
|
+
#near_native_sel = self.remove_traj_from_frame_list(self.rm_traj_list, frame_list_1, 1)
|
|
641
|
+
near_native_sel = np.array([sel for sel in near_native_sel if not sel in nan_frame_sel]) # remove nan SASA
|
|
642
|
+
|
|
643
|
+
# Select only frames seperated by #lag_frame for each state in the trajectory
|
|
644
|
+
near_native_sel_idx = []
|
|
645
|
+
for i in np.unique(near_native_sel[:,0]):
|
|
646
|
+
idx = np.where(near_native_sel[:,0] == i)[0]
|
|
647
|
+
near_native_sel_idx.append(idx[0])
|
|
648
|
+
for iidx in idx[1:]:
|
|
649
|
+
if near_native_sel[iidx,1] - near_native_sel[near_native_sel_idx[-1],1] >= self.lag_frame:
|
|
650
|
+
near_native_sel_idx.append(iidx)
|
|
651
|
+
near_native_sel = near_native_sel[near_native_sel_idx,:]
|
|
652
|
+
|
|
653
|
+
M_0_near_native = M_0[near_native_sel[:,0], near_native_sel[:,1]]
|
|
654
|
+
# bootstrapping to get 95%CI
|
|
655
|
+
boot_stat_near_native = self.bootstrap(np.mean, M_0_near_native, self.n_boot)
|
|
656
|
+
lb_near_native = np.percentile(boot_stat_near_native, 2.5)
|
|
657
|
+
ub_near_native = np.percentile(boot_stat_near_native, 97.5)
|
|
658
|
+
self.logger.info('All states vs. Native state %d:'%(num_meta_states))
|
|
659
|
+
self.logger.info(' Sample size: %d vs. %d'%(len(M_0_near_native), len(M_0_native)))
|
|
660
|
+
self.logger.info(' <M>: %.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native))
|
|
661
|
+
p_value_list_0 = []
|
|
662
|
+
for ts in ['!=', test_side]:
|
|
663
|
+
p = self.permutation_test(self.perm_fun, M_0_near_native, M_0_native, self.num_perm, side=ts)
|
|
664
|
+
# p, _ = bootstrap_test(M_0_near_native, M_0_native, statistic_fun, n_boot, side=ts)
|
|
665
|
+
self.logger.info(' p-value ("%s") = %.4f'%(ts, p))
|
|
666
|
+
p_value_list_0.append(p)
|
|
667
|
+
p_all_state_list.append(p_value_list_0)
|
|
668
|
+
df_all_state_data.append(['All states', '%d vs. %d'%(len(M_0_near_native), len(M_0_native)), '%.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native), p_value_list_0[0], 0, p_value_list_0[1], 0])
|
|
669
|
+
df = pd.DataFrame(df_all_state_data, columns=header, index=[index_str]*len(df_all_state_data))
|
|
670
|
+
df_all_state_list.append(df)
|
|
671
|
+
df_list.append(df_list_0)
|
|
672
|
+
#################################################################
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
#################################################################
|
|
676
|
+
# Correct p-value
|
|
677
|
+
p_list = np.array(p_list)
|
|
678
|
+
p_all_state_list = np.array(p_all_state_list)
|
|
679
|
+
adjusted_p_list = []
|
|
680
|
+
for pi in range(p_list.shape[1]):
|
|
681
|
+
pl = p_list[:,pi]
|
|
682
|
+
results = multipletests(pl, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
|
|
683
|
+
adjusted_p_list.append(results[1])
|
|
684
|
+
adjusted_p_list = np.array(adjusted_p_list).T
|
|
685
|
+
adjusted_p_all_state_list = []
|
|
686
|
+
for pi in range(p_all_state_list.shape[1]):
|
|
687
|
+
pl = p_all_state_list[:,pi]
|
|
688
|
+
results = multipletests(pl, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
|
|
689
|
+
adjusted_p_all_state_list.append(results[1])
|
|
690
|
+
adjusted_p_all_state_list = np.array(adjusted_p_all_state_list).T
|
|
691
|
+
#################################################################
|
|
692
|
+
|
|
693
|
+
#################################################################
|
|
694
|
+
# Update dataframes
|
|
695
|
+
pi = 0
|
|
696
|
+
for idx in range(len(df_list)):
|
|
697
|
+
for idx_0 in range(len(df_list[idx])):
|
|
698
|
+
header = list(df_list[idx][idx_0].columns)
|
|
699
|
+
for i in range(adjusted_p_list.shape[1]):
|
|
700
|
+
ii = -adjusted_p_list.shape[1]*2 + i*2 + 1
|
|
701
|
+
df_list[idx][idx_0][header[ii]] = adjusted_p_list[pi:pi+len(self.state_idx_list),i]
|
|
702
|
+
pi += len(df_list[idx][idx_0])
|
|
703
|
+
|
|
704
|
+
pi = 0
|
|
705
|
+
for df in df_all_state_list:
|
|
706
|
+
header = list(df.columns)
|
|
707
|
+
for i in range(adjusted_p_all_state_list.shape[1]):
|
|
708
|
+
ii = -adjusted_p_all_state_list.shape[1]*2 + i*2 + 1
|
|
709
|
+
df[header[ii]] = adjusted_p_all_state_list[pi:pi+len(df),i]
|
|
710
|
+
pi += len(df)
|
|
711
|
+
|
|
712
|
+
df_list.append(df_all_state_list)
|
|
713
|
+
M_str_list.append('All states')
|
|
714
|
+
#################################################################
|
|
715
|
+
|
|
716
|
+
#################################################################
|
|
717
|
+
# Creating Excel Writer Object from Pandas
|
|
718
|
+
with pd.ExcelWriter(xlsx_outfile, engine="openpyxl") as writer:
|
|
719
|
+
workbook=writer.book
|
|
720
|
+
for idx_0, df_list_0 in enumerate(df_list):
|
|
721
|
+
worksheet=workbook.create_sheet(M_str_list[idx_0])
|
|
722
|
+
writer.sheets[M_str_list[idx_0]] = worksheet
|
|
723
|
+
start_row = 0
|
|
724
|
+
for idx_1, df in enumerate(df_list_0):
|
|
725
|
+
df.to_excel(writer, sheet_name=M_str_list[idx_0], index=True, float_format="%.4f", startrow=start_row , startcol=0)
|
|
726
|
+
start_row += len(df) + 2
|
|
727
|
+
self.logger.debug(f'SAVED: {xlsx_outfile}')
|
|
728
|
+
#################################################################
|
|
729
|
+
|
|
730
|
+
self.logger.info(f'LiP_XL_MS_ConsistencyTest DONE!')
|
|
731
|
+
return npz_outfile, xlsx_outfile
|
|
732
|
+
##############################################################################
|
|
733
|
+
|
|
734
|
+
#######################################################################################
|
|
735
|
+
def load_OP(self, start:int=0, end:int=99999999999):
|
|
736
|
+
"""
|
|
737
|
+
Loads the GQ values of each trajectory into a 2D array and then appends it to a list
|
|
738
|
+
The list should have Nt = number of trajectories and each array should be n x 2 where n is the number of frames
|
|
739
|
+
"""
|
|
740
|
+
self.logger.info(f'Loading G and Q order parameters...')
|
|
741
|
+
Qfiles = glob.glob(os.path.join(self.OPpath, 'Q/*.Q'))
|
|
742
|
+
QTrajs = [int(pathlib.Path(Qf).stem.split('Traj')[-1]) for Qf in Qfiles]
|
|
743
|
+
|
|
744
|
+
Gfiles = glob.glob(os.path.join(self.OPpath, 'G/*.G'))
|
|
745
|
+
GTrajs = [int(pathlib.Path(Gf).stem.split('Traj')[-1]) for Gf in Gfiles]
|
|
746
|
+
|
|
747
|
+
shared_Trajs = set(QTrajs).intersection(GTrajs)
|
|
748
|
+
shared_Trajs = sorted(shared_Trajs)
|
|
749
|
+
self.logger.debug(f'shared_Trajs: {shared_Trajs}')
|
|
750
|
+
#print(f'Shared Traj between Q and G: {shared_Trajs} {len(shared_Trajs)}')
|
|
751
|
+
self.logger.info(f'Number of Q files found: {len(Qfiles)} | Number of G files found: {len(Gfiles)}')
|
|
752
|
+
|
|
753
|
+
assert len(Qfiles) == len(Gfiles), f"The # of Q and G files {len(Qfiles)} and {len(Gfiles)} is not equal"
|
|
754
|
+
|
|
755
|
+
## remove trajectories that are in the rm_traj_list
|
|
756
|
+
if len(self.rm_traj_list) > 0:
|
|
757
|
+
self.logger.info(f'Removing trajectories: {self.rm_traj_list}')
|
|
758
|
+
shared_Trajs = [traj for traj in shared_Trajs if traj not in self.rm_traj_list]
|
|
759
|
+
self.logger.info(f'Number of shared Traj after removing: {len(shared_Trajs)}')
|
|
760
|
+
|
|
761
|
+
# loop through the Qfiles and find matching Gfile
|
|
762
|
+
# then load the Q and G time series into a 2D array
|
|
763
|
+
Q_data = []
|
|
764
|
+
G_data = []
|
|
765
|
+
for traj in shared_Trajs:
|
|
766
|
+
#print(f'Traj: {traj}')
|
|
767
|
+
|
|
768
|
+
# get the cooresponding G and Q file
|
|
769
|
+
Qf = [f for f in Qfiles if f.endswith(f'Traj{traj}.Q')]
|
|
770
|
+
Gf = [f for f in Gfiles if f.endswith(f'Traj{traj}.G')]
|
|
771
|
+
#print(f'Qf: {Qf}')
|
|
772
|
+
#print(f'Gf: {Gf}')
|
|
773
|
+
|
|
774
|
+
## Quality check to assert that only a single G and Q file were found
|
|
775
|
+
assert len(Qf) == 1, f"the number of Q files {len(Qf)} should equal 1 for Traj {traj}"
|
|
776
|
+
assert len(Gf) == 1, f"the number of G files {len(Gf)} should equal 1 for Traj {traj}"
|
|
777
|
+
|
|
778
|
+
# load the G Q data and extract only the time series column
|
|
779
|
+
#Qdata = pd.read_csv(Qf)['Q'].values[self.start:self.end:self.stride]
|
|
780
|
+
Qdata = pd.read_csv(Qf[0], sep=',')
|
|
781
|
+
Gdata = pd.read_csv(Gf[0])
|
|
782
|
+
if start < 0: # start was specified as negative and takens as slicing the end of the arry
|
|
783
|
+
Qdata = Qdata['total'].values[start:]
|
|
784
|
+
Gdata = Gdata['G'].values[start:]
|
|
785
|
+
else:
|
|
786
|
+
Qdata = Qdata[(Qdata['Frame'] >= start) & (Qdata['Frame'] <= end)]
|
|
787
|
+
Qdata = Qdata['total'].values.astype(float)
|
|
788
|
+
|
|
789
|
+
Gdata = Gdata[(Gdata['Frame'] >= start) & (Gdata['Frame'] <= end)]
|
|
790
|
+
Gdata = Gdata['G'].values.astype(float)
|
|
791
|
+
# print(f'Shape of OP: Q {Qdata.shape} G {Gdata.shape}')
|
|
792
|
+
|
|
793
|
+
## Quality check that the G and Q data has the same number of frames
|
|
794
|
+
if Qdata.shape != Gdata.shape:
|
|
795
|
+
self.logger.warning(f"WARNING: The number of frames in Q {Qdata.shape} should equal the number of frames in G {Gdata.shape} in Traj {traj}")
|
|
796
|
+
continue
|
|
797
|
+
|
|
798
|
+
## Check and ensure that Qdata or Gdata has no nan values
|
|
799
|
+
if np.isnan(Qdata).any():
|
|
800
|
+
raise ValueError(f'There is a NaN value in this Qdata')
|
|
801
|
+
|
|
802
|
+
if np.isnan(Gdata).any():
|
|
803
|
+
raise ValueError(f'There is a NaN value in this Gdata')
|
|
804
|
+
|
|
805
|
+
Q_data.append(Qdata)
|
|
806
|
+
G_data.append(Gdata)
|
|
807
|
+
|
|
808
|
+
Q_data = np.asarray(Q_data)
|
|
809
|
+
G_data = np.asarray(G_data)
|
|
810
|
+
self.logger.debug(f'Q_data: {Q_data.shape}')
|
|
811
|
+
self.logger.debug(f'G_data: {G_data.shape}')
|
|
812
|
+
return Q_data, G_data
|
|
813
|
+
##############################################################################
|
|
814
|
+
|
|
815
|
+
##############################################################################
|
|
816
|
+
def select_rep_structs(self, consist_data_file:str, consist_result_file:str, total_traj_num_frames:int, last_num_frames:int):
|
|
817
|
+
"""
|
|
818
|
+
After performing the consistency test select representative structures with high consistency
|
|
819
|
+
"""
|
|
820
|
+
self.logger.info(f'Selecting representative structure')
|
|
821
|
+
|
|
822
|
+
############### Functions ###############
|
|
823
|
+
##############################################################################
|
|
824
|
+
def parse_consistant_results(consist_result_file, sheet_name, test_type='two-tailed', significant_level=0.05):
|
|
825
|
+
data = {}
|
|
826
|
+
consist_data = pd.read_excel(consist_result_file, sheet_name=sheet_name)
|
|
827
|
+
num_row = consist_data.index.size
|
|
828
|
+
all_key_list = []
|
|
829
|
+
if test_type == 'two-tailed':
|
|
830
|
+
t_idx = -3
|
|
831
|
+
elif test_type == 'one-tailed':
|
|
832
|
+
t_idx = -1
|
|
833
|
+
else:
|
|
834
|
+
self.logger.error('Error: Wrong test_type = %s; can be either "two-tailed" or "one-tailed"'%test_type)
|
|
835
|
+
sys.exit()
|
|
836
|
+
sign = consist_data.columns[t_idx].split()[-1][1:-1]
|
|
837
|
+
for i in range(num_row):
|
|
838
|
+
row = consist_data.loc[i]
|
|
839
|
+
if type(row[-1]) == float and np.isnan(row[-1]):
|
|
840
|
+
continue
|
|
841
|
+
elif type(row[0]) == float and np.isnan(row[0]):
|
|
842
|
+
sign = row[t_idx].split()[-1][1:-1]
|
|
843
|
+
elif row[1] == 'All states':
|
|
844
|
+
continue
|
|
845
|
+
else:
|
|
846
|
+
if row[0] not in all_key_list:
|
|
847
|
+
all_key_list.append(row[0])
|
|
848
|
+
if row[t_idx] < significant_level:
|
|
849
|
+
if row[0] not in data.keys():
|
|
850
|
+
mean = float(row[3].strip().split('vs.')[-1].split()[0])
|
|
851
|
+
words = row[3].strip().split('vs.')[-1].split()[1:]
|
|
852
|
+
lb = float(words[0][1:-1])
|
|
853
|
+
ub = float(words[1][:-1])
|
|
854
|
+
data[row[0]] = [sign, [row[1]-1], mean, [lb, ub]]
|
|
855
|
+
else:
|
|
856
|
+
data[row[0]][1].append(row[1]-1)
|
|
857
|
+
|
|
858
|
+
for key in data.keys():
|
|
859
|
+
data[key].append(all_key_list.index(key))
|
|
860
|
+
|
|
861
|
+
return data
|
|
862
|
+
##############################################################################
|
|
863
|
+
|
|
864
|
+
##############################################################################
|
|
865
|
+
def calc_rel_change(traj_idx, frame_idx):
|
|
866
|
+
signal_list = dtrajs_MS[traj_idx, frame_idx]
|
|
867
|
+
rel_change_list = []
|
|
868
|
+
for signal in signal_list:
|
|
869
|
+
if '-' in signal[1:-1]:
|
|
870
|
+
data = XLMS_consist_data
|
|
871
|
+
M = M_XLMS
|
|
872
|
+
else:
|
|
873
|
+
data = LIPMS_consist_data
|
|
874
|
+
M = M_LiPMS
|
|
875
|
+
mean = data[signal][2]
|
|
876
|
+
v = M[traj_idx, frame_idx, data[signal][4]]
|
|
877
|
+
if mean == 0:
|
|
878
|
+
rel_change_list.append(np.abs(v/1e-5))
|
|
879
|
+
else:
|
|
880
|
+
rel_change_list.append(np.abs(v/mean-1))
|
|
881
|
+
return rel_change_list
|
|
882
|
+
##############################################################################
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
################### MAIN ###############
|
|
886
|
+
##############################################################################
|
|
887
|
+
self.logger.info(f'Loading consistency test data from {consist_result_file}')
|
|
888
|
+
self.logger.info(f'Loading consistency test data from {consist_data_file}')
|
|
889
|
+
|
|
890
|
+
if_backmap = 0
|
|
891
|
+
pulchra_only = True
|
|
892
|
+
significant_level = 0.05
|
|
893
|
+
##############################################################################
|
|
894
|
+
|
|
895
|
+
##############################################################################
|
|
896
|
+
# Load MSM data
|
|
897
|
+
MSM_data = pd.read_csv(self.msm_data_file)
|
|
898
|
+
self.logger.info(f'MSM_data\n{MSM_data}')
|
|
899
|
+
meta_states = MSM_data['metastablestate'].unique()
|
|
900
|
+
meta_states = np.array(meta_states, dtype=int)
|
|
901
|
+
self.logger.debug(f'meta_states: {meta_states}')
|
|
902
|
+
num_meta_states = len(meta_states)
|
|
903
|
+
self.logger.debug(f'num_meta_states: {num_meta_states}')
|
|
904
|
+
|
|
905
|
+
rm_traj_set = set(int(t) for t in self.rm_traj_list)
|
|
906
|
+
meta_dtrajs_last = []
|
|
907
|
+
micro_dtrajs_last = []
|
|
908
|
+
MSM_traj_idx_to_trajnum = {} # mapping traj_idx to traj number (after rm_traj_list filtering)
|
|
909
|
+
for traj, traj_df in MSM_data.groupby('traj'):
|
|
910
|
+
traj = int(traj)
|
|
911
|
+
if traj in rm_traj_set:
|
|
912
|
+
continue
|
|
913
|
+
traj_len = len(traj_df)
|
|
914
|
+
#print(f'traj: {traj}, traj_len: {traj_len}\n{traj_df.head()}')
|
|
915
|
+
|
|
916
|
+
last = traj_df.iloc[-last_num_frames:,:]
|
|
917
|
+
last = last.reset_index(drop=True)
|
|
918
|
+
meta_last = last['metastablestate'].values
|
|
919
|
+
micro_last = last['microstate'].values
|
|
920
|
+
#print(f'last: {last}')
|
|
921
|
+
meta_dtrajs_last.append(meta_last)
|
|
922
|
+
micro_dtrajs_last.append(micro_last)
|
|
923
|
+
|
|
924
|
+
MSM_traj_idx_to_trajnum[len(MSM_traj_idx_to_trajnum)] = traj
|
|
925
|
+
|
|
926
|
+
meta_dtrajs_last = np.array(meta_dtrajs_last)
|
|
927
|
+
micro_dtrajs_last = np.array(micro_dtrajs_last)
|
|
928
|
+
self.logger.info(f'meta_dtrajs_last\n{meta_dtrajs_last} {meta_dtrajs_last.shape}')
|
|
929
|
+
self.logger.debug(np.unique(meta_dtrajs_last))
|
|
930
|
+
self.logger.info(f'micro_dtrajs_last\n{micro_dtrajs_last} {micro_dtrajs_last.shape}')
|
|
931
|
+
self.logger.debug(np.unique(micro_dtrajs_last))
|
|
932
|
+
self.logger.debug(f'MSM_traj_idx_to_trajnum: {MSM_traj_idx_to_trajnum}')
|
|
933
|
+
|
|
934
|
+
## load the meta_dist data
|
|
935
|
+
meta_dist = np.load(self.meta_dist_file, allow_pickle=True)
|
|
936
|
+
self.logger.info(f'meta_dist:\n{meta_dist} {meta_dist.shape}')
|
|
937
|
+
##############################################################################
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
##############################################################################
|
|
941
|
+
# Load Consistency Metrics
|
|
942
|
+
consist_data = np.load(consist_data_file, allow_pickle=True)
|
|
943
|
+
M_LiPMS = consist_data['M_LiPMS'][:,-last_num_frames:,:]
|
|
944
|
+
M_XLMS = consist_data['M_XLMS'][:,-last_num_frames:,:]
|
|
945
|
+
self.logger.info(f'Loaded consistency metrics from {consist_data_file}')
|
|
946
|
+
self.logger.debug(f'M_LiPMS: {M_LiPMS.shape}')
|
|
947
|
+
self.logger.debug(f'M_XLMS: {M_XLMS.shape}')
|
|
948
|
+
##############################################################################
|
|
949
|
+
|
|
950
|
+
##############################################################################
|
|
951
|
+
# Load consistency test results
|
|
952
|
+
LIPMS_consist_data = parse_consistant_results(consist_result_file, sheet_name='LiPMS', test_type='two-tailed', significant_level=significant_level)
|
|
953
|
+
XLMS_consist_data = parse_consistant_results(consist_result_file, sheet_name='XLMS', test_type='two-tailed', significant_level=significant_level)
|
|
954
|
+
self.logger.info(f'Loaded consistency test results')
|
|
955
|
+
##############################################################################
|
|
956
|
+
|
|
957
|
+
|
|
958
|
+
##############################################################################
|
|
959
|
+
# Load cluster data
|
|
960
|
+
# Beware the order of traj in the cluster_data may not be the same as the order in the MSM_data
|
|
961
|
+
# we will correct that here so it is consistent moving forward with the analysis
|
|
962
|
+
cluster_data = np.load(self.cluster_data_file, allow_pickle=True)
|
|
963
|
+
idx2trajfile = cluster_data['idx2trajfile'].tolist()
|
|
964
|
+
idx2traj = np.asarray([int(f.strip().split('/')[-1].split('_')[0]) for f in idx2trajfile])
|
|
965
|
+
# print(f'idx2traj: {idx2traj} {len(idx2traj)}')
|
|
966
|
+
|
|
967
|
+
# get the ordering array and reorder all the cluster_data arrays
|
|
968
|
+
# also remove those trajectories that are in the rm_traj_list
|
|
969
|
+
order = np.argsort(idx2traj)
|
|
970
|
+
# print(f'order: {order} {len(order)}')
|
|
971
|
+
idx2traj = idx2traj[order]
|
|
972
|
+
# print(f'idx2traj: {idx2traj} {len(idx2traj)}')
|
|
973
|
+
|
|
974
|
+
idx_2_keep = [idx for idx, traj in enumerate(idx2traj) if traj not in self.rm_traj_list]
|
|
975
|
+
# print(f'idx_2_keep: {idx_2_keep} {len(idx_2_keep)}')
|
|
976
|
+
|
|
977
|
+
# reorder idx2trajfile
|
|
978
|
+
idx2traj = idx2traj[idx_2_keep]
|
|
979
|
+
# print(f'idx2traj after removal of mirror images: {idx2traj} {len(idx2traj)}')
|
|
980
|
+
|
|
981
|
+
# QC to ensure idx2traj matches MSM_traj_idx_to_trajnum
|
|
982
|
+
assert np.array_equal(idx2traj, np.asarray(list(MSM_traj_idx_to_trajnum.values()))), f"Error: idx2traj does not match MSM_traj_idx_to_trajnum: {idx2traj} vs. {np.asarray(list(MSM_traj_idx_to_trajnum.values()))}"
|
|
983
|
+
|
|
984
|
+
# load the dtraj data
|
|
985
|
+
dtrajs = cluster_data['dtrajs']
|
|
986
|
+
dtrajs = dtrajs[order]
|
|
987
|
+
dtrajs = dtrajs[idx_2_keep]
|
|
988
|
+
self.logger.info(f'dtrajs after removal of mirror images: {dtrajs.shape}')
|
|
989
|
+
|
|
990
|
+
# load the rep_chg_ent_dtrajs
|
|
991
|
+
rep_chg_ent_dtrajs = cluster_data['rep_chg_ent_dtrajs']
|
|
992
|
+
# Map resid to residue idx using resid2residueidx_map
|
|
993
|
+
self.logger.info(f'Mapping of rep_chg_ent_dtrajs resid to residu idx using resid2residueidx_map: {self.resid2residueidx_map}')
|
|
994
|
+
for traj_idx, traj_data in enumerate(rep_chg_ent_dtrajs):
|
|
995
|
+
for frame_idx, frame_data in enumerate(traj_data):
|
|
996
|
+
for fingerprint_id, fingerprint in frame_data.items():
|
|
997
|
+
# print(f'\n{"#"*50}\nTraj idx: {traj_idx} | Frame idx {frame_idx} | Fingerprint ID: {fingerprint_id}')
|
|
998
|
+
|
|
999
|
+
for fingerprint_key, new_key in {'crossing_resid':'crossing_residx', 'ref_crossing_resid':'ref_crossing_residx'}.items():
|
|
1000
|
+
residx_arr = fingerprint[fingerprint_key]
|
|
1001
|
+
residx_arr = [[self.resid2residueidx_map[x] for x in sublist] for sublist in residx_arr]
|
|
1002
|
+
fingerprint[new_key] = residx_arr
|
|
1003
|
+
# print(f' mapped {new_key}: {residx_arr}')
|
|
1004
|
+
|
|
1005
|
+
for fingerprint_key, new_key in {'native_contact':'native_contact_residx', 'ref_native_contact':'ref_native_contact_residx'}.items():
|
|
1006
|
+
residx_arr = fingerprint[fingerprint_key]
|
|
1007
|
+
residx_arr = [self.resid2residueidx_map[x] for x in residx_arr]
|
|
1008
|
+
fingerprint[new_key] = residx_arr
|
|
1009
|
+
# print(f' mapped {new_key}: {residx_arr}')
|
|
1010
|
+
|
|
1011
|
+
# for k,v in fingerprint.items():
|
|
1012
|
+
# print(f' {k}: {v}')
|
|
1013
|
+
|
|
1014
|
+
rep_chg_ent_dtrajs = rep_chg_ent_dtrajs[order]
|
|
1015
|
+
rep_chg_ent_dtrajs = rep_chg_ent_dtrajs[idx_2_keep]
|
|
1016
|
+
self.logger.info(f'rep_chg_ent_dtrajs after removal of mirror images: {rep_chg_ent_dtrajs.shape}')
|
|
1017
|
+
|
|
1018
|
+
sorted_chg_ent_structure_keyword_list = cluster_data['sorted_chg_ent_structure_keyword_list'].tolist()
|
|
1019
|
+
self.logger.debug(f'sorted_chg_ent_structure_keyword_list: {len(sorted_chg_ent_structure_keyword_list)} {sorted_chg_ent_structure_keyword_list[:10]}')
|
|
1020
|
+
|
|
1021
|
+
dtrajs_cluster_idx = np.array([[sorted_chg_ent_structure_keyword_list.index(str(dd)) for dd in d] for d in dtrajs])
|
|
1022
|
+
self.logger.debug(f'dtrajs_cluster_idx: {dtrajs_cluster_idx.shape}')
|
|
1023
|
+
self.logger.info(f'Loaded cluster_data_file: {self.cluster_data_file}')
|
|
1024
|
+
##############################################################################
|
|
1025
|
+
|
|
1026
|
+
##############################################################################
|
|
1027
|
+
# Load SASA data
|
|
1028
|
+
sasa_traj_list = np.load(self.sasa_data_file, allow_pickle=True)[:,-last_num_frames:,:]
|
|
1029
|
+
self.logger.info(f'Loaded SASA data: {self.sasa_data_file}')
|
|
1030
|
+
|
|
1031
|
+
# remove trajectories that are in the rm_traj_list - 1
|
|
1032
|
+
sasa_traj_list = [v for i, v in enumerate(sasa_traj_list) if i not in np.asarray(self.rm_traj_list) - 1]
|
|
1033
|
+
sasa_traj_list = np.array(sasa_traj_list)
|
|
1034
|
+
self.logger.info(f'sasa_traj_list.shape after removal of mirror images: {sasa_traj_list.shape}')
|
|
1035
|
+
##############################################################################
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
##############################################################################
|
|
1039
|
+
# Change the ub and lb in LIPMS_consist_data and XLMS_consist_data
|
|
1040
|
+
# (1) Get the native state index (defined by the MSM indexing which is sorted in order of the trajectory number)
|
|
1041
|
+
native_sel = np.where(np.isin(meta_dtrajs_last, self.native_state_idx))
|
|
1042
|
+
native_sel = np.array([native_sel[0], native_sel[1]], dtype=int).T
|
|
1043
|
+
|
|
1044
|
+
# (2) remove any frames with NAN residual SASA
|
|
1045
|
+
nan_frame_sel = np.where(np.isnan(sasa_traj_list))
|
|
1046
|
+
nan_frame_sel = np.array([nan_frame_sel[0], nan_frame_sel[1]], dtype=int).T
|
|
1047
|
+
native_sel = np.array([sel for sel in native_sel if not sel.tolist() in nan_frame_sel.tolist()])
|
|
1048
|
+
self.logger.debug(f'native_sel: {native_sel} {native_sel.shape}')
|
|
1049
|
+
|
|
1050
|
+
# (3) Get the native M_LiPMS and M_XLMS
|
|
1051
|
+
native_M_LiPMS = M_LiPMS[native_sel[:,0], native_sel[:,1], :]
|
|
1052
|
+
native_M_XLMS = M_XLMS[native_sel[:,0], native_sel[:,1], :]
|
|
1053
|
+
native_M_data_outfile = os.path.join(self.outdir, 'native_M_data.npz')
|
|
1054
|
+
np.savez(native_M_data_outfile,
|
|
1055
|
+
M_LiPMS = native_M_LiPMS,
|
|
1056
|
+
M_XLMS = native_M_XLMS,)
|
|
1057
|
+
for key in LIPMS_consist_data:
|
|
1058
|
+
LIPMS_consist_data[key][3][0] = np.percentile(native_M_LiPMS[:, LIPMS_consist_data[key][4]], 2.5)
|
|
1059
|
+
LIPMS_consist_data[key][3][1] = np.percentile(native_M_LiPMS[:, LIPMS_consist_data[key][4]], 97.5)
|
|
1060
|
+
for key in XLMS_consist_data:
|
|
1061
|
+
XLMS_consist_data[key][3][0] = np.percentile(native_M_XLMS[:, XLMS_consist_data[key][4]], 2.5)
|
|
1062
|
+
XLMS_consist_data[key][3][1] = np.percentile(native_M_XLMS[:, XLMS_consist_data[key][4]], 97.5)
|
|
1063
|
+
self.logger.debug(f'SAVED: {native_M_data_outfile}')
|
|
1064
|
+
##############################################################################
|
|
1065
|
+
|
|
1066
|
+
##############################################################################
|
|
1067
|
+
# Go through LiPMS data
|
|
1068
|
+
LIPMS_struct_data = {}
|
|
1069
|
+
dtrajs_MS = np.empty(meta_dtrajs_last.shape, dtype=object)
|
|
1070
|
+
|
|
1071
|
+
# Initialize dtrajs_MS with empty lists
|
|
1072
|
+
for i in range(len(dtrajs_MS)):
|
|
1073
|
+
for j in range(len(dtrajs_MS[i])):
|
|
1074
|
+
dtrajs_MS[i,j] = []
|
|
1075
|
+
|
|
1076
|
+
# Go through each PK site in the LiPMS consistency data
|
|
1077
|
+
self.logger.info(f'\nProcessing LiPMS consistency data')
|
|
1078
|
+
for pk, data in LIPMS_consist_data.items():
|
|
1079
|
+
# print(pk, data)
|
|
1080
|
+
|
|
1081
|
+
LIPMS_struct_data[pk] = {}
|
|
1082
|
+
sign = data[0]
|
|
1083
|
+
state_list = data[1]
|
|
1084
|
+
sasa_ub = data[3][1]
|
|
1085
|
+
sasa_lb = data[3][0]
|
|
1086
|
+
idx_pk = data[4]
|
|
1087
|
+
|
|
1088
|
+
for idx_0, state_idx in enumerate(state_list):
|
|
1089
|
+
LIPMS_struct_data[pk][state_idx] = {}
|
|
1090
|
+
|
|
1091
|
+
idx_list = np.array(np.where(meta_dtrajs_last == state_idx)).T
|
|
1092
|
+
idx_list = np.array([sel for sel in idx_list if not sel.tolist() in nan_frame_sel.tolist()]) # Skip frames with NAN residual SASA (bad backmapped structure)
|
|
1093
|
+
|
|
1094
|
+
sasa_list = M_LiPMS[idx_list[:,0],idx_list[:,1], idx_pk]
|
|
1095
|
+
|
|
1096
|
+
for cluster_idx, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
|
|
1097
|
+
idx_list_1 = np.where(dtrajs_cluster_idx[idx_list[:,0],idx_list[:,1]] == cluster_idx)[0]
|
|
1098
|
+
|
|
1099
|
+
if len(idx_list_1) == 0:
|
|
1100
|
+
continue
|
|
1101
|
+
if sign == '>':
|
|
1102
|
+
idx_list_2 = np.where(sasa_list[idx_list_1] > sasa_ub)[0]
|
|
1103
|
+
idx_rep = np.argmax(sasa_list[idx_list_1])
|
|
1104
|
+
elif sign == '<':
|
|
1105
|
+
idx_list_2 = np.where(sasa_list[idx_list_1] < sasa_lb)[0]
|
|
1106
|
+
idx_rep = np.argmin(sasa_list[idx_list_1])
|
|
1107
|
+
else:
|
|
1108
|
+
idx_list_2 = np.where(np.any([sasa_list[idx_list_1] > sasa_ub, sasa_list[idx_list_1] < sasa_lb], axis=0))[0]
|
|
1109
|
+
idx_rep = np.argmax(np.max([sasa_list[idx_list_1]-sasa_ub, sasa_lb-sasa_list[idx_list_1]], axis=0))
|
|
1110
|
+
if len(idx_list_2) == 0:
|
|
1111
|
+
continue
|
|
1112
|
+
|
|
1113
|
+
consist_idx_list = idx_list[idx_list_1[idx_list_2], :]
|
|
1114
|
+
rep_idx = idx_list[idx_list_1[idx_rep], :]
|
|
1115
|
+
LIPMS_struct_data[pk][state_idx][cluster_idx] = [rep_idx, consist_idx_list]
|
|
1116
|
+
# print(f'PK: {pk}, state_idx: {state_idx}, cluster_idx: {cluster_idx}, rep_idx: {rep_idx}, consist_idx_list: {len(consist_idx_list)}')
|
|
1117
|
+
|
|
1118
|
+
for idx in consist_idx_list:
|
|
1119
|
+
dtrajs_MS[idx[0],idx[1]].append(pk)
|
|
1120
|
+
##############################################################################
|
|
1121
|
+
|
|
1122
|
+
##############################################################################
|
|
1123
|
+
# Go through XLMS data
|
|
1124
|
+
XLMS_struct_data = {}
|
|
1125
|
+
self.logger.info(f'\nProcessing XLMS consistency data')
|
|
1126
|
+
for pair, data in XLMS_consist_data.items():
|
|
1127
|
+
# print(pair, data)
|
|
1128
|
+
|
|
1129
|
+
# Initialize the structure data for each pair
|
|
1130
|
+
XLMS_struct_data[pair] = {}
|
|
1131
|
+
sign = data[0]
|
|
1132
|
+
state_list = data[1]
|
|
1133
|
+
dist_ub = data[3][1]
|
|
1134
|
+
dist_lb = data[3][0]
|
|
1135
|
+
idx_pair = data[4]
|
|
1136
|
+
|
|
1137
|
+
for state_idx in state_list:
|
|
1138
|
+
XLMS_struct_data[pair][state_idx] = {}
|
|
1139
|
+
|
|
1140
|
+
idx_list = np.array(np.where(meta_dtrajs_last == state_idx)).T
|
|
1141
|
+
idx_list = np.array([sel for sel in idx_list if not sel.tolist() in nan_frame_sel.tolist()])
|
|
1142
|
+
|
|
1143
|
+
dist_list = M_XLMS[idx_list[:,0],idx_list[:,1],idx_pair]
|
|
1144
|
+
for cluster_idx, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
|
|
1145
|
+
idx_list_1 = np.where(dtrajs_cluster_idx[idx_list[:,0],idx_list[:,1]] == cluster_idx)[0]
|
|
1146
|
+
if len(idx_list_1) == 0:
|
|
1147
|
+
continue
|
|
1148
|
+
if sign == '>':
|
|
1149
|
+
idx_list_2 = np.where(dist_list[idx_list_1] > dist_ub)[0]
|
|
1150
|
+
idx_rep = np.argmax(dist_list[idx_list_1])
|
|
1151
|
+
elif sign == '<':
|
|
1152
|
+
idx_list_2 = np.where(dist_list[idx_list_1] < dist_lb)[0]
|
|
1153
|
+
idx_rep = np.argmin(dist_list[idx_list_1])
|
|
1154
|
+
else:
|
|
1155
|
+
idx_list_2 = np.where(np.any([dist_list[idx_list_1] > dist_ub, dist_list[idx_list_1] < dist_lb], axis=0))[0]
|
|
1156
|
+
idx_rep = np.argmax(np.max([dist_list[idx_list_1]-dist_ub, dist_lb-dist_list[idx_list_1]], axis=0))
|
|
1157
|
+
if len(idx_list_2) == 0:
|
|
1158
|
+
continue
|
|
1159
|
+
|
|
1160
|
+
consist_idx_list = idx_list[idx_list_1[idx_list_2], :]
|
|
1161
|
+
rep_idx = idx_list[idx_list_1[idx_rep], :]
|
|
1162
|
+
XLMS_struct_data[pair][state_idx][cluster_idx] = [rep_idx, consist_idx_list]
|
|
1163
|
+
#print(f'Pair: {pair}, state_idx: {state_idx}, cluster_idx: {cluster_idx}, rep_idx: {rep_idx}, consist_idx_list: {len(consist_idx_list)}')
|
|
1164
|
+
|
|
1165
|
+
for idx in consist_idx_list:
|
|
1166
|
+
dtrajs_MS[idx[0],idx[1]].append(pair)
|
|
1167
|
+
##############################################################################
|
|
1168
|
+
self.logger.debug(f'dtrajs_MS: {dtrajs_MS.shape}')
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
##############################################################################
|
|
1172
|
+
# Group consistency data
|
|
1173
|
+
self.logger.info(f'\nGrouping consistency data based on consistency signals')
|
|
1174
|
+
consist_signal_dict = {}
|
|
1175
|
+
for i, d in enumerate(dtrajs_MS): # list of lists of consistency signals in each frame
|
|
1176
|
+
|
|
1177
|
+
for j, dd in enumerate(d): # list of consistency signals in frame idx j
|
|
1178
|
+
|
|
1179
|
+
# print(f'MSM traj idx: {i}, frame idx: {j}, consistency signal: {dd}')
|
|
1180
|
+
cluster_idxs = dtrajs_cluster_idx[i,j]
|
|
1181
|
+
# print(f'Cluster idx: {cluster_idxs}')
|
|
1182
|
+
|
|
1183
|
+
if str(dd) not in consist_signal_dict.keys():
|
|
1184
|
+
consist_signal_dict[str(dd)] = {cluster_idxs: [[i,j]]}
|
|
1185
|
+
|
|
1186
|
+
elif cluster_idxs not in consist_signal_dict[str(dd)].keys():
|
|
1187
|
+
consist_signal_dict[str(dd)][cluster_idxs] = [[i,j]]
|
|
1188
|
+
|
|
1189
|
+
else:
|
|
1190
|
+
consist_signal_dict[str(dd)][cluster_idxs].append([i,j])
|
|
1191
|
+
|
|
1192
|
+
Num_struct_list = [np.sum(np.array([len(vv) for vv in v.values()])) for k,v in consist_signal_dict.items()]
|
|
1193
|
+
sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
|
|
1194
|
+
sorted_consist_signal_list = [list(consist_signal_dict.keys())[idx] for idx in sort_idx]
|
|
1195
|
+
# print(f'sorted_consist_signal_list: {sorted_consist_signal_list}')
|
|
1196
|
+
sorted_consist_signal_dict = {}
|
|
1197
|
+
for k in sorted_consist_signal_list:
|
|
1198
|
+
kk_list = sorted(list(consist_signal_dict[k].keys()))
|
|
1199
|
+
sorted_consist_signal_dict[k] = {kk: consist_signal_dict[k][kk] for kk in kk_list}
|
|
1200
|
+
|
|
1201
|
+
# for a, b in sorted_consist_signal_dict[k].items():
|
|
1202
|
+
# print(f'\n{k}, {a}, {b}')
|
|
1203
|
+
##############################################################################
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
##############################################################################
|
|
1208
|
+
# Group based on metastable states, then consistensy
|
|
1209
|
+
self.logger.info(f'\nGrouping consistency data based on metastable states and consistency signals')
|
|
1210
|
+
group_dict = {}
|
|
1211
|
+
for state_id in self.state_idx_list:
|
|
1212
|
+
# print(f'Processing state {state_id}...')
|
|
1213
|
+
group_dict[state_id] = {}
|
|
1214
|
+
|
|
1215
|
+
idx_list = np.array(np.where(meta_dtrajs_last == state_id)).T
|
|
1216
|
+
|
|
1217
|
+
for idx_list_0 in idx_list:
|
|
1218
|
+
|
|
1219
|
+
[i, j] = list(idx_list_0)
|
|
1220
|
+
dd = dtrajs_MS[i, j]
|
|
1221
|
+
# print(f'Processing idx: {i}, {j}, dd: {dd}')
|
|
1222
|
+
cluster_idxs = dtrajs_cluster_idx[i,j]
|
|
1223
|
+
# print(f'Cluster idx: {cluster_idxs}')
|
|
1224
|
+
|
|
1225
|
+
if str(dd) not in group_dict[state_id].keys():
|
|
1226
|
+
group_dict[state_id][str(dd)] = {cluster_idxs: [[i,j]]}
|
|
1227
|
+
|
|
1228
|
+
elif cluster_idxs not in group_dict[state_id][str(dd)].keys():
|
|
1229
|
+
group_dict[state_id][str(dd)][cluster_idxs] = [[i,j]]
|
|
1230
|
+
|
|
1231
|
+
else:
|
|
1232
|
+
group_dict[state_id][str(dd)][cluster_idxs].append([i,j])
|
|
1233
|
+
|
|
1234
|
+
# sort based on population
|
|
1235
|
+
Num_struct_list = [np.sum(np.array([len(vv) for vv in v.values()])) for k,v in group_dict[state_id].items()]
|
|
1236
|
+
sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
|
|
1237
|
+
sorted_list = [list(group_dict[state_id].keys())[idx] for idx in sort_idx]
|
|
1238
|
+
new_dict = {}
|
|
1239
|
+
for k in sorted_list:
|
|
1240
|
+
kk_list = sorted(list(group_dict[state_id][k].keys()))
|
|
1241
|
+
new_dict[k] = {kk: group_dict[state_id][k][kk] for kk in kk_list}
|
|
1242
|
+
group_dict[state_id] = new_dict
|
|
1243
|
+
|
|
1244
|
+
# print(f'\nGrouped consistency data for state {state_id}:')
|
|
1245
|
+
# for k,v in group_dict[state_id].items():
|
|
1246
|
+
# print(f'{k}: {v}')
|
|
1247
|
+
##############################################################################
|
|
1248
|
+
|
|
1249
|
+
##############################################################################
|
|
1250
|
+
# Load Q list
|
|
1251
|
+
self.logger.info(f'Loading G and Q data from {self.OPpath}')
|
|
1252
|
+
Q_list, G_list = self.load_OP(start=-last_num_frames)
|
|
1253
|
+
self.logger.info(f'Loaded G and Q data from {self.OPpath}')
|
|
1254
|
+
##############################################################################
|
|
1255
|
+
|
|
1256
|
+
##############################################################################
|
|
1257
|
+
# Get representative structures
|
|
1258
|
+
self.logger.info(f'\nGetting representative structures for each group...')
|
|
1259
|
+
rep_group_dict = {}
|
|
1260
|
+
for state_id in self.state_idx_list:
|
|
1261
|
+
# print(f'Processing state {state_id}...')
|
|
1262
|
+
|
|
1263
|
+
rep_group_dict[state_id] = {}
|
|
1264
|
+
for k in group_dict[state_id].keys():
|
|
1265
|
+
|
|
1266
|
+
if k == '[]':
|
|
1267
|
+
continue
|
|
1268
|
+
|
|
1269
|
+
# print(f' Processing consistent signal k={k}...')
|
|
1270
|
+
rep_group_dict[state_id][k] = {}
|
|
1271
|
+
for kk in group_dict[state_id][k].keys():
|
|
1272
|
+
# print(f' Processing entanglement ID kk={kk}...')
|
|
1273
|
+
idx_list = np.array(group_dict[state_id][k][kk])
|
|
1274
|
+
# print(f' idx_list: {idx_list}')
|
|
1275
|
+
|
|
1276
|
+
# Max micro-states probability
|
|
1277
|
+
micro_prob = meta_dist[state_id][micro_dtrajs_last[idx_list[:,0], idx_list[:,1]]]
|
|
1278
|
+
max_idx = np.where(micro_prob == np.max(micro_prob))[0]
|
|
1279
|
+
max_idx_list = idx_list[max_idx,:]
|
|
1280
|
+
# print(f' max_idx_list: {max_idx_list}')
|
|
1281
|
+
|
|
1282
|
+
# Max Q
|
|
1283
|
+
Q_list_0 = Q_list[max_idx_list[:,0], max_idx_list[:,1]]
|
|
1284
|
+
max_idx = np.where(Q_list_0 == np.max(Q_list_0))[0]
|
|
1285
|
+
max_idx_list = max_idx_list[max_idx,:]
|
|
1286
|
+
# print(f' max_idx_list after Q: {max_idx_list}')
|
|
1287
|
+
|
|
1288
|
+
# Max G
|
|
1289
|
+
G_list_0 = G_list[max_idx_list[:,0], max_idx_list[:,1]]
|
|
1290
|
+
max_idx = np.where(G_list_0 == np.max(G_list_0))[0]
|
|
1291
|
+
max_idx_list = max_idx_list[max_idx,:]
|
|
1292
|
+
# print(f' max_idx_list after G: {max_idx_list}')
|
|
1293
|
+
|
|
1294
|
+
[rep_traj_idx, rep_frame_idx] = max_idx_list[0,:]
|
|
1295
|
+
# print(f' Representative structure: MSM traj idx {rep_traj_idx} -> {idx2traj[rep_traj_idx]}, frame {rep_frame_idx}')
|
|
1296
|
+
rep_group_dict[state_id][k][kk] = [rep_traj_idx, rep_frame_idx]
|
|
1297
|
+
##############################################################################
|
|
1298
|
+
|
|
1299
|
+
|
|
1300
|
+
##############################################################################
|
|
1301
|
+
# Save data
|
|
1302
|
+
consist_signal_struct_data_outfile = os.path.join(self.outdir, 'consist_signal_struct_data.npz')
|
|
1303
|
+
np.savez(consist_signal_struct_data_outfile,
|
|
1304
|
+
last_num_frames = last_num_frames,
|
|
1305
|
+
total_traj_num_frames = total_traj_num_frames,
|
|
1306
|
+
LIPMS_consist_data=LIPMS_consist_data,
|
|
1307
|
+
XLMS_consist_data=XLMS_consist_data,
|
|
1308
|
+
LIPMS_struct_data=LIPMS_struct_data,
|
|
1309
|
+
XLMS_struct_data=XLMS_struct_data,
|
|
1310
|
+
dtrajs_MS=dtrajs_MS,
|
|
1311
|
+
sorted_consist_signal_dict=sorted_consist_signal_dict,
|
|
1312
|
+
group_dict=group_dict,
|
|
1313
|
+
rep_group_dict=rep_group_dict,)
|
|
1314
|
+
self.logger.debug(f'SAVED: {consist_signal_struct_data_outfile}')
|
|
1315
|
+
|
|
1316
|
+
# Save info to excel
|
|
1317
|
+
self.logger.info(f'Saving info to excel...')
|
|
1318
|
+
df_list = []
|
|
1319
|
+
sheet_name_list = []
|
|
1320
|
+
df_data = []
|
|
1321
|
+
for k in sorted_consist_signal_dict.keys():
|
|
1322
|
+
k_list = eval(k)
|
|
1323
|
+
if k == '[]':
|
|
1324
|
+
continue
|
|
1325
|
+
k_str = ', '.join(k_list)
|
|
1326
|
+
|
|
1327
|
+
for kk in sorted_consist_signal_dict[k].keys():
|
|
1328
|
+
kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
|
|
1329
|
+
if len(kk_list) == 0:
|
|
1330
|
+
continue
|
|
1331
|
+
kk_str = ', '.join([str(i+1) for i in kk_list])
|
|
1332
|
+
num = len(sorted_consist_signal_dict[k][kk])
|
|
1333
|
+
df_data.append([k_str, kk_str, len(k_list), num])
|
|
1334
|
+
df = pd.DataFrame(df_data, columns=['Consistent signals', 'IDs of Changes in Entanglements', 'Number of consistent signals', 'Number of Structures'])
|
|
1335
|
+
df_sorted = df.sort_values(by=['Number of consistent signals', 'Consistent signals', 'Number of Structures'], ascending=[False, True, False])
|
|
1336
|
+
df_list.append(df_sorted)
|
|
1337
|
+
sheet_name_list.append('Total')
|
|
1338
|
+
|
|
1339
|
+
for state_id in self.state_idx_list:
|
|
1340
|
+
sheet_name_list.append('State %d'%(state_id+1))
|
|
1341
|
+
df_data = []
|
|
1342
|
+
|
|
1343
|
+
for k in rep_group_dict[state_id].keys():
|
|
1344
|
+
k_list = eval(k)
|
|
1345
|
+
k_str = ', '.join(k_list)
|
|
1346
|
+
#print(f'\nProcessing state {state_id}, consistent signal {k_str}')
|
|
1347
|
+
|
|
1348
|
+
for kk in rep_group_dict[state_id][k].keys():
|
|
1349
|
+
|
|
1350
|
+
kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
|
|
1351
|
+
if len(kk_list) == 0:
|
|
1352
|
+
continue
|
|
1353
|
+
kk_str = ', '.join([str(i+1) for i in kk_list])
|
|
1354
|
+
[traj_idx, frame_idx] = rep_group_dict[state_id][k][kk]
|
|
1355
|
+
#print(f'traj_idx: {traj_idx} | frame_idx: {frame_idx}')
|
|
1356
|
+
|
|
1357
|
+
traj_frame_idx = total_traj_num_frames - last_num_frames + frame_idx
|
|
1358
|
+
#print(f'Processing state {state_id}, consistent signal {k_str}, entanglement ID {kk_str}, traj {traj_idx+1}, frame {traj_frame_idx}')
|
|
1359
|
+
|
|
1360
|
+
num = len(group_dict[state_id][k][kk])
|
|
1361
|
+
micro_prob = meta_dist[state_id][micro_dtrajs_last[traj_idx, frame_idx]]
|
|
1362
|
+
# rel_change_list = calc_rel_change(traj_idx, frame_idx)
|
|
1363
|
+
traj = idx2traj[traj_idx]
|
|
1364
|
+
df_data.append([k_str, kk_str, len(k_list), num, str([traj, traj_frame_idx+1]), micro_prob, Q_list[traj_idx, frame_idx], G_list[traj_idx, frame_idx] ])
|
|
1365
|
+
|
|
1366
|
+
df = pd.DataFrame(df_data, columns=['Consistent signals', 'IDs of Changes in Entanglements', 'Number of consistent signals', 'Number of Structures', 'Representative Structure (Traj #, Frame #)', 'Prob', 'Q', 'G'])
|
|
1367
|
+
df_sorted = df.sort_values(by=['Number of consistent signals', 'Prob', 'Q', 'G'], ascending=[False, False, False, False])
|
|
1368
|
+
df_list.append(df_sorted)
|
|
1369
|
+
|
|
1370
|
+
# Creating Excel Writer Object from Pandas
|
|
1371
|
+
Consistent_structures_v8_outfile = os.path.join(self.outdir, 'Consistent_structures_v8.xlsx')
|
|
1372
|
+
with pd.ExcelWriter(Consistent_structures_v8_outfile, engine="openpyxl") as writer:
|
|
1373
|
+
workbook=writer.book
|
|
1374
|
+
for idx_0, df in enumerate(df_list):
|
|
1375
|
+
worksheet=workbook.create_sheet(sheet_name_list[idx_0])
|
|
1376
|
+
writer.sheets[sheet_name_list[idx_0]] = worksheet
|
|
1377
|
+
df.to_excel(writer, sheet_name=sheet_name_list[idx_0], index=False)
|
|
1378
|
+
self.logger.debug(f'SAVED: {Consistent_structures_v8_outfile}')
|
|
1379
|
+
##############################################################################
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
##############################################################################
|
|
1383
|
+
# Create visualization
|
|
1384
|
+
self.logger.info(f'Create visualizations...')
|
|
1385
|
+
if self.dist_data_file is not None and os.path.exists(self.dist_data_file):
|
|
1386
|
+
self.logger.info(f'Loading dist_traj_list: {self.dist_data_file}')
|
|
1387
|
+
dist_traj_list = np.load(self.dist_data_file, allow_pickle=True)[:,-last_num_frames:]
|
|
1388
|
+
self.logger.info(f'Loaded distance data: {dist_traj_list.shape}')
|
|
1389
|
+
dist_traj_list = [v for i, v in enumerate(dist_traj_list) if i not in np.asarray(self.rm_traj_list) - 1]
|
|
1390
|
+
dist_traj_list = np.array(dist_traj_list)
|
|
1391
|
+
self.logger.info(f'dist_traj_list after removal of mirror images: {dist_traj_list.shape}')
|
|
1392
|
+
elif self.xp_dir is not None and os.path.isdir(self.xp_dir):
|
|
1393
|
+
self.logger.info(f'Building sparse dist_traj_list from XP files in: {self.xp_dir}')
|
|
1394
|
+
dist_traj_list = np.empty((len(idx2traj), last_num_frames), dtype=object)
|
|
1395
|
+
dist_traj_list[:] = None
|
|
1396
|
+
|
|
1397
|
+
needed_frames_by_traj = {}
|
|
1398
|
+
for state_id in self.state_idx_list:
|
|
1399
|
+
for k in rep_group_dict[state_id].keys():
|
|
1400
|
+
for kk in rep_group_dict[state_id][k].keys():
|
|
1401
|
+
traj_idx, frame_idx = rep_group_dict[state_id][k][kk]
|
|
1402
|
+
needed_frames_by_traj.setdefault(int(traj_idx), set()).add(int(frame_idx))
|
|
1403
|
+
|
|
1404
|
+
for traj_idx, needed_frame_idx in needed_frames_by_traj.items():
|
|
1405
|
+
traj_num = int(idx2traj[traj_idx])
|
|
1406
|
+
fpath = os.path.join(self.xp_dir, f'{self.ID}_Traj{traj_num}.XP')
|
|
1407
|
+
if not os.path.exists(fpath):
|
|
1408
|
+
self.logger.warning(f'Missing XP file for sparse visualization load: {fpath}')
|
|
1409
|
+
continue
|
|
1410
|
+
|
|
1411
|
+
df = pd.read_csv(
|
|
1412
|
+
fpath,
|
|
1413
|
+
sep='\t',
|
|
1414
|
+
usecols=['Frame', 'Atom1', 'Atom2', 'Euclidean Distance', 'SASD'],
|
|
1415
|
+
dtype={
|
|
1416
|
+
'Frame': np.int32,
|
|
1417
|
+
'SASD': np.float32,
|
|
1418
|
+
'Euclidean Distance': np.float32,
|
|
1419
|
+
'Atom1': 'string',
|
|
1420
|
+
'Atom2': 'string',
|
|
1421
|
+
},
|
|
1422
|
+
)
|
|
1423
|
+
if df.empty:
|
|
1424
|
+
continue
|
|
1425
|
+
|
|
1426
|
+
frame_values = np.sort(df['Frame'].unique())
|
|
1427
|
+
frame_to_idx = {int(f): idx for idx, f in enumerate(frame_values)}
|
|
1428
|
+
df['frame_idx'] = df['Frame'].map(frame_to_idx)
|
|
1429
|
+
df = df[df['frame_idx'].isin(needed_frame_idx)]
|
|
1430
|
+
if df.empty:
|
|
1431
|
+
continue
|
|
1432
|
+
|
|
1433
|
+
atom1_parts = df['Atom1'].str.split('-', expand=True)
|
|
1434
|
+
atom2_parts = df['Atom2'].str.split('-', expand=True)
|
|
1435
|
+
if atom1_parts.shape[1] < 3 or atom2_parts.shape[1] < 3:
|
|
1436
|
+
self.logger.warning(f'Malformed Atom fields in {fpath}; skipping sparse load for traj {traj_num}')
|
|
1437
|
+
continue
|
|
1438
|
+
|
|
1439
|
+
df['key'] = (
|
|
1440
|
+
atom1_parts[1].astype(str) + '|' + atom1_parts[2].astype(str) + '-' +
|
|
1441
|
+
atom2_parts[1].astype(str) + '|' + atom2_parts[2].astype(str)
|
|
1442
|
+
)
|
|
1443
|
+
|
|
1444
|
+
for frame_idx, frame_df in df.groupby('frame_idx', sort=False):
|
|
1445
|
+
frame_dict = {}
|
|
1446
|
+
for _, row in frame_df.iterrows():
|
|
1447
|
+
frame_dict[row['key']] = {
|
|
1448
|
+
'Euclidean': float(row['Euclidean Distance']),
|
|
1449
|
+
'Jwalk': float(row['SASD']),
|
|
1450
|
+
}
|
|
1451
|
+
dist_traj_list[traj_idx, int(frame_idx)] = frame_dict
|
|
1452
|
+
else:
|
|
1453
|
+
raise ValueError(
|
|
1454
|
+
'Representative structure visualization requires either dist_data_file (Jwalk.npy) or xp_dir.'
|
|
1455
|
+
)
|
|
1456
|
+
|
|
1457
|
+
|
|
1458
|
+
# Check if the viz_rep_struct path exists. if so remove it and make a fresh one
|
|
1459
|
+
if os.path.exists('viz_rep_struct'):
|
|
1460
|
+
self.logger.info(f'viz_rep_struct exists and will be removed')
|
|
1461
|
+
os.system('rm -rf viz_rep_struct/')
|
|
1462
|
+
os.system('mkdir viz_rep_struct/')
|
|
1463
|
+
os.chdir('viz_rep_struct/')
|
|
1464
|
+
|
|
1465
|
+
if os.path.isdir(self.AAdcd_dir):
|
|
1466
|
+
AAtraj_files = glob.glob(os.path.join(self.AAdcd_dir, '*.dcd'))
|
|
1467
|
+
else:
|
|
1468
|
+
AAtraj_files = glob.glob(self.AAdcd_dir)
|
|
1469
|
+
if len(AAtraj_files) == 0:
|
|
1470
|
+
raise ValueError(
|
|
1471
|
+
f'No AA trajectory files found. AAdcd_dir={self.AAdcd_dir} '
|
|
1472
|
+
f'(resolved count={len(AAtraj_files)})'
|
|
1473
|
+
)
|
|
1474
|
+
self.logger.info(f'AAtraj_files:\n{AAtraj_files[:10]}')
|
|
1475
|
+
|
|
1476
|
+
wd = os.getcwd()
|
|
1477
|
+
for state_id in self.state_idx_list:
|
|
1478
|
+
state_dir = os.path.join(wd, 'State_%d'%(state_id+1))
|
|
1479
|
+
os.makedirs(state_dir, exist_ok=True)
|
|
1480
|
+
self.logger.info(f'Made {state_dir}')
|
|
1481
|
+
# os.chdir(state_dir)
|
|
1482
|
+
self.logger.info(f'Length of rep_group_dict[state_id]: {len(rep_group_dict[state_id])}')
|
|
1483
|
+
args_list = [
|
|
1484
|
+
(state_dir, state_id, k, rep_group_dict, sorted_chg_ent_structure_keyword_list, last_num_frames, total_traj_num_frames,
|
|
1485
|
+
idx2traj, AAtraj_files, self.native_AA_pdb, rep_chg_ent_dtrajs, Q_list, G_list,
|
|
1486
|
+
LIPMS_consist_data, M_LiPMS, XLMS_consist_data, M_XLMS, dist_traj_list, if_backmap, pulchra_only, self.logger.name)
|
|
1487
|
+
for k in rep_group_dict[state_id].keys()
|
|
1488
|
+
]
|
|
1489
|
+
self.logger.info(f'Processing {len(args_list)} consistent signal groups for state {state_id}...')
|
|
1490
|
+
if len(args_list) == 0:
|
|
1491
|
+
self.logger.info(f'No consistent signal groups for state {state_id}, skipping...')
|
|
1492
|
+
continue
|
|
1493
|
+
|
|
1494
|
+
# process_k(args_list[0])
|
|
1495
|
+
# print('Testing done for one process_k, exiting...')
|
|
1496
|
+
# quit()
|
|
1497
|
+
# nproc = 10
|
|
1498
|
+
with multiprocessing.Pool(processes=self.nproc) as pool:
|
|
1499
|
+
pool.map(process_k, args_list)
|
|
1500
|
+
self.logger.debug('Completion of selecting rep structure')
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
##################################################################################################
|
|
1504
|
+
import mdtraj as mdt # at the top of your file
|
|
1505
|
+
def process_k(args):
|
|
1506
|
+
(state_dir, state_id, k, rep_group_dict, sorted_chg_ent_structure_keyword_list, last_num_frames, total_traj_num_frames,
|
|
1507
|
+
idx2traj, AAtraj_files, native_AA_pdb, rep_chg_ent_dtrajs, Q_list, G_list,
|
|
1508
|
+
LIPMS_consist_data, M_LiPMS, XLMS_consist_data, M_XLMS, dist_traj_list, if_backmap, pulchra_only, logger_name) = args
|
|
1509
|
+
|
|
1510
|
+
logger = logging.getLogger(logger_name)
|
|
1511
|
+
|
|
1512
|
+
k_list = eval(k)
|
|
1513
|
+
k_str = '_'.join(k_list)
|
|
1514
|
+
k_str_dir = os.path.join(state_dir, k_str)
|
|
1515
|
+
os.makedirs(k_str_dir, exist_ok=True)
|
|
1516
|
+
logger.info(f'Made {k_str_dir}')
|
|
1517
|
+
key_order = ['type', 'code', 'native_contact', 'native_contact_residx', 'linking_value', 'crossing_resid', 'crossing_residx', 'crossing_pattern', 'gauss_linking_number', 'topoly_linking_number',
|
|
1518
|
+
'ref_native_contact', 'ref_native_contact_residx', 'ref_linking_value', 'ref_crossing_resid', 'ref_crossing_residx', 'ref_crossing_pattern', 'ref_gauss_linking_number', 'ref_topoly_linking_number']
|
|
1519
|
+
# os.system('mkdir %s/' % k_str)
|
|
1520
|
+
# os.chdir(k_str)
|
|
1521
|
+
|
|
1522
|
+
for kk in rep_group_dict[state_id][k].keys():
|
|
1523
|
+
kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
|
|
1524
|
+
if len(kk_list) == 0:
|
|
1525
|
+
continue
|
|
1526
|
+
|
|
1527
|
+
kk_str = '_'.join([str(i+1) for i in kk_list])
|
|
1528
|
+
kk_str_dir = os.path.join(k_str_dir, kk_str)
|
|
1529
|
+
os.makedirs(kk_str_dir, exist_ok=True)
|
|
1530
|
+
logger.info(f'Made {kk_str_dir}')
|
|
1531
|
+
# os.system('mkdir %s/' % kk_str)
|
|
1532
|
+
# os.chdir(kk_str)
|
|
1533
|
+
|
|
1534
|
+
[traj_idx, frame_idx] = rep_group_dict[state_id][k][kk]
|
|
1535
|
+
traj = idx2traj[traj_idx]
|
|
1536
|
+
traj_frame_idx = total_traj_num_frames - last_num_frames + frame_idx
|
|
1537
|
+
|
|
1538
|
+
AAtraj_file = match_pattern(AAtraj_files, f'{traj}')
|
|
1539
|
+
if len(AAtraj_file) != 1:
|
|
1540
|
+
raise ValueError(f'Found {len(AAtraj_file)} AA traj files for traj {traj}, expected 1.')
|
|
1541
|
+
|
|
1542
|
+
state_cor = mdt.load(AAtraj_file[0], top=native_AA_pdb)[traj_frame_idx].center_coordinates().xyz * 10
|
|
1543
|
+
rep_chg_ent_dict = rep_chg_ent_dtrajs[traj_idx][frame_idx]
|
|
1544
|
+
|
|
1545
|
+
rep_ent_dict = {tuple(v['code']): [] for kkk, v in rep_chg_ent_dict.items()}
|
|
1546
|
+
for kkk, v in rep_chg_ent_dict.items():
|
|
1547
|
+
v['chg_index'] = kkk
|
|
1548
|
+
rep_ent_dict[tuple(v['code'])].append(v)
|
|
1549
|
+
gen_state_visualizion(state_id, kk_str, kk_str_dir, native_AA_pdb, state_cor, native_AA_pdb, rep_ent_dict,
|
|
1550
|
+
logger,
|
|
1551
|
+
if_backmap=if_backmap, pulchra_only=pulchra_only, exp_signal_str=k_str)
|
|
1552
|
+
|
|
1553
|
+
Q = Q_list[traj_idx, frame_idx]
|
|
1554
|
+
G = G_list[traj_idx, frame_idx]
|
|
1555
|
+
info_file = os.path.join(kk_str_dir, 'info.txt')
|
|
1556
|
+
with open(info_file, 'w') as f:
|
|
1557
|
+
f.write('State #%d\n' % (state_id + 1))
|
|
1558
|
+
f.write('Q: %f\n' % (Q))
|
|
1559
|
+
f.write('G: %f\n' % (G))
|
|
1560
|
+
f.write('Consistent experiment signals: %s\n' % k)
|
|
1561
|
+
f.write('Changes in entanglement cluster IDs: %s\n' % [i+1 for i in kk_list])
|
|
1562
|
+
f.write('Trajectory number: %d\n' % (traj))
|
|
1563
|
+
f.write('Frame number: %d\n' % (traj_frame_idx + 1))
|
|
1564
|
+
f.write('%s\n' % ('-' * 64))
|
|
1565
|
+
for signal in k_list:
|
|
1566
|
+
if signal in LIPMS_consist_data.keys():
|
|
1567
|
+
M = M_LiPMS[traj_idx, frame_idx, LIPMS_consist_data[signal][4]]
|
|
1568
|
+
f.write('%s: M: %.4f\n' % (signal, M))
|
|
1569
|
+
elif signal in XLMS_consist_data.keys():
|
|
1570
|
+
M = M_XLMS[traj_idx, frame_idx, XLMS_consist_data[signal][4]]
|
|
1571
|
+
requested_key = '%s|A-%s|A' % (signal.split('-')[0][1:], signal.split('-')[1][1:])
|
|
1572
|
+
frame_data = dist_traj_list[traj_idx, frame_idx]
|
|
1573
|
+
if frame_data is None or requested_key not in frame_data:
|
|
1574
|
+
logger.warning(
|
|
1575
|
+
f'Missing XL-MS key for signal {signal}. '
|
|
1576
|
+
f'Traj {idx2traj[traj_idx]}, frame {frame_idx}, '
|
|
1577
|
+
f'MSM indices [{traj_idx}, {frame_idx}]. '
|
|
1578
|
+
f'Requested key: {requested_key}. '
|
|
1579
|
+
f'Available keys in frame_dict: {list(frame_data.keys()) if frame_data is not None else "None"}. '
|
|
1580
|
+
f'Skipping this signal in info file.'
|
|
1581
|
+
)
|
|
1582
|
+
f.write('%s: M: %.4f Jwalk: N/A Euclidean: N/A (key not in XP data)\n' % (signal, M))
|
|
1583
|
+
else:
|
|
1584
|
+
d_data = frame_data[requested_key]
|
|
1585
|
+
f.write('%s: M: %.4f Jwalk: %.4f Euclidean: %.4f\n' % (signal, M, d_data['Jwalk'], d_data['Euclidean']))
|
|
1586
|
+
f.write('%s\n' % ('-' * 64))
|
|
1587
|
+
for kkk, v in rep_chg_ent_dict.items():
|
|
1588
|
+
f.write('%d:\n' % (kkk + 1))
|
|
1589
|
+
# f.write(' ' * 4 + 'code: %s\n' % v['code'])
|
|
1590
|
+
# for kkkk in ['linking_value', 'topoly_linking_number', 'native_contact_residx', 'crossing_residx', 'crossing_pattern']:
|
|
1591
|
+
for kkkk in key_order:
|
|
1592
|
+
f.write(' ' * 4 + '%s: %s\n' % (kkkk, v[kkkk]))
|
|
1593
|
+
# f.write(' ' * 4 + 'ref_%s: %s\n' % (kkkk, v['ref_' + kkkk]))
|
|
1594
|
+
logger.debug(f'SAVED: {info_file}')
|
|
1595
|
+
##############################################################################
|
|
1596
|
+
|
|
1597
|
+
##############################################################################
|
|
1598
|
+
def match_pattern(strings, user_substring):
|
|
1599
|
+
# Build the regex pattern: non-digit (\D), user substring, then underscore
|
|
1600
|
+
pattern = re.compile(rf"\D{re.escape(user_substring)}_")
|
|
1601
|
+
|
|
1602
|
+
# Filter the list to only those matching the pattern
|
|
1603
|
+
matches = [s for s in strings if pattern.search(s)]
|
|
1604
|
+
return matches
|
|
1605
|
+
##############################################################################
|
|
1606
|
+
|
|
1607
|
+
##############################################################################
|
|
1608
|
+
def gen_state_visualizion(state_id, ent_id, kk_str_dir, psf, state_cor, native_AA_pdb, rep_ent_dict, logger, if_backmap=True, pulchra_only=False, exp_signal_str=''):
|
|
1609
|
+
def idx2sel(idx_list):
|
|
1610
|
+
if len(idx_list) == 0:
|
|
1611
|
+
return ''
|
|
1612
|
+
else:
|
|
1613
|
+
sel = 'index'
|
|
1614
|
+
idx_0 = idx_list[0]
|
|
1615
|
+
idx_1 = idx_list[0]
|
|
1616
|
+
sel_0 = ' %d'%idx_0
|
|
1617
|
+
for i in range(1, len(idx_list)):
|
|
1618
|
+
if idx_list[i] == idx_list[i-1] + 1:
|
|
1619
|
+
idx_1 = idx_list[i]
|
|
1620
|
+
else:
|
|
1621
|
+
if idx_1 > idx_0:
|
|
1622
|
+
sel_0 += ' to %d'%idx_1
|
|
1623
|
+
sel += sel_0
|
|
1624
|
+
idx_0 = idx_list[i]
|
|
1625
|
+
idx_1 = idx_list[i]
|
|
1626
|
+
sel_0 = ' %d'%idx_0
|
|
1627
|
+
if idx_1 > idx_0:
|
|
1628
|
+
sel_0 += ' to %d'%idx_1
|
|
1629
|
+
sel += sel_0
|
|
1630
|
+
return sel
|
|
1631
|
+
|
|
1632
|
+
AA_name_list = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
|
|
1633
|
+
'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL',
|
|
1634
|
+
'HIE', 'HID', 'HIP']
|
|
1635
|
+
protein_colorid_list = [12, 15]
|
|
1636
|
+
loop_colorid_list = [4, 1]
|
|
1637
|
+
thread_colorid_list = [7, 0]
|
|
1638
|
+
nc_colorid_list = [3, 3]
|
|
1639
|
+
crossing_colorid_list = [8, 8]
|
|
1640
|
+
LIP_colorid_list = [13, 10]
|
|
1641
|
+
XL_colorid_list = [13, 10]
|
|
1642
|
+
thread_cutoff=3
|
|
1643
|
+
terminal_cutoff=3
|
|
1644
|
+
|
|
1645
|
+
logger.info('Generate visualization of state %d'%(state_id + 1))
|
|
1646
|
+
logger.debug(f'ent_id: {ent_id}')
|
|
1647
|
+
logger.debug(f'kk_str_dir: {kk_str_dir}')
|
|
1648
|
+
logger.debug(f'psf: {psf}')
|
|
1649
|
+
logger.debug(f'state_cor: {state_cor}')
|
|
1650
|
+
logger.debug(f'native_AA_pdb: {native_AA_pdb}')
|
|
1651
|
+
logger.debug(f'if_backmap: {if_backmap}')
|
|
1652
|
+
logger.debug(f'pulchra_only: {pulchra_only}')
|
|
1653
|
+
logger.debug(f'exp_signal_str: {exp_signal_str}')
|
|
1654
|
+
logger.info(f'rep_ent_dict:')
|
|
1655
|
+
for k, v in rep_ent_dict.items():
|
|
1656
|
+
logger.info(f'k:\n{k}')
|
|
1657
|
+
logger.info(f'v:\n{v} {len(v)}')
|
|
1658
|
+
|
|
1659
|
+
struct = pmd.load_file(psf)
|
|
1660
|
+
struct.coordinates = state_cor
|
|
1661
|
+
|
|
1662
|
+
# backmap
|
|
1663
|
+
pdb_path = os.path.join(kk_str_dir, f'state_{state_id + 1}.pdb')
|
|
1664
|
+
if if_backmap:
|
|
1665
|
+
if pulchra_only:
|
|
1666
|
+
pulchra_only = '1'
|
|
1667
|
+
else:
|
|
1668
|
+
pulchra_only = '0'
|
|
1669
|
+
temp_pdb_path = os.path.join(kk_str_dir, 'tmp.pdb')
|
|
1670
|
+
struct.save(temp_pdb_path, overwrite=True)
|
|
1671
|
+
os.system('backmap.py -i '+native_AA_pdb+' -c '+temp_pdb_path+' -p '+pulchra_only)
|
|
1672
|
+
os.system('mv tmp_rebuilt.pdb '+pdb_path)
|
|
1673
|
+
os.system('rm -f '+temp_pdb_path)
|
|
1674
|
+
os.system('rm -rf ./rebuild_tmp/')
|
|
1675
|
+
else:
|
|
1676
|
+
struct.save(pdb_path, overwrite=True)
|
|
1677
|
+
|
|
1678
|
+
ref_struct = pmd.load_file(native_AA_pdb)
|
|
1679
|
+
current_struct = pmd.load_file(pdb_path)
|
|
1680
|
+
|
|
1681
|
+
# parse exp_signal_str
|
|
1682
|
+
if exp_signal_str == '':
|
|
1683
|
+
exp_signal_list = []
|
|
1684
|
+
else:
|
|
1685
|
+
exp_signal_list = exp_signal_str.strip().split('_')
|
|
1686
|
+
exp_signal_list = [es.strip().split('-') for es in exp_signal_list]
|
|
1687
|
+
exp_signal_list = [[int(ees[1:]) for ees in es] for es in exp_signal_list]
|
|
1688
|
+
|
|
1689
|
+
##############################################
|
|
1690
|
+
## no change of entaglement
|
|
1691
|
+
if len(list(rep_ent_dict.keys())) == 0:
|
|
1692
|
+
vmd_outfile = os.path.join(kk_str_dir, f'vmd_s{state_id}_none.tcl')
|
|
1693
|
+
f = open(vmd_outfile, 'w')
|
|
1694
|
+
f.write('# Entanglement type: no change\n')
|
|
1695
|
+
f.write('''display rendermode GLSL
|
|
1696
|
+
display projection Orthographic
|
|
1697
|
+
axes location off
|
|
1698
|
+
|
|
1699
|
+
color Display {Background} white
|
|
1700
|
+
|
|
1701
|
+
mol new ./'''+(pdb_path)+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
|
|
1702
|
+
mol delrep 0 top
|
|
1703
|
+
mol representation NewCartoon 0.300000 10.000000 4.100000 0
|
|
1704
|
+
mol color ColorID '''+str(protein_colorid_list[1])+'''
|
|
1705
|
+
mol selection {all}
|
|
1706
|
+
mol material AOChalky
|
|
1707
|
+
mol addrep top
|
|
1708
|
+
''')
|
|
1709
|
+
f.close()
|
|
1710
|
+
logger.debug(f'SAVED: {vmd_outfile}')
|
|
1711
|
+
|
|
1712
|
+
##############################################
|
|
1713
|
+
## Create vmd script for each type of change
|
|
1714
|
+
for ent_code, rep_ent_list in rep_ent_dict.items():
|
|
1715
|
+
# print(f'ent_code:\n{ent_code}')
|
|
1716
|
+
# print(f'rep_ent_list:\n{rep_ent_list}')
|
|
1717
|
+
pmd_struct_list = [ref_struct, current_struct]
|
|
1718
|
+
struct_dir_list = [native_AA_pdb, pdb_path]
|
|
1719
|
+
key_prefix_list = ['ref_', '']
|
|
1720
|
+
repres_list = ['', '']
|
|
1721
|
+
align_sel_list = ['', '']
|
|
1722
|
+
|
|
1723
|
+
|
|
1724
|
+
for chg_ent_fingerprint_idx, chg_ent_fingerprint in enumerate(rep_ent_list):
|
|
1725
|
+
# print(f'\nProcessing change of entanglement fingerprint {chg_ent_fingerprint_idx}:\n{chg_ent_fingerprint}')
|
|
1726
|
+
vmd_script = '''# Entanglement type: '''+str(chg_ent_fingerprint['type'])+'''
|
|
1727
|
+
package require topotools
|
|
1728
|
+
display rendermode GLSL
|
|
1729
|
+
display projection Orthographic
|
|
1730
|
+
axes location off
|
|
1731
|
+
|
|
1732
|
+
color Display {Background} white
|
|
1733
|
+
color Labels {Bonds} black
|
|
1734
|
+
|
|
1735
|
+
label textsize 0.000001
|
|
1736
|
+
|
|
1737
|
+
'''
|
|
1738
|
+
for struct_idx, pmd_struct in enumerate(pmd_struct_list):
|
|
1739
|
+
struct_dir = struct_dir_list[struct_idx]
|
|
1740
|
+
protein_colorid = protein_colorid_list[struct_idx]
|
|
1741
|
+
loop_colorid = loop_colorid_list[struct_idx]
|
|
1742
|
+
thread_colorid = thread_colorid_list[struct_idx]
|
|
1743
|
+
nc_colorid = nc_colorid_list[struct_idx]
|
|
1744
|
+
crossing_colorid = crossing_colorid_list[struct_idx]
|
|
1745
|
+
LIP_colorid = LIP_colorid_list[struct_idx]
|
|
1746
|
+
XL_colorid = XL_colorid_list[struct_idx]
|
|
1747
|
+
key_prefix = key_prefix_list[struct_idx]
|
|
1748
|
+
# print(f'Processing structure {struct_idx}: {struct_dir}')
|
|
1749
|
+
# print(f' Protein color ID: {protein_colorid}')
|
|
1750
|
+
# print(f' Loop color ID: {loop_colorid}')
|
|
1751
|
+
# print(f' Thread color ID: {thread_colorid}')
|
|
1752
|
+
# print(f' NC color ID: {nc_colorid}')
|
|
1753
|
+
# print(f' Crossing color ID: {crossing_colorid}')
|
|
1754
|
+
# print(f' LIP color ID: {LIP_colorid}')
|
|
1755
|
+
# print(f' XL color ID: {XL_colorid}')
|
|
1756
|
+
|
|
1757
|
+
# Clean ligands
|
|
1758
|
+
clean_sel_idx = np.zeros(len(pmd_struct.atoms))
|
|
1759
|
+
for res in pmd_struct.residues:
|
|
1760
|
+
if res.name in AA_name_list:
|
|
1761
|
+
for atm in res.atoms:
|
|
1762
|
+
clean_sel_idx[atm.idx] = 1
|
|
1763
|
+
pmd_clean_struct = pmd_struct[clean_sel_idx]
|
|
1764
|
+
clean_idx_to_idx = np.where(clean_sel_idx == 1)[0]
|
|
1765
|
+
|
|
1766
|
+
# vmd selection string for protein
|
|
1767
|
+
idx_list = []
|
|
1768
|
+
for res in pmd_struct.residues:
|
|
1769
|
+
if res.name in AA_name_list:
|
|
1770
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1771
|
+
vmd_sel = idx2sel(idx_list)
|
|
1772
|
+
|
|
1773
|
+
repres = '''mol new '''+struct_dir+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
|
|
1774
|
+
mol delrep 0 top
|
|
1775
|
+
mol representation NewCartoon 0.300000 10.000000 4.100000 0
|
|
1776
|
+
mol color ColorID '''+str(protein_colorid)+'''
|
|
1777
|
+
mol selection {'''+vmd_sel+'''}
|
|
1778
|
+
mol material GlassBubble
|
|
1779
|
+
mol addrep top
|
|
1780
|
+
'''
|
|
1781
|
+
align_sel = vmd_sel
|
|
1782
|
+
|
|
1783
|
+
nc = chg_ent_fingerprint[key_prefix+'native_contact_residx']
|
|
1784
|
+
chg_idx = chg_ent_fingerprint['chg_index']
|
|
1785
|
+
|
|
1786
|
+
idx_list = []
|
|
1787
|
+
for res in pmd_clean_struct.residues:
|
|
1788
|
+
if res.idx in nc:
|
|
1789
|
+
idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
|
|
1790
|
+
nc_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1791
|
+
# print(f' Native contact residx: {nc}, selection: {nc_sel}')
|
|
1792
|
+
|
|
1793
|
+
idx_list = []
|
|
1794
|
+
for res in pmd_clean_struct.residues:
|
|
1795
|
+
if res.idx >= nc[0] and res.idx <= nc[1]:
|
|
1796
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1797
|
+
loop_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1798
|
+
# print(f' Loop residx: {list(range(nc[0], nc[1]+1))}, selection: {loop_sel}')
|
|
1799
|
+
|
|
1800
|
+
|
|
1801
|
+
align_sel += ' and not (%s)'%loop_sel
|
|
1802
|
+
ref_coss_resid = chg_ent_fingerprint['ref_crossing_residx']
|
|
1803
|
+
cross_resid = chg_ent_fingerprint['crossing_residx']
|
|
1804
|
+
thread = []
|
|
1805
|
+
thread_sel_list = []
|
|
1806
|
+
for ter_idx in range(len(ref_coss_resid)):
|
|
1807
|
+
thread_0 = []
|
|
1808
|
+
resid_list = ref_coss_resid[ter_idx] + cross_resid[ter_idx]
|
|
1809
|
+
if len(resid_list) > 0:
|
|
1810
|
+
thread_0 = [np.min(resid_list)-5, np.max(resid_list)+5]
|
|
1811
|
+
if ter_idx == 0:
|
|
1812
|
+
thread_0[0] = np.max([thread_0[0], terminal_cutoff])
|
|
1813
|
+
thread_0[1] = np.min([thread_0[1], nc[0]-thread_cutoff])
|
|
1814
|
+
else:
|
|
1815
|
+
thread_0[0] = np.max([thread_0[0], nc[1]+thread_cutoff])
|
|
1816
|
+
thread_0[1] = np.min([thread_0[1], len(struct.atoms)-1-terminal_cutoff])
|
|
1817
|
+
idx_list = []
|
|
1818
|
+
for res in pmd_clean_struct.residues:
|
|
1819
|
+
if res.idx >= thread_0[0] and res.idx <= thread_0[1]:
|
|
1820
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1821
|
+
thread_0_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1822
|
+
thread_sel_list.append(thread_0_sel)
|
|
1823
|
+
align_sel += ' and not (%s)'%thread_0_sel
|
|
1824
|
+
else:
|
|
1825
|
+
thread_sel_list.append('')
|
|
1826
|
+
thread.append(thread_0)
|
|
1827
|
+
|
|
1828
|
+
ln = chg_ent_fingerprint[key_prefix+'topoly_linking_number']
|
|
1829
|
+
cross = []
|
|
1830
|
+
for i in range(len(chg_ent_fingerprint[key_prefix+'crossing_residx'])):
|
|
1831
|
+
cross.append([])
|
|
1832
|
+
for j in range(len(chg_ent_fingerprint[key_prefix+'crossing_residx'][i])):
|
|
1833
|
+
cross[-1].append(chg_ent_fingerprint[key_prefix+'crossing_pattern'][i][j]+str(chg_ent_fingerprint[key_prefix+'crossing_residx'][i][j]))
|
|
1834
|
+
repres += '# idx: native_contact_residx %s, topoly_linking_number %s, crossing_residx %s.\n'%(str(nc), str(ln), str(cross))
|
|
1835
|
+
repres +=''' mol representation NewCartoon 0.350000 10.000000 4.100000 0
|
|
1836
|
+
mol color ColorID '''+str(loop_colorid)+'''
|
|
1837
|
+
mol selection {'''+loop_sel+'''}
|
|
1838
|
+
mol material Opaque
|
|
1839
|
+
mol addrep top
|
|
1840
|
+
mol representation VDW 1.000000 12.000000
|
|
1841
|
+
mol color ColorID '''+str(nc_colorid)+'''
|
|
1842
|
+
mol selection {'''+nc_sel+'''}
|
|
1843
|
+
mol material Opaque
|
|
1844
|
+
mol addrep top
|
|
1845
|
+
set sel [atomselect top "'''+nc_sel+'''"]
|
|
1846
|
+
set idx [$sel get index]
|
|
1847
|
+
topo addbond [lindex $idx 0] [lindex $idx 1]
|
|
1848
|
+
mol representation Bonds 0.300000 12.000000
|
|
1849
|
+
mol color ColorID '''+str(nc_colorid)+'''
|
|
1850
|
+
mol selection {'''+nc_sel+'''}
|
|
1851
|
+
mol material Opaque
|
|
1852
|
+
mol addrep top
|
|
1853
|
+
'''
|
|
1854
|
+
for ter_idx, thread_resid in enumerate(thread):
|
|
1855
|
+
if len(thread_resid) == 0:
|
|
1856
|
+
continue
|
|
1857
|
+
repres += '''mol representation NewCartoon 0.350000 10.000000 4.100000 0
|
|
1858
|
+
mol color ColorID '''+str(thread_colorid)+'''
|
|
1859
|
+
mol selection {'''+thread_sel_list[ter_idx]+'''}
|
|
1860
|
+
mol material Opaque
|
|
1861
|
+
mol addrep top
|
|
1862
|
+
'''
|
|
1863
|
+
if len(chg_ent_fingerprint[key_prefix+'crossing_residx'][ter_idx]) > 0:
|
|
1864
|
+
idx_list = []
|
|
1865
|
+
for res in pmd_clean_struct.residues:
|
|
1866
|
+
if res.idx in chg_ent_fingerprint[key_prefix+'crossing_residx'][ter_idx]:
|
|
1867
|
+
idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
|
|
1868
|
+
crossing_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1869
|
+
repres += '''mol representation VDW 1.000000 12.000000
|
|
1870
|
+
mol color ColorID '''+str(crossing_colorid)+'''
|
|
1871
|
+
mol selection {'''+crossing_sel+'''}
|
|
1872
|
+
mol material Opaque
|
|
1873
|
+
mol addrep top
|
|
1874
|
+
'''
|
|
1875
|
+
#######################################
|
|
1876
|
+
## showing experimenal signal residues
|
|
1877
|
+
for es in exp_signal_list:
|
|
1878
|
+
idx_list = []
|
|
1879
|
+
for res in pmd_clean_struct.residues:
|
|
1880
|
+
if res.idx+1 in es:
|
|
1881
|
+
idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
|
|
1882
|
+
es_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1883
|
+
if len(es) == 1: # LiP-MS signal
|
|
1884
|
+
repres += '''# LiP-MS PK site '''+str(es[0])+'''
|
|
1885
|
+
mol representation VDW 1.000000 12.000000
|
|
1886
|
+
mol color ColorID '''+str(LIP_colorid)+'''
|
|
1887
|
+
mol selection {'''+es_sel+'''}
|
|
1888
|
+
mol material Opaque
|
|
1889
|
+
mol addrep top
|
|
1890
|
+
'''
|
|
1891
|
+
elif len(es) == 2: # XL-MS signal
|
|
1892
|
+
repres += '''# XL-MS pair ('''+str(es[0])+''', '''+str(es[1])+''')
|
|
1893
|
+
mol representation VDW 1.000000 12.000000
|
|
1894
|
+
mol color ColorID '''+str(XL_colorid)+'''
|
|
1895
|
+
mol selection {'''+es_sel+'''}
|
|
1896
|
+
mol material Opaque
|
|
1897
|
+
mol addrep top
|
|
1898
|
+
label add Bonds '''+('%d/%d %d/%d'%(struct_idx, idx_list[0], struct_idx, idx_list[1]))+'''
|
|
1899
|
+
'''
|
|
1900
|
+
|
|
1901
|
+
if struct_idx == 0:
|
|
1902
|
+
repres += '''mol representation VDW 1.000000 12.000000
|
|
1903
|
+
mol color Name
|
|
1904
|
+
mol selection {not ('''+vmd_sel+''') and not water}
|
|
1905
|
+
mol material Opaque
|
|
1906
|
+
mol addrep top
|
|
1907
|
+
'''
|
|
1908
|
+
repres_list[struct_idx] = repres
|
|
1909
|
+
# print(f'representation:\n{repres}')
|
|
1910
|
+
align_sel_list[struct_idx] = align_sel
|
|
1911
|
+
# print(f'align selection:\n{align_sel}')
|
|
1912
|
+
|
|
1913
|
+
vmd_script += '\n'.join(repres_list)
|
|
1914
|
+
vmd_script += '''
|
|
1915
|
+
set sel1 [atomselect 0 "'''+align_sel_list[0]+''' and name CA"]
|
|
1916
|
+
set sel2 [atomselect 1 "'''+align_sel_list[1]+''' and name CA"]
|
|
1917
|
+
set trans_mat [measure fit $sel1 $sel2]
|
|
1918
|
+
set move_sel [atomselect 0 "all"]
|
|
1919
|
+
$move_sel move $trans_mat
|
|
1920
|
+
'''
|
|
1921
|
+
vmd_outfile = os.path.join(kk_str_dir, f'vmd_s{state_id + 1}_e{chg_idx + 1}_n{ent_code[0]}_c{ent_code[1]}.tcl')
|
|
1922
|
+
f = open(vmd_outfile, 'w')
|
|
1923
|
+
f.write(vmd_script)
|
|
1924
|
+
f.close()
|
|
1925
|
+
logger.debug(f'SAVED: {vmd_outfile}')
|
|
1926
|
+
|
|
1927
|
+
##############################################################################
|