supremo-lite 0.5.4__tar.gz → 0.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: supremo_lite
3
- Version: 0.5.4
3
+ Version: 0.5.5
4
4
  Summary: A lightweight memory first, model agnostic version of SuPreMo
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -42,13 +42,10 @@ For the latest features and bug fixes:
42
42
 
43
43
  ```bash
44
44
  # Install directly latest release
45
- pip install supremo_lite
45
+ pip install supremo-lite
46
46
 
47
47
  # Or install a specific version/tag
48
48
  pip install git+https://github.com/gladstone-institutes/supremo_lite.git@v0.5.0
49
-
50
- # Or install from a specific branch
51
- pip install git+https://github.com/gladstone-institutes/supremo_lite.git@main
52
49
  ```
53
50
 
54
51
  ### Dependencies
@@ -60,7 +57,7 @@ Required dependencies will be installed automatically:
60
57
 
61
58
  Optional dependencies:
62
59
  - `torch` - For PyTorch tensor support (automatically detected)
63
- - [https://github.com/gladstone-institutes/brisket](brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
60
+ - [brisket](https://github.com/gladstone-institutes/brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
64
61
 
65
62
  ## Quick Start
66
63
 
@@ -214,3 +211,4 @@ Interested in contributing? Check out the contributing guidelines. Please note t
214
211
  ## Credits
215
212
 
216
213
  `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
214
+
@@ -20,13 +20,10 @@ For the latest features and bug fixes:
20
20
 
21
21
  ```bash
22
22
  # Install directly latest release
23
- pip install supremo_lite
23
+ pip install supremo-lite
24
24
 
25
25
  # Or install a specific version/tag
26
26
  pip install git+https://github.com/gladstone-institutes/supremo_lite.git@v0.5.0
27
-
28
- # Or install from a specific branch
29
- pip install git+https://github.com/gladstone-institutes/supremo_lite.git@main
30
27
  ```
31
28
 
32
29
  ### Dependencies
@@ -38,7 +35,7 @@ Required dependencies will be installed automatically:
38
35
 
39
36
  Optional dependencies:
40
37
  - `torch` - For PyTorch tensor support (automatically detected)
41
- - [https://github.com/gladstone-institutes/brisket](brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
38
+ - [brisket](https://github.com/gladstone-institutes/brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
42
39
 
43
40
  ## Quick Start
44
41
 
@@ -191,4 +188,4 @@ Interested in contributing? Check out the contributing guidelines. Please note t
191
188
 
192
189
  ## Credits
193
190
 
194
- `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
191
+ `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "supremo_lite"
3
- version = "0.5.4"
3
+ version = "0.5.5"
4
4
  description = "A lightweight memory first, model agnostic version of SuPreMo"
5
5
  authors = ["Natalie Gill", "Sean Whalen"]
6
6
  license = "MIT"
@@ -52,7 +52,7 @@ from .prediction_alignment import align_predictions_by_coordinate
52
52
  # This allows users who don't have PyTorch to still use the main package
53
53
 
54
54
  # Version
55
- __version__ = "0.5.4"
55
+ __version__ = "0.5.5"
56
56
  # Package metadata
57
57
  __description__ = (
58
58
  "A module for generating personalized genome sequences and in-silico mutagenesis"
@@ -126,9 +126,13 @@ if TORCH_AVAILABLE:
126
126
  )
127
127
 
128
128
  # Crop bins from all edges to focus loss function
129
- y_hat = y_hat[
130
- :, :, self.crop_bins : -self.crop_bins, self.crop_bins : -self.crop_bins
131
- ]
129
+ if self.crop_bins > 0:
130
+ y_hat = y_hat[
131
+ :,
132
+ :,
133
+ self.crop_bins : -self.crop_bins,
134
+ self.crop_bins : -self.crop_bins,
135
+ ]
132
136
 
133
137
  # Return full contact matrix
134
138
  return y_hat
@@ -146,7 +146,8 @@ def get_sm_sequences(chrom, start, end, reference_fasta, encoder=None):
146
146
 
147
147
  # Create a DataFrame for the metadata
148
148
  metadata_df = pd.DataFrame(
149
- metadata, columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"]
149
+ metadata,
150
+ columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"],
150
151
  )
151
152
 
152
153
  return ref_1h, alt_seqs_stacked, metadata_df
@@ -239,9 +240,7 @@ def get_sm_subsequences(
239
240
  )
240
241
  elif not has_bed:
241
242
  # Neither approach was specified
242
- raise ValueError(
243
- "Must provide either (anchor + anchor_radius) or bed_regions."
244
- )
243
+ raise ValueError("Must provide either (anchor + anchor_radius) or bed_regions.")
245
244
 
246
245
  alt_seqs = []
247
246
  metadata = []
@@ -331,7 +330,11 @@ def get_sm_subsequences(
331
330
 
332
331
  # Adjust window to stay within chromosome bounds
333
332
  chrom_obj = reference_fasta[chrom]
334
- chrom_len = len(chrom_obj) if hasattr(chrom_obj, '__len__') else len(chrom_obj.seq)
333
+ chrom_len = (
334
+ len(chrom_obj)
335
+ if hasattr(chrom_obj, "__len__")
336
+ else len(chrom_obj.seq)
337
+ )
335
338
  if window_start < 0:
336
339
  window_start = 0
337
340
  window_end = min(seq_len, chrom_len)
@@ -377,13 +380,17 @@ def get_sm_subsequences(
377
380
  # Create a clone and substitute the base
378
381
  if TORCH_AVAILABLE and isinstance(region_1h, torch.Tensor):
379
382
  alt_1h = region_1h.clone()
380
- alt_1h[:, i] = torch.tensor(nt_to_1h[alt], dtype=alt_1h.dtype)
383
+ alt_1h[:, i] = torch.tensor(
384
+ nt_to_1h[alt], dtype=alt_1h.dtype
385
+ )
381
386
  else:
382
387
  alt_1h = region_1h.copy()
383
388
  alt_1h[:, i] = nt_to_1h[alt]
384
389
 
385
390
  alt_seqs.append(alt_1h)
386
- metadata.append([chrom, window_start, window_end, i, ref_nt, alt])
391
+ metadata.append(
392
+ [chrom, window_start, window_end, i, ref_nt, alt]
393
+ )
387
394
 
388
395
  # If no regions were processed, create empty ref_1h
389
396
  if ref_1h is None:
@@ -408,7 +415,8 @@ def get_sm_subsequences(
408
415
 
409
416
  # Create a DataFrame for the metadata
410
417
  metadata_df = pd.DataFrame(
411
- metadata, columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"]
418
+ metadata,
419
+ columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"],
412
420
  )
413
421
 
414
422
  return ref_1h, alt_seqs_stacked, metadata_df
@@ -42,7 +42,7 @@ IUPAC_CODES = {
42
42
  "D": "[AGT]",
43
43
  "H": "[ACT]",
44
44
  "V": "[ACG]",
45
- "N": "[ACGT]"
45
+ "N": "[ACGT]",
46
46
  }
47
47
 
48
48
 
@@ -2811,6 +2811,7 @@ def get_pam_disrupting_alt_sequences(
2811
2811
  ... ref, vcf, seq_len=50, max_pam_distance=10, n_chunks=5):
2812
2812
  ... predictions = model.predict(alt_seqs, ref_seqs)
2813
2813
  """
2814
+
2814
2815
  # Helper function to find PAM sites in a sequence
2815
2816
  def _find_pam_sites(sequence, pam_pattern):
2816
2817
  """Find all PAM site positions in a sequence using IUPAC codes.
@@ -2830,7 +2831,7 @@ def get_pam_disrupting_alt_sequences(
2830
2831
  pat_base = pat_upper[j]
2831
2832
 
2832
2833
  # Sequence 'N' (padding or unknown) matches any pattern base
2833
- if seq_base == 'N':
2834
+ if seq_base == "N":
2834
2835
  continue # Always matches
2835
2836
 
2836
2837
  # Get allowed bases for this pattern position
@@ -3003,9 +3004,7 @@ def get_pam_disrupting_alt_sequences(
3003
3004
  ref_allele = var.get("ref", "")
3004
3005
  alt_allele = var.get("alt", "")
3005
3006
  is_indel = (
3006
- len(ref_allele) != len(alt_allele)
3007
- or ref_allele == "-"
3008
- or alt_allele == "-"
3007
+ len(ref_allele) != len(alt_allele) or ref_allele == "-" or alt_allele == "-"
3009
3008
  )
3010
3009
 
3011
3010
  truly_disrupted_pam_sites = []
@@ -3053,8 +3052,12 @@ def get_pam_disrupting_alt_sequences(
3053
3052
  # For each disrupted PAM site, create a metadata entry
3054
3053
  for pam_site_pos in truly_disrupted_pam_sites:
3055
3054
  # Extract PAM sequences
3056
- ref_pam_seq = ref_window_seq[pam_site_pos : pam_site_pos + len(pam_sequence)]
3057
- alt_pam_seq = modified_window[pam_site_pos : pam_site_pos + len(pam_sequence)]
3055
+ ref_pam_seq = ref_window_seq[
3056
+ pam_site_pos : pam_site_pos + len(pam_sequence)
3057
+ ]
3058
+ alt_pam_seq = modified_window[
3059
+ pam_site_pos : pam_site_pos + len(pam_sequence)
3060
+ ]
3058
3061
 
3059
3062
  # Calculate distance from variant to PAM
3060
3063
  pam_distance = abs(pam_site_pos - variant_pos_in_window)
@@ -3063,12 +3066,14 @@ def get_pam_disrupting_alt_sequences(
3063
3066
  pam_disrupting_variants_list.append(var)
3064
3067
 
3065
3068
  # Store PAM-specific metadata
3066
- pam_metadata_list.append({
3067
- 'pam_site_pos': pam_site_pos,
3068
- 'pam_ref_sequence': ref_pam_seq,
3069
- 'pam_alt_sequence': alt_pam_seq,
3070
- 'pam_distance': pam_distance
3071
- })
3069
+ pam_metadata_list.append(
3070
+ {
3071
+ "pam_site_pos": pam_site_pos,
3072
+ "pam_ref_sequence": ref_pam_seq,
3073
+ "pam_alt_sequence": alt_pam_seq,
3074
+ "pam_distance": pam_distance,
3075
+ }
3076
+ )
3072
3077
 
3073
3078
  # If no PAM-disrupting variants found, yield empty results
3074
3079
  if not pam_disrupting_variants_list:
@@ -3076,7 +3081,9 @@ def get_pam_disrupting_alt_sequences(
3076
3081
  return
3077
3082
 
3078
3083
  # Create DataFrame with filtered PAM-disrupting variants
3079
- filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(drop=True)
3084
+ filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(
3085
+ drop=True
3086
+ )
3080
3087
  pam_metadata_df = pd.DataFrame(pam_metadata_list)
3081
3088
 
3082
3089
  # Call get_alt_ref_sequences with the filtered variants
@@ -3087,12 +3094,13 @@ def get_pam_disrupting_alt_sequences(
3087
3094
  encode,
3088
3095
  n_chunks,
3089
3096
  encoder,
3090
- auto_map_chromosomes
3097
+ auto_map_chromosomes,
3091
3098
  ):
3092
3099
  # Merge PAM-specific metadata with base metadata
3093
3100
  # Both should have the same number of rows since we created one entry per PAM site
3094
- enriched_metadata = pd.concat([base_metadata.reset_index(drop=True),
3095
- pam_metadata_df], axis=1)
3101
+ enriched_metadata = pd.concat(
3102
+ [base_metadata.reset_index(drop=True), pam_metadata_df], axis=1
3103
+ )
3096
3104
 
3097
3105
  # Yield the chunk with enriched metadata
3098
3106
  yield (alt_seqs, ref_seqs, enriched_metadata)
@@ -40,19 +40,49 @@ class VariantPosition:
40
40
  svlen: int # Length of structural variant (base pairs, signed for DEL/INS)
41
41
  variant_type: str # Type of variant ('SNV', 'INS', 'DEL', 'DUP', 'INV', 'BND')
42
42
 
43
- def get_bin_positions(self, bin_size: int) -> Tuple[int, int, int]:
43
+ def get_bin_positions(
44
+ self, bin_size: int, window_start: int, crop_length: int
45
+ ) -> Tuple[int, int, int]:
44
46
  """
45
- Convert base pair positions to bin indices.
47
+ Convert base pair positions to bin indices relative to window.
46
48
 
47
49
  Args:
48
50
  bin_size: Number of base pairs per prediction bin
51
+ window_start: Start position of the sequence window (0-based genomic coord).
52
+ crop_length: Number of base pairs cropped from each edge by the model.
53
+ This accounts for edge bases removed before prediction.
49
54
 
50
55
  Returns:
51
- Tuple of (ref_bin, alt_start_bin, alt_end_bin)
56
+ Tuple of (ref_bin, alt_start_bin, alt_end_bin) as bin indices
57
+ relative to the prediction vector. For centered masking, these
58
+ represent the center and extent of the masked region.
59
+
60
+ Notes:
61
+ - Positions are calculated relative to window_start, not absolute genomic coords
62
+ - crop_length accounts for edge bases removed before prediction
63
+ - Masked bins are centered on the variant position
52
64
  """
53
- ref_bin = int(np.ceil(self.ref_pos / bin_size))
54
- alt_start_bin = int(np.ceil(self.alt_pos / bin_size))
55
- alt_end_bin = int(np.ceil((self.alt_pos + abs(self.svlen)) / bin_size))
65
+ # Calculate positions relative to window
66
+ rel_ref_pos = self.ref_pos - window_start
67
+ rel_alt_pos = self.alt_pos - window_start
68
+
69
+ # Account for cropping (bases removed from start of window before prediction)
70
+ rel_ref_pos -= crop_length
71
+ rel_alt_pos -= crop_length
72
+
73
+ # Convert to bin indices using floor division (not ceil!)
74
+ ref_bin_center = rel_ref_pos // bin_size
75
+ alt_bin_center = rel_alt_pos // bin_size
76
+
77
+ # Calculate number of bins to mask
78
+ svlen_bins = int(np.ceil(abs(self.svlen) / bin_size))
79
+ half_bins = svlen_bins // 2
80
+
81
+ # Center the masked region on the variant
82
+ ref_bin = ref_bin_center - half_bins
83
+ alt_start_bin = alt_bin_center - half_bins
84
+ alt_end_bin = alt_bin_center + (svlen_bins - half_bins)
85
+
56
86
  return ref_bin, alt_start_bin, alt_end_bin
57
87
 
58
88
 
@@ -71,24 +101,27 @@ class PredictionAligner1D:
71
101
  Args:
72
102
  target_size: Expected number of bins in the prediction output
73
103
  bin_size: Number of base pairs per prediction bin (model-specific)
104
+ crop_length: Number of base pairs cropped from each edge by the model
74
105
 
75
106
  Example:
76
- >>> aligner = PredictionAligner1D(target_size=896, bin_size=128)
107
+ >>> aligner = PredictionAligner1D(target_size=896, bin_size=128, crop_length=0)
77
108
  >>> ref_aligned, alt_aligned = aligner.align_predictions(
78
109
  ... ref_pred, alt_pred, 'INS', variant_position
79
110
  ... )
80
111
  """
81
112
 
82
- def __init__(self, target_size: int, bin_size: int):
113
+ def __init__(self, target_size: int, bin_size: int, crop_length: int):
83
114
  """
84
115
  Initialize the 1D prediction aligner.
85
116
 
86
117
  Args:
87
118
  target_size: Expected number of bins in prediction (e.g., 896 for Enformer)
88
119
  bin_size: Base pairs per bin (e.g., 128 for Enformer)
120
+ crop_length: Number of base pairs cropped from each edge by the model
89
121
  """
90
122
  self.target_size = target_size
91
123
  self.bin_size = bin_size
124
+ self.crop_length = crop_length
92
125
 
93
126
  def align_predictions(
94
127
  self,
@@ -96,6 +129,7 @@ class PredictionAligner1D:
96
129
  alt_pred: Union[np.ndarray, "torch.Tensor"],
97
130
  svtype: str,
98
131
  var_pos: VariantPosition,
132
+ window_start: int = 0,
99
133
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
100
134
  """
101
135
  Main entry point for 1D prediction alignment.
@@ -105,6 +139,8 @@ class PredictionAligner1D:
105
139
  alt_pred: Alternate prediction vector (length N)
106
140
  svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
107
141
  var_pos: Variant position information
142
+ window_start: Start position of sequence window (0-based genomic coord).
143
+ Required for correct bin calculation. Defaults to 0.
108
144
 
109
145
  Returns:
110
146
  Tuple of (aligned_ref, aligned_alt) vectors with NaN masking applied
@@ -120,10 +156,12 @@ class PredictionAligner1D:
120
156
 
121
157
  if svtype_normalized in ["DEL", "DUP", "INS"]:
122
158
  return self._align_indel_predictions(
123
- ref_pred, alt_pred, svtype_normalized, var_pos
159
+ ref_pred, alt_pred, svtype_normalized, var_pos, window_start
124
160
  )
125
161
  elif svtype_normalized == "INV":
126
- return self._align_inversion_predictions(ref_pred, alt_pred, var_pos)
162
+ return self._align_inversion_predictions(
163
+ ref_pred, alt_pred, var_pos, window_start
164
+ )
127
165
  elif svtype_normalized in ["SNV", "MNV"]:
128
166
  # SNVs don't change coordinates, direct alignment
129
167
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
@@ -140,18 +178,22 @@ class PredictionAligner1D:
140
178
  alt_pred: Union[np.ndarray, "torch.Tensor"],
141
179
  svtype: str,
142
180
  var_pos: VariantPosition,
181
+ window_start: int = 0,
143
182
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
144
183
  """
145
184
  Align predictions for insertions, deletions, and duplications.
146
185
 
147
186
  Strategy:
148
187
  1. For DEL: Swap REF/ALT (deletion removes from REF)
149
- 2. Insert NaN bins in shorter sequence
188
+ 2. Insert NaN bins in shorter sequence (centered on variant)
150
189
  3. Crop edges to maintain target size
151
190
  4. For DEL: Swap back
152
191
 
153
192
  This ensures that positions present in one sequence but not the other
154
193
  are marked with NaN, enabling fair comparison of overlapping regions.
194
+
195
+ Args:
196
+ window_start: Start position of sequence window (0-based genomic coord)
155
197
  """
156
198
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
157
199
 
@@ -170,8 +212,10 @@ class PredictionAligner1D:
170
212
  var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
171
213
  )
172
214
 
173
- # Get bin positions
174
- ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
215
+ # Get bin positions (window-relative, centered)
216
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
217
+ self.bin_size, window_start, self.crop_length
218
+ )
175
219
  bins_to_add = alt_end_bin - alt_start_bin
176
220
 
177
221
  # Insert NaN bins in REF where variant exists in ALT
@@ -248,6 +292,7 @@ class PredictionAligner1D:
248
292
  ref_pred: Union[np.ndarray, "torch.Tensor"],
249
293
  alt_pred: Union[np.ndarray, "torch.Tensor"],
250
294
  var_pos: VariantPosition,
295
+ window_start: int = 0,
251
296
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
252
297
  """
253
298
  Align predictions for inversions.
@@ -259,6 +304,9 @@ class PredictionAligner1D:
259
304
  For strand-aware models, inversions can significantly affect predictions
260
305
  because regulatory elements now appear on the opposite strand. We mask
261
306
  the inverted region to focus comparison on unaffected flanking sequences.
307
+
308
+ Args:
309
+ window_start: Start position of sequence window (0-based genomic coord)
262
310
  """
263
311
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
264
312
 
@@ -270,7 +318,9 @@ class PredictionAligner1D:
270
318
  ref_np = ref_pred.copy()
271
319
  alt_np = alt_pred.copy()
272
320
 
273
- var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
321
+ var_start, _, var_end = var_pos.get_bin_positions(
322
+ self.bin_size, window_start, self.crop_length
323
+ )
274
324
 
275
325
  # Mask inverted region in both REF and ALT
276
326
  ref_np[var_start : var_end + 1] = np.nan
@@ -373,19 +423,23 @@ class PredictionAligner2D:
373
423
  target_size: Expected matrix dimension (NxN)
374
424
  bin_size: Number of base pairs per matrix bin (model-specific)
375
425
  diag_offset: Number of diagonal bins to mask (model-specific)
426
+ crop_length: Number of base pairs cropped from each edge by the model
376
427
 
377
428
  Example:
378
429
  >>> aligner = PredictionAligner2D(
379
430
  ... target_size=448,
380
431
  ... bin_size=2048,
381
- ... diag_offset=2
432
+ ... diag_offset=2,
433
+ ... crop_length=0
382
434
  ... )
383
435
  >>> ref_aligned, alt_aligned = aligner.align_predictions(
384
436
  ... ref_matrix, alt_matrix, 'DEL', variant_position
385
437
  ... )
386
438
  """
387
439
 
388
- def __init__(self, target_size: int, bin_size: int, diag_offset: int):
440
+ def __init__(
441
+ self, target_size: int, bin_size: int, diag_offset: int, crop_length: int
442
+ ):
389
443
  """
390
444
  Initialize the 2D prediction aligner.
391
445
 
@@ -393,10 +447,12 @@ class PredictionAligner2D:
393
447
  target_size: Matrix dimension (e.g., 448 for Akita)
394
448
  bin_size: Base pairs per bin (e.g., 2048 for Akita)
395
449
  diag_offset: Diagonal masking offset (e.g., 2 for Akita)
450
+ crop_length: Number of base pairs cropped from each edge by the model
396
451
  """
397
452
  self.target_size = target_size
398
453
  self.bin_size = bin_size
399
454
  self.diag_offset = diag_offset
455
+ self.crop_length = crop_length
400
456
 
401
457
  def align_predictions(
402
458
  self,
@@ -404,6 +460,7 @@ class PredictionAligner2D:
404
460
  alt_pred: Union[np.ndarray, "torch.Tensor"],
405
461
  svtype: str,
406
462
  var_pos: VariantPosition,
463
+ window_start: int = 0,
407
464
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
408
465
  """
409
466
  Main entry point for 2D matrix alignment.
@@ -413,6 +470,8 @@ class PredictionAligner2D:
413
470
  alt_pred: Alternate prediction matrix (NxN)
414
471
  svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
415
472
  var_pos: Variant position information
473
+ window_start: Start position of sequence window (0-based genomic coord).
474
+ Required for correct bin calculation. Defaults to 0.
416
475
 
417
476
  Returns:
418
477
  Tuple of (aligned_ref, aligned_alt) matrices with NaN masking applied
@@ -428,10 +487,12 @@ class PredictionAligner2D:
428
487
 
429
488
  if svtype_normalized in ["DEL", "DUP", "INS"]:
430
489
  return self._align_indel_matrices(
431
- ref_pred, alt_pred, svtype_normalized, var_pos
490
+ ref_pred, alt_pred, svtype_normalized, var_pos, window_start
432
491
  )
433
492
  elif svtype_normalized == "INV":
434
- return self._align_inversion_matrices(ref_pred, alt_pred, var_pos)
493
+ return self._align_inversion_matrices(
494
+ ref_pred, alt_pred, var_pos, window_start
495
+ )
435
496
  elif svtype_normalized in ["SNV", "MNV"]:
436
497
  # SNVs don't change coordinates, direct alignment
437
498
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
@@ -448,15 +509,19 @@ class PredictionAligner2D:
448
509
  alt_pred: Union[np.ndarray, "torch.Tensor"],
449
510
  svtype: str,
450
511
  var_pos: VariantPosition,
512
+ window_start: int = 0,
451
513
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
452
514
  """
453
515
  Align matrices for insertions, deletions, and duplications.
454
516
 
455
517
  Strategy:
456
518
  1. For DEL: Swap REF/ALT (deletion removes from REF)
457
- 2. Insert NaN bins (rows AND columns) in shorter matrix
519
+ 2. Insert NaN bins (rows AND columns) in shorter matrix (centered on variant)
458
520
  3. Crop edges to maintain target size
459
521
  4. For DEL: Swap back
522
+
523
+ Args:
524
+ window_start: Start position of sequence window (0-based genomic coord)
460
525
  """
461
526
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
462
527
 
@@ -475,8 +540,10 @@ class PredictionAligner2D:
475
540
  var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
476
541
  )
477
542
 
478
- # Get bin positions
479
- ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
543
+ # Get bin positions (window-relative, centered)
544
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
545
+ self.bin_size, window_start, self.crop_length
546
+ )
480
547
  bins_to_add = alt_end_bin - alt_start_bin
481
548
 
482
549
  # Insert NaN bins in REF where variant exists in ALT
@@ -541,6 +608,7 @@ class PredictionAligner2D:
541
608
  ref_pred: Union[np.ndarray, "torch.Tensor"],
542
609
  alt_pred: Union[np.ndarray, "torch.Tensor"],
543
610
  var_pos: VariantPosition,
611
+ window_start: int = 0,
544
612
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
545
613
  """
546
614
  Align matrices for inversions.
@@ -556,6 +624,9 @@ class PredictionAligner2D:
556
624
 
557
625
  The same NaN pattern is mirrored to ALT so both matrices have identical
558
626
  masked regions, enabling fair comparison of the unaffected areas.
627
+
628
+ Args:
629
+ window_start: Start position of sequence window (0-based genomic coord)
559
630
  """
560
631
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
561
632
 
@@ -567,7 +638,9 @@ class PredictionAligner2D:
567
638
  ref_np = ref_pred.copy()
568
639
  alt_np = alt_pred.copy()
569
640
 
570
- var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
641
+ var_start, _, var_end = var_pos.get_bin_positions(
642
+ self.bin_size, window_start, self.crop_length
643
+ )
571
644
 
572
645
  # Mask inverted region in REF (cross-pattern: rows + columns)
573
646
  ref_np[var_start : var_end + 1, :] = np.nan
@@ -802,6 +875,7 @@ def align_predictions_by_coordinate(
802
875
  metadata_row: dict,
803
876
  bin_size: int,
804
877
  prediction_type: str,
878
+ crop_length: int,
805
879
  matrix_size: Optional[int] = None,
806
880
  diag_offset: int = 0,
807
881
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
@@ -812,7 +886,7 @@ def align_predictions_by_coordinate(
812
886
  vectors (e.g., chromatin accessibility, TF binding) and 2D matrices (e.g., Hi-C contact maps),
813
887
  routing to the appropriate alignment strategy based on variant type.
814
888
 
815
- IMPORTANT: Model-specific parameters (bin_size, matrix_size) must be explicitly
889
+ IMPORTANT: Model-specific parameters (bin_size, crop_length, matrix_size) must be explicitly
816
890
  provided by the user. There are no defaults because these vary across different models.
817
891
 
818
892
  Args:
@@ -824,10 +898,13 @@ def align_predictions_by_coordinate(
824
898
  - 'variant_pos0': Variant position (0-based, absolute genomic coordinate)
825
899
  - 'svlen': Length of structural variant (optional, for symbolic alleles)
826
900
  bin_size: Number of base pairs per prediction bin (REQUIRED, model-specific)
827
- Examples: 2048 for Akita
901
+ Examples: 2048 for Akita, 128 for Enformer
828
902
  prediction_type: Type of predictions ("1D" or "2D")
829
903
  - "1D": Vector predictions (chromatin accessibility, TF binding, etc.)
830
904
  - "2D": Matrix predictions (Hi-C contact maps, Micro-C, etc.)
905
+ crop_length: Number of base pairs cropped from each edge by the model (REQUIRED)
906
+ This accounts for edge bases removed before prediction.
907
+ Examples: 0 for models without cropping
831
908
  matrix_size: Size of contact matrix (REQUIRED for 2D type)
832
909
  Examples: 448 for Akita
833
910
  diag_offset: Number of diagonal bins to mask (default: 0 for no masking)
@@ -849,7 +926,8 @@ def align_predictions_by_coordinate(
849
926
  ... metadata_row={'variant_type': 'INS', 'window_start': 0,
850
927
  ... 'variant_pos0': 500, 'svlen': 100},
851
928
  ... bin_size=128,
852
- ... prediction_type="1D"
929
+ ... prediction_type="1D",
930
+ ... crop_length=0
853
931
  ... )
854
932
 
855
933
  Example (2D contact maps with diagonal masking):
@@ -860,6 +938,7 @@ def align_predictions_by_coordinate(
860
938
  ... 'variant_pos0': 50000, 'svlen': -2048},
861
939
  ... bin_size=2048,
862
940
  ... prediction_type="2D",
941
+ ... crop_length=0,
863
942
  ... matrix_size=448,
864
943
  ... diag_offset=2 # Optional: use 0 if no diagonal masking
865
944
  ... )
@@ -872,6 +951,7 @@ def align_predictions_by_coordinate(
872
951
  ... 'variant_pos0': 1000, 'svlen': 500},
873
952
  ... bin_size=1000,
874
953
  ... prediction_type="2D",
954
+ ... crop_length=0,
875
955
  ... matrix_size=512
876
956
  ... # diag_offset defaults to 0 (no masking)
877
957
  ... )
@@ -937,7 +1017,9 @@ def align_predictions_by_coordinate(
937
1017
  # Handle multi-target predictions [n_targets, n_bins]
938
1018
  if ndim > 1:
939
1019
  target_size = ref_preds.shape[-1] # Number of bins
940
- aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
1020
+ aligner = PredictionAligner1D(
1021
+ target_size=target_size, bin_size=bin_size, crop_length=crop_length
1022
+ )
941
1023
 
942
1024
  # Align each target separately
943
1025
  n_targets = ref_preds.shape[0]
@@ -948,7 +1030,7 @@ def align_predictions_by_coordinate(
948
1030
  ref_target = ref_preds[target_idx]
949
1031
  alt_target = alt_preds[target_idx]
950
1032
  ref_aligned, alt_aligned = aligner.align_predictions(
951
- ref_target, alt_target, variant_type, var_pos
1033
+ ref_target, alt_target, variant_type, var_pos, window_start
952
1034
  )
953
1035
  ref_aligned_list.append(ref_aligned)
954
1036
  alt_aligned_list.append(alt_aligned)
@@ -965,9 +1047,11 @@ def align_predictions_by_coordinate(
965
1047
  else:
966
1048
  # Single target prediction [n_bins]
967
1049
  target_size = len(ref_preds)
968
- aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
1050
+ aligner = PredictionAligner1D(
1051
+ target_size=target_size, bin_size=bin_size, crop_length=crop_length
1052
+ )
969
1053
  return aligner.align_predictions(
970
- ref_preds, alt_preds, variant_type, var_pos
1054
+ ref_preds, alt_preds, variant_type, var_pos, window_start
971
1055
  )
972
1056
  else: # 2D
973
1057
  # Check if predictions are 1D (flattened upper triangular) or 2D (full matrix)
@@ -989,10 +1073,13 @@ def align_predictions_by_coordinate(
989
1073
 
990
1074
  # Align matrices
991
1075
  aligner = PredictionAligner2D(
992
- target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
1076
+ target_size=matrix_size,
1077
+ bin_size=bin_size,
1078
+ diag_offset=diag_offset,
1079
+ crop_length=crop_length,
993
1080
  )
994
1081
  aligned_ref_matrix, aligned_alt_matrix = aligner.align_predictions(
995
- ref_matrix, alt_matrix, variant_type, var_pos
1082
+ ref_matrix, alt_matrix, variant_type, var_pos, window_start
996
1083
  )
997
1084
 
998
1085
  # Convert back to flattened format
@@ -1007,8 +1094,11 @@ def align_predictions_by_coordinate(
1007
1094
  else:
1008
1095
  # Already 2D matrices
1009
1096
  aligner = PredictionAligner2D(
1010
- target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
1097
+ target_size=matrix_size,
1098
+ bin_size=bin_size,
1099
+ diag_offset=diag_offset,
1100
+ crop_length=crop_length,
1011
1101
  )
1012
1102
  return aligner.align_predictions(
1013
- ref_preds, alt_preds, variant_type, var_pos
1103
+ ref_preds, alt_preds, variant_type, var_pos, window_start
1014
1104
  )
File without changes