crisp-ase 1.0.0.post0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crisp-ase might be problematic. Click here for more details.
- CRISP/__init__.py +19 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +31 -0
- CRISP/data_analysis/__init__.py +9 -0
- CRISP/data_analysis/clustering.py +828 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +716 -0
- CRISP/data_analysis/msd.py +1179 -0
- CRISP/data_analysis/prdf.py +403 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +9 -0
- CRISP/simulation_utility/atomic_indices.py +144 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +253 -0
- CRISP/simulation_utility/interatomic_distances.py +198 -0
- CRISP/simulation_utility/subsampling.py +221 -0
- CRISP/tests/__init__.py +3 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_crisp_comprehensive.py +677 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/METADATA +116 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/RECORD +25 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/WHEEL +5 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/entry_points.txt +2 -0
- crisp_ase-1.0.0.post0.dev0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/data_analysis/volumetric_density.py
|
|
3
|
+
|
|
4
|
+
This module performs volumetric density distribution analysis on specific atoms
|
|
5
|
+
in molecular dynamics trajectory data, creating 3D visualizations of atom density maps.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import plotly.graph_objects as go
|
|
10
|
+
from ase.io import read
|
|
11
|
+
from typing import Optional, Dict, List, Union
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
# Dictionary of van der Waals radii for all elements in Ångström
|
|
15
|
+
VDW_RADII = {
|
|
16
|
+
# Period 1
|
|
17
|
+
'H': 1.20, 'He': 1.40,
|
|
18
|
+
# Period 2
|
|
19
|
+
'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
|
|
20
|
+
# Period 3
|
|
21
|
+
'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
|
|
22
|
+
# Period 4
|
|
23
|
+
'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
|
|
24
|
+
'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
|
|
25
|
+
'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
|
|
26
|
+
# Period 5
|
|
27
|
+
'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
|
|
28
|
+
'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
|
|
29
|
+
'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
|
|
30
|
+
# Period 6
|
|
31
|
+
'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
|
|
32
|
+
'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
|
|
33
|
+
'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
|
|
34
|
+
'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
|
|
35
|
+
'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
|
|
36
|
+
# Period 7
|
|
37
|
+
'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
|
|
38
|
+
'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
|
|
39
|
+
'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
|
|
40
|
+
'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
|
|
41
|
+
'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
# Default color palette for atom types
|
|
45
|
+
ELEMENT_COLORS = {
|
|
46
|
+
# Common elements
|
|
47
|
+
'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
|
|
48
|
+
'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
|
|
49
|
+
'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
|
|
50
|
+
'Cu': 'orange', 'Zn': 'gray',
|
|
51
|
+
# Additional common elements with colors
|
|
52
|
+
'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
|
|
53
|
+
'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
|
|
54
|
+
'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
|
|
55
|
+
'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
|
|
56
|
+
'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
|
|
57
|
+
'W': 'gray', 'U': 'green'
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def create_density_map(
|
|
62
|
+
traj_path: str,
|
|
63
|
+
indices_path: str,
|
|
64
|
+
frame_skip: int = 100,
|
|
65
|
+
threshold: float = 0.05,
|
|
66
|
+
absolute_threshold: bool = False,
|
|
67
|
+
opacity: float = 0.2,
|
|
68
|
+
atom_size_scale: float = 3.0,
|
|
69
|
+
output_dir: str = ".",
|
|
70
|
+
output_file: Optional[str] = None,
|
|
71
|
+
colorscale: str = 'Plasma',
|
|
72
|
+
plot_title: str = 'Density Distribution of Selected Atoms',
|
|
73
|
+
nbins: int = 50,
|
|
74
|
+
omit_static_indices: Optional[Union[str, List[int], np.ndarray]] = None,
|
|
75
|
+
save_density: bool = False,
|
|
76
|
+
density_output_file: Optional[str] = None,
|
|
77
|
+
show_projections: bool = False,
|
|
78
|
+
projection_opacity: float = 0.7,
|
|
79
|
+
projection_offset: float = 2.0,
|
|
80
|
+
save_projection_images: bool = False,
|
|
81
|
+
projection_image_dpi: int = 300
|
|
82
|
+
) -> go.Figure:
|
|
83
|
+
"""
|
|
84
|
+
Create a 3D visualization of atom density with molecular structure and save to HTML.
|
|
85
|
+
|
|
86
|
+
This function analyzes the spatial distribution of selected atoms across a
|
|
87
|
+
trajectory, creating a volumetric density map showing where specified atoms
|
|
88
|
+
tend to be located over time.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
traj_path : str
|
|
93
|
+
Path to the ASE trajectory file
|
|
94
|
+
indices_path : str
|
|
95
|
+
Path to numpy file containing atom indices to analyze
|
|
96
|
+
frame_skip : int, optional
|
|
97
|
+
Only read every nth frame from trajectory (default: 100)
|
|
98
|
+
threshold : float, optional
|
|
99
|
+
Density value threshold (default: 0.05)
|
|
100
|
+
absolute_threshold : bool, optional
|
|
101
|
+
If True, threshold is an absolute count value;
|
|
102
|
+
If False (default), threshold is relative to maximum (0-1 scale)
|
|
103
|
+
opacity : float, optional
|
|
104
|
+
Transparency of the density visualization (default: 0.2, range: 0.0-1.0)
|
|
105
|
+
atom_size_scale : float, optional
|
|
106
|
+
Scale factor for atom sizes (default: 3.0)
|
|
107
|
+
output_dir : str, optional
|
|
108
|
+
Directory to save output files (default: current directory)
|
|
109
|
+
output_file : str, optional
|
|
110
|
+
Filename for HTML output (default: auto-generated from trajectory name)
|
|
111
|
+
colorscale : str, optional
|
|
112
|
+
Colorscale for density plot ('Plasma', 'Viridis', 'Blues', etc.)
|
|
113
|
+
plot_title : str, optional
|
|
114
|
+
Title for the visualization
|
|
115
|
+
nbins : int, optional
|
|
116
|
+
Number of bins for density grid in each dimension (default: 50)
|
|
117
|
+
omit_static_indices: str, list or np.ndarray, optional
|
|
118
|
+
Indices of atoms to omit from the static structure visualization.
|
|
119
|
+
Can be a path to a numpy file (like indices_path) or an array of indices.
|
|
120
|
+
If None, the selected_indices (from indices_path) will be used.
|
|
121
|
+
save_density : bool, optional
|
|
122
|
+
If True, saves the density data to a file (default: False)
|
|
123
|
+
density_output_file : str, optional
|
|
124
|
+
Filename for saving density data (default: auto-generated from trajectory name)
|
|
125
|
+
show_projections : bool, optional
|
|
126
|
+
Whether to show 2D projections of the density (default: False)
|
|
127
|
+
projection_opacity : float, optional
|
|
128
|
+
Opacity of projection surfaces (default: 0.7)
|
|
129
|
+
projection_offset : float, optional
|
|
130
|
+
Distance to offset projections from the main volume (default: 2.0)
|
|
131
|
+
save_projection_images : bool, optional
|
|
132
|
+
Whether to save 2D projections as separate PNG files (default: False)
|
|
133
|
+
projection_image_dpi : int, optional
|
|
134
|
+
Resolution of saved projection images (default: 300)
|
|
135
|
+
|
|
136
|
+
Returns
|
|
137
|
+
-------
|
|
138
|
+
plotly.graph_objects.Figure
|
|
139
|
+
The generated figure object
|
|
140
|
+
|
|
141
|
+
Notes
|
|
142
|
+
-----
|
|
143
|
+
The function creates a 3D histogram of selected atom positions across the
|
|
144
|
+
trajectory and visualizes it as an isosurface volume plot, overlaid with
|
|
145
|
+
the reference molecular structure and unit cell boundaries.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
149
|
+
|
|
150
|
+
if output_file is None:
|
|
151
|
+
traj_basename = os.path.splitext(os.path.basename(traj_path))[0]
|
|
152
|
+
output_file = f"{traj_basename}_density_map.html"
|
|
153
|
+
|
|
154
|
+
output_path = os.path.join(output_dir, output_file)
|
|
155
|
+
|
|
156
|
+
threshold = max(0.0, threshold) # Ensures threshold is at least 0
|
|
157
|
+
opacity = max(0.0, min(1.0, opacity)) # Ensures opacity is between 0 and 1
|
|
158
|
+
|
|
159
|
+
print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
|
|
160
|
+
trajectory = read(traj_path, index=f"::{frame_skip}")
|
|
161
|
+
selected_indices = np.load(indices_path)
|
|
162
|
+
print(f"Loaded {len(trajectory)} frames, {len(selected_indices)} selected indices")
|
|
163
|
+
|
|
164
|
+
# Get a reference frame for the static structure
|
|
165
|
+
static_frame = trajectory[0]
|
|
166
|
+
cell = static_frame.get_cell()
|
|
167
|
+
|
|
168
|
+
# Define cell boundaries (assuming orthogonal cell)
|
|
169
|
+
xmin, ymin, zmin = 0.0, 0.0, 0.0
|
|
170
|
+
xmax, ymax, zmax = cell[0, 0], cell[1, 1], cell[2, 2]
|
|
171
|
+
|
|
172
|
+
print("Extracting selected atom positions from trajectory...")
|
|
173
|
+
positions = []
|
|
174
|
+
for frame in trajectory:
|
|
175
|
+
frame.wrap()
|
|
176
|
+
for idx in selected_indices:
|
|
177
|
+
positions.append(frame.positions[idx])
|
|
178
|
+
positions = np.array(positions)
|
|
179
|
+
|
|
180
|
+
# Create a 3D grid using np.histogramdd
|
|
181
|
+
print("Creating density grid...")
|
|
182
|
+
edges_x = np.linspace(xmin, xmax, nbins)
|
|
183
|
+
edges_y = np.linspace(ymin, ymax, nbins)
|
|
184
|
+
edges_z = np.linspace(zmin, zmax, nbins)
|
|
185
|
+
H, edges = np.histogramdd(positions, bins=[edges_x, edges_y, edges_z])
|
|
186
|
+
|
|
187
|
+
if save_density:
|
|
188
|
+
if density_output_file is None:
|
|
189
|
+
traj_basename = os.path.splitext(os.path.basename(traj_path))[0]
|
|
190
|
+
density_output_file = f"{traj_basename}_density_data.npz"
|
|
191
|
+
|
|
192
|
+
density_path = os.path.join(output_dir, density_output_file)
|
|
193
|
+
|
|
194
|
+
np.savez(
|
|
195
|
+
density_path,
|
|
196
|
+
density=H, # Actual density histogram
|
|
197
|
+
edges=edges, # Bin edges
|
|
198
|
+
cell=cell, # Unit cell
|
|
199
|
+
nbins=nbins, # Number of bins
|
|
200
|
+
selected_indices=selected_indices # Which atoms were used
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
print(f"Density data saved to: {density_path}")
|
|
204
|
+
|
|
205
|
+
x_centers = (edges[0][:-1] + edges[0][1:]) / 2
|
|
206
|
+
y_centers = (edges[1][:-1] + edges[1][1:]) / 2
|
|
207
|
+
z_centers = (edges[2][:-1] + edges[2][1:]) / 2
|
|
208
|
+
X, Y, Z = np.meshgrid(x_centers, y_centers, z_centers, indexing='ij')
|
|
209
|
+
|
|
210
|
+
# Choose whether to use absolute or relative thresholds
|
|
211
|
+
if absolute_threshold:
|
|
212
|
+
vol = H
|
|
213
|
+
isomin = threshold # absolute count threshold
|
|
214
|
+
isomax = H.max() # maximum raw count
|
|
215
|
+
threshold_type = "absolute"
|
|
216
|
+
else:
|
|
217
|
+
# Normalize the histogram for relative thresholds
|
|
218
|
+
vol = H / H.max()
|
|
219
|
+
isomin = threshold # relative threshold (0-1)
|
|
220
|
+
isomax = 1.0 # maximum normalized value
|
|
221
|
+
threshold_type = "relative"
|
|
222
|
+
|
|
223
|
+
print(f"Creating visualization with {threshold_type} threshold={threshold}, opacity={opacity}")
|
|
224
|
+
print(f"Density range: {H.min()} to {H.max()} counts")
|
|
225
|
+
|
|
226
|
+
fig = go.Figure()
|
|
227
|
+
|
|
228
|
+
# Add volume trace for density visualization
|
|
229
|
+
fig.add_trace(go.Volume(
|
|
230
|
+
x=X.flatten(),
|
|
231
|
+
y=Y.flatten(),
|
|
232
|
+
z=Z.flatten(),
|
|
233
|
+
value=vol.flatten(),
|
|
234
|
+
isomin=isomin,
|
|
235
|
+
isomax=isomax,
|
|
236
|
+
opacity=opacity,
|
|
237
|
+
surface_count=20,
|
|
238
|
+
colorscale=colorscale,
|
|
239
|
+
caps=dict(x_show=False, y_show=False, z_show=False),
|
|
240
|
+
name='Density Volume'
|
|
241
|
+
))
|
|
242
|
+
|
|
243
|
+
if show_projections:
|
|
244
|
+
print("Adding 2D projections of density data...")
|
|
245
|
+
|
|
246
|
+
# Calculate projection data by summing along each axis
|
|
247
|
+
xy_projection = np.sum(H, axis=2) # Sum along z-axis
|
|
248
|
+
xz_projection = np.sum(H, axis=1) # Sum along y-axis
|
|
249
|
+
yz_projection = np.sum(H, axis=0) # Sum along x-axis
|
|
250
|
+
|
|
251
|
+
# Normalize if using relative threshold
|
|
252
|
+
if not absolute_threshold:
|
|
253
|
+
max_val = max(xy_projection.max(), xz_projection.max(), yz_projection.max())
|
|
254
|
+
if max_val > 0:
|
|
255
|
+
xy_projection = xy_projection / max_val
|
|
256
|
+
xz_projection = xz_projection / max_val
|
|
257
|
+
yz_projection = yz_projection / max_val
|
|
258
|
+
|
|
259
|
+
# Define offset positions for projections
|
|
260
|
+
# XY projection (bottom)
|
|
261
|
+
z_offset = np.ones((len(x_centers), len(y_centers))) * (np.min(z_centers) - projection_offset)
|
|
262
|
+
# YZ projection (left)
|
|
263
|
+
x_offset = np.ones((len(y_centers), len(z_centers))) * (np.min(x_centers) - projection_offset)
|
|
264
|
+
# XZ projection (back)
|
|
265
|
+
y_offset = np.ones((len(x_centers), len(z_centers))) * (np.max(y_centers) + projection_offset)
|
|
266
|
+
|
|
267
|
+
xy_text = [[f'x: {x_centers[i]:.2f}<br>y: {y_centers[j]:.2f}<br>Density: {xy_projection[i,j]:.2f}'
|
|
268
|
+
for j in range(len(y_centers))] for i in range(len(x_centers))]
|
|
269
|
+
|
|
270
|
+
yz_text = [[f'y: {y_centers[i]:.2f}<br>z: {z_centers[j]:.2f}<br>Density: {yz_projection[i,j]:.2f}'
|
|
271
|
+
for j in range(len(z_centers))] for i in range(len(y_centers))]
|
|
272
|
+
|
|
273
|
+
xz_text = [[f'x: {x_centers[i]:.2f}<br>z: {z_centers[j]:.2f}<br>Density: {xz_projection[i,j]:.2f}'
|
|
274
|
+
for j in range(len(z_centers))] for i in range(len(x_centers))]
|
|
275
|
+
|
|
276
|
+
# Add XY projection (floor)
|
|
277
|
+
xx, yy = np.meshgrid(x_centers, y_centers, indexing='ij')
|
|
278
|
+
fig.add_trace(go.Surface(
|
|
279
|
+
z=z_offset,
|
|
280
|
+
x=xx,
|
|
281
|
+
y=yy,
|
|
282
|
+
surfacecolor=xy_projection,
|
|
283
|
+
colorscale=colorscale,
|
|
284
|
+
opacity=projection_opacity,
|
|
285
|
+
showscale=False,
|
|
286
|
+
text=xy_text,
|
|
287
|
+
hoverinfo='text',
|
|
288
|
+
name='XY Projection (Floor)'
|
|
289
|
+
))
|
|
290
|
+
|
|
291
|
+
# Add YZ projection (left wall)
|
|
292
|
+
yy, zz = np.meshgrid(y_centers, z_centers, indexing='ij')
|
|
293
|
+
fig.add_trace(go.Surface(
|
|
294
|
+
z=zz,
|
|
295
|
+
x=x_offset,
|
|
296
|
+
y=yy,
|
|
297
|
+
surfacecolor=yz_projection,
|
|
298
|
+
colorscale=colorscale,
|
|
299
|
+
opacity=projection_opacity,
|
|
300
|
+
showscale=False,
|
|
301
|
+
text=yz_text,
|
|
302
|
+
hoverinfo='text',
|
|
303
|
+
name='YZ Projection (Left Wall)'
|
|
304
|
+
))
|
|
305
|
+
|
|
306
|
+
# Add XZ projection (back wall)
|
|
307
|
+
xx, zz = np.meshgrid(x_centers, z_centers, indexing='ij')
|
|
308
|
+
fig.add_trace(go.Surface(
|
|
309
|
+
z=zz,
|
|
310
|
+
x=xx,
|
|
311
|
+
y=y_offset,
|
|
312
|
+
surfacecolor=xz_projection,
|
|
313
|
+
colorscale=colorscale,
|
|
314
|
+
opacity=projection_opacity,
|
|
315
|
+
showscale=False,
|
|
316
|
+
text=xz_text,
|
|
317
|
+
hoverinfo='text',
|
|
318
|
+
name='XZ Projection (Back Wall)'
|
|
319
|
+
))
|
|
320
|
+
|
|
321
|
+
if save_projection_images:
|
|
322
|
+
output_basename = os.path.splitext(output_file)[0]
|
|
323
|
+
|
|
324
|
+
# Save XY projection (top-down view)
|
|
325
|
+
xy_fig = go.Figure(data=go.Heatmap(
|
|
326
|
+
z=xy_projection.T,
|
|
327
|
+
x=x_centers,
|
|
328
|
+
y=y_centers,
|
|
329
|
+
colorscale=colorscale,
|
|
330
|
+
colorbar=dict(
|
|
331
|
+
title='Density',
|
|
332
|
+
thickness=20
|
|
333
|
+
)
|
|
334
|
+
))
|
|
335
|
+
xy_fig.update_layout(
|
|
336
|
+
title=f"{plot_title} - XY Projection (Top View)",
|
|
337
|
+
xaxis_title='X (Å)',
|
|
338
|
+
yaxis_title='Y (Å)',
|
|
339
|
+
width=800,
|
|
340
|
+
height=700
|
|
341
|
+
)
|
|
342
|
+
xy_path = os.path.join(output_dir, f"{output_basename}_xy_projection.png")
|
|
343
|
+
xy_fig.write_image(xy_path, scale=1, width=800, height=700, engine="kaleido")
|
|
344
|
+
print(f"XY projection saved to: {xy_path}")
|
|
345
|
+
|
|
346
|
+
# Save YZ projection (side view)
|
|
347
|
+
yz_fig = go.Figure(data=go.Heatmap(
|
|
348
|
+
z=yz_projection.T,
|
|
349
|
+
x=y_centers,
|
|
350
|
+
y=z_centers,
|
|
351
|
+
colorscale=colorscale,
|
|
352
|
+
colorbar=dict(
|
|
353
|
+
title='Density',
|
|
354
|
+
thickness=20
|
|
355
|
+
)
|
|
356
|
+
))
|
|
357
|
+
yz_fig.update_layout(
|
|
358
|
+
title=f"{plot_title} - YZ Projection (Side View)",
|
|
359
|
+
xaxis_title='Y (Å)',
|
|
360
|
+
yaxis_title='Z (Å)',
|
|
361
|
+
width=800,
|
|
362
|
+
height=700
|
|
363
|
+
)
|
|
364
|
+
yz_path = os.path.join(output_dir, f"{output_basename}_yz_projection.png")
|
|
365
|
+
yz_fig.write_image(yz_path, scale=1, width=800, height=700, engine="kaleido")
|
|
366
|
+
print(f"YZ projection saved to: {yz_path}")
|
|
367
|
+
|
|
368
|
+
# Save XZ projection (front view)
|
|
369
|
+
xz_fig = go.Figure(data=go.Heatmap(
|
|
370
|
+
z=xz_projection.T,
|
|
371
|
+
x=x_centers,
|
|
372
|
+
y=z_centers,
|
|
373
|
+
colorscale=colorscale,
|
|
374
|
+
colorbar=dict(
|
|
375
|
+
title='Density',
|
|
376
|
+
thickness=20
|
|
377
|
+
)
|
|
378
|
+
))
|
|
379
|
+
xz_fig.update_layout(
|
|
380
|
+
title=f"{plot_title} - XZ Projection (Front View)",
|
|
381
|
+
xaxis_title='X (Å)',
|
|
382
|
+
yaxis_title='Z (Å)',
|
|
383
|
+
width=800,
|
|
384
|
+
height=700
|
|
385
|
+
)
|
|
386
|
+
xz_path = os.path.join(output_dir, f"{output_basename}_xz_projection.png")
|
|
387
|
+
xz_fig.write_image(xz_path, scale=1, width=800, height=700, engine="kaleido")
|
|
388
|
+
print(f"XZ projection saved to: {xz_path}")
|
|
389
|
+
|
|
390
|
+
# Only add atomic structure visualization if projections are not enabled
|
|
391
|
+
if not show_projections:
|
|
392
|
+
symbols = static_frame.get_chemical_symbols()
|
|
393
|
+
positions = static_frame.positions
|
|
394
|
+
unique_elements = list(set(symbols))
|
|
395
|
+
|
|
396
|
+
indices_to_omit = selected_indices
|
|
397
|
+
|
|
398
|
+
if omit_static_indices is not None:
|
|
399
|
+
if isinstance(omit_static_indices, str):
|
|
400
|
+
# Treat as path to numpy file
|
|
401
|
+
indices_to_omit = np.load(omit_static_indices)
|
|
402
|
+
else:
|
|
403
|
+
indices_to_omit = np.array(omit_static_indices)
|
|
404
|
+
print(f"Omitting {len(indices_to_omit)} custom indices from static structure visualization")
|
|
405
|
+
|
|
406
|
+
for element in unique_elements:
|
|
407
|
+
element_indices = [i for i, symbol in enumerate(symbols) if symbol == element and i not in indices_to_omit]
|
|
408
|
+
|
|
409
|
+
if not element_indices:
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
element_positions = positions[element_indices]
|
|
413
|
+
|
|
414
|
+
# Use ELEMENT_COLORS for color and VDW_RADII for size
|
|
415
|
+
color = ELEMENT_COLORS.get(element, 'gray')
|
|
416
|
+
|
|
417
|
+
# Calculate marker size based on van der Waals radius (with scaling factor)
|
|
418
|
+
size = VDW_RADII.get(element, 1.0) * atom_size_scale
|
|
419
|
+
|
|
420
|
+
# Uniform size for all atoms of this element
|
|
421
|
+
sizes = [size] * len(element_indices)
|
|
422
|
+
|
|
423
|
+
fig.add_trace(go.Scatter3d(
|
|
424
|
+
x=element_positions[:, 0],
|
|
425
|
+
y=element_positions[:, 1],
|
|
426
|
+
z=element_positions[:, 2],
|
|
427
|
+
mode='markers',
|
|
428
|
+
marker=dict(
|
|
429
|
+
size=sizes,
|
|
430
|
+
color=color,
|
|
431
|
+
opacity=0.8,
|
|
432
|
+
line=dict(color='black', width=0.5)
|
|
433
|
+
),
|
|
434
|
+
name=element
|
|
435
|
+
))
|
|
436
|
+
else:
|
|
437
|
+
print("Skipping atom visualization since projections are enabled")
|
|
438
|
+
|
|
439
|
+
if np.all(static_frame.get_pbc()):
|
|
440
|
+
cell = static_frame.get_cell()
|
|
441
|
+
corners = np.array([
|
|
442
|
+
[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
|
|
443
|
+
[0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]
|
|
444
|
+
])
|
|
445
|
+
box_coords = np.dot(corners, cell)
|
|
446
|
+
|
|
447
|
+
edges = [(0, 1), (1, 2), (2, 3), (3, 0),
|
|
448
|
+
(4, 5), (5, 6), (6, 7), (7, 4),
|
|
449
|
+
(0, 4), (1, 5), (2, 6), (3, 7)]
|
|
450
|
+
|
|
451
|
+
for i, j in edges:
|
|
452
|
+
fig.add_trace(go.Scatter3d(
|
|
453
|
+
x=[box_coords[i, 0], box_coords[j, 0]],
|
|
454
|
+
y=[box_coords[i, 1], box_coords[j, 1]],
|
|
455
|
+
z=[box_coords[i, 2], box_coords[j, 2]],
|
|
456
|
+
mode='lines',
|
|
457
|
+
line=dict(color='black', width=2),
|
|
458
|
+
showlegend=False
|
|
459
|
+
))
|
|
460
|
+
|
|
461
|
+
threshold_info = f"(Threshold: {threshold} {'counts' if absolute_threshold else 'relative'})"
|
|
462
|
+
full_title = f"{plot_title} {threshold_info}"
|
|
463
|
+
|
|
464
|
+
fig.update_layout(
|
|
465
|
+
title=full_title,
|
|
466
|
+
scene=dict(
|
|
467
|
+
xaxis=dict(title='X (Å)'),
|
|
468
|
+
yaxis=dict(title='Y (Å)'),
|
|
469
|
+
zaxis=dict(title='Z (Å)'),
|
|
470
|
+
aspectmode='data' if not show_projections else 'cube'
|
|
471
|
+
),
|
|
472
|
+
legend=dict(
|
|
473
|
+
yanchor="top",
|
|
474
|
+
y=0.99,
|
|
475
|
+
xanchor="left",
|
|
476
|
+
x=0.01
|
|
477
|
+
),
|
|
478
|
+
margin=dict(l=0, r=0, b=0, t=30)
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
if show_projections:
|
|
482
|
+
fig.update_layout(
|
|
483
|
+
updatemenus=[
|
|
484
|
+
dict(
|
|
485
|
+
type="buttons",
|
|
486
|
+
direction="right",
|
|
487
|
+
buttons=[
|
|
488
|
+
dict(
|
|
489
|
+
args=[{"visible": [True, True, True, True]}], # Only 4 traces: volume + 3 projections
|
|
490
|
+
label="All",
|
|
491
|
+
method="update"
|
|
492
|
+
),
|
|
493
|
+
dict(
|
|
494
|
+
args=[{"visible": [True, False, False, False]}],
|
|
495
|
+
label="Volume Only",
|
|
496
|
+
method="update"
|
|
497
|
+
),
|
|
498
|
+
dict(
|
|
499
|
+
args=[{"visible": [True, True, False, False]}],
|
|
500
|
+
label="XY Projection",
|
|
501
|
+
method="update"
|
|
502
|
+
),
|
|
503
|
+
dict(
|
|
504
|
+
args=[{"visible": [True, False, True, False]}],
|
|
505
|
+
label="YZ Projection",
|
|
506
|
+
method="update"
|
|
507
|
+
),
|
|
508
|
+
dict(
|
|
509
|
+
args=[{"visible": [True, False, False, True]}],
|
|
510
|
+
label="XZ Projection",
|
|
511
|
+
method="update"
|
|
512
|
+
)
|
|
513
|
+
],
|
|
514
|
+
pad={"r": 10, "t": 10},
|
|
515
|
+
showactive=True,
|
|
516
|
+
x=0.1,
|
|
517
|
+
xanchor="left",
|
|
518
|
+
y=1.1,
|
|
519
|
+
yanchor="top"
|
|
520
|
+
),
|
|
521
|
+
]
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
fig.write_html(output_path)
|
|
525
|
+
print(f"Visualization saved as HTML file: {output_path}")
|
|
526
|
+
|
|
527
|
+
return fig
|
CRISP/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# PEP 561 marker file. See https://peps.python.org/pep-0561/
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/simulation_utility/atomic_indices.py
|
|
3
|
+
|
|
4
|
+
This module extracts atomic indices from trajectory files
|
|
5
|
+
and identifying atom pairs within specified cutoff distances.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import numpy as np
|
|
10
|
+
import ase.io
|
|
11
|
+
import csv
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def atom_indices(traj_path, frame_index=0, custom_cutoffs=None):
|
|
15
|
+
"""Extract atom indices by chemical symbol and find atom pairs within specified cutoffs.
|
|
16
|
+
|
|
17
|
+
Parameters
|
|
18
|
+
----------
|
|
19
|
+
traj_path : str
|
|
20
|
+
Path to the trajectory file in any format supported by ASE
|
|
21
|
+
frame_index : int, optional
|
|
22
|
+
Index of the frame to analyze (default: 0)
|
|
23
|
+
custom_cutoffs : dict, optional
|
|
24
|
+
Dictionary with atom symbol pairs as keys and cutoff distances as values
|
|
25
|
+
Example: {('Si', 'O'): 2.0, ('Al', 'O'): 2.1}
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
indices_by_symbol : dict
|
|
30
|
+
Dictionary with chemical symbols as keys and lists of atomic indices as values
|
|
31
|
+
dist_matrix : numpy.ndarray
|
|
32
|
+
Distance matrix between all atoms, accounting for periodic boundary conditions
|
|
33
|
+
cutoff_indices : dict
|
|
34
|
+
Dictionary with atom symbol pairs as keys and lists of (idx1, idx2, distance) tuples
|
|
35
|
+
for atoms that are within the specified cutoff distance
|
|
36
|
+
"""
|
|
37
|
+
try:
|
|
38
|
+
new = ase.io.read(traj_path, index=frame_index)
|
|
39
|
+
dist_matrix = new.get_all_distances(mic=True)
|
|
40
|
+
symbols = new.get_chemical_symbols()
|
|
41
|
+
|
|
42
|
+
unique_symbols = list(set(symbols))
|
|
43
|
+
|
|
44
|
+
indices_by_symbol = {symbol: [] for symbol in unique_symbols}
|
|
45
|
+
|
|
46
|
+
for idx, atom in enumerate(new):
|
|
47
|
+
indices_by_symbol[atom.symbol].append(idx)
|
|
48
|
+
|
|
49
|
+
cutoff_indices = {}
|
|
50
|
+
|
|
51
|
+
if custom_cutoffs:
|
|
52
|
+
for pair, cutoff in custom_cutoffs.items():
|
|
53
|
+
symbol1, symbol2 = pair
|
|
54
|
+
pair_indices_distances = []
|
|
55
|
+
if symbol1 in indices_by_symbol and symbol2 in indices_by_symbol:
|
|
56
|
+
for idx1 in indices_by_symbol[symbol1]:
|
|
57
|
+
for idx2 in indices_by_symbol[symbol2]:
|
|
58
|
+
if dist_matrix[idx1, idx2] < cutoff:
|
|
59
|
+
pair_indices_distances.append(
|
|
60
|
+
(idx1, idx2, dist_matrix[idx1, idx2])
|
|
61
|
+
)
|
|
62
|
+
cutoff_indices[pair] = pair_indices_distances
|
|
63
|
+
|
|
64
|
+
return indices_by_symbol, dist_matrix, cutoff_indices
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
raise ValueError(f"Error processing atomic structure: {e}. Check if the format is supported by ASE.")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def run_atom_indices(traj_path, output_dir, frame_index=0, custom_cutoffs=None):
|
|
71
|
+
"""Run atom index extraction and save results to files.
|
|
72
|
+
|
|
73
|
+
Parameters
|
|
74
|
+
----------
|
|
75
|
+
traj_path : str
|
|
76
|
+
Path to the trajectory file in any format supported by ASE
|
|
77
|
+
output_dir : str
|
|
78
|
+
Directory where output files will be saved
|
|
79
|
+
frame_index : int, optional
|
|
80
|
+
Index of the frame to analyze (default: 0)
|
|
81
|
+
custom_cutoffs : dict, optional
|
|
82
|
+
Dictionary with atom symbol pairs as keys and cutoff distances as values
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
None
|
|
87
|
+
Results are saved to the specified output directory:
|
|
88
|
+
- lengths.npy: Dictionary of number of atoms per element
|
|
89
|
+
- {symbol}_indices.npy: Numpy array of atom indices for each element
|
|
90
|
+
- cutoff/{symbol1}-{symbol2}_cutoff.csv: CSV files with atom pairs within cutoff
|
|
91
|
+
"""
|
|
92
|
+
try:
|
|
93
|
+
try:
|
|
94
|
+
traj = ase.io.read(traj_path, index=":")
|
|
95
|
+
if isinstance(traj, list):
|
|
96
|
+
traj_length = len(traj)
|
|
97
|
+
else:
|
|
98
|
+
traj_length = 1
|
|
99
|
+
except TypeError:
|
|
100
|
+
ase.io.read(traj_path)
|
|
101
|
+
traj_length = 1
|
|
102
|
+
|
|
103
|
+
# Check if frame_index is within valid range
|
|
104
|
+
if frame_index < 0 or frame_index >= traj_length:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
f"Error: Frame index {frame_index} is out of range. "
|
|
107
|
+
f"Valid range is 0 to {traj_length-1}."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
print(f"Analyzing frame with index {frame_index} (out of {traj_length} frames)")
|
|
111
|
+
|
|
112
|
+
except ValueError as e:
|
|
113
|
+
raise e
|
|
114
|
+
|
|
115
|
+
except Exception as e:
|
|
116
|
+
raise ValueError(f"Error reading trajectory: {e}. Check if the format is supported by ASE.")
|
|
117
|
+
|
|
118
|
+
indices, dist_matrix, cutoff_indices = atom_indices(traj_path, frame_index, custom_cutoffs)
|
|
119
|
+
|
|
120
|
+
if not os.path.exists(output_dir):
|
|
121
|
+
os.makedirs(output_dir)
|
|
122
|
+
|
|
123
|
+
lengths = {symbol: len(indices[symbol]) for symbol in indices}
|
|
124
|
+
np.save(os.path.join(output_dir, "lengths.npy"), lengths)
|
|
125
|
+
|
|
126
|
+
for symbol, data in indices.items():
|
|
127
|
+
np.save(os.path.join(output_dir, f"{symbol}_indices.npy"), data)
|
|
128
|
+
print(f"Length of {symbol} indices: {len(data)}")
|
|
129
|
+
|
|
130
|
+
print("Outputs saved.")
|
|
131
|
+
|
|
132
|
+
cutoff_folder = os.path.join(output_dir, "cutoff")
|
|
133
|
+
if not os.path.exists(cutoff_folder):
|
|
134
|
+
os.makedirs(cutoff_folder)
|
|
135
|
+
|
|
136
|
+
for pair, pair_indices_distances in cutoff_indices.items():
|
|
137
|
+
symbol1, symbol2 = pair
|
|
138
|
+
filename = f"{symbol1}-{symbol2}_cutoff.csv"
|
|
139
|
+
filepath = os.path.join(cutoff_folder, filename)
|
|
140
|
+
with open(filepath, mode="w", newline="") as file:
|
|
141
|
+
writer = csv.writer(file)
|
|
142
|
+
writer.writerow([f"{symbol1} index", f"{symbol2} index", "distance"])
|
|
143
|
+
writer.writerows(pair_indices_distances)
|
|
144
|
+
print(f"Saved cutoff indices for {symbol1}-{symbol2} to {filepath}")
|