supremo-lite 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,137 @@
1
+ """
2
+ Sequence transformation utilities for supremo_lite.
3
+
4
+ This module provides functions for encoding, decoding, and manipulating
5
+ DNA sequences.
6
+ """
7
+
8
+ import numpy as np
9
+ from .core import nt_to_1h, nts, TORCH_AVAILABLE, BRISKET_AVAILABLE
10
+
11
+ try:
12
+ import torch
13
+ except ImportError:
14
+ pass # Already handled in core module
15
+
16
+ if BRISKET_AVAILABLE:
17
+ try:
18
+ from brisket import encode_seq as brisket_encode_seq
19
+ except ImportError:
20
+ pass # Already handled in core module
21
+
22
+
23
+ def encode_seq(seq, encoder=None):
24
+ """
25
+ Convert a nucleotide string to a one-hot encoded tensor/array.
26
+
27
+ Args:
28
+ seq: A string of nucleotides or a list of such strings
29
+ encoder: Optional custom encoding function. If provided, should accept a single
30
+ sequence string and return encoded array with shape (4, L).
31
+
32
+ Returns:
33
+ A tensor/array with shape (4, L) for a single sequence or (N, 4, L) for a list,
34
+ where L is the sequence length and N is the number of sequences.
35
+
36
+ Encoding scheme (default):
37
+ 'A' = [1, 0, 0, 0] (first channel)
38
+ 'C' = [0, 1, 0, 0] (second channel)
39
+ 'G' = [0, 0, 1, 0] (third channel)
40
+ 'T' = [0, 0, 0, 1] (fourth channel)
41
+ 'N' = [0, 0, 0, 0] (all channels zero)
42
+ """
43
+ if isinstance(seq, list):
44
+ # For a list of sequences, encode each separately and stack
45
+ encoded = np.stack([encode_seq(s, encoder) for s in seq])
46
+ if TORCH_AVAILABLE:
47
+ return torch.from_numpy(encoded).float()
48
+ return encoded
49
+
50
+ # Use custom encoder if provided
51
+ if encoder is not None:
52
+ encoded = encoder(seq)
53
+ if TORCH_AVAILABLE:
54
+ return torch.from_numpy(encoded).float()
55
+ return encoded
56
+
57
+ # For a single sequence, use brisket if available for performance, otherwise fallback
58
+ if BRISKET_AVAILABLE:
59
+ try:
60
+ # Use brisket for fast encoding
61
+ encoded = brisket_encode_seq(seq.upper())
62
+
63
+ except Exception:
64
+ # Fallback to original implementation if brisket fails
65
+ import warnings
66
+
67
+ warnings.warn(
68
+ "Fast encoding with brisket failed, falling back to slower NumPy implementation. "
69
+ "This may impact performance for large sequences.",
70
+ UserWarning,
71
+ stacklevel=2,
72
+ )
73
+ encoded = np.array([nt_to_1h[nt] for nt in seq]).T
74
+ else:
75
+ # Original implementation
76
+ encoded = np.array([nt_to_1h[nt] for nt in seq]).T
77
+
78
+ if TORCH_AVAILABLE:
79
+ return torch.from_numpy(encoded).float()
80
+ return encoded
81
+
82
+
83
+ def decode_seq(seq_1h):
84
+ """
85
+ Convert a one-hot encoded tensor/array back to a nucleotide string.
86
+
87
+ Args:
88
+ seq_1h: A tensor/array with shape (4, L) or (N, 4, L)
89
+
90
+ Returns:
91
+ A string or list of strings of nucleotides
92
+ """
93
+ # Convert to numpy if it's a torch tensor
94
+ if TORCH_AVAILABLE and isinstance(seq_1h, torch.Tensor):
95
+ seq_1h = seq_1h.numpy()
96
+
97
+ # Handle batch dimension if present
98
+ if len(seq_1h.shape) == 3:
99
+ return [decode_seq(s) for s in seq_1h]
100
+
101
+ # Get the index of the maximum value along the channel dimension (first dimension)
102
+ indices = seq_1h.argmax(axis=0)
103
+ seq = nts[indices]
104
+
105
+ return "".join(seq)
106
+
107
+
108
+ def rc(seq_1h):
109
+ """
110
+ Reverse complement a one-hot encoded tensor/array.
111
+
112
+ Args:
113
+ seq_1h: A tensor/array with shape (4, L) or (N, 4, L)
114
+
115
+ Returns:
116
+ The reverse complement with the same shape
117
+ """
118
+ if TORCH_AVAILABLE and isinstance(seq_1h, torch.Tensor):
119
+ # Reverse channels for complement: [A, C, G, T] → [T, G, C, A]
120
+ # Then flip the sequence dimension
121
+ return seq_1h[..., [3, 2, 1, 0], :].flip(dims=[-1])
122
+ # NumPy version: reverse channels and flip sequence dimension
123
+ return np.flip(seq_1h[..., [3, 2, 1, 0], :], axis=-1)
124
+
125
+
126
+ def rc_str(seq):
127
+ """
128
+ Reverse complement a nucleotide string.
129
+
130
+ Args:
131
+ seq: A string of nucleotides
132
+
133
+ Returns:
134
+ The reverse complement string
135
+ """
136
+ t = str.maketrans("ACGTacgt", "TGCAtgca")
137
+ return seq.translate(t)[::-1]