crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,200 @@
1
+ """
2
+ CRISP/simulation_utility/interatomic_distances.py
3
+
4
+ This module provides tools to calculate and analyze interatomic distances from
5
+ molecular dynamics trajectories.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ import pickle
11
+ from ase.io import read
12
+ from ase import Atoms
13
+ from typing import Union, Tuple, List, Dict, Any, Optional
14
+ from joblib import Parallel, delayed
15
+
16
+ __all__ = ['indices', 'distance_calculation', 'save_distance_matrices', 'calculate_interatomic_distances']
17
+
18
+
19
+ def indices(atoms: Atoms, ind: Union[str, List[Union[int, str]], None]) -> np.ndarray:
20
+ """Extract atom indices from various input types.
21
+
22
+ Parameters
23
+ ----------
24
+ atoms : ase.Atoms
25
+ Atoms object containing atomic coordinates and elements
26
+ ind : str, list, or None
27
+ Specification for which atoms to select:
28
+ - "all" or None: all atoms
29
+ - string ending with ".npy": load indices from NumPy file
30
+ - list of integers: direct atom indices
31
+ - list of strings: chemical symbols to select
32
+
33
+ Returns
34
+ -------
35
+ np.ndarray
36
+ Array of atom indices
37
+
38
+ Raises
39
+ ------
40
+ ValueError
41
+ If the index type is not recognized
42
+ """
43
+ if ind == "all" or ind is None:
44
+ return np.arange(len(atoms))
45
+ if isinstance(ind, str) and ind.endswith(".npy"):
46
+ return np.load(ind, allow_pickle=True)
47
+ if not isinstance(ind, list):
48
+ ind = [ind]
49
+ if any(isinstance(item, int) for item in ind):
50
+ return np.array(ind)
51
+ if any(isinstance(item, str) for item in ind):
52
+ idx = []
53
+ for symbol in ind:
54
+ idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
55
+ return np.concatenate(idx)
56
+ raise ValueError("Invalid index type")
57
+
58
+
59
+ def distance_calculation(
60
+ traj_path: str,
61
+ frame_skip: int,
62
+ index_type: Union[str, List[Union[int, str]]] = "all"
63
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
64
+ """Calculate distance matrices for multiple frames in a trajectory.
65
+
66
+ Parameters
67
+ ----------
68
+ traj_path : str
69
+ Path to the trajectory file in any format supported by ASE
70
+ frame_skip : int
71
+ Read every nth frame (n=frame_skip)
72
+ index_type : str, list, or None, optional
73
+ Specification for which atoms to select for sub-matrix (default: "all")
74
+
75
+ Returns
76
+ -------
77
+ Tuple[List[np.ndarray], List[np.ndarray]]
78
+ Two lists containing:
79
+ 1. Full distance matrices for all frames
80
+ 2. Sub-matrices for specified atoms
81
+
82
+ Raises
83
+ ------
84
+ ValueError
85
+ If no frames were found in the trajectory or if format is unsupported
86
+ """
87
+ try:
88
+ # Let ASE auto-detect file format based on extension
89
+ frames = read(traj_path, index=f"::{frame_skip}")
90
+
91
+ # Handle the case when a single frame is returned (not a list)
92
+ if not isinstance(frames, list):
93
+ frames = [frames]
94
+
95
+ if not frames:
96
+ raise ValueError("No frames were found in the trajectory using the given frame_skip.")
97
+
98
+ def process_frame(frame: Atoms) -> Tuple[np.ndarray, np.ndarray]:
99
+ dm = frame.get_all_distances(mic=True)
100
+ idx = indices(frame, index_type)
101
+ sub_dm = dm[np.ix_(idx, idx)]
102
+ return dm, sub_dm
103
+
104
+ results = Parallel(n_jobs=-1)(delayed(process_frame)(frame) for frame in frames)
105
+ full_dms, sub_dms = zip(*results)
106
+ return list(full_dms), list(sub_dms)
107
+
108
+ except ValueError as e:
109
+ raise e
110
+ except Exception as e:
111
+ raise ValueError(f"Error processing trajectory: {e}. Check if the format is supported by ASE.")
112
+
113
+
114
+ def save_distance_matrices(
115
+ full_dms: List[np.ndarray],
116
+ sub_dms: List[np.ndarray],
117
+ index_type: Union[str, List[Union[int, str]]] = "all",
118
+ output_dir: str = "distance_calculations"
119
+ ) -> None:
120
+ """Save distance matrices to pickle file.
121
+
122
+ Parameters
123
+ ----------
124
+ full_dms : List[np.ndarray]
125
+ List of full distance matrices
126
+ sub_dms : List[np.ndarray]
127
+ List of sub-matrices for specified atoms
128
+ index_type : str, list, or None, optional
129
+ Type of index selection used (default: "all")
130
+ output_dir : str, optional
131
+ Directory to save output file (default: "distance_calculations")
132
+
133
+ Returns
134
+ -------
135
+ None
136
+ Saves results to disk
137
+ """
138
+ data = {"full_dms": full_dms}
139
+ if index_type not in ["all", None]:
140
+ data["sub_dms"] = sub_dms
141
+
142
+ os.makedirs(output_dir, exist_ok=True)
143
+ output_path = os.path.join(output_dir, "distance_matrices.pkl")
144
+ with open(output_path, "wb") as f:
145
+ pickle.dump(data, f)
146
+ print(f"Distance matrices saved in '{output_path}'")
147
+
148
+
149
+ def calculate_interatomic_distances(
150
+ traj_path: str,
151
+ frame_skip: int = 10,
152
+ index_type: Union[str, List[Union[int, str]]] = "all",
153
+ output_dir: str = "distance_calculations",
154
+ save_results: bool = True
155
+ ) -> Dict[str, List[np.ndarray]]:
156
+ """
157
+ Calculate interatomic distances for a trajectory and optionally save results.
158
+
159
+ Parameters
160
+ ----------
161
+ traj_path : str
162
+ Path to the trajectory file
163
+ frame_skip : int, optional
164
+ Read every nth frame (default: 10)
165
+ index_type : str, list, or None, optional
166
+ Specification for which atoms to select (default: "all")
167
+ output_dir : str, optional
168
+ Directory to save output file (default: "distance_calculations")
169
+ save_results : bool, optional
170
+ Whether to save results to disk (default: True)
171
+
172
+ Returns
173
+ -------
174
+ Dict[str, List[np.ndarray]]
175
+ Dictionary containing full distance matrices and optionally sub-matrices
176
+
177
+ Examples
178
+ --------
179
+ >>> results = calculate_interatomic_distances("trajectory.traj")
180
+ >>> first_frame_distances = results["full_dms"][0]
181
+ >>> print(f"Distance matrix shape: {first_frame_distances.shape}")
182
+ """
183
+ print(f"Calculating interatomic distances from '{traj_path}'")
184
+ print(f"Using frame skip: {frame_skip}")
185
+ print(f"Index type: {index_type}")
186
+
187
+ full_dms, sub_dms = distance_calculation(traj_path, frame_skip, index_type)
188
+
189
+ print(f"Processed {len(full_dms)} frames")
190
+ print(f"Full matrix shape: {full_dms[0].shape}")
191
+ print(f"Sub-matrix shape: {sub_dms[0].shape}")
192
+
193
+ results = {"full_dms": full_dms}
194
+ if index_type not in ["all", None]:
195
+ results["sub_dms"] = sub_dms
196
+
197
+ if save_results:
198
+ save_distance_matrices(full_dms, sub_dms, index_type, output_dir)
199
+
200
+ return results
@@ -0,0 +1,241 @@
1
+ """
2
+ CRISP/simulation_utility/subsampling.py
3
+
4
+ This module provides functionality for structure subsampling from molecular dynamics
5
+ trajectories using Farthest Point Sampling (FPS) with SOAP descriptors.
6
+ """
7
+
8
+ import numpy as np
9
+ from ase.io import read, write
10
+ import fpsample
11
+ import glob
12
+ import os
13
+ from dscribe.descriptors import SOAP
14
+ import matplotlib.pyplot as plt
15
+ from joblib import Parallel, delayed
16
+ from typing import Union, List, Optional
17
+
18
+ __all__ = ['indices', 'compute_soap', 'create_repres', 'subsample']
19
+
20
+
21
+ def indices(atoms, ind: Union[str, List[Union[int, str]]]) -> np.ndarray:
22
+ """
23
+ Extract atom indices from an ASE Atoms object based on the input specifier.
24
+
25
+ Parameters
26
+ ----------
27
+ atoms : ase.Atoms
28
+ ASE Atoms object containing atomic structure
29
+ ind : Union[str, List[Union[int, str]]]
30
+ Index specifier, can be:
31
+ - "all" or None: all atoms
32
+ - string ending with ".npy": load indices from NumPy file
33
+ - integer or list of integers: direct atom indices
34
+ - string or list of strings: chemical symbols to select
35
+
36
+ Returns
37
+ -------
38
+ np.ndarray
39
+ Array of selected indices
40
+
41
+ Raises
42
+ ------
43
+ ValueError
44
+ If the index type is invalid
45
+ """
46
+ # Select all atoms
47
+ if ind == "all" or ind is None:
48
+ return np.arange(len(atoms))
49
+
50
+ # Load from NumPy file
51
+ if isinstance(ind, str) and ind.endswith(".npy"):
52
+ return np.load(ind, allow_pickle=True)
53
+
54
+ # Convert single items to list
55
+ if not isinstance(ind, list):
56
+ ind = [ind]
57
+
58
+ # Handle integer indices directly
59
+ if any(isinstance(item, int) for item in ind):
60
+ return np.array(ind)
61
+
62
+ # Handle chemical symbols
63
+ if any(isinstance(item, str) for item in ind):
64
+ idx = []
65
+ for symbol in ind:
66
+ idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
67
+ return np.concatenate(idx)
68
+
69
+ raise ValueError("Invalid index type")
70
+
71
+
72
+ def compute_soap(
73
+ structure,
74
+ all_spec: List[str],
75
+ rcut: float,
76
+ idx: np.ndarray
77
+ ) -> np.ndarray:
78
+ """Compute SOAP descriptors for a given structure.
79
+
80
+ Parameters
81
+ ----------
82
+ structure : ase.Atoms
83
+ Atomic structure for which to compute SOAP descriptors
84
+ all_spec : list
85
+ List of chemical elements to include in the descriptor
86
+ rcut : float
87
+ Cutoff radius for the SOAP descriptor in Angstroms
88
+ idx : numpy.ndarray
89
+ Indices of atoms to use as centers for SOAP calculation
90
+
91
+ Returns
92
+ -------
93
+ numpy.ndarray
94
+ Average SOAP descriptor vector for the structure
95
+ """
96
+ periodic_cell = structure.cell.volume > 0
97
+ soap = SOAP(
98
+ species=all_spec,
99
+ periodic=periodic_cell,
100
+ r_cut=rcut,
101
+ n_max=8,
102
+ l_max=6,
103
+ sigma=0.5,
104
+ sparse=False
105
+ )
106
+ soap_ind = soap.create(structure, centers=idx)
107
+ return np.mean(soap_ind, axis=0)
108
+
109
+
110
+ def create_repres(
111
+ traj_path: List,
112
+ rcut: float = 6,
113
+ ind: Union[str, List[Union[int, str]]] = "all",
114
+ n_jobs: int = -1
115
+ ) -> np.ndarray:
116
+ """Create SOAP representation vectors for a trajectory.
117
+
118
+ Parameters
119
+ ----------
120
+ traj_path : list
121
+ List of ase.Atoms objects representing a trajectory
122
+ rcut : float, optional
123
+ Cutoff radius for the SOAP descriptor in Angstroms (default: 6)
124
+ ind : str, list, or None, optional
125
+ Specification for which atoms to use as SOAP centers (default: "all")
126
+ n_jobs : int, optional
127
+ Number of parallel jobs to run; -1 uses all available cores (default: -1)
128
+
129
+ Returns
130
+ -------
131
+ numpy.ndarray
132
+ Array of SOAP descriptors for each frame in the trajectory
133
+ """
134
+ all_spec = traj_path[0].get_chemical_symbols()
135
+ idx = indices(traj_path[0], ind=ind)
136
+
137
+ repres = Parallel(n_jobs=n_jobs)(
138
+ delayed(compute_soap)(structure, all_spec, rcut, idx) for structure in traj_path
139
+ )
140
+
141
+ return np.array(repres)
142
+
143
+
144
+ def subsample(
145
+ traj_path: str,
146
+ n_samples: int = 50,
147
+ index_type: Union[str, List[Union[int, str]]] = "all",
148
+ rcut: float = 6.0,
149
+ file_format: Optional[str] = None,
150
+ plot_subsample: bool = False,
151
+ frame_skip: int = 1,
152
+ output_dir: str = "subsampled_structures"
153
+ ) -> None:
154
+ """Subsample a trajectory using Farthest Point Sampling with SOAP descriptors.
155
+
156
+ Parameters
157
+ ----------
158
+ traj_path : str
159
+ Path pattern to trajectory file(s); supports globbing
160
+ n_samples : int, optional
161
+ Number of frames to select (default: 50)
162
+ index_type : str, list, or None, optional
163
+ Specification for which atoms to use for SOAP calculation (default: "all")
164
+ rcut : float, optional
165
+ Cutoff radius for SOAP in Angstroms (default: 6.0)
166
+ file_format : str, optional
167
+ File format for ASE I/O (default: None, auto-detect)
168
+ plot_subsample : bool, optional
169
+ Whether to generate a plot of FPS distances (default: False)
170
+ frame_skip : int, optional
171
+ Read every nth frame from the trajectory (default: 1)
172
+ output_dir : str, optional
173
+ Directory to save the subsampled structures (default: "subsampled_structures")
174
+
175
+ Returns
176
+ -------
177
+ list
178
+ List of selected ase.Atoms frames
179
+
180
+ Notes
181
+ -----
182
+ The selected frames and plots are saved in the specified output directory
183
+ """
184
+ traj_files = glob.glob(traj_path)
185
+
186
+ # Check if any matching files were found
187
+ if not traj_files:
188
+ raise ValueError(f"No files found matching pattern: {traj_path}")
189
+
190
+ trajec = []
191
+ for file in traj_files:
192
+ if file_format is not None:
193
+ trajec += read(file, index=f'::{frame_skip}', format=file_format)
194
+ else:
195
+ trajec += read(file, index=f'::{frame_skip}')
196
+
197
+ if not isinstance(trajec, list):
198
+ trajec = [trajec]
199
+
200
+ repres = create_repres(trajec, ind=index_type, rcut=rcut)
201
+
202
+ # Ensure we don't request more samples than available frames
203
+ n_samples = min(n_samples, len(trajec))
204
+
205
+ perm = fpsample.fps_sampling(repres, n_samples, start_idx=0)
206
+
207
+ fps_frames = []
208
+
209
+ for str_idx, frame in enumerate(perm):
210
+ new_frame = trajec[frame]
211
+ fps_frames.append(new_frame)
212
+
213
+ os.makedirs(output_dir, exist_ok=True)
214
+
215
+ if plot_subsample:
216
+ distance = []
217
+ for i in range(1, len(perm)):
218
+ distance.append(np.min(np.linalg.norm(repres[perm[:i]] - repres[perm[i]], axis=1)))
219
+
220
+ plt.figure(figsize=(8, 6))
221
+ plt.plot(distance, c="blue", linewidth=2)
222
+ plt.ylim([0, 1.1 * max(distance)])
223
+ plt.xlabel("Number of subsampled structures")
224
+ plt.ylabel("Euclidean distance")
225
+ plt.title("FPS Subsampling")
226
+ plt.savefig(os.path.join(output_dir, "subsampled_convergence.png"), dpi=300)
227
+ plt.show()
228
+ plt.close()
229
+ print(f"Saved convergence plot to {os.path.join(output_dir, 'subsampled_convergence.png')}")
230
+
231
+ # Extract the base filename without path for output file using os.path for platform independence
232
+ base_filename = os.path.basename(traj_files[0])
233
+ output_file = os.path.join(output_dir, f"subsample_{base_filename}")
234
+
235
+ try:
236
+ write(output_file, fps_frames, format=file_format)
237
+ print(f"Saved {len(fps_frames)} subsampled structures to {output_file}")
238
+ except Exception as e:
239
+ print(f"Error saving subsampled structures: {e}")
240
+
241
+ return fps_frames
@@ -0,0 +1 @@
1
+ """CRISP DataAnalysis tests package."""
@@ -0,0 +1,212 @@
1
+ """Extended tests for clustering module to improve coverage."""
2
+ import pytest
3
+ import numpy as np
4
+ import os
5
+ import tempfile
6
+ import shutil
7
+ from ase import Atoms
8
+ from ase.io import write
9
+
10
+ try:
11
+ from CRISP.data_analysis.clustering import (
12
+ analyze_frame,
13
+ analyze_trajectory,
14
+ )
15
+ ASE_AVAILABLE = True
16
+ except ImportError:
17
+ ASE_AVAILABLE = False
18
+
19
+
20
+ @pytest.mark.skipif(not ASE_AVAILABLE, reason="ASE not available")
21
+ class TestClusteringExtended:
22
+ """Extended clustering tests for coverage."""
23
+
24
+ def test_analyze_frame_basic(self):
25
+ """Test frame clustering analysis."""
26
+ temp_dir = tempfile.mkdtemp()
27
+ try:
28
+ traj_file = os.path.join(temp_dir, 'test.traj')
29
+ atoms = Atoms('H2OH2O', positions=[
30
+ [0.0, 0.0, 0.0],
31
+ [0.96, 0.0, 0.0],
32
+ [0.24, 0.93, 0.0],
33
+ [2.8, 0.0, 0.0],
34
+ [3.76, 0.0, 0.0],
35
+ [3.04, 0.93, 0.0]
36
+ ])
37
+ atoms.set_cell([10, 10, 10])
38
+ atoms.set_pbc([True, True, True])
39
+ write(traj_file, atoms)
40
+
41
+ atom_indices = np.array([0, 1, 2, 3, 4, 5])
42
+ analyzer = analyze_frame(
43
+ traj_path=traj_file,
44
+ atom_indices=atom_indices,
45
+ threshold=2.5,
46
+ min_samples=2
47
+ )
48
+ assert analyzer is not None
49
+ finally:
50
+ shutil.rmtree(temp_dir)
51
+
52
+ @pytest.mark.parametrize("threshold", [1.5, 2.0, 2.5, 3.0])
53
+ def test_analyze_frame_different_cutoffs(self, threshold):
54
+ """Test with different distance cutoffs."""
55
+ temp_dir = tempfile.mkdtemp()
56
+ try:
57
+ traj_file = os.path.join(temp_dir, 'test.traj')
58
+ atoms = Atoms('H2OH2O', positions=[
59
+ [0.0, 0.0, 0.0],
60
+ [0.96, 0.0, 0.0],
61
+ [0.24, 0.93, 0.0],
62
+ [2.8, 0.0, 0.0],
63
+ [3.76, 0.0, 0.0],
64
+ [3.04, 0.93, 0.0]
65
+ ])
66
+ atoms.set_cell([10, 10, 10])
67
+ atoms.set_pbc([True, True, True])
68
+ write(traj_file, atoms)
69
+
70
+ atom_indices = np.array([0, 1, 2])
71
+ analyzer = analyze_frame(
72
+ traj_path=traj_file,
73
+ atom_indices=atom_indices,
74
+ threshold=threshold,
75
+ min_samples=1
76
+ )
77
+ assert analyzer is not None
78
+ finally:
79
+ shutil.rmtree(temp_dir)
80
+
81
+ def test_analyze_frame_calculate_distance_matrix(self):
82
+ """Test distance matrix calculation."""
83
+ temp_dir = tempfile.mkdtemp()
84
+ try:
85
+ traj_file = os.path.join(temp_dir, 'test.traj')
86
+ atoms = Atoms('H2O', positions=[
87
+ [0.0, 0.0, 0.0],
88
+ [0.96, 0.0, 0.0],
89
+ [0.24, 0.93, 0.0]
90
+ ])
91
+ atoms.set_cell([10, 10, 10])
92
+ atoms.set_pbc([True, True, True])
93
+ write(traj_file, atoms)
94
+
95
+ atom_indices = np.array([0, 1, 2])
96
+ analyzer = analyze_frame(
97
+ traj_path=traj_file,
98
+ atom_indices=atom_indices,
99
+ threshold=2.5,
100
+ min_samples=2
101
+ )
102
+
103
+ frame = analyzer.read_custom_frame()
104
+ assert frame is not None
105
+ dist_matrix, positions = analyzer.calculate_distance_matrix(frame)
106
+ assert dist_matrix is not None
107
+ finally:
108
+ shutil.rmtree(temp_dir)
109
+
110
+ def test_analyze_trajectory_basic(self):
111
+ """Test trajectory clustering."""
112
+ temp_dir = tempfile.mkdtemp()
113
+ try:
114
+ traj_file = os.path.join(temp_dir, 'test.traj')
115
+ atoms = Atoms('H2O', positions=[
116
+ [0.0, 0.0, 0.0],
117
+ [0.96, 0.0, 0.0],
118
+ [0.24, 0.93, 0.0]
119
+ ])
120
+ atoms.set_cell([10, 10, 10])
121
+ atoms.set_pbc([True, True, True])
122
+ write(traj_file, atoms)
123
+
124
+ atom_indices = np.array([0, 1, 2])
125
+ results = analyze_trajectory(
126
+ traj_path=traj_file,
127
+ indices_path=atom_indices,
128
+ threshold=2.5,
129
+ min_samples=2,
130
+ frame_skip=1
131
+ )
132
+ assert isinstance(results, list)
133
+ finally:
134
+ shutil.rmtree(temp_dir)
135
+
136
+
137
+ @pytest.mark.skipif(not ASE_AVAILABLE, reason="ASE not available")
138
+ class TestClusteringEdgeCases:
139
+ """Test edge cases for clustering."""
140
+
141
+ def test_clustering_min_atoms_validation(self):
142
+ """Test minimum atoms validation."""
143
+ temp_dir = tempfile.mkdtemp()
144
+ try:
145
+ traj_file = os.path.join(temp_dir, 'test.traj')
146
+ atoms = Atoms('H', positions=[[0.0, 0.0, 0.0]])
147
+ atoms.set_cell([10, 10, 10])
148
+ atoms.set_pbc([True, True, True])
149
+ write(traj_file, atoms)
150
+
151
+ atom_indices = np.array([0])
152
+ analyzer = analyze_frame(
153
+ traj_path=traj_file,
154
+ atom_indices=atom_indices,
155
+ threshold=2.5,
156
+ min_samples=5
157
+ )
158
+
159
+ frame = analyzer.read_custom_frame()
160
+ with pytest.raises(ValueError):
161
+ analyzer.calculate_distance_matrix(frame)
162
+ finally:
163
+ shutil.rmtree(temp_dir)
164
+
165
+ def test_clustering_invalid_trajectory(self):
166
+ """Test handling of invalid trajectory file."""
167
+ temp_dir = tempfile.mkdtemp()
168
+ try:
169
+ nonexistent = os.path.join(temp_dir, 'nonexistent.traj')
170
+ analyzer = analyze_frame(
171
+ traj_path=nonexistent,
172
+ atom_indices=np.array([0, 1, 2]),
173
+ threshold=2.5,
174
+ min_samples=2
175
+ )
176
+ frame = analyzer.read_custom_frame()
177
+ assert frame is None
178
+ finally:
179
+ shutil.rmtree(temp_dir)
180
+
181
+ def test_clustering_indices_from_file(self):
182
+ """Test loading indices from numpy file."""
183
+ temp_dir = tempfile.mkdtemp()
184
+ try:
185
+ traj_file = os.path.join(temp_dir, 'test.traj')
186
+ indices_file = os.path.join(temp_dir, 'indices.npy')
187
+
188
+ atoms = Atoms('H2O', positions=[
189
+ [0.0, 0.0, 0.0],
190
+ [0.96, 0.0, 0.0],
191
+ [0.24, 0.93, 0.0]
192
+ ])
193
+ atoms.set_cell([10, 10, 10])
194
+ atoms.set_pbc([True, True, True])
195
+ write(traj_file, atoms)
196
+
197
+ indices = np.array([0, 1, 2])
198
+ np.save(indices_file, indices)
199
+
200
+ analyzer = analyze_frame(
201
+ traj_path=traj_file,
202
+ atom_indices=indices_file,
203
+ threshold=2.5,
204
+ min_samples=2
205
+ )
206
+ assert analyzer is not None
207
+ finally:
208
+ shutil.rmtree(temp_dir)
209
+
210
+
211
+ if __name__ == '__main__':
212
+ pytest.main([__file__, '-v'])