crisp-ase 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRISP/__init__.py +99 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +41 -0
- CRISP/data_analysis/__init__.py +38 -0
- CRISP/data_analysis/clustering.py +838 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +772 -0
- CRISP/data_analysis/msd.py +1199 -0
- CRISP/data_analysis/prdf.py +404 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +31 -0
- CRISP/simulation_utility/atomic_indices.py +155 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +254 -0
- CRISP/simulation_utility/interatomic_distances.py +200 -0
- CRISP/simulation_utility/subsampling.py +241 -0
- CRISP/tests/DataAnalysis/__init__.py +1 -0
- CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
- CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
- CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
- CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
- CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
- CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
- CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
- CRISP/tests/DataAnalysis/test_prdf.py +206 -0
- CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
- CRISP/tests/SimulationUtility/__init__.py +1 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
- CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
- CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
- CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
- CRISP/tests/__init__.py +1 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_cli.py +87 -0
- CRISP/tests/test_crisp_comprehensive.py +679 -0
- crisp_ase-1.1.2.dist-info/METADATA +141 -0
- crisp_ase-1.1.2.dist-info/RECORD +42 -0
- crisp_ase-1.1.2.dist-info/WHEEL +5 -0
- crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
- crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,915 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/data_analysis/contact_coordination.py
|
|
3
|
+
|
|
4
|
+
This module performs contact and correlation on molecular dynamics trajectory data.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ase.io import read
|
|
8
|
+
import numpy as np
|
|
9
|
+
from joblib import Parallel, delayed
|
|
10
|
+
import pickle
|
|
11
|
+
from typing import Union, List, Dict, Tuple, Optional, Any
|
|
12
|
+
import os
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import itertools
|
|
15
|
+
from ase.data import vdw_radii, atomic_numbers, chemical_symbols
|
|
16
|
+
import seaborn as sns
|
|
17
|
+
import plotly.graph_objects as go
|
|
18
|
+
import plotly.io as pio
|
|
19
|
+
__all__ = ['indices', 'coordination_frame', 'coordination', 'contacts_frame', 'contacts']
|
|
20
|
+
|
|
21
|
+
def indices(atoms, ind: Union[str, List[Union[int, str]]]) -> np.ndarray:
|
|
22
|
+
"""
|
|
23
|
+
Return array of atom indices from an ASE Atoms object based on the input specifier.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
atoms : ase.Atoms
|
|
28
|
+
ASE Atoms object containing atomic structure
|
|
29
|
+
ind : Union[str, List[Union[int, str]]]
|
|
30
|
+
Index specifier, can be "all", .npy file, integer(s), or chemical symbol(s)
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
np.ndarray
|
|
35
|
+
Array of selected indices
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
If the index type is invalid
|
|
41
|
+
"""
|
|
42
|
+
if ind == "all" or ind is None:
|
|
43
|
+
return np.arange(len(atoms))
|
|
44
|
+
if isinstance(ind, str) and ind.endswith(".npy"):
|
|
45
|
+
return np.load(ind, allow_pickle=True)
|
|
46
|
+
if not isinstance(ind, list):
|
|
47
|
+
ind = [ind]
|
|
48
|
+
if any(isinstance(item, int) for item in ind):
|
|
49
|
+
return np.array(ind)
|
|
50
|
+
if any(isinstance(item, str) for item in ind):
|
|
51
|
+
idx = []
|
|
52
|
+
if isinstance(ind, str):
|
|
53
|
+
ind = [ind]
|
|
54
|
+
for symbol in ind:
|
|
55
|
+
idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
|
|
56
|
+
return np.concatenate(idx)
|
|
57
|
+
raise ValueError("Invalid index type")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def coordination_frame(atoms, central_atoms, target_atoms, custom_cutoffs=None, mic=True):
|
|
61
|
+
"""
|
|
62
|
+
Calculate coordination numbers for central atoms based on interatomic distances and cutoff criteria.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
atoms : ase.Atoms
|
|
67
|
+
ASE Atoms object containing atomic structure
|
|
68
|
+
central_atoms : Union[str, List[Union[int, str]]]
|
|
69
|
+
Specifier for central atoms
|
|
70
|
+
target_atoms : Union[str, List[Union[int, str]]]
|
|
71
|
+
Specifier for target atoms that interact with central atoms
|
|
72
|
+
custom_cutoffs : Optional[Dict[Tuple[str, str], float]]
|
|
73
|
+
Dictionary with custom cutoff distances for atom pairs
|
|
74
|
+
mic : bool
|
|
75
|
+
Whether to use the minimum image convention
|
|
76
|
+
|
|
77
|
+
Returns
|
|
78
|
+
-------
|
|
79
|
+
Dict[int, int]
|
|
80
|
+
Dictionary mapping central atom indices to their coordination numbers
|
|
81
|
+
"""
|
|
82
|
+
indices_central = indices(atoms, central_atoms)
|
|
83
|
+
indices_target = indices(atoms, target_atoms)
|
|
84
|
+
|
|
85
|
+
dm = atoms.get_all_distances(mic=mic)
|
|
86
|
+
np.fill_diagonal(dm, np.inf)
|
|
87
|
+
|
|
88
|
+
sub_dm = dm[np.ix_(indices_central, indices_target)]
|
|
89
|
+
|
|
90
|
+
central_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_central]
|
|
91
|
+
target_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_target]
|
|
92
|
+
|
|
93
|
+
central_vdw_radii = vdw_radii[central_atomic_numbers]
|
|
94
|
+
target_vdw_radii = vdw_radii[target_atomic_numbers]
|
|
95
|
+
|
|
96
|
+
cutoff_matrix = 0.6 * (central_vdw_radii[:, np.newaxis] + target_vdw_radii[np.newaxis, :])
|
|
97
|
+
|
|
98
|
+
if custom_cutoffs is not None:
|
|
99
|
+
cutoff_atomic_numbers = [tuple(sorted(atomic_numbers[symbol] for symbol in pair)) for pair in
|
|
100
|
+
list(custom_cutoffs.keys())]
|
|
101
|
+
cutoff_values = list(custom_cutoffs.values())
|
|
102
|
+
|
|
103
|
+
cutoff_matrix_indices = [[tuple(sorted([i, j])) for j in target_atomic_numbers] for i in central_atomic_numbers]
|
|
104
|
+
|
|
105
|
+
for i, central_atom in enumerate(cutoff_matrix_indices):
|
|
106
|
+
for j, bond in enumerate(central_atom):
|
|
107
|
+
if bond in cutoff_atomic_numbers:
|
|
108
|
+
cutoff_matrix[i, j] = cutoff_values[cutoff_atomic_numbers.index(bond)]
|
|
109
|
+
|
|
110
|
+
coordination_numbers = np.sum(sub_dm < cutoff_matrix, axis=1)
|
|
111
|
+
coordination_dict_frame = dict(zip(indices_central, coordination_numbers))
|
|
112
|
+
return coordination_dict_frame
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_avg_percentages(coordination_data, atom_type, plot_cn=False, output_dir="./"):
|
|
116
|
+
"""
|
|
117
|
+
Compute average percentages of each coordination number over all frames.
|
|
118
|
+
|
|
119
|
+
Parameters
|
|
120
|
+
----------
|
|
121
|
+
coordination_data : Dict[int, List[int]]
|
|
122
|
+
Dictionary mapping atom indices to lists of coordination numbers
|
|
123
|
+
atom_type : Optional[str]
|
|
124
|
+
Chemical symbol of target atoms or None
|
|
125
|
+
plot_cn : bool, optional
|
|
126
|
+
Boolean to indicate if a time series plot should be generated
|
|
127
|
+
output_dir : str, optional
|
|
128
|
+
Directory where output files will be saved
|
|
129
|
+
|
|
130
|
+
Returns
|
|
131
|
+
-------
|
|
132
|
+
Dict[int, List[float]]
|
|
133
|
+
Dictionary mapping each coordination number to a list of average percentages per frame
|
|
134
|
+
"""
|
|
135
|
+
coord_types = set(itertools.chain.from_iterable(coordination_data.values()))
|
|
136
|
+
coord_types = sorted(coord_types)
|
|
137
|
+
|
|
138
|
+
num_frames = len(next(iter(coordination_data.values())))
|
|
139
|
+
avg_percentages = {coord_type: [] for coord_type in coord_types}
|
|
140
|
+
|
|
141
|
+
for frame_idx in range(num_frames):
|
|
142
|
+
frame_data = [values[frame_idx] for values in coordination_data.values()]
|
|
143
|
+
total_atoms = len(frame_data)
|
|
144
|
+
for coord_type in coord_types:
|
|
145
|
+
count = frame_data.count(coord_type)
|
|
146
|
+
avg_percentage = count / total_atoms * 100
|
|
147
|
+
avg_percentages[coord_type].append(avg_percentage)
|
|
148
|
+
|
|
149
|
+
frames = list(range(len(next(iter(avg_percentages.values())))))
|
|
150
|
+
coord_types = list(avg_percentages.keys())
|
|
151
|
+
|
|
152
|
+
if plot_cn:
|
|
153
|
+
colors = plt.get_cmap('tab10', len(coord_types)).colors
|
|
154
|
+
markers = itertools.cycle(['o', 's', 'D', '^', 'v', 'p', '*', '+', 'x'])
|
|
155
|
+
plt.figure(figsize=(10, 6))
|
|
156
|
+
for i, coord_type in enumerate(coord_types):
|
|
157
|
+
plt.plot(frames, avg_percentages[coord_type], label=f'CN={coord_type}',
|
|
158
|
+
color=colors[i], marker=next(markers), markevery=max(1, len(frames) // 20))
|
|
159
|
+
for i, coord_type in enumerate(coord_types):
|
|
160
|
+
mean_value = sum(avg_percentages[coord_type]) / len(avg_percentages[coord_type])
|
|
161
|
+
plt.axhline(y=mean_value, color=colors[i], linestyle='--', alpha=0.7,
|
|
162
|
+
label=f'Mean CN={coord_type}: {mean_value:.1f}%')
|
|
163
|
+
plt.xlabel('Frame Index', fontsize=12)
|
|
164
|
+
plt.ylabel('Percentage of Atoms (%)', fontsize=12)
|
|
165
|
+
if atom_type is not None:
|
|
166
|
+
plt.title(f'Coordination Analysis: {atom_type} Atoms', fontsize=14)
|
|
167
|
+
else:
|
|
168
|
+
plt.title('Coordination Analysis', fontsize=14)
|
|
169
|
+
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
|
|
170
|
+
plt.grid(True, alpha=0.3)
|
|
171
|
+
plt.tight_layout()
|
|
172
|
+
plt.savefig(os.path.join(output_dir, "CN_time_series.png"), dpi=300, bbox_inches='tight')
|
|
173
|
+
plt.show()
|
|
174
|
+
plt.close()
|
|
175
|
+
|
|
176
|
+
return avg_percentages
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def plot_coordination_distribution(avg_percentages, atom_type, plot_cn, output_dir="./", output_file="CN_distribution"):
|
|
180
|
+
"""
|
|
181
|
+
Plot a pie chart showing the overall average distribution of coordination numbers.
|
|
182
|
+
|
|
183
|
+
Parameters
|
|
184
|
+
----------
|
|
185
|
+
avg_percentages : Dict[int, List[float]]
|
|
186
|
+
Dictionary of average percentages per coordination number
|
|
187
|
+
atom_type : Optional[str]
|
|
188
|
+
Chemical symbol for target atoms
|
|
189
|
+
plot_cn : bool
|
|
190
|
+
Boolean to indicate if the plot should be generated
|
|
191
|
+
output_dir : str, optional
|
|
192
|
+
Directory where output files will be saved
|
|
193
|
+
output_file : str, optional
|
|
194
|
+
Filename for saving the plot
|
|
195
|
+
|
|
196
|
+
Returns
|
|
197
|
+
-------
|
|
198
|
+
Dict[int, float]
|
|
199
|
+
Dictionary of overall average coordination percentages
|
|
200
|
+
"""
|
|
201
|
+
overall_avg_percentages = plot_coordination_distribution_plotly(avg_percentages, atom_type, plot_cn,
|
|
202
|
+
output_dir=output_dir,
|
|
203
|
+
output_file=output_file)
|
|
204
|
+
|
|
205
|
+
# Generate a static version
|
|
206
|
+
if plot_cn:
|
|
207
|
+
overall_avg_percentages = {coord_type: sum(percentages) / len(percentages)
|
|
208
|
+
for coord_type, percentages in avg_percentages.items()}
|
|
209
|
+
fig, ax = plt.subplots(figsize=(10, 7))
|
|
210
|
+
|
|
211
|
+
sorted_data = sorted(overall_avg_percentages.items())
|
|
212
|
+
coord_types = [item[0] for item in sorted_data]
|
|
213
|
+
percentages = [item[1] for item in sorted_data]
|
|
214
|
+
|
|
215
|
+
colors = plt.cm.tab10.colors[:len(coord_types)]
|
|
216
|
+
|
|
217
|
+
# Create pie chart
|
|
218
|
+
wedges, _ = ax.pie(percentages,
|
|
219
|
+
wedgeprops=dict(width=0.5, edgecolor='w'),
|
|
220
|
+
startangle=90,
|
|
221
|
+
colors=colors)
|
|
222
|
+
|
|
223
|
+
legend_labels = [f'CN={coord_type}: {pct:.1f}%' for coord_type, pct in sorted_data]
|
|
224
|
+
|
|
225
|
+
ax.legend(wedges, legend_labels,
|
|
226
|
+
title="Coordination Numbers",
|
|
227
|
+
loc="center left",
|
|
228
|
+
bbox_to_anchor=(1, 0.5),
|
|
229
|
+
frameon=True,
|
|
230
|
+
fancybox=True,
|
|
231
|
+
shadow=True)
|
|
232
|
+
|
|
233
|
+
ax.axis('equal')
|
|
234
|
+
|
|
235
|
+
if atom_type:
|
|
236
|
+
plt.title(f'Average Distribution of {atom_type} Atom Coordination', fontsize=14)
|
|
237
|
+
else:
|
|
238
|
+
plt.title(f'Average Distribution of Atom Coordination', fontsize=14)
|
|
239
|
+
|
|
240
|
+
plt.tight_layout()
|
|
241
|
+
plt.savefig(os.path.join(output_dir, f"{output_file}.png"), dpi=300, bbox_inches='tight')
|
|
242
|
+
plt.show()
|
|
243
|
+
plt.close()
|
|
244
|
+
|
|
245
|
+
return overall_avg_percentages
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def plot_coordination_distribution_plotly(avg_percentages, atom_type, plot_cn, output_dir="./", output_file="CN_distribution"):
|
|
249
|
+
"""
|
|
250
|
+
Plot an interactive pie chart showing the overall average distribution of coordination numbers.
|
|
251
|
+
|
|
252
|
+
Parameters
|
|
253
|
+
----------
|
|
254
|
+
avg_percentages : Dict[int, List[float]]
|
|
255
|
+
Dictionary of average percentages per coordination number
|
|
256
|
+
atom_type : Optional[str]
|
|
257
|
+
Chemical symbol for target atoms
|
|
258
|
+
plot_cn : bool
|
|
259
|
+
Boolean to indicate if the plot should be generated
|
|
260
|
+
output_dir : str, optional
|
|
261
|
+
Directory where output files will be saved
|
|
262
|
+
output_file : str, optional
|
|
263
|
+
Filename for saving the plot
|
|
264
|
+
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
Dict[int, float]
|
|
268
|
+
Dictionary of overall average coordination percentages
|
|
269
|
+
"""
|
|
270
|
+
overall_avg_percentages = {coord_type: sum(percentages) / len(percentages) for coord_type, percentages in
|
|
271
|
+
avg_percentages.items()}
|
|
272
|
+
|
|
273
|
+
if plot_cn:
|
|
274
|
+
sorted_data = sorted(overall_avg_percentages.items())
|
|
275
|
+
coord_types = [f"CN={item[0]}" for item in sorted_data]
|
|
276
|
+
percentages = [item[1] for item in sorted_data]
|
|
277
|
+
|
|
278
|
+
hover_info = [f"CN={coord_type}: {pct:.2f}%" for coord_type, pct in sorted_data]
|
|
279
|
+
|
|
280
|
+
fig = go.Figure(data=[go.Pie(
|
|
281
|
+
labels=coord_types,
|
|
282
|
+
values=percentages,
|
|
283
|
+
hole=0.4,
|
|
284
|
+
textinfo='label+percent',
|
|
285
|
+
hoverinfo='text',
|
|
286
|
+
hovertext=hover_info,
|
|
287
|
+
textfont=dict(size=12),
|
|
288
|
+
marker=dict(
|
|
289
|
+
colors=[f'rgb{tuple(int(c*255) for c in plt.cm.tab10(i)[:3])}' for i in range(len(coord_types))],
|
|
290
|
+
line=dict(color='white', width=2)
|
|
291
|
+
),
|
|
292
|
+
)])
|
|
293
|
+
|
|
294
|
+
if atom_type:
|
|
295
|
+
title_text = f'Average Distribution of {atom_type} Atom Coordination'
|
|
296
|
+
else:
|
|
297
|
+
title_text = 'Average Distribution of Atom Coordination'
|
|
298
|
+
|
|
299
|
+
fig.update_layout(
|
|
300
|
+
title=dict(
|
|
301
|
+
text=title_text,
|
|
302
|
+
font=dict(size=16)
|
|
303
|
+
),
|
|
304
|
+
legend=dict(
|
|
305
|
+
orientation='h',
|
|
306
|
+
xanchor='center',
|
|
307
|
+
x=0.5,
|
|
308
|
+
y=-0.1
|
|
309
|
+
),
|
|
310
|
+
height=600,
|
|
311
|
+
width=800
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
html_path = os.path.join(output_dir, f"{output_file}.html")
|
|
315
|
+
fig.write_html(html_path)
|
|
316
|
+
print(f"Interactive coordination distribution chart saved to {html_path}")
|
|
317
|
+
|
|
318
|
+
return overall_avg_percentages
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def log_coordination_data(distribution, avg_percentages, atom_type, avg_cn=None, std_cn=None, output_dir="./"):
|
|
322
|
+
"""
|
|
323
|
+
Log coordination analysis statistics to a text file.
|
|
324
|
+
|
|
325
|
+
Parameters
|
|
326
|
+
----------
|
|
327
|
+
distribution : Dict[int, float]
|
|
328
|
+
Dictionary of overall average coordination percentages
|
|
329
|
+
avg_percentages : Dict[int, List[float]]
|
|
330
|
+
Dictionary of percentage values per frame for each coordination number
|
|
331
|
+
atom_type : Optional[str]
|
|
332
|
+
Chemical symbol of target atoms or None
|
|
333
|
+
avg_cn : float, optional
|
|
334
|
+
Average coordination number
|
|
335
|
+
std_cn : float, optional
|
|
336
|
+
Standard deviation of coordination number
|
|
337
|
+
output_dir : str, optional
|
|
338
|
+
Directory where the statistics file will be saved
|
|
339
|
+
|
|
340
|
+
Returns
|
|
341
|
+
-------
|
|
342
|
+
None
|
|
343
|
+
"""
|
|
344
|
+
if atom_type is not None:
|
|
345
|
+
stats_file = f"CN_{atom_type}_statistics.txt"
|
|
346
|
+
else:
|
|
347
|
+
stats_file = "CN_statistics.txt"
|
|
348
|
+
|
|
349
|
+
stats_file = os.path.join(output_dir, stats_file)
|
|
350
|
+
with open(stats_file, 'w') as f:
|
|
351
|
+
if atom_type is not None:
|
|
352
|
+
f.write(f"Coordination Analysis for {atom_type} Atoms\n")
|
|
353
|
+
else:
|
|
354
|
+
f.write("Coordination Analysis\n")
|
|
355
|
+
f.write("======================================\n\n")
|
|
356
|
+
|
|
357
|
+
if avg_cn is not None and std_cn is not None:
|
|
358
|
+
f.write(f"Average Coordination Number: {avg_cn:.2f} ± {std_cn:.2f}\n\n")
|
|
359
|
+
|
|
360
|
+
f.write("Overall Average Percentages:\n")
|
|
361
|
+
|
|
362
|
+
# Standard deviations for each coordination number
|
|
363
|
+
std_devs = {cn: np.std(values) for cn, values in avg_percentages.items()}
|
|
364
|
+
|
|
365
|
+
for coord_type, avg_percentage in sorted(distribution.items()):
|
|
366
|
+
std_dev = std_devs[coord_type]
|
|
367
|
+
f.write(f" CN={coord_type}: {avg_percentage:.2f}% ± {std_dev:.2f}%\n")
|
|
368
|
+
|
|
369
|
+
most_common_cn = max(distribution.items(), key=lambda x: x[1])[0]
|
|
370
|
+
f.write(f"\nMost common coordination number: {most_common_cn}\n")
|
|
371
|
+
|
|
372
|
+
print(f"Coordination statistics saved to {stats_file}")
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def coordination(traj_path, central_atoms, target_atoms, custom_cutoffs, frame_skip=10,
|
|
376
|
+
plot_cn=False, output_dir="./"):
|
|
377
|
+
"""
|
|
378
|
+
Process a trajectory file to compute coordination numbers for each frame.
|
|
379
|
+
|
|
380
|
+
Parameters
|
|
381
|
+
----------
|
|
382
|
+
traj_path : str
|
|
383
|
+
Path to the trajectory file
|
|
384
|
+
central_atoms : Union[str, List[Union[int, str]]]
|
|
385
|
+
Specifier for central atoms being analyzed
|
|
386
|
+
target_atoms : Union[str, List[Union[int, str]]]
|
|
387
|
+
Specifier for target atoms that interact with central atoms
|
|
388
|
+
custom_cutoffs : Dict[Tuple[str, str], float]
|
|
389
|
+
Dictionary with custom cutoff distances
|
|
390
|
+
frame_skip : int, optional
|
|
391
|
+
Interval for skipping frames (default: 10)
|
|
392
|
+
plot_cn : bool, optional
|
|
393
|
+
Boolean to indicate if plots should be generated (default: False)
|
|
394
|
+
output_dir : str, optional
|
|
395
|
+
Directory where output files will be saved (default: "./")
|
|
396
|
+
|
|
397
|
+
Returns
|
|
398
|
+
-------
|
|
399
|
+
List
|
|
400
|
+
List containing coordination dictionary, average percentages, distribution, atom type, and avg CN
|
|
401
|
+
"""
|
|
402
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
403
|
+
|
|
404
|
+
trajectory = read(traj_path, index=f"::{frame_skip}")
|
|
405
|
+
coordination_dict = {}
|
|
406
|
+
|
|
407
|
+
results = Parallel(n_jobs=-1)(
|
|
408
|
+
delayed(coordination_frame)(atoms, central_atoms, target_atoms, custom_cutoffs)
|
|
409
|
+
for atoms in trajectory
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
for frame_dict in results:
|
|
413
|
+
for key, value in frame_dict.items():
|
|
414
|
+
coordination_dict.setdefault(key, []).append(value)
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
if isinstance(central_atoms, str):
|
|
418
|
+
if central_atoms.endswith('.npy'):
|
|
419
|
+
# Extract just the basename without extension for .npy files
|
|
420
|
+
atom_type = os.path.splitext(os.path.basename(central_atoms))[0]
|
|
421
|
+
else:
|
|
422
|
+
atom_type = central_atoms
|
|
423
|
+
elif isinstance(central_atoms, int):
|
|
424
|
+
atom_type = chemical_symbols[central_atoms]
|
|
425
|
+
else:
|
|
426
|
+
atom_type = None
|
|
427
|
+
|
|
428
|
+
avg_percentages = get_avg_percentages(coordination_dict, atom_type, plot_cn, output_dir=output_dir)
|
|
429
|
+
|
|
430
|
+
distribution = plot_coordination_distribution(avg_percentages, atom_type, plot_cn, output_dir=output_dir)
|
|
431
|
+
|
|
432
|
+
all_cn_values = []
|
|
433
|
+
for atom_idx, cn_list in coordination_dict.items():
|
|
434
|
+
all_cn_values.extend(cn_list)
|
|
435
|
+
|
|
436
|
+
all_cn_values = np.array(all_cn_values)
|
|
437
|
+
avg_cn = np.mean(all_cn_values)
|
|
438
|
+
std_cn = np.std(all_cn_values)
|
|
439
|
+
|
|
440
|
+
weighted_avg_cn = sum(cn * (pct/100) for cn, pct in distribution.items())
|
|
441
|
+
|
|
442
|
+
log_coordination_data(distribution, avg_percentages, atom_type,
|
|
443
|
+
avg_cn=avg_cn, std_cn=std_cn, output_dir=output_dir)
|
|
444
|
+
|
|
445
|
+
print("\nCoordination Statistics Summary:")
|
|
446
|
+
print("=" * 40)
|
|
447
|
+
|
|
448
|
+
std_devs = {cn: np.std(values) for cn, values in avg_percentages.items()}
|
|
449
|
+
|
|
450
|
+
for coord_type, avg_percentage in sorted(distribution.items()):
|
|
451
|
+
std_dev = std_devs[coord_type]
|
|
452
|
+
print(f"CN={coord_type}: {avg_percentage:.2f}% ± {std_dev:.2f}%")
|
|
453
|
+
|
|
454
|
+
most_common_cn = max(distribution.items(), key=lambda x: x[1])[0]
|
|
455
|
+
print(f"\nMost common coordination number: {most_common_cn}")
|
|
456
|
+
|
|
457
|
+
print(f"\nAverage Coordination Number: {avg_cn:.2f} ± {std_cn:.2f}")
|
|
458
|
+
|
|
459
|
+
return [coordination_dict, avg_percentages, distribution, atom_type, avg_cn, std_cn]
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def contacts_frame(atoms, central_atoms, target_atoms, custom_cutoffs, mic=True):
|
|
463
|
+
"""
|
|
464
|
+
Processes a single atoms frame to compute the sub-distance matrix and the corresponding cutoff matrix.
|
|
465
|
+
|
|
466
|
+
Parameters
|
|
467
|
+
----------
|
|
468
|
+
atoms : ase.Atoms
|
|
469
|
+
ASE Atoms object containing atomic structure
|
|
470
|
+
central_atoms : Union[str, List[Union[int, str]]]
|
|
471
|
+
Selection criteria for central atoms being analyzed
|
|
472
|
+
target_atoms : Union[str, List[Union[int, str]]]
|
|
473
|
+
Selection criteria for target atoms that interact with central atoms
|
|
474
|
+
custom_cutoffs : Dict[Tuple[str, str], float]
|
|
475
|
+
Dictionary mapping atom pairs to custom cutoff values
|
|
476
|
+
mic : bool, optional
|
|
477
|
+
Whether to use minimum image convention (default: True)
|
|
478
|
+
|
|
479
|
+
Returns
|
|
480
|
+
-------
|
|
481
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
|
|
482
|
+
Sub-distance matrix, cutoff matrix, central atom indices, and target atom indices
|
|
483
|
+
"""
|
|
484
|
+
indices_central = indices(atoms, central_atoms)
|
|
485
|
+
indices_target = indices(atoms, target_atoms)
|
|
486
|
+
|
|
487
|
+
dm = atoms.get_all_distances(mic=mic)
|
|
488
|
+
np.fill_diagonal(dm, np.inf) # Avoids self-interactions
|
|
489
|
+
|
|
490
|
+
has_overlap = np.intersect1d(indices_central, indices_target).size > 0
|
|
491
|
+
|
|
492
|
+
if has_overlap:
|
|
493
|
+
mask = np.zeros_like(dm, dtype=bool)
|
|
494
|
+
for i_idx, i in enumerate(indices_central):
|
|
495
|
+
for j_idx, j in enumerate(indices_target):
|
|
496
|
+
# We only count each pair once by enforcing i < j
|
|
497
|
+
if i < j:
|
|
498
|
+
mask[i, j] = True
|
|
499
|
+
|
|
500
|
+
filtered_dm = np.where(mask, dm, np.inf)
|
|
501
|
+
sub_dm = filtered_dm[np.ix_(indices_central, indices_target)]
|
|
502
|
+
else:
|
|
503
|
+
sub_dm = dm[np.ix_(indices_central, indices_target)]
|
|
504
|
+
|
|
505
|
+
# Get atomic numbers and van der Waals radii
|
|
506
|
+
central_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_central]
|
|
507
|
+
target_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_target]
|
|
508
|
+
|
|
509
|
+
central_vdw_radii = vdw_radii[central_atomic_numbers]
|
|
510
|
+
target_vdw_radii = vdw_radii[target_atomic_numbers]
|
|
511
|
+
|
|
512
|
+
cutoff_matrix = 0.6 * (central_vdw_radii[:, np.newaxis] + target_vdw_radii[np.newaxis, :])
|
|
513
|
+
|
|
514
|
+
if custom_cutoffs is not None:
|
|
515
|
+
cutoff_atomic_numbers = [
|
|
516
|
+
tuple(sorted(atomic_numbers[symbol] for symbol in pair))
|
|
517
|
+
for pair in list(custom_cutoffs.keys())
|
|
518
|
+
]
|
|
519
|
+
cutoff_values = list(custom_cutoffs.values())
|
|
520
|
+
|
|
521
|
+
cutoff_matrix_indices = [
|
|
522
|
+
[tuple(sorted([i, j])) for j in target_atomic_numbers]
|
|
523
|
+
for i in central_atomic_numbers
|
|
524
|
+
]
|
|
525
|
+
for i, central_atom in enumerate(cutoff_matrix_indices):
|
|
526
|
+
for j, bond in enumerate(central_atom):
|
|
527
|
+
if bond in cutoff_atomic_numbers:
|
|
528
|
+
cutoff_matrix[i, j] = cutoff_values[cutoff_atomic_numbers.index(bond)]
|
|
529
|
+
|
|
530
|
+
return sub_dm, cutoff_matrix, indices_central, indices_target
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def plot_contact_heatmap(contact_matrix, frame_skip, time_step, x_labels, y_labels, atom_type, output_dir="./"):
|
|
534
|
+
"""
|
|
535
|
+
Plots and saves an interactive heatmap showing contact times between central and target atoms.
|
|
536
|
+
|
|
537
|
+
Parameters
|
|
538
|
+
----------
|
|
539
|
+
contact_matrix : np.ndarray
|
|
540
|
+
Boolean 3D array of contacts
|
|
541
|
+
frame_skip : int
|
|
542
|
+
Number of frames skipped when processing the trajectory
|
|
543
|
+
time_step : float
|
|
544
|
+
Time step between frames (used to convert counts to time)
|
|
545
|
+
x_labels : List[str]
|
|
546
|
+
Labels for the target atoms (x-axis)
|
|
547
|
+
y_labels : List[str]
|
|
548
|
+
Labels for the central atoms (y-axis)
|
|
549
|
+
atom_type : Optional[str]
|
|
550
|
+
Central atom type (used for naming the file)
|
|
551
|
+
output_dir : str, optional
|
|
552
|
+
Directory where the output file will be saved (default: "./")
|
|
553
|
+
|
|
554
|
+
Returns
|
|
555
|
+
-------
|
|
556
|
+
None
|
|
557
|
+
"""
|
|
558
|
+
contact_times_matrix = np.sum(contact_matrix, axis=0) * frame_skip * time_step / 1000
|
|
559
|
+
|
|
560
|
+
hover_text = []
|
|
561
|
+
for i, central_label in enumerate(y_labels):
|
|
562
|
+
hover_row = []
|
|
563
|
+
for j, target_label in enumerate(x_labels):
|
|
564
|
+
hover_row.append(f"Central: {central_label}<br>" +
|
|
565
|
+
f"Target: {target_label}<br>" +
|
|
566
|
+
f"Contact time: {contact_times_matrix[i, j]:.2f} ps")
|
|
567
|
+
hover_text.append(hover_row)
|
|
568
|
+
|
|
569
|
+
fig = go.Figure(data=go.Heatmap(
|
|
570
|
+
z=contact_times_matrix,
|
|
571
|
+
x=x_labels,
|
|
572
|
+
y=y_labels,
|
|
573
|
+
colorscale='Viridis',
|
|
574
|
+
hoverinfo='text',
|
|
575
|
+
text=hover_text,
|
|
576
|
+
colorbar=dict(title='Contact Time (ps)')
|
|
577
|
+
))
|
|
578
|
+
|
|
579
|
+
fig.update_layout(
|
|
580
|
+
title=dict(text='Heatmap of contact times within cutoffs', font=dict(size=16)),
|
|
581
|
+
xaxis=dict(title='Target Atoms', tickfont=dict(size=10)),
|
|
582
|
+
yaxis=dict(title='Central Atoms', tickfont=dict(size=10)),
|
|
583
|
+
width=900,
|
|
584
|
+
height=700,
|
|
585
|
+
autosize=True
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if atom_type is not None:
|
|
589
|
+
html_filename = os.path.join(output_dir, f"{atom_type}_heatmap_contacts.html")
|
|
590
|
+
else:
|
|
591
|
+
html_filename = os.path.join(output_dir, "heatmap_contacts.html")
|
|
592
|
+
|
|
593
|
+
fig.write_html(html_filename)
|
|
594
|
+
print(f"Interactive contact heatmap saved to {html_filename}")
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
def plot_distance_heatmap(sub_dm_total, x_labels, y_labels, atom_type, output_dir="./"):
|
|
598
|
+
"""
|
|
599
|
+
Plots and saves an interactive heatmap of average distances between central and target atoms.
|
|
600
|
+
|
|
601
|
+
Parameters
|
|
602
|
+
----------
|
|
603
|
+
sub_dm_total : np.ndarray
|
|
604
|
+
3D numpy array containing sub-distance matrices for each frame
|
|
605
|
+
x_labels : List[str]
|
|
606
|
+
Labels for the target atoms (x-axis)
|
|
607
|
+
y_labels : List[str]
|
|
608
|
+
Labels for the central atoms (y-axis)
|
|
609
|
+
atom_type : Optional[str]
|
|
610
|
+
Central atom type (used for naming the file)
|
|
611
|
+
output_dir : str, optional
|
|
612
|
+
Directory where the output file will be saved (default: "./")
|
|
613
|
+
|
|
614
|
+
Returns
|
|
615
|
+
-------
|
|
616
|
+
None
|
|
617
|
+
"""
|
|
618
|
+
average_distance_matrix = np.mean(sub_dm_total, axis=0)
|
|
619
|
+
std_distance_matrix = np.std(sub_dm_total, axis=0)
|
|
620
|
+
|
|
621
|
+
hover_text = []
|
|
622
|
+
for i, central_label in enumerate(y_labels):
|
|
623
|
+
hover_row = []
|
|
624
|
+
for j, target_label in enumerate(x_labels):
|
|
625
|
+
hover_row.append(
|
|
626
|
+
f"Central: {central_label}<br>" +
|
|
627
|
+
f"Target: {target_label}<br>" +
|
|
628
|
+
f"Avg distance: {average_distance_matrix[i, j]:.3f} Å<br>" +
|
|
629
|
+
f"Std deviation: {std_distance_matrix[i, j]:.3f} Å<br>" +
|
|
630
|
+
f"Min: {np.min(sub_dm_total[:, i, j]):.3f} Å<br>" +
|
|
631
|
+
f"Max: {np.max(sub_dm_total[:, i, j]):.3f} Å"
|
|
632
|
+
)
|
|
633
|
+
hover_text.append(hover_row)
|
|
634
|
+
|
|
635
|
+
fig = go.Figure(data=go.Heatmap(
|
|
636
|
+
z=average_distance_matrix,
|
|
637
|
+
x=x_labels,
|
|
638
|
+
y=y_labels,
|
|
639
|
+
colorscale='Viridis',
|
|
640
|
+
hoverinfo='text',
|
|
641
|
+
text=hover_text,
|
|
642
|
+
colorbar=dict(title='Distance (Å)')
|
|
643
|
+
))
|
|
644
|
+
|
|
645
|
+
fig.update_layout(
|
|
646
|
+
title=dict(text='Distance matrix of selected atoms', font=dict(size=16)),
|
|
647
|
+
xaxis=dict(title='Target Atoms', tickfont=dict(size=10)),
|
|
648
|
+
yaxis=dict(title='Central Atoms', tickfont=dict(size=10)),
|
|
649
|
+
width=900,
|
|
650
|
+
height=700,
|
|
651
|
+
autosize=True
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
if atom_type is not None:
|
|
655
|
+
html_filename = os.path.join(output_dir, f"{atom_type}_heatmap_distance.html")
|
|
656
|
+
else:
|
|
657
|
+
html_filename = os.path.join(output_dir, "heatmap_distance.html")
|
|
658
|
+
|
|
659
|
+
fig.write_html(html_filename)
|
|
660
|
+
print(f"Interactive distance heatmap saved to {html_filename}")
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def plot_contact_distance(sub_dm_total, contact_matrix, time_step, frame_skip, output_dir="./"):
|
|
664
|
+
"""
|
|
665
|
+
Plots and saves a time series of the average contact distance over the trajectory using Plotly
|
|
666
|
+
and also saves a static Matplotlib version as PNG.
|
|
667
|
+
|
|
668
|
+
Parameters
|
|
669
|
+
----------
|
|
670
|
+
sub_dm_total : np.ndarray
|
|
671
|
+
3D numpy array of sub-distance matrices for each frame
|
|
672
|
+
contact_matrix : np.ndarray
|
|
673
|
+
Boolean 3D numpy array indicating which distances are considered contacts
|
|
674
|
+
time_step : float
|
|
675
|
+
Time between frames
|
|
676
|
+
frame_skip : int
|
|
677
|
+
Number of frames skipped when processing
|
|
678
|
+
output_dir : str, optional
|
|
679
|
+
Directory where the output file will be saved (default: "./")
|
|
680
|
+
|
|
681
|
+
Returns
|
|
682
|
+
-------
|
|
683
|
+
None
|
|
684
|
+
"""
|
|
685
|
+
contact_distance = np.where(contact_matrix, sub_dm_total, np.nan)
|
|
686
|
+
contact_count = np.sum(contact_matrix, axis=(1, 2))/np.shape(sub_dm_total)[1]
|
|
687
|
+
average_distance_contacts = np.nanmean(contact_distance, axis=(1, 2))
|
|
688
|
+
|
|
689
|
+
x = np.arange(len(average_distance_contacts)) * time_step * frame_skip / 1000
|
|
690
|
+
|
|
691
|
+
valid_indices = ~np.isnan(average_distance_contacts)
|
|
692
|
+
interpolated = np.interp(x, x[valid_indices], average_distance_contacts[valid_indices])
|
|
693
|
+
mean_distance = np.mean(interpolated)
|
|
694
|
+
mean_count = np.mean(contact_count)
|
|
695
|
+
|
|
696
|
+
fig = go.Figure()
|
|
697
|
+
|
|
698
|
+
fig.add_trace(go.Scatter(
|
|
699
|
+
x=x,
|
|
700
|
+
y=interpolated,
|
|
701
|
+
mode='lines+markers',
|
|
702
|
+
name='Avg Dist',
|
|
703
|
+
line=dict(color='blue'),
|
|
704
|
+
yaxis='y'
|
|
705
|
+
))
|
|
706
|
+
|
|
707
|
+
fig.add_trace(go.Scatter(
|
|
708
|
+
x=[x[0], x[-1]],
|
|
709
|
+
y=[mean_distance, mean_distance],
|
|
710
|
+
mode='lines',
|
|
711
|
+
name=f'Mean Dist: {mean_distance:.2f} Å',
|
|
712
|
+
line=dict(color='blue', dash='dash'),
|
|
713
|
+
yaxis='y'
|
|
714
|
+
))
|
|
715
|
+
|
|
716
|
+
fig.add_trace(go.Scatter(
|
|
717
|
+
x=x,
|
|
718
|
+
y=contact_count,
|
|
719
|
+
mode='lines+markers',
|
|
720
|
+
name='Contact Count',
|
|
721
|
+
line=dict(color='red'),
|
|
722
|
+
yaxis='y2'
|
|
723
|
+
))
|
|
724
|
+
|
|
725
|
+
fig.add_trace(go.Scatter(
|
|
726
|
+
x=[x[0], x[-1]],
|
|
727
|
+
y=[mean_count, mean_count],
|
|
728
|
+
mode='lines',
|
|
729
|
+
name=f'Mean Count: {mean_count:.1f}',
|
|
730
|
+
line=dict(color='red', dash='dash'),
|
|
731
|
+
yaxis='y2'
|
|
732
|
+
))
|
|
733
|
+
|
|
734
|
+
fig.update_layout(
|
|
735
|
+
title='Average Distance of Contacts & Contact Count',
|
|
736
|
+
xaxis=dict(
|
|
737
|
+
title='Time (ps)',
|
|
738
|
+
showgrid=True,
|
|
739
|
+
gridwidth=0.5
|
|
740
|
+
),
|
|
741
|
+
yaxis=dict(
|
|
742
|
+
title=dict(
|
|
743
|
+
text='Distance (Å)',
|
|
744
|
+
font=dict(color='blue')
|
|
745
|
+
),
|
|
746
|
+
tickfont=dict(color='blue'),
|
|
747
|
+
showgrid=True,
|
|
748
|
+
gridwidth=0.5
|
|
749
|
+
),
|
|
750
|
+
yaxis2=dict(
|
|
751
|
+
title=dict(
|
|
752
|
+
text='Contact count per atom',
|
|
753
|
+
font=dict(color='red')
|
|
754
|
+
),
|
|
755
|
+
tickfont=dict(color='red'),
|
|
756
|
+
anchor="x",
|
|
757
|
+
overlaying="y",
|
|
758
|
+
side="right"
|
|
759
|
+
),
|
|
760
|
+
legend=dict(
|
|
761
|
+
orientation="h",
|
|
762
|
+
yanchor="bottom",
|
|
763
|
+
y=1.02,
|
|
764
|
+
xanchor="right",
|
|
765
|
+
x=1
|
|
766
|
+
),
|
|
767
|
+
width=900,
|
|
768
|
+
height=600
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
html_filename = os.path.join(output_dir, "average_contact_analysis.html")
|
|
772
|
+
fig.write_html(html_filename)
|
|
773
|
+
print(f"Interactive contact analysis chart saved to {html_filename}")
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
fig_mpl, ax1 = plt.subplots(figsize=(10, 6))
|
|
777
|
+
ax2 = ax1.twinx()
|
|
778
|
+
|
|
779
|
+
line1 = ax1.plot(x, interpolated, 'o-', color='blue', label='Avg Dist',
|
|
780
|
+
markersize=4, markevery=max(1, len(x)//20))
|
|
781
|
+
ax1.axhline(y=mean_distance, color='blue', linestyle='--',
|
|
782
|
+
label=f'Mean Distance: {mean_distance:.2f} Å')
|
|
783
|
+
|
|
784
|
+
line2 = ax2.plot(x, contact_count, 'o-', color='red', label='Contact Count',
|
|
785
|
+
markersize=4, markevery=max(1, len(x)//20))
|
|
786
|
+
ax2.axhline(y=mean_count, color='red', linestyle='--',
|
|
787
|
+
label=f'Mean Count: {mean_count:.1f}')
|
|
788
|
+
|
|
789
|
+
ax1.set_xlabel('Time (ps)', fontsize=12)
|
|
790
|
+
ax1.set_ylabel('Distance (Å)', color='blue', fontsize=12)
|
|
791
|
+
ax2.set_ylabel('Contact count per atom', color='red', fontsize=12)
|
|
792
|
+
plt.title('Average Distance of Contacts & Contact Count', fontsize=14)
|
|
793
|
+
|
|
794
|
+
ax1.tick_params(axis='y', labelcolor='blue')
|
|
795
|
+
ax2.tick_params(axis='y', labelcolor='red')
|
|
796
|
+
|
|
797
|
+
ax1.grid(True, alpha=0.3)
|
|
798
|
+
|
|
799
|
+
lines = line1 + line2
|
|
800
|
+
labels = [l.get_label() for l in lines]
|
|
801
|
+
|
|
802
|
+
all_lines = line1 + line2 + [
|
|
803
|
+
plt.Line2D([0], [0], color='blue', linestyle='--'),
|
|
804
|
+
plt.Line2D([0], [0], color='red', linestyle='--')
|
|
805
|
+
]
|
|
806
|
+
all_labels = [l.get_label() for l in line1 + line2] + [
|
|
807
|
+
f'Mean Dist: {mean_distance:.2f} Å',
|
|
808
|
+
f'Mean Count: {mean_count:.1f}'
|
|
809
|
+
]
|
|
810
|
+
|
|
811
|
+
fig_mpl.legend(all_lines, all_labels, loc='upper center',
|
|
812
|
+
bbox_to_anchor=(0.5, -0.02), ncol=2)
|
|
813
|
+
|
|
814
|
+
plt.tight_layout(pad=1.2)
|
|
815
|
+
|
|
816
|
+
png_filename = os.path.join(output_dir, "average_contact_analysis.png")
|
|
817
|
+
plt.savefig(png_filename, dpi=300, bbox_inches='tight')
|
|
818
|
+
plt.show()
|
|
819
|
+
plt.close()
|
|
820
|
+
|
|
821
|
+
print(f"Static contact analysis chart saved to {png_filename}")
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
def save_matrix_data(sub_dm_total, contact_matrix, output_dir="./"):
|
|
825
|
+
"""
|
|
826
|
+
Saves the sub-distance matrices and contact matrix to npy files.
|
|
827
|
+
|
|
828
|
+
Parameters
|
|
829
|
+
----------
|
|
830
|
+
sub_dm_total : np.ndarray
|
|
831
|
+
3D numpy array of sub-distance matrices
|
|
832
|
+
contact_matrix : np.ndarray
|
|
833
|
+
Boolean 3D numpy array of contact information
|
|
834
|
+
output_dir : str, optional
|
|
835
|
+
Directory where the output files will be saved (default: "./")
|
|
836
|
+
|
|
837
|
+
Returns
|
|
838
|
+
-------
|
|
839
|
+
None
|
|
840
|
+
"""
|
|
841
|
+
np.save(os.path.join(output_dir, "sub_dm_total.npy"), sub_dm_total)
|
|
842
|
+
np.save(os.path.join(output_dir, "contact_matrix.npy"), contact_matrix)
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
def contacts(traj_path, central_atoms, target_atoms, custom_cutoffs, frame_skip=10,
|
|
846
|
+
plot_distance_matrix=False, plot_contacts=False, time_step=None, save_data=False,
|
|
847
|
+
output_dir="./", mic=True):
|
|
848
|
+
"""
|
|
849
|
+
Processes a molecular trajectory file to compute contacts between central and target atoms.
|
|
850
|
+
|
|
851
|
+
Parameters
|
|
852
|
+
----------
|
|
853
|
+
traj_path : str
|
|
854
|
+
Path to the trajectory file
|
|
855
|
+
central_atoms : Union[str, List[Union[int, str]]]
|
|
856
|
+
Criteria for selecting central atoms being analyzed
|
|
857
|
+
target_atoms : Union[str, List[Union[int, str]]]
|
|
858
|
+
Criteria for selecting target atoms that interact with central atoms
|
|
859
|
+
custom_cutoffs : Dict[Tuple[str, str], float]
|
|
860
|
+
Dictionary with custom cutoff values for specific atom pairs
|
|
861
|
+
frame_skip : int, optional
|
|
862
|
+
Number of frames to skip (default: 10)
|
|
863
|
+
plot_distance_matrix : bool, optional
|
|
864
|
+
Boolean flag to plot average distance heatmap (default: False)
|
|
865
|
+
plot_contacts : bool, optional
|
|
866
|
+
Boolean flag to plot contact times heatmap (default: False)
|
|
867
|
+
time_step : float, optional
|
|
868
|
+
Time between frames in fs (required for contact heatmap)
|
|
869
|
+
save_data : bool, optional
|
|
870
|
+
Boolean flag to save matrices as npy files (default: False)
|
|
871
|
+
output_dir : str, optional
|
|
872
|
+
Directory where output files will be saved (default: "./")
|
|
873
|
+
mic : bool, optional
|
|
874
|
+
Whether to use minimum image convention (default: True)
|
|
875
|
+
|
|
876
|
+
Returns
|
|
877
|
+
-------
|
|
878
|
+
Tuple[np.ndarray, np.ndarray]
|
|
879
|
+
3D numpy array of sub-distance matrices and Boolean 3D numpy array of contacts
|
|
880
|
+
"""
|
|
881
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
882
|
+
|
|
883
|
+
sub_dm_list = []
|
|
884
|
+
trajectory = read(traj_path, index=f"::{frame_skip}")
|
|
885
|
+
|
|
886
|
+
results = Parallel(n_jobs=-1)(
|
|
887
|
+
delayed(contacts_frame)(atoms, central_atoms, target_atoms, custom_cutoffs, mic=mic)
|
|
888
|
+
for atoms in trajectory
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
sub_dm_list, cutoff_matrices, indices_central, indices_target = zip(*results)
|
|
892
|
+
cutoff_matrix = cutoff_matrices[0]
|
|
893
|
+
sub_dm_total = np.array(sub_dm_list)
|
|
894
|
+
|
|
895
|
+
atom_type = (central_atoms if isinstance(central_atoms, str)
|
|
896
|
+
else chemical_symbols[central_atoms] if isinstance(central_atoms, int)
|
|
897
|
+
else None)
|
|
898
|
+
|
|
899
|
+
y_labels = [f"{trajectory[0].get_chemical_symbols()[i]}({i})" for i in indices_central[0]]
|
|
900
|
+
x_labels = [f"{trajectory[0].get_chemical_symbols()[i]}({i})" for i in indices_target[0]]
|
|
901
|
+
|
|
902
|
+
contact_matrix = sub_dm_total < cutoff_matrix
|
|
903
|
+
|
|
904
|
+
if plot_contacts and time_step is not None:
|
|
905
|
+
plot_contact_heatmap(contact_matrix, frame_skip, time_step, x_labels, y_labels, atom_type,
|
|
906
|
+
output_dir=output_dir)
|
|
907
|
+
plot_contact_distance(sub_dm_total, contact_matrix, time_step, frame_skip, output_dir=output_dir)
|
|
908
|
+
|
|
909
|
+
if plot_distance_matrix:
|
|
910
|
+
plot_distance_heatmap(sub_dm_total, x_labels, y_labels, atom_type, output_dir=output_dir)
|
|
911
|
+
|
|
912
|
+
if save_data:
|
|
913
|
+
save_matrix_data(sub_dm_total, contact_matrix, output_dir=output_dir)
|
|
914
|
+
|
|
915
|
+
return sub_dm_total, contact_matrix
|