RNApolis 0.9.2__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rnapolis/distiller.py ADDED
@@ -0,0 +1,1119 @@
1
+ import argparse
2
+ import hashlib
3
+ import itertools
4
+ import json
5
+ import os
6
+ import sys
7
+ import time
8
+ from concurrent.futures import ProcessPoolExecutor
9
+ from pathlib import Path
10
+ from typing import Dict, List, Optional, Tuple
11
+
12
+ import faiss
13
+ import numpy as np
14
+ from scipy.cluster.hierarchy import dendrogram, fcluster, linkage
15
+ from scipy.optimize import curve_fit
16
+ from scipy.spatial.distance import squareform
17
+ from sklearn.decomposition import PCA
18
+ from tqdm import tqdm
19
+
20
+ from rnapolis.parser_v2 import parse_cif_atoms, parse_pdb_atoms
21
+ from rnapolis.tertiary_v2 import (
22
+ Structure,
23
+ calculate_torsion_angle,
24
+ nrmsd_qcp_residues,
25
+ nrmsd_quaternions_residues,
26
+ nrmsd_svd_residues,
27
+ nrmsd_validate_residues,
28
+ )
29
+
30
+
31
+ def parse_arguments():
32
+ """Parse command line arguments."""
33
+ parser = argparse.ArgumentParser(
34
+ description="Find clusters of almost identical RNA structures from mmCIF or PDB files"
35
+ )
36
+
37
+ parser.add_argument(
38
+ "files", nargs="+", type=Path, help="Input mmCIF or PDB files to analyze"
39
+ )
40
+
41
+ parser.add_argument(
42
+ "--visualize",
43
+ action="store_true",
44
+ help="Show dendrogram visualization of clustering (exact mode only)",
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--output-json",
49
+ type=str,
50
+ help="Output JSON file to save clustering results (available in both modes)",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--rmsd-method",
55
+ type=str,
56
+ choices=["quaternions", "svd", "qcp", "validate"],
57
+ default="quaternions",
58
+ help="RMSD calculation method (default: quaternions). Use 'validate' to check all methods agree. (exact mode only)",
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--threshold",
63
+ type=float,
64
+ default=None,
65
+ help="nRMSD threshold for clustering (default: auto-detect from exponential decay inflection point) (exact mode only)",
66
+ )
67
+
68
+ parser.add_argument(
69
+ "--cache-file",
70
+ type=str,
71
+ default="nrmsd_cache.json",
72
+ help="Cache file for storing computed nRMSD values (exact mode only, default: nrmsd_cache.json)",
73
+ )
74
+
75
+ parser.add_argument(
76
+ "--cache-save-interval",
77
+ type=int,
78
+ default=100,
79
+ help="Save cache to disk every N computations (exact mode only, default: 100)",
80
+ )
81
+
82
+ parser.add_argument(
83
+ "--mode",
84
+ choices=["exact", "approximate"],
85
+ default="exact",
86
+ help="Clustering mode switch: --mode exact (default) performs rigorous nRMSD clustering, "
87
+ "--mode approximate performs faster feature-based PCA + FAISS clustering",
88
+ )
89
+
90
+ parser.add_argument(
91
+ "--radius",
92
+ type=float,
93
+ default=10.0,
94
+ help="Radius in PCA-reduced space for redundancy detection (approximate mode only, default: 10.0)",
95
+ )
96
+
97
+ return parser.parse_args()
98
+
99
+
100
+ class NRMSDCache:
101
+ """Cache for storing computed nRMSD values with file metadata."""
102
+
103
+ def __init__(self, cache_file: str, save_interval: int = 100):
104
+ self.cache_file = cache_file
105
+ self.save_interval = save_interval
106
+ self.cache: Dict[str, float] = {}
107
+ self.computation_count = 0
108
+ self.load_cache()
109
+
110
+ def _get_file_key(self, file_path: Path) -> str:
111
+ """Generate a unique key for a file based on path and modification time."""
112
+ stat = file_path.stat()
113
+ return f"{file_path.absolute()}:{stat.st_mtime}:{stat.st_size}"
114
+
115
+ def _get_pair_key(self, file1: Path, file2: Path, rmsd_method: str) -> str:
116
+ """Generate a unique key for a file pair and method."""
117
+ key1 = self._get_file_key(file1)
118
+ key2 = self._get_file_key(file2)
119
+ # Ensure consistent ordering
120
+ if key1 > key2:
121
+ key1, key2 = key2, key1
122
+ combined = f"{key1}|{key2}|{rmsd_method}"
123
+ # Use hash to keep keys manageable
124
+ return hashlib.md5(combined.encode()).hexdigest()
125
+
126
+ def load_cache(self):
127
+ """Load cache from disk if it exists."""
128
+ if os.path.exists(self.cache_file):
129
+ try:
130
+ with open(self.cache_file, "r") as f:
131
+ self.cache = json.load(f)
132
+ print(
133
+ f"Loaded {len(self.cache)} cached nRMSD values from {self.cache_file}"
134
+ )
135
+ except Exception as e:
136
+ print(
137
+ f"Warning: Could not load cache file {self.cache_file}: {e}",
138
+ file=sys.stderr,
139
+ )
140
+ self.cache = {}
141
+ else:
142
+ print(f"No existing cache file found at {self.cache_file}")
143
+
144
+ def save_cache(self, silent: bool = False):
145
+ """Save cache to disk."""
146
+ try:
147
+ with open(self.cache_file, "w") as f:
148
+ json.dump(self.cache, f, indent=2)
149
+ if not silent:
150
+ print(f"Saved {len(self.cache)} cached values to {self.cache_file}")
151
+ except Exception as e:
152
+ print(
153
+ f"Warning: Could not save cache file {self.cache_file}: {e}",
154
+ file=sys.stderr,
155
+ )
156
+
157
+ def get(self, file1: Path, file2: Path, rmsd_method: str) -> Optional[float]:
158
+ """Get cached nRMSD value if available."""
159
+ key = self._get_pair_key(file1, file2, rmsd_method)
160
+ return self.cache.get(key)
161
+
162
+ def set(self, file1: Path, file2: Path, rmsd_method: str, value: float):
163
+ """Store nRMSD value in cache."""
164
+ key = self._get_pair_key(file1, file2, rmsd_method)
165
+ self.cache[key] = value
166
+ self.computation_count += 1
167
+
168
+ # Save periodically (silently to avoid disrupting progress bar)
169
+ if self.computation_count % self.save_interval == 0:
170
+ self.save_cache(silent=True)
171
+
172
+
173
+ def validate_input_files(files: List[Path]) -> List[Path]:
174
+ """Validate that input files exist and have appropriate extensions."""
175
+ valid_files = []
176
+ valid_extensions = {".pdb", ".cif", ".mmcif"}
177
+
178
+ for file_path in files:
179
+ if not file_path.exists():
180
+ print(
181
+ f"Warning: File {file_path} does not exist, skipping", file=sys.stderr
182
+ )
183
+ continue
184
+
185
+ if file_path.suffix.lower() not in valid_extensions:
186
+ print(
187
+ f"Warning: File {file_path} does not have a recognized extension (.pdb, .cif, .mmcif), skipping",
188
+ file=sys.stderr,
189
+ )
190
+ continue
191
+
192
+ valid_files.append(file_path)
193
+
194
+ return valid_files
195
+
196
+
197
+ def parse_structure_file(file_path: Path) -> Structure:
198
+ """
199
+ Parse a structure file (PDB or mmCIF) into a Structure object.
200
+
201
+ Parameters:
202
+ -----------
203
+ file_path : Path
204
+ Path to the structure file
205
+
206
+ Returns:
207
+ --------
208
+ Structure
209
+ Parsed structure object
210
+ """
211
+ try:
212
+ with open(file_path, "r") as f:
213
+ content = f.read()
214
+
215
+ # Determine file type and parse accordingly
216
+ if file_path.suffix.lower() == ".pdb":
217
+ atoms_df = parse_pdb_atoms(content)
218
+ else: # .cif or .mmcif
219
+ atoms_df = parse_cif_atoms(content)
220
+
221
+ return Structure(atoms_df)
222
+
223
+ except Exception as e:
224
+ print(f"Error parsing {file_path}: {e}", file=sys.stderr)
225
+ raise
226
+
227
+
228
+ def validate_nucleotide_counts(
229
+ structures: List[Structure], file_paths: List[Path]
230
+ ) -> None:
231
+ """
232
+ Validate that all structures have the same number of nucleotides.
233
+
234
+ Parameters:
235
+ -----------
236
+ structures : List[Structure]
237
+ List of parsed structures
238
+ file_paths : List[Path]
239
+ Corresponding file paths for error reporting
240
+
241
+ Raises:
242
+ -------
243
+ SystemExit
244
+ If structures have different numbers of nucleotides
245
+ """
246
+ nucleotide_counts = []
247
+
248
+ for structure, file_path in zip(structures, file_paths):
249
+ nucleotide_residues = [
250
+ residue for residue in structure.residues if residue.is_nucleotide
251
+ ]
252
+ nucleotide_counts.append((len(nucleotide_residues), file_path))
253
+
254
+ if not nucleotide_counts:
255
+ print("Error: No structures with nucleotides found", file=sys.stderr)
256
+ sys.exit(1)
257
+
258
+ # Check if all counts are the same
259
+ first_count = nucleotide_counts[0][0]
260
+ mismatched = [
261
+ (count, path) for count, path in nucleotide_counts if count != first_count
262
+ ]
263
+
264
+ if mismatched:
265
+ print(
266
+ "Error: Structures have different numbers of nucleotides:", file=sys.stderr
267
+ )
268
+ print(
269
+ f"Expected: {first_count} nucleotides (from {nucleotide_counts[0][1]})",
270
+ file=sys.stderr,
271
+ )
272
+ for count, path in mismatched:
273
+ print(f"Found: {count} nucleotides in {path}", file=sys.stderr)
274
+ sys.exit(1)
275
+
276
+ print(f"All structures have {first_count} nucleotides")
277
+
278
+
279
+ # ----------------------------------------------------------------------
280
+
281
+
282
+ def run_exact(structures: List[Structure], valid_files: List[Path], args) -> None:
283
+ """
284
+ Exact mode: nRMSD-based clustering workflow (previously in main).
285
+ Produces the same outputs/visualisations as before.
286
+ """
287
+ # Initialize cache
288
+ print("\nInitializing nRMSD cache...")
289
+ cache = NRMSDCache(args.cache_file, args.cache_save_interval)
290
+
291
+ # Compute distance matrix
292
+ print("\nComputing distance matrix...")
293
+ distance_matrix = find_structure_clusters(
294
+ structures, valid_files, cache, args.visualize, args.rmsd_method
295
+ )
296
+
297
+ # Build linkage matrix
298
+ linkage_matrix = linkage(squareform(distance_matrix), method="complete")
299
+
300
+ # Determine threshold
301
+ if args.threshold is None:
302
+ optimal_threshold = determine_optimal_threshold(distance_matrix, linkage_matrix)
303
+ else:
304
+ optimal_threshold = args.threshold
305
+ print(f"Using user-specified threshold: {optimal_threshold}")
306
+
307
+ # Collect threshold data
308
+ all_threshold_data = find_all_thresholds_and_clusters(
309
+ distance_matrix, linkage_matrix, valid_files
310
+ )
311
+ threshold_clustering = get_clustering_at_threshold(
312
+ linkage_matrix, distance_matrix, valid_files, optimal_threshold
313
+ )
314
+
315
+ # Visualisation (re-uses the earlier logic)
316
+ if args.visualize:
317
+ try:
318
+ import matplotlib.pyplot as plt
319
+
320
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
321
+
322
+ dendrogram(
323
+ linkage_matrix,
324
+ labels=[f"Structure {i}" for i in range(len(structures))],
325
+ ax=ax1,
326
+ color_threshold=optimal_threshold,
327
+ )
328
+ ax1.axhline(
329
+ y=optimal_threshold,
330
+ color="red",
331
+ linestyle="--",
332
+ linewidth=2,
333
+ label=f"Threshold = {optimal_threshold:.6f}",
334
+ )
335
+ ax1.set_title("Hierarchical Clustering Dendrogram")
336
+ ax1.set_xlabel("Structure Index")
337
+ ax1.set_ylabel("nRMSD Distance")
338
+ ax1.legend()
339
+
340
+ thresholds = np.array(
341
+ [entry["nrmsd_threshold"] for entry in all_threshold_data]
342
+ )
343
+ cluster_counts = np.array(
344
+ [len(entry["clusters"]) for entry in all_threshold_data]
345
+ )
346
+
347
+ x_smooth, y_smooth, inflection_x = fit_exponential_decay(
348
+ thresholds, cluster_counts
349
+ )
350
+
351
+ ax2.scatter(
352
+ thresholds, cluster_counts, alpha=0.7, s=30, label="Data points"
353
+ )
354
+ if len(x_smooth) > 0:
355
+ ax2.plot(
356
+ x_smooth,
357
+ y_smooth,
358
+ "b-",
359
+ linewidth=2,
360
+ alpha=0.8,
361
+ label="Exponential decay fit",
362
+ )
363
+
364
+ if len(inflection_x) > 0 and len(x_smooth) > 0:
365
+ inflection_y = np.interp(inflection_x, x_smooth, y_smooth)
366
+ ax2.scatter(
367
+ inflection_x,
368
+ inflection_y,
369
+ color="orange",
370
+ s=100,
371
+ marker="*",
372
+ zorder=6,
373
+ label=f"Key points ({len(inflection_x)})",
374
+ )
375
+
376
+ ax2.axvline(
377
+ x=optimal_threshold,
378
+ color="red",
379
+ linestyle="--",
380
+ linewidth=2,
381
+ label=f"Threshold = {optimal_threshold:.6f}",
382
+ )
383
+ ax2.scatter(
384
+ [optimal_threshold],
385
+ [threshold_clustering["n_clusters"]],
386
+ color="red",
387
+ s=100,
388
+ zorder=5,
389
+ label=f"Selected ({threshold_clustering['n_clusters']} clusters)",
390
+ )
391
+
392
+ ax2.set_xlabel("nRMSD Threshold")
393
+ ax2.set_ylabel("Number of Clusters")
394
+ ax2.set_title("Threshold vs Cluster Count with Exponential Decay Fit")
395
+ ax2.grid(True, alpha=0.3)
396
+ ax2.legend()
397
+
398
+ plt.tight_layout()
399
+ plt.savefig("clustering_analysis.png", dpi=300, bbox_inches="tight")
400
+ print("Clustering analysis plots saved to clustering_analysis.png")
401
+
402
+ try:
403
+ plt.show()
404
+ except Exception:
405
+ print("Note: Could not display plot interactively, but saved to file")
406
+ except ImportError:
407
+ print(
408
+ "Warning: matplotlib not available, skipping visualization",
409
+ file=sys.stderr,
410
+ )
411
+
412
+ # Summary
413
+ print(f"\nFound {len(all_threshold_data)} different clustering configurations")
414
+ print(
415
+ f"Threshold range: {all_threshold_data[0]['nrmsd_threshold']:.6f} to {all_threshold_data[-1]['nrmsd_threshold']:.6f}"
416
+ )
417
+ print(
418
+ f"Cluster count range: {len(all_threshold_data[-1]['clusters'])} to {len(all_threshold_data[0]['clusters'])}"
419
+ )
420
+
421
+ print(f"\nClustering at threshold {optimal_threshold:.6f}:")
422
+ print(f" Number of clusters: {threshold_clustering['n_clusters']}")
423
+ print(f" Cluster sizes: {threshold_clustering['cluster_sizes']}")
424
+ for i, cluster in enumerate(threshold_clustering["clusters"]):
425
+ print(
426
+ f" Cluster {i + 1}: {cluster['representative']} + {len(cluster['members'])} members"
427
+ )
428
+
429
+ if args.output_json:
430
+ output_data = {
431
+ "all_thresholds": all_threshold_data,
432
+ "threshold_clustering": threshold_clustering,
433
+ "parameters": {
434
+ "threshold": optimal_threshold,
435
+ "threshold_source": "user-specified"
436
+ if args.threshold is not None
437
+ else "auto-detected",
438
+ "rmsd_method": args.rmsd_method,
439
+ },
440
+ }
441
+ with open(args.output_json, "w") as f:
442
+ json.dump(output_data, f, indent=2)
443
+ print(f"\nComprehensive clustering results saved to {args.output_json}")
444
+
445
+
446
+ # Approximate mode helper functions and workflow
447
+ # ----------------------------------------------------------------------
448
+ def _select_base_atoms(residue) -> List[Optional[np.ndarray]]:
449
+ """
450
+ Select four canonical base atoms for a nucleotide residue.
451
+
452
+ Purines (A/G/DA/DG): N9, N3, N1, C5
453
+ Pyrimidines (C/U/DC/DT): N1, O2, N3, C5
454
+
455
+ If residue name is unknown, we try purine mapping first and, if incomplete,
456
+ fall back to pyrimidine mapping. Returned list always has length 4 and may
457
+ contain ``None`` when coordinates are missing.
458
+ """
459
+ purines = {"A", "G", "DA", "DG"}
460
+ pyrimidines = {"C", "U", "DC", "DT"}
461
+
462
+ def _coords_for(names: List[str]) -> List[Optional[np.ndarray]]:
463
+ """Helper to fetch coordinates for a list of atom names."""
464
+ return [
465
+ (atom.coordinates if (atom := residue.find_atom(n)) is not None else None)
466
+ for n in names
467
+ ]
468
+
469
+ if residue.residue_name in purines:
470
+ return _coords_for(["N9", "N3", "N1", "C5"])
471
+
472
+ if residue.residue_name in pyrimidines:
473
+ return _coords_for(["N1", "O2", "N3", "C5"])
474
+
475
+ # Unknown residue – attempt purine rule first, then pyrimidine
476
+ coords = _coords_for(["N9", "N3", "N1", "C5"])
477
+ if all(c is not None for c in coords):
478
+ return coords
479
+ return _coords_for(["N1", "O2", "N3", "C5"])
480
+
481
+
482
+ def featurize_structure(structure: Structure) -> np.ndarray:
483
+ """
484
+ Convert a Structure into a fixed-length feature vector.
485
+ For n residues the length is 34 * n * (n-1) / 2.
486
+ """
487
+ residues = [r for r in structure.residues if r.is_nucleotide]
488
+ n = len(residues)
489
+ if n < 2:
490
+ return np.zeros(0, dtype=np.float32)
491
+
492
+ base_coords = [_select_base_atoms(r) for r in residues]
493
+ feats: List[float] = []
494
+
495
+ for i in range(n):
496
+ ai = base_coords[i]
497
+ for j in range(i + 1, n):
498
+ aj = base_coords[j]
499
+
500
+ # 16 distances
501
+ for ci in ai:
502
+ for cj in aj:
503
+ if ci is None or cj is None:
504
+ dist = 0.0
505
+ else:
506
+ dist = float(np.linalg.norm(ci - cj))
507
+ feats.append(dist)
508
+
509
+ # 18 torsion features (sin, cos over 9 angles)
510
+ a1 = ai[0]
511
+ a4 = aj[0]
512
+ for idx2 in range(1, 4):
513
+ for idx3 in range(1, 4):
514
+ a2, a3 = ai[idx2], aj[idx3]
515
+ if any(x is None for x in (a1, a2, a3, a4)):
516
+ feats.extend([0.0, 1.0])
517
+ else:
518
+ angle = calculate_torsion_angle(a1, a2, a3, a4)
519
+ feats.extend([float(np.sin(angle)), float(np.cos(angle))])
520
+
521
+ return np.asarray(feats, dtype=np.float32)
522
+
523
+
524
+ def run_approximate(structures: List[Structure], file_paths: List[Path], args) -> None:
525
+ """
526
+ Approximate mode: features → PCA → FAISS radius clustering.
527
+ """
528
+ print("\nRunning approximate mode (feature-based PCA + FAISS)")
529
+
530
+ feature_vectors = [featurize_structure(s) for s in structures]
531
+ feature_lengths = {len(v) for v in feature_vectors}
532
+ if len(feature_lengths) != 1:
533
+ print("Error: Inconsistent feature lengths among structures", file=sys.stderr)
534
+ sys.exit(1)
535
+
536
+ X = np.stack(feature_vectors).astype(np.float32)
537
+ print(f"Feature matrix shape: {X.shape}")
538
+
539
+ pca = PCA(n_components=0.95, svd_solver="full", random_state=0)
540
+ X_red = pca.fit_transform(X).astype(np.float32)
541
+ d = X_red.shape[1]
542
+ print(f"PCA reduced to {d} dimensions (95 % variance)")
543
+
544
+ index = faiss.IndexFlatL2(d)
545
+ index.add(X_red)
546
+ radius_sq = args.radius**2
547
+
548
+ visited: set[int] = set()
549
+ clusters: List[List[int]] = []
550
+
551
+ for idx in range(len(structures)):
552
+ if idx in visited:
553
+ continue
554
+ D, I = index.search(X_red[idx : idx + 1], len(structures))
555
+ cluster = [int(i) for dist, i in zip(D[0], I[0]) if dist <= radius_sq]
556
+ clusters.append(cluster)
557
+ visited.update(cluster)
558
+
559
+ print(f"\nIdentified {len(clusters)} representatives with radius {args.radius}")
560
+ for cluster in clusters:
561
+ rep = cluster[0]
562
+ redundants = cluster[1:]
563
+ print(f"Representative: {file_paths[rep]}")
564
+ for r in redundants:
565
+ print(f" Redundant: {file_paths[r]}")
566
+
567
+ if args.output_json:
568
+ out = {
569
+ "parameters": {"mode": "approximate", "radius": args.radius},
570
+ "clusters": [
571
+ {
572
+ "representative": str(file_paths[c[0]]),
573
+ "members": [str(file_paths[m]) for m in c[1:]],
574
+ }
575
+ for c in clusters
576
+ ],
577
+ }
578
+ with open(args.output_json, "w") as f:
579
+ json.dump(out, f, indent=2)
580
+ print(f"\nApproximate clustering saved to {args.output_json}")
581
+
582
+ return
583
+
584
+
585
+ # ----------------------------------------------------------------------
586
+
587
+
588
+ def find_all_thresholds_and_clusters(
589
+ distance_matrix: np.ndarray, linkage_matrix: np.ndarray, file_paths: List[Path]
590
+ ) -> List[dict]:
591
+ """
592
+ Find all threshold values where cluster assignments change and generate cluster data.
593
+
594
+ Parameters:
595
+ -----------
596
+ distance_matrix : np.ndarray
597
+ Square distance matrix
598
+ linkage_matrix : np.ndarray
599
+ Linkage matrix from hierarchical clustering
600
+ file_paths : List[Path]
601
+ List of file paths corresponding to structures
602
+
603
+ Returns:
604
+ --------
605
+ List[dict]
606
+ List of threshold cluster data
607
+ """
608
+ print("Finding all threshold values where cluster assignments change...")
609
+
610
+ # Extract merge distances from linkage matrix (column 2)
611
+ # These are the exact thresholds where cluster assignments change
612
+ merge_distances = linkage_matrix[:, 2]
613
+
614
+ # Sort thresholds in ascending order (all thresholds, no range filtering)
615
+ valid_thresholds = np.sort(merge_distances)
616
+
617
+ print(f"Testing {len(valid_thresholds)} threshold values where clustering changes:")
618
+
619
+ threshold_data = []
620
+
621
+ for threshold in valid_thresholds:
622
+ labels = fcluster(linkage_matrix, threshold, criterion="distance")
623
+ n_clusters = len(np.unique(labels))
624
+
625
+ # Group structure indices by cluster
626
+ clusters = {}
627
+ for i, label in enumerate(labels):
628
+ if label not in clusters:
629
+ clusters[label] = []
630
+ clusters[label].append(i)
631
+
632
+ cluster_sizes = [len(cluster) for cluster in clusters.values()]
633
+ cluster_sizes.sort(reverse=True) # Sort by size, largest first
634
+
635
+ print(
636
+ f" Threshold {threshold:.6f}: {n_clusters} clusters, sizes: {cluster_sizes}"
637
+ )
638
+
639
+ # Find medoids for each cluster
640
+ medoids = find_cluster_medoids(list(clusters.values()), distance_matrix)
641
+
642
+ # Create threshold data entry
643
+ threshold_entry = {"nrmsd_threshold": float(threshold), "clusters": []}
644
+
645
+ for cluster_indices, medoid_idx in zip(clusters.values(), medoids):
646
+ representative = str(file_paths[medoid_idx])
647
+ members = [
648
+ str(file_paths[idx]) for idx in cluster_indices if idx != medoid_idx
649
+ ]
650
+
651
+ threshold_entry["clusters"].append(
652
+ {"representative": representative, "members": members}
653
+ )
654
+
655
+ threshold_data.append(threshold_entry)
656
+
657
+ return threshold_data
658
+
659
+
660
+ def find_structure_clusters(
661
+ structures: List[Structure],
662
+ file_paths: List[Path],
663
+ cache: NRMSDCache,
664
+ visualize: bool = False,
665
+ rmsd_method: str = "quaternions",
666
+ ) -> np.ndarray:
667
+ """
668
+ Find clusters of almost identical structures using hierarchical clustering.
669
+
670
+ Parameters:
671
+ -----------
672
+ structures : List[Structure]
673
+ List of parsed structures to analyze
674
+ visualize : bool
675
+ Whether to show dendrogram and scatter plot visualization
676
+ rmsd_method : str
677
+ RMSD calculation method ("quaternions" or "svd")
678
+
679
+ Returns:
680
+ --------
681
+ np.ndarray
682
+ Distance matrix between all structures
683
+ """
684
+ n_structures = len(structures)
685
+
686
+ if n_structures == 1:
687
+ return [[0]], np.zeros((1, 1))
688
+
689
+ # Get nucleotide residues for each structure
690
+ nucleotide_lists = []
691
+ for structure in structures:
692
+ nucleotide_lists.append(
693
+ [residue for residue in structure.residues if residue.is_nucleotide]
694
+ )
695
+
696
+ # Select nRMSD function based on method
697
+ if rmsd_method == "quaternions":
698
+ nrmsd_func = nrmsd_quaternions_residues
699
+ print("Computing pairwise nRMSD distances using quaternion method...")
700
+ elif rmsd_method == "svd":
701
+ nrmsd_func = nrmsd_svd_residues
702
+ print("Computing pairwise nRMSD distances using SVD method...")
703
+ elif rmsd_method == "qcp":
704
+ nrmsd_func = nrmsd_qcp_residues
705
+ print("Computing pairwise nRMSD distances using QCP method...")
706
+ elif rmsd_method == "validate":
707
+ nrmsd_func = nrmsd_validate_residues
708
+ print(
709
+ "Computing pairwise nRMSD distances using validation mode (all methods)..."
710
+ )
711
+ else:
712
+ raise ValueError(f"Unknown RMSD method: {rmsd_method}")
713
+
714
+ distance_matrix = np.zeros((n_structures, n_structures))
715
+
716
+ # Prepare all pairs, checking cache first
717
+ cached_pairs = []
718
+ compute_pairs = []
719
+
720
+ for i, j in itertools.combinations(range(n_structures), 2):
721
+ cached_value = cache.get(file_paths[i], file_paths[j], rmsd_method)
722
+ if cached_value is not None:
723
+ cached_pairs.append((i, j, cached_value))
724
+ else:
725
+ compute_pairs.append((i, j, nucleotide_lists[i], nucleotide_lists[j]))
726
+
727
+ print(
728
+ f"Found {len(cached_pairs)} cached values, computing {len(compute_pairs)} new values"
729
+ )
730
+
731
+ # Fill distance matrix with cached values
732
+ for i, j, nrmsd_value in cached_pairs:
733
+ distance_matrix[i, j] = nrmsd_value
734
+ distance_matrix[j, i] = nrmsd_value
735
+
736
+ # Process remaining pairs with progress bar and timing
737
+ if compute_pairs:
738
+ start_time = time.time()
739
+ with ProcessPoolExecutor() as executor:
740
+ futures_dict = {
741
+ executor.submit(nrmsd_func, nucleotides_i, nucleotides_j): (i, j)
742
+ for i, j, nucleotides_i, nucleotides_j in compute_pairs
743
+ }
744
+ results = []
745
+ for future in tqdm(
746
+ futures_dict,
747
+ total=len(futures_dict),
748
+ desc="Computing nRMSD",
749
+ unit="pair",
750
+ ):
751
+ i, j = futures_dict[future]
752
+ nrmsd_value = future.result()
753
+ results.append((i, j, nrmsd_value))
754
+
755
+ # Cache the computed value
756
+ cache.set(file_paths[i], file_paths[j], rmsd_method, nrmsd_value)
757
+
758
+ end_time = time.time()
759
+ computation_time = end_time - start_time
760
+
761
+ print(f"RMSD computation completed in {computation_time:.2f} seconds")
762
+ if rmsd_method == "validate":
763
+ print(
764
+ "Note: Validation mode tests all methods, so timing includes overhead from multiple calculations"
765
+ )
766
+
767
+ # Fill the distance matrix with computed values
768
+ for i, j, nrmsd in results:
769
+ distance_matrix[i, j] = nrmsd
770
+ distance_matrix[j, i] = nrmsd
771
+
772
+ # Save cache after all computations
773
+ cache.save_cache()
774
+
775
+ # Convert to condensed distance matrix for scipy
776
+ condensed_distances = squareform(distance_matrix)
777
+
778
+ # Perform hierarchical clustering with complete linkage
779
+ linkage_matrix = linkage(condensed_distances, method="complete")
780
+
781
+ # Return distance matrix for further processing
782
+ return distance_matrix
783
+
784
+
785
+ def get_clustering_at_threshold(
786
+ linkage_matrix: np.ndarray,
787
+ distance_matrix: np.ndarray,
788
+ file_paths: List[Path],
789
+ threshold: float,
790
+ ) -> dict:
791
+ """
792
+ Get clustering results at a specific threshold.
793
+
794
+ Parameters:
795
+ -----------
796
+ linkage_matrix : np.ndarray
797
+ Linkage matrix from hierarchical clustering
798
+ distance_matrix : np.ndarray
799
+ Square distance matrix
800
+ file_paths : List[Path]
801
+ List of file paths corresponding to structures
802
+ threshold : float
803
+ nRMSD threshold for clustering
804
+
805
+ Returns:
806
+ --------
807
+ dict
808
+ Dictionary with clustering information at the given threshold
809
+ """
810
+ # Get cluster assignments at this threshold
811
+ labels = fcluster(linkage_matrix, threshold, criterion="distance")
812
+ n_clusters = len(np.unique(labels))
813
+
814
+ # Group structure indices by cluster
815
+ clusters = {}
816
+ for i, label in enumerate(labels):
817
+ if label not in clusters:
818
+ clusters[label] = []
819
+ clusters[label].append(i)
820
+
821
+ cluster_sizes = [len(cluster) for cluster in clusters.values()]
822
+ cluster_sizes.sort(reverse=True)
823
+
824
+ # Find medoids for each cluster
825
+ medoids = find_cluster_medoids(list(clusters.values()), distance_matrix)
826
+
827
+ # Create result data
828
+ result = {
829
+ "nrmsd_threshold": float(threshold),
830
+ "n_clusters": n_clusters,
831
+ "cluster_sizes": cluster_sizes,
832
+ "clusters": [],
833
+ }
834
+
835
+ for cluster_indices, medoid_idx in zip(clusters.values(), medoids):
836
+ representative = str(file_paths[medoid_idx])
837
+ members = [str(file_paths[idx]) for idx in cluster_indices if idx != medoid_idx]
838
+
839
+ result["clusters"].append(
840
+ {"representative": representative, "members": members}
841
+ )
842
+
843
+ return result
844
+
845
+
846
+ def exponential_decay(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
847
+ """
848
+ Exponential decay function: y = a * exp(-b * x) + c
849
+
850
+ Parameters:
851
+ -----------
852
+ x : np.ndarray
853
+ Input values
854
+ a : float
855
+ Amplitude parameter
856
+ b : float
857
+ Decay rate parameter
858
+ c : float
859
+ Offset parameter
860
+
861
+ Returns:
862
+ --------
863
+ np.ndarray
864
+ Function values
865
+ """
866
+ return a * np.exp(-b * x) + c
867
+
868
+
869
+ def fit_exponential_decay(
870
+ x: np.ndarray, y: np.ndarray
871
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
872
+ """
873
+ Fit exponential decay function to data and find inflection points.
874
+
875
+ Parameters:
876
+ -----------
877
+ x : np.ndarray
878
+ X coordinates (thresholds)
879
+ y : np.ndarray
880
+ Y coordinates (cluster counts)
881
+
882
+ Returns:
883
+ --------
884
+ Tuple[np.ndarray, np.ndarray, np.ndarray]
885
+ Tuple of (x_smooth, y_smooth, inflection_x) where:
886
+ - x_smooth: smooth x values for plotting the fitted curve
887
+ - y_smooth: smooth y values for plotting the fitted curve
888
+ - inflection_x: x coordinates of inflection points
889
+ """
890
+ if len(x) < 4:
891
+ return x, y, np.array([])
892
+
893
+ # Sort data by x values
894
+ sort_idx = np.argsort(x)
895
+ x_sorted = x[sort_idx]
896
+ y_sorted = y[sort_idx]
897
+
898
+ try:
899
+ # Initial parameter guess
900
+ a_guess = y_sorted.max() - y_sorted.min()
901
+ b_guess = 1.0
902
+ c_guess = y_sorted.min()
903
+
904
+ # Fit exponential decay
905
+ popt, _ = curve_fit(
906
+ exponential_decay,
907
+ x_sorted,
908
+ y_sorted,
909
+ p0=[a_guess, b_guess, c_guess],
910
+ maxfev=5000,
911
+ )
912
+
913
+ a_fit, b_fit, c_fit = popt
914
+
915
+ # Generate smooth curve for plotting
916
+ x_smooth = np.linspace(x_sorted.min(), x_sorted.max(), 200)
917
+ y_smooth = exponential_decay(x_smooth, a_fit, b_fit, c_fit)
918
+
919
+ # For exponential decay y = a*exp(-b*x) + c, the second derivative is:
920
+ # y'' = a*b^2*exp(-b*x)
921
+ # Since a > 0 and b > 0 for decay, y'' > 0 always, so no inflection points
922
+ # However, we can find the point of maximum curvature (steepest decline)
923
+ # This occurs where the first derivative is most negative
924
+ # y' = -a*b*exp(-b*x), which is most negative at x = 0
925
+ # But we'll look for the point where the rate of change is fastest within our data range
926
+
927
+ # For exponential decay, identify the "knee" point where the curve
928
+ # transitions from steep to gradual decline
929
+ # This is often around x = 1/b in the exponential decay
930
+ knee_x = 1.0 / b_fit if b_fit > 0 else None
931
+
932
+ # Also find the point where the second derivative is maximum
933
+ # (point of maximum curvature, excluding edges)
934
+ second_deriv_vals = a_fit * (b_fit**2) * np.exp(-b_fit * x_smooth)
935
+
936
+ # Exclude edge points (first and last 10% of data)
937
+ edge_margin = int(0.1 * len(x_smooth))
938
+ if edge_margin < 1:
939
+ edge_margin = 1
940
+
941
+ # Find maximum curvature point excluding edges
942
+ max_curvature_idx = (
943
+ np.argmax(second_deriv_vals[edge_margin:-edge_margin]) + edge_margin
944
+ )
945
+ max_curvature_x = x_smooth[max_curvature_idx]
946
+
947
+ inflection_points = []
948
+
949
+ # Add knee point if it's within data range and not at edges
950
+ if knee_x is not None and x_sorted.min() + 0.1 * (
951
+ x_sorted.max() - x_sorted.min()
952
+ ) <= knee_x <= x_sorted.max() - 0.1 * (x_sorted.max() - x_sorted.min()):
953
+ inflection_points.append(knee_x)
954
+
955
+ # Add maximum curvature point if it's meaningful and different from knee
956
+ if x_sorted.min() + 0.1 * (
957
+ x_sorted.max() - x_sorted.min()
958
+ ) <= max_curvature_x <= x_sorted.max() - 0.1 * (
959
+ x_sorted.max() - x_sorted.min()
960
+ ) and (
961
+ not inflection_points
962
+ or abs(max_curvature_x - inflection_points[0])
963
+ > 0.05 * (x_sorted.max() - x_sorted.min())
964
+ ):
965
+ inflection_points.append(max_curvature_x)
966
+
967
+ inflection_x = np.array(inflection_points)
968
+
969
+ print(
970
+ f"Exponential decay fit: y = {a_fit:.3f} * exp(-{b_fit:.3f} * x) + {c_fit:.3f}"
971
+ )
972
+
973
+ return x_smooth, y_smooth, inflection_x
974
+
975
+ except Exception as e:
976
+ print(f"Warning: Exponential decay fitting failed: {e}")
977
+ return x, y, np.array([])
978
+
979
+
980
+ def determine_optimal_threshold(
981
+ distance_matrix: np.ndarray, linkage_matrix: np.ndarray
982
+ ) -> float:
983
+ """
984
+ Determine optimal threshold from exponential decay inflection point.
985
+
986
+ Parameters:
987
+ -----------
988
+ distance_matrix : np.ndarray
989
+ Square distance matrix
990
+ linkage_matrix : np.ndarray
991
+ Linkage matrix from hierarchical clustering
992
+
993
+ Returns:
994
+ --------
995
+ float
996
+ Optimal threshold value
997
+ """
998
+ # Extract merge distances from linkage matrix
999
+ merge_distances = linkage_matrix[:, 2]
1000
+ valid_thresholds = np.sort(merge_distances)
1001
+
1002
+ # Calculate cluster counts for each threshold
1003
+ cluster_counts = []
1004
+ for threshold in valid_thresholds:
1005
+ labels = fcluster(linkage_matrix, threshold, criterion="distance")
1006
+ n_clusters = len(np.unique(labels))
1007
+ cluster_counts.append(n_clusters)
1008
+
1009
+ thresholds = np.array(valid_thresholds)
1010
+ cluster_counts = np.array(cluster_counts)
1011
+
1012
+ # Fit exponential decay and find inflection points
1013
+ x_smooth, y_smooth, inflection_x = fit_exponential_decay(thresholds, cluster_counts)
1014
+
1015
+ if len(inflection_x) > 0:
1016
+ # Use the first inflection point as the optimal threshold
1017
+ optimal_threshold = inflection_x[0]
1018
+ print(f"Auto-detected optimal threshold: {optimal_threshold:.6f}")
1019
+ return optimal_threshold
1020
+ else:
1021
+ # Fallback to a reasonable default if no inflection points found
1022
+ fallback_threshold = 0.1
1023
+ print(
1024
+ f"No inflection points found, using fallback threshold: {fallback_threshold}"
1025
+ )
1026
+ return fallback_threshold
1027
+
1028
+
1029
+ def find_cluster_medoids(
1030
+ clusters: List[List[int]], distance_matrix: np.ndarray
1031
+ ) -> List[int]:
1032
+ """
1033
+ Find the medoid (representative) for each cluster.
1034
+
1035
+ Parameters:
1036
+ -----------
1037
+ clusters : List[List[int]]
1038
+ List of clusters, where each cluster is a list of structure indices
1039
+ distance_matrix : np.ndarray
1040
+ Square distance matrix between all structures
1041
+
1042
+ Returns:
1043
+ --------
1044
+ List[int]
1045
+ List of medoid indices, one for each cluster
1046
+ """
1047
+ medoids = []
1048
+
1049
+ for cluster in clusters:
1050
+ if len(cluster) == 1:
1051
+ # Single element cluster - it's its own medoid
1052
+ medoids.append(cluster[0])
1053
+ else:
1054
+ # Find the element with minimum sum of distances to all other elements in cluster
1055
+ min_sum_distance = float("inf")
1056
+ medoid = cluster[0]
1057
+
1058
+ for candidate in cluster:
1059
+ sum_distance = sum(
1060
+ distance_matrix[candidate, other]
1061
+ for other in cluster
1062
+ if other != candidate
1063
+ )
1064
+ if sum_distance < min_sum_distance:
1065
+ min_sum_distance = sum_distance
1066
+ medoid = candidate
1067
+
1068
+ medoids.append(medoid)
1069
+
1070
+ return medoids
1071
+
1072
+
1073
+ def main():
1074
+ """Main entry point for the distiller CLI tool."""
1075
+ args = parse_arguments()
1076
+
1077
+ # Validate input files
1078
+ valid_files = validate_input_files(args.files)
1079
+
1080
+ if not valid_files:
1081
+ print("Error: No valid input files found", file=sys.stderr)
1082
+ sys.exit(1)
1083
+
1084
+ print(f"Processing {len(valid_files)} files")
1085
+
1086
+ # Parse all structure files
1087
+ print("Parsing structure files...")
1088
+ structures = []
1089
+ for file_path in valid_files:
1090
+ try:
1091
+ structure = parse_structure_file(file_path)
1092
+ structures.append(structure)
1093
+ print(f" Parsed {file_path}")
1094
+ except Exception:
1095
+ print(f" Failed to parse {file_path}, skipping", file=sys.stderr)
1096
+ continue
1097
+
1098
+ if not structures:
1099
+ print("Error: No structures could be parsed", file=sys.stderr)
1100
+ sys.exit(1)
1101
+
1102
+ # Update valid_files to match successfully parsed structures
1103
+ valid_files = valid_files[: len(structures)]
1104
+
1105
+ # Validate nucleotide counts
1106
+ print("\nValidating nucleotide counts...")
1107
+ validate_nucleotide_counts(structures, valid_files)
1108
+
1109
+ # Switch workflow based on requested mode
1110
+ if args.mode == "approximate":
1111
+ run_approximate(structures, valid_files, args)
1112
+ return
1113
+ else:
1114
+ run_exact(structures, valid_files, args)
1115
+ return
1116
+
1117
+
1118
+ if __name__ == "__main__":
1119
+ main()