rxnn 0.1.48__py3-none-any.whl → 0.1.50__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +15 -15
- rxnn/experimental/models.py +3 -0
- rxnn/transformers/layers.py +12 -7
- rxnn/transformers/models.py +3 -3
- {rxnn-0.1.48.dist-info → rxnn-0.1.50.dist-info}/METADATA +1 -1
- {rxnn-0.1.48.dist-info → rxnn-0.1.50.dist-info}/RECORD +8 -8
- {rxnn-0.1.48.dist-info → rxnn-0.1.50.dist-info}/LICENSE +0 -0
- {rxnn-0.1.48.dist-info → rxnn-0.1.50.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -94,16 +94,17 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
94
94
|
B, S, D = key.shape
|
95
95
|
key_flat = key.reshape(-1, D)
|
96
96
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
|
-
weights = weights.view(B, self.num_groups,
|
98
|
-
indices = indices.view(B, self.num_groups
|
97
|
+
weights = weights.view(B, S, self.num_groups, 1)
|
98
|
+
indices = indices.view(B, S, self.num_groups)
|
99
99
|
|
100
100
|
# Compute all experts' projections
|
101
|
-
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1)
|
102
|
-
v_all = self.v_proj(value).view(B, S, self.num_experts, -1)
|
101
|
+
k_all = self.k_proj(key_flat).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
102
|
+
v_all = self.v_proj(value).view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
103
103
|
|
104
104
|
# Gather top-k experts using expanded indices
|
105
|
-
|
106
|
-
|
105
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, num_groups, S, head_dim]
|
106
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
|
+
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, num_groups, S, head_dim]
|
107
108
|
|
108
109
|
# Weighted
|
109
110
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
@@ -122,8 +123,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
122
123
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
123
124
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
124
125
|
|
125
|
-
print(q.size(), k.size(), v.size())
|
126
|
-
|
127
126
|
return q, k, v
|
128
127
|
|
129
128
|
|
@@ -208,17 +207,18 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
208
207
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
209
208
|
B, T, D = query.shape
|
210
209
|
query_flat = query.reshape(-1, D)
|
211
|
-
|
212
|
-
|
213
|
-
|
210
|
+
weights_q, indices_q = self.query_router(query_flat)
|
211
|
+
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
212
|
+
indices_q = indices_q.view(B, T, self.num_query_groups)
|
214
213
|
|
215
|
-
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1)
|
214
|
+
q_all = self.q_proj(query_flat).view(B, T, self.num_query_experts, -1) # [B, num_groups, S, head_dim]
|
216
215
|
|
217
|
-
# Gather top-k experts
|
218
|
-
|
216
|
+
# Gather top-k experts
|
217
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1)) # [B, T, num_query_groups, head_dim]
|
218
|
+
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
219
219
|
|
220
220
|
# Weighted sum
|
221
|
-
q = (selected_q *
|
221
|
+
q = (selected_q * weights_q).to(selected_q.device, dtype=selected_q.dtype) # [B, T, num_query_groups, head_dim]
|
222
222
|
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, num_query_groups, T, head_dim]
|
223
223
|
|
224
224
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
rxnn/experimental/models.py
CHANGED
@@ -83,6 +83,8 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
83
83
|
num_query_experts=att_num_query_experts,
|
84
84
|
num_query_groups=att_num_query_groups)
|
85
85
|
|
86
|
+
use_moe_att = att_type in ['gma', 'dma', 'gma_s', 'dma_s']
|
87
|
+
|
86
88
|
self.model = ClassicTransformerDecoder(
|
87
89
|
embed_dim,
|
88
90
|
vocab_size,
|
@@ -99,6 +101,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
99
101
|
ff_dropout=ff_dropout,
|
100
102
|
use_rms_norm=use_rms_norm,
|
101
103
|
self_attention=att_init(),
|
104
|
+
use_moe_att=use_moe_att,
|
102
105
|
) for _ in range(num_layers)
|
103
106
|
]),
|
104
107
|
use_flash_attention=use_flash_attention,
|
rxnn/transformers/layers.py
CHANGED
@@ -22,6 +22,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
22
22
|
use_moe: bool = False,
|
23
23
|
num_experts: int = 1,
|
24
24
|
moe_top_k: int = 1,
|
25
|
+
use_moe_att: bool = False,
|
25
26
|
*args,
|
26
27
|
**kwargs,
|
27
28
|
):
|
@@ -54,6 +55,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
54
55
|
self.norm3 = nn.LayerNorm(embed_dim)
|
55
56
|
self.use_post_norm = use_post_norm
|
56
57
|
self.use_moe = use_moe
|
58
|
+
self.use_moe_att = use_moe_att
|
57
59
|
|
58
60
|
def trainable_cross_attention_(self, is_trainable: bool):
|
59
61
|
for param in self.memory_cross_attention.parameters():
|
@@ -62,12 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
|
|
62
64
|
def moe_router_loss(self):
|
63
65
|
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
64
66
|
att_router_loss = None
|
65
|
-
if self.
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
67
|
+
if self.use_moe_att:
|
68
|
+
if self.attention.router_loss is not None and self.memory_cross_attention.router_loss is not None:
|
69
|
+
att_router_loss = (self.attention.router_loss() + self.memory_cross_attention.router_loss()) / 2
|
70
|
+
elif self.attention.router_loss is not None:
|
71
|
+
att_router_loss = self.attention.router_loss()
|
72
|
+
elif self.memory_cross_attention.router_loss is not None:
|
73
|
+
att_router_loss = self.memory_cross_attention.router_loss()
|
71
74
|
|
72
75
|
if ff_router_loss is not None and att_router_loss is not None:
|
73
76
|
return (ff_router_loss + att_router_loss) / 2
|
@@ -123,6 +126,7 @@ class ClassicTransformerLayer(nn.Module):
|
|
123
126
|
use_moe: bool = False,
|
124
127
|
num_experts: int = 1,
|
125
128
|
moe_top_k: int = 1,
|
129
|
+
use_moe_att: bool = False,
|
126
130
|
*args,
|
127
131
|
**kwargs,
|
128
132
|
):
|
@@ -151,10 +155,11 @@ class ClassicTransformerLayer(nn.Module):
|
|
151
155
|
self.norm2 = nn.LayerNorm(embed_dim)
|
152
156
|
self.use_post_norm = use_post_norm
|
153
157
|
self.use_moe = use_moe
|
158
|
+
self.use_moe_att = use_moe_att
|
154
159
|
|
155
160
|
def moe_router_loss(self):
|
156
161
|
ff_router_loss = self.ff.router_loss() if self.use_moe else None
|
157
|
-
att_router_loss = self.attention.router_loss() if self.attention.router_loss is not None else None
|
162
|
+
att_router_loss = self.attention.router_loss() if self.use_moe_att and self.attention.router_loss is not None else None
|
158
163
|
|
159
164
|
if ff_router_loss is not None and att_router_loss is not None:
|
160
165
|
return (ff_router_loss + att_router_loss) / 2
|
rxnn/transformers/models.py
CHANGED
@@ -38,8 +38,8 @@ class ReactiveTransformerBase(nn.Module):
|
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
40
|
def moe_router_loss(self):
|
41
|
-
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
|
-
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
41
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
|
42
|
+
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
|
43
43
|
|
44
44
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
45
45
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
@@ -124,7 +124,7 @@ class ClassicTransformerBase(nn.Module):
|
|
124
124
|
self.num_layers = len(layers) if layers else 0
|
125
125
|
|
126
126
|
def moe_router_loss(self):
|
127
|
-
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
127
|
+
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att]).mean()
|
128
128
|
|
129
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
130
130
|
# Shared logic for encoders and decoders - apply embeddings and positional encoding
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256
|
3
|
+
rxnn/experimental/attention.py,sha256=22Qb4jYN6QaqibTU8bwD8x2FaOKCxvWglM2eK9EuOlo,29468
|
4
|
+
rxnn/experimental/models.py,sha256=-BQn7gWlSHLpkAQdthPW5L9ZNzIBqSJS9tkm2N88jgw,4711
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
|
-
rxnn/transformers/models.py,sha256=
|
23
|
+
rxnn/transformers/models.py,sha256=_w5C7xvjT4-BFeMfzi57BQ51_fgaYZ4UK0SqUDE5Ooo,7266
|
24
24
|
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.50.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.50.dist-info/METADATA,sha256=bIeDbrlcclSfD9oHf26i_sYepOTvTkpcwQMWpOm2jWc,16627
|
30
|
+
rxnn-0.1.50.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.50.dist-info/RECORD,,
|
File without changes
|
File without changes
|