supremo-lite 0.5.4__tar.gz → 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: supremo_lite
3
- Version: 0.5.4
3
+ Version: 1.0.0
4
4
  Summary: A lightweight memory first, model agnostic version of SuPreMo
5
5
  License: MIT
6
6
  License-File: LICENSE
@@ -42,13 +42,10 @@ For the latest features and bug fixes:
42
42
 
43
43
  ```bash
44
44
  # Install directly latest release
45
- pip install supremo_lite
45
+ pip install supremo-lite
46
46
 
47
47
  # Or install a specific version/tag
48
48
  pip install git+https://github.com/gladstone-institutes/supremo_lite.git@v0.5.0
49
-
50
- # Or install from a specific branch
51
- pip install git+https://github.com/gladstone-institutes/supremo_lite.git@main
52
49
  ```
53
50
 
54
51
  ### Dependencies
@@ -60,7 +57,7 @@ Required dependencies will be installed automatically:
60
57
 
61
58
  Optional dependencies:
62
59
  - `torch` - For PyTorch tensor support (automatically detected)
63
- - [https://github.com/gladstone-institutes/brisket](brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
60
+ - [brisket](https://github.com/gladstone-institutes/brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
64
61
 
65
62
  ## Quick Start
66
63
 
@@ -214,3 +211,4 @@ Interested in contributing? Check out the contributing guidelines. Please note t
214
211
  ## Credits
215
212
 
216
213
  `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
214
+
@@ -20,13 +20,10 @@ For the latest features and bug fixes:
20
20
 
21
21
  ```bash
22
22
  # Install directly latest release
23
- pip install supremo_lite
23
+ pip install supremo-lite
24
24
 
25
25
  # Or install a specific version/tag
26
26
  pip install git+https://github.com/gladstone-institutes/supremo_lite.git@v0.5.0
27
-
28
- # Or install from a specific branch
29
- pip install git+https://github.com/gladstone-institutes/supremo_lite.git@main
30
27
  ```
31
28
 
32
29
  ### Dependencies
@@ -38,7 +35,7 @@ Required dependencies will be installed automatically:
38
35
 
39
36
  Optional dependencies:
40
37
  - `torch` - For PyTorch tensor support (automatically detected)
41
- - [https://github.com/gladstone-institutes/brisket](brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
38
+ - [brisket](https://github.com/gladstone-institutes/brisket) - Cython powered faster 1 hot encoding for DNA sequences (automatically detected)
42
39
 
43
40
  ## Quick Start
44
41
 
@@ -191,4 +188,4 @@ Interested in contributing? Check out the contributing guidelines. Please note t
191
188
 
192
189
  ## Credits
193
190
 
194
- `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
191
+ `supremo_lite` was created with [`cookiecutter`](https://cookiecutter.readthedocs.io/en/latest/) and the `py-pkgs-cookiecutter` [template](https://github.com/py-pkgs/py-pkgs-cookiecutter).
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "supremo_lite"
3
- version = "0.5.4"
3
+ version = "1.0.0"
4
4
  description = "A lightweight memory first, model agnostic version of SuPreMo"
5
5
  authors = ["Natalie Gill", "Sean Whalen"]
6
6
  license = "MIT"
@@ -42,7 +42,11 @@ from .personalize import (
42
42
  )
43
43
 
44
44
  # Import mutagenesis functions
45
- from .mutagenesis import get_sm_sequences, get_sm_subsequences
45
+ from .mutagenesis import (
46
+ get_sm_sequences,
47
+ get_sm_subsequences,
48
+ get_scrambled_subsequences,
49
+ )
46
50
 
47
51
  # Import prediction alignment functions
48
52
  from .prediction_alignment import align_predictions_by_coordinate
@@ -52,7 +56,7 @@ from .prediction_alignment import align_predictions_by_coordinate
52
56
  # This allows users who don't have PyTorch to still use the main package
53
57
 
54
58
  # Version
55
- __version__ = "0.5.4"
59
+ __version__ = "1.0.0"
56
60
  # Package metadata
57
61
  __description__ = (
58
62
  "A module for generating personalized genome sequences and in-silico mutagenesis"
@@ -126,9 +126,13 @@ if TORCH_AVAILABLE:
126
126
  )
127
127
 
128
128
  # Crop bins from all edges to focus loss function
129
- y_hat = y_hat[
130
- :, :, self.crop_bins : -self.crop_bins, self.crop_bins : -self.crop_bins
131
- ]
129
+ if self.crop_bins > 0:
130
+ y_hat = y_hat[
131
+ :,
132
+ :,
133
+ self.crop_bins : -self.crop_bins,
134
+ self.crop_bins : -self.crop_bins,
135
+ ]
132
136
 
133
137
  # Return full contact matrix
134
138
  return y_hat
@@ -19,6 +19,94 @@ except ImportError:
19
19
  pass # Already handled in core
20
20
 
21
21
 
22
+ def _kmer_shuffle(sequence: str, k: int = 1, random_state=None) -> str:
23
+ """
24
+ Shuffle a sequence by k-mer chunks, preserving k-mer composition.
25
+
26
+ Breaks the sequence into non-overlapping k-mers and shuffles these chunks.
27
+ This preserves the k-mer frequency counts in the shuffled sequence:
28
+ - k=1: Shuffle individual nucleotides (preserves mononucleotide/GC composition)
29
+ - k=2: Shuffle 2-mers (preserves dinucleotide frequencies)
30
+ - k=3: Shuffle 3-mers (preserves trinucleotide frequencies)
31
+
32
+ Note: If sequence length is not divisible by k, the remainder bases are
33
+ treated as a partial k-mer and shuffled along with the complete k-mers.
34
+
35
+ Args:
36
+ sequence: Input DNA sequence string (ACGT only)
37
+ k: Size of k-mers to shuffle (default: 1)
38
+ random_state: Optional numpy random state or seed for reproducibility
39
+
40
+ Returns:
41
+ Shuffled sequence with preserved k-mer composition
42
+
43
+ Raises:
44
+ ValueError: If k < 1
45
+ """
46
+ if k < 1:
47
+ raise ValueError(f"k must be >= 1, got {k}")
48
+
49
+ if len(sequence) < k:
50
+ return sequence
51
+
52
+ # Handle random state
53
+ if random_state is None:
54
+ rng = np.random.default_rng()
55
+ elif isinstance(random_state, (int, np.integer)):
56
+ rng = np.random.default_rng(random_state)
57
+ else:
58
+ rng = random_state
59
+
60
+ seq = sequence.upper()
61
+
62
+ # Calculate how many complete k-mers we can make
63
+ n_complete_kmers = len(seq) // k
64
+ kmer_portion_len = n_complete_kmers * k
65
+
66
+ # Split into k-mers
67
+ kmers = [seq[i : i + k] for i in range(0, kmer_portion_len, k)]
68
+
69
+ # Include leftover bases as an additional chunk to shuffle
70
+ leftover = seq[kmer_portion_len:]
71
+ if leftover:
72
+ kmers.append(leftover)
73
+
74
+ # Shuffle all chunks (including leftover if present)
75
+ rng.shuffle(kmers)
76
+
77
+ return "".join(kmers)
78
+
79
+
80
+ def _scramble_region(
81
+ sequence: str, start: int, end: int, k: int = 1, random_state=None
82
+ ) -> str:
83
+ """
84
+ Scramble a specific region within a sequence using k-mer shuffle.
85
+
86
+ Args:
87
+ sequence: Full sequence string
88
+ start: Start position of region to scramble (0-based)
89
+ end: End position of region to scramble (exclusive)
90
+ k: Size of k-mers to shuffle (default: 1 for mononucleotide shuffle)
91
+ random_state: Optional random state for reproducibility
92
+
93
+ Returns:
94
+ Sequence with the specified region scrambled
95
+ """
96
+ if start < 0 or end > len(sequence) or start >= end:
97
+ raise ValueError(
98
+ f"Invalid region [{start}, {end}) for sequence of length {len(sequence)}"
99
+ )
100
+
101
+ prefix = sequence[:start]
102
+ region = sequence[start:end]
103
+ suffix = sequence[end:]
104
+
105
+ scrambled_region = _kmer_shuffle(region, k=k, random_state=random_state)
106
+
107
+ return prefix + scrambled_region + suffix
108
+
109
+
22
110
  def _read_bed_file(bed_regions: Union[str, pd.DataFrame]) -> pd.DataFrame:
23
111
  """
24
112
  Read BED file or validate BED DataFrame format.
@@ -146,7 +234,15 @@ def get_sm_sequences(chrom, start, end, reference_fasta, encoder=None):
146
234
 
147
235
  # Create a DataFrame for the metadata
148
236
  metadata_df = pd.DataFrame(
149
- metadata, columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"]
237
+ metadata,
238
+ columns=[
239
+ "chrom",
240
+ "window_start",
241
+ "window_end",
242
+ "variant_offset0",
243
+ "ref",
244
+ "alt",
245
+ ],
150
246
  )
151
247
 
152
248
  return ref_1h, alt_seqs_stacked, metadata_df
@@ -239,9 +335,7 @@ def get_sm_subsequences(
239
335
  )
240
336
  elif not has_bed:
241
337
  # Neither approach was specified
242
- raise ValueError(
243
- "Must provide either (anchor + anchor_radius) or bed_regions."
244
- )
338
+ raise ValueError("Must provide either (anchor + anchor_radius) or bed_regions.")
245
339
 
246
340
  alt_seqs = []
247
341
  metadata = []
@@ -331,7 +425,11 @@ def get_sm_subsequences(
331
425
 
332
426
  # Adjust window to stay within chromosome bounds
333
427
  chrom_obj = reference_fasta[chrom]
334
- chrom_len = len(chrom_obj) if hasattr(chrom_obj, '__len__') else len(chrom_obj.seq)
428
+ chrom_len = (
429
+ len(chrom_obj)
430
+ if hasattr(chrom_obj, "__len__")
431
+ else len(chrom_obj.seq)
432
+ )
335
433
  if window_start < 0:
336
434
  window_start = 0
337
435
  window_end = min(seq_len, chrom_len)
@@ -377,13 +475,17 @@ def get_sm_subsequences(
377
475
  # Create a clone and substitute the base
378
476
  if TORCH_AVAILABLE and isinstance(region_1h, torch.Tensor):
379
477
  alt_1h = region_1h.clone()
380
- alt_1h[:, i] = torch.tensor(nt_to_1h[alt], dtype=alt_1h.dtype)
478
+ alt_1h[:, i] = torch.tensor(
479
+ nt_to_1h[alt], dtype=alt_1h.dtype
480
+ )
381
481
  else:
382
482
  alt_1h = region_1h.copy()
383
483
  alt_1h[:, i] = nt_to_1h[alt]
384
484
 
385
485
  alt_seqs.append(alt_1h)
386
- metadata.append([chrom, window_start, window_end, i, ref_nt, alt])
486
+ metadata.append(
487
+ [chrom, window_start, window_end, i, ref_nt, alt]
488
+ )
387
489
 
388
490
  # If no regions were processed, create empty ref_1h
389
491
  if ref_1h is None:
@@ -408,7 +510,263 @@ def get_sm_subsequences(
408
510
 
409
511
  # Create a DataFrame for the metadata
410
512
  metadata_df = pd.DataFrame(
411
- metadata, columns=["chrom", "window_start", "window_end", "variant_pos0", "ref", "alt"]
513
+ metadata,
514
+ columns=[
515
+ "chrom",
516
+ "window_start",
517
+ "window_end",
518
+ "variant_offset0",
519
+ "ref",
520
+ "alt",
521
+ ],
412
522
  )
413
523
 
414
524
  return ref_1h, alt_seqs_stacked, metadata_df
525
+
526
+
527
+ def get_scrambled_subsequences(
528
+ chrom: str,
529
+ seq_len: int,
530
+ reference_fasta,
531
+ bed_regions: Union[str, pd.DataFrame],
532
+ n_scrambles: int = 1,
533
+ kmer_size: int = 1,
534
+ encoder=None,
535
+ auto_map_chromosomes: bool = False,
536
+ random_state=None,
537
+ ):
538
+ """
539
+ Generate sequences with BED-defined regions scrambled using k-mer shuffle.
540
+
541
+ This function creates control sequences where specific regions (defined by BED file)
542
+ are scrambled while preserving (k-1)-mer frequencies. Useful for generating
543
+ negative controls that maintain sequence composition properties.
544
+
545
+ Args:
546
+ chrom: Chromosome name
547
+ seq_len: Total sequence length for each window
548
+ reference_fasta: Reference genome object (pyfaidx.Fasta or dict-like)
549
+ bed_regions: BED file path or DataFrame defining regions to scramble.
550
+ BED format: chrom, start, end (0-based, half-open intervals).
551
+ Each BED region is scrambled within its centered seq_len window.
552
+ n_scrambles: Number of scrambled versions to generate per region (default: 1)
553
+ kmer_size: Size of k-mers to shuffle (default: 1).
554
+ - kmer_size=1: Shuffle individual nucleotides (preserves length only)
555
+ - kmer_size=2: Shuffle 2-mers (preserves mononucleotide composition)
556
+ - kmer_size=3: Shuffle 3-mers (preserves dinucleotide frequencies)
557
+ Higher values preserve more local sequence context.
558
+ encoder: Optional custom encoding function
559
+ auto_map_chromosomes: Automatically map chromosome names between reference
560
+ and BED file (e.g., 'chr1' <-> '1'). Default: False.
561
+ random_state: Random seed or numpy random generator for reproducibility.
562
+
563
+ Returns:
564
+ Tuple of (ref_seqs, scrambled_seqs, metadata):
565
+ - ref_seqs: One-hot encoded reference sequences, shape (N, 4, seq_len)
566
+ - scrambled_seqs: Scrambled sequences, shape (N * n_scrambles, 4, seq_len)
567
+ - metadata: DataFrame with columns:
568
+ - chrom: Chromosome name
569
+ - window_start: Start of sequence window (0-based)
570
+ - window_end: End of sequence window (0-based, exclusive)
571
+ - scramble_start: Start of scrambled region within window (0-based)
572
+ - scramble_end: End of scrambled region within window (0-based, exclusive)
573
+ - scramble_idx: Index of this scramble (0 to n_scrambles-1)
574
+ - ref: Original/reference sequence in scrambled region
575
+ - alt: Scrambled/alternate sequence in that region
576
+
577
+ Raises:
578
+ ValueError: If bed_regions is not provided, has invalid format, or kmer_size < 1
579
+ """
580
+ if bed_regions is None:
581
+ raise ValueError("bed_regions is required for get_scrambled_subsequences()")
582
+
583
+ if kmer_size < 1:
584
+ raise ValueError(f"kmer_size must be >= 1, got {kmer_size}")
585
+
586
+ # Handle random state
587
+ if random_state is None:
588
+ rng = np.random.default_rng()
589
+ elif isinstance(random_state, (int, np.integer)):
590
+ rng = np.random.default_rng(random_state)
591
+ else:
592
+ rng = random_state
593
+
594
+ # Parse BED file
595
+ bed_df = _read_bed_file(bed_regions)
596
+
597
+ # Apply chromosome name matching
598
+ ref_chroms = {chrom}
599
+ bed_chroms = set(bed_df["chrom"].unique())
600
+
601
+ mapping, unmatched = match_chromosomes_with_report(
602
+ ref_chroms,
603
+ bed_chroms,
604
+ verbose=False,
605
+ auto_map_chromosomes=auto_map_chromosomes,
606
+ )
607
+
608
+ if mapping:
609
+ bed_df = apply_chromosome_mapping(bed_df, mapping)
610
+
611
+ # Filter to target chromosome
612
+ chrom_bed_regions = bed_df[bed_df["chrom"] == chrom].copy()
613
+
614
+ if len(chrom_bed_regions) == 0:
615
+ warnings.warn(
616
+ f"No BED regions found for chromosome {chrom}. "
617
+ f"Returning original unshuffled sequence."
618
+ )
619
+ # Return original sequence (unshuffled) centered on chromosome
620
+ chrom_obj = reference_fasta[chrom]
621
+ if hasattr(chrom_obj, "__len__"):
622
+ chrom_len = len(chrom_obj)
623
+ else:
624
+ chrom_len = len(str(chrom_obj))
625
+
626
+ # Center window on chromosome
627
+ chrom_center = chrom_len // 2
628
+ window_start = max(0, chrom_center - seq_len // 2)
629
+ window_end = min(chrom_len, window_start + seq_len)
630
+
631
+ # Adjust if we hit the end
632
+ if window_end - window_start < seq_len:
633
+ window_start = max(0, window_end - seq_len)
634
+
635
+ # Get reference sequence
636
+ ref_seq_obj = reference_fasta[chrom][window_start:window_end]
637
+ if hasattr(ref_seq_obj, "seq"):
638
+ ref_seq = str(ref_seq_obj.seq)
639
+ else:
640
+ ref_seq = str(ref_seq_obj)
641
+
642
+ ref_1h = encode_seq(ref_seq, encoder)
643
+
644
+ if TORCH_AVAILABLE and isinstance(ref_1h, torch.Tensor):
645
+ ref_stacked = torch.stack([ref_1h])
646
+ # Return same sequence for all "scrambled" outputs (but unshuffled)
647
+ scrambled_stacked = torch.stack([ref_1h] * n_scrambles)
648
+ else:
649
+ ref_stacked = np.stack([ref_1h])
650
+ scrambled_stacked = np.stack([ref_1h] * n_scrambles)
651
+
652
+ # Create metadata indicating no scrambling occurred
653
+ meta_rows = []
654
+ for i in range(n_scrambles):
655
+ meta_rows.append(
656
+ {
657
+ "chrom": chrom,
658
+ "window_start": window_start,
659
+ "window_end": window_end,
660
+ "scramble_start": 0,
661
+ "scramble_end": 0, # Empty region indicates no scrambling
662
+ "scramble_idx": i,
663
+ "ref": ref_seq,
664
+ "alt": ref_seq, # Same as ref when no scrambling
665
+ }
666
+ )
667
+
668
+ return ref_stacked, scrambled_stacked, pd.DataFrame(meta_rows)
669
+
670
+ ref_sequences = []
671
+ scrambled_sequences = []
672
+ metadata = []
673
+
674
+ # Process each BED region
675
+ for _, bed_region in chrom_bed_regions.iterrows():
676
+ region_start = int(bed_region["start"])
677
+ region_end = int(bed_region["end"])
678
+ region_center = (region_start + region_end) // 2
679
+
680
+ # Calculate sequence window centered on BED region
681
+ window_start = region_center - seq_len // 2
682
+ window_end = window_start + seq_len
683
+
684
+ # Adjust window to stay within chromosome bounds
685
+ chrom_obj = reference_fasta[chrom]
686
+ chrom_len = len(chrom_obj) if hasattr(chrom_obj, "__len__") else len(chrom_obj)
687
+
688
+ if window_start < 0:
689
+ window_start = 0
690
+ window_end = min(seq_len, chrom_len)
691
+ elif window_end > chrom_len:
692
+ window_end = chrom_len
693
+ window_start = max(0, chrom_len - seq_len)
694
+
695
+ # Get reference sequence
696
+ ref_seq_obj = reference_fasta[chrom][window_start:window_end]
697
+ if hasattr(ref_seq_obj, "seq"):
698
+ ref_seq = str(ref_seq_obj.seq)
699
+ else:
700
+ ref_seq = str(ref_seq_obj)
701
+
702
+ if len(ref_seq) != seq_len:
703
+ warnings.warn(
704
+ f"Region {chrom}:{region_start}-{region_end} produces sequence of length "
705
+ f"{len(ref_seq)} instead of {seq_len}. Skipping."
706
+ )
707
+ continue
708
+
709
+ # Calculate scramble region relative to window
710
+ scramble_start_rel = max(0, region_start - window_start)
711
+ scramble_end_rel = min(seq_len, region_end - window_start)
712
+
713
+ if scramble_start_rel >= scramble_end_rel:
714
+ warnings.warn(
715
+ f"BED region {chrom}:{region_start}-{region_end} is outside window bounds. Skipping."
716
+ )
717
+ continue
718
+
719
+ # Store reference sequence
720
+ ref_1h = encode_seq(ref_seq, encoder)
721
+ ref_sequences.append(ref_1h)
722
+
723
+ # Get original region sequence for metadata
724
+ original_region = ref_seq[scramble_start_rel:scramble_end_rel]
725
+
726
+ # Generate n_scrambles scrambled versions
727
+ for scramble_idx in range(n_scrambles):
728
+ scrambled_seq = _scramble_region(
729
+ ref_seq,
730
+ scramble_start_rel,
731
+ scramble_end_rel,
732
+ k=kmer_size,
733
+ random_state=rng,
734
+ )
735
+
736
+ scrambled_1h = encode_seq(scrambled_seq, encoder)
737
+ scrambled_sequences.append(scrambled_1h)
738
+
739
+ scrambled_region = scrambled_seq[scramble_start_rel:scramble_end_rel]
740
+
741
+ metadata.append(
742
+ {
743
+ "chrom": chrom,
744
+ "window_start": window_start,
745
+ "window_end": window_end,
746
+ "scramble_start": scramble_start_rel,
747
+ "scramble_end": scramble_end_rel,
748
+ "scramble_idx": scramble_idx,
749
+ "ref": original_region,
750
+ "alt": scrambled_region,
751
+ }
752
+ )
753
+
754
+ # Stack sequences
755
+ if ref_sequences:
756
+ if TORCH_AVAILABLE and isinstance(ref_sequences[0], torch.Tensor):
757
+ ref_stacked = torch.stack(ref_sequences)
758
+ scrambled_stacked = torch.stack(scrambled_sequences)
759
+ else:
760
+ ref_stacked = np.stack(ref_sequences)
761
+ scrambled_stacked = np.stack(scrambled_sequences)
762
+ else:
763
+ if TORCH_AVAILABLE:
764
+ ref_stacked = torch.empty((0, 4, seq_len), dtype=torch.float32)
765
+ scrambled_stacked = torch.empty((0, 4, seq_len), dtype=torch.float32)
766
+ else:
767
+ ref_stacked = np.empty((0, 4, seq_len), dtype=np.float32)
768
+ scrambled_stacked = np.empty((0, 4, seq_len), dtype=np.float32)
769
+
770
+ metadata_df = pd.DataFrame(metadata)
771
+
772
+ return ref_stacked, scrambled_stacked, metadata_df
@@ -42,7 +42,7 @@ IUPAC_CODES = {
42
42
  "D": "[AGT]",
43
43
  "H": "[ACT]",
44
44
  "V": "[ACG]",
45
- "N": "[ACGT]"
45
+ "N": "[ACGT]",
46
46
  }
47
47
 
48
48
 
@@ -2811,6 +2811,7 @@ def get_pam_disrupting_alt_sequences(
2811
2811
  ... ref, vcf, seq_len=50, max_pam_distance=10, n_chunks=5):
2812
2812
  ... predictions = model.predict(alt_seqs, ref_seqs)
2813
2813
  """
2814
+
2814
2815
  # Helper function to find PAM sites in a sequence
2815
2816
  def _find_pam_sites(sequence, pam_pattern):
2816
2817
  """Find all PAM site positions in a sequence using IUPAC codes.
@@ -2830,7 +2831,7 @@ def get_pam_disrupting_alt_sequences(
2830
2831
  pat_base = pat_upper[j]
2831
2832
 
2832
2833
  # Sequence 'N' (padding or unknown) matches any pattern base
2833
- if seq_base == 'N':
2834
+ if seq_base == "N":
2834
2835
  continue # Always matches
2835
2836
 
2836
2837
  # Get allowed bases for this pattern position
@@ -3003,9 +3004,7 @@ def get_pam_disrupting_alt_sequences(
3003
3004
  ref_allele = var.get("ref", "")
3004
3005
  alt_allele = var.get("alt", "")
3005
3006
  is_indel = (
3006
- len(ref_allele) != len(alt_allele)
3007
- or ref_allele == "-"
3008
- or alt_allele == "-"
3007
+ len(ref_allele) != len(alt_allele) or ref_allele == "-" or alt_allele == "-"
3009
3008
  )
3010
3009
 
3011
3010
  truly_disrupted_pam_sites = []
@@ -3053,8 +3052,12 @@ def get_pam_disrupting_alt_sequences(
3053
3052
  # For each disrupted PAM site, create a metadata entry
3054
3053
  for pam_site_pos in truly_disrupted_pam_sites:
3055
3054
  # Extract PAM sequences
3056
- ref_pam_seq = ref_window_seq[pam_site_pos : pam_site_pos + len(pam_sequence)]
3057
- alt_pam_seq = modified_window[pam_site_pos : pam_site_pos + len(pam_sequence)]
3055
+ ref_pam_seq = ref_window_seq[
3056
+ pam_site_pos : pam_site_pos + len(pam_sequence)
3057
+ ]
3058
+ alt_pam_seq = modified_window[
3059
+ pam_site_pos : pam_site_pos + len(pam_sequence)
3060
+ ]
3058
3061
 
3059
3062
  # Calculate distance from variant to PAM
3060
3063
  pam_distance = abs(pam_site_pos - variant_pos_in_window)
@@ -3063,12 +3066,14 @@ def get_pam_disrupting_alt_sequences(
3063
3066
  pam_disrupting_variants_list.append(var)
3064
3067
 
3065
3068
  # Store PAM-specific metadata
3066
- pam_metadata_list.append({
3067
- 'pam_site_pos': pam_site_pos,
3068
- 'pam_ref_sequence': ref_pam_seq,
3069
- 'pam_alt_sequence': alt_pam_seq,
3070
- 'pam_distance': pam_distance
3071
- })
3069
+ pam_metadata_list.append(
3070
+ {
3071
+ "pam_site_pos": pam_site_pos,
3072
+ "pam_ref_sequence": ref_pam_seq,
3073
+ "pam_alt_sequence": alt_pam_seq,
3074
+ "pam_distance": pam_distance,
3075
+ }
3076
+ )
3072
3077
 
3073
3078
  # If no PAM-disrupting variants found, yield empty results
3074
3079
  if not pam_disrupting_variants_list:
@@ -3076,7 +3081,9 @@ def get_pam_disrupting_alt_sequences(
3076
3081
  return
3077
3082
 
3078
3083
  # Create DataFrame with filtered PAM-disrupting variants
3079
- filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(drop=True)
3084
+ filtered_variants_df = pd.DataFrame(pam_disrupting_variants_list).reset_index(
3085
+ drop=True
3086
+ )
3080
3087
  pam_metadata_df = pd.DataFrame(pam_metadata_list)
3081
3088
 
3082
3089
  # Call get_alt_ref_sequences with the filtered variants
@@ -3087,12 +3094,13 @@ def get_pam_disrupting_alt_sequences(
3087
3094
  encode,
3088
3095
  n_chunks,
3089
3096
  encoder,
3090
- auto_map_chromosomes
3097
+ auto_map_chromosomes,
3091
3098
  ):
3092
3099
  # Merge PAM-specific metadata with base metadata
3093
3100
  # Both should have the same number of rows since we created one entry per PAM site
3094
- enriched_metadata = pd.concat([base_metadata.reset_index(drop=True),
3095
- pam_metadata_df], axis=1)
3101
+ enriched_metadata = pd.concat(
3102
+ [base_metadata.reset_index(drop=True), pam_metadata_df], axis=1
3103
+ )
3096
3104
 
3097
3105
  # Yield the chunk with enriched metadata
3098
3106
  yield (alt_seqs, ref_seqs, enriched_metadata)
@@ -40,19 +40,49 @@ class VariantPosition:
40
40
  svlen: int # Length of structural variant (base pairs, signed for DEL/INS)
41
41
  variant_type: str # Type of variant ('SNV', 'INS', 'DEL', 'DUP', 'INV', 'BND')
42
42
 
43
- def get_bin_positions(self, bin_size: int) -> Tuple[int, int, int]:
43
+ def get_bin_positions(
44
+ self, bin_size: int, window_start: int, crop_length: int
45
+ ) -> Tuple[int, int, int]:
44
46
  """
45
- Convert base pair positions to bin indices.
47
+ Convert base pair positions to bin indices relative to window.
46
48
 
47
49
  Args:
48
50
  bin_size: Number of base pairs per prediction bin
51
+ window_start: Start position of the sequence window (0-based genomic coord).
52
+ crop_length: Number of base pairs cropped from each edge by the model.
53
+ This accounts for edge bases removed before prediction.
49
54
 
50
55
  Returns:
51
- Tuple of (ref_bin, alt_start_bin, alt_end_bin)
56
+ Tuple of (ref_bin, alt_start_bin, alt_end_bin) as bin indices
57
+ relative to the prediction vector. For centered masking, these
58
+ represent the center and extent of the masked region.
59
+
60
+ Notes:
61
+ - Positions are calculated relative to window_start, not absolute genomic coords
62
+ - crop_length accounts for edge bases removed before prediction
63
+ - Masked bins are centered on the variant position
52
64
  """
53
- ref_bin = int(np.ceil(self.ref_pos / bin_size))
54
- alt_start_bin = int(np.ceil(self.alt_pos / bin_size))
55
- alt_end_bin = int(np.ceil((self.alt_pos + abs(self.svlen)) / bin_size))
65
+ # Calculate positions relative to window
66
+ rel_ref_pos = self.ref_pos - window_start
67
+ rel_alt_pos = self.alt_pos - window_start
68
+
69
+ # Account for cropping (bases removed from start of window before prediction)
70
+ rel_ref_pos -= crop_length
71
+ rel_alt_pos -= crop_length
72
+
73
+ # Convert to bin indices using floor division (not ceil!)
74
+ ref_bin_center = rel_ref_pos // bin_size
75
+ alt_bin_center = rel_alt_pos // bin_size
76
+
77
+ # Calculate number of bins to mask
78
+ svlen_bins = int(np.ceil(abs(self.svlen) / bin_size))
79
+ half_bins = svlen_bins // 2
80
+
81
+ # Center the masked region on the variant
82
+ ref_bin = ref_bin_center - half_bins
83
+ alt_start_bin = alt_bin_center - half_bins
84
+ alt_end_bin = alt_bin_center + (svlen_bins - half_bins)
85
+
56
86
  return ref_bin, alt_start_bin, alt_end_bin
57
87
 
58
88
 
@@ -71,24 +101,27 @@ class PredictionAligner1D:
71
101
  Args:
72
102
  target_size: Expected number of bins in the prediction output
73
103
  bin_size: Number of base pairs per prediction bin (model-specific)
104
+ crop_length: Number of base pairs cropped from each edge by the model
74
105
 
75
106
  Example:
76
- >>> aligner = PredictionAligner1D(target_size=896, bin_size=128)
107
+ >>> aligner = PredictionAligner1D(target_size=896, bin_size=128, crop_length=0)
77
108
  >>> ref_aligned, alt_aligned = aligner.align_predictions(
78
109
  ... ref_pred, alt_pred, 'INS', variant_position
79
110
  ... )
80
111
  """
81
112
 
82
- def __init__(self, target_size: int, bin_size: int):
113
+ def __init__(self, target_size: int, bin_size: int, crop_length: int):
83
114
  """
84
115
  Initialize the 1D prediction aligner.
85
116
 
86
117
  Args:
87
118
  target_size: Expected number of bins in prediction (e.g., 896 for Enformer)
88
119
  bin_size: Base pairs per bin (e.g., 128 for Enformer)
120
+ crop_length: Number of base pairs cropped from each edge by the model
89
121
  """
90
122
  self.target_size = target_size
91
123
  self.bin_size = bin_size
124
+ self.crop_length = crop_length
92
125
 
93
126
  def align_predictions(
94
127
  self,
@@ -96,6 +129,7 @@ class PredictionAligner1D:
96
129
  alt_pred: Union[np.ndarray, "torch.Tensor"],
97
130
  svtype: str,
98
131
  var_pos: VariantPosition,
132
+ window_start: int = 0,
99
133
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
100
134
  """
101
135
  Main entry point for 1D prediction alignment.
@@ -105,6 +139,8 @@ class PredictionAligner1D:
105
139
  alt_pred: Alternate prediction vector (length N)
106
140
  svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
107
141
  var_pos: Variant position information
142
+ window_start: Start position of sequence window (0-based genomic coord).
143
+ Required for correct bin calculation. Defaults to 0.
108
144
 
109
145
  Returns:
110
146
  Tuple of (aligned_ref, aligned_alt) vectors with NaN masking applied
@@ -120,10 +156,12 @@ class PredictionAligner1D:
120
156
 
121
157
  if svtype_normalized in ["DEL", "DUP", "INS"]:
122
158
  return self._align_indel_predictions(
123
- ref_pred, alt_pred, svtype_normalized, var_pos
159
+ ref_pred, alt_pred, svtype_normalized, var_pos, window_start
124
160
  )
125
161
  elif svtype_normalized == "INV":
126
- return self._align_inversion_predictions(ref_pred, alt_pred, var_pos)
162
+ return self._align_inversion_predictions(
163
+ ref_pred, alt_pred, var_pos, window_start
164
+ )
127
165
  elif svtype_normalized in ["SNV", "MNV"]:
128
166
  # SNVs don't change coordinates, direct alignment
129
167
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
@@ -140,18 +178,22 @@ class PredictionAligner1D:
140
178
  alt_pred: Union[np.ndarray, "torch.Tensor"],
141
179
  svtype: str,
142
180
  var_pos: VariantPosition,
181
+ window_start: int = 0,
143
182
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
144
183
  """
145
184
  Align predictions for insertions, deletions, and duplications.
146
185
 
147
186
  Strategy:
148
187
  1. For DEL: Swap REF/ALT (deletion removes from REF)
149
- 2. Insert NaN bins in shorter sequence
188
+ 2. Insert NaN bins in shorter sequence (centered on variant)
150
189
  3. Crop edges to maintain target size
151
190
  4. For DEL: Swap back
152
191
 
153
192
  This ensures that positions present in one sequence but not the other
154
193
  are marked with NaN, enabling fair comparison of overlapping regions.
194
+
195
+ Args:
196
+ window_start: Start position of sequence window (0-based genomic coord)
155
197
  """
156
198
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
157
199
 
@@ -170,8 +212,10 @@ class PredictionAligner1D:
170
212
  var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
171
213
  )
172
214
 
173
- # Get bin positions
174
- ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
215
+ # Get bin positions (window-relative, centered)
216
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
217
+ self.bin_size, window_start, self.crop_length
218
+ )
175
219
  bins_to_add = alt_end_bin - alt_start_bin
176
220
 
177
221
  # Insert NaN bins in REF where variant exists in ALT
@@ -248,6 +292,7 @@ class PredictionAligner1D:
248
292
  ref_pred: Union[np.ndarray, "torch.Tensor"],
249
293
  alt_pred: Union[np.ndarray, "torch.Tensor"],
250
294
  var_pos: VariantPosition,
295
+ window_start: int = 0,
251
296
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
252
297
  """
253
298
  Align predictions for inversions.
@@ -259,6 +304,9 @@ class PredictionAligner1D:
259
304
  For strand-aware models, inversions can significantly affect predictions
260
305
  because regulatory elements now appear on the opposite strand. We mask
261
306
  the inverted region to focus comparison on unaffected flanking sequences.
307
+
308
+ Args:
309
+ window_start: Start position of sequence window (0-based genomic coord)
262
310
  """
263
311
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
264
312
 
@@ -270,7 +318,9 @@ class PredictionAligner1D:
270
318
  ref_np = ref_pred.copy()
271
319
  alt_np = alt_pred.copy()
272
320
 
273
- var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
321
+ var_start, _, var_end = var_pos.get_bin_positions(
322
+ self.bin_size, window_start, self.crop_length
323
+ )
274
324
 
275
325
  # Mask inverted region in both REF and ALT
276
326
  ref_np[var_start : var_end + 1] = np.nan
@@ -373,19 +423,23 @@ class PredictionAligner2D:
373
423
  target_size: Expected matrix dimension (NxN)
374
424
  bin_size: Number of base pairs per matrix bin (model-specific)
375
425
  diag_offset: Number of diagonal bins to mask (model-specific)
426
+ crop_length: Number of base pairs cropped from each edge by the model
376
427
 
377
428
  Example:
378
429
  >>> aligner = PredictionAligner2D(
379
430
  ... target_size=448,
380
431
  ... bin_size=2048,
381
- ... diag_offset=2
432
+ ... diag_offset=2,
433
+ ... crop_length=0
382
434
  ... )
383
435
  >>> ref_aligned, alt_aligned = aligner.align_predictions(
384
436
  ... ref_matrix, alt_matrix, 'DEL', variant_position
385
437
  ... )
386
438
  """
387
439
 
388
- def __init__(self, target_size: int, bin_size: int, diag_offset: int):
440
+ def __init__(
441
+ self, target_size: int, bin_size: int, diag_offset: int, crop_length: int
442
+ ):
389
443
  """
390
444
  Initialize the 2D prediction aligner.
391
445
 
@@ -393,10 +447,12 @@ class PredictionAligner2D:
393
447
  target_size: Matrix dimension (e.g., 448 for Akita)
394
448
  bin_size: Base pairs per bin (e.g., 2048 for Akita)
395
449
  diag_offset: Diagonal masking offset (e.g., 2 for Akita)
450
+ crop_length: Number of base pairs cropped from each edge by the model
396
451
  """
397
452
  self.target_size = target_size
398
453
  self.bin_size = bin_size
399
454
  self.diag_offset = diag_offset
455
+ self.crop_length = crop_length
400
456
 
401
457
  def align_predictions(
402
458
  self,
@@ -404,6 +460,7 @@ class PredictionAligner2D:
404
460
  alt_pred: Union[np.ndarray, "torch.Tensor"],
405
461
  svtype: str,
406
462
  var_pos: VariantPosition,
463
+ window_start: int = 0,
407
464
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
408
465
  """
409
466
  Main entry point for 2D matrix alignment.
@@ -413,6 +470,8 @@ class PredictionAligner2D:
413
470
  alt_pred: Alternate prediction matrix (NxN)
414
471
  svtype: Variant type ('DEL', 'DUP', 'INS', 'INV', 'SNV')
415
472
  var_pos: Variant position information
473
+ window_start: Start position of sequence window (0-based genomic coord).
474
+ Required for correct bin calculation. Defaults to 0.
416
475
 
417
476
  Returns:
418
477
  Tuple of (aligned_ref, aligned_alt) matrices with NaN masking applied
@@ -428,10 +487,12 @@ class PredictionAligner2D:
428
487
 
429
488
  if svtype_normalized in ["DEL", "DUP", "INS"]:
430
489
  return self._align_indel_matrices(
431
- ref_pred, alt_pred, svtype_normalized, var_pos
490
+ ref_pred, alt_pred, svtype_normalized, var_pos, window_start
432
491
  )
433
492
  elif svtype_normalized == "INV":
434
- return self._align_inversion_matrices(ref_pred, alt_pred, var_pos)
493
+ return self._align_inversion_matrices(
494
+ ref_pred, alt_pred, var_pos, window_start
495
+ )
435
496
  elif svtype_normalized in ["SNV", "MNV"]:
436
497
  # SNVs don't change coordinates, direct alignment
437
498
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
@@ -448,15 +509,19 @@ class PredictionAligner2D:
448
509
  alt_pred: Union[np.ndarray, "torch.Tensor"],
449
510
  svtype: str,
450
511
  var_pos: VariantPosition,
512
+ window_start: int = 0,
451
513
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
452
514
  """
453
515
  Align matrices for insertions, deletions, and duplications.
454
516
 
455
517
  Strategy:
456
518
  1. For DEL: Swap REF/ALT (deletion removes from REF)
457
- 2. Insert NaN bins (rows AND columns) in shorter matrix
519
+ 2. Insert NaN bins (rows AND columns) in shorter matrix (centered on variant)
458
520
  3. Crop edges to maintain target size
459
521
  4. For DEL: Swap back
522
+
523
+ Args:
524
+ window_start: Start position of sequence window (0-based genomic coord)
460
525
  """
461
526
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
462
527
 
@@ -475,8 +540,10 @@ class PredictionAligner2D:
475
540
  var_pos.alt_pos, var_pos.ref_pos, var_pos.svlen, svtype
476
541
  )
477
542
 
478
- # Get bin positions
479
- ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(self.bin_size)
543
+ # Get bin positions (window-relative, centered)
544
+ ref_bin, alt_start_bin, alt_end_bin = var_pos.get_bin_positions(
545
+ self.bin_size, window_start, self.crop_length
546
+ )
480
547
  bins_to_add = alt_end_bin - alt_start_bin
481
548
 
482
549
  # Insert NaN bins in REF where variant exists in ALT
@@ -541,6 +608,7 @@ class PredictionAligner2D:
541
608
  ref_pred: Union[np.ndarray, "torch.Tensor"],
542
609
  alt_pred: Union[np.ndarray, "torch.Tensor"],
543
610
  var_pos: VariantPosition,
611
+ window_start: int = 0,
544
612
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
545
613
  """
546
614
  Align matrices for inversions.
@@ -556,6 +624,9 @@ class PredictionAligner2D:
556
624
 
557
625
  The same NaN pattern is mirrored to ALT so both matrices have identical
558
626
  masked regions, enabling fair comparison of the unaffected areas.
627
+
628
+ Args:
629
+ window_start: Start position of sequence window (0-based genomic coord)
559
630
  """
560
631
  is_torch = TORCH_AVAILABLE and torch.is_tensor(ref_pred)
561
632
 
@@ -567,7 +638,9 @@ class PredictionAligner2D:
567
638
  ref_np = ref_pred.copy()
568
639
  alt_np = alt_pred.copy()
569
640
 
570
- var_start, _, var_end = var_pos.get_bin_positions(self.bin_size)
641
+ var_start, _, var_end = var_pos.get_bin_positions(
642
+ self.bin_size, window_start, self.crop_length
643
+ )
571
644
 
572
645
  # Mask inverted region in REF (cross-pattern: rows + columns)
573
646
  ref_np[var_start : var_end + 1, :] = np.nan
@@ -802,6 +875,7 @@ def align_predictions_by_coordinate(
802
875
  metadata_row: dict,
803
876
  bin_size: int,
804
877
  prediction_type: str,
878
+ crop_length: int,
805
879
  matrix_size: Optional[int] = None,
806
880
  diag_offset: int = 0,
807
881
  ) -> Tuple[Union[np.ndarray, "torch.Tensor"], Union[np.ndarray, "torch.Tensor"]]:
@@ -812,7 +886,7 @@ def align_predictions_by_coordinate(
812
886
  vectors (e.g., chromatin accessibility, TF binding) and 2D matrices (e.g., Hi-C contact maps),
813
887
  routing to the appropriate alignment strategy based on variant type.
814
888
 
815
- IMPORTANT: Model-specific parameters (bin_size, matrix_size) must be explicitly
889
+ IMPORTANT: Model-specific parameters (bin_size, crop_length, matrix_size) must be explicitly
816
890
  provided by the user. There are no defaults because these vary across different models.
817
891
 
818
892
  Args:
@@ -824,10 +898,13 @@ def align_predictions_by_coordinate(
824
898
  - 'variant_pos0': Variant position (0-based, absolute genomic coordinate)
825
899
  - 'svlen': Length of structural variant (optional, for symbolic alleles)
826
900
  bin_size: Number of base pairs per prediction bin (REQUIRED, model-specific)
827
- Examples: 2048 for Akita
901
+ Examples: 2048 for Akita, 128 for Enformer
828
902
  prediction_type: Type of predictions ("1D" or "2D")
829
903
  - "1D": Vector predictions (chromatin accessibility, TF binding, etc.)
830
904
  - "2D": Matrix predictions (Hi-C contact maps, Micro-C, etc.)
905
+ crop_length: Number of base pairs cropped from each edge by the model (REQUIRED)
906
+ This accounts for edge bases removed before prediction.
907
+ Examples: 0 for models without cropping
831
908
  matrix_size: Size of contact matrix (REQUIRED for 2D type)
832
909
  Examples: 448 for Akita
833
910
  diag_offset: Number of diagonal bins to mask (default: 0 for no masking)
@@ -849,7 +926,8 @@ def align_predictions_by_coordinate(
849
926
  ... metadata_row={'variant_type': 'INS', 'window_start': 0,
850
927
  ... 'variant_pos0': 500, 'svlen': 100},
851
928
  ... bin_size=128,
852
- ... prediction_type="1D"
929
+ ... prediction_type="1D",
930
+ ... crop_length=0
853
931
  ... )
854
932
 
855
933
  Example (2D contact maps with diagonal masking):
@@ -860,6 +938,7 @@ def align_predictions_by_coordinate(
860
938
  ... 'variant_pos0': 50000, 'svlen': -2048},
861
939
  ... bin_size=2048,
862
940
  ... prediction_type="2D",
941
+ ... crop_length=0,
863
942
  ... matrix_size=448,
864
943
  ... diag_offset=2 # Optional: use 0 if no diagonal masking
865
944
  ... )
@@ -872,6 +951,7 @@ def align_predictions_by_coordinate(
872
951
  ... 'variant_pos0': 1000, 'svlen': 500},
873
952
  ... bin_size=1000,
874
953
  ... prediction_type="2D",
954
+ ... crop_length=0,
875
955
  ... matrix_size=512
876
956
  ... # diag_offset defaults to 0 (no masking)
877
957
  ... )
@@ -937,7 +1017,9 @@ def align_predictions_by_coordinate(
937
1017
  # Handle multi-target predictions [n_targets, n_bins]
938
1018
  if ndim > 1:
939
1019
  target_size = ref_preds.shape[-1] # Number of bins
940
- aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
1020
+ aligner = PredictionAligner1D(
1021
+ target_size=target_size, bin_size=bin_size, crop_length=crop_length
1022
+ )
941
1023
 
942
1024
  # Align each target separately
943
1025
  n_targets = ref_preds.shape[0]
@@ -948,7 +1030,7 @@ def align_predictions_by_coordinate(
948
1030
  ref_target = ref_preds[target_idx]
949
1031
  alt_target = alt_preds[target_idx]
950
1032
  ref_aligned, alt_aligned = aligner.align_predictions(
951
- ref_target, alt_target, variant_type, var_pos
1033
+ ref_target, alt_target, variant_type, var_pos, window_start
952
1034
  )
953
1035
  ref_aligned_list.append(ref_aligned)
954
1036
  alt_aligned_list.append(alt_aligned)
@@ -965,9 +1047,11 @@ def align_predictions_by_coordinate(
965
1047
  else:
966
1048
  # Single target prediction [n_bins]
967
1049
  target_size = len(ref_preds)
968
- aligner = PredictionAligner1D(target_size=target_size, bin_size=bin_size)
1050
+ aligner = PredictionAligner1D(
1051
+ target_size=target_size, bin_size=bin_size, crop_length=crop_length
1052
+ )
969
1053
  return aligner.align_predictions(
970
- ref_preds, alt_preds, variant_type, var_pos
1054
+ ref_preds, alt_preds, variant_type, var_pos, window_start
971
1055
  )
972
1056
  else: # 2D
973
1057
  # Check if predictions are 1D (flattened upper triangular) or 2D (full matrix)
@@ -989,10 +1073,13 @@ def align_predictions_by_coordinate(
989
1073
 
990
1074
  # Align matrices
991
1075
  aligner = PredictionAligner2D(
992
- target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
1076
+ target_size=matrix_size,
1077
+ bin_size=bin_size,
1078
+ diag_offset=diag_offset,
1079
+ crop_length=crop_length,
993
1080
  )
994
1081
  aligned_ref_matrix, aligned_alt_matrix = aligner.align_predictions(
995
- ref_matrix, alt_matrix, variant_type, var_pos
1082
+ ref_matrix, alt_matrix, variant_type, var_pos, window_start
996
1083
  )
997
1084
 
998
1085
  # Convert back to flattened format
@@ -1007,8 +1094,11 @@ def align_predictions_by_coordinate(
1007
1094
  else:
1008
1095
  # Already 2D matrices
1009
1096
  aligner = PredictionAligner2D(
1010
- target_size=matrix_size, bin_size=bin_size, diag_offset=diag_offset
1097
+ target_size=matrix_size,
1098
+ bin_size=bin_size,
1099
+ diag_offset=diag_offset,
1100
+ crop_length=crop_length,
1011
1101
  )
1012
1102
  return aligner.align_predictions(
1013
- ref_preds, alt_preds, variant_type, var_pos
1103
+ ref_preds, alt_preds, variant_type, var_pos, window_start
1014
1104
  )
@@ -5,6 +5,7 @@ This module provides functions for reading variants from VCF files
5
5
  and other related operations.
6
6
  """
7
7
 
8
+ import gzip
8
9
  import io
9
10
  import pandas as pd
10
11
  import numpy as np
@@ -14,6 +15,22 @@ from typing import Dict, Optional, List, Tuple, Union
14
15
  from dataclasses import dataclass
15
16
 
16
17
 
18
+ def _open_vcf(path: str, mode: str = "rt"):
19
+ """
20
+ Open a VCF file, automatically detecting gzip compression.
21
+
22
+ Args:
23
+ path: Path to VCF file (may be .vcf or .vcf.gz)
24
+ mode: File mode. Use 'rt' for text reading (default).
25
+
26
+ Returns:
27
+ File handle (context manager compatible)
28
+ """
29
+ if path.endswith(".gz"):
30
+ return gzip.open(path, mode)
31
+ return open(path, mode.replace("t", "") if "t" in mode else mode)
32
+
33
+
17
34
  @dataclass
18
35
  class BreakendVariant:
19
36
  """
@@ -625,13 +642,15 @@ def _count_vcf_header_lines(path: str) -> int:
625
642
  - Lines starting with ## (metadata)
626
643
  - Line starting with #CHROM (column header)
627
644
 
645
+ Supports both uncompressed (.vcf) and gzip-compressed (.vcf.gz) files.
646
+
628
647
  Args:
629
648
  path: Path to VCF file
630
649
 
631
650
  Returns:
632
651
  Number of lines to skip (all ## lines + the #CHROM line)
633
652
  """
634
- with open(path, "r") as f:
653
+ with _open_vcf(path, "rt") as f:
635
654
  header_count = 0
636
655
  for line in f:
637
656
  if line.startswith("##"):
@@ -648,6 +667,8 @@ def read_vcf(path, include_info=True, classify_variants=True):
648
667
  """
649
668
  Read VCF file into pandas DataFrame with enhanced variant classification.
650
669
 
670
+ Supports both uncompressed (.vcf) and gzip-compressed (.vcf.gz) files.
671
+
651
672
  Args:
652
673
  path: Path to VCF file
653
674
  include_info: Whether to include INFO field (default: True)
@@ -656,11 +677,21 @@ def read_vcf(path, include_info=True, classify_variants=True):
656
677
  Returns:
657
678
  DataFrame with columns: chrom, pos1, id, ref, alt, [info], [variant_type]
658
679
 
680
+ Raises:
681
+ FileNotFoundError: If VCF file does not exist
682
+ ValueError: If VCF file has invalid format or no valid header
683
+
659
684
  Notes:
660
685
  - INFO field parsing enables structural variant classification
661
686
  - variant_type column uses VCF 4.2 compliant classification
662
687
  - Compatible with existing code expecting basic 5-column format
663
688
  """
689
+ import os
690
+
691
+ # Validate file exists
692
+ if not os.path.exists(path):
693
+ raise FileNotFoundError(f"VCF file not found: {path}")
694
+
664
695
  # Determine columns to read based on parameters
665
696
  if include_info:
666
697
  usecols = [0, 1, 2, 3, 4, 7] # Include INFO field
@@ -670,12 +701,38 @@ def read_vcf(path, include_info=True, classify_variants=True):
670
701
  base_columns = ["chrom", "pos1", "id", "ref", "alt"]
671
702
 
672
703
  # Count header lines for VCF line tracking (needed for vcf_line column)
673
- header_count = _count_vcf_header_lines(path)
704
+ try:
705
+ header_count = _count_vcf_header_lines(path)
706
+ except Exception as e:
707
+ raise ValueError(f"Failed to parse VCF header in {path}: {e}")
708
+
709
+ if header_count == 0:
710
+ raise ValueError(
711
+ f"VCF file {path} appears to have no header lines. "
712
+ "Valid VCF files must start with ##fileformat or #CHROM header."
713
+ )
674
714
 
675
715
  # Read VCF using pandas with comment='#' to skip all header lines automatically
676
- df = pd.read_table(
677
- path, comment="#", header=None, names=base_columns, usecols=usecols
678
- )
716
+ try:
717
+ df = pd.read_table(
718
+ path,
719
+ comment="#",
720
+ header=None,
721
+ names=base_columns,
722
+ usecols=usecols,
723
+ on_bad_lines="warn",
724
+ )
725
+ except pd.errors.EmptyDataError:
726
+ warnings.warn(f"VCF file {path} contains no data rows after header.")
727
+ empty_cols = base_columns + (["variant_type"] if classify_variants else [])
728
+ return pd.DataFrame(columns=empty_cols)
729
+
730
+ # Handle empty DataFrame
731
+ if len(df) == 0:
732
+ warnings.warn(f"VCF file {path} contains no variant records.")
733
+ if classify_variants:
734
+ df["variant_type"] = pd.Series(dtype=str)
735
+ return df
679
736
 
680
737
  # Add VCF line numbers for debugging (1-indexed, accounting for header lines)
681
738
  # Line number = header_lines + 1 (for 1-indexing) + row_index
@@ -683,9 +740,22 @@ def read_vcf(path, include_info=True, classify_variants=True):
683
740
 
684
741
  # Validate that pos1 column is numeric
685
742
  if not pd.api.types.is_numeric_dtype(df["pos1"]):
686
- raise ValueError(
687
- f"Position column (second column) must be numeric, got {df['pos1'].dtype}"
688
- )
743
+ # Try to convert, providing helpful error message
744
+ try:
745
+ df["pos1"] = pd.to_numeric(df["pos1"], errors="coerce")
746
+ invalid_rows = df[df["pos1"].isna()]
747
+ if len(invalid_rows) > 0:
748
+ warnings.warn(
749
+ f"Found {len(invalid_rows)} rows with non-numeric positions in {path}. "
750
+ f"First invalid at VCF line {invalid_rows.iloc[0]['vcf_line']}. "
751
+ "These rows will be removed."
752
+ )
753
+ df = df.dropna(subset=["pos1"])
754
+ df["pos1"] = df["pos1"].astype(int)
755
+ except Exception as e:
756
+ raise ValueError(
757
+ f"Position column must be numeric in {path}, conversion failed: {e}"
758
+ )
689
759
 
690
760
  # Filter out multiallelic variants (ALT alleles containing commas)
691
761
  df = _filter_multiallelic_variants(df)
@@ -775,6 +845,8 @@ def get_vcf_chromosomes(path):
775
845
  """
776
846
  Get list of chromosomes in VCF file without loading all variants.
777
847
 
848
+ Supports both uncompressed (.vcf) and gzip-compressed (.vcf.gz) files.
849
+
778
850
  Args:
779
851
  path: Path to VCF file
780
852
 
@@ -782,7 +854,7 @@ def get_vcf_chromosomes(path):
782
854
  Set of chromosome names found in the VCF file
783
855
  """
784
856
  chromosomes = set()
785
- with open(path, "r") as f:
857
+ with _open_vcf(path, "rt") as f:
786
858
  for line in f:
787
859
  if line.startswith("##"):
788
860
  continue
@@ -800,6 +872,8 @@ def read_vcf_chromosome(
800
872
  """
801
873
  Read VCF file for a specific chromosome only with enhanced variant classification.
802
874
 
875
+ Supports both uncompressed (.vcf) and gzip-compressed (.vcf.gz) files.
876
+
803
877
  Args:
804
878
  path: Path to VCF file
805
879
  target_chromosome: Chromosome name to filter for
@@ -813,7 +887,7 @@ def read_vcf_chromosome(
813
887
  chromosome_lines = []
814
888
  header_line = None
815
889
 
816
- with open(path, "r") as f:
890
+ with _open_vcf(path, "rt") as f:
817
891
  for line in f:
818
892
  if line.startswith("##"):
819
893
  continue
File without changes