rxnn 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/transformers/moe.py CHANGED
@@ -106,34 +106,44 @@ class MoeFeedForward(nn.Module):
106
106
  x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
107
107
 
108
108
  # Get routing weights and indices
109
- weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
109
+ weights, indices = self.router(x) # [B*T, top_k], [B*T, top_k]
110
110
 
111
111
  # Flatten indices and weights
112
- batch_size = x.size(0)
113
- top_k = indices.size(1)
114
- indices = indices.view(-1) # [batch*seq_len * top_k]
115
- weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
116
-
117
- # Select only the relevant experts for each token
118
- selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
119
- selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
120
- selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
121
- selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
122
-
123
- # Reshape x for batched computation
124
- x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
125
-
126
- # Compute only the selected experts
127
- h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
112
+ batch_size = x.shape[0]
113
+ top_k = indices.shape[1]
114
+ indices_flat = indices.view(-1) # [B*T * top_k]
115
+
116
+ # Compute contributions for selected experts without materializing large tensors
117
+ # First Layer:
118
+ # Compute all expert contributions first (but this may still be memory-heavy)
119
+ # Alternative: Compute contributions for selected experts directly
120
+ # ... (see detailed steps below)
121
+
122
+ # Alternative approach using gather and batched operations
123
+ x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [B*T*top_k, D]
124
+
125
+ # Compute first layer contributions using gather
126
+ # indices_flat has shape [B*T*top_k]
127
+ # selected_w1 is self.w1[indices_flat], but we compute the product inline
128
+ h = torch.einsum(
129
+ 'be, eih -> bh',
130
+ x_expanded,
131
+ self.w1[indices_flat]
132
+ ) + self.b1[indices_flat]
128
133
  h = self._activate(h)
129
134
  h = self.dropout(h)
130
135
 
131
- out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
132
-
133
- # Reshape back and apply weights
134
- out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
135
- weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
136
- out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
136
+ # Second layer:
137
+ out = torch.einsum(
138
+ 'bh, eho -> beo',
139
+ h,
140
+ self.w2[indices_flat]
141
+ ).squeeze(-1) + self.b2[indices_flat]
142
+
143
+ # Reshape and apply weights
144
+ out = out.view(batch_size, top_k, -1)
145
+ weights = weights.view(batch_size, top_k, 1)
146
+ out = (out * weights).sum(dim=1)
137
147
 
138
148
  return out.view(*orig_shape)
139
149
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.14
3
+ Version: 0.1.15
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -19,11 +19,11 @@ rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
19
19
  rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
20
20
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
21
21
  rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
22
- rxnn/transformers/moe.py,sha256=fFPTRcctCSc9OwHd0PhNb0nwHgNJY7dXfUtGreXtaho,6720
22
+ rxnn/transformers/moe.py,sha256=s2yeBsAg-JIqKp7tLlXPdLNar9FXZ14LgbHyXlUKk6o,6758
23
23
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
24
24
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
25
25
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
26
- rxnn-0.1.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
- rxnn-0.1.14.dist-info/METADATA,sha256=YQDNMaHDrfVdOk44qEUczgLaNcrXApoqVmNX50yQDdM,14629
28
- rxnn-0.1.14.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
- rxnn-0.1.14.dist-info/RECORD,,
26
+ rxnn-0.1.15.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
27
+ rxnn-0.1.15.dist-info/METADATA,sha256=r3sjBGoGAsIcNqrNEC1tDuG6blEuNRVrQ_3fyy-yWJY,14629
28
+ rxnn-0.1.15.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
29
+ rxnn-0.1.15.dist-info/RECORD,,
File without changes
File without changes