EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,2626 @@
1
+ #!/usr/bin/env python3
2
+ from collections import defaultdict
3
+ import numpy as np
4
+ import itertools
5
+ from geom_median.numpy import compute_geometric_median
6
+ from scipy.spatial.distance import cdist, squareform
7
+ from functools import cache
8
+ import re
9
+ import random
10
+ import pandas as pd
11
+ import pickle
12
+ import logging
13
+ import sys, getopt, math, os, time, traceback, glob, copy
14
+ import threading
15
+ from concurrent.futures import ThreadPoolExecutor, as_completed
16
+ from EntDetect._logging import setup_logger
17
+ from scipy.cluster.hierarchy import fcluster, linkage, cophenet
18
+ try:
19
+ import parmed as pmd
20
+ import mdtraj as mdt
21
+ except ImportError:
22
+ pmd = None
23
+ mdt = None
24
+ import matplotlib
25
+ import matplotlib.pyplot as plt
26
+ try:
27
+ import pyemma as pem
28
+ import deeptime
29
+ except ImportError:
30
+ pem = None
31
+ deeptime = None
32
+ from matplotlib.cm import get_cmap
33
+ from matplotlib.colors import ListedColormap, BoundaryNorm
34
+ import matplotlib.colors as mcolors
35
+ import seaborn as sns
36
+ from scipy.stats import mode
37
+ import pathlib
38
+ from dataclasses import dataclass
39
+ import json
40
+
41
+ matplotlib.use('Agg')
42
+ pd.set_option('display.max_rows', 500)
43
+
44
+ class ClusterNativeEntanglements:
45
+ """
46
+ Class to calculate native entanglements given either a file path to an entanglement file or an entanglement object
47
+ """
48
+
49
+ ##########################################################################################################################################################
50
+ def __init__(self, organism: str = 'Ecoli', cut_off: int = None, outdir: str = None, log_level: int = logging.INFO, logdir: str = None) -> None:
51
+ """
52
+ Constructor for GaussianEntanglement class.
53
+
54
+ Parameters
55
+ ----------
56
+ """
57
+
58
+ if organism == 'Human':
59
+ self.cut_off = 52
60
+ elif organism == 'Ecoli':
61
+ self.cut_off = 57
62
+ elif organism == 'Yeast':
63
+ self.cut_off = 49
64
+
65
+ if cut_off is not None:
66
+ self.cut_off = cut_off
67
+ self.organism = organism
68
+ self.logger = setup_logger('ClusterNativeEntanglements', outdir=logdir if logdir is not None else outdir, log_level=log_level)
69
+ ##########################################################################################################################################################
70
+
71
+ ##########################################################################################################################################################
72
+ def loop_distance(self, entangled_A: tuple, entangled_B: tuple):
73
+
74
+ # remove chiralites then perform euclidean distance
75
+ new_cr_A = [int(cr_A[1:]) for cr_A in entangled_A[3:-1]]
76
+ new_entangled_A = (entangled_A[0], entangled_A[1], entangled_A[2], *new_cr_A)
77
+
78
+ new_cr_B = [int(cr_B[1:]) for cr_B in entangled_B[3:-1]]
79
+ new_entangled_B = (entangled_B[0], entangled_B[1], entangled_B[2], *new_cr_B)
80
+
81
+ return math.dist(new_entangled_A[1:], new_entangled_B[1:])
82
+ ##########################################################################################################################################################
83
+
84
+ ##########################################################################################################################################################
85
+ def check_step_ij_kl_range(self, ent1: tuple, ent2: tuple):
86
+
87
+ # check if i or j of (i,j) reside within the range (inclusive) of (k,l), and vice versa
88
+
89
+ nc_pair_1 = ent1[1:3]
90
+ nc_pair_1_range = np.arange(ent1[1:3][0], ent1[1:3][1] + 1)
91
+
92
+ nc_pair_2 = ent2[1:3]
93
+ nc_pair_2_range = np.arange(ent2[1:3][0], ent2[1:3][1] + 1)
94
+
95
+ #return True if (nc_pair_1[0] in nc_pair_2_range or nc_pair_1[1] in nc_pair_2_range or
96
+
97
+ if nc_pair_1[0] in nc_pair_2_range or nc_pair_1[1] in nc_pair_2_range:
98
+ return True
99
+ elif nc_pair_2[0] in nc_pair_1_range or nc_pair_2[1] in nc_pair_1_range:
100
+ return True
101
+ else:
102
+ return False
103
+ ##########################################################################################################################################################
104
+
105
+ ##########################################################################################################################################################
106
+ #@cache
107
+ def Cluster_NativeEntanglements(self, GE_filepath: str, outdir: str='./', outfile: str='Cluster_NativeEntanglements.txt', chain: str=None):
108
+
109
+ """
110
+ PARAMS:
111
+ GE_file: str
112
+ cut_off: int
113
+
114
+ 1. Identify all unique "residue crossing set and chiralites"
115
+
116
+ 1b. sort the residues along with the chiralities
117
+
118
+ 2. Find the minimal loop encompassing a given "residue crossing set and chiralites"
119
+
120
+ 2b
121
+
122
+ i. Identify entanglements that have any crossing residues between them that are
123
+ less than or equal to 3 residues apart and have the same chirality.
124
+
125
+ ii. Then check if i or j of (i,j) reside within the range (inclusive) of (k,l), or vice versa;
126
+
127
+ iii. If yes, then check if any crossing residues are in the range of min(i,j,k,l) to max(i,j,k,l);
128
+ if yes, skip rest of 2
129
+
130
+ iv. If no, then check if the number of crossing residues, in each residue set, are different;
131
+
132
+ v. All crossing residue(s) in the entanglement with the fewer crossings need to have a distance <= 20
133
+ with the crossings in the other entanglement. Do this by the "brute force" approach and
134
+ the true distance formula. This means, calculate the distances and take the minimal distance
135
+ as the distance you check that is less than or equal to 20.
136
+
137
+ If yes, then keep the {i,j} {r} that have the greatest number of crossing residues;
138
+ If not, then keep the two entanglements separate.
139
+
140
+ 3. For at least two entanglements each with 1 or more crossings.
141
+ Loop over the entanglments two at time (avoid double counting)
142
+ Check if i or j of (i,j) reside within the range (inclusive) of (k,l), or vice versa;
143
+ If yes, check if number of crossing residues is the same (and it is 1 or more)
144
+ If yes, calculate the distances between all crossing residues
145
+ and if both have the same chiralities.
146
+ (Do NOT use brute force, just compare 1-to-1 index of crossing residues).
147
+ If all the distances are less than or equal to 20, then determine which
148
+ entanglement has the smaller loop, remove the entanglement with the larger loop
149
+
150
+ 4. Spatially cluster those outputs that have (i) the same number of crossings and (ii) the same chiralities
151
+
152
+ """
153
+ self.logger.info(f'Clustering {self.organism} Native Entanglements with dist_cutoff: {self.cut_off}')
154
+ GE_file = GE_filepath.split('/')[-1]
155
+ self.logger.debug(f'{GE_file} cut_off={self.cut_off} outdir={outdir}')
156
+
157
+ full_entanglement_data = defaultdict(list)
158
+
159
+ ent_data = defaultdict(list)
160
+
161
+ rep_ID_ent = defaultdict(list)
162
+
163
+ grouped_entanglement_data = defaultdict(list)
164
+
165
+ Before_cr_dist = defaultdict(list)
166
+
167
+ After_cr_dist = defaultdict(list)
168
+
169
+ entanglement_partial_g_data = {}
170
+
171
+ ## Check if the clustering file is already made and if so use it
172
+ outfilepath = os.path.join(f'{outdir}', f'{outfile}')
173
+ if os.path.exists(outfilepath):
174
+ self.logger.info(f'{outfilepath} ALREADY EXISTS AND WILL BE LOADED')
175
+ outdf = pd.read_csv(outfilepath, sep='|')
176
+ return {'outfile':outfilepath, 'ent_result':outdf}
177
+
178
+ self.logger.info(f'Loading {GE_filepath}')
179
+ GE_data = pd.read_csv(GE_filepath, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
180
+ GE_data = GE_data[GE_data['ENT'] == True].reset_index(drop=True)
181
+ # if Quality is a column name then only get the High Quality raw entanglements
182
+ if 'Quality' in GE_data.keys():
183
+ GE_data = GE_data[GE_data['Quality'] == 'High'].reset_index(drop=True)
184
+
185
+ GE_data = GE_data.replace(np.nan, '', regex=True)
186
+ self.logger.debug(GE_data)
187
+
188
+ ### STEP 1 INITAL LOADING AND MERGING ################################################################################################################
189
+ ############################################################################################
190
+ ## parse the entanglement file and
191
+ ## get those native contacts that are disulfide bonds
192
+ self.logger.info(f'# Step 1')
193
+ CCBonds = []
194
+ num_raw_ents = {}
195
+ chain_info = {} # Track chain for each ID
196
+ for rowi, row in GE_data.iterrows():
197
+ # print(row)
198
+
199
+ ID = row['ID']
200
+ chain = row['chain'] if 'chain' in row else 'A' # Default to 'A' if chain not present
201
+ chain_info[ID] = chain
202
+ native_contact_i, native_contact_j = row['i'], row['j']
203
+
204
+ # Store separate N and C crossings
205
+ crossingsN = row['crossingsN'] if pd.notna(row['crossingsN']) and row['crossingsN'] != '' else ''
206
+ crossingsC = row['crossingsC'] if pd.notna(row['crossingsC']) and row['crossingsC'] != '' else ''
207
+
208
+ crossing_res = [cr for cr in [crossingsN, crossingsC] if cr != '']
209
+ crossing_res = ','.join(crossing_res)
210
+ self.logger.debug(f'{native_contact_i} {native_contact_j} {crossing_res}')
211
+
212
+ gn, gc = row['gn'], row['gc']
213
+ GLNn, GLNc = row['GLNn'], row['GLNc']
214
+ TLNn, TLNc = row['TLNn'], row['TLNc']
215
+ # Handle empty strings from NaN replacement: convert to np.nan
216
+ TLNn = np.nan if (isinstance(TLNn, str) and TLNn == '') else TLNn
217
+ TLNc = np.nan if (isinstance(TLNc, str) and TLNc == '') else TLNc
218
+ CCbond = row['CCbond']
219
+
220
+ # keep track of number of raw ents for QC purposes
221
+ if ID not in num_raw_ents:
222
+ num_raw_ents[ID] = 1
223
+ else:
224
+ num_raw_ents[ID] += 1
225
+
226
+ #native_contact_i, native_contact_j, crossing_res = line[1], line[2], line[3]
227
+ #native_contact_i = int(native_contact_i)
228
+ #native_contact_j = int(native_contact_j)
229
+
230
+ reformat_cr = crossing_res.split(',') if crossing_res else []
231
+
232
+ # Filter out any empty strings from the split
233
+ reformat_cr = [cr for cr in reformat_cr if cr]
234
+
235
+ if reformat_cr:
236
+ reformat_cr = sorted(reformat_cr, key = lambda x: int(re.split("\\+|-", x, maxsplit= 1)[1]))
237
+ #print(native_contact_i, native_contact_j, reformat_cr)
238
+
239
+
240
+ # Step 1 and 1b
241
+ grouped_entanglement_data[(ID, *reformat_cr)].append((native_contact_i, native_contact_j))
242
+
243
+ entanglement_partial_g_data[(native_contact_i, native_contact_j, *reformat_cr)] = (gn, gc, GLNn, GLNc, TLNn, TLNc, crossingsN, crossingsC)
244
+
245
+ #print(f'CCbond: {CCbond}')
246
+ if CCbond == True:
247
+ CCBonds += [(native_contact_i, native_contact_j)]
248
+
249
+ #print(f'num_raw_ents: {num_raw_ents}')
250
+
251
+ #print(f'Step 1 results')
252
+ Step1_QC_counter = 0
253
+ for k,v in grouped_entanglement_data.items():
254
+ #print(k,v)
255
+ Step1_QC_counter += len(v)
256
+
257
+ # STEP 1 SUMMARY
258
+ self.logger.info(f'\n{"="*100}')
259
+ self.logger.info(f'STEP 1 SUMMARY: Data Loading and Grouping by Unique Crossing Sets')
260
+ self.logger.info(f'{"="*100}')
261
+ self.logger.info(f'Total raw entanglements loaded: {Step1_QC_counter}')
262
+ self.logger.info(f'Number of protein IDs: {len(num_raw_ents)}')
263
+ for prot_id, count in sorted(num_raw_ents.items()):
264
+ self.logger.info(f' - {prot_id}: {count} raw entanglements')
265
+ self.logger.info(f'Unique crossing patterns identified: {len(grouped_entanglement_data)}')
266
+ self.logger.info(f'Disulfide bonds found: {len(CCBonds)}')
267
+ if CCBonds:
268
+ self.logger.info(f' - Disulfide bond pairs: {CCBonds}')
269
+
270
+ ### STEP 2 ################################################################################################################
271
+ ############################################################################################
272
+ # Step 2 Get the minimal loop encompassing each set of unique crossings
273
+ self.logger.info(f'\n# Step 2a')
274
+ for ID_cr, loops in grouped_entanglement_data.items():
275
+
276
+ ID = ID_cr[0]
277
+
278
+ crossings = np.asarray(list(ID_cr[1:]))
279
+
280
+ loop_lengths = [nc[1] - nc[0] for nc in loops]
281
+
282
+ minimum_loop_length = min(loop_lengths)
283
+
284
+ minimum_loop_length_index = loop_lengths.index(minimum_loop_length)
285
+
286
+ minimum_loop_nc_i, minimum_loop_nc_j = loops[minimum_loop_length_index]
287
+
288
+ ent_data[ID].append((len(loops), minimum_loop_nc_i, minimum_loop_nc_j, *crossings, ';'.join(['-'.join([str(loop[0]), str(loop[1])]) for loop in loops])))
289
+
290
+ # STEP 2a SUMMARY
291
+ self.logger.info(f'{"="*100}')
292
+ self.logger.info(f'STEP 2A SUMMARY: Minimal Loop Identification for Unique Crossing Sets')
293
+ self.logger.info(f'{"="*100}')
294
+ for ID, ents in ent_data.items():
295
+ Step2a_QC_counter = 0
296
+ for ent_i, ent in enumerate(ents):
297
+ #print(ID, ent_i, ent)
298
+ Step2a_QC_counter += ent[0]
299
+
300
+ self.logger.info(f'{ID}: {Step2a_QC_counter} raw entanglements grouped into {len(ents)} representative loops')
301
+ for ent_i, ent in enumerate(ents):
302
+ num_loops = ent[0]
303
+ loop_i, loop_j = ent[1], ent[2]
304
+ crossings = [str(c) for c in ent[3:-1]]
305
+ self.logger.info(f' Representative {ent_i+1}: Loop ({loop_i}, {loop_j}), ' +
306
+ f'Crossings={crossings if crossings else "none"}, ' +
307
+ f'Represents {num_loops} raw entanglement(s)')
308
+
309
+ ## QC that the number of tracked entanglements after step 2a is still valid
310
+ #print(f'Step2a_QC_counter: {Step2a_QC_counter} should = {num_raw_ents[ID]}')
311
+ if Step2a_QC_counter != num_raw_ents[ID]:
312
+ raise ValueError(f'The number of tracked entaglements after Step 2a {Step2a_QC_counter} != {num_raw_ents[ID]}')
313
+
314
+ ############################################################################################
315
+ # Step 2b:
316
+ self.logger.info(f'{"="*100}')
317
+ self.logger.info(f'STEP 2B: Merging of Entanglements Based on Crossing Proximity')
318
+ self.logger.info(f'{"="*100}')
319
+ merged_ents = []
320
+ for ID, ents in ent_data.items():
321
+ orig_ents = ents.copy()
322
+ comb_ents = itertools.combinations(ents, 2)
323
+
324
+ # for each pair of ents
325
+ for each_ent_pair in comb_ents:
326
+ #print(f'\nAnalyzing pair: {each_ent_pair}')
327
+ if each_ent_pair[0] == each_ent_pair[1]:
328
+ self.logger.info(f'Ents are the same: {each_ent_pair}')
329
+ continue
330
+
331
+ distance_thresholds = []
332
+
333
+ ent1 = each_ent_pair[0]
334
+ ent2 = each_ent_pair[1]
335
+
336
+ # get crossings from ent pair without chiralities
337
+ cr1 = set([int(ent_cr_1[1:]) for ent_cr_1 in list(ent1[3:-1])])
338
+ cr2 = set([int(ent_cr_2[1:]) for ent_cr_2 in list(ent2[3:-1])])
339
+ #print(cr1, cr2)
340
+
341
+ # get all possible pairs of the crossings
342
+ all_cr_pairs = itertools.product(ent1[3:-1], ent2[3:-1])
343
+
344
+ # get the distances between all pairs of crossings
345
+ cr_dist_same_chiral = np.abs([int(pr[0][1:]) - int(pr[1][1:]) for pr in all_cr_pairs if pr[0][0] == pr[1][0]])
346
+ #print(cr_dist_same_chiral)
347
+
348
+
349
+ # if any of those distances is less than 3 and the number of crossings is not the same and ij in range of kl
350
+ dist_check = np.any(cr_dist_same_chiral <= 3)
351
+ loop_check = self.check_step_ij_kl_range(ent1, ent2)
352
+ diff_cross_size_check = len(cr1) != len(cr2)
353
+ #print(dist_check, loop_check, diff_cross_size_check)
354
+
355
+ if np.any(cr_dist_same_chiral <= 3) and len(cr1) != len(cr2) and self.check_step_ij_kl_range(ent1, ent2):
356
+ #print(f'\nAnalyzing pair: {each_ent_pair}')
357
+ #print(f'step 2b conditions met')
358
+
359
+ minumum_loop_base = min(ent1[1], ent1[2], ent2[1], ent2[2])
360
+
361
+ maximum_loop_base = max(ent1[1], ent1[2], ent2[1], ent2[2])
362
+
363
+ all_crossings = cr1.union(cr2)
364
+
365
+ min_max_loop_base_range = set(range(minumum_loop_base, maximum_loop_base + 1))
366
+
367
+ # if the crossings are not within the min max loop range covering both entanglements
368
+ if not min_max_loop_base_range.intersection(all_crossings):
369
+
370
+ fewer_cr = min(cr1, cr2, key = len)
371
+ more_cr = max(cr1, cr2, key = len)
372
+
373
+ distributive_product = list(itertools.product(fewer_cr, more_cr))
374
+
375
+ slices = itertools.islice(distributive_product, 0, None, len(more_cr))
376
+
377
+ groupings = []
378
+
379
+ for end_point in slices:
380
+
381
+ first_index = distributive_product.index(end_point)
382
+
383
+ groupings.append(distributive_product[first_index:len(more_cr) + first_index])
384
+
385
+ if len(groupings) != 1:
386
+
387
+ all_pair_products = itertools.product(*groupings)
388
+ all_pair_groupings = set()
389
+
390
+ for pairs in all_pair_products:
391
+
392
+ flag = True
393
+
394
+ # check common elements column wise
395
+ stacked_pairs = np.stack(pairs)
396
+
397
+ for col in range(stacked_pairs.shape[1]):
398
+
399
+ if stacked_pairs[:, col].size != len(set(stacked_pairs[:, col])):
400
+
401
+ flag = False
402
+ break
403
+
404
+ if flag:
405
+
406
+ all_pair_groupings.add(pairs)
407
+
408
+ else:
409
+
410
+ all_pair_groupings = groupings[0]
411
+
412
+ for condensed_pair in all_pair_groupings:
413
+
414
+ if isinstance(condensed_pair[0], int):
415
+
416
+ # when dealing with ent with one crossing
417
+
418
+ condensed_pair = [condensed_pair]
419
+
420
+ dist = np.sqrt(sum([(each_ele[0] - each_ele[1]) ** 2 for each_ele in condensed_pair]))
421
+
422
+ distance_thresholds.append(dist)
423
+
424
+ # all_pair_groupings and distance thresholds have the same size
425
+ if min(distance_thresholds) <= 20:
426
+
427
+ min_ent = min(ent1, ent2, key = len)
428
+ max_ent = max(ent1, ent2, key = len)
429
+ if min_ent == max_ent:
430
+ self.logger.warning(f'WARNING: Ents are the same. Setting min_ent = ent1 and max_ent = ent2')
431
+ min_ent = ent1
432
+ max_ent = ent2
433
+ #print(f'ent1: {ent1} | ent2: {ent2}')
434
+ #print(f'min_ent: {min_ent}')
435
+ #print(f'max_ent: {max_ent}')
436
+
437
+ if min_ent in ents and len(ents) > 1:
438
+ #if min_ent in ents and max_ent in ents and len(ents) > 1:
439
+ #min_ent_num_ncs = min_ent[0]
440
+
441
+ #print(f'Removing: min_ent {min_ent} at index {ents.index(min_ent)}')
442
+ del ents[ents.index(min_ent)]
443
+ #del ents[ents.index(max_ent)]
444
+
445
+ if max_ent == min_ent:
446
+ raise ValueError(f'WARNING: Ents are the same\n{min_ent} == {max_ent}')
447
+ else:
448
+ merged_ents += [(max_ent, min_ent)]
449
+
450
+
451
+ #print(f'\nStep 2b results')
452
+ # results foor the end of step 2
453
+ for ID, ents in ent_data.items():
454
+ ent_dict = {ent_idx:[ent] for ent_idx, ent in enumerate(ents)}
455
+ #for ent_idx, ent in ent_dict.items():
456
+ # print(ent_idx, ent)
457
+
458
+ ### Update entanglement list with those that got merged
459
+ self.logger.info(f'\n Processing {ID}: Analyzing {len(ents)} representative entanglements for merging...')
460
+ self.logger.info(f' Before merge: {len(ents)} representatives')
461
+ merge_count = 0
462
+ while len(merged_ents) != 0:
463
+ pre_num_merged = len(merged_ents)
464
+ #print(f'# merged_ents: {pre_num_merged}')
465
+ for m_ent in merged_ents:
466
+ #print(f'\n{m_ent}')
467
+ for ent_idx, ent in ent_dict.copy().items():
468
+ #print(ent_idx, ent)
469
+ if m_ent[0] in ent:
470
+ #print(f'FOUND MATCH for kept ent {ent_idx}')
471
+ ent_dict[ent_idx] += [m_ent[1]]
472
+ merged_ents.remove(m_ent)
473
+ merge_count += 1
474
+
475
+ #print(f'# merged_ents: {len(merged_ents)}')
476
+
477
+ # QC to ensure you dont enter an infinite loop
478
+ if len(merged_ents) == pre_num_merged:
479
+ raise ValueError('Failed to find a match this cycle and entering infi loop')
480
+
481
+ self.logger.info(f' After merge: {merge_count} merges completed')
482
+ self.logger.info(f' Reformatting entanglement data...')
483
+ updated_ents = []
484
+ for ent_idx, ent in ent_dict.items():
485
+ #print(ent_idx, ent)
486
+ if len(ent) > 1:
487
+ num_loops = np.sum([e[0] for e in ent])
488
+ NCs = ';'.join([e[-1] for e in ent])
489
+ #print(ent, num_loops, NCs)
490
+ ent = (num_loops, *ent[0][1:-1], NCs)
491
+ updated_ents += [ent]
492
+ else:
493
+ updated_ents += ent
494
+
495
+ #print(f'Results after adding those that got merged to each representative entanglement')
496
+ Step2b_QC_counter = 0
497
+ for uent in updated_ents:
498
+ #print(uent)
499
+ Step2b_QC_counter += uent[0]
500
+
501
+ self.logger.info(f' Result: {len(updated_ents)} representative entanglements after merging (tracking {Step2b_QC_counter} total raw)')
502
+ self.logger.debug(updated_ents)
503
+ ## QC that the number of tracked entanglements after step 2a is still valid
504
+ if Step2b_QC_counter != num_raw_ents[ID]:
505
+ raise ValueError(f'The number of tracked entaglements after Step 2b {Step2b_QC_counter} != {num_raw_ents[ID]}')
506
+
507
+ ent_data[ID] = updated_ents
508
+
509
+ ### STEP 3 ################################################################################################################
510
+ # Step 3
511
+ self.logger.info(f'{"="*100}')
512
+ self.logger.info(f'STEP 3: Removing Duplicate Entanglements (Same Crossings, Different Loop Sizes)')
513
+ self.logger.info(f'{"="*100}')
514
+ for ID, processed_ents in ent_data.items():
515
+ self.logger.info(f' {ID}: Checking {len(processed_ents)} entanglements for duplicates...')
516
+
517
+ comb_processed_ents = itertools.combinations(processed_ents, 2)
518
+
519
+ keep_track_of_larger_proc_ent = []
520
+ keep_track_of_shorter_proc_ent = []
521
+ removal_count = 0
522
+
523
+ for each_processed_ent_pair in comb_processed_ents:
524
+ #print(f'\npair of ents: {each_processed_ent_pair}')
525
+
526
+ proc_ent1 = each_processed_ent_pair[0]
527
+ proc_ent2 = each_processed_ent_pair[1]
528
+
529
+ proc_ent1_ijr = proc_ent1[1:-1]
530
+ proc_ent2_ijr = proc_ent2[1:-1]
531
+ #print(proc_ent1_ijr, proc_ent2_ijr)
532
+
533
+ if proc_ent1_ijr not in keep_track_of_larger_proc_ent and proc_ent2_ijr not in keep_track_of_larger_proc_ent:
534
+
535
+ # without chiralites
536
+ proc_cr1 = np.asarray([int(ent_cr_1[1:]) for ent_cr_1 in list(proc_ent1[3:-1])])
537
+ proc_cr2 = np.asarray([int(ent_cr_2[1:]) for ent_cr_2 in list(proc_ent2[3:-1])])
538
+
539
+ if len(proc_ent1[3:-1]) == len(proc_ent2[3:-1]):
540
+
541
+ chirality1 = [chir1[0] for chir1 in proc_ent1[3:-1]]
542
+ chirality2 = [chir2[0] for chir2 in proc_ent2[3:-1]]
543
+
544
+ if chirality1 == chirality2 and self.check_step_ij_kl_range(proc_ent1, proc_ent2) and np.all(np.abs(proc_cr1 - proc_cr2) <= 20):
545
+ #print(proc_ent1, proc_ent2)
546
+
547
+ loop1_length = proc_ent1[2] - proc_ent1[1]
548
+ loop2_length = proc_ent2[2] - proc_ent2[1]
549
+ #print(loop1_length, loop2_length)
550
+
551
+ check = [loop1_length, loop2_length]
552
+
553
+ maximum_loop_length = max(loop1_length, loop2_length)
554
+ minimum_loop_length = min(loop1_length, loop2_length)
555
+
556
+ if maximum_loop_length == minimum_loop_length:
557
+ longer_loop_ent = proc_ent1
558
+ shorter_loop_ent = proc_ent2
559
+ else:
560
+ longer_loop_ent = each_processed_ent_pair[check.index(maximum_loop_length)]
561
+ shorter_loop_ent = each_processed_ent_pair[check.index(minimum_loop_length)]
562
+ longer_loop_ent_ijr = longer_loop_ent[1:-1]
563
+ shorter_loop_ent_ijr = shorter_loop_ent[1:-1]
564
+
565
+ if len(processed_ents) > 1:
566
+
567
+ for long_proc_ent_index, long_proc_ent in enumerate(processed_ents):
568
+ long_proc_ent_ijr = long_proc_ent[1:-1]
569
+ #print(long_proc_ent_index, long_proc_ent, long_proc_ent_ijr)
570
+ if long_proc_ent_ijr == longer_loop_ent_ijr:
571
+ break
572
+ del processed_ents[long_proc_ent_index]
573
+ # remove the one with larger loop
574
+
575
+ # find the shorter loop and remove it
576
+ for short_proc_ent_index, short_proc_ent in enumerate(processed_ents):
577
+ short_proc_ent_ijr = short_proc_ent[1:-1]
578
+ if short_proc_ent_ijr == shorter_loop_ent_ijr:
579
+ break
580
+ del processed_ents[short_proc_ent_index]
581
+
582
+ updated_ent = (short_proc_ent[0] + long_proc_ent[0], *short_proc_ent[1:-1], short_proc_ent[-1] + ';' + long_proc_ent[-1])
583
+ processed_ents += [updated_ent]
584
+
585
+ keep_track_of_larger_proc_ent.append(longer_loop_ent_ijr)
586
+ keep_track_of_shorter_proc_ent.append(shorter_loop_ent_ijr)
587
+
588
+ # STEP 3 FINAL SUMMARY
589
+ self.logger.info(f'\nSTEP 3 RESULTS:')
590
+ for ID, ents in ent_data.items():
591
+ Step3_QC_counter = 0
592
+ for ent in ents:
593
+ #print(ent)
594
+ Step3_QC_counter += ent[0]
595
+
596
+ self.logger.info(f' {ID}: {len(ents)} representative entanglements remaining (tracking {Step3_QC_counter} raw)')
597
+ # QC to ensure number of raw ents was preserved after step 3
598
+ if Step3_QC_counter != num_raw_ents[ID]:
599
+ raise ValueError(f'The number of tracked entaglements after Step 3 {Step3_QC_counter} != {num_raw_ents[ID]}')
600
+
601
+ ### STEP 4 SPATIAL CLUSTERING ################################################################################################################
602
+ # Step 4 prep
603
+ self.logger.info(f'{"="*100}')
604
+ self.logger.info(f'STEP 4 PREP: Grouping Entanglements by Number and Chirality of Crossings')
605
+ self.logger.info(f'{"="*100}')
606
+ for ID, new_ents in ent_data.items():
607
+
608
+ for ent in new_ents:
609
+
610
+ number_of_crossings = len(ent[3:-1])
611
+
612
+ chiralites = [each_cr[0] for each_cr in ent[3:-1]]
613
+
614
+ ID_num_chirality_key = f"{ID}_{number_of_crossings}_{chiralites}"
615
+
616
+ full_entanglement_data[ID_num_chirality_key].append(ent)
617
+
618
+ self.logger.info(f'\nGrouping Summary:')
619
+ for group_key in sorted(full_entanglement_data.keys()):
620
+ ents = full_entanglement_data[group_key]
621
+ self.logger.info(f' {group_key}: {len(ents)} entanglements')
622
+
623
+ reset_counter = []
624
+
625
+ # Step 4
626
+ self.logger.info(f'{"="*100}')
627
+ self.logger.info(f'STEP 4: Primary Structure Clustering Within Each Group')
628
+ self.logger.info(f'{"="*100}')
629
+ for ID_num_chiral in full_entanglement_data.keys():
630
+ #print(ID_num_chiral)
631
+ #ID = ID_num_chiral.split("_")[0]
632
+
633
+ if ID not in reset_counter:
634
+
635
+ reset_counter.append(ID)
636
+
637
+ split_cluster_counter = 0
638
+
639
+ length_key = defaultdict(list)
640
+ loop_dist = defaultdict(list)
641
+ dups = []
642
+ clusters = {}
643
+ cluster_count = 0
644
+
645
+ pairwise_entanglements = list(itertools.combinations(full_entanglement_data[ID_num_chiral], 2))
646
+
647
+ self.logger.info(f'\n Group: {ID_num_chiral}')
648
+ self.logger.info(f' Total entanglements: {len(full_entanglement_data[ID_num_chiral])}')
649
+ self.logger.info(f' Pairwise comparisons: {len(pairwise_entanglements)}')
650
+
651
+ if pairwise_entanglements:
652
+
653
+ for i, pairwise_ent in enumerate(pairwise_entanglements):
654
+
655
+ dist = self.loop_distance(pairwise_ent[0], pairwise_ent[1])
656
+
657
+ if dist <= self.cut_off and pairwise_ent[0] not in dups and pairwise_ent[1] not in dups:
658
+ # 1. pair must be <= self.cut_off
659
+ # 2. the neighbor cannot be the next key and it cannot be captured by another key
660
+
661
+ loop_dist[pairwise_ent[0]].append(pairwise_ent[1])
662
+ dups.append(pairwise_ent[1])
663
+
664
+ key_list = list(loop_dist.keys())
665
+
666
+ for key in key_list:
667
+
668
+ length_key[len(loop_dist[key])].append(key)
669
+
670
+ # create clusters
671
+
672
+ while len(length_key.values()) > 0:
673
+
674
+ max_neighbor = max(length_key.keys())
675
+
676
+ selected_ent = random.choice(length_key[max_neighbor])
677
+
678
+ cluster = copy.deepcopy(loop_dist[selected_ent])
679
+ cluster.append(selected_ent)
680
+
681
+ clusters[cluster_count] = cluster
682
+ cluster_count += 1
683
+
684
+ length_key[max_neighbor].remove(selected_ent)
685
+
686
+ if len(length_key[max_neighbor]) == 0:
687
+ length_key.pop(max_neighbor)
688
+
689
+ # create single clusters
690
+ if clusters:
691
+ clusters_ijr_values = list(itertools.chain.from_iterable(list(clusters.values())))
692
+ else:
693
+ clusters_ijr_values = []
694
+
695
+ full_ent_values = np.asarray(full_entanglement_data[ID_num_chiral], dtype=object)
696
+
697
+ difference_ent = np.zeros(len(full_ent_values), dtype=bool)
698
+
699
+ for k, ijr in enumerate(full_ent_values):
700
+
701
+ if tuple(ijr) in clusters_ijr_values:
702
+ difference_ent[k] = True
703
+ else:
704
+ difference_ent[k] = False
705
+
706
+ i = np.unique(np.where(difference_ent == False)[0])
707
+
708
+ next_cluster_count = cluster_count
709
+
710
+ for single_cluster in full_ent_values[i]:
711
+
712
+ single_cluster_list = []
713
+ single_cluster_list.append(tuple(single_cluster))
714
+
715
+ clusters[next_cluster_count] = single_cluster_list
716
+
717
+ next_cluster_count += 1
718
+
719
+ # pick representative entanglement per cluster
720
+ self.logger.info(f' Primary structure clusters formed: {len(clusters)}')
721
+ for counter, ijr_values in clusters.items():
722
+ #print(f'\nCluster {counter} {ijr_values}')
723
+
724
+ # clusters contain many entanglements
725
+ if len(ijr_values) > 1:
726
+
727
+ ijr = np.asarray(ijr_values)
728
+ #print(f'cluster ijr:\n{ijr}')
729
+
730
+ cr_values = np.asarray([[int(r_value[0][1:])] for r_value in ijr[:, 3:-1]])
731
+ #print(f'cr_values: {cr_values}')
732
+
733
+ median_cr = compute_geometric_median(cr_values).median
734
+ #print(f'median_cr: {median_cr}')
735
+
736
+ distances = cdist(cr_values, [median_cr])
737
+
738
+ minimum_distances_i = np.where(distances == min(distances))[0]
739
+ #print(f'minimum_distances_i: {minimum_distances_i}')
740
+
741
+ possible_cand = ijr[minimum_distances_i]
742
+ #print(f'possible_cand:\n{possible_cand}')
743
+
744
+ loop_lengths = np.abs(possible_cand[:, 1].astype(int) - possible_cand[:, 2].astype(int))
745
+ #print(f'loop_lengths: {loop_lengths}')
746
+
747
+ smallest_loop_length = min(loop_lengths)
748
+ #print(f'smallest_loop_length: {smallest_loop_length}')
749
+
750
+ num_nc_in_cluster = np.sum([int(n[0]) for n in ijr_values])
751
+ #print(f'num_nc_in_cluster: {num_nc_in_cluster}')
752
+
753
+ loops = ';'.join([n[-1] for n in ijr_values])
754
+ #print(f'loops: {loops}')
755
+
756
+ #rep_entanglement = possible_cand[random.choice(np.where(smallest_loop_length == loop_lengths)[0])]
757
+ rep_entanglement = possible_cand[random.choice(np.where(smallest_loop_length == loop_lengths))[0]]
758
+ rep_entanglement = [str(num_nc_in_cluster), *rep_entanglement[1:-1], loops]
759
+ #rep_ID_ent[f"{ID}_{split_cluster_counter}"].append(rep_entanglement)
760
+ rep_ID_ent[(ID, split_cluster_counter)].append(rep_entanglement)
761
+
762
+ # clusters with a single entnalgement
763
+ else:
764
+ #rep_ID_ent[f"{ID}_{split_cluster_counter}"].append(ijr_values[0])
765
+ rep_ID_ent[(ID, split_cluster_counter)].append(ijr_values[0])
766
+ if counter == list(clusters.keys())[-1]: # Print only for last cluster to avoid clutter
767
+ num_single = sum(1 for c_vals in clusters.values() if len(c_vals) == 1)
768
+ self.logger.info(f' Single-entanglement clusters: {num_single}')
769
+
770
+ split_cluster_counter += 1
771
+
772
+ ## QC Step 4 results
773
+ self.logger.info(f'\n{"="*100}')
774
+ self.logger.info(f'STEP 4 FINAL RESULTS: Primary Structure Clustering Summary')
775
+ self.logger.info(f'{"="*100}')
776
+ num_raw_ents_FINAL = {}
777
+ for ID_counter, ijrs in rep_ID_ent.items():
778
+ #print(ID_counter, ijrs)
779
+ ID, counter = ID_counter
780
+ #print(ID_counter, ID, counter, ijrs)
781
+
782
+ if ID not in num_raw_ents_FINAL:
783
+ num_raw_ents_FINAL[ID] = 0
784
+
785
+ for ijr in ijrs:
786
+ num_nc = int(ijr[0])
787
+ num_raw_ents_FINAL[ID] += num_nc
788
+
789
+ ## check the final tracking of raw ents
790
+ for ID, count in num_raw_ents.items():
791
+ final_count = num_raw_ents_FINAL[ID]
792
+ num_clusters = len([ijrs for (c_id, _), ijrs in rep_ID_ent.items() if c_id == ID])
793
+ self.logger.info(f'{ID}: {count} raw → {final_count} raw in {num_clusters} final clusters')
794
+ if count != final_count:
795
+ raise ValueError(f'The FINAL # of raw ents {final_count} != the starting {count} for ID {ID}')
796
+
797
+ ### STEP 5 OUTPUT FILE ################################################################################################################
798
+ # Step 5
799
+ self.logger.info(f'{"="*100}')
800
+ self.logger.info(f'STEP 5: Writing Output File')
801
+ self.logger.info(f'{"="*100}')
802
+
803
+ ## set up the outdir for this calculation
804
+ #outdir = f"{os.getcwd()}/{outdir}"
805
+ if not os.path.isdir(outdir):
806
+ os.mkdir(f"{outdir}")
807
+ self.logger.info(f"Creating directory: {outdir}")
808
+
809
+ outfilepath = os.path.join(f'{outdir}', f'{outfile}')
810
+
811
+ with open(outfilepath, "w") as f:
812
+
813
+ f.write(f'ID|chain|i|j|crossingsN|crossingsC|gn|gc|GLNn|GLNc|TLNn|TLNc|num_contacts|contacts|CCBond\n')
814
+ for ID_counter, ijrs in rep_ID_ent.items():
815
+
816
+ ID, counter = ID_counter
817
+ chain = chain_info.get(ID, 'A') # Get chain for this ID, default to 'A'
818
+
819
+ for ijr in ijrs:
820
+
821
+ new_ijr = (int(ijr[1]), int(ijr[2]), *list(ijr[3:-1]))
822
+
823
+ num_nc = int(ijr[0])
824
+
825
+ gn, gc, GLNn, GLNc, TLNn, TLNc, crossingsN_stored, crossingsC_stored = entanglement_partial_g_data[new_ijr]
826
+ gn = float(gn)
827
+ gc = float(gc)
828
+ GLNn = int(GLNn)
829
+ GLNc = int(GLNc)
830
+ # Handle NaN/empty TLN values: convert to int only if not NaN
831
+ TLNn = np.nan if pd.isna(TLNn) else int(TLNn)
832
+ TLNc = np.nan if pd.isna(TLNc) else int(TLNc)
833
+
834
+ # Separate crossings into N and C terminal
835
+ all_crossings = list(ijr[3:-1])
836
+ crossingsN = []
837
+ crossingsC = []
838
+ i_val = int(ijr[1])
839
+ j_val = int(ijr[2])
840
+ for cross in all_crossings:
841
+ cross_resid = int(cross[1:])
842
+ if cross_resid < i_val:
843
+ crossingsN.append(cross)
844
+ elif cross_resid > j_val:
845
+ crossingsC.append(cross)
846
+
847
+ crossingsN_str = ','.join(crossingsN) if crossingsN else ''
848
+ crossingsC_str = ','.join(crossingsC) if crossingsC else ''
849
+
850
+ ## check for disulfide bonds
851
+ CCBond_flag = False
852
+ for CCBond in CCBonds:
853
+ check1 = f'{CCBond[0]}-{CCBond[1]}'
854
+ check2 = f'{CCBond[1]}-{CCBond[0]}'
855
+ if check1 in ijr[-1] or check2 in ijr[-1]:
856
+ CCBond_flag = True
857
+
858
+ line = f"{ID}|{chain}|{int(ijr[1])}|{int(ijr[2])}|{crossingsN_str}|{crossingsC_str}|{gn:.5f}|{gc:.5f}|{GLNn}|{GLNc}|{TLNn}|{TLNc}|{num_nc}|{ijr[-1]}|{CCBond_flag}"
859
+ #print(line)
860
+ f.write(f"{line}\n")
861
+ self.logger.info(f'SAVED: {outfilepath}')
862
+ outdf = pd.read_csv(outfilepath, sep='|')
863
+
864
+ # FINAL CLUSTERING SUMMARY
865
+ self.logger.info(f'\n{"="*100}')
866
+ self.logger.info(f'CLUSTERING COMPLETE: Final Summary')
867
+ self.logger.info(f'{"="*100}')
868
+ self.logger.info(f'Total raw entanglements processed: {sum(num_raw_ents.values())}')
869
+ self.logger.info(f'Total final representative entanglements: {len(outdf)}')
870
+ # print(f'Compression ratio: {sum(num_raw_ents.values())/len(outdf):.2f}x (raw → final)')
871
+ self.logger.info(f'Clustering by organism: {self.organism}')
872
+ self.logger.info(f'Spatial distance cutoff: {self.cut_off}')
873
+ self.logger.info(f'Output file: {outfilepath}')
874
+ self.logger.info(f'{"="*100}\n')
875
+
876
+ return {'outfile':outfilepath, 'ent_result':outdf}
877
+ ##########################################################################################################################################################
878
+ ##########################################################################################################################################################
879
+ ##########################################################################################################################################################
880
+
881
+
882
+ ##########################################################################################################################################################
883
+ ##########################################################################################################################################################
884
+ class ClusterNonNativeEntanglements:
885
+ """
886
+ Class to calculate Non-native entanglements given either a file path to an entanglement file or an entanglement object
887
+ """
888
+
889
+ ##########################################################################################################################################################
890
+ def __init__(self, trajnum2pklfile_path:str, traj_dir_prefix:str='./', outdir:str='./ClusterNonNativeEntanglements/', log_level:int=logging.INFO, logdir:str=None, nproc:int=1) -> None:
891
+ """
892
+ Constructor for GaussianEntanglement class.
893
+
894
+ Parameters
895
+ ----------
896
+ """
897
+ self.classify_key = ['topoly_linking_number']
898
+ self.cluster_method = ['average', 'average', 'average']
899
+ # cluster_dist_cutoff = [20, 1.0, 0.6] # Allow contamination
900
+ self.cluster_dist_cutoff = [20, 1.0, 0.1] # No contamination
901
+ self.memory_cutoff = 6.4e10 # 64 Gb
902
+ self.max_plot_samples = 1000
903
+
904
+ # matplotlib.rcParams['mathtext.fontset'] = 'stix'
905
+ # matplotlib.rcParams['font.sans-serif'] = ['Arial']
906
+ matplotlib.rcParams['axes.labelsize'] = 'small'
907
+ matplotlib.rcParams['axes.linewidth'] = 1
908
+ matplotlib.rcParams['lines.markersize'] = 4
909
+ matplotlib.rcParams['xtick.major.width'] = 1
910
+ matplotlib.rcParams['ytick.major.width'] = 1
911
+ matplotlib.rcParams['xtick.labelsize'] = 'x-small'
912
+ matplotlib.rcParams['ytick.labelsize'] = 'x-small'
913
+ matplotlib.rcParams['legend.fontsize'] = 'x-small'
914
+ matplotlib.rcParams['figure.dpi'] = 600
915
+
916
+ self.nproc = max(1, int(nproc))
917
+ self.logger = setup_logger('ClusterNonNativeEntanglements', outdir=logdir if logdir is not None else outdir, log_level=log_level)
918
+
919
+ self.traj_dir_prefix = traj_dir_prefix
920
+
921
+ ## Load the dataframe that maps trajectory numbers to pkl file paths
922
+ self.trajnum2pklfile_path = trajnum2pklfile_path
923
+ self.trajnum2pklfile = pd.read_csv(self.trajnum2pklfile_path)
924
+
925
+ ## Extract pkl file paths from the manifest (source of truth)
926
+ if 'pklfile' not in self.trajnum2pklfile.columns:
927
+ self.logger.error('Error: trajnum2pklfile CSV must contain a "pklfile" column')
928
+ sys.exit()
929
+
930
+ self.ent_data_file_list = self.trajnum2pklfile['pklfile'].tolist()
931
+
932
+ # Verify all pkl files exist
933
+ missing_files = [f for f in self.ent_data_file_list if not os.path.isfile(f)]
934
+ if missing_files:
935
+ self.logger.error(f'Error: {len(missing_files)} pkl files not found:')
936
+ for f in missing_files:
937
+ self.logger.error(f' {f}')
938
+ sys.exit()
939
+
940
+ self.logger.info(f'FOUND {len(self.ent_data_file_list)} .pkl files to cluster from manifest')
941
+
942
+ ## Set up the outdir for this calculation
943
+ self.outdir = outdir
944
+ if not os.path.isdir(self.outdir):
945
+ os.mkdir(f"{self.outdir}")
946
+ self.logger.info(f"Creating directory: {self.outdir}")
947
+ ##########################################################################################################################################################
948
+
949
+ ##########################################################################################################################################################
950
+ def save_pickle(self, filename, mode, data, protocol=4):
951
+ with open(filename, mode) as fh:
952
+ pickle.dump(data, fh, protocol=protocol)
953
+ ##########################################################################################################################################################
954
+
955
+ ##########################################################################################################################################################
956
+ def load_pickle(self, filename, start_frame=None, end_frame=None):
957
+ """Load pickle file and optionally filter frames.
958
+
959
+ Parameters
960
+ ----------
961
+ filename : str
962
+ Path to pickle file
963
+ start_frame : int, optional
964
+ Minimum frame index to keep (inclusive)
965
+ end_frame : int, optional
966
+ Maximum frame index to keep (inclusive)
967
+
968
+ Returns
969
+ -------
970
+ dict
971
+ Dictionary with frame keys and 'ref' key. Filtered to frame range if specified.
972
+
973
+ Note: Large unfiltered dictionaries are explicitly deleted to ensure timely
974
+ garbage collection, especially important when multiple threads load in parallel.
975
+ """
976
+ import gc
977
+ data_dict = {}
978
+ self.logger.debug(f'Loading pickle file: {filename}')
979
+ with open(filename, 'rb') as fr:
980
+ try:
981
+ while True:
982
+ chunk = pickle.load(fr)
983
+ self.logger.debug(f"Loaded chunk with size: {len(chunk)}")
984
+ if start_frame is not None or end_frame is not None:
985
+ self.logger.debug(f"Filtering chunk for frames between {start_frame} and {end_frame}")
986
+ filtered_chunk = {}
987
+
988
+ for k, v in chunk.items():
989
+ if k == "ref":
990
+ filtered_chunk[k] = v
991
+
992
+ elif ((start_frame is None or k >= start_frame) and (end_frame is None or k <= end_frame)):
993
+ filtered_chunk[k] = v
994
+
995
+ del chunk # CRITICAL: Explicitly free large unfiltered dict before update
996
+ chunk = filtered_chunk
997
+
998
+ self.logger.debug(f"Loaded filtered chunk with size: {len(chunk)}")
999
+ data_dict.update(chunk)
1000
+ del chunk # Free memory after update
1001
+
1002
+ except EOFError:
1003
+ pass
1004
+
1005
+ self.logger.debug(f'Total frames loaded: {len(data_dict) - 1}') # Exclude 'ref' key from frame count
1006
+
1007
+ # Force garbage collection to prevent memory accumulation in parallel loads
1008
+ gc.collect()
1009
+ return data_dict
1010
+ ##########################################################################################################################################################
1011
+
1012
+ ##########################################################################################################################################################
1013
+ def extract_traj_number(self, f):
1014
+ f_match = self.trajnum2pklfile[self.trajnum2pklfile['pklfile'] == f]
1015
+ if f_match.empty:
1016
+ self.logger.error(f'Error: {f} not found in {self.trajnum2pklfile_path}')
1017
+ raise ValueError(f'Error: {f} not found in {self.trajnum2pklfile_path}')
1018
+ match = f_match['trajnum'].values[0]
1019
+ return match
1020
+ ##########################################################################################################################################################
1021
+
1022
+ ##########################################################################################################################################################
1023
+ def pdist_loop_overlap(self, data_array_1, data_array_2):
1024
+ M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
1025
+ M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
1026
+ M = np.concatenate((M_1, M_2), axis=-1)
1027
+ del M_1, M_2
1028
+ dist_M = (np.max(M, axis=-1) - np.min(M, axis=-1) + 1) / (M[:,:,1]-M[:,:,0]+M[:,:,3]-M[:,:,2])
1029
+ # Make distance between inclusive loops to be minimum
1030
+ dist_M[(M[:,:,2]-M[:,:,0])*(M[:,:,1]-M[:,:,3]) >= 0] = 0.5
1031
+ return dist_M
1032
+ ##########################################################################################################################################################
1033
+
1034
+ ##########################################################################################################################################################
1035
+ def pdist_thread_deviation(self, data_array_1, data_array_2):
1036
+ M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
1037
+ M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
1038
+ M = np.concatenate((M_1, M_2), axis=-1)
1039
+ del M_1, M_2
1040
+ dist_M = np.abs((M[:,:,2:4]-M[:,:,0:2]))
1041
+ dist_M[M[:,:,2:4]*M[:,:,0:2] < 0] = 10 # Make distance between no crossing and crossings to be small
1042
+ del M
1043
+ dist = np.max(dist_M, axis=-1)
1044
+ return dist
1045
+ ##########################################################################################################################################################
1046
+
1047
+ ##########################################################################################################################################################
1048
+ def pdist_cross_contamination(self, data_array_1, data_array_2):
1049
+ # data looks like [nc_1, nc_2, cross_N1, cross_N2, ..., cross_C1, cross_C2, ...]
1050
+ M_1 = np.repeat(data_array_1.reshape((data_array_1.shape[0], 1, data_array_1.shape[1])), data_array_2.shape[0], axis=1)
1051
+ M_2 = np.repeat(data_array_2.reshape((1, data_array_2.shape[0], data_array_2.shape[1])), data_array_1.shape[0], axis=0)
1052
+ M = np.concatenate((M_1, M_2), axis=-1)
1053
+ del M_1, M_2
1054
+
1055
+ # Distance for cross_2 contaminate loop_1
1056
+ idx_array_1 = np.zeros((data_array_2.shape[1]-2,2), dtype=int)
1057
+ idx_array_1[:,0] = 1
1058
+ idx_array_1[:,1] = np.arange(data_array_1.shape[1]+2, data_array_1.shape[1]+data_array_2.shape[1], dtype=int)
1059
+ idx_array_2 = np.zeros((data_array_2.shape[1]-2,2), dtype=int)
1060
+ idx_array_2[:,0] = np.arange(data_array_1.shape[1]+2, data_array_1.shape[1]+data_array_2.shape[1], dtype=int)
1061
+ idx_array_2[:,1] = 0
1062
+ L = (M[:,:,1]-M[:,:,0]).reshape((M.shape[0], M.shape[1], 1))
1063
+
1064
+ dist_M_1 = np.min(M[:,:,idx_array_1]-M[:,:,idx_array_2], axis=-1) / L
1065
+ dist_M_1[dist_M_1 <= 0] = 0
1066
+ dist_M_1[dist_M_1 >= 1] = 0
1067
+
1068
+ # Distance for cross_1 contaminate loop_2
1069
+ idx_array_1 = np.zeros((data_array_1.shape[1]-2,2), dtype=int)
1070
+ idx_array_1[:,0] = data_array_1.shape[1]+1
1071
+ idx_array_1[:,1] = np.arange(2, data_array_1.shape[1], dtype=int)
1072
+ idx_array_2 = np.zeros((data_array_1.shape[1]-2,2), dtype=int)
1073
+ idx_array_2[:,0] = np.arange(2, data_array_1.shape[1], dtype=int)
1074
+ idx_array_2[:,1] = data_array_1.shape[1]
1075
+ L = (M[:,:,data_array_1.shape[1]+1]-M[:,:,data_array_1.shape[1]]).reshape((M.shape[0], M.shape[1], 1))
1076
+
1077
+ dist_M_2 = np.min(M[:,:,idx_array_1]-M[:,:,idx_array_2], axis=-1) / L
1078
+ dist_M_2[dist_M_2 <= 0] = 0
1079
+ dist_M_2[dist_M_2 >= 1] = 0
1080
+
1081
+ del M
1082
+ dist_M = np.max(np.concatenate((dist_M_1, dist_M_2), axis=-1), axis=-1)
1083
+ return dist_M
1084
+ ##########################################################################################################################################################
1085
+
1086
+ ##########################################################################################################################################################
1087
+ def agglomerative_clustering(self, dist, cluster_method, cluster_dist_cutoff, num_perm):
1088
+ min_SSDIFN = np.inf
1089
+ best_Z = None
1090
+ best_perm_idx_list = None
1091
+ pdist = squareform(dist, checks=False)
1092
+ if np.sum(pdist**2) == 0:
1093
+ best_Z = linkage(pdist, method=cluster_method)
1094
+ best_perm_idx_list = np.arange(dist.shape[0])
1095
+ else:
1096
+ # permuCLUSTER
1097
+ for idx_perm in range(np.max([1, num_perm])):
1098
+ perm_idx_list = np.random.permutation(np.arange(dist.shape[0]))
1099
+ pdist = squareform(dist[perm_idx_list,:][:,perm_idx_list], checks=False)
1100
+ Z = linkage(pdist, method=cluster_method)
1101
+ cdist = cophenet(Z)
1102
+ SSDIFN = np.sum((pdist - cdist)**2)/np.sum(pdist**2)
1103
+ if SSDIFN < min_SSDIFN:
1104
+ min_SSDIFN = SSDIFN
1105
+ best_Z = Z
1106
+ best_perm_idx_list = perm_idx_list
1107
+ cluster_id_list = fcluster(best_Z, cluster_dist_cutoff, criterion='distance')
1108
+ backmap_list = np.zeros(len(best_perm_idx_list), dtype=int)
1109
+ for i, j in enumerate(best_perm_idx_list):
1110
+ backmap_list[j] = i
1111
+ cluster_id_list = cluster_id_list[backmap_list]
1112
+ return cluster_id_list
1113
+ ##########################################################################################################################################################
1114
+
1115
+ ##########################################################################################################################################################
1116
+ def do_clustering(self, map_list, chg_ent_fingerprint_list, key, pdist_fun, cluster_method, cluster_dist_cutoff, num_perm=100):
1117
+ data = []
1118
+ # Get max number of crossings
1119
+ max_n_cross = 1
1120
+ if 'cross_contamination' in key:
1121
+ for map_idx in map_list:
1122
+ fingerprint = chg_ent_fingerprint_list[map_idx[0]][map_idx[1]][tuple(map_idx[2:4])]
1123
+ cr = fingerprint['crossing_resid']
1124
+ for ci, c in enumerate(cr):
1125
+ if len(c) > max_n_cross:
1126
+ max_n_cross = len(c)
1127
+ # Prepare clustering data for distance calculation
1128
+ for map_idx in map_list:
1129
+ fingerprint = chg_ent_fingerprint_list[map_idx[0]][map_idx[1]][tuple(map_idx[2:4])]
1130
+ if 'crossing_resid' in key:
1131
+ ter_idx = int(key.split('_')[-1])
1132
+ mc = []
1133
+ cr = fingerprint['crossing_resid']
1134
+ ref_cr = fingerprint['ref_crossing_resid']
1135
+ for c in [ref_cr[ter_idx], cr[ter_idx]]:
1136
+ if len(c) == 0:
1137
+ mc.append(-1)
1138
+ else:
1139
+ mc.append(np.median(c))
1140
+ data.append(mc)
1141
+ elif 'native_contact' in key:
1142
+ nc = fingerprint['native_contact']
1143
+ data.append(nc)
1144
+ elif 'cross_contamination' in key:
1145
+ nc = fingerprint['native_contact']
1146
+ mc = []
1147
+ cr = fingerprint['crossing_resid']
1148
+ for ci, c in enumerate(cr):
1149
+ for cii in range(max_n_cross):
1150
+ if cii >= len(c):
1151
+ mc.append(-1)
1152
+ else:
1153
+ mc.append(c[cii])
1154
+ data.append(nc + mc)
1155
+ else:
1156
+ self.logger.error('Error: Unknown key specified for do_clustering(), %s'%(key))
1157
+ sys.exit()
1158
+ data = np.array(data)
1159
+ # Reduce data size (saving memory usage) by combining duplicated data points
1160
+ reduced_data = np.unique(data, axis=0)
1161
+ reduced_data_map = np.array([np.all(data == d, axis=1).nonzero()[0].tolist() for d in reduced_data], dtype=object)
1162
+
1163
+ # If all data are the same, group them into a single cluster and return
1164
+ if len(reduced_data) == 1:
1165
+ cluster_data = [map_list]
1166
+ return cluster_data
1167
+
1168
+ # Chunk data to reduce memory usage if the expanded matrix occupy >= memory_cutoff
1169
+ if reduced_data.nbytes ** 2 >= self.memory_cutoff:
1170
+ n_chunk = int(np.ceil(reduced_data.nbytes / np.sqrt(self.memory_cutoff/2)))
1171
+ len_chunk = int(np.ceil(len(reduced_data) / n_chunk))
1172
+ dist = np.zeros((len(reduced_data), len(reduced_data)))
1173
+ for i in range(n_chunk):
1174
+ i_1 = i*len_chunk
1175
+ i_2 = np.min([(i+1)*len_chunk,len(reduced_data)])
1176
+ for j in range(n_chunk):
1177
+ j_1 = j*len_chunk
1178
+ j_2 = np.min([(j+1)*len_chunk,len(reduced_data)])
1179
+ dist[i_1:i_2, j_1:j_2] = pdist_fun(reduced_data[i_1:i_2], reduced_data[j_1:j_2])
1180
+ else:
1181
+ dist = pdist_fun(reduced_data, reduced_data)
1182
+
1183
+ if 'cross_contamination' in key:
1184
+ # Do divisive clustering
1185
+ cluster_idx_mapping = [list(np.arange(dist.shape[0]))]
1186
+ while True:
1187
+ cluster_1 = cluster_idx_mapping[-1]
1188
+ cluster_2 = []
1189
+ cluster_0 = copy.deepcopy(cluster_1)
1190
+ for i in range(len(cluster_0)-1):
1191
+ rm_idx_list = np.where(dist[cluster_0[i],cluster_0[i+1:]] >= cluster_dist_cutoff)[0]
1192
+ if len(rm_idx_list) > 0:
1193
+ cluster_1.remove(cluster_0[i])
1194
+ cluster_2.append(cluster_0[i])
1195
+ if len(cluster_2) > 0:
1196
+ cluster_idx_mapping.append(cluster_2)
1197
+ else:
1198
+ break
1199
+ # Do agglomerative clustering
1200
+ if cluster_method == 'single':
1201
+ dist_fun = np.min
1202
+ elif cluster_method == 'complete':
1203
+ dist_fun = np.max
1204
+ elif cluster_method == 'average':
1205
+ dist_fun = np.mean
1206
+ else:
1207
+ dist_fun = np.mean
1208
+ if len(cluster_idx_mapping) > 1:
1209
+ dist_0 = np.zeros((len(cluster_idx_mapping),len(cluster_idx_mapping)))
1210
+ for i in range(len(dist_0)-1):
1211
+ for j in range(i+1, len(dist_0)):
1212
+ dist_0[i,j] = dist_fun(dist[cluster_idx_mapping[i],:][:,cluster_idx_mapping[j]])
1213
+ cluster_id_list_0 = self.agglomerative_clustering(dist_0, cluster_method, cluster_dist_cutoff, num_perm)
1214
+ else:
1215
+ cluster_id_list_0 = np.array([1], dtype=int)
1216
+ cluster_id_list = np.zeros(dist.shape[0], dtype=int)
1217
+ for cluster_idx, mapping in enumerate(cluster_idx_mapping):
1218
+ cluster_id_list[mapping] = cluster_id_list_0[cluster_idx]
1219
+ else:
1220
+ # Do agglomerative clustering
1221
+ cluster_id_list = self.agglomerative_clustering(dist, cluster_method, cluster_dist_cutoff, num_perm)
1222
+ n_cluster = np.max(cluster_id_list)
1223
+
1224
+ # Back-Mapping indices
1225
+ cluster_data = []
1226
+ for cluster_id in range(n_cluster):
1227
+ idx = np.where(cluster_id_list == cluster_id+1)[0]
1228
+ idx_list = reduced_data_map[idx].tolist()
1229
+ idx_list_1 = []
1230
+ for i in idx_list:
1231
+ idx_list_1 += i
1232
+ idx_list_1 = sorted(idx_list_1)
1233
+ cluster_data.append(np.array(map_list)[idx_list_1].tolist())
1234
+ return cluster_data
1235
+ ##########################################################################################################################################################
1236
+
1237
+ ##########################################################################################################################################################
1238
+ def cluster_chg_ent(self, chg_ent_keyword_dict, chg_ent_fingerprint_list, cluster_method=['average', 'average', 'average'], cluster_dist_cutoff=[20, 1.0, 0.1]):
1239
+ cluster_data_keys = sorted(list(chg_ent_keyword_dict.keys()))
1240
+ cluster_data = {key: [] for key in cluster_data_keys}
1241
+ cluster_tree = {key: [] for key in cluster_data_keys}
1242
+
1243
+ def _process_key(key):
1244
+ """Run the 4-level hierarchical clustering pipeline for one keyword.
1245
+
1246
+ All four ``do_clustering`` calls are fully independent across keywords,
1247
+ so this function is safe to run in a ThreadPoolExecutor. numpy/scipy
1248
+ release the GIL during heavy array operations, giving real parallelism.
1249
+
1250
+ Note: when ``self.nproc > 1``, ``agglomerative_clustering`` uses
1251
+ ``np.random.permutation`` which draws from the shared global numpy
1252
+ random state. Results are therefore non-deterministic across runs with
1253
+ multiple workers, but clustering quality is unaffected.
1254
+ """
1255
+ map_list = chg_ent_keyword_dict[key]
1256
+ local_cluster_data = []
1257
+ backtrace_idx_list = []
1258
+ idx_1 = 0
1259
+ idx_2 = 0
1260
+ idx_3 = 0
1261
+
1262
+ # First clustering based on N-ter crossing residues
1263
+ N_cr_cluster_data = self.do_clustering(
1264
+ map_list, chg_ent_fingerprint_list, 'crossing_resid_0',
1265
+ self.pdist_thread_deviation, cluster_method[0], cluster_dist_cutoff[0])
1266
+
1267
+ # Second clustering based on C-ter crossing residues
1268
+ for map_list_N_cr in N_cr_cluster_data:
1269
+ C_cr_cluster_data = self.do_clustering(
1270
+ map_list_N_cr, chg_ent_fingerprint_list, 'crossing_resid_1',
1271
+ self.pdist_thread_deviation, cluster_method[0], cluster_dist_cutoff[0])
1272
+
1273
+ # Third clustering based on loop
1274
+ for map_list_C_cr in C_cr_cluster_data:
1275
+ nc_cluster_data = self.do_clustering(
1276
+ map_list_C_cr, chg_ent_fingerprint_list, 'native_contact',
1277
+ self.pdist_loop_overlap, cluster_method[1], cluster_dist_cutoff[1])
1278
+
1279
+ # Fourth clustering based on cross contamination
1280
+ for map_list_nc in nc_cluster_data:
1281
+ final_cluster_data = self.do_clustering(
1282
+ map_list_nc, chg_ent_fingerprint_list, 'cross_contamination',
1283
+ self.pdist_cross_contamination, cluster_method[2], cluster_dist_cutoff[2])
1284
+ for final_cluster in final_cluster_data:
1285
+ local_cluster_data.append(final_cluster)
1286
+ backtrace_idx_list.append([idx_1, idx_2, idx_3])
1287
+ idx_3 += 1
1288
+ idx_2 += 1
1289
+ idx_1 += 1
1290
+
1291
+ backtrace_idx_list = np.array(backtrace_idx_list, dtype=int)
1292
+ local_tree = [
1293
+ [np.where(backtrace_idx_list[:, i] == j)[0].tolist()
1294
+ for j in range(backtrace_idx_list[:, i].max() + 1)]
1295
+ for i in range(backtrace_idx_list.shape[1])
1296
+ ]
1297
+ n_cluster = len(local_cluster_data)
1298
+ self.logger.info('Found %d cluster(s) for %s' % (n_cluster, key))
1299
+ return key, local_cluster_data, local_tree
1300
+
1301
+ n_workers = min(self.nproc, len(cluster_data_keys)) if cluster_data_keys else 1
1302
+ self.logger.info(
1303
+ f'Clustering {len(cluster_data_keys)} keyword(s) '
1304
+ f'(nproc={n_workers})...')
1305
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
1306
+ for key, cd, ct in executor.map(_process_key, cluster_data_keys):
1307
+ cluster_data[key] = cd
1308
+ cluster_tree[key] = ct
1309
+
1310
+ return (cluster_data, cluster_tree)
1311
+ ##########################################################################################################################################################
1312
+
1313
+ ##########################################################################################################################################################
1314
+ def find_representative_entanglement(self, cluster_data, ent_cluster_idx_map):
1315
+ # most probable loop midpoint
1316
+ rep_ent_list = []
1317
+ for [key, idx] in ent_cluster_idx_map:
1318
+ cluster = np.array(cluster_data[key][idx])
1319
+ loop_midpoint_list = np.mean(cluster[:,2:4], axis=1)
1320
+ loop_max = np.max(cluster[:,3])
1321
+ loop_min = np.min(cluster[:,2])
1322
+ # mode
1323
+ bins = np.arange(loop_min, loop_max, 5)
1324
+ if len(bins) == 1:
1325
+ bins = np.arange(loop_min, loop_max, 1)
1326
+ hist, edges = np.histogram(loop_midpoint_list, bins=bins)
1327
+ idx = np.argmax(hist)
1328
+ min_loop_len = 1e6
1329
+ idx0 = np.where(loop_midpoint_list >= edges[idx])[0]
1330
+ idx1 = np.where(loop_midpoint_list[idx0] < edges[idx+1])[0]
1331
+ for iidx in idx0[idx1]:
1332
+ if cluster[iidx,3]-cluster[iidx,2] < min_loop_len:
1333
+ rep_ent = cluster[iidx]
1334
+ min_loop_len = rep_ent[3] - rep_ent[2]
1335
+ rep_ent_list.append(rep_ent)
1336
+ return rep_ent_list
1337
+ ##########################################################################################################################################################
1338
+
1339
+ ##########################################################################################################################################################
1340
+ def _process_traj_file(self, traj_idx, ent_data_file, start_frame, end_frame):
1341
+ """Load and pre-process one trajectory pkl file for clustering.
1342
+
1343
+ Called in parallel (one call per trajectory) by ``cluster()``.
1344
+
1345
+ Returns
1346
+ -------
1347
+ tuple : (traj_idx, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries)
1348
+ * fingerprint_dict – {frame: {nc: fingerprint}}
1349
+ * Q_dict – {frame: Q_value}
1350
+ * frame_list – sorted list of in-range frame indices
1351
+ * traj_file – resolved path to the matching .dcd file
1352
+ * keyword_entries – list of (keyword_str, entry) pairs for building
1353
+ chg_ent_keyword_dict after all trajectories are loaded
1354
+ """
1355
+ traj = self.extract_traj_number(ent_data_file)
1356
+ self.logger.debug(
1357
+ f'Processing {ent_data_file} {traj} ({traj_idx + 1} / {len(self.ent_data_file_list)})...')
1358
+
1359
+ # Load pickle with frame filtering applied
1360
+ ent_data = self.load_pickle(ent_data_file, start_frame, end_frame)
1361
+ # Frame list is now already filtered, just exclude 'ref'
1362
+ frame_list = sorted([frame for frame in ent_data.keys() if frame != 'ref'])
1363
+ self.logger.debug(f'frame_list: {frame_list} {len(frame_list)}')
1364
+
1365
+ # Locate the matching trajectory DCD file
1366
+ traj_file = os.path.join(self.traj_dir_prefix, f'{traj}_*.dcd')
1367
+ traj_file = glob.glob(traj_file)
1368
+ if len(traj_file) == 0:
1369
+ raise ValueError(f'No trajectory file found for {ent_data_file}.')
1370
+ elif len(traj_file) > 1:
1371
+ raise ValueError(f'More than 1 trajectory file found for {ent_data_file}.')
1372
+ else:
1373
+ traj_file = traj_file[0]
1374
+
1375
+ fingerprint_dict = {}
1376
+ Q_dict = {}
1377
+ keyword_entries = [] # collected as (keyword_str, entry) to merge after parallel load
1378
+
1379
+ for frame in frame_list:
1380
+ fingerprint_dict[frame] = {}
1381
+ Q_dict[frame] = (
1382
+ np.sum(list(ent_data[frame]['G_dict'].values()))
1383
+ / len(list(ent_data['ref']['ent_fingerprint'].keys())) / 2
1384
+ )
1385
+ for nc, fingerprint in ent_data[frame]['chg_ent_fingerprint'].items():
1386
+ # Skip if no change of entanglement
1387
+ if fingerprint['type'] == ['no change', 'no change']:
1388
+ continue
1389
+ fingerprint_dict[frame][nc] = fingerprint
1390
+ chg_ent_keyword = fingerprint['code'].copy()
1391
+ for ck in self.classify_key:
1392
+ if type(fingerprint[ck]) == list:
1393
+ chg_ent_keyword += fingerprint[ck]
1394
+ else:
1395
+ chg_ent_keyword += [fingerprint[ck]]
1396
+ chg_ent_keyword = str(chg_ent_keyword)
1397
+ keyword_entries.append((chg_ent_keyword, [traj_idx, frame] + list(nc)))
1398
+
1399
+ # Explicitly free the full deserialized pkl dict — it can be ~10 GB and the
1400
+ # semaphore in cluster() is held until this function returns, so freeing it
1401
+ # here lets CPython recycle the memory before the next load starts.
1402
+ del ent_data
1403
+
1404
+ return traj_idx, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries
1405
+ ##########################################################################################################################################################
1406
+
1407
+ ##########################################################################################################################################################
1408
+ def cluster(self, start_frame:int=0, end_frame:int=9999999):
1409
+
1410
+ ## Define the .npz file name
1411
+ npz_data_file = f'cluster_data_{"_".join(self.classify_key)}.npz'
1412
+ npz_data_file = os.path.join(self.outdir, npz_data_file)
1413
+ self.logger.info(f'Checking for {npz_data_file}')
1414
+
1415
+ if not os.path.exists(npz_data_file):
1416
+ # Classify changes of entanglement based on the keyword
1417
+ # "[change_code, classify_key_1_N, classify_key_1_C, classify_key_2_N, classify_key_2_C, ...]"
1418
+ self.logger.debug('Reading pkl data and classify changes of entanglement...')
1419
+ chg_ent_fingerprint_list = [None] * len(self.ent_data_file_list)
1420
+ Q_list = [None] * len(self.ent_data_file_list)
1421
+ idx2frame = [None] * len(self.ent_data_file_list)
1422
+ idx2trajfile = [None] * len(self.ent_data_file_list)
1423
+ dtrajs = [None] * len(self.ent_data_file_list)
1424
+ chg_ent_keyword_dict = {}
1425
+ chg_ent_keyword_list = []
1426
+ # combined_traj = None
1427
+
1428
+ n_workers = min(self.nproc, len(self.ent_data_file_list))
1429
+ # Each ~1 GB pkl file inflates to ~10 GB of live Python objects during
1430
+ # deserialization. Use memory_cutoff as a proxy for available RAM to
1431
+ # derive a safe concurrency limit: memory_cutoff / 1e10 (10 GB/file).
1432
+ n_load_workers = min(n_workers, max(1, int(self.memory_cutoff // 1e10)))
1433
+ self.logger.info(
1434
+ f'Loading {len(self.ent_data_file_list)} pkl files '
1435
+ f'(nproc={n_workers}, concurrent_loads={n_load_workers})...')
1436
+ # Semaphore caps how many threads may simultaneously hold a deserialized
1437
+ # pkl in memory. n_workers threads are still available to start new
1438
+ # loads as soon as a slot frees, so throughput is not reduced.
1439
+ _load_sem = threading.Semaphore(n_load_workers)
1440
+
1441
+ def _throttled_process(ti, f):
1442
+ with _load_sem:
1443
+ return self._process_traj_file(ti, f, start_frame, end_frame)
1444
+
1445
+ with ThreadPoolExecutor(max_workers=n_workers) as executor:
1446
+ futures = {
1447
+ executor.submit(_throttled_process, ti, f): ti
1448
+ for ti, f in enumerate(self.ent_data_file_list)
1449
+ }
1450
+ for fut in as_completed(futures):
1451
+ ti, fingerprint_dict, Q_dict, frame_list, traj_file, keyword_entries = fut.result()
1452
+ chg_ent_fingerprint_list[ti] = fingerprint_dict
1453
+ Q_list[ti] = Q_dict
1454
+ idx2frame[ti] = frame_list
1455
+ idx2trajfile[ti] = traj_file
1456
+ dtrajs[ti] = [[] for _ in frame_list]
1457
+ # Merge keyword entries into the shared dicts (sequential, no lock needed)
1458
+ for kw, entry in keyword_entries:
1459
+ if kw not in chg_ent_keyword_list:
1460
+ chg_ent_keyword_dict[kw] = []
1461
+ chg_ent_keyword_list.append(kw)
1462
+ chg_ent_keyword_dict[kw].append(entry)
1463
+
1464
+ self.logger.info('%d data files have been read.' % len(self.ent_data_file_list))
1465
+
1466
+ # cluster changes of entanglements found in the trajectories
1467
+ self.logger.info('Clustering changes of entanglement for %d keywords...'%(len(chg_ent_keyword_list)))
1468
+ ent_cluster_data, ent_cluster_tree = self.cluster_chg_ent(chg_ent_keyword_dict, chg_ent_fingerprint_list, cluster_method=self.cluster_method, cluster_dist_cutoff=self.cluster_dist_cutoff)
1469
+ chg_ent_keyword_list = sorted(chg_ent_keyword_list)
1470
+ # Save calculted data in case job is unexpectedly terminated
1471
+ np.savez(npz_data_file,
1472
+ chg_ent_fingerprint_list=chg_ent_fingerprint_list,
1473
+ Q_list=Q_list,
1474
+ chg_ent_keyword_dict=chg_ent_keyword_dict,
1475
+ chg_ent_keyword_list=chg_ent_keyword_list,
1476
+ idx2trajfile=idx2trajfile,
1477
+ idx2frame=idx2frame,
1478
+ ent_cluster_data=ent_cluster_data,
1479
+ ent_cluster_tree=ent_cluster_tree)
1480
+ self.logger.info(f'SAVED: {npz_data_file}')
1481
+
1482
+ else:
1483
+ self.logger.info(f'Reading clustering data from {npz_data_file}...')
1484
+ npz_data = np.load(npz_data_file, allow_pickle=True)
1485
+ chg_ent_fingerprint_list = npz_data['chg_ent_fingerprint_list'].tolist()
1486
+ Q_list = npz_data['Q_list'].tolist()
1487
+ chg_ent_keyword_dict = npz_data['chg_ent_keyword_dict'].item()
1488
+ chg_ent_keyword_list = npz_data['chg_ent_keyword_list'].tolist()
1489
+ idx2frame = npz_data['idx2frame'].tolist()
1490
+ idx2trajfile = npz_data['idx2trajfile'].tolist()
1491
+ dtrajs = [[[] for frame in chg_ent_fingerprint.keys()] for chg_ent_fingerprint in chg_ent_fingerprint_list]
1492
+ ent_cluster_data = npz_data['ent_cluster_data'].item()
1493
+ ent_cluster_tree = npz_data['ent_cluster_tree'].item()
1494
+
1495
+
1496
+ ent_cluster_idx_map = []
1497
+ for ent_keyword in chg_ent_keyword_list:
1498
+ for i in range(len(ent_cluster_data[ent_keyword])):
1499
+ ent_cluster_idx_map.append([ent_keyword, i])
1500
+
1501
+ ## Print and save cluster tree
1502
+ cluster_headers = ['After clustering on N crossing', 'After clustering on C crossing', 'After clustering on loop']
1503
+ cluster_tree_file = f'cluster_tree_{"_".join(self.classify_key)}.dat'
1504
+ cluster_tree_file = os.path.join(self.outdir, cluster_tree_file)
1505
+ self.logger.info(f'Making {cluster_tree_file}')
1506
+ with open(cluster_tree_file, 'w') as f:
1507
+ for ent_keyword in chg_ent_keyword_list:
1508
+ f.write(ent_keyword+'\n')
1509
+ clusters = ent_cluster_tree[ent_keyword]
1510
+ for i in range(len(clusters)):
1511
+ f.write(' '*4 + cluster_headers[i] + ':\n')
1512
+ for cluster in clusters[i]:
1513
+ f.write(' '*8 + '[')
1514
+ for ci, c in enumerate(cluster):
1515
+ cluster_id = ent_cluster_idx_map.index([ent_keyword, c])+1
1516
+ if ci == 0:
1517
+ f.write('%d'%cluster_id)
1518
+ else:
1519
+ f.write(', %d'%cluster_id)
1520
+ f.write(']\n')
1521
+ f.write('\n')
1522
+ self.logger.info(f'SAVED: {cluster_tree_file}')
1523
+
1524
+ # Find representative changes of entanglement in each cluster
1525
+ rep_chg_ent_list_file = f'rep_chg_ent_list_{"_".join(self.classify_key)}.pkl'
1526
+ rep_chg_ent_list_file = os.path.join(self.outdir, rep_chg_ent_list_file)
1527
+ rep_chg_ent_data_file = f'rep_chg_ent_{"_".join(self.classify_key)}.csv'
1528
+ rep_chg_ent_data_file = os.path.join(self.outdir, rep_chg_ent_data_file)
1529
+ if os.path.exists(rep_chg_ent_list_file) and os.path.exists(rep_chg_ent_data_file):
1530
+ self.logger.debug('Reading representative changes of entanglement...')
1531
+ with open(rep_chg_ent_list_file, 'rb') as f:
1532
+ rep_chg_ent_list = pickle.load(f)
1533
+ self.logger.debug(f'Loaded: {rep_chg_ent_data_file} into rep_chg_ent_list')
1534
+
1535
+ else:
1536
+ self.logger.debug('Finding representative changes of entanglement...')
1537
+ rep_chg_ent_list = self.find_representative_entanglement(ent_cluster_data, ent_cluster_idx_map)
1538
+ with open(rep_chg_ent_list_file, 'wb') as f: # save the list as a pickle file
1539
+ pickle.dump(rep_chg_ent_list, f)
1540
+ self.logger.info(f'SAVED: {rep_chg_ent_list_file}')
1541
+
1542
+ # Create dataframe and save data
1543
+ data = []
1544
+ column_list = ['Keywords', 'Trajectory', 'Frame', 'Native Contact (Residue Index)',
1545
+ 'Ref N-ter Crossing', 'Ref C-ter Crossing', 'N-ter Crossing', 'C-ter Crossing',
1546
+ 'Ref N-ter GLN', 'Ref C-ter GLN', 'N-ter GLN', 'C-ter GLN',
1547
+ 'Ref N-ter Linking Number', 'Ref C-ter Linking Number', 'N-ter Linking Number', 'C-ter Linking Number']
1548
+ index_list = []
1549
+ for state_id, rep_chg_ent in enumerate(rep_chg_ent_list):
1550
+ index_list.append(state_id+1)
1551
+ [traj_idx, frame_idx] = rep_chg_ent[:2]
1552
+ nc = tuple(rep_chg_ent[2:])
1553
+ keyword = ent_cluster_idx_map[state_id][0]
1554
+ chg_ent_fingerprint = chg_ent_fingerprint_list[traj_idx][frame_idx][nc]
1555
+ cross = []
1556
+ for i in range(len(chg_ent_fingerprint['crossing_resid'])):
1557
+ cross.append([])
1558
+ for j in range(len(chg_ent_fingerprint['crossing_resid'][i])):
1559
+ cross[-1].append(chg_ent_fingerprint['crossing_pattern'][i][j]+str(chg_ent_fingerprint['crossing_resid'][i][j]))
1560
+ ref_cross = []
1561
+ for i in range(len(chg_ent_fingerprint['ref_crossing_resid'])):
1562
+ ref_cross.append([])
1563
+ for j in range(len(chg_ent_fingerprint['ref_crossing_resid'][i])):
1564
+ ref_cross[-1].append(chg_ent_fingerprint['ref_crossing_pattern'][i][j]+str(chg_ent_fingerprint['ref_crossing_resid'][i][j]))
1565
+ GLN = chg_ent_fingerprint['linking_value']
1566
+ ref_GLN = chg_ent_fingerprint['ref_linking_value']
1567
+ LN = chg_ent_fingerprint['topoly_linking_number']
1568
+ ref_LN = chg_ent_fingerprint['ref_topoly_linking_number']
1569
+
1570
+ data_0 = [keyword, idx2trajfile[traj_idx], frame_idx, nc,
1571
+ ref_cross[0], ref_cross[1], cross[0], cross[1],
1572
+ ref_GLN[0], ref_GLN[1], GLN[0], GLN[1],
1573
+ ref_LN[0], ref_LN[1], LN[0], LN[1]]
1574
+ data.append(data_0)
1575
+
1576
+ df = pd.DataFrame(data, columns=column_list, index=index_list)
1577
+ rep_chg_ent_data_file = f'rep_chg_ent_{"_".join(self.classify_key)}.csv'
1578
+ rep_chg_ent_data_file = os.path.join(self.outdir, rep_chg_ent_data_file)
1579
+ df.to_csv(rep_chg_ent_data_file, index_label='State ID')
1580
+ self.logger.info(f'SAVED: {rep_chg_ent_data_file}')
1581
+
1582
+ # plot entanglement distribution
1583
+ n_cluster = len(ent_cluster_idx_map)
1584
+ fig = plt.figure(figsize=(np.max([6, 0.3*n_cluster]),5))
1585
+ ax = fig.add_subplot(1,1,1)
1586
+ window_width = 0.8
1587
+ for state_id, [key, cluster_idx] in enumerate(ent_cluster_idx_map):
1588
+ cluster = ent_cluster_data[key][cluster_idx]
1589
+ nc_list = [c[2:4] for c in cluster]
1590
+ sort_index = [i for i, x in sorted(enumerate(nc_list), key=lambda x: (x[1][1]-x[1][0], x[1][0]))]
1591
+ sort_index = np.array(sort_index)
1592
+ if len(sort_index) <= self.max_plot_samples:
1593
+ plot_idx = np.arange(0, len(sort_index), 1, dtype=int)
1594
+ else:
1595
+ plot_idx = np.linspace(0, len(sort_index)-1, self.max_plot_samples, dtype=int)
1596
+ for idx, ci in enumerate(sort_index[plot_idx]):
1597
+ c = cluster[ci]
1598
+ nc = c[2:4]
1599
+ traj_idx = c[0]
1600
+ frame_idx = c[1]
1601
+ fingerprint = chg_ent_fingerprint_list[traj_idx][frame_idx][tuple(nc)]
1602
+ crossings = fingerprint['crossing_resid']
1603
+ ref_crossings = fingerprint['ref_crossing_resid']
1604
+ # plot loop
1605
+ x = state_id+1-window_width/2 + (idx+1)*window_width/(len(plot_idx)+1)
1606
+ ax.plot([x,x], nc, '-', color='tomato', linewidth=0.5, alpha=0.4)
1607
+ # plot crossings
1608
+ x = state_id+1-window_width/2 + (idx+1)*window_width/(len(plot_idx)+1)
1609
+ for ccr in ref_crossings:
1610
+ for cc in ccr:
1611
+ ax.plot([x, x], [cc-0.5, cc+0.5], '-', color='green', linewidth=0.5, alpha=0.4)
1612
+ for ccr in crossings:
1613
+ for cc in ccr:
1614
+ ax.plot([x, x], [cc-0.5, cc+0.5], '-', color='blue', linewidth=0.5, alpha=0.4)
1615
+ ax.set_xticks(np.arange(1,n_cluster+1,1), np.arange(1,n_cluster+1,1))
1616
+ ax.set_xlim([0, n_cluster+1])
1617
+ ax.set_xlabel('Cluster')
1618
+ ax.set_ylabel('Residue index')
1619
+ chg_dist_data_file = f'chg_ent_{"_".join(self.classify_key)}_distribution.pdf'
1620
+ chg_dist_data_file = os.path.join(self.outdir, chg_dist_data_file)
1621
+ fig.savefig(chg_dist_data_file, bbox_inches='tight')
1622
+ self.logger.info(f'SAVED: {chg_dist_data_file}')
1623
+ del fig
1624
+
1625
+ self.logger.debug('Clustering structures with unique combinations of changes of entanglements...')
1626
+ # Assign entanglement clusters (list of ent_cluster_idx) in discrete trajectories
1627
+ for key, ent_clusters in ent_cluster_data.items():
1628
+ for i, ent_cluster in enumerate(ent_clusters):
1629
+ cluster_id = ent_cluster_idx_map.index([key, i])
1630
+ for chg_ent_keyword in ent_cluster:
1631
+ traj_idx, frame = chg_ent_keyword[:2]
1632
+ dtrajs[traj_idx][idx2frame[traj_idx].index(frame)].append(cluster_id)
1633
+ self.logger.info(f'Assigned entanglement clusters in discrete trajectories')
1634
+
1635
+ # Strip same cluster ids in each frame
1636
+ chg_ent_structure_keyword_list = []
1637
+ for dtraj in dtrajs:
1638
+ for i, cluster_id_list in enumerate(dtraj):
1639
+ dtraj[i] = sorted(list(set(cluster_id_list)))
1640
+ if str(dtraj[i]) not in chg_ent_structure_keyword_list:
1641
+ chg_ent_structure_keyword_list.append(str(dtraj[i]))
1642
+ chg_ent_structure_keyword_list = sorted(chg_ent_structure_keyword_list)
1643
+ self.logger.info(f'Stripped same cluster ids in each frame')
1644
+
1645
+ # cluster trajectory frames with different combinations of changes in entanglement
1646
+ chg_ent_structure_cluster_data = {chg_ent_structure_keyword: [] for chg_ent_structure_keyword in chg_ent_structure_keyword_list}
1647
+ for traj_idx, dtraj in enumerate(dtrajs):
1648
+ for frame_idx, cluster_id_list in enumerate(dtraj):
1649
+ chg_ent_structure_cluster_data[str(cluster_id_list)].append([traj_idx, idx2frame[traj_idx][frame_idx]])
1650
+ Num_struct_list = [len(chg_ent_structure_cluster_data[keyword]) for keyword in chg_ent_structure_keyword_list]
1651
+ sort_idx = np.argsort(-np.array(Num_struct_list, dtype=int))
1652
+ sorted_chg_ent_structure_keyword_list = [chg_ent_structure_keyword_list[idx] for idx in sort_idx]
1653
+ sorted_Num_struct_list = [Num_struct_list[idx] for idx in sort_idx]
1654
+ self.logger.info(f'Cluster trajectory frames with different combinations of changes in entanglement')
1655
+
1656
+ self.logger.debug('Find representative combinations of changes of entanglements in structures...')
1657
+ # Find representative changes of entanglement (minimal loop) in each frame
1658
+ rep_chg_ent_dtrajs = []
1659
+ for traj_idx, dtraj in enumerate(dtrajs):
1660
+ #print(f'\nTRAJ IDX: {traj_idx} with {len(dtraj)} frames')
1661
+ rep_chg_ent_dtrajs.append([])
1662
+ for frame_idx, cluster_id_list in enumerate(dtraj):
1663
+ #print(f'FRAME IDX: {frame_idx} {cluster_id_list}')
1664
+ frame_idx_0 = idx2frame[traj_idx][frame_idx]
1665
+ rep_chg_ent_dtrajs[-1].append({})
1666
+ for cluster_id in cluster_id_list:
1667
+ [ent_keyword, idx] = ent_cluster_idx_map[cluster_id]
1668
+ nc_list = []
1669
+ for element in ent_cluster_data[ent_keyword][idx]:
1670
+ if traj_idx == element[0] and frame_idx_0 == element[1]:
1671
+ nc_list.append(tuple(element[2:]))
1672
+ rep_nc = nc_list[0]
1673
+ for nc in nc_list:
1674
+ if nc[1]-nc[0] < rep_nc[1]-rep_nc[0]:
1675
+ rep_nc = nc
1676
+ rep_chg_ent_dtrajs[-1][-1][cluster_id] = chg_ent_fingerprint_list[traj_idx][frame_idx_0][rep_nc]
1677
+
1678
+ # Find representative structures (max Q) for each combination
1679
+ self.logger.info(f'Find representative structures (max Q) for each combination')
1680
+ rep_struct_data = {}
1681
+ for keyword in sorted_chg_ent_structure_keyword_list:
1682
+ Q_0 = 0
1683
+ rep_struct_data[keyword] = chg_ent_structure_cluster_data[keyword][0]
1684
+ for [traj_idx, frame_idx] in chg_ent_structure_cluster_data[keyword]:
1685
+ Q = Q_list[traj_idx][frame_idx]
1686
+ if Q > Q_0:
1687
+ Q_0 = Q
1688
+ rep_struct_data[keyword] = [traj_idx, frame_idx]
1689
+ self.logger.debug('Found representative structures (max Q) for each combination')
1690
+
1691
+ # Create dataframe and save data
1692
+ #chg_ent_data_file = f'chg_ent_struct_{"_".join(self.classify_key)}.csv'
1693
+ #chg_ent_data_file = os.path.join(self.outdir, chg_ent_data_file)
1694
+ #if os.path.exitst(chg_ent_data_file):
1695
+ # print(f'Reading {chg_ent_data_file}')
1696
+ # df = pd.read_csv(chg_ent_data_file, index_col='State ID')
1697
+ data = []
1698
+ column_list = ['Rep_chg_ents', 'Num of structures', 'Probability',
1699
+ 'Rep trajectory', 'Rep frame', 'Max Q', 'Median Q']
1700
+ index_list = []
1701
+ tot_num_frames = 0
1702
+ for traj_idx, dtraj in enumerate(dtrajs):
1703
+ tot_num_frames += len(dtraj)
1704
+ for state_id, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
1705
+ index_list.append(state_id+1)
1706
+ Rep_chg_ents = str(list(np.array(eval(keyword))+1))
1707
+ Num = len(chg_ent_structure_cluster_data[keyword])
1708
+ Prob = Num / tot_num_frames
1709
+ Q_0_list = [Q_list[cd[0]][cd[1]] for cd in chg_ent_structure_cluster_data[keyword]]
1710
+ max_Q = np.max(Q_0_list)
1711
+ median_Q = np.median(Q_0_list)
1712
+ data_0 = [Rep_chg_ents, Num, Prob, idx2trajfile[rep_struct_data[keyword][0]], rep_struct_data[keyword][1], max_Q, median_Q]
1713
+ data.append(data_0)
1714
+ df = pd.DataFrame(data, columns=column_list, index=index_list)
1715
+ chg_ent_data_file = f'chg_ent_struct_{"_".join(self.classify_key)}.csv'
1716
+ chg_ent_data_file = os.path.join(self.outdir, chg_ent_data_file)
1717
+ df.to_csv(chg_ent_data_file, index_label='State ID')
1718
+ self.logger.info(f'SAVED: {chg_ent_data_file}')
1719
+
1720
+ ## determine if there is any issue with the item shapes before saving
1721
+ save_items = {
1722
+ "chg_ent_fingerprint_list": chg_ent_fingerprint_list,
1723
+ "Q_list": Q_list,
1724
+ "chg_ent_keyword_dict": chg_ent_keyword_dict,
1725
+ "chg_ent_keyword_list": chg_ent_keyword_list,
1726
+ "idx2trajfile": idx2trajfile,
1727
+ "idx2frame": idx2frame,
1728
+ # "RMSD_array": RMSD_array,
1729
+ "ent_cluster_data": ent_cluster_data,
1730
+ "ent_cluster_tree": ent_cluster_tree,
1731
+ "rep_chg_ent_list": rep_chg_ent_list,
1732
+ "dtrajs": dtrajs,
1733
+ "rep_chg_ent_dtrajs": rep_chg_ent_dtrajs,
1734
+ "sorted_chg_ent_structure_keyword_list": sorted_chg_ent_structure_keyword_list,
1735
+ "chg_ent_structure_cluster_data": chg_ent_structure_cluster_data,
1736
+ "rep_struct_data": rep_struct_data}
1737
+
1738
+ for key, value in save_items.items():
1739
+ try:
1740
+ shape_info = f", shape={np.shape(value)}"
1741
+ except Exception as e:
1742
+ shape_info = f", shape=unavailable ({e})"
1743
+ self.logger.info(f"{key}: type={type(value)}{shape_info}")
1744
+
1745
+
1746
+ # Save data
1747
+ np.savez(npz_data_file,
1748
+ chg_ent_fingerprint_list=chg_ent_fingerprint_list,
1749
+ Q_list=Q_list,
1750
+ chg_ent_keyword_dict=chg_ent_keyword_dict,
1751
+ chg_ent_keyword_list=chg_ent_keyword_list,
1752
+ idx2trajfile=idx2trajfile,
1753
+ idx2frame=idx2frame,
1754
+ # RMSD_array=RMSD_array,
1755
+ ent_cluster_data=ent_cluster_data,
1756
+ ent_cluster_tree=ent_cluster_tree,
1757
+ rep_chg_ent_list=rep_chg_ent_list,
1758
+ dtrajs=np.array(dtrajs, dtype=object),
1759
+ rep_chg_ent_dtrajs=np.array(rep_chg_ent_dtrajs, dtype=object),
1760
+ sorted_chg_ent_structure_keyword_list=sorted_chg_ent_structure_keyword_list,
1761
+ chg_ent_structure_cluster_data=chg_ent_structure_cluster_data,
1762
+ rep_struct_data=rep_struct_data)
1763
+ self.logger.info(f'SAVED: {npz_data_file}')
1764
+ self.logger.info(f'Clustering Complete')
1765
+ ##########################################################################################################################################################
1766
+
1767
+ ##########################################################################################################################################################
1768
+ def viz_rep_changes(self, ):
1769
+ psf_file = None
1770
+ if_viz = 1
1771
+ if_backmap = 0
1772
+ pulchra_only = False
1773
+ native_AA_pdb = None
1774
+ top_struct = 0.01
1775
+
1776
+ if if_viz:
1777
+ ## Generate visualiztion for representative changes of entanglement in each cluster
1778
+ self.logger.debug('Generate visualization for representative changes of entanglement...')
1779
+ os.system('mkdir viz_rep_chg_ent_%s'%('_'.join(self.classify_key)))
1780
+ for state_id, rep_chg_ent in enumerate(rep_chg_ent_list):
1781
+ state_cor = mdt.load(idx2trajfile[rep_chg_ent[0]], top=psf_file)[rep_chg_ent[1]].center_coordinates().xyz*10
1782
+ nc = (rep_chg_ent[2], rep_chg_ent[3])
1783
+ chg_ent_fingerprint = chg_ent_fingerprint_list[rep_chg_ent[0]][rep_chg_ent[1]][nc]
1784
+ rep_ent_dict = {tuple(chg_ent_fingerprint['code']): [chg_ent_fingerprint]}
1785
+ os.chdir('viz_rep_chg_ent_%s'%('_'.join(self.classify_key)))
1786
+ gen_state_visualizion(state_id+1, psf_file, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=if_backmap, pulchra_only=pulchra_only)
1787
+ os.chdir('../')
1788
+
1789
+ # Generate visualiztion for representative changes of entanglements in each structural cluster
1790
+ self.logger.debug('Generate visualization for unique entangled structures...')
1791
+ if top_struct >= 1:
1792
+ viz_dir = 'viz_chg_ent_struct_%s_%d'%('_'.join(self.classify_key), top_struct)
1793
+ else:
1794
+ viz_dir = 'viz_chg_ent_struct_%s_%.4f'%('_'.join(self.classify_key), top_struct)
1795
+ os.system('mkdir %s'%viz_dir)
1796
+ for state_id, keyword in enumerate(sorted_chg_ent_structure_keyword_list):
1797
+ if top_struct >= 1 and state_id >= top_struct:
1798
+ break
1799
+ elif top_struct < 1 and sorted_Num_struct_list[state_id]/tot_num_frames < top_struct:
1800
+ break
1801
+ [traj_idx, frame_idx] = rep_struct_data[keyword]
1802
+ state_cor = mdt.load(idx2trajfile[traj_idx], top=psf_file)[frame_idx].center_coordinates().xyz*10
1803
+ frame_idx_0 = idx2frame[traj_idx].index(frame_idx)
1804
+ rep_chg_ent_dict = rep_chg_ent_dtrajs[traj_idx][frame_idx_0]
1805
+ rep_ent_dict = {tuple(v['code']): [] for k, v in rep_chg_ent_dict.items()}
1806
+ for k, v in rep_chg_ent_dict.items():
1807
+ rep_ent_dict[tuple(v['code'])].append(v)
1808
+ os.chdir(viz_dir)
1809
+ gen_state_visualizion(state_id+1, psf_file, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=if_backmap, pulchra_only=pulchra_only)
1810
+ os.chdir('../')
1811
+ ##########################################################################################################################################################
1812
+
1813
+ ##########################################################################################################################################################
1814
+ def gen_state_visualizion(self, state_id, psf, state_cor, native_AA_pdb, rep_ent_dict, if_backmap=True, pulchra_only=False):
1815
+ def idx2sel(idx_list):
1816
+ if len(idx_list) == 0:
1817
+ return ''
1818
+ else:
1819
+ sel = 'index'
1820
+ idx_0 = idx_list[0]
1821
+ idx_1 = idx_list[0]
1822
+ sel_0 = ' %d'%idx_0
1823
+ for i in range(1, len(idx_list)):
1824
+ if idx_list[i] == idx_list[i-1] + 1:
1825
+ idx_1 = idx_list[i]
1826
+ else:
1827
+ if idx_1 > idx_0:
1828
+ sel_0 += ' to %d'%idx_1
1829
+ sel += sel_0
1830
+ idx_0 = idx_list[i]
1831
+ idx_1 = idx_list[i]
1832
+ sel_0 = ' %d'%idx_0
1833
+ if idx_1 > idx_0:
1834
+ sel_0 += ' to %d'%idx_1
1835
+ sel += sel_0
1836
+ return sel
1837
+
1838
+ AA_name_list = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
1839
+ 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL',
1840
+ 'HIE', 'HID', 'HIP']
1841
+ protein_colorid_list = [6, 6]
1842
+ loop_colorid_list = [1, 1]
1843
+ thread_colorid_list = [0, 0]
1844
+ nc_colorid_list = [3, 3]
1845
+ crossing_colorid_list = [8, 8]
1846
+ thread_cutoff=3
1847
+ terminal_cutoff=3
1848
+
1849
+ self.logger.info('Generate visualization of state %d'%(state_id))
1850
+
1851
+ struct = pmd.load_file(psf)
1852
+ struct.coordinates = state_cor
1853
+
1854
+ # backmap
1855
+ if if_backmap:
1856
+ if pulchra_only:
1857
+ pulchra_only = '1'
1858
+ else:
1859
+ pulchra_only = '0'
1860
+ struct.save('tmp.pdb', overwrite=True)
1861
+ os.system('backmap.py -i '+native_AA_pdb+' -c tmp.pdb -p '+pulchra_only)
1862
+ os.system('mv tmp_rebuilt.pdb state_%d.pdb'%state_id)
1863
+ os.system('rm -f tmp.pdb')
1864
+ os.system('rm -rf ./rebuild_tmp/')
1865
+ else:
1866
+ struct.save('state_%d.pdb'%state_id, overwrite=True)
1867
+
1868
+ ref_struct = pmd.load_file(native_AA_pdb)
1869
+ current_struct = pmd.load_file('state_%d.pdb'%state_id)
1870
+
1871
+ if len(list(rep_ent_dict.keys())) == 0:
1872
+ # no change of entaglement
1873
+ f = open('vmd_s%d_none.tcl'%(state_id), 'w')
1874
+ f.write('# Entanglement type: no change\n')
1875
+ f.write('''display rendermode GLSL
1876
+ axes location off
1877
+
1878
+ color Display {Background} white
1879
+
1880
+ mol new ./'''+('state_%d.pdb'%state_id)+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
1881
+ mol delrep 0 top
1882
+ mol representation NewCartoon 0.300000 10.000000 4.100000 0
1883
+ mol color ColorID '''+str(protein_colorid_list[1])+'''
1884
+ mol selection {all}
1885
+ mol material AOChalky
1886
+ mol addrep top
1887
+ ''')
1888
+ f.close()
1889
+
1890
+ # Create vmd script for each type of change
1891
+ for ent_code, rep_ent_list in rep_ent_dict.items():
1892
+ pmd_struct_list = [ref_struct, current_struct]
1893
+ struct_dir_list = [native_AA_pdb, './state_%d.pdb'%state_id]
1894
+ key_prefix_list = ['ref_', '']
1895
+ repres_list = ['', '']
1896
+ align_sel_list = ['', '']
1897
+
1898
+ vmd_script = '''# Entanglement type: '''+str(rep_ent_list[0]['type'])+'''
1899
+ package require topotools
1900
+ display rendermode GLSL
1901
+ axes location off
1902
+
1903
+ color Display {Background} white
1904
+
1905
+ '''
1906
+ for struct_idx, pmd_struct in enumerate(pmd_struct_list):
1907
+ struct_dir = struct_dir_list[struct_idx]
1908
+ protein_colorid = protein_colorid_list[struct_idx]
1909
+ loop_colorid = loop_colorid_list[struct_idx]
1910
+ thread_colorid = thread_colorid_list[struct_idx]
1911
+ nc_colorid = nc_colorid_list[struct_idx]
1912
+ crossing_colorid = crossing_colorid_list[struct_idx]
1913
+ key_prefix = key_prefix_list[struct_idx]
1914
+
1915
+ # Clean ligands
1916
+ clean_sel_idx = np.zeros(len(pmd_struct.atoms))
1917
+ for res in pmd_struct.residues:
1918
+ if res.name in AA_name_list:
1919
+ for atm in res.atoms:
1920
+ clean_sel_idx[atm.idx] = 1
1921
+ pmd_clean_struct = pmd_struct[clean_sel_idx]
1922
+ clean_idx_to_idx = np.where(clean_sel_idx == 1)[0]
1923
+
1924
+ # vmd selection string for protein
1925
+ idx_list = []
1926
+ for res in pmd_struct.residues:
1927
+ if res.name in AA_name_list:
1928
+ idx_list += [atm.idx for atm in res.atoms]
1929
+ vmd_sel = idx2sel(idx_list)
1930
+
1931
+ repres = '''mol new '''+struct_dir+''' type pdb first 0 last -1 step 1 filebonds 1 autobonds 1 waitfor all
1932
+ mol delrep 0 top
1933
+ mol representation NewCartoon 0.300000 10.000000 4.100000 0
1934
+ mol color ColorID '''+str(protein_colorid)+'''
1935
+ mol selection {'''+vmd_sel+'''}
1936
+ mol material AOChalky
1937
+ mol addrep top
1938
+ '''
1939
+ align_sel = vmd_sel
1940
+ for chg_ent_fingerprint in rep_ent_list:
1941
+ nc = chg_ent_fingerprint[key_prefix+'native_contact']
1942
+
1943
+ idx_list = []
1944
+ for res in pmd_clean_struct.residues:
1945
+ if res.idx in nc:
1946
+ idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
1947
+ nc_sel = idx2sel(clean_idx_to_idx[idx_list])
1948
+
1949
+ idx_list = []
1950
+ for res in pmd_clean_struct.residues:
1951
+ if res.idx >= nc[0] and res.idx <= nc[1]:
1952
+ idx_list += [atm.idx for atm in res.atoms]
1953
+ loop_sel = idx2sel(clean_idx_to_idx[idx_list])
1954
+
1955
+ align_sel += ' and not (%s)'%loop_sel
1956
+ ref_coss_resid = chg_ent_fingerprint['ref_crossing_resid']
1957
+ cross_resid = chg_ent_fingerprint['crossing_resid']
1958
+ thread = []
1959
+ thread_sel_list = []
1960
+ for ter_idx in range(len(ref_coss_resid)):
1961
+ thread_0 = []
1962
+ resid_list = ref_coss_resid[ter_idx] + cross_resid[ter_idx]
1963
+ if len(resid_list) > 0:
1964
+ thread_0 = [np.min(resid_list)-5, np.max(resid_list)+5]
1965
+ if ter_idx == 0:
1966
+ thread_0[0] = np.max([thread_0[0], terminal_cutoff])
1967
+ thread_0[1] = np.min([thread_0[1], nc[0]-thread_cutoff])
1968
+ else:
1969
+ thread_0[0] = np.max([thread_0[0], nc[1]+thread_cutoff])
1970
+ thread_0[1] = np.min([thread_0[1], len(struct.atoms)-1-terminal_cutoff])
1971
+ idx_list = []
1972
+ for res in pmd_clean_struct.residues:
1973
+ if res.idx >= thread_0[0] and res.idx <= thread_0[1]:
1974
+ idx_list += [atm.idx for atm in res.atoms]
1975
+ thread_0_sel = idx2sel(clean_idx_to_idx[idx_list])
1976
+ thread_sel_list.append(thread_0_sel)
1977
+ align_sel += ' and not (%s)'%thread_0_sel
1978
+ else:
1979
+ thread_sel_list.append('')
1980
+ thread.append(thread_0)
1981
+
1982
+ ln = chg_ent_fingerprint[key_prefix+'topoly_linking_number']
1983
+ cross = []
1984
+ for i in range(len(chg_ent_fingerprint[key_prefix+'crossing_resid'])):
1985
+ cross.append([])
1986
+ for j in range(len(chg_ent_fingerprint[key_prefix+'crossing_resid'][i])):
1987
+ cross[-1].append(chg_ent_fingerprint[key_prefix+'crossing_pattern'][i][j]+str(chg_ent_fingerprint[key_prefix+'crossing_resid'][i][j]))
1988
+ repres += '# idx: native contact %s, linking number %s, crossings %s.\n'%(str(nc), str(ln), str(cross))
1989
+ repres +='''mol representation NewCartoon 0.350000 10.000000 4.100000 0
1990
+ mol color ColorID '''+str(loop_colorid)+'''
1991
+ mol selection {'''+loop_sel+'''}
1992
+ mol material Opaque
1993
+ mol addrep top
1994
+ mol representation VDW 1.000000 12.000000
1995
+ mol color ColorID '''+str(nc_colorid)+'''
1996
+ mol selection {'''+nc_sel+'''}
1997
+ mol material Opaque
1998
+ mol addrep top
1999
+ set sel [atomselect top "'''+nc_sel+'''"]
2000
+ set idx [$sel get index]
2001
+ topo addbond [lindex $idx 0] [lindex $idx 1]
2002
+ mol representation Bonds 0.300000 12.000000
2003
+ mol color ColorID '''+str(nc_colorid)+'''
2004
+ mol selection {'''+nc_sel+'''}
2005
+ mol material Opaque
2006
+ mol addrep top
2007
+ '''
2008
+ for ter_idx, thread_resid in enumerate(thread):
2009
+ if len(thread_resid) == 0:
2010
+ continue
2011
+ repres += '''mol representation NewCartoon 0.350000 10.000000 4.100000 0
2012
+ mol color ColorID '''+str(thread_colorid)+'''
2013
+ mol selection {'''+thread_sel_list[ter_idx]+'''}
2014
+ mol material Opaque
2015
+ mol addrep top
2016
+ '''
2017
+ if len(chg_ent_fingerprint[key_prefix+'crossing_resid'][ter_idx]) > 0:
2018
+ idx_list = []
2019
+ for res in pmd_clean_struct.residues:
2020
+ if res.idx in chg_ent_fingerprint[key_prefix+'crossing_resid'][ter_idx]:
2021
+ idx_list += [atm.idx for atm in res.atoms if atm.name == 'CA']
2022
+ crossing_sel = idx2sel(clean_idx_to_idx[idx_list])
2023
+ repres += '''mol representation VDW 1.000000 12.000000
2024
+ mol color ColorID '''+str(crossing_colorid)+'''
2025
+ mol selection {'''+crossing_sel+'''}
2026
+ mol material Opaque
2027
+ mol addrep top
2028
+ '''
2029
+
2030
+ if struct_idx == 0:
2031
+ repres += '''mol representation VDW 1.000000 12.000000
2032
+ mol color Name
2033
+ mol selection {not ('''+vmd_sel+''') and not water}
2034
+ mol material Opaque
2035
+ mol addrep top
2036
+ '''
2037
+ repres_list[struct_idx] = repres
2038
+ align_sel_list[struct_idx] = align_sel
2039
+
2040
+ vmd_script += '\n'.join(repres_list)
2041
+ vmd_script += '''
2042
+ set sel1 [atomselect 0 "'''+align_sel_list[0]+''' and name CA"]
2043
+ set sel2 [atomselect 1 "'''+align_sel_list[1]+''' and name CA"]
2044
+ set trans_mat [measure fit $sel1 $sel2]
2045
+ set move_sel [atomselect 0 "all"]
2046
+ $move_sel move $trans_mat
2047
+ '''
2048
+ f = open('vmd_s%d_n%s_c%s.tcl'%(state_id, ent_code[0], ent_code[1]), 'w')
2049
+ f.write(vmd_script)
2050
+ f.close()
2051
+ ##########################################################################################################################################################
2052
+
2053
+ ##########################################################################################################################################################
2054
+ ##########################################################################################################################################################
2055
+
2056
+ ##########################################################################################################################################################
2057
+ class MSMNonNativeEntanglementClustering:
2058
+ """
2059
+ Build a markov state model across an ensemble of protein structures with non-native entanglements to identify metastable states
2060
+ """
2061
+
2062
+ #######################################################################################
2063
+ def __init__(self, OPpath:str = './', outdir:str = './', ID:str = '',
2064
+ start:int= 0 , end:int = 99999999999, stride:int = 1, ITS:str = 'False', lagtime:int = 1,
2065
+ n_cluster:int = 400, kmean_stride:int = 2, n_small_states:int = 1, n_large_states:int = 10, dt:float = 0.015/1000, rm_traj_list:list = [], log_level:int = logging.INFO, logdir:str = None):
2066
+
2067
+ """
2068
+ Initializes the DataAnalysis class with necessary paths and parameters.
2069
+
2070
+ Parameters:
2071
+ ("--outdir", type=str, required=True, help="Path to output directory")
2072
+ ("--OPpath", type=str, required=True, help="Path to directory containing G and Q directories created by GQ.py")
2073
+ ("--ID", type=str, required=True, help="base name for output files")
2074
+ ("--start", type=int, required=False, help="First frame to analyze 0 indexed", default=0)
2075
+ ("--end", type=int, required=False, help="Last frame to analyze 0 indexed", default=-1)
2076
+ ("--stride", type=int, required=False, help="Frame stride", default=1)
2077
+ ("--ITS", type=str, required=False, help="Find optimal lag time with ITS", default='False')
2078
+ ("--lagtime", type=int, required=False, help="lagtime to build the model", default=1)
2079
+ ("--n_cluster", type=int, required=False, help="Number of k-means clusters to group. Default is 400.")
2080
+ ("--kmean_stride", type=int, required=False, help="Stride of reading trajectory frame when clustring by k-means. Default is 2.")
2081
+ ("--n_small_states", type=int, required=False, help="Number of clusters for the inactive microstates after MSM modeling to be clustered into", default=1)
2082
+ ("--n_large_states", type=int, required=False, help="Number of clusters for the active microstates after MSM modeling to be clustered into", default=10)
2083
+ ("--dt", type=float, required=False, help="timestep used in MD simulations in ns", default=0.015/1000)
2084
+ ("--rm_traj_list", type=str, nargs='+', required=False, help="List of trajectory numbers to remove from analysis", default=[])
2085
+ """
2086
+
2087
+ # parse the parameters
2088
+ self.outdir = outdir
2089
+ self.ID = ID
2090
+ self.logger = setup_logger('MSMNonNativeEntanglementClustering', outdir=logdir if logdir is not None else outdir, ID=ID, log_level=log_level)
2091
+
2092
+ self.OPpath = OPpath
2093
+ self.logger.debug(f'OPpath: {self.OPpath}')
2094
+ self.logger.debug(f'outdir: {self.outdir}')
2095
+ self.logger.debug(f'ID: {self.ID}')
2096
+
2097
+ self.ITS = ITS
2098
+ self.logger.debug(f'ITS: {ITS}')
2099
+
2100
+ self.lagtime = lagtime
2101
+ self.logger.debug(f'lagtime: {lagtime}')
2102
+
2103
+ #self.dcds = args.dcds
2104
+ #print(f'dcds: {self.dcds}')
2105
+
2106
+ self.start = start
2107
+ self.end = end
2108
+ self.stride = stride
2109
+ self.logger.debug(f'START: {self.start} | END: {self.end} | STRIDE: {self.stride}')
2110
+
2111
+ self.n_cluster = n_cluster # Number of k-means clusters to group. Default is 400.
2112
+ self.kmean_stride = kmean_stride # Stride of reading trajectory frame when clustring by k-means.
2113
+ self.n_small_states = n_small_states # Number of clusters for the inactive microstates after MSM modeling to be clustered into
2114
+ self.n_large_states = n_large_states # Adjust based on your system
2115
+ self.dt = dt # timestep used in MD simulations in ns
2116
+
2117
+ self.rm_traj_list = rm_traj_list
2118
+ self.logger.info(f'Trajectories to ignore: {self.rm_traj_list}')
2119
+
2120
+ #######################################################################################
2121
+
2122
+ #######################################################################################
2123
+ def load_OP(self,):
2124
+ """
2125
+ Loads the GQ values of each trajectory into a 2D array and then appends it to a list
2126
+ The list should have Nt = number of trajectories and each array should be n x 2 where n is the number of frames
2127
+ """
2128
+ self.logger.info(f'Loading G and Q order parameters...')
2129
+ cor_list = []
2130
+ cor_list_idx_2_traj = {}
2131
+ Qfiles = glob.glob(os.path.join(self.OPpath, 'Q/*.Q'))
2132
+ QTrajs = [int(pathlib.Path(Qf).stem.split('Traj')[-1]) for Qf in Qfiles]
2133
+
2134
+ Gfiles = glob.glob(os.path.join(self.OPpath, 'G/*.G'))
2135
+ GTrajs = [int(pathlib.Path(Gf).stem.split('Traj')[-1]) for Gf in Gfiles]
2136
+
2137
+ shared_Trajs = set(QTrajs).intersection(GTrajs)
2138
+ #print(f'Shared Traj between Q and G: {shared_Trajs} {len(shared_Trajs)}')
2139
+ self.logger.info(f'Number of Q files found: {len(Qfiles)} | Number of G files found: {len(Gfiles)}')
2140
+ self.logger.info(f'Number of shared Traj between Q and G: {len(shared_Trajs)}')
2141
+
2142
+
2143
+ ## remove trajectories that are in the rm_traj_list
2144
+ if len(self.rm_traj_list) > 0:
2145
+ self.logger.info(f'Removing trajectories: {self.rm_traj_list}')
2146
+ shared_Trajs = [traj for traj in shared_Trajs if traj not in self.rm_traj_list]
2147
+ self.logger.info(f'Number of shared Traj after removing: {len(shared_Trajs)}')
2148
+
2149
+
2150
+ # loop through the Qfiles and find matching Gfile
2151
+ # then load the Q and G time series into a 2D array
2152
+ idx = 0
2153
+ QFrames = {}
2154
+ GFrames = {}
2155
+ for traj in shared_Trajs:
2156
+ #print(f'Traj: {traj}')
2157
+
2158
+ # get the cooresponding G and Q file
2159
+ Qf = [f for f in Qfiles if f.endswith(f'Traj{traj}.Q')]
2160
+ Gf = [f for f in Gfiles if f.endswith(f'Traj{traj}.G')]
2161
+ self.logger.debug(f'Qf: {Qf}')
2162
+ self.logger.debug(f'Gf: {Gf}')
2163
+
2164
+ ## Quality check to assert that only a single G and Q file were found
2165
+ assert len(Qf) == 1, f"the number of Q files {len(Qf)} should equal 1 for Traj {traj}"
2166
+ assert len(Gf) == 1, f"the number of G files {len(Gf)} should equal 1 for Traj {traj}"
2167
+
2168
+ # load the G Q data and extract only the time series column
2169
+ Qdata = pd.read_csv(Qf[0], sep=',')
2170
+ if self.start < 0:
2171
+ Qdata = Qdata.iloc[self.start:self.end + 1]
2172
+ else:
2173
+ Qdata = Qdata[(Qdata['Frame'] >= self.start) & (Qdata['Frame'] <= self.end)]
2174
+ #print(Qdata)
2175
+ QFrames[traj] = Qdata['Frame'].values
2176
+ Qdata = Qdata['total'].values.astype(float)
2177
+
2178
+
2179
+
2180
+ Gdata = pd.read_csv(Gf[0])
2181
+ if self.start < 0:
2182
+ Gdata = Gdata.iloc[self.start:self.end + 1]
2183
+ else:
2184
+ Gdata = Gdata[(Gdata['Frame'] >= self.start) & (Gdata['Frame'] <= self.end)]
2185
+ #print(Gdata)
2186
+ GFrames[traj] = Gdata['Frame'].values
2187
+ Gdata = Gdata['G'].values.astype(float)
2188
+ #print(f'Shape of OP: Q {Qdata.shape} G {Gdata.shape}')
2189
+
2190
+ ## Quality check that QFrames == GFrames.
2191
+ if set(QFrames[traj]) != set(GFrames[traj]):
2192
+ raise ValueError(f'The frames in Q {QFrames[traj]} do not match the frames in G {GFrames[traj]} for Traj {traj}. Please check your data files.')
2193
+
2194
+ ## Quality check that the G and Q data has the same number of frames
2195
+ if Qdata.shape != Gdata.shape:
2196
+ self.logger.warning(f"WARNING: The number of frames in Q {Qdata.shape} should equal the number of frames in G {Gdata.shape} in Traj {traj}")
2197
+ continue
2198
+
2199
+ ## Check and ensure that Qdata or Gdata has no nan values
2200
+ if np.isnan(Qdata).any():
2201
+ raise ValueError(f'There is a NaN value in this Qdata')
2202
+
2203
+ if np.isnan(Gdata).any():
2204
+ raise ValueError(f'There is a NaN value in this Gdata')
2205
+
2206
+ data = np.stack((Qdata, Gdata)).T
2207
+ data = data.astype(float)
2208
+ cor_list.append(data)
2209
+
2210
+ cor_list_idx_2_traj[idx] = int(traj)
2211
+ idx += 1
2212
+
2213
+ self.logger.info(f'Number of trajecotry OP coordinate loaded: {len(cor_list)}')
2214
+ self.cor_list = cor_list
2215
+ self.QFrames = QFrames
2216
+ self.GFrames = GFrames
2217
+
2218
+ ## Quality check that the number of trajectories loaded is equal to the number of Q and G files
2219
+ assert len(cor_list) == len(shared_Trajs), f"The # of coordinates loaded {len(cor_list)} does not equal the number of Q and G files with shared traj after removal of mirror images {len(shared_Trajs)}"
2220
+
2221
+ self.logger.info(f'Mapping of cor_list index to trajID in file names: {cor_list_idx_2_traj}')
2222
+ self.cor_list_idx_2_traj = cor_list_idx_2_traj
2223
+ #######################################################################################
2224
+
2225
+ #######################################################################################
2226
+ def standardize(self,):
2227
+ """
2228
+ Standardizes your OP by taking the mean and std Q and G across all traj data and rescaling each trajectorys data by
2229
+ Z = (d - mean)/std
2230
+ """
2231
+ data_con = self.cor_list[0]
2232
+ for i in range(1, len(self.cor_list)):
2233
+ data_con = np.vstack((data_con, self.cor_list[i]))
2234
+ self.data_mean = np.mean(data_con, axis=0)
2235
+ self.data_std = np.std(data_con, axis=0)
2236
+ self.standard_cor_list = [(d - self.data_mean) / self.data_std for d in self.cor_list]
2237
+ #######################################################################################
2238
+
2239
+ #######################################################################################
2240
+ def unstandardize(self, data):
2241
+ """
2242
+ Unstandardizes your OP by taking the mean and std Q and G across all traj data and rescaling each trajectorys data by
2243
+ Z*std + mean = d
2244
+ """
2245
+ return data*self.data_std + self.data_mean
2246
+ #######################################################################################
2247
+
2248
+ #######################################################################################
2249
+ def cluster(self,):
2250
+ """
2251
+ Cluster the GQ data across all trajectories using kmeans.
2252
+ dtrajs contains the resulting kmeans cluster labels for each trajectory time series
2253
+ centers contains the standardized GQ coordinates of the cluster centers
2254
+
2255
+ if the number of unique centers found is not equal to self.n_cluster then adjust it to reflect the number found. This can happen if you have data that has a narrow distribution.
2256
+ """
2257
+ self.clusters = pem.coordinates.cluster_kmeans(self.standard_cor_list, k=self.n_cluster, max_iter=5000, stride=self.kmean_stride)
2258
+
2259
+ # Get the microstate tagged trajectories and their state counts
2260
+ self.dtrajs = self.clusters.dtrajs
2261
+ self.logger.debug(f'dtrajs: {len(self.dtrajs)} {self.dtrajs[0].shape}\n{self.dtrajs[0][:10]}')
2262
+ clusterIDs, counts = np.unique(self.dtrajs, return_counts=True)
2263
+ self.logger.info(f'Number of unique microstate IDs: {len(clusterIDs)} {clusterIDs}')
2264
+
2265
+ state_counts = {}
2266
+ for i,c in zip(clusterIDs, counts):
2267
+ state_counts[i] = c
2268
+ self.logger.debug(f'state_counts: {state_counts}')
2269
+
2270
+ # Quality check that all microstate ids are assigned
2271
+ # If not renumber from 0
2272
+ if len(clusterIDs) != self.n_cluster:
2273
+ self.logger.info(f'The number of microstate IDs assigned does not match the number specified: {len(clusterIDs)} != {self.n_cluster}')
2274
+
2275
+ mapping_dict = {}
2276
+ for new,old in enumerate(clusterIDs):
2277
+ mapping_dict[old] = new
2278
+ self.logger.debug(f'mapping_dict: {mapping_dict}')
2279
+
2280
+ # Convert the dictionary to a numpy array for efficient mapping
2281
+ max_key = max(mapping_dict.keys())
2282
+ mapping_array = np.zeros(max_key + 1, dtype=int)
2283
+ for key, value in mapping_dict.items():
2284
+ mapping_array[key] = value
2285
+
2286
+ # Map the arrays using the mapping array
2287
+ self.dtrajs = [mapping_array[arr] for arr in self.dtrajs]
2288
+
2289
+ clusterIDs, counts = np.unique(self.dtrajs, return_counts=True)
2290
+ self.logger.info(f'Number of unique microstate IDs after mapping: {len(clusterIDs)} {clusterIDs}')
2291
+ state_counts = {}
2292
+ for i,c in zip(clusterIDs, counts):
2293
+ state_counts[i] = c
2294
+ self.logger.debug(f'state_counts: {state_counts}')
2295
+
2296
+ self.n_cluster = len(clusterIDs)
2297
+
2298
+
2299
+ standard_centers = self.clusters.clustercenters
2300
+ unstandard_centers = self.unstandardize(standard_centers)
2301
+ self.logger.info(f'unstandard_centers:\n{unstandard_centers} {unstandard_centers.shape}')
2302
+ self.logger.info(f'self.n_cluster: {self.n_cluster}')
2303
+
2304
+ #######################################################################################
2305
+
2306
+ #######################################################################################
2307
+ def build_msm(self, lagtime=1):
2308
+ self.logger.info(f'Building MSM model with a lag time of {lagtime}')
2309
+
2310
+ # Get count matrix and connective groups of microstates
2311
+ c_matrix = deeptime.markov.tools.estimation.count_matrix(self.dtrajs, lagtime).toarray()
2312
+ self.logger.info(f'c_matrix:\n{c_matrix} {c_matrix.shape}')
2313
+
2314
+ sub_groups = deeptime.markov.tools.estimation.connected_sets(c_matrix)
2315
+ self.logger.info(f'Total number of sub_groups: {len(sub_groups)}\n{sub_groups}')
2316
+
2317
+ # Build the MSM models for any connected sets that have more than 1 microstate
2318
+ msm_list = []
2319
+ for sg in sub_groups:
2320
+ cm = deeptime.markov.tools.estimation.largest_connected_submatrix(c_matrix, lcc=sg)
2321
+ self.logger.info(f'For sub_group: {sg}')
2322
+ if len(cm) == 1:
2323
+ msm = None
2324
+ else:
2325
+ self.logger.info(f'Building Transition matrix and MSM model')
2326
+ T = deeptime.markov.tools.estimation.transition_matrix(cm, reversible=True)
2327
+ msm = pem.msm.markov_model(T, dt_model=str(self.dt)+' ns')
2328
+ msm_list.append(msm)
2329
+ self.logger.info(f'Number of models: {len(msm_list)}')
2330
+
2331
+ # Coarse grain out the metastable macrostates in the models
2332
+ self.logger.info(f'Coarse grain out the metastable macrostates in the models')
2333
+ meta_dist = []
2334
+ meta_set = []
2335
+ eigenvalues_list = []
2336
+ for idx_msm, msm in enumerate(msm_list):
2337
+
2338
+ # the first model should contain the largest connected state so use the largest number of metastable states
2339
+ # for every other subgroup use the smallest
2340
+ if idx_msm == 0:
2341
+ n_states = self.n_large_states
2342
+ else:
2343
+ n_states = self.n_small_states
2344
+
2345
+ if msm == None:
2346
+ eigenvalues_list.append(None)
2347
+ dist = np.zeros(self.n_cluster)
2348
+ iidx = sub_groups[idx_msm][0]
2349
+ dist[iidx] = 1.0
2350
+ meta_dist.append(dist)
2351
+ meta_set.append(sub_groups[idx_msm])
2352
+
2353
+ else:
2354
+ eigenvalues_list.append(msm.eigenvalues())
2355
+ # coarse-graining
2356
+ while n_states > 1:
2357
+ tag_empty = False
2358
+ pcca = msm.pcca(n_states)
2359
+ for ms in msm.metastable_sets:
2360
+ if ms.size == 0:
2361
+ tag_empty = True
2362
+ break
2363
+ if not tag_empty:
2364
+ break
2365
+ else:
2366
+ n_states -= 1
2367
+ self.logger.info('Reduced number of states to %d for active group %d'%(n_states, idx_msm+1))
2368
+ if n_states == 1:
2369
+ # use observation prob distribution for non-active set
2370
+ dist = np.zeros(self.n_cluster)
2371
+ for nas in sub_groups[idx_msm]:
2372
+ for dtraj in dtrajs:
2373
+ dist[nas] += np.count_nonzero(dtraj == nas)
2374
+ dist /= np.sum(dist)
2375
+ meta_dist.append(dist)
2376
+ meta_set.append(sub_groups[idx_msm])
2377
+ else:
2378
+ for i, md in enumerate(msm.metastable_distributions):
2379
+ dist = np.zeros(self.n_cluster)
2380
+ s = np.sum(md[msm.metastable_sets[i]])
2381
+ set_0 = []
2382
+ for idx in msm.metastable_sets[i]:
2383
+ iidx = sub_groups[idx_msm][idx]
2384
+ dist[iidx] = md[idx]
2385
+ set_0.append(iidx)
2386
+ dist = dist / s
2387
+ meta_dist.append(dist)
2388
+ meta_set.append(set_0)
2389
+ meta_dist = np.array(meta_dist)
2390
+ self.logger.debug(f'meta_dist: {len(meta_dist)} {meta_dist.shape}')
2391
+ meta_dist_outfile = os.path.join(self.outdir, f'{self.ID}_meta_dist.npy')
2392
+ np.save(meta_dist_outfile, meta_dist, allow_pickle=True)
2393
+ self.logger.info(f'SAVED: {meta_dist_outfile}')
2394
+
2395
+ # print(f'meta_set: {meta_set}')
2396
+ meta_set_df = {'metastable_state':[], 'microstates':[]}
2397
+ for i, ms in enumerate(meta_set):
2398
+ # print(f'Metastable state {i}: {ms} with {len(ms)} microstates')
2399
+ for m in ms:
2400
+ meta_set_df['metastable_state'].append(i)
2401
+ meta_set_df['microstates'].append(m)
2402
+
2403
+ meta_set_df = pd.DataFrame(meta_set_df)
2404
+ self.logger.info(f'Meta set DataFrame:\n{meta_set_df}')
2405
+ meta_set_outfile = os.path.join(self.outdir, f'{self.ID}_meta_set.csv')
2406
+ meta_set_df.to_csv(meta_set_outfile, index=False)
2407
+ self.logger.info(f'SAVED: {meta_set_outfile}')
2408
+
2409
+
2410
+ ## make microstate to metastable state mapping object
2411
+ self.logger.info(f'\nMetastable state assignment')
2412
+ meta_mapping = {}
2413
+ for metaID, microstates in enumerate(meta_set):
2414
+ #print(metaID, microstates)
2415
+ for m in microstates:
2416
+ if m not in meta_mapping:
2417
+ meta_mapping[m] = metaID
2418
+ else:
2419
+ raise ValueError(f'Microstate {m} already in a metastable state!')
2420
+ self.logger.debug(f'meta_mapping: {meta_mapping} {len(meta_mapping)}')
2421
+
2422
+ # map those microstate states to the metastable state
2423
+ metastable_dtraj = []
2424
+ for dtraj_idx, dtraj in enumerate(self.dtrajs):
2425
+ mapped_dtraj = []
2426
+ for d in dtraj:
2427
+ mapped_dtraj.append(meta_mapping[d])
2428
+
2429
+ #rint(mapped_dtraj)
2430
+ metastable_dtraj += [np.asarray(mapped_dtraj)]
2431
+
2432
+ self.logger.info(f'Metastable state mapping:')
2433
+ for dtraj_idx, dtraj in enumerate(metastable_dtraj):
2434
+ self.logger.debug(f'dtraj_idx={dtraj_idx} dtrajs[:10]={self.dtrajs[dtraj_idx][:10]} dtraj[:10]={dtraj[:10]} shape={dtraj.shape}')
2435
+
2436
+
2437
+ ## get samples of metastable states by most populated microstates
2438
+ self.logger.debug(f'num_dtrajs={len(self.dtrajs)} dtrajs[0].shape={self.dtrajs[0].shape}')
2439
+ cluster_indexes = deeptime.markov.sample.compute_index_states(self.dtrajs)
2440
+ self.logger.debug(f'cluster_indexes: {len(cluster_indexes)}')
2441
+
2442
+
2443
+ samples = deeptime.markov.sample.indices_by_distribution(cluster_indexes, meta_dist, 5)
2444
+ self.logger.debug(f'samples: {samples} {len(samples)}')
2445
+
2446
+ ## Make the output dataframe that has assignments for each frame of each traj
2447
+ df = {'traj':[], 'frame':[], 'microstate':[], 'metastablestate':[], 'Q':[], 'G':[], 'StateSample':[]}
2448
+ self.logger.info(f'Active & inactive metastable state mapping')
2449
+ for k,v in enumerate(metastable_dtraj):
2450
+ traj = self.cor_list_idx_2_traj[k]
2451
+ #print(k, traj, v[:10])
2452
+ for frame, macrostate in enumerate(v):
2453
+ microstate = self.dtrajs[k][frame]
2454
+ Q = self.cor_list[k][frame, 0]
2455
+ G = self.cor_list[k][frame, 1]
2456
+
2457
+ if [k, frame] in samples[macrostate].tolist():
2458
+ StateSample = True
2459
+ else:
2460
+ StateSample = False
2461
+ #print(k, frame, microstate, macrostate, StateSample)
2462
+ df['traj'] += [traj]
2463
+ df['frame'] += [self.QFrames[traj][frame]]
2464
+ df['microstate'] += [microstate]
2465
+ df['metastablestate'] += [macrostate]
2466
+ df['Q'] += [Q]
2467
+ df['G'] += [G]
2468
+ df['StateSample'] += [StateSample]
2469
+
2470
+ df = pd.DataFrame(df)
2471
+ #df['frame'] = self.start # correct the frame index to start from the start specified by the user as this frame index starts from 0
2472
+ self.logger.info(f'Final MSM mapping DF:\n{df}')
2473
+ df_outfile = os.path.join(self.outdir, f'{self.ID}_MSMmapping.csv')
2474
+ df.to_csv(df_outfile, index=False)
2475
+ self.logger.info(f'SAVED: {df_outfile}')
2476
+
2477
+ # Plot the metastable state membership and free energy surface
2478
+ xall = np.hstack([dtraj[:, 0] for dtraj in self.cor_list])
2479
+ yall = np.hstack([dtraj[:, 1] for dtraj in self.cor_list])
2480
+ states = np.hstack(metastable_dtraj)
2481
+ self.logger.debug(f'xall: {xall} {xall.shape}')
2482
+ self.logger.debug(f'yall: {yall} {yall.shape}')
2483
+ self.logger.debug(f'states: {states} {states.shape}')
2484
+
2485
+ stateplot_outfile = os.path.join(self.outdir, f'{self.ID}_StateAndFEplot.png')
2486
+ self.plot_state_map_and_FE(xall, yall, states, stateplot_outfile)
2487
+ #######################################################################################
2488
+
2489
+ #######################################################################################
2490
+ def plot_state_map_and_FE(self, x, y, states, outfile, cmap='viridis', point_size=50, alpha=0.85, title='State Map'):
2491
+ """
2492
+ Plots a state map using x and y values colored by state assignments with labeled colorbar.
2493
+ Parameters:
2494
+ x (array-like): The x-coordinates of the points.
2495
+ y (array-like): The y-coordinates of the points.
2496
+ states (array-like): The state assignment for each point.
2497
+ cmap (str or Colormap): Colormap for state coloring (default is 'viridis').
2498
+ point_size (int): Size of the scatter plot points (default is 50).
2499
+ alpha (float): Transparency of the points (default is 0.7).
2500
+ title (str): Title of the plot (default is 'State Map with Labels').
2501
+ """
2502
+ #############################################################################################
2503
+ # Create a figure and subplots with 1 row and 2 columns
2504
+ fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
2505
+
2506
+ ### plot FE surface on left plot
2507
+ # Define the number of bins for the 2D histogram
2508
+ num_bins = 20
2509
+
2510
+ # Calculate the 2D histogram
2511
+ hist, xedges, yedges = np.histogram2d(x, y, bins=num_bins, density=True)
2512
+
2513
+ # Calculate the probability as the histogram values
2514
+ probability = hist / np.sum(hist)
2515
+ #print(f'probability: {probability} {np.unique(probability)}')
2516
+
2517
+ # Compute the free energy as -log10(probability)
2518
+ with np.errstate(divide='ignore'): # Ignore divide-by-zero warnings
2519
+ free_energy = -np.log10(probability)
2520
+ #free_energy[np.isinf(free_energy)] = np.nan # Set infinities to NaN for better plotting
2521
+ free_energy[np.isinf(free_energy) | np.isnan(free_energy)] = np.nanmax(free_energy[np.isfinite(free_energy)]) #+ 1 # Replace NaN/Inf with a large value
2522
+ self.logger.debug(f'free_energy: {free_energy} {free_energy.shape} {np.unique(free_energy)}')
2523
+
2524
+ # Create the meshgrid for the contour plot
2525
+ X, Y = np.meshgrid(xedges[:-1], yedges[:-1])
2526
+ self.logger.debug(f'X: {X.shape}\nY: {Y.shape}')
2527
+
2528
+ # Create a custom colormap
2529
+ #cmap = plt.cm.viridis
2530
+ #cmap = plt.cm.magma
2531
+ cmap = plt.cm.gist_ncar
2532
+
2533
+ # Plotting the contour plot
2534
+ contour = axes[0].contourf(X, Y, free_energy.T, levels=100, cmap=cmap) # Transpose to align axes
2535
+ fig.colorbar(contour, ax=axes[0], label='Free Energy (-log10 Probability)')
2536
+ axes[0].set_xlabel('Q')
2537
+ axes[0].set_ylabel('G')
2538
+ axes[0].set_title('2D Free Energy Contour Plot')
2539
+ axes[0].set_xlim(0,1)
2540
+
2541
+
2542
+ #############################################################################################
2543
+ ## Plot state map
2544
+ #_, axes[1], _ = pem.plots.plot_state_map(x, y, states)
2545
+
2546
+ # Create a 2D histogram to determine the bin index for each (x, y) pair
2547
+ #############################################################################################
2548
+ # Step 1: Identify unique states
2549
+ unique_states = np.unique(states)
2550
+ n_states = len(unique_states)
2551
+ self.logger.debug(f'unique_states: {unique_states} {n_states}')
2552
+
2553
+ # Step 2: Create a colormap with one color per unique state
2554
+ # You can use any colormap, or define specific colors if desired
2555
+ colors = plt.cm.get_cmap('tab20', n_states) # 'tab10' has up to 10 colors; change if needed
2556
+ cmap = ListedColormap([colors(i) for i in range(n_states)])
2557
+
2558
+ # Step 3: Map states to color indices
2559
+ state_to_index = {state: i for i, state in enumerate(unique_states)}
2560
+ color_indices = np.vectorize(state_to_index.get)(states)
2561
+
2562
+ # # Step 4: Create scatter plot
2563
+ scatter = axes[1].scatter(x, y, c=color_indices, cmap=cmap, s=50, edgecolor='k') # Customize marker size, etc.
2564
+
2565
+ # Step 5: Add a colorbar with labels
2566
+ cbar = plt.colorbar(scatter, ax=axes[1], ticks=np.linspace(0.5, n_states - 1.5, num=n_states), label=f'Metastable States')
2567
+ cbar.ax.set_yticklabels(unique_states) # Label colorbar with the unique state values
2568
+ #############################################################################################
2569
+
2570
+ axes[1].set_xlabel('Q')
2571
+ axes[1].set_ylabel('G')
2572
+ axes[1].set_title('2D state map')
2573
+ axes[1].set_xlim(0,1)
2574
+
2575
+ #plt.tight_layout()
2576
+ plt.savefig(outfile)
2577
+ self.logger.info(f'SAVED: {outfile}')
2578
+ plt.clf()
2579
+
2580
+ #######################################################################################
2581
+
2582
+ #######################################################################################
2583
+ def plot_implied_timescales(self,):
2584
+ """
2585
+ Should be done first before building the model to find a proper lagtime for which the timescales (eignenvalues of the transition matrix) of the model are no longer dependant.
2586
+ Look for the point or range where the implied timescales stop changing significantly with increasing lag times.
2587
+ This lag time is generally a good choice for building your MSM, as it suggests the dynamics are being captured without undue dependence on the initial conditions.
2588
+ """
2589
+ #nits = -1
2590
+ lag_times = np.arange(1, 100, 10) # adjust the range based on your system
2591
+ n_states = len(np.unique(self.dtrajs)) # or a predefined number of states
2592
+ its = pem.msm.its(self.dtrajs, lags=lag_times, errors='bayes')
2593
+ pem.plots.plot_implied_timescales(its)
2594
+ ITS_outfile = os.path.join(self.outdir, f'{self.ID}_ITS.png')
2595
+ plt.savefig(ITS_outfile)
2596
+ self.logger.info(f'SAVED: {ITS_outfile}')
2597
+ #######################################################################################
2598
+
2599
+ #######################################################################################
2600
+ def run(self, ):
2601
+
2602
+ ## make output folder
2603
+ if not os.path.exists(self.outdir):
2604
+ os.makedirs(self.outdir)
2605
+ self.logger.info(f'Made directory: {self.outdir}')
2606
+
2607
+ # load the G and Q data
2608
+ self.load_OP()
2609
+
2610
+ # apply the standard scalar transformation to the data
2611
+ self.standardize()
2612
+
2613
+ # cluster the standardized data using kmeans clustering with a stride of 10, change this if necessary
2614
+ self.cluster()
2615
+
2616
+ # genereate the implied timescales plot to check for a suitable lag time
2617
+ # Should be done first before building the model to choose a suitable lag time
2618
+ if self.ITS == 'True':
2619
+ anal.plot_implied_timescales()
2620
+ self.logger.info(f'Analysis terminated since ITS was selected. Check the figure and choose an approrate lagtime')
2621
+ quit()
2622
+
2623
+ # Build the MSM model with the choosen lagtime
2624
+ self.build_msm(lagtime=self.lagtime)
2625
+ #######################################################################################
2626
+