crisp-ase 1.0.0.post0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crisp-ase might be problematic. Click here for more details.

@@ -0,0 +1,278 @@
1
+ """
2
+ CRISP/simulation_utility/atomic_traj_linemap.py
3
+
4
+ This module provides functionality for visualizing atomic trajectories from molecular dynamics
5
+ simulations using interactive 3D plots.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ from typing import List, Optional, Dict, Union
11
+ from ase.io import read
12
+ import plotly.graph_objects as go
13
+ import plotly.io as pio
14
+ pio.renderers.default = "notebook"
15
+
16
+ # Dictionary of van der Waals radii for all elements in Ångström
17
+ VDW_RADII = {
18
+ # Period 1
19
+ 'H': 1.20, 'He': 1.40,
20
+ # Period 2
21
+ 'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
22
+ # Period 3
23
+ 'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
24
+ # Period 4
25
+ 'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
26
+ 'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
27
+ 'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
28
+ # Period 5
29
+ 'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
30
+ 'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
31
+ 'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
32
+ # Period 6
33
+ 'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
34
+ 'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
35
+ 'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
36
+ 'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
37
+ 'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
38
+ # Period 7
39
+ 'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
40
+ 'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
41
+ 'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
42
+ 'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
43
+ 'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
44
+ }
45
+
46
+
47
+ ELEMENT_COLORS = {
48
+ # Common elements
49
+ 'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
50
+ 'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
51
+ 'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
52
+ 'Cu': 'orange', 'Zn': 'gray',
53
+ # Additional common elements with colors
54
+ 'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
55
+ 'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
56
+ 'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
57
+ 'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
58
+ 'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
59
+ 'W': 'gray', 'U': 'green'
60
+ }
61
+
62
+ def plot_atomic_trajectory(
63
+ traj_path: str,
64
+ indices_path: Union[str, List[int]],
65
+ output_dir: str,
66
+ output_filename: str = "trajectory_plot.html",
67
+ frame_skip: int = 100,
68
+ plot_title: str = None,
69
+ show_plot: bool = False,
70
+ atom_size_scale: float = 1.0,
71
+ plot_lines: bool = False
72
+ ):
73
+ """
74
+ Create a 3D visualization of atom trajectories with all atom types displayed.
75
+
76
+ Parameters
77
+ ----------
78
+ traj_path : str
79
+ Path to the ASE trajectory file (supports any ASE-readable format like XYZ)
80
+ indices_path : str or List[int]
81
+ Either a path to numpy file containing atom indices to plot trajectories for,
82
+ or a direct list of atom indices
83
+ output_dir : str
84
+ Directory where the output visualization will be saved
85
+ output_filename : str, optional
86
+ Filename for the output visualization (default: "trajectory_plot.html")
87
+ frame_skip : int, optional
88
+ Use every nth frame from the trajectory (default: 100)
89
+ plot_title : str, optional
90
+ Custom title for the plot (default: auto-generated based on atom types)
91
+ show_plot : bool, optional
92
+ Whether to display the plot interactively (default: False)
93
+ atom_size_scale : float, optional
94
+ Scale factor for atom sizes in the visualization (default: 1.0)
95
+ plot_lines : bool, optional
96
+ Whether to connect trajectory points with lines (default: False)
97
+
98
+ Returns
99
+ -------
100
+ plotly.graph_objects.Figure
101
+ The generated plotly figure object that can be further customized
102
+
103
+ Notes
104
+ -----
105
+ This function creates an interactive 3D visualization showing:
106
+ 1. All atoms from the first frame, colored by element
107
+ 2. Trajectory paths for selected atoms throughout all frames
108
+ 3. Annotations for the start and end positions of traced atoms
109
+
110
+ The output is saved as an HTML file which can be opened in any web browser.
111
+ """
112
+ print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
113
+ traj = read(traj_path, index=f'::{frame_skip}')
114
+
115
+ # Convert to list if not already (happens with single frame)
116
+ if not isinstance(traj, list):
117
+ traj = [traj]
118
+
119
+ print(f"Loaded {len(traj)} frames from trajectory")
120
+
121
+ if isinstance(indices_path, str):
122
+ selected_indices = np.load(indices_path)
123
+ print(f"Loaded {len(selected_indices)} atoms for trajectory plotting from {indices_path}")
124
+ else:
125
+ selected_indices = np.array(indices_path)
126
+ print(f"Using {len(selected_indices)} directly provided atom indices for trajectory plotting")
127
+
128
+ box = traj[0].cell.lengths()
129
+ print(f"Simulation box dimensions: {box} Å")
130
+
131
+ atom_types = {}
132
+ max_index = max([atom.index for atom in traj[0]])
133
+ print(f"Analyzing atom types in first frame (total atoms: {len(traj[0])}, max index: {max_index})...")
134
+
135
+ for atom in traj[0]:
136
+ symbol = atom.symbol
137
+ if symbol not in atom_types:
138
+ atom_types[symbol] = []
139
+ atom_types[symbol].append(atom.index)
140
+
141
+ print(f"Found {len(atom_types)} atom types: {', '.join(atom_types.keys())}")
142
+
143
+ fig = go.Figure()
144
+
145
+ use_same_color = len(selected_indices) > 5
146
+ colors = ['blue'] * len(selected_indices) if use_same_color else [
147
+ 'blue', 'green', 'red', 'orange', 'purple'
148
+ ][:len(selected_indices)]
149
+
150
+ for symbol, indices in atom_types.items():
151
+ positions = np.array([traj[0].positions[i] for i in indices])
152
+
153
+ # Skip if no atoms of this type
154
+ if len(positions) == 0:
155
+ continue
156
+
157
+ size = VDW_RADII.get(symbol, 1.0) * 3.0 * atom_size_scale
158
+ color = ELEMENT_COLORS.get(symbol, 'gray')
159
+
160
+ fig.add_trace(go.Scatter3d(
161
+ x=positions[:, 0],
162
+ y=positions[:, 1],
163
+ z=positions[:, 2],
164
+ mode='markers',
165
+ name=f'{symbol} Atoms',
166
+ marker=dict(
167
+ size=size,
168
+ color=color,
169
+ symbol='circle',
170
+ opacity=0.7,
171
+ line=dict(color='black', width=0.5)
172
+ )
173
+ ))
174
+
175
+ selected_positions = {idx: [] for idx in selected_indices}
176
+ for atoms in traj:
177
+ for idx in selected_indices:
178
+ if idx < len(atoms):
179
+ selected_positions[idx].append(atoms.positions[idx])
180
+ else:
181
+ print(f"Warning: Index {idx} is out of range")
182
+
183
+ annotations = []
184
+ for i, idx in enumerate(selected_indices):
185
+ if not selected_positions[idx]:
186
+ continue
187
+
188
+ pos = np.array(selected_positions[idx])
189
+ color = colors[i % len(colors)]
190
+
191
+ # Add annotations for first and last frames
192
+ first_frame_pos = pos[0]
193
+ last_frame_pos = pos[-1]
194
+
195
+ annotations.extend([
196
+ dict(
197
+ x=first_frame_pos[0],
198
+ y=first_frame_pos[1],
199
+ z=first_frame_pos[2],
200
+ text=f'Start {idx}',
201
+ showarrow=True,
202
+ arrowhead=2,
203
+ ax=20,
204
+ ay=-20,
205
+ arrowcolor=color,
206
+ font=dict(color=color, size=10)
207
+ ),
208
+ dict(
209
+ x=last_frame_pos[0],
210
+ y=last_frame_pos[1],
211
+ z=last_frame_pos[2],
212
+ text=f'End {idx}',
213
+ showarrow=True,
214
+ arrowhead=2,
215
+ ax=-20,
216
+ ay=-20,
217
+ arrowcolor=color,
218
+ font=dict(color=color, size=10)
219
+ )
220
+ ])
221
+
222
+ if plot_lines:
223
+ # Add trajectory line and markers (original behavior)
224
+ fig.add_trace(go.Scatter3d(
225
+ x=pos[:, 0],
226
+ y=pos[:, 1],
227
+ z=pos[:, 2],
228
+ mode='lines+markers',
229
+ name=f'Atom {idx}',
230
+ line=dict(width=3, color=color),
231
+ marker=dict(size=4, color=color),
232
+ ))
233
+ else:
234
+ # Scatter-only mode: just showing points for each frame
235
+ fig.add_trace(go.Scatter3d(
236
+ x=pos[:, 0],
237
+ y=pos[:, 1],
238
+ z=pos[:, 2],
239
+ mode='markers',
240
+ name=f'Atom {idx}',
241
+ marker=dict(
242
+ size=5,
243
+ color=color,
244
+ symbol='circle',
245
+ opacity=0.8,
246
+ line=dict(color='black', width=0.5)
247
+ )
248
+ ))
249
+
250
+ if not plot_title:
251
+ atom_types_str = ', '.join(atom_types.keys())
252
+ plot_title = f'Atomic Trajectories in {atom_types_str} System'
253
+
254
+ fig.update_layout(
255
+ title=plot_title,
256
+ scene=dict(
257
+ xaxis_title='X (Å)',
258
+ yaxis_title='Y (Å)',
259
+ zaxis_title='Z (Å)',
260
+ xaxis=dict(range=[0, box[0]]),
261
+ yaxis=dict(range=[0, box[1]]),
262
+ zaxis=dict(range=[0, box[2]]),
263
+ aspectmode='cube'
264
+ ),
265
+ margin=dict(l=0, r=0, b=0, t=40),
266
+ scene_annotations=annotations
267
+ )
268
+
269
+ os.makedirs(output_dir, exist_ok=True)
270
+ output_path = os.path.join(output_dir, output_filename)
271
+
272
+ fig.write_html(output_path)
273
+ print(f"Plot has been saved to {output_path}")
274
+
275
+ if show_plot:
276
+ fig.show()
277
+
278
+ return fig
@@ -0,0 +1,253 @@
1
+ """
2
+ CRISP/simulation_utility/error_analysis.py
3
+
4
+ This module provides statistical error analysis tools for molecular dynamics simulations,
5
+ including autocorrelation and block averaging methods to estimate statistical errors.
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from statsmodels.tsa.stattools import acf
11
+ import warnings
12
+
13
+
14
+ def optimal_lag(acf_values, threshold=0.05):
15
+ """
16
+ Find the optimal lag time at which autocorrelation drops below threshold.
17
+
18
+ Parameters
19
+ ----------
20
+ acf_values : numpy.ndarray
21
+ Array of autocorrelation function values
22
+ threshold : float, optional
23
+ Correlation threshold below which data is considered uncorrelated (default: 0.05)
24
+
25
+ Returns
26
+ -------
27
+ int
28
+ Optimal lag time where autocorrelation drops below threshold
29
+
30
+ Warns
31
+ -----
32
+ UserWarning
33
+ If autocorrelation function doesn't converge within available data
34
+ """
35
+ for lag, value in enumerate(acf_values):
36
+ if abs(value) < threshold:
37
+ return lag
38
+
39
+ acf_not_converged = (
40
+ "Autocorrelation function is not converged. "
41
+ f"Consider increasing the 'max_lag' parameter (current: {len(acf_values) - 1}) "
42
+ "or extending the simulation length."
43
+ )
44
+ warnings.warn(acf_not_converged)
45
+
46
+ return len(acf_values) - 1
47
+
48
+
49
+ def vector_acf(data, max_lag):
50
+ """
51
+ Calculate autocorrelation function for vector data.
52
+
53
+ Parameters
54
+ ----------
55
+ data : numpy.ndarray
56
+ Input vector data with shape (n_frames, n_dimensions)
57
+ max_lag : int
58
+ Maximum lag time to calculate autocorrelation for
59
+
60
+ Returns
61
+ -------
62
+ numpy.ndarray
63
+ Array of autocorrelation values from lag 0 to max_lag
64
+ """
65
+ n_frames = data.shape[0]
66
+ m = np.mean(data, axis=0)
67
+ data_centered = data - m
68
+ norm0 = np.mean(np.sum(data_centered**2, axis=1))
69
+ acf_vals = np.zeros(max_lag + 1)
70
+ for tau in range(max_lag + 1):
71
+ dots = np.sum(data_centered[:n_frames - tau] * data_centered[tau:], axis=1)
72
+ acf_vals[tau] = np.mean(dots) / norm0
73
+
74
+ return acf_vals
75
+
76
+
77
+ def autocorrelation_analysis(data, max_lag=None, threshold=0.05, plot_acf=False):
78
+ """
79
+ Perform autocorrelation analysis to estimate statistical errors.
80
+
81
+ Parameters
82
+ ----------
83
+ data : numpy.ndarray
84
+ Input data (1D array or multi-dimensional array)
85
+ max_lag : int, optional
86
+ Maximum lag time to calculate autocorrelation for (default: min(1000, N/10))
87
+ threshold : float, optional
88
+ Correlation threshold below which data is considered uncorrelated (default: 0.05)
89
+ plot_acf : bool, optional
90
+ Whether to generate an autocorrelation plot (default: False)
91
+
92
+ Returns
93
+ -------
94
+ dict
95
+ Dictionary containing:
96
+ - mean: Mean value of data
97
+ - acf_err: Error estimate from autocorrelation analysis
98
+ - std: Standard deviation of data
99
+ - tau_int: Integrated autocorrelation time
100
+ - optimal_lag: Optimal lag time where autocorrelation drops below threshold
101
+ """
102
+ if data.ndim == 1:
103
+ N = len(data)
104
+ mean_value = np.mean(data)
105
+ std_value = np.std(data, ddof=1)
106
+ if max_lag is None:
107
+ max_lag = min(1000, N // 10)
108
+ acf_values = acf(data - mean_value, nlags=max_lag, fft=True)
109
+ opt_lag = optimal_lag(acf_values, threshold)
110
+ else:
111
+ N = data.shape[0]
112
+ mean_value = np.mean(data, axis=0)
113
+ std_value = np.std(data, axis=0, ddof=1)
114
+ if max_lag is None:
115
+ max_lag = min(N // 20)
116
+ acf_values = vector_acf(data, max_lag)
117
+ opt_lag = optimal_lag(acf_values, threshold)
118
+
119
+ tau_int = 0.5 + np.sum(acf_values[1:opt_lag + 1])
120
+ autocorr_error = std_value * np.sqrt(2 * tau_int / N)
121
+
122
+ if plot_acf:
123
+ plt.figure(figsize=(8, 5))
124
+ plt.plot(np.arange(len(acf_values)), acf_values, linestyle='-', linewidth=2, color="blue", label='ACF')
125
+ plt.axhline(y=threshold, color='red', linestyle='--', linewidth=1.5, label='Threshold')
126
+ plt.xlabel('Lag')
127
+ plt.ylabel('Autocorrelation')
128
+ plt.title('Autocorrelation Function (ACF)')
129
+ plt.legend()
130
+ plt.savefig("ACF_lag_analysis.png", dpi=300, bbox_inches='tight')
131
+
132
+ return {"mean": mean_value, "acf_err": autocorr_error, "std": std_value, "tau_int": tau_int, "optimal_lag": opt_lag}
133
+
134
+
135
+ def block_analysis(data, convergence_tol=0.001, plot_blocks=False):
136
+ """
137
+ Perform block averaging analysis to estimate statistical errors.
138
+
139
+ Parameters
140
+ ----------
141
+ data : numpy.ndarray
142
+ Input data array
143
+ convergence_tol : float, optional
144
+ Tolerance for determining convergence of standard error (default: 0.001)
145
+ plot_blocks : bool, optional
146
+ Whether to generate a block averaging plot (default: False)
147
+
148
+ Returns
149
+ -------
150
+ dict
151
+ Dictionary containing:
152
+ - mean: Mean value of data
153
+ - block_err: Error estimate from block averaging
154
+ - std: Standard deviation of data
155
+ - converged_blocks: Number of blocks at convergence
156
+
157
+ Warns
158
+ -----
159
+ UserWarning
160
+ If block averaging does not converge with the given tolerance
161
+ """
162
+ N = len(data)
163
+ mean_value = np.mean(data)
164
+ std_value = np.std(data, ddof=1)
165
+
166
+ block_sizes = np.arange(1, N // 2)
167
+ standard_errors = []
168
+
169
+ for M in block_sizes:
170
+ block_length = N // M
171
+
172
+ truncated_data = data[:block_length * M]
173
+ blocks = truncated_data.reshape(M, block_length)
174
+ block_means = np.mean(blocks, axis=1)
175
+
176
+ if len(block_means) > 1:
177
+ std_error = np.std(block_means, ddof=1) / np.sqrt(M)
178
+ else:
179
+ continue
180
+
181
+ standard_errors.append(std_error)
182
+
183
+ if len(standard_errors) > 5:
184
+ recent_errors = standard_errors[-5:]
185
+ if np.max(recent_errors) - np.min(recent_errors) < convergence_tol:
186
+ converged_blocks = M
187
+ final_error = std_error
188
+ break
189
+ else:
190
+ converged_blocks = block_sizes[-1]
191
+ final_error = standard_errors[-1]
192
+ warnings.warn("Block averaging did not fully converge. Consider increasing data length or lowering tolerance.")
193
+
194
+ if plot_blocks:
195
+ plt.figure(figsize=(8, 5))
196
+ plt.plot(block_sizes[:len(standard_errors)], standard_errors, color="blue", label='Standard Error')
197
+ plt.xlabel('Number of Blocks')
198
+ plt.ylabel('Standard Error')
199
+ plt.title('Block Averaging Convergence')
200
+ plt.savefig("block_averaging_convergence.png", dpi=300, bbox_inches='tight')
201
+
202
+ return {
203
+ "mean": mean_value,
204
+ "block_err": final_error,
205
+ "std": std_value,
206
+ "converged_blocks": converged_blocks
207
+ }
208
+
209
+
210
+ def error_analysis(data, max_lag=None, threshold=0.05, convergence_tol=0.001, plot=False):
211
+ """
212
+ Perform comprehensive error analysis using both autocorrelation and block averaging.
213
+
214
+ Parameters
215
+ ----------
216
+ data : numpy.ndarray
217
+ Input data array
218
+ max_lag : int, optional
219
+ Maximum lag time for autocorrelation (default: min(1000, N/10))
220
+ threshold : float, optional
221
+ Correlation threshold for autocorrelation analysis (default: 0.05)
222
+ convergence_tol : float, optional
223
+ Convergence tolerance for block averaging (default: 0.001)
224
+ plot : bool, optional
225
+ Whether to generate diagnostic plots (default: False)
226
+
227
+ Returns
228
+ -------
229
+ dict
230
+ Dictionary containing results from both methods:
231
+ - mean: Mean value of data
232
+ - std: Standard deviation of data
233
+ - acf_results: Full results from autocorrelation analysis
234
+ - block_results: Full results from block averaging analysis
235
+ """
236
+ # Ensure data is a numpy array
237
+ data = np.asarray(data)
238
+
239
+ # Perform both types of analysis
240
+ acf_results = autocorrelation_analysis(data, max_lag, threshold, plot_acf=plot)
241
+ block_results = block_analysis(data, convergence_tol, plot_blocks=plot)
242
+
243
+ # Combine results
244
+ results = {
245
+ "mean": acf_results["mean"],
246
+ "std": acf_results["std"],
247
+ "acf_results": acf_results,
248
+ "block_results": block_results
249
+ }
250
+
251
+ return results
252
+
253
+