crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,915 @@
1
+ """
2
+ CRISP/data_analysis/contact_coordination.py
3
+
4
+ This module performs contact and correlation on molecular dynamics trajectory data.
5
+ """
6
+
7
+ from ase.io import read
8
+ import numpy as np
9
+ from joblib import Parallel, delayed
10
+ import pickle
11
+ from typing import Union, List, Dict, Tuple, Optional, Any
12
+ import os
13
+ import matplotlib.pyplot as plt
14
+ import itertools
15
+ from ase.data import vdw_radii, atomic_numbers, chemical_symbols
16
+ import seaborn as sns
17
+ import plotly.graph_objects as go
18
+ import plotly.io as pio
19
+ __all__ = ['indices', 'coordination_frame', 'coordination', 'contacts_frame', 'contacts']
20
+
21
+ def indices(atoms, ind: Union[str, List[Union[int, str]]]) -> np.ndarray:
22
+ """
23
+ Return array of atom indices from an ASE Atoms object based on the input specifier.
24
+
25
+ Parameters
26
+ ----------
27
+ atoms : ase.Atoms
28
+ ASE Atoms object containing atomic structure
29
+ ind : Union[str, List[Union[int, str]]]
30
+ Index specifier, can be "all", .npy file, integer(s), or chemical symbol(s)
31
+
32
+ Returns
33
+ -------
34
+ np.ndarray
35
+ Array of selected indices
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ If the index type is invalid
41
+ """
42
+ if ind == "all" or ind is None:
43
+ return np.arange(len(atoms))
44
+ if isinstance(ind, str) and ind.endswith(".npy"):
45
+ return np.load(ind, allow_pickle=True)
46
+ if not isinstance(ind, list):
47
+ ind = [ind]
48
+ if any(isinstance(item, int) for item in ind):
49
+ return np.array(ind)
50
+ if any(isinstance(item, str) for item in ind):
51
+ idx = []
52
+ if isinstance(ind, str):
53
+ ind = [ind]
54
+ for symbol in ind:
55
+ idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
56
+ return np.concatenate(idx)
57
+ raise ValueError("Invalid index type")
58
+
59
+
60
+ def coordination_frame(atoms, central_atoms, target_atoms, custom_cutoffs=None, mic=True):
61
+ """
62
+ Calculate coordination numbers for central atoms based on interatomic distances and cutoff criteria.
63
+
64
+ Parameters
65
+ ----------
66
+ atoms : ase.Atoms
67
+ ASE Atoms object containing atomic structure
68
+ central_atoms : Union[str, List[Union[int, str]]]
69
+ Specifier for central atoms
70
+ target_atoms : Union[str, List[Union[int, str]]]
71
+ Specifier for target atoms that interact with central atoms
72
+ custom_cutoffs : Optional[Dict[Tuple[str, str], float]]
73
+ Dictionary with custom cutoff distances for atom pairs
74
+ mic : bool
75
+ Whether to use the minimum image convention
76
+
77
+ Returns
78
+ -------
79
+ Dict[int, int]
80
+ Dictionary mapping central atom indices to their coordination numbers
81
+ """
82
+ indices_central = indices(atoms, central_atoms)
83
+ indices_target = indices(atoms, target_atoms)
84
+
85
+ dm = atoms.get_all_distances(mic=mic)
86
+ np.fill_diagonal(dm, np.inf)
87
+
88
+ sub_dm = dm[np.ix_(indices_central, indices_target)]
89
+
90
+ central_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_central]
91
+ target_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_target]
92
+
93
+ central_vdw_radii = vdw_radii[central_atomic_numbers]
94
+ target_vdw_radii = vdw_radii[target_atomic_numbers]
95
+
96
+ cutoff_matrix = 0.6 * (central_vdw_radii[:, np.newaxis] + target_vdw_radii[np.newaxis, :])
97
+
98
+ if custom_cutoffs is not None:
99
+ cutoff_atomic_numbers = [tuple(sorted(atomic_numbers[symbol] for symbol in pair)) for pair in
100
+ list(custom_cutoffs.keys())]
101
+ cutoff_values = list(custom_cutoffs.values())
102
+
103
+ cutoff_matrix_indices = [[tuple(sorted([i, j])) for j in target_atomic_numbers] for i in central_atomic_numbers]
104
+
105
+ for i, central_atom in enumerate(cutoff_matrix_indices):
106
+ for j, bond in enumerate(central_atom):
107
+ if bond in cutoff_atomic_numbers:
108
+ cutoff_matrix[i, j] = cutoff_values[cutoff_atomic_numbers.index(bond)]
109
+
110
+ coordination_numbers = np.sum(sub_dm < cutoff_matrix, axis=1)
111
+ coordination_dict_frame = dict(zip(indices_central, coordination_numbers))
112
+ return coordination_dict_frame
113
+
114
+
115
+ def get_avg_percentages(coordination_data, atom_type, plot_cn=False, output_dir="./"):
116
+ """
117
+ Compute average percentages of each coordination number over all frames.
118
+
119
+ Parameters
120
+ ----------
121
+ coordination_data : Dict[int, List[int]]
122
+ Dictionary mapping atom indices to lists of coordination numbers
123
+ atom_type : Optional[str]
124
+ Chemical symbol of target atoms or None
125
+ plot_cn : bool, optional
126
+ Boolean to indicate if a time series plot should be generated
127
+ output_dir : str, optional
128
+ Directory where output files will be saved
129
+
130
+ Returns
131
+ -------
132
+ Dict[int, List[float]]
133
+ Dictionary mapping each coordination number to a list of average percentages per frame
134
+ """
135
+ coord_types = set(itertools.chain.from_iterable(coordination_data.values()))
136
+ coord_types = sorted(coord_types)
137
+
138
+ num_frames = len(next(iter(coordination_data.values())))
139
+ avg_percentages = {coord_type: [] for coord_type in coord_types}
140
+
141
+ for frame_idx in range(num_frames):
142
+ frame_data = [values[frame_idx] for values in coordination_data.values()]
143
+ total_atoms = len(frame_data)
144
+ for coord_type in coord_types:
145
+ count = frame_data.count(coord_type)
146
+ avg_percentage = count / total_atoms * 100
147
+ avg_percentages[coord_type].append(avg_percentage)
148
+
149
+ frames = list(range(len(next(iter(avg_percentages.values())))))
150
+ coord_types = list(avg_percentages.keys())
151
+
152
+ if plot_cn:
153
+ colors = plt.get_cmap('tab10', len(coord_types)).colors
154
+ markers = itertools.cycle(['o', 's', 'D', '^', 'v', 'p', '*', '+', 'x'])
155
+ plt.figure(figsize=(10, 6))
156
+ for i, coord_type in enumerate(coord_types):
157
+ plt.plot(frames, avg_percentages[coord_type], label=f'CN={coord_type}',
158
+ color=colors[i], marker=next(markers), markevery=max(1, len(frames) // 20))
159
+ for i, coord_type in enumerate(coord_types):
160
+ mean_value = sum(avg_percentages[coord_type]) / len(avg_percentages[coord_type])
161
+ plt.axhline(y=mean_value, color=colors[i], linestyle='--', alpha=0.7,
162
+ label=f'Mean CN={coord_type}: {mean_value:.1f}%')
163
+ plt.xlabel('Frame Index', fontsize=12)
164
+ plt.ylabel('Percentage of Atoms (%)', fontsize=12)
165
+ if atom_type is not None:
166
+ plt.title(f'Coordination Analysis: {atom_type} Atoms', fontsize=14)
167
+ else:
168
+ plt.title('Coordination Analysis', fontsize=14)
169
+ plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
170
+ plt.grid(True, alpha=0.3)
171
+ plt.tight_layout()
172
+ plt.savefig(os.path.join(output_dir, "CN_time_series.png"), dpi=300, bbox_inches='tight')
173
+ plt.show()
174
+ plt.close()
175
+
176
+ return avg_percentages
177
+
178
+
179
+ def plot_coordination_distribution(avg_percentages, atom_type, plot_cn, output_dir="./", output_file="CN_distribution"):
180
+ """
181
+ Plot a pie chart showing the overall average distribution of coordination numbers.
182
+
183
+ Parameters
184
+ ----------
185
+ avg_percentages : Dict[int, List[float]]
186
+ Dictionary of average percentages per coordination number
187
+ atom_type : Optional[str]
188
+ Chemical symbol for target atoms
189
+ plot_cn : bool
190
+ Boolean to indicate if the plot should be generated
191
+ output_dir : str, optional
192
+ Directory where output files will be saved
193
+ output_file : str, optional
194
+ Filename for saving the plot
195
+
196
+ Returns
197
+ -------
198
+ Dict[int, float]
199
+ Dictionary of overall average coordination percentages
200
+ """
201
+ overall_avg_percentages = plot_coordination_distribution_plotly(avg_percentages, atom_type, plot_cn,
202
+ output_dir=output_dir,
203
+ output_file=output_file)
204
+
205
+ # Generate a static version
206
+ if plot_cn:
207
+ overall_avg_percentages = {coord_type: sum(percentages) / len(percentages)
208
+ for coord_type, percentages in avg_percentages.items()}
209
+ fig, ax = plt.subplots(figsize=(10, 7))
210
+
211
+ sorted_data = sorted(overall_avg_percentages.items())
212
+ coord_types = [item[0] for item in sorted_data]
213
+ percentages = [item[1] for item in sorted_data]
214
+
215
+ colors = plt.cm.tab10.colors[:len(coord_types)]
216
+
217
+ # Create pie chart
218
+ wedges, _ = ax.pie(percentages,
219
+ wedgeprops=dict(width=0.5, edgecolor='w'),
220
+ startangle=90,
221
+ colors=colors)
222
+
223
+ legend_labels = [f'CN={coord_type}: {pct:.1f}%' for coord_type, pct in sorted_data]
224
+
225
+ ax.legend(wedges, legend_labels,
226
+ title="Coordination Numbers",
227
+ loc="center left",
228
+ bbox_to_anchor=(1, 0.5),
229
+ frameon=True,
230
+ fancybox=True,
231
+ shadow=True)
232
+
233
+ ax.axis('equal')
234
+
235
+ if atom_type:
236
+ plt.title(f'Average Distribution of {atom_type} Atom Coordination', fontsize=14)
237
+ else:
238
+ plt.title(f'Average Distribution of Atom Coordination', fontsize=14)
239
+
240
+ plt.tight_layout()
241
+ plt.savefig(os.path.join(output_dir, f"{output_file}.png"), dpi=300, bbox_inches='tight')
242
+ plt.show()
243
+ plt.close()
244
+
245
+ return overall_avg_percentages
246
+
247
+
248
+ def plot_coordination_distribution_plotly(avg_percentages, atom_type, plot_cn, output_dir="./", output_file="CN_distribution"):
249
+ """
250
+ Plot an interactive pie chart showing the overall average distribution of coordination numbers.
251
+
252
+ Parameters
253
+ ----------
254
+ avg_percentages : Dict[int, List[float]]
255
+ Dictionary of average percentages per coordination number
256
+ atom_type : Optional[str]
257
+ Chemical symbol for target atoms
258
+ plot_cn : bool
259
+ Boolean to indicate if the plot should be generated
260
+ output_dir : str, optional
261
+ Directory where output files will be saved
262
+ output_file : str, optional
263
+ Filename for saving the plot
264
+
265
+ Returns
266
+ -------
267
+ Dict[int, float]
268
+ Dictionary of overall average coordination percentages
269
+ """
270
+ overall_avg_percentages = {coord_type: sum(percentages) / len(percentages) for coord_type, percentages in
271
+ avg_percentages.items()}
272
+
273
+ if plot_cn:
274
+ sorted_data = sorted(overall_avg_percentages.items())
275
+ coord_types = [f"CN={item[0]}" for item in sorted_data]
276
+ percentages = [item[1] for item in sorted_data]
277
+
278
+ hover_info = [f"CN={coord_type}: {pct:.2f}%" for coord_type, pct in sorted_data]
279
+
280
+ fig = go.Figure(data=[go.Pie(
281
+ labels=coord_types,
282
+ values=percentages,
283
+ hole=0.4,
284
+ textinfo='label+percent',
285
+ hoverinfo='text',
286
+ hovertext=hover_info,
287
+ textfont=dict(size=12),
288
+ marker=dict(
289
+ colors=[f'rgb{tuple(int(c*255) for c in plt.cm.tab10(i)[:3])}' for i in range(len(coord_types))],
290
+ line=dict(color='white', width=2)
291
+ ),
292
+ )])
293
+
294
+ if atom_type:
295
+ title_text = f'Average Distribution of {atom_type} Atom Coordination'
296
+ else:
297
+ title_text = 'Average Distribution of Atom Coordination'
298
+
299
+ fig.update_layout(
300
+ title=dict(
301
+ text=title_text,
302
+ font=dict(size=16)
303
+ ),
304
+ legend=dict(
305
+ orientation='h',
306
+ xanchor='center',
307
+ x=0.5,
308
+ y=-0.1
309
+ ),
310
+ height=600,
311
+ width=800
312
+ )
313
+
314
+ html_path = os.path.join(output_dir, f"{output_file}.html")
315
+ fig.write_html(html_path)
316
+ print(f"Interactive coordination distribution chart saved to {html_path}")
317
+
318
+ return overall_avg_percentages
319
+
320
+
321
+ def log_coordination_data(distribution, avg_percentages, atom_type, avg_cn=None, std_cn=None, output_dir="./"):
322
+ """
323
+ Log coordination analysis statistics to a text file.
324
+
325
+ Parameters
326
+ ----------
327
+ distribution : Dict[int, float]
328
+ Dictionary of overall average coordination percentages
329
+ avg_percentages : Dict[int, List[float]]
330
+ Dictionary of percentage values per frame for each coordination number
331
+ atom_type : Optional[str]
332
+ Chemical symbol of target atoms or None
333
+ avg_cn : float, optional
334
+ Average coordination number
335
+ std_cn : float, optional
336
+ Standard deviation of coordination number
337
+ output_dir : str, optional
338
+ Directory where the statistics file will be saved
339
+
340
+ Returns
341
+ -------
342
+ None
343
+ """
344
+ if atom_type is not None:
345
+ stats_file = f"CN_{atom_type}_statistics.txt"
346
+ else:
347
+ stats_file = "CN_statistics.txt"
348
+
349
+ stats_file = os.path.join(output_dir, stats_file)
350
+ with open(stats_file, 'w') as f:
351
+ if atom_type is not None:
352
+ f.write(f"Coordination Analysis for {atom_type} Atoms\n")
353
+ else:
354
+ f.write("Coordination Analysis\n")
355
+ f.write("======================================\n\n")
356
+
357
+ if avg_cn is not None and std_cn is not None:
358
+ f.write(f"Average Coordination Number: {avg_cn:.2f} ± {std_cn:.2f}\n\n")
359
+
360
+ f.write("Overall Average Percentages:\n")
361
+
362
+ # Standard deviations for each coordination number
363
+ std_devs = {cn: np.std(values) for cn, values in avg_percentages.items()}
364
+
365
+ for coord_type, avg_percentage in sorted(distribution.items()):
366
+ std_dev = std_devs[coord_type]
367
+ f.write(f" CN={coord_type}: {avg_percentage:.2f}% ± {std_dev:.2f}%\n")
368
+
369
+ most_common_cn = max(distribution.items(), key=lambda x: x[1])[0]
370
+ f.write(f"\nMost common coordination number: {most_common_cn}\n")
371
+
372
+ print(f"Coordination statistics saved to {stats_file}")
373
+
374
+
375
+ def coordination(traj_path, central_atoms, target_atoms, custom_cutoffs, frame_skip=10,
376
+ plot_cn=False, output_dir="./"):
377
+ """
378
+ Process a trajectory file to compute coordination numbers for each frame.
379
+
380
+ Parameters
381
+ ----------
382
+ traj_path : str
383
+ Path to the trajectory file
384
+ central_atoms : Union[str, List[Union[int, str]]]
385
+ Specifier for central atoms being analyzed
386
+ target_atoms : Union[str, List[Union[int, str]]]
387
+ Specifier for target atoms that interact with central atoms
388
+ custom_cutoffs : Dict[Tuple[str, str], float]
389
+ Dictionary with custom cutoff distances
390
+ frame_skip : int, optional
391
+ Interval for skipping frames (default: 10)
392
+ plot_cn : bool, optional
393
+ Boolean to indicate if plots should be generated (default: False)
394
+ output_dir : str, optional
395
+ Directory where output files will be saved (default: "./")
396
+
397
+ Returns
398
+ -------
399
+ List
400
+ List containing coordination dictionary, average percentages, distribution, atom type, and avg CN
401
+ """
402
+ os.makedirs(output_dir, exist_ok=True)
403
+
404
+ trajectory = read(traj_path, index=f"::{frame_skip}")
405
+ coordination_dict = {}
406
+
407
+ results = Parallel(n_jobs=-1)(
408
+ delayed(coordination_frame)(atoms, central_atoms, target_atoms, custom_cutoffs)
409
+ for atoms in trajectory
410
+ )
411
+
412
+ for frame_dict in results:
413
+ for key, value in frame_dict.items():
414
+ coordination_dict.setdefault(key, []).append(value)
415
+
416
+
417
+ if isinstance(central_atoms, str):
418
+ if central_atoms.endswith('.npy'):
419
+ # Extract just the basename without extension for .npy files
420
+ atom_type = os.path.splitext(os.path.basename(central_atoms))[0]
421
+ else:
422
+ atom_type = central_atoms
423
+ elif isinstance(central_atoms, int):
424
+ atom_type = chemical_symbols[central_atoms]
425
+ else:
426
+ atom_type = None
427
+
428
+ avg_percentages = get_avg_percentages(coordination_dict, atom_type, plot_cn, output_dir=output_dir)
429
+
430
+ distribution = plot_coordination_distribution(avg_percentages, atom_type, plot_cn, output_dir=output_dir)
431
+
432
+ all_cn_values = []
433
+ for atom_idx, cn_list in coordination_dict.items():
434
+ all_cn_values.extend(cn_list)
435
+
436
+ all_cn_values = np.array(all_cn_values)
437
+ avg_cn = np.mean(all_cn_values)
438
+ std_cn = np.std(all_cn_values)
439
+
440
+ weighted_avg_cn = sum(cn * (pct/100) for cn, pct in distribution.items())
441
+
442
+ log_coordination_data(distribution, avg_percentages, atom_type,
443
+ avg_cn=avg_cn, std_cn=std_cn, output_dir=output_dir)
444
+
445
+ print("\nCoordination Statistics Summary:")
446
+ print("=" * 40)
447
+
448
+ std_devs = {cn: np.std(values) for cn, values in avg_percentages.items()}
449
+
450
+ for coord_type, avg_percentage in sorted(distribution.items()):
451
+ std_dev = std_devs[coord_type]
452
+ print(f"CN={coord_type}: {avg_percentage:.2f}% ± {std_dev:.2f}%")
453
+
454
+ most_common_cn = max(distribution.items(), key=lambda x: x[1])[0]
455
+ print(f"\nMost common coordination number: {most_common_cn}")
456
+
457
+ print(f"\nAverage Coordination Number: {avg_cn:.2f} ± {std_cn:.2f}")
458
+
459
+ return [coordination_dict, avg_percentages, distribution, atom_type, avg_cn, std_cn]
460
+
461
+
462
+ def contacts_frame(atoms, central_atoms, target_atoms, custom_cutoffs, mic=True):
463
+ """
464
+ Processes a single atoms frame to compute the sub-distance matrix and the corresponding cutoff matrix.
465
+
466
+ Parameters
467
+ ----------
468
+ atoms : ase.Atoms
469
+ ASE Atoms object containing atomic structure
470
+ central_atoms : Union[str, List[Union[int, str]]]
471
+ Selection criteria for central atoms being analyzed
472
+ target_atoms : Union[str, List[Union[int, str]]]
473
+ Selection criteria for target atoms that interact with central atoms
474
+ custom_cutoffs : Dict[Tuple[str, str], float]
475
+ Dictionary mapping atom pairs to custom cutoff values
476
+ mic : bool, optional
477
+ Whether to use minimum image convention (default: True)
478
+
479
+ Returns
480
+ -------
481
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
482
+ Sub-distance matrix, cutoff matrix, central atom indices, and target atom indices
483
+ """
484
+ indices_central = indices(atoms, central_atoms)
485
+ indices_target = indices(atoms, target_atoms)
486
+
487
+ dm = atoms.get_all_distances(mic=mic)
488
+ np.fill_diagonal(dm, np.inf) # Avoids self-interactions
489
+
490
+ has_overlap = np.intersect1d(indices_central, indices_target).size > 0
491
+
492
+ if has_overlap:
493
+ mask = np.zeros_like(dm, dtype=bool)
494
+ for i_idx, i in enumerate(indices_central):
495
+ for j_idx, j in enumerate(indices_target):
496
+ # We only count each pair once by enforcing i < j
497
+ if i < j:
498
+ mask[i, j] = True
499
+
500
+ filtered_dm = np.where(mask, dm, np.inf)
501
+ sub_dm = filtered_dm[np.ix_(indices_central, indices_target)]
502
+ else:
503
+ sub_dm = dm[np.ix_(indices_central, indices_target)]
504
+
505
+ # Get atomic numbers and van der Waals radii
506
+ central_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_central]
507
+ target_atomic_numbers = np.array(atoms.get_atomic_numbers())[indices_target]
508
+
509
+ central_vdw_radii = vdw_radii[central_atomic_numbers]
510
+ target_vdw_radii = vdw_radii[target_atomic_numbers]
511
+
512
+ cutoff_matrix = 0.6 * (central_vdw_radii[:, np.newaxis] + target_vdw_radii[np.newaxis, :])
513
+
514
+ if custom_cutoffs is not None:
515
+ cutoff_atomic_numbers = [
516
+ tuple(sorted(atomic_numbers[symbol] for symbol in pair))
517
+ for pair in list(custom_cutoffs.keys())
518
+ ]
519
+ cutoff_values = list(custom_cutoffs.values())
520
+
521
+ cutoff_matrix_indices = [
522
+ [tuple(sorted([i, j])) for j in target_atomic_numbers]
523
+ for i in central_atomic_numbers
524
+ ]
525
+ for i, central_atom in enumerate(cutoff_matrix_indices):
526
+ for j, bond in enumerate(central_atom):
527
+ if bond in cutoff_atomic_numbers:
528
+ cutoff_matrix[i, j] = cutoff_values[cutoff_atomic_numbers.index(bond)]
529
+
530
+ return sub_dm, cutoff_matrix, indices_central, indices_target
531
+
532
+
533
+ def plot_contact_heatmap(contact_matrix, frame_skip, time_step, x_labels, y_labels, atom_type, output_dir="./"):
534
+ """
535
+ Plots and saves an interactive heatmap showing contact times between central and target atoms.
536
+
537
+ Parameters
538
+ ----------
539
+ contact_matrix : np.ndarray
540
+ Boolean 3D array of contacts
541
+ frame_skip : int
542
+ Number of frames skipped when processing the trajectory
543
+ time_step : float
544
+ Time step between frames (used to convert counts to time)
545
+ x_labels : List[str]
546
+ Labels for the target atoms (x-axis)
547
+ y_labels : List[str]
548
+ Labels for the central atoms (y-axis)
549
+ atom_type : Optional[str]
550
+ Central atom type (used for naming the file)
551
+ output_dir : str, optional
552
+ Directory where the output file will be saved (default: "./")
553
+
554
+ Returns
555
+ -------
556
+ None
557
+ """
558
+ contact_times_matrix = np.sum(contact_matrix, axis=0) * frame_skip * time_step / 1000
559
+
560
+ hover_text = []
561
+ for i, central_label in enumerate(y_labels):
562
+ hover_row = []
563
+ for j, target_label in enumerate(x_labels):
564
+ hover_row.append(f"Central: {central_label}<br>" +
565
+ f"Target: {target_label}<br>" +
566
+ f"Contact time: {contact_times_matrix[i, j]:.2f} ps")
567
+ hover_text.append(hover_row)
568
+
569
+ fig = go.Figure(data=go.Heatmap(
570
+ z=contact_times_matrix,
571
+ x=x_labels,
572
+ y=y_labels,
573
+ colorscale='Viridis',
574
+ hoverinfo='text',
575
+ text=hover_text,
576
+ colorbar=dict(title='Contact Time (ps)')
577
+ ))
578
+
579
+ fig.update_layout(
580
+ title=dict(text='Heatmap of contact times within cutoffs', font=dict(size=16)),
581
+ xaxis=dict(title='Target Atoms', tickfont=dict(size=10)),
582
+ yaxis=dict(title='Central Atoms', tickfont=dict(size=10)),
583
+ width=900,
584
+ height=700,
585
+ autosize=True
586
+ )
587
+
588
+ if atom_type is not None:
589
+ html_filename = os.path.join(output_dir, f"{atom_type}_heatmap_contacts.html")
590
+ else:
591
+ html_filename = os.path.join(output_dir, "heatmap_contacts.html")
592
+
593
+ fig.write_html(html_filename)
594
+ print(f"Interactive contact heatmap saved to {html_filename}")
595
+
596
+
597
+ def plot_distance_heatmap(sub_dm_total, x_labels, y_labels, atom_type, output_dir="./"):
598
+ """
599
+ Plots and saves an interactive heatmap of average distances between central and target atoms.
600
+
601
+ Parameters
602
+ ----------
603
+ sub_dm_total : np.ndarray
604
+ 3D numpy array containing sub-distance matrices for each frame
605
+ x_labels : List[str]
606
+ Labels for the target atoms (x-axis)
607
+ y_labels : List[str]
608
+ Labels for the central atoms (y-axis)
609
+ atom_type : Optional[str]
610
+ Central atom type (used for naming the file)
611
+ output_dir : str, optional
612
+ Directory where the output file will be saved (default: "./")
613
+
614
+ Returns
615
+ -------
616
+ None
617
+ """
618
+ average_distance_matrix = np.mean(sub_dm_total, axis=0)
619
+ std_distance_matrix = np.std(sub_dm_total, axis=0)
620
+
621
+ hover_text = []
622
+ for i, central_label in enumerate(y_labels):
623
+ hover_row = []
624
+ for j, target_label in enumerate(x_labels):
625
+ hover_row.append(
626
+ f"Central: {central_label}<br>" +
627
+ f"Target: {target_label}<br>" +
628
+ f"Avg distance: {average_distance_matrix[i, j]:.3f} Å<br>" +
629
+ f"Std deviation: {std_distance_matrix[i, j]:.3f} Å<br>" +
630
+ f"Min: {np.min(sub_dm_total[:, i, j]):.3f} Å<br>" +
631
+ f"Max: {np.max(sub_dm_total[:, i, j]):.3f} Å"
632
+ )
633
+ hover_text.append(hover_row)
634
+
635
+ fig = go.Figure(data=go.Heatmap(
636
+ z=average_distance_matrix,
637
+ x=x_labels,
638
+ y=y_labels,
639
+ colorscale='Viridis',
640
+ hoverinfo='text',
641
+ text=hover_text,
642
+ colorbar=dict(title='Distance (Å)')
643
+ ))
644
+
645
+ fig.update_layout(
646
+ title=dict(text='Distance matrix of selected atoms', font=dict(size=16)),
647
+ xaxis=dict(title='Target Atoms', tickfont=dict(size=10)),
648
+ yaxis=dict(title='Central Atoms', tickfont=dict(size=10)),
649
+ width=900,
650
+ height=700,
651
+ autosize=True
652
+ )
653
+
654
+ if atom_type is not None:
655
+ html_filename = os.path.join(output_dir, f"{atom_type}_heatmap_distance.html")
656
+ else:
657
+ html_filename = os.path.join(output_dir, "heatmap_distance.html")
658
+
659
+ fig.write_html(html_filename)
660
+ print(f"Interactive distance heatmap saved to {html_filename}")
661
+
662
+
663
+ def plot_contact_distance(sub_dm_total, contact_matrix, time_step, frame_skip, output_dir="./"):
664
+ """
665
+ Plots and saves a time series of the average contact distance over the trajectory using Plotly
666
+ and also saves a static Matplotlib version as PNG.
667
+
668
+ Parameters
669
+ ----------
670
+ sub_dm_total : np.ndarray
671
+ 3D numpy array of sub-distance matrices for each frame
672
+ contact_matrix : np.ndarray
673
+ Boolean 3D numpy array indicating which distances are considered contacts
674
+ time_step : float
675
+ Time between frames
676
+ frame_skip : int
677
+ Number of frames skipped when processing
678
+ output_dir : str, optional
679
+ Directory where the output file will be saved (default: "./")
680
+
681
+ Returns
682
+ -------
683
+ None
684
+ """
685
+ contact_distance = np.where(contact_matrix, sub_dm_total, np.nan)
686
+ contact_count = np.sum(contact_matrix, axis=(1, 2))/np.shape(sub_dm_total)[1]
687
+ average_distance_contacts = np.nanmean(contact_distance, axis=(1, 2))
688
+
689
+ x = np.arange(len(average_distance_contacts)) * time_step * frame_skip / 1000
690
+
691
+ valid_indices = ~np.isnan(average_distance_contacts)
692
+ interpolated = np.interp(x, x[valid_indices], average_distance_contacts[valid_indices])
693
+ mean_distance = np.mean(interpolated)
694
+ mean_count = np.mean(contact_count)
695
+
696
+ fig = go.Figure()
697
+
698
+ fig.add_trace(go.Scatter(
699
+ x=x,
700
+ y=interpolated,
701
+ mode='lines+markers',
702
+ name='Avg Dist',
703
+ line=dict(color='blue'),
704
+ yaxis='y'
705
+ ))
706
+
707
+ fig.add_trace(go.Scatter(
708
+ x=[x[0], x[-1]],
709
+ y=[mean_distance, mean_distance],
710
+ mode='lines',
711
+ name=f'Mean Dist: {mean_distance:.2f} Å',
712
+ line=dict(color='blue', dash='dash'),
713
+ yaxis='y'
714
+ ))
715
+
716
+ fig.add_trace(go.Scatter(
717
+ x=x,
718
+ y=contact_count,
719
+ mode='lines+markers',
720
+ name='Contact Count',
721
+ line=dict(color='red'),
722
+ yaxis='y2'
723
+ ))
724
+
725
+ fig.add_trace(go.Scatter(
726
+ x=[x[0], x[-1]],
727
+ y=[mean_count, mean_count],
728
+ mode='lines',
729
+ name=f'Mean Count: {mean_count:.1f}',
730
+ line=dict(color='red', dash='dash'),
731
+ yaxis='y2'
732
+ ))
733
+
734
+ fig.update_layout(
735
+ title='Average Distance of Contacts & Contact Count',
736
+ xaxis=dict(
737
+ title='Time (ps)',
738
+ showgrid=True,
739
+ gridwidth=0.5
740
+ ),
741
+ yaxis=dict(
742
+ title=dict(
743
+ text='Distance (Å)',
744
+ font=dict(color='blue')
745
+ ),
746
+ tickfont=dict(color='blue'),
747
+ showgrid=True,
748
+ gridwidth=0.5
749
+ ),
750
+ yaxis2=dict(
751
+ title=dict(
752
+ text='Contact count per atom',
753
+ font=dict(color='red')
754
+ ),
755
+ tickfont=dict(color='red'),
756
+ anchor="x",
757
+ overlaying="y",
758
+ side="right"
759
+ ),
760
+ legend=dict(
761
+ orientation="h",
762
+ yanchor="bottom",
763
+ y=1.02,
764
+ xanchor="right",
765
+ x=1
766
+ ),
767
+ width=900,
768
+ height=600
769
+ )
770
+
771
+ html_filename = os.path.join(output_dir, "average_contact_analysis.html")
772
+ fig.write_html(html_filename)
773
+ print(f"Interactive contact analysis chart saved to {html_filename}")
774
+
775
+
776
+ fig_mpl, ax1 = plt.subplots(figsize=(10, 6))
777
+ ax2 = ax1.twinx()
778
+
779
+ line1 = ax1.plot(x, interpolated, 'o-', color='blue', label='Avg Dist',
780
+ markersize=4, markevery=max(1, len(x)//20))
781
+ ax1.axhline(y=mean_distance, color='blue', linestyle='--',
782
+ label=f'Mean Distance: {mean_distance:.2f} Å')
783
+
784
+ line2 = ax2.plot(x, contact_count, 'o-', color='red', label='Contact Count',
785
+ markersize=4, markevery=max(1, len(x)//20))
786
+ ax2.axhline(y=mean_count, color='red', linestyle='--',
787
+ label=f'Mean Count: {mean_count:.1f}')
788
+
789
+ ax1.set_xlabel('Time (ps)', fontsize=12)
790
+ ax1.set_ylabel('Distance (Å)', color='blue', fontsize=12)
791
+ ax2.set_ylabel('Contact count per atom', color='red', fontsize=12)
792
+ plt.title('Average Distance of Contacts & Contact Count', fontsize=14)
793
+
794
+ ax1.tick_params(axis='y', labelcolor='blue')
795
+ ax2.tick_params(axis='y', labelcolor='red')
796
+
797
+ ax1.grid(True, alpha=0.3)
798
+
799
+ lines = line1 + line2
800
+ labels = [l.get_label() for l in lines]
801
+
802
+ all_lines = line1 + line2 + [
803
+ plt.Line2D([0], [0], color='blue', linestyle='--'),
804
+ plt.Line2D([0], [0], color='red', linestyle='--')
805
+ ]
806
+ all_labels = [l.get_label() for l in line1 + line2] + [
807
+ f'Mean Dist: {mean_distance:.2f} Å',
808
+ f'Mean Count: {mean_count:.1f}'
809
+ ]
810
+
811
+ fig_mpl.legend(all_lines, all_labels, loc='upper center',
812
+ bbox_to_anchor=(0.5, -0.02), ncol=2)
813
+
814
+ plt.tight_layout(pad=1.2)
815
+
816
+ png_filename = os.path.join(output_dir, "average_contact_analysis.png")
817
+ plt.savefig(png_filename, dpi=300, bbox_inches='tight')
818
+ plt.show()
819
+ plt.close()
820
+
821
+ print(f"Static contact analysis chart saved to {png_filename}")
822
+
823
+
824
+ def save_matrix_data(sub_dm_total, contact_matrix, output_dir="./"):
825
+ """
826
+ Saves the sub-distance matrices and contact matrix to npy files.
827
+
828
+ Parameters
829
+ ----------
830
+ sub_dm_total : np.ndarray
831
+ 3D numpy array of sub-distance matrices
832
+ contact_matrix : np.ndarray
833
+ Boolean 3D numpy array of contact information
834
+ output_dir : str, optional
835
+ Directory where the output files will be saved (default: "./")
836
+
837
+ Returns
838
+ -------
839
+ None
840
+ """
841
+ np.save(os.path.join(output_dir, "sub_dm_total.npy"), sub_dm_total)
842
+ np.save(os.path.join(output_dir, "contact_matrix.npy"), contact_matrix)
843
+
844
+
845
+ def contacts(traj_path, central_atoms, target_atoms, custom_cutoffs, frame_skip=10,
846
+ plot_distance_matrix=False, plot_contacts=False, time_step=None, save_data=False,
847
+ output_dir="./", mic=True):
848
+ """
849
+ Processes a molecular trajectory file to compute contacts between central and target atoms.
850
+
851
+ Parameters
852
+ ----------
853
+ traj_path : str
854
+ Path to the trajectory file
855
+ central_atoms : Union[str, List[Union[int, str]]]
856
+ Criteria for selecting central atoms being analyzed
857
+ target_atoms : Union[str, List[Union[int, str]]]
858
+ Criteria for selecting target atoms that interact with central atoms
859
+ custom_cutoffs : Dict[Tuple[str, str], float]
860
+ Dictionary with custom cutoff values for specific atom pairs
861
+ frame_skip : int, optional
862
+ Number of frames to skip (default: 10)
863
+ plot_distance_matrix : bool, optional
864
+ Boolean flag to plot average distance heatmap (default: False)
865
+ plot_contacts : bool, optional
866
+ Boolean flag to plot contact times heatmap (default: False)
867
+ time_step : float, optional
868
+ Time between frames in fs (required for contact heatmap)
869
+ save_data : bool, optional
870
+ Boolean flag to save matrices as npy files (default: False)
871
+ output_dir : str, optional
872
+ Directory where output files will be saved (default: "./")
873
+ mic : bool, optional
874
+ Whether to use minimum image convention (default: True)
875
+
876
+ Returns
877
+ -------
878
+ Tuple[np.ndarray, np.ndarray]
879
+ 3D numpy array of sub-distance matrices and Boolean 3D numpy array of contacts
880
+ """
881
+ os.makedirs(output_dir, exist_ok=True)
882
+
883
+ sub_dm_list = []
884
+ trajectory = read(traj_path, index=f"::{frame_skip}")
885
+
886
+ results = Parallel(n_jobs=-1)(
887
+ delayed(contacts_frame)(atoms, central_atoms, target_atoms, custom_cutoffs, mic=mic)
888
+ for atoms in trajectory
889
+ )
890
+
891
+ sub_dm_list, cutoff_matrices, indices_central, indices_target = zip(*results)
892
+ cutoff_matrix = cutoff_matrices[0]
893
+ sub_dm_total = np.array(sub_dm_list)
894
+
895
+ atom_type = (central_atoms if isinstance(central_atoms, str)
896
+ else chemical_symbols[central_atoms] if isinstance(central_atoms, int)
897
+ else None)
898
+
899
+ y_labels = [f"{trajectory[0].get_chemical_symbols()[i]}({i})" for i in indices_central[0]]
900
+ x_labels = [f"{trajectory[0].get_chemical_symbols()[i]}({i})" for i in indices_target[0]]
901
+
902
+ contact_matrix = sub_dm_total < cutoff_matrix
903
+
904
+ if plot_contacts and time_step is not None:
905
+ plot_contact_heatmap(contact_matrix, frame_skip, time_step, x_labels, y_labels, atom_type,
906
+ output_dir=output_dir)
907
+ plot_contact_distance(sub_dm_total, contact_matrix, time_step, frame_skip, output_dir=output_dir)
908
+
909
+ if plot_distance_matrix:
910
+ plot_distance_heatmap(sub_dm_total, x_labels, y_labels, atom_type, output_dir=output_dir)
911
+
912
+ if save_data:
913
+ save_matrix_data(sub_dm_total, contact_matrix, output_dir=output_dir)
914
+
915
+ return sub_dm_total, contact_matrix