crisp-ase 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRISP/__init__.py +99 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +41 -0
- CRISP/data_analysis/__init__.py +38 -0
- CRISP/data_analysis/clustering.py +838 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +772 -0
- CRISP/data_analysis/msd.py +1199 -0
- CRISP/data_analysis/prdf.py +404 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +31 -0
- CRISP/simulation_utility/atomic_indices.py +155 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +254 -0
- CRISP/simulation_utility/interatomic_distances.py +200 -0
- CRISP/simulation_utility/subsampling.py +241 -0
- CRISP/tests/DataAnalysis/__init__.py +1 -0
- CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
- CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
- CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
- CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
- CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
- CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
- CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
- CRISP/tests/DataAnalysis/test_prdf.py +206 -0
- CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
- CRISP/tests/SimulationUtility/__init__.py +1 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
- CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
- CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
- CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
- CRISP/tests/__init__.py +1 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_cli.py +87 -0
- CRISP/tests/test_crisp_comprehensive.py +679 -0
- crisp_ase-1.1.2.dist-info/METADATA +141 -0
- crisp_ase-1.1.2.dist-info/RECORD +42 -0
- crisp_ase-1.1.2.dist-info/WHEEL +5 -0
- crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
- crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/interatomic_distances.py
|
|
3
|
+
|
|
4
|
+
This module provides tools to calculate and analyze interatomic distances from
|
|
5
|
+
molecular dynamics trajectories.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pickle
|
|
11
|
+
from ase.io import read
|
|
12
|
+
from ase import Atoms
|
|
13
|
+
from typing import Union, Tuple, List, Dict, Any, Optional
|
|
14
|
+
from joblib import Parallel, delayed
|
|
15
|
+
|
|
16
|
+
__all__ = ['indices', 'distance_calculation', 'save_distance_matrices', 'calculate_interatomic_distances']
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def indices(atoms: Atoms, ind: Union[str, List[Union[int, str]], None]) -> np.ndarray:
|
|
20
|
+
"""Extract atom indices from various input types.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
atoms : ase.Atoms
|
|
25
|
+
Atoms object containing atomic coordinates and elements
|
|
26
|
+
ind : str, list, or None
|
|
27
|
+
Specification for which atoms to select:
|
|
28
|
+
- "all" or None: all atoms
|
|
29
|
+
- string ending with ".npy": load indices from NumPy file
|
|
30
|
+
- list of integers: direct atom indices
|
|
31
|
+
- list of strings: chemical symbols to select
|
|
32
|
+
|
|
33
|
+
Returns
|
|
34
|
+
-------
|
|
35
|
+
np.ndarray
|
|
36
|
+
Array of atom indices
|
|
37
|
+
|
|
38
|
+
Raises
|
|
39
|
+
------
|
|
40
|
+
ValueError
|
|
41
|
+
If the index type is not recognized
|
|
42
|
+
"""
|
|
43
|
+
if ind == "all" or ind is None:
|
|
44
|
+
return np.arange(len(atoms))
|
|
45
|
+
if isinstance(ind, str) and ind.endswith(".npy"):
|
|
46
|
+
return np.load(ind, allow_pickle=True)
|
|
47
|
+
if not isinstance(ind, list):
|
|
48
|
+
ind = [ind]
|
|
49
|
+
if any(isinstance(item, int) for item in ind):
|
|
50
|
+
return np.array(ind)
|
|
51
|
+
if any(isinstance(item, str) for item in ind):
|
|
52
|
+
idx = []
|
|
53
|
+
for symbol in ind:
|
|
54
|
+
idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
|
|
55
|
+
return np.concatenate(idx)
|
|
56
|
+
raise ValueError("Invalid index type")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def distance_calculation(
|
|
60
|
+
traj_path: str,
|
|
61
|
+
frame_skip: int,
|
|
62
|
+
index_type: Union[str, List[Union[int, str]]] = "all"
|
|
63
|
+
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
|
|
64
|
+
"""Calculate distance matrices for multiple frames in a trajectory.
|
|
65
|
+
|
|
66
|
+
Parameters
|
|
67
|
+
----------
|
|
68
|
+
traj_path : str
|
|
69
|
+
Path to the trajectory file in any format supported by ASE
|
|
70
|
+
frame_skip : int
|
|
71
|
+
Read every nth frame (n=frame_skip)
|
|
72
|
+
index_type : str, list, or None, optional
|
|
73
|
+
Specification for which atoms to select for sub-matrix (default: "all")
|
|
74
|
+
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
Tuple[List[np.ndarray], List[np.ndarray]]
|
|
78
|
+
Two lists containing:
|
|
79
|
+
1. Full distance matrices for all frames
|
|
80
|
+
2. Sub-matrices for specified atoms
|
|
81
|
+
|
|
82
|
+
Raises
|
|
83
|
+
------
|
|
84
|
+
ValueError
|
|
85
|
+
If no frames were found in the trajectory or if format is unsupported
|
|
86
|
+
"""
|
|
87
|
+
try:
|
|
88
|
+
# Let ASE auto-detect file format based on extension
|
|
89
|
+
frames = read(traj_path, index=f"::{frame_skip}")
|
|
90
|
+
|
|
91
|
+
# Handle the case when a single frame is returned (not a list)
|
|
92
|
+
if not isinstance(frames, list):
|
|
93
|
+
frames = [frames]
|
|
94
|
+
|
|
95
|
+
if not frames:
|
|
96
|
+
raise ValueError("No frames were found in the trajectory using the given frame_skip.")
|
|
97
|
+
|
|
98
|
+
def process_frame(frame: Atoms) -> Tuple[np.ndarray, np.ndarray]:
|
|
99
|
+
dm = frame.get_all_distances(mic=True)
|
|
100
|
+
idx = indices(frame, index_type)
|
|
101
|
+
sub_dm = dm[np.ix_(idx, idx)]
|
|
102
|
+
return dm, sub_dm
|
|
103
|
+
|
|
104
|
+
results = Parallel(n_jobs=-1)(delayed(process_frame)(frame) for frame in frames)
|
|
105
|
+
full_dms, sub_dms = zip(*results)
|
|
106
|
+
return list(full_dms), list(sub_dms)
|
|
107
|
+
|
|
108
|
+
except ValueError as e:
|
|
109
|
+
raise e
|
|
110
|
+
except Exception as e:
|
|
111
|
+
raise ValueError(f"Error processing trajectory: {e}. Check if the format is supported by ASE.")
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def save_distance_matrices(
|
|
115
|
+
full_dms: List[np.ndarray],
|
|
116
|
+
sub_dms: List[np.ndarray],
|
|
117
|
+
index_type: Union[str, List[Union[int, str]]] = "all",
|
|
118
|
+
output_dir: str = "distance_calculations"
|
|
119
|
+
) -> None:
|
|
120
|
+
"""Save distance matrices to pickle file.
|
|
121
|
+
|
|
122
|
+
Parameters
|
|
123
|
+
----------
|
|
124
|
+
full_dms : List[np.ndarray]
|
|
125
|
+
List of full distance matrices
|
|
126
|
+
sub_dms : List[np.ndarray]
|
|
127
|
+
List of sub-matrices for specified atoms
|
|
128
|
+
index_type : str, list, or None, optional
|
|
129
|
+
Type of index selection used (default: "all")
|
|
130
|
+
output_dir : str, optional
|
|
131
|
+
Directory to save output file (default: "distance_calculations")
|
|
132
|
+
|
|
133
|
+
Returns
|
|
134
|
+
-------
|
|
135
|
+
None
|
|
136
|
+
Saves results to disk
|
|
137
|
+
"""
|
|
138
|
+
data = {"full_dms": full_dms}
|
|
139
|
+
if index_type not in ["all", None]:
|
|
140
|
+
data["sub_dms"] = sub_dms
|
|
141
|
+
|
|
142
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
143
|
+
output_path = os.path.join(output_dir, "distance_matrices.pkl")
|
|
144
|
+
with open(output_path, "wb") as f:
|
|
145
|
+
pickle.dump(data, f)
|
|
146
|
+
print(f"Distance matrices saved in '{output_path}'")
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def calculate_interatomic_distances(
|
|
150
|
+
traj_path: str,
|
|
151
|
+
frame_skip: int = 10,
|
|
152
|
+
index_type: Union[str, List[Union[int, str]]] = "all",
|
|
153
|
+
output_dir: str = "distance_calculations",
|
|
154
|
+
save_results: bool = True
|
|
155
|
+
) -> Dict[str, List[np.ndarray]]:
|
|
156
|
+
"""
|
|
157
|
+
Calculate interatomic distances for a trajectory and optionally save results.
|
|
158
|
+
|
|
159
|
+
Parameters
|
|
160
|
+
----------
|
|
161
|
+
traj_path : str
|
|
162
|
+
Path to the trajectory file
|
|
163
|
+
frame_skip : int, optional
|
|
164
|
+
Read every nth frame (default: 10)
|
|
165
|
+
index_type : str, list, or None, optional
|
|
166
|
+
Specification for which atoms to select (default: "all")
|
|
167
|
+
output_dir : str, optional
|
|
168
|
+
Directory to save output file (default: "distance_calculations")
|
|
169
|
+
save_results : bool, optional
|
|
170
|
+
Whether to save results to disk (default: True)
|
|
171
|
+
|
|
172
|
+
Returns
|
|
173
|
+
-------
|
|
174
|
+
Dict[str, List[np.ndarray]]
|
|
175
|
+
Dictionary containing full distance matrices and optionally sub-matrices
|
|
176
|
+
|
|
177
|
+
Examples
|
|
178
|
+
--------
|
|
179
|
+
>>> results = calculate_interatomic_distances("trajectory.traj")
|
|
180
|
+
>>> first_frame_distances = results["full_dms"][0]
|
|
181
|
+
>>> print(f"Distance matrix shape: {first_frame_distances.shape}")
|
|
182
|
+
"""
|
|
183
|
+
print(f"Calculating interatomic distances from '{traj_path}'")
|
|
184
|
+
print(f"Using frame skip: {frame_skip}")
|
|
185
|
+
print(f"Index type: {index_type}")
|
|
186
|
+
|
|
187
|
+
full_dms, sub_dms = distance_calculation(traj_path, frame_skip, index_type)
|
|
188
|
+
|
|
189
|
+
print(f"Processed {len(full_dms)} frames")
|
|
190
|
+
print(f"Full matrix shape: {full_dms[0].shape}")
|
|
191
|
+
print(f"Sub-matrix shape: {sub_dms[0].shape}")
|
|
192
|
+
|
|
193
|
+
results = {"full_dms": full_dms}
|
|
194
|
+
if index_type not in ["all", None]:
|
|
195
|
+
results["sub_dms"] = sub_dms
|
|
196
|
+
|
|
197
|
+
if save_results:
|
|
198
|
+
save_distance_matrices(full_dms, sub_dms, index_type, output_dir)
|
|
199
|
+
|
|
200
|
+
return results
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/subsampling.py
|
|
3
|
+
|
|
4
|
+
This module provides functionality for structure subsampling from molecular dynamics
|
|
5
|
+
trajectories using Farthest Point Sampling (FPS) with SOAP descriptors.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from ase.io import read, write
|
|
10
|
+
import fpsample
|
|
11
|
+
import glob
|
|
12
|
+
import os
|
|
13
|
+
from dscribe.descriptors import SOAP
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
from joblib import Parallel, delayed
|
|
16
|
+
from typing import Union, List, Optional
|
|
17
|
+
|
|
18
|
+
__all__ = ['indices', 'compute_soap', 'create_repres', 'subsample']
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def indices(atoms, ind: Union[str, List[Union[int, str]]]) -> np.ndarray:
|
|
22
|
+
"""
|
|
23
|
+
Extract atom indices from an ASE Atoms object based on the input specifier.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
atoms : ase.Atoms
|
|
28
|
+
ASE Atoms object containing atomic structure
|
|
29
|
+
ind : Union[str, List[Union[int, str]]]
|
|
30
|
+
Index specifier, can be:
|
|
31
|
+
- "all" or None: all atoms
|
|
32
|
+
- string ending with ".npy": load indices from NumPy file
|
|
33
|
+
- integer or list of integers: direct atom indices
|
|
34
|
+
- string or list of strings: chemical symbols to select
|
|
35
|
+
|
|
36
|
+
Returns
|
|
37
|
+
-------
|
|
38
|
+
np.ndarray
|
|
39
|
+
Array of selected indices
|
|
40
|
+
|
|
41
|
+
Raises
|
|
42
|
+
------
|
|
43
|
+
ValueError
|
|
44
|
+
If the index type is invalid
|
|
45
|
+
"""
|
|
46
|
+
# Select all atoms
|
|
47
|
+
if ind == "all" or ind is None:
|
|
48
|
+
return np.arange(len(atoms))
|
|
49
|
+
|
|
50
|
+
# Load from NumPy file
|
|
51
|
+
if isinstance(ind, str) and ind.endswith(".npy"):
|
|
52
|
+
return np.load(ind, allow_pickle=True)
|
|
53
|
+
|
|
54
|
+
# Convert single items to list
|
|
55
|
+
if not isinstance(ind, list):
|
|
56
|
+
ind = [ind]
|
|
57
|
+
|
|
58
|
+
# Handle integer indices directly
|
|
59
|
+
if any(isinstance(item, int) for item in ind):
|
|
60
|
+
return np.array(ind)
|
|
61
|
+
|
|
62
|
+
# Handle chemical symbols
|
|
63
|
+
if any(isinstance(item, str) for item in ind):
|
|
64
|
+
idx = []
|
|
65
|
+
for symbol in ind:
|
|
66
|
+
idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
|
|
67
|
+
return np.concatenate(idx)
|
|
68
|
+
|
|
69
|
+
raise ValueError("Invalid index type")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def compute_soap(
|
|
73
|
+
structure,
|
|
74
|
+
all_spec: List[str],
|
|
75
|
+
rcut: float,
|
|
76
|
+
idx: np.ndarray
|
|
77
|
+
) -> np.ndarray:
|
|
78
|
+
"""Compute SOAP descriptors for a given structure.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
structure : ase.Atoms
|
|
83
|
+
Atomic structure for which to compute SOAP descriptors
|
|
84
|
+
all_spec : list
|
|
85
|
+
List of chemical elements to include in the descriptor
|
|
86
|
+
rcut : float
|
|
87
|
+
Cutoff radius for the SOAP descriptor in Angstroms
|
|
88
|
+
idx : numpy.ndarray
|
|
89
|
+
Indices of atoms to use as centers for SOAP calculation
|
|
90
|
+
|
|
91
|
+
Returns
|
|
92
|
+
-------
|
|
93
|
+
numpy.ndarray
|
|
94
|
+
Average SOAP descriptor vector for the structure
|
|
95
|
+
"""
|
|
96
|
+
periodic_cell = structure.cell.volume > 0
|
|
97
|
+
soap = SOAP(
|
|
98
|
+
species=all_spec,
|
|
99
|
+
periodic=periodic_cell,
|
|
100
|
+
r_cut=rcut,
|
|
101
|
+
n_max=8,
|
|
102
|
+
l_max=6,
|
|
103
|
+
sigma=0.5,
|
|
104
|
+
sparse=False
|
|
105
|
+
)
|
|
106
|
+
soap_ind = soap.create(structure, centers=idx)
|
|
107
|
+
return np.mean(soap_ind, axis=0)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def create_repres(
|
|
111
|
+
traj_path: List,
|
|
112
|
+
rcut: float = 6,
|
|
113
|
+
ind: Union[str, List[Union[int, str]]] = "all",
|
|
114
|
+
n_jobs: int = -1
|
|
115
|
+
) -> np.ndarray:
|
|
116
|
+
"""Create SOAP representation vectors for a trajectory.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
traj_path : list
|
|
121
|
+
List of ase.Atoms objects representing a trajectory
|
|
122
|
+
rcut : float, optional
|
|
123
|
+
Cutoff radius for the SOAP descriptor in Angstroms (default: 6)
|
|
124
|
+
ind : str, list, or None, optional
|
|
125
|
+
Specification for which atoms to use as SOAP centers (default: "all")
|
|
126
|
+
n_jobs : int, optional
|
|
127
|
+
Number of parallel jobs to run; -1 uses all available cores (default: -1)
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
numpy.ndarray
|
|
132
|
+
Array of SOAP descriptors for each frame in the trajectory
|
|
133
|
+
"""
|
|
134
|
+
all_spec = traj_path[0].get_chemical_symbols()
|
|
135
|
+
idx = indices(traj_path[0], ind=ind)
|
|
136
|
+
|
|
137
|
+
repres = Parallel(n_jobs=n_jobs)(
|
|
138
|
+
delayed(compute_soap)(structure, all_spec, rcut, idx) for structure in traj_path
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
return np.array(repres)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def subsample(
|
|
145
|
+
traj_path: str,
|
|
146
|
+
n_samples: int = 50,
|
|
147
|
+
index_type: Union[str, List[Union[int, str]]] = "all",
|
|
148
|
+
rcut: float = 6.0,
|
|
149
|
+
file_format: Optional[str] = None,
|
|
150
|
+
plot_subsample: bool = False,
|
|
151
|
+
frame_skip: int = 1,
|
|
152
|
+
output_dir: str = "subsampled_structures"
|
|
153
|
+
) -> None:
|
|
154
|
+
"""Subsample a trajectory using Farthest Point Sampling with SOAP descriptors.
|
|
155
|
+
|
|
156
|
+
Parameters
|
|
157
|
+
----------
|
|
158
|
+
traj_path : str
|
|
159
|
+
Path pattern to trajectory file(s); supports globbing
|
|
160
|
+
n_samples : int, optional
|
|
161
|
+
Number of frames to select (default: 50)
|
|
162
|
+
index_type : str, list, or None, optional
|
|
163
|
+
Specification for which atoms to use for SOAP calculation (default: "all")
|
|
164
|
+
rcut : float, optional
|
|
165
|
+
Cutoff radius for SOAP in Angstroms (default: 6.0)
|
|
166
|
+
file_format : str, optional
|
|
167
|
+
File format for ASE I/O (default: None, auto-detect)
|
|
168
|
+
plot_subsample : bool, optional
|
|
169
|
+
Whether to generate a plot of FPS distances (default: False)
|
|
170
|
+
frame_skip : int, optional
|
|
171
|
+
Read every nth frame from the trajectory (default: 1)
|
|
172
|
+
output_dir : str, optional
|
|
173
|
+
Directory to save the subsampled structures (default: "subsampled_structures")
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
list
|
|
178
|
+
List of selected ase.Atoms frames
|
|
179
|
+
|
|
180
|
+
Notes
|
|
181
|
+
-----
|
|
182
|
+
The selected frames and plots are saved in the specified output directory
|
|
183
|
+
"""
|
|
184
|
+
traj_files = glob.glob(traj_path)
|
|
185
|
+
|
|
186
|
+
# Check if any matching files were found
|
|
187
|
+
if not traj_files:
|
|
188
|
+
raise ValueError(f"No files found matching pattern: {traj_path}")
|
|
189
|
+
|
|
190
|
+
trajec = []
|
|
191
|
+
for file in traj_files:
|
|
192
|
+
if file_format is not None:
|
|
193
|
+
trajec += read(file, index=f'::{frame_skip}', format=file_format)
|
|
194
|
+
else:
|
|
195
|
+
trajec += read(file, index=f'::{frame_skip}')
|
|
196
|
+
|
|
197
|
+
if not isinstance(trajec, list):
|
|
198
|
+
trajec = [trajec]
|
|
199
|
+
|
|
200
|
+
repres = create_repres(trajec, ind=index_type, rcut=rcut)
|
|
201
|
+
|
|
202
|
+
# Ensure we don't request more samples than available frames
|
|
203
|
+
n_samples = min(n_samples, len(trajec))
|
|
204
|
+
|
|
205
|
+
perm = fpsample.fps_sampling(repres, n_samples, start_idx=0)
|
|
206
|
+
|
|
207
|
+
fps_frames = []
|
|
208
|
+
|
|
209
|
+
for str_idx, frame in enumerate(perm):
|
|
210
|
+
new_frame = trajec[frame]
|
|
211
|
+
fps_frames.append(new_frame)
|
|
212
|
+
|
|
213
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
214
|
+
|
|
215
|
+
if plot_subsample:
|
|
216
|
+
distance = []
|
|
217
|
+
for i in range(1, len(perm)):
|
|
218
|
+
distance.append(np.min(np.linalg.norm(repres[perm[:i]] - repres[perm[i]], axis=1)))
|
|
219
|
+
|
|
220
|
+
plt.figure(figsize=(8, 6))
|
|
221
|
+
plt.plot(distance, c="blue", linewidth=2)
|
|
222
|
+
plt.ylim([0, 1.1 * max(distance)])
|
|
223
|
+
plt.xlabel("Number of subsampled structures")
|
|
224
|
+
plt.ylabel("Euclidean distance")
|
|
225
|
+
plt.title("FPS Subsampling")
|
|
226
|
+
plt.savefig(os.path.join(output_dir, "subsampled_convergence.png"), dpi=300)
|
|
227
|
+
plt.show()
|
|
228
|
+
plt.close()
|
|
229
|
+
print(f"Saved convergence plot to {os.path.join(output_dir, 'subsampled_convergence.png')}")
|
|
230
|
+
|
|
231
|
+
# Extract the base filename without path for output file using os.path for platform independence
|
|
232
|
+
base_filename = os.path.basename(traj_files[0])
|
|
233
|
+
output_file = os.path.join(output_dir, f"subsample_{base_filename}")
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
write(output_file, fps_frames, format=file_format)
|
|
237
|
+
print(f"Saved {len(fps_frames)} subsampled structures to {output_file}")
|
|
238
|
+
except Exception as e:
|
|
239
|
+
print(f"Error saving subsampled structures: {e}")
|
|
240
|
+
|
|
241
|
+
return fps_frames
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""CRISP DataAnalysis tests package."""
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Extended tests for clustering module to improve coverage."""
|
|
2
|
+
import pytest
|
|
3
|
+
import numpy as np
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
import shutil
|
|
7
|
+
from ase import Atoms
|
|
8
|
+
from ase.io import write
|
|
9
|
+
|
|
10
|
+
try:
|
|
11
|
+
from CRISP.data_analysis.clustering import (
|
|
12
|
+
analyze_frame,
|
|
13
|
+
analyze_trajectory,
|
|
14
|
+
)
|
|
15
|
+
ASE_AVAILABLE = True
|
|
16
|
+
except ImportError:
|
|
17
|
+
ASE_AVAILABLE = False
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.skipif(not ASE_AVAILABLE, reason="ASE not available")
|
|
21
|
+
class TestClusteringExtended:
|
|
22
|
+
"""Extended clustering tests for coverage."""
|
|
23
|
+
|
|
24
|
+
def test_analyze_frame_basic(self):
|
|
25
|
+
"""Test frame clustering analysis."""
|
|
26
|
+
temp_dir = tempfile.mkdtemp()
|
|
27
|
+
try:
|
|
28
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
29
|
+
atoms = Atoms('H2OH2O', positions=[
|
|
30
|
+
[0.0, 0.0, 0.0],
|
|
31
|
+
[0.96, 0.0, 0.0],
|
|
32
|
+
[0.24, 0.93, 0.0],
|
|
33
|
+
[2.8, 0.0, 0.0],
|
|
34
|
+
[3.76, 0.0, 0.0],
|
|
35
|
+
[3.04, 0.93, 0.0]
|
|
36
|
+
])
|
|
37
|
+
atoms.set_cell([10, 10, 10])
|
|
38
|
+
atoms.set_pbc([True, True, True])
|
|
39
|
+
write(traj_file, atoms)
|
|
40
|
+
|
|
41
|
+
atom_indices = np.array([0, 1, 2, 3, 4, 5])
|
|
42
|
+
analyzer = analyze_frame(
|
|
43
|
+
traj_path=traj_file,
|
|
44
|
+
atom_indices=atom_indices,
|
|
45
|
+
threshold=2.5,
|
|
46
|
+
min_samples=2
|
|
47
|
+
)
|
|
48
|
+
assert analyzer is not None
|
|
49
|
+
finally:
|
|
50
|
+
shutil.rmtree(temp_dir)
|
|
51
|
+
|
|
52
|
+
@pytest.mark.parametrize("threshold", [1.5, 2.0, 2.5, 3.0])
|
|
53
|
+
def test_analyze_frame_different_cutoffs(self, threshold):
|
|
54
|
+
"""Test with different distance cutoffs."""
|
|
55
|
+
temp_dir = tempfile.mkdtemp()
|
|
56
|
+
try:
|
|
57
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
58
|
+
atoms = Atoms('H2OH2O', positions=[
|
|
59
|
+
[0.0, 0.0, 0.0],
|
|
60
|
+
[0.96, 0.0, 0.0],
|
|
61
|
+
[0.24, 0.93, 0.0],
|
|
62
|
+
[2.8, 0.0, 0.0],
|
|
63
|
+
[3.76, 0.0, 0.0],
|
|
64
|
+
[3.04, 0.93, 0.0]
|
|
65
|
+
])
|
|
66
|
+
atoms.set_cell([10, 10, 10])
|
|
67
|
+
atoms.set_pbc([True, True, True])
|
|
68
|
+
write(traj_file, atoms)
|
|
69
|
+
|
|
70
|
+
atom_indices = np.array([0, 1, 2])
|
|
71
|
+
analyzer = analyze_frame(
|
|
72
|
+
traj_path=traj_file,
|
|
73
|
+
atom_indices=atom_indices,
|
|
74
|
+
threshold=threshold,
|
|
75
|
+
min_samples=1
|
|
76
|
+
)
|
|
77
|
+
assert analyzer is not None
|
|
78
|
+
finally:
|
|
79
|
+
shutil.rmtree(temp_dir)
|
|
80
|
+
|
|
81
|
+
def test_analyze_frame_calculate_distance_matrix(self):
|
|
82
|
+
"""Test distance matrix calculation."""
|
|
83
|
+
temp_dir = tempfile.mkdtemp()
|
|
84
|
+
try:
|
|
85
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
86
|
+
atoms = Atoms('H2O', positions=[
|
|
87
|
+
[0.0, 0.0, 0.0],
|
|
88
|
+
[0.96, 0.0, 0.0],
|
|
89
|
+
[0.24, 0.93, 0.0]
|
|
90
|
+
])
|
|
91
|
+
atoms.set_cell([10, 10, 10])
|
|
92
|
+
atoms.set_pbc([True, True, True])
|
|
93
|
+
write(traj_file, atoms)
|
|
94
|
+
|
|
95
|
+
atom_indices = np.array([0, 1, 2])
|
|
96
|
+
analyzer = analyze_frame(
|
|
97
|
+
traj_path=traj_file,
|
|
98
|
+
atom_indices=atom_indices,
|
|
99
|
+
threshold=2.5,
|
|
100
|
+
min_samples=2
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
frame = analyzer.read_custom_frame()
|
|
104
|
+
assert frame is not None
|
|
105
|
+
dist_matrix, positions = analyzer.calculate_distance_matrix(frame)
|
|
106
|
+
assert dist_matrix is not None
|
|
107
|
+
finally:
|
|
108
|
+
shutil.rmtree(temp_dir)
|
|
109
|
+
|
|
110
|
+
def test_analyze_trajectory_basic(self):
|
|
111
|
+
"""Test trajectory clustering."""
|
|
112
|
+
temp_dir = tempfile.mkdtemp()
|
|
113
|
+
try:
|
|
114
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
115
|
+
atoms = Atoms('H2O', positions=[
|
|
116
|
+
[0.0, 0.0, 0.0],
|
|
117
|
+
[0.96, 0.0, 0.0],
|
|
118
|
+
[0.24, 0.93, 0.0]
|
|
119
|
+
])
|
|
120
|
+
atoms.set_cell([10, 10, 10])
|
|
121
|
+
atoms.set_pbc([True, True, True])
|
|
122
|
+
write(traj_file, atoms)
|
|
123
|
+
|
|
124
|
+
atom_indices = np.array([0, 1, 2])
|
|
125
|
+
results = analyze_trajectory(
|
|
126
|
+
traj_path=traj_file,
|
|
127
|
+
indices_path=atom_indices,
|
|
128
|
+
threshold=2.5,
|
|
129
|
+
min_samples=2,
|
|
130
|
+
frame_skip=1
|
|
131
|
+
)
|
|
132
|
+
assert isinstance(results, list)
|
|
133
|
+
finally:
|
|
134
|
+
shutil.rmtree(temp_dir)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@pytest.mark.skipif(not ASE_AVAILABLE, reason="ASE not available")
|
|
138
|
+
class TestClusteringEdgeCases:
|
|
139
|
+
"""Test edge cases for clustering."""
|
|
140
|
+
|
|
141
|
+
def test_clustering_min_atoms_validation(self):
|
|
142
|
+
"""Test minimum atoms validation."""
|
|
143
|
+
temp_dir = tempfile.mkdtemp()
|
|
144
|
+
try:
|
|
145
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
146
|
+
atoms = Atoms('H', positions=[[0.0, 0.0, 0.0]])
|
|
147
|
+
atoms.set_cell([10, 10, 10])
|
|
148
|
+
atoms.set_pbc([True, True, True])
|
|
149
|
+
write(traj_file, atoms)
|
|
150
|
+
|
|
151
|
+
atom_indices = np.array([0])
|
|
152
|
+
analyzer = analyze_frame(
|
|
153
|
+
traj_path=traj_file,
|
|
154
|
+
atom_indices=atom_indices,
|
|
155
|
+
threshold=2.5,
|
|
156
|
+
min_samples=5
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
frame = analyzer.read_custom_frame()
|
|
160
|
+
with pytest.raises(ValueError):
|
|
161
|
+
analyzer.calculate_distance_matrix(frame)
|
|
162
|
+
finally:
|
|
163
|
+
shutil.rmtree(temp_dir)
|
|
164
|
+
|
|
165
|
+
def test_clustering_invalid_trajectory(self):
|
|
166
|
+
"""Test handling of invalid trajectory file."""
|
|
167
|
+
temp_dir = tempfile.mkdtemp()
|
|
168
|
+
try:
|
|
169
|
+
nonexistent = os.path.join(temp_dir, 'nonexistent.traj')
|
|
170
|
+
analyzer = analyze_frame(
|
|
171
|
+
traj_path=nonexistent,
|
|
172
|
+
atom_indices=np.array([0, 1, 2]),
|
|
173
|
+
threshold=2.5,
|
|
174
|
+
min_samples=2
|
|
175
|
+
)
|
|
176
|
+
frame = analyzer.read_custom_frame()
|
|
177
|
+
assert frame is None
|
|
178
|
+
finally:
|
|
179
|
+
shutil.rmtree(temp_dir)
|
|
180
|
+
|
|
181
|
+
def test_clustering_indices_from_file(self):
|
|
182
|
+
"""Test loading indices from numpy file."""
|
|
183
|
+
temp_dir = tempfile.mkdtemp()
|
|
184
|
+
try:
|
|
185
|
+
traj_file = os.path.join(temp_dir, 'test.traj')
|
|
186
|
+
indices_file = os.path.join(temp_dir, 'indices.npy')
|
|
187
|
+
|
|
188
|
+
atoms = Atoms('H2O', positions=[
|
|
189
|
+
[0.0, 0.0, 0.0],
|
|
190
|
+
[0.96, 0.0, 0.0],
|
|
191
|
+
[0.24, 0.93, 0.0]
|
|
192
|
+
])
|
|
193
|
+
atoms.set_cell([10, 10, 10])
|
|
194
|
+
atoms.set_pbc([True, True, True])
|
|
195
|
+
write(traj_file, atoms)
|
|
196
|
+
|
|
197
|
+
indices = np.array([0, 1, 2])
|
|
198
|
+
np.save(indices_file, indices)
|
|
199
|
+
|
|
200
|
+
analyzer = analyze_frame(
|
|
201
|
+
traj_path=traj_file,
|
|
202
|
+
atom_indices=indices_file,
|
|
203
|
+
threshold=2.5,
|
|
204
|
+
min_samples=2
|
|
205
|
+
)
|
|
206
|
+
assert analyzer is not None
|
|
207
|
+
finally:
|
|
208
|
+
shutil.rmtree(temp_dir)
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
if __name__ == '__main__':
|
|
212
|
+
pytest.main([__file__, '-v'])
|