yamcot 1.0.0__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yamcot/comparison.py ADDED
@@ -0,0 +1,1066 @@
1
+ """
2
+ comparison
3
+ ==========
4
+
5
+ Implementations of motif comparison metrics. Comparing motifs is
6
+ useful for identifying similar patterns discovered in different datasets
7
+ or cross‑validation folds. This module defines a common
8
+ interface for comparison algorithms and several concrete
9
+ implementations
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ import tempfile
16
+ from abc import ABC, abstractmethod
17
+ from typing import Optional
18
+
19
+ import numpy as np
20
+ from joblib import Parallel, delayed
21
+ from scipy.ndimage import convolve1d
22
+
23
+ from yamcot.execute import run_motali
24
+ from yamcot.functions import (
25
+ _fast_cj_kernel_numba,
26
+ _fast_overlap_kernel_numba,
27
+ _fast_pearson_kernel,
28
+ pfm_to_pwm,
29
+ scores_to_frequencies,
30
+ )
31
+ from yamcot.io import write_fasta
32
+ from yamcot.models import MotifModel, RaggedScores
33
+ from yamcot.ragged import RaggedData, ragged_from_list
34
+
35
+
36
+ class GeneralMotifComparator(ABC):
37
+ """
38
+ Abstract base class for motif comparators.
39
+
40
+ This class defines the common interface for all motif comparison algorithms.
41
+ Concrete implementations should inherit from this class and implement
42
+ the compare method.
43
+ """
44
+
45
+ def __init__(self, name: str) -> None:
46
+ """
47
+ Initialize the comparator.
48
+
49
+ Parameters
50
+ ----------
51
+ name : str
52
+ Name of the comparator instance.
53
+ """
54
+ self.name = name
55
+
56
+ @abstractmethod
57
+ def compare(
58
+ self,
59
+ motif_1: MotifModel,
60
+ motif_2: MotifModel,
61
+ sequences: RaggedData | None = None,
62
+ ) -> dict:
63
+ """
64
+ Compare motifs from two collections.
65
+
66
+ This is an abstract method that must be implemented by subclasses.
67
+
68
+ Parameters
69
+ ----------
70
+ motifs_1 : List[MotifModel]
71
+ First collection of motifs to compare.
72
+ motifs_2 : List[MotifModel]
73
+ Second collection of motifs to compare.
74
+ sequences : RaggedData or None
75
+ Sequences for frequency calculation (if needed by the implementation).
76
+
77
+ Returns
78
+ -------
79
+ pd.DataFrame
80
+ DataFrame containing comparison results.
81
+ """
82
+ raise NotImplementedError
83
+
84
+
85
+ class TomtomComparator(GeneralMotifComparator):
86
+ """
87
+ Comparator for motifs using Euclidean Distance (ED) or Pearson Correlation (PCC).
88
+ Includes Monte Carlo p-value estimation.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ metric: str = "pcc",
94
+ n_permutations: int = 1000,
95
+ permute_rows: bool = False,
96
+ pfm_mode: bool = False,
97
+ n_jobs: int = 1,
98
+ seed: Optional[int] = None,
99
+ ):
100
+ """
101
+ Initialize comparator.
102
+
103
+ Parameters
104
+ ----------
105
+ metric : str
106
+ 'pcc' or 'ed'.
107
+ n_permutations : int
108
+ Number of Monte Carlo permutations for p-value calculation.
109
+ Set to 0 to disable.
110
+ permute_rows : bool
111
+ If True, shuffles values within each column (destroys nucleotide structure).
112
+ If False, only shuffles columns (positions).
113
+ n_jobs : int
114
+ Number of parallel jobs. -1 to use all cores.
115
+ seed : int, optional
116
+ Random seed for reproducibility.
117
+ """
118
+ super().__init__(name=f"TomtomComparator_{metric.upper()}")
119
+ self.metric = metric.lower()
120
+ if self.metric not in ["pcc", "ed", "cosine"]:
121
+ raise ValueError(f"Unsupported metric: {metric}. Use 'pcc', 'ed', or 'cosine'.")
122
+
123
+ self.n_permutations = n_permutations
124
+ self.permute_rows = permute_rows
125
+ self.pfm_mode = pfm_mode
126
+ self.n_jobs = n_jobs
127
+ self.seed = seed
128
+
129
+ def _prepare_matrix(self, matrix: np.ndarray):
130
+ """
131
+ Удаляет 'N', создает Reverse Complement для k-меров и преобразует в (4^k, L).
132
+ """
133
+ # Определяем структуру осей
134
+ # Предполагаем, что позиция L — это последняя ось.
135
+ # Если L первая (L, 4, 4...), перемещаем её в конец для удобства.
136
+ if matrix.shape[0] > 5:
137
+ matrix = np.moveaxis(matrix, 0, -1)
138
+
139
+ k = matrix.ndim - 1 # Порядок k-мера (1 для моно, 2 для ди, 3 для три)
140
+ L = matrix.shape[-1]
141
+
142
+ # Удаляем 'N' (индекс 4) по всем нуклеотидным осям
143
+ # Создаем динамический срез: (slice(0,4), slice(0,4), ..., slice(None))
144
+ clean_slice = tuple(slice(0, 4) if i < k else slice(None) for i in range(matrix.ndim))
145
+ matrix = matrix[clean_slice]
146
+
147
+ # Создание Reverse Complement (RC)
148
+ # Complement: Инвертируем каждую нуклеотидную ось (A<->T, C<->G)
149
+ # В алфавите [A, C, G, T] это делается простым flip по оси
150
+ rc_matrix = matrix.copy()
151
+ for axis in range(k):
152
+ rc_matrix = np.flip(rc_matrix, axis=axis)
153
+
154
+ rc_matrix = np.flip(rc_matrix, axis=-1)
155
+
156
+ flat_matrix = matrix.reshape(-1, L)
157
+ flat_rc_matrix = rc_matrix.reshape(-1, L)
158
+
159
+ return flat_matrix, flat_rc_matrix
160
+
161
+ def _randomize_matrix(self, matrix: np.ndarray, rng: np.random.Generator):
162
+ """
163
+ Shuffle columns and optionally rows (values) in the original multidimensional matrix.
164
+
165
+ This function implements a surrogate generation procedure where the nucleotide
166
+ structure can be partially or completely destroyed depending on the permute_rows setting.
167
+
168
+ Parameters
169
+ ----------
170
+ matrix : np.ndarray
171
+ Input matrix to randomize.
172
+ rng : np.random.Generator
173
+ Random number generator instance.
174
+
175
+ Returns
176
+ -------
177
+ np.ndarray
178
+ Randomized matrix with shuffled columns and optionally rows.
179
+ """
180
+ # Work with a copy of the full dimensionality
181
+ shuffled = matrix.copy()
182
+
183
+ # 1. Shuffle columns (positions) along the last axis
184
+ # Indices for the last axis
185
+ pos_indices = np.arange(shuffled.shape[-1])
186
+ rng.shuffle(pos_indices)
187
+ shuffled = shuffled[..., pos_indices]
188
+
189
+ if self.permute_rows:
190
+ # 1. Определяем размер алфавита (обычно 4 для A, C, G, T)
191
+ alphabet_size = shuffled.shape[0]
192
+ # 2. Генерируем одну общую перестановку для всех осей
193
+ perm = rng.permutation(alphabet_size)
194
+
195
+ # 3. Применяем перестановку ко всем осям, кроме последней (позиции)
196
+ # Это сохраняет структуру зависимостей (например, AA перейдет в GG)
197
+ for axis in range(shuffled.ndim - 1):
198
+ shuffled = np.take(shuffled, perm, axis=axis)
199
+
200
+ return shuffled
201
+
202
+ def _vectorized_pcc(self, m1: np.ndarray, m2: np.ndarray):
203
+ """
204
+ Compute vectorized Pearson Correlation Coefficient between columns of m1 and m2.
205
+
206
+ Parameters
207
+ ----------
208
+ m1 : np.ndarray
209
+ Shape (4, L1) matrix representing first motif
210
+ m2 : np.ndarray
211
+ Shape (4, L2) matrix representing second motif
212
+
213
+ Returns
214
+ -------
215
+ correlations : np.ndarray
216
+ Array of correlations between corresponding columns
217
+ """
218
+ # Center both matrices by subtracting column means
219
+ m1_centered = m1 - np.mean(m1, axis=0, keepdims=True)
220
+ m2_centered = m2 - np.mean(m2, axis=0, keepdims=True)
221
+
222
+ # Compute standard deviations for normalization
223
+ m1_stds = np.sqrt(np.sum(m1_centered**2, axis=0))
224
+ m2_stds = np.sqrt(np.sum(m2_centered**2, axis=0))
225
+
226
+ # Handle zero-variance columns by setting std to 1 (will result in 0 correlation)
227
+ m1_stds = np.where(m1_stds == 0, 1, m1_stds)
228
+ m2_stds = np.where(m2_stds == 0, 1, m2_stds)
229
+
230
+ # Compute dot product between centered matrices
231
+ numerator = np.sum(m1_centered * m2_centered, axis=0)
232
+
233
+ # Compute correlations
234
+ denominators = m1_stds * m2_stds
235
+ correlations = np.where(denominators != 0, numerator / denominators, 0.0)
236
+
237
+ return correlations
238
+
239
+ def _vectorized_cosine(self, m1: np.ndarray, m2: np.ndarray):
240
+ """
241
+ Compute vectorized Cosine Similarity between columns of m1 and m2.
242
+
243
+ Parameters
244
+ ----------
245
+ m1 : np.ndarray
246
+ Shape (N, L) matrix (N=4 для моно, 16 для динуклеотидов)
247
+ m2 : np.ndarray
248
+ Shape (N, L) matrix
249
+
250
+ Returns
251
+ -------
252
+ similarities : np.ndarray
253
+ Array of cosine similarities between corresponding columns
254
+ """
255
+ # 1. Вычисляем скалярное произведение (числитель)
256
+ # Суммируем по оси строк (нуклеотидов)
257
+ numerator = np.sum(m1 * m2, axis=0)
258
+
259
+ # 2. Вычисляем L2-нормы (длины векторов) для каждой колонки
260
+ norm1 = np.sqrt(np.sum(m1**2, axis=0))
261
+ norm2 = np.sqrt(np.sum(m2**2, axis=0))
262
+
263
+ # 3. Обработка нулевых векторов (чтобы не делить на 0)
264
+ # Если норма 0, значит в колонке все веса 0.
265
+ denominators = norm1 * norm2
266
+
267
+ # 4. Вычисляем сходство
268
+ # Где знаменатель > 0, делим. Где 0 — возвращаем 0.0
269
+ similarities = np.where(denominators > 1e-9, numerator / denominators, 0.0)
270
+
271
+ return similarities
272
+
273
+ def _align_motifs(self, m1: np.ndarray, m2: np.ndarray):
274
+ """
275
+ Align two motifs by sliding one along the other and computing the best score.
276
+
277
+ Parameters
278
+ ----------
279
+ m1 : np.ndarray
280
+ First motif matrix of shape (4, L1).
281
+ m2 : np.ndarray
282
+ Second motif matrix of shape (4, L2).
283
+
284
+ Returns
285
+ -------
286
+ tuple
287
+ Tuple containing (best_score, best_offset) where:
288
+ best_score : Best alignment score found.
289
+ best_offset : Offset at which best score occurs.
290
+ """
291
+ L1 = m1.shape[1]
292
+ L2 = m2.shape[1]
293
+
294
+ # Z-norm
295
+ if self.metric == "ed":
296
+ m1 = (m1 - np.mean(m1, axis=0)) / (np.std(m1, axis=0) + 1e-9)
297
+ m2 = (m2 - np.mean(m2, axis=0)) / (np.std(m2, axis=0) + 1e-9)
298
+
299
+ best_score = -np.inf if self.metric == "ed" else -np.inf
300
+ best_offset = 0
301
+
302
+ min_offset = -(L2 - 1)
303
+ max_offset = L1 - 1
304
+ min_overlap = min(L2, L1) / 2
305
+
306
+ for offset in range(min_offset, max_offset + 1):
307
+ if offset < 0:
308
+ len_overlap = min(L1, L2 + offset)
309
+ if len_overlap < min_overlap:
310
+ continue
311
+ s1, s2 = slice(0, len_overlap), slice(-offset, -offset + len_overlap)
312
+ else:
313
+ len_overlap = min(L1 - offset, L2)
314
+ if len_overlap < min_overlap:
315
+ continue
316
+ s1, s2 = slice(offset, offset + len_overlap), slice(0, len_overlap)
317
+
318
+ cols1, cols2 = m1[:, s1], m2[:, s2]
319
+
320
+ if self.metric == "ed":
321
+ # Compute sum of column-wise Euclidean distances
322
+ # This is the sum of ||col1_i - col2_i|| for each column pair
323
+ column_distances = np.sqrt(np.sum((cols1 - cols2) ** 2, axis=0))
324
+ current_score = -np.sum(column_distances) / len_overlap
325
+ elif self.metric == "pcc":
326
+ # Use vectorized PCC computation
327
+ correlations = self._vectorized_pcc(cols1, cols2)
328
+ current_score = np.sum(correlations) / len_overlap
329
+ elif self.metric == "cosine":
330
+ # Use vectorized PCC computation
331
+ correlations = self._vectorized_cosine(cols1, cols2)
332
+ current_score = np.sum(correlations) / len_overlap
333
+ else:
334
+ # Euclidean distances
335
+ column_distances = np.sqrt(np.sum((cols1 - cols2) ** 2, axis=0))
336
+ current_score = -np.sum(column_distances) / len_overlap
337
+ if current_score > best_score:
338
+ best_score = current_score
339
+ best_offset = offset
340
+
341
+ return best_score, best_offset
342
+
343
+ def _run_single_permutation(self, m1_flat: np.ndarray, m2_orig_matrix: np.ndarray, seed: int):
344
+ """
345
+ Worker function for parallel execution.
346
+
347
+ Generates one surrogate for m2 and compares it with m1.
348
+
349
+ Parameters
350
+ ----------
351
+ m1_flat : np.ndarray
352
+ Flattened version of the first motif matrix.
353
+ m2_orig_matrix : np.ndarray
354
+ Original matrix for the second motif (before flattening).
355
+ seed : int
356
+ Random seed for this permutation.
357
+
358
+ Returns
359
+ -------
360
+ float
361
+ Maximum alignment score between m1 and the randomized m2.
362
+ """
363
+ rng = np.random.default_rng(seed)
364
+
365
+ # 1. Randomize the original m2 matrix (full dimensionality)
366
+ m2_rand_matrix = self._randomize_matrix(m2_orig_matrix, rng)
367
+
368
+ # 2. Prepare randomized matrix (flatten + rc)
369
+ m2_rand_flat, m2_rand_rc_flat = self._prepare_matrix(m2_rand_matrix)
370
+
371
+ # 3. Compare
372
+ score_pp, _ = self._align_motifs(m1_flat, m2_rand_flat)
373
+ score_pm, _ = self._align_motifs(m1_flat, m2_rand_rc_flat)
374
+
375
+ return max(score_pp, score_pm)
376
+
377
+ def compare(
378
+ self,
379
+ motif_1: MotifModel,
380
+ motif_2: MotifModel,
381
+ sequences: RaggedData | None = None,
382
+ ) -> dict:
383
+ """
384
+ Compare two motif models with optional p-value calculation.
385
+
386
+ Parameters
387
+ ----------
388
+ motif_1 : MotifModel
389
+ First motif model to compare.
390
+ motif_2 : MotifModel
391
+ Second motif model to compare.
392
+ sequences : RaggedData or None
393
+ Sequences for comparison (required if pfm_mode is True).
394
+
395
+ Returns
396
+ -------
397
+ dict
398
+ Dictionary containing comparison results.
399
+ """
400
+
401
+ if self.pfm_mode:
402
+ if sequences is None:
403
+ raise ValueError("sequences are required for pfm_mode")
404
+ m1_flat, _ = self._prepare_matrix(pfm_to_pwm(motif_1.get_pfm(sequences, top_fraction=0.1)))
405
+ (
406
+ m2_flat,
407
+ m2_rc_flat,
408
+ ) = self._prepare_matrix(pfm_to_pwm(motif_2.get_pfm(sequences, top_fraction=0.1)))
409
+ else:
410
+ m1_flat, _ = self._prepare_matrix(motif_1.matrix)
411
+ m2_flat, m2_rc_flat = self._prepare_matrix(motif_2.matrix)
412
+
413
+ # --- Observed Score ---
414
+ obs_score_pp, obs_off_pp = self._align_motifs(m1_flat, m2_flat)
415
+ obs_score_pm, obs_off_pm = self._align_motifs(m1_flat, m2_rc_flat)
416
+
417
+ if obs_score_pm > obs_score_pp:
418
+ obs_score = obs_score_pm
419
+ obs_offset = obs_off_pm
420
+ orientation = "+-"
421
+ else:
422
+ obs_score = obs_score_pp
423
+ obs_offset = obs_off_pp
424
+ orientation = "++"
425
+
426
+ result = {
427
+ "query": motif_1.name,
428
+ "target": motif_2.name,
429
+ "score": float(obs_score),
430
+ "offset": int(obs_offset),
431
+ "orientation": orientation,
432
+ "metric": self.metric,
433
+ }
434
+
435
+ # --- Monte Carlo Permutations ---
436
+ if self.n_permutations > 0:
437
+ base_rng = np.random.default_rng(self.seed)
438
+ seeds = base_rng.integers(0, 2**31, size=self.n_permutations)
439
+
440
+ m2 = pfm_to_pwm(motif_2.pfm) if self.pfm_mode else motif_2.matrix
441
+ # Run in parallel
442
+ null_scores = Parallel(n_jobs=self.n_jobs, backend="loky")(
443
+ delayed(self._run_single_permutation)(m1_flat, m2, int(seeds[i])) for i in range(self.n_permutations)
444
+ )
445
+
446
+ null_scores = np.array(null_scores)
447
+
448
+ # --- P-value calculation ---
449
+ mean_null = np.mean(null_scores)
450
+ std_null = np.std(null_scores)
451
+ z_score = (obs_score - mean_null) / (std_null + 1e-9)
452
+
453
+ n_ge = int(np.sum(null_scores >= obs_score))
454
+ p_value = (n_ge + 1.0) / (self.n_permutations + 1.0)
455
+ # p_value = stats.norm.sf(abs(z_score))
456
+
457
+ result.update(
458
+ {
459
+ "p-value": float(p_value),
460
+ "z-score": float(z_score),
461
+ "null_mean": float(mean_null),
462
+ "null_std": float(std_null),
463
+ }
464
+ )
465
+
466
+ return result
467
+
468
+
469
+ class DataComparator:
470
+ """
471
+ Comparator implementation using Jaccard or Overlap metrics
472
+ with permutation-based statistics, working directly with RaggedData objects.
473
+ """
474
+
475
+ def __init__(
476
+ self,
477
+ name: str = "DataComparator",
478
+ metric: str = "cj",
479
+ n_permutations: int = 1000,
480
+ distortion_level: float = 0.4,
481
+ n_jobs: int = -1,
482
+ seed: Optional[int] = None,
483
+ search_range: int = 10,
484
+ min_kernel_size: int = 3,
485
+ max_kernel_size: int = 11,
486
+ ) -> None:
487
+ self.name = name
488
+ self.metric = metric
489
+ self.n_permutations = n_permutations
490
+ self.distortion_level = distortion_level
491
+ self.n_jobs = n_jobs
492
+ self.seed = seed
493
+ self.search_range = search_range
494
+ self.min_kernel_size = min_kernel_size
495
+ self.max_kernel_size = max_kernel_size
496
+
497
+ @staticmethod
498
+ def _compute_metric_internal(S1: RaggedData, S2: RaggedData, search_range: int, metric: str):
499
+ """Internal dispatcher for metric computation kernels."""
500
+ if metric == "cj":
501
+ # Returns (best_cj, best_offset)
502
+ return _fast_cj_kernel_numba(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
503
+ elif metric == "co":
504
+ # Returns (best_co, best_offset)
505
+ return _fast_overlap_kernel_numba(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
506
+ elif metric == "corr":
507
+ # Returns (correlation, p-value, best_offset)
508
+ return _fast_pearson_kernel(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
509
+ else:
510
+ raise ValueError(f"Unknown metric: {metric}")
511
+
512
+ def _single_compare(self, profile1: RaggedData, profile2: RaggedData):
513
+ """Perform a single comparison between two motifs."""
514
+ # Observed scores for both orientations
515
+ obs_res = self._compute_metric_internal(profile1, profile2, self.search_range, self.metric)
516
+
517
+ orientation = "."
518
+ obs_score = float(obs_res[0])
519
+ obs_offset = int(obs_res[-1]) # Offset is always the last element
520
+
521
+ result = {"score": obs_score, "offset": obs_offset, "orientation": orientation, "metric": self.metric}
522
+
523
+ if self.n_permutations > 0:
524
+ base_rng = np.random.default_rng(self.seed)
525
+ seeds = base_rng.integers(0, 2**31, size=self.n_permutations)
526
+
527
+ # Use MotifComparator's surrogate logic for 'cj' and 'co'
528
+ # For 'corr', we use the same surrogate logic if permutations are requested
529
+ results = Parallel(n_jobs=self.n_jobs, backend="loky")(
530
+ delayed(self._compute_surrogate_score)(
531
+ profile1,
532
+ profile2,
533
+ np.random.default_rng(int(seeds[i])),
534
+ )
535
+ for i in range(self.n_permutations)
536
+ )
537
+
538
+ null_scores = np.array([r[0] for r in results if r is not None], dtype=np.float32)
539
+ null_mean = float(np.mean(null_scores))
540
+ null_std = float(np.std(null_scores))
541
+
542
+ z_score = (obs_score - null_mean) / (null_std + 1e-9)
543
+ n_ge = int(np.sum(null_scores >= obs_score))
544
+ p_value = (n_ge + 1.0) / (self.n_permutations + 1.0)
545
+ # p_value = stats.norm.sf(abs(z_score))
546
+
547
+ result.update(
548
+ {
549
+ "p-value": p_value,
550
+ "z-score": z_score,
551
+ "null_mean": null_mean,
552
+ "null_std": null_std,
553
+ }
554
+ )
555
+
556
+ return result
557
+
558
+ def _compute_surrogate_score(self, freq1: RaggedData, freq2: RaggedData, rng: np.random.Generator):
559
+ """Helper for parallel permutation execution."""
560
+ # Reusing the static method from MotifComparator as requested by "preserving algorithmic nuances"
561
+ surrogate = self._generate_single_surrogate(
562
+ freq2,
563
+ rng,
564
+ min_kernel_size=self.min_kernel_size,
565
+ max_kernel_size=self.max_kernel_size,
566
+ distortion_level=self.distortion_level,
567
+ )
568
+ return self._compute_metric_internal(freq1, surrogate, self.search_range, self.metric)
569
+
570
+ @staticmethod
571
+ def _generate_single_surrogate(
572
+ frequencies: RaggedData,
573
+ rng: np.random.Generator,
574
+ min_kernel_size: int = 3,
575
+ max_kernel_size: int = 11,
576
+ distortion_level: float = 1.0,
577
+ ) -> RaggedData:
578
+ """
579
+ Generate a single surrogate frequency profile using convolution with a distorted kernel.
580
+
581
+ This function implements a sophisticated surrogate generation algorithm that creates
582
+ distorted versions of the input frequency profiles. The "distortion" logic refers to
583
+ how the identity kernel is systematically modified through several techniques:
584
+
585
+ 1. Base kernel selection (smooth, edge, double_peak patterns)
586
+ 2. Noise addition with controlled amplitude
587
+ 3. Gradient application to introduce directional bias
588
+ 4. Smoothing to reduce artifacts
589
+ 5. Convex combination with identity kernel based on distortion level
590
+ 6. Sign flipping for additional variation
591
+
592
+ Parameters
593
+ ----------
594
+ frequencies : RaggedData
595
+ Input frequency profile to generate surrogate from.
596
+ rng : np.random.Generator
597
+ Random number generator instance.
598
+ min_kernel_size : int, optional
599
+ Minimum size of the convolution kernel (default is 3).
600
+ max_kernel_size : int, optional
601
+ Maximum size of the convolution kernel (default is 11).
602
+ distortion_level : float, optional
603
+ Level of distortion to apply (0.0 to 1.0, default is 1.0).
604
+
605
+ Returns
606
+ -------
607
+ RaggedData
608
+ Surrogate frequency profile generated from the input.
609
+ """
610
+ # For simplicity in surrogate generation, we use dense adapter
611
+ dense_adapter = RaggedScores.from_numba(frequencies)
612
+ X = dense_adapter.values
613
+ lengths = dense_adapter.lengths
614
+
615
+ kernel_size = int(rng.integers(min_kernel_size, max_kernel_size + 1))
616
+ if kernel_size % 2 == 0:
617
+ kernel_size += 1
618
+ center = kernel_size // 2
619
+
620
+ kernel_types = ["smooth", "edge", "double_peak"]
621
+ kernel_type = str(rng.choice(kernel_types))
622
+
623
+ identity_kernel = np.zeros(kernel_size, dtype=np.float32)
624
+ identity_kernel[center] = 1.0
625
+
626
+ if kernel_type == "smooth":
627
+ x = np.linspace(-3, 3, kernel_size)
628
+ base = np.exp(-0.5 * x**2).astype(np.float32)
629
+ elif kernel_type == "edge":
630
+ base = np.zeros(kernel_size, dtype=np.float32)
631
+ base[max(center - 1, 0)] = -1.0
632
+ base[min(center + 1, kernel_size - 1)] = 1.0
633
+ elif kernel_type == "double_peak":
634
+ base = np.zeros(kernel_size, dtype=np.float32)
635
+ base[0] = 0.5
636
+ base[-1] = 0.5
637
+ base[center] = -1.0
638
+ else:
639
+ base = identity_kernel.copy()
640
+
641
+ noise = rng.normal(0, 1, size=kernel_size).astype(np.float32)
642
+ slope = float(rng.uniform(-1.0, 1.0)) * distortion_level * 2.0
643
+ gradient = np.linspace(-slope, slope, kernel_size).astype(np.float32)
644
+
645
+ distorted_kernel = base + distortion_level * noise + gradient
646
+
647
+ if kernel_size >= 3:
648
+ smooth_filter = np.array([0.25, 0.5, 0.25], dtype=np.float32)
649
+ distorted_kernel = np.convolve(distorted_kernel, smooth_filter, mode="same")
650
+
651
+ distorted_kernel /= np.linalg.norm(distorted_kernel) + 1e-8
652
+
653
+ alpha = max(0.0, min(1.0, distortion_level))
654
+ final_kernel = (1.0 - alpha) * identity_kernel + alpha * distorted_kernel
655
+ if rng.uniform() < 0.5:
656
+ final_kernel = -final_kernel
657
+ final_kernel /= np.linalg.norm(final_kernel) + 1e-8
658
+
659
+ convolved = convolve1d(X, final_kernel, axis=1, mode="constant", cval=0.0).astype(np.float32)
660
+
661
+ # Convert back to RaggedData
662
+ convolved_list = [convolved[i, : lengths[i]] for i in range(len(lengths))]
663
+ convolved_ragged = ragged_from_list(convolved_list, dtype=np.float32)
664
+
665
+ return scores_to_frequencies(convolved_ragged)
666
+
667
+ def compare(
668
+ self,
669
+ profile1: RaggedData,
670
+ profile2: RaggedData,
671
+ ) -> dict | None:
672
+ """
673
+ Compare two RaggedData objects directly.
674
+ """
675
+
676
+ # Calculate comparison metrics
677
+ out = self._single_compare(
678
+ profile1=profile1,
679
+ profile2=profile2,
680
+ )
681
+
682
+ # Create generic identifiers for the data
683
+ result = {"query": "Data1", "target": "Data2"}
684
+ result.update(out)
685
+
686
+ return result
687
+
688
+
689
+ class UniversalMotifComparator(GeneralMotifComparator):
690
+ """
691
+ Universal comparator implementation that integrates functionality from both
692
+ MotifComparator and CorrelationComparator. Supports Jaccard ('cj'),
693
+ Overlap ('co'), and Pearson Correlation ('corr') metrics with optional
694
+ permutation-based statistics.
695
+ """
696
+
697
+ def __init__(
698
+ self,
699
+ name: str = "UniversalMotifComparator",
700
+ metric: str = "cj",
701
+ n_permutations: int = 1000,
702
+ distortion_level: float = 0.4,
703
+ n_jobs: int = -1,
704
+ seed: Optional[int] = None,
705
+ min_kernel_size: int = 3,
706
+ max_kernel_size: int = 11,
707
+ search_range: int = 10,
708
+ ) -> None:
709
+ """
710
+ Initialize the unified comparator.
711
+
712
+ Parameters
713
+ ----------
714
+ name : str
715
+ Name of the comparator instance.
716
+ metric : str
717
+ Similarity metric to use: 'cj' (Continuous Jaccard),
718
+ 'co' (Continuous Overlap), or 'corr' (Pearson Correlation).
719
+ n_permutations : int
720
+ Number of permutations for statistical significance testing.
721
+ distortion_level : float
722
+ Level of distortion for surrogate generation (used for 'cj' and 'co').
723
+ n_jobs : int
724
+ Number of parallel jobs for permutations.
725
+ seed : int, optional
726
+ Random seed for reproducibility.
727
+ search_range : int
728
+ Range to search for optimal offset alignment.
729
+ """
730
+ super().__init__(name)
731
+ self.metric = metric.lower()
732
+ if self.metric not in ["cj", "co", "corr"]:
733
+ raise ValueError(f"Unsupported metric: {metric}. Use 'cj', 'co', or 'corr'.")
734
+
735
+ self.n_permutations = n_permutations
736
+ self.distortion_level = distortion_level
737
+ self.n_jobs = n_jobs
738
+ self.seed = seed
739
+ self.min_kernel_size = min_kernel_size
740
+ self.max_kernel_size = max_kernel_size
741
+ self.search_range = search_range
742
+
743
+ @staticmethod
744
+ def _compute_metric_internal(S1: RaggedData, S2: RaggedData, search_range: int, metric: str):
745
+ """Internal dispatcher for metric computation kernels."""
746
+ if metric == "cj":
747
+ return _fast_cj_kernel_numba(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
748
+ elif metric == "co":
749
+ return _fast_overlap_kernel_numba(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
750
+ elif metric == "corr":
751
+ # Returns (correlation, p-value, offset)
752
+ return _fast_pearson_kernel(S1.data, S1.offsets, S2.data, S2.offsets, search_range)
753
+ else:
754
+ raise ValueError(f"Unknown metric: {metric}")
755
+
756
+ def _single_compare(self, motif1: MotifModel, motif2: MotifModel, sequences: RaggedData):
757
+ """Perform a single comparison between two motifs."""
758
+ freq1_plus = motif1.get_frequencies(sequences, strand="+")
759
+ freq2_plus = motif2.get_frequencies(sequences, strand="+")
760
+ freq2_minus = motif2.get_frequencies(sequences, strand="-")
761
+
762
+ # Observed scores for both orientations
763
+ res_pp = self._compute_metric_internal(freq1_plus, freq2_plus, self.search_range, self.metric)
764
+ res_pm = self._compute_metric_internal(freq1_plus, freq2_minus, self.search_range, self.metric)
765
+
766
+ # Extract scores for comparison (first element of return tuple for all kernels)
767
+ score_pp = res_pp[0]
768
+ score_pm = res_pm[0]
769
+
770
+ if score_pm > score_pp:
771
+ orientation = "+-"
772
+ obs_res = res_pm
773
+ freq1, freq2 = freq1_plus, freq2_minus
774
+ else:
775
+ orientation = "++"
776
+ obs_res = res_pp
777
+ freq1, freq2 = freq1_plus, freq2_plus
778
+
779
+ obs_score = float(obs_res[0])
780
+ obs_offset = int(obs_res[-1]) # Offset is always the last element
781
+
782
+ result = {"score": obs_score, "offset": obs_offset, "orientation": orientation, "metric": self.metric}
783
+
784
+ if self.n_permutations > 0:
785
+ base_rng = np.random.default_rng(self.seed)
786
+ seeds = base_rng.integers(0, 2**31, size=self.n_permutations)
787
+
788
+ # Use MotifComparator's surrogate logic for 'cj' and 'co'
789
+ # For 'corr', we use the same surrogate logic if permutations are requested
790
+ results = Parallel(n_jobs=self.n_jobs, backend="loky")(
791
+ delayed(self._compute_surrogate_score)(
792
+ freq1,
793
+ freq2,
794
+ np.random.default_rng(int(seeds[i])),
795
+ )
796
+ for i in range(self.n_permutations)
797
+ )
798
+
799
+ null_scores = np.array([r[0] for r in results if r is not None], dtype=np.float32)
800
+ null_mean = float(np.mean(null_scores))
801
+ null_std = float(np.std(null_scores))
802
+
803
+ z_score = (obs_score - null_mean) / (null_std + 1e-9)
804
+ n_ge = int(np.sum(null_scores >= obs_score))
805
+ p_value = (n_ge + 1.0) / (self.n_permutations + 1.0)
806
+ # p_value = stats.norm.sf(abs(z_score))
807
+
808
+ result.update(
809
+ {
810
+ "p-value": p_value,
811
+ "z-score": z_score,
812
+ "null_mean": null_mean,
813
+ "null_std": null_std,
814
+ }
815
+ )
816
+
817
+ return result
818
+
819
+ def _compute_surrogate_score(self, freq1: RaggedData, freq2: RaggedData, rng: np.random.Generator):
820
+ """Helper for parallel permutation execution."""
821
+ # Reusing the static method from MotifComparator as requested by "preserving algorithmic nuances"
822
+ surrogate = self._generate_single_surrogate(
823
+ freq2,
824
+ rng,
825
+ min_kernel_size=self.min_kernel_size,
826
+ max_kernel_size=self.max_kernel_size,
827
+ distortion_level=self.distortion_level,
828
+ )
829
+ return self._compute_metric_internal(freq1, surrogate, self.search_range, self.metric)
830
+
831
+ @staticmethod
832
+ def _generate_single_surrogate(
833
+ frequencies: RaggedData,
834
+ rng: np.random.Generator,
835
+ min_kernel_size: int = 3,
836
+ max_kernel_size: int = 11,
837
+ distortion_level: float = 1.0,
838
+ ) -> RaggedData:
839
+ """
840
+ Generate a single surrogate frequency profile using convolution with a distorted kernel.
841
+
842
+ This function implements a sophisticated surrogate generation algorithm that creates
843
+ distorted versions of the input frequency profiles. The "distortion" logic refers to
844
+ how the identity kernel is systematically modified through several techniques:
845
+
846
+ 1. Base kernel selection (smooth, edge, double_peak patterns)
847
+ 2. Noise addition with controlled amplitude
848
+ 3. Gradient application to introduce directional bias
849
+ 4. Smoothing to reduce artifacts
850
+ 5. Convex combination with identity kernel based on distortion level
851
+ 6. Sign flipping for additional variation
852
+
853
+ Parameters
854
+ ----------
855
+ frequencies : RaggedData
856
+ Input frequency profile to generate surrogate from.
857
+ rng : np.random.Generator
858
+ Random number generator instance.
859
+ min_kernel_size : int, optional
860
+ Minimum size of the convolution kernel (default is 3).
861
+ max_kernel_size : int, optional
862
+ Maximum size of the convolution kernel (default is 11).
863
+ distortion_level : float, optional
864
+ Level of distortion to apply (0.0 to 1.0, default is 1.0).
865
+
866
+ Returns
867
+ -------
868
+ RaggedData
869
+ Surrogate frequency profile generated from the input.
870
+ """
871
+ # For simplicity in surrogate generation, we use dense adapter
872
+ dense_adapter = RaggedScores.from_numba(frequencies)
873
+ X = dense_adapter.values
874
+ lengths = dense_adapter.lengths
875
+
876
+ kernel_size = int(rng.integers(min_kernel_size, max_kernel_size + 1))
877
+ if kernel_size % 2 == 0:
878
+ kernel_size += 1
879
+ center = kernel_size // 2
880
+
881
+ kernel_types = ["smooth", "edge", "double_peak"]
882
+ kernel_type = str(rng.choice(kernel_types))
883
+
884
+ identity_kernel = np.zeros(kernel_size, dtype=np.float32)
885
+ identity_kernel[center] = 1.0
886
+
887
+ if kernel_type == "smooth":
888
+ x = np.linspace(-3, 3, kernel_size)
889
+ base = np.exp(-0.5 * x**2).astype(np.float32)
890
+ elif kernel_type == "edge":
891
+ base = np.zeros(kernel_size, dtype=np.float32)
892
+ base[max(center - 1, 0)] = -1.0
893
+ base[min(center + 1, kernel_size - 1)] = 1.0
894
+ elif kernel_type == "double_peak":
895
+ base = np.zeros(kernel_size, dtype=np.float32)
896
+ base[0] = 0.5
897
+ base[-1] = 0.5
898
+ base[center] = -1.0
899
+ else:
900
+ base = identity_kernel.copy()
901
+
902
+ noise = rng.normal(0, 1, size=kernel_size).astype(np.float32)
903
+ slope = float(rng.uniform(-1.0, 1.0)) * distortion_level * 2.0
904
+ gradient = np.linspace(-slope, slope, kernel_size).astype(np.float32)
905
+
906
+ distorted_kernel = base + distortion_level * noise + gradient
907
+
908
+ if kernel_size >= 3:
909
+ smooth_filter = np.array([0.25, 0.5, 0.25], dtype=np.float32)
910
+ distorted_kernel = np.convolve(distorted_kernel, smooth_filter, mode="same")
911
+
912
+ distorted_kernel /= np.linalg.norm(distorted_kernel) + 1e-8
913
+
914
+ alpha = max(0.0, min(1.0, distortion_level))
915
+ final_kernel = (1.0 - alpha) * identity_kernel + alpha * distorted_kernel
916
+ if rng.uniform() < 0.5:
917
+ final_kernel = -final_kernel
918
+ final_kernel /= np.linalg.norm(final_kernel) + 1e-8
919
+
920
+ convolved = convolve1d(X, final_kernel, axis=1, mode="constant", cval=0.0).astype(np.float32)
921
+
922
+ # Convert back to RaggedData
923
+ convolved_list = [convolved[i, : lengths[i]] for i in range(len(lengths))]
924
+ convolved_ragged = ragged_from_list(convolved_list, dtype=np.float32)
925
+
926
+ return scores_to_frequencies(convolved_ragged)
927
+
928
+ def compare(self, motif_1: MotifModel, motif_2: MotifModel, sequences: RaggedData | None = None) -> dict:
929
+ """
930
+ Compare two motif models pairwise.
931
+
932
+ Parameters
933
+ ----------
934
+ motif_1 : MotifModel
935
+ First motif to compare.
936
+ motif_2 : MotifModel
937
+ Second motif to compare.
938
+ sequences : RaggedData or None
939
+ Sequences for frequency calculation.
940
+
941
+ Returns
942
+ -------
943
+ dict
944
+ Dictionary containing comparison results with statistical information.
945
+ """
946
+ if sequences is None:
947
+ raise ValueError("Sequences list is required for this comparator.")
948
+
949
+ # Calculate comparison metrics
950
+ out = self._single_compare(
951
+ motif1=motif_1,
952
+ motif2=motif_2,
953
+ sequences=sequences,
954
+ )
955
+
956
+ # Merge identification info with results
957
+ result = {"query": motif_1.name, "target": motif_2.name}
958
+ result.update(out)
959
+
960
+ return result
961
+
962
+
963
+ class MotaliComparator(GeneralMotifComparator):
964
+ """Comparator that wraps the Motali program.
965
+
966
+ This comparator uses an external Motali program to compute similarity
967
+ between Position Frequency Matrices (PFMs).
968
+ """
969
+
970
+ def __init__(self, fasta_path: str, threshold: float = 0.95, tmp_directory: str = ".") -> None:
971
+ """
972
+ Initialize the MotaliComparator.
973
+
974
+ Parameters
975
+ ----------
976
+ fasta_path : str
977
+ Path to the FASTA file containing sequences for comparison.
978
+ threshold : float, optional
979
+ Minimum score threshold for filtering results (default is 0.95).
980
+ tmp_directory : str, optional
981
+ Directory for temporary files (default is '.', the current working directory).
982
+ """
983
+ super().__init__(name="motali")
984
+ self.threshold = threshold
985
+ self.tmp_directory = tmp_directory
986
+ self.fasta_path = fasta_path
987
+
988
+ def compare(self, motif_1: MotifModel, motif_2: MotifModel, sequences: RaggedData | None = None) -> dict:
989
+ """
990
+ Compare two motif models using the Motali program.
991
+
992
+ Parameters
993
+ ----------
994
+ motif_1 : MotifModel
995
+ First motif to compare.
996
+ motif_2 : MotifModel
997
+ Second motif to compare.
998
+ sequences : RaggedData or None
999
+ Sequences for comparison (not used in this implementation).
1000
+
1001
+ Returns
1002
+ -------
1003
+ dict or None
1004
+ Dictionary containing comparison results with columns:
1005
+ - query: name of the first motif
1006
+ - target: name of the second motif
1007
+ - score: similarity score computed by Motali
1008
+ Returns None if the score is below the threshold.
1009
+ """
1010
+ with tempfile.TemporaryDirectory(dir=self.tmp_directory, ignore_cleanup_errors=True) as tmp:
1011
+ # Determine file extensions based on model types
1012
+ type_1 = motif_1.model_type
1013
+ type_2 = motif_2.model_type
1014
+
1015
+ if type_1 == "sitega":
1016
+ type_1 = "sga"
1017
+
1018
+ if type_2 == "sitega":
1019
+ type_2 = "sga"
1020
+
1021
+ # Set file extensions based on model types
1022
+ ext_1 = ".pfm" if type_1 == "pwm" else ".mat"
1023
+ ext_2 = ".pfm" if type_2 == "pwm" else ".mat"
1024
+
1025
+ m1_path = os.path.join(tmp, f"motif_1{ext_1}")
1026
+ m2_path = os.path.join(tmp, f"motif_2{ext_2}")
1027
+
1028
+ d1_path = os.path.join(tmp, "thresholds_1.dist")
1029
+ d2_path = os.path.join(tmp, "thresholds_2.dist")
1030
+
1031
+ overlap_path = os.path.join(tmp, "overlap.txt")
1032
+ all_path = os.path.join(tmp, "all.txt")
1033
+ sta_path = os.path.join(tmp, "sta.txt")
1034
+ prc_path = os.path.join(tmp, "prc_pass.txt")
1035
+ hist_path = os.path.join(tmp, "hist_pass.txt")
1036
+
1037
+ # Write motifs using polymorphic write method
1038
+ motif_1.write(m1_path)
1039
+ motif_2.write(m2_path)
1040
+
1041
+ # Write distance thresholds
1042
+ motif_1.write_dist(d1_path)
1043
+ motif_2.write_dist(d2_path)
1044
+
1045
+ fasta_path = self.fasta_path
1046
+ if fasta_path is None and sequences is not None:
1047
+ fasta_path = os.path.join(tmp, "sequences.fa")
1048
+ write_fasta(sequences, fasta_path)
1049
+ score = run_motali(
1050
+ fasta_path,
1051
+ m1_path,
1052
+ m2_path,
1053
+ type_1,
1054
+ type_2,
1055
+ d1_path,
1056
+ d2_path,
1057
+ overlap_path,
1058
+ all_path,
1059
+ prc_path,
1060
+ hist_path,
1061
+ sta_path,
1062
+ )
1063
+
1064
+ result = {"query": motif_1.name, "target": motif_2.name, "score": score}
1065
+
1066
+ return result