EntDetect 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. EntDetect/Jwalk/GridTools.py +567 -0
  2. EntDetect/Jwalk/PDBTools.py +532 -0
  3. EntDetect/Jwalk/SASDTools.py +543 -0
  4. EntDetect/Jwalk/SurfaceTools.py +150 -0
  5. EntDetect/Jwalk/__init__.py +19 -0
  6. EntDetect/Jwalk/naccess.config.txt +255 -0
  7. EntDetect/__init__.py +10 -0
  8. EntDetect/_logging.py +71 -0
  9. EntDetect/change_resolution.py +2361 -0
  10. EntDetect/clustering.py +2626 -0
  11. EntDetect/compare_sim2exp.py +1927 -0
  12. EntDetect/entanglement_features.py +478 -0
  13. EntDetect/gaussian_entanglement.py +2067 -0
  14. EntDetect/order_params.py +1048 -0
  15. EntDetect/resources/__init__.py +11 -0
  16. EntDetect/resources/__pycache__/__init__.cpython-311.pyc +0 -0
  17. EntDetect/resources/calc_K.pl +712 -0
  18. EntDetect/resources/calc_Q.pl +962 -0
  19. EntDetect/resources/pulchra +0 -0
  20. EntDetect/resources/shared_files/__init__.py +2 -0
  21. EntDetect/resources/shared_files/bt_contact_potential.dat +22 -0
  22. EntDetect/resources/shared_files/karanicolas_dihe_parm.dat +1600 -0
  23. EntDetect/resources/shared_files/kgs_contact_potential.dat +22 -0
  24. EntDetect/resources/shared_files/mj_contact_potential.dat +22 -0
  25. EntDetect/resources/stride +0 -0
  26. EntDetect/statistics.py +1344 -0
  27. EntDetect/utilities.py +201 -0
  28. entdetect-1.2.0.dist-info/METADATA +26 -0
  29. entdetect-1.2.0.dist-info/RECORD +45 -0
  30. entdetect-1.2.0.dist-info/WHEEL +5 -0
  31. entdetect-1.2.0.dist-info/entry_points.txt +11 -0
  32. entdetect-1.2.0.dist-info/licenses/LICENSE +674 -0
  33. entdetect-1.2.0.dist-info/top_level.txt +2 -0
  34. scripts/__init__.py +5 -0
  35. scripts/convert_cor_psf_to_pdb.py +103 -0
  36. scripts/run_Foldingpathway.py +162 -0
  37. scripts/run_MSM.py +152 -0
  38. scripts/run_OP_on_simulation_traj.py +194 -0
  39. scripts/run_change_resolution.py +63 -0
  40. scripts/run_compare_sim2exp.py +215 -0
  41. scripts/run_montecarlo.py +158 -0
  42. scripts/run_nativeNCLE.py +179 -0
  43. scripts/run_nonnative_entanglement_clustering.py +110 -0
  44. scripts/run_population_modeling.py +117 -0
  45. scripts/run_workflow4_nativeNCLE_batch.py +412 -0
@@ -0,0 +1,1048 @@
1
+ #!/usr/bin/env python3
2
+ import requests, logging, os, sys
3
+ from EntDetect._logging import setup_logger
4
+ import time
5
+ import argparse
6
+ import pandas as pd
7
+ import numpy as np
8
+ import glob
9
+ import MDAnalysis as mda
10
+ import mdtraj as md
11
+ import freesasa
12
+ from scipy.spatial.distance import pdist, squareform
13
+ from topoly import lasso_type # used pip
14
+ import itertools
15
+ import concurrent.futures
16
+ from EntDetect.gaussian_entanglement import GaussianEntanglement
17
+ from EntDetect.clustering import ClusterNativeEntanglements
18
+ from EntDetect.Jwalk import PDBTools, GridTools, SurfaceTools, SASDTools
19
+ from importlib.resources import files
20
+ import subprocess
21
+ import pathlib
22
+ from multiprocessing import cpu_count
23
+ from scipy.stats import norm
24
+
25
+ pd.set_option('display.max_rows', 5000)
26
+
27
+ class CalculateOP:
28
+ """
29
+ A class to handel the analyssis of a C-alpha CG trajectory.
30
+ Current analysis available:
31
+ (1) - Fraction of native contacts (Q)
32
+ (2) - Fraction of native contacts with a change in entanglement (G)
33
+ (3) - Solvant Accessible Surface Area (SASA)
34
+ (4) - Mirror symmetry order parameter (K)
35
+ (5) - Cross linking probability score (XP)
36
+ use_traj=False (default): single static PDB
37
+ use_traj=True: per-frame from DCD, respects self.start/end/stride
38
+ (6) - Jwalk SASD
39
+ """
40
+ #######################################################################################
41
+ def __init__(self, outdir:str='./', ID:str='', Traj:int=1, psf:str='', cor:str='', dcd:str='', sec_elements:str='', domain:str='', start:int=0, end:int=99999999999999, stride:int=1, ent_detection_method:int=2, log_level:int=logging.INFO, logdir:str=None):
42
+ """
43
+ Initializes the DataAnalysis class with necessary paths and parameters.
44
+
45
+ Parameters:
46
+ ("--outdir", type=str, required=True, help="Path to output directory")
47
+ ("--ID", type=str, required=True, help="base name for output files")
48
+ ("--Traj", type=int, required=True, help="trajectory index")
49
+ ("--psf", type=str, required=True, help="Path to CA protein structure file")
50
+ ("--cor", type=str, required=True, help="Path to CA native coordinates file")
51
+ ("--dcd", type=str, required=True, help="Path to trajectory to analyze")
52
+ ("--sec_elements", type=str, required=True, help="Path to STRIDE secondary structure elements file")
53
+ ("--domain", type=str, required=True, help="Path to domain definition file")
54
+ ("--start", type=int, required=False, help="First frame to analyze 0 indexed", default=0)
55
+ ("--end", type=int, required=False, help="Last frame to analyze 0 indexed", default=-1)
56
+ ("--stride", type=int, required=False, help="Frame stride", default=1)
57
+ """
58
+
59
+ # parse the parameters
60
+ self.logger = setup_logger('CalculateOP', outdir=logdir if logdir is not None else outdir, ID=ID, log_level=log_level)
61
+ self.outdir = outdir
62
+ self.logger.debug(f'outdir: {self.outdir}')
63
+
64
+ self.ID = ID
65
+ self.logger.debug(f'ID: {self.ID}')
66
+
67
+ self.Traj = Traj
68
+ self.logger.debug(f'Traj: {Traj}')
69
+
70
+ self.psf = psf
71
+ self.logger.debug(f'psf: {self. psf}')
72
+
73
+ self.sec_elements = sec_elements
74
+ self.logger.debug(f'sec_elements: {self.sec_elements}')
75
+
76
+ self.domain = domain
77
+ self.logger.debug(f'domain: {self.domain}')
78
+
79
+ self.cor = cor
80
+ self.logger.debug(f'cor: {self.cor}')
81
+
82
+ self.dcd = dcd
83
+ self.logger.debug(f'dcd: {self.dcd}')
84
+
85
+ if self.cor is not None and self.cor != '' and self.cor.endswith('.cor'):
86
+ self.ref_universe = mda.Universe(self.psf, self.cor, format='CRD')
87
+ self.logger.debug(f'ref_universe: {self.ref_universe}')
88
+
89
+ self.traj_universe = mda.Universe(self.psf, self.dcd, format='DCD')
90
+ self.logger.debug(f'traj_universe: {self.traj_universe}')
91
+
92
+ self.start = start
93
+ self.end = end
94
+ self.stride = stride
95
+ self.logger.debug(f'START: {self.start} | END: {self.end} | STRIDE: {self.stride}')
96
+
97
+ self.ent_detection_method = ent_detection_method
98
+ self.logger.debug(f'ent_detection_method: {self.ent_detection_method}')
99
+
100
+ self.three_to_one = {
101
+ "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D",
102
+ "CYS": "C", "GLN": "Q", "GLU": "E", "GLY": "G",
103
+ "HIS": "H", "ILE": "I", "LEU": "L", "LYS": "K",
104
+ "MET": "M", "PHE": "F", "PRO": "P", "SER": "S",
105
+ "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
106
+ "SEC": "U", # Selenocysteine
107
+ "PYL": "O", # Pyrrolysine
108
+ "ASX": "B", # Asp or Asn (ambiguous)
109
+ "GLX": "Z", # Glu or Gln (ambiguous)
110
+ "XAA": "X", # Any/unknown amino acid
111
+ "TER": "*"} # Stop codon
112
+ #######################################################################################
113
+
114
+ #######################################################################################
115
+ def Qpy(self, ):
116
+ self.logger.info(f'Calculating the fraction of native contacts (Q)')
117
+ """
118
+ Calculate the fraction of native contacts in each frame of the DCD where a native contact is defined between secondary structures
119
+ and for residues atleast that are atleast 3 residues apart. So if i = 1 then j at a minimum can be 5.
120
+ For a contact to be present the distance between i and j must be less than 8A in the native structure and in a trajectory frame be less than 1.2*native distance.
121
+ """
122
+ # make directory for Q data if it doesnt exist
123
+ self.Qpath = os.path.join(self.outdir, 'Q')
124
+ if not os.path.exists(self.Qpath):
125
+ os.makedirs(self.Qpath)
126
+ self.logger.info(f'Made directory: {self.Qpath}')
127
+
128
+ # Step 0: load the reference structure and topology
129
+ ref_coor = self.ref_universe.atoms.positions
130
+ #print(f'ref_coor:\n{ref_coor} {ref_coor.shape}')
131
+
132
+
133
+ # Step 1: Get the secondary structure information
134
+ # get both those resides in the secondary structures and those not
135
+ self.logger.info(f'Step 1: Get the secondary structure information')
136
+ resid_in_sec_elements = np.loadtxt(self.sec_elements, dtype=int)
137
+ resid_in_sec_elements = [np.arange(x[1], x[2] + 1) for x in resid_in_sec_elements]
138
+ resid_in_sec_elements = np.hstack(resid_in_sec_elements)
139
+ #print(f'resid_in_sec_elements: {resid_in_sec_elements}')
140
+
141
+ resid_not_in_sec_elements = np.asarray([r for r in range(1, len(ref_coor) + 1) if r not in resid_in_sec_elements]) # residue ID not in secondary structures
142
+ #print(f'resid_not_in_sec_elements: {resid_not_in_sec_elements}')
143
+
144
+
145
+ # Step 2: Get the native distance map for the native state cordinates
146
+ self.logger.info(f'Step 2: Get the native distance map for the native state cordinates')
147
+ # Zero the resulting distance map up to the 4th diagonal so only those residues with more than 3 residues between them can be in contact
148
+ # Zero out any secondary structure element residues
149
+ # Zero out any distance not less than 8A
150
+ ref_distances = np.triu(squareform(pdist(ref_coor)), k=4)
151
+ ref_distances[resid_not_in_sec_elements - 1, :] = 0
152
+ ref_distances[:, resid_not_in_sec_elements - 1] = 0
153
+ ref_distances[ref_distances > 8] = 0
154
+ NumNativeContacts = np.count_nonzero(ref_distances)
155
+ self.logger.debug(f'NumNativeContacts: {NumNativeContacts}')
156
+ self.logger.debug(f'NumNativeContacts: {NumNativeContacts}')
157
+
158
+ # Step 3: Analyze each frame of the traj_universe and get the distance map
159
+ self.logger.info(f'Step 3: Analyze each frame of the traj_universe and calc Q')
160
+ # then determine the fraction of native contacts by those distances less than 1.2*native distance
161
+ Qoutput = {'Time(ns)':[], 'Frame':[], 'FrameNumNativeContacts':[], 'Q':[]}
162
+ for ts in self.traj_universe.trajectory[self.start:self.end:self.stride]:
163
+ frame_coor = self.traj_universe.atoms.positions
164
+ frame_distances = np.triu(squareform(pdist(frame_coor)), k=4)
165
+ frame_distances[resid_not_in_sec_elements - 1, :] = 0
166
+ frame_distances[:, resid_not_in_sec_elements - 1] = 0
167
+
168
+ cond = (frame_distances <= 1.2*ref_distances) & (ref_distances != 0)
169
+
170
+ FrameNumNativeContacts = np.sum(cond)
171
+ #print(f'FrameNumNativeContacts: {FrameNumNativeContacts} for frame {ts.frame}')
172
+
173
+ Q = FrameNumNativeContacts/NumNativeContacts
174
+ #print(f'Q: {Q} for frame {ts.frame}')
175
+
176
+ frame_time = ts.time/1000
177
+ Qoutput['Frame'] += [ts.frame]
178
+ Qoutput['FrameNumNativeContacts'] += [FrameNumNativeContacts]
179
+ Qoutput['Q'] += [Q]
180
+ Qoutput['Time(ns)'] += [frame_time]
181
+
182
+ # Step 4: save Q output
183
+ self.logger.info(f'Step 4: save Q output')
184
+ Qoutput = pd.DataFrame(Qoutput)
185
+ Qoutfile = os.path.join(self.Qpath, f'{self.ID}.Q')
186
+ Qoutput.to_csv(Qoutfile, index=False)
187
+ self.logger.info(f'SAVED: {Qoutfile}')
188
+ self.logger.info(f'SAVED: {Qoutfile}')
189
+ return {'outfile':Qoutfile, 'result':Qoutput}
190
+ #######################################################################################
191
+
192
+ #######################################################################################
193
+ def Q(self,):
194
+ """
195
+ Calculate the fraction of native contacts (Q) using Yang's perl code which goes further and uses the domain definitions as well as the secondary structure elements
196
+ it will return the fraction of native contacts overall (same as what Qpy) will give you as well as the Q within each domain and between them
197
+ """
198
+ # make directory for Q data if it doesnt exist
199
+ self.Qpath = os.path.join(self.outdir, 'Q')
200
+ if not os.path.exists(self.Qpath):
201
+ os.makedirs(self.Qpath)
202
+ self.logger.info(f'Made directory: {self.Qpath}')
203
+
204
+ # Check if the Q output file exists. else make it
205
+ dcdname = self.dcd.split('/')[-1].split('.')[0]
206
+ self.logger.debug(f'dcdname: {dcdname}')
207
+ outfilename = os.path.join(self.Qpath, f'Q_{dcdname}.dat')
208
+ self.logger.debug(f'outfilename: {outfilename}')
209
+ renamed_outfile = os.path.join(self.Qpath, f'{self.ID}_Traj{self.Traj}.Q') ## This is the new name of the calc_Q.pl output script after it has had the Frames added
210
+
211
+ u = mda.Universe(self.psf, self.dcd)
212
+ self.logger.debug(u)
213
+ frames = [ts.frame for ts in u.trajectory]
214
+ #print(f'frames: {frames}')
215
+ if self.start < 0:
216
+ self.start = frames[self.start]
217
+ self.logger.debug(f'START: {self.start}')
218
+
219
+ if os.path.exists(renamed_outfile):
220
+ self.logger.info(f'Q outfile exists: {renamed_outfile}')
221
+ Qoutput = pd.read_csv(renamed_outfile, sep = ',')
222
+ #print(f'Qoutput:\n{Qoutput}')
223
+
224
+ else:
225
+ script_path = files('EntDetect.resources').joinpath('calc_Q.pl')
226
+ self.logger.debug(f'script_path: {script_path}')
227
+
228
+ cmd = f'perl {script_path} -i {self.cor} -t {self.dcd} -d {self.domain} -s {self.sec_elements} -b {self.start + 1} -e {self.end} -o {self.Qpath}'
229
+ self.logger.debug(f'cmd: {cmd}')
230
+
231
+ result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
232
+ if result.returncode != 0:
233
+ raise RuntimeError(f"Perl script failed:\n{result.stderr}")
234
+ self.logger.debug(result.stdout)
235
+
236
+ ## rename the file to match the standard OP file format {ID}_Traj{traj}.Q
237
+ os.rename(outfilename, renamed_outfile)
238
+ self.logger.debug(f'Renamed: {outfilename} -> {renamed_outfile}')
239
+
240
+ ## read the Q file back in and add the Frame column
241
+ Qoutput = pd.read_csv(renamed_outfile, delim_whitespace=True)
242
+ #print(f'Qoutput:\n{Qoutput}')
243
+
244
+ #print(frames[self.start:self.end])
245
+ sel_frames = frames[self.start:self.end]
246
+ #print(f'sel_frames: {sel_frames}')
247
+ Qoutput['Frame'] = sel_frames
248
+ #print(f'Qoutput:\n{Qoutput}')
249
+
250
+ Qoutput.to_csv(renamed_outfile, index=False, sep = ',')
251
+ self.logger.info(f'SAVED: {renamed_outfile}')
252
+
253
+ return {'outfile':renamed_outfile, 'result':Qoutput}
254
+ #######################################################################################
255
+
256
+ #######################################################################################
257
+ def G(self, topoly:bool=True, Calpha:bool=True, CG:bool=True, nproc: int = 10, chunk_frames:int=None, chunk_suffix:str='_chunk') -> dict:
258
+ self.logger.info(f'Calculating the G entanglement order parameter')
259
+ """
260
+ Calculate the G entanglement order parameter for each frame of the DCD
261
+ """
262
+ # make directory for G data if it doesnt exist
263
+ self.Gpath = os.path.join(self.outdir, 'G')
264
+ if not os.path.exists(self.Gpath):
265
+ os.makedirs(self.Gpath)
266
+ self.logger.info(f'Made directory: {self.Gpath}')
267
+
268
+ # parse some of the default parameters
269
+ g_threshold = 0.6
270
+ density = 1.0
271
+
272
+ self.logger.debug(f'g_threshold: {g_threshold}')
273
+ self.logger.debug(f'density: {density}')
274
+ self.logger.debug(f'Calpha: {Calpha}')
275
+ self.logger.debug(f'CG: {CG}')
276
+ self.logger.debug(f'nproc: {nproc}')
277
+ self.logger.debug(f'chunk_frames: {chunk_frames}')
278
+ self.logger.debug(f'chunk_suffix: {chunk_suffix}')
279
+
280
+ ## initialize the entanglement object
281
+ ge = GaussianEntanglement(
282
+ g_threshold=g_threshold,
283
+ density=density,
284
+ Calpha=Calpha,
285
+ CG=CG,
286
+ nproc=nproc,
287
+ ent_detection_method=self.ent_detection_method,
288
+ ) # for CG structures and trajectories
289
+ #ge = GaussianEntanglement(g_threshold=g_threshold, density=density, Calpha=False, CG=False) # for all-atom structures
290
+ self.logger.debug(ge)
291
+
292
+ ## initialize the clustering object
293
+ clustering = ClusterNativeEntanglements(organism='Ecoli')
294
+ self.logger.debug(clustering)
295
+
296
+ ## Get the native entanglements from a CG model
297
+ self.logger.info(f'Calculating the native entanglements...')
298
+ NativeEnt = ge.calculate_native_entanglements(self.cor, outdir=os.path.join(self.Gpath,'Native_GE/'), ID=f'{self.ID}_native', topoly=topoly)
299
+ #print(NativeEnt)
300
+
301
+ ## Cluster the native entanglements
302
+ self.logger.info(f'Clustering the native entanglements...')
303
+ nativeClusteredEnt = clustering.Cluster_NativeEntanglements(NativeEnt['outfile'], outdir=os.path.join(self.Gpath,'Native_clustered_GE/'), outfile=f'{self.ID}_NativeEntClusters.txt')
304
+ #print(nativeClusteredEnt)
305
+
306
+ ## Get the trajectory entanglements
307
+ self.logger.info(f'Calculating the trajectory entanglements...')
308
+ TrajEnt = ge.calculate_traj_entanglements(
309
+ self.dcd,
310
+ self.psf,
311
+ outdir=os.path.join(self.Gpath, 'Traj_GE/'),
312
+ ID=f'{self.ID}_traj{self.Traj}',
313
+ start=self.start,
314
+ stop=self.end,
315
+ topoly=topoly,
316
+ ref_contact_file=NativeEnt['outfile'],
317
+ )
318
+ #print(TrajEnt)
319
+
320
+ ## Create the combined .pkl file required for clustering non-native entanglements
321
+ ## Will also calculate G at the same time
322
+ self.logger.info(f'Creating the combined .pkl file required for clustering non-native entanglements...')
323
+ Combined_data = ge.combine_ref_traj_GE(NativeEnt['outfile'], TrajEnt['outfile'], outdir=os.path.join(self.Gpath,'Combined_GE/'), ID=f'{self.ID}_traj{self.Traj}', chunk_frames=chunk_frames, chunk_suffix=chunk_suffix)
324
+ G = Combined_data['G']
325
+ Goutfile = os.path.join(self.Gpath, f'{self.ID}_Traj{self.Traj}.G')
326
+ G.to_csv(Goutfile, index=False)
327
+ self.logger.info(f'SAVED: {Goutfile}')
328
+ return {'outfile':Goutfile, 'result':G}
329
+ #######################################################################################
330
+
331
+ #######################################################################################
332
+ def SASA(self,) -> dict:
333
+ """
334
+ Calculate the solvent accessible surface area (SASA) for each frame of the DCD using freesasa.
335
+ Uses freesasa library which is robust to coordinate artifacts (e.g., overlapping atoms).
336
+ """
337
+ import tempfile
338
+
339
+ # make directory for SASA data if it doesnt exist
340
+ self.SASAPATH = os.path.join(self.outdir, 'SASA')
341
+ if not os.path.exists(self.SASAPATH):
342
+ os.makedirs(self.SASAPATH)
343
+ self.logger.info(f'Made directory: {self.SASAPATH}')
344
+
345
+ # Step -1: get the resid list from the MDAnalysis universe self.traj_universe
346
+ # this is the list of residues in the trajectory
347
+ resids = self.traj_universe.atoms.residues.resids
348
+
349
+ # Step 0: load the dcd and psf into a mdtraj trajectory for frame iteration
350
+ traj = md.load(self.dcd, top=self.psf)
351
+
352
+ # Get the total frames and then adjust the frame_number to start from there
353
+ total_frames = len(traj)
354
+ self.logger.debug(f'total_frames: {total_frames}')
355
+ if self.start >= 0:
356
+ frame_number = self.start
357
+ else:
358
+ frame_number = total_frames + self.start
359
+ self.logger.debug(f'frame_number: {frame_number}')
360
+
361
+ # Step 1: loop through the trajectory and calculate the SASA for each frame using freesasa
362
+ self.logger.info(f'Step 1: loop through the trajectory and calculate the SASA for each frame using freesasa')
363
+
364
+ SASAoutput = {'Time(ns)':[], 'Frame':[], 'resid':[], 'SASA(nm^2)':[]}
365
+ last_valid_sasa = None # Store last valid SASA results for fallback
366
+
367
+ for ts in traj[self.start:self.end:self.stride]:
368
+ # Save frame to temporary PDB
369
+ with tempfile.NamedTemporaryFile(suffix='.pdb', delete=False) as tmp:
370
+ tmp_pdb = tmp.name
371
+
372
+ try:
373
+ # Check for NaN coordinates before attempting to save PDB
374
+ positions = ts.xyz
375
+ if np.any(np.isnan(positions)):
376
+ nan_atoms = np.where(np.any(np.isnan(positions), axis=1))[0]
377
+ self.logger.warning(f'Frame {frame_number} has NaN coordinates in {len(nan_atoms)} atoms. Using SASA from previous frame.')
378
+
379
+ # Use last valid SASA results if available
380
+ if last_valid_sasa is not None:
381
+ frame_time = ts.time[0]/1000
382
+ for resididx, res_sasa in enumerate(last_valid_sasa):
383
+ SASAoutput['Time(ns)'] += [frame_time]
384
+ SASAoutput['Frame'] += [frame_number]
385
+ SASAoutput['resid'] += [resids[resididx]]
386
+ SASAoutput['SASA(nm^2)'] += [res_sasa]
387
+ else:
388
+ self.logger.error(f'Frame {frame_number} has NaN coordinates but no previous valid SASA to fall back to. Skipping this frame.')
389
+
390
+ frame_number += 1
391
+ continue
392
+
393
+ ts.save_pdb(tmp_pdb)
394
+
395
+ # Load with freesasa and calculate SASA
396
+ structure = freesasa.Structure(tmp_pdb)
397
+ result = freesasa.calc(structure)
398
+
399
+ # Get per-residue SASA from freesasa result
400
+ # residueAreas() returns nested dict: {chain: {res_num: ResidueArea_object}}
401
+ res_areas = result.residueAreas()
402
+ sasa_per_residue = []
403
+
404
+ # Extract residues in order across all chains
405
+ for chain in sorted(res_areas.keys()):
406
+ for res_num in sorted(res_areas[chain].keys(), key=lambda x: int(x)):
407
+ # Get total SASA for this residue (in Angstroms^2)
408
+ res_sasa = res_areas[chain][res_num].total
409
+ sasa_per_residue.append(res_sasa)
410
+
411
+ # Convert from Angstroms^2 to nm^2 (1 nm^2 = 100 Angstroms^2)
412
+ sasa_per_residue = np.array(sasa_per_residue) / 100.0
413
+
414
+ # Store this as last valid SASA for fallback
415
+ last_valid_sasa = sasa_per_residue
416
+
417
+ # get the time
418
+ frame_time = ts.time[0]/1000
419
+
420
+ # add the results to the output dictionary
421
+ for resididx, res_sasa in enumerate(sasa_per_residue):
422
+ SASAoutput['Time(ns)'] += [frame_time]
423
+ SASAoutput['Frame'] += [frame_number]
424
+ SASAoutput['resid'] += [resids[resididx]]
425
+ SASAoutput['SASA(nm^2)'] += [res_sasa]
426
+
427
+ frame_number += 1
428
+
429
+ finally:
430
+ # Clean up temporary file
431
+ if os.path.exists(tmp_pdb):
432
+ os.remove(tmp_pdb)
433
+
434
+ # Step 2: save the SASA output
435
+ self.logger.info(f'Step 2: save the SASA output')
436
+ SASAoutput = pd.DataFrame(SASAoutput)
437
+ self.logger.info(f'SASAoutput:\n{SASAoutput}')
438
+ SASAoutfile = os.path.join(self.SASAPATH, f'{self.ID}_Traj{self.Traj}.SASA')
439
+ SASAoutput.to_csv(SASAoutfile, index=False)
440
+ self.logger.info(f'SAVED: {SASAoutfile}')
441
+
442
+ return {'outfile':SASAoutfile, 'result':SASAoutput}
443
+ #######################################################################################
444
+
445
+ #######################################################################################
446
+ def K(self,) -> dict:
447
+ """
448
+ Calculate the mirror symmetry order parameter K for each frame of the DCD
449
+ """
450
+ # make directory for SASA data if it doesnt exist
451
+ self.KPATH = os.path.join(self.outdir, 'K')
452
+ if not os.path.exists(self.KPATH):
453
+ os.makedirs(self.KPATH)
454
+ self.logger.info(f'Made directory: {self.KPATH}')
455
+
456
+ script_path = files('EntDetect.resources').joinpath('calc_K.pl')
457
+ #print(f'script_path: {script_path}')
458
+
459
+ cmd = f'perl {script_path} -i {self.cor} -t {self.dcd} -d {self.domain} -s {self.sec_elements} -b {self.start} -e {self.end} -o {self.KPATH}'
460
+ #print(f'cmd: {cmd}')
461
+
462
+ result = subprocess.run(cmd, capture_output=True, text=True, shell=True)
463
+ if result.returncode != 0:
464
+ raise RuntimeError(f"Perl script failed:\n{result.stderr}")
465
+ #print(result)
466
+
467
+ ## outfile will follow the following format K_{name}.dat where name is the name of the DCD read in
468
+ dcdname = self.dcd.split('/')[-1].split('.')[0]
469
+ #print(f'dcdname: {dcdname}')
470
+ outfilename = os.path.join(self.KPATH, f'K_{dcdname}.dat')
471
+ #print(f'outfilename: {outfilename}')
472
+
473
+ if os.path.exists(outfilename):
474
+ self.logger.info(f'K outfile exists: {outfilename}')
475
+ Koutput = pd.read_csv(outfilename, delim_whitespace=True)
476
+ self.logger.info(f'Koutput:\n{Koutput}')
477
+ return {'outfile':outfilename, 'result':Koutput}
478
+ else:
479
+ self.logger.info(f'K outfile does not exist: {outfilename}')
480
+ raise FileNotFoundError(f'K outfile does not exist: {outfilename}')
481
+ #######################################################################################
482
+
483
+ #######################################################################################
484
+ def XP(self, pdb:str='None', use_traj:bool=False, nproc:int=1) -> dict:
485
+ """
486
+ Calculates the cross-linking probability (XP) for all pairs of amino acid types [K, S, T, Y, M].
487
+
488
+ use_traj=False (default):
489
+ Runs on the single static PDB supplied as `pdb` — original behaviour.
490
+ Output: XP/Jwalk_results_{ID}_Traj{N}/Jwalk_results/{stem}_crosslink_list.txt
491
+
492
+ use_traj=True:
493
+ Iterates over DCD frames [self.start : self.end : self.stride]. For each frame a
494
+ temporary per-frame PDB is written, Jwalk is run, XP is scored, and the per-frame PDB
495
+ is deleted immediately. Produces a single combined tab-separated file:
496
+ XP/{ID}_Traj{N}.XP with columns: Frame | Index | Model | Atom1 | Atom2 | SASD | Euclidean Distance | XP
497
+ nproc > 1 parallelises frame-level Jwalk runs via ThreadPoolExecutor.
498
+ """
499
+ traj_nproc = 1 # number of processors in the pool for the trajectory mode — Jwalk is not thread safe so must be run with nproc=1, but we can parallelise across frames with ThreadPoolExecutor
500
+
501
+ # make output directory
502
+ self.XPpath = os.path.join(self.outdir, 'XP')
503
+ if not os.path.exists(self.XPpath):
504
+ os.makedirs(self.XPpath)
505
+ self.logger.info(f'Made directory: {self.XPpath}')
506
+
507
+ col_names = ["Index", "Model", "Atom1", "Atom2", "SASD", "Euclidean Distance"]
508
+
509
+ if not use_traj:
510
+ # ── single-PDB path (original behaviour, unchanged) ───────────────
511
+ xl_list = os.path.join(self.XPpath, f'{self.ID}_Traj{self.Traj}_XLresidue_pairs.txt')
512
+ self.find_residue_pairs(pdb, output_file=xl_list)
513
+
514
+ pdbObj = pathlib.Path(pdb)
515
+ if not pdbObj.exists():
516
+ self.logger.error(f'ERROR: The input file supplied cannot be found. Please enter a .pdb file type')
517
+ sys.exit(2)
518
+ jwalk_results_dir = os.path.join(self.XPpath, f'Jwalk_results_{self.ID}_Traj{self.Traj}')
519
+ Jwalk_outfile = os.path.join(jwalk_results_dir, 'Jwalk_results', f'{pdbObj.stem}_crosslink_list.txt')
520
+ if os.path.exists(Jwalk_outfile):
521
+ self.logger.info(f'Jwalk outfile exists: {Jwalk_outfile}')
522
+ else:
523
+ self.runJwalk(pdb, xl_list=xl_list, max_dist=50.0,
524
+ jwalk_results_dir=jwalk_results_dir, vox=1, ncpus=nproc)
525
+ self.logger.debug('Jwalk calculated')
526
+
527
+ Jwalk_df = pd.read_csv(Jwalk_outfile, sep=r'\s+', names=col_names,
528
+ skiprows=1, engine='python', index_col=False)
529
+ XP_scores = []
530
+ for rowi, row in Jwalk_df.iterrows():
531
+ AA1 = self.three_to_one[row['Atom1'].split('-')[0][0:3]]
532
+ AA2 = self.three_to_one[row['Atom2'].split('-')[0][0:3]]
533
+ XP_scores.append(self.score_XL((AA1, AA2), row['SASD']))
534
+ Jwalk_df['XP'] = XP_scores
535
+ Jwalk_df.to_csv(Jwalk_outfile, index=False, sep='\t')
536
+ self.logger.info(f'SAVED: {Jwalk_outfile}')
537
+ return {'outfile': Jwalk_outfile, 'result': Jwalk_df}
538
+
539
+ else:
540
+ # ── trajectory mode ───────────────────────────────────────────────
541
+
542
+ # skip-if-exists guard
543
+ combined_outfile = os.path.join(self.XPpath, f'{self.ID}_Traj{self.Traj}.XP')
544
+ if os.path.exists(combined_outfile):
545
+ self.logger.info(f'XP outfile exists, loading: {combined_outfile}')
546
+ return {'outfile': combined_outfile,
547
+ 'result': pd.read_csv(combined_outfile, sep='\t')}
548
+
549
+ # compute residue pairs once from the reference PDB (topology is frame-invariant)
550
+ xl_list = os.path.join(self.XPpath, f'{self.ID}_Traj{self.Traj}_XLresidue_pairs.txt')
551
+ self.find_residue_pairs(pdb, output_file=xl_list)
552
+
553
+ # temporary directory for per-frame PDB files
554
+ frames_dir = os.path.join(self.XPpath, f'frames_Traj{self.Traj}')
555
+ os.makedirs(frames_dir, exist_ok=True)
556
+
557
+ # parent directory for per-frame Jwalk outputs
558
+ # runJwalk uses os.mkdir so the parent must already exist
559
+ jwalk_base_dir = os.path.join(self.XPpath, f'Jwalk_results_{self.ID}_Traj{self.Traj}')
560
+ os.makedirs(jwalk_base_dir, exist_ok=True)
561
+
562
+ # per-frame worker — closure over self, safe for ThreadPoolExecutor
563
+ def _run_frame(task):
564
+ frame_idx, frame_pdb, frame_jwalk_dir = task
565
+ pdb_stem = pathlib.Path(frame_pdb).stem
566
+ jwalk_outfile = os.path.join(frame_jwalk_dir, 'Jwalk_results',
567
+ f'{pdb_stem}_crosslink_list.txt')
568
+ print(f'\nProcessing frame {frame_idx} | PDB: {frame_pdb} | Jwalk out: {jwalk_outfile}')
569
+
570
+ if not os.path.exists(jwalk_outfile):
571
+ self.runJwalk(frame_pdb, xl_list=xl_list, max_dist=50.0,
572
+ jwalk_results_dir=frame_jwalk_dir, vox=1, ncpus=nproc)
573
+ print(f'Jwalk completed for frame {frame_idx}')
574
+
575
+ frame_df = pd.read_csv(jwalk_outfile, sep=r'\s+', names=col_names,
576
+ skiprows=1, engine='python', index_col=False)
577
+ print(f'Jwalk results loaded for frame {frame_idx}, calculating XP...')
578
+
579
+ xp_scores = [
580
+ self.score_XL(
581
+ (self.three_to_one[row['Atom1'].split('-')[0][0:3]],
582
+ self.three_to_one[row['Atom2'].split('-')[0][0:3]]),
583
+ row['SASD']
584
+ )
585
+ for _, row in frame_df.iterrows()
586
+ ]
587
+ frame_df['XP'] = xp_scores
588
+ frame_df['Frame'] = frame_idx
589
+ # delete per-frame PDB immediately after use
590
+ if os.path.exists(frame_pdb):
591
+ os.remove(frame_pdb)
592
+ self.logger.debug(f'Frame {frame_idx}: XP computed, per-frame PDB removed')
593
+ return frame_df
594
+
595
+ frame_tasks = []
596
+ results = []
597
+ last_valid_frame_result = None # Store last valid frame results for fallback
598
+
599
+ for ts in self.traj_universe.trajectory[self.start:self.end:self.stride]:
600
+ frame_idx = ts.frame
601
+
602
+ # Check for NaN coordinates before attempting to write PDB
603
+ positions = self.traj_universe.atoms.positions
604
+ if np.any(np.isnan(positions)):
605
+ nan_atoms = np.where(np.any(np.isnan(positions), axis=1))[0]
606
+ self.logger.warning(f'Frame {frame_idx} has NaN coordinates in {len(nan_atoms)} atoms. Using XP from previous frame.')
607
+
608
+ # Use last valid frame results if available
609
+ if last_valid_frame_result is not None:
610
+ # Copy previous frame's results but update frame number
611
+ fallback_frame_result = last_valid_frame_result.copy()
612
+ fallback_frame_result['Frame'] = frame_idx
613
+ results.append(fallback_frame_result)
614
+ self.logger.debug(f'Frame {frame_idx}: Used fallback XP from previous frame')
615
+ else:
616
+ self.logger.error(f'Frame {frame_idx} has NaN coordinates but no previous valid XP to fall back to. Skipping this frame.')
617
+
618
+ continue
619
+
620
+ frame_pdb = os.path.join(frames_dir, f'frame_{frame_idx}.pdb')
621
+ with mda.Writer(frame_pdb, self.traj_universe.atoms.n_atoms) as W:
622
+ W.write(self.traj_universe.atoms)
623
+
624
+
625
+ if traj_nproc > 1:
626
+ # parallel: write all frame PDBs first, then process with ThreadPoolExecutor
627
+ # (trajectory iteration must be sequential; Jwalk runs are independent)
628
+ frame_tasks.append((frame_idx, frame_pdb, os.path.join(jwalk_base_dir, f'frame_{frame_idx}')))
629
+
630
+ else:
631
+ # sequential: write PDB → run Jwalk → score → delete, one frame at a time
632
+ frame_df = _run_frame((frame_idx, frame_pdb, os.path.join(jwalk_base_dir, f'frame_{frame_idx}')))
633
+ results.append(frame_df)
634
+ last_valid_frame_result = frame_df # Update fallback with this valid result
635
+
636
+ if traj_nproc > 1:
637
+ print(f'\nRunning frame-level Jwalk in parallel with {traj_nproc} workers...')
638
+ with concurrent.futures.ThreadPoolExecutor(max_workers=traj_nproc) as executor:
639
+ results = list(executor.map(_run_frame, frame_tasks))
640
+ # Update fallback with last result from parallel execution
641
+ if results:
642
+ last_valid_frame_result = results[-1]
643
+ else:
644
+ print(f'Jwalk run completed for all frames in sequential mode.')
645
+
646
+
647
+ combined_df = pd.concat(results, ignore_index=True)
648
+ combined_df = combined_df[['Frame'] + [c for c in combined_df.columns if c != 'Frame']]
649
+ combined_df.to_csv(combined_outfile, index=False, sep='\t')
650
+ self.logger.info(f'SAVED: {combined_outfile}')
651
+ return {'outfile': combined_outfile, 'result': combined_df}
652
+ #######################################################################################
653
+
654
+ #######################################################################################
655
+ def find_residue_pairs(self, pdb_path, output_file="XLresidue_pairs.txt"):
656
+ """
657
+ Finds all unique residue pairs from amino acids [K, S, T, Y, M] in a PDB file.
658
+ Writes output as: resnum1|chain1|resnum2|chain2|
659
+ """
660
+ u = mda.Universe(pdb_path)
661
+
662
+ # Define one-letter code set and their three-letter equivalents
663
+ aa_of_interest = {'LYS', 'SER', 'THR', 'TYR', 'MET'}
664
+
665
+ # Select relevant residues
666
+ selection = u.select_atoms("protein and (" + " or ".join(f"resname {aa}" for aa in aa_of_interest) + ")")
667
+ residues = selection.residues
668
+
669
+ # Create all unique, unordered pairs (no double-counting)
670
+ pairs = list(itertools.combinations(residues, 2))
671
+
672
+ full_pairs = {'resid1': [], 'resname1':[], 'chain1': [], 'resid2': [], 'resname2':[], 'chain2': []}
673
+ with open(output_file, "w") as f:
674
+ for res1, res2 in pairs:
675
+ line = f"{res1.resid}|{res1.segid or res1.chain}|{res2.resid}|{res2.segid or res2.chain}|\n"
676
+ f.write(line)
677
+
678
+ full_pairs['resid1'] += [res1.resid]
679
+ full_pairs['resname1'] += [res1.resname]
680
+ full_pairs['chain1'] += [res1.segid or res1.chain]
681
+ full_pairs['resid2'] += [res2.resid]
682
+ full_pairs['resname2'] += [res2.resname]
683
+ full_pairs['chain2'] += [res2.segid or res2.chain]
684
+
685
+ # Convert to DataFrame
686
+ full_pairs_df = pd.DataFrame(full_pairs)
687
+ # Save to CSV
688
+ full_pairs_df.to_csv(output_file.replace('.txt', '_Full.csv'), index=False)
689
+ self.logger.info(f"Residue pairs saved to '{output_file.replace('.txt', '_Full.csv')}'")
690
+
691
+ self.logger.info(f"Found {len(pairs)} residue pairs and wrote to '{output_file}'")
692
+ #######################################################################################
693
+
694
+ #######################################################################################
695
+ def score_XL(self, pair_AA, JWalk_dist, XL_offset:float=1.1):
696
+ """
697
+ Calculates the cross-linking probability score using the Jwalk distance and the amino acid types
698
+ """
699
+ sc_length = {'K': 6.3,
700
+ 'S': 2.5,
701
+ 'T': 2.5,
702
+ 'Y': 6.5,
703
+ 'M': 1.5,}
704
+
705
+ KK_mu = 18.6
706
+ KK_sigma = 6.0
707
+ KK_threshold = 33
708
+
709
+ KK_mu += XL_offset
710
+ KK_sigma = (XL_offset + 3*KK_sigma) / 3
711
+ KK_threshold += XL_offset
712
+
713
+ mu = KK_mu + (sc_length[pair_AA[0]] + sc_length[pair_AA[1]]) - 2*sc_length['K']
714
+ sigma = (mu - (KK_mu - 3*KK_sigma)) / 3
715
+ threshold = KK_threshold + mu - KK_mu
716
+
717
+ N = norm(mu, sigma)
718
+
719
+ if JWalk_dist == -1:
720
+ score = 0
721
+ elif JWalk_dist <= threshold:
722
+ score = N.pdf(JWalk_dist)
723
+ else:
724
+ score = 0
725
+ return score
726
+ #######################################################################################
727
+
728
+ #######################################################################################
729
+ def runJwalk(self, pdb, xl_list:str='NULL', aa1:str='LYS', aa2:str='LYS', max_dist:float=50.0, jwalk_results_dir:str='./', vox:int=1, ncpus:int=1):
730
+ """
731
+ Execute Jwalk with processed command line options
732
+
733
+ pdb: Input path to .pdb file
734
+ xl_list: OPTIONAL - Input path to crosslink list (default: Finds all Lys-to-Lys crosslinks)
735
+ aa1: OPTIONAL - Specify inital crosslink amino acid three letter code (default: LYS)
736
+ aa2: OPTIONAL - Specify ending crosslink amino acid three letter code (default: LYS)
737
+ max_dist: OPTIONAL - Specify maximum crosslink distance cutoff in Angstroms (default: 50.0 Angstroms)
738
+ jwalk_results_dir: OPTIONAL - Output path for Jwalk results (default: Out to "./Jwalk_results" in the current working directory)
739
+ vox: OPTIONAL - Specify voxel resolution to use in Angstrom (default: 1 Angstrom)
740
+ ncpus: OPTIONAL - Specify number of cpus to use (default: 1)
741
+
742
+ J.Bullock, J. Schwab, K. Thalassinos, M. Topf (2016)
743
+ The importance of non-accessible crosslinks and solvent accessible surface distance
744
+ in modelling proteins with restraints from crosslinking mass spectrometry.
745
+ Molecular and Cellular Proteomics (15) pp.2491–2500
746
+ """
747
+ self.logger.info("Running Jwalk with the following parameters:")
748
+ self.logger.debug(f"pdb: {pdb}")
749
+ self.logger.debug(f"xl_list: {xl_list}")
750
+ self.logger.debug(f"aa1: {aa1}")
751
+ self.logger.debug(f"aa2: {aa2}")
752
+ self.logger.debug(f"max_dist: {max_dist}")
753
+ self.logger.debug(f"jwalk_results_dir: {jwalk_results_dir}")
754
+ self.logger.debug(f"vox: {vox}")
755
+ self.logger.debug(f"ncpus: {ncpus}")
756
+
757
+ # check if the number of cpus is greater than the number of available cpus
758
+ max_cpus = cpu_count()
759
+ amino_acids = {"LYS":"lysines", "CYS":"cysteines", "ASP":"aspartates", "GLU":"glutamates",
760
+ "VAL":"valines", "ILE":"isoleucines", "LEU":"leucines", "ARG":"arginines",
761
+ "PRO":"prolines", "GLY":"glycines", "ALA":"alanines", "TRP":"tryptophans",
762
+ "PHE":"phenylalanines", "SER":"serines", "GLN":"glutamines", "HIS":"histidines",
763
+ "MET":"methionines", "THR":"threonines", "ASN":"asparagines", "TYR":"tyrosines"}
764
+
765
+ # checking if pdb file supplied exists and is of type .pdb
766
+ if os.path.exists(pdb) and pdb.endswith(".pdb"):
767
+ self.logger.info("PDB file supplied is valid")
768
+ pass
769
+ elif not os.path.exists(pdb):
770
+ self.logger.error("ERROR: The input file supplied cannot be found. Please enter a .pdb file type")
771
+ sys.exit(2)
772
+ elif not pdb.endswith(".pdb"):
773
+ self.logger.error("ERROR: The input file supplied is not supported. Please enter a .pdb file type")
774
+ sys.exit(2)
775
+ else:
776
+ self.logger.error("ERROR: The input file supplied is not supported. Please enter a .pdb file type")
777
+ sys.exit(2)
778
+
779
+ # creating result output directory (defaulting to creating it in the working directory)
780
+ if os.path.exists(jwalk_results_dir) and os.path.isdir(jwalk_results_dir):
781
+ self.logger.warning(f"WARNING: {jwalk_results_dir} already exists. Overwriting directory")
782
+ pass
783
+ else:
784
+ self.logger.warning(f"WARNING: {jwalk_results_dir} not found. Creating directory {jwalk_results_dir}")
785
+ os.mkdir(jwalk_results_dir)
786
+ pass
787
+
788
+ # checking if an xl_list was provided
789
+ # if none is provided use the aa1 and aa2 inputs (default is LYS-LYS crosslinks)
790
+ if os.path.normpath(xl_list) == "NULL" or xl_list == "NULL":
791
+ aa1 = aa1.upper()
792
+ aa2 = aa2.upper()
793
+ xl_list = "NULL"
794
+
795
+ if aa1 not in amino_acids or aa2 not in amino_acids:
796
+ self.logger.error("ERROR: Please type amino acid in three letter code format")
797
+ self.logger.debug(amino_acids.keys())
798
+ sys.exit(2)
799
+ else:
800
+ self.logger.info("Calculating all {}-to-{} crosslinks".format(aa1,aa2))
801
+ pass
802
+ # accepting xl_list
803
+ elif os.path.exists(xl_list) and os.path.isfile(xl_list):
804
+ self.logger.info(f"Calculating all crosslinks found in {xl_list}")
805
+ aa1 = "NULL"
806
+ aa2 = "NULL"
807
+ pass
808
+
809
+ # load pdb into Jwalk
810
+ structure_instance = PDBTools.read_PDB_file(pdb)
811
+
812
+ # generate grid of voxel size (vox) that encapsulates pdb
813
+ print(f"Generating grid for PDB: {pdb}")
814
+ grid = GridTools.makeGrid(structure_instance, vox)
815
+
816
+ # mark C-alpha positions on grid
817
+ print(f"Marking C-alpha positions on grid for PDB: {pdb}")
818
+ if xl_list != "NULL": # if specific crosslinks need to be calculated
819
+ crosslink_pairs, aa1_CA, aa2_CA = GridTools.mark_CAlphas_pairs(grid, structure_instance, xl_list)
820
+ else:
821
+ crosslink_pairs = [] # na if searching every combination between residue types
822
+ aa1_CA, aa2_CA = GridTools.mark_CAlphas(grid, structure_instance, aa1, aa2)
823
+
824
+ # check more rigorously if residues are solvent accessible or not
825
+ print(f"Checking solvent accessibility for C-alpha positions for PDB: {pdb}")
826
+ aa1_CA, aa2_CA = SurfaceTools.check_solvent_accessibility_freesasa_both(
827
+ pdb, aa1_CA, aa2_CA, xl_list, amino_acids, ncpus
828
+ )
829
+
830
+ dens_map = GridTools.generate_solvent_accessible_surface(grid, structure_instance, aa1_CA, aa2_CA)
831
+ # identify which residues are on the surface
832
+ aa1_voxels, remove_aa1 = GridTools.find_surface_voxels(aa1_CA, dens_map, xl_list)
833
+ aa2_voxels, remove_aa2 = GridTools.find_surface_voxels(aa2_CA, dens_map, xl_list)
834
+
835
+ crosslink_pairs = SurfaceTools.update_crosslink_pairs(crosslink_pairs, aa1_CA, aa2_CA, remove_aa1, remove_aa2)
836
+
837
+ # calculate sasds
838
+ print(f"Calculating SASDs for PDB: {pdb} with len(crosslink_pairs): {len(crosslink_pairs)}")
839
+ sasds = SASDTools.parallel_BFS(aa1_voxels, aa2_voxels, dens_map, aa1_CA, aa2_CA, crosslink_pairs, max_dist, vox, ncpus, xl_list)
840
+
841
+ # remove duplicates
842
+ print(f"Removing duplicate SASDs for PDB: {pdb}")
843
+ sasds = GridTools.remove_duplicates(sasds)
844
+ sasds = SASDTools.get_euclidean_distances(sasds, pdb, aa1, aa2)
845
+
846
+ # output sasds to .txt file (the .pdb visualisation file is skipped — the
847
+ # chain-counter in write_sasd_to_pdb overflows for large residue-pair sets)
848
+ PDBTools.write_sasd_to_txt(sasds, pdb,jwalk_results_dir)
849
+ self.logger.info(f"{len(sasds)} SASDs calculated")
850
+ print(f"Jwalk completed for PDB: {pdb} | Results saved to: {jwalk_results_dir}")
851
+ #######################################################################################
852
+
853
+
854
+ #########################################################################################
855
+ class CollectOP:
856
+ """
857
+ Aggregate per-trajectory CalculateOP outputs into the single .npy arrays
858
+ expected by MassSpec (compare_sim2exp).
859
+
860
+ Reads
861
+ -----
862
+ {sasa_dir}/{ID}_Traj{N}.SASA – CSV written by CalculateOP.SASA()
863
+ {xp_dir}/{ID}_Traj{N}.XP – TSV written by CalculateOP.XP()
864
+
865
+ Writes
866
+ ------
867
+ SASA.npy : float64 array (n_traj, n_frames, prot_len) units Ų
868
+ Jwalk.npy : object array (n_traj, n_frames)
869
+ each element is a dict
870
+ { 'RESNUM|CHAIN-RESNUM|CHAIN' : {'Euclidean': float, 'Jwalk': float} }
871
+
872
+ Trajectories whose output file is missing are filled with NaN (SASA) or
873
+ left as None (Jwalk) so that MassSpec can skip them via its existing NaN
874
+ filtering logic.
875
+ """
876
+
877
+ def __init__(self, sasa_dir: str, xp_dir: str, outdir: str, ID: str,
878
+ n_traj: int, n_frames: int, prot_len: int):
879
+ """
880
+ Parameters
881
+ ----------
882
+ sasa_dir : directory containing {ID}_Traj{N}.SASA files
883
+ xp_dir : directory containing {ID}_Traj{N}.XP files
884
+ outdir : directory where SASA.npy and Jwalk.npy are written
885
+ ID : protein ID used in file naming (e.g. '1ZMR')
886
+ n_traj : total number of trajectories (files named 1 … n_traj)
887
+ n_frames : number of frames per trajectory stored in each file
888
+ prot_len : number of residues in the protein
889
+ """
890
+ self.sasa_dir = sasa_dir
891
+ self.xp_dir = xp_dir
892
+ self.outdir = outdir
893
+ self.ID = ID
894
+ self.n_traj = n_traj
895
+ self.n_frames = n_frames
896
+ self.prot_len = prot_len
897
+ os.makedirs(outdir, exist_ok=True)
898
+ self.logger = setup_logger('CollectOP', outdir)
899
+
900
+ # ------------------------------------------------------------------
901
+ # helpers
902
+ # ------------------------------------------------------------------
903
+ @staticmethod
904
+ def _atom_to_key_part(atom_str: str) -> str:
905
+ """Parse Jwalk atom string to key fragment.
906
+
907
+ 'MET-1-A-CA' → '1|A'
908
+ """
909
+ parts = atom_str.split('-')
910
+ return f'{parts[1]}|{parts[2]}'
911
+
912
+ # ------------------------------------------------------------------
913
+ def collect_SASA(self, outfile: str = 'SASA.npy') -> str:
914
+ """Read all {ID}_Traj{N}.SASA CSVs, convert nm² → Ų (×100), pivot
915
+ each to (n_frames, prot_len), stack into (n_traj, n_frames, prot_len)
916
+ and save. Missing trajectory files are filled with NaN.
917
+
918
+ Returns the absolute path to the saved .npy file.
919
+ """
920
+ out_path = os.path.join(self.outdir, outfile)
921
+ sasa_arr = np.full(
922
+ (self.n_traj, self.n_frames, self.prot_len),
923
+ np.nan,
924
+ dtype=np.float64,
925
+ )
926
+
927
+ for traj_num in range(1, self.n_traj + 1):
928
+ fpath = os.path.join(self.sasa_dir, f'{self.ID}_Traj{traj_num}.SASA')
929
+ if not os.path.isfile(fpath):
930
+ self.logger.warning(f'Missing SASA file: {fpath}')
931
+ continue
932
+
933
+ df = pd.read_csv(fpath)
934
+
935
+ # pivot to (n_frames, prot_len): rows = frames (sorted), cols = resids (sorted)
936
+ pivot = (
937
+ df.pivot_table(index='Frame', columns='resid',
938
+ values='SASA(nm^2)', aggfunc='first')
939
+ .sort_index() # ascending frame order
940
+ .sort_index(axis=1) # ascending resid order
941
+ )
942
+ arr = pivot.values # shape (n_frames, prot_len)
943
+
944
+ if arr.shape != (self.n_frames, self.prot_len):
945
+ self.logger.warning(
946
+ f'Traj {traj_num}: unexpected shape {arr.shape}, '
947
+ f'expected ({self.n_frames}, {self.prot_len}) – skipping'
948
+ )
949
+ continue
950
+
951
+ sasa_arr[traj_num - 1] = arr * 100.0 # nm² → Ų
952
+ self.logger.info(f'Collected SASA: Traj {traj_num}')
953
+
954
+ np.save(out_path, sasa_arr)
955
+ self.logger.info(f'SAVED: {out_path} shape={sasa_arr.shape}')
956
+ return out_path
957
+
958
+ # ------------------------------------------------------------------
959
+ def collect_Jwalk(self, outfile: str = 'Jwalk.npy') -> str:
960
+ """Read all {ID}_Traj{N}.XP TSVs and reconstruct the per-frame dict
961
+ structure used by MassSpec. Save an object array of shape
962
+ (n_traj, n_frames).
963
+
964
+ Each array element is a dict::
965
+
966
+ { 'RESNUM|CHAIN-RESNUM|CHAIN' : {'Euclidean': float, 'Jwalk': float} }
967
+
968
+ The 'SASD' column from the XP file maps to the 'Jwalk' key; the
969
+ 'Euclidean Distance' column maps to 'Euclidean'.
970
+ Missing trajectory files leave the corresponding row as None entries.
971
+
972
+ Returns the absolute path to the saved .npy file.
973
+ """
974
+ out_path = os.path.join(self.outdir, outfile)
975
+ jwalk_arr = np.empty((self.n_traj, self.n_frames), dtype=object)
976
+
977
+ for traj_num in range(1, self.n_traj + 1):
978
+ fpath = os.path.join(self.xp_dir, f'{self.ID}_Traj{traj_num}.XP')
979
+ if not os.path.isfile(fpath):
980
+ self.logger.warning(f'Missing XP file: {fpath}')
981
+ continue
982
+
983
+ df = pd.read_csv(
984
+ fpath,
985
+ sep='\t',
986
+ usecols=['Frame', 'Atom1', 'Atom2', 'Euclidean Distance', 'SASD'],
987
+ dtype={
988
+ 'Frame': np.int32,
989
+ 'Euclidean Distance': np.float32,
990
+ 'SASD': np.float32,
991
+ 'Atom1': 'string',
992
+ 'Atom2': 'string',
993
+ },
994
+ )
995
+ frames = sorted(df['Frame'].unique())
996
+
997
+ if len(frames) != self.n_frames:
998
+ self.logger.warning(
999
+ f'Traj {traj_num}: found {len(frames)} frames, '
1000
+ f'expected {self.n_frames} – skipping'
1001
+ )
1002
+ continue
1003
+
1004
+ for frame_idx, (_, fdf) in enumerate(df.groupby('Frame', sort=True)):
1005
+ fdict = {}
1006
+ for _, row in fdf.iterrows():
1007
+ k1 = self._atom_to_key_part(row['Atom1'])
1008
+ k2 = self._atom_to_key_part(row['Atom2'])
1009
+ fdict[f'{k1}-{k2}'] = {
1010
+ 'Euclidean': float(row['Euclidean Distance']),
1011
+ 'Jwalk': float(row['SASD']),
1012
+ }
1013
+ jwalk_arr[traj_num - 1, frame_idx] = fdict
1014
+
1015
+ del df
1016
+
1017
+ self.logger.info(f'Collected Jwalk/XP: Traj {traj_num}')
1018
+
1019
+ np.save(out_path, jwalk_arr, allow_pickle=True)
1020
+ self.logger.info(f'SAVED: {out_path} shape={jwalk_arr.shape}')
1021
+ return out_path
1022
+ #########################################################################################
1023
+
1024
+
1025
+ ## Round GaussLink values
1026
+ def custom_round(number):
1027
+ if number >= 0:
1028
+ # For positive numbers, round up if fractional part >= 0.6
1029
+ return np.ceil(number) if number % 1 >= 0.6 else np.floor(number)
1030
+ else:
1031
+ # For negative numbers, round down if the absolute fractional part >= 0.6
1032
+ # need to take the abs of the number first else the modulus does work right for negative numbers?
1033
+ return np.floor(number) if abs(abs(number) % 1) >= 0.6 else np.ceil(number)
1034
+
1035
+ def process_frame(frame_data):
1036
+ frame_coor, nc_list, ref_nc_gdict, frame_time, frame, GaussLink, GetLinkChanges, Nnative = frame_data
1037
+ # Call GaussLink function
1038
+ t1 = time.time()
1039
+ frame_nc_gdict = GaussLink(frame_coor, contact_mask=nc_list)
1040
+ #print(f'FRAME: {frame} GaussLink time: {time.time() - t1}')
1041
+
1042
+ # Call GetLinkChanges function
1043
+ t1 = time.time()
1044
+ change_info, count_info = GetLinkChanges(ref_nc_gdict, frame_nc_gdict, frame_time, frame, Nnative)
1045
+ #print(f'FRAME: {frame} GetLinkChanges time: {time.time() - t1}')
1046
+ return change_info, count_info
1047
+
1048
+