crisp-ase 1.0.0.post0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crisp-ase might be problematic. Click here for more details.

@@ -0,0 +1,403 @@
1
+ """
2
+ CRISP/data_analysis/prdf.py
3
+
4
+ This script performs Radial Distribution Function analysis on molecular dynamics trajectory data.
5
+ """
6
+
7
+
8
+ import os
9
+ import numpy as np
10
+ import pickle
11
+ import math
12
+ from ase.io import read
13
+ from ase import Atoms
14
+ from typing import Optional, Union, Tuple, List
15
+ from joblib import Parallel, delayed
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.animation import FuncAnimation
18
+
19
+ def check_cell_and_r_max(atoms: Atoms, rmax: float):
20
+ """
21
+ Check that the cell is large enough to contain a sphere of radius rmax.
22
+
23
+ Parameters
24
+ ----------
25
+ atoms : Atoms
26
+ ASE Atoms object with cell information
27
+ rmax : float
28
+ Maximum radius to consider
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If cell is not defined or too small for requested rmax
34
+ """
35
+ if not atoms.cell.any():
36
+ raise ValueError("RDF Error: The system's cell is not defined.")
37
+
38
+ cell = atoms.cell
39
+ try:
40
+ lengths = cell.lengths()
41
+ if np.min(lengths) < 2 * rmax:
42
+ raise ValueError(f"RDF Error: Cell length {np.min(lengths)} is smaller than 2*rmax ({2*rmax}).")
43
+ except AttributeError:
44
+ volume = cell.volume
45
+ required_volume = (4/3) * math.pi * rmax**3
46
+ if volume < required_volume:
47
+ raise ValueError(f"RDF Error: Cell volume {volume} is too small for rmax {rmax} (required >= {required_volume}).")
48
+
49
+ def compute_pairwise_rdf(atoms: Atoms,
50
+ ref_indices: List[int],
51
+ target_indices: List[int],
52
+ rmax: float, nbins: int,
53
+ volume: Optional[float] = None):
54
+ """
55
+ Compute pairwise radial distribution function between sets of atoms.
56
+
57
+ Parameters
58
+ ----------
59
+ atoms : Atoms
60
+ ASE Atoms object
61
+ ref_indices : List[int]
62
+ Indices of reference atoms
63
+ target_indices : List[int]
64
+ Indices of target atoms
65
+ rmax : float
66
+ Maximum radius for RDF calculation
67
+ nbins : int
68
+ Number of bins for histogram
69
+ volume : float, optional
70
+ Custom normalization volume to use instead of cell volume.
71
+ Useful for non-periodic systems or custom normalization.
72
+ (default: None, uses atoms.get_volume())
73
+
74
+ Returns
75
+ -------
76
+ Tuple[np.ndarray, np.ndarray]
77
+ RDF values and corresponding bin centers
78
+ """
79
+ N_total = len(atoms)
80
+ dm = atoms.get_all_distances(mic=True)
81
+ dr = float(rmax / nbins)
82
+ volume = atoms.get_volume() if volume is None else volume
83
+
84
+
85
+ if set(ref_indices) == set(target_indices):
86
+ sub_dm = dm[np.ix_(ref_indices, target_indices)]
87
+ sub_dm = np.triu(sub_dm, k=1) # Exclude diagonal and use upper triangle
88
+ distances = sub_dm[sub_dm > 0]
89
+
90
+ N = len(ref_indices)
91
+
92
+ # Division by 2 for same-species pairs
93
+ norm = (4 * math.pi * dr * (N/volume) * N)/2
94
+ else:
95
+ sub_dm = dm[np.ix_(ref_indices, target_indices)]
96
+ distances = sub_dm[sub_dm > 0]
97
+
98
+ N_A = len(ref_indices)
99
+ N_B = len(target_indices)
100
+
101
+ norm = 4 * math.pi * dr * (N_A / volume) * N_B
102
+
103
+ hist, bin_edges = np.histogram(distances, bins=nbins, range=(0, rmax))
104
+ bin_centers = 0.5 * (bin_edges[1:] + bin_edges[:-1])
105
+ rdf = hist / (norm * (bin_centers**2))
106
+
107
+ return rdf, bin_centers
108
+
109
+ class Analysis:
110
+ """
111
+ Class for analyzing atomic trajectories and calculating RDFs.
112
+
113
+ Parameters
114
+ ----------
115
+ images : List[Atoms]
116
+ List of ASE Atoms objects representing trajectory frames
117
+ """
118
+
119
+ def __init__(self, images: List[Atoms]):
120
+ self.images = images
121
+
122
+ def _get_slice(self, imageIdx: Optional[Union[int, slice]]):
123
+ """
124
+ Convert image index to slice for selecting trajectory frames.
125
+
126
+ Parameters
127
+ ----------
128
+ imageIdx : Optional[Union[int, slice]]
129
+ Index or slice to select images
130
+
131
+ Returns
132
+ -------
133
+ slice
134
+ Slice object for image selection
135
+ """
136
+ if imageIdx is None:
137
+ return slice(None)
138
+ return imageIdx
139
+
140
+ def get_rdf(self,
141
+ rmax: float,
142
+ nbins: int = 100,
143
+ imageIdx: Optional[Union[int, slice]] = None,
144
+ atomic_indices: Optional[Tuple[List[int], List[int]]] = None,
145
+ return_dists: bool = False):
146
+ """
147
+ Calculate radial distribution function for trajectory frames.
148
+
149
+ Parameters
150
+ ----------
151
+ rmax : float
152
+ Maximum radius for RDF calculation
153
+ nbins : int, optional
154
+ Number of bins for histogram (default: 100)
155
+ imageIdx : Optional[Union[int, slice]], optional
156
+ Index or slice to select images (default: None, all images)
157
+ atomic_indices : Optional[Tuple[List[int], List[int]]], optional
158
+ Tuple of (reference_indices, target_indices) for partial RDF
159
+ return_dists : bool, optional
160
+ Whether to return bin center distances (default: False)
161
+
162
+ Returns
163
+ -------
164
+ List[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]
165
+ List of RDF values or tuples (RDF, bin_centers) for each frame
166
+ """
167
+ sl = self._get_slice(imageIdx)
168
+ images_to_process = self.images[sl]
169
+
170
+ if atomic_indices is None:
171
+ def process_image(image: Atoms):
172
+ check_cell_and_r_max(image, rmax)
173
+ full_indices = list(range(len(image)))
174
+ rdf, bin_centers = compute_pairwise_rdf(image, full_indices, full_indices, rmax, nbins)
175
+ return (rdf, bin_centers) if return_dists else rdf
176
+ else:
177
+ ref_indices, target_indices = atomic_indices
178
+ def process_image(image: Atoms):
179
+ check_cell_and_r_max(image, rmax)
180
+ rdf, bin_centers = compute_pairwise_rdf(
181
+ image,
182
+ ref_indices,
183
+ target_indices,
184
+ rmax,
185
+ nbins
186
+ )
187
+ return (rdf, bin_centers) if return_dists else rdf
188
+
189
+ ls_rdf = Parallel(n_jobs=-1)(delayed(process_image)(image) for image in images_to_process)
190
+ return ls_rdf
191
+
192
+ def plot_rdf(x_data_all, y_data_all, title=None, output_file=None):
193
+ """
194
+ Plot the average RDF with peak location marked by vertical line.
195
+
196
+ Parameters
197
+ ----------
198
+ x_data_all : np.ndarray
199
+ Distance values in Ångström
200
+ y_data_all : List[np.ndarray]
201
+ RDF values for each frame
202
+ title : str, optional
203
+ Custom title for the plot
204
+ output_file : str, optional
205
+ Path to save the plot (if None, just display)
206
+
207
+ Returns
208
+ -------
209
+ None
210
+ """
211
+ # Average RDF across all frames
212
+ y_data_avg = np.mean(y_data_all, axis=0)
213
+
214
+ # Index of the maximum y value in the average RDF
215
+ max_y_index = np.argmax(y_data_avg)
216
+ max_y_x = x_data_all[max_y_index]
217
+ max_y = y_data_avg[max_y_index]
218
+
219
+ plt.figure(figsize=(10, 6))
220
+
221
+ plt.plot(x_data_all, y_data_avg, linewidth=2, label='Average RDF')
222
+
223
+ plt.axvline(x=max_y_x, color='red', linestyle='--', label=f'Peak at {max_y_x:.2f} Å')
224
+
225
+ plt.xlabel('Distance (Å)', fontsize=12)
226
+ plt.ylabel('g(r)', fontsize=12)
227
+ plt.title(title or 'Average Radial Distribution Function', fontsize=14)
228
+
229
+ plt.grid(True, alpha=0.3)
230
+ plt.legend(fontsize=10)
231
+
232
+ plt.ylim(bottom=0, top=max_y * 1.2)
233
+
234
+ plt.tight_layout()
235
+
236
+ if output_file:
237
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
238
+ plt.show() # Added to display plot in addition to saving it
239
+ plt.close()
240
+ else:
241
+ plt.show()
242
+
243
+
244
+ def animate_rdf(x_data_all, y_data_all, output_file=None):
245
+ fig, ax = plt.subplots(figsize=(10, 6))
246
+
247
+ y_data_avg = np.mean(y_data_all, axis=0)
248
+
249
+ plt.xlabel('Distance (Å)', fontsize=12)
250
+ plt.ylabel('g(r)', fontsize=12)
251
+
252
+ max_y = max([np.max(y) for y in y_data_all] + [np.max(y_data_avg)]) * 1.1
253
+
254
+ def update(frame):
255
+ ax.clear() # Clears the previous frame
256
+ y = y_data_all[frame]
257
+
258
+ ax.plot(x_data_all, y, linewidth=2, label='Current Frame')
259
+
260
+ ax.plot(x_data_all, y_data_avg, linewidth=2, linestyle='--',
261
+ color='purple', label='Average RDF')
262
+
263
+ max_y_index = np.argmax(y)
264
+ max_y_x = x_data_all[max_y_index]
265
+
266
+ ax.axvline(x=max_y_x, color='red', linestyle='--', label=f'Peak at {max_y_x:.2f} Å')
267
+
268
+ ax.legend(fontsize=10)
269
+ ax.grid(True, alpha=0.3)
270
+ ax.set_ylim(0, max_y)
271
+ ax.set_title(f'Radial Distribution Function - Frame {frame}', fontsize=14)
272
+ return ax,
273
+
274
+ ani = FuncAnimation(fig, update, frames=range(len(y_data_all)),
275
+ interval=200, blit=False)
276
+
277
+ if output_file:
278
+ html_file = os.path.splitext(output_file)[0] + ".html"
279
+ html_code = ani.to_jshtml()
280
+ with open(html_file, 'w') as f:
281
+ f.write(html_code)
282
+ print(f"Interactive animation saved to '{html_file}'")
283
+
284
+ try:
285
+ ani.save(output_file, writer='pillow', fps=5)
286
+ print(f"GIF animation saved to '{output_file}'")
287
+ except Exception as e:
288
+ print(f"Warning: Could not save GIF animation: {e}")
289
+
290
+ plt.tight_layout()
291
+ plt.close()
292
+ else:
293
+ plt.tight_layout()
294
+
295
+ return ani
296
+
297
+ def analyze_rdf(use_prdf: bool,
298
+ rmax: float,
299
+ traj_path: str,
300
+ nbins: int = 100,
301
+ frame_skip: int = 10,
302
+ output_filename: Optional[str] = None,
303
+ atomic_indices: Optional[Tuple[List[int], List[int]]] = None,
304
+ output_dir: str = 'custom_ase',
305
+ create_plots: bool = False):
306
+ """
307
+ Analyze trajectory and calculate radial distribution functions.
308
+
309
+ Parameters
310
+ ----------
311
+ use_prdf : bool
312
+ Whether to calculate partial RDF (True) or total RDF (False)
313
+ rmax : float
314
+ Maximum radius for RDF calculation
315
+ traj_path : str
316
+ Path to trajectory file
317
+ nbins : int, optional
318
+ Number of bins for histogram (default: 100)
319
+ frame_skip : int, optional
320
+ Number of frames to skip between analyses (default: 10)
321
+ output_filename : Optional[str], optional
322
+ Custom filename for output (default: None, auto-generated)
323
+ atomic_indices : Optional[Tuple[List[int], List[int]]], optional
324
+ Tuple of (reference_indices, target_indices) for partial RDF
325
+ output_dir : str, optional
326
+ Directory to save output files (default: 'custom_ase')
327
+ create_plots : bool, optional
328
+ Whether to create plots and animations of the RDF data (default: False)
329
+
330
+ Returns
331
+ -------
332
+ dict
333
+ Dictionary containing x_data (bin centers) and y_data_all (RDF values for each frame)
334
+
335
+ Raises
336
+ ------
337
+ ValueError
338
+ If no images found in trajectory or if atomic_indices is missing for PRDF
339
+ """
340
+ images = read(traj_path, index=f'::{frame_skip}')
341
+ if not isinstance(images, list):
342
+ images = [images]
343
+
344
+ if not images:
345
+ raise ValueError("No images found in the trajectory.")
346
+
347
+ # Check cell validity for the first image
348
+ check_cell_and_r_max(images[0], rmax)
349
+
350
+ analysis = Analysis(images)
351
+
352
+ if use_prdf:
353
+ if atomic_indices is None:
354
+ raise ValueError("For partial RDF, atomic_indices must be provided.")
355
+ ls_rdf = analysis.get_rdf(rmax, nbins, atomic_indices=atomic_indices, return_dists=True)
356
+ else:
357
+ ls_rdf = analysis.get_rdf(rmax, nbins, atomic_indices=None, return_dists=True)
358
+
359
+ x_data_all = ls_rdf[0][1]
360
+ y_data_all = [rdf for rdf, _ in ls_rdf]
361
+
362
+ if output_dir:
363
+ os.makedirs(output_dir, exist_ok=True)
364
+
365
+ if not output_filename:
366
+ if use_prdf:
367
+ if atomic_indices:
368
+ ref_str = f"{len(atomic_indices[0])}-atoms"
369
+ target_str = f"{len(atomic_indices[1])}-atoms"
370
+ base_name = f"prdf_{ref_str}_{target_str}"
371
+ else:
372
+ base_name = "prdf_custom_indices"
373
+ else:
374
+ base_name = "rdf_total"
375
+ else:
376
+ base_name = output_filename.rsplit('.', 1)[0] if '.' in output_filename else output_filename
377
+
378
+ if output_dir:
379
+ pickle_file = os.path.join(output_dir, f"{base_name}.pkl")
380
+ with open(pickle_file, 'wb') as f:
381
+ pickle.dump({'x_data': x_data_all, 'y_data_all': y_data_all}, f)
382
+
383
+ print(f"Data saved in '{pickle_file}'")
384
+
385
+ if create_plots:
386
+ if use_prdf:
387
+ if atomic_indices:
388
+ title = f"Partial RDF: {len(atomic_indices[0])} reference atoms, {len(atomic_indices[1])} target atoms"
389
+ else:
390
+ title = "Partial RDF"
391
+ else:
392
+ title = "Total Radial Distribution Function"
393
+
394
+ static_plot_file = os.path.join(output_dir, f"{base_name}_plot.png")
395
+ plot_rdf(x_data_all, y_data_all, title=title, output_file=static_plot_file)
396
+ print(f"Static plot saved in '{static_plot_file}'")
397
+
398
+ if len(y_data_all) > 1:
399
+ animation_file = os.path.join(output_dir, f"{base_name}_animation.gif")
400
+ animate_rdf(x_data_all, y_data_all, output_file=animation_file)
401
+ print(f"Animation saved in '{animation_file}'")
402
+
403
+ return {'x_data': x_data_all, 'y_data_all': y_data_all}