crisp-ase 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRISP/__init__.py +99 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +41 -0
- CRISP/data_analysis/__init__.py +38 -0
- CRISP/data_analysis/clustering.py +838 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +772 -0
- CRISP/data_analysis/msd.py +1199 -0
- CRISP/data_analysis/prdf.py +404 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +31 -0
- CRISP/simulation_utility/atomic_indices.py +155 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +254 -0
- CRISP/simulation_utility/interatomic_distances.py +200 -0
- CRISP/simulation_utility/subsampling.py +241 -0
- CRISP/tests/DataAnalysis/__init__.py +1 -0
- CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
- CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
- CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
- CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
- CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
- CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
- CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
- CRISP/tests/DataAnalysis/test_prdf.py +206 -0
- CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
- CRISP/tests/SimulationUtility/__init__.py +1 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
- CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
- CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
- CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
- CRISP/tests/__init__.py +1 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_cli.py +87 -0
- CRISP/tests/test_crisp_comprehensive.py +679 -0
- crisp_ase-1.1.2.dist-info/METADATA +141 -0
- crisp_ase-1.1.2.dist-info/RECORD +42 -0
- crisp_ase-1.1.2.dist-info/WHEEL +5 -0
- crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
- crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/atomic_indices.py
|
|
3
|
+
|
|
4
|
+
This module extracts atomic indices from trajectory files
|
|
5
|
+
and identifying atom pairs within specified cutoff distances.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
import ase.io
|
|
11
|
+
import csv
|
|
12
|
+
from typing import Dict, Tuple, List, Optional
|
|
13
|
+
|
|
14
|
+
__all__ = ['atom_indices', 'run_atom_indices']
|
|
15
|
+
|
|
16
|
+
def atom_indices(
|
|
17
|
+
traj_path: str,
|
|
18
|
+
frame_index: int = 0,
|
|
19
|
+
custom_cutoffs: Optional[Dict[Tuple[str, str], float]] = None
|
|
20
|
+
) -> Tuple[Dict[str, List[int]], np.ndarray, Dict[Tuple[str, str], List[Tuple[int, int, float]]]]:
|
|
21
|
+
"""Extract atom indices by chemical symbol and find atom pairs within specified cutoffs.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
traj_path : str
|
|
26
|
+
Path to the trajectory file in any format supported by ASE
|
|
27
|
+
frame_index : int, optional
|
|
28
|
+
Index of the frame to analyze (default: 0)
|
|
29
|
+
custom_cutoffs : dict, optional
|
|
30
|
+
Dictionary with atom symbol pairs as keys and cutoff distances as values
|
|
31
|
+
Example: {('Si', 'O'): 2.0, ('Al', 'O'): 2.1}
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
indices_by_symbol : dict
|
|
36
|
+
Dictionary with chemical symbols as keys and lists of atomic indices as values
|
|
37
|
+
dist_matrix : numpy.ndarray
|
|
38
|
+
Distance matrix between all atoms, accounting for periodic boundary conditions
|
|
39
|
+
cutoff_indices : dict
|
|
40
|
+
Dictionary with atom symbol pairs as keys and lists of (idx1, idx2, distance) tuples
|
|
41
|
+
for atoms that are within the specified cutoff distance
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
new = ase.io.read(traj_path, index=frame_index)
|
|
45
|
+
dist_matrix = new.get_all_distances(mic=True)
|
|
46
|
+
symbols = new.get_chemical_symbols()
|
|
47
|
+
|
|
48
|
+
unique_symbols = list(set(symbols))
|
|
49
|
+
|
|
50
|
+
indices_by_symbol = {symbol: [] for symbol in unique_symbols}
|
|
51
|
+
|
|
52
|
+
for idx, atom in enumerate(new):
|
|
53
|
+
indices_by_symbol[atom.symbol].append(idx)
|
|
54
|
+
|
|
55
|
+
cutoff_indices = {}
|
|
56
|
+
|
|
57
|
+
if custom_cutoffs:
|
|
58
|
+
for pair, cutoff in custom_cutoffs.items():
|
|
59
|
+
symbol1, symbol2 = pair
|
|
60
|
+
pair_indices_distances = []
|
|
61
|
+
if symbol1 in indices_by_symbol and symbol2 in indices_by_symbol:
|
|
62
|
+
for idx1 in indices_by_symbol[symbol1]:
|
|
63
|
+
for idx2 in indices_by_symbol[symbol2]:
|
|
64
|
+
if dist_matrix[idx1, idx2] < cutoff:
|
|
65
|
+
pair_indices_distances.append(
|
|
66
|
+
(idx1, idx2, dist_matrix[idx1, idx2])
|
|
67
|
+
)
|
|
68
|
+
cutoff_indices[pair] = pair_indices_distances
|
|
69
|
+
|
|
70
|
+
return indices_by_symbol, dist_matrix, cutoff_indices
|
|
71
|
+
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise ValueError(f"Error processing atomic structure: {e}. Check if the format is supported by ASE.")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def run_atom_indices(
|
|
77
|
+
traj_path: str,
|
|
78
|
+
output_dir: str,
|
|
79
|
+
frame_index: int = 0,
|
|
80
|
+
custom_cutoffs: Optional[Dict[Tuple[str, str], float]] = None
|
|
81
|
+
) -> None:
|
|
82
|
+
"""Run atom index extraction and save results to files.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
traj_path : str
|
|
87
|
+
Path to the trajectory file in any format supported by ASE
|
|
88
|
+
output_dir : str
|
|
89
|
+
Directory where output files will be saved
|
|
90
|
+
frame_index : int, optional
|
|
91
|
+
Index of the frame to analyze (default: 0)
|
|
92
|
+
custom_cutoffs : dict, optional
|
|
93
|
+
Dictionary with atom symbol pairs as keys and cutoff distances as values
|
|
94
|
+
|
|
95
|
+
Returns
|
|
96
|
+
-------
|
|
97
|
+
None
|
|
98
|
+
Results are saved to the specified output directory:
|
|
99
|
+
- lengths.npy: Dictionary of number of atoms per element
|
|
100
|
+
- {symbol}_indices.npy: Numpy array of atom indices for each element
|
|
101
|
+
- cutoff/{symbol1}-{symbol2}_cutoff.csv: CSV files with atom pairs within cutoff
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
try:
|
|
105
|
+
traj = ase.io.read(traj_path, index=":")
|
|
106
|
+
if isinstance(traj, list):
|
|
107
|
+
traj_length = len(traj)
|
|
108
|
+
else:
|
|
109
|
+
traj_length = 1
|
|
110
|
+
except TypeError:
|
|
111
|
+
ase.io.read(traj_path)
|
|
112
|
+
traj_length = 1
|
|
113
|
+
|
|
114
|
+
# Check if frame_index is within valid range
|
|
115
|
+
if frame_index < 0 or frame_index >= traj_length:
|
|
116
|
+
raise ValueError(
|
|
117
|
+
f"Error: Frame index {frame_index} is out of range. "
|
|
118
|
+
f"Valid range is 0 to {traj_length-1}."
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
print(f"Analyzing frame with index {frame_index} (out of {traj_length} frames)")
|
|
122
|
+
|
|
123
|
+
except ValueError as e:
|
|
124
|
+
raise e
|
|
125
|
+
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise ValueError(f"Error reading trajectory: {e}. Check if the format is supported by ASE.")
|
|
128
|
+
|
|
129
|
+
indices, dist_matrix, cutoff_indices = atom_indices(traj_path, frame_index, custom_cutoffs)
|
|
130
|
+
|
|
131
|
+
if not os.path.exists(output_dir):
|
|
132
|
+
os.makedirs(output_dir)
|
|
133
|
+
|
|
134
|
+
lengths = {symbol: len(indices[symbol]) for symbol in indices}
|
|
135
|
+
np.save(os.path.join(output_dir, "lengths.npy"), lengths)
|
|
136
|
+
|
|
137
|
+
for symbol, data in indices.items():
|
|
138
|
+
np.save(os.path.join(output_dir, f"{symbol}_indices.npy"), data)
|
|
139
|
+
print(f"Length of {symbol} indices: {len(data)}")
|
|
140
|
+
|
|
141
|
+
print("Outputs saved.")
|
|
142
|
+
|
|
143
|
+
cutoff_folder = os.path.join(output_dir, "cutoff")
|
|
144
|
+
if not os.path.exists(cutoff_folder):
|
|
145
|
+
os.makedirs(cutoff_folder)
|
|
146
|
+
|
|
147
|
+
for pair, pair_indices_distances in cutoff_indices.items():
|
|
148
|
+
symbol1, symbol2 = pair
|
|
149
|
+
filename = f"{symbol1}-{symbol2}_cutoff.csv"
|
|
150
|
+
filepath = os.path.join(cutoff_folder, filename)
|
|
151
|
+
with open(filepath, mode="w", newline="") as file:
|
|
152
|
+
writer = csv.writer(file)
|
|
153
|
+
writer.writerow([f"{symbol1} index", f"{symbol2} index", "distance"])
|
|
154
|
+
writer.writerows(pair_indices_distances)
|
|
155
|
+
print(f"Saved cutoff indices for {symbol1}-{symbol2} to {filepath}")
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/atomic_traj_linemap.py
|
|
3
|
+
|
|
4
|
+
This module provides functionality for visualizing atomic trajectories from molecular dynamics
|
|
5
|
+
simulations using interactive 3D plots.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import List, Optional, Dict, Union
|
|
11
|
+
from ase.io import read
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
import plotly.io as pio
|
|
14
|
+
pio.renderers.default = "notebook"
|
|
15
|
+
__all__ = ['plot_atomic_trajectory', 'VDW_RADII', 'ELEMENT_COLORS']
|
|
16
|
+
# Dictionary of van der Waals radii for all elements in Ångström
|
|
17
|
+
VDW_RADII = {
|
|
18
|
+
# Period 1
|
|
19
|
+
'H': 1.20, 'He': 1.40,
|
|
20
|
+
# Period 2
|
|
21
|
+
'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
|
|
22
|
+
# Period 3
|
|
23
|
+
'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
|
|
24
|
+
# Period 4
|
|
25
|
+
'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
|
|
26
|
+
'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
|
|
27
|
+
'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
|
|
28
|
+
# Period 5
|
|
29
|
+
'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
|
|
30
|
+
'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
|
|
31
|
+
'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
|
|
32
|
+
# Period 6
|
|
33
|
+
'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
|
|
34
|
+
'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
|
|
35
|
+
'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
|
|
36
|
+
'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
|
|
37
|
+
'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
|
|
38
|
+
# Period 7
|
|
39
|
+
'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
|
|
40
|
+
'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
|
|
41
|
+
'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
|
|
42
|
+
'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
|
|
43
|
+
'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
ELEMENT_COLORS = {
|
|
48
|
+
# Common elements
|
|
49
|
+
'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
|
|
50
|
+
'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
|
|
51
|
+
'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
|
|
52
|
+
'Cu': 'orange', 'Zn': 'gray',
|
|
53
|
+
# Additional common elements with colors
|
|
54
|
+
'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
|
|
55
|
+
'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
|
|
56
|
+
'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
|
|
57
|
+
'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
|
|
58
|
+
'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
|
|
59
|
+
'W': 'gray', 'U': 'green'
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
def plot_atomic_trajectory(
|
|
63
|
+
traj_path: str,
|
|
64
|
+
indices_path: Union[str, List[int]],
|
|
65
|
+
output_dir: str,
|
|
66
|
+
output_filename: str = "trajectory_plot.html",
|
|
67
|
+
frame_skip: int = 100,
|
|
68
|
+
plot_title: str = None,
|
|
69
|
+
show_plot: bool = False,
|
|
70
|
+
atom_size_scale: float = 1.0,
|
|
71
|
+
plot_lines: bool = False
|
|
72
|
+
):
|
|
73
|
+
"""
|
|
74
|
+
Create a 3D visualization of atom trajectories with all atom types displayed.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
traj_path : str
|
|
79
|
+
Path to the ASE trajectory file (supports any ASE-readable format like XYZ)
|
|
80
|
+
indices_path : str or List[int]
|
|
81
|
+
Either a path to numpy file containing atom indices to plot trajectories for,
|
|
82
|
+
or a direct list of atom indices
|
|
83
|
+
output_dir : str
|
|
84
|
+
Directory where the output visualization will be saved
|
|
85
|
+
output_filename : str, optional
|
|
86
|
+
Filename for the output visualization (default: "trajectory_plot.html")
|
|
87
|
+
frame_skip : int, optional
|
|
88
|
+
Use every nth frame from the trajectory (default: 100)
|
|
89
|
+
plot_title : str, optional
|
|
90
|
+
Custom title for the plot (default: auto-generated based on atom types)
|
|
91
|
+
show_plot : bool, optional
|
|
92
|
+
Whether to display the plot interactively (default: False)
|
|
93
|
+
atom_size_scale : float, optional
|
|
94
|
+
Scale factor for atom sizes in the visualization (default: 1.0)
|
|
95
|
+
plot_lines : bool, optional
|
|
96
|
+
Whether to connect trajectory points with lines (default: False)
|
|
97
|
+
|
|
98
|
+
Returns
|
|
99
|
+
-------
|
|
100
|
+
plotly.graph_objects.Figure
|
|
101
|
+
The generated plotly figure object that can be further customized
|
|
102
|
+
|
|
103
|
+
Notes
|
|
104
|
+
-----
|
|
105
|
+
This function creates an interactive 3D visualization showing:
|
|
106
|
+
1. All atoms from the first frame, colored by element
|
|
107
|
+
2. Trajectory paths for selected atoms throughout all frames
|
|
108
|
+
3. Annotations for the start and end positions of traced atoms
|
|
109
|
+
|
|
110
|
+
The output is saved as an HTML file which can be opened in any web browser.
|
|
111
|
+
"""
|
|
112
|
+
print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
|
|
113
|
+
traj = read(traj_path, index=f'::{frame_skip}')
|
|
114
|
+
|
|
115
|
+
# Convert to list if not already (happens with single frame)
|
|
116
|
+
if not isinstance(traj, list):
|
|
117
|
+
traj = [traj]
|
|
118
|
+
|
|
119
|
+
print(f"Loaded {len(traj)} frames from trajectory")
|
|
120
|
+
|
|
121
|
+
if isinstance(indices_path, str):
|
|
122
|
+
selected_indices = np.load(indices_path)
|
|
123
|
+
print(f"Loaded {len(selected_indices)} atoms for trajectory plotting from {indices_path}")
|
|
124
|
+
else:
|
|
125
|
+
selected_indices = np.array(indices_path)
|
|
126
|
+
print(f"Using {len(selected_indices)} directly provided atom indices for trajectory plotting")
|
|
127
|
+
|
|
128
|
+
box = traj[0].cell.lengths()
|
|
129
|
+
print(f"Simulation box dimensions: {box} Å")
|
|
130
|
+
|
|
131
|
+
atom_types = {}
|
|
132
|
+
max_index = max([atom.index for atom in traj[0]])
|
|
133
|
+
print(f"Analyzing atom types in first frame (total atoms: {len(traj[0])}, max index: {max_index})...")
|
|
134
|
+
|
|
135
|
+
for atom in traj[0]:
|
|
136
|
+
symbol = atom.symbol
|
|
137
|
+
if symbol not in atom_types:
|
|
138
|
+
atom_types[symbol] = []
|
|
139
|
+
atom_types[symbol].append(atom.index)
|
|
140
|
+
|
|
141
|
+
print(f"Found {len(atom_types)} atom types: {', '.join(atom_types.keys())}")
|
|
142
|
+
|
|
143
|
+
fig = go.Figure()
|
|
144
|
+
|
|
145
|
+
use_same_color = len(selected_indices) > 5
|
|
146
|
+
colors = ['blue'] * len(selected_indices) if use_same_color else [
|
|
147
|
+
'blue', 'green', 'red', 'orange', 'purple'
|
|
148
|
+
][:len(selected_indices)]
|
|
149
|
+
|
|
150
|
+
for symbol, indices in atom_types.items():
|
|
151
|
+
positions = np.array([traj[0].positions[i] for i in indices])
|
|
152
|
+
|
|
153
|
+
# Skip if no atoms of this type
|
|
154
|
+
if len(positions) == 0:
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
size = VDW_RADII.get(symbol, 1.0) * 3.0 * atom_size_scale
|
|
158
|
+
color = ELEMENT_COLORS.get(symbol, 'gray')
|
|
159
|
+
|
|
160
|
+
fig.add_trace(go.Scatter3d(
|
|
161
|
+
x=positions[:, 0],
|
|
162
|
+
y=positions[:, 1],
|
|
163
|
+
z=positions[:, 2],
|
|
164
|
+
mode='markers',
|
|
165
|
+
name=f'{symbol} Atoms',
|
|
166
|
+
marker=dict(
|
|
167
|
+
size=size,
|
|
168
|
+
color=color,
|
|
169
|
+
symbol='circle',
|
|
170
|
+
opacity=0.7,
|
|
171
|
+
line=dict(color='black', width=0.5)
|
|
172
|
+
)
|
|
173
|
+
))
|
|
174
|
+
|
|
175
|
+
selected_positions = {idx: [] for idx in selected_indices}
|
|
176
|
+
for atoms in traj:
|
|
177
|
+
for idx in selected_indices:
|
|
178
|
+
if idx < len(atoms):
|
|
179
|
+
selected_positions[idx].append(atoms.positions[idx])
|
|
180
|
+
else:
|
|
181
|
+
print(f"Warning: Index {idx} is out of range")
|
|
182
|
+
|
|
183
|
+
annotations = []
|
|
184
|
+
for i, idx in enumerate(selected_indices):
|
|
185
|
+
if not selected_positions[idx]:
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
pos = np.array(selected_positions[idx])
|
|
189
|
+
color = colors[i % len(colors)]
|
|
190
|
+
|
|
191
|
+
# Add annotations for first and last frames
|
|
192
|
+
first_frame_pos = pos[0]
|
|
193
|
+
last_frame_pos = pos[-1]
|
|
194
|
+
|
|
195
|
+
annotations.extend([
|
|
196
|
+
dict(
|
|
197
|
+
x=first_frame_pos[0],
|
|
198
|
+
y=first_frame_pos[1],
|
|
199
|
+
z=first_frame_pos[2],
|
|
200
|
+
text=f'Start {idx}',
|
|
201
|
+
showarrow=True,
|
|
202
|
+
arrowhead=2,
|
|
203
|
+
ax=20,
|
|
204
|
+
ay=-20,
|
|
205
|
+
arrowcolor=color,
|
|
206
|
+
font=dict(color=color, size=10)
|
|
207
|
+
),
|
|
208
|
+
dict(
|
|
209
|
+
x=last_frame_pos[0],
|
|
210
|
+
y=last_frame_pos[1],
|
|
211
|
+
z=last_frame_pos[2],
|
|
212
|
+
text=f'End {idx}',
|
|
213
|
+
showarrow=True,
|
|
214
|
+
arrowhead=2,
|
|
215
|
+
ax=-20,
|
|
216
|
+
ay=-20,
|
|
217
|
+
arrowcolor=color,
|
|
218
|
+
font=dict(color=color, size=10)
|
|
219
|
+
)
|
|
220
|
+
])
|
|
221
|
+
|
|
222
|
+
if plot_lines:
|
|
223
|
+
# Add trajectory line and markers (original behavior)
|
|
224
|
+
fig.add_trace(go.Scatter3d(
|
|
225
|
+
x=pos[:, 0],
|
|
226
|
+
y=pos[:, 1],
|
|
227
|
+
z=pos[:, 2],
|
|
228
|
+
mode='lines+markers',
|
|
229
|
+
name=f'Atom {idx}',
|
|
230
|
+
line=dict(width=3, color=color),
|
|
231
|
+
marker=dict(size=4, color=color),
|
|
232
|
+
))
|
|
233
|
+
else:
|
|
234
|
+
# Scatter-only mode: just showing points for each frame
|
|
235
|
+
fig.add_trace(go.Scatter3d(
|
|
236
|
+
x=pos[:, 0],
|
|
237
|
+
y=pos[:, 1],
|
|
238
|
+
z=pos[:, 2],
|
|
239
|
+
mode='markers',
|
|
240
|
+
name=f'Atom {idx}',
|
|
241
|
+
marker=dict(
|
|
242
|
+
size=5,
|
|
243
|
+
color=color,
|
|
244
|
+
symbol='circle',
|
|
245
|
+
opacity=0.8,
|
|
246
|
+
line=dict(color='black', width=0.5)
|
|
247
|
+
)
|
|
248
|
+
))
|
|
249
|
+
|
|
250
|
+
if not plot_title:
|
|
251
|
+
atom_types_str = ', '.join(atom_types.keys())
|
|
252
|
+
plot_title = f'Atomic Trajectories in {atom_types_str} System'
|
|
253
|
+
|
|
254
|
+
fig.update_layout(
|
|
255
|
+
title=plot_title,
|
|
256
|
+
scene=dict(
|
|
257
|
+
xaxis_title='X (Å)',
|
|
258
|
+
yaxis_title='Y (Å)',
|
|
259
|
+
zaxis_title='Z (Å)',
|
|
260
|
+
xaxis=dict(range=[0, box[0]]),
|
|
261
|
+
yaxis=dict(range=[0, box[1]]),
|
|
262
|
+
zaxis=dict(range=[0, box[2]]),
|
|
263
|
+
aspectmode='cube'
|
|
264
|
+
),
|
|
265
|
+
margin=dict(l=0, r=0, b=0, t=40),
|
|
266
|
+
scene_annotations=annotations
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
270
|
+
output_path = os.path.join(output_dir, output_filename)
|
|
271
|
+
|
|
272
|
+
fig.write_html(output_path)
|
|
273
|
+
print(f"Plot has been saved to {output_path}")
|
|
274
|
+
|
|
275
|
+
if show_plot:
|
|
276
|
+
fig.show()
|
|
277
|
+
|
|
278
|
+
return fig
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/error_analysis.py
|
|
3
|
+
|
|
4
|
+
This module provides statistical error analysis tools for molecular dynamics simulations,
|
|
5
|
+
including autocorrelation and block averaging methods to estimate statistical errors.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import matplotlib.pyplot as plt
|
|
10
|
+
from statsmodels.tsa.stattools import acf
|
|
11
|
+
import warnings
|
|
12
|
+
from typing import Dict, Optional, Any
|
|
13
|
+
|
|
14
|
+
__all__ = ['optimal_lag', 'vector_acf', 'autocorrelation_analysis', 'block_analysis']
|
|
15
|
+
|
|
16
|
+
def optimal_lag(acf_values: np.ndarray, threshold: float = 0.05) -> int:
|
|
17
|
+
"""Find the optimal lag time at which autocorrelation drops below threshold.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
acf_values : numpy.ndarray
|
|
22
|
+
Array of autocorrelation function values
|
|
23
|
+
threshold : float, optional
|
|
24
|
+
Correlation threshold below which data is considered uncorrelated (default: 0.05)
|
|
25
|
+
|
|
26
|
+
Returns
|
|
27
|
+
-------
|
|
28
|
+
int
|
|
29
|
+
Optimal lag time where autocorrelation drops below threshold
|
|
30
|
+
|
|
31
|
+
Warns
|
|
32
|
+
-----
|
|
33
|
+
UserWarning
|
|
34
|
+
If autocorrelation function does not converge within available data
|
|
35
|
+
"""
|
|
36
|
+
for lag, value in enumerate(acf_values):
|
|
37
|
+
if abs(value) < threshold:
|
|
38
|
+
return lag
|
|
39
|
+
|
|
40
|
+
acf_not_converged = (
|
|
41
|
+
"Autocorrelation function is not converged. "
|
|
42
|
+
f"Consider increasing the 'max_lag' parameter (current: {len(acf_values) - 1}) "
|
|
43
|
+
"or extending the simulation length."
|
|
44
|
+
)
|
|
45
|
+
warnings.warn(acf_not_converged)
|
|
46
|
+
|
|
47
|
+
return len(acf_values) - 1
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def vector_acf(data, max_lag):
|
|
51
|
+
"""
|
|
52
|
+
Calculate autocorrelation function for vector data.
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
data : numpy.ndarray
|
|
57
|
+
Input vector data with shape (n_frames, n_dimensions)
|
|
58
|
+
max_lag : int
|
|
59
|
+
Maximum lag time to calculate autocorrelation for
|
|
60
|
+
|
|
61
|
+
Returns
|
|
62
|
+
-------
|
|
63
|
+
numpy.ndarray
|
|
64
|
+
Array of autocorrelation values from lag 0 to max_lag
|
|
65
|
+
"""
|
|
66
|
+
n_frames = data.shape[0]
|
|
67
|
+
m = np.mean(data, axis=0)
|
|
68
|
+
data_centered = data - m
|
|
69
|
+
norm0 = np.mean(np.sum(data_centered**2, axis=1))
|
|
70
|
+
acf_vals = np.zeros(max_lag + 1)
|
|
71
|
+
for tau in range(max_lag + 1):
|
|
72
|
+
dots = np.sum(data_centered[:n_frames - tau] * data_centered[tau:], axis=1)
|
|
73
|
+
acf_vals[tau] = np.mean(dots) / norm0
|
|
74
|
+
|
|
75
|
+
return acf_vals
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def autocorrelation_analysis(data, max_lag=None, threshold=0.05, plot_acf=False):
|
|
79
|
+
"""
|
|
80
|
+
Perform autocorrelation analysis to estimate statistical errors.
|
|
81
|
+
|
|
82
|
+
Parameters
|
|
83
|
+
----------
|
|
84
|
+
data : numpy.ndarray
|
|
85
|
+
Input data (1D array or multi-dimensional array)
|
|
86
|
+
max_lag : int, optional
|
|
87
|
+
Maximum lag time to calculate autocorrelation for (default: min(1000, N/10))
|
|
88
|
+
threshold : float, optional
|
|
89
|
+
Correlation threshold below which data is considered uncorrelated (default: 0.05)
|
|
90
|
+
plot_acf : bool, optional
|
|
91
|
+
Whether to generate an autocorrelation plot (default: False)
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
dict
|
|
96
|
+
Dictionary containing:
|
|
97
|
+
- mean: Mean value of data
|
|
98
|
+
- acf_err: Error estimate from autocorrelation analysis
|
|
99
|
+
- std: Standard deviation of data
|
|
100
|
+
- tau_int: Integrated autocorrelation time
|
|
101
|
+
- optimal_lag: Optimal lag time where autocorrelation drops below threshold
|
|
102
|
+
"""
|
|
103
|
+
if data.ndim == 1:
|
|
104
|
+
N = len(data)
|
|
105
|
+
mean_value = np.mean(data)
|
|
106
|
+
std_value = np.std(data, ddof=1)
|
|
107
|
+
if max_lag is None:
|
|
108
|
+
max_lag = min(1000, N // 10)
|
|
109
|
+
acf_values = acf(data - mean_value, nlags=max_lag, fft=True)
|
|
110
|
+
opt_lag = optimal_lag(acf_values, threshold)
|
|
111
|
+
else:
|
|
112
|
+
N = data.shape[0]
|
|
113
|
+
mean_value = np.mean(data, axis=0)
|
|
114
|
+
std_value = np.std(data, axis=0, ddof=1)
|
|
115
|
+
if max_lag is None:
|
|
116
|
+
max_lag = min(N // 20)
|
|
117
|
+
acf_values = vector_acf(data, max_lag)
|
|
118
|
+
opt_lag = optimal_lag(acf_values, threshold)
|
|
119
|
+
|
|
120
|
+
tau_int = 0.5 + np.sum(acf_values[1:opt_lag + 1])
|
|
121
|
+
autocorr_error = std_value * np.sqrt(2 * tau_int / N)
|
|
122
|
+
|
|
123
|
+
if plot_acf:
|
|
124
|
+
plt.figure(figsize=(8, 5))
|
|
125
|
+
plt.plot(np.arange(len(acf_values)), acf_values, linestyle='-', linewidth=2, color="blue", label='ACF')
|
|
126
|
+
plt.axhline(y=threshold, color='red', linestyle='--', linewidth=1.5, label='Threshold')
|
|
127
|
+
plt.xlabel('Lag')
|
|
128
|
+
plt.ylabel('Autocorrelation')
|
|
129
|
+
plt.title('Autocorrelation Function (ACF)')
|
|
130
|
+
plt.legend()
|
|
131
|
+
plt.savefig("ACF_lag_analysis.png", dpi=300, bbox_inches='tight')
|
|
132
|
+
|
|
133
|
+
return {"mean": mean_value, "acf_err": autocorr_error, "std": std_value, "tau_int": tau_int, "optimal_lag": opt_lag}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def block_analysis(data, convergence_tol=0.001, plot_blocks=False):
|
|
137
|
+
"""
|
|
138
|
+
Perform block averaging analysis to estimate statistical errors.
|
|
139
|
+
|
|
140
|
+
Parameters
|
|
141
|
+
----------
|
|
142
|
+
data : numpy.ndarray
|
|
143
|
+
Input data array
|
|
144
|
+
convergence_tol : float, optional
|
|
145
|
+
Tolerance for determining convergence of standard error (default: 0.001)
|
|
146
|
+
plot_blocks : bool, optional
|
|
147
|
+
Whether to generate a block averaging plot (default: False)
|
|
148
|
+
|
|
149
|
+
Returns
|
|
150
|
+
-------
|
|
151
|
+
dict
|
|
152
|
+
Dictionary containing:
|
|
153
|
+
- mean: Mean value of data
|
|
154
|
+
- block_err: Error estimate from block averaging
|
|
155
|
+
- std: Standard deviation of data
|
|
156
|
+
- converged_blocks: Number of blocks at convergence
|
|
157
|
+
|
|
158
|
+
Warns
|
|
159
|
+
-----
|
|
160
|
+
UserWarning
|
|
161
|
+
If block averaging does not converge with the given tolerance
|
|
162
|
+
"""
|
|
163
|
+
N = len(data)
|
|
164
|
+
mean_value = np.mean(data)
|
|
165
|
+
std_value = np.std(data, ddof=1)
|
|
166
|
+
|
|
167
|
+
block_sizes = np.arange(1, N // 2)
|
|
168
|
+
standard_errors = []
|
|
169
|
+
|
|
170
|
+
for M in block_sizes:
|
|
171
|
+
block_length = N // M
|
|
172
|
+
|
|
173
|
+
truncated_data = data[:block_length * M]
|
|
174
|
+
blocks = truncated_data.reshape(M, block_length)
|
|
175
|
+
block_means = np.mean(blocks, axis=1)
|
|
176
|
+
|
|
177
|
+
if len(block_means) > 1:
|
|
178
|
+
std_error = np.std(block_means, ddof=1) / np.sqrt(M)
|
|
179
|
+
else:
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
standard_errors.append(std_error)
|
|
183
|
+
|
|
184
|
+
if len(standard_errors) > 5:
|
|
185
|
+
recent_errors = standard_errors[-5:]
|
|
186
|
+
if np.max(recent_errors) - np.min(recent_errors) < convergence_tol:
|
|
187
|
+
converged_blocks = M
|
|
188
|
+
final_error = std_error
|
|
189
|
+
break
|
|
190
|
+
else:
|
|
191
|
+
converged_blocks = block_sizes[-1]
|
|
192
|
+
final_error = standard_errors[-1]
|
|
193
|
+
warnings.warn("Block averaging did not fully converge. Consider increasing data length or lowering tolerance.")
|
|
194
|
+
|
|
195
|
+
if plot_blocks:
|
|
196
|
+
plt.figure(figsize=(8, 5))
|
|
197
|
+
plt.plot(block_sizes[:len(standard_errors)], standard_errors, color="blue", label='Standard Error')
|
|
198
|
+
plt.xlabel('Number of Blocks')
|
|
199
|
+
plt.ylabel('Standard Error')
|
|
200
|
+
plt.title('Block Averaging Convergence')
|
|
201
|
+
plt.savefig("block_averaging_convergence.png", dpi=300, bbox_inches='tight')
|
|
202
|
+
|
|
203
|
+
return {
|
|
204
|
+
"mean": mean_value,
|
|
205
|
+
"block_err": final_error,
|
|
206
|
+
"std": std_value,
|
|
207
|
+
"converged_blocks": converged_blocks
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def error_analysis(data, max_lag=None, threshold=0.05, convergence_tol=0.001, plot=False):
|
|
212
|
+
"""
|
|
213
|
+
Perform comprehensive error analysis using both autocorrelation and block averaging.
|
|
214
|
+
|
|
215
|
+
Parameters
|
|
216
|
+
----------
|
|
217
|
+
data : numpy.ndarray
|
|
218
|
+
Input data array
|
|
219
|
+
max_lag : int, optional
|
|
220
|
+
Maximum lag time for autocorrelation (default: min(1000, N/10))
|
|
221
|
+
threshold : float, optional
|
|
222
|
+
Correlation threshold for autocorrelation analysis (default: 0.05)
|
|
223
|
+
convergence_tol : float, optional
|
|
224
|
+
Convergence tolerance for block averaging (default: 0.001)
|
|
225
|
+
plot : bool, optional
|
|
226
|
+
Whether to generate diagnostic plots (default: False)
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
dict
|
|
231
|
+
Dictionary containing results from both methods:
|
|
232
|
+
- mean: Mean value of data
|
|
233
|
+
- std: Standard deviation of data
|
|
234
|
+
- acf_results: Full results from autocorrelation analysis
|
|
235
|
+
- block_results: Full results from block averaging analysis
|
|
236
|
+
"""
|
|
237
|
+
# Ensure data is a numpy array
|
|
238
|
+
data = np.asarray(data)
|
|
239
|
+
|
|
240
|
+
# Perform both types of analysis
|
|
241
|
+
acf_results = autocorrelation_analysis(data, max_lag, threshold, plot_acf=plot)
|
|
242
|
+
block_results = block_analysis(data, convergence_tol, plot_blocks=plot)
|
|
243
|
+
|
|
244
|
+
# Combine results
|
|
245
|
+
results = {
|
|
246
|
+
"mean": acf_results["mean"],
|
|
247
|
+
"std": acf_results["std"],
|
|
248
|
+
"acf_results": acf_results,
|
|
249
|
+
"block_results": block_results
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
return results
|
|
253
|
+
|
|
254
|
+
|