EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,2067 @@
1
+ ###################################################################################################
2
+ ## Multiprocesing helper function(s)
3
+ def process_frame(args):
4
+ frame_idx, dcd, PSF, chain, chain_res, resids, atom_names, topoly, ID, Calpha, CG, g_threshold, density, ent_detection_method = args
5
+ import MDAnalysis as mda
6
+ from EntDetect.gaussian_entanglement import GaussianEntanglement
7
+ univ = mda.Universe(PSF, dcd)
8
+ chain_atoms = univ.select_atoms(f"segid {chain}")
9
+ univ.trajectory[frame_idx]
10
+ coor = chain_atoms.positions
11
+ ge = GaussianEntanglement(g_threshold=g_threshold, density=density, Calpha=Calpha, CG=CG, ent_detection_method=ent_detection_method)
12
+ ent_result = ge.get_traj_entanglements(coor, chain_res, resids, atom_names, topoly=topoly)
13
+ result_rows = []
14
+ if not ent_result:
15
+ result_rows.append((ID, chain, frame_idx, '', '', '', '', '', '', '', '', '', '', ''))
16
+ else:
17
+ for ij_gN_gC, crossings in ent_result.items():
18
+ # print(ij_gN_gC, crossings)
19
+
20
+ i, j = ij_gN_gC[0], ij_gN_gC[1]
21
+ crossingsN, crossingsC = crossings
22
+ crossingsN = ','.join(crossingsN)
23
+ crossingsC = ','.join(crossingsC)
24
+
25
+ result_rows.append((ID, chain, frame_idx, i, j, crossingsN, crossingsC, f'{ij_gN_gC[2]: .5f}', f'{ij_gN_gC[3]: .5f}', ij_gN_gC[4], ij_gN_gC[5], ij_gN_gC[6], ij_gN_gC[7]))
26
+ logging.getLogger('EntDetect.GaussianEntanglement').info(f'Frame idx: {frame_idx} processed with {len(result_rows)} contact(s) found.')
27
+ return result_rows
28
+ ###################################################################################################
29
+
30
+ ###################################################################################################
31
+ ## load in necessary packages
32
+ import logging
33
+ import itertools
34
+ import math
35
+ import sys
36
+ from Bio.PDB import PDBParser
37
+ import os
38
+ from operator import itemgetter
39
+ from warnings import filterwarnings
40
+ import numpy as np
41
+ import MDAnalysis as mda
42
+ import pandas as pd
43
+ from scipy.spatial.distance import pdist, squareform
44
+ from topoly import lasso_type # used pip
45
+ import re
46
+ import pickle
47
+ import json
48
+ from collections import defaultdict
49
+ import time
50
+ import multiprocessing as mp
51
+ from EntDetect._logging import setup_logger
52
+ from typing import Optional
53
+ filterwarnings("ignore")
54
+
55
+ class GaussianEntanglement:
56
+ """
57
+ Gaussian Entanglement Class for calculating entanglements in protein structures derived from both experiments and AlphaFold.
58
+ """
59
+ def __init__(self, g_threshold: float = 0.6, density: float = 0.0, Calpha: bool = False, CG: bool = False, nproc: int = 10, ent_detection_method: int = 2, log_level: int = logging.INFO, logdir: str = None) -> None:
60
+ """
61
+ Constructor for GaussianEntanglement class.
62
+
63
+ Parameters
64
+ ----------
65
+ g_threshold : float, optional
66
+ Threshold for Gaussian entanglement, by default 0.6
67
+ density : float, optional
68
+ Density for triangulation of minimal loop surface, by default 0.0
69
+ Calpha : bool, optional
70
+ Whether to use C-alpha atoms or heavy-atoms, by default False
71
+ CG : bool, optional
72
+ Whether the CG model was used to generate the simulations or structures
73
+ ent_detection_method : int, optional
74
+ Method to define ENT status of a raw NCLE:
75
+ 1 = any nonzero GLN for either termini
76
+ 2 = any nonzero TLN for either termini (default)
77
+ 3 = both GLN and TLN must have nonzero for same termini
78
+ """
79
+
80
+ self.logger = setup_logger('GaussianEntanglement', outdir=logdir, log_level=log_level)
81
+ self.g_threshold = g_threshold
82
+ self.density = density
83
+ self.Calpha = Calpha
84
+ self.CG = CG
85
+ self.nproc = nproc
86
+ self.ent_detection_method = ent_detection_method
87
+ self.logger.debug(f'GaussianEntanglement initialized with g_threshold: {g_threshold}, density: {density}, Calpha: {Calpha}, CG: {CG}, nproc: {nproc}, ent_detection_method: {ent_detection_method}')
88
+
89
+ self.change_codes = {'L-C~': 'loss of linking number & switched linking chirality',
90
+ 'L-C#': 'loss of linking number & no change of linking chirality',
91
+ 'L+C~': 'gain of linking number & switched linking chirality',
92
+ 'L+C#': 'gain of linking number & no change of linking chirality',
93
+ 'L#C~': 'no change of linking number & switched linking chirality',
94
+ 'L#C#': 'no change'}
95
+ # print class initialization parameters
96
+ ##########################################################################################################################################################
97
+
98
+ ##########################################################################################################################################################
99
+ def determine_ent_status(self, gln_n: float, gln_c: float, tln_n: int, tln_c: int) -> bool:
100
+ """
101
+ Determine if a native contact is entangled based on the selected detection method.
102
+
103
+ Parameters
104
+ ----------
105
+ gln_n : float
106
+ Gaussian linking number for N-terminus
107
+ gln_c : float
108
+ Gaussian linking number for C-terminus
109
+ tln_n : int
110
+ Topological linking number for N-terminus
111
+ tln_c : int
112
+ Topological linking number for C-terminus
113
+
114
+ Returns
115
+ -------
116
+ bool
117
+ True if entangled according to the selected method, False otherwise
118
+ """
119
+ if self.ent_detection_method == 1:
120
+ # Any nonzero GLN for either termini
121
+ return (gln_n != 0) or (gln_c != 0)
122
+ elif self.ent_detection_method == 2:
123
+ # Any nonzero TLN for either termini (default)
124
+ return (tln_n != 0) or (tln_c != 0)
125
+ elif self.ent_detection_method == 3:
126
+ # Both GLN and TLN must have nonzero for same termini
127
+ # Both N, both C, or both N and C
128
+ n_both = (gln_n != 0) and (tln_n != 0)
129
+ c_both = (gln_c != 0) and (tln_c != 0)
130
+ return n_both or c_both
131
+ else:
132
+ raise ValueError(f"Invalid ent_detection_method: {self.ent_detection_method}. Must be 1, 2, or 3.")
133
+ ##########################################################################################################################################################
134
+
135
+ ##########################################################################################################################################################
136
+ def helper_dot(self, Runit: np.ndarray, dR_cross: np.ndarray) -> list:
137
+
138
+ """
139
+ Numba function to speed up dot product calculation. Ability
140
+ to use current GPU (if available) and CPU
141
+
142
+ """
143
+
144
+ return [np.dot(x,y) for x,y in zip(Runit,dR_cross)]
145
+ ##########################################################################################################################################################
146
+
147
+ ##########################################################################################################################################################
148
+ def point_rounding(self, num: float) -> float:
149
+
150
+ """
151
+ Rounds n based on the specified threshold:
152
+ - If n is NaN, returns NaN.
153
+ - For positive n: if fractional part >= threshold ? ceil; else ? floor.
154
+ - For negative n: if |fractional part| >= threshold ? more negative; else ? toward zero.
155
+ """
156
+ # Handle NaN
157
+ if isinstance(num, float) and math.isnan(num):
158
+ return num
159
+
160
+ int_part = math.trunc(num) # truncate toward zero
161
+ frac = num - int_part
162
+ if num >= 0:
163
+ return int_part + (1 if frac >= self.g_threshold else 0)
164
+ else:
165
+ return int_part - (1 if abs(frac) >= self.g_threshold else 0)
166
+ ##########################################################################################################################################################
167
+
168
+ ##########################################################################################################################################################
169
+ def get_entanglements(self, coor: np.ndarray, l: int, pdb_file: str, resids: np.ndarray,
170
+ resnames: np.ndarray,resid_index_to_ref_allatoms_idx: dict, ca_coor: np.ndarray, resid_index_to_resid: dict,
171
+ termini_threshold: list=[5,5], loop_thread_thresh: list=[4,4], topoly: bool = True) -> dict:
172
+
173
+ """
174
+ Find proteins containing non-covalent lasso entanglements.
175
+
176
+ Entanglements are composed of loops (defined by native contacts) and crossing residue(s).
177
+
178
+ """
179
+ Nterm_thresh = termini_threshold[0]
180
+ Cterm_thresh = termini_threshold[1]
181
+ loop_Nthread_thresh = loop_thread_thresh[0]
182
+ loop_Cthread_thresh = loop_thread_thresh[1]
183
+ # print(f'Finding entanglements with Nterm_thresh: {Nterm_thresh}, Cterm_thresh: {Cterm_thresh}, loop_Nthread_thresh: {loop_Nthread_thresh}, loop_Cthread_thresh: {loop_Cthread_thresh}')
184
+ # print(f'Coordinates shape: {coor.shape}\n{coor[:10]}\n...\n{coor[-10:]}')
185
+
186
+ # make native contact contact map
187
+ dist_matrix = squareform(pdist(coor))
188
+
189
+ if self.Calpha == False:
190
+ native_cmap = np.where(dist_matrix <= 4.5, 1, 0) # if true then 1 will appear otherwise zero
191
+ elif self.Calpha == True:
192
+ native_cmap = np.where(dist_matrix <= 8, 1, 0) # if true then 1 will appear otherwise zero
193
+ native_cmap = np.triu(native_cmap, k=4) # element below the 4th diagonal starting from middle are all zeros; # protein contact map
194
+
195
+ num_res = len(resid_index_to_ref_allatoms_idx.keys())
196
+
197
+ assert num_res == len(resids), f"Something's wrong with {pdb_file} residues {num_res} != {len(resids)}"
198
+
199
+ res_ncmap = np.zeros((num_res, num_res))
200
+ resid_pairs = list(itertools.product(np.arange(num_res), np.arange(num_res)))
201
+
202
+ for pair in resid_pairs:
203
+ pair0_resid = resid_index_to_resid[pair[0]]
204
+ pair1_resid = resid_index_to_resid[pair[1]]
205
+ # check that the resid are greater than 4 apart
206
+ if abs(pair1_resid - pair0_resid) > 4:
207
+ if pair[0] in resid_index_to_ref_allatoms_idx and pair[1] in resid_index_to_ref_allatoms_idx:
208
+ res1_atoms = resid_index_to_ref_allatoms_idx[pair[0]]
209
+ res2_atoms = resid_index_to_ref_allatoms_idx[pair[1]]
210
+ res1_atoms_start = min(res1_atoms)
211
+ res1_atoms_end = max(res1_atoms)
212
+ res2_atoms_start = min(res2_atoms)
213
+ res2_atoms_end = max(res2_atoms)
214
+ sub_array = native_cmap[res1_atoms_start:res1_atoms_end + 1, res2_atoms_start:res2_atoms_end + 1]
215
+ contact = np.sum(sub_array)
216
+ dist_sub_array = dist_matrix[res1_atoms_start:res1_atoms_end + 1, res2_atoms_start:res2_atoms_end + 1]
217
+ min_dist = np.min(dist_sub_array)
218
+
219
+ if contact > 0:
220
+ res_ncmap[pair[0], pair[1]] = 1
221
+ #print(f'Found contact: {pair0_resid} & {pair1_resid} with min dist: {min_dist}')
222
+ # else:
223
+ # print(f'No contact: {pair0_resid} & {pair1_resid} with min dist: {min_dist}')
224
+ del native_cmap
225
+ native_cmap = res_ncmap
226
+
227
+ nc_indexs = np.stack(np.nonzero(native_cmap)).T # stack indices based on rows
228
+
229
+ # make R coordinate and gradient of R length N-1
230
+ range_l = np.arange(0, l-1)
231
+ range_next_l = np.arange(1,l)
232
+
233
+ ca_coor = ca_coor.astype(np.float32)
234
+ R = 0.5*(ca_coor[range_l] + ca_coor[range_next_l])
235
+ dR = ca_coor[range_next_l] - ca_coor[range_l]
236
+
237
+ #make dRcross matrix
238
+ pair_array = np.asarray(list(itertools.product(dR,dR))) # combination of elements within array
239
+
240
+ x = pair_array[:,0,:]
241
+ y = pair_array[:,1,:]
242
+
243
+ dR_cross = np.cross(x, y)
244
+
245
+ #make Rnorm matrix
246
+ pair_array = np.asarray(list(itertools.product(R,R)))
247
+ diff = pair_array[:,0,:] - pair_array[:,1,:]
248
+ diff = diff.astype(np.float32)
249
+
250
+ Runit = diff / np.linalg.norm(diff, axis=1)[:,None]**3
251
+ Runit = Runit.astype(np.float32)
252
+
253
+ #make final dot matrix
254
+ dot_matrix = self.helper_dot(Runit, dR_cross)
255
+ dot_matrix = np.asarray(dot_matrix)
256
+ dot_matrix = dot_matrix.reshape((l-1,l-1))
257
+
258
+ nc_gdict = {}
259
+
260
+ for i,j in nc_indexs:
261
+
262
+ # loop_range = np.arange(i,j)
263
+ # nterm_range = np.arange(Nterm_thresh,i-5)
264
+ # cterm_range = np.arange(j+6,l-(Cterm_thresh + 1))
265
+ loop_range = np.arange(i, j)
266
+ nterm_range = np.arange(Nterm_thresh, i-loop_Nthread_thresh-1)
267
+ cterm_range = np.arange(j+loop_Cthread_thresh+1, l-(Cterm_thresh + 1))
268
+
269
+ gn_pairs_array = np.fromiter(itertools.chain(*itertools.product(nterm_range, loop_range)), int).reshape(-1, 2)
270
+ gc_pairs_array = np.fromiter(itertools.chain(*itertools.product(loop_range, cterm_range)), int).reshape(-1, 2)
271
+
272
+ if gn_pairs_array.size != 0:
273
+
274
+ gn_vals = dot_matrix[gn_pairs_array[:,0],gn_pairs_array[:,1]]
275
+ gn_vals = gn_vals[~np.isnan(gn_vals)]
276
+ gn_val = np.sum(gn_vals) / (4.0 * np.pi)
277
+
278
+ else:
279
+ gn_val = 0
280
+
281
+ if gc_pairs_array.size != 0:
282
+
283
+ gc_vals = dot_matrix[gc_pairs_array[:,0],gc_pairs_array[:,1]]
284
+ gc_vals = gc_vals[~np.isnan(gc_vals)]
285
+ gc_val = np.sum(gc_vals) / (4.0 * np.pi)
286
+
287
+ else:
288
+ gc_val = 0
289
+
290
+ rounded_gc_val = self.point_rounding(np.float64(gc_val))
291
+ rounded_gn_val = self.point_rounding(np.float64(gn_val))
292
+
293
+ #if np.abs(rounded_gn_val) >= 1 or np.abs(rounded_gc_val) >= 1:
294
+ # #print(f'({i}, {j}) with gn: {gn_val} and gc: {gc_val}')
295
+ # nc_gdict[ (int(i), int(j)) ] = (gn_val, gc_val, rounded_gn_val, rounded_gc_val)
296
+ nc_gdict[ (int(i), int(j)) ] = (gn_val, gc_val, rounded_gn_val, rounded_gc_val)
297
+
298
+ missing_residues = self.find_missing_residues(resids)
299
+ #print(f'missing_residues: {missing_residues}')
300
+
301
+ filtered_nc_gdict = self.loop_filter(nc_gdict, resids, missing_residues)
302
+ #print(f'size filtered_nc_gdict after accounting for missing residues: {len(filtered_nc_gdict)}\n{filtered_nc_gdict}')
303
+
304
+ if topoly:
305
+ entangled_res = self.find_crossing(ca_coor.tolist(), filtered_nc_gdict, resids)
306
+ #print(f'entangled_res: {entangled_res}')
307
+
308
+ filtered_entangled_res = self.crossing_filter(entangled_res, missing_residues)
309
+ #print(f'filtered_entangled_res: {filtered_entangled_res}')
310
+ else:
311
+ filtered_entangled_res = {}
312
+ for ij, values in filtered_nc_gdict.items():
313
+ i, j = ij[0], ij[1]
314
+ gn = values[0]
315
+ gc = values[1]
316
+ GLNn = values[2]
317
+ GLNc = values[3]
318
+ TLNn = np.nan
319
+ TLNc = np.nan
320
+ filtered_entangled_res[(resids[i], resids[j], gn, gc, GLNn, GLNc, TLNn, TLNc)] = [np.array([]), np.array([])]
321
+
322
+ return filtered_entangled_res, missing_residues
323
+ ##########################################################################################################################################################
324
+
325
+ ##########################################################################################################################################################
326
+ def find_missing_residues(self, resids:np.ndarray) -> np.ndarray:
327
+
328
+ """
329
+ Find missing residues in pdb file
330
+
331
+ """
332
+
333
+ check_all_resids = np.arange(resids[0], resids[-1] + 1)
334
+
335
+ missing_residues = np.setdiff1d(check_all_resids, resids)
336
+
337
+ return missing_residues
338
+ ##########################################################################################################################################################
339
+
340
+ ##########################################################################################################################################################
341
+ def loop_filter(self, native_contacts: dict, resids: np.ndarray, missing_res: np.ndarray) -> dict:
342
+
343
+ """
344
+ Remove loops if there are three or more consecutive missing residues
345
+ or the amount of any missing residues exceed 5% of the loop length
346
+
347
+ """
348
+
349
+ for ij, values in native_contacts.items():
350
+
351
+ native_i = resids[ij[0]]
352
+
353
+ native_j = resids[ij[1]]
354
+
355
+ rounded_gn = values[-2]
356
+
357
+ rounded_gc = values[-1]
358
+
359
+ check_loop = np.arange(native_i , native_j + 1)
360
+
361
+ loop_length = check_loop.size
362
+
363
+ missing_res_loop = np.intersect1d(check_loop, missing_res)
364
+
365
+ for index, diff_resid_index in itertools.groupby(enumerate(missing_res_loop), lambda ix : ix[0] - ix[1]):
366
+
367
+ conseuctive_missing_residues = list(map(itemgetter(1), diff_resid_index))
368
+
369
+ if len(conseuctive_missing_residues) >= 3 or len(missing_res_loop) > 0.05 * loop_length:
370
+
371
+ native_contacts[ij] = None
372
+
373
+ return native_contacts
374
+ ##########################################################################################################################################################
375
+
376
+
377
+ ##########################################################################################################################################################
378
+ def find_crossing(self, coor: np.ndarray, nc_data: dict, resids: np.ndarray, terminal_buff:int=5, loop_buff:int=4) -> dict:
379
+
380
+ """
381
+ Use Topoly to find crossing(s) based on partial linking number
382
+
383
+ """
384
+
385
+ entangled_res = {}
386
+
387
+ native_contacts = [[ij[0], ij[1]] for ij, values in nc_data.items() if values is not None]
388
+
389
+ data = lasso_type(coor, loop_indices=native_contacts, more_info=True, density=self.density, min_dist=[10, loop_buff, terminal_buff])
390
+ # data = lasso_type(coor, loop_indices=native_contacts, more_info=True, density=self.density, min_dist=[0, 4, 5])
391
+ for native_contact in native_contacts:
392
+ # print(f'native_contact:\n{native_contact}')
393
+
394
+ crossings = []
395
+
396
+ native_contact = tuple(native_contact)
397
+ i, j = resids[native_contact[0]], resids[native_contact[1]]
398
+ # print('\n', (i,j), native_contact, data[native_contact])
399
+ ## Parse the N terminal crossings
400
+ # crossingN = [f"{cr[0]}{resids[int(cr[1:])]}" for cr in data[native_contact]["crossingsN"]]
401
+ crossingN = []
402
+ beforeN = [f"{cr[0]}{resids[int(cr[1:])]}" for cr in data[native_contact]["beforeN"] if '*' not in cr]
403
+ if beforeN:
404
+ # remove residues that violate the terminial and loop buffers
405
+ Ncrossing_resids = sorted([int(cr) for cr in beforeN], key=lambda x: abs(x), reverse=True)
406
+ # print(f'Ncrossing_resids before filtering:', Ncrossing_resids)
407
+ filtered_crossingN_resids = []
408
+ for c in Ncrossing_resids:
409
+ if abs(c) < (i - terminal_buff) and abs(c) > terminal_buff:
410
+ filtered_crossingN_resids.append(c)
411
+ # print(f'Filtered crossings after checking termini and loop buffers: {filtered_crossingN_resids}')
412
+
413
+ if len(filtered_crossingN_resids) > 1:
414
+ # check the distance between each pair of residues. If a pair is less than 10 residues apart, remove both crossings
415
+ # Use a greedy approach: iterate and keep crossings that are >=10 apart from the previous kept crossing
416
+ results = [filtered_crossingN_resids[0]]
417
+ for k in range(1, len(filtered_crossingN_resids)):
418
+ # Check if current crossing is >= 10 away from the last kept crossing
419
+ if abs(abs(filtered_crossingN_resids[k]) - abs(results[-1])) >= 10:
420
+ results.append(filtered_crossingN_resids[k])
421
+ filtered_crossingN_resids = results
422
+
423
+ crossingN = []
424
+ for c in filtered_crossingN_resids:
425
+ if c > 0:
426
+ crossingN.append(f'+{c}')
427
+ else:
428
+ crossingN.append(f'{c}')
429
+ # print(f'Final crossingN: {crossingN}\n')
430
+ crossings += crossingN
431
+
432
+
433
+ ## Parse the C terminal crossings
434
+ # crossingC = [f"{cr[0]}{resids[int(cr[1:])]}" for cr in data[native_contact]["crossingsC"]]
435
+ crossingC = []
436
+ # print("Before C:", data[native_contact]["beforeC"])
437
+ beforeC = [f"{cr[0]}{resids[int(cr[1:])]}" for cr in data[native_contact]["beforeC"] if '*' not in cr]
438
+ # print("Before C (formatted):", beforeC)
439
+ if beforeC:
440
+ # remove residues that violate the terminial and loop buffers
441
+ Ccrossing_resids = sorted([int(cr) for cr in beforeC], key=lambda x: abs(x))
442
+ # print(f'Ccrossing_resids before filtering:', Ccrossing_resids)
443
+ filtered_crossingC_resids = []
444
+ for c in Ccrossing_resids:
445
+ if abs(c) > (j + terminal_buff) and abs(c) < (resids[-1] - terminal_buff):
446
+ filtered_crossingC_resids.append(c)
447
+
448
+ # print(f'Filtered crossings after checking termini and loop buffers: {filtered_crossingC_resids}')
449
+ if len(filtered_crossingC_resids) > 1:
450
+ # check the distance between each pair of residues. If a pair is less than 10 residues apart, remove both crossings
451
+ # Use a greedy approach: iterate and keep crossings that are >=10 apart from the previous kept crossing
452
+ results = [filtered_crossingC_resids[0]]
453
+ for k in range(1, len(filtered_crossingC_resids)):
454
+ # Check if current crossing is >= 10 away from the last kept crossing
455
+ if abs(abs(filtered_crossingC_resids[k]) - abs(results[-1])) >= 10:
456
+ results.append(filtered_crossingC_resids[k])
457
+ filtered_crossingC_resids = results
458
+
459
+ crossingC = []
460
+ for c in filtered_crossingC_resids:
461
+ if c > 0:
462
+ crossingC.append(f'+{c}')
463
+ else:
464
+ crossingC.append(f'{c}')
465
+
466
+ # print(f'Final crossingC: {crossingC}\n')
467
+ crossings += crossingC
468
+
469
+ gn = nc_data[native_contact][0]
470
+ GLNn = nc_data[native_contact][2]
471
+
472
+ gc = nc_data[native_contact][1]
473
+ GLNc = nc_data[native_contact][3]
474
+
475
+ if crossingN:
476
+ TLNn_signs = [c[0] for c in crossingN if int(c[1:]) < i]
477
+ TLNn = [1 for s in TLNn_signs if s == '+'] + [-1 for s in TLNn_signs if s == '-']
478
+ TLNn = sum(TLNn)
479
+ else:
480
+ TLNn = 0
481
+
482
+ if crossingC:
483
+ TLNc_signs = [c[0] for c in crossingC if int(c[1:]) > j]
484
+ TLNc = [1 for s in TLNc_signs if s == '+'] + [-1 for s in TLNc_signs if s == '-']
485
+ TLNc = sum(TLNc)
486
+ else:
487
+ TLNc = 0
488
+
489
+ ij_gN_gC = (resids[native_contact[0]], resids[native_contact[1]]) + (gn, gc) + (GLNn, GLNc) + (TLNn, TLNc)
490
+
491
+ entangled_res[ij_gN_gC] = [np.unique(crossingN), np.unique(crossingC)]
492
+
493
+ return entangled_res
494
+ ##########################################################################################################################################################
495
+
496
+
497
+ ##########################################################################################################################################################
498
+ def crossing_filter(self, entanglements: dict, missing_res: np.ndarray) -> dict:
499
+
500
+ """
501
+ Remove entanglements if there are any missing residues plus-and-minus 10 of the crossing(s)
502
+
503
+ """
504
+ filtered_entanglements = {}
505
+ for ij_gN_gC, crossings in entanglements.items():
506
+ crossingsN, crossingsC = crossings
507
+
508
+ # Convert to list for easier manipulation
509
+ crossingsN_filtered = list(crossingsN)
510
+ crossingsC_filtered = list(crossingsC)
511
+
512
+ if len(crossingsN_filtered) > 0:
513
+ crossings_to_remove = []
514
+ for crossing in crossingsN_filtered:
515
+ reg_exp = re.split("\\+|-", crossing, maxsplit=1) # split the chirality
516
+ check_crossing = np.arange(int(reg_exp[1]) - 10 , int(reg_exp[1]) + 11)
517
+ missing_res_cr = np.intersect1d(check_crossing, missing_res)
518
+
519
+ if missing_res_cr.size:
520
+ crossings_to_remove.append(crossing)
521
+
522
+ for crossing in crossings_to_remove:
523
+ crossingsN_filtered.remove(crossing)
524
+
525
+ if len(crossingsC_filtered) > 0:
526
+ crossings_to_remove = []
527
+ for crossing in crossingsC_filtered:
528
+ reg_exp = re.split("\\+|-", crossing, maxsplit=1) # split the chirality
529
+ check_crossing = np.arange(int(reg_exp[1]) - 10 , int(reg_exp[1]) + 11)
530
+ missing_res_cr = np.intersect1d(check_crossing, missing_res)
531
+
532
+ if missing_res_cr.size:
533
+ crossings_to_remove.append(crossing)
534
+
535
+ for crossing in crossings_to_remove:
536
+ crossingsC_filtered.remove(crossing)
537
+
538
+ filtered_entanglements[ij_gN_gC] = [np.array(crossingsN_filtered), np.array(crossingsC_filtered)]
539
+
540
+ # filtered_entanglements = {nc: re_cr for nc, re_cr in entanglements.items() if re_cr is not None and len(re_cr) > 0}
541
+ # filtered_entanglements = {nc: re_cr for nc, re_cr in entanglements.items() if re_cr is not None}
542
+
543
+ return filtered_entanglements
544
+ ##########################################################################################################################################################
545
+
546
+
547
+ ##########################################################################################################################################################
548
+ def check_disulfideBonds(self, pdb_file):
549
+
550
+ # Parse the PDB structure
551
+ parser = PDBParser()
552
+ structure = parser.get_structure('protein', pdb_file)
553
+
554
+ disulfide_bonds = []
555
+ # Iterate over residues and identify disulfide bonds
556
+ for model in structure:
557
+ for chain in model:
558
+ for residue in chain:
559
+ if residue.get_resname() == 'CYS':
560
+ if 'SG' in residue:
561
+ sg_atom = residue['SG']
562
+ else:
563
+ continue
564
+ # Check for disulfide bonds with distance threshold (e.g., <2.2 Å)
565
+ for model2 in structure:
566
+ for chain2 in model2:
567
+ for residue2 in chain2:
568
+ if residue2.get_resname() == 'CYS' and residue != residue2:
569
+ if 'SG' in residue2:
570
+ sg_atom2 = residue2['SG']
571
+ else:
572
+ continue
573
+ distance = sg_atom - sg_atom2
574
+ if distance < 2.2:
575
+ self.logger.info(f"Disulfide bond between {residue} and {residue2} at distance {distance:.2f} Å")
576
+ i,j = residue.get_id()[1], residue2.get_id()[1]
577
+
578
+ if (i,j) not in disulfide_bonds and (j,i) not in disulfide_bonds:
579
+ disulfide_bonds += [(i,j)]
580
+ return disulfide_bonds
581
+ ##########################################################################################################################################################
582
+
583
+
584
+ ##########################################################################################################################################################
585
+ def calculate_native_entanglements(self, pdb_file: str, outdir: str, ID: str='', chain: str=None, topoly: bool = True) -> dict:
586
+
587
+ """
588
+ Driver function that outputs native lasso-like self entanglements and missing residues for pdb and all of its chains if any
589
+ """
590
+
591
+ ## set up the outdir for this calculation
592
+ #outdir = f"{os.getcwd()}/{outdir}"
593
+ if not os.path.isdir(outdir):
594
+ os.mkdir(f"{outdir}")
595
+ self.logger.info(f"Creating directory: {outdir}")
596
+
597
+ if not os.path.isdir(f"{outdir}/unmapped_missing_residues"):
598
+ os.mkdir(f"{outdir}/unmapped_missing_residues")
599
+ self.logger.info(f"Creating directory: {outdir}/unmapped_missing_residues")
600
+
601
+
602
+ ## get the PDB file name
603
+ pdb = pdb_file.split('/')[-1].split(".")[0]
604
+ if ID == '':
605
+ ID = pdb
606
+ self.logger.info(f"\n{'#'*100}\nCOMPUTING ENTANGLEMENTS FOR \033[4m{pdb}\033[0m with ID {ID}")
607
+
608
+ ## Define the outfile and check if it exists. If so load it else create it
609
+ outfile = os.path.join(f'{outdir}', f'{ID}_GE.csv')
610
+ if os.path.exists(outfile):
611
+ self.logger.info(f'{outfile} ALREADY EXISTS AND WILL BE LOADED')
612
+ outdf = pd.read_csv(outfile, sep='|', dtype={'c': str})
613
+ return {'outfile':outfile, 'ent_result':outdf}
614
+
615
+ ## Load the reference universe and use the MDA parser for it with the special excapetion for our CG files that end in .cor
616
+ if pdb_file.endswith('.cor'):
617
+ ref_univ = mda.Universe(f'{pdb_file}', format='CRD')
618
+ else:
619
+ ref_univ = mda.Universe(f"{pdb_file}")
620
+ self.logger.debug(f'ref_univ: {ref_univ}')
621
+
622
+
623
+ ### Find any disulfide bonds
624
+ self.logger.info(f'Checking for disulfide bonds')
625
+ disulfide_bonds = self.check_disulfideBonds(pdb_file)
626
+ self.disulfide_bonds = disulfide_bonds
627
+ self.logger.debug(f'disulfide_bonds: {disulfide_bonds}')
628
+
629
+
630
+ ## Get only the heavy atoms or CA atoms depending on what type of contact we are looking for
631
+ if self.Calpha == False and self.CG == False:
632
+ self.logger.info('All-atom model and contacts: Selecting all heavy atoms (no hydrogens) for entanglement calculations')
633
+ ref_allatoms_dups = ref_univ.select_atoms("not name H* and protein")
634
+ elif self.Calpha == True and self.CG == False:
635
+ self.logger.info('All-atom model and Calpha contacts: Selecting only CA atoms for entanglement calculations')
636
+ ref_allatoms_dups = ref_univ.select_atoms("name CA and protein")
637
+ elif self.CG == True:
638
+ self.logger.info('Coarse-grained model: Selecting all atoms for entanglement calculations')
639
+ ref_allatoms_dups = ref_univ.select_atoms("all")
640
+ #print(f'ref_allatoms_dups: {ref_allatoms_dups} {len(ref_allatoms_dups)}')
641
+
642
+ chains_to_analyze = set(ref_univ.segments.segids)
643
+ if chain is not None:
644
+ chains_to_analyze = {chain} if chain in chains_to_analyze else set()
645
+ if not chains_to_analyze:
646
+ raise ValueError(f"Chain {chain} not found in structure. Available chains: {set(ref_univ.segments.segids)}")
647
+
648
+ for chain in chains_to_analyze:
649
+
650
+ ## Check for duplicate residues
651
+ atom_data = []
652
+ check = set()
653
+
654
+ for atom in ref_allatoms_dups.select_atoms(f"segid {chain}").atoms:
655
+
656
+ atom_data.append((atom.resid, atom.name))
657
+ check.add((atom.resid, atom.name))
658
+
659
+ temp_df = pd.DataFrame(atom_data, columns=["resid", "name"])
660
+
661
+ unique_rows = temp_df.drop_duplicates()
662
+ unique_indices = unique_rows.index.tolist()
663
+
664
+ assert len(check) == len(unique_indices), "You did not remove dup atoms!"
665
+
666
+ ref_allatoms_unique = ref_allatoms_dups.select_atoms(f"segid {chain}")[unique_indices]
667
+ #print(f'ref_allatoms_unique: {ref_allatoms_unique} {len(ref_allatoms_unique)}')
668
+
669
+ ## select only those unique residue alpha carbons
670
+ if self.CG == False:
671
+ ref_ca_unique = ref_allatoms_unique.select_atoms("name CA")
672
+ else:
673
+ ref_ca_unique = ref_allatoms_unique.select_atoms("all")
674
+ #print(f'ref_ca_unique: {ref_ca_unique} {len(ref_ca_unique)}')
675
+
676
+
677
+ resid_index_to_ref_allatoms_idx = {}
678
+ resid_index_to_resid = {}
679
+ ref_allatoms_idx_to_resid = {}
680
+ atom_ix = 0
681
+ res_ix = 0
682
+ PDB_resids = ref_ca_unique.resids
683
+ #print(f'PDB_resids: {PDB_resids}')
684
+ new_atm_idx = []
685
+
686
+ ## QC if the chain has only one alpha carbon or none
687
+ if len(PDB_resids) == 0 or len(PDB_resids) == 1:
688
+ raise ValueError(f"Skipping over chain {chain} for \033[4m{pdb}\033[0m since chain has only one alpha carbon or none")
689
+
690
+
691
+ for atom in ref_allatoms_unique.atoms:
692
+ new_atm_idx.append(atom_ix)
693
+ ref_allatoms_idx_to_resid[atom_ix] = [atom.resid]
694
+
695
+ if atom_ix == new_atm_idx[0]:
696
+ resid = atom.resid
697
+ resid_index_to_ref_allatoms_idx[res_ix] = [atom_ix]
698
+ resid_index_to_resid[res_ix] = resid
699
+ atom_ix += 1
700
+ else:
701
+ if atom.resid == resid:
702
+ resid_index_to_ref_allatoms_idx[res_ix] += [atom_ix]
703
+ resid_index_to_resid[res_ix] = resid
704
+ resid = atom.resid
705
+ atom_ix += 1
706
+ else:
707
+ res_ix += 1
708
+ resid_index_to_ref_allatoms_idx[res_ix] = [atom_ix]
709
+ resid_index_to_resid[res_ix] = resid
710
+ atom_ix += 1
711
+ resid = atom.resid
712
+
713
+ ## Quality check that if Calpha is True there is 1-to-1 mapping of resid index to allatom indexs
714
+ if self.Calpha == True:
715
+ for k,v in resid_index_to_ref_allatoms_idx.items():
716
+ if len(v) != 1:
717
+ raise ValueError(f'When Calpha is specified there should only be one resid index for each all atom index: resid index {k} has {v}')
718
+
719
+ assert len(new_atm_idx) == np.concatenate(list(resid_index_to_ref_allatoms_idx.values())).size, f"Not enough atom indicies! {pdb_file}"
720
+
721
+ # x y z cooridnates of chain
722
+ coor = ref_allatoms_unique.atoms.positions[new_atm_idx]
723
+
724
+ for resid_idx, all_atom_idx in resid_index_to_ref_allatoms_idx.items():
725
+
726
+ resid = PDB_resids[resid_idx]
727
+ #print(resid_idx, resid, all_atom_idx)
728
+
729
+ check_coor = coor[all_atom_idx]
730
+ #print(f'check_coor:\n{check_coor}')
731
+
732
+ structure_coor = ref_allatoms_unique.select_atoms(f"resid {resid}").positions
733
+ #print(f'structure_coor:\n{structure_coor}')
734
+
735
+ try:
736
+ np.all(check_coor == structure_coor)
737
+ except:
738
+ raise ValueError(f'Error in checking residue coordinates: most likely caused by resides with letters after them or specifying CG=True when it is infact an allatom model. Check resid: {resid} in PDB')
739
+
740
+ if not np.all(check_coor == structure_coor):
741
+ raise ValueError(f"Coordinates do not match up! Resid {resid} {pdb_file}")
742
+
743
+
744
+ ca_coor = ref_ca_unique.positions
745
+
746
+ resnames = ref_ca_unique.resnames
747
+
748
+ chain_res = PDB_resids.size
749
+
750
+ if PDB_resids.size:
751
+
752
+ ent_result, missing_residues = self.get_entanglements(coor, chain_res, pdb_file, PDB_resids, resnames, resid_index_to_ref_allatoms_idx, ca_coor, resid_index_to_resid, topoly=topoly)
753
+ # print(f'Number of ENTs found: {len(ent_result)}\n{ent_result}')
754
+
755
+
756
+ ## If there is Native entanglement then save a file
757
+ if ent_result:
758
+ outfile = os.path.join(f'{outdir}', f'{ID}_GE.csv')
759
+ #print(f'WRITING: {outfile}')
760
+ outdf = {'ID':[], 'chain':[], 'i':[], 'j': [], 'crossingsN': [], 'crossingsC': [], 'gn':[], 'gc':[], 'GLNn':[], 'GLNc':[], 'TLNn':[], 'TLNc':[], 'CCbond':[]}
761
+
762
+ for ij_gN_gC, crossings in ent_result.items():
763
+ # print(f'ij_gN_gC: {ij_gN_gC} with crossings: {crossings}')
764
+ crossingsN, crossingsC = crossings
765
+ crossingsN = ','.join(crossingsN)
766
+ crossingsC = ','.join(crossingsC)
767
+ i, j, gn, gc, GLNn, GLNc, TLNn, TLNc = ij_gN_gC
768
+
769
+ if (i,j) in disulfide_bonds or (j,i) in disulfide_bonds:
770
+ CCbond = True
771
+ else:
772
+ CCbond = False
773
+
774
+ # print(f'Contact: ({i}, {j}) with GLNn: {GLNn} and GLNc: {GLNc} has crossings: {crossingsN} {crossingsC} and CCbond: {CCbond}')
775
+
776
+ outdf['ID'] += [ID]
777
+ outdf['chain'] += [chain]
778
+ outdf['i'] += [i]
779
+ outdf['j'] += [j]
780
+ outdf['crossingsN'] += [crossingsN]
781
+ outdf['crossingsC'] += [crossingsC]
782
+ outdf['gn'] += [f'{gn: .5f}']
783
+ outdf['gc'] += [f'{gc: .5f}']
784
+ outdf['GLNn'] += [GLNn]
785
+ outdf['GLNc'] += [GLNc]
786
+ outdf['TLNn'] += [TLNn]
787
+ outdf['TLNc'] += [TLNc]
788
+ outdf['CCbond'] += [CCbond]
789
+
790
+ outdf = pd.DataFrame(outdf)
791
+ outdf['ENT'] = outdf.apply(lambda row: self.determine_ent_status(row['GLNn'], row['GLNc'], row['TLNn'], row['TLNc']), axis=1)
792
+ outdf.to_csv(outfile, sep='|', index=False)
793
+ self.logger.info(f'SAVED: {outfile}')
794
+ ent_result = pd.read_csv(outfile, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
795
+ else:
796
+ self.logger.info(f'NO CONTACTS DETECTED for {pdb}')
797
+ ent_result = pd.DataFrame({'ID':[], 'chain':[], 'i':[], 'j': [], 'crossingsN': [], 'crossingsC': [], 'gn':[], 'gc':[], 'GLNn':[], 'GLNc':[], 'TLNn':[], 'TLNc':[], 'CCbond':[], 'ENT':[]})
798
+ ent_result.to_csv(outfile, sep='|', index=False)
799
+ self.logger.info(f'SAVED: {outfile}')
800
+ ent_result = pd.read_csv(outfile, sep='|', dtype={'c': str})
801
+
802
+ if len(missing_residues):
803
+ self.logger.info(f'WRITING: {pdb}_M.txt')
804
+ with open(f"{outdir}/unmapped_missing_residues/{pdb}_M.txt", "a") as f:
805
+ f.write(f"Chain {chain}: ")
806
+ for m_res in missing_residues:
807
+ f.write(f"{m_res} ")
808
+ f.write("\n")
809
+
810
+ ## Return a dictionary with the outfile and the results
811
+ return {'outfile':outfile, 'ent_result':ent_result}
812
+
813
+
814
+ ##########################################################################################################################################################
815
+ def calculate_traj_entanglements(
816
+ self,
817
+ dcd: str,
818
+ PSF: str,
819
+ outdir: str = './',
820
+ ID: str = '',
821
+ topoly: bool = True,
822
+ start: int = 0,
823
+ stop: int = 999999999,
824
+ stride: int = 1,
825
+ ref_contact_file: Optional[str] = None,
826
+ ) -> dict:
827
+
828
+ """
829
+ Driver function that takes a CA coarse grained MD trajectory and looks for entanglements.
830
+ """
831
+
832
+ ## set up the outdir for this calculation
833
+ #outdir = f"{os.getcwd()}/{outdir}"
834
+ if not os.path.isdir(outdir):
835
+ os.mkdir(f"{outdir}")
836
+ self.logger.info(f"Creating directory: {outdir}")
837
+
838
+ ## get the DCD file name
839
+ dcd_name = dcd.split('/')[-1].split(".")[0]
840
+ if ID == '':
841
+ ID = dcd_name
842
+ self.logger.info(f"\n{'#'*100}\nCOMPUTING ENTANGLEMENTS FOR \033[4m{dcd_name}\033[0m with ID {ID}")
843
+ self.logger.debug(f'Topoly: {topoly} {type(topoly)}')
844
+
845
+
846
+ ## Define the outfile and check if it exists. If so load it else create it
847
+ outfile = os.path.join(f'{outdir}', f'{ID}_GE.csv')
848
+ if os.path.exists(outfile):
849
+ self.logger.info(f'{outfile} ALREADY EXISTS AND WILL BE LOADED')
850
+ outdf = pd.read_csv(outfile, sep='|', dtype={'c': str})
851
+ if ref_contact_file is not None:
852
+ try:
853
+ ref_df = pd.read_csv(ref_contact_file, sep='|', usecols=['chain', 'i', 'j'])
854
+ ref_min = np.minimum(ref_df['i'].astype(int).to_numpy(), ref_df['j'].astype(int).to_numpy())
855
+ ref_max = np.maximum(ref_df['i'].astype(int).to_numpy(), ref_df['j'].astype(int).to_numpy())
856
+ ref_keys = (
857
+ ref_df['chain'].astype(str).to_numpy().astype(str)
858
+ + ':'
859
+ + ref_min.astype(str)
860
+ + '-'
861
+ + ref_max.astype(str)
862
+ )
863
+ ref_keys = set(ref_keys.tolist())
864
+
865
+ out_min = np.minimum(outdf['i'].astype(int).to_numpy(), outdf['j'].astype(int).to_numpy())
866
+ out_max = np.maximum(outdf['i'].astype(int).to_numpy(), outdf['j'].astype(int).to_numpy())
867
+ out_keys = (
868
+ outdf['chain'].astype(str).to_numpy().astype(str)
869
+ + ':'
870
+ + out_min.astype(str)
871
+ + '-'
872
+ + out_max.astype(str)
873
+ )
874
+ mask = pd.Series(out_keys).isin(ref_keys).to_numpy()
875
+ filtered = outdf.loc[mask].reset_index(drop=True)
876
+ if len(filtered) != len(outdf):
877
+ self.logger.info(
878
+ f'Filtered existing Traj_GE output to reference contacts: '
879
+ f'{len(outdf)} -> {len(filtered)} rows (ref: {ref_contact_file})'
880
+ )
881
+ filtered.to_csv(outfile, sep='|', index=False)
882
+ outdf = filtered
883
+ except Exception as e:
884
+ self.logger.warning(f'WARNING: Failed to filter Traj_GE output using ref_contact_file={ref_contact_file}: {e}')
885
+
886
+ return {'outfile': outfile, 'ent_result': outdf}
887
+
888
+ ## Else analyze the traj and create the outfile
889
+ univ = mda.Universe(PSF, dcd)
890
+ self.logger.debug(f'univ: {univ}')
891
+
892
+ chains_to_analyze = set(univ.segments.segids)
893
+
894
+ ## define the output dataframe
895
+ # outdf = {'ID': [], 'chain':[], 'frame':[], 'i':[], 'j': [], 'c': [], 'gn':[], 'gc':[], 'Gn':[], 'Gc':[]}
896
+ rows = []
897
+
898
+ for chain in chains_to_analyze:
899
+ self.logger.info(f'Analyzing chain {chain}')
900
+
901
+ # Get the coordinates of the chain
902
+ chain_atoms = univ.select_atoms(f"segid {chain}")
903
+
904
+ resids = chain_atoms.resids
905
+ self.logger.debug(f'resids: {resids}')
906
+
907
+ resnames = chain_atoms.resnames
908
+ self.logger.debug(f'resnames: {resnames}')
909
+
910
+ chain_res = resids.size
911
+ self.logger.debug(f'chain_res: {chain_res}')
912
+
913
+ atom_names = chain_atoms.names
914
+ self.logger.debug(f'atom_names: {atom_names}')
915
+
916
+
917
+ frame_indices = []
918
+ for ts in univ.trajectory[start:stop:stride]:
919
+ frame_indices.append(ts.frame)
920
+ self.logger.info(f'Analyzing frames from {start} to {stop} with stride {stride}')
921
+ self.logger.info(f'Total frames to analyze: {len(frame_indices)}')
922
+ self.logger.debug(f'frame_indices: {frame_indices[:10]} ... {frame_indices[-10:]}')
923
+
924
+ pool_args = [
925
+ (frame_idx, dcd, PSF, chain, chain_res, resids, atom_names, topoly, ID, self.Calpha, self.CG, self.g_threshold, self.density, self.ent_detection_method)
926
+ for frame_idx in frame_indices
927
+ ]
928
+ import multiprocessing as mp
929
+ with mp.get_context("spawn").Pool(processes=self.nproc) as pool:
930
+ all_rows = pool.map(process_frame, pool_args)
931
+
932
+ for frame_rows in all_rows:
933
+ rows.extend(frame_rows)
934
+
935
+ # outdf = pd.DataFrame(outdf)
936
+ outdf = pd.DataFrame(rows, columns=['ID','chain','frame','i','j','crossingsN','crossingsC','gn','gc','GLNn','GLNc','TLNn','TLNc'])
937
+ outdf['frame'] = pd.to_numeric(outdf['frame'], errors='coerce')
938
+ outdf = outdf.sort_values(by='frame', ascending=True).reset_index(drop=True)
939
+
940
+ outdf['ENT'] = outdf.apply(lambda row: self.determine_ent_status(row['GLNn'], row['GLNc'], row['TLNn'], row['TLNc']), axis=1)
941
+
942
+ if ref_contact_file is not None:
943
+ try:
944
+ ref_df = pd.read_csv(ref_contact_file, sep='|', usecols=['chain', 'i', 'j'])
945
+ ref_min = np.minimum(ref_df['i'].astype(int).to_numpy(), ref_df['j'].astype(int).to_numpy())
946
+ ref_max = np.maximum(ref_df['i'].astype(int).to_numpy(), ref_df['j'].astype(int).to_numpy())
947
+ ref_keys = (
948
+ ref_df['chain'].astype(str).to_numpy().astype(str)
949
+ + ':'
950
+ + ref_min.astype(str)
951
+ + '-'
952
+ + ref_max.astype(str)
953
+ )
954
+ ref_keys = set(ref_keys.tolist())
955
+
956
+ out_min = np.minimum(outdf['i'].astype(int).to_numpy(), outdf['j'].astype(int).to_numpy())
957
+ out_max = np.maximum(outdf['i'].astype(int).to_numpy(), outdf['j'].astype(int).to_numpy())
958
+ out_keys = (
959
+ outdf['chain'].astype(str).to_numpy().astype(str)
960
+ + ':'
961
+ + out_min.astype(str)
962
+ + '-'
963
+ + out_max.astype(str)
964
+ )
965
+ mask = pd.Series(out_keys).isin(ref_keys).to_numpy()
966
+ before = len(outdf)
967
+ outdf = outdf.loc[mask].reset_index(drop=True)
968
+ after = len(outdf)
969
+ self.logger.info(f'Filtered Traj_GE to reference contacts: {before} -> {after} rows (ref: {ref_contact_file})')
970
+ except Exception as e:
971
+ self.logger.warning(f'WARNING: Failed to filter Traj_GE output using ref_contact_file={ref_contact_file}: {e}')
972
+
973
+ self.logger.info(f'outdf:\n{outdf.head()}\n{outdf.tail()}')
974
+ outdf.to_csv(outfile, sep='|', index=False)
975
+ self.logger.info(f'SAVED: {outfile}')
976
+
977
+ # Multiprocessing disables per-frame timing, so skip frame_times and mean_time reporting
978
+ # Return a dictionary with the outfile and the results
979
+ return {'outfile':outfile, 'ent_result':outdf}
980
+ ##########################################################################################################################################################
981
+
982
+ ##########################################################################################################################################################
983
+
984
+ ##########################################################################################################################################################
985
+ def get_traj_entanglements(self, coor: np.ndarray, l: int, resids: np.ndarray, atom_names: np.ndarray, topoly:bool=True, dist_cutoff:float=8.0, termini_threshold: list=[5,5], loop_thread_thresh: list=[4,4]) -> dict:
986
+
987
+ """
988
+ Find proteins containing non-covalent lasso entanglements.
989
+
990
+ Entanglements are composed of loops (defined by native contacts) and crossing residue(s).
991
+
992
+ """
993
+
994
+ ## Check that if topoly is False then ent_detection_method is not 2 or 3 since those require TLN which requires topoly
995
+ if not topoly and self.ent_detection_method in (2, 3):
996
+ self.logger.warning(f'topoly=False but ent_detection_method={self.ent_detection_method} requires TLN — no NCLEs will be detected. Use ent_detection_method=1 (GLN) when topoly is disabled.')
997
+ raise ValueError(f'topoly=False but ent_detection_method={self.ent_detection_method} requires TLN — no NCLEs will be detected. Use ent_detection_method=1 (GLN) when topoly is disabled.')
998
+
999
+ Nterm_thresh = termini_threshold[0]
1000
+ Cterm_thresh = termini_threshold[1]
1001
+ loop_Nthread_thresh = loop_thread_thresh[0]
1002
+ loop_Cthread_thresh = loop_thread_thresh[1]
1003
+ self.logger.debug(f'Finding entanglements with Nterm_thresh: {Nterm_thresh}, Cterm_thresh: {Cterm_thresh}, loop_Nthread_thresh: {loop_Nthread_thresh}, loop_Cthread_thresh: {loop_Cthread_thresh}')
1004
+
1005
+ # make native contact contact map
1006
+ native_cmap, bb_coor, dist_matrix = self.processes_coor(coor, resids, atom_names, CG=self.CG, Calpha=self.Calpha)
1007
+ l = len(bb_coor)
1008
+ # print(f'native_cmap: {native_cmap.shape} {native_cmap}')
1009
+ # print(f'bb_coor: {bb_coor.shape} {bb_coor}')
1010
+ # print(f'dist_matrix: {dist_matrix.shape} {dist_matrix}')
1011
+
1012
+ nc_indexs = np.stack(np.nonzero(native_cmap)).T # stack indices based on rows
1013
+ # print(f'nc_indexs: {nc_indexs.shape} {nc_indexs}')
1014
+
1015
+ # make R coordinate and gradient of R length N-1
1016
+ range_l = np.arange(0, l-1)
1017
+ range_next_l = np.arange(1,l)
1018
+
1019
+ bb_coor = bb_coor.astype(np.float32)
1020
+ R = 0.5*(bb_coor[range_l] + bb_coor[range_next_l])
1021
+ dR = bb_coor[range_next_l] - bb_coor[range_l]
1022
+
1023
+ #make dRcross matrix
1024
+ pair_array = np.asarray(list(itertools.product(dR,dR))) # combination of elements within array
1025
+
1026
+ x = pair_array[:,0,:]
1027
+ y = pair_array[:,1,:]
1028
+
1029
+ dR_cross = np.cross(x, y)
1030
+
1031
+ #make Rnorm matrix
1032
+ pair_array = np.asarray(list(itertools.product(R,R)))
1033
+ diff = pair_array[:,0,:] - pair_array[:,1,:]
1034
+ diff = diff.astype(np.float32)
1035
+
1036
+ Runit = diff / np.linalg.norm(diff, axis=1)[:,None]**3
1037
+ Runit = Runit.astype(np.float32)
1038
+
1039
+ #make final dot matrix
1040
+ dot_matrix = self.helper_dot(Runit, dR_cross)
1041
+ dot_matrix = np.asarray(dot_matrix)
1042
+ dot_matrix = dot_matrix.reshape((l-1,l-1))
1043
+
1044
+ nc_gdict = {}
1045
+
1046
+ for i,j in nc_indexs:
1047
+
1048
+ # loop_range = np.arange(i,j)
1049
+ # nterm_range = np.arange(Nterm_thresh,i-5)
1050
+ # cterm_range = np.arange(j+6,l-(Cterm_thresh + 1))
1051
+ loop_range = np.arange(i, j)
1052
+ nterm_range = np.arange(Nterm_thresh, i-loop_Nthread_thresh-1)
1053
+ cterm_range = np.arange(j+loop_Cthread_thresh+1, l-(Cterm_thresh + 1))
1054
+
1055
+ gn_pairs_array = np.fromiter(itertools.chain(*itertools.product(nterm_range, loop_range)), int).reshape(-1, 2)
1056
+ gc_pairs_array = np.fromiter(itertools.chain(*itertools.product(loop_range, cterm_range)), int).reshape(-1, 2)
1057
+
1058
+ if gn_pairs_array.size != 0:
1059
+
1060
+ gn_vals = dot_matrix[gn_pairs_array[:,0],gn_pairs_array[:,1]]
1061
+ gn_vals = gn_vals[~np.isnan(gn_vals)]
1062
+ gn_val = np.sum(gn_vals) / (4.0 * np.pi)
1063
+
1064
+ else:
1065
+ gn_val = 0
1066
+
1067
+ if gc_pairs_array.size != 0:
1068
+
1069
+ gc_vals = dot_matrix[gc_pairs_array[:,0],gc_pairs_array[:,1]]
1070
+ gc_vals = gc_vals[~np.isnan(gc_vals)]
1071
+ gc_val = np.sum(gc_vals) / (4.0 * np.pi)
1072
+
1073
+ else:
1074
+ gc_val = 0
1075
+
1076
+ rounded_gc_val = self.point_rounding(np.float64(gc_val))
1077
+ rounded_gn_val = self.point_rounding(np.float64(gn_val))
1078
+
1079
+ #if np.abs(rounded_gn_val) >= 1 or np.abs(rounded_gc_val) >= 1:
1080
+ # #print(f'({i}, {j}) with gn: {gn_val} and gc: {gc_val}')
1081
+ # nc_gdict[ (int(i), int(j)) ] = (gn_val, gc_val, rounded_gn_val, rounded_gc_val)
1082
+ nc_gdict[ (int(i), int(j)) ] = (gn_val, gc_val, rounded_gn_val, rounded_gc_val)
1083
+
1084
+ ## check for crossings if there are entanglements and topoly==True
1085
+ if len(nc_gdict) == 0:
1086
+ #print(f'No entanglements found')
1087
+ return {}
1088
+
1089
+ else:
1090
+ if topoly == True:
1091
+ entangled_res = self.find_crossing(bb_coor.tolist(), nc_gdict, resids)
1092
+ # print(f'entangled_res: {entangled_res}')
1093
+ # for k, v in entangled_res.items():
1094
+ # print(f'{k}: {v}')
1095
+ # quit()
1096
+ return entangled_res
1097
+
1098
+ if topoly == False:
1099
+ entangled_res = {}
1100
+ for ij, values in nc_gdict.items():
1101
+ i,j = ij[0], ij[1]
1102
+ gn = values[0]
1103
+ gc = values[1]
1104
+ GLNn = values[2]
1105
+ GLNc = values[3]
1106
+ TLNn = np.nan
1107
+ TLNc = np.nan
1108
+ entangled_res[(resids[i], resids[j], gn, gc, GLNn, GLNc, TLNn, TLNc)] = [[], []]
1109
+ return entangled_res
1110
+ ##########################################################################################################################################################
1111
+
1112
+ ##########################################################################################################################################################
1113
+ def processes_coor(self, coor: np.ndarray, resids: np.ndarray, atom_names: np.ndarray, CG: bool=False, Calpha: bool=False) -> tuple:
1114
+ """
1115
+ Processes the coordinates of the protein to create a residue level contact map and backbone coordinates.
1116
+ If CG is True, it uses the full coor array as the backbone coordinates and the value of Calpha does not matter.
1117
+ If CG is False and Calpha is True, it uses the coordinates of the alpha carbons to determine if two residues are in contact.
1118
+ If CG is False and Calpha is False, it uses the coordinates of the heavy atoms to determine if two residues are in contact.
1119
+ In all cases the backbone coordinates are the alphacarbon coordinates of the residues.
1120
+ Returns a tuple of the native contact map and backbone coordinates.
1121
+ """
1122
+ heavy_atom_names = ['C', 'O', 'N', 'CA', 'CB', 'CG', 'CD', 'CE', 'CZ', 'SD', 'SG'] # heavy atoms for all atom models
1123
+ # for idx, c in enumerate(coor):
1124
+ # print(f'Atom {idx}: {atom_names[idx]} at position {c}')
1125
+ if CG == True:
1126
+ bb_coor = coor
1127
+ # make native contact contact map
1128
+ dist_matrix = squareform(pdist(bb_coor))
1129
+ # print the index pairs and distances where distance is not 0
1130
+ # for i in range(dist_matrix.shape[0]):
1131
+ # for j in range(dist_matrix.shape[1]):
1132
+ # print(f'Contact between residues at indexs {i} and {j} at distance {dist_matrix[i, j]} Å')
1133
+ native_cmap = np.where(dist_matrix <= 8.0, 1, 0) # if true then 1 will appear otherwise zero
1134
+ native_cmap = np.triu(native_cmap, k=4) # element below the 4th diagonal starting from middle are all zeros; # protein contact map
1135
+ return native_cmap, bb_coor, dist_matrix
1136
+
1137
+ if CG == False:
1138
+ CA_idx = [i for i, resname in enumerate(atom_names) if resname == 'CA']
1139
+ # print(f'CA_idx: {CA_idx}')
1140
+ bb_coor = coor[CA_idx]
1141
+ # print(f'bb_coor: {bb_coor.shape} {bb_coor}')
1142
+
1143
+ if Calpha == True:
1144
+ # make native contact contact map
1145
+ dist_matrix = squareform(pdist(bb_coor))
1146
+ # print the index pairs and distances where distance is not 0
1147
+ # for i in range(dist_matrix.shape[0]):
1148
+ # for j in range(dist_matrix.shape[1]):
1149
+ # print(f'Contact between residues at indexs {i} and {j} at distance {dist_matrix[i, j]} Å')
1150
+ native_cmap = np.where(dist_matrix <= 8.0, 1, 0) # if true then 1 will appear otherwise zero
1151
+ native_cmap = np.triu(native_cmap, k=4) # element below the 4th diagonal starting from middle are all zeros; # protein contact map
1152
+ return native_cmap, bb_coor, dist_matrix
1153
+
1154
+ if Calpha == False:
1155
+
1156
+ # Select heavy atoms
1157
+ heavy_mask = np.isin(atom_names, heavy_atom_names)
1158
+ heavy_coor = coor[heavy_mask]
1159
+ heavy_resids = resids[heavy_mask]
1160
+
1161
+ # Distance matrix between heavy atoms
1162
+ heavy_dist_matrix = squareform(pdist(heavy_coor))
1163
+ heavy_native_cmap = heavy_dist_matrix <= 4.5
1164
+
1165
+ # Precompute which atoms belong to which residue
1166
+ residue_to_atom_indices = defaultdict(list)
1167
+ for idx, resid in enumerate(heavy_resids):
1168
+ residue_to_atom_indices[resid].append(idx)
1169
+
1170
+ # Unique residues and index mapping
1171
+ unique_resids = np.unique(resids)
1172
+ resid_to_index = {resid: i for i, resid in enumerate(unique_resids)}
1173
+ n = len(unique_resids)
1174
+ native_cmap = np.zeros((n, n), dtype=int)
1175
+
1176
+ # Only compute upper triangle (excluding diagonal)
1177
+ for i, resid_i in enumerate(unique_resids):
1178
+ atoms_i = residue_to_atom_indices.get(resid_i, [])
1179
+ for j in range(i + 4, n): # skip close-in-sequence residues
1180
+ resid_j = unique_resids[j]
1181
+ atoms_j = residue_to_atom_indices.get(resid_j, [])
1182
+
1183
+ # Skip if either residue has no heavy atoms
1184
+ if not atoms_i or not atoms_j:
1185
+ continue
1186
+
1187
+ # Use any contact between atoms of residues i and j
1188
+ contact_exists = np.any(heavy_native_cmap[np.ix_(atoms_i, atoms_j)])
1189
+ if contact_exists:
1190
+ native_cmap[i, j] = 1
1191
+ native_cmap[j, i] = 1
1192
+
1193
+ return native_cmap, bb_coor, heavy_dist_matrix
1194
+ ##########################################################################################################################################################
1195
+
1196
+ ##########################################################################################################################################################
1197
+ def combine_ref_traj_GE(self, RefFile: dict, TrajFile: dict, outdir: str='./', ID: str='', chunk_frames: int=None, chunk_suffix: str='_chunk'):
1198
+ """
1199
+ Combines reference and trajectory entanglements into pickle files.
1200
+
1201
+ If chunk_frames is None (default): creates single {ID}_GE.pkl with all frames (backward compatible)
1202
+ If chunk_frames > 0: creates multiple chunk files {ID}{chunk_suffix}_0000.pkl, etc., each with ref data
1203
+
1204
+ Each chunk pickle contains: {'ref': ref_dict, frame_num: frame_dict, ...}
1205
+ """
1206
+ ## set up the outdir for this calculation
1207
+ if not os.path.isdir(outdir):
1208
+ os.mkdir(f"{outdir}")
1209
+ self.logger.info(f"Creating directory: {outdir}")
1210
+
1211
+ ## Load reference and trajectory data
1212
+ Ref = pd.read_csv(RefFile, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
1213
+ Traj = pd.read_csv(TrajFile, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
1214
+ self.logger.info(f'Ref {RefFile}')
1215
+ self.logger.info(f'Traj {TrajFile}')
1216
+
1217
+ ##########################################################################################
1218
+ ## Parse the reference entanglements into ref_dict
1219
+ ref_dict = {'ent_fingerprint':{}, 'chg_ent_fingerprint':None, 'G_dict':None, 'G':None}
1220
+ Num_native_contacts = len(Ref)
1221
+ self.logger.debug(f'Num_native_contacts: {Num_native_contacts}')
1222
+
1223
+ for rowi, row in Ref.iterrows():
1224
+ i = int(row['i'])
1225
+ j = int(row['j'])
1226
+ key = (i, j)
1227
+ crossing_resid, crossing_pattern = self.processes_crossings(row)
1228
+ gn = float(row['gn'])
1229
+ gc = float(row['gc'])
1230
+ GLNn = int(row['GLNn'])
1231
+ GLNc = int(row['GLNc'])
1232
+ TLNn = np.nan if pd.isna(row['TLNn']) else int(row['TLNn'])
1233
+ TLNc = np.nan if pd.isna(row['TLNc']) else int(row['TLNc'])
1234
+ value = {'linking_value': [gn, gc], 'crossing_resid': crossing_resid, 'crossing_pattern': crossing_pattern, 'gauss_linking_number': [GLNn, GLNc], 'topoly_linking_number': [TLNn, TLNc], 'native_contact': [i, j]}
1235
+ ref_dict['ent_fingerprint'][key] = value
1236
+
1237
+ ##########################################################################################
1238
+ ## Process trajectory frames
1239
+ Gdf = {'Frame':[], 'L-C~':[], 'L-C#':[], 'L+C~':[], 'L+C#':[], 'L#C~':[], 'L#C#':[], 'G':[]}
1240
+
1241
+ # Collect all frames first to allow chunking
1242
+ frames_data = {}
1243
+ for frame, frame_df in Traj.groupby('frame'):
1244
+ frame_dict = {'ent_fingerprint': {},
1245
+ 'chg_ent_fingerprint': {},
1246
+ 'G_dict': {'L-C~': 0, 'L-C#': 0, 'L+C~': 0, 'L+C#': 0, 'L#C~': 0, 'L#C#': 0},
1247
+ 'G': None}
1248
+
1249
+ ## Get the ent_fingerprint data for the frame
1250
+ for rowi, row in frame_df.iterrows():
1251
+ i = int(row['i'])
1252
+ j = int(row['j'])
1253
+ key = (i, j)
1254
+ gn = float(row['gn'])
1255
+ gc = float(row['gc'])
1256
+ GLNn = int(row['GLNn'])
1257
+ GLNc = int(row['GLNc'])
1258
+ TLNn = np.nan if pd.isna(row['TLNn']) else int(row['TLNn'])
1259
+ TLNc = np.nan if pd.isna(row['TLNc']) else int(row['TLNc'])
1260
+
1261
+ crossing_resid, crossing_pattern = self.processes_crossings(row)
1262
+ value = {'linking_value': [gn, gc], 'crossing_resid': crossing_resid, 'crossing_pattern': crossing_pattern, 'gauss_linking_number': [GLNn, GLNc], 'topoly_linking_number': [TLNn, TLNc], 'native_contact': [i, j]}
1263
+ frame_dict['ent_fingerprint'][key] = value
1264
+
1265
+ ## Get the chg_ent_fingerprint data for the frame
1266
+ if key in ref_dict['ent_fingerprint']:
1267
+ chg_ent_fingerprint = self.get_chg_ent_fingerprint(ref = ref_dict['ent_fingerprint'][key], frame = value)
1268
+ frame_dict['chg_ent_fingerprint'][key] = chg_ent_fingerprint
1269
+ frame_dict['G_dict'][chg_ent_fingerprint['code'][0]] += 1
1270
+ frame_dict['G_dict'][chg_ent_fingerprint['code'][1]] += 1
1271
+
1272
+ ## Calculate G for this frame
1273
+ Gdf['Frame'] += [frame]
1274
+ G = 0
1275
+ for code in ['L-C~', 'L-C#', 'L+C~', 'L+C#', 'L#C~']:
1276
+ G += frame_dict['G_dict'][code]
1277
+ Gdf[code] += [frame_dict['G_dict'][code]]
1278
+ Gdf['L#C#'] += [frame_dict['G_dict']['L#C#']]
1279
+ G /= (Num_native_contacts*2)
1280
+ Gdf['G'] += [G]
1281
+ frame_dict['G'] = G
1282
+ frames_data[frame] = frame_dict
1283
+
1284
+ ##########################################################################################
1285
+ ## Save output based on chunking mode
1286
+ if chunk_frames is None:
1287
+ # Backward-compatible mode: single file with all frames
1288
+ Master = {'ref': ref_dict}
1289
+ Master.update(frames_data)
1290
+ outfile = os.path.join(f'{outdir}', f'{ID}_GE.pkl')
1291
+ with open(outfile, 'wb') as fw:
1292
+ pickle.dump(Master, fw)
1293
+ self.logger.info(f'SAVED: {outfile}')
1294
+ return {'outfile': outfile, 'Combined_ref_traj_dict': Master, 'G': pd.DataFrame(Gdf)}
1295
+
1296
+ else:
1297
+ # Chunking mode: split frames into chunks, each with ref data
1298
+ sorted_frames = sorted(frames_data.keys())
1299
+ total_frames = len(sorted_frames)
1300
+ num_chunks = (total_frames + chunk_frames - 1) // chunk_frames # ceiling division
1301
+
1302
+ chunk_metadata = {
1303
+ 'ID': ID,
1304
+ 'total_frames': total_frames,
1305
+ 'chunk_size': chunk_frames,
1306
+ 'num_chunks': num_chunks,
1307
+ 'chunks': []
1308
+ }
1309
+
1310
+ first_chunk_file = None
1311
+ for chunk_idx in range(num_chunks):
1312
+ start_idx = chunk_idx * chunk_frames
1313
+ end_idx = min(start_idx + chunk_frames, total_frames)
1314
+ chunk_frame_nums = sorted_frames[start_idx:end_idx]
1315
+
1316
+ # Create chunk dict with ref and frame data
1317
+ chunk_dict = {'ref': ref_dict}
1318
+ for frame_num in chunk_frame_nums:
1319
+ chunk_dict[frame_num] = frames_data[frame_num]
1320
+
1321
+ # Save chunk
1322
+ chunk_filename = f'{ID}{chunk_suffix}_{chunk_idx:04d}.pkl'
1323
+ chunk_filepath = os.path.join(outdir, chunk_filename)
1324
+ with open(chunk_filepath, 'wb') as fw:
1325
+ pickle.dump(chunk_dict, fw)
1326
+ self.logger.info(f'SAVED: {chunk_filepath}')
1327
+
1328
+ if first_chunk_file is None:
1329
+ first_chunk_file = chunk_filepath
1330
+
1331
+ # Record metadata for this chunk
1332
+ chunk_metadata['chunks'].append({
1333
+ 'chunk_index': chunk_idx,
1334
+ 'filename': chunk_filename,
1335
+ 'frame_range': [int(chunk_frame_nums[0]), int(chunk_frame_nums[-1])],
1336
+ 'num_frames': len(chunk_frame_nums)
1337
+ })
1338
+
1339
+ # Save metadata file
1340
+ metadata_filepath = os.path.join(outdir, f'{ID}_chunk_metadata.json')
1341
+ with open(metadata_filepath, 'w') as fw:
1342
+ json.dump(chunk_metadata, fw, indent=2)
1343
+ self.logger.info(f'SAVED: {metadata_filepath}')
1344
+
1345
+ # Return with outfile pointing to first chunk for backward compatibility
1346
+ return {
1347
+ 'outfile': first_chunk_file,
1348
+ 'Combined_ref_traj_dict': None, # Not applicable for chunked mode
1349
+ 'G': pd.DataFrame(Gdf),
1350
+ 'chunk_info': chunk_metadata
1351
+ }
1352
+
1353
+ ##########################################################################################################################################################
1354
+
1355
+ ##########################################################################################################################################################
1356
+ def processes_crossings(self, row: pd.Series) -> tuple:
1357
+ """
1358
+ This function takes the crossing string and processes it into a list of crossing residues and a list of crossing patterns
1359
+ 1. Split the crossings on , and determine which are N terminal and C terminal and then extracts the crossing number and sign
1360
+
1361
+ Example:
1362
+ i = 15 and j = 101
1363
+ crossing_str = +-15,-10,+108,-150
1364
+
1365
+ processes_crossings -> crossing_resid = [[-15, 10], [108, 150]] and crossing_pattern = ['+-', '+-']
1366
+ """
1367
+ crossing_resid = [[], []]
1368
+ crossing_pattern = [[], []]
1369
+
1370
+ crossingsN = row['crossingsN']
1371
+ if isinstance(crossingsN, str):
1372
+ crossingsN = crossingsN.split(',')
1373
+ for crossing in crossingsN:
1374
+ if crossing == '?':
1375
+ crossing_resid[0] += []
1376
+ crossing_pattern[0] += ['?']
1377
+ else:
1378
+ sign = crossing[0]
1379
+ num = int(crossing[1:])
1380
+ crossing_resid[0] += [num]
1381
+ crossing_pattern[0] += [sign]
1382
+ crossing_pattern[0] = ''.join(crossing_pattern[0])
1383
+ else:
1384
+ crossing_pattern[0] = ''
1385
+
1386
+ crossingsC = row['crossingsC']
1387
+ if isinstance(crossingsC, str):
1388
+ crossingsC = crossingsC.split(',')
1389
+ for crossing in crossingsC:
1390
+ if crossing == '?':
1391
+ crossing_resid[1] += []
1392
+ crossing_pattern[1] += ['?']
1393
+ else:
1394
+ sign = crossing[0]
1395
+ num = int(crossing[1:])
1396
+ crossing_resid[1] += [num]
1397
+ crossing_pattern[1] += [sign]
1398
+ crossing_pattern[1] = ''.join(crossing_pattern[1])
1399
+ else:
1400
+ crossing_pattern[1] = ''
1401
+
1402
+ return crossing_resid, crossing_pattern
1403
+ ##########################################################################################################################################################
1404
+
1405
+ ##########################################################################################################################################################
1406
+ def get_chg_ent_fingerprint(self, ref: dict, frame: dict) -> dict:
1407
+ """
1408
+ This function takes the Gn and Gc numbers for a given native contact and determines if ther eis a change in the frame.
1409
+ The types of changes are represented by the following codes: 'L' = 'Linking number'; 'C' = 'Chirality'; '+' = 'Gain'; '-' = 'Loss'; '~' = 'Switch'; '#' = 'No change'.
1410
+ For example, a change with a code "L+C~" refers to a "Gain of entanglement with a switched chirality".
1411
+ {'L-C~': 0, 'L-C#': 2, 'L+C~': 0, 'L+C#': 0, 'L#C~': 0, 'L#C#': 1630}
1412
+
1413
+ # G1: L-C~, loss of linking number & switched linking chirality
1414
+ # G2: L-C#, loss of linking number & no change of linking chirality
1415
+ # G3: L+C~, gain of linking number & switched linking chirality
1416
+ # G4: L+C#, gain of linking number & no change of linking chirality
1417
+ # G5: L#C~, no change of linking number & switched linking chirality
1418
+ # G6: L#C#, loop formed & no change
1419
+ # G: Number of change of entanglement (G1+...+G5) / (2 x Number of native contacts in reference structure)
1420
+ # Number of native contact in the reference structure: 1066
1421
+
1422
+ Finally the chg_ent_fingerprint looks like this
1423
+ {'type': ['loss of linking number & no change of linking chirality', 'no change'],
1424
+ 'code': ['L-C#', 'L#C#'],
1425
+ 'GLN': [-0.1950360202409954, 0.40138207852044694],
1426
+ 'crossing_resid': [[], []],
1427
+ 'crossing_pattern': ['', ''],
1428
+ 'linking_number': [0, 0],
1429
+ 'native_contact': [267, 316],
1430
+ 'surrounding_resid': [[], []],
1431
+ 'ref_GLN': [-0.8566778684626731, 0.38085875463868263],
1432
+ 'ref_crossing_resid': [[256], []],
1433
+ 'ref_crossing_pattern': ['-', ''],
1434
+ 'ref_linking_number': [-1, 0],
1435
+ 'ref_native_contact': [267, 316]}
1436
+ """
1437
+ #print(f'ref: {ref}')
1438
+ #print(f'frame {frame}')
1439
+
1440
+ ## Check for N/C terminal change — compute only what ent_detection_method requires
1441
+ ###-------------------------------------------------------------------------------------
1442
+ if self.ent_detection_method in (1, 3):
1443
+ GLN_ref_N_G = ref['gauss_linking_number'][0]
1444
+ GLN_ref_C_G = ref['gauss_linking_number'][1]
1445
+ GLN_frame_N_G = frame['gauss_linking_number'][0]
1446
+ GLN_frame_C_G = frame['gauss_linking_number'][1]
1447
+ print(f'\nGLN_ref_N_G: {GLN_ref_N_G}, GLN_ref_C_G: {GLN_ref_C_G}, GLN_frame_N_G: {GLN_frame_N_G}, GLN_frame_C_G: {GLN_frame_C_G}')
1448
+
1449
+ ## get the change code for GLN N terminal
1450
+ if abs(GLN_frame_N_G) < abs(GLN_ref_N_G):
1451
+ GLN_N_link = '-'
1452
+ elif abs(GLN_frame_N_G) > abs(GLN_ref_N_G):
1453
+ GLN_N_link = '+'
1454
+ elif abs(GLN_frame_N_G) == abs(GLN_ref_N_G):
1455
+ GLN_N_link = '#'
1456
+
1457
+ ## get the change code for GLN C terminal
1458
+ if abs(GLN_frame_C_G) < abs(GLN_ref_C_G):
1459
+ GLN_C_link = '-'
1460
+ elif abs(GLN_frame_C_G) > abs(GLN_ref_C_G):
1461
+ GLN_C_link = '+'
1462
+ elif abs(GLN_frame_C_G) == abs(GLN_ref_C_G):
1463
+ GLN_C_link = '#'
1464
+
1465
+ # check if the signs of frame_G and ref_G are the same for N terminus
1466
+ if GLN_frame_N_G * GLN_ref_N_G >= 0:
1467
+ GLN_N_chiral = '#'
1468
+ elif GLN_frame_N_G * GLN_ref_N_G < 0:
1469
+ GLN_N_chiral = '~'
1470
+
1471
+ # check if the signs of frame_G and ref_G are the same for C terminus
1472
+ if GLN_frame_C_G * GLN_ref_C_G >= 0:
1473
+ GLN_C_chiral = '#'
1474
+ elif GLN_frame_C_G * GLN_ref_C_G < 0:
1475
+ GLN_C_chiral = '~'
1476
+
1477
+ print(f'GLN_N_link: {GLN_N_link}, GLN_N_chiral: {GLN_N_chiral}, GLN_C_link: {GLN_C_link}, GLN_C_chiral: {GLN_C_chiral}')
1478
+ ###-------------------------------------------------------------------------------------
1479
+
1480
+ ###-------------------------------------------------------------------------------------
1481
+ if self.ent_detection_method in (2, 3):
1482
+ TLN_ref_N_G = ref['topoly_linking_number'][0]
1483
+ TLN_ref_C_G = ref['topoly_linking_number'][1]
1484
+ TLN_frame_N_G = frame['topoly_linking_number'][0]
1485
+ TLN_frame_C_G = frame['topoly_linking_number'][1]
1486
+ print(f'TLN_ref_N_G: {TLN_ref_N_G}, TLN_ref_C_G: {TLN_ref_C_G}, TLN_frame_N_G: {TLN_frame_N_G}, TLN_frame_C_G: {TLN_frame_C_G}')
1487
+
1488
+ ## get the change code for TLN N terminal
1489
+ if abs(TLN_frame_N_G) < abs(TLN_ref_N_G):
1490
+ TLN_N_link = '-'
1491
+ elif abs(TLN_frame_N_G) > abs(TLN_ref_N_G):
1492
+ TLN_N_link = '+'
1493
+ elif abs(TLN_frame_N_G) == abs(TLN_ref_N_G):
1494
+ TLN_N_link = '#'
1495
+
1496
+ ## get the change code for TLN C terminal
1497
+ if abs(TLN_frame_C_G) < abs(TLN_ref_C_G):
1498
+ TLN_C_link = '-'
1499
+ elif abs(TLN_frame_C_G) > abs(TLN_ref_C_G):
1500
+ TLN_C_link = '+'
1501
+ elif abs(TLN_frame_C_G) == abs(TLN_ref_C_G):
1502
+ TLN_C_link = '#'
1503
+
1504
+ # check if the signs of frame_G and ref_G are the same for N terminus
1505
+ if TLN_frame_N_G * TLN_ref_N_G >= 0:
1506
+ TLN_N_chiral = '#'
1507
+ elif TLN_frame_N_G * TLN_ref_N_G < 0:
1508
+ TLN_N_chiral = '~'
1509
+
1510
+ # check if the signs of frame_G and ref_G are the same for C terminus
1511
+ if TLN_frame_C_G * TLN_ref_C_G >= 0:
1512
+ TLN_C_chiral = '#'
1513
+ elif TLN_frame_C_G * TLN_ref_C_G < 0:
1514
+ TLN_C_chiral = '~'
1515
+
1516
+ print(f'TLN_N_link: {TLN_N_link}, TLN_N_chiral: {TLN_N_chiral}, TLN_C_link: {TLN_C_link}, TLN_C_chiral: {TLN_C_chiral}')
1517
+ ###-------------------------------------------------------------------------------------
1518
+
1519
+ ###-------------------------------------------------------------------------------------
1520
+ ## Determine the overall change code for the frame based on the ent_detection_method
1521
+ Ncode = ''
1522
+ Ccode = ''
1523
+ if self.ent_detection_method == 1:
1524
+ # Any nonzero GLN for either termini
1525
+ Ncode = f'L{GLN_N_link}C{GLN_N_chiral}'
1526
+ Ccode = f'L{GLN_C_link}C{GLN_C_chiral}'
1527
+ codes = [Ncode, Ccode]
1528
+
1529
+ Ntype = self.change_codes[Ncode]
1530
+ Ctype = self.change_codes[Ccode]
1531
+ types = [Ntype, Ctype]
1532
+
1533
+
1534
+ elif self.ent_detection_method == 2:
1535
+ # Any nonzero TLN for either termini (default)
1536
+ Ncode = f'L{TLN_N_link}C{TLN_N_chiral}'
1537
+ Ccode = f'L{TLN_C_link}C{TLN_C_chiral}'
1538
+ codes = [Ncode, Ccode]
1539
+
1540
+ Ntype = self.change_codes[Ncode]
1541
+ Ctype = self.change_codes[Ccode]
1542
+ types = [Ntype, Ctype]
1543
+
1544
+ elif self.ent_detection_method == 3:
1545
+ # if both GLN and TLN changes are the same then use that code, if they are different then use the GLN code (because it is more sensitive)
1546
+ GLN_N_code = f'L{GLN_N_link}C{GLN_N_chiral}'
1547
+ GLN_C_code = f'L{GLN_C_link}C{GLN_C_chiral}'
1548
+ TLN_N_code = f'L{TLN_N_link}C{TLN_N_chiral}'
1549
+ TLN_C_code = f'L{TLN_C_link}C{TLN_C_chiral}'
1550
+ if GLN_N_code == TLN_N_code:
1551
+ Ncode = GLN_N_code
1552
+ else:
1553
+ Ncode = 'L#C#'
1554
+ if GLN_C_code == TLN_C_code:
1555
+ Ccode = GLN_C_code
1556
+ else:
1557
+ Ccode = 'L#C#'
1558
+ codes = [Ncode, Ccode]
1559
+
1560
+ Ntype = self.change_codes[Ncode]
1561
+ Ctype = self.change_codes[Ccode]
1562
+ types = [Ntype, Ctype]
1563
+ ###-------------------------------------------------------------------------------------
1564
+
1565
+ #print(f'codes: {codes}')
1566
+ #print(f'types: {types}')
1567
+
1568
+ chg_ent_fingerprint = {'type': types,
1569
+ 'code': codes,
1570
+ 'native_contact': frame['native_contact'],
1571
+ 'linking_value': frame['linking_value'],
1572
+ 'crossing_resid': frame['crossing_resid'],
1573
+ 'crossing_pattern': frame['crossing_pattern'],
1574
+ 'gauss_linking_number': frame['gauss_linking_number'],
1575
+ 'topoly_linking_number': frame['topoly_linking_number'],
1576
+ 'ref_native_contact': ref['native_contact'],
1577
+ 'ref_linking_value': ref['linking_value'],
1578
+ 'ref_crossing_resid': ref['crossing_resid'],
1579
+ 'ref_crossing_pattern': ref['crossing_pattern'],
1580
+ 'ref_gauss_linking_number': ref['gauss_linking_number'],
1581
+ 'ref_topoly_linking_number': ref['topoly_linking_number'],
1582
+ 'ent_detection_method': self.ent_detection_method}
1583
+
1584
+ return chg_ent_fingerprint
1585
+ ##########################################################################################################################################################
1586
+
1587
+ ##########################################################################################################################################################
1588
+ def select_high_quality_entanglements(self, GE_filepath: str, pdb: str, outdir: str='./', ID: str='', model: str='EXP', mapping: str='None', chain: str=None) -> dict:
1589
+ """
1590
+ This function takes the GE file and selects the high quality entanglements based on the following criteria:
1591
+ 1. Remove any native NCLE's that are predicted to be pure slipknots (crossings with a net sign that cancels out)
1592
+ 2. if the model is EXP it will try and only grab those ENT that are mapped to a uniprot sequence if the user specifies a mapping file
1593
+ 3. if the model is AF then also check that the i, j, and k meet our criteria
1594
+ """
1595
+ ## set up the outdir for this calculation
1596
+ #outdir = f"{os.getcwd()}/{outdir}"
1597
+ if not os.path.isdir(outdir):
1598
+ os.mkdir(f"{outdir}")
1599
+ self.logger.info(f"Creating directory: {outdir}")
1600
+
1601
+ ## load the dataframe
1602
+ GE_data = pd.read_csv(GE_filepath, sep='|', dtype={'crossingsN': str, 'crossingsC': str})
1603
+ GE_data = GE_data[GE_data['ENT'] == True].reset_index(drop=True)
1604
+ self.logger.info(f'GE FILE: {GE_filepath}')
1605
+ # print(f'RAW GE_data:\n{GE_data}')
1606
+
1607
+ ## select only those entanglements that are mapped for the EXP model
1608
+ if model == 'EXP' and mapping != 'None':
1609
+ GE_data = self.remove_slipknots(GE_data)
1610
+ #print(f'No Slipknot GE_data:\n{GE_data}')
1611
+
1612
+ GE_data = self.mappingPDB2Uniprot(GE_data, mapping)
1613
+ #print(f'No Slipknot mapped GE_data:\n{GE_data}')
1614
+
1615
+ ## select only those entanglements that are mapped for the EXP model
1616
+ if model == 'EXP' and mapping == 'None':
1617
+ GE_data = self.remove_slipknots(GE_data)
1618
+ # print(f'No Slipknot GE_data:\n{GE_data}')
1619
+
1620
+ ## select only those entanglements that meet our pLDDT thresholds for AF model
1621
+ if model == 'AF':
1622
+ #GE_data = self.remove_slipknots(GE_data)
1623
+ #print(f'No Slipknot GE_data:\n{GE_data}')
1624
+ GE_data = self.remove_low_quality_AF_entanglements(GE_data, pdb)
1625
+ #print(f'No Slipknot GE_data:\n{GE_data}')
1626
+
1627
+
1628
+ outfile = os.path.join(outdir, f'{ID}.csv')
1629
+ GE_data.to_csv(outfile, index=False, sep='|')
1630
+ self.logger.info(f'SAVED: {outfile}')
1631
+ GE_data = pd.read_csv(outfile, sep='|', dtype={'c': str})
1632
+ # print(f'HQ GE_data:\n{GE_data}')
1633
+ return {'outfile':outfile, 'GE_data':GE_data}
1634
+ ##########################################################################################################################################################
1635
+
1636
+ ##########################################################################################################################################################
1637
+ def remove_slipknots(self, df):
1638
+ """
1639
+ Checks each raw entanglement for crossings where the signs sum to a net of 0.
1640
+ """
1641
+ self.logger.info(f'\n{"#"*50}\nRemoving slipknots...')
1642
+ new_df = {'ID':[], 'chain':[], 'i':[], 'j':[], 'crossingsN':[], 'crossingsC':[], 'gn':[], 'gc':[], 'GLNn':[], 'GLNc':[], 'TLNn':[], 'TLNc':[], 'CCbond':[], 'ENT':[], 'Slipknot_N':[], 'Slipknot_C':[]}
1643
+ for rowi, row in df.iterrows():
1644
+ # print(row)
1645
+ ID = row['ID']
1646
+ chain = row['chain']
1647
+ i = row['i']
1648
+ j = row['j']
1649
+ if isinstance(row['crossingsN'], float):
1650
+ rN = ['']
1651
+ else:
1652
+ rN = row['crossingsN'].split(',')
1653
+ ENT = row['ENT']
1654
+ if isinstance(row['crossingsC'], float):
1655
+ rC = ['']
1656
+ else:
1657
+ rC = row['crossingsC'].split(',')
1658
+ ENT = row['ENT']
1659
+ gn = row['gn']
1660
+ gc = row['gc']
1661
+ GLNn = row['GLNn']
1662
+ GLNc = row['GLNc']
1663
+ TLNn = row['TLNn']
1664
+ TLNc = row['TLNc']
1665
+ CCbond = row['CCbond']
1666
+ ENT = row['ENT']
1667
+ # print(ID, i, j, rN, rC, gn, gc, GLNn, GLNc, TLNn, TLNc, CCbond, ENT)
1668
+
1669
+
1670
+ # Check each termini for slipknots
1671
+ slipknot_dict = {'N': False, 'C': False}
1672
+ for termini, crossing_list in {'N': rN, 'C': rC}.items():
1673
+ if len(crossing_list) > 1:
1674
+ crossings_signs = []
1675
+ crossings = []
1676
+ crossings_resids = []
1677
+ for cross in crossing_list:
1678
+ #print(cross)
1679
+ cross_sign = cross[0]
1680
+ cross_int = int(cross[1:])
1681
+ #print(i, j, cross)
1682
+
1683
+ #check if the crossing in N terminal
1684
+ if cross_int < i:
1685
+ crossings += [cross]
1686
+ crossings_resids += [cross_int]
1687
+ if cross_sign == '+':
1688
+ crossings_signs += [1]
1689
+ elif cross_sign == '-':
1690
+ crossings_signs += [-1]
1691
+ else:
1692
+ raise ValueError(f'The crossing sign was not + or - {cross}')
1693
+
1694
+ # check if either termini has duplicate crossings and empty out those crossings lists as we will not have any confidence
1695
+ if len(crossings_resids) != len(set(crossings_resids)):
1696
+ raise ValueError(f'Duplicate crossings found in N terminus: {crossings_resids}')
1697
+
1698
+ # get the sum of the N terminal crossings
1699
+ if len(crossings_signs) != 0:
1700
+ slipknot_dict[termini] = True
1701
+ # print(f'slipknot_dict: {slipknot_dict}')
1702
+
1703
+ # update the new df
1704
+ new_df['i'] += [i]
1705
+ new_df['j'] += [j]
1706
+ new_df['ID'] += [ID]
1707
+ new_df['chain'] += [chain]
1708
+ new_df['crossingsN'] += [','.join(rN)]
1709
+ new_df['crossingsC'] += [','.join(rC)]
1710
+ new_df['gn'] += [gn]
1711
+ new_df['gc'] += [gc]
1712
+ new_df['GLNn'] += [GLNn]
1713
+ new_df['GLNc'] += [GLNc]
1714
+ new_df['TLNn'] += [TLNn]
1715
+ new_df['TLNc'] += [TLNc]
1716
+ new_df['CCbond'] += [CCbond]
1717
+ new_df['ENT'] += [ENT]
1718
+ new_df['Slipknot_N'] += [slipknot_dict['N']]
1719
+ new_df['Slipknot_C'] += [slipknot_dict['C']]
1720
+
1721
+ new_df = pd.DataFrame(new_df)
1722
+ return new_df
1723
+ ##########################################################################################################################################################
1724
+
1725
+ ##########################################################################################################################################################
1726
+ def mappingPDB2Uniprot(self, df, mapping):
1727
+ """
1728
+ Maps the PDB level ENT to the uniprot resid if desired
1729
+ """
1730
+ ## Check if the mapping file exists
1731
+ if os.path.exists(mapping):
1732
+ mapping = np.loadtxt(mapping, dtype='O')
1733
+ mapping = np.vstack([x[1:] for x in mapping if ('Mapped' in x[0] or 'Modifed_Residue' in x[0] or 'Missense' in x[0])]).astype(int)
1734
+ mapping_pdb2uniprot = {pdb:uni for pdb, uni in mapping}
1735
+ else:
1736
+ raise ValueError(f'Mapping file {mapping} could not be found!')
1737
+
1738
+ new_df = {'ID':[], 'chain':[], 'i':[], 'j':[], 'crossingsN':[], 'crossingsC':[], 'gn':[], 'gc':[], 'GLNn':[], 'GLNc':[], 'TLNn':[], 'TLNc':[], 'CCbond':[], 'ENT':[]}
1739
+ for rowi, row in df.iterrows():
1740
+ #print(row)
1741
+ ID = row['ID']
1742
+ chain = row['chain']
1743
+ i = row['i']
1744
+ j = row['j']
1745
+
1746
+ # Parse crossings from crossingsN and crossingsC
1747
+ crossingsN = row['crossingsN'] if pd.notna(row['crossingsN']) and row['crossingsN'] != '' and row['crossingsN'] != '?' else ''
1748
+ crossingsC = row['crossingsC'] if pd.notna(row['crossingsC']) and row['crossingsC'] != '' and row['crossingsC'] != '?' else ''
1749
+
1750
+ rN = crossingsN.split(',') if crossingsN else []
1751
+ rC = crossingsC.split(',') if crossingsC else []
1752
+ r = rN + rC
1753
+
1754
+ ENT = row['ENT']
1755
+ crossings = [int(c[1:]) for c in r if c != '']
1756
+
1757
+ gn = row['gn']
1758
+ gc = row['gc']
1759
+ GLNn = row['GLNn']
1760
+ GLNc = row['GLNc']
1761
+ TLNn = row['TLNn']
1762
+ TLNc = row['TLNc']
1763
+ CCbond = row['CCbond']
1764
+ key_res = [i, j] + crossings
1765
+ #print(ID, i, j, r, crossings, gn, gc, GLNn, GLNc, TLNn, TLNc, CCbond, ENT, key_res)
1766
+
1767
+ mapped = True
1768
+ for res in key_res:
1769
+ if res not in mapping_pdb2uniprot:
1770
+ self.logger.debug(f'Res: {res} not mapped! this entanglement will be discarded from {ID}')
1771
+ mapped = False
1772
+
1773
+ if mapped:
1774
+ new_df['ID'] += [ID]
1775
+ new_df['chain'] += [chain]
1776
+ new_df['i'] += [i]
1777
+ new_df['j'] += [j]
1778
+ new_df['crossingsN'] += [crossingsN]
1779
+ new_df['crossingsC'] += [crossingsC]
1780
+ new_df['gn'] += [gn]
1781
+ new_df['gc'] += [gc]
1782
+ new_df['GLNn'] += [GLNn]
1783
+ new_df['GLNc'] += [GLNc]
1784
+ new_df['TLNn'] += [TLNn]
1785
+ new_df['TLNc'] += [TLNc]
1786
+ new_df['CCbond'] += [CCbond]
1787
+ new_df['ENT'] += [ENT]
1788
+ else:
1789
+ self.logger.info(f'Entanglement {rowi} was not mapped and will be discarded')
1790
+
1791
+ new_df = pd.DataFrame(new_df)
1792
+ return new_df
1793
+ ##########################################################################################################################################################
1794
+
1795
+ ##########################################################################################################################################################
1796
+ def remove_low_quality_AF_entanglements(self, df, pdb):
1797
+ """
1798
+ # (1) check if both i and j have pLDDt >= 70. if so continue else completely ignore the ent
1799
+ # (2) starting from the loop base get the set of ordered crossings that have pLDDT > 70 and discard any after the first crossings that fails this.
1800
+ """
1801
+ self.logger.info(f'\n{"#"*50}\nRemoving low quality AF entanglements...')
1802
+ avg_pLDDT, pLDDT_df = self.average_pLDDT(pdb)
1803
+ #print(f'avg_pLDDT: {avg_pLDDT}\n{pLDDT_df}')
1804
+
1805
+ new_df = {'ID':[], 'chain':[], 'i':[], 'j':[], 'crossingsN':[], 'crossingsC':[], 'gn':[], 'gc':[], 'GLNn':[], 'GLNc':[], 'TLNn':[], 'TLNc':[], 'CCbond':[], 'ENT':[], 'Quality':[], 'Reason':[]}
1806
+ for rowi, row in df.iterrows():
1807
+ # print(row)
1808
+ ID = row['ID']
1809
+ chain = row['chain']
1810
+ i = row['i']
1811
+ j = row['j']
1812
+
1813
+ # Parse crossings from crossingsN and crossingsC
1814
+ crossingsN = row['crossingsN'] if pd.notna(row['crossingsN']) and row['crossingsN'] != '' and row['crossingsN'] != '?' else ''
1815
+ crossingsC = row['crossingsC'] if pd.notna(row['crossingsC']) and row['crossingsC'] != '' and row['crossingsC'] != '?' else ''
1816
+
1817
+ rN = crossingsN.split(',') if crossingsN else []
1818
+ rC = crossingsC.split(',') if crossingsC else []
1819
+ # Filter out empty strings from combined list
1820
+ r = [x for x in (rN + rC) if x != '']
1821
+
1822
+ ENT = row['ENT']
1823
+ gn = row['gn']
1824
+ gc = row['gc']
1825
+ GLNn = row['GLNn']
1826
+ GLNc = row['GLNc']
1827
+ TLNn = row['TLNn']
1828
+ TLNc = row['TLNc']
1829
+ CCbond = row['CCbond']
1830
+ #print(ID, i, j, r, gn, gc, GLNn, GLNc, TLNn, TLNc, CCbond, ENT)
1831
+
1832
+
1833
+ # (1) check if both i and j have pLDDt >= 70. if so continue else completely ignore the ent
1834
+ NC_pLDDT = pLDDT_df[pLDDT_df['resid'].isin([i,j])]['pLDDT'].values
1835
+ if all(NC_pLDDT >= 70):
1836
+ #print(f'Native contact pLDDT are >= 70 {NC_pLDDT}')
1837
+ if ENT == False: ## if no entanglement but the native contact has a pLDDT >= 70 still return the contact as HQ even though there is no ent
1838
+ new_df['ID'] += [ID]
1839
+ new_df['chain'] += [chain]
1840
+ new_df['i'] += [i]
1841
+ new_df['j'] += [j]
1842
+ new_df['crossingsN'] += [crossingsN]
1843
+ new_df['crossingsC'] += [crossingsC]
1844
+ new_df['gn'] += [gn]
1845
+ new_df['gc'] += [gc]
1846
+ new_df['GLNn'] += [GLNn]
1847
+ new_df['GLNc'] += [GLNc]
1848
+ new_df['TLNn'] += [TLNn]
1849
+ new_df['TLNc'] += [TLNc]
1850
+ new_df['CCbond'] += [CCbond]
1851
+ new_df['ENT'] += [ENT]
1852
+ new_df['Quality'] += ['High']
1853
+ new_df['Reason'] += ['NC pLDDT >= 70']
1854
+ continue
1855
+ else:
1856
+ #print(f'Native contact pLDDT are < 70 {NC_pLDDT}')
1857
+ new_df['ID'] += [ID]
1858
+ new_df['chain'] += [chain]
1859
+ new_df['i'] += [i]
1860
+ new_df['j'] += [j]
1861
+ new_df['crossingsN'] += [crossingsN]
1862
+ new_df['crossingsC'] += [crossingsC]
1863
+ new_df['gn'] += [gn]
1864
+ new_df['gc'] += [gc]
1865
+ new_df['GLNn'] += [GLNn]
1866
+ new_df['GLNc'] += [GLNc]
1867
+ new_df['TLNn'] += [TLNn]
1868
+ new_df['TLNc'] += [TLNc]
1869
+ new_df['CCbond'] += [CCbond]
1870
+ new_df['ENT'] += [ENT]
1871
+ new_df['Quality'] += ['Low']
1872
+ new_df['Reason'] += ['NC pLDDT < 70']
1873
+ continue
1874
+
1875
+ # (2) starting from the loop base get the set of ordered crossings that have pLDDT > 70 and discard any after the first crossings that fails this.
1876
+ #print(f'Getting HQ N-terminal entanglements')
1877
+ Ncrossings_resids, Ncrossings, Ncrossings_signs = self.parse_crossings(r, i=i, j=j, term='N')
1878
+ #print(Ncrossings_resids, Ncrossings, Ncrossings_signs)
1879
+ Ncrossings_resids, Ncrossings, Ncrossings_signs, Ndup_flag = self.remove_duplicates(Ncrossings_resids, Ncrossings, Ncrossings_signs)
1880
+ HQ_Ncrossings = []
1881
+ HQ_Ncrossings_resids = []
1882
+ HQ_Ncrossings_signs = []
1883
+ if len(Ncrossings_resids) != 0:
1884
+ sorted_indices = np.argsort(Ncrossings_resids)[::-1] # [::-1] reverses the order
1885
+ Ncrossings_resids = [Ncrossings_resids[i] for i in sorted_indices]
1886
+ Ncrossings = [Ncrossings[i] for i in sorted_indices]
1887
+ Ncrossings_signs = [Ncrossings_signs[i] for i in sorted_indices]
1888
+ Ncrossings_pLDDTs = pLDDT_df[pLDDT_df['resid'].isin(Ncrossings_resids)]['pLDDT'].values
1889
+ #print(Ncrossings_resids, Ncrossings, Ncrossings_signs, Ncrossings_pLDDTs)
1890
+ for cross_i, cross in enumerate(Ncrossings_resids):
1891
+ if Ncrossings_pLDDTs[cross_i] >=70:
1892
+ HQ_Ncrossings += [Ncrossings[cross_i]]
1893
+ HQ_Ncrossings_resids += [cross]
1894
+ HQ_Ncrossings_signs += [Ncrossings_signs[cross_i]]
1895
+ else:
1896
+ break
1897
+
1898
+ # check that the remaining HQ Nterminal crossigns are not a slipknot
1899
+ if sum(HQ_Ncrossings_signs) == 0:
1900
+ #print(f'SlipKnot foundin N terminus after HQ search')
1901
+ HQ_Ncrossings = []
1902
+ HQ_Ncrossings_resids = []
1903
+ HQ_Ncrossings_signs = []
1904
+
1905
+
1906
+ #print(f'Getting HQ C-terminal entanglements')
1907
+ Ccrossings_resids, Ccrossings, Ccrossings_signs = self.parse_crossings(r, i=i, j=j, term='C')
1908
+ #print(Ccrossings_resids, Ccrossings, Ccrossings_signs)
1909
+ Ccrossings_resids, Ccrossings, Ccrossings_signs, Cdup_flag = self.remove_duplicates(Ccrossings_resids, Ccrossings, Ccrossings_signs)
1910
+ HQ_Ccrossings = []
1911
+ HQ_Ccrossings_resids = []
1912
+ HQ_Ccrossings_signs = []
1913
+ if len(Ccrossings_resids) != 0:
1914
+ sorted_indices = np.argsort(Ccrossings_resids) # [::-1] reverses the order
1915
+ Ccrossings_resids = [Ccrossings_resids[i] for i in sorted_indices]
1916
+ Ccrossings = [Ccrossings[i] for i in sorted_indices]
1917
+ Ccrossings_signs = [Ccrossings_signs[i] for i in sorted_indices]
1918
+ Ccrossings_pLDDTs = pLDDT_df[pLDDT_df['resid'].isin(Ccrossings_resids)]['pLDDT'].values
1919
+ #print(Ccrossings_resids, Ccrossings, Ccrossings_signs, Ccrossings_pLDDTs)
1920
+ for cross_i, cross in enumerate(Ccrossings_resids):
1921
+ if Ccrossings_pLDDTs[cross_i] >=70:
1922
+ HQ_Ccrossings += [Ccrossings[cross_i]]
1923
+ HQ_Ccrossings_resids += [cross]
1924
+ HQ_Ccrossings_signs += [Ccrossings_signs[cross_i]]
1925
+ else:
1926
+ break
1927
+
1928
+ # check that the remaining HQ Cterminal crossigns are not a slipknot
1929
+ if sum(HQ_Ccrossings_signs) == 0:
1930
+ #print(f'SlipKnot foundin N terminus after HQ search')
1931
+ HQ_Ccrossings = []
1932
+ HQ_Ccrossings_resids = []
1933
+ HQ_Ccrossings_signs = []
1934
+
1935
+ HQ_crossings = HQ_Ccrossings + HQ_Ncrossings
1936
+ HQ_crossings = np.asarray(HQ_crossings, dtype=str)
1937
+ #print(f'HQ_crossings: {HQ_crossings}')
1938
+
1939
+ # Separate HQ crossings into N and C terminal
1940
+ HQ_crossingsN = ','.join(HQ_Ncrossings) if len(HQ_Ncrossings) > 0 else ''
1941
+ HQ_crossingsC = ','.join(HQ_Ccrossings) if len(HQ_Ccrossings) > 0 else ''
1942
+
1943
+ if len(HQ_crossings) != 0:
1944
+ new_df['ID'] += [ID]
1945
+ new_df['chain'] += [chain]
1946
+ new_df['i'] += [i]
1947
+ new_df['j'] += [j]
1948
+ new_df['crossingsN'] += [HQ_crossingsN]
1949
+ new_df['crossingsC'] += [HQ_crossingsC]
1950
+ new_df['gn'] += [gn]
1951
+ new_df['gc'] += [gc]
1952
+ new_df['GLNn'] += [GLNn]
1953
+ new_df['GLNc'] += [GLNc]
1954
+ new_df['TLNn'] += [TLNn]
1955
+ new_df['TLNc'] += [TLNc]
1956
+ new_df['CCbond'] += [CCbond]
1957
+ new_df['ENT'] += [ENT]
1958
+ new_df['Quality'] += ['High']
1959
+ new_df['Reason'] += ['NC and Crossings pLDDT >= 70']
1960
+ else:
1961
+ #print(f'No crossing remaining in either the N or C terminus after removals. ENT will be ignored')
1962
+ new_df['ID'] += [ID]
1963
+ new_df['chain'] += [chain]
1964
+ new_df['i'] += [i]
1965
+ new_df['j'] += [j]
1966
+ new_df['crossingsN'] += [crossingsN]
1967
+ new_df['crossingsC'] += [crossingsC]
1968
+ new_df['gn'] += [gn]
1969
+ new_df['gc'] += [gc]
1970
+ new_df['GLNn'] += [GLNn]
1971
+ new_df['GLNc'] += [GLNc]
1972
+ new_df['TLNn'] += [TLNn]
1973
+ new_df['TLNc'] += [TLNc]
1974
+ new_df['CCbond'] += [CCbond]
1975
+ new_df['ENT'] += [ENT]
1976
+ new_df['Quality'] += ['Low']
1977
+ new_df['Reason'] += ['No HQ crossings']
1978
+
1979
+ new_df = pd.DataFrame(new_df)
1980
+ return new_df
1981
+ ##########################################################################################################################################################
1982
+
1983
+ ##########################################################################################################################################################
1984
+ def parse_crossings(self, r:list, i=int, j=int, term:str='N'):
1985
+ crossings_resids, crossings, crossings_signs = [], [], []
1986
+ #print(i, j, r, term)
1987
+ for cross in r:
1988
+ if cross[0] == '+':
1989
+ cross_sign = 1
1990
+ elif cross[0] == '-':
1991
+ cross_sign = -1
1992
+ else:
1993
+ cross_sign = 0
1994
+
1995
+ cross_resid = int(cross[1:])
1996
+
1997
+ if term == 'N':
1998
+ if cross_resid < i:
1999
+ crossings_resids += [cross_resid]
2000
+ crossings += [cross]
2001
+ crossings_signs += [cross_sign]
2002
+ if term == 'C':
2003
+ if cross_resid > j:
2004
+ crossings_resids += [cross_resid]
2005
+ crossings += [cross]
2006
+ crossings_signs += [cross_sign]
2007
+
2008
+ return crossings_resids, crossings, crossings_signs
2009
+ ##########################################################################################################################################################
2010
+
2011
+ ##########################################################################################################################################################
2012
+ def remove_duplicates(self, crossings_resids, crossings, crossings_signs):
2013
+ # Create a list to keep track of unique elements in A and their corresponding elements in B
2014
+ unique_crossings_resids = []
2015
+ unique_crossings = []
2016
+ unique_crossings_signs = []
2017
+
2018
+ # Dictionary to count occurrences of elements in A
2019
+ counts_A = {item: crossings_resids.count(item) for item in crossings_resids}
2020
+
2021
+ # Remove elements from both lists where duplicates are found in A
2022
+ dup_flag = False
2023
+ for i in range(len(crossings_resids)):
2024
+ if counts_A[crossings_resids[i]] == 1: # If the element in A is unique
2025
+ unique_crossings_resids.append(crossings_resids[i])
2026
+ unique_crossings.append(crossings[i])
2027
+ unique_crossings_signs.append(crossings_signs[i])
2028
+ else:
2029
+ unique_crossings_resids = []
2030
+ unique_crossings = []
2031
+ unique_crossings_signs = []
2032
+ dup_flag = True
2033
+ break
2034
+
2035
+ return unique_crossings_resids, unique_crossings, unique_crossings_signs, dup_flag
2036
+ ##########################################################################################################################################################
2037
+
2038
+ ##########################################################################################################################################################
2039
+ def average_pLDDT(self, pdb_filename):
2040
+
2041
+ # Create a PDB parser
2042
+ parser = PDBParser(QUIET=True)
2043
+
2044
+ # Parse the PDB structure
2045
+ structure = parser.get_structure('PDB_structure', pdb_filename)
2046
+
2047
+ # List to hold pLDDT
2048
+ pLDDTs = []
2049
+ pLDDT_df = {'resid':[], 'pLDDT':[]}
2050
+ # Iterate over all atoms in the structure to extract pLDDT
2051
+ for model in structure:
2052
+ for chain in model:
2053
+ for residue in chain:
2054
+ for atom in residue:
2055
+ if atom.get_name() == 'CA':
2056
+ pLDDTs.append(atom.get_bfactor())
2057
+ pLDDT_df['resid'] += [residue.get_id()[1]]
2058
+ pLDDT_df['pLDDT'] += [atom.get_bfactor()]
2059
+
2060
+ # Calculate the average pLDDT
2061
+ if len(pLDDTs) > 0:
2062
+ avg_pLDDT = sum(pLDDTs) / len(pLDDTs)
2063
+ pLDDT_df = pd.DataFrame(pLDDT_df)
2064
+ return avg_pLDDT, pLDDT_df
2065
+ else:
2066
+ return None
2067
+ ##########################################################################################################################################################