crisp-ase 1.0.0.post0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crisp-ase might be problematic. Click here for more details.

@@ -0,0 +1,198 @@
1
+ """
2
+ CRISP/simulation_utility/interatomic_distances.py
3
+
4
+ This module provides tools to calculate and analyze interatomic distances from
5
+ molecular dynamics trajectories.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ import pickle
11
+ from ase.io import read
12
+ from ase import Atoms
13
+ from typing import Union, Tuple, List, Dict, Any
14
+ from joblib import Parallel, delayed
15
+
16
+
17
+ def indices(atoms: Atoms, ind: Union[str, List[Union[int, str]], None]) -> np.ndarray:
18
+ """Extract atom indices from various input types.
19
+
20
+ Parameters
21
+ ----------
22
+ atoms : ase.Atoms
23
+ Atoms object containing atomic coordinates and elements
24
+ ind : str, list, or None
25
+ Specification for which atoms to select:
26
+ - "all" or None: all atoms
27
+ - string ending with ".npy": load indices from NumPy file
28
+ - list of integers: direct atom indices
29
+ - list of strings: chemical symbols to select
30
+
31
+ Returns
32
+ -------
33
+ np.ndarray
34
+ Array of atom indices
35
+
36
+ Raises
37
+ ------
38
+ ValueError
39
+ If the index type is not recognized
40
+ """
41
+ if ind == "all" or ind is None:
42
+ return np.arange(len(atoms))
43
+ if isinstance(ind, str) and ind.endswith(".npy"):
44
+ return np.load(ind, allow_pickle=True)
45
+ if not isinstance(ind, list):
46
+ ind = [ind]
47
+ if any(isinstance(item, int) for item in ind):
48
+ return np.array(ind)
49
+ if any(isinstance(item, str) for item in ind):
50
+ idx = []
51
+ for symbol in ind:
52
+ idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
53
+ return np.concatenate(idx)
54
+ raise ValueError("Invalid index type")
55
+
56
+
57
+ def distance_calculation(
58
+ traj_path: str,
59
+ frame_skip: int,
60
+ index_type: Union[str, List[Union[int, str]]] = "all"
61
+ ) -> Tuple[List[np.ndarray], List[np.ndarray]]:
62
+ """Calculate distance matrices for multiple frames in a trajectory.
63
+
64
+ Parameters
65
+ ----------
66
+ traj_path : str
67
+ Path to the trajectory file in any format supported by ASE
68
+ frame_skip : int
69
+ Read every nth frame (n=frame_skip)
70
+ index_type : str, list, or None, optional
71
+ Specification for which atoms to select for sub-matrix (default: "all")
72
+
73
+ Returns
74
+ -------
75
+ Tuple[List[np.ndarray], List[np.ndarray]]
76
+ Two lists containing:
77
+ 1. Full distance matrices for all frames
78
+ 2. Sub-matrices for specified atoms
79
+
80
+ Raises
81
+ ------
82
+ ValueError
83
+ If no frames were found in the trajectory or if format is unsupported
84
+ """
85
+ try:
86
+ # Let ASE auto-detect file format based on extension
87
+ frames = read(traj_path, index=f"::{frame_skip}")
88
+
89
+ # Handle the case when a single frame is returned (not a list)
90
+ if not isinstance(frames, list):
91
+ frames = [frames]
92
+
93
+ if not frames:
94
+ raise ValueError("No frames were found in the trajectory using the given frame_skip.")
95
+
96
+ def process_frame(frame: Atoms) -> Tuple[np.ndarray, np.ndarray]:
97
+ dm = frame.get_all_distances(mic=True)
98
+ idx = indices(frame, index_type)
99
+ sub_dm = dm[np.ix_(idx, idx)]
100
+ return dm, sub_dm
101
+
102
+ results = Parallel(n_jobs=-1)(delayed(process_frame)(frame) for frame in frames)
103
+ full_dms, sub_dms = zip(*results)
104
+ return list(full_dms), list(sub_dms)
105
+
106
+ except ValueError as e:
107
+ raise e
108
+ except Exception as e:
109
+ raise ValueError(f"Error processing trajectory: {e}. Check if the format is supported by ASE.")
110
+
111
+
112
+ def save_distance_matrices(
113
+ full_dms: List[np.ndarray],
114
+ sub_dms: List[np.ndarray],
115
+ index_type: Union[str, List[Union[int, str]]] = "all",
116
+ output_dir: str = "distance_calculations"
117
+ ) -> None:
118
+ """Save distance matrices to pickle file.
119
+
120
+ Parameters
121
+ ----------
122
+ full_dms : List[np.ndarray]
123
+ List of full distance matrices
124
+ sub_dms : List[np.ndarray]
125
+ List of sub-matrices for specified atoms
126
+ index_type : str, list, or None, optional
127
+ Type of index selection used (default: "all")
128
+ output_dir : str, optional
129
+ Directory to save output file (default: "distance_calculations")
130
+
131
+ Returns
132
+ -------
133
+ None
134
+ Saves results to disk
135
+ """
136
+ data = {"full_dms": full_dms}
137
+ if index_type not in ["all", None]:
138
+ data["sub_dms"] = sub_dms
139
+
140
+ os.makedirs(output_dir, exist_ok=True)
141
+ output_path = os.path.join(output_dir, "distance_matrices.pkl")
142
+ with open(output_path, "wb") as f:
143
+ pickle.dump(data, f)
144
+ print(f"Distance matrices saved in '{output_path}'")
145
+
146
+
147
+ def calculate_interatomic_distances(
148
+ traj_path: str,
149
+ frame_skip: int = 10,
150
+ index_type: Union[str, List[Union[int, str]]] = "all",
151
+ output_dir: str = "distance_calculations",
152
+ save_results: bool = True
153
+ ) -> Dict[str, List[np.ndarray]]:
154
+ """
155
+ Calculate interatomic distances for a trajectory and optionally save results.
156
+
157
+ Parameters
158
+ ----------
159
+ traj_path : str
160
+ Path to the trajectory file
161
+ frame_skip : int, optional
162
+ Read every nth frame (default: 10)
163
+ index_type : str, list, or None, optional
164
+ Specification for which atoms to select (default: "all")
165
+ output_dir : str, optional
166
+ Directory to save output file (default: "distance_calculations")
167
+ save_results : bool, optional
168
+ Whether to save results to disk (default: True)
169
+
170
+ Returns
171
+ -------
172
+ Dict[str, List[np.ndarray]]
173
+ Dictionary containing full distance matrices and optionally sub-matrices
174
+
175
+ Examples
176
+ --------
177
+ >>> results = calculate_interatomic_distances("trajectory.traj")
178
+ >>> first_frame_distances = results["full_dms"][0]
179
+ >>> print(f"Distance matrix shape: {first_frame_distances.shape}")
180
+ """
181
+ print(f"Calculating interatomic distances from '{traj_path}'")
182
+ print(f"Using frame skip: {frame_skip}")
183
+ print(f"Index type: {index_type}")
184
+
185
+ full_dms, sub_dms = distance_calculation(traj_path, frame_skip, index_type)
186
+
187
+ print(f"Processed {len(full_dms)} frames")
188
+ print(f"Full matrix shape: {full_dms[0].shape}")
189
+ print(f"Sub-matrix shape: {sub_dms[0].shape}")
190
+
191
+ results = {"full_dms": full_dms}
192
+ if index_type not in ["all", None]:
193
+ results["sub_dms"] = sub_dms
194
+
195
+ if save_results:
196
+ save_distance_matrices(full_dms, sub_dms, index_type, output_dir)
197
+
198
+ return results
@@ -0,0 +1,221 @@
1
+ """
2
+ CRISP/simulation_utility/subsampling.py
3
+
4
+ This module provides functionality for structure subsampling from molecular dynamics
5
+ trajectories using Farthest Point Sampling (FPS) with SOAP descriptors.
6
+ """
7
+
8
+ import numpy as np
9
+ from ase.io import read, write
10
+ import fpsample
11
+ import glob
12
+ import os
13
+ from dscribe.descriptors import SOAP
14
+ import matplotlib.pyplot as plt
15
+ from joblib import Parallel, delayed
16
+ from typing import Union, List
17
+
18
+
19
+ def indices(atoms, ind: Union[str, List[Union[int, str]]]) -> np.ndarray:
20
+ """
21
+ Extract atom indices from an ASE Atoms object based on the input specifier.
22
+
23
+ Parameters
24
+ ----------
25
+ atoms : ase.Atoms
26
+ ASE Atoms object containing atomic structure
27
+ ind : Union[str, List[Union[int, str]]]
28
+ Index specifier, can be:
29
+ - "all" or None: all atoms
30
+ - string ending with ".npy": load indices from NumPy file
31
+ - integer or list of integers: direct atom indices
32
+ - string or list of strings: chemical symbols to select
33
+
34
+ Returns
35
+ -------
36
+ np.ndarray
37
+ Array of selected indices
38
+
39
+ Raises
40
+ ------
41
+ ValueError
42
+ If the index type is invalid
43
+ """
44
+ # Select all atoms
45
+ if ind == "all" or ind is None:
46
+ return np.arange(len(atoms))
47
+
48
+ # Load from NumPy file
49
+ if isinstance(ind, str) and ind.endswith(".npy"):
50
+ return np.load(ind, allow_pickle=True)
51
+
52
+ # Convert single items to list
53
+ if not isinstance(ind, list):
54
+ ind = [ind]
55
+
56
+ # Handle integer indices directly
57
+ if any(isinstance(item, int) for item in ind):
58
+ return np.array(ind)
59
+
60
+ # Handle chemical symbols
61
+ if any(isinstance(item, str) for item in ind):
62
+ idx = []
63
+ for symbol in ind:
64
+ idx.append(np.where(np.array(atoms.get_chemical_symbols()) == symbol)[0])
65
+ return np.concatenate(idx)
66
+
67
+ raise ValueError("Invalid index type")
68
+
69
+
70
+ def compute_soap(structure, all_spec, rcut, idx):
71
+ """Compute SOAP descriptors for a given structure.
72
+
73
+ Parameters
74
+ ----------
75
+ structure : ase.Atoms
76
+ Atomic structure for which to compute SOAP descriptors
77
+ all_spec : list
78
+ List of chemical elements to include in the descriptor
79
+ rcut : float
80
+ Cutoff radius for the SOAP descriptor in Angstroms
81
+ idx : numpy.ndarray
82
+ Indices of atoms to use as centers for SOAP calculation
83
+
84
+ Returns
85
+ -------
86
+ numpy.ndarray
87
+ Average SOAP descriptor vector for the structure
88
+ """
89
+ periodic_cell = structure.cell.volume > 0
90
+ soap = SOAP(
91
+ species=all_spec,
92
+ periodic=periodic_cell,
93
+ r_cut=rcut,
94
+ n_max=8,
95
+ l_max=6,
96
+ sigma=0.5,
97
+ sparse=False
98
+ )
99
+ soap_ind = soap.create(structure, centers=idx)
100
+ return np.mean(soap_ind, axis=0)
101
+
102
+
103
+ def create_repres(traj_path, rcut=6, ind="all", n_jobs=-1):
104
+ """Create SOAP representation vectors for a trajectory.
105
+
106
+ Parameters
107
+ ----------
108
+ traj_path : list
109
+ List of ase.Atoms objects representing a trajectory
110
+ rcut : float, optional
111
+ Cutoff radius for the SOAP descriptor in Angstroms (default: 6)
112
+ ind : str, list, or None, optional
113
+ Specification for which atoms to use as SOAP centers (default: "all")
114
+ n_jobs : int, optional
115
+ Number of parallel jobs to run; -1 uses all available cores (default: -1)
116
+
117
+ Returns
118
+ -------
119
+ numpy.ndarray
120
+ Array of SOAP descriptors for each frame in the trajectory
121
+ """
122
+ all_spec = traj_path[0].get_chemical_symbols()
123
+ idx = indices(traj_path[0], ind=ind)
124
+
125
+ repres = Parallel(n_jobs=n_jobs)(
126
+ delayed(compute_soap)(structure, all_spec, rcut, idx) for structure in traj_path
127
+ )
128
+
129
+ return np.array(repres)
130
+
131
+
132
+ def subsample(traj_path, n_samples=50, index_type="all", rcut=6.0, file_format=None,
133
+ plot_subsample=False, frame_skip=1, output_dir="subsampled_structures"):
134
+ """Subsample a trajectory using Farthest Point Sampling with SOAP descriptors.
135
+
136
+ Parameters
137
+ ----------
138
+ traj_path : str
139
+ Path pattern to trajectory file(s); supports globbing
140
+ n_samples : int, optional
141
+ Number of frames to select (default: 50)
142
+ index_type : str, list, or None, optional
143
+ Specification for which atoms to use for SOAP calculation (default: "all")
144
+ rcut : float, optional
145
+ Cutoff radius for SOAP in Angstroms (default: 6.0)
146
+ file_format : str, optional
147
+ File format for ASE I/O (default: None, auto-detect)
148
+ plot_subsample : bool, optional
149
+ Whether to generate a plot of FPS distances (default: False)
150
+ frame_skip : int, optional
151
+ Read every nth frame from the trajectory (default: 1)
152
+ output_dir : str, optional
153
+ Directory to save the subsampled structures (default: "subsampled_structures")
154
+
155
+ Returns
156
+ -------
157
+ list
158
+ List of selected ase.Atoms frames
159
+
160
+ Notes
161
+ -----
162
+ The selected frames and plots are saved in the specified output directory
163
+ """
164
+ traj_files = glob.glob(traj_path)
165
+
166
+ # Check if any matching files were found
167
+ if not traj_files:
168
+ raise ValueError(f"No files found matching pattern: {traj_path}")
169
+
170
+ trajec = []
171
+ for file in traj_files:
172
+ if file_format is not None:
173
+ trajec += read(file, index=f'::{frame_skip}', format=file_format)
174
+ else:
175
+ trajec += read(file, index=f'::{frame_skip}')
176
+
177
+ if not isinstance(trajec, list):
178
+ trajec = [trajec]
179
+
180
+ repres = create_repres(trajec, ind=index_type, rcut=rcut)
181
+
182
+ # Ensure we don't request more samples than available frames
183
+ n_samples = min(n_samples, len(trajec))
184
+
185
+ perm = fpsample.fps_sampling(repres, n_samples, start_idx=0)
186
+
187
+ fps_frames = []
188
+
189
+ for str_idx, frame in enumerate(perm):
190
+ new_frame = trajec[frame]
191
+ fps_frames.append(new_frame)
192
+
193
+ os.makedirs(output_dir, exist_ok=True)
194
+
195
+ if plot_subsample:
196
+ distance = []
197
+ for i in range(1, len(perm)):
198
+ distance.append(np.min(np.linalg.norm(repres[perm[:i]] - repres[perm[i]], axis=1)))
199
+
200
+ plt.figure(figsize=(8, 6))
201
+ plt.plot(distance, c="blue", linewidth=2)
202
+ plt.ylim([0, 1.1 * max(distance)])
203
+ plt.xlabel("Number of subsampled structures")
204
+ plt.ylabel("Euclidean distance")
205
+ plt.title("FPS Subsampling")
206
+ plt.savefig(os.path.join(output_dir, "subsampled_convergence.png"), dpi=300)
207
+ plt.show()
208
+ plt.close()
209
+ print(f"Saved convergence plot to {os.path.join(output_dir, 'subsampled_convergence.png')}")
210
+
211
+ # Extract the base filename without path for output file using os.path for platform independence
212
+ base_filename = os.path.basename(traj_files[0])
213
+ output_file = os.path.join(output_dir, f"subsample_{base_filename}")
214
+
215
+ try:
216
+ write(output_file, fps_frames, format=file_format)
217
+ print(f"Saved {len(fps_frames)} subsampled structures to {output_file}")
218
+ except Exception as e:
219
+ print(f"Error saving subsampled structures: {e}")
220
+
221
+ return fps_frames
@@ -0,0 +1,3 @@
1
+ """
2
+ Empty init file in case you choose a package besides PyTest such as Nose which may look for such a file.
3
+ """
@@ -0,0 +1,28 @@
1
+ """Basic CRISP package tests."""
2
+
3
+ import pytest
4
+ import CRISP
5
+
6
+
7
+ def test_crisp_import():
8
+ """Test that CRISP can be imported."""
9
+ assert CRISP is not None
10
+
11
+
12
+ def test_crisp_version():
13
+ """Test that CRISP has a version."""
14
+ try:
15
+ version = CRISP.__version__
16
+ assert isinstance(version, str)
17
+ except AttributeError:
18
+ pass
19
+
20
+
21
+ def test_crisp_modules():
22
+ """Test that main modules can be imported."""
23
+ try:
24
+ from CRISP.simulation_utility import atomic_indices
25
+ from CRISP.data_analysis import msd
26
+ assert True
27
+ except ImportError as e:
28
+ pytest.fail(f"Failed to import CRISP modules: {e}")