crisp-ase 1.0.0.post0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crisp-ase might be problematic. Click here for more details.

@@ -0,0 +1,527 @@
1
+ """
2
+ CRISP/data_analysis/volumetric_density.py
3
+
4
+ This module performs volumetric density distribution analysis on specific atoms
5
+ in molecular dynamics trajectory data, creating 3D visualizations of atom density maps.
6
+ """
7
+
8
+ import numpy as np
9
+ import plotly.graph_objects as go
10
+ from ase.io import read
11
+ from typing import Optional, Dict, List, Union
12
+ import os
13
+
14
+ # Dictionary of van der Waals radii for all elements in Ångström
15
+ VDW_RADII = {
16
+ # Period 1
17
+ 'H': 1.20, 'He': 1.40,
18
+ # Period 2
19
+ 'Li': 1.82, 'Be': 1.53, 'B': 1.92, 'C': 1.70, 'N': 1.55, 'O': 1.52, 'F': 1.47, 'Ne': 1.54,
20
+ # Period 3
21
+ 'Na': 2.27, 'Mg': 1.73, 'Al': 1.84, 'Si': 2.10, 'P': 1.80, 'S': 1.80, 'Cl': 1.75, 'Ar': 1.88,
22
+ # Period 4
23
+ 'K': 2.75, 'Ca': 2.31, 'Sc': 2.30, 'Ti': 2.15, 'V': 2.05, 'Cr': 2.05, 'Mn': 2.05,
24
+ 'Fe': 2.05, 'Co': 2.00, 'Ni': 2.00, 'Cu': 2.00, 'Zn': 2.10, 'Ga': 1.87, 'Ge': 2.11,
25
+ 'As': 1.85, 'Se': 1.90, 'Br': 1.85, 'Kr': 2.02,
26
+ # Period 5
27
+ 'Rb': 3.03, 'Sr': 2.49, 'Y': 2.40, 'Zr': 2.30, 'Nb': 2.15, 'Mo': 2.10, 'Tc': 2.05,
28
+ 'Ru': 2.05, 'Rh': 2.00, 'Pd': 2.05, 'Ag': 2.10, 'Cd': 2.20, 'In': 2.20, 'Sn': 2.17,
29
+ 'Sb': 2.06, 'Te': 2.06, 'I': 1.98, 'Xe': 2.16,
30
+ # Period 6
31
+ 'Cs': 3.43, 'Ba': 2.68, 'La': 2.50, 'Ce': 2.48, 'Pr': 2.47, 'Nd': 2.45, 'Pm': 2.43,
32
+ 'Sm': 2.42, 'Eu': 2.40, 'Gd': 2.38, 'Tb': 2.37, 'Dy': 2.35, 'Ho': 2.33, 'Er': 2.32,
33
+ 'Tm': 2.30, 'Yb': 2.28, 'Lu': 2.27, 'Hf': 2.25, 'Ta': 2.20, 'W': 2.10, 'Re': 2.05,
34
+ 'Os': 2.00, 'Ir': 2.00, 'Pt': 2.05, 'Au': 2.10, 'Hg': 2.05, 'Tl': 2.20, 'Pb': 2.30,
35
+ 'Bi': 2.30, 'Po': 2.00, 'At': 2.00, 'Rn': 2.00,
36
+ # Period 7
37
+ 'Fr': 3.50, 'Ra': 2.80, 'Ac': 2.60, 'Th': 2.40, 'Pa': 2.30, 'U': 2.30, 'Np': 2.30,
38
+ 'Pu': 2.30, 'Am': 2.30, 'Cm': 2.30, 'Bk': 2.30, 'Cf': 2.30, 'Es': 2.30, 'Fm': 2.30,
39
+ 'Md': 2.30, 'No': 2.30, 'Lr': 2.30, 'Rf': 2.30, 'Db': 2.30, 'Sg': 2.30, 'Bh': 2.30,
40
+ 'Hs': 2.30, 'Mt': 2.30, 'Ds': 2.30, 'Rg': 2.30, 'Cn': 2.30, 'Nh': 2.30, 'Fl': 2.30,
41
+ 'Mc': 2.30, 'Lv': 2.30, 'Ts': 2.30, 'Og': 2.30
42
+ }
43
+
44
+ # Default color palette for atom types
45
+ ELEMENT_COLORS = {
46
+ # Common elements
47
+ 'H': 'white', 'C': 'black', 'N': 'blue', 'O': 'red', 'F': 'green',
48
+ 'Na': 'purple', 'Mg': 'pink', 'Al': 'gray', 'Si': 'yellow', 'P': 'orange',
49
+ 'S': 'yellow', 'Cl': 'green', 'K': 'purple', 'Ca': 'gray', 'Fe': 'orange',
50
+ 'Cu': 'orange', 'Zn': 'gray',
51
+ # Additional common elements with colors
52
+ 'Br': 'brown', 'I': 'purple', 'Li': 'purple', 'B': 'olive',
53
+ 'He': 'cyan', 'Ne': 'cyan', 'Ar': 'cyan', 'Kr': 'cyan', 'Xe': 'cyan',
54
+ 'Mn': 'gray', 'Co': 'blue', 'Ni': 'green', 'Pd': 'gray', 'Pt': 'gray',
55
+ 'Au': 'gold', 'Hg': 'silver', 'Pb': 'darkgray', 'Ag': 'silver',
56
+ 'Ti': 'gray', 'V': 'gray', 'Cr': 'gray', 'Zr': 'gray', 'Mo': 'gray',
57
+ 'W': 'gray', 'U': 'green'
58
+ }
59
+
60
+
61
+ def create_density_map(
62
+ traj_path: str,
63
+ indices_path: str,
64
+ frame_skip: int = 100,
65
+ threshold: float = 0.05,
66
+ absolute_threshold: bool = False,
67
+ opacity: float = 0.2,
68
+ atom_size_scale: float = 3.0,
69
+ output_dir: str = ".",
70
+ output_file: Optional[str] = None,
71
+ colorscale: str = 'Plasma',
72
+ plot_title: str = 'Density Distribution of Selected Atoms',
73
+ nbins: int = 50,
74
+ omit_static_indices: Optional[Union[str, List[int], np.ndarray]] = None,
75
+ save_density: bool = False,
76
+ density_output_file: Optional[str] = None,
77
+ show_projections: bool = False,
78
+ projection_opacity: float = 0.7,
79
+ projection_offset: float = 2.0,
80
+ save_projection_images: bool = False,
81
+ projection_image_dpi: int = 300
82
+ ) -> go.Figure:
83
+ """
84
+ Create a 3D visualization of atom density with molecular structure and save to HTML.
85
+
86
+ This function analyzes the spatial distribution of selected atoms across a
87
+ trajectory, creating a volumetric density map showing where specified atoms
88
+ tend to be located over time.
89
+
90
+ Parameters
91
+ ----------
92
+ traj_path : str
93
+ Path to the ASE trajectory file
94
+ indices_path : str
95
+ Path to numpy file containing atom indices to analyze
96
+ frame_skip : int, optional
97
+ Only read every nth frame from trajectory (default: 100)
98
+ threshold : float, optional
99
+ Density value threshold (default: 0.05)
100
+ absolute_threshold : bool, optional
101
+ If True, threshold is an absolute count value;
102
+ If False (default), threshold is relative to maximum (0-1 scale)
103
+ opacity : float, optional
104
+ Transparency of the density visualization (default: 0.2, range: 0.0-1.0)
105
+ atom_size_scale : float, optional
106
+ Scale factor for atom sizes (default: 3.0)
107
+ output_dir : str, optional
108
+ Directory to save output files (default: current directory)
109
+ output_file : str, optional
110
+ Filename for HTML output (default: auto-generated from trajectory name)
111
+ colorscale : str, optional
112
+ Colorscale for density plot ('Plasma', 'Viridis', 'Blues', etc.)
113
+ plot_title : str, optional
114
+ Title for the visualization
115
+ nbins : int, optional
116
+ Number of bins for density grid in each dimension (default: 50)
117
+ omit_static_indices: str, list or np.ndarray, optional
118
+ Indices of atoms to omit from the static structure visualization.
119
+ Can be a path to a numpy file (like indices_path) or an array of indices.
120
+ If None, the selected_indices (from indices_path) will be used.
121
+ save_density : bool, optional
122
+ If True, saves the density data to a file (default: False)
123
+ density_output_file : str, optional
124
+ Filename for saving density data (default: auto-generated from trajectory name)
125
+ show_projections : bool, optional
126
+ Whether to show 2D projections of the density (default: False)
127
+ projection_opacity : float, optional
128
+ Opacity of projection surfaces (default: 0.7)
129
+ projection_offset : float, optional
130
+ Distance to offset projections from the main volume (default: 2.0)
131
+ save_projection_images : bool, optional
132
+ Whether to save 2D projections as separate PNG files (default: False)
133
+ projection_image_dpi : int, optional
134
+ Resolution of saved projection images (default: 300)
135
+
136
+ Returns
137
+ -------
138
+ plotly.graph_objects.Figure
139
+ The generated figure object
140
+
141
+ Notes
142
+ -----
143
+ The function creates a 3D histogram of selected atom positions across the
144
+ trajectory and visualizes it as an isosurface volume plot, overlaid with
145
+ the reference molecular structure and unit cell boundaries.
146
+ """
147
+
148
+ os.makedirs(output_dir, exist_ok=True)
149
+
150
+ if output_file is None:
151
+ traj_basename = os.path.splitext(os.path.basename(traj_path))[0]
152
+ output_file = f"{traj_basename}_density_map.html"
153
+
154
+ output_path = os.path.join(output_dir, output_file)
155
+
156
+ threshold = max(0.0, threshold) # Ensures threshold is at least 0
157
+ opacity = max(0.0, min(1.0, opacity)) # Ensures opacity is between 0 and 1
158
+
159
+ print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
160
+ trajectory = read(traj_path, index=f"::{frame_skip}")
161
+ selected_indices = np.load(indices_path)
162
+ print(f"Loaded {len(trajectory)} frames, {len(selected_indices)} selected indices")
163
+
164
+ # Get a reference frame for the static structure
165
+ static_frame = trajectory[0]
166
+ cell = static_frame.get_cell()
167
+
168
+ # Define cell boundaries (assuming orthogonal cell)
169
+ xmin, ymin, zmin = 0.0, 0.0, 0.0
170
+ xmax, ymax, zmax = cell[0, 0], cell[1, 1], cell[2, 2]
171
+
172
+ print("Extracting selected atom positions from trajectory...")
173
+ positions = []
174
+ for frame in trajectory:
175
+ frame.wrap()
176
+ for idx in selected_indices:
177
+ positions.append(frame.positions[idx])
178
+ positions = np.array(positions)
179
+
180
+ # Create a 3D grid using np.histogramdd
181
+ print("Creating density grid...")
182
+ edges_x = np.linspace(xmin, xmax, nbins)
183
+ edges_y = np.linspace(ymin, ymax, nbins)
184
+ edges_z = np.linspace(zmin, zmax, nbins)
185
+ H, edges = np.histogramdd(positions, bins=[edges_x, edges_y, edges_z])
186
+
187
+ if save_density:
188
+ if density_output_file is None:
189
+ traj_basename = os.path.splitext(os.path.basename(traj_path))[0]
190
+ density_output_file = f"{traj_basename}_density_data.npz"
191
+
192
+ density_path = os.path.join(output_dir, density_output_file)
193
+
194
+ np.savez(
195
+ density_path,
196
+ density=H, # Actual density histogram
197
+ edges=edges, # Bin edges
198
+ cell=cell, # Unit cell
199
+ nbins=nbins, # Number of bins
200
+ selected_indices=selected_indices # Which atoms were used
201
+ )
202
+
203
+ print(f"Density data saved to: {density_path}")
204
+
205
+ x_centers = (edges[0][:-1] + edges[0][1:]) / 2
206
+ y_centers = (edges[1][:-1] + edges[1][1:]) / 2
207
+ z_centers = (edges[2][:-1] + edges[2][1:]) / 2
208
+ X, Y, Z = np.meshgrid(x_centers, y_centers, z_centers, indexing='ij')
209
+
210
+ # Choose whether to use absolute or relative thresholds
211
+ if absolute_threshold:
212
+ vol = H
213
+ isomin = threshold # absolute count threshold
214
+ isomax = H.max() # maximum raw count
215
+ threshold_type = "absolute"
216
+ else:
217
+ # Normalize the histogram for relative thresholds
218
+ vol = H / H.max()
219
+ isomin = threshold # relative threshold (0-1)
220
+ isomax = 1.0 # maximum normalized value
221
+ threshold_type = "relative"
222
+
223
+ print(f"Creating visualization with {threshold_type} threshold={threshold}, opacity={opacity}")
224
+ print(f"Density range: {H.min()} to {H.max()} counts")
225
+
226
+ fig = go.Figure()
227
+
228
+ # Add volume trace for density visualization
229
+ fig.add_trace(go.Volume(
230
+ x=X.flatten(),
231
+ y=Y.flatten(),
232
+ z=Z.flatten(),
233
+ value=vol.flatten(),
234
+ isomin=isomin,
235
+ isomax=isomax,
236
+ opacity=opacity,
237
+ surface_count=20,
238
+ colorscale=colorscale,
239
+ caps=dict(x_show=False, y_show=False, z_show=False),
240
+ name='Density Volume'
241
+ ))
242
+
243
+ if show_projections:
244
+ print("Adding 2D projections of density data...")
245
+
246
+ # Calculate projection data by summing along each axis
247
+ xy_projection = np.sum(H, axis=2) # Sum along z-axis
248
+ xz_projection = np.sum(H, axis=1) # Sum along y-axis
249
+ yz_projection = np.sum(H, axis=0) # Sum along x-axis
250
+
251
+ # Normalize if using relative threshold
252
+ if not absolute_threshold:
253
+ max_val = max(xy_projection.max(), xz_projection.max(), yz_projection.max())
254
+ if max_val > 0:
255
+ xy_projection = xy_projection / max_val
256
+ xz_projection = xz_projection / max_val
257
+ yz_projection = yz_projection / max_val
258
+
259
+ # Define offset positions for projections
260
+ # XY projection (bottom)
261
+ z_offset = np.ones((len(x_centers), len(y_centers))) * (np.min(z_centers) - projection_offset)
262
+ # YZ projection (left)
263
+ x_offset = np.ones((len(y_centers), len(z_centers))) * (np.min(x_centers) - projection_offset)
264
+ # XZ projection (back)
265
+ y_offset = np.ones((len(x_centers), len(z_centers))) * (np.max(y_centers) + projection_offset)
266
+
267
+ xy_text = [[f'x: {x_centers[i]:.2f}<br>y: {y_centers[j]:.2f}<br>Density: {xy_projection[i,j]:.2f}'
268
+ for j in range(len(y_centers))] for i in range(len(x_centers))]
269
+
270
+ yz_text = [[f'y: {y_centers[i]:.2f}<br>z: {z_centers[j]:.2f}<br>Density: {yz_projection[i,j]:.2f}'
271
+ for j in range(len(z_centers))] for i in range(len(y_centers))]
272
+
273
+ xz_text = [[f'x: {x_centers[i]:.2f}<br>z: {z_centers[j]:.2f}<br>Density: {xz_projection[i,j]:.2f}'
274
+ for j in range(len(z_centers))] for i in range(len(x_centers))]
275
+
276
+ # Add XY projection (floor)
277
+ xx, yy = np.meshgrid(x_centers, y_centers, indexing='ij')
278
+ fig.add_trace(go.Surface(
279
+ z=z_offset,
280
+ x=xx,
281
+ y=yy,
282
+ surfacecolor=xy_projection,
283
+ colorscale=colorscale,
284
+ opacity=projection_opacity,
285
+ showscale=False,
286
+ text=xy_text,
287
+ hoverinfo='text',
288
+ name='XY Projection (Floor)'
289
+ ))
290
+
291
+ # Add YZ projection (left wall)
292
+ yy, zz = np.meshgrid(y_centers, z_centers, indexing='ij')
293
+ fig.add_trace(go.Surface(
294
+ z=zz,
295
+ x=x_offset,
296
+ y=yy,
297
+ surfacecolor=yz_projection,
298
+ colorscale=colorscale,
299
+ opacity=projection_opacity,
300
+ showscale=False,
301
+ text=yz_text,
302
+ hoverinfo='text',
303
+ name='YZ Projection (Left Wall)'
304
+ ))
305
+
306
+ # Add XZ projection (back wall)
307
+ xx, zz = np.meshgrid(x_centers, z_centers, indexing='ij')
308
+ fig.add_trace(go.Surface(
309
+ z=zz,
310
+ x=xx,
311
+ y=y_offset,
312
+ surfacecolor=xz_projection,
313
+ colorscale=colorscale,
314
+ opacity=projection_opacity,
315
+ showscale=False,
316
+ text=xz_text,
317
+ hoverinfo='text',
318
+ name='XZ Projection (Back Wall)'
319
+ ))
320
+
321
+ if save_projection_images:
322
+ output_basename = os.path.splitext(output_file)[0]
323
+
324
+ # Save XY projection (top-down view)
325
+ xy_fig = go.Figure(data=go.Heatmap(
326
+ z=xy_projection.T,
327
+ x=x_centers,
328
+ y=y_centers,
329
+ colorscale=colorscale,
330
+ colorbar=dict(
331
+ title='Density',
332
+ thickness=20
333
+ )
334
+ ))
335
+ xy_fig.update_layout(
336
+ title=f"{plot_title} - XY Projection (Top View)",
337
+ xaxis_title='X (Å)',
338
+ yaxis_title='Y (Å)',
339
+ width=800,
340
+ height=700
341
+ )
342
+ xy_path = os.path.join(output_dir, f"{output_basename}_xy_projection.png")
343
+ xy_fig.write_image(xy_path, scale=1, width=800, height=700, engine="kaleido")
344
+ print(f"XY projection saved to: {xy_path}")
345
+
346
+ # Save YZ projection (side view)
347
+ yz_fig = go.Figure(data=go.Heatmap(
348
+ z=yz_projection.T,
349
+ x=y_centers,
350
+ y=z_centers,
351
+ colorscale=colorscale,
352
+ colorbar=dict(
353
+ title='Density',
354
+ thickness=20
355
+ )
356
+ ))
357
+ yz_fig.update_layout(
358
+ title=f"{plot_title} - YZ Projection (Side View)",
359
+ xaxis_title='Y (Å)',
360
+ yaxis_title='Z (Å)',
361
+ width=800,
362
+ height=700
363
+ )
364
+ yz_path = os.path.join(output_dir, f"{output_basename}_yz_projection.png")
365
+ yz_fig.write_image(yz_path, scale=1, width=800, height=700, engine="kaleido")
366
+ print(f"YZ projection saved to: {yz_path}")
367
+
368
+ # Save XZ projection (front view)
369
+ xz_fig = go.Figure(data=go.Heatmap(
370
+ z=xz_projection.T,
371
+ x=x_centers,
372
+ y=z_centers,
373
+ colorscale=colorscale,
374
+ colorbar=dict(
375
+ title='Density',
376
+ thickness=20
377
+ )
378
+ ))
379
+ xz_fig.update_layout(
380
+ title=f"{plot_title} - XZ Projection (Front View)",
381
+ xaxis_title='X (Å)',
382
+ yaxis_title='Z (Å)',
383
+ width=800,
384
+ height=700
385
+ )
386
+ xz_path = os.path.join(output_dir, f"{output_basename}_xz_projection.png")
387
+ xz_fig.write_image(xz_path, scale=1, width=800, height=700, engine="kaleido")
388
+ print(f"XZ projection saved to: {xz_path}")
389
+
390
+ # Only add atomic structure visualization if projections are not enabled
391
+ if not show_projections:
392
+ symbols = static_frame.get_chemical_symbols()
393
+ positions = static_frame.positions
394
+ unique_elements = list(set(symbols))
395
+
396
+ indices_to_omit = selected_indices
397
+
398
+ if omit_static_indices is not None:
399
+ if isinstance(omit_static_indices, str):
400
+ # Treat as path to numpy file
401
+ indices_to_omit = np.load(omit_static_indices)
402
+ else:
403
+ indices_to_omit = np.array(omit_static_indices)
404
+ print(f"Omitting {len(indices_to_omit)} custom indices from static structure visualization")
405
+
406
+ for element in unique_elements:
407
+ element_indices = [i for i, symbol in enumerate(symbols) if symbol == element and i not in indices_to_omit]
408
+
409
+ if not element_indices:
410
+ continue
411
+
412
+ element_positions = positions[element_indices]
413
+
414
+ # Use ELEMENT_COLORS for color and VDW_RADII for size
415
+ color = ELEMENT_COLORS.get(element, 'gray')
416
+
417
+ # Calculate marker size based on van der Waals radius (with scaling factor)
418
+ size = VDW_RADII.get(element, 1.0) * atom_size_scale
419
+
420
+ # Uniform size for all atoms of this element
421
+ sizes = [size] * len(element_indices)
422
+
423
+ fig.add_trace(go.Scatter3d(
424
+ x=element_positions[:, 0],
425
+ y=element_positions[:, 1],
426
+ z=element_positions[:, 2],
427
+ mode='markers',
428
+ marker=dict(
429
+ size=sizes,
430
+ color=color,
431
+ opacity=0.8,
432
+ line=dict(color='black', width=0.5)
433
+ ),
434
+ name=element
435
+ ))
436
+ else:
437
+ print("Skipping atom visualization since projections are enabled")
438
+
439
+ if np.all(static_frame.get_pbc()):
440
+ cell = static_frame.get_cell()
441
+ corners = np.array([
442
+ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0],
443
+ [0, 0, 1], [1, 0, 1], [1, 1, 1], [0, 1, 1]
444
+ ])
445
+ box_coords = np.dot(corners, cell)
446
+
447
+ edges = [(0, 1), (1, 2), (2, 3), (3, 0),
448
+ (4, 5), (5, 6), (6, 7), (7, 4),
449
+ (0, 4), (1, 5), (2, 6), (3, 7)]
450
+
451
+ for i, j in edges:
452
+ fig.add_trace(go.Scatter3d(
453
+ x=[box_coords[i, 0], box_coords[j, 0]],
454
+ y=[box_coords[i, 1], box_coords[j, 1]],
455
+ z=[box_coords[i, 2], box_coords[j, 2]],
456
+ mode='lines',
457
+ line=dict(color='black', width=2),
458
+ showlegend=False
459
+ ))
460
+
461
+ threshold_info = f"(Threshold: {threshold} {'counts' if absolute_threshold else 'relative'})"
462
+ full_title = f"{plot_title} {threshold_info}"
463
+
464
+ fig.update_layout(
465
+ title=full_title,
466
+ scene=dict(
467
+ xaxis=dict(title='X (Å)'),
468
+ yaxis=dict(title='Y (Å)'),
469
+ zaxis=dict(title='Z (Å)'),
470
+ aspectmode='data' if not show_projections else 'cube'
471
+ ),
472
+ legend=dict(
473
+ yanchor="top",
474
+ y=0.99,
475
+ xanchor="left",
476
+ x=0.01
477
+ ),
478
+ margin=dict(l=0, r=0, b=0, t=30)
479
+ )
480
+
481
+ if show_projections:
482
+ fig.update_layout(
483
+ updatemenus=[
484
+ dict(
485
+ type="buttons",
486
+ direction="right",
487
+ buttons=[
488
+ dict(
489
+ args=[{"visible": [True, True, True, True]}], # Only 4 traces: volume + 3 projections
490
+ label="All",
491
+ method="update"
492
+ ),
493
+ dict(
494
+ args=[{"visible": [True, False, False, False]}],
495
+ label="Volume Only",
496
+ method="update"
497
+ ),
498
+ dict(
499
+ args=[{"visible": [True, True, False, False]}],
500
+ label="XY Projection",
501
+ method="update"
502
+ ),
503
+ dict(
504
+ args=[{"visible": [True, False, True, False]}],
505
+ label="YZ Projection",
506
+ method="update"
507
+ ),
508
+ dict(
509
+ args=[{"visible": [True, False, False, True]}],
510
+ label="XZ Projection",
511
+ method="update"
512
+ )
513
+ ],
514
+ pad={"r": 10, "t": 10},
515
+ showactive=True,
516
+ x=0.1,
517
+ xanchor="left",
518
+ y=1.1,
519
+ yanchor="top"
520
+ ),
521
+ ]
522
+ )
523
+
524
+ fig.write_html(output_path)
525
+ print(f"Visualization saved as HTML file: {output_path}")
526
+
527
+ return fig
CRISP/py.typed ADDED
@@ -0,0 +1 @@
1
+ # PEP 561 marker file. See https://peps.python.org/pep-0561/
@@ -0,0 +1,9 @@
1
+ # CRISP/simulation_utility/__init__.py
2
+
3
+ from .atomic_indices import *
4
+ from .atomic_traj_linemap import *
5
+ from .error_analysis import *
6
+ from .interatomic_distances import *
7
+ from .subsampling import *
8
+
9
+
@@ -0,0 +1,144 @@
1
+ """
2
+ CRISP/simulation_utility/atomic_indices.py
3
+
4
+ This module extracts atomic indices from trajectory files
5
+ and identifying atom pairs within specified cutoff distances.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ import ase.io
11
+ import csv
12
+
13
+
14
+ def atom_indices(traj_path, frame_index=0, custom_cutoffs=None):
15
+ """Extract atom indices by chemical symbol and find atom pairs within specified cutoffs.
16
+
17
+ Parameters
18
+ ----------
19
+ traj_path : str
20
+ Path to the trajectory file in any format supported by ASE
21
+ frame_index : int, optional
22
+ Index of the frame to analyze (default: 0)
23
+ custom_cutoffs : dict, optional
24
+ Dictionary with atom symbol pairs as keys and cutoff distances as values
25
+ Example: {('Si', 'O'): 2.0, ('Al', 'O'): 2.1}
26
+
27
+ Returns
28
+ -------
29
+ indices_by_symbol : dict
30
+ Dictionary with chemical symbols as keys and lists of atomic indices as values
31
+ dist_matrix : numpy.ndarray
32
+ Distance matrix between all atoms, accounting for periodic boundary conditions
33
+ cutoff_indices : dict
34
+ Dictionary with atom symbol pairs as keys and lists of (idx1, idx2, distance) tuples
35
+ for atoms that are within the specified cutoff distance
36
+ """
37
+ try:
38
+ new = ase.io.read(traj_path, index=frame_index)
39
+ dist_matrix = new.get_all_distances(mic=True)
40
+ symbols = new.get_chemical_symbols()
41
+
42
+ unique_symbols = list(set(symbols))
43
+
44
+ indices_by_symbol = {symbol: [] for symbol in unique_symbols}
45
+
46
+ for idx, atom in enumerate(new):
47
+ indices_by_symbol[atom.symbol].append(idx)
48
+
49
+ cutoff_indices = {}
50
+
51
+ if custom_cutoffs:
52
+ for pair, cutoff in custom_cutoffs.items():
53
+ symbol1, symbol2 = pair
54
+ pair_indices_distances = []
55
+ if symbol1 in indices_by_symbol and symbol2 in indices_by_symbol:
56
+ for idx1 in indices_by_symbol[symbol1]:
57
+ for idx2 in indices_by_symbol[symbol2]:
58
+ if dist_matrix[idx1, idx2] < cutoff:
59
+ pair_indices_distances.append(
60
+ (idx1, idx2, dist_matrix[idx1, idx2])
61
+ )
62
+ cutoff_indices[pair] = pair_indices_distances
63
+
64
+ return indices_by_symbol, dist_matrix, cutoff_indices
65
+
66
+ except Exception as e:
67
+ raise ValueError(f"Error processing atomic structure: {e}. Check if the format is supported by ASE.")
68
+
69
+
70
+ def run_atom_indices(traj_path, output_dir, frame_index=0, custom_cutoffs=None):
71
+ """Run atom index extraction and save results to files.
72
+
73
+ Parameters
74
+ ----------
75
+ traj_path : str
76
+ Path to the trajectory file in any format supported by ASE
77
+ output_dir : str
78
+ Directory where output files will be saved
79
+ frame_index : int, optional
80
+ Index of the frame to analyze (default: 0)
81
+ custom_cutoffs : dict, optional
82
+ Dictionary with atom symbol pairs as keys and cutoff distances as values
83
+
84
+ Returns
85
+ -------
86
+ None
87
+ Results are saved to the specified output directory:
88
+ - lengths.npy: Dictionary of number of atoms per element
89
+ - {symbol}_indices.npy: Numpy array of atom indices for each element
90
+ - cutoff/{symbol1}-{symbol2}_cutoff.csv: CSV files with atom pairs within cutoff
91
+ """
92
+ try:
93
+ try:
94
+ traj = ase.io.read(traj_path, index=":")
95
+ if isinstance(traj, list):
96
+ traj_length = len(traj)
97
+ else:
98
+ traj_length = 1
99
+ except TypeError:
100
+ ase.io.read(traj_path)
101
+ traj_length = 1
102
+
103
+ # Check if frame_index is within valid range
104
+ if frame_index < 0 or frame_index >= traj_length:
105
+ raise ValueError(
106
+ f"Error: Frame index {frame_index} is out of range. "
107
+ f"Valid range is 0 to {traj_length-1}."
108
+ )
109
+
110
+ print(f"Analyzing frame with index {frame_index} (out of {traj_length} frames)")
111
+
112
+ except ValueError as e:
113
+ raise e
114
+
115
+ except Exception as e:
116
+ raise ValueError(f"Error reading trajectory: {e}. Check if the format is supported by ASE.")
117
+
118
+ indices, dist_matrix, cutoff_indices = atom_indices(traj_path, frame_index, custom_cutoffs)
119
+
120
+ if not os.path.exists(output_dir):
121
+ os.makedirs(output_dir)
122
+
123
+ lengths = {symbol: len(indices[symbol]) for symbol in indices}
124
+ np.save(os.path.join(output_dir, "lengths.npy"), lengths)
125
+
126
+ for symbol, data in indices.items():
127
+ np.save(os.path.join(output_dir, f"{symbol}_indices.npy"), data)
128
+ print(f"Length of {symbol} indices: {len(data)}")
129
+
130
+ print("Outputs saved.")
131
+
132
+ cutoff_folder = os.path.join(output_dir, "cutoff")
133
+ if not os.path.exists(cutoff_folder):
134
+ os.makedirs(cutoff_folder)
135
+
136
+ for pair, pair_indices_distances in cutoff_indices.items():
137
+ symbol1, symbol2 = pair
138
+ filename = f"{symbol1}-{symbol2}_cutoff.csv"
139
+ filepath = os.path.join(cutoff_folder, filename)
140
+ with open(filepath, mode="w", newline="") as file:
141
+ writer = csv.writer(file)
142
+ writer.writerow([f"{symbol1} index", f"{symbol2} index", "distance"])
143
+ writer.writerows(pair_indices_distances)
144
+ print(f"Saved cutoff indices for {symbol1}-{symbol2} to {filepath}")