crisp-ase 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. CRISP/__init__.py +99 -0
  2. CRISP/_version.py +1 -0
  3. CRISP/cli.py +41 -0
  4. CRISP/data_analysis/__init__.py +38 -0
  5. CRISP/data_analysis/clustering.py +838 -0
  6. CRISP/data_analysis/contact_coordination.py +915 -0
  7. CRISP/data_analysis/h_bond.py +772 -0
  8. CRISP/data_analysis/msd.py +1199 -0
  9. CRISP/data_analysis/prdf.py +404 -0
  10. CRISP/data_analysis/volumetric_atomic_density.py +527 -0
  11. CRISP/py.typed +1 -0
  12. CRISP/simulation_utility/__init__.py +31 -0
  13. CRISP/simulation_utility/atomic_indices.py +155 -0
  14. CRISP/simulation_utility/atomic_traj_linemap.py +278 -0
  15. CRISP/simulation_utility/error_analysis.py +254 -0
  16. CRISP/simulation_utility/interatomic_distances.py +200 -0
  17. CRISP/simulation_utility/subsampling.py +241 -0
  18. CRISP/tests/DataAnalysis/__init__.py +1 -0
  19. CRISP/tests/DataAnalysis/test_clustering_extended.py +212 -0
  20. CRISP/tests/DataAnalysis/test_contact_coordination.py +184 -0
  21. CRISP/tests/DataAnalysis/test_contact_coordination_extended.py +465 -0
  22. CRISP/tests/DataAnalysis/test_h_bond_complete.py +326 -0
  23. CRISP/tests/DataAnalysis/test_h_bond_extended.py +322 -0
  24. CRISP/tests/DataAnalysis/test_msd_complete.py +305 -0
  25. CRISP/tests/DataAnalysis/test_msd_extended.py +522 -0
  26. CRISP/tests/DataAnalysis/test_prdf.py +206 -0
  27. CRISP/tests/DataAnalysis/test_volumetric_atomic_density.py +463 -0
  28. CRISP/tests/SimulationUtility/__init__.py +1 -0
  29. CRISP/tests/SimulationUtility/test_atomic_traj_linemap.py +101 -0
  30. CRISP/tests/SimulationUtility/test_atomic_traj_linemap_extended.py +469 -0
  31. CRISP/tests/SimulationUtility/test_error_analysis_extended.py +151 -0
  32. CRISP/tests/SimulationUtility/test_interatomic_distances.py +223 -0
  33. CRISP/tests/SimulationUtility/test_subsampling.py +365 -0
  34. CRISP/tests/__init__.py +1 -0
  35. CRISP/tests/test_CRISP.py +28 -0
  36. CRISP/tests/test_cli.py +87 -0
  37. CRISP/tests/test_crisp_comprehensive.py +679 -0
  38. crisp_ase-1.1.2.dist-info/METADATA +141 -0
  39. crisp_ase-1.1.2.dist-info/RECORD +42 -0
  40. crisp_ase-1.1.2.dist-info/WHEEL +5 -0
  41. crisp_ase-1.1.2.dist-info/entry_points.txt +2 -0
  42. crisp_ase-1.1.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,838 @@
1
+ """
2
+ CRISP/data_analysis/clustering.py
3
+
4
+ This module performs cluster analysis on molecular dynamics trajectory data,
5
+ using DBSCAN algorithm to identify atom clusters and their properties.
6
+ """
7
+
8
+ import numpy as np
9
+ from ase.io import read
10
+ from sklearn.cluster import DBSCAN
11
+ from sklearn.metrics import silhouette_score
12
+ import plotly.graph_objects as go
13
+ import pickle
14
+ import csv
15
+ import matplotlib.pyplot as plt
16
+ import os
17
+ from typing import Union, List, Optional, Tuple, Dict, Any
18
+
19
+ __all__ = ['analyze_frame', 'analyze_trajectory']
20
+
21
+
22
+ class analyze_frame:
23
+ """
24
+ Analyze atomic structures using DBSCAN clustering algorithm.
25
+
26
+ Parameters
27
+ ----------
28
+ traj_path : str
29
+ Path to trajectory file
30
+ atom_indices : np.ndarray
31
+ Array containing indices of atoms to analyze
32
+ threshold : float
33
+ DBSCAN eps parameter (distance threshold)
34
+ min_samples : int
35
+ DBSCAN min_samples parameter
36
+ metric : str, optional
37
+ Distance metric to use (default: 'precomputed')
38
+ custom_frame_index : int, optional
39
+ Specific frame to analyze (default: None, uses last frame)
40
+ """
41
+ def __init__(
42
+ self,
43
+ traj_path: str,
44
+ atom_indices: Union[str, np.ndarray],
45
+ threshold: float,
46
+ min_samples: int,
47
+ metric: str = 'precomputed',
48
+ custom_frame_index: Optional[int] = None
49
+ ) -> None:
50
+ self.traj_path = traj_path
51
+ if isinstance(atom_indices, str):
52
+ atom_indices = np.load(atom_indices)
53
+ self.atom_indices = atom_indices
54
+ self.threshold = threshold
55
+ self.min_samples = min_samples
56
+ self.metric = metric
57
+ self.custom_frame_index = custom_frame_index
58
+ self.labels = None
59
+ self.distance_matrix = None
60
+
61
+ def read_custom_frame(self):
62
+ """
63
+ Read a specific frame or the last frame from the trajectory.
64
+
65
+ Returns
66
+ -------
67
+ ase.Atoms or None
68
+ Atomic structure or None if reading fails
69
+ """
70
+ try:
71
+ if self.custom_frame_index is not None:
72
+ frame = read(self.traj_path, index=self.custom_frame_index)
73
+ else:
74
+ frame = read(self.traj_path, index='-1')
75
+ return frame
76
+ except Exception as e:
77
+ print(f"Error reading trajectory: {e}")
78
+ return None
79
+
80
+ def calculate_distance_matrix(self, atoms):
81
+ """
82
+ Calculate a distance matrix with periodic boundary conditions.
83
+
84
+ Parameters
85
+ ----------
86
+ atoms : ase.Atoms
87
+ Atomic structure
88
+
89
+ Returns
90
+ -------
91
+ Tuple[np.ndarray, np.ndarray]
92
+ Distance matrix and positions array
93
+
94
+ Raises
95
+ ------
96
+ ValueError
97
+ If there are not enough atoms to form clusters
98
+ """
99
+ positions = atoms.positions[self.atom_indices]
100
+
101
+ if len(self.atom_indices) < self.min_samples:
102
+ raise ValueError(f"Not enough atoms ({len(self.atom_indices)}) to form clusters with min_samples={self.min_samples}")
103
+
104
+ full_dm = atoms.get_all_distances(mic=True)
105
+ n_atoms = len(self.atom_indices)
106
+ self.distance_matrix = np.zeros((n_atoms, n_atoms))
107
+
108
+ for i, idx_i in enumerate(self.atom_indices):
109
+ for j, idx_j in enumerate(self.atom_indices):
110
+ self.distance_matrix[i, j] = full_dm[idx_i, idx_j]
111
+
112
+ return self.distance_matrix, positions
113
+
114
+ def find_clusters(self):
115
+ """
116
+ Find clusters using DBSCAN algorithm.
117
+
118
+ Returns
119
+ -------
120
+ np.ndarray
121
+ Cluster labels for each input point
122
+
123
+ Raises
124
+ ------
125
+ ValueError
126
+ If distance matrix has not been calculated
127
+ """
128
+ if self.distance_matrix is None:
129
+ raise ValueError("Distance matrix must be calculated first")
130
+
131
+ db = DBSCAN(
132
+ eps=self.threshold,
133
+ min_samples=self.min_samples,
134
+ metric=self.metric
135
+ ).fit(self.distance_matrix if self.metric == 'precomputed' else None)
136
+
137
+ self.labels = db.labels_
138
+ return self.labels
139
+
140
+ def analyze_structure(self, save_html_path=None, output_dir=None):
141
+ """
142
+ Analyze the structure and find clusters.
143
+
144
+ Parameters
145
+ ----------
146
+ save_html_path : str, optional
147
+ Path to save HTML visualization
148
+ output_dir : str, optional
149
+ Directory to save all results
150
+
151
+ Returns
152
+ -------
153
+ dict or None
154
+ Dictionary with analysis results or None if analysis fails
155
+ """
156
+ frame = self.read_custom_frame()
157
+ if frame is None:
158
+ return None
159
+
160
+ if output_dir is None and save_html_path is not None:
161
+ output_dir = os.path.dirname(save_html_path)
162
+ if not output_dir:
163
+ output_dir = "clustering_results"
164
+
165
+ os.makedirs(output_dir, exist_ok=True)
166
+ print(f"\nSaving results to directory: {output_dir}")
167
+
168
+ if save_html_path is None:
169
+ base_name = os.path.splitext(os.path.basename(self.traj_path))[0]
170
+ save_html_path = os.path.join(output_dir, f"{base_name}_clusters.html")
171
+
172
+ distance_matrix, positions = self.calculate_distance_matrix(frame)
173
+ self.find_clusters()
174
+
175
+ # Silhouette score (excluding outliers)
176
+ silhouette_avg = calculate_silhouette_score(self.distance_matrix, self.labels)
177
+
178
+ create_html_visualization(
179
+ positions=positions,
180
+ labels=self.labels,
181
+ title='Interactive 3D Cluster Visualization',
182
+ save_path=save_html_path,
183
+ cell_dimensions=frame.cell.lengths()
184
+ )
185
+
186
+ # Extract cluster information
187
+ cluster_info = extract_cluster_info(self.labels, self.atom_indices)
188
+ num_clusters = cluster_info["num_clusters"]
189
+ outlier_count = cluster_info["outlier_count"]
190
+ avg_cluster_size = cluster_info["avg_cluster_size"]
191
+ cluster_to_original = cluster_info["cluster_to_original"]
192
+
193
+ print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original)
194
+
195
+ frame_info_path = os.path.join(output_dir, "frame_data.txt")
196
+ save_frame_info_to_file(
197
+ frame_info_path,
198
+ self.threshold,
199
+ self.min_samples,
200
+ num_clusters,
201
+ outlier_count,
202
+ silhouette_avg,
203
+ avg_cluster_size,
204
+ cluster_to_original,
205
+ self.labels,
206
+ self.atom_indices
207
+ )
208
+
209
+ pickle_path = os.path.join(output_dir, "single_frame_analysis.pkl")
210
+ result_data = {
211
+ "num_clusters": num_clusters,
212
+ "outlier_count": outlier_count,
213
+ "silhouette_avg": silhouette_avg,
214
+ "avg_cluster_size": avg_cluster_size,
215
+ "cluster_to_original": cluster_to_original,
216
+ "labels": self.labels,
217
+ "positions": positions,
218
+ "parameters": {
219
+ "threshold": self.threshold,
220
+ "min_samples": self.min_samples,
221
+ "trajectory": self.traj_path,
222
+ "frame_index": self.custom_frame_index
223
+ }
224
+ }
225
+
226
+ with open(pickle_path, 'wb') as f:
227
+ pickle.dump(result_data, f)
228
+ print(f"Full analysis data saved to: {pickle_path}")
229
+
230
+ return result_data
231
+
232
+
233
+ def create_html_visualization(positions, labels, title, save_path, cell_dimensions=None):
234
+ """
235
+ Create and save a 3D HTML visualization of clusters.
236
+
237
+ Parameters
238
+ ----------
239
+ positions : np.ndarray
240
+ Array of atom positions
241
+ labels : np.ndarray
242
+ Array of cluster labels
243
+ title : str
244
+ Title for the visualization
245
+ save_path : str
246
+ Path to save the HTML file
247
+ cell_dimensions : np.ndarray, optional
248
+ Cell dimensions from the simulation [a, b, c]
249
+ """
250
+ fig = go.Figure()
251
+
252
+ for label in np.unique(labels):
253
+ cluster_points = positions[labels == label]
254
+ label_name = "Outliers" if label == -1 else f'Cluster {label}'
255
+ marker_size = 5
256
+ marker_color = 'gray' if label == -1 else None
257
+
258
+ fig.add_trace(
259
+ go.Scatter3d(
260
+ x=cluster_points[:, 0],
261
+ y=cluster_points[:, 1],
262
+ z=cluster_points[:, 2],
263
+ mode='markers',
264
+ marker=dict(
265
+ size=marker_size,
266
+ color=marker_color
267
+ ),
268
+ name=label_name
269
+ )
270
+ )
271
+
272
+ layout_args = {
273
+ "title": title,
274
+ "legend": dict(itemsizing='constant')
275
+ }
276
+
277
+ scene_dict = {
278
+ "xaxis": dict(title='X'),
279
+ "yaxis": dict(title='Y'),
280
+ "zaxis": dict(title='Z')
281
+ }
282
+
283
+ if cell_dimensions is not None:
284
+ scene_dict["xaxis"]["range"] = [0, cell_dimensions[0]]
285
+ scene_dict["yaxis"]["range"] = [0, cell_dimensions[1]]
286
+ scene_dict["zaxis"]["range"] = [0, cell_dimensions[2]]
287
+
288
+ vertices = [
289
+ [0, 0, 0], [cell_dimensions[0], 0, 0],
290
+ [cell_dimensions[0], cell_dimensions[1], 0], [0, cell_dimensions[1], 0],
291
+ [0, 0, cell_dimensions[2]], [cell_dimensions[0], 0, cell_dimensions[2]],
292
+ [cell_dimensions[0], cell_dimensions[1], cell_dimensions[2]], [0, cell_dimensions[1], cell_dimensions[2]]
293
+ ]
294
+
295
+ # Define box edges
296
+ i, j, k = [], [], []
297
+ # Bottom face
298
+ i.extend([0, 1, 2, 3, 0])
299
+ j.extend([1, 2, 3, 0, 4])
300
+ k.extend([0, 0, 0, 0, 0])
301
+ # Top face
302
+ i.extend([4, 5, 6, 7, 4])
303
+ j.extend([5, 6, 7, 4, 0])
304
+ k.extend([0, 0, 0, 0, 0])
305
+ # Vertical edges
306
+ i.extend([1, 2, 3])
307
+ j.extend([5, 6, 7])
308
+ k.extend([0, 0, 0])
309
+
310
+ fig.add_trace(go.Scatter3d(
311
+ x=[vertices[idx][0] for idx in i],
312
+ y=[vertices[idx][1] for idx in j],
313
+ z=[vertices[idx][2] for idx in k],
314
+ mode='lines',
315
+ line=dict(color='black', width=2),
316
+ name='Unit Cell',
317
+ showlegend=False
318
+ ))
319
+
320
+ layout_args["scene"] = scene_dict
321
+ fig.update_layout(**layout_args)
322
+
323
+ fig.write_html(save_path)
324
+ print(f"3D visualization saved to {save_path}")
325
+
326
+
327
+ def calculate_silhouette_score(distance_matrix, labels):
328
+ """
329
+ Calculate silhouette score, handling edge cases.
330
+
331
+ Parameters
332
+ ----------
333
+ distance_matrix : np.ndarray
334
+ Distance matrix for points
335
+ labels : np.ndarray
336
+ Cluster labels
337
+
338
+ Returns
339
+ -------
340
+ float
341
+ Silhouette score or 0 if calculation fails
342
+ """
343
+ try:
344
+ non_outlier_mask = labels != -1
345
+ if np.sum(non_outlier_mask) > 1:
346
+ # Extract the sub-matrix for non-outlier points
347
+ filtered_matrix = distance_matrix[np.ix_(non_outlier_mask, non_outlier_mask)]
348
+ filtered_labels = labels[non_outlier_mask]
349
+ return silhouette_score(filtered_matrix, filtered_labels, metric='precomputed')
350
+ return 0
351
+ except ValueError:
352
+ return 0
353
+
354
+
355
+ def extract_cluster_info(labels, atom_indices):
356
+ """
357
+ Extract cluster information from labels.
358
+
359
+ Parameters
360
+ ----------
361
+ labels : np.ndarray
362
+ Cluster labels
363
+ atom_indices : np.ndarray
364
+ Original atom indices
365
+
366
+ Returns
367
+ -------
368
+ dict
369
+ Dictionary with cluster information
370
+ """
371
+ cluster_indices = {}
372
+ cluster_sizes = {}
373
+ cluster_to_original = {}
374
+
375
+ for cluster_id in np.unique(labels):
376
+ if cluster_id != -1: # Only count actual clusters (not outliers)
377
+ cluster_indices[cluster_id] = np.where(labels == cluster_id)[0]
378
+ cluster_sizes[cluster_id] = len(cluster_indices[cluster_id])
379
+ cluster_to_original[cluster_id] = atom_indices[cluster_indices[cluster_id]]
380
+
381
+ outlier_count = np.sum(labels == -1)
382
+ num_clusters = len([label for label in np.unique(labels) if label != -1])
383
+
384
+ # Calculate average cluster size
385
+ avg_cluster_size = np.mean(list(cluster_sizes.values())) if cluster_sizes else 0
386
+
387
+ return {
388
+ "num_clusters": num_clusters,
389
+ "outlier_count": outlier_count,
390
+ "avg_cluster_size": avg_cluster_size,
391
+ "cluster_sizes": cluster_sizes,
392
+ "cluster_to_original": cluster_to_original
393
+ }
394
+
395
+
396
+ def print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original):
397
+ """
398
+ Print a summary of clustering results.
399
+
400
+ Parameters
401
+ ----------
402
+ num_clusters : int
403
+ Number of clusters found
404
+ outlier_count : int
405
+ Number of outliers
406
+ silhouette_avg : float
407
+ Average silhouette score
408
+ avg_cluster_size : float
409
+ Average cluster size
410
+ cluster_to_original : dict
411
+ Mapping from cluster IDs to original atom indices
412
+
413
+ Returns
414
+ -------
415
+ None
416
+ """
417
+ print(f"\nNumber of Clusters: {num_clusters}")
418
+ print(f"Number of Outliers: {outlier_count}")
419
+ print(f"Silhouette Score: {silhouette_avg:.4f}")
420
+ print(f"Average Cluster Size: {avg_cluster_size:.2f}")
421
+ print("Cluster Information:")
422
+
423
+ for cluster_id, atoms in cluster_to_original.items():
424
+ print(f" Cluster {cluster_id}: {len(atoms)} points")
425
+
426
+
427
+ def save_frame_info_to_file(file_path, threshold, min_samples, num_clusters, outlier_count,
428
+ silhouette_avg, avg_cluster_size, cluster_to_original, labels, atom_indices):
429
+ """
430
+ Save detailed frame information to a text file.
431
+
432
+ Parameters
433
+ ----------
434
+ file_path : str
435
+ Path to save the text file
436
+ threshold : float
437
+ DBSCAN eps parameter
438
+ min_samples : int
439
+ DBSCAN min_samples parameter
440
+ num_clusters : int
441
+ Number of clusters found
442
+ outlier_count : int
443
+ Number of outliers
444
+ silhouette_avg : float
445
+ Average silhouette score
446
+ avg_cluster_size : float
447
+ Average cluster size
448
+ cluster_to_original : dict
449
+ Mapping from cluster IDs to original atom indices
450
+ labels : np.ndarray
451
+ Cluster labels
452
+ atom_indices : np.ndarray
453
+ Original atom indices
454
+
455
+ Returns
456
+ -------
457
+ None
458
+ """
459
+ with open(file_path, 'w') as f:
460
+ f.write(f"DBSCAN Clustering Analysis Results\n")
461
+ f.write(f"================================\n\n")
462
+ f.write(f"Parameters:\n")
463
+ f.write(f" Threshold (eps): {threshold}\n")
464
+ f.write(f" Min Samples: {min_samples}\n\n")
465
+ f.write(f"Results:\n")
466
+ f.write(f" Number of Clusters: {num_clusters}\n")
467
+ f.write(f" Number of Outliers: {outlier_count}\n")
468
+ f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
469
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
470
+
471
+ f.write(f"Detailed Cluster Information:\n")
472
+ for cluster_id, indices in cluster_to_original.items():
473
+ f.write(f" Cluster {cluster_id}: {len(indices)} points\n")
474
+ f.write(f" Original atom indices: {indices.tolist()}\n\n")
475
+
476
+ f.write(f"Outlier Information:\n")
477
+ outlier_indices = atom_indices[labels == -1]
478
+ f.write(f" {len(outlier_indices)} outliers\n")
479
+ if len(outlier_indices) > 0:
480
+ f.write(f" Original atom indices: {outlier_indices.tolist()}\n")
481
+
482
+ print(f"Detailed frame data saved to: {file_path}")
483
+
484
+
485
+ def analyze_trajectory(traj_path, indices_path, threshold, min_samples, frame_skip=10,
486
+ output_dir="clustering_results", save_html_visualizations=True):
487
+ """
488
+ Analyze an entire trajectory with DBSCAN clustering.
489
+
490
+ Parameters
491
+ ----------
492
+ traj_path : str
493
+ Path to trajectory file (supports any ASE-readable format like XYZ)
494
+ indices_path : Union[str, List[int], np.ndarray]
495
+ Either a path to numpy file containing atom indices to analyze,
496
+ or a direct list/array of atom indices
497
+ threshold : float
498
+ DBSCAN eps parameter (distance threshold)
499
+ min_samples : int
500
+ DBSCAN min_samples parameter
501
+ frame_skip : int, optional
502
+ Number of frames to skip (default: 10)
503
+ output_dir : str, optional
504
+ Directory to save output files (default: "clustering_results")
505
+ save_html_visualizations : bool, optional
506
+ Whether to save HTML visualizations for first and last frames (default: True)
507
+
508
+ Returns
509
+ -------
510
+ list
511
+ List of analysis results for each frame
512
+ """
513
+ if isinstance(indices_path, str):
514
+ atom_indices = np.load(indices_path)
515
+ print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
516
+ else:
517
+ atom_indices = np.array(indices_path)
518
+ print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
519
+
520
+ os.makedirs(output_dir, exist_ok=True)
521
+
522
+ try:
523
+ print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
524
+ trajectory = read(traj_path, index=f'::{frame_skip}')
525
+ if not isinstance(trajectory, list):
526
+ trajectory = [trajectory]
527
+ print(f"Loaded {len(trajectory)} frames from trajectory")
528
+ except Exception as e:
529
+ print(f"Error reading trajectory: {e}")
530
+ return []
531
+
532
+ results = []
533
+
534
+ print(f"Analyzing {len(trajectory)} frames...")
535
+
536
+ for i, frame in enumerate(trajectory):
537
+ try:
538
+ frame_number = i * frame_skip
539
+
540
+ full_dm = frame.get_all_distances(mic=True)
541
+ n_atoms = len(atom_indices)
542
+ distance_matrix = np.zeros((n_atoms, n_atoms))
543
+
544
+ for i_local, idx_i in enumerate(atom_indices):
545
+ if idx_i >= len(frame):
546
+ print(f"Warning: Atom index {idx_i} out of range for frame with {len(frame)} atoms. Skipping.")
547
+ continue
548
+ for j_local, idx_j in enumerate(atom_indices):
549
+ if idx_j >= len(frame):
550
+ continue
551
+ distance_matrix[i_local, j_local] = full_dm[idx_i, idx_j]
552
+
553
+ db = DBSCAN(
554
+ eps=threshold,
555
+ min_samples=min_samples,
556
+ metric='precomputed'
557
+ ).fit(distance_matrix)
558
+
559
+ labels = db.labels_
560
+
561
+ # Extract positions for visualization
562
+ positions = frame.positions[atom_indices]
563
+
564
+ # Calculate silhouette score
565
+ silhouette_avg = calculate_silhouette_score(distance_matrix, labels)
566
+
567
+ # Extract cluster information
568
+ cluster_info = extract_cluster_info(labels, atom_indices)
569
+ num_clusters = cluster_info["num_clusters"]
570
+ outlier_count = cluster_info["outlier_count"]
571
+ avg_cluster_size = cluster_info["avg_cluster_size"]
572
+ cluster_to_original = cluster_info["cluster_to_original"]
573
+
574
+ if save_html_visualizations and (i == 0 or i == len(trajectory) - 1):
575
+ frame_prefix = "first" if i == 0 else "last"
576
+ html_path = os.path.join(output_dir, f"{frame_prefix}_frame_clusters.html")
577
+ create_html_visualization(
578
+ positions=positions,
579
+ labels=labels,
580
+ title=f"Frame {frame_number} Clusters",
581
+ save_path=html_path,
582
+ cell_dimensions=frame.cell.lengths()
583
+ )
584
+
585
+ results.append([frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size])
586
+
587
+ except Exception as e:
588
+ print(f"Error processing frame {i}: {e}")
589
+ results.append([i * frame_skip, 0, 0, 0.0, 0.0])
590
+
591
+ print(f"Trajectory analysis complete: {len(results)} frames processed")
592
+
593
+ if not results:
594
+ print("Warning: No results were generated from trajectory analysis")
595
+ return []
596
+
597
+ return results
598
+
599
+
600
+ def save_analysis_results(analysis_results, output_dir="clustering_results", output_prefix="clustering_results"):
601
+ """
602
+ Save analysis results to CSV, TXT, and PKL files in the specified output directory.
603
+
604
+ Parameters
605
+ ----------
606
+ analysis_results : list
607
+ List of analysis results for each frame
608
+ output_dir : str, optional
609
+ Directory to save output files (default: "clustering_results")
610
+ output_prefix : str, optional
611
+ Prefix for output file names (default: "clustering_results")
612
+
613
+ Returns
614
+ -------
615
+ str
616
+ Path to the saved pickle file
617
+ """
618
+ os.makedirs(output_dir, exist_ok=True)
619
+
620
+ output_csv_file = os.path.join(output_dir, f"{output_prefix}.csv")
621
+ output_txt_file = os.path.join(output_dir, f"{output_prefix}.txt")
622
+ output_pickle_file = os.path.join(output_dir, f"{output_prefix}.pkl")
623
+
624
+ with open(output_csv_file, 'w', newline='') as csvfile:
625
+ csv_writer = csv.writer(csvfile)
626
+ csv_writer.writerow([
627
+ "Frame Number",
628
+ "Number of Clusters",
629
+ "Number of Outliers",
630
+ "Silhouette Score",
631
+ "Average Cluster Size"
632
+ ])
633
+ for result in analysis_results:
634
+ csv_writer.writerow(result)
635
+
636
+ with open(output_txt_file, 'w') as f:
637
+ # Calculate averages
638
+ frame_numbers = [result[0] for result in analysis_results]
639
+ num_clusters = [result[1] for result in analysis_results]
640
+ outlier_counts = [result[2] for result in analysis_results]
641
+ silhouette_scores = [result[3] for result in analysis_results]
642
+ avg_cluster_sizes = [result[4] for result in analysis_results]
643
+
644
+ avg_num_clusters = np.mean(num_clusters)
645
+ avg_outlier_count = np.mean(outlier_counts)
646
+ avg_silhouette = np.mean(silhouette_scores)
647
+ avg_cluster_size = np.mean(avg_cluster_sizes)
648
+
649
+ f.write(f"DBSCAN Clustering Analysis Summary\n")
650
+ f.write(f"================================\n\n")
651
+ f.write(f"Average Values Across All Frames:\n")
652
+ f.write(f" Average Number of Clusters: {avg_num_clusters:.2f}\n")
653
+ f.write(f" Average Number of Outliers: {avg_outlier_count:.2f}\n")
654
+ f.write(f" Average Silhouette Score: {avg_silhouette:.4f}\n")
655
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
656
+
657
+ f.write(f"Analysis Results by Frame:\n")
658
+ for result in analysis_results:
659
+ frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size = result
660
+ f.write(f"Frame {frame_number}:\n")
661
+ f.write(f" Number of Clusters: {num_clusters}\n")
662
+ f.write(f" Number of Outliers: {outlier_count}\n")
663
+ f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
664
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
665
+
666
+ with open(output_pickle_file, 'wb') as picklefile:
667
+ pickle.dump(analysis_results, picklefile)
668
+
669
+ print(f"Analysis results saved to directory: {output_dir}")
670
+
671
+ return output_pickle_file
672
+
673
+
674
+ def plot_analysis_results(pickle_file, output_dir=None):
675
+ """
676
+ Plot analysis results from a pickle file and save to specified directory.
677
+
678
+ Parameters
679
+ ----------
680
+ pickle_file : str
681
+ Path to pickle file containing analysis results
682
+ output_dir : str, optional
683
+ Directory to save output files
684
+
685
+ Returns
686
+ -------
687
+ None
688
+ """
689
+ with open(pickle_file, 'rb') as f:
690
+ analysis_results = pickle.load(f)
691
+
692
+ # Extract data for plotting
693
+ frame_numbers = [result[0] for result in analysis_results]
694
+ num_clusters = [result[1] for result in analysis_results]
695
+ outlier_counts = [result[2] for result in analysis_results]
696
+ silhouette_scores = [result[3] for result in analysis_results]
697
+ avg_cluster_sizes = [result[4] for result in analysis_results]
698
+
699
+ # Calculate averages
700
+ avg_num_clusters = np.mean(num_clusters)
701
+ avg_outlier_count = np.mean(outlier_counts)
702
+ avg_silhouette = np.mean(silhouette_scores)
703
+ avg_avg_cluster_size = np.mean(avg_cluster_sizes)
704
+
705
+ fig, axs = plt.subplots(4, 1, figsize=(18, 16), sharex=True)
706
+
707
+ # Plot for avg_cluster_sizes
708
+ axs[0].plot(frame_numbers, avg_cluster_sizes, color='red', linestyle='-', linewidth=2)
709
+ axs[0].axhline(y=avg_avg_cluster_size, color='darkred', linestyle='--', alpha=0.7,
710
+ label=f'Average: {avg_avg_cluster_size:.2f}')
711
+ axs[0].set_ylabel('Average Cluster Size', color='red', fontsize=16)
712
+ axs[0].tick_params(axis='y', labelcolor='red', labelsize=14)
713
+ axs[0].grid(True, alpha=0.3)
714
+ axs[0].legend(fontsize=12)
715
+
716
+ # Plot for number of clusters
717
+ axs[1].plot(frame_numbers, num_clusters, color='blue', linestyle='-', linewidth=2)
718
+ axs[1].axhline(y=avg_num_clusters, color='darkblue', linestyle='--', alpha=0.7,
719
+ label=f'Average: {avg_num_clusters:.2f}')
720
+ axs[1].set_ylabel('Number of Clusters', color='blue', fontsize=16)
721
+ axs[1].tick_params(axis='y', labelcolor='blue', labelsize=14)
722
+ axs[1].grid(True, alpha=0.3)
723
+ axs[1].legend(fontsize=12)
724
+
725
+ # Plot for outlier counts
726
+ axs[2].plot(frame_numbers, outlier_counts, color='purple', linestyle='-', linewidth=2)
727
+ axs[2].axhline(y=avg_outlier_count, color='darkviolet', linestyle='--', alpha=0.7,
728
+ label=f'Average: {avg_outlier_count:.2f}')
729
+ axs[2].set_ylabel('Number of Outliers', color='purple', fontsize=16)
730
+ axs[2].tick_params(axis='y', labelcolor='purple', labelsize=14)
731
+ axs[2].grid(True, alpha=0.3)
732
+ axs[2].legend(fontsize=12)
733
+
734
+ # Plot for silhouette scores
735
+ axs[3].plot(frame_numbers, silhouette_scores, color='orange', linestyle='-', linewidth=2)
736
+ axs[3].axhline(y=avg_silhouette, color='darkorange', linestyle='--', alpha=0.7,
737
+ label=f'Average: {avg_silhouette:.4f}')
738
+ axs[3].set_ylabel('Silhouette Score', color='orange', fontsize=16)
739
+ axs[3].tick_params(axis='y', labelcolor='orange', labelsize=14)
740
+ axs[3].grid(True, alpha=0.3)
741
+ axs[3].legend(fontsize=12)
742
+ axs[3].set_xlabel('Frame Number', fontsize=16)
743
+
744
+ plt.suptitle('Clustering Analysis Results', y=0.98, fontsize=20)
745
+
746
+ plt.tight_layout()
747
+ plt.subplots_adjust(top=0.95)
748
+
749
+ if output_dir is None:
750
+ output_dir = os.path.dirname(pickle_file)
751
+
752
+ os.makedirs(output_dir, exist_ok=True)
753
+
754
+ output_base = os.path.splitext(os.path.basename(pickle_file))[0]
755
+ plot_file = os.path.join(output_dir, f"{output_base}_plot.png")
756
+
757
+ plt.savefig(plot_file, dpi=300, bbox_inches='tight')
758
+
759
+ plt.show()
760
+
761
+ print(f"Analysis plot saved to: {plot_file}")
762
+
763
+
764
+ def cluster_analysis(traj_path, indices_path, threshold, min_samples=2,
765
+ mode='single', output_dir="clustering", custom_frame_index=None,
766
+ frame_skip=10, output_prefix="clustering_results"):
767
+ """
768
+ Analyze molecular structures with DBSCAN clustering.
769
+
770
+ Parameters
771
+ ----------
772
+ traj_path : str
773
+ Path to trajectory file (supports any ASE-readable format like XYZ)
774
+ indices_path : Union[str, List[int], np.ndarray]
775
+ Either a path to numpy file containing atom indices to analyze,
776
+ or a direct list/array of atom indices
777
+ threshold : float
778
+ DBSCAN clustering threshold (eps parameter)
779
+ min_samples : int, optional
780
+ Minimum number of samples in a cluster for DBSCAN (default: 2)
781
+ mode : str, optional
782
+ Analysis mode: 'single' for single frame, 'trajectory' for whole trajectory (default: 'single')
783
+ output_dir : str, optional
784
+ Directory to save output files (default: "clustering")
785
+ custom_frame_index : int, optional
786
+ Specific frame number to analyze in 'single' mode. If None, the last frame is analyzed
787
+ frame_skip : int, optional
788
+ Skip frames in trajectory analysis (default: 10)
789
+ output_prefix : str, optional
790
+ Prefix for output file names in trajectory analysis (default: "clustering_results")
791
+
792
+ Returns
793
+ -------
794
+ dict or list
795
+ Analysis result (dict for single frame, list for trajectory)
796
+ """
797
+ os.makedirs(output_dir, exist_ok=True)
798
+ print(f"Output files will be saved to: {output_dir}")
799
+
800
+ if isinstance(indices_path, str):
801
+ atom_indices = np.load(indices_path)
802
+ print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
803
+ else:
804
+ atom_indices = np.array(indices_path)
805
+ print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
806
+
807
+ if mode == 'single':
808
+ # Create a mode-specific subdirectory
809
+ mode_dir = os.path.join(output_dir, "single_frame")
810
+ os.makedirs(mode_dir, exist_ok=True)
811
+
812
+ # Analyze a single frame
813
+ analyzer = analyze_frame(
814
+ traj_path,
815
+ atom_indices,
816
+ threshold,
817
+ min_samples,
818
+ metric='precomputed',
819
+ custom_frame_index=custom_frame_index
820
+ )
821
+
822
+ analysis_result = analyzer.analyze_structure(output_dir=mode_dir)
823
+ return analysis_result
824
+
825
+ else:
826
+ mode_dir = os.path.join(output_dir, "trajectory")
827
+ os.makedirs(mode_dir, exist_ok=True)
828
+
829
+ analysis_results = analyze_trajectory(
830
+ traj_path,
831
+ atom_indices,
832
+ threshold,
833
+ min_samples,
834
+ frame_skip,
835
+ output_dir=mode_dir,
836
+ save_html_visualizations=True
837
+ )
838
+