EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,478 @@
1
+ import os
2
+ import math
3
+ import re
4
+ import logging
5
+ import argparse
6
+ import numpy as np
7
+ import pandas as pd
8
+ from glob import glob
9
+ import mdtraj as md
10
+ from Bio.PDB import PDBParser, is_aa
11
+ from Bio import PDB
12
+ from scipy.spatial.distance import pdist, squareform
13
+ import MDAnalysis as mda
14
+ import requests, sys
15
+ from EntDetect._logging import setup_logger
16
+ np.set_printoptions(linewidth=np.inf, precision=4)
17
+ pd.set_option('display.max_rows', None)
18
+
19
+ class FeatureGen:
20
+ """
21
+ Processes biological data including PDB files, sequence data, and interaction potentials.
22
+ """
23
+ #############################################################################################################
24
+ def __init__(self, PDBfile:str, outdir:str='./', cluster_file:str='None', log_level:int=logging.INFO, logdir:str=None):
25
+ self.PDBfile = PDBfile
26
+ self.outdir = outdir
27
+ self.logger = setup_logger('FeatureGen', outdir=logdir if logdir is not None else outdir, log_level=log_level)
28
+
29
+ if not os.path.exists(self.outdir):
30
+ os.makedirs(self.outdir)
31
+ self.logger.debug(f'Made directory: {self.outdir}')
32
+
33
+ self.traj = md.load(PDBfile)
34
+ #print(f'traj: {self.traj}')
35
+
36
+ ## parse lines to get native contacts, crossings,
37
+ if os.path.exists(cluster_file):
38
+ self.GE_data = pd.read_csv(cluster_file, sep='|', dtype={'c': str, 'crossingsN': str, 'crossingsC': str})
39
+ #print(self.GE_data)
40
+ else:
41
+ raise ValueError(f"{self.cluster_file} does not exits")
42
+ #############################################################################################################
43
+
44
+ #############################################################################################################
45
+ def get_AA(self, pdb_file, gene):
46
+
47
+ """
48
+ Get the PDB resid to AA mapping for the provided PDB
49
+ """
50
+ three_to_one_letter = {
51
+ 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
52
+ 'GLU': 'E', 'GLN': 'Q', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
53
+ 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'MSE': 'M', 'PHE': 'F',
54
+ 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y',
55
+ 'VAL': 'V'}
56
+
57
+ resid2AA = {}
58
+ # Define the path to your PDB file
59
+
60
+ # Create a PDB parser
61
+ parser = PDBParser(QUIET=True)
62
+
63
+ # Parse the PDB file
64
+ structure = parser.get_structure("protein", pdb_file)
65
+
66
+ # Initialize an empty list to store amino acid codes
67
+ amino_acid_codes = []
68
+
69
+ # Iterate through the structure and extract amino acid codes
70
+ for model in structure:
71
+ for chain in model:
72
+ for residue in chain:
73
+ if is_aa(residue):
74
+ resname = residue.get_resname()
75
+ resid = residue.get_id()[1]
76
+ if resname in three_to_one_letter:
77
+ AA = three_to_one_letter[resname]
78
+ else:
79
+ AA = 'NC'
80
+ #print(resname, resid, AA)
81
+ resid2AA[resid] = AA
82
+ self.resid2AA = resid2AA
83
+
84
+ ## get the canonical uniprot sequence length
85
+ # Define the URL for the UniProt API
86
+ url = f"https://rest.uniprot.org/uniprotkb/{gene}.fasta"
87
+
88
+ # Make a GET request to the UniProt API
89
+ response = requests.get(url)
90
+
91
+ # Check if the response is OK
92
+ if response.status_code == 200:
93
+ # Extract the sequence from the FASTA format
94
+ fasta_data = response.text.splitlines()
95
+ sequence = ''.join(fasta_data[1:]) # Skip the first line (header)
96
+
97
+ # Return the length of the sequence
98
+ self.prot_size = len(sequence)
99
+ if self.prot_size == 0:
100
+ self.logger.error(f'The size of the protein in Uniprot is {self.prot_size} == 0. This likely means this uniprot ID no longer exists. No entanglement features will be calculated')
101
+ quit()
102
+ else:
103
+ raise ValueError(f"Error: Could not retrieve data for UniProt ID {uniprot_id}.")
104
+ #############################################################################################################
105
+
106
+ #############################################################################################################
107
+ def split_on_nth_char(self, s, char, n):
108
+ # Find the index of the nth occurrence of the char
109
+ occurrence = 0
110
+ index = -1
111
+ for i, c in enumerate(s):
112
+ if c == char:
113
+ occurrence += 1
114
+ if occurrence == n:
115
+ index = i
116
+ break
117
+
118
+ # If the nth occurrence is found, split the string
119
+ if index != -1:
120
+ return s[:index], s[index+1:]
121
+ else:
122
+ return s, ""
123
+ #############################################################################################################
124
+
125
+ #############################################################################################################
126
+ def get_uent_features(self, gene:str, pdbid:str, chain:str='A'):
127
+ """
128
+ Get the features for each unique entanglement provided in the clustered_unampped_GE file
129
+ """
130
+
131
+ uent_df = {'gene':[],
132
+ 'PDB':[],
133
+ 'chain':[],
134
+ 'ENT-ID':[],
135
+ 'gn':[],
136
+ 'N_term_thread':[],
137
+ 'gc':[],
138
+ 'C_term_thread':[],
139
+ 'i':[],
140
+ 'j':[],
141
+ 'NC':[],
142
+ 'NC_wbuff':[],
143
+ 'NC_region':[],
144
+ 'crossingsN':[],
145
+ 'crossingsC':[],
146
+ 'crossingsN_wbuff':[],
147
+ 'crossingsC_wbuff':[],
148
+ 'crossingsN_region':[],
149
+ 'crossingsC_region':[],
150
+ 'ent_region':[],
151
+ 'loopsize': [],
152
+ 'num_zipper_nc':[],
153
+ 'perc_bb_loop':[],
154
+ 'num_loop_contacting_res':[],
155
+ 'num_cross_nearest_neighbors':[],
156
+ 'ent_coverage':[],
157
+ 'min_N_prot_depth_left':[],
158
+ 'min_N_thread_depth_left':[],
159
+ 'min_N_thread_slippage_left':[],
160
+ 'min_C_prot_depth_right':[],
161
+ 'min_C_thread_depth_right':[],
162
+ 'min_C_thread_slippage_right':[],
163
+ 'prot_size':[],
164
+ 'ACO':[],
165
+ 'RCO':[],
166
+ 'CCBond':[]}
167
+
168
+ #############################################################################################################################################################################
169
+ ### Load entanglement information if present
170
+ topology = self.traj.topology
171
+
172
+ # get mapping of chain letters to chain index
173
+ chain_ids = {chain.chain_id: chain.index for chain in topology.chains}
174
+ self.logger.debug(f'chain_ids: {chain_ids}')
175
+ if chain not in chain_ids:
176
+ raise ValueError(f'chain {chain} not in PDB file')
177
+
178
+
179
+ # Get the protein size from uniprot and a dictionary that maps resid to amino acid (one letter)
180
+ self.get_AA(self.PDBfile, gene)
181
+ self.logger.debug(f'gene: {gene}, chain: {chain}, pdbid: {pdbid}, prot_size: {self.prot_size}')
182
+
183
+ ## parse lines to get native contacts, crossings,
184
+ rbuffer = 3
185
+ pdb_NC_list = [] # list of PDB native contact residues +/- rbuffer
186
+ pdb_NC_core_list = [] # list of PDB natvie contact residues
187
+ pdb_crossing_list = [] # list of PDB crossing residues +/- rbuffer
188
+ pdb_crossing_core_list = [] # list of PDB crossing residues
189
+
190
+ for rowi, row in self.GE_data.iterrows():
191
+ #print(row)
192
+ #print(f'#######: ENT-ID: {rowi}')
193
+ ent_core = []
194
+
195
+ ## check that the entanglement isnt in a non-mapped area. if so skip it
196
+ #line = line[1].split(',')
197
+ pdb_NCi_core = row['i']
198
+ pdb_NCj_core = row['j']
199
+
200
+ # Parse crossings from crossingsN and crossingsC columns
201
+ # Each column contains comma-separated crossing residues like "+109" or "+92,+93,+94"
202
+ pdb_crossing_res_core_N = []
203
+ pdb_crossing_res_core_C = []
204
+ for col in ['crossingsN', 'crossingsC']:
205
+ if col in row.index and pd.notna(row[col]) and row[col] != '':
206
+ crossings_str = str(row[col])
207
+ for cross in crossings_str.split(','):
208
+ if cross: # Skip empty strings
209
+ # Remove +/- sign and convert to int, handling potential .0 float artifacts
210
+ cross_num = cross[1:].split('.')[0] # Remove sign and any decimal part
211
+ cross_int = int(cross_num)
212
+ if col == 'crossingsN':
213
+ pdb_crossing_res_core_N.append(cross_int)
214
+ else:
215
+ pdb_crossing_res_core_C.append(cross_int)
216
+
217
+ # Combined list for backward compatibility in calculations
218
+ pdb_crossing_res_core = pdb_crossing_res_core_N + pdb_crossing_res_core_C
219
+ #print(f'pdb_crossing_res_core_N: {pdb_crossing_res_core_N}, pdb_crossing_res_core_C: {pdb_crossing_res_core_C}')
220
+
221
+ uent_df['gene'] += [gene]
222
+ uent_df['PDB'] += [pdbid]
223
+ uent_df['chain'] += [chain]
224
+ uent_df['ENT-ID'] += [rowi]
225
+ uent_df['i'] += [pdb_NCi_core]
226
+ uent_df['j'] += [pdb_NCj_core]
227
+
228
+
229
+ #########################################################################
230
+ ## get Gn and Gc and if it is present the cluster size
231
+ num_zipper_nc = row['num_contacts']
232
+ CCBond = row['CCBond']
233
+ gn = row['gn']
234
+ gc = row['gc']
235
+
236
+
237
+ # Calcualte the absolute and relative contact orders
238
+ range_strings = row['contacts'].split(';')
239
+ loops = []
240
+ for l in range_strings:
241
+ # if no negative residue was found
242
+ if l.count('-') == 1:
243
+ x = l.split('-', 1)
244
+ elif l.count('--') == 1 and l.count('-') == 2:
245
+ x = l.split('-', 1)
246
+ elif l.count('-') == 2 and l.count('--') == 0:
247
+ x = self.split_on_nth_char(l, '-', 2)
248
+ elif l.count('--') == 1 and l.count('-') == 3:
249
+ x = self.split_on_nth_char(l, '-', 2)
250
+ loops += [(int(x[0]), int(x[1]))]
251
+
252
+ loop_sizes = [j-i for i,j in loops]
253
+ #print(f'loop_sizes: {loop_sizes}')
254
+ ACO = np.sum(loop_sizes)/len(loop_sizes)
255
+ RCO = ACO/self.prot_size
256
+ #print(f'gn: {gn} | gc: {gc} | num_zipper_nc: {num_zipper_nc} | ACO: {ACO} | RCO: {RCO} | CCBond: {CCBond}')
257
+
258
+ uent_df['gn'] += [gn]
259
+ uent_df['gc'] += [gc]
260
+ uent_df['num_zipper_nc'] += [num_zipper_nc]
261
+ uent_df['ACO'] += [ACO]
262
+ uent_df['RCO'] += [RCO]
263
+ uent_df['CCBond'] += [CCBond]
264
+
265
+
266
+ #########################################################################
267
+ #get PDB native contact and those +/- rbuffer along the primary structure
268
+ pdb_NC_core = [pdb_NCi_core, pdb_NCj_core]
269
+ pdb_NC_core_list += pdb_NC_core
270
+
271
+ pdb_NCi = np.arange(pdb_NCi_core - rbuffer, pdb_NCi_core + rbuffer + 1)
272
+ pdb_NCj = np.arange(pdb_NCj_core - rbuffer, pdb_NCj_core + rbuffer + 1)
273
+ pdb_NC = np.hstack([pdb_NCi, pdb_NCj]).tolist()
274
+ pdb_NC_list += pdb_NC
275
+
276
+ #print(f'pdb_NC: {pdb_NC}')
277
+ #print(f'pdb_NC_core: {pdb_NC_core}')
278
+ uent_df['NC'] += [",".join([str(r) for r in pdb_NC_core])]
279
+ uent_df['NC_wbuff'] += [",".join([str(r) for r in pdb_NC])]
280
+
281
+ ## Calculate the NC_region using heavy atom distances
282
+ NC_region = self.find_neighboring_residues(self.traj, pdb_NC)
283
+ #print(f'NC_region: {NC_region}')
284
+ uent_df['NC_region'] += [",".join([str(r) for r in NC_region])]
285
+
286
+
287
+ loopsize = pdb_NCj_core - pdb_NCi_core
288
+ loop_resids = np.arange(pdb_NCi_core, pdb_NCj_core + 1)
289
+ loop_contacting_res = self.find_neighboring_residues(self.traj, loop_resids)
290
+ num_loop_contacting_res = len(loop_contacting_res)
291
+ #print(f'loop_contacting_res: {loop_contacting_res}')
292
+ #print(f'num_loop_contacting_res: {num_loop_contacting_res}')
293
+
294
+ uent_df['loopsize'] += [loopsize]
295
+ uent_df['perc_bb_loop'] += [loopsize/self.prot_size]
296
+ uent_df['num_loop_contacting_res'] += [num_loop_contacting_res]
297
+ #########################################################################
298
+
299
+
300
+ #########################################################################
301
+ #get PDB crossings and those +/- rbuffer along the primary structure
302
+ if pdb_crossing_res_core_N:
303
+ pdb_crossing_res_N = np.hstack([np.arange(int(x) - rbuffer, int(x) + rbuffer + 1) for x in pdb_crossing_res_core_N]).tolist()
304
+ else:
305
+ pdb_crossing_res_N = []
306
+
307
+ if pdb_crossing_res_core_C:
308
+ pdb_crossing_res_C = np.hstack([np.arange(int(x) - rbuffer, int(x) + rbuffer + 1) for x in pdb_crossing_res_core_C]).tolist()
309
+ else:
310
+ pdb_crossing_res_C = []
311
+
312
+ # Combined for overall calculations
313
+ pdb_crossing_res = pdb_crossing_res_N + pdb_crossing_res_C
314
+ #print(f'pdb_crossing_res_N: {pdb_crossing_res_N}')
315
+ #print(f'pdb_crossing_res_C: {pdb_crossing_res_C}')
316
+ #print(f'pdb_crossing_res_core_N: {pdb_crossing_res_core_N}, pdb_crossing_res_core_C: {pdb_crossing_res_core_C}')
317
+
318
+ pdb_crossing_list += pdb_crossing_res
319
+ pdb_crossing_core_list += pdb_crossing_res_core
320
+
321
+ # Store separated crossings
322
+ uent_df['crossingsN'] += [",".join([str(c) for c in pdb_crossing_res_core_N])]
323
+ uent_df['crossingsC'] += [",".join([str(c) for c in pdb_crossing_res_core_C])]
324
+ uent_df['crossingsN_wbuff'] += [",".join([str(c) for c in pdb_crossing_res_N])]
325
+ uent_df['crossingsC_wbuff'] += [",".join([str(c) for c in pdb_crossing_res_C])]
326
+
327
+ ### Get the crossing region using heavy atom distances
328
+ crossing_region_N = self.find_neighboring_residues(self.traj, pdb_crossing_res_N)
329
+ crossing_region_C = self.find_neighboring_residues(self.traj, pdb_crossing_res_C)
330
+ #print(f'crossing_region_N: {crossing_region_N}')
331
+ #print(f'crossing_region_C: {crossing_region_C}')
332
+ uent_df['crossingsN_region'] += [",".join([str(r) for r in crossing_region_N])]
333
+ uent_df['crossingsC_region'] += [",".join([str(r) for r in crossing_region_C])]
334
+
335
+ num_cross_nearest_neighbors = len(crossing_region_N) + len(crossing_region_C)
336
+ #print(f'num_cross_nearest_neighbors: {num_cross_nearest_neighbors}')
337
+ uent_df['num_cross_nearest_neighbors'] += [num_cross_nearest_neighbors]
338
+ #########################################################################
339
+
340
+
341
+ #########################################################################
342
+ ## Get number of threads in each termini and depth
343
+ #print(f'prot_size: {self.prot_size}')
344
+ N_term_thread = [c for c in pdb_crossing_res_core if c < pdb_NCi_core]
345
+ num_N_term_thread = len(N_term_thread)
346
+ #print(f'num_N_term_thread: {num_N_term_thread}')
347
+
348
+ C_term_thread = [c for c in pdb_crossing_res_core if c > pdb_NCj_core]
349
+ num_C_term_thread = len(C_term_thread)
350
+ #print(f'num_C_term_thread: {num_C_term_thread}')
351
+
352
+ #print(f'N_term_thread: {N_term_thread}')
353
+ #print(f'C_term_thread: {C_term_thread}')
354
+ uent_df['N_term_thread'] += [num_N_term_thread]
355
+ uent_df['C_term_thread'] += [num_C_term_thread]
356
+
357
+ if num_N_term_thread != 0:
358
+ min_N_thread_slippage_left = min(N_term_thread)
359
+ min_N_thread_depth_left = min_N_thread_slippage_left / pdb_NCi_core
360
+ min_N_prot_depth_left = min_N_thread_slippage_left / self.prot_size
361
+ else:
362
+ min_N_thread_slippage_left = np.nan
363
+ min_N_thread_depth_left = np.nan
364
+ min_N_prot_depth_left = np.nan
365
+ uent_df['min_N_thread_slippage_left'] += [min_N_thread_slippage_left]
366
+ uent_df['min_N_thread_depth_left'] += [min_N_thread_depth_left]
367
+ uent_df['min_N_prot_depth_left'] += [min_N_prot_depth_left]
368
+
369
+ if num_C_term_thread != 0:
370
+ min_C_thread_slippage_right = self.prot_size - max(C_term_thread)
371
+ min_C_thread_depth_right = min_C_thread_slippage_right / (self.prot_size - pdb_NCj_core)
372
+ min_C_prot_depth_right = min_C_thread_slippage_right / self.prot_size
373
+ else:
374
+ min_C_thread_slippage_right = np.nan
375
+ min_C_thread_depth_right = np.nan
376
+ min_C_prot_depth_right = np.nan
377
+ uent_df['min_C_thread_slippage_right'] += [min_C_thread_slippage_right]
378
+ uent_df['min_C_thread_depth_right'] += [min_C_thread_depth_right]
379
+ uent_df['min_C_prot_depth_right'] += [min_C_prot_depth_right]
380
+ #########################################################################
381
+
382
+
383
+ #########################################################################
384
+ ### Get entangled residues = NC_region U crossing_region
385
+ #print('Get total entangled region residues')
386
+ #print(f'NC_region: {NC_region}')
387
+ #print(f'crossing_region_N: {crossing_region_N}')
388
+ #print(f'crossing_region_C: {crossing_region_C}')
389
+ ent_region = set(NC_region).union(set(crossing_region_N)).union(set(crossing_region_C))
390
+ ent_region = ent_region.union(set(pdb_NC))
391
+ ent_region = ent_region.union(set(pdb_crossing_res))
392
+
393
+ #print(f'ent_region: {ent_region}')
394
+ uent_df['ent_region'] += [",".join([str(r) for r in ent_region])]
395
+
396
+ uent_df['ent_coverage'] += [len(ent_region)/self.prot_size]
397
+ uent_df['prot_size'] += [self.prot_size]
398
+ #########################################################################
399
+
400
+ ### save file for unique entanglement features
401
+ uent_df = pd.DataFrame(uent_df)
402
+ #print(f'uent_df:\n{uent_df}')
403
+ outfile = os.path.join(self.outdir, f'{gene}_{pdbid}_{chain}_uent_features.csv')
404
+ uent_df.to_csv(outfile, index=False, sep='|')
405
+ self.logger.info(f'Unique entanglement features saved to {outfile}')
406
+
407
+ return {'outfile':outfile, 'results': uent_df}
408
+ ########################################################################################################################
409
+
410
+ ########################################################################################################################
411
+ def find_neighboring_residues(self, traj, target_resids, cutoff=0.45):
412
+ """
413
+ Find all residues whose side-chain heavy atoms are within `cutoff` Å
414
+ of any side-chain heavy atom of the residues in `target_resids`.
415
+
416
+ Parameters
417
+ ----------
418
+ traj : md.Trajectory
419
+ The trajectory (or single-frame PDB) to search.
420
+ target_resids : list of int
421
+ List of topology residue indices (residue.index) to probe around.
422
+ cutoff : float, optional
423
+ Distance cutoff in Å (default 4.5).
424
+
425
+ Returns
426
+ -------
427
+ neighbors : list of int
428
+ Sorted residue indices (residue.index) whose side-chain heavy atoms
429
+ lie within `cutoff` of the target side-chain atoms.
430
+ (Target residues themselves are excluded from the result.)
431
+ """
432
+ #print(f'Finding neighboring residues for target residues: {target_resids}')
433
+ # --- 1. Identify side-chain heavy atoms in the target residues ---
434
+ query_atoms = [
435
+ atom.index
436
+ for atom in traj.topology.atoms
437
+ if atom.residue.index in target_resids
438
+ and atom.element.symbol != 'H'
439
+ and atom.name not in ('N', 'CA', 'C', 'O')
440
+ ]
441
+
442
+ # --- 2. Identify side‑chain heavy atoms in all residues ---
443
+ haystack_atoms = [
444
+ atom.index
445
+ for atom in traj.topology.atoms
446
+ if atom.element.symbol != 'H'
447
+ and atom.name not in ('N', 'CA', 'C', 'O')
448
+ ]
449
+ #print(f'haystack_atoms: {haystack_atoms} {len(haystack_atoms)}')
450
+
451
+ # --- 3. Use MDTraj’s neighbor search ---
452
+ # For each frame, this returns indices *into* `haystack_atoms`
453
+ neighbors_per_frame = md.compute_neighbors(
454
+ traj, cutoff=cutoff,
455
+ query_indices=query_atoms,
456
+ haystack_indices=haystack_atoms
457
+ )
458
+ #print(f'neighbors_per_frame: {neighbors_per_frame}')
459
+
460
+
461
+ # --- 4. Map back to global atom indices and then to residues ---
462
+ neighbor_atom_indices = set(neighbors_per_frame[0])
463
+ #for frame_indices in neighbors_per_frame:
464
+ # print(f'frame_indices: {frame_indices}')
465
+ # for idx_in_haystack in frame_indices:
466
+ # neighbor_atom_indices.add(haystack_atoms[idx_in_haystack])
467
+
468
+ neighbor_residue_indices = {
469
+ traj.topology.atom(atom_idx).residue.resSeq
470
+ for atom_idx in neighbor_atom_indices
471
+ }
472
+ #print(f'neighbor_residue_indices: {neighbor_residue_indices} {len(neighbor_residue_indices)}')
473
+
474
+ # --- 5. Exclude the original target residues ---
475
+ neighbor_residue_indices -= set(target_resids)
476
+
477
+ return sorted(neighbor_residue_indices)
478
+ ########################################################################################################################