crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,155 @@
1
+ """
2
+ CRISP/simulation_utility/atomic_indices.py
3
+
4
+ This module extracts atomic indices from trajectory files
5
+ and identifying atom pairs within specified cutoff distances.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ import ase.io
11
+ import csv
12
+ from typing import Dict, Tuple, List, Optional
13
+
14
+ __all__ = ['atom_indices', 'run_atom_indices']
15
+
16
+ def atom_indices(
17
+ traj_path: str,
18
+ frame_index: int = 0,
19
+ custom_cutoffs: Optional[Dict[Tuple[str, str], float]] = None
20
+ ) -> Tuple[Dict[str, List[int]], np.ndarray, Dict[Tuple[str, str], List[Tuple[int, int, float]]]]:
21
+ """Extract atom indices by chemical symbol and find atom pairs within specified cutoffs.
22
+
23
+ Parameters
24
+ ----------
25
+ traj_path : str
26
+ Path to the trajectory file in any format supported by ASE
27
+ frame_index : int, optional
28
+ Index of the frame to analyze (default: 0)
29
+ custom_cutoffs : dict, optional
30
+ Dictionary with atom symbol pairs as keys and cutoff distances as values
31
+ Example: {('Si', 'O'): 2.0, ('Al', 'O'): 2.1}
32
+
33
+ Returns
34
+ -------
35
+ indices_by_symbol : dict
36
+ Dictionary with chemical symbols as keys and lists of atomic indices as values
37
+ dist_matrix : numpy.ndarray
38
+ Distance matrix between all atoms, accounting for periodic boundary conditions
39
+ cutoff_indices : dict
40
+ Dictionary with atom symbol pairs as keys and lists of (idx1, idx2, distance) tuples
41
+ for atoms that are within the specified cutoff distance
42
+ """
43
+ try:
44
+ new = ase.io.read(traj_path, index=frame_index)
45
+ dist_matrix = new.get_all_distances(mic=True)
46
+ symbols = new.get_chemical_symbols()
47
+
48
+ unique_symbols = list(set(symbols))
49
+
50
+ indices_by_symbol = {symbol: [] for symbol in unique_symbols}
51
+
52
+ for idx, atom in enumerate(new):
53
+ indices_by_symbol[atom.symbol].append(idx)
54
+
55
+ cutoff_indices = {}
56
+
57
+ if custom_cutoffs:
58
+ for pair, cutoff in custom_cutoffs.items():
59
+ symbol1, symbol2 = pair
60
+ pair_indices_distances = []
61
+ if symbol1 in indices_by_symbol and symbol2 in indices_by_symbol:
62
+ for idx1 in indices_by_symbol[symbol1]:
63
+ for idx2 in indices_by_symbol[symbol2]:
64
+ if dist_matrix[idx1, idx2] < cutoff:
65
+ pair_indices_distances.append(
66
+ (idx1, idx2, dist_matrix[idx1, idx2])
67
+ )
68
+ cutoff_indices[pair] = pair_indices_distances
69
+
70
+ return indices_by_symbol, dist_matrix, cutoff_indices
71
+
72
+ except Exception as e:
73
+ raise ValueError(f"Error processing atomic structure: {e}. Check if the format is supported by ASE.")
74
+
75
+
76
+ def run_atom_indices(
77
+ traj_path: str,
78
+ output_dir: str,
79
+ frame_index: int = 0,
80
+ custom_cutoffs: Optional[Dict[Tuple[str, str], float]] = None
81
+ ) -> None:
82
+ """Run atom index extraction and save results to files.
83
+
84
+ Parameters
85
+ ----------
86
+ traj_path : str
87
+ Path to the trajectory file in any format supported by ASE
88
+ output_dir : str
89
+ Directory where output files will be saved
90
+ frame_index : int, optional
91
+ Index of the frame to analyze (default: 0)
92
+ custom_cutoffs : dict, optional
93
+ Dictionary with atom symbol pairs as keys and cutoff distances as values
94
+
95
+ Returns
96
+ -------
97
+ None
98
+ Results are saved to the specified output directory:
99
+ - lengths.npy: Dictionary of number of atoms per element
100
+ - {symbol}_indices.npy: Numpy array of atom indices for each element
101
+ - cutoff/{symbol1}-{symbol2}_cutoff.csv: CSV files with atom pairs within cutoff
102
+ """
103
+ try:
104
+ try:
105
+ traj = ase.io.read(traj_path, index=":")
106
+ if isinstance(traj, list):
107
+ traj_length = len(traj)
108
+ else:
109
+ traj_length = 1
110
+ except TypeError:
111
+ ase.io.read(traj_path)
112
+ traj_length = 1
113
+
114
+ # Check if frame_index is within valid range
115
+ if frame_index < 0 or frame_index >= traj_length:
116
+ raise ValueError(
117
+ f"Error: Frame index {frame_index} is out of range. "
118
+ f"Valid range is 0 to {traj_length-1}."
119
+ )
120
+
121
+ print(f"Analyzing frame with index {frame_index} (out of {traj_length} frames)")
122
+
123
+ except ValueError as e:
124
+ raise e
125
+
126
+ except Exception as e:
127
+ raise ValueError(f"Error reading trajectory: {e}. Check if the format is supported by ASE.")
128
+
129
+ indices, dist_matrix, cutoff_indices = atom_indices(traj_path, frame_index, custom_cutoffs)
130
+
131
+ if not os.path.exists(output_dir):
132
+ os.makedirs(output_dir)
133
+
134
+ lengths = {symbol: len(indices[symbol]) for symbol in indices}
135
+ np.save(os.path.join(output_dir, "lengths.npy"), lengths)
136
+
137
+ for symbol, data in indices.items():
138
+ np.save(os.path.join(output_dir, f"{symbol}_indices.npy"), data)
139
+ print(f"Length of {symbol} indices: {len(data)}")
140
+
141
+ print("Outputs saved.")
142
+
143
+ cutoff_folder = os.path.join(output_dir, "cutoff")
144
+ if not os.path.exists(cutoff_folder):
145
+ os.makedirs(cutoff_folder)
146
+
147
+ for pair, pair_indices_distances in cutoff_indices.items():
148
+ symbol1, symbol2 = pair
149
+ filename = f"{symbol1}-{symbol2}_cutoff.csv"
150
+ filepath = os.path.join(cutoff_folder, filename)
151
+ with open(filepath, mode="w", newline="") as file:
152
+ writer = csv.writer(file)
153
+ writer.writerow([f"{symbol1} index", f"{symbol2} index", "distance"])
154
+ writer.writerows(pair_indices_distances)
155
+ print(f"Saved cutoff indices for {symbol1}-{symbol2} to {filepath}")
@@ -0,0 +1,278 @@
1
+ """
2
+ CRISP/simulation_utility/atomic_traj_linemap.py
3
+
4
+ This module provides functionality for visualizing atomic trajectories from molecular dynamics
5
+ simulations using interactive 3D plots.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ from typing import List, Optional, Dict, Union
11
+ from ase.io import read
12
+ import plotly.graph_objects as go
13
+ import plotly.io as pio
14
+ pio.renderers.default = "notebook"
15
+ __all__ = ['plot_atomic_trajectory', 'VDW_RADII', 'ELEMENT_COLORS']
16
+ # Dictionary of van der Waals radii for all elements in Ångström
17
+ VDW_RADII = {
18
+ # Period 1
19
+ 'H': 1.20, 'He': 1.40,
20
+ # Period 2
21
+ 'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
22
+ # Period 3
23
+ 'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
24
+ # Period 4
25
+ 'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
26
+ 'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
27
+ 'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
28
+ # Period 5
29
+ 'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
30
+ 'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
31
+ 'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
32
+ # Period 6
33
+ 'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
34
+ 'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
35
+ 'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
36
+ 'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
37
+ 'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
38
+ # Period 7
39
+ 'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
40
+ 'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
41
+ 'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
42
+ 'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
43
+ 'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
44
+ }
45
+
46
+
47
+ ELEMENT_COLORS = {
48
+ # Common elements
49
+ 'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
50
+ 'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
51
+ 'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
52
+ 'Cu': 'orange', 'Zn': 'gray',
53
+ # Additional common elements with colors
54
+ 'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
55
+ 'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
56
+ 'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
57
+ 'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
58
+ 'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
59
+ 'W': 'gray', 'U': 'green'
60
+ }
61
+
62
+ def plot_atomic_trajectory(
63
+ traj_path: str,
64
+ indices_path: Union[str, List[int]],
65
+ output_dir: str,
66
+ output_filename: str = "trajectory_plot.html",
67
+ frame_skip: int = 100,
68
+ plot_title: str = None,
69
+ show_plot: bool = False,
70
+ atom_size_scale: float = 1.0,
71
+ plot_lines: bool = False
72
+ ):
73
+ """
74
+ Create a 3D visualization of atom trajectories with all atom types displayed.
75
+
76
+ Parameters
77
+ ----------
78
+ traj_path : str
79
+ Path to the ASE trajectory file (supports any ASE-readable format like XYZ)
80
+ indices_path : str or List[int]
81
+ Either a path to numpy file containing atom indices to plot trajectories for,
82
+ or a direct list of atom indices
83
+ output_dir : str
84
+ Directory where the output visualization will be saved
85
+ output_filename : str, optional
86
+ Filename for the output visualization (default: "trajectory_plot.html")
87
+ frame_skip : int, optional
88
+ Use every nth frame from the trajectory (default: 100)
89
+ plot_title : str, optional
90
+ Custom title for the plot (default: auto-generated based on atom types)
91
+ show_plot : bool, optional
92
+ Whether to display the plot interactively (default: False)
93
+ atom_size_scale : float, optional
94
+ Scale factor for atom sizes in the visualization (default: 1.0)
95
+ plot_lines : bool, optional
96
+ Whether to connect trajectory points with lines (default: False)
97
+
98
+ Returns
99
+ -------
100
+ plotly.graph_objects.Figure
101
+ The generated plotly figure object that can be further customized
102
+
103
+ Notes
104
+ -----
105
+ This function creates an interactive 3D visualization showing:
106
+ 1. All atoms from the first frame, colored by element
107
+ 2. Trajectory paths for selected atoms throughout all frames
108
+ 3. Annotations for the start and end positions of traced atoms
109
+
110
+ The output is saved as an HTML file which can be opened in any web browser.
111
+ """
112
+ print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
113
+ traj = read(traj_path, index=f'::{frame_skip}')
114
+
115
+ # Convert to list if not already (happens with single frame)
116
+ if not isinstance(traj, list):
117
+ traj = [traj]
118
+
119
+ print(f"Loaded {len(traj)} frames from trajectory")
120
+
121
+ if isinstance(indices_path, str):
122
+ selected_indices = np.load(indices_path)
123
+ print(f"Loaded {len(selected_indices)} atoms for trajectory plotting from {indices_path}")
124
+ else:
125
+ selected_indices = np.array(indices_path)
126
+ print(f"Using {len(selected_indices)} directly provided atom indices for trajectory plotting")
127
+
128
+ box = traj[0].cell.lengths()
129
+ print(f"Simulation box dimensions: {box} Å")
130
+
131
+ atom_types = {}
132
+ max_index = max([atom.index for atom in traj[0]])
133
+ print(f"Analyzing atom types in first frame (total atoms: {len(traj[0])}, max index: {max_index})...")
134
+
135
+ for atom in traj[0]:
136
+ symbol = atom.symbol
137
+ if symbol not in atom_types:
138
+ atom_types[symbol] = []
139
+ atom_types[symbol].append(atom.index)
140
+
141
+ print(f"Found {len(atom_types)} atom types: {', '.join(atom_types.keys())}")
142
+
143
+ fig = go.Figure()
144
+
145
+ use_same_color = len(selected_indices) > 5
146
+ colors = ['blue'] * len(selected_indices) if use_same_color else [
147
+ 'blue', 'green', 'red', 'orange', 'purple'
148
+ ][:len(selected_indices)]
149
+
150
+ for symbol, indices in atom_types.items():
151
+ positions = np.array([traj[0].positions[i] for i in indices])
152
+
153
+ # Skip if no atoms of this type
154
+ if len(positions) == 0:
155
+ continue
156
+
157
+ size = VDW_RADII.get(symbol, 1.0) * 3.0 * atom_size_scale
158
+ color = ELEMENT_COLORS.get(symbol, 'gray')
159
+
160
+ fig.add_trace(go.Scatter3d(
161
+ x=positions[:, 0],
162
+ y=positions[:, 1],
163
+ z=positions[:, 2],
164
+ mode='markers',
165
+ name=f'{symbol} Atoms',
166
+ marker=dict(
167
+ size=size,
168
+ color=color,
169
+ symbol='circle',
170
+ opacity=0.7,
171
+ line=dict(color='black', width=0.5)
172
+ )
173
+ ))
174
+
175
+ selected_positions = {idx: [] for idx in selected_indices}
176
+ for atoms in traj:
177
+ for idx in selected_indices:
178
+ if idx < len(atoms):
179
+ selected_positions[idx].append(atoms.positions[idx])
180
+ else:
181
+ print(f"Warning: Index {idx} is out of range")
182
+
183
+ annotations = []
184
+ for i, idx in enumerate(selected_indices):
185
+ if not selected_positions[idx]:
186
+ continue
187
+
188
+ pos = np.array(selected_positions[idx])
189
+ color = colors[i % len(colors)]
190
+
191
+ # Add annotations for first and last frames
192
+ first_frame_pos = pos[0]
193
+ last_frame_pos = pos[-1]
194
+
195
+ annotations.extend([
196
+ dict(
197
+ x=first_frame_pos[0],
198
+ y=first_frame_pos[1],
199
+ z=first_frame_pos[2],
200
+ text=f'Start {idx}',
201
+ showarrow=True,
202
+ arrowhead=2,
203
+ ax=20,
204
+ ay=-20,
205
+ arrowcolor=color,
206
+ font=dict(color=color, size=10)
207
+ ),
208
+ dict(
209
+ x=last_frame_pos[0],
210
+ y=last_frame_pos[1],
211
+ z=last_frame_pos[2],
212
+ text=f'End {idx}',
213
+ showarrow=True,
214
+ arrowhead=2,
215
+ ax=-20,
216
+ ay=-20,
217
+ arrowcolor=color,
218
+ font=dict(color=color, size=10)
219
+ )
220
+ ])
221
+
222
+ if plot_lines:
223
+ # Add trajectory line and markers (original behavior)
224
+ fig.add_trace(go.Scatter3d(
225
+ x=pos[:, 0],
226
+ y=pos[:, 1],
227
+ z=pos[:, 2],
228
+ mode='lines+markers',
229
+ name=f'Atom {idx}',
230
+ line=dict(width=3, color=color),
231
+ marker=dict(size=4, color=color),
232
+ ))
233
+ else:
234
+ # Scatter-only mode: just showing points for each frame
235
+ fig.add_trace(go.Scatter3d(
236
+ x=pos[:, 0],
237
+ y=pos[:, 1],
238
+ z=pos[:, 2],
239
+ mode='markers',
240
+ name=f'Atom {idx}',
241
+ marker=dict(
242
+ size=5,
243
+ color=color,
244
+ symbol='circle',
245
+ opacity=0.8,
246
+ line=dict(color='black', width=0.5)
247
+ )
248
+ ))
249
+
250
+ if not plot_title:
251
+ atom_types_str = ', '.join(atom_types.keys())
252
+ plot_title = f'Atomic Trajectories in {atom_types_str} System'
253
+
254
+ fig.update_layout(
255
+ title=plot_title,
256
+ scene=dict(
257
+ xaxis_title='X (Å)',
258
+ yaxis_title='Y (Å)',
259
+ zaxis_title='Z (Å)',
260
+ xaxis=dict(range=[0, box[0]]),
261
+ yaxis=dict(range=[0, box[1]]),
262
+ zaxis=dict(range=[0, box[2]]),
263
+ aspectmode='cube'
264
+ ),
265
+ margin=dict(l=0, r=0, b=0, t=40),
266
+ scene_annotations=annotations
267
+ )
268
+
269
+ os.makedirs(output_dir, exist_ok=True)
270
+ output_path = os.path.join(output_dir, output_filename)
271
+
272
+ fig.write_html(output_path)
273
+ print(f"Plot has been saved to {output_path}")
274
+
275
+ if show_plot:
276
+ fig.show()
277
+
278
+ return fig
@@ -0,0 +1,254 @@
1
+ """
2
+ CRISP/simulation_utility/error_analysis.py
3
+
4
+ This module provides statistical error analysis tools for molecular dynamics simulations,
5
+ including autocorrelation and block averaging methods to estimate statistical errors.
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from statsmodels.tsa.stattools import acf
11
+ import warnings
12
+ from typing import Dict, Optional, Any
13
+
14
+ __all__ = ['optimal_lag', 'vector_acf', 'autocorrelation_analysis', 'block_analysis']
15
+
16
+ def optimal_lag(acf_values: np.ndarray, threshold: float = 0.05) -> int:
17
+ """Find the optimal lag time at which autocorrelation drops below threshold.
18
+
19
+ Parameters
20
+ ----------
21
+ acf_values : numpy.ndarray
22
+ Array of autocorrelation function values
23
+ threshold : float, optional
24
+ Correlation threshold below which data is considered uncorrelated (default: 0.05)
25
+
26
+ Returns
27
+ -------
28
+ int
29
+ Optimal lag time where autocorrelation drops below threshold
30
+
31
+ Warns
32
+ -----
33
+ UserWarning
34
+ If autocorrelation function does not converge within available data
35
+ """
36
+ for lag, value in enumerate(acf_values):
37
+ if abs(value) < threshold:
38
+ return lag
39
+
40
+ acf_not_converged = (
41
+ "Autocorrelation function is not converged. "
42
+ f"Consider increasing the 'max_lag' parameter (current: {len(acf_values) - 1}) "
43
+ "or extending the simulation length."
44
+ )
45
+ warnings.warn(acf_not_converged)
46
+
47
+ return len(acf_values) - 1
48
+
49
+
50
+ def vector_acf(data, max_lag):
51
+ """
52
+ Calculate autocorrelation function for vector data.
53
+
54
+ Parameters
55
+ ----------
56
+ data : numpy.ndarray
57
+ Input vector data with shape (n_frames, n_dimensions)
58
+ max_lag : int
59
+ Maximum lag time to calculate autocorrelation for
60
+
61
+ Returns
62
+ -------
63
+ numpy.ndarray
64
+ Array of autocorrelation values from lag 0 to max_lag
65
+ """
66
+ n_frames = data.shape[0]
67
+ m = np.mean(data, axis=0)
68
+ data_centered = data - m
69
+ norm0 = np.mean(np.sum(data_centered**2, axis=1))
70
+ acf_vals = np.zeros(max_lag + 1)
71
+ for tau in range(max_lag + 1):
72
+ dots = np.sum(data_centered[:n_frames - tau] * data_centered[tau:], axis=1)
73
+ acf_vals[tau] = np.mean(dots) / norm0
74
+
75
+ return acf_vals
76
+
77
+
78
+ def autocorrelation_analysis(data, max_lag=None, threshold=0.05, plot_acf=False):
79
+ """
80
+ Perform autocorrelation analysis to estimate statistical errors.
81
+
82
+ Parameters
83
+ ----------
84
+ data : numpy.ndarray
85
+ Input data (1D array or multi-dimensional array)
86
+ max_lag : int, optional
87
+ Maximum lag time to calculate autocorrelation for (default: min(1000, N/10))
88
+ threshold : float, optional
89
+ Correlation threshold below which data is considered uncorrelated (default: 0.05)
90
+ plot_acf : bool, optional
91
+ Whether to generate an autocorrelation plot (default: False)
92
+
93
+ Returns
94
+ -------
95
+ dict
96
+ Dictionary containing:
97
+ - mean: Mean value of data
98
+ - acf_err: Error estimate from autocorrelation analysis
99
+ - std: Standard deviation of data
100
+ - tau_int: Integrated autocorrelation time
101
+ - optimal_lag: Optimal lag time where autocorrelation drops below threshold
102
+ """
103
+ if data.ndim == 1:
104
+ N = len(data)
105
+ mean_value = np.mean(data)
106
+ std_value = np.std(data, ddof=1)
107
+ if max_lag is None:
108
+ max_lag = min(1000, N // 10)
109
+ acf_values = acf(data - mean_value, nlags=max_lag, fft=True)
110
+ opt_lag = optimal_lag(acf_values, threshold)
111
+ else:
112
+ N = data.shape[0]
113
+ mean_value = np.mean(data, axis=0)
114
+ std_value = np.std(data, axis=0, ddof=1)
115
+ if max_lag is None:
116
+ max_lag = min(N // 20)
117
+ acf_values = vector_acf(data, max_lag)
118
+ opt_lag = optimal_lag(acf_values, threshold)
119
+
120
+ tau_int = 0.5 + np.sum(acf_values[1:opt_lag + 1])
121
+ autocorr_error = std_value * np.sqrt(2 * tau_int / N)
122
+
123
+ if plot_acf:
124
+ plt.figure(figsize=(8, 5))
125
+ plt.plot(np.arange(len(acf_values)), acf_values, linestyle='-', linewidth=2, color="blue", label='ACF')
126
+ plt.axhline(y=threshold, color='red', linestyle='--', linewidth=1.5, label='Threshold')
127
+ plt.xlabel('Lag')
128
+ plt.ylabel('Autocorrelation')
129
+ plt.title('Autocorrelation Function (ACF)')
130
+ plt.legend()
131
+ plt.savefig("ACF_lag_analysis.png", dpi=300, bbox_inches='tight')
132
+
133
+ return {"mean": mean_value, "acf_err": autocorr_error, "std": std_value, "tau_int": tau_int, "optimal_lag": opt_lag}
134
+
135
+
136
+ def block_analysis(data, convergence_tol=0.001, plot_blocks=False):
137
+ """
138
+ Perform block averaging analysis to estimate statistical errors.
139
+
140
+ Parameters
141
+ ----------
142
+ data : numpy.ndarray
143
+ Input data array
144
+ convergence_tol : float, optional
145
+ Tolerance for determining convergence of standard error (default: 0.001)
146
+ plot_blocks : bool, optional
147
+ Whether to generate a block averaging plot (default: False)
148
+
149
+ Returns
150
+ -------
151
+ dict
152
+ Dictionary containing:
153
+ - mean: Mean value of data
154
+ - block_err: Error estimate from block averaging
155
+ - std: Standard deviation of data
156
+ - converged_blocks: Number of blocks at convergence
157
+
158
+ Warns
159
+ -----
160
+ UserWarning
161
+ If block averaging does not converge with the given tolerance
162
+ """
163
+ N = len(data)
164
+ mean_value = np.mean(data)
165
+ std_value = np.std(data, ddof=1)
166
+
167
+ block_sizes = np.arange(1, N // 2)
168
+ standard_errors = []
169
+
170
+ for M in block_sizes:
171
+ block_length = N // M
172
+
173
+ truncated_data = data[:block_length * M]
174
+ blocks = truncated_data.reshape(M, block_length)
175
+ block_means = np.mean(blocks, axis=1)
176
+
177
+ if len(block_means) > 1:
178
+ std_error = np.std(block_means, ddof=1) / np.sqrt(M)
179
+ else:
180
+ continue
181
+
182
+ standard_errors.append(std_error)
183
+
184
+ if len(standard_errors) > 5:
185
+ recent_errors = standard_errors[-5:]
186
+ if np.max(recent_errors) - np.min(recent_errors) < convergence_tol:
187
+ converged_blocks = M
188
+ final_error = std_error
189
+ break
190
+ else:
191
+ converged_blocks = block_sizes[-1]
192
+ final_error = standard_errors[-1]
193
+ warnings.warn("Block averaging did not fully converge. Consider increasing data length or lowering tolerance.")
194
+
195
+ if plot_blocks:
196
+ plt.figure(figsize=(8, 5))
197
+ plt.plot(block_sizes[:len(standard_errors)], standard_errors, color="blue", label='Standard Error')
198
+ plt.xlabel('Number of Blocks')
199
+ plt.ylabel('Standard Error')
200
+ plt.title('Block Averaging Convergence')
201
+ plt.savefig("block_averaging_convergence.png", dpi=300, bbox_inches='tight')
202
+
203
+ return {
204
+ "mean": mean_value,
205
+ "block_err": final_error,
206
+ "std": std_value,
207
+ "converged_blocks": converged_blocks
208
+ }
209
+
210
+
211
+ def error_analysis(data, max_lag=None, threshold=0.05, convergence_tol=0.001, plot=False):
212
+ """
213
+ Perform comprehensive error analysis using both autocorrelation and block averaging.
214
+
215
+ Parameters
216
+ ----------
217
+ data : numpy.ndarray
218
+ Input data array
219
+ max_lag : int, optional
220
+ Maximum lag time for autocorrelation (default: min(1000, N/10))
221
+ threshold : float, optional
222
+ Correlation threshold for autocorrelation analysis (default: 0.05)
223
+ convergence_tol : float, optional
224
+ Convergence tolerance for block averaging (default: 0.001)
225
+ plot : bool, optional
226
+ Whether to generate diagnostic plots (default: False)
227
+
228
+ Returns
229
+ -------
230
+ dict
231
+ Dictionary containing results from both methods:
232
+ - mean: Mean value of data
233
+ - std: Standard deviation of data
234
+ - acf_results: Full results from autocorrelation analysis
235
+ - block_results: Full results from block averaging analysis
236
+ """
237
+ # Ensure data is a numpy array
238
+ data = np.asarray(data)
239
+
240
+ # Perform both types of analysis
241
+ acf_results = autocorrelation_analysis(data, max_lag, threshold, plot_acf=plot)
242
+ block_results = block_analysis(data, convergence_tol, plot_blocks=plot)
243
+
244
+ # Combine results
245
+ results = {
246
+ "mean": acf_results["mean"],
247
+ "std": acf_results["std"],
248
+ "acf_results": acf_results,
249
+ "block_results": block_results
250
+ }
251
+
252
+ return results
253
+
254
+