codon-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codon/__init__.py +5 -0
- codon/base.py +167 -0
- codon/exp/__init__.py +0 -0
- codon/exp/moe.py +307 -0
- codon/model/__init__.py +0 -0
- codon/model/motif/__init__.py +1 -0
- codon/model/motif/motif_a1.py +121 -0
- codon/model/patch_disc.py +151 -0
- codon/model/tcn.py +124 -0
- codon/ops/__init__.py +3 -0
- codon/ops/attention.py +107 -0
- codon/ops/bio.py +0 -0
- codon/utils/__init__.py +0 -0
- codon/utils/dataset/__init__.py +3 -0
- codon/utils/dataset/base.py +46 -0
- codon/utils/dataset/corpus.py +478 -0
- codon/utils/dataset/dataviewer.py +196 -0
- codon/utils/dataset/flatdata.py +455 -0
- codon/utils/mask.py +266 -0
- codon/utils/safecode.py +24 -0
- codon/utils/seed.py +75 -0
- codon/utils/theta.py +55 -0
- codon/utils/token.py +276 -0
- codon_model-0.0.1.dist-info/METADATA +17 -0
- codon_model-0.0.1.dist-info/RECORD +28 -0
- codon_model-0.0.1.dist-info/WHEEL +5 -0
- codon_model-0.0.1.dist-info/licenses/LICENSE +201 -0
- codon_model-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
from codon.base import *
|
|
4
|
+
from codon.block.conv import ConvBlock
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PatchDiscriminator(BasicModel):
|
|
8
|
+
'''
|
|
9
|
+
PatchGAN discriminator.
|
|
10
|
+
|
|
11
|
+
The output is not a scalar, but an N x N matrix, where each point represents
|
|
12
|
+
whether the corresponding patch is real or fake.
|
|
13
|
+
|
|
14
|
+
Attributes:
|
|
15
|
+
main (nn.Sequential): The main sequential model.
|
|
16
|
+
'''
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
in_channels: int = 3,
|
|
21
|
+
hidden_dim: int = 64,
|
|
22
|
+
num_layers: int = 3,
|
|
23
|
+
norm: str = 'batch',
|
|
24
|
+
activation: str = 'leaky_relu',
|
|
25
|
+
leaky_relu: float = 0.2
|
|
26
|
+
) -> None:
|
|
27
|
+
'''
|
|
28
|
+
Initialize the PatchDiscriminator.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
in_channels (int): Number of input channels. Defaults to 3.
|
|
32
|
+
hidden_dim (int): Base number of filters (channels) in the discriminator. Defaults to 64.
|
|
33
|
+
num_layers (int): Number of layers in the discriminator. Defaults to 3.
|
|
34
|
+
norm (str, optional): Normalization type. Defaults to 'batch'.
|
|
35
|
+
activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
|
|
36
|
+
leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.2.
|
|
37
|
+
'''
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
sequence = [
|
|
41
|
+
ConvBlock(
|
|
42
|
+
in_channels=in_channels,
|
|
43
|
+
out_channels=hidden_dim,
|
|
44
|
+
kernel_size=4,
|
|
45
|
+
stride=2,
|
|
46
|
+
padding=1,
|
|
47
|
+
dim=2,
|
|
48
|
+
norm=None,
|
|
49
|
+
activation=activation,
|
|
50
|
+
leaky_relu=leaky_relu,
|
|
51
|
+
bias=True
|
|
52
|
+
)
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
channel_mult = 1
|
|
56
|
+
channel_mult_prev = 1
|
|
57
|
+
for n in range(1, num_layers):
|
|
58
|
+
channel_mult_prev = channel_mult
|
|
59
|
+
channel_mult = min(2 ** n, 8)
|
|
60
|
+
sequence += [
|
|
61
|
+
ConvBlock(
|
|
62
|
+
in_channels=hidden_dim * channel_mult_prev,
|
|
63
|
+
out_channels=hidden_dim * channel_mult,
|
|
64
|
+
kernel_size=4,
|
|
65
|
+
stride=2,
|
|
66
|
+
padding=1,
|
|
67
|
+
dim=2,
|
|
68
|
+
norm=norm,
|
|
69
|
+
activation=activation,
|
|
70
|
+
leaky_relu=leaky_relu,
|
|
71
|
+
bias=False
|
|
72
|
+
)
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
channel_mult_prev = channel_mult
|
|
76
|
+
channel_mult = min(2 ** num_layers, 8)
|
|
77
|
+
|
|
78
|
+
sequence += [
|
|
79
|
+
ConvBlock(
|
|
80
|
+
in_channels=hidden_dim * channel_mult_prev,
|
|
81
|
+
out_channels=hidden_dim * channel_mult,
|
|
82
|
+
kernel_size=4,
|
|
83
|
+
stride=1,
|
|
84
|
+
padding=1,
|
|
85
|
+
dim=2,
|
|
86
|
+
norm=norm,
|
|
87
|
+
activation=activation,
|
|
88
|
+
leaky_relu=leaky_relu,
|
|
89
|
+
bias=False
|
|
90
|
+
),
|
|
91
|
+
ConvBlock(
|
|
92
|
+
in_channels=hidden_dim * channel_mult,
|
|
93
|
+
out_channels=1,
|
|
94
|
+
kernel_size=4,
|
|
95
|
+
stride=1,
|
|
96
|
+
padding=1,
|
|
97
|
+
dim=2,
|
|
98
|
+
norm=None,
|
|
99
|
+
activation=None,
|
|
100
|
+
bias=True
|
|
101
|
+
)
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
self.main = nn.Sequential(*sequence)
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def auto_build(
|
|
108
|
+
in_channels: int,
|
|
109
|
+
hidden_dim: int,
|
|
110
|
+
image_size: int,
|
|
111
|
+
norm: str = 'batch',
|
|
112
|
+
activation: str = 'leaky_relu',
|
|
113
|
+
leaky_relu: float = 0.2
|
|
114
|
+
) -> 'PatchDiscriminator':
|
|
115
|
+
'''
|
|
116
|
+
Automatically builds a PatchDiscriminator based on the image size.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
in_channels (int): Number of input channels.
|
|
120
|
+
hidden_dim (int): Base number of filters (channels).
|
|
121
|
+
image_size (int): Size of the input image.
|
|
122
|
+
norm (str, optional): Normalization type. Defaults to 'batch'.
|
|
123
|
+
activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
|
|
124
|
+
leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.2.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
PatchDiscriminator: The constructed PatchDiscriminator.
|
|
128
|
+
'''
|
|
129
|
+
num_layers = int(math.log2(image_size / 32))
|
|
130
|
+
num_layers = max(1, num_layers)
|
|
131
|
+
|
|
132
|
+
return PatchDiscriminator(
|
|
133
|
+
in_channels=in_channels,
|
|
134
|
+
hidden_dim=hidden_dim,
|
|
135
|
+
num_layers=num_layers,
|
|
136
|
+
norm=norm,
|
|
137
|
+
activation=activation,
|
|
138
|
+
leaky_relu=leaky_relu
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
|
|
142
|
+
'''
|
|
143
|
+
Defines the computation performed at every call.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
input_tensor (torch.Tensor): The input data.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
torch.Tensor: The output of the discriminator.
|
|
150
|
+
'''
|
|
151
|
+
return self.main(input_tensor)
|
codon/model/tcn.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from codon.base import *
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
from codon.block.conv import CausalConv1d, calculate_causal_layer
|
|
5
|
+
|
|
6
|
+
class TemporalConvNet(BasicModel):
|
|
7
|
+
'''
|
|
8
|
+
Temporal Convolutional Network.
|
|
9
|
+
|
|
10
|
+
Consists of a series of Causal Dilated Convolution layers.
|
|
11
|
+
Supports manually specifying the number of channels for each layer.
|
|
12
|
+
Use `TemporalConvNet.auto_build` for automatic construction based on the target receptive field.
|
|
13
|
+
'''
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
in_channels: int,
|
|
18
|
+
num_channels: List[int],
|
|
19
|
+
kernel_size: int = 3,
|
|
20
|
+
dropout: float = 0.2,
|
|
21
|
+
use_res: bool = True,
|
|
22
|
+
norm: str = None,
|
|
23
|
+
activation: str = 'leaky_relu',
|
|
24
|
+
leaky_relu: float = 0.1,
|
|
25
|
+
channel_first: bool = True
|
|
26
|
+
):
|
|
27
|
+
'''
|
|
28
|
+
Initializes the TCN module manually.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
in_channels (int): Number of input channels.
|
|
32
|
+
num_channels (List[int]): List of output channels for each layer.
|
|
33
|
+
kernel_size (int, optional): Kernel size. Defaults to 3.
|
|
34
|
+
dropout (float, optional): Dropout probability. Defaults to 0.2.
|
|
35
|
+
use_res (bool, optional): Whether to use residual connections. Defaults to True.
|
|
36
|
+
norm (str, optional): Normalization type (passed to CausalConv1d/ConvBlock). Defaults to None.
|
|
37
|
+
activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
|
|
38
|
+
leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.1.
|
|
39
|
+
channel_first (bool, optional): Whether input is (Batch, Channels, Seq_Len). Defaults to True.
|
|
40
|
+
'''
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self.in_channels = in_channels
|
|
44
|
+
self.kernel_size = kernel_size
|
|
45
|
+
self.dropout = dropout
|
|
46
|
+
self.channel_first = channel_first
|
|
47
|
+
|
|
48
|
+
self.network = CausalConv1d.manual_block(
|
|
49
|
+
in_channels=in_channels,
|
|
50
|
+
num_channels=num_channels,
|
|
51
|
+
kernel_size=kernel_size,
|
|
52
|
+
norm=norm,
|
|
53
|
+
activation=activation,
|
|
54
|
+
leaky_relu=leaky_relu,
|
|
55
|
+
use_res=use_res,
|
|
56
|
+
dropout=dropout
|
|
57
|
+
)
|
|
58
|
+
self.out_channels = num_channels[-1]
|
|
59
|
+
|
|
60
|
+
@staticmethod
|
|
61
|
+
def auto_build(
|
|
62
|
+
in_channels: int,
|
|
63
|
+
out_channels: int,
|
|
64
|
+
receptive_field: int,
|
|
65
|
+
kernel_size: int = 3,
|
|
66
|
+
dropout: float = 0.2,
|
|
67
|
+
use_res: bool = True,
|
|
68
|
+
norm: str = None,
|
|
69
|
+
activation: str = 'leaky_relu',
|
|
70
|
+
leaky_relu: float = 0.1,
|
|
71
|
+
channel_first: bool = True
|
|
72
|
+
) -> 'TemporalConvNet':
|
|
73
|
+
'''
|
|
74
|
+
Automatically builds a TCN module based on the target receptive field.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
in_channels (int): Number of input channels.
|
|
78
|
+
out_channels (int): Unified output channels for each layer.
|
|
79
|
+
receptive_field (int): Target receptive field (time steps).
|
|
80
|
+
kernel_size (int, optional): Kernel size. Defaults to 3.
|
|
81
|
+
dropout (float, optional): Dropout probability. Defaults to 0.2.
|
|
82
|
+
use_res (bool, optional): Whether to use residual connections. Defaults to True.
|
|
83
|
+
norm (str, optional): Normalization type. Defaults to None.
|
|
84
|
+
activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
|
|
85
|
+
leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.1.
|
|
86
|
+
channel_first (bool, optional): Whether input is (Batch, Channels, Seq_Len). Defaults to True.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
TemporalConvNet: An initialized TCN module.
|
|
90
|
+
'''
|
|
91
|
+
layers, _ = calculate_causal_layer(receptive_field, kernel_size)
|
|
92
|
+
num_channels = [out_channels] * layers
|
|
93
|
+
|
|
94
|
+
return TemporalConvNet(
|
|
95
|
+
in_channels=in_channels,
|
|
96
|
+
num_channels=num_channels,
|
|
97
|
+
kernel_size=kernel_size,
|
|
98
|
+
dropout=dropout,
|
|
99
|
+
use_res=use_res,
|
|
100
|
+
norm=norm,
|
|
101
|
+
activation=activation,
|
|
102
|
+
leaky_relu=leaky_relu,
|
|
103
|
+
channel_first=channel_first
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
107
|
+
'''
|
|
108
|
+
Forward pass.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
x (torch.Tensor): Input tensor. Shape: [Batch, in_channels, Seq_Len] or [Batch, Seq_Len, in_channels] if channel_first=False.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
torch.Tensor: Output tensor.
|
|
115
|
+
'''
|
|
116
|
+
if not self.channel_first:
|
|
117
|
+
x = x.transpose(1, 2)
|
|
118
|
+
|
|
119
|
+
x = self.network(x)
|
|
120
|
+
|
|
121
|
+
if not self.channel_first:
|
|
122
|
+
x = x.transpose(1, 2)
|
|
123
|
+
|
|
124
|
+
return x
|
codon/ops/__init__.py
ADDED
codon/ops/attention.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import torch.nn.functional as F
|
|
2
|
+
import torch
|
|
3
|
+
import math
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class AttentionOutput:
|
|
12
|
+
'''
|
|
13
|
+
Output of the attention mechanism.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
output (torch.Tensor): The output tensor from the attention mechanism.
|
|
17
|
+
attention_weights (Optional[torch.Tensor], optional): Attention weights. Defaults to None.
|
|
18
|
+
past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]], optional):
|
|
19
|
+
Cached key and value tensors for autoregressive generation. Defaults to None.
|
|
20
|
+
'''
|
|
21
|
+
output: torch.Tensor
|
|
22
|
+
attention_weights: Optional[torch.Tensor] = None
|
|
23
|
+
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def apply_attention(
|
|
27
|
+
query_states: torch.Tensor,
|
|
28
|
+
key_states: torch.Tensor,
|
|
29
|
+
value_states: torch.Tensor,
|
|
30
|
+
attention_mask: torch.Tensor = None,
|
|
31
|
+
output_attentions: bool = False,
|
|
32
|
+
is_causal: bool = None,
|
|
33
|
+
dropout: float = 0.0
|
|
34
|
+
) -> AttentionOutput:
|
|
35
|
+
'''
|
|
36
|
+
Compute scaled dot-product attention.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
query_states (torch.Tensor): Query states tensor.
|
|
40
|
+
key_states (torch.Tensor): Key states tensor.
|
|
41
|
+
value_states (torch.Tensor): Value states tensor.
|
|
42
|
+
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
|
|
43
|
+
output_attentions (bool, optional): Whether to output attention weights. Defaults to False.
|
|
44
|
+
is_causal (bool, optional): Whether to apply a causal mask. Defaults to None.
|
|
45
|
+
dropout (float, optional): Dropout probability. Defaults to 0.0.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
AttentionOutput: Object containing attention output and optional weights.
|
|
49
|
+
'''
|
|
50
|
+
|
|
51
|
+
if attention_mask is not None:
|
|
52
|
+
if attention_mask.dtype != torch.float32:
|
|
53
|
+
attention_mask = attention_mask.float()
|
|
54
|
+
|
|
55
|
+
if attention_mask.max() <= 1.0:
|
|
56
|
+
attention_mask = torch.where(attention_mask == 0, float('-inf'), 0.0)
|
|
57
|
+
|
|
58
|
+
if is_causal:
|
|
59
|
+
tgt_len = query_states.size(-2)
|
|
60
|
+
src_len = key_states.size(-2)
|
|
61
|
+
|
|
62
|
+
causal_mask = torch.tril(
|
|
63
|
+
torch.ones((tgt_len, src_len), device=query_states.device, dtype=query_states.dtype)
|
|
64
|
+
).view(1, 1, tgt_len, src_len)
|
|
65
|
+
|
|
66
|
+
causal_mask = torch.where(causal_mask == 0, float('-inf'), 0.0)
|
|
67
|
+
|
|
68
|
+
if attention_mask is not None:
|
|
69
|
+
attention_mask = attention_mask + causal_mask
|
|
70
|
+
else:
|
|
71
|
+
attention_mask = causal_mask
|
|
72
|
+
|
|
73
|
+
is_causal = False
|
|
74
|
+
|
|
75
|
+
if not output_attentions:
|
|
76
|
+
if attention_mask is None and is_causal is None:
|
|
77
|
+
is_causal = True
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
with sdpa_kernel([
|
|
81
|
+
SDPBackend.FLASH_ATTENTION,
|
|
82
|
+
SDPBackend.CUDNN_ATTENTION
|
|
83
|
+
]):
|
|
84
|
+
output = F.scaled_dot_product_attention(
|
|
85
|
+
query_states,
|
|
86
|
+
key_states,
|
|
87
|
+
value_states,
|
|
88
|
+
attn_mask=attention_mask,
|
|
89
|
+
is_causal=is_causal,
|
|
90
|
+
dropout_p=dropout
|
|
91
|
+
)
|
|
92
|
+
return AttentionOutput(output=output, attention_weights=None)
|
|
93
|
+
except RuntimeError:
|
|
94
|
+
pass
|
|
95
|
+
# Manual Fallback Path
|
|
96
|
+
d_k = query_states.size(-1)
|
|
97
|
+
scores = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(d_k)
|
|
98
|
+
|
|
99
|
+
if attention_mask is not None:
|
|
100
|
+
scores = scores + attention_mask
|
|
101
|
+
attention_weights = torch.softmax(scores, dim=-1)
|
|
102
|
+
|
|
103
|
+
if dropout > 0.0:
|
|
104
|
+
attention_weights = F.dropout(attention_weights, p=dropout)
|
|
105
|
+
output = torch.matmul(attention_weights, value_states)
|
|
106
|
+
|
|
107
|
+
return AttentionOutput(output=output, attention_weights=attention_weights)
|
codon/ops/bio.py
ADDED
|
File without changes
|
codon/utils/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
class CodonDataset:
|
|
4
|
+
'''
|
|
5
|
+
Base class for all Codon datasets.
|
|
6
|
+
|
|
7
|
+
This abstract class defines the interface that all datasets must implement.
|
|
8
|
+
It provides a common structure for accessing data rows and length.
|
|
9
|
+
'''
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def row(self) -> int:
|
|
13
|
+
'''
|
|
14
|
+
Returns the number of rows in the dataset.
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
int: The total number of rows.
|
|
18
|
+
'''
|
|
19
|
+
return len(self)
|
|
20
|
+
|
|
21
|
+
def __len__(self) -> int:
|
|
22
|
+
'''
|
|
23
|
+
Returns the length of the dataset.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
int: The length of the dataset.
|
|
27
|
+
|
|
28
|
+
Raises:
|
|
29
|
+
NotImplementedError: If the method is not implemented by the subclass.
|
|
30
|
+
'''
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
def __getitem__(self, idx: Any) -> Any:
|
|
34
|
+
'''
|
|
35
|
+
Retrieves an item from the dataset at the specified index.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
idx (Any): The index of the item to retrieve.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Any: The item at the specified index.
|
|
42
|
+
|
|
43
|
+
Raises:
|
|
44
|
+
NotImplementedError: If the method is not implemented by the subclass.
|
|
45
|
+
'''
|
|
46
|
+
raise NotImplementedError
|