codon-model 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- codon/__init__.py +5 -0
- codon/base.py +167 -0
- codon/exp/__init__.py +0 -0
- codon/exp/moe.py +307 -0
- codon/model/__init__.py +0 -0
- codon/model/motif/__init__.py +1 -0
- codon/model/motif/motif_a1.py +121 -0
- codon/model/patch_disc.py +151 -0
- codon/model/tcn.py +124 -0
- codon/ops/__init__.py +3 -0
- codon/ops/attention.py +107 -0
- codon/ops/bio.py +0 -0
- codon/utils/__init__.py +0 -0
- codon/utils/dataset/__init__.py +3 -0
- codon/utils/dataset/base.py +46 -0
- codon/utils/dataset/corpus.py +478 -0
- codon/utils/dataset/dataviewer.py +196 -0
- codon/utils/dataset/flatdata.py +455 -0
- codon/utils/mask.py +266 -0
- codon/utils/safecode.py +24 -0
- codon/utils/seed.py +75 -0
- codon/utils/theta.py +55 -0
- codon/utils/token.py +276 -0
- codon_model-0.0.1.dist-info/METADATA +17 -0
- codon_model-0.0.1.dist-info/RECORD +28 -0
- codon_model-0.0.1.dist-info/WHEEL +5 -0
- codon_model-0.0.1.dist-info/licenses/LICENSE +201 -0
- codon_model-0.0.1.dist-info/top_level.txt +1 -0
codon/__init__.py
ADDED
codon/base.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
|
|
4
|
+
from typing import Callable, Any, Iterator, Union
|
|
5
|
+
|
|
6
|
+
from safetensors.torch import save_model as safe_save_model
|
|
7
|
+
from safetensors.torch import load_model as safe_load_model
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BasicModel(nn.Module):
|
|
11
|
+
'''
|
|
12
|
+
Base class for all models, providing common functionality like gradient checkpointing and parameter counting.
|
|
13
|
+
'''
|
|
14
|
+
def __init__(self):
|
|
15
|
+
'''
|
|
16
|
+
Initialize the BasicModel.
|
|
17
|
+
'''
|
|
18
|
+
super(BasicModel, self).__init__()
|
|
19
|
+
self.gradient_checkpointing: bool = False
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def device(self) -> torch.device:
|
|
23
|
+
'''
|
|
24
|
+
Get the device of the model.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
torch.device: The device where the model parameters are located.
|
|
28
|
+
Returns 'cpu' if the model has no parameters.
|
|
29
|
+
'''
|
|
30
|
+
try: return next(self.parameters()).device
|
|
31
|
+
except StopIteration: return torch.device('cpu')
|
|
32
|
+
|
|
33
|
+
def set_checkpoint(self, value:bool) -> None:
|
|
34
|
+
'''
|
|
35
|
+
Enable or disable gradient checkpointing for the model and its sub-modules.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
value (bool): True to enable gradient checkpointing, False to disable.
|
|
39
|
+
'''
|
|
40
|
+
self.gradient_checkpointing = value
|
|
41
|
+
for model in self.modules():
|
|
42
|
+
if isinstance(model, BasicModel) and model is not self:
|
|
43
|
+
model.gradient_checkpointing = value
|
|
44
|
+
|
|
45
|
+
def checkpoint(self, function:Callable, *args, **kwargs) -> Any:
|
|
46
|
+
'''
|
|
47
|
+
Apply gradient checkpointing to a function if enabled and in training mode.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
function (Callable): The function to be checkpointed.
|
|
51
|
+
*args: Positional arguments for the function.
|
|
52
|
+
**kwargs: Keyword arguments for the function.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Any: The output of the function.
|
|
56
|
+
'''
|
|
57
|
+
if self.gradient_checkpointing and self.training:
|
|
58
|
+
return torch.utils.checkpoint.checkpoint(
|
|
59
|
+
function, *args, use_reentrant=False, **kwargs
|
|
60
|
+
)
|
|
61
|
+
return function(*args, **kwargs)
|
|
62
|
+
|
|
63
|
+
def get_params(self, trainable_only:bool=False) -> Iterator[torch.nn.Parameter]:
|
|
64
|
+
'''
|
|
65
|
+
Get an iterator over the model parameters.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
trainable_only (bool, optional): If True, only yield parameters that require gradients.
|
|
69
|
+
Defaults to False.
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
Iterator[torch.nn.Parameter]: An iterator over the model parameters.
|
|
73
|
+
'''
|
|
74
|
+
if trainable_only:
|
|
75
|
+
return (p for p in self.parameters() if p.requires_grad)
|
|
76
|
+
return self.parameters()
|
|
77
|
+
|
|
78
|
+
def count_params(self, trainable_only:bool=False, active_only:bool=False, human_readable:bool=False, seen:set=None) -> Union[int, str]:
|
|
79
|
+
'''
|
|
80
|
+
Count the number of parameters in the model.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
trainable_only (bool, optional): If True, count only trainable parameters.
|
|
84
|
+
Defaults to False.
|
|
85
|
+
active_only (bool, optional): If True, count only active parameters (e.g. for MoE).
|
|
86
|
+
Defaults to False.
|
|
87
|
+
human_readable (bool, optional): If True, return a string representation with units (e.g. M, B).
|
|
88
|
+
Defaults to False.
|
|
89
|
+
seen (set, optional): A set of already counted parameters to avoid duplicates.
|
|
90
|
+
Defaults to None.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Union[int, str]: The total number of parameters.
|
|
94
|
+
'''
|
|
95
|
+
if seen is None:
|
|
96
|
+
seen = set()
|
|
97
|
+
|
|
98
|
+
if not active_only:
|
|
99
|
+
total = 0
|
|
100
|
+
for p in self.get_params(trainable_only):
|
|
101
|
+
if p not in seen:
|
|
102
|
+
seen.add(p)
|
|
103
|
+
total += p.numel()
|
|
104
|
+
else:
|
|
105
|
+
total = self._count_params_recursive(self, trainable_only, active_only, seen)
|
|
106
|
+
|
|
107
|
+
if human_readable:
|
|
108
|
+
if total >= 1e9:
|
|
109
|
+
return f'{total / 1e9:.2f}B'
|
|
110
|
+
elif total >= 1e6:
|
|
111
|
+
return f'{total / 1e6:.2f}M'
|
|
112
|
+
elif total >= 1e3:
|
|
113
|
+
return f'{total / 1e3:.2f}K'
|
|
114
|
+
return str(total)
|
|
115
|
+
|
|
116
|
+
return total
|
|
117
|
+
|
|
118
|
+
@staticmethod
|
|
119
|
+
def _count_params_recursive(module: nn.Module, trainable_only: bool, active_only: bool, seen: set) -> int:
|
|
120
|
+
total = 0
|
|
121
|
+
for p in module.parameters(recurse=False):
|
|
122
|
+
if p not in seen:
|
|
123
|
+
if not trainable_only or p.requires_grad:
|
|
124
|
+
seen.add(p)
|
|
125
|
+
total += p.numel()
|
|
126
|
+
|
|
127
|
+
for child in module.children():
|
|
128
|
+
if isinstance(child, BasicModel):
|
|
129
|
+
total += child.count_params(trainable_only, active_only, seen=seen)
|
|
130
|
+
else:
|
|
131
|
+
total += BasicModel._count_params_recursive(child, trainable_only, active_only, seen)
|
|
132
|
+
|
|
133
|
+
return total
|
|
134
|
+
|
|
135
|
+
def load_pretrained(self, path: str) -> None:
|
|
136
|
+
'''
|
|
137
|
+
Load a pretrained model from a file.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
path (str): The path to the model file.
|
|
141
|
+
'''
|
|
142
|
+
if path.endswith('.safetensors'):
|
|
143
|
+
safe_load_model(self, path)
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
state_dict = torch.load(path, map_location=self.device)
|
|
147
|
+
|
|
148
|
+
if isinstance(state_dict, dict):
|
|
149
|
+
if 'model_state_dict' in state_dict:
|
|
150
|
+
state_dict = state_dict['model_state_dict']
|
|
151
|
+
elif 'state_dict' in state_dict:
|
|
152
|
+
state_dict = state_dict['state_dict']
|
|
153
|
+
|
|
154
|
+
self.load_state_dict(state_dict)
|
|
155
|
+
|
|
156
|
+
def save_pretrained(self, path: str) -> None:
|
|
157
|
+
'''
|
|
158
|
+
Save the model to a file.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
path (str): The path to save the model file.
|
|
162
|
+
'''
|
|
163
|
+
if path.endswith('.safetensors'):
|
|
164
|
+
safe_save_model(self, path)
|
|
165
|
+
else:
|
|
166
|
+
state_dict = self.state_dict()
|
|
167
|
+
torch.save(state_dict, path)
|
codon/exp/__init__.py
ADDED
|
File without changes
|
codon/exp/moe.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
import torch.nn.functional as F
|
|
2
|
+
|
|
3
|
+
from codon.base import *
|
|
4
|
+
from codon.block.mlp import MLP
|
|
5
|
+
from codon.block.moe import *
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ParallelExpert(nn.Module):
|
|
11
|
+
'''
|
|
12
|
+
A module that computes multiple expert outputs in parallel.
|
|
13
|
+
|
|
14
|
+
This module manages weights for multiple experts and processes inputs efficiently using batch matrix multiplication.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
use_gate (bool): Whether to use Gated Linear Unit (GLU) variants.
|
|
18
|
+
num_experts (int): The number of experts.
|
|
19
|
+
weight1 (nn.Parameter): First weight matrix with shape (num_experts, in_features, inter_dim).
|
|
20
|
+
weight2 (nn.Parameter): Second weight matrix with shape (num_experts, hidden_features, out_features).
|
|
21
|
+
act (nn.Module): Activation function.
|
|
22
|
+
dropout (nn.Dropout): Dropout layer.
|
|
23
|
+
'''
|
|
24
|
+
|
|
25
|
+
def __init__(self, num_experts: int, in_features: int, hidden_features: int, out_features: int, use_gate: bool = False, dropout: float = 0.1) -> None:
|
|
26
|
+
'''
|
|
27
|
+
Initializes the ParallelExpert module.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
num_experts (int): The number of experts.
|
|
31
|
+
in_features (int): Size of each input sample.
|
|
32
|
+
hidden_features (int): Size of the hidden layer.
|
|
33
|
+
out_features (int): Size of each output sample.
|
|
34
|
+
use_gate (bool): If True, uses SiLU activation with gating; otherwise, uses GELU.
|
|
35
|
+
dropout (float): Dropout probability.
|
|
36
|
+
'''
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.use_gate = use_gate
|
|
39
|
+
self.num_experts = num_experts
|
|
40
|
+
self.dropout_p = dropout
|
|
41
|
+
|
|
42
|
+
inter_dim = hidden_features * 2 if use_gate else hidden_features
|
|
43
|
+
|
|
44
|
+
# [Experts, In, Hidden]
|
|
45
|
+
self.weight1 = nn.Parameter(torch.empty(num_experts, in_features, inter_dim))
|
|
46
|
+
self.weight2 = nn.Parameter(torch.empty(num_experts, hidden_features, out_features))
|
|
47
|
+
|
|
48
|
+
self.act = nn.SiLU() if use_gate else nn.GELU()
|
|
49
|
+
self.dropout = nn.Dropout(dropout)
|
|
50
|
+
|
|
51
|
+
self.reset_parameters()
|
|
52
|
+
|
|
53
|
+
def reset_parameters(self) -> None:
|
|
54
|
+
'''
|
|
55
|
+
Resets the parameters of the experts using Kaiming Uniform initialization.
|
|
56
|
+
'''
|
|
57
|
+
for i in range(self.num_experts):
|
|
58
|
+
nn.init.kaiming_uniform_(self.weight1[i], a=math.sqrt(5))
|
|
59
|
+
nn.init.kaiming_uniform_(self.weight2[i], a=math.sqrt(5))
|
|
60
|
+
|
|
61
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
62
|
+
'''
|
|
63
|
+
Performs the forward pass for all experts in parallel.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
x (torch.Tensor): Input tensor with shape (num_experts, capacity, in_features).
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
torch.Tensor: Output tensor with shape (num_experts, capacity, out_features).
|
|
70
|
+
'''
|
|
71
|
+
# x shape: [Num_Experts, Capacity, In_Features]
|
|
72
|
+
# Weight1: [Num_Experts, In_Features, Inter_Dim]
|
|
73
|
+
|
|
74
|
+
# [Num_Experts, Capacity, Inter_Dim]
|
|
75
|
+
h = torch.bmm(x, self.weight1)
|
|
76
|
+
|
|
77
|
+
if self.use_gate:
|
|
78
|
+
gate, val = h.chunk(2, dim=-1)
|
|
79
|
+
h = self.act(gate) * val
|
|
80
|
+
else:
|
|
81
|
+
h = self.act(h)
|
|
82
|
+
|
|
83
|
+
h = self.dropout(h)
|
|
84
|
+
|
|
85
|
+
# Weight2: [Num_Experts, Hidden_Features, Out_Features]
|
|
86
|
+
out = torch.bmm(h, self.weight2)
|
|
87
|
+
|
|
88
|
+
return out
|
|
89
|
+
|
|
90
|
+
class ParallelMoE(BasicModel):
|
|
91
|
+
'''
|
|
92
|
+
A Parallel Mixture-of-Experts (MoE) model.
|
|
93
|
+
|
|
94
|
+
This model routes tokens to the top-k experts and computes their outputs in parallel. It also supports shared experts and auxiliary loss for load balancing.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
model_dim (int): The dimension of the model.
|
|
98
|
+
top_k (int): The number of experts to route each token to.
|
|
99
|
+
num_experts (int): The total number of experts.
|
|
100
|
+
num_shared_experts (int): The number of shared experts that process all tokens.
|
|
101
|
+
use_aux_loss (bool): Whether to use auxiliary loss for load balancing.
|
|
102
|
+
capacity_factor (float): Factor to determine the capacity of each expert.
|
|
103
|
+
use_gate (bool): Whether to use Gated Linear Unit (GLU) variants in experts.
|
|
104
|
+
gate (nn.Linear): The gating network to route tokens to experts.
|
|
105
|
+
parallel_experts (ParallelExpert): The parallel experts module.
|
|
106
|
+
shared_experts (nn.ModuleList): List of shared experts, if any.
|
|
107
|
+
'''
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
model_dim: int,
|
|
112
|
+
top_k: int,
|
|
113
|
+
num_experts: int,
|
|
114
|
+
num_shared_experts: int = 0,
|
|
115
|
+
use_aux_loss: bool = False,
|
|
116
|
+
use_gate: bool = False,
|
|
117
|
+
capacity_factor: float = 1.25,
|
|
118
|
+
) -> None:
|
|
119
|
+
'''
|
|
120
|
+
Initializes the ParallelMoE model.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
model_dim (int): The dimension of the model.
|
|
124
|
+
top_k (int): The number of experts to route each token to.
|
|
125
|
+
num_experts (int): The total number of experts.
|
|
126
|
+
num_shared_experts (int): The number of shared experts that process all tokens.
|
|
127
|
+
use_aux_loss (bool): Whether to use auxiliary loss for load balancing.
|
|
128
|
+
use_gate (bool): Whether to use Gated Linear Unit (GLU) variants in experts.
|
|
129
|
+
capacity_factor (float): Factor to determine the capacity of each expert.
|
|
130
|
+
'''
|
|
131
|
+
super().__init__()
|
|
132
|
+
self.model_dim = model_dim
|
|
133
|
+
self.top_k = top_k
|
|
134
|
+
self.num_experts = num_experts
|
|
135
|
+
self.num_shared_experts = num_shared_experts
|
|
136
|
+
self.use_aux_loss = use_aux_loss
|
|
137
|
+
self.capacity_factor = capacity_factor
|
|
138
|
+
self.use_gate = use_gate
|
|
139
|
+
|
|
140
|
+
hidden_dim = model_dim * 4
|
|
141
|
+
|
|
142
|
+
self.gate = nn.Linear(model_dim, num_experts, bias=False)
|
|
143
|
+
|
|
144
|
+
self.parallel_experts = ParallelExpert(
|
|
145
|
+
num_experts, model_dim, hidden_dim, model_dim, use_gate=use_gate
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
self.shared_experts = None
|
|
149
|
+
if num_shared_experts > 0:
|
|
150
|
+
act_layer = "silu" if use_gate else "gelu"
|
|
151
|
+
self.shared_experts = nn.ModuleList([
|
|
152
|
+
MLP(
|
|
153
|
+
in_features=model_dim,
|
|
154
|
+
hidden_features=hidden_dim,
|
|
155
|
+
out_features=model_dim,
|
|
156
|
+
use_gate=use_gate,
|
|
157
|
+
act_layer=act_layer
|
|
158
|
+
) for _ in range(num_shared_experts)
|
|
159
|
+
])
|
|
160
|
+
|
|
161
|
+
def count_params(self, trainable_only: bool = False, active_only: bool = False) -> int:
|
|
162
|
+
'''
|
|
163
|
+
Counts the number of parameters in the model.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
trainable_only (bool): If True, counts only trainable parameters.
|
|
167
|
+
active_only (bool): If True, counts only active parameters (parameters used during a single forward pass).
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
int: The number of parameters.
|
|
171
|
+
'''
|
|
172
|
+
if not active_only:
|
|
173
|
+
return super().count_params(trainable_only, active_only)
|
|
174
|
+
|
|
175
|
+
total = self.gate.weight.numel()
|
|
176
|
+
|
|
177
|
+
if self.shared_experts:
|
|
178
|
+
total += sum(p.numel() for split in self.shared_experts for p in split.parameters())
|
|
179
|
+
|
|
180
|
+
parallel_params = sum(p.numel() for p in self.parallel_experts.parameters())
|
|
181
|
+
single_expert_params = parallel_params // self.num_experts
|
|
182
|
+
total += single_expert_params * self.top_k
|
|
183
|
+
|
|
184
|
+
return total
|
|
185
|
+
|
|
186
|
+
@property
|
|
187
|
+
def info(self) -> MoEInfo:
|
|
188
|
+
'''
|
|
189
|
+
Returns information about the MoE model's parameters.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
MoEInfo: An object containing total and active parameter counts.
|
|
193
|
+
'''
|
|
194
|
+
total = self.count_params(active_only=False)
|
|
195
|
+
active = self.count_params(active_only=True)
|
|
196
|
+
return MoEInfo(total_count=total, active_count=active)
|
|
197
|
+
|
|
198
|
+
def forward(self, x: torch.Tensor) -> MoEOutput:
|
|
199
|
+
'''
|
|
200
|
+
Performs the forward pass of the ParallelMoE model.
|
|
201
|
+
|
|
202
|
+
This method routes tokens to experts, computes expert outputs in parallel, adds shared expert outputs, and returns the combined result.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
x (torch.Tensor): Input tensor with shape (batch_size, seq_len, model_dim).
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
MoEOutput: The output of the MoE model, containing the final output tensor and auxiliary loss.
|
|
209
|
+
'''
|
|
210
|
+
original_shape = x.shape
|
|
211
|
+
batch, seq_len, dim = original_shape
|
|
212
|
+
num_tokens = batch * seq_len
|
|
213
|
+
|
|
214
|
+
x_flat = x.reshape(-1, dim)
|
|
215
|
+
|
|
216
|
+
# Shared Experts
|
|
217
|
+
shared_output = torch.zeros_like(x_flat)
|
|
218
|
+
if self.shared_experts is not None:
|
|
219
|
+
for expert in self.shared_experts:
|
|
220
|
+
shared_output = shared_output + expert(x_flat)
|
|
221
|
+
|
|
222
|
+
# Gating
|
|
223
|
+
router_logits = self.gate(x_flat) # [Tokens, Experts]
|
|
224
|
+
routing_probs = F.softmax(router_logits, dim=-1)
|
|
225
|
+
|
|
226
|
+
# weights: [Tokens, TopK], indices: [Tokens, TopK]
|
|
227
|
+
topk_weights, topk_indices = torch.topk(routing_probs, self.top_k, dim=-1)
|
|
228
|
+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
229
|
+
|
|
230
|
+
# Aux Loss
|
|
231
|
+
aux_loss = None
|
|
232
|
+
if self.use_aux_loss and self.training:
|
|
233
|
+
mask = torch.zeros_like(routing_probs).scatter_(1, topk_indices, 1.0)
|
|
234
|
+
density = mask.mean(dim=0)
|
|
235
|
+
density_proxy = routing_probs.mean(dim=0)
|
|
236
|
+
aux_loss = (self.num_experts * (density * density_proxy).sum())
|
|
237
|
+
|
|
238
|
+
capacity = int(num_tokens * self.top_k / self.num_experts * self.capacity_factor)
|
|
239
|
+
capacity = max(capacity, 4)
|
|
240
|
+
|
|
241
|
+
# [Tokens * TopK]
|
|
242
|
+
flat_topk_indices = topk_indices.view(-1)
|
|
243
|
+
|
|
244
|
+
sort_vals, sort_indices = flat_topk_indices.sort()
|
|
245
|
+
|
|
246
|
+
# x_flat: [Tokens, Dim] -> [Tokens, TopK, Dim] -> [Tokens*TopK, Dim]
|
|
247
|
+
x_expanded = x_flat.index_select(0, sort_indices // self.top_k)
|
|
248
|
+
|
|
249
|
+
# expert_counts: [Num_Experts]
|
|
250
|
+
expert_counts = torch.histc(
|
|
251
|
+
flat_topk_indices.float(),
|
|
252
|
+
bins=self.num_experts,
|
|
253
|
+
min=0,
|
|
254
|
+
max=self.num_experts - 1
|
|
255
|
+
).int()
|
|
256
|
+
|
|
257
|
+
parallel_inputs = torch.zeros(
|
|
258
|
+
self.num_experts, capacity, dim,
|
|
259
|
+
dtype=x.dtype, device=x.device
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
cumsum_counts = torch.cat([torch.tensor([0], device=x.device), expert_counts.cumsum(0)])
|
|
263
|
+
expert_starts = cumsum_counts[sort_vals]
|
|
264
|
+
range_indices = torch.arange(sort_vals.size(0), device=x.device)
|
|
265
|
+
indices_in_expert = range_indices - expert_starts
|
|
266
|
+
|
|
267
|
+
mask = indices_in_expert < capacity
|
|
268
|
+
|
|
269
|
+
valid_indices = indices_in_expert[mask] # [Valid_Count]
|
|
270
|
+
valid_experts = sort_vals[mask] # [Valid_Count]
|
|
271
|
+
valid_inputs = x_expanded[mask] # [Valid_Count, Dim]
|
|
272
|
+
|
|
273
|
+
# index: (Expert_ID, Capacity_ID)
|
|
274
|
+
parallel_inputs[valid_experts, valid_indices] = valid_inputs
|
|
275
|
+
|
|
276
|
+
parallel_outputs = self.parallel_experts(parallel_inputs)
|
|
277
|
+
# [Num_Experts, Capacity, Dim]
|
|
278
|
+
|
|
279
|
+
# [Tokens * TopK, Dim]
|
|
280
|
+
combined_output = torch.zeros(
|
|
281
|
+
num_tokens * self.top_k, dim,
|
|
282
|
+
dtype=x.dtype, device=x.device
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# parallel_outputs[valid_experts, valid_indices]
|
|
286
|
+
valid_outputs = parallel_outputs[valid_experts, valid_indices]
|
|
287
|
+
|
|
288
|
+
original_positions = sort_indices[mask] # [Valid_Count]
|
|
289
|
+
|
|
290
|
+
token_ids = original_positions.div(self.top_k, rounding_mode='floor')
|
|
291
|
+
|
|
292
|
+
# [Tokens * TopK]
|
|
293
|
+
flat_weights = topk_weights.view(-1)
|
|
294
|
+
valid_weights = flat_weights[original_positions].unsqueeze(-1) # W
|
|
295
|
+
|
|
296
|
+
weighted_output = valid_outputs * valid_weights
|
|
297
|
+
|
|
298
|
+
final_output = torch.zeros_like(x_flat)
|
|
299
|
+
final_output.index_add_(0, token_ids, weighted_output)
|
|
300
|
+
|
|
301
|
+
final_output = final_output + shared_output
|
|
302
|
+
|
|
303
|
+
# [Batch, Seq, Dim]
|
|
304
|
+
return MoEOutput(
|
|
305
|
+
output=final_output.reshape(original_shape),
|
|
306
|
+
aux_loss=aux_loss
|
|
307
|
+
)
|
codon/model/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .motif_a1 import MotifA1, MotifA1Output
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from codon.base import *
|
|
2
|
+
|
|
3
|
+
from codon.block.transformer import TransformerMoEDecoder
|
|
4
|
+
from codon.block.embedding import RotaryEmbedding
|
|
5
|
+
|
|
6
|
+
from typing import Optional, List, Tuple
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class MotifA1Output:
|
|
11
|
+
logits: torch.Tensor
|
|
12
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
|
|
13
|
+
aux_loss: Optional[torch.Tensor] = None
|
|
14
|
+
attentions: Optional[List[torch.Tensor]] = None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MotifA1(BasicModel):
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
vocab_size: int = 32000,
|
|
21
|
+
model_dim: int = 768,
|
|
22
|
+
num_layers: int = 8,
|
|
23
|
+
num_heads: int = 8,
|
|
24
|
+
num_kv_heads: int = 2,
|
|
25
|
+
dropout: float = 0.1,
|
|
26
|
+
tie_weights: bool = False
|
|
27
|
+
):
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.token_emb = nn.Embedding(vocab_size, model_dim)
|
|
31
|
+
self.position_emb = RotaryEmbedding(model_dim // num_heads)
|
|
32
|
+
self.dropout = nn.Dropout(dropout)
|
|
33
|
+
|
|
34
|
+
self.decoder = nn.ModuleList([
|
|
35
|
+
TransformerMoEDecoder(
|
|
36
|
+
model_dim=model_dim,
|
|
37
|
+
num_heads=num_heads,
|
|
38
|
+
num_kv_heads=num_kv_heads,
|
|
39
|
+
top_k=1,
|
|
40
|
+
num_experts=3,
|
|
41
|
+
num_shared_experts=1,
|
|
42
|
+
use_expert_gate=True,
|
|
43
|
+
use_qk_norm=True,
|
|
44
|
+
use_attn_gate=True,
|
|
45
|
+
dropout=dropout,
|
|
46
|
+
idx=str(idx)
|
|
47
|
+
)
|
|
48
|
+
for idx in range(num_layers)
|
|
49
|
+
])
|
|
50
|
+
|
|
51
|
+
self.norm = nn.RMSNorm(model_dim)
|
|
52
|
+
self.proj_out = nn.Linear(model_dim, vocab_size, bias=False)
|
|
53
|
+
|
|
54
|
+
if tie_weights:
|
|
55
|
+
self.proj_out.weight = self.token_emb.weight
|
|
56
|
+
|
|
57
|
+
self.apply(self._init_weights)
|
|
58
|
+
|
|
59
|
+
def _init_weights(self, module):
|
|
60
|
+
std = 0.02
|
|
61
|
+
if isinstance(module, nn.Linear):
|
|
62
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
63
|
+
if module.bias is not None:
|
|
64
|
+
torch.nn.init.zeros_(module.bias)
|
|
65
|
+
elif isinstance(module, nn.Embedding):
|
|
66
|
+
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
67
|
+
if module.padding_idx is not None:
|
|
68
|
+
torch.nn.init.zeros_(module.weight[module.padding_idx])
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self,
|
|
72
|
+
input_ids: torch.Tensor,
|
|
73
|
+
mask: torch.Tensor = None,
|
|
74
|
+
start_pos: int = 0,
|
|
75
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
76
|
+
use_cache: bool = False,
|
|
77
|
+
output_attentions: bool = False
|
|
78
|
+
) -> MotifA1Output:
|
|
79
|
+
x = self.token_emb(input_ids)
|
|
80
|
+
x = self.dropout(x)
|
|
81
|
+
|
|
82
|
+
new_kv_cache = [] if use_cache else None
|
|
83
|
+
all_attentions = [] if output_attentions else None
|
|
84
|
+
aux_loss = None
|
|
85
|
+
|
|
86
|
+
for i, layer in enumerate(self.decoder):
|
|
87
|
+
layer_past = past_key_values[i] if past_key_values is not None else None
|
|
88
|
+
|
|
89
|
+
out = layer(
|
|
90
|
+
hidden_states=x,
|
|
91
|
+
attention_mask=mask,
|
|
92
|
+
output_attentions=output_attentions,
|
|
93
|
+
position_emb=self.position_emb,
|
|
94
|
+
embedding_start=start_pos,
|
|
95
|
+
past_key_value=layer_past,
|
|
96
|
+
use_cache=use_cache
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
x = out.output
|
|
100
|
+
|
|
101
|
+
if use_cache:
|
|
102
|
+
new_kv_cache.append(out.past_key_value)
|
|
103
|
+
|
|
104
|
+
if output_attentions:
|
|
105
|
+
all_attentions.append(out.attention_weights)
|
|
106
|
+
|
|
107
|
+
if out.aux_loss is not None:
|
|
108
|
+
if aux_loss is None:
|
|
109
|
+
aux_loss = out.aux_loss
|
|
110
|
+
else:
|
|
111
|
+
aux_loss += out.aux_loss
|
|
112
|
+
|
|
113
|
+
x = self.norm(x)
|
|
114
|
+
logits = self.proj_out(x)
|
|
115
|
+
|
|
116
|
+
return MotifA1Output(
|
|
117
|
+
logits=logits,
|
|
118
|
+
past_key_values=new_kv_cache,
|
|
119
|
+
aux_loss=aux_loss,
|
|
120
|
+
attentions=all_attentions
|
|
121
|
+
)
|