crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1199 @@
1
+ """
2
+ CRISP/data_analysis/msd.py
3
+
4
+ This module performs mean square displacement (MSD) analysis on molecular dynamics
5
+ trajectory data for diffusion coefficient calculations.
6
+ """
7
+
8
+ import ase.io
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ import csv
12
+ import sys
13
+ from ase.units import fs
14
+ from ase.units import fs as fs_conversion
15
+ from ase.data import chemical_symbols
16
+ from scipy.optimize import curve_fit
17
+ import pandas as pd
18
+ import os
19
+ import traceback
20
+ from typing import List, Tuple, Union, Optional, Dict, Any
21
+ from joblib import Parallel, delayed, cpu_count
22
+
23
+ __all__ = ['read_trajectory_chunk', 'calculate_frame_msd', 'calculate_msd', 'calculate_save_msd', 'msd_analysis']
24
+
25
+
26
+ def read_trajectory_chunk(
27
+ traj_path: str,
28
+ index_slice: str,
29
+ frame_skip: int = 1
30
+ ) -> List:
31
+ """Read a chunk of trajectory data in parallel.
32
+
33
+ Parameters
34
+ ----------
35
+ traj_path : str
36
+ Path to the trajectory file (supports any ASE-readable format)
37
+ index_slice : str
38
+ ASE index slice for reading a subset of frames
39
+ frame_skip : int, optional
40
+ Number of frames to skip (default: 1)
41
+
42
+ Returns
43
+ -------
44
+ list
45
+ List of ASE Atoms objects for the specified chunk
46
+ """
47
+ try:
48
+ frames = ase.io.read(traj_path, index=index_slice)
49
+ if not isinstance(frames, list):
50
+ frames = [frames]
51
+
52
+ frames = frames[::frame_skip]
53
+ return frames
54
+ except Exception as e:
55
+ print(f"Error reading trajectory chunk {index_slice}: {e}")
56
+ return []
57
+
58
+
59
+ def calculate_frame_msd(
60
+ frame_idx: int,
61
+ current_frame,
62
+ reference_frame,
63
+ atom_indices: List[int],
64
+ msd_direction: bool = False
65
+ ) -> Union[Tuple[int, float], Tuple[int, float, float, float]]:
66
+ """
67
+ Calculate MSD for a single frame.
68
+
69
+ Parameters
70
+ ----------
71
+ frame_idx : int
72
+ Index of the current frame
73
+ current_frame : ase.Atoms
74
+ Current frame
75
+ reference_frame : ase.Atoms
76
+ Reference frame
77
+ atom_indices : list
78
+ List of atom indices to include in MSD calculation
79
+ msd_direction : bool, optional
80
+ Whether to calculate directional MSD (default: False)
81
+
82
+ Returns
83
+ -------
84
+ tuple
85
+ If msd_direction is False: (frame_idx, msd_value)
86
+ If msd_direction is True: (frame_idx, msd_x, msd_y, msd_z)
87
+ """
88
+ atom_positions_current = current_frame.positions[atom_indices]
89
+ atom_positions_reference = reference_frame.positions[atom_indices]
90
+ displacements = atom_positions_current - atom_positions_reference
91
+
92
+ if not msd_direction:
93
+ # Calculate total MSD
94
+ msd_value = np.sum(np.square(displacements)) / (len(atom_indices))
95
+ return frame_idx, msd_value
96
+ else:
97
+ # Directional MSDs
98
+ msd_x = np.sum(displacements[:, 0]**2) / len(atom_indices)
99
+ msd_y = np.sum(displacements[:, 1]**2) / len(atom_indices)
100
+ msd_z = np.sum(displacements[:, 2]**2) / len(atom_indices)
101
+
102
+ return frame_idx, msd_x, msd_y, msd_z
103
+
104
+ def calculate_msd(traj, timestep, atom_indices=None, ignore_n_images=0, n_jobs=-1,
105
+ msd_direction=False, msd_direction_atom=None):
106
+ """
107
+ Calculate Mean Square Displacement (MSD) vs time using parallel processing.
108
+
109
+ Parameters
110
+ ----------
111
+ traj : list of ase.Atoms
112
+ Trajectory data
113
+ timestep : float
114
+ Simulation timestep
115
+ atom_indices : numpy.ndarray, optional
116
+ Indices of atoms to analyze (default: all atoms)
117
+ ignore_n_images : int, optional
118
+ Number of initial images to ignore (default: 0)
119
+ n_jobs : int, optional
120
+ Number of parallel jobs to run (default: -1, use all available cores)
121
+ msd_direction : bool, optional
122
+ Whether to calculate directional MSD (default: False)
123
+ If True and atom_indices is provided, directional MSD is calculated for those indices
124
+ msd_direction_atom : str or int, optional
125
+ Atom symbol or atomic number to filter for directional MSD (default: None)
126
+ Only used when atom_indices is None
127
+
128
+ Returns
129
+ -------
130
+ tuple or dict
131
+ If atom_indices is provided: (msd_times, msd_x, msd_y, msd_z) if msd_direction=True
132
+ else (msd_values, msd_times)
133
+ If atom_indices is None: A dictionary with keys for each atom type
134
+
135
+ """
136
+ # Time values
137
+ total_images = len(traj) - ignore_n_images
138
+ timesteps = np.linspace(0, total_images * timestep, total_images+1)
139
+ msd_times = timesteps[:] / fs_conversion # Convert to femtoseconds
140
+
141
+ # Reference frame
142
+ reference_frame = traj[ignore_n_images]
143
+
144
+ if n_jobs == -1:
145
+ n_jobs = cpu_count()
146
+
147
+ direction_indices = None
148
+ if msd_direction and msd_direction_atom is not None:
149
+ atoms = traj[0]
150
+ if isinstance(msd_direction_atom, str):
151
+ # An atom symbol (e.g., 'O')
152
+ symbols = atoms.get_chemical_symbols()
153
+ direction_indices = [i for i, s in enumerate(symbols) if s == msd_direction_atom]
154
+ print(f"Calculating directional MSD for {len(direction_indices)} {msd_direction_atom} atoms")
155
+ elif isinstance(msd_direction_atom, int):
156
+ # An atomic number (e.g., 8 for oxygen)
157
+ atomic_numbers = atoms.get_atomic_numbers()
158
+ direction_indices = [i for i, z in enumerate(atomic_numbers) if z == msd_direction_atom]
159
+ symbol = chemical_symbols[msd_direction_atom]
160
+ print(f"Calculating directional MSD for {len(direction_indices)} {symbol} atoms (Z={msd_direction_atom})")
161
+
162
+ # MSD for those atoms
163
+ if atom_indices is not None:
164
+ do_direction = msd_direction
165
+
166
+ # Parallelize MSD calculation
167
+ results = Parallel(n_jobs=n_jobs)(
168
+ delayed(calculate_frame_msd)(
169
+ i - ignore_n_images,
170
+ traj[i],
171
+ reference_frame,
172
+ atom_indices,
173
+ do_direction
174
+ )
175
+ for i in range(ignore_n_images, len(traj))
176
+ )
177
+
178
+ # Sort results by frame index and extract MSD values
179
+ results.sort(key=lambda x: x[0])
180
+
181
+ if do_direction:
182
+ print(f"Calculating directional MSD for {len(atom_indices)} specified atoms")
183
+ msd_x = np.array([r[1] for r in results])
184
+ msd_y = np.array([r[2] for r in results])
185
+ msd_z = np.array([r[3] for r in results])
186
+ return msd_times, msd_x, msd_y, msd_z
187
+ else:
188
+ msd_values = np.array([r[1] for r in results])
189
+ return msd_values, msd_times[:]
190
+
191
+ # MSD per atom type
192
+ else:
193
+ atoms = traj[0]
194
+ symbols = atoms.get_chemical_symbols()
195
+ unique_symbols = set(symbols)
196
+
197
+ # A dictionary mapping symbols to their indices
198
+ symbol_indices = {symbol: [i for i, s in enumerate(symbols) if s == symbol]
199
+ for symbol in unique_symbols}
200
+
201
+ # Overall MSD using all atoms
202
+ all_indices = list(range(len(atoms)))
203
+
204
+ overall_results = Parallel(n_jobs=n_jobs)(
205
+ delayed(calculate_frame_msd)(
206
+ i - ignore_n_images,
207
+ traj[i],
208
+ reference_frame,
209
+ all_indices,
210
+ False
211
+ )
212
+ for i in range(ignore_n_images, len(traj))
213
+ )
214
+
215
+ # Sort results by frame index and extract MSD values
216
+ overall_results.sort(key=lambda x: x[0])
217
+ overall_msd = np.array([r[1] for r in overall_results])
218
+
219
+ # Dictionary to store MSD results
220
+ result = {'overall': (overall_msd, msd_times)}
221
+
222
+ # Calculate MSD for each atom type in parallel
223
+ for symbol, indices in symbol_indices.items():
224
+ print(f"Calculating MSD for {symbol} atoms...")
225
+ calc_direction = msd_direction and (
226
+ (isinstance(msd_direction_atom, str) and symbol == msd_direction_atom) or
227
+ (isinstance(msd_direction_atom, int) and
228
+ atoms.get_atomic_numbers()[indices[0]] == msd_direction_atom)
229
+ )
230
+
231
+ symbol_results = Parallel(n_jobs=n_jobs)(
232
+ delayed(calculate_frame_msd)(
233
+ i - ignore_n_images,
234
+ traj[i],
235
+ reference_frame,
236
+ indices,
237
+ calc_direction
238
+ )
239
+ for i in range(ignore_n_images, len(traj))
240
+ )
241
+
242
+ # Sort results by frame index
243
+ symbol_results.sort(key=lambda x: x[0])
244
+
245
+ if calc_direction:
246
+ msd_x = np.array([r[1] for r in symbol_results])
247
+ msd_y = np.array([r[2] for r in symbol_results])
248
+ msd_z = np.array([r[3] for r in symbol_results])
249
+
250
+ # Fix: also store total and ensure (values, times) tuples
251
+ total_values = msd_x + msd_y + msd_z
252
+ result[symbol] = (total_values, msd_times)
253
+ result[f'{symbol}_x'] = (msd_x, msd_times)
254
+ result[f'{symbol}_y'] = (msd_y, msd_times)
255
+ result[f'{symbol}_z'] = (msd_z, msd_times)
256
+
257
+ print(f"Saved directional MSD data for {symbol} atoms")
258
+ else:
259
+ msd_values = np.array([r[1] for r in symbol_results])
260
+ result[symbol] = (msd_values, msd_times)
261
+
262
+ return result
263
+
264
+ def save_msd_data(msd_data, csv_file_path, output_dir="traj_csv_detailed"):
265
+ """
266
+ Save MSD data to CSV files.
267
+
268
+ Parameters
269
+ ----------
270
+ msd_data : tuple or dict
271
+ MSD data to be saved
272
+ csv_file_path : str
273
+ Path to the CSV file
274
+ output_dir : str, optional
275
+ Directory to save CSV files (default: "traj_csv_detailed")
276
+
277
+ Returns
278
+ -------
279
+ list
280
+ List of saved file paths
281
+ """
282
+ saved_files = []
283
+
284
+ os.makedirs(output_dir, exist_ok=True)
285
+
286
+ base_filename = os.path.basename(csv_file_path)
287
+
288
+ if isinstance(msd_data, tuple):
289
+ if len(msd_data) == 2:
290
+ msd_values, msd_times = msd_data
291
+
292
+ csv_full_path = os.path.join(output_dir, base_filename)
293
+
294
+ with open(csv_full_path, 'w', newline='') as csvfile:
295
+ csv_writer = csv.writer(csvfile)
296
+ csv_writer.writerow(['Time (fs)', 'MSD'])
297
+ for time, msd in zip(msd_times, msd_values):
298
+ csv_writer.writerow([time, msd])
299
+
300
+ print(f"MSD data has been saved to {csv_full_path}")
301
+ saved_files.append(csv_full_path)
302
+
303
+ elif len(msd_data) == 4:
304
+ msd_times, msd_x, msd_y, msd_z = msd_data
305
+
306
+ base_path, ext = os.path.splitext(base_filename)
307
+
308
+ # Save total (x+y+z) to the base file
309
+ total_values = msd_x + msd_y + msd_z
310
+ total_path = os.path.join(output_dir, base_filename)
311
+ with open(total_path, 'w', newline='') as csvfile:
312
+ csv_writer = csv.writer(csvfile)
313
+ csv_writer.writerow(['Time (fs)', 'MSD'])
314
+ for time, msd in zip(msd_times, total_values):
315
+ csv_writer.writerow([time, msd])
316
+ print(f"Total MSD data has been saved to {total_path}")
317
+ saved_files.append(total_path)
318
+
319
+ # Save directional components
320
+ x_path = os.path.join(output_dir, f"{base_path}_x{ext}")
321
+ with open(x_path, 'w', newline='') as csvfile:
322
+ csv_writer = csv.writer(csvfile)
323
+ csv_writer.writerow(['Time (fs)', 'MSD'])
324
+ for time, msd in zip(msd_times, msd_x):
325
+ csv_writer.writerow([time, msd])
326
+ print(f"X-direction MSD data has been saved to {x_path}")
327
+ saved_files.append(x_path)
328
+
329
+ y_path = os.path.join(output_dir, f"{base_path}_y{ext}")
330
+ with open(y_path, 'w', newline='') as csvfile:
331
+ csv_writer = csv.writer(csvfile)
332
+ csv_writer.writerow(['Time (fs)', 'MSD'])
333
+ for time, msd in zip(msd_times, msd_y):
334
+ csv_writer.writerow([time, msd])
335
+ print(f"Y-direction MSD data has been saved to {y_path}")
336
+ saved_files.append(y_path)
337
+
338
+ z_path = os.path.join(output_dir, f"{base_path}_z{ext}")
339
+ with open(z_path, 'w', newline='') as csvfile:
340
+ csv_writer = csv.writer(csvfile)
341
+ csv_writer.writerow(['Time (fs)', 'MSD'])
342
+ for time, msd in zip(msd_times, msd_z):
343
+ csv_writer.writerow([time, msd])
344
+ print(f"Z-direction MSD data has been saved to {z_path}")
345
+ saved_files.append(z_path)
346
+
347
+ elif isinstance(msd_data, dict):
348
+ base_name, ext = os.path.splitext(base_filename)
349
+
350
+ if 'overall' in msd_data:
351
+ overall_filename = f"{base_name}_overall{ext}"
352
+ overall_path = os.path.join(output_dir, overall_filename)
353
+ msd_values, msd_times = msd_data['overall']
354
+
355
+ with open(overall_path, 'w', newline='') as csvfile:
356
+ csv_writer = csv.writer(csvfile)
357
+ csv_writer.writerow(['Time (fs)', 'MSD'])
358
+ for time, msd in zip(msd_times, msd_values):
359
+ csv_writer.writerow([time, msd])
360
+
361
+ print(f"Overall MSD data has been saved to {overall_path}")
362
+ saved_files.append(overall_path)
363
+
364
+ for symbol, data in msd_data.items():
365
+ if symbol == 'overall':
366
+ continue
367
+
368
+ symbol_filename = f"{base_name}_{symbol}{ext}"
369
+ symbol_path = os.path.join(output_dir, symbol_filename)
370
+ msd_values, msd_times = data
371
+
372
+ with open(symbol_path, 'w', newline='') as csvfile:
373
+ csv_writer = csv.writer(csvfile)
374
+ csv_writer.writerow(['Time (fs)', 'MSD'])
375
+ for time, msd in zip(msd_times, msd_values):
376
+ csv_writer.writerow([time, msd])
377
+
378
+ print(f"MSD data for {symbol} atoms has been saved to {symbol_path}")
379
+ saved_files.append(symbol_path)
380
+
381
+ return saved_files
382
+
383
+ def calculate_diffusion_coefficient(msd_times, msd_values, start_index=None, end_index=None,
384
+ with_intercept=False, plot_msd=False, dimension=3):
385
+ """
386
+ Calculate diffusion coefficient from MSD data in a general way for 1D, 2D, or 3D.
387
+
388
+ Parameters
389
+ ----------
390
+ msd_times : numpy.ndarray
391
+ Time values in femtoseconds.
392
+ msd_values : numpy.ndarray
393
+ Mean square displacement values.
394
+ start_index : int, optional
395
+ Starting index for the fit (default: 1/3 of data length).
396
+ end_index : int, optional
397
+ Ending index for the fit (default: None).
398
+ with_intercept : bool, optional
399
+ Whether to fit with intercept (default: False).
400
+ plot_msd : bool, optional
401
+ Whether to plot the fit (default: False).
402
+ dimension : int, optional
403
+ Dimensionality of the system (default: 3). Use 1 for 1D, 2 for 2D, 3 for 3D.
404
+
405
+ Returns
406
+ -------
407
+ tuple
408
+ (D, error) where D is the diffusion coefficient in cm²/s and error is the statistical error.
409
+ """
410
+ if start_index is None:
411
+ start_index = len(msd_times) // 3
412
+ if end_index is None:
413
+ end_index = len(msd_times)
414
+ if start_index < 0 or end_index > len(msd_times):
415
+ raise ValueError("Indices are out of bounds.")
416
+ if start_index >= end_index:
417
+ raise ValueError("Start index must be less than end index.")
418
+
419
+ x_fit = msd_times[start_index:end_index]
420
+ y_fit = msd_values[start_index:end_index]
421
+
422
+ def linear_no_intercept(x, m):
423
+ return m * x
424
+
425
+ def linear_with_intercept(x, m, c):
426
+ return m * x + c
427
+
428
+ if with_intercept:
429
+ params, covariance = curve_fit(linear_with_intercept, x_fit, y_fit)
430
+ slope, intercept = params
431
+ fit_func = lambda x: linear_with_intercept(x, slope, intercept)
432
+ else:
433
+ params, covariance = curve_fit(linear_no_intercept, x_fit, y_fit)
434
+ slope = params[0]
435
+ intercept = 0
436
+ fit_func = lambda x: linear_no_intercept(x, slope)
437
+
438
+ std_err = np.sqrt(np.diag(covariance))[0]
439
+
440
+ # Calculate diffusion coefficient using D = slope / (2 * dimension)
441
+ # Correct conversion from Ų/fs to cm²/s:
442
+ # 1 Å = 10^-8 cm, 1 Ų = 10^-16 cm²
443
+ # 1 fs = 10^-15 s
444
+ # (Ų/fs) * (10^-16 cm²/Ų) / (10^-15 s/fs) = 10^-1 cm²/s
445
+ conversion_angstrom2_fs_to_cm2_s = 0.1
446
+ D = slope / (2 * dimension) * conversion_angstrom2_fs_to_cm2_s
447
+ error = std_err / (2 * dimension) * conversion_angstrom2_fs_to_cm2_s
448
+
449
+ # goodness‐of‐fit R²
450
+ y_model = fit_func(x_fit)
451
+ ss_res = np.sum((y_fit - y_model)**2)
452
+ ss_tot = np.sum((y_fit - np.mean(y_fit))**2)
453
+ r2 = 1 - ss_res/ss_tot
454
+
455
+ if plot_msd:
456
+ plt.figure(figsize=(10, 6))
457
+ # Convert time from fs to ps for plotting
458
+ plt.scatter(msd_times/1000, msd_values, s=10, alpha=0.5, label='MSD data')
459
+ plt.plot(x_fit/1000, fit_func(x_fit), 'r-', linewidth=2,
460
+ label=f'D = {D:.2e} cm²/s')
461
+ plt.xlabel('Time (ps)')
462
+ plt.ylabel('MSD (Ų)')
463
+ plt.title('Mean Square Displacement vs Time')
464
+ plt.grid(True, alpha=0.3)
465
+ plt.legend()
466
+ plt.tight_layout()
467
+ plt.savefig('msd_analysis.png', dpi=300, bbox_inches='tight')
468
+ plt.show()
469
+
470
+ print(f"R² = {r2:.4f}")
471
+
472
+ return D, error
473
+
474
+ def plot_diffusion_time_series(msd_times, msd_values, min_window=10, with_intercept=False, csv_file=None, dimension=3):
475
+ """
476
+ Plot diffusion coefficient as a time series by calculating it over different time windows.
477
+
478
+ Parameters
479
+ ----------
480
+ msd_times : numpy.ndarray
481
+ Time values in femtoseconds
482
+ msd_values : numpy.ndarray
483
+ Mean square displacement values in Ų
484
+ min_window : int, optional
485
+ Minimum window size for calculating diffusion (default: 10)
486
+ with_intercept : bool, optional
487
+ Whether to fit with intercept (default: False)
488
+ csv_file : str, optional
489
+ Path to the CSV file, used for output filename (default: None)
490
+ dimension : int, optional
491
+ Dimensionality of the system: 1 for 1D, 2 for 2D, 3 for 3D (default: 3)
492
+
493
+ Returns
494
+ -------
495
+ None
496
+ """
497
+
498
+ def linear_no_intercept(x, m):
499
+ return m * x
500
+
501
+ def linear_with_intercept(x, m, c):
502
+ return m * x + c
503
+
504
+ # Conversion from Ų/fs to Ų/ps
505
+ # 1 ps = 1000 fs, so multiply by 1000
506
+ conversion_fs_to_ps = 1000.0
507
+
508
+ diffusion_coeffs = []
509
+ window_ends = []
510
+
511
+ for end_idx in range(min_window + 1, len(msd_times)):
512
+ x_fit = msd_times[:end_idx]
513
+ y_fit = msd_values[:end_idx]
514
+
515
+ try:
516
+ if with_intercept:
517
+ params, covariance = curve_fit(linear_with_intercept, x_fit, y_fit)
518
+ slope = params[0]
519
+ else:
520
+ params, covariance = curve_fit(linear_no_intercept, x_fit, y_fit)
521
+ slope = params[0]
522
+
523
+ D = slope / (2 * dimension) * conversion_fs_to_ps
524
+
525
+ diffusion_coeffs.append(D)
526
+ window_ends.append(msd_times[end_idx-1])
527
+ except:
528
+ continue
529
+
530
+ plt.figure(figsize=(10, 6))
531
+
532
+ window_ends_ps = np.array(window_ends) / 1000.0
533
+
534
+ # Plot diffusion coefficient vs. time
535
+ plt.plot(window_ends_ps, diffusion_coeffs, 'b-', linewidth=2, label='Diffusion Coefficient')
536
+
537
+ if len(diffusion_coeffs) > 1:
538
+ avg_diffusion = np.mean(diffusion_coeffs)
539
+ std_diffusion = np.std(diffusion_coeffs, ddof=1)
540
+ plt.axhline(y=avg_diffusion, color='r', linestyle='--',
541
+ label=f'Average D = {avg_diffusion:.2e} ± {std_diffusion:.2e} Ų/ps')
542
+
543
+ plt.axhspan(avg_diffusion - std_diffusion, avg_diffusion + std_diffusion,
544
+ color='r', alpha=0.2)
545
+
546
+ plt.xlabel('Time Window End (ps)')
547
+ plt.ylabel('Diffusion Coefficient (Ų/ps)')
548
+ plt.title(f'{dimension}D Diffusion Coefficient Evolution Over Time')
549
+ plt.grid(True, alpha=0.3)
550
+ plt.legend()
551
+
552
+ output_file = 'diffusion_time_series.png'
553
+ if csv_file:
554
+ dir_name = os.path.dirname(csv_file)
555
+ base_name = os.path.splitext(os.path.basename(csv_file))[0]
556
+ output_file = os.path.join(dir_name, f"{base_name}_{dimension}D_diffusion_evolution.png")
557
+
558
+ plt.tight_layout()
559
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
560
+ plt.show()
561
+
562
+ print(f"Diffusion coefficient evolution plot saved to: {output_file}")
563
+
564
+ def calculate_save_msd(traj_path, timestep_fs, indices_path=None,
565
+ ignore_n_images=0, output_file="msd_results.csv",
566
+ frame_skip=1, n_jobs=-1, output_dir="traj_csv_detailed",
567
+ msd_direction=False, msd_direction_atom=None,
568
+ use_windowed=True, lag_times_fs=None):
569
+ """
570
+ Calculate MSD data and save to CSV file.
571
+
572
+ Parameters
573
+ ----------
574
+ traj_path : str
575
+ Path to the ASE trajectory file
576
+ timestep_fs : float
577
+ Simulation timestep in femtoseconds (fs)
578
+ indices_path : str, optional
579
+ Path to file containing atom indices (default: None)
580
+ ignore_n_images : int, optional
581
+ Number of initial images to ignore (default: 0)
582
+ output_file : str, optional
583
+ Output CSV file path (default: "msd_results.csv")
584
+ frame_skip : int, optional
585
+ Number of frames to skip between samples (default: 1)
586
+ n_jobs : int, optional
587
+ Number of parallel jobs to run (default: -1, use all available cores)
588
+ output_dir : str, optional
589
+ Directory to save CSV files (default: "traj_csv_detailed")
590
+ msd_direction : bool, optional
591
+ Whether to calculate directional MSD (default: False)
592
+ msd_direction_atom : str or int, optional
593
+ Atom symbol or atomic number for directional analysis (default: None)
594
+ use_windowed : bool, optional
595
+ Whether to use the windowed approach for more robust statistics (default: True)
596
+ lag_times_fs : list of float, optional
597
+ List of lag times (in fs) for which to compute MSD (default: None, use all possible lags)
598
+
599
+ Returns
600
+ -------
601
+ tuple or dict
602
+ MSD values and corresponding time values
603
+ """
604
+ if not os.path.exists(traj_path):
605
+ raise FileNotFoundError(f"Trajectory file not found: {traj_path}")
606
+
607
+ if indices_path is not None and not os.path.exists(indices_path):
608
+ raise FileNotFoundError(f"Indices file not found: {indices_path}")
609
+
610
+ try:
611
+ traj = ase.io.read(traj_path, index=f'::{frame_skip}')
612
+ if not isinstance(traj, list):
613
+ traj = [traj]
614
+
615
+ print(f"Loaded {len(traj)} frames after applying frame_skip={frame_skip}")
616
+
617
+ traj_nodrift = []
618
+ for frame in traj:
619
+ new_frame = frame.copy()
620
+ com = new_frame.get_center_of_mass()
621
+ new_frame.set_positions(new_frame.get_positions() - com)
622
+ traj_nodrift.append(new_frame)
623
+ traj = traj_nodrift
624
+
625
+
626
+ except Exception as e:
627
+ print(f"Error loading trajectory file: {e}")
628
+ return None, None
629
+
630
+ atom_indices = None
631
+ if indices_path:
632
+ try:
633
+ atom_indices = np.load(indices_path)
634
+ print(f"Loaded {len(atom_indices)} atom indices")
635
+ except Exception as e:
636
+ print(f"Error loading atom indices: {e}")
637
+ return None, None
638
+
639
+ timestep = timestep_fs * fs
640
+ print(f"Using timestep: {timestep_fs} fs")
641
+
642
+ print("Calculating MSD using parallel processing...")
643
+ if use_windowed:
644
+ print("Using windowed approach for MSD calculation (averaging over all time origins)")
645
+ if indices_path:
646
+ msd_data = calculate_msd_windowed(
647
+ traj=traj,
648
+ timestep=timestep,
649
+ atom_indices=atom_indices,
650
+ ignore_n_images=ignore_n_images,
651
+ n_jobs=n_jobs,
652
+ msd_direction=msd_direction,
653
+ lag_times_fs=lag_times_fs
654
+ )
655
+ else:
656
+ msd_data = calculate_msd_windowed(
657
+ traj=traj,
658
+ timestep=timestep,
659
+ atom_indices=atom_indices,
660
+ ignore_n_images=ignore_n_images,
661
+ n_jobs=n_jobs,
662
+ msd_direction=msd_direction,
663
+ msd_direction_atom=msd_direction_atom,
664
+ lag_times_fs=lag_times_fs
665
+ )
666
+ else:
667
+ print("Using single reference frame approach for MSD calculation")
668
+ if indices_path:
669
+ msd_data = calculate_msd(
670
+ traj=traj,
671
+ timestep=timestep,
672
+ atom_indices=atom_indices,
673
+ ignore_n_images=ignore_n_images,
674
+ n_jobs=n_jobs,
675
+ msd_direction=msd_direction
676
+ )
677
+ else:
678
+ msd_data = calculate_msd(
679
+ traj=traj,
680
+ timestep=timestep,
681
+ atom_indices=atom_indices,
682
+ ignore_n_images=ignore_n_images,
683
+ n_jobs=n_jobs,
684
+ msd_direction=msd_direction,
685
+ msd_direction_atom=msd_direction_atom
686
+ )
687
+
688
+ saved_files = save_msd_data(
689
+ msd_data=msd_data,
690
+ csv_file_path=output_file,
691
+ output_dir=output_dir
692
+ )
693
+
694
+ return msd_data
695
+
696
+ def analyze_from_csv(csv_file="msd_results.csv", fit_start=None, fit_end=None, dimension=3,
697
+ with_intercept=False, plot_msd=False, plot_diffusion=False,
698
+ use_block_averaging=False, n_blocks=10):
699
+ """
700
+ Analyze MSD data from a CSV file with block averaging by default.
701
+
702
+ Parameters
703
+ ----------
704
+ csv_file : str, optional
705
+ Path to the CSV file containing MSD data (default: "msd_results.csv")
706
+ fit_start : int, optional
707
+ Start index for fitting or visualization if using block averaging (default: None)
708
+ fit_end : int, optional
709
+ End index for fitting or visualization if using block averaging (default: None)
710
+ dimension : int, optional
711
+ Dimensionality of the system: 1 for 1D, 2 for 2D, 3 for 3D (default: 3)
712
+ with_intercept : bool, optional
713
+ Whether to fit with intercept (default: False)
714
+ plot_msd : bool, optional
715
+ Whether to plot MSD vs time (default: False)
716
+ plot_diffusion : bool, optional
717
+ Whether to plot diffusion coefficient as time series (default: False)
718
+ use_block_averaging : bool, optional
719
+ Whether to use block averaging for error estimation (default: True)
720
+ n_blocks : int, optional
721
+ Number of blocks for block averaging (default: 10)
722
+
723
+ Returns
724
+ -------
725
+ tuple
726
+ (D, error) where D is the diffusion coefficient in cm²/s and error is the statistical error
727
+ """
728
+ try:
729
+ df = pd.read_csv(csv_file)
730
+ print(f"Loaded MSD data from {csv_file}")
731
+
732
+ # Extract time and MSD values
733
+ msd_times = df['Time (fs)'].values
734
+ msd_values = df['MSD'].values
735
+
736
+ # Diffusion coefficient
737
+ if use_block_averaging:
738
+ # Use fit_start and fit_end to select the fit zone
739
+ fit_start_idx = fit_start if fit_start is not None else 0
740
+ fit_end_idx = fit_end if fit_end is not None else len(msd_times)
741
+
742
+ msd_times_fit = msd_times[fit_start_idx:fit_end_idx]
743
+ msd_values_fit = msd_values[fit_start_idx:fit_end_idx]
744
+
745
+ D, error = block_averaging_error(
746
+ msd_times=msd_times_fit,
747
+ msd_values=msd_values_fit,
748
+ n_blocks=n_blocks,
749
+ dimension=dimension,
750
+ with_intercept=with_intercept
751
+ )
752
+
753
+ print(f"Using block averaging method with {n_blocks} blocks")
754
+
755
+ # Plot MSD with the block averaging diffusion coefficient
756
+ if plot_msd:
757
+ visualization_start = fit_start if fit_start is not None else int(len(msd_times) * 0.3)
758
+ visualization_end = fit_end if fit_end is not None else int(len(msd_times) * 0.8)
759
+
760
+ plt.figure(figsize=(10, 6))
761
+ plt.scatter(msd_times/1000, msd_values, s=10, alpha=0.5, label='MSD data')
762
+
763
+ x_fit = msd_times[visualization_start:visualization_end]
764
+ if with_intercept:
765
+ slope = 2 * dimension * D / 0.1 # Convert from cm²/s to Ų/fs
766
+ y_fit = msd_values[visualization_start:visualization_end]
767
+ residuals = y_fit - (slope * x_fit)
768
+ intercept = np.mean(residuals)
769
+ fit_line = slope * x_fit + intercept
770
+ else:
771
+ slope = 2 * dimension * D / 0.1
772
+ fit_line = slope * x_fit
773
+
774
+ plt.plot(x_fit/1000, fit_line, 'r-', linewidth=2,
775
+ label=f'Block Avg: D = ({D:.2e} ± {error:.2e}) cm²/s')
776
+ plt.xlabel('Time (ps)')
777
+ plt.ylabel('MSD (Ų)')
778
+ plt.title('Mean Square Displacement vs Time')
779
+ plt.grid(True, alpha=0.3)
780
+ plt.legend()
781
+
782
+ dir_name = os.path.dirname(csv_file)
783
+ base_name = os.path.splitext(os.path.basename(csv_file))[0]
784
+ output_file = os.path.join(dir_name, f"{base_name}_msd_block_avg.png")
785
+ plt.tight_layout()
786
+ plt.savefig(output_file, dpi=300, bbox_inches='tight')
787
+ plt.show()
788
+ print(f"MSD plot saved to: {output_file}")
789
+ else:
790
+ D, error = calculate_diffusion_coefficient(
791
+ msd_times=msd_times,
792
+ msd_values=msd_values,
793
+ start_index=fit_start,
794
+ end_index=fit_end,
795
+ with_intercept=with_intercept,
796
+ plot_msd=plot_msd,
797
+ dimension=dimension
798
+ )
799
+
800
+ if plot_diffusion:
801
+ plot_diffusion_time_series(msd_times, msd_values, 10, with_intercept, csv_file, dimension)
802
+
803
+ method = "Block Averaging" if use_block_averaging else "Standard Fit"
804
+ print(f"\nMSD Analysis Results ({method}):")
805
+ if use_block_averaging:
806
+ print(f"Diffusion Coefficient: D = {D:.4e} ± {error:.4e} cm²/s ({100*error/D:.1f}%)")
807
+ else:
808
+ print(f"Diffusion Coefficient: D = {D:.4e} cm²/s")
809
+
810
+ return D, error
811
+
812
+ except Exception as e:
813
+ print(f"Error analyzing MSD data: {e}")
814
+ traceback.print_exc()
815
+ return None, None
816
+
817
+ def msd_analysis(traj_path, timestep_fs, indices_path=None, ignore_n_images=0,
818
+ output_dir=None, frame_skip=10, fit_start=None, fit_end=None,
819
+ with_intercept=False, plot_msd=False, save_csvs_in_subdir=False,
820
+ msd_direction=False, msd_direction_atom=None, dimension=3,
821
+ use_windowed=True, lag_times_fs=None):
822
+ """
823
+ Perform MSD analysis workflow: calculate MSD and save data.
824
+
825
+ Parameters
826
+ ----------
827
+ traj_path : str
828
+ Path to the ASE trajectory file
829
+ timestep_fs : float
830
+ Simulation timestep in femtoseconds (fs)
831
+ indices_path : str, optional
832
+ Path to file containing atom indices (default: None)
833
+ ignore_n_images : int, optional
834
+ Number of initial images to ignore (default: 0)
835
+ output_dir : str, optional
836
+ Directory to save output files (default: based on trajectory filename)
837
+ frame_skip : int, optional
838
+ Number of frames to skip between samples (default: 10)
839
+ fit_start : int, optional
840
+ Start index for fitting diffusion coefficient (default: None)
841
+ fit_end : int, optional
842
+ End index for fitting diffusion coefficient (default: None)
843
+ with_intercept : bool, optional
844
+ Whether to fit with intercept (default: False)
845
+ plot_msd : bool, optional
846
+ Whether to plot results (default: False)
847
+ save_csvs_in_subdir : bool, optional
848
+ Whether to save CSV files in a subdirectory (default: False)
849
+ msd_direction : bool, optional
850
+ Whether to calculate directional MSD (default: False)
851
+ msd_direction_atom : str or int, optional
852
+ Atom symbol or atomic number for directional analysis (default: None)
853
+ dimension : int, optional
854
+ Dimensionality of the system: 1 for 1D, 2 for 2D, 3 for 3D (default: 3)
855
+ use_windowed : bool, optional
856
+ Whether to use the windowed approach for more robust statistics (default: True)
857
+ lag_times_fs : list of float, optional
858
+ List of lag times (in fs) for which to compute MSD (default: None, use all possible lags)
859
+
860
+ Returns
861
+ -------
862
+ dict
863
+ Dictionary containing MSD values, times, output directory, and optionally diffusion coefficient
864
+ """
865
+ if output_dir is None:
866
+ traj_basename = os.path.splitext(os.path.basename(traj_path))[0]
867
+ output_dir = f"msd_{traj_basename}"
868
+ print(f"Using trajectory-based output directory: {output_dir}")
869
+
870
+ os.makedirs(output_dir, exist_ok=True)
871
+
872
+ csv_path = os.path.join(output_dir, "msd_data.csv")
873
+
874
+ csv_dir = os.path.join(output_dir, "csv_data") if save_csvs_in_subdir else output_dir
875
+
876
+ # Calculate and save MSD data (passing timestep directly in fs)
877
+ msd_data = calculate_save_msd(
878
+ traj_path=traj_path,
879
+ timestep_fs=timestep_fs,
880
+ indices_path=indices_path,
881
+ ignore_n_images=ignore_n_images,
882
+ output_file=csv_path,
883
+ frame_skip=frame_skip,
884
+ output_dir=csv_dir,
885
+ msd_direction=msd_direction,
886
+ msd_direction_atom=msd_direction_atom,
887
+ use_windowed=use_windowed,
888
+ lag_times_fs=lag_times_fs
889
+ )
890
+
891
+ if isinstance(msd_data, tuple) and msd_data[0] is None:
892
+ print("Error: Failed to calculate MSD data")
893
+ return {"error": "Failed to calculate MSD data"}
894
+
895
+ # Extract MSD values and times for return value
896
+ if isinstance(msd_data, dict):
897
+ if 'overall' in msd_data:
898
+ msd_values, msd_times = msd_data['overall']
899
+ else:
900
+ # Use the first atom type's data
901
+ first_symbol = next(iter(msd_data))
902
+ msd_values, msd_times = msd_data[first_symbol]
903
+ else:
904
+ if len(msd_data) == 4:
905
+ msd_times, msd_x, msd_y, msd_z = msd_data
906
+ msd_values = msd_x + msd_y + msd_z # use total for downstream analysis
907
+ else:
908
+ msd_values, msd_times = msd_data
909
+
910
+ result_dict = {
911
+ "msd_values": msd_values,
912
+ "msd_times": msd_times,
913
+ "output_dir": output_dir
914
+ }
915
+
916
+ # Diffusion coefficient analyzing the CSV file
917
+ if fit_start is not None or fit_end is not None or plot_msd:
918
+ try:
919
+ print("Calculating diffusion coefficient...")
920
+ D, error = calculate_diffusion_coefficient(
921
+ msd_times=msd_times,
922
+ msd_values=msd_values,
923
+ start_index=fit_start,
924
+ end_index=fit_end,
925
+ with_intercept=with_intercept,
926
+ plot_msd=plot_msd,
927
+ dimension=dimension
928
+ )
929
+
930
+ if D is not None:
931
+ result_dict["diffusion_coefficient"] = D
932
+ result_dict["error"] = error
933
+ print(f"Calculated diffusion coefficient: {D:.2e} cm²/s")
934
+ except Exception as e:
935
+ print(f"Error calculating diffusion coefficient: {e}")
936
+
937
+ return result_dict
938
+
939
+ def block_averaging_error(msd_times, msd_values, n_blocks=5, dimension=3, **kwargs):
940
+ """
941
+ Calculate diffusion coefficient error using block averaging.
942
+
943
+ Parameters
944
+ ----------
945
+ msd_times : numpy.ndarray
946
+ Time values in femtoseconds
947
+ msd_values : numpy.ndarray
948
+ MSD values
949
+ n_blocks : int
950
+ Number of blocks to divide the data into
951
+ dimension : int
952
+ Dimensionality of the system (default: 3)
953
+
954
+ Returns
955
+ -------
956
+ tuple
957
+ (mean_D, std_error_D) - mean diffusion coefficient and its standard error
958
+ """
959
+ # Block size
960
+ block_size = len(msd_times) // n_blocks
961
+ if block_size < 10:
962
+ print(f"Warning: Block size is small ({block_size} points). Consider using fewer blocks.")
963
+
964
+ # D for each block
965
+ D_values = []
966
+ for i in range(n_blocks):
967
+ start_idx = i * block_size
968
+ end_idx = (i + 1) * block_size if i < n_blocks - 1 else len(msd_times)
969
+
970
+ if end_idx - start_idx < 10:
971
+ continue
972
+
973
+ # Remove fit_start/end from kwargs for block fit
974
+ block_kwargs = {k: v for k, v in kwargs.items() if k not in ['start_index', 'end_index', 'fit_start', 'fit_end']}
975
+
976
+ try:
977
+ D, _ = calculate_diffusion_coefficient(
978
+ msd_times[start_idx:end_idx],
979
+ msd_values[start_idx:end_idx],
980
+ dimension=dimension,
981
+ **block_kwargs
982
+ )
983
+ D_values.append(D)
984
+ except Exception as e:
985
+ print(f"Warning: Failed to fit block {i}: {e}")
986
+
987
+ # Statistics
988
+ D_values = np.array(D_values)
989
+ mean_D = np.mean(D_values)
990
+ std_D = np.std(D_values, ddof=1)
991
+ std_error_D = std_D / np.sqrt(len(D_values))
992
+
993
+ return mean_D, std_error_D
994
+
995
+ def calculate_frame_msd_windowed(positions_i, positions_j, atom_indices, msd_direction=False):
996
+ """
997
+ Calculate MSD between two frames for the windowed approach.
998
+
999
+ Parameters
1000
+ ----------
1001
+ positions_i : numpy.ndarray
1002
+ Positions at time i
1003
+ positions_j : numpy.ndarray
1004
+ Positions at time j
1005
+ atom_indices : list
1006
+ List of atom indices to include in MSD calculation
1007
+ msd_direction : bool, optional
1008
+ Whether to calculate directional MSD (default: False)
1009
+
1010
+ Returns
1011
+ -------
1012
+ float or tuple
1013
+ If msd_direction is False: msd_value
1014
+ If msd_direction is True: (msd_x, msd_y, msd_z)
1015
+ """
1016
+ atom_positions_i = positions_i[atom_indices]
1017
+ atom_positions_j = positions_j[atom_indices]
1018
+ displacements = atom_positions_j - atom_positions_i
1019
+
1020
+ if not msd_direction:
1021
+ # Calculate total MSD
1022
+ msd_value = np.sum(np.square(displacements)) / (len(atom_indices))
1023
+ return msd_value
1024
+ else:
1025
+ # Directional MSDs
1026
+ msd_x = np.sum(displacements[:, 0]**2) / len(atom_indices)
1027
+ msd_y = np.sum(displacements[:, 1]**2) / len(atom_indices)
1028
+ msd_z = np.sum(displacements[:, 2]**2) / len(atom_indices)
1029
+
1030
+ return msd_x, msd_y, msd_z
1031
+
1032
+ def calculate_msd_windowed(
1033
+ traj, timestep, atom_indices=None, ignore_n_images=0, n_jobs=-1,
1034
+ msd_direction=False, msd_direction_atom=None, lag_times_fs=None
1035
+ ):
1036
+ """
1037
+ Calculate Mean Square Displacement (MSD) vs time using the windowed approach,
1038
+ averaging over all possible time origins.
1039
+
1040
+ Parameters
1041
+ ----------
1042
+ traj : list of ase.Atoms
1043
+ Trajectory data
1044
+ timestep : float
1045
+ Simulation timestep
1046
+ atom_indices : numpy.ndarray, optional
1047
+ Indices of atoms to analyze (default: all atoms)
1048
+ ignore_n_images : int, optional
1049
+ Number of initial images to ignore (default: 0)
1050
+ n_jobs : int, optional
1051
+ Number of parallel jobs to run (default: -1, use all available cores)
1052
+ msd_direction : bool, optional
1053
+ Whether to calculate directional MSD (default: False)
1054
+ msd_direction_atom : str or int, optional
1055
+ Atom symbol or atomic number to filter for directional MSD (default: None)
1056
+ lag_times_fs : list of float, optional
1057
+ List of lag times (in fs) for which to compute MSD (default: None, use all possible lags)
1058
+
1059
+ Returns
1060
+ -------
1061
+ tuple or dict
1062
+ If atom_indices is provided: (msd_values, msd_times) or (msd_times, msd_x, msd_y, msd_z) if msd_direction=True
1063
+ If atom_indices is None: A dictionary with keys for each atom type
1064
+ """
1065
+
1066
+ # Time values
1067
+ total_images = len(traj) - ignore_n_images
1068
+ timestep_fs = timestep / fs # Convert timestep to fs
1069
+ n_frames = total_images
1070
+ positions = [traj[i].positions for i in range(ignore_n_images, len(traj))]
1071
+ positions = np.array(positions)
1072
+
1073
+ # Determine lag times (in frames)
1074
+ if lag_times_fs is not None:
1075
+ lag_frames = [int(round(lag_fs / timestep_fs)) for lag_fs in lag_times_fs if lag_fs > 0]
1076
+ lag_frames = [lf for lf in lag_frames if 1 <= lf < n_frames]
1077
+ else:
1078
+ lag_frames = list(range(1, n_frames))
1079
+
1080
+ msd_times = np.array(lag_frames) * timestep_fs # in fs
1081
+
1082
+ # MSD for specific atoms
1083
+ if atom_indices is not None:
1084
+ if msd_direction:
1085
+ msd_x = np.zeros(len(lag_frames))
1086
+ msd_y = np.zeros(len(lag_frames))
1087
+ msd_z = np.zeros(len(lag_frames))
1088
+ for idx, lag in enumerate(lag_frames):
1089
+ n_pairs = n_frames - lag
1090
+ results = Parallel(n_jobs=n_jobs)(
1091
+ delayed(calculate_frame_msd_windowed)(
1092
+ positions[i], positions[i + lag], atom_indices, True
1093
+ ) for i in range(n_pairs)
1094
+ )
1095
+ msd_x[idx] = np.mean([r[0] for r in results])
1096
+ msd_y[idx] = np.mean([r[1] for r in results])
1097
+ msd_z[idx] = np.mean([r[2] for r in results])
1098
+ return msd_times, msd_x, msd_y, msd_z
1099
+ else:
1100
+ msd_values = np.zeros(len(lag_frames))
1101
+ for idx, lag in enumerate(lag_frames):
1102
+ n_pairs = n_frames - lag
1103
+ results = Parallel(n_jobs=n_jobs)(
1104
+ delayed(calculate_frame_msd_windowed)(
1105
+ positions[i], positions[i + lag], atom_indices, False
1106
+ ) for i in range(n_pairs)
1107
+ )
1108
+ msd_values[idx] = np.mean(results)
1109
+ return msd_values, msd_times
1110
+
1111
+ # MSD per atom type
1112
+ else:
1113
+ atoms = traj[0]
1114
+ symbols = atoms.get_chemical_symbols()
1115
+ unique_symbols = set(symbols)
1116
+
1117
+ # A dictionary mapping symbols to their indices
1118
+ symbol_indices = {symbol: [i for i, s in enumerate(symbols) if s == symbol]
1119
+ for symbol in unique_symbols}
1120
+
1121
+ # Overall MSD using all atoms
1122
+ all_indices = list(range(len(atoms)))
1123
+
1124
+ # Initialize array for overall MSD
1125
+ overall_msd = np.zeros(len(lag_frames))
1126
+
1127
+ # For each lag time, calculate the average MSD over all possible time origins
1128
+ for idx, lag in enumerate(lag_frames):
1129
+ n_pairs = n_frames - lag
1130
+ results = Parallel(n_jobs=n_jobs)(
1131
+ delayed(calculate_frame_msd_windowed)(
1132
+ positions[i],
1133
+ positions[i + lag],
1134
+ all_indices,
1135
+ False
1136
+ )
1137
+ for i in range(n_pairs)
1138
+ )
1139
+ overall_msd[idx] = np.mean(results)
1140
+
1141
+ # Dictionary to store MSD results
1142
+ result = {'overall': (overall_msd[:], msd_times[:])}
1143
+
1144
+ # Calculate MSD for each atom type
1145
+ for symbol, indices in symbol_indices.items():
1146
+ print(f"Calculating MSD for {symbol} atoms...")
1147
+ calc_direction = msd_direction and (
1148
+ (isinstance(msd_direction_atom, str) and symbol == msd_direction_atom) or
1149
+ (isinstance(msd_direction_atom, int) and
1150
+ atoms.get_atomic_numbers()[indices[0]] == msd_direction_atom)
1151
+ )
1152
+
1153
+ if calc_direction:
1154
+ # Initialize arrays for directional MSD
1155
+ msd_x = np.zeros(len(lag_frames))
1156
+ msd_y = np.zeros(len(lag_frames))
1157
+ msd_z = np.zeros(len(lag_frames))
1158
+
1159
+ # For each lag time, calculate the average MSD over all possible time origins
1160
+ for idx2, lag in enumerate(lag_frames):
1161
+ n_pairs = n_frames - lag
1162
+ results = Parallel(n_jobs=n_jobs)(
1163
+ delayed(calculate_frame_msd_windowed)(
1164
+ positions[i],
1165
+ positions[i + lag],
1166
+ indices,
1167
+ True
1168
+ )
1169
+ for i in range(n_pairs)
1170
+ )
1171
+ msd_x[idx2] = np.mean([r[0] for r in results])
1172
+ msd_y[idx2] = np.mean([r[1] for r in results])
1173
+ msd_z[idx2] = np.mean([r[2] for r in results])
1174
+
1175
+ total_values = msd_x + msd_y + msd_z
1176
+ result[symbol] = (total_values, msd_times[:])
1177
+ result[f'{symbol}_x'] = (msd_x, msd_times[:])
1178
+ result[f'{symbol}_y'] = (msd_y, msd_times[:])
1179
+ result[f'{symbol}_z'] = (msd_z, msd_times[:])
1180
+
1181
+ print(f"Saved directional MSD data for {symbol} atoms")
1182
+ else:
1183
+ msd_values = np.zeros(len(lag_frames))
1184
+ for idx2, lag in enumerate(lag_frames):
1185
+ n_pairs = n_frames - lag
1186
+ results = Parallel(n_jobs=n_jobs)(
1187
+ delayed(calculate_frame_msd_windowed)(
1188
+ positions[i],
1189
+ positions[i + lag],
1190
+ indices,
1191
+ False
1192
+ )
1193
+ for i in range(n_pairs)
1194
+ )
1195
+ msd_values[idx2] = np.mean(results)
1196
+
1197
+ result[symbol] = (msd_values[:], msd_times[:])
1198
+
1199
+ return result