crisp-ase 1.0.0.post0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crisp-ase might be problematic. Click here for more details.

@@ -0,0 +1,828 @@
1
+ """
2
+ CRISP/data_analysis/clustering.py
3
+
4
+ This module performs cluster analysis on molecular dynamics trajectory data,
5
+ using DBSCAN algorithm to identify atom clusters and their properties.
6
+ """
7
+
8
+ import numpy as np
9
+ from ase.io import read
10
+ from sklearn.cluster import DBSCAN
11
+ from sklearn.metrics import silhouette_score
12
+ import plotly.graph_objects as go
13
+ import pickle
14
+ import csv
15
+ import matplotlib.pyplot as plt
16
+ import os
17
+ from typing import Union, List
18
+
19
+
20
+ class analyze_frame:
21
+ """
22
+ Analyze atomic structures using DBSCAN clustering algorithm.
23
+
24
+ Parameters
25
+ ----------
26
+ traj_path : str
27
+ Path to trajectory file
28
+ atom_indices : np.ndarray
29
+ Array containing indices of atoms to analyze
30
+ threshold : float
31
+ DBSCAN eps parameter (distance threshold)
32
+ min_samples : int
33
+ DBSCAN min_samples parameter
34
+ metric : str, optional
35
+ Distance metric to use (default: 'precomputed')
36
+ custom_frame_index : int, optional
37
+ Specific frame to analyze (default: None, uses last frame)
38
+ """
39
+ def __init__(self, traj_path, atom_indices, threshold, min_samples, metric='precomputed', custom_frame_index=None):
40
+ self.traj_path = traj_path
41
+ if isinstance(atom_indices, str):
42
+ atom_indices = np.load(atom_indices)
43
+ self.atom_indices = atom_indices
44
+ self.threshold = threshold
45
+ self.min_samples = min_samples
46
+ self.metric = metric
47
+ self.custom_frame_index = custom_frame_index
48
+ self.labels = None
49
+ self.distance_matrix = None
50
+
51
+ def read_custom_frame(self):
52
+ """
53
+ Read a specific frame or the last frame from the trajectory.
54
+
55
+ Returns
56
+ -------
57
+ ase.Atoms or None
58
+ Atomic structure or None if reading fails
59
+ """
60
+ try:
61
+ if self.custom_frame_index is not None:
62
+ frame = read(self.traj_path, index=self.custom_frame_index)
63
+ else:
64
+ frame = read(self.traj_path, index='-1')
65
+ return frame
66
+ except Exception as e:
67
+ print(f"Error reading trajectory: {e}")
68
+ return None
69
+
70
+ def calculate_distance_matrix(self, atoms):
71
+ """
72
+ Calculate a distance matrix with periodic boundary conditions.
73
+
74
+ Parameters
75
+ ----------
76
+ atoms : ase.Atoms
77
+ Atomic structure
78
+
79
+ Returns
80
+ -------
81
+ Tuple[np.ndarray, np.ndarray]
82
+ Distance matrix and positions array
83
+
84
+ Raises
85
+ ------
86
+ ValueError
87
+ If there are not enough atoms to form clusters
88
+ """
89
+ positions = atoms.positions[self.atom_indices]
90
+
91
+ if len(self.atom_indices) < self.min_samples:
92
+ raise ValueError(f"Not enough atoms ({len(self.atom_indices)}) to form clusters with min_samples={self.min_samples}")
93
+
94
+ full_dm = atoms.get_all_distances(mic=True)
95
+ n_atoms = len(self.atom_indices)
96
+ self.distance_matrix = np.zeros((n_atoms, n_atoms))
97
+
98
+ for i, idx_i in enumerate(self.atom_indices):
99
+ for j, idx_j in enumerate(self.atom_indices):
100
+ self.distance_matrix[i, j] = full_dm[idx_i, idx_j]
101
+
102
+ return self.distance_matrix, positions
103
+
104
+ def find_clusters(self):
105
+ """
106
+ Find clusters using DBSCAN algorithm.
107
+
108
+ Returns
109
+ -------
110
+ np.ndarray
111
+ Cluster labels for each input point
112
+
113
+ Raises
114
+ ------
115
+ ValueError
116
+ If distance matrix has not been calculated
117
+ """
118
+ if self.distance_matrix is None:
119
+ raise ValueError("Distance matrix must be calculated first")
120
+
121
+ db = DBSCAN(
122
+ eps=self.threshold,
123
+ min_samples=self.min_samples,
124
+ metric=self.metric
125
+ ).fit(self.distance_matrix if self.metric == 'precomputed' else None)
126
+
127
+ self.labels = db.labels_
128
+ return self.labels
129
+
130
+ def analyze_structure(self, save_html_path=None, output_dir=None):
131
+ """
132
+ Analyze the structure and find clusters.
133
+
134
+ Parameters
135
+ ----------
136
+ save_html_path : str, optional
137
+ Path to save HTML visualization
138
+ output_dir : str, optional
139
+ Directory to save all results
140
+
141
+ Returns
142
+ -------
143
+ dict or None
144
+ Dictionary with analysis results or None if analysis fails
145
+ """
146
+ frame = self.read_custom_frame()
147
+ if frame is None:
148
+ return None
149
+
150
+ if output_dir is None and save_html_path is not None:
151
+ output_dir = os.path.dirname(save_html_path)
152
+ if not output_dir:
153
+ output_dir = "clustering_results"
154
+
155
+ os.makedirs(output_dir, exist_ok=True)
156
+ print(f"\nSaving results to directory: {output_dir}")
157
+
158
+ if save_html_path is None:
159
+ base_name = os.path.splitext(os.path.basename(self.traj_path))[0]
160
+ save_html_path = os.path.join(output_dir, f"{base_name}_clusters.html")
161
+
162
+ distance_matrix, positions = self.calculate_distance_matrix(frame)
163
+ self.find_clusters()
164
+
165
+ # Silhouette score (excluding outliers)
166
+ silhouette_avg = calculate_silhouette_score(self.distance_matrix, self.labels)
167
+
168
+ create_html_visualization(
169
+ positions=positions,
170
+ labels=self.labels,
171
+ title='Interactive 3D Cluster Visualization',
172
+ save_path=save_html_path,
173
+ cell_dimensions=frame.cell.lengths()
174
+ )
175
+
176
+ # Extract cluster information
177
+ cluster_info = extract_cluster_info(self.labels, self.atom_indices)
178
+ num_clusters = cluster_info["num_clusters"]
179
+ outlier_count = cluster_info["outlier_count"]
180
+ avg_cluster_size = cluster_info["avg_cluster_size"]
181
+ cluster_to_original = cluster_info["cluster_to_original"]
182
+
183
+ print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original)
184
+
185
+ frame_info_path = os.path.join(output_dir, "frame_data.txt")
186
+ save_frame_info_to_file(
187
+ frame_info_path,
188
+ self.threshold,
189
+ self.min_samples,
190
+ num_clusters,
191
+ outlier_count,
192
+ silhouette_avg,
193
+ avg_cluster_size,
194
+ cluster_to_original,
195
+ self.labels,
196
+ self.atom_indices
197
+ )
198
+
199
+ pickle_path = os.path.join(output_dir, "single_frame_analysis.pkl")
200
+ result_data = {
201
+ "num_clusters": num_clusters,
202
+ "outlier_count": outlier_count,
203
+ "silhouette_avg": silhouette_avg,
204
+ "avg_cluster_size": avg_cluster_size,
205
+ "cluster_to_original": cluster_to_original,
206
+ "labels": self.labels,
207
+ "positions": positions,
208
+ "parameters": {
209
+ "threshold": self.threshold,
210
+ "min_samples": self.min_samples,
211
+ "trajectory": self.traj_path,
212
+ "frame_index": self.custom_frame_index
213
+ }
214
+ }
215
+
216
+ with open(pickle_path, 'wb') as f:
217
+ pickle.dump(result_data, f)
218
+ print(f"Full analysis data saved to: {pickle_path}")
219
+
220
+ return result_data
221
+
222
+
223
+ def create_html_visualization(positions, labels, title, save_path, cell_dimensions=None):
224
+ """
225
+ Create and save a 3D HTML visualization of clusters.
226
+
227
+ Parameters
228
+ ----------
229
+ positions : np.ndarray
230
+ Array of atom positions
231
+ labels : np.ndarray
232
+ Array of cluster labels
233
+ title : str
234
+ Title for the visualization
235
+ save_path : str
236
+ Path to save the HTML file
237
+ cell_dimensions : np.ndarray, optional
238
+ Cell dimensions from the simulation [a, b, c]
239
+ """
240
+ fig = go.Figure()
241
+
242
+ for label in np.unique(labels):
243
+ cluster_points = positions[labels == label]
244
+ label_name = "Outliers" if label == -1 else f'Cluster {label}'
245
+ marker_size = 5
246
+ marker_color = 'gray' if label == -1 else None
247
+
248
+ fig.add_trace(
249
+ go.Scatter3d(
250
+ x=cluster_points[:, 0],
251
+ y=cluster_points[:, 1],
252
+ z=cluster_points[:, 2],
253
+ mode='markers',
254
+ marker=dict(
255
+ size=marker_size,
256
+ color=marker_color
257
+ ),
258
+ name=label_name
259
+ )
260
+ )
261
+
262
+ layout_args = {
263
+ "title": title,
264
+ "legend": dict(itemsizing='constant')
265
+ }
266
+
267
+ scene_dict = {
268
+ "xaxis": dict(title='X'),
269
+ "yaxis": dict(title='Y'),
270
+ "zaxis": dict(title='Z')
271
+ }
272
+
273
+ if cell_dimensions is not None:
274
+ scene_dict["xaxis"]["range"] = [0, cell_dimensions[0]]
275
+ scene_dict["yaxis"]["range"] = [0, cell_dimensions[1]]
276
+ scene_dict["zaxis"]["range"] = [0, cell_dimensions[2]]
277
+
278
+ vertices = [
279
+ [0, 0, 0], [cell_dimensions[0], 0, 0],
280
+ [cell_dimensions[0], cell_dimensions[1], 0], [0, cell_dimensions[1], 0],
281
+ [0, 0, cell_dimensions[2]], [cell_dimensions[0], 0, cell_dimensions[2]],
282
+ [cell_dimensions[0], cell_dimensions[1], cell_dimensions[2]], [0, cell_dimensions[1], cell_dimensions[2]]
283
+ ]
284
+
285
+ # Define box edges
286
+ i, j, k = [], [], []
287
+ # Bottom face
288
+ i.extend([0, 1, 2, 3, 0])
289
+ j.extend([1, 2, 3, 0, 4])
290
+ k.extend([0, 0, 0, 0, 0])
291
+ # Top face
292
+ i.extend([4, 5, 6, 7, 4])
293
+ j.extend([5, 6, 7, 4, 0])
294
+ k.extend([0, 0, 0, 0, 0])
295
+ # Vertical edges
296
+ i.extend([1, 2, 3])
297
+ j.extend([5, 6, 7])
298
+ k.extend([0, 0, 0])
299
+
300
+ fig.add_trace(go.Scatter3d(
301
+ x=[vertices[idx][0] for idx in i],
302
+ y=[vertices[idx][1] for idx in j],
303
+ z=[vertices[idx][2] for idx in k],
304
+ mode='lines',
305
+ line=dict(color='black', width=2),
306
+ name='Unit Cell',
307
+ showlegend=False
308
+ ))
309
+
310
+ layout_args["scene"] = scene_dict
311
+ fig.update_layout(**layout_args)
312
+
313
+ fig.write_html(save_path)
314
+ print(f"3D visualization saved to {save_path}")
315
+
316
+
317
+ def calculate_silhouette_score(distance_matrix, labels):
318
+ """
319
+ Calculate silhouette score, handling edge cases.
320
+
321
+ Parameters
322
+ ----------
323
+ distance_matrix : np.ndarray
324
+ Distance matrix for points
325
+ labels : np.ndarray
326
+ Cluster labels
327
+
328
+ Returns
329
+ -------
330
+ float
331
+ Silhouette score or 0 if calculation fails
332
+ """
333
+ try:
334
+ non_outlier_mask = labels != -1
335
+ if np.sum(non_outlier_mask) > 1:
336
+ # Extract the sub-matrix for non-outlier points
337
+ filtered_matrix = distance_matrix[np.ix_(non_outlier_mask, non_outlier_mask)]
338
+ filtered_labels = labels[non_outlier_mask]
339
+ return silhouette_score(filtered_matrix, filtered_labels, metric='precomputed')
340
+ return 0
341
+ except ValueError:
342
+ return 0
343
+
344
+
345
+ def extract_cluster_info(labels, atom_indices):
346
+ """
347
+ Extract cluster information from labels.
348
+
349
+ Parameters
350
+ ----------
351
+ labels : np.ndarray
352
+ Cluster labels
353
+ atom_indices : np.ndarray
354
+ Original atom indices
355
+
356
+ Returns
357
+ -------
358
+ dict
359
+ Dictionary with cluster information
360
+ """
361
+ cluster_indices = {}
362
+ cluster_sizes = {}
363
+ cluster_to_original = {}
364
+
365
+ for cluster_id in np.unique(labels):
366
+ if cluster_id != -1: # Only count actual clusters (not outliers)
367
+ cluster_indices[cluster_id] = np.where(labels == cluster_id)[0]
368
+ cluster_sizes[cluster_id] = len(cluster_indices[cluster_id])
369
+ cluster_to_original[cluster_id] = atom_indices[cluster_indices[cluster_id]]
370
+
371
+ outlier_count = np.sum(labels == -1)
372
+ num_clusters = len([label for label in np.unique(labels) if label != -1])
373
+
374
+ # Calculate average cluster size
375
+ avg_cluster_size = np.mean(list(cluster_sizes.values())) if cluster_sizes else 0
376
+
377
+ return {
378
+ "num_clusters": num_clusters,
379
+ "outlier_count": outlier_count,
380
+ "avg_cluster_size": avg_cluster_size,
381
+ "cluster_sizes": cluster_sizes,
382
+ "cluster_to_original": cluster_to_original
383
+ }
384
+
385
+
386
+ def print_cluster_summary(num_clusters, outlier_count, silhouette_avg, avg_cluster_size, cluster_to_original):
387
+ """
388
+ Print a summary of clustering results.
389
+
390
+ Parameters
391
+ ----------
392
+ num_clusters : int
393
+ Number of clusters found
394
+ outlier_count : int
395
+ Number of outliers
396
+ silhouette_avg : float
397
+ Average silhouette score
398
+ avg_cluster_size : float
399
+ Average cluster size
400
+ cluster_to_original : dict
401
+ Mapping from cluster IDs to original atom indices
402
+
403
+ Returns
404
+ -------
405
+ None
406
+ """
407
+ print(f"\nNumber of Clusters: {num_clusters}")
408
+ print(f"Number of Outliers: {outlier_count}")
409
+ print(f"Silhouette Score: {silhouette_avg:.4f}")
410
+ print(f"Average Cluster Size: {avg_cluster_size:.2f}")
411
+ print("Cluster Information:")
412
+
413
+ for cluster_id, atoms in cluster_to_original.items():
414
+ print(f" Cluster {cluster_id}: {len(atoms)} points")
415
+
416
+
417
+ def save_frame_info_to_file(file_path, threshold, min_samples, num_clusters, outlier_count,
418
+ silhouette_avg, avg_cluster_size, cluster_to_original, labels, atom_indices):
419
+ """
420
+ Save detailed frame information to a text file.
421
+
422
+ Parameters
423
+ ----------
424
+ file_path : str
425
+ Path to save the text file
426
+ threshold : float
427
+ DBSCAN eps parameter
428
+ min_samples : int
429
+ DBSCAN min_samples parameter
430
+ num_clusters : int
431
+ Number of clusters found
432
+ outlier_count : int
433
+ Number of outliers
434
+ silhouette_avg : float
435
+ Average silhouette score
436
+ avg_cluster_size : float
437
+ Average cluster size
438
+ cluster_to_original : dict
439
+ Mapping from cluster IDs to original atom indices
440
+ labels : np.ndarray
441
+ Cluster labels
442
+ atom_indices : np.ndarray
443
+ Original atom indices
444
+
445
+ Returns
446
+ -------
447
+ None
448
+ """
449
+ with open(file_path, 'w') as f:
450
+ f.write(f"DBSCAN Clustering Analysis Results\n")
451
+ f.write(f"================================\n\n")
452
+ f.write(f"Parameters:\n")
453
+ f.write(f" Threshold (eps): {threshold}\n")
454
+ f.write(f" Min Samples: {min_samples}\n\n")
455
+ f.write(f"Results:\n")
456
+ f.write(f" Number of Clusters: {num_clusters}\n")
457
+ f.write(f" Number of Outliers: {outlier_count}\n")
458
+ f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
459
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
460
+
461
+ f.write(f"Detailed Cluster Information:\n")
462
+ for cluster_id, indices in cluster_to_original.items():
463
+ f.write(f" Cluster {cluster_id}: {len(indices)} points\n")
464
+ f.write(f" Original atom indices: {indices.tolist()}\n\n")
465
+
466
+ f.write(f"Outlier Information:\n")
467
+ outlier_indices = atom_indices[labels == -1]
468
+ f.write(f" {len(outlier_indices)} outliers\n")
469
+ if len(outlier_indices) > 0:
470
+ f.write(f" Original atom indices: {outlier_indices.tolist()}\n")
471
+
472
+ print(f"Detailed frame data saved to: {file_path}")
473
+
474
+
475
+ def analyze_trajectory(traj_path, indices_path, threshold, min_samples, frame_skip=10,
476
+ output_dir="clustering_results", save_html_visualizations=True):
477
+ """
478
+ Analyze an entire trajectory with DBSCAN clustering.
479
+
480
+ Parameters
481
+ ----------
482
+ traj_path : str
483
+ Path to trajectory file (supports any ASE-readable format like XYZ)
484
+ indices_path : Union[str, List[int], np.ndarray]
485
+ Either a path to numpy file containing atom indices to analyze,
486
+ or a direct list/array of atom indices
487
+ threshold : float
488
+ DBSCAN eps parameter (distance threshold)
489
+ min_samples : int
490
+ DBSCAN min_samples parameter
491
+ frame_skip : int, optional
492
+ Number of frames to skip (default: 10)
493
+ output_dir : str, optional
494
+ Directory to save output files (default: "clustering_results")
495
+ save_html_visualizations : bool, optional
496
+ Whether to save HTML visualizations for first and last frames (default: True)
497
+
498
+ Returns
499
+ -------
500
+ list
501
+ List of analysis results for each frame
502
+ """
503
+ if isinstance(indices_path, str):
504
+ atom_indices = np.load(indices_path)
505
+ print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
506
+ else:
507
+ atom_indices = np.array(indices_path)
508
+ print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
509
+
510
+ os.makedirs(output_dir, exist_ok=True)
511
+
512
+ try:
513
+ print(f"Loading trajectory from {traj_path} (using every {frame_skip}th frame)...")
514
+ trajectory = read(traj_path, index=f'::{frame_skip}')
515
+ if not isinstance(trajectory, list):
516
+ trajectory = [trajectory]
517
+ print(f"Loaded {len(trajectory)} frames from trajectory")
518
+ except Exception as e:
519
+ print(f"Error reading trajectory: {e}")
520
+ return []
521
+
522
+ results = []
523
+
524
+ print(f"Analyzing {len(trajectory)} frames...")
525
+
526
+ for i, frame in enumerate(trajectory):
527
+ try:
528
+ frame_number = i * frame_skip
529
+
530
+ full_dm = frame.get_all_distances(mic=True)
531
+ n_atoms = len(atom_indices)
532
+ distance_matrix = np.zeros((n_atoms, n_atoms))
533
+
534
+ for i_local, idx_i in enumerate(atom_indices):
535
+ if idx_i >= len(frame):
536
+ print(f"Warning: Atom index {idx_i} out of range for frame with {len(frame)} atoms. Skipping.")
537
+ continue
538
+ for j_local, idx_j in enumerate(atom_indices):
539
+ if idx_j >= len(frame):
540
+ continue
541
+ distance_matrix[i_local, j_local] = full_dm[idx_i, idx_j]
542
+
543
+ db = DBSCAN(
544
+ eps=threshold,
545
+ min_samples=min_samples,
546
+ metric='precomputed'
547
+ ).fit(distance_matrix)
548
+
549
+ labels = db.labels_
550
+
551
+ # Extract positions for visualization
552
+ positions = frame.positions[atom_indices]
553
+
554
+ # Calculate silhouette score
555
+ silhouette_avg = calculate_silhouette_score(distance_matrix, labels)
556
+
557
+ # Extract cluster information
558
+ cluster_info = extract_cluster_info(labels, atom_indices)
559
+ num_clusters = cluster_info["num_clusters"]
560
+ outlier_count = cluster_info["outlier_count"]
561
+ avg_cluster_size = cluster_info["avg_cluster_size"]
562
+ cluster_to_original = cluster_info["cluster_to_original"]
563
+
564
+ if save_html_visualizations and (i == 0 or i == len(trajectory) - 1):
565
+ frame_prefix = "first" if i == 0 else "last"
566
+ html_path = os.path.join(output_dir, f"{frame_prefix}_frame_clusters.html")
567
+ create_html_visualization(
568
+ positions=positions,
569
+ labels=labels,
570
+ title=f"Frame {frame_number} Clusters",
571
+ save_path=html_path,
572
+ cell_dimensions=frame.cell.lengths()
573
+ )
574
+
575
+ results.append([frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size])
576
+
577
+ except Exception as e:
578
+ print(f"Error processing frame {i}: {e}")
579
+ results.append([i * frame_skip, 0, 0, 0.0, 0.0])
580
+
581
+ print(f"Trajectory analysis complete: {len(results)} frames processed")
582
+
583
+ if not results:
584
+ print("Warning: No results were generated from trajectory analysis")
585
+ return []
586
+
587
+ return results
588
+
589
+
590
+ def save_analysis_results(analysis_results, output_dir="clustering_results", output_prefix="clustering_results"):
591
+ """
592
+ Save analysis results to CSV, TXT, and PKL files in the specified output directory.
593
+
594
+ Parameters
595
+ ----------
596
+ analysis_results : list
597
+ List of analysis results for each frame
598
+ output_dir : str, optional
599
+ Directory to save output files (default: "clustering_results")
600
+ output_prefix : str, optional
601
+ Prefix for output file names (default: "clustering_results")
602
+
603
+ Returns
604
+ -------
605
+ str
606
+ Path to the saved pickle file
607
+ """
608
+ os.makedirs(output_dir, exist_ok=True)
609
+
610
+ output_csv_file = os.path.join(output_dir, f"{output_prefix}.csv")
611
+ output_txt_file = os.path.join(output_dir, f"{output_prefix}.txt")
612
+ output_pickle_file = os.path.join(output_dir, f"{output_prefix}.pkl")
613
+
614
+ with open(output_csv_file, 'w', newline='') as csvfile:
615
+ csv_writer = csv.writer(csvfile)
616
+ csv_writer.writerow([
617
+ "Frame Number",
618
+ "Number of Clusters",
619
+ "Number of Outliers",
620
+ "Silhouette Score",
621
+ "Average Cluster Size"
622
+ ])
623
+ for result in analysis_results:
624
+ csv_writer.writerow(result)
625
+
626
+ with open(output_txt_file, 'w') as f:
627
+ # Calculate averages
628
+ frame_numbers = [result[0] for result in analysis_results]
629
+ num_clusters = [result[1] for result in analysis_results]
630
+ outlier_counts = [result[2] for result in analysis_results]
631
+ silhouette_scores = [result[3] for result in analysis_results]
632
+ avg_cluster_sizes = [result[4] for result in analysis_results]
633
+
634
+ avg_num_clusters = np.mean(num_clusters)
635
+ avg_outlier_count = np.mean(outlier_counts)
636
+ avg_silhouette = np.mean(silhouette_scores)
637
+ avg_cluster_size = np.mean(avg_cluster_sizes)
638
+
639
+ f.write(f"DBSCAN Clustering Analysis Summary\n")
640
+ f.write(f"================================\n\n")
641
+ f.write(f"Average Values Across All Frames:\n")
642
+ f.write(f" Average Number of Clusters: {avg_num_clusters:.2f}\n")
643
+ f.write(f" Average Number of Outliers: {avg_outlier_count:.2f}\n")
644
+ f.write(f" Average Silhouette Score: {avg_silhouette:.4f}\n")
645
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
646
+
647
+ f.write(f"Analysis Results by Frame:\n")
648
+ for result in analysis_results:
649
+ frame_number, num_clusters, outlier_count, silhouette_avg, avg_cluster_size = result
650
+ f.write(f"Frame {frame_number}:\n")
651
+ f.write(f" Number of Clusters: {num_clusters}\n")
652
+ f.write(f" Number of Outliers: {outlier_count}\n")
653
+ f.write(f" Silhouette Score: {silhouette_avg:.4f}\n")
654
+ f.write(f" Average Cluster Size: {avg_cluster_size:.2f}\n\n")
655
+
656
+ with open(output_pickle_file, 'wb') as picklefile:
657
+ pickle.dump(analysis_results, picklefile)
658
+
659
+ print(f"Analysis results saved to directory: {output_dir}")
660
+
661
+ return output_pickle_file
662
+
663
+
664
+ def plot_analysis_results(pickle_file, output_dir=None):
665
+ """
666
+ Plot analysis results from a pickle file and save to specified directory.
667
+
668
+ Parameters
669
+ ----------
670
+ pickle_file : str
671
+ Path to pickle file containing analysis results
672
+ output_dir : str, optional
673
+ Directory to save output files
674
+
675
+ Returns
676
+ -------
677
+ None
678
+ """
679
+ with open(pickle_file, 'rb') as f:
680
+ analysis_results = pickle.load(f)
681
+
682
+ # Extract data for plotting
683
+ frame_numbers = [result[0] for result in analysis_results]
684
+ num_clusters = [result[1] for result in analysis_results]
685
+ outlier_counts = [result[2] for result in analysis_results]
686
+ silhouette_scores = [result[3] for result in analysis_results]
687
+ avg_cluster_sizes = [result[4] for result in analysis_results]
688
+
689
+ # Calculate averages
690
+ avg_num_clusters = np.mean(num_clusters)
691
+ avg_outlier_count = np.mean(outlier_counts)
692
+ avg_silhouette = np.mean(silhouette_scores)
693
+ avg_avg_cluster_size = np.mean(avg_cluster_sizes)
694
+
695
+ fig, axs = plt.subplots(4, 1, figsize=(18, 16), sharex=True)
696
+
697
+ # Plot for avg_cluster_sizes
698
+ axs[0].plot(frame_numbers, avg_cluster_sizes, color='red', linestyle='-', linewidth=2)
699
+ axs[0].axhline(y=avg_avg_cluster_size, color='darkred', linestyle='--', alpha=0.7,
700
+ label=f'Average: {avg_avg_cluster_size:.2f}')
701
+ axs[0].set_ylabel('Average Cluster Size', color='red', fontsize=16)
702
+ axs[0].tick_params(axis='y', labelcolor='red', labelsize=14)
703
+ axs[0].grid(True, alpha=0.3)
704
+ axs[0].legend(fontsize=12)
705
+
706
+ # Plot for number of clusters
707
+ axs[1].plot(frame_numbers, num_clusters, color='blue', linestyle='-', linewidth=2)
708
+ axs[1].axhline(y=avg_num_clusters, color='darkblue', linestyle='--', alpha=0.7,
709
+ label=f'Average: {avg_num_clusters:.2f}')
710
+ axs[1].set_ylabel('Number of Clusters', color='blue', fontsize=16)
711
+ axs[1].tick_params(axis='y', labelcolor='blue', labelsize=14)
712
+ axs[1].grid(True, alpha=0.3)
713
+ axs[1].legend(fontsize=12)
714
+
715
+ # Plot for outlier counts
716
+ axs[2].plot(frame_numbers, outlier_counts, color='purple', linestyle='-', linewidth=2)
717
+ axs[2].axhline(y=avg_outlier_count, color='darkviolet', linestyle='--', alpha=0.7,
718
+ label=f'Average: {avg_outlier_count:.2f}')
719
+ axs[2].set_ylabel('Number of Outliers', color='purple', fontsize=16)
720
+ axs[2].tick_params(axis='y', labelcolor='purple', labelsize=14)
721
+ axs[2].grid(True, alpha=0.3)
722
+ axs[2].legend(fontsize=12)
723
+
724
+ # Plot for silhouette scores
725
+ axs[3].plot(frame_numbers, silhouette_scores, color='orange', linestyle='-', linewidth=2)
726
+ axs[3].axhline(y=avg_silhouette, color='darkorange', linestyle='--', alpha=0.7,
727
+ label=f'Average: {avg_silhouette:.4f}')
728
+ axs[3].set_ylabel('Silhouette Score', color='orange', fontsize=16)
729
+ axs[3].tick_params(axis='y', labelcolor='orange', labelsize=14)
730
+ axs[3].grid(True, alpha=0.3)
731
+ axs[3].legend(fontsize=12)
732
+ axs[3].set_xlabel('Frame Number', fontsize=16)
733
+
734
+ plt.suptitle('Clustering Analysis Results', y=0.98, fontsize=20)
735
+
736
+ plt.tight_layout()
737
+ plt.subplots_adjust(top=0.95)
738
+
739
+ if output_dir is None:
740
+ output_dir = os.path.dirname(pickle_file)
741
+
742
+ os.makedirs(output_dir, exist_ok=True)
743
+
744
+ output_base = os.path.splitext(os.path.basename(pickle_file))[0]
745
+ plot_file = os.path.join(output_dir, f"{output_base}_plot.png")
746
+
747
+ plt.savefig(plot_file, dpi=300, bbox_inches='tight')
748
+
749
+ plt.show()
750
+
751
+ print(f"Analysis plot saved to: {plot_file}")
752
+
753
+
754
+ def cluster_analysis(traj_path, indices_path, threshold, min_samples=2,
755
+ mode='single', output_dir="clustering", custom_frame_index=None,
756
+ frame_skip=10, output_prefix="clustering_results"):
757
+ """
758
+ Analyze molecular structures with DBSCAN clustering.
759
+
760
+ Parameters
761
+ ----------
762
+ traj_path : str
763
+ Path to trajectory file (supports any ASE-readable format like XYZ)
764
+ indices_path : Union[str, List[int], np.ndarray]
765
+ Either a path to numpy file containing atom indices to analyze,
766
+ or a direct list/array of atom indices
767
+ threshold : float
768
+ DBSCAN clustering threshold (eps parameter)
769
+ min_samples : int, optional
770
+ Minimum number of samples in a cluster for DBSCAN (default: 2)
771
+ mode : str, optional
772
+ Analysis mode: 'single' for single frame, 'trajectory' for whole trajectory (default: 'single')
773
+ output_dir : str, optional
774
+ Directory to save output files (default: "clustering")
775
+ custom_frame_index : int, optional
776
+ Specific frame number to analyze in 'single' mode. If None, the last frame is analyzed
777
+ frame_skip : int, optional
778
+ Skip frames in trajectory analysis (default: 10)
779
+ output_prefix : str, optional
780
+ Prefix for output file names in trajectory analysis (default: "clustering_results")
781
+
782
+ Returns
783
+ -------
784
+ dict or list
785
+ Analysis result (dict for single frame, list for trajectory)
786
+ """
787
+ os.makedirs(output_dir, exist_ok=True)
788
+ print(f"Output files will be saved to: {output_dir}")
789
+
790
+ if isinstance(indices_path, str):
791
+ atom_indices = np.load(indices_path)
792
+ print(f"Loaded {len(atom_indices)} atoms for clustering from {indices_path}")
793
+ else:
794
+ atom_indices = np.array(indices_path)
795
+ print(f"Using {len(atom_indices)} directly provided atom indices for clustering")
796
+
797
+ if mode == 'single':
798
+ # Create a mode-specific subdirectory
799
+ mode_dir = os.path.join(output_dir, "single_frame")
800
+ os.makedirs(mode_dir, exist_ok=True)
801
+
802
+ # Analyze a single frame
803
+ analyzer = analyze_frame(
804
+ traj_path,
805
+ atom_indices,
806
+ threshold,
807
+ min_samples,
808
+ metric='precomputed',
809
+ custom_frame_index=custom_frame_index
810
+ )
811
+
812
+ analysis_result = analyzer.analyze_structure(output_dir=mode_dir)
813
+ return analysis_result
814
+
815
+ else:
816
+ mode_dir = os.path.join(output_dir, "trajectory")
817
+ os.makedirs(mode_dir, exist_ok=True)
818
+
819
+ analysis_results = analyze_trajectory(
820
+ traj_path,
821
+ atom_indices,
822
+ threshold,
823
+ min_samples,
824
+ frame_skip,
825
+ output_dir=mode_dir,
826
+ save_html_visualizations=True
827
+ )
828
+