codon-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ import math
2
+
3
+ from codon.base import *
4
+ from codon.block.conv import ConvBlock
5
+
6
+
7
+ class PatchDiscriminator(BasicModel):
8
+ '''
9
+ PatchGAN discriminator.
10
+
11
+ The output is not a scalar, but an N x N matrix, where each point represents
12
+ whether the corresponding patch is real or fake.
13
+
14
+ Attributes:
15
+ main (nn.Sequential): The main sequential model.
16
+ '''
17
+
18
+ def __init__(
19
+ self,
20
+ in_channels: int = 3,
21
+ hidden_dim: int = 64,
22
+ num_layers: int = 3,
23
+ norm: str = 'batch',
24
+ activation: str = 'leaky_relu',
25
+ leaky_relu: float = 0.2
26
+ ) -> None:
27
+ '''
28
+ Initialize the PatchDiscriminator.
29
+
30
+ Args:
31
+ in_channels (int): Number of input channels. Defaults to 3.
32
+ hidden_dim (int): Base number of filters (channels) in the discriminator. Defaults to 64.
33
+ num_layers (int): Number of layers in the discriminator. Defaults to 3.
34
+ norm (str, optional): Normalization type. Defaults to 'batch'.
35
+ activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
36
+ leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.2.
37
+ '''
38
+ super().__init__()
39
+
40
+ sequence = [
41
+ ConvBlock(
42
+ in_channels=in_channels,
43
+ out_channels=hidden_dim,
44
+ kernel_size=4,
45
+ stride=2,
46
+ padding=1,
47
+ dim=2,
48
+ norm=None,
49
+ activation=activation,
50
+ leaky_relu=leaky_relu,
51
+ bias=True
52
+ )
53
+ ]
54
+
55
+ channel_mult = 1
56
+ channel_mult_prev = 1
57
+ for n in range(1, num_layers):
58
+ channel_mult_prev = channel_mult
59
+ channel_mult = min(2 ** n, 8)
60
+ sequence += [
61
+ ConvBlock(
62
+ in_channels=hidden_dim * channel_mult_prev,
63
+ out_channels=hidden_dim * channel_mult,
64
+ kernel_size=4,
65
+ stride=2,
66
+ padding=1,
67
+ dim=2,
68
+ norm=norm,
69
+ activation=activation,
70
+ leaky_relu=leaky_relu,
71
+ bias=False
72
+ )
73
+ ]
74
+
75
+ channel_mult_prev = channel_mult
76
+ channel_mult = min(2 ** num_layers, 8)
77
+
78
+ sequence += [
79
+ ConvBlock(
80
+ in_channels=hidden_dim * channel_mult_prev,
81
+ out_channels=hidden_dim * channel_mult,
82
+ kernel_size=4,
83
+ stride=1,
84
+ padding=1,
85
+ dim=2,
86
+ norm=norm,
87
+ activation=activation,
88
+ leaky_relu=leaky_relu,
89
+ bias=False
90
+ ),
91
+ ConvBlock(
92
+ in_channels=hidden_dim * channel_mult,
93
+ out_channels=1,
94
+ kernel_size=4,
95
+ stride=1,
96
+ padding=1,
97
+ dim=2,
98
+ norm=None,
99
+ activation=None,
100
+ bias=True
101
+ )
102
+ ]
103
+
104
+ self.main = nn.Sequential(*sequence)
105
+
106
+ @staticmethod
107
+ def auto_build(
108
+ in_channels: int,
109
+ hidden_dim: int,
110
+ image_size: int,
111
+ norm: str = 'batch',
112
+ activation: str = 'leaky_relu',
113
+ leaky_relu: float = 0.2
114
+ ) -> 'PatchDiscriminator':
115
+ '''
116
+ Automatically builds a PatchDiscriminator based on the image size.
117
+
118
+ Args:
119
+ in_channels (int): Number of input channels.
120
+ hidden_dim (int): Base number of filters (channels).
121
+ image_size (int): Size of the input image.
122
+ norm (str, optional): Normalization type. Defaults to 'batch'.
123
+ activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
124
+ leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.2.
125
+
126
+ Returns:
127
+ PatchDiscriminator: The constructed PatchDiscriminator.
128
+ '''
129
+ num_layers = int(math.log2(image_size / 32))
130
+ num_layers = max(1, num_layers)
131
+
132
+ return PatchDiscriminator(
133
+ in_channels=in_channels,
134
+ hidden_dim=hidden_dim,
135
+ num_layers=num_layers,
136
+ norm=norm,
137
+ activation=activation,
138
+ leaky_relu=leaky_relu
139
+ )
140
+
141
+ def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
142
+ '''
143
+ Defines the computation performed at every call.
144
+
145
+ Args:
146
+ input_tensor (torch.Tensor): The input data.
147
+
148
+ Returns:
149
+ torch.Tensor: The output of the discriminator.
150
+ '''
151
+ return self.main(input_tensor)
codon/model/tcn.py ADDED
@@ -0,0 +1,124 @@
1
+ from codon.base import *
2
+ from typing import List, Optional
3
+
4
+ from codon.block.conv import CausalConv1d, calculate_causal_layer
5
+
6
+ class TemporalConvNet(BasicModel):
7
+ '''
8
+ Temporal Convolutional Network.
9
+
10
+ Consists of a series of Causal Dilated Convolution layers.
11
+ Supports manually specifying the number of channels for each layer.
12
+ Use `TemporalConvNet.auto_build` for automatic construction based on the target receptive field.
13
+ '''
14
+
15
+ def __init__(
16
+ self,
17
+ in_channels: int,
18
+ num_channels: List[int],
19
+ kernel_size: int = 3,
20
+ dropout: float = 0.2,
21
+ use_res: bool = True,
22
+ norm: str = None,
23
+ activation: str = 'leaky_relu',
24
+ leaky_relu: float = 0.1,
25
+ channel_first: bool = True
26
+ ):
27
+ '''
28
+ Initializes the TCN module manually.
29
+
30
+ Args:
31
+ in_channels (int): Number of input channels.
32
+ num_channels (List[int]): List of output channels for each layer.
33
+ kernel_size (int, optional): Kernel size. Defaults to 3.
34
+ dropout (float, optional): Dropout probability. Defaults to 0.2.
35
+ use_res (bool, optional): Whether to use residual connections. Defaults to True.
36
+ norm (str, optional): Normalization type (passed to CausalConv1d/ConvBlock). Defaults to None.
37
+ activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
38
+ leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.1.
39
+ channel_first (bool, optional): Whether input is (Batch, Channels, Seq_Len). Defaults to True.
40
+ '''
41
+ super().__init__()
42
+
43
+ self.in_channels = in_channels
44
+ self.kernel_size = kernel_size
45
+ self.dropout = dropout
46
+ self.channel_first = channel_first
47
+
48
+ self.network = CausalConv1d.manual_block(
49
+ in_channels=in_channels,
50
+ num_channels=num_channels,
51
+ kernel_size=kernel_size,
52
+ norm=norm,
53
+ activation=activation,
54
+ leaky_relu=leaky_relu,
55
+ use_res=use_res,
56
+ dropout=dropout
57
+ )
58
+ self.out_channels = num_channels[-1]
59
+
60
+ @staticmethod
61
+ def auto_build(
62
+ in_channels: int,
63
+ out_channels: int,
64
+ receptive_field: int,
65
+ kernel_size: int = 3,
66
+ dropout: float = 0.2,
67
+ use_res: bool = True,
68
+ norm: str = None,
69
+ activation: str = 'leaky_relu',
70
+ leaky_relu: float = 0.1,
71
+ channel_first: bool = True
72
+ ) -> 'TemporalConvNet':
73
+ '''
74
+ Automatically builds a TCN module based on the target receptive field.
75
+
76
+ Args:
77
+ in_channels (int): Number of input channels.
78
+ out_channels (int): Unified output channels for each layer.
79
+ receptive_field (int): Target receptive field (time steps).
80
+ kernel_size (int, optional): Kernel size. Defaults to 3.
81
+ dropout (float, optional): Dropout probability. Defaults to 0.2.
82
+ use_res (bool, optional): Whether to use residual connections. Defaults to True.
83
+ norm (str, optional): Normalization type. Defaults to None.
84
+ activation (str, optional): Activation function type. Defaults to 'leaky_relu'.
85
+ leaky_relu (float, optional): Negative slope for LeakyReLU. Defaults to 0.1.
86
+ channel_first (bool, optional): Whether input is (Batch, Channels, Seq_Len). Defaults to True.
87
+
88
+ Returns:
89
+ TemporalConvNet: An initialized TCN module.
90
+ '''
91
+ layers, _ = calculate_causal_layer(receptive_field, kernel_size)
92
+ num_channels = [out_channels] * layers
93
+
94
+ return TemporalConvNet(
95
+ in_channels=in_channels,
96
+ num_channels=num_channels,
97
+ kernel_size=kernel_size,
98
+ dropout=dropout,
99
+ use_res=use_res,
100
+ norm=norm,
101
+ activation=activation,
102
+ leaky_relu=leaky_relu,
103
+ channel_first=channel_first
104
+ )
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ '''
108
+ Forward pass.
109
+
110
+ Args:
111
+ x (torch.Tensor): Input tensor. Shape: [Batch, in_channels, Seq_Len] or [Batch, Seq_Len, in_channels] if channel_first=False.
112
+
113
+ Returns:
114
+ torch.Tensor: Output tensor.
115
+ '''
116
+ if not self.channel_first:
117
+ x = x.transpose(1, 2)
118
+
119
+ x = self.network(x)
120
+
121
+ if not self.channel_first:
122
+ x = x.transpose(1, 2)
123
+
124
+ return x
codon/ops/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+
2
+
3
+ __seed__: int = None
codon/ops/attention.py ADDED
@@ -0,0 +1,107 @@
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import math
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple
7
+
8
+ from torch.nn.attention import SDPBackend, sdpa_kernel
9
+
10
+ @dataclass
11
+ class AttentionOutput:
12
+ '''
13
+ Output of the attention mechanism.
14
+
15
+ Args:
16
+ output (torch.Tensor): The output tensor from the attention mechanism.
17
+ attention_weights (Optional[torch.Tensor], optional): Attention weights. Defaults to None.
18
+ past_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]], optional):
19
+ Cached key and value tensors for autoregressive generation. Defaults to None.
20
+ '''
21
+ output: torch.Tensor
22
+ attention_weights: Optional[torch.Tensor] = None
23
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
24
+
25
+
26
+ def apply_attention(
27
+ query_states: torch.Tensor,
28
+ key_states: torch.Tensor,
29
+ value_states: torch.Tensor,
30
+ attention_mask: torch.Tensor = None,
31
+ output_attentions: bool = False,
32
+ is_causal: bool = None,
33
+ dropout: float = 0.0
34
+ ) -> AttentionOutput:
35
+ '''
36
+ Compute scaled dot-product attention.
37
+
38
+ Args:
39
+ query_states (torch.Tensor): Query states tensor.
40
+ key_states (torch.Tensor): Key states tensor.
41
+ value_states (torch.Tensor): Value states tensor.
42
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
43
+ output_attentions (bool, optional): Whether to output attention weights. Defaults to False.
44
+ is_causal (bool, optional): Whether to apply a causal mask. Defaults to None.
45
+ dropout (float, optional): Dropout probability. Defaults to 0.0.
46
+
47
+ Returns:
48
+ AttentionOutput: Object containing attention output and optional weights.
49
+ '''
50
+
51
+ if attention_mask is not None:
52
+ if attention_mask.dtype != torch.float32:
53
+ attention_mask = attention_mask.float()
54
+
55
+ if attention_mask.max() <= 1.0:
56
+ attention_mask = torch.where(attention_mask == 0, float('-inf'), 0.0)
57
+
58
+ if is_causal:
59
+ tgt_len = query_states.size(-2)
60
+ src_len = key_states.size(-2)
61
+
62
+ causal_mask = torch.tril(
63
+ torch.ones((tgt_len, src_len), device=query_states.device, dtype=query_states.dtype)
64
+ ).view(1, 1, tgt_len, src_len)
65
+
66
+ causal_mask = torch.where(causal_mask == 0, float('-inf'), 0.0)
67
+
68
+ if attention_mask is not None:
69
+ attention_mask = attention_mask + causal_mask
70
+ else:
71
+ attention_mask = causal_mask
72
+
73
+ is_causal = False
74
+
75
+ if not output_attentions:
76
+ if attention_mask is None and is_causal is None:
77
+ is_causal = True
78
+
79
+ try:
80
+ with sdpa_kernel([
81
+ SDPBackend.FLASH_ATTENTION,
82
+ SDPBackend.CUDNN_ATTENTION
83
+ ]):
84
+ output = F.scaled_dot_product_attention(
85
+ query_states,
86
+ key_states,
87
+ value_states,
88
+ attn_mask=attention_mask,
89
+ is_causal=is_causal,
90
+ dropout_p=dropout
91
+ )
92
+ return AttentionOutput(output=output, attention_weights=None)
93
+ except RuntimeError:
94
+ pass
95
+ # Manual Fallback Path
96
+ d_k = query_states.size(-1)
97
+ scores = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(d_k)
98
+
99
+ if attention_mask is not None:
100
+ scores = scores + attention_mask
101
+ attention_weights = torch.softmax(scores, dim=-1)
102
+
103
+ if dropout > 0.0:
104
+ attention_weights = F.dropout(attention_weights, p=dropout)
105
+ output = torch.matmul(attention_weights, value_states)
106
+
107
+ return AttentionOutput(output=output, attention_weights=attention_weights)
codon/ops/bio.py ADDED
File without changes
File without changes
@@ -0,0 +1,3 @@
1
+ from .flatdata import FlatDataset, FlatColumnDataset, MappedFlatDataset
2
+ from .corpus import FileType, CorpusData, CorpusDataset
3
+ from .dataviewer import DataViewer, preview_fields
@@ -0,0 +1,46 @@
1
+ from typing import Any
2
+
3
+ class CodonDataset:
4
+ '''
5
+ Base class for all Codon datasets.
6
+
7
+ This abstract class defines the interface that all datasets must implement.
8
+ It provides a common structure for accessing data rows and length.
9
+ '''
10
+
11
+ @property
12
+ def row(self) -> int:
13
+ '''
14
+ Returns the number of rows in the dataset.
15
+
16
+ Returns:
17
+ int: The total number of rows.
18
+ '''
19
+ return len(self)
20
+
21
+ def __len__(self) -> int:
22
+ '''
23
+ Returns the length of the dataset.
24
+
25
+ Returns:
26
+ int: The length of the dataset.
27
+
28
+ Raises:
29
+ NotImplementedError: If the method is not implemented by the subclass.
30
+ '''
31
+ raise NotImplementedError
32
+
33
+ def __getitem__(self, idx: Any) -> Any:
34
+ '''
35
+ Retrieves an item from the dataset at the specified index.
36
+
37
+ Args:
38
+ idx (Any): The index of the item to retrieve.
39
+
40
+ Returns:
41
+ Any: The item at the specified index.
42
+
43
+ Raises:
44
+ NotImplementedError: If the method is not implemented by the subclass.
45
+ '''
46
+ raise NotImplementedError