supremo-lite 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supremo_lite/__init__.py +59 -0
- supremo_lite/chromosome_utils.py +322 -0
- supremo_lite/core.py +41 -0
- supremo_lite/mock_models/__init__.py +110 -0
- supremo_lite/mock_models/testmodel_1d.py +184 -0
- supremo_lite/mock_models/testmodel_2d.py +203 -0
- supremo_lite/mutagenesis.py +414 -0
- supremo_lite/personalize.py +3098 -0
- supremo_lite/prediction_alignment.py +1014 -0
- supremo_lite/sequence_utils.py +137 -0
- supremo_lite/variant_utils.py +1645 -0
- supremo_lite-0.5.4.dist-info/METADATA +216 -0
- supremo_lite-0.5.4.dist-info/RECORD +15 -0
- supremo_lite-0.5.4.dist-info/WHEEL +4 -0
- supremo_lite-0.5.4.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,1014 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for aligning model predictions between reference and variant sequences.
|
|
3
|
+
|
|
4
|
+
This module provides functions to handle the alignment of ML model predictions
|
|
5
|
+
when reference and variant sequences have position offsets due to structural variants.
|
|
6
|
+
|
|
7
|
+
The alignment logic properly handles:
|
|
8
|
+
- 1D predictions (chromatin accessibility, TF binding, etc.)
|
|
9
|
+
- 2D contact maps (Hi-C, Micro-C predictions)
|
|
10
|
+
- All variant types: SNV, INS, DEL, DUP, INV, BND
|
|
11
|
+
|
|
12
|
+
Key principle: Users must specify all model-specific parameters (bin_size, diag_offset)
|
|
13
|
+
as these vary across different prediction models.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
import warnings
|
|
18
|
+
from typing import Tuple, List, Optional, Union
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
TORCH_AVAILABLE = True
|
|
25
|
+
except ImportError:
|
|
26
|
+
TORCH_AVAILABLE = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class VariantPosition:
|
|
31
|
+
"""
|
|
32
|
+
Container for variant position information in both REF and ALT sequences.
|
|
33
|
+
|
|
34
|
+
This class encapsulates the essential positional information needed to align
|
|
35
|
+
predictions across reference and alternate sequences that may differ in length.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
ref_pos: int # Position in reference sequence (base pairs, 0-based)
|
|
39
|
+
alt_pos: int # Position in alternate sequence (base pairs, 0-based)
|
|
40
|
+
svlen: int # Length of structural variant (base pairs, signed for DEL/INS)
|
|
41
|
+
variant_type: str # Type of variant ('SNV', 'INS', 'DEL', 'DUP', 'INV', 'BND')
|
|
42
|
+
|
|
43
|
+
def get_bin_positions(self, bin_size: int) -> Tuple[int, int, int]:
|
|
44
|
+
"""
|
|
45
|
+
Convert base pair positions to bin indices.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
bin_size: Number of base pairs per prediction bin
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tuple of (ref_bin, alt_start_bin, alt_end_bin)
|
|
52
|
+
"""
|
|
53
|
+
ref_bin = int(np.ceil(self.ref_pos / bin_size))
|
|
54
|
+
alt_start_bin = int(np.ceil(self.alt_pos / bin_size))
|
|
55
|
+
alt_end_bin = int(np.ceil((self.alt_pos + abs(self.svlen)) / bin_size))
|
|
56
|
+
return ref_bin, alt_start_bin, alt_end_bin
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class PredictionAligner1D:
|
|
60
|
+
"""
|
|
61
|
+
Aligns reference and alternate 1D prediction vectors for variant comparison.
|
|
62
|
+
|
|
63
|
+
Handles alignment of 1D genomic predictions (e.g., chromatin accessibility,
|
|
64
|
+
transcription factor binding, epigenetic marks) between reference and variant
|
|
65
|
+
sequences that may differ in length due to structural variants.
|
|
66
|
+
|
|
67
|
+
The aligner uses a masking strategy where positions that exist in one sequence
|
|
68
|
+
but not the other are marked with NaN values, enabling direct comparison of
|
|
69
|
+
corresponding genomic positions.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
target_size: Expected number of bins in the prediction output
|
|
73
|
+
bin_size: Number of base pairs per prediction bin (model-specific)
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
>>> aligner = PredictionAligner1D(target_size=896, bin_size=128)
|
|
77
|
+
>>> ref_aligned, alt_aligned = aligner.align_predictions(
|
|
78
|
+
... ref_pred, alt_pred, 'INS', variant_position
|
|
79
|
+
... )
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, target_size: int, bin_size: int):
|
|
83
|
+
"""
|
|
84
|
+
Initialize the 1D prediction aligner.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
target_size: Expected number of bins in prediction (e.g., 896 for Enformer)
|
|
88
|
+
bin_size: Base pairs per bin (e.g., 128 for Enformer)
|
|
89
|
+
"""
|
|
90
|
+
self.target_size = target_size
|
|
91
|
+
self.bin_size = bin_size
|
|
92
|
+
|
|
93
|
+
def align_predictions(
|
|
94
|
+
self,
|
|
95
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
96
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
97
|
+
svtype: str,
|
|
98
|
+
var_pos: VariantPosition,
|
|
99
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
100
|
+
"""
|
|
101
|
+
Main entry point for 1D prediction alignment.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
ref_pred: Reference prediction vector (length N)
|
|
105
|
+
alt_pred: Alternate prediction vector (length N)
|
|
106
|
+
svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
|
|
107
|
+
var_pos: Variant position information
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Tuple of (aligned_ref, aligned_alt) vectors with NaN masking applied
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: For unsupported variant types or if using BND (use align_bnd_predictions)
|
|
114
|
+
"""
|
|
115
|
+
if svtype == "BND" or svtype == "SV_BND":
|
|
116
|
+
raise ValueError("Use align_bnd_predictions() for breakends")
|
|
117
|
+
|
|
118
|
+
# Normalize variant type names
|
|
119
|
+
svtype_normalized = svtype.replace("SV_", "")
|
|
120
|
+
|
|
121
|
+
if svtype_normalized in ["DEL", "DUP", "INS"]:
|
|
122
|
+
return self._align_indel_predictions(
|
|
123
|
+
ref_pred, alt_pred, svtype_normalized, var_pos
|
|
124
|
+
)
|
|
125
|
+
elif svtype_normalized == "INV":
|
|
126
|
+
return self._align_inversion_predictions(ref_pred, alt_pred, var_pos)
|
|
127
|
+
elif svtype_normalized in ["SNV", "MNV"]:
|
|
128
|
+
# SNVs don't change coordinates, direct alignment
|
|
129
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
130
|
+
if is_torch:
|
|
131
|
+
return ref_pred.clone(), alt_pred.clone()
|
|
132
|
+
else:
|
|
133
|
+
return ref_pred.copy(), alt_pred.copy()
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(f"Unknown variant type: {svtype}")
|
|
136
|
+
|
|
137
|
+
def _align_indel_predictions(
|
|
138
|
+
self,
|
|
139
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
140
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
141
|
+
svtype: str,
|
|
142
|
+
var_pos: VariantPosition,
|
|
143
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
144
|
+
"""
|
|
145
|
+
Align predictions for insertions, deletions, and duplications.
|
|
146
|
+
|
|
147
|
+
Strategy:
|
|
148
|
+
1. For DEL: Swap REF/ALT (deletion removes from REF)
|
|
149
|
+
2. Insert NaN bins in shorter sequence
|
|
150
|
+
3. Crop edges to maintain target size
|
|
151
|
+
4. For DEL: Swap back
|
|
152
|
+
|
|
153
|
+
This ensures that positions present in one sequence but not the other
|
|
154
|
+
are marked with NaN, enabling fair comparison of overlapping regions.
|
|
155
|
+
"""
|
|
156
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
157
|
+
|
|
158
|
+
# Convert to numpy for manipulation
|
|
159
|
+
if is_torch:
|
|
160
|
+
ref_np = ref_pred.detach().cpu().numpy()
|
|
161
|
+
alt_np = alt_pred.detach().cpu().numpy()
|
|
162
|
+
else:
|
|
163
|
+
ref_np = ref_pred
|
|
164
|
+
alt_np = alt_pred
|
|
165
|
+
|
|
166
|
+
# Swap for deletions (treat as insertion in reverse)
|
|
167
|
+
if svtype == "DEL":
|
|
168
|
+
ref_np, alt_np = alt_np, ref_np
|
|
169
|
+
var_pos = VariantPosition(
|
|
170
|
+
var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Get bin positions
|
|
174
|
+
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
|
|
175
|
+
bins_to_add = alt_end_bin - alt_start_bin
|
|
176
|
+
|
|
177
|
+
# Insert NaN bins in REF where variant exists in ALT
|
|
178
|
+
ref_masked = self._insert_nan_bins(ref_np, ref_bin, bins_to_add)
|
|
179
|
+
|
|
180
|
+
# Crop to maintain target size
|
|
181
|
+
ref_masked = self._crop_vector(ref_masked, ref_bin, alt_start_bin)
|
|
182
|
+
alt_masked = alt_np.copy()
|
|
183
|
+
|
|
184
|
+
# Swap back for deletions
|
|
185
|
+
if svtype == "DEL":
|
|
186
|
+
ref_masked, alt_masked = alt_masked, ref_masked
|
|
187
|
+
|
|
188
|
+
self._validate_size(ref_masked, alt_masked)
|
|
189
|
+
|
|
190
|
+
# Convert back to torch if needed
|
|
191
|
+
if is_torch:
|
|
192
|
+
ref_masked = (
|
|
193
|
+
torch.from_numpy(ref_masked).to(ref_pred.device).type(ref_pred.dtype)
|
|
194
|
+
)
|
|
195
|
+
alt_masked = (
|
|
196
|
+
torch.from_numpy(alt_masked).to(alt_pred.device).type(alt_pred.dtype)
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return ref_masked, alt_masked
|
|
200
|
+
|
|
201
|
+
def _insert_nan_bins(
|
|
202
|
+
self, vector: np.ndarray, position: int, num_bins: int
|
|
203
|
+
) -> np.ndarray:
|
|
204
|
+
"""
|
|
205
|
+
Insert NaN values at specified position in vector.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
vector: Input prediction vector
|
|
209
|
+
position: Position to insert NaN values (bin index)
|
|
210
|
+
num_bins: Number of NaN values to insert
|
|
211
|
+
|
|
212
|
+
Returns:
|
|
213
|
+
Vector with NaN values inserted
|
|
214
|
+
"""
|
|
215
|
+
result = vector.copy()
|
|
216
|
+
for offset in range(num_bins):
|
|
217
|
+
insert_pos = position + offset
|
|
218
|
+
result = np.insert(result, insert_pos, np.nan)
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
def _crop_vector(
|
|
222
|
+
self, vector: np.ndarray, ref_bin: int, alt_bin: int
|
|
223
|
+
) -> np.ndarray:
|
|
224
|
+
"""
|
|
225
|
+
Crop vector edges to maintain target size.
|
|
226
|
+
|
|
227
|
+
After inserting NaN bins, the vector is longer than expected.
|
|
228
|
+
This function crops from edges proportionally to center the variant.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
vector: Vector to crop
|
|
232
|
+
ref_bin: Reference bin position
|
|
233
|
+
alt_bin: Alternate bin position
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Cropped vector of target_size length
|
|
237
|
+
"""
|
|
238
|
+
remove_left = ref_bin - alt_bin
|
|
239
|
+
remove_right = len(vector) - self.target_size - remove_left
|
|
240
|
+
|
|
241
|
+
# Apply cropping
|
|
242
|
+
start = max(0, remove_left)
|
|
243
|
+
end = len(vector) - max(0, remove_right)
|
|
244
|
+
return vector[start:end]
|
|
245
|
+
|
|
246
|
+
def _align_inversion_predictions(
|
|
247
|
+
self,
|
|
248
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
249
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
250
|
+
var_pos: VariantPosition,
|
|
251
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
252
|
+
"""
|
|
253
|
+
Align predictions for inversions.
|
|
254
|
+
|
|
255
|
+
Strategy:
|
|
256
|
+
1. Mask the inverted region in both vectors with NaN
|
|
257
|
+
2. This allows comparison of only the flanking (unaffected) regions
|
|
258
|
+
|
|
259
|
+
For strand-aware models, inversions can significantly affect predictions
|
|
260
|
+
because regulatory elements now appear on the opposite strand. We mask
|
|
261
|
+
the inverted region to focus comparison on unaffected flanking sequences.
|
|
262
|
+
"""
|
|
263
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
264
|
+
|
|
265
|
+
# Convert to numpy for manipulation
|
|
266
|
+
if is_torch:
|
|
267
|
+
ref_np = ref_pred.detach().cpu().numpy()
|
|
268
|
+
alt_np = alt_pred.detach().cpu().numpy()
|
|
269
|
+
else:
|
|
270
|
+
ref_np = ref_pred.copy()
|
|
271
|
+
alt_np = alt_pred.copy()
|
|
272
|
+
|
|
273
|
+
var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
|
|
274
|
+
|
|
275
|
+
# Mask inverted region in both REF and ALT
|
|
276
|
+
ref_np[var_start : var_end + 1] = np.nan
|
|
277
|
+
alt_np[var_start : var_end + 1] = np.nan
|
|
278
|
+
|
|
279
|
+
self._validate_size(ref_np, alt_np)
|
|
280
|
+
|
|
281
|
+
# Convert back to torch if needed
|
|
282
|
+
if is_torch:
|
|
283
|
+
ref_np = torch.from_numpy(ref_np).to(ref_pred.device).type(ref_pred.dtype)
|
|
284
|
+
alt_np = torch.from_numpy(alt_np).to(alt_pred.device).type(alt_pred.dtype)
|
|
285
|
+
|
|
286
|
+
return ref_np, alt_np
|
|
287
|
+
|
|
288
|
+
def align_bnd_predictions(
|
|
289
|
+
self,
|
|
290
|
+
left_ref: Union[np.ndarray, "torch.Tensor"],
|
|
291
|
+
right_ref: Union[np.ndarray, "torch.Tensor"],
|
|
292
|
+
bnd_alt: Union[np.ndarray, "torch.Tensor"],
|
|
293
|
+
breakpoint_bin: int,
|
|
294
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
295
|
+
"""
|
|
296
|
+
Align predictions for breakends (chromosomal rearrangements).
|
|
297
|
+
|
|
298
|
+
BNDs join two distant loci, so we create a chimeric reference
|
|
299
|
+
prediction from the two separate loci for comparison with the fusion ALT.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
left_ref: Prediction from left locus
|
|
303
|
+
right_ref: Prediction from right locus
|
|
304
|
+
bnd_alt: Prediction from joined (alternate) sequence
|
|
305
|
+
breakpoint_bin: Bin position of breakpoint
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Tuple of (chimeric_ref, alt) vectors
|
|
309
|
+
"""
|
|
310
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(left_ref)
|
|
311
|
+
|
|
312
|
+
# Convert to numpy for manipulation
|
|
313
|
+
if is_torch:
|
|
314
|
+
left_np = left_ref.detach().cpu().numpy()
|
|
315
|
+
right_np = right_ref.detach().cpu().numpy()
|
|
316
|
+
alt_np = bnd_alt.detach().cpu().numpy()
|
|
317
|
+
else:
|
|
318
|
+
left_np = left_ref
|
|
319
|
+
right_np = right_ref
|
|
320
|
+
alt_np = bnd_alt
|
|
321
|
+
|
|
322
|
+
# Extract the relevant portions from each reference
|
|
323
|
+
left_portion = left_np[:breakpoint_bin]
|
|
324
|
+
right_portion = right_np[-(self.target_size - breakpoint_bin) :]
|
|
325
|
+
|
|
326
|
+
# Assemble chimeric reference (continuous, no masking)
|
|
327
|
+
ref_chimeric = np.concatenate([left_portion, right_portion])
|
|
328
|
+
|
|
329
|
+
self._validate_size(ref_chimeric, alt_np)
|
|
330
|
+
|
|
331
|
+
# Convert back to torch if needed
|
|
332
|
+
if is_torch:
|
|
333
|
+
ref_chimeric = (
|
|
334
|
+
torch.from_numpy(ref_chimeric).to(left_ref.device).type(left_ref.dtype)
|
|
335
|
+
)
|
|
336
|
+
alt_np = torch.from_numpy(alt_np).to(bnd_alt.device).type(bnd_alt.dtype)
|
|
337
|
+
|
|
338
|
+
return ref_chimeric, alt_np
|
|
339
|
+
|
|
340
|
+
def _validate_size(self, ref_vector: np.ndarray, alt_vector: np.ndarray):
|
|
341
|
+
"""
|
|
342
|
+
Validate that vectors are the correct size.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
ref_vector: Reference prediction vector
|
|
346
|
+
alt_vector: Alternate prediction vector
|
|
347
|
+
|
|
348
|
+
Raises:
|
|
349
|
+
ValueError: If either vector has incorrect size
|
|
350
|
+
"""
|
|
351
|
+
if len(ref_vector) != self.target_size:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Reference vector wrong size: {len(ref_vector)} vs {self.target_size}"
|
|
354
|
+
)
|
|
355
|
+
if len(alt_vector) != self.target_size:
|
|
356
|
+
raise ValueError(
|
|
357
|
+
f"Alternate vector wrong size: {len(alt_vector)} vs {self.target_size}"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
class PredictionAligner2D:
|
|
362
|
+
"""
|
|
363
|
+
Aligns reference and alternate prediction matrices for variant comparison.
|
|
364
|
+
|
|
365
|
+
Handles alignment of 2D genomic predictions (e.g., Hi-C contact maps,
|
|
366
|
+
Micro-C predictions) between reference and variant sequences that may
|
|
367
|
+
differ in length due to structural variants.
|
|
368
|
+
|
|
369
|
+
The aligner uses a masking strategy where matrix rows and columns that
|
|
370
|
+
exist in one sequence but not the other are marked with NaN values.
|
|
371
|
+
|
|
372
|
+
Args:
|
|
373
|
+
target_size: Expected matrix dimension (NxN)
|
|
374
|
+
bin_size: Number of base pairs per matrix bin (model-specific)
|
|
375
|
+
diag_offset: Number of diagonal bins to mask (model-specific)
|
|
376
|
+
|
|
377
|
+
Example:
|
|
378
|
+
>>> aligner = PredictionAligner2D(
|
|
379
|
+
... target_size=448,
|
|
380
|
+
... bin_size=2048,
|
|
381
|
+
... diag_offset=2
|
|
382
|
+
... )
|
|
383
|
+
>>> ref_aligned, alt_aligned = aligner.align_predictions(
|
|
384
|
+
... ref_matrix, alt_matrix, 'DEL', variant_position
|
|
385
|
+
... )
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def __init__(self, target_size: int, bin_size: int, diag_offset: int):
|
|
389
|
+
"""
|
|
390
|
+
Initialize the 2D prediction aligner.
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
target_size: Matrix dimension (e.g., 448 for Akita)
|
|
394
|
+
bin_size: Base pairs per bin (e.g., 2048 for Akita)
|
|
395
|
+
diag_offset: Diagonal masking offset (e.g., 2 for Akita)
|
|
396
|
+
"""
|
|
397
|
+
self.target_size = target_size
|
|
398
|
+
self.bin_size = bin_size
|
|
399
|
+
self.diag_offset = diag_offset
|
|
400
|
+
|
|
401
|
+
def align_predictions(
|
|
402
|
+
self,
|
|
403
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
404
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
405
|
+
svtype: str,
|
|
406
|
+
var_pos: VariantPosition,
|
|
407
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
408
|
+
"""
|
|
409
|
+
Main entry point for 2D matrix alignment.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
ref_pred: Reference prediction matrix (NxN)
|
|
413
|
+
alt_pred: Alternate prediction matrix (NxN)
|
|
414
|
+
svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
|
|
415
|
+
var_pos: Variant position information
|
|
416
|
+
|
|
417
|
+
Returns:
|
|
418
|
+
Tuple of (aligned_ref, aligned_alt) matrices with NaN masking applied
|
|
419
|
+
|
|
420
|
+
Raises:
|
|
421
|
+
ValueError: For unsupported variant types or if using BND (use align_bnd_matrices)
|
|
422
|
+
"""
|
|
423
|
+
if svtype == "BND" or svtype == "SV_BND":
|
|
424
|
+
raise ValueError("Use align_bnd_matrices() for breakends")
|
|
425
|
+
|
|
426
|
+
# Normalize variant type names
|
|
427
|
+
svtype_normalized = svtype.replace("SV_", "")
|
|
428
|
+
|
|
429
|
+
if svtype_normalized in ["DEL", "DUP", "INS"]:
|
|
430
|
+
return self._align_indel_matrices(
|
|
431
|
+
ref_pred, alt_pred, svtype_normalized, var_pos
|
|
432
|
+
)
|
|
433
|
+
elif svtype_normalized == "INV":
|
|
434
|
+
return self._align_inversion_matrices(ref_pred, alt_pred, var_pos)
|
|
435
|
+
elif svtype_normalized in ["SNV", "MNV"]:
|
|
436
|
+
# SNVs don't change coordinates, direct alignment
|
|
437
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
438
|
+
if is_torch:
|
|
439
|
+
return ref_pred.clone(), alt_pred.clone()
|
|
440
|
+
else:
|
|
441
|
+
return ref_pred.copy(), alt_pred.copy()
|
|
442
|
+
else:
|
|
443
|
+
raise ValueError(f"Unknown variant type: {svtype}")
|
|
444
|
+
|
|
445
|
+
def _align_indel_matrices(
|
|
446
|
+
self,
|
|
447
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
448
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
449
|
+
svtype: str,
|
|
450
|
+
var_pos: VariantPosition,
|
|
451
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
452
|
+
"""
|
|
453
|
+
Align matrices for insertions, deletions, and duplications.
|
|
454
|
+
|
|
455
|
+
Strategy:
|
|
456
|
+
1. For DEL: Swap REF/ALT (deletion removes from REF)
|
|
457
|
+
2. Insert NaN bins (rows AND columns) in shorter matrix
|
|
458
|
+
3. Crop edges to maintain target size
|
|
459
|
+
4. For DEL: Swap back
|
|
460
|
+
"""
|
|
461
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
462
|
+
|
|
463
|
+
# Convert to numpy for manipulation
|
|
464
|
+
if is_torch:
|
|
465
|
+
ref_np = ref_pred.detach().cpu().numpy()
|
|
466
|
+
alt_np = alt_pred.detach().cpu().numpy()
|
|
467
|
+
else:
|
|
468
|
+
ref_np = ref_pred
|
|
469
|
+
alt_np = alt_pred
|
|
470
|
+
|
|
471
|
+
# Swap for deletions (treat as insertion in reverse)
|
|
472
|
+
if svtype == "DEL":
|
|
473
|
+
ref_np, alt_np = alt_np, ref_np
|
|
474
|
+
var_pos = VariantPosition(
|
|
475
|
+
var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
# Get bin positions
|
|
479
|
+
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
|
|
480
|
+
bins_to_add = alt_end_bin - alt_start_bin
|
|
481
|
+
|
|
482
|
+
# Insert NaN bins in REF where variant exists in ALT
|
|
483
|
+
ref_masked = self._insert_nan_bins(ref_np, ref_bin, bins_to_add)
|
|
484
|
+
|
|
485
|
+
# Crop to maintain target size
|
|
486
|
+
ref_masked = self._crop_matrix(ref_masked, ref_bin, alt_start_bin)
|
|
487
|
+
alt_masked = alt_np.copy()
|
|
488
|
+
|
|
489
|
+
# Swap back for deletions
|
|
490
|
+
if svtype == "DEL":
|
|
491
|
+
ref_masked, alt_masked = alt_masked, ref_masked
|
|
492
|
+
|
|
493
|
+
self._validate_size(ref_masked, alt_masked)
|
|
494
|
+
|
|
495
|
+
# Convert back to torch if needed
|
|
496
|
+
if is_torch:
|
|
497
|
+
ref_masked = (
|
|
498
|
+
torch.from_numpy(ref_masked).to(ref_pred.device).type(ref_pred.dtype)
|
|
499
|
+
)
|
|
500
|
+
alt_masked = (
|
|
501
|
+
torch.from_numpy(alt_masked).to(alt_pred.device).type(alt_pred.dtype)
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
return ref_masked, alt_masked
|
|
505
|
+
|
|
506
|
+
def _insert_nan_bins(
|
|
507
|
+
self, matrix: np.ndarray, position: int, num_bins: int
|
|
508
|
+
) -> np.ndarray:
|
|
509
|
+
"""
|
|
510
|
+
Insert NaN bins (rows and columns) at specified position.
|
|
511
|
+
|
|
512
|
+
For 2D matrices, we must insert both rows AND columns to maintain
|
|
513
|
+
the square matrix structure and properly mask interactions.
|
|
514
|
+
"""
|
|
515
|
+
result = matrix.copy()
|
|
516
|
+
for offset in range(num_bins):
|
|
517
|
+
insert_pos = position + offset
|
|
518
|
+
result = np.insert(result, insert_pos, np.nan, axis=0)
|
|
519
|
+
result = np.insert(result, insert_pos, np.nan, axis=1)
|
|
520
|
+
return result
|
|
521
|
+
|
|
522
|
+
def _crop_matrix(
|
|
523
|
+
self, matrix: np.ndarray, ref_bin: int, alt_bin: int
|
|
524
|
+
) -> np.ndarray:
|
|
525
|
+
"""
|
|
526
|
+
Crop matrix edges to maintain target size.
|
|
527
|
+
|
|
528
|
+
After inserting NaN bins, the matrix is larger than expected.
|
|
529
|
+
This function crops from edges proportionally to center the variant.
|
|
530
|
+
"""
|
|
531
|
+
remove_left = ref_bin - alt_bin
|
|
532
|
+
remove_right = len(matrix) - self.target_size - remove_left
|
|
533
|
+
|
|
534
|
+
# Apply cropping
|
|
535
|
+
start = max(0, remove_left)
|
|
536
|
+
end = len(matrix) - max(0, remove_right)
|
|
537
|
+
return matrix[start:end, start:end]
|
|
538
|
+
|
|
539
|
+
def _align_inversion_matrices(
|
|
540
|
+
self,
|
|
541
|
+
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
542
|
+
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
543
|
+
var_pos: VariantPosition,
|
|
544
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
545
|
+
"""
|
|
546
|
+
Align matrices for inversions.
|
|
547
|
+
|
|
548
|
+
Strategy: Mask the inverted region in both REF and ALT matrices using
|
|
549
|
+
a cross-pattern (mask entire rows AND columns at the inversion position).
|
|
550
|
+
|
|
551
|
+
Why cross-masking?
|
|
552
|
+
- Inversions reverse the sequence, creating geometric rotation in contact maps
|
|
553
|
+
- Masking rows removes interactions with the inverted region (one dimension)
|
|
554
|
+
- Masking columns removes interactions from the inverted region (other dimension)
|
|
555
|
+
- This ensures only flanking regions (unaffected by inversion) are compared
|
|
556
|
+
|
|
557
|
+
The same NaN pattern is mirrored to ALT so both matrices have identical
|
|
558
|
+
masked regions, enabling fair comparison of the unaffected areas.
|
|
559
|
+
"""
|
|
560
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
561
|
+
|
|
562
|
+
# Convert to numpy for manipulation
|
|
563
|
+
if is_torch:
|
|
564
|
+
ref_np = ref_pred.detach().cpu().numpy()
|
|
565
|
+
alt_np = alt_pred.detach().cpu().numpy()
|
|
566
|
+
else:
|
|
567
|
+
ref_np = ref_pred.copy()
|
|
568
|
+
alt_np = alt_pred.copy()
|
|
569
|
+
|
|
570
|
+
var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
|
|
571
|
+
|
|
572
|
+
# Mask inverted region in REF (cross-pattern: rows + columns)
|
|
573
|
+
ref_np[var_start : var_end + 1, :] = np.nan
|
|
574
|
+
ref_np[:, var_start : var_end + 1] = np.nan
|
|
575
|
+
|
|
576
|
+
# Mirror NaN pattern to ALT (correct approach)
|
|
577
|
+
nan_mask = ref_np.copy()
|
|
578
|
+
nan_mask[np.invert(np.isnan(nan_mask))] = 0 # Non-NaN → 0, NaN stays NaN
|
|
579
|
+
alt_np = alt_np + nan_mask # Adding NaN propagates NaN to ALT
|
|
580
|
+
|
|
581
|
+
self._validate_size(ref_np, alt_np)
|
|
582
|
+
|
|
583
|
+
# Convert back to torch if needed
|
|
584
|
+
if is_torch:
|
|
585
|
+
ref_np = torch.from_numpy(ref_np).to(ref_pred.device).type(ref_pred.dtype)
|
|
586
|
+
alt_np = torch.from_numpy(alt_np).to(alt_pred.device).type(alt_pred.dtype)
|
|
587
|
+
|
|
588
|
+
return ref_np, alt_np
|
|
589
|
+
|
|
590
|
+
def align_bnd_matrices(
|
|
591
|
+
self,
|
|
592
|
+
left_ref: Union[np.ndarray, "torch.Tensor"],
|
|
593
|
+
right_ref: Union[np.ndarray, "torch.Tensor"],
|
|
594
|
+
bnd_alt: Union[np.ndarray, "torch.Tensor"],
|
|
595
|
+
breakpoint_bin: int,
|
|
596
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
597
|
+
"""
|
|
598
|
+
Align matrices for breakends (chromosomal rearrangements).
|
|
599
|
+
|
|
600
|
+
BNDs join two distant loci, so we create a chimeric reference
|
|
601
|
+
matrix from the two separate loci.
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
left_ref: Prediction from left locus
|
|
605
|
+
right_ref: Prediction from right locus
|
|
606
|
+
bnd_alt: Prediction from joined (alternate) sequence
|
|
607
|
+
breakpoint_bin: Bin position of breakpoint
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
Tuple of (chimeric_ref, alt) matrices
|
|
611
|
+
"""
|
|
612
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(left_ref)
|
|
613
|
+
|
|
614
|
+
# Convert to numpy for manipulation
|
|
615
|
+
if is_torch:
|
|
616
|
+
left_np = left_ref.detach().cpu().numpy()
|
|
617
|
+
right_np = right_ref.detach().cpu().numpy()
|
|
618
|
+
alt_np = bnd_alt.detach().cpu().numpy()
|
|
619
|
+
else:
|
|
620
|
+
left_np = left_ref
|
|
621
|
+
right_np = right_ref
|
|
622
|
+
alt_np = bnd_alt
|
|
623
|
+
|
|
624
|
+
# Assemble chimeric matrix from two loci
|
|
625
|
+
ref_chimeric = self._assemble_chimeric_matrix(left_np, right_np, breakpoint_bin)
|
|
626
|
+
|
|
627
|
+
self._validate_size(ref_chimeric, alt_np)
|
|
628
|
+
|
|
629
|
+
# Convert back to torch if needed
|
|
630
|
+
if is_torch:
|
|
631
|
+
ref_chimeric = (
|
|
632
|
+
torch.from_numpy(ref_chimeric).to(left_ref.device).type(left_ref.dtype)
|
|
633
|
+
)
|
|
634
|
+
alt_np = torch.from_numpy(alt_np).to(bnd_alt.device).type(bnd_alt.dtype)
|
|
635
|
+
|
|
636
|
+
return ref_chimeric, alt_np
|
|
637
|
+
|
|
638
|
+
def _assemble_chimeric_matrix(
|
|
639
|
+
self, left_matrix: np.ndarray, right_matrix: np.ndarray, breakpoint: int
|
|
640
|
+
) -> np.ndarray:
|
|
641
|
+
"""
|
|
642
|
+
Assemble chimeric matrix from two loci.
|
|
643
|
+
|
|
644
|
+
Structure:
|
|
645
|
+
- Upper left quadrant: left locus
|
|
646
|
+
- Lower right quadrant: right locus
|
|
647
|
+
- Upper right/lower left quadrants: NaN (no trans prediction)
|
|
648
|
+
"""
|
|
649
|
+
matrix = np.zeros((self.target_size, self.target_size))
|
|
650
|
+
|
|
651
|
+
# Fill upper left quadrant (left locus)
|
|
652
|
+
matrix[:breakpoint, :breakpoint] = left_matrix[:breakpoint, :breakpoint]
|
|
653
|
+
|
|
654
|
+
# Fill lower right quadrant (right locus)
|
|
655
|
+
matrix[breakpoint:, breakpoint:] = right_matrix[breakpoint:, breakpoint:]
|
|
656
|
+
|
|
657
|
+
# Fill transition quadrants with NaN
|
|
658
|
+
matrix[:breakpoint, breakpoint:] = np.nan
|
|
659
|
+
matrix[breakpoint:, :breakpoint] = np.nan
|
|
660
|
+
|
|
661
|
+
# Mask diagonals as specified by model
|
|
662
|
+
for offset in range(-self.diag_offset + 1, self.diag_offset):
|
|
663
|
+
if offset < 0:
|
|
664
|
+
np.fill_diagonal(matrix[abs(offset) :, :], np.nan)
|
|
665
|
+
else:
|
|
666
|
+
np.fill_diagonal(matrix[:, offset:], np.nan)
|
|
667
|
+
|
|
668
|
+
return matrix
|
|
669
|
+
|
|
670
|
+
def _validate_size(self, ref_matrix: np.ndarray, alt_matrix: np.ndarray):
|
|
671
|
+
"""
|
|
672
|
+
Validate that matrices are the correct size.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
ref_matrix: Reference prediction matrix
|
|
676
|
+
alt_matrix: Alternate prediction matrix
|
|
677
|
+
|
|
678
|
+
Raises:
|
|
679
|
+
ValueError: If either matrix has incorrect size
|
|
680
|
+
"""
|
|
681
|
+
if ref_matrix.shape[0] != self.target_size:
|
|
682
|
+
raise ValueError(
|
|
683
|
+
f"Reference matrix wrong size: {ref_matrix.shape[0]} vs {self.target_size}"
|
|
684
|
+
)
|
|
685
|
+
if alt_matrix.shape[0] != self.target_size:
|
|
686
|
+
raise ValueError(
|
|
687
|
+
f"Alternate matrix wrong size: {alt_matrix.shape[0]} vs {self.target_size}"
|
|
688
|
+
)
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
# =============================================================================
|
|
692
|
+
# Utility Functions for Contact Maps
|
|
693
|
+
# =============================================================================
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def vector_to_contact_matrix(
|
|
697
|
+
vector: Union[np.ndarray, "torch.Tensor"], matrix_size: int, diag_offset: int = 0
|
|
698
|
+
) -> Union[np.ndarray, "torch.Tensor"]:
|
|
699
|
+
"""
|
|
700
|
+
Convert flattened upper triangular vector to full contact matrix.
|
|
701
|
+
|
|
702
|
+
This function reconstructs a full symmetric contact matrix from its upper
|
|
703
|
+
triangular representation, following the pattern used in genomic contact map models.
|
|
704
|
+
Supports diagonal masking where near-diagonal elements are excluded.
|
|
705
|
+
|
|
706
|
+
Args:
|
|
707
|
+
vector: Flattened upper triangular matrix
|
|
708
|
+
Expected length depends on diag_offset:
|
|
709
|
+
- diag_offset=0: matrix_size * (matrix_size + 1) / 2 (includes diagonal)
|
|
710
|
+
- diag_offset=k: (matrix_size - k) * (matrix_size - k + 1) / 2
|
|
711
|
+
matrix_size: Dimension of the output square matrix
|
|
712
|
+
diag_offset: Diagonal offset for masking (default=0, no masking)
|
|
713
|
+
diag_offset=2 means skip main diagonal and first off-diagonal
|
|
714
|
+
|
|
715
|
+
Returns:
|
|
716
|
+
Full symmetric contact matrix of shape (matrix_size, matrix_size)
|
|
717
|
+
Elements within diag_offset of the diagonal are set to NaN
|
|
718
|
+
|
|
719
|
+
Example:
|
|
720
|
+
>>> # Full upper triangle (diag_offset=0, default)
|
|
721
|
+
>>> vector = np.array([1, 2, 3, 4, 5, 6])
|
|
722
|
+
>>> matrix = vector_to_contact_matrix(vector, 3)
|
|
723
|
+
>>> # Result: [[1, 2, 3], [2, 4, 5], [3, 5, 6]]
|
|
724
|
+
"""
|
|
725
|
+
# Handle both PyTorch tensors and NumPy arrays
|
|
726
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(vector)
|
|
727
|
+
|
|
728
|
+
if is_torch:
|
|
729
|
+
# Initialize matrix with NaN
|
|
730
|
+
matrix = torch.full(
|
|
731
|
+
(matrix_size, matrix_size),
|
|
732
|
+
float("nan"),
|
|
733
|
+
dtype=vector.dtype,
|
|
734
|
+
device=vector.device,
|
|
735
|
+
)
|
|
736
|
+
# Get upper triangle indices with diagonal offset
|
|
737
|
+
triu_indices = torch.triu_indices(matrix_size, matrix_size, offset=diag_offset)
|
|
738
|
+
matrix[triu_indices[0], triu_indices[1]] = vector
|
|
739
|
+
# Make symmetric by copying upper triangle to lower (preserving NaN)
|
|
740
|
+
for i in range(matrix_size):
|
|
741
|
+
for j in range(i + diag_offset, matrix_size):
|
|
742
|
+
if not torch.isnan(matrix[i, j]):
|
|
743
|
+
matrix[j, i] = matrix[i, j]
|
|
744
|
+
else:
|
|
745
|
+
# Initialize matrix with NaN
|
|
746
|
+
matrix = np.full((matrix_size, matrix_size), np.nan, dtype=vector.dtype)
|
|
747
|
+
# Get upper triangle indices with diagonal offset
|
|
748
|
+
triu_indices = np.triu_indices(matrix_size, k=diag_offset)
|
|
749
|
+
matrix[triu_indices] = vector
|
|
750
|
+
# Make symmetric by copying upper triangle to lower (preserving NaN)
|
|
751
|
+
for i in range(matrix_size):
|
|
752
|
+
for j in range(i + diag_offset, matrix_size):
|
|
753
|
+
if not np.isnan(matrix[i, j]):
|
|
754
|
+
matrix[j, i] = matrix[i, j]
|
|
755
|
+
|
|
756
|
+
return matrix
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
def contact_matrix_to_vector(
|
|
760
|
+
matrix: Union[np.ndarray, "torch.Tensor"], diag_offset: int = 0
|
|
761
|
+
) -> Union[np.ndarray, "torch.Tensor"]:
|
|
762
|
+
"""
|
|
763
|
+
Convert full contact matrix to flattened upper triangular vector.
|
|
764
|
+
|
|
765
|
+
This function extracts the upper triangular portion of a contact matrix,
|
|
766
|
+
which is the standard representation for genomic contact maps. Supports
|
|
767
|
+
diagonal masking to exclude near-diagonal elements.
|
|
768
|
+
|
|
769
|
+
Args:
|
|
770
|
+
matrix: Full symmetric contact matrix of shape (N, N)
|
|
771
|
+
diag_offset: Diagonal offset for extraction (default=0, includes diagonal)
|
|
772
|
+
diag_offset=2 means skip main diagonal and first off-diagonal
|
|
773
|
+
|
|
774
|
+
Returns:
|
|
775
|
+
Flattened upper triangular vector
|
|
776
|
+
Length depends on diag_offset:
|
|
777
|
+
- diag_offset=0: N * (N + 1) / 2 (includes diagonal)
|
|
778
|
+
- diag_offset=k: (N - k) * (N - k + 1) / 2
|
|
779
|
+
|
|
780
|
+
Example:
|
|
781
|
+
>>> # Full upper triangle (diag_offset=0, default)
|
|
782
|
+
>>> matrix = np.array([[1, 2, 3], [2, 4, 5], [3, 5, 6]])
|
|
783
|
+
>>> vector = contact_matrix_to_vector(matrix)
|
|
784
|
+
>>> # Result: [1, 2, 3, 4, 5, 6]
|
|
785
|
+
"""
|
|
786
|
+
# Handle both PyTorch tensors and NumPy arrays
|
|
787
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(matrix)
|
|
788
|
+
|
|
789
|
+
if is_torch:
|
|
790
|
+
triu_indices = torch.triu_indices(
|
|
791
|
+
matrix.shape[0], matrix.shape[1], offset=diag_offset
|
|
792
|
+
)
|
|
793
|
+
return matrix[triu_indices[0], triu_indices[1]]
|
|
794
|
+
else:
|
|
795
|
+
triu_indices = np.triu_indices(matrix.shape[0], k=diag_offset)
|
|
796
|
+
return matrix[triu_indices]
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
def align_predictions_by_coordinate(
|
|
800
|
+
ref_preds: Union[np.ndarray, "torch.Tensor"],
|
|
801
|
+
alt_preds: Union[np.ndarray, "torch.Tensor"],
|
|
802
|
+
metadata_row: dict,
|
|
803
|
+
bin_size: int,
|
|
804
|
+
prediction_type: str,
|
|
805
|
+
matrix_size: Optional[int] = None,
|
|
806
|
+
diag_offset: int = 0,
|
|
807
|
+
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
808
|
+
"""
|
|
809
|
+
Align reference and alt predictions using coordinate transformation and variant type awareness.
|
|
810
|
+
|
|
811
|
+
This is the main public API for prediction alignment. It handles both 1D prediction
|
|
812
|
+
vectors (e.g., chromatin accessibility, TF binding) and 2D matrices (e.g., Hi-C contact maps),
|
|
813
|
+
routing to the appropriate alignment strategy based on variant type.
|
|
814
|
+
|
|
815
|
+
IMPORTANT: Model-specific parameters (bin_size, matrix_size) must be explicitly
|
|
816
|
+
provided by the user. There are no defaults because these vary across different models.
|
|
817
|
+
|
|
818
|
+
Args:
|
|
819
|
+
ref_preds: Reference predictions array (from model with edge cropping)
|
|
820
|
+
alt_preds: Alt predictions array (same shape as ref_preds)
|
|
821
|
+
metadata_row: Dictionary with variant information containing:
|
|
822
|
+
- 'variant_type': Type of variant (SNV, INS, DEL, DUP, INV, BND)
|
|
823
|
+
- 'window_start': Start position of window (0-based)
|
|
824
|
+
- 'variant_pos0': Variant position (0-based, absolute genomic coordinate)
|
|
825
|
+
- 'svlen': Length of structural variant (optional, for symbolic alleles)
|
|
826
|
+
bin_size: Number of base pairs per prediction bin (REQUIRED, model-specific)
|
|
827
|
+
Examples: 2048 for Akita
|
|
828
|
+
prediction_type: Type of predictions ("1D" or "2D")
|
|
829
|
+
- "1D": Vector predictions (chromatin accessibility, TF binding, etc.)
|
|
830
|
+
- "2D": Matrix predictions (Hi-C contact maps, Micro-C, etc.)
|
|
831
|
+
matrix_size: Size of contact matrix (REQUIRED for 2D type)
|
|
832
|
+
Examples: 448 for Akita
|
|
833
|
+
diag_offset: Number of diagonal bins to mask (default: 0 for no masking)
|
|
834
|
+
Set to 0 if your model doesn't mask diagonals, or to model-specific value
|
|
835
|
+
Examples: 2 for Akita, 0 for models without diagonal masking
|
|
836
|
+
|
|
837
|
+
Returns:
|
|
838
|
+
Tuple of (aligned_ref_preds, aligned_alt_preds) with NaN masking applied
|
|
839
|
+
at positions that differ between reference and alternate sequences
|
|
840
|
+
|
|
841
|
+
Raises:
|
|
842
|
+
ValueError: If prediction_type is invalid, required parameters are missing,
|
|
843
|
+
or variant type is unsupported
|
|
844
|
+
|
|
845
|
+
Example (1D predictions):
|
|
846
|
+
>>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
|
|
847
|
+
... ref_preds=ref_scores,
|
|
848
|
+
... alt_preds=alt_scores,
|
|
849
|
+
... metadata_row={'variant_type': 'INS', 'window_start': 0,
|
|
850
|
+
... 'variant_pos0': 500, 'svlen': 100},
|
|
851
|
+
... bin_size=128,
|
|
852
|
+
... prediction_type="1D"
|
|
853
|
+
... )
|
|
854
|
+
|
|
855
|
+
Example (2D contact maps with diagonal masking):
|
|
856
|
+
>>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
|
|
857
|
+
... ref_preds=ref_contact_map,
|
|
858
|
+
... alt_preds=alt_contact_map,
|
|
859
|
+
... metadata_row={'variant_type': 'DEL', 'window_start': 0,
|
|
860
|
+
... 'variant_pos0': 50000, 'svlen': -2048},
|
|
861
|
+
... bin_size=2048,
|
|
862
|
+
... prediction_type="2D",
|
|
863
|
+
... matrix_size=448,
|
|
864
|
+
... diag_offset=2 # Optional: use 0 if no diagonal masking
|
|
865
|
+
... )
|
|
866
|
+
|
|
867
|
+
Example (2D contact maps without diagonal masking):
|
|
868
|
+
>>> ref_aligned, alt_aligned = align_predictions_by_coordinate(
|
|
869
|
+
... ref_preds=ref_contact_map,
|
|
870
|
+
... alt_preds=alt_contact_map,
|
|
871
|
+
... metadata_row={'variant_type': 'INS', 'window_start': 0,
|
|
872
|
+
... 'variant_pos0': 1000, 'svlen': 500},
|
|
873
|
+
... bin_size=1000,
|
|
874
|
+
... prediction_type="2D",
|
|
875
|
+
... matrix_size=512
|
|
876
|
+
... # diag_offset defaults to 0 (no masking)
|
|
877
|
+
... )
|
|
878
|
+
"""
|
|
879
|
+
# Validate prediction type and parameters
|
|
880
|
+
if prediction_type not in ["1D", "2D"]:
|
|
881
|
+
raise ValueError(
|
|
882
|
+
f"prediction_type must be '1D' or '2D', got '{prediction_type}'"
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
if prediction_type == "2D":
|
|
886
|
+
if matrix_size is None:
|
|
887
|
+
raise ValueError("matrix_size must be provided for 2D prediction type")
|
|
888
|
+
|
|
889
|
+
# Extract variant information from metadata
|
|
890
|
+
variant_type = metadata_row.get("variant_type", "unknown")
|
|
891
|
+
window_start = metadata_row.get("window_start", 0)
|
|
892
|
+
variant_pos0 = metadata_row.get("variant_pos0")
|
|
893
|
+
|
|
894
|
+
# For backward compatibility, check for effective_variant_start (deprecated)
|
|
895
|
+
if variant_pos0 is not None:
|
|
896
|
+
abs_variant_pos = variant_pos0
|
|
897
|
+
else:
|
|
898
|
+
# Fall back to old field name if present (for backward compatibility)
|
|
899
|
+
effective_variant_start = metadata_row.get("effective_variant_start", 0)
|
|
900
|
+
abs_variant_pos = window_start + effective_variant_start
|
|
901
|
+
|
|
902
|
+
svlen = metadata_row.get("svlen", None)
|
|
903
|
+
|
|
904
|
+
# Calculate svlen from alleles if not present (for non-symbolic DEL/INS variants)
|
|
905
|
+
# Symbolic variants (SV_DEL, SV_INS, SV_INV, SV_DUP) have svlen in INFO field
|
|
906
|
+
# Regular variants (DEL, INS) need svlen calculated from allele lengths
|
|
907
|
+
if svlen is None or svlen == 0:
|
|
908
|
+
ref_allele = metadata_row.get("ref", "")
|
|
909
|
+
alt_allele = metadata_row.get("alt", "")
|
|
910
|
+
if ref_allele and alt_allele:
|
|
911
|
+
# svlen = len(alt) - len(ref)
|
|
912
|
+
# For DEL: negative (e.g., 1 - 13 = -12)
|
|
913
|
+
# For INS: positive (e.g., 7 - 1 = 6)
|
|
914
|
+
svlen = len(alt_allele) - len(ref_allele)
|
|
915
|
+
|
|
916
|
+
# Create VariantPosition object
|
|
917
|
+
var_pos = VariantPosition(
|
|
918
|
+
ref_pos=abs_variant_pos,
|
|
919
|
+
alt_pos=abs_variant_pos,
|
|
920
|
+
svlen=svlen if svlen is not None else 0,
|
|
921
|
+
variant_type=variant_type,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
# Determine target size from predictions
|
|
925
|
+
if prediction_type == "1D":
|
|
926
|
+
# Check if predictions are multi-dimensional (multiple targets)
|
|
927
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_preds)
|
|
928
|
+
is_numpy = isinstance(ref_preds, np.ndarray)
|
|
929
|
+
|
|
930
|
+
if is_torch:
|
|
931
|
+
ndim = len(ref_preds.shape)
|
|
932
|
+
elif is_numpy:
|
|
933
|
+
ndim = ref_preds.ndim
|
|
934
|
+
else:
|
|
935
|
+
ndim = 1
|
|
936
|
+
|
|
937
|
+
# Handle multi-target predictions [n_targets, n_bins]
|
|
938
|
+
if ndim > 1:
|
|
939
|
+
target_size = ref_preds.shape[-1] # Number of bins
|
|
940
|
+
aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
|
|
941
|
+
|
|
942
|
+
# Align each target separately
|
|
943
|
+
n_targets = ref_preds.shape[0]
|
|
944
|
+
ref_aligned_list = []
|
|
945
|
+
alt_aligned_list = []
|
|
946
|
+
|
|
947
|
+
for target_idx in range(n_targets):
|
|
948
|
+
ref_target = ref_preds[target_idx]
|
|
949
|
+
alt_target = alt_preds[target_idx]
|
|
950
|
+
ref_aligned, alt_aligned = aligner.align_predictions(
|
|
951
|
+
ref_target, alt_target, variant_type, var_pos
|
|
952
|
+
)
|
|
953
|
+
ref_aligned_list.append(ref_aligned)
|
|
954
|
+
alt_aligned_list.append(alt_aligned)
|
|
955
|
+
|
|
956
|
+
# Stack results back into multi-target format
|
|
957
|
+
if is_torch:
|
|
958
|
+
ref_result = torch.stack(ref_aligned_list)
|
|
959
|
+
alt_result = torch.stack(alt_aligned_list)
|
|
960
|
+
else:
|
|
961
|
+
ref_result = np.stack(ref_aligned_list)
|
|
962
|
+
alt_result = np.stack(alt_aligned_list)
|
|
963
|
+
|
|
964
|
+
return ref_result, alt_result
|
|
965
|
+
else:
|
|
966
|
+
# Single target prediction [n_bins]
|
|
967
|
+
target_size = len(ref_preds)
|
|
968
|
+
aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
|
|
969
|
+
return aligner.align_predictions(
|
|
970
|
+
ref_preds, alt_preds, variant_type, var_pos
|
|
971
|
+
)
|
|
972
|
+
else: # 2D
|
|
973
|
+
# Check if predictions are 1D (flattened upper triangular) or 2D (full matrix)
|
|
974
|
+
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_preds)
|
|
975
|
+
|
|
976
|
+
if is_torch:
|
|
977
|
+
ndim = len(ref_preds.shape)
|
|
978
|
+
else:
|
|
979
|
+
ndim = ref_preds.ndim
|
|
980
|
+
|
|
981
|
+
# If 1D, convert to 2D matrix
|
|
982
|
+
if ndim == 1:
|
|
983
|
+
ref_matrix = vector_to_contact_matrix(
|
|
984
|
+
ref_preds, matrix_size, diag_offset=diag_offset
|
|
985
|
+
)
|
|
986
|
+
alt_matrix = vector_to_contact_matrix(
|
|
987
|
+
alt_preds, matrix_size, diag_offset=diag_offset
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Align matrices
|
|
991
|
+
aligner = PredictionAligner2D(
|
|
992
|
+
target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
|
|
993
|
+
)
|
|
994
|
+
aligned_ref_matrix, aligned_alt_matrix = aligner.align_predictions(
|
|
995
|
+
ref_matrix, alt_matrix, variant_type, var_pos
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
# Convert back to flattened format
|
|
999
|
+
aligned_ref_vector = contact_matrix_to_vector(
|
|
1000
|
+
aligned_ref_matrix, diag_offset=diag_offset
|
|
1001
|
+
)
|
|
1002
|
+
aligned_alt_vector = contact_matrix_to_vector(
|
|
1003
|
+
aligned_alt_matrix, diag_offset=diag_offset
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
return aligned_ref_vector, aligned_alt_vector
|
|
1007
|
+
else:
|
|
1008
|
+
# Already 2D matrices
|
|
1009
|
+
aligner = PredictionAligner2D(
|
|
1010
|
+
target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
|
|
1011
|
+
)
|
|
1012
|
+
return aligner.align_predictions(
|
|
1013
|
+
ref_preds, alt_preds, variant_type, var_pos
|
|
1014
|
+
)
|