rxnn 0.1.42__py3-none-any.whl → 0.1.43__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +1 -29
- rxnn/transformers/attention.py +0 -7
- {rxnn-0.1.42.dist-info → rxnn-0.1.43.dist-info}/METADATA +1 -1
- {rxnn-0.1.42.dist-info → rxnn-0.1.43.dist-info}/RECORD +6 -6
- {rxnn-0.1.42.dist-info → rxnn-0.1.43.dist-info}/LICENSE +0 -0
- {rxnn-0.1.42.dist-info → rxnn-0.1.43.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -83,20 +83,16 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
83
83
|
|
84
84
|
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
85
85
|
skip_query_processing: bool = False):
|
86
|
-
head_dim = d // self.num_heads
|
87
|
-
|
88
86
|
# Process Query as in GQA
|
89
87
|
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2) if not skip_query_processing else query
|
90
88
|
|
91
89
|
# Key/Value MoE routing
|
92
90
|
B, S, D = key.shape
|
93
|
-
print('key/value type', key.dtype, value.dtype)
|
94
91
|
key_flat = key.reshape(-1, D)
|
95
|
-
print('key flat type', key_flat.dtype)
|
96
92
|
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
97
93
|
weights = weights.view(B, S, self.num_groups, 1)
|
98
94
|
indices = indices.view(B, S, self.num_groups)
|
99
|
-
|
95
|
+
|
100
96
|
# Compute all experts' projections
|
101
97
|
# Shape: (B*S, num_experts, head_dim)
|
102
98
|
k_all = torch.einsum('bd,edh->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
@@ -106,43 +102,23 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
106
102
|
k_all += self.bk
|
107
103
|
v_all += self.bv
|
108
104
|
|
109
|
-
print('k all/v all before get all')
|
110
|
-
print(k_all.size(), k_all.dtype)
|
111
|
-
print(v_all.size(), v_all.dtype)
|
112
|
-
|
113
105
|
# Get results for all heads
|
114
106
|
k_all = k_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
115
107
|
v_all = v_all.view(B, S, self.num_experts, -1) # [B, S, num_experts, head_dim]
|
116
108
|
|
117
|
-
print('k all/v all get all')
|
118
|
-
print(k_all.size(), k_all.dtype)
|
119
|
-
print(v_all.size(), v_all.dtype)
|
120
|
-
|
121
109
|
# Gather top-k experts using expanded indices
|
122
110
|
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1)) # [B, S, num_groups, head_dim]
|
123
111
|
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
124
112
|
selected_v = torch.gather(v_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
125
113
|
|
126
|
-
print('selected k/selected v')
|
127
|
-
print(selected_k.size(), selected_k.dtype)
|
128
|
-
print(selected_v.size(), selected_v.dtype)
|
129
|
-
|
130
114
|
# Weighted
|
131
115
|
weighted_k = (selected_k * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
132
116
|
weighted_v = (selected_v * weights).to(selected_k.device, dtype=selected_k.dtype) # [B, S, num_groups, head_dim]
|
133
117
|
|
134
|
-
print('weighted')
|
135
|
-
print(weighted_k.size(), weighted_k.dtype)
|
136
|
-
print(weighted_v.size(), weighted_v.dtype)
|
137
|
-
|
138
118
|
# Reshape to GQA format
|
139
119
|
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
140
120
|
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, num_groups, S, head_dim]
|
141
121
|
|
142
|
-
print('out 1')
|
143
|
-
print(k.size(), k.dtype)
|
144
|
-
print(v.size(), v.dtype)
|
145
|
-
|
146
122
|
if self.rel_embed:
|
147
123
|
group_heads = self.num_heads // self.num_groups
|
148
124
|
|
@@ -152,10 +128,6 @@ class GroupedMoeAttention(GroupedQueryAttention):
|
|
152
128
|
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
153
129
|
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
154
130
|
|
155
|
-
print('out 2')
|
156
|
-
print(k.size(), k.dtype)
|
157
|
-
print(v.size(), v.dtype)
|
158
|
-
|
159
131
|
return q, k, v
|
160
132
|
|
161
133
|
|
rxnn/transformers/attention.py
CHANGED
@@ -106,9 +106,6 @@ class MultiHeadAttention(nn.Module):
|
|
106
106
|
# with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
107
107
|
# return self._torch_attention(q, k, v, b, t, d, mask=mask, enable_gqa=enable_gqa)
|
108
108
|
from flash_attn import flash_attn_func
|
109
|
-
print(q.size(), q.dtype)
|
110
|
-
print(k.size(), k.dtype)
|
111
|
-
print(v.size(), v.dtype)
|
112
109
|
attn_output = flash_attn_func(q, k, v, dropout_p=self.dropout.p if self.training else 0.0, causal=self.is_causal)
|
113
110
|
return self._transpose_output(attn_output, b, t, d)
|
114
111
|
|
@@ -137,10 +134,6 @@ class MultiHeadAttention(nn.Module):
|
|
137
134
|
return self._calculate_output(attn_weights, v, b, t, d)
|
138
135
|
|
139
136
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
140
|
-
print('MHA forward')
|
141
|
-
print(query.size(), query.dtype)
|
142
|
-
print(key.size(), key.dtype)
|
143
|
-
print(value.size(), value.dtype)
|
144
137
|
b, t, d = query.size()
|
145
138
|
q, k, v = self._forward_qkv(query, key, value, b, t, d)
|
146
139
|
if not self.rel_embed:
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=tyeZ2xkZfFrlTWzqjKS13NnIRSLGe1-oR7J20hN6BgI,29602
|
4
4
|
rxnn/experimental/models.py,sha256=IzUVc5s-cA__8jsG2mVvzUDmzPRcfBcI5btaOjnPYhA,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -16,7 +16,7 @@ rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
rxnn/transformers/attention.py,sha256=
|
19
|
+
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.43.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.43.dist-info/METADATA,sha256=7iCT0KM5FyGKZwe9ZbxnULu9z-rwyko4Xbklok1ZsOs,16627
|
30
|
+
rxnn-0.1.43.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.43.dist-info/RECORD,,
|
File without changes
|
File without changes
|