codon-model 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codon/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from typing import Optional
2
+
3
+ __version__ = '0.0.1'
4
+
5
+ __seed__: Optional[int] = None
codon/base.py ADDED
@@ -0,0 +1,167 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Callable, Any, Iterator, Union
5
+
6
+ from safetensors.torch import save_model as safe_save_model
7
+ from safetensors.torch import load_model as safe_load_model
8
+
9
+
10
+ class BasicModel(nn.Module):
11
+ '''
12
+ Base class for all models, providing common functionality like gradient checkpointing and parameter counting.
13
+ '''
14
+ def __init__(self):
15
+ '''
16
+ Initialize the BasicModel.
17
+ '''
18
+ super(BasicModel, self).__init__()
19
+ self.gradient_checkpointing: bool = False
20
+
21
+ @property
22
+ def device(self) -> torch.device:
23
+ '''
24
+ Get the device of the model.
25
+
26
+ Returns:
27
+ torch.device: The device where the model parameters are located.
28
+ Returns 'cpu' if the model has no parameters.
29
+ '''
30
+ try: return next(self.parameters()).device
31
+ except StopIteration: return torch.device('cpu')
32
+
33
+ def set_checkpoint(self, value:bool) -> None:
34
+ '''
35
+ Enable or disable gradient checkpointing for the model and its sub-modules.
36
+
37
+ Args:
38
+ value (bool): True to enable gradient checkpointing, False to disable.
39
+ '''
40
+ self.gradient_checkpointing = value
41
+ for model in self.modules():
42
+ if isinstance(model, BasicModel) and model is not self:
43
+ model.gradient_checkpointing = value
44
+
45
+ def checkpoint(self, function:Callable, *args, **kwargs) -> Any:
46
+ '''
47
+ Apply gradient checkpointing to a function if enabled and in training mode.
48
+
49
+ Args:
50
+ function (Callable): The function to be checkpointed.
51
+ *args: Positional arguments for the function.
52
+ **kwargs: Keyword arguments for the function.
53
+
54
+ Returns:
55
+ Any: The output of the function.
56
+ '''
57
+ if self.gradient_checkpointing and self.training:
58
+ return torch.utils.checkpoint.checkpoint(
59
+ function, *args, use_reentrant=False, **kwargs
60
+ )
61
+ return function(*args, **kwargs)
62
+
63
+ def get_params(self, trainable_only:bool=False) -> Iterator[torch.nn.Parameter]:
64
+ '''
65
+ Get an iterator over the model parameters.
66
+
67
+ Args:
68
+ trainable_only (bool, optional): If True, only yield parameters that require gradients.
69
+ Defaults to False.
70
+
71
+ Returns:
72
+ Iterator[torch.nn.Parameter]: An iterator over the model parameters.
73
+ '''
74
+ if trainable_only:
75
+ return (p for p in self.parameters() if p.requires_grad)
76
+ return self.parameters()
77
+
78
+ def count_params(self, trainable_only:bool=False, active_only:bool=False, human_readable:bool=False, seen:set=None) -> Union[int, str]:
79
+ '''
80
+ Count the number of parameters in the model.
81
+
82
+ Args:
83
+ trainable_only (bool, optional): If True, count only trainable parameters.
84
+ Defaults to False.
85
+ active_only (bool, optional): If True, count only active parameters (e.g. for MoE).
86
+ Defaults to False.
87
+ human_readable (bool, optional): If True, return a string representation with units (e.g. M, B).
88
+ Defaults to False.
89
+ seen (set, optional): A set of already counted parameters to avoid duplicates.
90
+ Defaults to None.
91
+
92
+ Returns:
93
+ Union[int, str]: The total number of parameters.
94
+ '''
95
+ if seen is None:
96
+ seen = set()
97
+
98
+ if not active_only:
99
+ total = 0
100
+ for p in self.get_params(trainable_only):
101
+ if p not in seen:
102
+ seen.add(p)
103
+ total += p.numel()
104
+ else:
105
+ total = self._count_params_recursive(self, trainable_only, active_only, seen)
106
+
107
+ if human_readable:
108
+ if total >= 1e9:
109
+ return f'{total / 1e9:.2f}B'
110
+ elif total >= 1e6:
111
+ return f'{total / 1e6:.2f}M'
112
+ elif total >= 1e3:
113
+ return f'{total / 1e3:.2f}K'
114
+ return str(total)
115
+
116
+ return total
117
+
118
+ @staticmethod
119
+ def _count_params_recursive(module: nn.Module, trainable_only: bool, active_only: bool, seen: set) -> int:
120
+ total = 0
121
+ for p in module.parameters(recurse=False):
122
+ if p not in seen:
123
+ if not trainable_only or p.requires_grad:
124
+ seen.add(p)
125
+ total += p.numel()
126
+
127
+ for child in module.children():
128
+ if isinstance(child, BasicModel):
129
+ total += child.count_params(trainable_only, active_only, seen=seen)
130
+ else:
131
+ total += BasicModel._count_params_recursive(child, trainable_only, active_only, seen)
132
+
133
+ return total
134
+
135
+ def load_pretrained(self, path: str) -> None:
136
+ '''
137
+ Load a pretrained model from a file.
138
+
139
+ Args:
140
+ path (str): The path to the model file.
141
+ '''
142
+ if path.endswith('.safetensors'):
143
+ safe_load_model(self, path)
144
+ return
145
+
146
+ state_dict = torch.load(path, map_location=self.device)
147
+
148
+ if isinstance(state_dict, dict):
149
+ if 'model_state_dict' in state_dict:
150
+ state_dict = state_dict['model_state_dict']
151
+ elif 'state_dict' in state_dict:
152
+ state_dict = state_dict['state_dict']
153
+
154
+ self.load_state_dict(state_dict)
155
+
156
+ def save_pretrained(self, path: str) -> None:
157
+ '''
158
+ Save the model to a file.
159
+
160
+ Args:
161
+ path (str): The path to save the model file.
162
+ '''
163
+ if path.endswith('.safetensors'):
164
+ safe_save_model(self, path)
165
+ else:
166
+ state_dict = self.state_dict()
167
+ torch.save(state_dict, path)
codon/exp/__init__.py ADDED
File without changes
codon/exp/moe.py ADDED
@@ -0,0 +1,307 @@
1
+ import torch.nn.functional as F
2
+
3
+ from codon.base import *
4
+ from codon.block.mlp import MLP
5
+ from codon.block.moe import *
6
+
7
+ import math
8
+
9
+
10
+ class ParallelExpert(nn.Module):
11
+ '''
12
+ A module that computes multiple expert outputs in parallel.
13
+
14
+ This module manages weights for multiple experts and processes inputs efficiently using batch matrix multiplication.
15
+
16
+ Attributes:
17
+ use_gate (bool): Whether to use Gated Linear Unit (GLU) variants.
18
+ num_experts (int): The number of experts.
19
+ weight1 (nn.Parameter): First weight matrix with shape (num_experts, in_features, inter_dim).
20
+ weight2 (nn.Parameter): Second weight matrix with shape (num_experts, hidden_features, out_features).
21
+ act (nn.Module): Activation function.
22
+ dropout (nn.Dropout): Dropout layer.
23
+ '''
24
+
25
+ def __init__(self, num_experts: int, in_features: int, hidden_features: int, out_features: int, use_gate: bool = False, dropout: float = 0.1) -> None:
26
+ '''
27
+ Initializes the ParallelExpert module.
28
+
29
+ Args:
30
+ num_experts (int): The number of experts.
31
+ in_features (int): Size of each input sample.
32
+ hidden_features (int): Size of the hidden layer.
33
+ out_features (int): Size of each output sample.
34
+ use_gate (bool): If True, uses SiLU activation with gating; otherwise, uses GELU.
35
+ dropout (float): Dropout probability.
36
+ '''
37
+ super().__init__()
38
+ self.use_gate = use_gate
39
+ self.num_experts = num_experts
40
+ self.dropout_p = dropout
41
+
42
+ inter_dim = hidden_features * 2 if use_gate else hidden_features
43
+
44
+ # [Experts, In, Hidden]
45
+ self.weight1 = nn.Parameter(torch.empty(num_experts, in_features, inter_dim))
46
+ self.weight2 = nn.Parameter(torch.empty(num_experts, hidden_features, out_features))
47
+
48
+ self.act = nn.SiLU() if use_gate else nn.GELU()
49
+ self.dropout = nn.Dropout(dropout)
50
+
51
+ self.reset_parameters()
52
+
53
+ def reset_parameters(self) -> None:
54
+ '''
55
+ Resets the parameters of the experts using Kaiming Uniform initialization.
56
+ '''
57
+ for i in range(self.num_experts):
58
+ nn.init.kaiming_uniform_(self.weight1[i], a=math.sqrt(5))
59
+ nn.init.kaiming_uniform_(self.weight2[i], a=math.sqrt(5))
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ '''
63
+ Performs the forward pass for all experts in parallel.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input tensor with shape (num_experts, capacity, in_features).
67
+
68
+ Returns:
69
+ torch.Tensor: Output tensor with shape (num_experts, capacity, out_features).
70
+ '''
71
+ # x shape: [Num_Experts, Capacity, In_Features]
72
+ # Weight1: [Num_Experts, In_Features, Inter_Dim]
73
+
74
+ # [Num_Experts, Capacity, Inter_Dim]
75
+ h = torch.bmm(x, self.weight1)
76
+
77
+ if self.use_gate:
78
+ gate, val = h.chunk(2, dim=-1)
79
+ h = self.act(gate) * val
80
+ else:
81
+ h = self.act(h)
82
+
83
+ h = self.dropout(h)
84
+
85
+ # Weight2: [Num_Experts, Hidden_Features, Out_Features]
86
+ out = torch.bmm(h, self.weight2)
87
+
88
+ return out
89
+
90
+ class ParallelMoE(BasicModel):
91
+ '''
92
+ A Parallel Mixture-of-Experts (MoE) model.
93
+
94
+ This model routes tokens to the top-k experts and computes their outputs in parallel. It also supports shared experts and auxiliary loss for load balancing.
95
+
96
+ Attributes:
97
+ model_dim (int): The dimension of the model.
98
+ top_k (int): The number of experts to route each token to.
99
+ num_experts (int): The total number of experts.
100
+ num_shared_experts (int): The number of shared experts that process all tokens.
101
+ use_aux_loss (bool): Whether to use auxiliary loss for load balancing.
102
+ capacity_factor (float): Factor to determine the capacity of each expert.
103
+ use_gate (bool): Whether to use Gated Linear Unit (GLU) variants in experts.
104
+ gate (nn.Linear): The gating network to route tokens to experts.
105
+ parallel_experts (ParallelExpert): The parallel experts module.
106
+ shared_experts (nn.ModuleList): List of shared experts, if any.
107
+ '''
108
+
109
+ def __init__(
110
+ self,
111
+ model_dim: int,
112
+ top_k: int,
113
+ num_experts: int,
114
+ num_shared_experts: int = 0,
115
+ use_aux_loss: bool = False,
116
+ use_gate: bool = False,
117
+ capacity_factor: float = 1.25,
118
+ ) -> None:
119
+ '''
120
+ Initializes the ParallelMoE model.
121
+
122
+ Args:
123
+ model_dim (int): The dimension of the model.
124
+ top_k (int): The number of experts to route each token to.
125
+ num_experts (int): The total number of experts.
126
+ num_shared_experts (int): The number of shared experts that process all tokens.
127
+ use_aux_loss (bool): Whether to use auxiliary loss for load balancing.
128
+ use_gate (bool): Whether to use Gated Linear Unit (GLU) variants in experts.
129
+ capacity_factor (float): Factor to determine the capacity of each expert.
130
+ '''
131
+ super().__init__()
132
+ self.model_dim = model_dim
133
+ self.top_k = top_k
134
+ self.num_experts = num_experts
135
+ self.num_shared_experts = num_shared_experts
136
+ self.use_aux_loss = use_aux_loss
137
+ self.capacity_factor = capacity_factor
138
+ self.use_gate = use_gate
139
+
140
+ hidden_dim = model_dim * 4
141
+
142
+ self.gate = nn.Linear(model_dim, num_experts, bias=False)
143
+
144
+ self.parallel_experts = ParallelExpert(
145
+ num_experts, model_dim, hidden_dim, model_dim, use_gate=use_gate
146
+ )
147
+
148
+ self.shared_experts = None
149
+ if num_shared_experts > 0:
150
+ act_layer = "silu" if use_gate else "gelu"
151
+ self.shared_experts = nn.ModuleList([
152
+ MLP(
153
+ in_features=model_dim,
154
+ hidden_features=hidden_dim,
155
+ out_features=model_dim,
156
+ use_gate=use_gate,
157
+ act_layer=act_layer
158
+ ) for _ in range(num_shared_experts)
159
+ ])
160
+
161
+ def count_params(self, trainable_only: bool = False, active_only: bool = False) -> int:
162
+ '''
163
+ Counts the number of parameters in the model.
164
+
165
+ Args:
166
+ trainable_only (bool): If True, counts only trainable parameters.
167
+ active_only (bool): If True, counts only active parameters (parameters used during a single forward pass).
168
+
169
+ Returns:
170
+ int: The number of parameters.
171
+ '''
172
+ if not active_only:
173
+ return super().count_params(trainable_only, active_only)
174
+
175
+ total = self.gate.weight.numel()
176
+
177
+ if self.shared_experts:
178
+ total += sum(p.numel() for split in self.shared_experts for p in split.parameters())
179
+
180
+ parallel_params = sum(p.numel() for p in self.parallel_experts.parameters())
181
+ single_expert_params = parallel_params // self.num_experts
182
+ total += single_expert_params * self.top_k
183
+
184
+ return total
185
+
186
+ @property
187
+ def info(self) -> MoEInfo:
188
+ '''
189
+ Returns information about the MoE model's parameters.
190
+
191
+ Returns:
192
+ MoEInfo: An object containing total and active parameter counts.
193
+ '''
194
+ total = self.count_params(active_only=False)
195
+ active = self.count_params(active_only=True)
196
+ return MoEInfo(total_count=total, active_count=active)
197
+
198
+ def forward(self, x: torch.Tensor) -> MoEOutput:
199
+ '''
200
+ Performs the forward pass of the ParallelMoE model.
201
+
202
+ This method routes tokens to experts, computes expert outputs in parallel, adds shared expert outputs, and returns the combined result.
203
+
204
+ Args:
205
+ x (torch.Tensor): Input tensor with shape (batch_size, seq_len, model_dim).
206
+
207
+ Returns:
208
+ MoEOutput: The output of the MoE model, containing the final output tensor and auxiliary loss.
209
+ '''
210
+ original_shape = x.shape
211
+ batch, seq_len, dim = original_shape
212
+ num_tokens = batch * seq_len
213
+
214
+ x_flat = x.reshape(-1, dim)
215
+
216
+ # Shared Experts
217
+ shared_output = torch.zeros_like(x_flat)
218
+ if self.shared_experts is not None:
219
+ for expert in self.shared_experts:
220
+ shared_output = shared_output + expert(x_flat)
221
+
222
+ # Gating
223
+ router_logits = self.gate(x_flat) # [Tokens, Experts]
224
+ routing_probs = F.softmax(router_logits, dim=-1)
225
+
226
+ # weights: [Tokens, TopK], indices: [Tokens, TopK]
227
+ topk_weights, topk_indices = torch.topk(routing_probs, self.top_k, dim=-1)
228
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
229
+
230
+ # Aux Loss
231
+ aux_loss = None
232
+ if self.use_aux_loss and self.training:
233
+ mask = torch.zeros_like(routing_probs).scatter_(1, topk_indices, 1.0)
234
+ density = mask.mean(dim=0)
235
+ density_proxy = routing_probs.mean(dim=0)
236
+ aux_loss = (self.num_experts * (density * density_proxy).sum())
237
+
238
+ capacity = int(num_tokens * self.top_k / self.num_experts * self.capacity_factor)
239
+ capacity = max(capacity, 4)
240
+
241
+ # [Tokens * TopK]
242
+ flat_topk_indices = topk_indices.view(-1)
243
+
244
+ sort_vals, sort_indices = flat_topk_indices.sort()
245
+
246
+ # x_flat: [Tokens, Dim] -> [Tokens, TopK, Dim] -> [Tokens*TopK, Dim]
247
+ x_expanded = x_flat.index_select(0, sort_indices // self.top_k)
248
+
249
+ # expert_counts: [Num_Experts]
250
+ expert_counts = torch.histc(
251
+ flat_topk_indices.float(),
252
+ bins=self.num_experts,
253
+ min=0,
254
+ max=self.num_experts - 1
255
+ ).int()
256
+
257
+ parallel_inputs = torch.zeros(
258
+ self.num_experts, capacity, dim,
259
+ dtype=x.dtype, device=x.device
260
+ )
261
+
262
+ cumsum_counts = torch.cat([torch.tensor([0], device=x.device), expert_counts.cumsum(0)])
263
+ expert_starts = cumsum_counts[sort_vals]
264
+ range_indices = torch.arange(sort_vals.size(0), device=x.device)
265
+ indices_in_expert = range_indices - expert_starts
266
+
267
+ mask = indices_in_expert < capacity
268
+
269
+ valid_indices = indices_in_expert[mask] # [Valid_Count]
270
+ valid_experts = sort_vals[mask] # [Valid_Count]
271
+ valid_inputs = x_expanded[mask] # [Valid_Count, Dim]
272
+
273
+ # index: (Expert_ID, Capacity_ID)
274
+ parallel_inputs[valid_experts, valid_indices] = valid_inputs
275
+
276
+ parallel_outputs = self.parallel_experts(parallel_inputs)
277
+ # [Num_Experts, Capacity, Dim]
278
+
279
+ # [Tokens * TopK, Dim]
280
+ combined_output = torch.zeros(
281
+ num_tokens * self.top_k, dim,
282
+ dtype=x.dtype, device=x.device
283
+ )
284
+
285
+ # parallel_outputs[valid_experts, valid_indices]
286
+ valid_outputs = parallel_outputs[valid_experts, valid_indices]
287
+
288
+ original_positions = sort_indices[mask] # [Valid_Count]
289
+
290
+ token_ids = original_positions.div(self.top_k, rounding_mode='floor')
291
+
292
+ # [Tokens * TopK]
293
+ flat_weights = topk_weights.view(-1)
294
+ valid_weights = flat_weights[original_positions].unsqueeze(-1) # W
295
+
296
+ weighted_output = valid_outputs * valid_weights
297
+
298
+ final_output = torch.zeros_like(x_flat)
299
+ final_output.index_add_(0, token_ids, weighted_output)
300
+
301
+ final_output = final_output + shared_output
302
+
303
+ # [Batch, Seq, Dim]
304
+ return MoEOutput(
305
+ output=final_output.reshape(original_shape),
306
+ aux_loss=aux_loss
307
+ )
File without changes
@@ -0,0 +1 @@
1
+ from .motif_a1 import MotifA1, MotifA1Output
@@ -0,0 +1,121 @@
1
+ from codon.base import *
2
+
3
+ from codon.block.transformer import TransformerMoEDecoder
4
+ from codon.block.embedding import RotaryEmbedding
5
+
6
+ from typing import Optional, List, Tuple
7
+ from dataclasses import dataclass
8
+
9
+ @dataclass
10
+ class MotifA1Output:
11
+ logits: torch.Tensor
12
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None
13
+ aux_loss: Optional[torch.Tensor] = None
14
+ attentions: Optional[List[torch.Tensor]] = None
15
+
16
+
17
+ class MotifA1(BasicModel):
18
+ def __init__(
19
+ self,
20
+ vocab_size: int = 32000,
21
+ model_dim: int = 768,
22
+ num_layers: int = 8,
23
+ num_heads: int = 8,
24
+ num_kv_heads: int = 2,
25
+ dropout: float = 0.1,
26
+ tie_weights: bool = False
27
+ ):
28
+ super().__init__()
29
+
30
+ self.token_emb = nn.Embedding(vocab_size, model_dim)
31
+ self.position_emb = RotaryEmbedding(model_dim // num_heads)
32
+ self.dropout = nn.Dropout(dropout)
33
+
34
+ self.decoder = nn.ModuleList([
35
+ TransformerMoEDecoder(
36
+ model_dim=model_dim,
37
+ num_heads=num_heads,
38
+ num_kv_heads=num_kv_heads,
39
+ top_k=1,
40
+ num_experts=3,
41
+ num_shared_experts=1,
42
+ use_expert_gate=True,
43
+ use_qk_norm=True,
44
+ use_attn_gate=True,
45
+ dropout=dropout,
46
+ idx=str(idx)
47
+ )
48
+ for idx in range(num_layers)
49
+ ])
50
+
51
+ self.norm = nn.RMSNorm(model_dim)
52
+ self.proj_out = nn.Linear(model_dim, vocab_size, bias=False)
53
+
54
+ if tie_weights:
55
+ self.proj_out.weight = self.token_emb.weight
56
+
57
+ self.apply(self._init_weights)
58
+
59
+ def _init_weights(self, module):
60
+ std = 0.02
61
+ if isinstance(module, nn.Linear):
62
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
63
+ if module.bias is not None:
64
+ torch.nn.init.zeros_(module.bias)
65
+ elif isinstance(module, nn.Embedding):
66
+ torch.nn.init.normal_(module.weight, mean=0.0, std=std)
67
+ if module.padding_idx is not None:
68
+ torch.nn.init.zeros_(module.weight[module.padding_idx])
69
+
70
+ def forward(
71
+ self,
72
+ input_ids: torch.Tensor,
73
+ mask: torch.Tensor = None,
74
+ start_pos: int = 0,
75
+ past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
76
+ use_cache: bool = False,
77
+ output_attentions: bool = False
78
+ ) -> MotifA1Output:
79
+ x = self.token_emb(input_ids)
80
+ x = self.dropout(x)
81
+
82
+ new_kv_cache = [] if use_cache else None
83
+ all_attentions = [] if output_attentions else None
84
+ aux_loss = None
85
+
86
+ for i, layer in enumerate(self.decoder):
87
+ layer_past = past_key_values[i] if past_key_values is not None else None
88
+
89
+ out = layer(
90
+ hidden_states=x,
91
+ attention_mask=mask,
92
+ output_attentions=output_attentions,
93
+ position_emb=self.position_emb,
94
+ embedding_start=start_pos,
95
+ past_key_value=layer_past,
96
+ use_cache=use_cache
97
+ )
98
+
99
+ x = out.output
100
+
101
+ if use_cache:
102
+ new_kv_cache.append(out.past_key_value)
103
+
104
+ if output_attentions:
105
+ all_attentions.append(out.attention_weights)
106
+
107
+ if out.aux_loss is not None:
108
+ if aux_loss is None:
109
+ aux_loss = out.aux_loss
110
+ else:
111
+ aux_loss += out.aux_loss
112
+
113
+ x = self.norm(x)
114
+ logits = self.proj_out(x)
115
+
116
+ return MotifA1Output(
117
+ logits=logits,
118
+ past_key_values=new_kv_cache,
119
+ aux_loss=aux_loss,
120
+ attentions=all_attentions
121
+ )