crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,404 @@
1
+ """
2
+ CRISP/data_analysis/prdf.py
3
+
4
+ This script performs Radial Distribution Function analysis on molecular dynamics trajectory data.
5
+ """
6
+
7
+ import os
8
+ import numpy as np
9
+ import pickle
10
+ import math
11
+ from ase.io import read
12
+ from ase import Atoms
13
+ from typing import Optional, Union, Tuple, List
14
+ from joblib import Parallel, delayed
15
+ import matplotlib.pyplot as plt
16
+ from matplotlib.animation import FuncAnimation
17
+
18
+ __all__ = ['check_cell_and_r_max', 'compute_pairwise_rdf', 'plot_rdf', 'animate_rdf', 'analyze_rdf']
19
+
20
+ def check_cell_and_r_max(atoms: Atoms, rmax: float):
21
+ """
22
+ Check that the cell is large enough to contain a sphere of radius rmax.
23
+
24
+ Parameters
25
+ ----------
26
+ atoms : Atoms
27
+ ASE Atoms object with cell information
28
+ rmax : float
29
+ Maximum radius to consider
30
+
31
+ Raises
32
+ ------
33
+ ValueError
34
+ If cell is not defined or too small for requested rmax
35
+ """
36
+ if not atoms.cell.any():
37
+ raise ValueError("RDF Error: The system's cell is not defined.")
38
+
39
+ cell = atoms.cell
40
+ try:
41
+ lengths = cell.lengths()
42
+ if np.min(lengths) < 2 * rmax:
43
+ raise ValueError(f"RDF Error: Cell length {np.min(lengths)} is smaller than 2*rmax ({2*rmax}).")
44
+ except AttributeError:
45
+ volume = cell.volume
46
+ required_volume = (4/3) * math.pi * rmax**3
47
+ if volume < required_volume:
48
+ raise ValueError(f"RDF Error: Cell volume {volume} is too small for rmax {rmax} (required >= {required_volume}).")
49
+
50
+ def compute_pairwise_rdf(atoms: Atoms,
51
+ ref_indices: List[int],
52
+ target_indices: List[int],
53
+ rmax: float, nbins: int,
54
+ volume: Optional[float] = None):
55
+ """
56
+ Compute pairwise radial distribution function between sets of atoms.
57
+
58
+ Parameters
59
+ ----------
60
+ atoms : Atoms
61
+ ASE Atoms object
62
+ ref_indices : List[int]
63
+ Indices of reference atoms
64
+ target_indices : List[int]
65
+ Indices of target atoms
66
+ rmax : float
67
+ Maximum radius for RDF calculation
68
+ nbins : int
69
+ Number of bins for histogram
70
+ volume : float, optional
71
+ Custom normalization volume to use instead of cell volume.
72
+ Useful for non-periodic systems or custom normalization.
73
+ (default: None, uses atoms.get_volume())
74
+
75
+ Returns
76
+ -------
77
+ Tuple[np.ndarray, np.ndarray]
78
+ RDF values and corresponding bin centers
79
+ """
80
+ N_total = len(atoms)
81
+ dm = atoms.get_all_distances(mic=True)
82
+ dr = float(rmax / nbins)
83
+ volume = atoms.get_volume() if volume is None else volume
84
+
85
+
86
+ if set(ref_indices) == set(target_indices):
87
+ sub_dm = dm[np.ix_(ref_indices, target_indices)]
88
+ sub_dm = np.triu(sub_dm, k=1) # Exclude diagonal and use upper triangle
89
+ distances = sub_dm[sub_dm > 0]
90
+
91
+ N = len(ref_indices)
92
+
93
+ # Division by 2 for same-species pairs
94
+ norm = (4 * math.pi * dr * (N/volume) * N)/2
95
+ else:
96
+ sub_dm = dm[np.ix_(ref_indices, target_indices)]
97
+ distances = sub_dm[sub_dm > 0]
98
+
99
+ N_A = len(ref_indices)
100
+ N_B = len(target_indices)
101
+
102
+ norm = 4 * math.pi * dr * (N_A / volume) * N_B
103
+
104
+ hist, bin_edges = np.histogram(distances, bins=nbins, range=(0, rmax))
105
+ bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
106
+ rdf = hist / (norm * (bin_centers**2))
107
+
108
+ return rdf, bin_centers
109
+
110
+ class Analysis:
111
+ """
112
+ Class for analyzing atomic trajectories and calculating RDFs.
113
+
114
+ Parameters
115
+ ----------
116
+ images : List[Atoms]
117
+ List of ASE Atoms objects representing trajectory frames
118
+ """
119
+
120
+ def __init__(self, images: List[Atoms]):
121
+ self.images = images
122
+
123
+ def _get_slice(self, imageIdx: Optional[Union[int, slice]]):
124
+ """
125
+ Convert image index to slice for selecting trajectory frames.
126
+
127
+ Parameters
128
+ ----------
129
+ imageIdx : Optional[Union[int, slice]]
130
+ Index or slice to select images
131
+
132
+ Returns
133
+ -------
134
+ slice
135
+ Slice object for image selection
136
+ """
137
+ if imageIdx is None:
138
+ return slice(None)
139
+ return imageIdx
140
+
141
+ def get_rdf(self,
142
+ rmax: float,
143
+ nbins: int = 100,
144
+ imageIdx: Optional[Union[int, slice]] = None,
145
+ atomic_indices: Optional[Tuple[List[int], List[int]]] = None,
146
+ return_dists: bool = False):
147
+ """
148
+ Calculate radial distribution function for trajectory frames.
149
+
150
+ Parameters
151
+ ----------
152
+ rmax : float
153
+ Maximum radius for RDF calculation
154
+ nbins : int, optional
155
+ Number of bins for histogram (default: 100)
156
+ imageIdx : Optional[Union[int, slice]], optional
157
+ Index or slice to select images (default: None, all images)
158
+ atomic_indices : Optional[Tuple[List[int], List[int]]], optional
159
+ Tuple of (reference_indices, target_indices) for partial RDF
160
+ return_dists : bool, optional
161
+ Whether to return bin center distances (default: False)
162
+
163
+ Returns
164
+ -------
165
+ List[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]
166
+ List of RDF values or tuples (RDF, bin_centers) for each frame
167
+ """
168
+ sl = self._get_slice(imageIdx)
169
+ images_to_process = self.images[sl]
170
+
171
+ if atomic_indices is None:
172
+ def process_image(image: Atoms):
173
+ check_cell_and_r_max(image, rmax)
174
+ full_indices = list(range(len(image)))
175
+ rdf, bin_centers = compute_pairwise_rdf(image, full_indices, full_indices, rmax, nbins)
176
+ return (rdf, bin_centers) if return_dists else rdf
177
+ else:
178
+ ref_indices, target_indices = atomic_indices
179
+ def process_image(image: Atoms):
180
+ check_cell_and_r_max(image, rmax)
181
+ rdf, bin_centers = compute_pairwise_rdf(
182
+ image,
183
+ ref_indices,
184
+ target_indices,
185
+ rmax,
186
+ nbins
187
+ )
188
+ return (rdf, bin_centers) if return_dists else rdf
189
+
190
+ ls_rdf = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images_to_process)
191
+ return ls_rdf
192
+
193
+ def plot_rdf(x_data_all, y_data_all, title=None, output_file=None):
194
+ """
195
+ Plot the average RDF with peak location marked by vertical line.
196
+
197
+ Parameters
198
+ ----------
199
+ x_data_all : np.ndarray
200
+ Distance values in Ångström
201
+ y_data_all : List[np.ndarray]
202
+ RDF values for each frame
203
+ title : str, optional
204
+ Custom title for the plot
205
+ output_file : str, optional
206
+ Path to save the plot (if None, just display)
207
+
208
+ Returns
209
+ -------
210
+ None
211
+ """
212
+ # Average RDF across all frames
213
+ y_data_avg = np.mean(y_data_all, axis=0)
214
+
215
+ # Index of the maximum y value in the average RDF
216
+ max_y_index = np.argmax(y_data_avg)
217
+ max_y_x = x_data_all[max_y_index]
218
+ max_y = y_data_avg[max_y_index]
219
+
220
+ plt.figure(figsize=(10, 6))
221
+
222
+ plt.plot(x_data_all, y_data_avg, linewidth=2, label='Average RDF')
223
+
224
+ plt.axvline(x=max_y_x, color='red', linestyle='--', label=f'Peak at {max_y_x:.2f} Å')
225
+
226
+ plt.xlabel('Distance (Å)', fontsize=12)
227
+ plt.ylabel('g(r)', fontsize=12)
228
+ plt.title(title or 'Average Radial Distribution Function', fontsize=14)
229
+
230
+ plt.grid(True, alpha=0.3)
231
+ plt.legend(fontsize=10)
232
+
233
+ plt.ylim(bottom=0, top=max_y * 1.2)
234
+
235
+ plt.tight_layout()
236
+
237
+ if output_file:
238
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
239
+ plt.show() # Added to display plot in addition to saving it
240
+ plt.close()
241
+ else:
242
+ plt.show()
243
+
244
+
245
+ def animate_rdf(x_data_all, y_data_all, output_file=None):
246
+ fig, ax = plt.subplots(figsize=(10, 6))
247
+
248
+ y_data_avg = np.mean(y_data_all, axis=0)
249
+
250
+ plt.xlabel('Distance (Å)', fontsize=12)
251
+ plt.ylabel('g(r)', fontsize=12)
252
+
253
+ max_y = max([np.max(y) for y in y_data_all] + [np.max(y_data_avg)]) * 1.1
254
+
255
+ def update(frame):
256
+ ax.clear() # Clears the previous frame
257
+ y = y_data_all[frame]
258
+
259
+ ax.plot(x_data_all, y, linewidth=2, label='Current Frame')
260
+
261
+ ax.plot(x_data_all, y_data_avg, linewidth=2, linestyle='--',
262
+ color='purple', label='Average RDF')
263
+
264
+ max_y_index = np.argmax(y)
265
+ max_y_x = x_data_all[max_y_index]
266
+
267
+ ax.axvline(x=max_y_x, color='red', linestyle='--', label=f'Peak at {max_y_x:.2f} Å')
268
+
269
+ ax.legend(fontsize=10)
270
+ ax.grid(True, alpha=0.3)
271
+ ax.set_ylim(0, max_y)
272
+ ax.set_title(f'Radial Distribution Function - Frame {frame}', fontsize=14)
273
+ return ax,
274
+
275
+ ani = FuncAnimation(fig, update, frames=range(len(y_data_all)),
276
+ interval=200, blit=False)
277
+
278
+ if output_file:
279
+ html_file = os.path.splitext(output_file)[0] + ".html"
280
+ html_code = ani.to_jshtml()
281
+ with open(html_file, 'w') as f:
282
+ f.write(html_code)
283
+ print(f"Interactive animation saved to '{html_file}'")
284
+
285
+ try:
286
+ ani.save(output_file, writer='pillow', fps=5)
287
+ print(f"GIF animation saved to '{output_file}'")
288
+ except Exception as e:
289
+ print(f"Warning: Could not save GIF animation: {e}")
290
+
291
+ plt.tight_layout()
292
+ plt.close()
293
+ else:
294
+ plt.tight_layout()
295
+
296
+ return ani
297
+
298
+ def analyze_rdf(use_prdf: bool,
299
+ rmax: float,
300
+ traj_path: str,
301
+ nbins: int = 100,
302
+ frame_skip: int = 10,
303
+ output_filename: Optional[str] = None,
304
+ atomic_indices: Optional[Tuple[List[int], List[int]]] = None,
305
+ output_dir: str = 'custom_ase',
306
+ create_plots: bool = False):
307
+ """
308
+ Analyze trajectory and calculate radial distribution functions.
309
+
310
+ Parameters
311
+ ----------
312
+ use_prdf : bool
313
+ Whether to calculate partial RDF (True) or total RDF (False)
314
+ rmax : float
315
+ Maximum radius for RDF calculation
316
+ traj_path : str
317
+ Path to trajectory file
318
+ nbins : int, optional
319
+ Number of bins for histogram (default: 100)
320
+ frame_skip : int, optional
321
+ Number of frames to skip between analyses (default: 10)
322
+ output_filename : Optional[str], optional
323
+ Custom filename for output (default: None, auto-generated)
324
+ atomic_indices : Optional[Tuple[List[int], List[int]]], optional
325
+ Tuple of (reference_indices, target_indices) for partial RDF
326
+ output_dir : str, optional
327
+ Directory to save output files (default: 'custom_ase')
328
+ create_plots : bool, optional
329
+ Whether to create plots and animations of the RDF data (default: False)
330
+
331
+ Returns
332
+ -------
333
+ dict
334
+ Dictionary containing x_data (bin centers) and y_data_all (RDF values for each frame)
335
+
336
+ Raises
337
+ ------
338
+ ValueError
339
+ If no images found in trajectory or if atomic_indices is missing for PRDF
340
+ """
341
+ images = read(traj_path, index=f'::{frame_skip}')
342
+ if not isinstance(images, list):
343
+ images = [images]
344
+
345
+ if not images:
346
+ raise ValueError("No images found in the trajectory.")
347
+
348
+ # Check cell validity for the first image
349
+ check_cell_and_r_max(images[0], rmax)
350
+
351
+ analysis = Analysis(images)
352
+
353
+ if use_prdf:
354
+ if atomic_indices is None:
355
+ raise ValueError("For partial RDF, atomic_indices must be provided.")
356
+ ls_rdf = analysis.get_rdf(rmax, nbins, atomic_indices=atomic_indices, return_dists=True)
357
+ else:
358
+ ls_rdf = analysis.get_rdf(rmax, nbins, atomic_indices=None, return_dists=True)
359
+
360
+ x_data_all = ls_rdf[0][1]
361
+ y_data_all = [rdf for rdf, _ in ls_rdf]
362
+
363
+ if output_dir:
364
+ os.makedirs(output_dir, exist_ok=True)
365
+
366
+ if not output_filename:
367
+ if use_prdf:
368
+ if atomic_indices:
369
+ ref_str = f"{len(atomic_indices[0])}-atoms"
370
+ target_str = f"{len(atomic_indices[1])}-atoms"
371
+ base_name = f"prdf_{ref_str}_{target_str}"
372
+ else:
373
+ base_name = "prdf_custom_indices"
374
+ else:
375
+ base_name = "rdf_total"
376
+ else:
377
+ base_name = output_filename.rsplit('.', 1)[0] if '.' in output_filename else output_filename
378
+
379
+ if output_dir:
380
+ pickle_file = os.path.join(output_dir, f"{base_name}.pkl")
381
+ with open(pickle_file, 'wb') as f:
382
+ pickle.dump({'x_data': x_data_all, 'y_data_all': y_data_all}, f)
383
+
384
+ print(f"Data saved in '{pickle_file}'")
385
+
386
+ if create_plots:
387
+ if use_prdf:
388
+ if atomic_indices:
389
+ title = f"Partial RDF: {len(atomic_indices[0])} reference atoms, {len(atomic_indices[1])} target atoms"
390
+ else:
391
+ title = "Partial RDF"
392
+ else:
393
+ title = "Total Radial Distribution Function"
394
+
395
+ static_plot_file = os.path.join(output_dir, f"{base_name}_plot.png")
396
+ plot_rdf(x_data_all, y_data_all, title=title, output_file=static_plot_file)
397
+ print(f"Static plot saved in '{static_plot_file}'")
398
+
399
+ if len(y_data_all) > 1:
400
+ animation_file = os.path.join(output_dir, f"{base_name}_animation.gif")
401
+ animate_rdf(x_data_all, y_data_all, output_file=animation_file)
402
+ print(f"Animation saved in '{animation_file}'")
403
+
404
+ return {'x_data': x_data_all, 'y_data_all': y_data_all}