RNApolis 0.9.2__py3-none-any.whl → 0.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rnapolis/distiller.py +1119 -0
- rnapolis/parser.py +7 -0
- rnapolis/tertiary_v2.py +482 -18
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/METADATA +4 -1
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/RECORD +9 -8
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/entry_points.txt +1 -0
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/WHEEL +0 -0
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/licenses/LICENSE +0 -0
- {rnapolis-0.9.2.dist-info → rnapolis-0.10.1.dist-info}/top_level.txt +0 -0
rnapolis/distiller.py
ADDED
@@ -0,0 +1,1119 @@
|
|
1
|
+
import argparse
|
2
|
+
import hashlib
|
3
|
+
import itertools
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import sys
|
7
|
+
import time
|
8
|
+
from concurrent.futures import ProcessPoolExecutor
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Dict, List, Optional, Tuple
|
11
|
+
|
12
|
+
import faiss
|
13
|
+
import numpy as np
|
14
|
+
from scipy.cluster.hierarchy import dendrogram, fcluster, linkage
|
15
|
+
from scipy.optimize import curve_fit
|
16
|
+
from scipy.spatial.distance import squareform
|
17
|
+
from sklearn.decomposition import PCA
|
18
|
+
from tqdm import tqdm
|
19
|
+
|
20
|
+
from rnapolis.parser_v2 import parse_cif_atoms, parse_pdb_atoms
|
21
|
+
from rnapolis.tertiary_v2 import (
|
22
|
+
Structure,
|
23
|
+
calculate_torsion_angle,
|
24
|
+
nrmsd_qcp_residues,
|
25
|
+
nrmsd_quaternions_residues,
|
26
|
+
nrmsd_svd_residues,
|
27
|
+
nrmsd_validate_residues,
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def parse_arguments():
|
32
|
+
"""Parse command line arguments."""
|
33
|
+
parser = argparse.ArgumentParser(
|
34
|
+
description="Find clusters of almost identical RNA structures from mmCIF or PDB files"
|
35
|
+
)
|
36
|
+
|
37
|
+
parser.add_argument(
|
38
|
+
"files", nargs="+", type=Path, help="Input mmCIF or PDB files to analyze"
|
39
|
+
)
|
40
|
+
|
41
|
+
parser.add_argument(
|
42
|
+
"--visualize",
|
43
|
+
action="store_true",
|
44
|
+
help="Show dendrogram visualization of clustering (exact mode only)",
|
45
|
+
)
|
46
|
+
|
47
|
+
parser.add_argument(
|
48
|
+
"--output-json",
|
49
|
+
type=str,
|
50
|
+
help="Output JSON file to save clustering results (available in both modes)",
|
51
|
+
)
|
52
|
+
|
53
|
+
parser.add_argument(
|
54
|
+
"--rmsd-method",
|
55
|
+
type=str,
|
56
|
+
choices=["quaternions", "svd", "qcp", "validate"],
|
57
|
+
default="quaternions",
|
58
|
+
help="RMSD calculation method (default: quaternions). Use 'validate' to check all methods agree. (exact mode only)",
|
59
|
+
)
|
60
|
+
|
61
|
+
parser.add_argument(
|
62
|
+
"--threshold",
|
63
|
+
type=float,
|
64
|
+
default=None,
|
65
|
+
help="nRMSD threshold for clustering (default: auto-detect from exponential decay inflection point) (exact mode only)",
|
66
|
+
)
|
67
|
+
|
68
|
+
parser.add_argument(
|
69
|
+
"--cache-file",
|
70
|
+
type=str,
|
71
|
+
default="nrmsd_cache.json",
|
72
|
+
help="Cache file for storing computed nRMSD values (exact mode only, default: nrmsd_cache.json)",
|
73
|
+
)
|
74
|
+
|
75
|
+
parser.add_argument(
|
76
|
+
"--cache-save-interval",
|
77
|
+
type=int,
|
78
|
+
default=100,
|
79
|
+
help="Save cache to disk every N computations (exact mode only, default: 100)",
|
80
|
+
)
|
81
|
+
|
82
|
+
parser.add_argument(
|
83
|
+
"--mode",
|
84
|
+
choices=["exact", "approximate"],
|
85
|
+
default="exact",
|
86
|
+
help="Clustering mode switch: --mode exact (default) performs rigorous nRMSD clustering, "
|
87
|
+
"--mode approximate performs faster feature-based PCA + FAISS clustering",
|
88
|
+
)
|
89
|
+
|
90
|
+
parser.add_argument(
|
91
|
+
"--radius",
|
92
|
+
type=float,
|
93
|
+
default=10.0,
|
94
|
+
help="Radius in PCA-reduced space for redundancy detection (approximate mode only, default: 10.0)",
|
95
|
+
)
|
96
|
+
|
97
|
+
return parser.parse_args()
|
98
|
+
|
99
|
+
|
100
|
+
class NRMSDCache:
|
101
|
+
"""Cache for storing computed nRMSD values with file metadata."""
|
102
|
+
|
103
|
+
def __init__(self, cache_file: str, save_interval: int = 100):
|
104
|
+
self.cache_file = cache_file
|
105
|
+
self.save_interval = save_interval
|
106
|
+
self.cache: Dict[str, float] = {}
|
107
|
+
self.computation_count = 0
|
108
|
+
self.load_cache()
|
109
|
+
|
110
|
+
def _get_file_key(self, file_path: Path) -> str:
|
111
|
+
"""Generate a unique key for a file based on path and modification time."""
|
112
|
+
stat = file_path.stat()
|
113
|
+
return f"{file_path.absolute()}:{stat.st_mtime}:{stat.st_size}"
|
114
|
+
|
115
|
+
def _get_pair_key(self, file1: Path, file2: Path, rmsd_method: str) -> str:
|
116
|
+
"""Generate a unique key for a file pair and method."""
|
117
|
+
key1 = self._get_file_key(file1)
|
118
|
+
key2 = self._get_file_key(file2)
|
119
|
+
# Ensure consistent ordering
|
120
|
+
if key1 > key2:
|
121
|
+
key1, key2 = key2, key1
|
122
|
+
combined = f"{key1}|{key2}|{rmsd_method}"
|
123
|
+
# Use hash to keep keys manageable
|
124
|
+
return hashlib.md5(combined.encode()).hexdigest()
|
125
|
+
|
126
|
+
def load_cache(self):
|
127
|
+
"""Load cache from disk if it exists."""
|
128
|
+
if os.path.exists(self.cache_file):
|
129
|
+
try:
|
130
|
+
with open(self.cache_file, "r") as f:
|
131
|
+
self.cache = json.load(f)
|
132
|
+
print(
|
133
|
+
f"Loaded {len(self.cache)} cached nRMSD values from {self.cache_file}"
|
134
|
+
)
|
135
|
+
except Exception as e:
|
136
|
+
print(
|
137
|
+
f"Warning: Could not load cache file {self.cache_file}: {e}",
|
138
|
+
file=sys.stderr,
|
139
|
+
)
|
140
|
+
self.cache = {}
|
141
|
+
else:
|
142
|
+
print(f"No existing cache file found at {self.cache_file}")
|
143
|
+
|
144
|
+
def save_cache(self, silent: bool = False):
|
145
|
+
"""Save cache to disk."""
|
146
|
+
try:
|
147
|
+
with open(self.cache_file, "w") as f:
|
148
|
+
json.dump(self.cache, f, indent=2)
|
149
|
+
if not silent:
|
150
|
+
print(f"Saved {len(self.cache)} cached values to {self.cache_file}")
|
151
|
+
except Exception as e:
|
152
|
+
print(
|
153
|
+
f"Warning: Could not save cache file {self.cache_file}: {e}",
|
154
|
+
file=sys.stderr,
|
155
|
+
)
|
156
|
+
|
157
|
+
def get(self, file1: Path, file2: Path, rmsd_method: str) -> Optional[float]:
|
158
|
+
"""Get cached nRMSD value if available."""
|
159
|
+
key = self._get_pair_key(file1, file2, rmsd_method)
|
160
|
+
return self.cache.get(key)
|
161
|
+
|
162
|
+
def set(self, file1: Path, file2: Path, rmsd_method: str, value: float):
|
163
|
+
"""Store nRMSD value in cache."""
|
164
|
+
key = self._get_pair_key(file1, file2, rmsd_method)
|
165
|
+
self.cache[key] = value
|
166
|
+
self.computation_count += 1
|
167
|
+
|
168
|
+
# Save periodically (silently to avoid disrupting progress bar)
|
169
|
+
if self.computation_count % self.save_interval == 0:
|
170
|
+
self.save_cache(silent=True)
|
171
|
+
|
172
|
+
|
173
|
+
def validate_input_files(files: List[Path]) -> List[Path]:
|
174
|
+
"""Validate that input files exist and have appropriate extensions."""
|
175
|
+
valid_files = []
|
176
|
+
valid_extensions = {".pdb", ".cif", ".mmcif"}
|
177
|
+
|
178
|
+
for file_path in files:
|
179
|
+
if not file_path.exists():
|
180
|
+
print(
|
181
|
+
f"Warning: File {file_path} does not exist, skipping", file=sys.stderr
|
182
|
+
)
|
183
|
+
continue
|
184
|
+
|
185
|
+
if file_path.suffix.lower() not in valid_extensions:
|
186
|
+
print(
|
187
|
+
f"Warning: File {file_path} does not have a recognized extension (.pdb, .cif, .mmcif), skipping",
|
188
|
+
file=sys.stderr,
|
189
|
+
)
|
190
|
+
continue
|
191
|
+
|
192
|
+
valid_files.append(file_path)
|
193
|
+
|
194
|
+
return valid_files
|
195
|
+
|
196
|
+
|
197
|
+
def parse_structure_file(file_path: Path) -> Structure:
|
198
|
+
"""
|
199
|
+
Parse a structure file (PDB or mmCIF) into a Structure object.
|
200
|
+
|
201
|
+
Parameters:
|
202
|
+
-----------
|
203
|
+
file_path : Path
|
204
|
+
Path to the structure file
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
--------
|
208
|
+
Structure
|
209
|
+
Parsed structure object
|
210
|
+
"""
|
211
|
+
try:
|
212
|
+
with open(file_path, "r") as f:
|
213
|
+
content = f.read()
|
214
|
+
|
215
|
+
# Determine file type and parse accordingly
|
216
|
+
if file_path.suffix.lower() == ".pdb":
|
217
|
+
atoms_df = parse_pdb_atoms(content)
|
218
|
+
else: # .cif or .mmcif
|
219
|
+
atoms_df = parse_cif_atoms(content)
|
220
|
+
|
221
|
+
return Structure(atoms_df)
|
222
|
+
|
223
|
+
except Exception as e:
|
224
|
+
print(f"Error parsing {file_path}: {e}", file=sys.stderr)
|
225
|
+
raise
|
226
|
+
|
227
|
+
|
228
|
+
def validate_nucleotide_counts(
|
229
|
+
structures: List[Structure], file_paths: List[Path]
|
230
|
+
) -> None:
|
231
|
+
"""
|
232
|
+
Validate that all structures have the same number of nucleotides.
|
233
|
+
|
234
|
+
Parameters:
|
235
|
+
-----------
|
236
|
+
structures : List[Structure]
|
237
|
+
List of parsed structures
|
238
|
+
file_paths : List[Path]
|
239
|
+
Corresponding file paths for error reporting
|
240
|
+
|
241
|
+
Raises:
|
242
|
+
-------
|
243
|
+
SystemExit
|
244
|
+
If structures have different numbers of nucleotides
|
245
|
+
"""
|
246
|
+
nucleotide_counts = []
|
247
|
+
|
248
|
+
for structure, file_path in zip(structures, file_paths):
|
249
|
+
nucleotide_residues = [
|
250
|
+
residue for residue in structure.residues if residue.is_nucleotide
|
251
|
+
]
|
252
|
+
nucleotide_counts.append((len(nucleotide_residues), file_path))
|
253
|
+
|
254
|
+
if not nucleotide_counts:
|
255
|
+
print("Error: No structures with nucleotides found", file=sys.stderr)
|
256
|
+
sys.exit(1)
|
257
|
+
|
258
|
+
# Check if all counts are the same
|
259
|
+
first_count = nucleotide_counts[0][0]
|
260
|
+
mismatched = [
|
261
|
+
(count, path) for count, path in nucleotide_counts if count != first_count
|
262
|
+
]
|
263
|
+
|
264
|
+
if mismatched:
|
265
|
+
print(
|
266
|
+
"Error: Structures have different numbers of nucleotides:", file=sys.stderr
|
267
|
+
)
|
268
|
+
print(
|
269
|
+
f"Expected: {first_count} nucleotides (from {nucleotide_counts[0][1]})",
|
270
|
+
file=sys.stderr,
|
271
|
+
)
|
272
|
+
for count, path in mismatched:
|
273
|
+
print(f"Found: {count} nucleotides in {path}", file=sys.stderr)
|
274
|
+
sys.exit(1)
|
275
|
+
|
276
|
+
print(f"All structures have {first_count} nucleotides")
|
277
|
+
|
278
|
+
|
279
|
+
# ----------------------------------------------------------------------
|
280
|
+
|
281
|
+
|
282
|
+
def run_exact(structures: List[Structure], valid_files: List[Path], args) -> None:
|
283
|
+
"""
|
284
|
+
Exact mode: nRMSD-based clustering workflow (previously in main).
|
285
|
+
Produces the same outputs/visualisations as before.
|
286
|
+
"""
|
287
|
+
# Initialize cache
|
288
|
+
print("\nInitializing nRMSD cache...")
|
289
|
+
cache = NRMSDCache(args.cache_file, args.cache_save_interval)
|
290
|
+
|
291
|
+
# Compute distance matrix
|
292
|
+
print("\nComputing distance matrix...")
|
293
|
+
distance_matrix = find_structure_clusters(
|
294
|
+
structures, valid_files, cache, args.visualize, args.rmsd_method
|
295
|
+
)
|
296
|
+
|
297
|
+
# Build linkage matrix
|
298
|
+
linkage_matrix = linkage(squareform(distance_matrix), method="complete")
|
299
|
+
|
300
|
+
# Determine threshold
|
301
|
+
if args.threshold is None:
|
302
|
+
optimal_threshold = determine_optimal_threshold(distance_matrix, linkage_matrix)
|
303
|
+
else:
|
304
|
+
optimal_threshold = args.threshold
|
305
|
+
print(f"Using user-specified threshold: {optimal_threshold}")
|
306
|
+
|
307
|
+
# Collect threshold data
|
308
|
+
all_threshold_data = find_all_thresholds_and_clusters(
|
309
|
+
distance_matrix, linkage_matrix, valid_files
|
310
|
+
)
|
311
|
+
threshold_clustering = get_clustering_at_threshold(
|
312
|
+
linkage_matrix, distance_matrix, valid_files, optimal_threshold
|
313
|
+
)
|
314
|
+
|
315
|
+
# Visualisation (re-uses the earlier logic)
|
316
|
+
if args.visualize:
|
317
|
+
try:
|
318
|
+
import matplotlib.pyplot as plt
|
319
|
+
|
320
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
321
|
+
|
322
|
+
dendrogram(
|
323
|
+
linkage_matrix,
|
324
|
+
labels=[f"Structure {i}" for i in range(len(structures))],
|
325
|
+
ax=ax1,
|
326
|
+
color_threshold=optimal_threshold,
|
327
|
+
)
|
328
|
+
ax1.axhline(
|
329
|
+
y=optimal_threshold,
|
330
|
+
color="red",
|
331
|
+
linestyle="--",
|
332
|
+
linewidth=2,
|
333
|
+
label=f"Threshold = {optimal_threshold:.6f}",
|
334
|
+
)
|
335
|
+
ax1.set_title("Hierarchical Clustering Dendrogram")
|
336
|
+
ax1.set_xlabel("Structure Index")
|
337
|
+
ax1.set_ylabel("nRMSD Distance")
|
338
|
+
ax1.legend()
|
339
|
+
|
340
|
+
thresholds = np.array(
|
341
|
+
[entry["nrmsd_threshold"] for entry in all_threshold_data]
|
342
|
+
)
|
343
|
+
cluster_counts = np.array(
|
344
|
+
[len(entry["clusters"]) for entry in all_threshold_data]
|
345
|
+
)
|
346
|
+
|
347
|
+
x_smooth, y_smooth, inflection_x = fit_exponential_decay(
|
348
|
+
thresholds, cluster_counts
|
349
|
+
)
|
350
|
+
|
351
|
+
ax2.scatter(
|
352
|
+
thresholds, cluster_counts, alpha=0.7, s=30, label="Data points"
|
353
|
+
)
|
354
|
+
if len(x_smooth) > 0:
|
355
|
+
ax2.plot(
|
356
|
+
x_smooth,
|
357
|
+
y_smooth,
|
358
|
+
"b-",
|
359
|
+
linewidth=2,
|
360
|
+
alpha=0.8,
|
361
|
+
label="Exponential decay fit",
|
362
|
+
)
|
363
|
+
|
364
|
+
if len(inflection_x) > 0 and len(x_smooth) > 0:
|
365
|
+
inflection_y = np.interp(inflection_x, x_smooth, y_smooth)
|
366
|
+
ax2.scatter(
|
367
|
+
inflection_x,
|
368
|
+
inflection_y,
|
369
|
+
color="orange",
|
370
|
+
s=100,
|
371
|
+
marker="*",
|
372
|
+
zorder=6,
|
373
|
+
label=f"Key points ({len(inflection_x)})",
|
374
|
+
)
|
375
|
+
|
376
|
+
ax2.axvline(
|
377
|
+
x=optimal_threshold,
|
378
|
+
color="red",
|
379
|
+
linestyle="--",
|
380
|
+
linewidth=2,
|
381
|
+
label=f"Threshold = {optimal_threshold:.6f}",
|
382
|
+
)
|
383
|
+
ax2.scatter(
|
384
|
+
[optimal_threshold],
|
385
|
+
[threshold_clustering["n_clusters"]],
|
386
|
+
color="red",
|
387
|
+
s=100,
|
388
|
+
zorder=5,
|
389
|
+
label=f"Selected ({threshold_clustering['n_clusters']} clusters)",
|
390
|
+
)
|
391
|
+
|
392
|
+
ax2.set_xlabel("nRMSD Threshold")
|
393
|
+
ax2.set_ylabel("Number of Clusters")
|
394
|
+
ax2.set_title("Threshold vs Cluster Count with Exponential Decay Fit")
|
395
|
+
ax2.grid(True, alpha=0.3)
|
396
|
+
ax2.legend()
|
397
|
+
|
398
|
+
plt.tight_layout()
|
399
|
+
plt.savefig("clustering_analysis.png", dpi=300, bbox_inches="tight")
|
400
|
+
print("Clustering analysis plots saved to clustering_analysis.png")
|
401
|
+
|
402
|
+
try:
|
403
|
+
plt.show()
|
404
|
+
except Exception:
|
405
|
+
print("Note: Could not display plot interactively, but saved to file")
|
406
|
+
except ImportError:
|
407
|
+
print(
|
408
|
+
"Warning: matplotlib not available, skipping visualization",
|
409
|
+
file=sys.stderr,
|
410
|
+
)
|
411
|
+
|
412
|
+
# Summary
|
413
|
+
print(f"\nFound {len(all_threshold_data)} different clustering configurations")
|
414
|
+
print(
|
415
|
+
f"Threshold range: {all_threshold_data[0]['nrmsd_threshold']:.6f} to {all_threshold_data[-1]['nrmsd_threshold']:.6f}"
|
416
|
+
)
|
417
|
+
print(
|
418
|
+
f"Cluster count range: {len(all_threshold_data[-1]['clusters'])} to {len(all_threshold_data[0]['clusters'])}"
|
419
|
+
)
|
420
|
+
|
421
|
+
print(f"\nClustering at threshold {optimal_threshold:.6f}:")
|
422
|
+
print(f" Number of clusters: {threshold_clustering['n_clusters']}")
|
423
|
+
print(f" Cluster sizes: {threshold_clustering['cluster_sizes']}")
|
424
|
+
for i, cluster in enumerate(threshold_clustering["clusters"]):
|
425
|
+
print(
|
426
|
+
f" Cluster {i + 1}: {cluster['representative']} + {len(cluster['members'])} members"
|
427
|
+
)
|
428
|
+
|
429
|
+
if args.output_json:
|
430
|
+
output_data = {
|
431
|
+
"all_thresholds": all_threshold_data,
|
432
|
+
"threshold_clustering": threshold_clustering,
|
433
|
+
"parameters": {
|
434
|
+
"threshold": optimal_threshold,
|
435
|
+
"threshold_source": "user-specified"
|
436
|
+
if args.threshold is not None
|
437
|
+
else "auto-detected",
|
438
|
+
"rmsd_method": args.rmsd_method,
|
439
|
+
},
|
440
|
+
}
|
441
|
+
with open(args.output_json, "w") as f:
|
442
|
+
json.dump(output_data, f, indent=2)
|
443
|
+
print(f"\nComprehensive clustering results saved to {args.output_json}")
|
444
|
+
|
445
|
+
|
446
|
+
# Approximate mode helper functions and workflow
|
447
|
+
# ----------------------------------------------------------------------
|
448
|
+
def _select_base_atoms(residue) -> List[Optional[np.ndarray]]:
|
449
|
+
"""
|
450
|
+
Select four canonical base atoms for a nucleotide residue.
|
451
|
+
|
452
|
+
Purines (A/G/DA/DG): N9, N3, N1, C5
|
453
|
+
Pyrimidines (C/U/DC/DT): N1, O2, N3, C5
|
454
|
+
|
455
|
+
If residue name is unknown, we try purine mapping first and, if incomplete,
|
456
|
+
fall back to pyrimidine mapping. Returned list always has length 4 and may
|
457
|
+
contain ``None`` when coordinates are missing.
|
458
|
+
"""
|
459
|
+
purines = {"A", "G", "DA", "DG"}
|
460
|
+
pyrimidines = {"C", "U", "DC", "DT"}
|
461
|
+
|
462
|
+
def _coords_for(names: List[str]) -> List[Optional[np.ndarray]]:
|
463
|
+
"""Helper to fetch coordinates for a list of atom names."""
|
464
|
+
return [
|
465
|
+
(atom.coordinates if (atom := residue.find_atom(n)) is not None else None)
|
466
|
+
for n in names
|
467
|
+
]
|
468
|
+
|
469
|
+
if residue.residue_name in purines:
|
470
|
+
return _coords_for(["N9", "N3", "N1", "C5"])
|
471
|
+
|
472
|
+
if residue.residue_name in pyrimidines:
|
473
|
+
return _coords_for(["N1", "O2", "N3", "C5"])
|
474
|
+
|
475
|
+
# Unknown residue – attempt purine rule first, then pyrimidine
|
476
|
+
coords = _coords_for(["N9", "N3", "N1", "C5"])
|
477
|
+
if all(c is not None for c in coords):
|
478
|
+
return coords
|
479
|
+
return _coords_for(["N1", "O2", "N3", "C5"])
|
480
|
+
|
481
|
+
|
482
|
+
def featurize_structure(structure: Structure) -> np.ndarray:
|
483
|
+
"""
|
484
|
+
Convert a Structure into a fixed-length feature vector.
|
485
|
+
For n residues the length is 34 * n * (n-1) / 2.
|
486
|
+
"""
|
487
|
+
residues = [r for r in structure.residues if r.is_nucleotide]
|
488
|
+
n = len(residues)
|
489
|
+
if n < 2:
|
490
|
+
return np.zeros(0, dtype=np.float32)
|
491
|
+
|
492
|
+
base_coords = [_select_base_atoms(r) for r in residues]
|
493
|
+
feats: List[float] = []
|
494
|
+
|
495
|
+
for i in range(n):
|
496
|
+
ai = base_coords[i]
|
497
|
+
for j in range(i + 1, n):
|
498
|
+
aj = base_coords[j]
|
499
|
+
|
500
|
+
# 16 distances
|
501
|
+
for ci in ai:
|
502
|
+
for cj in aj:
|
503
|
+
if ci is None or cj is None:
|
504
|
+
dist = 0.0
|
505
|
+
else:
|
506
|
+
dist = float(np.linalg.norm(ci - cj))
|
507
|
+
feats.append(dist)
|
508
|
+
|
509
|
+
# 18 torsion features (sin, cos over 9 angles)
|
510
|
+
a1 = ai[0]
|
511
|
+
a4 = aj[0]
|
512
|
+
for idx2 in range(1, 4):
|
513
|
+
for idx3 in range(1, 4):
|
514
|
+
a2, a3 = ai[idx2], aj[idx3]
|
515
|
+
if any(x is None for x in (a1, a2, a3, a4)):
|
516
|
+
feats.extend([0.0, 1.0])
|
517
|
+
else:
|
518
|
+
angle = calculate_torsion_angle(a1, a2, a3, a4)
|
519
|
+
feats.extend([float(np.sin(angle)), float(np.cos(angle))])
|
520
|
+
|
521
|
+
return np.asarray(feats, dtype=np.float32)
|
522
|
+
|
523
|
+
|
524
|
+
def run_approximate(structures: List[Structure], file_paths: List[Path], args) -> None:
|
525
|
+
"""
|
526
|
+
Approximate mode: features → PCA → FAISS radius clustering.
|
527
|
+
"""
|
528
|
+
print("\nRunning approximate mode (feature-based PCA + FAISS)")
|
529
|
+
|
530
|
+
feature_vectors = [featurize_structure(s) for s in structures]
|
531
|
+
feature_lengths = {len(v) for v in feature_vectors}
|
532
|
+
if len(feature_lengths) != 1:
|
533
|
+
print("Error: Inconsistent feature lengths among structures", file=sys.stderr)
|
534
|
+
sys.exit(1)
|
535
|
+
|
536
|
+
X = np.stack(feature_vectors).astype(np.float32)
|
537
|
+
print(f"Feature matrix shape: {X.shape}")
|
538
|
+
|
539
|
+
pca = PCA(n_components=0.95, svd_solver="full", random_state=0)
|
540
|
+
X_red = pca.fit_transform(X).astype(np.float32)
|
541
|
+
d = X_red.shape[1]
|
542
|
+
print(f"PCA reduced to {d} dimensions (95 % variance)")
|
543
|
+
|
544
|
+
index = faiss.IndexFlatL2(d)
|
545
|
+
index.add(X_red)
|
546
|
+
radius_sq = args.radius**2
|
547
|
+
|
548
|
+
visited: set[int] = set()
|
549
|
+
clusters: List[List[int]] = []
|
550
|
+
|
551
|
+
for idx in range(len(structures)):
|
552
|
+
if idx in visited:
|
553
|
+
continue
|
554
|
+
D, I = index.search(X_red[idx : idx + 1], len(structures))
|
555
|
+
cluster = [int(i) for dist, i in zip(D[0], I[0]) if dist <= radius_sq]
|
556
|
+
clusters.append(cluster)
|
557
|
+
visited.update(cluster)
|
558
|
+
|
559
|
+
print(f"\nIdentified {len(clusters)} representatives with radius {args.radius}")
|
560
|
+
for cluster in clusters:
|
561
|
+
rep = cluster[0]
|
562
|
+
redundants = cluster[1:]
|
563
|
+
print(f"Representative: {file_paths[rep]}")
|
564
|
+
for r in redundants:
|
565
|
+
print(f" Redundant: {file_paths[r]}")
|
566
|
+
|
567
|
+
if args.output_json:
|
568
|
+
out = {
|
569
|
+
"parameters": {"mode": "approximate", "radius": args.radius},
|
570
|
+
"clusters": [
|
571
|
+
{
|
572
|
+
"representative": str(file_paths[c[0]]),
|
573
|
+
"members": [str(file_paths[m]) for m in c[1:]],
|
574
|
+
}
|
575
|
+
for c in clusters
|
576
|
+
],
|
577
|
+
}
|
578
|
+
with open(args.output_json, "w") as f:
|
579
|
+
json.dump(out, f, indent=2)
|
580
|
+
print(f"\nApproximate clustering saved to {args.output_json}")
|
581
|
+
|
582
|
+
return
|
583
|
+
|
584
|
+
|
585
|
+
# ----------------------------------------------------------------------
|
586
|
+
|
587
|
+
|
588
|
+
def find_all_thresholds_and_clusters(
|
589
|
+
distance_matrix: np.ndarray, linkage_matrix: np.ndarray, file_paths: List[Path]
|
590
|
+
) -> List[dict]:
|
591
|
+
"""
|
592
|
+
Find all threshold values where cluster assignments change and generate cluster data.
|
593
|
+
|
594
|
+
Parameters:
|
595
|
+
-----------
|
596
|
+
distance_matrix : np.ndarray
|
597
|
+
Square distance matrix
|
598
|
+
linkage_matrix : np.ndarray
|
599
|
+
Linkage matrix from hierarchical clustering
|
600
|
+
file_paths : List[Path]
|
601
|
+
List of file paths corresponding to structures
|
602
|
+
|
603
|
+
Returns:
|
604
|
+
--------
|
605
|
+
List[dict]
|
606
|
+
List of threshold cluster data
|
607
|
+
"""
|
608
|
+
print("Finding all threshold values where cluster assignments change...")
|
609
|
+
|
610
|
+
# Extract merge distances from linkage matrix (column 2)
|
611
|
+
# These are the exact thresholds where cluster assignments change
|
612
|
+
merge_distances = linkage_matrix[:, 2]
|
613
|
+
|
614
|
+
# Sort thresholds in ascending order (all thresholds, no range filtering)
|
615
|
+
valid_thresholds = np.sort(merge_distances)
|
616
|
+
|
617
|
+
print(f"Testing {len(valid_thresholds)} threshold values where clustering changes:")
|
618
|
+
|
619
|
+
threshold_data = []
|
620
|
+
|
621
|
+
for threshold in valid_thresholds:
|
622
|
+
labels = fcluster(linkage_matrix, threshold, criterion="distance")
|
623
|
+
n_clusters = len(np.unique(labels))
|
624
|
+
|
625
|
+
# Group structure indices by cluster
|
626
|
+
clusters = {}
|
627
|
+
for i, label in enumerate(labels):
|
628
|
+
if label not in clusters:
|
629
|
+
clusters[label] = []
|
630
|
+
clusters[label].append(i)
|
631
|
+
|
632
|
+
cluster_sizes = [len(cluster) for cluster in clusters.values()]
|
633
|
+
cluster_sizes.sort(reverse=True) # Sort by size, largest first
|
634
|
+
|
635
|
+
print(
|
636
|
+
f" Threshold {threshold:.6f}: {n_clusters} clusters, sizes: {cluster_sizes}"
|
637
|
+
)
|
638
|
+
|
639
|
+
# Find medoids for each cluster
|
640
|
+
medoids = find_cluster_medoids(list(clusters.values()), distance_matrix)
|
641
|
+
|
642
|
+
# Create threshold data entry
|
643
|
+
threshold_entry = {"nrmsd_threshold": float(threshold), "clusters": []}
|
644
|
+
|
645
|
+
for cluster_indices, medoid_idx in zip(clusters.values(), medoids):
|
646
|
+
representative = str(file_paths[medoid_idx])
|
647
|
+
members = [
|
648
|
+
str(file_paths[idx]) for idx in cluster_indices if idx != medoid_idx
|
649
|
+
]
|
650
|
+
|
651
|
+
threshold_entry["clusters"].append(
|
652
|
+
{"representative": representative, "members": members}
|
653
|
+
)
|
654
|
+
|
655
|
+
threshold_data.append(threshold_entry)
|
656
|
+
|
657
|
+
return threshold_data
|
658
|
+
|
659
|
+
|
660
|
+
def find_structure_clusters(
|
661
|
+
structures: List[Structure],
|
662
|
+
file_paths: List[Path],
|
663
|
+
cache: NRMSDCache,
|
664
|
+
visualize: bool = False,
|
665
|
+
rmsd_method: str = "quaternions",
|
666
|
+
) -> np.ndarray:
|
667
|
+
"""
|
668
|
+
Find clusters of almost identical structures using hierarchical clustering.
|
669
|
+
|
670
|
+
Parameters:
|
671
|
+
-----------
|
672
|
+
structures : List[Structure]
|
673
|
+
List of parsed structures to analyze
|
674
|
+
visualize : bool
|
675
|
+
Whether to show dendrogram and scatter plot visualization
|
676
|
+
rmsd_method : str
|
677
|
+
RMSD calculation method ("quaternions" or "svd")
|
678
|
+
|
679
|
+
Returns:
|
680
|
+
--------
|
681
|
+
np.ndarray
|
682
|
+
Distance matrix between all structures
|
683
|
+
"""
|
684
|
+
n_structures = len(structures)
|
685
|
+
|
686
|
+
if n_structures == 1:
|
687
|
+
return [[0]], np.zeros((1, 1))
|
688
|
+
|
689
|
+
# Get nucleotide residues for each structure
|
690
|
+
nucleotide_lists = []
|
691
|
+
for structure in structures:
|
692
|
+
nucleotide_lists.append(
|
693
|
+
[residue for residue in structure.residues if residue.is_nucleotide]
|
694
|
+
)
|
695
|
+
|
696
|
+
# Select nRMSD function based on method
|
697
|
+
if rmsd_method == "quaternions":
|
698
|
+
nrmsd_func = nrmsd_quaternions_residues
|
699
|
+
print("Computing pairwise nRMSD distances using quaternion method...")
|
700
|
+
elif rmsd_method == "svd":
|
701
|
+
nrmsd_func = nrmsd_svd_residues
|
702
|
+
print("Computing pairwise nRMSD distances using SVD method...")
|
703
|
+
elif rmsd_method == "qcp":
|
704
|
+
nrmsd_func = nrmsd_qcp_residues
|
705
|
+
print("Computing pairwise nRMSD distances using QCP method...")
|
706
|
+
elif rmsd_method == "validate":
|
707
|
+
nrmsd_func = nrmsd_validate_residues
|
708
|
+
print(
|
709
|
+
"Computing pairwise nRMSD distances using validation mode (all methods)..."
|
710
|
+
)
|
711
|
+
else:
|
712
|
+
raise ValueError(f"Unknown RMSD method: {rmsd_method}")
|
713
|
+
|
714
|
+
distance_matrix = np.zeros((n_structures, n_structures))
|
715
|
+
|
716
|
+
# Prepare all pairs, checking cache first
|
717
|
+
cached_pairs = []
|
718
|
+
compute_pairs = []
|
719
|
+
|
720
|
+
for i, j in itertools.combinations(range(n_structures), 2):
|
721
|
+
cached_value = cache.get(file_paths[i], file_paths[j], rmsd_method)
|
722
|
+
if cached_value is not None:
|
723
|
+
cached_pairs.append((i, j, cached_value))
|
724
|
+
else:
|
725
|
+
compute_pairs.append((i, j, nucleotide_lists[i], nucleotide_lists[j]))
|
726
|
+
|
727
|
+
print(
|
728
|
+
f"Found {len(cached_pairs)} cached values, computing {len(compute_pairs)} new values"
|
729
|
+
)
|
730
|
+
|
731
|
+
# Fill distance matrix with cached values
|
732
|
+
for i, j, nrmsd_value in cached_pairs:
|
733
|
+
distance_matrix[i, j] = nrmsd_value
|
734
|
+
distance_matrix[j, i] = nrmsd_value
|
735
|
+
|
736
|
+
# Process remaining pairs with progress bar and timing
|
737
|
+
if compute_pairs:
|
738
|
+
start_time = time.time()
|
739
|
+
with ProcessPoolExecutor() as executor:
|
740
|
+
futures_dict = {
|
741
|
+
executor.submit(nrmsd_func, nucleotides_i, nucleotides_j): (i, j)
|
742
|
+
for i, j, nucleotides_i, nucleotides_j in compute_pairs
|
743
|
+
}
|
744
|
+
results = []
|
745
|
+
for future in tqdm(
|
746
|
+
futures_dict,
|
747
|
+
total=len(futures_dict),
|
748
|
+
desc="Computing nRMSD",
|
749
|
+
unit="pair",
|
750
|
+
):
|
751
|
+
i, j = futures_dict[future]
|
752
|
+
nrmsd_value = future.result()
|
753
|
+
results.append((i, j, nrmsd_value))
|
754
|
+
|
755
|
+
# Cache the computed value
|
756
|
+
cache.set(file_paths[i], file_paths[j], rmsd_method, nrmsd_value)
|
757
|
+
|
758
|
+
end_time = time.time()
|
759
|
+
computation_time = end_time - start_time
|
760
|
+
|
761
|
+
print(f"RMSD computation completed in {computation_time:.2f} seconds")
|
762
|
+
if rmsd_method == "validate":
|
763
|
+
print(
|
764
|
+
"Note: Validation mode tests all methods, so timing includes overhead from multiple calculations"
|
765
|
+
)
|
766
|
+
|
767
|
+
# Fill the distance matrix with computed values
|
768
|
+
for i, j, nrmsd in results:
|
769
|
+
distance_matrix[i, j] = nrmsd
|
770
|
+
distance_matrix[j, i] = nrmsd
|
771
|
+
|
772
|
+
# Save cache after all computations
|
773
|
+
cache.save_cache()
|
774
|
+
|
775
|
+
# Convert to condensed distance matrix for scipy
|
776
|
+
condensed_distances = squareform(distance_matrix)
|
777
|
+
|
778
|
+
# Perform hierarchical clustering with complete linkage
|
779
|
+
linkage_matrix = linkage(condensed_distances, method="complete")
|
780
|
+
|
781
|
+
# Return distance matrix for further processing
|
782
|
+
return distance_matrix
|
783
|
+
|
784
|
+
|
785
|
+
def get_clustering_at_threshold(
|
786
|
+
linkage_matrix: np.ndarray,
|
787
|
+
distance_matrix: np.ndarray,
|
788
|
+
file_paths: List[Path],
|
789
|
+
threshold: float,
|
790
|
+
) -> dict:
|
791
|
+
"""
|
792
|
+
Get clustering results at a specific threshold.
|
793
|
+
|
794
|
+
Parameters:
|
795
|
+
-----------
|
796
|
+
linkage_matrix : np.ndarray
|
797
|
+
Linkage matrix from hierarchical clustering
|
798
|
+
distance_matrix : np.ndarray
|
799
|
+
Square distance matrix
|
800
|
+
file_paths : List[Path]
|
801
|
+
List of file paths corresponding to structures
|
802
|
+
threshold : float
|
803
|
+
nRMSD threshold for clustering
|
804
|
+
|
805
|
+
Returns:
|
806
|
+
--------
|
807
|
+
dict
|
808
|
+
Dictionary with clustering information at the given threshold
|
809
|
+
"""
|
810
|
+
# Get cluster assignments at this threshold
|
811
|
+
labels = fcluster(linkage_matrix, threshold, criterion="distance")
|
812
|
+
n_clusters = len(np.unique(labels))
|
813
|
+
|
814
|
+
# Group structure indices by cluster
|
815
|
+
clusters = {}
|
816
|
+
for i, label in enumerate(labels):
|
817
|
+
if label not in clusters:
|
818
|
+
clusters[label] = []
|
819
|
+
clusters[label].append(i)
|
820
|
+
|
821
|
+
cluster_sizes = [len(cluster) for cluster in clusters.values()]
|
822
|
+
cluster_sizes.sort(reverse=True)
|
823
|
+
|
824
|
+
# Find medoids for each cluster
|
825
|
+
medoids = find_cluster_medoids(list(clusters.values()), distance_matrix)
|
826
|
+
|
827
|
+
# Create result data
|
828
|
+
result = {
|
829
|
+
"nrmsd_threshold": float(threshold),
|
830
|
+
"n_clusters": n_clusters,
|
831
|
+
"cluster_sizes": cluster_sizes,
|
832
|
+
"clusters": [],
|
833
|
+
}
|
834
|
+
|
835
|
+
for cluster_indices, medoid_idx in zip(clusters.values(), medoids):
|
836
|
+
representative = str(file_paths[medoid_idx])
|
837
|
+
members = [str(file_paths[idx]) for idx in cluster_indices if idx != medoid_idx]
|
838
|
+
|
839
|
+
result["clusters"].append(
|
840
|
+
{"representative": representative, "members": members}
|
841
|
+
)
|
842
|
+
|
843
|
+
return result
|
844
|
+
|
845
|
+
|
846
|
+
def exponential_decay(x: np.ndarray, a: float, b: float, c: float) -> np.ndarray:
|
847
|
+
"""
|
848
|
+
Exponential decay function: y = a * exp(-b * x) + c
|
849
|
+
|
850
|
+
Parameters:
|
851
|
+
-----------
|
852
|
+
x : np.ndarray
|
853
|
+
Input values
|
854
|
+
a : float
|
855
|
+
Amplitude parameter
|
856
|
+
b : float
|
857
|
+
Decay rate parameter
|
858
|
+
c : float
|
859
|
+
Offset parameter
|
860
|
+
|
861
|
+
Returns:
|
862
|
+
--------
|
863
|
+
np.ndarray
|
864
|
+
Function values
|
865
|
+
"""
|
866
|
+
return a * np.exp(-b * x) + c
|
867
|
+
|
868
|
+
|
869
|
+
def fit_exponential_decay(
|
870
|
+
x: np.ndarray, y: np.ndarray
|
871
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
872
|
+
"""
|
873
|
+
Fit exponential decay function to data and find inflection points.
|
874
|
+
|
875
|
+
Parameters:
|
876
|
+
-----------
|
877
|
+
x : np.ndarray
|
878
|
+
X coordinates (thresholds)
|
879
|
+
y : np.ndarray
|
880
|
+
Y coordinates (cluster counts)
|
881
|
+
|
882
|
+
Returns:
|
883
|
+
--------
|
884
|
+
Tuple[np.ndarray, np.ndarray, np.ndarray]
|
885
|
+
Tuple of (x_smooth, y_smooth, inflection_x) where:
|
886
|
+
- x_smooth: smooth x values for plotting the fitted curve
|
887
|
+
- y_smooth: smooth y values for plotting the fitted curve
|
888
|
+
- inflection_x: x coordinates of inflection points
|
889
|
+
"""
|
890
|
+
if len(x) < 4:
|
891
|
+
return x, y, np.array([])
|
892
|
+
|
893
|
+
# Sort data by x values
|
894
|
+
sort_idx = np.argsort(x)
|
895
|
+
x_sorted = x[sort_idx]
|
896
|
+
y_sorted = y[sort_idx]
|
897
|
+
|
898
|
+
try:
|
899
|
+
# Initial parameter guess
|
900
|
+
a_guess = y_sorted.max() - y_sorted.min()
|
901
|
+
b_guess = 1.0
|
902
|
+
c_guess = y_sorted.min()
|
903
|
+
|
904
|
+
# Fit exponential decay
|
905
|
+
popt, _ = curve_fit(
|
906
|
+
exponential_decay,
|
907
|
+
x_sorted,
|
908
|
+
y_sorted,
|
909
|
+
p0=[a_guess, b_guess, c_guess],
|
910
|
+
maxfev=5000,
|
911
|
+
)
|
912
|
+
|
913
|
+
a_fit, b_fit, c_fit = popt
|
914
|
+
|
915
|
+
# Generate smooth curve for plotting
|
916
|
+
x_smooth = np.linspace(x_sorted.min(), x_sorted.max(), 200)
|
917
|
+
y_smooth = exponential_decay(x_smooth, a_fit, b_fit, c_fit)
|
918
|
+
|
919
|
+
# For exponential decay y = a*exp(-b*x) + c, the second derivative is:
|
920
|
+
# y'' = a*b^2*exp(-b*x)
|
921
|
+
# Since a > 0 and b > 0 for decay, y'' > 0 always, so no inflection points
|
922
|
+
# However, we can find the point of maximum curvature (steepest decline)
|
923
|
+
# This occurs where the first derivative is most negative
|
924
|
+
# y' = -a*b*exp(-b*x), which is most negative at x = 0
|
925
|
+
# But we'll look for the point where the rate of change is fastest within our data range
|
926
|
+
|
927
|
+
# For exponential decay, identify the "knee" point where the curve
|
928
|
+
# transitions from steep to gradual decline
|
929
|
+
# This is often around x = 1/b in the exponential decay
|
930
|
+
knee_x = 1.0 / b_fit if b_fit > 0 else None
|
931
|
+
|
932
|
+
# Also find the point where the second derivative is maximum
|
933
|
+
# (point of maximum curvature, excluding edges)
|
934
|
+
second_deriv_vals = a_fit * (b_fit**2) * np.exp(-b_fit * x_smooth)
|
935
|
+
|
936
|
+
# Exclude edge points (first and last 10% of data)
|
937
|
+
edge_margin = int(0.1 * len(x_smooth))
|
938
|
+
if edge_margin < 1:
|
939
|
+
edge_margin = 1
|
940
|
+
|
941
|
+
# Find maximum curvature point excluding edges
|
942
|
+
max_curvature_idx = (
|
943
|
+
np.argmax(second_deriv_vals[edge_margin:-edge_margin]) + edge_margin
|
944
|
+
)
|
945
|
+
max_curvature_x = x_smooth[max_curvature_idx]
|
946
|
+
|
947
|
+
inflection_points = []
|
948
|
+
|
949
|
+
# Add knee point if it's within data range and not at edges
|
950
|
+
if knee_x is not None and x_sorted.min() + 0.1 * (
|
951
|
+
x_sorted.max() - x_sorted.min()
|
952
|
+
) <= knee_x <= x_sorted.max() - 0.1 * (x_sorted.max() - x_sorted.min()):
|
953
|
+
inflection_points.append(knee_x)
|
954
|
+
|
955
|
+
# Add maximum curvature point if it's meaningful and different from knee
|
956
|
+
if x_sorted.min() + 0.1 * (
|
957
|
+
x_sorted.max() - x_sorted.min()
|
958
|
+
) <= max_curvature_x <= x_sorted.max() - 0.1 * (
|
959
|
+
x_sorted.max() - x_sorted.min()
|
960
|
+
) and (
|
961
|
+
not inflection_points
|
962
|
+
or abs(max_curvature_x - inflection_points[0])
|
963
|
+
> 0.05 * (x_sorted.max() - x_sorted.min())
|
964
|
+
):
|
965
|
+
inflection_points.append(max_curvature_x)
|
966
|
+
|
967
|
+
inflection_x = np.array(inflection_points)
|
968
|
+
|
969
|
+
print(
|
970
|
+
f"Exponential decay fit: y = {a_fit:.3f} * exp(-{b_fit:.3f} * x) + {c_fit:.3f}"
|
971
|
+
)
|
972
|
+
|
973
|
+
return x_smooth, y_smooth, inflection_x
|
974
|
+
|
975
|
+
except Exception as e:
|
976
|
+
print(f"Warning: Exponential decay fitting failed: {e}")
|
977
|
+
return x, y, np.array([])
|
978
|
+
|
979
|
+
|
980
|
+
def determine_optimal_threshold(
|
981
|
+
distance_matrix: np.ndarray, linkage_matrix: np.ndarray
|
982
|
+
) -> float:
|
983
|
+
"""
|
984
|
+
Determine optimal threshold from exponential decay inflection point.
|
985
|
+
|
986
|
+
Parameters:
|
987
|
+
-----------
|
988
|
+
distance_matrix : np.ndarray
|
989
|
+
Square distance matrix
|
990
|
+
linkage_matrix : np.ndarray
|
991
|
+
Linkage matrix from hierarchical clustering
|
992
|
+
|
993
|
+
Returns:
|
994
|
+
--------
|
995
|
+
float
|
996
|
+
Optimal threshold value
|
997
|
+
"""
|
998
|
+
# Extract merge distances from linkage matrix
|
999
|
+
merge_distances = linkage_matrix[:, 2]
|
1000
|
+
valid_thresholds = np.sort(merge_distances)
|
1001
|
+
|
1002
|
+
# Calculate cluster counts for each threshold
|
1003
|
+
cluster_counts = []
|
1004
|
+
for threshold in valid_thresholds:
|
1005
|
+
labels = fcluster(linkage_matrix, threshold, criterion="distance")
|
1006
|
+
n_clusters = len(np.unique(labels))
|
1007
|
+
cluster_counts.append(n_clusters)
|
1008
|
+
|
1009
|
+
thresholds = np.array(valid_thresholds)
|
1010
|
+
cluster_counts = np.array(cluster_counts)
|
1011
|
+
|
1012
|
+
# Fit exponential decay and find inflection points
|
1013
|
+
x_smooth, y_smooth, inflection_x = fit_exponential_decay(thresholds, cluster_counts)
|
1014
|
+
|
1015
|
+
if len(inflection_x) > 0:
|
1016
|
+
# Use the first inflection point as the optimal threshold
|
1017
|
+
optimal_threshold = inflection_x[0]
|
1018
|
+
print(f"Auto-detected optimal threshold: {optimal_threshold:.6f}")
|
1019
|
+
return optimal_threshold
|
1020
|
+
else:
|
1021
|
+
# Fallback to a reasonable default if no inflection points found
|
1022
|
+
fallback_threshold = 0.1
|
1023
|
+
print(
|
1024
|
+
f"No inflection points found, using fallback threshold: {fallback_threshold}"
|
1025
|
+
)
|
1026
|
+
return fallback_threshold
|
1027
|
+
|
1028
|
+
|
1029
|
+
def find_cluster_medoids(
|
1030
|
+
clusters: List[List[int]], distance_matrix: np.ndarray
|
1031
|
+
) -> List[int]:
|
1032
|
+
"""
|
1033
|
+
Find the medoid (representative) for each cluster.
|
1034
|
+
|
1035
|
+
Parameters:
|
1036
|
+
-----------
|
1037
|
+
clusters : List[List[int]]
|
1038
|
+
List of clusters, where each cluster is a list of structure indices
|
1039
|
+
distance_matrix : np.ndarray
|
1040
|
+
Square distance matrix between all structures
|
1041
|
+
|
1042
|
+
Returns:
|
1043
|
+
--------
|
1044
|
+
List[int]
|
1045
|
+
List of medoid indices, one for each cluster
|
1046
|
+
"""
|
1047
|
+
medoids = []
|
1048
|
+
|
1049
|
+
for cluster in clusters:
|
1050
|
+
if len(cluster) == 1:
|
1051
|
+
# Single element cluster - it's its own medoid
|
1052
|
+
medoids.append(cluster[0])
|
1053
|
+
else:
|
1054
|
+
# Find the element with minimum sum of distances to all other elements in cluster
|
1055
|
+
min_sum_distance = float("inf")
|
1056
|
+
medoid = cluster[0]
|
1057
|
+
|
1058
|
+
for candidate in cluster:
|
1059
|
+
sum_distance = sum(
|
1060
|
+
distance_matrix[candidate, other]
|
1061
|
+
for other in cluster
|
1062
|
+
if other != candidate
|
1063
|
+
)
|
1064
|
+
if sum_distance < min_sum_distance:
|
1065
|
+
min_sum_distance = sum_distance
|
1066
|
+
medoid = candidate
|
1067
|
+
|
1068
|
+
medoids.append(medoid)
|
1069
|
+
|
1070
|
+
return medoids
|
1071
|
+
|
1072
|
+
|
1073
|
+
def main():
|
1074
|
+
"""Main entry point for the distiller CLI tool."""
|
1075
|
+
args = parse_arguments()
|
1076
|
+
|
1077
|
+
# Validate input files
|
1078
|
+
valid_files = validate_input_files(args.files)
|
1079
|
+
|
1080
|
+
if not valid_files:
|
1081
|
+
print("Error: No valid input files found", file=sys.stderr)
|
1082
|
+
sys.exit(1)
|
1083
|
+
|
1084
|
+
print(f"Processing {len(valid_files)} files")
|
1085
|
+
|
1086
|
+
# Parse all structure files
|
1087
|
+
print("Parsing structure files...")
|
1088
|
+
structures = []
|
1089
|
+
for file_path in valid_files:
|
1090
|
+
try:
|
1091
|
+
structure = parse_structure_file(file_path)
|
1092
|
+
structures.append(structure)
|
1093
|
+
print(f" Parsed {file_path}")
|
1094
|
+
except Exception:
|
1095
|
+
print(f" Failed to parse {file_path}, skipping", file=sys.stderr)
|
1096
|
+
continue
|
1097
|
+
|
1098
|
+
if not structures:
|
1099
|
+
print("Error: No structures could be parsed", file=sys.stderr)
|
1100
|
+
sys.exit(1)
|
1101
|
+
|
1102
|
+
# Update valid_files to match successfully parsed structures
|
1103
|
+
valid_files = valid_files[: len(structures)]
|
1104
|
+
|
1105
|
+
# Validate nucleotide counts
|
1106
|
+
print("\nValidating nucleotide counts...")
|
1107
|
+
validate_nucleotide_counts(structures, valid_files)
|
1108
|
+
|
1109
|
+
# Switch workflow based on requested mode
|
1110
|
+
if args.mode == "approximate":
|
1111
|
+
run_approximate(structures, valid_files, args)
|
1112
|
+
return
|
1113
|
+
else:
|
1114
|
+
run_exact(structures, valid_files, args)
|
1115
|
+
return
|
1116
|
+
|
1117
|
+
|
1118
|
+
if __name__ == "__main__":
|
1119
|
+
main()
|