speconsense 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,461 @@
1
+ """Base classes and protocols for scalability features."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Callable, Dict, List, Set, Tuple, Optional, Protocol, runtime_checkable
5
+ import logging
6
+
7
+ from tqdm import tqdm
8
+
9
+ from .config import ScalabilityConfig
10
+
11
+
12
+ @runtime_checkable
13
+ class CandidateFinder(Protocol):
14
+ """Protocol for fast approximate candidate finding.
15
+
16
+ Implementations must be able to:
17
+ 1. Build an index from a set of sequences
18
+ 2. Find candidate matches for query sequences
19
+ 3. Clean up resources when done
20
+ """
21
+
22
+ @property
23
+ def name(self) -> str:
24
+ """Human-readable name of this backend."""
25
+ ...
26
+
27
+ @property
28
+ def is_available(self) -> bool:
29
+ """Check if this backend is available (e.g., tool installed)."""
30
+ ...
31
+
32
+ def build_index(self,
33
+ sequences: Dict[str, str],
34
+ output_dir: str) -> None:
35
+ """Build search index from sequences.
36
+
37
+ Args:
38
+ sequences: Dict mapping sequence_id -> sequence_string
39
+ output_dir: Directory for any cache/index files
40
+ """
41
+ ...
42
+
43
+ def find_candidates(self,
44
+ query_ids: List[str],
45
+ sequences: Dict[str, str],
46
+ min_identity: float,
47
+ max_candidates: int) -> Dict[str, List[str]]:
48
+ """Find candidate matches for query sequences.
49
+
50
+ Args:
51
+ query_ids: List of sequence IDs to query
52
+ sequences: Dict mapping sequence_id -> sequence_string
53
+ min_identity: Minimum identity threshold (0.0-1.0)
54
+ max_candidates: Maximum candidates to return per query
55
+
56
+ Returns:
57
+ Dict mapping query_id -> list of candidate target_ids
58
+ """
59
+ ...
60
+
61
+ def cleanup(self) -> None:
62
+ """Clean up any temporary files or resources."""
63
+ ...
64
+
65
+
66
+ class ScalablePairwiseOperation:
67
+ """Generic scalable pairwise operation using candidate pre-filtering.
68
+
69
+ This class encapsulates the two-stage pattern:
70
+ 1. Fast candidate finding using CandidateFinder (e.g., vsearch)
71
+ 2. Exact scoring using provided scoring function
72
+ """
73
+
74
+ def __init__(self,
75
+ candidate_finder: Optional[CandidateFinder],
76
+ scoring_function: Callable[[str, str, str, str], float],
77
+ config: ScalabilityConfig):
78
+ """Initialize scalable pairwise operation.
79
+
80
+ Args:
81
+ candidate_finder: Backend for fast candidate finding (None = brute force only)
82
+ scoring_function: Function(seq1, seq2, id1, id2) -> similarity_score (0.0-1.0)
83
+ config: Scalability configuration
84
+ """
85
+ self.candidate_finder = candidate_finder
86
+ self.scoring_function = scoring_function
87
+ self.config = config
88
+
89
+ def compute_top_k_neighbors(self,
90
+ sequences: Dict[str, str],
91
+ k: int,
92
+ min_identity: float,
93
+ output_dir: str,
94
+ min_edges_per_node: int = 3) -> Dict[str, List[Tuple[str, float]]]:
95
+ """Compute top-k nearest neighbors for all sequences.
96
+
97
+ Args:
98
+ sequences: Dict mapping sequence_id -> sequence_string
99
+ k: Number of neighbors to find per sequence
100
+ min_identity: Minimum identity threshold for neighbors
101
+ output_dir: Directory for temporary files
102
+ min_edges_per_node: Minimum edges to ensure connectivity
103
+
104
+ Returns:
105
+ Dict mapping sequence_id -> list of (neighbor_id, similarity) tuples
106
+ """
107
+ n = len(sequences)
108
+
109
+ # Decide whether to use scalable or brute-force approach
110
+ use_scalable = (
111
+ self.config.enabled and
112
+ self.candidate_finder is not None and
113
+ self.candidate_finder.is_available and
114
+ n >= self.config.activation_threshold
115
+ )
116
+
117
+ if use_scalable:
118
+ return self._compute_knn_scalable(sequences, k, min_identity, output_dir, min_edges_per_node)
119
+ else:
120
+ return self._compute_knn_brute_force(sequences, k, min_identity, min_edges_per_node)
121
+
122
+ def _compute_knn_scalable(self,
123
+ sequences: Dict[str, str],
124
+ k: int,
125
+ min_identity: float,
126
+ output_dir: str,
127
+ min_edges_per_node: int) -> Dict[str, List[Tuple[str, float]]]:
128
+ """Two-stage scalable K-NN computation."""
129
+ logging.debug(f"Using {self.candidate_finder.name}-based scalable K-NN computation")
130
+
131
+ # Build index
132
+ self.candidate_finder.build_index(sequences, output_dir)
133
+
134
+ # Find candidates with oversampling and relaxed threshold
135
+ candidate_count = k * self.config.oversampling_factor
136
+ relaxed_threshold = min_identity * self.config.relaxed_identity_factor
137
+
138
+ seq_ids = sorted(sequences.keys())
139
+ candidates = self.candidate_finder.find_candidates(
140
+ seq_ids, sequences, relaxed_threshold, candidate_count
141
+ )
142
+
143
+ # Refine with exact scoring
144
+ results: Dict[str, List[Tuple[str, float]]] = {}
145
+
146
+ with tqdm(total=len(seq_ids), desc="Refining K-NN with exact scoring") as pbar:
147
+ for seq_id in seq_ids:
148
+ seq_candidates = candidates.get(seq_id, [])
149
+
150
+ # Score all candidates
151
+ scored = []
152
+ for cand_id in seq_candidates:
153
+ if cand_id != seq_id:
154
+ score = self.scoring_function(sequences[seq_id], sequences[cand_id], seq_id, cand_id)
155
+ scored.append((cand_id, score))
156
+
157
+ # Sort by score descending
158
+ scored.sort(key=lambda x: x[1], reverse=True)
159
+
160
+ # Take top k meeting threshold
161
+ top_k = [(cid, score) for cid, score in scored[:k] if score >= min_identity]
162
+
163
+ # Ensure minimum connectivity
164
+ if len(top_k) < min_edges_per_node and len(scored) >= min_edges_per_node:
165
+ for cid, score in scored[len(top_k):]:
166
+ if score >= min_identity * self.config.relaxed_identity_factor:
167
+ top_k.append((cid, score))
168
+ if len(top_k) >= min_edges_per_node:
169
+ break
170
+
171
+ results[seq_id] = top_k
172
+ pbar.update(1)
173
+
174
+ return results
175
+
176
+ def _compute_knn_brute_force(self,
177
+ sequences: Dict[str, str],
178
+ k: int,
179
+ min_identity: float,
180
+ min_edges_per_node: int) -> Dict[str, List[Tuple[str, float]]]:
181
+ """Standard O(n^2) brute-force K-NN computation."""
182
+ logging.debug("Using brute-force K-NN computation")
183
+
184
+ seq_ids = sorted(sequences.keys())
185
+ n = len(seq_ids)
186
+
187
+ # Compute all pairwise similarities
188
+ # IMPORTANT: This matches main branch's asymmetric dict structure exactly.
189
+ # Main branch creates similarities[id1] = {} inside the loop, which overwrites
190
+ # any entries added via setdefault(). The result is that similarities[id1]
191
+ # only contains entries for id2 > id1 (lexically). This affects tie-breaking
192
+ # when selecting top-k neighbors.
193
+ similarities: Dict[str, Dict[str, float]] = {}
194
+
195
+ total = (n * (n - 1)) // 2
196
+ with tqdm(total=total, desc="Computing pairwise similarities") as pbar:
197
+ for id1 in seq_ids:
198
+ similarities[id1] = {}
199
+ for id2 in seq_ids:
200
+ if id1 >= id2: # Only calculate upper triangle (id2 > id1)
201
+ continue
202
+ score = self.scoring_function(sequences[id1], sequences[id2], id1, id2)
203
+ similarities[id1][id2] = score
204
+ similarities.setdefault(id2, {})[id1] = score # Mirror for lookup
205
+ pbar.update(1)
206
+
207
+ # Extract top-k for each sequence
208
+ results: Dict[str, List[Tuple[str, float]]] = {}
209
+
210
+ for seq_id in seq_ids:
211
+ neighbors = sorted(
212
+ [(nid, score) for nid, score in similarities[seq_id].items()],
213
+ key=lambda x: x[1],
214
+ reverse=True
215
+ )
216
+
217
+ top_k = [(nid, score) for nid, score in neighbors[:k] if score >= min_identity]
218
+
219
+ # Ensure minimum connectivity
220
+ if len(top_k) < min_edges_per_node and len(neighbors) >= min_edges_per_node:
221
+ for nid, score in neighbors[k:]:
222
+ if score >= min_identity * 0.9:
223
+ top_k.append((nid, score))
224
+ if len(top_k) >= min_edges_per_node:
225
+ break
226
+
227
+ results[seq_id] = top_k
228
+
229
+ return results
230
+
231
+ def compute_distance_matrix(self,
232
+ sequences: Dict[str, str],
233
+ output_dir: str,
234
+ min_identity: float = 0.9) -> Dict[Tuple[str, str], float]:
235
+ """Compute pairwise distance matrix (for HAC clustering).
236
+
237
+ Args:
238
+ sequences: Dict mapping sequence_id -> sequence_string
239
+ output_dir: Directory for temporary files
240
+ min_identity: Identity threshold for clustering (used to filter candidates)
241
+
242
+ Returns:
243
+ Dict mapping (id1, id2) -> distance, symmetric
244
+ """
245
+ n = len(sequences)
246
+ logging.debug(f"compute_distance_matrix called with {n} sequences")
247
+
248
+ # For small sets or when scalability disabled, use brute force
249
+ use_scalable = (
250
+ self.config.enabled and
251
+ self.candidate_finder is not None and
252
+ self.candidate_finder.is_available and
253
+ n >= self.config.activation_threshold and
254
+ n > 50 # Only worthwhile for larger sets
255
+ )
256
+
257
+ logging.debug(f"use_scalable={use_scalable} (enabled={self.config.enabled}, "
258
+ f"finder={self.candidate_finder is not None}, "
259
+ f"available={self.candidate_finder.is_available if self.candidate_finder else 'N/A'}, "
260
+ f"threshold={self.config.activation_threshold})")
261
+
262
+ if use_scalable:
263
+ return self._compute_distance_matrix_scalable(sequences, output_dir, min_identity)
264
+ else:
265
+ return self._compute_distance_matrix_brute_force(sequences)
266
+
267
+ def _compute_distance_matrix_brute_force(self,
268
+ sequences: Dict[str, str]) -> Dict[Tuple[str, str], float]:
269
+ """Brute-force distance matrix computation."""
270
+ seq_ids = sorted(sequences.keys())
271
+ distances: Dict[Tuple[str, str], float] = {}
272
+
273
+ total = (len(seq_ids) * (len(seq_ids) - 1)) // 2
274
+ with tqdm(total=total, desc="Computing pairwise distances") as pbar:
275
+ for i, id1 in enumerate(seq_ids):
276
+ for id2 in seq_ids[i + 1:]:
277
+ score = self.scoring_function(sequences[id1], sequences[id2], id1, id2)
278
+ distance = 1.0 - score # Convert similarity to distance
279
+ distances[(id1, id2)] = distance
280
+ distances[(id2, id1)] = distance
281
+ pbar.update(1)
282
+
283
+ return distances
284
+
285
+ def _compute_distance_matrix_scalable(self,
286
+ sequences: Dict[str, str],
287
+ output_dir: str,
288
+ min_identity: float) -> Dict[Tuple[str, str], float]:
289
+ """Scalable distance matrix using candidates to reduce comparisons."""
290
+ logging.debug(f"Using {self.candidate_finder.name}-based scalable distance matrix")
291
+
292
+ # Build index
293
+ self.candidate_finder.build_index(sequences, output_dir)
294
+
295
+ seq_ids = sorted(sequences.keys())
296
+ n = len(seq_ids)
297
+
298
+ # Use same safety factors as K-NN computation
299
+ relaxed_threshold = min_identity * self.config.relaxed_identity_factor
300
+ max_candidates = 500
301
+
302
+ logging.debug(f"Finding candidates: identity>={relaxed_threshold:.2f}, max_candidates={max_candidates}")
303
+ all_candidates = self.candidate_finder.find_candidates(
304
+ seq_ids, sequences, relaxed_threshold, max_candidates
305
+ )
306
+
307
+ distances: Dict[Tuple[str, str], float] = {}
308
+ computed_pairs: set = set()
309
+
310
+ with tqdm(total=len(seq_ids), desc="Computing distances for candidates") as pbar:
311
+ for id1 in seq_ids:
312
+ for id2 in all_candidates.get(id1, []):
313
+ pair = (min(id1, id2), max(id1, id2))
314
+ if pair not in computed_pairs and id1 != id2:
315
+ score = self.scoring_function(sequences[id1], sequences[id2], id1, id2)
316
+ distance = 1.0 - score
317
+ distances[(id1, id2)] = distance
318
+ distances[(id2, id1)] = distance
319
+ computed_pairs.add(pair)
320
+ pbar.update(1)
321
+
322
+ # Return sparse matrix - missing pairs are treated as distance 1.0 by consumers
323
+ logging.debug(f"Computed {len(computed_pairs)} distance pairs (sparse matrix)")
324
+
325
+ return distances
326
+
327
+ def compute_equivalence_groups(self,
328
+ sequences: Dict[str, str],
329
+ equivalence_fn: Callable[[str, str], bool],
330
+ output_dir: str,
331
+ min_candidate_identity: float = 0.95) -> List[List[str]]:
332
+ """Compute groups of equivalent sequences using candidate pre-filtering.
333
+
334
+ This is useful for merging clusters whose consensus sequences are
335
+ identical or homopolymer-equivalent. Uses union-find for transitive grouping.
336
+
337
+ Args:
338
+ sequences: Dict mapping sequence_id -> sequence_string
339
+ equivalence_fn: Function(seq1, seq2) -> bool for exact equivalence check
340
+ output_dir: Directory for temporary files
341
+ min_candidate_identity: Min identity threshold for candidates (default 0.95)
342
+
343
+ Returns:
344
+ List of groups, where each group is a list of equivalent sequence IDs
345
+ """
346
+ n = len(sequences)
347
+
348
+ # For small sets or when scalability disabled, use brute force
349
+ use_scalable = (
350
+ self.config.enabled and
351
+ self.candidate_finder is not None and
352
+ self.candidate_finder.is_available and
353
+ n >= self.config.activation_threshold and
354
+ n > 50 # Only worthwhile for larger sets
355
+ )
356
+
357
+ if use_scalable:
358
+ return self._compute_equivalence_groups_scalable(
359
+ sequences, equivalence_fn, output_dir, min_candidate_identity
360
+ )
361
+ else:
362
+ return self._compute_equivalence_groups_brute_force(sequences, equivalence_fn)
363
+
364
+ def _compute_equivalence_groups_brute_force(self,
365
+ sequences: Dict[str, str],
366
+ equivalence_fn: Callable[[str, str], bool]) -> List[List[str]]:
367
+ """Brute-force O(n²) equivalence group computation."""
368
+ seq_ids = sorted(sequences.keys())
369
+
370
+ # Union-find data structure
371
+ parent = {sid: sid for sid in seq_ids}
372
+
373
+ def find(x: str) -> str:
374
+ if parent[x] != x:
375
+ parent[x] = find(parent[x])
376
+ return parent[x]
377
+
378
+ def union(x: str, y: str) -> None:
379
+ px, py = find(x), find(y)
380
+ if px != py:
381
+ parent[px] = py
382
+
383
+ # Check all pairs
384
+ total = (len(seq_ids) * (len(seq_ids) - 1)) // 2
385
+ with tqdm(total=total, desc="Finding equivalent pairs") as pbar:
386
+ for i, id1 in enumerate(seq_ids):
387
+ for id2 in seq_ids[i + 1:]:
388
+ if equivalence_fn(sequences[id1], sequences[id2]):
389
+ union(id1, id2)
390
+ pbar.update(1)
391
+
392
+ # Collect groups
393
+ groups: Dict[str, List[str]] = {}
394
+ for sid in seq_ids:
395
+ root = find(sid)
396
+ if root not in groups:
397
+ groups[root] = []
398
+ groups[root].append(sid)
399
+
400
+ return list(groups.values())
401
+
402
+ def _compute_equivalence_groups_scalable(self,
403
+ sequences: Dict[str, str],
404
+ equivalence_fn: Callable[[str, str], bool],
405
+ output_dir: str,
406
+ min_candidate_identity: float) -> List[List[str]]:
407
+ """Scalable equivalence group computation using candidate pre-filtering."""
408
+ logging.debug(f"Using {self.candidate_finder.name}-based equivalence grouping")
409
+
410
+ # Build index
411
+ self.candidate_finder.build_index(sequences, output_dir)
412
+
413
+ seq_ids = sorted(sequences.keys())
414
+ n = len(seq_ids)
415
+
416
+ # Find candidates with high identity (likely equivalent sequences)
417
+ # Use a reasonable max_candidates to limit work while still finding all equivalents
418
+ max_candidates = min(n, 100)
419
+ all_candidates = self.candidate_finder.find_candidates(
420
+ seq_ids, sequences, min_candidate_identity, max_candidates
421
+ )
422
+
423
+ # Union-find data structure
424
+ parent = {sid: sid for sid in seq_ids}
425
+
426
+ def find(x: str) -> str:
427
+ if parent[x] != x:
428
+ parent[x] = find(parent[x])
429
+ return parent[x]
430
+
431
+ def union(x: str, y: str) -> None:
432
+ px, py = find(x), find(y)
433
+ if px != py:
434
+ parent[px] = py
435
+
436
+ # Check only candidate pairs
437
+ checked_pairs: set = set()
438
+ equivalent_count = 0
439
+
440
+ with tqdm(total=len(seq_ids), desc="Checking candidate equivalences") as pbar:
441
+ for id1 in seq_ids:
442
+ for id2 in all_candidates.get(id1, []):
443
+ pair = (min(id1, id2), max(id1, id2))
444
+ if pair not in checked_pairs and id1 != id2:
445
+ if equivalence_fn(sequences[id1], sequences[id2]):
446
+ union(id1, id2)
447
+ equivalent_count += 1
448
+ checked_pairs.add(pair)
449
+ pbar.update(1)
450
+
451
+ logging.debug(f"Found {equivalent_count} equivalent pairs from {len(checked_pairs)} candidates")
452
+
453
+ # Collect groups
454
+ groups: Dict[str, List[str]] = {}
455
+ for sid in seq_ids:
456
+ root = find(sid)
457
+ if root not in groups:
458
+ groups[root] = []
459
+ groups[root].append(sid)
460
+
461
+ return list(groups.values())
@@ -0,0 +1,42 @@
1
+ """Configuration for scalability features."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+
7
+ @dataclass
8
+ class ScalabilityConfig:
9
+ """Configuration for scalability features.
10
+
11
+ Attributes:
12
+ enabled: Whether scalability mode is active
13
+ activation_threshold: Minimum sequence count to activate scalability
14
+ max_threads: Max threads for internal parallelism (default: 1 for backward compatibility)
15
+ backend: Which backend to use (default: 'vsearch')
16
+ oversampling_factor: Multiplier for candidate count in K-NN (default: 10)
17
+ relaxed_identity_factor: Factor to relax identity threshold for candidates (default: 0.9)
18
+ batch_size: Number of sequences per batch for vsearch queries (default: 1000)
19
+ """
20
+ enabled: bool = False
21
+ activation_threshold: int = 0
22
+ max_threads: int = 1
23
+ backend: str = 'vsearch'
24
+ oversampling_factor: int = 10
25
+ relaxed_identity_factor: float = 0.9
26
+ batch_size: int = 1000
27
+
28
+ @classmethod
29
+ def from_args(cls, args) -> 'ScalabilityConfig':
30
+ """Create config from command-line arguments.
31
+
32
+ The scale_threshold arg controls scalability:
33
+ - 0: disabled
34
+ - N > 0: enabled for datasets >= N sequences (default: 1001)
35
+ """
36
+ threshold = getattr(args, 'scale_threshold', 1001)
37
+ return cls(
38
+ enabled=threshold > 0,
39
+ activation_threshold=threshold,
40
+ max_threads=getattr(args, 'threads', 1),
41
+ backend=getattr(args, 'scalability_backend', 'vsearch'),
42
+ )