supremo-lite 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1014 @@
1
+ """
2
+ Utilities for aligning model predictions between reference and variant sequences.
3
+
4
+ This module provides functions to handle the alignment of ML model predictions
5
+ when reference and variant sequences have position offsets due to structural variants.
6
+
7
+ The alignment logic properly handles:
8
+ - 1D predictions (chromatin accessibility, TF binding, etc.)
9
+ - 2D contact maps (Hi-C, Micro-C predictions)
10
+ - All variant types: SNV, INS, DEL, DUP, INV, BND
11
+
12
+ Key principle: Users must specify all model-specific parameters (bin_size, diag_offset)
13
+ as these vary across different prediction models.
14
+ """
15
+
16
+ import numpy as np
17
+ import warnings
18
+ from typing import Tuple, List, Optional, Union
19
+ from dataclasses import dataclass
20
+
21
+ try:
22
+ import torch
23
+
24
+ TORCH_AVAILABLE = True
25
+ except ImportError:
26
+ TORCH_AVAILABLE = False
27
+
28
+
29
+ @dataclass
30
+ class VariantPosition:
31
+ """
32
+ Container for variant position information in both REF and ALT sequences.
33
+
34
+ This class encapsulates the essential positional information needed to align
35
+ predictions across reference and alternate sequences that may differ in length.
36
+ """
37
+
38
+ ref_pos: int # Position in reference sequence (base pairs, 0-based)
39
+ alt_pos: int # Position in alternate sequence (base pairs, 0-based)
40
+ svlen: int # Length of structural variant (base pairs, signed for DEL/INS)
41
+ variant_type: str # Type of variant ('SNV', 'INS', 'DEL', 'DUP', 'INV', 'BND')
42
+
43
+ def get_bin_positions(self, bin_size: int) -> Tuple[int, int, int]:
44
+ """
45
+ Convert base pair positions to bin indices.
46
+
47
+ Args:
48
+ bin_size: Number of base pairs per prediction bin
49
+
50
+ Returns:
51
+ Tuple of (ref_bin, alt_start_bin, alt_end_bin)
52
+ """
53
+ ref_bin = int(np.ceil(self.ref_pos / bin_size))
54
+ alt_start_bin = int(np.ceil(self.alt_pos / bin_size))
55
+ alt_end_bin = int(np.ceil((self.alt_pos + abs(self.svlen)) / bin_size))
56
+ return ref_bin, alt_start_bin, alt_end_bin
57
+
58
+
59
+ class PredictionAligner1D:
60
+ """
61
+ Aligns reference and alternate 1D prediction vectors for variant comparison.
62
+
63
+ Handles alignment of 1D genomic predictions (e.g., chromatin accessibility,
64
+ transcription factor binding, epigenetic marks) between reference and variant
65
+ sequences that may differ in length due to structural variants.
66
+
67
+ The aligner uses a masking strategy where positions that exist in one sequence
68
+ but not the other are marked with NaN values, enabling direct comparison of
69
+ corresponding genomic positions.
70
+
71
+ Args:
72
+ target_size: Expected number of bins in the prediction output
73
+ bin_size: Number of base pairs per prediction bin (model-specific)
74
+
75
+ Example:
76
+ >>> aligner = PredictionAligner1D(target_size=896, bin_size=128)
77
+ >>> ref_aligned, alt_aligned = aligner.align_predictions(
78
+ ... ref_pred, alt_pred, 'INS', variant_position
79
+ ... )
80
+ """
81
+
82
+ def __init__(self, target_size: int, bin_size: int):
83
+ """
84
+ Initialize the 1D prediction aligner.
85
+
86
+ Args:
87
+ target_size: Expected number of bins in prediction (e.g., 896 for Enformer)
88
+ bin_size: Base pairs per bin (e.g., 128 for Enformer)
89
+ """
90
+ self.target_size = target_size
91
+ self.bin_size = bin_size
92
+
93
+ def align_predictions(
94
+ self,
95
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
96
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
97
+ svtype: str,
98
+ var_pos: VariantPosition,
99
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
100
+ """
101
+ Main entry point for 1D prediction alignment.
102
+
103
+ Args:
104
+ ref_pred: Reference prediction vector (length N)
105
+ alt_pred: Alternate prediction vector (length N)
106
+ svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
107
+ var_pos: Variant position information
108
+
109
+ Returns:
110
+ Tuple of (aligned_ref, aligned_alt) vectors with NaN masking applied
111
+
112
+ Raises:
113
+ ValueError: For unsupported variant types or if using BND (use align_bnd_predictions)
114
+ """
115
+ if svtype == "BND" or svtype == "SV_BND":
116
+ raise ValueError("Use align_bnd_predictions() for breakends")
117
+
118
+ # Normalize variant type names
119
+ svtype_normalized = svtype.replace("SV_", "")
120
+
121
+ if svtype_normalized in ["DEL", "DUP", "INS"]:
122
+ return self._align_indel_predictions(
123
+ ref_pred, alt_pred, svtype_normalized, var_pos
124
+ )
125
+ elif svtype_normalized == "INV":
126
+ return self._align_inversion_predictions(ref_pred, alt_pred, var_pos)
127
+ elif svtype_normalized in ["SNV", "MNV"]:
128
+ # SNVs don't change coordinates, direct alignment
129
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
130
+ if is_torch:
131
+ return ref_pred.clone(), alt_pred.clone()
132
+ else:
133
+ return ref_pred.copy(), alt_pred.copy()
134
+ else:
135
+ raise ValueError(f"Unknown variant type: {svtype}")
136
+
137
+ def _align_indel_predictions(
138
+ self,
139
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
140
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
141
+ svtype: str,
142
+ var_pos: VariantPosition,
143
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
144
+ """
145
+ Align predictions for insertions, deletions, and duplications.
146
+
147
+ Strategy:
148
+ 1. For DEL: Swap REF/ALT (deletion removes from REF)
149
+ 2. Insert NaN bins in shorter sequence
150
+ 3. Crop edges to maintain target size
151
+ 4. For DEL: Swap back
152
+
153
+ This ensures that positions present in one sequence but not the other
154
+ are marked with NaN, enabling fair comparison of overlapping regions.
155
+ """
156
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
157
+
158
+ # Convert to numpy for manipulation
159
+ if is_torch:
160
+ ref_np = ref_pred.detach().cpu().numpy()
161
+ alt_np = alt_pred.detach().cpu().numpy()
162
+ else:
163
+ ref_np = ref_pred
164
+ alt_np = alt_pred
165
+
166
+ # Swap for deletions (treat as insertion in reverse)
167
+ if svtype == "DEL":
168
+ ref_np, alt_np = alt_np, ref_np
169
+ var_pos = VariantPosition(
170
+ var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
171
+ )
172
+
173
+ # Get bin positions
174
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
175
+ bins_to_add = alt_end_bin - alt_start_bin
176
+
177
+ # Insert NaN bins in REF where variant exists in ALT
178
+ ref_masked = self._insert_nan_bins(ref_np, ref_bin, bins_to_add)
179
+
180
+ # Crop to maintain target size
181
+ ref_masked = self._crop_vector(ref_masked, ref_bin, alt_start_bin)
182
+ alt_masked = alt_np.copy()
183
+
184
+ # Swap back for deletions
185
+ if svtype == "DEL":
186
+ ref_masked, alt_masked = alt_masked, ref_masked
187
+
188
+ self._validate_size(ref_masked, alt_masked)
189
+
190
+ # Convert back to torch if needed
191
+ if is_torch:
192
+ ref_masked = (
193
+ torch.from_numpy(ref_masked).to(ref_pred.device).type(ref_pred.dtype)
194
+ )
195
+ alt_masked = (
196
+ torch.from_numpy(alt_masked).to(alt_pred.device).type(alt_pred.dtype)
197
+ )
198
+
199
+ return ref_masked, alt_masked
200
+
201
+ def _insert_nan_bins(
202
+ self, vector: np.ndarray, position: int, num_bins: int
203
+ ) -> np.ndarray:
204
+ """
205
+ Insert NaN values at specified position in vector.
206
+
207
+ Args:
208
+ vector: Input prediction vector
209
+ position: Position to insert NaN values (bin index)
210
+ num_bins: Number of NaN values to insert
211
+
212
+ Returns:
213
+ Vector with NaN values inserted
214
+ """
215
+ result = vector.copy()
216
+ for offset in range(num_bins):
217
+ insert_pos = position + offset
218
+ result = np.insert(result, insert_pos, np.nan)
219
+ return result
220
+
221
+ def _crop_vector(
222
+ self, vector: np.ndarray, ref_bin: int, alt_bin: int
223
+ ) -> np.ndarray:
224
+ """
225
+ Crop vector edges to maintain target size.
226
+
227
+ After inserting NaN bins, the vector is longer than expected.
228
+ This function crops from edges proportionally to center the variant.
229
+
230
+ Args:
231
+ vector: Vector to crop
232
+ ref_bin: Reference bin position
233
+ alt_bin: Alternate bin position
234
+
235
+ Returns:
236
+ Cropped vector of target_size length
237
+ """
238
+ remove_left = ref_bin - alt_bin
239
+ remove_right = len(vector) - self.target_size - remove_left
240
+
241
+ # Apply cropping
242
+ start = max(0, remove_left)
243
+ end = len(vector) - max(0, remove_right)
244
+ return vector[start:end]
245
+
246
+ def _align_inversion_predictions(
247
+ self,
248
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
249
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
250
+ var_pos: VariantPosition,
251
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
252
+ """
253
+ Align predictions for inversions.
254
+
255
+ Strategy:
256
+ 1. Mask the inverted region in both vectors with NaN
257
+ 2. This allows comparison of only the flanking (unaffected) regions
258
+
259
+ For strand-aware models, inversions can significantly affect predictions
260
+ because regulatory elements now appear on the opposite strand. We mask
261
+ the inverted region to focus comparison on unaffected flanking sequences.
262
+ """
263
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
264
+
265
+ # Convert to numpy for manipulation
266
+ if is_torch:
267
+ ref_np = ref_pred.detach().cpu().numpy()
268
+ alt_np = alt_pred.detach().cpu().numpy()
269
+ else:
270
+ ref_np = ref_pred.copy()
271
+ alt_np = alt_pred.copy()
272
+
273
+ var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
274
+
275
+ # Mask inverted region in both REF and ALT
276
+ ref_np[var_start : var_end + 1] = np.nan
277
+ alt_np[var_start : var_end + 1] = np.nan
278
+
279
+ self._validate_size(ref_np, alt_np)
280
+
281
+ # Convert back to torch if needed
282
+ if is_torch:
283
+ ref_np = torch.from_numpy(ref_np).to(ref_pred.device).type(ref_pred.dtype)
284
+ alt_np = torch.from_numpy(alt_np).to(alt_pred.device).type(alt_pred.dtype)
285
+
286
+ return ref_np, alt_np
287
+
288
+ def align_bnd_predictions(
289
+ self,
290
+ left_ref: Union[np.ndarray, "torch.Tensor"],
291
+ right_ref: Union[np.ndarray, "torch.Tensor"],
292
+ bnd_alt: Union[np.ndarray, "torch.Tensor"],
293
+ breakpoint_bin: int,
294
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
295
+ """
296
+ Align predictions for breakends (chromosomal rearrangements).
297
+
298
+ BNDs join two distant loci, so we create a chimeric reference
299
+ prediction from the two separate loci for comparison with the fusion ALT.
300
+
301
+ Args:
302
+ left_ref: Prediction from left locus
303
+ right_ref: Prediction from right locus
304
+ bnd_alt: Prediction from joined (alternate) sequence
305
+ breakpoint_bin: Bin position of breakpoint
306
+
307
+ Returns:
308
+ Tuple of (chimeric_ref, alt) vectors
309
+ """
310
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(left_ref)
311
+
312
+ # Convert to numpy for manipulation
313
+ if is_torch:
314
+ left_np = left_ref.detach().cpu().numpy()
315
+ right_np = right_ref.detach().cpu().numpy()
316
+ alt_np = bnd_alt.detach().cpu().numpy()
317
+ else:
318
+ left_np = left_ref
319
+ right_np = right_ref
320
+ alt_np = bnd_alt
321
+
322
+ # Extract the relevant portions from each reference
323
+ left_portion = left_np[:breakpoint_bin]
324
+ right_portion = right_np[-(self.target_size - breakpoint_bin) :]
325
+
326
+ # Assemble chimeric reference (continuous, no masking)
327
+ ref_chimeric = np.concatenate([left_portion, right_portion])
328
+
329
+ self._validate_size(ref_chimeric, alt_np)
330
+
331
+ # Convert back to torch if needed
332
+ if is_torch:
333
+ ref_chimeric = (
334
+ torch.from_numpy(ref_chimeric).to(left_ref.device).type(left_ref.dtype)
335
+ )
336
+ alt_np = torch.from_numpy(alt_np).to(bnd_alt.device).type(bnd_alt.dtype)
337
+
338
+ return ref_chimeric, alt_np
339
+
340
+ def _validate_size(self, ref_vector: np.ndarray, alt_vector: np.ndarray):
341
+ """
342
+ Validate that vectors are the correct size.
343
+
344
+ Args:
345
+ ref_vector: Reference prediction vector
346
+ alt_vector: Alternate prediction vector
347
+
348
+ Raises:
349
+ ValueError: If either vector has incorrect size
350
+ """
351
+ if len(ref_vector) != self.target_size:
352
+ raise ValueError(
353
+ f"Reference vector wrong size: {len(ref_vector)} vs {self.target_size}"
354
+ )
355
+ if len(alt_vector) != self.target_size:
356
+ raise ValueError(
357
+ f"Alternate vector wrong size: {len(alt_vector)} vs {self.target_size}"
358
+ )
359
+
360
+
361
+ class PredictionAligner2D:
362
+ """
363
+ Aligns reference and alternate prediction matrices for variant comparison.
364
+
365
+ Handles alignment of 2D genomic predictions (e.g., Hi-C contact maps,
366
+ Micro-C predictions) between reference and variant sequences that may
367
+ differ in length due to structural variants.
368
+
369
+ The aligner uses a masking strategy where matrix rows and columns that
370
+ exist in one sequence but not the other are marked with NaN values.
371
+
372
+ Args:
373
+ target_size: Expected matrix dimension (NxN)
374
+ bin_size: Number of base pairs per matrix bin (model-specific)
375
+ diag_offset: Number of diagonal bins to mask (model-specific)
376
+
377
+ Example:
378
+ >>> aligner = PredictionAligner2D(
379
+ ... target_size=448,
380
+ ... bin_size=2048,
381
+ ... diag_offset=2
382
+ ... )
383
+ >>> ref_aligned, alt_aligned = aligner.align_predictions(
384
+ ... ref_matrix, alt_matrix, 'DEL', variant_position
385
+ ... )
386
+ """
387
+
388
+ def __init__(self, target_size: int, bin_size: int, diag_offset: int):
389
+ """
390
+ Initialize the 2D prediction aligner.
391
+
392
+ Args:
393
+ target_size: Matrix dimension (e.g., 448 for Akita)
394
+ bin_size: Base pairs per bin (e.g., 2048 for Akita)
395
+ diag_offset: Diagonal masking offset (e.g., 2 for Akita)
396
+ """
397
+ self.target_size = target_size
398
+ self.bin_size = bin_size
399
+ self.diag_offset = diag_offset
400
+
401
+ def align_predictions(
402
+ self,
403
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
404
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
405
+ svtype: str,
406
+ var_pos: VariantPosition,
407
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
408
+ """
409
+ Main entry point for 2D matrix alignment.
410
+
411
+ Args:
412
+ ref_pred: Reference prediction matrix (NxN)
413
+ alt_pred: Alternate prediction matrix (NxN)
414
+ svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
415
+ var_pos: Variant position information
416
+
417
+ Returns:
418
+ Tuple of (aligned_ref, aligned_alt) matrices with NaN masking applied
419
+
420
+ Raises:
421
+ ValueError: For unsupported variant types or if using BND (use align_bnd_matrices)
422
+ """
423
+ if svtype == "BND" or svtype == "SV_BND":
424
+ raise ValueError("Use align_bnd_matrices() for breakends")
425
+
426
+ # Normalize variant type names
427
+ svtype_normalized = svtype.replace("SV_", "")
428
+
429
+ if svtype_normalized in ["DEL", "DUP", "INS"]:
430
+ return self._align_indel_matrices(
431
+ ref_pred, alt_pred, svtype_normalized, var_pos
432
+ )
433
+ elif svtype_normalized == "INV":
434
+ return self._align_inversion_matrices(ref_pred, alt_pred, var_pos)
435
+ elif svtype_normalized in ["SNV", "MNV"]:
436
+ # SNVs don't change coordinates, direct alignment
437
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
438
+ if is_torch:
439
+ return ref_pred.clone(), alt_pred.clone()
440
+ else:
441
+ return ref_pred.copy(), alt_pred.copy()
442
+ else:
443
+ raise ValueError(f"Unknown variant type: {svtype}")
444
+
445
+ def _align_indel_matrices(
446
+ self,
447
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
448
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
449
+ svtype: str,
450
+ var_pos: VariantPosition,
451
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
452
+ """
453
+ Align matrices for insertions, deletions, and duplications.
454
+
455
+ Strategy:
456
+ 1. For DEL: Swap REF/ALT (deletion removes from REF)
457
+ 2. Insert NaN bins (rows AND columns) in shorter matrix
458
+ 3. Crop edges to maintain target size
459
+ 4. For DEL: Swap back
460
+ """
461
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
462
+
463
+ # Convert to numpy for manipulation
464
+ if is_torch:
465
+ ref_np = ref_pred.detach().cpu().numpy()
466
+ alt_np = alt_pred.detach().cpu().numpy()
467
+ else:
468
+ ref_np = ref_pred
469
+ alt_np = alt_pred
470
+
471
+ # Swap for deletions (treat as insertion in reverse)
472
+ if svtype == "DEL":
473
+ ref_np, alt_np = alt_np, ref_np
474
+ var_pos = VariantPosition(
475
+ var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
476
+ )
477
+
478
+ # Get bin positions
479
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
480
+ bins_to_add = alt_end_bin - alt_start_bin
481
+
482
+ # Insert NaN bins in REF where variant exists in ALT
483
+ ref_masked = self._insert_nan_bins(ref_np, ref_bin, bins_to_add)
484
+
485
+ # Crop to maintain target size
486
+ ref_masked = self._crop_matrix(ref_masked, ref_bin, alt_start_bin)
487
+ alt_masked = alt_np.copy()
488
+
489
+ # Swap back for deletions
490
+ if svtype == "DEL":
491
+ ref_masked, alt_masked = alt_masked, ref_masked
492
+
493
+ self._validate_size(ref_masked, alt_masked)
494
+
495
+ # Convert back to torch if needed
496
+ if is_torch:
497
+ ref_masked = (
498
+ torch.from_numpy(ref_masked).to(ref_pred.device).type(ref_pred.dtype)
499
+ )
500
+ alt_masked = (
501
+ torch.from_numpy(alt_masked).to(alt_pred.device).type(alt_pred.dtype)
502
+ )
503
+
504
+ return ref_masked, alt_masked
505
+
506
+ def _insert_nan_bins(
507
+ self, matrix: np.ndarray, position: int, num_bins: int
508
+ ) -> np.ndarray:
509
+ """
510
+ Insert NaN bins (rows and columns) at specified position.
511
+
512
+ For 2D matrices, we must insert both rows AND columns to maintain
513
+ the square matrix structure and properly mask interactions.
514
+ """
515
+ result = matrix.copy()
516
+ for offset in range(num_bins):
517
+ insert_pos = position + offset
518
+ result = np.insert(result, insert_pos, np.nan, axis=0)
519
+ result = np.insert(result, insert_pos, np.nan, axis=1)
520
+ return result
521
+
522
+ def _crop_matrix(
523
+ self, matrix: np.ndarray, ref_bin: int, alt_bin: int
524
+ ) -> np.ndarray:
525
+ """
526
+ Crop matrix edges to maintain target size.
527
+
528
+ After inserting NaN bins, the matrix is larger than expected.
529
+ This function crops from edges proportionally to center the variant.
530
+ """
531
+ remove_left = ref_bin - alt_bin
532
+ remove_right = len(matrix) - self.target_size - remove_left
533
+
534
+ # Apply cropping
535
+ start = max(0, remove_left)
536
+ end = len(matrix) - max(0, remove_right)
537
+ return matrix[start:end, start:end]
538
+
539
+ def _align_inversion_matrices(
540
+ self,
541
+ ref_pred: Union[np.ndarray, "torch.Tensor"],
542
+ alt_pred: Union[np.ndarray, "torch.Tensor"],
543
+ var_pos: VariantPosition,
544
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
545
+ """
546
+ Align matrices for inversions.
547
+
548
+ Strategy: Mask the inverted region in both REF and ALT matrices using
549
+ a cross-pattern (mask entire rows AND columns at the inversion position).
550
+
551
+ Why cross-masking?
552
+ - Inversions reverse the sequence, creating geometric rotation in contact maps
553
+ - Masking rows removes interactions with the inverted region (one dimension)
554
+ - Masking columns removes interactions from the inverted region (other dimension)
555
+ - This ensures only flanking regions (unaffected by inversion) are compared
556
+
557
+ The same NaN pattern is mirrored to ALT so both matrices have identical
558
+ masked regions, enabling fair comparison of the unaffected areas.
559
+ """
560
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
561
+
562
+ # Convert to numpy for manipulation
563
+ if is_torch:
564
+ ref_np = ref_pred.detach().cpu().numpy()
565
+ alt_np = alt_pred.detach().cpu().numpy()
566
+ else:
567
+ ref_np = ref_pred.copy()
568
+ alt_np = alt_pred.copy()
569
+
570
+ var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
571
+
572
+ # Mask inverted region in REF (cross-pattern: rows + columns)
573
+ ref_np[var_start : var_end + 1, :] = np.nan
574
+ ref_np[:, var_start : var_end + 1] = np.nan
575
+
576
+ # Mirror NaN pattern to ALT (correct approach)
577
+ nan_mask = ref_np.copy()
578
+ nan_mask[np.invert(np.isnan(nan_mask))] = 0 # Non-NaN → 0, NaN stays NaN
579
+ alt_np = alt_np + nan_mask # Adding NaN propagates NaN to ALT
580
+
581
+ self._validate_size(ref_np, alt_np)
582
+
583
+ # Convert back to torch if needed
584
+ if is_torch:
585
+ ref_np = torch.from_numpy(ref_np).to(ref_pred.device).type(ref_pred.dtype)
586
+ alt_np = torch.from_numpy(alt_np).to(alt_pred.device).type(alt_pred.dtype)
587
+
588
+ return ref_np, alt_np
589
+
590
+ def align_bnd_matrices(
591
+ self,
592
+ left_ref: Union[np.ndarray, "torch.Tensor"],
593
+ right_ref: Union[np.ndarray, "torch.Tensor"],
594
+ bnd_alt: Union[np.ndarray, "torch.Tensor"],
595
+ breakpoint_bin: int,
596
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
597
+ """
598
+ Align matrices for breakends (chromosomal rearrangements).
599
+
600
+ BNDs join two distant loci, so we create a chimeric reference
601
+ matrix from the two separate loci.
602
+
603
+ Args:
604
+ left_ref: Prediction from left locus
605
+ right_ref: Prediction from right locus
606
+ bnd_alt: Prediction from joined (alternate) sequence
607
+ breakpoint_bin: Bin position of breakpoint
608
+
609
+ Returns:
610
+ Tuple of (chimeric_ref, alt) matrices
611
+ """
612
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(left_ref)
613
+
614
+ # Convert to numpy for manipulation
615
+ if is_torch:
616
+ left_np = left_ref.detach().cpu().numpy()
617
+ right_np = right_ref.detach().cpu().numpy()
618
+ alt_np = bnd_alt.detach().cpu().numpy()
619
+ else:
620
+ left_np = left_ref
621
+ right_np = right_ref
622
+ alt_np = bnd_alt
623
+
624
+ # Assemble chimeric matrix from two loci
625
+ ref_chimeric = self._assemble_chimeric_matrix(left_np, right_np, breakpoint_bin)
626
+
627
+ self._validate_size(ref_chimeric, alt_np)
628
+
629
+ # Convert back to torch if needed
630
+ if is_torch:
631
+ ref_chimeric = (
632
+ torch.from_numpy(ref_chimeric).to(left_ref.device).type(left_ref.dtype)
633
+ )
634
+ alt_np = torch.from_numpy(alt_np).to(bnd_alt.device).type(bnd_alt.dtype)
635
+
636
+ return ref_chimeric, alt_np
637
+
638
+ def _assemble_chimeric_matrix(
639
+ self, left_matrix: np.ndarray, right_matrix: np.ndarray, breakpoint: int
640
+ ) -> np.ndarray:
641
+ """
642
+ Assemble chimeric matrix from two loci.
643
+
644
+ Structure:
645
+ - Upper left quadrant: left locus
646
+ - Lower right quadrant: right locus
647
+ - Upper right/lower left quadrants: NaN (no trans prediction)
648
+ """
649
+ matrix = np.zeros((self.target_size, self.target_size))
650
+
651
+ # Fill upper left quadrant (left locus)
652
+ matrix[:breakpoint, :breakpoint] = left_matrix[:breakpoint, :breakpoint]
653
+
654
+ # Fill lower right quadrant (right locus)
655
+ matrix[breakpoint:, breakpoint:] = right_matrix[breakpoint:, breakpoint:]
656
+
657
+ # Fill transition quadrants with NaN
658
+ matrix[:breakpoint, breakpoint:] = np.nan
659
+ matrix[breakpoint:, :breakpoint] = np.nan
660
+
661
+ # Mask diagonals as specified by model
662
+ for offset in range(-self.diag_offset + 1, self.diag_offset):
663
+ if offset < 0:
664
+ np.fill_diagonal(matrix[abs(offset) :, :], np.nan)
665
+ else:
666
+ np.fill_diagonal(matrix[:, offset:], np.nan)
667
+
668
+ return matrix
669
+
670
+ def _validate_size(self, ref_matrix: np.ndarray, alt_matrix: np.ndarray):
671
+ """
672
+ Validate that matrices are the correct size.
673
+
674
+ Args:
675
+ ref_matrix: Reference prediction matrix
676
+ alt_matrix: Alternate prediction matrix
677
+
678
+ Raises:
679
+ ValueError: If either matrix has incorrect size
680
+ """
681
+ if ref_matrix.shape[0] != self.target_size:
682
+ raise ValueError(
683
+ f"Reference matrix wrong size: {ref_matrix.shape[0]} vs {self.target_size}"
684
+ )
685
+ if alt_matrix.shape[0] != self.target_size:
686
+ raise ValueError(
687
+ f"Alternate matrix wrong size: {alt_matrix.shape[0]} vs {self.target_size}"
688
+ )
689
+
690
+
691
+ # =============================================================================
692
+ # Utility Functions for Contact Maps
693
+ # =============================================================================
694
+
695
+
696
+ def vector_to_contact_matrix(
697
+ vector: Union[np.ndarray, "torch.Tensor"], matrix_size: int, diag_offset: int = 0
698
+ ) -> Union[np.ndarray, "torch.Tensor"]:
699
+ """
700
+ Convert flattened upper triangular vector to full contact matrix.
701
+
702
+ This function reconstructs a full symmetric contact matrix from its upper
703
+ triangular representation, following the pattern used in genomic contact map models.
704
+ Supports diagonal masking where near-diagonal elements are excluded.
705
+
706
+ Args:
707
+ vector: Flattened upper triangular matrix
708
+ Expected length depends on diag_offset:
709
+ - diag_offset=0: matrix_size * (matrix_size + 1) / 2 (includes diagonal)
710
+ - diag_offset=k: (matrix_size - k) * (matrix_size - k + 1) / 2
711
+ matrix_size: Dimension of the output square matrix
712
+ diag_offset: Diagonal offset for masking (default=0, no masking)
713
+ diag_offset=2 means skip main diagonal and first off-diagonal
714
+
715
+ Returns:
716
+ Full symmetric contact matrix of shape (matrix_size, matrix_size)
717
+ Elements within diag_offset of the diagonal are set to NaN
718
+
719
+ Example:
720
+ >>> # Full upper triangle (diag_offset=0, default)
721
+ >>> vector = np.array([1, 2, 3, 4, 5, 6])
722
+ >>> matrix = vector_to_contact_matrix(vector, 3)
723
+ >>> # Result: [[1, 2, 3], [2, 4, 5], [3, 5, 6]]
724
+ """
725
+ # Handle both PyTorch tensors and NumPy arrays
726
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(vector)
727
+
728
+ if is_torch:
729
+ # Initialize matrix with NaN
730
+ matrix = torch.full(
731
+ (matrix_size, matrix_size),
732
+ float("nan"),
733
+ dtype=vector.dtype,
734
+ device=vector.device,
735
+ )
736
+ # Get upper triangle indices with diagonal offset
737
+ triu_indices = torch.triu_indices(matrix_size, matrix_size, offset=diag_offset)
738
+ matrix[triu_indices[0], triu_indices[1]] = vector
739
+ # Make symmetric by copying upper triangle to lower (preserving NaN)
740
+ for i in range(matrix_size):
741
+ for j in range(i + diag_offset, matrix_size):
742
+ if not torch.isnan(matrix[i, j]):
743
+ matrix[j, i] = matrix[i, j]
744
+ else:
745
+ # Initialize matrix with NaN
746
+ matrix = np.full((matrix_size, matrix_size), np.nan, dtype=vector.dtype)
747
+ # Get upper triangle indices with diagonal offset
748
+ triu_indices = np.triu_indices(matrix_size, k=diag_offset)
749
+ matrix[triu_indices] = vector
750
+ # Make symmetric by copying upper triangle to lower (preserving NaN)
751
+ for i in range(matrix_size):
752
+ for j in range(i + diag_offset, matrix_size):
753
+ if not np.isnan(matrix[i, j]):
754
+ matrix[j, i] = matrix[i, j]
755
+
756
+ return matrix
757
+
758
+
759
+ def contact_matrix_to_vector(
760
+ matrix: Union[np.ndarray, "torch.Tensor"], diag_offset: int = 0
761
+ ) -> Union[np.ndarray, "torch.Tensor"]:
762
+ """
763
+ Convert full contact matrix to flattened upper triangular vector.
764
+
765
+ This function extracts the upper triangular portion of a contact matrix,
766
+ which is the standard representation for genomic contact maps. Supports
767
+ diagonal masking to exclude near-diagonal elements.
768
+
769
+ Args:
770
+ matrix: Full symmetric contact matrix of shape (N, N)
771
+ diag_offset: Diagonal offset for extraction (default=0, includes diagonal)
772
+ diag_offset=2 means skip main diagonal and first off-diagonal
773
+
774
+ Returns:
775
+ Flattened upper triangular vector
776
+ Length depends on diag_offset:
777
+ - diag_offset=0: N * (N + 1) / 2 (includes diagonal)
778
+ - diag_offset=k: (N - k) * (N - k + 1) / 2
779
+
780
+ Example:
781
+ >>> # Full upper triangle (diag_offset=0, default)
782
+ >>> matrix = np.array([[1, 2, 3], [2, 4, 5], [3, 5, 6]])
783
+ >>> vector = contact_matrix_to_vector(matrix)
784
+ >>> # Result: [1, 2, 3, 4, 5, 6]
785
+ """
786
+ # Handle both PyTorch tensors and NumPy arrays
787
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(matrix)
788
+
789
+ if is_torch:
790
+ triu_indices = torch.triu_indices(
791
+ matrix.shape[0], matrix.shape[1], offset=diag_offset
792
+ )
793
+ return matrix[triu_indices[0], triu_indices[1]]
794
+ else:
795
+ triu_indices = np.triu_indices(matrix.shape[0], k=diag_offset)
796
+ return matrix[triu_indices]
797
+
798
+
799
+ def align_predictions_by_coordinate(
800
+ ref_preds: Union[np.ndarray, "torch.Tensor"],
801
+ alt_preds: Union[np.ndarray, "torch.Tensor"],
802
+ metadata_row: dict,
803
+ bin_size: int,
804
+ prediction_type: str,
805
+ matrix_size: Optional[int] = None,
806
+ diag_offset: int = 0,
807
+ ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
808
+ """
809
+ Align reference and alt predictions using coordinate transformation and variant type awareness.
810
+
811
+ This is the main public API for prediction alignment. It handles both 1D prediction
812
+ vectors (e.g., chromatin accessibility, TF binding) and 2D matrices (e.g., Hi-C contact maps),
813
+ routing to the appropriate alignment strategy based on variant type.
814
+
815
+ IMPORTANT: Model-specific parameters (bin_size, matrix_size) must be explicitly
816
+ provided by the user. There are no defaults because these vary across different models.
817
+
818
+ Args:
819
+ ref_preds: Reference predictions array (from model with edge cropping)
820
+ alt_preds: Alt predictions array (same shape as ref_preds)
821
+ metadata_row: Dictionary with variant information containing:
822
+ - 'variant_type': Type of variant (SNV, INS, DEL, DUP, INV, BND)
823
+ - 'window_start': Start position of window (0-based)
824
+ - 'variant_pos0': Variant position (0-based, absolute genomic coordinate)
825
+ - 'svlen': Length of structural variant (optional, for symbolic alleles)
826
+ bin_size: Number of base pairs per prediction bin (REQUIRED, model-specific)
827
+ Examples: 2048 for Akita
828
+ prediction_type: Type of predictions ("1D" or "2D")
829
+ - "1D": Vector predictions (chromatin accessibility, TF binding, etc.)
830
+ - "2D": Matrix predictions (Hi-C contact maps, Micro-C, etc.)
831
+ matrix_size: Size of contact matrix (REQUIRED for 2D type)
832
+ Examples: 448 for Akita
833
+ diag_offset: Number of diagonal bins to mask (default: 0 for no masking)
834
+ Set to 0 if your model doesn't mask diagonals, or to model-specific value
835
+ Examples: 2 for Akita, 0 for models without diagonal masking
836
+
837
+ Returns:
838
+ Tuple of (aligned_ref_preds, aligned_alt_preds) with NaN masking applied
839
+ at positions that differ between reference and alternate sequences
840
+
841
+ Raises:
842
+ ValueError: If prediction_type is invalid, required parameters are missing,
843
+ or variant type is unsupported
844
+
845
+ Example (1D predictions):
846
+ >>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
847
+ ... ref_preds=ref_scores,
848
+ ... alt_preds=alt_scores,
849
+ ... metadata_row={'variant_type': 'INS', 'window_start': 0,
850
+ ... 'variant_pos0': 500, 'svlen': 100},
851
+ ... bin_size=128,
852
+ ... prediction_type="1D"
853
+ ... )
854
+
855
+ Example (2D contact maps with diagonal masking):
856
+ >>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
857
+ ... ref_preds=ref_contact_map,
858
+ ... alt_preds=alt_contact_map,
859
+ ... metadata_row={'variant_type': 'DEL', 'window_start': 0,
860
+ ... 'variant_pos0': 50000, 'svlen': -2048},
861
+ ... bin_size=2048,
862
+ ... prediction_type="2D",
863
+ ... matrix_size=448,
864
+ ... diag_offset=2 # Optional: use 0 if no diagonal masking
865
+ ... )
866
+
867
+ Example (2D contact maps without diagonal masking):
868
+ >>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
869
+ ... ref_preds=ref_contact_map,
870
+ ... alt_preds=alt_contact_map,
871
+ ... metadata_row={'variant_type': 'INS', 'window_start': 0,
872
+ ... 'variant_pos0': 1000, 'svlen': 500},
873
+ ... bin_size=1000,
874
+ ... prediction_type="2D",
875
+ ... matrix_size=512
876
+ ... # diag_offset defaults to 0 (no masking)
877
+ ... )
878
+ """
879
+ # Validate prediction type and parameters
880
+ if prediction_type not in ["1D", "2D"]:
881
+ raise ValueError(
882
+ f"prediction_type must be '1D' or '2D', got '{prediction_type}'"
883
+ )
884
+
885
+ if prediction_type == "2D":
886
+ if matrix_size is None:
887
+ raise ValueError("matrix_size must be provided for 2D prediction type")
888
+
889
+ # Extract variant information from metadata
890
+ variant_type = metadata_row.get("variant_type", "unknown")
891
+ window_start = metadata_row.get("window_start", 0)
892
+ variant_pos0 = metadata_row.get("variant_pos0")
893
+
894
+ # For backward compatibility, check for effective_variant_start (deprecated)
895
+ if variant_pos0 is not None:
896
+ abs_variant_pos = variant_pos0
897
+ else:
898
+ # Fall back to old field name if present (for backward compatibility)
899
+ effective_variant_start = metadata_row.get("effective_variant_start", 0)
900
+ abs_variant_pos = window_start + effective_variant_start
901
+
902
+ svlen = metadata_row.get("svlen", None)
903
+
904
+ # Calculate svlen from alleles if not present (for non-symbolic DEL/INS variants)
905
+ # Symbolic variants (SV_DEL, SV_INS, SV_INV, SV_DUP) have svlen in INFO field
906
+ # Regular variants (DEL, INS) need svlen calculated from allele lengths
907
+ if svlen is None or svlen == 0:
908
+ ref_allele = metadata_row.get("ref", "")
909
+ alt_allele = metadata_row.get("alt", "")
910
+ if ref_allele and alt_allele:
911
+ # svlen = len(alt) - len(ref)
912
+ # For DEL: negative (e.g., 1 - 13 = -12)
913
+ # For INS: positive (e.g., 7 - 1 = 6)
914
+ svlen = len(alt_allele) - len(ref_allele)
915
+
916
+ # Create VariantPosition object
917
+ var_pos = VariantPosition(
918
+ ref_pos=abs_variant_pos,
919
+ alt_pos=abs_variant_pos,
920
+ svlen=svlen if svlen is not None else 0,
921
+ variant_type=variant_type,
922
+ )
923
+
924
+ # Determine target size from predictions
925
+ if prediction_type == "1D":
926
+ # Check if predictions are multi-dimensional (multiple targets)
927
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_preds)
928
+ is_numpy = isinstance(ref_preds, np.ndarray)
929
+
930
+ if is_torch:
931
+ ndim = len(ref_preds.shape)
932
+ elif is_numpy:
933
+ ndim = ref_preds.ndim
934
+ else:
935
+ ndim = 1
936
+
937
+ # Handle multi-target predictions [n_targets, n_bins]
938
+ if ndim > 1:
939
+ target_size = ref_preds.shape[-1] # Number of bins
940
+ aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
941
+
942
+ # Align each target separately
943
+ n_targets = ref_preds.shape[0]
944
+ ref_aligned_list = []
945
+ alt_aligned_list = []
946
+
947
+ for target_idx in range(n_targets):
948
+ ref_target = ref_preds[target_idx]
949
+ alt_target = alt_preds[target_idx]
950
+ ref_aligned, alt_aligned = aligner.align_predictions(
951
+ ref_target, alt_target, variant_type, var_pos
952
+ )
953
+ ref_aligned_list.append(ref_aligned)
954
+ alt_aligned_list.append(alt_aligned)
955
+
956
+ # Stack results back into multi-target format
957
+ if is_torch:
958
+ ref_result = torch.stack(ref_aligned_list)
959
+ alt_result = torch.stack(alt_aligned_list)
960
+ else:
961
+ ref_result = np.stack(ref_aligned_list)
962
+ alt_result = np.stack(alt_aligned_list)
963
+
964
+ return ref_result, alt_result
965
+ else:
966
+ # Single target prediction [n_bins]
967
+ target_size = len(ref_preds)
968
+ aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
969
+ return aligner.align_predictions(
970
+ ref_preds, alt_preds, variant_type, var_pos
971
+ )
972
+ else: # 2D
973
+ # Check if predictions are 1D (flattened upper triangular) or 2D (full matrix)
974
+ is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_preds)
975
+
976
+ if is_torch:
977
+ ndim = len(ref_preds.shape)
978
+ else:
979
+ ndim = ref_preds.ndim
980
+
981
+ # If 1D, convert to 2D matrix
982
+ if ndim == 1:
983
+ ref_matrix = vector_to_contact_matrix(
984
+ ref_preds, matrix_size, diag_offset=diag_offset
985
+ )
986
+ alt_matrix = vector_to_contact_matrix(
987
+ alt_preds, matrix_size, diag_offset=diag_offset
988
+ )
989
+
990
+ # Align matrices
991
+ aligner = PredictionAligner2D(
992
+ target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
993
+ )
994
+ aligned_ref_matrix, aligned_alt_matrix = aligner.align_predictions(
995
+ ref_matrix, alt_matrix, variant_type, var_pos
996
+ )
997
+
998
+ # Convert back to flattened format
999
+ aligned_ref_vector = contact_matrix_to_vector(
1000
+ aligned_ref_matrix, diag_offset=diag_offset
1001
+ )
1002
+ aligned_alt_vector = contact_matrix_to_vector(
1003
+ aligned_alt_matrix, diag_offset=diag_offset
1004
+ )
1005
+
1006
+ return aligned_ref_vector, aligned_alt_vector
1007
+ else:
1008
+ # Already 2D matrices
1009
+ aligner = PredictionAligner2D(
1010
+ target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
1011
+ )
1012
+ return aligner.align_predictions(
1013
+ ref_preds, alt_preds, variant_type, var_pos
1014
+ )