supremo-lite 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- supremo_lite/__init__.py +59 -0
- supremo_lite/chromosome_utils.py +322 -0
- supremo_lite/core.py +41 -0
- supremo_lite/mock_models/__init__.py +110 -0
- supremo_lite/mock_models/testmodel_1d.py +184 -0
- supremo_lite/mock_models/testmodel_2d.py +203 -0
- supremo_lite/mutagenesis.py +414 -0
- supremo_lite/personalize.py +3098 -0
- supremo_lite/prediction_alignment.py +1014 -0
- supremo_lite/sequence_utils.py +137 -0
- supremo_lite/variant_utils.py +1645 -0
- supremo_lite-0.5.4.dist-info/METADATA +216 -0
- supremo_lite-0.5.4.dist-info/RECORD +15 -0
- supremo_lite-0.5.4.dist-info/WHEEL +4 -0
- supremo_lite-0.5.4.dist-info/licenses/LICENSE +22 -0
supremo_lite/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
supremo_lite: A module for generating personalized genome sequences from a reference
|
|
3
|
+
fasta and a variants file, or sequences for in-silico mutagenesis.
|
|
4
|
+
|
|
5
|
+
This package provides functionality for:
|
|
6
|
+
- Sequence encoding and transformation
|
|
7
|
+
- Variant reading and application
|
|
8
|
+
- In-silico mutagenesis
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
# Import core components
|
|
12
|
+
from .core import TORCH_AVAILABLE, BRISKET_AVAILABLE, nt_to_1h, nts
|
|
13
|
+
|
|
14
|
+
# Import sequence transformation utilities
|
|
15
|
+
from .sequence_utils import encode_seq, decode_seq, rc, rc_str
|
|
16
|
+
|
|
17
|
+
# Import variant reading utilities
|
|
18
|
+
from .variant_utils import (
|
|
19
|
+
read_vcf,
|
|
20
|
+
read_vcf_chunked,
|
|
21
|
+
get_vcf_chromosomes,
|
|
22
|
+
read_vcf_chromosome,
|
|
23
|
+
classify_variant_type,
|
|
24
|
+
parse_vcf_info,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
# Import chromosome matching utilities
|
|
28
|
+
from .chromosome_utils import (
|
|
29
|
+
normalize_chromosome_name,
|
|
30
|
+
create_chromosome_mapping,
|
|
31
|
+
match_chromosomes_with_report,
|
|
32
|
+
ChromosomeMismatchError,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Import personalize functions
|
|
36
|
+
from .personalize import (
|
|
37
|
+
get_personal_genome,
|
|
38
|
+
get_alt_sequences,
|
|
39
|
+
get_ref_sequences,
|
|
40
|
+
get_pam_disrupting_alt_sequences,
|
|
41
|
+
get_alt_ref_sequences,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Import mutagenesis functions
|
|
45
|
+
from .mutagenesis import get_sm_sequences, get_sm_subsequences
|
|
46
|
+
|
|
47
|
+
# Import prediction alignment functions
|
|
48
|
+
from .prediction_alignment import align_predictions_by_coordinate
|
|
49
|
+
|
|
50
|
+
# Mock models are available in a separate submodule
|
|
51
|
+
# Import with: from supremo_lite.mock_models import TestModel, TestModel2D
|
|
52
|
+
# This allows users who don't have PyTorch to still use the main package
|
|
53
|
+
|
|
54
|
+
# Version
|
|
55
|
+
__version__ = "0.5.4"
|
|
56
|
+
# Package metadata
|
|
57
|
+
__description__ = (
|
|
58
|
+
"A module for generating personalized genome sequences and in-silico mutagenesis"
|
|
59
|
+
)
|
|
@@ -0,0 +1,322 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Chromosome name matching utilities for supremo_lite.
|
|
3
|
+
|
|
4
|
+
This module provides functions for handling mismatches in chromosome naming
|
|
5
|
+
between FASTA references and VCF files using intelligent heuristics.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import re
|
|
9
|
+
import warnings
|
|
10
|
+
from typing import Dict, Set, Optional, List, Tuple
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ChromosomeMismatchError(Exception):
|
|
14
|
+
"""
|
|
15
|
+
Raised when chromosome names in VCF and reference do not match.
|
|
16
|
+
|
|
17
|
+
This error is raised by default when chromosome names don't match exactly
|
|
18
|
+
and automatic chromosome mapping is not enabled.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def normalize_chromosome_name(chrom_name: str) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Normalize chromosome name to a standard format.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
chrom_name: Raw chromosome name from VCF or FASTA
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
Normalized chromosome name (without 'chr' prefix, uppercase)
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
'chr1' -> '1'
|
|
36
|
+
'CHR1' -> '1'
|
|
37
|
+
'chrX' -> 'X'
|
|
38
|
+
'chrMT' -> 'MT'
|
|
39
|
+
'M' -> 'MT' # Mitochondrial normalization
|
|
40
|
+
"""
|
|
41
|
+
# Convert to string and strip whitespace
|
|
42
|
+
normalized = str(chrom_name).strip()
|
|
43
|
+
|
|
44
|
+
# Remove 'chr' prefix (case insensitive)
|
|
45
|
+
normalized = re.sub(r"^chr", "", normalized, flags=re.IGNORECASE)
|
|
46
|
+
|
|
47
|
+
# Handle mitochondrial chromosome variants
|
|
48
|
+
if normalized.upper() in ["M", "MITO", "MITOCHONDRION"]:
|
|
49
|
+
normalized = "MT"
|
|
50
|
+
|
|
51
|
+
# Convert to uppercase for consistency
|
|
52
|
+
normalized = normalized.upper()
|
|
53
|
+
|
|
54
|
+
return normalized
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_chromosome_mapping(
|
|
58
|
+
reference_chroms: Set[str], vcf_chroms: Set[str]
|
|
59
|
+
) -> Dict[str, str]:
|
|
60
|
+
"""
|
|
61
|
+
Create a mapping from VCF chromosome names to reference chromosome names.
|
|
62
|
+
|
|
63
|
+
This function uses heuristics to match chromosome names between VCF and FASTA:
|
|
64
|
+
1. Exact match (case sensitive)
|
|
65
|
+
2. Exact match (case insensitive)
|
|
66
|
+
3. Normalized match (with/without 'chr' prefix)
|
|
67
|
+
4. Special cases for mitochondrial chromosomes
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
reference_chroms: Set of chromosome names from reference FASTA
|
|
71
|
+
vcf_chroms: Set of chromosome names from VCF file
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Tuple of (mapping dict, unmatched set)
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
reference_chroms = {'1', '2', 'X', 'Y', 'MT'}
|
|
78
|
+
vcf_chroms = {'chr1', 'chr2', 'chrX', 'chrY', 'chrM'}
|
|
79
|
+
Returns: {'chr1': '1', 'chr2': '2', 'chrX': 'X', 'chrY': 'Y', 'chrM': 'MT'}
|
|
80
|
+
"""
|
|
81
|
+
mapping = {}
|
|
82
|
+
unmatched_vcf = set()
|
|
83
|
+
|
|
84
|
+
# Try to match each VCF chromosome
|
|
85
|
+
for vcf_chrom in vcf_chroms:
|
|
86
|
+
matched_ref = None
|
|
87
|
+
|
|
88
|
+
# 1. Try exact match (case sensitive)
|
|
89
|
+
if vcf_chrom in reference_chroms:
|
|
90
|
+
matched_ref = vcf_chrom
|
|
91
|
+
|
|
92
|
+
# 2. Try exact match (case insensitive)
|
|
93
|
+
if matched_ref is None:
|
|
94
|
+
for ref_chrom in reference_chroms:
|
|
95
|
+
if vcf_chrom.lower() == ref_chrom.lower():
|
|
96
|
+
matched_ref = ref_chrom
|
|
97
|
+
break
|
|
98
|
+
|
|
99
|
+
# 3. Try removing/adding chr prefix
|
|
100
|
+
if matched_ref is None:
|
|
101
|
+
# If VCF has 'chr' prefix, try without it
|
|
102
|
+
if vcf_chrom.lower().startswith("chr"):
|
|
103
|
+
no_chr = vcf_chrom[3:]
|
|
104
|
+
if no_chr in reference_chroms:
|
|
105
|
+
matched_ref = no_chr
|
|
106
|
+
else:
|
|
107
|
+
# Try case insensitive match without chr
|
|
108
|
+
for ref_chrom in reference_chroms:
|
|
109
|
+
if no_chr.lower() == ref_chrom.lower():
|
|
110
|
+
matched_ref = ref_chrom
|
|
111
|
+
break
|
|
112
|
+
|
|
113
|
+
# If VCF doesn't have 'chr' prefix, try with it
|
|
114
|
+
else:
|
|
115
|
+
with_chr = f"chr{vcf_chrom}"
|
|
116
|
+
if with_chr in reference_chroms:
|
|
117
|
+
matched_ref = with_chr
|
|
118
|
+
else:
|
|
119
|
+
# Try case insensitive match with chr
|
|
120
|
+
for ref_chrom in reference_chroms:
|
|
121
|
+
if with_chr.lower() == ref_chrom.lower():
|
|
122
|
+
matched_ref = ref_chrom
|
|
123
|
+
break
|
|
124
|
+
|
|
125
|
+
# 4. Try normalized matching (handles mitochondrial variants)
|
|
126
|
+
if matched_ref is None:
|
|
127
|
+
vcf_normalized = normalize_chromosome_name(vcf_chrom)
|
|
128
|
+
for ref_chrom in reference_chroms:
|
|
129
|
+
ref_normalized = normalize_chromosome_name(ref_chrom)
|
|
130
|
+
if vcf_normalized == ref_normalized:
|
|
131
|
+
matched_ref = ref_chrom
|
|
132
|
+
break
|
|
133
|
+
|
|
134
|
+
# Record result
|
|
135
|
+
if matched_ref is not None:
|
|
136
|
+
mapping[vcf_chrom] = matched_ref
|
|
137
|
+
else:
|
|
138
|
+
unmatched_vcf.add(vcf_chrom)
|
|
139
|
+
|
|
140
|
+
return mapping, unmatched_vcf
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def apply_chromosome_mapping(variants_df, mapping: Dict[str, str]):
|
|
144
|
+
"""
|
|
145
|
+
Apply chromosome name mapping to a variants DataFrame.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
variants_df: Pandas DataFrame with 'chrom' column
|
|
149
|
+
mapping: Dictionary mapping original to new chromosome names
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
Modified DataFrame with updated chromosome names
|
|
153
|
+
"""
|
|
154
|
+
variants_df = variants_df.copy()
|
|
155
|
+
|
|
156
|
+
# Apply mapping to chromosome column
|
|
157
|
+
variants_df["chrom"] = variants_df["chrom"].map(lambda x: mapping.get(x, x))
|
|
158
|
+
|
|
159
|
+
return variants_df
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_chromosome_match_report(
|
|
163
|
+
reference_chroms: Set[str],
|
|
164
|
+
vcf_chroms: Set[str],
|
|
165
|
+
mapping: Dict[str, str],
|
|
166
|
+
unmatched: Set[str],
|
|
167
|
+
) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Generate a human-readable report of chromosome matching results.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
reference_chroms: Set of reference chromosome names
|
|
173
|
+
vcf_chroms: Set of VCF chromosome names
|
|
174
|
+
mapping: Successful mappings
|
|
175
|
+
unmatched: Unmatched VCF chromosomes
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Formatted report string
|
|
179
|
+
"""
|
|
180
|
+
report_lines = []
|
|
181
|
+
|
|
182
|
+
report_lines.append("Chromosome Matching Report")
|
|
183
|
+
report_lines.append("=" * 40)
|
|
184
|
+
|
|
185
|
+
report_lines.append(
|
|
186
|
+
f"Reference chromosomes ({len(reference_chroms)}): {sorted(reference_chroms)}"
|
|
187
|
+
)
|
|
188
|
+
report_lines.append(f"VCF chromosomes ({len(vcf_chroms)}): {sorted(vcf_chroms)}")
|
|
189
|
+
report_lines.append("")
|
|
190
|
+
|
|
191
|
+
if mapping:
|
|
192
|
+
report_lines.append(f"Successfully matched ({len(mapping)}):")
|
|
193
|
+
for vcf_chrom, ref_chrom in sorted(mapping.items()):
|
|
194
|
+
if vcf_chrom != ref_chrom:
|
|
195
|
+
report_lines.append(f" '{vcf_chrom}' -> '{ref_chrom}'")
|
|
196
|
+
else:
|
|
197
|
+
report_lines.append(f" '{vcf_chrom}' (exact match)")
|
|
198
|
+
|
|
199
|
+
if unmatched:
|
|
200
|
+
report_lines.append("")
|
|
201
|
+
report_lines.append(f"Unmatched VCF chromosomes ({len(unmatched)}):")
|
|
202
|
+
for chrom in sorted(unmatched):
|
|
203
|
+
report_lines.append(f" '{chrom}' (no suitable reference match found)")
|
|
204
|
+
|
|
205
|
+
report_lines.append("")
|
|
206
|
+
coverage = len(mapping) / len(vcf_chroms) * 100 if vcf_chroms else 100
|
|
207
|
+
report_lines.append(
|
|
208
|
+
f"Matching coverage: {coverage:.1f}% ({len(mapping)}/{len(vcf_chroms)})"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
return "\n".join(report_lines)
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def match_chromosomes_with_report(
|
|
215
|
+
reference_chroms: Set[str],
|
|
216
|
+
vcf_chroms: Set[str],
|
|
217
|
+
verbose: bool = True,
|
|
218
|
+
auto_map_chromosomes: bool = False,
|
|
219
|
+
) -> Tuple[Dict[str, str], Set[str]]:
|
|
220
|
+
"""
|
|
221
|
+
Match chromosomes and optionally print a detailed report.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
reference_chroms: Set of reference chromosome names
|
|
225
|
+
vcf_chroms: Set of VCF chromosome names
|
|
226
|
+
verbose: Whether to print matching report
|
|
227
|
+
auto_map_chromosomes: Whether to automatically map chromosome names when they don't
|
|
228
|
+
match exactly (default: False). When False, raises
|
|
229
|
+
ChromosomeMismatchError if names don't match.
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Tuple of (mapping dict, unmatched set)
|
|
233
|
+
|
|
234
|
+
Raises:
|
|
235
|
+
ChromosomeMismatchError: If auto_map_chromosomes=False and chromosome names don't
|
|
236
|
+
match exactly between VCF and reference
|
|
237
|
+
"""
|
|
238
|
+
# Check for exact matches first
|
|
239
|
+
exact_matches = reference_chroms & vcf_chroms
|
|
240
|
+
needs_mapping = vcf_chroms - exact_matches
|
|
241
|
+
|
|
242
|
+
# If all chromosomes match exactly, no mapping needed
|
|
243
|
+
if not needs_mapping:
|
|
244
|
+
mapping = {chrom: chrom for chrom in vcf_chroms}
|
|
245
|
+
return mapping, set()
|
|
246
|
+
|
|
247
|
+
# If mapping is needed but not enabled, raise error
|
|
248
|
+
if not auto_map_chromosomes:
|
|
249
|
+
# Format chromosome lists for error message
|
|
250
|
+
vcf_list = ", ".join(sorted(vcf_chroms))
|
|
251
|
+
ref_list = ", ".join(sorted(reference_chroms))
|
|
252
|
+
|
|
253
|
+
error_msg = (
|
|
254
|
+
f"Chromosome names in VCF and reference do not match.\n\n"
|
|
255
|
+
f"VCF chromosomes: {vcf_list}\n"
|
|
256
|
+
f"Reference chromosomes: {ref_list}\n\n"
|
|
257
|
+
f"To enable automatic chromosome name mapping, add auto_map_chromosomes=True:\n"
|
|
258
|
+
f" \n"
|
|
259
|
+
f" get_personal_genome(..., auto_map_chromosomes=True)\n\n"
|
|
260
|
+
f"Alternatively, rename chromosomes in your VCF to match the reference."
|
|
261
|
+
)
|
|
262
|
+
raise ChromosomeMismatchError(error_msg)
|
|
263
|
+
|
|
264
|
+
# Automatic mapping is enabled - use heuristics
|
|
265
|
+
mapping, unmatched = create_chromosome_mapping(reference_chroms, vcf_chroms)
|
|
266
|
+
|
|
267
|
+
if verbose and (
|
|
268
|
+
len(mapping) < len(vcf_chroms) or any(k != v for k, v in mapping.items())
|
|
269
|
+
):
|
|
270
|
+
report = get_chromosome_match_report(
|
|
271
|
+
reference_chroms, vcf_chroms, mapping, unmatched
|
|
272
|
+
)
|
|
273
|
+
print(report)
|
|
274
|
+
|
|
275
|
+
if unmatched:
|
|
276
|
+
chrom_list = ", ".join(sorted(unmatched))
|
|
277
|
+
warnings.warn(
|
|
278
|
+
f"Skipped {len(unmatched)} chromosome(s) not in reference: {chrom_list}"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
return mapping, unmatched
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def validate_chromosomes_early(reference, variants_fn):
|
|
285
|
+
"""
|
|
286
|
+
Efficiently validate chromosome compatibility before loading all variant data.
|
|
287
|
+
|
|
288
|
+
This function optimizes chromosome checking by:
|
|
289
|
+
- For VCF file paths: Reading only the chromosome column (very fast and memory efficient)
|
|
290
|
+
- For DataFrames: Using existing data without reloading
|
|
291
|
+
- Returning chromosome sets for reuse in subsequent mapping operations
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
reference: Reference genome (dict-like object with .keys())
|
|
295
|
+
variants_fn: VCF file path (str) or DataFrame with variants
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
Tuple of (ref_chroms, vcf_chroms) as sets
|
|
299
|
+
|
|
300
|
+
Note:
|
|
301
|
+
This function does NOT raise errors or perform mapping - it only extracts
|
|
302
|
+
chromosome names efficiently. Use match_chromosomes_with_report() for
|
|
303
|
+
actual validation and mapping.
|
|
304
|
+
"""
|
|
305
|
+
from .variant_utils import get_vcf_chromosomes
|
|
306
|
+
import pandas as pd
|
|
307
|
+
|
|
308
|
+
# Get reference chromosomes
|
|
309
|
+
ref_chroms = set(reference.keys())
|
|
310
|
+
|
|
311
|
+
# Get VCF chromosomes efficiently based on input type
|
|
312
|
+
if isinstance(variants_fn, str):
|
|
313
|
+
# VCF file path - use efficient chromosome extraction (reads only first column)
|
|
314
|
+
vcf_chroms = get_vcf_chromosomes(variants_fn)
|
|
315
|
+
elif isinstance(variants_fn, pd.DataFrame):
|
|
316
|
+
# DataFrame - extract unique chromosome names from chrom column
|
|
317
|
+
vcf_chroms = set(variants_fn["chrom"].unique())
|
|
318
|
+
else:
|
|
319
|
+
# Other formats - try to get chrom column
|
|
320
|
+
vcf_chroms = set(variants_fn["chrom"].unique())
|
|
321
|
+
|
|
322
|
+
return ref_chroms, vcf_chroms
|
supremo_lite/core.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core utilities, constants and common functions for supremo_lite.
|
|
3
|
+
|
|
4
|
+
This module provides the basic constants and utility functions used throughout
|
|
5
|
+
the package.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from collections import defaultdict
|
|
10
|
+
import warnings
|
|
11
|
+
|
|
12
|
+
# Check for PyTorch availability
|
|
13
|
+
try:
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
TORCH_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
TORCH_AVAILABLE = False
|
|
19
|
+
warnings.warn("PyTorch not found. Will return numpy arrays instead of tensors.")
|
|
20
|
+
|
|
21
|
+
# Check for brisket availability
|
|
22
|
+
try:
|
|
23
|
+
import brisket
|
|
24
|
+
|
|
25
|
+
BRISKET_AVAILABLE = True
|
|
26
|
+
except ImportError:
|
|
27
|
+
BRISKET_AVAILABLE = False
|
|
28
|
+
warnings.warn("Brisket not found. Using slower sequence encoding implementation.")
|
|
29
|
+
|
|
30
|
+
# Nucleotide to one-hot encoding mapping
|
|
31
|
+
# Using a defaultdict to handle ambiguous bases as zeros for efficiency
|
|
32
|
+
nt_to_1h = defaultdict(lambda: np.array([0, 0, 0, 0]))
|
|
33
|
+
nt_to_1h["A"] = np.array([1, 0, 0, 0])
|
|
34
|
+
nt_to_1h["a"] = np.array([1, 0, 0, 0])
|
|
35
|
+
nt_to_1h["C"] = np.array([0, 1, 0, 0])
|
|
36
|
+
nt_to_1h["c"] = np.array([0, 1, 0, 0])
|
|
37
|
+
nt_to_1h["G"] = np.array([0, 0, 1, 0])
|
|
38
|
+
nt_to_1h["g"] = np.array([0, 0, 1, 0])
|
|
39
|
+
nt_to_1h["T"] = np.array([0, 0, 0, 1])
|
|
40
|
+
nt_to_1h["t"] = np.array([0, 0, 0, 1])
|
|
41
|
+
nts = np.array(list("ACGT"))
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mock models for testing and demonstration purposes.
|
|
3
|
+
|
|
4
|
+
This module provides simple PyTorch models that mimic realistic genomic deep learning
|
|
5
|
+
architectures without requiring actual training. These models are intended for:
|
|
6
|
+
|
|
7
|
+
1. **Testing**: Verifying that prediction alignment functions work correctly with
|
|
8
|
+
realistic model outputs (binned predictions, edge cropping, diagonal masking)
|
|
9
|
+
|
|
10
|
+
2. **Documentation**: Providing immediately runnable examples for users who want to
|
|
11
|
+
understand the package workflow without training their own models
|
|
12
|
+
|
|
13
|
+
**Important**: These models return constant values and should NOT be used for actual
|
|
14
|
+
genomic predictions or biological interpretation.
|
|
15
|
+
|
|
16
|
+
Available Models
|
|
17
|
+
----------------
|
|
18
|
+
TestModel : nn.Module
|
|
19
|
+
Mock 1D genomic prediction model
|
|
20
|
+
- Output shape: (batch_size, n_targets, n_final_bins)
|
|
21
|
+
- Features: binning, edge cropping
|
|
22
|
+
|
|
23
|
+
TestModel2D : nn.Module
|
|
24
|
+
Mock 2D contact map prediction model
|
|
25
|
+
- Output shape: (batch_size, n_targets, n_flattened_ut_bins)
|
|
26
|
+
- Features: binning, edge cropping, diagonal masking, flattened output
|
|
27
|
+
|
|
28
|
+
Examples
|
|
29
|
+
--------
|
|
30
|
+
Using TestModel for 1D predictions:
|
|
31
|
+
|
|
32
|
+
>>> from supremo_lite.mock_models import TestModel, TORCH_AVAILABLE
|
|
33
|
+
>>> if TORCH_AVAILABLE:
|
|
34
|
+
... import torch
|
|
35
|
+
... model = TestModel(seq_length=1024, bin_length=32, crop_length=128)
|
|
36
|
+
... x = torch.randn(4, 4, 1024)
|
|
37
|
+
... predictions = model(x)
|
|
38
|
+
... print(predictions.shape)
|
|
39
|
+
torch.Size([4, 1, 24])
|
|
40
|
+
|
|
41
|
+
Using TestModel2D for contact maps:
|
|
42
|
+
|
|
43
|
+
>>> from supremo_lite.mock_models import TestModel2D
|
|
44
|
+
>>> if TORCH_AVAILABLE:
|
|
45
|
+
... import torch
|
|
46
|
+
... model = TestModel2D(seq_length=2048, bin_length=64, crop_length=256)
|
|
47
|
+
... x = torch.randn(4, 4, 2048)
|
|
48
|
+
... predictions = model(x)
|
|
49
|
+
... print(predictions.shape)
|
|
50
|
+
torch.Size([4, 1, 276])
|
|
51
|
+
|
|
52
|
+
Checking PyTorch Availability
|
|
53
|
+
------------------------------
|
|
54
|
+
>>> from supremo_lite.mock_models import TORCH_AVAILABLE
|
|
55
|
+
>>> if not TORCH_AVAILABLE:
|
|
56
|
+
... print("Please install PyTorch to use mock models")
|
|
57
|
+
|
|
58
|
+
Notes
|
|
59
|
+
-----
|
|
60
|
+
- Requires PyTorch to be installed
|
|
61
|
+
- If PyTorch is not available, attempting to instantiate models will raise ImportError
|
|
62
|
+
- Check TORCH_AVAILABLE before using models
|
|
63
|
+
- See individual model documentation for architecture details
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
from .testmodel_1d import TestModel, TORCH_AVAILABLE as TORCH_AVAILABLE_1D
|
|
68
|
+
from .testmodel_2d import TestModel2D, TORCH_AVAILABLE as TORCH_AVAILABLE_2D
|
|
69
|
+
|
|
70
|
+
# Both should have the same value, but check for consistency
|
|
71
|
+
TORCH_AVAILABLE = TORCH_AVAILABLE_1D and TORCH_AVAILABLE_2D
|
|
72
|
+
|
|
73
|
+
except ImportError as e:
|
|
74
|
+
# This should rarely happen since the modules handle their own imports
|
|
75
|
+
# But we provide a graceful fallback
|
|
76
|
+
import warnings
|
|
77
|
+
|
|
78
|
+
warnings.warn(
|
|
79
|
+
f"Could not import mock models: {e}\n"
|
|
80
|
+
"Mock models require PyTorch. Install with: pip install torch",
|
|
81
|
+
ImportWarning,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# Create placeholder classes
|
|
85
|
+
class TestModel:
|
|
86
|
+
"""TestModel requires PyTorch. Please install with: pip install torch"""
|
|
87
|
+
|
|
88
|
+
def __init__(self, *args, **kwargs):
|
|
89
|
+
raise ImportError(
|
|
90
|
+
"TestModel requires PyTorch. Install with: pip install torch\n"
|
|
91
|
+
"See https://pytorch.org/get-started/locally/"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
class TestModel2D:
|
|
95
|
+
"""TestModel2D requires PyTorch. Please install with: pip install torch"""
|
|
96
|
+
|
|
97
|
+
def __init__(self, *args, **kwargs):
|
|
98
|
+
raise ImportError(
|
|
99
|
+
"TestModel2D requires PyTorch. Install with: pip install torch\n"
|
|
100
|
+
"See https://pytorch.org/get-started/locally/"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
TORCH_AVAILABLE = False
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
__all__ = [
|
|
107
|
+
"TestModel",
|
|
108
|
+
"TestModel2D",
|
|
109
|
+
"TORCH_AVAILABLE",
|
|
110
|
+
]
|