EntDetect 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- EntDetect/Jwalk/GridTools.py +567 -0
- EntDetect/Jwalk/PDBTools.py +532 -0
- EntDetect/Jwalk/SASDTools.py +543 -0
- EntDetect/Jwalk/SurfaceTools.py +150 -0
- EntDetect/Jwalk/__init__.py +19 -0
- EntDetect/Jwalk/naccess.config.txt +255 -0
- EntDetect/__init__.py +10 -0
- EntDetect/_logging.py +71 -0
- EntDetect/change_resolution.py +2361 -0
- EntDetect/clustering.py +2626 -0
- EntDetect/compare_sim2exp.py +1927 -0
- EntDetect/entanglement_features.py +478 -0
- EntDetect/gaussian_entanglement.py +2067 -0
- EntDetect/order_params.py +1048 -0
- EntDetect/resources/__init__.py +11 -0
- EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
- EntDetect/resources/calc_K.pl +712 -0
- EntDetect/resources/calc_Q.pl +962 -0
- EntDetect/resources/pulchra +0 -0
- EntDetect/resources/shared_files/__init__.py +2 -0
- EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
- EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
- EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
- EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
- EntDetect/resources/stride +0 -0
- EntDetect/statistics.py +1344 -0
- EntDetect/utilities.py +201 -0
- entdetect-1.2.0.dist-info/METADATA +26 -0
- entdetect-1.2.0.dist-info/RECORD +45 -0
- entdetect-1.2.0.dist-info/WHEEL +5 -0
- entdetect-1.2.0.dist-info/entry_points.txt +11 -0
- entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
- entdetect-1.2.0.dist-info/top_level.txt +2 -0
- scripts/__init__.py +5 -0
- scripts/convert_cor_psf_to_pdb.py +103 -0
- scripts/run_Foldingpathway.py +162 -0
- scripts/run_MSM.py +152 -0
- scripts/run_OP_on_simulation_traj.py +194 -0
- scripts/run_change_resolution.py +63 -0
- scripts/run_compare_sim2exp.py +215 -0
- scripts/run_montecarlo.py +158 -0
- scripts/run_nativeNCLE.py +179 -0
- scripts/run_nonnative_entanglement_clustering.py +110 -0
- scripts/run_population_modeling.py +117 -0
- scripts/run_workflow4_nativeNCLE_batch.py +412 -0
EntDetect/clustering.py
ADDED
|
@@ -0,0 +1,2626 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
import numpy as np
|
|
4
|
+
import itertools
|
|
5
|
+
from geom_median.numpy import compute_geometric_median
|
|
6
|
+
from scipy.spatial.distance import cdist, squareform
|
|
7
|
+
from functools import cache
|
|
8
|
+
import re
|
|
9
|
+
import random
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import pickle
|
|
12
|
+
import logging
|
|
13
|
+
import sys, getopt, math, os, time, traceback, glob, copy
|
|
14
|
+
import threading
|
|
15
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
16
|
+
from EntDetect._logging import setup_logger
|
|
17
|
+
from scipy.cluster.hierarchy import fcluster, linkage, cophenet
|
|
18
|
+
try:
|
|
19
|
+
import parmed as pmd
|
|
20
|
+
import mdtraj as mdt
|
|
21
|
+
except ImportError:
|
|
22
|
+
pmd = None
|
|
23
|
+
mdt = None
|
|
24
|
+
import matplotlib
|
|
25
|
+
import matplotlib.pyplot as plt
|
|
26
|
+
try:
|
|
27
|
+
import pyemma as pem
|
|
28
|
+
import deeptime
|
|
29
|
+
except ImportError:
|
|
30
|
+
pem = None
|
|
31
|
+
deeptime = None
|
|
32
|
+
from matplotlib.cm import get_cmap
|
|
33
|
+
from matplotlib.colors import ListedColormap, BoundaryNorm
|
|
34
|
+
import matplotlib.colors as mcolors
|
|
35
|
+
import seaborn as sns
|
|
36
|
+
from scipy.stats import mode
|
|
37
|
+
import pathlib
|
|
38
|
+
from dataclasses import dataclass
|
|
39
|
+
import json
|
|
40
|
+
|
|
41
|
+
matplotlib.use('Agg')
|
|
42
|
+
pd.set_option('display.max_rows', 500)
|
|
43
|
+
|
|
44
|
+
class ClusterNativeEntanglements:
|
|
45
|
+
"""
|
|
46
|
+
Class to calculate native entanglements given either a file path to an entanglement file or an entanglement object
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
##########################################################################################################################################################
|
|
50
|
+
def __init__(self, organism: str = 'Ecoli', cut_off: int = None, outdir: str = None, log_level: int = logging.INFO, logdir: str = None) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Constructor for GaussianEntanglement class.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
if organism == 'Human':
|
|
59
|
+
self.cut_off = 52
|
|
60
|
+
elif organism == 'Ecoli':
|
|
61
|
+
self.cut_off = 57
|
|
62
|
+
elif organism == 'Yeast':
|
|
63
|
+
self.cut_off = 49
|
|
64
|
+
|
|
65
|
+
if cut_off is not None:
|
|
66
|
+
self.cut_off = cut_off
|
|
67
|
+
self.organism = organism
|
|
68
|
+
self.logger = setup_logger('ClusterNativeEntanglements', outdir=logdir if logdir is not None else outdir, log_level=log_level)
|
|
69
|
+
##########################################################################################################################################################
|
|
70
|
+
|
|
71
|
+
##########################################################################################################################################################
|
|
72
|
+
def loop_distance(self, entangled_A: tuple, entangled_B: tuple):
|
|
73
|
+
|
|
74
|
+
# remove chiralites then perform euclidean distance
|
|
75
|
+
new_cr_A = [int(cr_A[1:]) for cr_A in entangled_A[3:-1]]
|
|
76
|
+
new_entangled_A = (entangled_A[0], entangled_A[1], entangled_A[2], *new_cr_A)
|
|
77
|
+
|
|
78
|
+
new_cr_B = [int(cr_B[1:]) for cr_B in entangled_B[3:-1]]
|
|
79
|
+
new_entangled_B = (entangled_B[0], entangled_B[1], entangled_B[2], *new_cr_B)
|
|
80
|
+
|
|
81
|
+
return math.dist(new_entangled_A[1:], new_entangled_B[1:])
|
|
82
|
+
##########################################################################################################################################################
|
|
83
|
+
|
|
84
|
+
##########################################################################################################################################################
|
|
85
|
+
def check_step_ij_kl_range(self, ent1: tuple, ent2: tuple):
|
|
86
|
+
|
|
87
|
+
# check if i or j of (i,j) reside within the range (inclusive) of (k,l), and vice versa
|
|
88
|
+
|
|
89
|
+
nc_pair_1 = ent1[1:3]
|
|
90
|
+
nc_pair_1_range = np.arange(ent1[1:3][0], ent1[1:3][1] + 1)
|
|
91
|
+
|
|
92
|
+
nc_pair_2 = ent2[1:3]
|
|
93
|
+
nc_pair_2_range = np.arange(ent2[1:3][0], ent2[1:3][1] + 1)
|
|
94
|
+
|
|
95
|
+
#return True if (nc_pair_1[0] in nc_pair_2_range or nc_pair_1[1] in nc_pair_2_range or
|
|
96
|
+
|
|
97
|
+
if nc_pair_1[0] in nc_pair_2_range or nc_pair_1[1] in nc_pair_2_range:
|
|
98
|
+
return True
|
|
99
|
+
elif nc_pair_2[0] in nc_pair_1_range or nc_pair_2[1] in nc_pair_1_range:
|
|
100
|
+
return True
|
|
101
|
+
else:
|
|
102
|
+
return False
|
|
103
|
+
##########################################################################################################################################################
|
|
104
|
+
|
|
105
|
+
##########################################################################################################################################################
|
|
106
|
+
#@cache
|
|
107
|
+
def Cluster_NativeEntanglements(self, GE_filepath: str, outdir: str='./', outfile: str='Cluster_NativeEntanglements.txt', chain: str=None):
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
PARAMS:
|
|
111
|
+
GE_file: str
|
|
112
|
+
cut_off: int
|
|
113
|
+
|
|
114
|
+
1. Identify all unique "residue crossing set and chiralites"
|
|
115
|
+
|
|
116
|
+
1b. sort the residues along with the chiralities
|
|
117
|
+
|
|
118
|
+
2. Find the minimal loop encompassing a given "residue crossing set and chiralites"
|
|
119
|
+
|
|
120
|
+
2b
|
|
121
|
+
|
|
122
|
+
i. Identify entanglements that have any crossing residues between them that are
|
|
123
|
+
less than or equal to 3 residues apart and have the same chirality.
|
|
124
|
+
|
|
125
|
+
ii. Then check if i or j of (i,j) reside within the range (inclusive) of (k,l), or vice versa;
|
|
126
|
+
|
|
127
|
+
iii. If yes, then check if any crossing residues are in the range of min(i,j,k,l) to max(i,j,k,l);
|
|
128
|
+
if yes, skip rest of 2
|
|
129
|
+
|
|
130
|
+
iv. If no, then check if the number of crossing residues, in each residue set, are different;
|
|
131
|
+
|
|
132
|
+
v. All crossing residue(s) in the entanglement with the fewer crossings need to have a distance <= 20
|
|
133
|
+
with the crossings in the other entanglement. Do this by the "brute force" approach and
|
|
134
|
+
the true distance formula. This means, calculate the distances and take the minimal distance
|
|
135
|
+
as the distance you check that is less than or equal to 20.
|
|
136
|
+
|
|
137
|
+
If yes, then keep the {i,j} {r} that have the greatest number of crossing residues;
|
|
138
|
+
If not, then keep the two entanglements separate.
|
|
139
|
+
|
|
140
|
+
3. For at least two entanglements each with 1 or more crossings.
|
|
141
|
+
Loop over the entanglments two at time (avoid double counting)
|
|
142
|
+
Check if i or j of (i,j) reside within the range (inclusive) of (k,l), or vice versa;
|
|
143
|
+
If yes, check if number of crossing residues is the same (and it is 1 or more)
|
|
144
|
+
If yes, calculate the distances between all crossing residues
|
|
145
|
+
and if both have the same chiralities.
|
|
146
|
+
(Do NOT use brute force, just compare 1-to-1 index of crossing residues).
|
|
147
|
+
If all the distances are less than or equal to 20, then determine which
|
|
148
|
+
entanglement has the smaller loop, remove the entanglement with the larger loop
|
|
149
|
+
|
|
150
|
+
4. Spatially cluster those outputs that have (i) the same number of crossings and (ii) the same chiralities
|
|
151
|
+
|
|
152
|
+
"""
|
|
153
|
+
self.logger.info(f'Clustering {self.organism} Native Entanglements with dist_cutoff: {self.cut_off}')
|
|
154
|
+
GE_file = GE_filepath.split('/')[-1]
|
|
155
|
+
self.logger.debug(f'{GE_file} cut_off={self.cut_off} outdir={outdir}')
|
|
156
|
+
|
|
157
|
+
full_entanglement_data = defaultdict(list)
|
|
158
|
+
|
|
159
|
+
ent_data = defaultdict(list)
|
|
160
|
+
|
|
161
|
+
rep_ID_ent = defaultdict(list)
|
|
162
|
+
|
|
163
|
+
grouped_entanglement_data = defaultdict(list)
|
|
164
|
+
|
|
165
|
+
Before_cr_dist = defaultdict(list)
|
|
166
|
+
|
|
167
|
+
After_cr_dist = defaultdict(list)
|
|
168
|
+
|
|
169
|
+
entanglement_partial_g_data = {}
|
|
170
|
+
|
|
171
|
+
## Check if the clustering file is already made and if so use it
|
|
172
|
+
outfilepath = os.path.join(f'{outdir}', f'{outfile}')
|
|
173
|
+
if os.path.exists(outfilepath):
|
|
174
|
+
self.logger.info(f'{outfilepath} ALREADY EXISTS AND WILL BE LOADED')
|
|
175
|
+
outdf = pd.read_csv(outfilepath, sep='|')
|
|
176
|
+
return {'outfile':outfilepath, 'ent_result':outdf}
|
|
177
|
+
|
|
178
|
+
self.logger.info(f'Loading {GE_filepath}')
|
|
179
|
+
GE_data = pd.read_csv(GE_filepath, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
|
|
180
|
+
GE_data = GE_data[GE_data['ENT'] == True].reset_index(drop=True)
|
|
181
|
+
# if Quality is a column name then only get the High Quality raw entanglements
|
|
182
|
+
if 'Quality' in GE_data.keys():
|
|
183
|
+
GE_data = GE_data[GE_data['Quality'] == 'High'].reset_index(drop=True)
|
|
184
|
+
|
|
185
|
+
GE_data = GE_data.replace(np.nan, '', regex=True)
|
|
186
|
+
self.logger.debug(GE_data)
|
|
187
|
+
|
|
188
|
+
### STEP 1 INITAL LOADING AND MERGING ################################################################################################################
|
|
189
|
+
############################################################################################
|
|
190
|
+
## parse the entanglement file and
|
|
191
|
+
## get those native contacts that are disulfide bonds
|
|
192
|
+
self.logger.info(f'# Step 1')
|
|
193
|
+
CCBonds = []
|
|
194
|
+
num_raw_ents = {}
|
|
195
|
+
chain_info = {} # Track chain for each ID
|
|
196
|
+
for rowi, row in GE_data.iterrows():
|
|
197
|
+
# print(row)
|
|
198
|
+
|
|
199
|
+
ID = row['ID']
|
|
200
|
+
chain = row['chain'] if 'chain' in row else 'A' # Default to 'A' if chain not present
|
|
201
|
+
chain_info[ID] = chain
|
|
202
|
+
native_contact_i, native_contact_j = row['i'], row['j']
|
|
203
|
+
|
|
204
|
+
# Store separate N and C crossings
|
|
205
|
+
crossingsN = row['crossingsN'] if pd.notna(row['crossingsN']) and row['crossingsN'] != '' else ''
|
|
206
|
+
crossingsC = row['crossingsC'] if pd.notna(row['crossingsC']) and row['crossingsC'] != '' else ''
|
|
207
|
+
|
|
208
|
+
crossing_res = [cr for cr in [crossingsN, crossingsC] if cr != '']
|
|
209
|
+
crossing_res = ','.join(crossing_res)
|
|
210
|
+
self.logger.debug(f'{native_contact_i} {native_contact_j} {crossing_res}')
|
|
211
|
+
|
|
212
|
+
gn, gc = row['gn'], row['gc']
|
|
213
|
+
GLNn, GLNc = row['GLNn'], row['GLNc']
|
|
214
|
+
TLNn, TLNc = row['TLNn'], row['TLNc']
|
|
215
|
+
# Handle empty strings from NaN replacement: convert to np.nan
|
|
216
|
+
TLNn = np.nan if (isinstance(TLNn, str) and TLNn == '') else TLNn
|
|
217
|
+
TLNc = np.nan if (isinstance(TLNc, str) and TLNc == '') else TLNc
|
|
218
|
+
CCbond = row['CCbond']
|
|
219
|
+
|
|
220
|
+
# keep track of number of raw ents for QC purposes
|
|
221
|
+
if ID not in num_raw_ents:
|
|
222
|
+
num_raw_ents[ID] = 1
|
|
223
|
+
else:
|
|
224
|
+
num_raw_ents[ID] += 1
|
|
225
|
+
|
|
226
|
+
#native_contact_i, native_contact_j, crossing_res = line[1], line[2], line[3]
|
|
227
|
+
#native_contact_i = int(native_contact_i)
|
|
228
|
+
#native_contact_j = int(native_contact_j)
|
|
229
|
+
|
|
230
|
+
reformat_cr = crossing_res.split(',') if crossing_res else []
|
|
231
|
+
|
|
232
|
+
# Filter out any empty strings from the split
|
|
233
|
+
reformat_cr = [cr for cr in reformat_cr if cr]
|
|
234
|
+
|
|
235
|
+
if reformat_cr:
|
|
236
|
+
reformat_cr = sorted(reformat_cr, key = lambda x: int(re.split("\\+|-", x, maxsplit= 1)[1]))
|
|
237
|
+
#print(native_contact_i, native_contact_j, reformat_cr)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# Step 1 and 1b
|
|
241
|
+
grouped_entanglement_data[(ID, *reformat_cr)].append((native_contact_i, native_contact_j))
|
|
242
|
+
|
|
243
|
+
entanglement_partial_g_data[(native_contact_i, native_contact_j, *reformat_cr)] = (gn, gc, GLNn, GLNc, TLNn, TLNc, crossingsN, crossingsC)
|
|
244
|
+
|
|
245
|
+
#print(f'CCbond: {CCbond}')
|
|
246
|
+
if CCbond == True:
|
|
247
|
+
CCBonds += [(native_contact_i, native_contact_j)]
|
|
248
|
+
|
|
249
|
+
#print(f'num_raw_ents: {num_raw_ents}')
|
|
250
|
+
|
|
251
|
+
#print(f'Step 1 results')
|
|
252
|
+
Step1_QC_counter = 0
|
|
253
|
+
for k,v in grouped_entanglement_data.items():
|
|
254
|
+
#print(k,v)
|
|
255
|
+
Step1_QC_counter += len(v)
|
|
256
|
+
|
|
257
|
+
# STEP 1 SUMMARY
|
|
258
|
+
self.logger.info(f'\n{"="*100}')
|
|
259
|
+
self.logger.info(f'STEP 1 SUMMARY: Data Loading and Grouping by Unique Crossing Sets')
|
|
260
|
+
self.logger.info(f'{"="*100}')
|
|
261
|
+
self.logger.info(f'Total raw entanglements loaded: {Step1_QC_counter}')
|
|
262
|
+
self.logger.info(f'Number of protein IDs: {len(num_raw_ents)}')
|
|
263
|
+
for prot_id, count in sorted(num_raw_ents.items()):
|
|
264
|
+
self.logger.info(f' - {prot_id}: {count} raw entanglements')
|
|
265
|
+
self.logger.info(f'Unique crossing patterns identified: {len(grouped_entanglement_data)}')
|
|
266
|
+
self.logger.info(f'Disulfide bonds found: {len(CCBonds)}')
|
|
267
|
+
if CCBonds:
|
|
268
|
+
self.logger.info(f' - Disulfide bond pairs: {CCBonds}')
|
|
269
|
+
|
|
270
|
+
### STEP 2 ################################################################################################################
|
|
271
|
+
############################################################################################
|
|
272
|
+
# Step 2 Get the minimal loop encompassing each set of unique crossings
|
|
273
|
+
self.logger.info(f'\n# Step 2a')
|
|
274
|
+
for ID_cr, loops in grouped_entanglement_data.items():
|
|
275
|
+
|
|
276
|
+
ID = ID_cr[0]
|
|
277
|
+
|
|
278
|
+
crossings = np.asarray(list(ID_cr[1:]))
|
|
279
|
+
|
|
280
|
+
loop_lengths = [nc[1] - nc[0] for nc in loops]
|
|
281
|
+
|
|
282
|
+
minimum_loop_length = min(loop_lengths)
|
|
283
|
+
|
|
284
|
+
minimum_loop_length_index = loop_lengths.index(minimum_loop_length)
|
|
285
|
+
|
|
286
|
+
minimum_loop_nc_i, minimum_loop_nc_j = loops[minimum_loop_length_index]
|
|
287
|
+
|
|
288
|
+
ent_data[ID].append((len(loops), minimum_loop_nc_i, minimum_loop_nc_j, *crossings, ';'.join(['-'.join([str(loop[0]), str(loop[1])]) for loop in loops])))
|
|
289
|
+
|
|
290
|
+
# STEP 2a SUMMARY
|
|
291
|
+
self.logger.info(f'{"="*100}')
|
|
292
|
+
self.logger.info(f'STEP 2A SUMMARY: Minimal Loop Identification for Unique Crossing Sets')
|
|
293
|
+
self.logger.info(f'{"="*100}')
|
|
294
|
+
for ID, ents in ent_data.items():
|
|
295
|
+
Step2a_QC_counter = 0
|
|
296
|
+
for ent_i, ent in enumerate(ents):
|
|
297
|
+
#print(ID, ent_i, ent)
|
|
298
|
+
Step2a_QC_counter += ent[0]
|
|
299
|
+
|
|
300
|
+
self.logger.info(f'{ID}: {Step2a_QC_counter} raw entanglements grouped into {len(ents)} representative loops')
|
|
301
|
+
for ent_i, ent in enumerate(ents):
|
|
302
|
+
num_loops = ent[0]
|
|
303
|
+
loop_i, loop_j = ent[1], ent[2]
|
|
304
|
+
crossings = [str(c) for c in ent[3:-1]]
|
|
305
|
+
self.logger.info(f' Representative {ent_i+1}: Loop ({loop_i}, {loop_j}), ' +
|
|
306
|
+
f'Crossings={crossings if crossings else "none"}, ' +
|
|
307
|
+
f'Represents {num_loops} raw entanglement(s)')
|
|
308
|
+
|
|
309
|
+
## QC that the number of tracked entanglements after step 2a is still valid
|
|
310
|
+
#print(f'Step2a_QC_counter: {Step2a_QC_counter} should = {num_raw_ents[ID]}')
|
|
311
|
+
if Step2a_QC_counter != num_raw_ents[ID]:
|
|
312
|
+
raise ValueError(f'The number of tracked entaglements after Step 2a {Step2a_QC_counter} != {num_raw_ents[ID]}')
|
|
313
|
+
|
|
314
|
+
############################################################################################
|
|
315
|
+
# Step 2b:
|
|
316
|
+
self.logger.info(f'{"="*100}')
|
|
317
|
+
self.logger.info(f'STEP 2B: Merging of Entanglements Based on Crossing Proximity')
|
|
318
|
+
self.logger.info(f'{"="*100}')
|
|
319
|
+
merged_ents = []
|
|
320
|
+
for ID, ents in ent_data.items():
|
|
321
|
+
orig_ents = ents.copy()
|
|
322
|
+
comb_ents = itertools.combinations(ents, 2)
|
|
323
|
+
|
|
324
|
+
# for each pair of ents
|
|
325
|
+
for each_ent_pair in comb_ents:
|
|
326
|
+
#print(f'\nAnalyzing pair: {each_ent_pair}')
|
|
327
|
+
if each_ent_pair[0] == each_ent_pair[1]:
|
|
328
|
+
self.logger.info(f'Ents are the same: {each_ent_pair}')
|
|
329
|
+
continue
|
|
330
|
+
|
|
331
|
+
distance_thresholds = []
|
|
332
|
+
|
|
333
|
+
ent1 = each_ent_pair[0]
|
|
334
|
+
ent2 = each_ent_pair[1]
|
|
335
|
+
|
|
336
|
+
# get crossings from ent pair without chiralities
|
|
337
|
+
cr1 = set([int(ent_cr_1[1:]) for ent_cr_1 in list(ent1[3:-1])])
|
|
338
|
+
cr2 = set([int(ent_cr_2[1:]) for ent_cr_2 in list(ent2[3:-1])])
|
|
339
|
+
#print(cr1, cr2)
|
|
340
|
+
|
|
341
|
+
# get all possible pairs of the crossings
|
|
342
|
+
all_cr_pairs = itertools.product(ent1[3:-1], ent2[3:-1])
|
|
343
|
+
|
|
344
|
+
# get the distances between all pairs of crossings
|
|
345
|
+
cr_dist_same_chiral = np.abs([int(pr[0][1:]) - int(pr[1][1:]) for pr in all_cr_pairs if pr[0][0] == pr[1][0]])
|
|
346
|
+
#print(cr_dist_same_chiral)
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
# if any of those distances is less than 3 and the number of crossings is not the same and ij in range of kl
|
|
350
|
+
dist_check = np.any(cr_dist_same_chiral <= 3)
|
|
351
|
+
loop_check = self.check_step_ij_kl_range(ent1, ent2)
|
|
352
|
+
diff_cross_size_check = len(cr1) != len(cr2)
|
|
353
|
+
#print(dist_check, loop_check, diff_cross_size_check)
|
|
354
|
+
|
|
355
|
+
if np.any(cr_dist_same_chiral <= 3) and len(cr1) != len(cr2) and self.check_step_ij_kl_range(ent1, ent2):
|
|
356
|
+
#print(f'\nAnalyzing pair: {each_ent_pair}')
|
|
357
|
+
#print(f'step 2b conditions met')
|
|
358
|
+
|
|
359
|
+
minumum_loop_base = min(ent1[1], ent1[2], ent2[1], ent2[2])
|
|
360
|
+
|
|
361
|
+
maximum_loop_base = max(ent1[1], ent1[2], ent2[1], ent2[2])
|
|
362
|
+
|
|
363
|
+
all_crossings = cr1.union(cr2)
|
|
364
|
+
|
|
365
|
+
min_max_loop_base_range = set(range(minumum_loop_base, maximum_loop_base + 1))
|
|
366
|
+
|
|
367
|
+
# if the crossings are not within the min max loop range covering both entanglements
|
|
368
|
+
if not min_max_loop_base_range.intersection(all_crossings):
|
|
369
|
+
|
|
370
|
+
fewer_cr = min(cr1, cr2, key = len)
|
|
371
|
+
more_cr = max(cr1, cr2, key = len)
|
|
372
|
+
|
|
373
|
+
distributive_product = list(itertools.product(fewer_cr, more_cr))
|
|
374
|
+
|
|
375
|
+
slices = itertools.islice(distributive_product, 0, None, len(more_cr))
|
|
376
|
+
|
|
377
|
+
groupings = []
|
|
378
|
+
|
|
379
|
+
for end_point in slices:
|
|
380
|
+
|
|
381
|
+
first_index = distributive_product.index(end_point)
|
|
382
|
+
|
|
383
|
+
groupings.append(distributive_product[first_index:len(more_cr) + first_index])
|
|
384
|
+
|
|
385
|
+
if len(groupings) != 1:
|
|
386
|
+
|
|
387
|
+
all_pair_products = itertools.product(*groupings)
|
|
388
|
+
all_pair_groupings = set()
|
|
389
|
+
|
|
390
|
+
for pairs in all_pair_products:
|
|
391
|
+
|
|
392
|
+
flag = True
|
|
393
|
+
|
|
394
|
+
# check common elements column wise
|
|
395
|
+
stacked_pairs = np.stack(pairs)
|
|
396
|
+
|
|
397
|
+
for col in range(stacked_pairs.shape[1]):
|
|
398
|
+
|
|
399
|
+
if stacked_pairs[:, col].size != len(set(stacked_pairs[:, col])):
|
|
400
|
+
|
|
401
|
+
flag = False
|
|
402
|
+
break
|
|
403
|
+
|
|
404
|
+
if flag:
|
|
405
|
+
|
|
406
|
+
all_pair_groupings.add(pairs)
|
|
407
|
+
|
|
408
|
+
else:
|
|
409
|
+
|
|
410
|
+
all_pair_groupings = groupings[0]
|
|
411
|
+
|
|
412
|
+
for condensed_pair in all_pair_groupings:
|
|
413
|
+
|
|
414
|
+
if isinstance(condensed_pair[0], int):
|
|
415
|
+
|
|
416
|
+
# when dealing with ent with one crossing
|
|
417
|
+
|
|
418
|
+
condensed_pair = [condensed_pair]
|
|
419
|
+
|
|
420
|
+
dist = np.sqrt(sum([(each_ele[0] - each_ele[1]) ** 2 for each_ele in condensed_pair]))
|
|
421
|
+
|
|
422
|
+
distance_thresholds.append(dist)
|
|
423
|
+
|
|
424
|
+
# all_pair_groupings and distance thresholds have the same size
|
|
425
|
+
if min(distance_thresholds) <= 20:
|
|
426
|
+
|
|
427
|
+
min_ent = min(ent1, ent2, key = len)
|
|
428
|
+
max_ent = max(ent1, ent2, key = len)
|
|
429
|
+
if min_ent == max_ent:
|
|
430
|
+
self.logger.warning(f'WARNING: Ents are the same. Setting min_ent = ent1 and max_ent = ent2')
|
|
431
|
+
min_ent = ent1
|
|
432
|
+
max_ent = ent2
|
|
433
|
+
#print(f'ent1: {ent1} | ent2: {ent2}')
|
|
434
|
+
#print(f'min_ent: {min_ent}')
|
|
435
|
+
#print(f'max_ent: {max_ent}')
|
|
436
|
+
|
|
437
|
+
if min_ent in ents and len(ents) > 1:
|
|
438
|
+
#if min_ent in ents and max_ent in ents and len(ents) > 1:
|
|
439
|
+
#min_ent_num_ncs = min_ent[0]
|
|
440
|
+
|
|
441
|
+
#print(f'Removing: min_ent {min_ent} at index {ents.index(min_ent)}')
|
|
442
|
+
del ents[ents.index(min_ent)]
|
|
443
|
+
#del ents[ents.index(max_ent)]
|
|
444
|
+
|
|
445
|
+
if max_ent == min_ent:
|
|
446
|
+
raise ValueError(f'WARNING: Ents are the same\n{min_ent} == {max_ent}')
|
|
447
|
+
else:
|
|
448
|
+
merged_ents += [(max_ent, min_ent)]
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
#print(f'\nStep 2b results')
|
|
452
|
+
# results foor the end of step 2
|
|
453
|
+
for ID, ents in ent_data.items():
|
|
454
|
+
ent_dict = {ent_idx:[ent] for ent_idx, ent in enumerate(ents)}
|
|
455
|
+
#for ent_idx, ent in ent_dict.items():
|
|
456
|
+
# print(ent_idx, ent)
|
|
457
|
+
|
|
458
|
+
### Update entanglement list with those that got merged
|
|
459
|
+
self.logger.info(f'\n Processing {ID}: Analyzing {len(ents)} representative entanglements for merging...')
|
|
460
|
+
self.logger.info(f' Before merge: {len(ents)} representatives')
|
|
461
|
+
merge_count = 0
|
|
462
|
+
while len(merged_ents) != 0:
|
|
463
|
+
pre_num_merged = len(merged_ents)
|
|
464
|
+
#print(f'# merged_ents: {pre_num_merged}')
|
|
465
|
+
for m_ent in merged_ents:
|
|
466
|
+
#print(f'\n{m_ent}')
|
|
467
|
+
for ent_idx, ent in ent_dict.copy().items():
|
|
468
|
+
#print(ent_idx, ent)
|
|
469
|
+
if m_ent[0] in ent:
|
|
470
|
+
#print(f'FOUND MATCH for kept ent {ent_idx}')
|
|
471
|
+
ent_dict[ent_idx] += [m_ent[1]]
|
|
472
|
+
merged_ents.remove(m_ent)
|
|
473
|
+
merge_count += 1
|
|
474
|
+
|
|
475
|
+
#print(f'# merged_ents: {len(merged_ents)}')
|
|
476
|
+
|
|
477
|
+
# QC to ensure you dont enter an infinite loop
|
|
478
|
+
if len(merged_ents) == pre_num_merged:
|
|
479
|
+
raise ValueError('Failed to find a match this cycle and entering infi loop')
|
|
480
|
+
|
|
481
|
+
self.logger.info(f' After merge: {merge_count} merges completed')
|
|
482
|
+
self.logger.info(f' Reformatting entanglement data...')
|
|
483
|
+
updated_ents = []
|
|
484
|
+
for ent_idx, ent in ent_dict.items():
|
|
485
|
+
#print(ent_idx, ent)
|
|
486
|
+
if len(ent) > 1:
|
|
487
|
+
num_loops = np.sum([e[0] for e in ent])
|
|
488
|
+
NCs = ';'.join([e[-1] for e in ent])
|
|
489
|
+
#print(ent, num_loops, NCs)
|
|
490
|
+
ent = (num_loops, *ent[0][1:-1], NCs)
|
|
491
|
+
updated_ents += [ent]
|
|
492
|
+
else:
|
|
493
|
+
updated_ents += ent
|
|
494
|
+
|
|
495
|
+
#print(f'Results after adding those that got merged to each representative entanglement')
|
|
496
|
+
Step2b_QC_counter = 0
|
|
497
|
+
for uent in updated_ents:
|
|
498
|
+
#print(uent)
|
|
499
|
+
Step2b_QC_counter += uent[0]
|
|
500
|
+
|
|
501
|
+
self.logger.info(f' Result: {len(updated_ents)} representative entanglements after merging (tracking {Step2b_QC_counter} total raw)')
|
|
502
|
+
self.logger.debug(updated_ents)
|
|
503
|
+
## QC that the number of tracked entanglements after step 2a is still valid
|
|
504
|
+
if Step2b_QC_counter != num_raw_ents[ID]:
|
|
505
|
+
raise ValueError(f'The number of tracked entaglements after Step 2b {Step2b_QC_counter} != {num_raw_ents[ID]}')
|
|
506
|
+
|
|
507
|
+
ent_data[ID] = updated_ents
|
|
508
|
+
|
|
509
|
+
### STEP 3 ################################################################################################################
|
|
510
|
+
# Step 3
|
|
511
|
+
self.logger.info(f'{"="*100}')
|
|
512
|
+
self.logger.info(f'STEP 3: Removing Duplicate Entanglements (Same Crossings, Different Loop Sizes)')
|
|
513
|
+
self.logger.info(f'{"="*100}')
|
|
514
|
+
for ID, processed_ents in ent_data.items():
|
|
515
|
+
self.logger.info(f' {ID}: Checking {len(processed_ents)} entanglements for duplicates...')
|
|
516
|
+
|
|
517
|
+
comb_processed_ents = itertools.combinations(processed_ents, 2)
|
|
518
|
+
|
|
519
|
+
keep_track_of_larger_proc_ent = []
|
|
520
|
+
keep_track_of_shorter_proc_ent = []
|
|
521
|
+
removal_count = 0
|
|
522
|
+
|
|
523
|
+
for each_processed_ent_pair in comb_processed_ents:
|
|
524
|
+
#print(f'\npair of ents: {each_processed_ent_pair}')
|
|
525
|
+
|
|
526
|
+
proc_ent1 = each_processed_ent_pair[0]
|
|
527
|
+
proc_ent2 = each_processed_ent_pair[1]
|
|
528
|
+
|
|
529
|
+
proc_ent1_ijr = proc_ent1[1:-1]
|
|
530
|
+
proc_ent2_ijr = proc_ent2[1:-1]
|
|
531
|
+
#print(proc_ent1_ijr, proc_ent2_ijr)
|
|
532
|
+
|
|
533
|
+
if proc_ent1_ijr not in keep_track_of_larger_proc_ent and proc_ent2_ijr not in keep_track_of_larger_proc_ent:
|
|
534
|
+
|
|
535
|
+
# without chiralites
|
|
536
|
+
proc_cr1 = np.asarray([int(ent_cr_1[1:]) for ent_cr_1 in list(proc_ent1[3:-1])])
|
|
537
|
+
proc_cr2 = np.asarray([int(ent_cr_2[1:]) for ent_cr_2 in list(proc_ent2[3:-1])])
|
|
538
|
+
|
|
539
|
+
if len(proc_ent1[3:-1]) == len(proc_ent2[3:-1]):
|
|
540
|
+
|
|
541
|
+
chirality1 = [chir1[0] for chir1 in proc_ent1[3:-1]]
|
|
542
|
+
chirality2 = [chir2[0] for chir2 in proc_ent2[3:-1]]
|
|
543
|
+
|
|
544
|
+
if chirality1 == chirality2 and self.check_step_ij_kl_range(proc_ent1, proc_ent2) and np.all(np.abs(proc_cr1 - proc_cr2) <= 20):
|
|
545
|
+
#print(proc_ent1, proc_ent2)
|
|
546
|
+
|
|
547
|
+
loop1_length = proc_ent1[2] - proc_ent1[1]
|
|
548
|
+
loop2_length = proc_ent2[2] - proc_ent2[1]
|
|
549
|
+
#print(loop1_length, loop2_length)
|
|
550
|
+
|
|
551
|
+
check = [loop1_length, loop2_length]
|
|
552
|
+
|
|
553
|
+
maximum_loop_length = max(loop1_length, loop2_length)
|
|
554
|
+
minimum_loop_length = min(loop1_length, loop2_length)
|
|
555
|
+
|
|
556
|
+
if maximum_loop_length == minimum_loop_length:
|
|
557
|
+
longer_loop_ent = proc_ent1
|
|
558
|
+
shorter_loop_ent = proc_ent2
|
|
559
|
+
else:
|
|
560
|
+
longer_loop_ent = each_processed_ent_pair[check.index(maximum_loop_length)]
|
|
561
|
+
shorter_loop_ent = each_processed_ent_pair[check.index(minimum_loop_length)]
|
|
562
|
+
longer_loop_ent_ijr = longer_loop_ent[1:-1]
|
|
563
|
+
shorter_loop_ent_ijr = shorter_loop_ent[1:-1]
|
|
564
|
+
|
|
565
|
+
if len(processed_ents) > 1:
|
|
566
|
+
|
|
567
|
+
for long_proc_ent_index, long_proc_ent in enumerate(processed_ents):
|
|
568
|
+
long_proc_ent_ijr = long_proc_ent[1:-1]
|
|
569
|
+
#print(long_proc_ent_index, long_proc_ent, long_proc_ent_ijr)
|
|
570
|
+
if long_proc_ent_ijr == longer_loop_ent_ijr:
|
|
571
|
+
break
|
|
572
|
+
del processed_ents[long_proc_ent_index]
|
|
573
|
+
# remove the one with larger loop
|
|
574
|
+
|
|
575
|
+
# find the shorter loop and remove it
|
|
576
|
+
for short_proc_ent_index, short_proc_ent in enumerate(processed_ents):
|
|
577
|
+
short_proc_ent_ijr = short_proc_ent[1:-1]
|
|
578
|
+
if short_proc_ent_ijr == shorter_loop_ent_ijr:
|
|
579
|
+
break
|
|
580
|
+
del processed_ents[short_proc_ent_index]
|
|
581
|
+
|
|
582
|
+
updated_ent = (short_proc_ent[0] + long_proc_ent[0], *short_proc_ent[1:-1], short_proc_ent[-1] + ';' + long_proc_ent[-1])
|
|
583
|
+
processed_ents += [updated_ent]
|
|
584
|
+
|
|
585
|
+
keep_track_of_larger_proc_ent.append(longer_loop_ent_ijr)
|
|
586
|
+
keep_track_of_shorter_proc_ent.append(shorter_loop_ent_ijr)
|
|
587
|
+
|
|
588
|
+
# STEP 3 FINAL SUMMARY
|
|
589
|
+
self.logger.info(f'\nSTEP 3 RESULTS:')
|
|
590
|
+
for ID, ents in ent_data.items():
|
|
591
|
+
Step3_QC_counter = 0
|
|
592
|
+
for ent in ents:
|
|
593
|
+
#print(ent)
|
|
594
|
+
Step3_QC_counter += ent[0]
|
|
595
|
+
|
|
596
|
+
self.logger.info(f' {ID}: {len(ents)} representative entanglements remaining (tracking {Step3_QC_counter} raw)')
|
|
597
|
+
# QC to ensure number of raw ents was preserved after step 3
|
|
598
|
+
if Step3_QC_counter != num_raw_ents[ID]:
|
|
599
|
+
raise ValueError(f'The number of tracked entaglements after Step 3 {Step3_QC_counter} != {num_raw_ents[ID]}')
|
|
600
|
+
|
|
601
|
+
### STEP 4 SPATIAL CLUSTERING ################################################################################################################
|
|
602
|
+
# Step 4 prep
|
|
603
|
+
self.logger.info(f'{"="*100}')
|
|
604
|
+
self.logger.info(f'STEP 4 PREP: Grouping Entanglements by Number and Chirality of Crossings')
|
|
605
|
+
self.logger.info(f'{"="*100}')
|
|
606
|
+
for ID, new_ents in ent_data.items():
|
|
607
|
+
|
|
608
|
+
for ent in new_ents:
|
|
609
|
+
|
|
610
|
+
number_of_crossings = len(ent[3:-1])
|
|
611
|
+
|
|
612
|
+
chiralites = [each_cr[0] for each_cr in ent[3:-1]]
|
|
613
|
+
|
|
614
|
+
ID_num_chirality_key = f"{ID}_{number_of_crossings}_{chiralites}"
|
|
615
|
+
|
|
616
|
+
full_entanglement_data[ID_num_chirality_key].append(ent)
|
|
617
|
+
|
|
618
|
+
self.logger.info(f'\nGrouping Summary:')
|
|
619
|
+
for group_key in sorted(full_entanglement_data.keys()):
|
|
620
|
+
ents = full_entanglement_data[group_key]
|
|
621
|
+
self.logger.info(f' {group_key}: {len(ents)} entanglements')
|
|
622
|
+
|
|
623
|
+
reset_counter = []
|
|
624
|
+
|
|
625
|
+
# Step 4
|
|
626
|
+
self.logger.info(f'{"="*100}')
|
|
627
|
+
self.logger.info(f'STEP 4: Primary Structure Clustering Within Each Group')
|
|
628
|
+
self.logger.info(f'{"="*100}')
|
|
629
|
+
for ID_num_chiral in full_entanglement_data.keys():
|
|
630
|
+
#print(ID_num_chiral)
|
|
631
|
+
#ID = ID_num_chiral.split("_")[0]
|
|
632
|
+
|
|
633
|
+
if ID not in reset_counter:
|
|
634
|
+
|
|
635
|
+
reset_counter.append(ID)
|
|
636
|
+
|
|
637
|
+
split_cluster_counter = 0
|
|
638
|
+
|
|
639
|
+
length_key = defaultdict(list)
|
|
640
|
+
loop_dist = defaultdict(list)
|
|
641
|
+
dups = []
|
|
642
|
+
clusters = {}
|
|
643
|
+
cluster_count = 0
|
|
644
|
+
|
|
645
|
+
pairwise_entanglements = list(itertools.combinations(full_entanglement_data[ID_num_chiral], 2))
|
|
646
|
+
|
|
647
|
+
self.logger.info(f'\n Group: {ID_num_chiral}')
|
|
648
|
+
self.logger.info(f' Total entanglements: {len(full_entanglement_data[ID_num_chiral])}')
|
|
649
|
+
self.logger.info(f' Pairwise comparisons: {len(pairwise_entanglements)}')
|
|
650
|
+
|
|
651
|
+
if pairwise_entanglements:
|
|
652
|
+
|
|
653
|
+
for i, pairwise_ent in enumerate(pairwise_entanglements):
|
|
654
|
+
|
|
655
|
+
dist = self.loop_distance(pairwise_ent[0], pairwise_ent[1])
|
|
656
|
+
|
|
657
|
+
if dist <= self.cut_off and pairwise_ent[0] not in dups and pairwise_ent[1] not in dups:
|
|
658
|
+
# 1. pair must be <= self.cut_off
|
|
659
|
+
# 2. the neighbor cannot be the next key and it cannot be captured by another key
|
|
660
|
+
|
|
661
|
+
loop_dist[pairwise_ent[0]].append(pairwise_ent[1])
|
|
662
|
+
dups.append(pairwise_ent[1])
|
|
663
|
+
|
|
664
|
+
key_list = list(loop_dist.keys())
|
|
665
|
+
|
|
666
|
+
for key in key_list:
|
|
667
|
+
|
|
668
|
+
length_key[len(loop_dist[key])].append(key)
|
|
669
|
+
|
|
670
|
+
# create clusters
|
|
671
|
+
|
|
672
|
+
while len(length_key.values()) > 0:
|
|
673
|
+
|
|
674
|
+
max_neighbor = max(length_key.keys())
|
|
675
|
+
|
|
676
|
+
selected_ent = random.choice(length_key[max_neighbor])
|
|
677
|
+
|
|
678
|
+
cluster = copy.deepcopy(loop_dist[selected_ent])
|
|
679
|
+
cluster.append(selected_ent)
|
|
680
|
+
|
|
681
|
+
clusters[cluster_count] = cluster
|
|
682
|
+
cluster_count += 1
|
|
683
|
+
|
|
684
|
+
length_key[max_neighbor].remove(selected_ent)
|
|
685
|
+
|
|
686
|
+
if len(length_key[max_neighbor]) == 0:
|
|
687
|
+
length_key.pop(max_neighbor)
|
|
688
|
+
|
|
689
|
+
# create single clusters
|
|
690
|
+
if clusters:
|
|
691
|
+
clusters_ijr_values = list(itertools.chain.from_iterable(list(clusters.values())))
|
|
692
|
+
else:
|
|
693
|
+
clusters_ijr_values = []
|
|
694
|
+
|
|
695
|
+
full_ent_values = np.asarray(full_entanglement_data[ID_num_chiral], dtype=object)
|
|
696
|
+
|
|
697
|
+
difference_ent = np.zeros(len(full_ent_values), dtype=bool)
|
|
698
|
+
|
|
699
|
+
for k, ijr in enumerate(full_ent_values):
|
|
700
|
+
|
|
701
|
+
if tuple(ijr) in clusters_ijr_values:
|
|
702
|
+
difference_ent[k] = True
|
|
703
|
+
else:
|
|
704
|
+
difference_ent[k] = False
|
|
705
|
+
|
|
706
|
+
i = np.unique(np.where(difference_ent == False)[0])
|
|
707
|
+
|
|
708
|
+
next_cluster_count = cluster_count
|
|
709
|
+
|
|
710
|
+
for single_cluster in full_ent_values[i]:
|
|
711
|
+
|
|
712
|
+
single_cluster_list = []
|
|
713
|
+
single_cluster_list.append(tuple(single_cluster))
|
|
714
|
+
|
|
715
|
+
clusters[next_cluster_count] = single_cluster_list
|
|
716
|
+
|
|
717
|
+
next_cluster_count += 1
|
|
718
|
+
|
|
719
|
+
# pick representative entanglement per cluster
|
|
720
|
+
self.logger.info(f' Primary structure clusters formed: {len(clusters)}')
|
|
721
|
+
for counter, ijr_values in clusters.items():
|
|
722
|
+
#print(f'\nCluster {counter} {ijr_values}')
|
|
723
|
+
|
|
724
|
+
# clusters contain many entanglements
|
|
725
|
+
if len(ijr_values) > 1:
|
|
726
|
+
|
|
727
|
+
ijr = np.asarray(ijr_values)
|
|
728
|
+
#print(f'cluster ijr:\n{ijr}')
|
|
729
|
+
|
|
730
|
+
cr_values = np.asarray([[int(r_value[0][1:])] for r_value in ijr[:, 3:-1]])
|
|
731
|
+
#print(f'cr_values: {cr_values}')
|
|
732
|
+
|
|
733
|
+
median_cr = compute_geometric_median(cr_values).median
|
|
734
|
+
#print(f'median_cr: {median_cr}')
|
|
735
|
+
|
|
736
|
+
distances = cdist(cr_values, [median_cr])
|
|
737
|
+
|
|
738
|
+
minimum_distances_i = np.where(distances == min(distances))[0]
|
|
739
|
+
#print(f'minimum_distances_i: {minimum_distances_i}')
|
|
740
|
+
|
|
741
|
+
possible_cand = ijr[minimum_distances_i]
|
|
742
|
+
#print(f'possible_cand:\n{possible_cand}')
|
|
743
|
+
|
|
744
|
+
loop_lengths = np.abs(possible_cand[:, 1].astype(int) - possible_cand[:, 2].astype(int))
|
|
745
|
+
#print(f'loop_lengths: {loop_lengths}')
|
|
746
|
+
|
|
747
|
+
smallest_loop_length = min(loop_lengths)
|
|
748
|
+
#print(f'smallest_loop_length: {smallest_loop_length}')
|
|
749
|
+
|
|
750
|
+
num_nc_in_cluster = np.sum([int(n[0]) for n in ijr_values])
|
|
751
|
+
#print(f'num_nc_in_cluster: {num_nc_in_cluster}')
|
|
752
|
+
|
|
753
|
+
loops = ';'.join([n[-1] for n in ijr_values])
|
|
754
|
+
#print(f'loops: {loops}')
|
|
755
|
+
|
|
756
|
+
#rep_entanglement = possible_cand[random.choice(np.where(smallest_loop_length == loop_lengths)[0])]
|
|
757
|
+
rep_entanglement = possible_cand[random.choice(np.where(smallest_loop_length == loop_lengths))[0]]
|
|
758
|
+
rep_entanglement = [str(num_nc_in_cluster), *rep_entanglement[1:-1], loops]
|
|
759
|
+
#rep_ID_ent[f"{ID}_{split_cluster_counter}"].append(rep_entanglement)
|
|
760
|
+
rep_ID_ent[(ID, split_cluster_counter)].append(rep_entanglement)
|
|
761
|
+
|
|
762
|
+
# clusters with a single entnalgement
|
|
763
|
+
else:
|
|
764
|
+
#rep_ID_ent[f"{ID}_{split_cluster_counter}"].append(ijr_values[0])
|
|
765
|
+
rep_ID_ent[(ID, split_cluster_counter)].append(ijr_values[0])
|
|
766
|
+
if counter == list(clusters.keys())[-1]: # Print only for last cluster to avoid clutter
|
|
767
|
+
num_single = sum(1 for c_vals in clusters.values() if len(c_vals) == 1)
|
|
768
|
+
self.logger.info(f' Single-entanglement clusters: {num_single}')
|
|
769
|
+
|
|
770
|
+
split_cluster_counter += 1
|
|
771
|
+
|
|
772
|
+
## QC Step 4 results
|
|
773
|
+
self.logger.info(f'\n{"="*100}')
|
|
774
|
+
self.logger.info(f'STEP 4 FINAL RESULTS: Primary Structure Clustering Summary')
|
|
775
|
+
self.logger.info(f'{"="*100}')
|
|
776
|
+
num_raw_ents_FINAL = {}
|
|
777
|
+
for ID_counter, ijrs in rep_ID_ent.items():
|
|
778
|
+
#print(ID_counter, ijrs)
|
|
779
|
+
ID, counter = ID_counter
|
|
780
|
+
#print(ID_counter, ID, counter, ijrs)
|
|
781
|
+
|
|
782
|
+
if ID not in num_raw_ents_FINAL:
|
|
783
|
+
num_raw_ents_FINAL[ID] = 0
|
|
784
|
+
|
|
785
|
+
for ijr in ijrs:
|
|
786
|
+
num_nc = int(ijr[0])
|
|
787
|
+
num_raw_ents_FINAL[ID] += num_nc
|
|
788
|
+
|
|
789
|
+
## check the final tracking of raw ents
|
|
790
|
+
for ID, count in num_raw_ents.items():
|
|
791
|
+
final_count = num_raw_ents_FINAL[ID]
|
|
792
|
+
num_clusters = len([ijrs for (c_id, _), ijrs in rep_ID_ent.items() if c_id == ID])
|
|
793
|
+
self.logger.info(f'{ID}: {count} raw → {final_count} raw in {num_clusters} final clusters')
|
|
794
|
+
if count != final_count:
|
|
795
|
+
raise ValueError(f'The FINAL # of raw ents {final_count} != the starting {count} for ID {ID}')
|
|
796
|
+
|
|
797
|
+
### STEP 5 OUTPUT FILE ################################################################################################################
|
|
798
|
+
# Step 5
|
|
799
|
+
self.logger.info(f'{"="*100}')
|
|
800
|
+
self.logger.info(f'STEP 5: Writing Output File')
|
|
801
|
+
self.logger.info(f'{"="*100}')
|
|
802
|
+
|
|
803
|
+
## set up the outdir for this calculation
|
|
804
|
+
#outdir = f"{os.getcwd()}/{outdir}"
|
|
805
|
+
if not os.path.isdir(outdir):
|
|
806
|
+
os.mkdir(f"{outdir}")
|
|
807
|
+
self.logger.info(f"Creating directory: {outdir}")
|
|
808
|
+
|
|
809
|
+
outfilepath = os.path.join(f'{outdir}', f'{outfile}')
|
|
810
|
+
|
|
811
|
+
with open(outfilepath, "w") as f:
|
|
812
|
+
|
|
813
|
+
f.write(f'ID|chain|i|j|crossingsN|crossingsC|gn|gc|GLNn|GLNc|TLNn|TLNc|num_contacts|contacts|CCBond\n')
|
|
814
|
+
for ID_counter, ijrs in rep_ID_ent.items():
|
|
815
|
+
|
|
816
|
+
ID, counter = ID_counter
|
|
817
|
+
chain = chain_info.get(ID, 'A') # Get chain for this ID, default to 'A'
|
|
818
|
+
|
|
819
|
+
for ijr in ijrs:
|
|
820
|
+
|
|
821
|
+
new_ijr = (int(ijr[1]), int(ijr[2]), *list(ijr[3:-1]))
|
|
822
|
+
|
|
823
|
+
num_nc = int(ijr[0])
|
|
824
|
+
|
|
825
|
+
gn, gc, GLNn, GLNc, TLNn, TLNc, crossingsN_stored, crossingsC_stored = entanglement_partial_g_data[new_ijr]
|
|
826
|
+
gn = float(gn)
|
|
827
|
+
gc = float(gc)
|
|
828
|
+
GLNn = int(GLNn)
|
|
829
|
+
GLNc = int(GLNc)
|
|
830
|
+
# Handle NaN/empty TLN values: convert to int only if not NaN
|
|
831
|
+
TLNn = np.nan if pd.isna(TLNn) else int(TLNn)
|
|
832
|
+
TLNc = np.nan if pd.isna(TLNc) else int(TLNc)
|
|
833
|
+
|
|
834
|
+
# Separate crossings into N and C terminal
|
|
835
|
+
all_crossings = list(ijr[3:-1])
|
|
836
|
+
crossingsN = []
|
|
837
|
+
crossingsC = []
|
|
838
|
+
i_val = int(ijr[1])
|
|
839
|
+
j_val = int(ijr[2])
|
|
840
|
+
for cross in all_crossings:
|
|
841
|
+
cross_resid = int(cross[1:])
|
|
842
|
+
if cross_resid < i_val:
|
|
843
|
+
crossingsN.append(cross)
|
|
844
|
+
elif cross_resid > j_val:
|
|
845
|
+
crossingsC.append(cross)
|
|
846
|
+
|
|
847
|
+
crossingsN_str = ','.join(crossingsN) if crossingsN else ''
|
|
848
|
+
crossingsC_str = ','.join(crossingsC) if crossingsC else ''
|
|
849
|
+
|
|
850
|
+
## check for disulfide bonds
|
|
851
|
+
CCBond_flag = False
|
|
852
|
+
for CCBond in CCBonds:
|
|
853
|
+
check1 = f'{CCBond[0]}-{CCBond[1]}'
|
|
854
|
+
check2 = f'{CCBond[1]}-{CCBond[0]}'
|
|
855
|
+
if check1 in ijr[-1] or check2 in ijr[-1]:
|
|
856
|
+
CCBond_flag = True
|
|
857
|
+
|
|
858
|
+
line = f"{ID}|{chain}|{int(ijr[1])}|{int(ijr[2])}|{crossingsN_str}|{crossingsC_str}|{gn:.5f}|{gc:.5f}|{GLNn}|{GLNc}|{TLNn}|{TLNc}|{num_nc}|{ijr[-1]}|{CCBond_flag}"
|
|
859
|
+
#print(line)
|
|
860
|
+
f.write(f"{line}\n")
|
|
861
|
+
self.logger.info(f'SAVED: {outfilepath}')
|
|
862
|
+
outdf = pd.read_csv(outfilepath, sep='|')
|
|
863
|
+
|
|
864
|
+
# FINAL CLUSTERING SUMMARY
|
|
865
|
+
self.logger.info(f'\n{"="*100}')
|
|
866
|
+
self.logger.info(f'CLUSTERING COMPLETE: Final Summary')
|
|
867
|
+
self.logger.info(f'{"="*100}')
|
|
868
|
+
self.logger.info(f'Total raw entanglements processed: {sum(num_raw_ents.values())}')
|
|
869
|
+
self.logger.info(f'Total final representative entanglements: {len(outdf)}')
|
|
870
|
+
# print(f'Compression ratio: {sum(num_raw_ents.values())/len(outdf):.2f}x (raw → final)')
|
|
871
|
+
self.logger.info(f'Clustering by organism: {self.organism}')
|
|
872
|
+
self.logger.info(f'Spatial distance cutoff: {self.cut_off}')
|
|
873
|
+
self.logger.info(f'Output file: {outfilepath}')
|
|
874
|
+
self.logger.info(f'{"="*100}\n')
|
|
875
|
+
|
|
876
|
+
return {'outfile':outfilepath, 'ent_result':outdf}
|
|
877
|
+
##########################################################################################################################################################
|
|
878
|
+
##########################################################################################################################################################
|
|
879
|
+
##########################################################################################################################################################
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
##########################################################################################################################################################
|
|
883
|
+
##########################################################################################################################################################
|
|
884
|
+
class ClusterNonNativeEntanglements:
|
|
885
|
+
"""
|
|
886
|
+
Class to calculate Non-native entanglements given either a file path to an entanglement file or an entanglement object
|
|
887
|
+
"""
|
|
888
|
+
|
|
889
|
+
##########################################################################################################################################################
|
|
890
|
+
def __init__(self, trajnum2pklfile_path:str, traj_dir_prefix:str='./', outdir:str='./ClusterNonNativeEntanglements/', log_level:int=logging.INFO, logdir:str=None, nproc:int=1) -> None:
|
|
891
|
+
"""
|
|
892
|
+
Constructor for GaussianEntanglement class.
|
|
893
|
+
|
|
894
|
+
Parameters
|
|
895
|
+
----------
|
|
896
|
+
"""
|
|
897
|
+
self.classify_key = ['topoly_linking_number']
|
|
898
|
+
self.cluster_method = ['average', 'average', 'average']
|
|
899
|
+
# cluster_dist_cutoff = [20, 1.0, 0.6] # Allow contamination
|
|
900
|
+
self.cluster_dist_cutoff = [20, 1.0, 0.1] # No contamination
|
|
901
|
+
self.memory_cutoff = 6.4e10 # 64 Gb
|
|
902
|
+
self.max_plot_samples = 1000
|
|
903
|
+
|
|
904
|
+
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
|
|
905
|
+
# matplotlib.rcParams['font.sans-serif'] = ['Arial']
|
|
906
|
+
matplotlib.rcParams['axes.labelsize'] = 'small'
|
|
907
|
+
matplotlib.rcParams['axes.linewidth'] = 1
|
|
908
|
+
matplotlib.rcParams['lines.markersize'] = 4
|
|
909
|
+
matplotlib.rcParams['xtick.major.width'] = 1
|
|
910
|
+
matplotlib.rcParams['ytick.major.width'] = 1
|
|
911
|
+
matplotlib.rcParams['xtick.labelsize'] = 'x-small'
|
|
912
|
+
matplotlib.rcParams['ytick.labelsize'] = 'x-small'
|
|
913
|
+
matplotlib.rcParams['legend.fontsize'] = 'x-small'
|
|
914
|
+
matplotlib.rcParams['figure.dpi'] = 600
|
|
915
|
+
|
|
916
|
+
self.nproc = max(1, int(nproc))
|
|
917
|
+
self.logger = setup_logger('ClusterNonNativeEntanglements', outdir=logdir if logdir is not None else outdir, log_level=log_level)
|
|
918
|
+
|
|
919
|
+
self.traj_dir_prefix = traj_dir_prefix
|
|
920
|
+
|
|
921
|
+
## Load the dataframe that maps trajectory numbers to pkl file paths
|
|
922
|
+
self.trajnum2pklfile_path = trajnum2pklfile_path
|
|
923
|
+
self.trajnum2pklfile = pd.read_csv(self.trajnum2pklfile_path)
|
|
924
|
+
|
|
925
|
+
## Extract pkl file paths from the manifest (source of truth)
|
|
926
|
+
if 'pklfile' not in self.trajnum2pklfile.columns:
|
|
927
|
+
self.logger.error('Error: trajnum2pklfile CSV must contain a "pklfile" column')
|
|
928
|
+
sys.exit()
|
|
929
|
+
|
|
930
|
+
self.ent_data_file_list = self.trajnum2pklfile['pklfile'].tolist()
|
|
931
|
+
|
|
932
|
+
# Verify all pkl files exist
|
|
933
|
+
missing_files = [f for f in self.ent_data_file_list if not os.path.isfile(f)]
|
|
934
|
+
if missing_files:
|
|
935
|
+
self.logger.error(f'Error: {len(missing_files)} pkl files not found:')
|
|
936
|
+
for f in missing_files:
|
|
937
|
+
self.logger.error(f' {f}')
|
|
938
|
+
sys.exit()
|
|
939
|
+
|
|
940
|
+
self.logger.info(f'FOUND {len(self.ent_data_file_list)} .pkl files to cluster from manifest')
|
|
941
|
+
|
|
942
|
+
## Set up the outdir for this calculation
|
|
943
|
+
self.outdir = outdir
|
|
944
|
+
if not os.path.isdir(self.outdir):
|
|
945
|
+
os.mkdir(f"{self.outdir}")
|
|
946
|
+
self.logger.info(f"Creating directory: {self.outdir}")
|
|
947
|
+
##########################################################################################################################################################
|
|
948
|
+
|
|
949
|
+
##########################################################################################################################################################
|
|
950
|
+
def save_pickle(self, filename, mode, data, protocol=4):
|
|
951
|
+
with open(filename, mode) as fh:
|
|
952
|
+
pickle.dump(data, fh, protocol=protocol)
|
|
953
|
+
##########################################################################################################################################################
|
|
954
|
+
|
|
955
|
+
##########################################################################################################################################################
|
|
956
|
+
def load_pickle(self, filename, start_frame=None, end_frame=None):
|
|
957
|
+
"""Load pickle file and optionally filter frames.
|
|
958
|
+
|
|
959
|
+
Parameters
|
|
960
|
+
----------
|
|
961
|
+
filename : str
|
|
962
|
+
Path to pickle file
|
|
963
|
+
start_frame : int, optional
|
|
964
|
+
Minimum frame index to keep (inclusive)
|
|
965
|
+
end_frame : int, optional
|
|
966
|
+
Maximum frame index to keep (inclusive)
|
|
967
|
+
|
|
968
|
+
Returns
|
|
969
|
+
-------
|
|
970
|
+
dict
|
|
971
|
+
Dictionary with frame keys and 'ref' key. Filtered to frame range if specified.
|
|
972
|
+
|
|
973
|
+
Note: Large unfiltered dictionaries are explicitly deleted to ensure timely
|
|
974
|
+
garbage collection, especially important when multiple threads load in parallel.
|
|
975
|
+
"""
|
|
976
|
+
import gc
|
|
977
|
+
data_dict = {}
|
|
978
|
+
self.logger.debug(f'Loading pickle file: {filename}')
|
|
979
|
+
with open(filename, 'rb') as fr:
|
|
980
|
+
try:
|
|
981
|
+
while True:
|
|
982
|
+
chunk = pickle.load(fr)
|
|
983
|
+
self.logger.debug(f"Loaded chunk with size: {len(chunk)}")
|
|
984
|
+
if start_frame is not None or end_frame is not None:
|
|
985
|
+
self.logger.debug(f"Filtering chunk for frames between {start_frame} and {end_frame}")
|
|
986
|
+
filtered_chunk = {}
|
|
987
|
+
|
|
988
|
+
for k, v in chunk.items():
|
|
989
|
+
if k == "ref":
|
|
990
|
+
filtered_chunk[k] = v
|
|
991
|
+
|
|
992
|
+
elif ((start_frame is None or k >= start_frame) and (end_frame is None or k <= end_frame)):
|
|
993
|
+
filtered_chunk[k] = v
|
|
994
|
+
|
|
995
|
+
del chunk # CRITICAL: Explicitly free large unfiltered dict before update
|
|
996
|
+
chunk = filtered_chunk
|
|
997
|
+
|
|
998
|
+
self.logger.debug(f"Loaded filtered chunk with size: {len(chunk)}")
|
|
999
|
+
data_dict.update(chunk)
|
|
1000
|
+
del chunk # Free memory after update
|
|
1001
|
+
|
|
1002
|
+
except EOFError:
|
|
1003
|
+
pass
|
|
1004
|
+
|
|
1005
|
+
self.logger.debug(f'Total frames loaded: {len(data_dict) - 1}') # Exclude 'ref' key from frame count
|
|
1006
|
+
|
|
1007
|
+
# Force garbage collection to prevent memory accumulation in parallel loads
|
|
1008
|
+
gc.collect()
|
|
1009
|
+
return data_dict
|
|
1010
|
+
##########################################################################################################################################################
|
|
1011
|
+
|
|
1012
|
+
##########################################################################################################################################################
|
|
1013
|
+
def extract_traj_number(self, f):
|
|
1014
|
+
f_match = self.trajnum2pklfile[self.trajnum2pklfile['pklfile'] == f]
|
|
1015
|
+
if f_match.empty:
|
|
1016
|
+
self.logger.error(f'Error: {f} not found in {self.trajnum2pklfile_path}')
|
|
1017
|
+
raise ValueError(f'Error: {f} not found in {self.trajnum2pklfile_path}')
|
|
1018
|
+
match = f_match['trajnum'].values[0]
|
|
1019
|
+
return match
|
|
1020
|
+
##########################################################################################################################################################
|
|
1021
|
+
|
|
1022
|
+
##########################################################################################################################################################
|
|
1023
|
+
def pdist_loop_overlap(self, data_array_1, data_array_2):
|
|
1024
|
+
M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
|
|
1025
|
+
M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
|
|
1026
|
+
M = np.concatenate((M_1, M_2), axis=-1)
|
|
1027
|
+
del M_1, M_2
|
|
1028
|
+
dist_M = (np.max(M, axis=-1) - np.min(M, axis=-1) + 1) / (M[:,:,1]-M[:,:,0]+M[:,:,3]-M[:,:,2])
|
|
1029
|
+
# Make distance between inclusive loops to be minimum
|
|
1030
|
+
dist_M[(M[:,:,2]-M[:,:,0])*(M[:,:,1]-M[:,:,3]) >= 0] = 0.5
|
|
1031
|
+
return dist_M
|
|
1032
|
+
##########################################################################################################################################################
|
|
1033
|
+
|
|
1034
|
+
##########################################################################################################################################################
|
|
1035
|
+
def pdist_thread_deviation(self, data_array_1, data_array_2):
|
|
1036
|
+
M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
|
|
1037
|
+
M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
|
|
1038
|
+
M = np.concatenate((M_1, M_2), axis=-1)
|
|
1039
|
+
del M_1, M_2
|
|
1040
|
+
dist_M = np.abs((M[:,:,2:4]-M[:,:,0:2]))
|
|
1041
|
+
dist_M[M[:,:,2:4]*M[:,:,0:2] < 0] = 10 # Make distance between no crossing and crossings to be small
|
|
1042
|
+
del M
|
|
1043
|
+
dist = np.max(dist_M, axis=-1)
|
|
1044
|
+
return dist
|
|
1045
|
+
##########################################################################################################################################################
|
|
1046
|
+
|
|
1047
|
+
##########################################################################################################################################################
|
|
1048
|
+
def pdist_cross_contamination(self, data_array_1, data_array_2):
|
|
1049
|
+
# data looks like [nc_1, nc_2, cross_N1, cross_N2, ..., cross_C1, cross_C2, ...]
|
|
1050
|
+
M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
|
|
1051
|
+
M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
|
|
1052
|
+
M = np.concatenate((M_1, M_2), axis=-1)
|
|
1053
|
+
del M_1, M_2
|
|
1054
|
+
|
|
1055
|
+
# Distance for cross_2 contaminate loop_1
|
|
1056
|
+
idx_array_1 = np.zeros((data_array_2.shape[1]-2,2), dtype=int)
|
|
1057
|
+
idx_array_1[:,0] = 1
|
|
1058
|
+
idx_array_1[:,1] = np.arange(data_array_1.shape[1]+2, data_array_1.shape[1]+data_array_2.shape[1], dtype=int)
|
|
1059
|
+
idx_array_2 = np.zeros((data_array_2.shape[1]-2,2), dtype=int)
|
|
1060
|
+
idx_array_2[:,0] = np.arange(data_array_1.shape[1]+2, data_array_1.shape[1]+data_array_2.shape[1], dtype=int)
|
|
1061
|
+
idx_array_2[:,1] = 0
|
|
1062
|
+
L = (M[:,:,1]-M[:,:,0]).reshape((M.shape[0], M.shape[1], 1))
|
|
1063
|
+
|
|
1064
|
+
dist_M_1 = np.min(M[:,:,idx_array_1]-M[:,:,idx_array_2], axis=-1) / L
|
|
1065
|
+
dist_M_1[dist_M_1 <= 0] = 0
|
|
1066
|
+
dist_M_1[dist_M_1 >= 1] = 0
|
|
1067
|
+
|
|
1068
|
+
# Distance for cross_1 contaminate loop_2
|
|
1069
|
+
idx_array_1 = np.zeros((data_array_1.shape[1]-2,2), dtype=int)
|
|
1070
|
+
idx_array_1[:,0] = data_array_1.shape[1]+1
|
|
1071
|
+
idx_array_1[:,1] = np.arange(2, data_array_1.shape[1], dtype=int)
|
|
1072
|
+
idx_array_2 = np.zeros((data_array_1.shape[1]-2,2), dtype=int)
|
|
1073
|
+
idx_array_2[:,0] = np.arange(2, data_array_1.shape[1], dtype=int)
|
|
1074
|
+
idx_array_2[:,1] = data_array_1.shape[1]
|
|
1075
|
+
L = (M[:,:,data_array_1.shape[1]+1]-M[:,:,data_array_1.shape[1]]).reshape((M.shape[0], M.shape[1], 1))
|
|
1076
|
+
|
|
1077
|
+
dist_M_2 = np.min(M[:,:,idx_array_1]-M[:,:,idx_array_2], axis=-1) / L
|
|
1078
|
+
dist_M_2[dist_M_2 <= 0] = 0
|
|
1079
|
+
dist_M_2[dist_M_2 >= 1] = 0
|
|
1080
|
+
|
|
1081
|
+
del M
|
|
1082
|
+
dist_M = np.max(np.concatenate((dist_M_1, dist_M_2), axis=-1), axis=-1)
|
|
1083
|
+
return dist_M
|
|
1084
|
+
##########################################################################################################################################################
|
|
1085
|
+
|
|
1086
|
+
##########################################################################################################################################################
|
|
1087
|
+
def agglomerative_clustering(self, dist, cluster_method, cluster_dist_cutoff, num_perm):
|
|
1088
|
+
min_SSDIFN = np.inf
|
|
1089
|
+
best_Z = None
|
|
1090
|
+
best_perm_idx_list = None
|
|
1091
|
+
pdist = squareform(dist, checks=False)
|
|
1092
|
+
if np.sum(pdist**2) == 0:
|
|
1093
|
+
best_Z = linkage(pdist, method=cluster_method)
|
|
1094
|
+
best_perm_idx_list = np.arange(dist.shape[0])
|
|
1095
|
+
else:
|
|
1096
|
+
# permuCLUSTER
|
|
1097
|
+
for idx_perm in range(np.max([1, num_perm])):
|
|
1098
|
+
perm_idx_list = np.random.permutation(np.arange(dist.shape[0]))
|
|
1099
|
+
pdist = squareform(dist[perm_idx_list,:][:,perm_idx_list], checks=False)
|
|
1100
|
+
Z = linkage(pdist, method=cluster_method)
|
|
1101
|
+
cdist = cophenet(Z)
|
|
1102
|
+
SSDIFN = np.sum((pdist - cdist)**2)/np.sum(pdist**2)
|
|
1103
|
+
if SSDIFN < min_SSDIFN:
|
|
1104
|
+
min_SSDIFN = SSDIFN
|
|
1105
|
+
best_Z = Z
|
|
1106
|
+
best_perm_idx_list = perm_idx_list
|
|
1107
|
+
cluster_id_list = fcluster(best_Z, cluster_dist_cutoff, criterion='distance')
|
|
1108
|
+
backmap_list = np.zeros(len(best_perm_idx_list), dtype=int)
|
|
1109
|
+
for i, j in enumerate(best_perm_idx_list):
|
|
1110
|
+
backmap_list[j] = i
|
|
1111
|
+
cluster_id_list = cluster_id_list[backmap_list]
|
|
1112
|
+
return cluster_id_list
|
|
1113
|
+
##########################################################################################################################################################
|
|
1114
|
+
|
|
1115
|
+
##########################################################################################################################################################
|
|
1116
|
+
def do_clustering(self, map_list, chg_ent_fingerprint_list, key, pdist_fun, cluster_method, cluster_dist_cutoff, num_perm=100):
|
|
1117
|
+
data = []
|
|
1118
|
+
# Get max number of crossings
|
|
1119
|
+
max_n_cross = 1
|
|
1120
|
+
if 'cross_contamination' in key:
|
|
1121
|
+
for map_idx in map_list:
|
|
1122
|
+
fingerprint = chg_ent_fingerprint_list[map_idx[0]][map_idx[1]][tuple(map_idx[2:4])]
|
|
1123
|
+
cr = fingerprint['crossing_resid']
|
|
1124
|
+
for ci, c in enumerate(cr):
|
|
1125
|
+
if len(c) > max_n_cross:
|
|
1126
|
+
max_n_cross = len(c)
|
|
1127
|
+
# Prepare clustering data for distance calculation
|
|
1128
|
+
for map_idx in map_list:
|
|
1129
|
+
fingerprint = chg_ent_fingerprint_list[map_idx[0]][map_idx[1]][tuple(map_idx[2:4])]
|
|
1130
|
+
if 'crossing_resid' in key:
|
|
1131
|
+
ter_idx = int(key.split('_')[-1])
|
|
1132
|
+
mc = []
|
|
1133
|
+
cr = fingerprint['crossing_resid']
|
|
1134
|
+
ref_cr = fingerprint['ref_crossing_resid']
|
|
1135
|
+
for c in [ref_cr[ter_idx], cr[ter_idx]]:
|
|
1136
|
+
if len(c) == 0:
|
|
1137
|
+
mc.append(-1)
|
|
1138
|
+
else:
|
|
1139
|
+
mc.append(np.median(c))
|
|
1140
|
+
data.append(mc)
|
|
1141
|
+
elif 'native_contact' in key:
|
|
1142
|
+
nc = fingerprint['native_contact']
|
|
1143
|
+
data.append(nc)
|
|
1144
|
+
elif 'cross_contamination' in key:
|
|
1145
|
+
nc = fingerprint['native_contact']
|
|
1146
|
+
mc = []
|
|
1147
|
+
cr = fingerprint['crossing_resid']
|
|
1148
|
+
for ci, c in enumerate(cr):
|
|
1149
|
+
for cii in range(max_n_cross):
|
|
1150
|
+
if cii >= len(c):
|
|
1151
|
+
mc.append(-1)
|
|
1152
|
+
else:
|
|
1153
|
+
mc.append(c[cii])
|
|
1154
|
+
data.append(nc + mc)
|
|
1155
|
+
else:
|
|
1156
|
+
self.logger.error('Error: Unknown key specified for do_clustering(), %s'%(key))
|
|
1157
|
+
sys.exit()
|
|
1158
|
+
data = np.array(data)
|
|
1159
|
+
# Reduce data size (saving memory usage) by combining duplicated data points
|
|
1160
|
+
reduced_data = np.unique(data, axis=0)
|
|
1161
|
+
reduced_data_map = np.array([np.all(data == d, axis=1).nonzero()[0].tolist() for d in reduced_data], dtype=object)
|
|
1162
|
+
|
|
1163
|
+
# If all data are the same, group them into a single cluster and return
|
|
1164
|
+
if len(reduced_data) == 1:
|
|
1165
|
+
cluster_data = [map_list]
|
|
1166
|
+
return cluster_data
|
|
1167
|
+
|
|
1168
|
+
# Chunk data to reduce memory usage if the expanded matrix occupy >= memory_cutoff
|
|
1169
|
+
if reduced_data.nbytes ** 2 >= self.memory_cutoff:
|
|
1170
|
+
n_chunk = int(np.ceil(reduced_data.nbytes / np.sqrt(self.memory_cutoff/2)))
|
|
1171
|
+
len_chunk = int(np.ceil(len(reduced_data) / n_chunk))
|
|
1172
|
+
dist = np.zeros((len(reduced_data), len(reduced_data)))
|
|
1173
|
+
for i in range(n_chunk):
|
|
1174
|
+
i_1 = i*len_chunk
|
|
1175
|
+
i_2 = np.min([(i+1)*len_chunk,len(reduced_data)])
|
|
1176
|
+
for j in range(n_chunk):
|
|
1177
|
+
j_1 = j*len_chunk
|
|
1178
|
+
j_2 = np.min([(j+1)*len_chunk,len(reduced_data)])
|
|
1179
|
+
dist[i_1:i_2, j_1:j_2] = pdist_fun(reduced_data[i_1:i_2], reduced_data[j_1:j_2])
|
|
1180
|
+
else:
|
|
1181
|
+
dist = pdist_fun(reduced_data, reduced_data)
|
|
1182
|
+
|
|
1183
|
+
if 'cross_contamination' in key:
|
|
1184
|
+
# Do divisive clustering
|
|
1185
|
+
cluster_idx_mapping = [list(np.arange(dist.shape[0]))]
|
|
1186
|
+
while True:
|
|
1187
|
+
cluster_1 = cluster_idx_mapping[-1]
|
|
1188
|
+
cluster_2 = []
|
|
1189
|
+
cluster_0 = copy.deepcopy(cluster_1)
|
|
1190
|
+
for i in range(len(cluster_0)-1):
|
|
1191
|
+
rm_idx_list = np.where(dist[cluster_0[i],cluster_0[i+1:]] >= cluster_dist_cutoff)[0]
|
|
1192
|
+
if len(rm_idx_list) > 0:
|
|
1193
|
+
cluster_1.remove(cluster_0[i])
|
|
1194
|
+
cluster_2.append(cluster_0[i])
|
|
1195
|
+
if len(cluster_2) > 0:
|
|
1196
|
+
cluster_idx_mapping.append(cluster_2)
|
|
1197
|
+
else:
|
|
1198
|
+
break
|
|
1199
|
+
# Do agglomerative clustering
|
|
1200
|
+
if cluster_method == 'single':
|
|
1201
|
+
dist_fun = np.min
|
|
1202
|
+
elif cluster_method == 'complete':
|
|
1203
|
+
dist_fun = np.max
|
|
1204
|
+
elif cluster_method == 'average':
|
|
1205
|
+
dist_fun = np.mean
|
|
1206
|
+
else:
|
|
1207
|
+
dist_fun = np.mean
|
|
1208
|
+
if len(cluster_idx_mapping) > 1:
|
|
1209
|
+
dist_0 = np.zeros((len(cluster_idx_mapping),len(cluster_idx_mapping)))
|
|
1210
|
+
for i in range(len(dist_0)-1):
|
|
1211
|
+
for j in range(i+1, len(dist_0)):
|
|
1212
|
+
dist_0[i,j] = dist_fun(dist[cluster_idx_mapping[i],:][:,cluster_idx_mapping[j]])
|
|
1213
|
+
cluster_id_list_0 = self.agglomerative_clustering(dist_0, cluster_method, cluster_dist_cutoff, num_perm)
|
|
1214
|
+
else:
|
|
1215
|
+
cluster_id_list_0 = np.array([1], dtype=int)
|
|
1216
|
+
cluster_id_list = np.zeros(dist.shape[0], dtype=int)
|
|
1217
|
+
for cluster_idx, mapping in enumerate(cluster_idx_mapping):
|
|
1218
|
+
cluster_id_list[mapping] = cluster_id_list_0[cluster_idx]
|
|
1219
|
+
else:
|
|
1220
|
+
# Do agglomerative clustering
|
|
1221
|
+
cluster_id_list = self.agglomerative_clustering(dist, cluster_method, cluster_dist_cutoff, num_perm)
|
|
1222
|
+
n_cluster = np.max(cluster_id_list)
|
|
1223
|
+
|
|
1224
|
+
# Back-Mapping indices
|
|
1225
|
+
cluster_data = []
|
|
1226
|
+
for cluster_id in range(n_cluster):
|
|
1227
|
+
idx = np.where(cluster_id_list == cluster_id+1)[0]
|
|
1228
|
+
idx_list = reduced_data_map[idx].tolist()
|
|
1229
|
+
idx_list_1 = []
|
|
1230
|
+
for i in idx_list:
|
|
1231
|
+
idx_list_1 += i
|
|
1232
|
+
idx_list_1 = sorted(idx_list_1)
|
|
1233
|
+
cluster_data.append(np.array(map_list)[idx_list_1].tolist())
|
|
1234
|
+
return cluster_data
|
|
1235
|
+
##########################################################################################################################################################
|
|
1236
|
+
|
|
1237
|
+
##########################################################################################################################################################
|
|
1238
|
+
def cluster_chg_ent(self, chg_ent_keyword_dict, chg_ent_fingerprint_list, cluster_method=['average', 'average', 'average'], cluster_dist_cutoff=[20, 1.0, 0.1]):
|
|
1239
|
+
cluster_data_keys = sorted(list(chg_ent_keyword_dict.keys()))
|
|
1240
|
+
cluster_data = {key: [] for key in cluster_data_keys}
|
|
1241
|
+
cluster_tree = {key: [] for key in cluster_data_keys}
|
|
1242
|
+
|
|
1243
|
+
def _process_key(key):
|
|
1244
|
+
"""Run the 4-level hierarchical clustering pipeline for one keyword.
|
|
1245
|
+
|
|
1246
|
+
All four ``do_clustering`` calls are fully independent across keywords,
|
|
1247
|
+
so this function is safe to run in a ThreadPoolExecutor. numpy/scipy
|
|
1248
|
+
release the GIL during heavy array operations, giving real parallelism.
|
|
1249
|
+
|
|
1250
|
+
Note: when ``self.nproc > 1``, ``agglomerative_clustering`` uses
|
|
1251
|
+
``np.random.permutation`` which draws from the shared global numpy
|
|
1252
|
+
random state. Results are therefore non-deterministic across runs with
|
|
1253
|
+
multiple workers, but clustering quality is unaffected.
|
|
1254
|
+
"""
|
|
1255
|
+
map_list = chg_ent_keyword_dict[key]
|
|
1256
|
+
local_cluster_data = []
|
|
1257
|
+
backtrace_idx_list = []
|
|
1258
|
+
idx_1 = 0
|
|
1259
|
+
idx_2 = 0
|
|
1260
|
+
idx_3 = 0
|
|
1261
|
+
|
|
1262
|
+
# First clustering based on N-ter crossing residues
|
|
1263
|
+
N_cr_cluster_data = self.do_clustering(
|
|
1264
|
+
map_list, chg_ent_fingerprint_list, 'crossing_resid_0',
|
|
1265
|
+
self.pdist_thread_deviation, cluster_method[0], cluster_dist_cutoff[0])
|
|
1266
|
+
|
|
1267
|
+
# Second clustering based on C-ter crossing residues
|
|
1268
|
+
for map_list_N_cr in N_cr_cluster_data:
|
|
1269
|
+
C_cr_cluster_data = self.do_clustering(
|
|
1270
|
+
map_list_N_cr, chg_ent_fingerprint_list, 'crossing_resid_1',
|
|
1271
|
+
self.pdist_thread_deviation, cluster_method[0], cluster_dist_cutoff[0])
|
|
1272
|
+
|
|
1273
|
+
# Third clustering based on loop
|
|
1274
|
+
for map_list_C_cr in C_cr_cluster_data:
|
|
1275
|
+
nc_cluster_data = self.do_clustering(
|
|
1276
|
+
map_list_C_cr, chg_ent_fingerprint_list, 'native_contact',
|
|
1277
|
+
self.pdist_loop_overlap, cluster_method[1], cluster_dist_cutoff[1])
|
|
1278
|
+
|
|
1279
|
+
# Fourth clustering based on cross contamination
|
|
1280
|
+
for map_list_nc in nc_cluster_data:
|
|
1281
|
+
final_cluster_data = self.do_clustering(
|
|
1282
|
+
map_list_nc, chg_ent_fingerprint_list, 'cross_contamination',
|
|
1283
|
+
self.pdist_cross_contamination, cluster_method[2], cluster_dist_cutoff[2])
|
|
1284
|
+
for final_cluster in final_cluster_data:
|
|
1285
|
+
local_cluster_data.append(final_cluster)
|
|
1286
|
+
backtrace_idx_list.append([idx_1, idx_2, idx_3])
|
|
1287
|
+
idx_3 += 1
|
|
1288
|
+
idx_2 += 1
|
|
1289
|
+
idx_1 += 1
|
|
1290
|
+
|
|
1291
|
+
backtrace_idx_list = np.array(backtrace_idx_list, dtype=int)
|
|
1292
|
+
local_tree = [
|
|
1293
|
+
[np.where(backtrace_idx_list[:, i] == j)[0].tolist()
|
|
1294
|
+
for j in range(backtrace_idx_list[:, i].max() + 1)]
|
|
1295
|
+
for i in range(backtrace_idx_list.shape[1])
|
|
1296
|
+
]
|
|
1297
|
+
n_cluster = len(local_cluster_data)
|
|
1298
|
+
self.logger.info('Found %d cluster(s) for %s' % (n_cluster, key))
|
|
1299
|
+
return key, local_cluster_data, local_tree
|
|
1300
|
+
|
|
1301
|
+
n_workers = min(self.nproc, len(cluster_data_keys)) if cluster_data_keys else 1
|
|
1302
|
+
self.logger.info(
|
|
1303
|
+
f'Clustering {len(cluster_data_keys)} keyword(s) '
|
|
1304
|
+
f'(nproc={n_workers})...')
|
|
1305
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
|
1306
|
+
for key, cd, ct in executor.map(_process_key, cluster_data_keys):
|
|
1307
|
+
cluster_data[key] = cd
|
|
1308
|
+
cluster_tree[key] = ct
|
|
1309
|
+
|
|
1310
|
+
return (cluster_data, cluster_tree)
|
|
1311
|
+
##########################################################################################################################################################
|
|
1312
|
+
|
|
1313
|
+
##########################################################################################################################################################
|
|
1314
|
+
def find_representative_entanglement(self, cluster_data, ent_cluster_idx_map):
|
|
1315
|
+
# most probable loop midpoint
|
|
1316
|
+
rep_ent_list = []
|
|
1317
|
+
for [key, idx] in ent_cluster_idx_map:
|
|
1318
|
+
cluster = np.array(cluster_data[key][idx])
|
|
1319
|
+
loop_midpoint_list = np.mean(cluster[:,2:4], axis=1)
|
|
1320
|
+
loop_max = np.max(cluster[:,3])
|
|
1321
|
+
loop_min = np.min(cluster[:,2])
|
|
1322
|
+
# mode
|
|
1323
|
+
bins = np.arange(loop_min, loop_max, 5)
|
|
1324
|
+
if len(bins) == 1:
|
|
1325
|
+
bins = np.arange(loop_min, loop_max, 1)
|
|
1326
|
+
hist, edges = np.histogram(loop_midpoint_list, bins=bins)
|
|
1327
|
+
idx = np.argmax(hist)
|
|
1328
|
+
min_loop_len = 1e6
|
|
1329
|
+
idx0 = np.where(loop_midpoint_list >= edges[idx])[0]
|
|
1330
|
+
idx1 = np.where(loop_midpoint_list[idx0] < edges[idx+1])[0]
|
|
1331
|
+
for iidx in idx0[idx1]:
|
|
1332
|
+
if cluster[iidx,3]-cluster[iidx,2] < min_loop_len:
|
|
1333
|
+
rep_ent = cluster[iidx]
|
|
1334
|
+
min_loop_len = rep_ent[3] - rep_ent[2]
|
|
1335
|
+
rep_ent_list.append(rep_ent)
|
|
1336
|
+
return rep_ent_list
|
|
1337
|
+
##########################################################################################################################################################
|
|
1338
|
+
|
|
1339
|
+
##########################################################################################################################################################
|
|
1340
|
+
def _process_traj_file(self, traj_idx, ent_data_file, start_frame, end_frame):
|
|
1341
|
+
"""Load and pre-process one trajectory pkl file for clustering.
|
|
1342
|
+
|
|
1343
|
+
Called in parallel (one call per trajectory) by ``cluster()``.
|
|
1344
|
+
|
|
1345
|
+
Returns
|
|
1346
|
+
-------
|
|
1347
|
+
tuple : (traj_idx, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries)
|
|
1348
|
+
* fingerprint_dict – {frame: {nc: fingerprint}}
|
|
1349
|
+
* Q_dict – {frame: Q_value}
|
|
1350
|
+
* frame_list – sorted list of in-range frame indices
|
|
1351
|
+
* traj_file – resolved path to the matching .dcd file
|
|
1352
|
+
* keyword_entries – list of (keyword_str, entry) pairs for building
|
|
1353
|
+
chg_ent_keyword_dict after all trajectories are loaded
|
|
1354
|
+
"""
|
|
1355
|
+
traj = self.extract_traj_number(ent_data_file)
|
|
1356
|
+
self.logger.debug(
|
|
1357
|
+
f'Processing {ent_data_file} {traj} ({traj_idx + 1} / {len(self.ent_data_file_list)})...')
|
|
1358
|
+
|
|
1359
|
+
# Load pickle with frame filtering applied
|
|
1360
|
+
ent_data = self.load_pickle(ent_data_file, start_frame, end_frame)
|
|
1361
|
+
# Frame list is now already filtered, just exclude 'ref'
|
|
1362
|
+
frame_list = sorted([frame for frame in ent_data.keys() if frame != 'ref'])
|
|
1363
|
+
self.logger.debug(f'frame_list: {frame_list} {len(frame_list)}')
|
|
1364
|
+
|
|
1365
|
+
# Locate the matching trajectory DCD file
|
|
1366
|
+
traj_file = os.path.join(self.traj_dir_prefix, f'{traj}_*.dcd')
|
|
1367
|
+
traj_file = glob.glob(traj_file)
|
|
1368
|
+
if len(traj_file) == 0:
|
|
1369
|
+
raise ValueError(f'No trajectory file found for {ent_data_file}.')
|
|
1370
|
+
elif len(traj_file) > 1:
|
|
1371
|
+
raise ValueError(f'More than 1 trajectory file found for {ent_data_file}.')
|
|
1372
|
+
else:
|
|
1373
|
+
traj_file = traj_file[0]
|
|
1374
|
+
|
|
1375
|
+
fingerprint_dict = {}
|
|
1376
|
+
Q_dict = {}
|
|
1377
|
+
keyword_entries = [] # collected as (keyword_str, entry) to merge after parallel load
|
|
1378
|
+
|
|
1379
|
+
for frame in frame_list:
|
|
1380
|
+
fingerprint_dict[frame] = {}
|
|
1381
|
+
Q_dict[frame] = (
|
|
1382
|
+
np.sum(list(ent_data[frame]['G_dict'].values()))
|
|
1383
|
+
/ len(list(ent_data['ref']['ent_fingerprint'].keys())) / 2
|
|
1384
|
+
)
|
|
1385
|
+
for nc, fingerprint in ent_data[frame]['chg_ent_fingerprint'].items():
|
|
1386
|
+
# Skip if no change of entanglement
|
|
1387
|
+
if fingerprint['type'] == ['no change', 'no change']:
|
|
1388
|
+
continue
|
|
1389
|
+
fingerprint_dict[frame][nc] = fingerprint
|
|
1390
|
+
chg_ent_keyword = fingerprint['code'].copy()
|
|
1391
|
+
for ck in self.classify_key:
|
|
1392
|
+
if type(fingerprint[ck]) == list:
|
|
1393
|
+
chg_ent_keyword += fingerprint[ck]
|
|
1394
|
+
else:
|
|
1395
|
+
chg_ent_keyword += [fingerprint[ck]]
|
|
1396
|
+
chg_ent_keyword = str(chg_ent_keyword)
|
|
1397
|
+
keyword_entries.append((chg_ent_keyword, [traj_idx, frame] + list(nc)))
|
|
1398
|
+
|
|
1399
|
+
# Explicitly free the full deserialized pkl dict — it can be ~10 GB and the
|
|
1400
|
+
# semaphore in cluster() is held until this function returns, so freeing it
|
|
1401
|
+
# here lets CPython recycle the memory before the next load starts.
|
|
1402
|
+
del ent_data
|
|
1403
|
+
|
|
1404
|
+
return traj_idx, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries
|
|
1405
|
+
##########################################################################################################################################################
|
|
1406
|
+
|
|
1407
|
+
##########################################################################################################################################################
|
|
1408
|
+
def cluster(self, start_frame:int=0, end_frame:int=9999999):
|
|
1409
|
+
|
|
1410
|
+
## Define the .npz file name
|
|
1411
|
+
npz_data_file = f'cluster_data_{"_".join(self.classify_key)}.npz'
|
|
1412
|
+
npz_data_file = os.path.join(self.outdir, npz_data_file)
|
|
1413
|
+
self.logger.info(f'Checking for {npz_data_file}')
|
|
1414
|
+
|
|
1415
|
+
if not os.path.exists(npz_data_file):
|
|
1416
|
+
# Classify changes of entanglement based on the keyword
|
|
1417
|
+
# "[change_code, classify_key_1_N, classify_key_1_C, classify_key_2_N, classify_key_2_C, ...]"
|
|
1418
|
+
self.logger.debug('Reading pkl data and classify changes of entanglement...')
|
|
1419
|
+
chg_ent_fingerprint_list = [None] * len(self.ent_data_file_list)
|
|
1420
|
+
Q_list = [None] * len(self.ent_data_file_list)
|
|
1421
|
+
idx2frame = [None] * len(self.ent_data_file_list)
|
|
1422
|
+
idx2trajfile = [None] * len(self.ent_data_file_list)
|
|
1423
|
+
dtrajs = [None] * len(self.ent_data_file_list)
|
|
1424
|
+
chg_ent_keyword_dict = {}
|
|
1425
|
+
chg_ent_keyword_list = []
|
|
1426
|
+
# combined_traj = None
|
|
1427
|
+
|
|
1428
|
+
n_workers = min(self.nproc, len(self.ent_data_file_list))
|
|
1429
|
+
# Each ~1 GB pkl file inflates to ~10 GB of live Python objects during
|
|
1430
|
+
# deserialization. Use memory_cutoff as a proxy for available RAM to
|
|
1431
|
+
# derive a safe concurrency limit: memory_cutoff / 1e10 (10 GB/file).
|
|
1432
|
+
n_load_workers = min(n_workers, max(1, int(self.memory_cutoff // 1e10)))
|
|
1433
|
+
self.logger.info(
|
|
1434
|
+
f'Loading {len(self.ent_data_file_list)} pkl files '
|
|
1435
|
+
f'(nproc={n_workers}, concurrent_loads={n_load_workers})...')
|
|
1436
|
+
# Semaphore caps how many threads may simultaneously hold a deserialized
|
|
1437
|
+
# pkl in memory. n_workers threads are still available to start new
|
|
1438
|
+
# loads as soon as a slot frees, so throughput is not reduced.
|
|
1439
|
+
_load_sem = threading.Semaphore(n_load_workers)
|
|
1440
|
+
|
|
1441
|
+
def _throttled_process(ti, f):
|
|
1442
|
+
with _load_sem:
|
|
1443
|
+
return self._process_traj_file(ti, f, start_frame, end_frame)
|
|
1444
|
+
|
|
1445
|
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
|
1446
|
+
futures = {
|
|
1447
|
+
executor.submit(_throttled_process, ti, f): ti
|
|
1448
|
+
for ti, f in enumerate(self.ent_data_file_list)
|
|
1449
|
+
}
|
|
1450
|
+
for fut in as_completed(futures):
|
|
1451
|
+
ti, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries = fut.result()
|
|
1452
|
+
chg_ent_fingerprint_list[ti] = fingerprint_dict
|
|
1453
|
+
Q_list[ti] = Q_dict
|
|
1454
|
+
idx2frame[ti] = frame_list
|
|
1455
|
+
idx2trajfile[ti] = traj_file
|
|
1456
|
+
dtrajs[ti] = [[] for _ in frame_list]
|
|
1457
|
+
# Merge keyword entries into the shared dicts (sequential, no lock needed)
|
|
1458
|
+
for kw, entry in keyword_entries:
|
|
1459
|
+
if kw not in chg_ent_keyword_list:
|
|
1460
|
+
chg_ent_keyword_dict[kw] = []
|
|
1461
|
+
chg_ent_keyword_list.append(kw)
|
|
1462
|
+
chg_ent_keyword_dict[kw].append(entry)
|
|
1463
|
+
|
|
1464
|
+
self.logger.info('%d data files have been read.' % len(self.ent_data_file_list))
|
|
1465
|
+
|
|
1466
|
+
# cluster changes of entanglements found in the trajectories
|
|
1467
|
+
self.logger.info('Clustering changes of entanglement for %d keywords...'%(len(chg_ent_keyword_list)))
|
|
1468
|
+
ent_cluster_data, ent_cluster_tree = self.cluster_chg_ent(chg_ent_keyword_dict, chg_ent_fingerprint_list, cluster_method=self.cluster_method, cluster_dist_cutoff=self.cluster_dist_cutoff)
|
|
1469
|
+
chg_ent_keyword_list = sorted(chg_ent_keyword_list)
|
|
1470
|
+
# Save calculted data in case job is unexpectedly terminated
|
|
1471
|
+
np.savez(npz_data_file,
|
|
1472
|
+
chg_ent_fingerprint_list=chg_ent_fingerprint_list,
|
|
1473
|
+
Q_list=Q_list,
|
|
1474
|
+
chg_ent_keyword_dict=chg_ent_keyword_dict,
|
|
1475
|
+
chg_ent_keyword_list=chg_ent_keyword_list,
|
|
1476
|
+
idx2trajfile=idx2trajfile,
|
|
1477
|
+
idx2frame=idx2frame,
|
|
1478
|
+
ent_cluster_data=ent_cluster_data,
|
|
1479
|
+
ent_cluster_tree=ent_cluster_tree)
|
|
1480
|
+
self.logger.info(f'SAVED: {npz_data_file}')
|
|
1481
|
+
|
|
1482
|
+
else:
|
|
1483
|
+
self.logger.info(f'Reading clustering data from {npz_data_file}...')
|
|
1484
|
+
npz_data = np.load(npz_data_file, allow_pickle=True)
|
|
1485
|
+
chg_ent_fingerprint_list = npz_data['chg_ent_fingerprint_list'].tolist()
|
|
1486
|
+
Q_list = npz_data['Q_list'].tolist()
|
|
1487
|
+
chg_ent_keyword_dict = npz_data['chg_ent_keyword_dict'].item()
|
|
1488
|
+
chg_ent_keyword_list = npz_data['chg_ent_keyword_list'].tolist()
|
|
1489
|
+
idx2frame = npz_data['idx2frame'].tolist()
|
|
1490
|
+
idx2trajfile = npz_data['idx2trajfile'].tolist()
|
|
1491
|
+
dtrajs = [[[] for frame in chg_ent_fingerprint.keys()] for chg_ent_fingerprint in chg_ent_fingerprint_list]
|
|
1492
|
+
ent_cluster_data = npz_data['ent_cluster_data'].item()
|
|
1493
|
+
ent_cluster_tree = npz_data['ent_cluster_tree'].item()
|
|
1494
|
+
|
|
1495
|
+
|
|
1496
|
+
ent_cluster_idx_map = []
|
|
1497
|
+
for ent_keyword in chg_ent_keyword_list:
|
|
1498
|
+
for i in range(len(ent_cluster_data[ent_keyword])):
|
|
1499
|
+
ent_cluster_idx_map.append([ent_keyword, i])
|
|
1500
|
+
|
|
1501
|
+
## Print and save cluster tree
|
|
1502
|
+
cluster_headers = ['After clustering on N crossing', 'After clustering on C crossing', 'After clustering on loop']
|
|
1503
|
+
cluster_tree_file = f'cluster_tree_{"_".join(self.classify_key)}.dat'
|
|
1504
|
+
cluster_tree_file = os.path.join(self.outdir, cluster_tree_file)
|
|
1505
|
+
self.logger.info(f'Making {cluster_tree_file}')
|
|
1506
|
+
with open(cluster_tree_file, 'w') as f:
|
|
1507
|
+
for ent_keyword in chg_ent_keyword_list:
|
|
1508
|
+
f.write(ent_keyword+'\n')
|
|
1509
|
+
clusters = ent_cluster_tree[ent_keyword]
|
|
1510
|
+
for i in range(len(clusters)):
|
|
1511
|
+
f.write(' '*4 + cluster_headers[i] + ':\n')
|
|
1512
|
+
for cluster in clusters[i]:
|
|
1513
|
+
f.write(' '*8 + '[')
|
|
1514
|
+
for ci, c in enumerate(cluster):
|
|
1515
|
+
cluster_id = ent_cluster_idx_map.index([ent_keyword, c])+1
|
|
1516
|
+
if ci == 0:
|
|
1517
|
+
f.write('%d'%cluster_id)
|
|
1518
|
+
else:
|
|
1519
|
+
f.write(', %d'%cluster_id)
|
|
1520
|
+
f.write(']\n')
|
|
1521
|
+
f.write('\n')
|
|
1522
|
+
self.logger.info(f'SAVED: {cluster_tree_file}')
|
|
1523
|
+
|
|
1524
|
+
# Find representative changes of entanglement in each cluster
|
|
1525
|
+
rep_chg_ent_list_file = f'rep_chg_ent_list_{"_".join(self.classify_key)}.pkl'
|
|
1526
|
+
rep_chg_ent_list_file = os.path.join(self.outdir, rep_chg_ent_list_file)
|
|
1527
|
+
rep_chg_ent_data_file = f'rep_chg_ent_{"_".join(self.classify_key)}.csv'
|
|
1528
|
+
rep_chg_ent_data_file = os.path.join(self.outdir, rep_chg_ent_data_file)
|
|
1529
|
+
if os.path.exists(rep_chg_ent_list_file) and os.path.exists(rep_chg_ent_data_file):
|
|
1530
|
+
self.logger.debug('Reading representative changes of entanglement...')
|
|
1531
|
+
with open(rep_chg_ent_list_file, 'rb') as f:
|
|
1532
|
+
rep_chg_ent_list = pickle.load(f)
|
|
1533
|
+
self.logger.debug(f'Loaded: {rep_chg_ent_data_file} into rep_chg_ent_list')
|
|
1534
|
+
|
|
1535
|
+
else:
|
|
1536
|
+
self.logger.debug('Finding representative changes of entanglement...')
|
|
1537
|
+
rep_chg_ent_list = self.find_representative_entanglement(ent_cluster_data, ent_cluster_idx_map)
|
|
1538
|
+
with open(rep_chg_ent_list_file, 'wb') as f: # save the list as a pickle file
|
|
1539
|
+
pickle.dump(rep_chg_ent_list, f)
|
|
1540
|
+
self.logger.info(f'SAVED: {rep_chg_ent_list_file}')
|
|
1541
|
+
|
|
1542
|
+
# Create dataframe and save data
|
|
1543
|
+
data = []
|
|
1544
|
+
column_list = ['Keywords', 'Trajectory', 'Frame', 'Native Contact (Residue Index)',
|
|
1545
|
+
'Ref N-ter Crossing', 'Ref C-ter Crossing', 'N-ter Crossing', 'C-ter Crossing',
|
|
1546
|
+
'Ref N-ter GLN', 'Ref C-ter GLN', 'N-ter GLN', 'C-ter GLN',
|
|
1547
|
+
'Ref N-ter Linking Number', 'Ref C-ter Linking Number', 'N-ter Linking Number', 'C-ter Linking Number']
|
|
1548
|
+
index_list = []
|
|
1549
|
+
for state_id, rep_chg_ent in enumerate(rep_chg_ent_list):
|
|
1550
|
+
index_list.append(state_id+1)
|
|
1551
|
+
[traj_idx, frame_idx] = rep_chg_ent[:2]
|
|
1552
|
+
nc = tuple(rep_chg_ent[2:])
|
|
1553
|
+
keyword = ent_cluster_idx_map[state_id][0]
|
|
1554
|
+
chg_ent_fingerprint = chg_ent_fingerprint_list[traj_idx][frame_idx][nc]
|
|
1555
|
+
cross = []
|
|
1556
|
+
for i in range(len(chg_ent_fingerprint['crossing_resid'])):
|
|
1557
|
+
cross.append([])
|
|
1558
|
+
for j in range(len(chg_ent_fingerprint['crossing_resid'][i])):
|
|
1559
|
+
cross[-1].append(chg_ent_fingerprint['crossing_pattern'][i][j]+str(chg_ent_fingerprint['crossing_resid'][i][j]))
|
|
1560
|
+
ref_cross = []
|
|
1561
|
+
for i in range(len(chg_ent_fingerprint['ref_crossing_resid'])):
|
|
1562
|
+
ref_cross.append([])
|
|
1563
|
+
for j in range(len(chg_ent_fingerprint['ref_crossing_resid'][i])):
|
|
1564
|
+
ref_cross[-1].append(chg_ent_fingerprint['ref_crossing_pattern'][i][j]+str(chg_ent_fingerprint['ref_crossing_resid'][i][j]))
|
|
1565
|
+
GLN = chg_ent_fingerprint['linking_value']
|
|
1566
|
+
ref_GLN = chg_ent_fingerprint['ref_linking_value']
|
|
1567
|
+
LN = chg_ent_fingerprint['topoly_linking_number']
|
|
1568
|
+
ref_LN = chg_ent_fingerprint['ref_topoly_linking_number']
|
|
1569
|
+
|
|
1570
|
+
data_0 = [keyword, idx2trajfile[traj_idx], frame_idx, nc,
|
|
1571
|
+
ref_cross[0], ref_cross[1], cross[0], cross[1],
|
|
1572
|
+
ref_GLN[0], ref_GLN[1], GLN[0], GLN[1],
|
|
1573
|
+
ref_LN[0], ref_LN[1], LN[0], LN[1]]
|
|
1574
|
+
data.append(data_0)
|
|
1575
|
+
|
|
1576
|
+
df = pd.DataFrame(data, columns=column_list, index=index_list)
|
|
1577
|
+
rep_chg_ent_data_file = f'rep_chg_ent_{"_".join(self.classify_key)}.csv'
|
|
1578
|
+
rep_chg_ent_data_file = os.path.join(self.outdir, rep_chg_ent_data_file)
|
|
1579
|
+
df.to_csv(rep_chg_ent_data_file, index_label='State ID')
|
|
1580
|
+
self.logger.info(f'SAVED: {rep_chg_ent_data_file}')
|
|
1581
|
+
|
|
1582
|
+
# plot entanglement distribution
|
|
1583
|
+
n_cluster = len(ent_cluster_idx_map)
|
|
1584
|
+
fig = plt.figure(figsize=(np.max([6, 0.3*n_cluster]),5))
|
|
1585
|
+
ax = fig.add_subplot(1,1,1)
|
|
1586
|
+
window_width = 0.8
|
|
1587
|
+
for state_id, [key, cluster_idx] in enumerate(ent_cluster_idx_map):
|
|
1588
|
+
cluster = ent_cluster_data[key][cluster_idx]
|
|
1589
|
+
nc_list = [c[2:4] for c in cluster]
|
|
1590
|
+
sort_index = [i for i, x in sorted(enumerate(nc_list), key=lambda x: (x[1][1]-x[1][0], x[1][0]))]
|
|
1591
|
+
sort_index = np.array(sort_index)
|
|
1592
|
+
if len(sort_index) <= self.max_plot_samples:
|
|
1593
|
+
plot_idx = np.arange(0, len(sort_index), 1, dtype=int)
|
|
1594
|
+
else:
|
|
1595
|
+
plot_idx = np.linspace(0, len(sort_index)-1, self.max_plot_samples, dtype=int)
|
|
1596
|
+
for idx, ci in enumerate(sort_index[plot_idx]):
|
|
1597
|
+
c = cluster[ci]
|
|
1598
|
+
nc = c[2:4]
|
|
1599
|
+
traj_idx = c[0]
|
|
1600
|
+
frame_idx = c[1]
|
|
1601
|
+
fingerprint = chg_ent_fingerprint_list[traj_idx][frame_idx][tuple(nc)]
|
|
1602
|
+
crossings = fingerprint['crossing_resid']
|
|
1603
|
+
ref_crossings = fingerprint['ref_crossing_resid']
|
|
1604
|
+
# plot loop
|
|
1605
|
+
x = state_id+1-window_width/2 + (idx+1)*window_width/(len(plot_idx)+1)
|
|
1606
|
+
ax.plot([x,x], nc, '-', color='tomato', linewidth=0.5, alpha=0.4)
|
|
1607
|
+
# plot crossings
|
|
1608
|
+
x = state_id+1-window_width/2 + (idx+1)*window_width/(len(plot_idx)+1)
|
|
1609
|
+
for ccr in ref_crossings:
|
|
1610
|
+
for cc in ccr:
|
|
1611
|
+
ax.plot([x, x], [cc-0.5, cc+0.5], '-', color='green', linewidth=0.5, alpha=0.4)
|
|
1612
|
+
for ccr in crossings:
|
|
1613
|
+
for cc in ccr:
|
|
1614
|
+
ax.plot([x, x], [cc-0.5, cc+0.5], '-', color='blue', linewidth=0.5, alpha=0.4)
|
|
1615
|
+
ax.set_xticks(np.arange(1,n_cluster+1,1), np.arange(1,n_cluster+1,1))
|
|
1616
|
+
ax.set_xlim([0, n_cluster+1])
|
|
1617
|
+
ax.set_xlabel('Cluster')
|
|
1618
|
+
ax.set_ylabel('Residue index')
|
|
1619
|
+
chg_dist_data_file = f'chg_ent_{"_".join(self.classify_key)}_distribution.pdf'
|
|
1620
|
+
chg_dist_data_file = os.path.join(self.outdir, chg_dist_data_file)
|
|
1621
|
+
fig.savefig(chg_dist_data_file, bbox_inches='tight')
|
|
1622
|
+
self.logger.info(f'SAVED: {chg_dist_data_file}')
|
|
1623
|
+
del fig
|
|
1624
|
+
|
|
1625
|
+
self.logger.debug('Clustering structures with unique combinations of changes of entanglements...')
|
|
1626
|
+
# Assign entanglement clusters (list of ent_cluster_idx) in discrete trajectories
|
|
1627
|
+
for key, ent_clusters in ent_cluster_data.items():
|
|
1628
|
+
for i, ent_cluster in enumerate(ent_clusters):
|
|
1629
|
+
cluster_id = ent_cluster_idx_map.index([key, i])
|
|
1630
|
+
for chg_ent_keyword in ent_cluster:
|
|
1631
|
+
traj_idx, frame = chg_ent_keyword[:2]
|
|
1632
|
+
dtrajs[traj_idx][idx2frame[traj_idx].index(frame)].append(cluster_id)
|
|
1633
|
+
self.logger.info(f'Assigned entanglement clusters in discrete trajectories')
|
|
1634
|
+
|
|
1635
|
+
# Strip same cluster ids in each frame
|
|
1636
|
+
chg_ent_structure_keyword_list = []
|
|
1637
|
+
for dtraj in dtrajs:
|
|
1638
|
+
for i, cluster_id_list in enumerate(dtraj):
|
|
1639
|
+
dtraj[i] = sorted(list(set(cluster_id_list)))
|
|
1640
|
+
if str(dtraj[i]) not in chg_ent_structure_keyword_list:
|
|
1641
|
+
chg_ent_structure_keyword_list.append(str(dtraj[i]))
|
|
1642
|
+
chg_ent_structure_keyword_list = sorted(chg_ent_structure_keyword_list)
|
|
1643
|
+
self.logger.info(f'Stripped same cluster ids in each frame')
|
|
1644
|
+
|
|
1645
|
+
# cluster trajectory frames with different combinations of changes in entanglement
|
|
1646
|
+
chg_ent_structure_cluster_data = {chg_ent_structure_keyword: [] for chg_ent_structure_keyword in chg_ent_structure_keyword_list}
|
|
1647
|
+
for traj_idx, dtraj in enumerate(dtrajs):
|
|
1648
|
+
for frame_idx, cluster_id_list in enumerate(dtraj):
|
|
1649
|
+
chg_ent_structure_cluster_data[str(cluster_id_list)].append([traj_idx, idx2frame[traj_idx][frame_idx]])
|
|
1650
|
+
Num_struct_list = [len(chg_ent_structure_cluster_data[keyword]) for keyword in chg_ent_structure_keyword_list]
|
|
1651
|
+
sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
|
|
1652
|
+
sorted_chg_ent_structure_keyword_list = [chg_ent_structure_keyword_list[idx] for idx in sort_idx]
|
|
1653
|
+
sorted_Num_struct_list = [Num_struct_list[idx] for idx in sort_idx]
|
|
1654
|
+
self.logger.info(f'Cluster trajectory frames with different combinations of changes in entanglement')
|
|
1655
|
+
|
|
1656
|
+
self.logger.debug('Find representative combinations of changes of entanglements in structures...')
|
|
1657
|
+
# Find representative changes of entanglement (minimal loop) in each frame
|
|
1658
|
+
rep_chg_ent_dtrajs = []
|
|
1659
|
+
for traj_idx, dtraj in enumerate(dtrajs):
|
|
1660
|
+
#print(f'\nTRAJ IDX: {traj_idx} with {len(dtraj)} frames')
|
|
1661
|
+
rep_chg_ent_dtrajs.append([])
|
|
1662
|
+
for frame_idx, cluster_id_list in enumerate(dtraj):
|
|
1663
|
+
#print(f'FRAME IDX: {frame_idx} {cluster_id_list}')
|
|
1664
|
+
frame_idx_0 = idx2frame[traj_idx][frame_idx]
|
|
1665
|
+
rep_chg_ent_dtrajs[-1].append({})
|
|
1666
|
+
for cluster_id in cluster_id_list:
|
|
1667
|
+
[ent_keyword, idx] = ent_cluster_idx_map[cluster_id]
|
|
1668
|
+
nc_list = []
|
|
1669
|
+
for element in ent_cluster_data[ent_keyword][idx]:
|
|
1670
|
+
if traj_idx == element[0] and frame_idx_0 == element[1]:
|
|
1671
|
+
nc_list.append(tuple(element[2:]))
|
|
1672
|
+
rep_nc = nc_list[0]
|
|
1673
|
+
for nc in nc_list:
|
|
1674
|
+
if nc[1]-nc[0] < rep_nc[1]-rep_nc[0]:
|
|
1675
|
+
rep_nc = nc
|
|
1676
|
+
rep_chg_ent_dtrajs[-1][-1][cluster_id] = chg_ent_fingerprint_list[traj_idx][frame_idx_0][rep_nc]
|
|
1677
|
+
|
|
1678
|
+
# Find representative structures (max Q) for each combination
|
|
1679
|
+
self.logger.info(f'Find representative structures (max Q) for each combination')
|
|
1680
|
+
rep_struct_data = {}
|
|
1681
|
+
for keyword in sorted_chg_ent_structure_keyword_list:
|
|
1682
|
+
Q_0 = 0
|
|
1683
|
+
rep_struct_data[keyword] = chg_ent_structure_cluster_data[keyword][0]
|
|
1684
|
+
for [traj_idx, frame_idx] in chg_ent_structure_cluster_data[keyword]:
|
|
1685
|
+
Q = Q_list[traj_idx][frame_idx]
|
|
1686
|
+
if Q > Q_0:
|
|
1687
|
+
Q_0 = Q
|
|
1688
|
+
rep_struct_data[keyword] = [traj_idx, frame_idx]
|
|
1689
|
+
self.logger.debug('Found representative structures (max Q) for each combination')
|
|
1690
|
+
|
|
1691
|
+
# Create dataframe and save data
|
|
1692
|
+
#chg_ent_data_file = f'chg_ent_struct_{"_".join(self.classify_key)}.csv'
|
|
1693
|
+
#chg_ent_data_file = os.path.join(self.outdir, chg_ent_data_file)
|
|
1694
|
+
#if os.path.exitst(chg_ent_data_file):
|
|
1695
|
+
# print(f'Reading {chg_ent_data_file}')
|
|
1696
|
+
# df = pd.read_csv(chg_ent_data_file, index_col='State ID')
|
|
1697
|
+
data = []
|
|
1698
|
+
column_list = ['Rep_chg_ents', 'Num of structures', 'Probability',
|
|
1699
|
+
'Rep trajectory', 'Rep frame', 'Max Q', 'Median Q']
|
|
1700
|
+
index_list = []
|
|
1701
|
+
tot_num_frames = 0
|
|
1702
|
+
for traj_idx, dtraj in enumerate(dtrajs):
|
|
1703
|
+
tot_num_frames += len(dtraj)
|
|
1704
|
+
for state_id, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
|
|
1705
|
+
index_list.append(state_id+1)
|
|
1706
|
+
Rep_chg_ents = str(list(np.array(eval(keyword))+1))
|
|
1707
|
+
Num = len(chg_ent_structure_cluster_data[keyword])
|
|
1708
|
+
Prob = Num / tot_num_frames
|
|
1709
|
+
Q_0_list = [Q_list[cd[0]][cd[1]] for cd in chg_ent_structure_cluster_data[keyword]]
|
|
1710
|
+
max_Q = np.max(Q_0_list)
|
|
1711
|
+
median_Q = np.median(Q_0_list)
|
|
1712
|
+
data_0 = [Rep_chg_ents, Num, Prob, idx2trajfile[rep_struct_data[keyword][0]], rep_struct_data[keyword][1], max_Q, median_Q]
|
|
1713
|
+
data.append(data_0)
|
|
1714
|
+
df = pd.DataFrame(data, columns=column_list, index=index_list)
|
|
1715
|
+
chg_ent_data_file = f'chg_ent_struct_{"_".join(self.classify_key)}.csv'
|
|
1716
|
+
chg_ent_data_file = os.path.join(self.outdir, chg_ent_data_file)
|
|
1717
|
+
df.to_csv(chg_ent_data_file, index_label='State ID')
|
|
1718
|
+
self.logger.info(f'SAVED: {chg_ent_data_file}')
|
|
1719
|
+
|
|
1720
|
+
## determine if there is any issue with the item shapes before saving
|
|
1721
|
+
save_items = {
|
|
1722
|
+
"chg_ent_fingerprint_list": chg_ent_fingerprint_list,
|
|
1723
|
+
"Q_list": Q_list,
|
|
1724
|
+
"chg_ent_keyword_dict": chg_ent_keyword_dict,
|
|
1725
|
+
"chg_ent_keyword_list": chg_ent_keyword_list,
|
|
1726
|
+
"idx2trajfile": idx2trajfile,
|
|
1727
|
+
"idx2frame": idx2frame,
|
|
1728
|
+
# "RMSD_array": RMSD_array,
|
|
1729
|
+
"ent_cluster_data": ent_cluster_data,
|
|
1730
|
+
"ent_cluster_tree": ent_cluster_tree,
|
|
1731
|
+
"rep_chg_ent_list": rep_chg_ent_list,
|
|
1732
|
+
"dtrajs": dtrajs,
|
|
1733
|
+
"rep_chg_ent_dtrajs": rep_chg_ent_dtrajs,
|
|
1734
|
+
"sorted_chg_ent_structure_keyword_list": sorted_chg_ent_structure_keyword_list,
|
|
1735
|
+
"chg_ent_structure_cluster_data": chg_ent_structure_cluster_data,
|
|
1736
|
+
"rep_struct_data": rep_struct_data}
|
|
1737
|
+
|
|
1738
|
+
for key, value in save_items.items():
|
|
1739
|
+
try:
|
|
1740
|
+
shape_info = f", shape={np.shape(value)}"
|
|
1741
|
+
except Exception as e:
|
|
1742
|
+
shape_info = f", shape=unavailable ({e})"
|
|
1743
|
+
self.logger.info(f"{key}: type={type(value)}{shape_info}")
|
|
1744
|
+
|
|
1745
|
+
|
|
1746
|
+
# Save data
|
|
1747
|
+
np.savez(npz_data_file,
|
|
1748
|
+
chg_ent_fingerprint_list=chg_ent_fingerprint_list,
|
|
1749
|
+
Q_list=Q_list,
|
|
1750
|
+
chg_ent_keyword_dict=chg_ent_keyword_dict,
|
|
1751
|
+
chg_ent_keyword_list=chg_ent_keyword_list,
|
|
1752
|
+
idx2trajfile=idx2trajfile,
|
|
1753
|
+
idx2frame=idx2frame,
|
|
1754
|
+
# RMSD_array=RMSD_array,
|
|
1755
|
+
ent_cluster_data=ent_cluster_data,
|
|
1756
|
+
ent_cluster_tree=ent_cluster_tree,
|
|
1757
|
+
rep_chg_ent_list=rep_chg_ent_list,
|
|
1758
|
+
dtrajs=np.array(dtrajs, dtype=object),
|
|
1759
|
+
rep_chg_ent_dtrajs=np.array(rep_chg_ent_dtrajs, dtype=object),
|
|
1760
|
+
sorted_chg_ent_structure_keyword_list=sorted_chg_ent_structure_keyword_list,
|
|
1761
|
+
chg_ent_structure_cluster_data=chg_ent_structure_cluster_data,
|
|
1762
|
+
rep_struct_data=rep_struct_data)
|
|
1763
|
+
self.logger.info(f'SAVED: {npz_data_file}')
|
|
1764
|
+
self.logger.info(f'Clustering Complete')
|
|
1765
|
+
##########################################################################################################################################################
|
|
1766
|
+
|
|
1767
|
+
##########################################################################################################################################################
|
|
1768
|
+
def viz_rep_changes(self, ):
|
|
1769
|
+
psf_file = None
|
|
1770
|
+
if_viz = 1
|
|
1771
|
+
if_backmap = 0
|
|
1772
|
+
pulchra_only = False
|
|
1773
|
+
native_AA_pdb = None
|
|
1774
|
+
top_struct = 0.01
|
|
1775
|
+
|
|
1776
|
+
if if_viz:
|
|
1777
|
+
## Generate visualiztion for representative changes of entanglement in each cluster
|
|
1778
|
+
self.logger.debug('Generate visualization for representative changes of entanglement...')
|
|
1779
|
+
os.system('mkdir viz_rep_chg_ent_%s'%('_'.join(self.classify_key)))
|
|
1780
|
+
for state_id, rep_chg_ent in enumerate(rep_chg_ent_list):
|
|
1781
|
+
state_cor = mdt.load(idx2trajfile[rep_chg_ent[0]], top=psf_file)[rep_chg_ent[1]].center_coordinates().xyz*10
|
|
1782
|
+
nc = (rep_chg_ent[2], rep_chg_ent[3])
|
|
1783
|
+
chg_ent_fingerprint = chg_ent_fingerprint_list[rep_chg_ent[0]][rep_chg_ent[1]][nc]
|
|
1784
|
+
rep_ent_dict = {tuple(chg_ent_fingerprint['code']): [chg_ent_fingerprint]}
|
|
1785
|
+
os.chdir('viz_rep_chg_ent_%s'%('_'.join(self.classify_key)))
|
|
1786
|
+
gen_state_visualizion(state_id+1, psf_file, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=if_backmap, pulchra_only=pulchra_only)
|
|
1787
|
+
os.chdir('../')
|
|
1788
|
+
|
|
1789
|
+
# Generate visualiztion for representative changes of entanglements in each structural cluster
|
|
1790
|
+
self.logger.debug('Generate visualization for unique entangled structures...')
|
|
1791
|
+
if top_struct >= 1:
|
|
1792
|
+
viz_dir = 'viz_chg_ent_struct_%s_%d'%('_'.join(self.classify_key), top_struct)
|
|
1793
|
+
else:
|
|
1794
|
+
viz_dir = 'viz_chg_ent_struct_%s_%.4f'%('_'.join(self.classify_key), top_struct)
|
|
1795
|
+
os.system('mkdir %s'%viz_dir)
|
|
1796
|
+
for state_id, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
|
|
1797
|
+
if top_struct >= 1 and state_id >= top_struct:
|
|
1798
|
+
break
|
|
1799
|
+
elif top_struct < 1 and sorted_Num_struct_list[state_id]/tot_num_frames < top_struct:
|
|
1800
|
+
break
|
|
1801
|
+
[traj_idx, frame_idx] = rep_struct_data[keyword]
|
|
1802
|
+
state_cor = mdt.load(idx2trajfile[traj_idx], top=psf_file)[frame_idx].center_coordinates().xyz*10
|
|
1803
|
+
frame_idx_0 = idx2frame[traj_idx].index(frame_idx)
|
|
1804
|
+
rep_chg_ent_dict = rep_chg_ent_dtrajs[traj_idx][frame_idx_0]
|
|
1805
|
+
rep_ent_dict = {tuple(v['code']): [] for k, v in rep_chg_ent_dict.items()}
|
|
1806
|
+
for k, v in rep_chg_ent_dict.items():
|
|
1807
|
+
rep_ent_dict[tuple(v['code'])].append(v)
|
|
1808
|
+
os.chdir(viz_dir)
|
|
1809
|
+
gen_state_visualizion(state_id+1, psf_file, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=if_backmap, pulchra_only=pulchra_only)
|
|
1810
|
+
os.chdir('../')
|
|
1811
|
+
##########################################################################################################################################################
|
|
1812
|
+
|
|
1813
|
+
##########################################################################################################################################################
|
|
1814
|
+
def gen_state_visualizion(self, state_id, psf, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=True, pulchra_only=False):
|
|
1815
|
+
def idx2sel(idx_list):
|
|
1816
|
+
if len(idx_list) == 0:
|
|
1817
|
+
return ''
|
|
1818
|
+
else:
|
|
1819
|
+
sel = 'index'
|
|
1820
|
+
idx_0 = idx_list[0]
|
|
1821
|
+
idx_1 = idx_list[0]
|
|
1822
|
+
sel_0 = ' %d'%idx_0
|
|
1823
|
+
for i in range(1, len(idx_list)):
|
|
1824
|
+
if idx_list[i] == idx_list[i-1] + 1:
|
|
1825
|
+
idx_1 = idx_list[i]
|
|
1826
|
+
else:
|
|
1827
|
+
if idx_1 > idx_0:
|
|
1828
|
+
sel_0 += ' to %d'%idx_1
|
|
1829
|
+
sel += sel_0
|
|
1830
|
+
idx_0 = idx_list[i]
|
|
1831
|
+
idx_1 = idx_list[i]
|
|
1832
|
+
sel_0 = ' %d'%idx_0
|
|
1833
|
+
if idx_1 > idx_0:
|
|
1834
|
+
sel_0 += ' to %d'%idx_1
|
|
1835
|
+
sel += sel_0
|
|
1836
|
+
return sel
|
|
1837
|
+
|
|
1838
|
+
AA_name_list = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
|
|
1839
|
+
'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL',
|
|
1840
|
+
'HIE', 'HID', 'HIP']
|
|
1841
|
+
protein_colorid_list = [6, 6]
|
|
1842
|
+
loop_colorid_list = [1, 1]
|
|
1843
|
+
thread_colorid_list = [0, 0]
|
|
1844
|
+
nc_colorid_list = [3, 3]
|
|
1845
|
+
crossing_colorid_list = [8, 8]
|
|
1846
|
+
thread_cutoff=3
|
|
1847
|
+
terminal_cutoff=3
|
|
1848
|
+
|
|
1849
|
+
self.logger.info('Generate visualization of state %d'%(state_id))
|
|
1850
|
+
|
|
1851
|
+
struct = pmd.load_file(psf)
|
|
1852
|
+
struct.coordinates = state_cor
|
|
1853
|
+
|
|
1854
|
+
# backmap
|
|
1855
|
+
if if_backmap:
|
|
1856
|
+
if pulchra_only:
|
|
1857
|
+
pulchra_only = '1'
|
|
1858
|
+
else:
|
|
1859
|
+
pulchra_only = '0'
|
|
1860
|
+
struct.save('tmp.pdb', overwrite=True)
|
|
1861
|
+
os.system('backmap.py -i '+native_AA_pdb+' -c tmp.pdb -p '+pulchra_only)
|
|
1862
|
+
os.system('mv tmp_rebuilt.pdb state_%d.pdb'%state_id)
|
|
1863
|
+
os.system('rm -f tmp.pdb')
|
|
1864
|
+
os.system('rm -rf ./rebuild_tmp/')
|
|
1865
|
+
else:
|
|
1866
|
+
struct.save('state_%d.pdb'%state_id, overwrite=True)
|
|
1867
|
+
|
|
1868
|
+
ref_struct = pmd.load_file(native_AA_pdb)
|
|
1869
|
+
current_struct = pmd.load_file('state_%d.pdb'%state_id)
|
|
1870
|
+
|
|
1871
|
+
if len(list(rep_ent_dict.keys())) == 0:
|
|
1872
|
+
# no change of entaglement
|
|
1873
|
+
f = open('vmd_s%d_none.tcl'%(state_id), 'w')
|
|
1874
|
+
f.write('# Entanglement type: no change\n')
|
|
1875
|
+
f.write('''display rendermode GLSL
|
|
1876
|
+
axes location off
|
|
1877
|
+
|
|
1878
|
+
color Display {Background} white
|
|
1879
|
+
|
|
1880
|
+
mol new ./'''+('state_%d.pdb'%state_id)+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
|
|
1881
|
+
mol delrep 0 top
|
|
1882
|
+
mol representation NewCartoon 0.300000 10.000000 4.100000 0
|
|
1883
|
+
mol color ColorID '''+str(protein_colorid_list[1])+'''
|
|
1884
|
+
mol selection {all}
|
|
1885
|
+
mol material AOChalky
|
|
1886
|
+
mol addrep top
|
|
1887
|
+
''')
|
|
1888
|
+
f.close()
|
|
1889
|
+
|
|
1890
|
+
# Create vmd script for each type of change
|
|
1891
|
+
for ent_code, rep_ent_list in rep_ent_dict.items():
|
|
1892
|
+
pmd_struct_list = [ref_struct, current_struct]
|
|
1893
|
+
struct_dir_list = [native_AA_pdb, './state_%d.pdb'%state_id]
|
|
1894
|
+
key_prefix_list = ['ref_', '']
|
|
1895
|
+
repres_list = ['', '']
|
|
1896
|
+
align_sel_list = ['', '']
|
|
1897
|
+
|
|
1898
|
+
vmd_script = '''# Entanglement type: '''+str(rep_ent_list[0]['type'])+'''
|
|
1899
|
+
package require topotools
|
|
1900
|
+
display rendermode GLSL
|
|
1901
|
+
axes location off
|
|
1902
|
+
|
|
1903
|
+
color Display {Background} white
|
|
1904
|
+
|
|
1905
|
+
'''
|
|
1906
|
+
for struct_idx, pmd_struct in enumerate(pmd_struct_list):
|
|
1907
|
+
struct_dir = struct_dir_list[struct_idx]
|
|
1908
|
+
protein_colorid = protein_colorid_list[struct_idx]
|
|
1909
|
+
loop_colorid = loop_colorid_list[struct_idx]
|
|
1910
|
+
thread_colorid = thread_colorid_list[struct_idx]
|
|
1911
|
+
nc_colorid = nc_colorid_list[struct_idx]
|
|
1912
|
+
crossing_colorid = crossing_colorid_list[struct_idx]
|
|
1913
|
+
key_prefix = key_prefix_list[struct_idx]
|
|
1914
|
+
|
|
1915
|
+
# Clean ligands
|
|
1916
|
+
clean_sel_idx = np.zeros(len(pmd_struct.atoms))
|
|
1917
|
+
for res in pmd_struct.residues:
|
|
1918
|
+
if res.name in AA_name_list:
|
|
1919
|
+
for atm in res.atoms:
|
|
1920
|
+
clean_sel_idx[atm.idx] = 1
|
|
1921
|
+
pmd_clean_struct = pmd_struct[clean_sel_idx]
|
|
1922
|
+
clean_idx_to_idx = np.where(clean_sel_idx == 1)[0]
|
|
1923
|
+
|
|
1924
|
+
# vmd selection string for protein
|
|
1925
|
+
idx_list = []
|
|
1926
|
+
for res in pmd_struct.residues:
|
|
1927
|
+
if res.name in AA_name_list:
|
|
1928
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1929
|
+
vmd_sel = idx2sel(idx_list)
|
|
1930
|
+
|
|
1931
|
+
repres = '''mol new '''+struct_dir+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
|
|
1932
|
+
mol delrep 0 top
|
|
1933
|
+
mol representation NewCartoon 0.300000 10.000000 4.100000 0
|
|
1934
|
+
mol color ColorID '''+str(protein_colorid)+'''
|
|
1935
|
+
mol selection {'''+vmd_sel+'''}
|
|
1936
|
+
mol material AOChalky
|
|
1937
|
+
mol addrep top
|
|
1938
|
+
'''
|
|
1939
|
+
align_sel = vmd_sel
|
|
1940
|
+
for chg_ent_fingerprint in rep_ent_list:
|
|
1941
|
+
nc = chg_ent_fingerprint[key_prefix+'native_contact']
|
|
1942
|
+
|
|
1943
|
+
idx_list = []
|
|
1944
|
+
for res in pmd_clean_struct.residues:
|
|
1945
|
+
if res.idx in nc:
|
|
1946
|
+
idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
|
|
1947
|
+
nc_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1948
|
+
|
|
1949
|
+
idx_list = []
|
|
1950
|
+
for res in pmd_clean_struct.residues:
|
|
1951
|
+
if res.idx >= nc[0] and res.idx <= nc[1]:
|
|
1952
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1953
|
+
loop_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1954
|
+
|
|
1955
|
+
align_sel += ' and not (%s)'%loop_sel
|
|
1956
|
+
ref_coss_resid = chg_ent_fingerprint['ref_crossing_resid']
|
|
1957
|
+
cross_resid = chg_ent_fingerprint['crossing_resid']
|
|
1958
|
+
thread = []
|
|
1959
|
+
thread_sel_list = []
|
|
1960
|
+
for ter_idx in range(len(ref_coss_resid)):
|
|
1961
|
+
thread_0 = []
|
|
1962
|
+
resid_list = ref_coss_resid[ter_idx] + cross_resid[ter_idx]
|
|
1963
|
+
if len(resid_list) > 0:
|
|
1964
|
+
thread_0 = [np.min(resid_list)-5, np.max(resid_list)+5]
|
|
1965
|
+
if ter_idx == 0:
|
|
1966
|
+
thread_0[0] = np.max([thread_0[0], terminal_cutoff])
|
|
1967
|
+
thread_0[1] = np.min([thread_0[1], nc[0]-thread_cutoff])
|
|
1968
|
+
else:
|
|
1969
|
+
thread_0[0] = np.max([thread_0[0], nc[1]+thread_cutoff])
|
|
1970
|
+
thread_0[1] = np.min([thread_0[1], len(struct.atoms)-1-terminal_cutoff])
|
|
1971
|
+
idx_list = []
|
|
1972
|
+
for res in pmd_clean_struct.residues:
|
|
1973
|
+
if res.idx >= thread_0[0] and res.idx <= thread_0[1]:
|
|
1974
|
+
idx_list += [atm.idx for atm in res.atoms]
|
|
1975
|
+
thread_0_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
1976
|
+
thread_sel_list.append(thread_0_sel)
|
|
1977
|
+
align_sel += ' and not (%s)'%thread_0_sel
|
|
1978
|
+
else:
|
|
1979
|
+
thread_sel_list.append('')
|
|
1980
|
+
thread.append(thread_0)
|
|
1981
|
+
|
|
1982
|
+
ln = chg_ent_fingerprint[key_prefix+'topoly_linking_number']
|
|
1983
|
+
cross = []
|
|
1984
|
+
for i in range(len(chg_ent_fingerprint[key_prefix+'crossing_resid'])):
|
|
1985
|
+
cross.append([])
|
|
1986
|
+
for j in range(len(chg_ent_fingerprint[key_prefix+'crossing_resid'][i])):
|
|
1987
|
+
cross[-1].append(chg_ent_fingerprint[key_prefix+'crossing_pattern'][i][j]+str(chg_ent_fingerprint[key_prefix+'crossing_resid'][i][j]))
|
|
1988
|
+
repres += '# idx: native contact %s, linking number %s, crossings %s.\n'%(str(nc), str(ln), str(cross))
|
|
1989
|
+
repres +='''mol representation NewCartoon 0.350000 10.000000 4.100000 0
|
|
1990
|
+
mol color ColorID '''+str(loop_colorid)+'''
|
|
1991
|
+
mol selection {'''+loop_sel+'''}
|
|
1992
|
+
mol material Opaque
|
|
1993
|
+
mol addrep top
|
|
1994
|
+
mol representation VDW 1.000000 12.000000
|
|
1995
|
+
mol color ColorID '''+str(nc_colorid)+'''
|
|
1996
|
+
mol selection {'''+nc_sel+'''}
|
|
1997
|
+
mol material Opaque
|
|
1998
|
+
mol addrep top
|
|
1999
|
+
set sel [atomselect top "'''+nc_sel+'''"]
|
|
2000
|
+
set idx [$sel get index]
|
|
2001
|
+
topo addbond [lindex $idx 0] [lindex $idx 1]
|
|
2002
|
+
mol representation Bonds 0.300000 12.000000
|
|
2003
|
+
mol color ColorID '''+str(nc_colorid)+'''
|
|
2004
|
+
mol selection {'''+nc_sel+'''}
|
|
2005
|
+
mol material Opaque
|
|
2006
|
+
mol addrep top
|
|
2007
|
+
'''
|
|
2008
|
+
for ter_idx, thread_resid in enumerate(thread):
|
|
2009
|
+
if len(thread_resid) == 0:
|
|
2010
|
+
continue
|
|
2011
|
+
repres += '''mol representation NewCartoon 0.350000 10.000000 4.100000 0
|
|
2012
|
+
mol color ColorID '''+str(thread_colorid)+'''
|
|
2013
|
+
mol selection {'''+thread_sel_list[ter_idx]+'''}
|
|
2014
|
+
mol material Opaque
|
|
2015
|
+
mol addrep top
|
|
2016
|
+
'''
|
|
2017
|
+
if len(chg_ent_fingerprint[key_prefix+'crossing_resid'][ter_idx]) > 0:
|
|
2018
|
+
idx_list = []
|
|
2019
|
+
for res in pmd_clean_struct.residues:
|
|
2020
|
+
if res.idx in chg_ent_fingerprint[key_prefix+'crossing_resid'][ter_idx]:
|
|
2021
|
+
idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
|
|
2022
|
+
crossing_sel = idx2sel(clean_idx_to_idx[idx_list])
|
|
2023
|
+
repres += '''mol representation VDW 1.000000 12.000000
|
|
2024
|
+
mol color ColorID '''+str(crossing_colorid)+'''
|
|
2025
|
+
mol selection {'''+crossing_sel+'''}
|
|
2026
|
+
mol material Opaque
|
|
2027
|
+
mol addrep top
|
|
2028
|
+
'''
|
|
2029
|
+
|
|
2030
|
+
if struct_idx == 0:
|
|
2031
|
+
repres += '''mol representation VDW 1.000000 12.000000
|
|
2032
|
+
mol color Name
|
|
2033
|
+
mol selection {not ('''+vmd_sel+''') and not water}
|
|
2034
|
+
mol material Opaque
|
|
2035
|
+
mol addrep top
|
|
2036
|
+
'''
|
|
2037
|
+
repres_list[struct_idx] = repres
|
|
2038
|
+
align_sel_list[struct_idx] = align_sel
|
|
2039
|
+
|
|
2040
|
+
vmd_script += '\n'.join(repres_list)
|
|
2041
|
+
vmd_script += '''
|
|
2042
|
+
set sel1 [atomselect 0 "'''+align_sel_list[0]+''' and name CA"]
|
|
2043
|
+
set sel2 [atomselect 1 "'''+align_sel_list[1]+''' and name CA"]
|
|
2044
|
+
set trans_mat [measure fit $sel1 $sel2]
|
|
2045
|
+
set move_sel [atomselect 0 "all"]
|
|
2046
|
+
$move_sel move $trans_mat
|
|
2047
|
+
'''
|
|
2048
|
+
f = open('vmd_s%d_n%s_c%s.tcl'%(state_id, ent_code[0], ent_code[1]), 'w')
|
|
2049
|
+
f.write(vmd_script)
|
|
2050
|
+
f.close()
|
|
2051
|
+
##########################################################################################################################################################
|
|
2052
|
+
|
|
2053
|
+
##########################################################################################################################################################
|
|
2054
|
+
##########################################################################################################################################################
|
|
2055
|
+
|
|
2056
|
+
##########################################################################################################################################################
|
|
2057
|
+
class MSMNonNativeEntanglementClustering:
|
|
2058
|
+
"""
|
|
2059
|
+
Build a markov state model across an ensemble of protein structures with non-native entanglements to identify metastable states
|
|
2060
|
+
"""
|
|
2061
|
+
|
|
2062
|
+
#######################################################################################
|
|
2063
|
+
def __init__(self, OPpath:str = './', outdir:str = './', ID:str = '',
|
|
2064
|
+
start:int= 0 , end:int = 99999999999, stride:int = 1, ITS:str = 'False', lagtime:int = 1,
|
|
2065
|
+
n_cluster:int = 400, kmean_stride:int = 2, n_small_states:int = 1, n_large_states:int = 10, dt:float = 0.015/1000, rm_traj_list:list = [], log_level:int = logging.INFO, logdir:str = None):
|
|
2066
|
+
|
|
2067
|
+
"""
|
|
2068
|
+
Initializes the DataAnalysis class with necessary paths and parameters.
|
|
2069
|
+
|
|
2070
|
+
Parameters:
|
|
2071
|
+
("--outdir", type=str, required=True, help="Path to output directory")
|
|
2072
|
+
("--OPpath", type=str, required=True, help="Path to directory containing G and Q directories created by GQ.py")
|
|
2073
|
+
("--ID", type=str, required=True, help="base name for output files")
|
|
2074
|
+
("--start", type=int, required=False, help="First frame to analyze 0 indexed", default=0)
|
|
2075
|
+
("--end", type=int, required=False, help="Last frame to analyze 0 indexed", default=-1)
|
|
2076
|
+
("--stride", type=int, required=False, help="Frame stride", default=1)
|
|
2077
|
+
("--ITS", type=str, required=False, help="Find optimal lag time with ITS", default='False')
|
|
2078
|
+
("--lagtime", type=int, required=False, help="lagtime to build the model", default=1)
|
|
2079
|
+
("--n_cluster", type=int, required=False, help="Number of k-means clusters to group. Default is 400.")
|
|
2080
|
+
("--kmean_stride", type=int, required=False, help="Stride of reading trajectory frame when clustring by k-means. Default is 2.")
|
|
2081
|
+
("--n_small_states", type=int, required=False, help="Number of clusters for the inactive microstates after MSM modeling to be clustered into", default=1)
|
|
2082
|
+
("--n_large_states", type=int, required=False, help="Number of clusters for the active microstates after MSM modeling to be clustered into", default=10)
|
|
2083
|
+
("--dt", type=float, required=False, help="timestep used in MD simulations in ns", default=0.015/1000)
|
|
2084
|
+
("--rm_traj_list", type=str, nargs='+', required=False, help="List of trajectory numbers to remove from analysis", default=[])
|
|
2085
|
+
"""
|
|
2086
|
+
|
|
2087
|
+
# parse the parameters
|
|
2088
|
+
self.outdir = outdir
|
|
2089
|
+
self.ID = ID
|
|
2090
|
+
self.logger = setup_logger('MSMNonNativeEntanglementClustering', outdir=logdir if logdir is not None else outdir, ID=ID, log_level=log_level)
|
|
2091
|
+
|
|
2092
|
+
self.OPpath = OPpath
|
|
2093
|
+
self.logger.debug(f'OPpath: {self.OPpath}')
|
|
2094
|
+
self.logger.debug(f'outdir: {self.outdir}')
|
|
2095
|
+
self.logger.debug(f'ID: {self.ID}')
|
|
2096
|
+
|
|
2097
|
+
self.ITS = ITS
|
|
2098
|
+
self.logger.debug(f'ITS: {ITS}')
|
|
2099
|
+
|
|
2100
|
+
self.lagtime = lagtime
|
|
2101
|
+
self.logger.debug(f'lagtime: {lagtime}')
|
|
2102
|
+
|
|
2103
|
+
#self.dcds = args.dcds
|
|
2104
|
+
#print(f'dcds: {self.dcds}')
|
|
2105
|
+
|
|
2106
|
+
self.start = start
|
|
2107
|
+
self.end = end
|
|
2108
|
+
self.stride = stride
|
|
2109
|
+
self.logger.debug(f'START: {self.start} | END: {self.end} | STRIDE: {self.stride}')
|
|
2110
|
+
|
|
2111
|
+
self.n_cluster = n_cluster # Number of k-means clusters to group. Default is 400.
|
|
2112
|
+
self.kmean_stride = kmean_stride # Stride of reading trajectory frame when clustring by k-means.
|
|
2113
|
+
self.n_small_states = n_small_states # Number of clusters for the inactive microstates after MSM modeling to be clustered into
|
|
2114
|
+
self.n_large_states = n_large_states # Adjust based on your system
|
|
2115
|
+
self.dt = dt # timestep used in MD simulations in ns
|
|
2116
|
+
|
|
2117
|
+
self.rm_traj_list = rm_traj_list
|
|
2118
|
+
self.logger.info(f'Trajectories to ignore: {self.rm_traj_list}')
|
|
2119
|
+
|
|
2120
|
+
#######################################################################################
|
|
2121
|
+
|
|
2122
|
+
#######################################################################################
|
|
2123
|
+
def load_OP(self,):
|
|
2124
|
+
"""
|
|
2125
|
+
Loads the GQ values of each trajectory into a 2D array and then appends it to a list
|
|
2126
|
+
The list should have Nt = number of trajectories and each array should be n x 2 where n is the number of frames
|
|
2127
|
+
"""
|
|
2128
|
+
self.logger.info(f'Loading G and Q order parameters...')
|
|
2129
|
+
cor_list = []
|
|
2130
|
+
cor_list_idx_2_traj = {}
|
|
2131
|
+
Qfiles = glob.glob(os.path.join(self.OPpath, 'Q/*.Q'))
|
|
2132
|
+
QTrajs = [int(pathlib.Path(Qf).stem.split('Traj')[-1]) for Qf in Qfiles]
|
|
2133
|
+
|
|
2134
|
+
Gfiles = glob.glob(os.path.join(self.OPpath, 'G/*.G'))
|
|
2135
|
+
GTrajs = [int(pathlib.Path(Gf).stem.split('Traj')[-1]) for Gf in Gfiles]
|
|
2136
|
+
|
|
2137
|
+
shared_Trajs = set(QTrajs).intersection(GTrajs)
|
|
2138
|
+
#print(f'Shared Traj between Q and G: {shared_Trajs} {len(shared_Trajs)}')
|
|
2139
|
+
self.logger.info(f'Number of Q files found: {len(Qfiles)} | Number of G files found: {len(Gfiles)}')
|
|
2140
|
+
self.logger.info(f'Number of shared Traj between Q and G: {len(shared_Trajs)}')
|
|
2141
|
+
|
|
2142
|
+
|
|
2143
|
+
## remove trajectories that are in the rm_traj_list
|
|
2144
|
+
if len(self.rm_traj_list) > 0:
|
|
2145
|
+
self.logger.info(f'Removing trajectories: {self.rm_traj_list}')
|
|
2146
|
+
shared_Trajs = [traj for traj in shared_Trajs if traj not in self.rm_traj_list]
|
|
2147
|
+
self.logger.info(f'Number of shared Traj after removing: {len(shared_Trajs)}')
|
|
2148
|
+
|
|
2149
|
+
|
|
2150
|
+
# loop through the Qfiles and find matching Gfile
|
|
2151
|
+
# then load the Q and G time series into a 2D array
|
|
2152
|
+
idx = 0
|
|
2153
|
+
QFrames = {}
|
|
2154
|
+
GFrames = {}
|
|
2155
|
+
for traj in shared_Trajs:
|
|
2156
|
+
#print(f'Traj: {traj}')
|
|
2157
|
+
|
|
2158
|
+
# get the cooresponding G and Q file
|
|
2159
|
+
Qf = [f for f in Qfiles if f.endswith(f'Traj{traj}.Q')]
|
|
2160
|
+
Gf = [f for f in Gfiles if f.endswith(f'Traj{traj}.G')]
|
|
2161
|
+
self.logger.debug(f'Qf: {Qf}')
|
|
2162
|
+
self.logger.debug(f'Gf: {Gf}')
|
|
2163
|
+
|
|
2164
|
+
## Quality check to assert that only a single G and Q file were found
|
|
2165
|
+
assert len(Qf) == 1, f"the number of Q files {len(Qf)} should equal 1 for Traj {traj}"
|
|
2166
|
+
assert len(Gf) == 1, f"the number of G files {len(Gf)} should equal 1 for Traj {traj}"
|
|
2167
|
+
|
|
2168
|
+
# load the G Q data and extract only the time series column
|
|
2169
|
+
Qdata = pd.read_csv(Qf[0], sep=',')
|
|
2170
|
+
if self.start < 0:
|
|
2171
|
+
Qdata = Qdata.iloc[self.start:self.end + 1]
|
|
2172
|
+
else:
|
|
2173
|
+
Qdata = Qdata[(Qdata['Frame'] >= self.start) & (Qdata['Frame'] <= self.end)]
|
|
2174
|
+
#print(Qdata)
|
|
2175
|
+
QFrames[traj] = Qdata['Frame'].values
|
|
2176
|
+
Qdata = Qdata['total'].values.astype(float)
|
|
2177
|
+
|
|
2178
|
+
|
|
2179
|
+
|
|
2180
|
+
Gdata = pd.read_csv(Gf[0])
|
|
2181
|
+
if self.start < 0:
|
|
2182
|
+
Gdata = Gdata.iloc[self.start:self.end + 1]
|
|
2183
|
+
else:
|
|
2184
|
+
Gdata = Gdata[(Gdata['Frame'] >= self.start) & (Gdata['Frame'] <= self.end)]
|
|
2185
|
+
#print(Gdata)
|
|
2186
|
+
GFrames[traj] = Gdata['Frame'].values
|
|
2187
|
+
Gdata = Gdata['G'].values.astype(float)
|
|
2188
|
+
#print(f'Shape of OP: Q {Qdata.shape} G {Gdata.shape}')
|
|
2189
|
+
|
|
2190
|
+
## Quality check that QFrames == GFrames.
|
|
2191
|
+
if set(QFrames[traj]) != set(GFrames[traj]):
|
|
2192
|
+
raise ValueError(f'The frames in Q {QFrames[traj]} do not match the frames in G {GFrames[traj]} for Traj {traj}. Please check your data files.')
|
|
2193
|
+
|
|
2194
|
+
## Quality check that the G and Q data has the same number of frames
|
|
2195
|
+
if Qdata.shape != Gdata.shape:
|
|
2196
|
+
self.logger.warning(f"WARNING: The number of frames in Q {Qdata.shape} should equal the number of frames in G {Gdata.shape} in Traj {traj}")
|
|
2197
|
+
continue
|
|
2198
|
+
|
|
2199
|
+
## Check and ensure that Qdata or Gdata has no nan values
|
|
2200
|
+
if np.isnan(Qdata).any():
|
|
2201
|
+
raise ValueError(f'There is a NaN value in this Qdata')
|
|
2202
|
+
|
|
2203
|
+
if np.isnan(Gdata).any():
|
|
2204
|
+
raise ValueError(f'There is a NaN value in this Gdata')
|
|
2205
|
+
|
|
2206
|
+
data = np.stack((Qdata, Gdata)).T
|
|
2207
|
+
data = data.astype(float)
|
|
2208
|
+
cor_list.append(data)
|
|
2209
|
+
|
|
2210
|
+
cor_list_idx_2_traj[idx] = int(traj)
|
|
2211
|
+
idx += 1
|
|
2212
|
+
|
|
2213
|
+
self.logger.info(f'Number of trajecotry OP coordinate loaded: {len(cor_list)}')
|
|
2214
|
+
self.cor_list = cor_list
|
|
2215
|
+
self.QFrames = QFrames
|
|
2216
|
+
self.GFrames = GFrames
|
|
2217
|
+
|
|
2218
|
+
## Quality check that the number of trajectories loaded is equal to the number of Q and G files
|
|
2219
|
+
assert len(cor_list) == len(shared_Trajs), f"The # of coordinates loaded {len(cor_list)} does not equal the number of Q and G files with shared traj after removal of mirror images {len(shared_Trajs)}"
|
|
2220
|
+
|
|
2221
|
+
self.logger.info(f'Mapping of cor_list index to trajID in file names: {cor_list_idx_2_traj}')
|
|
2222
|
+
self.cor_list_idx_2_traj = cor_list_idx_2_traj
|
|
2223
|
+
#######################################################################################
|
|
2224
|
+
|
|
2225
|
+
#######################################################################################
|
|
2226
|
+
def standardize(self,):
|
|
2227
|
+
"""
|
|
2228
|
+
Standardizes your OP by taking the mean and std Q and G across all traj data and rescaling each trajectorys data by
|
|
2229
|
+
Z = (d - mean)/std
|
|
2230
|
+
"""
|
|
2231
|
+
data_con = self.cor_list[0]
|
|
2232
|
+
for i in range(1, len(self.cor_list)):
|
|
2233
|
+
data_con = np.vstack((data_con, self.cor_list[i]))
|
|
2234
|
+
self.data_mean = np.mean(data_con, axis=0)
|
|
2235
|
+
self.data_std = np.std(data_con, axis=0)
|
|
2236
|
+
self.standard_cor_list = [(d - self.data_mean) / self.data_std for d in self.cor_list]
|
|
2237
|
+
#######################################################################################
|
|
2238
|
+
|
|
2239
|
+
#######################################################################################
|
|
2240
|
+
def unstandardize(self, data):
|
|
2241
|
+
"""
|
|
2242
|
+
Unstandardizes your OP by taking the mean and std Q and G across all traj data and rescaling each trajectorys data by
|
|
2243
|
+
Z*std + mean = d
|
|
2244
|
+
"""
|
|
2245
|
+
return data*self.data_std + self.data_mean
|
|
2246
|
+
#######################################################################################
|
|
2247
|
+
|
|
2248
|
+
#######################################################################################
|
|
2249
|
+
def cluster(self,):
|
|
2250
|
+
"""
|
|
2251
|
+
Cluster the GQ data across all trajectories using kmeans.
|
|
2252
|
+
dtrajs contains the resulting kmeans cluster labels for each trajectory time series
|
|
2253
|
+
centers contains the standardized GQ coordinates of the cluster centers
|
|
2254
|
+
|
|
2255
|
+
if the number of unique centers found is not equal to self.n_cluster then adjust it to reflect the number found. This can happen if you have data that has a narrow distribution.
|
|
2256
|
+
"""
|
|
2257
|
+
self.clusters = pem.coordinates.cluster_kmeans(self.standard_cor_list, k=self.n_cluster, max_iter=5000, stride=self.kmean_stride)
|
|
2258
|
+
|
|
2259
|
+
# Get the microstate tagged trajectories and their state counts
|
|
2260
|
+
self.dtrajs = self.clusters.dtrajs
|
|
2261
|
+
self.logger.debug(f'dtrajs: {len(self.dtrajs)} {self.dtrajs[0].shape}\n{self.dtrajs[0][:10]}')
|
|
2262
|
+
clusterIDs, counts = np.unique(self.dtrajs, return_counts=True)
|
|
2263
|
+
self.logger.info(f'Number of unique microstate IDs: {len(clusterIDs)} {clusterIDs}')
|
|
2264
|
+
|
|
2265
|
+
state_counts = {}
|
|
2266
|
+
for i,c in zip(clusterIDs, counts):
|
|
2267
|
+
state_counts[i] = c
|
|
2268
|
+
self.logger.debug(f'state_counts: {state_counts}')
|
|
2269
|
+
|
|
2270
|
+
# Quality check that all microstate ids are assigned
|
|
2271
|
+
# If not renumber from 0
|
|
2272
|
+
if len(clusterIDs) != self.n_cluster:
|
|
2273
|
+
self.logger.info(f'The number of microstate IDs assigned does not match the number specified: {len(clusterIDs)} != {self.n_cluster}')
|
|
2274
|
+
|
|
2275
|
+
mapping_dict = {}
|
|
2276
|
+
for new,old in enumerate(clusterIDs):
|
|
2277
|
+
mapping_dict[old] = new
|
|
2278
|
+
self.logger.debug(f'mapping_dict: {mapping_dict}')
|
|
2279
|
+
|
|
2280
|
+
# Convert the dictionary to a numpy array for efficient mapping
|
|
2281
|
+
max_key = max(mapping_dict.keys())
|
|
2282
|
+
mapping_array = np.zeros(max_key + 1, dtype=int)
|
|
2283
|
+
for key, value in mapping_dict.items():
|
|
2284
|
+
mapping_array[key] = value
|
|
2285
|
+
|
|
2286
|
+
# Map the arrays using the mapping array
|
|
2287
|
+
self.dtrajs = [mapping_array[arr] for arr in self.dtrajs]
|
|
2288
|
+
|
|
2289
|
+
clusterIDs, counts = np.unique(self.dtrajs, return_counts=True)
|
|
2290
|
+
self.logger.info(f'Number of unique microstate IDs after mapping: {len(clusterIDs)} {clusterIDs}')
|
|
2291
|
+
state_counts = {}
|
|
2292
|
+
for i,c in zip(clusterIDs, counts):
|
|
2293
|
+
state_counts[i] = c
|
|
2294
|
+
self.logger.debug(f'state_counts: {state_counts}')
|
|
2295
|
+
|
|
2296
|
+
self.n_cluster = len(clusterIDs)
|
|
2297
|
+
|
|
2298
|
+
|
|
2299
|
+
standard_centers = self.clusters.clustercenters
|
|
2300
|
+
unstandard_centers = self.unstandardize(standard_centers)
|
|
2301
|
+
self.logger.info(f'unstandard_centers:\n{unstandard_centers} {unstandard_centers.shape}')
|
|
2302
|
+
self.logger.info(f'self.n_cluster: {self.n_cluster}')
|
|
2303
|
+
|
|
2304
|
+
#######################################################################################
|
|
2305
|
+
|
|
2306
|
+
#######################################################################################
|
|
2307
|
+
def build_msm(self, lagtime=1):
|
|
2308
|
+
self.logger.info(f'Building MSM model with a lag time of {lagtime}')
|
|
2309
|
+
|
|
2310
|
+
# Get count matrix and connective groups of microstates
|
|
2311
|
+
c_matrix = deeptime.markov.tools.estimation.count_matrix(self.dtrajs, lagtime).toarray()
|
|
2312
|
+
self.logger.info(f'c_matrix:\n{c_matrix} {c_matrix.shape}')
|
|
2313
|
+
|
|
2314
|
+
sub_groups = deeptime.markov.tools.estimation.connected_sets(c_matrix)
|
|
2315
|
+
self.logger.info(f'Total number of sub_groups: {len(sub_groups)}\n{sub_groups}')
|
|
2316
|
+
|
|
2317
|
+
# Build the MSM models for any connected sets that have more than 1 microstate
|
|
2318
|
+
msm_list = []
|
|
2319
|
+
for sg in sub_groups:
|
|
2320
|
+
cm = deeptime.markov.tools.estimation.largest_connected_submatrix(c_matrix, lcc=sg)
|
|
2321
|
+
self.logger.info(f'For sub_group: {sg}')
|
|
2322
|
+
if len(cm) == 1:
|
|
2323
|
+
msm = None
|
|
2324
|
+
else:
|
|
2325
|
+
self.logger.info(f'Building Transition matrix and MSM model')
|
|
2326
|
+
T = deeptime.markov.tools.estimation.transition_matrix(cm, reversible=True)
|
|
2327
|
+
msm = pem.msm.markov_model(T, dt_model=str(self.dt)+' ns')
|
|
2328
|
+
msm_list.append(msm)
|
|
2329
|
+
self.logger.info(f'Number of models: {len(msm_list)}')
|
|
2330
|
+
|
|
2331
|
+
# Coarse grain out the metastable macrostates in the models
|
|
2332
|
+
self.logger.info(f'Coarse grain out the metastable macrostates in the models')
|
|
2333
|
+
meta_dist = []
|
|
2334
|
+
meta_set = []
|
|
2335
|
+
eigenvalues_list = []
|
|
2336
|
+
for idx_msm, msm in enumerate(msm_list):
|
|
2337
|
+
|
|
2338
|
+
# the first model should contain the largest connected state so use the largest number of metastable states
|
|
2339
|
+
# for every other subgroup use the smallest
|
|
2340
|
+
if idx_msm == 0:
|
|
2341
|
+
n_states = self.n_large_states
|
|
2342
|
+
else:
|
|
2343
|
+
n_states = self.n_small_states
|
|
2344
|
+
|
|
2345
|
+
if msm == None:
|
|
2346
|
+
eigenvalues_list.append(None)
|
|
2347
|
+
dist = np.zeros(self.n_cluster)
|
|
2348
|
+
iidx = sub_groups[idx_msm][0]
|
|
2349
|
+
dist[iidx] = 1.0
|
|
2350
|
+
meta_dist.append(dist)
|
|
2351
|
+
meta_set.append(sub_groups[idx_msm])
|
|
2352
|
+
|
|
2353
|
+
else:
|
|
2354
|
+
eigenvalues_list.append(msm.eigenvalues())
|
|
2355
|
+
# coarse-graining
|
|
2356
|
+
while n_states > 1:
|
|
2357
|
+
tag_empty = False
|
|
2358
|
+
pcca = msm.pcca(n_states)
|
|
2359
|
+
for ms in msm.metastable_sets:
|
|
2360
|
+
if ms.size == 0:
|
|
2361
|
+
tag_empty = True
|
|
2362
|
+
break
|
|
2363
|
+
if not tag_empty:
|
|
2364
|
+
break
|
|
2365
|
+
else:
|
|
2366
|
+
n_states -= 1
|
|
2367
|
+
self.logger.info('Reduced number of states to %d for active group %d'%(n_states, idx_msm+1))
|
|
2368
|
+
if n_states == 1:
|
|
2369
|
+
# use observation prob distribution for non-active set
|
|
2370
|
+
dist = np.zeros(self.n_cluster)
|
|
2371
|
+
for nas in sub_groups[idx_msm]:
|
|
2372
|
+
for dtraj in dtrajs:
|
|
2373
|
+
dist[nas] += np.count_nonzero(dtraj == nas)
|
|
2374
|
+
dist /= np.sum(dist)
|
|
2375
|
+
meta_dist.append(dist)
|
|
2376
|
+
meta_set.append(sub_groups[idx_msm])
|
|
2377
|
+
else:
|
|
2378
|
+
for i, md in enumerate(msm.metastable_distributions):
|
|
2379
|
+
dist = np.zeros(self.n_cluster)
|
|
2380
|
+
s = np.sum(md[msm.metastable_sets[i]])
|
|
2381
|
+
set_0 = []
|
|
2382
|
+
for idx in msm.metastable_sets[i]:
|
|
2383
|
+
iidx = sub_groups[idx_msm][idx]
|
|
2384
|
+
dist[iidx] = md[idx]
|
|
2385
|
+
set_0.append(iidx)
|
|
2386
|
+
dist = dist / s
|
|
2387
|
+
meta_dist.append(dist)
|
|
2388
|
+
meta_set.append(set_0)
|
|
2389
|
+
meta_dist = np.array(meta_dist)
|
|
2390
|
+
self.logger.debug(f'meta_dist: {len(meta_dist)} {meta_dist.shape}')
|
|
2391
|
+
meta_dist_outfile = os.path.join(self.outdir, f'{self.ID}_meta_dist.npy')
|
|
2392
|
+
np.save(meta_dist_outfile, meta_dist, allow_pickle=True)
|
|
2393
|
+
self.logger.info(f'SAVED: {meta_dist_outfile}')
|
|
2394
|
+
|
|
2395
|
+
# print(f'meta_set: {meta_set}')
|
|
2396
|
+
meta_set_df = {'metastable_state':[], 'microstates':[]}
|
|
2397
|
+
for i, ms in enumerate(meta_set):
|
|
2398
|
+
# print(f'Metastable state {i}: {ms} with {len(ms)} microstates')
|
|
2399
|
+
for m in ms:
|
|
2400
|
+
meta_set_df['metastable_state'].append(i)
|
|
2401
|
+
meta_set_df['microstates'].append(m)
|
|
2402
|
+
|
|
2403
|
+
meta_set_df = pd.DataFrame(meta_set_df)
|
|
2404
|
+
self.logger.info(f'Meta set DataFrame:\n{meta_set_df}')
|
|
2405
|
+
meta_set_outfile = os.path.join(self.outdir, f'{self.ID}_meta_set.csv')
|
|
2406
|
+
meta_set_df.to_csv(meta_set_outfile, index=False)
|
|
2407
|
+
self.logger.info(f'SAVED: {meta_set_outfile}')
|
|
2408
|
+
|
|
2409
|
+
|
|
2410
|
+
## make microstate to metastable state mapping object
|
|
2411
|
+
self.logger.info(f'\nMetastable state assignment')
|
|
2412
|
+
meta_mapping = {}
|
|
2413
|
+
for metaID, microstates in enumerate(meta_set):
|
|
2414
|
+
#print(metaID, microstates)
|
|
2415
|
+
for m in microstates:
|
|
2416
|
+
if m not in meta_mapping:
|
|
2417
|
+
meta_mapping[m] = metaID
|
|
2418
|
+
else:
|
|
2419
|
+
raise ValueError(f'Microstate {m} already in a metastable state!')
|
|
2420
|
+
self.logger.debug(f'meta_mapping: {meta_mapping} {len(meta_mapping)}')
|
|
2421
|
+
|
|
2422
|
+
# map those microstate states to the metastable state
|
|
2423
|
+
metastable_dtraj = []
|
|
2424
|
+
for dtraj_idx, dtraj in enumerate(self.dtrajs):
|
|
2425
|
+
mapped_dtraj = []
|
|
2426
|
+
for d in dtraj:
|
|
2427
|
+
mapped_dtraj.append(meta_mapping[d])
|
|
2428
|
+
|
|
2429
|
+
#rint(mapped_dtraj)
|
|
2430
|
+
metastable_dtraj += [np.asarray(mapped_dtraj)]
|
|
2431
|
+
|
|
2432
|
+
self.logger.info(f'Metastable state mapping:')
|
|
2433
|
+
for dtraj_idx, dtraj in enumerate(metastable_dtraj):
|
|
2434
|
+
self.logger.debug(f'dtraj_idx={dtraj_idx} dtrajs[:10]={self.dtrajs[dtraj_idx][:10]} dtraj[:10]={dtraj[:10]} shape={dtraj.shape}')
|
|
2435
|
+
|
|
2436
|
+
|
|
2437
|
+
## get samples of metastable states by most populated microstates
|
|
2438
|
+
self.logger.debug(f'num_dtrajs={len(self.dtrajs)} dtrajs[0].shape={self.dtrajs[0].shape}')
|
|
2439
|
+
cluster_indexes = deeptime.markov.sample.compute_index_states(self.dtrajs)
|
|
2440
|
+
self.logger.debug(f'cluster_indexes: {len(cluster_indexes)}')
|
|
2441
|
+
|
|
2442
|
+
|
|
2443
|
+
samples = deeptime.markov.sample.indices_by_distribution(cluster_indexes, meta_dist, 5)
|
|
2444
|
+
self.logger.debug(f'samples: {samples} {len(samples)}')
|
|
2445
|
+
|
|
2446
|
+
## Make the output dataframe that has assignments for each frame of each traj
|
|
2447
|
+
df = {'traj':[], 'frame':[], 'microstate':[], 'metastablestate':[], 'Q':[], 'G':[], 'StateSample':[]}
|
|
2448
|
+
self.logger.info(f'Active & inactive metastable state mapping')
|
|
2449
|
+
for k,v in enumerate(metastable_dtraj):
|
|
2450
|
+
traj = self.cor_list_idx_2_traj[k]
|
|
2451
|
+
#print(k, traj, v[:10])
|
|
2452
|
+
for frame, macrostate in enumerate(v):
|
|
2453
|
+
microstate = self.dtrajs[k][frame]
|
|
2454
|
+
Q = self.cor_list[k][frame, 0]
|
|
2455
|
+
G = self.cor_list[k][frame, 1]
|
|
2456
|
+
|
|
2457
|
+
if [k, frame] in samples[macrostate].tolist():
|
|
2458
|
+
StateSample = True
|
|
2459
|
+
else:
|
|
2460
|
+
StateSample = False
|
|
2461
|
+
#print(k, frame, microstate, macrostate, StateSample)
|
|
2462
|
+
df['traj'] += [traj]
|
|
2463
|
+
df['frame'] += [self.QFrames[traj][frame]]
|
|
2464
|
+
df['microstate'] += [microstate]
|
|
2465
|
+
df['metastablestate'] += [macrostate]
|
|
2466
|
+
df['Q'] += [Q]
|
|
2467
|
+
df['G'] += [G]
|
|
2468
|
+
df['StateSample'] += [StateSample]
|
|
2469
|
+
|
|
2470
|
+
df = pd.DataFrame(df)
|
|
2471
|
+
#df['frame'] = self.start # correct the frame index to start from the start specified by the user as this frame index starts from 0
|
|
2472
|
+
self.logger.info(f'Final MSM mapping DF:\n{df}')
|
|
2473
|
+
df_outfile = os.path.join(self.outdir, f'{self.ID}_MSMmapping.csv')
|
|
2474
|
+
df.to_csv(df_outfile, index=False)
|
|
2475
|
+
self.logger.info(f'SAVED: {df_outfile}')
|
|
2476
|
+
|
|
2477
|
+
# Plot the metastable state membership and free energy surface
|
|
2478
|
+
xall = np.hstack([dtraj[:, 0] for dtraj in self.cor_list])
|
|
2479
|
+
yall = np.hstack([dtraj[:, 1] for dtraj in self.cor_list])
|
|
2480
|
+
states = np.hstack(metastable_dtraj)
|
|
2481
|
+
self.logger.debug(f'xall: {xall} {xall.shape}')
|
|
2482
|
+
self.logger.debug(f'yall: {yall} {yall.shape}')
|
|
2483
|
+
self.logger.debug(f'states: {states} {states.shape}')
|
|
2484
|
+
|
|
2485
|
+
stateplot_outfile = os.path.join(self.outdir, f'{self.ID}_StateAndFEplot.png')
|
|
2486
|
+
self.plot_state_map_and_FE(xall, yall, states, stateplot_outfile)
|
|
2487
|
+
#######################################################################################
|
|
2488
|
+
|
|
2489
|
+
#######################################################################################
|
|
2490
|
+
def plot_state_map_and_FE(self, x, y, states, outfile, cmap='viridis', point_size=50, alpha=0.85, title='State Map'):
|
|
2491
|
+
"""
|
|
2492
|
+
Plots a state map using x and y values colored by state assignments with labeled colorbar.
|
|
2493
|
+
Parameters:
|
|
2494
|
+
x (array-like): The x-coordinates of the points.
|
|
2495
|
+
y (array-like): The y-coordinates of the points.
|
|
2496
|
+
states (array-like): The state assignment for each point.
|
|
2497
|
+
cmap (str or Colormap): Colormap for state coloring (default is 'viridis').
|
|
2498
|
+
point_size (int): Size of the scatter plot points (default is 50).
|
|
2499
|
+
alpha (float): Transparency of the points (default is 0.7).
|
|
2500
|
+
title (str): Title of the plot (default is 'State Map with Labels').
|
|
2501
|
+
"""
|
|
2502
|
+
#############################################################################################
|
|
2503
|
+
# Create a figure and subplots with 1 row and 2 columns
|
|
2504
|
+
fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
|
|
2505
|
+
|
|
2506
|
+
### plot FE surface on left plot
|
|
2507
|
+
# Define the number of bins for the 2D histogram
|
|
2508
|
+
num_bins = 20
|
|
2509
|
+
|
|
2510
|
+
# Calculate the 2D histogram
|
|
2511
|
+
hist, xedges, yedges = np.histogram2d(x, y, bins=num_bins, density=True)
|
|
2512
|
+
|
|
2513
|
+
# Calculate the probability as the histogram values
|
|
2514
|
+
probability = hist / np.sum(hist)
|
|
2515
|
+
#print(f'probability: {probability} {np.unique(probability)}')
|
|
2516
|
+
|
|
2517
|
+
# Compute the free energy as -log10(probability)
|
|
2518
|
+
with np.errstate(divide='ignore'): # Ignore divide-by-zero warnings
|
|
2519
|
+
free_energy = -np.log10(probability)
|
|
2520
|
+
#free_energy[np.isinf(free_energy)] = np.nan # Set infinities to NaN for better plotting
|
|
2521
|
+
free_energy[np.isinf(free_energy) | np.isnan(free_energy)] = np.nanmax(free_energy[np.isfinite(free_energy)]) #+ 1 # Replace NaN/Inf with a large value
|
|
2522
|
+
self.logger.debug(f'free_energy: {free_energy} {free_energy.shape} {np.unique(free_energy)}')
|
|
2523
|
+
|
|
2524
|
+
# Create the meshgrid for the contour plot
|
|
2525
|
+
X, Y = np.meshgrid(xedges[:-1], yedges[:-1])
|
|
2526
|
+
self.logger.debug(f'X: {X.shape}\nY: {Y.shape}')
|
|
2527
|
+
|
|
2528
|
+
# Create a custom colormap
|
|
2529
|
+
#cmap = plt.cm.viridis
|
|
2530
|
+
#cmap = plt.cm.magma
|
|
2531
|
+
cmap = plt.cm.gist_ncar
|
|
2532
|
+
|
|
2533
|
+
# Plotting the contour plot
|
|
2534
|
+
contour = axes[0].contourf(X, Y, free_energy.T, levels=100, cmap=cmap) # Transpose to align axes
|
|
2535
|
+
fig.colorbar(contour, ax=axes[0], label='Free Energy (-log10 Probability)')
|
|
2536
|
+
axes[0].set_xlabel('Q')
|
|
2537
|
+
axes[0].set_ylabel('G')
|
|
2538
|
+
axes[0].set_title('2D Free Energy Contour Plot')
|
|
2539
|
+
axes[0].set_xlim(0,1)
|
|
2540
|
+
|
|
2541
|
+
|
|
2542
|
+
#############################################################################################
|
|
2543
|
+
## Plot state map
|
|
2544
|
+
#_, axes[1], _ = pem.plots.plot_state_map(x, y, states)
|
|
2545
|
+
|
|
2546
|
+
# Create a 2D histogram to determine the bin index for each (x, y) pair
|
|
2547
|
+
#############################################################################################
|
|
2548
|
+
# Step 1: Identify unique states
|
|
2549
|
+
unique_states = np.unique(states)
|
|
2550
|
+
n_states = len(unique_states)
|
|
2551
|
+
self.logger.debug(f'unique_states: {unique_states} {n_states}')
|
|
2552
|
+
|
|
2553
|
+
# Step 2: Create a colormap with one color per unique state
|
|
2554
|
+
# You can use any colormap, or define specific colors if desired
|
|
2555
|
+
colors = plt.cm.get_cmap('tab20', n_states) # 'tab10' has up to 10 colors; change if needed
|
|
2556
|
+
cmap = ListedColormap([colors(i) for i in range(n_states)])
|
|
2557
|
+
|
|
2558
|
+
# Step 3: Map states to color indices
|
|
2559
|
+
state_to_index = {state: i for i, state in enumerate(unique_states)}
|
|
2560
|
+
color_indices = np.vectorize(state_to_index.get)(states)
|
|
2561
|
+
|
|
2562
|
+
# # Step 4: Create scatter plot
|
|
2563
|
+
scatter = axes[1].scatter(x, y, c=color_indices, cmap=cmap, s=50, edgecolor='k') # Customize marker size, etc.
|
|
2564
|
+
|
|
2565
|
+
# Step 5: Add a colorbar with labels
|
|
2566
|
+
cbar = plt.colorbar(scatter, ax=axes[1], ticks=np.linspace(0.5, n_states - 1.5, num=n_states), label=f'Metastable States')
|
|
2567
|
+
cbar.ax.set_yticklabels(unique_states) # Label colorbar with the unique state values
|
|
2568
|
+
#############################################################################################
|
|
2569
|
+
|
|
2570
|
+
axes[1].set_xlabel('Q')
|
|
2571
|
+
axes[1].set_ylabel('G')
|
|
2572
|
+
axes[1].set_title('2D state map')
|
|
2573
|
+
axes[1].set_xlim(0,1)
|
|
2574
|
+
|
|
2575
|
+
#plt.tight_layout()
|
|
2576
|
+
plt.savefig(outfile)
|
|
2577
|
+
self.logger.info(f'SAVED: {outfile}')
|
|
2578
|
+
plt.clf()
|
|
2579
|
+
|
|
2580
|
+
#######################################################################################
|
|
2581
|
+
|
|
2582
|
+
#######################################################################################
|
|
2583
|
+
def plot_implied_timescales(self,):
|
|
2584
|
+
"""
|
|
2585
|
+
Should be done first before building the model to find a proper lagtime for which the timescales (eignenvalues of the transition matrix) of the model are no longer dependant.
|
|
2586
|
+
Look for the point or range where the implied timescales stop changing significantly with increasing lag times.
|
|
2587
|
+
This lag time is generally a good choice for building your MSM, as it suggests the dynamics are being captured without undue dependence on the initial conditions.
|
|
2588
|
+
"""
|
|
2589
|
+
#nits = -1
|
|
2590
|
+
lag_times = np.arange(1, 100, 10) # adjust the range based on your system
|
|
2591
|
+
n_states = len(np.unique(self.dtrajs)) # or a predefined number of states
|
|
2592
|
+
its = pem.msm.its(self.dtrajs, lags=lag_times, errors='bayes')
|
|
2593
|
+
pem.plots.plot_implied_timescales(its)
|
|
2594
|
+
ITS_outfile = os.path.join(self.outdir, f'{self.ID}_ITS.png')
|
|
2595
|
+
plt.savefig(ITS_outfile)
|
|
2596
|
+
self.logger.info(f'SAVED: {ITS_outfile}')
|
|
2597
|
+
#######################################################################################
|
|
2598
|
+
|
|
2599
|
+
#######################################################################################
|
|
2600
|
+
def run(self, ):
|
|
2601
|
+
|
|
2602
|
+
## make output folder
|
|
2603
|
+
if not os.path.exists(self.outdir):
|
|
2604
|
+
os.makedirs(self.outdir)
|
|
2605
|
+
self.logger.info(f'Made directory: {self.outdir}')
|
|
2606
|
+
|
|
2607
|
+
# load the G and Q data
|
|
2608
|
+
self.load_OP()
|
|
2609
|
+
|
|
2610
|
+
# apply the standard scalar transformation to the data
|
|
2611
|
+
self.standardize()
|
|
2612
|
+
|
|
2613
|
+
# cluster the standardized data using kmeans clustering with a stride of 10, change this if necessary
|
|
2614
|
+
self.cluster()
|
|
2615
|
+
|
|
2616
|
+
# genereate the implied timescales plot to check for a suitable lag time
|
|
2617
|
+
# Should be done first before building the model to choose a suitable lag time
|
|
2618
|
+
if self.ITS == 'True':
|
|
2619
|
+
anal.plot_implied_timescales()
|
|
2620
|
+
self.logger.info(f'Analysis terminated since ITS was selected. Check the figure and choose an approrate lagtime')
|
|
2621
|
+
quit()
|
|
2622
|
+
|
|
2623
|
+
# Build the MSM model with the choosen lagtime
|
|
2624
|
+
self.build_msm(lagtime=self.lagtime)
|
|
2625
|
+
#######################################################################################
|
|
2626
|
+
|