crisp-ase 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CRISP/__init__.py +99 -0
- CRISP/_version.py +1 -0
- CRISP/cli.py +41 -0
- CRISP/data_analysis/__init__.py +38 -0
- CRISP/data_analysis/clustering.py +838 -0
- CRISP/data_analysis/contact_coordination.py +915 -0
- CRISP/data_analysis/h_bond.py +772 -0
- CRISP/data_analysis/msd.py +1199 -0
- CRISP/data_analysis/prdf.py +404 -0
- CRISP/data_analysis/volumetric_atomic_density.py +527 -0
- CRISP/py.typed +1 -0
- CRISP/simulation_utility/__init__.py +31 -0
- CRISP/simulation_utility/atomic_indices.py +155 -0
- CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
- CRISP/simulation_utility/error_analysis.py +254 -0
- CRISP/simulation_utility/interatomic_distances.py +200 -0
- CRISP/simulation_utility/subsampling.py +241 -0
- CRISP/tests/DataAnalysis/__init__.py +1 -0
- CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
- CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
- CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
- CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
- CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
- CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
- CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
- CRISP/tests/DataAnalysis/test_prdf.py +206 -0
- CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
- CRISP/tests/SimulationUtility/__init__.py +1 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
- CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
- CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
- CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
- CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
- CRISP/tests/__init__.py +1 -0
- CRISP/tests/test_CRISP.py +28 -0
- CRISP/tests/test_cli.py +87 -0
- CRISP/tests/test_crisp_comprehensive.py +679 -0
- crisp_ase-1.1.2.dist-info/METADATA +141 -0
- crisp_ase-1.1.2.dist-info/RECORD +42 -0
- crisp_ase-1.1.2.dist-info/WHEEL +5 -0
- crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
- crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,838 @@
|
|
|
1
|
+
"""
|
|
2
|
+
CRISP/data_analysis/clustering.py
|
|
3
|
+
|
|
4
|
+
This module performs cluster analysis on molecular dynamics trajectory data,
|
|
5
|
+
using DBSCAN algorithm to identify atom clusters and their properties.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from ase.io import read
|
|
10
|
+
from sklearn.cluster import DBSCAN
|
|
11
|
+
from sklearn.metrics import silhouette_score
|
|
12
|
+
import plotly.graph_objects as go
|
|
13
|
+
import pickle
|
|
14
|
+
import csv
|
|
15
|
+
import matplotlib.pyplot as plt
|
|
16
|
+
import os
|
|
17
|
+
from typing import Union, List, Optional, Tuple, Dict, Any
|
|
18
|
+
|
|
19
|
+
__all__ = ['analyze_frame', 'analyze_trajectory']
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class analyze_frame:
|
|
23
|
+
"""
|
|
24
|
+
Analyze atomic structures using DBSCAN clustering algorithm.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
traj_path : str
|
|
29
|
+
Path to trajectory file
|
|
30
|
+
atom_indices : np.ndarray
|
|
31
|
+
Array containing indices of atoms to analyze
|
|
32
|
+
threshold : float
|
|
33
|
+
DBSCAN eps parameter (distance threshold)
|
|
34
|
+
min_samples : int
|
|
35
|
+
DBSCAN min_samples parameter
|
|
36
|
+
metric : str, optional
|
|
37
|
+
Distance metric to use (default: 'precomputed')
|
|
38
|
+
custom_frame_index : int, optional
|
|
39
|
+
Specific frame to analyze (default: None, uses last frame)
|
|
40
|
+
"""
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
traj_path: str,
|
|
44
|
+
atom_indices: Union[str, np.ndarray],
|
|
45
|
+
threshold: float,
|
|
46
|
+
min_samples: int,
|
|
47
|
+
metric: str = 'precomputed',
|
|
48
|
+
custom_frame_index: Optional[int] = None
|
|
49
|
+
) -> None:
|
|
50
|
+
self.traj_path = traj_path
|
|
51
|
+
if isinstance(atom_indices, str):
|
|
52
|
+
atom_indices = np.load(atom_indices)
|
|
53
|
+
self.atom_indices = atom_indices
|
|
54
|
+
self.threshold = threshold
|
|
55
|
+
self.min_samples = min_samples
|
|
56
|
+
self.metric = metric
|
|
57
|
+
self.custom_frame_index = custom_frame_index
|
|
58
|
+
self.labels = None
|
|
59
|
+
self.distance_matrix = None
|
|
60
|
+
|
|
61
|
+
def read_custom_frame(self):
|
|
62
|
+
"""
|
|
63
|
+
Read a specific frame or the last frame from the trajectory.
|
|
64
|
+
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
ase.Atoms or None
|
|
68
|
+
Atomic structure or None if reading fails
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
if self.custom_frame_index is not None:
|
|
72
|
+
frame = read(self.traj_path, index=self.custom_frame_index)
|
|
73
|
+
else:
|
|
74
|
+
frame = read(self.traj_path, index='-1')
|
|
75
|
+
return frame
|
|
76
|
+
except Exception as e:
|
|
77
|
+
print(f"Error reading trajectory: {e}")
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
def calculate_distance_matrix(self, atoms):
|
|
81
|
+
"""
|
|
82
|
+
Calculate a distance matrix with periodic boundary conditions.
|
|
83
|
+
|
|
84
|
+
Parameters
|
|
85
|
+
----------
|
|
86
|
+
atoms : ase.Atoms
|
|
87
|
+
Atomic structure
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
Tuple[np.ndarray, np.ndarray]
|
|
92
|
+
Distance matrix and positions array
|
|
93
|
+
|
|
94
|
+
Raises
|
|
95
|
+
------
|
|
96
|
+
ValueError
|
|
97
|
+
If there are not enough atoms to form clusters
|
|
98
|
+
"""
|
|
99
|
+
positions = atoms.positions[self.atom_indices]
|
|
100
|
+
|
|
101
|
+
if len(self.atom_indices) < self.min_samples:
|
|
102
|
+
raise ValueError(f"Not enough atoms ({len(self.atom_indices)}) to form clusters with min_samples={self.min_samples}")
|
|
103
|
+
|
|
104
|
+
full_dm = atoms.get_all_distances(mic=True)
|
|
105
|
+
n_atoms = len(self.atom_indices)
|
|
106
|
+
self.distance_matrix = np.zeros((n_atoms, n_atoms))
|
|
107
|
+
|
|
108
|
+
for i, idx_i in enumerate(self.atom_indices):
|
|
109
|
+
for j, idx_j in enumerate(self.atom_indices):
|
|
110
|
+
self.distance_matrix[i, j] = full_dm[idx_i, idx_j]
|
|
111
|
+
|
|
112
|
+
return self.distance_matrix, positions
|
|
113
|
+
|
|
114
|
+
def find_clusters(self):
|
|
115
|
+
"""
|
|
116
|
+
Find clusters using DBSCAN algorithm.
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
np.ndarray
|
|
121
|
+
Cluster labels for each input point
|
|
122
|
+
|
|
123
|
+
Raises
|
|
124
|
+
------
|
|
125
|
+
ValueError
|
|
126
|
+
If distance matrix has not been calculated
|
|
127
|
+
"""
|
|
128
|
+
if self.distance_matrix is None:
|
|
129
|
+
raise ValueError("Distance matrix must be calculated first")
|
|
130
|
+
|
|
131
|
+
db = DBSCAN(
|
|
132
|
+
eps=self.threshold,
|
|
133
|
+
min_samples=self.min_samples,
|
|
134
|
+
metric=self.metric
|
|
135
|
+
).fit(self.distance_matrix if self.metric == 'precomputed' else None)
|
|
136
|
+
|
|
137
|
+
self.labels = db.labels_
|
|
138
|
+
return self.labels
|
|
139
|
+
|
|
140
|
+
def analyze_structure(self, save_html_path=None, output_dir=None):
|
|
141
|
+
"""
|
|
142
|
+
Analyze the structure and find clusters.
|
|
143
|
+
|
|
144
|
+
Parameters
|
|
145
|
+
----------
|
|
146
|
+
save_html_path : str, optional
|
|
147
|
+
Path to save HTML visualization
|
|
148
|
+
output_dir : str, optional
|
|
149
|
+
Directory to save all results
|
|
150
|
+
|
|
151
|
+
Returns
|
|
152
|
+
-------
|
|
153
|
+
dict or None
|
|
154
|
+
Dictionary with analysis results or None if analysis fails
|
|
155
|
+
"""
|
|
156
|
+
frame = self.read_custom_frame()
|
|
157
|
+
if frame is None:
|
|
158
|
+
return None
|
|
159
|
+
|
|
160
|
+
if output_dir is None and save_html_path is not None:
|
|
161
|
+
output_dir = os.path.dirname(save_html_path)
|
|
162
|
+
if not output_dir:
|
|
163
|
+
output_dir = "clustering_results"
|
|
164
|
+
|
|
165
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
166
|
+
print(f"\nSaving results to directory: {output_dir}")
|
|
167
|
+
|
|
168
|
+
if save_html_path is None:
|
|
169
|
+
base_name = os.path.splitext(os.path.basename(self.traj_path))[0]
|
|
170
|
+
save_html_path = os.path.join(output_dir, f"{base_name}_clusters.html")
|
|
171
|
+
|
|
172
|
+
distance_matrix, positions = self.calculate_distance_matrix(frame)
|
|
173
|
+
self.find_clusters()
|
|
174
|
+
|
|
175
|
+
# Silhouette score (excluding outliers)
|
|
176
|
+
silhouette_avg = calculate_silhouette_score(self.distance_matrix, self.labels)
|
|
177
|
+
|
|
178
|
+
create_html_visualization(
|
|
179
|
+
positions=positions,
|
|
180
|
+
labels=self.labels,
|
|
181
|
+
title='Interactive 3D Cluster Visualization',
|
|
182
|
+
save_path=save_html_path,
|
|
183
|
+
cell_dimensions=frame.cell.lengths()
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Extract cluster information
|
|
187
|
+
cluster_info = extract_cluster_info(self.labels, self.atom_indices)
|
|
188
|
+
num_clusters = cluster_info["num_clusters"]
|
|
189
|
+
outlier_count = cluster_info["outlier_count"]
|
|
190
|
+
avg_cluster_size = cluster_info["avg_cluster_size"]
|
|
191
|
+
cluster_to_original = cluster_info["cluster_to_original"]
|
|
192
|
+
|
|
193
|
+
print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original)
|
|
194
|
+
|
|
195
|
+
frame_info_path = os.path.join(output_dir, "frame_data.txt")
|
|
196
|
+
save_frame_info_to_file(
|
|
197
|
+
frame_info_path,
|
|
198
|
+
self.threshold,
|
|
199
|
+
self.min_samples,
|
|
200
|
+
num_clusters,
|
|
201
|
+
outlier_count,
|
|
202
|
+
silhouette_avg,
|
|
203
|
+
avg_cluster_size,
|
|
204
|
+
cluster_to_original,
|
|
205
|
+
self.labels,
|
|
206
|
+
self.atom_indices
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
pickle_path = os.path.join(output_dir, "single_frame_analysis.pkl")
|
|
210
|
+
result_data = {
|
|
211
|
+
"num_clusters": num_clusters,
|
|
212
|
+
"outlier_count": outlier_count,
|
|
213
|
+
"silhouette_avg": silhouette_avg,
|
|
214
|
+
"avg_cluster_size": avg_cluster_size,
|
|
215
|
+
"cluster_to_original": cluster_to_original,
|
|
216
|
+
"labels": self.labels,
|
|
217
|
+
"positions": positions,
|
|
218
|
+
"parameters": {
|
|
219
|
+
"threshold": self.threshold,
|
|
220
|
+
"min_samples": self.min_samples,
|
|
221
|
+
"trajectory": self.traj_path,
|
|
222
|
+
"frame_index": self.custom_frame_index
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
with open(pickle_path, 'wb') as f:
|
|
227
|
+
pickle.dump(result_data, f)
|
|
228
|
+
print(f"Full analysis data saved to: {pickle_path}")
|
|
229
|
+
|
|
230
|
+
return result_data
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def create_html_visualization(positions, labels, title, save_path, cell_dimensions=None):
|
|
234
|
+
"""
|
|
235
|
+
Create and save a 3D HTML visualization of clusters.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
positions : np.ndarray
|
|
240
|
+
Array of atom positions
|
|
241
|
+
labels : np.ndarray
|
|
242
|
+
Array of cluster labels
|
|
243
|
+
title : str
|
|
244
|
+
Title for the visualization
|
|
245
|
+
save_path : str
|
|
246
|
+
Path to save the HTML file
|
|
247
|
+
cell_dimensions : np.ndarray, optional
|
|
248
|
+
Cell dimensions from the simulation [a, b, c]
|
|
249
|
+
"""
|
|
250
|
+
fig = go.Figure()
|
|
251
|
+
|
|
252
|
+
for label in np.unique(labels):
|
|
253
|
+
cluster_points = positions[labels == label]
|
|
254
|
+
label_name = "Outliers" if label == -1 else f'Cluster {label}'
|
|
255
|
+
marker_size = 5
|
|
256
|
+
marker_color = 'gray' if label == -1 else None
|
|
257
|
+
|
|
258
|
+
fig.add_trace(
|
|
259
|
+
go.Scatter3d(
|
|
260
|
+
x=cluster_points[:, 0],
|
|
261
|
+
y=cluster_points[:, 1],
|
|
262
|
+
z=cluster_points[:, 2],
|
|
263
|
+
mode='markers',
|
|
264
|
+
marker=dict(
|
|
265
|
+
size=marker_size,
|
|
266
|
+
color=marker_color
|
|
267
|
+
),
|
|
268
|
+
name=label_name
|
|
269
|
+
)
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
layout_args = {
|
|
273
|
+
"title": title,
|
|
274
|
+
"legend": dict(itemsizing='constant')
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
scene_dict = {
|
|
278
|
+
"xaxis": dict(title='X'),
|
|
279
|
+
"yaxis": dict(title='Y'),
|
|
280
|
+
"zaxis": dict(title='Z')
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
if cell_dimensions is not None:
|
|
284
|
+
scene_dict["xaxis"]["range"] = [0, cell_dimensions[0]]
|
|
285
|
+
scene_dict["yaxis"]["range"] = [0, cell_dimensions[1]]
|
|
286
|
+
scene_dict["zaxis"]["range"] = [0, cell_dimensions[2]]
|
|
287
|
+
|
|
288
|
+
vertices = [
|
|
289
|
+
[0, 0, 0], [cell_dimensions[0], 0, 0],
|
|
290
|
+
[cell_dimensions[0], cell_dimensions[1], 0], [0, cell_dimensions[1], 0],
|
|
291
|
+
[0, 0, cell_dimensions[2]], [cell_dimensions[0], 0, cell_dimensions[2]],
|
|
292
|
+
[cell_dimensions[0], cell_dimensions[1], cell_dimensions[2]], [0, cell_dimensions[1], cell_dimensions[2]]
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
# Define box edges
|
|
296
|
+
i, j, k = [], [], []
|
|
297
|
+
# Bottom face
|
|
298
|
+
i.extend([0, 1, 2, 3, 0])
|
|
299
|
+
j.extend([1, 2, 3, 0, 4])
|
|
300
|
+
k.extend([0, 0, 0, 0, 0])
|
|
301
|
+
# Top face
|
|
302
|
+
i.extend([4, 5, 6, 7, 4])
|
|
303
|
+
j.extend([5, 6, 7, 4, 0])
|
|
304
|
+
k.extend([0, 0, 0, 0, 0])
|
|
305
|
+
# Vertical edges
|
|
306
|
+
i.extend([1, 2, 3])
|
|
307
|
+
j.extend([5, 6, 7])
|
|
308
|
+
k.extend([0, 0, 0])
|
|
309
|
+
|
|
310
|
+
fig.add_trace(go.Scatter3d(
|
|
311
|
+
x=[vertices[idx][0] for idx in i],
|
|
312
|
+
y=[vertices[idx][1] for idx in j],
|
|
313
|
+
z=[vertices[idx][2] for idx in k],
|
|
314
|
+
mode='lines',
|
|
315
|
+
line=dict(color='black', width=2),
|
|
316
|
+
name='Unit Cell',
|
|
317
|
+
showlegend=False
|
|
318
|
+
))
|
|
319
|
+
|
|
320
|
+
layout_args["scene"] = scene_dict
|
|
321
|
+
fig.update_layout(**layout_args)
|
|
322
|
+
|
|
323
|
+
fig.write_html(save_path)
|
|
324
|
+
print(f"3D visualization saved to {save_path}")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def calculate_silhouette_score(distance_matrix, labels):
|
|
328
|
+
"""
|
|
329
|
+
Calculate silhouette score, handling edge cases.
|
|
330
|
+
|
|
331
|
+
Parameters
|
|
332
|
+
----------
|
|
333
|
+
distance_matrix : np.ndarray
|
|
334
|
+
Distance matrix for points
|
|
335
|
+
labels : np.ndarray
|
|
336
|
+
Cluster labels
|
|
337
|
+
|
|
338
|
+
Returns
|
|
339
|
+
-------
|
|
340
|
+
float
|
|
341
|
+
Silhouette score or 0 if calculation fails
|
|
342
|
+
"""
|
|
343
|
+
try:
|
|
344
|
+
non_outlier_mask = labels != -1
|
|
345
|
+
if np.sum(non_outlier_mask) > 1:
|
|
346
|
+
# Extract the sub-matrix for non-outlier points
|
|
347
|
+
filtered_matrix = distance_matrix[np.ix_(non_outlier_mask, non_outlier_mask)]
|
|
348
|
+
filtered_labels = labels[non_outlier_mask]
|
|
349
|
+
return silhouette_score(filtered_matrix, filtered_labels, metric='precomputed')
|
|
350
|
+
return 0
|
|
351
|
+
except ValueError:
|
|
352
|
+
return 0
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def extract_cluster_info(labels, atom_indices):
|
|
356
|
+
"""
|
|
357
|
+
Extract cluster information from labels.
|
|
358
|
+
|
|
359
|
+
Parameters
|
|
360
|
+
----------
|
|
361
|
+
labels : np.ndarray
|
|
362
|
+
Cluster labels
|
|
363
|
+
atom_indices : np.ndarray
|
|
364
|
+
Original atom indices
|
|
365
|
+
|
|
366
|
+
Returns
|
|
367
|
+
-------
|
|
368
|
+
dict
|
|
369
|
+
Dictionary with cluster information
|
|
370
|
+
"""
|
|
371
|
+
cluster_indices = {}
|
|
372
|
+
cluster_sizes = {}
|
|
373
|
+
cluster_to_original = {}
|
|
374
|
+
|
|
375
|
+
for cluster_id in np.unique(labels):
|
|
376
|
+
if cluster_id != -1: # Only count actual clusters (not outliers)
|
|
377
|
+
cluster_indices[cluster_id] = np.where(labels == cluster_id)[0]
|
|
378
|
+
cluster_sizes[cluster_id] = len(cluster_indices[cluster_id])
|
|
379
|
+
cluster_to_original[cluster_id] = atom_indices[cluster_indices[cluster_id]]
|
|
380
|
+
|
|
381
|
+
outlier_count = np.sum(labels == -1)
|
|
382
|
+
num_clusters = len([label for label in np.unique(labels) if label != -1])
|
|
383
|
+
|
|
384
|
+
# Calculate average cluster size
|
|
385
|
+
avg_cluster_size = np.mean(list(cluster_sizes.values())) if cluster_sizes else 0
|
|
386
|
+
|
|
387
|
+
return {
|
|
388
|
+
"num_clusters": num_clusters,
|
|
389
|
+
"outlier_count": outlier_count,
|
|
390
|
+
"avg_cluster_size": avg_cluster_size,
|
|
391
|
+
"cluster_sizes": cluster_sizes,
|
|
392
|
+
"cluster_to_original": cluster_to_original
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original):
|
|
397
|
+
"""
|
|
398
|
+
Print a summary of clustering results.
|
|
399
|
+
|
|
400
|
+
Parameters
|
|
401
|
+
----------
|
|
402
|
+
num_clusters : int
|
|
403
|
+
Number of clusters found
|
|
404
|
+
outlier_count : int
|
|
405
|
+
Number of outliers
|
|
406
|
+
silhouette_avg : float
|
|
407
|
+
Average silhouette score
|
|
408
|
+
avg_cluster_size : float
|
|
409
|
+
Average cluster size
|
|
410
|
+
cluster_to_original : dict
|
|
411
|
+
Mapping from cluster IDs to original atom indices
|
|
412
|
+
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
None
|
|
416
|
+
"""
|
|
417
|
+
print(f"\nNumber of Clusters: {num_clusters}")
|
|
418
|
+
print(f"Number of Outliers: {outlier_count}")
|
|
419
|
+
print(f"Silhouette Score: {silhouette_avg:.4f}")
|
|
420
|
+
print(f"Average Cluster Size: {avg_cluster_size:.2f}")
|
|
421
|
+
print("Cluster Information:")
|
|
422
|
+
|
|
423
|
+
for cluster_id, atoms in cluster_to_original.items():
|
|
424
|
+
print(f" Cluster {cluster_id}: {len(atoms)} points")
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def save_frame_info_to_file(file_path, threshold, min_samples, num_clusters, outlier_count,
|
|
428
|
+
silhouette_avg, avg_cluster_size, cluster_to_original, labels, atom_indices):
|
|
429
|
+
"""
|
|
430
|
+
Save detailed frame information to a text file.
|
|
431
|
+
|
|
432
|
+
Parameters
|
|
433
|
+
----------
|
|
434
|
+
file_path : str
|
|
435
|
+
Path to save the text file
|
|
436
|
+
threshold : float
|
|
437
|
+
DBSCAN eps parameter
|
|
438
|
+
min_samples : int
|
|
439
|
+
DBSCAN min_samples parameter
|
|
440
|
+
num_clusters : int
|
|
441
|
+
Number of clusters found
|
|
442
|
+
outlier_count : int
|
|
443
|
+
Number of outliers
|
|
444
|
+
silhouette_avg : float
|
|
445
|
+
Average silhouette score
|
|
446
|
+
avg_cluster_size : float
|
|
447
|
+
Average cluster size
|
|
448
|
+
cluster_to_original : dict
|
|
449
|
+
Mapping from cluster IDs to original atom indices
|
|
450
|
+
labels : np.ndarray
|
|
451
|
+
Cluster labels
|
|
452
|
+
atom_indices : np.ndarray
|
|
453
|
+
Original atom indices
|
|
454
|
+
|
|
455
|
+
Returns
|
|
456
|
+
-------
|
|
457
|
+
None
|
|
458
|
+
"""
|
|
459
|
+
with open(file_path, 'w') as f:
|
|
460
|
+
f.write(f"DBSCAN Clustering Analysis Results\n")
|
|
461
|
+
f.write(f"================================\n\n")
|
|
462
|
+
f.write(f"Parameters:\n")
|
|
463
|
+
f.write(f" Threshold (eps): {threshold}\n")
|
|
464
|
+
f.write(f" Min Samples: {min_samples}\n\n")
|
|
465
|
+
f.write(f"Results:\n")
|
|
466
|
+
f.write(f" Number of Clusters: {num_clusters}\n")
|
|
467
|
+
f.write(f" Number of Outliers: {outlier_count}\n")
|
|
468
|
+
f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
|
|
469
|
+
f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
|
|
470
|
+
|
|
471
|
+
f.write(f"Detailed Cluster Information:\n")
|
|
472
|
+
for cluster_id, indices in cluster_to_original.items():
|
|
473
|
+
f.write(f" Cluster {cluster_id}: {len(indices)} points\n")
|
|
474
|
+
f.write(f" Original atom indices: {indices.tolist()}\n\n")
|
|
475
|
+
|
|
476
|
+
f.write(f"Outlier Information:\n")
|
|
477
|
+
outlier_indices = atom_indices[labels == -1]
|
|
478
|
+
f.write(f" {len(outlier_indices)} outliers\n")
|
|
479
|
+
if len(outlier_indices) > 0:
|
|
480
|
+
f.write(f" Original atom indices: {outlier_indices.tolist()}\n")
|
|
481
|
+
|
|
482
|
+
print(f"Detailed frame data saved to: {file_path}")
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def analyze_trajectory(traj_path, indices_path, threshold, min_samples, frame_skip=10,
|
|
486
|
+
output_dir="clustering_results", save_html_visualizations=True):
|
|
487
|
+
"""
|
|
488
|
+
Analyze an entire trajectory with DBSCAN clustering.
|
|
489
|
+
|
|
490
|
+
Parameters
|
|
491
|
+
----------
|
|
492
|
+
traj_path : str
|
|
493
|
+
Path to trajectory file (supports any ASE-readable format like XYZ)
|
|
494
|
+
indices_path : Union[str, List[int], np.ndarray]
|
|
495
|
+
Either a path to numpy file containing atom indices to analyze,
|
|
496
|
+
or a direct list/array of atom indices
|
|
497
|
+
threshold : float
|
|
498
|
+
DBSCAN eps parameter (distance threshold)
|
|
499
|
+
min_samples : int
|
|
500
|
+
DBSCAN min_samples parameter
|
|
501
|
+
frame_skip : int, optional
|
|
502
|
+
Number of frames to skip (default: 10)
|
|
503
|
+
output_dir : str, optional
|
|
504
|
+
Directory to save output files (default: "clustering_results")
|
|
505
|
+
save_html_visualizations : bool, optional
|
|
506
|
+
Whether to save HTML visualizations for first and last frames (default: True)
|
|
507
|
+
|
|
508
|
+
Returns
|
|
509
|
+
-------
|
|
510
|
+
list
|
|
511
|
+
List of analysis results for each frame
|
|
512
|
+
"""
|
|
513
|
+
if isinstance(indices_path, str):
|
|
514
|
+
atom_indices = np.load(indices_path)
|
|
515
|
+
print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
|
|
516
|
+
else:
|
|
517
|
+
atom_indices = np.array(indices_path)
|
|
518
|
+
print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
|
|
519
|
+
|
|
520
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
521
|
+
|
|
522
|
+
try:
|
|
523
|
+
print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
|
|
524
|
+
trajectory = read(traj_path, index=f'::{frame_skip}')
|
|
525
|
+
if not isinstance(trajectory, list):
|
|
526
|
+
trajectory = [trajectory]
|
|
527
|
+
print(f"Loaded {len(trajectory)} frames from trajectory")
|
|
528
|
+
except Exception as e:
|
|
529
|
+
print(f"Error reading trajectory: {e}")
|
|
530
|
+
return []
|
|
531
|
+
|
|
532
|
+
results = []
|
|
533
|
+
|
|
534
|
+
print(f"Analyzing {len(trajectory)} frames...")
|
|
535
|
+
|
|
536
|
+
for i, frame in enumerate(trajectory):
|
|
537
|
+
try:
|
|
538
|
+
frame_number = i * frame_skip
|
|
539
|
+
|
|
540
|
+
full_dm = frame.get_all_distances(mic=True)
|
|
541
|
+
n_atoms = len(atom_indices)
|
|
542
|
+
distance_matrix = np.zeros((n_atoms, n_atoms))
|
|
543
|
+
|
|
544
|
+
for i_local, idx_i in enumerate(atom_indices):
|
|
545
|
+
if idx_i >= len(frame):
|
|
546
|
+
print(f"Warning: Atom index {idx_i} out of range for frame with {len(frame)} atoms. Skipping.")
|
|
547
|
+
continue
|
|
548
|
+
for j_local, idx_j in enumerate(atom_indices):
|
|
549
|
+
if idx_j >= len(frame):
|
|
550
|
+
continue
|
|
551
|
+
distance_matrix[i_local, j_local] = full_dm[idx_i, idx_j]
|
|
552
|
+
|
|
553
|
+
db = DBSCAN(
|
|
554
|
+
eps=threshold,
|
|
555
|
+
min_samples=min_samples,
|
|
556
|
+
metric='precomputed'
|
|
557
|
+
).fit(distance_matrix)
|
|
558
|
+
|
|
559
|
+
labels = db.labels_
|
|
560
|
+
|
|
561
|
+
# Extract positions for visualization
|
|
562
|
+
positions = frame.positions[atom_indices]
|
|
563
|
+
|
|
564
|
+
# Calculate silhouette score
|
|
565
|
+
silhouette_avg = calculate_silhouette_score(distance_matrix, labels)
|
|
566
|
+
|
|
567
|
+
# Extract cluster information
|
|
568
|
+
cluster_info = extract_cluster_info(labels, atom_indices)
|
|
569
|
+
num_clusters = cluster_info["num_clusters"]
|
|
570
|
+
outlier_count = cluster_info["outlier_count"]
|
|
571
|
+
avg_cluster_size = cluster_info["avg_cluster_size"]
|
|
572
|
+
cluster_to_original = cluster_info["cluster_to_original"]
|
|
573
|
+
|
|
574
|
+
if save_html_visualizations and (i == 0 or i == len(trajectory) - 1):
|
|
575
|
+
frame_prefix = "first" if i == 0 else "last"
|
|
576
|
+
html_path = os.path.join(output_dir, f"{frame_prefix}_frame_clusters.html")
|
|
577
|
+
create_html_visualization(
|
|
578
|
+
positions=positions,
|
|
579
|
+
labels=labels,
|
|
580
|
+
title=f"Frame {frame_number} Clusters",
|
|
581
|
+
save_path=html_path,
|
|
582
|
+
cell_dimensions=frame.cell.lengths()
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
results.append([frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size])
|
|
586
|
+
|
|
587
|
+
except Exception as e:
|
|
588
|
+
print(f"Error processing frame {i}: {e}")
|
|
589
|
+
results.append([i * frame_skip, 0, 0, 0.0, 0.0])
|
|
590
|
+
|
|
591
|
+
print(f"Trajectory analysis complete: {len(results)} frames processed")
|
|
592
|
+
|
|
593
|
+
if not results:
|
|
594
|
+
print("Warning: No results were generated from trajectory analysis")
|
|
595
|
+
return []
|
|
596
|
+
|
|
597
|
+
return results
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def save_analysis_results(analysis_results, output_dir="clustering_results", output_prefix="clustering_results"):
|
|
601
|
+
"""
|
|
602
|
+
Save analysis results to CSV, TXT, and PKL files in the specified output directory.
|
|
603
|
+
|
|
604
|
+
Parameters
|
|
605
|
+
----------
|
|
606
|
+
analysis_results : list
|
|
607
|
+
List of analysis results for each frame
|
|
608
|
+
output_dir : str, optional
|
|
609
|
+
Directory to save output files (default: "clustering_results")
|
|
610
|
+
output_prefix : str, optional
|
|
611
|
+
Prefix for output file names (default: "clustering_results")
|
|
612
|
+
|
|
613
|
+
Returns
|
|
614
|
+
-------
|
|
615
|
+
str
|
|
616
|
+
Path to the saved pickle file
|
|
617
|
+
"""
|
|
618
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
619
|
+
|
|
620
|
+
output_csv_file = os.path.join(output_dir, f"{output_prefix}.csv")
|
|
621
|
+
output_txt_file = os.path.join(output_dir, f"{output_prefix}.txt")
|
|
622
|
+
output_pickle_file = os.path.join(output_dir, f"{output_prefix}.pkl")
|
|
623
|
+
|
|
624
|
+
with open(output_csv_file, 'w', newline='') as csvfile:
|
|
625
|
+
csv_writer = csv.writer(csvfile)
|
|
626
|
+
csv_writer.writerow([
|
|
627
|
+
"Frame Number",
|
|
628
|
+
"Number of Clusters",
|
|
629
|
+
"Number of Outliers",
|
|
630
|
+
"Silhouette Score",
|
|
631
|
+
"Average Cluster Size"
|
|
632
|
+
])
|
|
633
|
+
for result in analysis_results:
|
|
634
|
+
csv_writer.writerow(result)
|
|
635
|
+
|
|
636
|
+
with open(output_txt_file, 'w') as f:
|
|
637
|
+
# Calculate averages
|
|
638
|
+
frame_numbers = [result[0] for result in analysis_results]
|
|
639
|
+
num_clusters = [result[1] for result in analysis_results]
|
|
640
|
+
outlier_counts = [result[2] for result in analysis_results]
|
|
641
|
+
silhouette_scores = [result[3] for result in analysis_results]
|
|
642
|
+
avg_cluster_sizes = [result[4] for result in analysis_results]
|
|
643
|
+
|
|
644
|
+
avg_num_clusters = np.mean(num_clusters)
|
|
645
|
+
avg_outlier_count = np.mean(outlier_counts)
|
|
646
|
+
avg_silhouette = np.mean(silhouette_scores)
|
|
647
|
+
avg_cluster_size = np.mean(avg_cluster_sizes)
|
|
648
|
+
|
|
649
|
+
f.write(f"DBSCAN Clustering Analysis Summary\n")
|
|
650
|
+
f.write(f"================================\n\n")
|
|
651
|
+
f.write(f"Average Values Across All Frames:\n")
|
|
652
|
+
f.write(f" Average Number of Clusters: {avg_num_clusters:.2f}\n")
|
|
653
|
+
f.write(f" Average Number of Outliers: {avg_outlier_count:.2f}\n")
|
|
654
|
+
f.write(f" Average Silhouette Score: {avg_silhouette:.4f}\n")
|
|
655
|
+
f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
|
|
656
|
+
|
|
657
|
+
f.write(f"Analysis Results by Frame:\n")
|
|
658
|
+
for result in analysis_results:
|
|
659
|
+
frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size = result
|
|
660
|
+
f.write(f"Frame {frame_number}:\n")
|
|
661
|
+
f.write(f" Number of Clusters: {num_clusters}\n")
|
|
662
|
+
f.write(f" Number of Outliers: {outlier_count}\n")
|
|
663
|
+
f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
|
|
664
|
+
f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
|
|
665
|
+
|
|
666
|
+
with open(output_pickle_file, 'wb') as picklefile:
|
|
667
|
+
pickle.dump(analysis_results, picklefile)
|
|
668
|
+
|
|
669
|
+
print(f"Analysis results saved to directory: {output_dir}")
|
|
670
|
+
|
|
671
|
+
return output_pickle_file
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def plot_analysis_results(pickle_file, output_dir=None):
|
|
675
|
+
"""
|
|
676
|
+
Plot analysis results from a pickle file and save to specified directory.
|
|
677
|
+
|
|
678
|
+
Parameters
|
|
679
|
+
----------
|
|
680
|
+
pickle_file : str
|
|
681
|
+
Path to pickle file containing analysis results
|
|
682
|
+
output_dir : str, optional
|
|
683
|
+
Directory to save output files
|
|
684
|
+
|
|
685
|
+
Returns
|
|
686
|
+
-------
|
|
687
|
+
None
|
|
688
|
+
"""
|
|
689
|
+
with open(pickle_file, 'rb') as f:
|
|
690
|
+
analysis_results = pickle.load(f)
|
|
691
|
+
|
|
692
|
+
# Extract data for plotting
|
|
693
|
+
frame_numbers = [result[0] for result in analysis_results]
|
|
694
|
+
num_clusters = [result[1] for result in analysis_results]
|
|
695
|
+
outlier_counts = [result[2] for result in analysis_results]
|
|
696
|
+
silhouette_scores = [result[3] for result in analysis_results]
|
|
697
|
+
avg_cluster_sizes = [result[4] for result in analysis_results]
|
|
698
|
+
|
|
699
|
+
# Calculate averages
|
|
700
|
+
avg_num_clusters = np.mean(num_clusters)
|
|
701
|
+
avg_outlier_count = np.mean(outlier_counts)
|
|
702
|
+
avg_silhouette = np.mean(silhouette_scores)
|
|
703
|
+
avg_avg_cluster_size = np.mean(avg_cluster_sizes)
|
|
704
|
+
|
|
705
|
+
fig, axs = plt.subplots(4, 1, figsize=(18, 16), sharex=True)
|
|
706
|
+
|
|
707
|
+
# Plot for avg_cluster_sizes
|
|
708
|
+
axs[0].plot(frame_numbers, avg_cluster_sizes, color='red', linestyle='-', linewidth=2)
|
|
709
|
+
axs[0].axhline(y=avg_avg_cluster_size, color='darkred', linestyle='--', alpha=0.7,
|
|
710
|
+
label=f'Average: {avg_avg_cluster_size:.2f}')
|
|
711
|
+
axs[0].set_ylabel('Average Cluster Size', color='red', fontsize=16)
|
|
712
|
+
axs[0].tick_params(axis='y', labelcolor='red', labelsize=14)
|
|
713
|
+
axs[0].grid(True, alpha=0.3)
|
|
714
|
+
axs[0].legend(fontsize=12)
|
|
715
|
+
|
|
716
|
+
# Plot for number of clusters
|
|
717
|
+
axs[1].plot(frame_numbers, num_clusters, color='blue', linestyle='-', linewidth=2)
|
|
718
|
+
axs[1].axhline(y=avg_num_clusters, color='darkblue', linestyle='--', alpha=0.7,
|
|
719
|
+
label=f'Average: {avg_num_clusters:.2f}')
|
|
720
|
+
axs[1].set_ylabel('Number of Clusters', color='blue', fontsize=16)
|
|
721
|
+
axs[1].tick_params(axis='y', labelcolor='blue', labelsize=14)
|
|
722
|
+
axs[1].grid(True, alpha=0.3)
|
|
723
|
+
axs[1].legend(fontsize=12)
|
|
724
|
+
|
|
725
|
+
# Plot for outlier counts
|
|
726
|
+
axs[2].plot(frame_numbers, outlier_counts, color='purple', linestyle='-', linewidth=2)
|
|
727
|
+
axs[2].axhline(y=avg_outlier_count, color='darkviolet', linestyle='--', alpha=0.7,
|
|
728
|
+
label=f'Average: {avg_outlier_count:.2f}')
|
|
729
|
+
axs[2].set_ylabel('Number of Outliers', color='purple', fontsize=16)
|
|
730
|
+
axs[2].tick_params(axis='y', labelcolor='purple', labelsize=14)
|
|
731
|
+
axs[2].grid(True, alpha=0.3)
|
|
732
|
+
axs[2].legend(fontsize=12)
|
|
733
|
+
|
|
734
|
+
# Plot for silhouette scores
|
|
735
|
+
axs[3].plot(frame_numbers, silhouette_scores, color='orange', linestyle='-', linewidth=2)
|
|
736
|
+
axs[3].axhline(y=avg_silhouette, color='darkorange', linestyle='--', alpha=0.7,
|
|
737
|
+
label=f'Average: {avg_silhouette:.4f}')
|
|
738
|
+
axs[3].set_ylabel('Silhouette Score', color='orange', fontsize=16)
|
|
739
|
+
axs[3].tick_params(axis='y', labelcolor='orange', labelsize=14)
|
|
740
|
+
axs[3].grid(True, alpha=0.3)
|
|
741
|
+
axs[3].legend(fontsize=12)
|
|
742
|
+
axs[3].set_xlabel('Frame Number', fontsize=16)
|
|
743
|
+
|
|
744
|
+
plt.suptitle('Clustering Analysis Results', y=0.98, fontsize=20)
|
|
745
|
+
|
|
746
|
+
plt.tight_layout()
|
|
747
|
+
plt.subplots_adjust(top=0.95)
|
|
748
|
+
|
|
749
|
+
if output_dir is None:
|
|
750
|
+
output_dir = os.path.dirname(pickle_file)
|
|
751
|
+
|
|
752
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
753
|
+
|
|
754
|
+
output_base = os.path.splitext(os.path.basename(pickle_file))[0]
|
|
755
|
+
plot_file = os.path.join(output_dir, f"{output_base}_plot.png")
|
|
756
|
+
|
|
757
|
+
plt.savefig(plot_file, dpi=300, bbox_inches='tight')
|
|
758
|
+
|
|
759
|
+
plt.show()
|
|
760
|
+
|
|
761
|
+
print(f"Analysis plot saved to: {plot_file}")
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def cluster_analysis(traj_path, indices_path, threshold, min_samples=2,
|
|
765
|
+
mode='single', output_dir="clustering", custom_frame_index=None,
|
|
766
|
+
frame_skip=10, output_prefix="clustering_results"):
|
|
767
|
+
"""
|
|
768
|
+
Analyze molecular structures with DBSCAN clustering.
|
|
769
|
+
|
|
770
|
+
Parameters
|
|
771
|
+
----------
|
|
772
|
+
traj_path : str
|
|
773
|
+
Path to trajectory file (supports any ASE-readable format like XYZ)
|
|
774
|
+
indices_path : Union[str, List[int], np.ndarray]
|
|
775
|
+
Either a path to numpy file containing atom indices to analyze,
|
|
776
|
+
or a direct list/array of atom indices
|
|
777
|
+
threshold : float
|
|
778
|
+
DBSCAN clustering threshold (eps parameter)
|
|
779
|
+
min_samples : int, optional
|
|
780
|
+
Minimum number of samples in a cluster for DBSCAN (default: 2)
|
|
781
|
+
mode : str, optional
|
|
782
|
+
Analysis mode: 'single' for single frame, 'trajectory' for whole trajectory (default: 'single')
|
|
783
|
+
output_dir : str, optional
|
|
784
|
+
Directory to save output files (default: "clustering")
|
|
785
|
+
custom_frame_index : int, optional
|
|
786
|
+
Specific frame number to analyze in 'single' mode. If None, the last frame is analyzed
|
|
787
|
+
frame_skip : int, optional
|
|
788
|
+
Skip frames in trajectory analysis (default: 10)
|
|
789
|
+
output_prefix : str, optional
|
|
790
|
+
Prefix for output file names in trajectory analysis (default: "clustering_results")
|
|
791
|
+
|
|
792
|
+
Returns
|
|
793
|
+
-------
|
|
794
|
+
dict or list
|
|
795
|
+
Analysis result (dict for single frame, list for trajectory)
|
|
796
|
+
"""
|
|
797
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
798
|
+
print(f"Output files will be saved to: {output_dir}")
|
|
799
|
+
|
|
800
|
+
if isinstance(indices_path, str):
|
|
801
|
+
atom_indices = np.load(indices_path)
|
|
802
|
+
print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
|
|
803
|
+
else:
|
|
804
|
+
atom_indices = np.array(indices_path)
|
|
805
|
+
print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
|
|
806
|
+
|
|
807
|
+
if mode == 'single':
|
|
808
|
+
# Create a mode-specific subdirectory
|
|
809
|
+
mode_dir = os.path.join(output_dir, "single_frame")
|
|
810
|
+
os.makedirs(mode_dir, exist_ok=True)
|
|
811
|
+
|
|
812
|
+
# Analyze a single frame
|
|
813
|
+
analyzer = analyze_frame(
|
|
814
|
+
traj_path,
|
|
815
|
+
atom_indices,
|
|
816
|
+
threshold,
|
|
817
|
+
min_samples,
|
|
818
|
+
metric='precomputed',
|
|
819
|
+
custom_frame_index=custom_frame_index
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
analysis_result = analyzer.analyze_structure(output_dir=mode_dir)
|
|
823
|
+
return analysis_result
|
|
824
|
+
|
|
825
|
+
else:
|
|
826
|
+
mode_dir = os.path.join(output_dir, "trajectory")
|
|
827
|
+
os.makedirs(mode_dir, exist_ok=True)
|
|
828
|
+
|
|
829
|
+
analysis_results = analyze_trajectory(
|
|
830
|
+
traj_path,
|
|
831
|
+
atom_indices,
|
|
832
|
+
threshold,
|
|
833
|
+
min_samples,
|
|
834
|
+
frame_skip,
|
|
835
|
+
output_dir=mode_dir,
|
|
836
|
+
save_html_visualizations=True
|
|
837
|
+
)
|
|
838
|
+
|