supremo-lite 0.5.4__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supremo_lite/__init__.py +1 -1
- supremo_lite/mock_models/testmodel_2d.py +7 -3
- supremo_lite/mutagenesis.py +16 -8
- supremo_lite/personalize.py +25 -17
- supremo_lite/prediction_alignment.py +123 -33
- {supremo_lite-0.5.4.dist-info → supremo_lite-0.5.5.dist-info}/METADATA +4 -6
- supremo_lite-0.5.5.dist-info/RECORD +15 -0
- supremo_lite-0.5.4.dist-info/RECORD +0 -15
- {supremo_lite-0.5.4.dist-info → supremo_lite-0.5.5.dist-info}/WHEEL +0 -0
- {supremo_lite-0.5.4.dist-info → supremo_lite-0.5.5.dist-info}/licenses/LICENSE +0 -0
supremo_lite/__init__.py
CHANGED
|
@@ -52,7 +52,7 @@ from .prediction_alignment import align_predictions_by_coordinate
|
|
|
52
52
|
# This allows users who don't have PyTorch to still use the main package
|
|
53
53
|
|
|
54
54
|
# Version
|
|
55
|
-
__version__ = "0.5.
|
|
55
|
+
__version__ = "0.5.5"
|
|
56
56
|
# Package metadata
|
|
57
57
|
__description__ = (
|
|
58
58
|
"A module for generating personalized genome sequences and in-silico mutagenesis"
|
|
@@ -126,9 +126,13 @@ if TORCH_AVAILABLE:
|
|
|
126
126
|
)
|
|
127
127
|
|
|
128
128
|
# Crop bins from all edges to focus loss function
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
129
|
+
if self.crop_bins > 0:
|
|
130
|
+
y_hat = y_hat[
|
|
131
|
+
:,
|
|
132
|
+
:,
|
|
133
|
+
self.crop_bins : -self.crop_bins,
|
|
134
|
+
self.crop_bins : -self.crop_bins,
|
|
135
|
+
]
|
|
132
136
|
|
|
133
137
|
# Return full contact matrix
|
|
134
138
|
return y_hat
|
supremo_lite/mutagenesis.py
CHANGED
|
@@ -146,7 +146,8 @@ def get_sm_sequences(chrom, start, end, reference_fasta, encoder=None):
|
|
|
146
146
|
|
|
147
147
|
# Create a DataFrame for the metadata
|
|
148
148
|
metadata_df = pd.DataFrame(
|
|
149
|
-
metadata,
|
|
149
|
+
metadata,
|
|
150
|
+
columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"],
|
|
150
151
|
)
|
|
151
152
|
|
|
152
153
|
return ref_1h, alt_seqs_stacked, metadata_df
|
|
@@ -239,9 +240,7 @@ def get_sm_subsequences(
|
|
|
239
240
|
)
|
|
240
241
|
elif not has_bed:
|
|
241
242
|
# Neither approach was specified
|
|
242
|
-
raise ValueError(
|
|
243
|
-
"Must provide either (anchor + anchor_radius) or bed_regions."
|
|
244
|
-
)
|
|
243
|
+
raise ValueError("Must provide either (anchor + anchor_radius) or bed_regions.")
|
|
245
244
|
|
|
246
245
|
alt_seqs = []
|
|
247
246
|
metadata = []
|
|
@@ -331,7 +330,11 @@ def get_sm_subsequences(
|
|
|
331
330
|
|
|
332
331
|
# Adjust window to stay within chromosome bounds
|
|
333
332
|
chrom_obj = reference_fasta[chrom]
|
|
334
|
-
chrom_len =
|
|
333
|
+
chrom_len = (
|
|
334
|
+
len(chrom_obj)
|
|
335
|
+
if hasattr(chrom_obj, "__len__")
|
|
336
|
+
else len(chrom_obj.seq)
|
|
337
|
+
)
|
|
335
338
|
if window_start < 0:
|
|
336
339
|
window_start = 0
|
|
337
340
|
window_end = min(seq_len, chrom_len)
|
|
@@ -377,13 +380,17 @@ def get_sm_subsequences(
|
|
|
377
380
|
# Create a clone and substitute the base
|
|
378
381
|
if TORCH_AVAILABLE and isinstance(region_1h, torch.Tensor):
|
|
379
382
|
alt_1h = region_1h.clone()
|
|
380
|
-
alt_1h[:, i] = torch.tensor(
|
|
383
|
+
alt_1h[:, i] = torch.tensor(
|
|
384
|
+
nt_to_1h[alt], dtype=alt_1h.dtype
|
|
385
|
+
)
|
|
381
386
|
else:
|
|
382
387
|
alt_1h = region_1h.copy()
|
|
383
388
|
alt_1h[:, i] = nt_to_1h[alt]
|
|
384
389
|
|
|
385
390
|
alt_seqs.append(alt_1h)
|
|
386
|
-
metadata.append(
|
|
391
|
+
metadata.append(
|
|
392
|
+
[chrom, window_start, window_end, i, ref_nt, alt]
|
|
393
|
+
)
|
|
387
394
|
|
|
388
395
|
# If no regions were processed, create empty ref_1h
|
|
389
396
|
if ref_1h is None:
|
|
@@ -408,7 +415,8 @@ def get_sm_subsequences(
|
|
|
408
415
|
|
|
409
416
|
# Create a DataFrame for the metadata
|
|
410
417
|
metadata_df = pd.DataFrame(
|
|
411
|
-
metadata,
|
|
418
|
+
metadata,
|
|
419
|
+
columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"],
|
|
412
420
|
)
|
|
413
421
|
|
|
414
422
|
return ref_1h, alt_seqs_stacked, metadata_df
|
supremo_lite/personalize.py
CHANGED
|
@@ -42,7 +42,7 @@ IUPAC_CODES = {
|
|
|
42
42
|
"D": "[AGT]",
|
|
43
43
|
"H": "[ACT]",
|
|
44
44
|
"V": "[ACG]",
|
|
45
|
-
"N": "[ACGT]"
|
|
45
|
+
"N": "[ACGT]",
|
|
46
46
|
}
|
|
47
47
|
|
|
48
48
|
|
|
@@ -2811,6 +2811,7 @@ def get_pam_disrupting_alt_sequences(
|
|
|
2811
2811
|
... ref, vcf, seq_len=50, max_pam_distance=10, n_chunks=5):
|
|
2812
2812
|
... predictions = model.predict(alt_seqs, ref_seqs)
|
|
2813
2813
|
"""
|
|
2814
|
+
|
|
2814
2815
|
# Helper function to find PAM sites in a sequence
|
|
2815
2816
|
def _find_pam_sites(sequence, pam_pattern):
|
|
2816
2817
|
"""Find all PAM site positions in a sequence using IUPAC codes.
|
|
@@ -2830,7 +2831,7 @@ def get_pam_disrupting_alt_sequences(
|
|
|
2830
2831
|
pat_base = pat_upper[j]
|
|
2831
2832
|
|
|
2832
2833
|
# Sequence 'N' (padding or unknown) matches any pattern base
|
|
2833
|
-
if seq_base ==
|
|
2834
|
+
if seq_base == "N":
|
|
2834
2835
|
continue # Always matches
|
|
2835
2836
|
|
|
2836
2837
|
# Get allowed bases for this pattern position
|
|
@@ -3003,9 +3004,7 @@ def get_pam_disrupting_alt_sequences(
|
|
|
3003
3004
|
ref_allele = var.get("ref", "")
|
|
3004
3005
|
alt_allele = var.get("alt", "")
|
|
3005
3006
|
is_indel = (
|
|
3006
|
-
len(ref_allele) != len(alt_allele)
|
|
3007
|
-
or ref_allele == "-"
|
|
3008
|
-
or alt_allele == "-"
|
|
3007
|
+
len(ref_allele) != len(alt_allele) or ref_allele == "-" or alt_allele == "-"
|
|
3009
3008
|
)
|
|
3010
3009
|
|
|
3011
3010
|
truly_disrupted_pam_sites = []
|
|
@@ -3053,8 +3052,12 @@ def get_pam_disrupting_alt_sequences(
|
|
|
3053
3052
|
# For each disrupted PAM site, create a metadata entry
|
|
3054
3053
|
for pam_site_pos in truly_disrupted_pam_sites:
|
|
3055
3054
|
# Extract PAM sequences
|
|
3056
|
-
ref_pam_seq = ref_window_seq[
|
|
3057
|
-
|
|
3055
|
+
ref_pam_seq = ref_window_seq[
|
|
3056
|
+
pam_site_pos : pam_site_pos + len(pam_sequence)
|
|
3057
|
+
]
|
|
3058
|
+
alt_pam_seq = modified_window[
|
|
3059
|
+
pam_site_pos : pam_site_pos + len(pam_sequence)
|
|
3060
|
+
]
|
|
3058
3061
|
|
|
3059
3062
|
# Calculate distance from variant to PAM
|
|
3060
3063
|
pam_distance = abs(pam_site_pos - variant_pos_in_window)
|
|
@@ -3063,12 +3066,14 @@ def get_pam_disrupting_alt_sequences(
|
|
|
3063
3066
|
pam_disrupting_variants_list.append(var)
|
|
3064
3067
|
|
|
3065
3068
|
# Store PAM-specific metadata
|
|
3066
|
-
pam_metadata_list.append(
|
|
3067
|
-
|
|
3068
|
-
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
|
|
3069
|
+
pam_metadata_list.append(
|
|
3070
|
+
{
|
|
3071
|
+
"pam_site_pos": pam_site_pos,
|
|
3072
|
+
"pam_ref_sequence": ref_pam_seq,
|
|
3073
|
+
"pam_alt_sequence": alt_pam_seq,
|
|
3074
|
+
"pam_distance": pam_distance,
|
|
3075
|
+
}
|
|
3076
|
+
)
|
|
3072
3077
|
|
|
3073
3078
|
# If no PAM-disrupting variants found, yield empty results
|
|
3074
3079
|
if not pam_disrupting_variants_list:
|
|
@@ -3076,7 +3081,9 @@ def get_pam_disrupting_alt_sequences(
|
|
|
3076
3081
|
return
|
|
3077
3082
|
|
|
3078
3083
|
# Create DataFrame with filtered PAM-disrupting variants
|
|
3079
|
-
filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(
|
|
3084
|
+
filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(
|
|
3085
|
+
drop=True
|
|
3086
|
+
)
|
|
3080
3087
|
pam_metadata_df = pd.DataFrame(pam_metadata_list)
|
|
3081
3088
|
|
|
3082
3089
|
# Call get_alt_ref_sequences with the filtered variants
|
|
@@ -3087,12 +3094,13 @@ def get_pam_disrupting_alt_sequences(
|
|
|
3087
3094
|
encode,
|
|
3088
3095
|
n_chunks,
|
|
3089
3096
|
encoder,
|
|
3090
|
-
auto_map_chromosomes
|
|
3097
|
+
auto_map_chromosomes,
|
|
3091
3098
|
):
|
|
3092
3099
|
# Merge PAM-specific metadata with base metadata
|
|
3093
3100
|
# Both should have the same number of rows since we created one entry per PAM site
|
|
3094
|
-
enriched_metadata = pd.concat(
|
|
3095
|
-
|
|
3101
|
+
enriched_metadata = pd.concat(
|
|
3102
|
+
[base_metadata.reset_index(drop=True), pam_metadata_df], axis=1
|
|
3103
|
+
)
|
|
3096
3104
|
|
|
3097
3105
|
# Yield the chunk with enriched metadata
|
|
3098
3106
|
yield (alt_seqs, ref_seqs, enriched_metadata)
|
|
@@ -40,19 +40,49 @@ class VariantPosition:
|
|
|
40
40
|
svlen: int # Length of structural variant (base pairs, signed for DEL/INS)
|
|
41
41
|
variant_type: str # Type of variant ('SNV', 'INS', 'DEL', 'DUP', 'INV', 'BND')
|
|
42
42
|
|
|
43
|
-
def get_bin_positions(
|
|
43
|
+
def get_bin_positions(
|
|
44
|
+
self, bin_size: int, window_start: int, crop_length: int
|
|
45
|
+
) -> Tuple[int, int, int]:
|
|
44
46
|
"""
|
|
45
|
-
Convert base pair positions to bin indices.
|
|
47
|
+
Convert base pair positions to bin indices relative to window.
|
|
46
48
|
|
|
47
49
|
Args:
|
|
48
50
|
bin_size: Number of base pairs per prediction bin
|
|
51
|
+
window_start: Start position of the sequence window (0-based genomic coord).
|
|
52
|
+
crop_length: Number of base pairs cropped from each edge by the model.
|
|
53
|
+
This accounts for edge bases removed before prediction.
|
|
49
54
|
|
|
50
55
|
Returns:
|
|
51
|
-
Tuple of (ref_bin, alt_start_bin, alt_end_bin)
|
|
56
|
+
Tuple of (ref_bin, alt_start_bin, alt_end_bin) as bin indices
|
|
57
|
+
relative to the prediction vector. For centered masking, these
|
|
58
|
+
represent the center and extent of the masked region.
|
|
59
|
+
|
|
60
|
+
Notes:
|
|
61
|
+
- Positions are calculated relative to window_start, not absolute genomic coords
|
|
62
|
+
- crop_length accounts for edge bases removed before prediction
|
|
63
|
+
- Masked bins are centered on the variant position
|
|
52
64
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
65
|
+
# Calculate positions relative to window
|
|
66
|
+
rel_ref_pos = self.ref_pos - window_start
|
|
67
|
+
rel_alt_pos = self.alt_pos - window_start
|
|
68
|
+
|
|
69
|
+
# Account for cropping (bases removed from start of window before prediction)
|
|
70
|
+
rel_ref_pos -= crop_length
|
|
71
|
+
rel_alt_pos -= crop_length
|
|
72
|
+
|
|
73
|
+
# Convert to bin indices using floor division (not ceil!)
|
|
74
|
+
ref_bin_center = rel_ref_pos // bin_size
|
|
75
|
+
alt_bin_center = rel_alt_pos // bin_size
|
|
76
|
+
|
|
77
|
+
# Calculate number of bins to mask
|
|
78
|
+
svlen_bins = int(np.ceil(abs(self.svlen) / bin_size))
|
|
79
|
+
half_bins = svlen_bins // 2
|
|
80
|
+
|
|
81
|
+
# Center the masked region on the variant
|
|
82
|
+
ref_bin = ref_bin_center - half_bins
|
|
83
|
+
alt_start_bin = alt_bin_center - half_bins
|
|
84
|
+
alt_end_bin = alt_bin_center + (svlen_bins - half_bins)
|
|
85
|
+
|
|
56
86
|
return ref_bin, alt_start_bin, alt_end_bin
|
|
57
87
|
|
|
58
88
|
|
|
@@ -71,24 +101,27 @@ class PredictionAligner1D:
|
|
|
71
101
|
Args:
|
|
72
102
|
target_size: Expected number of bins in the prediction output
|
|
73
103
|
bin_size: Number of base pairs per prediction bin (model-specific)
|
|
104
|
+
crop_length: Number of base pairs cropped from each edge by the model
|
|
74
105
|
|
|
75
106
|
Example:
|
|
76
|
-
>>> aligner = PredictionAligner1D(target_size=896, bin_size=128)
|
|
107
|
+
>>> aligner = PredictionAligner1D(target_size=896, bin_size=128, crop_length=0)
|
|
77
108
|
>>> ref_aligned, alt_aligned = aligner.align_predictions(
|
|
78
109
|
... ref_pred, alt_pred, 'INS', variant_position
|
|
79
110
|
... )
|
|
80
111
|
"""
|
|
81
112
|
|
|
82
|
-
def __init__(self, target_size: int, bin_size: int):
|
|
113
|
+
def __init__(self, target_size: int, bin_size: int, crop_length: int):
|
|
83
114
|
"""
|
|
84
115
|
Initialize the 1D prediction aligner.
|
|
85
116
|
|
|
86
117
|
Args:
|
|
87
118
|
target_size: Expected number of bins in prediction (e.g., 896 for Enformer)
|
|
88
119
|
bin_size: Base pairs per bin (e.g., 128 for Enformer)
|
|
120
|
+
crop_length: Number of base pairs cropped from each edge by the model
|
|
89
121
|
"""
|
|
90
122
|
self.target_size = target_size
|
|
91
123
|
self.bin_size = bin_size
|
|
124
|
+
self.crop_length = crop_length
|
|
92
125
|
|
|
93
126
|
def align_predictions(
|
|
94
127
|
self,
|
|
@@ -96,6 +129,7 @@ class PredictionAligner1D:
|
|
|
96
129
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
97
130
|
svtype: str,
|
|
98
131
|
var_pos: VariantPosition,
|
|
132
|
+
window_start: int = 0,
|
|
99
133
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
100
134
|
"""
|
|
101
135
|
Main entry point for 1D prediction alignment.
|
|
@@ -105,6 +139,8 @@ class PredictionAligner1D:
|
|
|
105
139
|
alt_pred: Alternate prediction vector (length N)
|
|
106
140
|
svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
|
|
107
141
|
var_pos: Variant position information
|
|
142
|
+
window_start: Start position of sequence window (0-based genomic coord).
|
|
143
|
+
Required for correct bin calculation. Defaults to 0.
|
|
108
144
|
|
|
109
145
|
Returns:
|
|
110
146
|
Tuple of (aligned_ref, aligned_alt) vectors with NaN masking applied
|
|
@@ -120,10 +156,12 @@ class PredictionAligner1D:
|
|
|
120
156
|
|
|
121
157
|
if svtype_normalized in ["DEL", "DUP", "INS"]:
|
|
122
158
|
return self._align_indel_predictions(
|
|
123
|
-
ref_pred, alt_pred, svtype_normalized, var_pos
|
|
159
|
+
ref_pred, alt_pred, svtype_normalized, var_pos, window_start
|
|
124
160
|
)
|
|
125
161
|
elif svtype_normalized == "INV":
|
|
126
|
-
return self._align_inversion_predictions(
|
|
162
|
+
return self._align_inversion_predictions(
|
|
163
|
+
ref_pred, alt_pred, var_pos, window_start
|
|
164
|
+
)
|
|
127
165
|
elif svtype_normalized in ["SNV", "MNV"]:
|
|
128
166
|
# SNVs don't change coordinates, direct alignment
|
|
129
167
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
@@ -140,18 +178,22 @@ class PredictionAligner1D:
|
|
|
140
178
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
141
179
|
svtype: str,
|
|
142
180
|
var_pos: VariantPosition,
|
|
181
|
+
window_start: int = 0,
|
|
143
182
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
144
183
|
"""
|
|
145
184
|
Align predictions for insertions, deletions, and duplications.
|
|
146
185
|
|
|
147
186
|
Strategy:
|
|
148
187
|
1. For DEL: Swap REF/ALT (deletion removes from REF)
|
|
149
|
-
2. Insert NaN bins in shorter sequence
|
|
188
|
+
2. Insert NaN bins in shorter sequence (centered on variant)
|
|
150
189
|
3. Crop edges to maintain target size
|
|
151
190
|
4. For DEL: Swap back
|
|
152
191
|
|
|
153
192
|
This ensures that positions present in one sequence but not the other
|
|
154
193
|
are marked with NaN, enabling fair comparison of overlapping regions.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
window_start: Start position of sequence window (0-based genomic coord)
|
|
155
197
|
"""
|
|
156
198
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
157
199
|
|
|
@@ -170,8 +212,10 @@ class PredictionAligner1D:
|
|
|
170
212
|
var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
|
|
171
213
|
)
|
|
172
214
|
|
|
173
|
-
# Get bin positions
|
|
174
|
-
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
|
|
215
|
+
# Get bin positions (window-relative, centered)
|
|
216
|
+
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
|
|
217
|
+
self.bin_size, window_start, self.crop_length
|
|
218
|
+
)
|
|
175
219
|
bins_to_add = alt_end_bin - alt_start_bin
|
|
176
220
|
|
|
177
221
|
# Insert NaN bins in REF where variant exists in ALT
|
|
@@ -248,6 +292,7 @@ class PredictionAligner1D:
|
|
|
248
292
|
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
249
293
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
250
294
|
var_pos: VariantPosition,
|
|
295
|
+
window_start: int = 0,
|
|
251
296
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
252
297
|
"""
|
|
253
298
|
Align predictions for inversions.
|
|
@@ -259,6 +304,9 @@ class PredictionAligner1D:
|
|
|
259
304
|
For strand-aware models, inversions can significantly affect predictions
|
|
260
305
|
because regulatory elements now appear on the opposite strand. We mask
|
|
261
306
|
the inverted region to focus comparison on unaffected flanking sequences.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
window_start: Start position of sequence window (0-based genomic coord)
|
|
262
310
|
"""
|
|
263
311
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
264
312
|
|
|
@@ -270,7 +318,9 @@ class PredictionAligner1D:
|
|
|
270
318
|
ref_np = ref_pred.copy()
|
|
271
319
|
alt_np = alt_pred.copy()
|
|
272
320
|
|
|
273
|
-
var_start, _, var_end = var_pos.get_bin_positions(
|
|
321
|
+
var_start, _, var_end = var_pos.get_bin_positions(
|
|
322
|
+
self.bin_size, window_start, self.crop_length
|
|
323
|
+
)
|
|
274
324
|
|
|
275
325
|
# Mask inverted region in both REF and ALT
|
|
276
326
|
ref_np[var_start : var_end + 1] = np.nan
|
|
@@ -373,19 +423,23 @@ class PredictionAligner2D:
|
|
|
373
423
|
target_size: Expected matrix dimension (NxN)
|
|
374
424
|
bin_size: Number of base pairs per matrix bin (model-specific)
|
|
375
425
|
diag_offset: Number of diagonal bins to mask (model-specific)
|
|
426
|
+
crop_length: Number of base pairs cropped from each edge by the model
|
|
376
427
|
|
|
377
428
|
Example:
|
|
378
429
|
>>> aligner = PredictionAligner2D(
|
|
379
430
|
... target_size=448,
|
|
380
431
|
... bin_size=2048,
|
|
381
|
-
... diag_offset=2
|
|
432
|
+
... diag_offset=2,
|
|
433
|
+
... crop_length=0
|
|
382
434
|
... )
|
|
383
435
|
>>> ref_aligned, alt_aligned = aligner.align_predictions(
|
|
384
436
|
... ref_matrix, alt_matrix, 'DEL', variant_position
|
|
385
437
|
... )
|
|
386
438
|
"""
|
|
387
439
|
|
|
388
|
-
def __init__(
|
|
440
|
+
def __init__(
|
|
441
|
+
self, target_size: int, bin_size: int, diag_offset: int, crop_length: int
|
|
442
|
+
):
|
|
389
443
|
"""
|
|
390
444
|
Initialize the 2D prediction aligner.
|
|
391
445
|
|
|
@@ -393,10 +447,12 @@ class PredictionAligner2D:
|
|
|
393
447
|
target_size: Matrix dimension (e.g., 448 for Akita)
|
|
394
448
|
bin_size: Base pairs per bin (e.g., 2048 for Akita)
|
|
395
449
|
diag_offset: Diagonal masking offset (e.g., 2 for Akita)
|
|
450
|
+
crop_length: Number of base pairs cropped from each edge by the model
|
|
396
451
|
"""
|
|
397
452
|
self.target_size = target_size
|
|
398
453
|
self.bin_size = bin_size
|
|
399
454
|
self.diag_offset = diag_offset
|
|
455
|
+
self.crop_length = crop_length
|
|
400
456
|
|
|
401
457
|
def align_predictions(
|
|
402
458
|
self,
|
|
@@ -404,6 +460,7 @@ class PredictionAligner2D:
|
|
|
404
460
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
405
461
|
svtype: str,
|
|
406
462
|
var_pos: VariantPosition,
|
|
463
|
+
window_start: int = 0,
|
|
407
464
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
408
465
|
"""
|
|
409
466
|
Main entry point for 2D matrix alignment.
|
|
@@ -413,6 +470,8 @@ class PredictionAligner2D:
|
|
|
413
470
|
alt_pred: Alternate prediction matrix (NxN)
|
|
414
471
|
svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
|
|
415
472
|
var_pos: Variant position information
|
|
473
|
+
window_start: Start position of sequence window (0-based genomic coord).
|
|
474
|
+
Required for correct bin calculation. Defaults to 0.
|
|
416
475
|
|
|
417
476
|
Returns:
|
|
418
477
|
Tuple of (aligned_ref, aligned_alt) matrices with NaN masking applied
|
|
@@ -428,10 +487,12 @@ class PredictionAligner2D:
|
|
|
428
487
|
|
|
429
488
|
if svtype_normalized in ["DEL", "DUP", "INS"]:
|
|
430
489
|
return self._align_indel_matrices(
|
|
431
|
-
ref_pred, alt_pred, svtype_normalized, var_pos
|
|
490
|
+
ref_pred, alt_pred, svtype_normalized, var_pos, window_start
|
|
432
491
|
)
|
|
433
492
|
elif svtype_normalized == "INV":
|
|
434
|
-
return self._align_inversion_matrices(
|
|
493
|
+
return self._align_inversion_matrices(
|
|
494
|
+
ref_pred, alt_pred, var_pos, window_start
|
|
495
|
+
)
|
|
435
496
|
elif svtype_normalized in ["SNV", "MNV"]:
|
|
436
497
|
# SNVs don't change coordinates, direct alignment
|
|
437
498
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
@@ -448,15 +509,19 @@ class PredictionAligner2D:
|
|
|
448
509
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
449
510
|
svtype: str,
|
|
450
511
|
var_pos: VariantPosition,
|
|
512
|
+
window_start: int = 0,
|
|
451
513
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
452
514
|
"""
|
|
453
515
|
Align matrices for insertions, deletions, and duplications.
|
|
454
516
|
|
|
455
517
|
Strategy:
|
|
456
518
|
1. For DEL: Swap REF/ALT (deletion removes from REF)
|
|
457
|
-
2. Insert NaN bins (rows AND columns) in shorter matrix
|
|
519
|
+
2. Insert NaN bins (rows AND columns) in shorter matrix (centered on variant)
|
|
458
520
|
3. Crop edges to maintain target size
|
|
459
521
|
4. For DEL: Swap back
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
window_start: Start position of sequence window (0-based genomic coord)
|
|
460
525
|
"""
|
|
461
526
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
462
527
|
|
|
@@ -475,8 +540,10 @@ class PredictionAligner2D:
|
|
|
475
540
|
var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
|
|
476
541
|
)
|
|
477
542
|
|
|
478
|
-
# Get bin positions
|
|
479
|
-
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
|
|
543
|
+
# Get bin positions (window-relative, centered)
|
|
544
|
+
ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
|
|
545
|
+
self.bin_size, window_start, self.crop_length
|
|
546
|
+
)
|
|
480
547
|
bins_to_add = alt_end_bin - alt_start_bin
|
|
481
548
|
|
|
482
549
|
# Insert NaN bins in REF where variant exists in ALT
|
|
@@ -541,6 +608,7 @@ class PredictionAligner2D:
|
|
|
541
608
|
ref_pred: Union[np.ndarray, "torch.Tensor"],
|
|
542
609
|
alt_pred: Union[np.ndarray, "torch.Tensor"],
|
|
543
610
|
var_pos: VariantPosition,
|
|
611
|
+
window_start: int = 0,
|
|
544
612
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
545
613
|
"""
|
|
546
614
|
Align matrices for inversions.
|
|
@@ -556,6 +624,9 @@ class PredictionAligner2D:
|
|
|
556
624
|
|
|
557
625
|
The same NaN pattern is mirrored to ALT so both matrices have identical
|
|
558
626
|
masked regions, enabling fair comparison of the unaffected areas.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
window_start: Start position of sequence window (0-based genomic coord)
|
|
559
630
|
"""
|
|
560
631
|
is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
|
|
561
632
|
|
|
@@ -567,7 +638,9 @@ class PredictionAligner2D:
|
|
|
567
638
|
ref_np = ref_pred.copy()
|
|
568
639
|
alt_np = alt_pred.copy()
|
|
569
640
|
|
|
570
|
-
var_start, _, var_end = var_pos.get_bin_positions(
|
|
641
|
+
var_start, _, var_end = var_pos.get_bin_positions(
|
|
642
|
+
self.bin_size, window_start, self.crop_length
|
|
643
|
+
)
|
|
571
644
|
|
|
572
645
|
# Mask inverted region in REF (cross-pattern: rows + columns)
|
|
573
646
|
ref_np[var_start : var_end + 1, :] = np.nan
|
|
@@ -802,6 +875,7 @@ def align_predictions_by_coordinate(
|
|
|
802
875
|
metadata_row: dict,
|
|
803
876
|
bin_size: int,
|
|
804
877
|
prediction_type: str,
|
|
878
|
+
crop_length: int,
|
|
805
879
|
matrix_size: Optional[int] = None,
|
|
806
880
|
diag_offset: int = 0,
|
|
807
881
|
) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
|
|
@@ -812,7 +886,7 @@ def align_predictions_by_coordinate(
|
|
|
812
886
|
vectors (e.g., chromatin accessibility, TF binding) and 2D matrices (e.g., Hi-C contact maps),
|
|
813
887
|
routing to the appropriate alignment strategy based on variant type.
|
|
814
888
|
|
|
815
|
-
IMPORTANT: Model-specific parameters (bin_size, matrix_size) must be explicitly
|
|
889
|
+
IMPORTANT: Model-specific parameters (bin_size, crop_length, matrix_size) must be explicitly
|
|
816
890
|
provided by the user. There are no defaults because these vary across different models.
|
|
817
891
|
|
|
818
892
|
Args:
|
|
@@ -824,10 +898,13 @@ def align_predictions_by_coordinate(
|
|
|
824
898
|
- 'variant_pos0': Variant position (0-based, absolute genomic coordinate)
|
|
825
899
|
- 'svlen': Length of structural variant (optional, for symbolic alleles)
|
|
826
900
|
bin_size: Number of base pairs per prediction bin (REQUIRED, model-specific)
|
|
827
|
-
Examples: 2048 for Akita
|
|
901
|
+
Examples: 2048 for Akita, 128 for Enformer
|
|
828
902
|
prediction_type: Type of predictions ("1D" or "2D")
|
|
829
903
|
- "1D": Vector predictions (chromatin accessibility, TF binding, etc.)
|
|
830
904
|
- "2D": Matrix predictions (Hi-C contact maps, Micro-C, etc.)
|
|
905
|
+
crop_length: Number of base pairs cropped from each edge by the model (REQUIRED)
|
|
906
|
+
This accounts for edge bases removed before prediction.
|
|
907
|
+
Examples: 0 for models without cropping
|
|
831
908
|
matrix_size: Size of contact matrix (REQUIRED for 2D type)
|
|
832
909
|
Examples: 448 for Akita
|
|
833
910
|
diag_offset: Number of diagonal bins to mask (default: 0 for no masking)
|
|
@@ -849,7 +926,8 @@ def align_predictions_by_coordinate(
|
|
|
849
926
|
... metadata_row={'variant_type': 'INS', 'window_start': 0,
|
|
850
927
|
... 'variant_pos0': 500, 'svlen': 100},
|
|
851
928
|
... bin_size=128,
|
|
852
|
-
... prediction_type="1D"
|
|
929
|
+
... prediction_type="1D",
|
|
930
|
+
... crop_length=0
|
|
853
931
|
... )
|
|
854
932
|
|
|
855
933
|
Example (2D contact maps with diagonal masking):
|
|
@@ -860,6 +938,7 @@ def align_predictions_by_coordinate(
|
|
|
860
938
|
... 'variant_pos0': 50000, 'svlen': -2048},
|
|
861
939
|
... bin_size=2048,
|
|
862
940
|
... prediction_type="2D",
|
|
941
|
+
... crop_length=0,
|
|
863
942
|
... matrix_size=448,
|
|
864
943
|
... diag_offset=2 # Optional: use 0 if no diagonal masking
|
|
865
944
|
... )
|
|
@@ -872,6 +951,7 @@ def align_predictions_by_coordinate(
|
|
|
872
951
|
... 'variant_pos0': 1000, 'svlen': 500},
|
|
873
952
|
... bin_size=1000,
|
|
874
953
|
... prediction_type="2D",
|
|
954
|
+
... crop_length=0,
|
|
875
955
|
... matrix_size=512
|
|
876
956
|
... # diag_offset defaults to 0 (no masking)
|
|
877
957
|
... )
|
|
@@ -937,7 +1017,9 @@ def align_predictions_by_coordinate(
|
|
|
937
1017
|
# Handle multi-target predictions [n_targets, n_bins]
|
|
938
1018
|
if ndim > 1:
|
|
939
1019
|
target_size = ref_preds.shape[-1] # Number of bins
|
|
940
|
-
aligner = PredictionAligner1D(
|
|
1020
|
+
aligner = PredictionAligner1D(
|
|
1021
|
+
target_size=target_size, bin_size=bin_size, crop_length=crop_length
|
|
1022
|
+
)
|
|
941
1023
|
|
|
942
1024
|
# Align each target separately
|
|
943
1025
|
n_targets = ref_preds.shape[0]
|
|
@@ -948,7 +1030,7 @@ def align_predictions_by_coordinate(
|
|
|
948
1030
|
ref_target = ref_preds[target_idx]
|
|
949
1031
|
alt_target = alt_preds[target_idx]
|
|
950
1032
|
ref_aligned, alt_aligned = aligner.align_predictions(
|
|
951
|
-
ref_target, alt_target, variant_type, var_pos
|
|
1033
|
+
ref_target, alt_target, variant_type, var_pos, window_start
|
|
952
1034
|
)
|
|
953
1035
|
ref_aligned_list.append(ref_aligned)
|
|
954
1036
|
alt_aligned_list.append(alt_aligned)
|
|
@@ -965,9 +1047,11 @@ def align_predictions_by_coordinate(
|
|
|
965
1047
|
else:
|
|
966
1048
|
# Single target prediction [n_bins]
|
|
967
1049
|
target_size = len(ref_preds)
|
|
968
|
-
aligner = PredictionAligner1D(
|
|
1050
|
+
aligner = PredictionAligner1D(
|
|
1051
|
+
target_size=target_size, bin_size=bin_size, crop_length=crop_length
|
|
1052
|
+
)
|
|
969
1053
|
return aligner.align_predictions(
|
|
970
|
-
ref_preds, alt_preds, variant_type, var_pos
|
|
1054
|
+
ref_preds, alt_preds, variant_type, var_pos, window_start
|
|
971
1055
|
)
|
|
972
1056
|
else: # 2D
|
|
973
1057
|
# Check if predictions are 1D (flattened upper triangular) or 2D (full matrix)
|
|
@@ -989,10 +1073,13 @@ def align_predictions_by_coordinate(
|
|
|
989
1073
|
|
|
990
1074
|
# Align matrices
|
|
991
1075
|
aligner = PredictionAligner2D(
|
|
992
|
-
target_size=matrix_size,
|
|
1076
|
+
target_size=matrix_size,
|
|
1077
|
+
bin_size=bin_size,
|
|
1078
|
+
diag_offset=diag_offset,
|
|
1079
|
+
crop_length=crop_length,
|
|
993
1080
|
)
|
|
994
1081
|
aligned_ref_matrix, aligned_alt_matrix = aligner.align_predictions(
|
|
995
|
-
ref_matrix, alt_matrix, variant_type, var_pos
|
|
1082
|
+
ref_matrix, alt_matrix, variant_type, var_pos, window_start
|
|
996
1083
|
)
|
|
997
1084
|
|
|
998
1085
|
# Convert back to flattened format
|
|
@@ -1007,8 +1094,11 @@ def align_predictions_by_coordinate(
|
|
|
1007
1094
|
else:
|
|
1008
1095
|
# Already 2D matrices
|
|
1009
1096
|
aligner = PredictionAligner2D(
|
|
1010
|
-
target_size=matrix_size,
|
|
1097
|
+
target_size=matrix_size,
|
|
1098
|
+
bin_size=bin_size,
|
|
1099
|
+
diag_offset=diag_offset,
|
|
1100
|
+
crop_length=crop_length,
|
|
1011
1101
|
)
|
|
1012
1102
|
return aligner.align_predictions(
|
|
1013
|
-
ref_preds, alt_preds, variant_type, var_pos
|
|
1103
|
+
ref_preds, alt_preds, variant_type, var_pos, window_start
|
|
1014
1104
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: supremo_lite
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.5
|
|
4
4
|
Summary: A lightweight memory first, model agnostic version of SuPreMo
|
|
5
5
|
License: MIT
|
|
6
6
|
License-File: LICENSE
|
|
@@ -42,13 +42,10 @@ For the latest features and bug fixes:
|
|
|
42
42
|
|
|
43
43
|
```bash
|
|
44
44
|
# Install directly latest release
|
|
45
|
-
pip install
|
|
45
|
+
pip install supremo-lite
|
|
46
46
|
|
|
47
47
|
# Or install a specific version/tag
|
|
48
48
|
pip install git+https://github.com/gladstone-institutes/supremo_lite.git@v0.5.0
|
|
49
|
-
|
|
50
|
-
# Or install from a specific branch
|
|
51
|
-
pip install git+https://github.com/gladstone-institutes/supremo_lite.git@main
|
|
52
49
|
```
|
|
53
50
|
|
|
54
51
|
### Dependencies
|
|
@@ -60,7 +57,7 @@ Required dependencies will be installed automatically:
|
|
|
60
57
|
|
|
61
58
|
Optional dependencies:
|
|
62
59
|
- `torch` - For PyTorch tensor support (automatically detected)
|
|
63
|
-
- [https://github.com/gladstone-institutes/brisket
|
|
60
|
+
- [brisket](https://github.com/gladstone-institutes/brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
|
|
64
61
|
|
|
65
62
|
## Quick Start
|
|
66
63
|
|
|
@@ -214,3 +211,4 @@ Interested in contributing? Check out the contributing guidelines. Please note t
|
|
|
214
211
|
## Credits
|
|
215
212
|
|
|
216
213
|
`supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
|
|
214
|
+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
supremo_lite/__init__.py,sha256=mKgstoebv8o-3d0h1RasW5OpBmUaOlxqt7pJNYcMxRU,1660
|
|
2
|
+
supremo_lite/chromosome_utils.py,sha256=rOjS3IQXmjBYZG949C5eG1zWprTOdhPVg4oW8GLbCek,10938
|
|
3
|
+
supremo_lite/core.py,sha256=-OmQEAS5J2rhocXC4aHYWaZZE2N8MX942SePMLoQ6Xc,1190
|
|
4
|
+
supremo_lite/mock_models/__init__.py,sha256=YQcL3oOoe0WJW5y_LtpuhmYIlx-_xS8raB1Mty5wtF4,3672
|
|
5
|
+
supremo_lite/mock_models/testmodel_1d.py,sha256=0CqLuAwxthz_sn_2v0C5XHu8q42dZO7EIzo9aX1hJ2U,6056
|
|
6
|
+
supremo_lite/mock_models/testmodel_2d.py,sha256=swSEEkORf7sjlZQ7XakY3qGGeRomVpiLZZbice6zziw,7011
|
|
7
|
+
supremo_lite/mutagenesis.py,sha256=p5oc4apcCm4AnGoYmSlbTKG6uS7xVNX2EoJ8322cu8M,16260
|
|
8
|
+
supremo_lite/personalize.py,sha256=w3Bv0xwikHbpZlXkgXWB4lo6XzzbzrrKSJqrrC-rTRs,126389
|
|
9
|
+
supremo_lite/prediction_alignment.py,sha256=rmpZDE-PK9-CqsXjlVw0J0KJqZXuPPl55IceF76gj-s,43020
|
|
10
|
+
supremo_lite/sequence_utils.py,sha256=yl-ghw9mGEGjiIYCBZ-4-S-CpXDjsP6suGuVVdww1mY,4147
|
|
11
|
+
supremo_lite/variant_utils.py,sha256=9B-IiUBIWYipoB2Sa7OJmhmLNVCjykMmdL3Mm4n9wvw,60663
|
|
12
|
+
supremo_lite-0.5.5.dist-info/METADATA,sha256=adDo5VH2nXYF-J6R1gT3rqelY-LUFFhv3V3DG-CTXhg,9025
|
|
13
|
+
supremo_lite-0.5.5.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
14
|
+
supremo_lite-0.5.5.dist-info/licenses/LICENSE,sha256=QoRjddrQkzdNXNq7EQbRtWGvOKv1h031CG8wreXDa00,1079
|
|
15
|
+
supremo_lite-0.5.5.dist-info/RECORD,,
|
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
supremo_lite/__init__.py,sha256=-s37P1FPl_xgrjv7Oh4W3bFetJCtk9i4B1vMlwzkd5Y,1660
|
|
2
|
-
supremo_lite/chromosome_utils.py,sha256=rOjS3IQXmjBYZG949C5eG1zWprTOdhPVg4oW8GLbCek,10938
|
|
3
|
-
supremo_lite/core.py,sha256=-OmQEAS5J2rhocXC4aHYWaZZE2N8MX942SePMLoQ6Xc,1190
|
|
4
|
-
supremo_lite/mock_models/__init__.py,sha256=YQcL3oOoe0WJW5y_LtpuhmYIlx-_xS8raB1Mty5wtF4,3672
|
|
5
|
-
supremo_lite/mock_models/testmodel_1d.py,sha256=0CqLuAwxthz_sn_2v0C5XHu8q42dZO7EIzo9aX1hJ2U,6056
|
|
6
|
-
supremo_lite/mock_models/testmodel_2d.py,sha256=S3diAgBjCHIFlmdVQKn8bGB9mYKbjamwuO3Zl3VR174,6903
|
|
7
|
-
supremo_lite/mutagenesis.py,sha256=hhuA548RuHZUF8w03Ft75LFTVerbzCar6w1sRFxZ1bA,16068
|
|
8
|
-
supremo_lite/personalize.py,sha256=77FzD4JIW2rM9vJcanmd77i1-2Y_FAqKlkP4nz9iQNs,126302
|
|
9
|
-
supremo_lite/prediction_alignment.py,sha256=EMPm4vIymiE6Y1Zmu8Z3CPajv02tvuwReOwujLQ6QjI,39164
|
|
10
|
-
supremo_lite/sequence_utils.py,sha256=yl-ghw9mGEGjiIYCBZ-4-S-CpXDjsP6suGuVVdww1mY,4147
|
|
11
|
-
supremo_lite/variant_utils.py,sha256=9B-IiUBIWYipoB2Sa7OJmhmLNVCjykMmdL3Mm4n9wvw,60663
|
|
12
|
-
supremo_lite-0.5.4.dist-info/METADATA,sha256=3k8__zQ9iti2qKCpF5I1I3Iy8gIQI8fwyynr7-jqgLI,9139
|
|
13
|
-
supremo_lite-0.5.4.dist-info/WHEEL,sha256=zp0Cn7JsFoX2ATtOhtaFYIiE2rmFAD4OcMhtUki8W3U,88
|
|
14
|
-
supremo_lite-0.5.4.dist-info/licenses/LICENSE,sha256=QoRjddrQkzdNXNq7EQbRtWGvOKv1h031CG8wreXDa00,1079
|
|
15
|
-
supremo_lite-0.5.4.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|