EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,1927 @@
1
+ #!/usr/bin/env python3
2
+ import sys, getopt, math, os, time, traceback, glob, multiprocessing, copy
3
+ import numpy as np
4
+ from scipy.stats import norm
5
+ import pandas as pd
6
+ from statsmodels.stats.multitest import multipletests
7
+ import parmed as pmd
8
+ import mdtraj as mdt
9
+ import pathlib, re
10
+ import logging
11
+ from EntDetect._logging import setup_logger
12
+
13
+ class MassSpec:
14
+ """
15
+ Compare ensembles of protein structures with LiPMS and XLMS experimental data.
16
+ Primary author: Yang Jiang
17
+ Secondary author: Ian Sitarik
18
+ """
19
+ #############################################################################################################
20
+ def __init__(self, msm_data_file:str, meta_dist_file:str, LiPMS_exp_file:str, sasa_data_file:str, XLMS_exp_file:str, dist_data_file:str,
21
+ cluster_data_file:str, OPpath:str, AAdcd_dir:str, native_AA_pdb:str, native_state_idx:int, state_idx_list:list, prot_len:int, last_num_frames:int,
22
+ rm_traj_list:list=[], outdir:str='./', ID:str='', xp_dir:str=None, resid2residueidx_map:dict={},
23
+ start:int=0, end:int=999999999999, stride:int=1, verbose:bool=False, num_perm:int=10000, n_boot:int=10000, lag_frame:int=1, nproc:int=1, log_level:int=logging.INFO, logdir:str=None):
24
+
25
+
26
+ self.msm_data_file = msm_data_file
27
+ self.meta_dist_file = meta_dist_file
28
+ self.LiPMS_exp_file = LiPMS_exp_file
29
+ self.sasa_data_file = sasa_data_file
30
+ self.XLMS_exp_file = XLMS_exp_file
31
+ self.dist_data_file = dist_data_file
32
+ self.xp_dir = xp_dir
33
+ self.cluster_data_file = cluster_data_file
34
+ self.OPpath = OPpath
35
+ self.AAdcd_dir = AAdcd_dir
36
+ self.native_AA_pdb = native_AA_pdb
37
+ self.native_state_idx = native_state_idx
38
+ self.state_idx_list = state_idx_list
39
+ self.rm_traj_list = rm_traj_list
40
+ self.outdir = outdir
41
+ self.ID = ID
42
+
43
+ # Initialize logging before any log calls in constructor.
44
+ self.logger = setup_logger('MassSpec', outdir=logdir if logdir is not None else self.outdir, ID=self.ID, log_level=log_level)
45
+
46
+ self.resid2residueidx_map = resid2residueidx_map
47
+ if len(self.resid2residueidx_map) == 0:
48
+ self.resid2residueidx_map = {i + 1:i for i in range(prot_len)}
49
+ self.logger.info(f'No resid2residueidx_map provided, using identity mapping for protein length {prot_len} with an offset of -1')
50
+ self.start = start
51
+ self.end = end
52
+ self.stride = stride
53
+ self.verbose = verbose
54
+
55
+ self.res_buffer = 5
56
+ self.num_perm = num_perm
57
+ self.prot_len = prot_len
58
+ self.nproc = nproc
59
+
60
+ self.if_calc_M = 1
61
+
62
+ self.last_num_frames = last_num_frames # last X ns
63
+ self.lag_frame = lag_frame # down sample trajectories at each #lag_frame frame
64
+ self.n_boot = n_boot
65
+ #self.n_boot = 100
66
+
67
+ # make the outdir if it doesnt existgs
68
+ if not os.path.exists(self.outdir):
69
+ os.makedirs(self.outdir)
70
+ self.logger.debug(f'Creating directory: {self.outdir}')
71
+ ##############################################################################
72
+
73
+ ##############################################################################
74
+ def load_LiPMS_data(self, file_path):
75
+ df = pd.read_excel(file_path)
76
+ LiPMS_sig_data = {}
77
+ for index, row in df.iterrows():
78
+ peptide = row['Cut Site']
79
+ if '-' not in peptide:
80
+ resid = int(peptide.strip()[1:])-1
81
+ peptide_range = set(list(np.arange(np.max([0,resid-self.res_buffer]), np.min([self.prot_len,resid+self.res_buffer+1]))))
82
+ if peptide not in LiPMS_sig_data:
83
+ LiPMS_sig_data[peptide] = {}
84
+ LiPMS_sig_data[peptide]['peptide_range'] = []
85
+ LiPMS_sig_data[peptide]['qual_change'] = []
86
+ LiPMS_sig_data[peptide]['peptide_range'] = peptide_range
87
+ if row['Log2 FC'] < 0:
88
+ LiPMS_sig_data[peptide]['qual_change'].append(-1)
89
+ elif row['Log2 FC'] > 0:
90
+ LiPMS_sig_data[peptide]['qual_change'].append(1)
91
+ elif row['Log2 FC'] == 0:
92
+ LiPMS_sig_data[peptide]['qual_change'].append(0)
93
+ removed_key_list = []
94
+ for k, v in LiPMS_sig_data.items():
95
+ a = list(set(v['qual_change']))
96
+ if len(a) == 1:
97
+ v['qual_change'] = a[0]
98
+ else:
99
+ self.logger.info('Site %s has inconsistent changes in abundance: %s'%(k, str(v['qual_change'])))
100
+ removed_key_list.append(k)
101
+ for k in removed_key_list:
102
+ LiPMS_sig_data.pop(k)
103
+ LiPMS_sig_data = dict(sorted(LiPMS_sig_data.items(), key=lambda item: int(item[0].strip()[1:])))
104
+ return LiPMS_sig_data
105
+ ##############################################################################
106
+
107
+ ##############################################################################
108
+ def load_XLMS_data(self, file_path):
109
+ df = pd.read_excel(file_path)
110
+ XLMS_sig_data = {}
111
+ for index, row in df.iterrows():
112
+ XL_site_key = row['Pairs'].strip()
113
+ XL_site = XL_site_key.split('-')
114
+ if XL_site_key not in XLMS_sig_data.keys():
115
+ XLMS_sig_data[XL_site_key] = {}
116
+ XLMS_sig_data[XL_site_key]['qual_change'] = []
117
+ XLMS_sig_data[XL_site_key]['res_pair'] = []
118
+ XLMS_sig_data[XL_site_key]['res_pair'] = [int(XL_site[0][1:])-1, int(XL_site[1][1:])-1]
119
+ if row['log2(heavy/light)'] < 0:
120
+ XLMS_sig_data[XL_site_key]['qual_change'].append(-1)
121
+ elif row['log2(heavy/light)'] > 0:
122
+ XLMS_sig_data[XL_site_key]['qual_change'].append(1)
123
+ elif row['log2(heavy/light)'] == 0:
124
+ XLMS_sig_data[XL_site_key]['qual_change'].append(0)
125
+ removed_key_list = []
126
+ for k, v in XLMS_sig_data.items():
127
+ a = list(set(v['qual_change']))
128
+ if len(a) == 1:
129
+ v['qual_change'] = a[0]
130
+ else:
131
+ self.logger.info('Sites %s has inconsistent changes in abundance: %s'%(k, str(v['qual_change'])))
132
+ removed_key_list.append(k)
133
+ for k in removed_key_list:
134
+ XLMS_sig_data.pop(k)
135
+ XLMS_sig_data = dict(sorted(XLMS_sig_data.items(), key=lambda item: (int(item[0].split('-')[0][1:]), int(item[0].split('-')[1][1:]))))
136
+ return XLMS_sig_data
137
+ ##############################################################################
138
+
139
+ ##############################################################################
140
+ def score_XL(self, pair_AA, JWalk_dist):
141
+ XL_offset = 1.1
142
+ sc_length = {'K': 6.3,
143
+ 'S': 2.5,
144
+ 'T': 2.5,
145
+ 'Y': 6.5,
146
+ 'M': 1.5,}
147
+ KK_mu = 18.6
148
+ KK_sigma = 6.0
149
+ KK_threshold = 33
150
+
151
+ KK_mu += XL_offset
152
+ KK_sigma = (XL_offset + 3*KK_sigma) / 3
153
+ KK_threshold += XL_offset
154
+
155
+ mu = KK_mu + (sc_length[pair_AA[0]] + sc_length[pair_AA[1]]) - 2*sc_length['K']
156
+ sigma = (mu - (KK_mu - 3*KK_sigma)) / 3
157
+ threshold = KK_threshold + mu - KK_mu
158
+
159
+ N = norm(mu, sigma)
160
+
161
+ if JWalk_dist == -1:
162
+ score = 0
163
+ elif JWalk_dist <= threshold:
164
+ score = N.pdf(JWalk_dist)
165
+ else:
166
+ score = 0
167
+ return score
168
+ ##############################################################################
169
+
170
+ ##############################################################################
171
+ def perm_fun(self, perm_idx_list, combined_data, length_1):
172
+ d_1 = combined_data[perm_idx_list[:length_1]]
173
+ d_2 = combined_data[perm_idx_list[length_1:]]
174
+ return self.statistic_fun(d_1, d_2, 0)
175
+ ##############################################################################
176
+
177
+ ##############################################################################
178
+ def permutation_test(self, perm_stat_fun, data_1, data_2, num_perm, side='!='):
179
+ if side not in ['!=', '>', '<']:
180
+ self.logger.info('side parameter is wrong for function "permutation_test". It must be "!=", ">", or "<".')
181
+ sys.exit()
182
+ combined_data = np.array(list(data_1) + list(data_2))
183
+ perm_idx_list_0 = np.arange(len(combined_data))
184
+ t0 = self.perm_fun(perm_idx_list_0, combined_data, len(data_1))
185
+
186
+ pool = multiprocessing.Pool(self.nproc)
187
+ pool_list = []
188
+ start_time = time.time()
189
+ self.logger.debug('start permutation test')
190
+ for i in range(num_perm):
191
+ perm_idx_list = np.random.permutation(perm_idx_list_0)
192
+ pool_list.append(pool.apply_async(self.perm_fun, (perm_idx_list, combined_data, len(data_1),)))
193
+ pool.close()
194
+ pool.join()
195
+ t_dist = [p.get() for p in pool_list]
196
+ p = 0
197
+ for t in t_dist:
198
+ if side == '!=' and np.abs(t) >= np.abs(t0):
199
+ p += 1
200
+ elif side == '>' and t >= t0:
201
+ p += 1
202
+ elif side == '<' and t <= t0:
203
+ p += 1
204
+ p = (p+1)/(num_perm+1)
205
+ used_time = time.time() - start_time
206
+ self.logger.info('%.2fs'%used_time)
207
+ return p
208
+ ##############################################################################
209
+
210
+ ##############################################################################
211
+ def bootstrap(self, boot_fun, data, n_time):
212
+ def fun(boot_fun, sample_idx_list):
213
+ if len(data.shape) == 1:
214
+ new_data = data[sample_idx_list]
215
+ return boot_fun(new_data)
216
+ elif len(data.shape) == 2:
217
+ new_data = data[sample_idx_list, :]
218
+ result = np.array([boot_fun(new_data[:,j]) for j in range(data.shape[1])])
219
+ return result
220
+ idx_list = np.arange(len(data))
221
+ if len(data.shape) == 1:
222
+ boot_stat = np.zeros(n_time)
223
+ elif len(data.shape) == 2:
224
+ boot_stat = np.zeros((n_time, data.shape[1]))
225
+ else:
226
+ self.logger.info('bootstrap: Can only handle 1 or 2 dimentional data')
227
+ sys.exit()
228
+
229
+ boot_stat = []
230
+ for i in range(n_time):
231
+ sample_idx_list = np.random.choice(idx_list, len(idx_list))
232
+ bs = fun(boot_fun, sample_idx_list)
233
+ boot_stat.append(bs)
234
+ boot_stat = np.array(boot_stat)
235
+
236
+ # pool = multiprocessing.Pool(nproc)
237
+ # pool_list = []
238
+ # start_time = time.time()
239
+ # print('start bootstrapping')
240
+ # for i in range(n_time):
241
+ # sample_idx_list = np.random.choice(idx_list, len(idx_list))
242
+ # pool_list.append(pool.apply_async(fun, (boot_fun, sample_idx_list)))
243
+ # pool.close()
244
+ # pool.join()
245
+ # boot_stat = np.array([p.get() for p in pool_list])
246
+ # used_time = time.time() - start_time
247
+ # print('%.2fs'%used_time)
248
+ return boot_stat
249
+ ##############################################################################
250
+
251
+ ##############################################################################
252
+ def remove_traj_from_frame_list(self, rm_traj_list, frame_list, traj_axis):
253
+ if len(frame_list) == 0:
254
+ result_list = frame_list
255
+ else:
256
+ if traj_axis == 0: # traj ids in 1st row
257
+ sel_idx = []
258
+ for idx, i in enumerate(frame_list[0,:]):
259
+ if i not in rm_traj_list:
260
+ sel_idx.append(idx)
261
+ result_list = frame_list[:,sel_idx]
262
+ elif traj_axis == 1: # traj ids in 1st column
263
+ sel_idx = []
264
+ for idx, i in enumerate(frame_list[:,0]):
265
+ if i not in rm_traj_list:
266
+ sel_idx.append(idx)
267
+ result_list = frame_list[sel_idx,:]
268
+ return result_list
269
+ ##############################################################################
270
+
271
+ ##############################################################################
272
+ def statistic_fun(self, data_1, data_2, ref):
273
+ a = (np.mean(data_1) - np.mean(data_2)) - ref
274
+ b = (np.std(data_1)**2/len(data_1) + np.std(data_2)**2/len(data_2))**0.5
275
+ if a == 0 and b == 0:
276
+ stat = 1.0
277
+ else:
278
+ stat = a/b
279
+ return stat
280
+ ##############################################################################
281
+
282
+ ##############################################################################
283
+ def bootstrap_test(self, data_1, data_2, statistic_fun, n_time, side='!='):
284
+ if side not in ['!=', '>', '<']:
285
+ self.logger.info('side parameter is wrong for function "bootstrap_test". It must be "!=", ">", or "<".')
286
+ sys.exit()
287
+ idx_list_1 = np.arange(len(data_1))
288
+ idx_list_2 = np.arange(len(data_2))
289
+ boot_stat = []
290
+ ref_0 = 0
291
+ ref_1 = np.mean(data_1) - np.mean(data_2)
292
+ for i in range(n_time):
293
+ sample_idx_list_1 = np.random.choice(idx_list_1, len(idx_list_1))
294
+ sample_idx_list_2 = np.random.choice(idx_list_2, len(idx_list_2))
295
+ bs = statistic_fun(data_1[sample_idx_list_1], data_2[sample_idx_list_2], ref_1)
296
+ boot_stat.append(bs)
297
+ boot_stat = np.array(boot_stat)
298
+
299
+ boot_stat_0 = statistic_fun(data_1, data_2, ref_0)
300
+
301
+ if side == '!=':
302
+ p = (np.min([len(np.where(boot_stat >= boot_stat_0)[0]), len(np.where(boot_stat <= boot_stat_0)[0])])+1) / (n_time+1) * 2
303
+ elif side == '>':
304
+ p = (len(np.where(boot_stat >= boot_stat_0)[0])+1) / (n_time+1)
305
+ elif side == '<':
306
+ p = (len(np.where(boot_stat <= boot_stat_0)[0])+1) / (n_time+1)
307
+
308
+ if p > 1:
309
+ p = 1.0
310
+
311
+ return (p, boot_stat)
312
+ ##############################################################################
313
+
314
+
315
+ ##############################################################################
316
+ def LiP_XL_MS_ConsistencyTest(self,):
317
+ self.logger.info(f'Comparing simulation to experimental data...')
318
+ xlsx_outfile = os.path.join(self.outdir, f'LiPMS_XLMS_consist_pvalues_metastates_v11_down_sample_lag{self.lag_frame}.xlsx')
319
+ npz_outfile = os.path.join(self.outdir, 'LiPMS_XLMS_consist_data_v9.npz')
320
+ self.logger.debug(f'xlsx_outfile: {xlsx_outfile}')
321
+ self.logger.debug(f'npz_outfile: {npz_outfile}')
322
+
323
+ if os.path.exists(npz_outfile) and os.path.exists(xlsx_outfile):
324
+ self.logger.info(f'npz_outfile EXISTS: Loading...')
325
+ self.logger.info(f'xlsx_outfile EXISTS: Loading...')
326
+ M_data = np.load(npz_outfile, allow_pickle=True)
327
+ XLSX_df = pd.read_excel(xlsx_outfile)
328
+ #print(f'XLSX_df:\n{XLSX_df}')
329
+
330
+ else:
331
+
332
+ #################################################################
333
+ # Load MSM data
334
+ MSM_data = pd.read_csv(self.msm_data_file)
335
+ self.logger.info(f'MSM_data\n{MSM_data}')
336
+ meta_states = MSM_data['metastablestate'].unique()
337
+ meta_states = np.array(meta_states, dtype=int)
338
+ self.logger.debug(f'meta_states: {meta_states}')
339
+ num_meta_states = len(meta_states)
340
+ self.logger.debug(f'num_meta_states: {num_meta_states}')
341
+
342
+
343
+ meta_dtrajs_last = []
344
+ traj_idx_to_trajnum = {} # mapping traj_idx to traj number
345
+ for traj_idx, (traj, traj_df) in enumerate(MSM_data.groupby('traj')):
346
+ traj_len = len(traj_df)
347
+ self.logger.debug(f'traj: {traj}, traj_len: {traj_len}\n{traj_df.head()}')
348
+
349
+ last = traj_df.iloc[-self.last_num_frames:,:]
350
+ last = last.reset_index(drop=True)
351
+ last = last['metastablestate'].values
352
+ #print(f'last: {last}')
353
+ meta_dtrajs_last.append(last)
354
+
355
+ self.logger.debug(f'traj_idx: {traj_idx}, traj: {traj}, traj_len: {traj_len}, last_num_frames: {self.last_num_frames}, last: {last} {len(last)}')
356
+ traj_idx_to_trajnum[traj_idx] = traj
357
+
358
+ meta_dtrajs_last = np.array(meta_dtrajs_last)
359
+ self.logger.info(f'meta_dtrajs_last.shape: {meta_dtrajs_last.shape}')
360
+ self.logger.info(f'meta_dtrajs_last\n{meta_dtrajs_last} {meta_dtrajs_last.shape}')
361
+ self.logger.debug(np.unique(meta_dtrajs_last))
362
+
363
+ # Keep MSM trajectory indexing aligned with downstream OP arrays.
364
+ rm_traj_set = set(int(t) for t in self.rm_traj_list)
365
+ keep_traj_idx = [
366
+ idx for idx in range(meta_dtrajs_last.shape[0])
367
+ if int(traj_idx_to_trajnum[idx]) not in rm_traj_set
368
+ ]
369
+ meta_dtrajs_last = meta_dtrajs_last[keep_traj_idx, :]
370
+ traj_idx_to_trajnum = {
371
+ new_idx: int(traj_idx_to_trajnum[old_idx])
372
+ for new_idx, old_idx in enumerate(keep_traj_idx)
373
+ }
374
+ self.logger.info(
375
+ f'meta_dtrajs_last.shape after mirror-image removal: {meta_dtrajs_last.shape}'
376
+ )
377
+ #################################################################
378
+
379
+
380
+ #################################################################
381
+ # Create frame_list for each state
382
+ # The result is a list where each element is a 2D array with the first column being the trajectory index and the second column being the frame index
383
+ sel_frame_idx = np.arange(0, self.last_num_frames, self.lag_frame)
384
+ self.logger.info(f'sel_frame_idx:\n{sel_frame_idx}')
385
+
386
+ frame_list = []
387
+ empty_states = []
388
+ for state_idx in self.state_idx_list:
389
+ self.logger.info(f'\nGetting frames for state {state_idx}...')
390
+ frame_list_0 = np.array(np.where(meta_dtrajs_last[:, sel_frame_idx] == state_idx)).T
391
+ frame_list_0[:,1] = sel_frame_idx[frame_list_0[:,1]]
392
+ utrajs = np.unique(frame_list_0[:,0])
393
+ self.logger.debug(frame_list_0)
394
+ self.logger.debug(f'utrajs: {utrajs}')
395
+
396
+ #frame_list_0 = self.remove_traj_from_frame_list(self.rm_traj_list, frame_list_0, 1)
397
+ if len(frame_list_0) == 0:
398
+ self.logger.info(f'No frames for state {state_idx} in the last {self.last_num_frames} frames. Exitting. The state maybe made entirely of mirror traj and in the self.rm_traj_list!')
399
+ empty_states.append(state_idx)
400
+ continue
401
+
402
+ frame_list.append(frame_list_0)
403
+
404
+ ## adjust the frame_list and the state_idx_list depending on what states are populated
405
+ self.logger.debug(f'empty_states: {empty_states}')
406
+ self.state_idx_list = [idx for idx in self.state_idx_list if idx not in empty_states]
407
+ self.logger.info(f'Updated state_idx_list: {self.state_idx_list}')
408
+ for idx, arr in enumerate(frame_list):
409
+ state = self.state_idx_list[idx]
410
+ self.logger.info(f'state_list_idx={idx}, state={state}, shape={arr.shape}')
411
+
412
+ native_sel = np.array(np.where(meta_dtrajs_last[:, sel_frame_idx] == self.native_state_idx)).T
413
+ native_sel[:,1] = sel_frame_idx[native_sel[:,1]]
414
+ self.logger.info(f'native_sel:\n{native_sel} {native_sel.shape}')
415
+
416
+ #native_sel = self.remove_traj_from_frame_list(self.rm_traj_list, native_sel, 1)
417
+ #print(f'native_sel:\n{native_sel} {native_sel.shape}')
418
+ #################################################################
419
+
420
+
421
+ #################################################################
422
+ # Load SASA data
423
+ sasa_traj_list = np.load(self.sasa_data_file, allow_pickle=True)[:,-self.last_num_frames:,:]
424
+ self.logger.info(f'sasa_traj_list.shape: {sasa_traj_list.shape}')
425
+
426
+ # Apply the same trajectory filtering used for MSM indexing.
427
+ sasa_traj_list = sasa_traj_list[keep_traj_idx, :, :]
428
+ self.logger.info(f'sasa_traj_list.shape after mirror-image removal: {sasa_traj_list.shape}')
429
+ #################################################################
430
+
431
+
432
+ #################################################################
433
+ # XL-MS scores are computed either from a pre-built Jwalk.npy
434
+ # (legacy mode) or directly from per-trajectory XP files
435
+ # (memory-friendly streaming mode).
436
+ dist_traj_list = None
437
+ #################################################################
438
+
439
+
440
+ #################################################################
441
+ # Select the frames without nan residual SASA
442
+ # NAN SASA indicates a bad backmapped structure and I want to skip them
443
+ nan_frame_sel = np.where(np.isnan(sasa_traj_list))
444
+ nan_frame_sel = np.array([nan_frame_sel[0], nan_frame_sel[1]], dtype=int).T
445
+ self.logger.debug(nan_frame_sel)
446
+ frame_list = [np.array([sel for sel in frame if not sel.tolist() in nan_frame_sel.tolist()]) for frame in frame_list]
447
+ self.logger.info(f'frame_list:')
448
+ for idx, arr in enumerate(frame_list):
449
+ state = self.state_idx_list[idx]
450
+ self.logger.info(f'state_list_idx={idx}, state={state}, shape={arr.shape}')
451
+ native_sel = np.array([sel for sel in native_sel if not sel.tolist() in nan_frame_sel.tolist()])
452
+ self.logger.info(f'native_sel:\n{native_sel} {native_sel.shape}')
453
+ #################################################################
454
+
455
+
456
+ #################################################################
457
+ # Load LiPMS experimental data
458
+ LiPMS_sig_data = self.load_LiPMS_data(self.LiPMS_exp_file)
459
+ self.logger.info(f'Loaded LiPMS experimental data: {len(LiPMS_sig_data)} peptides')
460
+
461
+ # load XLMS experimental data
462
+ XLMS_sig_data = self.load_XLMS_data(self.XLMS_exp_file)
463
+ self.logger.info(f'Loaded XLMS experimental data: {len(XLMS_sig_data)} pairs')
464
+ #################################################################
465
+
466
+
467
+ #################################################################
468
+ # Calculate or load metric matrix
469
+ if self.if_calc_M == 1:
470
+ self.logger.debug('Calculating metric matrix...')
471
+ # calculate metric matrix
472
+ M_LiPMS = np.zeros((*meta_dtrajs_last.shape, len(LiPMS_sig_data)))
473
+ for idx, peptide in enumerate(LiPMS_sig_data.keys()):
474
+ sel = list(LiPMS_sig_data[peptide]['peptide_range'])
475
+ SA = np.sum(sasa_traj_list[:,:,sel], axis=-1)
476
+ M_LiPMS[:,:,idx] = SA
477
+
478
+ xlms_targets = []
479
+ for idx, key in enumerate(XLMS_sig_data.keys()):
480
+ pair_AA = [k[0] for k in key.split('-')]
481
+ key_0 = '-'.join([k[1:]+'|A' for k in key.split('-')])
482
+ xlms_targets.append((idx, key_0, pair_AA))
483
+
484
+ M_XLMS = np.zeros((*meta_dtrajs_last.shape, len(XLMS_sig_data)))
485
+ if self.dist_data_file is not None and os.path.exists(self.dist_data_file):
486
+ dist_traj_list = np.load(self.dist_data_file, allow_pickle=True)[:,-self.last_num_frames:]
487
+ self.logger.info(f'dist_traj_list.shape: {dist_traj_list.shape}')
488
+
489
+ # Apply the same trajectory filtering used for MSM indexing.
490
+ dist_traj_list = dist_traj_list[keep_traj_idx, :]
491
+ self.logger.info(f'dist_traj_list.shape after mirror-image removal: {dist_traj_list.shape}')
492
+
493
+ n_traj = min(meta_dtrajs_last.shape[0], dist_traj_list.shape[0])
494
+ n_frame = min(meta_dtrajs_last.shape[1], dist_traj_list.shape[1])
495
+ for i in range(n_traj):
496
+ for j in range(n_frame):
497
+ frame_data = dist_traj_list[i, j]
498
+ if frame_data is None:
499
+ continue
500
+ for idx, key_0, pair_AA in xlms_targets:
501
+ if key_0 not in frame_data:
502
+ continue
503
+ JWalk_dist = frame_data[key_0].get('Jwalk', -1)
504
+ M_XLMS[i, j, idx] = self.score_XL(pair_AA, JWalk_dist)
505
+ elif self.xp_dir is not None and os.path.isdir(self.xp_dir):
506
+ self.logger.info(f'Computing XL-MS scores in streaming mode from XP files in: {self.xp_dir}')
507
+ target_lookup = {key_0: (idx, pair_AA) for idx, key_0, pair_AA in xlms_targets}
508
+
509
+ for traj_idx in range(meta_dtrajs_last.shape[0]):
510
+ traj_num = int(traj_idx_to_trajnum[traj_idx])
511
+ if traj_num in rm_traj_set:
512
+ continue
513
+
514
+ fpath = os.path.join(self.xp_dir, f'{self.ID}_Traj{traj_num}.XP')
515
+ if not os.path.exists(fpath):
516
+ self.logger.warning(f'Missing XP file for streaming XL-MS scoring: {fpath}')
517
+ continue
518
+
519
+ df = pd.read_csv(
520
+ fpath,
521
+ sep='\t',
522
+ usecols=['Frame', 'Atom1', 'Atom2', 'SASD'],
523
+ dtype={'Frame': np.int32, 'SASD': np.float32, 'Atom1': 'string', 'Atom2': 'string'},
524
+ )
525
+ if df.empty:
526
+ continue
527
+
528
+ frame_values = np.sort(df['Frame'].unique())
529
+ frame_to_idx = {int(f): idx for idx, f in enumerate(frame_values)}
530
+
531
+ atom1_parts = df['Atom1'].str.split('-', expand=True)
532
+ atom2_parts = df['Atom2'].str.split('-', expand=True)
533
+ if atom1_parts.shape[1] < 3 or atom2_parts.shape[1] < 3:
534
+ self.logger.warning(f'Malformed Atom fields in {fpath}; skipping trajectory {traj_num}')
535
+ continue
536
+
537
+ frame_key_df = pd.DataFrame({
538
+ 'Frame': df['Frame'].values,
539
+ 'SASD': df['SASD'].values,
540
+ 'key': (atom1_parts[1].astype(str) + '|'+ atom1_parts[2].astype(str) + '-' +
541
+ atom2_parts[1].astype(str) + '|' + atom2_parts[2].astype(str)).values,
542
+ })
543
+
544
+ frame_key_df = frame_key_df[frame_key_df['key'].isin(target_lookup.keys())]
545
+ frame_key_df = frame_key_df.drop_duplicates(subset=['Frame', 'key'], keep='first')
546
+ if frame_key_df.empty:
547
+ continue
548
+
549
+ for key_0, key_df in frame_key_df.groupby('key', sort=False):
550
+ idx, pair_AA = target_lookup[key_0]
551
+ for frame_num, sasd in zip(key_df['Frame'].values, key_df['SASD'].values):
552
+ frame_idx = frame_to_idx.get(int(frame_num), None)
553
+ if frame_idx is None or frame_idx >= meta_dtrajs_last.shape[1]:
554
+ continue
555
+ M_XLMS[traj_idx, frame_idx, idx] = self.score_XL(pair_AA, float(sasd))
556
+
557
+ if (traj_idx + 1) % 50 == 0:
558
+ self.logger.info(f'Streamed XP scoring progress: {traj_idx + 1}/{meta_dtrajs_last.shape[0]} trajectories')
559
+ else:
560
+ raise ValueError(
561
+ 'XL-MS scoring requires either dist_data_file (Jwalk.npy) or xp_dir (per-trajectory XP files).'
562
+ )
563
+
564
+ # Save data
565
+ np.savez(npz_outfile,
566
+ M_LiPMS = M_LiPMS,
567
+ M_XLMS = M_XLMS)
568
+ self.logger.info(f'Saved metric matrices to {npz_outfile}')
569
+
570
+ else:
571
+ self.logger.debug('Loading metric matrix...')
572
+ M_data = np.load(npz_outfile, allow_pickle=True)
573
+ M_LiPMS = M_data['M_LiPMS'][:,-self.last_num_frames:,:]
574
+ M_XLMS = M_data['M_XLMS'][:,-self.last_num_frames:,:]
575
+ #################################################################
576
+
577
+
578
+ #################################################################
579
+ # Consistency tests
580
+ M_str_list = ['LiPMS', 'XLMS']
581
+ exp_data_list = [LiPMS_sig_data, XLMS_sig_data]
582
+ df_list = []
583
+ df_all_state_list = []
584
+ p_list = []
585
+ p_all_state_list = []
586
+ for idx, M in enumerate([M_LiPMS, M_XLMS]):
587
+ self.logger.info(M_str_list[idx]+':')
588
+ exp_data = exp_data_list[idx]
589
+
590
+ df_list_0 = []
591
+ for idx_0 in range(len(exp_data)):
592
+ self.logger.info('%s:'%(list(exp_data.keys())[idx_0]))
593
+ M_0 = M[:,:,idx_0]
594
+ index_str = list(exp_data.keys())[idx_0]
595
+ header = ['Near-native state', 'Sample size', '<M>']
596
+ header += ['p (!=)', 'Adjusted p (!=)']
597
+ if exp_data[list(exp_data.keys())[idx_0]]['qual_change'] > 0:
598
+ header += ['p (>)', 'Adjusted p (>)']
599
+ test_side = '>'
600
+ else:
601
+ header += ['p (<)', 'Adjusted p (<)']
602
+ test_side = '<'
603
+
604
+ M_0_native = M_0[native_sel[:,0], native_sel[:,1]]
605
+ # bootstrapping to get 95%CI
606
+ boot_stat_native = self.bootstrap(np.mean, M_0_native, self.n_boot)
607
+ lb_native = np.percentile(boot_stat_native, 2.5)
608
+ ub_native = np.percentile(boot_stat_native, 97.5)
609
+
610
+ df_data = []
611
+ df_all_state_data = []
612
+ # for near-native states
613
+ self.logger.info(f'len(self.state_idx_list): {self.state_idx_list} {len(self.state_idx_list)}')
614
+ for idx_1, state_id in enumerate(self.state_idx_list):
615
+ self.logger.info(f'state_list_idx={idx_1}, state={state_id}')
616
+ near_native_sel = np.array(frame_list[idx_1], dtype=int)
617
+ self.logger.debug(near_native_sel)
618
+ M_0_near_native = M_0[near_native_sel[:,0], near_native_sel[:,1]]
619
+ # bootstrapping to get 95%CI
620
+ boot_stat_near_native = self.bootstrap(np.mean, M_0_near_native, self.n_boot)
621
+ lb_near_native = np.percentile(boot_stat_near_native, 2.5)
622
+ ub_near_native = np.percentile(boot_stat_near_native, 97.5)
623
+ self.logger.info('Near-native state %d vs. Native state %d:'%(state_id+1, num_meta_states))
624
+ self.logger.info(' Sample size: %d vs. %d'%(len(M_0_near_native), len(M_0_native)))
625
+ self.logger.info(' <M>: %.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native))
626
+ p_value_list_0 = []
627
+ for ts in ['!=', test_side]:
628
+ p = self.permutation_test(self.perm_fun, M_0_near_native, M_0_native, self.num_perm, side=ts)
629
+ # p, _ = bootstrap_test(M_0_near_native, M_0_native, statistic_fun, n_boot, side=ts)
630
+ self.logger.info(' p-value ("%s") = %.4f'%(ts, p))
631
+ p_value_list_0.append(p)
632
+ p_list.append(p_value_list_0)
633
+ df_data.append([state_id+1, '%d vs. %d'%(len(M_0_near_native), len(M_0_native)), '%.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native), p_value_list_0[0], 0, p_value_list_0[1], 0])
634
+
635
+ df = pd.DataFrame(df_data, columns=header, index=[index_str]*len(df_data))
636
+ df_list_0.append(df)
637
+
638
+ # for all states
639
+ frame_list_1 = np.array(np.where(meta_dtrajs_last < num_meta_states)).T
640
+ #near_native_sel = self.remove_traj_from_frame_list(self.rm_traj_list, frame_list_1, 1)
641
+ near_native_sel = np.array([sel for sel in near_native_sel if not sel in nan_frame_sel]) # remove nan SASA
642
+
643
+ # Select only frames seperated by #lag_frame for each state in the trajectory
644
+ near_native_sel_idx = []
645
+ for i in np.unique(near_native_sel[:,0]):
646
+ idx = np.where(near_native_sel[:,0] == i)[0]
647
+ near_native_sel_idx.append(idx[0])
648
+ for iidx in idx[1:]:
649
+ if near_native_sel[iidx,1] - near_native_sel[near_native_sel_idx[-1],1] >= self.lag_frame:
650
+ near_native_sel_idx.append(iidx)
651
+ near_native_sel = near_native_sel[near_native_sel_idx,:]
652
+
653
+ M_0_near_native = M_0[near_native_sel[:,0], near_native_sel[:,1]]
654
+ # bootstrapping to get 95%CI
655
+ boot_stat_near_native = self.bootstrap(np.mean, M_0_near_native, self.n_boot)
656
+ lb_near_native = np.percentile(boot_stat_near_native, 2.5)
657
+ ub_near_native = np.percentile(boot_stat_near_native, 97.5)
658
+ self.logger.info('All states vs. Native state %d:'%(num_meta_states))
659
+ self.logger.info(' Sample size: %d vs. %d'%(len(M_0_near_native), len(M_0_native)))
660
+ self.logger.info(' <M>: %.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native))
661
+ p_value_list_0 = []
662
+ for ts in ['!=', test_side]:
663
+ p = self.permutation_test(self.perm_fun, M_0_near_native, M_0_native, self.num_perm, side=ts)
664
+ # p, _ = bootstrap_test(M_0_near_native, M_0_native, statistic_fun, n_boot, side=ts)
665
+ self.logger.info(' p-value ("%s") = %.4f'%(ts, p))
666
+ p_value_list_0.append(p)
667
+ p_all_state_list.append(p_value_list_0)
668
+ df_all_state_data.append(['All states', '%d vs. %d'%(len(M_0_near_native), len(M_0_native)), '%.4f [%.4f, %.4f] vs. %.4f [%.4f, %.4f]'%(np.mean(M_0_near_native), lb_near_native, ub_near_native, np.mean(M_0_native), lb_native, ub_native), p_value_list_0[0], 0, p_value_list_0[1], 0])
669
+ df = pd.DataFrame(df_all_state_data, columns=header, index=[index_str]*len(df_all_state_data))
670
+ df_all_state_list.append(df)
671
+ df_list.append(df_list_0)
672
+ #################################################################
673
+
674
+
675
+ #################################################################
676
+ # Correct p-value
677
+ p_list = np.array(p_list)
678
+ p_all_state_list = np.array(p_all_state_list)
679
+ adjusted_p_list = []
680
+ for pi in range(p_list.shape[1]):
681
+ pl = p_list[:,pi]
682
+ results = multipletests(pl, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
683
+ adjusted_p_list.append(results[1])
684
+ adjusted_p_list = np.array(adjusted_p_list).T
685
+ adjusted_p_all_state_list = []
686
+ for pi in range(p_all_state_list.shape[1]):
687
+ pl = p_all_state_list[:,pi]
688
+ results = multipletests(pl, alpha=0.05, method='fdr_bh', is_sorted=False, returnsorted=False)
689
+ adjusted_p_all_state_list.append(results[1])
690
+ adjusted_p_all_state_list = np.array(adjusted_p_all_state_list).T
691
+ #################################################################
692
+
693
+ #################################################################
694
+ # Update dataframes
695
+ pi = 0
696
+ for idx in range(len(df_list)):
697
+ for idx_0 in range(len(df_list[idx])):
698
+ header = list(df_list[idx][idx_0].columns)
699
+ for i in range(adjusted_p_list.shape[1]):
700
+ ii = -adjusted_p_list.shape[1]*2 + i*2 + 1
701
+ df_list[idx][idx_0][header[ii]] = adjusted_p_list[pi:pi+len(self.state_idx_list),i]
702
+ pi += len(df_list[idx][idx_0])
703
+
704
+ pi = 0
705
+ for df in df_all_state_list:
706
+ header = list(df.columns)
707
+ for i in range(adjusted_p_all_state_list.shape[1]):
708
+ ii = -adjusted_p_all_state_list.shape[1]*2 + i*2 + 1
709
+ df[header[ii]] = adjusted_p_all_state_list[pi:pi+len(df),i]
710
+ pi += len(df)
711
+
712
+ df_list.append(df_all_state_list)
713
+ M_str_list.append('All states')
714
+ #################################################################
715
+
716
+ #################################################################
717
+ # Creating Excel Writer Object from Pandas
718
+ with pd.ExcelWriter(xlsx_outfile, engine="openpyxl") as writer:
719
+ workbook=writer.book
720
+ for idx_0, df_list_0 in enumerate(df_list):
721
+ worksheet=workbook.create_sheet(M_str_list[idx_0])
722
+ writer.sheets[M_str_list[idx_0]] = worksheet
723
+ start_row = 0
724
+ for idx_1, df in enumerate(df_list_0):
725
+ df.to_excel(writer, sheet_name=M_str_list[idx_0], index=True, float_format="%.4f", startrow=start_row , startcol=0)
726
+ start_row += len(df) + 2
727
+ self.logger.debug(f'SAVED: {xlsx_outfile}')
728
+ #################################################################
729
+
730
+ self.logger.info(f'LiP_XL_MS_ConsistencyTest DONE!')
731
+ return npz_outfile, xlsx_outfile
732
+ ##############################################################################
733
+
734
+ #######################################################################################
735
+ def load_OP(self, start:int=0, end:int=99999999999):
736
+ """
737
+ Loads the GQ values of each trajectory into a 2D array and then appends it to a list
738
+ The list should have Nt = number of trajectories and each array should be n x 2 where n is the number of frames
739
+ """
740
+ self.logger.info(f'Loading G and Q order parameters...')
741
+ Qfiles = glob.glob(os.path.join(self.OPpath, 'Q/*.Q'))
742
+ QTrajs = [int(pathlib.Path(Qf).stem.split('Traj')[-1]) for Qf in Qfiles]
743
+
744
+ Gfiles = glob.glob(os.path.join(self.OPpath, 'G/*.G'))
745
+ GTrajs = [int(pathlib.Path(Gf).stem.split('Traj')[-1]) for Gf in Gfiles]
746
+
747
+ shared_Trajs = set(QTrajs).intersection(GTrajs)
748
+ shared_Trajs = sorted(shared_Trajs)
749
+ self.logger.debug(f'shared_Trajs: {shared_Trajs}')
750
+ #print(f'Shared Traj between Q and G: {shared_Trajs} {len(shared_Trajs)}')
751
+ self.logger.info(f'Number of Q files found: {len(Qfiles)} | Number of G files found: {len(Gfiles)}')
752
+
753
+ assert len(Qfiles) == len(Gfiles), f"The # of Q and G files {len(Qfiles)} and {len(Gfiles)} is not equal"
754
+
755
+ ## remove trajectories that are in the rm_traj_list
756
+ if len(self.rm_traj_list) > 0:
757
+ self.logger.info(f'Removing trajectories: {self.rm_traj_list}')
758
+ shared_Trajs = [traj for traj in shared_Trajs if traj not in self.rm_traj_list]
759
+ self.logger.info(f'Number of shared Traj after removing: {len(shared_Trajs)}')
760
+
761
+ # loop through the Qfiles and find matching Gfile
762
+ # then load the Q and G time series into a 2D array
763
+ Q_data = []
764
+ G_data = []
765
+ for traj in shared_Trajs:
766
+ #print(f'Traj: {traj}')
767
+
768
+ # get the cooresponding G and Q file
769
+ Qf = [f for f in Qfiles if f.endswith(f'Traj{traj}.Q')]
770
+ Gf = [f for f in Gfiles if f.endswith(f'Traj{traj}.G')]
771
+ #print(f'Qf: {Qf}')
772
+ #print(f'Gf: {Gf}')
773
+
774
+ ## Quality check to assert that only a single G and Q file were found
775
+ assert len(Qf) == 1, f"the number of Q files {len(Qf)} should equal 1 for Traj {traj}"
776
+ assert len(Gf) == 1, f"the number of G files {len(Gf)} should equal 1 for Traj {traj}"
777
+
778
+ # load the G Q data and extract only the time series column
779
+ #Qdata = pd.read_csv(Qf)['Q'].values[self.start:self.end:self.stride]
780
+ Qdata = pd.read_csv(Qf[0], sep=',')
781
+ Gdata = pd.read_csv(Gf[0])
782
+ if start < 0: # start was specified as negative and takens as slicing the end of the arry
783
+ Qdata = Qdata['total'].values[start:]
784
+ Gdata = Gdata['G'].values[start:]
785
+ else:
786
+ Qdata = Qdata[(Qdata['Frame'] >= start) & (Qdata['Frame'] <= end)]
787
+ Qdata = Qdata['total'].values.astype(float)
788
+
789
+ Gdata = Gdata[(Gdata['Frame'] >= start) & (Gdata['Frame'] <= end)]
790
+ Gdata = Gdata['G'].values.astype(float)
791
+ # print(f'Shape of OP: Q {Qdata.shape} G {Gdata.shape}')
792
+
793
+ ## Quality check that the G and Q data has the same number of frames
794
+ if Qdata.shape != Gdata.shape:
795
+ self.logger.warning(f"WARNING: The number of frames in Q {Qdata.shape} should equal the number of frames in G {Gdata.shape} in Traj {traj}")
796
+ continue
797
+
798
+ ## Check and ensure that Qdata or Gdata has no nan values
799
+ if np.isnan(Qdata).any():
800
+ raise ValueError(f'There is a NaN value in this Qdata')
801
+
802
+ if np.isnan(Gdata).any():
803
+ raise ValueError(f'There is a NaN value in this Gdata')
804
+
805
+ Q_data.append(Qdata)
806
+ G_data.append(Gdata)
807
+
808
+ Q_data = np.asarray(Q_data)
809
+ G_data = np.asarray(G_data)
810
+ self.logger.debug(f'Q_data: {Q_data.shape}')
811
+ self.logger.debug(f'G_data: {G_data.shape}')
812
+ return Q_data, G_data
813
+ ##############################################################################
814
+
815
+ ##############################################################################
816
+ def select_rep_structs(self, consist_data_file:str, consist_result_file:str, total_traj_num_frames:int, last_num_frames:int):
817
+ """
818
+ After performing the consistency test select representative structures with high consistency
819
+ """
820
+ self.logger.info(f'Selecting representative structure')
821
+
822
+ ############### Functions ###############
823
+ ##############################################################################
824
+ def parse_consistant_results(consist_result_file, sheet_name, test_type='two-tailed', significant_level=0.05):
825
+ data = {}
826
+ consist_data = pd.read_excel(consist_result_file, sheet_name=sheet_name)
827
+ num_row = consist_data.index.size
828
+ all_key_list = []
829
+ if test_type == 'two-tailed':
830
+ t_idx = -3
831
+ elif test_type == 'one-tailed':
832
+ t_idx = -1
833
+ else:
834
+ self.logger.error('Error: Wrong test_type = %s; can be either "two-tailed" or "one-tailed"'%test_type)
835
+ sys.exit()
836
+ sign = consist_data.columns[t_idx].split()[-1][1:-1]
837
+ for i in range(num_row):
838
+ row = consist_data.loc[i]
839
+ if type(row[-1]) == float and np.isnan(row[-1]):
840
+ continue
841
+ elif type(row[0]) == float and np.isnan(row[0]):
842
+ sign = row[t_idx].split()[-1][1:-1]
843
+ elif row[1] == 'All states':
844
+ continue
845
+ else:
846
+ if row[0] not in all_key_list:
847
+ all_key_list.append(row[0])
848
+ if row[t_idx] < significant_level:
849
+ if row[0] not in data.keys():
850
+ mean = float(row[3].strip().split('vs.')[-1].split()[0])
851
+ words = row[3].strip().split('vs.')[-1].split()[1:]
852
+ lb = float(words[0][1:-1])
853
+ ub = float(words[1][:-1])
854
+ data[row[0]] = [sign, [row[1]-1], mean, [lb, ub]]
855
+ else:
856
+ data[row[0]][1].append(row[1]-1)
857
+
858
+ for key in data.keys():
859
+ data[key].append(all_key_list.index(key))
860
+
861
+ return data
862
+ ##############################################################################
863
+
864
+ ##############################################################################
865
+ def calc_rel_change(traj_idx, frame_idx):
866
+ signal_list = dtrajs_MS[traj_idx, frame_idx]
867
+ rel_change_list = []
868
+ for signal in signal_list:
869
+ if '-' in signal[1:-1]:
870
+ data = XLMS_consist_data
871
+ M = M_XLMS
872
+ else:
873
+ data = LIPMS_consist_data
874
+ M = M_LiPMS
875
+ mean = data[signal][2]
876
+ v = M[traj_idx, frame_idx, data[signal][4]]
877
+ if mean == 0:
878
+ rel_change_list.append(np.abs(v/1e-5))
879
+ else:
880
+ rel_change_list.append(np.abs(v/mean-1))
881
+ return rel_change_list
882
+ ##############################################################################
883
+
884
+
885
+ ################### MAIN ###############
886
+ ##############################################################################
887
+ self.logger.info(f'Loading consistency test data from {consist_result_file}')
888
+ self.logger.info(f'Loading consistency test data from {consist_data_file}')
889
+
890
+ if_backmap = 0
891
+ pulchra_only = True
892
+ significant_level = 0.05
893
+ ##############################################################################
894
+
895
+ ##############################################################################
896
+ # Load MSM data
897
+ MSM_data = pd.read_csv(self.msm_data_file)
898
+ self.logger.info(f'MSM_data\n{MSM_data}')
899
+ meta_states = MSM_data['metastablestate'].unique()
900
+ meta_states = np.array(meta_states, dtype=int)
901
+ self.logger.debug(f'meta_states: {meta_states}')
902
+ num_meta_states = len(meta_states)
903
+ self.logger.debug(f'num_meta_states: {num_meta_states}')
904
+
905
+ rm_traj_set = set(int(t) for t in self.rm_traj_list)
906
+ meta_dtrajs_last = []
907
+ micro_dtrajs_last = []
908
+ MSM_traj_idx_to_trajnum = {} # mapping traj_idx to traj number (after rm_traj_list filtering)
909
+ for traj, traj_df in MSM_data.groupby('traj'):
910
+ traj = int(traj)
911
+ if traj in rm_traj_set:
912
+ continue
913
+ traj_len = len(traj_df)
914
+ #print(f'traj: {traj}, traj_len: {traj_len}\n{traj_df.head()}')
915
+
916
+ last = traj_df.iloc[-last_num_frames:,:]
917
+ last = last.reset_index(drop=True)
918
+ meta_last = last['metastablestate'].values
919
+ micro_last = last['microstate'].values
920
+ #print(f'last: {last}')
921
+ meta_dtrajs_last.append(meta_last)
922
+ micro_dtrajs_last.append(micro_last)
923
+
924
+ MSM_traj_idx_to_trajnum[len(MSM_traj_idx_to_trajnum)] = traj
925
+
926
+ meta_dtrajs_last = np.array(meta_dtrajs_last)
927
+ micro_dtrajs_last = np.array(micro_dtrajs_last)
928
+ self.logger.info(f'meta_dtrajs_last\n{meta_dtrajs_last} {meta_dtrajs_last.shape}')
929
+ self.logger.debug(np.unique(meta_dtrajs_last))
930
+ self.logger.info(f'micro_dtrajs_last\n{micro_dtrajs_last} {micro_dtrajs_last.shape}')
931
+ self.logger.debug(np.unique(micro_dtrajs_last))
932
+ self.logger.debug(f'MSM_traj_idx_to_trajnum: {MSM_traj_idx_to_trajnum}')
933
+
934
+ ## load the meta_dist data
935
+ meta_dist = np.load(self.meta_dist_file, allow_pickle=True)
936
+ self.logger.info(f'meta_dist:\n{meta_dist} {meta_dist.shape}')
937
+ ##############################################################################
938
+
939
+
940
+ ##############################################################################
941
+ # Load Consistency Metrics
942
+ consist_data = np.load(consist_data_file, allow_pickle=True)
943
+ M_LiPMS = consist_data['M_LiPMS'][:,-last_num_frames:,:]
944
+ M_XLMS = consist_data['M_XLMS'][:,-last_num_frames:,:]
945
+ self.logger.info(f'Loaded consistency metrics from {consist_data_file}')
946
+ self.logger.debug(f'M_LiPMS: {M_LiPMS.shape}')
947
+ self.logger.debug(f'M_XLMS: {M_XLMS.shape}')
948
+ ##############################################################################
949
+
950
+ ##############################################################################
951
+ # Load consistency test results
952
+ LIPMS_consist_data = parse_consistant_results(consist_result_file, sheet_name='LiPMS', test_type='two-tailed', significant_level=significant_level)
953
+ XLMS_consist_data = parse_consistant_results(consist_result_file, sheet_name='XLMS', test_type='two-tailed', significant_level=significant_level)
954
+ self.logger.info(f'Loaded consistency test results')
955
+ ##############################################################################
956
+
957
+
958
+ ##############################################################################
959
+ # Load cluster data
960
+ # Beware the order of traj in the cluster_data may not be the same as the order in the MSM_data
961
+ # we will correct that here so it is consistent moving forward with the analysis
962
+ cluster_data = np.load(self.cluster_data_file, allow_pickle=True)
963
+ idx2trajfile = cluster_data['idx2trajfile'].tolist()
964
+ idx2traj = np.asarray([int(f.strip().split('/')[-1].split('_')[0]) for f in idx2trajfile])
965
+ # print(f'idx2traj: {idx2traj} {len(idx2traj)}')
966
+
967
+ # get the ordering array and reorder all the cluster_data arrays
968
+ # also remove those trajectories that are in the rm_traj_list
969
+ order = np.argsort(idx2traj)
970
+ # print(f'order: {order} {len(order)}')
971
+ idx2traj = idx2traj[order]
972
+ # print(f'idx2traj: {idx2traj} {len(idx2traj)}')
973
+
974
+ idx_2_keep = [idx for idx, traj in enumerate(idx2traj) if traj not in self.rm_traj_list]
975
+ # print(f'idx_2_keep: {idx_2_keep} {len(idx_2_keep)}')
976
+
977
+ # reorder idx2trajfile
978
+ idx2traj = idx2traj[idx_2_keep]
979
+ # print(f'idx2traj after removal of mirror images: {idx2traj} {len(idx2traj)}')
980
+
981
+ # QC to ensure idx2traj matches MSM_traj_idx_to_trajnum
982
+ assert np.array_equal(idx2traj, np.asarray(list(MSM_traj_idx_to_trajnum.values()))), f"Error: idx2traj does not match MSM_traj_idx_to_trajnum: {idx2traj} vs. {np.asarray(list(MSM_traj_idx_to_trajnum.values()))}"
983
+
984
+ # load the dtraj data
985
+ dtrajs = cluster_data['dtrajs']
986
+ dtrajs = dtrajs[order]
987
+ dtrajs = dtrajs[idx_2_keep]
988
+ self.logger.info(f'dtrajs after removal of mirror images: {dtrajs.shape}')
989
+
990
+ # load the rep_chg_ent_dtrajs
991
+ rep_chg_ent_dtrajs = cluster_data['rep_chg_ent_dtrajs']
992
+ # Map resid to residue idx using resid2residueidx_map
993
+ self.logger.info(f'Mapping of rep_chg_ent_dtrajs resid to residu idx using resid2residueidx_map: {self.resid2residueidx_map}')
994
+ for traj_idx, traj_data in enumerate(rep_chg_ent_dtrajs):
995
+ for frame_idx, frame_data in enumerate(traj_data):
996
+ for fingerprint_id, fingerprint in frame_data.items():
997
+ # print(f'\n{"#"*50}\nTraj idx: {traj_idx} | Frame idx {frame_idx} | Fingerprint ID: {fingerprint_id}')
998
+
999
+ for fingerprint_key, new_key in {'crossing_resid':'crossing_residx', 'ref_crossing_resid':'ref_crossing_residx'}.items():
1000
+ residx_arr = fingerprint[fingerprint_key]
1001
+ residx_arr = [[self.resid2residueidx_map[x] for x in sublist] for sublist in residx_arr]
1002
+ fingerprint[new_key] = residx_arr
1003
+ # print(f' mapped {new_key}: {residx_arr}')
1004
+
1005
+ for fingerprint_key, new_key in {'native_contact':'native_contact_residx', 'ref_native_contact':'ref_native_contact_residx'}.items():
1006
+ residx_arr = fingerprint[fingerprint_key]
1007
+ residx_arr = [self.resid2residueidx_map[x] for x in residx_arr]
1008
+ fingerprint[new_key] = residx_arr
1009
+ # print(f' mapped {new_key}: {residx_arr}')
1010
+
1011
+ # for k,v in fingerprint.items():
1012
+ # print(f' {k}: {v}')
1013
+
1014
+ rep_chg_ent_dtrajs = rep_chg_ent_dtrajs[order]
1015
+ rep_chg_ent_dtrajs = rep_chg_ent_dtrajs[idx_2_keep]
1016
+ self.logger.info(f'rep_chg_ent_dtrajs after removal of mirror images: {rep_chg_ent_dtrajs.shape}')
1017
+
1018
+ sorted_chg_ent_structure_keyword_list = cluster_data['sorted_chg_ent_structure_keyword_list'].tolist()
1019
+ self.logger.debug(f'sorted_chg_ent_structure_keyword_list: {len(sorted_chg_ent_structure_keyword_list)} {sorted_chg_ent_structure_keyword_list[:10]}')
1020
+
1021
+ dtrajs_cluster_idx = np.array([[sorted_chg_ent_structure_keyword_list.index(str(dd)) for dd in d] for d in dtrajs])
1022
+ self.logger.debug(f'dtrajs_cluster_idx: {dtrajs_cluster_idx.shape}')
1023
+ self.logger.info(f'Loaded cluster_data_file: {self.cluster_data_file}')
1024
+ ##############################################################################
1025
+
1026
+ ##############################################################################
1027
+ # Load SASA data
1028
+ sasa_traj_list = np.load(self.sasa_data_file, allow_pickle=True)[:,-last_num_frames:,:]
1029
+ self.logger.info(f'Loaded SASA data: {self.sasa_data_file}')
1030
+
1031
+ # remove trajectories that are in the rm_traj_list - 1
1032
+ sasa_traj_list = [v for i, v in enumerate(sasa_traj_list) if i not in np.asarray(self.rm_traj_list) - 1]
1033
+ sasa_traj_list = np.array(sasa_traj_list)
1034
+ self.logger.info(f'sasa_traj_list.shape after removal of mirror images: {sasa_traj_list.shape}')
1035
+ ##############################################################################
1036
+
1037
+
1038
+ ##############################################################################
1039
+ # Change the ub and lb in LIPMS_consist_data and XLMS_consist_data
1040
+ # (1) Get the native state index (defined by the MSM indexing which is sorted in order of the trajectory number)
1041
+ native_sel = np.where(np.isin(meta_dtrajs_last, self.native_state_idx))
1042
+ native_sel = np.array([native_sel[0], native_sel[1]], dtype=int).T
1043
+
1044
+ # (2) remove any frames with NAN residual SASA
1045
+ nan_frame_sel = np.where(np.isnan(sasa_traj_list))
1046
+ nan_frame_sel = np.array([nan_frame_sel[0], nan_frame_sel[1]], dtype=int).T
1047
+ native_sel = np.array([sel for sel in native_sel if not sel.tolist() in nan_frame_sel.tolist()])
1048
+ self.logger.debug(f'native_sel: {native_sel} {native_sel.shape}')
1049
+
1050
+ # (3) Get the native M_LiPMS and M_XLMS
1051
+ native_M_LiPMS = M_LiPMS[native_sel[:,0], native_sel[:,1], :]
1052
+ native_M_XLMS = M_XLMS[native_sel[:,0], native_sel[:,1], :]
1053
+ native_M_data_outfile = os.path.join(self.outdir, 'native_M_data.npz')
1054
+ np.savez(native_M_data_outfile,
1055
+ M_LiPMS = native_M_LiPMS,
1056
+ M_XLMS = native_M_XLMS,)
1057
+ for key in LIPMS_consist_data:
1058
+ LIPMS_consist_data[key][3][0] = np.percentile(native_M_LiPMS[:, LIPMS_consist_data[key][4]], 2.5)
1059
+ LIPMS_consist_data[key][3][1] = np.percentile(native_M_LiPMS[:, LIPMS_consist_data[key][4]], 97.5)
1060
+ for key in XLMS_consist_data:
1061
+ XLMS_consist_data[key][3][0] = np.percentile(native_M_XLMS[:, XLMS_consist_data[key][4]], 2.5)
1062
+ XLMS_consist_data[key][3][1] = np.percentile(native_M_XLMS[:, XLMS_consist_data[key][4]], 97.5)
1063
+ self.logger.debug(f'SAVED: {native_M_data_outfile}')
1064
+ ##############################################################################
1065
+
1066
+ ##############################################################################
1067
+ # Go through LiPMS data
1068
+ LIPMS_struct_data = {}
1069
+ dtrajs_MS = np.empty(meta_dtrajs_last.shape, dtype=object)
1070
+
1071
+ # Initialize dtrajs_MS with empty lists
1072
+ for i in range(len(dtrajs_MS)):
1073
+ for j in range(len(dtrajs_MS[i])):
1074
+ dtrajs_MS[i,j] = []
1075
+
1076
+ # Go through each PK site in the LiPMS consistency data
1077
+ self.logger.info(f'\nProcessing LiPMS consistency data')
1078
+ for pk, data in LIPMS_consist_data.items():
1079
+ # print(pk, data)
1080
+
1081
+ LIPMS_struct_data[pk] = {}
1082
+ sign = data[0]
1083
+ state_list = data[1]
1084
+ sasa_ub = data[3][1]
1085
+ sasa_lb = data[3][0]
1086
+ idx_pk = data[4]
1087
+
1088
+ for idx_0, state_idx in enumerate(state_list):
1089
+ LIPMS_struct_data[pk][state_idx] = {}
1090
+
1091
+ idx_list = np.array(np.where(meta_dtrajs_last == state_idx)).T
1092
+ idx_list = np.array([sel for sel in idx_list if not sel.tolist() in nan_frame_sel.tolist()]) # Skip frames with NAN residual SASA (bad backmapped structure)
1093
+
1094
+ sasa_list = M_LiPMS[idx_list[:,0],idx_list[:,1], idx_pk]
1095
+
1096
+ for cluster_idx, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
1097
+ idx_list_1 = np.where(dtrajs_cluster_idx[idx_list[:,0],idx_list[:,1]] == cluster_idx)[0]
1098
+
1099
+ if len(idx_list_1) == 0:
1100
+ continue
1101
+ if sign == '>':
1102
+ idx_list_2 = np.where(sasa_list[idx_list_1] > sasa_ub)[0]
1103
+ idx_rep = np.argmax(sasa_list[idx_list_1])
1104
+ elif sign == '<':
1105
+ idx_list_2 = np.where(sasa_list[idx_list_1] < sasa_lb)[0]
1106
+ idx_rep = np.argmin(sasa_list[idx_list_1])
1107
+ else:
1108
+ idx_list_2 = np.where(np.any([sasa_list[idx_list_1] > sasa_ub, sasa_list[idx_list_1] < sasa_lb], axis=0))[0]
1109
+ idx_rep = np.argmax(np.max([sasa_list[idx_list_1]-sasa_ub, sasa_lb-sasa_list[idx_list_1]], axis=0))
1110
+ if len(idx_list_2) == 0:
1111
+ continue
1112
+
1113
+ consist_idx_list = idx_list[idx_list_1[idx_list_2], :]
1114
+ rep_idx = idx_list[idx_list_1[idx_rep], :]
1115
+ LIPMS_struct_data[pk][state_idx][cluster_idx] = [rep_idx, consist_idx_list]
1116
+ # print(f'PK: {pk}, state_idx: {state_idx}, cluster_idx: {cluster_idx}, rep_idx: {rep_idx}, consist_idx_list: {len(consist_idx_list)}')
1117
+
1118
+ for idx in consist_idx_list:
1119
+ dtrajs_MS[idx[0],idx[1]].append(pk)
1120
+ ##############################################################################
1121
+
1122
+ ##############################################################################
1123
+ # Go through XLMS data
1124
+ XLMS_struct_data = {}
1125
+ self.logger.info(f'\nProcessing XLMS consistency data')
1126
+ for pair, data in XLMS_consist_data.items():
1127
+ # print(pair, data)
1128
+
1129
+ # Initialize the structure data for each pair
1130
+ XLMS_struct_data[pair] = {}
1131
+ sign = data[0]
1132
+ state_list = data[1]
1133
+ dist_ub = data[3][1]
1134
+ dist_lb = data[3][0]
1135
+ idx_pair = data[4]
1136
+
1137
+ for state_idx in state_list:
1138
+ XLMS_struct_data[pair][state_idx] = {}
1139
+
1140
+ idx_list = np.array(np.where(meta_dtrajs_last == state_idx)).T
1141
+ idx_list = np.array([sel for sel in idx_list if not sel.tolist() in nan_frame_sel.tolist()])
1142
+
1143
+ dist_list = M_XLMS[idx_list[:,0],idx_list[:,1],idx_pair]
1144
+ for cluster_idx, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
1145
+ idx_list_1 = np.where(dtrajs_cluster_idx[idx_list[:,0],idx_list[:,1]] == cluster_idx)[0]
1146
+ if len(idx_list_1) == 0:
1147
+ continue
1148
+ if sign == '>':
1149
+ idx_list_2 = np.where(dist_list[idx_list_1] > dist_ub)[0]
1150
+ idx_rep = np.argmax(dist_list[idx_list_1])
1151
+ elif sign == '<':
1152
+ idx_list_2 = np.where(dist_list[idx_list_1] < dist_lb)[0]
1153
+ idx_rep = np.argmin(dist_list[idx_list_1])
1154
+ else:
1155
+ idx_list_2 = np.where(np.any([dist_list[idx_list_1] > dist_ub, dist_list[idx_list_1] < dist_lb], axis=0))[0]
1156
+ idx_rep = np.argmax(np.max([dist_list[idx_list_1]-dist_ub, dist_lb-dist_list[idx_list_1]], axis=0))
1157
+ if len(idx_list_2) == 0:
1158
+ continue
1159
+
1160
+ consist_idx_list = idx_list[idx_list_1[idx_list_2], :]
1161
+ rep_idx = idx_list[idx_list_1[idx_rep], :]
1162
+ XLMS_struct_data[pair][state_idx][cluster_idx] = [rep_idx, consist_idx_list]
1163
+ #print(f'Pair: {pair}, state_idx: {state_idx}, cluster_idx: {cluster_idx}, rep_idx: {rep_idx}, consist_idx_list: {len(consist_idx_list)}')
1164
+
1165
+ for idx in consist_idx_list:
1166
+ dtrajs_MS[idx[0],idx[1]].append(pair)
1167
+ ##############################################################################
1168
+ self.logger.debug(f'dtrajs_MS: {dtrajs_MS.shape}')
1169
+
1170
+
1171
+ ##############################################################################
1172
+ # Group consistency data
1173
+ self.logger.info(f'\nGrouping consistency data based on consistency signals')
1174
+ consist_signal_dict = {}
1175
+ for i, d in enumerate(dtrajs_MS): # list of lists of consistency signals in each frame
1176
+
1177
+ for j, dd in enumerate(d): # list of consistency signals in frame idx j
1178
+
1179
+ # print(f'MSM traj idx: {i}, frame idx: {j}, consistency signal: {dd}')
1180
+ cluster_idxs = dtrajs_cluster_idx[i,j]
1181
+ # print(f'Cluster idx: {cluster_idxs}')
1182
+
1183
+ if str(dd) not in consist_signal_dict.keys():
1184
+ consist_signal_dict[str(dd)] = {cluster_idxs: [[i,j]]}
1185
+
1186
+ elif cluster_idxs not in consist_signal_dict[str(dd)].keys():
1187
+ consist_signal_dict[str(dd)][cluster_idxs] = [[i,j]]
1188
+
1189
+ else:
1190
+ consist_signal_dict[str(dd)][cluster_idxs].append([i,j])
1191
+
1192
+ Num_struct_list = [np.sum(np.array([len(vv) for vv in v.values()])) for k,v in consist_signal_dict.items()]
1193
+ sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
1194
+ sorted_consist_signal_list = [list(consist_signal_dict.keys())[idx] for idx in sort_idx]
1195
+ # print(f'sorted_consist_signal_list: {sorted_consist_signal_list}')
1196
+ sorted_consist_signal_dict = {}
1197
+ for k in sorted_consist_signal_list:
1198
+ kk_list = sorted(list(consist_signal_dict[k].keys()))
1199
+ sorted_consist_signal_dict[k] = {kk: consist_signal_dict[k][kk] for kk in kk_list}
1200
+
1201
+ # for a, b in sorted_consist_signal_dict[k].items():
1202
+ # print(f'\n{k}, {a}, {b}')
1203
+ ##############################################################################
1204
+
1205
+
1206
+
1207
+ ##############################################################################
1208
+ # Group based on metastable states, then consistensy
1209
+ self.logger.info(f'\nGrouping consistency data based on metastable states and consistency signals')
1210
+ group_dict = {}
1211
+ for state_id in self.state_idx_list:
1212
+ # print(f'Processing state {state_id}...')
1213
+ group_dict[state_id] = {}
1214
+
1215
+ idx_list = np.array(np.where(meta_dtrajs_last == state_id)).T
1216
+
1217
+ for idx_list_0 in idx_list:
1218
+
1219
+ [i, j] = list(idx_list_0)
1220
+ dd = dtrajs_MS[i, j]
1221
+ # print(f'Processing idx: {i}, {j}, dd: {dd}')
1222
+ cluster_idxs = dtrajs_cluster_idx[i,j]
1223
+ # print(f'Cluster idx: {cluster_idxs}')
1224
+
1225
+ if str(dd) not in group_dict[state_id].keys():
1226
+ group_dict[state_id][str(dd)] = {cluster_idxs: [[i,j]]}
1227
+
1228
+ elif cluster_idxs not in group_dict[state_id][str(dd)].keys():
1229
+ group_dict[state_id][str(dd)][cluster_idxs] = [[i,j]]
1230
+
1231
+ else:
1232
+ group_dict[state_id][str(dd)][cluster_idxs].append([i,j])
1233
+
1234
+ # sort based on population
1235
+ Num_struct_list = [np.sum(np.array([len(vv) for vv in v.values()])) for k,v in group_dict[state_id].items()]
1236
+ sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
1237
+ sorted_list = [list(group_dict[state_id].keys())[idx] for idx in sort_idx]
1238
+ new_dict = {}
1239
+ for k in sorted_list:
1240
+ kk_list = sorted(list(group_dict[state_id][k].keys()))
1241
+ new_dict[k] = {kk: group_dict[state_id][k][kk] for kk in kk_list}
1242
+ group_dict[state_id] = new_dict
1243
+
1244
+ # print(f'\nGrouped consistency data for state {state_id}:')
1245
+ # for k,v in group_dict[state_id].items():
1246
+ # print(f'{k}: {v}')
1247
+ ##############################################################################
1248
+
1249
+ ##############################################################################
1250
+ # Load Q list
1251
+ self.logger.info(f'Loading G and Q data from {self.OPpath}')
1252
+ Q_list, G_list = self.load_OP(start=-last_num_frames)
1253
+ self.logger.info(f'Loaded G and Q data from {self.OPpath}')
1254
+ ##############################################################################
1255
+
1256
+ ##############################################################################
1257
+ # Get representative structures
1258
+ self.logger.info(f'\nGetting representative structures for each group...')
1259
+ rep_group_dict = {}
1260
+ for state_id in self.state_idx_list:
1261
+ # print(f'Processing state {state_id}...')
1262
+
1263
+ rep_group_dict[state_id] = {}
1264
+ for k in group_dict[state_id].keys():
1265
+
1266
+ if k == '[]':
1267
+ continue
1268
+
1269
+ # print(f' Processing consistent signal k={k}...')
1270
+ rep_group_dict[state_id][k] = {}
1271
+ for kk in group_dict[state_id][k].keys():
1272
+ # print(f' Processing entanglement ID kk={kk}...')
1273
+ idx_list = np.array(group_dict[state_id][k][kk])
1274
+ # print(f' idx_list: {idx_list}')
1275
+
1276
+ # Max micro-states probability
1277
+ micro_prob = meta_dist[state_id][micro_dtrajs_last[idx_list[:,0], idx_list[:,1]]]
1278
+ max_idx = np.where(micro_prob == np.max(micro_prob))[0]
1279
+ max_idx_list = idx_list[max_idx,:]
1280
+ # print(f' max_idx_list: {max_idx_list}')
1281
+
1282
+ # Max Q
1283
+ Q_list_0 = Q_list[max_idx_list[:,0], max_idx_list[:,1]]
1284
+ max_idx = np.where(Q_list_0 == np.max(Q_list_0))[0]
1285
+ max_idx_list = max_idx_list[max_idx,:]
1286
+ # print(f' max_idx_list after Q: {max_idx_list}')
1287
+
1288
+ # Max G
1289
+ G_list_0 = G_list[max_idx_list[:,0], max_idx_list[:,1]]
1290
+ max_idx = np.where(G_list_0 == np.max(G_list_0))[0]
1291
+ max_idx_list = max_idx_list[max_idx,:]
1292
+ # print(f' max_idx_list after G: {max_idx_list}')
1293
+
1294
+ [rep_traj_idx, rep_frame_idx] = max_idx_list[0,:]
1295
+ # print(f' Representative structure: MSM traj idx {rep_traj_idx} -> {idx2traj[rep_traj_idx]}, frame {rep_frame_idx}')
1296
+ rep_group_dict[state_id][k][kk] = [rep_traj_idx, rep_frame_idx]
1297
+ ##############################################################################
1298
+
1299
+
1300
+ ##############################################################################
1301
+ # Save data
1302
+ consist_signal_struct_data_outfile = os.path.join(self.outdir, 'consist_signal_struct_data.npz')
1303
+ np.savez(consist_signal_struct_data_outfile,
1304
+ last_num_frames = last_num_frames,
1305
+ total_traj_num_frames = total_traj_num_frames,
1306
+ LIPMS_consist_data=LIPMS_consist_data,
1307
+ XLMS_consist_data=XLMS_consist_data,
1308
+ LIPMS_struct_data=LIPMS_struct_data,
1309
+ XLMS_struct_data=XLMS_struct_data,
1310
+ dtrajs_MS=dtrajs_MS,
1311
+ sorted_consist_signal_dict=sorted_consist_signal_dict,
1312
+ group_dict=group_dict,
1313
+ rep_group_dict=rep_group_dict,)
1314
+ self.logger.debug(f'SAVED: {consist_signal_struct_data_outfile}')
1315
+
1316
+ # Save info to excel
1317
+ self.logger.info(f'Saving info to excel...')
1318
+ df_list = []
1319
+ sheet_name_list = []
1320
+ df_data = []
1321
+ for k in sorted_consist_signal_dict.keys():
1322
+ k_list = eval(k)
1323
+ if k == '[]':
1324
+ continue
1325
+ k_str = ', '.join(k_list)
1326
+
1327
+ for kk in sorted_consist_signal_dict[k].keys():
1328
+ kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
1329
+ if len(kk_list) == 0:
1330
+ continue
1331
+ kk_str = ', '.join([str(i+1) for i in kk_list])
1332
+ num = len(sorted_consist_signal_dict[k][kk])
1333
+ df_data.append([k_str, kk_str, len(k_list), num])
1334
+ df = pd.DataFrame(df_data, columns=['Consistent signals', 'IDs of Changes in Entanglements', 'Number of consistent signals', 'Number of Structures'])
1335
+ df_sorted = df.sort_values(by=['Number of consistent signals', 'Consistent signals', 'Number of Structures'], ascending=[False, True, False])
1336
+ df_list.append(df_sorted)
1337
+ sheet_name_list.append('Total')
1338
+
1339
+ for state_id in self.state_idx_list:
1340
+ sheet_name_list.append('State %d'%(state_id+1))
1341
+ df_data = []
1342
+
1343
+ for k in rep_group_dict[state_id].keys():
1344
+ k_list = eval(k)
1345
+ k_str = ', '.join(k_list)
1346
+ #print(f'\nProcessing state {state_id}, consistent signal {k_str}')
1347
+
1348
+ for kk in rep_group_dict[state_id][k].keys():
1349
+
1350
+ kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
1351
+ if len(kk_list) == 0:
1352
+ continue
1353
+ kk_str = ', '.join([str(i+1) for i in kk_list])
1354
+ [traj_idx, frame_idx] = rep_group_dict[state_id][k][kk]
1355
+ #print(f'traj_idx: {traj_idx} | frame_idx: {frame_idx}')
1356
+
1357
+ traj_frame_idx = total_traj_num_frames - last_num_frames + frame_idx
1358
+ #print(f'Processing state {state_id}, consistent signal {k_str}, entanglement ID {kk_str}, traj {traj_idx+1}, frame {traj_frame_idx}')
1359
+
1360
+ num = len(group_dict[state_id][k][kk])
1361
+ micro_prob = meta_dist[state_id][micro_dtrajs_last[traj_idx, frame_idx]]
1362
+ # rel_change_list = calc_rel_change(traj_idx, frame_idx)
1363
+ traj = idx2traj[traj_idx]
1364
+ df_data.append([k_str, kk_str, len(k_list), num, str([traj, traj_frame_idx+1]), micro_prob, Q_list[traj_idx, frame_idx], G_list[traj_idx, frame_idx] ])
1365
+
1366
+ df = pd.DataFrame(df_data, columns=['Consistent signals', 'IDs of Changes in Entanglements', 'Number of consistent signals', 'Number of Structures', 'Representative Structure (Traj #, Frame #)', 'Prob', 'Q', 'G'])
1367
+ df_sorted = df.sort_values(by=['Number of consistent signals', 'Prob', 'Q', 'G'], ascending=[False, False, False, False])
1368
+ df_list.append(df_sorted)
1369
+
1370
+ # Creating Excel Writer Object from Pandas
1371
+ Consistent_structures_v8_outfile = os.path.join(self.outdir, 'Consistent_structures_v8.xlsx')
1372
+ with pd.ExcelWriter(Consistent_structures_v8_outfile, engine="openpyxl") as writer:
1373
+ workbook=writer.book
1374
+ for idx_0, df in enumerate(df_list):
1375
+ worksheet=workbook.create_sheet(sheet_name_list[idx_0])
1376
+ writer.sheets[sheet_name_list[idx_0]] = worksheet
1377
+ df.to_excel(writer, sheet_name=sheet_name_list[idx_0], index=False)
1378
+ self.logger.debug(f'SAVED: {Consistent_structures_v8_outfile}')
1379
+ ##############################################################################
1380
+
1381
+
1382
+ ##############################################################################
1383
+ # Create visualization
1384
+ self.logger.info(f'Create visualizations...')
1385
+ if self.dist_data_file is not None and os.path.exists(self.dist_data_file):
1386
+ self.logger.info(f'Loading dist_traj_list: {self.dist_data_file}')
1387
+ dist_traj_list = np.load(self.dist_data_file, allow_pickle=True)[:,-last_num_frames:]
1388
+ self.logger.info(f'Loaded distance data: {dist_traj_list.shape}')
1389
+ dist_traj_list = [v for i, v in enumerate(dist_traj_list) if i not in np.asarray(self.rm_traj_list) - 1]
1390
+ dist_traj_list = np.array(dist_traj_list)
1391
+ self.logger.info(f'dist_traj_list after removal of mirror images: {dist_traj_list.shape}')
1392
+ elif self.xp_dir is not None and os.path.isdir(self.xp_dir):
1393
+ self.logger.info(f'Building sparse dist_traj_list from XP files in: {self.xp_dir}')
1394
+ dist_traj_list = np.empty((len(idx2traj), last_num_frames), dtype=object)
1395
+ dist_traj_list[:] = None
1396
+
1397
+ needed_frames_by_traj = {}
1398
+ for state_id in self.state_idx_list:
1399
+ for k in rep_group_dict[state_id].keys():
1400
+ for kk in rep_group_dict[state_id][k].keys():
1401
+ traj_idx, frame_idx = rep_group_dict[state_id][k][kk]
1402
+ needed_frames_by_traj.setdefault(int(traj_idx), set()).add(int(frame_idx))
1403
+
1404
+ for traj_idx, needed_frame_idx in needed_frames_by_traj.items():
1405
+ traj_num = int(idx2traj[traj_idx])
1406
+ fpath = os.path.join(self.xp_dir, f'{self.ID}_Traj{traj_num}.XP')
1407
+ if not os.path.exists(fpath):
1408
+ self.logger.warning(f'Missing XP file for sparse visualization load: {fpath}')
1409
+ continue
1410
+
1411
+ df = pd.read_csv(
1412
+ fpath,
1413
+ sep='\t',
1414
+ usecols=['Frame', 'Atom1', 'Atom2', 'Euclidean Distance', 'SASD'],
1415
+ dtype={
1416
+ 'Frame': np.int32,
1417
+ 'SASD': np.float32,
1418
+ 'Euclidean Distance': np.float32,
1419
+ 'Atom1': 'string',
1420
+ 'Atom2': 'string',
1421
+ },
1422
+ )
1423
+ if df.empty:
1424
+ continue
1425
+
1426
+ frame_values = np.sort(df['Frame'].unique())
1427
+ frame_to_idx = {int(f): idx for idx, f in enumerate(frame_values)}
1428
+ df['frame_idx'] = df['Frame'].map(frame_to_idx)
1429
+ df = df[df['frame_idx'].isin(needed_frame_idx)]
1430
+ if df.empty:
1431
+ continue
1432
+
1433
+ atom1_parts = df['Atom1'].str.split('-', expand=True)
1434
+ atom2_parts = df['Atom2'].str.split('-', expand=True)
1435
+ if atom1_parts.shape[1] < 3 or atom2_parts.shape[1] < 3:
1436
+ self.logger.warning(f'Malformed Atom fields in {fpath}; skipping sparse load for traj {traj_num}')
1437
+ continue
1438
+
1439
+ df['key'] = (
1440
+ atom1_parts[1].astype(str) + '|' + atom1_parts[2].astype(str) + '-' +
1441
+ atom2_parts[1].astype(str) + '|' + atom2_parts[2].astype(str)
1442
+ )
1443
+
1444
+ for frame_idx, frame_df in df.groupby('frame_idx', sort=False):
1445
+ frame_dict = {}
1446
+ for _, row in frame_df.iterrows():
1447
+ frame_dict[row['key']] = {
1448
+ 'Euclidean': float(row['Euclidean Distance']),
1449
+ 'Jwalk': float(row['SASD']),
1450
+ }
1451
+ dist_traj_list[traj_idx, int(frame_idx)] = frame_dict
1452
+ else:
1453
+ raise ValueError(
1454
+ 'Representative structure visualization requires either dist_data_file (Jwalk.npy) or xp_dir.'
1455
+ )
1456
+
1457
+
1458
+ # Check if the viz_rep_struct path exists. if so remove it and make a fresh one
1459
+ if os.path.exists('viz_rep_struct'):
1460
+ self.logger.info(f'viz_rep_struct exists and will be removed')
1461
+ os.system('rm -rf viz_rep_struct/')
1462
+ os.system('mkdir viz_rep_struct/')
1463
+ os.chdir('viz_rep_struct/')
1464
+
1465
+ if os.path.isdir(self.AAdcd_dir):
1466
+ AAtraj_files = glob.glob(os.path.join(self.AAdcd_dir, '*.dcd'))
1467
+ else:
1468
+ AAtraj_files = glob.glob(self.AAdcd_dir)
1469
+ if len(AAtraj_files) == 0:
1470
+ raise ValueError(
1471
+ f'No AA trajectory files found. AAdcd_dir={self.AAdcd_dir} '
1472
+ f'(resolved count={len(AAtraj_files)})'
1473
+ )
1474
+ self.logger.info(f'AAtraj_files:\n{AAtraj_files[:10]}')
1475
+
1476
+ wd = os.getcwd()
1477
+ for state_id in self.state_idx_list:
1478
+ state_dir = os.path.join(wd, 'State_%d'%(state_id+1))
1479
+ os.makedirs(state_dir, exist_ok=True)
1480
+ self.logger.info(f'Made {state_dir}')
1481
+ # os.chdir(state_dir)
1482
+ self.logger.info(f'Length of rep_group_dict[state_id]: {len(rep_group_dict[state_id])}')
1483
+ args_list = [
1484
+ (state_dir, state_id, k, rep_group_dict, sorted_chg_ent_structure_keyword_list, last_num_frames, total_traj_num_frames,
1485
+ idx2traj, AAtraj_files, self.native_AA_pdb, rep_chg_ent_dtrajs, Q_list, G_list,
1486
+ LIPMS_consist_data, M_LiPMS, XLMS_consist_data, M_XLMS, dist_traj_list, if_backmap, pulchra_only, self.logger.name)
1487
+ for k in rep_group_dict[state_id].keys()
1488
+ ]
1489
+ self.logger.info(f'Processing {len(args_list)} consistent signal groups for state {state_id}...')
1490
+ if len(args_list) == 0:
1491
+ self.logger.info(f'No consistent signal groups for state {state_id}, skipping...')
1492
+ continue
1493
+
1494
+ # process_k(args_list[0])
1495
+ # print('Testing done for one process_k, exiting...')
1496
+ # quit()
1497
+ # nproc = 10
1498
+ with multiprocessing.Pool(processes=self.nproc) as pool:
1499
+ pool.map(process_k, args_list)
1500
+ self.logger.debug('Completion of selecting rep structure')
1501
+
1502
+
1503
+ ##################################################################################################
1504
+ import mdtraj as mdt # at the top of your file
1505
+ def process_k(args):
1506
+ (state_dir, state_id, k, rep_group_dict, sorted_chg_ent_structure_keyword_list, last_num_frames, total_traj_num_frames,
1507
+ idx2traj, AAtraj_files, native_AA_pdb, rep_chg_ent_dtrajs, Q_list, G_list,
1508
+ LIPMS_consist_data, M_LiPMS, XLMS_consist_data, M_XLMS, dist_traj_list, if_backmap, pulchra_only, logger_name) = args
1509
+
1510
+ logger = logging.getLogger(logger_name)
1511
+
1512
+ k_list = eval(k)
1513
+ k_str = '_'.join(k_list)
1514
+ k_str_dir = os.path.join(state_dir, k_str)
1515
+ os.makedirs(k_str_dir, exist_ok=True)
1516
+ logger.info(f'Made {k_str_dir}')
1517
+ key_order = ['type', 'code', 'native_contact', 'native_contact_residx', 'linking_value', 'crossing_resid', 'crossing_residx', 'crossing_pattern', 'gauss_linking_number', 'topoly_linking_number',
1518
+ 'ref_native_contact', 'ref_native_contact_residx', 'ref_linking_value', 'ref_crossing_resid', 'ref_crossing_residx', 'ref_crossing_pattern', 'ref_gauss_linking_number', 'ref_topoly_linking_number']
1519
+ # os.system('mkdir %s/' % k_str)
1520
+ # os.chdir(k_str)
1521
+
1522
+ for kk in rep_group_dict[state_id][k].keys():
1523
+ kk_list = eval(sorted_chg_ent_structure_keyword_list[kk])
1524
+ if len(kk_list) == 0:
1525
+ continue
1526
+
1527
+ kk_str = '_'.join([str(i+1) for i in kk_list])
1528
+ kk_str_dir = os.path.join(k_str_dir, kk_str)
1529
+ os.makedirs(kk_str_dir, exist_ok=True)
1530
+ logger.info(f'Made {kk_str_dir}')
1531
+ # os.system('mkdir %s/' % kk_str)
1532
+ # os.chdir(kk_str)
1533
+
1534
+ [traj_idx, frame_idx] = rep_group_dict[state_id][k][kk]
1535
+ traj = idx2traj[traj_idx]
1536
+ traj_frame_idx = total_traj_num_frames - last_num_frames + frame_idx
1537
+
1538
+ AAtraj_file = match_pattern(AAtraj_files, f'{traj}')
1539
+ if len(AAtraj_file) != 1:
1540
+ raise ValueError(f'Found {len(AAtraj_file)} AA traj files for traj {traj}, expected 1.')
1541
+
1542
+ state_cor = mdt.load(AAtraj_file[0], top=native_AA_pdb)[traj_frame_idx].center_coordinates().xyz * 10
1543
+ rep_chg_ent_dict = rep_chg_ent_dtrajs[traj_idx][frame_idx]
1544
+
1545
+ rep_ent_dict = {tuple(v['code']): [] for kkk, v in rep_chg_ent_dict.items()}
1546
+ for kkk, v in rep_chg_ent_dict.items():
1547
+ v['chg_index'] = kkk
1548
+ rep_ent_dict[tuple(v['code'])].append(v)
1549
+ gen_state_visualizion(state_id, kk_str, kk_str_dir, native_AA_pdb, state_cor, native_AA_pdb, rep_ent_dict,
1550
+ logger,
1551
+ if_backmap=if_backmap, pulchra_only=pulchra_only, exp_signal_str=k_str)
1552
+
1553
+ Q = Q_list[traj_idx, frame_idx]
1554
+ G = G_list[traj_idx, frame_idx]
1555
+ info_file = os.path.join(kk_str_dir, 'info.txt')
1556
+ with open(info_file, 'w') as f:
1557
+ f.write('State #%d\n' % (state_id + 1))
1558
+ f.write('Q: %f\n' % (Q))
1559
+ f.write('G: %f\n' % (G))
1560
+ f.write('Consistent experiment signals: %s\n' % k)
1561
+ f.write('Changes in entanglement cluster IDs: %s\n' % [i+1 for i in kk_list])
1562
+ f.write('Trajectory number: %d\n' % (traj))
1563
+ f.write('Frame number: %d\n' % (traj_frame_idx + 1))
1564
+ f.write('%s\n' % ('-' * 64))
1565
+ for signal in k_list:
1566
+ if signal in LIPMS_consist_data.keys():
1567
+ M = M_LiPMS[traj_idx, frame_idx, LIPMS_consist_data[signal][4]]
1568
+ f.write('%s: M: %.4f\n' % (signal, M))
1569
+ elif signal in XLMS_consist_data.keys():
1570
+ M = M_XLMS[traj_idx, frame_idx, XLMS_consist_data[signal][4]]
1571
+ requested_key = '%s|A-%s|A' % (signal.split('-')[0][1:], signal.split('-')[1][1:])
1572
+ frame_data = dist_traj_list[traj_idx, frame_idx]
1573
+ if frame_data is None or requested_key not in frame_data:
1574
+ logger.warning(
1575
+ f'Missing XL-MS key for signal {signal}. '
1576
+ f'Traj {idx2traj[traj_idx]}, frame {frame_idx}, '
1577
+ f'MSM indices [{traj_idx}, {frame_idx}]. '
1578
+ f'Requested key: {requested_key}. '
1579
+ f'Available keys in frame_dict: {list(frame_data.keys()) if frame_data is not None else "None"}. '
1580
+ f'Skipping this signal in info file.'
1581
+ )
1582
+ f.write('%s: M: %.4f Jwalk: N/A Euclidean: N/A (key not in XP data)\n' % (signal, M))
1583
+ else:
1584
+ d_data = frame_data[requested_key]
1585
+ f.write('%s: M: %.4f Jwalk: %.4f Euclidean: %.4f\n' % (signal, M, d_data['Jwalk'], d_data['Euclidean']))
1586
+ f.write('%s\n' % ('-' * 64))
1587
+ for kkk, v in rep_chg_ent_dict.items():
1588
+ f.write('%d:\n' % (kkk + 1))
1589
+ # f.write(' ' * 4 + 'code: %s\n' % v['code'])
1590
+ # for kkkk in ['linking_value', 'topoly_linking_number', 'native_contact_residx', 'crossing_residx', 'crossing_pattern']:
1591
+ for kkkk in key_order:
1592
+ f.write(' ' * 4 + '%s: %s\n' % (kkkk, v[kkkk]))
1593
+ # f.write(' ' * 4 + 'ref_%s: %s\n' % (kkkk, v['ref_' + kkkk]))
1594
+ logger.debug(f'SAVED: {info_file}')
1595
+ ##############################################################################
1596
+
1597
+ ##############################################################################
1598
+ def match_pattern(strings, user_substring):
1599
+ # Build the regex pattern: non-digit (\D), user substring, then underscore
1600
+ pattern = re.compile(rf"\D{re.escape(user_substring)}_")
1601
+
1602
+ # Filter the list to only those matching the pattern
1603
+ matches = [s for s in strings if pattern.search(s)]
1604
+ return matches
1605
+ ##############################################################################
1606
+
1607
+ ##############################################################################
1608
+ def gen_state_visualizion(state_id, ent_id, kk_str_dir, psf, state_cor, native_AA_pdb, rep_ent_dict, logger, if_backmap=True, pulchra_only=False, exp_signal_str=''):
1609
+ def idx2sel(idx_list):
1610
+ if len(idx_list) == 0:
1611
+ return ''
1612
+ else:
1613
+ sel = 'index'
1614
+ idx_0 = idx_list[0]
1615
+ idx_1 = idx_list[0]
1616
+ sel_0 = ' %d'%idx_0
1617
+ for i in range(1, len(idx_list)):
1618
+ if idx_list[i] == idx_list[i-1] + 1:
1619
+ idx_1 = idx_list[i]
1620
+ else:
1621
+ if idx_1 > idx_0:
1622
+ sel_0 += ' to %d'%idx_1
1623
+ sel += sel_0
1624
+ idx_0 = idx_list[i]
1625
+ idx_1 = idx_list[i]
1626
+ sel_0 = ' %d'%idx_0
1627
+ if idx_1 > idx_0:
1628
+ sel_0 += ' to %d'%idx_1
1629
+ sel += sel_0
1630
+ return sel
1631
+
1632
+ AA_name_list = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
1633
+ 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL',
1634
+ 'HIE', 'HID', 'HIP']
1635
+ protein_colorid_list = [12, 15]
1636
+ loop_colorid_list = [4, 1]
1637
+ thread_colorid_list = [7, 0]
1638
+ nc_colorid_list = [3, 3]
1639
+ crossing_colorid_list = [8, 8]
1640
+ LIP_colorid_list = [13, 10]
1641
+ XL_colorid_list = [13, 10]
1642
+ thread_cutoff=3
1643
+ terminal_cutoff=3
1644
+
1645
+ logger.info('Generate visualization of state %d'%(state_id + 1))
1646
+ logger.debug(f'ent_id: {ent_id}')
1647
+ logger.debug(f'kk_str_dir: {kk_str_dir}')
1648
+ logger.debug(f'psf: {psf}')
1649
+ logger.debug(f'state_cor: {state_cor}')
1650
+ logger.debug(f'native_AA_pdb: {native_AA_pdb}')
1651
+ logger.debug(f'if_backmap: {if_backmap}')
1652
+ logger.debug(f'pulchra_only: {pulchra_only}')
1653
+ logger.debug(f'exp_signal_str: {exp_signal_str}')
1654
+ logger.info(f'rep_ent_dict:')
1655
+ for k, v in rep_ent_dict.items():
1656
+ logger.info(f'k:\n{k}')
1657
+ logger.info(f'v:\n{v} {len(v)}')
1658
+
1659
+ struct = pmd.load_file(psf)
1660
+ struct.coordinates = state_cor
1661
+
1662
+ # backmap
1663
+ pdb_path = os.path.join(kk_str_dir, f'state_{state_id + 1}.pdb')
1664
+ if if_backmap:
1665
+ if pulchra_only:
1666
+ pulchra_only = '1'
1667
+ else:
1668
+ pulchra_only = '0'
1669
+ temp_pdb_path = os.path.join(kk_str_dir, 'tmp.pdb')
1670
+ struct.save(temp_pdb_path, overwrite=True)
1671
+ os.system('backmap.py -i '+native_AA_pdb+' -c '+temp_pdb_path+' -p '+pulchra_only)
1672
+ os.system('mv tmp_rebuilt.pdb '+pdb_path)
1673
+ os.system('rm -f '+temp_pdb_path)
1674
+ os.system('rm -rf ./rebuild_tmp/')
1675
+ else:
1676
+ struct.save(pdb_path, overwrite=True)
1677
+
1678
+ ref_struct = pmd.load_file(native_AA_pdb)
1679
+ current_struct = pmd.load_file(pdb_path)
1680
+
1681
+ # parse exp_signal_str
1682
+ if exp_signal_str == '':
1683
+ exp_signal_list = []
1684
+ else:
1685
+ exp_signal_list = exp_signal_str.strip().split('_')
1686
+ exp_signal_list = [es.strip().split('-') for es in exp_signal_list]
1687
+ exp_signal_list = [[int(ees[1:]) for ees in es] for es in exp_signal_list]
1688
+
1689
+ ##############################################
1690
+ ## no change of entaglement
1691
+ if len(list(rep_ent_dict.keys())) == 0:
1692
+ vmd_outfile = os.path.join(kk_str_dir, f'vmd_s{state_id}_none.tcl')
1693
+ f = open(vmd_outfile, 'w')
1694
+ f.write('# Entanglement type: no change\n')
1695
+ f.write('''display rendermode GLSL
1696
+ display projection Orthographic
1697
+ axes location off
1698
+
1699
+ color Display {Background} white
1700
+
1701
+ mol new ./'''+(pdb_path)+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
1702
+ mol delrep 0 top
1703
+ mol representation NewCartoon 0.300000 10.000000 4.100000 0
1704
+ mol color ColorID '''+str(protein_colorid_list[1])+'''
1705
+ mol selection {all}
1706
+ mol material AOChalky
1707
+ mol addrep top
1708
+ ''')
1709
+ f.close()
1710
+ logger.debug(f'SAVED: {vmd_outfile}')
1711
+
1712
+ ##############################################
1713
+ ## Create vmd script for each type of change
1714
+ for ent_code, rep_ent_list in rep_ent_dict.items():
1715
+ # print(f'ent_code:\n{ent_code}')
1716
+ # print(f'rep_ent_list:\n{rep_ent_list}')
1717
+ pmd_struct_list = [ref_struct, current_struct]
1718
+ struct_dir_list = [native_AA_pdb, pdb_path]
1719
+ key_prefix_list = ['ref_', '']
1720
+ repres_list = ['', '']
1721
+ align_sel_list = ['', '']
1722
+
1723
+
1724
+ for chg_ent_fingerprint_idx, chg_ent_fingerprint in enumerate(rep_ent_list):
1725
+ # print(f'\nProcessing change of entanglement fingerprint {chg_ent_fingerprint_idx}:\n{chg_ent_fingerprint}')
1726
+ vmd_script = '''# Entanglement type: '''+str(chg_ent_fingerprint['type'])+'''
1727
+ package require topotools
1728
+ display rendermode GLSL
1729
+ display projection Orthographic
1730
+ axes location off
1731
+
1732
+ color Display {Background} white
1733
+ color Labels {Bonds} black
1734
+
1735
+ label textsize 0.000001
1736
+
1737
+ '''
1738
+ for struct_idx, pmd_struct in enumerate(pmd_struct_list):
1739
+ struct_dir = struct_dir_list[struct_idx]
1740
+ protein_colorid = protein_colorid_list[struct_idx]
1741
+ loop_colorid = loop_colorid_list[struct_idx]
1742
+ thread_colorid = thread_colorid_list[struct_idx]
1743
+ nc_colorid = nc_colorid_list[struct_idx]
1744
+ crossing_colorid = crossing_colorid_list[struct_idx]
1745
+ LIP_colorid = LIP_colorid_list[struct_idx]
1746
+ XL_colorid = XL_colorid_list[struct_idx]
1747
+ key_prefix = key_prefix_list[struct_idx]
1748
+ # print(f'Processing structure {struct_idx}: {struct_dir}')
1749
+ # print(f' Protein color ID: {protein_colorid}')
1750
+ # print(f' Loop color ID: {loop_colorid}')
1751
+ # print(f' Thread color ID: {thread_colorid}')
1752
+ # print(f' NC color ID: {nc_colorid}')
1753
+ # print(f' Crossing color ID: {crossing_colorid}')
1754
+ # print(f' LIP color ID: {LIP_colorid}')
1755
+ # print(f' XL color ID: {XL_colorid}')
1756
+
1757
+ # Clean ligands
1758
+ clean_sel_idx = np.zeros(len(pmd_struct.atoms))
1759
+ for res in pmd_struct.residues:
1760
+ if res.name in AA_name_list:
1761
+ for atm in res.atoms:
1762
+ clean_sel_idx[atm.idx] = 1
1763
+ pmd_clean_struct = pmd_struct[clean_sel_idx]
1764
+ clean_idx_to_idx = np.where(clean_sel_idx == 1)[0]
1765
+
1766
+ # vmd selection string for protein
1767
+ idx_list = []
1768
+ for res in pmd_struct.residues:
1769
+ if res.name in AA_name_list:
1770
+ idx_list += [atm.idx for atm in res.atoms]
1771
+ vmd_sel = idx2sel(idx_list)
1772
+
1773
+ repres = '''mol new '''+struct_dir+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
1774
+ mol delrep 0 top
1775
+ mol representation NewCartoon 0.300000 10.000000 4.100000 0
1776
+ mol color ColorID '''+str(protein_colorid)+'''
1777
+ mol selection {'''+vmd_sel+'''}
1778
+ mol material GlassBubble
1779
+ mol addrep top
1780
+ '''
1781
+ align_sel = vmd_sel
1782
+
1783
+ nc = chg_ent_fingerprint[key_prefix+'native_contact_residx']
1784
+ chg_idx = chg_ent_fingerprint['chg_index']
1785
+
1786
+ idx_list = []
1787
+ for res in pmd_clean_struct.residues:
1788
+ if res.idx in nc:
1789
+ idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
1790
+ nc_sel = idx2sel(clean_idx_to_idx[idx_list])
1791
+ # print(f' Native contact residx: {nc}, selection: {nc_sel}')
1792
+
1793
+ idx_list = []
1794
+ for res in pmd_clean_struct.residues:
1795
+ if res.idx >= nc[0] and res.idx <= nc[1]:
1796
+ idx_list += [atm.idx for atm in res.atoms]
1797
+ loop_sel = idx2sel(clean_idx_to_idx[idx_list])
1798
+ # print(f' Loop residx: {list(range(nc[0], nc[1]+1))}, selection: {loop_sel}')
1799
+
1800
+
1801
+ align_sel += ' and not (%s)'%loop_sel
1802
+ ref_coss_resid = chg_ent_fingerprint['ref_crossing_residx']
1803
+ cross_resid = chg_ent_fingerprint['crossing_residx']
1804
+ thread = []
1805
+ thread_sel_list = []
1806
+ for ter_idx in range(len(ref_coss_resid)):
1807
+ thread_0 = []
1808
+ resid_list = ref_coss_resid[ter_idx] + cross_resid[ter_idx]
1809
+ if len(resid_list) > 0:
1810
+ thread_0 = [np.min(resid_list)-5, np.max(resid_list)+5]
1811
+ if ter_idx == 0:
1812
+ thread_0[0] = np.max([thread_0[0], terminal_cutoff])
1813
+ thread_0[1] = np.min([thread_0[1], nc[0]-thread_cutoff])
1814
+ else:
1815
+ thread_0[0] = np.max([thread_0[0], nc[1]+thread_cutoff])
1816
+ thread_0[1] = np.min([thread_0[1], len(struct.atoms)-1-terminal_cutoff])
1817
+ idx_list = []
1818
+ for res in pmd_clean_struct.residues:
1819
+ if res.idx >= thread_0[0] and res.idx <= thread_0[1]:
1820
+ idx_list += [atm.idx for atm in res.atoms]
1821
+ thread_0_sel = idx2sel(clean_idx_to_idx[idx_list])
1822
+ thread_sel_list.append(thread_0_sel)
1823
+ align_sel += ' and not (%s)'%thread_0_sel
1824
+ else:
1825
+ thread_sel_list.append('')
1826
+ thread.append(thread_0)
1827
+
1828
+ ln = chg_ent_fingerprint[key_prefix+'topoly_linking_number']
1829
+ cross = []
1830
+ for i in range(len(chg_ent_fingerprint[key_prefix+'crossing_residx'])):
1831
+ cross.append([])
1832
+ for j in range(len(chg_ent_fingerprint[key_prefix+'crossing_residx'][i])):
1833
+ cross[-1].append(chg_ent_fingerprint[key_prefix+'crossing_pattern'][i][j]+str(chg_ent_fingerprint[key_prefix+'crossing_residx'][i][j]))
1834
+ repres += '# idx: native_contact_residx %s, topoly_linking_number %s, crossing_residx %s.\n'%(str(nc), str(ln), str(cross))
1835
+ repres +=''' mol representation NewCartoon 0.350000 10.000000 4.100000 0
1836
+ mol color ColorID '''+str(loop_colorid)+'''
1837
+ mol selection {'''+loop_sel+'''}
1838
+ mol material Opaque
1839
+ mol addrep top
1840
+ mol representation VDW 1.000000 12.000000
1841
+ mol color ColorID '''+str(nc_colorid)+'''
1842
+ mol selection {'''+nc_sel+'''}
1843
+ mol material Opaque
1844
+ mol addrep top
1845
+ set sel [atomselect top "'''+nc_sel+'''"]
1846
+ set idx [$sel get index]
1847
+ topo addbond [lindex $idx 0] [lindex $idx 1]
1848
+ mol representation Bonds 0.300000 12.000000
1849
+ mol color ColorID '''+str(nc_colorid)+'''
1850
+ mol selection {'''+nc_sel+'''}
1851
+ mol material Opaque
1852
+ mol addrep top
1853
+ '''
1854
+ for ter_idx, thread_resid in enumerate(thread):
1855
+ if len(thread_resid) == 0:
1856
+ continue
1857
+ repres += '''mol representation NewCartoon 0.350000 10.000000 4.100000 0
1858
+ mol color ColorID '''+str(thread_colorid)+'''
1859
+ mol selection {'''+thread_sel_list[ter_idx]+'''}
1860
+ mol material Opaque
1861
+ mol addrep top
1862
+ '''
1863
+ if len(chg_ent_fingerprint[key_prefix+'crossing_residx'][ter_idx]) > 0:
1864
+ idx_list = []
1865
+ for res in pmd_clean_struct.residues:
1866
+ if res.idx in chg_ent_fingerprint[key_prefix+'crossing_residx'][ter_idx]:
1867
+ idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
1868
+ crossing_sel = idx2sel(clean_idx_to_idx[idx_list])
1869
+ repres += '''mol representation VDW 1.000000 12.000000
1870
+ mol color ColorID '''+str(crossing_colorid)+'''
1871
+ mol selection {'''+crossing_sel+'''}
1872
+ mol material Opaque
1873
+ mol addrep top
1874
+ '''
1875
+ #######################################
1876
+ ## showing experimenal signal residues
1877
+ for es in exp_signal_list:
1878
+ idx_list = []
1879
+ for res in pmd_clean_struct.residues:
1880
+ if res.idx+1 in es:
1881
+ idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
1882
+ es_sel = idx2sel(clean_idx_to_idx[idx_list])
1883
+ if len(es) == 1: # LiP-MS signal
1884
+ repres += '''# LiP-MS PK site '''+str(es[0])+'''
1885
+ mol representation VDW 1.000000 12.000000
1886
+ mol color ColorID '''+str(LIP_colorid)+'''
1887
+ mol selection {'''+es_sel+'''}
1888
+ mol material Opaque
1889
+ mol addrep top
1890
+ '''
1891
+ elif len(es) == 2: # XL-MS signal
1892
+ repres += '''# XL-MS pair ('''+str(es[0])+''', '''+str(es[1])+''')
1893
+ mol representation VDW 1.000000 12.000000
1894
+ mol color ColorID '''+str(XL_colorid)+'''
1895
+ mol selection {'''+es_sel+'''}
1896
+ mol material Opaque
1897
+ mol addrep top
1898
+ label add Bonds '''+('%d/%d %d/%d'%(struct_idx, idx_list[0], struct_idx, idx_list[1]))+'''
1899
+ '''
1900
+
1901
+ if struct_idx == 0:
1902
+ repres += '''mol representation VDW 1.000000 12.000000
1903
+ mol color Name
1904
+ mol selection {not ('''+vmd_sel+''') and not water}
1905
+ mol material Opaque
1906
+ mol addrep top
1907
+ '''
1908
+ repres_list[struct_idx] = repres
1909
+ # print(f'representation:\n{repres}')
1910
+ align_sel_list[struct_idx] = align_sel
1911
+ # print(f'align selection:\n{align_sel}')
1912
+
1913
+ vmd_script += '\n'.join(repres_list)
1914
+ vmd_script += '''
1915
+ set sel1 [atomselect 0 "'''+align_sel_list[0]+''' and name CA"]
1916
+ set sel2 [atomselect 1 "'''+align_sel_list[1]+''' and name CA"]
1917
+ set trans_mat [measure fit $sel1 $sel2]
1918
+ set move_sel [atomselect 0 "all"]
1919
+ $move_sel move $trans_mat
1920
+ '''
1921
+ vmd_outfile = os.path.join(kk_str_dir, f'vmd_s{state_id + 1}_e{chg_idx + 1}_n{ent_code[0]}_c{ent_code[1]}.tcl')
1922
+ f = open(vmd_outfile, 'w')
1923
+ f.write(vmd_script)
1924
+ f.close()
1925
+ logger.debug(f'SAVED: {vmd_outfile}')
1926
+
1927
+ ##############################################################################