rxnn 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/transformers/moe.py
CHANGED
@@ -106,34 +106,44 @@ class MoeFeedForward(nn.Module):
|
|
106
106
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
107
107
|
|
108
108
|
# Get routing weights and indices
|
109
|
-
weights, indices = self.router(x) # [
|
109
|
+
weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
|
110
110
|
|
111
111
|
# Flatten indices and weights
|
112
|
-
batch_size = x.
|
113
|
-
top_k = indices.
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
#
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
#
|
127
|
-
|
112
|
+
batch_size = x.shape[0]
|
113
|
+
top_k = indices.shape[1]
|
114
|
+
indices_flat = indices.view(-1) # [B*T * top_k]
|
115
|
+
|
116
|
+
# Compute contributions for selected experts without materializing large tensors
|
117
|
+
# First Layer:
|
118
|
+
# Compute all expert contributions first (but this may still be memory-heavy)
|
119
|
+
# Alternative: Compute contributions for selected experts directly
|
120
|
+
# ... (see detailed steps below)
|
121
|
+
|
122
|
+
# Alternative approach using gather and batched operations
|
123
|
+
x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
|
124
|
+
|
125
|
+
# Compute first layer contributions using gather
|
126
|
+
# indices_flat has shape [B*T*top_k]
|
127
|
+
# selected_w1 is self.w1[indices_flat], but we compute the product inline
|
128
|
+
h = torch.einsum(
|
129
|
+
'be, eih -> bh',
|
130
|
+
x_expanded,
|
131
|
+
self.w1[indices_flat]
|
132
|
+
) + self.b1[indices_flat]
|
128
133
|
h = self._activate(h)
|
129
134
|
h = self.dropout(h)
|
130
135
|
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
136
|
+
# Second layer:
|
137
|
+
out = torch.einsum(
|
138
|
+
'bh, eho -> beo',
|
139
|
+
h,
|
140
|
+
self.w2[indices_flat]
|
141
|
+
).squeeze(-1) + self.b2[indices_flat]
|
142
|
+
|
143
|
+
# Reshape and apply weights
|
144
|
+
out = out.view(batch_size, top_k, -1)
|
145
|
+
weights = weights.view(batch_size, top_k, 1)
|
146
|
+
out = (out * weights).sum(dim=1)
|
137
147
|
|
138
148
|
return out.view(*orig_shape)
|
139
149
|
|
@@ -19,11 +19,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
|
19
19
|
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
21
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
22
|
+
rxnn/transformers/moe.py,sha256=s2yeBsAg-JIqKp7tLlXPdLNar9FXZ14LgbHyXlUKk6o,6758
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
24
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.15.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.15.dist-info/METADATA,sha256=r3sjBGoGAsIcNqrNEC1tDuG6blEuNRVrQ_3fyy-yWJY,14629
|
28
|
+
rxnn-0.1.15.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|