yamcot 1.0.0__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yamcot/functions.py ADDED
@@ -0,0 +1,787 @@
1
+ import numpy as np
2
+ from numba import njit, prange
3
+ from scipy.stats import pearsonr
4
+
5
+ from yamcot.ragged import RaggedData
6
+
7
+ RC_TABLE = np.array([3, 2, 1, 0, 4], dtype=np.int8)
8
+ BACKGROUND_FREQ = 0.25 # Background frequency for PWM calculation
9
+ PFM_TO_PWM_PSEUDOCOUNT = 0.0001 # Pseudocount added to PFM values
10
+ PCM_TO_PFM_NUCLEOTIDE_PSEUDOCOUNT = 0.25 # Pseudocount for nucleotide frequency
11
+ PCM_TO_PFM_DENOMINATOR_CONSTANT = 1 # Constant added to denominator in PCM to PFM conversion
12
+
13
+
14
+ def pfm_to_pwm(pfm):
15
+ """
16
+ Convert Position Frequency Matrix to Position Weight Matrix.
17
+
18
+ Parameters
19
+ ----------
20
+ pfm : np.ndarray
21
+ Position Frequency Matrix of shape (4, L) where L is the motif length.
22
+
23
+ Returns
24
+ -------
25
+ np.ndarray
26
+ Position Weight Matrix computed as log(PFM + pseudo_count) / background.
27
+ """
28
+ background = BACKGROUND_FREQ
29
+ pwm = np.log((pfm + PFM_TO_PWM_PSEUDOCOUNT) / background)
30
+ return pwm
31
+
32
+
33
+ def pcm_to_pfm(pcm):
34
+ """
35
+ Convert Position Count Matrix to Position Frequency Matrix.
36
+
37
+ Parameters
38
+ ----------
39
+ pcm : np.ndarray
40
+ Position Count Matrix of shape (4, L) where L is the motif length.
41
+
42
+ Returns
43
+ -------
44
+ np.ndarray
45
+ Position Frequency Matrix with pseudo-counts added.
46
+ """
47
+ number_of_sites = pcm.sum(axis=0)
48
+ nuc_pseudo = PCM_TO_PFM_NUCLEOTIDE_PSEUDOCOUNT
49
+ pfm = (pcm + nuc_pseudo) / (number_of_sites + PCM_TO_PFM_DENOMINATOR_CONSTANT)
50
+ return pfm
51
+
52
+
53
+ @njit
54
+ def score_seq(num_site, kmer, model):
55
+ """
56
+ Compute score for a sequence site using a k-mer model.
57
+
58
+ Parameters
59
+ ----------
60
+ num_site : np.ndarray
61
+ Numerical representation of the DNA sequence site.
62
+ kmer : int
63
+ Length of the k-mer used for indexing.
64
+ model : np.ndarray
65
+ Scoring model matrix.
66
+
67
+ Returns
68
+ -------
69
+ float
70
+ Computed score for the sequence site.
71
+ """
72
+ score = 0.0
73
+ seq_len = num_site.shape[0]
74
+ for i in range(seq_len - kmer + 1):
75
+ score_idx = 0
76
+ for j in range(kmer):
77
+ score_idx = score_idx * 5 + num_site[i + j] # Convert to single index
78
+ score += model.flat[score_idx * model.shape[-1] + i] # Access via flat index
79
+
80
+ return score
81
+
82
+
83
+ @njit(inline="always")
84
+ def _fill_rc_buffer(data, start, length, buffer):
85
+ """
86
+ Заполняет буфер обратным комплементом без аллокаций.
87
+ """
88
+ for j in range(length):
89
+ val = data[start + length - 1 - j]
90
+ buffer[j] = RC_TABLE[val]
91
+
92
+
93
+ @njit(parallel=True, fastmath=True, cache=True)
94
+ def _batch_all_scores_jit(data, offsets, matrix, kmer, is_revcomp):
95
+ """
96
+ Специализированное ядро для PWM (без логики BaMM).
97
+ Максимальная скорость за счет отсутствия лишних проверок и буферов (для forward).
98
+ """
99
+ n_seq = len(offsets) - 1
100
+ m = matrix.shape[-1]
101
+
102
+ # 1. Расчет новых смещений
103
+ new_offsets = np.zeros(n_seq + 1, dtype=np.int64)
104
+ for i in range(n_seq):
105
+ seq_len = offsets[i + 1] - offsets[i]
106
+ if seq_len >= m:
107
+ new_offsets[i + 1] = seq_len - m + 1
108
+
109
+ for i in range(n_seq):
110
+ new_offsets[i + 1] += new_offsets[i]
111
+
112
+ total_scores = new_offsets[n_seq]
113
+ results = np.zeros(total_scores, dtype=np.float32)
114
+
115
+ # 2. Основной цикл
116
+ for i in prange(n_seq):
117
+ start = offsets[i]
118
+ out_start = new_offsets[i]
119
+ n_scores = new_offsets[i + 1] - out_start
120
+
121
+ if n_scores > 0:
122
+ site_buffer = np.empty(m, dtype=data.dtype)
123
+
124
+ for k in range(n_scores):
125
+ if not is_revcomp:
126
+ # Zero-copy view для прямой цепи
127
+ num_site = data[start + k : start + k + m]
128
+ results[out_start + k] = score_seq(num_site, kmer, matrix)
129
+ else:
130
+ # Заполнение буфера для обратной цепи
131
+ _fill_rc_buffer(data, start + k, m, site_buffer)
132
+ results[out_start + k] = score_seq(site_buffer, kmer, matrix)
133
+
134
+ return results, new_offsets
135
+
136
+
137
+ @njit(parallel=True, fastmath=True, cache=True)
138
+ def _batch_all_scores_with_context_jit(data, offsets, matrix, kmer, is_revcomp):
139
+ """
140
+ Специализированное ядро для BaMM.
141
+ Обрабатывает контекст и паддинг (N).
142
+ """
143
+ n_seq = len(offsets) - 1
144
+ m = matrix.shape[-1]
145
+ context_len = kmer - 1
146
+ window_size = m + context_len
147
+
148
+ # Расчет смещений аналогичен PWM
149
+ new_offsets = np.zeros(n_seq + 1, dtype=np.int64)
150
+ for i in range(n_seq):
151
+ seq_len = offsets[i + 1] - offsets[i]
152
+ if seq_len >= m:
153
+ new_offsets[i + 1] = seq_len - m + 1
154
+
155
+ for i in range(n_seq):
156
+ new_offsets[i + 1] += new_offsets[i]
157
+
158
+ total_scores = new_offsets[n_seq]
159
+ results = np.zeros(total_scores, dtype=np.float32)
160
+
161
+ for i in prange(n_seq):
162
+ start = offsets[i]
163
+ seq_len = offsets[i + 1] - start
164
+ out_start = new_offsets[i]
165
+ n_scores = new_offsets[i + 1] - out_start
166
+
167
+ # Буфер обязателен для BaMM из-за паддинга
168
+ site_buffer = np.full(window_size, 4, dtype=data.dtype)
169
+
170
+ if n_scores > 0:
171
+ for k in range(n_scores):
172
+ # Сброс буфера в 'N' (4)
173
+ site_buffer[:] = 4
174
+
175
+ if not is_revcomp:
176
+ # Forward: копируем с учетом границ
177
+ s_idx = k - context_len
178
+ e_idx = k + m
179
+
180
+ actual_start = max(0, s_idx)
181
+ actual_end = min(seq_len, e_idx)
182
+ dest_start = max(0, -s_idx)
183
+
184
+ copy_len = actual_end - actual_start
185
+ if copy_len > 0:
186
+ site_buffer[dest_start : dest_start + copy_len] = data[
187
+ start + actual_start : start + actual_end
188
+ ]
189
+ else:
190
+ # Reverse: сложная логика RC с паддингом
191
+ r_start = k
192
+ # Заполняем буфер RC значениями
193
+ for t in range(window_size):
194
+ data_idx = start + r_start + (window_size - 1 - t)
195
+ if start <= data_idx < start + seq_len:
196
+ site_buffer[t] = RC_TABLE[data[data_idx]]
197
+
198
+ results[out_start + k] = score_seq(site_buffer, kmer, matrix)
199
+
200
+ return results, new_offsets
201
+
202
+
203
+ def batch_all_scores(
204
+ sequences: RaggedData, matrix: np.ndarray, kmer: int = 1, is_revcomp: bool = False, with_context: bool = False
205
+ ) -> RaggedData:
206
+ """
207
+ Compute scores for all sequences in RaggedData.
208
+ Supports both PWM (with_context=False) and BaMM models.
209
+
210
+ Parameters
211
+ ----------
212
+ sequences : RaggedData
213
+ Input sequences in RaggedData format.
214
+ matrix : np.ndarray
215
+ Scoring matrix for motif evaluation.
216
+ kmer : int, optional
217
+ K-mer length parameter for scoring (default is 1).
218
+ is_revcomp : bool, optional
219
+ Whether to consider reverse complement strand (default is False).
220
+ with_context : bool, optional
221
+ Whether to use extending site ((kmer - 1) + length of site) (default is False).
222
+
223
+ Returns
224
+ -------
225
+ RaggedData
226
+ RaggedData object containing computed scores.
227
+ """
228
+ if with_context:
229
+ data, offsets = _batch_all_scores_with_context_jit(sequences.data, sequences.offsets, matrix, kmer, is_revcomp)
230
+ else:
231
+ data, offsets = _batch_all_scores_jit(sequences.data, sequences.offsets, matrix, kmer, is_revcomp)
232
+ return RaggedData(data, offsets)
233
+
234
+
235
+ @njit
236
+ def precision_recall_curve(classification, scores):
237
+ """Compute precision-recall curve (JIT-compiled)."""
238
+ n = len(scores)
239
+ if n == 0:
240
+ return np.array([1.0]), np.array([0.0]), np.array([np.inf])
241
+
242
+ # Get indices for sorting scores in descending order
243
+ indexes = np.argsort(scores)[::-1]
244
+ sorted_scores = scores[indexes]
245
+ sorted_classification = classification[indexes]
246
+
247
+ # Initialize arrays (with +1 buffer for initial point)
248
+ max_size = n
249
+
250
+ precision = np.zeros(max_size)
251
+ recall = np.zeros(max_size)
252
+ uniq_scores = np.zeros(max_size)
253
+
254
+ # Initial point: (recall=0, precision=1, threshold=inf)
255
+ precision[0] = 1.0
256
+ recall[0] = 0.0
257
+ uniq_scores[0] = np.inf
258
+
259
+ TP, FP = 0, 0
260
+ number_of_true = np.sum(classification == 1)
261
+ number_of_false = np.sum(classification == 0)
262
+
263
+ if number_of_false == 0:
264
+ true_false_ratio = 1.0
265
+ else:
266
+ true_false_ratio = number_of_true / number_of_false
267
+
268
+ position = 1
269
+ score = sorted_scores[0]
270
+
271
+ for i in range(len(scores)):
272
+ _score = sorted_scores[i]
273
+ _flag = sorted_classification[i]
274
+
275
+ # Update TP and FP
276
+ if _flag == 1:
277
+ TP += 1
278
+ else:
279
+ FP += 1
280
+
281
+ # Check if score changed
282
+ if i == len(scores) - 1 or score != sorted_scores[i + 1]:
283
+ uniq_scores[position] = _score
284
+
285
+ if TP + FP > 0:
286
+ precision[position] = TP / (TP + true_false_ratio * FP)
287
+ else:
288
+ precision[position] = 1.0
289
+
290
+ if number_of_true > 0:
291
+ recall[position] = TP / number_of_true
292
+ else:
293
+ recall[position] = 0.0
294
+
295
+ position += 1
296
+ if i < len(scores) - 1:
297
+ score = sorted_scores[i + 1]
298
+
299
+ return precision[:position], recall[:position], uniq_scores[:position]
300
+
301
+
302
+ @njit
303
+ def roc_curve(classification, scores):
304
+ """Compute ROC curve (JIT-compiled)."""
305
+ n = len(scores)
306
+ if n == 0:
307
+ return np.array([0.0]), np.array([0.0]), np.array([np.inf])
308
+
309
+ # Get indices for sorting scores in descending order
310
+ indexes = np.argsort(scores)[::-1]
311
+ sorted_scores = scores[indexes]
312
+ sorted_classification = classification[indexes]
313
+
314
+ # Initialize arrays
315
+ max_size = n + 1
316
+
317
+ tpr = np.zeros(max_size)
318
+ fpr = np.zeros(max_size)
319
+ uniq_scores = np.zeros(max_size)
320
+
321
+ # Initial point: (fpr=0, tpr=0, threshold=inf)
322
+ tpr[0] = 0.0
323
+ fpr[0] = 0.0
324
+ uniq_scores[0] = np.inf
325
+
326
+ TP, FP = 0, 0
327
+ number_of_true = np.sum(classification == 1)
328
+ number_of_false = np.sum(classification == 0)
329
+ position = 1
330
+ score = sorted_scores[0]
331
+
332
+ for i in range(len(scores)):
333
+ _score = sorted_scores[i]
334
+ _flag = sorted_classification[i]
335
+
336
+ # Update TP and FP
337
+ if _flag == 1:
338
+ TP += 1
339
+ else:
340
+ FP += 1
341
+
342
+ # Check if score changed
343
+ if i == len(scores) - 1 or score != sorted_scores[i + 1]:
344
+ uniq_scores[position] = _score
345
+
346
+ if number_of_true > 0:
347
+ tpr[position] = TP / number_of_true
348
+ else:
349
+ tpr[position] = 0.0
350
+
351
+ if number_of_false > 0:
352
+ fpr[position] = FP / number_of_false
353
+ else:
354
+ fpr[position] = 0.0
355
+
356
+ position += 1
357
+ if i < len(scores) - 1:
358
+ score = sorted_scores[i + 1]
359
+
360
+ return tpr[:position], fpr[:position], uniq_scores[:position]
361
+
362
+
363
+ def cut_roc(tpr: np.ndarray, fpr: np.ndarray, thr: np.ndarray, score_cutoff: float):
364
+ """
365
+ Truncate ROC curve at a specific score threshold.
366
+
367
+ This function truncates the ROC curve (True Positive Rate vs False Positive Rate)
368
+ at a given score threshold. If interpolation is needed between points, it
369
+ performs linear interpolation to determine the TPR and FPR values at the exact
370
+ score cutoff.
371
+
372
+ Parameters
373
+ ----------
374
+ tpr : np.ndarray
375
+ True Positive Rate values from the ROC curve.
376
+ fpr : np.ndarray
377
+ False Positive Rate values from the ROC curve.
378
+ thr : np.ndarray
379
+ Threshold values corresponding to each TPR/FPR pair.
380
+ score_cutoff : float
381
+ The score threshold at which to truncate the ROC curve.
382
+
383
+ Returns
384
+ -------
385
+ tuple
386
+ Tuple containing (truncated_tpr, truncated_fpr, truncated_thresholds).
387
+ """
388
+ if score_cutoff == -np.inf:
389
+ return tpr, fpr, thr
390
+
391
+ # thr starts with inf, then decreases
392
+ mask = thr >= score_cutoff
393
+ if not np.any(mask):
394
+ return (
395
+ np.array([tpr[0]], dtype=tpr.dtype),
396
+ np.array([0.0], dtype=fpr.dtype),
397
+ np.array([score_cutoff], dtype=thr.dtype),
398
+ )
399
+
400
+ last = int(np.where(mask)[0][-1])
401
+
402
+ if thr[last] == score_cutoff or last == len(thr) - 1:
403
+ return tpr[: last + 1], fpr[: last + 1], thr[: last + 1]
404
+
405
+ # Score interpolation
406
+ s0, s1 = float(thr[last]), float(thr[last + 1])
407
+ t0, t1 = float(tpr[last]), float(tpr[last + 1])
408
+ f0, f1 = float(fpr[last]), float(fpr[last + 1])
409
+
410
+ alpha = 0.0 if s0 == s1 else (score_cutoff - s0) / (s1 - s0)
411
+ t_cut = t0 + alpha * (t1 - t0)
412
+ f_cut = f0 + alpha * (f1 - f0)
413
+
414
+ tpr_cut = np.concatenate([tpr[: last + 1], np.array([t_cut], dtype=tpr.dtype)])
415
+ fpr_cut = np.concatenate([fpr[: last + 1], np.array([f_cut], dtype=fpr.dtype)])
416
+ thr_cut = np.concatenate([thr[: last + 1], np.array([score_cutoff], dtype=thr.dtype)])
417
+
418
+ return tpr_cut, fpr_cut, thr_cut
419
+
420
+
421
+ def cut_prc(rec: np.ndarray, prec: np.ndarray, thr: np.ndarray, score_cutoff: float):
422
+ """
423
+ Truncate Precision-Recall curve at a specific score threshold.
424
+
425
+ This function truncates the Precision-Recall curve at a given score threshold.
426
+ If interpolation is needed between points, it performs linear interpolation
427
+ to determine the precision and recall values at the exact score cutoff.
428
+
429
+ Parameters
430
+ ----------
431
+ rec : np.ndarray
432
+ Recall values from the Precision-Recall curve.
433
+ prec : np.ndarray
434
+ Precision values from the Precision-Recall curve.
435
+ thr : np.ndarray
436
+ Threshold values corresponding to each precision/recall pair.
437
+ score_cutoff : float
438
+ The score threshold at which to truncate the PRC.
439
+
440
+ Returns
441
+ -------
442
+ tuple
443
+ Tuple containing (truncated_recall, truncated_precision, truncated_thresholds).
444
+ """
445
+ if score_cutoff == -np.inf:
446
+ return rec, prec, thr
447
+
448
+ # thr starts with inf, then decreases
449
+ # find i: last index where thr[i] >= score_cutoff
450
+ mask = thr >= score_cutoff
451
+ if not np.any(mask):
452
+ # threshold too high -> almost empty
453
+ return (
454
+ np.array([0.0], dtype=rec.dtype),
455
+ np.array([prec[0]], dtype=prec.dtype),
456
+ np.array([score_cutoff], dtype=thr.dtype),
457
+ )
458
+
459
+ last = int(np.where(mask)[0][-1])
460
+
461
+ # if we hit the node exactly - just truncate
462
+ if thr[last] == score_cutoff or last == len(thr) - 1:
463
+ return rec[: last + 1], prec[: last + 1], thr[: last + 1]
464
+
465
+ # otherwise interpolate between last and last+1 by score
466
+ s0, s1 = float(thr[last]), float(thr[last + 1]) # s0 > cutoff > s1
467
+ r0, r1 = float(rec[last]), float(rec[last + 1])
468
+ p0, p1 = float(prec[last]), float(prec[last + 1])
469
+
470
+ alpha = 0.0 if s0 == s1 else (score_cutoff - s0) / (s1 - s0)
471
+ r_cut = r0 + alpha * (r1 - r0)
472
+ p_cut = p0 + alpha * (p1 - p0)
473
+
474
+ rec_cut = np.concatenate([rec[: last + 1], np.array([r_cut], dtype=rec.dtype)])
475
+ prec_cut = np.concatenate([prec[: last + 1], np.array([p_cut], dtype=prec.dtype)])
476
+ thr_cut = np.concatenate([thr[: last + 1], np.array([score_cutoff], dtype=thr.dtype)])
477
+
478
+ return rec_cut, prec_cut, thr_cut
479
+
480
+
481
+ def standardized_pauc(pauc_raw: float, pauc_min: float, pauc_max: float) -> float:
482
+ """
483
+ Standardize partial AUC value to range [0.5, 1].
484
+
485
+ This function standardizes a raw partial AUC value to a range between 0.5 and 1,
486
+ where 0.5 represents random performance and 1 represents perfect performance.
487
+ This standardization accounts for the theoretical minimum and maximum possible
488
+ partial AUC values for the given conditions.
489
+
490
+ Parameters
491
+ ----------
492
+ pauc_raw : float
493
+ Raw partial AUC value to standardize.
494
+ pauc_min : float
495
+ Minimum possible partial AUC value for the given conditions.
496
+ pauc_max : float
497
+ Maximum possible partial AUC value for the given conditions.
498
+
499
+ Returns
500
+ -------
501
+ float
502
+ Standardized partial AUC value in range [0.5, 1].
503
+ """
504
+ denom = pauc_max - pauc_min
505
+ if denom <= 0:
506
+ return 0.5
507
+ return 0.5 * (1.0 + (pauc_raw - pauc_min) / denom)
508
+
509
+
510
+ def scores_to_frequencies(ragged_scores: RaggedData) -> RaggedData:
511
+ """
512
+ Convert RaggedData containing scores to frequency representation.
513
+
514
+ This function computes log-frequency transformation of scores where each
515
+ unique score value is replaced by its negative log-frequency across all
516
+ sequences in the RaggedData structure.
517
+
518
+ Parameters
519
+ ----------
520
+ ragged_scores : RaggedData
521
+ Input RaggedData containing score values.
522
+
523
+ Returns
524
+ -------
525
+ RaggedData
526
+ RaggedData with transformed frequency values.
527
+ """
528
+ flat = ragged_scores.data
529
+ n = flat.size
530
+
531
+ if n == 0:
532
+ return RaggedData(np.zeros(0, dtype=np.float32), ragged_scores.offsets.copy())
533
+
534
+ _, inv, cnt = np.unique(flat, return_inverse=True, return_counts=True)
535
+ surv = np.cumsum(cnt[::-1])[::-1]
536
+
537
+ # To avoid log10(0)
538
+ eps = 1e-12
539
+ log_p = np.log10(n + eps) - np.log10(surv + eps)
540
+
541
+ new_data = log_p[inv].astype(np.float32)
542
+ return RaggedData(new_data, ragged_scores.offsets.copy())
543
+
544
+
545
+ @njit(fastmath=True, cache=True)
546
+ def _fast_overlap_kernel_numba(data1, offsets1, data2, offsets2, search_range):
547
+ """
548
+ Fast overlap coefficient kernel for RaggedData using JIT compilation.
549
+
550
+ This kernel computes the overlap coefficient (Szymkiewicz-Simpson coefficient)
551
+ between two sets of ragged sequences, finding the best alignment within
552
+ the specified search range.
553
+
554
+ Parameters
555
+ ----------
556
+ data1 : np.ndarray
557
+ Flattened data array for first sequence collection.
558
+ offsets1 : np.ndarray
559
+ Offsets for first sequence collection.
560
+ data2 : np.ndarray
561
+ Flattened data array for second sequence collection.
562
+ offsets2 : np.ndarray
563
+ Offsets for second sequence collection.
564
+ search_range : int
565
+ Range of offsets to search for best alignment (from -search_range to +search_range).
566
+
567
+ Returns
568
+ -------
569
+ tuple
570
+ Tuple containing (best_overlap, best_offset) where:
571
+ best_overlap : Maximum overlap coefficient found.
572
+ best_offset : Offset at which maximum overlap occurs.
573
+ """
574
+ n_seq = len(offsets1) - 1
575
+ n_offsets = 2 * search_range + 1
576
+
577
+ inters = np.zeros(n_offsets, dtype=np.float32)
578
+ sum1s = np.zeros(n_offsets, dtype=np.float32)
579
+ sum2s = np.zeros(n_offsets, dtype=np.float32)
580
+
581
+ for i in range(n_seq):
582
+ s1 = data1[offsets1[i] : offsets1[i + 1]]
583
+ s2 = data2[offsets2[i] : offsets2[i + 1]]
584
+ vlen1 = s1.size
585
+ vlen2 = s2.size
586
+
587
+ for k in range(n_offsets):
588
+ offset = k - search_range
589
+ idx1_start = 0 if offset < 0 else offset
590
+ idx2_start = -offset if offset < 0 else 0
591
+
592
+ if idx1_start >= vlen1 or idx2_start >= vlen2:
593
+ continue
594
+
595
+ overlap = min(vlen1 - idx1_start, vlen2 - idx2_start)
596
+ if overlap <= 0:
597
+ continue
598
+
599
+ local_inter = np.float32(0.0)
600
+ local_s1 = np.float32(0.0)
601
+ local_s2 = np.float32(0.0)
602
+
603
+ for j in range(overlap):
604
+ v1 = s1[idx1_start + j]
605
+ v2 = s2[idx2_start + j]
606
+ local_s1 += v1
607
+ local_s2 += v2
608
+ # min(v1, v2) = 0.5 * (v1 + v2 - |v1 - v2|)
609
+ local_inter += np.float32(0.5) * (v1 + v2 - abs(v1 - v2))
610
+
611
+ inters[k] += local_inter
612
+ sum1s[k] += local_s1
613
+ sum2s[k] += local_s2
614
+
615
+ best = -1.0
616
+ best_offset = 0
617
+ eps = 1e-6
618
+
619
+ for k in range(n_offsets):
620
+ denom = min(sum1s[k], sum2s[k])
621
+ if denom > eps:
622
+ val = inters[k] / denom
623
+ if val > best:
624
+ best = val
625
+ best_offset = k - search_range
626
+
627
+ return best, best_offset
628
+
629
+
630
+ @njit(fastmath=True, cache=True)
631
+ def _fast_cj_kernel_numba(data1, offsets1, data2, offsets2, search_range):
632
+ """
633
+ Fast Continues Jaccard (CJ) coefficient kernel for RaggedData using JIT compilation.
634
+
635
+ This kernel computes the continues jaccard coefficient between two sets of
636
+ ragged sequences, finding the best alignment within the specified search range.
637
+
638
+ Parameters
639
+ ----------
640
+ data1 : np.ndarray
641
+ Flattened data array for first sequence collection.
642
+ offsets1 : np.ndarray
643
+ Offsets for first sequence collection.
644
+ data2 : np.ndarray
645
+ Flattened data array for second sequence collection.
646
+ offsets2 : np.ndarray
647
+ Offsets for second sequence collection.
648
+ search_range : int
649
+ Range of offsets to search for best alignment (from -search_range to +search_range).
650
+
651
+ Returns
652
+ -------
653
+ tuple
654
+ Tuple containing (best_cj, best_offset) where:
655
+ best_cj : Maximum Czekanowski-Dice coefficient found.
656
+ best_offset : Offset at which maximum coefficient occurs.
657
+ """
658
+ n_seq = len(offsets1) - 1
659
+ n_offsets = 2 * search_range + 1
660
+
661
+ sums = np.zeros(n_offsets, dtype=np.float32)
662
+ diffs = np.zeros(n_offsets, dtype=np.float32)
663
+
664
+ for i in range(n_seq):
665
+ s1 = data1[offsets1[i] : offsets1[i + 1]]
666
+ s2 = data2[offsets2[i] : offsets2[i + 1]]
667
+ vlen1 = s1.size
668
+ vlen2 = s2.size
669
+
670
+ for k in range(n_offsets):
671
+ offset = k - search_range
672
+ idx1_start = 0 if offset < 0 else offset
673
+ idx2_start = -offset if offset < 0 else 0
674
+
675
+ if idx1_start >= vlen1 or idx2_start >= vlen2:
676
+ continue
677
+
678
+ overlap = min(vlen1 - idx1_start, vlen2 - idx2_start)
679
+ if overlap <= 0:
680
+ continue
681
+
682
+ local_sum = np.float32(0.0)
683
+ local_diff = np.float32(0.0)
684
+
685
+ for j in range(overlap):
686
+ v1 = s1[idx1_start + j]
687
+ v2 = s2[idx2_start + j]
688
+ local_sum += v1 + v2
689
+ local_diff += abs(v1 - v2)
690
+
691
+ sums[k] += local_sum
692
+ diffs[k] += local_diff
693
+
694
+ best_cj = -1.0
695
+ best_offset = 0
696
+ eps = 1e-6
697
+
698
+ for k in range(n_offsets):
699
+ S = sums[k]
700
+ D = diffs[k]
701
+ denom = S + D
702
+ if denom > eps:
703
+ cj = (S - D) / denom
704
+ if cj > best_cj:
705
+ best_cj = cj
706
+ best_offset = k - search_range
707
+
708
+ return best_cj, best_offset
709
+
710
+
711
+ def _fast_pearson_kernel(data1, offsets1, data2, offsets2, search_range):
712
+ """Pearson correlation kernel for RaggedData using numpy built-in functions."""
713
+
714
+ n_seq = len(offsets1) - 1
715
+ n_offsets = 2 * search_range + 1
716
+
717
+ correlations = np.zeros(n_offsets, dtype=np.float64)
718
+ pvalues = np.ones(n_offsets, dtype=np.float64)
719
+ valid_correlations = np.zeros(n_offsets, dtype=np.bool_)
720
+
721
+ for k in range(n_offsets):
722
+ offset = k - search_range
723
+
724
+ all_x_values = []
725
+ all_y_values = []
726
+
727
+ for i in range(n_seq):
728
+ s1 = data1[offsets1[i] : offsets1[i + 1]]
729
+ s2 = data2[offsets2[i] : offsets2[i + 1]]
730
+ vlen1 = s1.size
731
+ vlen2 = s2.size
732
+
733
+ idx1_start = 0 if offset < 0 else offset
734
+ idx2_start = -offset if offset < 0 else 0
735
+
736
+ if idx1_start >= vlen1 or idx2_start >= vlen2:
737
+ continue
738
+
739
+ overlap = min(vlen1 - idx1_start, vlen2 - idx2_start)
740
+ if overlap <= 0:
741
+ continue
742
+
743
+ x_vals = s1[idx1_start : idx1_start + overlap]
744
+ y_vals = s2[idx2_start : idx2_start + overlap]
745
+
746
+ all_x_values.extend(x_vals)
747
+ all_y_values.extend(y_vals)
748
+
749
+ if len(all_x_values) > 1: # Need at least 2 points for correlation
750
+ x_array = np.array(all_x_values, dtype=np.float64)
751
+ y_array = np.array(all_y_values, dtype=np.float64)
752
+
753
+ # Check if either array has zero variance
754
+ if np.var(x_array) > 1e-10 and np.var(y_array) > 1e-10:
755
+ corr_val, pvalue = pearsonr(x_array, y_array)
756
+ correlations[k] = corr_val
757
+ pvalues[k] = pvalue
758
+ valid_correlations[k] = True
759
+ else:
760
+ # If one variable has no variance, correlation is undefined (set to 0)
761
+ correlations[k] = 0.0
762
+ pvalues[k] = 1.0
763
+ valid_correlations[k] = True
764
+
765
+ # Find the best correlation among valid ones
766
+ best_corr = -2.0 # Pearson correlation ranges from -1 to 1
767
+ best_offset = 0
768
+ found_valid = False
769
+ best_pvalue = 1.0
770
+
771
+ for k in range(n_offsets):
772
+ if valid_correlations[k] and correlations[k] > best_corr:
773
+ best_corr = correlations[k]
774
+ best_pvalue = pvalues[k]
775
+ best_offset = k - search_range
776
+ found_valid = True
777
+
778
+ if not found_valid:
779
+ best_corr = 0.0
780
+ best_pvalue = 1.0
781
+
782
+ return best_corr, best_pvalue, best_offset
783
+
784
+
785
+ def format_params(params: dict) -> str:
786
+ keys = sorted(params.keys())
787
+ return "_".join(f"{k}-{params[k]}" for k in keys)